FlashAttention V1 学习笔记

Flash Attention 是一种新型的注意力机制,旨在解决传统 Transformer 模型在处理长序列数据时面临的计算和内存效率问题。它通过一系列创新的技术优化,显著提高了注意力机制的计算速度和内存使用效率,同时保持了精确的结果,不依赖于近似计算。

背景&动机

当输入序列较长时,Transformer 的计算过程缓慢且耗费内存,这是因为 self-attention 的时间和内存复杂度会随着序列长度的增加而呈二次增长。标准 Attention 计算的中间结果 S, P(见下文)通常需要通过 HBM 进行存取,两者所需内存空间复杂度为 O ( N 2 ) O(N^2) O(N2)

s e l f − a t t e n t i o n ( x ) = s o f t m a x ( Q K T d ) ⋅ V self-attention(x) = softmax(\frac{Q K^T}{\sqrt{d}})\cdot V selfattention(x)=softmax(d QKT)V

S = Q K T d , P = s o f t m a x ( S ) S = \frac{Q K^T}{\sqrt{d}}, \quad P = softmax(S) S=d QKT,P=softmax(S)

在不考虑 batch size 的前提下,令 N 表示序列长度,d 表示注意力头维度(隐藏层维度 / 注意力头数)。那么,Q、K 和 V 矩阵的 shape 为 [N, d],S 和 P 的 shape 为 [N, N]。

在这里插入图片描述
标准 Attention 计算操作的算法逻辑如上图所示。

  • 首先,需要从 HBM 中加载 Q 和 K 矩阵,时间复杂度 O ( N d ) O(Nd) O(Nd)
  • 然后,计算 S = Q K T S = QK^T S=QKT,并将 S 写回到 HBM,时间复杂度 O ( N 2 ) O(N^2) O(N2)
  • 接着,从 HBM 中加载 S,计算 P = s o f t m a x ( S ) P = softmax(S) P=softmax(S),并将 P 写回到 HBM,时间复杂度 O ( 2 N 2 ) O(2N^2) O(2N2)
  • 最后,从 HBM 中加载 P 和 V 矩阵,计算 O = P V O = PV O=PV,并将 O 写回到 HBM,时间复杂度 O ( N d ) O(Nd) O(Nd)

注意事项:时间复杂度衡量的是对 HBM 访问所需的耗时量级。

因此,标准 Attention 计算对 HBM 访问的时间复杂度为 O ( N d + N 2 ) O(Nd + N^2) O(Nd+N2)。由于部分或大部分操作是 memory-bound(例如,softmax 操作,需要加载矩阵 S 后才可以进行计算),大量的内存访问转化为缓慢的 wall-clock 时间。如果对注意力矩阵进行其他元素运算,例如对 S 进行 mask 或对 P 进行 dropout,则会加剧这一问题。

前置知识

了解这些内容后会帮助理解 Flash Attention 设计的动机以及如何可以加速。

硬件性能

GPU 内存层次结构:现代 GPU 具有多种不同大小和速度的内存形式。例如,A100 GPU 具有 40-80GB 的高带宽内存(HBM),带宽为 1.5-2.0 TB/s,共有 108 个流多处理器(Stream MultiProcessor),总计 192KB 的 on-chip SRAM,带宽约为 19TB/s。SRAM 比 HBM 快一个数量级,但大小比 HBM 小多个数量级。如下图所示:

在这里插入图片描述
执行过程:GPU 有大量线程来执行操作(称为 kernel)。每个 kernel 加载输入数据到寄存器和 SRAM 进行计算,然后将输出写回到 HBM。

关于 GPU 更多的内容请参考:

  • 《理解 GPU 的底层架构》
  • 《GPU 内存(显存)的理解与基本使用》

FLOPs & MAC

FLOPs 定义了模型核心计算的密集程度,因此模型的计算量 FLOPs 与模型的计算速度有很大关系。学术界有很多使用各种技巧来降低 Transformer FLOPs 的方法,通常将由这些方法改进得到的模型称为 Efficient Transformer,但大多数只关注 FLOPs。

Flash Attention 的作者们发现,Efficient Transformer 虽然能够有效降低模型的 FLOPs,但它们的计算速度并没有显著降低。导致该现象的根本原因是模型的计算速度除了与 FLOPs 有很大关系外,同时与 MAC(Memory Access Cost,存储访问开销)有关。尤其是当计算本身已经很高效的前提下,MAC 的开销更加不能忽略,其开销主要来自两个方面:

  • 从存储中读取数据。
  • 向存储中写入数据。

