HiFT全参数微调新范式---逐层微调

论文链接: https://arxiv.org/abs/2401.15207

HiFT 是一个端到端的层级优化策略。目前论文的结果是原始混合精度的结果,目前最新进展已将混合精度进行了分层适配,微调7B模型的内存需求约为16.87G,13B模型约为31G(batch=1,seq_length=512)

背景

在大语言模型之前,语言模型适配下游任务的首选方法就是全参数微调。随着大语言模型的出现,全参数微调对GPU的内存需求越来越高。以LORA为代表的PEFT(Parameter-Efficient Fine-tuning)方法成为微调大模型的首选。PEFT的方法以较低的GPU内存代价在多个任务上取得了媲美全参数微调的性能。但是已有的研究表明PEFT的方法整体上性能和全参数微调相比仍有差距。

最近的一些工作开始关注内存高效的全参数微调方法。这些方法的特点就是采用原始SGD优化器的思路----无动量优化。类似SGD这样的优化器没有优化器的状态,因此可以极大的降低微调过程中优化器占用的内存。其中代表的就是MeZO和LOMO优化器。MeZO优化器在prompt的微调下性能接近全参数微调,在无prompt的情况下性能和全参数微调差距较大。MEZO能做到在80G内存设备上微调30B的模型。LOMO优化器除了无动量优化器外,采用了融合梯度计算和更新的策略,这样做可以同时降低梯度的内存占用,但是LOMO需要forward两次。经过我们测试,在不使用外部优化技术的情况下,在内存节省上MeZO更有优势。这些零阶优化器虽然可以节省内存,但是类似Adamw这种优化器器训练更加稳定,更容易收敛。使用这些优化器替代被广泛验证过的优化器代价是巨大的。

内存被谁占用了

微调过程中占用内存的主要部分有:模型参数 ,梯度 ,优化器状态,剩余部分(有的地方叫激活状态,这里我们参考混合精度论文的叫法,主要包括,激活状态,图中间变量等)。模型参数由于需要进行前向传播,该部分很难进行优化,或者说我们必须要把它加载到GPU上。标准的全参数微调梯度等于模型参数;优化器状态参数量取决于使用了几阶动量,AdamW使用了二阶动量,状态是模型参数的2倍,SGD为零阶优化器,状态参数为0,SGDM加入了迭代动量,该部分参数和模型参数相等。剩余部分主要是由图的激活状态和中间变量占用,该部分占用的内存和输入呈正相关,也就是输入长句越长,batch size 越大该部分参数占用的内存就越多,也叫动态内存占用。

图片

framework

图1:HiFT 策略的示意图。Group表示对图层的分组操作。bottom2up、top2down 和 random 代表训练策略。灰色表示对应参数处于冻结状态,棕色表示对应参数处于激活状态。K表示组数,n表示给定模型的层数,BP表示通过反向传播进行参数更新。

如何降低内存占用

梯度,优化器状态和神经元的激活都和模型的可训练参数量相关。既然一次更新所有参数有困难,那是否可以一次只更新一层参数?这好比爬山,当路比较宽的时候,大家可以同时走,当路比较窄的时候只能排着队走,只要保证后面的人不掉队保持步调一致就行。实际上模型更新也可以这么做。具体的做法如下:第一步,先将模型进行分组,如上图所示,假设分了K组,K最大为模型的层数。分完组后,在微调过程中,每个step只更新其中一组参数,假设我们选择了自低向上 (bottom2up)的更新方式, 当更新低2组的参数时候,冻结其它层的参数,下一个step更新第三组的参数同时冻结其它层,按照自低向上的顺序依次更新每一层参数,直至模型收敛。这样更新方式有一个问题,如果每个step都对学习率更新,那会导致第K组参数更新幅度过大导致整个模型参数更细幅度不一致出现loss震荡的情况,这就好比排队爬山,如果第一个人永远迈大步,而最后一个人永远迈小步,最后就会出现人员掉队的问题,模型的更新也是同样的道理。因此我们采用了学习率延迟更新的策略,只有当所有层都更细一遍的时候,才对学习率进行更新。

