OpenAI o1复现:自动构造prm训练数据-OmegaPRM

作者:cmathx
原文:https://zhuanlan.zhihu.com/p/1477078851

openai o1复现中,有个比较关键的问题,怎么样自动化构造prm模型的训练数据?本文主要从代码层面,来解析OmegaPRM原理。

论文

Improve Mathematical Reasoning in Language...[1]

原理

Markov决策过程

图片

 

OmegaPRM

State:对应Markov决策过程中的状态,rollout:对应Markov决策过程中的动作;

  • • step1:初始化root节点state;每个state包含n个扩展rollouts,q+pa作为prompt,进行n次llm生成采样;基于bootstrap采样方法估计Monte Carlo模拟正确答案的概率mc;

  • • step2:从所有节点中,基于UCB1(Explore&&Exploit方法)选取最优的“state和rollout”,添加到PRM训练集;Exploit:alpha ** (1 - mc) * beta ** (len(r) / L),其中:mc表示蒙特卡洛模拟正确答案概率、len(r)表示LLM生成的长度;Explore:c_puct * sqrt(N_sum) / (1 + s.v),其中:N_sum表示所有节点的访问次数,s.v表示当前节点的访问次数,c_puct控制MCTS树的探索程度;

  • • step3:评估最优“state和rollout”,二分rollout的结果,将左半部分纳入到新的state中,并计算新的mc;mc=1,表示state完全包含正确答案,忽略;mc=0,表示state完全没有生成正确答案可能性,添加到叶子节点;mc>0,表示state作为继续探索的节点;

  • • step4:重复step2、step3,直至“探索到足够的样本、无法继续探索”退出;

  • • step5:将叶子节点全部添加到PRM训练集;

PRM模型训练效果

论文的base模型

图片

 

基于OmegaPRM方法合成数据,在MATH数据集,相比base model51%的准确率,OmegaPRM准确率提高到69.4%;

图片

 

其他PRM方法

OmegaPRM:gemini提到的方法;

AlphaMath:qwen提到的方法;

Math-Shepherd: Verify and Reinforce LLMs Step-by-step without Human Annotations[2]

AlphaMath Almost Zero: Process Supervision without Process[3]

源码来源

https://github.com/openreasoner/openr[4]

源码解析

数据结构

class State:
    def __init__(self, q, pa, a):
        self.q = q #问题
        self.pa = pa #当前step的prompt
        self.a = a #答案
        self.mc = None #基于当前节点,生成正确答案的概率
        self.v = 0 #被访问次数
        self.rollouts = [] #扩展的子节点
        self.rollout_was_visited = [] #扩展的子节点是否被访问

主流程

 # Load the JSON data
    data = load_json_file(json_file_path)
    
    # Process each problem and its final answer
    for i, item in enumerate(data):
        problem = item.get('problem', 'No problem found')
        final_answer = item.get('final_answer', 'No answer found')
        
        # Print to console
        print(f'Problem {i + 1}: {problem}')
        print(f'Final Answer: {final_answer}')
        
        # Log each problem and answer
        logging.info(f'Processed Problem {i + 1}: {problem}')
        logging.info(f'Final Answer: {final_answer}')
        
        # Call getrollout and handle the result
        states = []
        root = State(problem, '', final_answer)
        max_roll_num = 20
        rollouts, corrs = getrollouts(root, max_roll_num)
        mcst = cal_mc_bs(root)
        root.mc = mcst
           
        # 生成root节点
        states.append(root)

        if sum(corrs) > 0 and sum(corrs) < max_roll_num: 
            print('Process annotation ...\n')
            filename = str(i+1) +'_states_list.json'
            # 生成PRM训练数据
            process_annotation(problem, final_answer, states, filename)

蒙特卡洛采样

