Scaled_dot_product_attention(SDPA)使用详解

在学习huggingFace的Transformer库时,我们不可避免会遇到scaled_dot_product_attention(SDPA)这个函数,它被用来加速大模型的Attention计算,本文就详细介绍一下它的使用方法,核心内容主要参考了torch.nn.functional中该函数的注释。

1. Attention计算公式

Attention的计算主要涉及三个矩阵:Q、K、V。我们先不考虑multi-head attention,只考虑one head的self attention。在大模型的prefill阶段,这三个矩阵的维度均为N x d,N即为上下文的长度;在decode阶段,Q的维度为1 x d, KV还是N x d。然后通过下面的公式计算attention矩阵:
O = A t t e n t i o n ( Q , K , V ) = s o f t m a x ( Q K T d ) V O=Attention(Q,K,V)=softmax(\frac{QK^T}{\sqrt d})V O=Attention(Q,K,V)=softmax(d QKT)V
在真正使用attention的时候,我们往往采用multi-head attention(MHA)。MHA的计算公式和one head attention基本一致,它改变了Q、K、V每一行的定义:将维度d的向量分成h组变成一个h x dk的矩阵,Q、K、V此时成为了 N ∗ h ∗ d k N * h * d_k Nhdk的三维矩阵(不考虑batch维)。分别将Q、K、V的第一和第二维进行转置得到三个维度为 h ∗ N ∗ d k h * N * d_k hNdk的三维矩阵。此时的三个矩阵就是具有h个头的Q、K、V,我们就可以按照self attention的定义计算h个头的attention值。

不过,在真正进行大模型推理的时候就会发现KV Cache是非常占显存的,所以大家尝试各种手段压缩KV Cache,具体可以参考《大模型推理–KV Cache压缩》。一种手段就是将MHA替换成group query attention(GQA),这块在torch2.5以上的SDPA中也已经得到了支持。

2. SDPA伪代码

在SDPA的注释中,给出了伪代码:

