多头注意力
上篇文章中我们了解了词编码和位置编码,接下来我们介绍Transformer中的核心模块——多头注意力。
自注意力
首先回顾下注意力机制,注意力机制允许模型为序列中不同的元素分配不同的权重。而自注意力中的"自"表示输入序列中的输入相互之间的注意力,即通过某种方式计算输入序列每个位置相互之间的相关性。具体的推导可以看这篇文章。
对于Transformer编码器来说,给定一个输入序列
(
x
1
,
.
.
.
,
x
n
)
(x_1,...,x_n)
(x1,...,xn),这里假设是输入序列中第i个位置所对应的词嵌入。自注意力产生了一个新的相同长度的嵌入
(
y
1
,
.
.
.
,
y
2
)
(y_1,...,y_2)
(y1,...,y2),其中每个
y
i
y_i
yi是所有
x
j
x_j
xj的加权和(包括本身):
y
i
=
∑
j
α
i
j
x
j
y_i=\sum_j\alpha_{ij}x_j
yi=j∑αijxj 其中,
α
i
j
\alpha_{ij}
αij是注意力权重,有
∑
j
α
i
j
=
1
\sum_j\alpha_{ij}=1
∑jαij=1。
计算注意力的方式有很多种,最高效的是点积注意力,即两个输入之间做点积。点积的结果是一个实数范围内的标量,结果越大代表两个向量越相似。这是计算两个输入之间的注意力分数,将某个token与所有的输入进行计算,就可以得到n个注意力分数,经过Softmax归一化就可以得到权重向量
α
\alpha
α,其中
α
i
j
\alpha_{ij}
αij表示两个输入i和j之间的相关度(权重系数):
α
i
j
=
s
o
f
t
m
a
x
(
s
c
o
r
e
(
x
i
,
x
j
)
)
\alpha_{ij}=softmax(score(x_i,x_j))
αij=softmax(score(xi,xj)) 得到了这些权重之后,就可以按照上面的公式对所有输入加权得到
y
i
y_i
yi。
Transformer中的注意力会更加复杂一点,主要体现在两点:Q,K,V和缩放点积机制和多头注意力。
缩放点积注意力
Transformer的多头注意力模块有三个输入:
- Query: 与所有的输入进行比较,为当前关注的点。
- Key:作为与Query进行比较的角色,用于计算和Query之间的相关性。
- Value:用于计算当前注意力关注点的输出,根据注意力权重对不同的Value进行加权和。
这三个输入都是由原始输入映射而来的,为了生成这三种不同的角色,Transformer分别引入了三个权重矩阵 W Q , W K , W V W^Q,W^K,W^V WQ,WK,WV,分别将每个输入投影到不同角色query,key和value表示: q = x W Q ; k = x W K ; v = x W V q=xW^Q; k=xW^K; v=xW^V q=xWQ;k=xWK;v=xWV Query和Key是用于比较的,Value是用于提取特征的。通过将输入映射到不同的角色,使模型具有更强的学习能力。看到一个比较直观的解释:如果把注意力过程类比成搜索的话,那么假设在百度中输入"自然语言处理是什么",那么Query就是这个搜索的语句;Key相当于检索到的网页的标题;Value就是网页的内容。 现在我们用Q和K矩阵计算亲和度矩阵sim matrix,矩阵中第i,j个位置,即指原序列第i个token和第j个token之间的亲和度(点积结果): s c o r e ( q i , k j ) = q i ⋅ k j score(q_i,k_j)=q_i\cdot k_j score(qi,kj)=qi⋅kj 点积的结果是一个标量,但这个结果可能非常大(不管是正的还是负的),这会使得softmax函数值进入一个导数非常小的区域。需要对这个注意力得分进行缩放,缩放使得分布更加平滑。一种缩放的方法是把点积结果除以一个和嵌入大小相关的因子(factor)。注意这是在传递给softmax之前进行的。Transformer的做法是除以query和key向量维度的平方根: s c o r e ( q i , k j ) = q i ⋅ k j d score(q_i,k_j)=\frac{q_i\cdot k_j}{\sqrt d} score(qi,kj)=dqi⋅kj 这是计算具体两个token向量的亲和度,我们可以用矩阵相乘实现批量操作,考虑一下维度的问题:包含batch大小和序列长度,输入的完整维度是(batch_size, seq_len, embed_dim)。我们得到query,key和value也是相同维度的,只是经过了不同的线性变换。那么亲和度矩阵的计算方法为: S e l f A t t e n t i o n ( Q , K , V ) = s o f t m a x ( Q K T d k ) V SelfAttention(Q,K,V)=softmax(\frac{QK^T}{\sqrt d_k})V SelfAttention(Q,K,V)=softmax(dkQKT)V
多头注意力
上面介绍的缩放点积注意力把原始的x映射到不同的空间后,去做注意力。每次映射相当于是在特定空间中去建模特定的语义交互关系,类似卷积中的多通道可以得到多个特征图,那么多个注意力可以得到多个不同方面的语义交互关系。可以让模型更好地关注到不同位置的信息,捕捉到输入序列中不同依赖关系和语义信息。有助于处理长序列、解决语义消歧、句子表示等任务,提高模型的建模能力。Transformer中使用多头注意力实现这一点。
对于多头注意力中的每个头i,都有自己不同的query,key和value矩阵,假设Q,K矩阵的维度是
d
k
d_k
dk,V矩阵维度是
d
v
d_v
dv,那么每个头的权重矩阵是
W
i
Q
∈
R
d
×
d
k
,
W
i
K
∈
R
d
×
d
k
,
,
W
i
K
∈
R
d
×
d
v
W_i^Q\in R^{d×d_k},W_i^K\in R^{d×d_k},,W_i^K\in R^{d×d_v}
WiQ∈Rd×dk,WiK∈Rd×dk,,WiK∈Rd×dv,将他们与reshape后的输入相乘,得到的每个头的Q,K,V矩阵维度应该是
Q
∈
R
s
q
l
e
n
∗
d
k
,
K
∈
R
s
q
l
e
n
∗
d
k
,
V
∈
R
s
q
l
e
n
∗
d
v
Q\in R^{sqlen*d_k},K\in R^{sqlen*d_k},V\in R^{sqlen*d_v}
Q∈Rsqlen∗dk,K∈Rsqlen∗dk,V∈Rsqlen∗dv其中,sqlen是序列长度,
d
k
d_k
dk是每个头的维度。
得到这些多头注意力的组合以后,再把它们拼接起来,然后通过一个线性变化映射回原来的维度,保证输入和输出的维度一致:
M
u
l
t
i
H
e
a
d
A
t
t
e
n
t
i
o
n
(
X
)
=
c
o
n
c
a
t
(
h
e
a
d
1
,
.
.
.
,
h
e
a
d
h
)
W
o
MultiHeadAttention(X)=concat(head_1,...,head_h)W^o
MultiHeadAttention(X)=concat(head1,...,headh)Wo
h
e
a
d
i
=
S
e
l
f
A
t
t
e
n
t
i
o
n
(
Q
,
K
,
V
)
head_i=SelfAttention(Q,K,V)
headi=SelfAttention(Q,K,V)
Q
=
X
W
i
Q
;
K
=
X
W
I
K
;
V
=
X
W
I
V
Q=XW_i^Q;K=XW_I^K;V=XW_I^V
Q=XWiQ;K=XWIK;V=XWIV
下面是一个三个头的注意力示意图,在原论文中,d = 512,有h = 8个注意力头。由于每个头维度的减少,总的计算量和正常维度的单头注意力一样(8 × 64 = 512)。 给出一个pytorch实现的例子,在forward方法中,首先利用三个线性变换分别计算query,key,value矩阵。接着拆分成多个头,传给attention方法计算多头注意力,然后合并多头注意力的结果。最后经过一个用作拼接的线性层。
在下一篇文章中,我会继续讲解Transformer中剩下的其他组件。
class MultiHeadAttention(nn.Module):
def __init__(
self,
d_model: int = 512,
n_heads: int = 8,
dropout: float = 0.1,
) -> None:
super().__init__()
assert d_model % n_heads == 0
self.d_model = d_model
self.n_heads = n_heads
self.d_key = d_model // n_heads # dimension of every head
self.q = nn.Linear(d_model, d_model) # query matrix
self.k = nn.Linear(d_model, d_model) # key matrix
self.v = nn.Linear(d_model, d_model) # value matrix
self.concat = nn.Linear(d_model, d_model) # output
self.dropout = nn.Dropout(dropout)
def split_heads(self, x: Tensor, is_key: bool = False) -> Tensor:
batch_size = x.size(0)
# x (batch_size, seq_len, n_heads, d_key)
x = x.view(batch_size, -1, self.n_heads, self.d_key)
if is_key:
# (batch_size, n_heads, d_key, seq_len)
return x.permute(0, 2, 3, 1)
# (batch_size, n_heads, seq_len, d_key)
return x.transpose(1, 2)
def merge_heads(self, x: Tensor) -> Tensor:
x = x.transpose(1, 2).contiguous().view(x.size(0), -1, self.d_model)
return x
def attenion(
self,
query: Tensor,
key: Tensor,
value: Tensor,
mask: Tensor = None,
keep_attentions: bool = False,
):
scores = torch.matmul(query, key) / math.sqrt(self.d_key)
if mask is not None:
scores = scores.masked_fill(mask == 0, -1e9)
# weights (batch_size, n_heads, q_length, k_length)
weights = self.dropout(torch.softmax(scores, dim=-1))
# (batch_size, n_heads, q_length, k_length) x (batch_size, n_heads, v_length, d_key) -> (batch_size, n_heads, q_length, d_key)
# assert k_length == v_length
# attn_output (batch_size, n_heads, q_length, d_key)
attn_output = torch.matmul(weights, value)
if keep_attentions:
self.weights = weights
else:
del weights
return attn_output
def forward(
self,
query: Tensor,
key: Tensor,
value: Tensor,
mask: Tensor = None,
keep_attentions: bool = False,
) -> Tuple[Tensor, Tensor]:
"""
Args:
query (Tensor): (batch_size, q_length, d_model)
key (Tensor): (batch_size, k_length, d_model)
value (Tensor): (batch_size, v_length, d_model)
mask (Tensor, optional): mask for padding or decoder. Defaults to None.
keep_attentions (bool): whether keep attention weigths or not. Defaults to False.
Returns:
output (Tensor): (batch_size, q_length, d_model) attention output
"""
query, key, value = self.q(query), self.k(key), self.v(value)
query, key, value = (
self.split_heads(query),
self.split_heads(key, is_key=True),
self.split_heads(value),
)
attn_output = self.attenion(query, key, value, mask, keep_attentions)
del query
del key
del value
# Concat
concat_output = self.merge_heads(attn_output)
# the final liear
# output (batch_size, q_length, d_model)
output = self.concat(concat_output)
return output