FlashAttention v1 论文解读

论文标题:FlashAttention: Fast and Memory-Efficient Exact Attention with IO-Awareness

论文地址:https://arxiv.org/pdf/2205.14135

FlashAttention 是一种重新排序注意力计算的算法,它无需任何近似即可加速注意力计算并减少内存占用。所以作为目前LLM的模型加速它是一个非常好的解决方案,本文介绍经典的V1版本。

目前FlashAttention已经推出了V1~V3版本,遗憾的是,FlashAttention V3目前只支持Nvidia Hopper架构的GPU。目前transformers库已经集成了FlashAttention。

【注】穷人玩不起系列。

FlashAttention是用于在训练或推理时加速注意力计算的方法,参考其官方仓库可以看到对于训练精度和显卡还是有较大限制的:

https://github.com/Dao-AILab/flash-attention

带有 CUDA 的 FlashAttention-2 目前支持:

GPU架构 Ampere, Ada, or Hopper GPUs(例如 A100、RTX 3090、RTX 4090、H100)。对Turing GPU(T4、RTX 2080)的支持即将推出,目前请为Turing GPU 使用 FlashAttention 1.x。

数据类型 fp16 和 bf16(bf16 需要Ampere, Ada, or Hopper GPUs)。

标准注意力机制

在介绍FlashAttention前,一定要深入了解标准注意力机制计算原理。

在 Transformer 架构当中,Attention 是整个模型中最重要的运算,而这个 Attention 的运算示意图如下:

图1 标准注意力计算图示

首先我们把 Q Q Q K K K做矩阵相乘,接下来就是除以隐藏层维度的开根号 d \sqrt{d} d ,然后我们会把运算出来的结果 S(Score)丢进 Softmax 函数得到 P,最后 P 再和 V V V做矩阵相乘就会得到 Attention 的输出 O O O

但实际上我们会发现这一连串的运算非常的耗时间,且会使用到非常大量的内存。

在我们的 GPU 架构中,可以把内存简单地分成 HBM(高带宽内存)和 SRAM(静态随机存取存储器)两个部分。

HBM 的内存空间虽然很大,但是它的带宽比较低。

SRAM 的内存空间虽然很小,但是它的带宽非常高。

所以我们常常看到 GPU 的参数,像是 Nvidia RTX 4090 24GB,就是这张 GPU 有大约 24GB 大小的 HBM。而 SRAM 这块又贵又小的内存,就是拿来做运算的。

图2 GPU存储架构与FlashAttention计算示例

因此我们可以看到今天你在GPU 上运行 标准Attention 的流程如下(N:序列长度、d 是 隐藏层维度):

图3 标准注意力机制流程

首先我们会把 Q Q Q K K K从 HBM 拉到 SRAM 运算,接下来把算出来的结果 S S S写回 HBM,然后 GPU 又把 S S S拉到 SRAM 计算 S o f t m a x Softmax Softmax,算出来的 P P P又写回 HBM,最后 P P P V V V从 HBM 写到 SRAM 做矩阵运算,最后输出 O O O写回 HBM。

而实际情况当然没那么简单,我们知道 SRAM 这块内存又贵又小,所以当然不可能直接把整个 Q Q Q或是 K K K加载进 SRAM,而是一小块一小块地加载。所以这样大量的读写导致 Attention 运算速度很慢,而且会有内存碎片化问题。

【注】有了上面的背景之后,我们来看看FlashAttention V1是如何优化的,下面为大家带来FlashAttention V1论文精读。

Abstract

针对Transformer在处理长序列时速度慢、内存消耗大的问题,论文提出了FlashAttention,一种IO感知的精确注意力算法。该算法通过使用平铺(tiling)技术减少GPU内存(HBM)与SRAM之间的内存读写次数,从而降低计算复杂性。