#针对节点s进行n次采样,基于LLM生成n个rollouts,并给出每个rollout是否包含正确答案;
def getrollouts(s, n = 5):
  corrs = []
  q = s.q
  pa = s.pa
  for i in range(n):
    re = complete_answer(q, pa)
    s.add_rollout(re)
    #check the answer
    a = s.a
    if check_answer(a, re):
      corrs.append(1)
    else:
      corrs.append(0)
  return s.rollouts, corrs
  
  #蒙特卡洛采样,并给出包含正确答案的概率
  def cal_mc_bs(s, bs = 5):
    n = len(s.rollouts)
    subn = max(1,random.randint(n//2, n))
    mc = 0
    for i in range(bs):
    corr = 0
        sub = random.sample(s.rollouts, subn)
        for r in sub:
            if check_answer(s.a, r):
                corr += 1
        mc += corr * 1.0 / len(sub)
    return mc / bs 
  
  #针对问题problem,使用problem+partial_answer作为prompt,进行LLM生成
  complete_answer(problem, partial_answer, checkpoint)
  #LLM生成的response是否包含正确答案groundtruth_answer
  check_answer(groundtruth_answer, response)

基于mcts方法自动构造prm训练数据

#基于MCTS方法生成PRM训练数据
def process_annotation(q, a, states, filename = 'states_list.json'):
   print('++++++')
   it = 0
   leaf_states = []
   while True:
       s, rollout, maxqu = select(states)
       if s is not None and s.pa!='':
           new_data = {
               'q': q,           # Ensure q is serializable
               'states': s.pa, # Ensure states is serializable
               'mcs': s.mc        # Ensure mcs is serializable
           }
           # Call the function to append the new data
           append_to_json_file(filename, new_data)
           it += 1
           if it > 100:
               break
       # all state-rolls pairs were exhausted
       if s is None:
           break
       print()
       print('[sel]')
       print(s)
       print('  roll=',rollout,' || qu=', maxqu)
       
       s.add_visit()
       div_roll_sts,leaf_sts = error_locate(s, rollout)
       if len(div_roll_sts)==0:
           continue
       
       states.extend([s for s in div_roll_sts if s!=None and s.pa != ''])
       leaf_states.extend(leaf_sts)
   #
   ## add leaf states to data
   for s in leaf_states:
       new_data = {
           'q': q,           # Ensure q is serializable
           'states': s.pa, # Ensure states is serializable
            'mcs': s.mc        # Ensure mcs is serializable
       }
       # Call the function to append the new data
       append_to_json_file(filename, new_data)
   print('++++++')

基于UCB1方法,选择最优的节点,纳入到训练集

#选择当前最优的节点
#exploitation:使用“更大的mc、更短的llm生成”节点;
#exploration:探索“未充分访问的、更大的树探索程度”节点;
def select(states):
    best_st = None
    best_roll_idx = -1
    best_qu = -1
    for s in states:
        # mcs = cal_mc(s) if s.mc is None else s.mc
        mcs = cal_mc_bs(s) if s.mc is None else s.mc
        if mcs == 0 or mcs==1.0:
            continue
        for i,r in enumerate(s.rollouts):
            if s.rollout_was_visited[i]:
                continue
            q = Q(r, mcs)
            u = U(s,states)
            qu = q + u
            if qu > best_qu:
                best_st = s
                best_roll_idx = i
                best_qu = qu
                  #
    if best_roll_idx != -1:
        best_st.rollout_was_visited[best_roll_idx] = True
    return best_st,best_st.rollouts[best_roll_idx],best_qu

#exploitation:倾向于选择已知表现好的状态和rollout;
#alpha ** (1 - mc) * beta ** (len(r) / L)
#1. 鼓励使用更大mc(生成包含正确答案可能性更大);
#2. 更短rollout(更短的生成,更可能推理出正确答案)的节点,
def Q(r, mc, alpha  = 0.5, beta = 0.9, L = 500):
    part1 = alpha ** (1 - mc)
    part2 = beta ** (len(r) / L)
    Q_value = part1 * part2
    return Q_value

#exploration:鼓励尝试未充分探索的选项,使用UCB1算法(Upper Confidence Bound 1);
#c_puct * sqrt(N_sum) / (1 + s.v)
#1. s.v:当前状态访问次数,鼓励探索访问次数较少的节点;
#2. N_sum:所有状态的访问次数总和,表示搜索过程的广度和深度,即鼓励更大的搜索树探索程度;
#3. c_puct:控制探索程度的常数;
def U(s, states, c_puct = 0.125):
    N_sum = 0
    for item in states:
        N_sum += item.v
    numerator = math.sqrt(N_sum)
    denominator = 1 + s.v
    U_value = c_puct * (numerator / denominator)
    return U_value

def qu(i, r, mc, ncs):
    q = Q(r, mc)
    u = U(i, ncs)
    return q+u

评估最优节点,是否继续探索?无法探索(完全错误)作为叶子节点,纳入到训练集

#评估最优“state和rollout”,二分rollout的结果,将左半部分纳入到新的state中,并计算新的mc;
def error_locate(s, rollout):
    current_span = rollout
    prev = ''
    divide_roll_pos_st = []
    leaf_st = []
    while True:
        word_count = len(current_span.split())
        if word_count < 2:
            break
        np1, np2 = split_sentence_middle(current_span)
        print('----')
        print(' BS[l]=', np1)
        print(' BS[r]=', np2)
        #二分LLM生成结果rollout,新的prompt:已有生成结果+左半部分
        st = State(s.q, prev + np1, s.a)
        rollouts, corrs = getrollouts(st)
        # mcst = cal_mc(st)
        mcst = cal_mc_bs(st)
        st.mc = mcst
        # case 1: always correct (we are not interested in this kind of state)
        if mcst == 1:
         # leaf_st.append(st)
            break
        # case 2: right span(继续扩展节点)
        elif mcst > 0:
            current_span = np2
            prev = prev + np1
            divide_roll_pos_st.append(st)
        # case 3: left span(这里LLM生成完全没有可能包含正确答案,因此节点扩展terminated)
        elif mcst == 0:
            current_span = np1
            leaf_st.append(st)
        
    #
    print('----')
    return divide_roll_pos_st,leaf_st
引用链接

[1] Improve Mathematical Reasoning in Language...: https://arxiv.org/abs/2406.06592
[2] Math-Shepherd: Verify and Reinforce LLMs Step-by-step without Human Annotations: https://arxiv.org/abs/2312.08935
[3] AlphaMath Almost Zero: Process Supervision without Process: https://arxiv.org/abs/2405.03553
[4] https://github.com/openreasoner/openr: github.com/openreasoner/openr

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/898921.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

Discuz | 起尔开发 传奇开服表游戏公益服发布论坛网站插件

Discuz | 起尔开发 传奇开服表游戏公益服发布论坛网站插件 插件下载&#xff1a;源码 - 起尔开发的插件下载 演示地址&#xff1a;discuz.72jz.com 标黄和非标黄自动分开 在标黄时间内显示在上面置顶&#xff0c;标黄过期后自动显示在下面白色区域。 后台可以设置非标黄默认…

四、多线程带来的的⻛险-线程安全

4.1 观察线程不安全 运行以下代码&#xff1a; package demo02;public class Test {private static int count 0;public static void main(String[] args) throws Exception {Thread t1 new Thread(() -> {for (int i 0; i < 50_000; i) {count;}});Thread t2 new …

通过Docker Compose构建自己的Java项目

通过Docker Compose构建自己的Java项目 前置条件 安装了Docker,未安装的请移步:CentOS7 / CentOS8 安装 Docker-ce安装了Docker-Compose,未安装的请移步:在CentOS7、CentOS8系统下安装Docker Compose1. 配置阿里云镜像仓库 为了提高Docker镜像的下载速度,我们可以配置阿…

版本工具报错:Error Unity Version Control

NotConfiguredClientException: Unity VCS client is not correctly configured for the current user:Client config file.

python 爬虫 入门 三、登录以及代理。

目录 一、登录 &#xff08;一&#xff09;、登录4399 1.直接使用Cookie 2.使用账号密码进行登录 可选观看内容&#xff0c;使用python对密码进行加密&#xff08;无结果代码&#xff0c;只有过程分析&#xff09; 二、代理 免费代理 后续&#xff1a;协程&#xff0c;…

TitanIDE:解锁编程教学新范式

在高校软件工程类课程教育中&#xff0c;传统编程教学方式正面临着多重痛点&#xff1a; 环境配置繁琐&#xff1a;软件工程类课程往往需要学生自行配置复杂的开发环境。但是&#xff0c;学校硬件设备条件差异、软件兼容性问题等因素&#xff0c;导致学生学习效率低下&#xf…

热销王西圣H1头戴式耳机—全平台售罄断货:揭秘抢购潮究其原因?

西圣xisem作为国内平价享轻奢的领军品牌&#xff0c;就在今年它家的头戴式蓝牙耳机性价比标杆—西圣H1&#xff0c;凭借其发烧级的千元音质、降噪与满级的旗舰配置性能&#xff0c;不仅惊艳了整个耳机圈&#xff0c;还在仅仅的几个月内&#xff0c;西圣H1头戴式耳机已经火爆断货…

python 使用gradio启动程序报错

问题一&#xff1a;localhost is not accessible 解决办法&#xff1a; export no_proxy"localhost,127.0.0.1,::1"

C#学习笔记(三)

C#学习笔记&#xff08;三&#xff09; 第 二 章 命名空间和类、数据类型、变量和代码规范二、类的组成和使用分析1. 基本概念2. 类的内容组成3. 方法的初步理解 第 二 章 命名空间和类、数据类型、变量和代码规范 二、类的组成和使用分析 1. 基本概念 类是程序的基本单元&a…

PostgreSQL中触发器递归的处理 | 翻译

许多初学者在某个时候都会陷入触发器递归的陷阱。通常&#xff0c;解决方案是完全避免递归。但对于某些用例&#xff0c;您可能必须处理触发器递归。本文将告诉您有关该主题需要了解的内容。如果您曾经被错误消息“超出堆栈深度限制”所困扰&#xff0c;那么这里就是解决方案。…

Javascript算法——二分查找

1.数组 1.1二分查找 1.搜索索引 开闭matters&#xff01;&#xff01;&#xff01;[left,right]与[left,right) /*** param {number[]} nums* param {number} target* return {number}*/ var search function(nums, target) {let left0;let rightnums.length-1;//[left,rig…

大话网络协议:从OSI七层模型说开去

时至今日,互联网已经是大家日常生活中不可或缺的一部分,购物、点餐、刷剧、网课,已经融入了我们生活的方方面面。但网络具体是怎么工作的呢? 特别是我们具体从事软件研发、ICT行业的同学,理解和掌握这个我们产品运行的基础设施尤为必要。 本文,我们会力争用最简单易懂的…

秋季猫咪疯狂掉毛,宠物空气净化器有用吗?性价比高的该怎么选?

我家猫真的是换季就变掉毛怪&#xff0c;整只猫“虚胖”了一大圈不止&#xff0c;在阳光下可以看见非常多飘在空气中的浮毛。浮毛到处乱飞&#xff0c;沉积在黑色的衣服上&#xff0c;就形成白色的薄膜。自从养猫后&#xff0c;我再也没穿过深色的衣服。 现在每天都给它梳毛&am…

Linux文件的查找和打包以及压缩

文件的查找 文件查找的用处&#xff0c;在我们需要文件但却又不知道文件在哪里的时候 文件查找存在着三种类型的查找 1、which或whereis&#xff1a;查找命令的程序文件位置 2、locate&#xff1a;也是一种文件查找&#xff0c;但是基于数据库的查找 3、find&#xff1a;针…

Vue.js 学习总结(9)—— Vue 3 组件封装技巧

1、需求说明 需求背景&#xff1a;日常开发中&#xff0c;我们经常会使用一些UI组件库诸如and design vue、element plus等辅助开发&#xff0c;提升效率。有时我们需要进行个性化封装&#xff0c;以满足在项目中大量使用的需求。错误示范&#xff1a;基于a-modal封装一个自定…

【AIGC半月报】AIGC大模型启元:2024.10(下)

【AIGC半月报】AIGC大模型启元&#xff1a;2024.10&#xff08;下&#xff09; (1) Janus&#xff08;两面神&#xff09;&#xff08;DeepSeek 1.3B多模态大模型&#xff09;(2) Stable Diffusion 3.5&#xff08;StabilityAI文生图大模型&#xff09;(3) Mochi 1&#xff08;…

Python文件操作(读取、写入、修改和删除)

目录 一、文件的读取 二、文件的写入 三、文件的修改 四、文件的删除 Python是一种功能强大的编程语言&#xff0c;文件操作是编程中常见的需求。本文将详细介绍Python中的文件操作&#xff0c;包括文件的读取、写入、修改和删除&#xff0c;帮助读者掌握Python文件操作的基…

分布式系统之异步与消息队列(MQ)(原理+代码实战一文讲清!)

异步 什么是异步 异步编程是一种编程范式&#xff0c;它允许程序在等待操作完成&#xff08;如等待网络响应、文件读写等&#xff09;时继续执行其他任务。这种编程方式对于提高程序的性能和响应性至关重要&#xff0c;尤其是在处理耗时操作或在资源受限的环境中。下面我将更…

山东以“八策并举”确保人民满意学前教育“普惠落地”

10月19日-22日&#xff0c;2024年中国学前教育研究会学术年会在山东国际会展中心召开。年会围绕“优质普惠可持续——加强学前教育高质量发展的法治保障”主题&#xff0c;通过5场主旨报告、28个园所观摩、10个分论坛交流研讨&#xff0c;为2200余名嘉宾提供智慧盛宴。成为近年…

URP学习四

一.Bilt To RTHandle feature代码&#xff1a; 二.DistortTunnel 只有个飞机却有很多太空场景。因为设置了其他pass来渲染背景 队列添加3个Pass&#xff1a; 第一个Pass把颜色图进行输出 第二个Pass&#xff1a;创建了个纹理 加了个扰动&#xff0c;把纹理进行输出 第三个pas…