HuggingFace-利用BERT预训练模型实现中文情感分类(下游任务)

准备数据集

使用编码工具

首先需要加载编码工具,编码工具可以将抽象的文字转成数字,便于神经网络后续的处理,其代码如下:

# 定义数据集
from transformers import BertTokenizer, BertModel, AdamW
# 加载tokenizer
token = BertTokenizer.from_pretrained('bert-base-chinese')
print('token', token)

out:
token BertTokenizer(name_or_path=‘bert-base-chinese’, vocab_size=21128, model_max_length=512, is_fast=False, padding_side=‘right’, truncation_side=‘right’, special_tokens={‘unk_token’: ‘[UNK]’, ‘sep_token’: ‘[SEP]’, ‘pad_token’: ‘[PAD]’, ‘cls_token’: ‘[CLS]’, ‘mask_token’: ‘[MASK]’}, clean_up_tokenization_spaces=True), added_tokens_decoder={
0: AddedToken(“[PAD]”, rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
100: AddedToken(“[UNK]”, rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
101: AddedToken(“[CLS]”, rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
102: AddedToken(“[SEP]”, rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
103: AddedToken(“[MASK]”, rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
}

由上可知bert-base-chinese模型的字典中共有21128个词,编码器编码句子的最大长度为512个词,并且能够看到bert-base-chinese模型所使用的一些特殊符号,例如SEK,PAD等。

这里使用的编码工具是bert-base-chinese,编码工具和预训练模型往往是成对使用的,后续将使用同名的预训练语言模型作为backbone。

编码工具的试算

加载完成编码工具之后可以进行一次试算,观察编码工具的输入和输出,代码如下:


    data = token.batch_encode_plus(batch_text_or_text_pairs=['关注博主,不迷路。','俺要带你上高速。'], truncation=True,
                                   padding='max_length',
                                   max_length=12,
                                   return_tensors='pt',
                                   return_length=True)
# 查看编码输出
for k,v in out.items():
	print(k,v.shape)
# 把编码还原成句子
print(token.decode(out['input_ids'][0]))

out:
input_ids torch.Size([2, 17])
token_type_ids torch.Size([2, 17])
length torch.Size([2])
attention_mask torch.Size([2, 17])
[CLS] 关 注 博 主 , 不 迷 路 。 [SEP] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD]
[CLS] 俺 要 带 你 上 高 速 。 [SEP] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD]

编码工具的参数说明

对于编码工具的使用,特别是参数值的含义可以参考下面的两段代码:

"""使用简单的编码"""
# 编码两个句子
out = tokenizer.encode(
    # 句子1
    text = sents[0],
    text_pair = sents[1],

    # 当句子长度大于max_length时进行截断
    truncation=True,

    # 一律补充pad到max_length长度
    padding = 'max_length',
    add_special_tokens = True,
    # 许多大模型的阶段也是使用512作为最终的max_length
    max_length=30,
    return_tensors=None,
)
"""增强的编码函数"""
# 增强的编码函数
out = tokenizer.encode_plus(
    text = sents[0],
    text_pair = sents[1],

    #当句子长度大于max_length时进行截断操作
    truncation = True,

    #一律补零到max_length长度
    padding='max_length',
    max_length=30,
    add_special_tokens=True,

    #可以取值tf,pt,np,默认返回list--->tensorflow,pytorch,numpy
    return_tensors=None,

    #返回token_type_ids
    return_token_type_ids=True,

    #返回attention_mask
    return_attention_mask=True,

    #返回special_tokens_mask 特殊符号标识
    return_special_tokens_mask=True,

    #返回offset_mapping标识每个词的起始和结束位置---》这个参数只能BertTokenizerFast使用
    #return_offsets_mapping=True,

    #返回length 标识长度
    return_length=True
)

从上面的代码中的参数max_length=500可以看出经过编码后的句子的长度一定是12个词的长度。如果源句子超出则会进行截断,如果源句子不足则会进行填充PAD,其运行结果如下:

{'input_ids': tensor([[ 101, 1068, 3800, 1300,  712, 8024,  679, 6837, 6662,  511,  102,    0],
        [ 101,  939, 6206, 2372,  872,  677, 7770, 6862,  511,  102,    0,    0]]), 'token_type_ids': tensor([[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
        [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]]), 'length': tensor([11, 10]), 'attention_mask': tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0],
        [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0]])}
input_ids torch.Size([2, 12])
token_type_ids torch.Size([2, 12])
length torch.Size([2])
attention_mask torch.Size([2, 12])
[CLS] 关 注 博 主 , 不 迷 路 。 [SEP] [PAD]
[CLS] 俺 要 带 你 上 高 速 。 [SEP] [PAD] [PAD]

在这里插入图片描述
编码工具首先是对一条完整的句子进行了tokenizer,把句子分成了一个个token。同时,对于不同的编码工具,分词的结果也不一定一致。这里采用的bert-base-chinese编码工具中,它是以字为词,即把每个字当做一个词进行处理。
这些编码的结果对于预训练模型的计算十分重要,在后面将会使用编码器将所有的句子进行编码,用于输入到预训练模型中进行计算。

定义数据集

这里使用的数据集为ChnSentiCorp数据集,Dataset类如下:

# import torch
from datasets import load_dataset
class Dataset(torch.utils.data.Dataset):
    def __init__(self, split):
        self.dataset = load_dataset(path='lansinuote/ChnSentiCorp', split=split)
    def __len__(self):
        return len(self.dataset)
    def __getitem__(self, i):
        text = self.dataset[i]['text']
        label = self.dataset[i]['label']
        return text, label
dataset = Dataset('train')
print(len(dataset))
print(dataset[0])

在上述代码中加载了ChnSentiCorp数据集,并使用Pytorch中的Dataset对象进行封装,利用__getitem__()得到每一条数据,每条数据中包含textlabels两个字段,最后初始化训练数据集并查看训练数据集的长度和第一条数据样例。

out: 9600
('选择珠江花园的原因就是方便,有电动扶梯直接到达海边,周围餐馆、食廊、商场、超市、摊位一应俱全。酒店装修一般,但还算整洁。 泳池在大堂的屋顶,因此很小,不过女儿倒是喜欢。 包的早餐是西式的,还算丰富。 服务吗,一般', 1)

由上面的输出可知训练数据集包括9600条数据,每条数据包含一条评论文本和一个标识,表明这一条评论是好评还是差评。注意:这里的数据集是单纯的原始文本数据,并没有进行编码。

定义计算设备

这里将使用CUDA作为计算设备,这样可以极大加速模型的训练和测试的过程,代码如下:

device = 'cpu'
if torch.cuda.is_available():
  device = 'CUDA'
print('选用的计算设备:',device)

在该段代码中默认使用CPU进行计算,如果存在CUDA的话则选用CUDA作为计算设备。

定义数据整理函数

正如上面所述的那样,ChnSentiCorp数据集中的每一条数据是抽象的文本数据,并没有进行任何的编码操作,而预训练模型是需要编码之后的数据才能进行计算,所以需要一个将文本句子转成编码的过程。
另外,在训练模型时数据集往往很大,如果一条一条地处理则效率会太低,在现实中我们往往一批一批地处理数据,这样可以快速地处理数据集,同时从梯度下降的角度来讲,批数据的梯度方差相较于一条条数据的梯度小,可以让模型更加稳定地更新参数

# 定义批处理函数
def collate_fn(data):
    sents = [i[0] for i in data]
    labels = [i[1] for i in data]
    # 编码
    data = token.batch_encode_plus(batch_text_or_text_pairs=sents, truncation=True,
                                   padding='max_length',
                                   max_length=500,
                                   return_tensors='pt',
                                   return_length=True)
    # input_ids:编码之后的数字
    # attention_masks:补0的位置都是0,其他位置都是1
    input_ids = data['input_ids']
    attention_mask = data['attention_mask']
    token_type_ids = data['token_type_ids']
    labels = torch.LongTensor(labels)
    # print(data['length'],data['length'].max())
    return input_ids, attention_mask, token_type_ids, labels

在这段代码中,参数data表示一批数据,取出其中的句子和标识,它们都是list类型,在上述代码中会将两者分别赋给sentslabels,然后是使用编码器编码该批句子,在参数中将编码后的结果指定为固定的500个词的大小,与上面的例子同理超出500个词的部分会被截断(这里是通过truncation=True控制),同时少于500个词的句子会被[PAD]填充(这里主要是通过 padding='max_length'控制)。另外,在编码过程中通过 return_tensors='pt'参数,将编码后的结果返回torch中的tensor类型,免去了后面转换数据格式的麻烦(也就是说后面可以通过数据格式转换可以将‘tf’转成‘pt’格式)。
之后取出编码后的结果,并将labels也转成Pytorch中的Tensor格式,再把它们移动到之前已经定义好的计算设备device上,最后把这些数据全部返回,到这里数据整理函数的工作已经全部完成。

数据处理函数的例子

上述定义了数据处理函数,为了实验其效果也可使用下面的例子:(本用例已加狗头保命~)在这里插入图片描述

data = [
    ('选择新大的原因当然不是为了延毕。',1),
    ('笔记本的内存确实小。',0),
    ('宿舍没有风扇。其他都很好。',1),
    ('今天才知道这本书还有第10000卷,真是太屌了。',1),
    ('机器的背面似乎被撕了张什么标签,残胶还在。',0),
    ('为什么有人在校园里尖叫,是疯了还是giao。',0)
]

# 狗头保命版试算
input_ids,attention_mask,token_type_ids,labels = collate_fn(data)
print('input_ids.shape',input_ids.shape)
print('attention_mask.shape',attention_mask.shape)
print('token_type_ids.shape',token_type_ids.shape)
print('labels:',labels)

在该段代码中首先是模拟了一批数据,这批数据中包含4个句子,通过将该批数据输入到整理函数以后,运行结果如下:

input_ids.shape torch.Size([6, 500])
attention_mask.shape torch.Size([6, 500])
token_type_ids.shape torch.Size([6, 500])
labels: tensor([1, 0, 1, 1, 0, 0])

可见编码之后的结果都是确定的500个词的长度,并且每个结果都会被移动到可用的计算设备上,这样可以方便后续的计算。

定义数据加载器

上述代码中定义了数据集和数据整理函数以后,下面我们将定义一个数据加载器DataLoader,它可以使用数据整理函数来完成成批地处理数据集中的数据,通俗来讲每一批的数据我们可以称为batch

# 定义数据加载器并查看数据样例
loader = torch.utils.data.DataLoader(dataset=dataset, batch_size=16,
                                     collate_fn=collate_fn,
                                     shuffle=True,
                                     drop_last=True)

对于上述代码,我们使用了Pytorch提供的工具类定义数据集加载器,其参数说明可参考下图:
在这里插入图片描述

数据加载器的例子

为了更好地使用数据加载器,这里我们查看一批数据样例,将这批数据输入到数据加载器中,可以发现其结果会与数据整理函数的运行结果相似,只不过是句子的数量增多了。

上述代码依次打印了加载器中批次数目、加载器中输入数据的input_ids和掩蔽注意力的形状
attention_mask_shape、词元的ids类型形状token_type_ids_shape以及标签labels

for i, (input_ids, attention_mask, token_type_ids, labels) in enumerate(loader):
    break
print(len(loader))
print('input_ids', input_ids)
print('attention_mask_shape', attention_mask.shape)
print('token_type_ids_shape', token_type_ids.shape)
print('labels', labels)
  1. input_ids 就是编码后的词
  2. token_type_ids 第一个句子和特殊符号的位置是0,第二个句子的位置是1
  3. attention_mask pad的位置是0,其他位置都是1
  4. special_tokens_mask 特殊符号的位置是1,其他位置都是0

定义模型

因为我们是要利用Huggingface的预训练语言模型,所以需要做两件事情:加载预训练模型PLM以及定义下游任务模型。

【待更新~】

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/177001.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

关于AssetBundle禁用TypeTree之后的一些可序列化的问题

1)关于AssetBundle禁用TypeTree之后的一些可序列化的问题 2)启动Unity导入变动的资源时,Singleton ScriptableObject 加载不到 3)Xcode15构建Unity 2022.3的Xcode工程,报错没有兼容的iPhone SDK 这是第361篇UWA技术知识…

