Huggingface微调BART的代码示例:WMT16数据集训练新的标记进行翻译

BART模型是用来预训练seq-to-seq模型的降噪自动编码器(autoencoder)。它是一个序列到序列的模型,具有对损坏文本的双向编码器和一个从左到右的自回归解码器,所以它可以完美的执行翻译任务。

如果你想在翻译任务上测试一个新的体系结构,比如在自定义数据集上训练一个新的标记,那么处理起来会很麻烦,所以在本文中,我将介绍添加新标记的预处理步骤,并介绍如何进行模型微调。

因为Huggingface Hub有很多预训练过的模型,可以很容易地找到预训练标记器。但是我们要添加一个标记可能就会有些棘手,下面我们来完整的介绍如何实现它,首先加载和预处理数据集。

加载数据集

我们使用WMT16数据集及其罗马尼亚语-英语子集。load_dataset()函数将从Huggingface下载并加载任何可用的数据集。

 importdatasets
 
 dataset=datasets.load_dataset("stas/wmt16-en-ro-pre-processed", cache_dir="./wmt16-en_ro")

在上图1中可以看到数据集内容。我们需要将其“压平”,这样可以更好的访问数据,让后将其保存到硬盘中。

 defflatten(batch):
     batch['en'] =batch['translation']['en']
     batch['ro'] =batch['translation']['ro']
     
     returnbatch
 
 # Map the 'flatten' function
 train=dataset['train'].map( flatten )
 test=dataset['test'].map( flatten )
 validation=dataset['validation'].map( flatten )
 
 # Save to disk
 train.save_to_disk("./dataset/train")
 test.save_to_disk("./dataset/test")
 validation.save_to_disk("./dataset/validation")

下图2可以看到,已经从数据集中删除了“translation”维度。

标记器

标记器提供了训练标记器所需的所有工作。它由四个基本组成部分:(但这四个部分不是所有的都是必要的)

Models:标记器将如何分解每个单词。例如,给定单词“playing”:i) BPE模型将其分解为“play”+“ing”两个标记,ii) WordLevel将其视为一个标记。

Normalizers:需要在文本上发生的一些转换。有一些过滤器可以更改Unicode、小写字母或删除内容。

Pre-Tokenizers:为操作文本提供更大灵活性处理的函数。例如,如何处理数字。数字100应该被认为是“100”还是“1”、“0”、“0”?

Post-Processors:后处理具体情况取决于预训练模型的选择。例如,将 [BOS](句首)或 [EOS](句尾)标记添加到 BERT 输入。

下面的代码使用BPE模型、小写Normalizers和空白Pre-Tokenizers。然后用默认值初始化训练器对象,主要包括

1、词汇量大小使用50265以与BART的英语标记器一致

2、特殊标记,如和,

3、初始词汇量,这是每个模型启动过程的预定义列表。

 fromtokenizersimportnormalizers, pre_tokenizers, Tokenizer, models, trainers
 
 # Build a tokenizer
 bpe_tokenizer=Tokenizer(models.BPE())
 bpe_tokenizer.normalizer=normalizers.Lowercase()
 bpe_tokenizer.pre_tokenizer=pre_tokenizers.Whitespace()
 
 trainer=trainers.BpeTrainer(
     vocab_size=50265,
     special_tokens=["<s>", "<pad>", "</s>", "<unk>", "<mask>"],
     initial_alphabet=pre_tokenizers.ByteLevel.alphabet(),
 )

使用Huggingface的最后一步是连接Trainer和BPE模型,并传递数据集。根据数据的来源,可以使用不同的训练函数。我们将使用train_from_iterator()。

 defbatch_iterator():
     batch_length=1000
     foriinrange(0, len(train), batch_length):
         yieldtrain[i : i+batch_length]["ro"]
         
 bpe_tokenizer.train_from_iterator( batch_iterator(), length=len(train), trainer=trainer )
 
 bpe_tokenizer.save("./ro_tokenizer.json")
 

