FlashAttention3:“GEMM”就是比较快!

 

阅读文章之前请温习以下四篇文章,避免云里雾里:

轻松读懂FlashAttention上<矩阵分块加载,改写softmax算法>

轻松读懂FlashAttention下

轻松读懂FlashAttention-2<优化循环体,减少非矩阵运算>

GPU的基础认知<GEMM是最佳的选择!最好都改写为矩阵运算!>

加速Hopper GPU

注意力机制是Transformer架构的核心能力,也是大型语言模型和长上下文应用的瓶颈。FlashAttention(和 FlashAttention-2)开创了一种通过最小化内存读/写来加速 GPU 注意力的方法,现在大多数库都使用它来加速 Transformer 训练和推理。

这导致了过去两年上下文LLM长度的大幅增加,从2-4K(GPT-3,OPT)增加到128K(GPT-4),甚至1M(Llama 3)。然而,尽管取得了成功,但 FlashAttention 尚未利用现代硬件中的新功能,FlashAttention-2在H100 GPU上仅实现了35%的理论最大FLOP 利用率。

在这篇博文中介绍了三种主要技术来加快对Hopper GPU的关注:利用 Tensor Core和TMA的异步性。

1) 通过扭曲专业化重叠整体计算和数据移动

2)交错块matmul和softmax操作

3)利用硬件支持实现 FP8 低精度的非连贯处理

快速了解

FlashAttention-3比使用FP16的FlashAttention-2快1.5-2.0倍,高达740 TFLOPS,即H100理论最大FLOPS利用率为 75%。使用FP8时,FlashAttention-3达到接近 1.2 PFLOPS,误差比基线FP8注意小2.6倍。

  • 更高效的GPU利用率:新技术可利用高达75%的H100 GPU最大功能,而之前仅为35%。这导致在训练和运行大型语言模型方面,比以前的版本快得多(1.5-2 倍LLMs)。

  • 以较低的精度获得更好的性能:FlashAttention-3可以处理精度较低的数字FP8,同时保持精度。这样可以实现更快的处理速度,并可能降低内存使用率,从而为运行大规模AI操作的客户节省成本并提高效率。

  • 能够在以下位置LLMs使用更长的上下文:通过加速注意力机制,FlashAttention-3使AI模型能够更有效地处理更长的文本片段。这可以使应用程序能够在不减慢速度的情况下理解和生成更长、更复杂的内容。

GEMM和SOFTMAX

注意力有两个主要操作GEMMs(GEMMs是指广义矩阵乘法General Matrix Multiply),例如注意力机制中Q和K之间以及注意力矩阵P和V之间的矩阵乘法。

GPU上面现代加速器上,非matmul操作比matmul操作慢得多。例如softmax中的指数运算等特殊函数的吞吐量远远低于浮点乘加。这些特殊运算(函数)SF一般是由多功能(计算)单元负责,多功能(计算)单元是独立于浮点乘-加(例如y=wx+b)或矩阵乘加之外。

例如,H100 GPU SXM5具有989TFLOPS的FP16矩阵乘法,但对于特殊的函数SF,只有 3.9TFLOPS的吞吐,吞吐量低 256 倍。

CUDA 编程指南规定,特殊函数的吞吐量为每个时钟周期每个流式多处理器 (SM) 16次操作。将16乘以132SM和1830 Mhz(用于计算 FP16 matmul 的989TFLOPS 的时钟速度)得到 3.9TFLOPS!

假如注意力机制的head维度为128,matmul FLOPS比指数运算多512倍,这意味着与matmul运算相比,花费在指数运算的时间需要比矩阵运算多50%的时间。Matmul在FP8的精度下速度比FP16还要快多两倍,这样一来就被指数运算严重的拖后腿!能有魔法棒实现两者并行么?

上面文绉绉的话翻译成白话就是:GEMM比Softmax快,如何让两者并驾齐驱?

