DeepSeek V3 源码:从入门到放弃!

从入门到放弃

花了几天时间,看懂了DeepSeek V3 源码的逻辑。源码的逻辑是不难的,但为什么模型结构需要这样设计,为什么参数需要这样设置呢?知其然,但不知其所以然。除了模型结构以外,模型的训练数据、训练脚本和训练经验,也是DeepSeek V3能够训练出来的关键,但这些是DeepSeek母公司的核心机密,我们无从得知。
因此,看懂了源码,算是入门了DeepSeek V3,因为没有条件知道更多重要细节,因此不得不放弃重现整个模型的训练。

Paper 和源码

Paper URL: https://github.com/deepseek-ai/DeepSeek-V3/blob/main/DeepSeek_V3.pdf
Code URL: https://github.com/deepseek-ai/DeepSeek-V3

模型逻辑

下面这张图,代表了DeepSeek的核心逻辑。左边是Transformer的逻辑结构,可以认为有N个左边这样的Block结构不断重复,组成Transformer模型。每个Block中,分成两个部分,Attention 和 Feed-Forward Network。对这两个部分使用不同的网络结构,我们就得到了不同的模型。
DeepSeek V3 的 Attention 用的是 Multi-Head Latent Attention(MLA) ,Feed-Forward Network 用的是DeepSeekMoE。
在这里插入图片描述

MLA

Multi-Head Latent Attention(MLA)即多头潜在注意力,是DeepSeek模型中引入的一种创新注意力机制,旨在优化传统多头注意力(Multi-Head Attention,MHA)的计算效率和内存占用。具体介绍如下:

核心创新点

  • 低秩键值压缩
    • KV的低秩压缩:不直接存储原始的Key和Value,而是先将隐藏状态投影到一个更小的压缩潜在向量。在推理时,只需缓存该压缩潜在向量,而不是完整的Key和Value,从而大大降低了KV缓存的存储需求。
    • Query的低秩压缩:对Query也进行低秩压缩,虽然不会减少KV缓存的大小,但可以减少训练时的激活存储需求,进而降低计算成本。
  • 解耦旋转位置嵌入(RoPE)
    • 额外引入“解耦查询”:将查询拆分为两个部分,一部分不经过RoPE变换,代表非位置敏感的特征信息;另一部分专门用于嵌入RoPE位置编码信息。
    • 共享RoPE变换的Key:所有注意力头共用一个旋转变换后的Key,减少了计算开销,也减小了KV缓存大小,降低了GPU内存占用,提高了推理速度,特别适用于长序列任务和大规模Transformer。

推理过程中的优化

将上投影矩阵吸收到里面,简化查询计算,并优化注意力分数的计算,减少了计算步骤,提升了计算效率。避免了先计算Value向量,减少了矩阵运算的开销,使推理更快。

整体优势

  • 降低内存占用:通过对键值进行低秩联合压缩以及解耦RoPE等策略,显著减少了KV缓存的存储需求,降低了GPU内存占用。
  • 提高计算效率:减少了训练和推理过程中的计算量,加快了模型的推理速度,在保持甚至提高模型性能的同时,提升了模型的运行效率。
  • 增强模型适应性:特别适用于长序列任务和大规模Transformer模型,能够更好地处理长序列输入,提高模型在各种自然语言处理任务中的表现。

MLA 有物理意义吗?

Multi-Head Latent Attention(MLA)能够起作用主要源于其独特的技术设计,在数学和信息处理层面有清晰的逻辑,不过它是一种抽象的算法概念,并不直接对应具体的物理意义,以下是对其作用原理的分析:

起作用的原因

  • 低秩压缩的有效性
    • 信息浓缩与降噪:通过低秩键值压缩,MLA将高维的Key和Value信息投影到低维的潜在向量空间,这一过程类似于对原始信息进行浓缩,提取出最关键、最具代表性的特征,去除了一些可能的噪声和冗余信息,使得模型能够更聚焦于重要信息,从而提高信息处理的效率和准确性。
    • 减少计算量和存储需求:低秩压缩大大降低了数据的维度,减少了模型训练和推理过程中的计算量和存储需求,使得模型能够更高效地运行,尤其是在处理大规模数据和长序列数据时,这种优势更为明显。
  • 解耦旋转位置嵌入的优势
    • 位置信息与内容信息的分离:传统的位置编码方式将位置信息和内容信息混合在一起进行处理,而MLA的解耦旋转位置嵌入将查询拆分为位置敏感和非位置敏感两部分,使模型能够更清晰地分离和处理位置信息与内容信息,更好地捕捉文本中的长距离依赖关系。
    • 共享RoPE变换的Key:所有注意力头共用一个旋转变换后的Key,不仅减少了计算开销,还使得模型能够从更宏观的角度利用位置信息,增强了模型对序列数据整体结构的理解和把握能力。
  • 多头机制的协同作用
    • 捕捉多维度信息:MLA中的多头机制允许模型同时从多个不同的角度和维度去捕捉输入数据中的信息,每个头可以关注到输入序列的不同方面,通过多个头的并行计算和协同工作,模型能够更全面、更深入地理解输入数据,提高模型的表示能力和泛化能力。

难以直接赋予物理意义的原因及近似理解

  • 抽象的算法概念:MLA是一种基于数学和计算机科学的算法概念,主要用于处理和分析数据中的模式和关系,它不像物理概念那样具有直接可观测的物理实体或现象与之对应,更多地是在数据空间和计算逻辑中发挥作用。
  • 类比物理现象理解:可以进行一些类比来帮助理解。比如低秩压缩类似于物理中的能量聚集,将分散的能量(信息)聚集到关键的“点”上;解耦旋转位置嵌入有点像物理中对不同性质力的分解,将位置信息和内容信息这两种“力”分开处理;多头机制如同多个物理传感器从不同方向和角度对环境进行感知,然后综合这些感知信息来对整个系统进行理解和判断。

DeepSeekMoE