def scaled_dot_product_attention(query, key, value, attn_mask=None, dropout_p=0.0,
                is_causal=False, scale=None, enable_gqa=False) -> torch.Tensor:
	L, S = query.size(-2), key.size(-2)
	scale_factor = 1 / math.sqrt(query.size(-1)) if scale is None else scale
	
    attn_bias = torch.zeros(L, S, dtype=query.dtype)
	if is_causal:
    	 assert attn_mask is None
         temp_mask = torch.ones(L, S, dtype=torch.bool).tril(diagonal=0)
         attn_bias.masked_fill_(temp_mask.logical_not(), float("-inf"))
         attn_bias.to(query.dtype)

    if attn_mask is not None:
    	 if attn_mask.dtype == torch.bool:
         	 attn_bias.masked_fill_(attn_mask.logical_not(), float("-inf"))
         else:
             attn_bias += attn_mask

    if enable_gqa:
    	 key = key.repeat_interleave(query.size(-3)//key.size(-3), -3)
         value = value.repeat_interleave(query.size(-3)//value.size(-3), -3)

    attn_weight = query @ key.transpose(-2, -1) * scale_factor
    attn_weight += attn_bias
    attn_weight = torch.softmax(attn_weight, dim=-1)
    attn_weight = torch.dropout(attn_weight, dropout_p, train=True)
    
	return attn_weight @ value

可以看出,我们实际在使用SDPA时除了query、key和value之外,还有另外几个参数:attn_mask、dropout_p、is_causal、scale和enable_gqa。scale就是计算Attention时的缩放因子,一般无需传递。dropout_p表示Dropout概率,在推理阶段也不需要传递,不过官方建议如下输入:dropout_p=(self.p if self.training else 0.0)。我们着重看一下另外三个参数在使用时该如何设置。

先看enable_gqa。前面提到GQA是一种KV Cache压缩方法,MHA的KV和Q一样,也会有h个头,GQA则将KV的h个头进行压缩来减小KV Cache的大小。比如Qwen2-7B-Instruct这个模型,Q的h等于28,KV的h等于4,相当于把KV Cache压缩到之前的七分之一。GQA虽然压缩了KV Cache,但是真正要计算Attention的时候还是需要对齐KV与Q的head数,所以我们可以看到HF Transformer库中的qwen2.py在Attention计算时会有一个repeat_kv的操作,目的就是将QKV的head数统一。在torch2.5以后的版本中,我们无需再手动去执行repeat_kv,直接将SDPA的enable_gqa设置为True即可自动完成repeat_kv,而且速度比自己去做repaet_kv还要更快。

attn_mask和is_causal两个参数的作用相同,目的都是要给softmax之前的QKT矩阵添加mask。只不过attn_mask是自己在外面构造mask矩阵,is_causal则是根据大模型推理的阶段属于prefill还是decode来进行设置。通过看伪代码可以看出,SDPA会首先构造一个L x S的零矩阵attn_bias,L表示Q的上下文长度,S表示KV Cache的长度。在prefill阶段,L和S相等,在decode阶段,L为1,S还是N。所以在prefill阶段,attn_bias就是一个N x N的矩阵,将is_causal设置为True时就会构造一个下三角为0,上三角为负无穷的矩阵作为attn_bias,然后将其加到QKT矩阵上,这样就实现了因果关系的Attention计算。在decode阶段,attn_bias就是一个1 x N的向量,此时可以将is_causal设置为False,attn_bias始终为0就不会对 Q K T QK^T QKT行向量产生影响,表示KV Cache所有的行都参与计算,因果关系保持正确。

attn_mask作用和is_causal一样,但是需要我们自行构造,如果你对如何构造不了解建议就使用is_causal选项,prefill阶段设置为True,decode阶段设置为False,attn_mask设置为None。不过,如果prefill按照chunk来执行也即chunk_prefill阶段,我们会发现is_causal设置为True时的attn_bias设置的不正确,我们不是从左上角开始构造下三角矩阵,而是要从右下角开始构造下三角矩阵,这种情况下我们可以从外面自行构造attn_mask矩阵代替SDPA的构造。attn_mask有两种构造方式,一种是bool类型,True的位置会保持不变,False的位置会置为负无穷;一种是float类型,会直接将attn_mask加到SDPA内部的attn_bias上,和bool类型一样,我们一般是构造一个下三角为0上三角为负无穷的矩阵。总结来说,绝大多数情况下我们只需要设置is_causal选项,prefill阶段设置为True,decode阶段设置为False,attn_mask设置为None即可。如果推理阶段引入了chunk_prefill,则我们需要自行构造attn_mask,但是要注意构造的attn_mask矩阵是从右下角开始的下三角矩阵。

3. SDPA实现(翻译自SDPA注释)

目前SDPA有三种实现:

  1. 基于FlashAttention-2的实现;
  2. Memory-Efficient Attention(facebook xformers);
  3. Pytorch版本对上述伪代码的c++实现(对应MATH后端)。

针对CUDA后端,SDPA可能会调用经过优化的内核以提高性能。对于所有其他后端,将使用PyTorch实现。所有实现方式默认都是启用的,SDPA会尝试根据输入自动选择最优的实现方式。为了对使用哪种实现方式提供更细粒度的控制,torch提供了以下函数来启用和禁用各种实现方式:

  1. torch.nn.attention.sdpa_kernel:一个上下文管理器,用于启用或禁用任何一种实现方式;
  2. torch.backends.cuda.enable_flash_sdp:全局启用或禁用FlashAttention
  3. torch.backends.cuda.enable_mem_efficient_sdp:全局启用或禁用memory efficient attention
  4. torch.backends.cuda.enable_math_sdp:全局启用或禁用PyTorch的C++实现。

每个融合内核都有特定的输入限制。如果用户需要使用特定的融合实现方式,请使用torch.nn.attention.sdpa_kernel 禁用PyTorch 的C++实现。如果某个融合实现方式不可用,将会发出警告,说明该融合实现方式无法运行的原因。由于融合浮点运算的特性,此函数的输出可能会因所选择的后端内核而异。C++ 实现支持torch.float64,当需要更高精度时可以使用。对于math后端,如果输入是torch.half或torch.bfloat16类型,那么所有中间计算结果都会保持为torch.float类型。

4. SDPA使用示例

首先强调一点,灌入SDPA的QKV都是做过转置的,也即维度为batch x head x N x d,在老版本的torch中还需要QKV都是contiguous的,新版本下无此要求。SDPA注释中还给了两个示例,我们在此也给出:

# Optionally use the context manager to ensure one of the fused kernels is run
 query = torch.rand(32, 8, 128, 64, dtype=torch.float16, device="cuda")
 key = torch.rand(32, 8, 128, 64, dtype=torch.float16, device="cuda")
 value = torch.rand(32, 8, 128, 64, dtype=torch.float16, device="cuda")
 with sdpa_kernel(backends=[SDPBackend.FLASH_ATTENTION]):
 	F.scaled_dot_product_attention(query,key,value)

上述示例中,给定的输入为batch等于32,head等于8,上下文长度128,embedding维度64,然后通过sdpa_kernel选择使用FlashAttention。
示例二:

# Sample for GQA for llama3
query = torch.rand(32, 32, 128, 64, dtype=torch.float16, device="cuda")
key = torch.rand(32, 8, 128, 64, dtype=torch.float16, device="cuda")
value = torch.rand(32, 8, 128, 64, dtype=torch.float16, device="cuda")
with sdpa_kernel(backends=[SDPBackend.MATH]):
	F.scaled_dot_product_attention(query,key,value,enable_gqa=True)

示例二演示了GQA的用法,给定的query head数为32,key和value均为8,此时我们可以通过enable_gqa选项来实现对GQA的支持,此外代码还通过sdpa_kernel选项使用了MATH后端。

5. 参考

  1. FlashAttention-2: Faster Attention with Better Parallelism and Work Partitioning
  2. Memory-Efficient Attention
  3. Grouped-Query Attention
  4. Attention Is All You Need

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/984351.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

DeepSeek R1-32B医疗大模型的完整微调实战分析(全码版)

DeepSeek R1-32B微调实战指南 ├── 1. 环境准备 │ ├── 1.1 硬件配置 │ │ ├─ 全参数微调:4*A100 80GB │ │ └─ LoRA微调:单卡24GB │ ├── 1.2 软件依赖 │ │ ├─ PyTorch 2.1.2+CUDA │ │ └─ Unsloth/ColossalAI │ └── 1.3 模…

vue3 vite项目安装eslint

npm install eslint -D 安装eslint库 npx eslint --init 初始化配置,按项目实际情况选 自动生成eslint.config.js,可以添加自定义rules 安装ESLint插件 此时打开vue文件就会标红有问题的位置 安装prettier npm install prettier eslint-config-pr…

【五.LangChain技术与应用】【10.LangChain ChatPromptTemplate(下):复杂场景下的应用】

凌晨两点的西二旗,你盯着监控大屏上跳动的错误日志,智能客服系统在流量洪峰中像纸船一样摇晃。用户骂声塞满弹窗:“等了十分钟就这?”“刚才说的怎么不认了?”“我要人工!!”——这时候你需要的不只是ChatPromptTemplate,而是给对话系统装上航天级操控台。 一、模板组…

javascrip网页设计案例,SuperSlide+bootstrap+html经典组合

概述 JavaScript作为一种强大的脚本语言,在网页设计领域发挥着举足轻重的作用,能够为网页赋予丰富的交互性与动态功能。以下通过具体案例来深入理解其应用。​ 假设要打造一个旅游网站,该网站具备诸多实用功能。在响应式设计方面&#xff0…

python量化交易——金融数据管理最佳实践——使用qteasy大批量自动拉取金融数据

文章目录 使用数据获取渠道自动填充数据QTEASY数据拉取功能数据拉取接口refill_data_source()数据拉取API的功能特性多渠道拉取数据实现下载流量控制实现错误重试日志记录其他功能 qteasy是一个功能全面且易用的量化交易策略框架, Github地址在这里。使用它&#x…

Vite 6 升级指南:CJS 和 ESM 的爱恨情仇

Vite 6 升级指南:CJS 和 ESM 的爱恨情仇 前言:老朋友 CJS,新欢 ESM 说到 JavaScript 模块化,CJS 和 ESM 这俩货简直像是两代人的思维碰撞。前者是 Node.js 的老朋友,后者是前端新时代的宠儿。Vite 6 直接把 CJS 踢出…

FreeRTOS任务状态查询

一.任务相关API vTaskList(),创建一个表格描述每个任务的详细信息 char biaoge[1000]; //定义一个缓存 vTaskList(biaoge); //将表格存到这缓存中 printf("%s /r/n",biaoge); 1.uxTaskPriorityGet(&#xf…

【3】VS Code 新建上位机项目---C#窗体与控件开发

【3】VS Code 新建上位机项目---C#窗体与控件开发 1 窗体1.1 新建窗体1.2 windows程序与窗口的关系1.3 windows程序的事件1.3.1 定义事件与处理事件1.3.2 关联事件1.3.3 Windows窗体对象创建与显示(show与ShowDialog())2 控件与容器2.1 常用容器2.1.1 Groupbox2.1.2 Pannel2.1.…

AI编程: 一个案例对比CPU和GPU在深度学习方面的性能差异

背景 字节跳动正式发布中国首个AI原生集成开发环境工具(AI IDE)——AI编程工具Trae国内版。 该工具模型搭载doubao-1.5-pro,支持切换满血版DeepSeek R1&V3, 可以帮助各阶段开发者与AI流畅协作,更快、更高质量地完…

ubuntu 20.04下ZEDmini安装使用

提前安装好显卡驱动和cuda,如果没有安装可以参考我的这两篇文章进行安装: ubuntu20.04配置YOLOV5(非虚拟机)_ubuntu20.04安装yolov5-CSDN博客 ubuntu20.04安装显卡驱动及问题总结_乌班图里怎么备份显卡驱动-CSDN博客 还需要提前…

2025数据存储技术风向标:解析数据湖与数据仓库的实战效能差距

一、技术演进的十字路口 当前全球数据量正以每年65%的复合增长率激增,IDC预测到2027年企业将面临日均处理500TB数据的挑战。在这样的背景下,传统数据仓库与新兴数据湖的博弈进入白热化阶段。Gartner最新报告显示,采用混合架构的企业数据运营效…

Spring(1)——mvc概念,部分常用注解

1、什么是Spring Web MVC? Spring MVC 是一种基于 Java 的实现了 MVC(Model-View-Controller,模型 - 视图 - 控制器)设计模式的 Web 应用框架,它是 Spring 框架的一个重要组成部分,用于构建 Web 应用程序。…

PY32MD320单片机 QFN32封装,内置多功能三相 NN 型预驱。

PY32MD320单片机是普冉半导体的一款电机专用MCU,芯片采用了高性能的 32 位 ARM Cortex-M0 内核,主要用于电机控制。PY32MD320嵌入高达 64 KB Flash 和 8 KB SRAM 存储器,最高工作频率 48 MHz。PY32MD320单片机的工作温度范围为 -40 ~ 105 ℃&…

《OkHttp:工作原理 拦截器链深度解析》

目录 一、OKHttp 的基本使用 1. 添加依赖 2. 发起 HTTP 请求 3. 拦截器(Interceptor) 4. 高级配置 二、OKHttp 核心原理 1. 责任链模式(Interceptor Chain) 2. 连接池(ConnectionPool) 3. 请求调度…

HeidiSQL:一款免费的数据库管理工具

HeidiSQL 是一款免费的图形化数据库管理工具,支持 MySQL、MariaDB、Microsoft SQL、PostgreSQL、SQLite、Interbase 以及 Firebird,目前只能在 Windows 平台使用。 HeidiSQL 的核心功能包括: 免费且开源,所有功能都可以直接使用。…

C/C++蓝桥杯算法真题打卡(Day3)

一、P8598 [蓝桥杯 2013 省 AB] 错误票据 - 洛谷 算法代码&#xff1a; #include<bits/stdc.h> using namespace std;int main() {int N;cin >> N; // 读取数据行数unordered_map<int, int> idCount; // 用于统计每个ID出现的次数vector<int> ids; …

【2025软考高级架构师】——软件工程(2)

摘要 本文主要介绍了软件工程中常见的多种软件过程模型&#xff0c;包括瀑布模型、原型模型、V模型、W模型、迭代与增量模型、螺旋模型、构件组装模型、基于构件的软件工程&#xff08;CBSE&#xff09;、快速应用开发&#xff08;RAD&#xff09;、统一过程/统一开发方法和敏…

【Vue3 Element UI - Plus + Tyscript 实现Tags标签输入及回显】

Vue3 Element Plus TypeScript 实现 Tags 标签输入及回显 在开发后台管理系统或表单页面时&#xff0c;动态标签&#xff08;Tags&#xff09; 是一个常见的功能需求。用户可以通过输入框添加标签&#xff0c;并通过关闭按钮删除标签&#xff0c;同时还需要支持标签数据的提…

Easysearch 使用 AWS S3 进行快照备份与还原:完整指南及常见错误排查

Easysearch 可以使用 AWS S3 作为远程存储库&#xff0c;进行索引的快照&#xff08;Snapshot&#xff09;备份和恢复。同时&#xff0c;Easysearch 内置了 S3 插件&#xff0c;无需额外安装。以下是完整的配置和操作步骤。 1. 在 AWS S3 上创建存储桶 登录 AWS 控制台&#x…

【CSS3】筑基篇

目录 复合选择器后代选择器子选择器并集选择器交集选择器伪类选择器 CSS 三大特性继承性层叠性优先级 背景属性背景色背景图背景图平铺方式背景图位置背景图缩放背景图固定背景复合属性 显示模式显示模式块级元素行内元素行内块元素 转换显示模式 结构伪类选择器结构伪类选择器…