大模型增量预训练新技巧:解决灾难性遗忘

大家好,目前不少开源模型在通用领域具有不错的效果,但由于缺乏领域数据,往往在一些垂直领域中表现不理想,这时就需要增量预训练和微调等方法来提高模型的领域能力。

但在领域数据增量预训练或微调时,很容易出现灾难性遗忘现象,也就是学会了垂直领域知识,但忘记了通用领域知识,之前介绍过增量预训练以及领域大模型训练技巧。

今天给大家带来一篇增量预训练方法-Llama-Pro,对LLMs进行Transformer块扩展后,增量预训练过程中仅对新增块进行训练,有效地进行模型知识注入,并且极大程度地避免灾难性遗忘。

图片

LLaMA Pro: Progressive LLaMA with Block Expansion

LLaMA Pro: Progressive LLaMA with Block Expansion
Paper: https://arxiv.org/abs/2401.02415
Github: https://github.com/TencentARC/LLaMA-Pro

文章目录

    • 技术交流群
    • 用通俗易懂方式讲解系列
    • 块扩展方法
    • 实验细节
    • 讨论分析
    • 写在最后

技术交流群

前沿技术资讯、算法交流、求职内推、算法竞赛、面试交流(校招、社招、实习)等、与 10000+来自港科大、北大、清华、中科院、CMU、腾讯、百度等名校名企开发者互动交流~

我们建了大模型面试与技术交流群, 想要进交流群、需要源码&资料、提升技术的同学,可以直接加微信号:mlc2060。加的时候备注一下:研究方向 +学校/公司+CSDN,即可。然后就可以拉你进群了。

方式①、微信搜索公众号:机器学习社区,后台回复:加群
方式②、添加微信号:mlc2060,备注:技术交流

资料1
在这里插入图片描述

用通俗易懂方式讲解系列

  • 用通俗易懂的方式讲解:自然语言处理初学者指南(附1000页的PPT讲解)
  • 用通俗易懂的方式讲解:1.6万字全面掌握 BERT
  • 用通俗易懂的方式讲解:NLP 这样学习才是正确路线
  • 用通俗易懂的方式讲解:28张图全解深度学习知识!
  • 用通俗易懂的方式讲解:不用再找了,这就是 NLP 方向最全面试题库
  • 用通俗易懂的方式讲解:实体关系抽取入门教程
  • 用通俗易懂的方式讲解:灵魂 20 问帮你彻底搞定Transformer
  • 用通俗易懂的方式讲解:图解 Transformer 架构
  • 用通俗易懂的方式讲解:大模型算法面经指南(附答案)
  • 用通俗易懂的方式讲解:十分钟部署清华 ChatGLM-6B,实测效果超预期
  • 用通俗易懂的方式讲解:内容讲解+代码案例,轻松掌握大模型应用框架 LangChain
  • 用通俗易懂的方式讲解:如何用大语言模型构建一个知识问答系统
  • 用通俗易懂的方式讲解:最全的大模型 RAG 技术概览
  • 用通俗易懂的方式讲解:利用 LangChain 和 Neo4j 向量索引,构建一个RAG应用程序
  • 用通俗易懂的方式讲解:使用 Neo4j 和 LangChain 集成非结构化知识图增强 QA
  • 用通俗易懂的方式讲解:面了 5 家知名企业的NLP算法岗(大模型方向),被考倒了。。。。。
  • 用通俗易懂的方式讲解:NLP 算法实习岗,对我后续找工作太重要了!。
  • 用通俗易懂的方式讲解:理想汽车大模型算法工程师面试,被问的瑟瑟发抖。。。。
  • 用通俗易懂的方式讲解:基于 Langchain-Chatchat,我搭建了一个本地知识库问答系统
  • 面试了字节大模型算法岗(实习),快被问哭了。。。。

块扩展方法

