大模型工程师学习日记(十五):Hugging Face 模型微调训练(基于 BERT 的中文评价情感分析)

1. datasets 库核心方法

1.1. 列出数据集 使用 d atasets 库,你可以轻松列出所有 Hugging Face 平台上的数据集:

from datasets import list_datasets
 # 列出所有数据集
all_datasets = list_datasets()
 print(all_datasets)

1.2. 加载数据集 你可以通过 load_dataset 方法加载任何数据集:

from datasets import load_dataset
 # 加载GLUE数据集
dataset = load_dataset("glue", "mrpc")
 print(dataset)

1.3. 加载磁盘数据 你可以加载本地磁盘上的数据:

from datasets import load_from_disk
 # 从本地磁盘加载数据集
dataset = load_from_disk("./my_dataset")
 print(dataset)

2. 分词工具与文字编码 2.1. 加载字典和分词工具 你可以使用 AutoTokenizer 自动加载分词工具:

from transformers import AutoTokenizer
 # 加载中文BERT模型的分词器
tokenizer = AutoTokenizer.from_pretrained("bert-base-chinese")

2.1. 批量编码句子 使用分词器,你可以批量对文本进行编码:

# 批量编码句子
sentences = ["我爱自然语言处理", "Hugging Face 很强大"]
 encoded_inputs = tokenizer(sentences, padding=True, truncation=True, 
return_tensors="pt")
 print(encoded_inputs)

3. 模型微调的基本概念与流程

微调是指在预训练模型的基础上,通过进一步的训练来适应特定的下游任务。BERT 模型通过预训练来 学习语言的通用模式,然后通过微调来适应特定任务,如情感分析、命名实体识别等。微调过程中,通 常冻结 BERT 的预训练层,只训练与下游任务相关的层。本课件将介绍如何使用 BERT 模型进行情感分 析任务的微调训练。

4. 加载数据集 情感分析任务的数据通常包括文本及其对应的情感标签。使用 Hugging Face 的 datasets 库可以轻松地 加载和处理数据集

from datasets import load_dataset
 # 加载数据集
dataset = load_dataset('csv', data_files="data/ChnSentiCorp.csv")
 # 查看数据集信息
print(dataset)

4.1 数据集格式

Hugging Face 的 datasets 库支持多种数据集格式,如 CSV、JSON、TFRecord 等。在本案例中,使用 CSV 格式,CSV 文件应包含两列:一列是文本数据,另一列是情感标签。

4.2 数据集信息

加载数据集后,可以查看数据集的基本信息,如数据集大小、字段名称等。这有助于我们了解数据的分 布情况,并在后续步骤中进行适当的处理。

5. 制作 Dataset 加载数据集后,需要对其进行处理以适应模型的输入格式。这包括数据清洗、格式转换等操作。

