SwanLab入门深度学习:BERT IMDB文本情感分类

基于BERT模型的IMDB电影评论情感分类,是NLP经典的Hello World任务之一。

这篇文章我将带大家使用SwanLab、transformers、datasets三个开源工具,完成从数据集准备、代码编写、可视化训练的全过程。

观察了一下,中文互联网上似乎很少有能直接跑起来的BERT训练代码和教程,所以也希望这篇文章可以帮到大家。
在这里插入图片描述

  • 代码:完整代码直接看本文第5节
  • 模型与数据集:百度云,提取码: u9gi
  • 实验过程:BERT-SwanLab
  • SwanLab:https://swanlab.cn
  • transformers:https://github.com/huggingface/transformers
  • datasets:https://github.com/huggingface/datasets

1.环境安装

我们需要安装以下这4个Python库:

transformers>=4.41.0
datasets>=2.19.1
swanlab>=0.3.3

一键安装命令:

pip install transformers datasets swanlab

他们的作用分别是:

  1. transformers:HuggingFace出品的深度学习框架,已经成为了NLP(自然语言处理)领域最流行的训练与推理框架。代码中用transformers主要用于加载模型、训练以及推理。
  2. datasets:同样是HuggingFace出品的数据集工具,可以下载来自huggingface社区上的数据集。代码中用datasets主要用于下载、加载数据集。
  3. swanlab:在线训练可视化和超参数记录工具,官网,可以记录整个实验的超参数、指标、训练环境、Python版本等,并可是化成图表,帮助你分析训练的表现。代码中用swanlab主要用于记录指标和可视化。

本文的代码测试于transformers4.41.0、datasets2.19.1、swanlab==0.3.3,更多库版本可查看SwanLab记录的Python环境。

2.加载BERT模型

BERT模型我们直接下载来自HuggingFace上由Google发布的bert-case-uncased预训练模型。

执行下面的代码,会自动下载模型权重并加载模型:

from transformers import AutoTokenizer, AutoModelForSequenceClassification, Trainer, TrainingArguments

# 加载预训练的BERT tokenizer
model = AutoModelForSequenceClassification.from_pretrained('bert-base-uncased', num_labels=2)

如果国内下载比较慢的话,可以在这个百度云(提取码: u9gi)下载后,把bert-base-uncased文件夹放到根目录,然后改写上面的代码为:

model = AutoModelForSequenceClassification.from_pretrained('./bert-base-uncased', num_labels=2)

3.加载IMDB数据集

IMDB数据集(Internet Movie Database Dataset)是自然语言处理(NLP)领域中一个非常著名和广泛使用的数据集,主要应用于文本情感分析任务。

IMDB数据集源自全球最大的电影数据库网站Internet Movie Database(IMDb),该网站包含了大量的电影、电视节目、纪录片等影视作品信息,以及用户对这些作品的评论和评分。
数据集包括50,000条英文电影评论,这些评论被标记为正面或负面情感,用以进行二分类任务。其中,25,000条评论被分配为训练集,另外25,000条则作为测试集。训练集和测试集都保持了平衡的正负样本比例,即各含50%的正面评论和50%的负面评论.

在这里插入图片描述
我们同样直接下载HuggingFace上的imdb数据集,执行下面的代码,会自动下载数据集并加载:

from datasets import load_dataset

# 加载IMDB数据集
dataset = load_dataset('imdb')

如果国内下载比较慢的话,可以在这个百度云(提取码: u9gi)下载后,把imdb文件夹放到根目录,然后改写上面的代码为:

dataset = load_dataset('./imdb')

4.集成SwanLab

因为swanlab已经和transformers框架做了集成,所以将SwanLabCallback类传入到trainercallbacks参数中即可实现实验跟踪和可视化:

from swanlab.integration.huggingface import SwanLabCallback

# 设置swanlab回调函数
swanlab_callback = SwanLabCallback()

...

# 定义Transformers Trainer
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=tokenized_datasets['train'],
    eval_dataset=tokenized_datasets['test'],
    # 传入swanlab回调函数
    callbacks=[swanlab_callback],
)

想了解更多关于SwanLab的知识,请看SwanLab官方文档。

5.开始训练!

