大模型推理优化之 KV Cache

原文:https://zhuanlan.zhihu.com/p/677660376

目录

收起

KV Cache 定义

KV Cache 原理

KV Cache 实现

KV Cache 显存占用分析

KV Cache 优化方法

在语言模型推理的过程中,性能优化一直是一个备受关注的话题。LLM(Large Language Models)的出现使得自然语言处理取得了显著的进展,但随之而来的是庞大的模型和复杂的计算过程,因此推理效率的提升变得至关重要。在这个背景下,KV Cache(键-值缓存)成为了一项被广泛应用的推理优化技术。

KV Cache 定义

KV Cache,即键-值缓存,是一种用于存储键值对数据的缓存机制。在语言模型的推理过程中,经常需要多次访问相同的数据,而KV Cache通过将这些数据缓存到内存中,提供了快速的数据访问速度,从而加速推理过程。该技术仅应用于解码阶段。如 decode only 模型(如 GPT3、Llama 等)、encode-decode 模型(如 T5)的 decode 阶段,像 Bert 等非生成式模型并不适用。

KV Cache 原理

推理过程:给定一个问题,模型会输出一个回答。生成回答的过程每次只生成一个 token,输出的 token会和问题拼接在一起,再次作为输入传给模型,这样不断重复直至生成终止符停止。

GPT-4推理过程图

下图是Scaled dot-product attention 有无 KV Cache 优化计算过程的比较。一般情况下,在每个生成步骤中,都会重新计算之前token的注意力,而实际上我们只想计算新 token 的注意力。而采用 KV Cache 方法后,会把之前 Token的 KV 值存下来,新 token 预测时只需要从缓存中读取结果就可以了。

动图封面

Scaled dot-product attention 有无 KV Cache 比较,图片来源:https://medium.com/@joaolages/kv-caching-explained-276520203249

KV Cache 实现

huggingface的 transformer库已经实现了 KV cache,在推理时新增了past_key_values,设置 use_cache=True 或 config.use_cache=True 就可以了。

past_key_values( Cacheor tuple(tuple(torch.FloatTensor)), optional) — Pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention blocks) that can be used to speed up sequential decoding. This typically consists in the past_key_valuesreturned by the model at a previous stage of decoding, when use_cache=Trueor config.use_cache=True.
import numpy as np
import time
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
model_path = "Llama-2-7b-chat-hf"

device = "cuda:7" if torch.cuda.is_available() else "cpu"
tokenizer = AutoTokenizer.from_pretrained(model_path)
model = AutoModelForCausalLM.from_pretrained(model_path, device_map=device)
for use_cache in (True, False):
  times = []
  for _ in range(10):  # measuring 10 generations
      start = time.time()
      input = tokenizer("What is KV caching?", return_tensors="pt").to(device)
      outputs = model.generate(**input, use_cache=use_cache, max_new_tokens=1000, temperature=0.00001)
      times.append(time.time() - start)
  print(f"{'With' if use_cache else 'Without'} KV caching: {round(np.mean(times), 3)} +- {round(np.std(times), 3)} seconds")

执行结果如下所示:

With KV caching: 8.946 +- 0.011 seconds
Without KV caching: 58.68 +- 0.012 seconds

从结果可以看出使用 KV Cache 方法进行大模型推理,推理速度增加了6.56倍,差异巨大。

KV Cache 显存占用分析

假设输入的序列长度是 𝑚,输出序列长度是 𝑛 , 𝑏 为数据批次大小, 𝑙 为层数, ℎ 为隐向量维度,以 FP16(2bytes) 来保存,那么 KV Cache的峰值显存占用大小为 𝑏(𝑚+𝑛)ℎ∗𝑙∗2∗2=4𝑏𝑙ℎ(𝑚+𝑛) ,第一个 2 代表 K、V,第二个 2 代表 2bytes。可见随着批次大小和长度的增加,KV Cache 的显存占用也会快速增大。

KV Cache 优化方法

这里主要介绍下 Multi Query Attention 和 Grouped Query Attention。

Multi Query Attention

Multi-query attention is identical except that the different heads share a single set of keys and values.

MQA 和 MHA的区别是:每个头共享相同的 K、V 权重而不共享Q的权重。

Grouped Query Attention

Grouped-query attention divides query heads intoG groups, each of which shares a single key head and value head.

分组注意力将查询头分为 G 组,每组共享一个键头和值头。GQA-G 是指有 G 组的分组查询。GQA-1,有一个组,因此有一个键头和值头,等同于 MQA,而 GQA-H,组数等于头数,等同于 MHA。