加载数据集后,需要对其进行处理以适应模型的输入格式。这包括数据清洗、格式转换等操作。
from datasets import Dataset
 # 制作 Dataset
 dataset = Dataset.from_dict({
 'text': ['位置尚可,但距离海边的位置比预期的要差的多', '5月8日付款成功,当当网显示5月10
日发货,可是至今还没看到货物,也没收到任何通知,简不知怎么说好!!!', '整体来说,本书还是不错
的。至少在书中描述了许多现实中存在的司法系统方面的问题,这是值得每个法律工作者去思考的。尤其是让
那些涉世不深的想加入到律师队伍中的年青人,看到了社会特别是中国司法界真实的一面。缺点是:书中引用
了大量的法律条文和司法解释,对于已经是律师或有一定工作经验的法律工作者来说有点多余,而且所占的篇
幅不少,有凑字数的嫌疑。整体来说还是不错的。不要对一本书提太高的要求。'],
 'label': [0, 1, 1]  # 0 表示负向评价,1 表示正向评价
})
 # 查看数据集信息
print(dataset)

5.1 数据集字段

在制作 Dataset 时,需定义数据集的字段。在本案例中,定义了两个字段: text (文本)和 label (情感标签)。每个字段都需要与模型的输入和输出匹配。

5.2 数据集信息

制作 Dataset 后,可以通过 dataset.info 等方法查看其大小、字段名称等信息,以确保数据集的正确 性和完整性。

6. vocab 字典操作

在微调 BERT 模型之前,需要将模型的词汇表(vocab)与数据集中的文本匹配。这一步骤确保输入的 文本能够被正确转换为模型的输入格式。

from transformers import BertTokenizer
 # 加载 BERT 模型的 vocab 字典
tokenizer = BertTokenizer.from_pretrained('bert-base-chinese')
 # 将数据集中的文本转换为 BERT 模型所需的输入格式
dataset = dataset.map(lambda x: tokenizer(x['text'], return_tensors="pt"), 
batched=True)
 # 查看数据集信息
print(dataset)

6.1 词汇表(vocab)

BERT 模型使用词汇表(vocab)将文本转换为模型可以理解的输入格式。词汇表包含所有模型已知的单 词及其对应的索引。确保数据集中的所有文本都能找到对应的词汇索引是至关重要的。

6.2 文本转换

使用 tokenizer 将文本分割成词汇表中的单词,并转换为相应的索引。此步骤需要确保文本长度、特殊 字符处理等都与 BERT 模型的预训练设置相一致。

7. 下游任务模型设计

在微调 BERT 模型之前,需要设计一个适应情感分析任务的下游模型结构。通常包括一个或多个全连接 层,用于将 BERT 输出的特征向量转换为分类结果。

from transformers import BertModel
 import torch.nn as nn
 class SentimentAnalysisModel(nn.Module):
 def __init__(self):
 super().__init__()
 self.bert = BertModel.from_pretrained('bert-base-chinese')
 self.drop_out = nn.Dropout(0.3)
 self.linear = nn.Linear(768, 2)  # 假设情感分类为二分类
def forward(self, input_ids, attention_mask):
 _, pooled_output = self.bert(
 input_ids=input_ids,
 attention_mask=attention_mask,
 return_dict=False
 )
 output = self.drop_out(pooled_output)
 return self.linear(output)

7.1 模型结构

下游任务模型通常包括以下几个部分: BERT 模型:用于生成文本的上下文特征向量。 Dropout 层:用于防止过拟合,通过随机丢弃一部分神经元来提高模型的泛化能力。 全连接层:用于将 BERT 的输出特征向量映射到具体的分类任务上。

7.2 模型初始化

使用 接层。初始化时,需要根据下游任务的需求,定义合适的输出维度。 BertModel.from_pretrained() 方法加载预训练的 BERT 模型,同时也可以初始化自定义的全连

8. 自定义模型训练

模型设计完成后,进入训练阶段。通过数据加载器(DataLoader)高效地批量处理数据,并使用优化器 更新模型参数

from torch.utils.data import DataLoader
 from transformers import AdamW
 # 实例化 DataLoader
 data_loader = DataLoader(dataset, batch_size=16, shuffle=True)
 # 初始化模型和优化器
model = SentimentAnalysisModel()
 optimizer = AdamW(model.parameters(), lr=5e-5)
 # 训练循环
for epoch in range(3):  # 假设训练 3 个 epoch
 model.train()
 for batch in data_loader:
 optimizer.zero_grad()
 outputs = model(input_ids=batch['input_ids'], 
attention_mask=batch['attention_mask'])
 loss = nn.CrossEntropyLoss()(outputs, batch['labels'])
 loss.backward()
 optimizer.step()

8.1 数据加载

使用 DataLoader 实现批量数据加载。 DataLoader 自动处理数据的批处理和随机打乱,确保训练的高 效性和数据的多样性。

8.2 优化器

AdamW 是一种适用于 BERT 模型的优化器,结合了 Adam 和权重衰减的特点,能够有效地防止过拟合。

8.3 训练循环

训练循环包含前向传播(forward pass)、损失计算(loss calculation)、反向传播(backward pass)、参数更新(parameter update)等步骤。每个 epoch 都会对整个数据集进行一次遍历,更新 模型参数。通常训练过程中会跟踪损失值的变化,以判断模型的收敛情况。

9. 最终效果评估与测试

在模型训练完成后,加载模型训练权重测试其效果:

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/982350.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

高考數學。。。

2024上 具体来说,直线的参数方程可以写为: x1t y−t z1t 二、简答题(本大题共5小题,每小题7分,共35分。) 12.数学学习评价不仅要关注结果评价,也要关注过程评价。简要说明过程评价应关注哪几个方面。…

Seurat - Guided Clustering Tutorial官方文档学习及复现

由于本人没有使用过Seurat4.0,而是直接使用的最新版。所以本文都是基于Seurat5.2.0(截止2025/3/6)来进行撰写。 参考的官方教程来进行学习(上图中的 Guided tutorial-2.700 PBMCs),肯定没有官方文档那么全面…

(undone) MIT6.S081 Lec14 File systems 学习笔记

url: https://mit-public-courses-cn-translatio.gitbook.io/mit6-s081/lec14-file-systems-frans Why Interesting 从一个问题开始:既然你每天都使用了文件系统,XV6的文件系统与你正在使用的文件系统有什么区别。接下来我会点名: 学生回答…

【C++进阶学习】第一讲——继承(下)---深入挖掘继承的奥秘

目录 1.隐藏 1.1隐藏的概念 1.2隐藏的两种方式 2.继承与友元 3、继承与静态成员 4.单继承和多继承 4.1单继承 4.2多继承 5.菱形继承 问题1:冗余性 问题2:二义性 6.虚拟继承 7.总结 1.隐藏 1.1隐藏的概念 在 C 中,继承是一种机制…

UI自动化:利用百度ocr识别解决图形验证码登录问题

相信大家在做自动化测试过程中都遇到过图形验证码的问题,最近我也是遇到了,网上搜了很多方法,最简单的方法无非就是去掉图形验证码或者设置一个万能验证码,但是这个都需要开发来帮忙解决,对于我们这种自学的人来说就不…

C/C++蓝桥杯算法真题打卡(Day1)

一、LCR 018. 验证回文串 - 力扣(LeetCode) 算法代码: class Solution { public:bool isPalindrome(string s) {int n s.size();// 处理一下s为空字符的情况if (n 0) {return true; // 修正拼写错误}// 定义左右指针遍历字符串int left …

蓝桥杯备考:动态规划路径类DP之矩阵的最小路径和

如题,要求左上角到右下角的最短路径,我们还是老样子按顺序做 step1:确定状态表示 f[i][j]表示(1,1)到(i,j)的最短距离 step2 :推导状态表达方程 step3:确定填表顺序,应该是从上到下,从左到右 step4:初始化 step5 找结果&#…

18类创新平台培育入库!长沙经开区2025年各类科技创新平台培育申报流程时间材料及申报条件

长沙经开区打算申报企业研发中心、技术创新中心、工程技术研究中心、新型研发机构、重点实验室、概念验证中心和中试平台、工程研究中心、企业技术中心、制造业创新中心、工业设计中心等创新平台的可先备案培育入库,2025年各类平台的认定将从培育库中优先推荐&#…

CyberDefenders----WebStrike Lab

WebStrike Lab 实验室链接 简介: 公司网络服务器上发现了一个可疑文件,在内联网中发出警报。开发团队标记了异常,怀疑存在潜在的恶意活动。为了解决这个问题,网络团队捕获了关键网络流量并准备了一个 PCAP 文件以供审查。您的任务是分析提供的 PCAP 文件以发现文件的出现…

【python】gunicorn配置

起因:因为cpu利用率低导致我去缩容,虽然缩容之后cpu利用率上升维持在60%左右,但是程序响应耗时增加了。 解释:因为cpu干这件活本身不累,但在干这件活的时候不能去干其他事情,导致并发的请求不能及时响应&am…

SSE vs WebSocket:AI 驱动的实时通信抉择

引言 近年来,基于 Transformer 的大模型推动了 AI 产业的飞速发展,同时带来了新的技术挑战: 流式传输 vs 批量返回:大模型生成的长文本若需一次性返回,会显著影响用户体验,实时推送成为必需。语音交互需求:语音助手要求毫秒级响应,而非等待用户完整输入后再返回结果。…

基于海思soc的智能产品开发(芯片sdk和linux开发关系)

【 声明:版权所有,欢迎转载,请勿用于商业用途。 联系信箱:feixiaoxing 163.com】 随着国产化芯片的推进,在soc领域,越来越多的项目使用国产soc芯片。这些soc芯片,通常来说运行的os不是linux&…

各种DCC软件使用Datasmith导入UE教程

3Dmax: 先安装插件 https://www.unrealengine.com/zh-CN/datasmith/plugins 左上角导出即可 虚幻中勾选3个插件,重启引擎 左上角选择文件导入即可 Blender导入Datasmith进UE 需要两个插件, 文章最下方链接进去下载安装即可 一样的,直接导出,然后UE导入即可 C4D 直接保存成…

VMware Fusion虚拟机Mac版安装Ubuntu系统

介绍 Ubuntu操作系统是一个基于Linux内核的桌面和服务器操作系统。它由Canonical公司开发和维护,是最受欢迎的Linux操作系统之一。Ubuntu操作系统以简洁、直观和易用为设计原则,提供了友好的图形界面,支持多种语言和自定义设置,用…

发行思考:全球热销榜的频繁变动

几点杂感: 1、单机游戏销量与在线人数的衰退是剧烈的,有明显的周期性,而在线游戏则稳定很多。 如去年的某明星游戏,最高200多万在线,如今在线人数是48名,3万多。 而近期热门的是MH,在线人数8…

AI赋能科研绘图与数据可视化高级应用

在科研成果竞争日益激烈的当下,「一图胜千言」已成为高水平SCI期刊的硬性门槛——数据显示很多情况的拒稿与图表质量直接相关。科研人员普遍面临的工具效率低、设计规范缺失、多维数据呈现难等痛点,因此科研绘图已成为成果撰写中的至关重要的一个环节&am…

thingsboard edge 在windows 环境下的配置

按照官方文档:Installing ThingsBoard Edge on Windows | ThingsBoard Edge,配置好java环境和PostgreSQL。 下载对应的windows 环境下的tb-edge安装包。下载附件 接下来操作具体如下 步骤1,需要先在thingsboard 服务上开启edge 权限 步骤2…

最硬核DNS详解

1、是什么 DNS(域名系统)是互联网的一项服务,它作为将域名和IP地址相互映射的一个分布式数据库,能够使人更方便地访问互联网。DNS协议基于UDP协议,使用端口号53。 2、域名服务器类型 域名服务器在DNS体系中扮演着不…

CentOS 7 安装Nginx-1.26.3

无论安装啥工具、首先认准了就是官网。Nginx Nginx官网下载安装包 Windows下载: http://nginx.org/download/nginx-1.26.3.zipLinxu下载 wget http://nginx.org/download/nginx-1.26.3.tar.gzLinux安装Nginx-1.26.3 安装之前先安装Nginx依赖包、自行选择 yum -y i…

基于国产芯片的AI引擎技术,打造更安全的算力生态 | 京东零售技术实践

近年来,随着国产AI芯片的日益崛起,基于国产AI芯片的模型适配、性能优化以及应用落地是国产AI应用的一道重要关卡。如何在复杂的京东零售业务场景下更好地使用国产AI芯片,并保障算力安全,是目前亟需解决的问题。对此,京…