DeepSeekMoE是由深度求索(DeepSeek)研发的基于混合专家系统(Mixture of Experts,MoE)的技术架构,以下是具体介绍:

架构原理

  • 混合专家系统核心:采用MoE架构,核心在于通过动态路由机制,把输入数据分配给最相关的专家处理。比如在自然语言处理中,有的专家专门处理情感分析,有的处理主题建模。
  • 结合多头潜在注意力机制:与MLA相结合,MLA通过引入潜在向量,减少键值缓存(KV cache)需求,提升推理效率。
  • Transformer架构基础:以Transformer架构为基础,每个Transformer块由一个注意力模块和一个前馈网络(FFN)组成,在注意力机制和FFN方面采用创新架构。

技术优势

  • 降低算力需求:MoE的动态分配机制和MLA减少KV缓存需求等特点,使模型在训练和推理时对算力的要求降低。
  • 保持高性能:在参数量减少的情况下仍能保持高性能,例如DeepSeek-V2以236B总参数、21B激活,大致可以达到70B-110B Dense的模型能力。
  • 减少计算量:自研Sparse结构DeepSeekMoE进一步降低了计算量。
  • 长上下文理解能力强:支持超100万token的上下文窗口,显著优于行业平均水平,适用于长文档分析、代码开发等复杂场景的连贯交互。

DeepSeekMoE的物理意义是什么?

DeepSeekMoE作为一种人工智能技术架构,没有严格意义上的物理意义,但可以从一些角度进行类比和理解:

从系统资源分配角度

  • 资源按需分配类比:可以将DeepSeekMoE的专家网络和动态路由机制类比为一个智能电力分配系统。在这个系统中,不同的电器设备(任务)需要不同的电量(计算资源)来运行。专家网络就像不同功率的发电机,而动态路由机制则像是智能电表和分配器,它会根据每个电器设备的实际需求,将电力(计算资源)精准地分配给需要的设备,避免了资源的浪费,提高了整个系统的能源利用效率。
  • 负载均衡类比:类似于在一个大型物流中心,不同的仓库区域(专家)负责存储和处理不同类型的货物(数据)。当有货物运输任务时,调度系统(动态路由)会根据货物的特点和仓库的负载情况,合理地安排货物存储到哪个仓库,确保每个仓库都能在其承载能力范围内高效运作,不会出现某个仓库过度拥挤而其他仓库闲置的情况,实现了负载均衡,提高了物流中心的整体运营效率。

从信息处理角度

  • 多维度信息处理类比:可以把DeepSeekMoE处理信息的过程想象成一个由多个不同专业的侦探(专家)组成的侦探团队在调查一个复杂案件。每个侦探都有自己独特的专业技能和视角,比如有的擅长调查线索,有的擅长分析人物关系,有的擅长破解密码等。当面对案件(输入数据)时,队长(路由器)会根据案件的具体情况,分配合适的侦探去处理相应的部分,最后将各个侦探的调查结果综合起来,形成对整个案件的全面了解和判断,从而更高效地解决复杂问题。
  • 特征提取与融合类比:如同在一个化学实验中,不同的化学试剂(专家)可以与不同的物质发生反应,提取出特定的化学特征。DeepSeekMoE中的专家网络就像这些化学试剂,它们各自对输入数据进行处理,提取出不同的特征。然后通过融合机制,将这些特征像混合化学物质一样进行整合,得到更全面、更有价值的信息,用于后续的分析和决策。

从模型架构角度

  • 积木搭建类比:把DeepSeekMoE的架构比作搭建积木。每个专家网络就像不同形状和功能的积木块,有的积木块负责搭建基础结构,有的负责构建上层建筑,有的负责添加装饰等。路由器则像是搭建者的手,根据要搭建的目标模型的需求,选择合适的积木块进行组合,最终搭建出一个复杂而功能强大的模型结构,实现对各种自然语言处理任务的高效处理。
  • 人体神经系统类比:可以将DeepSeekMoE类比为人体的神经系统。专家网络类似于人体的不同神经细胞或神经中枢,它们各自负责处理特定类型的信息,如视觉神经细胞负责处理视觉信息,听觉神经细胞负责处理听觉信息等。路由器就像神经系统中的神经递质或信号传导机制,它负责将外界的刺激信号(输入数据)准确地传递给相应的神经细胞,并将各个神经细胞处理后的信号进行整合和传递,使人体能够做出协调的反应和决策,实现对外部世界的感知和交互。

代码逻辑

整体 - Transformer

下面这段代码是典型的 Transformer 实现,核心可以看 forward 函数逻辑:

  1. 进行 Embeding;
  2. 经过各个 Block;
  3. 归一化并输出。
    对应的代码:
# 通过嵌入层将输入标记转换为向量表示
h = self.embed(tokens)
# 依次通过每个Transformer块进行处理
for layer in self.layers:
    h = layer(h, start_pos, freqs_cis, mask)
# 对输出进行层归一化,并取最后一个时间步的输出
h = self.norm(h)[:, -1]
# 通过输出投影层得到对数概率
logits = self.head(h)

完整代码:

# 定义Transformer类,继承自PyTorch的nn.Module类
class Transformer(Module):
    """
    Transformer模型,包含位置嵌入、多个层以及输出投影。

    属性:
        max_seq_len (int): Transformer允许的最大序列长度。
        embed (nn.Module): 用于输入标记的嵌入层,将输入的标记转换为向量表示。
        layers (torch.nn.ModuleList): 存储多个Transformer块的列表,每个块包含多头注意力和前馈网络。
        norm (nn.Module): 层归一化层,在所有Transformer块之后应用,用于稳定训练。
        head (nn.Module): 输出投影层,将模型的输出映射到词汇表大小,用于预测下一个标记。
        freqs_cis (torch.Tensor): 预计算的复指数值,用于旋转位置嵌入,帮助模型捕捉序列中的位置信息。
    """
    def __init__(self, args):
        """
        初始化Transformer模型。

        参数:
            args: 模型参数对象,包含Transformer的各种参数,如词汇表大小、维度、层数等。
        """
        # 获取全局变量world_size和rank,分别表示分布式训练中的进程总数和当前进程的编号
        global world_size, rank
        # 如果分布式训练已初始化,则获取进程总数,否则默认为1
        world_size = dist.get_world_size() if dist.is_initialized() else 1
        # 如果分布式训练已初始化,则获取当前进程编号,否则默认为0
        rank = dist.get_rank() if dist.is_initialized() else 0
        # 根据参数设置线性层的数据类型
        Linear.dtype = torch.float8_e4m3fn if args.dtype == "fp8" else torch.bfloat16
        # 调用父类的初始化方法
        super().__init__()
        # 保存最大序列长度
        self.max_seq_len = args.max_seq_len
        # 初始化嵌入层,将输入标记转换为向量表示
        self.embed = ParallelEmbedding(args.vocab_size, args.dim)
        # 初始化一个空的ModuleList,用于存储Transformer块
        self.layers = torch.nn.ModuleList()
        # 循环创建指定数量的Transformer块,并添加到layers列表中
        for layer_id in range(args.n_layers):
            self.layers.append(Block(layer_id, args))
        # 初始化层归一化层
        self.norm = RMSNorm(args.dim)
        # 初始化输出投影层,将模型的输出映射到词汇表大小
        self.head = ColumnParallelLinear(args.dim, args.vocab_size, dtype=torch.get_default_dtype())
        # 预计算旋转位置嵌入所需的复指数值,并将其注册为缓冲区,不参与模型参数的更新
        self.register_buffer("freqs_cis", precompute_freqs_cis(args), persistent=False)

    @torch.inference_mode()
    def forward(self, tokens, start_pos=0):
        """
        Transformer模型的前向传播过程。

        参数:
            tokens (torch.Tensor): 输入的标记ID张量,形状为 (batch_size, seq_len)。
            start_pos (int, 可选): 旋转位置嵌入的起始位置,默认为0。

        返回:
            torch.Tensor: 对数概率张量,形状为 (batch_size, vocab_size),表示每个标记的预测概率。
        """
        # 获取输入序列的长度
        seqlen = tokens.size(1)
        # 通过嵌入层将输入标记转换为向量表示
        h = self.embed(tokens)
        # 从预计算的复指数值中截取当前序列所需的部分
        freqs_cis = self.freqs_cis[start_pos:start_pos + seqlen]
        # 初始化掩码为None
        mask = None
        # 如果序列长度大于1,则创建一个上三角掩码,用于屏蔽未来的标记
        if seqlen > 1:
            mask = torch.full((seqlen, seqlen), float("-inf"), device=tokens.device).triu_(1)
        # 依次通过每个Transformer块进行处理
        for layer in self.layers:
            h = layer(h, start_pos, freqs_cis, mask)
        # 对输出进行层归一化,并取最后一个时间步的输出
        h = self.norm(h)[:, -1]
        # 通过输出投影层得到对数概率
        logits = self.head(h)
        # 如果使用分布式训练,则收集所有进程的对数概率
        if world_size > 1:
            # 创建一个列表,用于存储所有进程的对数概率
            all_logits = [torch.empty_like(logits) for _ in range(world_size)]
            # 收集所有进程的对数概率
            dist.all_gather(all_logits, logits)
            # 将所有进程的对数概率拼接在一起
            logits = torch.cat(all_logits, dim=-1)
        return logits

单个 - Block

在这里插入图片描述
核心代码非常简单MLA(attention) + MOE(Feed-Forward Network):

# 首先对输入进行层归一化,然后通过注意力层进行计算,最后将结果与输入进行残差连接
x = x + self.attn(self.attn_norm(x), start_pos, freqs_cis, mask)
# 接着对上述结果进行层归一化,再通过前馈网络层进行计算,最后将结果与之前的结果进行残差连接
x = x + self.ffn(self.ffn_norm(x))

全部代码:

# 定义一个Transformer块类,继承自PyTorch的nn.Module类
class Block(Module):
    """
    Transformer块,结合了注意力层和前馈网络层。

    属性:
        attn (nn.Module): 注意力层(采用多头潜在注意力机制,即MLA),用于捕捉输入序列中不同位置之间的依赖关系。
        ffn (nn.Module): 前馈网络层(可以是多层感知机MLP或者混合专家模型MoE),对注意力层的输出进行非线性变换。
        attn_norm (nn.Module): 用于注意力层的层归一化层,对输入到注意力层的数据进行归一化处理,稳定训练过程。
        ffn_norm (nn.Module): 用于前馈网络层的层归一化层,对输入到前馈网络层的数据进行归一化处理。
    """
    def __init__(self, layer_id, args):
        """
        初始化Transformer块。

        参数:
            layer_id (int): 当前块在Transformer模型中的层索引,用于确定使用哪种前馈网络结构。
            args: 模型参数对象,包含了块的各种参数,如维度、层数等。
        """
        # 调用父类的初始化方法
        super().__init__()
        # 初始化注意力层,使用多头潜在注意力机制(MLA)
        self.attn = MLA(args)
        # 根据当前层的索引来决定使用MLP还是MoE作为前馈网络
        # 如果当前层索引小于密集层的数量,则使用MLP
        # 否则使用混合专家模型(MoE)
        self.ffn = MLP(args.dim, args.inter_dim) if layer_id < args.n_dense_layers else MoE(args)
        # 初始化用于注意力层的层归一化层
        self.attn_norm = RMSNorm(args.dim)
        # 初始化用于前馈网络层的层归一化层
        self.ffn_norm = RMSNorm(args.dim)

    def forward(self, x, start_pos, freqs_cis, mask=None):
        """
        Transformer块的前向传播过程。

        参数:
            x (torch.Tensor): 输入张量,包含了序列的特征信息。
            start_pos (int): 序列中的起始位置,用于旋转位置嵌入。
            freqs_cis (torch.Tensor): 预计算的复指数值,用于旋转位置嵌入,帮助模型捕捉序列中的位置信息。
            mask (Optional[torch.Tensor]): 掩码张量,用于在注意力计算中排除某些位置,避免模型关注到不应该关注的信息。

        返回:
            torch.Tensor: 经过当前Transformer块计算后的输出张量。
        """
        # 首先对输入进行层归一化,然后通过注意力层进行计算,最后将结果与输入进行残差连接
        x = x + self.attn(self.attn_norm(x), start_pos, freqs_cis, mask)
        # 接着对上述结果进行层归一化,再通过前馈网络层进行计算,最后将结果与之前的结果进行残差连接
        x = x + self.ffn(self.ffn_norm(x))
        return x

