基于transformers框架实践Bert系列6-完形填空

本系列用于Bert模型实践实际场景,分别包括分类器、命名实体识别、选择题、文本摘要等等。(关于Bert的结构和详细这里就不做讲解,但了解Bert的基本结构是做实践的基础,因此看本系列之前,最好了解一下transformers和Bert等)
本篇主要讲解完形填空应用场景。本系列代码和数据集都上传到GitHub上:https://github.com/forever1986/bert_task

1 环境说明

1)本次实践的框架采用torch-2.1+transformer-4.37
2)另外还采用或依赖其它一些库,如:evaluate、pandas、datasets、accelerate等

2 前期准备

Bert模型是一个只包含transformer的encoder部分,并采用双向上下文和预测下一句训练而成的预训练模型。可以基于该模型做很多下游任务。

2.1 了解Bert的输入输出

Bert的输入:input_ids(使用tokenizer将句子向量化),attention_mask,token_type_ids(句子序号)、labels(结果)
Bert的输出:
last_hidden_state:最后一层encoder的输出;大小是(batch_size, sequence_length, hidden_size)(注意:这是关键输出,本次任务就需要获取该值,可以取出那个被mask掉的token,获取其前几个,取score最高的(当然也可以使用top_k或者top_p方式获取一定随机性)
pooler_output:这是序列的第一个token(classification token)的最后一层的隐藏状态,输出的大小是(batch_size, hidden_size),它是由线性层和Tanh激活函数进一步处理的。(通常用于句子分类,至于是使用这个表示,还是使用整个输入序列的隐藏状态序列的平均化或池化,视情况而定)。
hidden_states: 这是输出的一个可选项,如果输出,需要指定config.output_hidden_states=True,它也是一个元组,它的第一个元素是embedding,其余元素是各层的输出,每个元素的形状是(batch_size, sequence_length, hidden_size)
attentions:这是输出的一个可选项,如果输出,需要指定config.output_attentions=True,它也是一个元组,它的元素是每一层的注意力权重,用于计算self-attention heads的加权平均值。

2.2 数据集与模型

1)数据集来自:ChnSentiCorp(该数据集本身是做情感分类,但是我们只需要取其text部分即可)
2)模型权重使用:bert-base-chinese

2.3 任务说明

完形填空其实就是在一段文字中mask掉几个字,让模型能够自动填充字。这里本身就是bert模型做预训练是所做的事情之一,因此就是让数据给模型做训练的过程。

2.4 实现关键

1)数据集结构是一个带有text和label两列的数据,我们只需要获取到text部分即可。
在这里插入图片描述
2)随机mask掉部分数据,这个本身也是bert的训练过程,因此在transforms框架中DataCollatorForLanguageModeling已经实现了,你也可以自己实现随机mask掉你的数据进行训练

3 关键代码

3.1 数据集处理

数据集不需要做过多处理,只需要将text部分进行tokenizer,并制定max_length和truncation即可

def process_function(datas):
    tokenized_datas = tokenizer(datas["text"], max_length=256, truncation=True)
    return tokenized_datas
new_datasets = datasets.map(process_function, batched=True, remove_columns=datasets["train"].column_names)

3.2 模型加载

model = BertForMaskedLM.from_pretrained(model_path)

注意:这里使用的是transformers中的BertForMaskedLM,该类对bert模型进行封装。如果我们不使用该类,需要自己定义一个model,继承bert,增加分类线性层。另外使用AutoModelForMaskedLM也可以,其实AutoModel最终返回的也是BertForMaskedLM,它是根据你config中的model_type去匹配的。
这里列一下BertForMaskedLM的关键源代码说明一下transformers帮我们做了哪些关键事情

# 在__init__方法中增加增加了BertOnlyMLMHead,BertOnlyMLMHead其实就是一个二层神经网络,一层是BertPredictionHeadTransform(包括linear+geluAct+ln),一层是decoder(hidden_size*vocab_size大小的linear)。
self.bert = BertModel(config, add_pooling_layer=False)
self.cls = BertOnlyMLMHead(config)
# 将输出结果outputs取第一个返回值,也就是last_hidden_state
sequence_output = outputs[0]
# 将last_hidden_state输入到cls层中,获得最终结果(预测的score和词)
prediction_scores = self.cls(sequence_output)

3.3 自动并随机mask数据

关键代码在于DataCollatorForLanguageModeling,该类会实现自动mask。参考torch_mask_tokens方法。

trainer = Trainer(model=model,
                  args=train_args,
                  train_dataset=new_datasets["train"],
                  data_collator=DataCollatorForLanguageModeling(tokenizer, mlm=True, mlm_probability=0.15),
                  )

