LLM推理优化笔记1:KV cache、Grouped-query attention等

KV cache

对于decoder-only 模型比如现在如火如荼的大模型,其在生成内容的过程中,为了避免冗余计算,会将Transformer里的self-attention的K和V矩阵给缓存起来,这个过程即为KV cache。

在这里插入图片描述

decoder-only模型的生成过程是自回归的(auto-regressive),生成过程中先根据输入生成下一个token,再将生成的token与输入一起生成下一个token,重复这个过程直到遇到停止符号或者达到限定的输出token个数。(gif图来自illustrated-gpt2)
在这里插入图片描述

因为decoder-only模型的生成过程是自回归的,并且decoder的self-attention是causal的,即每一个token的attention计算只与其前面的tokens有关,所以我们每生成一个token时都重复计算了前面出现过的token的attention。为了节省计算量,可以将已经计算过的token的attention矩阵存储下来,每生成下一个token时直接使用存储好的attention矩阵并将新计算的token attention存储起来。(下面图片来自博客,不考虑softmax和scale示意对比KV cache使用)

在这里插入图片描述

在每一步计算时,只需要使用到上一步计算过的K和V矩阵,所以KV cache只会缓存K和V。当然缓存的代价就是需要额外的显存存储:

  • 每缓存一个token,其需要的空间为 2 * precision_in_bytes * head_dim * n_heads * n_layers(式中2是因为缓存K和V两个矩阵,precision_in_bytes是token的存储精度占用字节大小,head_dim是attention的head维度,n_head是attention的head个数,n_layers是transformer的层个数)。
  • 对于16-bit精度的模型以最大上下文长度max_context_length进行批量推理要求的缓存大小2 * 2 * head_dim * n_heads * n_layers * max_context_length * batch_size,比如Llama-2-13B模型对应最大上下文窗口为4096,batch大小为8时要求的缓存显存最多高达25GB左右。

transformers包生成时默认使用KV cache(use_cache=True),我们可以用如下代码去测试一下使用了KV cache以及不使用时的性能差异。

## 代码来自 https://medium.com/@joaolages/kv-caching-explained-276520203249
import numpy as np
import time
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer

device = "cuda" if torch.cuda.is_available() else "cpu"
tokenizer = AutoTokenizer.from_pretrained("gpt2")
model = AutoModelForCausalLM.from_pretrained("gpt2").to(device)

for use_cache in (True, False):
  times = []
  for _ in range(10):  # measuring 10 generations
    start = time.time()
    model.generate(**tokenizer("What is KV caching?", return_tensors="pt").to(device), use_cache=use_cache, max_new_tokens=1000)
    times.append(time.time() - start)
  print(f"{'with' if use_cache else 'without'} KV caching: {round(np.mean(times), 3)} +- {round(np.std(times), 3)} seconds")

Multi-query attention 和Grouped-query attention

Multi-query attention

Multi-query attention(MQA)出自2019年11月的论文《Fast Transformer Decoding: One Write-Head is All You Need》,它让multi-head attention里的多个head共享K和V矩阵,并做试验验了修改之后模型的性能下降不明显,但因为减少了参数,推理时KV cache占用的存储和读取时间都会少很多。

Grouped-query attention

在这里插入图片描述

Grouped-query attention(GQA)出自2023年5月的论文《GQA: Training Generalized Multi-Query Transformer Models from Multi-Head Checkpoints》, 如上图所示,它的共享K和V矩阵介于Multi-query attention(MQA)和Multi-head attention(MHA)之间,通过实验证明GQA可达到类似MQA的速度以及MHA的性能。

Grouped-query attention将query heads划分为G个groups,每一组query heads共享一个key head和value head,将 G Q A − G GQA_{-G} GQAG 记为有G个groups的grouped-query attention,则 G Q A − 1 GQA_{-1} GQA1为Multi-query attention, G Q A − H GQA_{-H} GQAH则等价于Multi-head attention。

论文还提出了一个将Multi-head attention模型转变MQA或GQA模型的方法,其分为两步:

  • 将MHA模型的checkpoint转变成MQA或GQA模型,使用如下图示意的mean pooling将多个K和V矩阵变成单个矩阵(论文做了试验比较选取第一个head、随机初始化、mean pooling,mean pooling的效果是最好的)。
  • 使用少量比例(5%左右)的预训练数据来继续预训练使模型适应新结构。

