入门微调预训练Transformer模型

大家好,HuggingFace 为众多开源的自然语言处理(NLP)模型提供了强大的支持平台,让这些模型能够通过训练和微调来更好地服务于各种特定的应用场景。在大型语言模型(LLM)迅猛发展的今天,HuggingFace 提供的核心工具,特别是 Trainer 类,极大地优化了 NLP 模型的训练过程,开发者得以更加高效地实现模型定制和优化。

HuggingFace 的 Trainer 类是为 Transformer 模型量身打造的,不仅优化了模型的交互体验,还与 Datasets 和 Evaluate 等库实现了紧密集成,支持更高级的分布式训练,并能无缝对接 Amazon SageMaker 等基础设施服务。通过这种方式,可以更加便捷地进行模型训练和部署。

本文将通过一个实例,展示如何利用 HuggingFace 的 Trainer 类在本地环境中对 BERT 模型进行微调,以处理文本分类任务。并且重点介绍如何使用 HuggingFace 模型中心的预训练模型,而不是深入机器学习的理论基础。

 1.设置

示例将在 SageMaker Studio(https://aws.amazon.com/cn/sagemaker/studio/) 环境下进行操作,利用 ml.g4dn.12xlarge 实例搭载的 conda_python3 内核来完成任务。需要提醒的是,可以选择使用更小型的实例,但这可能会影响训练速度,具体取决于可用的 CPU/工作进程的数量。

使用 HuggingFace 数据集库下载数据集。

import datasets
from datasets import load_dataset

这里指定了训练数据集和评估数据集,会在训练循环中进行使用。

train_dataset = load_dataset("imdb", split="train")
test_dataset = load_dataset("imdb", split="test")
test_subset = test_dataset.select(range(100)) # 取数据的一个子集进行评估

对于任何文本数据,必须指定一个分词器,将数据预处理成模型可以理解的格式。在这种情况下,这里指定了我们使用的 BERT 模型的 HuggingFace 模型中心 ID。

from transformers import AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained("bert-base-cased")

# 分词文本数据
def tokenize_function(examples):
    return tokenizer(examples["text"], padding="max_length", truncation=True)

然后使用内置的 map 函数处理我们的训练和评估数据集。

tokenized_train = train_dataset.map(tokenize_function, batched=True)
tokenized_test = test_subset.map(tokenize_function, batched=True)

图片

预处理后的数据

2.微调 BERT

数据准备就绪后,利用先前选定的模型ID来加载BERT模型。需要注意的是,针对文本分类任务,还定义了标签的总数。在此案例中设定了两个标签,分别用0和1来表示,0代表负面,1代表正面。

from transformers import AutoModelForSequenceClassification
model = AutoModelForSequenceClassification.from_pretrained("bert-base-cased", num_labels=2)

接下来在训练循环中,需要定义一个TrainingArguments对象。在这个对象中,可以设置训练过程中的各种参数,比如训练周期的数量、分布式训练的策略等。

from transformers import TrainingArguments, Trainer
training_args = TrainingArguments(output_dir="test_trainer", evaluation_strategy="epoch", num_train_epochs=1)

对于评估,使用 Evaluate 库内置的评估函数。

import numpy as np
import evaluate
metric = evaluate.load("accuracy")

# 评估函数
def compute_metrics(eval_pred):
    logits, labels = eval_pred
    predictions = np.argmax(logits, axis=-1)
    return metric.compute(predictions=predictions, references=labels)

然后将 TrainingArguments、分词数据集和评估指标函数传递给 Trainer 对象。可以使用 train 方法启动训练运行,这将需要大约 10-15 分钟的时间,具体取决于现有硬件。

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=tokenized_train,
    eval_dataset=tokenized_test, # 使用测试作为评估
    compute_metrics=compute_metrics,
    tokenizer=tokenizer
)
trainer.train()

图片

训练完成

对于推理,可以直接使用微调后的 trainer 对象,并在用于评估的分词测试数据集上进行预测:

trainer.predict(tokenized_test)

图片

输出

在更为实际的应用场景中,可以使用 trainer 对象将模型工件保存到本地目录中。

trainer.save_model("./custom_model")

图片

模型工件

然后可以加载这些模型工件,指定训练的模型类型,并在单个数据点上进行推理。

loaded_model = AutoModelForSequenceClassification.from_pretrained(pretrained_model_name_or_path="custom_model/")

# 样本推理
encoding = tokenizer("I am super delighted", return_tensors="pt")
res = loaded_model(**encoding)
predicted_label_classes = res.logits.argmax(-1)
predicted_label_classes

图片

正面分类

在现实应用场景中,可以将训练好的模型工件部署到像 Amazon SageMaker 这样的服务堆栈上。

 

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/523978.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

配置vscode用于STM32编译,Debug

配置环境参考: Docs 用cubemx配置工程文件,用VScode打开工程文件。 编译的时候会有如下报错: vscode出现process_begin :CreateProcess failed 系统找不到指定文件 解决方案:在你的makefile中加上SHELLcmd.exe就可以了 参考…

数据库系统

三级模式 外模式 数据库的用户使用的局部数据的逻辑结构和特征的描述数据库用户的数据视图,是与某一应用程序有关的数据的逻辑表示。 概念模式 它是数据库中全体数据的逻辑结构和特征的描述;模式是所有用户的公共数据视图。因为数据库是多人共享使用的&…

CICD流水线 发布公用jar到maven私仓

3.1 发布公用jar到Maven私仓 1.选择流水线 2.新建流水线 3.选择模版 4.选择代码仓库 5. 调整构建命令 6.新增一个新阶段为 ”发送通知“,这里以邮件通知为例,保存之后,运行该流水线,对应jar就会自动发到我们私仓,并之…