块扩展,顾名思义,就是在原始模型中每个Transformer块或者某几个Transformer块后增加一个Transformer块,但为了保持扩展后的模型输出保持不变,需要增加的块为恒等块(输入输出相同),如下图所示。

图片

在构建恒等块过程中,主要是将多头注意力层和FFN层中的最后一个线性层(Linear)权重置为0变成Zero-Linear,即可保持经过该块的输入输出一致。

PS:论文附录A中写了大段的推导公式来证明,在此不做过多介绍。

块的增加方式是,对原始模型的 个Transformer块分成 组,每组中包含 个Transformer块,对于每组后添加 个恒等块。代码实现具体如下:

model = AutoModelForCausalLM.from_pretrained(model_path, torch_dtype=torch.float16)
ckpt = model.state_dict()

# original_layers是模型原始层数,layers是模型最后达到层数
split = int(original_layers / (layers - original_layers))

layer_cnt = 0

output = {}
for i in range(original_layers):
    for k in ckpt:
        if ('layers.' + str(i) + '.') in k:
            output[k.replace(('layers.' + str(i) + '.'), ('layers.' + str(layer_cnt) + '.'))] = ckpt[k]
    layer_cnt += 1
    if (i+1) % split == 0:
        for k in ckpt:
            if ('layers.' + str(i) + '.') in k:
                if 'down_proj' in k or 'o_proj' in k:
                    output[k.replace(('layers.' + str(i) + '.'), ('layers.' + str(layer_cnt) + '.'))] = torch.zeros_like(ckpt[k])
                else:
                    output[k.replace(('layers.' + str(i) + '.'), ('layers.' + str(layer_cnt) + '.'))] = ckpt[k]
        layer_cnt += 1
    
assert layer_cnt==layers
for k in ckpt:
    if not 'layers' in k:
        output[k] = ckpt[k]

torch.save(output, output_path)

实验细节

数据由代码和数学组成,其中代码数据采用The-Stack-Dedup数据集中Python语言部分共22B Token,数学数据采用Proof-Pile-2数据集中AlgebraicStack、OpenWebMath和ArXiv部分共55B,详细如下表所示。

图片

数据分布

基础模型为LLaMA2-7B模型,通过块扩展方法将32层模型扩展到40层,其中 、 、 ,每个组从4个Transformer块扩展到5个Transformer块。

对于代码和数学数据进行增量预训练,批量大小为1024,序列最大长度为4096,预热比率为6%,学习率为2e-4,采用余弦学习率调度器,BF16混合精度训练,权重衰减为0.1。使用16个NVIDIA H800 GPU进行了15900个步骤的训练,大约耗费2830个GPU/小时。

在ARC、HellaSwag、MMLU、TruthfulQA、Winogrande、GSM8K、GSM8K-PoT、HumanEval、MBPP等多个评测数据集中进行评测,可以看出,在保持通用任务能力不下降的情况下,数学和代码能力较原始LLaMA2-7B模型有很大提升。

图片

图片

讨论分析

对比块扩展方法与正常训练和Lora方法之间的区别,采用TRACE基准利用总体性能(OP)和逆向转移(BWT)指标进行评估。,如下表所示,块扩展方法整体提升较大。

图片

对比块个数对块扩展方法的影响,进行了不同个数块的实验,并且对比了MoE的方法,训练损失如下,MoE方法的损失下降程度与添加四个块相当。

图片

在代码和法律(16.7B)领域数据下进行增量预训练,在通用任务以及领域任务上比较不同个数块之间的差异,同时比较扩展块全部添加到模型底部或顶部之间的差别,如下所示。可以发现块个数为8时效果最佳,并且不能直接将扩展块全部堆积在头部或尾部,需要分开插入。

图片

写在最后

该方法主要通过增加恒定块扩展模型层数,使模型在增量训练过程中仅训练新增层、冻结原始层,保持模型原有能力,防止模型出现灾难性遗忘现象。

