基于 chinese-roberta-wwm-ext 微调训练 6 分类情感分析模型

一、模型和数据集介绍

1.1 预训练模型

chinese-roberta-wwm-ext 是基于 RoBERTa 架构下开发,其中 wwm 代表 Whole Word Masking,即对整个词进行掩码处理,通过这种方式,模型能够更好地理解上下文和语义关联,提高中文文本处理的准确性和效果。

与原始的 BERT 模型相比,chinese-roberta-wwm-ext 在训练数据规模和训练步数上做了一些调整,以进一步提升模型的性能和鲁棒性。并且在大规模无监督语料库上进行了预训练,使其具备强大的语言理解和生成能力。它能够广泛应用于各种自然语言处理任务,如文本分类、命名实体识别、情感分析等。我们可以使用这个模型作为基础,在不同的任务上进行微调和迁移学习,以实现更准确、高效的中文文本处理。

huggingface地址:https://huggingface.co/hfl/chinese-roberta-wwm-ext

进到 huggingface 中下载预训练模型:

在这里插入图片描述

1.2 数据集

数据集采用 SMP2020微博情绪分类评测 ,进入下面链接下载数据集:

https://smp2020ewect.github.io/

在这里插入图片描述

下载后数据分了三种类型,训练集、测试集、验证集

在这里插入图片描述

数据的格式为 JSON 格式,结构如下:

在这里插入图片描述

其中 label 为当前 content 的分类,一共分了 6 种类别,如下所示:

label说明
fear恐惧
neutral无情绪
sad悲伤
surprise惊奇
angry愤怒
happy积极

二、模型微调训练

2.1、处理数据集构建DataLoader

from transformers import BertTokenizer
from torch.utils.data import DataLoader, RandomSampler, TensorDataset
import numpy as np
import torch
import json
import os


class GenDateSet():
    def __init__(self, tokenizer, train_file, val_file, label_dict, max_length=128, batch_size=10):
        self.train_file = train_file
        self.val_file = val_file
        self.max_length = max_length
        self.batch_size = batch_size
        self.label_dict = label_dict
        self.tokenizer = tokenizer

    def gen_data(self, file):
        if not os.path.exists(file):
            raise Exception("数据集不存在")
        input_ids = []
        input_types = []
        input_masks = []
        labels = []
        with open(file, encoding='utf8') as f:
            data = json.load(f)
        if not data:
            raise Exception("数据集不存在")

        # 处理数据
        for index, item in enumerate(data):
            text = item['content']
            label = item['label']
            tokens = self.tokenizer(text, padding="max_length", truncation=True, max_length=self.max_length)
            input_id, types, masks = tokens['input_ids'], tokens['token_type_ids'], tokens['attention_mask']
            input_ids.append(input_id)
            input_types.append(types)
            input_masks.append(masks)
            y_ = self.label_dict[label]
            labels.append([y_])

            if index % 1000 == 0:
                print('处理', index, '条数据')

        # 构建 TensorDataset
        data_gen = TensorDataset(torch.LongTensor(np.array(input_ids)),
                                 torch.LongTensor(np.array(input_types)),
                                 torch.LongTensor(np.array(input_masks)),
                                 torch.LongTensor(np.array(labels)))
        # 打乱
        sampler = RandomSampler(data_gen)
        # 构建 DataLoader
        return DataLoader(data_gen, sampler=sampler, batch_size=self.batch_size)

    def gen_train_data(self):
        # 生成训练集
        return self.gen_data(self.train_file)

    def gen_val_data(self):
        # 生成验证集
        return self.gen_data(self.val_file)

2.2 构建模型迭代训练

import torch
from transformers import AutoTokenizer, AutoModelForSequenceClassification
import torch.nn as nn
from tqdm import tqdm
from gen_datasets import GenDateSet

# 标签结构
label_dict = {
    'fear': 0,
    'neutral': 1,
    'sad': 2,
    'surprise': 3,
    'angry': 4,
    'happy': 5
}

