大模型——推理优化——KV Cache

在本文中,我们将详细介绍KV Cache,这是一种大模型推理加速的方法。
正如其名称所示,该方法通过缓存Attention中的K和V来实现推理优化。

一、大模型推理的冗余计算

我们先简单观察一下基于Decoder架构的大模型的生成过程
用户输入“中国的首都”,模型续写得到的输出为“是北京”,模型的生成过程如下:

  1. 将“中国的首都”输入模型,得到每个token的注意力表示(绿色部分)。使用“首都”的注意力表示,预测得到下一个token为“是”(实际还需要将该注意力表示映射成概率分布logits,为了方便叙述,我们忽略该步骤)。

  2. 将“是”拼接到原来的输入,得到“中国的首都是”,将其输入模型,得到注意力表示,使用“是”的注意力表示,预测得到下一个token为“北”。

  3. 将“北”拼接到原来的输入,依此类推,预测得到“京”,最终得到“中国的首都是北京”

存在的问题:
在每一步生成中,仅使用输入序列中的最后一个token的注意力表示,即可预测出下一个token。但模型还是并行计算了所有token的注意力表示,其中产生了大量冗余的计算(包含qkv映射,attention计算等),并且输入的长度越长,产生的冗余计算量越大。

例如:

  1. 在第一步中,我们仅需使用“首都”的注意力表示,即可预测得到“是”,但模型仍然会并行计算出“中国”,“的”这两个token的注意力表示。

  2. 在第二步中,我们仅需使用“是”的注意力表示,即可预测得到“北”,但模型仍然会并行计算“中国”,“的”,“首都”这三个token的注意力表示。

二、Self-Attention过程解析

2.1 公式解析

假设输入序列长度为 n,第 j个token对于整个输入序列的注意力表示如下公式: 

                                    b^{j} = \sum_{i=1}^{n}softmax(q^{j} \cdot k^{i})v^{i}

j个token对于整个输入序列的注意力表示的计算步骤大致如下:

  1. 向量映射:将输入序列中的每个token的词向量分别映射为q,k,v三个向量。

  2. 注意力计算:使用q^{j}分别与每个k进行点乘,得到第j个token对每个token的注意力分数。

  3. 注意力分数归一化:对注意力分数进行softmax,得到注意力权重。

  4. 加权求和:注意力权重与对应的向量v加权求和,最终得到第j个token的注意力表示。

2.2 过程实例

下面将以图像的方式帮助大家更形象地理解Self Attention。

假设:

  • a = a^{1}a^{2}a^{3}a^{4}
  • a^{1}对于整个输入序列a的注意力值是b^{1}

根据上面的Self-Attention公式得出:

 b^{1} = \sum_{i=1}^{4}softmax(q^{1} \cdot k^{i})v^{i}

继续观察a^{2}对于整个输入序列a的注意力b^{2}表示  ,即:
b^{2} = \sum_{i=1}^{4}softmax(q^{2} \cdot k^{i})v^{i}

三、KV Cache

3.1 原理

  • 在推理阶段,当输入长度为 n,我们仅需使用  即可预测出下一个token,但模型却会并行计算出  ,这部分会产生大量的冗余计算。
  • 而实际上b^{n}可直接通过公式b^{n} = \sum_{i=1}^{n}softmax(q^{n} \cdot k^{i})v^{i}算出,即b^{n}的计算只与  q^{n}、所有 k 和  v有关
  • KV Cache的本质是以空间换时间,它将历史输入的token的kv缓存下来,避免每步生成都重新计算历史token的k和 v 以及注意力表示  b^{1}...b^{n-1},而是直接通过b^{n} = \sum_{i=1}^{n}softmax(q^{n} \cdot k^{i})v^{i}的方式计算得到 b^{n} ,然后预测下一个token。

3.2 KV cache过程

第一步生成时,缓存  K,V均为空,输入为“中国的首都”,模型将按照常规方式并行计算:

  1. 并行计算得到每个token对应的  k,v,以及注意力表示b^{1},b^{2},b^{3}  。

  2. 使用 b^{3} 预测下一个token,得到“是”。

  3. 更新缓存,令 K=[k^{1},k^{2},k^{3}],V=[v^{1},v^{2},v^{3}] 。

第二步生成时,计算流程如下:

  1. 仅将“是”输入模型,对其词向量进行映射,得到 q^{4},k^{4},v^{4} 。

  2. 更新缓存,令 K=[k^{1},k^{2},k^{3},k^{4}],V=[v^{1},v^{2},v^{3},v^{4}]  。

  3. 计算  b^{4} = \sum_{i=1}^{4}softmax(q^{4} \cdot k^{i})v^{i},预测下一个token,得到“北”

