大模型增量预训练新技巧-解决灾难性遗忘

大模型增量预训练新技巧-解决灾难性遗忘

机器学习算法与自然语言处理 2024年03月21日 00:02 吉林

以下文章来源于NLP工作站 ,作者刘聪NLP

NLP工作站.

AIGC前沿知识分享&落地经验总结

转载自 | NLP工作站

作者 | 刘聪NLP

目前不少开源模型在通用领域具有不错的效果,但由于缺乏领域数据,往往在一些垂直领域中表现不理想,这时就需要增量预训练和微调等方法来提高模型的领域能力。

但在领域数据增量预训练或微调时,很容易出现灾难性遗忘现象,也就是学会了垂直领域知识,但忘记了通用领域知识,之前介绍过增量预训练以及领域大模型训练技巧,详见:

  • 如何更好地继续预训练-Continue PreTraining

  • 领域大模型-训练Trick&落地思考

今天给大家带来一篇增量预训练方法-Llama-Pro,对LLMs进行Transformer块扩展后,增量预训练过程中仅对新增块进行训练,有效地进行模型知识注入,并且极大程度地避免灾难性遗忘。

图片

LLaMA Pro: Progressive LLaMA with Block Expansion

 

LLaMA Pro: Progressive LLaMA with Block Expansion
Paper: https://arxiv.org/abs/2401.02415
Github: https://github.com/TencentARC/LLaMA-Pro

块扩展方法

块扩展,顾名思义,就是在原始模型中每个Transformer块或者某几个Transformer块增加一个Transformer块,但为了保持扩展后的模型输出保持不变,需要增加的块为恒等块(输入输出相同),如下图所示。

图片

