跟李沐学AI:InstructGPT论文精读(SFT、RLHF)

原论文:[2203.02155] Training language models to follow instructions with human feedback

原视频:InstructGPT 论文精读【论文精读·48】_哔哩哔哩_bilibili

简介

1. RLHF 的基本概念

RLHF 是一种结合强化学习和人类反馈的训练方法,旨在让模型生成更符合人类期望的输出。它通过引入人类的偏好数据来调整模型的行为,使其不仅能够生成语法正确的文本,还能生成语义上更贴合用户意图的内容

2. RLHF 的主要步骤

RLHF 的流程通常分为以下三个阶段:

(1) 预训练语言模型
  • 在 RLHF 流程开始之前,使用一个预训练的语言模型作为基础。
  • 预训练的目标是让模型掌握大规模语言知识,但此时的模型尚未经过特定任务的优化
(2) 监督微调(Supervised Fine-Tuning, SFT)
  • 在这一阶段,研究人员收集一组高质量的人类生成的指令-响应对(instruction-response pairs),并用这些数据对预训练模型进行监督微调。
  • 这些指令-响应对由人类标注者编写,确保模型学会如何根据明确的指令生成合适的回答
(3) 强化学习优化
  • 强化学习阶段是 RLHF 的核心部分,分为两个子步骤:
    • 奖励模型(Reward Model, RM)训练 :研究人员收集人类对模型输出的偏好数据(例如,标注者选择更优的回答),并用这些数据训练一个奖励模型。奖励模型的作用是为模型生成的回答打分,反映其质量或与人类期望的匹配程度。
    • 策略优化(Policy Optimization) :使用强化学习算法(如 PPO,Proximal Policy Optimization)对模型进行进一步优化。PPO 是一种高效的强化学习算法,能够在保持模型稳定性的同时提升性能。模型会根据奖励模型的反馈不断调整其生成策略,最终生成更符合人类偏好的回答。

3. RLHF 的优势

  • 更高的指令遵从性 :通过引入人类反馈,InstructGPT 能够更好地理解并执行复杂的指令,减少了无关或不准确的回答。
  • 减少有害内容 :RLHF 方法可以帮助模型避免生成有害、偏见或不适当的内容,从而提高安全性。
  • 自然流畅的对话能力 :虽然 InstructGPT 更注重指令执行,但 RLHF 的优化也让模型在对话场景中表现得更加自然。

摘要

扩大模型的规模并不能从本质上让模型遵守用户的意图。如大语言模型可能像用户生成不真实、有害的或者无用的输出。

这篇文章中,作者团队提出了一种将大语言模型在大规模基于人类反馈的微调数据上与用户意图对齐的方法。

从一组标注者编写的提示词和通过OpenAI API提交的提示词开始,作者团队收集了一个数据集,其中包含标注者预期的模型输出结果,作者用这些数据来通过监督学习微调GPT-3。

然后,作者收集了一个模型输出排名的数据集,进一步使用基于人类反馈的强化学习来微调这个监督模型,并将最终得到的模型称为InstructGPT。

导论

大语言模型可以根据提示来完成一系列自然语言处理工作,但是模型经常会出现捏造事实、生成有偏见或有害的文本内容。作者认为这是因为模型训练的目标函数存在错误。模型训练的目标函数是从网络文本中预测下一个词元,这一目标函数与遵循用户提示的目标有一定偏差。这就是语言建模目标没有“对齐”。

为此,作者提出RLHF(Reinforcement Learning form Human Feedback,基于人类反馈的强化学习),以保证模型输出具有帮助性(helpful)、真实性(honest)和无害性(harmless)。

1. 采集各类问题prompt;

2. 标注者对问题进行回答(以模型预期的输出方式进行回答)

3. prompt+回答的文本对讲用于微调GPT-3 (SFT,Self Supervised Learning,有监督微调)

1. 收集一段问题prompt,同时讲模型对这个问题的不同输出进行采样。(模型输出具有随机性,所以对同一个问题会有不同的回答,输出的灵活度取决于解码策略)

2. 标注者对采样的模型输出从好到坏进行排序。

3. 将这些有好坏排序标注的回答数据用于训练一个奖励模型(RM,Reward Model)

这一步可以减轻数据标注成本。

1. 从数据集中随机抽取一个新的提示(prompt),例如“Write a story about frogs”(写一个关于青蛙的故事)。这个提示将作为输入提供给策略模型。

2. 策略模型(Policy Model)根据接收到的提示生成一个输出。在这个例子中,策略模型可能生成一段故事的开头,如“Once upon a time...”(从前……)。 

3. 奖励模型(Reward Model, RM)对策略模型生成的输出进行评估,并计算出一个奖励值(reward)。这个奖励值反映了输出的质量或与人类期望的匹配程度。

4. 计算出的奖励值被用于通过PPO(Proximal Policy Optimization,近端策略优化算法)更新策略模型。PPO 是一种高效的强化学习算法,能够帮助策略模型根据奖励信号调整其生成策略,从而在后续迭代中生成更高质量的输出。

方法

数据集来源:

标注人员编写提示词,提示词涉及如下三类:

  • 一般性的任意提示词
  • 多步问答提示词
  • 用户反馈的提示词:每个用户最多会被采集200个提示词,数据将根据用户ID来划分为训练集、测试集、验证集。

基于这些提示词,作者创造了三种不同的数据集:

  • SFT数据集:标注人员基于提示词编写答案
  • RM数据集:标注人员对模型输出的结果进行排序
  • PPO数据集:没有任何人工标签,直接用于RLHF微调

数据标注:招聘了40人进行数据标注,与这些标注人员紧密联系。

模型

1. 在SFT数据集基础上,对GPT-3进行微调,一共训练了16轮。SFT模型仅仅在第一轮就出现了过拟合,但是对后续步骤没有影响,甚至有帮助。

2. RM模型移除了SFT模型中的Softmax层,转为一层线性层将模型输出投影为一个标量奖励。本文使用了6B的模型作为奖励模型。

RM模型的损失函数如下:

对于每个输入Prompt x,我们生成 K 个不同的答案,这里 K=9。我们将这些答案分为两组:一组是排序较高的答案 Yw​,另一组是排序较低的答案 Yl​。

我们的目标是让模型能够正确地识别出哪些答案更好。为此,我们将 Yw​ 和 Yl​ 输入到奖励模型中,分别得到它们的奖励分数 rθ​(x,Yw​) 和 rθ​(x,Yl​)。由于 Yw​ 的排序高于 Yl​,我们希望 Yw​ 的奖励分数也高于 Yl​。

为了实现这一目标,我们计算两个答案的奖励分数之差,并通过sigmoid函数进行转换,以确保结果在0到1之间。然后,我们对这个结果取对数,并添加一个负号,以确保当 Yw​ 的奖励高于 Yl​ 时,损失值为正。最后,我们将这个值乘以1 / (CK_2),即 1 / 36​,以平均化所有可能的配对组合。

3. 强化学习:使用PPO算法对策略进行优化。

结果 

经过SFT+RLHF后的1.3B模型效果好于原有的175B GPT-3模型。 

局限性:

  • 1. 模型的行为与标注者息息相关
  • 2. 模型不是完全的安全

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/976051.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

基于YOLO11深度学习的运动鞋品牌检测与识别系统【python源码+Pyqt5界面+数据集+训练代码】

《------往期经典推荐------》 一、AI应用软件开发实战专栏【链接】 项目名称项目名称1.【人脸识别与管理系统开发】2.【车牌识别与自动收费管理系统开发】3.【手势识别系统开发】4.【人脸面部活体检测系统开发】5.【图片风格快速迁移软件开发】6.【人脸表表情识别系统】7.【…

条款24:若所有参数皆需类型转换,请为此采用 non-member 函数

1.针对隐式转换的情况&#xff0c;可能会出现误用的情况 示例代码 #include <iostream>class Rational { public:Rational(float iNum1 1, float iNum2 2) { fNum iNum1 / iNum2; }~Rational() {}//自定义逻辑const Rational operator * (const Rational& rhs) …

无人机实战系列(番外一)本地图像+Apple ML Depth Pro

这篇文章作为系列文章 “无人机实战系列” 的一篇番外文章&#xff0c;主要测试了下 Apple 推出的一个基于机器学习的单目图像转深度的工具 ml-depth-pro&#xff0c;这个也是我在找这方面工具时意外发现的一个仓库&#xff0c;后期仍然会以 Depth Anything V2 为主线进行记录。…

MySQL数据库连接池泄露导致MySQL Server超时关闭连接

前言 最近做项目&#xff0c;发现老项目出现xxx&#xff0c;这个错误其实很简单&#xff0c;出现在MySQL数据库Server端对长时间没有使用的client连接执行清楚处理&#xff0c;因为是druid数据库&#xff0c;且在github也出现这样的issue&#xff1a;The last packet successf…

人工智能基础知识笔记一:核函数

1、简介 核函数有严格的数学要求&#xff0c;凡满足Mercer定理【参考本文第9章节】的都可以作为核函数。Mercer 定理确保高维:间任意两个向量的内积一定可以被低维空间中两个向量的某种计算表示(多数时候是内积的某换)。本节通过一个例子讲解核函数的使用。 2、核函数定义 设…

本地部署DeepSeek-R1(Ollama+Docker+OpenWebUI知识库)

安装Ollama 打开 Ollama官网 https://ollama.com/下载安装 Ollama服务默认只允许本机访问&#xff0c;修改允许其它主机访问 OLLAMA_HOST0.0.0.0 ollama serve也可以添加系统环境变量 都知道模型体积很大&#xff0c;顺便也通过环境变量修改模型存放位置&#xff0c;我这…

图论算法篇:BFS宽度优先遍历

那么bfs算法的大名想必大家都一定听闻过&#xff0c;那么也许有的人在认识我们bfs算法之前是先接触的我们的dfs算法&#xff0c;那么目前我们的算法世界中的两种搜索算法就是我们的dfs和我们的bfs&#xff0c;那么废话不多说&#xff0c;就让我们进入bfs算法的学习 BFS算法原理…

