文本分类-RNN-LSTM

1.前言

        本节介绍RNN和LSTM,并采用它们在电影评论数据集上实现文本分类,会涉及以下几个知识点。

        1. 词表构建:包括数据清洗,词频统计,词频截断,词表构建。

        2. 预训练词向量应用:下载并加载Glove的预训练embedding进行训练,主要是如何把词向量放到nn.embedding层中的权重。

        3. RNN及LSTM构建:涉及nn.RNN和nn.LSTM的使用。

2.任务介绍

        本节采用的数据集是斯坦福大学的大型电影评论数据集(large movie review dataset) https://ai.stanford.edu/~amaas/data/sentiment/

        包含25000个训练样本,25000个测试样本,下载解压后得到aclImdb文件夹,aclImdb下有train和test,neg和pos下分别 有txt文件,txt中为电影评论文本。

         来看看一条具体的样本,train/pos/3_10.txt:

        本节任务就是对这样的一条文本进行处理,输出积极/消极的二分类概率向量。

3.数据模块

        文本任务与图像任务不同,输入不再是像素这样的数值,而是字符串,因此需要将字符串转为矩阵运算可接受的向量形 式。

         为此需要在数据处理模块完成以下步骤:

        a.分词:将一长串文本切分为一个个独立语义的词,英文可用空格来切分。

        b. 词嵌入:词嵌入通常分两步。首先将词字符串转为索引序号,然后索引序号根据词嵌入矩阵(embedding层)取对应的向量。其中词与索引之间的映射关系需要提前构建,这就是词表构建的过程。

        因此,代码开发整体流程:

        1. 编写分词功能函数

        2. 构建词表:对训练数据进行分词,统计词频,并构建词表。例如{'UNK': 0, 'PAD': 1, 'the': 2, '.': 3, 'and': 4, 'a': 5, 'of': 6, 'to': 7, ...}

        3. 编写PyTorch的Dataset,实现分词、词转序号、长度填充/截断序号转词向量的过程由模型的nn.Embedding层实现,因此数据模块只需将词变为索引序号即可,接下来一一解析各环节核心功能代码实现。

        序号转词向量的过程由模型的nn.Embedding层实现,因此数据模块只需将词变为索引序号即可,接下来一一解析各环节核心功能代码实现。

4.词表构建

        参考配套代码a_gen_vocabulary.py,首先编写分词功能函数,分词前做一些简单的数据清洗,例如在标点符号前加入空 格、去除掉不是大小写字母及 .!? 符号的数据。

        接着,写一个词表统计类实现词频统计,和词表字典的创建,代码注释非常详细,这里不赘述。 运行代码,即可完成词频统计,词表的构建,并保存到本地npy文件,在训练及推理过程中使用。

        在词表构建过程中有一个截断数量的超参数需要设置,这里设置为20000,即最多有20000个词的表示,不在字典中的词被归为UNK这个词。

         在这个数据集中,原始词表长度为74952,即通过split切分后,有7万多个不一样的字符串,通常可以通过降序排列,取前面一部分即可。

        代码会输出词频统计图,也可以观察出词频下降的速度以及高频词是哪些。

5.Dataset编写

        参考配套代码aclImdb_dataset.py,getitem中主要做两件事,首先获取label,然后获取文本预处理后的列表,列表中元素是词所对应的索引序号。

        在self.word2index.encode中需要注意设置文本最大长度self.max_len,这是由于需要将所有文本处理到相同长度,长度不足的用词填充,长度超出则截断。

6.模型模块——RNN

        模型的构建相对简单,理论知识在这里不介绍,需要了解和温习的推荐看看《动手学》。这里借助动手学的RNN图片讲解代码的实现。

        在构建的模型RNNTextClassifier中,需要三个子module,分别是:

                1. nn.Embedding:将词序号变为词向量,用于后续矩阵运算

                2. nn.RNN:循环神经网络的实现

                3. nn.Linear:最终分类输出层的实现

        在forward时,流程如下:

                1. 获取词向量

                2. 构建初始化隐藏层,默认为全0

                3. rnn推理获得输出层和隐藏层

                4. fc层输出分类概率:fc层的输入是rnn最后一个隐藏层

        更多关于nn.RNN的参数设置,可以参考官方文档:

        torch.nn.RNN(self, input_size, hidden_size, num_layers=1, nonlinearity='tanh', bias=True, batch_first=False, dropout=0.0, bidirectional=False, device=None, dtype=None)

