机器学习深度学习——常见循环神经网络结构(RNN、LSTM、GRU)

👨‍🎓作者简介:一位即将上大四,正专攻机器学习的保研er
🌌上期文章:机器学习&&深度学习——RNN的从零开始实现与简洁实现
📚订阅专栏:机器学习&&深度学习
希望文章对你们有所帮助

常见循环神经网络结构(RNN、LSTM、GRU)

  • 引言
  • RNN
  • LSTM
    • 门控记忆元
      • 输入门、输出门和遗忘门
      • 候选记忆元
      • 记忆元
      • 隐状态
    • LSTM的简洁实现
  • GRU
    • 结构详解
    • GRU的简洁实现
  • 常用应用方式

引言

之前已经实现讲解并实现过了RNN模型,而LSTM可以弥补RNN的一些缺点,GRU是LSTM的简化版本,这里我们就回顾一下RNN模型,接着循序渐进讲解LSTM和GRU。
CNN和全连接网络的数据表示能力已经很强了,但是我们为啥还需要循环神经网络呢?这是因为现实的问题更复杂,很多数据的输入顺序对于结果都是有很大影响的。如文本数据(尤其是字母和文字的组合),先后顺序具有非常重要的意义,如果打乱,就会无法正确表示原始信息。而相比其他网络,循环神经网络因为具有记忆能力,所以更有效。

RNN

RNN循环神经网络使用torch.nn.RNN()来构建,如下图所示:
在这里插入图片描述
针对t时刻的隐状态,可以由下面公式计算:
h t = φ ( W i h x t + b i h + W h h h t − 1 + b h h ) = φ ( W i h x t + W h h h t − 1 + b h ) 其中: h t 是 t 时刻的隐藏状态; h t − 1 是 t − 1 时刻的隐藏状态 W i h 是输入到隐藏层的权重; W h h 是隐藏层到隐藏层的权重; b i h 是输入到隐藏层的偏置; b h h 是隐藏层到隐藏层的偏置; h_t=φ(W_{ih}x_t+b_{ih}+W_{hh}h_{t-1}+b_{hh})\\ =φ(W_{ih}x_t+W_{hh}h_{t-1}+b_{h})\\ 其中:h_t是t时刻的隐藏状态;h_{t-1}是t-1时刻的隐藏状态\\ W_{ih}是输入到隐藏层的权重;W_{hh}是隐藏层到隐藏层的权重;\\ b_{ih}是输入到隐藏层的偏置;b_{hh}是隐藏层到隐藏层的偏置; ht=φ(Wihxt+bih+Whhht1+bhh)=φ(Wihxt+Whhht1+bh)其中:htt时刻的隐藏状态;ht1t1时刻的隐藏状态Wih是输入到隐藏层的权重;Whh是隐藏层到隐藏层的权重;bih是输入到隐藏层的偏置;bhh是隐藏层到隐藏层的偏置;
激活函数可以使用ReLU或tanh。
虽然在对序列数据进行建模时,RNN有一定记忆能力,但单纯的RNN会随着递归次数的增加,出现权重指数级爆炸或消失的问题,从而难以捕捉长时间关联,并导师训练时收敛困难。

LSTM

LSTM称为长短期记忆网络,是一种特殊的RNN,主要用于解决长序列训练过程中的梯度消失和爆炸问题,能在长序列中获得更好的分析效果。

门控记忆元

记忆元的目的是为了记录附加的信息,要控制记忆元,我们需要下面的几个门:
1、输出门:用来从单元中输出条目
2、输入门:决定何时将数据读入单元
3、遗忘门:重置单元的内容
接下来来看看如何工作的:

输入门、输出门和遗忘门

当前时间步的输入和前一个时间步的隐状态作为数据送入长短期记忆网络的门中,如下图:
在这里插入图片描述
上图的σ是代表由sigmoid激活函数的全连接层处理,因此三个门的值都在(0,1)范围内,显然计算方法如下:
I t = σ ( X t W x i + H t − 1 W h i + b i ) O t = σ ( X t W x o + H t − 1 W h o + b o ) F t = σ ( X t W x f + H t − 1 W h f + b f ) I_t=\sigma(X_tW_{xi}+H_{t-1}W_{hi}+b_i)\\ O_t=\sigma(X_tW_{xo}+H_{t-1}W_{ho}+b_o)\\ F_t=\sigma(X_tW_{xf}+H_{t-1}W_{hf}+b_f) It=σ(XtWxi+Ht1Whi+bi)Ot=σ(XtWxo+Ht1Who+bo)Ft=σ(XtWxf+Ht1Whf+bf)

