一些RLHF的平替汇总

卷友们好,我是rumor。

众所周知,RLHF十分玄学且令人望而却步。我听过有的小道消息说提升很大,也有小道消息说效果不明显,究其根本还是系统链路太长自由度太高,不像SFT一样可以通过数据配比、prompt、有限的超参数来可控地调整效果。

但也正是因为它的自由度、以目标为导向的学习范式和性价比更高的标注成本,业内往往认为它会有更高的效果天花板。同时我最近看OpenAI的SuperAlignment计划感受颇深,非常坚定地认为scalable的RLHF(不局限于PPO)就是下一步的大突破所在。

所以我秉着不抛弃不放弃的决心,带大家梳理一下最近的RLHF平替工作,探索如何更稳定地拿到效果。

RLHF链路可以分为两个模块,RM和RL,这两个模块各有各的问题:

  • RM:对准确率和泛化性的要求都很高,不然很容易就被hack到(比如输出某个pattern就给高分)。但业内普遍标注数据的一致率只有70%左右,数据决定效果天花板,如何让RM代表大部分人的判断、且能区分出模型结果的细微差异,难难难。这也是RLHF方法没法规模化起来的主要瓶颈

  • RL:奖励太稀疏(最后一步才拿到句子分数,不像SFT一样有真实的token-level监督信号)、PPO超参数非常多,导致效果很不稳定

针对上述两个模块的问题,学术界大佬们各显神通,大概有以下几种解决方案:

  • 没得商量,不做RL了,选择性保留RM:比如RRHF、DPO,这类方法可以直接在RM数据上优化语言模型,但如果想提升效果,需要用自身模型采样,得再引入一个RM,比如RSO、SCiL、PRO等。又或者直接用RM采样的数据做精调,比如RAFT、Llama2等

  • 用其他RL算法:比如ReMax、Decision Transformer

下面我们就逐一盘盘这些方法以及他们给出的有用结论。

不做RL了

RRHF

RRHF: Rank Responses to Align Language Models with Human Feedback without tears

RRHF是阿里在今年年初(2023.04)发布的工作,它的做法是直接在RM数据山优化LM,让chosen回答的概率大于rejected回答的概率。

3cdb8750d6c378a68a290a0210f68fee.png
RRHF

在具体实现上,就是计算句子的条件概率后加一个ranking loss:

9354be5b214f574336188c07bebd31bc.png
RRHF loss

但在实践中,作者发现只用ranking loss会把模型训崩溃,所以又加了SFT loss。从消融实验可以看到加了rank loss确实对模型效果有一些提升:

e391fe43244a4166662f821d79929ac1.png

最终在HH数据集上,作者提出的RRHF平均得分略好于PPO(-1.02 vs -1.03),效果差距不是太大,但该方法主打一个便捷稳定。

同时作者也在实验中尝试了不同的数据采样策略:

  1. 直接用开源RM的数据

  2. 用自己的模型生成response,用开源RM进行排序,做出新的RM数据

  3. 循环执行2,类似强化的思维不断靠自身采样到更好的答案

最后的结论也比较符合直接,是3>2>1。

Preference Ranking Optimization for Human Alignment

后续阿里(非同作者)在2023.06又提出了一个PRO方法,核心思想跟RRHF接近,但有两个不同:

  1. 选用了更多负例,不止停留在pair-wise

  2. 给不同负例不同的惩罚项(比如分数差的多就拉大一些)

19d31e97ad644cd9e235c40c914dffa4.png
PRO

同时也加上了SFT loss,最终效果比RLHF和RRHF都有些提升。

DPO

Direct Preference Optimization:Your Language Model is Secretly a Reward Model

DPO是斯坦福在2023.05底提出的工作,主打一个硬核,直接从PPO公式推出了一个平替方案,虽然最终loss呈现的思想跟RRHF接近(chosen句子概率>rejected句子概率),但同时带有一个SFT模型的约束,可以保证在不加SFT loss的情况下训练不崩溃(个人猜测)。

