Megatron-LM源码系列(八): Context Parallel并行

1. Context Parallel并行原理介绍

megatron中的context并行(简称CP)与sequence并行(简称SP)不同点在于,SP只针对LayernormDropout输出的activation在sequence维度上进行切分,CP则是对所有的input输入和所有的输出activation在sequence维度上进行切分,可以看成是增强版的SP。除了Attention模块以外,其他的模块(Layernorm、Dropout)由于没有多token的处理,在CP并行时都不用任何修改。

为什么Attention模块是个例外? 因为Attention计算过程中每个token的Q(query)要跟同一个sequence中其他token的K(key)和V(value)一起进行计算,存在计算上的依赖,所以通过CP并行后,在计算Attention前要通过allgather通信拿到所有token的KV向量,在反向计算时对应需要通过reduce_scatter分发gradient梯度。

为了减少显存占用,在前向时每个gpu只用保存一部分KV块,反向时通过allgather通信拿到所有的KV数据。KV的通信发生在相邻TP通信组相同位置的rank之间。allgather和reduce_scatter在ring拓扑架构实现时,底层会通过send和recv来进行实现。
在这里插入图片描述

以上图TP2-CP2的transformer网络为例,在Attention前的是CP的通信算子,其他都是TP的通信算子。AG表示allgather, RS表示reduce_scatter, AG/RS表示前向allgather反向reduce_scatter, RS/AG表示前向reduce_scatter反向allgather。

这里TP2对应为[GPU0, GPU1], [GPU2, GPU3], CP2对应为TP组相同位置的rank号,也就是[GPU0, GPU2], [GPU1, GPU3]。CP并行与Ring Attention类似,但是提供了新的OSS与FlashAttention版本,也去除了low-triangle causal masking的冗余计算。

LLM经常由于sequence长度过长导致显存OOM,这时之前的一种方式是通过重计算的方式保存中间的activation产出,全量重计算的劣势会带来30%的计算代价;另外一种方式是扩大TP(tensor parallel)的大小,扩大TP的劣势在于会对tensor切的更小,从而导致linear fc的计算时间变少,从而与通信很难进行计算的掩盖。

通过CP可以更好解决OOM的问题,每个GPU只用处理一部分的sequence, 同时减少CP倍的通信和计算,但保持TP不变,同时activation也会减少CP倍。CP优化的性能参考如下图,在Megatron中(Megatron-Core>=0.5.0 && Transformer Engine >=1.1)通过指定--context-parallel-size可以进行使用。 t o t a l _ g p u _ c o u n t = T P × C P × P P × D P total\_gpu\_count = TP \times CP \times PP \times DP total_gpu_count=TP×CP×PP×DP

在这里插入图片描述

2. 源码

以Megatron-Core 0.5.0为例进行介绍

  • 首先在megatron/arguments.py中定义了--context-parallel-size参数, 同时也要求了world_size能要整除TP*PP*CP
    group.add_argument('--context-parallel-size', type=int, default=1,
                       help='Degree of context parallelism.')
    ....
    model_parallel_size = args.pipeline_model_parallel_size * \
                          args.tensor_model_parallel_size
    assert args.world_size % (model_parallel_size * args.context_parallel_size) == 0, \
        'world size ({}) is not divisible by tensor parallel size ({}) times ' \
        'pipeline parallel size ({}) times context parallel size ({})'.format(
        args.world_size, args.tensor_model_parallel_size,
        args.pipeline_model_parallel_size, args.context_parallel_size)
    args.data_parallel_size = args.world_size // (model_parallel_size * args.context_parallel_size)
  • megatron/core/parallel_state.py中初始化通信组时会初始化相关CP通信组, 以TP-PP-DP-CP=8-1-1-2为例,TP通信组为[0,1,2,3,4,5,6,7],[8,9,10,11,12,13,14,15], CP通信组为[0,8],[1,9],[2,10],[3,11],[4,12],[5,13],[6,14],[7,15]
