NLP项目实战01--电影评论分类

介绍:

欢迎来到本篇文章!在这里,我们将探讨一个常见而重要的自然语言处理任务——文本分类。具体而言,我们将关注情感分析任务,即通过分析电影评论的情感来判断评论是正面的、负面的。

展示:
训练展示如下:

在这里插入图片描述
在这里插入图片描述

实际使用如下:

请添加图片描述

实现方式:

选择PyTorch作为深度学习框架,使用电影评论IMDB数据集,并结合torchtext对数据进行预处理。

环境:

Windows+Anaconda
重要库版本信息
torch==1.8.2+cu102
torchaudio==0.8.2
torchdata==0.7.1
torchtext==0.9.2
torchvision==0.9.2+cu102

实现思路:

1、数据集
本次使用的是IMDB数据集,IMDB是一个含有50000条关于电影评论的数据集
数据如下:
请添加图片描述
请添加图片描述

2、数据加载与预处理
使用torchtext加载IMDB数据集,并对数据集进行划分
具体划分如下:

TEXT = data.Field(tokenize='spacy', tokenizer_language='en_core_web_sm')
LABEL = data.LabelField(dtype=torch.float)
# Load the IMDB dataset
train_data, test_data = datasets.IMDB.splits(TEXT, LABEL)

创建一个 Field 对象,用于处理文本数据。同时使用spacy分词器对文本进行分词,由于IMDB是英文的,所以使用en_core_web_sm语言模型。
创建一个 LabelField 对象,用于处理标签数据。设置dtype 参数为 torch.float,表示标签的数据类型为浮点型。

使用 datasets.IMDB.splits 方法加载 IMDB 数据集,并将文本字段 TEXT 和标签字段 LABEL 传递给该方法。返回的 train_data 和 test_data 包含了 IMDB 数据集的训练和测试部分。
下面是train_data的输出
请添加图片描述

3、构建词汇表与加载预训练词向量

TEXT.build_vocab(train_data,max_size=25000,vectors="glove.6B.100d",unk_init=torch.Tensor.normal_)
LABEL.build_vocab(train_data)

train_data:表示使用train_data中数据构建词汇表
max_size:限制词汇表的大小为 25000
vectors=“glove.6B.100d”:表示使用预训练的 GloVe 词向量,其中 “glove.6B.100d” 指的是包含 100 维向量的 6B 版 GloVe。
unk_init=torch.Tensor.normal_ :表示指定未知单词(UNK)的初始化方式,这里使用正态分布进行初始化。
LABEL.build_vocab(train_data):表示对标签进行类似的操作,构建标签的词汇表

train_iterator, valid_iterator, test_iterator = data.BucketIterator.splits( (train_data, valid_data, test_data), batch_size=BATCH_SIZE, device=device)

使用data.BucketIterator.splits 来创建数据加载器,包括训练、验证和测试集的迭代器。这将确保你能够方便地以批量的形式获取数据进行训练和评估。

4、定义神经网络
这里的网络定义比较简单,主要采用在词嵌入层(embedding)后接一个全连接层的方式完成对文本数据的分类。
具体如下:

class NetWork(nn.Module):
    def __init__(self,vocab_size,embedding_dim,output_dim,pad_idx):
        super(NetWork,self).__init__()
        self.embedding = nn.Embedding(vocab_size,embedding_dim,padding_idx=pad_idx)
        self.fc = nn.Linear(embedding_dim,output_dim)
        self.dropout = nn.Dropout(0.5)
        self.relu = nn.ReLU()

    def forward(self,x):
        embedded = self.embedding(x)
        embedded = embedded.permute(1,0,2) 
        pooled = F.avg_pool2d(embedded, (embedded.shape[1], 1)).squeeze(1)
        pooled = self.relu(pooled)
        pooled = self.dropout(pooled)
        
        output = self.fc(pooled)
        return output

5、模型初始化

