Zephyr-7B论文解析及全量训练、Lora训练

文章目录

    • 一、Zephyr:Direct Distillation of LM Alignment
      • 1.1 开发经过
        • 1.1.1 Zephyr-7B-alpha
        • 1.1.2 Zephyr-7B-beta
      • 1.2 摘要
      • 1.3 相关工作
      • 1.4 算法
        • 1.4.1 蒸馏监督微调(dSFT)
        • 1.4.2 基于偏好的AI反馈 (AIF)
        • 1.4.3 直接蒸馏偏好优化(dDPO)
        • 1.4.4 训练细节
      • 1.5 实验
    • 二、alignment-handbook:低成本训练Zephyr
      • 2.1 项目简介
      • 2.2 全量训练
        • 2.2.1 环境配置
        • 2.2.2 SFT训练
        • 2.2.3 DPO训练
      • 2.3 Lora训练
      • 2.4 测试

一、Zephyr:Direct Distillation of LM Alignment

1.1 开发经过

  • 全文参参考: 《Thomas Wolf: Mistral + OpenBMB + HuggingFace 跨越三大洲的大模型开源合作故事》、《最好7B模型再易主!打败700亿LLaMA2,苹果电脑就能跑,还开源免费》
  • 资源:UltraFeedback HuggingFace、UltraFeedback Github 地址 、UltraChat HuggingFace、UltraChat Github 地址
  • 推荐:OpenBMB官网、OpenBMB GitHub、清华NLP GitHub、HuggingFaceH4
1.1.1 Zephyr-7B-alpha

  几个月前,巴黎的一个新团队发布了他们首个模型:Mistral 7B,这个模型体积小巧但性能强劲,在基准测试中的表现超过了所有同类模型,并且还是一个开源项目。

  Hugging Face H4 团队的两名成员在一次小聚中,讨论了用斯坦福大学新发表的 DPO 方法对 Mistral 7B 这个模型进行微调的可能性。随后,他们在 HF hub 上找到了一些公开的数据集,包括由面壁智能和清华大学 NLP 共同支持的 OpenBMB 开源的两个大规模、高质量的微调数据集:UltraFeedback 和 UltraChat。

  1. UltraFeedback:一个大规模、多样化、细粒度 的偏好数据集,构建过程如下:
    • 从UltraChat、ShareGPT、Evol-Instruct、TruthfulQA、FalseQA 和 FLAN等多个资源收集了约64k提示
    • 为了防止奖励模型过度拟合某些文本样式或捕获文本样式和奖励之间的虚假相关性,我们选择具有不同大小、架构和训练数据的各个级别的17个基础模型构建一个模型池,包括LLaMA、Falcon、StarChat、MPT、GPT 和 Bard等。
    • 定义了Helpfulness, Truthfulness, Honesty, Verbalized Calibration , Harmless5项原则,以从不同方面调整模型行为
    • 每条指令,随机采样 4 个模型来完成,并对每个完成的指示,我们随机采样一个原则并将其添加到系统提示中,以调整模型的行为。
    • 最终数据集包括 64k指令、256k 条对话数据以及相应的偏好标注数据、308k高质量反馈,在非社区标注的偏好数据集中,这一数据规模排在首位。对话数据中每条偏好标注均包含 instruction-following, truthfulness, honesty and helpfulness四个方面的细粒度得分与 GPT-4 的注释说明。整个数据集详细技术原理可参考其论文。
      在这里插入图片描述

基于 UltraFeedback,面壁团队还训练了奖励模型UltraRM和批评模型UltraCM来进一步辅助模型评测和模型反馈学习,更多介绍,详见《面壁智能对齐技术UltraFeedback如何让7B模型打败70B LLaMA2》

  • UltraChat: 高质量的对话数据集,包含了 150 余万条多轮指令数据。调用多个 ChatGPT API 相互对话,从而生成多轮对话数据。

  经过几轮实验证明,使用 OpenBMB 两个数据集训练出来的新模型非常强大,是 H4 团队 在伯克利和斯坦福的基准测试中见过的最强模型,并在之后被名名为 Zephyr模型。Zephyr-7B-alpha 的MT-Bench平均得分7.09 ,超越Llama2-70B-Chat。
在这里插入图片描述

  一个基于高质量数据集的 7B 模型就打败了参数十倍之大的 LLaMA2-70B-Chat,这说明底层的数据工作 才是最稀缺的和有时间价值的,这或许是各家各派大模型在 百模大战中的突破口之一。

  另外,Zephyr的效果优于 LLaMA2-70B-Chat,另外一个主要原因是使用了斯坦福大学和CZ Biohub不久前合作提出DPO方法。与传统的PPO强化学习方法不同,DPO方法舍弃了强化学习,要比PPO稳定得多。

  DPO简单解释:要想使模型的输出更加符合人类偏好,一直以来传统方法是用一个奖励模型来微调目标模型。输出得好给奖励,输出不好不给奖励。而DPO的方法绕过了建模奖励函数,相当于直接在偏好数据上优化模型,它解决了人类反馈的强化学习训练难、训练成本高的问题。

1.1.2 Zephyr-7B-beta

  开发二代模型时,他们思考了大模型所用的蒸馏监督微调(dSFT),但用这种方法模型是不对齐的,不能很好地生成符合用户意图的输出。

在这里插入图片描述
  所以团队尝试使用来自AI反馈(AI Feedback,AIF)的偏好数据,用一个“教师模型”对输出进行排名,形成一个数据集,然后应用蒸馏直接偏好优化(dDPO)来训练一个与用户意图对齐的模型,且在微调期间不需要任何额外的抽样。研究人员还测试了不用SFT的效果,结果性能大大降低,说明dSFT步骤至关重要。
