【强化学习】DPO(Direct Preference Optimization)算法学习笔记

【强化学习】DPO(Direct Preference Optimization)算法学习笔记

  • RLHF与DPO的关系
  • KL散度
  • Bradley-Terry模型
  • DPO算法流程
  • 参考文献

RLHF与DPO的关系

  • DPO(Direct Preference Optimization)和RLHF(Reinforcement Learning from Human Feedback)都是用于训练和优化人工智能模型的方法,特别是在大型语言模型的训练中
  • DPO和RLHF都旨在通过人类的反馈来优化模型的表现,它们都试图让模型学习到更符合人类偏好的行为或输出
  • RLHF通常涉及三个阶段:全监督微调(Supervised Fine-Tuning)、奖励模型(Reward Model)的训练,以及强化学习(Reinforcement Learning)的微调
  • DPO是一种直接优化模型偏好的方法,不需要显式地定义奖励函数,而是通过比较不同模型输出的结果,选择更符合人类偏好的结果作为训练目标,主要是通过直接最小化或最大化目标函数来实现优化,利用偏好直接指导优化过程,而不依赖于强化学习框架
    在这里插入图片描述

KL散度

  • KL散度(Kullback-Leibler divergence),也被称为相对熵,是衡量两个概率分布P和Q差异的一种方法
  • 公式: K L ( P ∣ ∣ Q ) = ∑ x P ( x ) log ⁡ ( P ( x ) Q ( x ) ) \mathrm{KL}(P||Q)=\sum_xP(x)\log\left(\frac{P(x)}{Q(x)}\right) KL(P∣∣Q)=xP(x)log(Q(x)P(x))
  • KL散度是不对称的, K L ( P ∣ ∣ Q ) ! = K L ( Q ∣ ∣ P ) KL(P||Q)!=KL(Q||P) KL(P∣∣Q)!=KL(Q∣∣P)

在这里插入图片描述

Bradley-Terry模型

  • Bradley-Terry模型是一种用于比较成对对象并确定相对偏好或能力的方法。这种模型特别适用于对成对比较数据进行分析,从而对一组对象进行排序

  • P ( i > j ) = α i α i + α j P(i{>}j)=\frac{\alpha_i}{\alpha_i{+}\alpha_j} P(i>j)=αi+αjαi

  • α i \alpha_i αi表示第 i i i个元素的能力参数,且大于0。 P ( i > j ) P(i>j) P(i>j)表示第 i i i个元素战胜第 j j j个元素的概率

  • Bradley-Terry模型的参数通常通过最大似然估计(MLE)来确定
    在这里插入图片描述

  • sigmoid函数: σ ( x ) = 1 1 + e − x \sigma(x) = \frac{1}{1 + e^{-x}} σ(x)=1+ex1

  • loss函数的化简
    L o s s = − E ( x , y w , y l ) ∼ D [ ln ⁡ e x p ( r ( x , y w ) ) e x p ( r ( x , y w ) ) + e x p ( r ( x , y l ) ) ] = − E ( x , y w , y l ) ∼ D [ ln ⁡ 1 1 + e x p ( r ( x , y l ) − r ( x , y w ) ) ] = − E ( x , y w , y l ) ∼ D [ ln ⁡ σ ( r ( x , y w ) − r ( x , y l ) ) ] \begin{aligned}Loss &=-\mathbb{E}_{(x,y_{w},y_{l})\sim D}[\ln\frac{exp(r(x,y_{w}))}{exp(r(x,y_{w}))+exp(r(x,y_{l}))}] \\ &= -\mathbb{E}_{(x,y_{w},y_{l})\sim D}[\ln\frac{1}{1 + exp(r(x,y_{l})- r(x,y_{w}))}] \\ &= -\mathbb{E}_{(x,y_{w},y_{l})\sim D}[\ln \sigma(r(x,y_{w})-r(x,y_{l}))] \end{aligned} Loss=E(x,yw,yl)D[lnexp(r(x,yw))+exp(r(x,yl))exp(r(x,yw))]=E(x,yw,yl)D[ln1+exp(r(x,yl)r(x,yw))1]=E(x,yw,yl)D[lnσ(r(x,yw)r(x,yl))]

  • loss函数的目标是优化LLM输出的 y w y_w yw,经过reward计算的得分尽可能的大于 y w y_w yw经过reward计算的得分

在这里插入图片描述

