LSTM和GRU的介绍以及Pytorch源码解析

介绍一下LSTM模型的结构以及源码,用作自己复习的材料。 

LSTM模型所对应的源码在:\PyTorch\Lib\site-packages\torch\nn\modules\RNN.py文件中。

上次上一篇文章介绍了RNN序列模型,但是RNN模型存在比较严重的梯度爆炸和梯度消失问题。

本文介绍的LSTM模型解决的RNN的大部分缺陷。

首先展示LSTM的模型框架:

下面是LSTM模型的数学推导公式:

\begin{array}{ll} \\ i_t = \sigma(W_{ii} x_t + b_{ii} + W_{hi} h_{t-1} + b_{hi}) \\ f_t = \sigma(W_{if} x_t + b_{if} + W_{hf} h_{t-1} + b_{hf}) \\ g_t = \tanh(W_{ig} x_t + b_{ig} + W_{hg} h_{t-1} + b_{hg}) \\ o_t = \sigma(W_{io} x_t + b_{io} + W_{ho} h_{t-1} + b_{ho}) \\ c_t = f_t \odot c_{t-1} + i_t \odot g_t \\ h_t = o_t \odot \tanh(c_t) \\ \end{array}

h_t表示t时刻的隐藏状态,c_t表示t时刻的记忆细胞状态,x_t表示t时刻的输入,h_{t-1}表示在时间t-1的隐藏状态或在时间0的初始隐藏状态。

i_t,f_t,g_t,o_t 分别是输入门、遗忘门、单元门和输出门。

这张图片比较好的介绍了各个门之间的交互关系以及输入输出,大家可以放大看一下。

接下来展示GRU的框架模型:

下面是GRU的数学推导公式:

r_t = \sigma(W_{ir} x_t + b_{ir} + W_{hr} h_{(t-1)} + b_{hr}) \\ z_t = \sigma(W_{iz} x_t + b_{iz} + W_{hz} h_{(t-1)} + b_{hz}) \\ n_t = \tanh(W_{in} x_t + b_{in} + r_t * (W_{hn} h_{(t-1)}+ b_{hn})) \\ h_t = (1 - z_t) * n_t + z_t * h_{(t-1)}

h_t表示t时刻的隐藏状态,x_t表示t时刻的输入,h_{t-1}表示在时间t-1的隐藏状态或在时间0的初始隐藏状态。r_t,n_t,z_t分别表示重置门更新门和新建门

上面的图片可以更直观的看到GRU中是如何迭代的。

接下来我们看一下源码中LSTM和GRU类的初始化(只介绍几个重要的参数):

torch.nn.LSTM(self, input_size, hidden_size, num_layers=1,
              bias=True, batch_first=False, dropout=0.0, 
              bidirectional=False, proj_size=0, device=None,
              dtype=None)
torch.nn.GRU(self, input_size, hidden_size, num_layers=1,
             bias=True, batch_first=False, dropout=0.0, 
             bidirectional=False, device=None, dtype=None)
  • input_size:输入数据中的特征数(可以理解为嵌入维度 embedding_dim)。
  • hidden_size:处于隐藏状态 h 的特征数(可以理解为输出的特征维度)。
  • num_layers:代表着RNN的层数,默认是1(层),当该参数大于零时,又称为多层RNN。
  • bidirectional:即是否启用双向LSTM(GRU),默认关闭。

LSTM与GRU都是特殊的RNN,因此输入输出可以参考的上一篇介绍RNN的文章,在这里直接进行代码举例。

lstm1 = nn.LSTM(input_size=20,hidden_size=40,num_layers=4,bidirectional=True)
lstm2 = nn.LSTM(input_size=20,hidden_size=40,num_layers=4,bidirectional=False)

gru1 = nn.GRU(input_size=20,hidden_size=25,num_layers=4,bidirectional=True)
gru2 = nn.GRU(input_size=20,hidden_size=25,num_layers=4,bidirectional=False)

tensor1 = torch.randn(5,10,20)  # (batch_size * seq_len * emb_dim)
tensor2 = torch.randn(5,10,20)  # (batch_size * seq_len * emb_dim)