在这里插入图片描述

  二代Zephyr-7B-beta,探索了从GPT-4、Claude 2中提取对齐性,然后将其注入小模型中的想法,开发出了将蒸馏直接偏好优化(dDPO)用于小模型的方法,MT-Bench平均得分升高至7.34

在这里插入图片描述
在AlpacaEval上,Zephyr胜率为90.6%,优于ChatGPT(3.5):
在这里插入图片描述

1.2 摘要

  • DPO论文《Direct Preference Optimization: Your Language Model is Secretly a Reward Model》
  • Zephyr论文《Zephyr: Direct Distillation of LM Alignment》

  本文旨在创建一个较小的语言模型,该模型能够更好地对用户意图进行对齐。

  先前的研究表明,对较大模型进行蒸馏监督微调(dSFT)可以显著提高任务准确性。然而,这些模型在自然提示下的响应不够理想。为了改善这一性质,研究者尝试使用来自AI Feedback(AIF)的偏好数据,通过使用由教师模型排名的输出数据集,应用蒸馏直接偏好优化(dDPO)来学习一个具有显著改进的意图对齐的聊天模型。

  这种方法只需要几个小时的训练,而且在微调过程中无需进行额外的采样。最终得到的Zephyr-7B在7B参数模型的聊天基准( chat benchmarks)上取得了新的最先进水平,并且无需人工标注。MT-Bench结果显示,Zephyr-7B超过了LLaMA2-70B-Chat。该系统的代码、模型、数据和教程可在alignment-handbook上找到。

1.3 相关工作

  近年来开源的大规模语言模型不断涌现,如ChatGPT之后出现的LLaMA、RedPajama-INCITE、Falcon、Llama 2、Mistral等模型,为研究社区提供了研究和应用的基础模型。随着开源模型的发展,研究人员研究从大模型中迁移知识来提升小模型性能的方法,这一趋势开始于self-instruct和Alpaca,蒸馏策略如SFT和偏好优化是其中的研究重点。

  为了跟上生成式AI的创新步伐,基准测试和评估LLM的工具也取得了长足发展,比如:

  • 使用强大的LLM(GPT-4、Claude)作为评估器,对模型输出打分或成对排名回复来判断模型的响应。
  • LMSYS chatbot arena:使用众包的方式,通过匿名随机对战来基准测试LLM。模型根据排行榜上的Elo评分进行排名。
  • AlpacaEval:类似LMSYS这种排行榜的方式,成对比较模型,但使用GPT-4和Claude等更大的LLM来替代人类进行评估
  • MTBench:使用GPT-4对不同任务类别的多轮对话进行评分(1-10分),任务类别包括推理、角色扮演、数学、编码、写作、人文、STEM和信息提取等。
  • 其它评测工具:HuggingFace Open LLM leaderbaord、Chain-of-Thought Hub、ChatEvalFastEval等。

  本文最终通过在MTBench、AlpacaEval和HuggingFace OpenLLM排行榜上的评测结果来展示Zephyr模型的效果。

Model performance on MT-Bench

1.4 算法

参考:《Zephyr-7B: Fine-Tuning and Inference with W&B》、《HuggingFace 新作:70亿打败700亿Llama 2,开源模型Zephyr-7B!MAC可跑》

论文旨在使开源大型语言模型与用户意图保持一致,如下图所示整个训练过程分为三步:

1.4.1 蒸馏监督微调(dSFT)

  dSFT(Distilled Supervised Fine-Tuning)通过高质量的指令-响应数据集来教会我们的模型对指令和提示进行响应。Zephyr没有采用传统的在指令-响应数据集上进行监督微调(SFT),而是利用教师模型来生成这些高质量的响应,从而“蒸馏”教师模型的某些能力到我们的模型中,你也可以将其看作一种伪标签法。

  假设你有一个种子提示集合 { x 1 , . . . , x j } \{x_1, ..., x_j\} {x1,...,xj},对于每个种子提示 x i  x_i xi,使用教师模型(GPT-4)对指令做出响应得到 y i y_i yi,同时基于其响应进一步提炼这个指令得到 x ^ i \hat{x}_i x^i,最终得到数据集: C = { ( x ^ i , y i ) , . . . , ( x ^ j , y j ) } C = \{(\hat{x}_i, y_i), ..., (\hat{x}_j, y_j)\} C={(x^i,yi),...,(x^j,yj)}

然后对模型进行指令调优,以优化以下方程:

π d S F T = m a x π   E ( x , y ) ∼ C l o g π ( y ∣ x ) \pi_{dSFT} = \underset{\pi}{max} ~ \underset{(x, y) \sim C}{\mathbb{E}} log \pi(y|x) πdSFT=πmax (x,y)CElogπ(yx)

  • π \pi π:要优化的参数,即学生模型
  • C:教师模型生成的训练数据集,包含提炼后的提示 x ^ i \hat{x}_i x^i和响应 y i y_i yi
  •   E ( x , y ) ∼ C ~ \underset{(x, y) \sim C}{\mathbb{E}}  (x,y)CE:表示从数据集C中采样 x x x y y y

  这个方程的目标是最大化学生模型生成教师模型响应的对数似然概率,即通过使学生模型模仿教师模型的响应,实现知识迁移。

