之前一直以为 Attention 和 RNN 没关系是凭空蹦出来的新概念;以为 Transformer, Encoder, Decoder 这几个概念是绑在一起的。并不尽然。
Encoder 和 Decoder
RNN 里就有 Encoder Decoder 的概念。其中,encoder 接受用户输入,写入 hidden state。Decoder 接受之前时刻的隐状态,并生成 logits。类似的架构也出现在 CNN 图像模型中。
所以,不论如何,只要是数据流长得像 encode, decode 的,都是 Encoder, Decoder
Attention 普遍意义上的注意力机制
上面 RNN 的问题是,decoder 只能拿到 encoder 最后的这个 <end>
位置的 feature,相当于必须串行接收整个输入,不能有注意力地选择输入序列的重点(不能加权)。
所以,我们想实现一个类似全连接的功能,在每个 decode 的位置,给输入序列的隐状态加个系数,共同喂给 decoder。所以,注意力其实就是把上面的这个序列算个系数。
但是怎么能让这个全连接矩阵可训练,可泛化是个问题。注意力机制引入了 Q, K, V 三个概念,其中 K, V 是 n 个 kv pair,Query 表示上图上面的部分,最后,Q 和 K 会两两一组算一个相关系数,然后用相关系数乘上 v,作为注意力输出。
其中,Q, K 表示。一个例子是我看涩图的注意力集中在人脸上,Q = 我; K = 涩图(V 和 K 严格绑定,是另一个空间对 K 的表示)Q,K 算一个相似度赋给 V.
一般 K = V。
自注意力机制
注意力是一个很宽泛的概念,不知道 QKV 是什么,自注意力机制则是规定了 QKV 同源,都是通过原始输入 X X X 乘上线性矩阵 W q , W k , W v W^q, W^k, W^v Wq,Wk,Wv 产生的。
给定输入矩阵
X
X
X(形状为
(
n
,
d
)
(n, d)
(n,d),其中
n
n
n 是序列长度,
d
d
d 是嵌入维度),计算 Query(查询)、Key(键)、Value(值):
Q
=
X
W
Q
,
K
=
X
W
K
,
V
=
X
W
V
Q = X W_Q, \quad K = X W_K, \quad V = X W_V
Q=XWQ,K=XWK,V=XWV
其中:
- W Q , W K , W V W_Q, W_K, W_V WQ,WK,WV 是可训练的权重矩阵(形状均为 ( d , d k ) (d, d_k) (d,dk))。
- Q , K , V Q, K, V Q,K,V 的形状均为 ( n , d k ) (n, d_k) (n,dk)。
2. 计算注意力分数(Scaled Dot-Product Attention)
A
=
Q
K
T
d
k
A = \frac{Q K^T}{\sqrt{d_k}}
A=dkQKT
其中:
- K T K^T KT 是 Key 矩阵的转置(形状为 ( d k , n ) (d_k, n) (dk,n)),使得 Q K T QK^T QKT 形状为 ( n , n ) (n, n) (n,n)。
- 1 d k \frac{1}{\sqrt{d_k}} dk1 是缩放因子,防止大数值影响梯度。
3. 计算注意力权重(Softmax 归一化)
α
=
softmax
(
A
)
\alpha = \text{softmax}(A)
α=softmax(A)
其中,
α
\alpha
α 形状为
(
n
,
n
)
(n, n)
(n,n),表示序列中每个位置对其他位置的注意力权重。
4. 计算加权 Value
Z
=
α
V
Z = \alpha V
Z=αV
其中:
- Z Z Z 形状为 ( n , d k ) (n, d_k) (n,dk),即每个输入位置的加权输出。
5. 多头注意力(Multi-Head Attention)
如果使用
h
h
h 个头,每个头分别计算:
Z
i
=
Attention
(
X
W
Q
i
,
X
W
K
i
,
X
W
V
i
)
Z_i = \text{Attention}(X W_{Q_i}, X W_{K_i}, X W_{V_i})
Zi=Attention(XWQi,XWKi,XWVi)
然后将多个头的结果拼接并映射回原始维度:
Z
=
[
Z
1
,
Z
2
,
…
,
Z
h
]
W
O
Z = [Z_1, Z_2, \dots, Z_h] W_O
Z=[Z1,Z2,…,Zh]WO
其中:
- W O W_O WO 是输出投影矩阵(形状为 ( h ⋅ d k , d ) (h \cdot d_k, d) (h⋅dk,d))。
- Z Z Z 形状回到 ( n , d ) (n, d) (n,d)。
Ref
https://zhuanlan.zhihu.com/p/109585084
https://www.cnblogs.com/nickchen121/p/16470710.html
https://www.cnblogs.com/nickchen121/p/16470711.html