详解pytorch中循环神经网络(RNN、LSTM、GRU)的维度

详解pytorch中循环神经网络(RNN、LSTM、GRU)的维度

  • RNN
    • torch.nn.rnn详解
    • RNN输入输出维度
  • LSTM
    • torch.nn.LSTM详解
    • LSTM输入输出维度
  • GRU
    • torch.nn.GRU详解
    • GRU输入输出维度
  • 三种RNN的示例

首先如果你对RNN、LSTM、GRU不太熟悉,可点击查看。

RNN

torch.nn.rnn详解

torch.nn.RNN(input_size,
hidden_size,
num_layers=1,
nonlinearity=‘tanh’,
bias=True,
batch_first=False,
dropout=0.0,
bidirectional=False,
device=None,
dtype=None)

原理
在这里插入图片描述

参数详解

  • input_size – 输入x中预期特征的数量

  • hidden_size – 隐藏状态h中的特征数量

  • num_layers – 循环层数。例如,设置num_layers=2 意味着将两个LSTM堆叠在一起形成堆叠 LSTM,第二个 LSTM 接收第一个 LSTM 的输出并计算最终结果。默认值:1

  • nonlinearity– 使用的非线性。可以是’tanh’或’relu’。默认:‘tanh’

  • bias– 如果False,则该层不使用偏差权重b_ih和b_hh。默认:True

  • batch_first – 如果,则输入和输出张量以(batch, seq, feature)True形式提供,而不是(seq, batch, feature)。请注意,这不适用于隐藏状态或单元状态。默认:False

  • dropout – 如果非零,则在除最后一层之外的每个LSTM层的输出上 引入Dropout层,dropout 概率等于 。默认值:0.0

  • bidirectional – 如果True, 则成为双向LSTM。默认:False

RNN输入输出维度

rnn = nn.RNN(10, 20, 2)
input = torch.randn(5, 3, 10)
h0 = torch.randn(2, 3, 20)
output, hn = rnn(input, h0)

可以看到输入是xh_0,h_0可以是None。如果batch_size是第0维度,需设置batch_first=True
输出则是outputh_n。h_n存了每一层的t时刻的隐藏状态值

# Efficient implementation equivalent to the following with bidirectional=False
def forward(x, h_0=None):
    if batch_first:
        x = x.transpose(0, 1)
    seq_len, batch_size, _ = x.size()
    if h_0 is None:
        h_0 = torch.zeros(num_layers, batch_size, hidden_size)
	...
    return output, h_n

输入:
x的输入维度:(batch_size, sequence_length, input_size) [前提:batch_first=True]
h_0的维度:(D∗num_layers, hidden_size) [可以为None]

输出: output的输出维度:(batch_size, sequence_length, D*hidden_size)
[D=2 if bidirectional=True otherwise 1]
h_n的维度:(D∗num_layers, hidden_size)

LSTM

torch.nn.LSTM详解

torch.nn.LSTM(input_size,
hidden_size,
num_layers=1,
bias=True,
batch_first=False,
dropout=0.0,
bidirectional=False,
proj_size=0,
device=None,
dtype=None)

原理:

参数详解:
相比于RNN多了proj_size参数,少了nonlinearity参数

  • input_size – 输入x中预期特征的数量

  • hidden_size – 隐藏状态h中的特征数量

  • num_layers – 循环层数。例如,设置num_layers=2 意味着将两个LSTM堆叠在一起形成堆叠 LSTM,第二个 LSTM 接收第一个 LSTM 的输出并计算最终结果。默认值:1

  • bias– 如果False,则该层不使用偏差权重b_ih和b_hh。默认:True

  • batch_first – 如果,则输入和输出张量以(batch, seq, feature)True形式提供,而不是(seq, batch, feature)。请注意,这不适用于隐藏状态或单元状态。默认:False

  • dropout – 如果非零,则在除最后一层之外的每个LSTM层的输出上 引入Dropout层,dropout 概率等于 。默认值:0dropout

  • bidirectional – 如果True, 则成为双向LSTM。默认:False

  • proj_size – 如果,将使用具有相应大小投影的LSTM 。默认值:0

LSTM输入输出维度

LSTM= nn.LSTM(10, 20, 2)
input = torch.randn(5, 3, 10)
h0 = torch.randn(2, 3, 20)
c0 = torch.randn(2, 3, 20)
output, (hn, cn) = LSTM(input, (h0, c0))

输入是x,此外h_0c_0可以是None。如果batch_size是第0维度,需设置batch_first=True
输出则是output和一个元组(h_n, c_n)