训练过程看这里:BERT-SwanLab。

在首次使用SwanLab时,需要去官网注册一下账号,然后在用户设置复制一下你的API Key。

在这里插入图片描述
然后在终端输入swanlab login:

swanlab login

把API Key粘贴进去即可完成登录,之后就不需要再次登录了。

完整的训练代码:

"""
用预训练的Bert模型微调IMDB数据集,并使用SwanLabCallback回调函数将结果上传到SwanLab。
IMDB数据集的1是positive,0是negative。
"""

import torch
from datasets import load_dataset
from transformers import AutoTokenizer, AutoModelForSequenceClassification, Trainer, TrainingArguments
from swanlab.integration.huggingface import SwanLabCallback
import swanlab

def predict(text, model, tokenizer, CLASS_NAME):
    inputs = tokenizer(text, return_tensors="pt")

    with torch.no_grad():
        outputs = model(**inputs)
        logits = outputs.logits
        predicted_class = torch.argmax(logits).item()

    print(f"Input Text: {text}")
    print(f"Predicted class: {int(predicted_class)} {CLASS_NAME[int(predicted_class)]}")
    return int(predicted_class)

# 加载IMDB数据集
dataset = load_dataset('imdb')

# 加载预训练的BERT tokenizer
tokenizer = AutoTokenizer.from_pretrained('bert-base-uncased')

# 定义tokenize函数
def tokenize(batch):
    return tokenizer(batch['text'], padding=True, truncation=True)

# 对数据集进行tokenization
tokenized_datasets = dataset.map(tokenize, batched=True)

# 设置模型输入格式
tokenized_datasets = tokenized_datasets.rename_column("label", "labels")
tokenized_datasets.set_format('torch', columns=['input_ids', 'attention_mask', 'labels'])

# 加载预训练的BERT模型
model = AutoModelForSequenceClassification.from_pretrained('bert-base-uncased', num_labels=2)

# 设置训练参数
training_args = TrainingArguments(
    output_dir='./results',
    eval_strategy='epoch',
    save_strategy='epoch',
    learning_rate=2e-5,
    per_device_train_batch_size=8,
    per_device_eval_batch_size=8,
    logging_first_step=100,
    # 总的训练轮数
    num_train_epochs=3,
    weight_decay=0.01,
    report_to="none",
    # 单卡训练
)

CLASS_NAME = {0: "negative", 1: "positive"}

# 设置swanlab回调函数
swanlab_callback = SwanLabCallback(project='BERT',
                                   experiment_name='BERT-IMDB',
                                   config={'dataset': 'IMDB', "CLASS_NAME": CLASS_NAME})

# 定义Trainer
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=tokenized_datasets['train'],
    eval_dataset=tokenized_datasets['test'],
    callbacks=[swanlab_callback],
)

# 训练模型
trainer.train()

# 保存模型
model.save_pretrained('./sentiment_model')
tokenizer.save_pretrained('./sentiment_model')

# 测试模型
test_reviews = [
    "I absolutely loved this movie! The storyline was captivating and the acting was top-notch. A must-watch for everyone.",
    "This movie was a complete waste of time. The plot was predictable and the characters were poorly developed.",
    "An excellent film with a heartwarming story. The performances were outstanding, especially the lead actor.",
    "I found the movie to be quite boring. It dragged on and didn't really go anywhere. Not recommended.",
    "A masterpiece! The director did an amazing job bringing this story to life. The visuals were stunning.",
    "Terrible movie. The script was awful and the acting was even worse. I can't believe I sat through the whole thing.",
    "A delightful film with a perfect mix of humor and drama. The cast was great and the dialogue was witty.",
    "I was very disappointed with this movie. It had so much potential, but it just fell flat. The ending was particularly bad.",
    "One of the best movies I've seen this year. The story was original and the performances were incredibly moving.",
    "I didn't enjoy this movie at all. It was confusing and the pacing was off. Definitely not worth watching."
]

model.to('cpu')
text_list = []
for review in test_reviews:
    label = predict(review, model, tokenizer, CLASS_NAME)
    text_list.append(swanlab.Text(review, caption=f"{label}-{CLASS_NAME[label]}"))

if text_list:
    swanlab.log({"predict": text_list})

