【PyTorch][chapter-33][transformer-5] MHA MQA GQA, KV-Cache

   主要翻译外网: 解剖Deep Seek 系列,详细见参考部分。


目录:

  1.   Multi-Head Attention (MHA)
  2.   KV-Cache 
  3.   KV-Cache 公式
  4.    Multi-Query Attention(MQA)
  5.    Grouped-Query Attention(GQA)
  6.    Multi-Head Latent Attention
  7.   PyTorch Implementing MHA, MQA, and GQA 


一   Multi-Head Attention (MHA) 

      

    输入:     Q,K,V 通常为x ,其形状为:

                    [batch_size, seq_len, d_model]

     第一步:    进行线性空间变换:

                    q= W^Qh ...(1)

                    k= W^Kh ...(2)

                   v= W^Vh ...(3)

      

     第二步: 子空间投影(projection)

          其中n_h为子空间头数量一般设置为8

     第三步:  做 self-attention

    


二         KV-Cache 

           在Transformer的Decoder推理过程中,由于自注意力机制需要遍历整个先前输入的序列来计算每个新token的注意力权重,这导致了显著的计算负担。随着序列长度的延伸,计算复杂度急剧上升,不仅增加了延迟,还限制了模型处理长序列的能力。因此,优化Decoder的自注意力机制,减少不必要的计算开销,成为提升Transformer模型推理效率的关键所在。

      KV 缓存发生在多个 token 生成步骤中,并且仅发生在解码器中(即,在 GPT 等仅解码器模型中,或在 T5 等编码器-解码器模型的解码器部分中)。BERT 等模型不是生成式的,因此没有 KV 缓存。

    这种自回归行为重复了一些操作,我们可以通过放大在解码器中计算的掩蔽缩放点积注意力计算来更好地理解这一点。
            

     由于解码器是auto-regressive  的,因此在每个生成步骤中,我们都在重新计算相同的先前标记的注意力,而实际上我们只是想计算新标记的注意力。

    这就是 KV 发挥作用的地方。通过缓存以前的 Keys 和 Values,我们可以专注于计算新 token 的注意力。

       为什么这种优化很重要?如上图所示,使用 KV 缓存获得的矩阵要小得多,从而可以加快矩阵乘法的速度。唯一的缺点是它需要更多的 GPU VRAM(如果不使用 GPU,则需要 CPU RAM)来缓存 Key 和 Value 状态。


三   KV-Cache  公式

       基本和MHA 过程差不多,区别是每次输入的是h_t: 第t 时刻的token

           3.1  对当前时刻的输入进行线性变换

                   q_t= W^{Q}h_t \: \: \: \: \: \: \: (1)

                   k_t= W^{K}h_t \: \: \: \: \: \: \: (2)

                   v_t= W^{V}h_t \: \: \: \: \: \: \: (2)

                 h_t \in R^d

                 d: embedding 的维度

          3.2   进行子空间投影

                        

                       q_t =\begin{bmatrix} q_{t,1},q_{t,2},...q_{t,n_h} \end{bmatrix} \, \, \, \, \, \, \, \, (4)

                       k_t =\begin{bmatrix} k_{t,1},k_{t,2},...k_{t,n_h} \end{bmatrix} \, \, \, \, \, \, \, \, (5)

                      v_t =\begin{bmatrix} v_{t,1},v_{t,2},...v_{t,n_h} \end{bmatrix} \, \,\, \, \, \, \, \, \, (6)

                     其中:

                    n_h: attention head 数量

                    d_h: attention head 的维度

                    q_t,k_t,v_t \in R^{d_h*n_h}

                    q_{t,i}:  第i个头,t时刻的查询向量

   3.3  做self-attention

         

   我们把存储的K,V缓存叫做K-V Cache. 对于一个L层的模型,每个t个token 一共需要

2*n_hd_hLt 缓存。

     d_h 是一个head的size ,MLA 就是研究这个size 如何降维降低KV-Cache


四  Multi-Query Attention(MQA)

   为了缓解多头注意力(MHA)中的键值缓存瓶颈问题,Shazeer在2019年提出了多查询注意力(MQA)机制。在该机制中,所有的不同注意力头共享相同的键和值,即除了不同的注意力头共享同一组键和值之外,其余部分与MHA相同。这大大减轻了键值缓存的负担,从而显著加快了解码器的推理速度。然而,MQA会导致质量下降和训练不稳定。

    

    