分析显示,FlashAttention减少了HBM访问次数,并优化了SRAM使用。此外,本研究将FlashAttention扩展至块稀疏注意力,实现了比现有近似注意力方法更快的近似注意力算法,为长序列处理提供了高效解决方案。

【注】标准自注意力机制的时间复杂度是 O ( n 2 ∗ d ) O(n^2*d) O(n2d),其中 n n n是序列长度, d d d是隐藏层维度。多头注意力只是把 d d d进行了多头拆分,单头的时间复杂度是 O ( n 2 ∗ d h ) O(n^2*d_h) O(n2dh),其中 d h d_h dh是单头的隐藏层维度,虽然多头之间可以并行计算,但是仍然没有解决平方量的复杂度。

Introduction

目前许多优化 attention 的方法旨在降低 attention 的计算和内存需求。这些方法专注于减少 FLOP,并且倾向于忽略内存访问 (IO) 的开销。

但是本文认为attention的一个优化方向是使算法具有 IO 感知能力

【注】也就是说,让求注意力的操作尽可能放在SRAM里,而不是频繁的让SRAM与HBM通信。

现代的GPU,计算速度超过了内存IO速度,当读取和写入数据可能占据运行时间的很大一部分时,IO 感知算法对于加速与降内存就变得很重要了。并且深度学习的常见 Python 库(如 PyTorch 和 Tensorflow)目前还不允许对内存访问进行精细控制。

因此,FlashAttention应运而生。

论文提到,为了实现计算注意力时多使用SRAM而少与HBM交换数据,需要克服两点:

  1. 在输入不完整的情况下,计算 S o f t m a x Softmax Softmax
  2. 不存储用于反向传播的中间结果;

FlashAttention

第一招:内核融合(Kernel Fusion)

相信聪明的朋友立刻就能明白,何必这样反复加载和卸载,一次性在SRAM中完成所有计算不就好了?没错,这就是FlashAttention的精髓之一。

FlashAttention就是直接将 Q K V QKV QKV一次性加载到SRAM中完成所有计算,然后再将 O O O写回HBM。

这样大大减少了读写次数,这种一次性完成所有计算的流程被称为内核融合(Kernel Fusion)。

图4 内核融合示意图

第二招:反向重计算(Backward Recomputation)

但是等一下,我们是不是忘了什么?我们直接计算出了 O O O,那么 P P P S S S难道就直接丢弃不存回HBM吗?在进行反向传播时,我们需要从 O O O推回 P P P,再从 P P P推回 S S S,它们都被我们丢弃了,怎么进行反向传播?没错,这就是FlashAttention的第二招,反向重计算(Backward Recomputation)。

因为 P P P S S S这两者实在太占用空间了,所以

在前向传播时, P P P S S S都不会被存储起来。当进行反向传播时,我们就会重新计算一次前向传播,重新计算出 P P P S S S,以便执行反向传播。

所以说:我们执行了2次前向传播和1次反向传播。

这里大家可能又会问:啊这样计算量不是更多了吗,怎么可能会更快?事实上,虽然我们重新计算了一次前向传播,但它不仅帮我们省下了存储P和S的内存空间,还省下了 P P P S S S在HBM和SRAM之间搬运的时间,让我们可以开启更大的batch size,所以总的来说,GPU每秒能处理的数据量依然是大幅增加的。

第三招:Softmax分块(Softmax Tiling)

最后是FlashAttention的最后一招分块(Tiling)。首先我们需要知道注意力机制中的最难搞的就是 S o f t m a x Softmax Softmax函数:

s o f t m a x ( { x 1 , . . . , x N } ) = { e x i ∑ j = 1 N e x j } i = 1 N (1) softmax(\{x_1, ..., x_N\}) = \left\{\frac{e^{x_i}}{\sum_{j=1}^N e^{x_j}}\right\}_{i=1}^N \tag1 softmax({x1,...,xN})={j=1Nexjexi}i=1N(1)

