Moonshot AI 新突破:MoBA 为大语言模型长文本处理提效论文速读

前言

在自然语言处理领域,随着大语言模型(LLMs)不断拓展其阅读、理解和生成文本的能力,如何高效处理长文本成为一项关键挑战。近日,Moonshot AI Research 联合清华大学、浙江大学的研究人员提出了一种创新方法 —— 混合块注意力机制(Mixture of Block Attention,MoBA),它将专家混合(Mixture of Experts,MoE)原理应用于注意力机制,为解决长文本处理难题带来了新的思路。

在 Transformer 架构广泛应用的当下,其注意力机制存在明显弊端。在处理长文本时,传统注意力机制需将每个 token 与其他所有 token 进行比较,这使得计算成本随序列长度呈二次方增长。当模型处理长篇文档、多章书籍、法律简报或大型代码库等包含大量文本信息的任务时,这种计算成本会变得难以承受。此前,为解决这一问题,研究人员尝试过多种方法。例如,滑动窗口机制将 token 限制在局部邻域内,虽降低了计算量,但会忽略重要的全局关系;而一些彻底改变基本架构的方法,如用全新结构替代 softmax 注意力机制,往往需要从头开始重新训练模型,难以利用现有的预训练成果。

核心原理

MoBA 的出现有效弥补了上述方法的不足。它的核心在于将输入划分为易于管理的 “块”,并借助可训练的门控系统来确定每个查询 token 相关的块。这种设计遵循 “少结构” 原则,不预先定义哪些 token 应该相互作用,而是由学习到的门控网络做出决策。与固定结构或近似处理的方法不同,MoBA 能让模型自主学习注意力的聚焦点。而且,MoBA 可与现有的基于 Transformer 的模型无缝协作,它作为一种 “插件” 或替代方案,保持与原模型相同的参数数量,避免架构膨胀,同时保留因果掩码,确保自回归生成的准确性。在实际应用中,MoBA 能在稀疏注意力和全注意力之间灵活切换。处理超长输入时,稀疏注意力可提升速度;而在训练的某些层或阶段,若需要全注意力,模型也能切换回标准模式。

从技术细节来看,MoBA 将上下文划分为多个块,每个块包含连续的 token 序列。门控机制通过比较查询 token 与块的池化键表示,计算查询 token 与每个块之间的 “亲和度” 分数,然后选择得分最高的块。这样,只有最相关块中的 token 才会对最终的注意力分布产生影响。同时,包含查询 token 本身的块始终被纳入,以确保局部上下文信息可访问。并且,MoBA 执行因果掩码,防止 token 关注未来位置,维持从左到右的自回归属性。这种基于块的方法大幅减少了 token 比较次数,使计算规模低于二次方,随着上下文长度增加到数十万甚至数百万个 token,效率提升愈发显著。此外,MoBA 与现代加速器和专用内核兼容性良好。研究人员将 MoBA 与 FlashAttention(一种高性能的快速、内存高效的精确注意力库)相结合,根据所选块对查询 - 键 - 值操作进行精心分组,进一步优化了计算流程。实验数据显示,在处理一百万个 token 时,MoBA 相比传统全注意力机制速度提升约 6 倍,凸显了其在实际应用中的优势。

在性能测试方面,MoBA 表现出色。技术报告显示,在多种任务中,MoBA 的性能与全注意力机制相当,但在处理长序列时可显著节省计算资源。在语言建模数据测试中,当序列长度为 8192 或 32768 个 token 时,MoBA 的困惑度与全注意力 Transformer 相近。更为关键的是,当研究人员将上下文长度逐渐扩展到 128000 及更长时,MoBA 仍能保持强大的长上下文理解能力。在 “尾随 token” 评估中,MoBA 能够有效处理长提示末尾附近的 token 预测任务,且预测质量没有明显下降。研究人员还对 MoBA 的块大小和门控策略进行了敏感性探索。实验表明,细化粒度(使用更小的块但选择更多的块)有助于模型更接近全注意力的效果。即使在忽略大部分上下文的情况下,自适应门控也能识别与查询真正相关的块。此外,“混合” 模式展现出一种平衡策略:部分层继续使用 MoBA 提升速度,少数层则恢复全注意力。这种混合方法在监督微调任务中尤为有益,例如当输入中的某些位置在训练目标中被屏蔽时,保留少数上层的全注意力,可使模型保持广泛的上下文覆盖,有助于需要全局视角的任务。

关键代码分析:

以下是对 MoBA 库关键代码 MixedAttention 类的分析以及关键代码的摘录与注释:

整体分析

