从TinyZero的数据与源码来理解DeepSeek-R1-Zero的强化学习训练过程

1. 引入

TinyZero(参考1)是伯克利的博士生复现DeepSeek-R1-Zero的代码参仓库,他使用veRL来运行RL强化学习方法,对qwen2.5的0.5B、1.5B、3B等模型进行训练,在一个数字游戏数据集上,达到了较好的推理效果。

下面解读源码中的关键训练逻辑细节。

2. 训练过程

  1. 原始数据

原始数据来自参考2,一共490k条数据,数据中只有两个字段,格式如下:

{
	"nums": [ 95, 11, 56 ],
	"target":28
}

这是一个数字游戏,要求对nums中的数据,进行基础数学运算(+, -, *, /),每个数字只能用一次,最终结果等于target的值。比如上例子,95-11-56=28。

  1. 数据处理

具体源码见参考3,下文仅仅解析关键步骤:

(1)训练集和测试集大小

默认值如下:

parser.add_argument('--train_size', type=int, default=327680)
parser.add_argument('--test_size', type=int, default=1024)

(2)对原始数据添加提示词

下面的dp就是一条原始数据(参考2.1例子):

def make_prefix(dp, template_type):
    target = dp['target']# 取出目标
    numbers = dp['nums']# 取出数字
    # 对于默认模型加的提示词如下
    if template_type == 'base':
        """This works for any base model"""
        prefix = f"""A conversation between User and Assistant. The user asks a question, and the Assistant solves it. The assistant first thinks about the reasoning process in the mind and then provides the user with the answer.
User: Using the numbers {numbers}, create an equation that equals {target}. You can use basic arithmetic operations (+, -, *, /) and each number can only be used once. Show your work in <think> </think> tags. And return the final answer in <answer> </answer> tags, for example <answer> (1 + 2) / 3 </answer>.
Assistant: Let me solve this step by step.
<think>"""
	# 对于qwen-instruct模型加的提示词如下
    elif template_type == 'qwen-instruct':
        """This works for Qwen Instruct Models"""
        prefix = f"""<|im_start|>system\nYou are a helpful assistant. You first thinks about the reasoning process in the mind and then provides the user with the answer.<|im_end|>\n<|im_start|>user\n Using the numbers {numbers}, create an equation that equals {target}. You can use basic arithmetic operations (+, -, *, /) and each number can only be used once. Show your work in <think> </think> tags. And return the final answer in <answer> </answer> tags, for example <answer> (1 + 2) / 3 </answer>.<|im_end|>\n<|im_start|>assistant\nLet me solve this step by step.\n<think>"""
    return prefix

(3)对数据进行完整的处理,增加提示词与reward等数据

如下函数中的example就是一条原始数据(参考2.1例子)。

        def process_fn(example, idx):
            question = make_prefix(example, template_type=args.template_type) # 增加提示词,见2.2.2
            solution = {
                "target": example['target'],
                "numbers": example['nums']
            }
            data = {
                "data_source": data_source, # 任务名称,默认为'countdown'
                "prompt": [{
                    "role": "user",
                    "content": question, # 带有提示词的问题
                }],
                "ability": "math",
                "reward_model": {
                    "style": "rule",
                    "ground_truth": solution # 含有nums和target
                },
                "extra_info": {
                    'split': split,
                    'index': idx,
                }
            }
            return data

最终数据为含有prompt和reward_model等字段的json结构。

  1. 训练

从参考4的训练代码中,摘取部分配置如下:

python3 -m verl.trainer.main_ppo \
data.train_files=$DATA_DIR/train.parquet \
data.val_files=$DATA_DIR/test.parquet \
data.train_batch_size=256 \
data.val_batch_size=1312 \
data.max_prompt_length=256 \
data.max_response_length=1024 \
actor_rollout_ref.model.path=$BASE_MODEL \
actor_rollout_ref.actor.optim.lr=1e-6 \
actor_rollout_ref.actor.ppo_mini_batch_size=128 \
actor_rollout_ref.actor.ppo_micro_batch_size=8 \
actor_rollout_ref.rollout.log_prob_micro_batch_size=8 \
actor_rollout_ref.rollout.tensor_model_parallel_size=$ROLLOUT_TP_SIZE \
actor_rollout_ref.rollout.gpu_memory_utilization=0.4 \
actor_rollout_ref.ref.log_prob_micro_batch_size=4 \
critic.optim.lr=1e-5 \
critic.model.path=$BASE_MODEL \
critic.ppo_micro_batch_size=8 \
algorithm.kl_ctrl.kl_coef=0.001 \
trainer.logger=['wandb'] \
+trainer.val_before_train=False \
trainer.default_hdfs_dir=null \
trainer.n_gpus_per_node=$N_GPUS \
trainer.nnodes=1 \
trainer.save_freq=100 \
trainer.test_freq=100 \
trainer.project_name=TinyZero \
trainer.experiment_name=$EXPERIMENT_NAME \
trainer.total_epochs=15 2>&1 | tee verl_demo.log

