如何根据自己的数据集微调一个 Transformer 模型

将通过 NLP 中最常见的文本分类任务来学习如何在自己的数据集上利用迁移学习(transfer learning)微调一个预训练的 Transformer 模型—— DistilBERT。DistilBERT 是 BERT 的一个衍生版本,它的优点在它的性能与 BERT 相当,但是体积更小、更高效。所以我们可以在几分钟内训练一个文本分类器。

如果你想尝试一下 BERT,那么只需改一下模型的 checkpoint 就可以了。通常,checkpoint 指的是要加载到给定 Transformer 架构中的一系列模型权重。 

数据集

这里我们将使用英文推文情感数据集,这个数据集中包含了:anger,disgust,fear,joy,sadness 和 surprise 六种情感类别。

http://dx.doi.org/10.18653/v1/D18-1404

所以我们的任务是给定一段推文,训练一个可以将其分类成这六种基本情感的其中之一的模型。

现在我们来下载数据集。

图片

为了更好地分析数据,我们可以将 Dataset 对象转成 Pandas DataFrame,然后就可以利用各种高级 API 可视化数据集了。但是这种转换不会改变数据集的底层存储方式(这里是 Apache Arrow)。

图片

从上面可以看到 text 列中的推文在 label 列都有一个整数对应,显然这个整数和六种情绪是一一对应的。那么怎么去将整数映射成文本标签呢?

如果我们观察一下原始数据集中的每列的数据类型。

图片

我们发现 text 列就是普通的 string 类型,label 列是 ClassLabel 类型。ClassLabel 中包含了 names 属性,我们可以利用 ClassLabel 附带的 int2str 方法来将整数映射到文本标签。

图片

现在看起来就清楚多了。

处理任何分类任务之前,都要看一下样本的类别分布是否均衡,不均衡类别分布的数据集在训练损失和评估指标方面可能需要与平衡数据集做不同的处理。

图片

类别分布严重不均衡!joy 和 sadness 类样本数量最多,而 love 和 surprise 类的样本数量几乎要少 5-10 倍。

有好几种方法可以处理类别不均衡问题:

  • 对样本数量少的类别进行随机上采样。

  • 对样本数量少的类别进行随机下采样。

  • 对于样本数量不足的类别收集更多样本。

限于篇幅,我们这里不做任何处理。

最后一件事,也是最重要的。无论是哪个 Transformer 模型,它都有上下文长度限制(maximum context size)。GPT-4 Turbo 的上下文长度已经达到了 128k 个 token!不过 DistilBERT 只有 512。

token 指的是不能再被拆分的文本原子,我们将在后面学习,这里就简单理解为单词就好。

图片

从上图可以看到最长的推文长度也没超过 512,大多数长度在 15 左右。完全符合 DistilBERT 的要求。比模型最长上下文限制还要长的文本需要被截断,如果截断的文本包含关键信息,这可能会导致性能损失,不过我们这里没有这个问题。

分析完数据集之后,别忘了将数据集格式从 DataFrame 转回来。

图片

Token

像 DistilBERT 这样的 Transformer 模型无法接受原始的字符串作为输入,我们必须将文本拆分成一个个 token(这一过程称为 tokenized),然后编码成数值向量表示。

将文本拆分成模型可用的原子单位的步骤称为 tokenization。对于英文来说有 character tokenization 和 word tokenization。我们这里简单地见识一下,不深入探讨。

以英文为例,对于 character tokenization 来说。

  1. 将原始文本拆分成一个个字符,也就是 26 个大小写字母加标点符号。

  2. 建立一个字符到唯一整数映射的映射关系表。

  3. 将字符映射到唯一的整数表示 input_ids。

  4. 将 input_ids 转成 2D 的 one-hot encoding 向量。

图片

character-level tokenization 忽略了文本的结构,将字符串看成是一连串的字符流,尽管这种方法可以处理拼写错误和罕见的单词。其主要缺点是需要从数据中学习单词等语言结构。这需要大量的计算、内存和数据。因此,这种方法在实践中很少使用。