输入: x的输入维度:(batch_size, sequence_length, input_size)`
[前提:batch_first=True]

输出: output的输出维度:(batch_size, sequence_length, D*hidden_size)
[D=2 if bidirectional=True otherwise 1]

具体可参考官方文档:nn.LSTM
在这里插入图片描述

GRU

torch.nn.GRU详解

torch.nn.GRU(input_size,
hidden_size,
num_layers=1,
bias=True,
batch_first=False,
dropout=0.0,
bidirectional=False,
device=None,
dtype=None)

原理:
在这里插入图片描述

参数详解:
与上文LSTM相比,缺少了proj_size参数,与RNN相比也缺少了nonlinearity参数

GRU输入输出维度

gru= nn.GRU(10, 20, 2)
input = torch.randn(5, 3, 10)
h0 = torch.randn(2, 3, 20)
output, hn = gru(input, h0)

与RNN一致见上文,相比LSTM少了c_n

三种RNN的示例

import torch
import torch.nn as nn

rnn = nn.RNN(10, 20, 2, batch_first=True) # (input_size, hidden_size, num_layer)
lstm = nn.LSTM(10, 20, 2, batch_first=True)
gru = nn.GRU(10, 20, 2, batch_first=True)

input = torch.randn(5, 3, 10)  # (batchsize, seq, input_size)
h0 = torch.randn(2, 3, 20)
c0 = torch.randn(2, 3, 20)

output_rnn, h_n = rnn(input)
output_lstm, (hn, cn) = lstm(input)
output_gru, h_n2 = gru(input)
print("输入维度:", input.shape)
print(f"RNN 输出维度:{output_rnn.shape}, h_n维度:{h_n.shape}" )
print("LSTM 输出维度:", output_lstm.shape)
print("GRU 输出维度:", output_gru.shape)


"""
输入维度: torch.Size([5, 3, 10])
RNN 输出维度:torch.Size([5, 3, 20]), h_n维度:torch.Size([2, 5, 20])
LSTM 输出维度: torch.Size([5, 3, 20])
GRU 输出维度: torch.Size([5, 3, 20])
"""

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/627433.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

2万字实操入门案例之在Springboot框架下用Mybatis简化JDBC开发实现基础的操作MySQL之预编译SQL主键返回增删改查

环境准备 准备数据库表 use mybatis;-- 部门管理 create table dept(id int unsigned primary key auto_increment comment 主键ID,name varchar(10) not null unique comment 部门名称,create_time datetime not null comment 创建时间,update_time datetime not null comme…

做一个犬榜小程序,警示社会。

2024-5-15新闻:流浪狗咬死3岁男童。 相关链接:3岁男童被恶犬咬伤离世 母亲发声 急寻狗主讨公道_中华网 恶狗的主人终于付出代价 链接:这回,恶狗的主人终于付出代价 - 知乎 7岁女童家门口玩耍被恶犬咬伤头面部,多处撕…

【SQL每日一练】获取PADS公司用户名称和各职业总数并根据格式输出

文章目录 题目一、解析二、题解1.MySQL 题目 生成以下两个结果集: 1、查询 OCCUPATIONS 表中所有名字,紧跟每个职业的第一个字母作为括号(即:括在括号中),并按名字顺序排序。例如:AnActorName…

OpenCV-android-sdk配置及使用(NDK)

opencv官网下载Android版Releases - OpenCV 下载好OpenCV-android-sdk并解压好,然后新建一个jni文件夹测试,测试项目目录结构如下: ├── jni │ ├── Android.mk │ ├── Application.mk │ └── test.cpp Application.mk: APP_STL := c++_static APP_CPP…

Cow Exhibition G的来龙去脉

[USACO03FALL] Cow Exhibition G - 洛谷 曲折经过 爆搜 一开始没什么好的想法&#xff0c;就针对每头奶牛去or不去进行了爆搜。 #include <cstdio> #include <algorithm> using namespace std;#define maxn 405 int iq[maxn], eq[maxn]; int ans; int n;void d…

Stm32串口搭配DMA实现自定义printf、scanf

前言:本文仅供学习参考使用&#xff0c;主要目的是让大家快速使用串口调试&#xff0c;文章所提及的GCC适用于Clion&#xff0c;Vscode等第三方编辑器的用户。作者有时间会继续更新^_^ 一、GCC环境 1、标准库 (1)、使用方法 在主函数while(1)初始化中&#xff0c;添加Seria…

【复试分数线】综合性985历年分数线汇总(第四弹)

国家线和34所自划线 可以看作是考研上岸最最最基础的门槛。真正决定你能不能进入复试的还要看院线&#xff08;复试分数线&#xff09;&#xff01;今天我将分析考信号的除C9、工科类985的其他7所985近三年复试分数线&#xff08;不包括2024&#xff09;&#xff0c;大家可以参…

Web前端开发 - 4 - CSS3动画

CSS3动画 一、 设计2D变换二、 设计3D变换三、 设计过渡动画四、设计帧动画 一、 设计2D变换 transform : none | <transform-function> /* <transform-function> 设置变换函数&#xff0c;可以是一个或多个变换函数列表。函数包括: martrix(x缩放,x倾斜,y倾斜,y…

玩转Matlab-Simscape(初级)-01-从一个简单模型开始学习之旅

** 玩转Matlab-Simscape&#xff08;初级&#xff09;- 01 - 从一个简单模型开始学习之旅 ** 目录 玩转Matlab-Simscape&#xff08;初级&#xff09;- 01 - 从一个简单模型开始学习之旅 前言一、从模板开始建模二、建模一个简单的连杆2.1 建模2.2 生成子系统 总结 前言 在产…

如何在 Windows 11/10 中恢复已删除的分区

在将重要数据存储在计算机上之前&#xff0c;许多用户会创建分区以更好地组织和管理他们的文件。此分区可以在内部硬盘驱动器或外部存储设备上创建。但是&#xff0c;有时可能会意外删除分区。如果发生这种情况&#xff0c;您可能想知道是否可以在不丢失任何信息的情况下恢复已…

适用于 Windows 8/10/11 的 10 大 PC 迁移工具:电脑克隆迁移软件

当您发现自己拥有一台新的 PC 或笔记本电脑时&#xff0c;PC 迁移变得至关重要。将数据从旧计算机传输到新计算机的过程似乎令人生畏&#xff0c;尤其是如果您是第一次这样做。迁移过程中数据丢失的潜在风险加剧了焦虑。为确保文件和系统设置的无缝无忧传输&#xff0c;使用专为…

初识C语言——第二十天

do while ()循环 do 循环语句; while(表达式); 句式结构&#xff1a; 执行过程&#xff1a; do while循环的特点&#xff1a; 代码练习&#xff1a; 二分法算法&#xff1a; int main() {int arr[] { 0,1,2,3,4,5,6,7,8,9};int k 7;//查找数字7&#xff0c;在arr这个数组…

prompt工程策略(三:使用 LLM 防护围栏创建系统提示)

原文&#xff1a;我是如何赢得GPT-4提示工程大赛冠军的 原文的原文&#xff1a; How I Won Singapore’s GPT-4 Prompt Engineering Competition &#xff01;&#xff01;本内容仅适用于具有 System Prompt&#xff08;系统提示&#xff09;功能的 LLM。具有这一功能的最著名 …

Python sort() 和 sorted() 的区别应用实例详解

大家好&#xff0c;今天针对 Python 中 sort() 和 sorted() 之间的区别&#xff0c;来一个实例详细解读。sort — 顾名思义就是排序的意思&#xff0c;它可以接收的对象为可迭代的数据类型。今天以列表为例子演示两者的不同点、相同点&#xff0c;以及其中一些常用的高级参数使…

【3dmax笔记】022:文件合并、导入、导出

文章目录 一、合并二、导入三、导出四、注意事项一、合并 只能合并 max 文件(高版本能够合并低版本模型,低版本不能合并高版本的模型)。点击【文件】→【导入】→【合并】: 选择要合并的文件,后缀名为3dmax默认的格式,max文件。 二、导入 点击【文件】→【导入】→【导…

NSSCTF中的1zjs、作业管理系统、finalrce、websign、简单包含、Http pro max plus

目录 [LitCTF 2023]1zjs [LitCTF 2023]作业管理系统 [SWPUCTF 2021 新生赛]finalrce exec()函数&#xff1a;php中exec介绍及使用_php exec-CSDN博客​​​​​​ 资料参考&#xff1a;RCE(远程命令执行)绕过总结_rce绕过-CSDN博客 [UUCTF 2022 新生赛]websign [鹏城杯 …

【校园论坛系统】分站式后台,多城市圈子论坛,校园圈子交流平台,二手发布市场,校园圈子论坛系统

简述 校园论坛系统是为学生们提供一个交流、分享信息、互相帮助的平台。它通常包括了各种分类的版块&#xff0c;例如学习交流、社团活动、二手交易、失物招领等等。用户可以在论坛上发帖&#xff0c;回复他人的帖子&#xff0c;也可以私信其他用户。此外&#xff0c;管理员还…

只用了三天就入门了Vue3?

"真的我学Vue3&#xff0c;只是为了完成JAVA课设" 环境配置 使用Vue3要去先下载Node.js。 就像用Python离不开pip包管理器一样。 Node.js — Run JavaScript Everywhere (nodejs.org) 下完Node.js去学习怎么使用npm包管理器&#xff0c;放心你只需要学一些基础的…

【数据结构】数据结构大汇总 {数据结构的分类总结:定义和特性、实现方式、操作与复杂度、适用场景、相关算法、应用实例}

一、线性结构 1.1 顺序表 定义和特性&#xff1a;顺序表是一种线性表的存储结构&#xff0c;它采用一段地址连续的存储单元依次存储线性表中的元素。顺序表具有随机访问的特性&#xff0c;即可以通过元素的下标直接访问元素。 实现方式&#xff1a;顺序表可以通过数组来实现&…

React Native 之 原生组件和核心组件(二)

原生组件 在 Android 开发中是使用 Kotlin 或 Java 来编写视图&#xff1b;在 iOS 开发中是使用 Swift 或 Objective-C 来编写视图。在 React Native 中&#xff0c;则使用 React 组件通过 JavaScript 来调用这些视图。在运行时&#xff0c;React Native 为这些组件创建相应的 …