【文本分类】bert二分类

import os
import torch
from torch.utils.data import DataLoader, Dataset
from transformers import BertTokenizer, BertForSequenceClassification, AdamW
from sklearn.metrics import accuracy_score, classification_report
from tqdm import tqdm

# 自定义数据集
class CustomDataset(Dataset):
    def __init__(self, texts, labels, tokenizer, max_length=128):
        self.texts = texts
        self.labels = labels
        self.tokenizer = tokenizer
        self.max_length = max_length

    def __len__(self):
        return len(self.texts)

    def __getitem__(self, idx):
        text = self.texts[idx]
        label = self.labels[idx]
        encoding = self.tokenizer(
            text,
            max_length=self.max_length,
            padding="max_length",
            truncation=True,
            return_tensors="pt"
        )
        return {
            "input_ids": encoding["input_ids"].squeeze(0),
            "attention_mask": encoding["attention_mask"].squeeze(0),
            "label": torch.tensor(label, dtype=torch.long)
        }


# 训练函数
def train_model(model, train_loader, optimizer, device, num_epochs=3):
    model.train()
    for epoch in range(num_epochs):
        total_loss = 0
        for batch in tqdm(train_loader, desc=f"Training Epoch {epoch + 1}/{num_epochs}"):
            input_ids = batch["input_ids"].to(device)
            attention_mask = batch["attention_mask"].to(device)
            labels = batch["label"].to(device)

            outputs = model(input_ids, attention_mask=attention_mask, labels=labels)
            loss = outputs.loss
            total_loss += loss.item()

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

        print(f"Epoch {epoch + 1} Loss: {total_loss / len(train_loader)}")


# 评估函数
def evaluate_model(model, val_loader, device):
    model.eval()
    predictions, true_labels = [], []
    with torch.no_grad():
        for batch in val_loader:
            input_ids = batch["input_ids"].to(device)
            attention_mask = batch["attention_mask"].to(device)
            labels = batch["label"].to(device)

            outputs = model(input_ids, attention_mask=attention_mask)
            logits = outputs.logits
            preds = torch.argmax(logits, dim=1).cpu().numpy()

            predictions.extend(preds)
            true_labels.extend(labels.cpu().numpy())

    accuracy = accuracy_score(true_labels, predictions)
    report = classification_report(true_labels, predictions)
    print(f"Validation Accuracy: {accuracy}")
    print("Classification Report:")
    print(report)


# 模型保存函数
def save_model(model, tokenizer, output_dir):
    os.makedirs(output_dir, exist_ok=True)
    model.save_pretrained(output_dir)
    tokenizer.save_pretrained(output_dir)
    print(f"Model saved to {output_dir}")


# 模型加载函数
def load_model(output_dir, device):
    tokenizer = BertTokenizer.from_pretrained(output_dir)
    model = BertForSequenceClassification.from_pretrained(output_dir)
    model.to(device)
    print(f"Model loaded from {output_dir}")
    return model, tokenizer


# 推理预测函数
def predict(texts, model, tokenizer, device, max_length=128):
    model.eval()
    encodings = tokenizer(
        texts,
        max_length=max_length,
        padding="max_length",
        truncation=True,
        return_tensors="pt"
    )
    input_ids = encodings["input_ids"].to(device)
    attention_mask = encodings["attention_mask"].to(device)

    with torch.no_grad():
        outputs = model(input_ids, attention_mask=attention_mask)
        logits = outputs.logits
        probabilities = torch.softmax(logits, dim=1).cpu().numpy()
        predictions = torch.argmax(logits, dim=1).cpu().numpy()

    return predictions, probabilities


