Mistral MOE架构全面解析

从代码角度理解Mistral架构

  • Mistral架构全面解析
    • 前言
    • Mistral 架构分析
      • 分词
      • 网络主干
        • MixtralDecoderLayer
          • Attention
          • MOE
          • MLP
      • 下游任务
        • 因果推理
        • 文本分类

Mistral架构全面解析

前言

Mixtral-8x7B 大型语言模型 (LLM) 是一种预训练的生成式稀疏专家混合模型。在大多数基准测试中,Mistral-8x7B 的性能优于 Llama 2 70B。

Mixtral 8x7B 是 Mistral AI 全新发布的 MoE 模型,MoE 是 Mixture-of-Experts 的简称,具体的实现就是将 Transformer 中的 FFN 层换成 MoE FFN 层,其他部分保持不变。在训练过程中,Mixtral 8x7B 采用了 8 个专家协同工作,而在推理阶段,则仅需激活其中的 2 个专家。这种设计巧妙地平衡了模型的复杂度和推理成本,即使在拥有庞大模型参数的情况下,也能保证高效的推理性能,使得 MoE 模型在保持强大功能的同时,也具备了更优的实用性和经济性。

  • 在大多数基准测试中表现优于Llama 2 70B
  • 甚至足以击败GPT-3.5上下文窗口为32k
  • 可以处理英语、法语、意大利语、德语和西班牙语
  • 在代码生成方面表现优异

huggingface上给出基本的加载方法

from transformers import AutoModelForCausalLM, AutoTokenizer

model_id = "mistralai/Mixtral-8x7B-v0.1"
tokenizer = AutoTokenizer.from_pretrained(model_id)

model = AutoModelForCausalLM.from_pretrained(model_id)

text = "Hello my name is"
inputs = tokenizer(text, return_tensors="pt")

outputs = model.generate(**inputs, max_new_tokens=20)
print(tokenizer.decode(outputs[0], skip_special_tokens=True))

在这里插入图片描述

Mistral 架构分析

其它结构和llama的一模一样,知道llama结构的话,省流直接看MOE部分。

分词

分词部分主要做的是利用文本分词器对文本进行分词

在这里插入图片描述

tokenizer = AutoTokenizer.from_pretrained(PATH_TO_CONVERTED_TOKENIZER)
text = "Hey, are you conscious? Can you talk to me?"
inputs = tokenizer(text, return_tensors="pt")

网络主干

主干网络部分主要是将分词得到的input_ids输入到embedding层中进行文本向量化,得到hidden_states(中间结果),然后输入到layers层中,得到hidden_states(中间结果),用于下游任务。

在这里插入图片描述

self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
        self.layers = nn.ModuleList(
            [MixtralDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
        )
        self._use_flash_attention_2 = config._attn_implementation == "flash_attention_2"
        self.norm = MixtralRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
MixtralDecoderLayer

主干网络的layers层就是由多个MixtralDecoderLayer组成的,由num_hidden_layers参数决定,一般我们说的模型量级就取决于这个数量,7b的模型DecoderLayer层的数量是32。

MixtralDecoderLayer层中又包含了Attention层和MOE层,主要的一个思想是利用了残差结构。

如下图所示,分为两个部分

第一部分

  • 首先,将hidden_states(文本向量化的结构)进行复制,即残差
  • 归一化
  • 注意力层
  • 残差相加

第二部分

  • 首先将第一部分得到的hidden_states进行复制,即残差
  • 归一化
  • MLP层
  • 残差相加

在这里插入图片描述

#复制一份
residual = hidden_states
#归一化
hidden_states = self.input_layernorm(hidden_states)

#注意力层
hidden_states, self_attn_weights, present_key_value = self.self_attn(
    hidden_states=hidden_states,
    attention_mask=attention_mask,
    position_ids=position_ids,
    past_key_value=past_key_value,
    output_attentions=output_attentions,
    use_cache=use_cache,
    padding_mask=padding_mask,
)
#加上残差
hidden_states = residual + hidden_states

#复制一份
residual = hidden_states
#归一化
hidden_states = self.post_attention_layernorm(hidden_states)
#mlp
hidden_states = self.mlp(hidden_states)
#加上残差
hidden_states = residual + hidden_states

outputs = (hidden_states,)

if output_attentions:
    outputs += (self_attn_weights,)

if use_cache:
        outputs += (present_key_value,)

return outputs
Attention

进行位置编码,让模型更好的捕捉上下文信息

在这里插入图片描述

#经过线性层
query_states = self.q_proj(hidden_states)
key_states = self.k_proj(hidden_states)
value_states = self.v_proj(hidden_states)

#多头注意力形状变换
query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
kv_seq_len = key_states.shape[-2]

#计算cos、sin
#计算旋转位置嵌入
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)

