引言
Grouped Query Attention(GQA,分组查询注意力)和多头注意力机制(Multi-Head Attention,MHA)都是Transformer模型中用于捕获输入序列中不同位置之间关系的注意力机制。然而,它们在实现方式和计算复杂度上有所不同。下面我将详细介绍它们的原理以及它们之间的区别。
1. 多头注意力机制(MHA)
1.1 概念
多头注意力机制(Multi-Head Attention,MHA)是Transformer模型中的核心组件。它通过并行的多组注意力机制,让模型能够在不同的子空间中关注序列的不同方面,从而加强模型的表达能力。
1.2 工作原理
输入表示:给定输入序列 X ∈ R n × d model \mathbf{X} \in \mathbb{R}^{n \times d_{\text{model}}} X∈Rn×dmodel, n n n是序列长度, d model d_{\text{model}} dmodel是隐藏维度。
线性投影:通过线性变换将输入 X \mathbf{X} X映射为查询( Q \mathbf{Q} Q)、键( K \mathbf{K} K)和值( V \mathbf{V} V):
Q = X W Q , K = X W K , V = X W V \mathbf{Q} = \mathbf{X}\mathbf{W}_Q, \quad \mathbf{K} = \mathbf{X}\mathbf{W}_K, \quad \mathbf{V} = \mathbf{X}\mathbf{W}_V Q=XWQ,K=XWK,V=XWV
拆分多头:将 Q \mathbf{Q} Q、 K \mathbf{K} K、 V \mathbf{V} V沿着隐藏维度拆分成 h h h个头,每个头的维度为 d k = d model / h d_k = d_{\text{model}} / h dk=dmodel/h。
计算每个头的注意力:
head i = Attention ( Q i , K i , V i ) \text{head}_i = \text{Attention}(\mathbf{Q}_i, \mathbf{K}_i, \mathbf{V}_i) headi=Attention(Qi,Ki,Vi)
其中,
Attention ( Q , K , V ) = softmax ( Q K T d k ) V \text{Attention}(\mathbf{Q}, \mathbf{K}, \mathbf{V}) = \text{softmax}\left( \frac{\mathbf{Q} \mathbf{K}^T}{\sqrt{d_k}} \right) \mathbf{V} Attention(Q,K,V)=softmax(dkQKT)V
拼接和线性变换:将所有头的输出拼接起来,再通过一个线性层:
Z = Concat ( head 1 , … , head h ) W O \mathbf{Z} = \text{Concat}(\text{head}_1, \dots, \text{head}_h) \mathbf{W}_O Z=Concat(head1,…,headh)WO
其中 W O ∈ R d model × d model \mathbf{W}O \in \mathbb{R}^{d{\text{model}} \times d_{\text{model}}} WO∈Rdmodel×dmodel。
1.3 特点
每个头有独立的Q、K和V的投影矩阵,能够在不同的子空间中捕获输入序列的不同特征。提高模型的多样性和表达能力。
2. 分组查询注意力(Grouped Query Attention,GQA)
2.1 概念
分组查询注意力(Grouped Query Attention,GQA)是一种改进的注意力机制,旨在降低模型的参数量和计算复杂度,特别适用于资源受限的环境,如移动设备上的应用。
2.2 工作原理
分组思想:GQA将多头注意力机制中的多个注意力头分为 g g g个组,每组共享一个查询投影矩阵,但仍然拥有独立的键和值投影矩阵。
输入表示:与MHA相同,输入 X ∈ R n × d model \mathbf{X} \in \mathbb{R}^{n \times d_{\text{model}}} X∈Rn×dmodel。
线性投影:
查询:每组内的头共享查询投影矩阵,共有
g
g
g个查询投影矩阵:
Q
(
j
)
=
X
W
Q
(
j
)
,
j
=
1
,
2
,
…
,
g
\mathbf{Q}^{(j)} = \mathbf{X} \mathbf{W}_Q^{(j)}, \quad j = 1, 2, \dots, g
Q(j)=XWQ(j),j=1,2,…,g
键和值:每个头仍然有独立的键和值投影矩阵,总共有
h
h
h个键和值投影矩阵。
计算每个头的注意力:
对于第 j j j组,第 i i i个头:
head ( j , i ) = Attention ( Q ( j ) , K ( j , i ) , V ( j , i ) ) \text{head}{(j, i)} = \text{Attention}(\mathbf{Q}^{(j)}, \mathbf{K}{(j, i)}, \mathbf{V}_{(j, i)}) head(j,i)=Attention(Q(j),K(j,i),V(j,i))
拼接和线性变换:与MHA类似,将所有头的输出拼接起来,通过线性层输出。
2.3 特点
减少参数量:由于查询投影矩阵在组内共享,参数量较MHA有所减少。
降低计算复杂度:共享查询减少了计算量,特别是在查询投影的部分。
折衷方案:在保持一定表达能力的同时,降低了模型的资源消耗。
3. MHA 和 GQA 的区别
3.1 查询投影矩阵(Q)的共享
MHA:每个头都有独立的查询、键和值投影矩阵。
GQA:查询投影矩阵在每个组内共享,键和值投影矩阵仍然是独立的。
3.2 参数量和计算复杂度
参数量
MHA:总参数量与头数 h h h成正比,因为每个头都有独立的投影矩阵。
GQA:参数量减少,因为查询投影矩阵共享,参数量与组数 g g g和头数 h h h有关。
计算复杂度
MHA:计算复杂度较高,需要计算所有独立投影。
GQA:由于查询投影减少,计算量有所降低,效率更高。
3.3 表达能力
MHA:每个头完全独立,具有最大的表达灵活性,能够在不同的子空间中捕获多样化的特征。
GQA:在组内头的查询受限于共享的投影矩阵,可能会略微降低表达能力,但通过保留独立的键和值投影矩阵,仍然能够捕获丰富的特征。
3.4 应用场景
MHA:适用于对模型性能要求高、资源相对充足的场景,如服务器端的模型训练和推理。
GQA:适用于资源受限的场景,如移动设备、嵌入式系统等,追求在降低资源消耗的同时保持较好的模型性能。
4. 直观理解
MHA:想象每个注意力头都有自己独立的“视角”,从查询、键和值三个方面独立观察输入序列。
GQA:在GQA中,每组内的注意力头共享“视角”(查询),但仍然可以通过自己的键和值关注不同的信息。这有点像一组人看着同一张地图(查询),但关注不同的地标(键和值)。
5. 总结
多头注意力机制(MHA)
特点:每个头都有独立的查询、键和值投影矩阵,最大化模型的表达能力。
优点:能够捕获输入序列中丰富的特征,适用于对性能要求高的场景。
缺点:参数量大,计算复杂度高,对资源要求较高。
分组查询注意力(GQA)
特点:在组内共享查询投影矩阵,减少参数量和计算量。
优点:在降低资源消耗的同时,尽可能保持模型的性能,适用于资源受限的场景。
缺点:由于共享查询,可能会影响模型的表达能力。