作者:cmathx
原文:https://zhuanlan.zhihu.com/p/1477078851
openai o1复现中,有个比较关键的问题,怎么样自动化构造prm模型的训练数据?本文主要从代码层面,来解析OmegaPRM原理。
论文
Improve Mathematical Reasoning in Language...[1]
原理
Markov决策过程
OmegaPRM
State:对应Markov决策过程中的状态,rollout:对应Markov决策过程中的动作;
-
• step1:初始化root节点state;每个state包含n个扩展rollouts,q+pa作为prompt,进行n次llm生成采样;基于bootstrap采样方法估计Monte Carlo模拟正确答案的概率mc;
-
• step2:从所有节点中,基于UCB1(Explore&&Exploit方法)选取最优的“state和rollout”,添加到PRM训练集;Exploit:alpha ** (1 - mc) * beta ** (len(r) / L),其中:mc表示蒙特卡洛模拟正确答案概率、len(r)表示LLM生成的长度;Explore:c_puct * sqrt(N_sum) / (1 + s.v),其中:N_sum表示所有节点的访问次数,s.v表示当前节点的访问次数,c_puct控制MCTS树的探索程度;
-
• step3:评估最优“state和rollout”,二分rollout的结果,将左半部分纳入到新的state中,并计算新的mc;mc=1,表示state完全包含正确答案,忽略;mc=0,表示state完全没有生成正确答案可能性,添加到叶子节点;mc>0,表示state作为继续探索的节点;
-
• step4:重复step2、step3,直至“探索到足够的样本、无法继续探索”退出;
-
• step5:将叶子节点全部添加到PRM训练集;
PRM模型训练效果
论文的base模型
基于OmegaPRM方法合成数据,在MATH数据集,相比base model51%的准确率,OmegaPRM准确率提高到69.4%;
其他PRM方法
OmegaPRM:gemini提到的方法;
AlphaMath:qwen提到的方法;
Math-Shepherd: Verify and Reinforce LLMs Step-by-step without Human Annotations[2]
AlphaMath Almost Zero: Process Supervision without Process[3]
源码来源
https://github.com/openreasoner/openr[4]
源码解析
数据结构
class State:
def __init__(self, q, pa, a):
self.q = q #问题
self.pa = pa #当前step的prompt
self.a = a #答案
self.mc = None #基于当前节点,生成正确答案的概率
self.v = 0 #被访问次数
self.rollouts = [] #扩展的子节点
self.rollout_was_visited = [] #扩展的子节点是否被访问
主流程
# Load the JSON data
data = load_json_file(json_file_path)
# Process each problem and its final answer
for i, item in enumerate(data):
problem = item.get('problem', 'No problem found')
final_answer = item.get('final_answer', 'No answer found')
# Print to console
print(f'Problem {i + 1}: {problem}')
print(f'Final Answer: {final_answer}')
# Log each problem and answer
logging.info(f'Processed Problem {i + 1}: {problem}')
logging.info(f'Final Answer: {final_answer}')
# Call getrollout and handle the result
states = []
root = State(problem, '', final_answer)
max_roll_num = 20
rollouts, corrs = getrollouts(root, max_roll_num)
mcst = cal_mc_bs(root)
root.mc = mcst
# 生成root节点
states.append(root)
if sum(corrs) > 0 and sum(corrs) < max_roll_num:
print('Process annotation ...\n')
filename = str(i+1) +'_states_list.json'
# 生成PRM训练数据
process_annotation(problem, final_answer, states, filename)
蒙特卡洛采样
#针对节点s进行n次采样,基于LLM生成n个rollouts,并给出每个rollout是否包含正确答案;
def getrollouts(s, n = 5):
corrs = []
q = s.q
pa = s.pa
for i in range(n):
re = complete_answer(q, pa)
s.add_rollout(re)
#check the answer
a = s.a
if check_answer(a, re):
corrs.append(1)
else:
corrs.append(0)
return s.rollouts, corrs
#蒙特卡洛采样,并给出包含正确答案的概率
def cal_mc_bs(s, bs = 5):
n = len(s.rollouts)
subn = max(1,random.randint(n//2, n))
mc = 0
for i in range(bs):
corr = 0
sub = random.sample(s.rollouts, subn)
for r in sub:
if check_answer(s.a, r):
corr += 1
mc += corr * 1.0 / len(sub)
return mc / bs
#针对问题problem,使用problem+partial_answer作为prompt,进行LLM生成
complete_answer(problem, partial_answer, checkpoint)
#LLM生成的response是否包含正确答案groundtruth_answer
check_answer(groundtruth_answer, response)
基于mcts方法自动构造prm训练数据
#基于MCTS方法生成PRM训练数据
def process_annotation(q, a, states, filename = 'states_list.json'):
print('++++++')
it = 0
leaf_states = []
while True:
s, rollout, maxqu = select(states)
if s is not None and s.pa!='':
new_data = {
'q': q, # Ensure q is serializable
'states': s.pa, # Ensure states is serializable
'mcs': s.mc # Ensure mcs is serializable
}
# Call the function to append the new data
append_to_json_file(filename, new_data)
it += 1
if it > 100:
break
# all state-rolls pairs were exhausted
if s is None:
break
print()
print('[sel]')
print(s)
print(' roll=',rollout,' || qu=', maxqu)
s.add_visit()
div_roll_sts,leaf_sts = error_locate(s, rollout)
if len(div_roll_sts)==0:
continue
states.extend([s for s in div_roll_sts if s!=None and s.pa != ''])
leaf_states.extend(leaf_sts)
#
## add leaf states to data
for s in leaf_states:
new_data = {
'q': q, # Ensure q is serializable
'states': s.pa, # Ensure states is serializable
'mcs': s.mc # Ensure mcs is serializable
}
# Call the function to append the new data
append_to_json_file(filename, new_data)
print('++++++')
基于UCB1方法,选择最优的节点,纳入到训练集
#选择当前最优的节点
#exploitation:使用“更大的mc、更短的llm生成”节点;
#exploration:探索“未充分访问的、更大的树探索程度”节点;
def select(states):
best_st = None
best_roll_idx = -1
best_qu = -1
for s in states:
# mcs = cal_mc(s) if s.mc is None else s.mc
mcs = cal_mc_bs(s) if s.mc is None else s.mc
if mcs == 0 or mcs==1.0:
continue
for i,r in enumerate(s.rollouts):
if s.rollout_was_visited[i]:
continue
q = Q(r, mcs)
u = U(s,states)
qu = q + u
if qu > best_qu:
best_st = s
best_roll_idx = i
best_qu = qu
#
if best_roll_idx != -1:
best_st.rollout_was_visited[best_roll_idx] = True
return best_st,best_st.rollouts[best_roll_idx],best_qu
#exploitation:倾向于选择已知表现好的状态和rollout;
#alpha ** (1 - mc) * beta ** (len(r) / L)
#1. 鼓励使用更大mc(生成包含正确答案可能性更大);
#2. 更短rollout(更短的生成,更可能推理出正确答案)的节点,
def Q(r, mc, alpha = 0.5, beta = 0.9, L = 500):
part1 = alpha ** (1 - mc)
part2 = beta ** (len(r) / L)
Q_value = part1 * part2
return Q_value
#exploration:鼓励尝试未充分探索的选项,使用UCB1算法(Upper Confidence Bound 1);
#c_puct * sqrt(N_sum) / (1 + s.v)
#1. s.v:当前状态访问次数,鼓励探索访问次数较少的节点;
#2. N_sum:所有状态的访问次数总和,表示搜索过程的广度和深度,即鼓励更大的搜索树探索程度;
#3. c_puct:控制探索程度的常数;
def U(s, states, c_puct = 0.125):
N_sum = 0
for item in states:
N_sum += item.v
numerator = math.sqrt(N_sum)
denominator = 1 + s.v
U_value = c_puct * (numerator / denominator)
return U_value
def qu(i, r, mc, ncs):
q = Q(r, mc)
u = U(i, ncs)
return q+u
评估最优节点,是否继续探索?无法探索(完全错误)作为叶子节点,纳入到训练集
#评估最优“state和rollout”,二分rollout的结果,将左半部分纳入到新的state中,并计算新的mc;
def error_locate(s, rollout):
current_span = rollout
prev = ''
divide_roll_pos_st = []
leaf_st = []
while True:
word_count = len(current_span.split())
if word_count < 2:
break
np1, np2 = split_sentence_middle(current_span)
print('----')
print(' BS[l]=', np1)
print(' BS[r]=', np2)
#二分LLM生成结果rollout,新的prompt:已有生成结果+左半部分
st = State(s.q, prev + np1, s.a)
rollouts, corrs = getrollouts(st)
# mcst = cal_mc(st)
mcst = cal_mc_bs(st)
st.mc = mcst
# case 1: always correct (we are not interested in this kind of state)
if mcst == 1:
# leaf_st.append(st)
break
# case 2: right span(继续扩展节点)
elif mcst > 0:
current_span = np2
prev = prev + np1
divide_roll_pos_st.append(st)
# case 3: left span(这里LLM生成完全没有可能包含正确答案,因此节点扩展terminated)
elif mcst == 0:
current_span = np1
leaf_st.append(st)
#
print('----')
return divide_roll_pos_st,leaf_st
引用链接
[1]
Improve Mathematical Reasoning in Language...: https://arxiv.org/abs/2406.06592[2]
Math-Shepherd: Verify and Reinforce LLMs Step-by-step without Human Annotations: https://arxiv.org/abs/2312.08935[3]
AlphaMath Almost Zero: Process Supervision without Process: https://arxiv.org/abs/2405.03553[4]
https://github.com/openreasoner/openr: github.com/openreasoner/openr