# 主函数
def main():
    # 配置参数
    config = {
        "train_batch_size": 16,
        "val_batch_size": 16,
        "learning_rate": 5e-5,
        "num_epochs": 5,
        "max_length": 128,
        "device_id": 7,  # 指定 GPU ID
        "model_dir": "model",
        "local_model_path": "roberta_tiny_model",  # 指定本地模型路径,如果为 None 则使用预训练模型
        "pretrained_model_name": "uer/chinese_roberta_L-12_H-128",  # 预训练模型名称
    }

    # 设置设备
    device = torch.device(f"cuda:{config['device_id']}" if torch.cuda.is_available() else "cpu")
    print(f"Using device: {device}")

    # 加载分词器和模型
    tokenizer = BertTokenizer.from_pretrained(config["local_model_path"])
    model = BertForSequenceClassification.from_pretrained(config["local_model_path"], num_labels=2)
    model.to(device)

    # 示例数据
    train_texts = ["This is a great product!", "I hate this service."]
    train_labels = [1, 0]
    val_texts = ["Awesome experience.", "Terrible product."]
    val_labels = [1, 0]

    # 创建数据集和数据加载器
    train_dataset = CustomDataset(train_texts, train_labels, tokenizer, config["max_length"])
    val_dataset = CustomDataset(val_texts, val_labels, tokenizer, config["max_length"])
    train_loader = DataLoader(train_dataset, batch_size=config["train_batch_size"], shuffle=True)
    val_loader = DataLoader(val_dataset, batch_size=config["val_batch_size"])

    # 定义优化器
    optimizer = AdamW(model.parameters(), lr=config["learning_rate"])

    # 训练模型
    train_model(model, train_loader, optimizer, device, num_epochs=config["num_epochs"])

    # 评估模型
    evaluate_model(model, val_loader, device)

    # 保存模型
    save_model(model, tokenizer, config["model_dir"])

    # 加载模型
    loaded_model, loaded_tokenizer = load_model(config["model_dir"], "cpu")

    # 推理预测
    new_texts = ["I love this!", "It's the worst."]
    predictions, probabilities = predict(new_texts, loaded_model, loaded_tokenizer,  "cpu")
    for text, pred, prob in zip(new_texts, predictions, probabilities):
        print(f"Text: {text}")
        print(f"Predicted Label: {pred} (Probability: {prob})")


if __name__ == "__main__":
    main()

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/950363.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

STM32烧写失败之Contents mismatch at: 0800005CH (Flash=FFH Required=29H) !

一)问题:用ULINK2给STM32F103C8T6下载程序,下载方式设置如下: 出现下面两个问题: 1)下载问题界面如下: 这个错误的信息大概可以理解为,在0x08000063地址上读取到flash存储为FF&am…

vscode通过ssh连接服务器实现免密登录

一、通过ssh连接服务器 1、打开vscode,进入拓展(CtrlShiftX),下载拓展Remote - SSH。 2、点击远程资源管理器选项卡,选择远程(隧道/SSH)类别。 3、点击SSH配置。 4、在中间上部分弹出的配置文件…

在Nvidia Jetson ADX Orin中使用TensorRT-LLM运行llama3-8b

目录 背景:步骤 1.获取模型权重第 2 步:准备第 3 步:构建 TensorRT-LLM 引擎 背景: 大型语言模型 (LLM) 推理的关键瓶颈在于 GPU 内存资源短缺。因此,各种加速框架主要强调减少峰值 GPU 内存使…

Unity Shader学习日记 part4 Shader 基础结构

其实在这一篇之前,应该还有一个关于坐标空间转换的内容,但是内容囤积的有些多,就先把Shader的基础结构先记录一下。 笔记主要记录在代码中,所以知识点主要是图和代码的展示。 Unity Shader分类 在Unity中,Shader的种…

特征点检测与匹配——MATLAB R2022b

特征点检测与匹配在计算机视觉中的作用至关重要,它为图像处理、物体识别、增强现实等领域提供了坚实的基础。 目录 Harris角点检测 SIFT(尺度不变特征变换) SURF(加速稳健特征) ORB(Oriented FAST and Rotated BRIEF) 总结 特征点检测与匹配是计算机视觉中的一项基…

Airflow:HttpSensor实现API驱动数据流程

数据管道工作流通常依赖于api来访问、获取和处理来自外部系统的数据。为了处理这些场景,Apache Airflow提供了HttpSensor,这是一个内置的Sensor,用于监视HTTP请求的状态,并在满足指定条件时触发后续任务。在这篇博文中&#xff0c…

图数据库 | 17、高可用分布式设计(上)

我们在前面的文章中,探索了多种可能的系统扩展方式,以及每种扩展方式的优劣。 本篇文章将通过具体的架构设计方案来对每一种方案的设计、投入产出比、各项指标与功能,以及孰优孰劣等进行评价。 在设计高性能、高可用图数据库的时候&#xf…

JAVA学习记录1

文章为个人学习记录,仅供参考,如有错误请指出。 什么是JAVA? JAVA是一种高级的编程语言,可以用于开发大部分场景的软件,但主要用于服务器的开发。 什么是JDK? 类似于python使用PyCharm来编写代码&#…

