LLM 可以从简单数据中学习吗?

在 10 月份的一次周会结束后,我提到 SFT 训练后的 Loss 曲线呈现阶梯状,至于为什么,并没有人有合理的解释,加上当时的重心是提升次日留存率,Loss 曲线呈现阶梯状与次日留存率的关系还太远,即使有问题,起码次日留存率是逐渐在提升。

幸运的是,在一次逛论坛时发现了一篇博客 Can LLMs learn from a single example?,也是我这篇博客的标题名称由来,在其基础上结合了公司业务的一些现状和我个人的思考。
图:fastchat 框架在 VTuber 数据集上训练全 BOT 回复,在 3 epoch 上的 Loss 曲线。

可以清楚地看到每个 epoch 的终点——loss 突然向下跳。我们以前也见过类似的损失曲线,但都是由于错误造成的。例如,在评估验证集时,很容易意外地让模型继续学习——这样在验证之后,模型就会突然变得更好。因此,开始寻找训练过程中的错误。

发现该“问题”的时间,恰好与单句重复问题同一时期(9 月份),于是推测是不是 context length 从 2k 变到 4k 所致,以及 Transformers 库和 RoPE 位置编码的问题。在开始逐步修改代码的同时,在 Alignment Lab AI Discord 上看到他人反馈的类似的奇怪 loss 曲线,并且每个回复的人也都在使用 Trainer,这在当时加深了我认为 Transformers 库存在问题的猜测,甚至我还去询问了同事李老师是否有同样的问题,以及 load model 时的 warning。

9 月中旬,老板要求我们加上验证 loss,于是出现了如下图所示的 eval loss 曲线。

eval loss 曲线

该问题在 Discord 上讨论得越来越激烈,也有人反映在不使用 Trainer 的情况下,也会出现阶梯状的 loss 曲线。

查阅资料,看到一种假设:即这些训练曲线实际上显示了过拟合。起初,这似乎是不可能的。这意味着模型正在学习识别来自一个或两个示例的输入。如果回过头来看我们展示的第一条曲线,就会发现 loss 在第二和第三个 epoch 期间,它根本没有学习到任何新东西。因此,除了在第一个 epoch 开始时的初始学习(学习了多轮对话的对齐方式)外,几乎所有表面上的学习都是(根据这一理论)对训练集的记忆。此外,对于每个问题,它只能获得极少量的信号:它对答案的猜测与真实标签的比较。

资料提到了一项实验:使用以下学习率计划,对 Kaggle 模型进行了两个 epoch 的训练:

如今,这种 schedule 并不常见,但莱斯利-史密斯(Leslie Smith)在 2015 年发表的论文《训练神经网络的循环学习率》(Cyclical Learning Rates for Training Neural Networks)中讨论了这种方法,并取得了很大成功。

下面就是我们因此而看到的看起来很疯狂的训练和验证损失曲线:

到目前为止,我们唯一能完全解释这种情况的方法就是假设是正确的:模型正在快速学习识别实例,即使只看到一次。让我们依次查看 loss 曲线的各个部分:

  • 从第一个 epoch 来看,这是一条非常标准的 loss 曲线。在第一个 10% 的 epoch 中,学习率开始升温,一旦达到温度后,训练和验证 loss 就会迅速降低;然后按照余弦曲线逐渐下降,两者都会放缓。
  • 第二个 epoch 才是我们感兴趣的地方。我们并没有在 epoch 开始时重新 shuffle 数据集,因此第二个 epoch 的第一批数据是学习率仍在预热的时候。这就是为什么在我们展示的第一条 loss 曲线中,没有看到像从 epoch 2 到 epoch 3 那样的直接阶跃变化——这些批次只有在学习率较低时才会出现,所以它学不到太多东西。在 epoch 2 开始 10% 时,训练 loss 急剧下降,因为在第一个 epoch 中看到这些批次时,学习率很高,模型已经知道了它们的样子,因此它可以非常自信地猜出正确答案。但在此期间,验证 loss 会受到影响。这是因为虽然模型变得非常自信,但实际上它的预测能力并没有提高。它只是记住了数据集(早期没有清洗掉训练数据中的保底回复以及一些涉及到公司信息的关键词,模型会输出这些内容,甚至会将原样的超时保底回复输出),但并没有提高泛化能力。过于自信的预测会导致验证损失变大,因为损失函数会对更自信的错误进行更高的惩罚。
  • 曲线的末端是特别有趣的地方。训练 loss 开始变得越来越大,而这是绝对不应该发生的!事实上,我还从未在使用合理的学习率时遇到过这种情况。根据记忆假说,这完全说得通:这些批次是模型在学习率再次下降时看到的,因此它无法有效地记忆这些批次。但模型仍然过于自信,因为它刚刚得到了一大堆几乎完全正确的批次,还没有适应现在看到的批次没有机会学得那么好这一事实。它会逐渐重新校准到一个更合理的置信度水平,但这需要一段时间,因为学习率越来越低。在重新校准的过程中,验证 loss 会再次下降。
    记忆假说很有可能是真的。按照先前小模型时代的训练经验,我们往往需要大量的数据来让模型学习输入分布和模式。使用随机梯度下降法(SGD)导航的损失面太崎岖,无法一下子跳得很远。不过,有些东西可以让损失面变得更平滑,比如使用残差连接,如经典论文《可视化神经网络的损失景观》(Li et al,2018)中所示。

很可能的情况是,预训练的大语言模型在接近最小损失的区域具有极其平滑的损失面,而开源社区所做的大量微调工作都是在这一区域。这是基于最初开发微调通用语言模型的基本前提。简单来说,我们的训练数据并不能够让模型跳出该平滑的损失面,只是让模型记住了 BOT 的回复、以及通过几个数据就让模型学到了说话风格。

如果以上猜测都属实,这不是什么糟糕的事情,拥有一个学习速度非常快、且能够举一反三的模型是一件非常棒的事情。同时,这也佐证了《LIMA:Less Is More for Alignment》、《A Few More Examples May Be Worth Billions of Parameters》、《Maybe only 0.5% Data is Needed: A Preliminary Exploration of Low Training Data Instruction Tuning》等一系列证明少量优质、多样性丰富的指令数据就能让模型有很强指令遵循的论文的有效性。以及最近出现的一系列关于指令数据集子集选择的论文,例如《Smaller Language Models are capable of selecting Instruction-Tuning Training Data for Larger Language Models》、《LESS: Selecting Influential Data for Targeted Instruction Tuning》。这些论文提到经过他们方法挑选出来的子集,在该子集上训练出来的模型比在全量数据集上微调的模型效果要更好。

我统计了从 7 月 到 11 月份所训练模型的 Loss 曲线是否呈现阶梯状,正常表示平滑下降,不正常表示阶梯下降(在每个 epoch 交界处骤降)。早期训练的模型的 loss 曲线都是正常,可惜的是早期的训练数据被删了,无法准确地判断是数据质量的因素,还是基底模型的因素。

早期训练遵循多阶段的方式,即先在 continual pretrain 得到的 base 模型上用 GPT4all 数据集以及一个闲聊场景的对话集进行训练,然后再用高质量的对话数据集再次微调。以此得到的模型表现平常,虽不会犯错,但也没有新意,不能提升平均对话轮数,因此后续我们不再进行 base model -> GPT4all + 闲聊数据集 -> 高质量对话数据集的多段式 SFT,而是直接在 base model 上用高质量对话数据集进行 SFT。在这之后训练的模型的 loss 曲线都是阶梯状,按照记忆假说和先前分析的内容来看,llama2、vicuna-13b-v1.5 等模型的对话、闲聊能力得到了提升(也有可能是 GPT4all 数据集让模型闲聊能力下降),在我们所认为的“高质量”数据集上进行训练,模型只是记住了对话内容,而非真正意义上地学习(训练数据集对于模型来说非常简单)。

PS:我没有否认和贬低这种方式,当模型的“脑容量”(记忆力)大到能够将我们提供的优质回复都记住,并且在合适的场景输出,这在业务上完全没有问题。在复读机问题上,将高质量数据集从 4k 扩充至 26k 后,的确减少了该问题的频次。