但有两点存疑:

  • 目前来说mistral要好于llama,为啥不用mistral进行实验

  • 不用恒定块,性能会差多少

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/369494.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

Github 2024-02-04 开源项目日报 Top9

根据Github Trendings的统计,今日(2024-02-04统计)共有9个项目上榜。根据开发语言中项目的数量,汇总情况如下: 开发语言项目数量Python项目6Ruby项目1HTML项目1C项目1Go项目1TypeScript项目1 Windows 终端、控制台和命令行存储库 创建周期…

2024年新版全国行政区划代码

嗨喽,大家好,我是小码哥,今天免费将2024年全国行政区划代码分享给大家,已经整理成sql和excel文件,方便大家直接使用,文章末尾直接获取。 01-数据来源 根据国家统计局官网统计查询,目前全国最新…

mermaid使用指南+notion使用实例-持续更新中

最近一个月了吧,发现Notion插入图片的功能坏了,直接paste会404,本地上传也不行。电脑本地版和手机端都插不了图片,很头疼。解决方法也简单,用图床,放链接。 付费版我用的七牛,结合PicGo&#x…

C#,雅各布斯塔尔—卢卡斯(Jacobsthal Lucas Number)的算法与源代码

1 雅各布斯塔尔序列 雅各布斯塔尔序列是一个与斐波那契序列类似的加法序列,由递归关系JnJn-12Jn-2定义,初始项J00,J11。序列中的一个数字称为雅可布沙尔数。它们是卢卡斯序列Un(P,Q)的一种特殊类型&#x…

SmartX 在保险(2023):服务 40+ 客户,聚焦信创转型与高性能数据库场景

更新内容: 更新 SmartX 超融合在保险行业的覆盖范围与部署规模。更新保险客户超融合应用情况。新增 Nutanix 国产化替代、高性能数据库构建与验证、企业云原生转型等场景实践。更多超融合金融核心生产业务场景实践,欢迎阅读文末电子书。 近两年来&…

[Python] 什么是逻辑回归模型?使用scikit-learn中的LogisticRegression来解决乳腺癌数据集上的二分类问题

什么是线性回归和逻辑回归? 线性回归是一种用于解决回归问题的统计模型。它通过建立自变量(或特征)与因变量之间的线性关系来预测连续数值的输出。线性回归的目标是找到一条直线(或超平面),使得预测值与观…

华为数通方向HCIP-DataCom H12-821题库(单选题:421-440)

第421题 以下关于IS-IS中路由器分类的描述,错误的是哪一项? A、Level-1路由器无法与Level-2路由器建立邻接关系 B、华为路由器上配置IS-IS时,缺省时,路由器全局Level为Level-1-2 C、Level-2的LSDB只包含Level-2路由器所在区域的路由信息 D、Level-1路由器可以和Level-1-2路…

vite打包原理

vite 工程化开发:打包工具 启动速度很快 核心原理还是webpack 把webpack封装了,把webpack对象封装了 和vue2整体结构几乎一致 webpack两种模式:开发&生产 代码打包编译,本地起一个web服务器实时预览编译后的结果 build 命令模…

Go协程揭秘:轻量、并发与性能的完美结合

