LSTM原理解读与实战

在RNN详解及其实战中,简单讨论了为什么需要RNN这类模型、RNN的具体思路、RNN的简单实现等问题。同时,在文章结尾部分我们提到了RNN存在的梯度消失问题,及之后的一个解决方案:LSTM。因此,本篇文章主要结构如下:

  1. LSTM 理解及简单实现
  2. LSTM 实战
  3. 经典 RNN 与 LSTM 对比
  4. 关于梯度消失

LSTM 理解

其实,将 LSTM 与 RNN 说成两个并不可取, LSTM 依然归属于 RNN 之下,相比于使用线性回归方式来处理序列问题, LSTM 其实是设计了一个模块来取代线性回归算法。

LSTM(Long Short-Term Memory),翻译过来是长短期记忆法,其核心思想可以说非常的简单:既然 RNN 只能保存短期的记忆,那我增加一个长期记忆,不就可以解决这个问题了名?因此,LSTM提出了长期记忆和短期记忆,通过调整长期记忆和短期记忆之间的比例,来维持长期记忆的可靠,降低 RNN 的梯度消失问题。可以看到下方结构图中,模型输入由两个升级到三个,分别是当前节点状态 X t \mathbf{X}_{t} Xt,长期记忆: C t − 1 \mathbf{C}_{t-1} Ct1,短期记忆 H t − 1 \mathbf{H}_{t-1} Ht1。输出状态依然是两个:节点当前状态 C t \mathbf{C}_{t} Ct,和节点当前隐藏状态 H t \mathbf{H}_{t} Ht

在这里插入图片描述

那LSTM 是如何实现对长短记忆的控制呢?
这就不得不提众人所知的三个门:

  • 遗忘门:控制保留多少上一时刻的单元节点到当前节点
  • 记忆门:控制将当前时刻的多少信息记忆到节点中
  • 输出门:控制输出多少信息给当前输出

我们在分析三个门之前,我们先了解 这一概念。

从简化图中可以看到, 的感觉类似于电路中的一个开关,当开关按下,信息通过,而开关抬起,信息不再通过。实际也如此类似,是一个全连接层,输入为一个向量,输出为一个位于 [0,1] 之间的值。
我们来设计一个非常简单的遗忘门:每次学习状态之后,都遗忘一定的已学习内容,注意,这里的遗忘门与 LSTM 的遗忘门无关,单纯理解 这一概念。

# 一个线性层 用来计算遗忘多少
gate_linear = nn.Linear(hidden_size, 1)
# 一个线性层 用来学习
study_linear = nn.Linear(hidden_size, hidden_size)
# 此刻 h_t 是上一时刻状态
# 输出为 0 - 1 的值
gate = gate_linear(h_t)
# h_t 经过 study_linear 进行学习
_h_t = study_linear(h_t)
# 在输出结果之前,经过 gate 导致内容受损,遗忘了一定的学习内容
h_t = gate * (_h_t)

可以看到,如果 g a t e gate gate 值为 0,则历史信息均会被遗忘,而如果值为1,则历史信息则会被完全保留,而 gate_linear 网络中的超参数会不断的学习,因此一个可以学习的开关门就出现了。

但是, g a t e gate gate 作为一个浮点型的数据,对于 临时结果矩阵变量 _ h _ t \_h\_t _h_t 而言,其遗忘控制是全局的,也就是,当 g a t e gate gate 为 0 时, 其最终结果 h _ t h\_t h_t 为全 0 矩阵。因此我们应该注意: LSTM 中并不采用这样的大闸门,而是采用对每个变量进行分别控制的小水龙头(神经网络激活函数 nn.Sigmode )

而在 LSTM 中,门主要使用 S i g m o d Sigmod Sigmod 神经网络(再次注意,并非是激活函数,而是 Sigmod 神经网络)来完成。

下方是一个示例代码:

hidden_size = 5
sigmoid = nn.Sigmoid()
# 隐藏状态 为了方便计算,假定全 1
hidden_emb = torch.ones(hidden_size, hidden_size)
# 中间某一层神经网络
model = nn.Linear(hidden_size,hidden_size)
# 获取该层输出,此时尚未被门限制
mid_out = model(hidden_emb)
# 获取一个门 -- 注意:并非一定由该变量所控制
# 比如:也可以由上一时刻的隐藏状态控制
# 代码为: gate = sigmoid(hidden_emb)
gate = sigmoid(mid_out) 
# 得到最终输出
final_out = gate * mid_out

在有了对门的基础知识后,接下来对遗忘门、记忆门、输出门进行分别分析。

遗忘门

遗忘门涉及部分如下图所示:
在这里插入图片描述

其中,下方蓝色表示三个门共用的输入部分,均为 [ h t − 1 \mathbf{h}_{t-1} ht1, X t \mathbf{X}_{t} Xt],需要注意,这里由于三个门之间并不共享权重参数,因此公示虽然接近,但是一共计算了三次,遗忘门被标记为 f t f_t ft, 列出遗忘门公式为:
f t = σ ( W f ∗ [ h t − 1 , X t ] + b f ) f_t = \sigma(\mathbf{W_f} * [\mathbf{h}_{t-1},\mathbf{X}_{t}] + \mathbf{b_f}) ft=σ(Wf[ht1,Xt]+bf)
输出结果为取值范围为 [ 0, 1 ] 的矩阵,主要功能是控制与之相乘的矩阵的遗忘程度
f t f_t ft 与输入的上一长期状态 C t − 1 C_{t-1} Ct1 相乘:
C t ′ = f t ∗ C t − 1 C_t' = f_t * C_{t-1} Ct=ftCt1

一部分的 C t − 1 C_{t-1} Ct1 就这样被遗忘了。

记忆门

记忆门涉及部分如下所示:
在这里插入图片描述

从图中可以看到,记忆门中相乘的两个部分均由 h t − 1 \mathbf{h}_{t-1} ht1 X t \mathbf{X}_{t} Xt 得到,
其中,左侧控制记忆多少的部分,与遗忘门公式基本一致:
i t = σ ( W i ∗ [ h t − 1 , X t ] + b i ) i_t = \sigma(\mathbf{W_i} * [\mathbf{h}_{t-1},\mathbf{X}_{t}] + \mathbf{b_i}) it=σ(Wi[ht1,Xt]+bi)
与遗忘门相通,输出结果为取值范围为 [ 0, 1 ] 的矩阵,主要功能是控制与之相乘的矩阵的记忆程度
而右侧,则更换了激活函数,由 s i g m o i d sigmoid sigmoid 变成了 t a n h tanh tanh
C t ~ = tanh ⁡ ( W c ∗ [ h t − 1 , X t ] + b c ) \tilde{C_t} = \tanh(\mathbf{W_c} * [\mathbf{h}_{t-1},\mathbf{X}_{t}] + \mathbf{b_c}) Ct~=tanh(Wc[ht1,Xt]+bc)
该公式负责的部分可以看做负责短期隐藏状态的更新,取值范围为 [ -1, 1 ]。

最终记忆门更新公式如下:
C t ′ ~ = i t ∗ C t ~ \tilde{C_t'}= i_t * \tilde{C_t} Ct~=itCt~

可以说 C t ′ ~ \tilde{C_t'} Ct~ 是保留了一定内容的短期状态

状态更新

在这里插入图片描述

在通过遗忘门获取到了被遗忘一定内容的长期状态 C t ′ C_t' Ct 和 保留了一定内容的短期状态 C t ′ ~ \tilde{C_t'} Ct~ 之后,可以通过加法直接结合

C t = C t ′ + C t ′ ~ C_t = C_t' + \tilde{C_t'} Ct=Ct+Ct~

输出门

在这里插入图片描述

