LLM Algorithms(1): Flash Attention

目录

Background

Flash Attention

Flash Attention Algorithm

参考


NIPS-2022:  Flash Attention: Fast and Memory-Efficient Exact Attention with IO-Awareness

  • idea:减少资源消耗,提升或保持模型性能。
  • 普通attention的空间复杂度是O(N^2) --》降低到Flash Attention O(N)
  • Exact 结果相等。这不是attention的近似计算,Flash Attention的计算结果和原始方法一致。
  • IO aware. 和传统attention相比,Flash Attention会考虑硬件特性,而不是把它当作黑盒。 

Background

Nvidia GPU (GPU性能指标 = FLOPS / GB/s,FLOPS, GPU计算能力--每秒计算速度;GB/s,GPU内存吞吐量

  1. 2016-P100
  2. 2018-V100
  3. 2020-A100
  4. 2022-H100

多年来,GPU的计算能力(FLOPS)的增长速度比增加内存吞吐量(TB/s)更快。 

这两者需要紧密配合去达到数据处理的最优比,但自从硬件失去了这种平衡,我们必须通过软件来进行补偿。因此需要算法能够感知IO (IO-aware)。根据计算和内存访问比例,一个操作可以分为:

  1. 计算受限型 (e.g. 矩阵乘法)
  2. 内存受限型
    1. Element-wise 逐元素操作: activation, dropout, masking.
    2. Reduction 操作: softmax, layer norm, sum. 

element-wise操作是指在计算时只依赖当前值,比如每个元素都乘以2。而reduction依赖所有值(比如整个矩阵或矩阵的行),比如softmax。 

attention的计算时内存受限的,因为它的大部分计算都是element-wise的。 

尽管masking、softmax和dropout操作占用了大部分时间,但大部分FLOPS都用在矩阵乘法中,虽然他们花的时间不多。即数据太庞大,attention计算内存不足,或者说内存利用效率太低!

可以通过内存调整去加速masking、softmax和dropout这些操作呢,但是具体咋办? 

人们都知道把大矩阵切分成小块,但如何保证切分小块的计算结果=原attention计算结果?  

扩展:在计算机体系结构里,内存不是单一的构建,内存存储都是分层的。一般规则是:Memory IO speed 内存速度越快,成本越高,容量越小。

  1. GPU SRAM,19TB/s (20 MB),Static RAM, 静态随机存储器
  2. GPU HBM,1.5TB/s (40 GB),high Boardwidth memory, 高带宽内存 
  3. GPU DRAM,12.8GB/s (>1 TB),main memory

实际上,要充分利用内存、实现IO-aware,关键在于充分利用静态随机存取存储器 (SPAM)比高带宽内存 (HBM)快得多的事实,确保减少两者之间的通信。

(HBM,这是导致CUDA内存溢出的因素之一) 

Flash Attention

Flash Attention 采样分而治之的思想,将大矩阵切块加载到SRAM中,计算每个分块的m和l值。利用上一轮m和l值结合新的子块迭代计算,最终计算出整个矩阵的树枝。Flash Attention基本上可以归结为两个主要思想:

  •  Tiling (在前向和后向传递中使用) - 简单讲就是将NxN的softmax分数矩阵划分为块。
  • 重新计算(因为每个块的系数不一样,Flash Attention每融合一个小块,就需要调整一下之前块的系数,去保持一致!)
  • 传统attention需要分配完整的NxN矩阵(S, P),这是main需要解决的瓶颈,这也是Flash Attention主要解决的问题,将复杂度从O(N^2)降低到O(N)

整个过程不用存储中间变量S和P矩阵,节省了效率因为Attention 操作最大的问题就是每次操作都要从HBM把数据加载到GPU SRAM,运算结束后又从SRAM复制到HBM。这类似于cpu的寄存器与内存的关系,因此最容易的优化方法就是避免这种数据的来回移动,即编译器行话"kernel fusion"。

Flash Attention Algorithm

假设输入一个一维向量x^{(i)} = [x_1,x_2,...,x_B],对应于QK=Sij相似度矩阵中的一行向量。 

1. softmax分块计算:

  • m(x) = max(xi),这是rowmax 操作这是单个值
  • f(x) = [e^{x_1-m(x),..., e^{x_B-m(x)}}]。对应原公式的\tilde{P}_{ij}then why xi-m(x)?这是为了数值稳定,每个数减去相同的任一常量,其softmax值不变。==》减去最大的元素,保证最大值为e^0=1,因为在0~1之间时,浮点数的精度是最大的。
  • l(x) = \sum_if(x)_i,对应原公式\tilde{l}_{ij}这是rowsum 操作
  • so\!ftmax = \frac{f(x)}{l(x)}, softmax除法可以写成diag(l(x))^{-1},把l(x)拉伸成diag的主要原因是把更新公式写成矩阵乘法的形式

2. Flash Attention每次都是合并两块:previous blocks result + latest block。如何保证每一个小块的合并结果与原有attention结果相同?搞好softmax系数的一致性!

  •  因为each step都需要重新计算m(x) = max(m^{(i)}),而m(x)是变的,前面blocks在合并之前,需要先通过m_i - m_i^{new}修正之前block的系数,\tilde{m}_{ij}是指第ij单个block的max(x),不涉及之前blocks的max值
  • m(x) = m([x^{(1), x^{(2)}}]) = max(m(x^{(1)}, m^{(2)}))
  • f(x) = [e^{m(x^{(1)})-m(x))}f(x^{(1)}, e^{m(x^{(2)})-m(x))}f(x^{(2)})]
  • l(x) = e^{m(x^{(1)})-m(x))}l(x^{(1)}, e^{m(x^{(2)})-m(x))}l(x^{(2)})修正系数m_i - m_i^{new}保持一致,因为这两个blocks的softmax系数不一致,m(x^{(2)})-m(x)保证最新的single block的softmax系数与之前的一致!
  • so\!ftmax = \frac{f(x)}{l(x)}

