DistilBERT模型训练实战

LLM似乎正在接管世界,但许多人仍然不真正理解他们是如何运作的。 我从事机器学习工作已有几年,并且对自然语言处理和最近的进展非常着迷。

尽管我阅读了大部分随附的论文,但训练这些模型对我来说仍然是一个谜,这就是为什么我决定继续自己训练一个模型,以真正了解它是如何工作的。 我将其与训练问答模型结合起来,但这里仅详细介绍 DistilBERT 模型。

为了让你的生活更轻松,我决定对其工作原理进行简短回顾。 请查看这篇文章中的 distilbert.ipynb文件来查找相关代码。

在线工具推荐: Three.js AI纹理开发包 - YOLO合成数据生成器 - GLTF/GLB在线编辑 - 3D模型格式在线转换 - 可编程3D场景编辑器 

1、为什么选择 DistilBERT

要回答的第一个问题是为什么我选择 DistilBERT 而不是 BERT、ALBERT 和该模型的所有其他变体。 不幸的是,我没有无限的云计算访问权限,只有内存有限的本地 GPU,因此我必须针对模型大小和训练时间而不是性能进行优化。

也就是说,与 BERT 相比,官方的 DistilBERT 性能仅下降了3%,这似乎是一个合理的权衡。 BERT 基础有1.1亿个参数,训练时间为12天,而 DistilBERT 有6600 万个参数,训练时间只有3.5天左右。 原始论文中指出模型减小了 40%,保留了97% 的语言理解能力,速度提高 60%。

我查看了这篇文章中对 BERT、RoBERTA、DistilBERT 和 XLNet 的简短总结和比较,文章在评论中提供了一个很棒的表格,比较了所有模型。

2、数据

我使用 HuggingFace  的 OpenWebText 数据集来训练模型。 它是 OpenAI 的 WebText 数据集的开源版本。 它包含从 Reddit 采样的 8013769 个段落。

HuggingFace 为许多数据集和模型提供了一个令人惊叹的(!!!)界面,我在整个项目中都使用了它。 只需使用以下命令即可下载整个数据集。

from datasets import load_dataset

ds = load_dataset("openwebtext")

然后我继续将数据集以 10 000 个为单位存储在本地,因为这需要一些时间,而且我不想每次都等待。

3、分词(tokenization)

接下来,我们需要为模型训练一个分词器(因为我们无法将自然语言输入到模型中)。 我们可以使用 HuggingFace 的 BertWordPieceTokenizer。 我们只需传递文件的路径,它就会自动完成所有操作。 此外,我们还需要添加特殊标记 PAD(填充)、UNK(未知)、CLS(分类)、SEP(分隔符)和 MSK(掩码)标记。 有关这些标记的解释,请参阅基本 BERT 模型教程。

from tokenizers import BertWordPieceTokenizer

paths = [str(x) for x in Path('data/original').glob('**/*.txt')]

tokenizer = BertWordPieceTokenizer(
        clean_text=True,
        handle_chinese_chars=False,
        strip_accents=False,
        lowercase=True
)
tokenizer.train(files=paths[:10], vocab_size=30_000, min_frequency=2,
                    limit_alphabet=1000, wordpieces_prefix='##',
                    special_tokens=['[PAD]', '[UNK]', '[CLS]', '[SEP]', '[MASK]'])

当我们测试它时,我们得到以下标记并再次解码它们,结果表明标记生成器在每个输入的开头添加了一个 CLS 标记,并在句子后面添加了分隔符标记。 此外,我们还看到标记化输入包含输入 id(每个单词的 id)和注意掩码(告诉模型哪些标记很重要,即如果我们将序列填充到给定长度,它们将为 0)。

tokens = tokenizer('Hello, how are you?')
print(tokens)
# {'input_ids': [2, 21694, 16, 2287, 2009, 1991, 35, 3], 
# 'attention_mask': [1, 1, 1, 1, 1, 1, 1, 1]}

tokenizer.decode(tokens['input_ids']) 
# '[CLS] hello, how are you? [SEP]'

4、数据集和数据加载器

我们可以继续使用自定义的 Dataset 类和 PyTorch 中的 DataLoader 准备要加载到模型中的数据。 数据集类可以在这里找到。 我们基本上加载文件并使用我们的分词器对输入进行编码。

我在数据集中做的另一件事是逐个加载文件。 考虑到内存限制,我必须以这种方式实现它。 它有一些缺点,即你不能以这种方式洗牌数据,因为这会把一切搞乱。 不过,这应该不是什么太大的问题,因为数据集已经根据数据集描述进行了改组。

在训练过程中,模型尝试预测被屏蔽的标记,我们需要对其进行屏蔽。 因此,我屏蔽了(分配 MSK 令牌)15% 的输入,效果非常好。 其中一些基于 DistilBERT 的 HuggingFace 实现,可以在这里找到。

dataset = Dataset(paths = [str(x) for x in Path('data/original').glob('**/*.txt')][50:70], tokenizer=tokenizer)
loader = torch.utils.data.DataLoader(dataset, batch_size=8)

