28. 深度学习进阶 - LSTM

文章目录


在这里插入图片描述

Hi, 你好。我是茶桁。

我们上一节课,用了一个示例来展示了一下我们为什么要用RNN神经网络,它和全连接的神经网络具体有什么区别。

这节课,我们就着上一节课的内容继续往后讲,没看过上节课的,建议回头去好好看看,特别是对代码的进程顺序好好的弄清楚。

全连接的模型得很仔细的去改变它的结构,然后再给它加很多东西,效果才能变好:

self.linear_with_tanh = nn.Sequential(
    nn.Linear(10, self.hidden_size),
    nn.Tanh(),
    nn.Linear(self.hidden_size, self.hidden_size),
    nn.Tanh(),
    nn.Linear(self.hidden_size, output_size)
)

但是对于RNN模型来说,我们只用了两个函数:

self.rnn = nn.RNN(x_size, hidden_size, n_layers, batch_first=True)
self.out = nn.Linear(hidden_size, output_size) 

这是一个很本质的问题, 也比较重要。为什么RNN的模型这么简单,它的效果比更复杂的全连接要好呢?

这个和我们平时生活中做各种事情其实都很类似,他背后的原因是他的信息保留的更多。RNN模型厉害的本质是在运行的过程中把更多的信息记录下来,而全连接没有记录。

对于RNN模型,还有两个点大家需要注意。

第一个,有一种叫做stacked的RNN的模型。我们RNN模型每一次输出都有一个output和hidden,把outputs和hidden作为它的输入再传给另外一个RNN模型,模型就变得更复杂,理论上可以解决些更复杂的场景。我们把这种就叫做stacked RNN。

Alt text

还有一种形式,Bidirectional RNN,双向RNN。有一个很著名的文本模型Bert, 那个B就是双向的意思。

我们回过头来看上节课我们讲过的两种网络:

h t = σ h ( W h x t + U h h t − 1 + b h ) y t = σ y ( W y h t + b y ) \begin{align*} h_t & = \sigma_h(W_hx_t + U_hh_{t-1} + b_h) \\ y_t & = \sigma_y(W_yh_t + b_y) \end{align*} htyt=σh(Whxt+Uhht1+bh)=σy(Wyht+by)

在这个里面,每一时刻的y_t只和y_{t-1} 有关系,如果把所有的x一次性给到模型的时候,其实我们在这里可以给它加一个东西:

h t = σ h ( W h x t + U h h t − 1 + V h ∗ h t + 1 + b h ) \begin{align*} h_t & = \sigma_h(W_hx_t + U_hh_{t-1} + V_h * h_{t+1} + b_h) \end{align*} ht=σh(Whxt+Uhht1+Vhht+1+bh)

还可以写成这样,那这样的话它实现的就是每一时刻的t既和前一次有关系
和后一刻有关系。这样我们每一次的值不仅和前面有关,还和后面有关。就叫做双向RNN。

Alt text

对于RNN来说,它有一个很严重的问题,就是之前说过的,它的vanishing和exploding的问题会很明显, 也就是梯度消失和爆炸问题。

在这里插入图片描述

想一下,现在如果有一个loss,那它最终的loss是不是对于{x1, x2, …, xn}都有关系,比方说现在要求 ∂ l o s s ∂ w 1 \frac{\partial loss}{\partial w_1} w1loss, 假如说现在h是100, 那这种调用关系就是

∂ l o s s ∂ w 1 = ∂ h 100 ∂ h 99 ⋅ ∂ h 99 ∂ h 98 ⋅ . . . ⋅ ∂ h 0 ∂ w 1 \begin{align*} \frac{\partial loss}{\partial w_1} = \frac{\partial h_{100}}{\partial h_{99}} \cdot \frac{\partial h_{99}}{\partial h_{98}} \cdot ... \cdot \frac{\partial h_{0}}{\partial w_{1}} \end{align*} w1loss=h99h100h98h99...w1h0

loss对于w1求偏导的时候,其实loss最先接受的是离他最近的, 假如说是h100。h100调用了h99,h99调用h98,就这个调用过程,这一串东西会变得很长。

我们之前课程说过一些情况,怎么去解决这个问题呢?对于RNN模型来说梯度爆炸很好解决,就直接设定一个阈值就可以了,起码也是能学习的。

在这里插入图片描述

要讲的是想一种方法怎么样来解决梯度消失的问题。这个梯度消失的解决方法,就叫LSTM。要解决梯度消失,就是要用LSTM: Long Short-Term Memory,长短记忆模型,既能保持长信息,又能保持短信息。