候选记忆元

其计算与上面类似,但是使用tanh来作为激活函数,函数范围为(-1,1),计算方式为:
G t = t a n h ( X t W x g + H t − 1 W h g + b g ) G_t=tanh(X_tW_{xg}+H_{t-1}W_{hg}+b_g) Gt=tanh(XtWxg+Ht1Whg+bg)
如图所示:
在这里插入图片描述

记忆元

在LSTM中,有两个门用于实现一种输入和遗忘的机制:输入门控制采用多少来自候选记忆元的新数据,而遗忘门控制保留多少过去的记忆元的内容。使用按元素乘法,得出:
C t = F t ⨀ C t − 1 + I t ⨀ G t C_t=F_t \bigodot C_{t-1}+I_t \bigodot G_t Ct=FtCt1+ItGt
若遗忘门始终为1且输入门始终为0,则过去的记忆元 将随时间被保存并传递到当前时间步。
引入这种设计是为了缓解梯度消失问题, 并更好地捕获序列中的长距离依赖关系。
如下图所示:
在这里插入图片描述

隐状态

最后是计算隐状态,这里就是输出门的作用了。LSTM中,它是记忆元的tanh的门控版本,确保了隐状态的值在(-1,1)之间:
H t = O t ⨀ t a n h ( C t ) H_t=O_t \bigodot tanh(C_t) Ht=Ottanh(Ct)
只要输出门接近1,就能有效将所有记忆换递给预测部分,对于输出门接近0,我们只保留记忆元内的所有信息,而不需要更新隐状态。
那么整体的LSTM图示如下所示:
在这里插入图片描述

LSTM的简洁实现

使用高级API,我们可以直接实例化LSTM模型。这段代码的运行速度要快得多, 因为它使用的是编译好的运算符而不是Python来处理之前阐述的许多细节:

from torch import nn
from d2l import torch as d2l

batch_size, num_steps = 32, 35
train_iter, vocab = d2l.load_data_time_machine(batch_size, num_steps)

vocab_size, num_hiddens, device = len(vocab), 256, d2l.try_gpu()
num_epochs, lr = 500, 1

num_inputs = vocab_size
lstm_layer = nn.LSTM(num_inputs, num_hiddens)
model = d2l.RNNModel(lstm_layer, len(vocab))
model = model.to(device)
d2l.train_ch8(model, train_iter, vocab, lr, num_epochs, device)
d2l.plt.show()

运行结果:

perplexity 1.1, 48684.5 tokens/sec on cpu
time travelleryou can show black is white by argument said filby
travelleryou can show black is white by argument said filby

运行图片:
在这里插入图片描述

GRU

结构详解

LSTM对很多需要“长期记忆”的任务来说效果显著。但是门控状态太多,导致需要训练更多的参数,使得训练难度加大。因此提出循环门控单元GRU,GRU通过将遗忘门和输入门组合在一起,减少了门的数量,并做了其他改变,在保证记忆能力同时,提升网络训练效率。其组成如下所示:
在这里插入图片描述
而每个GRU单元针对输入进行下面函数的计算:
R t = σ ( X t W x r + H t − 1 W h r + b r ) Z t = σ ( X t W x z + H t − 1 W h z + b z ) 候选隐状态 H t ′ = t a n h ( X t W x h + ( R t ⨀ H t − 1 ) W h h + b h ) 其中 R t ⨀ H t − 1 可以减少以往遗忘状态的影响: 每当 R t 接近 1 时,我们恢复一个传统 R N N 网络; R t 接近 0 时,候选隐状态是以 X t 作为输入的多层感知机的结果 H t = Z t ⨀ H t − 1 + ( 1 − Z t ) ⨀ H t ′ Z t 接近 1 时,模型倾向于保留旧状态; Z t 接近 0 时,倾向于候选隐状态 R_t=\sigma(X_tW_{xr}+H_{t-1}W_{hr}+b_r)\\ Z_t=\sigma(X_tW_{xz}+H_{t-1}W_{hz}+b_z)\\ 候选隐状态H_t^{'}=tanh(X_tW_{xh}+(R_t \bigodot H_{t-1})W_{hh}+b_h)\\ 其中R_t \bigodot H_{t-1}可以减少以往遗忘状态的影响:\\ 每当R_t接近1时,我们恢复一个传统RNN网络;\\ R_t接近0时,候选隐状态是以X_t作为输入的多层感知机的结果\\ H_t=Z_t \bigodot H_{t-1}+(1-Z_t) \bigodot H_t^{'}\\ Z_t接近1时,模型倾向于保留旧状态;Z_t接近0时,倾向于候选隐状态 Rt=σ(XtWxr+Ht1Whr+br)Zt=σ(XtWxz+Ht1Whz+bz)候选隐状态Ht=tanh(XtWxh+(RtHt1)Whh+bh)其中RtHt1可以减少以往遗忘状态的影响:每当Rt接近1时,我们恢复一个传统RNN网络;Rt接近0时,候选隐状态是以Xt作为输入的多层感知机的结果Ht=ZtHt1+(1Zt)HtZt接近1时,模型倾向于保留旧状态;Zt接近0时,倾向于候选隐状态
总之,GRU有以下显著特征:
1、重置门有助于捕获序列中的短期依赖关系
2、更新门有助于捕获序列中的长期依赖关系