NLP学习

参考:NLP发展之路I - 从词袋模型到Transformer - 知乎 (zhihu.com) NLP大致的发展历史。从最开始的词袋模型,到RNN,到Transformers和BERT,再到ChatGPT,NLP经历了一段不断精进的发展道路。数据驱动和不断完善的端到端的…

Spring+Mybatis解析

源码执行流程 通过MapperScan导入MapperScannerRegistrar类MapperScannerRegistrar类实现了ImportBeanDefinitionRegistrar接口,Spring启动会调MapperScannerRegistrar类中的registerBeanDefinitions方法在registerBeanDefinitions方法中注册一个MapperScannerConf…

盖雅绩效应用通过SAP认证并斩获创新方案奖

近日,在「不啻微芒 造炬成阳」为主题的SAP合作伙伴创新大赛上,盖雅工场「G移动绩效创新方案」荣获创新解决方案奖。该方案核心是一款基于SAP SuccessFactors套件及SAP BTP平台的扩展应用,主要针对一线人员绩效管理场景,借助简洁的…

国民新旅游时代,OTA们如何制胜新周期?

文 | 螳螂观察(TanglangFin) 作者 | 图霖 消费全面复苏的大背景下,旅游业正迎来预期中的拐点。 一个显著表现是,旅游消费正在从可选消费转化成必选消费。 国内消费者旅游需求的不降反增,就是最好的印证。 同程研究…

