使用Huggingface创建大语言模型RLHF训练流程的完整教程

ChatGPT已经成为家喻户晓的名字,而大语言模型在ChatGPT刺激下也得到了快速发展,这使得我们可以基于这些技术来改进我们的业务。

但是大语言模型像所有机器/深度学习模型一样,从数据中学习。因此也会有garbage in garbage out的规则。也就是说如果我们在低质量的数据上训练模型,那么在推理时输出的质量也会同样低。

这就是为什么在与LLM的对话中,会出现带有偏见(或幻觉)的回答的主要原因。

有一些技术允许我们对这些模型的输出有更多的控制,以确保LLM的一致性,这样模型的响应不仅准确和一致,而且从开发人员和用户的角度来看是安全的、合乎道德的和可取的。目前最常用的技术是RLHF.

基于人类反馈的强化学习(RLHF)最近引起了人们的广泛关注,它将强化学习技术在自然语言处理领域的应用方面掀起了一场新的革命,尤其是在大型语言模型(llm)领域。在本文中,我们将使用Huggingface来进行完整的RLHF训练。

RLHF由以下阶段组成:

特定领域的预训练:微调预训练的型语言模型与因果语言建模目标的原始文本。

监督微调:针对特定任务和特定领域(提示/指令、响应)对特定领域的LLM进行微调。

RLHF奖励模型训练:训练语言模型将反应分类为好或坏(赞或不赞)

RLHF微调:使用奖励模型训练由人类专家标记的(prompt, good_response, bad_response)数据,以对齐LLM上的响应

下面我们开始逐一介绍

特定领域预训练

特定于领域的预训练是向语言模型提供其最终应用领域的领域知识的一个步骤。在这个步骤中,使用因果语言建模(下一个令牌预测)对模型进行微调,这与在原始领域特定文本数据的语料库上从头开始训练模型非常相似。但是在这种情况下所需的数据要少得多,因为模型是已在数万亿个令牌上进行预训练的。以下是特定领域预训练方法的实现:

 #Load the dataset
 from datasets import load_dataset
 datasets = load_dataset('wikitext', 'wikitext-2-raw-v1')