这种分层更新的方式每个step只有一组参数是需要计算梯度的,剩下的参数都是冻结状态。当K等于模型层数时,训练过程中的峰值训练参数量等于模型中参数量最大的那一层的参数量。可训练参数量的减少,将直接会降低梯度参数的内存使用。由于每个step只更新部分参数,因此没有必要将所有的优化器状态都同时保存在GPU上,所以每个step只有需要进行梯度更新参数的状态会在GPU上,而其它参数的优化器状态都会被保留在CPU上。

图片

图2:在 E2E  数据集上微调 LLaMA2-7B 的 GPU 内存使用情况。序列长度和批量大小分别设置为 512 和 6。#Dtype表示训练时使用的数据类型,其中FP32表示以32位精度对模型进行全参数微调,mixed表示以混合精度进行微调。#Trainable 参数表示微调过程中单个步骤中出现的最大可训练参数数量。#Para表示模型参数占用的内存,#Gra表示梯度占用的内存,#Sta表示优化器状态占用的内存。#PGS 表示模型参数(即#Para)、梯度(即#Gra)和优化器状态(即#Sta)占用的内存总和。残留状态主要包括激活状态、临时缓冲区和不可用的碎片内存。Total 表示微调期间使用的总内存。HiFT的参数K设置为最大,既每组一层参数。 需要说明的是,混合精度的结果是混合精度未适配分层微调下的结果。 最新的结果混合精度适配分层策略后,支持24G设备全参数微调7B模型

以adamw 优化器为例,标准全参数微调下需要GPU的参数量为=。 分层微调下 =.   两者的差距为 = 。 以使用AdamW优化器对LLaMA-7B单精度微调为例,模型参数 约为26.08G,理论上,7B 模型的, 和微调所需的 GPU 内存约为 104.32 GB。LLaMA-7B 有34 层(包括embedding层和head层)。可推断, 和所需GPU显存为约 31.13G。与标准的全参数微调相比节省GPU显存约73.19G。

大模型下混合精度的问题

图2为 使用不同优化器微调LLaMA2-7B模型是的GPU使用情况,需要说明的是,混合精度的结果是混合精度未适配分层微调下的结果。可以看出混合精度下的内存需求比不使用混合精度的更高。解释这一结果我们需要了解混合精度的原理。混合精度虽然会以半精度进行前向传播,但是在更新模型的时候,会备份出32位的权重进行梯度更新,同时梯度也会在32位环境下进行更新,原因是位解决数值下溢的问题,在训练后期,激活函数的梯度会非常小, 甚至在梯度乘以学习率后,值会更加小。如果利用 fp16 来进行参数更新的话,会出现舍入误差(fp16的表示范围有限,超过这个范围的值会被置为0)问题,导致更新无效。

问题来了,为什么混合精度能降低内存使用?结论是混合精度能大幅降低动态内存的使用,既主要是激活神经元的内存占用。以往的模型比如RoBERTa,GPT-2这些模型,模型的参数量较小(相比于大模型说),这个时候,模型占用的参数量十分有限,微调过程中使用内存较多的是和输入相关的动态内存部分。比如RoBERTa-base,以adamw 微调过程中(batch=8,max length=512),模型参数,梯度,优化器状态占据的固定内存约为1.9G.  而动态内存部分约为5G.  小模型时代,微调过程中动态内存占主导,当使用混合精度微调时,由于前向传播半精度该部分内存会大幅降低,batch越大,句子长度越大,混合精度的降低的内存越多,可以抵消多出来的16为模型权重占据的内存。

大模型背景下,因为设备限制,很多情况下无法使用大的batch发挥出混合精度的优势。以65B模型为例,单精度仅模型参数需求内存约为242G,半精度的内存需求约为121G,混合精度下仅模型参数需求的内存为242+121=363G,只有当混合精度降低的动态内存超过121G时候,混合精度的优势才能体现出来,但是大多数情况下,设备的限制,无法使用大的batch size. 根据我们的实验结果看,当微调3B(GPT-Neo)左右的模型时候,在小的batch下(我们设置的batch size 为8,句子长度为512),混合精度已经没有内存优势。当然微调GPT-large时 混合精度仍然有优势。

混合精度适配分层策略

目前论文的结果是原始混合精度的结果,目前最新进展已将混合精度进行了分层适配,batch=1,句子长度为512下,实测结果7B模型大约混合精度下内存需求为16.87G,13B模型为31G。请持续关注我们的工作,代码基于基于hugging face,代码即将开源,可移植性友好,可与LORA等其PEFT的方式融合