在之前那个很长的过程中,怎么样能够让它不消散呢?LSTM的核心思想是通过门控机制来控制信息的流动和及已的更新,包含了Input Gate, Forget Gate,Cell State以及Output Gate。这些会一起协作来处理序列数据。

其中Input Gate控制着新信息的输入,以及信息对细胞状态的影响。 Forget Gate控制着细胞状态中哪些信息应该被易王,Cell State用于传递信息,是LSTM的核心,Output Gate控制着细胞状态如何影响输出。

这里每一个门控单元都由一个Sigmoid激活函数来控制信息的流动,以及一个Tanh激活函数来确定信息的值。

I n p u t G a t e i t = σ ( W i ⋅ [ h t − 1 , x t ] + b i ) C ′ t = tanh ⁡ ( W c ⋅ [ h t − 1 , x t ] + b c ) C t = f t ⋅ C t − 1 + i t ⋅ C t ′ F o r g e t G a t e f t = σ ( W f ⋅ [ h t − 1 , x t ] + b f ) C t = f t ⋅ C t − 1 + i t ⋅ C t ′ O u t p u t G a t e o t = σ ( W o ⋅ [ h t − 1 , x t ] + b o ) h t = o t ⋅ tanh ⁡ ( C t ) \begin{align*} Input Gate \\ i_t & = \sigma(W_i \cdot [h_{t-1}, x_t] + b_i) \\ C't & = \tanh(W_c \cdot [h{t-1}, x_t] + b_c) \\ C_t & = f_t \cdot C_{t-1} + i_t \cdot C'_t \\ Forget Gate \\ f_t & = \sigma(W_f \cdot [h_{t-1}, x_t] + b_f) \\ C_t & = f_t \cdot C_{t-1} + i_t \cdot C'_t \\ Output Gate \\ o_t & = \sigma(W_o \cdot [h_{t-1}, x_t] + b_o) \\ h_t & = o_t \cdot \tanh(C_t) \end{align*} InputGateitCtCtForgetGateftCtOutputGateotht=σ(Wi[ht1,xt]+bi)=tanh(Wc[ht1,xt]+bc)=ftCt1+itCt=σ(Wf[ht1,xt]+bf)=ftCt1+itCt=σ(Wo[ht1,xt]+bo)=ottanh(Ct)

其中, h t − 1 h_{t-1} ht1 是前一个时间步的隐藏状态, x t x_t xt 是当前时间步的输入, W i , W f , W o , W c W_i, W_f, W_o, W_c Wi,Wf,Wo,Wc 是权重矩阵, b i , b f , b o , b c b_i, b_f, b_o, b_c bi,bf,bo,bc 是偏置。

Alt text

LSTM输入的是一个序列数据,可以是文本、时间序列,音频信号等等。那每个时间步的输入是序列中的饿一个元素,比如一个单词、一个时间点的观测值等等。

假设我们有一个序列 x = [x1, x2, …, xt], 其中t就代表的是时间步。

xt进来的时候, 之前我们是只接收一个hidden state, 现在我们多接收了一个 C t − 1 C_{t-1} Ct1,这个就是我们的Cell,这一步的 C t − 1 C_{t-1} Ct1其实就是上一步的 C t C_t Ct

在训练开始时,需要初始化LSTM单元的隐藏状态h0和细胞状态c0。通常我们初始化它们为全零向量。

最开始的时候,我们要进入Input Gate, 对于每个时间步t, 计算输入门的激活值 i t i_t it,控制新信息的输入。使用Sigmoid函数来计算输入门的值:

i t = σ ( W i ⋅ [ h t − 1 , x t ] + b i ) i_t = \sigma (W_i \cdot [h_{t-1}, x_t] + b_i) it=σ(Wi[ht1,xt]+bi)

然后,计算新的侯选值 C t ′ C'_t Ct, 这是在当前时间步考虑的新信息。使用tanh激活函数来计算侯选值:

C t ′ = t a n h ( W c ⋅ [ h t − 1 , x t ] + b c ) C'_t = tanh(W_c \cdot [h_{t-1}, x_t] + b_c) Ct=tanh(Wc[ht1,xt]+bc)

接下来我们就要更新细胞状态了,细胞状态 C t C_t Ct更新是通过遗忘门 f t f_t ft和输入门 i t i_t it控制的。遗忘门控制着哪些信息应该被遗忘,输入门控制新信息对细胞状态的影响:

C t = f t ⋅ C t − 1 + i t ⋅ C t ′ C_t = f_t \cdot C_{t-1} + i_t \cdot C'_t Ct=ftCt1+itCt