哈希表之开散列的实现

回顾与引出 我们在上一节用闭散列的开放定址法实现了哈希表。不难看出这种方法有明显的缺点:一旦发生哈希冲突,所有的冲突连在一起,容易产生数据“堆积”,即:不同 关键码占据了可利用的空位置,使得寻找某关…

秋招如何准备?有什么建议?

秋招,是毕业生最好的求职渠道,没有之一。尽管还有春招,社招......都不如秋招重要,因为秋招的机会更多..... 如何准备秋招? 1、简历很重要 一个好的简历,就是敲门砖,这是你跟企业HR的第一次亲…

如何使用SD-WAN提升物流供应链网络效率

案例背景 本次分享的物流供应链企业是一家国际性的大型企业,专注于提供全球范围内的物流和供应链解决方案。案例用户在不同国家和地区均设有多个分支机构和办公地点,以支持客户需求和业务运营。 在过去,该企业用户使用传统的MPLS网络来连接各…

【grep】从html表格中快速定位某个数据

文章目录 1 背景2 参考知识2.1 grep2.2 HTML基础语言标签 3 解决方案 1 背景 在html中是一堆表格、图片、文字什么的,想从表格中提取关键词为“GJC”后对应的数字,怎么办呢? 逐个打开html文件,“ctrlF”搜一下,然后复…