这条命令是一个典型的 Python 脚本调用,用于训练一个基于 PPO(Proximal Policy Optimization) 算法的模型。

用veRL进行训练(参考5),需要指定 数据、模型、超参数:

(1)数据相关配置

data.train_files=$DATA_DIR/train.parquet:指定训练数据文件路径(Parquet 格式)。

data.val_files=$DATA_DIR/test.parquet:指定验证数据文件路径。

data.train_batch_size=256:训练时的批量大小(batch size)。

data.val_batch_size=1312:验证时的批量大小。

data.max_prompt_length=256:输入提示(prompt)的最大长度。

data.max_response_length=1024:生成响应(response)的最大长度。

(2)Actor 模型配置

actor_rollout_ref.model.path=$BASE_MODEL:指定 Actor 模型的预训练权重路径。

actor_rollout_ref.actor.optim.lr=1e-6:Actor 模型的学习率。

actor_rollout_ref.actor.ppo_mini_batch_size=128:PPO 算法中 Actor 的 mini-batch 大小。

actor_rollout_ref.actor.ppo_micro_batch_size=8:PPO 算法中 Actor 的 micro-batch 大小。

actor_rollout_ref.rollout.log_prob_micro_batch_size=8:Rollout 阶段计算 log probability 的 micro-batch 大小。

actor_rollout_ref.rollout.tensor_model_parallel_size=$ROLLOUT_TP_SIZE:Rollout 阶段的张量并行大小(用于分布式训练)。

actor_rollout_ref.rollout.gpu_memory_utilization=0.4:Rollout 阶段的 GPU 内存利用率。

actor_rollout_ref.ref.log_prob_micro_batch_size=4:参考模型(ref model)计算 log probability 的 micro-batch 大小。

(3)Critic 模型配置

critic.optim.lr=1e-5:Critic 模型的学习率。

critic.model.path=$BASE_MODEL:指定 Critic 模型的预训练权重路径。

critic.ppo_micro_batch_size=8:PPO 算法中 Critic 的 micro-batch 大小。

(4)算法配置

algorithm.kl_ctrl.kl_coef=0.001:KL 散度(Kullback-Leibler divergence)的系数,用于控制策略更新的稳定性。

(5)训练器配置

trainer.logger=['wandb']:使用 Weights & Biases(WandB)作为日志记录工具。

+trainer.val_before_train=False:在训练开始前不进行验证。

trainer.default_hdfs_dir=null:HDFS 目录未设置(HDFS 是分布式文件系统)。

trainer.n_gpus_per_node=$N_GPUS:每个节点使用的 GPU 数量。

trainer.nnodes=1:使用的节点数量(单节点训练)。

trainer.save_freq=100:每 100 步保存一次模型。

trainer.test_freq=100:每 100 步进行一次测试。

trainer.project_name=TinyZero:WandB 项目名称。

trainer.experiment_name=$EXPERIMENT_NAME:实验名称。

trainer.total_epochs=15:总训练轮数(epochs)。
  1. 训练效果

用强化学习的方法训练后,能如下所示,输出字段(推理过程),并给出最终结果字段。
在这里插入图片描述

3. 总结

通过具体的数据与处理训练过程,来更好的理解DeepSeek-R1-Zero的强化学习训练方法。

4. 参考

  1. 项目:https://github.com/Jiayi-Pan/TinyZero
  2. 数据:https://huggingface.co/datasets/Jiayi-Pan/Countdown-Tasks-3to4
  3. 数据处理源码:https://github.com/Jiayi-Pan/TinyZero/blob/main/examples/data_preprocess/countdown.py
  4. 训练源码:https://github.com/Jiayi-Pan/TinyZero/blob/main/scripts/train_tiny_zero.sh
  5. veRL:https://verl.readthedocs.io/en/latest/start/quickstart.html

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/963398.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

