【NLP练习】中文文本分类-Pytorch实现

中文文本分类-Pytorch实现

  • 🍨 本文为🔗365天深度学习训练营 中的学习记录博客
  • 🍖 原作者:K同学啊

一、准备工作

1. 任务说明

本次使用Pytorch实现中文文本分类。主要代码与文本分类代码基本一致,不同的是本次任务使用了本地的中文数据,数据示例如下:
在这里插入图片描述

2.加载数据

import torch
import torch.nn as nn
import torchvision
from torchvision import transforms,datasets
import os,PIL,pathlib,warnings

warnings.filterwarnings("ignore")   #忽略警告信息

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device

输出:

device(type='cpu')
import pandas as pd
 #加载自定义中文数据
train_data = pd.read_csv('./train.csv',sep='\t',header = None)
train_data.head()

输出:
在这里插入图片描述

#构造数据集迭代器
def coustom_data_iter(texts,labels):
 for x,y in zip(texts,labels):
  yield x,y
 
train_iter =coustom_data_iter(train_data[0].values[:],train_data[1].values[:])

二、数据预处理

#构建词典
from torchtext.data.utils import get_tokenizer
from torchtext.vocab import build_vocab_from_iterator
import jieba

#中文分词方法
tokenizer = jieba.lcut
def yield_tokens(data_iter):
    for text,_ in data_iter:
        yield tokenizer(text)

vocab = build_vocab_from_iterator(yield_tokens(train_iter),
                                 specials=["<unk>"])
vocab.set_default_index(vocab["<unk>"])   #设置默认索引,如果找不到单词,则会选择默认索引
vocab(['我','想','看','和平','精英','上','战神','必备','技巧','的','游戏','视频'])

输出:

[2, 10, 13, 973, 1079, 146, 7724, 7574, 7793, 1, 186, 28]
label_name = list(set(train_data[1].values[:]))
print(label_name)

输出:

['FilmTele-Play', 'Alarm-Update', 'Weather-Query', 'Audio-Play', 'Radio-Listen', 'Travel-Query', 'Music-Play', 'Video-Play', 'HomeAppliance-Control', 'Calendar-Query', 'TVProgram-Play', 'Other']
text_pipeline = lambda x : vocab(tokenizer(x))
label_pipeline = lambda x : label_name.index(x)

print(text_pipeline('我想看和平精英上战神必备技巧的游戏视频'))
print(label_pipeline('Video-Play'))

输出:

[2, 10, 13, 973, 1079, 146, 7724, 7574, 7793, 1, 186, 28]
7

lambda表达式的语法为:lambda arguments: expression
其中arguments是函数的参数,可以有多个参数,用逗号分隔。expression是一个表达式,它定义了函数的返回值。

  • text_pipeline函数: 将原始文本数据转换为整数列表,使用了之前构建的vocab词表和tokenizer分词器函数。具体步骤:
  1. 接受一个字符串x作为输入
  2. 使用tokenizer将其分词
  3. 将每个词在vocab词表中的索引放入一个列表返回
  • label_pipeline函数: 将原始标签数据转换为整数,它接受一个字符串x作为输入,并使用 label_index.index(x) 方法获取x在label_name列表中的索引作为输出。

2.生成数据批次和迭代器

#生成数据批次和迭代器
from torch.utils.data import DataLoader

def collate_batch(batch):
    label_list, text_list, offsets = [],[],[0]         
    for(_text, _label) in batch:
        #标签列表
        label_list.append(label_pipeline(_label))
        #文本列表
        processed_text = torch.tensor(text_pipeline(_text), dtype=torch.int64)
        text_list.append(processed_text)
        #偏移量
        offsets.append(processed_text.size(0))

    label_list = torch.tensor(label_list,dtype=torch.int64)
    text_list = torch.cat(text_list)
    offsets = torch.tensor(offsets[:-1]).cumsum(dim=0)       #返回维度dim中输入元素的累计和
    return text_list.to(device), label_list.to(device), offsets.to(device)

#数据加载器
dataloader = DataLoader(
    train_iter,
    batch_size = 8,
    shuffle = False,
    collate_fn = collate_batch
)

三、模型构建

1. 搭建模型

#搭建模型
from torch import nn

