bert分类模型使用

使用 bert-bert-chinese 预训练模型去做分类任务,这里找了新闻分类数据,数据有 20w,来自https://github.com/649453932/Bert-Chinese-Text-Classification-Pytorch/tree/master/THUCNews

数据 20w ,18w 训练数据,1w 验证数据, 1w 测试数据,10个类别我跑起来后,预测要7天7夜,于是吧数据都缩小了一些,每个类别抽一些,1800 训练数据,150 验证数据, 150 测试数据,都跑了 1.5 小时, cpu ,电脑 gpu 只有 2g 显存,带不起来

bert- base-chinses 模型下载:bert预训练模型下载-CSDN博客

训练

现在是大模型时代了,这篇文章的代码是利用大模型帮我写的的,通过大模型修正代码,并解释代码一直到可用,代码都写了注释了,整个分类流程就这样,算是一个通用模板了吧

train.py 

# 导入所需的库
import torch
import os
import time
from torch.utils.data import Dataset, DataLoader
from transformers import BertTokenizer, BertForSequenceClassification
from torch.optim import AdamW


# 定义数据集类,符合高内聚原则
class NewsTitleDataset(Dataset):
    def __init__(self, file_path, tokenizer, max_len=128):
        self.data = []
        with open(file_path, 'r', encoding='utf-8') as f:
            for line in f.readlines():
                title, label = line.strip().split('\t')
                inputs = tokenizer(title, padding='max_length', truncation=True, max_length=max_len)
                self.data.append({'input_ids': inputs['input_ids'], 'attention_mask': inputs['attention_mask'], 'label': int(label)})

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        '''在使用DataLoader加载数据进行训练或验证时被调用'''
        return self.data[idx]


# 训练函数(部分代码,实际训练时应包含更多细节如损失计算、模型更新等)
def train_model(model, train_loader, val_loader, optimizer, epochs=3, model_save_path='../output/bert_news_classifier'):
    # device = torch.device("cuda" if torch.cuda.is_available() else "cpu")  # 判断是否使用GPU
    device = torch.device("cpu")
    model.to(device)

    best_val_accuracy = None  # 初始化最优验证集准确率

    # 创建保存目录(如果不存在)
    os.makedirs(os.path.dirname(model_save_path), exist_ok=True)

    # 训练几次模型
    for epoch in range(epochs):
        model.train()  # 开启训练模式,会更新参数
        for batch in train_loader:
            input_ids = batch['input_ids'].to(device)  # 直接通过键名访问'input_ids'
            attention_mask = batch['attention_mask'].to(device)  # 直接通过键名访问'attention_mask'
            labels = batch['label'].to(device)  # 直接通过键名访问'label'

            outputs = model(input_ids=input_ids, attention_mask=attention_mask, labels=labels)
            loss = outputs.loss  # 获取损失

            optimizer.zero_grad()  # 清零梯度
            loss.backward()  # 反向传播
            optimizer.step()  # 更新权重

        # 在每个epoch结束时评估模型性能
        model.eval()
        with torch.no_grad():
            val_loss = 0
            correct_predictions = 0
            total_samples = len(val_data)  # 计算总样本数,用于后续计算准确率

            for batch in val_loader:
                input_ids = batch['input_ids'].to(device)
                attention_mask = batch['attention_mask'].to(device)
                labels = batch['label'].squeeze().to(device)

                # 计算logits而不是直接获取loss
                outputs = model(input_ids=input_ids, attention_mask=attention_mask)
                # logits 是模型对输入数据计算出的未归一化的类别概率分布,通常是一个形状为 (batch_size, num_classes) 的张量
                logits = outputs.logits

                # 手动计算loss(假设labels已转换为one-hot编码或数值标签)
                # 交叉熵损失函数,多分类问题中常用的损失函数,特别适合于处理像BERT这样的预训练模型输出的logits,并且与one-hot编码的目标标签一起使用
                loss_fct = torch.nn.CrossEntropyLoss()
                # labels 是实际的类别标签,需要转换成一个形状为 (batch_size,) 的张量以匹配logits的展开维度。
                # view(-1, model.num_labels) 会将logits展平为 (batch_size * num_classes) 的向量,使得每个样本的每个类别都有一个单独的概率值对应
                loss = loss_fct(logits.view(-1, model.num_labels), labels.view(-1))
                # .item() 方法用于从损失张量提取标量值。
                val_loss += loss.item()

                # 找出每个样本的最大概率对应的类别索引,即模型预测的结果。
                # dim=1 时,表示在第二个维度上找到最大值
                _, preds = torch.max(logits, dim=1)
                correct_predictions += (preds == labels).sum().item()

        val_accuracy = correct_predictions / total_samples
        print(f'Epoch {epoch + 1}, Validation Loss: {val_loss / len(val_loader):.4f}, Accuracy: {val_accuracy * 100:.2f}%')

        # 如果当前验证集上的准确率优于之前保存的最佳模型,则保存当前模型
        if best_val_accuracy is None or val_accuracy > best_val_accuracy:
            best_val_accuracy = val_accuracy
            torch.save(model.state_dict(), model_save_path)  # 保存模型参数


