ROCm上情感分析:使用循环神经网络

15.2. 情感分析:使用循环神经网络 — 动手学深度学习 2.0.0 documentation (d2l.ai)

代码

import torch
from torch import nn
from d2l import torch as d2l

batch_size = 64
train_iter, test_iter, vocab = d2l.load_data_imdb(batch_size)

class BiRNN(nn.Module):
    def __init__(self, vocab_size, embed_size, num_hiddens,
                 num_layers, **kwargs):
        super(BiRNN, self).__init__(**kwargs)
        self.embedding = nn.Embedding(vocab_size, embed_size)
        # 将bidirectional设置为True以获取双向循环神经网络
        self.encoder = nn.LSTM(embed_size, num_hiddens, num_layers=num_layers,
                                bidirectional=True)
        self.decoder = nn.Linear(4 * num_hiddens, 2)

    def forward(self, inputs):
        # inputs的形状是(批量大小,时间步数)
        # 因为长短期记忆网络要求其输入的第一个维度是时间维,
        # 所以在获得词元表示之前,输入会被转置。
        # 输出形状为(时间步数,批量大小,词向量维度)
        embeddings = self.embedding(inputs.T)
        self.encoder.flatten_parameters()
        # 返回上一个隐藏层在不同时间步的隐状态,
        # outputs的形状是(时间步数,批量大小,2*隐藏单元数)
        outputs, _ = self.encoder(embeddings)
        # 连结初始和最终时间步的隐状态,作为全连接层的输入,
        # 其形状为(批量大小,4*隐藏单元数)
        encoding = torch.cat((outputs[0], outputs[-1]), dim=1)
        outs = self.decoder(encoding)
        return outs

embed_size, num_hiddens, num_layers = 100, 100, 2
devices = d2l.try_all_gpus()
net = BiRNN(len(vocab), embed_size, num_hiddens, num_layers)

def init_weights(m):
    if type(m) == nn.Linear:
        nn.init.xavier_uniform_(m.weight)
    if type(m) == nn.LSTM:
        for param in m._flat_weights_names:
            if "weight" in param:
                nn.init.xavier_uniform_(m._parameters[param])
net.apply(init_weights);

glove_embedding = d2l.TokenEmbedding('glove.6b.100d')

embeds = glove_embedding[vocab.idx_to_token]
embeds.shape

net.embedding.weight.data.copy_(embeds)
net.embedding.weight.requires_grad = False

lr, num_epochs = 0.01, 5
trainer = torch.optim.Adam(net.parameters(), lr=lr)
loss = nn.CrossEntropyLoss(reduction="none")
d2l.train_ch13(net, train_iter, test_iter, loss, trainer, num_epochs,
    devices)

#@save
def predict_sentiment(net, vocab, sequence):
    """预测文本序列的情感"""
    sequence = torch.tensor(vocab[sequence.split()], device=d2l.try_gpu())
    label = torch.argmax(net(sequence.reshape(1, -1)), dim=1)
    return 'positive' if label == 1 else 'negative'

predict_sentiment(net, vocab, 'this movie is so great')

predict_sentiment(net, vocab, 'this movie is so bad')

代码解析

这段代码实现了一个用于情感分析的双向循环神经网络(BiRNN)。下面我将逐部分用中文解析它:
1. 导入所需的库和模块:

import torch
from torch import nn
from d2l import torch as d2l

这里导入了PyTorch库、神经网络模块`nn`和基于PyTorch的深度学习库`d2l`(深度学习的一本书)。
2. 加载数据集:

batch_size = 64
train_iter, test_iter, vocab = d2l.load_data_imdb(batch_size)

加载IMDB电影评论数据集,并用迭代器`train_iter`和`test_iter`进行训练和测试。`vocab`是数据集的词汇表。
3. 定义双向循环神经网络(BiRNN)模型:

class BiRNN(nn.Module):
    ...

创建了一个名为`BiRNN`的类,用于定义双向LSTM模型。模型有一个嵌入层(`embedding`),将词汇映射到向量空间。LSTM层(`encoder`)设定为双向,输出经过全连接层(`decoder`)得到最终的分类结果。
4. 初始化模型参数:

def init_weights(m):
    ...
net.apply(init_weights);

init_weights函数用于模型参数的初始化。`net.apply(init_weights);`使用这个函数来应用参数初始化。
5. 加载预训练的词向量:

glove_embedding = d2l.TokenEmbedding('glove.6b.100d')
embeds = glove_embedding[vocab.idx_to_token]
net.embedding.weight.data.copy_(embeds)
net.embedding.weight.requires_grad = False

使用GloVe预训练的100维词向量,并将它们复制到嵌入层`net.embedding`。同时设置`requires_grad = False`使得这些词向量在训练中不被更新。
6. 训练模型:

lr, num_epochs = 0.01, 5
trainer = torch.optim.Adam(net.parameters(), lr=lr)
loss = nn.CrossEntropyLoss(reduction="none")
d2l.train_ch13(net, train_iter, test_iter, loss, trainer, num_epochs, devices)

设置学习率和迭代次数,使用Adam优化器和交叉熵损失函数。用`d2l.train_ch13`函数来训练和评估模型。
7. 定义预测函数:

def predict_sentiment(net, vocab, sequence):
    ...

这个函数用于预测给定文本序列的情感标签(积极或消极)。
8. 使用模型进行预测:

predict_sentiment(net, vocab, 'this movie is so great')
predict_sentiment(net, vocab, 'this movie is so bad')

调用`predict_sentiment`函数分别对两个句子进行情感预测。
整体来看,这段代码主要是利用循环神经网络对电影评论的情感进行分类,它通过加载预训练好的词向量,构建一个双向LSTM网络,并在IMDB评论数据集上进行训练和测试。最后定义了一个实用函数,用于预测输入句子的情感倾向。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/639791.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

sqlserver的查询(三)

目录 10. group by(分组) 11. having(对分组后的信息过滤) 可能从这里开始,执行顺序越来越显得重要了!!! 10. group by(分组) 这个查询相比前面会有一些困难; 格式:group by 字段的集合; 功…

GD32F103RCT6/GD32F303RCT6-UCOSIII底层移植(4)消息队列实验

本文章基于兆易创新GD32 MCU所提供的2.2.4版本库函数开发 后续项目主要在下面该专栏中发布: 手把手教你嵌入式国产化_不及你的温柔的博客-CSDN博客 感兴趣的点个关注收藏一下吧! 电机驱动开发可以跳转: 手把手教你嵌入式国产化-实战项目-无刷电机驱动&am…

Java | Leetcode Java题解之第109题有序链表转换二叉搜索树

