Text-to-SQL小白入门(九)InstructGPT论文:教你如何训练ChatGPT

论文概述

InstructGPT和ChatGPT 的训练流程基本一致 ,ChatGPT是改进后的InstructGPT,比如InstructGPT是基于GPT-3训练,而ChatGPT是基于GPT-3.5训练。

基本信息

  • 英文标题:Training language models to follow instructions with human feedback
  • 中文标题:通过人类反馈的指令训练语言模型
  • 发表时间:2023年3月 arxiv
  • 作者单位:Open AI
  • 论文链接:https://arxiv.org/pdf/2203.02155.pdf
  • 代码链接:GitHub - openai/following-instructions-human-feedback

学习InstructGPT论文之前,想了解了基本的LLM或者RLHF流程,可以看看组织「eosphoros-ai」(今年的8000+star的开源项目DB-GPT的开源社区)提出的LLM+Text2SQL汇总项目:https://github.com/eosphoros-ai/Awesome-Text2SQL,里面也收集了一些微调SFT(lora, qlora, p-tuning等),RLHF相关的论文(比如RLHF,RRHF,RLTF, RRTF, RLAIF等等),目前也有300+的star,持续更新中,欢迎围观使用star!

摘要

背景

使语言模型更大并不能使它们更好地遵循用户的意图。例如,大型语言模型可能生成不真实的(untruthful)有害的(toxic)或对用户没有帮助(not helpful)的输出。

贡献/方法

在本文中,作者展示了一种方法,通过使用人类反馈进行微调,在广泛的任务中使语言模型与用户意图保持一致。

  • 先使用有监督微调SFT
  • 然后收集一批rank排序的模型输出
  • 再使用人类反馈的强化学习rlhf微调
  • 最终得到的模型叫做InstructGPT

结果:参数量小了100倍,性能差不多。 真实性⬆️、有毒⬇️、精度⬇️(轻微)

结果惊艳:

  • 1.3b参数的InstructGPT的模型输出和175b GPT-3的输出很类似。
  • 在公共NLP数据集上,InstructGPT模型显示出真实性的改进和有毒输出生成的减少,同时性能下降最小

结论:

尽管InstructGPT仍然会犯一些简单的错误,但结果表明,根据人类反馈进行微调是使语言模型与人类意图保持一致的一个有希望的方向

结果

API prompt distribution

  • 参数说明:
    • 横坐标是模型参数大小,纵坐标是和175B GPT SFT比较赢的概率(比如绿色的线条,横坐标为175B时候,赢的概率刚好为0.5,此时就是175B GPT SFT vs 175B GPT SFT )
    • GPT就是最普通的模型
    • GPT(prompted)就是给几个例子few-shot
    • SFT 有监督微调
    • PPO 用强化学习
    • PPO-ptx: 在PPO算法期间,使用pretraining mix (但是几乎没有什么效果)
  • 对比的模型是SFT 175B,可以发现的是1.3B PPO或者PPO-ptx已经超过0.5的概率赢175B,说明方法很有效。
  • InstructGPT就是PPO-ptx

论文还在 public NLP dataset进行了实验,InstructGPT模型在公有NLP数据集上有“对齐税”导致性能下降,可能是因为API prompt 训练的原因。

论文还公布了qualitative results,InstructGPT模型泛化能力很强,具体实验参考原论文。

结论

对齐研究alignment research的影响

  • 提高模型对齐度的成本比预训练低。
  • InstructGPT泛化能力强,可以推广到没有监督数据的领域。
  • 通过微调,可以减少性能下降
  • 验证了对齐技术在现实生活中应用

对齐的是什么?

人类偏好,人类价值观 --> 标注者的偏好、OpenAI 研究人员的偏好、API 用户的偏好。

核心方法

RLHF架构图

