文章:FLASHATTENTION: Fast and Memory-Efficient Exact Attention with IO-Awareness
原始Attention 计算使用gpu存储标准流程
涉及两个gpu存储器:
1)SRAM(static Random Access Memory):静态随机存取存储器
2)HBM(High Bandwidth Memory):高带宽存储器
具体流程
Q,K,V 初始存储在HBM中,S=Q(K的转置),P= softmax(S),O = PV
1)将Q,K从HBM取到SRAM当中,计算S,将S 放到HBM当中。
2)将S从HBM取到SRAM当中,计算P,将P 放到HBM当中。
3P)将S从HBM取到SRAM当中,计算O,将O 放到HBM当中。
FlashAttention旨在避免从 HBM(High Bandwidth Memory)中读取和写入注意力矩阵,这需要做到:
- 目标一:在不访问整个输入的情况下计算softmax函数的缩减;
- 目标二:在后向传播中不能存储中间注意力矩阵.
优化1:
FlashAttention如何实现在不访问整个输入的情况计算softmax大的缩减,标准Attention算法由于要计算softmax,而softmax都是按行来计算的,即在和V做矩阵乘之前,需要让 Q、K 的各个分块完成整一行分块的计算得到Softmax的结果后,再和矩阵V分块做矩阵乘。而在Flash Attention中,将输入分割成块,并在输入块上进行多次传递,从而以增量方式执行softmax缩减。
优化2:
在后向传播中不存储中间注意力矩阵,以Flash Attention所提供的算法为例,通过对比标准Attention算法在实现过程中,标准Attention算法的实现需要将计算过程中的S、P写入到HBM中,而这些中间矩阵的大小与输入的序列长度有关且为二次型,因此Flash Attention就提出了不使用中间注意力矩阵,通过存储归一化因子来减少HBM内存的消耗。
在Flash Attention的前向计算算法中我们可以看出,Flash Attention算法并没有将S、P写入HBM中去,而是通过分块写入到HBM中去,存储前向传递的 softmax 归一化因子,在后向传播中快速重新计算片上注意力,这比从HBM中读取中间注意力矩阵的标准方法更快。即使由于重新计算导致 FLOPS 增加,但其运行速度更快并且使用更少的内存(序列长度线性),主要是因为大大减少了 HBM 访问量。
1. 存储softmax归一化的系数的含义
在Flash Attention中,"只存储softmax归一化的系数"是指在进行反向传播时,不直接存储整个注意力矩阵,而是只存储用于softmax归一化的系数。这样做的目的是为了减少存储需求,从而节省内存。
在传统的注意力机制中,我们需要计算并存储一个N^2的注意力矩阵,其中N是输入序列的长度。这个矩阵存储了序列中每个元素对其他所有元素的注意力权重。然而,这种方法的存储需求随着序列长度的增加而呈平方级增长,对内存的需求非常大。
Flash Attention通过只存储softmax归一化的系数,避免了存储整个注意力矩阵,从而大大减少了内存需求。这些系数足够用于在反向传播过程中计算梯度,而无需参考完整的注意力矩阵。
2. Flash Attention的优点
Flash Attention的主要优点是它可以显著减少内存需求,同时也加速了计算。这使得模型能够处理更长的序列,或者在内存有限的设备上运行。此外,Flash Attention还保持了与传统注意力机制相同的表现力,因为它仍然能够模拟序列中元素之间的所有对应关系。
Softmax归一化系数解析
1. Softmax归一化
Softmax是一种常用的归一化函数,它可以将一组任意实数转换为一组在(0,1)区间内的实数,且这组实数的总和为1。这使得softmax函数的输出可以被解释为一组概率分布。
2. Softmax归一化系数
在softmax函数中,"归一化系数"通常指的是用于将每个输入值转换为概率的那个系数。具体来说,假设我们有一组输入值x1, x2, ..., xn,那么对应的softmax归一化系数就是1/Z,其中Z是所有输入值经过指数运算后的和,即Z = exp(x1) + exp(x2) + ... + exp(xn)。这个归一化系数保证了softmax函数的输出值的总和为1。
在某些情况下,为了节省存储空间和计算资源,我们可能只会存储这个归一化系数,而不是存储整个softmax函数的输出。例如,在上文提到的Flash Attention中,就只存储了这个归一化系数,而不是存储整个注意力矩阵。