NLP项目实战01之电影评论分类

介绍:

欢迎来到本篇文章!在这里,我们将探讨一个常见而重要的自然语言处理任务——文本分类。具体而言,我们将关注情感分析任务,即通过分析电影评论的情感来判断评论是正面的、负面的。

展示:
训练展示如下:

在这里插入图片描述
在这里插入图片描述

实际使用如下:

请添加图片描述

实现方式:

选择PyTorch作为深度学习框架,使用电影评论IMDB数据集,并结合torchtext对数据进行预处理。

环境:

Windows+Anaconda
重要库版本信息
torch==1.8.2+cu102
torchaudio==0.8.2
torchdata==0.7.1
torchtext==0.9.2
torchvision==0.9.2+cu102

实现思路:

1、数据集
本次使用的是IMDB数据集,IMDB是一个含有50000条关于电影评论的数据集
数据如下:
请添加图片描述
请添加图片描述

2、数据加载与预处理
使用torchtext加载IMDB数据集,并对数据集进行划分
具体划分如下:

TEXT = data.Field(tokenize='spacy', tokenizer_language='en_core_web_sm')
LABEL = data.LabelField(dtype=torch.float)
# Load the IMDB dataset
train_data, test_data = datasets.IMDB.splits(TEXT, LABEL)

创建一个 Field 对象,用于处理文本数据。同时使用spacy分词器对文本进行分词,由于IMDB是英文的,所以使用en_core_web_sm语言模型。
创建一个 LabelField 对象,用于处理标签数据。设置dtype 参数为 torch.float,表示标签的数据类型为浮点型。

使用 datasets.IMDB.splits 方法加载 IMDB 数据集,并将文本字段 TEXT 和标签字段 LABEL 传递给该方法。返回的 train_data 和 test_data 包含了 IMDB 数据集的训练和测试部分。
下面是train_data的输出
请添加图片描述

3、构建词汇表与加载预训练词向量

TEXT.build_vocab(train_data,max_size=25000,vectors="glove.6B.100d",unk_init=torch.Tensor.normal_)
LABEL.build_vocab(train_data)

train_data:表示使用train_data中数据构建词汇表
max_size:限制词汇表的大小为 25000
vectors=“glove.6B.100d”:表示使用预训练的 GloVe 词向量,其中 “glove.6B.100d” 指的是包含 100 维向量的 6B 版 GloVe。
unk_init=torch.Tensor.normal_ :表示指定未知单词(UNK)的初始化方式,这里使用正态分布进行初始化。
LABEL.build_vocab(train_data):表示对标签进行类似的操作,构建标签的词汇表

train_iterator, valid_iterator, test_iterator = data.BucketIterator.splits( (train_data, valid_data, test_data), batch_size=BATCH_SIZE, device=device)

使用data.BucketIterator.splits 来创建数据加载器,包括训练、验证和测试集的迭代器。这将确保你能够方便地以批量的形式获取数据进行训练和评估。

4、定义神经网络
这里的网络定义比较简单,主要采用在词嵌入层(embedding)后接一个全连接层的方式完成对文本数据的分类。
具体如下:

class NetWork(nn.Module):
    def __init__(self,vocab_size,embedding_dim,output_dim,pad_idx):
        super(NetWork,self).__init__()
        self.embedding = nn.Embedding(vocab_size,embedding_dim,padding_idx=pad_idx)
        self.fc = nn.Linear(embedding_dim,output_dim)
        self.dropout = nn.Dropout(0.5)
        self.relu = nn.ReLU()

    def forward(self,x):
        embedded = self.embedding(x)
        embedded = embedded.permute(1,0,2) 
        pooled = F.avg_pool2d(embedded, (embedded.shape[1], 1)).squeeze(1)
        pooled = self.relu(pooled)
        pooled = self.dropout(pooled)
        
        output = self.fc(pooled)
        return output

5、模型初始化