Attention 模块

经典的QKV计算公式。解释可以自行搜索,或者参考:Transformer结构和注意力机制
在这里插入图片描述
和传统的QKV相比,可以认为是做了压缩,主要是为了减小 KV Cache。
在这里插入图片描述
代码,就是做了一堆下采样上采样和矩阵的组合变换。最终目的是减少计算量和显存使用量。

# 定义多头注意力层类,继承自PyTorch的nn.Module类
class MLA(Module):
    """
    多头注意力层(MLA)。

    属性:
        dim (int): 输入特征的维度。
        n_heads (int): 注意力头的数量。
        n_local_heads (int): 分布式系统中本地注意力头的数量。
        q_lora_rank (int): 查询(query)的低秩投影的秩。
        kv_lora_rank (int): 键(key)和值(value)的低秩投影的秩。
        qk_nope_head_dim (int): 非位置相关的查询/键投影的维度。
        qk_rope_head_dim (int): 旋转位置编码的查询/键投影的维度。
        qk_head_dim (int): 查询/键投影的总维度。
        v_head_dim (int): 值投影的维度。
        softmax_scale (float): 注意力计算中softmax函数的缩放因子。
    """
    def __init__(self, args):
        # 调用父类的初始化方法
        super().__init__()
        # 保存输入特征的维度
        self.dim = args.dim
        # 保存注意力头的数量
        self.n_heads = args.n_heads
        # 计算分布式系统中本地注意力头的数量
        self.n_local_heads = args.n_heads // world_size
        # 保存查询的低秩投影的秩
        self.q_lora_rank = args.q_lora_rank
        # 保存键和值的低秩投影的秩
        self.kv_lora_rank = args.kv_lora_rank
        # 保存非位置相关的查询/键投影的维度
        self.qk_nope_head_dim = args.qk_nope_head_dim
        # 保存旋转位置编码的查询/键投影的维度
        self.qk_rope_head_dim = args.qk_rope_head_dim
        # 计算查询/键投影的总维度
        self.qk_head_dim = args.qk_nope_head_dim + args.qk_rope_head_dim
        # 保存值投影的维度
        self.v_head_dim = args.v_head_dim

        # 如果查询的低秩投影的秩为0,直接使用列并行线性层进行查询投影
        if self.q_lora_rank == 0:
            self.wq = ColumnParallelLinear(self.dim, self.n_heads * self.qk_head_dim)
        # 否则,使用低秩分解的方式进行查询投影
        else:
            self.wq_a = Linear(self.dim, self.q_lora_rank)
            self.q_norm = RMSNorm(self.q_lora_rank)
            self.wq_b = ColumnParallelLinear(self.q_lora_rank, self.n_heads * self.qk_head_dim)
        # 对输入进行线性变换得到键和值的低秩表示
        self.wkv_a = Linear(self.dim, self.kv_lora_rank + self.qk_rope_head_dim)
        # 对键和值的低秩表示进行归一化
        self.kv_norm = RMSNorm(self.kv_lora_rank)
        # 对归一化后的键和值进行线性变换
        self.wkv_b = ColumnParallelLinear(self.kv_lora_rank, self.n_heads * (self.qk_nope_head_dim + self.v_head_dim))
        # 对多头注意力的输出进行行并行线性变换
        self.wo = RowParallelLinear(self.n_heads * self.v_head_dim, self.dim)
        # 计算softmax函数的缩放因子
        self.softmax_scale = self.qk_head_dim ** -0.5
        # 如果最大序列长度大于原始序列长度,对缩放因子进行调整
        if args.max_seq_len > args.original_seq_len:
            mscale = 0.1 * args.mscale * math.log(args.rope_factor) + 1.0
            self.softmax_scale = self.softmax_scale * mscale * mscale

        # 如果注意力实现方式为朴素方式
        if attn_impl == "naive":
            # 注册键缓存
            self.register_buffer("k_cache", torch.zeros(args.max_batch_size, args.max_seq_len, self.n_local_heads, self.qk_head_dim), persistent=False)
            # 注册值缓存
            self.register_buffer("v_cache", torch.zeros(args.max_batch_size, args.max_seq_len, self.n_local_heads, self.v_head_dim), persistent=False)
        # 否则
        else:
            # 注册键值缓存
            self.register_buffer("kv_cache", torch.zeros(args.max_batch_size, args.max_seq_len, self.kv_lora_rank), persistent=False)
            # 注册位置编码缓存
            self.register_buffer("pe_cache", torch.zeros(args.max_batch_size, args.max_seq_len, self.qk_rope_head_dim), persistent=False)

    def forward(self, x, start_pos, freqs_cis, mask=None):
        """
        多头注意力层(MLA)的前向传播过程。

        参数:
            x (torch.Tensor): 输入张量,形状为 (batch_size, seq_len, dim)。
            start_pos (int): 序列中用于缓存的起始位置。
            freqs_cis (torch.Tensor): 预计算的复指数值,用于旋转位置编码。
            mask (Optional[torch.Tensor]): 掩码张量,用于在注意力计算中排除某些位置。

        返回:
            torch.Tensor: 输出张量,形状与输入相同。
        """
        # 获取输入张量的批次大小、序列长度
        bsz, seqlen, _ = x.size()
        # 计算序列的结束位置
        end_pos = start_pos + seqlen
        # 如果查询的低秩投影的秩为0,直接通过线性层得到查询
        if self.q_lora_rank == 0:
            q = self.wq(x)
        # 否则,通过低秩分解的方式得到查询
        else:
            q = self.wq_b(self.q_norm(self.wq_a(x)))
        # 调整查询的形状,将其划分为多个头
        q = q.view(bsz, seqlen, self.n_local_heads, self.qk_head_dim)
        # 将查询划分为非位置相关部分和位置相关部分
        q_nope, q_pe = torch.split(q, [self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1)
        # 对位置相关部分应用旋转位置编码
        q_pe = apply_rotary_emb(q_pe, freqs_cis)
        # 通过线性层得到键和值的低秩表示
        kv = self.wkv_a(x)
        # 将键和值的低秩表示划分为低秩部分和位置编码部分
        kv, k_pe = torch.split(kv, [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1)
        # 对位置编码部分应用旋转位置编码
        k_pe = apply_rotary_emb(k_pe.unsqueeze(2), freqs_cis)

        # 如果注意力实现方式为朴素方式
        if attn_impl == "naive":
            # 将非位置相关部分和位置相关部分拼接得到完整的查询
            q = torch.cat([q_nope, q_pe], dim=-1)
            # 对键和值的低秩表示进行归一化和线性变换
            kv = self.wkv_b(self.kv_norm(kv))
            # 调整键和值的形状,将其划分为多个头
            kv = kv.view(bsz, seqlen, self.n_local_heads, self.qk_nope_head_dim + self.v_head_dim)
            # 将键和值划分为非位置相关部分和值部分
            k_nope, v = torch.split(kv, [self.qk_nope_head_dim, self.v_head_dim], dim=-1)
            # 将非位置相关部分和位置编码部分拼接得到完整的键
            k = torch.cat([k_nope, k_pe.expand(-1, -1, self.n_local_heads, -1)], dim=-1)
            # 将键存入缓存
            self.k_cache[:bsz, start_pos:end_pos] = k
            # 将值存入缓存
            self.v_cache[:bsz, start_pos:end_pos] = v
            # 计算注意力分数
            scores = torch.einsum("bshd,bthd->bsht", q, self.k_cache[:bsz, :end_pos]) * self.softmax_scale
        # 否则
        else:
            # 获取键和值的线性变换层的权重
            wkv_b = self.wkv_b.weight if self.wkv_b.scale is None else weight_dequant(self.wkv_b.weight, self.wkv_b.scale, block_size) 
            # 调整权重的形状
            wkv_b = wkv_b.view(self.n_local_heads, -1, self.kv_lora_rank)
            # 计算非位置相关部分的注意力分数
            q_nope = torch.einsum("bshd,hdc->bshc", q_nope, wkv_b[:, :self.qk_nope_head_dim])
            # 将键和值的低秩表示归一化后存入缓存
            self.kv_cache[:bsz, start_pos:end_pos] = self.kv_norm(kv)
            # 将位置编码部分存入缓存
            self.pe_cache[:bsz, start_pos:end_pos] = k_pe.squeeze(2)
            # 计算注意力分数
            scores = (torch.einsum("bshc,btc->bsht", q_nope, self.kv_cache[:bsz, :end_pos]) +
                      torch.einsum("bshr,btr->bsht", q_pe, self.pe_cache[:bsz, :end_pos])) * self.softmax_scale

        # 如果存在掩码,将掩码加到注意力分数上
        if mask is not None:
            scores += mask.unsqueeze(1)
        # 对注意力分数应用softmax函数
        scores = scores.softmax(dim=-1, dtype=torch.float32).type_as(x)

        # 如果注意力实现方式为朴素方式
        if attn_impl == "naive":
            # 通过注意力分数和值缓存计算输出
            x = torch.einsum("bsht,bthd->bshd", scores, self.v_cache[:bsz, :end_pos])
        # 否则
        else:
            # 通过注意力分数和键值缓存计算中间结果
            x = torch.einsum("bsht,btc->bshc", scores, self.kv_cache[:bsz, :end_pos])
            # 通过中间结果和权重计算输出
            x = torch.einsum("bshc,hdc->bshd", x, wkv_b[:, -self.v_head_dim:])
        # 对输出进行线性变换
        x = self.wo(x.flatten(2))
        return x

代码解释总结

这段代码定义了一个多头注意力层(MLA)类。在初始化时,根据传入的参数设置各种维度、低秩投影的秩等,并初始化相应的线性层和归一化层,同时根据注意力实现方式注册不同的缓存。在前向传播过程中,对输入进行处理得到查询、键和值,应用旋转位置编码,根据不同的注意力实现方式计算注意力分数,最后通过注意力分数和缓存得到输出并进行线性变换。

Feed-Forward Network

在这里插入图片描述
这个MoE分成两个部分,左边是一些可以分享的专家,就是每次都需要去计算的,右边的是根据分数来选择的。如何选择,是通过一个门控机制来选择。这个门控是如何设计的,代码里有实现,但论文和代码都没有对它物理意义的解释。门控机制,简单来说,就是设计了一个网络,来选出K个候选。

核心逻辑:

  1. 通过门控确定本轮要用到的本地专家:weights, indices = self.gate(x)
  2. 用选择的每个本地专家进行计算:y[idx] += expert(x[idx]) * weights[idx, top, None]
  3. 用共享专家进行计算:z = self.shared_experts(x)
  4. 将本地专家的输出和共享专家的输出相加,并恢复到原始形状: return (y + z).view(shape)

全部代码:

# 定义混合专家(Mixture-of-Experts, MoE)模块类,继承自PyTorch的nn.Module类
class MoE(nn.Module):
    """
    混合专家(Mixture-of-Experts, MoE)模块。

    属性:
        dim (int): 输入特征的维度。
        n_routed_experts (int): 模型中专家的总数。
        n_local_experts (int): 在分布式系统中本地处理的专家数量。
        n_activated_experts (int): 每个输入激活的专家数量。
        gate (nn.Module): 门控机制,用于将输入路由到不同的专家。
        experts (nn.ModuleList): 专家模块列表,包含多个专家网络。
        shared_experts (nn.Module): 共享专家模块,应用于所有输入。
    """
    def __init__(self, args):
        """
        初始化MoE模块。

        参数:
            args: 模型参数对象,包含MoE模块的相关参数。
        """
        # 调用父类的初始化方法
        super().__init__()
        # 保存输入特征的维度
        self.dim = args.dim
        # 确保专家总数能被分布式系统中的进程数整除
        assert args.n_routed_experts % world_size == 0, f"专家数量必须能被进程数整除 (进程数={world_size})"
        # 保存模型中专家的总数
        self.n_routed_experts = args.n_routed_experts
        # 计算本地处理的专家数量
        self.n_local_experts = args.n_routed_experts // world_size
        # 保存每个输入激活的专家数量
        self.n_activated_experts = args.n_activated_experts
        # 计算本地专家在所有专家中的起始索引
        self.experts_start_idx = rank * self.n_local_experts
        # 计算本地专家在所有专家中的结束索引
        self.experts_end_idx = self.experts_start_idx + self.n_local_experts
        # 初始化门控机制
        self.gate = Gate(args)
        # 初始化专家模块列表,本地负责的专家使用Expert模块,其他位置置为None
        self.experts = nn.ModuleList([Expert(args.dim, args.moe_inter_dim) if self.experts_start_idx <= i < self.experts_end_idx else None
                                      for i in range(self.n_routed_experts)])
        # 初始化共享专家模块
        self.shared_experts = MLP(args.dim, args.n_shared_experts * args.moe_inter_dim)

    def forward(self, x):
        """
        MoE模块的前向传播过程。

        参数:
            x (torch.Tensor): 输入张量。

        返回:
            torch.Tensor: 经过专家路由和计算后的输出张量。
        """
        # 保存输入张量的原始形状
        shape = x.size()
        # 将输入张量展平为二维张量,方便后续处理
        x = x.view(-1, self.dim)
        # 通过门控机制得到每个输入分配到各个专家的权重和对应的专家索引
        weights, indices = self.gate(x)
        # 初始化输出张量,形状与输入相同,初始值全为0
        y = torch.zeros_like(x)
        # 统计每个专家被分配到的输入数量
        counts = torch.bincount(indices.flatten(), minlength=self.n_routed_experts).tolist()
        # 遍历本地负责的专家
        for i in range(self.experts_start_idx, self.experts_end_idx):
            # 如果该专家没有被分配到输入,则跳过
            if counts[i] == 0:
                continue
            # 获取当前专家模块
            expert = self.experts[i]
            # 找出分配到当前专家的输入的索引
            idx, top = torch.where(indices == i)
            # 将这些输入通过当前专家模块进行计算,并乘以对应的权重,累加到输出张量中
            y[idx] += expert(x[idx]) * weights[idx, top, None]
        # 将输入通过共享专家模块进行计算
        z = self.shared_experts(x)
        # 如果使用分布式训练,对本地专家的输出进行全局归约操作
        if world_size > 1:
            dist.all_reduce(y)
        # 将本地专家的输出和共享专家的输出相加,并恢复到原始形状
        return (y + z).view(shape)

代码解释总结

这段代码定义了一个混合专家(MoE)模块。在初始化时,根据传入的参数设置专家的数量、门控机制、专家模块列表和共享专家模块。在前向传播过程中,首先通过门控机制将输入路由到不同的专家,然后对本地负责的专家进行计算并累加结果,同时将输入通过共享专家模块进行计算,最后将两部分结果相加并恢复原始形状。如果使用分布式训练,还会对本地专家的输出进行全局归约操作。

Gate

不考虑分组路由来看看它的核心逻辑,实际上就是线下变换,然后激活,选择K个极值(如果用了分组,就是选择K的方式发生了一些变化):

# 通过线性变换计算每个输入对应各个专家的分数
scores = linear(x, self.weight)
# 根据评分函数类型对分数进行处理
if self.score_func == "softmax":
    scores = scores.softmax(dim=-1, dtype=torch.float32)
else:
    scores = scores.sigmoid()
# 选择分数最高的若干专家
indices = torch.topk(scores, self.topk, dim=-1)[1]
# 根据选择的专家索引,从原始分数中获取对应的权重
weights = scores.gather(1, indices)
# 如果评分函数是sigmoid,对权重进行归一化
if self.score_func == "sigmoid":
    weights /= weights.sum(dim=-1, keepdim=True)
# 对权重进行缩放
weights *= self.route_scale
return weights.type_as(x), indices
# 定义门控机制类,用于在混合专家(MoE)模型中对输入进行路由
class Gate(nn.Module):
    """
    混合专家(MoE)模型中用于输入路由的门控机制。

    属性:
        dim (int): 输入特征的维度。
        topk (int): 每个输入激活的顶级专家数量。
        n_groups (int): 用于路由的分组数量。
        topk_groups (int): 输入将被路由到的分组数量。
        score_func (str): 评分函数,取值为 'softmax' 或 'sigmoid'。
        route_scale (float): 路由权重的缩放因子。
        weight (torch.nn.Parameter): 门控机制的可学习权重。
        bias (Optional[torch.nn.Parameter]): 门控机制的可选偏置项。
    """
    def __init__(self, args):
        """
        初始化门控机制模块。

        参数:
            args: 模型参数对象,包含门控机制的相关参数。
        """
        # 调用父类的初始化方法
        super().__init__()
        # 保存输入特征的维度
        self.dim = args.dim
        # 保存每个输入激活的顶级专家数量
        self.topk = args.n_activated_experts
        # 保存用于路由的分组数量
        self.n_groups = args.n_expert_groups
        # 保存输入将被路由到的分组数量
        self.topk_groups = args.n_limited_groups
        # 保存评分函数类型
        self.score_func = args.score_func
        # 保存路由权重的缩放因子
        self.route_scale = args.route_scale
        # 初始化可学习权重
        self.weight = nn.Parameter(torch.empty(args.n_routed_experts, args.dim))
        # 根据输入特征维度决定是否初始化偏置项
        self.bias = nn.Parameter(torch.empty(args.n_routed_experts)) if self.dim == 7168 else None

    def forward(self, x):
        """
        门控机制的前向传播过程。

        参数:
            x (torch.Tensor): 输入张量。

        返回:
            Tuple[torch.Tensor, torch.Tensor]: 路由权重和选择的专家索引。
        """
        # 通过线性变换计算每个输入对应各个专家的分数
        scores = linear(x, self.weight)
        # 根据评分函数类型对分数进行处理
        if self.score_func == "softmax":
            scores = scores.softmax(dim=-1, dtype=torch.float32)
        else:
            scores = scores.sigmoid()
        # 保存原始分数,后续计算权重时使用
        original_scores = scores
        # 如果存在偏置项,将其加到分数上
        if self.bias is not None:
            scores = scores + self.bias
        # 如果分组数量大于1,进行分组路由操作
        if self.n_groups > 1:
            # 调整分数的形状,以便按组处理
            scores = scores.view(x.size(0), self.n_groups, -1)
            # 根据是否有偏置项,计算每个组的分数表示
            if self.bias is None:
                group_scores = scores.amax(dim=-1)
            else:
                group_scores = scores.topk(2, dim=-1)[0].sum(dim=-1)
            # 选择分数最高的若干组
            indices = group_scores.topk(self.topk_groups, dim=-1)[1]
            # 创建掩码,用于屏蔽未选择的组
            mask = scores.new_ones(x.size(0), self.n_groups, dtype=bool).scatter_(1, indices, False)
            # 将屏蔽组的分数设为负无穷
            scores = scores.masked_fill_(mask.unsqueeze(-1), float("-inf")).flatten(1)
        # 选择分数最高的若干专家
        indices = torch.topk(scores, self.topk, dim=-1)[1]
        # 根据选择的专家索引,从原始分数中获取对应的权重
        weights = original_scores.gather(1, indices)
        # 如果评分函数是sigmoid,对权重进行归一化
        if self.score_func == "sigmoid":
            weights /= weights.sum(dim=-1, keepdim=True)
        # 对权重进行缩放
        weights *= self.route_scale
        return weights.type_as(x), indices

代码解释总结

这段代码定义了一个门控机制(Gate)类,用于在混合专家(MoE)模型中对输入进行路由。在初始化时,根据传入的参数设置各种属性,如输入维度、激活专家数量、分组数量等,并初始化可学习的权重和偏置项。在前向传播过程中,首先计算每个输入对应各个专家的分数,然后根据评分函数类型进行处理,接着根据分组情况进行分组路由操作,选择激活的专家并计算对应的权重,最后返回路由权重和选择的专家索引。

B站大牛详解视频链接

B站大牛有详细视频讲解:
https://www.bilibili.com/video/BV1RtNLeqEeu/

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/982867.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

mapbox进阶,模仿百度,简单实现室内楼层切换

👨‍⚕️ 主页: gis分享者 👨‍⚕️ 感谢各位大佬 点赞👍 收藏⭐ 留言📝 加关注✅! 👨‍⚕️ 收录于专栏:mapbox 从入门到精通 文章目录 一、🍀前言1.1 ☘️mapboxgl.Map 地图对象1.2 ☘️mapboxgl.Map style属性1.3 ☘️fill-extrusion三维填充图层样式1.4 ☘…

【Bert系列模型】

目录 一、BERT模型介绍 1.1 BERT简介 1.2 BERT的架构 1.2.1 Embedding模块 1.2.2 双向Transformer模块 1.2.3 预微调模块 1.3 BERT的预训练任务 1.3.1 Masked Language Model (MLM) 1.3.2 Next Sentence Prediction (NSP) 1.4 预训练与微调的关系 1.5 小结 二、BERT…

Linux | Vim 鼠标不能右键粘贴、跨系统复制粘贴

注&#xff1a;本文为 “ Vim 中鼠标右键粘贴、跨系统复制粘贴问题解决方案” 相关文章合辑。 未整理去重。 Linux 入门&#xff1a;vim 鼠标不能右键粘贴、跨系统复制粘贴 foryouslgme 发布时间 2016 - 09 - 28 10:24:16 Vim 基础 命令模式&#xff08;command - mode&…

使用查询,休眠-唤醒方式,POLL方式,异步通知方式,读取输入设备信息

查询方式&#xff1a; APP调用open函数时&#xff0c;传入“O_NONBLOCK”表示非阻塞&#xff0c;就可以以非阻塞方式&#xff0c;也就是查询方式用read函数去读取&#xff0c;如果没有数据的话&#xff0c;就会立刻返回一个错误。 如果我们打开这个文件时没有传入“NONBLOCK”参…

【Java篇】算术如诗,逻辑似梦:Java 编程中的运算符探寻

文章目录 Java 运算符&#xff1a;在计算与逻辑之中追寻编程的哲理1.前言2. 算术运算符2.1 基本四则运算符&#xff1a;加减乘除&#xff08; - * / %&#xff09;2.2 除法与取余2.3 增量运算符&#xff08; --&#xff09;2.4 自增/自减运算符 3. 关系运算符3.1 关系运算符 4.…

Ae 效果详解:VR 转换器

Ae菜单&#xff1a;效果/沉浸式视频/VR 转换器 Immersive Video/VR Converter VR 转换器 VR Converter效果能够在 2D、球面投影、立方图、球形图等格式之间转换&#xff0c;并支持调整摄像机视角&#xff0c;适用于 VR 视频格式适配、画面校正和动画视角调整等&#xff0c;确保…

无显示器安装访问树莓派3B+

一、硬件准备 树莓派3B&#xff0c;适配器&#xff08;供电&#xff09;&#xff0c;读卡器和SD卡 二、软件下载及安装 安装过程都是默认选项&#xff0c;一直点击下一步即可&#xff0c;在选择安装路径时可以改到你自己想装的盘里。 1.树莓派系统镜像 官网地址&#xff1…

Vue3路由组件和一般组件 切换路由时组件挂载和卸载 路由的工作模式

路由组件和一般组件 路由组件 一般放到pages或view目录 一般组件 一般放到component目录 切换路由 切换路由时&#xff0c;组件和执行挂载和卸载 路由的工作模式 Hash模式 缺点 1.不美观&#xff0c;路径带#号 优点 1.兼容性好 一般适用于管理系统 History模式 缺点…

多线程初阶(一)

文章目录 1.线程和进程的区别2.创建线程2.1Thread类2.2Runnable接口2.3匿名类创建Thread子类对象创建后台线程 3.Thread常⻅⽅法4.中断线程4.1中断标记&#xff08;Interrupt Flag&#xff09;4.2调⽤ interrupt() ⽅法 5.线程状态 1.线程和进程的区别 1.进程中包含线程&#…

1.3 Spring Boot原理解析

Spring Boot通过起步依赖&#xff08;如spring-boot-starter-parent和spring-boot-starter-web&#xff09;简化项目配置&#xff0c;减少版本冲突和依赖配置代码量。它采用“约定大于配置”的设计思想&#xff0c;通过SpringBootApplication注解&#xff08;包含SpringBootCon…

⭐算法OJ⭐N-皇后问题 II【回溯剪枝】(C++实现)N-Queens II

⭐算法OJ⭐N-皇后问题【回溯剪枝】&#xff08;C实现&#xff09;N-Queens 问题描述 The n-queens puzzle is the problem of placing n n n queens on an n n n \times n nn chessboard such that no two queens attack each other. Given an integer n, return the num…

第6章 定时器计数器

目录 6.1 定时计数器的结构框图 6.2 定时器的控制字 6.2.1 TMOD&#xff1a;工作方式控制寄存器 6.2.2 定时/计数器控制寄存器TCON 6.3 定时/计数器的4种工作方式 6.3.1 方式0、方式1&#xff08;13位、16位定时计数方式&#xff09; 6.3.2 方式2(常数自动重装入) 6.3.3 方…

JavaWeb基础一(Tomcat、Maven)

前言 web开发 web开发&#xff1a;Web开发是指在万维网或私有网络上创建和维护网站的工作。它包括网页设计、网页编程、数据库管理等多方面的技术。Web开发可以分为前端开发和后端开发&#xff0c;前端主要关注用户界面和用户体验&#xff0c;而后端则处理服务器、应用程序和…

写一写idea中使用tomcat启动activiti过程

一 环境 tomcat 9.0.62 activiti的war包版本 7.1.0.M6 二 操作 官网下载&#xff1a;https://www.activiti.org/get-started 2.1 先在idea中编辑配置 2.2 点击加号然后选择tomcat本地进行确认 2.3 点击部署之后下边小加号 选择第二个之后就是选择自己想要使用tomcat启动的…

基于开源库编写MQTT通讯

目录 1. MQTT是什么&#xff1f;2. 开发交互UI3. 服务器核心代码4. 客户端核心代码5. 消息订阅与发布6. 通讯测试7. MQTT与PLC通讯最后. 核心总结 1. MQTT是什么&#xff1f; MQTT&#xff08;Message Queuing Terlemetry Transport&#xff09;消息队列遥测协议&#xff1b;是…

MAVEN手动配置(阿里云)全教程

介于网上各种各样的MAVEN配置过程中方法大致相同却细节参差不齐&#xff0c;我总结了我遇见的一些问题&#xff0c;来完全的解决MAVEN手动配置的全过程&#xff0c;以及分享解决小毛病的经验。 所需材料&#xff1a; MAVEN3.9.9&#xff08;下载适合自己的版本即可&#xff09…

从0到1入门Linux

一、常用命令 ls 列出目录内容 cd切换目录mkdir创建新目录rm删除文件或目录cp复制文件或目录mv移动或重命名文件和目录cat查看文件内容grep在文件中查找指定字符串ps查看当前进程状态top查看内存kill终止进程df -h查看磁盘空间存储情况iotop -o直接查看比较高的磁盘读写程序up…

pytest结合allure

Allure 一、文档二、指令三、装饰器3.1 allure.step装饰器3.2 allure.description装饰器3.3 allure.title装饰器3.4 allure.link、allure.issue 和 allure.testcase装饰器3.5 allure.epic、allure.feature 和 allure.story装饰器3.6 allure.severity装饰器 一、文档 allure文档…

Dockerfile 深入浅出:从基础到进阶全解析

Dockerfile 深入浅出&#xff1a;从基础到进阶全解析 各位同学&#xff0c;大家好&#xff01;欢迎来到今天的 Dockerfile 课程。Docker 技术在当今的软件开发和部署领域可以说是非常热门&#xff0c;而 Dockerfile 作为构建 Docker 镜像的关键文件&#xff0c;掌握它对于我们…

大模型巅峰对决:DeepSeek vs GPT-4/Claude/PaLM-2 全面对比与核心差异揭秘

文章目录 一、架构设计深度解剖1.1 核心架构对比图谱1.2 动态MoE架构实现架构差异分析表 二、训练策略全面对比2.1 训练数据工程对比2.2 分布式训练代码对比DeepSeek混合并行实现GPT-4 Megatron实现对比 2.3 关键训练参数对比 三、性能表现多维评测3.1 基准测试全景对比3.2 推理…