目录
- 4.6 GRPO训练过程
- 4.6.1 GRPO原理
- 4.6.2 设置参考模型
- 4.6.3 从训练集中抽取问题
- 4.6.4 旧策略模型生成G个输出
- 4.6.5 对每个输出用奖励模型 RM 打分
- 4.6.6 根据目标函数做梯度更新
【复现DeepSeek-R1之Open R1实战】系列博文链接:
【复现DeepSeek-R1之Open R1实战】系列1:跑通SFT(一步步操作,手把手教学)
【复现DeepSeek-R1之Open R1实战】系列2:没有卡也能训模型!Colab跑OpenR1(附源码)
【复现DeepSeek-R1之Open R1实战】系列3:基础知识介绍
【复现DeepSeek-R1之Open R1实战】系列4:跑通GRPO!
【复现DeepSeek-R1之Open R1实战】系列5:SFT源码逐行深度解析
【复现DeepSeek-R1之Open R1实战】系列6:GRPO源码结构解析
【复现DeepSeek-R1之Open R1实战】系列7:GRPO原理介绍、训练流程和源码深度解析
4.6 GRPO训练过程
我们挑一些重点部分的代码来分析。
4.6.1 GRPO原理
在分析源码之前,我们再回顾一下GRPO的原理。
强化学习的介绍可以参考该博文:【DeepSeek-R1背后的技术】系列三:强化学习(Reinforcement Learning, RL)。
核心思想如下图所示:
核心动机:在许多实际应用中,奖励只有在序列末端才给一个分数(称之为 Result/Oucome Supervision),或在每一步给一些局部分数(Process Supervision)。不管怎么样,这个奖励本身往往是离散且比较稀疏的,要让价值网络去学习每个token的价值,可能并不划算。而如果我们在同一个问题 q 上采样多份输出 o1, o2, … , oG,经过奖励模型Reward Model之后得到对应到奖励 r1, r2, … , rG,对它们进行奖励对比,就能更好地推断哪些输出更好。由此,就能对每个输出的所有 token 做相对评分,无须明确地学到一个价值函数。
在数理推理、数学解题等场景,这个技巧尤其管用,因为常常会基于同一个题目 q 生成多个候选输出,有对有错,或者优劣程度不同。那就把它们的奖励进行一个分组内的比较,以获取相对差异,然后把相对优势视为更新策略的依据。
关键点1:分组采样与相对奖励
GRPO 中,“分组”非常关键:我们会在一个问题 q 上,采样 GRPO 份输出 o1, o2, … , oG,然后把这组输出一起送进奖励模型(或规则),得到奖励分 r = {r1, r2, … , rG},先对r做归一化(减去均值除以标准差),从而得出分组内的相对水平,这样就形成了相对奖励 r’i,最后我们把这个相对奖励赋给该输出对应的所有 token 的优势函数。简单来说:多生成几份答案,一起比较,再根据排名或分数差更新,能更直接、简洁地反映同一问题下的优劣关系,而不需要用一个显式的价值网络去学习所有中间时刻的估计。
关键点2:无需价值网络的高效策略优化
因为不再需要在每个 token 上拟合一个价值函数,我们就能大幅节省内存,因为不必再维护和 Actor 同样大的 Critic 模型。这不仅是存储层面的解放,也是训练过程中的显著加速。当然,GRPO 也会引入一些新的代价:我们要为每个问题采样一组输出(不止一条),意味着推理时要多花点算力去生成候选答案。这种方法和“自洽性采样(Self-consistency)”思路也有点类似。
具体流程如下:
分组相对奖励A’i,t的计算方法:
我们先把每个oi的奖励ri做归一化 r’i = ( ri - mean( r ) ) / std( r ),然后令A’i,t = r’i,也就是说,输出oi的所有 token 共享同一个分数r’i。它们的好坏相对于该分组内的平均水平来衡量,而不依赖外部价值网络去“拆分”或“插值”。这样我们就得到了一个无价值网络的优势函数,核心思路就是基于相互间的比较与排序。
如果用的是过程监督(process supervision),即在推理过程中的每个关键步骤都打分,那么就会略有不同。那时每个步骤都有一个局部奖励,就可以把它依时间序列累加或折算成与 token 对应的优势。
过程监督VS结果监督:过程奖励与末端奖励的对比
- 结果监督(Outcome Supervision):只有输出序列结束才打一个奖励,如回答对/错、得分多少。GRPO 则把这个 r rr 同样分配给序列里每个 token。
- 过程监督(Process Supervision):对中间推理步骤也有打分(比如计算正确一步就+1,错误一步就-1)。那就得收集多个时刻的奖励,然后累加到每个 token 或步骤上,再做分组相对化。
那么问题来了,batch内如何分组?在实际操作中,我们往往会在一个 batch 中包含若干个问题 q ,对每个问题生成 G 个答案。也就是说 batch 大小 = B,每个问题生成 G 个候选,那么一次前向推理要生成 B ∗ G 条候选。然后,每个候选都送进奖励模型得到分数ri。这样做推理开销不小,如果 G 较大,会显著地增加生成次数,但换来的好处是,我们不再需要价值网络了。
延伸:迭代式强化学习——奖励模型的更新与回放机制
在实际用 GRPO 的时候,如果奖励模型 RM 也是学习得来的,那么当策略模型变强时,RM 所得到的训练样本分布会越来越“难”,这时 RM 自身也需要更新。这样就会出现迭代强化学习流程:先用当前 RM 来指导一轮策略更新,然后再用新策略生成的数据来更新 RM。为了避免灾难性遗忘,可以保留一部分旧数据(回放机制 replay buffer),让 RM 每次都在新旧数据上共同训练,这样 RM 不会完全忘记之前的问题特征。
接下来,我们就按照上面的流程,详细解读源码。
4.6.2 设置参考模型
# Reference model
if is_deepspeed_zero3_enabled():
self.ref_model = AutoModelForCausalLM.from_pretrained(model_id, **model_init_kwargs)
elif not is_peft_model(model):
# If PEFT configuration is not provided, create a reference model based on the initial model.
self.ref_model = create_reference_model(model)
else:
# If PEFT is used, the reference model is not needed since the adapter can be disabled
# to revert to the initial model.
self.ref_model = None
这段代码片段展示了如何根据不同的条件创建或配置一个参考模型(ref_model
),主要用于深度学习中的模型训练和评估。以下是详细的解析:
-
DeepSpeed ZeRO-3 启用时:
is_deepspeed_zero3_enabled()
:检查是否启用了 DeepSpeed 的 ZeRO-3 零冗余优化器。- 如果启用,则从预训练模型中加载一个因果语言模型(Causal Language Model)作为参考模型,并使用
model_init_kwargs
中的参数进行初始化。
-
PEFT 模型未启用时:
is_peft_model(model)
:检查当前模型是否是 PEFT(Parameter-Efficient Fine-Tuning)模型。- 如果不是 PEFT 模型,则调用
create_reference_model(model)
创建一个基于初始模型的参考模型。
-
PEFT 模型启用时:
- 如果是 PEFT 模型,则不需要创建参考模型,因为可以通过禁用适配器(adapter)来恢复到初始模型状态,因此将
self.ref_model
设为None
。
- 如果是 PEFT 模型,则不需要创建参考模型,因为可以通过禁用适配器(adapter)来恢复到初始模型状态,因此将
4.6.3 从训练集中抽取问题
采样器是RepeatRandomSampler类,主要是通过_prepare_inputs函数准备输入数据的。
def _prepare_inputs(self, inputs: dict[str, Union[torch.Tensor, Any]]) -> dict[str, Union[torch.Tensor, Any]]:
class RepeatRandomSampler(Sampler):
"""
Sampler that repeats the indices of a dataset N times.
Args:
data_source (`Sized`):
Dataset to sample from.
repeat_count (`int`):
Number of times to repeat each index.
seed (`Optional[int]`):
Random seed for reproducibility (only affects this sampler).
Example:
```python
>>> sampler = RepeatRandomSampler(["a", "b", "c", "d"], repeat_count=2)
>>> list(sampler)
[2, 2, 0, 0, 3, 3, 1, 1]
```
"""
def __init__(self, data_source: Sized, repeat_count: int, seed: Optional[int] = None):
self.data_source = data_source
self.repeat_count = repeat_count
self.num_samples = len(data_source)
self.seed = seed
self.generator = torch.Generator() # Create a local random generator
if seed is not None:
self.generator.manual_seed(seed)
def __iter__(self):
indexes = [
idx
for idx in torch.randperm(self.num_samples, generator=self.generator).tolist()
for _ in range(self.repeat_count)
]
return iter(indexes)
def __len__(self):
return self.num_samples * self.repeat_count
def _get_train_sampler(self) -> Sampler:
# Returns a sampler that ensures each prompt is repeated across multiple processes. This guarantees that
# identical prompts are distributed to different GPUs, allowing rewards to be computed and normalized correctly
# within each prompt group. Using the same seed across processes ensures consistent prompt assignment,
# preventing discrepancies in group formation.
return RepeatRandomSampler(self.train_dataset, self.num_generations, seed=self.args.seed)
def _get_eval_sampler(self, eval_dataset) -> Sampler:
# Returns a sampler that ensures each prompt is repeated across multiple processes. This guarantees that
# identical prompts are distributed to different GPUs, allowing rewards to be computed and normalized correctly
# within each prompt group. Using the same seed across processes ensures consistent prompt assignment,
# preventing discrepancies in group formation.
return RepeatRandomSampler(eval_dataset, self.num_generations, seed=self.args.seed)
4.6.4 旧策略模型生成G个输出
同样在_prepare_inputs函数函数里,通过self.llm.generate()函数生成了G个输出,并做了一系列后处理操作:
# Generate completions using either vLLM or regular generation
if self.args.use_vllm:
# First, have main process load weights if needed
if self.state.global_step != self._last_loaded_step:
self._move_model_to_vllm()
self._last_loaded_step = self.state.global_step
# Generate completions using vLLM: gather all prompts and use them in a single call in the main process
all_prompts_text = gather_object(prompts_text)
if self.accelerator.is_main_process:
outputs = self.llm.generate(all_prompts_text, sampling_params=self.sampling_params, use_tqdm=False)
completion_ids = [out.token_ids for completions in outputs for out in completions.outputs]
else:
completion_ids = [None] * len(all_prompts_text)
# Broadcast the completions from the main process to all processes, ensuring each process receives its
# corresponding slice.
completion_ids = broadcast_object_list(completion_ids, from_process=0)
process_slice = slice(
self.accelerator.process_index * len(prompts),
(self.accelerator.process_index + 1) * len(prompts),
)
completion_ids = completion_ids[process_slice]
# Pad the completions, and concatenate them with the prompts
completion_ids = [torch.tensor(ids, device=device) for ids in completion_ids]
completion_ids = pad(completion_ids, padding_value=self.processing_class.pad_token_id)
prompt_completion_ids = torch.cat([prompt_ids, completion_ids], dim=1)
else:
# Regular generation path
with unwrap_model_for_generation(self.model, self.accelerator) as unwrapped_model:
prompt_completion_ids = unwrapped_model.generate(
prompt_ids, attention_mask=prompt_mask, generation_config=self.generation_config
)
# Compute prompt length and extract completion ids
prompt_length = prompt_ids.size(1)
prompt_ids = prompt_completion_ids[:, :prompt_length]
completion_ids = prompt_completion_ids[:, prompt_length:]
4.6.5 对每个输出用奖励模型 RM 打分
- Reward Model初始化
# Reward functions
if not isinstance(reward_funcs, list):
reward_funcs = [reward_funcs]
for i, reward_func in enumerate(reward_funcs):
if isinstance(reward_func, str):
reward_funcs[i] = AutoModelForSequenceClassification.from_pretrained(
reward_func, num_labels=1, **model_init_kwargs
)
self.reward_funcs = reward_funcs
# Reward weights
if args.reward_weights is not None:
if len(args.reward_weights) != len(reward_funcs):
raise ValueError(
f"Number of reward weights ({len(args.reward_weights)}) must match number of reward "
f"functions ({len(reward_funcs)})"
)
self.reward_weights = torch.tensor(args.reward_weights, dtype=torch.float32)
else:
self.reward_weights = torch.ones(len(reward_funcs), dtype=torch.float32)
- 计算奖励分数
计算过程是在_prepare_inputs函数里实现的,主要功能模块如代码注释所示,整个计算过程和我们在上一节原理介绍里能一一对应上:
rewards_per_func = torch.zeros(len(prompts), len(self.reward_funcs), device=device)
for i, (reward_func, reward_processing_class) in enumerate(
zip(self.reward_funcs, self.reward_processing_classes)
):
if isinstance(reward_func, nn.Module): # Module instead of PretrainedModel for compat with compiled models
if is_conversational(inputs[0]):
messages = [{"messages": p + c} for p, c in zip(prompts, completions)]
texts = [apply_chat_template(x, reward_processing_class)["text"] for x in messages]
else:
texts = [p + c for p, c in zip(prompts, completions)]
reward_inputs = reward_processing_class(
texts, return_tensors="pt", padding=True, padding_side="right", add_special_tokens=False
)
reward_inputs = super()._prepare_inputs(reward_inputs)
with torch.inference_mode():
rewards_per_func[:, i] = reward_func(**reward_inputs).logits[:, 0] # Shape (B*G,)
else:
# Repeat all input columns (but "prompt" and "completion") to match the number of generations
keys = [key for key in inputs[0] if key not in ["prompt", "completion"]]
reward_kwargs = {key: [example[key] for example in inputs] for key in keys}
output_reward_func = reward_func(prompts=prompts, completions=completions, **reward_kwargs)
rewards_per_func[:, i] = torch.tensor(output_reward_func, dtype=torch.float32, device=device)
# Gather the reward per function: this part is crucial, because the rewards are normalized per group and the
# completions may be distributed across processes
rewards_per_func = gather(rewards_per_func)
# Apply weights to each reward function's output and sum
rewards = (rewards_per_func * self.reward_weights.to(device).unsqueeze(0)).sum(dim=1)
# Compute grouped-wise rewards
mean_grouped_rewards = rewards.view(-1, self.num_generations).mean(dim=1)
std_grouped_rewards = rewards.view(-1, self.num_generations).std(dim=1)
# Normalize the rewards to compute the advantages
mean_grouped_rewards = mean_grouped_rewards.repeat_interleave(self.num_generations, dim=0)
std_grouped_rewards = std_grouped_rewards.repeat_interleave(self.num_generations, dim=0)
advantages = (rewards - mean_grouped_rewards) / (std_grouped_rewards + 1e-4)
# Slice to keep only the local part of the data
process_slice = slice(
self.accelerator.process_index * len(prompts),
(self.accelerator.process_index + 1) * len(prompts),
)
advantages = advantages[process_slice]
4.6.6 根据目标函数做梯度更新
在compute_loss函数里,根据每个生成的优势分数advantages计算对应的损失,并加上KL正则。
梯度更新在Trainer的主函数train()里实现了,可以参考前一篇博文介绍:【复现DeepSeek-R1之Open R1实战】系列5:SFT源码逐行深度解析。
def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=None):
if return_outputs:
raise ValueError("The GRPOTrainer does not support returning outputs")
# Compute the per-token log probabilities for the model
prompt_ids, prompt_mask = inputs["prompt_ids"], inputs["prompt_mask"]
completion_ids, completion_mask = inputs["completion_ids"], inputs["completion_mask"]
input_ids = torch.cat([prompt_ids, completion_ids], dim=1)
attention_mask = torch.cat([prompt_mask, completion_mask], dim=1)
logits_to_keep = completion_ids.size(1) # we only need to compute the logits for the completion tokens
per_token_logps = self._get_per_token_logps(model, input_ids, attention_mask, logits_to_keep)
# Compute the KL divergence between the model and the reference model
ref_per_token_logps = inputs["ref_per_token_logps"]
per_token_kl = torch.exp(ref_per_token_logps - per_token_logps) - (ref_per_token_logps - per_token_logps) - 1
# x - x.detach() allows for preserving gradients from x
advantages = inputs["advantages"]
per_token_loss = torch.exp(per_token_logps - per_token_logps.detach()) * advantages.unsqueeze(1)
per_token_loss = -(per_token_loss - self.beta * per_token_kl)
loss = ((per_token_loss * completion_mask).sum(dim=1) / completion_mask.sum(dim=1)).mean()
# Log the metrics
completion_length = self.accelerator.gather_for_metrics(completion_mask.sum(1)).float().mean().item()
self._metrics["completion_length"].append(completion_length)
mean_kl = ((per_token_kl * completion_mask).sum(dim=1) / completion_mask.sum(dim=1)).mean()
self._metrics["kl"].append(self.accelerator.gather_for_metrics(mean_kl).mean().item())
return loss