BART微调

现在可以使用使用新的标记器了。

 fromtransformersimportAutoTokenizer, PreTrainedTokenizerFast
 
 en_tokenizer=AutoTokenizer.from_pretrained( "facebook/bart-base" );
 ro_tokenizer=PreTrainedTokenizerFast.from_pretrained( "./ro_tokenizer.json" );
 ro_tokenizer.pad_token=en_tokenizer.pad_token
 
 deftokenize_dataset(sample):
     input=en_tokenizer(sample['en'], padding='max_length', max_length=120, truncation=True)
     label=ro_tokenizer(sample['ro'], padding='max_length', max_length=120, truncation=True)
 
     input["decoder_input_ids"] =label["input_ids"]
     input["decoder_attention_mask"] =label["attention_mask"]
     input["labels"] =label["input_ids"]
 
     returninput
 
 train_tokenized=train.map(tokenize_dataset, batched=True)
 test_tokenized=test.map(tokenize_dataset, batched=True)
 validation_tokenized=validation.map(tokenize_dataset, batched=True)

上面代码的第5行,为罗马尼亚语的标记器设置填充标记是非常必要的。因为它将在第9行使用,标记器使用填充可以使所有输入都具有相同的大小。

下面就是训练的过程:

 fromtransformersimportBartForConditionalGeneration
 fromtransformersimportSeq2SeqTrainingArguments, Seq2SeqTrainer
 
 model=BartForConditionalGeneration.from_pretrained(  "facebook/bart-base" )
 
 training_args=Seq2SeqTrainingArguments(
     output_dir="./",
     evaluation_strategy="steps",
     per_device_train_batch_size=2,
     per_device_eval_batch_size=2,
     predict_with_generate=True,
     logging_steps=2,  # set to 1000 for full training
     save_steps=64,  # set to 500 for full training
     eval_steps=64,  # set to 8000 for full training
     warmup_steps=1,  # set to 2000 for full training
     max_steps=128, # delete for full training
     overwrite_output_dir=True,
     save_total_limit=3,
     fp16=False, # True if GPU
 )
 
 trainer=Seq2SeqTrainer(
     model=model,
     args=training_args,
     train_dataset=train_tokenized,
     eval_dataset=validation_tokenized,
 )
 
 trainer.train()

过程也非常简单,加载bart基础模型(第4行),设置训练参数(第6行),使用Trainer对象绑定所有内容(第22行),并启动流程(第29行)。上述超参数都是测试目的,所以如果要得到最好的结果还需要进行超参数的设置,我们使用这些参数是可以运行的。

推理

推理过程也很简单,加载经过微调的模型并使用generate()方法进行转换就可以了,但是需要注意的是对源 (En) 和目标 (RO) 序列使用适当的分词器。

总结

虽然在使用自然语言处理(NLP)时,标记化似乎是一个基本操作,但它是一个不应忽视的关键步骤。HuggingFace的出现可以方便的让我们使用,这使得我们很容易忘记标记化的基本原理,而仅仅依赖预先训练好的模型。但是当我们希望自己训练新模型时,了解标记化过程及其对下游任务的影响是必不可少的,所以熟悉和掌握这个基本的操作是非常有必要的。

本文代码:https://avoid.overfit.cn/post/6a533780b5d842a28245c81bf46fac63

作者:Ala Alam Falaki

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/8194.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

游戏运营专员的职责有哪些?提高游戏收入的关键是什么?

游戏运营是将一款游戏平台推入市场&#xff0c;通过对平台的运作&#xff0c;使用户从接触、认识、再到了解实际线上的一种操作、最终成为这款游戏平台的忠实玩家的这一过程。同时通过一系列的营销手段达到提高线上人数&#xff0c;刺激消费增长利润的目的。 游戏运营专员的职…

Go 连接池的设计与实现

