文本分类TextRCNN模型(pytorch实现)

文本分类TextRCNN模型

        • RCNN简介
        • TextRCNN模型介绍
        • TextRCNN代码(文本10分类)

RCNN简介

从之前的文章中介绍过RNN的优点是能够捕捉到序列的时序信息,这可能有利于捕获长文本的语义。但是RNN对于文本序列后面的单词获取到的语义会更多,故RNN是一个有偏倚的模型。而CNN能够捕获到全局的信息,它可以很好地确定文本中带有最大池化层的识别性短语。因此,与递归或循环神经网络相比,CNN可以更好地捕捉文本的语义。于是为了综合两者的优点,提出了TextRCNN模型。主要用于解决文本分类问题。

首先,我们应用一个双向的循环结构(LSTMGRU),与传统的基于窗口的神经网络相比,它可以大大减少噪声,从而最大程度地捕捉上下文信息。此外,该模型在学习文本表示时可以保留更大范围的词序。其次,我们使用了一个可以自动判断哪些特性在文本分类中扮演关键角色的池化层(max-pooling),以捕获文本中的关键组件。我们的模型结合了RNN的结构和最大池化层,利用了循环神经模型和卷积神经模型的优点。此外,我们的模型显示了O(n)的时间复杂度,它与文本长度的长度是线性相关的。

TextRCNN模型介绍

如下图是作者提出的模型框架,输入是一个文本sentence,它可以看成是由一系列单词组成的(w1,w2…)。输出是一个概率分布,最大的那个位置对应文章属于的类别K 。

在这里插入图片描述

RCNN整体的模型构建流程如下:

  1. 利用Bi-LSTM获得上下文的信息,类似于语言模型。
  2. Bi-LSTM获得的隐层输出和词向量拼接[fwOutput, wordEmbedding, bwOutput]。
  3. 将拼接后的向量非线性映射到低维。
  4. 向量中的每一个位置的值都取所有时序上的最大值,得到最终的特征向量,该过程类似于max-pool。
  5. softmax分类。

RCNN整体的运算过程为:

  1. 首先将词进行词向量编码,即第一栏中间的 word-embedding
  2. 接着将词向量输入到双向的 RNN(这里的 RNN cell 可以使用 lstm 或者 gru 或者最简单的)双向的RNN能够更好的捕捉上下文关系
  3. 将1和2得到的结果拼接在一起也就是模型构建流程中的2,然后输入到激活函数中(tanh或者sigmod)得到图中所示的 y ( 2 ) y^{(2)} y(2)
  4. 然后对于相同位置的值做最大池化,得到图中的 y ( 3 ) y^{(3)} y(3)
  5. softmax分类或者fc映射,得到图中最终结果。
TextRCNN代码(文本10分类)
# coding: UTF-8
import torch
import torch.nn as nn
import torch.nn.functional as F


class Config(object):
    """配置参数"""

    def __init__(self):
        self.model_name = 'TextRCNN'
        self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')  # 设备
        self.dropout = 1.0  # 随机失活
        self.num_classes = 10  # 类别数
        self.n_vocab = 10000  # 词表大小,在运行时赋值
        self.num_epochs = 10  # epoch数
        self.batch_size = 128  # mini-batch大小
        self.pad_size = 32  # 每句话处理成的长度(短填长切)
        self.learning_rate = 1e-3  # 学习率
        self.embed = 300  # 字向量维度, 若使用了预训练词向量,则维度统一
        self.hidden_size = 256  # lstm隐藏层
        self.num_layers = 1  # lstm层数

'''Recurrent Convolutional Neural Networks for Text Classification'''