初识.git文件泄露

.git 文件泄露 当在一个空目录执行 git init 时&#xff0c;Git 会创建一个 .git 目录。 这个目录包含所有的 Git 存储和操作的对象。 如果想备份或复制一个版本库&#xff0c;只需把这个目录拷贝至另一处就可以了 这是一种常见的安全漏洞&#xff0c;指的是网站的 .git 目录…

【SpringBoot】【JWT】使用JWT的claims()方法存入Integer类型数据自动转为Double类型

生成令牌时使用Map存入Integer类型数据&#xff0c;将map使用claims方法放入JWT令牌后&#xff0c;取出时变成Double类型&#xff0c;强转报错&#xff1a; 解决&#xff1a; 将Integer转为String后存入JWT令牌&#xff0c;不会被自动转为其他类型&#xff0c;取出后转为Integ…

JVM之JVM的组成

Java 虚拟机&#xff08;JVM&#xff09;是 Java 程序的运行核心&#xff0c;它主要由类加载系统、运行时数据区、执行引擎和本地方法接口这几个关键部分组成。 类加载系统&#xff08;Class Loading System&#xff09; 类加载系统负责在程序运行时动态地将 Java 类加载到 J…

数据库面试题(基础常考!!!)

在数据库领域&#xff0c;无论是日常开发还是面试场景&#xff0c;都有一些高频且重要的问题需要我们深入理解和掌握。本文将对这些常见面试题进行详细阐述&#xff0c;帮助大家更好地应对面试和实际工作中的挑战。 面试题一&#xff1a;三范式详解 什么是三范式 三范式是关…

Linux网络 网络层

IP 协议 协议头格式 4 位版本号(version): 指定 IP 协议的版本, 对于 IPv4 来说, 就是 4. 4 位头部长度(header length): IP 头部的长度是多少个 32bit, 也就是 4 字节&#xff0c;4bit 表示最大的数字是 15, 因此 IP 头部最大长度是 60 字节. 8 位服务类型(Type Of Service):…

uniapp 微信小程序打包之后vendor.js 主包体积太大,解决办法,“subPackages“:true设置不生效

现在是打包的时候&#xff0c;vendor.js 的内容全部打到了主包里面&#xff0c; 说一下我的方法&#xff1a; 1. 通过发行 小程序打包 这样打包的体积是最小的&#xff0c;打包之后打开微信开发工具&#xff0c;然后再上传 2.manifest.json,在“mp-weixin”里添加代码 "…

python-leetcode-N 皇后

51. N 皇后 - 力扣&#xff08;LeetCode&#xff09; class Solution:def solveNQueens(self, n: int) -> List[List[str]]:res []board [[.] * n for _ in range(n)]def is_safe(row, col):for i in range(row):if board[i][col] Q:return Falseif col - (row - i) >…

【蓝桥杯单片机】客观题

一、第十三届省赛&#xff08;一&#xff09; 二、第十三届省赛&#xff08;二&#xff09;

如何进行ERP系统的定制开发?

在当今数字化时代&#xff0c;企业资源规划&#xff08;ERP&#xff09;系统已然成为企业提升管理效能、优化资源配置以及实现精细化管理的关键工具。然而&#xff0c;鉴于不同企业在行业特性、业务流程以及管理需求等方面存在显著差异&#xff0c;通用型的ERP系统往往难以契合…

基于SpringBoot的校园消费点评管理系统

作者&#xff1a;计算机学姐 开发技术&#xff1a;SpringBoot、SSM、Vue、MySQL、JSP、ElementUI、Python、小程序等&#xff0c;“文末源码”。 专栏推荐&#xff1a;前后端分离项目源码、SpringBoot项目源码、Vue项目源码、SSM项目源码、微信小程序源码 精品专栏&#xff1a;…

MySQL数据库——常见慢查询优化方式

大家好&#xff0c;这里是编程Cookbook。本文详细介绍MySQL的慢查询相关概念&#xff0c;分析步骤及其优化方案等。 文章目录 什么是慢查询日志&#xff1f;慢查询日志的相关参数如何启用慢查询日志&#xff1f;方式一&#xff1a;修改配置文件方式二&#xff1a;通过命令动态启…

【前端基础篇】Day 1

总结&#xff1a; 1. Web标准的构成 2. 基本标签 目录 1. Web标准的构成 2. 基本标签 2.1快捷键 2.2.1标题标签 2.2.2段落和换行标签 2.2.3文本格式化标签 2.2.4div和span标签 2.3.1 图像标签和路径 2.3.2路径 2.3.3超链接标签 2.4注释标签 2.5特殊字符 1. Web标准…

【复习】Redis

数据结构 Redis常见的数据结构 String&#xff1a;缓存对象Hash&#xff1a;缓存对象、购物车List&#xff1a;消息队列Set&#xff1a;点赞、共同关注ZSet&#xff1a;排序 Zset底层&#xff1f; Zset底层的数据结构是由压缩链表或跳表实现的 如果有序集合的元素 < 12…