vocab_size = len(TEXT.vocab)
embedding_dim  = 100
output = 1
pad_idx = TEXT.vocab.stoi[TEXT.pad_token]
model = NetWork(vocab_size,embedding_dim,output,pad_idx)
pretrained_embeddings = TEXT.vocab.vectors
model.embedding.weight.data.copy_(pretrained_embeddings)

定义模型的超参数,包括词汇表大小(vocab_size)、词向量维度(embedding_dim)、输出维度(output,在这个任务中是1,因为是二元分类,所以使用1),以及 PAD 标记的索引(pad_idx)

之后需要将预训练的词向量加载到嵌入层的权重中。TEXT.vocab.vectors 包含了词汇表中每个单词的预训练词向量,然后通过 copy_ 方法将这些词向量复制到模型的嵌入层权重中对网络进行初始化。这样做确保了模型的初始化状态良好。

6、训练模型

 total_loss = 0
 train_acc = 0 
model.train()
for batch in train_iterator:
        optimizer.zero_grad()
        preds = model(batch.text).squeeze(1)
        loss = criterion(preds,batch.label)
        total_loss += loss.item()

        batch_acc = (torch.round(torch.sigmoid(preds)) == batch.label).sum().item()
        train_acc += batch_acc
        
        loss.backward()
        optimizer.step()

    average_loss = total_loss / len(train_iterator)
    train_acc /= len(train_iterator.dataset)

optimizer.zero_grad():表示将模型参数的梯度清零,以准备接收新的梯度。
preds = model(batch.text).squeeze(1):表示一次前向传播的过程,由于model输出的是torch.tensor(batch_size,1)所以使用squeeze(1)给其中的1维度数据去除,以匹配标签张量的形状
criterion(preds,batch.label):定义的损失函数 criterion 计算预测值 preds 与真实标签 batch.label 之间的损失

(torch.round(torch.sigmoid(preds)) == batch.label).sum().item():
通过比较模型的预测值与真实标签,计算当前批次的准确率,并将其累加到 train_acc 中
后面的就是进行反向传播更新参数,还有就是计算loss和train_acc的值了
7、模型评估:

model.eval()
    valid_loss = 0
    valid_acc = 0
    best_valid_acc = 0
    with torch.no_grad():
        for batch in valid_iterator:
            preds = model(batch.text).squeeze(1)
            loss = criterion(preds,batch.label)
            valid_loss += loss.item()
            batch_acc = ((torch.round(torch.sigmoid(preds)) == batch.label).sum().item())
            valid_acc += batch_acc

和训练模型的类似,这里就不解释了

8、保存模型
这里一共使用了两种保存模型的方式:

torch.save(model, "model.pth")
torch.save(model.state_dict(),"model.pth")

第一种方式叫做模型的全量保存
第二种方式叫做模型的参数保存

全量保存是保存了整个模型,包括模型的结构、参数、优化器状态等信息
参数量保存是保存了模型的参数(state_dict),不包括模型的结构
9、测试模型
测试模型的基本思路:
加载训练保存的模型、对待推理的文本进行预处理、将文本数据加载给模型进行推理

加载模型:

saved_model_path = "model.pth"
saved_model = torch.load(saved_model_path)

输入文本:
input_text = “Great service! The staff was very friendly and helpful.”

文本进行处理:

tokenizer = get_tokenizer("spacy", language="en_core_web_sm")
tokenized_text = tokenizer(input_text)
indexed_text = [TEXT.vocab.stoi[token] for token in tokenized_text]
tensor_text = torch.LongTensor(indexed_text).unsqueeze(1).to(device)

模型推理:

saved_model.eval()
with torch.no_grad():
    output = saved_model(tensor_text).squeeze(1)
    prediction = torch.round(torch.sigmoid(output)).item()
    probability = torch.sigmoid(output).item()

由于笔者能力有限,所以在描述的过程中难免会有不准确的地方,还请多多包含!

更多NLP和CV文章以及完整代码请到"陶陶name"获取。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/230504.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

【深度学习】一维数组的 K-Means 聚类算法理解

