[阅读笔记20][BTX]Branch-Train-MiX: Mixing Expert LLMs into a Mixture-of-Experts LLM

这篇论文是meta在24年3月发表的,它提出的BTX结构融合了BTM和MoE的优点,既能保证各专家模型训练时的高度并行,又是一个统一的单个模型,可以进一步微调。

这篇论文研究了以高效方法训练LLM使其获得各领域专家的能力,例如写代码、数学推理以及自然知识。现有的融合多个专家模型的方法有Branch-Train-Merge和Mixture-of-Experts,前者BTM各专家模型在不进行任何同步的情况下并行训练,大大提升了训练时的吞吐量,但是缺乏一个统一的模型,导致没法进行后续的SFT和RLHF,这两步是对齐LLM的重要步骤。后者MoE虽然是一个统一的模型,可以进行微调了,但是训练时是各专家模型是完全同步的,并且由于all-to-all通信,随着专家数量增加通信成本也在增加。

这篇论文提出的BTX就是融合了BTM和MoE的优点,弥补了二者的缺点,具体来说,BTX的各个专家模型可以异步的独立训练,大大提高了模型训练时的数据吞吐量,另外BTX是一个统一的模型,所以之后可以对其进行微调。
实现分为三步,首先是Branch,这里取了四个一样的种子模型LLaMA-2 7B,然后其中三个分别在数学数据集、代码数据集、维基百科上预训练,最后剩下的保留LLaMA原始权重。前三个专家模型分别具有数学推理能力、代码能力、世界知识,最后一个专家模型作为通才专家,将通用知识迁移到模型中。第二步是Train,这三个领域专家分别在各自领域数据集(Llemma、CodeLlama、Wikipedia)上预训练,这个过程是并行且互不干扰的。第三步是Mix,也就是将这四个专家模型进行混合,这一步在下一张ppt会详细说明。

具体融合四个专家的过程其实就是把这四个专家的前馈层进行合并,也就是将同层次的四个前馈合并为一个MoE层,下图公式展示了如何合并,整个MoE层输出是各前馈层的加权和。这里使用了Top2路由,对于输入x使用投影矩阵Wl进行投影,然后取值最大的两个专家模型进行混合,混合比例由softmax计算得到。
对于模型的其他部分,例如注意力层、embedding层,BTX混合各专家的方法是直接把对应的模型参数取平均,这个方法比较粗暴,作者给出的解释是这个做法基于一个假设:自注意力层比前馈层更通用化。另外后续微调阶段还会对这些参数进行调整,所以问题不大。

预训练说完了,接下来就是微调了。微调使用的数据仍然来自训练时数据,作者对用于训练四个专家的数据集进行采样得到微调用的数据,采样概率数学是30%,代码是40%,维基百科是10%,LLaMA-2是19%。
下图是训练三个专家模型时使用的三个专业领域数据集及其采样比例。

接下来是结果展示。左图是各模型的训练代价和平均性能对比,x轴是训练所需要的GPU天数,圆圈大小是推理时激活参数数量。右图是不同领域的性能差异,可以看到BTX在各方面都得到了很大的改善,尤其是代码领域,已经接近专用模型CodeLlama了。

这张图反映了每个token来自于哪两个专家,共有四个专家,所以有六种组合,分别用不同的颜色来代表。如果是由领域内专家生成的token,则标上下划线。对于数学、代码、知识问答这三个领域的三个输出,大部分token都是由各自领域专家生成。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/562088.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

idea项目启动异常:Command line is too long.

项目场景: 提示:这里简述项目相关背景: idea中启动项目报错: 解决方案 在idea 的运行配置中,修改enviroment下的shorten command line 为jar manifest 注: 有时shorten command line 可能不是默认存在的…

Linux实验一:NAT、桥接方式的验证

实验名称:在虚拟机中安装RHEL7,验证NAT、桥接上网方式 实验结果: 创建虚拟机 NAT模式 自动获取IP 手动配置IP 桥接模式 自动获取IP 手动配置IP 总结和分析:

我与C++的爱恋:类和对象(四)

​ ​ 🔥个人主页:guoguoqiang. 🔥专栏:我与C的爱恋 ​ 朋友们大家好!本篇是类和对象的最后一个部分。 一、static成员 声明为static的类成员称为类的静态成员,用static修饰的成员变量,称之…

[阅读笔记29][AgentStudio]A Toolkit for Building General Virtual Agents

这篇论文是24年3月提交的,提出了一个用于agent开发的全流程工具包。 作者提到目前agent开发主要有两个阻碍,一个是缺乏软件基础,另一个是缺乏在真实世界场景中进行评估。针对这两个阻碍,作者涉及了一个开发工具包,包括…

使用立创EDA打开JSON格式的PCB及原理图

一、将PCB和原理图放同一文件夹 并打包成.zip文件 二、打开嘉立创EDA并导入.zip文件 文件 -> 导入 -> 嘉立创EDA标准版/专业版 三、选择.zip文件并选择 “导入文件并提取库” 四、自定义工程路径 完成导入并转换为.eprj文件 五、视频教学 bilibili_使用立创EDA打开JSO…

NLP预训练模型-GPT-3