word tokenization 就是按照单词维度来拆分文本。

图片

其余步骤和 character tokenization 都一样。不过 character tokenization 的词汇表最多只有几百个(对英文来说,26 个大小写字母和标点符号)。但是 word tokenziation 形成的词汇表可能有数千甚至数万之多,尤其是英文这种每个单词还有不同的形式变化的语言。

subword tokenization 可以看成是 character tokenization 和 word tokenization 的折中方法。

NLP 中有不少算法可以实现 subword tokenization,BERT 和 DistilBERT 都是采用 WordPiece 算法。

每个模型都有自己 tokenization 方法,所以要从对应模型的 checkpoint 下载预训练过的 tokenizer。

图片

我们还能获取像最大上下文长度等基本的 tokenizer 信息。

图片

最后一个看起来有点懵,其实在实际工作中我们一般这样做。

图片

首先 input_ids 字段还是 token 对应的整数,但是首尾增加了标识序列开头和结尾的特殊 token:[CLS] 和 [SEP]。

现在再来看看 attention_mask 字段。当批量处理文本时,每个文本的长度都不一样。

  • 如果最长的文本超过模型的最长上下文限制,则直接截断多余的部分。

  • 在其余短文本后面附加 padding token,使它们的长度都一致。

图片

attention mask 为 0 的部分表示对应的 token 是为了扩展长度而引入的 padding token,模型无需理会。

现在对整个数据集进行 tokenization。

图片

模型架构

像 DistilBERT 这样的模型的预训练目标是预测文本序列中的 mask 词,所以我们并不能直接拿来做文本分类任务。像 DistilBERT 这种 encoder-based Transformer 模型架构通常由一个预训练的 body 和对应分类任务的 head 组成。

图片

首先我们将文本进行 tokenization 处理,形成称为 token encodings 的 one-hot 向量。tokenizer 词汇表的大小决定了 token encodings 的维度,通常在 20k-30k。

然后,token encodings 被转成更低维度的 token embeddings 向量,比如 768 维,在 embedding 空间中,意思相近的 token 的 embedding 向量表示的距离也会更相近。

然后 token embeddings 经过一系列的 encoder 层,最终每个 token 都生成了一个 hidden state。

现在我们有两种选择:

将 Transformer 模型视为特征抽取模型,我们不改变原模型的权重,仅仅将 hidden state 作为每个文本的特征,然后训练一个分类模型,比如逻辑回归。

所以我们需要在训练时冻结 body 部分的权重,仅更新 head 的权重。

图片

这样做的好处是即使 GPU 不可用时我们也可以快速训练一个小模型。

让我们先下载模型。

图片

这个模型就会将 token encoding 转成 embedding,再经过若干 encoder 层输出 hidden state。

图片

在分类任务中,习惯用 [CLS] token 对应的 hidden state 作为句子特征,所以我们先写一个特征抽取函数。

图片

然后抽取我们这个数据集的特征。

图片

然后我们可以训练一个逻辑回归模型去预测推文情绪类别。

图片

图片

图片

从混淆矩阵可以看到 anger 和 fear 通常会被误分类成 sadness,love 和 surprise 也总会被误分类成 joy。

微调 Transformer 模型

此时我们不再将预训练的 Transformer 模型当作特征抽取器了,我们也不会将 hidden state 作为固定的特征了,我们会从头训练整个整个 Transformer 模型,也就是会更新预训练模型的权重。

如下图所示,此时 head 部分要可导了,不能使用逻辑回归这样的机器学习算法了,我们可以使用神经网络。

首先我们加载预训练模型,从下方的警告信息可以看到此时模型一部分参数是随机初始化的。

图片

接下来再定义 F1-score 和准确率作为微调模型时的性能衡量指标。

图片

然后就是定义一些训练模型时的超参数设定。

