ChatGLM-RLHF(六)-PPO(Proximal Policy Optimization)原理实现代码逐行注释-论文导读

 一,前言

从open AI 的论文可以看到,大语言模型的优化,分下面三个步骤,SFT,RM,PPO,我们跟随大神的步伐,来学习一下这三个步骤和代码实现,本章介绍PPO论文。


上一章介绍了论文的核心点,那我们对照原文,看看大神们是怎么写的

摘要

首先对比强化学习几种不同的方法,deep Q-learning、policy gradient methods和natural policy gradient methods。

1,列举之前方法存在的问题,在开发可扩展(适用于大型模型和并行实现)、数据效率高、稳健性强(即在不进行超参数调整的情况下在各种问题上成功)的方法方面还有改进的空间。deep Q-learning(使用函数逼近)在许多简单问题上失败 [1],且理解不足;policy gradient 的数据效率和稳健性差;而信任区域策略优化(TRPO)相对复杂,并且不兼容包含噪声(例如 dropout)或参数共享(在策略和价值函数之间或与辅助任务共享参数)的架构。

2,PPO算法的优点,该算法实现了 TRPO 的数据效率和可靠性,同时仅使用一阶优化。PPO提出了一种具有剪切概率比率的新目标,它形成了策略性能的一种下限估计。为了优化策略,PPO在从策略中采样数据和对采样数据进行多次优化之间交替进行。

3,PPO的实验比较了代理损失函数的各种不同版本的性能,并发现具有剪切概率比率的版本表现最佳。PPO还将 PPO 与文献中的几个以前的算法进行了比较。在连续控制任务中,PPO的表现优于进行比较的算法。在 Atari 上,它的表现显著优于 A2C(在样本复杂度方面),并与 ACER 类似,但它要简单得多。结果显示PPO优于其他在线策略梯度方法,并在样本复杂度、简单性和强时之间取得了有利的平衡。

背景:

一些传统的做法

1,Policy Gradient Methods策略梯度方法 策略梯度方法通过计算策略梯度的估计量并将其插入到随机梯度上升算法中来工作。最常用的梯度估计器具有以下形式: gˆ = Eˆ t h ∇θ log πθ(at | st)Aˆ t i (1) 其中πθ是一种随机策略,Aˆ t 是时间步t处优势函数的估计量。这里,期望Eˆ t [...]表示在一个交替采样和优化的算法中,对有限批量样本的经验平均值。使用自动微分软件的实现是通过构造一个目标函数,其梯度是策略梯度估计器,估计器ˆg是通过对目标函数求导得到的。 L P G(θ) = Eˆ t h log πθ(at | st)Aˆ t i . (2) 虽然在相同的轨迹上执行多个优化步骤很有吸引力,但这样做并没有得到很好的解释,并且在经验上通常会导致破坏性的大型策略更新。

2,Trust Region Methods

信任域方法

在 TRPO [Sch+15b] 中,最大化一个目标函数("代理"目标函数),同时限制策略更新的大小。具体来说,

$ \max_{\theta} \hat{E}_t \left[ \frac{\pi_{\theta}(a_t|s_t)}{\pi_{\theta_{\text{old}}}(a_t|s_t)} A_t \right] $

$ \text{s.t. } \hat{E}_t \left[ \text{KL}[\pi_{\theta_{\text{old}}}(·|s_t), \pi_{\theta}(·|s_t)] \right] \leq \delta $

其中,$\theta_{\text{old}}$是更新前的策略参数向量。这个问题可以通过对目标函数进行线性近似和对约束进行二次近似,使用共轭梯度算法有效地近似解决。

支持 TRPO 的理论实际上建议使用惩罚而不是约束,即解决无约束优化问题:

$ \max_{\theta} \hat{E}_t \left[ \frac{\pi_{\theta}(a_t|s_t)}{\pi_{\theta_{\text{old}}}(a_t|s_t)} A_t - \beta \text{KL}[\pi_{\theta_{\text{old}}}(·|s_t), \pi_{\theta}(·|s_t)] \right] $

其中,$\beta$是某个系数。这是因为某个代理目标函数(在状态上计算最大 KL,而不是计算平均值)形成了策略$\pi$性能的下限(即悲观估计)。TRPO 使用硬约束而不是惩罚,因为很难选择一个能够在不同问题上或甚至在单个问题上(在学习过程中特性会改变)表现良好的$\beta$值。因此,为了实现我们的目标,即使用一阶算法模拟 TRPO 的单调改进,实验表明,仅仅选择一个固定的惩罚系数 $\beta$ 并使用 SGD 优化带惩罚的目标函数方程是不够的,需要进行额外的修改。

3,Clipped Surrogate Objective 剪切代理目标函数