五  Grouped Query Attention — (GQA)

     分组查询注意力(GQA)通过在多头注意力(MHA)和多查询注意力(MQA)之间引入一定数量的查询头子组(少于总注意力头的数量),每个子组有一个单独的键头和值头,从而实现了一种插值。与MQA相比,随着模型规模的增加,GQA在内存带宽和容量上保持了相同比例的减少。中间数量的子组导致了一个插值模型,该模型的质量高于MQA但推理速度快于MHA。很明显,只有一个组的GQA等同于MQA。

        


六  PyTorch Implementing MHA, MQA, and GQA 

num_kv_heads 和 num_heads  一样的时候就是 MHA

num_kv_heads= 1                 就是MQA

num_kv_heads<num_heads  就是GQA

# -*- coding: utf-8 -*-
"""
Created on Fri Feb 21 15:02:18 2025

@author: chengxf2
"""

import torch.nn as nn
import torch.nn.functional as F
import torch


def scaled_dot_product_attention(query, key, value, mask=None, dropout=None):
    # 获取 key 的维度大小,用于缩放
    d_k = query.size(-1)
    # 计算点积注意力得分
    scores = torch.matmul(query, key.transpose(-2, -1)) / torch.sqrt(torch.tensor(d_k, dtype=torch.float32))
    # 如果提供了 mask,将其应用到得分上
    if mask is not None:
        scores = scores.masked_fill(mask == 0, -1e9)
    # 对得分进行 softmax 操作,得到注意力权重
    p_attention = F.softmax(scores, dim=-1)
    # 如果提供了 dropout,应用 dropout
    if dropout is not None:
        p_attention = dropout(p_attention)
    # 使用注意力权重对 value 进行加权求和
    return torch.matmul(p_attention, value)


class  Attention(nn.Module):
    def __init__(self, d_model=512,num_heads=8, num_kv_heads=2,dropout=0.5):
        super().__init__()
        self.num_heads = num_heads
        self.head_dim =  d_model//num_heads
        self.num_kv_heads = num_kv_heads
        assert self.num_heads%self.num_kv_heads == 0
        self.num_queries_per_kv = self.num_heads//self.num_kv_heads
        #Linear
        self.query =   nn.Linear(d_model, self.head_dim * self.num_heads)
        self.key =     nn.Linear(d_model,   self.head_dim * self.num_kv_heads)
        self.value =   nn.Linear(d_model, self.head_dim * self.num_kv_heads)
        #输出
        self.proj = nn.Linear(d_model, d_model)
        self.attn_dropout = nn.Dropout(dropout)
    
    def forward(self, inputs):
        
        batch, seq_len, d_model = inputs.shape
        q = self.query(inputs)
        k = self.key(inputs)
        v = self.value(inputs)
        # shape = (B, seq_len, num_heads, head_dim)
        q = q.view(batch, seq_len, -1,  self.head_dim)
        k = k.view(batch, seq_len, -1 , self.head_dim)  
        v = v.view(batch, seq_len, -1,  self.head_dim)
     
        print("default q.shape",q.shape)
        print("default k.shape",k.shape)
        print("default v.shape",v.shape)
        # Grouped Query Attention
        #[batch, seq_len, num_kv_heads, head_dim]->[batch, seq_len, num_heads, head_dim]
        if self.num_kv_heads != self.num_heads:
           k = torch.repeat_interleave(k, self.num_queries_per_kv, dim=2)
           v = torch.repeat_interleave(v, self.num_queries_per_kv, dim=2)
        # shape = (B, num_heads, seq_len, head_dim) 
        k = k.transpose(1, 2)  
        q = q.transpose(1, 2)
        v = v.transpose(1, 2)
        print("q.shape",q.shape)
        print("k.shape",k.shape)
        print("v.shape",v.shape)

        output = scaled_dot_product_attention(
            q,
            k,
            v,  # order impotent
            None,
            self.attn_dropout,
        )
        print("v.shape",v.shape)
        print("output.shape",output.shape)
        output = output.transpose(1, 2).contiguous().view(batch, seq_len, d_model)
        # final projection into the residual stream
        output = self.proj(output)
        return output
net = Attention()
batch_size =2
seq_len = 5
d_model =512
x = torch.randn(batch_size,seq_len, d_model)
net(x)

        
        
        


七  Multi-Head Latent Attention — (MLA)

Multi-Head Latent Attention (MLA) achieves superior performance than MHA, as well as significantly reduces KV-cache boosting inference efficiency. Instead of reducing KV-heads as in MQA and GQA, MLA jointly compresses the Key and Value into a latent vector.

Low-Rank Key-Value Joint Compression

Instead of caching both the Key and Value matrices, MLA jointly compresses them in a low-rank vector which allows caching fewer items since the compression dimension is much less compared to the output projection matrix dimension in MHA.