MixedAttention 类是一个自定义的 torch.autograd.Function,用于实现混合块注意力机制。这个类主要包含两个静态方法:forward 和 backward,分别用于前向传播和反向传播。

class MixedAttention(torch.autograd.Function):

    # 前向传播函数
    @staticmethod
    def forward(
        ctx,
        q,  # 查询张量
        k,  # 键张量
        v,  # 值张量
        self_attn_cu_seqlen,  # 自注意力累积序列长度
        moba_q,  # MoBA 查询张量
        moba_kv,  # MoBA 键值张量
        moba_cu_seqlen_q,  # MoBA 查询累积序列长度
        moba_cu_seqlen_kv,  # MoBA 键值累积序列长度
        max_seqlen,  # 最大序列长度
        moba_chunk_size,  # MoBA 块大小
        moba_q_sh_indices,  # MoBA 查询块索引
    ):
        # 保存一些参数,用于后续的反向传播
        ctx.max_seqlen = max_seqlen
        ctx.moba_chunk_size = moba_chunk_size
        ctx.softmax_scale = softmax_scale = q.shape[-1] ** (-0.5)

        # 自注意力计算
        _, _, _, _, self_attn_out_sh, self_attn_lse_hs, _, _ = (
            _flash_attn_varlen_forward(
                q=q,
                k=k,
                v=v,
                cu_seqlens_q=self_attn_cu_seqlen,
                cu_seqlens_k=self_attn_cu_seqlen,
                max_seqlen_q=max_seqlen,
                max_seqlen_k=max_seqlen,
                softmax_scale=softmax_scale,
                causal=True,
                dropout_p=0.0,
            )
        )

        # MoBA 注意力计算
        _, _, _, _, moba_attn_out, moba_attn_lse_hs, _, _ = _flash_attn_varlen_forward(
            q=moba_q,
            k=moba_kv[:, 0],
            v=moba_kv[:, 1],
            cu_seqlens_q=moba_cu_seqlen_q,
            cu_seqlens_k=moba_cu_seqlen_kv,
            max_seqlen_q=max_seqlen,
            max_seqlen_k=moba_chunk_size,
            softmax_scale=softmax_scale,
            causal=False,
            dropout_p=0.0,
        )

        # 转换 lse 形状,从 hs 转换为 sh(遵循传统混合注意力逻辑)
        self_attn_lse_sh = self_attn_lse_hs.t().contiguous()
        moba_attn_lse = moba_attn_lse_hs.t().contiguous()

        # 初始化输出缓冲区,形状与 q 相同
        output = torch.zeros(
            (q.shape[0], q.shape[1], q.shape[2]), device=q.device, dtype=torch.float32
        )

        # 将输出张量展平为二维,便于后续索引操作
        output_2d = output.view(-1, q.shape[2])

        # 计算混合 lse
        # 减去最大 lse 以避免指数爆炸
        max_lse_1d = self_attn_lse_sh.view(-1)
        max_lse_1d = max_lse_1d.index_reduce(
            0, moba_q_sh_indices, moba_attn_lse.view(-1), "amax"
        )
        self_attn_lse_sh = self_attn_lse_sh - max_lse_1d.view_as(self_attn_lse_sh)
        moba_attn_lse = (
            moba_attn_lse.view(-1)
            .sub(max_lse_1d.index_select(0, moba_q_sh_indices))
            .reshape_as(moba_attn_lse)
        )

        # 计算自注意力和 MoBA 注意力的 softmax 结果
        mixed_attn_se_sh = self_attn_lse_sh.exp()
        moba_attn_se = moba_attn_lse.exp()

        # 将 MoBA 注意力结果累加到自注意力结果上
        mixed_attn_se_sh.view(-1).index_add_(
            0, moba_q_sh_indices, moba_attn_se.view(-1)
        )
        mixed_attn_lse_sh = mixed_attn_se_sh.log()

        # 加权自注意力输出
        factor = (self_attn_lse_sh - mixed_attn_lse_sh).exp()  # [ vS, H ]
        self_attn_out_sh = self_attn_out_sh * factor.unsqueeze(-1)
        output_2d += self_attn_out_sh.reshape_as(output_2d)

        # 加权 MoBA 输出
        mixed_attn_lse = (
            mixed_attn_lse_sh.view(-1)
            .index_select(0, moba_q_sh_indices)
            .view_as(moba_attn_lse)
        )
        factor = (moba_attn_lse - mixed_attn_lse).exp()  # [ vS, H ]
        moba_attn_out = moba_attn_out * factor.unsqueeze(-1)
        raw_attn_out = moba_attn_out.view(-1, moba_attn_out.shape[-1])
        output_2d.index_add_(0, moba_q_sh_indices, raw_attn_out)

        # 将输出转换为与输入相同的数据类型
        output = output.to(q.dtype)

        # 恢复最大 lse
        mixed_attn_lse_sh = mixed_attn_lse_sh + max_lse_1d.view_as(mixed_attn_se_sh)

        # 保存中间结果,用于反向传播
        ctx.save_for_backward(
            output,
            mixed_attn_lse_sh,
            q,
            k,
            v,
            self_attn_cu_seqlen,
            moba_q,
            moba_kv,
            moba_cu_seqlen_q,
            moba_cu_seqlen_kv,
            moba_q_sh_indices,
        )

        return output

    # 反向传播函数
    @staticmethod
    def backward(ctx, d_output):
        # 从上下文中获取保存的参数
        max_seqlen = ctx.max_seqlen
        moba_chunk_size = ctx.moba_chunk_size
        softmax_scale = ctx.softmax_scale

        (
            output,
            mixed_attn_vlse_sh,
            q,
            k,
            v,
            self_attn_cu_seqlen,
            moba_q,
            moba_kv,
            moba_cu_seqlen_q,
            moba_cu_seqlen_kv,
            moba_q_sh_indices,
        ) = ctx.saved_tensors

        # 确保输入梯度连续
        d_output = d_output.contiguous()

        # 计算自注意力的梯度
        dq, dk, dv, _ = _flash_attn_varlen_backward(
            dout=d_output,
            q=q,
            k=k,
            v=v,
            out=output,
            softmax_lse=mixed_attn_vlse_sh.t().contiguous(),
            dq=None,
            dk=None,
            dv=None,
            cu_seqlens_q=self_attn_cu_seqlen,
            cu_seqlens_k=self_attn_cu_seqlen,
            max_seqlen_q=max_seqlen,
            max_seqlen_k=max_seqlen,
            softmax_scale=softmax_scale,
            causal=True,
            dropout_p=0.0,
            window_size=(-1, -1),
            softcap=0.0,
            alibi_slopes=None,
            deterministic=True,
        )

        # 计算 MoBA 注意力的梯度
        headdim = q.shape[-1]
        d_moba_output = (
            d_output.view(-1, headdim).index_select(0, moba_q_sh_indices).unsqueeze(1)
        )
        moba_output = (
            output.view(-1, headdim).index_select(0, moba_q_sh_indices).unsqueeze(1)
        )

        mixed_attn_vlse = (
            mixed_attn_vlse_sh.view(-1).index_select(0, moba_q_sh_indices).view(1, -1)
        )

        dmq, dmk, dmv, _ = _flash_attn_varlen_backward(
            dout=d_moba_output,
            q=moba_q,
            k=moba_kv[:, 0],
            v=moba_kv[:, 1],
            out=moba_output,
            softmax_lse=mixed_attn_vlse,
            dq=None,
            dk=None,
            dv=None,
            cu_seqlens_q=moba_cu_seqlen_q,
            cu_seqlens_k=moba_cu_seqlen_kv,
            max_seqlen_q=max_seqlen,
            max_seqlen_k=moba_chunk_size,
            softmax_scale=softmax_scale,
            causal=False,
            dropout_p=0.0,
            window_size=(-1, -1),
            softcap=0.0,
            alibi_slopes=None,
            deterministic=True,
        )

        # 合并 MoBA 的键和值的梯度
        dmkv = torch.stack((dmk, dmv), dim=1)

        return dq, dk, dv, None, dmq, dmkv, None, None, None, None, None