test_dataset = Dataset(paths = [str(x) for x in Path('data/original').glob('**/*.txt')][10:12], tokenizer=tokenizer)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=4)

5、模型

接下来我们必须定义我们的模型,是的,你猜对了,我们在这里也使用 HuggingFace。 它提供了一个令人惊叹的界面,使训练变得非常容易。

from transformers import DistilBertForMaskedLM, DistilBertConfig

config = DistilBertConfig(
    vocab_size=30000,
    max_position_embeddings=514
)
model = DistilBertForMaskedLM(config)

我们使用学习率为 1e-4 的 AdamW 作为优化器并训练 10 个 epoch(这已经花费了很多时间)。 在下面,你可以找到我的训练过程,这是非常基础的代码。

epochs = 10

for epoch in range(epochs):
    loop = tqdm(loader, leave=True)
    
    # set model to training mode
    model.train()
    losses = []
    
    # iterate over dataset
    for batch in loop:
        optim.zero_grad()
        
        # copy input to device
        input_ids = batch['input_ids'].to(device)
        attention_mask = batch['attention_mask'].to(device)
        labels = batch['labels'].to(device)
        
        # predict
        outputs = model(input_ids, attention_mask=attention_mask, labels=labels)
        
        # update weights
        loss = outputs.loss
        loss.backward()
        
        optim.step()
        
        # output current loss
        loop.set_description(f'Epoch {epoch}')
        loop.set_postfix(loss=loss.item())
        losses.append(loss.item())
        
    print("Mean Training Loss", np.mean(losses))
    losses = []
    loop = tqdm(test_loader, leave=True)
    
    # set model to evaluation mode
    model.eval()
    
    # iterate over dataset
    for batch in loop:
        # copy input to device
        input_ids = batch['input_ids'].to(device)
        attention_mask = batch['attention_mask'].to(device)
        labels = batch['labels'].to(device)
        
        # predict
        outputs = model(input_ids, attention_mask=attention_mask, labels=labels)
        
        # update weights
        loss = outputs.loss
        
        # output current loss
        loop.set_description(f'Epoch {epoch}')
        loop.set_postfix(loss=loss.item())
        losses.append(loss.item())
    print("Mean Test Loss", np.mean(losses))

6、测试

之后,我们可以运行一些健全性测试来查看模型对某些屏蔽标记的预测。 我们可以再次使用 HuggingFace 创建一个管道,它将为我们处理预测。 我们使用 fill.tokenizer.mask_token 将 MSK 令牌添加到输入中。

from transformers import pipeline

fill = pipeline("fill-mask", model='distilbert', config=config, tokenizer='distilbert_tokenizer')
fill(f'It seems important to tackle the climate {fill.tokenizer.mask_token}.')

此外,我们得到了以下带有置信水平的预测,这些预测似乎都是这句话中合理的下一个标记。

  • change: 0.19
  • crisis: 0.12
  • issues: 0.05
  • issue: 0.04

7、结束语

总而言之,考虑到基础设施的限制,结果相当不错。 显然,我们没有达到与原始模型相当的性能,但如果确实想在应用程序中使用它,你可以使用预训练模型(请参考这里)。


原文链接:DistilBERT模型训练实战 - BimAnt

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/189624.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

Java(七)(Lambda表达式,正则表达式,集合(Collection,Collection的遍历方式))

目录 Lambda表达式 省略写法(要看懂) 正则表达式 语法 案例 正则表达式的搜索替换和分割内容 集合进阶 集合体系结构 Collection Collection的遍历方式 迭代器 增强for循环 Lambda表达式遍历Collection List集合 ArrayList LinkedList 哈希值 HashSet底层原理 …

中东客户亲临广东育菁装备参观桌面型数控机床生产

近日,中东地区的一位重要客户在广东育菁装备有限公司的热情接待下,深入了解了该公司生产的桌面型数控机床。这次会面不仅加强了双方在业务领域的交流,也为中国与中东地区的经济合作描绘出更美好的前景。 在育菁装备公司各部门主要负责人及工作…

赋值,浅拷贝,深拷贝

1.前置知识 数据分为基本类型(String, Number, Boolean, Null, Undefined,Symbol)和引用类型(Object)基本类型:直接存储在栈内存中的数据引用类型:指向改数据的指针变量存储在栈内存中,真实的数据存储在堆内存中引用类型在栈内存…

cephadm部署ceph quincy版本,使用ceph-csi连接

环境说明 IP主机名角色 存储设备 192.168.2.100 master100 mon,mgr,osd,mds,rgw 大于5G的空设备192.168.2.101node101mon,mgr,osd,mds,rgw大于5G的空设备192.168.2.102node102mon,mgr,osd,mds,rgw大于5G的空设备 关闭防火墙 关闭并且禁用selinux 配置主机名/etc/hosts …

【UGUI】中Content Size Fitter)组件-使 UI 元素适应其内容的大小