深度卷积神经网络实战无人机视角目标识别

本文采用深度卷积神经网络作为核心算法框架&#xff0c;结合PyQt5构建用户界面&#xff0c;使用Python3进行开发。YOLOv8以其高效的实时检测能力&#xff0c;在多个目标检测任务中展现出卓越性能。本研究针对无人机目标数据集进行训练和优化&#xff0c;该数据集包含丰富的无人…

初级数据结构:栈和队列

一、栈 (一)、栈的定义 栈是一种遵循后进先出&#xff08;LIFO&#xff0c;Last In First Out&#xff09;原则的数据结构。栈的主要操作包括入栈&#xff08;Push&#xff09;和出栈&#xff08;Pop&#xff09;。入栈操作是将元素添加到栈顶&#xff0c;这一过程中&#xf…

数据结构 前缀中缀后缀

目录 前言 一&#xff0c;前缀中缀后缀的基本概念 二&#xff0c;前缀与后缀表达式 三&#xff0c;使用栈实现后缀 四&#xff0c;由中缀到后缀 总结 前言 这里学习前缀中缀后缀为我们学习树和图做准备&#xff0c;这个主题主要是对于算术和逻辑表达式求值&#xff0c;这…

笔灵ai写作技术浅析(三):深度学习

笔灵AI写作的深度学习技术主要基于Transformer架构,尤其是GPT(Generative Pre-trained Transformer)系列模型。 1. Transformer架构 Transformer架构由Vaswani等人在2017年提出,是GPT系列模型的基础。它摒弃了传统的循环神经网络(RNN)和卷积神经网络(CNN),完全依赖自…

专业的定制版软件,一键操作,无限使用

今天给大家介绍一个专业的PDF转word的小软件&#xff0c;软件只有5.5M。非常小&#xff0c;而且没有文档大小的限制&#xff0c;可以随意使用。 PDFtu PDF转word 软件第一次使用需要安装一下。 安装好之后&#xff0c;我们就能在桌面找到对应的图标&#xff0c;打开就能直接使…

QGIS系列22-如何提取不规则多边形的中心经纬度

今天我们来学习一下啊如何通过QGIS提取不规则多边形的中心经纬度 1、首先我们把不规则的多边形图形导入进QGIS里面去 2、现在打开的图层是不可以编辑的&#xff0c;因此我们还需要转换成可编辑状态&#xff0c;具体是选择图层&#xff0c;右键点击&#xff0c;选择切换编辑模式…

word2vec 实战应用介绍

Word2Vec 是一种由 Google 在 2013 年推出的重要词嵌入模型,通过将单词映射为低维向量,实现了对自然语言处理任务的高效支持。其核心思想是利用深度学习技术,通过训练大量文本数据,将单词表示为稠密的向量形式,从而捕捉单词之间的语义和语法关系。以下是关于 Word2Vec 实战…

数据库安全管理中的权限控制:保护数据资产的关键措施

title: 数据库安全管理中的权限控制:保护数据资产的关键措施 date: 2025/2/2 updated: 2025/2/2 author: cmdragon excerpt: 在信息化迅速发展的今天,数据库作为关键的数据存储和管理中心,已经成为了企业营运和决策的核心所在。然而,伴随着数据规模的不断扩大和数据价值…

【漫话机器学习系列】076.合页损失函数(Hinge Loss)

Hinge Loss损失函数 Hinge Loss&#xff08;合页损失&#xff09;&#xff0c;也叫做合页损失函数&#xff0c;广泛用于支持向量机&#xff08;SVM&#xff09;等分类模型的训练过程中。它主要用于二分类问题&#xff0c;尤其是支持向量机中的优化目标函数。 定义与公式 对于…

openmv的端口被拆分为两个 导致电脑无法访问openmv文件系统解决办法 openmv USB功能改动 openmv驱动被更改如何修复

我之前误打误撞遇到一次&#xff0c;直接把openmv的全部端口删除卸载然后重新插上就会自动重新装上一个openmv端口修复成功&#xff0c;大家可以先试试不行再用下面的方法 全部卸载再重新插拔openmv 要解决OpenMV IDE中出现的两个端口问题&#xff0c;可以尝试以下步骤&#x…

