LLM - Generate With KV-Cache 图解与实践 By GPT-2

目录

一.引言

二.KV-Cache 图解

1.Attention 计算

2.Generate WithOut KV-Cache

3.Generate With KV-Cache

4.Cache Memory Usage

三.KV-Cache 实践

1.WithOut KV-Cache

2.With KV-Cache

3.Compare Efficiency

四.总结


一.引言

LLM 推理中 KV-Cache 是最常见的优化方式,其通过缓存过去的 Keys、Values 从而提高 generate 每一个新 token 的速度,效果明显,是典型的空间换时间的做法,下面通过图示和 GPT-2 实测,看下 KV-Cache 的原理与实践。

二.KV-Cache 图解

1.Attention 计算

- MatMul-1        Q、K 负责计算当前 Token 与 候选 Token 之间的相似度

- Scale        防止 MatMul 值过大,对 MatMul 的值进行 Sqrt(d) 的缩放

- Mask        Causal Mask 时前后 Token 存在逻辑关系,后面的 Token 权重为 0 或很小的数

- SoftMax        权重归一化

- MatMul-2        根据相似度加权平均获取当前 Attention 后的结果

上面的流程简化一下,可以看作是一次 '基于 QK 相似度对 V 的加权平均' 的操作:

2.Generate WithOut KV-Cache

KV Cache 用于推理过程,下面我们以生成 "遥遥领先" 为例示范:

- <s>

生成遥遥领先之前,需要先从起始符 <s> 开始,其遵循前面图中的 Attention 计算公式:

- <s>遥

当前字符为 "<s>遥" 由于 '遥' 在 '<s>' 的后面,所以对于 '<s>' 而言,'遥' 的向量 V 是不会对 '<s>' 的 Attention 结果有影响的,也就是说对于 '<s>' 而言,'遥' 的向量 softmax 后权重是一个极小的接近于 0 的数字,下面有计算过程:

计算的得到的 1x2 的矩阵 1 对应 Batch Size,2 对应 seq_len,还有一个隐含的向量维度 Dim:

由 Att1、Att2 我们可以看到:

- Att1 的生成需要 K1、V1

- Att2 的生成需要 K1、K2、V1、V2

- <s>遥遥

Att3 的生成需要 K1K2K3V1V2V3

- <s>遥遥领

Att4 的生成需要 K1K2K3K4V1V2V3、V4

- <s>遥遥领先 ...

数学归纳法的原理是给定 F(0),F(1),再假设有 F(n-1) 看能否推出 F(n),

后续的生成过程就不再赘述了,根据 Casual Mask 的性质,不难得出:

 AttN 的生成需要 K1、,,,、KNV1、...、VN

3.Generate With KV-Cache

- Output Probability

Generate 生成 Next Token 时基于 Attention 的最后一个结果,举个例子:

'<s>遥遥领' 已经生成,此时需要预测 Next Token,通过 Attention 计算得到 1 x 4 x Dim 的 Attention 矩阵,而预测最终 token 的概率计算只参考最后 Dim 维的向量,即上图红框标注的位置。

- Repeat Counting

基于 Output Probability 的计算流程,再观察最右侧的 Attention 计算结果,我们发现每次计算都有很多的冗余,其实我们只需要获取每一步最后一维的 Tensor 即可,但是我们每一步都在重复计算前面的部分,所以 KV-Cache 应运而生:

 AttN 的生成需要 K1、,,,、KNV1、...、VN

通过缓存每一步的 Keys 和 Values 即可实现高效的 Next Token 的推理。

- Generate Process

通过缓存每一步的 K、V 实现高效的推理,因为我们只需要计算最后一维的向量即可,付出的代价是显存的增加,其与我们生成的 Response Token Length 成线性正比关系。

4.Cache Memory Usage

继续看刚才的示例,此时我们有如下参数:

- batch_size 1 \

- seq_len 4 \

- dim emb_size \ 

由于 Decoder-Only 是多层堆叠的结构,所以还有一个潜在的参数:

- layer_num N \

由于 Multi-Head Attention 是按照 emb_size = head_num x head_dim,所以我们这里直接按照总的 emb_size 计算,不再拆分 head,按照 FP16 计算其缓存的通式:

memory_usage = 2 x bsz x seq_len x  dim x layer_num x byte(FP16)

以 1 条样本、生成 512 长度、4096 维向量、32 层堆叠为例:

memory = 2 x 1 x 512 x 4096 x 32 x 2 = 268435456 / 1024 / 1024 = 256.0 MB