MHA、MQA、GQA比较可参考下图。

Multi Head Attention、Multi Query Attention、Grouped Query Attention 比较

使用MHA、MQA、GQA进行KV Cache 显存占用情况比较

MHA: 𝑏(𝑚+𝑛)ℎ∗𝑙∗2∗2=4𝑏𝑙ℎ(𝑚+𝑛) ;

MQA: 4𝑏𝑙ℎ(𝑚+𝑛)/𝐻 , 𝐻 代表头数;

GQA: 4𝑏𝑙ℎ(𝑚+𝑛)∗𝐺/𝐻 , 𝐻 代表头数, 𝐺 代表分组数;

MQA、GQA Huggingface 库都有实现,具体见llm_tutorial_optimization。

如果觉得本文对您有帮助,麻烦点个小小的赞,谢谢大家啦~

Transformers KV Caching Explained
Huggingface-llama2
Fast Transformer Decoding: One Write-Head is All You Need
GQA: Training Generalized Multi-Query Transformer Models from Multi-Head Checkpoints

编辑于 2024-01-16 21:52・IP 属地江苏

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/572026.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

基于stm32的UART高效接收DMA+IDLE编程示例

目录 基于stm32的UART高效接收DMAIDLE编程示例实验目的场景使用原理图UART的三种编程方式IDLE程序设计串口配置配置中断配置DMA代码片段本文中使用的测试工程 基于stm32的UART高效接收DMAIDLE编程示例 本文目标:基于stm32_h5的freertos编程示例 按照本文的描述&am…

训练营第三十三天贪心(第五部分重叠区间问题)

训练营第三十三天贪心(第五部分重叠区间问题) 435.无重叠区间 力扣题目链接 题目 给定一个区间的集合 intervals ,其中 intervals[i] [starti, endi] 。返回 需要移除区间的最小数量,使剩余区间互不重叠 。 示例 1: 输入: …

MySQL多版本并发控制mvcc原理浅析

文章目录 1.mvcc简介1.1mvcc定义1.2mvcc解决的问题1.3当前读与快照读 2.mvcc原理2.1隐藏字段2.2版本链2.3ReadView2.4读视图生成原则 3.rc和rr隔离级别下mvcc的不同 1.mvcc简介 1.1mvcc定义 mvcc(Multi Version Concurrency Control),多版本并发控制,是…

Kubeedge:Metamanager源码速读(不定期更新)

Kubeedge源码版本:v1.15.1 在看Metamanager之前,先看一下Metamanager源码的目录结构(位于edge/pkg下)和官方文档: 目录结构如下面的两张图所示。请忽略绿色的文件高亮,这是Jetbrains goland对未提交修改的…

【教程】MySQL数据库学习笔记(五)——约束(持续更新)

写在前面: 如果文章对你有帮助,记得点赞关注加收藏一波,利于以后需要的时候复习,多谢支持! 【MySQL数据库学习】系列文章 第一章 《认识与环境搭建》 第二章 《数据类型》 第三章 《数据定义语言DDL》 第四章 《数据操…

【求助】西门子S7-200PLC定时中断+数据归档的使用

前言 已经经历了种种磨难来记录我的数据(使用过填表程序、触摸屏的历史记录和数据归档)之后,具体可以看看这篇文章:🚪西门子S7-200PLC的数据归档怎么用?,出现了新的问题。 问题的提出 最新的…

25 - MOV 指令

---- 整理自B站UP主 踌躇月光 的视频 文章目录 1. 指令系统设计2. MOV 指令3. 实现3.1 CPU 电路图3.2 代码实现3.3 实验过程3.4 实验结果3.5 实验工程 1. 指令系统设计 指令 IR 8 位程序状态字 4 位微程序周期 4 位 2. MOV 指令 MOV A, 5; 立即寻址 MOV A, B; 寄存器寻址 MO…

基于PaddlePaddle平台训练物体分类——猫狗分类

学习目标: 在百度的PaddlePaddle平台训练自己需要的模型,以训练一个猫狗分类模型为例 PaddlePaddle平台: 飞桨(PaddlePaddle)是百度开发的深度学习平台,具有动静统一框架、端到端开发套件等特性&#xf…

tailwindcss在使用cdn引入静态html的时候,vscode默认不会提示问题