为什么需要连接池 如果不用连接池&#xff0c;而是每次请求都创建一个连接是比较昂贵的&#xff0c;因此需要完成3次tcp握手 同时在高并发场景下&#xff0c;由于没有连接池的最大连接数限制&#xff0c;可以创建无数个连接&#xff0c;耗尽文件描述符 连接池就是为了复用这…

高效的实现金蝶云星空ERP与自研MES系统数据集成

一、项目背景 随着企业数字化转型的不断深入&#xff0c;数据集成变得愈发重要。金蝶云星空ERP与自研MES系统之间的数据集成是企业提高管理效率、降低运营成本的关键。为了实现这一目标&#xff0c;企业选择了轻易云数据集成平台进行数据集成。 二、项目实施过程 低耦合、高内…

二叉树的前序遍历(力扣144)

目录 题目描述&#xff1a; 解法一&#xff1a;递归法 解法二&#xff1a;迭代法 解法三&#xff1a;Morris 遍历 二叉树的前序遍历 题目描述&#xff1a; 给你二叉树的根节点 root &#xff0c;返回它节点值的 前序 遍历。 示例 1&#xff1a; 输入&#xff1a;root […

Unity反编译:AssetStudio资源浏览器及代码查看器

前言 假如你手上有Unity发布出来的exe文件、apk文件或者webGL文件&#xff0c;但就是没有工程源文件&#xff0c;那么&#xff0c;如何从这些文件里面一窥究竟呢&#xff1f;这就需要资源提取工具以及代码反编译工具&#xff01; 本文所涉软件【文中附有下载链接】&#xff1…

【接口测试工具】Eolink Apikit 快速入门教程

Eolink Apikit 下载安装【官方版】&#xff1a;https://www.eolink.com/apikit 发起 API 测试 进入 API 文档详情页&#xff0c;点击上方 测试 标签&#xff0c;进入 API 测试页&#xff0c;系统会根据 API 文档自动生成测试界面并且填充测试数据。 填写请求参数 首先填写好请…

【创作赢红包】python学习——【第七弹】

前言 上一篇文章 python学习——【第六弹】中介绍了 python中的字典操作&#xff0c;这篇文章接着学习python中的可变序列 集合 集合 1&#xff1a; 集合是python语言提供的内置数据结构&#xff0c;具有无序性&#xff08;集合中的元素无法通过索引下标访问&#xff0c;并且…

UDP协议详解