与 CPU 的情况类似,当需要计算时,将数据从内存中读取并由计算单元进行计算操作。计算完成后,再写回到内存中。

Compute-bound & Memory-bound

在计算机科学中,特别是在性能优化领域,“compute-bound”和“memory-bound”是两个描述程序性能瓶颈的术语。它们指出了程序执行速度受限的主要因素:是处理器的计算能力,还是内存的访问速度。

根据计算的密集程度,可以将操作(operator)分为两类:

  • Compute-bound(计算受限):。一个程序或系统的性能受限于处理器速度,这意味着程序执行的瓶颈在于 CPU 或 GPU 的计算能力,而非数据的输入输出速度或内存访问速度。对于 compute-bound 的程序,增加更多的处理器核心、使用更快的处理器或优化代码中的计算部分可以提高性能。在深度学习中,一个典型的 compute-bound 情况是当模型包含大量的浮点运算,如矩阵乘法和卷积操作。如果处理器无法快速完成这些运算,那么程序的执行就会受到限制。
  • Memory-bound(内存受限):一个程序或系统的性能受限于内存访问速度。在这种情况下,处理器花费大量时间等待数据从内存中读取或写入,而不是执行计算。内存带宽和延迟成为性能瓶颈的主要因素。对于 memory-bound 的程序,提高内存的速度、减少内存访问次数或者优化数据的存储和访问模式可以提高性能。在处理大数据集或者具有复杂数据结构的应用程序中,内存访问模式对性能影响很大,这些程序往往是 memory-bound 的。

Kernel fusion(核函数融合)

Kernel fusion(核函数融合)是一种优化技术,旨在提高计算效率和减少内存访问开销。它主要针对内存密集型操作(memory-bound operations)。在深度学习中,许多操作通常以核函数的形式执行,每个核函数代表一个特定的计算操作,例如矩阵乘法、激活函数应用或者 softmax 计算等等。

核函数融合通过将多个核函数合并或融合成一个更大的核函数来优化计算流程,这有助于减少内存访问的次数。通常情况下,多个操作处理相同的输入数据,如果分开执行,会导致重复的数据加载和写入操作,增加内存访问的开销。通过融合核函数,可以将这些操作合并为一个单一的计算任务,使得输入数据只需加载一次,减少了数据传输的次数,从而提高计算效率。

然而,在某些情况下,即使进行了核函数融合,仍然需要将中间结果写回内存以供反向传播或其他操作使用,这可能会限制融合操作的效果。

Flash Attention V1

Flash Attention V1 考虑如何以较少的 HBM 读写次数计算精确注意力,并且无需为反向传播存储大型中间矩阵(中间激活)。这样既能节省内存,又能以 wall-clock 时间加快速度。

给定 HBM 中的输入矩阵 Q , K , V ∈ R N × d Q, K, V \in \R^{N \times d} Q,K,VRN×d,计算注意力输出 O ∈ R N × d O \in \R^{N \times d} ORN×d,并将其写入到 HBM。但标准 Attention 计算过程存在大量的 HBM 访问 O ( N 2 ) O(N^2) O(N2)。我们的目标是减少 HBM 的访问次数(达到 N 的二次方以下)。Flash Attention V1 采用了平铺和重新计算技术来实现这一目标,Algorithm 1 对此进行了描述,下文将介绍这两种技术,然后再介绍 Flash Attention V1 的计算过程。

在这里插入图片描述
主要思路是将输入的 Q、K 和 V 矩阵分割成更小的块,从相对 SRAM 更慢的 HBM 加载到 SRAM,依次来减少在 HBM 上的读写次数。然后计算这些块的注意力输出,并用正确的归一化因子对其进行缩放。最后将每个块的输出相加。

问题:那么是如何减少在 HBM 上的读写次数呢?

结合对 GPU 内部存储的理解,标准的 Attention 计算公式中,首先从 HBM(片下存储)将 Q 和 K 矩阵加载到片上存储(SRAM),计算 S = Q K T S=QK^T S=QKT,并将 S 矩阵写回到 HBM。接着,又将 S 矩阵从 HBM 加载到片上存储计算 softmax。有没有什么办法可以不写回 S 矩阵,直接计算 softmax 呢?这样就可以减少写回的性能开销。

