作者:黄奕桐、沈雯婷、艾宝乐、王昂、九丰
摘要
我们提出了长序列训练方案 FlashSequence 并集成在 PAI-TorchAcc (阿里云机器学习平台开发的Pytorch上的大模型训练加速框架)中,该方案能够支持SORA类超长序列模型的高效训练。在两机 16 卡 A100 上,FlashSequence 能够训练 1M 的长序列模型,并达到了 51.7%的 MFU,接近占据 E2E 95%时间的 FlashAttention 53.5%的 MFU。
一、横空出世的 SORA
SORA 介绍
SORA 是一个文生视频的模型,可以根据输入的文本生成对应的视频。
图1: SORA,这个图的核心部分来自:https://openai.com/research/video-generation-models-as-world-simulators。加上了text encoder和DiT blocks。
SORA 在训练时输入的视频可以看成是若干帧图像, 通过visual encoder得到spatial tempral patches 并 flatten成一维作为transformer tokens,同时输入的文本通过 text encoder 生成 embed,两者送入 diffusion transformer (DiT) 进行训练。
DiT 模型
图2: DiT模型的网络结构,来自:https://arxiv.org/abs/2212.09748
图 2 是 DiT 模型的网络结构,可以看到,DiT 模型和 LLAMA 等 LLM 模型的结构上大体上相同,DiT 模型多了对于输入 latent 的 patchify 处理(转换为 LLM 模型需要的 tokens 输入)、在 DecoderLayer 中增加了与文本输入的交互等。基本上 LLM 有的 Multi-Head Self-Attention、Pointwise Feadforward(MLP)结构,DiT 模型也有。
从整体上来看,DiT 模型与 LLM 模型的区别不大,从计算上来看,主要的计算量还是在 Attention 和 MLP 部分,区别在于 Attention 部分通常不会使用 casual mask,导致计算量较大;从显存上看,文本交互的部分会引入额外的显存使用。
训练需求
和LLM只有一个模型不同的是,文生视频模型由多个模型组成,包括:对文本进行编码的 text encoder、对视频进行编码的 visual encoder、DiT、用于推理的和 DiT 模型相同大小的 EMA 模型等。一般来说,text encoder为一个LM模型,预计在几 B 左右;visual encoder 通常为包含conv的VAE,参数量较小;DiT 为一个中小规模的模型,在 1B ~ 30B 左右。其中text encoder, visual encoder 通常为 pretrained 模型,在一些场景下无需进行训练。整体的模型参数量在 10B ~ 60B 左右,与一般的 LLM 模型差别不大。
同时,与一般的 LLM 模型不同的是,文生视频模型的输入 token 数在几十K到几M之间,因此文生视频模型训练的核心挑战是对长序列的高效支持。
二、 workload 分析
计算量
对于 text encoder 和 DiT 模型,主要计算量为,其中 L 表示 decoder layer 层数、b 表示 micro batch size,s 表示 sequence length,h 表示 hidden dim size,表示 linear 层的系数,表示 attention 层的系数(在一般的非 causal attention 上前向是 4,后向是 8)。在 s 比较大时,如几十K到几M 时,其他的算子如 element wise 算子在这里可以忽略。text encoder 的 s 通常比较固定的,且远小于视频输入的 token 数。对于 VAE 模型,其主要计算量为一系列卷积操作,在长序列场景,计算量一般小于 DiT 模型。
因此,在视频输入的 token 数 s 比较大时,如几十K到几M 时,整个训练的计算量绝大部分在 DiT 模型上。同时,在 s 比较大的时候,attention 部分的计算量会按照 token 数平方增长,而 linear 层的增长只是线性增长,因此,attention 部分的计算量会逐渐成为整个训练过程中的瓶颈。在 1M 场景,attention 部分的计算时间可以达到 E2E 时间的 95%。
显存使用
如前所述,文生视频整体的模型参数量在 10B ~ 60B 左右,与一般的 LLM 模型差别不大。在使用 7B 的 text encoder 和 7B 的 DiT 模型时,模型常驻的显存大概在 130GB(包含各个模型的参数和 DiT 部分的 optimizer state)。
在文生视频的训练过程中,通常只会训练 DiT 部分,而 text encoder 和 VAE 部分通常不会进行训练,所以显存使用主要在 DiT 的训练部分。text encoder 和 VAE 部分的显存主要考虑临时的 tensor 使用会不会导致 OOM。
在目前的 LLM 模型中,通常会使用 FlashAttention 进行性能和显存优化。在使用了 FlashAttention 之后,DiT 部分的显存使用为,其中 为一个 Decoder Layer 的显存使用系数,取决于 DiT 模型的实现,这个值会有所变化,但是由于 text 的输入,通常会比普通的 LLM 模型大,例如,一般的 LLM 模型可以是 34,而 DiT 模型会达到 60 ~ 70。在 tokens 数为 1M 、micro batch size 为 1 的场景下,7B 的 DiT 模型总的 activation 显存使用可以达到 8000 GB 以上。
可以看到,相比于计算量按照 token 数 s 平方增长,显存量是按照 token 数 s 线性增长的,这也为后续显存优化提供了参考。
三、FlashSequence 长序列训练方案
FlashSequence
基于 workload 分析,我们提出了 FlashSequence 这一解决方案:
-
分布式策略:
-
为了切分中小规模模型的参数,FlashSequence 使用了 FSDP 这一分布式策略,同时,FlashSequence 在 FSDP 外面嵌套使用了 DP 提升多机拓展性。
-
为了切分长序列训练场景下的 activation,FlashSequence 使用了 context parallel 对 sequence 维度进行切分,同时,FlashSequence 提出 2D context parallel 的方案减少 context parallel 跨机的通信开销。我们还去除了使用 context parallel 之后带来的冗余重复计算。
-
显存优化策略:
-
FlashSequence 通过使用 CPU offloading 将 activation offload 到 CPU 内存上减少显存,同时极大减少了 gradient checkpoint(GC)带来的额外重算计算量。CPU offloading 的策略在长序列场景下数据传输时间能够和计算时间完全 overlap,相比 GC 能够在不影响 E2E 时间的情况下减少显存。
-
为了避免 CPU 内存 OOM 和减少一部分 offloading 时间,FlashSequence 使用了 selective GC,selective GC 会优先选择显存计算比高的部分。
-
FlashSequence 还使用了 PyTorch expandable allocator 解决长序列场景下显存碎片过多的问题。
分布式策略
整体思路
在模型中存在两种类型的 tensor,一种是参数相关的包括模型参数、optimizer state、gradients,另一种是 activation。由于长序列场景下参数和 activation 都是不可忽视的,为了避免 OOM,我们需要同时切分参数和 activation。例如,常驻的参数和 optimizer state 可以达到 130GB,而 activation 在 1M 场景下可以达到 8000GB 以上。
参数切分
在参数的切分方面,我们存在多种选择,比如 TP、PP、FSDP 等,但是由于模型本身规模不是特别大,同时考虑到计算和通信的 overlap 情况,FlashSequence 选择了 FSDP 这种参数切分策略。不同于 TP 和 PP,FSDP 的通信除了第一个 layer 的 allgather 之外都能和计算 overlap,没有和计算 overlap 的通信时间在 FSDP 较小的情况下通常可以忽略。虽然 TP 也可以同时切分 activation,但是 TP 会引入无法 overlap 的通信,同时 PP 需要比较大的 gradient accumulation steps 才能掩盖 bubble。
activation 切分
对于 activation,DiT blocks 输入的 shape 为 [batch, sequence, hidden_dim]。由于长序列场景 activation 非常大,所以 micro batch size 通常为 1,这一维度无法切分。在 sequence 维度的切分目前存在 context parallel 如DeepSpeed-Ulysses 和 Ring Attention,以及 Megatron 的 sequence parallel。在 hidden_dim 的维度的切分主要是 Megatron 的 tensor parallel(以切分 weight 的方式实现对 activation 的 hidden dim 维度的切分)。纯粹的 tensor parallel 在 layer norm 等部分还是需要全量的 tensor,这一点在长序列场景是不可接受的。通常目前的主流做法是 Megatron 的 TP-SP 切分方式,这种切分方式和 context parallel 一样可以完整切分 layer 内的 activation。
对于 TP-SP 的切分方式,通信量为,其中t为 TP-SP 的数目,L 为 layer 数、s 为 sequence、b 为 micro batch size、h 为 hidden dim。对于 context parallel,以 DeepSpeed-Ulysses 为例,通信量为,其中为模型参数量,前面一项是对模型参数的 all reduce 通信,后面一项是对 self attn 的 q、k、v、out 的 alltoall 通信。对于DeepSpeed-Ulysses,模型参数的 all reduce 可以被计算 overlap(类似 DDP),而后面不能 overlap 的通信小于 TP-SP 的切分方式。
从上面的对比可以看出,DeepSpeed-Ulysses 不能 overlap 的通信理论上是小于 TP-SP 的(即使考虑 TP-SP 后向通信可以 overlap)。同时,我们使用 FSDP 切分参数之后也不再需要 TP 对模型参数进行切分。在这种场景下面,FSDP 只是一种切分模型参数的分布式策略,其数据并行的含义被弱化了,不再是开启多少 FSDP 读取多少不同的数据样本,只需要保证 context parallel 的一个 group 内读取相同的数据即可。
综上所述,FSDP+context parallel 的方式优于 TP-SP 的切分方式。同时context parallel 还可以使用 Ring Attention 的方式进一步减少不能 overlap 的通信。
FSDP+DP
在文生视频这种中小规模模型的场景下,FSDP 不需要开很大就可以避免 OOM,在 7B 及以下规模,使用 FSDP=8 就足够满足显存使用需求,同时还能使用高速的机内带宽进行通信。
在更多的卡数下,FSDP 的拓展性会存在一些问题,为了避免这些问题,FlashSequence 进行了 DP 和 FSDP 的嵌套,在外层使用 DP,在内层使用 FSDP。虽然 FSDP 和 DP 的通信都能被计算 overlap,但是 DP 的通信量小于 FSDP,同时 DP 只在计算时间更长的后向进行通信,所以,DP 相比于 FSDP 拥有更好的多机拓展性。在使用了 DP+FSDP 的组合之后,不只能满足参数切分的需求,同时提升了多机的拓展性。
Context Parallel
context parallel 的好处是只在 attention 部分和 transformer 模型之后引入了额外通信,在其他的部分比如 MLP 均不需要额外的通信,而且 gradients 的同步使用 DP+FSDP 就可以完成。同时,在 context parallel 的作用域之内 activation 和计算可以被均匀切分。
目前的 context parallel 都是在一开始就对 sequence 维度进行切分。唯一的区别在于 attention 部分的处理,DeepSpeed-Ulysses 会将 sequence 维度的切分转换为 head 维度的切分再进行 attention 的计算,而 RingAttention 会依然保留 sequence 维度的切分对 attention 的计算进行特殊处理。
DeepSpeed-Ulysses
图3:DeepSpeed-Ulysses,来自:https://arxiv.org/abs/2309.14509
如图 3 所示,DeepSpeed-Ulysses 会对 q、k、v 分别进行 all to all,将 sequence 维度的切分转换为 head 维度的切分再进行 attention 的计算,然后再对 attention 的输出进行 all to all,将 head 维度的切分转换回 sequence 维度的切分。由于 attention 的计算在 head 维度是并行的,所以这样操作之后不需要对 attention 的计算进行额外处理。可以看到,DeepSpeed-Ulysses 切分的是 head 维度,所以这使得DeepSpeed-Ulysses 的并行数目最多开到 head 的大小。
单个 layer 内 DeepSpeed-Ulysses 的通信和计算对比为:,其中 F 为 GPU 计算 FLOPS,B 为 alltoall 通信带宽。可以看到,随着 s 的变大,DeepSpeed-Ulysses 的通信占比会逐渐降低,最终达到一个可以忽略的程度,在 seq len = 256K 单机 8 卡的场景下,DeepSpeed-Ulysses 的通信时间在 E2E 的时间占比已经低于 1%,在 seq len=64K 的场景下也只有 2%~ 3%。但是,在涉及到跨机通信时,DeepSpeed-Ulysses 的通信开销由于机间通信带宽较低会变得不可忽视。在 256K 的场景下 2 机 16 卡会达到 10%以上。
Ring Attention
图4:Ring Attention,来自:https://arxiv.org/abs/2310.01889
如图 4 所示,Ring Attention 的实现过程中会保持 sequence 维度的切分。Ring Attention 会以 ring 的方式发送和接收其他 device 上的 k 和 v,同时计算本地的 q、k、v 分块的 attention,对输出进行一些矫正保证正确性。这种方式可以使得计算和通信能够 overlap 起来。
Ring Attention 计算和通信 overlap 的理论条件是:考虑前向的一个小的 Attention,通信量为 k 和 v:4bsh,计算量为:,所以计算能够掩盖通信的条件为:,其中 F 为 GPU 计算 FLOPS,B 为 send/recv 通信带宽。在实际运行过程中,还需要考虑 Flash Attention 的计算利用率和 send/recv 的带宽利用率,根据机器和算子性能的不同,在涉及跨机通信时,在 A100 上面下单卡需要 24K 的序列长度才能 overlap。
可以看到,Ring Attention 的优势是通信能够和计算 overlap,但是需要保证 s 切分后单 GPU 卡上的句子长度满足 overlap 条件。
2D context parallel
对于 context parallel,由于只有 attention 部分存在通信,所以我们需要考虑的只是 attention 部分的处理。在 attention 部分,activation 的 shape 为 [batch, sequence, heads, head_dim],由于维度的大小关系,在这其中 sequence、heads 和 head_dim 是可以进行切分的 sequence 和 heads 的切分分别代表了 Ring Attention 和 Ulysses。head_dim 维度由于是矩阵乘的 contracted 维度,这种维度的切分一般不可避免会引入无法 overlap 的 allreduce 或者 allgather 等通信算子,这会使得通信量大于 Ulysses 的 alltoall。
除此之外,我们还可以同时切分 sequence 维度和 heads 维度。在这种情况下,我们只需要进行一部分通信量较少的 alltoall 通信将一部分 sequence 维度转换为 head 维度,同时,针对剩余的 sequence 维度的切分,可以使用可以 overlap 的 send/recv 通信进行处理。由于 alltoall 的跨机性能较差同时 send/recv 的通信时间可以被计算 overlap,FlashSequence 让外层 alltoall 的通信使用机内的 nvlink 进行通信,内层的 send/recv 使用机间带宽进行通信。我们称这种 context parallel 为 2D context parallel。
2D context parallel 相比 DeepSpeed-Ulysses 可以减少没有 overlap 的 alltoall 时间,相比 Ring Attention 可以在单机 tokens 数较小时减少 send/recv 的次数和 attention 的计算时间,使得 send/recv 和计算可以 overlap。这种设计在 context parallel 涉及跨机通信时会显著减少没有和计算 overlap 的通信时间在 E2E 中的占比,在 seqlen = 256K、2 机 16 卡的场景可以将 DeepSpeed-Ulysses 的通信时间从 10%以上减少到低于 1%。
分布式策略的冗余计算优化
在上面我们提到使用 context parallel 对 sequence 维度进行切分,但是这个切分是存在边界的,一般情况下我们会在 activation 的 shape 转换为 transformer 需要的 shape 之后(比如 DiT 模型的 patchify 之后)才对sequence 维度进行切分。由于 context parallel 需要 group 内的 device 读取相同的数据,这就会导致从 dataloader 读取样本到 sequence 维度切分之间在 group 内的 device 进行的是相同的计算。这一部分在 SORA 模型中通常是 visual encoder 和 text encoder 模型,分别负责对视频和文本进行编码。这些计算在中小长度的序列长度下占比比较高,取决于具体模型实现和序列长度,可以达到 20%甚至 70%。
为了为了去除这一部分的冗余计算,我们可以让 context parallel group 内的 device 读取不同的数据,在需要 sequence 维度切分时进行一个 context parallel 大小的 loop 遍历,依次对前面不同 device 读取的数据进行 broadcast ,使得 transformer 的部分输入的数据一样。这样处理之后,VAE+text encoder 的时间占比会减少到之前的 1/context parallel size,带来 E2E 性能提升。
显存优化策略
使用分布式策略可以进行模型参数和 activation 的切分以减少显存,但是分布式策略的切分会引入通信开销,在更多卡参与切分时,这些开销会逐渐变得不可忽视。例如 activation 在 1M 场景下可以达到 8000GB 以上,使用 80GB 的 GPU 就需要至少 100 张卡,这是不可接受的。因此,我们还需要一些显存优化策略来进一步减少显存。
在目前的实践中,gradient checkpoint(GC)是较为常见的策略,GC 的重点是选择合适的重算部分以减少额外的计算开销。CPU offloading 在 DeepSpeed 中通常是对参数进行 offload,但是在长序列场景,我们发现 CPU offloading 在 activation 上相比 GC 也能带来明显的性能提升。显存碎片在长序列场景也会经常遇到,经常会出现 PyTorch reserve 了 10 几 GB 的显存却无法分配一个几百 MB 的 tensor,进而导致 OOM。
Selective GC
gradient checkpoint(GC)的思想是在前向过程中不保留 activation,在后向时重新运行一次前向生成 activation。在使用 GC 的过程中,最主要的问题是选择好重算的部分。目前主流的做法是对整个 decoder layer 进行 GC(full GC)或者对 Attention 部分进行 GC(Megatron selective GC)。
但是,如前所述,在长序列场景,attention 部分占据了绝大部分的计算,重算 attention 的开销很大。同时,与较小序列不同的是,在长序列场景,MLP 部分也是可以考虑进行 GC 的,在 1M 场景,MLP 的 E2E 占比已经低于 5%,重算的开销较小。
FlashSequence 会优先选择显存计算比高的部分。按照模型中的算子 FLOPS以及算子节省的显存量,FlashSequence 会选择依次节省显存收益较大的部分。
CPU Offloading
CPU Offloading 的思想是将部分 tensor 从显存传输到 CPU 内存上,在需要时再 prefetch 回来。在 DeepSpeed 中,这一技术通常只在参数上使用,这是因为之前的场景 offload activation 会有比较大的 PCIe 传输开销。然而,在长序列场景,如上面所述,相比于计算量按照 token 数 s 平方增长,显存量是按照 token 数 s 线性增长的,这就使得在长序列场景,计算的时间会逐渐超过 offload activation 的 PCIe 传输时间。在 64K 场景,offload 一层 decoder layer activation 的 PCIe 传输时间可能需要 2 ~ 3 层 layer 的计算进行 overlap,而在超过 256K 的场景,offload 一层 decoder layer activation 的 PCIe 传输时间仅需一层 layer 的计算就可以 overlap。在不同的模型下,这个 overlap 的 layer 的数目会有所区别,但是随着序列长度 s 的增长,最终都会达到一个可用的状态,比如在 64K 上使用 offloading 就可以无损减少多层 decoder layer 的 activation 显存占用。
以一层 decoder layer 的 activation 作为 offload 的粒度,offload 一层 decoder layer 可以达到和 GC 一层 decoder layer 类似的显存减少量,同时在长序列场景,offloading 的传输时间能够被计算时间 overlap,相当于在 E2E 性能无损的情况下减少了显存,相比于 GC 能够减少额外的计算开销。
虽然 offloading 在长序列场景拥有比 GC 更好的性能表现,offloading 本身也存在一些问题:
-
较短的序列长度需要多层 layer 的计算才能 overlap 传输时间,当然这个在更长的序列长度上不是问题。
-
offloading 需要使用 CPU 的 pinned 内存,而 CPU 的内存虽然有 1TB ~ 2TB,但是在长序列场景,8 张卡的 offloading 所需要使用的内存总量会很快超过 CPU 的内存。这可以通过结合部分 selective GC 进行解决。
-
offloading 会和跨机通信(RDMA 也会使用一部分 PCIe 资源)竞争 PCIe,这种影响在前向计算中比较明显。但是在使用了 DP+FSDP 和 2D context parallel 的组合之后,大部分通信都是使用 nvlink,机间通信也能够被计算 overlap,所以对 E2E 的性能影响不大。
基于上述问题,FlashSequence 在优先 CPU offloading 的同时使用 selective GC,避免 CPU 内存 OOM 的同时减少重算的 FLOPS。
显存碎片
在长序列场景,一个 tensor 的显存使用可以达到几百 MB 甚至几 GB,在这种场景下,PyTorch 的 caching allocator 会导致比较多的显存碎片。
图5: caching allocator的显存分配情况
图6: expandable allocator的显存分配情况
图 5 是某个长序列场景 OOM 时的显存使用情况,其中空白部分是还没分配但是被 PyTorch reserve 的显存。这个 OOM 本来是不应该出现的,因为这个时候请求分配的 tensor 只需要 500 多 MB 的显存,而 PyTorch reserve 的未分配显存有 7.5GB。但是因为 PyTorch reserve 的未分配显存都是不连续的(大的空白是 200 多 MB),所以导致了 OOM。这个显存碎片问题在更长的序列场景会更加常见,有时候可以达到 10 ~ 20GB 的显存碎片。
在 PyTorch 的 2.2 及以上版本,引入了expandable 的 allocator,这个 allocator 可以在有更大显存分配请求的情况下拓展已有的空闲显存块,进而减少原始 caching allocator 的显存碎片。从图 6 中可以看到显存碎片低于 1GB。在大部分长序列场景下,expandable allocator 的显存碎片都比 caching allocator 的小,同时在我们场景下性能基本没有变化。
计算优化
FlashAttention 优化
FlashAttention是DiT 模型中attention部分的常用优化手段,FlashAttention的前向计算量为FLOPS,后向的计算量是FLOPS(FlashAttention 在后向存在部分重算),如前文对计算量的分析,随着序列长度的增长,attention部分在端到端的训练时间中甚至占比到95%以上。因此,FlashAttention的计算性能,也成为整个训练任务最为dominate的部分。
我们对TriDao版本的FlashAttention2在不同序列长度下的性能做了A100上的kernel的性能测试,以batch-size=1, hidden-dim=128为例。通过图 7 和图 8 的性能数字,可以看到:
1. 序列长度和N_HEADS同时太大(N_CTX>=512K and N_HEADS>=8)或太小(N_CTX<=4K 或 N_CTX<=32K and N_HEADS<=4),都会造成一定程度的性能损失。这个性能数字也可以指导对N_HEADS和N_CTX的分片;
2. 根据FlashAttention的性能,我们可以预估一次迭代的训练时长。如1M下,按照前向227TFLOPS/s,后向191TFLOPS/s计算,一个layer计算batch-size=1, hidden-size=4096的FlashAttention计算时间为,约5min,L个layer的FlashAttention的用#n_gpus并行计算时间为5Lmin/n_gpus。
图7: FlashAttention在A100上面不同序列长度和HEAD大小下前向的性能
图8: FlashAttention在A100上面不同序列长度和HEAD大小下后向的性能
不同的 Attention 实现
前文中我们预估了1M序列长度下FlashAttention的计算时间,可以看到由于序列长度平方项的计算量的存在,一轮迭代的时间在分钟级别,导致模型训练的速度非常慢。有很多降低Attention二次方序列长度计算量的工作,提升Transformer效率的工作,主要包含以下几种:Linear Attention、Sparse Attention、Mamba、Compress Memory。这些算法由于计算量的减少,在模型的效果方面和原始Transformer存在差异,后续我们将对这方面的工作进行探索,并集成到系统中。
-
Linear Attention:
线性Attention的核心思想是用kernel函数代替softmax,然后通过矩阵乘法的结合律,将序列长度维度的两层循环减少为一层循环,从而将Attention的计算量从序列长度的平方项降低为线性项。线性Attention减少计算量也会造成模型的精度损失,取决于核函数的设计。
相关的工作有Transformers are RNNs, RMKW, Linformer, Lightning, DiJiang等。
-
Sparse Attention:
Sparse Attention通过让每个token对应的向量只跟部分token对应的向量(可见域)计算相关度,使Attention矩阵计算变得稀疏。在Sparse Attention中,如何选择有相关性的元素进行计算,成为影响模型精度的关键。相关工作探索了固定可见域,如OpenAI 2019的Sparse Transformers,与输入数据相关的可见域,如ICLR 2024的Transformer-VQ等多种算法。
-
Mamba
Mamba基于状态空间模型(SSM, State space models),结合了RNN(表达隐藏状态和输入关系,去掉非线性激活函数)和CNN(并行训练),并根据输入动态调整模型的选择性参数,包括当前输入和历史状态信息对输出的影响系数、无关信息的过滤参数等,并通过硬件感知算法来优化计算效率,将Transoformer的计算效率变成序列长度线性相关。
相关的工作包括Mamba,Mamba在视觉模型上的应用如ZigMA, ICLR 24 Diffusion SSM等。
-
Compress Memory
Compress Memory是一种将长序列切分为一个个segment,将历史segment的信息编码到一个固定大小的memory中,将当前segment的attention和memory信息concat到一起,从而将计算复杂度降低到1/n_segments,而memory占用的空间为固定大小,以此方式可以计算无限长度的序列的attention。相关工作如Infini-attention,ICAE等。
四、实验
图9: 不同sequence length和context parallel(CP)下的MFU
为了衡量我们提出的 FlashSequence 的解决方案,我们以纯 Ulysses+FSDP 并使用 full GC 作为 baseline,其中 full GC 的 layer 数依据显存使用决定,少于模型 layer 数。图 9 展示了在 A100 上不同sequence length和context parallel(CP)下的MFU,可以看到,FlashSequence 的 MFU 相比 baseline 平均提高了 11.75%,性能平均提升了 23.3%。同时,FlashSequence 的方案在长序列场景可以获得和 FlashAttention 接近的 MFU,比如在 1M、CP=16 的场景下,FlashSequence 的 MFU 为 51.7%,接近占据 E2E 95%时间的 FlashAttention 53.5%的 MFU。
五、总结与展望
PAI-TorchAcc(Torch Accelerator)是阿里云机器学习平台开发的Pytorch上的大模型训练加速框架。PAI-TorchAcc 通过进行分布式优化、计算优化、显存优化等,为包括 SORA 模型在内的Pytorch上的模型提供高效训练支持。
目前,FlashSequence 已经集成到 PAI-TorchAcc 中,并在开源的 DiT 模型上验证了效果。此外,由于 SORA 所使用的 DiT 类模型结构与 LLM 模型基本类似, FlashSequence 也可以应用在大部分长序列训练场景。后续我们会陆续开源这些工作。
同时,目前长序列的最主要瓶颈都在 FlashAttention 的计算上面,如何优化 FlashAttention 的计算将成为长序列场景下的主要问题。由于 FlashAttention 的计算量按照 token 数平方增长,未来更可能的优化方向是探索计算量更低的 Attention 实现比如线性的 Attention,同时低精度如 FP8 训练、稀疏训练等也都是一些可以探索的方向。
【招聘】最后,如果你对大模型训练加速技术感兴趣,欢迎加入到我们的团队中。目前研究型实习生和社招都在火热招聘中,欢迎投递简历到研究型实习生 - 基于负载与硬件特性协同的大模型训练加速技术研究 或邮箱 wenting.swt@alibaba-inc.com