5b0a7e2afe2eb4eea10d52b37e7f6005.png
DPO

作者在公开的几个RM数据集上都做了实验,可以发现DPO对超参数的敏感度更低,效果更稳定,且奖励得分优于RLHF。

同时,微软在2023.10月的一篇工作[1]上也对DPO做了进一步的探索。考虑到排序数据成本,他们直接默认GPT4 > ChatGPT > InstructGPT,实验后得到以下结论:

  1. 用DPO在 GPT4 vs InstructGPT 上训练的效果 > 直接在GPT-4数据精调的效果

  2. 先在简单的pair上训练后,再在困难的pair上训练会有更好的效果

RSO

Statistical rejection sampling improves preference optimization

上面介绍了两种ranking思想的loss,具体哪种更好一些呢?DeepMind在2023.09月份的一篇RSO[2]工作中进行了更系统的对比,得到了以下结论:

  1. DPO(sigmoid-norm) loss效果略好,但更重要的是增加SFT约束,可以看表中没加约束的hinge loss效果很差,但加了约束后则能接近DPO

  2. 另外重要的还有采样策略,比如要优化模型A,最好用模型A生产的结果,去做pair标注,再训练A,比用模型B生产的数据训练A更好。这跟RRHF的结论也比较一致,更接近「强化」的思想

e7bba8b68fb7baca4f95fae18f543cef.png
RSO实验结果

同时作者提出了另外一种RSO(Rejection Sampling Optimization)的采样方法,实验发现有2个点左右的提升。

Rejection Sampling + SFT

拒绝采样是一种针对复杂问题的采样策略[3],可以更高效地采样到合适的样本,进行复杂分布的估计。最近也有很多方法,利用RM进行拒绝采样,直接用采样出的数据对模型做SFT。

Llama 2: Open Foundation and Fine-Tuned Chat Models

LLama2就很好地使用了拒绝采样,先问问地训RM,再用RM筛选出当前模型最好的结果进行SFT。论文发出时他们一共把llama2-chat迭代了5轮,前4轮都是用的拒绝采样,只有最后一轮用了PPO,可以看到相比ChatGPT的胜率一直在提升:

7c98a32df087ebd3138d8fcd8cb437a4.png

不过从RLHF v5(no PPO)和RLHF v5(with PPO)来看,RL还是能有很大的效果收益。

这种方法还有很多变体可以探索,比如港大在2023.04提出的RAFT[4],就是选取多个样本进行后续精调。同时采样策略也可以进行一些优化,比如上面提到的RSO。

用其他RL算法

ReMax

ReMax: A Simple, Effective, and Efficient Reinforcement Learning Method for Aligning Large Language Models

ReMax是港中文在2023.10提出的工作,核心是对RLHF中RL阶段的PPO算法进行了简化。

强化的难点是怎么把多步之后的最终目标转化成模型loss,针对这个问题有不同解决方案,目前OpenAI所使用的RL策略叫PPO[5],是他们自己在2017年提出的一个经典RL算法(OpenAI早期真的做了很多强化的工作)。

但ReMax的作者认为,PPO并不适用于语言模型的场景:

  1. 可以快速拿到句子奖励:传统RL的长期奖励获取可能会比较昂贵,比如必须玩完一局游戏、拿起一个杯子,而RLHF在有了RM后可以快速拿到奖励

  2. 确定性的环境:传统RL中,环境也是变化的,同一个场景+动作可能拿到不同奖励,而在语言模型中,给定上下文和当前结果,下一步的状态也是确定的,RM打分也是确定的

上面两点在传统RL中会造成学习不稳定的问题,因此PPO使用了Actor-Critic网络,即引入一个「助教」来给模型的每一步打分,而作者认为在语言模型上可以省去。

cc6e9a1f6f7d4725d6991b47ef5075a3.png
ReMax

