强化学习-RLHF-PPO入门

一、定义

  1. 强化学习微调分类
  2. RM模型 数据集格式
  3. 训练流程
  4. Reward 模型训练流程(分类模型,积极为1,消极为0) AutoModelForSequenceClassification
  5. Reward 模型训练案例
  6. PPO模型训练流程
  7. PPO模型训练案例

二、实现

  1. 强化学习微调分类
    RLHF:基于人类反馈对语言模型进行强化学习, 分两步:
    1. RM (Reward Model)奖励模型建模,构造人类偏好排序数据集,训练奖励模型,用来建模人类偏好。
    2 RL(Reinforcement Learning)强化学习,用奖励模型来训练SFT模型,生成模型使用奖励或惩罚来更新其策略,以便生成更高质量、更符合人类偏好的文本.
    DPO(Direct Preference Optimization): 直接偏好优化方法,DPO通过直接优化语言模型来实现对其行为的精确控制,而无需使用复杂的强化学习,也可以有效学习到人类偏好。
    RLHF主要是进行对齐微调, 目标是将大语言模型的行为与人类的价值观或偏好对齐。
    PPO: (Proximal Policy Optimization,近端策略优化)是一种在强化学习领域广泛使用的算法.

  2. RM模型 数据集格式

{conversations:  [0: 
 {from:  
"human",value:  "国会的转发 美国国会由众议院和参议院组成,每两年换届一次(参议员任期为6年,但参议院选举是错位的。是更常见地转发国会议员还是来自国会外部?"}],
chosen:  {from:  "gpt",value:  "计算推文的政党边际概率,我们可以使用以下代码这表明大多数转发不是来自国会议员,而是来自国会之外。"},
rejected:  {from:  "gpt",value:  "回答问题的第(计算转发国会议员或来自国会以外的人的边际概率"}}

其中chosen 代表是好的回答, rejected代表的是不好的回答

  1. 训练流程
    在这里插入图片描述
    训练reward Model---->PPO模型

  2. Reward 模型训练流程(激励模型为深度学习模型)
    数据处理:

def preprocess_function(examples):
    new_examples = {
        "input_ids_chosen": [],
        "attention_mask_chosen": [],
        "input_ids_rejected": [],
        "attention_mask_rejected": [],
    }
    for chosen, rejected in zip(examples["chosen"], examples["rejected"]):
        tokenized_chosen = tokenizer(chosen)
        tokenized_rejected = tokenizer(rejected)

        new_examples["input_ids_chosen"].append(tokenized_chosen["input_ids"])
        new_examples["attention_mask_chosen"].append(tokenized_chosen["attention_mask"])
        new_examples["input_ids_rejected"].append(tokenized_rejected["input_ids"])
        new_examples["attention_mask_rejected"].append(tokenized_rejected["attention_mask"])
    return new_examples

训练求损失:AutoModelForSequenceClassification 分类模型

model = AutoModelForSequenceClassification.from_pretrained(
    model_config.model_name_or_path, num_labels=1, **model_kwargs
)
def compute_loss(
    self,
    model: Union[PreTrainedModel, nn.Module],
    inputs: Dict[str, Union[torch.Tensor, Any]],
    return_outputs=False,
) -> Union[torch.Tensor, Tuple[torch.Tensor, Dict[str, torch.Tensor]]]:
    if not self.use_reward_data_collator:
        warnings.warn(
            "The current compute_loss is implemented for RewardDataCollatorWithPadding,"
            " if you are using a custom data collator make sure you know what you are doing or"
            " implement your own compute_loss method."
        )
    rewards_chosen = model(
        input_ids=inputs["input_ids_chosen"],
        attention_mask=inputs["attention_mask_chosen"],
        return_dict=True,
    )["logits"]
    rewards_rejected = model(
        input_ids=inputs["input_ids_rejected"],
        attention_mask=inputs["attention_mask_rejected"],
        return_dict=True,
    )["logits"]
    # calculate loss, optionally modulate with margin
    if "margin" in inputs:
        loss = -nn.functional.logsigmoid(rewards_chosen - rewards_rejected - inputs["margin"]).mean()
    else:
        loss = -nn.functional.logsigmoid(rewards_chosen - rewards_rejected).mean()

    if return_outputs:
        return loss, {
            "rewards_chosen": rewards_chosen,
            "rewards_rejected": rewards_rejected,
        }
    return loss
  1. Reward 模型训练案例
    https://github.com/huggingface/trl/blob/main/examples/scripts/reward_modeling.py

  2. PPO模型训练流程

在这里插入图片描述

步骤:
1. 语言模型预测
2. 激活模型评估(分类模型),1 代表积极,0 代表消极
3. PPO算法优化。
数据:

def tokenize(sample):
    sample["input_ids"] = tokenizer.encode(sample["review"])[: input_size()]
    sample["query"] = tokenizer.decode(sample["input_ids"])
    return sample
  1. 加载模型, 参考模型(参考模型可以为None)
# We then build the PPOTrainer, passing the model, the reference model, the tokenizer
ppo_trainer = PPOTrainer(ppo_config, model, ref_model, tokenizer, dataset=dataset, data_collator=collator)
# Get response from gpt2    待训练的模型响应,参考模型响应
response_tensors, ref_response_tensors = ppo_trainer.generate(
    query_tensors, return_prompt=False, generate_ref_response=True, **generation_kwargs
)
batch["response"] = tokenizer.batch_decode(response_tensors)
batch["ref_response"] = tokenizer.batch_decode(ref_response_tensors)
  1. 激活模型评估(分类模型),1 代表积极,0 代表消极
2. 获取激励值
# Compute sentiment score
texts = [q + r for q, r in zip(batch["query"], batch["response"])]
pipe_outputs = sentiment_pipe(texts, **sent_kwargs)                      #激励函数
rewards = [torch.tensor(output[1]["score"]) for output in pipe_outputs]   #激励值
  1. PPO算法优化。
#   问题query  、  模型响应   、激励值
#其中上图优化模块,均在step 方法中实现。
stats = ppo_trainer.step(query_tensors, response_tensors, rewards)
step内部:
        all_logprobs, logits_or_none, values, masks = self.batched_forward_pass(
            self.model,
            queries,
            responses,
            model_inputs,
            response_masks=response_masks,
            return_logits=full_kl_penalty,
        )
        with self.optional_peft_ctx():
            ref_logprobs, ref_logits_or_none, _, _ = self.batched_forward_pass(
                self.model if self.is_peft_model else self.ref_model,
                queries,
                responses,
                model_inputs,
                return_logits=full_kl_penalty,
            )
  1. PPO模型训练案例
    https://github.com/huggingface/trl/blob/main/examples/scripts/ppo.py

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/743624.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

什么概率密度函数?

首先我们来理解一下什么是连续的随机变量,在此之前,我们要先理解什么是随机变量。所谓随机变量就是在一次随机实验中一组可能的值。比如说抛硬币,我们设正面100,反面200,设随机变量为X,那么X{100,200}。 X是…

Java之多线程的实现与应用