片上存储的空间有限(A100 L2 Cache 最多能够设置 40MB 的持续化数据),无法完全存下 S 矩阵,一个朴素的想法是将 Q 和 K 矩阵分块,然后存下分块后的 S’ 矩阵。但新的问题诞生了,softmax 操作的分母是求和项,需要完整的 S 矩阵。于是,考虑有没有什么办法可以动态地更新 softmax 输出值?平铺技术!

平铺技术(Tiling)

FlashAttention 的核心是对标准的 Attention 操作进行分块计算。对于矩阵乘法来说,可以直接通过分块来达到分块计算的目的,但 self-attention 中存在 softmax 操作,而 softmax 函数的分母项包含与所有元素相关的求和,所以真正难点在于对 softmax 的分块计算。

稳定版 softmax 计算

softmax 的计算公式中含有指数项,当指数项 e x i e^{x_i} exi中的 x i x_i xi较大时, e x i e^{x_i} exi的值也容易很大,从而在计算中出现溢出。为了避免溢出的问题,大多数深度学习框架中都使用了 softmax 的稳定版本。仍以向量 x 为例,稳定版的 softmax 的计算如下:

m ( x ) : = max ⁡ i x i , f ( x ) : = [ e x 1 − m ( x ) … e x B − m ( x ) ] , ℓ ( x ) : = ∑ i f ( x ) i , softmax ⁡ ( x ) : = f ( x ) ℓ ( x ) m(x):=\max _{i} \quad x_{i}, \quad f(x):=\left[\begin{array}{lll}e^{x_{1}-m(x)} & \ldots & e^{x_{B}-m(x)}\end{array}\right], \\ \ell(x):=\sum_{i} f(x)_{i}, \quad \operatorname{softmax}(x):=\frac{f(x)}{\ell(x)} m(x):=imaxxi,f(x):=[ex1m(x)exBm(x)],(x):=if(x)i,softmax(x):=(x)f(x)

  • 计算向量 x 中的最大值: m ( x ) = m a x ( [ x 1 , x 2 , … , x B ] ) m(x) = max([x_1, x_2, \ldots, x_B]) m(x)=max([x1,x2,,xB])
  • 将向量 x - m(x)后,再计算 e x i e^{x_i} exi f ( x ) = [ e x 1 − m ( x ) , … , e x B − m ( x ) ] f(x) = [e^{x_1 -m(x)}, \ldots, e^{x_B - m(x)}] f(x)=[ex1m(x),,exBm(x)]
  • 计算 softmax 分母中的求和项: l ( x ) = ∑ i f ( x ) i l(x) = \sum_i f(x)_i l(x)=if(x)i
  • 最后计算 softmax 结果: s o f t m a x ( x ) = f ( x ) l ( x ) softmax(x) = \frac{f(x)}{l(x)} softmax(x)=l(x)f(x)

其中,f(x) 是向量,分母 l(x) 是标量,所以这里的除法是逐元素相除

直觉上,softmax 操作难以分块计算的主要原因是它的分母 l(x) 依赖于输入向量 x 中的每个值。

softmax 动态更新原理

按照论文中的介绍,我们假设输入向量 x 可以切分为两块 [ x ( 1 ) , x ( 2 ) ] [x^{(1)}, x^{(2)}] [x(1),x(2)]。在分块计算中,首先处理 x ( 1 ) x^{(1)} x(1)再处理 x ( 2 ) x^{(2)} x(2)。我们按照上述的稳定版 softmax 对子向量 x ( 1 ) x^{(1)} x(1)计算“局部 softmax”。

  • 计算子向量 x ( 1 ) x^{(1)} x(1)中的最大值: m ( x ( 1 ) ) m(x^{(1)}) m(x(1))
  • 将子向量 x ( 1 ) − m ( x ( 1 ) ) x^{(1)} - m(x^{(1)}) x(1)m(x(1))后,在计算 f ( x ( 1 ) ) = [ e x 1 ( 1 ) − m ( x ( 1 ) ) , … , e x B / 2 ( 1 ) − m ( x ( 1 ) ) ] f(x^{(1)}) = [e^{x_1^{(1)} - m(x^{(1)})}, \ldots, e^{x_{B/2}^{(1)} - m(x^{(1)})}] f(x(1))=[ex1(1)m(x(1)),,exB/2(1)m(x(1))]
  • 计算 softmax 分母中的求和项: l ( x ( 1 ) ) = ∑ i f ( x ( 1 ) ) i l(x^{(1)}) = \sum_i f(x^{(1)})_i l(x(1))=if(x(1))i
  • 最后计算 softmax 结果: s o f t m a x ( x ( 1 ) ) = f ( x ( 1 ) ) l ( x ( 1 ) ) softmax(x^{(1)}) = \frac{f(x^{(1)})}{l(x^{(1)})} softmax(x(1))=l(x(1))f(x(1))