主要原因是在计算分母时,我们需要将所有位的exp值加总。但由于SRAM的大小限制,我们不可能一次性计算出所有数值的 S o f t m a x Softmax Softmax,一定是需要一块一块地丢进SRAM进行计算,所以需要将所有中间计算的数值存储在HBM中。

在FP16精度下,最大可以表示65536,而

e 12 = 162754 e^{12} = 162754 e12=162754

为了防止在计算 S o f t m a x Softmax Softmax产生数值溢出,引入了 S a f e − s o f t m a x Safe-softmax Safesoftmax概念,其公式如下:

S a f e − s o f t m a x ( { x 1 , . . . , x N } ) = { e x i − m ∑ j = 1 N e x j − m } i = 1 N (2) Safe-softmax(\{x_1, ..., x_N\}) = \left\{\frac{e^{x_i-m}}{\sum_{j=1}^N e^{x_j-m}}\right\}_{i=1}^N \tag2 Safesoftmax({x1,...,xN})={j=1Nexjmexim}i=1N(2)

在公式(2)中,有如下定义:

x = [ x 1 , . . . , x N ] (3) x=[x_1,...,x_N] \tag3 x=[x1,...,xN](3)

m ( x ) : = m a x ( x ) (4) m(x):=max(x) \tag4 m(x):=max(x)(4)

p ( x ) : = [ e x 1 − m ( x ) , . . . , e x N − m ( x ) ] (5) p(x):=[e^{x_1-m(x)},...,e^{x_N-m(x)}] \tag5 p(x):=[ex1m(x),...,exNm(x)](5)

l ( x ) : = ∑ i p ( x ) i (6) l(x):=\sum_ip(x)_i \tag6 l(x):=ip(x)i(6)

s o f t m a x ( x ) : = p ( x ) l ( x ) (7) softmax(x):=\frac{p(x)}{l(x)} \tag7 softmax(x):=l(x)p(x)(7)

其原理就是,从 x x x中找出最大值 m m m,在计算 S o f t m a x Softmax Softmax时,分子分母同除以 e m e^m em,这样既可以防止数据溢出,也能保证 S o f t m a x Softmax Softmax值保持不变。

【注】类似于归一化。

x = [ x 1 , … , x N , … , x 2 N ] x 1 = [ x 1 , … , x N ] x 2 = [ x N + 1 , … , x 2 N ] m ( x 1 )   p ( x 1 )   l ( x 1 )   m ( x 2 )   p ( x 2 )   l ( x 2 ) m ( x ) : = max ⁡ ( m ( x 1 ) , m ( x 2 ) ) p ( x ) : = [ e m ( x 1 ) − m ( x ) p ( x 1 ) , e m ( x 2 ) − m ( x ) p ( x 2 ) ] l ( x ) : = e m ( x 1 ) − m ( x ) l ( x 1 ) + e m ( x 2 ) − m ( x ) l ( x 2 ) s o f t m a x ( x ) : = p ( x ) l ( x ) (8) \begin{align*} & x = [x_1, \ldots, x_N, \ldots, x_{2N}] \\ & x^1 = [x_1, \ldots, x_N] \\ & x^2 = [x_{N+1}, \ldots, x_{2N}] \\ & m(x^1) \ p(x^1) \ l(x^1) \ m(x^2) \ p(x^2) \ l(x^2) \\ & m(x) := \max(m(x^1), m(x^2)) \\ & p(x) := [e^{m(x^1)-m(x)} p(x^1), e^{m(x^2)-m(x)} p(x^2)] \\ & l(x) := e^{m(x^1)-m(x)} l(x^1) + e^{m(x^2)-m(x)} l(x^2) \\ & softmax(x) := \frac{p(x)}{l(x)} \end{align*}\tag8 x=[x1,,xN,,x2N]x1=[x1,,xN]x2=[xN+1,,x2N]m(x1) p(x1) l(x1) m(x2) p(x2) l(x2)m(x):=max(m(x1),m(x2))p(x):=[em(x1)m(x)p(x1),em(x2)m(x)p(x2)]l(x):=em(x1)m(x)l(x1)+em(x2)m(x)l(x2)softmax(x):=l(x)p(x)(8)