多线程 创建进程方式: (1)继承Thread类 class Main {public static void main(String[] args) { MyThread01 myThread01new MyThread01(); myThread01.start(); while(true){System.out.println("main方法的run()方法正在运行")…

Vue 3 中处理文件上传和响应式更新

Vue 3 中处理文件上传和响应式更新 一、前言1.创建文件上传组件2.解释代码3.在主应用中使用文件上传组件4.总结 一、前言 在现代 web 开发中,文件上传是一个常见需求。本文将详细介绍如何在 Vue 3 中处理文件上传,并确保上传后的文件列表能够响应式更新…

AI视频教程下载-定制GPT:使用您的数据创建一个定制聊天GPT

Custom GPTs_ Create a Custom ChatGPT with Your Data 构建一个定制的GPT,与您自己的数据进行聊天。添加文档,生成图像,并集成API和Zapier。 这门全面的Udemy课程专为那些渴望学习如何创建自己定制版ChatGPT的人设计,以满足他们…

C++:C与C++混合编程

混合编程 为什么需要混合编程 (1)C有很多优秀成熟项目和库,丢了可惜,重写没必要,C程序里要调用 (2)庞大项目划分后一部分适合用C,一部分适合用C (3)其他情况,如项目组一部分人习惯用C,一部分习惯用C 为什么…

HarmonyOS角落里的知识:“开发应用沉浸式效果”

概述 典型应用全屏窗口UI元素包括状态栏、应用界面和底部导航条。开发应用沉浸式效果主要指通过调整状态栏、应用界面和导航条的显示效果来减少状态栏导航条等系统界面的突兀感,从而使用户获得最佳的UI体验。 图1 界面元素示意图 开发应用沉浸式效果主要要考虑如下…

心灵馆咨询系统小程序心理咨询平台聊天咨询

心灵馆咨询系统小程序:解锁你的心灵密码 💖 心灵之旅的导航者 在繁忙的现代生活中,我们时常会面临各种压力与困惑。心灵馆咨询系统小程序,如同一位贴心的导航者,引领我们探索内心的世界,寻找真正的自我。 …

DDP(Differential Dynamic Programming)算法举例

DDP(Differential Dynamic Programming)算法 基本原理 DDP(Differential Dynamic Programming)是一种用于求解非线性最优控制问题的递归算法。它基于动态规划的思想,通过线性化系统的动力学方程和二次近似代价函数,递归地优化控制策略。DDP的核心在于利用局部二次近似来…

北大医院副院长李建平:用AI解决临床心肌缺血预测的难点、卡点和痛点

2024年6月14日,第六届北京智源大会在中关村展示中心开幕,海内外的专家学者围绕人工智能关键技术路径和应用场景,展开了精彩演讲与尖峰对话。在「智慧医疗和生物系统:影像、功能与仿真」论坛上,北京大学第一医院副院长、…

[经典]原型资源:蚂蚁金服UI模版部件库

部件库预览链接: https://d3ttsx.axshare.com 支持版本: Axrure RP 8 文件大小: 30MB 文档内容介绍 基本部件:表单样式:12款、数据样式:10款、服务样式:6款、导航:5款、业务组件:7款、 模板…

区块链技术与数字货币

1.起源 ➢中本聪(Satoshi Nakamoto), 2008 ➢比特币:一种点对点的电子现金系统 2.分布式账本技术原理 1.两个核心技术: ➢以链式区块组织账本数据实现账本数据的不可篡改 ➢分布式的可信记账机制 2.共识机制:由谁记账 ➢目的: ⚫ 解…

鸿蒙开发系统基础能力:【@ohos.hiTraceMeter (性能打点)】

性能打点 本模块提供了追踪进程轨迹,度量程序执行性能的打点能力。本模块打点的数据供hiTraceMeter工具分析使用。 说明: 本模块首批接口从API version 8开始支持。后续版本的新增接口,采用上角标单独标记接口的起始版本。 导入模块 impor…

AcWing算法基础课笔记——状态压缩DP:蒙德里安的梦想

状态压缩DP 状态是整数,但把它看成二进制数,二进制中每一位是0或1表示不同的情况。 蒙德里安的梦想 291. 蒙德里安的梦想 - AcWing题库 题目 求把 NM𝑁𝑀 的棋盘分割成若干个 1212 的长方形,有多少种方案。 例如…

Java面试题:聚簇索引和非聚簇索引

聚簇索引和非聚簇索引 聚簇索引(聚集索引) 将数据的存储和索引放在一块,索引结构的叶子节点保存了行数据 索引字段必须存在,且只能存在一个 非聚集索引(二级索引) 将数据和索引分开存储,索引结构的叶子节点关联的是对应的主键 索引字段可以存在多个 索引的选取规则 如果…

2024 年 8 款最佳建筑 3D 渲染软件

你现在使用的3D 渲染软件真得适合你吗? 在建筑和室内渲染当中,市面上有许多3D渲染软件可供选择。然而,并不是每款软件都适合你的需求。本指南将重点介绍2024年精选的8款最佳建筑3D渲染软件,帮助你了解不同的选项,并选…

第100+13步 ChatGPT学习:R实现决策树分类

基于R 4.2.2版本演示 一、写在前面 有不少大佬问做机器学习分类能不能用R语言,不想学Python咯。 答曰:可!用GPT或者Kimi转一下就得了呗。 加上最近也没啥内容写了,就帮各位搬运一下吧。 二、R代码实现决策树分类 (…

SSM宠物领养系统-计算机毕业设计源码08465

目 录 摘要 1 绪论 1.1课题背景及意义 1.2研究现状 1.3ssm框架介绍 1.3论文结构与章节安排 2 宠物领养系统系统分析 2.1 可行性分析 2.2 系统流程分析 2.2.1 数据流程 3.3.2 业务流程 2.3 系统功能分析 2.3.1 功能性分析 2.3.2 非功能性分析 2.4 系统用例分析 …

大模型管理平台:one-api使用指南

大模型相关目录 大模型,包括部署微调prompt/Agent应用开发、知识库增强、数据库增强、知识图谱增强、自然语言处理、多模态等大模型应用开发内容 从0起步,扬帆起航。 大模型应用向开发路径:AI代理工作流大模型应用开发实用开源项目汇总大模…

Go 实现SFTP连接服务

我们将SFTP连接和处理逻辑,以及登录账户信息封装,这样可以在不同的地方重用代码,并且可以轻松地更改登录凭据。下面我将演示如何使用Go语言中的结构体来封装这些信息,并实现一个简单的SFTP服务器: package mainimport…

信息系统项目管理师 | 新一代信息技术

关注WX:CodingTechWork 物联网 定义 The Internet of Things是指通过信息传感设备,按约定的协议,将任何物品与互联网连接,进行信息交互和通信,以实现智能化识别。定位、跟踪、监控和管理的一种网络。物联网主要解决…