所以显存比较极限的场景下,也需要注意 KV-Cache 的显存占用,虽然是随着 SeqLen 线性增长的,但是架不住维度和堆叠的 Decoder 多。 

Tips:

这里解释下计算时前后两个 2 怎么来的,第一个 2 是因为 KV cache 里 K/V 各缓存一次;第二个 2 是因为 FP16 16 位占用 2 个 byte。

三.KV-Cache 实践

接下来我们实践下 KV-Cache,由于是本地实验,所以采用比较小的 GPT-2 作为实验 LLM。

Generate 时主要通过 past_key_values 传递 KV-Cache:

1.WithOut KV-Cache

import time

import torch
from transformers import GPT2LMHeadModel, GPT2Tokenizer


def common(in_tokens, model, tokenizer, is_log=False):
    # inference
    token_eos = torch.tensor([198])  # line break symbol
    out_token = None
    i = 0
    st = time.time()
    with torch.no_grad():
        while out_token != token_eos:
            logits, _ = model(in_tokens)
            out_token = torch.argmax(logits[-1, :], dim=0, keepdim=True)
            in_tokens = torch.cat((in_tokens, out_token), 0)
            text = tokenizer.decode(in_tokens)
            if is_log:
                print(f'step {i} input: {text}', flush=True)
            i += 1
    end = time.time()

    out_text = tokenizer.decode(in_tokens)
    print(f'Input: {in_text}')
    print(f'Output: {out_text}')
    print(f"Total Cost: {end - st} Mean: {(end - st) / i}")

token_id = 198 为 GPT-2 的 <eos>,我们手动停止生成,这里可以看到每一个 token 的预测过程:

- logits 通过 model 计算 logits 概率

- argmax 通过 argmax 获取概率最大的 token

- input_tokens = in_tokens + new_token 持续追加 seq_len 长度

- text 通过 tokenizer decode 即可获取 token 转变后的字符 

2.With KV-Cache

def cache(in_tokens, model, tokenizer, is_log=False):
    # inference
    token_eos = torch.tensor([198])  # line break symbol
    out_token = None
    kvcache = None
    out_text = in_text
    i = 0
    st = time.time()
    with torch.no_grad():
        while out_token != token_eos:
            logits, kvcache = model(in_tokens, past_key_values=kvcache)  # 增加了一个 past_key_values 的参数
            out_token = torch.argmax(logits[-1, :], dim=0, keepdim=True)
            in_tokens = out_token  # 输出 token 直接作为下一轮的输入,不再拼接
            text = tokenizer.decode(in_tokens)
            if is_log:
                print(f'step {i} input: {text}', flush=True)
            i += 1
            out_text += text
    end = time.time()

    print(f'Input: {in_text}')
    print(f'Output: {out_text}')
    print(f"Total Cost: {end - st} Mean: {(end - st) / i}")

Generate Process: 

3.Compare Efficiency

if __name__ == '__main__':
    local_path = "/LLM/model/gpt2"

    model = GPT2LMHeadModel.from_pretrained(local_path, torchscript=True).eval()

    # tokenizer
    tokenizer = GPT2Tokenizer.from_pretrained(local_path)
    in_text = "Cristiano Ronaldo is a"
    in_tokens = torch.tensor(tokenizer.encode(in_text))

    common(in_tokens, model, tokenizer)
    cache(in_tokens, model, tokenizer)

比较 common 和 cache 的生成效果和时间:

- Common

- Cache

- Efficient

生成的结果相同,Cache 只需 Common 耗时的 43% 左右即可完成相同的推理。

Tips:

由于本地测试且 seq 比较短,所以这里就不参考本机显存变化了,需要的话大家可以用前面公式计算一下,这里 GPT-2 的 dim = 768,layer_num = 12。

四.总结

Generate 流程、Attention 计算以及 KV-Cache 的流程大致就这么多,下面总结下:

- 注意 Scale

Attention 计算时有一个 Scale 的操作,图中没有标注,注意不要忘记。

- Generate

生成是一个 token 一个 token 生成的,后一个 token 基于前面的所有 token。

- Q Cache

KV-Cache,有的同学肯定有疑问,为啥不把 Q 也 Cache 了。因为:

 AttN 的生成需要 K1、,,,、KNV1、...、VN

还需要 QN,因此缓存 Q1、Q2 ... 对于计算 AttN 没有意义,而且每一个 Q 都是需要基于前面序列来生成。