在这里插入图片描述

关于GQA的组个数选取,论文做了消融实验后对于总head个数为64时G选取的是8,而在Llama2-70B模型也是8(总heads数也为64)。

在这里插入图片描述

实现

不考虑性能的代码示意如下:

from dataclasses import dataclass
import math
import torch
import torch.nn as nn 
from torch.nn import functional as F

@dataclass
class GPTConfig:
    block_size: int = 1024 # max sequence length
    vocab_size: int = 50257 # number of tokens: 50,000 BPE merges + 256 bytes tokens + 1 <|endoftext|> token
    n_layer: int = 12 # number of layers
    n_head: int = 12 # number of heads
    n_embd: int = 768 # embedding dimension
    n_kv_heads: int = 12 # grouped-query的group个数

def repeat_kv(hidden: torch.Tensor, n_rep: int) -> torch.Tensor:
    """Perform repeat of kv heads along a particular dimension.
    hidden.shape expected to be: (batch size, seq len, kv_n_heads, head_dim)
    n_rep: amount of repetitions of kv_n_heads
    Unlike torch.repeat_interleave, this function avoids allocating new memory.
    from https://huggingface.co/mosaicml/mpt-7b-chat/blob/main/attention.py#L47
    llama2里的写法差不多https://github.com/meta-llama/llama/blob/llama_v2/llama/model.py#L164C1-L165C1
    """
    if n_rep == 1:
        return hidden
    (b, s, kv_n_heads, d) = hidden.shape
    hidden = hidden[:, :, :, None, :].expand(b, s, kv_n_heads, n_rep, d)
    return hidden.reshape(b, s, kv_n_heads * n_rep, d)


## adapt from https://github.com/karpathy/nanoGPT/blob/master/model.py
class MultiHeadAttention(nn.Module):
    def __init__(self, config):
        super().__init__()
        assert config.n_embd % config.n_head == 0
        # key, query, value projections for all heads, but in a batch
        self.c_attn = nn.Linear(config.n_embd, 3 * config.n_embd)
        # output projection
        self.c_proj = nn.Linear(config.n_embd, config.n_embd)
        # regularization
        self.n_head = config.n_head
        self.n_embd = config.n_embd
        # not really a 'bias', more of a mask, but following the OpenAI/HF naming though
        self.register_buffer("bias", torch.tril(torch.ones(config.block_size, config.block_size))
                                     .view(1, 1, config.block_size, config.block_size))

    def forward(self, x):
        B, T, C = x.size() # batch size, sequence length, embedding dimensionality (n_embd)
        # calculate query, key, values for all heads in batch and move head forward to be the batch dim
        # nh is "number of heads", hs is "head size", and C (number of channels) = nh * hs
        # e.g. in GPT-2 (124M), n_head=12, hs=64, so nh*hs=C=768 channels in the Transformer
        qkv = self.c_attn(x)
        q, k, v = qkv.split(self.n_embd, dim=2)
        k = k.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)
        q = q.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)
        v = v.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)
        # attention (materializes the large (T,T) matrix for all the queries and keys)
        att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1)))
        att = att.masked_fill(self.bias[:,:,:T,:T] == 0, float('-inf'))
        att = F.softmax(att, dim=-1)
        y = att @ v # (B, nh, T, T) x (B, nh, T, hs) -> (B, nh, T, hs)
        y = y.transpose(1, 2).contiguous().view(B, T, C) # re-assemble all head outputs side by side
        # output projection
        y = self.c_proj(y)
        return y
    