# 预训练模型位置
model_dir = 'D:\\AIGC\\model\\chinese-roberta-wwm-ext'
# 这里暂时使用测试集训练,数据较少
train_file = 'data/usual_train.txt'
# train_file = 'data/usual_test_labeled.txt'
# 验证集
val_file = 'data/virus_eval_labeled.txt'
# 训练模型存储位置
save_model_path = './model/'
# 最大长度
max_length = 128
# 分类数
num_classes = 6
batch_size = 10
epoch = 10

def val(model, device, data):
    model.eval()
    test_loss = 0.0
    acc = 0
    for (input_id, types, masks, y) in tqdm(data):
        input_id, types, masks, y = input_id.to(device), types.to(device), masks.to(device), y.to(device)
        with torch.no_grad():
            y_ = model(input_id, token_type_ids=types, attention_mask=masks)
            logits = y_['logits']
        test_loss += nn.functional.cross_entropy(logits, y.squeeze())
        pred = logits.max(-1, keepdim=True)[1]
        acc += pred.eq(y.view_as(pred)).sum().item()
    test_loss /= len(data)
    return acc / len(data.dataset)


def main():
    # 加载 tokenizer 和  model
    tokenizer = AutoTokenizer.from_pretrained(model_dir)
    model = AutoModelForSequenceClassification.from_pretrained(model_dir, num_labels=num_classes)
    # 加载数据集
    dateset = GenDateSet(tokenizer, train_file, val_file, label_dict, max_length, batch_size)
    # 训练集
    train_data = dateset.gen_train_data()
    # 验证集
    val_data = dateset.gen_val_data()

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model.train()
    model = model.to(device)
    # 优化器
    optimizer = torch.optim.Adam(model.parameters(), lr=1e-5)

    best_acc = 0.0
    for epoch_index in range(epoch):
        batch_epoch = 0
        for (input_id, types, masks, y) in tqdm(train_data):
            input_id, types, masks, y = input_id.to(device), types.to(device), masks.to(device), y.to(device)
            # 前向传播
            outputs = model(input_id, token_type_ids=types, attention_mask=masks, labels=y)
            # 梯度清零
            optimizer.zero_grad()
            # 计算 loss
            loss = outputs.loss
            # 反向传播
            loss.backward()
            optimizer.step()
            batch_epoch += 1
            if batch_epoch % 10 == 0:
                print('Train Epoch:', epoch_index, ' , batch_epoch: ', batch_epoch, ' , loss = ', loss.item())

        # 评估准确度
        acc = val(model, device, val_data)
        print('Train Epoch:', epoch_index, ' val acc = ', acc)
        # 存储 best model
        if best_acc < acc:
            # torch.save(model.state_dict(), save_model_path)
            model.save_pretrained("./model")
            tokenizer.save_pretrained("./model")
            best_acc = acc

if __name__ == '__main__':
    main()

运行之后可以看到训练进度:

在这里插入图片描述

训练中可以看到验证集的准确率,等待训练结束后,可以在 model 下看到保存的模型和 tokenizer 文件:

在这里插入图片描述

三、模型测试

import torch
from transformers import AutoTokenizer, AutoModelForSequenceClassification

label_dict = {
    0: '恐惧',
    1: '无情绪',
    2: '悲伤',
    3: '惊奇',
    4: '愤怒',
    5: '积极'
}
model_dir = "./model"
num_classes = 6
max_length = 128


def main():
    # 加载预训练模型和分词器
    tokenizer = AutoTokenizer.from_pretrained(model_dir)
    model = AutoModelForSequenceClassification.from_pretrained(model_dir, num_labels=num_classes)

    while True:
        text = input("请输入内容: \n ")
        if not text or text == "":
            continue
        if text == "q":
            break

        encoded_input = tokenizer(text, padding="max_length", truncation=True, max_length=max_length)

        input_ids = torch.tensor([encoded_input['input_ids']])
        token_type_ids = torch.tensor([encoded_input['token_type_ids']])
        attention_mask = torch.tensor([encoded_input['attention_mask']])
        # 前向传播
        y_ = model(input_ids, token_type_ids=token_type_ids, attention_mask=attention_mask)
        output = y_['logits'][0]
        pred = output.max(-1, keepdim=True)[1][0].item()
        print('预测结果:', label_dict[pred])