1.4.2 基于偏好的AI反馈 (AIF)

  人类反馈(HF,Human feedback)可以为对齐大语言模型(LLM)提供额外的指导信号,是调整LLM的常见方法。本文使用了蒸馏,所以改为使用教师模型对其他模型生成的输出给出指导,即基于偏好的AI反馈(AIF,AI Feedback through Preferences )。说白了就是用AI反馈(教师模型)代替人类反馈。

  具体来说,参考UltraFeedback中的方法,对于每个提示 x 1 , . . . , x j x_1, ..., x_j x1,...,xj,使用4个模型(Claude、LLaMA、Falcon等)生成响应 ( y i 1 , y i 2 , y i 3 , y i 4 ) (y^1_i, y^2_i, y^3_i, y^4_i) (yi1,yi2,yi3,yi4),然后使用GPT-4作为教师模型对其给出分数 s { 1 , 2 , 3 , 4 } = π T ( ⋅ ∣ x i , y i { 1 , 2 , 3 , 4 } )  s^{\{1, 2, 3, 4\}} = \pi_{T}(\cdot|x_i, y_i^{\{1, 2, 3, 4\}}) s{1,2,3,4}=πT(xi,yi{1,2,3,4}),4 个响应中的最高分的响应称为 y w  y_w yw ,随机一个较低分数的响应称为 y l y_l yl 。这样,我们就从提示列表 { x 1 , . . . , x j } \{x_1, ..., x_j\} {x1,...,xj}中派生出AI反馈数据集 D = { ( x 1 , y 1 w , y 1 l ) , . . . , ( x j , y j w , y j l ) }  D = \{(x_1, y_1^w, y_1^l), ..., (x_j, y_j^w, y_j^l)\} D={(x1,y1w,y1l),...,(xj,yjw,yjl)},这是一个具有较强响应和较弱响应的三元组。

1.4.3 直接蒸馏偏好优化(dDPO)

  直接蒸馏偏好优化 (dDPO,Distilled direct preference optimization)的目标,是通过最大化偏好模型中响应 y w y_w yw对响应 y l y_l yl的优先排列概率,来优化经过dSFT的模型πdSFT。奖励函数由学生语言模型确定

  过去使用人工智能反馈(AI feedback)的工作主要集中在使用强化学习(RL)的方法,比如PPO(Proximal Policy Optimization),通过先训练奖励函数然后从当前策略中抽样来计算更新以优化 θ θ θ。而在DPO中,偏好模型是由奖励函数 r θ ( x , y ) r_θ(x, y) rθ(x,y) 确定的,该函数利用了学生语言模型 π θ π_θ πθ

  DPO的关键观察是用最优语言模型策略π和原始语言模型策略πdSFT,来导出最优的奖励函数。在适当选择偏好模型的情况下,他们证明对于常数β和配分函数Z,有:
r ∗ ( x , y ) = β π * ( y ∣ x ) π dSFT ( y ∣ x ) + β log ⁡ Z ( x ) r^*(x,y) = \beta \frac{\pi_{\text{*}}(y | x)} {\pi_{\text{dSFT}}(y | x)} + \beta\log Z(x) r(x,y)=βπdSFT(yx)π*(yx)+βlogZ(x)

将奖励函数插入偏好模型中,得到的目标函数如下:
π θ = m a x π E ( x , y w , y l ) ∼ D l o g σ ( β l o g π ( y w ∣ x ) π d S F T ( y w ∣ x ) − β l o g π ( y l ∣ x ) π d S F T ( y l ∣ x ) ) \pi_\theta = \underset{\pi}{max} \underset{(x, y_w, y_l)\sim D}{\mathbb{E}} log \sigma (\beta log \frac{\pi(y_w|x)}{\pi_{dSFT} (y_w|x)} - \beta log \frac{\pi(y_l|x)}{\pi_{dSFT} (y_l|x)}) πθ=πmax(x,yw,yl)DElogσ(βlogπdSFT(ywx)π(ywx)βlogπdSFT(ylx)π(ylx))

  与 RLHF 相比,DPO直接从静态偏好数据优化模型,而不需要经过训练的奖励模型。据作者称,DPO 是轻量级的并且更稳定。文中使用的方法称为 dDPO,因为数据集是从早期步骤中提取(distilled)的,利用 AI 提供的偏好标签。

  总结整个训练过程:

  1. 对LLM进行dSFT指令调整,得到模型 π d S F T \pi_{dSFT} πdSFT
  2. 参考UltraFeedback中的方法,从提示列表 { x 1 , . . . , x j } \{x_1, ..., x_j\} {x1,...,xj}中构建AI反馈数据集 D = { ( x 1 , y 1 w , y 1 l ) , . . . , ( x j , y j w , y j l ) }  D = \{(x_1, y_1^w, y_1^l), ..., (x_j, y_j^w, y_j^l)\} D={(x1,y1w,y1l),...,(xj,yjw,yjl)}
  3. 遍历每个AIF三元组 ( x i , y i w , y i l ) } (x_i, y_i^w, y_i^l)\} (xi,yiw,yil)},以实现模型的优化。
    • 计算dSFT 模型的 ( x , y w ) (x, y_w) (x,yw) ( x , y l ) (x, y_l) (x,yl) 概率(仅进行前向计算)。
    • 计算dDPO模型的 ( x , y w ) (x, y_w) (x,yw) ( x , y l ) (x, y_l) (x,yl)的概率
    • 根据目标函数计算损失,反向传播以更新参数,然后重复此过程。
1.4.4 训练细节
  • π d S F T \pi_{dSFT} πdSFT模型训练蚕食为:cosine LR scheduler,最大学习率为2e-5, warmup steps=10%,epoch=1,sequence length=2048,batch size =512。
  • DPO 模型训练蚕食为:linear LR scheduler,最大学习率为5e-7, warmup steps=10%。batch size =32,β=0.1,epoch=3。

  最终的Zephyr-7B模型是在SFT模型(训练1个epcoh)上进行权重初始化,然后进行3个epoch的DPO训练。