输出门是三个门中最后一个门,当数据到达这里的时候,我们主要控制将长期状态中的内容 C t C_t Ct 保存一定内容到 h t h_t ht 中,这里不再赘述
o t = σ ( W o ∗ [ h t − 1 , X t ] + b o ) o_t = \sigma(\mathbf{W_o} * [\mathbf{h}_{t-1},\mathbf{X}_{t}] + \mathbf{b_o}) ot=σ(Wo[ht1,Xt]+bo)

h t = o t ∗ tanh ⁡ ( C t ) h_t = o_t * \tanh(C_t) ht=ottanh(Ct)

模型总结

可以看到,所有公式的核心部分都是如此的相似:
W c ∗ [ h t − 1 , X t ] + b c \mathbf{W_c} * [\mathbf{h}_{t-1},\mathbf{X}_{t}] + \mathbf{b_c} Wc[ht1,Xt]+bc
而这部分其实又只是简单的线性函数,所以 LSTM 比 RNN 高级的地方其实并不在于某一条公式,而是它调整了数据之间的流动,按照一定的比例进行融合,弱化了长距离下的梯度消失问题。

最后总的来看,LSTM 其实就是一个升级版本的的 RNN,他额外初始化了一个状态 C C C, 用来保存长期的记忆,控制远距离上的参数权重。而输出也基本类似于此。

LSTM 实战

实验说明

实验数据集采用 IMDB 数据集。主要由电影评论构成,长度不均,但是长度在 1000 左右的数据属于常见数据。数据集样本均衡,数共计 50000 个样本,训练和测试各有 25000 个样本,同时训练和测试的正负比例均为 1:1。

根据我们对 RNN 的了解,这样的长度是很难学习到有效的知识的,所以很适合比较 RNN 与 LSTM 之间的区别。

为了方便代码复现,在实现中借助了 torchtext 来完成数据下载及加载。

为了证明模型真的有学习到一定的内容,所以对比实验中部分参数可能存在部分区别,可以在本地调整到同一参数进行细致的对比实验。

模型实现

分析一下由我实现的 LSTM 模型,并以此了解 LSTM 模型。

# 定义基础模型
class LSTM(nn.Module):
    def __init__(self, input_size, hidden_size, num_classes):
        """
        args:
            input_size: 输入大小
            hidden_size: 隐藏层大小
            num_classes: 最后输出的类别,在这个示例中,输出应该是 0 或者 1
        """
        super(LSTM, self).__init__()
        self.input_size = input_size
        self.hidden_size = hidden_size
        self.num_layers = num_layers
        self.fc_i = nn.Linear(input_size + hidden_size, hidden_size)
        self.fc_f = nn.Linear(input_size + hidden_size, hidden_size)
        self.fc_g = nn.Linear(input_size + hidden_size, hidden_size)
        self.fc_o = nn.Linear(input_size + hidden_size, hidden_size)
        self.sigmoid = nn.Sigmoid()
        self.tanh = nn.Tanh()
        self.fc_out = nn.Linear(hidden_size, num_classes)
    def forward(self, x):
        # 初始化隐藏状态 -- 短期记忆
        h_t = torch.zeros(x.size(0), x.size(1), self.hidden_size).to(x.device)
        # 初始化隐藏状态 -- 长期记忆
        c_t = torch.zeros(x.size(0), x.size(1), self.hidden_size).to(x.device)
        # 输入与短期记忆相拼接
        combined = torch.cat((x, h_t), dim=2)
        # 记忆门 -- 输出矩阵内容为 0-1 之间的数字
        i_t = self.sigmoid(self.fc_i(combined))
        # 遗忘门 -- 输出矩阵内容为 0-1 之间的数字
        f_t = self.sigmoid(self.fc_f(combined))
        #
        g_t = self.tanh(self.fc_g(combined))
        #  输出门 -- 输出矩阵内容为 0-1 之间的数字
        o_t = self.sigmoid(self.fc_o(combined))
        # 长期状态 =  遗忘门 * 上一时刻的长期状态 + 记忆门* 当前记忆状态
        c_t = f_t * c_t + i_t * g_t
        # 隐藏状态 = 输出门 * 长期状态
        h_t = o_t * self.tanh(c_t)
        # 降维操作 
        h_t = F.avg_pool2d(h_t, (h_t.shape[1],1)).squeeze()
        # 
        out = self.fc_out(h_t)
        return out 