def initialize_model_parallel(...):
    ...
    for i in range(pipeline_model_parallel_size):
        for j in range(data_parallel_size):
            start_rank = (
                i * num_pipeline_model_parallel_groups
                + j * tensor_model_parallel_size * context_parallel_size
            )
            end_rank = (
                i * num_pipeline_model_parallel_groups
                + (j + 1) * tensor_model_parallel_size * context_parallel_size
            )
            for k in range(tensor_model_parallel_size):
                ranks = range(start_rank + k, end_rank, tensor_model_parallel_size)
                group = torch.distributed.new_group(
                    ranks, pg_options=get_nccl_options('cp', nccl_comm_cfgs)
                )
                if rank in ranks:
                    _CONTEXT_PARALLEL_GROUP = group
                    _CONTEXT_PARALLEL_GLOBAL_RANKS = ranks
  • megatron/core/transformer/custom_layers/transformer_engine.pyTEDotProductAttention会初始化相关CP通信组相关参数,TEDotProductAttention继承自te.pytorch.DotProductAttention,在前向中直接调用父类的的forward函数。
class TEDotProductAttention(te.pytorch.DotProductAttention):
    def __init__(...):
        ...
        if te_version >= packaging.version.Version("1.0.0"):
            if getattr(TEDotProductAttention, "cp_stream") is None:
                TEDotProductAttention.cp_stream = torch.cuda.Stream()
            extra_kwargs["cp_group"] = get_context_parallel_group(check_initialized=False)
            extra_kwargs["cp_global_ranks"] = get_context_parallel_global_ranks(
                check_initialized=False
            )
            extra_kwargs["cp_stream"] = TEDotProductAttention.cp_stream
        ...

    def forward(...):
        ...
            core_attn_out = super().forward(
                query,
                key,
                value,
                attention_mask,
                attn_mask_type=attn_mask_type.name,
                **packed_seq_kwargs,
            )
        ...
  • Transformer Engine中DotProductAttention定义在transformer_engine/pytorch/attention.py中,CP相关参数通过attn_kwargs进行传入。接下来会调用FlashAttention的前反向。
class DotProductAttention(torch.nn.Module):
    def __init__(...):
        ...
        if self.use_flash_attention:
            self.flash_attention = FlashAttention(norm_factor,
                                                  attention_type=attention_type,
                                                  layer_number=layer_number,
                                                  deterministic=self.deterministic,
                                                  **attn_kwargs)
        ...

class FlashAttention(torch.nn.Module):
    def forward(...):
        ...
        if context_parallel:
            with self.attention_dropout_ctx():
                output = attn_forward_func_with_cp(
                    self.training, query_layer, key_layer, value_layer,
                    cu_seqlens_q, cu_seqlens_kv, max_seqlen_q, max_seqlen_kv,
                    self.attention_dropout if self.training else 0.0,
                    cp_group, cp_global_ranks, cp_stream,
                    softmax_scale=1.0/self.norm_factor,
                    qkv_format="bshd" if qkv_format=="sbhd" else qkv_format,
                    attn_mask_type=attn_mask_type,
                    deterministic=self.deterministic
                )
  • 在FlashAttention中会通过函数attn_forward_func_with_cp进行调用,最终Attn前的all_gather通信是在AttnFuncWithCP中通过send、recv通信来实现的, 执行完通信就执行对应的flash_attention算子的调用。
def attn_forward_func_with_cp(...):
    out = AttnFuncWithCP.apply(
        is_training, q, k, v, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k,
        dropout_p, cp_group, cp_global_ranks, cp_stream, softmax_scale, qkv_format,
        attn_mask_type, attn_bias_type, attn_bias, deterministic, use_fused_attention
    )
    return out
    
class AttnFuncWithCP(torch.autograd.Function):
    def forward(...):
        for i in range(cp_size+1):
            if i < cp_size:
                with torch.cuda.stream(flash_attn_streams[i%2]):
                    # wait until KV is received
                    for req in send_recv_reqs[(i+1)%2]:
                        req.wait()

                    if i < (cp_size-1):
                        p2p_comm_buffers[i+1] = torch.empty_like(p2p_comm_buffers[i])
                        send_recv_reqs[i%2] = flash_attn_p2p_communicate(rank,
                                                                         p2p_comm_buffers[i],
                                                                         send_dst,
                                                                         p2p_comm_buffers[i+1],
                                                                         recv_src,
                                                                         cp_group,
                                                                         batch_p2p_comm)
                    ...
                                fused_attn_fwd(
                                    is_training, max_seqlen_q, max_seqlen_k, cu_seqlens_q,
                                    cu_seqlens_k, q_inputs[i%2], kv_inputs[i%2][0],
                                    kv_inputs[i%2][1], TE_DType[q.dtype],
                                    tex.NVTE_Fused_Attn_Backend.NVTE_F16_arbitrary_seqlen,
                                    attn_scale=softmax_scale, dropout=dropout_p,
                                    qkv_layout=qkv_layout, attn_mask_type="causal",
                                    attn_bias_type=attn_bias_type, attn_bias=attn_bias_inputs[i%2],
                                )

