N3 中文文本分类

  • 🍨 本文为🔗365天深度学习训练营 中的学习记录博客
  • 🍖 原作者:K同学啊# 前言

前言

前面学习了相关自然语言编码,这周进行相关实战

导入依赖库和设置设备

import torch
import torch.nn as nn
import torchvision
from torchvision import transforms, datasets
import os, PIL, pathlib, warnings

warnings.filterwarnings("ignore")  # 忽略警告
# win10
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

这段代码导入了必要的库并设置了设备(GPU或CPU)。

数据预处理和词汇表构建

from torchtext.data.utils import get_tokenizer
from torchtext.vocab import build_vocab_from_iterator
from torchtext.datasets import AG_NEWS

train_iter = AG_NEWS(split='train')
tokenizer = get_tokenizer('basic_english')  # 返回分词器函数

def yield_tokens(data_iter):
    for _, text in data_iter:
        yield tokenizer(text)

vocab = build_vocab_from_iterator(yield_tokens(train_iter), specials=["<unk>"])
vocab.set_default_index(vocab["<unk>"])  # 设置默认索引,如果找不到单词,则会选择默认索引

这里使用torchtext库加载AG_NEWS数据集,定义了一个分词器并构建了词汇表。

数据处理管道

text_pipeline = lambda x: vocab(tokenizer(x))
label_pipeline = lambda x: int(x) - 1
text_pipeline('here is the an example')

定义了两个数据处理管道:text_pipeline用于将文本转化为词汇表中的索引序列,label_pipeline用于将标签转化为整数索引。

定义数据加载器

from torch.utils.data import DataLoader

def collate_batch(batch):
    label_list, text_list, offsets = [], [], [0]

    for (_label, _text) in batch:
        label_list.append(label_pipeline(_label))
        processed_text = torch.tensor(text_pipeline(_text), dtype=torch.int64)
        text_list.append(processed_text)
        offsets.append(processed_text.size(0))

    label_list = torch.tensor(label_list, dtype=torch.int64)
    text_list = torch.cat(text_list)
    offsets = torch.tensor(offsets[:-1]).cumsum(dim=0)  # 返回维度dim中输入元素的累计和

    return label_list.to(device), text_list.to(device), offsets.to(device)

dataloader = DataLoader(train_iter, batch_size=8, shuffle=False, collate_fn=collate_batch)

定义了一个collate_batch函数用于将一个批次的数据整合在一起,并创建了一个数据加载器。

定义模型

from torch import nn

class TextClassificationModel(nn.Module):

    def __init__(self, vocab_size, embed_dim, num_class):
        super(TextClassificationModel, self).__init__()
        self.embedding = nn.EmbeddingBag(vocab_size, embed_dim, sparse=False)
        self.fc = nn.Linear(embed_dim, num_class)
        self.init_weights()

    def init_weights(self):
        initrange = 0.5
        self.embedding.weight.data.uniform_(-initrange, initrange)
        self.fc.weight.data.uniform_(-initrange, initrange)
        self.fc.bias.data.zero_()

    def forward(self, text, offsets):
        embedded = self.embedding(text, offsets)
        return self.fc(embedded)

num_class = len(set([label for (label, text) in train_iter]))
vocab_size = len(vocab)
em_size = 64
model = TextClassificationModel(vocab_size, em_size, num_class).to(device)

定义了一个文本分类模型TextClassificationModel,包括初始化函数、权重初始化和前向传播函数。模型由一个嵌入层和一个线性层组成。

训练和评估函数

import time

def train(dataloader):
    model.train()  # 切换为训练模式
    total_acc, train_loss, total_count = 0, 0, 0
    log_interval = 500
    start_time = time.time()

    for idx, (label, text, offsets) in enumerate(dataloader):
        predicted_label = model(text, offsets)
        optimizer.zero_grad()  # grad属性归零
        loss = criterion(predicted_label, label)  # 计算网络输出和真实值之间的差距,label为真实值
        loss.backward()  # 反向传播
        optimizer.step()  # 每一步自动更新

        total_acc += (predicted_label.argmax(1) == label).sum().item()
        train_loss += loss.item()
        total_count += label.size(0)

        if idx % log_interval == 0 and idx > 0:
            elapsed = time.time() - start_time
            print('| epoch {:1d} | {:4d}/{:4d} batches '
                  '| train_acc {:4.3f} train_loss {:4.5f}'.format(epoch, idx, len(dataloader),
                                                                  total_acc / total_count, train_loss / total_count))
            total_acc, train_loss, total_count = 0, 0, 0
            start_time = time.time()

def evaluate(dataloader):
    model.eval()  # 切换为测试模式
    total_acc, train_loss, total_count = 0, 0, 0

    with torch.no_grad():
        for idx, (label, text, offsets) in enumerate(dataloader):
            predicted_label = model(text, offsets)
            loss = criterion(predicted_label, label)  # 计算loss值
            total_acc += (predicted_label.argmax(1) == label).sum().item()
            train_loss += loss.item()
            total_count += label.size(0)

    return total_acc / total_count, train_loss / total_count

定义了训练和评估函数,用于训练模型和评估模型性能。