7.模型模块——LSTM

        RNN是神经网络中处理时序任务最为经典的设计,但是其也存在一些缺点,例如梯度消失和梯度爆炸,以及长期依赖问 题。

        当序列很长时,RNN模型很难捕捉到远距离的依赖关系,导致模型预测不准确。

        为此,带门控机制的RNN涌现,包括GRU(Gated Recurrent Unit,门控循环单元)和LSTM(Long Short-Term Memory,长短期记忆网络),其中LSTM应用最广,这里直接跳过GRU。         LSTM模型引入了三个门(input gate、forget gate和output gate),用于控制输入、输出和遗忘的流动,允许模型有选择性地忘记或记住一些信息。

        input gate用于控制输入的流动

        forget gate用于控制遗忘的流动

        output gate用于控制输出的流动

        相较于RNN,除了输出隐藏层向量h,还输出记忆层向量c,不过对于下游使用,不需要关心向量c的存在。 同样地,借助《动手学》中的LSTM示意图来理解代码。

        在这里,借鉴《动手学》的代码,采用的LSTM为双向LSTM,这里简单介绍双向循环神经网络的概念。

         双向循环神经网络(Bidirectional Recurrent Neural Network,Bi-RNN)同时考虑前向和后向的上下文信息,前向层和后向层的输出在每个时间步骤上都被连接起来,形成了一个综合的输出,这样可以更好地捕捉序列中的上下文信息。

        在pytorch代码中,只需要将bidirectional设置为True即可,

        nn.LSTM(embed_size, num_hiddens, num_layers=num_layers, bidirectional=True)。

        当采用双向时,需要注意output矩阵的shape为 [ sequence length , batch size ,2×hidden size]

        更多关于nn.LSTM的参数设置,可以参考官方文档:torch.nn.LSTM(self, input_size, hidden_size, num_layers=1, bias=True, batch_first=False, dropout=0.0, bidirectional=False, proj_size=0, device=None, dtype=None)

        详细参考:https://pytorch.org/docs/stable/generated/torch.nn.LSTM.html#torch.nn.LSTM

8.embedding预训练加载

        模型构建好之后,词向量的embedding层是随机初始化的,要从头训练具备一定逻辑关系的词向量表示是费时费力的, 通常可以采用在大规模预料上训练好的词向量矩阵。

        这里可以参考斯坦福大学的GloVe(Global Vectors for Word Representation)预训练词向量。

        GloVe是一种无监督学习算法,用于获取单词的向量表示,GloVe预训练词向量可以有效地捕捉单词之间的语义关系,被广泛应用于自然语言处理领域的各种任务,例如文本分类、命名实体识别和机器翻译等。

        Glove有四大类,根据数据量不同进行区分,相同数据下又根据向量长度分

        a.Wikipedia 2014 + Gigaword 5 (6B tokens, 400K vocab, uncased, 50d, 100d, 200d, & 300d vectors, 822 MB download): glove.6B.zip

        b.Common Crawl (42B tokens, 1.9M vocab, uncased, 300d vectors, 1.75 GB download): glove.42B.300d.zip

        c.Common Crawl (840B tokens, 2.2M vocab, cased, 300d vectors, 2.03 GB download): glove.840B.300d.zip

        d.Twitter (2B tweets, 27B tokens, 1.2M vocab, uncased, 25d, 50d, 100d, & 200d vectors, 1.42 GB download): glove.twitter.27B.zip

         在这里,采用Wikipedia 2014 + Gigaword 5 中的100d,即词向量长度为100,向量的token数量有6B。

        下载好的GloVe词向量矩阵是一个txt文件,一行是一个词和词向量,中间用空格隔开,因此加载该预训练词向量矩阵可以这样。

        原始GloVe预训练词向量有40万个词,在这里只关心词表中有的词,因此可以在加载字典时加一行过滤,即在词表中的词,才去获取它的词向量。

        在本案例中,词表大小是2万,根据匹配,只有19720个词在GloVe中找到了词向量,其余的词向量就需要随机初始化。

        获取GloVe预训练词向量字典后,需要把词向量放到embedding层中的矩阵,对弈embedding层来说,一行是一个词的词向量,因此通过词表的序号找到对应的行,然后把预训练词向量放进去即可,代码如下:

9.训练及实验记录

        准备好了数据和模型,接下来按照常规模型训练即可。

        这里将会做一些对比实验,包括模型对比:

         a.RNN vs LSTM

        b.有预训练词向量 vs 无预训练词向量

       c. 冻结预训练词向量 vs 放开预训练词向量

        具体指令如下,推荐放到bash文件中,一次性跑

        实验结果如下所示:

        1. RNN整体不work,经过分析发现设置的文本token长度太长,导致RNN梯度消失,以至于无法训练。调整 text_max_len为50后,train acc=0.8+, val=0.62,整体效果较差。

         2. 有了预训练词向量要比没有预训练词向量高出10多个点。

         3. 放开词向量训练,效果会好一些,但是不明显。

        补充实验:将RNN模型的文本最长token数量设置为50,其余保持不变,得到的三种embedding方式的结果如下:

        结论:

        1. LSTM较RNN在长文本处理上效果更好

        2. 预训练词向量在小样本数据集上很关键,有10多个点的提升

        3. 放开与冻结embedding层训练,效果差不多