4 整体代码

"""
基于BERT做完形填空
1)数据集来自:ChnSentiCorp
2)模型权重使用:bert-base-chinese
"""
# step 1 引入数据库
from datasets import DatasetDict
from transformers import TrainingArguments, Trainer, BertTokenizerFast, BertForMaskedLM, DataCollatorForLanguageModeling, pipeline

model_path = "./model/tiansz/bert-base-chinese"
data_path = "./data/ChnSentiCorp"

# step 2 数据集处理
datasets = DatasetDict.load_from_disk(data_path)
tokenizer = BertTokenizerFast.from_pretrained(model_path)


def process_function(datas):
    tokenized_datas = tokenizer(datas["text"], max_length=256, truncation=True)
    return tokenized_datas


new_datasets = datasets.map(process_function, batched=True, remove_columns=datasets["train"].column_names)


# step 3 加载模型
model = BertForMaskedLM.from_pretrained(model_path)


# step 4 创建TrainingArguments
# 原先train是9600条数据,batch_size=32,因此每个epoch的step=300
train_args = TrainingArguments(output_dir="./checkpoints",      # 输出文件夹
                               per_device_train_batch_size=32,  # 训练时的batch_size
                               num_train_epochs=1,              # 训练轮数
                               logging_steps=30,                # log 打印的频率
                               )


# step 5 创建Trainer
trainer = Trainer(model=model,
                  args=train_args,
                  train_dataset=new_datasets["train"],
                  # 自动MASK关键所在,通过DataCollatorForLanguageModeling实现自动MASK数据
                  data_collator=DataCollatorForLanguageModeling(tokenizer, mlm=True, mlm_probability=0.15),
                  )


# Step 6 模型训练
trainer.train()

# step 7 模型评估
pipe = pipeline("fill-mask", model=model, tokenizer=tokenizer, device=0)
str = datasets["test"][3]["text"]
str = str.replace("方便","[MASK][MASK]")
results = pipe(str)
# results[0][0]["token_str"]
print(results[0][0]["token_str"]+results[1][0]["token_str"])

5 运行效果

在这里插入图片描述

注:本文参考来自大神:https://github.com/zyds/transformers-code

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/641850.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

c++入门的基础知识

c入门 C是在C的基础之上,容纳进去了面向对象编程思想,并增加了许多有用的库,以及编程范式等。熟悉C语言之后,对C学习有一定的帮助,本章节主要目标: 补充C语言语法的不足,以及C是如何对C语言设计…

ClickHouse vs. Elasticsearch: 计数聚合的工作原理

本文字数:7875;估计阅读时间:20 分钟 审校:庄晓东(魏庄) 介绍 在另一篇博客文章中,我们对 ClickHouse 和 Elasticsearch 在大规模数据分析和可观测性用例中的性能进行了比较,特别是对…

k8s-helloword部署一个应用

k8s-helloword部署一个应用 快速部署一个pod命令 部署一个名为 test-nginx Pod 方式一:使用 kubectl run kubectl run test-nginx --imagenginx然后使用 kubectl get pod 查看,kubectl get pod 是查看默认名称空间下的Pod 如果想要跟详细的查看这个…

HTML静态网页成品作业(HTML+CSS)——宠物狗介绍网页(3个页面)

🎉不定期分享源码,关注不丢失哦 文章目录 一、作品介绍二、作品演示三、代码目录四、网站代码HTML部分代码 五、源码获取 一、作品介绍 🏷️本套采用HTMLCSS,未使用Javacsript代码,共有3个页面。 二、作品演示 三、代…

如何禁止U盘拷贝文件|禁止U盘使用的软件有哪些

禁止U盘拷贝文件的方法有很多,比如使用注册表、组策略编辑器等,但这些方法都适合个人,不适合企业,因为企业需要对下属多台电脑进行远程管控,需要方便、省时、省力的方法。目前来说,最好的方法就是使用第三方…

RDP方式连接服务器上传文件方法

随笔 目录 1. RDP 连接服务器 2. 为避免rdp 访问界面文字不清晰 3. 本地上传文件到服务器 1. RDP 连接服务器 # mstsc 连接服务器step1: 输入mstscstep2: 输入 IP, username, passwd 2. 为避免rdp 访问界面文字不清晰 解决方法: 3. 本地上传文件到服务器 step…

Java进阶学习笔记13——抽象类

认识抽象类: 当我们在做子类共性功能抽取的时候,有些方法在父类中并没有具体的体现,这个时候就需要抽象类了。在Java中,一个没有方法体的方法应该定义为抽象方法,而类中如果有抽象方法,该类就定义为抽象类…