一个猜想:当模型的学习速度如此之快时,灾难性遗忘问题可能会突然变得明显得多。例如,如果一个模型看到了十个非常常见关系的示例,然后又看到了一个不太常见的反例,那么它很可能会记住这个反例,而不仅仅是稍微降低它对原来十个示例的记忆权重。

在 6 月下旬时,老板询问我为什么模型的效果不太好时,我想了想说是灾难性遗忘(找的理由)。现在看来,似乎的确大概率是这个原因。沿着 base model -> GPT4all + 闲聊数据集 -> 高质量数据集训练的路径,希望模型能够不断地进化,但实际上 base model 原先的知识和 GPT4all 数据集中的内容都遗忘得差不多。因此,不要多阶段 SFT,而是将每个阶段的训练数据进行混合,可以减少灾难性遗忘的影响,这或许就是后来尝试数据混合方案后,能够提升次日留存率的一个原因?

此外,我们还需要审视,对于模型来说,什么是高质量的数据集。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mfbz.cn/a/612273.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

使用单片机的IO引脚直接驱动段码屏

使用单片机的IO引脚直接驱动段码屏,目的是为了降低成本。这种古老的应用,在低功耗产品中比较多见。 如:水表,燃气表等需要电池供电的产品。 下面纯属个人理解,未经测试。 1/3Duty表示LCD共有3个COM引脚,分别占显示周期的1/3 1/2BIAS表示电压0和VCC 1、…

2024年记一次Mingw64-13.2.0编译Qt6.6.3,包含文档编译。

My C Development. 前言:不包含qtwebengine。 一、准备文件 (1)mingw64-13.2.0 下载链接:,ucrt64_13.2_ucrt_posix_rev6_msys2.7z【蓝奏云】。 (2)qt6.6.3源码 下载链接:Downlo…

纯血鸿蒙APP实战开发——一镜到底“页面转场”动画

介绍 本方案做的是页面点击卡片跳转到详情预览的转场动画效果 效果图预览 使用说明 点击首页卡片跳转到详情页,再点击进入路由页面按钮,进入新的路由页面 实现思路 首页使用了一种视觉上看起来像是组件的转场动画,这种转场动画通常是通过…

opencv绘制灰度直方图-------c++

