《PyTorch深度学习实践》第十二讲循环神经网络基础

一、RNN简介

1、RNN网络最大的特点就是可以处理序列特征,就是我们的一组动态特征。比如,我们可以通过将前三天每天的特征(是否下雨,是否有太阳等)输入到网络,从而来预测第四天的天气。
       我们可以看RNN的网络结构如下:

二、RNN cell用法

import torch

batch_size = 1 # 批处理大小
seq_len = 3 # 序列长度
input_size = 4 # 输入维度
hidden_size = 2 # 隐藏层维度

cell = torch.nn.RNNCell(input_size=input_size, hidden_size=hidden_size)

# (seq, batch, features)
dataset = torch.randn(seq_len, batch_size, input_size)
print(dataset)
hidden = torch.zeros(batch_size, hidden_size)
print(hidden)

for idx, input in enumerate(dataset):
    print( '=' * 20, idx, '=' * 20)
    print( 'Input size: ', input.shape)
    hidden = cell(input, hidden)
    print( 'outputs size: ', hidden.shape)
    print(hidden)

三、RNN用法

import torch

batch_size = 1 # 批处理大小
seq_len = 3 # 序列长度
input_size = 4 # 输入维度
hidden_size = 2 # 隐藏层维度
num_layers = 4  # 隐藏层数量

cell = torch.nn.RNN(input_size=input_size, hidden_size=hidden_size, num_layers=num_layers)

# (seqLen, batchSize, inputSize)
inputs = torch.randn(seq_len, batch_size, input_size)
hidden = torch.zeros(num_layers, batch_size, hidden_size)
out, hidden = cell(inputs, hidden)

print( 'Output size:', out.shape)
print( 'Output:', out)
print( 'Hidden size: ', hidden.shape)
print( 'Hidden: ', hidden)

四、Embedding

把input变为稠密的数据

代码:

import torch

# parameters
num_class = 4
input_size = 4
hidden_size = 8
embedding_size = 10
num_layers = 2
batch_size = 1
seq_len = 5

# 准备数据集
idx2char = ['e', 'h', 'l', 'o']
x_data = [[1, 0, 2, 2, 3]]  # (batch, seq_len)
y_data = [3, 1, 2, 3, 2]    # (batch * seq_len)

inputs = torch.LongTensor(x_data)   # Input should be LongTensor: (batchSize, seqLen)
labels = torch.LongTensor(y_data)   # Target should be LongTensor: (batchSize * seqLen)

# 构建模型
class Model(torch.nn.Module):
    def __init__(self):
        super(Model, self).__init__()
        self.emb = torch.nn.Embedding(input_size, embedding_size)
        self.rnn = torch.nn.RNN(input_size=embedding_size, hidden_size=hidden_size, num_layers=num_layers, batch_first=True)
        self.fc = torch.nn.Linear(hidden_size, num_class)

    def forward(self, x):
        hidden = torch.zeros(num_layers, x.size(0), hidden_size)
        x = self.emb(x)  # (batch, seqLen, embeddingSize)
        x, _ = self.rnn(x, hidden)  # 输出(𝒃𝒂𝒕𝒄𝒉𝑺𝒊𝒛𝒆, 𝒔𝒆𝒒𝑳𝒆𝒏, hidden_size)
        x = self.fc(x)  # 输出(𝒃𝒂𝒕𝒄𝒉𝑺𝒊𝒛𝒆, 𝒔𝒆𝒒𝑳𝒆𝒏, 𝒏𝒖𝒎𝑪𝒍𝒂𝒔𝒔)
        return x.view(-1, num_class)  # reshape to use Cross Entropy: (𝒃𝒂𝒕𝒄𝒉𝑺𝒊𝒛𝒆×𝒔𝒆𝒒𝑳𝒆𝒏, 𝒏𝒖𝒎𝑪𝒍𝒂𝒔𝒔)

net = Model()

# 损失函数和优化器
criterion = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(net.parameters(), lr=0.05)

# 训练模型
for epoch in range(15):
    optimizer.zero_grad()
    outputs = net(inputs)
    loss = criterion(outputs, labels)
    loss.backward()
    optimizer.step()

    _, idx = outputs.max(dim=1)
    idx = idx.data.numpy()
    print('Predicted: ', ''.join([idx2char[x] for x in idx]), end='')
    print(', Epoch [%d/15] loss = %.3f' % (epoch + 1, loss.item()))

 运行结果:

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/420836.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

C++ Algorithm Tutorial (1)

中文版 c算法入门教程(1)_c怎么学习算法-CSDN博客 Cis a powerful and widely used programming language, and for those who want to delve deeper into programming and algorithms, mastering Cis an important milestone. This article will take you step by step to und…

electron+vue3全家桶+vite项目搭建【28】封装窗口工具类【2】窗口组,维护窗口关系

文章目录 引入实现效果思路主进程模块渲染进程模块测试效果 引入 demo项目地址 窗口工具类系列文章: 封装窗口工具类【1】雏形 我们思考一下窗口间的关系,窗口创建和销毁的一些动作,例如父子窗口,窗口组合等等,还有…

labview数组精讲

题主经过写文章一段时间的发现,许多同学对该软件的理解和编程能力是不太一样的,有些知识相对一些同学较为简单,但是有些同学提问就比较困难。那么针对这个问题,题主打算出一期说白话系列的专栏,在该栏目中用最通俗的大白话和例子去让大家深刻了解这个软件的功能和摸透他的…

【center-loss 中心损失函数】 原理及程序解释(更新中)

文章目录 前言问题引出open-set问题抛出 解决方法softmax函数、softmax-loss函数解决代码(center_loss.py)原理程序解释 代码运用 如何梯度更新首先了解一下基本的梯度下降算法然后 补充:外围知识模型 前言 学习一下: 中心损失函…