春招-实战项目冲刺直播课

春招-实战项目冲刺直播课 CCtalk 丰富多元的综合内容平台-专业的知识分享与在线教育平台https://www.cctalk.com/m/group/91161801

CTF之社工-初步收集

题目就一个刷钻网站(假的) 扫描一下目录 发现还有一个登录界面 时间多的可以爆破一下(反正我爆不出来),接着我们下载那个压缩包看看 发现是一个钓鱼小软件 没发现什么有用的信息那我们就去wireshark看看数据包喽&#…

winform 等待加载窗体

winform 等待加载窗体 当我们查询sql语句或处理大量的数据时,为了防止界面假死状态,可以加一个等待窗体过渡一下。 1. 新建一个主窗体,一个等待窗体frmLoading 2. 给等待窗体增加一个动态图片 3. 在主窗体中调用 namespace winformLoading…

解析以及探讨数据库技术及其应用

一,引言 数据库作为信息时代的基石,是一种用于高效存储、管理和检索大量结构化数据的系统。它的核心价值在于提供了一种可靠且可扩展的方式,将复杂多样的数据按照特定结构和规则组织起来,以便于不同用户和应用程序进行访问和使用。…

揭秘!接口自动化测试应该做什么?

在软件开发过程中,接口测试是一个至关重要的环节,它确保了系统或组件之间的数据交换、传递和控制管理过程以及相互逻辑依赖关系的正确性。传统的瀑布软件流程中,测试人员在做某个系统的手工功能测试时,会首先从业务人员或开发人员…

vitepress系列-04-规整sideBar左侧菜单导航

规整左侧菜单导航 新建navConfig.ts 文件用来管理左侧导航菜单: 将于其他的配置分开,避免config.mts太大 在config目录下,新建 sidebarModules文件目录用来左侧导航菜单 按模块进行分类: 在config下新建sidebarConfig.ts文件&…

Visual Studio 配置代码风格审查工具cpplint

文章目录 一、Visual Studio 配置代码风格审查工具cpplint1、安装2、运行3、集成到Visual Studio4、集成到Git 前言 cpplint是一个用于检查C代码风格的工具,它可以帮助我们发现潜在的编码问题,提高代码质量。cpplint遵循Google的C编码规范,通…

【c++练习】求3个长方柱的体积

【问题描述】编写一个基于对象数组的程序,用成员函数实现多个功能,求3个长方柱的体积。要求用成员函数实现以下功能: 1、由键盘分别输入3个长方柱的长、宽、高; 2、计算长方柱的体积; 3、输出3个长方柱的体积。 【…

【白菜基础】蛋白组学之生信分析(1)

刚换了一个新课题组,新老板的研究方向为蛋白组学,从未接触过蛋白组学的我准备找一组模拟数据进行生信分析的入门学习。 蛋白组学数据挖掘流程图,参考公众号:蛋白质组学数据挖掘思路解析 (qq.com) 一、认识数据 我们组的数据主要…

csdn博客自定义模块:显示实时天气、日历、随机语录代码

目录 1.样式说明2.效果展示3.代码下载 1.样式说明 vip会员或者博客专家可以自定义模块代码,比如我博客的样式,有这几部分组成: 灯笼祝福(我这里是龙年快乐,可以自定义更改任何字)、滚动欢迎语&#xff08…

Ubuntu下TexStudio如何兼容中文

怎么就想起来研究一下这个? 我使用大名鼎鼎的3Blue1Brown数学动画引擎Manim,制作了一个特别小的动画视频克里金插值。在视频中,绘制文字时,Manim使用到了texlive texlive-latex-extra这些库。专业的关系,当年的毕设没…

服装商城小程序设计分享,服装商城设计分享,自助建站模板分享

在当今数字化的时代,服装商城小程序的设计成为了提升用户购物体验的关键。下面,我将分享一些关于服装商城小程序设计的要点和思路。 首先,界面的简洁与美观至关重要。简洁的布局能让用户轻松找到所需商品,避免繁琐的操作流程。同时…

如何水出第一篇SCI:SCI发刊历程,从0到1全过程经验分享!!!

如何水出第一篇SCI:SCI发刊历程,从0到1全路程经验分享!!! 详细的改进教程以及源码,戳这!戳这!!戳这!!!B站:Ai学术叫叫兽e…

血细胞检测数据集 | 用于血细胞计数+检测的小规模数据集_已经整理成VOC格式_总共410张图

项目应用场景 面向血细胞检测计数数据集,已经整理成 VOC 格式,可以直接用于目标检测算法的训练,如 YOLO 等目标检测算法的训练。血细胞检测数据集图片质量好,可直接训练出一个血细胞检测模型,或者作为血细胞检测数据集…

SOLIDWORKS在教育领域的应用

随着科技的飞速发展和数字化浪潮的推进,SOLIDWORKS作为一款强大的三维设计软件,其应用领域已经不仅局限于工程设计和制造行业,还逐渐渗透到教育领域中,成为培养学生实践能力和创新思维的重要工具。本文将探讨SOLIDWORKS在教育领域…

数组排序(Comparator)

题目 import java.util.Arrays; import java.util.Comparator; import java.util.Scanner; public class Main {public static void main(String[] args) {Scanner sc new Scanner(System.in);int n sc.nextInt();sc.nextLine();Integer[] res new Integer[n1];//使用Integ…

idea Springboot 电影推荐系统LayUI框架开发协同过滤算法web结构java编程计算机网页

一、源码特点 springboot 电影推荐系统是一套完善的完整信息系统,结合mvc框架和LayUI框架完成本系统springboot dao bean 采用协同过滤算法进行推荐 ,对理解JSP java编程开发语言有帮助系统采用springboot框架(MVC模式开发)&…