【Pytorch】nn.RNN、nn.LSTM 和 nn.GRU的输入和输出形状

nn.RNN、nn.LSTM 和 nn.GRU的输入和输出形状

      • 输入形状
        • 通用输入参数
        • 特殊情况(LSTM)
      • 输出形状
        • nn.RNN 和 nn.GRU
        • nn.LSTM
      • 代码示例

输入形状

通用输入参数

这三个模块通常接收以下两种形式的输入:

  • 输入序列:形状为 (seq_len, batch_size, input_size)
    • seq_len:表示序列的长度,即时间步的数量。例如在处理文本时,它可以是句子的单词数量;在处理时间序列数据时,它可以是时间点的数量。
    • batch_size:表示每次输入的样本数量。在训练模型时,通常会将多个样本组成一个批次进行处理,以提高计算效率。
    • input_size:表示每个时间步输入的特征维度。例如,在处理图像序列时,它可以是图像的特征向量维度;在处理文本时,它可以是词向量的维度。
  • 初始隐藏状态:形状为 (num_layers * num_directions, batch_size, hidden_size)
    • num_layers:表示 RNN 层数。如果设置为多层 RNN,信息会在不同层之间依次传递。
    • num_directions:表示 RNN 的方向数,取值为 1(单向 RNN)或 2(双向 RNN)。双向 RNN 会同时考虑序列的正向和反向信息。
    • hidden_size:表示隐藏层的维度,即每个时间步输出的隐藏状态的特征数量。
特殊情况(LSTM)

对于 nn.LSTM,除了初始隐藏状态外,还需要一个初始细胞状态,其形状与初始隐藏状态相同,即 (num_layers * num_directions, batch_size, hidden_size)

输出形状

nn.RNN 和 nn.GRU
  • 输出序列:形状为 (seq_len, batch_size, num_directions * hidden_size)。它包含了每个时间步的隐藏状态输出,其中 num_directions 取决于 RNN 是否为双向。如果是单向 RNN,num_directions 为 1;如果是双向 RNN,num_directions 为 2,输出的特征维度会翻倍。
  • 最终隐藏状态:形状为 (num_layers * num_directions, batch_size, hidden_size)。它表示最后一个时间步的隐藏状态,用于后续的任务,如分类或预测。
nn.LSTM
  • 输出序列:形状同样为 (seq_len, batch_size, num_directions * hidden_size),含义与 nn.RNNnn.GRU 的输出序列类似。
  • 最终隐藏状态和细胞状态:最终隐藏状态和细胞状态的形状均为 (num_layers * num_directions, batch_size, hidden_size)。最终隐藏状态和细胞状态一起保存了 LSTM 在最后一个时间步的信息。

代码示例

import torch
import torch.nn as nn

# 定义参数
input_size = 10
hidden_size = 20
num_layers = 2
batch_size = 3
seq_len = 5
num_directions = 1  # 单向 RNN

# 创建 RNN 模型
rnn = nn.RNN(input_size, hidden_size, num_layers)
# 创建 LSTM 模型
lstm = nn.LSTM(input_size, hidden_size, num_layers)
# 创建 GRU 模型
gru = nn.GRU(input_size, hidden_size, num_layers)

# 生成随机输入序列
input_seq = torch.randn(seq_len, batch_size, input_size)
# 初始化隐藏状态
h0 = torch.randn(num_layers * num_directions, batch_size, hidden_size)

# 运行 RNN
rnn_output, hn_rnn = rnn(input_seq, h0)
print("RNN 输出序列形状:", rnn_output.shape)
print("RNN 最终隐藏状态形状:", hn_rnn.shape)

# 初始化 LSTM 的细胞状态
c0 = torch.randn(num_layers * num_directions, batch_size, hidden_size)
# 运行 LSTM
lstm_output, (hn_lstm, cn_lstm) = lstm(input_seq, (h0, c0))
print("LSTM 输出序列形状:", lstm_output.shape)
print("LSTM 最终隐藏状态形状:", hn_lstm.shape)
print("LSTM 最终细胞状态形状:", cn_lstm.shape)