DPO算法流程

  • DPO通过比较不同输出的偏好,构建一个目标函数,该函数直接反映人类的偏好,通常使用排序损失函数(例如Pairwise Ranking Loss),该函数用来衡量模型在用户偏好上的表现
  • DPO优化过程:使用梯度下降等优化算法,直接最小化或最大化目标函数。通过不断调整模型参数,使得模型生成的输出更加符合用户的偏好
    在这里插入图片描述
  • 基准模型一般指经过SFT有监督微调后的模型
  • DPO的目标是尽可能得到多的奖励,同时使得新训练的 模型尽可能与基准模型分布一致

DPO训练目标的化简

在这里插入图片描述
上图中第一步利用的是KL散度的定义,之所以式子中没有KL散度中的 P ( π ( y ∣ x ) ) P(\pi(y|x)) P(π(yx)),是因为KL散度可以理解成是一个概率比值的log的期望,在这里这个概率以期望的形式放到式子左边的期望中了

  • 求最大值 通过在式中加上负号转化为求最小值,并同时除以 β \beta β
  • DPO原论文中的推导过程

在这里插入图片描述

  • 继续推导

在这里插入图片描述
在这里插入图片描述

  • 求解reward函数的表达式,将reward函数的表达式代入loss函数中

在这里插入图片描述

  • DPO loss损失函数的表达形式

在这里插入图片描述

  • logZ(x)项被抵消,于是可以转而用最大似然估计MLE直接在这个概率模型上直接优化LM,去得到希望的最优的π*
    个人理解的一知半解 有时间还是得去看看原论文

参考文献

  1. DPO (Direct Preference Optimization) 算法讲解
  2. Direct Preference Optimization(DPO)学习笔记
  3. DPO原论文 Direct Preference Optimization: Your Language Model is Secretly a Reward Model

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/664738.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

KMPlayer v2024.4.25.13 官方版 (万能播放器)

前言 KMPlaye通过各种插件扩展KMP可以支持层出不穷的新格式。KMPlaye强大的插件功能,直接从Winamp继承的插件功能,能够直接使用Winamp的音频,输入,视觉效果插件,而通过独有的扩展能力,只要你喜欢&#xff…

【linux-imx6ull-设备树点灯】

目录 1. 设备树简介1.1 编译-引用1.2 设备树文件结构1.3 设备树节点介绍1.3.1 特殊节点chosen 1.4 节点内容追加 2. 设备树常用OF操作函数2.1 节点寻找类2.2 属性提取类2.3 其它常用类 4. 设备树下LED实验4.1 实验简介4.2 添加LED设备节点4.3 获取设备节点并提取属性4.3.1 获取…

国内类似ChatGPT的大模型应用有哪些?发展情况如何了

第一部分:几个容易混淆的概念 很多人,包括很多粉丝的科技博主,经常把ChatGPT和预训练大模型混为一谈,因此有必要先做一个澄清。预训练大语言模型属于预训练大模型的一类,而ChatGPT、文心一言又是预训练大语言模型的一个…

【Linux】Linux基本指令3

目录 1.date指令 2.cal指令 3.find指令:(灰常重要) -name 4.grep指令——行文本过滤工具 5.zip/unzip指令: 6.tar指令(重要):打包/解包,不打开它,直接看内容 7.bc…

SpringBoot六种API请求参数读取方式

SpringBoot六种API请求参数读取方式 同步请求和异步请求 同步: 指单线程依次做几件事异步: 指多线程同时做几件事 同步请求: 指客户端浏览器只有一个主线程, 此线程负责页面的渲染和发出请求等操作, 如果此主线程发出请求的话则停止渲染而且会清空页面显示的内容 直到服务器响…

3d渲染的常用概念和技术,渲染100邀请码1a12

之前我们介绍了3D渲染的基本原理和流程,这次说下几个常用概念和技术。 3D渲染中涉及到很多专业的概念和技术,它们决定了渲染质量和效果,常用的有以下几个。1、光线追踪 光线追踪是一些专业渲染器(如V-Ray和Corona等)…

算法思想总结:哈希表

一、哈希表剖析 1、哈希表底层:通过对C的学习,我们知道STL中哈希表底层是用的链地址法封装的开散列。 2、哈希表作用:存储数据的容器,插入、删除、搜索的时间复杂度都是O(1),无序。 3、什么时…

Android HIDL接口添加

