1 bmm (batch matrix multiplication)
- 批量矩阵乘法,用于同时处理多个矩阵的乘法
bmm
的输入是两个 3D 张量(batch of matrices),形状分别为(batch_size, n, m)
和(batch_size, m, p)
bmm
输出的形状是(batch_size, n, p)
2 mm
mm
是标准的矩阵乘法操作,用于两个二维矩阵相乘mm
仅适用于 2D 张量,输入的形状分别是(n, m)
和(m, p)
- 输出的形状是
(n, p)