if __name__ == '__main__':
    main()

在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/39820.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

DuiLib中的list控件以及ListContainerElement控件

文章目录 前言1、创建list控件2、创建 ListContainerElement 元素&#xff0c;并添加到 List 控件中,这里的ListContainerElement用xml来表示3、在 ListContainerElement 元素中添加子控件 1、List控件2、ListContainerElement控件 前言 在 Duilib 中&#xff0c;List 控件用于…

Python 集合 add()函数使用详解,集合添加元素

「作者主页」&#xff1a;士别三日wyx 「作者简介」&#xff1a;CSDN top100、阿里云博客专家、华为云享专家、网络安全领域优质创作者 「推荐专栏」&#xff1a;小白零基础《Python入门到精通》 add函数使用详解 1、元素的顺序2、可以添加的元素类型3、添加重复的元素4、一次只…

Python爬虫学习笔记:1688商品详情API 开发API接口文档

1688API接口是阿里巴巴集团推出的一种开放平台&#xff0c;提供了丰富的数据接口、转换工具以及开发资源&#xff0c;为开发者提供了通用的应用接口及大量数据资源&#xff0c;支持开发者在1688上进行商品搜索、订单管理、交易报表及物流等方面的操作。 1688API接口主要包含以…

Unity游戏源码分享-单车骑行游戏

Unity游戏源码分享-单车骑行游戏 项目地址&#xff1a;https://download.csdn.net/download/Highning0007/88057717

MySQL之DML和DDL

1、显示所有职工的基本信息&#xff1a; 2、查询所有职工所属部门的部门号&#xff0c;不显示重复的部门号。 3、求出所有职工的人数。 4、列出最高工和最低工资。 5、列出职工的平均工资和总工资。 6、创建一个只有职工号、姓名和参加工作的新表&#xff0c;名为工作日期表。 …

react报错信息

报错信息 render函数里dom不能直接展示obj对象 取变量记得要有{} https://segmentfault.com/q/1010000009619339 这样在写的时候就已经执行方法了&#xff0c;所以此处用箭头函数&#xff08;&#xff09;》{}才会在点击时执行或者 遍历数据使用map来遍历&#xff0c;使用forea…

TCP和UDP的区别

连接&#xff1a;TCP 是面向连接的传输层协议&#xff0c;传输数据前先要建立连接&#xff1b;UDP 是不需要连接&#xff0c;即刻传输数据。首部开销&#xff1a;TCP 首部长度较长&#xff0c;首部在没有使用「选项」字段时是 20 个字节&#xff0c;如果使用了「选项」字段则会…

概率论的学习和整理17:EXCEL的各种期望,方差的公式

目录 1 总结 1.1 本文目标总结方法 1.2 总结一些中间关键函数 2 均值和期望 2.1 求均值的公式 2.2 求随机变量期望的公式 2.3 求随机变量期望的朴素公式 3 方差 3.1 确定数的方差 3.2 统计数的方差公式 3.3 随机变量的方差公式 3.4 EXCEL提供的直接计算方差的公式 …

CentOS目录详解

在centos中&#xff0c;最顶层的目录称作根目录&#xff0c; 用/表示。/目录下用户可以再创建目录&#xff0c;但是有一些目录随着系统创建就已经存在&#xff0c;接下来重点介绍几个常用目录。 /bin&#xff08;binary&#xff09;包含了许多所有用户都可以访问的可执行文件&a…

PostgreSQL MVCC的弊端优化方案

我们之前的博客文章“我们最讨厌的 PostgreSQL 部分”讨论了大家最喜欢的 DBMS 多版本并发控制 (MVCC) 实现所带来的问题。其中包括版本复制、表膨胀、索引维护和真空管理。本文将探讨针对每个问题优化 PostgreSQL 的方法。 尽管 PostgreSQL 的 MVCC 实现是 Oracle 和 MySQL 等…

如何在Appium中使用AI定位