显然上述计算得到的 softmax 结果并不是子向量 x ( 1 ) x^{(1)} x(1)的最终结果。首先,减去的最大值是整个向量 x 的最大值,而不是子向量 x ( 1 ) x^{(1)} x(1)的最大值。另外,求和项是整个向量 x 的求和项,而不仅仅是子向量中所有元素的求和。正因该计算得到的 softmax 结果不是最终结果,所以称其为“局部的”。

那么,在计算最后一个分块时,我们是可以拿到整个向量的最大值(在遍历每个区块时,记录下最大值 M),并且也可以计算得到整个向量的求和项(同样在遍历每个区块时,累加这些区块的求和项 l a l l l_{all} lall)。

M n e w = m a x ( [ M , m ( x ( 2 ) ) ] ) , l a l l = e M − M n e w l a l l + e m ( x ( 2 ) ) − M n e w l ( x ( 2 ) ) M_{new} = max([M, m(x^{(2)})]), \quad l_{all} = e^{M - M_{new}} l_{all} + e^{m(x^{(2)}) - M_{new}} l(x^{(2)}) Mnew=max([M,m(x(2))]),lall=eMMnewlall+em(x(2))Mnewl(x(2))

那么如何将 l ( x ( 2 ) ) l(x^{(2)}) l(x(2))从“局部”更新成“全局”呢?按照计算公式将 m ( x ( 2 ) ) m(x^{(2)}) m(x(2))替换成全局的最大值 M n e w M_{new} Mnew。简而言之,当需要把某个 l 更新为“全局”时,只要将其乘以 e m − M e^{m - M} emM,其中 m 表示当前 l 的最大值,M 表示全局最大值。在最后一个分块将 M 和 l a l l l_{all} lall分别更新至“全局”后,我们就能直接更新 softmax 值。

在这动态更新的过程中,我们用到了如下变量:

  • x ( 2 ) x^{(2)} x(2)的局部 softmax 值.
  • x ( 2 ) x^{(2)} x(2)的局部求和项 l ( x ( 2 ) ) l(x^{(2)}) l(x(2))
  • x ( 2 ) x^{(2)} x(2)的局部最大值 m ( x ( 2 ) ) m(x^{(2)}) m(x(2))
  • 全局最大值 M。
  • 全局求和项 l a l l l_{all} lall

更新的过程中不需要用到 x ( 1 ) x^{(1)} x(1) x ( 2 ) x^{(2)} x(2)。然而,再反向将 x ( 1 ) x^{(1)} x(1)从“局部”更新成“全局”。这就是 Flash Attention 中对 softmax 峙进行动态更新的本质。实际上一个增量计算的过程:首先计算第一个分块的局部 softmax 值,然后存储该局部 softmax 值、当前的全局最大值和全局求和项。当处理完最后一个分块后,得到真正的全局最大值和全局求和项,再反过来更新所有的分块。

论文中的原始计算公式如下所示:

m ( x ) = m ( [ x ( 1 ) x ( 2 ) ] ) = max ⁡ ( m ( x ( 1 ) ) , m ( x ( 2 ) ) ) ,   f ( x ) = [ e m ( x ( 1 ) ) − m ( x ) f ( x ( 1 ) ) e m ( x ( 2 ) ) − m ( x ) f ( x ( 2 ) ) ] ℓ ( x ) = ℓ ( [ x ( 1 ) x ( 2 ) ] ) = e m ( x ( 1 ) ) − m ( x ) ℓ ( x ( 1 ) ) + e m ( x ( 2 ) ) − m ( x ) ℓ ( x ( 2 ) ) , softmax ⁡ ( x ) = f ( x ) ℓ ( x ) . \begin{array}{l}m(x)=m\left(\left[x^{(1)} x^{(2)}\right]\right)=\max \left(m\left(x^{(1)}\right), m\left(x^{(2)}\right)\right), \ f(x)=\left[e^{m\left(x^{(1)}\right)-m(x)} f\left(x^{(1)}\right) \quad e^{m\left(x^{(2)}\right)-m(x)} f\left(x^{(2)}\right)\right] \\ \ell(x)=\ell\left(\left[x^{(1)} x^{(2)}\right]\right)=e^{m\left(x^{(1)}\right)-m(x)} \ell\left(x^{(1)}\right)+e^{m\left(x^{(2)}\right)-m(x)} \ell\left(x^{(2)}\right), \quad \operatorname{softmax}(x)=\frac{f(x)}{\ell(x)} .\end{array} m(x)=m([x(1)x(2)])=max(m(x(1)),m(x(2))), f(x)=[em(x(1))m(x)f(x(1))em(x(2))m(x)f(x(2))](x)=([x(1)x(2)])=em(x(1))m(x)(x(1))+em(x(2))m(x)(x(2)),softmax(x)=(x)f(x).