r_t(\theta)=\frac{\pi_{\theta}(a_t|s_t)}{\pi_{\theta_{\text{old}}}(a_t|s_t)},其中 $r(\theta_{\text{old}})=1$。TRPO 最大化一个"代理"目标函数

L_{\text{CPI}}(\theta) = \hat{E}_t \left[ \frac{\pi_{\theta}(a_t|s_t)}{\pi_{\theta_{\text{old}}}(a_t|s_t)} A_t \right] = \hat{E}_t [r_t(\theta)A_t]

其中,L_{\text{CPI}}(\theta)上标指的是保守策略迭代。如果没有约束,最大化 $L_{\text{CPI}}$将导致策略更新过大;

方法:

1,PPO提出的主要目标函数如下:

L_{\text{CLIP}}(\theta) = \hat{E}_t \left[ \min \left( r_t(\theta)A_t, \text{clip}(r_t(\theta), 1-\epsilon, 1+\epsilon)A_t \right) \right]

其中,$\epsilon$是一个超参数,例如 $\epsilon=0.2$。这个目标函数的动机如下。min 函数内的第一项是 $L_{\text{CPI}}$,第二项 \text{clip}(r_t(\theta), 1-\epsilon, 1+\epsilon)A_t是通过剪切概率比率来修改代理目标函数,从而消除了将 $r_t$移出区间$[1-\epsilon, 1+\epsilon]$的动机。最后,我们取剪切和未剪切目标函数的最小值,因此最终目标函数是未剪切目标函数的下限(即悲观估计)。使用这种方案,只有当改变概率比率会使目标函数变好时,我们才会忽略概率比率的变化,并在概率比率变化使目标函数变差时将其纳入考虑。注意,$L_{\text{CLIP}}(\theta)$$\theta_{\text{old}}$ 处一阶近似地等于$L_{\text{CPI}}(\theta)$,但是当 $\theta$离开 $\theta_{\text{old}}$ 时,它们变得不同。

2,自适应 KL 惩罚系数

另一种方法是使用 KL 散度惩罚,并调整惩罚系数,以便在每次策略更新时实现一些目标 KL 散度 $d_{\text{target}}$。在我们的实验中,PPO发现 KL 惩罚的表现不如剪切代理目标函数,但是PPO在这里包含它是因为它是一个重要的基准。

在这种算法的最简单的实现中,PPO在每次策略更新中执行以下步骤:

• 使用多个 epoch 的小批量 SGD,优化 KL 惩罚目标函数

$ L_{\text{KLPEN}}(\theta) = \hat{E}_t \left[ \frac{\pi_{\theta}(a_t|s_t)}{\pi_{\theta_{\text{old}}}(a_t|s_t)} A_t - \beta \text{KL}[\pi_{\theta_{\text{old}}}(·|s_t), \pi_{\theta}(·|s_t)] \right] $

• 计算 $d=\hat{E}_t[\text{KL}[\pi_{\theta_{\text{old}}}(·|s_t), \pi_{\theta}(·|s_t)]]$

- 如果$d<d_{\text{target}}/1.5$,则 $\beta \leftarrow \beta/2$
- 如果 $d>d_{\text{target}}\times 1.5$,则$\beta \leftarrow \beta\times 2$

更新后的$\beta$用于下一次策略更新。使用这种方案,我们偶尔会看到 KL 散度与$d_{\text{target}}$显著不同的策略更新,但这种情况很少见,$\beta$ 很快就会调整。上面的参数 1.5 和 2 是启发式选择的,但算法对它们不太敏感。$\beta$ 的初始值是另一个超参数,但在实践中并不重要,因为算法会快速调整它。

3,Adaptive KL Penalty Coefficient 自适应KL惩罚系数

前面几节中的代理损失可以通过对典型策略梯度实现进行微小更改来计算和求导。对于使用自动微分的实现,我们只需构建损失函数$L_{\text{CLIP}}$$L_{\text{KLPEN}}$,而不是$L_{\text{PG}}$,并在该目标函数上执行多步随机梯度上升。

大多数计算方差减少优势函数估计器的技术都利用了一个学习的状态值函数 $V(s)$,例如,广义优势估计 [Sch+15a] 或 [Mni+16] 中的有限时间段估计器。如果使用共享策略和价值函数参数的神经网络架构,则必须使用将策略代理和价值函数误差项组合的损失函数。可以通过添加熵奖励来增强此目标函数,以确保足够的探索,这是过去的工作所建议的 [Wil92; Mni+16]。

将这些项组合,我们得到以下目标函数,每次迭代(近似)最大化:

$ L_{\text{CLIP}+\text{VF}+S}(\theta) = \hat{E}_t [ L_{\text{CLIP},t}(\theta) - c_1 L_{\text{VF},t}(\theta) + c_2 S[\pi_\theta](s_t)] $