在构建恒等块过程中,主要是将多头注意力层和FFN层中的最后一个线性层(Linear权重置为0变成Zero-Linear,即可保持经过该块的输入输出一致。

PS:论文附录A中写了大段的推导公式来证明,在此不做过多介绍。

块的增加方式是,对原始模型的L个Transformer块分成N组,每组中包含M=L/N个Transformer块,对于每组后添加P个恒等块。代码实现具体如下:

model = AutoModelForCausalLM.from_pretrained(model_path, torch_dtype=torch.float16)
ckpt = model.state_dict()

# original_layers是模型原始层数,layers是模型最后达到层数
split = int(original_layers / (layers - original_layers))

layer_cnt = 0

output = {}
for i in range(original_layers):
    for k in ckpt:
        if ('layers.' + str(i) + '.') in k:
            output[k.replace(('layers.' + str(i) + '.'), ('layers.' + str(layer_cnt) + '.'))] = ckpt[k]
    layer_cnt += 1
    if (i+1) % split == 0:
        for k in ckpt:
            if ('layers.' + str(i) + '.') in k:
                if 'down_proj' in k or 'o_proj' in k:
                    output[k.replace(('layers.' + str(i) + '.'), ('layers.' + str(layer_cnt) + '.'))] = torch.zeros_like(ckpt[k])
                else:
                    output[k.replace(('layers.' + str(i) + '.'), ('layers.' + str(layer_cnt) + '.'))] = ckpt[k]
        layer_cnt += 1
    
assert layer_cnt==layers
for k in ckpt:
    if not 'layers' in k:
        output[k] = ckpt[k]

torch.save(output, output_path)

实验细节

数据由代码和数学组成,其中代码数据采用The-Stack-Dedup数据集中Python语言部分共22B Token,数学数据采用Proof-Pile-2数据集中AlgebraicStack、OpenWebMath和ArXiv部分共55B,详细如下表所示。

图片

数据分布

基础模型为LLaMA2-7B模型,通过块扩展方法将32层模型扩展到40层,其中 P=1,M=4,N=8,每个组从4个Transformer块扩展到5个Transformer块。

对于代码和数学数据进行增量预训练,批量大小为1024,序列最大长度为4096,预热比率为6%,学习率为2e-4,采用余弦学习率调度器,BF16混合精度训练,权重衰减为0.1。使用16个NVIDIA H800 GPU进行了15900个步骤的训练,大约耗费2830个GPU/小时

ARC、HellaSwag、MMLU、TruthfulQA、Winogrande、GSM8K、GSM8K-PoT、HumanEval、MBPP等多个评测数据集中进行评测,可以看出,在保持通用任务能力不下降的情况下,数学和代码能力较原始LLaMA2-7B模型有很大提升。

图片

图片

讨论分析

对比块扩展方法与正常训练和Lora方法之间的区别,采用TRACE基准利用总体性能(OP)和逆向转移(BWT)指标进行评估。,如下表所示,块扩展方法整体提升较大。

图片

对比块个数对块扩展方法的影响,进行了不同个数块的实验,并且对比了MoE的方法,训练损失如下,MoE方法的损失下降程度与添加四个块相当

图片

代码和法律(16.7B)领域数据下进行增量预训练,在通用任务以及领域任务上比较不同个数块之间的差异,同时比较扩展块全部添加到模型底部或顶部之间的差别,如下所示。可以发现块个数为8时效果最佳,并且不能直接将扩展块全部堆积在头部或尾部需要分开插入

图片

写在最后

该方法主要通过增加恒定块扩展模型层数,使模型在增量训练过程中仅训练新增层、冻结原始层,保持模型原有能力,防止模型出现灾难性遗忘现象。

但有两点存疑:

  • 目前来说mistral要好于llama,为啥不用mistral进行实验

  • 不用恒定块,性能会差多少

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/783689.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

基于three.js的数字孪生项目,慢如老牛,7条优化技术。

基于three.js的数字孪生项目慢如老牛可能有以下几个地方可以提升: 优化模型加载: 数字孪生项目通常涉及复杂的3D模型,加载大型模型可能会导致性能下降。可以尝试使用压缩模型、使用LOD(Level of Detail)技术根据距离…

you should not run configure as root, 升级tar出错

为了能用 tar 支持 zstd 的压/解缩包命令,需要升级 tar 到 1.3 以上,下面是下载和编译、安装命令: wget https://mirrors.aliyun.com/gnu/tar/tar-1.32.tar.bz2 tar -jxvf tar-1.32.tar.bz2 cd tar-1.32 ./configure make make install但在执…

Pandas 学习笔记(四)--CSV文件

CSV文件 CSV(Comma-Separated Values,逗号分隔值,有时也称为字符分隔值,因为分隔字符也可以不是逗号),其文件以纯文本形式存储表格数据(数字和文本)。 读取与写入 读取csv文件 i…

202406 CCF-GESP Python 三级试题及详细答案注释

202406 CCF-GESP Python 三级试题及详细答案注释 1 单选题(每题 2 分,共 30 分)第 1 题 小杨父母带他到某培训机构给他报名参加CCF组织的GESP认证考试的第1级,那他可以选择的认证语言有几种?( ) A. 1 B. 2 C. 3 D. 4答案:C解析:目前CCF组织的GESP认证考试有C++、Pyth…

Java语言程序设计——篇二(1)

Java语言基础 数据类型关键字与标识符关键字标识符 常量与变量1、常量2、变量 类型转换自动类型转换强制类型转换 数据类型 数据的基本要素数据的性质(数据结构)数据的取值范围(字节大小)数据的存储方式参与的运算 Java是一门强类…

权力之望怎么注册账号创建角色 权利之网角色账号注册教程

权力之望是一款全新的大型MMORPG游戏,拥有9把独特武器和56种职业组合,并搭配了超炫酷的战斗画面,全程采用低俯视角游戏,让玩家能体验到更强的操作感和爽快感。这款游戏主打高养成自由度玩家可以自由更换武器进行战斗,还…

MySQL之表的约束(下)

自增长 auto_increment:当对应的字段,不给值,会自动的被系统触发,系统会从当前字段中已经有的最大值 1操作,得到一个新的不同的值。通常和主键搭配使用,作为逻辑主键。 自增长的特点: 1. 任何一…

SSM慢性病患者健康管理系统-计算机毕业设计源码04877

目 录 摘要 1 绪论 1.1 研究意义 1.2研究目的 1.3论文结构与章节安排 2 慢性病患者健康管理系统系统分析 2.1 可行性分析 2.1.1 技术可行性分析 2.1.2 经济可行性分析 2.1.3 法律可行性分析 2.2 系统功能分析 2.2.1 功能性分析 2.2.2 非功能性分析 2.3 系统用例分…

python语句性能分析

1、for语句性能优于while import timeif __name__ __main__:start_time time.time()for i in range(10 ** 8):passend_time time.time()run_time end_time - start_timeprint(run_time)i 0start_time time.time()while i < 10 ** 8:i 1end_time time.time()run_tim…

【CSAPP】-cachelab实验

目录 实验目的与要求 实验设备与软件环境 实验过程与结果&#xff08;可贴图&#xff09; 操作异常问题与解决方案 实验总结 实验目的与要求 1、掌握应用程序性能的优化方法&#xff1b; 2、理解存储器层次结构在程序运行过程中所起的重要作用&#xff1b; 3、让学生更好…

一网统管/视频汇聚/安防监控平台EasyCVR启动后无法访问是什么原因?

智慧城市/一网统管/视频汇聚/安防监控平台EasyCVR兼容性强&#xff0c;支持多协议接入&#xff0c;包括国标GB/T 28181协议、GA/T 1400协议、部标JT808协议、RTMP、RTSP/Onvif协议、海康Ehome、海康SDK、大华SDK、华为SDK、宇视SDK、乐橙SDK、萤石云SDK等&#xff0c;并能对外分…

科普文:jvm实战(六)搞懂各个版本JDK和GC

jdk6&#xff0c;7&#xff0c;8三个版本的内存模型 如图所示 JDK 1.6、1.7、1.8 的内存模型演变过程&#xff0c;其实这个内存模型就是 JVM 运行时数据区依照JVM虚拟机规范的具体实现过程。 JDK 1.6&#xff1a;程序计数器、Java虚拟机栈、本地方法栈、堆、方法区[永久代]&am…

C++系列-String(四)String初步的模拟实现

&#x1f308;个人主页&#xff1a;羽晨同学 &#x1f4ab;个人格言:“成为自己未来的主人~” 下面的这些是我们这篇文章将要实现的String的功能&#xff1a; #pragma once #include<iostream> #include<assert.h> using namespace std;namespace bit {class…

下载程序到仿真

第一步&#xff0c;新建工程 第二步&#xff0c;设备组态 第三步&#xff0c;地址分配 需要注意的是&#xff0c;分配地址的范围&#xff0c;是CPU决定的。 关于常见数据类型 下载与仿真 一般安装好博图会自带。 PLCSIM/PLCSIM Advanced PLCSIM普通仿真 PLCSIM Advanced高级…

Spark 分布式弹性计算集(RDD)相关概念介绍

目录 一、概述 二、RDD的核心概念 2.1 Partition 2.2 Partitioner 2.3 RDD的依赖关系 2.4 Stage 2.5 PreferredLocation 2.6 CheckPoint 三、RDD的持久化 3.1 概述 3.2 概念 3.3 RDD持久化级别 3.3.1 MEMORY_ONLY 3.3.2 MEMORY_AND_DISK 3.3.3 MEMORY_ONLY_SER …

昇思第18天打卡|ShuffleNet图像分类

ShuffleNet网络介绍 ShuffleNetV1是旷视科技提出的一种计算高效的CNN模型&#xff0c;和MobileNet, SqueezeNet等一样主要应用在移动端&#xff0c;所以模型的设计目标就是利用有限的计算资源来达到最好的模型精度。ShuffleNetV1的设计核心是引入了两种操作&#xff1a;Pointw…

使用命令行修改Ubuntu 24.04的网络设置

Ubuntu里&#xff0c;使用命令行下修改IP地址&#xff0c;网上有很多方案&#xff0c;我最终觉得这个方案&#xff08;使用Netplan&#xff09;最好&#xff0c;最根本&#xff0c;记录下来备查 1.使用命令ip link show 查看Ubuntu上可以使用的网络接口名称 2.查找Netplan的配…

「Java开发指南」如何用MyEclipse完成Spring Web Flow 2.0搭建?

本教程将引导您完成Spring Web Flow的软件组件生成&#xff0c;这是Spring的一个项目&#xff0c;用于简化Web应用程序的开发。虽然Spring Web Flow与Spring MVC兼容&#xff0c;但Spring Web Flow使用流而不是控制器来实现应用程序的Web层。在本教程中&#xff0c;您将学习如何…

放大镜案例

放大镜 <!DOCTYPE html> <html lang"zh-cn"><head><meta charset"UTF-8"><meta name"viewport" content"widthdevice-width, initial-scale1.0"><title>商品放大镜</title><link rel&qu…

可控硅整流自动恒流充电器设计制作

该充电器除可为各种镍镉电池充电外&#xff0c;也可为干电池充电。其充电电流可调。充电终止电压由RP1预先确定。 工作原理 电路原理见图1。开始充电时&#xff0c;电池组两端电压较低&#xff0c;不足以使晶体管VT导通。由RC组成的移相电路给可控硅提供触发电流。移相角度由R…