目录 1. Go协程简介什么是Go协程?Go协程与线程的比较Go协程的核心优势 2. Go协程的基本使用创建并启动Go协程使用匿名函数创建Go协程Go协程与主函数 3. Go协程的同步机制1. 通道 (Channels)2. sync.WaitGroup3. 互斥锁 (sync.Mutex) 4. Go协程的高级用法1. 选择器 (…

jss/css/html 相关的技术栈有哪些?

js 的技术组件有哪些?比如 jQuery vue 等 常见的JavaScript技术组件: jQuery: jQuery是一个快速、小巧且功能丰富的JavaScript库,用于简化DOM操作、事件处理、动画效果等任务。 React: React是由Facebook开发的用于构…

【大数据实时数据同步】超级详细的生产环境OGG(GoldenGate)12.2实时异构同步Oracle数据部署方案(上)

系列文章目录 【大数据实时数据同步】超级详细的生产环境OGG(GoldenGate)12.2实时异构同步Oracle数据部署方案(上) 【大数据实时数据同步】超级详细的生产环境OGG(GoldenGate)12.2实时异构同步Oracle数据部署方案(中) 【大数据实时数据同步】超级详细的生产环境OGG(GoldenGate…

海量数据处理商用短链接生成器平台 - 2

第二章 短链平台项目创建git代码管理开发分层规范 第1集 短链平台实战-Maven聚合工程创建微服务项目 **简介:Maven聚合工程创建微服务项目实战 ** Maven聚合工程拆分 dcloud-common 公共依赖包 dcloud-app FlinkKafka实时计算 dcloud-account 账号流量包微服务 dc…

程序报错无法打开源文件stdafx.h

在运行代码时,代码中头文件突然报错程序无法打开源文件stdafx.h include “stdafx.h”,编译器就说无法打开源文件,直接上干货解决方法是: 1.打开项目 ->项目属性(最后一个)-> C/C ->常规, 2在附…

【MySQL】——数据定义

🎃个人专栏: 🐬 算法设计与分析:算法设计与分析_IT闫的博客-CSDN博客 🐳Java基础:Java基础_IT闫的博客-CSDN博客 🐋c语言:c语言_IT闫的博客-CSDN博客 🐟MySQL&#xff1a…

第二十二回 横海郡柴进留宾 景阳冈武松打虎-大模型ChatGLM2-6B新手速通

柴进说这人叫武松,排名老二。宋江说江湖上听说过武二郎的名字,幸会幸会,就拉着武松的手,一起喝酒吃饭。 武松是家喻户晓的打虎英雄,现在最流行的是大模型! 大模型ChatGLM2-6B新手速通! 人工智能…

成都爱尔林江院长解读儿童青少年为什么一定要进行医学验光配镜

根据国家卫健委数据显示:我国青少年儿童总体近视率为52.7%、高度近视人口超3000万。近视学生中,有10%为高度近视,且占比随年级升高而增长。 近视孩子之多,孩子视力发展备受关注。戴镜进行近视防控十分必要,且眼镜不可随意验配! 成…

超越体量:TinyLlama用1.1B参数实现大模型级性能

引言 随着人工智能技术的快速发展,大型语言模型(LLM)在全球范围内受到瞩目。但与此同时,另一类模型正在逐渐崭露头角:参数规模较小的语言模型。这类模型在计算资源受限的环境下显示出巨大潜力,特别是在智能…

SpringbootV2.6整合Knife4j 3.0.3 问题记录

参考 https://juejin.cn/post/7249173717749940284 近期由于升级到springboot2.6X,所以服务端很多组件都需要重新导入以及解决依赖问题。 下面就是一个很经典的问题了, springboot2.6与knife4j的整合。 版本对应 springboot2.6与knife4j 3.0.3 坑 …

STM32外部中断(红外传感器与旋转编码器计数案例)

文章目录 一、介绍部分简介中断系统中断执行流程STM32中断NVIC基本结构NVIC优先级分组外部中断外部中断简介外部中断基本结构外部中断的流程AFIOEXTI框图 相关外设介绍旋转编码器介绍硬件电路对射式红外传感器 二、代码实现对射式红外传感器计次连接电路封装红外传感器与中断函…

JDK17中的密封类sealed和permits使用指南:什么是Java中的sealed和permits?

博主猫头虎的技术世界 🌟 欢迎来到猫头虎的博客 — 探索技术的无限可能! 专栏链接: 🔗 精选专栏: 《面试题大全》 — 面试准备的宝典!《IDEA开发秘籍》 — 提升你的IDEA技能!《100天精通鸿蒙》 …