1.5 实验

  1. dDPO提升了在对话数据集MT-Bench 和 AlpacaEval 上的效果
    在这里插入图片描述
  2. dDPO提升了在传统任务(Academic Task)上的效果
    在这里插入图片描述
  3. 偏好优化是否必要?
    在表格3中,我们通过以四种不同方式对 Mistral 7B 进行微调,来考察对齐过程中不同步骤的影响:
    在这里插入图片描述
    • dDPO - dSFT:直接在base model上进行1个epoch的DPO训练,数据集为UltraFeedback。可以看到没有第一步SFT训练,模型无法从反馈中学习,表现糟糕。
    • dSFT1:在base model上进行1个epoch的SFT训练,数据集为UltraChat,这一步显著提高模型在两个聊天基准上的得分。
    • dSFT2:先进行dSFT1,然后在UltraFeedback数据集上接着进行1个epoch的SFT训练,模型过拟合
    • dDPO+dSFT:本文的训练策略,dSFT1之后在接着在ltraFeedback数据集上接着进行1个epoch的DPO训练,两个基准上均有显著提升。
  4. 过拟合是否会损失在下游任务上的性能
    • 一轮 DPO 训练后,模型会出现强烈的过拟合,如下图中训练集准确率的完美表现所示。但这并没有损害在 MT-Bench 和 AlpacaEval 上的下游性能。随着训练时间增加,过拟合后,效果居然更好了。研究人员认为这类似于SFT中的过拟合。
    • 如果 SFT 训练超过1 epoch,那么DPO 步骤会引起性能的退化。
    • 最佳模型经过了一轮 SFT训练,三轮DPO 的训练。
      在这里插入图片描述

二、alignment-handbook:低成本训练Zephyr

参考:《如何低成本训练一个可以超越 70B Llama2 的模型 Zephyr-7B》、项目地址《alignment-handbook》

  整个Zephyr的完整训练过程,发布在《alignment-handbook》,环境安装见项目首页。下面对训练过程进行简单介绍。

2.1 项目简介

整个训练过程分两步:

  1. SFT训练:使用 UltraChat 数据集对 Mistral 7B 模型进行SFT训练。
    对于 SFT 训练,我们使用了 UltraChat 数据集,它包含了约 1.6M个 由 GPT3.5 生成的对话。我们最初是在所有数据上进行训练的,但后来发现训练出来的模型性格有点让人讨厌 。因此,我们筛选出了大约 200K 个更有帮助的例子进行训练,筛选后的数据集为ultrachat_200k。
  2. DPO微调:使用UltraFeedback 数据集的预处理版本,对SFT模型进行DPO(直接偏好优化)微调,将其与AI反馈(AI feedback)对齐。
    UltraFeedback 数据集涵盖了各种模型范围。每个回答都由 GPT-4 根据有益性等标准进行了评分,以此来推导 AI 的偏好。一个有趣的发现是,在用DPO的方法时,随着训练时间增加,过拟合后,效果居然更好了。研究人员认为这类似于SFT中的过拟合。

  另外,在所有实验中都使用了 TRL 和 DeepSpeed ZeRO-3:SFTTrainer、DPOTrainer,总计算成本:$500 或在16 x A100 上运行 8 小时,体验demo:zephyr-chat。

  评估方法:我们使用了 LMSYS 提供的优秀工具 MT Bench。这个多轮的基准测试可以评估聊天机器人在创意写作、编码和数学等各个领域的能力。相比其他排行榜,它能提供更准确的关于聊天机器人性能的信息。

最终,项目提供了两种训练方式:

  • Zephyr-7B完整训练:因为是全量训练,所以开启了deepspeed ZERO stage3,环境配置见recipes/accelerate_configs/deepspeed_zero3.yaml。

    # Step 1 - SFT
    ACCELERATE_LOG_LEVEL=info accelerate launch --config_file recipes/accelerate_configs/deepspeed_zero3.yaml scripts/run_sft.py recipes/zephyr-7b-beta/sft/config_full.yaml
    
    # Step 2 - DPO
    ACCELERATE_LOG_LEVEL=info accelerate launch --config_file recipes/accelerate_configs/deepspeed_zero3.yaml scripts/run_dpo.py recipes/zephyr-7b-beta/dpo/config_full.yaml
    
  • Zephyr-7B LoRA训练:微调不需要开启deepspeed,环境配置见recipes/accelerate_configs/multi_gpu.yaml。

    # Step 1 - SFT
    ACCELERATE_LOG_LEVEL=info accelerate launch --config_file recipes/accelerate_configs/multi_gpu.yaml --num_processes=1 scripts/run_sft.py recipes/zephyr-7b-beta/sft/config_lora.yaml
    
    # Step 2 - DPO
    ACCELERATE_LOG_LEVEL=info accelerate launch --config_file recipes/accelerate_configs/multi_gpu.yaml --num_processes=1 scripts/run_dpo.py recipes/zephyr-7b-beta/dpo/config_lora.yaml
    

下面给出训练代码,今天写笔记太累了,还没有跑,有空再补吧。

2.2 全量训练

2.2.1 环境配置

配置文件为recipes/accelerate_configs/deepspeed_zero3.yaml

compute_environment: LOCAL_MACHINE
debug: false
deepspeed_config:
  deepspeed_multinode_launcher: standard
  offload_optimizer_device: none
  offload_param_device: none
  zero3_init_flag: true
  zero3_save_16bit_model: true
  zero_stage: 3
distributed_type: DEEPSPEED
downcast_bf16: 'no'
machine_rank: 0
main_training_function: main
mixed_precision: bf16
num_machines: 1
num_processes: 8
rdzv_backend: static
same_network: true
tpu_env: []
tpu_use_cluster: false
tpu_use_sudo: false
use_cpu: false
2.2.2 SFT训练
  1. 模型配置文件为 recipes/zephyr-7b-beta/sft/config_full.yaml