图片

全部就绪后,就可以训练模型了,我们这里训练 2 个 epoch。

可以看到仅仅训练了 2 个 epoch,模型在验证集上的 F1-score 就达到了 93%。

我们再看一下模型在验证集上的混淆矩阵。

图片

图片

可以看到此时的混淆矩阵已经十分接近对角矩阵了,比之前的好多了。

最后我们看一下微调过的模型是如何预测推文情绪的。

图片

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/329744.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

Unity3d C#实现场景编辑/运行模式下3D模型XYZ轴混合一键排序功能(含源码工程)

前言 在部分场景搭建中需要整齐摆放一些物品(如仓库中的货堆、货架等),因为有交互的操作在单个模型上,每次总是手动拖动模型操作起来也是繁琐和劳累。 在这背景下,我编写了一个在运行或者编辑状态下都可以进行一键排序…

【嘉立创EDA-PCB设计指南】3.网络表概念解读+板框绘制

前言:本文对网络表概念解读板框绘制(确定PCB板子轮廓) 网络表概念解读 在本专栏的上一篇文章【嘉立创EDA-PCB设计指南】2,将设计的原理图转为了PCB,在PCB界面下出现了所有的封装,以及所有的飞线属性&…

从0开始python学习-48.pytest框架之断言

目录 1. 响应进行断言 1.1 在yaml用例中写入断言内容 1.2 封装断言方法 1.3 在执行流程中加入断言判断内容 2. 数据库数据断言 2.1 在yaml用例中写入断言内容 2.2 连接数据库并封装执行sql的方法 2.3 封装后校验方法是否可执行 2.4 使用之前封装的断言方法&#xff0c…

austin-admin 消息推送平台前端项目依赖低代码平台Amis 怎么使用

austin-admin 消息推送平台前端项目🔥依赖低代码平台Amis 怎么使用 收到一个通知,要将部署一个开源的消息系统 :austin的前端开源:https://gitee.com/zhongfucheng/austin-admin 本地运行 1、使用npm或者yarn这些咯 yarn yarn start2、使用…

【LabVIEW FPGA入门】FPGA中的数学运算

数值控件选板上的大部分数学函数都支持整数或定点数据类型,但是需要请注意,避免使用乘法、除法、倒数、平方根等函数,此类函数比较占用FPGA资源,且如果使用的是定点数据或单精度浮点数据仅适用于FPGA终端。 1.整数运算 支持的数…

pyechart基础

pyecharts - A Python Echarts Plotting Library built with love. 全局配置项 初识全局配置组件 Note: 配置项章节应该配合图表类型章节中的 example 阅读。 全局配置项可通过 set_global_opts 方法设置 InitOpts:初始化配置项 class pyecharts.options.InitO…

Java顺序表(2)

🐵本篇文章将对ArrayList类进行讲解 一、ArrayList类介绍 上篇文章我们对顺序表的增删查改等方法进行了模拟实现,实际上Java提供了ArrayList类,而在这个类中就包含了顺序表的一系列方法,这样在用顺序表解决问题时就不用每次都去实…

【C++干货铺】红黑树 (Red Black Tree)

个人主页点击直达:小白不是程序媛 C系列专栏:C干货铺 代码仓库:Gitee 目录 前言 红黑树的概念 红黑树的性质 红黑树结点的定义 红黑树的插入操作 插入新的结点 检查规则进行改色 情况一 情况二 情况三 插入完整代码 红黑树的验…

SpringMVC参数接收见解4

# 4.参数接收Springmvc中,接收页面提交的数据是通过方法形参来接收: 处理器适配器调用springmvc使用反射将前端提交的参数传递给controller方法的形参 springmvc接收的参数都是String类型,所以spirngmvc提供了很多converter(转换…

【数据结构】归并排序的两种实现方式与计数排序

