torch.einsum
是一种使用爱因斯坦求和约定的操作,能够灵活地对张量执行各种线性代数运算(如点积、矩阵乘法、张量缩并等)。下面从公式写法、各部分意义和具体示例三个方面详细解释。
1. 公式写法
torch.einsum
的基本调用格式如下:
result = torch.einsum('subscripts', *operands)
其中:
'subscripts'
:索引表示的字符串,定义了张量的维度以及它们之间的关系。*operands
:参与计算的一个或多个张量。
索引表示的结构:
'input_subscripts->output_subscripts'
input_subscripts
:用逗号分隔的字符串,每组字母表示对应输入张量的维度索引。->
(可选):指示输出张量的维度索引。如果省略->
,默认保留未求和的维度。output_subscripts
:指定结果张量的维度索引。
核心规则:
- 相同索引表示要对该维度进行相乘后求和。
- 不同索引表示要保留的维度。
- 未出现在
output_subscripts
中的索引会进行求和并被消除。
2. 每部分的意义
(1) 输入索引
- 描述每个输入张量的维度。
- 例如:
'ij,jk->ik'
表示两个矩阵:- 第一个矩阵 AA 的维度索引为
ij
,形状为(I, J)
。 - 第二个矩阵 BB 的维度索引为
jk
,形状为(J, K)
。
- 第一个矩阵 AA 的维度索引为
j
是公共索引,对它进行求和。
(2) 输出索引
- 指定哪些维度保留在结果张量中。
- 未出现在输出中的索引会被自动求和消除。
- 例如:
'ij,jk->ik'
:保留i
和k
,对j
求和。