FlashAttention(flash-attn)安装
Flash Attention是一种注意力算法,用于减少这一问题,并更有效地缩放基于转换器的模型,从而实现更快的训练和推理。标准注意力机制使用高带宽内存(HBM)来存储、读取和写入密钥、查询和值。很多llm模型运行的时候都需要安装flash_attn,比如Llama3。
FlashAttention: Fast and Memory-Efficient Exact Attention with IO-Awareness
Paper: https://arxiv.org/abs/2205.14135
github地址,建议安装之前读一下官方文档。
官方给出的安装依赖:
- CUDA 11.6 and above.
- PyTorch 1.12 and above.
- Make sure that ninja is installed and that it works correctly (e.g. ninja --version then echo $? should return exit code 0). If not (sometimes ninja --version then echo $? returns a nonzero exit code), uninstall then reinstall ninja (pip uninstall -y ninja && pip install ninja). Without ninja, compiling can take a very long time (2h) since it does not use multiple CPU cores. With ninja compiling takes 3-5 minutes on a 64-core machine.
1. 最终的成功版本
按照自己的Pytorch版本、CUDA版本和Python版本来选择对应版本的flash-attn
,否则会默认下载最新版。
个人趟坑后推荐的pytorch版本大于2.2,CUDA版本为1.8。
安装顺序:CUDA Toolkit→PyTorch→FlashAttention
- CUDA Toolkit下载地址:https://developer.nvidia.com/cuda-toolkit-archive
- 特定版本的PyTorch下载地址:https://pytorch.org/get-started/previous-versions/
- FlashAttention下载地址:https://github.com/Dao-AILab/flash-attention/releases/
1.1 安装/更新CUDA Toolkit
最详细的安装教程请务必参考NVIDIA CUDA Installation Guide for Linux,非常保姆的教程,前前后后的依赖和处理都很详尽,但英文的阅读和长度也很感人,以下是我趟出来的一套简单流程。
1.1.1 下载CUDA Toolkit
以终为始我们需要的是cuda-11.8,那么直接去CUDA Toolkit下载地址下载特定版本。
wget https://developer.download.nvidia.com/compute/cuda/11.8.0/local_installers/cuda_11.8.0_520.61.05_linux.run
1.1.2 安装CUDA Toolkit
#登录管理员,使用root权限
su
# 修改权限
chmod +x cuda_10.2.89_440.33.01_linux.run
#安装
./cuda_10.2.89_440.33.01_linux.run --no-opengl-libs
#退出root
exit
执行后的弹框输入:accept,然后按Enter键继续。第二页选中Driver,按Enter键,再选中Install,按Enter键进行安装。
默认安装路径,那么将会安装在**/usr/local/cuda-11.8/**路径下。
1.1.3 设置环境变量
export PATH=/usr/local/cuda-11.8/bin${PATH:+:${PATH}}
export LD_LIBRARY_PATH=/usr/local/cuda-11.8/lib64/${LD_LIBRARY_PATH:+:${LD_LIBRARY_PATH}}
1.2 安装PyTorch
已经安装好了cuda-11.8后,并加入了环境变量,nvcc -V
显示的是为cuda-11.8则可以正式开始安装PyTorch v2.2.2。(之所是是2.2.2版本也是因为之前的1.2版本太老导致FlashAttention虽然看起来安装成功,但用起来仍然报错。查了很久之后,安装成功的人都推荐PyTorch v2.2.2)
pip install torch==2.2.2 torchvision==0.17.2 torchaudio==2.2.2 --index-url https://download.pytorch.org/whl/cu118
尽管官方给出安装PyTorch用conda install
也可,但我没成功/(ㄒoㄒ)/~~,可能是channel的问题,whatever, 还是用pip install
吧。
1.3 安装 FlashAttention
至此,已经安装好了cuda-11.8和PyTorch v2.2.2,python -V
查看当前的Python版本,就可以在FlashAttention下载地址选择对应的whl
文件用pip install
来安装了。
以flash_attn-2.5.3+cu118torch2.2cxx11abiFALSE-cp39-cp39-linux_x86_64.whl
为例
wget https://github.com/Dao-AILab/flash-attention/releases/download/v2.5.3/flash_attn-2.5.3+cu118torch2.2cxx11abiFALSE-cp39-cp39-linux_x86_64.whl
pip install flash_attn-2.5.3+cu118torch2.2cxx11abiFALSE-cp39-cp39-linux_x86_64.whl