举例:假设x \in R^6,并且它被分成3块:x^{(1)} = [1,3]x^{(2)} = [2,4]x^{(3)} = [3,2]

我们先计算前两块:

  • m(x^{(1)})=3, f(x^{(1)})=[e^{-2},1], l(x^{(1)})=(e^{-2}+1)
  • m(x^{(2)})=4, f(x^{(2)})=[e^{-2},1], l(x^{(2)})=(e^{-2}+1)

我们根据上面的结果计算前两块的结果:

  • m(x) = max(m(x^{(1)}), m(x^{(2)})) = max(3,4)=4
  • f(x) = [e^{3-4}f(x^{(1)}), e^{4-4}f(x^{(2)})]
  • l(x) = e^{3-4}l(x^{(1)}) + e^{4-4}l(x^{(2)})

为什么上面的结果是正确的呢?首先m(x)应该非常明显,4个数中的最大数肯定就是分成两组后的最大中的较大者。而f(x)计算的核心就是在𝑓(𝑥(1))𝑓(𝑥(1))前乘以𝑒3−4𝑒3−4以及在𝑓(𝑥(2))𝑓(𝑥(2))前乘以𝑒4−4𝑒4−4。l(x)的计算和f(x)是类似的。为什么需要在𝑓(𝑥(1))𝑓(𝑥(1))前乘以𝑒3−4𝑒3−4?因为在计算𝑓(𝑥(1))𝑓(𝑥(1))时最大的数是3,因此前两个数的指数都乘以了𝑒−3𝑒−3。但是现在前4个数的最大是4了,后面两个数的指数乘以了𝑒−4𝑒−4,因此直接合并为[𝑓(𝑥(1)),𝑓(𝑥(2))][𝑓(𝑥(1)),𝑓(𝑥(2))]是不对的,需要把前面两个数再乘以𝑒3−4=𝑒−1𝑒3−4=𝑒−1。而后面两个数本来就乘以了𝑒−4𝑒−4,所以不用变