刚看了这个算法,理解如下,放在这里,备忘,如有错误的地方,请指出,谢谢 需要做聚类的数组我们称之为【源数组】 需要一个分组个数K变量来标记需要分多少个组,这个数组我们称之为【聚类中心数组】…

Kafka快速实战以及基本原理详解

文章目录 一、Kafka介绍为什么要用Kafka 二、Kafka快速上手实验环境单机服务体验 三、理解Kakfa的消息传递机制四、Kafka集群服务五、理解服务端的Topic、Partition和Broker七、Kafka集群的整体结构八、Kraft集群Kraft集群简介配置Kraft集群 一、Kafka介绍 ChatGPT对于Apache …

人体关键点检测1:人体姿势估计数据集

人体关键点检测1:人体姿势估计数据集 目录 人体关键点检测1:人体姿势估计数据集 1.人体姿态估计 2.人体姿势估计数据集 (1)COCO数据集 (2)MPII数据集 (3)Human3.6M &#xf…

CENTOS 7 添加黑名单禁止IP访问服务器

一、通过 firewall 添加单个黑名单 只需要把ip添加到 /etc/hosts.deny 文件即可,格式 sshd:$IP:deny vim /etc/hosts.deny# 禁止访问sshd:*.*.*.*:deny# 允许的访问sshd:.*.*.*:allowsshd:.*.*.*:allow 二、多次失败登录即封掉IP,防止暴力破解的脚本…

【解决办法】Pycharm中新添加或者导入项目文件名红色!

提示:文章写完后,目录可以自动生成,如何生成可参考右边的帮助文档 文章目录 一、问题描述二、问题原因三、解决办法 一、问题描述 Pycharm的代码中添加新的文件夹,发现文件夹下的文件名是红色的,如下图: …

视频推拉流直播点播EasyDSS平台点播文件加密存储的实现方法

视频推拉流直播点播系统EasyDSS平台,可提供流畅的视频直播、点播、视频推拉流、转码、管理、分发、录像、检索、时移回看等功能,可兼容多操作系统,还能支持CDN转推,具备较强的可拓展性与灵活性,在直播点播领域具有广泛…

设计模式基础——概述(1/2)