基础背景知识

  • RLHF方法最早是2017年提出:Deep reinforcement learning from human preferences(2017)
  • 在2020年RLHF文章「Learning to summarize from human feedback(2020」中,RM训练使用了两个模型在相同input情况下的output进行比较,使用交叉熵损失。——InstructGPT使用KL散度
  • PPO算法,也是Open AI 2017年提出的:Proximal policy optimization algorithms(2017),这篇文章的作者「John Schulman」也在InstructGPT作者名单中。

这个图也是经典大图了,RLHF实践参考的范式,RLHF主要分成了3个阶段:

  • 第一阶段:SFT
  • 第二阶段:RM
  • 第三阶段:RL (使用PPO算法:proximal policy optimization 最近策略优化),对第三阶段进行一个简单解释:
    • 输入一个标注数据,模型经过PPO算法输出一个response
    • RM模型对response打分
    • 根据打分score更新PPO策略。

PPO算法具体是什么呢?——(留个坑,后续补上)

详情参考论文:Schulman, J., Wolski, F., Dhariwal, P., Radford, A., and Klimov, O. (2017). Proximal policy optimization algorithms. arXiv preprint arXiv:1707.06347.

SFT

数据格式

  • prompt - output

更直观一点,以一个具体的小任务比如Text2SQL为例子,构造的数据集如下所示:

来源知乎文档:Text-to-SQL小白入门(八)RLAIF论文:AI代替人类反馈的强化学习

{"prompt": "I want you to act as a SQL terminal in front of an example database, you need only to return the sql command to me.Below is an instruction that describes a task, Write a response that appropriately completes the request.\n\"\n##Instruction:\ndepartment_management contains tables such as department, head, management. Table department has columns such as Department_ID, Name, Creation, Ranking, Budget_in_Billions, Num_Employees. Department_ID is the primary key.\nTable head has columns such as head_ID, name, born_state, age. head_ID is the primary key.\nTable management has columns such as department_ID, head_ID, temporary_acting. department_ID is the primary key.\nThe head_ID of management is the foreign key of head_ID of head.\nThe department_ID of management is the foreign key of Department_ID of department.\n###Input:\nHow many heads of the departments are older than 56 ?\n\n###Response:","output": "SELECT count(*) FROM head WHERE age  >  56"}

实验参数

参数如下:

  • base model——GPT-3
  • epoch——16
  • lr decay——cosine
  • dropout——0.2

选择最终的SFT模型时,是根据验证集上的RM分数。

惊讶点:

  • 1个epoch后已经过拟合了,但是为了后续的RM分数,还是多跑几轮epoch

RM

数据格式

  • prompy-chosen-rejected

同样的,以Text2SQL任务为例子,构造的数据集如下所示:

{"prompt": "I want you to act as a SQL terminal in front of an example database, you need only to return the sql command to me.Below is an instruction that describes a task, Write a response that appropriately completes the request.\n\"\n##Instruction:\ndepartment_management contains tables such as department, head, management. Table department has columns such as Department_ID, Name, Creation, Ranking, Budget_in_Billions, Num_Employees. Department_ID is the primary key.\nTable head has columns such as head_ID, name, born_state, age. head_ID is the primary key.\nTable management has columns such as department_ID, head_ID, temporary_acting. department_ID is the primary key.\nThe head_ID of management is the foreign key of head_ID of head.\nThe department_ID of management is the foreign key of Department_ID of department.\n###Input:\nHow many heads of the departments are older than 56 ?\n\n###Response:","chosen": "SELECT count(*) FROM head WHERE age  >  56","rejected":"SELECT COUNT(head_name) FROM head WHERE age > 56;"}

实验参数

  • base model: 是GPT-3 SFT之后的模型,但是去掉了最后一层
    • 因为原始模型输入是prompt,输出是response
    • 现在需要模型输入是prompt + response,输出是score
  • 参数量仅选择的6B大小

为什么RM模型选6B,不是175B?

    • 6B 减少计算量
    • 175B 训练不稳定
  • 标注者,需要对K=4 和 K=9之间的response进行排序,会产生C(k, 2)个两两比较pair
  • 一个epoch中,对所有的C(k, 2)比较对训练,一次传播loss

损失函数:

  • x代表输入的prompt;y_w代表chosen_data; y_l代表rejected_data; D代表实验数据集
  • r_θ(x,y)代表RM模型输入prompt x和response y的输出得分

最后要对奖励归一化,使得平均奖励为0。

RL

数据格式

  • prompt-output

和SFT阶段数据格式一致。

{"prompt": "I want you to act as a SQL terminal in front of an example database, you need only to return the sql command to me.Below is an instruction that describes a task, Write a response that appropriately completes the request.\n\"\n##Instruction:\ndepartment_management contains tables such as department, head, management. Table department has columns such as Department_ID, Name, Creation, Ranking, Budget_in_Billions, Num_Employees. Department_ID is the primary key.\nTable head has columns such as head_ID, name, born_state, age. head_ID is the primary key.\nTable management has columns such as department_ID, head_ID, temporary_acting. department_ID is the primary key.\nThe head_ID of management is the foreign key of head_ID of head.\nThe department_ID of management is the foreign key of Department_ID of department.\n###Input:\nHow many heads of the departments are older than 56 ?\n\n###Response:","output": "SELECT count(*) FROM head WHERE age  >  56"}

实验参数

1.RM可以和RL重复多轮迭代——这样构建更多数据,越来越趋近于人类偏好。

  • SFT训练->训练一个RM->训练一个RL->不断重复下面的步骤:
    • 构建RM数据->重新训练一个RM->重新训练一个RL->
    • 构建RM数据->重新训练一个RM->重新训练一个RL->
    • 构建RM数据->重新训练一个RM->重新训练一个RL->

2.实践中,大部分的比较数据来源于SFT的数据,少部分数据来源于RL模型的比较数据。

  • 继2020文章「Learning to summarize from human feedback」之后,作者再次使用PPO对环境中的SFT模型进行了微调。
  • 额外增加了 KL散度。
  • 额外增加了预训练梯度——目的是为了减少在NLP数据集上性能倒退,所以InstructGPT模型 == PPO-ptx

  • π^RL代表学习到的强化学习RL模型; π^SFT代表SFT阶段训练的模型。

为什么用π表示?为什么用除法表示?这就是强化学习的基本概念

从状态State到动作Action的过程就称之为一个策略Policy,一般用π表示(可以理解为一个函数表示),也就是在强化学习阶段需要找到一个关系:a=π(s) 或者是 π(a|s), a 就是action, s就是state

  • D_pretrain代表预训练阶段的数据分布;D_π^RL代表强化学习阶段的数据分布
  • r_θ(x,y)代表RM模型输入prompt x和response y的输出得分
  • β是控制KL奖励的系数; γ是控制预训练梯度的系数,如果是普通的PPO,那么γ=0

数据收集

之前听一个大学教授的讲座,有个观点很有意思:Open AI做大模型为什么比谷歌强,因为包括transformer在内的一些创新模型大多是谷歌研究的,那为什么Open AI在大模型领域为什么比谷歌强?答:因为Open AI在数据清洗,数据质量把控这方面做的很好。——所以数据是相当重要的!

API数据

为了训练本文的最终InstructGPT

prompt dataset 主要由OpenAI 的API获得,用户和API交互,把这些数据收集起来(前提是用户使用的时候就告知数据要被收集),此时的API是早期的InstructGPT模型,并且没有使用用户在生产中使用API的数据。

API数据分布如下,主要有9类。

那么问题来了?早期的InstructGPT模型的训练数据怎么来?

  • 通过人工标注的有监督学习训练得到的

对API收集的数据做了一些处理:

  • 去除重复的提示:通过检查公共前缀(感觉回到了leetcode刷题,求两个字符串的最长公共前缀)
  • 每个用户不超过200条prompt:应该是避免单独个体的偏好
  • 基于用户id,划分train,val,test——这样验证集和测试集就不包含来自训练集中的用户的数据
    • 比如训练数据用id 1, 2, 3, 4的所有数据
    • 测试的数据用id 5的数据。
  • 过滤掉了个人身份信息的数据

人工标注数据

主要是为了训练早期的InstructGPT

标注者被要求手写以下三种类型的prompt:

  • plain:标记人员提出任意的简单任务,同时保证任务的多样性
  • few-shot:标注人员提出一条指令instruction,以及该指令的多个查询/响应对(query/response)
  • user-based:标注人员在OpenAI 提供的API中获取用例,标注人员需要给出这些用例相对应的instruction

数据量级

数据中96%以上是英文,其它20个语种例如中文,法语,西班牙语等加起来不到4%,这可能导致InstructGPT/ChatGPT能进行其它语种的生成时,效果应该远不如英文

  • SFT 数据,大概13k
  • RM 数据,大概33k
  • PPO数据,大概31k

论文还有大量的附录数据详情,可以参考论文原文,比如标注人员分布,数据示例,数据标注等等,不得不说,Open AI数据扎实,正文20页,附录48页,总共68页。

其他文章

Text-to-SQL小白入门(一)综述文章学习

Text-to-SQL小白入门(二)Transformer学习

Text-to-SQL小白入门(三)IRNet:引入中间表示SemQL

Text-to-SQL小白入门(四)指令进化大模型WizardLM

Text-to-SQL小白入门(五)开源代码大模型Code Llama

Text-to-SQL小白入门(六)Awesome-Text2SQL项目介绍

Text-to-SQL小白入门(七)PanGu-Coder2论文——RRTF

Text-to-SQL小白入门(八)RLAIF论文:AI代替人类反馈的强化学习

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/197056.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

香港科技大学广州|智能制造学域博士招生宣讲会—华中科技大学专场

时间:2023年12月08日(星期五)15:00 地点:华中科技大学大学生活动中心A座603 报名链接:https://www.wjx.top/vm/mmukLPC.aspx# 宣讲嘉宾: 胡鹏程 副教授 https://facultyprofiles.hkust-gz.edu.cn/faculty-…

RabbitMQ消息队列

简介 MQ(message queue),从字面意思上看就个 FIFO 先入先出的队列,只不过队列中存放的内容是 message 而已,它是一种具有接收数据、存储数据、发送数据等功能的技术服务。 作用:流量削峰、应用解耦、异步处理。 生产者将消息发送…

黑马点评-Feed流的实现方案,基于推拉结合模式实现笔记推送

Feed流实现方案 我们关注了博主之后,当用户发布了动态后我们应该把这些数据推送给粉丝,关注推送也叫作Feed(投喂)流,通过无限下拉刷新获取新的信息 传统的模式内容检索: 粉丝需要主动通过搜索引擎或者是其他方式去查找想看的内容新型Feed流的效果: 系统分析用户到底想看什么,…

okhttp系列-拦截器的执行顺序

1.将拦截器添加到ArrayList final class RealCall implements Call {Response getResponseWithInterceptorChain() throws IOException {//将Interceptor添加到ArrayListList<Interceptor> interceptors new ArrayList<>();interceptors.addAll(client.intercept…

注意力机制(Attention Mechanism)

目录 1. 简介&#xff1a;探索注意力机制的世界 2. 历史背景 3. 核心原理 4. 应用案例 5. 技术挑战与未来趋势 6. 图表和示例 7. Conclusion 1. 简介&#xff1a;探索注意力机制的世界 在当今的人工智能&#xff08;AI&#xff09;和机器学习&#xff08;ML&#xff09;…

戴尔科技推出全新96核Precision 7875塔式工作站

工作站行业一直是快节奏且充满惊喜的。在过去25年中,戴尔Precision一直处于行业前沿,帮助创作者、工程师、建筑师、研究人员等将想法变为现实,并对整个世界产生影响。工作站所发挥的作用至关重要,被视为化不可能为可能的必要工具。如今,人工智能(AI)和生成式AI(GenAI)的浪潮正在…

npm管理发布包-创建与发布

创建与发布 我们可以将自己开发的工具包发布到 npm 服务上&#xff0c;方便自己和其他开发者使用&#xff0c;操作步骤如下 创建文件夹&#xff0c;并创建文件indexjs&#xff0c;在文件中声明函数&#xff0c;使用 module.exports 暴露npm初始化工具包&#xff0c;package.j…

浅谈硬件连通性测试几大优势

硬件连通性测试是确保硬件系统正常运行、提高系统可靠性和降低生产成本的关键步骤。在现代工程和制造中&#xff0c;将连通性测试纳入生产流程是一个明智的选择&#xff0c;有助于确保硬件产品的质量和性能达到最优水平。本文将介绍硬件连通性测试的主要优势有哪些! 一、提高系…

Java基础之集合类

Java基础之集合类 一、集合的框架1.1、集合概述1.2、集合与数组区别1.3、数组的缺点&#xff1a;1.4、常用集合分类1.5、Collection常用方法 二、List集合2.1、ArrayList2.2、LinkedList2.3、Vector2.4、区别 三、Set集合3.1、HashSet集合3.2、LinkedHashSet集合3.3、TreeSet集…

Unity 接入TapADN播放广告时闪退 LZ4JavaSafeCompressor

通过跟踪安卓日志&#xff0c;发现报如下错误 Didnt find class "com.tapadn.lz4.LZ4JavaSafeCompressor" 解决方案&#xff1a; 去掉Minify这边的勾选&#xff0c;再打包即可。

国内高速下载huggingface上的模型

前提 Python版本至少是3.8 安装 安装hugging face官方提供的下载工具 pip install -U huggingface_hub hf-transfer Windows设置环境变量 在当前窗口设置临时环境变量&#xff08;cmd.exe&#xff09; set HF_HUB_ENABLE_HF_TRANSFER 1 你也可以设置永久的环境变量&am…

MySQL基础进阶篇

进阶篇 存储引擎 MySQL体系结构&#xff1a; 存储引擎就是存储数据、建立索引、更新/查询数据等技术的实现方式。存储引擎是基于表而不是基于库的&#xff0c;所以存储引擎也可以被称为表引擎。 默认存储引擎是InnoDB。 相关操作&#xff1a; -- 查询建表语句 show create …

uniapp 导航分类

商品分类数据&#xff0c;包括分类名称和对应的商品列表点击弹出 列表的内容 展示效果如下&#xff1a; 代码展示 ①div部分 <view class"container"><view class"menu-bar"><view class"menu"><view class"menu-sc…

差异性分析方法汇总与pk

在数据研究中&#xff0c;常见的数据关系可以分为四类&#xff0c;分析是相关关系&#xff0c;因果关系、差异关系以及其它。本次所进行研究的关系为差异关系。对于差异性分析方法常见可以分为三类&#xff1a;参数检验、非参数检验以及可视化图形。 一、参数检验 1、参数检验…

Flask Session 登录认证模块

Flask 框架提供了强大的 Session 模块组件&#xff0c;为 Web 应用实现用户注册与登录系统提供了方便的机制。结合 Flask-WTF 表单组件&#xff0c;我们能够轻松地设计出用户友好且具备美观界面的注册和登录页面&#xff0c;使这一功能能够直接应用到我们的项目中。本文将深入探…

Redis(二):常见数据类型:String 和 哈希

引言 Redis 提供了 5 种数据结构&#xff0c;理解每种数据结构的特点对于 Redis 开发运维⾮常重要&#xff0c;同时掌握每 种数据结构的常⻅命令&#xff0c;会在使⽤ Redis 的时候做到游刃有余。 Redis 的命令有上百种&#xff0c;我们不可能全部死记硬背下来&#xff0c;但是…

linaro交叉编译工具链下载与使用笔记

笔记 文章目录 笔记确定目标 &#xff08;aarch64&#xff09;选择版本&#xff08;7.5&#xff09;选择目标&#xff08;aarch64-linux-gnu&#xff09;下载地址工具链&#xff08;gcc-linaro-7.5.0-2019.12-x86_64_aarch64-linux-gnu.tar.xz&#xff09;编译测试 &#xff08…

Selenium+Python做web端自动化测试框架与实例详解教程

最近受到万点暴击&#xff0c;由于公司业务出现问题&#xff0c;工作任务没那么繁重&#xff0c;有时间摸索seleniumpython自动化测试&#xff0c;结合网上查到的资料自己编写出适合web自动化测试的框架&#xff0c;由于本人也是刚刚开始学习python&#xff0c;这套自动化框架目…

python爬虫实习找工作练习测试(以下内容仅供参考学习)

要求&#xff1a;获取下图指定网站的指定数据 空气质量状况报告-中国环境监测总站 输入&#xff1a;用户输入下载时间范围&#xff0c;格式为2022-10 输出&#xff1a;将更新时间在2022年10月1日到31日之间的文件下载到本地目录&#xff08;可配置&#xff09;&#xff0c;并…

WIFI模块(esp-01s)实现天气预报代码实现

目录 前言 实现图片 一、串口编程的实现 二、发送AT指令 esp01s.c esp01s.h 三、数据处理 1、初始化 2、cjson处理函数 3、核心控制代码 四、修改堆栈大小 前言 实现图片 前面讲解了使用AT指令获取天气与cjson的解析数据&#xff0c;本章综合将时间显示到屏幕 一、…