计算output Oi:我们把一个很大的x拆分成长度为B的blocks,用上面的算法计算block 1和block 2,然后合并其结果;接着计算第3块,并将above 结果与第三块合并; ... =》所以,我们在定义时,可以把空块x=[], m(x)=-inf, f(x)=[], l(x)=0,这样我们就可以把第一块block的计算转换成block 1和空块的合并,使得循环可以从第一块开始!

  • O_1 = diag(l_1)^{-1}(0 * 0 + e^{\tilde{m}_{ij}-m_i^{new}}\tilde{P}_{ij}V_j)
  •  O_2 = diag(l_i^{new})^{-1}(diag(l_i)O_ie^{m_i-m_i^{new}} + e^{\tilde{m}_{ij}-m_i^{new}}\tilde{P}_{ij}V_j)

因为Flash Attention不存储中间变量S和P矩阵,所以我们用diag(l_i)O_i反推出之前的PV值,再用e^{m_i-m_i^{new}}修正系数,最后加上第ij块e^{\tilde{m}_{ij}-m_i^{new}}\tilde{P}_{ij}V_j) with single e^{\tilde{m}_{ij}},得到的结果最后再除以diag(l_i^{new})^{-1}保持softmax运算完整性。

参考

Flash Attention论文解读 - 李理的博客

https://gordicaleksa.medium.com/eli5-flash-attention-5c44017022ad

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/695939.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

【PR2019】怎样批量添加转场效果及修改默认持续时间

一,设置“交叉溶解”效果到所有素材 选择效果,右击“将所选过渡设置为默认过渡”: 框选所有素材,“Ctrl D”: 每个素材中间有有了交叉溶解的效果: 二,修改效果属性 2.1,单个修…

1.nginx介绍

介绍 是一个高性能的http和反向代理服务器。 特点 占用内存少,并发能力强。 nginx专为性能优化而开发,性能是其最重要的考量,实现上非常注重效率,能经受高负载的考验,有报告表明能支持高达50,000个并发连接数。 基…

拐点已至:企业如何借助AI重塑增长?

2024年的激进增长与AI数智化创新并行,传统策略的功效已经减弱。在这篇文章中,我们将展望并深度探索2024年的6大创新增长策略,包括AI驱动的实验,产品再造,超个性化,自动化运营,短视频和KOL营销等…

力扣hot100: 48. 旋转图像

LeetCode:48. 旋转图像 受到力扣hot100:54. 螺旋矩阵的启发,我们可以对旋转图像按层旋转,我们只需要记录四个顶点,并且本题是一个方阵,四个顶点就能完成图像的旋转操作。 1、逐层旋转 注意到&#xff0…

Java核心: JarIndex的使用

在讲解Java类加载器的时候,我们发现URLClassLoader加载类或资源时通过访问ClassPath下的每一个路径,来确定类是否存在的,假设我们执行的命令是这样的 java -classpath D:\DiveInSpring\target\classes;C:\lib\spring-expression.jar;C:\lib\…

扩展学习|风险管理的文献综述汇总(持续更新向)

一、风险管理发展历程和趋势综述(2007年发表) 文献来源:[1]严复海,党星,颜文虎.风险管理发展历程和趋势综述[J].管理现代化, 2007(2):4.DOI:CNKI:SUN:GLXX.0.2007-02-009. 简介:该文以风险管理发展历程中的大事件为线索, 对风险管…

第1回 最开始的两行代码

当你按下开机键的那一刻,在主板上提前写死的固件程序BIOS会将硬盘启动区中的512(B)的数据,原封不动地复制到内存中的0x7c00这个位置,并跳转到那个位置: 下面我们针对每一步做详细介绍. 开机后初始化指向BIOS CPU中有一个PC寄存器,里面存储这将要执行的指令在内存中的地…

挑战绝对不可能:再证有长度不同的射线

黄小宁 一空间坐标系中有公共汽车A,A中各座位到司机处的距离h是随着座位的不同而不同的变数,例如5号座位到司机处的距离是h3,…h5,…。A移动了一段距离变为汽车B≌A,B中5号座位到司机处的距离h’h3,…h’h5…

C语言详解文件操作

目录 什么是文件? 为什么使用文件? 程序文件和数据文件、文本文件和二进制文件 1.程序文件和数据文件 1.1程序文件 1.2数据文件 2.文本文件和二进制文件 文件的打开和关闭(流、标准流、文件指针和文件的打开与关闭) 1.流和标…

了解常用智能指针