直线导轨在自动锁螺丝机的作用及注意事项

直线导轨在自动锁螺丝机中具有重要作用,可以提供精确的导向,使滑块能够沿固定轨迹移动,确保螺丝准确无误地进入螺丝孔并被锁定,因此,选择高品质的直线导轨对于自动锁螺丝机的性能和精度至关重要!那么&#…

拿下!这些证书可以帮你职场晋升!(PMP/CSPM/NPDP)

PMP证书为项目管理道路打好基础,建立规划思维,整合思维,提高解决问题效率。中国也有自己的项目管理认证CSPM,与PMP相比难度较小,可用已获得的证书免考。NPDP认证拓宽视野,帮助项目经理提升技能。 01PMP为项…

常见树种(贵州省):006栎类

摘要:本专栏树种介绍图片来源于PPBC中国植物图像库(下附网址),本文整理仅做交流学习使用,同时便于查找,如有侵权请联系删除。 图片网址:PPBC中国植物图像库——最大的植物分类图片库 一、麻栎 …

【狂神说】CSS3详解

目录 CSS概述什么是CSSCSS发展史快速入门CSS的三种导入方式 2 选择器2.1 基本选择器标签选择器类选择器id选择器 2.2 层次选择器2.3 结构伪类选择器2.4 属性选择器(常用) 3 美化网页元素3.1 为什么要美化网页3.2 字体样式3.3 文本样式 视频课程见链接&am…

口碑好的猫罐头有哪些?宠物店受欢迎的5款猫罐头推荐!

快到双十二啦!铲屎官们是时候给家里猫主子囤猫罐头了。许多铲屎官看大促的各种品牌宣传,看到眼花缭乱,不知道选哪些猫罐头好,胡乱选又怕踩坑。 口碑好的猫罐头有哪些?作为一个经营宠物店7年的老板,活动期间…

c语言编程(模考2)

简答题1 从键盘输入10个数&#xff0c;统计非正数的个数&#xff0c;并且计算非正数的和 #include<stdio.h> int main() {int i,n0,sum0;int a[10];printf("请输入10个数&#xff1a;");for(i0;i<10;i){scanf("%d",&a[i]);}for(i0;i<10…

Android使用Kotlin利用Gson解析多层嵌套Json数据

文章目录 1、依赖2、解析 1、依赖 build.gradle(app)中加入 dependencies { implementation com.google.code.gson:gson:2.8.9 }2、解析 假设这是要解析Json数据 var responseStr "{"code": 200,"message": "操作成功","data&quo…

vue3 iconify 图标几种使用 并加载本地 svg 图标

iconify iconify 与 iconify/vue 使用 下载 pnpm add iconify/vue -D使用 import { Icon } from "iconify/vue";<template><Icon icon"mdi-light:home" style"color: red; font-size: 43px" /><Icon icon"mdi:home-flo…

11.6AOP

一.AOP是什么 是面向切面编程,是对某一类事情的集中处理. 二.解决的问题 三.AOP的组成 四.实现步骤 1.添加依赖(版本要对应): maven仓库链接 2.添加两个注解 3.定义切点 4.通知 5.环绕通知 五.excution表达式 六.AOP原理 1.建立在动态代理的基础上,对方法级别的拦截. 2. …

python实现鼠标实时坐标监测

python实现鼠标实时坐标监测 一、说明 使用了以下技术和库&#xff1a; tkinter&#xff1a;用于创建GUI界面。pyperclip&#xff1a;用于复制文本到剪贴板。pynput.mouse&#xff1a;用于监听鼠标事件&#xff0c;包括移动和点击。threading&#xff1a;用于创建多线程&…

深入浅出 Linux 中的 ARM IOMMU SMMU I

Linux 系统下的 SMMU 介绍 在计算机系统架构中&#xff0c;与传统的用于 CPU 访问内存的管理的 MMU 类似&#xff0c;IOMMU (Input Output Memory Management Unit) 将来自系统 I/O 设备的 DMA 请求传递到系统互连之前&#xff0c;它会先转换请求的地址&#xff0c;并对系统 I…