- Use Cache

{
  "architectures": [
    "MistralForCausalLM"
  ],
  "attention_dropout": 0.0,
  "bos_token_id": 1,

...

  "use_cache": true,
  "vocab_size": 32000
}

当前新出的 LLM 模型都在 config 内置了 use_cache 参数,上面是 Mistral config 中的部分参数,KV-Cache 都是 infer 时默认开启的。

!! 最后感谢下面大佬们的输出:

大模型推理加速:看图学KV Cache

大模型推理性能优化之KV Cache解读

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mfbz.cn/a/619884.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

若依-2主1从表(解决了编辑页面的添加按钮失效问题)

1. 3个表的分析&#xff08;表名里不要加“t_”&#xff0c;会出现问题&#xff09; 主表&#xff1a;t_qxk 这是试卷表 主表&#xff1a;t_ques_xk 这是题目表 子表&#xff1a;t_quescxk 这是试卷和题目的关系表&#xff0c;即同时是试卷和题目表的子表。 因为一张试卷可…

给centos机器打个样格式化挂载磁盘(新机器)

文章目录 一、先安装lvm2二、观察磁盘三、磁盘分区四、建PV五、建VG六、创建LV七、在LV上创建文件系统八、挂载到/home&#xff08;1&#xff09;临时挂载&#xff08;2&#xff09;永久挂载 九、最后reboot一下 一、先安装lvm2 yum install lvm2二、观察磁盘 三、磁盘分区 四…

QT 项目打包(为了后期远程实验用)

一、环境准备 1、一个项目工程 二、步骤 1、将编译器设置调整为Release模式 二、对项目重新编译构建 三、可以看到工程目录这个文件夹 打开工程目录文件夹的Release文件夹&#xff0c;我的路径如下 四、新建一个文件夹&#xff0c;将上述路径文件夹下的exe文件复制到新的文…

云相册APP

简介 一款用于云存照片的app&#xff0c;支持批量上传和下载照片。 平台技术 Android客户端&#xff1a;Kotlin 协程 Retrofit Server服务后端&#xff1a;Java SpringBoot 部署云服务器&#xff1a;华为云耀云服务器L实例 下载网址 小鲸鱼相册 Ps: 由于网站域名备案审核…

SQL STRING_SPLIT函数,将指定的分隔符将字符串拆分为子字符串行

文章目录 STRING_SPLIT (Transact-SQL)1、语法2、参数3、样例样例1样例2 STRING_SPLIT (Transact-SQL) STRING_SPLIT 是一个表值函数&#xff0c;它根据指定的分隔符将字符串拆分为子字符串行。 1、语法 STRING_SPLIT ( string , separator [ , enable_ordinal ] ) 2、参数…

ICLR上新 | 强化学习、扩散模型、多模态语言模型,你想了解的前沿方向进展全都有

编者按&#xff1a;欢迎阅读“科研上新”栏目&#xff01;“科研上新”汇聚了微软亚洲研究院最新的创新成果与科研动态。在这里&#xff0c;你可以快速浏览研究院的亮点资讯&#xff0c;保持对前沿领域的敏锐嗅觉&#xff0c;同时也能找到先进实用的开源工具。 今天的“科研上…

AlphaFold3—转录因子预测(实操)

写在前面 我们上一次已经介绍了如何使用AlphaFold3&#xff1a;最新AlphaFold 3&#xff1a;预测所有生物分子结构、相互作用 AlphaFold3可以做什么&#xff1f; 1.AlphaFold服务器可以对以下生物分子类型进行建模&#xff0c;评价其相互结合&#xff1a; 蛋白质 DNA RNA 生…

计算机网络-DHCPv6基础

前面我们学习了IPv6地址可以通过手动配置、无状态自动配置、DHCPv6配置&#xff0c;这里简单学习下DHCPv6的知识点。 一、DHCPv6概述 DHCPv6 (Dynamic Host Configuration Protocol for IPv6) 是一种网络协议&#xff0c;设计用于IPv6网络环境中自动为网络设备分配必要的配置信…

java -jar提示jar中没有主清单属性(no main manifest attribute)

目录 传送门前言排查原因问题1-》jdk17和jdk8共存导致idea的maven插件识别报错问题2-》pom.xml中mainClass下面的skip属性是罪魁祸首 其他办法&#xff08;修改jar包&#xff09; 传送门 SpringMVC的源码解析&#xff08;精品&#xff09; Spring6的源码解析&#xff08;精品&…