Comparison of Deepseek’s new Multi-latent head attention with MHA, MQA, and GQA.

参考:

https://github.com/deepseek-ai/DeepSeek-V3/blob/main/DeepSeek_V3.pdf

缓存与效果的极限拉扯:从MHA、MQA、GQA到MLA - 科学空间|Scientific Spaces

DeepSeek-V3 Explained 1: Multi-head Latent Attention | Towards Data Science

https://medium.com/@zaiinn440/mha-vs-mqa-vs-gqa-vs-mla-c6cf8285bbec

https://medium.com/@zaiinn440/mha-vs-mqa-vs-gqa-vs-mla-c6cf8285bbec

deepseek技术解读(1)-彻底理解MLA(Multi-Head Latent Attention)-CSDN博客

怎么加快大模型推理?10分钟学懂VLLM内部原理,KV Cache,PageAttention_哔哩哔哩_bilibili

https://medium.com/@joaolages/kv-caching-explained-276520203249

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/980028.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

Spring Boot 流式响应豆包大模型对话能力

当Spring Boot遇见豆包大模型&#xff1a;一场流式响应的"魔法吟唱"仪式 一、前言&#xff1a;关于流式响应的奇妙比喻 想象一下你正在火锅店点单&#xff0c;如果服务员必须等所有菜品都备齐才一次性端上来&#xff0c;你可能会饿得把菜单都啃了。而流式响应就像贴…

记录Liunx安装Jenkins时的Package ‘jenkins‘ has no installation candidate

1、确保是否安装了Java&#xff0c;如果没有&#xff0c;可通过以下命令进行安装&#xff1a; sudo apt update sudo apt install openjdk-21-jre2、安装Jenkins sudo apt update sudo apt install jenkins执行sudo apt install jenkins时&#xff0c;可能会出现 意思是&…

Windows用户如何零成本迁移Sketch项目?2025实测方案推荐

在设计领域&#xff0c;Sketch一直是UI/UX设计师的不二之选。它凭借简洁的界面、强大的矢量绘图功能深受设计师们的喜爱。尽管有着广泛的应用和众多优势&#xff0c;但Sketch仅支持MacOS系统&#xff0c;这对于Windows用户来说是一个巨大的限制。 然而&#xff0c;随着设计需求…

通过百度构建一个智能体