swanlab.finish()

训练可视化过程:
在这里插入图片描述

训练大概需要6G左右的显存,我在一块3090上跑了,1个epoch大概要12~13分钟时间。

训练的推理结果:

在这里插入图片描述
这里我生成了10个比较简单的测试文本,微调后的BERT模型基本都能答对。

至此,我们顺利完成了用BERT预训练模型微调IMDB数据的训练过程~

相关链接

  • 代码:完整代码直接看本文第5节
  • 模型与数据集:百度云,提取码: u9gi
  • 实验过程:BERT-SwanLab
  • SwanLab:https://swanlab.cn
  • transformers:https://github.com/huggingface/transformers
  • datasets:https://github.com/huggingface/datasets

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/635332.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

Apache Log4j Server 反序列化命令执行漏洞(CVE-2017-5645)

漏洞复现环境搭建请参考 http://t.csdnimg.cn/MxmId 漏洞版本 Apache Log4j 2.8.2之前的2.x版本 漏洞验证 (1)开放端口4712 漏洞利用 (1)ysoserial工具获取 wget https://github.com/frohoff/ysoserial/releases/download/v0…

强化学习算法

从上图看出,强化学习可以分成价值/策略、随机策略/确定策略、在线策略/离线策略、蒙特卡洛/时间差分这四个维度。这里分析了基础算法中除了在线策略/离线策略以外的其他维度。 (一)基础知识 一、基础概念 重点概念:状态S、动作A、…

浏览器自动化~插件推荐Automa

引言 作为一款现代浏览器,得自动化吧,自主完成那些日复一日的重复性任务,开启音乐啥的不在话下~。而你则可以专注于其他更有意义的事情,如享受音乐带来的愉悦。但如果你对编写脚本一窍不通,又该如何实现这一愿景呢&am…

华为机考入门python3--(28)牛客28-素数伴侣

分类:质数、素数、贪心算法、矩阵 知识点: 素数里除了2,都是奇数 奇奇偶,偶+偶偶 对矩阵求和 sum(map(sum, matrix)) 查找元素 3 在列表中的索引 my_list.index(3) 题目来自【牛客】 质数又称素数,是指…

一种综合评价及决策方法:层次分析法AHP

大家好,层次分析法(Analytic Hierarchy Process,AHP)是一种多准则决策方法,它帮助决策者处理复杂的决策问题,将其分解成层次结构,然后通过两两比较来确定各个层次的因素之间的相对重要性。这种分析方式允许决策者对问题…

【vue与iframe通讯】

vue 与 iframe 通讯 发送数据vue 向 iframe 发送数据iframe 向 vue 发送数据接收信息( vue & iframe 通用) 实现相互通讯通讯流程图实现代码vue 页面iframe页面iframe内部重定向访问地址,更新vue路由 代码下载 前言:vue嵌套iframe实现步骤 发送数据 vue 向 if…

回溯算法05(leetcode491/46/47)

参考资料: https://programmercarl.com/0491.%E9%80%92%E5%A2%9E%E5%AD%90%E5%BA%8F%E5%88%97.html 491. 非递减子序列 题目描述: 给你一个整数数组 nums ,找出并返回所有该数组中不同的递增子序列,递增子序列中 至少有两个元素…

微软开发者大会:编程进入自然语言时代、“AI员工”闪亮登场

当地时间周二,美国科技公司微软召开年度Build开发者大会。在CEO纳德拉的带领下,微软各个产品团队再一次展现出惊人的执行力,在发布会上又拿出了接近50个新产品或功能更新。 整场发布会持续了接近两个小时,在这里挑选了一些投资者…

知识表示概述

文章目录 知识表示研究现状技术发展趋势 知识表示 知识是人类在认识和改造客观世界的过程中总结出的客观事实、概念、定理和公理的集合。知识具有不同的分类方式,例如按照知识的作用范围可分为常识性知识与领域性知识。知识表示是将现实世界中存在的知识转换成计算机…

Linux进程的地址空间

Linux进程的地址空间 1. 前言 在编写程序语言的代码时&#xff0c;打印输出一个变量的地址时&#xff0c;这个地址在内存中是以什么形式存在的&#xff1f;一个地址可以存储两个不同的值吗&#xff1f; 运行以下代码&#xff1a; #include <stdio.h> #include <un…