因此,作者提出用强化中的REINFORCE算法来代替PPO,去掉了Critic模型,但作者在实验中同样发现了梯度方差较大优化不稳定的问题,于是增加了一项bias来降低方差,命名为ReMax算法。

由于资源受限,作者没跑通7B的PPO实验,只对比了1.3B的ReMax和PPO,效果显示ReMax更好一些:

7dcaba854a5e645156fee110c78fffc0.png
ReMax效果

除了效果提升之外,由于去掉了一个要训练的模型,在显存占用和训练速度上都有提升。

Offline RL: Decision Transformer

上面我们说的PPO、REINFORCE都是Online RL,需要一个虚拟环境,通过互动拿到奖励,再进行学习。相对的,Offline RL是指直接拿之前和环境互动的数据来学习。

Aligning Language Models with Offline Reinforcement Learning from Human Feedback

这篇是英伟达在2023.08提出的工作,探索了MLE、用reward做回归、DT(Decision Transformer)三种离线强化算法,最终发现DT的效果更好。

Decision Transformer是一个2021 RL Transformer的开山之作,但NLPer一看就懂:

55e3568db130e1ce58958c176bb7c725.png
Decision Transformer

它的核心思想是把奖励、状态作为输入,让模型预测动作,从而建模三者之间的关系。比如模型训练时见过1分的答案,也见过5分的,那预测时直接输入<reward>5.0让它给出最好的结果。

这样训下来效果居然还不错,也超过了PPO:

47aa44a7c0468f58579c722821dbb3f3.png
DT效果

SteerLM: Attribute Conditioned SFT as an (User-Steerable) Alternative to RLHF

没想到的是,英伟达不同团队在2023.10月又推出了一篇SteerLM的工作,与DT的思想类似,但会把奖励分为不同维度,比如质量、帮助性等等。

ce60948c6b77d9a4bce1c3c3702465af.png
SteerLM

具体做法:

  1. 通过人工标注的各个维度打分,训练一个打分模型

  2. 用打分模型对更多数据打分

  3. 精调一个SFT模型,可以做到输入prompt、目标分数,输出符合分数的结果

  4. 用第三步的模型生产更多答案,再打分,如此循环

最终的效果也是好于RLHF(PPO哭晕在厕所):

5b1bf4e731aed57877ce79bec24c03d8.png
SteerLM效果

总结

以上就是我最近关注的RLHF平替方法,虽然可走的路很多,但很难有一个可靠且全面的效果对比,毕竟RLHF本身就难训不稳定,几百条数据下波动几个点很正常,而且无论是自动测评还是人工测评都会带有bias。

但对于资源有限的团队来说,平替方案不失为一种选择。

参考资料

[1]

Contrastive Post-training Large Language Models on Data Curriculum: https://arxiv.org/abs/2310.02263

[2]

Statistical rejection sampling improves preference optimization: https://arxiv.org/pdf/2309.06657.pdf

[3]

理解Rejection Sampling: https://gaolei786.github.io/statistics/reject.html

[4]

RAFT: Reward rAnked FineTuning for Generative Foundation Model Alignment: https://arxiv.org/abs/2304.06767

[5]

PPO: https://arxiv.org/pdf/1707.06347.pdf

a0022f25eab886b2218d09c47b40e3cf.jpeg


我是朋克又极客的AI算法小姐姐rumor

北航本硕,NLP算法工程师,谷歌开发者专家

欢迎关注我,带你学习带你肝

一起在人工智能时代旋转跳跃眨巴眼

「彩蛋:你能找到几个同厂不同组的相近工作7f32338193390ffe0eb48641d0dfcda5.png07cddb428b6187bb73e1651d114ed08d.png

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/163241.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

【论文解读】FFHQ-UV:用于3D面部重建的归一化面部UV纹理数据集

【论文解读】FFHQ-UV 论文地址&#xff1a;https://arxiv.org/pdf/2211.13874.pdf 0. 摘要 我们提出了一个大规模的面部UV纹理数据集&#xff0c;其中包含超过50,000张高质量的纹理UV贴图&#xff0c;这些贴图具有均匀的照明、中性的表情和清洁的面部区域&#xff0c;这些都是…