而本文softmax分块的做法如公式(8)所示。

我们首先将一块数据 x x x中的第一块 x 1 x_1 x1丢进去计算出softmax,这里的 m 1 m_1 m1代表的是这一块加载到SRAM的最大值,所以我们称之为局部最大值。接下来,我们可以根据 m 1 m_1 m1计算出局部softmax。

接下来第二块数据进来时,我们将第一块的最大值 m 1 m_1 m1和第二块的最大值 m 2 m_2 m2取最大值,就可以得到这两块数据的最大值 m ( x ) m(x) m(x)。这个时候定义 p ( x ) : = [ e m ( x 1 ) − m ( x ) p ( x 1 ) , e m ( x 2 ) − m ( x ) p ( x 2 ) ] p(x) := [e^{m(x^1)-m(x)} p(x^1), e^{m(x^2)-m(x)} p(x^2)] p(x):=[em(x1)m(x)p(x1),em(x2)m(x)p(x2)],再与公式(5)结合,只会出现两种情况:

  1. m ( x 1 ) m(x^1) m(x1)最大,最后可化简为 p ( x ) : = [ e x 1 − m ( x 1 ) , . . . , e x N − m ( x 1 ) ] p(x) := [e^{x_1-m(x^1)},...,e^{x_N-m(x^1)}] p(x):=[ex1m(x1),...,exNm(x1)]
  2. m ( x 2 ) m(x^2) m(x2)最大,最后可化简为 p ( x ) : = [ e x 1 − m ( x 2 ) , . . . , e x N − m ( x 2 ) ] p(x) := [e^{x_1-m(x^2)},...,e^{x_N-m(x^2)}] p(x):=[ex1m(x2),...,exNm(x2)]

l ( x ) l(x) l(x)的计算化简也同理,所以我们只需要将第一块的局部softmax乘上这次更新的数值。如此一来,我们就得到了这两块的局部softmax。

没错!接下来依此类推,我们就可以将整个softmax计算完。而通过这种方式:

我们就不需要将每块计算出来的数值存储在HBM中,我们只需要存储当前的最大值 m ( x ) m(x) m(x)和分母加总值 l ( x ) l(x) l(x)就可以了。

而这两者都非常小,所以可以进一步帮我们节省更多内存空间。

另外,这里还有一个小细节,就是由于softmax计算出来后需要与value state进行矩阵相乘,但同样由于SRAM有限,我们一次只能加载一块进行内核融合运算,所以第一块QKV进去后,它计算出来的O是不准确的。但由于矩阵相乘就是数字相乘,所以同样道理,我们只要在计算到下一块时,使用l和m更新O就可以了。

我们可以看到实际的流程就是这样,蓝色的区域就是HBM,橙色虚线的区域就是SRAM。每次运算时,由于SRAM大小有限,所以我们只加载一部分的Key和Value。红色的字就是我们的第一个block的计算,蓝色的字就是我们的第二个block的计算。

图5 block计算演示图

这边我们可以更深入探讨算法和实现部分。静态随机存取存储器(SRAM)容量较小,当序列长度很长时,根本不可能一次性将如此庞大的查询(query)、键(key)、值(value)状态全部塞进SRAM。

一开始我们会把查询状态(Query State)切成 T r T_r Tr块,键/值状态(Key/Value State)切成 T c T_c Tc块,查询状态块的大小为 ( B r , d ) (B_r, d) (Br,d),键/值状态块的大小为 ( B c , d ) (B_c, d) (Bc,d)。切好的这些块再放入SRAM进行Flash Attention运算。 你可能会好奇 B r B_r Br B c B_c Bc是什么神奇的数字,其实非常简单, M M M是我们SRAM的大小,并且查询(Q)、键(K)、值(V)、输出(O)这四个矩阵大小完全相同,所以当然是 M / 4 d M/4d M/4d啦,这样Q、K、V、O四个矩阵的块加起来不就刚好是 M M M嘛,也就是说刚好填满SRAM。