数据集分割和数据加载器创建

from torch.utils.data.dataset import random_split
from torchtext.data.functional import to_map_style_dataset

EPOCHS = 10  # epoch
LR = 5  # 学习率
BATCH_SIZE = 64  # batch size for training

criterion = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model.parameters(), lr=LR)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, 1.0, gamma=0.1)
total_accu = None

train_iter, test_iter = AG_NEWS()  # 加载数据
train_dataset = to_map_style_dataset(train_iter)
test_dataset = to_map_style_dataset(test_iter)
num_train = int(len(train_dataset) * 0.95)

split_train_, split_valid_ = random_split(train_dataset,
                                          [num_train, len(train_dataset) - num_train])

train_dataloader = DataLoader(split_train_, batch_size=BATCH_SIZE,
                              shuffle=True, collate_fn=collate_batch)
valid_dataloader = DataLoader(split_valid_, batch_size=BATCH_SIZE,
                              shuffle=True, collate_fn=collate_batch)
test_dataloader = DataLoader(test_dataset, batch_size=BATCH_SIZE,
                             shuffle=True, collate_fn=collate_batch)

加载数据集并将其转换为适用于随机访问的数据集,分割训练集和验证集,并创建相应的数据加载器。

训练和验证模型

for epoch in range(1, EPOCHS + 1):
    epoch_start_time = time.time()
    train(train_dataloader)
    val_acc, val_loss = evaluate(valid_dataloader)

    if total_accu is not None and total_accu > val_acc:
        scheduler.step()
    else:
        total_accu = val_acc
    print('-' * 69)
    print('| epoch {:1d} | time: {:4.2f}s | '
          'valid_acc {:4.3f} valid_loss {:4.3f}'.format(epoch,
                                                        time.time() - epoch_start_time,
                                                        val_acc, val_loss))
    print('-' * 69)

进行训练和验证,在每个epoch结束时打印验证准确率和损失,并根据验证结果调整学习率。

测试模型

print('Checking the results of test dataset.')
test_acc, test_loss = evaluate(test_dataloader)
print('test accuracy {:8.3f}'.format(test_acc))

在测试集上评估模型性能并打印测试准确率。

结果

在这里插入图片描述

总结

这个案例实现了一个完整的文本分类流程,从数据预处理、模型定义到训练和评估。使用torchtext加载数据,并利用PyTorch构建和训练深度学习模型,实现了对AG_NEWS数据集的文本分类任务,达到了90.1%的精度。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/711564.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

湘潭大学信息与网络安全复习笔记2(总览)

前面的实验和作业反正已经结束了&#xff0c;现在就是集中火力把剩下的内容复习一遍&#xff0c;这一篇博客的内容主要是参考教学大纲和教学日历 文章目录 教学日历教学大纲 教学日历 总共 12 次课&#xff0c;第一次课是概述&#xff0c;第二次和第三次课是密码学基础&#x…

Android入门第68天-自动更新/升级怎么做(生产级实例)

开篇 今天我们进入第68讲。 在第60天左右其实很多同学们已经进入了APP应用开发了,因为60天内容足以让大家踏上正实的Android开发生涯。 随着开发的深入,我们发觉日常工作中无非就是一些组件的嵌套、合理应用。当代码迭代、功能迭代越来越频繁后我们面临着另一个问题,即:…

Vue3 生命周期函数及其与Vue2的对比总结

Vue3 继续保留了 Vue2 的生命周期钩子&#xff0c;但在 Composition API&#xff08;setup 函数&#xff09;中&#xff0c;它们被改为了一组导入函数。以下是它们的对比&#xff1a; Vue2 生命周期钩子和 Vue3 对应的生命周期函数&#xff1a; 在 Vue3 中&#xff0c;所有的…

git 快速将当前目录添加仓储

一、进入目录 git init git add . git commit -m "init" git remote add origin http://192.168.31.104/root/AutoBuildDemo.git 二、登录gitlab&#xff0c;创建项目AutoBuildDemo 最后执行&#xff1a; git push -u origin master

笔记 | 软件工程06-1:软件设计-软件设计基础

1 软件设计概述 1.1 为什么要软件设计 1.2 何为软件设计 何为软件系统的解决方案&#xff1f; 软件设计关注与软件需求的实现问题软件设计是需求分析和软件实现间的桥梁 1.3 软件设计的质量要求 1.4 软件设计的过程 1.4.1 软件体系结构设计 1.4.2 用户界面设计 1.4.3 软件详细…

C++ 18 之 函数的重载

c18函数的重载.cpp #include <iostream> #include <string.h> using namespace std;void fun4(int a) {cout << "int a: "<< a << endl; } void fun4(double a) {cout << "double a: " << a << endl; }v…

yolov10主要特点

在我们探讨YOLOv10之前&#xff0c;让我们回顾一下YOLO的发展历程。YOLO在实时目标检测领域一直是先驱&#xff0c;兼顾速度和准确性。从YOLOv1到YOLOv9&#xff0c;每个版本在架构、优化和数据增强方面都引入了显著的改进。然而&#xff0c;随着模型的发展&#xff0c;某些限制…