GRU的简洁实现

from torch import nn
from d2l import torch as d2l

batch_size, num_steps = 32, 35
train_iter, vocab = d2l.load_data_time_machine(batch_size, num_steps)

vocab_size, num_hiddens, device = len(vocab), 256, d2l.try_gpu()
num_epochs, lr = 500, 1
num_inputs = vocab_size
gru_layer = nn.GRU(num_inputs, num_hiddens)
model = d2l.RNNModel(gru_layer, len(vocab))
model = model.to(device)
d2l.train_ch8(model, train_iter, vocab, lr, num_epochs, device)
d2l.plt.show()

运行结果:

perplexity 1.0, 12581.5 tokens/sec on cpu
time traveller for so it will be convenient to speak of himwas e
travelleryou can show black is white by argument said filby

运行图片:
在这里插入图片描述

常用应用方式

循环神经网络中的不同的输入输出对应情况都有不同的应用方式。其中,一对多的网络结构可以用于图像描述(根据输入的一张图像,自动使用文字描述图像内容);多对一的网络结构可用于文本分类;多对多的网络结构可用于语言翻译。
比如,我们可以用RNN来做手写体分类,可以用LSTM来做中文新闻分类,可以用GRU来进行情感分类等等。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/68198.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

4用opencv玩转图像2

opencv绘制文字和几何图形 黑色底图 显示是一张黑色图片 使用opencv画圆形 #画一个圆 cv2.circle(imgblack_img,center(400,400),radius100,color(0,0,255),thickness10) 画实心圆 只需要把thickness-1。 cv2.circle(imgblack_img,center(500,600),radius50,color(0,0,255),t…

JS逆向系列之猿人学爬虫第11题 - app抓取 - so文件协议破解

题目地址 http://match.yuanrenxue.com/match/11这是个app题目,先下载下来安装到测试手机上 安装完成后的app界面长这样 打开之后是这样的: 要求已经简单明了了。 二话不说先反编译app 不出意外的是没出意外,源代码里面没啥混淆,所有东西都展示的明明白白的。 "…

Apache Maven:从构建到部署,一站式解决方案

目录 一、Maven介绍 1. Maven是什么? 2.Maven的作用? 二、Maven仓库介绍 2.1 库的分类 三、Maven安装与配置 3.1 Maven安装 3.2 Maven环境配置 3.3 仓库配置 四、Eclipse与Maven配置 五、Maven项目测试 5.1 新建Maven项目步骤及注意事项 5.…

【LeetCode】236. 二叉树的最近公共祖先、 JZ36 二叉搜索树与双向链表

作者:小卢 专栏:《Leetcode》 喜欢的话:世间因为少年的挺身而出,而更加瑰丽。 ——《人民日报》 236. 二叉树的最近公共祖先 236. 二叉树的最近公共祖先 题目描述: 给定一个二叉树…

为MySQL新增一张performance_schema表 | StoneDB 技术分享会 #4

StoneDB开源地址 https://github.com/stoneatom/stonedb 设计:小艾 审核:丁奇、李浩 编辑:宇亭 作者:王若添 中国科学技术大学-软件工程-在读硕士、StoneDB 内核研发实习生 performance_schema 简介 MySQL 启动后会自动创建四…

中睿天下入选河南省网信系统2023年度网络安全技术支撑单位