InfiniGate自研网关实现四

13.服务发现组件搭建和注册网关连接 以封装 api-gateway-core 为目的&#xff0c;搭建 SpringBoot Starter 组件&#xff0c;用于服务注册发现的相关内容处理。 这里最大的目的在于搭建起用于封装网关算力服务的 api-gateway-core 系统&#xff0c;提供网关服务注册发现能力。…

Mysql 多表查询,内外连接

内连接&#xff1a; 隐式内连接 使用sql语句直接进行多表查询 select 字段列表 from 表1 , 表2 where 条件 … ; 显式内连接 将‘&#xff0c;’改为 inner join 连接两个表的 on select 字段列表 from 表1 [ inner ] join 表2 on 连接条件 … ; select emp.id, emp.name, …

宝塔安装多个版本的PHP,如何设置默认的PHP版本

如何将默认的PHP版本设置为7.3.32&#xff0c; 创建软链接指向7.3版本&#xff0c;关键命令&#xff1a;ln -sf /www/server/php/73/bin/php /usr/bin/php 然后再查看PHP版本验证一下结果 [rootlocalhost ~]# ln -sf /www/server/php/73/bin/php /usr/bin/php [rootlocalho…

Mysql进阶-sql优化篇

sql优化 sql优化insert优化批量插入手动提交事务主键顺序插入大批量插入数据 主键优化数据组织方式页分裂页合并主键设计原则 order by 优化原则 group by优化limit优化count 优化count的几种用法 update优化 sql优化 insert优化 批量插入 Insert into tb_test values(1,Tom…

一文读懂设计模式-单例模式

单例模式&#xff08;Singleton Pattern&#xff09;提供了一种创建对象的最佳方式 单例模式涉及到一个单一的类&#xff0c;该类负责创建自己的对象&#xff0c;同时确保只有单个对象被创建&#xff0c;这个类提供了一种访问其唯一的对象的方式&#xff0c;可以直接访问&…

IPD推行成功的核心要素(四)IPD究竟分几期做更合适?

集成产品开发 IPD体系&#xff08;Integrated Product Developm e nt&#xff09;是产品创新型企业关于产品开发&#xff08;从概念到产品开发、发布直至退市的全过程&#xff09;的一种理念与方法。IPD体系强调以市场需求作为产品开发的驱动力&#xff0c;将产品开发作为一项投…

快手短剧,和爱优腾踏入同一条河流

文丨黄小艺 “我们定制短剧的重心排序分别是抖音、淘宝、快手。”MCN机构从业者周明&#xff08;化名&#xff09;说道&#xff0c;“无论是单条还是品牌冠名剧&#xff0c;我们在快手短剧拿到的收益都相对偏低。” 近期&#xff0c;商业数据派和多家机构创作者沟通后发现&am…

Windows系统安装MongoDB数据库

MongoDB是一个基于分布式文件存储的NoSQL数据库&#xff0c;由C语言编写的。MongoDB的数据存储基本单元是文档&#xff0c;它是由多个键值对有序组合的数据单元&#xff0c;类似于关系数据库中的数据记录。适合存储JSON形式的数据&#xff0c;数据格式自由&#xff0c;不固定。…

区块链共识机制的演进

分布式系统的基本概念 FLP不可能原理和CAP原理 FLP 不可能原理&#xff08;FLP impossibility&#xff09;&#xff1a;在网络可靠&#xff0c;存在节点失效&#xff08;即便只有一个&#xff09;的最小化异步模型系统中&#xff0c;不存在一个可以解决一致性问题的确定性算法…

动手实践DDD领域驱动设计,DDD到底好不好用?真有那么神吗

文章目录 一、到底什么是DDD1、传统的MVC三层架构2、DDD到底解决了什么问题3、DDD四层架构4、为什么需要舍弃MVC而用DDD 二、DDD改造实战1、充血模型2、避免大实体3、Dao改造4、构建防腐层5、抽象中间件6、使用领域服务&#xff0c;封装跨实体业务7、使用设计模式8、改造结果9、…

自然资源-城市更新从立项到开发全流程梳理

自然资源-城市更新从立项到开发全流程梳理 一、城市更新项目分类 &#xff08;一&#xff09;按改造力度划分&#xff1a;整治、改建和拆建 按照改造力度由弱到强&#xff0c;城市更新项目可分为 整治类、改建类和 拆建类三种类型。不同城市命名略有不同&#xff0c;但实质相…