拦截器 之 用户登录判断

spring boot 拦截器的实现需要有两步&#xff1a; 自定义一个拦截器 package com.example.demo.common;import jakarta.servlet.http.HttpServletRequest; import jakarta.servlet.http.HttpServletResponse; import jakarta.servlet.http.HttpSession; import org.springfra…

Zadig vs. Jenkins 详细比较

01、Zadig vs. Jenkins&#xff1a;关于时代的选择 最近官方公众号发布了一篇名为 《是时候和 Jenkins 说再见了》的文章&#xff0c;引起了社区的广泛关注和讨论。作为曾经最被广泛使用的持续构建交付工具&#xff0c;Jenkins 的江湖地位似乎被挑战了。评论中有一条被高度点赞…

【字符串函数】

1.strlen的使⽤和模拟实现 size_t strlen ( const char * str ); 1.字符串以 \0 作为结束标志&#xff0c;strlen函数返回的是在字符串中 \0 前⾯出现的字符个数&#xff08;不包 含 \0 )。 2.参数指向的字符串必须要以 \0 结束。 3.注意函数的返回值为size_t&#xff0c;是⽆…

Denoising Prior Driven Deep Neural Network for Image Restoration

之所以能够检索到这篇论文是想看看该论文是如何利用多尺度相似性解决图像去噪问题&#xff0c;除了摘要和结论&#xff0c;论文中两次提到这个术语。next section是指section 4。然后整个section 4&#xff0c;根本没有提多尺度的事儿&#xff0c;更别说解决了。又看了一下The …

excel中按多列进行匹配并对数量进行累加

公司的生产计划是按订单下发&#xff0c;但不同订单的不同产品中可能有用到相同的配件&#xff0c;按单1对1时&#xff0c;对计算机十分友好&#xff0c;但对于在配件库检料的工人来说就比较麻烦&#xff0c;上百条产品里可能会有多条都是相同的产品&#xff0c;首先考虑的办法…

[面试题]Java【并发】

[面试题]Java【基础】[面试题]Java【虚拟机】[面试题]Java【并发】[面试题]Java【集合】[面试题]MySQL 因为 Java 并发涉及到的内容会非常多&#xff0c;本面试题可能很难覆盖到所有的知识点&#xff0c;所以推荐 《Java并发编程的艺术》 。 Java 线程 线程 通知 等待 线…

Linux C编译器从零开发一

基础程序汇编 test.c int main() {return 42; } 查看反汇编 cc -o test test.c objdump -d -M intel test 0000000000001129 <main>:1129: f3 0f 1e fa endbr64 112d: 55 push rbp112e: 48 89 e5 mov rbp,rsp1131: b…

抽象工厂模式(大话设计模式)C/C++版本

抽象工厂模式 C 参考&#xff1a;https://www.cnblogs.com/Galesaur-wcy/p/15927110.html #include <iostream> using namespace std;// 抽象产品Department ,定义具体产品的公共接口 class Department { public:virtual ~Department() default;virtual void Insert()…

Rust 实战丨并发构建倒排索引

引言 继上篇 Rust 实战丨倒排索引&#xff0c;本篇我们将参考《Rust 程序设计&#xff08;第二版&#xff09;》中并发编程篇章来实现高并发构建倒排索引。 本篇主要分为以下几个部分&#xff1a; 功能展示&#xff1a;展示我们最终实现的 2 个工具的效果&#xff08;构建索…

大话C语言:第24篇 预处理

1 C语言编译流程 C语言的编译流程包括&#xff1a; 预编译&#xff1a;将.c 中的头文件展开、宏展开&#xff0c;生成的文件是.i 文件。gcc指令&#xff1a;gcc -E file.c -o file.i 编译&#xff1a;将预处理之后的.i 文件生成 .s 汇编文件。gcc指令&#xff1a;gcc -S file…

逢3必过报数游戏-第13届蓝桥杯省赛Python真题精选

[导读]&#xff1a;超平老师的Scratch蓝桥杯真题解读系列在推出之后&#xff0c;受到了广大老师和家长的好评&#xff0c;非常感谢各位的认可和厚爱。作为回馈&#xff0c;超平老师计划推出《Python蓝桥杯真题解析100讲》&#xff0c;这是解读系列的第84讲。 逢3必过报数游戏&…

直流放大器

一&#xff0c;概念及存在问题 集成电路主要由半导体材料构成&#xff0c;其内部适合用二极管&#xff0c;三极管等类型的元器件制作&#xff0c;而不适用电容&#xff0c;电感和变压器&#xff0c;因此集成放大电路内部多个放大电路之间通常采用直接耦合。直接耦合电路除了可…

文生视频新王登场:Luma官宣免费、电影级大片生成,Sora?可灵?SD3.0?(内附网址)

✨点击这里✨&#xff1a;&#x1f680;原文链接&#xff1a;&#xff08;更好排版、视频播放、社群交流、最新AI开源项目、AI工具分享都在这个公众号&#xff01;&#xff09; 文生视频新王登场&#xff1a;Luma官宣免费、电影级大片生成&#xff0c;Sora&#xff1f;可灵&am…