智能指针 1、概念 C中引入智能指针的主要目的是为了解决内存管理的问题,传统的指针(裸指针)在使用时需要手动分配和释放内存,容易出现内存泄漏和悬挂指针等问题。智能指针通过封装裸指针,并提供自动内存管理功能&…

Python私教张大鹏 Vue3整合Vue Router之编程式导航

除了使用 <router-link> 创建 a 标签来定义导航链接&#xff0c;我们还可以借助 router 的实例方法&#xff0c;通过编写代码来实现。 导航到不同的位置 注意: 下面的示例中的 router 指代路由器实例。在组件内部&#xff0c;你可以使用 $router 属性访问路由&#xff…

spool 管道 小文件 mknod

Spool File In SQL*PLUS in Multiple Small Files ? (Doc ID 2152654.1)​编辑To Bottom In this Document Goal Solution APPLIES TO: Oracle Database - Enterprise Edition - Version 10.2.0.1 to 12.1.0.2 [Release 10.2 to 12.1] Oracle Database Cloud Schema Service…

城镇污水处理设施运维服务认证

初次申请认证时需提交的文件/资料 1、通用文件/资料(证明文件复印件需签字盖公章) ☐ 营业执照复印件、统一社会信用代码/组织机构代码证复印件 ☐ 增值税一般纳税人资格证复印件&#xff0c;或其他增值税一般纳税人资格认定文件复印件 ☐ 资质 或 许可证 复印件&#x…

DNS协议 | NAT技术 | 代理服务器

目录 一、DNS协议 1、DNS背景 2、DNS协议 域名 域名解析 二、NAT技术 1、NAT技术 2、NAPT技术 3、NAT技术的缺陷 三、代理服务器 1、正向代理服务器 2、反向代理服务器 一、DNS协议 域名系统&#xff08;Domain Name System&#xff0c;缩写&#xff1a;DNS&#…

Vue TypeScript 实战:掌握静态类型编程

title: Vue TypeScript 实战&#xff1a;掌握静态类型编程 date: 2024/6/10 updated: 2024/6/10 excerpt: 这篇文章介绍了如何在TypeScript环境下为Vue.js应用搭建项目结构&#xff0c;包括初始化配置、创建Vue组件、实现状态管理利用Vuex、配置路由以及性能优化的方法&#x…

vue2自定义指令

本节目标 快速入门v-loading 快速入门 指令对比 基本语法 使用: v-指令名"指令值"定义: 通过 directives 局部定义或者全局定义通过事件对象 el 可以拿到指令所在元素通过形参 binding 可以拿到指令的传值通过update钩子, 可以监听指令值的变化,进行更新操作 局部…

2024浙江省三支一扶报名流程!超详细图解!

2024浙江省三支一扶报名流程&#xff01;超详细图解&#xff01; 浙江省高校毕业生“三支一扶”报名即将开始&#xff0c;准备报考的同学们做好准备&#xff1a; &#x1f534;重点时间安排&#xff1a; 1、网络报名&#xff1a;6月11日9:00至6月18日17:00 2、资格审核&…

速卖通店铺防关联该怎么做?

大家都知道&#xff0c;想要进行多账号操作必须一再小心&#xff0c;否则会有很大的关联风险&#xff0c;而账号关联所带来的后果是卖家绝对不能轻视的&#xff0c;严重的话会导致封号&#xff0c;这样一来自己前期的辛苦运营就全都打水漂了&#xff0c;因此防关联很重要&#…

C++对象池设计与实现

目录 一、对象池简介 1.1 池化技术 1.2 什么是对象池 1.3 对象池分配策略 二、C new和delete运算符重载 三、实现一个对象池框架 3.1 策略接口 四、实现几种对象池的分配策略 4.1 数组策略 4.2 堆策略 ​编辑 4.3 栈策略 4.4 区块策略 一、对象池简介 1.1 池化技…

【C语言】插入排序(经典算法,建议收藏!!!)

目录 1、原理2、代码展示3、解析代码4、适用场景 1、原理 插入排序&#xff08;Insertion Sort&#xff09;是一种简单直观的排序算法&#xff0c;其原理可以简述如下&#xff1a; 1.分已排序区间和未排序区间: 将数组分为已排序区间和未排序区间。初始时&#xff0c;已排序区…