flash attention作者Tri Dao发布了flash attention 2,性能为flash attention的2倍。
优化点主要如下:
一、减少 non-matmul FLOPs
A00中由于tensor core的存在,使得gpu对于浮点矩阵运算吞吐很高,如FP16/BF16可以达到312 TFLOPs/s,而对于非矩阵乘的浮点运算吞吐较低,如FP32只有19.5 TFLOPs/s。因此作者调整算法以减少非矩阵乘的浮点运算。
如图1-1,基线算法计算O2的时候会对O1进行放缩,先乘上之前的sum L1,再除以新的sum L2。
二、并行模式
基线对于CTA的分块逻辑为启动batch_size * num_head个CTA,每个CTA执行一个batch里的一个head,那么当seq_len很长的场景,batch_size一般会比较小,这个时候无法充分利用所有的SM,所以作者调整了并行模型,一个batch里的一个head也会被多个CTA执行。
基线算法中外层循环是对K,内层循环对Q,作者交换了这个循环,对外层循环进行并发。
综合一,二之后的算法流程如图2-1
三、warp分块
基线warp分块如图3-1,一个CTA所有warp都load Q,但是对K分块,这个时候计算S和P并没有啥问题,但是对计算O的时候,会导致warp之间对O执行一次reduce sum。