#计算权重
key_states = repeat_kv(key_states, self.num_key_value_groups)
value_states = repeat_kv(value_states, self.num_key_value_groups)
attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)

#加上掩码
attn_weights = attn_weights + attention_mask
#计算softmax
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
attn_output = torch.matmul(attn_weights, value_states)

attn_output = self.o_proj(attn_output)

MOE

MOE层,也就是我们的专家模块,简单来说,主要干的就是通过一个线性层,得到8个专家,从这8个专家中选出最专业的2个,把他们的权重相加,输入到MLP层,得到最终的结果。

  • attention层得到的hidden_states经过控制门(nn.Linear)得到8个输出。(有点像多分类)
  • t通过softmax计算8个输出的概率值
  • 从8个中选择概率值最高的两个专家
  • 概率最高的两个专家进行权重相加,并计算相对概率值
  • 这两个专家输入到MLP层中进行一系列计算得到最后结果

在这里插入图片描述

batch_size, sequence_length, hidden_dim = hidden_states.shape
hidden_states = hidden_states.view(-1, hidden_dim)
#这里通过一个线性层,得到8个输出(n_experts),也就是所谓的专家。
# router_logits: (batch * sequence_length, n_experts)
router_logits = self.gate(hidden_states)
#这里通过softmax计算8个输出的概率值,如(0.2,0.3,0.0833,0.0833,0.0833,0.0833,0.0833,0.0833)
routing_weights = F.softmax(router_logits, dim=1, dtype=torch.float)
#从8个中选择概率值最高的两个专家((0.2,0.3)
routing_weights, selected_experts = torch.topk(routing_weights, self.top_k, dim=-1)
#概率最高的两个专家进行权重相加,并计算相对概率值((0.2,0.3)->(0.4,0.6)
routing_weights /= routing_weights.sum(dim=-1, keepdim=True)
routing_weights = routing_weights.to(hidden_states.dtype)
#初始化最终结果
final_hidden_states = torch.zeros(
            (batch_size * sequence_length, hidden_dim), dtype=hidden_states.dtype, device=hidden_states.device
        )
#掩码
expert_mask = torch.nn.functional.one_hot(selected_experts, num_classes=self.num_experts).permute(2, 1, 0)

for expert_idx in range(self.num_experts):
    expert_layer = self.experts[expert_idx]
    #通过掩码找到top2的位置
    idx, top_x = torch.where(expert_mask[expert_idx])

    if top_x.shape[0] == 0:
        continue
        top_x_list = top_x.tolist()
        idx_list = idx.tolist()

        #top2对应的向量
        current_state = hidden_states[None, top_x_list].reshape(-1, hidden_dim)
        #经过mlp
        current_hidden_states = expert_layer(current_state) * routing_weights[top_x_list, idx_list, None]

        #加到final_hidden_states中
        final_hidden_states.index_add_(0, top_x, current_hidden_states.to(hidden_states.dtype))
        final_hidden_states = final_hidden_states.reshape(batch_size, sequence_length, hidden_dim)
return final_hidden_states, router_logits

MLP