代码关键部分解释

  • 前向传播 (forward)

    • 分别计算自注意力和 MoBA 注意力的结果。
    • 对注意力分数进行处理,包括形状转换、归一化等操作,以避免指数爆炸。
    • 将自注意力和 MoBA 注意力的结果进行加权合并,得到最终的输出。
    • 保存中间结果,用于后续的反向传播。
  • 反向传播 (backward)

    • 根据前向传播保存的中间结果,计算自注意力和 MoBA 注意力的梯度。
    • 最终返回各个输入张量的梯度。

小结

通过这种方式,MixedAttention 类实现了 MoBA 混合块注意力机制,通过将上下文划分为块并进行选择性的注意力计算,有效减少了计算量,提升了处理长文本的效率。

总结

总体而言,MoBA 非常适合处理涉及大量上下文的任务,如长篇文档阅读理解、大规模代码补全以及需要完整对话历史的多轮对话系统。它在提高效率的同时,性能损失极小,为大规模训练大语言模型提供了一种极具吸引力的方法。虽然目前 MoBA 主要应用于文本领域,但研究人员认为,其底层机制在其他数据模态中也具有应用潜力。只要序列长度足够长,引发计算或内存问题,将查询分配给块 “专家” 的思路就有望缓解瓶颈,同时保持处理关键全局依赖关系的能力。随着语言应用中的序列长度持续增长,像 MoBA 这样的方法可能会在推动神经语言建模的可扩展性和成本效益方面发挥关键作用,为人工智能的发展注入新的活力。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/974945.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