# Model arguments
model_name_or_path: mistralai/Mistral-7B-v0.1
model_revision: main
torch_dtype: bfloat16
use_flash_attention_2: true

# Data training arguments
dataset_mixer:
  HuggingFaceH4/ultrachat_200k: 1.0
dataset_splits:
- train_sft
- test_sft
preprocessing_num_workers: 12

# SFT trainer config
bf16: true
do_eval: true
evaluation_strategy: epoch
gradient_accumulation_steps: 2
gradient_checkpointing: true
hub_model_id: zephyr-7b-sft-full
hub_strategy: every_save
learning_rate: 2.0e-05
log_level: info
logging_steps: 5  
logging_strategy: steps
lr_scheduler_type: cosine
max_seq_length: 2048
max_steps: -1
num_train_epochs: 1
output_dir: data/zephyr-7b-sft-full
overwrite_output_dir: true
per_device_eval_batch_size: 16
per_device_train_batch_size: 32
push_to_hub: true
remove_unused_columns: true
report_to:
- tensorboard
save_strategy: "no"
save_total_limit: null
seed: 42
tf32: true
  1. SFT训练代码见scripts/run_sft.py
#!/usr/bin/env python
# coding=utf-8
# Copyright 2023 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Supervised fine-tuning script for decoder language models.
"""

import logging
import random
import sys

import datasets
import torch
import transformers
from transformers import set_seed

from accelerate import Accelerator
from alignment import (
    DataArguments,
    H4ArgumentParser,
    ModelArguments,
    SFTConfig,
    apply_chat_template,
    get_datasets,
    get_kbit_device_map,
    get_peft_config,
    get_quantization_config,
    get_tokenizer,
)
from trl import SFTTrainer


logger = logging.getLogger(__name__)


def main():
    parser = H4ArgumentParser((ModelArguments, DataArguments, SFTConfig))
    model_args, data_args, training_args = parser.parse()

    # Set seed for reproducibility
    set_seed(training_args.seed)

    accelerator = Accelerator()

    ###############
    # Setup logging
    ###############
    logging.basicConfig(
        format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
        datefmt="%Y-%m-%d %H:%M:%S",
        handlers=[logging.StreamHandler(sys.stdout)],
    )
    log_level = training_args.get_process_log_level()
    logger.setLevel(log_level)
    datasets.utils.logging.set_verbosity(log_level)
    transformers.utils.logging.set_verbosity(log_level)
    transformers.utils.logging.enable_default_handler()
    transformers.utils.logging.enable_explicit_format()

    # Log on each process a small summary
    logger.warning(
        f"Process rank: {training_args.local_rank}, device: {training_args.device}, n_gpu: {training_args.n_gpu}"
        + f" distributed training: {bool(training_args.local_rank != -1)}, 16-bits training: {training_args.fp16}"
    )
    logger.info(f"Model parameters {model_args}")
    logger.info(f"Data parameters {data_args}")
    logger.info(f"Training/evaluation parameters {training_args}")

    ###############
    # Load datasets
    ###############
    raw_datasets = get_datasets(data_args, splits=data_args.dataset_splits)
    logger.info(
        f"Training on the following datasets and their proportions: {[split + ' : ' + str(dset.num_rows) for split, dset in raw_datasets.items()]}"
    )

    ################
    # Load tokenizer
    ################
    tokenizer = get_tokenizer(model_args, data_args)

    #####################
    # Apply chat template
    #####################
    raw_datasets = raw_datasets.map(apply_chat_template, fn_kwargs={"tokenizer": tokenizer, "task": "sft"})
    train_dataset = raw_datasets["train"]
    eval_dataset = raw_datasets["test"]

    with training_args.main_process_first(desc="Log a few random samples from the processed training set"):
        for index in random.sample(range(len(raw_datasets["train"])), 3):
            logger.info(f"Sample {index} of the processed training set:\n\n{raw_datasets['train'][index]['text']}")

    #######################
    # Load pretrained model
    #######################
    logger.info("*** Load pretrained model ***")
    torch_dtype = (
        model_args.torch_dtype if model_args.torch_dtype in ["auto", None] else getattr(torch, model_args.torch_dtype)
    )

    model_kwargs = dict(
        revision=model_args.model_revision,
        trust_remote_code=model_args.trust_remote_code,
        use_flash_attention_2=model_args.use_flash_attention_2,
        torch_dtype=torch_dtype,
        use_cache=False if training_args.gradient_checkpointing else True,
        device_map=get_kbit_device_map(),
        quantization_config=get_quantization_config(model_args),
    )
    logger.info("*** Model loaded! ***")

    ########################
    # Initialize the Trainer
    ########################
    trainer = SFTTrainer(
        model=model_args.model_name_or_path,
        model_init_kwargs=model_kwargs,
        args=training_args,
        train_dataset=train_dataset,
        eval_dataset=eval_dataset,
        dataset_text_field="text",
        max_seq_length=training_args.max_seq_length,
        tokenizer=tokenizer,
        packing=True,
        peft_config=get_peft_config(model_args),
    )

    ###############
    # Training loop
    ###############
    logger.info("*** Train ***")
    train_result = trainer.train()
    metrics = train_result.metrics
    max_train_samples = data_args.max_train_samples if data_args.max_train_samples is not None else len(train_dataset)
    metrics["train_samples"] = min(max_train_samples, len(train_dataset))
    trainer.log_metrics("train", metrics)
    trainer.save_metrics("train", metrics)
    trainer.save_state()

    ##########
    # Evaluate
    ##########
    if training_args.do_eval:
        logger.info("*** Evaluate ***")
        metrics = trainer.evaluate()
        max_eval_samples = data_args.max_eval_samples if data_args.max_eval_samples is not None else len(eval_dataset)
        metrics["eval_samples"] = min(max_eval_samples, len(eval_dataset))
        trainer.log_metrics("eval", metrics)
        trainer.save_metrics("eval", metrics)

    ##################################
    # Save model and create model card
    ##################################
    logger.info("*** Save model ***")
    trainer.save_model(training_args.output_dir)
    logger.info(f"Model saved to {training_args.output_dir}")

    # Save everything else on main process
    if accelerator.is_main_process:
        kwargs = {
            "finetuned_from": model_args.model_name_or_path,
            "dataset": list(data_args.dataset_mixer.keys()),
            "dataset_tags": list(data_args.dataset_mixer.keys()),
            "tags": ["alignment-handbook"],
        }
        trainer.create_model_card(**kwargs)
        # Restore k,v cache for fast inference
        trainer.model.config.use_cache = True
        trainer.model.config.save_pretrained(training_args.output_dir)

        if training_args.push_to_hub is True:
            logger.info("Pushing to hub...")
            trainer.push_to_hub()

    accelerator.wait_for_everyone()


if __name__ == "__main__":
    main()
2.2.3 DPO训练
  1. 环境配置文件相同
  2. 模型配置文件见recipes/zephyr-7b-beta/dpo/config_full.yaml
# Model arguments
model_name_or_path: alignment-handbook/zephyr-7b-sft-full

# Data training arguments
# For definitions, see: src/h4/training/config.py
dataset_mixer:
  HuggingFaceH4/ultrafeedback_binarized: 1.0
dataset_splits:
- train_prefs
- test_prefs
preprocessing_num_workers: 12

# DPOTrainer arguments
bf16: true
beta: 0.1
do_eval: true
evaluation_strategy: steps
eval_steps: 100
gradient_accumulation_steps: 1
gradient_checkpointing: true
hub_model_id: zephyr-7b-dpo-full
learning_rate: 5.0e-7
log_level: info
logging_steps: 10
lr_scheduler_type: linear
max_length: 1024
max_prompt_length: 512
num_train_epochs: 3
optim: rmsprop
output_dir: data/zephyr-7b-dpo-full
per_device_train_batch_size: 8
per_device_eval_batch_size: 4
push_to_hub: true
save_strategy: "no"
save_total_limit: null
seed: 42
warmup_ratio: 0.1
  1. DPO训练代码见scripts/run_dpo.py
#!/usr/bin/env python
# coding=utf-8
# Copyright 2023 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
import sys

import torch
import transformers
from transformers import AutoModelForCausalLM, set_seed

from accelerate import Accelerator
from alignment import (
    DataArguments,
    DPOConfig,
    H4ArgumentParser,
    ModelArguments,
    apply_chat_template,
    get_datasets,
    get_kbit_device_map,
    get_peft_config,
    get_quantization_config,
    get_tokenizer,
    is_adapter_model,
)
from peft import PeftConfig, PeftModel
from trl import DPOTrainer


logger = logging.getLogger(__name__)


def main():
    parser = H4ArgumentParser((ModelArguments, DataArguments, DPOConfig))
    model_args, data_args, training_args = parser.parse()

    #######
    # Setup
    #######
    logging.basicConfig(
        format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
        datefmt="%Y-%m-%d %H:%M:%S",
        handlers=[logging.StreamHandler(sys.stdout)],
    )
    log_level = training_args.get_process_log_level()
    logger.setLevel(log_level)
    transformers.utils.logging.set_verbosity(log_level)
    transformers.utils.logging.enable_default_handler()
    transformers.utils.logging.enable_explicit_format()

    # Log on each process the small summary:
    logger.info(f"Model parameters {model_args}")
    logger.info(f"Data parameters {data_args}")
    logger.info(f"Training/evaluation parameters {training_args}")

    # Set seed for reproducibility
    set_seed(training_args.seed)

    # Increase distributed timeout to 3h to enable push to Hub to complete
    accelerator = Accelerator()

    ###############
    # Load datasets
    ###############
    raw_datasets = get_datasets(data_args, splits=data_args.dataset_splits)
    logger.info(
        f"Training on the following splits: {[split + ' : ' + str(dset.num_rows) for split, dset in raw_datasets.items()]}"
    )
    column_names = list(raw_datasets["train"].features)

    #####################################
    # Load tokenizer and process datasets
    #####################################
    data_args.truncation_side = "left"  # Truncate from left to ensure we don't lose labels in final turn
    tokenizer = get_tokenizer(model_args, data_args)

    #####################
    # Apply chat template
    #####################
    raw_datasets = raw_datasets.map(
        apply_chat_template,
        fn_kwargs={"tokenizer": tokenizer, "task": "dpo"},
        num_proc=data_args.preprocessing_num_workers,
        remove_columns=column_names,
        desc="Formatting comparisons with prompt template",
    )

    # Replace column names with what TRL needs, text_chosen -> chosen and text_rejected -> rejected
    for split in ["train", "test"]:
        raw_datasets[split] = raw_datasets[split].rename_columns(
            {"text_prompt": "prompt", "text_chosen": "chosen", "text_rejected": "rejected"}
        )

    torch_dtype = (
        model_args.torch_dtype if model_args.torch_dtype in ["auto", None] else getattr(torch, model_args.torch_dtype)
    )
    model_kwargs = dict(
        revision=model_args.model_revision,
        trust_remote_code=model_args.trust_remote_code,
        use_flash_attention_2=model_args.use_flash_attention_2,
        torch_dtype=torch_dtype,
        use_cache=False if training_args.gradient_checkpointing else True,
        device_map=get_kbit_device_map(),
        quantization_config=get_quantization_config(model_args),
    )

    model = model_args.model_name_or_path
    if is_adapter_model(model, model_args.model_revision):
        # load the model, merge the adapter weights and unload the adapter
        # Note: to run QLora, you will need to merge the based model separately as the merged model in 16bit
        logger.info(f"Merging peft adapters for {model_args.model_name_or_path=}")

        peft_config = PeftConfig.from_pretrained(model_args.model_name_or_path, revision=model_args.model_revision)

        model_kwargs = dict(
            revision=model_args.base_model_revision,
            trust_remote_code=model_args.trust_remote_code,
            use_flash_attention_2=model_args.use_flash_attention_2,
            torch_dtype=torch_dtype,
            use_cache=False if training_args.gradient_checkpointing else True,
        )
        base_model = AutoModelForCausalLM.from_pretrained(
            peft_config.base_model_name_or_path,
            **model_kwargs,
        )
        model = PeftModel.from_pretrained(
            base_model, model_args.model_name_or_path, revision=model_args.model_revision
        )
        model.eval()
        model = model.merge_and_unload()
        model_kwargs = None

    ref_model = model
    ref_model_kwargs = model_kwargs

    if model_args.use_peft is True:
        ref_model = None
        ref_model_kwargs = None

    #########################
    # Instantiate DPO trainer
    #########################
    dpo_trainer = DPOTrainer(
        model,
        ref_model,
        model_init_kwargs=model_kwargs,
        ref_model_init_kwargs=ref_model_kwargs,
        args=training_args,
        beta=training_args.beta,
        train_dataset=raw_datasets["train"],
        eval_dataset=raw_datasets["test"],
        tokenizer=tokenizer,
        max_length=training_args.max_length,
        max_prompt_length=training_args.max_prompt_length,
        peft_config=get_peft_config(model_args),
    )

    ###############
    # Training loop
    ###############
    train_result = dpo_trainer.train()
    metrics = train_result.metrics
    max_train_samples = (
        data_args.max_train_samples if data_args.max_train_samples is not None else len(raw_datasets["train"])
    )
    metrics["train_samples"] = min(max_train_samples, len(raw_datasets["train"]))
    dpo_trainer.log_metrics("train", metrics)
    dpo_trainer.save_metrics("train", metrics)
    dpo_trainer.save_state()

    logger.info("*** Training complete ***")

    ##########
    # Evaluate
    ##########
    if training_args.do_eval:
        logger.info("*** Evaluate ***")
        metrics = dpo_trainer.evaluate()
        max_eval_samples = (
            data_args.max_eval_samples if data_args.max_eval_samples is not None else len(raw_datasets["test"])
        )
        metrics["eval_samples"] = min(max_eval_samples, len(raw_datasets["test"]))
        dpo_trainer.log_metrics("eval", metrics)
        dpo_trainer.save_metrics("eval", metrics)

    ##################################
    # Save model and create model card
    ##################################
    dpo_trainer.save_model(training_args.output_dir)
    # Save everything else on main process
    if accelerator.is_main_process:
        kwargs = {
            "finetuned_from": model_args.model_name_or_path,
            "dataset": list(data_args.dataset_mixer.keys()),
            "dataset_tags": list(data_args.dataset_mixer.keys()),
            "tags": ["alignment-handbook"],
        }
        dpo_trainer.create_model_card(**kwargs)
        # Restore k,v cache for fast inference
        dpo_trainer.model.config.use_cache = True
        dpo_trainer.model.config.save_pretrained(training_args.output_dir)
        if training_args.push_to_hub is True:
            dpo_trainer.push_to_hub()

    # Ensure we don't timeout on model save / push to Hub
    logger.info("*** Waiting for all processes to finish ***")
    accelerator.wait_for_everyone()

    logger.info("*** Run complete! ***")


if __name__ == "__main__":
    main()

2.3 Lora训练

  Lora训练时的训练代码和模型配置文件与全量训练时完全相同,只是环境配置不一样。因为只是微调,不需要开启ZERO stage3,所以环境配置为recipes/accelerate_configs/multi_gpu.yaml

compute_environment: LOCAL_MACHINE
debug: false
distributed_type: MULTI_GPU
downcast_bf16: 'no'
gpu_ids: all
machine_rank: 0
main_training_function: main
mixed_precision: bf16
num_machines: 1
num_processes: 8
rdzv_backend: static
same_network: true
tpu_env: []
tpu_use_cluster: false
tpu_use_sudo: false
use_cpu: false

2.4 测试

  测试部分见alignment-handbook项目下的test文件夹,作者还没有上传相关说明文件,大家可以继续跟踪相关进度。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/133298.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

赛氪助力全国大学生数学竞赛山东赛区圆满举办

近日,全国大学生数学竞赛山东赛区比赛有序进行,赛氪已连续6年助力本项赛事蓬勃发展。在中国高等教育学会高校竞赛评估与管理体系研究专家工作组发布的《2022全国普通高校大学生竞赛分析报告》中,本赛事荣登观察目录。 全国大学生数学竞赛旨在…

Transforme原理--全局解读

文章目录 作用全局解读 作用 Transformer最初设计用于处理序列数据,特别在NLP(自然语言处理)领域取得了巨大成功 全局解读 Transformer来源于谷歌的一篇经典论文Attention is All you Need 在此使用Transformer在机器翻译中的运用来讲解Transformer。 其中Tran…

Windows11跳过联网激活 跳过登陆操作

1 背景 笔者使用VirtualBox时安装Win11,初始化的配置提示需要注册账户才能进行下一步操作,于是去查了一下发现有办法绕过,方法就是断网oobe\ByPassNRO.cmd,试了一下发现可以,便有了这篇文章。 2 流程 开机之前&…

【完美世界】石昊负伤遭囚禁,无始种惊现,二秃子用柳枝力保石昊

Hello,小伙伴们,我是小郑继续为大家深度解析国漫资讯。 深度爆料完美世界最新预告资讯,《完美世界》第137集预告片已经更新了,这一集的预告片充满了紧张的气氛和精彩的情节。从预告中我们可以看到,石昊的真实身份被天人族知晓&…

3.30每日一题(多元函数微分学)

1、判断连续:再分界点的极限值等于该点的函数值; 如何求极限值: 初步判断:分母都为二次幂开根号,所以分母为一次幂;分子为二次,一般来说整体为0; 如何说明极限为零(常用…

ZYNQ_project:IP_ram_pll_test

例化MMCM ip核,产生100Mhz,100Mhz并相位偏移180,50Mhz,25Mhz的时钟信号。 例化单口ram,并编写读写控制器,实现32个数据的写入与读出。 模块框图: 代码: module ip_top(input …

SpringBoot_01

Spring https://spring.io/ SpringBoot可以帮助我们非常快速的构建应用程序、简化开发、提高效率。 SpringBootWeb入门 需求:使用SpringBoot开发一个web应用,浏览器发起请求/hello后,给浏览器返回字符串"Hello World~~~"。 步骤…

测试人员如何通过AI提高工作效率!

随着AI技术的兴起,像OpenAI推出的ChatGPT、Microsoft发布的Microsoft 365 Copilot、阿里的通义千问、百度的文心一言、华为的盘古大模型等。很多测试人员开始担心,岗位是否会被AI取代?其实取代你的不是AI,而是会使用AI的测试人&am…

基于springboot+vue的校园闲置物品交易系统

运行环境 开发语言:Java 框架:springboot JDK版本:JDK1.8 服务器:tomcat7 数据库:mysql 数据库工具:Navicat11 开发软件:eclipse/myeclipse/idea Maven包:Maven 项目介绍 本文从管…

自动驾驶学习笔记(七)——感知融合

#Apollo开发者# 学习课程的传送门如下,当您也准备学习自动驾驶时,可以和我一同前往: 《自动驾驶新人之旅》免费课程—> 传送门 《Apollo Beta宣讲和线下沙龙》免费报名—>传送门 文章目录 前言 感知融合 卡尔曼滤波 融合策略 实…

NtripShare Mos地铁自动化监测终端盒子硬件设计

自动化监测产品到目前为止做了接近一年,在软件层面上,控制终端软件、平台软件、网平差算法都已解决,硬件盒子始终是心里过不去的坎,最终还是没有耐住性子自己做了一把。 选型如下: 1、主板:瑞芯微RK3568主板。 2、外…

解决《荒野大镖客》提示emp.dll文件丢失问题,总结5个修复方法

在当今数字时代,游戏已经成为人们休闲娱乐的重要方式。作为一名游戏爱好者,笔者在近期体验《荒野大镖客》这款游戏时,遇到了一个令人苦恼的问题——emp.dll文件丢失。这个问题让游戏的无法启动进行。本文将围绕这一问题,探讨其原因…

Leetcode2834. 找出美丽数组的最小和

Every day a Leetcode 题目来源:2834. 找出美丽数组的最小和 解法1:贪心 从最小正整数 1 开始枚举,设当前数为 num,如果 nums 里没有 target - num,就说明可以添加 num,依次填满直到有 n 个数即可。 用…

公开数据集:灵长类动物多通道感觉运动皮层电生理学的研究

Nonhuman Primate Reaching with Multichannel Sensorimotor Cortex Electrophysiology. 1 公开数据集网址:https://zenodo.org/records/3854034 目录 General DescriptionPossible usesVariable namesDecoder ResultsVideosSupplementsContact InformationCitation…

java 类和对象 (图文搭配,万字详解!!)

关于java类和对象,我们要掌握几个重点! 1.类的定义方式以及对象的实例化 2.类中的成员变量和成员方法的使用 3.对象的整个初始化过程 4.封装特性 5.代码块 目录 一、面向对象的初步认识 1.1 什么是面向对象 1.2 面向对象与面向过程 1.2.1传统洗…

Python:词法分析(行结构与显式、隐式行拼接)

相关阅读 Pythonhttps://blog.csdn.net/weixin_45791458/category_12403403.html?spm1001.2014.3001.5482 1、逻辑结构 一个Python程序由许多逻辑行组成,字面意义上的一行指的是末尾有换行符(\n),但在不同的情况下,行末尾的换行符(\n)可能有…

语音识别与自然语言处理(NLP):技术前沿与未来趋势

语音识别与自然语言处理(NLP):技术前沿与未来趋势 随着科技的快速发展,语音识别与自然语言处理(NLP)技术逐渐成为人工智能领域的研究热点。这两项技术的结合,使得机器能够更好地理解和处理人类语…

解析html生成Word文档

内容:读取html文件中的文本内容,然后生成Word文档导出。 事例场景:需求开发完成之后需要写文档(代码修改清单),文档内容就是这次需求修改/新增的所有代码,需要列出修改的文件路径以及代码片段&…

Dart笔记:一些代码生成工具站点的介绍

Dart笔记: 一些代码生成工具站点的介绍 作者:李俊才 (jcLee95):https://blog.csdn.net/qq_28550263 邮箱 :291148484163.com 本文地址:https://blog.csdn.net/qq_28550263/article/details/1343…

力扣138:随机链表的复制

力扣138:随机链表的复制 题目描述: 给你一个长度为 n 的链表,每个节点包含一个额外增加的随机指针 random ,该指针可以指向链表中的任何节点或空节点。 构造这个链表的 深拷贝。 深拷贝应该正好由 n 个 全新 节点组成&#xff…