# 定义评估函数
def evaluate_model(model, data_loader):
    device = next(model.parameters()).device
    model.eval()
    correct_predictions = 0
    total_samples = 0

    with torch.no_grad():
        for batch in data_loader:
            inputs = {key: batch[key].to(device) for key in ['input_ids', 'attention_mask']}
            labels = batch['label'].to(device)

            outputs = model(**inputs)
            _, preds = torch.max(outputs.logits, dim=1)

            correct_predictions += (preds == labels).sum().item()
            total_samples += len(labels)

    return correct_predictions / total_samples


def collate_to_tensors(batch):
    input_ids = torch.tensor([example['input_ids'] for example in batch])
    attention_mask = torch.tensor([example['attention_mask'] for example in batch])
    labels = torch.tensor([example['label'] for example in batch])

    return {'input_ids': input_ids, 'attention_mask': attention_mask, 'label': labels}


start = time.time()

# 加载预训练的tokenizer和模型
tokenizer = BertTokenizer.from_pretrained('../bert-base-chinese')
with open('../data/class.txt', 'r', encoding='utf8') as f:
    class_labels = f.readlines()
model = BertForSequenceClassification.from_pretrained('../bert-base-chinese', num_labels=len(class_labels))  # 假设class_labels是一个包含所有类别的列表

# 加载训练、验证和测试数据集
train_data = NewsTitleDataset('../data/train.txt', tokenizer)
val_data = NewsTitleDataset('../data/dev.txt', tokenizer)
test_data = NewsTitleDataset('../data/test.txt', tokenizer)

# 创建DataLoader,用于批处理数据
# collate_to_tensors 调用函数,保证模型接受的数据参数类型一定为 pytorch 的张量类型
# shuffle=True 于防止模型过拟合和提高泛化性能至关重要,因为它确保了模型不会因为训练数据的顺序而产生依赖性。
# batch_size 示每次迭代从数据集中取出多少个样本作为一个批次(batch)进行训练。设置合理的批量大小有助于平衡计算效率和内存使用。
train_loader = DataLoader(train_data, batch_size=32, shuffle=True, collate_fn=collate_to_tensors)
val_loader = DataLoader(val_data, batch_size=32, collate_fn=collate_to_tensors)
test_loader = DataLoader(test_data, batch_size=32, collate_fn=collate_to_tensors)

# 设置优化器与学习率
# model.parameters():这是PyTorch中的一个方法,用于获取模型的所有可训练参数。
# lr代表学习率(Learning Rate),它是一个超参数,决定了在每个训练步骤中更新模型参数的幅度大小。给定值 2e-5 表示0.00002
optimizer = AdamW(model.parameters(), lr=2e-5)

# 开始训练
train_model(model, train_loader, val_loader, optimizer, model_save_path='../output/best_bert_news_classifier.pth')

# 测试模型(仅评估,不更新参数)
test_acc = evaluate_model(model, test_loader)
print(f'Test Accuracy: {test_acc * 100:.2f}%')
print(time.time() - start)

运行结果

 

预测

假如只想输入一个文本,直接得到疯了及结果,可以使用一下代码

import torch
from transformers import BertTokenizer
from transformers import BertForSequenceClassification


# 假设 model_state_dict 是从文件加载的模型参数
with open('../data/class.txt', 'r', encoding='utf8') as f:
    class_labels = f.readlines()
model = BertForSequenceClassification.from_pretrained('../bert-base-chinese', num_labels=len(class_labels))  # 初始化模型结构,并指定分类类别数量

# 假设 tokenizer 是您在训练时使用的 BERT tokenizer
tokenizer = BertTokenizer.from_pretrained('../bert-base-chinese')
# 加载模型参数,训练好输出的模型参数
model.load_state_dict(torch.load('../output/best_bert_news_classifier.pth'))
model.eval()  # 设置模型为评估模式