那遗忘门决定哪些信息应该被遗忘,使用的就是Sigmoid函数计算遗忘门的激活值。

f t = σ ( W f ⋅ [ h t − 1 , x t ] + b f ) f_t = \sigma(W_f \cdot [h_{t-1}, x_t] + b_f) ft=σ(Wf[ht1,xt]+bf)

接着,计算输出门 O t O_t Ot, 控制着细胞状态如何影响输出和隐藏状态。一样,我们还是使用Sigmoid函数计算。

o t = σ ( W o ⋅ [ h t − 1 , x t ] + b o ) o_t = \sigma(W_o \cdot [h_{t-1}, x_t] + b_o) ot=σ(Wo[ht1,xt]+bo)

使用输出门的值 o t o_t ot来计算最终的隐藏状态 h t h_t ht和输出。 隐藏状态和输出都是根据细胞状态和输出门的值来计算的:

h t = o t ⋅ t a n h ( C t ) h_t = o_t \cdot tanh(C_t) ht=ottanh(Ct)

接下来就容易了,我们迭代重复上述过程,处理序列中的每一个时间步,直到处理完整个序列。

LSTM的输出可以是隐藏状态 h t h_t ht, 也可以是细胞状态 C t C_t Ct, 具体是取决于应用的需求。

后来大家就发现了一种改进的LSTM,其中门控机制允许细胞状态窥视现前的细胞状态的信息,而不仅仅是根据当前时间步的输入和隐藏状态来决定。 这个机制在LSTM单源种引入了额外的权重和连接,以允许细胞状态在门控过程中访问现前的细胞状态,我们称之为窥视孔连接: Peephole connections。

f t = σ ( W f ⋅ [ C t − 1 , h t − 1 , x t ] + b f ) i t = σ ( W i ⋅ [ C t − 1 , h t − 1 , x t ] + b i ) o t = σ ( W o ⋅ [ C t − 1 , h t − 1 , x t ] + b o ) \begin{align*} f_t = \sigma(W_f \cdot [C_{t-1}, h_{t-1}, x_t] + b_f) \\ i_t = \sigma(W_i \cdot [C_{t-1}, h_{t-1}, x_t] + b_i) \\ o_t = \sigma(W_o \cdot [C_{t-1}, h_{t-1}, x_t] + b_o) \\ \end{align*} ft=σ(Wf[Ct1,ht1,xt]+bf)it=σ(Wi[Ct1,ht1,xt]+bi)ot=σ(Wo[Ct1,ht1,xt]+bo)

之前,我们是xt和x_{t-1}决定的f,那现在又把c_{t-1}加上了。就是多加了一些信息。

除此之外它有一个方法GRU,这个是2014年提出来的,Geted Recurrent Unit,它是LSTM的一个简化版本。

它最核心的内容:

h t = ( 1 − z t ) ⋅ h t − 1 + z t ⋅ h t ′ \begin{align*} h_t = (1-z_t) \cdot h_{t-1} + z_t \cdot h'_t \end{align*} ht=(1zt)ht1+ztht

咱们刚刚是 C t = f t ⋅ C t − 1 + i t ⋅ C t ′ C_t = f_t \cdot C_{t-1} + i_t \cdot C'_t Ct=ftCt1+itCt,也就是遗忘加上输入,那我们对过去保留越多的时候,
输入就会越小,那对过去保留越小的时候,输入就会越大。

所以既然f也是1-0,i也是0-1,f大的时候i就小,f小的时候i就大,那么能不能写成f=(1-i)?

于是,GRU就这样实现了, 它其实最核心的就做了这样一件事, f=(1-i)。

z t = σ ( W z ⋅ [ h t − 1 , x t ] ) r t = σ ( W r ⋅ [ h t − 1 , x t ] ) h t ′ = tanh ⁡ ( W ⋅ [ r t ⋅ h t − 1 , x t ] ) h t = ( 1 − z t ) ⋅ h t − 1 + z t ⋅ h t ′ \begin{align*} z_t & = \sigma(W_z \cdot [h_{t-1}, x_t]) \\ r_t & = \sigma(W_r \cdot [h_{t-1}, x_t]) \\ h'_t & = \tanh(W \cdot [r_t \cdot h_{t-1}, x_t]) \\ h_t & = (1-z_t) \cdot h_{t-1} + z_t \cdot h'_t \end{align*} ztrththt=σ(Wz[ht1,xt])=σ(Wr[ht1,xt])=tanh(W[rtht1,xt])=(1zt)ht1+ztht

这个z其实和i是一样的东西,只是原作者为了发表论文方便而改了个名称。

https://arxiv.org/pdf/1406.1078v3.pdf

