0 英文缩写
- FA: Flash Attention
- HBM:High Bandwidth Memory,高带宽显存
0 论文
[2205.14135] FlashAttention: Fast and Memory-Efficient Exact Attention with IO-Awareness
中文:FlashAttention:一种具有 IO 感知,且兼具快速、内存高效的新型注意力算法
科研团队:斯坦福大学计算机系+纽约州立大学布法罗分校
发表时间:20220527
1 背景:
- 背景1:应用广泛:Transformer 模型在图像分类、自然语言处理等分支领域中逐渐成为最为常见的架构
- 背景2:模型扩展:随着技术不断进步,Transformer 模型在尺寸和深度等方面都进一步拓展
- 背景3:算法复杂度特征:核心模块自注意力机制(self attention)的时间复杂度和存储复杂度,均与输入长度(一般即为处理的序列长度)的平方成正比
结合背景123,可以发现更大的模型在更长的上下文背景上还存在着一定的挑战。
-
背景4:计算读写开销:论文GPU内不同存储系统的速度举例如下:
- GPU SRAM 读写(I/O)速度19 TB/s
- GPU HBM 读写(I/O)速度 1.5 TB/s
2 相关方案
在此背景之下,有人提出一些近似自注意力的方法,旨在减少注意力计算和内存需求。
- 稀疏近似
- 低秩近似
- 它们的组合
缺点:尽管这些方法可以将计算降低到线性或接近线性,但它们过于关注降低每秒所执行的浮点运算次数(FLops),换句话说更倾向于单纯降低计算复杂度。忽略来自内存访问(IO)的开销。不能实现更高且更有实用价值的计算加速范式。
3 传统Attention
(更详细的推导过程和描述可以参考前文)
Attention机制
其核心为计算输入向量的相关程度
,例如在翻译过程中,不同的英文对中文的依赖程度不同,Attention机制
通常可以进行如下描述
3.1 输入输出定义
- 输入1: Q Q Q 序列(query),其中 { Q = ( q 1 q 2 q 3 ⋮ q m ) ⏟ d k } m ∈ R m × d k , q i ∈ R 1 × d k ∣ i ∈ 1 , 2 , … , m } \left\{Q=\underbrace{\left(\begin{array}{c}q_1 \\ q_2 \\ q_3 \\ \vdots \\ q_m \end{array}\right)}_{d_{k}}\} m \in\mathbb{R}^{m\times d_k}, q_{i}\in\mathbb{R}^{1\times d_k} \mid i\in 1,2, \ldots, m\right\} ⎩ ⎨ ⎧Q=dk q1q2q3⋮qm }m∈Rm×dk,qi∈R1×dk∣i∈1,2,…,m⎭ ⎬ ⎫
- 输入2: K K K 序列 (key),其中 { K = ( k 1 k 2 k 3 ⋮ k m ) ⏟ d k } m ∈ R m × d k , k i ∈ R 1 × d k ∣ i = 1 , 2 , … , m } \left\{K=\underbrace{\left(\begin{array}{c}k_1 \\ k_2 \\ k_3 \\ \vdots \\ k_m\end{array}\right)}_{d_{k}}\} m\in\mathbb{R}^{m\times d_k}, k_{i}\in \mathbb{R}^{1\times d_k} \mid i=1,2, \ldots, m\right\} ⎩ ⎨ ⎧K=dk k1k2k3⋮km }m∈Rm×dk,ki∈R1×dk∣i=1,2,…,m⎭ ⎬ ⎫
- 输入3: V V V 序列 (value) ,其中 { V = ( v 1 v 2 v 3 ⋮ v m ) ⏟ d v } m ∈ R m × d v , v i ∈ R 1 × d v ∣ i = 1 , 2 , … , m } \left\{V=\underbrace{\left(\begin{array}{c}v_1 \\ v_2 \\ v_3 \\ \vdots \\ v_m\end{array}\right)}_{d_{v}}\} m\in\mathbb{R}^{m\times d_v}, v_{i}\in \mathbb{R}^{1\times d_v} \mid i=1,2, \ldots, m\right\} ⎩ ⎨ ⎧V=dv v1v2v3⋮vm }m∈Rm×dv,vi∈R1×dv∣i=1,2,…,m⎭ ⎬ ⎫
- 输出为$\text { Attention }(Q, K, V) $ 向量,计算公式:
Attention ( Q , K , V ) ∈ R m × d v = softmax ( Q K T d k ) V \text { Attention }(Q, K, V) \in\mathbb R^{m \times d_{v}}=\operatorname{softmax}\left(\frac{Q K^{T}}{\sqrt{d_{k}}}\right) V Attention (Q,K,V)∈Rm×dv=softmax(dkQKT)V
3.2 算法解析
第一步:矩阵乘法
为什么可以计算得到不同输入向量之间的得分
矩阵乘法
假设共有十个输入向量,每个向量的长度为512,也即为 m = 10 m=10 m=10, d k = 512 d_k=512 dk=512
Q = ( q 1 [ 0 ] ⋯ q 1 [ d k ] ⋮ ⋯ ⋮ q 10 [ 0 ] ⋯ q 10 [ 511 ] ) = ( q 1 ⃗ ⋮ q 10 ⃗ ) Q=\left(\begin{array}{ccc} q_{1}[0] & \cdots & q_{1}[d_k] \\ \vdots & \cdots & \vdots \\ q_{10}[0] & \cdots & q_{10}[511] \end{array}\right) = \left(\begin{array}{c}\vec{q_{1}}\\\vdots\\ \vec{q_{10}} \end{array}\right) Q= q1[0]⋮q10[0]⋯⋯⋯q1[dk]⋮q10[511] = q1⋮q10
K = ( k 1 [ 0 ] ⋯ k 1 [ 511 ] ⋮ ⋯ ⋮ k 10 [ 0 ] ⋯ k 10 [ 511 ] ) = ( k 1 ⃗ ⋮ k 10 ⃗ ) K=\left(\begin{array}{ccc}k_{1}[0] & \cdots & k_{1}[511] \\\vdots & \cdots & \vdots \\k_{10}[0] & \cdots & k_{10}[511]\end{array}\right) = \left(\begin{array}{c}\vec{k_{1}}\\\vdots\\ \vec{k_{10}} \end{array}\right) K= k1[0]⋮k10[0]⋯⋯⋯k1[511]⋮k10[511] = k1⋮k10
相乘结果如下
Q
⋅
K
T
∈
R
m
×
m
=
(
q
1
⃗
⋮
q
10
⃗
)
⋅
(
k
1
⃗
T
⋯
k
10
⃗
T
)
(
q
1
⃗
⋅
k
1
⃗
T
⋯
q
1
⃗
⋅
k
10
⃗
T
⋮
⋯
⋮
q
10
⃗
⋅
k
1
⃗
T
⋯
q
10
⃗
⋅
k
10
⃗
T
)
=
(
s
1
−
1
⋯
s
1
−
10
⋮
⋯
⋮
s
10
−
1
⋯
s
10
−
10
)
Q \cdot K^T \in \mathbf{R}^{m\times m}= \left(\begin{array}{c}\vec{q_{1}}\\\vdots\\ \vec{q_{10}} \end{array}\right) \cdot \left(\vec{k_{1}}^T\cdots \vec{k_{10}}^T\right) \left(\begin{array}{ccc} \vec{q_{1}}\cdot\vec{k_{1}}^T & \cdots & \vec{q_{1}}\cdot\vec{k_{10}}^T \\\vdots & \cdots & \vdots \\\vec{q_{10}}\cdot\vec{k_{1}}^T& \cdots & \vec{q_{10}}\cdot\vec{k_{10}}^T\end{array}\right) =\left(\begin{array}{ccc}s_{1-1} & \cdots & s_{1-10} \\\vdots & \cdots & \vdots \\s_{10-1} & \cdots & s_{10-10}\end{array}\right)
Q⋅KT∈Rm×m=
q1⋮q10
⋅(k1T⋯k10T)
q1⋅k1T⋮q10⋅k1T⋯⋯⋯q1⋅k10T⋮q10⋅k10T
=
s1−1⋮s10−1⋯⋯⋯s1−10⋮s10−10
矩阵 S S S中的每一个元素通过分别来自于 Q \mathbf{Q} Q 和 K \mathbf{K} K的两个向量的点乘得到的,通过最原始的矩阵定义,可以得知两个向量的点乘意味着一个向量在另一个向量的投影,也可以李继伟表示向量 q i ⃗ \vec{q_{i}} qi, k j ⃗ \vec{k_j} kj的相似程度
第二步:scaling与归一化
除以一个数字 d k \sqrt{d_{k}} dk的意义是:
- 因为如果 d k d_k dk太大,点乘的值太大,如果不做scaling,结果就没有加法注意力好。
- 为了不让输入太大,导致
softmax函数
被推动到非常平缓的区域。
将得到scaling后的相似度进行Softmax
操作,假定Scaling之后相似度矩阵为
(
s
1
−
1
′
⋯
s
1
−
m
′
⋮
⋯
⋮
s
m
−
1
′
⋯
s
m
−
m
′
)
=
(
s
1
−
1
/
d
k
⋯
s
1
−
m
/
d
k
⋮
⋯
⋮
s
m
−
1
/
d
k
⋯
s
m
−
m
/
d
k
)
\left(\begin{array}{ccc}s'_{1-1} & \cdots & s'_{1-m} \\\vdots & \cdots & \vdots \\ s'_{m-1} & \cdots & s'_{m-m}\end{array}\right) = \left(\begin{array}{ccc}s_{1-1}/\sqrt{d_{k}} & \cdots & s_{1-m}/\sqrt{d_{k}} \\\vdots & \cdots & \vdots \\s_{m-1}/\sqrt{d_{k}} & \cdots & s_{m-m}/\sqrt{d_{k}}\end{array}\right)
s1−1′⋮sm−1′⋯⋯⋯s1−m′⋮sm−m′
=
s1−1/dk⋮sm−1/dk⋯⋯⋯s1−m/dk⋮sm−m/dk
进行归一化
(
s
1
−
1
′
′
⋯
s
1
−
m
′
′
⋮
⋯
⋮
s
m
−
1
′
′
⋯
s
m
−
m
′
)
=
(
e
s
1
−
1
′
∑
i
=
1
m
e
s
1
−
i
′
⋯
e
s
1
−
m
′
∑
i
=
1
m
e
s
1
−
i
′
⋮
⋯
⋮
e
s
m
−
1
′
∑
i
=
1
m
e
s
m
−
i
′
⋯
e
s
m
−
m
′
∑
i
=
1
m
e
s
m
−
i
′
)
\left(\begin{array}{ccc}s''_{1-1} & \cdots & s''_{1-m} \\\vdots & \cdots & \vdots \\ s''_{m-1} & \cdots & s'_{m-m}\end{array}\right) = \left(\begin{array}{ccc}\frac{e^{s'_{1-1}}} {\sum_{i=1}^{m} e^{s'_{1-i}} } & \cdots & \frac{e^{s'_{1-m}}} {\sum_{i=1}^{m} e^{s'_{1-i}} } \\\vdots & \cdots & \vdots \\ \frac{e^{s'_{m-1}}} {\sum_{i=1}^{m} e^{s'_{m-i}} } & \cdots & \frac{e^{s'_{m-m}}} {\sum_{i=1}^{m} e^{s'_{m-i}} } \end{array}\right)
s1−1′′⋮sm−1′′⋯⋯⋯s1−m′′⋮sm−m′
=
∑i=1mes1−i′es1−1′⋮∑i=1mesm−i′esm−1′⋯⋯⋯∑i=1mes1−i′es1−m′⋮∑i=1mesm−i′esm−m′
如此实现一横行的加权和为1,不同的 v i v_i vi 向量获得的加权综合为1
第三步:加权输出
针对计算出来的权重
α
i
\alpha_{i}
αi,通过权重对
V
V
V中所有的values
进行加权求和计算,得到Attention向量
(
s
1
−
1
′
⋯
s
1
−
m
′
⋮
⋯
⋮
s
m
−
1
′
⋯
s
m
−
m
′
)
(
v
1
⃗
⋮
v
m
⃗
)
\left(\begin{array}{ccc}s'_{1-1} & \cdots & s'_{1-m} \\\vdots & \cdots & \vdots \\ s'_{m-1} & \cdots & s'_{m-m}\end{array}\right)\left(\begin{array}{c}\vec{v_{1}}\\\vdots\\ \vec{v_{m}} \end{array}\right)
s1−1′⋮sm−1′⋯⋯⋯s1−m′⋮sm−m′
v1⋮vm
3.3 读写IO伪代码
#########Standard Attention Implementation
Require: Matrices Q, K, V ∈ R^{N×d} in HBM.
1: Load Q, K by blocks from HBM, compute S = QK^{T}, write S to HBM.
2: Read S from HBM, compute P = softmax(S), write P to HBM.
3: Load P and V by blocks from HBM, compute O = PV, write O to HBM.
4: Return O.
3.3 关于Attention的总结
- 采用点乘注意力,这种注意力机制对于加法注意力而言,更快,同时更节省空间。
- 把
attention抽象
为对value
的每个表示(token
)进行加权,而加权的weight
就是attention weight
,而attention weight
就是根据query
和key
计算得到,其意义为:为了用value
求出query
的结果, 根据query
和key
来决定注意力应该放在value
的哪部分。
4 Flash Attention
4.1 背景分析
在标准注意力实现中,注意力的性能主要受限于内存带宽,是内存受限的。频繁地从HBM中读写 R N × N \mathbb{R}^{N \times N} RN×N的矩阵是影响性能的主要瓶颈。稀疏近似和低秩近似等近似注意力方法虽然减少了计算量FLOPs,但对于内存受限的操作,运行时间的瓶颈是从HBM中读写数据的耗时,减少计算量并不能有效地减少运行时间(wall-clock time)。针对内存受限的标准注意力,Flash Attention是IO感知的,目标是避免频繁地从HBM中读写数据。
4.2 解决方案
从GPU显存分级来看,SRAM的读写速度比HBM高一个数量级,但内存大小要小很多。通过kernel融合的方式,将多个操作融合为一个操作,利用高速的SRAM进行计算,可以减少读写HBM的次数,从而有效减少内存受限操作的运行时间。但SRAM的内存大小有限,不可能一次性计算完整的注意力,因此必须进行分块计算,使得分块计算需要的内存不超过SRAM的大小。
问题一:为什么要进行分块计算呢?
内存受限 --> 减少HBM读写次数 --> kernel融合 --> 满足SRAM的内存大小 --> 分块计算
因此分块大小block_size不能太大,否则会导致存储内容踢出。
问题二:分块计算的难点是什么呢?
注意力机制的计算过程是“矩阵乘法 --> scale --> mask --> softmax --> dropout --> 矩阵乘法”,矩阵乘法和逐点操作(scale,mask,dropout)的分块计算是容易实现的,难点在于softmax的分块计算。由于计算softmax的归一化因子(分母)时,需要获取到完整的输入数据,进行分块计算的难度比较大。论文中也是重点对softmax的分块计算进行了阐述。
tiling的主要思想是分块计算注意力。分块计算的难点在于softmax的分块计算,softmax与矩阵 K K K 的列是耦合的,通过引入了两个额外的统计量 m ( x ) m(x) m(x), l ( x ) l(x) l(x)来进行解耦,实现了分块计算。需要注意的是,可以利用GPU多线程同时并行计算多个block的softmax。为了充分利用硬件性能,多个block的计算不是串行(sequential)的, 而是并行的。
4.3 前向算法伪代码:Softmax的IO缩减
一个简单的例子实现分块计算Softmax
对向量 A = [ 1 , 2 , 3 , 4 ] A = [1,2,3,4] A=[1,2,3,4] 计算Softmax,分成两块 A 1 = [ 1 , 2 ] A_1 = [1,2] A1=[1,2] 和 A 2 = [ 3 , 4 ] A_2 = [3,4] A2=[3,4] 进行计算。 计算block1和block2:
block1
m
1
=
m
a
x
(
[
1
,
2
]
)
=
2
f
1
=
[
e
1
−
m
1
,
e
2
−
m
1
]
=
[
e
−
1
,
e
0
]
l
1
=
∑
f
1
=
e
−
1
+
e
0
o
1
=
f
1
l
1
=
[
e
−
1
,
e
0
]
e
−
1
+
e
0
=
[
e
−
1
e
−
1
+
e
0
,
e
0
e
−
1
+
e
0
]
m_1 = max([1,2]) = 2\\ f_1 = [e^{1-m_1},e^{2-m_1}] = [e^{-1},e^0]\\ l_1 = \sum f_1 = e^{-1} + e^0\\ o_1 = \frac{f_1}{l_1} = \frac{[e^{-1},e^0]}{e^{-1} + e^0} = \left[ \frac{e^{-1}}{e^{-1} + e^0}, \frac{e^0}{e^{-1} + e^0}\right]
m1=max([1,2])=2f1=[e1−m1,e2−m1]=[e−1,e0]l1=∑f1=e−1+e0o1=l1f1=e−1+e0[e−1,e0]=[e−1+e0e−1,e−1+e0e0]
block2
m
2
=
m
a
x
(
[
3
,
4
]
)
=
4
f
2
=
[
e
3
−
m
2
,
e
4
−
m
2
]
=
[
e
−
1
,
e
0
]
l
2
=
∑
f
2
=
e
−
1
+
e
0
o
2
=
f
2
l
2
=
[
e
−
1
,
e
0
]
e
−
1
+
e
0
=
[
e
−
1
e
−
1
+
e
0
,
e
0
e
−
1
+
e
0
]
m_2 = max([3,4]) = 4\\ f_2 = [e^{3-m_2},e^{4-m_2}] = [e^{-1},e^0]\\ l_2 = \sum f_2 = e^{-1} + e^0\\ o_2 = \frac{f_2}{l_2} = \frac{[e^{-1},e^0]}{e^{-1} + e^0} = \left[ \frac{e^{-1}}{e^{-1} + e^0}, \frac{e^0}{e^{-1} + e^0}\right]
m2=max([3,4])=4f2=[e3−m2,e4−m2]=[e−1,e0]l2=∑f2=e−1+e0o2=l2f2=e−1+e0[e−1,e0]=[e−1+e0e−1,e−1+e0e0]
合并得到完整的softmax结果:
m
=
m
a
x
(
m
a
x
1
,
m
a
x
2
)
=
4
f
=
[
e
m
1
−
m
f
1
,
e
m
2
−
m
∗
f
2
]
=
[
e
−
3
,
e
−
2
,
e
−
1
,
e
0
]
l
=
e
m
1
−
m
l
1
,
e
m
2
−
m
∗
l
2
=
e
−
3
+
e
−
2
+
e
−
1
+
e
0
o
=
f
l
=
[
e
−
1
,
e
0
]
e
−
1
+
e
0
=
[
e
−
1
e
−
1
+
e
0
,
e
0
e
−
1
+
e
0
]
m = max(max_1,max_2) = 4\\ f = \left[e^{m_1-m}f_1,e^{m_2-m}*f_2\right] = \left[e^{-3},e^{-2},e^{-1},e^0\right]\\ l = e^{m_1-m}l_1,e^{m_2-m}*l_2 = e^{-3}+e^{-2}+e^{-1}+e^0\\ o = \frac{f}{l} = \frac{[e^{-1},e^0]}{e^{-1} + e^0} = \left[ \frac{e^{-1}}{e^{-1} + e^0}, \frac{e^0}{e^{-1} + e^0}\right]
m=max(max1,max2)=4f=[em1−mf1,em2−m∗f2]=[e−3,e−2,e−1,e0]l=em1−ml1,em2−m∗l2=e−3+e−2+e−1+e0o=lf=e−1+e0[e−1,e0]=[e−1+e0e−1,e−1+e0e0]
算法伪代码
备注:这是在在忽略mask和dropout的情况下,简化分析Flash Attention算法的前向计算过程
作用分析:
在Flash Attention的前向计算算法中可以看出,FlashAttention实现在不访问整个输入的情况下计算softmax,实现IO的较大缩减,标准Attention算法由于要计算softmax,而softmax都是按行来计算的,即在和 V \mathbf{V} V做矩阵乘之前,需要让 Q \mathbf{Q} Q、 K \mathbf{K} K 的各个分块完成整一行分块的计算得到Softmax的结果后,再和矩阵 V \mathbf{V} V分块做矩阵乘。而在Flash Attention中,将输入分割成块,并在输入块上进行多次传递,从而以增量方式执行softmax缩减。
4.4 后向回传伪代码
将前文的前向计算抽象成如下模型,便于后文的引用
S
=
τ
Q
K
⊤
∈
R
N
×
N
S
masked
=
M
A
S
K
(
S
)
∈
R
N
×
N
P
=
softmax
(
S
masked
)
∈
R
N
×
N
P
dropped
=
dropout
(
P
,
p
drop
)
∈
R
N
×
N
O
=
P
dropped
V
∈
R
N
×
d
\begin{gathered} S=\tau Q K^{\top} \in \mathbb{R}^{N \times N} \\ S^{\text {masked }}=M A S K(S) \in \mathbb{R}^{N \times N} \\ P=\operatorname{softmax}\left(S^{\text {masked }}\right) \in \mathbb{R}^{N \times N} \\ P^{\text {dropped }}=\operatorname{dropout}\left(P, p_{\text {drop }}\right) \in \mathbb{R}^{N \times N} \\ O=P^{\text {dropped }} V \in \mathbb{R}^{N \times d} \end{gathered}
S=τQK⊤∈RN×NSmasked =MASK(S)∈RN×NP=softmax(Smasked )∈RN×NPdropped =dropout(P,pdrop )∈RN×NO=Pdropped V∈RN×d
在标准注意力实现中,后向传递计算
Q
\mathbf{Q}
Q,
K
\mathbf{K}
K,
V
\mathbf{V}
V的梯度时,需要用到中间矩阵
S
∈
R
N
×
N
\mathbf{S}\in\mathbb{R}^{N\times N}
S∈RN×N,
P
∈
R
N
×
N
\mathbf{P}\in\mathbb{R}^{N\times N}
P∈RN×N。Flash Attention没有保存这两个矩阵,而是保存了两个统计量
m
(
x
)
m(x)
m(x),
l
(
x
)
l(x)
l(x),在后向传递时进行重计算。
在反向传递过程中, 需要计算损失函数 ϕ \phi ϕ 对 O \mathbf{O} O, Q \mathbf{Q} Q, K \mathbf{K} K, V \mathbf{V} V 的梯度。在给定 d O ∈ R N × d d \mathbf{O} \in \mathbb{R}^{N \times d} dO∈RN×d 的情况下, 计算梯度 d Q ∈ R N × d d\mathbf{Q}\in \mathbb{R}^{N \times d} dQ∈RN×d, d K ∈ R N × d d\mathbf{K}\in \mathbb{R}^{N \times d} dK∈RN×d, d V ∈ R N × d d\mathbf{V} \in \mathbb{R}^{N \times d} dV∈RN×d 。其中, d O d\mathbf{O} dO, d Q d\mathbf{Q} dQ, d K d\mathbf{K} dK, d V d\mathbf{V} dV 分别表示为 ∂ ϕ ∂ O \frac{\partial \phi}{\partial \mathbf{O}} ∂O∂ϕ, ∂ ϕ ∂ Q \frac{\partial \phi}{\partial \mathbf{Q}} ∂Q∂ϕ, ∂ ϕ ∂ K \frac{\partial \phi}{\partial \mathbf{K}} ∂K∂ϕ, ∂ ϕ ∂ V \frac{\partial \phi}{\partial \mathbf{V}} ∂V∂ϕ
计算 d V d\mathbf{V} dV
梯度
d
V
d\mathbf{V}
dV 是容易计算的。由
O
=
P
V
\mathbf{O}=\mathbf{P} \mathbf{V}
O=PV,基于矩阵求导算法和链式法则, 得到矩阵形式的梯度
d
V
=
P
⊤
d
O
d\mathbf{V}=\mathbf{P}^{\top} d \mathbf{O}
dV=P⊤dO 。在元素形式上,有:
d
v
j
=
∑
i
P
i
j
d
o
i
=
∑
i
e
(
q
i
⊤
k
j
)
L
i
d
o
i
d \mathbf{v}_j=\sum_i \mathbf{P}_{i j} d \mathbf{o}_i=\sum_i \frac{e^{(\mathbf{q}_i^{\top} k_j)}}{L_i} d \mathbf{o}_i
dvj=i∑Pijdoi=i∑Lie(qi⊤kj)doi
之前已经计算好
L
i
L_i
Li,就可以通过反复累加的方式计算得到
d
v
j
d \mathbf{v}_j
dvj 。
计算 d Q d\mathbf{Q} dQ, d K d\mathbf{K} dK
梯度
d
Q
d\mathbf{Q}
dQ,
K
\mathbf{K}
K 的计算是略微复杂的。首先要计算
d
P
d\mathbf{P}
dP,
d
S
d\mathbf{S}
dS 。由
O
=
P
V
\mathbf{O}=\mathbf{P} \mathbf{V}
O=PV,得到矩阵形式的梯度
d
P
=
d
O
V
⊤
d\mathbf{P}=d\mathbf{O} \mathbf{V}^{\top}
dP=dOV⊤ 。在元素形式上,有:
d
P
i
j
=
d
o
i
⊤
v
j
d \mathbf{P}_{i j}=d \mathbf{o}_i^{\top} \mathbf{v}_j
dPij=doi⊤vj
有
P
i
:
=
softmax
(
S
i
:
)
\mathbf{P}_{i:}=\operatorname{softmax}\left(\mathbf{S}_{i:}\right)
Pi:=softmax(Si:) (表示
i
i
i的一整行)。基于
y
=
softmax
(
x
)
y=\operatorname{softmax}(x)
y=softmax(x) 的雅各比矩阵为
diag
(
y
)
−
y
y
⊤
\operatorname{diag}(y)-y y^{\top}
diag(y)−yy⊤ 。可以得到:
d
S
i
:
=
(
diag
(
P
i
:
)
−
P
i
:
P
i
:
⊤
)
d
P
i
:
=
P
i
:
∘
d
P
i
:
−
(
P
i
:
⊤
d
P
i
:
)
P
i
:
d \mathbf{S}_{i:}=\left(\operatorname{diag}\left(\mathbf{P}_{i:}\right)-\mathbf{P}_{i:} P_{i:}^{\top}\right) d \mathbf{P}_{i:}=\mathbf{P}_{i:} \circ d \mathbf{P}_{i:}-\left(P_{i:}^{\top} d \mathbf{P}_{i:}\right) \mathbf{P}_{i:}
dSi:=(diag(Pi:)−Pi:Pi:⊤)dPi:=Pi:∘dPi:−(Pi:⊤dPi:)Pi:
其中 ∘ \circ ∘ 表示逐点相乘。
可以定义:
D
i
=
P
i
:
⊤
d
P
i
:
=
∑
j
e
q
i
⊤
k
j
L
i
d
o
i
⊤
v
j
=
d
o
i
⊤
∑
j
e
q
i
⊤
k
j
L
i
v
j
=
d
o
i
⊤
o
i
D_i=P_{i:}^{\top} d P_{i:}=\sum_j \frac{e^{q_i^{\top} k_j}}{L_i} d o_i^{\top} v_j=d o_i^{\top} \sum_j \frac{e^{q_i^{\top} k_j}}{L_i} v_j=d o_i^{\top} o_i
Di=Pi:⊤dPi:=j∑Lieqi⊤kjdoi⊤vj=doi⊤j∑Lieqi⊤kjvj=doi⊤oi
将该定义代回到上式中, 可以得到:
d
S
i
:
=
P
i
:
∘
d
P
i
:
−
D
i
P
i
:
d S_{i:}=P_{i:} \circ d P_{i:}-D_i P_{i:}
dSi:=Pi:∘dPi:−DiPi:
因此,梯度
d
S
d\mathbf{S}
dS 可以表示为以下形式:
d
S
i
j
=
P
i
j
d
P
i
j
−
D
i
P
i
j
=
P
i
j
(
d
P
i
j
−
D
i
)
d \mathbf{S}_{i j}=\mathbf{P}_{i j} d \mathbf{P}_{i j}-\mathbf{D}_i \mathbf{P}_{i j}=\mathbf{P}_{i j}\left(d \mathbf{P}_{i j}-\mathbf{D}_i\right)
dSij=PijdPij−DiPij=Pij(dPij−Di)
在计算得到
d
P
i
j
d \mathbf{P}_{i j}
dPij,
d
S
i
j
d \mathbf{S}_{i j}
dSij 后, 可以计算
d
Q
d\mathbf{Q}
dQ,
d
K
d\mathbf{K}
dK 。有前向计算公式
S
i
j
=
q
i
⊤
k
j
\mathbf{S}_{i j}=\mathbf{q}_i^{\top} \mathbf{k}_j
Sij=qi⊤kj, 可以得到:
d
q
i
=
∑
j
d
S
i
j
k
j
=
∑
j
P
i
j
(
d
P
i
j
−
D
i
)
k
j
=
∑
j
e
(
q
i
⊤
k
j
)
L
i
(
d
o
i
⊤
v
j
−
D
i
)
k
j
d
k
j
=
∑
i
d
S
i
j
q
i
=
∑
i
P
i
j
(
d
P
i
j
−
D
i
)
q
i
=
∑
i
e
(
q
i
⊤
k
j
)
L
i
(
d
o
i
⊤
v
j
−
D
i
)
q
i
\begin{gathered} d \mathbf{q}_i=\sum_j d \mathbf{S}_{i j} \mathbf{k}_j=\sum_j \mathbf{P}_{i j}\left(d \mathbf{P}_{i j}-\mathbf{D}_i\right) \mathbf{k}_j=\sum_j \frac{e^{(\mathbf{q}_i^{\top} \mathbf{k}_j)}}{\mathbf{L}_i}\left(d \mathbf{o}_i^{\top} \mathbf{v}_j-\mathbf{D}_i\right) \mathbf{k}_j \\ d \mathbf{k}_j=\sum_i d \mathbf{S}_{i j} \mathbf{q}_i=\sum_i \mathbf{P}_{i j}\left(d \mathbf{P}_{i j}-\mathbf{D}_i\right) \mathbf{q}_i=\sum_i \frac{e^{(\mathbf{q}_i^{\top} \mathbf{k}_j)}}{\mathbf{L}_i}\left(d \mathbf{o}_i^{\top} \mathbf{v}_j-\mathbf{D}_i\right) \mathbf{q}_i \end{gathered}
dqi=j∑dSijkj=j∑Pij(dPij−Di)kj=j∑Lie(qi⊤kj)(doi⊤vj−Di)kjdkj=i∑dSijqi=i∑Pij(dPij−Di)qi=i∑Lie(qi⊤kj)(doi⊤vj−Di)qi
与前向计算类似,在计算得到 L i \mathbf{L}_i Li 后, 就可以通过反复累加的方式计算得到 d q i d \mathbf{q}_i dqi, d k j d \mathbf{k}_j dkj, d v j d \mathbf{v}_j dvj 。避免了实例化矩阵 P \mathbf{P} P, S \mathbf{S} S,节省了显存,后向传递的显存复杂度为 O ( N ) O(N) O(N) 。
作用分析
对比标准Attention算法的实现过程中,其需要将计算中的 S \mathbf{S} S、 P \mathbf{P} P写入到HBM中,而这些中间矩阵的大小与输入的序列长度有关且为二次型;
Flash Attention算法中,其并没有将 S \mathbf{S} S、 P \mathbf{P} P写入HBM中去,而是通过分块写入到HBM中去,存储前向传递的 softmax 归一化因子,在后向传播中快速重新计算片上注意力,这比从HBM中读取中间注意力矩阵的标准方法更快。即使由于重新计算导致 FLOPS 增加,但其运行速度更快并且使用更少的内存(序列长度线性),主要是因为大大减少了 HBM 访问量。
Flash Attention实现了不使用中间注意力矩阵,通过存储归一化因子来减少HBM内存的消耗。
5 总结
- FA尽可能避免从HBM中读取和写入注意力矩阵,做到了:
- 在不访问整个输入的情况下计算softmax函数的IO缩减;
- 在后向传播中不存储中间注意力矩阵
- 通过减少GPU内存读取/写入,FlashAttention的运行速度比PyTorch标准注意力快 2-4 倍,所需内存减少5-20倍。
6 参考文献
[2205.14135] FlashAttention: Fast and Memory-Efficient Exact Attention with IO-Awareness
论文分享:新型注意力算法FlashAttention - 知乎
FlashAttention:加速计算,节省显存, IO感知的精确注意力 - 知乎