探索和构建 LLaMA 3 架构:深入探讨组件、编码和推理技术(三)
KV缓存
在推理的每一步中,只对模型输出的最后一个标记感兴趣,因为已经有了之前的标记。然而,模型需要访问所有先前的标记来决定输出哪个标记,因为它们构成了它的上下文(或“提示”)。
这是一种使模型在推理过程中对已经看到的标记进行更少计算的方法。解决办法就是KV缓存!
在Transformer的推理过程中, 增量且顺序地获取查询向量。将其乘以 Key 向量即可得到每个 token 与先前生成的 token 及其自身的注意力矩阵。然后,在取softmax之后, 乘以值向量以获得自注意力分数。最后有另一个输出投影矩阵,用于转换下一组多头注意力层的注意力分数。这个计算重复多次,然后得到词汇表中所有单词的概率分布
在上图中, 可以看到Transformer的推论。标记 TOKEN 1 到 TOKEN 4 按顺序出现,因为注意力计算 TOKEN 4 取决于所有先前的标记。
-
在紫色矩阵中, 可以看到 Q 和 K 矩阵乘法随着注意力矩阵一起增长,但 K 和 V 值矩阵对于所有先前的标记保持相同。另外,如图所示, 不需要已经计算出的注意力分数(需要注意的是, 可能需要波束搜索来获得它们,但这里 只考虑贪婪采样),所以 可以扔掉它们。深紫色矩阵实际上为零,因为它是因果矩阵,因此第一个标记从不关注第四个标记,并且它们被屏蔽。
-
因此 可以缓存 K 和 V 矩阵,因为它们不会改变。但是, 无法缓存 Q 矩阵。这是因为 Q 矩阵随着每个新标记而变化。查询矩阵是标记正在查找的内容,键矩阵是标记包含的内容,值矩阵是当前标记和前一个标记是否对词汇表中的标记感兴趣。
-
此外,可以借助电影数据库来理解查询(query)、键(key)和值(value)的概念。假设你想看一部能让你发笑,并且最后有一个“谁是凶手”环节的电影(这是查询)。那么首先,我们会在数据库中查询一部能让我们发笑的电影,这将是一部喜剧片(这是键)。然后,我们会得到一系列喜剧电影的推荐(这是值)。在那之后,电影数据库会获取到电影应该是“谁是凶手”类型或属于惊悚片类型的信息。然后,电影数据库将寻找喜剧和惊悚类型的电影(这是更新后的键),并且借助之前缓存的喜剧电影推荐,我们可以搜索那些同时也是惊悚片的电影(值)。
因此, 可以缓存喜剧类型和所有喜剧电影推荐,以便当新信息出现(惊悚类型)时, 可以缩小搜索范围并提高效率。
KV 缓存对于高效推理至关重要,因为 增量存储键和值矩阵并缓存它们,以便可以更快地计算未来的注意力分数。
def repeat_kv(x: torch.Tensor, n_rep: int)-> torch.Tensor:
batch_size, seq_len, n_kv_heads, head_dim = x.shape
if n_rep == 1:
return x
else:
return (
# (B, seq_len, n_kv_heads, 1, head_dim)
x[:, :, :, None, :]
.expand(batch_size, seq_len, n_kv_heads, n_rep, head_dim)
.reshape(batch_size, seq_len, n_kv_heads * n_rep, head_dim)
)
KV 缓存的一些问题
KV缓存一般存储在连续的内存中。如果有多个并行请求,那么它们需要单独存储,这会浪费内存,并可能导致 OOM(内存不足)错误。而且,每个请求的提示几乎相同(特别是像“你是一个有用的助手…”这样的系统提示),因此一次又一次地将它们存储在连续的内存中效率很低。
-
静态模型权重消耗了近 65% 的 VRAM 内存,而 KV 缓存则消耗了近 30%,因为它会因多个请求而增大且内存使用效率低下。并且,如果将 KV 缓存存储在连续的内存中,那么在一些服务之后需要将其取消分配以适应最近的 KV 缓存
-
如果想要生成具有一些初始响应的并行多个响应,那么需要为每个生成的响应单独存储它们在连续的内存中,这会浪费很多空间。此外,使用诸如束搜索(beam search)这样的高级技术时,会根据生成的的未来累积概率来选择最有可能的。在这里,需要回溯并关闭一些路径,因此对于束搜索中的每个方向,如果分配了一个新的连续内存,那么它将消耗大量内存,效率很低。
-
GPU在矩阵乘法方面已经变得非常擅长,但这些系统的记忆仍然有限,因此受内存限制。KV缓存可以帮助,因为它可以帮助更快地获取键和值矩阵以进行计算。但在内存有限的情况下,需要提出更好的内存管理方法。
系列博客
探索和构建 LLaMA 3 架构:深入探讨组件、编码和推理技术(一)
https://duanzhihua.blog.csdn.net/article/details/138208650
探索和构建 LLaMA 3 架构:深入探讨组件、编码和推理技术(二)
https://duanzhihua.blog.csdn.net/article/details/138212328