def predict_news_category(model, tokenizer, text):
    # 对文本进行预处理并编码
    inputs = tokenizer.encode_plus(
        text,
        add_special_tokens=True,
        max_length=128,  # 根据实际情况调整最大长度
        padding='max_length',
        truncation=True,
        return_tensors='pt'
    )

    input_ids = inputs['input_ids'].to(model.device)
    attention_mask = inputs['attention_mask'].to(model.device)

    # 将数据传递给模型以获取logits
    with torch.no_grad():
        outputs = model(input_ids=input_ids, attention_mask=attention_mask)

    # 获取分类结果
    logits = outputs.logits
    _, prediction = torch.max(logits, dim=1)

    # 返回预测类别索引,实际应用中可能需要将其映射回原始类别标签
    return prediction.item()

# 示例:输入一条新闻标题并预测类别
text = "车载大模型是原子弹还是茶叶蛋?"
predicted_category = predict_news_category(model, tokenizer, text)
print(f"预测的新闻类别是:{predicted_category}")

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/371650.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

挑战!贪吃蛇小游戏的实现(1)

引言 相信大家都玩过贪吃蛇这个游戏! 玩家控制一个不断移动的蛇形角色,在一个封闭空间内移动。随着时间推进,这个蛇形角色会逐渐增长,通常是通过吞食屏幕上出现的物品(如点或者其他标志)来实现。每当贪吃…

JQuery动态插入Bootstrap模态框(Modal)

这里所说的动态插入,是指用JS的append()方式追加元素内容,而不是静态写在HTML里面。 为什么会用到这种方式呢?比如登录框。有些网站在大部分页面都有登录按钮,如果是用Bootstrap的模态框调用的话,常规方式都是写在HTM…

目标检测及相关算法介绍

文章目录 目标检测介绍目标检测算法分类目标检测算法模型组成经典目标检测论文 目标检测介绍 目标检测是计算机视觉领域中的一项重要任务,旨在识别图像或视频中的特定对象的位置并将其与不同类别中的对象进行分类。与图像分类任务不同,目标检测不仅需要…

vue全家桶之状态管理Pinia

一、Pinia和Vuex的对比 1.什么是Pinia呢? Pinia(发音为/piːnjʌ/,如英语中的“peenya”)是最接近pia(西班牙语中的菠萝)的词; Pinia开始于大概2019年,最初是作为一个实验为Vue重新…

详解C++类和对象(上)

文章目录 写在前面1. 类的定义2. 类的访问限定符及封装2.1 类的访问限定符2.2 封装 3. 类的作用域4. 类的实例化5 类的对象大小的计算6. 类成员函数的this指针 写在前面 类和对象这一章节,分为上、中、下三篇文章进行拆分介绍的,本篇文章介绍了类和对象…

LabVIEW与EtherCAT实现风洞安全联锁及状态监测

LabVIEW与EtherCAT实现风洞安全联锁及状态监测 在现代风洞试验中,安全联锁与状态监测系统发挥着至关重要的作用,确保了试验过程的安全性与高效性。介绍了一套基于EtherCAT总线技术和LabVIEW软件开发的风洞安全联锁及状态监测系统。该系统通过实时、可靠…

C++后端开发之Sylar学习二:配置VSCode远程连接Ubuntu开发

C后端开发之Sylar学习二:配置VSCode远程连接Ubuntu开发 没错,我不能像大佬那样直接在Ubuntu上面用Vim手搓代码,只能在本地配置一下VSCode远程连接Ubuntu进行开发咯! 本篇主要是讲解了VSCode如何配置ssh连接Ubuntu,还有…

蓝桥杯每日一题-----数位dp练习

题目 链接 参考代码 写了两个,一个是很久以前写的,一个是最近刚写的,很久以前写的时候还不会数位dp所以写了比较详细的注释,这两个代码主要是设置了不同的记忆数组,通过这两个代码可以理解记忆数组设置的灵活性。 im…

UE4运用C++和框架开发坦克大战教程笔记(十七)(第51~54集)

UE4运用C和框架开发坦克大战教程笔记(十七)(第51~54集) 51. UI 框架介绍UE4 使用 UI 所面临的问题以及解决思路关于即将编写的 UI 框架的思维导图 52. 管理类与面板类53. 预加载与直接加载54. UI 首次进入界面 51. UI 框架介绍 U…

【C++】运算符重载详解

