在Pytorch的2.2版本更新文档中,官方重点强调了通过实现FlashAtteneion-v2
实现了对scaled_dot_product_attention
约2X左右的加速。
今天抽空亲自试了下,看看加速效果是否如官方所说。测试前需要将Pytorch的版本更新到2.2及以上,下面是测试代码,一个是原始手写的Self-Attention
的实现,一个是使用Pytorch官方的scaled_dot_product_attention
接口:
import time
import torch
import torch.nn.functional as F
def main():
repeat = 100
device = torch.device("cuda:0")
dtype = torch.float16
query = torch.rand(32, 8, 128, 64, dtype=dtype, device=device)
key = torch.rand(32, 8, 128, 64, dtype=dtype, device=device)
value = torch.rand(32, 8, 128, 64, dtype=dtype, device=device)
scale_factor = 0.125
ori_time_list = []
for _ in range(repeat):
torch.cuda.synchronize(device=device)
time_start = time.perf_counter()
# 原始Self-Attention实现
res = torch.softmax(query @ key.transpose(-2, -1) * scale_factor, dim=-1) @ value
torch.cuda.synchronize(device=device)
time_end = time.perf_counter()
ori_time_list.append(time_end - time_start)
fa_time_list = []
for _ in range(repeat):
torch.cuda.synchronize(device=device)
time_start = time.perf_counter()
with torch.backends.cuda.sdp_kernel(enable_math=False):
# 使用Pytorch官方提供的FA实现
res_fa = F.scaled_dot_product_attention(query, key, value, scale=scale_factor)
torch.cuda.synchronize(device=device)
time_end = time.perf_counter()
fa_time_list.append(time_end - time_start)
diff = (res - res_fa).abs().max()
ratio = [ori_time_list[i] / fa_time_list[i] for i in range(repeat)]
avg_ratio = sum(ratio[1:]) / len(ratio[1:])
print(f"max diff: {diff}")
print(f"avg speed up ratio: {avg_ratio}")
if __name__ == '__main__':
main()
执行以上代码,终端输出如下:
max diff: 0.00048828125
avg speed up ratio: 2.2846881043417118
这里使用的设备是RTX4070
,跑了很多次发现确实加速2X左右,看来以后训练或者推理时可以考虑直接使用官方的scaled_dot_product_attention
接口了。但是这里也发现了两个问题,一个是原始手写的Self-Attention
的计算结果和直接调用scaled_dot_product_attention
接口得到的结果差异有点大(注意,这里计算的Tensor都是FP16精度的),如果我切换到FP32精度差异会再小两个数量级。第二个问题是如果使用FP32的话实测没有明显加速,这个就很奇怪了,官方文档里并没有说专门针对FP16精度优化的(后面找了个A100的GPU试了下,发现FP32也是有加速的)。