图片

图3:(a)显示了RoBERTa-base在不同策略下的HiFT性能。B2U、T2D和RAN分别代表使用策略bottom2up、top2bottom和随机微调策略。(b)表示RoBERTa-base在不同分组设置下的HiFT性能。m表示分组设置中每组的层数。

图3显示了采用不同的微调策略和设置不同组队模型性能的影响。结论影响可以忽略不计。 一个有意思的现象是微调的顺序对模型性能几乎没有影响。

VS MeZO和LOMO

  1. 我们的方法和MeZO的在下游任务的比较参请考原始论文,我们的方法比MEZO具有明显的性能优势。请持续关注我们的最新论文获取最新的结果。

  2. 和LOMO相比,在不使用其它内存节省技术下,使用e2e 数据,对llama2-7B 测试,batch为1,最大句子长度为512。混合精度下,LOMO的峰值内存为21.57GB,HiFT为16.87GB;单精度下:LOMO 60.06GB,HiFT 29.73GB

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/367364.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

Fluent的小bug处理:后处理截面显示存在漏洞

最近发现的Fluent的bug,关于后处理截面显示不完整的问题。 1 现象 在使用六面体核心类型单元(包括四面体-六面体核心和多面体-六面体核心)进行网格划分的时候,可能会在截面上不能完整捕捉单元形状及其分布状态,导致做…

后端——go系统学习笔记(不断更新中......)

数组 固定大小 初始化 arr1 : [3]int{1, 2, 3} arr2 : [...]int{1, 2, 3} var arr3 []int var arr4 [4]int切片 长度是动态的 初始化 arr[0:3] slice : []int{1,2,3} slice : make([]int, 10)len和cap len是获取切片、数组、字符串的长度——元素的个数cap是获取切片的容量—…

