6.2 通过构建情感分类器训练词向量

        在上一节中,我们简要地了解了词向量,但并没有去实现它。在本节中,我们将下载一个名为IMDB的数据集(其中包含了评论),然后构建一个用于计算评论的情感是正面、负面还是未知的情感分类器。在构建过程中,还将为 IMDB 数据集中存在的词进行词向量的训练。我们将使用一个名为 torchtext 的库,这个库使下载、向量化文本和批处理等许多过程变得更加容易。训练情感分类器将包括以下步骤。

  1. 下载 IMDB 数据并对文本分词;
  2. 建立词表;
  3. 生成向量的批数据;
  4. 使用词向量创建网络模型;
  5. 训练模型。

6.2.1 下载 IMDB 数据并对文本分词

        对于与计算机视觉相关的应用,我们使用过 torchvision 库。它提供了许多实用功能,并帮助我们构建计算机视觉应用程序。同样,有一个名为 torchtext 的库,它也是 PyTorch 的一部分,它与 PyTorch 一起工作,通过为文本提供不同的数据加载器和抽象,简化了许多与自然语言处理相关的活动。在本书写作时,torchtext 没有包含在 PyTorch 包内,需要独立安装。可以在计算机的命令行中运行以下代码来安装torchtext:

pip install torchtext

        安装完成后就可以使用它了。torchtext 提供了两个重要的模块:torchtext.data和torchtext.datasets。

        1. torchtext.data

        torchtext.data 实例定义了一个名为 Field 的类,它可以用来定义数据如何读取和分词。让我们看一下使用它来准备 IMDB 数据集的示例:

from torchtext import data
TEXT = data.Field(lower=True, batch_first=True, fix_length=20)
LABEL = data.Field(sequential=False)

        在上述代码中,我们定义了两个 Field 对象,一个用于实际的文本,另一个用于标签数据。对于实际的文本,我们期望 torchtext 将所有文本都小写并对文本分词,同时将其修整为最大长度为20。如果我们正在为生产环境构建应用程序,则可以将长度修正为更大的数字。当然对于当前练习的例子,20的长度够用了。Field 的构造函数还接受另一个名为 tokenize 的参数,该参数默认使用str.split 函数。还可以指定spaCy作为参数或任何其他分词器。我们的例子将使用 str.split。

        2. torchtext.datasets

        torchtext.datasets 实例提供了使用不同数据集的封装,如IMDB、TREC(问题分类)、语言建模(WikiText-2)和一些其他数据集。我们将使用 torch.datasets 下载 IMDB 数据集并将其拆分为 train 和 test 数据集。以下代码执行此操作,当第一次运行它时,可能需要几分钟,具体取决于网络连接速度,因为它是从 Internet 上下载 IMDB 数据集的:

train,test=datasets.IMDB.splits(TEXT,LABEL)

        之前的数据集的 IMDB 类抽象出了下载、分词和将数据库拆分为 train 和 test 数据集涉及的所有复杂度。train.fields 包含一个字典,其中 TEXT 是键,值是 LABEL。让我们看看 train.fields 和 train 集合的每个元素:

print('train.fields',train.fields)

        从这些结果中可以看到,单个元素包含了一个字段 text 和表示 text 的所有 token,以及包含了文本标签的字段 label。现在已准备好对 IMDB 数据集进行批处理了。

6.2.2 构建词表

        当为 thor_review 创建独热编码时,同时创建了一个作为词表的 word2idx 字典,它包含文档中唯一词的所有细节。torchtext 实例使处理更加容易。在加载数据后,可以调用 build_vocab 并传入负责为数据构建词表的必要参数。以下代码说明了如何构建词表:

TEXT.build_vocab(train,vectors=GloVe(name=,6B,dim=300),max_size=10000,min_freq=10)
LABEL.build_vocab(train)

        在上述代码中,传入了需要构建词表的 train 对象,并让它使用维度为 300 的预训练词向量来初始化向量。当使用预训练权重训练情感分类器时,build_vocab 对象只是下载并创建稍后将使用的维度。max_size 实例限制了词表中词的数量,而min_freg删除了出现不超过10 次的词,其中 10是可配置的。
        当词汇表构建完成后,我们就可以获得例如词频、词索引和每个词的向量表示等不同的值。下面的代码演示了如何访问这些值:

print(TEXT.vocab.freqs)

        以下代码演示了如何访问结果:

print(TEXT.vocab.vectors)

        使用 stoi 访问包含词及其索引的字典。

6.2.3 生成向量的批数据

        torchtext 提供了 BucketIterator,它有助于批处理所有文本并将词替换成词的索引。BucketIterator 实例带有许多有用的参数,如batch_size、device(GPU或CPU)和 shuffle (是否必须对数据进行混洗)。下面的代码演示了如何为 train 和 test 数据集创建生成批处理的迭代器:

train_iter, test_iter = data.BucketIterator.splits((train, test),
batch_size=128,device=-1,shuffle=True)
#device = -1 表示使用 cpu,设置为 None 时使用 gpu.

        上述代码为 train 和 test 数据集提供了一个 BucketIterator 对象。以下代码将说明如何创建 batch 并显示 batch 的结果:

batch = next(iter(train_iter))
batch.text

        从上面代码段的结果中,可以看到文本数据如何转换为 batch_size * fix_len (即128x20) 大小的矩阵。

6.2.4 使用词向量创建网络模型

        我们之前简要地讨论过词向量。在本节中,我们将创建作为网络架构的一部分的词向量,并训练整个模型用以预测每个评论的情感。在训练结束时,将得到一个情感分类器模型,以及 IMDB 数据集的词向量。以下代码演示了如何使用词向量创建用于情感预测的网络架构:

class EmbNet(nn.Module):
    def _init_(self,emb_size,hidden_sizel,hidden_size2 = 400):
        super()._init_()
        self.embedding = nn.Embedding(emb_size,hidden_sizel)
        self.fc = nn.Linear(hidden_size2,3)
    def forward(self,x):
        embeds = self.embedding(x).view(x.size(0),-1)
        out = self.fc(embeds)
        return F.log_softmax(out, dim = -1)

        在上述代码中,EmbNet 创建了情感分类模型。在_init_函数中,我们使用两个参数初始化了 nn.Embedding 类的一个对象,它接收两个参数,即词表的大小和希望为每个单词创建的维度。由于限制了唯一单词的数量,因此词表的大小将为10,000,并且我们可以从一个小的向量尺寸(比如10)开始。为了快速运行程序,有必要使用个小尺寸的向量值,但是当试图为生产系统构建应用程序时,请使用大尺寸的词向量。我们还有一个线性层,将词向量映射到情感的类别(如正面、负面或未知)。
        forward 函数确定了输入数据的处理方式。对于批量大小为 32 以及最大长度为 20 个词的句子,输入形状为 32x20。第一个 embedding 层充当查找表,用相应的词向量替换掉每个词。对于向量维度 10,当每个词被其相应的词向量替换时,输出形状变成了 32x20x10。view 函数将使 embedding 层的结果变得扁平。传递给 view 函数的第一个参数将保持维数不变。在我们的例子中,我们不希望组合来自不同批次的数据,因此保留第一个维数并将张量中的其余值扁平化。在应用 view 函数后,张量形状变为 32x200。全连接层将扁平化的词向量映射到类别的编号。定义了网络后就可以像往常一样训练它了。

6.2.5 训练模型

        训练模型与在构建图像分类器时看到的非常类似,因此将使用相同的函数。我们把批数据传入模型并计算输出和损失,然后优化包括词向量权重在内的模型权重。以下代码执行此操作:

def fit(epoch,model,data_loader,phase=,training,,volatile=False):
    if phase == 'training':
        model.train()
    if phase == 'validation':
        model.eval()
        volatile = True 
    running_loss = 0.0
    running_correct = 0
    for batch_idx r batch in enumerate(data_loader):
        text, target = batch.text r batch.label
        if is_cuda:
            text,target = text.cuda(), target.cuda()
        if phase =='training':
            optimizer.zero_grad()
        output = model(text)
        loss = F.nll_loss(output,target)
        running loss += F.nll loss(output,target,size_average=False).data[0]
        preds = output.data.max(dim=1,keepdim=True)[1]
        running_correct += preds.eq(target.data.view_as(preds)).cpu().sum()
        if phase == 'training': 
            loss.backward()
            optimizer.step()
        loss = running_loss/len(data_loader.dataset)
        accuracy = 100. * running_correct/len(data_loader.dataset)
        print(f'{phase} loss is (loss:(5}.{2}} and {phase} accuracy is 
            {running_correct}/{len(data_loader.dataset)}{accuracy:{10}.{4}}')
        return loss,accuracy
    train_losses,train_accuracy = [],[]
    val_losses,val_accuracy = [],[]
    train_iter.repeat = False
    test_iter.repeat = False
    for epoch in range(1,10):
        epoch_loss,epoch_accuracy = fit(epoch,model,train_iter,phase='training')
        val_epoch_loss,val_epoch_accuracy = 
            fit(epoch,model,test_iter,phase='validation')
        train_losses.append(epoch_loss)
        train_accuracy.append(epoch_accuracy)
        val_losses.append(val_epoch_loss)
        val_accuracy.append(val_epoch_accuracy)

        在上述代码中,通过传入为批处理数据创建的 BucketIterator 对象来调用 fit 方法。默认情况下,迭代器不会停止生成批数据,因此必须将 BucketIterator 对象的 repeat 变量设置为 False。如果不将 repeat 变量设置为 False,那么 fit 函数将无限地运行。模型训练10轮后得到的验证准确率约为70%。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/746919.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

Windows上PyTorch3D安装踩坑记录

直入正题,打开命令行,直接通过 pip 安装 PyTorch3D : (python11) F:\study\2021-07\python>pip install pytorch3d Looking in indexes: http://mirrors.aliyun.com/pypi/simple/ ERROR: Could not find a version that satisfies the requirement p…

JS(JavaScript)入门指南(DOM、事件处理、BOM、数据校验)

天行健,君子以自强不息;地势坤,君子以厚德载物。 每个人都有惰性,但不断学习是好好生活的根本,共勉! 文章均为学习整理笔记,分享记录为主,如有错误请指正,共同学习进步。 玉阶生白露,夜久侵罗袜。 却下水晶帘,玲珑望秋月。 ——《玉阶怨》 文章目录 一、DOM操作1. D…

从零开始做题:有手就行

1 题目 2 解题 ARPHCR工具破解 得到flag DASCTF{2b3767763885a019b65bbfe9d1136c3b}

从零开始学docker(四)-安装mysql及主从配置(一)

mysql MySQL是一个关系型数据库管理系统,由瑞典MySQL AB 公司开发,属于 Oracle 旗下产品。MySQL 是最流行的关系型数据库管理系统之一,在 WEB 应用方面,MySQL是最好的 RDBMS (Relational Database Management System,关…

仿Photoshop利用曲线对图像调整亮度与色彩

曲线调整是Photoshop的最常用的重要功能之一。对于一个RGB图像, 可以对R, G, B 通道进行独立的曲线调整,即,对三个通道分别使用三条曲线(Curve)。还可以再增加一条曲线对 三个通道进行整体调整。 因此,对一个图像&a…

C++初学者指南-2.输入和输出---流输入和输出

C初学者指南-2.输入和输出—流输入和输出 文章目录 C初学者指南-2.输入和输出---流输入和输出1.定制输入/输出1.1 示例:点坐标输入/输出1.2 流操作符1.3(一部分)标准库流类型 2. 工具2.1 用getline读取行 2.2 用ignore进行跳转2.3 格式化操作…

武汉星起航:全球化舞台,中国跨境电商品牌力与竞争力双提升

随着全球化步伐的加快和数字技术的迅猛发展,跨境出口电商模式已经成为中国企业海外拓展的重要战略选择。这一模式不仅为中国的中小型企业提供了进军全球市场的机会,更为它们在全球舞台上展示独特的竞争优势提供了强有力的支撑。武汉星起航将从市场拓宽、…

STL迭代器的基础应用

STL迭代器的应用 迭代器的定义方法: 类型作用定义方式正向迭代器正序遍历STL容器容器类名::iterator 迭代器名常量正向迭代器以只读方式正序遍历STL容器容器类名::const_iterator 迭代器名反向迭代器逆序遍历STL容器容器类名::reverse_iterator 迭代器名常量反向迭…

问界M9累计大定破10万,创中国豪车新纪录

ChatGPT狂飙160天,世界已经不是之前的样子。 更多资源欢迎关注 6月26日消息,华为常务董事、终端BG董事长、智能汽车解决方案BU董事长余承东今日宣布,问界M9上市6个月,累计大定突破10万辆。 这一成绩,也创造了中国市场…

5款名不见经传的小众软件,简单好用

​ 我们在使用一些流行的软件的时候,往往会忽略一些知名度不高但是功能非常强大的软件,有的是因为小众,有的是因为名不见经传,总之因为不出名,有许多的好用的软件都不为大众所知道。 1.桌面美化——Win10 Widgets ​…

为什么需要对数据质量问题进行根因分析?根因分析该怎么做?

在当今的商业环境中,数据已成为企业决策的核心。然而,数据的价值高度依赖于其质量。低质量的数据不仅会降低分析的准确性,还可能导致错误的决策,从而影响企业的竞争力和市场表现。因此,识别和解决数据质量问题是数据管…

c#关键字 ArgumentOutOfRangeException .? IEnumerable string.Join

c# ArgumentOutOfRangeException ArgumentOutOfRangeException 是 C# 中表示某个参数值超出了方法或属性定义的有效范围时引发的一个异常。这个异常通常在尝试访问数组、集合、字符串等的无效索引,或者当传递给方法或属性的参数不在其有效范围内时发生。 例如&…

浅学JVM

一、基本概念 目录 一、基本概念 二、JVM 运行时内存 1、新生代 1.1 Eden 区 1.2. ServivorFrom 1.3. ServivorTo 1.4 MinorGC 的过程 (复制- >清空- >互换) 1.4.1:eden 、servicorFrom 复制到ServicorTo,年龄1 …

K8S集群进行分布式负载测试

使用K8S集群执行分布式负载测试 本教程介绍如何使用Kubernetes部署分布式负载测试框架,该框架使用分布式部署的locust 产生压测流量,对一个部署到 K8S集群的 Web 应用执行负载测试,该 Web 应用公开了 REST 格式的端点,以响应传入…

C++用Crow实现一个简单的Web程序,实现动态页面,向页面中输入数据并展示

Crow是一个轻量级、快速的C微框架,用于构建Web应用程序和RESTful API。 将处理前端页面的POST请求以添加数据的逻辑添加到 /submit 路由中,并添加了一个新的路由 / 用于返回包含输入框、按钮和表格的完整页面。当用户向表格添加数据时,JavaS…

创新与责任并重!中国星坤连接器的可持续发展战略!

在当今全球化的商业环境中,企业的社会责任、技术创新和产品质量是企业可持续发展的三大支柱。中国星坤正是这样一家企业,它在电子连接技术领域以其卓越的技术创新、坚定的环保责任和严格的生产品控而著称。本文将深入探讨星坤科技如何通过其FAE技术团队的…

浏览器扩展V3开发系列之 chrome.contextMenus 右键菜单的用法和案例

【作者主页】:小鱼神1024 【擅长领域】:JS逆向、小程序逆向、AST还原、验证码突防、Python开发、浏览器插件开发、React前端开发、NestJS后端开发等等 chrome.contextMenus 允许开发者向浏览器的右键菜单添加自定义项。 在使用 chrome.contextMenus 之前…

CMN-700(1)CMN-700概述

本章介绍CMN-700,这是用于AMBA5 CHI互连,且可根据需要定制的网格拓扑结构。 1. 关于CMN‐700 CMN‐700是一种可配置扩展的一致性互连网络,旨在满足高端网络和企业计算应用中使用的一致性网络系统的功率、性能和面积(PPA)要求。支持1-256个处…

ES6深潜指南:解锁JavaScript类与继承的高级技巧,让您的代码更加优雅

前言 随着前端技术的迅猛发展,JavaScript已经成为构建现代Web应用不可或缺的编程语言。ES6(ECMAScript 2015)引入了许多期待已久的特性,其中类(Classes)和继承机制的引入,极大地增强了JavaScrip…

gc.log中 CMS-concurrent-abortable-preclean

问题 在gc日志中看到 2024-06-26T16:16:07.5040800: 64690272.666: [CMS-concurrent-abortable-preclean-start]CMS: abort preclean due to time 2024-06-26T16:16:12.5530800: 64690277.716: [CMS-concurrent-abortable-preclean: 1.052/5.049 secs] [Times: user1.33 sys0…