class Model(nn.Module):
    def __init__(self, config):
        super(Model, self).__init__()
        self.embedding = nn.Embedding(config.n_vocab, config.embed, padding_idx=config.n_vocab - 1)
        self.lstm = nn.LSTM(config.embed, config.hidden_size, config.num_layers,
                            bidirectional=True, batch_first=True, dropout=config.dropout)
        self.maxpool = nn.MaxPool1d(config.pad_size)
        self.fc = nn.Linear(config.hidden_size * 2 + config.embed, config.num_classes)

    def forward(self, x):
        print(x)
        x, _ = x  # torch.Size([128, 32])
        embed = self.embedding(x)  # [batch_size, seq_len, embeding]=[128,32,300]
        out, _ = self.lstm(embed)  # [batch_size, seq_len, hiden_size*2]=[128,32,512]
        out = torch.cat((embed, out), 2)  # [batch_size, seq_len, hiden_size*2]=[128,32,512+300]
        out = F.relu(out)  # ReLU 激活函数将所有负数变换为零,而保持正数不变 [batch_size, seq_len, hiden_size*2]=[128,32,512+300]
        out = out.permute(0, 2, 1)  # 将张量 out 的第二个和第三个维度进行交换。[128,812,32]
        out = self.maxpool(out).squeeze()  # config.pad_size=32,则maxpool之后,维度为[128,812,1],经过压缩之后变为[128,812]
        out = self.fc(out)  # [128,812]====>[128,10]
        return out

config=Config()
model=Model(config)
print(model)

输出:

Model(
  (embedding): Embedding(10000, 300, padding_idx=9999)
  (lstm): LSTM(300, 256, batch_first=True, dropout=1.0, bidirectional=True)
  (maxpool): MaxPool1d(kernel_size=32, stride=32, padding=0, dilation=1, ceil_mode=False)
  (fc): Linear(in_features=812, out_features=10, bias=True)
)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/631011.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

Python 全栈体系【四阶】(四十五)

第五章 深度学习 十、生成对抗网络(GAN) 1. 图像生成技术概述 1.1 什么是图像生成技术 图像生成技术是指利用机器学习或深度学习等人工智能技术,通过训练模型来生成逼真的图像。这些技术可以根据给定的输入,生成与真实图像相似…

线性系统(二)

线性系统(二) 1.直观理解线性方程组结构2. 不同解的结论3. 更一般的高斯-约旦消元法4.齐次线性方程组 链接: 线性系统(一) 1.直观理解线性方程组结构 长这样,方程就有解,即相交坐标。 长这样,…

《天空之城》观后感

曾经很长一段时间都着迷于《天空之城》这段旋律,一遍一遍不厌其烦地听,静谧而温馨、豪迈却苍凉,各种复杂的感受随着起伏的音符流淌进心里。多年之后才知道这首曲子出自宫崎骏的同名动画电影。说来也有意思,似乎大多数人是通过电影…

如何配置静态住宅IP?

静态住宅IP是指专为家庭网络环境设计的固定IP地址,通常由互联网服务提供商(ISP)为家庭用户提供。这种IP地址在其生命周期中保持不变,除非由于某些外部因素(如ISP更改策略)或用户请求更改。相比于动态IP地址…

css设置滚动条的样式