比如说,假设M = 1000, d = 5。那么块大小为(1000/4*5)= 50。所以一次加载50个q, k, v, o个向量的块,这样可以减少HBM/SRAM之间的读/写次数。

性能

图6 FlashAttention V1性能

我们可以看到 FlashAttention 大大地加速了运算,达到 3 倍以上。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/963327.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

Vue - shallowRef 和 shallowReactive

一、shallowRef 和 shallowReactive (一)shallowRef 在 Vue 3 中,shallowRef 是一个用于创建响应式引用的 API,它与 ref 相似,但它只会使引用的基本类型(如对象、数组等)表现为响应式&#xf…

【深度学习】softmax回归的简洁实现

softmax回归的简洁实现 我们发现(通过深度学习框架的高级API能够使实现)(softmax)线性(回归变得更加容易)。 同样,通过深度学习框架的高级API也能更方便地实现softmax回归模型。 本节继续使用Fashion-MNIST数据集,并保持批量大小为256。 import torch …

ESP32-c3实现获取土壤湿度(ADC模拟量)

1硬件实物图 2引脚定义 3使用说明 4实例代码 // 定义土壤湿度传感器连接的模拟输入引脚 const int soilMoisturePin 2; // 假设连接到GPIO2void setup() {// 初始化串口通信Serial.begin(115200); }void loop() {// 读取土壤湿度传感器的模拟值int sensorValue analogRead…

【python】python油田数据分析与可视化(源码+数据集)【独一无二】

👉博__主👈:米码收割机 👉技__能👈:C/Python语言 👉专__注👈:专注主流机器人、人工智能等相关领域的开发、测试技术。 【python】python油田数据分析与可视化&#xff08…

代码讲解系列-CV(一)——CV基础框架

文章目录 一、环境配置IDE选择一套完整复现安装自定义cuda算子 二、Linux基础文件和目录操作查看显卡状态压缩和解压 三、常用工具和pipeline远程文件工具版本管理代码辅助工具 随手记录下一个晚课 一、环境配置 pytorch是AI框架用的很多,或者 其他是国内的框架 an…

HTB:Alert[WriteUP]

目录 连接至HTB服务器并启动靶机 信息收集 使用rustscan对靶机TCP端口进行开放扫描 使用nmap对靶机TCP开放端口进行脚本、服务扫描 使用nmap对靶机TCP开放端口进行漏洞、系统扫描 使用nmap对靶机常用UDP端口进行开放扫描 使用ffuf对alert.htb域名进行子域名FUZZ 使用go…

小红的合数寻找

A-小红的合数寻找_牛客周赛 Round 79 题目描述 小红拿到了一个正整数 x,她希望你在 [x,2x] 区间内找到一个合数,你能帮帮她吗? 一个数为合数,当且仅当这个数是大于1的整数,并且不是质数。 输入描述 在一行上输入一…

Linux环境下的Java项目部署技巧:安装 Mysql

查看 myslq 是否安装: rpm -qa|grep mysql 如果已经安装,可执行命令来删除软件包: rpm -e --nodeps 包名 下载 repo 源: http://dev.mysql.com/get/mysql80-community-release-el7-7.noarch.rpm 执行命令安装 rpm 源(根据下载的…

基于springboot+vue的哈利波特书影音互动科普网站

开发语言:Java框架:springbootJDK版本:JDK1.8服务器:tomcat7数据库:mysql 5.7(一定要5.7版本)数据库工具:Navicat11开发软件:eclipse/myeclipse/ideaMaven包:…

在React中使用redux