前言:在前面我们讲了各种常见的排序,今天我们就来对排序部分收个尾,再来对归并排序通过递归和非递归的方法进行实现,与对计数排序进行简单的学习。 💖 博主CSDN主页:卫卫卫的个人主页 💞 👉 专栏…

Android Matrix绘制PaintDrawable设置BitmapShader,手指触点为圆心scale放大原图,Kotlin

Android Matrix绘制PaintDrawable设置BitmapShader,手指触点为圆心scale放大原图,Kotlin 在 Android基于Matrix绘制PaintDrawable设置BitmapShader,以手指触点为中心显示原图的圆切图,Kotlin(4)-CSDN博客 的…

2001-2022年上市公司企业财务绩效、公司价值、并购绩效数据(ROA、ROE、TOBINQ变化)

2001-2022年上市公司企业财务绩效、公司价值、并购绩效数据(ROA、ROE、TOBINQ变化) 1、时间:2001-2022年 2、指标:证券代码、统计截止日期、证券简称、行业代码、行业名称、年份、、总资产净利润率B、净资产收益率(ROE)B、托宾Q…

【方法】如何压缩zip格式文件?

zip是一种常见的压缩文件格式,能够高效打包文件便于存储和传输,那zip格式的压缩文件要如何压缩呢? 压缩zip文件需要用到解压缩软件,比如常见的WinRAR、7-Zip软件都可以压缩zip格式。下面一起来看看具体如何操作。 一、使用WinRAR…

日期处理第一篇--优雅好用的Java日期工具类Joda-Time

日常开发中,处理时间和日期是很常见的需求。基础的java内置工具类只有Date和Calendar,但是这些工具类的api使用并不是很方便和强大,于是就诞生了Joda-Time这个专门处理日期时间的库。 简介 Joda-Time提供了Java日期处理的优雅的替代品&…

IntelliJ IDEA 拉取gitlab项目

一、准备好Gitlab服务器及项目 http://192.168.31.104/root/com.saas.swaggerdemogit 二、打开 IntelliJ IDEA安装插件 打开GitLab上的项目,输入项目地址 http://192.168.31.104/root/com.saas.swaggerdemogit 弹出输入登录用户名密码,完成。 操作Comm…

【昕宝爸爸小模块】图文源码详解什么是线程池、线程池的底层到底是如何实现的

➡️博客首页 https://blog.csdn.net/Java_Yangxiaoyuan 欢迎优秀的你👍点赞、🗂️收藏、加❤️关注哦。 本文章CSDN首发,欢迎转载,要注明出处哦! 先感谢优秀的你能认真的看完本文&…

发送HTTP POST请求并处理响应

发送HTTP POST请求并处理响应是Web开发中的常见任务。在Go语言中,可以使用net/http包来发送HTTP POST请求并处理响应。 以下是一个示例代码,演示了如何发送HTTP POST请求并处理响应: go复制代码 package main import ( "b…

代码随想录算法训练营day10|232.用栈实现队列、225.用队列实现栈

理论基础 232.用栈实现队列 225. 用队列实现栈 理论基础 了解一下 栈与队列的内部实现机智,文中是以C为例讲解的。 文章讲解:代码随想录 232.用栈实现队列 大家可以先看视频,了解一下模拟的过程,然后写代码会轻松很多。 题目链…

Maven 依赖传递和冲突、继承和聚合

一、依赖传递和冲突 1.1 Maven 依赖传递特性 1.1.1 概念 假如有三个 Maven 项目 A、B 和 C,其中项目 A 依赖 B,项目 B 依赖 C。那么我们可以说 A 依赖 C。也就是说,依赖的关系为:A—>B—>C, 那么我们执行项目 …

性能优化-一文宏观理解OpenCL

本文主要对OpenCL做一个整体的介绍、包括环境搭建、第一个OpenCL程序、架构、优化策略,希望对读者有所收获。 🎬个人简介:一个全栈工程师的升级之路! 📋个人专栏:高性能(HPC)开发基础…