css中的部分文字特性

文章目录 一、writing-mode二、word-break三、word-spacing;四、white-space五、省略 总结归纳常见文字特性,后续补充 一、writing-mode 默认horizontal-tbwriting-mode: vertical-lr; 从第一排开始竖着排,到底部再换第二排,文字与文字之间从…

Android wifi常见问题及分析

参考 Android Network/WiFi 那些事儿 前言 本文将讨论几个有意思的网络问题,同时介绍 Android 上常见WiFi 问题的分析思路。 网络基础Q & A 一. 网络分层缘由 分层想必大家很熟悉,是否想过为何需要这样分层? 网上大多都是介绍每一层…

【C语言】_指针与数组

目录 1. 数组名的含义 1.1 数组名与数组首元素的地址的联系 1.3 数组名与首元素地址相异的情况 2. 使用指针访问数组 3. 一维数组传参的本质 3.1 代码示例1:函数体内计算sz(sz不作实参传递) 3.2 代码示例2:sz作为实参传递 3…

IDEA 字符串拼接符号“+”位于下一行的前面,而不是当前行的末尾

效果图 IDEA 默认效果是“历史效果”,经过修改后为“预期效果” 设置方式 在设置中找到Editor > Code Style > Java > Wrapping and Braces > Binary expressions > 勾选 Operation sign on next line 即可实现。具体设置如图。

牛客网刷题 ——C语言初阶(2分支和循环-for)——打印菱形

1. 题目描述 用C语言在屏幕上输出以下图案: 2. 思路 我是先上手,先把上半部分打印出来,然后慢慢再来分析,下面这是我先把整个上半部分打印出来,因为空格不方便看是几个,这里先用&代替空格了 然后这里…

C# 整型、浮点型 数值范围原理分析

总目录 前言 一、整型、浮点型 数值范围列表 二、什么是大小、范围 在上面的列表中,每个数据类型都有自己的Range (范围) 和 Size (大小)。 1. 范围 范围好理解,就是对应数据类型的数据范围,如 sbtyte 的数据范围是 -128~127,超…

安装vue脚手架出现的一系列问题

安装vue脚手架出现的一系列问题 前言使用 npm 安装 vue/cli2.权限问题及解决方法一:可以使用管理员权限进行安装。方法二:更改npm全局安装路径 前言 由于已有较长时间未进行 vue 项目开发,今日着手准备开发一个新的 vue 项目时,在…

Qt 5.14.2 学习记录 —— 칠 QWidget 常用控件(2)

文章目录 1、Window Frame2、windowTitle3、windowIcon4、qrc机制5、windowOpacity 1、Window Frame 在运行Qt程序后,除了用户做的界面,最上面还有一个框,这就是window frame框。对于界面的元素,它们的原点是Qt界面的左上角或win…

数据结构大作业——家谱管理系统(超详细!完整代码!)

目录 设计思路: 一、项目背景 二、功能分析 查询功能流程图: 管理功能流程图: 三、设计 四、实现 代码实现: 头文件 结构体 函数声明及定义 创建家谱树头结点 绘制家谱树(打印) 建立右兄弟…

springboot参数注解

在Spring Boot中,创建RESTful API时,通常会使用Spring MVC提供的注解来声明请求参数。以下是一些常用的注解及其用途: 1. RequestBody 用途:用于将HTTP请求的body部分绑定到方法参数上,通常用于接收JSON或XML格式的数…

uniapp实现在card卡片组件内为图片添加长按保存、识别二维码等功能

在原card组件的cover属性添加图片的话&#xff0c;无法在图片上面绑定 show-menu-by-longpress"true"属性&#xff0c;通过将图片自定义添加可使用该属性。 代码&#xff1a; <uni-card title"标题" padding"10px 0" :thumbnail"avata…

【Springer斯普林格出版,Ei稳定,往届快速见刊检索】第四届电子信息工程、大数据与计算机技术国际学术会议(EIBDCT 2025)

第四届电子信息工程、大数据与计算机技术国际学术会议&#xff08;EIBDCT 2025&#xff09;将于2025年2月21-23日在中国青岛举行。该会议主要围绕电子信息工程、大数据、计算机技术等研究领域展开讨论。会议旨在为从事相关科研领域的专家学者、工程技术人员、技术研发人员提供一…