一、首先安装两个插件 1.Redux Toolkit 2.react-redux 第一步:创建模块counterStore 第二步:在store的入口文件进行子模块的导入组合 第三步:在index.js中进行store的全局注入 第四步:在组件中进行使用 第五步:在组件中…

记录 | 基于MaxKB的文字生成视频

目录 前言一、安装SDK二、创建视频函数库三、调试更新时间 前言 参考文章:如何利用智谱全模态免费模型,生成大家都喜欢的图、文、视并茂的文章! 自己的感想 本文记录了创建文字生成视频的函数库的过程。如果想复现本文,需要你逐一…

Redis|前言

文章目录 什么是 Redis?Redis 主流功能与应用 什么是 Redis? Redis,Remote Dictionary Server(远程字典服务器)。Redis 是完全开源的,使用 ANSIC 语言编写,遵守 BSD 协议,是一个高性…

安全防护前置

就业概述 网络安全工程师/安全运维工程师/安全工程师 安全架构师/安全专员/研究院(数学要好) 厂商工程师(售前/售后) 系统集成工程师(所有计算机知识都要会一点) 学习目标 前言 网络安全事件 蠕虫病毒--&…

开源2 + 1链动模式AI智能名片S2B2C商城小程序视角下从产品经营到会员经营的转型探究

摘要:本文聚焦于开源2 1链动模式AI智能名片S2B2C商城小程序,深入探讨在其应用场景下,企业从产品经营向会员经营转型的必要性与策略。通过分析如何借助该平台优化会员权益与价值,解决付费办卡的接受度问题,揭示其在提升…

让banner.txt可以自动读取项目版本

文章目录 1.sunrays-dependencies1.配置插件2.pluginManagement统一指定版本 2.common-log4j2-starter1.banner.txt使用$ 符号取出2.查看效果 1.sunrays-dependencies 1.配置插件 <!-- 为了让banner.txt自动获取版本号 --><plugin><groupId>org.apache.mave…

音视频多媒体编解码器基础-codec

如果要从事编解码多媒体的工作&#xff0c;需要准备哪些更为基础的内容&#xff0c;这里帮你总结完。 因为数据类型不同所以编解码算法不同&#xff0c;分为图像、视频和音频三大类&#xff1b;因为流程不同&#xff0c;可以分为编码和解码两部分&#xff1b;因为编码器实现不…

openmv运行时突然中断并且没断联只是跟复位了一样

就是 # 内存不足时硬件复位 except MemoryError as me: print("Memory Error:", me) pyb.hard_reset() # 内存不足时硬件复位 很有可能是你的代码加了内存溢出的复位&#xff0c;没加的话他会报错的

Redis集群理解以及Tendis的优化

主从模式 主从同步 同步过程&#xff1a; 全量同步&#xff08;第一次连接&#xff09;&#xff1a;RDB文件加缓冲区&#xff0c;主节点fork子进程&#xff0c;保存RDB&#xff0c;发送RDB到从节点磁盘&#xff0c;从节点清空数据&#xff0c;从节点加载RDB到内存增量同步&am…

77-《欧耧斗菜》

欧耧斗菜 欧耧斗菜&#xff08;学名&#xff1a;Aquilegia vulgaris L. &#xff09;是毛茛科耧斗菜属植物&#xff0c;株高30-60厘米。基生叶有长柄&#xff0c;基生叶及茎下部叶为二回三出复叶&#xff0c;小叶2-3裂&#xff0c;裂片边缘具圆齿。最上部茎生叶近无柄。聚伞花序…

为AI聊天工具添加一个知识系统 之83 详细设计之24 度量空间之1 因果关系和过程:认知金字塔

本文要点 度量空间 在本项目&#xff08;为AI聊天工具添加一个知识系统 &#xff09;中 是出于对“用”的考量 来考虑的。这包括&#xff1a; 相对-位置 力用&#xff08;“相”&#xff09;。正如 法力&#xff0c;相关-速度 体用 &#xff08;“体”&#xff09;。例如 重…