题目: 题解: class Solution {ListNode globalHead;public TreeNode sortedListToBST(ListNode head) {globalHead head;int length getLength(head);return buildTree(0, length - 1);}public int getLength(ListNode head) {int ret 0;while (head…

kimi :系统框架 实力学习

学海无涯,你,准备好了吗? 学习一个新的嵌入式系统架构,你"只"需要 - 1 - 手册/速查函数(对于比较大的架构,F12往往返回多个结果,增加混乱); 2 - 源代码和VS&am…

20.有序性与内存屏障

文章目录 有序性与内存屏障1.重排序1.1.编译器重排序1.2.CPU重排序1.2.1.指令级重排序1.2.2.内存系统重排序1.3.As-if-Serial规则 2.内存屏障2.1.硬件层面的内存屏障2.1.2.写屏障2.1.3.读屏障2.1.4.全屏障 2.2.硬件层的内存屏障作用2.3.案例 有序性与内存屏障 有序性 与 可见性…

混合组网VS传统网络:智能硬件混合组网优劣势浅要解析

智能硬件混合组网是一种利用多种通信技术相结合的方法,以实现更灵活、更可靠的网络连接。通过蓝牙、Wi-Fi、LoRa、4G相互之间的不同通讯方式,根据应用场景的不同以及现场实际环境,优选最佳物联网混合组网方案,以达到部署最便捷性价…

云曦2024年春季学期期中考复现

目录 Web Web_SINGIN 简简单单的文件上传 好玩的PHP 渗透的本质 简简单单的sql re baby_re easy xor Crypto easy_rsa Rsa2 Crypto_Singin Pwn pwn_Sing Misc easy_singin Xjpg 流量分析1 流量分析3 流量分析2 Web Web_SINGIN 1.使用右键检查&#xff0c…

IMU内参标定(理论)

1、内参标定标定什么? 生产零偏、标度因数误差、安装误差 2、现象是什么? 零偏现象:即使没有任何运动或旋转,IMU传感器仍然会输出一个非零的信号。零偏是一个恒定的误差,导致测量值始终偏离实际值。对于加速度计&am…

DolphinDB 携手九鞅科技,助力固收投研效能飞跃

随着金融市场开放的广度与深度不断拓宽,金融产品呈现出多样化的发展态势,其中债券投资组合凭借其低风险性、高流动性与稳健的收益表现,逐渐成为投资理财领域备受瞩目的焦点。投资经理不仅需要了解哪些债券值得投资,更要对债券投资…

【GESP试卷】2024年03月Scratch四级试卷

2024年GESP03月认证Scratch四级试卷 分数:100 题数:27 一、单选题(共15题,每题2分,共30分) 010203040506070809101112131415CDBBACBCDCDADBA 1、小杨的父母最近刚刚给他买了一块华为手表,他说手表上跑的是鸿蒙&…

朋友正确交往方式,以及保留有效沟通,才是对朋友的尊重!

人生就像一列火车,从生命之初驶向生命的终点,路途上有很多站点,每一个站点都会遇到不同的人,结交各式各样的朋友,中间有人下车,有人上车,有人与你走着走着就散了,有人偶有相见却已是…

Qt 科目一考试系统(有源码)

项目源码和资源:科目一考试系统: qt实现科目一考试系统 一.项目概述 该项目是一个基于Qt框架开发的在线考试系统,主要实现了考试题目的随机抽取、考试时间限制、成绩统计等功能。用户可以通过界面操作进行考试,并查看自己的考试成绩。 二.技…

计算机网络之应用层知识点总结

6.1 网络应用模型 (1)应用层概述 (2)网络应用模型的介绍 客户/服务器(C/S)模型 P2P模型 6.2 域名解析系统DNS (1)DNS系统介绍 (2)域名 (3&#…

AI爆文写作:标题需要什么?情绪炸裂,态度要激烈,行为要夸张!

现在这个传播环境下,在公域中,轻声细语,慢慢的说,无法吸引到注意,没有人搭理。 标题要需要情绪张扬,态度激烈,行为夸张,大声喧闹。 唐韧的用户群是互联网产品经理,阅读量…

小猫咪的奇幻冒险:一个简单的Python小游戏

新书上架~👇全国包邮奥~ python实用小工具开发教程http://pythontoolsteach.com/3 欢迎关注我👆,收藏下次不迷路┗|`O′|┛ 嗷~~ 目录 一、游戏简介与演示 二、游戏开发与运行 1. 环境搭建 2. 代码解析 3. 加速机制 三、游戏…

油猴插件刷学习通

油猴插件刷学习通 edge浏览器 浏览器进入这个网址https://microsoftedge.microsoft.com/addons/search/%E6%B2%B9%E7%8C%B4tampermonkey?hlzh-CN。 点我自动进入 点那个绿色的,点击获取 油猴插件下载在了这里 找到油猴图标,获取新脚本。 安装 …

DPDK实践之(1)dpdk基础使用

DPDK实践之(1)dpdk基础使用 Author: Once Day Date: 2024年5月19日 一位热衷于Linux学习和开发的菜鸟,试图谱写一场冒险之旅,也许终点只是一场白日梦… 漫漫长路,有人对你微笑过嘛… 全系列文档可参考专栏:Linux基础知识_Once…

Unity Terrain Adjust插件使用教程

一、Terrain Adjust插件介绍 二、插件下载以及导入 1、官方下载地址:Terrain Adjust 2、积分下载地址:Terrain Adjust 下载好之后,回到Unity当中,导入下载好之后的unitypackage包 三、插件使用 1、在使用之前一定要在场景中新…

【数据结构】二叉树的功能实现

文章目录 关于二叉树的创建如何创建二叉树实现二叉树的前、中、后序遍历层序遍历 关于二叉树的创建 在笔者的上一篇文章中堆进行了一个详细介绍,而二叉树是以堆为基础进行创建,它与堆的显著不同是 堆像是一个线性结构,堆的结构往往是一个数…

刷题之寻找重复数(leetcode)

寻找重复数 这题实际上就是变形的环形链表Ⅱ&#xff0c;下标为index的下一个元素是nums[index]&#xff0c;下下一个元素是nums[nums[index]] class Solution { public:int findDuplicate(vector<int>& nums) {int fast0;int slow0;while(1){fastnums[nums[fast]]…