重新计算(ReComputation)

在标准的注意力机制中,前向传播过程会存储中间激活(包括注意力矩阵以及中间激活值),从而用于反向传播的梯度计算。FlashAttention 通过使用重新计算来避免存储大量的中间激活。

核心思想:在前向传播期间,将注意力输出矩阵(O)和 softmax 归一化统计数据(M 和 ℓ \ell )存储起来。在反向传播阶段,通过使用这些值和在 SRAM 中的输入块重新计算注意力矩阵 S 和 P,而不必在 HBM 中存储中间激活值(S 和 P)。

这种方式类似于选择性梯度检查点技术,梯度检查点技术通常涉及在计算图中选择性地存储中间值或梯度信息,以便在反向传播过程中计算梯度。然而,梯度检查点技术通常需要在内存中保存一些计算中间状态,这会增加内存占用,并且在某些情况下可能会影响计算的速度。

Flash Attention 的重新计算不需要牺牲速度以换取内存。尽管重新计算涉及更多的浮点运算,但由于减少了 HBM 读写次数,反向传播的速度反而得到提升。

计算过程

以算法 1 的伪代码为例,Flash Attention V1 的输入包括:

  • 存储于 HBM 中的 Q , K , V ∈ R N × d Q, K, V \in \R^{N \times d} Q,K,VRN×d
  • SRAM 的大小 M。

问题:如何进行分块,以及块的大小该如何设定呢?

博客 - FlashAttention图解(如何加速Attention)中以 GPT2 和 A100 进行举例,A100 的 SRAM 大小为 192KB = 196608B,对应算法 1 中的 M,GPT2 中 N =1024,d = 64。Q、K 和 V 矩阵的 shape 为 1024 x 64,中间结果 S、P 的 shape 为 1024 x 1024。

初始化部分

  • 第 1 行:根据 SRAM 的大小 M和注意力头维度 d 计算合适的分块大小, B c = ⌈ M 4 d ⌉ = ⌈ 196608 / ( 4 × 64 ) ⌉ = 768 ; B r = m i n ( B c , d ) = m i n ( 768 , 64 ) = 64 B_c = \lceil \frac{M}{4d} \rceil = \lceil 196608 / (4 \times 64) \rceil = 768; \quad B_r = min(B_c, d) = min(768, 64) = 64 Bc=4dM=196608/(4×64)⌉=768;Br=min(Bc,d)=min(768,64)=64
  • 第 2 行:初始化平铺技术中用来计算动态 softmax 的辅助变量 l = ( 0 ) N ∈ R N , m = ( − ∞ ) N ∈ R N l = (0)_N \in \R^N, m = (-\infin)_N \in \R^N l=(0)NRN,m=()NRN,存放在 HBM。N 维向量 l 用来记录每个位置的求和项,N 维向量 m 用来记录每个位置的最大值。同时,也初始化输出矩阵 O = ( 0 ) N × d ∈ R N × d O = (0)_{N \times d} \in \R^{N \times d} O=(0)N×dRN×d
  • 第 3 行:计算 T c = ⌈ 1024 / 768 ⌉ = 2 ; T r = ⌈ 1024 / 64 ⌉ = 16 T_c = \lceil 1024 / 768 \rceil = 2; \quad T_r = \lceil 1024 / 64 \rceil = 16 Tc=1024/768=2;Tr=1024/64=16。可以理解为将完整的 Q、K 和 V 矩阵加载到 SRAM 的次数。然后将 Q 矩阵按照 T r T_r Tr拆分成 Q 1 , … , Q T r Q_1, \ldots, Q_{T_r} Q1,,QTr个子矩阵(维度为 B r × d B_r \times d Br×d),将 K 和 V 矩阵按照 T c T_c Tc拆分成 K 1 , … , K T c K_1, \ldots, K_{T_c} K1,,KTc V 1 , … , V T c V_1, \ldots, V_{T_c} V1,,VTc个子矩阵(维度为 B c × d B_c \times d Bc×d)。
  • 第 4 行:将 O 矩阵按照 T r T_r Tr拆分成 O 1 , … , O T r O_1, \ldots, O_{T_r} O1,,OTr(维度为 B r × d B_r \times d Br×d),将l 和 m 按照 T r T_r Tr拆分成 l 1 , … , l T r l_1, \ldots, l_{T_r} l1,,lTr m 1 , … , m T r m_1, \ldots, m_{T_r} m1,,mTr向量(维度为 B r B_r Br)。