vocab_size = len(TEXT.vocab)
embedding_dim  = 100
output = 1
pad_idx = TEXT.vocab.stoi[TEXT.pad_token]
model = NetWork(vocab_size,embedding_dim,output,pad_idx)
pretrained_embeddings = TEXT.vocab.vectors
model.embedding.weight.data.copy_(pretrained_embeddings)

定义模型的超参数,包括词汇表大小(vocab_size)、词向量维度(embedding_dim)、输出维度(output,在这个任务中是1,因为是二元分类,所以使用1),以及 PAD 标记的索引(pad_idx)

之后需要将预训练的词向量加载到嵌入层的权重中。TEXT.vocab.vectors 包含了词汇表中每个单词的预训练词向量,然后通过 copy_ 方法将这些词向量复制到模型的嵌入层权重中对网络进行初始化。这样做确保了模型的初始化状态良好。

6、训练模型

 total_loss = 0
 train_acc = 0 
model.train()
for batch in train_iterator:
        optimizer.zero_grad()
        preds = model(batch.text).squeeze(1)
        loss = criterion(preds,batch.label)
        total_loss += loss.item()

        batch_acc = (torch.round(torch.sigmoid(preds)) == batch.label).sum().item()
        train_acc += batch_acc
        
        loss.backward()
        optimizer.step()

    average_loss = total_loss / len(train_iterator)
    train_acc /= len(train_iterator.dataset)

optimizer.zero_grad():表示将模型参数的梯度清零,以准备接收新的梯度。
preds = model(batch.text).squeeze(1):表示一次前向传播的过程,由于model输出的是torch.tensor(batch_size,1)所以使用squeeze(1)给其中的1维度数据去除,以匹配标签张量的形状
criterion(preds,batch.label):定义的损失函数 criterion 计算预测值 preds 与真实标签 batch.label 之间的损失

(torch.round(torch.sigmoid(preds)) == batch.label).sum().item():
通过比较模型的预测值与真实标签,计算当前批次的准确率,并将其累加到 train_acc 中
后面的就是进行反向传播更新参数,还有就是计算loss和train_acc的值了
7、模型评估:

model.eval()
    valid_loss = 0
    valid_acc = 0
    best_valid_acc = 0
    with torch.no_grad():
        for batch in valid_iterator:
            preds = model(batch.text).squeeze(1)
            loss = criterion(preds,batch.label)
            valid_loss += loss.item()
            batch_acc = ((torch.round(torch.sigmoid(preds)) == batch.label).sum().item())
            valid_acc += batch_acc

和训练模型的类似,这里就不解释了

8、保存模型
这里一共使用了两种保存模型的方式:

torch.save(model, "model.pth")
torch.save(model.state_dict(),"model.pth")

第一种方式叫做模型的全量保存
第二种方式叫做模型的参数保存

全量保存是保存了整个模型,包括模型的结构、参数、优化器状态等信息
参数量保存是保存了模型的参数(state_dict),不包括模型的结构
9、测试模型
测试模型的基本思路:
加载训练保存的模型、对待推理的文本进行预处理、将文本数据加载给模型进行推理

加载模型:

saved_model_path = "model.pth"
saved_model = torch.load(saved_model_path)

输入文本:
input_text = “Great service! The staff was very friendly and helpful.”

文本进行处理:

tokenizer = get_tokenizer("spacy", language="en_core_web_sm")
tokenized_text = tokenizer(input_text)
indexed_text = [TEXT.vocab.stoi[token] for token in tokenized_text]
tensor_text = torch.LongTensor(indexed_text).unsqueeze(1).to(device)

模型推理:

saved_model.eval()
with torch.no_grad():
    output = saved_model(tensor_text).squeeze(1)
    prediction = torch.round(torch.sigmoid(output)).item()
    probability = torch.sigmoid(output).item()

由于笔者能力有限,所以在描述的过程中难免会有不准确的地方,还请多多包含!

更多NLP和CV文章以及完整代码请到"陶陶name"获取。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/233197.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

Android笔记(十七):PendingIntent简介