大语言模型推理能力从何而来?

前言 DeepSeek R1采用强化学习进行后训练,通过奖励机制和规则引导模型生成结构化思维链(CoT),从而显著提升了推理能力。这一创新方法使得DeepSeek R1能够在无需大量监督数据的情况下,通过自我进化发展出强大的推理能力…

最新本地部署 DeepSeekR1 蒸馏\满血量化版 + WebOpenUI 完整教程(Ubuntu\Linux系统\Ollama)

测试机为6133CPU(40Cores)256G D44*4090D 24G 一种方法是部署蒸馏版Distill模型。一种是部署Huggingface上unsloth的量化版模型 Ollama及模型安装 1.下载并安装ollama curl -fsSL https://ollama.com/install.sh | sh如果下载不动可以试试挂梯子或者再试几次 挂代理代码&…

PySide6学习专栏(四):用多线程完成复杂计算任务

如果计程序中要处理一个非常庞大的数据集中的数据,且数据处理计算很复杂,造成数据处理占用大量时间和CPU资源,如果不用多线程,仅在主进程中来处理数据,将会使整个程序卡死,必须采用多线程来处理这些数据是唯…

路由基本配置

学习目标 • 根据拓扑图进行网络布线。 • 清除启动配置并将路由器重新加载为默认状态。 • 在路由器上执行基本配置任务。 • 配置并激活以太网接口。 • 测试并检验配置。 • 思考网络实施方案并整理成文档。 任务 1:网络布线 使用适当的电缆类型连接网络设备。…

STM32MP157A单片机移植Linux驱动深入版

需求整理 在Linux设备树中新增leds节点&#xff0c;其有3个gpio属性&#xff0c;分别表示PE10对应led1&#xff0c;PF10对应led2&#xff0c;PE8对应led3&#xff0c;设备树键值对如下&#xff1a; leds { led1-gpio <&gpioe 10 0>; led2-gpio &l…

瑞芯微RV1126部署YOLOv8全流程:环境搭建、pt-onnx-rknn模型转换、C++推理代码、错误解决、优化、交叉编译第三方库

目录 1 环境搭建 2 交叉编译opencv 3 模型训练 4 模型转换 4.1 pt模型转onnx模型 4.2 onnx模型转rknn模型 4.2.1 安装rknn-toolkit 4.2.2 onn转成rknn模型 5 升级npu驱动 6 C++推理源码demo 6.1 原版demo 6.2 增加opencv读取图片的代码 7 交叉编译x264 ffmepg和op…

如何为自己的 PDF 文件添加密码?在线加密 PDF 文件其实更简单

随着信息泄露和数据安全问题的日益突出&#xff0c;保护敏感信息变得尤为重要。加密 PDF 文件是一种有效的手段&#xff0c;可以确保只有授权用户才能访问或修改文档内容。本文将详细介绍如何使用 CleverPDF 在线工具为你的 PDF 文件添加密码保护&#xff0c;确保其安全性。 为…

蓝桥杯核心内容

核心内容 数学 质数与筛质数&#xff0c;分解质因数 分解质因数 所有的数都可以写成有限个数相乘质数&#xff1a;可以写成1✖本身&#xff08;如131✖13&#xff09;合数&#xff1a;ab1✖...✖bn-》把乘数里面是合数的再分&#xff08;如b3是合数-》b3c1✖c2&#xff09;进…

七星棋牌源码高阶技术指南:6端互通、200+子游戏玩法深度剖析与企业级搭建实战(完全开源)

在棋牌游戏行业高速发展的今天&#xff0c;如何构建一个具备高并发、强稳定性与多功能支持的棋牌游戏系统成为众多开发者和运营团队关注的焦点。七星棋牌全开源修复版源码 凭借其 六端互通、200子游戏玩法、多省区本地化支持&#xff0c;以及 乐豆系统、防沉迷、比赛场、AI智能…

【学习笔记】【SpringCloud】MybatisPlus 基础使用