ChatGPT GPT-3是OpenAI开发的一个自然语言处理(NLP)预训练模型。GPT代表“生成式预训练变换器”(Generative Pretrained Transformer)。GPT-3是GPT系列的第三代模型,是一种采用了深度学习技术的强大语言模型&#xff…

驱动开发-windows驱动设计目标

驱动程序和应用程序不一样的,由于其直接运行于windows r0级,故对于开发有更多和更严格的标准,一般会有以下一些常见的设计目标: 安全性、可移植性、可配置性、 可被中断、多处理器安全、可重用 IRP、 支持异步 I/O这些是基本目标。 1. 安全…

【Numpy】对于 Numpy 中 Axis 的理解

文章目录 前言理解轴的两个角度在维度变化方向上计算降维 示例剖析写在最后 前言 Numpy 是 Python 中一个常用科学计算库,常用来表示向量、矩阵以及多维度数组。在 Numpy 中多对某一个维度(轴)进行相应的操作,这一点经常出错。今…

再论图像变化和频率的关系。

我之前是做了一些探讨,但是没说清楚,现在再看这个问题。 我先提出这个问题。 以以为点列为例,先写成傅里叶级数的形式,不过这里不是三角函数形式,而是指数形式,是一样的。 对f(n)求导,就可以观…

【大语言模型LLM】-使用大语言模型搭建点餐机器人

关于作者 行业:人工智能训练师/LLM 学者/LLM微调乙方PM发展:大模型微调/增强检索RAG分享国内大模型前沿动态,共同成长,欢迎关注交流… 大语言模型LLM基础-系列文章 【大语言模型LLM】-大语言模型如何编写Prompt?【大语言模型LL…

C语言—字符指针,指针数组和数组指针详解

字符指针 在指针的类型中我们知道有一种指针类型为字符指针 char* ; int main() {char ch w;char *pc &ch;*pc w;return 0; }还有一种使用方式如下: int main() {const char* pstr "hello world.";//这里是把一个字符串放到pstr指针变量里了吗…

chrome浏览器查看css样式

样式的查看 1.匹配器为灰色文本: 表示非当前选择器 2.样式有划线标识:CSS属性无效或未知 / 属性值无效 / 被其他属性覆盖的属性 3.属性以浅色文本显示且有感叹号提示:属性虽然有效,但由于CSS逻辑而没有任何影响 转自:…

笔试狂刷系列--Day1

大家好,我是LvZi,今天开启新的章节笔试狂刷系列 一.两个数组的交集 1. 题⽬链接: 两个数组的交集 思路分析: 查找两个数组的公共元素,一开始可能想到使用Set,先遍历第一个数组,存储nums1中所有的元素,接着遍历nums2中的所有元素,判断是否在Set之中,但是发现在遍历第二个数组…

神经网络中的神经元和激活函数介绍

文章目录 1、什么是人工神经网络 2、什么是神经元 3、什么是激活函数 线性激活函数 Sigmoid激活函数 双曲正切激活函数 修正线性单元(ReLU)激活函数 Leaky ReLU激活函数 Softmax激活函数 1、什么是人工神经网络 神经网络能够利用多层神经元学习复杂的模…

使用docker打包当前服务器的neo4j环境

Docker 是一个开源的应用容器引擎,它允许开发者将应用程序及其依赖打包到一个可移植的容器中,这样应用程序就可以在任何支持Docker的平台上运行,而无需担心环境差异。 当运行一个Docker容器时,它会加载一个镜像并运行它。Docker在容器内部创建一个隔离的环境,这个环境被称…

Redis学习-Redis的九种数据结构

String (字符串) 虽然redis是用C语言编写,但是redis中的string是redis自己实现的字符串结构,叫Simple Dynamic String简称(SDS),因为redis做为中间件会接受不同语言编写的程序传过来的字符串&a…

Oracle Hint 语法详解

什么是Hint Hint 是 Oracle 提供的一种 SQL 语法,它允许用户在 SQL 语句中插入相关的语法,从而影响 SQL 的执行方式。 因为 Hint 的特殊作用,所以对于开发人员不应该在代码中使用它,Hint 更像是 Oracle 提供给 DBA 用来分析诊断问…

Python中pyside2出现的pyside2 qt platform plugin could be in错误及其解决方法

系统平台:Win10 64bit python版本: python 3.8 使用pip install pyside2安装 pyside2 这是找不到QT平台的插件,这是环境变量QT_QPA_PLATFORM_PLUGIN_PATH出现错误 具体解决方法: 我们可以在每一段程序开始之前设定环境变量&…

pytorch与深度学习

ChatGPT PyTorch是一个由Facebook AI Research Team开发的开源深度学习库,它提供了一个灵活的环境和丰富的API,用于快速且方便地构建、训练和部署深度学习模型。PyTorch在科学界和工业界都收到了广泛的使用,其中包括了学术研究、小型项目和大…

第50篇:算法的硬件实现<一>

Q:本期我们来开始介绍如何使用算法状态机(ASM)图在硬件开发板上实现算法。 A:算法状态机 (Algorithmic State Machine,ASM) 图是描述数字系统控制单元的工作流程图,主要用来描述控制单元的时序操作特性&am…