当我们在写自动化测试脚本的时候&#xff0c;传统情况下一定要知道元素的属性&#xff0c;如id、name、class等。那么通过AI的方式定位元素可能就不需要知道元素的属性&#xff0c;评价人对元素的判断来定位&#xff0c;比如&#xff0c;看到一个搜索框&#xff0c;直接使用ai:…

【无标题】使用html2canvas和jspdf生成的pdf在不同大小的屏幕下文字大小不一样

问题&#xff1a;使用html2canvas和jspdf生成的pdf在不同大小的屏幕下文字大小不一样&#xff0c;在mac下&#xff0c;一切正常&#xff0c;看起来很舒服&#xff0c;但是当我把页面放在扩展屏幕下&#xff08;27寸&#xff09;&#xff0c;再生成一个pdf&#xff0c;虽然排版一…

【算法基础:数据结构】2.3 并查集

文章目录 并查集算法原理&#xff08;重要&#xff01;⭐&#xff09; 经典例题836. 合并集合&#xff08;重要&#xff01;模板&#xff01;⭐&#xff09;837. 连通块中点的数量&#xff08;维护连通块大小的并查集&#xff09;240. 食物链&#xff08;维护额外信息的并查集&…

【AI绘画】AI绘画乐趣:稳定增强扩散技术展现

目录 前言一、Stable Diffusion是什么&#xff1f;二、安装stable-diffusion-webui1. python安装2. 下载模型3. 开始安装&#xff1a;4. 汉化&#xff1a;5. 模型使用&#xff1a;6. 下载新模型&#xff1a;7. 基础玩法 三、总结 前言 本文将借助stable-diffusion-webui项目来…

【数据结构与算法】哈夫曼编码(最优二叉树)实现

哈夫曼编码 等长编码&#xff1a;占的位置一样 变长编码&#xff08;不等长编码&#xff09;&#xff1a;经常使用的编码比较短&#xff0c;不常用的比较短 最优&#xff1a;总长度最短 最优的要求&#xff1a;占用空间尽可能短&#xff0c;不占用多余空间&#xff0c;且不…

网络版计算器

本次我们实现一个服务器版本的简单的计算器&#xff0c;通过自己在应用层定制协议来将数据传输出去。 协议代码 此处我们为了解耦&#xff0c;采用了两个类&#xff0c;分别表示客户端的请求和服务端的响应。 Request class Request { public:Request(){}Request(int x, int…

Unity 任意数据在Scene窗口Debug

任意数据在Scene窗口Debug &#x1f354;效果&#x1f96a;食用方法 &#x1f354;效果 如下所示可以很方便的把需要Debug的数据绘制到Scene中&#xff08;普通的Editor脚本只能够对MonoBehaviour进行Debug&#xff09; &#x1f96a;食用方法 &#x1f4a1;. 新建脚本继承Z…

实例018 类似windows xp的程序界面

实例说明 在Windows XP环境下打开控制面板&#xff0c;会发现左侧的导航界面很实用。双击展开按钮&#xff0c;导航栏功能显示出来&#xff0c;双击收缩按钮&#xff0c;导航按钮收缩。下面通过实例介绍此种主窗体的设计方法。运行本例&#xff0c;效果如图1.18所示。 ​编辑…

如何顺势而为,让ChatGPT为教育所用?

恐惧和回避无法阻挡科技的浪潮&#xff0c;教育与AI的深度融合时代已经到来&#xff0c;如何把AI当做工具&#xff0c;把其成为教育的机会而非威胁&#xff0c;是教育体系未来不得不得面对的新变化。 接受ChatGPT作为一种教学辅助工具&#xff0c;成为教师的朋友或者帮手&…

Leetcode每日一题:979. 在二叉树中分配硬币(2023.7.14 C++)

目录 979. 在二叉树中分配硬币 题目描述&#xff1a; 实现代码与解析&#xff1a; dfs&#xff08;后序遍历&#xff09; 原理思路&#xff1a; 979. 在二叉树中分配硬币 题目描述&#xff1a; 给定一个有 N 个结点的二叉树的根结点 root&#xff0c;树中的每个结点上都对…