欢迎关注我的CSDN:https://spike.blog.csdn.net/
本文地址:https://spike.blog.csdn.net/article/details/145056912
Llama 3 是 Meta 公司发布的开源大型语言模型,包括具有 80 亿和 700 亿参数的预训练和指令微调的语言模型,支持广泛的应用场景。在多个行业标准基准测试中展示了最先进的性能,特别是在推理、代码生成和指令遵循方面表现出色,超过了同等规模的商业模型。
Llama 3 Paper: The Llama 3 Herd of Models,Llama3 模型群
参考:大模型训练 RLHF 阶段的 PPO/DPO 策略公式与源码
1. Llama 3 Loss
Llama 3 主要包括 2个阶段,即 预训练阶段(Pre-Training) 和 后训练阶段(Post-Training)。
Llama 3 的网络架构,如下:
1.1 预训练阶段(Pre-Training)
预训练阶段(Pre-Training),包括:
- 初始预训练(Initial Pre-Training):使用 AdamW 优化器对 Llama 3 405B 进行预训练,峰值学习率为
8
×
1
0
−
5
8 × 10^{−5}
8×10−5;线性预热步数为 8,000 步,采用余弦学习率调度,在 1,200,000 步内衰减至
8
×
1
0
−
7
8 × 10^{−7}
8×10−7 ;为了提高训练稳定性,在训练初期使用较小的 BatchSize,在后续逐步增加以提高效率。具体:
- 初始使用 Batch Size 是 4M Tokens 和 序列(Sequences)长度是 4,096 (4K) Tokens,预训练 252M 个 Token。
- 将 Batch Size 和序列长度,增加至 8M Tokens (batch size) 和 8,192 (8K) Tokens (sequences),预训练 2.87T 个 Token。
- 再次将批量大小增加到 16M。
- 降低到损失值的突刺(Spikes),不需要进行干预以纠正模型训练的发散(Divergence)。
- 长上下文预训练(Long Context Pre-Training):使用 800B 个训练 Token,上下文长度增加到 6 个阶段,从 8K 的上下文窗口开始,达到 128K。在长上下文预训练中,自注意力层的计算量,随着序列长度的平方增长。评估模型适应长上下文的标准:
- 模型在 短上下文 评估中的性能是否完全恢复。
- 在特定长度中,模型是否能够完美解决 大海捞针(needle in a haystack) 任务。
- 退火阶段(Annealing):最后 400M 个 token 预训练,线性的(Linearly) 将学习率退火到 0,同时,保持上下文长度 128K Token。使用少量高质量的代码和数学数据,进行退火,提高预训练模型在关键基准测试上的性能。退火阶段,在 8B 模型中效果明显,但是,在 405B 模型中改进较小。
损失函数都是 交叉熵损失,其中 w i w_{i} wi 是第 i i i 个词, N N N 是序列长度, p p p 是概率(softmax), C C C 是 类别数(也就是输出维度),即:
L o s s = − 1 N ∑ i = 1 N l o g p ( w i ∣ w 1 , w 2 , . . . , w i − 1 ) L o s s = − 1 N ∑ i = 1 N ∑ j = 1 C y i j l o g ( p i j ) \begin{align} Loss &= -\frac{1}{N}\sum_{i=1}^{N}log \ p(w_{i}|w_{1},w_{2},...,w_{i-1}) \\ Loss &= -\frac{1}{N}\sum_{i=1}^{N}\sum_{j=1}^{C}y_{ij}log(p_{ij}) \end{align} LossLoss=−N1i=1∑Nlog p(wi∣w1,w2,...,wi−1)=−N1i=1∑Nj=1∑Cyijlog(pij)
这 2 个公式的含义是一样的,第 1 个是选择 w i w_{i} wi ,第 2 个是通过 y i j = 1 y_{ij}=1 yij=1 确定 w i w_{i} wi 。
1.2 后训练阶段(Post-Training)
后训练阶段(Post-Training),包括:
- 奖励模型(Reward Modeling, RM):偏好数据(Preference data) 包含 3 个经过排序的响应,即 编辑后(edited) > 所选(chosen) > 被拒(rejected)。
Reward Model 训练 Loss 函数:
L
o
s
s
=
−
l
o
g
σ
(
r
ϕ
(
x
,
y
w
i
n
)
−
r
ϕ
(
x
,
y
l
o
s
s
)
)
Loss = -log \ \sigma(r_{\phi}(x,y_{win}) - r_{\phi}(x,y_{loss}))
Loss=−log σ(rϕ(x,ywin)−rϕ(x,yloss))
源码:
loss = -torch.nn.functional.logsigmoid(chosen_scores.float() - rejected_scores.float()).mean()
-
监督微调 (Supervised Finetuning, SFT):使用奖励模型对于人工标注提示进行拒绝采样(Rejection Sampling)。 SFT 的 Loss 函数与 PreTraining 阶段一致,数据略有不同。
-
直接偏好优化 (Direct Preference Optimization, DPO),其中,在 Llama3 中,学习率 L R = 1 0 − 5 LR=10^{-5} LR=10−5,超参数 β = 0.1 \beta=0.1 β=0.1, α = 0.2 \alpha=0.2 α=0.2,即:
L o s s D P O = L o s s D P O + α L o s s N L L = − l o g σ ( β l o g π θ ( y w i n ∣ x ) π r e f ( y w i n ∣ x ) − β l o g π θ ( y l o s e ∣ x ) π r e f ( y l o s e ∣ x ) ) − α l o g π θ ( y w i n ∣ x ) \begin{align} Loss_{DPO} &= Loss_{DPO} + \alpha Loss_{NLL} \\ &= - log \ \sigma(\beta log\frac{\pi_{\theta}(y_{win}|x)}{\pi_{ref}(y_{win}|x)} - \beta log\frac{\pi_{\theta}(y_{lose}|x)}{\pi_{ref}(y_{lose}|x)}) - \alpha log \ \pi_{\theta}(y_{win}|x) \end{align} LossDPO=LossDPO+αLossNLL=−log σ(βlogπref(ywin∣x)πθ(ywin∣x)−βlogπref(ylose∣x)πθ(ylose∣x))−αlog πθ(ywin∣x)
DPO 训练的改进如下:
- 在 DPO 损失中,在 所选(chosen) 和 被拒(rejected) 的响应中,屏蔽 特殊格式令牌(Special Formatting Tokens)。学习这些标记,导致模型出现不期望的行为,例如尾部重复或突然生成终止标记。可能是由于 DPO 损失的对比性质,在选定和拒绝的响应中都存在常见标记,这导致了冲突的学习目标,因为模型需要同时增加和减少这些标记的出现概率。
- DPO 损失 增加 NLL (Negative Log-Likelihood, 负对数似然) 损失,避免 所选(chosen) 响应的对数概率下降。参考论文 Iterative Reasoning Preference Optimization 。
- 模型平均(Model Averaging):训练多个模型,在不同的数据集、不同的初始化、不同的学习率或其他超参数设置下进行训练,对于模型权重进行平均,可以使用简单的算术平均,也可以使用加权平均。
- 迭代轮次(Iterative Rounds):划分 6 个轮次应用上述方法(RM、SFT、DPO、MA)。在每个轮次中,收集新的偏好标注和 SFT数据,以及从最新模型中采样合成数据。
Llama 3 的网络参数:
2. 补充内容
前置概念:
- PreTraining 阶段 与 SFT 阶段 的差异
- logits 含义
- logps 含义
1.1 PreTraining 阶段 与 SFT 阶段 的差异
PreTraining 与 SFT 在训练过程中,没有任何区别,主要区别在于数据的组成形式上,包括 6 点,即:
- PreTraining 样本数据都是满编 4K / 8K;SFT 样本数据保持不变,原始多长就是多长。
- SFT 引入 PreTraining 阶段中未见过的
special_token
,让模型学习全新的语义。 - SFT 让模型 重新学习 最重要的
eos_token
,停止生成;PreTraining 阶段eos_token
只是作为样本的一部分,无法停止生成。 - SFT 借助
special_token
,把语料切分成不同的角色,标配包括system
、user
、assistant
,根据业务需求也可以自定义。 - SFT 的
prompt
部分不做loss
反传,原因是 prompt 的同质化严重,如果不做loss_mask
,同样的一句话会被翻来覆去的学。如果保证每条prompt
都是独一无二的,也可以省略prompt
的loss_mask
部分。 - SFT 的 session 数据(多轮对话),可以每一个
answer
都计算loss
,也可以只对最后一轮的answer
计算loss
。
两者的训练目的也不一样,PreTraining 是在背书,纯粹的学习知识;SFT 则是在做题,学习的是指令 follow 能力。
整体的数据源码,参考 Llama-Factory 的 src/llamafactory/data/preprocess.py
文件,包括不同阶段的数据处理,即 pt
、sft
、rm
、kto
等。
其中,PreTraining 数据源码(src/llamafactory/data/processors/pretrain.py
),
def preprocess_pretrain_dataset(
examples: Dict[str, List[Any]], tokenizer: "PreTrainedTokenizer", data_args: "DataArguments"
) -> Dict[str, List[Any]]:
# build grouped texts with format `X1 X2 X3 ...` if packing is enabled
eos_token = "<|end_of_text|>" if data_args.template == "llama3" else tokenizer.eos_token
text_examples = [messages[0]["content"] + eos_token for messages in examples["_prompt"]]
tokenized_examples = tokenizer(text_examples, add_special_tokens=False) # add_special_tokens=False
concatenated_examples = {k: list(chain(*tokenized_examples[k])) for k in tokenized_examples.keys()}
total_length = len(concatenated_examples[list(concatenated_examples.keys())[0]])
block_size = data_args.cutoff_len
total_length = (total_length // block_size) * block_size
result = {
k: [t[i : i + block_size] for i in range(0, total_length, block_size)]
for k, t in concatenated_examples.items()
}
return result
其中,SFT 的 prompt mask 源码(src/llamafactory/data/processors/supervised.py
),参考:
IGNORE_INDEX = -100
if train_on_prompt:
source_label = source_ids
elif template.efficient_eos:
# mask 掉 prompt(source) 部分,保留 answer(target) 部分
source_label = [tokenizer.eos_token_id] + [IGNORE_INDEX] * (source_len - 1)
else:
source_label = [IGNORE_INDEX] * source_len
if mask_history and turn_idx != 0: # train on the last turn only
# 只训练最后 1 轮
target_label = [IGNORE_INDEX] * target_len
else:
target_label = target_ids
PreTraining 阶段的数据:
- 数据组成形式:
- 输入 input:
<bos> X1 X2 X3
- 标签 labels:
X1 X2 X3 </s>
- 输入 input:
- 典型的 Decoder 架构的数据训练方式
SFT 阶段的数据:
- 数据组成形式:
- 输入 input:
<bos> prompt response
- 标签 labels:
-100 ... -100 response </s>
- 输入 input:
- labels 的重点在于prompt部分的被 -100 所填充
训练过程源码,参考 transformers/src/transformers/models/llama/modeling_llama.py
,即:
logits = self.lm_head(hidden_states)
logits = logits.float()
loss = None
if labels is not None:
# Shift so that tokens < n predict n
shift_logits = logits[..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous()
# Flatten the tokens
loss_fct = CrossEntropyLoss()
shift_logits = shift_logits.view(-1, self.config.vocab_size)
shift_labels = shift_labels.view(-1)
# Enable model parallelism
shift_labels = shift_labels.to(shift_logits.device)
loss = loss_fct(shift_logits, shift_labels)
其中,shift_logits
和 shift_labels
数据:
输入文本 | <bos> | 我 | 是 | 一 | 个 | 中 | 国 | 人 | <eos> |
---|---|---|---|---|---|---|---|---|---|
shift_logits | <bos> | 我 | 是 | 一 | 个 | 中 | 国 | 人 | |
shift_labels | 我 | 是 | 一 | 个 | 中 | 国 | 人 | <eos> |
1.2 logits 含义
模型输出的数值部分,被称为 logits,其英文含义是 log-odds
,即对数几率,即:logits = self.lm_head(hidden_states)
。
logits
是模型在 应用激活函数(如 softmax)之前 的原始输出。假设有一个分类问题,模型的输出层有多个神经元,每个神经元对应一个类别。这些神经元的输出值就是 logits。例如,对于一个三分类问题,模型输出层的三个神经元输出值可能是[2.1, -1.5, 0.7],这就是 logits。logits
可以看作是每个类别的 对数几率,表示某个事件发生的几率的对数。在分类问题中,假设某个类别发生的概率是p
,那么该类别的几率是p/(1 - p)
。对数几率就是log(p/(1 - p))
。模型输出的logits
通过一些变换(如softmax
函数)来近似地表示这种对数几率。在 softmax 函数中,使用 e 为底数,将 对数几率(logits) 转换为概率分布,使得输出值在 0 到 1 之间,并且所有类别的概率之和为 1。
Softmax 公式,其中 x
就是 logits
,即:
s
o
f
t
m
a
x
(
X
)
=
e
x
∑
i
=
1
n
e
x
i
softmax(X) = \frac{e^x}{\sum_{i=1}^{n}e^{x_{i}}}
softmax(X)=∑i=1nexiex
Sigmoid 公式,其中 x
就是 logits
,即:
σ
(
x
)
=
1
1
+
e
−
x
\sigma(x)=\frac{1}{1+e^{-x}}
σ(x)=1+e−x1
例如源码:
logits = self.lm_head(hidden_states)
logits = logits.float()
1.3 logps 含义
logps
是 log probabilities
的缩写,即对数概率。具体来说,logps
是模型输出的对数概率分布,通过对 logits
应用 softmax
函数,并且取 对数(log)
得到的。
源码测试:
- 初始化数据
logits
和labels
import torch
from torch import nn
import torch.nn.functional as F
torch.manual_seed(42)
logits = torch.randn(3, 4) # 模拟 logits,3个样本,维度是4
# 真实标签,即4个维度中,正确的是0位置、1位置、1位置,也就是说让 logits 的这些位置的对数几率最大。
labels = torch.tensor([0, 1, 1])
print(f"[Info] logits: {logits}")
print(f"[Info] labels: {labels}")
[Info] logits: tensor([[ 0.3367, 0.1288, 0.2345, 0.2303],
[-1.1229, -0.1863, 2.2082, -0.6380],
[ 0.4617, 0.2674, 0.5349, 0.8094]])
[Info] labels: tensor([0, 1, 1])
- 计算
logps
使用log_softmax
,即 softmax -> log,先概率化,再转换成对数(负值)
logps = torch.log(nn.Softmax(dim=1)(logits))
print(f"[Info] logps1: {logps}")
logps = F.log_softmax(logits, dim=1)
print(f"[Info] logps2: {logps}")
[Info] logps1: tensor([[-1.2849, -1.4928, -1.3871, -1.3912],
[-3.5008, -2.5643, -0.1698, -3.0160],
[-1.4621, -1.6565, -1.3889, -1.1144]])
[Info] logps2: tensor([[-1.2849, -1.4928, -1.3871, -1.3912],
[-3.5008, -2.5643, -0.1698, -3.0160],
[-1.4621, -1.6565, -1.3889, -1.1144]])
- 再调用 负对数似然(Negative Log-Likelihood,NLL) Loss,把多个样本的 logps 均值,再计算负值,因为负负得正,即 loss 值(正值),优化越来越小
v_nll = - (-1.2849 + -2.5643 + -1.6565) / 3 # 0,1,1 位置
print(f"[Info] NLL1: {v_nll}")
v_nll = nn.NLLLoss()(torch.log(nn.Softmax(dim=1)(logits)), labels)
print(f"[Info] NLL2: {v_nll}")
[Info] NLL1: 1.8352333333333333
[Info] NLL2: 1.8352112770080566
- 直接使用
CrossEntropyLoss
函数,代替这些操作,即CrossEntropyLoss = Softmax+log+NLLLoss
v_nll = nn.CrossEntropyLoss()(logits, labels)
print(f"[Info] NLL3: {v_ce}")
[Info] NLL3: 1.835211157798767
参考 PPL (Perplexities,困惑度) 源码 (scripts/stat_utils/cal_ppl.py
):
outputs = model(**batch)
shift_logits: "torch.Tensor" = outputs["logits"][..., :-1, :]
shift_labels: "torch.Tensor" = batch["labels"][..., 1:]
loss_mask = shift_labels != IGNORE_INDEX
flatten_logits = shift_logits.contiguous().view(shift_labels.size(0) * shift_labels.size(1), -1)
flatten_labels = shift_labels.contiguous().view(-1)
token_logps: "torch.Tensor" = criterion(flatten_logits, flatten_labels)
token_logps = token_logps.contiguous().view(shift_logits.size(0), -1)
sentence_logps = (token_logps * loss_mask).sum(-1) / loss_mask.sum(-1)
total_ppl += sentence_logps.exp().sum().item()
perplexities.extend(sentence_logps.exp().tolist())
参考:
- 知乎 - 详解 PyTorch 的损失函数:NLLLoss()和CrossEntropyLoss()
- 知乎 - 大模型Pretrain和SFT阶段的Loss分析
- 大模型训练(SFT)实践总结