r t r_t rt是来控制上一时刻的 h t h_t ht在我们此时此刻的重要性、影响程度。那我们可以将 r t ⋅ h t − 1 r_t \cdot h_{t-1} rtht1看成是关于及已的, 1 − z t 1-z_t 1zt也是关于记忆的。

GRU这样做之后有什么好处呢?

原来我们有三个门: f, i, o, 那现在变成了两个,z和r。为什么就更好了呢?我们在PyTorch里面往往用的是GRU。

大家想一下,是不是少了一个门其实就少了一个矩阵?我们看公式的时候, W f W_f Wf是一个数学符号,但是在背后其实是一个矩阵,是一个矩阵的话少了一个矩阵意味着参数就少多了,运算就更快了等等。

但其实这些都不是最关键的,最关键的是减少过拟合了。我们之前的课程中一再强调,过拟合之所以产生,最主要的原因是数据不够或者说是模型太复杂。

但是在现有的数据情况下,为了让数据发挥出最大效力,你把需要训练的模型变简单,参数变少,就没有那么复杂了。

关于RNN模型,我们后面还会介绍一些具体的示例。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/233383.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

深度学习 | 前馈神经网络与反向传播算法

目录 一、Logistic函数 二、前馈神经网络(FNN) 三、反向传播算法(BP算法) ​四、基于前馈神经网络的手写体数字识别 一、Logistic函数 Logistic函数是学习前馈神经网络的基础。所以在介绍前馈神经网络之前,我们首…

消息队列使用指南

介绍 消息队列是一种常用的应用程序间通信方法,可以用来在不同应用程序或组件之间传递数据或消息。消息队列就像一个缓冲区,接收来自发送方的消息,并存储在队列中,等待接收方从队列中取出并处理。 在分布式系统中,消…

对无向图进行邻接矩阵的转化,并且利用DFS(深度优先)和BFS(广度优先)算法进行遍历输出, 在邻接矩阵存储结构上,完成最小生成树的操作。

一 实验目的 1.掌握图的相关概念。 2.掌握用邻接矩阵和邻接表的方法描述图的存储结构。 3.掌握图的深度优先搜索和广度优先搜索遍历的方法及其计算机的实现。 4.理解最小生成树的有关算法 二 实验内容及要求 实验内容&#…

【Angular开发】Angular在2023年之前不是很好

做一个简单介绍,年近48 ,有20多年IT工作经历,目前在一家500强做企业架构.因为工作需要,另外也因为兴趣涉猎比较广,为了自己学习建立了三个博客,分别是【全球IT瞭望】,【架构师酒馆】…

第 119 场 LeetCode 双周赛题解