/* 滚动条样式 *//* 定义滚动条整体的宽度和轨道的背景颜色 */::-webkit-scrollbar {width: 10px;/* 对于垂直滚动条的宽度 */height: 10px;/* 对于水平滚动条的高度,可选 */}/* 定义滚动条轨道的样式 */::-webkit-scrollbar-track {background-color: rgba(0, 0, 0…

全新多语言海外抢单刷单系统源码 订单自动匹配 支持分组 代理后台

安装教程 测试环境:Nginx PHP7.0 MySQL5.6 config/database 修改数据库 设置运行目录public 伪静态thinkphp 后台登录地址:/admin 账号admin 密码admin123 前端出现报错 删除runtime文件夹得缓存文件即可 源码免费下载地址抄笔记 (chaobiji.cn)

机器人非线性系统反馈线性化——Brunovsky标准型

Brunovsky Canonical Form 机器人非线性系统的反馈线性化,特别是涉及到Brunovsky标准型,是现代控制理论中的一个重要话题。反馈线性化是一种非线性控制设计方法,其核心思想是通过设计反馈控制器,将非线性系统转化为线性系统。这种…

windows驱动开发-PCI讨论(一)

前面描述中断的时候,我们曾经多次体积PCI,甚至提供了一些PCI的相关知识,但是整个PCI是一个很大的体系,专门记录这个体系超出了这个系列的范畴,有兴趣的可以到PCI官网了解详细的情况。 但是还是会花费一些时间讨论PCI技…

Python 全栈体系【四阶】(四十四)

第五章 深度学习 九、图像分割 3. 常用模型 3.4 DeepLab 系列 3.4.3 DeepLab v3(2017) 在DeepLab v3中,主要进行了以下改进: 使用更深的网络结构,以及串联不同膨胀率的空洞卷积,来获取更多的上下文信…

Spark RDD案例:统计网站每月访问量

这个项目利用Spark技术,通过统计网站访问记录中的日期信息,实现了对每月访问量的统计和排序。通过分析数据,我们可以了解到不同月份的网站访问情况,为进一步优化网站内容和推广策略提供数据支持。 使用Spark统计网站每月访问量 …

平芯微PW4056HH中文规格书

概述 PW4056HH 是一款完整的采用恒定电流/恒定电压的高压、大电流、单节锂离子电池线性充电 IC。充电电流可达 1A。输入 MAX 低工作电压 3.75V,降低充电功耗,提高效率。 PW4056HH 采用了内部 PMOS 架构,加上防反充电路,不需要外部…

白酒:低酒精度白酒的消费特点与市场前景

低酒精度白酒的消费特点与市场前景是酒类市场的一个重要话题。随着品质意识的提高和消费者口味的多样化,低酒精度白酒逐渐受到越来越多的关注。云仓酒庄豪迈白酒作为白酒的品牌之一,其消费特点和市场前景值得深入探讨。 首先,从消费特点来看…

linux Docker在线/离线服务安装并支持centos7和centos8系统

注:以下内容都是经过测试;能在生产环境使用. 一、centos7版本的docker在线安装 1:运行以下命令,下载docker-ce的yum源。 sudo wget -O /etc/yum.repos.d/docker-ce.repo https://mirrors.aliyun.com/docker-ce/linux/centos/docker-ce.repo…

【数据结构】时间、空间复杂度实例分析

跌倒了,就重新站起来,继续向前走;傻坐在地上是没用的。💓💓💓 目录 •✨说在前面 🍋知识点一:算法的效率 • 🌰1.斐波那契数列的第n项 • 🌰2.算法的复杂度…

类和对象、包等知识总结Java

类 类的概念:类是用来对一个实体(对象)进行描述的,主要描述该对象的属性,功能等。 类的定义和实例化 定义 定义类需要用到class关键字 (大驼峰定义)for example:class Dog... 初步了解一下…

Hadoop3:客户端向HDFS写数据流的流程讲解(较枯燥)

一、场景描述 我们登陆HDFS的web端,上传一个大文件。 二、流程图 三、讲解 流程1(Client与NameNode交互) 1、HDFS client创建DistributedFileSystem,通过dfs与NameNode进行2次(一来一回4次)对话&#x…

【UE5.1 角色练习】02-添加慢走、快速跑、蹲伏功能

目录 前言 步骤 一、慢走 二、快速跑 三、蹲伏 前言 在上一篇文章基础上(【UE5.1 角色练习】01-使用小白人蓝图控制商城角色移动)继续实现角色的慢走、快速跑以及蹲伏功能 步骤 一、慢走 1. 打开项目设置,添加一个操作映射&#x…

uniapp地图电子围栏(多边形)绘制和编辑

uniapp地图电子围栏(多边形)绘制和编辑 背景实现思路代码实现注意事项尾巴 背景 最近项目中需要在地图上进行电子围栏的绘制和编辑,这里将实现的思路给大家分享下。由于uniapp官方提供的map组件功能不全,还有在APP端(…

Java面试题:Spring框架除了IOC和AOP,还有哪些好玩的设计模式?

Spring是一个基于Java的企业级应用程序开发框架,它使用了多种设计模式来实现其各种特性和功能。本文将介绍一些在Spring中使用的常见设计模式以及相应的代码示例和说明。 单例模式 单例模式是Spring中最常用的设计模式之一。在ApplicationContext中,Bean…

5.8 5.9_C语言(初阶)

1、函数 1.1 函数是什么 数学中就见过函数,例,f(x)2*x1,如果给一个不同的数x,就能得到一个不同的f(x),这就是数学中的函数,其实在C语言中跟这非常相似。 维基百科中对函数的定义:子程序 在计算机科学中&#xff0c…