通过百度构建一个智能体 直接可用,我不吝啬算力 首先部署一个模型,我们选用deepseek14 构建智能体思考步骤,甚至多智能体; from openai import OpenAIclass Agent:def __init__(self, api_key, base_url, model

【K8S】Kubernetes 基本架构、节点类型及运行流程详解(附架构图及流程图)

Kubernetes 架构 k8s 集群 多个 master node 多个 work nodeMaster 节点&#xff08;主节点&#xff09;&#xff1a;负责集群的管理任务&#xff0c;包括调度容器、维护集群状态、监控集群、管理服务发现等。Worker 节点&#xff08;工作节点&#xff09;&#xff1a;实际运…

FFmpeg-chapter2-C++中的线程

1 常规的线程 一般常规的线程如下所示 // CMakeProject1.cpp: 定义应用程序的入口点。 //#include "CMakeProject1.h" #include <thread> using namespace std;void threadFunction(int index) {for (int i 0; i < 1000; i){std::cout << "Th…

GitCode 助力 JeeSite:开启企业级快速开发新篇章

项目仓库&#xff08;点击阅读原文链接可直达前端仓库&#xff09; https://gitcode.com/thinkgem/jeesite 企业级快速开发的得力助手&#xff1a;JeeSite 快速开发平台 JeeSite 不仅仅是一个普通的后台开发框架&#xff0c;而是一套全面的企业级快速开发解决方案。后端基于 …

EasyRTC:支持任意平台设备的嵌入式WebRTC实时音视频通信SDK解决方案

随着互联网技术的飞速发展&#xff0c;实时音视频通信已成为各行各业数字化转型的核心需求之一。无论是远程办公、在线教育、智慧医疗&#xff0c;还是智能安防、直播互动&#xff0c;用户对低延迟、高可靠、跨平台的音视频通信需求日益增长。 一、WebRTC与WebP2P&#xff1a;实…

【Qt】MVC设计模式

目录 一、搭建MVC框架 二、创建数据库连接单例类SingleDB 三、数据库业务操作类model设计 四、control层&#xff0c;关于model管理类设计 五、view层即为窗口UI类 一、搭建MVC框架 里面的bin、lib、database文件夹以及sqlite3.h与工程后缀为.pro文件的配置与上次发的文章…

使用C#控制台调用本地部署的DeepSeek

1、背景 春节期间大火的deepseek&#xff0c;在医疗圈也是火的不要不要的。北京这边的医院也都在搞“deepseek竞赛”。友谊、北医三院等都已经上了&#xff0c;真是迅速啊&#xff01; C#也是可以进行对接&#xff0c;并且非常简单。 2、具体实现 1、使用Ollama部署DeepSeek…

接口测试工具:postman详解

&#x1f345; 点击文末小卡片&#xff0c;免费获取软件测试全套资料&#xff0c;资料在手&#xff0c;涨薪更快 Postman 是一款功能强大的 API 开发和测试工具&#xff0c;以下是一些高级用法的详细介绍和操作步骤。 一、环境和全局变量 环境变量允许你设置特定于环境&#…

ERP系统的库存模块业务逻辑及设计

传统上通常将“库存管理”理解为对物料的进、出、存的业务管理&#xff0c;但这种理解在ERP系统中是不全面的。 APICS词汇中对库存的定义是“以支持生产、维护、操作和客户服务为目的而存储的各种物料&#xff0c;包括原材料和在制品、维修件和生产消耗、成品和备件等”​。库…

软件安全性测试类型分享,第三方软件测试机构如何进行安全性测试?

在数字化时代&#xff0c;软件的安全性至关重要&#xff0c;因此软件产品安全性测试必不可少。软件安全性测试是指针对软件系统的漏洞、弱点及其他安全隐患进行评估和检测的过程。它旨在发现潜在的安全问题&#xff0c;以保护软件和用户的利益。通过系统化的测试&#xff0c;企…

自由学习记录(40)

virtual的重写能力&#xff0c;&#xff0c;这在剥离Player方法和成员变量的时候&#xff0c;起的作用很灵活&#xff0c;敌人默认可以继承这些规则&#xff0c;但只是默认&#xff0c;自己要修改的话和原来不会产生半点联系&#xff0c;这个确实厉害 Cinemachine Virtual Came…

神经网络|(十一)|神经元和神经网络

【1】引言 前序已经了解了基本的神经元知识&#xff0c;相关文章链接为&#xff1a; 神经网络|(一)加权平均法&#xff0c;感知机和神经元-CSDN博客 神经网络|(二)sigmoid神经元函数_sigmoid函数绘制-CSDN博客 神经网络|(三)线性回归基础知识-CSDN博客 把不同的神经元通过…

【Python】基础语法三

> 作者&#xff1a;დ旧言~ > 座右铭&#xff1a;松树千年终是朽&#xff0c;槿花一日自为荣。 > 目标&#xff1a;了解Python的函数、列表和数组。 > 毒鸡汤&#xff1a;有些事情&#xff0c;总是不明白&#xff0c;所以我不会坚持。早安! > 专栏选自&#xff…

PHP使用Redis实战实录2:Redis扩展方法和PHP连接Redis的多种方案

PHP使用Redis实战实录系列 PHP使用Redis实战实录1&#xff1a;宝塔环境搭建、6379端口配置、Redis服务启动失败解决方案PHP使用Redis实战实录2&#xff1a;Redis扩展方法和PHP连接Redis的多种方案 Redis扩展方法和PHP连接Redis的多种方案 一、Redis扩展方法二、php操作Redis…

kubernetes 初学命令

基础命令 kubectl 运维命令常用&#xff1a; #查看pod创建过程以及相关日志 kubectl describe pod pod-command -n dev #查看某个pod&#xff0c;以yaml格式展示结果 kubectl get pod nginx -o yaml #查看pod 详情 以及对应的集群IP地址 kubectl get pods -o wide 1. kubetc…

[C++_] set | map | unordered_map

前文回顾&#xff1a; 【C】详解 set | multiset 【C】关联容器探秘&#xff1a;Map与Multimap详解 在 C 中&#xff0c;map 和 unordered_map 都是存储键值对的关联容器&#xff0c;但它们的实现和特性有显著区别。如下&#xff1a; 1. 底层实现与有序性 map 基于红黑树&a…

【计算机网络】TCP三次握手,四次挥手以及SYN,ACK,seq,以及握手次数理解

TCP三次握手图解 描述 第一次握手&#xff1a;客户端请求建立连接&#xff0c;发送同步报文(SYN1)&#xff0c;同时随机一个seqx作为初始序列号&#xff0c;进入SYN_SENT状态&#xff0c;等待服务器确认 第二次握手&#xff1a;服务端收到请求报文&#xff0c;如果同意建立连接…