class TextClassificationModel(nn.Module):
    def __init__(self, vocab_size, embed_dim, num_class):
        super(TextClassificationModel,self).__init__()
        self.embedding = nn.EmbeddingBag(vocab_size,      #词典大小
                                        embed_dim,        # 嵌入的维度
                                        sparse=False)     #
        self.fc = nn.Linear(embed_dim, num_class)
        self.init_weights()

    def init_weights(self):
        initrange = 0.5
        self.embedding.weight.data.uniform_(-initrange, initrange)
        self.fc.weight.data.uniform_(-initrange, initrange)
        self.fc.bias.data.zero_()

    def forward(self, text, offsets):
        embedded = self.embedding(text, offsets)
        return self.fc(embedded)

2. 初始化模型

#初始化模型
#定义实例
num_class = len(label_name)
vocab_size = len(vocab)
em_size = 64
model = TextClassificationModel(vocab_size, em_size, num_class).to(device)

3. 定义训练与评估函数

#定义训练与评估函数
import time

def train(dataloader):
    model.train()          #切换为训练模式
    total_acc, train_loss, total_count = 0,0,0
    log_interval = 50
    start_time = time.time()
    for idx, (text,label, offsets) in enumerate(dataloader):
        predicted_label = model(text, offsets)
        optimizer.zero_grad()                             #grad属性归零
        loss = criterion(predicted_label, label)          #计算网络输出和真实值之间的差距,label为真
        loss.backward()                                   #反向传播
        torch.nn.utils.clip_grad_norm_(model.parameters(),0.1)  #梯度裁剪
        optimizer.step()                                  #每一步自动更新
        
        #记录acc与loss
        total_acc += (predicted_label.argmax(1) == label).sum().item()
        train_loss += loss.item()
        total_count += label.size(0)
        if idx % log_interval == 0 and idx > 0:
            elapsed = time.time() - start_time
            print('|epoch{:d}|{:4d}/{:4d} batches|train_acc{:4.3f} train_loss{:4.5f}'.format(
                epoch,
                idx,
                len(dataloader),
                total_acc/total_count,
                train_loss/total_count))
            total_acc,train_loss,total_count = 0,0,0
            staet_time = time.time()

def evaluate(dataloader):
    model.eval()      #切换为测试模式
    total_acc,train_loss,total_count = 0,0,0
    with torch.no_grad():
        for idx,(text,label,offsets) in enumerate(dataloader):
            predicted_label = model(text, offsets)
            loss = criterion(predicted_label,label)   #计算loss值
            #记录测试数据
            total_acc += (predicted_label.argmax(1) == label).sum().item()
            train_loss += loss.item()
            total_count += label.size(0)
    
    return total_acc/total_count, train_loss/total_count

四、训练模型

1. 拆分数据集并运行模型

#拆分数据集并运行模型
from torch.utils.data.dataset   import random_split
from torchtext.data.functional  import to_map_style_dataset

# 超参数设定
EPOCHS      = 10   #epoch
LR          = 5    #learningRate
BATCH_SIZE  = 64   #batch size for training

#设置损失函数、选择优化器、设置学习率调整函数
criterion   = torch.nn.CrossEntropyLoss()
optimizer   = torch.optim.SGD(model.parameters(), lr = LR)
scheduler   = torch.optim.lr_scheduler.StepLR(optimizer, 1.0, gamma = 0.1)
total_accu  = None

# 构建数据集
train_iter = custom_data_iter(train_data[0].values[:],train_data[1].values[:])
train_dataset   = to_map_style_dataset(train_iter)
split_train_, split_valid_ = random_split(train_dataset,
                                         [int(len(train_dataset)*0.8),int(len(train_dataset)*0.2)])

                                           
train_dataloader    = DataLoader(split_train_, batch_size = BATCH_SIZE, shuffle = True, collate_fn = collate_batch)
valid_dataloader    = DataLoader(split_valid_, batch_size = BATCH_SIZE, shuffle = True, collate_fn = collate_batch)

for epoch in range(1, EPOCHS + 1):
    epoch_start_time = time.time()
    train(train_dataloader)
    val_acc, val_loss = evaluate(valid_dataloader)
    #获取当前的学习率
    lr = optimizer.state_dict()['param_groups'][0]['lr']
    if total_accu is not None and total_accu > val_acc:
        scheduler.step()
    else:
        total_accu = val_acc
    print('-' * 69)
    print('| epoch {:d} | time:{:4.2f}s | valid_acc {:4.3f} valid_loss {:4.3f}'.format(
        epoch,
        time.time() - epoch_start_time,
        val_acc,
        val_loss))
    print('-' * 69)

输出:

['还有双鸭山到淮阴的汽车票吗13号的' '从这里怎么回家' '随便播放一首专辑阁楼里的佛里的歌' ...
 '黎耀祥陈豪邓萃雯畲诗曼陈法拉敖嘉年杨怡马浚伟等到场出席' '百事盖世群星星光演唱会有谁' '下周一视频会议的闹钟帮我开开']
|epoch1|  50/ 152 batches|train_acc0.953 train_loss0.00282
|epoch1| 100/ 152 batches|train_acc0.953 train_loss0.00271
|epoch1| 150/ 152 batches|train_acc0.952 train_loss0.00292
---------------------------------------------------------------------
| epoch 1 | time:5.50s | valid_acc 0.949 valid_loss 0.003
---------------------------------------------------------------------
|epoch2|  50/ 152 batches|train_acc0.961 train_loss0.00231
|epoch2| 100/ 152 batches|train_acc0.967 train_loss0.00204
|epoch2| 150/ 152 batches|train_acc0.963 train_loss0.00228
---------------------------------------------------------------------
| epoch 2 | time:5.06s | valid_acc 0.949 valid_loss 0.003
---------------------------------------------------------------------
|epoch3|  50/ 152 batches|train_acc0.975 train_loss0.00173
|epoch3| 100/ 152 batches|train_acc0.973 train_loss0.00177
|epoch3| 150/ 152 batches|train_acc0.972 train_loss0.00166
---------------------------------------------------------------------
| epoch 3 | time:5.07s | valid_acc 0.948 valid_loss 0.003
---------------------------------------------------------------------
|epoch4|  50/ 152 batches|train_acc0.984 train_loss0.00137
|epoch4| 100/ 152 batches|train_acc0.987 train_loss0.00123
|epoch4| 150/ 152 batches|train_acc0.983 train_loss0.00119
---------------------------------------------------------------------
| epoch 4 | time:5.07s | valid_acc 0.950 valid_loss 0.003
---------------------------------------------------------------------
|epoch5|  50/ 152 batches|train_acc0.985 train_loss0.00125
|epoch5| 100/ 152 batches|train_acc0.987 train_loss0.00119
|epoch5| 150/ 152 batches|train_acc0.986 train_loss0.00120
---------------------------------------------------------------------
| epoch 5 | time:5.03s | valid_acc 0.949 valid_loss 0.003
---------------------------------------------------------------------
|epoch6|  50/ 152 batches|train_acc0.985 train_loss0.00118
|epoch6| 100/ 152 batches|train_acc0.989 train_loss0.00114
|epoch6| 150/ 152 batches|train_acc0.985 train_loss0.00120
---------------------------------------------------------------------
| epoch 6 | time:5.40s | valid_acc 0.949 valid_loss 0.003
---------------------------------------------------------------------
|epoch7|  50/ 152 batches|train_acc0.984 train_loss0.00119
|epoch7| 100/ 152 batches|train_acc0.986 train_loss0.00119
|epoch7| 150/ 152 batches|train_acc0.989 train_loss0.00112
---------------------------------------------------------------------
| epoch 7 | time:5.71s | valid_acc 0.949 valid_loss 0.003
---------------------------------------------------------------------
|epoch8|  50/ 152 batches|train_acc0.985 train_loss0.00115
|epoch8| 100/ 152 batches|train_acc0.986 train_loss0.00128
|epoch8| 150/ 152 batches|train_acc0.989 train_loss0.00107
---------------------------------------------------------------------
| epoch 8 | time:5.22s | valid_acc 0.949 valid_loss 0.003
---------------------------------------------------------------------
|epoch9|  50/ 152 batches|train_acc0.988 train_loss0.00114
|epoch9| 100/ 152 batches|train_acc0.983 train_loss0.00127
|epoch9| 150/ 152 batches|train_acc0.989 train_loss0.00109
---------------------------------------------------------------------
| epoch 9 | time:5.28s | valid_acc 0.949 valid_loss 0.003
---------------------------------------------------------------------
|epoch10|  50/ 152 batches|train_acc0.986 train_loss0.00115
|epoch10| 100/ 152 batches|train_acc0.987 train_loss0.00117
|epoch10| 150/ 152 batches|train_acc0.986 train_loss0.00119
---------------------------------------------------------------------
| epoch 10 | time:5.22s | valid_acc 0.949 valid_loss 0.003
---------------------------------------------------------------------
test_acc,test_loss = evaluate(valid_dataloader)
print('模型准确率为:{:5.4f}'.format(test_acc))

输出:

模型准确率为:0.9492

2. 测试指定数据