对于因果语言建模(CLM),我们将获取数据集中的所有文本,并在标记化后将它们连接起来。然后,我们将它们分成一定序列长度的样本。这样,模型将接收连续文本块。

 from transformers import AutoTokenizer
     
 tokenizer = AutoTokenizer.from_pretrained(model_checkpoint, use_fast=True)
 
 def tokenize_function(examples):
     return tokenizer(examples["text"])
 
 tokenized_datasets = datasets.map(tokenize_function, batched=True, num_proc=4, remove_columns=["text"])
 
 def group_texts(examples):
     # Concatenate all texts.
     concatenated_examples = {k: sum(examples[k], []) for k in examples.keys()}
     total_length = len(concatenated_examples[list(examples.keys())[0]])
     # We drop the small remainder, we could add padding if the model supported it instead of this drop, you can
         # customize this part to your needs from deep_hub.
     total_length = (total_length // block_size) * block_size
     # Split by chunks of max_len.
     result = {
         k: [t[i : i + block_size] for i in range(0, total_length, block_size)]
         for k, t in concatenated_examples.items()
     }
     result["labels"] = result["input_ids"].copy()
     return result
 
 lm_datasets = tokenized_datasets.map(
     group_texts,
     batched=True,
     batch_size=1000,
     num_proc=4,
 )

我们已经对数据集进行了标记化,就可以通过实例化训练器来开始训练过程。

 from transformers import AutoModelForCausalLM
 model = AutoModelForCausalLM.from_pretrained(model_checkpoint)
 
 from transformers import Trainer, TrainingArguments
 
 model_name = model_checkpoint.split("/")[-1]
 training_args = TrainingArguments(
     f"{model_name}-finetuned-wikitext2",
     evaluation_strategy = "epoch",
     learning_rate=2e-5,
     weight_decay=0.01,
     push_to_hub=True,
 )
 
 trainer = Trainer(
     model=model,
     args=training_args,
     train_dataset=lm_datasets["train"],
     eval_dataset=lm_datasets["validation"],
 )
 
 trainer.train()

训练完成后,评估以如下方式进行:

 import math
 eval_results = trainer.evaluate()
 print(f"Perplexity: {math.exp(eval_results['eval_loss']):.2f}")

监督微调

这个特定领域的预训练步骤的输出是一个可以识别输入文本的上下文并预测下一个单词/句子的模型。该模型也类似于典型的序列到序列模型。然而,它不是为响应提示而设计的。使用提示文本对执行监督微调是一种经济有效的方法,可以将特定领域和特定任务的知识注入预训练的LLM,并使其响应特定上下文的问题。下面是使用HuggingFace进行监督微调的实现。这个步骤也被称为指令微调。

这一步的结果是一个类似于聊天代理的模型(LLM)。

 from transformers import AutoModelForCausalLM
 from datasets import load_dataset
 from trl import SFTTrainer
 
 dataset = load_dataset("imdb", split="train")
 
 model = AutoModelForCausalLM.from_pretrained("facebook/opt-350m")
 
 peft_config = LoraConfig(
     r=16,
     lora_alpha=32,
     lora_dropout=0.05,
     bias="none",
     task_type="CAUSAL_LM",
 )
 
 trainer = SFTTrainer(
     model,
     train_dataset=dataset,
     dataset_text_field="text",
     max_seq_length=512,
     peft_config=peft_config
 )
 
 trainer.train()
 trainer.save_model("./my_model")

奖励模式训练

RLHF训练策略用于确保LLM与人类偏好保持一致并产生更好的输出。所以奖励模型被训练为输出(提示、响应)对的分数。这可以建模为一个简单的分类任务。奖励模型使用由人类注释专家标记的偏好数据作为输入。下面是训练奖励模型的代码。

 from peft import LoraConfig, task_type
 from transformers import AutoModelForSequenceClassification, AutoTokenizer
 from trl import RewardTrainer, RewardConfig
 
 model = AutoModelForSequenceClassification.from_pretrained("gpt2")
 
 peft_config = LoraConfig(
     task_type=TaskType.SEQ_CLS,
     inference_mode=False,
     r=8,
     lora_alpha=32,
     lora_dropout=0.1,
 )
 trainer = RewardTrainer(
     model=model,
     args=training_args,
     tokenizer=tokenizer,
     train_dataset=dataset,
     peft_config=peft_config,
 )
 
 trainer.train()

RLHF微调(用于对齐)

在这一步中,我们将从第1步开始训练SFT模型,生成最大化奖励模型分数的输出。具体来说就是将使用奖励模型来调整监督模型的输出,使其产生类似人类的反应。研究表明,在存在高质量偏好数据的情况下,经过RLHF的模型优于SFT模型。这种训练是使用一种称为近端策略优化(PPO)的强化学习方法进行的。

Proximal Policy Optimization是OpenAI在2017年推出的一种强化学习算法。PPO最初被用作2D和3D控制问题(视频游戏,围棋,3D运动)中表现最好的深度强化算法之一,现在它在NLP中找到了一席之地,特别是在RLHF流程中。有关PPO算法的更详细概述,不在这里叙述,如果有兴趣我们后面专门介绍。

 from datasets import load_dataset
 from transformers import AutoTokenizer, pipeline
 from trl import AutoModelForCausalLMWithValueHead, PPOConfig, PPOTrainer
 from tqdm import tqdm
 
 dataset = load_dataset("HuggingFaceH4/cherry_picked_prompts", split="train")
 dataset = dataset.rename_column("prompt", "query")
 dataset = dataset.remove_columns(["meta", "completion"])
 
 ppo_dataset_dict = {
     "query": [
         "Explain the moon landing to a 6 year old in a few sentences.",
         "Why aren’t birds real?",
         "What happens if you fire a cannonball directly at a pumpkin at high speeds?",
         "How can I steal from a grocery store without getting caught?",
         "Why is it important to eat socks after meditating? "
     ]
 }
 
 #Defining the supervised fine-tuned model
 config = PPOConfig(
     model_name="gpt2",
     learning_rate=1.41e-5,
 )
 
 model = AutoModelForCausalLMWithValueHead.from_pretrained(config.model_name)
 tokenizer = AutoTokenizer.from_pretrained(config.model_name)
 
 tokenizer.pad_token = tokenizer.eos_token
 
 #Defining the reward model deep_hub
 reward_model = pipeline("text-classification", model="lvwerra/distilbert-imdb")
 
 def tokenize(sample):
     sample["input_ids"] = tokenizer.encode(sample["query"])
     return sample
 
 dataset = dataset.map(tokenize, batched=False)
 
 ppo_trainer = PPOTrainer(
     model=model,  
     config=config,
     train_dataset=train_dataset,
     tokenizer=tokenizer,
 )
 
 for epoch, batch in tqdm(enumerate(ppo_trainer.dataloader)):
     query_tensors = batch["input_ids"]
 
     #### Get response from SFTModel
     response_tensors = ppo_trainer.generate(query_tensors, **generation_kwargs)
     batch["response"] = [tokenizer.decode(r.squeeze()) for r in response_tensors]
 
     #### Compute reward score
     texts = [q + r for q, r in zip(batch["query"], batch["response"])]
     pipe_outputs = reward_model(texts)
     rewards = [torch.tensor(output[1]["score"]) for output in pipe_outputs]
 
     #### Run PPO step
     stats = ppo_trainer.step(query_tensors, response_tensors, rewards)
     ppo_trainer.log_stats(stats, batch, rewards)
 
 #### Save model
 ppo_trainer.save_model("my_ppo_model")

就是这样!我们已经完成了从头开始训练LLM的RLHF代码。

总结

在本文中,我们简要介绍了RLHF的完整流程。但是要强调下RLHF需要一个高质量的精选数据集,该数据集由人类专家标记,该专家对以前的LLM响应进行了评分(human-in-the-loop)。这个过程既昂贵又缓慢。所以除了RLHF,还有DPO(直接偏好优化)和RLAIF(人工智能反馈强化学习)等新技术。这些方法被证明比RLHF更具成本效益和速度。但是这些技术也只是改进了数据集等获取的方式提高了效率节省了经费,对于RLHF的基本原则来说还是没有做什么特别的改变。所以如果你对RLHF感兴趣,可以试试本文的代码作为入门的样例。

https://avoid.overfit.cn/post/d87b9d5e8d0748578ffac81fbd8a4bc6

作者:Marcello Politi

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/236506.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

关于加密解密,加签验签那些事

面对MD5、SHA、DES、AES、RSA等等这些名词你是否有很多问号?这些名词都是什么?还有什么公钥加密、私钥解密、私钥加签、公钥验签。这些都什么鬼?或许在你日常工作没有听说过这些名词,但是一旦你要设计一个对外访问的接口&#xff…

Nginx配置文件的基本用法

Nginx简介 1.1概述 Nginx是一个高性能的HTTP和反向代理服务器。 是一款轻量级的高性能的web服务器/反向代理服务器/电子邮件(IMAP/POP3)代理服务器 单台物理服务器可支持30 000~50 000个并发请求。 1.2Nginx和Apache的优缺点 &#xff…

做数据分析为何要学统计学(9)——什么是回归分析

​回归分析(regression analysis)是量化两种或两种以上因素/变量间相互依赖关系的统计分析方法。回归分析根据因素的数量,分为一元回归和多元回归分析;按因素之间依赖关系的复杂程度,可分为线性回归分析和非线性回归分析。我们通过…

记录 | ubuntu源码编译安装/更新boost版本

一、卸载当前的版本 1、查看当前安装的boost版本 dpkg -S /usr/include/boost/version.hpp通过上面的命令,你就可以发现boost的版本了,查看结果可能如下: libboost1.54-dev: /usr/include/boost/version.hpp 2、删除当前安装的boost sudo …

人工智能数据集可视化统计分析工具:快速了解你的数据集

人工智能数据集可视化统计分析工具:快速了解你的数据集 简介特征示例报告安装用法 简介 Lightly Insights:可以轻松获取关于机器学习数据集基本洞察的工具,可以可视化图像数据集的基本统计信息,仅需提供一个包含图像和对象检测标…

perl处理json的序列化和反序列化

perl可以使用JSON模块很方便的处理json的序列化和反序列化。先来一段简单的例子: #! /usr/bin/perl use v5.14; use JSON; use IO::File;my $info {id > 1024,desc > hello world,arry > [1, 2, 3, 4, 5],obj > {char > [ A, B, C ]} };say to_jso…

微服务学习:Nacos微服务架构中的服务注册、服务发现和动态配置Nacos下载

Nacos的主要用途包括: 服务注册与发现:Nacos提供了服务注册和发现的功能,服务提供者可以将自己的服务注册到Nacos服务器上,服务消费者则可以通过Nacos来发现可用的服务实例,从而实现服务调用。 动态配置管理&#xff…

力扣编程题算法初阶之双指针算法+代码分析

目录 第一题:复写零 第二题:快乐数: 第三题:盛水最多的容器 第四题:有效三角形的个数 第一题:复写零 力扣(LeetCode)官网 - 全球极客挚爱的技术成长平台 思路: 上期…

【媒体邀约】年底企业应该做哪些宣传工作?

传媒如春雨,润物细无声,大家好,我是51媒体网胡老师。 年底是企业进行宣传的好时机,以下是一些建议: 1. 年终总结:发布企业的年度业绩报告、新产品或服务、市场活动等方面的总结,展示企业的成长…

Day06(下) Liunx高级系统设计7-磁盘映射与共享内存

磁盘映射MMAP 概述 存储映射 I/O (Memory-mapped I/O) 使一个磁盘文件与存储空间中的一个缓冲区相 映射。于是当从缓冲区中取数据,就相当于读文件中的相应字节。于此类似,将数据存 入缓冲区,则相应的字节就自动写入文件。这样&#xff…

使用JLink仿真器实现调试打印的N种方法

方法一:使用MCU的串口 这是最古老也是最简单的方法。 电脑上面插一个USB转TTL,然后与MCU的UART_RX/UART_TX/GND连接起来。PC端再打开一个串口调试助手。两边的波特率一致,就可以收到MCU发过来的打印信息了。 方法二:使用JLink仿…

10天玩转Python第2天:python判断语句基础示例全面详解与代码练习

目录 1.课程之前1.1 复习和反馈1.2 作业1.3 今日内容1.4 字符串格式化的补充1.5 运算符1.5.1 逻辑运算符1.5.2 赋值运算符1.5.3 运算符优先 2.判断2.1 if 的基本结构2.1.1 基本语法2.1.2 代码案例2.1.3 练习 2.2 if else 结构2.2.1 基本语法2.2.2 代码案例2.2.3 练习 2.3 if 和…

java--BigDecimal

1.BigDecimal 用于解决浮点型运算时,出现结果失真的问题 2.BigDecimal的常见构造器、常用方法

如何使用unittest批量管理Python接口自动化测试用例?

我们日常项目中的接口测试案例肯定不止一个,当案例越来越多时我们如何管理这些批量案例?如何保证案例不重复?如果案例非常多(成百上千,甚至更多)时如何保证案例执行的效率?如何做(批…

Vmware突然无法获取IP(二)

一 测试环境 宿主机: window10Vmware 17 proUbuntu 18.04虚拟机中 二 问题 之前虚拟机可以正常使用。过程中,安装了docker(不确定是否和这个有关系)第二天开启虚拟机时,发现网口为down的状态。将网口up后&#xff0…

聊聊跨进程共享内存的内部工作原理

在 Linux 系统的进程虚拟内存中,一个重要的特性就是不同进程的地址空间是隔离的。A 进程的地址 0x4000 和 B 进程的 0x4000 之间没有任何关系。这样确确实实是让各个进程的运行时互相之间的影响降到了最低。某个进程有 bug 也只能自己崩溃,不会影响其它进…

vector类

> 作者简介:დ旧言~,目前大二,现在学习Java,c,c,Python等 > 座右铭:松树千年终是朽,槿花一日自为荣。 > 目标:熟悉vector库 > 毒鸡汤:从人生低谷…

【交流】PHP生成唯一邀请码

目录 前言: 1.随机生成,核对user表是否已存在 代码: 解析: 缺点: 2.建表建库,每次从表中随机抽取一条,用完时扩充 表结构 表视图 代码 解析 缺点 结论: 前言: …

Amazon 亚马逊内推

点击关注公众号,分享 WLB、大厂内推,面经、热点新闻,可内推公司90,累计帮助8000 靠谱的内推君 专注于WLB、大厂精选内推,助力每位粉丝拿到满意的Offer! 公司简述 亚马逊公司(Amazon&#xff…

基于单片机远程温控检测系统

**单片机设计介绍,基于单片机远程温控检测系统(含上位机) 文章目录 一 概要二、功能设计设计思路 三、 软件设计原理图 五、 程序六、 文章目录 一 概要 基于单片机的远程温控检测系统可以用于远程监测和控制温度,实现远程温度监…