PendingIntent翻译成中文为“待定意图”,这个翻译很好地表示了它的涵义。PendingIntent描述了封装Intent意图以及该意图要执行的目标操作。PendingIntent封装Intent的目标行为的执行是必须满足一定条件,只有条件满足,才会触发意图的目标操作。…

HCIP —— BGP 基础 (上)

BGP --- 边界网关协议 (路径矢量协议) IGP --- 内部网关协议 --- OSPF RIP ISIS EGP --- 外部网关协议 --- EGP BGP AS --- 自治系统 由单一的组织或者机构独立维护的网络设备以及网络资源的集合。 因 网络范围太大 需 自治 。 为区分不同的AS&#…

C#,图算法——以邻接节点表示的图最短路径的迪杰斯特拉(Dijkstra)算法C#程序

1 文本格式 using System; using System.Text; using System.Linq; using System.Collections; using System.Collections.Generic; namespace Legalsoft.Truffer.Algorithm { public class Node // : IComparable<Node> { private int vertex, weigh…

CNN发展史脉络 概述图整理

CNN发展史脉络概述图整理&#xff0c;学习心得&#xff0c;供参考&#xff0c;错误请批评指正。 相关论文&#xff1a; LeNet&#xff1a;Handwritten Digit Recognition with a Back-Propagation Network&#xff1b; Gradient-Based Learning Applied to Document Recogniti…

【工具使用-JFlash】如何使用Jflash擦除和读取MCU内部指定扇区的数据

一&#xff0c;简介 在调试的过程中&#xff0c;特别是在调试向MCU内部flash写数据的时候&#xff0c;我们常常要擦除数据区的内容&#xff0c;而不想擦除程序取。那这种情况就需要擦除指定的扇区数据即可。本文介绍一种方法&#xff0c;可以擦除MCU内部Flash中指定扇区的数据…

【小沐学Python】Python实现TTS文本转语音(speech、pyttsx3、百度AI)

文章目录 1、简介2、Windows语音2.1 简介2.2 安装2.3 代码 3、pyttsx33.1 简介3.2 安装3.3 代码 4、ggts4.1 简介4.2 安装4.3 代码 5、pywin326、百度AI7、百度飞桨结语 1、简介 TTS(Text To Speech) 译为从文本到语音&#xff0c;TTS是人工智能AI的一个模组&#xff0c;是人机…

【linux】yum安装时: Couldn‘t resolve host name for XXXXX

yum 安装 sysstat 报错了&#xff1a; Kylin Linux Advanced Server 10 - Os 0.0 B/s | 0 B 00:00 Errors during downloading metadata for repository ks10-adv-os:- Curl error (6): Couldnt resolve host nam…

【摸鱼向】利用Arduino实现自动化切屏

曾几何时&#xff0c;每次背着老妈打游戏的时候都要紧张兮兮地听着爸妈是不是会破门而入&#xff0c;这严重影响了游戏体验&#xff0c;因此&#xff0c;最近想到了用Arduino加上红外传感器来实现自动监测的功能&#xff0c;当有人靠近门口的时候&#xff0c;电脑可以自动执行预…

【文件上传系列】No.2 秒传(原生前端 + Node 后端)

上一篇文章 【文件上传系列】No.1 大文件分片、进度图展示&#xff08;原生前端 Node 后端 & Koa&#xff09; 秒传效果展示 秒传思路 整理的思路是&#xff1a;根据文件的二进制内容生成 Hash 值&#xff0c;然后去服务器里找&#xff0c;如果找到了&#xff0c;说明已经…

pytorch:YOLOV1的pytorch实现

pytorch&#xff1a;YOLOV1的pytorch实现 注&#xff1a;本篇仅为学习记录、学习笔记&#xff0c;请谨慎参考&#xff0c;如果有错误请评论指出。 参考&#xff1a; 动手学习深度学习pytorch版——从零开始实现YOLOv1 目标检测模型YOLO-V1损失函数详解 3.1 YOLO系列理论合集(Y…

【IDEA】解决mac版IDEA,进行单元测试时控制台不能输入问题

我的IDEA版本 编辑VM配置 //增加如下配置&#xff0c;重启IDEA -Deditable.java.test.consoletrue测试效果

风力发电对讲 IP语音对讲终端IP安防一键呼叫对讲 医院对讲终端SV-6005网络音频终端

风力发电对讲 IP语音对讲终端IP安防一键呼叫对讲 医院对讲终端SV-6005网络音频终端 目 录 1、产品规格 2、接口使用 2.1、侧面接口功能 2.2、背面接口功能 2.3、面板接口功能 3、功能使用 1、产品规格 输入电源&#xff1a; 12V&#xff5e;24V的直流电源 网络接口&am…

初级数据结构(二)——链表

文中代码源文件已上传&#xff1a;数据结构源码 <-上一篇 初级数据结构&#xff08;一&#xff09;——顺序表 | NULL 下一篇-> 1、链表特征 与顺序表数据连续存放不同&#xff0c;链表中每个数据是分开存放的&#xff0c;而且存放的位置尤其零散&#…

C# WPF上位机开发(会员管理软件)

【 声明&#xff1a;版权所有&#xff0c;欢迎转载&#xff0c;请勿用于商业用途。 联系信箱&#xff1a;feixiaoxing 163.com】 好多同学都认为上位机只是纯软件开发&#xff0c;不涉及到硬件设备&#xff0c;比如听听音乐、看看电影、写写小的应用等等。如果是消费电子&#…

人工智能从 DeepMind 到 ChatGPT ,从 2012 - 2024

本心、输入输出、结果 文章目录 人工智能从 DeepMind 到 ChatGPT &#xff0c;从 2012 - 2024前言2010年&#xff1a;DeepMind诞生2012&#xff5e;2013年&#xff1a;谷歌重视AI发展&#xff0c;“拿下”Hinton2013&#xff5e;2014年&#xff1a;谷歌收购DeepMind2013年&…

调用别人提供的接口无法通过try catch捕获异常(C#),见鬼了

前几天做CA签名这个需求时发现一个很诡异的事情&#xff0c;CA签名调用的接口是由另外一个开发部门的同事(比较难沟通的那种人)封装并提供到我们这边的。我们这边只需要把数据准备好&#xff0c;然后调他封装的接口即可完成签名操作。但在测试过程中&#xff0c;发现他提供的接…

2024年网络安全竞赛-数字取证调查attack817

​ 数字取证调查 (一)拓扑图 服务器场景:FTPServer20221010(关闭链接) 服务器场景操作系统:未知 FTP用户名:attack817密码:attack817 分析attack.pcapng数据包文件,通过分析数据包attack.pcapng找出恶意用户第一次访问HTTP服务的数据包是第几号,将该号数作为Flag值…

力扣刷题总结 字符串(2)【KMP】

&#x1f525;博客主页&#xff1a; A_SHOWY&#x1f3a5;系列专栏&#xff1a;力扣刷题总结录 数据结构 云计算 数字图像处理 28.找出字符串中第一个匹配项的下标mid经典KMP4593重复的子字符串mid可以使用滑动窗口或者KMP KMP章节难度较大&#xff0c;需要深入理解其中…

Uncaught SyntaxError: Unexpected end of input (at manage.html:1:21) 的一个解

关于Uncaught SyntaxError: Unexpected end of input (at manage.html:1:21)的一个解 问题复现 <button onclick"deleteItem(${order.id},hire,"Orders")" >delete</button>报错 原因 函数参数的双引号和外面的双引号混淆了&#xff0c;改成…

数据可视化|jupyter notebook运行pyecharts,无法正常显示“可视化图形”,怎么解决?

前言 本文是该专栏的第39篇,后面会持续分享python数据分析的干货知识,记得关注。 相信有些同学在本地使用jupyter notebook运行pyecharts的时候,在代码没有任何异常的情况下,无论是html还是notebook区域,都无法显示“可视化图形”,界面区域只有空白一片。遇到这种情况,…