#测试指定的数据
def predict(text, text_pipeline):
    with torch.no_grad():
        text = torch.tensor(text_pipeline(text))
        output = model(text, torch.tensor([0]))
        return output.argmax(1).item()

ex_text_str = "还有双鸭山到淮阴的汽车票吗13号的"
model = model.to("cpu")

print("该文本的类别是: %s" %label_name[predict(ex_text_str,text_pipeline)])

输出:

该文本的类别是: Travel-Query

五、总结

训练神经网络时,可使用梯度裁剪的方法来防止梯度爆炸,使得模型训练更加稳定

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/526505.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

【学习心得】Python中的queue模块使用

一、Queue模块的知识点思维导图 二、Queue模块常用函数介绍 queue模块是内置的&#xff0c;不需要安装直接导入就可以了。 &#xff08;1&#xff09;创建一个Queue对象 import queue# 创建一个队列实例 q queue.Queue(maxsize20) # 可选参数&#xff0c;默认为无限大&am…

libVLC 提取视频帧使用QWidget渲染

在前面的文章中&#xff0c;我们使用libvlc_media_player_set_hwnd设置了视频的显示的窗口。 libvlc_media_player_set_hwnd(vlc_mediaPlayer, (void *)ui.widgetShow->winId()); 如果我们想要提取每一帧数据&#xff0c;将数据渲染到QWidget上&#xff0c;该如何操作呢&a…

西圣、漫步者、万魔开放式耳机怎么样?无广真实测评对比推荐

开放式耳机因其独特的音质体验和佩戴舒适度&#xff0c;受到了越来越多消费者的青睐。西圣、漫步者、万魔作为国内知名的耳机品牌&#xff0c;各自都推出了自家的开放式耳机产品&#xff0c;那么&#xff0c;这三款耳机究竟如何呢&#xff1f;身为开放式耳机党的我&#xff0c;…

小白如何挖到自己的第一个漏洞

目录 挖洞公式 个人介绍 我的技术与生活——小站首页 | Hexohttps://xiaoyunxi.wiki/ 漏洞介绍 漏洞详情 如何进行信息收集(最快捷) 方法1(Google Hacking) 0x01 常见Google语法 0x02 不是经常用的关键字 0x03 特殊符号使用 0x04 布尔操作 0x05注意事项 0x06 使用…

【C++ STL算法】sort 排序

文章目录 【 1. 基本原理 】【 2. sort 的应用 】实例 - sort 函数实现 升序排序和降序排序 函数名用法sort (first, last)基于 快速排序&#xff0c;对容器或普通数组中 [ first, last ) 范围内的元素进行排序&#xff0c;默认进行升序排序&#xff08;从小到大&#xff09;。…

在linux上面安装nexus私有maven库

问题 需要在EC2机器上面安装私有maven库。 步骤 更新OS sudo yum update -yJDK11 # 安装JDK11 sudo yum install java-11-amazon-corretto # 查询jdk11位置 sudo alternatives --config java内容如下&#xff1a; There are 2 programs which provide java.Selection …

用动态IP采集数据总是掉线是为什么?该怎么解决?

动态IP可以说是做爬虫、采集数据、搜集热门商品信息中必备的代理工具&#xff0c;但在爬虫的使用中&#xff0c;总是会遇到动态IP掉线的情况&#xff0c;从而影响使用效率&#xff0c;本文将探讨动态IP代理掉线的几种常见原因&#xff0c;并提供解决方法&#xff0c;以帮助大家…

libVLC 提取视频帧使用OpenGL渲染

在上一节中&#xff0c;我们讲解了如何使用QWidget渲染每一帧视频数据。 由于我们不停的生成的是QImage对象&#xff0c;因此对 CPU 负荷较高。其实在绘制这块我们可以使用 OpenGL去绘制&#xff0c;利用 GPU 减轻 CPU 计算负荷&#xff0c;本节讲解使用OpenGL来绘制每一帧视频…

ctfshow web入门 php特性 web108--web115

web108 ereg函数相当于而preg_match()函数 ereg函数的漏洞&#xff1a;00截断。%00截断及遇到%00则默认为字符串的结束 strrev函数就是把字符串倒过来 就是说intval处理倒过来的传参c0x36d&#xff08;877&#xff09;?ca%00778 web109 异常处理类 通过异常处理类Excepti…

国外动态住宅IP代理的五大优势

在当前的数字时代&#xff0c;随着网络监控和地理限制的日益增强&#xff0c;国外动态住宅IP代理成为了解决这些问题的关键工具。它不仅能够保护用户的个人隐私&#xff0c;提供高度的匿名性&#xff0c;还能轻松突破访问限制&#xff0c;让全球的内容触手可及。对于数据科学家…