【数据预处理2】数据预处理——数据标准化

数据标准化 1. 什么是标准化&#xff1f;   数据标准化是一个常用的数据预处理操作&#xff0c;目的是将不同规格的数据转换到统一规格或不同分布的数据转换到某个特定范围&#xff0c;以减少规模、特征、分布差异等对模型的影响。这种操作也叫作无量纲化。   除了用作模型…

【【萌新的SOC学习之 VDMA 彩条显示实验之一】】

萌新的SOC学习之 VDMA 彩条显示实验之一 实验任务 &#xff1a; 本章的实验任务是 PS写彩条数据至 DDR3 内存中 然后通过 VDMA IP核 将彩条数据显示在 RGB LCD 液晶屏上 下面是本次实验的系统框图 VDMA 通过 HP接口 与 PS端的 DDR 存储器 进行交互 因为 VDMA 出来的是 str…

【数据预处理3】数据预处理 - 归一化和标准化

处理数据之前&#xff0c;通常会使用一些转换函数将「特征数据」转换成更适合「算法模型」的特征数据。这个过程&#xff0c;也叫数据预处理。 比如&#xff0c;我们在择偶时&#xff0c;有身高、体重、存款三个特征&#xff0c;身高是180、体重是180、存款是180000&#xff1…

SpringBoot 整合 Freemarker

通过 Freemarker 模版&#xff0c;我们可以将数据渲染成 HTML 网页、电子邮件、配置文件以及源代码等。 Freemarker 不是面向最终用户的&#xff0c;而是一个 Java 类库&#xff0c;我们可以将之作为一个普通的组件嵌入到我们的产品中。 Freemarker 模版后缀为 .ftl(FreeMarke…

python算法例10 整数转换为罗马数字

1. 问题描述 给定一个整数&#xff0c;将其转换为罗马数字&#xff0c;要求返回结果的取值范围为1~3999。 2. 问题示例 4→Ⅳ&#xff0c;12→Ⅻ&#xff0c;21→XⅪ&#xff0c;99→XCIX。 3. 代码实现 def int_to_roman(num):val [1000, 900, 500, 400,100, 90, 50, 40…

【DevOps】Git 图文详解(四):Git 使用入门

Git 图文详解&#xff08;四&#xff09;&#xff1a;Git 使用入门 1.创建仓库2.暂存区 add3.提交 commit 记录4.Git 的 “指针” 引用5.提交的唯一标识 id&#xff0c;HEAD~n 是什么意思&#xff1f;6.比较 diff 1.创建仓库 创建本地仓库的方法有两种&#xff1a; 一种是创建…

(Matalb时序预测)PSO-BP粒子群算法优化BP神经网络的多维时序回归预测

目录 一、程序及算法内容介绍&#xff1a; 基本内容&#xff1a; 亮点与优势&#xff1a; 二、实际运行效果&#xff1a; 三、部分程序&#xff1a; 四、完整程序数据说明文档下载&#xff1a; 一、程序及算法内容介绍&#xff1a; 基本内容&#xff1a; 本代码基于Matalb平…

Java Swing算术我最棒

内容要求 1) 本次程序设计是专门针对 Java 课程的,要求使用 Java 语言进行具有一定代码量的程序开发。程序的设计要结合一定的算法&#xff0c;在进行代码编写前要能够设计好自己的算法。 本次程序设计涉及到 Java 的基本语法&#xff0c;即课堂上所介绍的变量、条件语句、循…

vuedraggable拖拽列表设置某一条元素禁止被拖拽

直接上代码 <draggable filter".unDrag"><div class"unDrag">不能拖拽</div><div class"canDrag">可以拖拽</div> </draggable>一、设置filter 在draggable节点的属性filter设置不可拖拽的class名&#…

3D全景视角,足不出户感知真实场景的魅力