第三步生成时,计算流程如下:

  1. 仅将“北”输入模型,对其词向量进行映射,得到q^{5},k^{5},v^{5}  。

  2. 更新缓存,令 K=[k^{1},k^{2},k^{3},k^{4},k^{5}],V=[v^{1},v^{2},v^{3},v^{4},v^{5}]   。

  3. 计算 b^{5} = \sum_{i=1}^{5}softmax(q^{5} \cdot k^{i})v^{i} ,预测下一个token,得到“京”

 

上述生成流程中,只有在第一步生成时,模型需要计算所有token的 k,v ,并且缓存下来。
此后的每一步,仅需计算当前token的 q^{n},k^{n},v^{n} ,更新缓存 K,V,然后使用 q^{n},K,V 即可算出当前token的注意力表示,最后用来预测一下个token。 

3.3 代码修改

这里参考gpt2里面的代码实现

https://github.com/huggingface/transformers/blob/main/src/transformers/models/gpt2/modeling_gpt2.py

query, key, value = self.c_attn(hidden_states).split(self.split_size, dim=2)

query = self._split_heads(query, self.num_heads, self.head_dim)  # 当前token对应的query
key = self._split_heads(key, self.num_heads, self.head_dim)  # 当前token对应的key
value = self._split_heads(value, self.num_heads, self.head_dim)  # 当前token对应的value

if layer_past is not None:
    past_key, past_value = layer_past  # KV Cache
    key = torch.cat((past_key, key), dim=-2)  # 将当前token的key与历史的K拼接
    value = torch.cat((past_value, value), dim=-2)  # 将当前token的value与历史的V拼接

if use_cache is True:
    present = (key, value)
else:
    present = None

# 使用当前token的query与K和V计算注意力表示
if self.reorder_and_upcast_attn:
    attn_output, attn_weights = self._upcast_and_reordered_attn(query, key, value, attention_mask, head_mask)
else:
    attn_output, attn_weights = self._attn(query, key, value, attention_mask, head_mask)

#参考
 图解大模型推理优化之KV Cache

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/358558.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

springboot本地测试

文章目录 本地测试引入依赖进入StudentMapper右键点击生成 项目结构 本地测试 引入依赖 <dependency><groupId>org.springframework.boot</groupId><artifactId>spring-boot-starter-test</artifactId><scope>test</scope> </d…

【劳德巴赫 Trace32 高阶系列 3 -- trace32 svf 文件操作命令】

请阅读【嵌入式开发学习必备专栏 之 Trace32 系列 】 文章目录 Trace32 SVF 文件操作命令JTAG.PROGRAM.autoJTAG.PROGRAM.SVF命令参数介绍IRPREIRPOSTDRPREDRPOSTInitStateIgnoreTDOVerbose使用示例Trace32 SVF 文件操作命令 JTAG.PROGRAM.auto Format: JTAG.PROGRAM.</

mfc140.dll找不到了要怎么解决?教你多种修复mfc140.dll的方法

遭遇 mfc140.dll 文件缺失的状况时&#xff0c;首要任务是保持冷静&#xff0c;并深入理解问题所在&#xff0c;随后按照科学的方法来应对这一挑战。本篇文章概述了多种应对策略&#xff0c;从适合新手的基本步骤到针对有技术基础用户的高级方案&#xff0c;各种手段都能有效地…

[Bug] [OpenAI] [TypeError: fetch failed] { cause: [Error: AggregateError] }

[Bug] [OpenAI] [TypeError: fetch failed] { cause: [Error: AggregateError] } ubuntu20 win10 edge浏览器访问 服务器部署 页面打开后想使用chatgpt报错了 rootcoal-pasi1cmp:/www/wwwroot/ChatGPT-Next-Web# PORT3000 yarn start yarn run v1.22.19 warning package.json:…

多场景建模:腾讯3MN

3MN: Three Meta Networks for Multi-Scenario and Multi-Task Learning in Online Advertising Recommender Systems 背景 推荐领域的多场景多任务学习&#xff1a;维护单模型即可节省资源也可节省人力&#xff1b;各个场景的数据共享&#xff0c;理论上面学习是更加充分的 …

基于ldap实现登录认证

最近开发的应用需要外协人员实现登录认证&#xff0c;外协人员的密码等信息已经录入到ldap, 需要连接ldap进行登录认证。下面先介绍一下登录的网络旅程图。 一.nginx实现AES加密 nginx请求处理入口&#xff08;前端请求为json格式&#xff09; location /aes {default_type te…

leetcode常见错误