1.首先确保vscode下载tailwind插件:Tailwind CSS IntelliSense 2.需要在根目录文件夹创建一个tailwind.config.js文件 export default {theme: {extend: {// 可根据需要自行配置,空配置项可以正常使用},}, }3.在html文件的标签中引入配置文件&#xf…

程序员到架构师,除了代码,还有文档和图

文章目录 前言一、书面设计文档文档应该作为代码和口头交流的补充文档应该注意鲜活 二、图——架构讨论的直观语言总结 前言 作为人类,我们天生就被视觉所吸引。在这个信息爆炸的时代,从精炼的代码到清晰的文档,再到直观的图,我们…

【数据结构】串(String)

文章目录 基本概念顺序存储结构比较当前串与串s的大小取子串插入删除其他构造函数拷贝构造函数扩大数组空间。重载重载重载重载[]重载>>重载<< 链式存储结构链式存储结构链块存储结构 模式匹配朴素的模式匹配算法(BF算法)KMP算法字符串的前缀、后缀和部分匹配值nex…

Parade Series - CoreAudio Reformating

// 获得音频播放设备格式信息CComHeapPtr<WAVEFORMATEX> pDeviceFormat;pAudioClient->GetMixFormat(&pDeviceFormat);constexpr int REFTIMES_PER_SEC 10000000; // 1 reference_time 100nsconstexpr int REFTIMES_PER_MILLISEC 10000;// Microsoftif (p…

Golang | Leetcode Golang题解之第49题字母异位词分组

题目&#xff1a; 题解&#xff1a; func groupAnagrams(strs []string) [][]string {mp : map[[26]int][]string{}for _, str : range strs {cnt : [26]int{}for _, b : range str {cnt[b-a]}mp[cnt] append(mp[cnt], str)}ans : make([][]string, 0, len(mp))for _, v : ra…

Alibaba 的fastjson源码详解

一、概述 Fastjson 是阿里巴巴开源的一个 Java 工具库&#xff0c;它常常被用来完成 Java 的对象与 JSON 格式的字符串的相互转化。 Fastjson 可以操作任何 Java 对象&#xff0c;即使是一些预先存在的没有源码的对象。 二、源码分析 1.首先以fastjson-1.2.70为例&#xff0c;…

nodejs

334 先下载zip文件&#xff0c;然后加上.zip,可以看到两个文件 在user中可以看到 输入即可得到flag。 335. 这里提到eval函数&#xff0c;eval中可以执行js代码&#xff0c;可以尝试使用这个函数进行测试 payload&#xff08;显示当前目录下的文件和文件夹列表&#xff09; …

基于emp的mysql查询

SQL命令 结构化查询语句&#xff1a;Structured Query Language 结构化查询语言是高级的非过程化变成语言&#xff0c;允许用户在高层数据结构上工作。是一种特殊目的的变成语言&#xff0c;是一种数据库查询和程序设计语言&#xff0c;用于存取数据以及查询、更新和管理关系数…

Python 网络与并发编程(四)

文章目录 协程Coroutines协程的核心(控制流的让出和恢复)协程和多线程比较协程的优点协程的缺点 asyncio实现协程(重点) 协程Coroutines 协程&#xff0c;全称是“协同程序”&#xff0c;用来实现任务协作。是一种在线程中&#xff0c;比线程更加轻量级的存在&#xff0c;由程…

android脱壳第二发:grpc-dumpdex加修复

上一篇我写的dex脱壳&#xff0c;写到银行类型的app的dex修复问题&#xff0c;因为dex中被抽取出来的函数的code_item_off 的偏移所在的内存&#xff0c;不在dex文件范围内&#xff0c;所以需要进行一定的修复&#xff0c;然后就停止了。本来不打算接着搞得&#xff0c;但是写了…

基础SQL DCL语句

DCL是数据控制语言&#xff0c;用来管理数据库用户&#xff0c;还有控制用户的访问权限 1.用户的查询 MySQL的用户信息存储在mysql数据库中&#xff0c;查询用户时&#xff0c;我们需要使用这个数据库。 后面&#xff0c;还有很多数据&#xff0c;因为篇幅的问题&#xff0c;就…

【FFmpeg】音视频录制 ② ( 使用 Screen Capturer Recorder 软件生成 ffmpeg 可录制的音视频设备 )

文章目录 一、使用 Screen Capturer Recorder 软件生成音视频设备1、设备查找问题 - 引入 Screen Capturer Recorder 软件2、下载安装 Screen Capturer Recorder 软件3、验证 Screen Capturer Recorder 生成的设备 一、使用 Screen Capturer Recorder 软件生成音视频设备 1、设…