超参数及参数说明

MyLSTM 与 nn.LSTM

名称
learning_rate0.001
batch_size32
epoch6(3)
input_size64
hidden_size128
num_classes2

此时:
MyLSTM 参数量: 99074
nn.LSTM 参数量: 99328

由于我实现的 MyLSTM 与 nn.LSTM 有 254 的参数差,我本人并没能分析出来差别。 nn.LSTM 在实验时大概率比我的 MyLSTM 迭代更快,所以容易较早的过拟合,所以将其训练 epoch 砍半,也就是说 MyLSTM 使用 6 epoch 进行训练,而 nn.LSTM 使用 3 epoch 进行训练。两者可以达到基本相近的效果

另外在代码实现中 nn.LSTM 后面加了一个 nn.Linear 来实现二分类,参数量为 258, 所以 MyLSTM 和 LSTM 相差参数总量为 512。

nn.RNN

名称
learning_rate0.0001
batch_size32
epoch12-18
input_size64
hidden_size128
num_classes2

此时:
nn.RNN 参数量: 25090

由于实验样本长度在 1000 上下, RNN 显示出来了极大的不稳定性,其中, 相较于 LSTM 更容易梯度爆炸、训练 epoch 更多、学习率需要调低等等问题,尽管如此依然不能保证稳定的良好结果。

举例来说,某学生学习阅读理解,要求根据文章内容回答文章的情感倾向,但是学生只喜欢看最后一句话,每次都根据最后一句话来回答问题,那么他基本上是等于瞎猜的,只能学到一点浅薄的知识。

实验结果

MyLSTMnn.LSTMnn.RNN
0.860.800.67

关于梯度问题

  • RNN问题中,总的梯度是不会消失的。即便梯度越传越弱,那也是远处的梯度逐渐消失,而近距离的梯度不会消失,因此,梯度总和不会消失。RNN 梯度消失的真正含义是:梯度被近距离梯度所主导,导致模型难以学到远距离的依赖关系。

  • LSTM 上有多条信息流路径,其中,元素相加的路径的梯度流是最稳定的,而其他路径上与基本的 RNN 相类似,依然存在反复相乘问题。

  • LSTM 刚刚提出时不存在遗忘门。这时候历史数据可以在这条路径上无损的传递,可以将其视为一条 高速公路,类似于 ResNet 中的残差连接。

  • 但是其他路径上, LSTM 与 RNN 并无太多区别,依然会爆炸或者消失。由于总的远距离梯度 = 各个路径的远距离梯度之和,因此只要有一条路的远距离梯度没有消失,总的远距离梯度就不会消失。可以说,LSTM 通过这一条路拯救了总的远距离梯度。

  • 同样,总的远距离梯度 = 各个路径的远距离梯度之和,虽然高速路上的梯度流比较稳定,但是其他路上依然存在梯度消失和梯度爆炸问题。因此,总的远距离梯度 = 正常梯度 + 爆炸梯度 = 爆炸梯度,因此 LSTM 依然存在梯度爆炸问题。 但是由于 LSTM 的道路相比经典 RNN 来说非常崎岖, 存在多次激活函数,因此 LSTM 发生梯度爆炸的概率要小得多。实践中通常通过梯度剪裁来优化问题。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/921562.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

Springboot之登录模块探索(含Token,验证码,网络安全等知识)

简介 登录模块很简单,前端发送账号密码的表单,后端接收验证后即可~ 淦!可是我想多了,于是有了以下几个问题(里面还包含网络安全问题): 1.登录时的验证码 2.自动登录的实现 3.怎么维护前后端…

