介绍torch.matmul之前先介绍torch.mm函数, mm和matmul都是torch中矩阵乘法函数,mm只能作用于二维矩阵,matmul可以作用于二维也能作用于高维矩阵
mm函数使用
x = torch.rand(4, 9)
y = torch.rand(9, 8)
print(torch.mm(x,y).shape)
torch.Size([4, 8])
matmul函数使用
- 1 二维乘二维,结果和mm函数一样
x = torch.rand(4,9)
y = torch.rand(9,8)
print(torch.matmul(x,y).shape)
torch.Size([4, 8])
- 2 高维乘法 3 维 乘 2 维
将x的第0维提出来,剩下的就是二维矩阵乘法得到 9,4,8
x = torch.rand(9,4,9)
y = torch.rand(9,8)
print(torch.matmul(x,y).shape)
torch.Size([9, 4, 8])
- 3 高维矩阵乘法 3维 乘 3维
两种情况
1)x和y的0维一样直接提出来,剩下的是二维矩阵乘法结果是 9,5,6
2)x和y的第0维不一样,如果是有一个第0维是1则可以直接扩展成跟另一个矩阵的维度一样,然后直接提出来,剩下的二维矩阵乘法,下方x的矩阵第0维是1,y的是9,直接把x矩阵第0维扩展成9,即可跟下方第一个操作相同
x = torch.rand(9, 5, 8)
y = torch.rand(9, 8, 6)
print(torch.matmul(x,y).shape)
torch.Size([9,5,6])
x = torch.rand(1, 5, 8)
y = torch.rand(9, 8, 6)
print(torch.matmul(x,y).shape)
torch.Size([9,5,6])
torch.Size([9,5,6])
-
- 高维矩阵乘法 4维 乘 3 维
根据上方总结的规则,下方同样做法,多余的一维或几维,直接提出来,剩下的同维度矩阵直接计算,如果是1就扩展成与之相对另一个矩阵的相同的数,如果不同也不为1,就直接报错
- 高维矩阵乘法 4维 乘 3 维
x = torch.rand(6, 9, 5, 8)
y = torch.rand(9, 8, 6)
print(torch.matmul(x,y).shape)
torch.Size([6,9,5,6])
x = torch.rand(6, 1, 5, 8)
y = torch.rand(9, 8, 6)
print(torch.matmul(x,y).shape)
torch.Size([6,9,5,6])
x = torch.rand(6, 9, 1, 5, 8)
y = torch.rand(9, 8, 6)
print(torch.matmul(x,y).shape)
torch.Size([6,9,9,5,6])
x = torch.rand(6, 9,8,9, 1, 5, 8)
y = torch.rand(9, 8, 6)
print(torch.matmul(x,y).shape)
torch.Size([6,9,8,9,9,5,6])
参考资料
torch.mm()&torch.matmul()
torch.matmul()用法介绍
pytorch官方文档
torch中点积,叉积和卷积