# 运行 GRU
gru_output, hn_gru = gru(input_seq, h0)
print("GRU 输出序列形状:", gru_output.shape)
print("GRU 最终隐藏状态形状:", hn_gru.shape)

在上述代码中,我们定义了输入序列和初始隐藏状态,并分别使用 nn.RNNnn.LSTMnn.GRU 对输入序列进行处理,最后打印出它们的输出形状,帮助你更好地理解输入输出形状的特点。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/965945.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

docker容器运行时忘了加自动重启命令了,之后如何添加自动重启命令,使其随开机自动重启

要让已有的Docker容器在系统重启后自动启动,可以通过以下步骤设置其重启策略: 步骤 1:查找容器名称或ID docker ps -a找到目标容器的ID或名称。 步骤 2:更新容器的重启策略 使用 docker update 命令直接修改容器的重启策略&am…

第16章 Single Thread Execution设计模式(Java高并发编程详解:多线程与系统设计)

简单来说, Single Thread Execution就是采用排他式的操作保证在同一时刻只能有一个线程访问共享资源。 1.机场过安检 1.1非线程安全 先模拟一个非线程安全的安检口类,旅客(线程)分别手持登机牌和身份证接受工作人员的检查,示例代码如所示。…

深度学习:解码智能的“数字炼金术”

深度学习:解码智能的“数字炼金术” 1943年,当神经科学家沃伦麦卡洛克和数学家沃尔特皮茨在论文中首次提出人工神经元模型时,他们或许没有想到,这个简单的数学公式会在80年后掀起改变人类文明的技术革命。深度学习作为这场革命的…

让文物“活”起来,以3D数字化技术传承文物历史文化!

文物,作为不可再生的宝贵资源,其任何毁损都是无法逆转的损失。然而,当前文物保护与修复领域仍大量依赖传统技术,同时,文物管理机构和专业团队的力量相对薄弱,亟需引入数字化管理手段以应对挑战。 积木易搭…

pytest-xdist 进行多进程并发测试

在自动化测试中,运行时间过长往往是令人头疼的问题。你是否遇到过执行 Pytest 测试用例时,整个测试流程缓慢得让人抓狂?别担心,pytest-xdist 正是解决这一问题的利器!它支持多进程并发执行,能够显著加快测试…

广度优先搜索(BFS)算法详解——以走迷宫问题为例