3. 参考

  • Context parallelism overview

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/652210.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

06.部署jpress

安装mariadb数据 yum -y install mariadb-server #启动并设置开启自启动 systemctl start mariadb.service systemctl enable mariadb.service数据库准备 [rootweb01 ~]# mysql Welcome to the MariaDB monitor. Commands end with ; or \g. Your MariaDB connection id…

HCIP-Datacom-ARST自选题库_10_其他多选【48道题】

1.为什么说可以通过提高链路带宽容量来提高网络的QoS? 链路带宽的增加减小了拥塞发生的几率从而减少了云包的数量 链路带宽的增加可以增加控制协议的可用带宽 链路带宽的增加意味着更小的延迟和抖动 链路带宽的增加可以支持更高的流量 2.当拥塞发生时&#xff0c;通常会影…

VectorDBBench在windows的调试

VectorDBBench在windows的调试 VectorDBBench是一款向量数据库基准测试工具&#xff0c;支持milvus、Zilliz Cloud、Elastic Search、Qdrant Cloud、Weaviate Cloud 、 PgVector、PgVectorRS等&#xff0c;可以测试其QPS、时延、recall。 VectorDBBench是一款使用python编写的…

如何实现倾斜摄影三维模型OSGB格式轻量化

如何实现倾斜摄影三维模型OSGB格式轻量化 倾斜摄影三维模型以其高精度和真实感受在城市规划、建筑设计和虚拟漫游等领域发挥着重要作用。然而&#xff0c;由于其庞大的数据量和复杂的几何结构&#xff0c;给数据存储、传输和可视化带来了挑战。为了解决这个问题&#xff0c;倾斜…

KT6368A蓝牙芯片AT命令会被透传出去,指令对为什么会被透传出去

一、简介 KT6368A再被连接之后&#xff0c;AT命令会被透传出去。被透传的这组AT命令是符合文档要求&#xff0c;不应被透传&#xff0c;实际却经常被透传。并且可以每次都复现 详细描述 有问题部分的串口数据监控结果如下&#xff1a;其中41 54 2B 42 4D 46 30 41 46 42 43 3…

消费增值:国家支持的消费新零售模型

在当下的消费时代&#xff0c;一个全新的概念——消费增值&#xff0c;正逐渐走进大众视野。它不仅仅是一种消费模式&#xff0c;更是一种全新的财富增长途径。那么&#xff0c;消费增值究竟是什么&#xff1f; 首先&#xff0c;消费增值的本质在于将消费行为与投资行为相结合…

【代码随想录——回溯算法——三周目】

1. 子集2 这题需要先进行排序&#xff0c;和候选人那题类似。防止出现重复的子集。 func subsetsWithDup(nums []int) [][]int {path : make([]int, 0)res : make([][]int, 0)sort.Ints(nums)var dfs func(nums []int, start int)dfs func(nums []int, start int) {res app…

智能时代下,人机交互和虚拟现实的机遇和挑战

智能时代下,人机交互和虚拟现实的机遇和挑战

vue打包时报错文件包过大

1.问题&#xff1a;npm run build 之后出现 2. 翻译之后意思就是某块过大 3. 解决办法&#xff1a;在vite.config.ts文件上添加 build: { chunkSizeWarningLimit: 1600, }, 4.最终打包

部署运行petalinux系统镜像

参考文档《编译 petalinux 系统镜像》编译获取 petalinux 系统镜像&#xff0c;编译生成的各种镜像文件如下&#xff1a; scilogyhunterubuntu1804:~/petalinux/workspace/project0/petalinux$ ls images/linux/ bl31.bin Image pxelinux.cfg rootfs.cpio.gz.u-boot …