报错问题解决django.db.utils.OperationalError: (1049, “Unknown database ‘ mxshop‘“)

开发环境:ubuntu22.04 pycharm 功能:django连接使用mysql数据库,各项配置看似正常 报错: django.db.utils.OperationalError: (1049, "Unknown database mxshop") 分析检查原因: Setting的配置文件内&…

分布式版本控制系统 git

学习目标: 提示:了解及 学会使用 git 1、 含义: Git(读音为/gɪt/)是一个 开源的分布式版本控制系统 ,可以有效、高速地处理从很小到非常大的项目版本管理。 git 也是Linus Torvalds为了帮助管理Linux内核…

信号系统之快速傅里叶变换

1 使用复数DFT的实数DFT 本文的主题,如何使用 FFT 计算真正的 DFT? 由于 FFT 是一种用于计算复数 DFT 的算法,因此了解如何将实数 DFT 数据输入和输出复数 DFT 格式非常重要。图 12-1 比较了实数 DFT 和复数 DFT 存储数据的方式。实数 DFT …

VUE3 页面路由 router

VUE3 页面路由 router 1.理解路由流程1.1新建工程1.2 理解路由1.3 结果1.4补充 2.路由传递参数2.1.新建工程2.2.打开项目2.3.新建路由2.4 路由传递参数 3.路由嵌套配置3.1 基础3.2 子二级布局文件3.3 配置子二级路由3.4 父级页面链接路由3.5 结果3.6 细节:子二级目录…

String类的使用

String常用的构造方法 String的源码 内部是一个数组和hash值,涉及到常量池后续补充(常量池:存储相同的字符时只会存储一租) String的比较 equals()与:String里面为我们提供了许多方法,可直接调用&#xf…

Rocky Linux 运维工具 firewall-cmd

一、firewall-cmd​的简介 ​​firewall-cmd​是基于firewalld的防火墙管理工具。用户可以使用它来配置、监控和管理防火墙规则,包括开放端口、设置服务规则等。 二、firewall-cmd​​的参数说明 序号参数描述1​​–zone指定防火墙区域2–add-portxxx/tcp允许特定…

spring boot3解决跨域的几种方式

⛰️个人主页: 蒾酒 🔥系列专栏:《spring boot实战》 🌊山高路远,行路漫漫,终有归途。 目录 1.前言 2.何为跨域 3.跨域问题出现特征 4.方式一:使用 CrossOrigin 注解 5.方式二:自定义…

erp项目采购模块新增商品,商品价格计算demo

1:因为此前erp项目中的采购模块添加商品,实时计算单个商品预估价格及全部商品总额折磨了我很久,所以今天闲来无事优化一下,使用另一种思维来做一个类似demo,作为此前总结反思 2:原来项目模块是这样的 3&am…

达梦数据库DM8安装与配置

达梦数据库安装与配置 1 安装前准备 1.1 新建 dmdba 用户 创建用户所在的组,命令如下: groupadd dinstall创建用户,命令如下: useradd -g dinstall -m -d /home/dmdba -s /bin/bash dmdba修改用户密码,命令如下&am…

(每日持续更新)jdk api之OutputStreamWriter基础、应用、实战

博主18年的互联网软件开发经验,从一名程序员小白逐步成为了一名架构师,我想通过平台将经验分享给大家,因此博主每天会在各个大牛网站点赞量超高的博客等寻找该技术栈的资料结合自己的经验,晚上进行用心精简、整理、总结、定稿&…

MyBatis 学习(五)之 高级映射

目录 1 association 和 collection 介绍 2 案例分析 3 一对一关联和一对多关联 4 参考文档 1 association 和 collection 介绍 在之前的 SQL 映射文件中提及了 resultMap 元素的 association 和 collection 标签,这两个标签是用来关联查询的,它们的属…

中小企业“数智未来”行动|ZStack Cloud 荣获“推荐方案”奖

2月29日,以“数智未来 共创数字时代新篇章”为主题的中小企业“数智未来”行动在京成功举办,本次活动由中央广播电视总台央视频和中国中小企业协会作为联合观察单位,带来了一系列帮助中小企业成就业务新价值和数智化升级的优秀产品和方案&…

从入门到精通的Android进阶学习笔记整理,你有过迷茫吗

面试题 一般Android面试分为两部分:Java部分和Android部分,下面说一下自己面试过程遇到的一些具体题目和一些相关知识点。 一 JAVA相关 1)JAVA基础 1.java基本数据类型有哪些,int, long占几个字节 2. 和 equals有什…

Mint_21.3 drawing-area和goocanvas的FB笔记(二)

一、goocanvas安装 Linux mint 21.3 库中带有 libgoocanvas-2.0-dev, 用sudo apt install libgoocanvas-2.0-dev 安装,安装完成后,检查一个 /usr/lib/x86_64-linux-gnu 下是否有libgoocanvas.so的软件链接。如果没有,或是 .so.x 等类似后面…

QT Widget: 自定义Widget组件及创建和使用动静态库

学习自定义Widget组件,书中的案例: // 自定义QmyBattery组件 // QmyBattery.c #include "qmybattery.h"QmyBattery::QmyBattery(QWidget *parent) : QWidget(parent) {}/** 1.QPainter的viewport()与window()分别代表着物理坐标与逻辑坐标区域…

类加载的基本流程

⭐ 作者:小胡_不糊涂 🌱 作者主页:小胡_不糊涂的个人主页 📀 收录专栏:JavaEE 💖 持续更文,关注博主少走弯路,谢谢大家支持 💖 类加载 1. 加载2. 验证3. 准备4. 解析5. 初…