目录 一、设计模式的定义 二、设计模式的三大类别 三、设计模式的原则 四、主要设计模式目录 4.1 创建型模式(Creational Patterns) 4.2 结构型模式(Structural Patterns) 4.3 行为型模式(Behavioral Patterns&…

5.10 Windows驱动开发:摘除InlineHook内核钩子

在笔者上一篇文章《内核层InlineHook挂钩函数》中介绍了通过替换函数头部代码的方式实现Hook挂钩,对于ARK工具来说实现扫描与摘除InlineHook钩子也是最基本的功能,此类功能的实现一般可在应用层进行,而驱动层只需要保留一个读写字节的函数即可…

Numpy数组的重塑,转置与切片 (第6讲)

Numpy数组的重塑,转置与切片 (第6讲)         🍹博主 侯小啾 感谢您的支持与信赖。☀️ 🌹꧔ꦿ🌹꧔ꦿ🌹꧔ꦿ🌹꧔ꦿ🌹꧔ꦿ🌹꧔ꦿ🌹꧔ꦿ🌹꧔ꦿ🌹꧔ꦿ🌹꧔ꦿ🌹꧔ꦿ🌹꧔ꦿ🌹꧔ꦿ🌹꧔ꦿ🌹꧔ꦿ🌹꧔ꦿ🌹꧔ꦿ🌹꧔ꦿ🌹꧔ꦿ🌹꧔ꦿ�…

如何将腾讯混元大模型AI接入自己的项目里(中国版本ChatGPT)

如何将腾讯混元大模型AI接入自己的项目里 一、腾讯混元大模型API二、使用步骤1、接口2、请求参数3、请求参数示例4、接口 返回示例 三、 如何获取appKey和uid1、申请appKey:2、获取appKey和uid 四、重要说明 一、腾讯混元大模型API 基于腾讯混元大模型AI的智能文本对话AI机器人…

uniapp 数组添加不重复元素

一、效果图 二、代码 //点击事件rightBtn(sub, index) {console.log(sub, index)//uniapp 数组添加不重复元素if (this.selectList.includes(sub.type)) {this.selectList this.selectList.filter((item) > {return item ! sub.type;});} else {this.selectList.push(sub.t…

长城之上的无人机:文化遗产的守护者

长城之上的无人机:文化遗产的守护者 在八达岭长城景区,两架无人机分别部署在了长城的南、北楼两点。根据当前的保护焦点和需求,制定了5条无人机综合巡查航线,以确保长城景区的所有开放区域都能得到有效监管。每天,无人…

【ITK库学习】使用itk库进行图像滤波ImageFilter:二阶微分

目录 1、itkRecursiveGaussianImageFliter 递归高斯滤波器2、itkLapacianRecursiveGaussianImageFiter 拉普拉斯高斯滤波器 1、itkRecursiveGaussianImageFliter 递归高斯滤波器 该类用于计算具有高斯核近似值的 IIR 卷积的基类。 该类是递归滤波器的基类,它与高斯…

MyBatis `saveBatch` 性能调优详解

文章目录 1. 引言2. MyBatis saveBatch 简介3. 常见性能问题3.1 SQL 语句拼接3.2 参数传递3.3 数据库连接数 4. MyBatis saveBatch 性能调优4.1 使用批量插入语句4.1.1 代码示例 4.2 使用MyBatis的foreach标签4.2.1 代码示例 4.3 使用VALUES构造器4.3.1 代码示例 4.4 调整批量大…

linux下部署frp客户端服务端-内网穿透

简介 部署在公司内部局域网虚拟机上的服务需要在外网能够访问到,这不就是内网穿透的需求吗,之前通过路由器实现过,现在公司这块路由器不具备这个功能了,目前市面上一些主流的内网穿透工具有:Ngrok,Natapp&…

【WPF.NET开发】WPF中的对话框

目录 1、消息框 2、通用对话框 3、自定义对话框 实现对话框 4、打开对话框的 UI 元素 4.1 菜单项 4.2 按钮 5、返回结果 5.1 模式对话框 5.2 处理响应 5.3 非模式对话框 Windows Presentation Foundation (WPF) 为你提供了自行设计对话框的方法。 对话框是窗口&…

HarmonyOS(鸿蒙操作系统)与Android系统 各自特点 架构对比 各自优势

综合对比 HarmonyOS(鸿蒙操作系统)是由华为开发的操作系统,旨在跨多种设备和平台使用。HarmonyOS的架构与谷歌开发的广泛使用的Android操作系统有显著不同。以下是两者之间的一些主要比较点: 设计理念和使用案例: Harm…

屏蔽百度首页推荐和热搜的实战方案

大家好,我是爱编程的喵喵。双985硕士毕业,现担任全栈工程师一职,热衷于将数据思维应用到工作与生活中。从事机器学习以及相关的前后端开发工作。曾在阿里云、科大讯飞、CCF等比赛获得多次Top名次。现为CSDN博客专家、人工智能领域优质创作者。喜欢通过博客创作的方式对所学的…

C++STL的string(超详解)

文章目录 前言C语言的字符串 stringstring类的常用接口string类的常见构造string (const string& str);string (const string& str, size_t pos, size_t len npos); capacitysize和lengthreserveresizeresize可以删除数据 modify尾插插入字符插入字符串 inserterasere…

二分查找|前缀和|滑动窗口|2302:统计得分小于 K 的子数组数目

作者推荐 贪心算法LeetCode2071:你可以安排的最多任务数目 本文涉及的基础知识点 二分查找算法合集 题目 一个数组的 分数 定义为数组之和 乘以 数组的长度。 比方说,[1, 2, 3, 4, 5] 的分数为 (1 2 3 4 5) * 5 75 。 给你一个正整数数组 nums 和一个整数…