Einsun 简介
ein 就是爱因斯坦的ein,sum就是求和。einsum就是爱因斯坦求和约定,其实作用就是把求和符号省略。
B = torch.einsum("ij->i", A)
einsum接收的第一个参数为einsum表达式,-> 符号就相当于要把->前面的张量变成->后面的张量,后面的参数为要被操作的张量。
这里的 i,j 是指代A的下标,也可以换成其他字母。
你可以直接看->后面相对于前面少了什么,那么就是对前面的哪些维度做了求和操作
实例:
1、对二维张量进行某一维度进行求和:
A = torch.Tensor(range(2*3)).view(2, 3)
# 对第一个维度'i'进行求和(也就是对列进行求和)
B = torch.einsum("ij->j", A)
# 对第二个维度'j'进行求和(也就是对行进行求和)
C = torch.einsum("ij->i", A)
print(A)
print(B)
print(C)
要被操作的张量A:
tensor([[0., 1., 2.],
[3., 4., 5.]])
1)首先看B:
ij -> j
那么相当于对A的 i 维度进行了求和:
tensor([3., 5., 7.])
2)看C:
ij -> i
那么相当于对A的 j 维度进行了求和:
tensor([3., 5., 7.])
tensor([ 3., 12.])
2、对张量维度进行变换
D = torch.einsum("ij->ji", A)
tensor([[0., 3.],
[1., 4.],
[2., 5.]])
3、张量乘法
高维数组相乘的运算规则看:不同维度的矩阵相乘的维度结果(高维数组相乘的运算规则)、时间复杂度计算_不同维度矩阵乘法-CSDN博客
示例1:
A = torch.Tensor(range(2*3)).view(2, 3)
B = torch.Tensor(range(2*3)).view(2, 3).view(3, 2)
C = torch.einsum("ij,jk->ik", A, B)
print(A)
print(B)
print(C)
通过写表达式的方式可以看到具体做了什么操作:
第一步,写出数学表达式:
第二步,补充求和符号
由于右侧相比于左侧少了一个j,所以相当于对第j维度进行求和
整体来看整个步骤就是:
tensor([[0., 1., 2.],
[3., 4., 5.]])
tensor([[0., 1.],
[2., 3.],
[4., 5.]])
tensor([[10., 13.],
[28., 40.]])
示例2:
在ViT中,下面q与k相乘的代码:
weight = th.einsum("bct,bcs->bts", q * scale, k * scale)
可以替换为:
q = q * scale
k = k * scale
k_transposed = k.transpose(1, 2)
weight = torch.bmm(q, k_transposed)
具体的运算步骤如下:
一个函数打天下,einsum - 知乎
矩阵操作万能函数 einsum 详细解析(通法教你如何看懂并写出einsum表达式)-CSDN博客