欢迎关注我的CSDN:https://spike.blog.csdn.net/
本文地址:https://spike.blog.csdn.net/article/details/145368666
GroupQueryAttention(分组查询注意力机制) 和 KVCache(键值缓存) 是大语言模型中的常见架构,GroupQueryAttention 是注意力机制的变体,通过将查询(Query)分组,每组与相同的键(Key)值(Value)交互,优化计算效率和性能,保持模型对于输入信息有效关注,减少计算资源的消耗,适用于处理大规模数据和复杂任务的场景。KVCache 是缓存机制,用于存储和快速检索键值对(KV),当模型处理新的输入(Q)时,直接从缓存中读取KV数据,无需重新计算,显著提高模型的推理速度和效率。GQA 与 KVCache 在提升模型性能和优化资源利用方面,都发挥着重要作用,结合使用可以进一步增强模型在实际应用中的表现。
从 MHA 到 GQA,再到 GQA+KVCache,简单实现,参考:
- GQA:从头实现 LLaMA3 网络与推理流程
- KVCache:GPT(Decoder Only) 类模型的 KV Cache 公式与原理
Scaled Dot-Product Attention (缩放点积注意力机制),也称单头自注意力机制,公式:
A
t
t
e
n
t
i
o
n
(
Q
,
K
,
V
)
=
s
o
f
t
m
a
x
(
Q
K
⊤
d
k
)
V
Attention(Q,K,V)=softmax(\frac{QK^{\top}}{\sqrt{d_{k}}})V
Attention(Q,K,V)=softmax(dkQK⊤)V
1. MultiHeadAttention
MultiHeadAttention (多头注意力机制),合计 43 行:
__init__
初始化 (10行):- 输入:
heads
(头数)、d_model
(维度)、dropout
(用于scores
) - 计算
d_k
每个 Head 的维度,即 d m o d e l = h e a d s × d k d_{model} = heads \times d_{k} dmodel=heads×dk - 线性层是 QKVO,Dropout 层
- 输入:
attention
注意力 (10行):-
q
q
q 的维度
[bs,h,s,d]
,与 k ⊤ k^{\top} k⊤ 的[bs,h,d,s]
,mm 之后 scores 是[bs,h,s,s]
- mask 的维度是
[bs,s,s]
,使用unsqueeze(1)
,转换成[bs,1,s,s]
- QKV 的计算,额外支持 Dropout
-
q
q
q 的维度
forward
推理 (12行):- QKV Linear 转换成
[bs,s,h,dk]
,再转换[bs,h,s,dk]
- 计算 attn 的
[bs,h,s,dk]
- 转换
[bs,s,h,dk]
,再contiguous()
,再 合并 h × d k = d h \times d_{k} = d h×dk=d - 再过 O
- QKV Linear 转换成
- 测试 (11行):
torch.randn
构建数据- Mask 的
torch.tril(torch.ones(bs, s, s))
即:
import math
import torch
import torch.nn.functional as F
from torch import nn
class MultiHeadAttention(nn.Module):
"""
多头自注意力机制 MultiHeadAttention
"""
def __init__(self, heads, d_model, dropout=0.1): # 10行
super().__init__()
self.d_model = d_model
self.d_k = d_model // heads
self.h = heads
self.q_linear = nn.Linear(d_model, d_model)
self.k_linear = nn.Linear(d_model, d_model)
self.v_linear = nn.Linear(d_model, d_model)
self.out = nn.Linear(d_model, d_model)
self.dropout = nn.Dropout(dropout)
@staticmethod
def attention(q, k, v, d_k, mask=None, dropout=None): # 10行
scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(d_k)
# 掩盖掉那些为了填补长度增加的单元,使其通过 softmax 计算后为 0
if mask is not None:
mask = mask.unsqueeze(1)
scores = scores.masked_fill(mask == 0, -1e9)
scores = F.softmax(scores, dim=-1)
if dropout is not None:
scores = dropout(scores)
output = torch.matmul(scores, v)
return output
def forward(self, q, k, v, mask=None): # 12行
bs = q.size(0)
# 进行线性操作划分为成 h 个头
k = self.k_linear(k).view(bs, -1, self.h, self.d_k)
q = self.q_linear(q).view(bs, -1, self.h, self.d_k)
v = self.v_linear(v).view(bs, -1, self.h, self.d_k)
# 矩阵转置
k = k.transpose(1, 2) # [bs,h,s,d] = [2, 8, 10, 64]
q = q.transpose(1, 2)
v = v.transpose(1, 2)
# 计算 attention
attn = self.attention(q, k, v, self.d_k, mask, self.dropout)
print(f"[Info] attn: {attn.shape}")
# 连接多个头并输入到最后的线性层
concat = attn.transpose(1, 2).contiguous().view(bs, -1, self.d_model)
output = self.out(concat)
return output
def main():
# 设置超参数
bs, s, h, d = 2, 10, 8, 512
dropout_rate = 0.1
# 创建 MultiHeadAttention 实例
attention = MultiHeadAttention(h, d, dropout_rate)
# 创建随机输入张量
q = torch.randn(bs, s, d)
k = torch.randn(bs, s, d)
v = torch.randn(bs, s, d)
# 可选:创建掩码,因果掩码,上三角矩阵
mask = torch.tril(torch.ones(bs, s, s))
# 测试无掩码的情况
output_no_mask = attention(q, k, v)
print("Output shape without mask:", output_no_mask.shape)
# 测试有掩码的情况
output_with_mask = attention(q, k, v, mask)
print("Output shape with mask:", output_with_mask.shape)
# 检查输出是否符合预期
assert output_no_mask.shape == (bs, s, d), "Output shape is incorrect without mask"
assert output_with_mask.shape == (bs, s, d), "Output shape is incorrect with mask"
print("Test passed!")
if __name__ == '__main__':
main()
2. GroupQueryAttention
GroupQueryAttention (分组查询注意力机制),相比于 MHA,参考 torch.nn.functional.scaled_dot_product_attention
__init__
:增加参数kv_heads
,即 KV Head 数量,KV 的 Linear 层输出维度(kv_heads * self.d_k
)也需要修改。forward
:使用repeat_interleave
扩充 KV 维度,其他相同,增加 3 行。
即:
import math
import torch
import torch.nn.functional as F
from torch import nn
class GroupQueryAttention(nn.Module):
"""
分组查询注意力机制(Group Query Attention)
"""
def __init__(self, heads, d_model, kv_heads, dropout=0.1):
super().__init__()
self.d_model = d_model
self.d_k = d_model // heads
self.h = heads
self.kv_heads = kv_heads
self.q_linear = nn.Linear(d_model, d_model)
self.k_linear = nn.Linear(d_model, kv_heads * self.d_k)
self.v_linear = nn.Linear(d_model, kv_heads * self.d_k)
self.out = nn.Linear(d_model, d_model)
self.dropout = nn.Dropout(dropout)
@staticmethod
def attention(q, k, v, d_k, mask=None, dropout=None):
# [2, 8, 10, 64] x [2, 8, 64, 10] = [2, 8, 10, 10]
scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(d_k)
# 掩盖掉那些为了填补长度增加的单元,使其通过 softmax 计算后为 0
if mask is not None:
mask = mask.unsqueeze(1)
scores = scores.masked_fill(mask == 0, -1e9)
scores = F.softmax(scores, dim=-1)
if dropout is not None:
scores = dropout(scores)
output = torch.matmul(scores, v)
return output
def forward(self, q, k, v, mask=None):
bs = q.size(0)
# 进行线性操作
q = self.q_linear(q).view(bs, -1, self.h, self.d_k) # [2, 10, 8, 64]
k = self.k_linear(k).view(bs, -1, self.kv_heads, self.d_k) # [2, 10, 4, 64]
v = self.v_linear(v).view(bs, -1, self.kv_heads, self.d_k)
# 复制键值头以匹配查询头的数量
group = self.h // self.kv_heads
k = k.repeat_interleave(group, dim=2) # [2, 10, 4, 64] -> [2, 10, 8, 64]
v = v.repeat_interleave(group, dim=2)
# 矩阵转置, 将 head 在前
k = k.transpose(1, 2) # [2, 8, 10, 64]
q = q.transpose(1, 2)
v = v.transpose(1, 2)
# 计算 attention
attn = self.attention(q, k, v, self.d_k, mask, self.dropout)
# 连接多个头并输入到最后的线性层
concat = attn.transpose(1, 2).contiguous().view(bs, -1, self.d_model)
output = self.out(concat)
return output
def main():
# 设置超参数, GQA 8//4=2组
bs, s, h, d, kv_heads = 2, 10, 8, 512, 4
dropout_rate = 0.1
# 创建 MultiHeadAttention 实例
attention = GroupQueryAttention(h, d, kv_heads, dropout_rate)
# 创建随机输入张量
q = torch.randn(bs, s, d)
k = torch.randn(bs, s, d)
v = torch.randn(bs, s, d)
# 可选:创建掩码,因果掩码,上三角矩阵
mask = torch.tril(torch.ones(bs, s, s))
# 测试无掩码的情况
output_no_mask = attention(q, k, v)
print("Output shape without mask:", output_no_mask.shape)
# 测试有掩码的情况
output_with_mask = attention(q, k, v, mask)
print("Output shape with mask:", output_with_mask.shape)
# 检查输出是否符合预期
assert output_no_mask.shape == (bs, s, d), "Output shape is incorrect without mask"
assert output_with_mask.shape == (bs, s, d), "Output shape is incorrect with mask"
print("Test passed!")
if __name__ == '__main__':
main()
3. GQA + KVCache
GroupQueryAttention + KVCache,相比于 GQA,增加 KVCache:
forward
:增加参数kv_cache
,合并[cached_k, new_k]
,同时返回new_kv_cache
,用于迭代,增加 5 行。- 设置
cur_qkv
与cur_mask
,迭代序列s维度,合计 8 行。
即:
import math
import torch
import torch.nn.functional as F
from torch import nn
class GroupQueryAttention(nn.Module):
"""
分组查询注意力机制(Group Query Attention)
"""
def __init__(self, heads, d_model, kv_heads, dropout=0.1):
super().__init__()
self.d_model = d_model
self.d_k = d_model // heads
self.h = heads
self.kv_heads = kv_heads
self.q_linear = nn.Linear(d_model, d_model)
self.k_linear = nn.Linear(d_model, kv_heads * self.d_k)
self.v_linear = nn.Linear(d_model, kv_heads * self.d_k)
self.out = nn.Linear(d_model, d_model)
self.dropout = nn.Dropout(dropout)
@staticmethod
def attention(q, k, v, d_k, mask=None, dropout=None):
# [2, 8, 1, 64] x [2, 8, 64, 10] = [2, 8, 1, 10]
scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(d_k)
# 掩盖掉那些为了填补长度增加的单元,使其通过 softmax 计算后为 0
if mask is not None:
mask = mask.unsqueeze(1)
scores = scores.masked_fill(mask == 0, -1e9)
scores = F.softmax(scores, dim=-1)
if dropout is not None:
scores = dropout(scores)
output = torch.matmul(scores, v)
return output
def forward(self, q, k, v, mask=None, kv_cache=None):
bs = q.size(0)
# 进行线性操作
q = self.q_linear(q).view(bs, -1, self.h, self.d_k) # [2, 1, 8, 64]
new_k = self.k_linear(k).view(bs, -1, self.kv_heads, self.d_k) # [2, 1, 4, 64]
new_v = self.v_linear(v).view(bs, -1, self.kv_heads, self.d_k) # [2, 1, 4, 64]
# 处理 KV Cache
if kv_cache is not None:
cached_k, cached_v = kv_cache
new_k = torch.cat([cached_k, new_k], dim=1)
new_v = torch.cat([cached_v, new_v], dim=1)
# 复制键值头以匹配查询头的数量
group = self.h // self.kv_heads
k = new_k.repeat_interleave(group, dim=2) # [2, 10, 4, 64] -> [2, 10, 8, 64]
v = new_v.repeat_interleave(group, dim=2)
# 矩阵转置, 将 head 在前
# KV Cache 最后1轮: q—>[2, 8, 1, 64] k->[2, 8, 10, 64] v->[2, 8, 10, 64]
k = k.transpose(1, 2) # [2, 8, 10, 64]
q = q.transpose(1, 2)
v = v.transpose(1, 2)
# 计算 attention
attn = self.attention(q, k, v, self.d_k, mask, self.dropout) # [2, 8, 1, 64]
print(f"[Info] attn: {attn.shape}")
# 连接多个头并输入到最后的线性层
concat = attn.transpose(1, 2).contiguous().view(bs, -1, self.d_model)
output = self.out(concat)
# 更新 KV Cache
new_kv_cache = (new_k, new_v) # 当前的 KV 缓存
return output, new_kv_cache
def main():
# 设置超参数
bs, s, h, d, kv_heads = 2, 10, 8, 512, 4
dropout_rate = 0.1
# 创建 GroupQueryAttention 实例
attention = GroupQueryAttention(h, d, kv_heads, dropout_rate)
# 创建随机输入张量
q = torch.randn(bs, s, d)
k = torch.randn(bs, s, d)
v = torch.randn(bs, s, d)
# 可选:创建掩码,因果掩码,上三角矩阵
mask = torch.tril(torch.ones(bs, s, s))
# 模拟逐步生成序列,测试 KV Cache
print("Testing KV Cache...")
kv_cache, output = None, None
for i in range(s):
cur_q = q[:, i:i+1, :]
cur_k = k[:, i:i+1, :]
cur_v = v[:, i:i+1, :]
cur_mask = mask[:, i:i+1, :i+1] # q是 i:i+1,k是 :i+1
output, kv_cache = attention(cur_q, cur_k, cur_v, cur_mask, kv_cache)
print(f"Output shape at step {i}:", output.shape)
# 检查输出是否符合预期
assert output.shape == (bs, 1, d), "Output shape is incorrect when using KV Cache"
print("Test passed!")
if __name__ == "__main__":
main()