&#x1f497;个人主页&#x1f497; ⭐个人专栏——C学习⭐ &#x1f4ab;点击关注&#x1f929;一起学习C语言&#x1f4af;&#x1f4ab; 目录 导读 1. 为什么需要运算符重载 2. 运算符重载概念 3. 运算符重载示例 3.1 运算符重载 3.2 >或<运算符 4. 运算符重…

2024最新最详细【接口测试总结】

序章 ​ 说起接口测试&#xff0c;网上有很多例子&#xff0c;但是当初做为新手的我来说&#xff0c;看了不不知道他们说的什么&#xff0c;觉得接口测试&#xff0c;好高大上。认为学会了接口测试就能屌丝逆袭&#xff0c;走上人生巅峰&#xff0c;迎娶白富美。因此学了点开发…

分享个前端工具-取色调色工具

这里虽然贴了两个&#xff0c;但推荐 Pipette. PipetteWin22.10.22.zip: https://download.csdn.net/download/rainyspring4540/88799632 图标&#xff1a; 界面&#xff1a; ColorPix https://download.csdn.net/download/rainyspring4540/88799642 图标&#xff1a; 界面…

【Spring】自定义注解 + AOP 记录用户的使用日志

目录 ​编辑 自定义注解 AOP 记录用户的使用日志 使用背景 落地实践 一&#xff1a;自定义注解 二&#xff1a;切面配置 三&#xff1a;Api层使用 使用效果 自定义注解 AOP 记录用户的使用日志 使用背景 &#xff08;1&#xff09;在学校项目中&#xff0c;安防平台…

FastAdmin西陆房产系统(xiluHouse)全开源

应用介绍 一款基于FastAdminThinkPHPUniapp开发的西陆房产管理系统&#xff0c;支持小程序、H5、APP&#xff1b;包含房客、房东(高级授权)、经纪人(高级授权)三种身份。核心功能有&#xff1a;新盘销售、房屋租赁、地图找房、房源代理(高级授权)、在线签约(高级授权)、电子合同…

C#实现坐标系转换

已知坐标系的向量线段AB&#xff0c;旋转指定角度后平移到达坐标AB 获取旋转角度以及新的其他坐标转换。 新建窗体应用程序CoordinateTransDemo&#xff0c;将默认的Form1重命名为FormCoordinateTrans&#xff0c;窗体设计如图&#xff1a; 窗体设计代码如下&#xff1a; 部分…

Redis-缓存问题及解决方案

本文已收录于专栏 《中间件合集》 目录 概念说明缓存问题缓存击穿问题描述解决方案 缓存穿透问题描述解决方案 缓存雪崩问题描述解决方案提高缓存可用性过期时间配置熔断降级 总结提升 概念说明 Redis是一个开源的内存数据库&#xff0c;也可以用作缓存系统。它支持多种数据结构…

前端小案例——动态导航栏文字(HTML + CSS, 附源码)

一、前言 实现功能: 这案例是一个具有动态效果的导航栏。导航栏的样式设置了一个灰色的背景&#xff0c;并使用flex布局在水平方向上平均分配了四个选项。每个选项都是一个li元素&#xff0c;包含一个文本和一个横向的下划线。 当鼠标悬停在选项上时&#xff0c;选项的文本颜色…

华为自动驾驶干不过特斯拉?

文 | AUTO芯球 作者 | 李诞 什么&#xff1f; 华为的智能驾驶方案干不过蔚小理&#xff1f; 特斯拉的智能驾驶[FSD]要甩中国车企几条街&#xff1f; 这华为问界阿维塔刚刚推送“全国都能开”的城区“无图 NCA” 就有黑子来喷了 这是跪久了站不起来了吧 作为玩车14年&…

get通过发送Body传参-工具类

1、调用方式 String url "http://ip/xxx/zh/xxxxx/xxxx/userCode"; //进行url中的对应的参数 url2 url2.replace("ip",bancirili); url2 url2.replace("zh",zh); url2 url2.replace("userCode",userCode);String dateTime xxxx; //组…

04. 【Linux教程】安装 Linux 操作系统

通过前面的小节学习&#xff0c;我们已经对 Linux 操作系统有了简单的了解&#xff0c;同时也在 Windows 下安装了虚拟机软件 VMware &#xff0c;那么本节课我们就介绍下如何使用虚拟机软件安装 Linux 操作系统。 通过第一小节的学习我们知道 Linux 有很多的发行版本&#xf…