使用Element UI实现前端分页,前端搜索,及el-table表格跨页选择数据,切换分页保留分页数据,限制多选数量

文章目录 一、前端分页1、模板部分 (\<template>)2、数据部分 (data)3、计算属性 (computed)4、方法 (methods) 二、前端搜索1、模板部分 (\<template>)2、数据部分 (data)3、计算属性 (computed)4、方法 (methods) 三、跨页选择1、模板部分 (\<template>)2、…

VMware Workstation 17.6.1

概述 目前 VMware Workstation Pro 发布了最新版 v17.6.1&#xff1a; 本月11号官宣&#xff1a;针对所有人免费提供&#xff0c;包括商业、教育和个人用户。 使用说明 软件安装 获取安装包后&#xff0c;双击默认安装即可&#xff1a; 一路单击下一步按钮&#xff1a; 等待…

探索PyMuPDF:Python中的强大PDF处理库

文章目录 **探索PyMuPDF&#xff1a;Python中的强大PDF处理库**第一部分&#xff1a;背景第二部分&#xff1a;PyMuPDF是什么&#xff1f;第三部分&#xff1a;如何安装这个库&#xff1f;第四部分&#xff1a;至少5个简单的库函数使用方法第五部分&#xff1a;结合至少3个场景…

go语言range的高级用法-使用range来接收通道里面的数据

在 Go 语言中&#xff0c;可以使用 for ... range 循环来遍历通道&#xff08;channel&#xff09;。for ... range 循环会一直从通道中接收值&#xff0c;直到通道关闭并且所有值都被接收完毕。 使用 for ... range 遍历通道 示例代码 下面是一个使用 for ... range 遍历通…

14.C++STL1(STL简介)

⭐本篇重点&#xff1a;STL简介 ⭐本篇代码&#xff1a;c学习/7.STL简介/07.STL简介 橘子真甜/c-learning-of-yzc - 码云 - 开源中国 (gitee.com) 目录 一. STL六大组件简介 二. STL常见算法的简易使用 2.1 swap ​2.2 sort 2.3 binary_search lower_bound up_bound 三…

5G CPE与4G CPE的主要区别有哪些

什么是CPE&#xff1f; CPE是Customer Premise Equipment&#xff08;客户前置设备&#xff09;的缩写&#xff0c;也可称为Customer-side Equipment、End-user Equipment或On-premises Equipment。CPE通常指的是位于用户或客户处的网络设备或终端设备&#xff0c;用于连接用户…

智能安全配电装置在高校实验室中的应用

​ 摘要&#xff1a;高校实验室是科研人员进行科学研究和实验的场所&#xff0c;通常会涉及到大量的仪器设备和电气设备。电气设备的使用不当或者维护不周可能会引发火灾事故。本文将以一起实验室电气火灾事故为例&#xff0c;对事故原因、危害程度以及防范措施进行分析和总结…

深入理解 LMS 算法:自适应滤波与回声消除

深入理解 LMS 算法&#xff1a;自适应滤波与回声消除 在信号处理领域&#xff0c;自适应滤波是一种重要的技术&#xff0c;广泛应用于噪声消除、回声消除和信号恢复等任务。LMS&#xff08;Least Mean Squares&#xff09;算法是实现自适应滤波的经典方法之一。本文将详细介绍…

如何在分布式环境中实现高可靠性分布式锁

目录 一、简单了解分布式锁 &#xff08;一&#xff09;分布式锁&#xff1a;应对分布式环境的同步挑战 &#xff08;二&#xff09;分布式锁的实现方式 &#xff08;三&#xff09;分布式锁的使用场景 &#xff08;四&#xff09;分布式锁需满足的特点 二、Redis 实现分…

socket连接封装

效果&#xff1a; class websocketMessage {constructor(params) {this.params params; // 传入的参数this.socket null;this.lockReconnect false; // 重连的锁this.socketTimer null; // 心跳this.lockTimer null; // 重连this.timeout 3000; // 发送消息this.callbac…