目录 UDP协议报文结构 端口号 报文长度 校验和 生成校验和的算法 MD5的特点 UDP协议报文结构 UDP会把载荷数据(也就是通过 UDP socekt,send方法拿来的数据基础上,再前面拼装(相当于字符串拼接此处是二进制的)上几个字节的报头 UDP报头里包含了一些特定的属性,这些属性携带…

阿里云linux云服务器 安装指定版本node.js

我们在实例管理中找到自己的服务器 然后点击右侧的 远程连接 接着点击理解登录 进入命令窗口 我们在这上面输入 curl -h阿里云的服务器都还是最好会有 curl的 然后 我们输入 curl -o- https://raw.githubusercontent.com/nvm-sh/nvm/v0.34.0/install.sh | bash下把nvm下下…

量化注意事项和模型设计思想

量化的注意事项 1、量化检测器时&#xff0c;尽量不要对Detect Head进行量化&#xff0c;一旦进行量化可能会引起比较大的量化误差&#xff1b; 2、量化模型时&#xff0c;模型的First&Second Layer也尽可能不进行量化&#xff08;精度损失具有随机性&#xff09;&#xf…

【软件设计师06】数据结构与算法基础

数据结构与算法基础 考点&#xff1a;数组与矩阵、线性表、广义表、树与二叉树、图、排序与查找、算法基础与常见的算法 1. 数组 数组类型存储地址计算一维度数组a[n]a[i]的存储地址为ai*len二维数组a[m][n]a[i][j]的存储地址&#xff1b;按行存储&#xff1a;a(i*nj)*len&a…

Spring原理学习(二):Bean的生命周期和Bean后处理器

〇、前言 倘若是为了面试&#xff0c;请背下来下面这段&#xff1a; spring的bean的生命周期主要是创建bean的过程&#xff0c;一个bean的生命周期主要是4个步骤&#xff1a;实例化、属性注入、初始化、销毁。但是对于一些复杂的bean的创建&#xff0c;spring会在bean的生命周期…

如何搭建chatGPT4.0模型-国内如何用chatGPT4.0

国内如何用chatGPT4.0 在国内&#xff0c;目前可以通过以下途径使用 OpenAI 的 ChatGPT 4.0&#xff1a; 自己搭建模型&#xff1a;如果您具备一定的技术能力&#xff0c;可以通过下载预训练模型和相关的开发工具包&#xff0c;自行搭建 ChatGPT 4.0 模型。OpenAI提供了相关的…

旅游心得Traveling Experience

前言 加油 原文 旅游心得常用会话 ❶ Share photos of the trip with friends. 与朋友分享旅游的照片。 ❷ We’ll go to the Great Wall, if you prefer. 你如果愿意的话,我们去长城。 ❸ Would you go to the church or the synagogue or the mosque? 你会去教堂,犹太…

二结(4.11)IO流学习

FIle类只能对文件本身操作&#xff0c;不能读写文件里面存储的数据 文件保存的位置叫路径&#xff0c;而数据传输叫IO流 Java I/O流&#xff08;Input/Output stream&#xff09;在Java应用程序中用于读取和写入数据&#xff0c;可分为基本流和高级流两类 关于什么是输出流、…

CSC中加学者交换项目申报即将开始

3月31日&#xff0c;国家留学基金委&#xff08;CSC&#xff09;发布了2023-2024年度中加学者交换项目遴选通知。根据通知精神&#xff0c;选派规模&#xff1a;100人月&#xff0c;留学及资助期限&#xff1a;4-12个月&#xff0c;网上报名及申请受理时间为2023年4月11日至6月…

SpringCloud学习6(Spring Cloud Alibaba)断路器Sentinel熔断降级

文章目录服务熔断降级Sentinel高并发请求模拟&#xff08;这里我们使用contiperf来进行测试&#xff09;修改tomcat配置最大线程数引入测试依赖编写测试代码服务雪崩服务雪崩的容错方案&#xff08;隔离、超时、限流、熔断、降级&#xff09;隔离机制&#xff1a;超时机制&…

Baumer工业相机堡盟工业相机如何设置网口的IP地址(工业相机连接的网口设置IP地址步骤)

Baumer工业相机堡盟工业相机如何设置网口的IP地址&#xff08;工业相机连接的网口设置IP地址步骤&#xff09;Baumer工业相机Baumer工业相机设置网络端口IP地址匹配设置网络端口IP地址和工业相机IP地址匹配第一次打开CameraExplorer软件确认问题为IP地址不匹配问题打开网络连接…

C++ - 继承 | 菱形继承

之前的文章中我们简要的讲述了C中继承部分的知识&#xff0c;但是还没有完全的讲完&#xff0c;在本文中将会讲到菱形继承的问题。 复杂的菱形继承 单继承&#xff1a;一个子类只有一个直接父类时称这个继承关系为单继承。 多继承&#xff1a;一个子类有两个或以上直接父类时…

最新阿里、腾讯、华为、字节等大厂的薪资和职级对比,看看你差了多少...

互联网大厂新入职员工各职级薪资对应表(技术线)~ 最新阿里、腾讯、华为、字节跳动等大厂的薪资和职级对比 上面的表格不排除有很极端的收入情况&#xff0c;但至少能囊括一部分同职级的收入。这个表是“技术线”新入职员工的职级和薪资情况&#xff0c;非技术线(如产品、运营、…