out_lstm1,(hn, cn) = lstm1(tensor1)  # (batch_size * seq_len * (hidden_size * bidirectional))
out_lstm2,(hn, cn) = lstm2(tensor2)  # (batch_size * seq_len * (hidden_size * bidirectional))

out_gru1,h_n = gru1(tensor1)  # (batch_size * seq_len * (hidden_size * bidirectional))
out_gru2,h_n = gru2(tensor1)  # (batch_size * seq_len * (hidden_size * bidirectional))

print(out_lstm1.shape)  # torch.Size([5, 10, 80])
print(out_lstm2.shape)  # torch.Size([5, 10, 40])

print(out_gru1.shape)  # torch.Size([5, 10, 50])
print(out_gru2.shape)  # torch.Size([5, 10, 25])

维度已经在注释中给大家标注上了!

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/245212.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

【TwinCAT学习笔记 1】TwinCAT开发环境搭建

写在前面 作为技术开发人员,开启任何一项开发工作之前,首先都要搭建好开发环境,所谓磨刀不误砍材工,一定要有耐心,一次不行卸载再装。我曾遇到过一个学生,仅搭建环境就用了两周,这个过程也是一…

数据寻址-偏移寻址(硬核)

目录 一. 基址寻址二. 变址寻址三. 相对寻址四. 硬件如何实现数的"比较" \quad \quad \quad \quad \quad \quad \quad 一. 基址寻址 \quad A就是偏移量 有的用通用寄存器来代替BR专用寄存器的功能 其中 R 0 R_0 R0​的位数是由通用寄存器的总数来判断的, 比如通用寄存…

社交网络分析1:起源发展、不同领域的应用、核心概念

社交网络分析1:社交网络相关定义和概念 写在最前面关于课程 社交网络、社交网络分析社交网络发展阶段(自己感兴趣)1. 社交网络的起源2. 社交网络的演变3. 社交网络的成熟4. 发展阶段补充和展望 2023社交大变革(自己感兴趣的点&…

安装spaCy及语言包下载安装

文章目录 1. spaCy的安装1.1 安装spaCy包方式1 : 通过pip / conda命令安装方式2 : 通过离线导入 1.2 安装语言模型方式1 : 通过pip / conda命令安装方式2 : 通过离线导入 2. 常见问题a. 版本问题 3. 参考文档 关注公众号:『AI学习星球』 回复:遥感图像语…

【Spring】04 国际化

文章目录 1. 定义2. Spring 的支持1) MessageSource接口2) ResourceBundleMessageSource 3. 配置国际化1)配置MessageSource Bean2)创建资源文件3)在Bean中使用国际化消息 4. 使用占位符和参数结语 Spring 为我们提供了…

橘子学K8S01之容器中所谓的隔离

我们一直都在说容器就是一个沙盒,沙盒技术顾名思义就是像一个集装箱一样,把应用(服务,进程之类的)装起来的技术,这样每个进程在自己的沙盒中和其他的沙盒隔离开来,每个沙盒之间存在一个边界使得他们互不干扰&#xff0…

【动手学深度学习】(十三)深度学习硬件

文章目录 一、CPU和GPU二、更多的芯片1.DSP:数字信号处理2.可编程阵列(FPGA)3.AI ASIC 三、单机多卡并行 一、CPU和GPU 提升CPU利用率 在计算ab之前,需要准备数据 主内存->L3->L2->L1->寄存器(数据只有进入寄存器才可以参与运算) 提升空间和时间的内存…

Linux基本开发工具

编译器和自动化构建工具 一、编译器——gcc、g1. 安装 gcc/g2. 使用3. 链接库4. 拓展命令:od/file/ldd/readelf 二、自动化构建项目——make、makefile1. 介绍2. 使用例子touch——change file timestampsstat——display file or file system status修改时间 .PHON…

Java入门学习笔记二

一、抽象类 当编写一个类时,我们往往会为该类定义一些方法,这些方法是用来描述该类的行为方式,那么这些方法都有具体的方法体。 分析事物时,发现了共性内容,就出现向上抽取。会有这样一种特殊情况,就是功…