mlp层的主要作用是应用非线性激活函数和线性投影。

  • 首先将attention层得到的结果经过两个线性层得到gate_proj和up_proj
  • gate_proj经过激活函数,再和up_proj相乘
  • 最后经过一个线性层得到最后的结果

在这里插入图片描述

self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
self.act_fn = ACT2FN[config.hidden_act]
down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))

下游任务

因果推理

所谓因果推理,就是回归任务。

在这里插入图片描述

self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
文本分类

即分类任务

在这里插入图片描述

self.score = nn.Linear(config.hidden_size, self.num_labels, bias=False)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/252462.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

1.Mybtis-Plus框架基本使用

Mybatis-plus是一个mybatis的增强工具,在mybatis的基础上只做增加不做改变,简化开发 提供通用的`mapper和service` 可以在不编写任何SQL语句的情况下快速实现对单表CRUD、批量、逻辑删除、分页操作 Mybatis-plus提供优秀插件,并对idea中快速开发插件mybatisX也进行功能使用。…

甜酷女孩穿搭 I 时尚与保暖都兼具的羽绒服

这款工装风羽绒服 酷酷的中性风 清新温柔的杏紫两色 采用定制复合面料 顺滑平整硬朗的材质 具有防水功能 下雪下雨天也不用担心哦 90白鹅绒,立领连帽设计 帽子做的是可拆卸 可以切换两种风格 袖口采用可调节魔术贴设计 下摆可调节抽绳设计 处处透着细节…

网络安全——SSH密码攻击实验

一、实验目的要求: 二、实验设备与环境: 三、实验原理: 四、实验步骤:​ 五、实验现象、结果记录及整理: 六、分析讨论与思考题解答: 网络安全-SSH密码攻击实验效果截图: https://downloa…

设计模式(3)--对象结构(3)--组合

1. 意图 将对象组合成树形结构以表示“部分-整体”的层次结构。Composite使得用户对单个对象和组合对象的使用具有一致性。 2. 三种角色 抽象组件(Component)、组合式节点(Composite)、叶节点(Leaf) 3. 优点 3.1 定义了包含基本对象和组合对象的类层次结构。 客户代码中&…

openwrt中taiscale自动安装脚本详解

openwrt中taiscale自动安装脚本详解 一、代码仓库地址 https://github.com/adyanth/openwrt-tailscale-enabler 二、代码仓库中脚本文件详解 主要包含三个脚本分别是etc/init.d/tailscale、usr/bin/tailscale、usr/bin/tailscaled ,接下来逐个分析一下脚本中的具…

【Redis】AOF 基础

因为 Redis AOF 的实现有些绕, 就分成 2 篇进行分析, 本篇主要是介绍一下 AOF 的一些特性和依赖的其他函数的逻辑,为下一篇 (Redis AOF 源码) 源码分析做一些铺垫。 AOF 全称: Append Only File, 是 Redis 提供了一种数据保存模式, Redis 默认不开启。 AOF 采用日志的形式来记…

Go标准包之flag命令行参数解析

1.介绍 在 Go中,如果要接收命令行参数,需要使用 flag 包进行解析。不同的参数类型可以通过不同的方法接收。 2.参数接受 2.1 接受方式 使用flag接收参数,可以由以下三种方式接受: 方式一: flag.Type(name,defaultVal,desc)方…

Linux上使用HTTP协议进行数据获取的实战示例

嗨,Linux爱好者们,今天我们要一起探讨一下如何在Linux上进行HTTP协议的数据获取。这不是一项简单的任务,但放心,我会以最简单的语言,结合实例来给大家讲解。 首先,我们需要一个工具,那就是curl…

Python生成器(Generator)(继续更新...)

学习网页: Welcome to Python.orghttps://www.python.org/https://www.python.org/ Python生成器 生成器(Generator)是 Python 的一种特殊类型的迭代器。生成器允许你创建自己的数据流,每次从数据流中获取一个元素,…

医保电子凭证在项目中的集成应用