可以理解为将矩阵 Q 和 O 沿着行方向切分成 T r T_r Tr块,将向量 l 和 m 分为 T r T_r Tr块。将矩阵 K 和 V 沿着列方向分为 T c T_c Tc块。

在这里插入图片描述
动态 softmax 计算过程

  • 第 5 至第 6 行:每次外循环 j(一共循环 T c = 2 T_c = 2 Tc=2次),将 K j , V j K_j, V_j Kj,Vj从 HBM 加载到 on-chip SRAM 上,这两个分块矩阵的维度为 B c × d = 768 × 64 B_c \times d = 768 \times 64 Bc×d=768×64
  • 第 7 至第 14 行:每次内循环 i(一共循环 T r = 16 T_r = 16 Tr=16次)。
    • 第 8 行:从 HBM 加载 Q i , O i , ζ i , m i Q_i, O_i, \zeta_i, mi Qi,Oi,ζi,mi到 on-chip SRAM。
    • 第 9 行:使用第 8 行加载的 Q i Q_i Qi以及外循环中已经加载的 K j K_j Kj,计算分块的注意力分数 S i j = Q i K j T ∈ R B r × B c S_{ij} = Q_iK_j^T \in \R^{B_r \times B_c} Sij=QiKjTRBr×Bc
    • 第 10 行:对分块的注意力分数 S i j S_{ij} Sij,计算它当前分块的局部最大值、局部求和项以及局部 softmax 值。
    • 第 11 行:在 on-chip SRAM 上计算新的局部最大值和局部求和项。
    • 第 12 行:计算当前分块的输出矩阵 O i O_i Oi,并写回到 HBM。并且在该步骤可以多行一同计算,每一个小分块 s i j s_{ij} sij B r B_r Br行,但行与行之间的数据不会有任何交互,真正分块的意义是在列上(每一行表示当前位置与其他位置的相似程度,因此行与行之间没有交互)。
    • 第 13 行:将新的局部最大值和局部求和项覆盖原值,并写回到 HBM。

论文中的定理

定理 1:FlashAttention 的 FLOPs 为 O ( N 2 d ) O(N^2d) O(N2d),除了 input 和 output 外,额外需要的内存为O(N)。

影响 FLOPs 的主要是矩阵乘法,在一次循环中:

  • 算法 1 第 9 行:计算 Q i K j T ∈ R B r × B c Q_iK_j^T \in \R^{B_r \times B_c} QiKjTRBr×Bc,由于 Q i ∈ R B r × d , K j ∈ R B c × d Q_i \in \R^{B_r \times d}, K_j \in \R^{B_c \times d} QiRBr×d,KjRBc×d,因此一次计算需要的 FLOPs 为 O ( B r B c d ) O(B_rB_cd) O(BrBcd)
  • 算法 1 第 12 行:计算 P ~ i j V j ∈ R B r × d \tilde{P}_{ij}V_j \in \R^{B_r \times d} P~ijVjRBr×d,由于 P ~ i j ∈ R B r × B c , V j ∈ R B c × d \tilde{P}_{ij} \in \R^{B_r \times B_c}, V_j \in \R^{B_c \times d} P~ijRBr×Bc,VjRBc×d,因此一次计算需要的 FLOPs 为 O ( B r B c d ) O(B_rB_cd) O(BrBcd)

上述计算循环的总次数为 T c T r = [ N B c ] [ N B r ] T_cT_r = [\frac{N}{B_c}][\frac{N}{B_r}] TcTr=[BcN][BrN],因此总的 FLOPs 为:

O ( N 2 B c B r B r B c d ) = O ( N 2 d ) O(\frac{N^2}{B_cB_r}B_rB_cd) = O(N^2d) O(BcBrN2BrBcd)=O(N2d)