灰度直方图 cv::Mat opencvTool::calculateHistogram(const cv::Mat& image) {// 如果输入图像尚未处于灰度级,请将其转换为灰度级cv::Mat grayscale_image;if (image.channels() > 1){cv::cvtColor(image, grayscale_image, cv::COLOR_BGR2GRAY);}else{gra…

求一个B站屏蔽竖屏视频的脚本

求一个B站屏蔽竖屏视频的脚本 现在B站竖屏竖屏越来越多了,手机还好点给我一个按钮,选择不喜欢,但是我一般都用网页版看视屏,网页版不给我选择不喜欢的按钮,目测大概1/4到1/3的视频都是竖屏视频。 目前网页版唯一的进…

使用AudioCraft(MusicGen)生成音乐

AudioCraft 是一个 PyTorch 库,用于音频生成的深度学习研究。AudioCraft 包含 AudioGen 和 MusicGen 两个最先进的人工智能生成模型的推理和训练代码,用于生成高质量的音频。 MusicGen 是一种简单可控的音乐生成模型,它使用Meta 20K 小时的授权音乐来进行训练,能够生成与文…

SM4在线解密工具(支持GCM模式)

SM4在线解密工具(支持GCM模式)

spring boot参数验证注解@NotNull、@NotBlank和@NotEmpty区别

目录 前言说明举例 前言 使用spring boot参数验证是常常会使用NotNull、NotBlank和NotEmpty三个判断是否不为空的注解,中文都有不能为空的意思,大部分使用者都傻傻分清它们之间到底有什么区别。今天就让咱们来一起探索它们之间的不同吧。 说明 注解名…

rngd: Error writing /dev/tpm0

检查数据库时发现messages中一直有rngd报错,rngd一直未配置,直接关闭了 /var/log/messages-20240414:Apr 11 04:59:49 hydb2 rngd: Error writing /dev/tpm0 /var/log/messages-20240414:Apr 12 07:31:39 hydb2 rngd: Error writing /dev/tpm0 /var/log…

深度学习之前馈神经网络

1.导入常用工具包 #在终端中输入以下命令就可以安装工具包 pip install numpy pip install pandas Pip install matplotlib注: numpy是科学计算基础包 pandas能方便处理结构化数据和函数 matplotlib主要用于绘制图表。 #导包的代码: import numpy as n…

怎样的跨网软件,可以实现网间数据的安全收发?

网络隔离已是较为常见的网络安全保护措施,比如防火墙、网闸、VLAN,云桌面虚拟环境等方面进行隔离。像一些科技研发型企业,不仅仅是内外网隔离,甚至还划分办公网、研发网、测试网、生产网等,防止研发资料、设计资料等敏…

【机器学习300问】85、Adam梯度下降优化算法的原理是什么?

Adam优化算法取了两个算法名称的首字母——Adaptive Moment Estimation的缩写,结合了Momentum算法和RMSprop算法的优点。在Momentum中,会计算前一时刻的梯度,并将其用于当前时刻的梯度更新;而RMSprop会对梯度的大小进行自适应调整…

二叉树的遍历(前序 中序 后序)

一、前序遍历 顺序为: 根-->左子树---->右子树 先访问根节点,再递归进入根节点的左子树(通过递归不断往下遍历),直到访问的节点没有左子树,此时递归进入其右子树(通过递归进行相同操作&a…

vue布局设置——使用 el-drawer 打造个性化 Admin 后台布局设置

在前端开发中,我们常常需要为 admin 后台构建灵活且个性化的布局设置。今天,我要分享的是如何利用 el-drawer 来实现这样一个有趣的功能。 首先,我们来看一下主要的设置参数: 1. theme: 用于定义主题,可以根据需求切换…

Java入门基础学习笔记15——强制类型转换

大范围类型的变量是否可以赋值给小范围类型的变量呢? IDEA直接报错。直接报错,是提醒你有问题。但是我非常进行类型转换。 非要强行赋值呢? 强制类型转换,强行将类型范围大的变量,数据赋值给类型范围小的变量。 数据…

若依生成树表和下拉框选择树表结构(在其他页面使用该下拉框输入)

1.数据库表设计 生成树结构的主要列是id列和parent_id列,后者指向他的父级 2.来到前端代码生成器页面 导入你刚刚写出该格式的数据库表 3.点击编辑,来到字段 祖籍列表是为了好找到直接父类,不属于代码生成器方法,需要后台编…

数据挖掘(二)数据预处理

前言 基于国防科技大学 丁兆云老师的《数据挖掘》 数据挖掘 数据挖掘(一)数据类型与统计 2、数据预处理 2.1数据清理 缺失值处理: from sklearn.impute import SimpleImputer# 创建一个SimpleImputer对象,指定缺失值的处理策略…

day07beef-xss之根据beef-xss获取cookies

1.安装 apt-get update apt-get install beef-xss 若报错运行不了尝试 apt remove ruby apt remove beef-xss apt-get install ruby apt-get install ruby-dev libpcap-dev gem install eventmachine apt-get install beef-xss 2.运行 beef-xss 运行成功会自动弹出浏览框。 攻…

WM Transaction Code 仓库管理模块事务代码大全

1.1 LE-WM 仓库管理 Warehouse Management 仓库管理事务码 描述 LB01 Create Transfer Requirement 创建转储需求 LB02 Change transfer requirement 修改转储需求 LB03 Display Transfer Requirement 显示转储需求 LB10 TRs for Storage Type 按仓储类型的转储请求 …

一次完整的GC流程

Java堆中内存区分 Java的堆由新生代(Young Generation)和老年代(Old Generation)组成。新生代存放新分配的对象,老年代存放长期存在的对象。 新生代(Young)由年轻区(Eden&a…