Warp是SM中的基本概念,可以先回去温习下GPU的组成。Warp其实已经做了一些调度的事宜,某些Warp被阻塞,其他翘曲可以运行。

例如存在 2个warpgroup(标记为 1 和 2),每个warpgroup是4个warp 的组),这时候通过使用同步屏障 (bar.sync),以便warpgroup 1首先执行它的GEMM。例如,一次迭代的GEMM1和下一次迭代的 GEMM0。然后warpgroup 2执行它的GEMM,而warpgroup 1执行它的softmax, 等等。这个类似乒乓球的调度方式,确保了两者并驾齐驱。上图相同颜色的为相同的迭代。

这种方式在实践中,调度并不是真的这么妥帖,但是这样的调度可以将 FP16 注意力前向传递从大约 570 TFLOPS提高到620 TFLOPS(头部head 128维,序列长度8K)。

即使在一个Warpgroup中,可以在这个群组运行GEMM的时候运行softmax的某些部分。如下图所示:

<非工科读者跳过!>具体的原理在于在注意力算法中,内部循环(主循环)内的操作具有顺序依赖性,这些依赖性会阻碍单次迭代中的并行化。例如,(本地)softmax 18-19行依赖于第一个 GEMM 的输出,而第二个 GEMM 将其结果作为操作数。实际上,算法 1 的第 17- 21行中的等待语句序列化了softmax 和GEMM的执行。但是可以通过寄存器中的额外缓冲区在迭代之间流水线来打破这些依赖关系。遵循这一思路,FL3提出了以下两阶段GEMM-softmax流水线算法:

<继续>这种流水线将吞吐量从大约620 TFLOPS提高到大约640-660 TFLOPS,用于FP16注意力向前转移,但代价是更高的寄存器压力,因为需要更多的寄存器来容纳GEMM的累加器和softmax的输入/输出。

扩展上述 2 阶段算法,FL3继续提出了一个3阶段变体,该变体将进一步重叠第二个WGMMA与softmax。虽然这种方法提供了更高的 Tensor Core 利用率的潜力,但它需要更多的寄存器。

FP8的量化支持

FP8和FP32在寄存器中的存储布局的不一致给FL3的算法带来了挑战。

对于 FP8 FlashAttention-3, 𝐕在将分片加载到SMEM后进行内核内转置。对于内核内转置,我们利用了LDSM ( ldmatrix ) 和STSM ( stmatrix )指令,它们涉及一系列线程共同加载 SMEM到RMEM,并以 128 字节的粒度存储 RMEM 到 SMEM。

LDSM/STSM指令都是高效的,允许在warpgroup中执行,并且能够在执行内存复制时转置布局。在第一次迭代之后,可以在前一个𝐕切片和当前 𝐊切片的WGMMA运算中,加入下一个𝐕切片的转置。

使用 FP8 (e4m3) 格式,仅使用3位来存储尾数,使用4位来存储指数。这导致比FP16/BF16更高的数值误差。此外,大型模型通常具有异常值,它的量级比大多数其他值大得多,这使得量化变得困难。为了减少 FP8中注意力机制的误差,FL3采用了两种技术:

  • 块量化:为每个块保留一个标量,以便对于每个Q,K,V 将其张量拆分为大小𝐵𝑟×𝑑 𝐵𝑐×𝑑 块,然后独立量化。这种量化可以与注意力之前的操作融合,而不会额外减慢速度。由于FlashAttention-3算法都是基于快进行计算,因此可以缩放每个S块进行量化,而无需计算成本。

  • 利用QuIP的非相干处理,将Q和K与随机正交矩阵相乘,以“分散”异常值并减少量化误差。<不明白可以跳过,后面专栏介绍这种算法>。

在实验中,Q、K、V是由标准正态分布生成的,但0.1%的条目具有较大的量级(模拟异常值),我们发现非相干处理可以将量化误差减少 2.6倍。下表为数值误差比较。