ASP+ACCESS基于WEB网上留言板

摘要 本文概述了ACCESS数据库及其相关的一些知识,着重论述ACCESS数据库和ASP的中间技术,构建一个简单的留言板。具体的实现是构造一个留言板系统,能很方便的和同学沟通和交流。留言板具有功能强大、使用方便的特点。用户以个人的身份进入&am…

jenkins+sonarqube部署与配置过程

1、部署jenkins(本文不做说明) 2、部署sonarqube(docker-compose) version: "2.1"services:sonarqube:image: sonarqube:9.9.4-communitycontainer_name: sonarqubedepends_on:- dbports:- 9000:9000networks:- sonarnetenvironment:SONARQU…

集合、Collection接口特点和常用方法

1、集合介绍 对于保存多个数据使用的是数组,那么数组有不足的地方。比如, 长度开始时必须指定,而且一旦制定,不能更改。 保存的必须为同一类型的元素。 使用数组进行增加/删除元素的示意代码,也就是比较麻烦。 为…

深入理解CPU缓存一致性

存储体系结构 速度快的存储硬件成本高、容量小,速度慢的成本低、容量大。为了权衡成本和速度,计算机存储分了很多层次,有寄存器、L1 cache、L2 cache、L3 cache、主存(内存)和硬盘等。 根据程序的空间局部性和时间局…

【qt】标准型模型 下

标准型模型 一.前言二.预览数据1.获取表头2.获取数据项 三.保存文件1.文件对话框获取保存文件名2.用文件名初始化文件对象3.打开文件对象4.用文件对象初始化文本流5.写入数据 四.格式1.居右2.居中3.居左4.粗体 五.模型的信号1.解决粗体action问题2.状态栏显示信息 六.总结 一.前…

visual studio 2022 ssh 主机密钥算法失败问题解决

 Solution - aengusjiang 问题: I follow the document, then check sshd_config, uncomment“HostKey /etc/ssh/ssh_host_ecdsa_key” maybe need add the key algorithms: #HostKeyAlgorithms ssh-ed25519[Redacted][Redacted]rsa-sha2-256,rsa-sha2-512 Ho…

【C++初阶】vector

✅✅✅✅✅✅✅✅✅✅✅✅✅✅✅✅ ✨✨✨✨✨✨✨✨✨✨✨✨✨✨✨✨ 🌿🌿🌿🌿🌿🌿🌿🌿🌿🌿🌿🌿🌿🌿&#x1…

交叉熵损失函数计算过程(tensorflow)

交叉熵损失函数通常用于多类分类损失函数计算。计算公式如下: P为真实值,Q为预测值。 使用tensorflow计算 import tensorflow as tf import keras# 创建一个示例数据集 # 假设有3个样本,每个样本有4个特征,共2个类别 # 目标标签…

【退役之重学Java】关于B+树索引

一、为什么使用索引 一条数据可能有很多字段,数据量比较大,挨个查询效率极差故使用索引,提高查询性能和加快数据检索速度。同时还可以帮助优化排序、分组和连接操作,提高数据库系统的整体性能和响应速度。 二、为什么要用 B 树 B树…

【全开源】点餐小程序系统源码(ThinkPHP+FastAdmin+UniApp)

基于ThinkPHPFastAdminUniApp开发的点餐微信小程序,类似肯德基,麦当劳,喜茶等小程序多店铺模式,支持子商户模式,提供全部前后台无加密源代码和数据库,支持私有化部署。 革新餐饮行业的智慧点餐解决方案 一…

设计模式—23种设计模式重点 表格梳理

设计模式的核心在于提供了相关的问题的解决方案,使得人们可以更加简单方便的复用成功的设计和体系结构。 按照设计模式的目的可以分为三大类。创建型模式与对象的创建有关;结构型模式处理类或对象的组合;行为型模式对类或对象怎样交互和怎样…

视频怎么转换成二维码图片?视频做成二维码播放的方法

怎样在电脑上制作可以播放视频的二维码呢?很多日常生活中,很多的场景或者物品都会有自己的二维码,其他人通过扫码就可以获取对应的内容。有很多场景下会把视频转换二维码,通过扫码在手机上查看视频内容,比如产品介绍、…

408数据结构-图的基本概念 自学知识点整理

*第六章个人感觉是最难的,请各位抓稳扶手,系好安全带。 图的定义 通俗来讲,一个图由一些点和连接这些点的若干边组成,边的两头必须都有顶点,否则不是图。 注:G: Graph; V: Vertex; …