其中,$c_1$$c_2$ 是系数,$S$ 表示熵奖励,$L_{\text{VF},t}$是一个平方误差损失 $(V_{\theta}(s_t)-V_{\text{targ},t})^2$

好,上面就得到了PPO的最后迭代公式

结论:

PPO介绍了近端策略优化,这是一组策略优化方法,使用多个随机梯度上升的 epochs 来执行每个策略更新。这些方法具有信任区域方法的稳定性和可靠性,但实现起来要简单得多,只需要对基本的策略梯度实现进行几行代码更改,并适用于更一般的设置(例如,在策略和价值函数的联合架构下使用),并具有更好的总体性能。

 代码:

GitHub - Pillars-Creation/ChatGLM-RLHF-LoRA-RM-PPO: ChatGLM-6B添加了RLHF的实现,以及部分核心代码的逐行讲解 ,实例部分是做了个新闻短标题的生成

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/64088.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

IDEA常用插件介绍

1.CodeGlance&#xff08;CodeGlance Pro&#xff09; 安装后&#xff0c;重新启动编译器即可。 CodeGlance是一款非常好用的代码地图插件&#xff0c;可以在代码编辑区的右侧生成一个竖向可拖动的代码缩略区&#xff0c;可以快速定位代码的同时&#xff0c;并且提供放大镜功能…

【UE4】多人联机教程(重点笔记)

效果 1. 创建房间、搜索房间功能 2. 根据指定IP和端口加入游戏 步骤 1. 新建一个第三人称角色模板工程 2. 创建一个空白关卡&#xff0c;这里命名为“InitMap” 3. 新建一个控件蓝图&#xff0c;这里命名为“UMG_ConnectMenu” 在关卡蓝图中显示该控件蓝图 打开“UMG_Connec…

如何在页面中嵌入音频和视频?

聚沙成塔每天进步一点点 ⭐ 专栏简介⭐ 嵌入音频⭐ 嵌入视频⭐ 写在最后 ⭐ 专栏简介 前端入门之旅&#xff1a;探索Web开发的奇妙世界 记得点击上方或者右侧链接订阅本专栏哦 几何带你启航前端之旅 欢迎来到前端入门之旅&#xff01;这个专栏是为那些对Web开发感兴趣、刚刚踏…

Redis 7.X Linux 环境安装

Redis 简介 作为一名开发人员&#xff0c;想必大家对Redis一定是耳熟能详&#xff0c;因此在此只做简单介绍。 Remote Dictionary Server(远程字典服务)是完全开源的&#xff0c;使用ANSIC语言编写遵守BSD协议&#xff0c;是一个高性能的Key-Value内存数据库&#xff0c;它提…

布隆过滤器,Guava实现布隆过滤器(本地内存),Redis实现布隆过滤器(分布式)

一、前言 利用布隆过滤器可以快速地解决项目中一些比较棘手的问题。如网页 URL 去重、垃圾邮件识别、大集合中重复元素的判断和缓存穿透等问题。不知道从什么时候开始&#xff0c;本来默默无闻的布隆过滤器一下子名声大噪&#xff0c;在面试中面试官问到怎么避免缓存穿透&#…

Vue3 实现产品图片放大器

Vue3 实现类似淘宝、京东产品详情图片放大器功能 环境&#xff1a;vue3tsvite 1.创建picShow.vue组件 <script lang"ts" setup> import {ref, computed} from vue import {useMouseInElement} from vueuse/core/*获取父组件的传值*/ defineProps<{images:…

机器学习实战13-超导体材料的临界温度预测与分析(决策树回归,梯度提升回归,随机森林回归和Bagging回归)

大家好&#xff0c;我是微学AI&#xff0c;今天给大家介绍一下机器学习实战13-超导体材料的临界温度预测与分析(决策树回归,梯度提升回归,随机森林回归和Bagging回归)&#xff0c;这几天引爆网络的科技大新闻就是韩国科研团队宣称发现了室温超导材料-LK-99&#xff0c;这种材料…

pytorch中torch.einsum函数的详细计算过程图解

第一次见到 rel_h torch.einsum(“bhwc,hkc->bhwk”, r_q, Rh)这行代码时&#xff0c;属实是懵了&#xff0c;网上找了很多博主的介绍&#xff0c;但都没有详细的说明函数内部的计算过程&#xff0c;看得我是一头雾水&#xff0c;只知道计算结果的维度是如何变化的&#xf…

【积水成渊】CSS磨砂玻璃效果和渐变主题色文字

大家好&#xff0c;我是csdn的博主&#xff1a;lqj_本人 lqj_本人_python人工智能视觉&#xff08;opencv&#xff09;从入门到实战,前端,微信小程序-CSDN博客 最新的uniapp毕业设计专栏也放在下方了&#xff1a; https://blog.csdn.net/lbcyllqj/category_12346639.html?spm1…