定理 2:如果 SRAM 的大小 M 满足 d ≤ M ≤ N d d \leq M \leq Nd dMNd。标准 Attention 对 HBM 访问的次数为 Ω ( N d + N 2 ) \Omega(Nd + N^2) Ω(Nd+N2),而 FlashAttention 对 HBM 访问的次数为 O ( N 2 d 2 M − 1 ) O(N^2d^2M^{-1}) O(N2d2M1)

需要从 HBM 读取的数据有:

  • 算法 1 第 6 行:每次循环读取的 K j , V j K_j,V_j Kj,Vj的空间复杂度都为 Θ ( M ) \Theta(M) Θ(M),总复杂度为 Θ ( N d ) \Theta(Nd) Θ(Nd)
  • 算法 1 第 8 行:每次循环读取的 Q i , O i Q_i, O_i Qi,Oi的空间复杂度都为 Θ ( N d ) \Theta(Nd) Θ(Nd),循环次数为 T c = ⌈ N B c ⌉ T_c = \lceil \frac{N}{B_c} \rceil Tc=BcN,总复杂度为 Θ ( N d M ) \Theta(\frac{Nd}{M}) Θ(MNd)

因此,FlashAttention 对 HBM 总访问次数的复杂度为:

Θ ( N d + N d T c ) = Θ ( N d T c ) = Θ ( N 2 d 2 M − 1 ) \Theta(Nd + NdT_c) = \Theta(NdT_c) = \Theta(N^2d^2M^{-1}) Θ(Nd+NdTc)=Θ(NdTc)=Θ(N2d2M1)

当 M 越接近 Nd 时,FlashAttention 的总复杂度就近似 Θ ( N d ) \Theta(Nd) Θ(Nd),远比标准 Attention 快。并且 A100 显卡的 SRAM 大小为 192KB,远大于 d,因此 FlashAttention 的总复杂度要低于标准 Attention。

在这里插入图片描述

Q & A 相关

核函数融合在 Flash Attention 中的作用是什么?

在 Flash Attention 中,核函数融合的作用是将多个操作融合到一个 CUDA 核函数中执行。这意味着在 Flash Attention 算法中,输入从 HBM 加载到内存中,然后在 GPU 上执行所有计算步骤(矩阵乘法、softmax 等),最终将结果写回 HBM。核函数融合避免了反复读取和写入 HBM 的开销,提高效率。

参考资料

  • 论文地址:https://arxiv.org/pdf/2205.14135.pdf
  • FlashAttention V1论文粗读
  • FlashAttention 的速度优化原理是怎样的?

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/534102.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

超好用的iframe的postMessage穿参

前言 ❝ 跨域,简单来说是指不同域之间相互请求资源,例如AJAX请求,浏览器根据同源策略对响应结果进行拦截,这是浏览器对JavaScript实施的安全限制。所谓同源是指相同的域名、协议和端口,只要其中一项不同就为跨域 ❞ 背…

C++11异常:到底是怎么个异常

目录​​​​​​​ 一、C/C如何处理错误 1.1C语言传统的处理错误的方式 1.2C异常概念 二、异常的使用 2.1异常的抛出和捕获 2.2try/catch的使用 2.3异常安全 2.4异常的重新抛出 2.5异常的规范 三、服务器开发中异常体系的模拟 一、C/C如何处理错误 1.1C语言传统的处…

SEO关键词布局时如何查找用户爱搜索的关键词?

在关键词布局中,确定完核心词后,就要考虑在网站关键词扩展时,找到用户爱搜索的词,只有在网站页面关键词布局时,布局用户爱搜索的词,才能够使用户在搜索时网站的页面能够有机会出现在用户的搜索结果页&#…

蓝桥杯算法题:栈(Stack)

这道题考的是递推动态规划,可能不是很难,不过这是自己第一次靠自己想出状态转移方程,所以纪念一下: 要做这些题目,首先要把题目中会出现什么状态给找出来,然后想想他们的状态可以通过什么操作转移&#xf…

个人成长秘籍:参加六西格玛绿带培训的好处

在当今竞争激烈的商业环境中,追求卓越与持续改进已成为企业和个人成功的关键。六西格玛绿带培训,作为一种全面提升管理技能和工作效率的培训课程,不仅帮助企业优化流程、提高质量和效率,也为个人职业发展开辟了一条光明大道。张驰…

cog predict docker unknown flag: --file

如图: 使用cog predict -i image“link-to-image” 出现docker unknown flag: --file的问题。 解决方法(对我可行):切换cog版本。 这个是我一开始的cog安装命令(大概是下的最新版?)&#xff1…