近日,河南省委网信办发布了“河南省网信系统2023年度网络安全技术支撑单位名单”,中睿天下凭借出色的网络安全技术能力和优势成功入选。 本次遴选由河南省委网信办会同国家计算机网络与信息安全管理中心河南分中心(以下简称安全中心河南分中心…

高斯模糊与图像处理(Gaussian Blur)

高斯模糊在图像处理中的用途及其广泛,除了常规的模糊效果外,还可用于图像金字塔分解、反走样、高低频分解、噪声压制、发光效果等等等等。正因为高斯模糊太基础,应用太广泛,所以需要尽可能深入认识这个能力,避免在实际…

AttentionFreeTransformer 源码解析(一):AFTFull、AFTSimple、AFTLocal

我觉得源码写的很好懂,我就不加注释了,直接上计算流程图。 AFTFull class AFTFull(nn.Module):def __init__(self, max_seqlen, dim, hidden_dim64):super().__init__()max_seqlen: the maximum number of timesteps (sequence length) to be fed indim…

DP(区间DP)

石子合并 设有 N 堆石子排成一排,其编号为 1,2,3,…,N。 每堆石子有一定的质量,可以用一个整数来描述,现在要将这 N 堆石子合并成为一堆。 每次只能合并相邻的两堆,合并的代价为这两堆石子的质量之和,合并后与这两堆…

cesium学习记录06-视图、场景与相机

一、视图(Viewer) viewer是cesium的核心类,是一切的开端。通过new Cesium.Viewer(container, options)来创建一个Viewer对象,而通过这个 Viewer对象,可以添加图层、实体、相机控制等,以及设置一些全局属性…

esp8266使用arduinoJson与tft_espi库发生冲突解决方法

esp8266使用arduinoJson与tft_espi库发生冲突解决方法 arduinoJson与tft_espi库发生冲突解决方法下载arduinoJson5.0版本的,不要用最新版本 示范代码: // Copyright Benoit Blanchon 2014 // MIT License // // Arduino JSON library // https://git…

Unity游戏源码分享-仿帝国时代游戏Demo-uRTS Toolkit

Unity游戏源码分享-仿帝国时代游戏Demo-uRTS Toolkit 游戏的架构值得参考 项目地址:https://download.csdn.net/download/Highning0007/88189905

Pycharm 双击启动失败?

事故 双击 Pycharm 后,出现加载工程,我不想加载这个工程,就点击了弹出的 cancle 取消按钮。然后再到桌面双击 Pycharm 却发现无法启动了。哪怕以管理员权限运行也没用,就是不出界面。 原因未知 CtrlshiftESC 打开后台&#xff…

易服客工作室:如何创建有用的内容日历

利用技巧和工具优化您的内容营销效率和效果。创建一个内容日历,您的整个团队都会从中受益! 欢迎来到熙熙攘攘、瞬息万变的内容营销世界,在这里,截止日期到来的速度比喝咖啡的猎豹还要快。 现在,想象一下在没有地图、…

MapBox加载不同风格

初始化MapBox地图: var map new mapboxgl.Map({container: map,zoom: 3,center: [105, 34],//此处更改地图风格style: mapbox://styles/mapbox/satellite-v9,hash: false,});1.户外地图(mapbox://styles/mapbox/basic-v9)新版:&a…

设计模式之模板方法

一、概述 定义一个操作中的算法的骨架,将一些步骤延迟到子类中。 TemplateMethod使得子类可以不改变一个算法的结构即可重新定义该算法的某些特定步骤。 二、适用性 1.一次性实现一个算法的不变的部分,并将可变的行为留给子类来实现。 2.各子类中公共…

YOLOv5修改注意力机制CBAM

直接上干货 CBAM注意力机制是由通道注意力机制(channel)和空间注意力机制(spatial)组成。 传统基于卷积神经网络的注意力机制更多的是关注对通道域的分析,局限于考虑特征图通道之间的作用关系。CBAM从 channel 和 sp…

Kafka与Zookeeper版本对应关系

文章目录 了解版本对应Kafka安装包Kafka源码包 了解 比如: kafka_2.11-1.1.1.jar包 其中2.11表示的是Scala的版本,因为Kafka服务器端代码完全由Scala语音编写。”-“后面的1.1.1表示的kafka的版本信息。遵循一个基本原则,Kafka客户端版本和服…

MySQL5.7数据库、Navicat Premium1.6可视化工具安装教程【详细教程】

文章目录 一、MySQL、Navicat、注册机地址二、安装(一)、MySQL安装(二)、Navicat Premium安装(三)、集活Navicat Premium 三、遇到的问题1、Are you sure your navicat has not beenpatched/modified befor…

Spring 事务管理

目录 1. 事务管理 1.1. Spring框架的事务支持模型的优势 1.1.1. 全局事务 1.1.2. 本地事务 1.1.3. Spring框架的一致化编程模型 1.2. 了解Spring框架的事务抽象(Transaction Abstraction) 1.2.1. Hibernate 事务设置 1.3. 用事务同步资源 1.3.1…