CSP-202305-2-矩阵运算
关键点总结:改变矩阵计算顺序优化时间复杂度
通过先计算 K T × V K ^ T \times V KT×V 而不是先计算 Q × K T Q \times K ^ T Q×KT,有效地减少了计算时间,特别是在处理长序列时。这种优化通常在数据维度一不等时有显著效果,特别是当序列长度显著大于向量维度时。
1.原始的计算顺序
在Transformer的自注意力机制中,给定矩阵 Q Q Q(查询), K K K(键)和 V V V(值),计算首先涉及到以下步骤:
- 计算 Q × K T Q \times K ^ T Q×KT(查询和键的点积),得到注意力得分矩阵。
- 将注意力得分矩阵乘以 V V V(值矩阵),得到加权的值,这是最终的输出。
原始计算的时间复杂度主要由 Q × K T Q \times K ^ T Q×KT 的计算决定,这个操作的时间复杂度为 O ( n 2 ⋅ d ) O(n ^ 2 \cdot d) O(n2⋅d),其中 n n n 是序列长度(例如,句子中的单词数量或向量数量), d d d 是向量的维度。当 n n n 很大时,这个操作非常耗时。
2.代码中的计算顺序
代码中采取了不同的计算顺序:
-
首先,它通过与 W W W 相乘来调整 Q Q Q 中的每个元素(这对应于自注意力机制中的缩放操作,但在这个特定的实现中, W W W 似乎用于不同的目的,比如加权或转换,这并不是标准的自注意力机制的一部分)。
-
然后,它计算 K T × V K ^ T \times V KT×V,这个操作的时间复杂度为 O ( n ⋅ d 2 ) O(n \cdot d ^ 2) O(n⋅d2),因为它是在矩阵 K T K ^ T KT(维度 d × n d \times n d×n)和矩阵 V V V(维度 n × d n \times d n×d)之间进行的。
-
最后,它计算调整后的 Q Q Q 与 K T × V K ^ T \times V KT×V 的结果,时间复杂度为 O ( n ⋅ d 2 ) O(n \cdot d ^ 2) O(n⋅d2)。
3.时间复杂度比较
当 n > d n > d n>d(即,序列长度大于向量维度)时,代码中的计算顺序比原始计算顺序更有效率。原始方法的复杂度主要是由序列长度的平方决定的,而代码中的方法将这个平方项降低到了 n n n 和 d d d 的乘积,这在大多数实际情况下会减少计算量,尤其是在处理长序列时。
解题思路
搞清楚上面的点后,本质上就是简单的矩阵乘法,留意本题关于矩阵点乘计算规则的定义即可。
完整代码
#include <iostream>
#include <vector>
#include <string>
using namespace std;
int main() {
ios_base::sync_with_stdio(false);
cin.tie(0);
cout.tie(0);
long long n, d;
cin >> n >> d;
vector<vector<long long>>Q(n, vector<long long>(d));
vector<vector<long long>>K_T(d, vector<long long>(n));
vector<vector<long long>>V(n, vector<long long>(d));
vector<long long>W(n);
// 输入Q
for (int i = 0; i < n; i++)
{
for (int j = 0; j < d; j++)
{
cin >> Q[i][j];
}
}
// 输入K_T
for (int i = 0; i < n; i++)
{
for (int j = 0; j < d; j++)
{
cin >> K_T[j][i];
}
}
// 输入V
for (int i = 0; i < n; i++)
{
for (int j = 0; j < d; j++)
{
cin >> V[i][j];
}
}
// 输入W
for (int i = 0; i < n; i++)
{
cin >> W[i];
}
// 计算 W * Q
for (int i = 0; i < n; i++)
{
for (int j = 0; j < d; j++)
{
Q[i][j] *= W[i];
}
}
// 计算 K_T * V
vector<vector<long long>>T1(d, vector<long long>(d));
for (int i = 0; i < d; i++)
{
for (int j = 0; j < d; j++)
{
for (int k = 0; k < n; k++)
{
T1[i][j] += K_T[i][k] * V[k][j];
}
}
}
// 计算 Q * T1
vector<vector<long long>>T2(n, vector<long long>(d));
for (int i = 0; i < n; i++)
{
for (int j = 0; j < d; j++)
{
for (int k = 0; k < d; k++)
{
T2[i][j] += Q[i][k] * T1[k][j];
}
}
}
for (const auto& it : T2) {
for (const auto& jt : it) {
cout << jt << " ";
}
cout << endl;
}
return 0;
}