C#-根据日志等级进行日志的过滤输出

文章速览 概要具体实施创建Log系统动态修改日志等级 坚持记录实属不易&#xff0c;希望友善多金的码友能够随手点一个赞。 共同创建氛围更加良好的开发者社区&#xff01; 谢谢~ 概要 方便后期对软件进行维护&#xff0c;需要在一些关键处添加log日志输出&#xff0c;但时间长…

vulhub——ActiveMQ漏洞

文章目录 一、CVE-2015-5254(反序列化漏洞)二、CVE-2016-3088&#xff08;任意文件写入漏洞&#xff09;2.1 漏洞原理2.2 写入webshell2.3 写入crontab 三、CVE-2022-41678&#xff08;远程代码执行漏洞&#xff09;方法一方法2 四、CVE-2023-46604&#xff08;反序列化命令执行…

HTML+CSS+JS 扩散登录表单动画

效果演示 Code <!DOCTYPE html> <html><head><meta http-equiv="content-type" content="text/html; charset=utf-8"><meta name="viewport" content="width=device-width,initial-scale=1,maximum-scale=1,us…

MAIA:多模态自动化可解释智能体的突破

随着人工智能技术的飞速发展&#xff0c;深度学习模型在图像识别、自然语言处理等领域取得了显著成就。然而&#xff0c;这些模型的“黑箱”特性使得其决策过程难以理解&#xff0c;限制了它们的应用范围和可靠性。为了解决这一问题&#xff0c;研究者们提出了多种模型可解释性…

【机器学习】—机器学习和NLP预训练模型探索之旅

目录 一.预训练模型的基本概念 1.BERT模型 2 .GPT模型 二、预训练模型的应用 1.文本分类 使用BERT进行文本分类 2. 问答系统 使用BERT进行问答 三、预训练模型的优化 1.模型压缩 1.1 剪枝 权重剪枝 2.模型量化 2.1 定点量化 使用PyTorch进行定点量化 3. 知识蒸馏…

[emailprotected](7)父子通信,传递元素内容

目录 1&#xff0c;children 属性2&#xff0c;多个属性 普通对象等&#xff0c;可以通过变量直接传递&#xff0c;那类似 vue 中的 slot 插槽&#xff0c;如何传递元素内容&#xff1f; 1&#xff0c;children 属性 实际上&#xff0c;写在自定义组件标签的内部代码&#xf…

【再探】Java—泛型

Java 泛型本质是参数化类型&#xff0c;可以用在类、接口和方法的创建中。 1 “擦除式”泛型 Java的“擦除式”的泛型实现一直受到开发者的诟病。 “擦除式”的实现几乎只需要在Javac编译器上做出改进即可&#xff0c;不要改动字节码、虚拟机&#xff0c;也保证了以前没有使…

k8s pv 一直是release状态

如下图所示&#xff0c;pv 一直是release状态 这个时候大家可能就会想到现在我的 PVC 被删除了&#xff0c;PV 也变成了 Released 状态&#xff0c;那么我重建之前的 PVC 他们不就可以重新绑定了&#xff0c;事实并不会&#xff0c;PVC 只能和 Available 状态的 PV 进行绑定。…

【华为】将eNSP导入CRT,并解决不能敲Tab问题

华为】将eNSP导入CRT&#xff0c;并解决不能敲Tab问题 eNSP导入CRT打开eNSP&#xff0c;新建一个拓扑右键启动查看串口号关联CRT成功界面 SecureCRT连接华为模拟器ensp,Tab键不能补全问题选择Options&#xff08;选项&#xff09;-- Global Options &#xff08;全局选项&#…

ORB-SLAM2从理论到代码实现(六):Tracking程序详解(上)

1. Tracking框架 Tracking线程流程框图&#xff1a; 各流程对应的主要函数 2. Tracking整体流程图 上面这张图把Tracking.cc讲的特别明白。 tracking线程在获取图像数据后&#xff0c;会传给函数GrabImageStereo、GrabImageRGBD或GrabImageMonocular进行预处理&#xff0c;这…