性能对比

下面展示了FlashAttention-3的一些结果,并将其与FlashAttention-2以及 Triton和cuDNN中的实现进行了比较(两者都已经使用了Hopper GPU 的新硬件功能)。对于FP16,FlashAttention-2的加速约为1.6倍至 2.0倍。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/801245.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

绘制science研究区域概况图

Python绘制研究区域概况图&#xff1a; import numpy as np import matplotlib.pyplot as plt from cartopy.mpl.ticker import LongitudeFormatter, LatitudeFormatter import cartopy.crs as ccrs import cartopy.feature as cfeature from cartopy.io.shapereader import R…

【螺旋矩阵】python刷题记录

目录 题目&#xff1a; 思路&#xff1a; 上代码&#xff1a; 结果&#xff1a; 大佬题解&#xff1a; 思路&#xff1a; 代码&#xff1a; 结果&#xff1a; 题目&#xff1a; 这题 模拟 来做 思路&#xff1a; 方向转换&#xff1a;使用矩阵&#xff08;4个元素&…

el-table的selection多选表格改为单选

需求场景: 选择表格数据时&#xff0c;需要控制单条数据的操作按钮是否禁用。 效果图: html代码: <div><el-tableref"multipleTable":data"tableData"tooltip-effect"dark"style"width: 100%"selection-change"handl…

【vue+el-table】实现表尾合计行分两行显示,一行显示勾选项之和,一行显示合计,已实现,具体思路解析

效果图&#xff1a; 思路解析&#xff1a; 首先进行了el-table列表的组件封装&#xff0c;很多参数是传进来的。如果是普通的列表&#xff0c;相关参数直接定义就行 1、使用el-table的summary-method处理表尾行 &#xff08;1&#xff09;定义summaryIndex用于指定合计在哪一列…

【Springboot】新增profile环境配置应用启动失败

RT 最近接手了一个新的项目&#xff0c;为了不污染别人的环境&#xff0c;我新增了一个自己的环境配置。结果&#xff0c;在启动的时候总是失败&#xff0c;就算是反复mvn clean install也是无效。 问题现象 卡住无法进行下一步 解决思路 由于之前都是能启动的&#xff0c…

防火墙小试——部分(书接上回)NAT

toop接上回 1.实验拓扑及要求 前情回顾 DMZ区内的服务器&#xff0c;办公区仅能在办公时间内&#xff08;9&#xff1a;00 - 18&#xff1a;00&#xff09;可以访问&#xff0c;生产区的设备全天可以访问. 生产区不允许访问互联网&#xff0c;办公区和游客区允许访问互联网 …

【Linux】常用命令总结(updating)

1.date2.du&#xff08;disk use&#xff09;3.df&#xff08;disk free&#xff09;4.find5.crontab6.netstat shell命令可以使用man查看命令文档说明&#xff0c;说明界面中可通过b(backward)向上翻页&#xff0c;f(forward)向下翻页&#xff0c;g(go to)跳到说明首页&#x…

五层模型讲解

TCP/IP 模型协议分层: 应用层: HTTP: 超文本传输协议(网站访问web)(Apache、nginx)(IIS)e FTP : 文件传输协议(网络文件传输)&#xff0c; TFTP: 简单文件传输协议(交换机和路由器系统重装)&#xff08;和console线有关&#xff09; SMTP:简单邮件传输协议(发信) POP3:邮…

Azure Repos 仓库管理

从远端仓库克隆到本地 前提:本地要安装git,并且登录了账户 1.在要放这个远程仓库的路径下,打git 然后 git clone https://.. 如果要登录验证,那就验证下 克隆完后,cd 到克隆的路径, 可以用 git branch -a //查看分支名 git status //查看代码状态 删除…

【Linux网络】poll{初识poll / poll接口 / poll vs select / poll开发多客户端echo服务器}