洛谷P1403 [AHOI2005] 约数研究

题目链接&#xff1a;P1403 [AHOI2005] 约数研究 - 洛谷 | 计算机科学教育新生态 题目难度&#xff1a;普及一 题目分析&#xff1a;本题很明显是要你求从i到n的质因数个数之和&#xff0c;如果采用暴力肯定是超时的&#xff0c;故我的想法是采用埃氏筛法来求时间复杂度为&…

elasticsearch8.15 高可用集群搭建(含认证Kibana)

文章目录 1.资源配置2.系统参数优化3.JDK17安装4.下载&安装ES 8.155.生成ES的证书(用于ES节点之间进行安全数据传输)6.修改ES 相关配置文件7.创建es用户并启动8.配置ES的账号和密码(用于ES服务端和客户端)9.下载和安装Kibana10.编辑Kibana配置文件11.启动Kiabana12.访问Kia…

MATLAB中的IIR滤波器设计

在数字信号处理中&#xff0c;滤波器是消除噪声、提取特征或调整信号频率的核心工具。其中&#xff0c;无限脉冲响应&#xff08;IIR&#xff09;滤波器因其低阶数实现陡峭滚降的特性&#xff0c;被广泛应用于音频处理、通信系统和生物医学工程等领域。借助MATLAB强大的工具箱&…

数据结构:优先级队列—堆

一、优先级队列 1、优先级队列概念 优先级队列&#xff0c;听名字我们就知道他是一种队列&#xff0c;队列在前面我们已经学习过了&#xff0c;它是一种先进先出的数据结构&#xff0c;但是在特殊的情况下&#xff0c;我们我们队列中元素是带有一定优先级的&#xff0c;它需要…

北大:三阶段学习优化多模态推理问答

&#x1f4d6;标题&#xff1a;ReasVQA: Advancing VideoQA with Imperfect Reasoning Process &#x1f310;来源&#xff1a;arXiv, 2501.13536 &#x1f31f;摘要 &#x1f538;视频问答&#xff08;VideoQA&#xff09;是一项具有挑战性的任务&#xff0c;需要理解视频中…

从零开始:用Qt开发一个功能强大的文本编辑器——WPS项目全解析

文章目录 引言项目功能介绍1. **文件操作**2. **文本编辑功能**3. **撤销与重做**4. **剪切、复制与粘贴**5. **文本查找与替换**6. **打印功能**7. **打印预览**8. **设置字体颜色**9. **设置字号**10. **设置字体**11. **左对齐**12. **右对齐**13. **居中对齐**14. **两侧对…

Jason配置环境变量

jason官网 https://jason-lang.github.io/ https://github.com/jason-lang/jason/releases 步骤 安装 Java 21 或更高版本 安装 Visual Studio Code 根据操作系统&#xff0c;请按照以下具体步骤操作 视窗 下载 Jason 的最新版本&#xff0c;选择“jason-bin-3.3.0.zip”…

机器学习--概览

一、机器学习基础概念 1. 定义 机器学习&#xff08;Machine Learning, ML&#xff09;&#xff1a;通过算法让计算机从数据中自动学习规律&#xff0c;并利用学习到的模型进行预测或决策&#xff0c;而无需显式编程。 2. 与编程的区别 传统编程机器学习输入&#xff1a;规…

如何使用SliverGrid组件

文章目录 1 概念介绍2 使用方法3 示例代码 我们在上一章回中介绍了SliverList组件相关的内容&#xff0c;本章回中将介绍SliverGrid组件.闲话休提&#xff0c;让我们一起Talk Flutter吧。 1 概念介绍 我们在本章回中介绍的SliverGrid组件是一种网格类组件&#xff0c;主要用来…

大模型培训讲师老师叶梓分享:DeepSeek多模态大模型janus初探

以下视频内容为叶梓分享DeepSeek多模态大模型janus的部署&#xff0c;并验证其实际效果&#xff0c;包括图生文和文生图两部分。 叶梓老师人工智能培训分享DeepSeek多模态大模型janus初探 DeepSeek 的多模态大模型 Janus 是一款强大的 AI 模型&#xff0c;专注于图像和文本的多…