近年来&#xff0c;随着科技的快速发展&#xff0c;普通的平面静态视角已经无法满足我们了&#xff0c;不管是视角框架的限制还是片面的环境展示&#xff0c;都不足以让我们深入了解场景环境。随着VR全景技术的日益成熟&#xff0c;3D全景技术的出现为我们提供了全新的视觉体验…

uni-app(1)pages. json和tabBar

第一步 在HBuilderX中新建项目 填写项目名称、确定目录、选择模板、选择Vue版本&#xff1a;3、点击创建 第二步 配置pages.json文件 pages.json是一个非常重要的配置文件&#xff0c;它用于配置小程序的页面路径、窗口表现、导航条样式等信息。 右键点击pages&#xff0c;按…

Kafka(四)消费者消费消息

文章目录 如何确保不重复消费消息&#xff1f;消费者业务逻辑重试消费者提交自定义反序列化类消费者参数配置及其说明重要的参数session.time.ms和heartbeat.interval.ms和group.instance.id增加消费者的吞吐量消费者消费的超时时间和poll()方法的关系 消费者消费逻辑启动消费者…

遗传算法GA-算法原理与算法流程图

本站原创文章&#xff0c;转载请说明来自《老饼讲解-BP神经网络》bp.bbbdata.com 目录 一、遗传算法流程图 1.1. 遗传算法流程图 二、遗传算法的思想与机制 2.1 遗传算法的思想 2.2 遗传算法的机制介绍 三、 遗传算法的算法流程 3.1 遗传算法的算法…

PXE高效批量网络装机

目录 一.PXE 1. 系统装机的三种引导方式 2. 系统安装过程 3. 光盘安装相关文件 4. PXE简介 5. 实现过程 6. PXE优点 二.PXE实现过程 1. 实验准备 2. 搭建DHCP服务器 3. 配置TFTP服务器 4. 准备pxelinu.0文件 5. 挂载镜像准备内核、驱动文件 6. 手写配置文件 7. 准…

强烈 推荐 13 个 Web前端在线代码IDE

codesandbox.io&#xff08;国外&#xff0c;提供免费空间&#xff09; 网址&#xff1a;https://codesandbox.io/ CodeSandbox 专注于构建完整的 Web 应用程序&#xff0c;支持多种流行的前端框架和库&#xff0c;例如 React、Vue 和 Angular。它提供了一系列增强的功能&…

springboot项目中获取业务功能的导入数据模板文件

场景: 在实际业务场景中,经常会遇到某些管理功能需要数据导入共功能,但既然是导入数据,肯定会有规则限制,有规则就会有数据模板,但这个模板一般是让客户自己下载固定规则模板,而不是让客户自己随便上传模板。下面介绍直接下载模板 一、下载模板示例 1、在项目的…

信安.网络安全.UDP协议拥塞

第一部分 如何解决UDP丢包问题 一、UDP 报文格式 每个 UDP 报文分为 UDP 报头和 UDP 数据区两部分。报头由 4 个 16 位长&#xff08;2 字节&#xff09;字段组成&#xff0c;分别说明该报文的源端口、目的端口、报文长度和校验值。UDP 报文格式如图所示。 UDP 报文中每个…

前端性能优化之LightHouse

优质博文&#xff1a;IT-BLOG-CN 一、LightHouse环境搭建 LightHouse是一款由Google开发的开源工具&#xff0c;用于评估Web应用程序的性能和质量。可以将其看作是一个Chrome扩展程序运行&#xff0c;或从命令行运行。为LightHouse提供一个需要审查的网址&#xff0c;它将针对…

基于django水果蔬菜生鲜销售系统

基于django水果蔬菜生鲜销售系统 摘要 基于Django的水果蔬菜生鲜销售系统是一种利用Django框架开发的电子商务平台&#xff0c;旨在提供高效、便捷的购物体验&#xff0c;同时支持水果蔬菜生鲜产品的在线销售。该系统整合了用户管理、产品管理、购物车、订单管理等核心功能&…