cannal的使用

搭建MySQL 安装canal 1.新建文件夹logs, 新建文件canal.properties instance.properties docker.compose.yml instance.properties ################################################# ## mysql serverId , v1.0.26 will autoGen # canal.instance.mysql.slaveId0# enable g…

【话题:工作生活】2022年工作总结--疫情下的上海,疫情中的我。

现在是阳历2023年11月27日星期一,我再次开始撰写自己的年终工作总结。希望再过1、2个月,这份年终总结能够出炉,与大家相遇。 给自己定个小目标,年终的工作生活总结坚持写10年。我2017年毕业,之后就开始写每年的年终总结…

MathJax的基本使用

一、引言 MathJax引擎是一个开源的JavaScript库,它允许Web开发者在网页中嵌入高质量的数学公式。通过利用Web的最新技术,MathJax引擎可以解析LaTeX、MathML和AsciiMath等数学标记语言,并将其渲染为可视化的数学公式,这些公式可以…

【智能制造-1】涂胶解决方案

平面或立体转角处的等距点胶,既是技术上的难点也是实现上的痛点。 如何更好地保证拐角点胶的均匀性? 1、位置同步输出算法(PSO):可以在点胶阀设定频率不变的情况下实现恒速等距点胶,完美解决拐角堆胶问题…

jsonpath在线解析器网址

jsonpath在线解析器网址:https://jsonpath.com/

梦想CAD 在线编辑软件

前言 有用户集成网页CAD之后,需要提取图纸的各种信息和数据,下面我们讲一下Web版CAD如何在前端直接提取修改和转换后的图纸信息,没有集成过在线CAD的小伙伴可以先看一下快速入门(https://help.mxdraw.com/?pid32) 在…

Vue2.x实现商城购物车

1.实现购物车页面 在页面中显示购物车中的商品信息,并能进行数量增减及商品删除操作,购物车中金额也随商品数量的变化而变化 2.创建cart.html页面 创建cart.html页面,在其中创建Vue实例,实例中首先准备一些商品信息以供显示&a…

逆向入门:为CTF国赛而战day05day06

用的汉化版的 昨天做了一道题目,然后下了那个apkide改之理,就没了 今天再来一题。 我发现:ascii表要好好学。这里#号是35就被写到题目里去了。 CTF reverse 不一样的flag_ctf reverse flag.bin-CSDN博客

云原生技术:开启你的数字王国

在科技领域的飞速进步中,云计算已经成为了现代企业和个人不可或缺的技术。在这股云计算的热潮中,"云原生"这一概念正逐步成为焦点。云原生的话题越来越普及,无论是在日常生活中还是在专业工作场合,这个术语都频繁出现。…

1. VirtualBox安装CentOS

安装 VirtualBox 地址:https://www.virtualbox.org/wiki/Downloads 版本: 6.1和7.0+版本都可以 安装: windows上安装需要admin权限,右键菜单选中 “Run as administrator” 安装 CentOS 6.10 地址:https://vault.centos.org/6.10/isos/x86_64/ 版本: 如果不需要GUI,选择…

无人新零售引领的创新浪潮

无人新零售引领的创新浪潮 在数字化时代加速演进的背景下,无人新零售作为商业领域的一股新兴力量,正以其独特的高效性和便捷性重塑着传统的购物模式,开辟了一条充满创新潜力的发展道路。 依托人脸识别、物联网等尖端技术,无人新…

8.排序(直接插入排序、希尔排序、选择排序、堆排序、冒泡排序、快速排序、归并排序)的模拟实现

1.排序的概念及其运用 1.1排序的概念 排序:所谓排序,就是使一串记录,按照其中的某个或某些关键字的大小,递增或递减的排列起来的操作。 稳定性:假定在待排序的记录序列中,存在多个具有相同的关键字的记录…

node.js-入门

定义 Node.js是一个跨平台Javascript运行环境,使开发者可以搭建服务器端的Javascript应用程序 作用:使用Node.js编写服务器端程序 1)编写数据接口,提供网页资源浏览功能等 2)前端工程化:为后续学习Vue和…

DSP笔记8-通用GPIO

电源类 AD引脚类 系统相关JTAG 时钟 GPIO (general purpose input output)复用, 复用,I/O引脚,外设的功能引脚, 88个GPIO引脚,通用的输入输出口,功能复用的。 GPIO特点 输入电平与TTL电平兼容。>2.0V…