官方文档:使 UI 元素适应其内容的大小 - Unity 手册 必备组件:Content Size Fitter 通常,在使用矩形变换定位 UI 元素时,应手动指定其位置和大小(可选择性地包括使用父矩形变换进行拉伸的行为)。 但是&a…

什么是 Node.js?

在 Node.js 出现之前,最常见的 JavaScript 运行时环境是浏览器,也叫做 JavaScript 的宿主环境。浏览器为 JavaScript 提供了 DOM API,能够让 JavaScript 操作浏览器环境(JS 环境)。 2009 年初 Node.js 出现了&#xf…

智能AI名片-Pushmall推贴SCRM数字名片的发展趋势

智能AI名片-Pushmall推贴SCRM数字名片的发展趋势 基于相识靠铺人脉相互引荐,共享人脉资源,众筹共创赋能交友、商务实现大众创业,灵活创收的智能AI名片平台。帮助企业实现成员管理与客户资源管理。功能说明 1、搜索查询:个人信息与…

1 时间序列模型入门: LSTM

0 前言 循环神经网络(Recurrent Neural Network,RNN)是一种用于处理序列数据的神经网络。相比一般的神经网络来说,他能够处理序列变化的数据。比如某个单词的意思会因为上文提到的内容不同而有不同的含义,RNN就能够很好…

lv11 嵌入式开发 C工程与寄存器封装 10

目录 1 C语言工程介绍 1.1 工程模板介绍 1.2 启动代码分析 2 C语言实现LED实验 2.1 C语言如何实现寄存器读写 2.2 实现LED实验 2.3 练习 1 C语言工程介绍 1.1 工程模板介绍 工程目录,后续代码都会利用到这个目录 interface.c 写了一个main函数的框架 int …

nginx反向代理解决跨域前端实践

需求实现 本地请求百度的一个搜索接口,用nginx代理解决跨域思路:前端和后端都用nginx代理到同一个地址8080,这样访问接口就不存在跨域限制 本地页面 查询一个百度搜索接口,运行在http://localhost:8035 index.js const path …

Stable-Diffusion——Windows部署教程

Windows 参考文章:从零开始,手把手教你本地部署Stable Diffusion Webui AI绘画(非最新版) 一键脚本安装 默认环境安装在项目路径的venv下 conda create -n df_env python3.10安装pytorch:(正常用国内网就行) python -…

【Unity实战】切换场景加载进度和如何在后台异步加载具有庞大世界的游戏场景,实现无缝衔接(附项目源码)

文章目录 最终效果前言一、绘制不同的场景二、切换场景加载进度1. 简单实现2. 优化 三、角色移动和跳跃控制四、添加虚拟摄像机五、触发器动态加载场景六、最终效果参考源码完结 最终效果 前言 观看本文后,我的希望你对unity场景管理有更好的理解,并且能…

Java抽象类和接口(1)

🐵本篇文章将对抽象类和接口相关知识进行讲解 一、抽象类 先来看下面的代码: class Shape {public void draw() {System.out.println("画");} } class Cycle extends Shape {public void draw() {System.out.println("圆形");} } …

飞翔的鸟小游戏

第一步是创建项目 项目名自拟 第二步创建个包名 来规范class 再创建一个包 来存储照片 如下: package game; import java.awt.*; import javax.swing.*; import javax.imageio.ImageIO;public class Bird {Image image;int x,y;int width,height;int size;doub…

一个超强算法模型实战案例!

哈喽,大家周末愉快,今儿不了很多的原理性内容。准备和大家一起实现一个开源且重要的项目:MNIST数字分类机器学习。 大概介绍下:MNIST数字分类项目旨在使用机器学习技术来构建一个模型,能够自动识别手写数字的图像。这个项目是一个经典的图像分类任务,常用于入门级机器学…

基于51单片机超市快递寄存自动柜设计源程序

一、系统方案 1、本设计采用这51单片机作为主控器。 2、存包,GSM短信取件码。 3、液晶1620显示。 4、矩阵键盘输入取件码,完成取包。 二、硬件设计 原理图如下: 三、单片机软件设计 1、首先是系统初始化 /******************************…

MFC 绘制单一颜色三角形、渐变颜色边框三角形、渐变填充三角形、正弦函数曲线实例

MFC 绘制三种不同圆形以及绘制正弦函数曲线 本文使用visual Studio MFC 平台实现绘制单一颜色圆形、渐变颜色边框圆形、渐变填充圆形以及绘制三角函数正弦函数曲线. 关于基础工程的创建请参考 01-Visual Studio 使用MFC 单文档工程绘制单一颜色直线和绘制渐变颜色的直线 02-vis…

吉他初学者学习网站搭建系列(1)——目录

文章目录 背景文章目录功能网站地址网站展示展望 背景 这个系列是对我最近周末搭建的吉他工具类平台YUERGS的总结。我个人业余爱好是自学吉他,我会在这个平台中动手集成我认为很有帮助的一些工具,来提升我的吉他水平和音乐素养,希望也可以帮…