文章目录 1.初识pollpoll与select的主要联系与区别poll的原理poll的优点poll的缺点poll vs select 2.poll开发多客户端echo服务器封装套接字接口Makefile主函数日志服务聊天服务器 1.初识poll poll是Linux系统中的一个系统调用&#xff0c;它用于监控多个文件描述符&#xff08…

qtcrerator,代码屏蔽了,断点还是能进去,新增的代码没反应。编译无报错,无异常

qtcrerator&#xff0c;代码屏蔽了&#xff0c;断点还是能进去&#xff0c;新增的代码没反应。编译无报错&#xff0c;无异常 办法都试过了&#xff0c;没有用。无法生产新的exe 原因&#xff1a;我把工程复制了一份&#xff0c;然后改了工程名 结果分析&#xff1a;编译运行…

Perl之正则表达式

正则表达式(regular expression)描述了一种字符串匹配的模式&#xff0c;可以用来检查一个串是否含有某种子串、将匹配的子串做替换或者从某个串中取出符合某个条件的子串等。   Perl语言的正则表达式功能非常强大&#xff0c;基本上是常用语言中最强大的&#xff0c;很多语言…

vscode常用组件

1.vue-helper 启用后点击右下角注册&#xff0c;可以通过vue组件点击到源码里面 2.【Auto Close Tag】和【Auto Rename Tag】 3.setting---Auto Reveal Exclude vscode跳转node_modules下文件&#xff0c;没有切换定位到左侧菜单目录> 打开VSCode的setting配置&#xff…

Umi.js 项目中使用 Web Worker

1.配置 Umi.js 在 Umi.js 中&#xff0c;需要通过配置来扩展 Webpack 的功能。在项目根目录下修改 config/config.ts 文件&#xff1a; export default defineConfig({chainWebpack(config) {config.module.rule(worker).test(/\.worker\.ts$/).use(worker-loader).loader(wo…

vue2使用MarkDown的回显与生成自定义目录

最终实现效果图&#xff1a; 1.回显markdown 1.1安装mark npm install marked -s 1.2使用 //导入 import {marked} from marked // data(){return{textDatas: "",} },methods: {getData() {//获取数据axios({url: "http://localhost:8889/articles/view/158…

操作系统科普与入门之进程篇

文章目录 ⭐前言一、浅谈OS的各个管理模块对应的硬件资源二、从OS的各个管理模块浅谈进程管理2.1 什么是进程&#xff1f;2.2 我知道进程是啥啦&#xff0c;那么OS怎么知道进程的呢&#xff1f; 三、OS是如何进行进程管理&#xff1f;3.1 进程状态转换3.1.1 创建态3.1.2 运行态…

深度学习5 神经网络

生物神经网络是指人的大脑&#xff0c;这是人工神经网络的技术原型。根据生物神经网络的原理&#xff0c;人们用计算机复现了简化的神经网络。当然&#xff0c;人工神经网络是机器学习的一大分支。 1.基本组成 1.1神 经 元 神经元是神经网络的基本组成。激活函数又称作激励函…

所有权与生命周期:Rust 内存管理的哲学

所有权与生命周期&#xff1a;Rust内存管理的哲学 博主寄语引言&#xff1a;编程语言的内存管理困境与 Rust 的解决方案。所有权基本概念&#xff1a;资源的绝对主权生命周期的理解与应用&#xff1a;编译时的守护神借用与引用的精妙设计&#xff1a;安全与效率的和谐共舞Rust …

VL02N 创建过账时删除订单号显示

VL02N 删除订单号显示 VL02N 交货过账 删除 交货单 & 物料凭证 & 会计凭证 上的订单号值 目录 VL02N 删除订单号显示 目录 交货单订单号值删除物料凭证订单号值删除会计凭证订单号删除 删除BSEG表的订单号值删除ACDOCA表的订单号值 交货单订单号值删除 增强点L…