docker相关问题解决(file exists、not a directory

背景 以下环境为wsl file exists 缓存没删干净 docker-compose down -v not a directory flags: 0x5000: not a directory: unknown: Are you trying to mount a directory onto a file (or vice-versa)? 明明我确定报错指示的位置就是文件而不是文件夹...相当神奇的错误 …

【lesson2】定长内存池的实现

文章目录 介绍定长内存池的设计定长内存池的实现需要成员变量需要的成员函数定长内存池结构定长内存池Delete(释放空间)的实现定长内存池New(申请空间)的实现 定长内存池的实现完整版 介绍 作为程序员(C/C)我们知道申请内存使用的…

谷歌产品大更新:Bard可生成图像;文生音乐平台等5大免费功能

2月2日,谷歌在官网对生成式AI产品进行了大更新,包括类ChatGPT聊天助手Bard可以通过文本提示生成图像; 全新的文生音乐平台MusicFX;新的文生图像平台ImageFX;新的文本扩写平台TextFX;在谷歌地图中增加生成式…

Open3D 深度图像转点云

目录 一、算法原理1、算法过程2、主要函数3、算法源码二、代码实现三、结果展示1、深度图像2、点云四、测试数据

Python详细教程

一、Python简历 Python 是一个高层次的结合了解释性、编译性、互动性和面向对象的脚本语言。 Python 的设计具有很强的可读性,相比其他语言经常使用英文关键字,其他语言的一些标点符号,它具有比其他语言更有特色语法结构。 Python 是一种解…

客户端和服务端的简介

Client 和 Server 客户端(Client) 或称用户端,是指与服务器相对应,为客户提供本地服务的程序。除了一些只在本地运行的应用程序之外,一般安装在客户机上,需要与服务端互相配合运行。例如:下载 Q…

jvm基础篇之垃圾回收[3](垃圾回收器)

文章目录 分代GC代取划分原因垃圾回收器组合关系年轻代-Serial垃圾回收器老年代-SerialOld垃圾回收器年轻代-ParNew垃圾回收器老年代-CMS垃圾回收器年轻代-Parallel Scavenge垃圾回收器老年代-Parallel Old垃圾回收器 G1垃圾回收器G1内存结构G1回收方式年轻代回收混合回收FULL …

安全通信设置:使用 OpenSSL 为 Logstash 和 Filebeat 提供 SSL 证书

在为 Elasticsearch 采集数据时,我们经常使用到 Filebeat 及 Logstash。在我们之前的很多教程中,我们通常不为 Filebeat 和 Logstash 之前的通信做安全配置。 如何为 Filebeat 及 Logstash 直接建立安全的链接?这个在很多的情况下是非常有用的…

2024美赛A题七鳃鳗种群复杂系统动力学模型完整成品论文和代码

经过不懈的努力,2024美赛A题完整成品论文和代码已完成,代码为A题全部4问的代码,论文包括摘要、问题重述、问题分析、模型假设、符号说明、模型的建立和求解(问题1七鳃鳗种群竞争模型的建立和求解、问题2种群优势劣势评估模型的建立…

很多人不看好造车新势力,我却坚信他们一定会成功

最近我国出现了很多新能源汽车品牌,除了理小蔚之外,最近爆火的华为与赛力斯合作的问界以及小米借用北汽生产的小米SU7汽车。可能是于大嘴和雷布斯营销过度了,引起了很多网民的质疑,更是引来了汽车大佬长安董事长的担忧。朱董事长说…

如何取消隐藏Excel中的行?这里提供详细步骤

取消隐藏Microsoft Excel电子表格中的所有行就像按下键盘快捷键或使用功能区上的按钮一样简单。我们将向你展示如何操作。 如何使用快捷方式取消隐藏Excel中的所有行 若要在电子表格中显示隐藏行,请使用Microsoft Excel启动电子表格。然后,访问包含隐藏…

超详细Anconda pytorch cuda cuDNN安装及介绍(李沐老师视频环境)

零、准备知识阶段 ⇲ 显卡驱动、CUDA、cuDNN之间联系以及安装配置 在配置PyTorch的过程中,显卡驱动、CUDA、cuDNN三者之间的关系、作用以及在众多版本中如何搭配一直困扰着我。虽然网上资料很多,但各说其词,即使最终迈过种种坑成功运行&…

快充协议的奥秘:工作原理与特性比较

文章目录 一、 前言二、快充协议1.公有协议1.1 PD协议介绍发展史USB-IF组织PPS快充协议 1.2 QC协议介绍发展史 1.3 PE协议介绍 2.私有协议2.1 VOOC 闪充2.2 FCP/SCP 闪充2.3 FlashCharge闪充2.4 MIChargeTurbo闪充2.5 AFC闪充2.6 mCharge快充 三、总结 一、 前言 最近&#xf…

【深度学习】数据归一化/标准化 Normalization/Standardization

目录 一、实际问题 二、归一化 Normalization 三、归一化的类型 1. Min-max normalization (Rescaling) 2. Mean normalization 3.Z-score normalization (Standardization) 4.非线性归一化 4-1 对数归一化 4-2 反正切函数归一化 4-3 小数定标标准化(Demi…

echarts中绘制3D三维地球

简介 echarts中的三维地球,需要用到世界地图json数据,我把json文件放到我的资源中,有需要的自行下载。 安装插件 // 安装echats npm install echarts --save npm install echarts-gl --save 项目中引用 1,引入安装的echarts…

Http请求Cookie失效问题

Http请求Cookie失效问题记录 一、问题现象 在开发功能的过程中,业务依赖cookie进行取之,项目进行交互时会对前端http请求携带的cookies进行解析操作,但在自测调试对过程中出现账户的授权失效的报错问题。 二、问题排查 用arthas进行代码方…

【国产MCU】-CH32V307-GPIO控制:输入与输出

GPIO控制:输入与输出 文章目录 GPIO控制:输入与输出1、GPIO简单介绍2、驱动API介绍3、GPIO配置代码实现3.1 GPIO配置为输出3.2 GPIO配置为输入CH32V307的GPIO口可以配置成多种输入或输出模式,内置可关闭的上拉或下拉电阻,可以配置成推挽或开漏功能。GPIO口还可以复用成其他…

【24美赛思路已出】2024年美赛A~F题解题思路已出 | 无偿自提

A题:资源可用性和性别比例 问题一: 涉及当灯鱼种群的性别比例发生变化时,对更大的生态系统产生的影响。为了分析这个问题,可以采用以下的数学建模思路:建立灯鱼种群模型: 首先,建立一个灯鱼种群…