国家开放大学 湖南开放大学形成性考核 平时作业 统一资料

试卷代号:1258 房屋建筑混凝土结构设计 参考试题 一、单项选择题(每小题2分,共计40分) 1.( )是将框架结构中的部分跨间布置剪力墙或把剪力墙结构的部分剪力墙抽掉改为框架承重。 A.梁板结构体系 B.框…

区块链的可拓展性研究【05】闪电网络

1.闪电网络:闪电网络是一种基于比特币区块链的 Layer2 扩容方案,它通过建立一个双向支付通道网络,实现了快速、低成本的小额支付。闪电网络的交易速度非常快,可以达到每秒数万笔交易,而且交易费用非常低,几…

Linux---复制、移动文件及目录命令

1. 复制、移动文件及目录命令的使用 命令说明cp拷贝文件、拷贝目录mv移动文件、移动目录、重命名 cp命令效果图: 说明: 拷贝目录需要加上-r选项,-r表示递归拷贝目录及其内容 mv命令移动文件效果图 mv命令移动目录效果图 mv命令重名名效果图 2. 小结 cp&#…

TwinCAT3 Modbus-TCP Client/Server使用

目录 一、环境配置和准备 1、PLC中安装TF6250-Modbus-TCP库 2、勾选TF6250的license 3、PLC工程中添加Tc2_ModbusSrv库文件 4、分别创建测试ModbusTCP测试的Server和Client程序 二、PLC作为Client端 1、设置测试电脑IP地址 2、运行MobusTCP测试工具 3、PLC端程序编写 …

Python 使用 openpyxl 写表格

当前环境:Win10 x64 MS office 2016 Python3.7 openpyxl3.0.9 1 写入表格 from openpyxl import Workbook# 创建一个 workbook workbook_w Workbook()# 获取被激活的 worksheet worksheet_w workbook_w.active# 1 批量插入数据# 设置一行数据 worksheet_w.ap…

鸿蒙ArkTS基础类库——线性容器

容器类库概述 容器类库,用于存储各种数据类型的元素,并具备一系列处理数据元素的方法,作为纯数据结构容器来使用具有一定的优势。容器类采用了类似静态语言的方式来实现,并通过对存储位置以及属性的限制,让每种类型的…

自己动手写数据库: select 查询语句对应查询树的构造和执行

首先我们需要给原来代码打个补丁,在SelectScan 结构体初始化时需要传入 UpdateScan 接口对象,但很多时候我们需要传入的是 Scan 对象,因此我们需要做一个转换,也就是当初始化 SelectScan 时,如果传入的是 Scan 对象&am…

vs2019 配置liblas

vs2019 配置liblas liblas是一个开源的C库,用于读写LAS格式的点云数据。libLAS易于使用,广泛应用于各种点云处理任务。它支持基本的点云数据操作,如读取、写入和编辑LAS文件中的点。 1.源码下载 https://github.com/libLAS/libLAS 2.编译 首…

Docker多平台安装与配置指南

Docker的流行使得它成为开发者和运维人员不可或缺的工具。在本文中,将深入探讨如何在不同平台上安装和配置Docker,旨在为大家提供详尽的指南,确保他们能够顺利地使用这一强大的容器化工具。 Docker基础概念回顾 Docker利用容器技术&#xf…

设计模式-访问者模式

访问者模式是设计模式中行为型模式的一种(其他的还有如创建型、结构型),听说是设计模式中比较难理解的一种,最近项目中用到了该模式,所以今天总结和实践一下。 一、访问者模式要解决的问题: 稳定的数据结…

【系统架构】集群、分布式概念及系统架构演进过程

集群、分布式概念: 对食物没有太高要求的人在肚子饿的时候一般都会选择去兰州拉面、沙县小吃等小饭馆,这类小饭馆有个很显著的特点:洗菜、切菜、炒菜都是同一个人完成,如果厨子不舒服可能饭馆还会歇业。而一些人流量较大的饭馆的分…