Vue+OpenLayers7入门到实战:OpenLayers加载WFS服务的要素资源数据并叠加到地图上

返回《Vue+OpenLayers7》专栏目录:Vue+OpenLayers7入门到实战 前言 本章讲解如何使用OpenLayers6加载WFS服务的要素资源数据并叠加到地图上的功能。 WFS规范介绍 WFS是基于地理要素级别的数据共享和数据操作,WFS规范定义了若干基于地理要素(Feature)级别的数据操作接口…

PL2303HXA自2012已停产,请联系供货商的解决方法

一、概述 Win10不支持PL2303HXA usb串口&#xff0c;当管理员做配置的时候发现无法查看端口信息&#xff0c;设备管理器提示&#xff1a;PL2303 HXA自2012已停产&#xff0c;请联系供货商。PL2303 是Prolific 公司生产的一种高度集成的RS232-USB接口转换器&#xff0c;可提供一…

(学习日记)2024.04.10:UCOSIII第三十八节:事件实验

写在前面&#xff1a; 由于时间的不足与学习的碎片化&#xff0c;写博客变得有些奢侈。 但是对于记录学习&#xff08;忘了以后能快速复习&#xff09;的渴望一天天变得强烈。 既然如此 不如以天为单位&#xff0c;以时间为顺序&#xff0c;仅仅将博客当做一个知识学习的目录&a…

微信多账号如何聚合聊天?

看这篇文章的你是否有以下烦恼&#xff1a; 1.微信账号太多&#xff0c;每次都要拿很多手机&#xff0c;管理起来很乱 2.微信号多&#xff0c;需要很多员工来管理&#xff0c;人工费用高 3.多个微信打开后会造成微信登陆界面过多&#xff0c;切换操作十分不方便 4.当微信多…

创建型模式--4.抽象工厂模式【弗兰奇一家】

1. 奔向大海 在海贼世界中&#xff0c;位于水之都的弗兰奇一家是由铁人弗兰奇所领导的以拆船为职业的家族&#xff0c;当然了他们的逆向工程做的也很好&#xff0c;会拆船必然会造船。船是海贼们出海所必备的海上交通工具&#xff0c;它由很多的零件组成&#xff0c;从宏观上看…

【案例·增】拼接字符串,增加字符型数据记录

问题描述&#xff1a; MySQL中的数据库表存在字符型(String)字段&#xff0c;要为其拼接信息以达成数据新增&#xff0c;可以使用 SQL 中的CONCAT()、CONCAT_WS()函数来处理 案例&#xff1a; #拼接字符串 SELECT CONCAT(HELLO, World);#举例&#xff1a;student学生表。sno…

为什么多数游戏服务端是用 C++ 来写呢,是历史原因还是性能方面的考虑?

游戏服务端开发语言的选择往往受到多方面因素的影响&#xff0c;其中C被广泛应用&#xff0c;这一现象的背后既包含了历史演进的原因&#xff0c;也凸显出性能至上的技术考量。 历史沿革 自上世纪80年代起&#xff0c;C语言便以其对C语言的兼容性、面向对象的特性以及对系统资…

搞懂LLM中的Token,看这一篇就够了

一、Token是什么&#xff1f; 1.1 Token一定表示一个汉字么&#xff1f; 下面这个例子中我让ChatGPT按照Token的划分粒度将“我喜欢苹果”进行倒序输出。 这个例子说明&#xff1a;Token既可能是一个单词&#xff0c;例如“苹果”&#xff0c;也可能是一个汉字&#xff0c;例…

新版本v24.1发布,私有网盘也能创建互联网对外分享链接了

24年龙年到来之际&#xff0c;我们发布了v24.1版本&#xff0c;其中商业版优先响应了客户期待已久的外网分享特性&#xff0c;现在客户既可以创建内网登录账号可用的分享链接&#xff0c;也可以安全地创建互联网可访问的公共链接&#xff0c;产品应用场景得到了延伸。而社区版则…

T-GATE:交叉注意力使文本到图像扩散模型中的推理变得麻烦

今天给大家带来最近一项非常有意思的研究通过计算出&#xff0c;交叉注意力的输出会趋向于一个固定点。 然后在保真度提升阶段忽略文本条件不仅可以减少计算的复杂性。并提出了TGATE这种简单而无需训练的方法&#xff0c;用于高效地生成图像。 TGATE的做法是&#xff0c;一旦交…