10.小结

        本小节通过电影影评数据集实现文本分类任务,通过该任务可以了解:

        1. 文本预处理机制:包括清洗、分词、词频统计、词表构建、词表截断、UNK与PAD特殊词设定等。

        2. 预训练词向量使用:包括GloVe的下载及加载、nn.embedding层的设置 。

        3. RNN系列网络模型使用:大致了解循环神经网络的输入/输出是如何构建,如何配合fc层实现文本分类。

         4. RNN可接收的文本长度有限:文本过长,导致梯度消失,文本过短,导致无法捕获更多文本信息,因此推荐采用 LSTM等门控机制的模型。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mfbz.cn/a/752696.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

Terraform基础概念一

Terraform基础概念一 1.Infrastructure-as-Code(IaC)概念1.1 IaC优势1.2 IaC工具1.3 IaC的两种方式 2.Terraform基础概念2.1 Terraform工作原理2.2 Terraform 工作流 3.总结 1.Infrastructure-as-Code(IaC)概念 基础设施即代码(Infrastructure-as-Code,…

告别 “屎山” 代码,务必掌握这14 个 SpringBoot 优化小妙招

插: AI时代,程序员或多或少要了解些人工智能,前些天发现了一个巨牛的人工智能学习网站,通俗易懂,风趣幽默,忍不住分享一下给大家(前言 – 人工智能教程 ) 坚持不懈,越努力越幸运,大家…

flash-Attention2安装和使用

flash-Attention2安装和使用 文章目录 flash-Attention2安装和使用写在前面解决方案 写在前面 就怕你不知道怎么查 pytorch、cuda 的版本 配置cuda:vim ~/.bashrc export CUDA_HOME/usr/local/cuda/ export PATH$PATH:$CUDA_HOME/bin export LD_LIBRARY_PATH$LD_LIB…

鉴源实验室·基于MQTT协议的模糊测试研究

作者 | 张渊策 上海控安可信软件创新研究院工控网络安全组 来源 | 鉴源实验室 社群 | 添加微信号“TICPShanghai”加入“上海控安51fusa安全社区” 随着物联网技术的快速发展,越来越多的设备加入到互联网中,形成了庞大的物联网系统。这些设备之间的通信…

【Sklearn-线性回归驯化】史上最为全面的预测分析的基石-线性回归大全

【Sklearn-驯化】史上最为全面的预测分析的基石-线性回归大全 本次修炼方法请往下查看 🌈 欢迎莅临我的个人主页 👈这里是我工作、学习、实践 IT领域、真诚分享 踩坑集合,智慧小天地! 🎇 免费获取相关内容文档关注&…

Java8新特性stream的原理和使用

这是一种流式惰性计算&#xff0c;整体过程是&#xff1a; stream的使用也异常方便&#xff0c;可以对比如List、Set之类的对象进行流式计算&#xff0c;挑出最终想要的结果&#xff1a; List<Timestamp> laterTimes allRecords.stream().map(Record::getTime).filter…

电脑音频剪辑怎么操作?分享六个简单的音频剪辑技巧【常用】

音频剪辑的需求越来越多&#xff0c;大多数短视频中的音乐都是大家后期制作的&#xff0c;主要目的就就是让视频观看起来更有趣。音频剪辑的方法有很多&#xff0c;比较好用的可以借助第三方音频剪辑软件。操作简单&#xff0c;对没有任何剪辑经验的小白用户来说十分友好。 本文…

java简易计算器(多种方法)

parseDouble() 方法属于 java.lang.Double 类。它接收一个字符串参数&#xff0c;其中包含要转换的数字表示。如果字符串表示一个有效的 double&#xff0c;它将返回一个 double 值。 应用场景 parseDouble() 方法在以下场景中非常有用&#xff1a; 从用户输入中获取数字&a…

VUE大屏的开发过程(纯前端)

写在前面&#xff0c;博主是个在北京打拼的码农&#xff0c;工作多年做过各类项目&#xff0c;最近心血来潮在这儿写点东西&#xff0c;欢迎大家多多指教。 对于文章中出现的任何错误请大家批评指出&#xff0c;一定及时修改。有任何想要讨论和学习的问题可联系我&#xff1a;1…

2.4G无线通信芯片数据手册解读:Ci24R1南京中科微

今天&#xff0c;我非常荣幸地向您介绍这款引领行业潮流的2.4G射频芯片&#xff1a;Ci24R1。这款芯片&#xff0c;不仅是我们技术的结晶&#xff0c;更是未来无线通信的璀璨明星。 首先&#xff0c;让我们来谈谈Ci24R1的“速度”。2.4G射频芯片&#xff0c;凭借其卓越的数据传输…

Python基于逻辑回归分类模型、决策树分类模型、随机森林分类模型和XGBoost分类模型实现乳腺癌分类预测项目实战

说明&#xff1a;这是一个机器学习实战项目&#xff08;附带数据代码文档视频讲解&#xff09;&#xff0c;如需数据代码文档视频讲解可以直接到文章最后获取。 1.项目背景 在当今医疗健康领域&#xff0c;乳腺癌作为威胁女性健康的主要恶性肿瘤之一&#xff0c;其早期诊断与精…

OpenHarmony开发实战:HDF驱动开发流程

概述 HDF&#xff08;Hardware Driver Foundation&#xff09;驱动框架&#xff0c;为驱动开发者提供驱动框架能力&#xff0c;包括驱动加载、驱动服务管理、驱动消息机制和配置管理。并以组件化驱动模型作为核心设计思路&#xff0c;让驱动开发和部署更加规范&#xff0c;旨在…

Redis-Bitmap位图及其常用命令详解

1.Redis概述 2.Bitmap Bitmap 是 Redis 中的一种数据结构&#xff0c;用于表示位图&#xff08;bit array&#xff09;。 它通常用于处理大规模数据集中每个元素的状态&#xff0c;比如用户的在线/离线状态&#xff08;每个用户对应一个位&#xff0c;表示在线&#xff08;1&a…

[数据结构】——七种常见排序

文章目录 前言 一.冒泡排序二.选择排序三.插入排序四.希尔排序五.堆排序六.快速排序hoare挖坑法前后指针快排递归实现&#xff1a;快排非递归实现&#xff1a; 七、归并排序归并递归实现&#xff1a;归并非递归实现&#xff1a; 八、各个排序的对比图 前言 排序&#xff1a;所谓…

Mac中的xshell、xftp

ROYAL TSX 插件式支持远程连接linux、支持命令行、支持ftp、支持远程windows桌面。 免费版就足够使用了。&#xff08;支持维护一个Connections文件夹&#xff09; 需要在本地创建一个文件夹&#xff0c;用以保存链接信息 使用方法

Bytebase 2.20.0 - 支持为工单事件配置飞书个人通知

&#x1f680; 新功能 支持 Databricks。支持 SQL Server 的 TLS/SSL 连接。支持为工单事件配置飞书个人通知。支持限制用户注册的邮箱域名。 &#x1f514; 重大变更 将分类分级同步设置从数据库配置移至工作空间的全局配置。 SQL 编辑器只读模式下只允许执行 Redis 的只读…

抖音外卖服务商申请全域外卖系统源码部署,如何保证竞争力?

随着本地生活市场规模的逐渐扩大&#xff0c;多家互联网公司在加大投入力度的同时&#xff0c;也在不断调整其市场竞争策略&#xff0c;作为国内头部社交平台的抖音也不例外。就在近日&#xff0c;抖音发布了关于新增《【到家外卖】内容服务商开放准入公告》的意见征集通知&…

OSI七层模型TCP/IP四层面试高频考点

OSI七层模型&TCP/IP四层&面试高频考点 1 OSI七层模型 1. 物理层&#xff1a;透明地传输比特流 在物理媒介上传输原始比特流&#xff0c;定义了连接主机的硬件设备和传输媒介的规范。它确保比特流能够在网络中准确地传输&#xff0c;例如通过以太网、光纤和无线电波等媒…

SCI二区复现|体育场观众优化算法(SSO)原理及实现

目录 1.背景2.算法原理2.1算法思想2.2算法过程 3.结果展示4.参考文献5.代码获取 1.背景 2024年&#xff0c;M Nemati受到体育场观众的行为对比赛中球员行为的影响启发&#xff0c;提出了体育场观众优化算法&#xff08;Stadium Spectators Optimizer, SSO&#xff09;。 2.算法…

2023年第十四届蓝桥杯JavaB组省赛真题及全部解析(下)

承接上文&#xff1a;2023年第十四届蓝桥杯JavaB组省赛真题及全部解析&#xff08;下&#xff09;。 目录 七、试题 G&#xff1a;买二赠一 八、试题 H&#xff1a;合并石子 九、试题 I&#xff1a;最大开支 十、试题 J&#xff1a;魔法阵 题目来自&#xff1a;蓝桥杯官网…