一.HIDL介绍 HIDL的全称是HAL interface definition language(硬件抽象层接口定义语言),是Android Framework 与Android HAL之间的接口。HIDL 旨在用于进程间通信 (IPC),进程之间的通信 采用 Binder 机制。 二.HIDL 与AIDL 的对…

客户文章|难能可贵,非模式生物的功能研究与创新

菜豆(Phaseolus vulgaris),又名四季豆、芸豆、油豆角,是全球第一大豆类蔬菜,我国是世界上最主要的菜豆生产国和销售国。在田间生产过程中,菜豆常面临着各种生物和非生物逆境的胁迫,对其产量品质…

FOC - BLDC六步换相驱动原理

文章目录 1 . 前言2 . 电机旋转原理3 . BLDC特点4 . BLDC反电动势投影位置5 . BLDC换相时刻6 . BLDC换相注意事项7 . 小结 【全文大纲】 : https://blog.csdn.net/Engineer_LU/article/details/135149485 1 . 前言 无刷直流电机在这里区分为两种,一是永磁无刷直流电…

【Linux-LCD 驱动】

Linux-LCD 驱动 ■ Framebuffer 简称 fb■ LCD 驱动程序编写■ 1、LCD 屏幕 IO 配置■ 2、LCD 屏幕参数节点信息修改■ 3、LCD 屏幕背光节点信息■ 4、使能 Linux logo 显示 ■ 设置 LCD 作为终端控制台■ 1、设置 uboot 中的 bootargs■ 2、修改/etc/inittab 文件 ■ LCD 背光…

python前端streamlit模型部署

简单介绍使用前端streamlit框架快速部署本地模型: 1、模型训练: import pandas as pd # 流程整合 from sklearn.pipeline import make_pipeline, Pipeline # 数据处理 from sklearn.impute import SimpleImputer from sklearn.preprocessing import Min…

探索 Android Studio 中的 Gemini:加速 Android 开发的新助力

探索 Android Studio 中的 Gemini:加速 Android 开发的新助力 在 Gemini 时代的下一篇章中,Gemini融入了更多产品中,Android Studio 正在使用 Gemini 1.0 Pro 模型,使 Android 开发变得更快、更简单。 Studio Bot 现已更名为 And…

深度学习知识与心得

目录 深度学习简介 传统机器学习 深度学习发展 感知机 前馈神经网络 前馈神经网络(BP网络) 深度学习框架讲解 深度学习框架 TensorFlow 一个简单的线性函数拟合过程 卷积神经网络CNN(计算机视觉) 自然语言处理NLP Wo…

C# WinForm —— 23 Timers.Timer 组件介绍与使用

1. 简介 System.Timers.Timer 计时器 轻量 每隔一段时间触发Elapsed事件,执行操作(不是由UI线程执行的),即使事件中执行了比较耗时的操作,也不会造成 UI 失去响应 如果要获取服务器的计时功能的话,可以使用System.Timers.Timer …

unity2020打包webGL时卡进程问题

我使用的2020.3.0f1c1,打包发布WEB版的时候会一直卡到asm2wasm.exe这个进程里,而且CPU占用率90%以上。 即使是打包一个新建项目的空场景也是同样的问题,我尝试过一直卡在这里会如何,结果还真打包成功了。只是打包一个空场景需要20…

C++(入门基础版本)

1,什么是C C 是一种通用的、面向对象的编程语言,是 C 语言的一个超集,也就是说,任何有效的 C 程序都是有效的 C 程序。C 通过添加诸如类和对象、继承和多态等概念,扩展了 C 语言的功能,使其更适用于大型软…

CSS学习笔记目录

CSS学习笔记之基础教程(一) CSS学习笔记之基础教程(二) CSS学习笔记之中级教程(一) CSS学习笔记之中级教程(二) CSS学习笔记之中级教程(三) CSS学习笔记之高级…

国产身份域管架构图集合(信创政策AD域替换必看)

几类典型架构 双机架构 单点单机房 集群架构 多点单机房 两地三中心架构 多点多机房 多地分布式架构 多点多机房 全栈信创方案架构,欢迎探讨交流~

[数据集][目标检测]喝水检测数据集VOC+YOLO格式995张3类别

数据集格式:Pascal VOC格式YOLO格式(不包含分割路径的txt文件,仅仅包含jpg图片以及对应的VOC格式xml文件和yolo格式txt文件) 图片数量(jpg文件个数):995 标注数量(xml文件个数):995 标注数量(txt文件个数):995 标注类别…