目录 一、使用 MybatisPlus 基本步骤 1. 引入 MybatisPlus 依赖 2. 定义Mapper接口并继承BaseMapper 二、MybatisPlus 常用配置 三、自定义SQL 四、IService 接口 1. 批量新增的效率问题 2. 配置方式 五、插件功能 1. 分页插件 一、使用 MybatisPlus 基本步骤 1. 引…

QT 引入Quazip和Zlib源码工程到项目中,无需编译成库,跨平台,压缩进度

前言 最近在做项目时遇到一个需求&#xff0c;需要将升级的文件压缩成zip&#xff0c;再进行传输&#xff1b; 通过网络调研&#xff0c;有许多方式可以实现&#xff0c;例如QT私有模块的ZipReader、QZipWriter&#xff1b;或者第三方库zlib或者libzip或者quazip等&#xff1…

在高流量下保持WordPress网站的稳定和高效运行

随着流量的不断增加&#xff0c;网站的稳定和高效运行变得越来越重要&#xff0c;特别是使用WordPress搭建的网站。流量过高时&#xff0c;网站加载可能会变慢&#xff0c;甚至崩溃&#xff0c;直接影响用户体验和网站正常运营。因此&#xff0c;我们需要采取一些有效的措施&am…

linux 安装启动zookeeper全过程及遇到的坑

1、下载安装zookeeper 参考文章&#xff1a;https://blog.csdn.net/weixin_48887095/article/details/132397448 2、启动失败 1、启动失败JAVA_HOME is not set and java could not be found in PATH 已安装 JAVA 配置了JAVA_HOME,还是报错解决方法&#xff1a;参考&#xf…

投资组合风险管理

投资组合风险管理 市场风险 信用风险流动性风险风险指标收益率波动率最大回撤 α \alpha α&#xff08;詹森指数&#xff09;&#xff0c; β \beta β卡玛比率月胜率上/下行捕获比夏普比率索提诺比率经风险调整的收益率&#xff08;&#x1d440;2&#xff09;特雷诺比率信息…

MySQL八股学习笔记

文章目录 一、MySQL结构1.宏观结构1.1.Server层1.2.存储引擎层 2.建立链接-连接器3.查询缓存4.解析SQL-解析器&#xff08;1&#xff09;词法分析&#xff08;2&#xff09;语法分析 5.执行SQL5.1.预处理器 prepare5.2.优化器 optimize5.3.执行器 execute&#xff08;1&#xf…

在windows下安装windows+Ubuntu16.04双系统(下)

这篇文章的内容主要来源于这篇文章&#xff0c;为正式安装windowsUbuntu16.04双系统部分。在正式安装前&#xff0c;若还没有进行前期准备工作&#xff08;1.分区2.制作启动u盘&#xff09;&#xff0c;见《在windows下安装windowsUbuntu16.04双系统(上)》 二、正式安装Ubuntu …

自然语言处理NLP 04案例——苏宁易购优质评论与差评分析

上一篇文章&#xff0c;我们爬取了苏宁易购平台某产品的优质评价和差评&#xff0c;今天我们对优质评价与差评进行分析 selenium爬取苏宁易购平台某产品的评论-CSDN博客 目录 1. 数据加载 2. 中文分词 3. 停用词处理 4. 数据标注与合并 5. 数据集划分 6. 文本特征提取 …

最新版本Exoplayer扩展FFmpeg音频软解码保姆级教程

ExoPlayer 是一个开源的 Android 媒体播放库&#xff0c;由 Google 开发和维护&#xff0c;用于替代 Android 系统自带的 MediaPlayer。它提供了更强大的功能、更好的性能和更高的灵活性&#xff0c;适用于各种复杂的媒体播放场景。所以被广泛用于各种播放器场景。 最近项目中…

华为昇腾910b服务器部署DeepSeek翻车现场

最近到祸一台HUAWEI Kunpeng 920 5250&#xff0c;先看看配置。之前是部署的讯飞大模型&#xff0c;发现资源利用率太低了。把5台减少到3台&#xff0c;就出了他 硬件配置信息 基本硬件信息 按照惯例先来看看配置。一共3块盘&#xff0c;500G的系统盘&#xff0c; 2块3T固态…

【操作系统】操作系统概述

操作系统概述 1.1 操作系统的概念1.1.1 操作系统定义——什么是OS&#xff1f;1.1.2 操作系统作用——OS有什么用&#xff1f;1.1.3 操作系统地位——计算机系统中&#xff0c;OS处于什么地位&#xff1f;1.1.4 为什么学操作系统&#xff1f; 1.2 操作系统的历史1.2.1 操作系统…