### multi-query
class MultiQueryAttention(nn.Module):
    def __init__(self, config):
        super().__init__()
        assert config.n_embd % config.n_head == 0
        # key, query, value projections for all heads, but in a batch
        self.c_attn = nn.Linear(config.n_embd, config.n_embd + 2*config.n_embd//config.n_head)
        # output projection
        self.c_proj = nn.Linear(config.n_embd, config.n_embd)
        # regularization
        self.n_head = config.n_head
        self.n_embd = config.n_embd
        # not really a 'bias', more of a mask, but following the OpenAI/HF naming though
        self.register_buffer("bias", torch.tril(torch.ones(config.block_size, config.block_size))
                                     .view(1, 1, config.block_size, config.block_size))

    def forward(self, x):
        B, T, C = x.size() # batch size, sequence length, embedding dimensionality (n_embd)
        # calculate query, key, values for all heads in batch and move head forward to be the batch dim
        # nh is "number of heads", hs is "head size", and C (number of channels) = nh * hs
        # e.g. in GPT-2 (124M), n_head=12, hs=64, so nh*hs=C=768 channels in the Transformer
        qkv = self.c_attn(x)
        q, k, v = qkv.split([self.n_embd, self.n_embd//self.n_head, self.n_embd//self.n_head], dim=2)
        q = q.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)
        k = repeat_kv(k.view(B, T, 1, C // self.n_head), self.n_head).transpose(1, 2) # (B, nh, T, hs)
        v = repeat_kv(v.view(B, T, 1, C // self.n_head), self.n_head).transpose(1, 2) # (B, nh, T, hs)
                # attention (materializes the large (T,T) matrix for all the queries and keys)
        att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1)))
        att = att.masked_fill(self.bias[:,:,:T,:T] == 0, float('-inf'))
        att = F.softmax(att, dim=-1)
        y = att @ v # (B, nh, T, T) x (B, nh, T, hs) -> (B, nh, T, hs)
        y = y.transpose(1, 2).contiguous().view(B, T, C) # re-assemble all head outputs side by side
        # output projection
        y = self.c_proj(y)
        return y
    
### grouped-query attention
class GroupedQueryAttention(nn.Module):
    def __init__(self, config):
        super().__init__()
        assert config.n_embd % config.n_head == 0
        # key, query, value projections for all heads, but in a batch
        self.c_attn = nn.Linear(config.n_embd, config.n_embd + 2*config.n_kv_heads*config.n_embd//config.n_head)
        # output projection
        self.c_proj = nn.Linear(config.n_embd, config.n_embd)
        # regularization
        self.n_head = config.n_head
        self.n_embd = config.n_embd
        self.n_kv_heads = config.n_kv_heads
        # not really a 'bias', more of a mask, but following the OpenAI/HF naming though
        self.register_buffer("bias", torch.tril(torch.ones(config.block_size, config.block_size))
                                     .view(1, 1, config.block_size, config.block_size))

    def forward(self, x):
        B, T, C = x.size() # batch size, sequence length, embedding dimensionality (n_embd)
        # calculate query, key, values for all heads in batch and move head forward to be the batch dim
        # nh is "number of heads", hs is "head size", and C (number of channels) = nh * hs
        # e.g. in GPT-2 (124M), n_head=12, hs=64, so nh*hs=C=768 channels in the Transformer
        qkv = self.c_attn(x)
        q, k, v = qkv.split([self.n_embd, self.n_kv_heads*self.n_embd//self.n_head, self.n_kv_heads*self.n_embd//self.n_head], dim=2)
        q = q.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)
        k = repeat_kv(k.view(B, T, self.n_kv_heads, C // self.n_head), self.n_head//self.n_kv_heads).transpose(1, 2) # (B, nh, T, hs)
        v = repeat_kv(v.view(B, T, self.n_kv_heads, C // self.n_head), self.n_head//self.n_kv_heads).transpose(1, 2) # (B, nh, T, hs)

                # attention (materializes the large (T,T) matrix for all the queries and keys)
        att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1)))
        att = att.masked_fill(self.bias[:,:,:T,:T] == 0, float('-inf'))
        att = F.softmax(att, dim=-1)
        y = att @ v # (B, nh, T, T) x (B, nh, T, hs) -> (B, nh, T, hs)
        y = y.transpose(1, 2).contiguous().view(B, T, C) # re-assemble all head outputs side by side
        # output projection
        y = self.c_proj(y)
        return y

Sliding Window Attention

Mistral 7B使用Sliding Window Attention(SWA)来减少KV cache的内存占用,每次计算attention时,只考虑固定窗口大小W内的信息。对于位置i的隐状态,只会考虑在其前面i-W到i的窗口内的隐状态信息,如下图示意所示,所以对于在第k层的位置i来说,最多可以访问到 W × k W\times k W×k个tokens。在Mistral 7B里,W=4096,层数为32,所以理论上的attention范围近似为131K。

在这里插入图片描述

因为使用固定attention窗口,所以Mistral 7B使用滚动(rolling) buffer cache, cache大小固定为W,在时刻t的K和V存储在cache的第i mod W个位置,也就是说如果位置i比W大,cache中原先存储的值会被覆盖掉。下图是W=3时的示意。
在这里插入图片描述

参考资料

  1. 看图学KV Cache

  2. Transformer Inference Arithmetic

  3. Transformers KV Caching Explained(其gif动画有助于加深理解)

  4. KV caching内存增长

  5. KV cache 是chatbot 规模化的一大工程挑战

  6. Techniques for KV Cache Optimization in Large Language Models

  7. KV cache quantization

  8. Inference Optimization

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mfbz.cn/a/797134.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

单元测试实施最佳方案(背景、实施、覆盖率统计)

1. 什么是单元测试&#xff1f; 对于很多开发人员来说&#xff0c;单元测试一定不陌生 单元测试是白盒测试的一种形式&#xff0c;它的目标是测试软件的最小单元——函数、方法或类。单元测试的主要目的是验证代码的正确性&#xff0c;以确保每个单元按照预期执行。单元测试通…

Android C++系列:Linux网络(三)协议格式

1. 数据包封装 传输层及其以下的机制由内核提供,应用层由用户进程提供(后面将介绍如何使用 socket API编写应用程序),应用程序对通讯数据的含义进行解释,而传输层及其以下 处理通讯的细节,将数据从一台计算机通过一定的路径发送到另一台计算机。应用层 数据通过协议栈发到…

搜索引擎中的相关性模型

一、什么是相关性模型&#xff1f; 相关性模型主要关注的是query和doc的相关性。例如给定query&#xff0c;和1000个doc&#xff0c;找到哪个doc是好query最相关的。 二、为什么需要相关性模型&#xff1f; 熟悉es的应该都熟悉BM25相关性算法。它是一个很简单的相关性算法。我…

nginx的四层负载均衡实战

目录 1 环境准备 1.1 mysql 部署 1.2 nginx 部署 1.3 关闭防火墙和selinux 2 nginx配置 2.1 修改nginx主配置文件 2.2 创建stream配置文件 2.3 重启nginx 3 测试四层代理是否轮循成功 3.1 远程链接通过代理服务器访问 3.2 动图演示 4 四层反向代理算法介绍 4.1 轮询&#xff0…

安全防御,防火墙配置NAT转换智能选举综合实验

一、实验拓扑图 二、实验需求 1、办公区设备可以通过电信链路和移动链路上网(多对多的NAT&#xff0c;并且需要保留一个公网IP不能用来转换) 2、分公司设备可以通过总公司的移动链路和电信链路访问到Dmz区的http服务器 3、多出口环境基于带宽比例进行选路&#xff0c;但是&…

python:使用matplotlib库绘制图像(四)

作者是跟着http://t.csdnimg.cn/4fVW0学习的&#xff0c;matplotlib系列文章是http://t.csdnimg.cn/4fVW0的自己学习过程中整理的详细说明版本&#xff0c;对小白更友好哦&#xff01; 四、条形图 1. 一个数据样本的条形图 条形图&#xff1a;常用于比较不同类别的数量或值&…

STM32之六:SysTick系统滴答定时器

目录 1. SysTick简介 2. 时钟来源 3. SysTick寄存器 3.1 CTRL—SysTick控制及状态寄存器 3.2 RELOAD—SysTick重装载数值寄存器 3.3 CURRENT—SysTick当前数值寄存器 4. systick系统定时器配置 5. 延时函数实现 5.1 延时函数编写步骤 5.2 微秒级延时函数delay_us 5.…

代理模式(大话设计模式)C/C++版本

代理模式 C #include <iostream> using namespace std;class Subject // Subject 定义了RealSubject和Proxy的共用接口..这样就在任何使用RealSubject的地方都可以使用Proxy { public:virtual void func(){cout << "Subject" << endl;} };class R…

Leetcode—3011. 判断一个数组是否可以变为有序【中等】(__builtin_popcount()、ranges::is_sorted())

2024每日刷题&#xff08;144&#xff09; Leetcode—3011. 判断一个数组是否可以变为有序 O(n)复杂度实现代码 class Solution { public:bool canSortArray(vector<int>& nums) {// 二进制数位下1数目相同的元素就不进行组内排序// 只进行分组// 当前组的值若小于…

全栈物联网项目:结合 C/C++、Python、Node.js 和 React 开发智能温控系统(附代码示例)

1. 项目概述 本文详细介绍了一个基于STM32微控制器和AWS IoT云平台的智能温控器项目。该项目旨在实现远程温度监控和控制,具有以下主要特点: 使用STM32F103微控制器作为主控芯片,负责数据采集、处理和控制逻辑采用DHT22数字温湿度传感器,精确采集环境温湿度数据通过ESP8266 W…

Android Spinner

1. Spinner Spinner是下拉列表&#xff0c;如图3-14所示&#xff0c;通常用于为用户提供选择输入。Spinner有一个重要的属性&#xff1a;spinnerMode&#xff0c;它有2种情况&#xff1a; 属性值为dropdown时&#xff0c;表示Spinner的数据下拉展示&#xff0c;如图1&#xf…

GenAl如何改变 DevOps 中的软件测试?

TestComplete 是一款自动化UI测试工具&#xff0c;这款工具目前在全球范围内被广泛应用于进行桌面、移动和Web应用的自动化测试。 TestComplete 集成了一种精心设计的自动化引擎&#xff0c;可以自动记录和回放用户的操作&#xff0c;方便用户进行UI&#xff08;用户界面&…

快速使用BRTR公式出具的大模型Prompt提示语

Role:文章模仿大师 Background: 你是一位文章模仿大师&#xff0c;擅长分析文章风格并进行模仿创作。老板常让你学习他人文章后进行模仿创作。 Attention: 请专注在文章模仿任务上&#xff0c;提供高质量的输出。 Profile: Author: 一博Version: 1.0Language: 中文Descri…

SpringCloud第三篇(服务中心与OpenFeign)

p 文章目录 一、服务中心二、Nacos注册中心 一、服务中心 在上一章我们实现了微服务拆分&#xff0c;并且通过Http请求实现了跨微服务的远程调用。不过这种手动发送Http请求的方式存在一些问题。 试想一下&#xff0c;假如商品微服务被调用较多&#xff0c;为了应对更高的并发…

韦东山嵌入式linux系列-具体单板的 LED 驱动程序

笔者使用的是STM32MP157的板子 1 怎么写 LED 驱动程序&#xff1f; 详细步骤如下&#xff1a; ① 看原理图确定引脚&#xff0c;确定引脚输出什么电平才能点亮/熄灭 LED ② 看主芯片手册&#xff0c;确定寄存器操作方法&#xff1a;哪些寄存器&#xff1f;哪些位&#xff1f;…

pytorch-pytorch之LSTM

目录 1. nn.LSTM2. nn.LSTMCell 1. nn.LSTM 初始化函数输入参数与RNN相同&#xff0c;分别是input_size&#xff0c;hidden_size和num_layer foward函数也与RNN类似&#xff0c;只不过返回值除了out外&#xff0c;ht变为(ht,ct) 代码见下图&#xff1a; 2. nn.LSTMCell 初…

基于与STM32的加湿器之旋转编码器驱动

1.简介 旋转编码器&#xff0c;也被称为轴编码器或脉冲编码器&#xff08;SPC&#xff09;&#xff0c;是一种将旋转的机械位移量转换为电气信号的传感器&#xff0c;其信号可用于检测位置、速度等。 2.工作原理 旋转编码器的工作原理主要基于光电转换或磁电转换。以光电式旋转…

电子签章 签到 互动 打卡 创意印章 支持小程序 H5 App

电子签章 签到 互动 打卡 创意印章 支持小程序 H5 App 定制化

华为防火墙nat和智能选路配置

要求&#xff1a; 7&#xff0c;办公区设备可以通过电信链路和移动链路上网(多对多的NAT&#xff0c;并且需要保留一个公网IP不能用来转换) 8&#xff0c;分公司设备可以通过总公司的移动链路和电信链路访问到Dmz区的http服务器 9&#xff0c;多出口环境基于带宽比例进行选路&…

k8s集群新增节点

目前集群状态 如K8S 集群搭建中规划的集群一样 Masternode01node02IP192.168.100.100192.168.100.101192.168.100.102OSCent OS 7.9Cent OS 7.9Cent OS 7.9 目前打算新增节点node03 Masternode01node02node03IP192.168.100.100192.168.100.101192.168.100.102192.168.100.1…