基于RM开发板32学习日记

环境配置 芯片选型 STM32F407IGH6 配置时钟 12 168 模块 Led 引脚选择 比对原理图 可查看 设置为Out_Put输出 三色同时点亮 合为白色光 HAL_GPIO_WritePin(LED_R_GPIO_Port, LED_R_Pin, GPIO_PIN_SET);HAL_GPIO_WritePin(GPIOH, GPIO_PIN_10, GPIO_PIN_SET);GPIOH->ODR…

MacOS下的Opencv3.4.16的编译

前言 MacOS下编译opencv还是有点麻烦的。 1、Opencv3.4.16的下载 注意&#xff0c;我们使用的是Mac&#xff0c;所以ios pack并不能使用。 如何嫌官网上下载比较慢的话&#xff0c;可以考虑在csdn网站上下载&#xff0c;应该也是可以找到的。 2、cmake的下载 官网的链接&…

刷题笔记15

问题描述 小M和小F在玩飞行棋。游戏结束后&#xff0c;他们需要将桌上的飞行棋棋子分组整理好。现在有 N 个棋子&#xff0c;每个棋子上有一个数字序号。小M的目标是将这些棋子分成 M 组&#xff0c;每组恰好5个&#xff0c;并且组内棋子的序号相同。小M希望知道是否可以按照这…

stm32 指定变量存储地址

uint8_t array[10] attribute((at(0x20000000))) 当你使用 attribute((at(地址))) 强制将变量放置在特定地址时&#xff0c;编译器和链接器通常不会自动调整其他变量的地址以避免冲突。这意味着&#xff0c;如果指定的地址已经被其他变量占用&#xff0c;就会发生冲突。 如果…

性能超越Spark 13.3 倍,比某MPP整体快数十秒 | 多项性能指标数倍于主流开源引擎 | 云器科技发布性能测试报告

云器Lakehouse正式发布性能测试报告 &#x1f3c5;离线批处理&#xff1a;在复杂批处理任务中&#xff0c;云器Lakehouse相较Spark表现出13.31倍性能提升。 &#x1f3c5;即席查询&#xff1a;在交互式分析场景下&#xff0c;云器Lakehouse相较Trino表现出9.84倍性能提升。 &am…

NIST 发布后量子密码学转型战略草案

美国国家标准与技术研究所 (NIST) 发布了其初步战略草案&#xff0c;即内部报告 (IR) 8547&#xff0c;标题为“向后量子密码标准过渡”。 该草案概述了 NIST 从当前易受量子计算攻击的加密算法迁移到抗量子替代算法的战略。该草案于 2024 年 11 月 12 日发布&#xff0c;开放…

论文阅读——Performance Evaluation of Passive Tag to Tag Communications(一)

文章目录 摘要一、互耦对监听器标签输入阻抗的影响A. 无限细偶极子互阻抗的理论研究B. 电细偶极子的情况&#xff1a;理论与模拟C. 印刷偶极子的情况&#xff1a;电磁模拟与测量 二、T2T 通信系统的性能评估总结 论文来源&#xff1a;https://ieeexplore.ieee.org/document/970…

IT人员面试重点底层逻辑概念

arrayList的底层原理 ArrayList是个动态数组&#xff0c;实现List接口&#xff0c;主要用来存储数据&#xff0c;如果存储基本类型的数据&#xff0c;如int&#xff0c;long&#xff0c;boolean&#xff0c;short&#xff0c;byte&#xff0c;那只存储它们对应的包装类。 它的…

PyTorch 分布式并行计算

0. Abstract 使用 PyTorch 进行多卡训练, 最简单的是 DataParallel, 仅仅添加一两行代码就可以使模型在多张 GPU 上并行地计算. 但它是比较老的方法, 官方推荐使用新的 Distributed Data Parallel, 更加灵活与强大: 1. Distributed Data Parallel (DDP) 从一个简单的非分布…