1 runtime error: load of null pointer of type ‘std::_Bit_type‘ (aka ‘unsigned long‘) (stl_bvector&#xff09; 力扣&#xff1a;runtime error: load of null pointer of type ‘std::_Bit_type‘ (aka ‘unsigned long‘) (stl_bvector&#xff09;_runtime error…

GitLab 中国发行版如何设置镜像拉取策略?

最近在用极狐GitLab&#xff08;极狐GitLab 可以理解为 GitLab 在中国的发行版&#xff09; CI/CD 的时候遇到一个问题&#xff1a;CI/CD 中有一个 stage 需要拉取 dockerhub 上的镜像&#xff0c;但是由于 dockerhub 在国内的访问不是很顺畅&#xff0c;经常发生 timeout 的情…

方法阻塞的解决方案之一

1、简单使用 一个h一个cpp文件 #pragma once #include <iostream> #include <thread> #include <atomic> #include <chrono> #include <string>class Person {public:struct dog {std::string name;int age;};public:void a(std::atomic<bo…

链表——超详细

一、无头单向非循环链表 1.结构&#xff08;两个部分&#xff09;&#xff1a; typedef int SLTDataType; typedef struct SListNode {SLTDataType data;//数据域struct SListNode* next;//指针域 }SLNode; 它只有一个数字域和一个指针域&#xff0c;里面数据域就是所存放的…

一些著名的软件都用什么语言编写?

1、操作系统 Microsoft Windows &#xff1a;汇编 -> C -> C 备注&#xff1a;曾经在智能手机的操作系统&#xff08;Windows Mobile&#xff09;考虑掺点C#写的程序&#xff0c;比如软键盘&#xff0c;结果因为写出来的程序太慢&#xff0c;实在无法和别的模块合并&…

宠物空气净化器适合养猫家庭吗?猫用空气净化器品牌推荐!

养宠物的家庭都了解到&#xff0c;宠物掉毛是一个令人头痛的问题。即使我们及时清理地面&#xff0c;也很难跟上宠物掉毛的速度。飘散的毛发不仅让家里显得不整洁&#xff0c;还可能对家人的呼吸健康产生影响&#xff0c;甚至引起过敏反应。此外&#xff0c;猫咪每天上厕所&…

vue3之echarts3D环柱图-间隔版

vue3之echarts3D环柱图-间隔版 效果&#xff1a; 版本 "echarts": "^5.4.1", "echarts-gl": "^2.0.9" 核心代码&#xff1a; <template><div class"content"><div ref"eCharts" class"c…

零基础自学C语言|内存函数

&#x1f50d;memcpy的使用与模拟实现 格式如下&#xff1a; void* memcpy(void* destination, const void* source, size_t num); 函数memcpy从source的位置开始向后复制num个字节的数据到destination指向的内存位置。这个函数在遇到\0的时候并不会停下来。如果source和de…

Flask 入门1:一个简单的 Web 程序

1. 关于 Flask Flask诞生于2010年&#xff0c; Armin Ronacher的一个愚人节玩笑。不过现在已经是一个用python语言基于Werkzeug工具箱编写的轻量级web开发框架&#xff0c;它主要面向需求简单&#xff0c;项目周期短的小应用。 Flask本身相当于一个内核&#xff0c;其他几乎所…

基于SpringBoot实现的AI智能大数据医疗诊断平台

系统介绍 系统演示 微信关注视频号&#xff1a;【全栈小白】&#xff0c;查看演示视频 基于SpringBoot实现的AI智能大数据医疗诊断平台&#xff0c;主要包含六个大模块&#xff1a;系统管理、居民医保信息、药物信息管理、居民健康信息、居民就诊信息和我的预约信息。项目启…

Python中的递归函数是什么

Python 递归函数 递归的特性&#xff1a; 1.调用自身函数 2.有一个结束条件 3.递归效率不高&#xff0c;可能会导致栈溢出(函数调用是通过栈这种数据结构实现的&#xff0c;每进入一个函数调用&#xff0c;栈就会增加一层栈帧&#xff0c;函数每返回&#xff0c;栈就会减少…

C++11—— lambda表达式与包装器

C11—— lambda表达式与包装器 文章目录 C11—— lambda表达式与包装器一、 lambda表达式lambda表达式产生的意义lambda表达式语法函数对象与lambda表达式 二、 包装器functionfunction产生的意义function的用法function使用的例子 bind调整参数顺序固定绑定参数 一、 lambda表…

解锁Web3:数字未来的大门

随着科技的不断推进&#xff0c;我们正站在数字时代的新门槛上。Web3&#xff0c;作为互联网的下一个演进阶段&#xff0c;正在逐渐揭开数字未来的面纱。本文将深入探讨Web3的本质、对社会的影响以及在数字时代中所扮演的关键角色。 什么是Web3&#xff1f; Web3是互联网发展的…

git仓库批量备份

git的mirror参数 在git中&#xff0c;--mirror是一个用于克隆和推送操作的参数。它用于创建一个镜像仓库&#xff0c;包含了源仓库的所有分支、标签和提交历史记录。 当使用git clone --mirror <source-repo>命令时&#xff0c;会创建一个完全相同的镜像仓库&#xff0…