引言:当算法遇见迷宫 想象你置身于一个复杂的迷宫,如何在最短时间内找到出口?这个问题不仅存在于童话故事中,更是计算机科学中经典的路径搜索问题。本文将带你通过走迷宫问题,深入理解广度优先搜索(BFS&am…

kubeadm构建k8s源码阅读环境

目标 前面看了minikube的源码了解到其本质是调用了kubeadm来启动k8s集群,并没有达到最初看代码的目的。 所以继续看看kubeadm的代码,看看能否用来方便地构建源码调试环境。 k8s源码编译 kubeadm源码在k8s源码库中,所以要先克隆k8s源码。之…

BFS算法篇——广度优先搜索,探索未知的旅程(上)

文章目录 前言一、BFS的思路二、BFS的C语言实现1. 图的表示2. BFS的实现 三、代码解析四、输出结果五、总结 前言 广度优先搜索(BFS)是一种广泛应用于图论中的算法,常用于寻找最短路径、图的遍历等问题。与深度优先搜索(DFS&…

baigeiRSA

baigeiRSA 打开附件有两个: 1.import libnumfrom Crypto.Util import numberfrom secret import flag​size 128e 65537p number.getPrime(size)q number.getPrime(size)n p*q​m libnum.s2n(flag)c pow(m, e, n)​print(n %d % n)print(c %d % c)​​2.n…

脚本一键生成管理下游k8s集群的kubeconfig

一、场景 1.1 需要管理下游k8s集群的场景。 1.2 不希望使用默认的cluster-admin权限的config. 二、脚本 **重点参数: 2.1 配置变量。 1、有单独namespace的权限和集群只读权限。 2、自签名的CA证书位置要正确。 2.2 如果配置错误,需要重新…

camera光心检测算法

1.概要 光心检测算法,基于opencv c实现,便于模组厂快速集成到软件工具中,适用于camera模组厂算法评估组装制程镜头与sensor的偏心程度,便于工程师了解制程的问题找出改善方向。 2.技术介绍 下图为camera模组厂抓取的bayer-raw经过…

基于logback+fastjson实现日志脱敏

一、需求背景 日常工作中,必不可免的会将一些敏感信息,如用户名、密码、手机号、身份证号、银行账号等等打印出来,但往往为了安全,这些信息都需要进行脱敏。脱敏实际就是用一些特殊字符来替换部分值。 JSON 和 JSONObject Fastj…

RC5分组加密算法

目录 (1)RC5密钥扩展算法 (2)RC5加密算法 (3)RC5解密算法 RC5分组加密算法 RC5分组密码算法是1994年RSA实验室的RonaldL.Rivest教授发明的。它是参数可变的分组密码算法,三个可变的参数是&a…

GPU — 8 卡 GPU 服务器与 NVLink/NVSwitch 互联技术

目录 文章目录 目录8 卡 GPU 服务器GPU 互联技术分类PCIe 直连PCIe Switch 互联NVLink 互联NVLink 1.0 与 DGX-1 系统NVLink 2.0 与 DGX-1 系统NVSwitch 全互联NVSwitch 1.0 与 DGX-2 系统NVLink 3.0、NVSwitch 2.0 与 DGX A100NVLink 4.0、NVSwitch 3.0 与 DGX H100NVSwitch v…

idea——IDEA2024版本创建Sping项目无法选择Java 8

目录 一、背景二、解决方式(替换创建项目的源地址) 一、背景 IDEA2024创建一个springboot的项目,本地安装的是1.8,但是在使用Spring Initializr创建项目时,发现版本只有17、21、23。 二、解决方式(替换创…

STM32 串口发送与接收

接线图 代码配置 根据上一章发送的代码配置,在GPIO配置的基础上需要再配置PA10引脚做RX接收,引脚模式可以选择浮空输入或者上拉输入,在USART配置串口模式里加上RX模式。 配置中断 //配置中断 USART_ITConfig(USART1, USART_IT_RXNE, ENABLE…

储能系统-系统架构

已更新系列文章包括104、61850、modbus 、单片机等,欢迎关注 IEC61850实现方案和测试-1-CSDN博客 快速了解104协议-CSDN博客 104调试工具2_104协议调试工具-CSDN博客 1 电池储能系统(BESS) 架构 电池储能系统主要包括、电池、pcs、本地控制…

TOTP实现Google Authenticator认证工具获取6位验证码

登录遇到Google认证怎么办? TOTP是什么?(Google Authenticator) TOTP(Time-based One-Time Password)是一种基于时间的一次性密码算法,主要用于双因素身份验证。其核心原理是通过共享密钥和时间同步生成动态密码,具体步骤如下: 共享密钥:服务端与客户端预先共享一个…

清理服务器/docker容器

清理服务器 服务器或docker容器清理空间。 清理conda环境 删除不用的conda虚拟环境: conda env remove --name python38 conda env remove --name python310清理临时目录:/tmp du -sh /tmp # 查看/tmp目录的大小/tmp 目录下的文件通常是可以直接删除…

Naive UI去掉n-select下拉框边框,去掉n-input输入框边框

<template><div><div style"margin-top:10px;width: 100%;"><dade-descriptions><tr><dade-descriptions-item label"代理名称"><dade-input placeholder"代理名称"></dade-input></dade-de…