1967python多媒体素材管理系统mysql数据库Django结构layUI布局计算机软件工程网页

一、源码特点 python Django多媒体素材管理系统是一套完善的web设计系统mysql数据库 &#xff0c;对理解python编程开发语言有帮助&#xff0c;系统具有完整的源代码和数据库&#xff0c;系统主要采用B/S模式开发。 开发环境pycharm mysql 5.0 到5.5 依赖包 Dj…

postman调用Grpc

环境&#xff1a; .net6.0 一、准备 安装nuget&#xff1a; Grpc.AspNetCore Google.Protobuf Grpc.Core.Api Grpc.Tools Grpc.AspNetCore.Server.Reflection Program.cs&#xff1a; public class Program{public static void Main(string[] args){var builder WebApplicat…

【Matlab函数分析】绘图函数:colormap查看并设置当前颜色图

&#x1f517; 运行环境&#xff1a;Matlab &#x1f6a9; 撰写作者&#xff1a;左手の明天 &#x1f947; 精选专栏&#xff1a;《python》 &#x1f525; 推荐专栏&#xff1a;《算法研究》 #### 防伪水印——左手の明天 #### &#x1f497; 大家好&#x1f917;&#x1f91…

【PB案例学习笔记】-10 进度条使用

写在前面 这是PB案例学习笔记系列文章的第10篇&#xff0c;该系列文章适合具有一定PB基础的读者。 通过一个个由浅入深的编程实战案例学习&#xff0c;提高编程技巧&#xff0c;以保证小伙伴们能应付公司的各种开发需求。 文章中设计到的源码&#xff0c;小凡都上传到了gite…

2.5D的架构图相比3D有五大不可替代优势

2.5D架构图是一种介于2D和3D之间的图形表现形式&#xff0c;具有以下几个优势&#xff1a; 省时省力&#xff1a;相比于完全的3D架构图&#xff0c;2.5D架构图的制作相对简单&#xff0c;可以节省制作时间和人力成本。它只需要在平面上进行设计和绘制&#xff0c;不需要考虑3D…

域提权漏洞系列分析-Zerologon漏洞分析2

漏洞点⼆&#xff1a;错误设置CFB8模式 建⽴安全通道时&#xff0c;需要使⽤ComputeNetlogonCredential函数对客户端的Netlogon凭据输⼊client challenge和服 务器的Netlogon凭据输⼊server challenge (SC&#xff09;进⾏加密&#xff0c;ComputeNetlogonCredential函数⽀持使…

飞控如何和接收机接线?

飞控如何和接收机接线&#xff1f; 一般遥控都是按照顺序1对1接到飞控的INPUT端口。特别注意&#xff0c;华科尔的接收机&#xff0c;需要把1和2通道条换过来。 具体方法如下&#xff1a; 下面以MC6遥控接收机为例子&#xff1a; 用下面的图的接收机连接线来演示&#xff1a…

【C++】开源:RabbitMQ安装与配置使用(SimpleAmqpClient)

&#x1f60f;★,:.☆(&#xffe3;▽&#xffe3;)/$:.★ &#x1f60f; 这篇文章主要介绍。 无专精则不能成&#xff0c;无涉猎则不能通。——梁启超 欢迎来到我的博客&#xff0c;一起学习&#xff0c;共同进步。 喜欢的朋友可以关注一下&#xff0c;下次更新不迷路&#x1…

隆道专属商城 | 助力企业跨平台整合优势资源,解决采购寻源比价难题!

数字化采购时代&#xff0c;企业面临着日益激烈的市场竞争&#xff0c;如何优化资源配置、降低采购成本、提高采购效率成为企业追求的核心目标。当前&#xff0c;网上商城凭借其强大的供应链资源整合能力&#xff0c;为企业内部采购商城的搭建提供了独特的优势&#xff0c;已然…

常见SSL证书品牌关系图

常见SSL证书品牌关系图 在SSL证书市场上&#xff0c;有几个主要的品牌和他们之间的复杂关系。以下是一些主要的SSL证书提供商及其关系的简要概述&#xff1a; DigiCert&#xff1a; DigiCert 是最大的SSL证书颁发机构之一。它收购了Symantec的SSL和PKI业务&#xff0c;其中包括…