Element-UI简介

目录 安装 常用组件 Container 布局容器 Button 按钮 MessageBox 弹框 Form 表单验证 element-ui是一个前端的ui框架&#xff0c;封装了很多已经写好的ui组件&#xff0c;例如表单组件&#xff0c;布局组件&#xff0c;表格组件.......是一套桌面端组件。 Element - 网站…

【Winform学习笔记(七)】Winform无边框窗体拖动功能

Winform无边框窗体拖动功能 前言正文1、设置无边框模式2、无边框窗体拖动方法1、通过Panel控件实现窗体移动2、通过窗体事件实现窗体移动3、调用系统API实现窗体移动4、重写WndProc()实现窗体移动 前言 在本文中主要介绍 如何将窗体设置成无边框模式、以及实现无边框窗体拖动功…

【设计模式——学习笔记】23种设计模式——迭代器模式Iterator(原理讲解+应用场景介绍+案例介绍+Java代码实现)

文章目录 案例引入介绍基础介绍应用场景登场角色 案例实现案例一实现 案例二实现 迭代器模式在JDK源码中的应用总结文章说明 案例引入 编写程序展示一个学校院系结构: 需求是这样&#xff0c;要在一个页面中展示出学校的院系组成&#xff0c;一个学校有多个学院&#xff0c;一…

运放电路笔记3-加/减法运算电路

一、反相加法运算电路 反相加法运算电路如下&#xff1a; 根据电路图可知道&#xff1a; V- V 0V 设 Vi1 V18 Vi2 V20 求得输出电压Vo的值如下&#xff1a; ( Vo - (V-) )/R26 ((V-) - Vi1)/R27 ((V-) - Vi2)/R30 Vo - (V-) ((V-) - Vi1)*R26/R27 ((V-) - Vi2)*R26/R3…

无涯教程-Lua - nested语句函数

Lua编程语言允许在另一个循环中使用一个循环。以下部分显示了一些示例来说明这一概念。 nested loops - 语法 Lua中嵌套for循环语句的语法如下- for init,max/min value, increment dofor init,max/min value, incrementdostatement(s)endstatement(s) end Lua编程语言中的…

前沿分享-可降解体内微型机器人

大概是这样的&#xff0c;通过外部磁场的应用&#xff0c;微型机器人可以在微流体通道内进行远程控制&#xff0c;便于快速准确地运送到目标点。 在研究中&#xff0c;该团队通过将具有高生物相容性和超顺磁性的氧化铁纳米颗粒内化到从人鼻甲骨中提取的干细胞中&#xff0c;开发…

Maven介绍-下载-安装-使用-基础知识

Maven介绍-下载-安装-使用-基础知识 Maven的进阶高级用法可查看这篇文章&#xff1a; Maven分模块-继承-聚合-私服的高级用法 文章目录 Maven介绍-下载-安装-使用-基础知识01. Maven1.1 初识Maven1.1.1 什么是Maven1.1.2 Maven的作用 02. Maven概述2.1 Maven介绍2.2 Maven模型…

自动化处理,web自动化测试处理多窗口+切换iframe框架页总结(超细整理)

目录&#xff1a;导读 前言一、Python编程入门到精通二、接口自动化项目实战三、Web自动化项目实战四、App自动化项目实战五、一线大厂简历六、测试开发DevOps体系七、常用自动化测试工具八、JMeter性能测试九、总结&#xff08;尾部小惊喜&#xff09; 前言 web 自动化之处理…

STM32——LED内容补充(寄存器点灯及反转的原理)

文章目录 点灯流程开时钟配置IO关灯操作灯反转宏定义最后给自己说 本篇文章使用的是STM32F103xC系列的芯片&#xff0c;四个led灯在PE2,PE3,PE4,PE5上连接 点灯流程 1.开时钟 2.配置IO口 &#xff08;1&#xff09;清零指定寄存器位 &#xff08;2&#xff09;设置模式为推挽输…

pl/sql函数如何返回多行数据

用游标即可&#xff1a; SQL code ? 1 2 3 4 5 6 7 8 9 10 11 12 Create or REPLACE FUNCTION getCursorList( P_USER_ID_I IN VARCHAR2 --接收输入参数 ) RETURN SYS_REFCURSOR AS P_RESULT_SET_O SYS_REFCURSOR…

大数据-玩转数据-Flink-Transform(上)

一、Transform 转换算子可以把一个或多个DataStream转成一个新的DataStream.程序可以把多个复杂的转换组合成复杂的数据流拓扑. 二、基本转换算子 2.1、map&#xff08;映射&#xff09; 将数据流中的数据进行转换, 形成新的数据流&#xff0c;消费一个元素并产出一个元素…