一,前言
从open AI 的论文可以看到,大语言模型的优化,分下面三个步骤,SFT,RM,PPO,我们跟随大神的步伐,来学习一下这三个步骤和代码实现,本章介绍PPO论文。
上一章介绍了论文的核心点,那我们对照原文,看看大神们是怎么写的
摘要
首先对比强化学习几种不同的方法,deep Q-learning、policy gradient methods和natural policy gradient methods。
1,列举之前方法存在的问题,在开发可扩展(适用于大型模型和并行实现)、数据效率高、稳健性强(即在不进行超参数调整的情况下在各种问题上成功)的方法方面还有改进的空间。deep Q-learning(使用函数逼近)在许多简单问题上失败 [1],且理解不足;policy gradient 的数据效率和稳健性差;而信任区域策略优化(TRPO)相对复杂,并且不兼容包含噪声(例如 dropout)或参数共享(在策略和价值函数之间或与辅助任务共享参数)的架构。
2,PPO算法的优点,该算法实现了 TRPO 的数据效率和可靠性,同时仅使用一阶优化。PPO提出了一种具有剪切概率比率的新目标,它形成了策略性能的一种下限估计。为了优化策略,PPO在从策略中采样数据和对采样数据进行多次优化之间交替进行。
3,PPO的实验比较了代理损失函数的各种不同版本的性能,并发现具有剪切概率比率的版本表现最佳。PPO还将 PPO 与文献中的几个以前的算法进行了比较。在连续控制任务中,PPO的表现优于进行比较的算法。在 Atari 上,它的表现显著优于 A2C(在样本复杂度方面),并与 ACER 类似,但它要简单得多。结果显示PPO优于其他在线策略梯度方法,并在样本复杂度、简单性和强时之间取得了有利的平衡。
背景:
一些传统的做法
1,Policy Gradient Methods策略梯度方法 策略梯度方法通过计算策略梯度的估计量并将其插入到随机梯度上升算法中来工作。最常用的梯度估计器具有以下形式: gˆ = Eˆ t h ∇θ log πθ(at | st)Aˆ t i (1) 其中πθ是一种随机策略,Aˆ t 是时间步t处优势函数的估计量。这里,期望Eˆ t [...]表示在一个交替采样和优化的算法中,对有限批量样本的经验平均值。使用自动微分软件的实现是通过构造一个目标函数,其梯度是策略梯度估计器,估计器ˆg是通过对目标函数求导得到的。 L P G(θ) = Eˆ t h log πθ(at | st)Aˆ t i . (2) 虽然在相同的轨迹上执行多个优化步骤很有吸引力,但这样做并没有得到很好的解释,并且在经验上通常会导致破坏性的大型策略更新。
2,Trust Region Methods
信任域方法
在 TRPO [Sch+15b] 中,最大化一个目标函数("代理"目标函数),同时限制策略更新的大小。具体来说,
其中,是更新前的策略参数向量。这个问题可以通过对目标函数进行线性近似和对约束进行二次近似,使用共轭梯度算法有效地近似解决。
支持 TRPO 的理论实际上建议使用惩罚而不是约束,即解决无约束优化问题:
其中,是某个系数。这是因为某个代理目标函数(在状态上计算最大 KL,而不是计算平均值)形成了策略性能的下限(即悲观估计)。TRPO 使用硬约束而不是惩罚,因为很难选择一个能够在不同问题上或甚至在单个问题上(在学习过程中特性会改变)表现良好的值。因此,为了实现我们的目标,即使用一阶算法模拟 TRPO 的单调改进,实验表明,仅仅选择一个固定的惩罚系数 并使用 SGD 优化带惩罚的目标函数方程是不够的,需要进行额外的修改。
3,Clipped Surrogate Objective 剪切代理目标函数
设 ,其中 。TRPO 最大化一个"代理"目标函数
其中,上标指的是保守策略迭代。如果没有约束,最大化 将导致策略更新过大;
方法:
1,PPO提出的主要目标函数如下:
其中,是一个超参数,例如 。这个目标函数的动机如下。min 函数内的第一项是 ,第二项 是通过剪切概率比率来修改代理目标函数,从而消除了将 移出区间的动机。最后,我们取剪切和未剪切目标函数的最小值,因此最终目标函数是未剪切目标函数的下限(即悲观估计)。使用这种方案,只有当改变概率比率会使目标函数变好时,我们才会忽略概率比率的变化,并在概率比率变化使目标函数变差时将其纳入考虑。注意,在 处一阶近似地等于,但是当 离开 时,它们变得不同。
2,自适应 KL 惩罚系数
另一种方法是使用 KL 散度惩罚,并调整惩罚系数,以便在每次策略更新时实现一些目标 KL 散度 。在我们的实验中,PPO发现 KL 惩罚的表现不如剪切代理目标函数,但是PPO在这里包含它是因为它是一个重要的基准。
在这种算法的最简单的实现中,PPO在每次策略更新中执行以下步骤:
• 使用多个 epoch 的小批量 SGD,优化 KL 惩罚目标函数
• 计算
- 如果,则
- 如果 ,则
更新后的用于下一次策略更新。使用这种方案,我们偶尔会看到 KL 散度与显著不同的策略更新,但这种情况很少见, 很快就会调整。上面的参数 1.5 和 2 是启发式选择的,但算法对它们不太敏感。 的初始值是另一个超参数,但在实践中并不重要,因为算法会快速调整它。
3,Adaptive KL Penalty Coefficient 自适应KL惩罚系数
前面几节中的代理损失可以通过对典型策略梯度实现进行微小更改来计算和求导。对于使用自动微分的实现,我们只需构建损失函数 或 ,而不是,并在该目标函数上执行多步随机梯度上升。
大多数计算方差减少优势函数估计器的技术都利用了一个学习的状态值函数 ,例如,广义优势估计 [Sch+15a] 或 [Mni+16] 中的有限时间段估计器。如果使用共享策略和价值函数参数的神经网络架构,则必须使用将策略代理和价值函数误差项组合的损失函数。可以通过添加熵奖励来增强此目标函数,以确保足够的探索,这是过去的工作所建议的 [Wil92; Mni+16]。
将这些项组合,我们得到以下目标函数,每次迭代(近似)最大化:
其中,、 是系数, 表示熵奖励,是一个平方误差损失 。
好,上面就得到了PPO的最后迭代公式
结论:
PPO介绍了近端策略优化,这是一组策略优化方法,使用多个随机梯度上升的 epochs 来执行每个策略更新。这些方法具有信任区域方法的稳定性和可靠性,但实现起来要简单得多,只需要对基本的策略梯度实现进行几行代码更改,并适用于更一般的设置(例如,在策略和价值函数的联合架构下使用),并具有更好的总体性能。
代码:
GitHub - Pillars-Creation/ChatGLM-RLHF-LoRA-RM-PPO: ChatGLM-6B添加了RLHF的实现,以及部分核心代码的逐行讲解 ,实例部分是做了个新闻短标题的生成