LLM 注意力机制
- LLM 注意力机制
- 1. 注意力机制类型概述
- 2.Group Query Attention
- 3.FlashAttention
- 4. PageAttention
1. 注意力机制类型概述
注意力机制最早来源于Transformer,Transformer中的注意力机制分为2种 Encoder中的 全量注意力机制和 Decoder中的带mask的注意力机制。这两种注意力机制 都是 MultiHeadAttention 由Key,Query, Value 三个矩阵组成。
由于经典的MHA的计算时间和缓存占用量都是O(n^2)级别的(n是序列长度),这就意味着如果序列长度变成原来的 2 倍,显存占用量就是原来的 4 倍,计算时间也是原来的 4 倍。当然,假设并行核心数足够多的情况下,计算时间未必会增加到原来的 4 倍,但是显存的 4 倍却是实实在在的无可避免,这也是之前微调 Bert 的时候时不时就来个 OOM 的原因了。
所以不少工作致力于研究 降低Attention的计算复杂度和缓存大小,从而使复杂度从O(n^2) 降低到O(nlogn) 甚至O(n).
- 稀疏attention: SparseAttention,Longformer
- Reformer,Linformer:
- Linear Attention 思想: Q K t QK^t QKt 这一步我们得到一个nn的矩阵,就是这一步决定了Attention的复杂度是O(n^2);如果没有Softmax,那么就是三个矩阵连乘 Q K t V QK^{t}V QKtV,而矩阵乘法 是满足结合率的,所以我们可以先算 K t V K^{t}V KtV,得到矩阵dd,然后再用Q左乘它,由于d<<n,所以这样算大致的复杂度只有O(n)(就是Q左乘的那一步占主导)。也就说,去掉softmax的Attention的复杂度可以降低到最理想的线性级别。这显然是我们的终极追求:Linear Attention,复杂度为线性级别的Attention.
优化计算量和缓存后,LLM时代,推理速度加速的成为一个问题,于是针对推理慢的开始进行如下优化:
- IO传输瓶颈: 斯坦福团队发现 影响推理速度的瓶颈不在于计算量,而是IO传输。于是提出了减少IO传输的FlashAttention 1/2/3. FlashAttention论文的目标是尽可能高效地使用SRAM来加快计算速度。
- GPU显存瓶颈:研究人眼引入了 PagedAttention,这是一种受操作系统中虚拟内存和分页经典思想启发的注意力算法。与传统的注意力算法不同,PagedAttention 允许在非连续的内存空间中存储连续的 key 和 value 。具体来说,PagedAttention 将每个序列的 KV cache 划分为块,每个块包含固定数量 token 的键和值。在注意力计算期间,PagedAttention 内核可以有效地识别和获取这些块。
- 减少推理缓存: **GQA(group query attention)**分组注意力机制,在MQA基础上,增加多组Key,Value(但是不是全量),每个head独立拥有Query。
2.Group Query Attention
在自回归解码的标准做法是缓存序列中先前标记的键(K)和值(V) 对,从而加快注意力计算速度。然而,随着上下文窗口或批量大小的增加,多头注意力 (MHA)模型中与 KV 缓存大小相关的内存成本显着增长,所以随着上下文窗口增加,KV缓存大小成为瓶颈,为了扩展上下文,减少注意力机制的计算量和缓存大小,从而研究者开始对全量注意力机制的优化进行研究,目前主流的注意力机制主要分为3种:
- MHA(multi-head attention)全量注意力机制,每个head 独立拥有K Q V。
- MQA(multi-query attention)多查询注意力机制,多个head共享1组Key,Value,每个head独立拥有Query。
- 由于只是用一个 key 和value,大大加快解码推断的速度,但是可能导致质量下降。
- 目前ChatGLM2-6B使用的是这个
- GQA(group query attention)分组注意力机制,在MQA基础上,增加多组Key,Value(但是不是全量),每个head独立拥有Query。
- LLaMA2 和 Mistral采用的是这个。
- 属于1和2的折中,KV个数在1-head 中间.
- GQA 变体在大多数评估任务上的表现与 MHA 基线相当,并且平均优于 MQA 变体
3.FlashAttention
4. PageAttention
tobe added