随着医保电子凭证使用普及,医疗行业的各个场景都要求支持医保码一码通办,在此分享一下,在C#和js中集成医保电子凭证的demo 供有需要的小伙伴参考。 一、项目效果图 在c#中集成医保电子凭证效果 在js中集成医保电子凭证效果 二、主要代码 c#…

Linux_Docker图形化工具Portainer如何安装并结合内网穿透实现远程访问

文章目录 前言1. 部署Portainer2. 本地访问Portainer3. Linux 安装cpolar4. 配置Portainer 公网访问地址5. 公网远程访问Portainer6. 固定Portainer公网地址 前言 本文主要介绍如何本地安装Portainer并结合内网穿透工具实现任意浏览器远程访问管理界面。Portainer 是一个轻量级…

算法竞赛备赛进阶之树形DP训练

目录 1.树的最长路径 2.树的中心 3.数字转换 4.二叉苹果树 5.战略游戏 6.皇宫守卫 树形DP是一种动态规划方法,主要用于解决树形结构的问题。在树形DP中,通常会使用动态规划的思想来求解最优化问题。其核心在于通过不断地分解问题和优化子问题来解决…

【理论篇】SaTokenException: 非Web上下文无法获取Request问题解决 -理论篇

在我们使用sa-token安全框架的时候,有时候会提示:SaTokenException:非Web上下文无法获取Request 错误截图: 在官方网站中,查看常见问题排查: 错误追踪: 跟着源码可以看到如下代码: 从源码中&a…

01 整体代码运行流程

文章目录 01 整体代码运行流程1.1 运行官方 Demo1.2 变量命名规则1.3 多线程1.4 线程锁1.5 SLAM 主类 System 01 整体代码运行流程 1.1 运行官方 Demo 以 stereo_kitti 为例,执行 ./stereo_kitti path_to_vocabulary path_to_settings path_to_sequence./stereo_…

大创项目推荐 深度学习 python opencv 实现人脸年龄性别识别

文章目录 0 前言1 项目课题介绍2 关键技术2.1 卷积神经网络2.2 卷积层2.3 池化层2.4 激活函数:2.5 全连接层 3 使用tensorflow中keras模块实现卷积神经网络4 Keras介绍4.1 Keras深度学习模型4.2 Keras中重要的预定义对象4.3 Keras的网络层构造 5 数据集处理训练5.1 …

W25Q64(模拟SPI)读写数据的简单应用

文章目录 一、W25Q64是什么?二、使用步骤1.硬件1.引脚说明2.硬件连接3.设备ID4.内部框架5.指令集指令集1指令集2 2.软件1.W25Q64引脚定义代码如下(示例):2.W25Q64初始化代码如下(示例):3.W25Q64…

在排序数组中查找元素的第一个和最后一个位置(Java详解)

一、题目描述 给你一个按照非递减顺序排列的整数数组 nums,和一个目标值 target。请你找出给定目标值在数组中的开始位置和结束位置。 如果数组中不存在目标值 target,返回 [-1, -1]。 你必须设计并实现时间复杂度为 O(log n) 的算法解决此问题。 示…

Android开发——组合函数、注解与连接Android设备

1、JetPack Compose、组合函数与注解和文本修改 1、JetPack Compose:Jetpack Compose 是由 Google 推出的用于构建 Android 用户界面的现代化工具包。它是一个声明式的 UI 工具包,用于简化 Android 应用程序的用户界面设计和开发。Jetpack Compose 采用…

02.Git常用基本操作

一、基本配置 (1)打开Git Bash (2)配置姓名和邮箱 git config --global user.name "Your Name" git config --global user.email "Your email" 因为Git是分布式版本控制工具,所以每个用户都需要…

02-分组查询group by和having的使用

分组查询 MySQL中默认是对整张表的数据进行操作即整张表为一组, 如果想对每一组的数据进行操作,这个时候我们需要使用分组查询 分组函数的执行顺序: 先根据where条件筛选数据,然后对查询到的数据进行分组,最后也可以采用having关键字过滤取得正确的数据 group by子句 在一条…