A 找到两个数组中的公共元素 模拟 class Solution { public:vector<int> findIntersectionValues(vector<int> &nums1, vector<int> &nums2) {unordered_set<int> s1(nums1.begin(), nums1.end()), s2(nums2.begin(), nums2.end());vector<…

C语言进阶之路-数据结构篇

目录 一、学习目标 二、数据结构 1.基本概念 线性关系&#xff1a; 非线性关系&#xff1a; 存储形式 2. 算法分析 2.1 时间复杂度 2.2 空间复杂度 2.3 时空复杂度互换 总结 一、学习目标 了解数据结构的基本概念了解算法的分析方法 二、数据结构 1.基本概念 数据结…

Si24R03—低功耗 SOC 芯片(集成RISC-V内核+2.4GHz无线收发器)

Si24R03是一款高度集成的低功耗SOC芯片&#xff0c;其集成了基于RISC-V核的低功耗MCU和工作在2.4GHz ISM频段的无线收发器模块。 MCU模块具有低功耗、Low Pin Count、宽电压工作范围&#xff0c;集成了13/14/15/16位精度的ADC、LVD、UART、SPI、I2C、TIMER、WUP、IWDG、RTC等丰…

hdlbits系列verilog解答(mt2015_q4)-54

文章目录 一、问题描述二、verilog源码三、仿真结果一、问题描述 本次使用系列文章52和53中实现的子模块,实现以下组合逻辑电路。 二、verilog源码 module top_module (input x, input y, output z);wire [3:0

EPICS modbus 模块数字量读写练习

本文使用modbus slave软件模拟一个受控的modbus设备&#xff0c;此模拟设备提供如下功能&#xff1a; 1、线圈组1&#xff0c;8个线圈&#xff0c;起始地址为0&#xff0c;数量为8&#xff0c;软件设置如下(功能码1)&#xff0c;用于测试功能码5&#xff0c;一次写一个线圈&am…

了解红帽认证

红帽公司成立于1993年&#xff0c;是全球首家收入超10亿美元的开源公司&#xff0c;总部位于美国&#xff0c;分支机构遍布全球。红帽公司作为全球领先的开源和Linux系统提供商&#xff0c;其产品已被业界广泛认可并使用&#xff0c;尤其是RHEL系统在业内拥有超高的Linux系统市…

利用Node.js和cpolar实现远程访问,无需公网IP和路由器设置的完美解决方案

文章目录 前言1.安装Node.js环境2.创建node.js服务3. 访问node.js 服务4.内网穿透4.1 安装配置cpolar内网穿透4.2 创建隧道映射本地端口 5.固定公网地址 前言 Node.js 是能够在服务器端运行 JavaScript 的开放源代码、跨平台运行环境。Node.js 由 OpenJS Foundation&#xff0…

Dockerfile的介绍和使用

什么是dockerfile? Dockerfile是一个包含用于组合映像的命令的文本文档。可以使用在命令行中调用任何命令。 Docker通过读取Dockerfile中的指令自动生成映像。 docker build命令用于从Dockerfile构建映像。可以在docker build命令中使用-f标志指向文件系统中任何位置的Dockerf…

c语言:理解和避免野指针

野指针的定义&#xff1a; 野指针是指一个指针变量存储了一个无效的地址&#xff0c;通常是一个未初始化的指针或者指向已经被释放的内存地址。当程序尝试使用野指针时&#xff0c;可能会导致程序崩溃、内存泄漏或者其他不可预测的行为。因此&#xff0c;在编程中需要特别注意…

二叉树前中后序遍历——(非)递归写法

文章目录 前言递归实现非递归实现力扣习题 红色&#xff1a;前序遍历顺序绿色&#xff1a;中序遍历顺序蓝色&#xff1a;后续遍历顺序 前言 二叉树遍历也分为两种 广度优先遍历&#xff08;Breadth-first order&#xff09;&#xff1a;尽可能先访问距离根最近的节点&#x…

未成年人保护成为《蛋仔派对》最高优先级工作,与家长携手保护孩子健康成长

《蛋仔派对》于近日发布致家长的第二封信&#xff0c;信中向社会各界公布了正在推出的三大“防沉迷”举措&#xff0c;严防“冒用成年人账号”等行为&#xff0c;针对家长关心的未成年防沉迷、冒用成年人账号、渠道服充值退款难等问题进行回应。 《蛋仔派对》表示始终把未成年…

多窗口文件管理工具Q-Dir安装以及使用教程

软件介绍 Q-Dir 是一款功能强大的Windows资源管理器&#xff0c;可以非常方便的管理你的各种文件。Q-Dir有4 个窗口&#xff0c;特别适用于频繁在各个目录间跳跃复制粘贴的情况&#xff0c;每个窗口都可以方便的切换目录&#xff0c;以不同颜色区分不同类型的文件&#xff0c;…

分销电商结算设计

概述 分销电商中涉及支付与结算&#xff1b;支付职责是收钱&#xff0c;结算则是出钱给各利益方&#xff1b; 结算核心围绕业务模式涉及哪些费用&#xff0c;以及这些费用什么时候通过什么出资渠道&#xff0c;由谁给到收方利益方&#xff1b; 结算要素组成费用项结算周期出…

持续集成交付CICD:Jenkins配置Nexus制品上传流水线

目录 一、实验 1.Jenkins配置制品上传流水线 二、问题 1.上传制品显示名称有误 一、实验 1.Jenkins配置制品上传流水线 (1) 新建流水线项目 &#xff08;2&#xff09;描述 &#xff08;3&#xff09;添加参数 &#xff08;4&#xff09;查看构建首页 &#xff08;5&…

体验一下使用 ArkUI 进行 HarmonyOS 开发并与 Compose 简单对比

前言 最近几年各个技术公众号和技术群都在唱衰原生安卓开发&#xff0c;疯狂贩卖焦虑。 搞得我也焦虑的不行&#xff0c;在谷歌的 Compose 推出后就赶紧去学&#xff0c;但是又觉得好像 Compose 的热度也不算太高&#xff0c;又去学 Flutter 。 转头两个都还没学明白呢&…

机器学习入门笔记

文章目录 背景具体步骤1.环境搭建2.写个demo1.数据处理2.分割数据集3.用模型训练数据&#xff0c;并得到预测结果4.绘制结果5.评估 背景 最近学习了一些关于机器学习的内容&#xff0c;做个笔记。 具体步骤 1.环境搭建 需要用到的工具&#xff1a;pycharm&#xff0c;anaco…