循环神经网络的变体模型-LSTM、GRU

一.LSTM(长短时记忆网络)

1.1基本介绍

长短时记忆网络(Long Short-Term Memory,LSTM)是一种深度学习模型,属于循环神经网络(Recurrent Neural Network,RNN)的一种变体。LSTM的设计旨在解决传统RNN中遇到的长序列依赖问题,以更好地捕捉和处理序列数据中的长期依赖关系。

下面是LSTM的内部结构图

LSTM

LSTM为了改善梯度消失,引入了一种特殊的存储单元,该存储单元被设计用于存储和提取长期记忆。与传统的RNN不同,LSTM包含三个关键的门(gate)来控制信息的流动,这些门分别是遗忘门(Forget Gate)、输入门(Input Gate)和输出门(Output Gate)。

LSTM的结构允许它有效地处理和学习序列中的长期依赖关系,这在许多任务中很有用,如自然语言处理、语音识别和时间序列预测。由于其能捕获长期记忆,LSTM成为深度学习中重要的组件之一。

1.2 主要组成部分和工作原理

首先我们先弄明白LSTM单元中的每个符号的含义。每个黄色方框表示一个神经网络层,由权值,偏置以及激活函数组成;每个粉色圆圈表示元素级别操作;箭头表示向量流向;相交的箭头表示向量的拼接;分叉的箭头表示向量的复制。
图中元素的节点信息

以下是LSTM的主要组成部分和工作原理:

  1. 细胞状态(Cell State):
    细胞状态是LSTM网络的主要存储单元,用于存储和传递长期记忆。细胞状态在序列的每一步都会被更新。在LSTM中,细胞状态负责保留网络需要记住的信息,以便更好地处理长期依赖关系。在每个时间步,LSTM通过一系列的操作来更新细胞状态。这些操作包括遗忘门、输入门和输出门的计算。细胞状态在这些门的帮助下动态地保留和遗忘信息。
    细胞状态

  2. 遗忘门(Forget Gate):
    遗忘门决定哪些信息应该被遗忘,从而允许网络丢弃不重要的信息。它通过一个sigmoid激活函数生成一个介于0和1之间的值,用于控制细胞状态中信息的丢失程度。
    遗忘门的计算过程如下:
    2.1 输入:
    上一时刻的隐藏状态(或者是输入数据的向量)
    当前时刻的输入数据
    2.2 计算遗忘门的值:
    将上一时刻的隐藏状态和当前时刻的输入数据拼接在一起。
    通过一个带有sigmoid激活函数的全连接层(通常称为遗忘门层)得到介于0和1之间的值。
    这个值表示细胞状态中哪些信息应该被保留(接近1),哪些信息应该被遗忘(接近0)。
    2.3 遗忘操作:
    将上一时刻的细胞状态与遗忘门的输出相乘,以决定保留哪些信息。
    2.4数学表达式如下:
    遗忘门的输出:
    遗忘门

其中:
W f 和 b f 是遗忘门的权重矩阵和偏置向量。 W_f 和 b_f是遗忘门的权重矩阵和偏置向量。 Wfbf是遗忘门的权重矩阵和偏置向量。
h t − 1 ​是上一时刻的隐藏状态。 h_{t−1}​ 是上一时刻的隐藏状态。 ht1是上一时刻的隐藏状态。
x t 是当前时刻的输入数据。 x_t是当前时刻的输入数据。 xt是当前时刻的输入数据。
σ 是 s i g m o i d 激活函数。 σ 是sigmoid激活函数。 σsigmoid激活函数。

遗忘门的输出 ft 决定了细胞状态中上一时刻信息的保留程度。这个机制允许LSTM网络在处理时间序列数据时更有效地记住长期依赖关系。

  1. 输入门(Input Gate):
    输入门负责确定在当前时间步骤中要添加到细胞状态的新信息。类似于遗忘门,输入门使用sigmoid激活函数产生一个介于0和1之间的值,表示要保留多少新信息,并使用tanh激活函数生成一个新的候选值。
    在这里插入图片描述输入门的计算过程如下:
(1)输入门的输出计算:
    将上一时刻的隐藏状态(或者是输入数据)和当前时刻的输入数据拼接在一起。
    通过一个带有sigmoid激活函数的全连接层得到介于0和1之间的值。这个值表示要保留的新信息的程度。
(2)生成新的候选值:
	将上一时刻的隐藏状态(或者是输入数据)和当前时刻的输入数据拼接在一起。
	通过一个带有tanh激活函数的全连接层得到一个新的候选值(介于-1和1之间)。
(3)更新细胞状态的操作:
	将输入门的输出与新的候选值相乘,得到要添加到细胞状态的新信息。
  1. 输出门(Output Gate):
    输出门(Output Gate)在LSTM中控制细胞在特定时间步上的输出。输出门使用sigmoid激活函数产生介于0和1之间的值,这个值决定了在当前时间步细胞状态中有多少信息被输出。同时,输出门的输出与细胞状态经过tanh激活函数后的值相乘,产生最终的LSTM输出。

输出门的计算过程如下:

输出门的输出计算:
    将上一时刻的隐藏状态(或者是输入数据)和当前时刻的输入数据拼接在一起。
    通过一个带有sigmoid激活函数的全连接层得到介于0和1之间的值。
    这个值表示在当前时间步细胞状态中有多少信息要输出。
生成最终的LSTM输出:
	将当前时刻的细胞状态经过tanh激活函数,得到介于-1和1之间的值。
	将输出门的输出与tanh激活函数的细胞状态相乘,产生最终的LSTM输出。

在这里插入图片描述

1.3 LSTM的基础代码实现

以下是一个基础的实现,其中包括多层双向LSTM的前向传播。请注意,这个实现仍然是一个简化版本,实际应用中可能需要更多的调整和优化。

import numpy as np

def sigmoid(x):
    return 1 / (1 + np.exp(-x))

def tanh(x):
    return np.tanh(x)

def lstm_cell(xt, a_prev, c_prev, parameters):
    # 从参数中提取权重和偏置
    Wf = parameters["Wf"]
    bf = parameters["bf"]
    Wi = parameters["Wi"]
    bi = parameters["bi"]
    Wo = parameters["Wo"]
    bo = parameters["bo"]
    Wc = parameters["Wc"]
    bc = parameters["bc"]

    # 合并输入和上一个时间步的隐藏状态
    concat = np.concatenate((a_prev, xt), axis=0)

    # 遗忘门
    ft = sigmoid(np.dot(Wf, concat) + bf)
    
    # 输入门
    it = sigmoid(np.dot(Wi, concat) + bi)
    
    # 更新细胞状态
    cct = tanh(np.dot(Wc, concat) + bc)
    c_next = ft * c_prev + it * cct
    
    # 输出门
    ot = sigmoid(np.dot(Wo, concat) + bo)
    
    # 更新隐藏状态
    a_next = ot * tanh(c_next)
    
    # 保存计算中间结果,以便反向传播
    cache = (xt, a_prev, c_prev, a_next, c_next, ft, it, ot, cct)
    
    return a_next, c_next, cache

def lstm_forward(x, a0, parameters):
    n_x, m, T_x = x.shape
    n_a = a0.shape[0]
    a = np.zeros((n_a, m, T_x))
    c = np.zeros_like(a)
    caches = []
    
    a_prev = a0
    c_prev = np.zeros_like(a_prev)
    
    for t in range(T_x):
        xt = x[:, :, t]
        a_next, c_next, cache = lstm_cell(xt, a_prev, c_prev, parameters)
        a[:,:,t] = a_next
        c[:,:,t] = c_next
        caches.append(cache)
        a_prev = a_next
        c_prev = c_next
    
    return a, c, caches

def lstm_model_forward(x, parameters):
    caches = []
    a = x
    c_list = []
    
    for layer in parameters:
        a, c, layer_cache = lstm_forward(a, np.zeros_like(a[:, :, 0]), layer)
        caches.append(layer_cache)
        c_list.append(c)
    
    return a, c_list, caches

def dense_layer_forward(a, parameters):
    W = parameters["W"]
    b = parameters["b"]
    z = np.dot(W, a) + b
    a_next = sigmoid(z)
    return a_next, z

def model_forward(x, parameters_lstm, parameters_dense):
    a_lstm, c_list, caches_lstm = lstm_model_forward(x, parameters_lstm)
    
    a_dense = a_lstm[:, :, -1]
    z_dense_list = []
    
    for layer_dense in parameters_dense:
        a_dense, z_dense = dense_layer_forward(a_dense, layer_dense)
        z_dense_list.append(z_dense)
    
    return a_dense, c_list, caches_lstm, z_dense_list

# 示例数据和参数
np.random.seed(1)
x = np.random.randn(10, 5, 3)  # 10个样本,每个样本5个时间步,每个时间步3个特征

# LSTM参数
parameters_lstm = [
    {"Wf": np.random.randn(5, 8), "bf": np.random.randn(5, 1),
     "Wi": np.random.randn(5, 8), "bi": np.random.randn(5, 1),
     "Wo": np.random.randn(5, 8), "bo": np.random.randn(5, 1),
     "Wc": np.random.randn(5, 8), "bc": np.random.randn(5, 1)},
    {"Wf": np.random.randn(3, 8), "bf": np.random.randn(3, 1),
     "Wi": np.random.randn(3, 8), "bi": np.random.randn(3, 1),
     "Wo": np.random.randn(3, 8), "bo": np.random.randn(3, 1),
     "Wc": np.random.randn(3, 8), "bc": np.random.randn(3, 1)}
]

# Dense层参数
parameters_dense = [
    {"W": np.random.randn(1, 5), "b": np.random.randn(1, 1)},
    {"W": np.random.randn(1, 5), "b": np.random.randn(1, 1)}
]

# 进行正向传播
a_dense, c_list, caches_lstm, z_dense_list = model_forward(x, parameters_lstm, parameters_dense)

# 打印输出形状
print("a_dense.shape:", a_dense.shape)

二.GRU(门控循环单元)

GRU

2.1 GRU的基本介绍

门控循环单元(GRU,Gated Recurrent Unit)是一种用于处理序列数据的循环神经网络(RNN)变体,旨在解决传统RNN中的梯度消失问题,并提供更好的长期依赖建模。GRU引入了门控机制,类似于LSTM,但相对于LSTM,GRU结构更加简单。

GRU包含两个门:更新门(Update Gate)和重置门(Reset Gate)。这两个门允许GRU网络决定在当前时间步更新细胞状态的程度以及如何利用先前的隐藏状态。

重置门(Reset Gate)的计算:

通过一个sigmoid激活函数计算重置门的输出。重置门决定了在当前时间步,应该忽略多少先前的隐藏状态信息。

更新门(Update Gate)的计算:

通过一个sigmoid激活函数计算更新门的输出。更新门决定了在当前时间步,应该保留多少先前的隐藏状态信息。

候选隐藏状态的计算:

通过tanh激活函数计算一个候选的隐藏状态。

新的隐藏状态的计算:

通过更新门和候选隐藏状态计算新的隐藏状态。

2.2 GRU的代码实现

以下是使用PyTorch库实现基本的门控循环单元(GRU)的代码。PyTorch提供了GRU的高级API,可以轻松实现和使用。下面是一个简单的例子:

import torch
import torch.nn as nn

# 定义GRU模型
class SimpleGRU(nn.Module):
    def __init__(self, input_size, hidden_size):
        super(SimpleGRU, self).__init__()
        self.gru = nn.GRU(input_size, hidden_size)

    def forward(self, x, hidden=None):
        output, hidden = self.gru(x, hidden)
        return output, hidden

# 示例数据和模型参数
input_size = 3
hidden_size = 5
seq_len = 1  # 序列长度
batch_size = 1

# 创建GRU模型
gru_model = SimpleGRU(input_size, hidden_size)

# 将输入数据转换为PyTorch的Tensor
x = torch.randn(seq_len, batch_size, input_size)

# 前向传播
output, hidden = gru_model(x)

# 打印输出形状
print("Output shape:", output.shape)
print("Hidden shape:", hidden.shape)

以下是使用NumPy库实现基本的门控循环单元(GRU)的代码。这个实现是一个简化版本,其中包含更新门和重置门的计算,以及候选隐藏状态和新的隐藏状态的计算。

import numpy as np

def sigmoid(x):
    return 1 / (1 + np.exp(-x))

def tanh(x):
    return np.tanh(x)

def gru_cell(a_prev, x, parameters):
    # 从参数中提取权重和偏置
    W_r = parameters["W_r"]
    b_r = parameters["b_r"]
    W_z = parameters["W_z"]
    b_z = parameters["b_z"]
    W_a = parameters["W_a"]
    b_a = parameters["b_a"]

    # 计算重置门
    r_t = sigmoid(np.dot(W_r, np.concatenate([a_prev, x])) + b_r)

    # 计算更新门
    z_t = sigmoid(np.dot(W_z, np.concatenate([a_prev, x])) + b_z)

    # 计算候选隐藏状态
    tilde_a_t = tanh(np.dot(W_a, np.concatenate([r_t * a_prev, x])) + b_a)

    # 计算新的隐藏状态
    a_t = (1 - z_t) * a_prev + z_t * tilde_a_t

    # 保存计算中间结果,以便反向传播
    cache = (a_prev, x, r_t, z_t, tilde_a_t, a_t)

    return a_t, cache

# 示例数据和参数
np.random.seed(1)
a_prev = np.random.randn(5, 1)  # 上一时刻的隐藏状态
x = np.random.randn(3, 1)  # 当前时刻的输入数据

# GRU参数
parameters = {
    "W_r": np.random.randn(5, 8),
    "b_r": np.random.randn(5, 1),
    "W_z": np.random.randn(5, 8),
    "b_z": np.random.randn(5, 1),
    "W_a": np.random.randn(5, 8),
    "b_a": np.random.randn(5, 1)
}

# 单个GRU单元的前向传播
a_t, cache = gru_cell(a_prev, x, parameters)

# 打印输出形状
print("a_t.shape:", a_t.shape)

本文参考了以下链接:http://colah.github.io/posts/2015-08-Understanding-LSTMs/

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/330281.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

鸿蒙开发(四)UIAbility和Page交互

通过上一篇的学习,相信大家对UIAbility已经有了初步的认知。在上篇中,我们最后实现了一个小demo,从一个UIAbility调起了另外一个UIAbility。当时我提到过,暂不实现比如点击EntryAbility中的控件去触发跳转,而是在Entry…

自动驾驶预测-决策-规划-控制学习(5):图像分割与语义分割入门

提示:文章写完后,目录可以自动生成,如何生成可参考右边的帮助文档 文章目录 论文题目:Evolution of Image Segmentation using Deep Convolutional Neural Network: A Survey前言:图像分割与语义分割一、图像分割是什么…

vue3 实现简单计数器示例——一个html文件展示vue3的效果

目的&#xff1a;作为一个新手开发&#xff0c;我想使用 Vue 3 将代码封装在 HTML 文件中时&#xff0c;进行界面打开展示。 一、vue计数示例 学了一个简单计数器界面展示&#xff0c;代码如下&#xff1a; <!DOCTYPE html> <html lang"en"><head&…

嵌入式-Stm32-江科大基于标准库的GPIO的八种模式

文章目录 一&#xff1a;GPIO输入输出原理二&#xff1a;GPIO基本结构三&#xff1a;GPIO位结构四&#xff1a;GPIO的八种模式道友&#xff1a;相信别人&#xff0c;更要一百倍地相信自己。 &#xff08;推荐先看文章&#xff1a;《 嵌入式-32单片机-GPIO推挽输出和开漏输出》…

宏集干货丨探索物联网HMI的端口转发和NAT功能

来源&#xff1a;宏集科技 工业物联网 宏集干货丨探索物联网HMI的端口转发和NAT功能 原文链接&#xff1a;https://mp.weixin.qq.com/s/zF2OqkiGnIME6sov55cGTQ 欢迎关注虹科&#xff0c;为您提供最新资讯&#xff01; #工业自动化 #工业物联网 #HMI 前 言 端口转发和NAT功…

Qt纯代码实现UI界面

1.相关信息 设置编辑框内容的字体样式&#xff0c;包括加粗、下划线、斜体、蓝色、红色、黑色 2.界面展示 3.相关代码 #include "dialog.h" #include <QHBoxLayout> #include <QVBoxLayout> #include <QCheckBox> #include <QRadioButton> …

【软件测试学习笔记6】Linux常用命令

格式 command [-options] [parameter] command 表示的是命令的名称 []表示是可选的&#xff0c;可有可无 [-options]&#xff1a;表示的是命令的选项&#xff0c;可有一个或多个&#xff0c;也可以没有 [parameter]&#xff1a;表示命令的参数&#xff0c;可以有一个或多…

清晰光谱空间:全自动可调波长系统的高光谱成像优势

高光谱成像技术 高光谱成像技术是一种捕获和分析宽波长信息的技术&#xff0c;能够对材料和特征进行详细的光谱分析和识别。高光谱成像技术的实现通过高光谱相机&#xff0c;其工作原理是使用多个光学传感器或光学滤波器分离不同波长的光&#xff0c;并捕获每个波段的图像&…

前端:布局(用于div中有多行元素,一行只显示四个,最左或最右要紧贴父div,最顶层和最底层也要紧贴父div)

效果 一、flex实现 html <!DOCTYPE html> <html><head><title>Flexbox Layout</title><style>.container {display: flex;flex-wrap: wrap;justify-content: space-between;gap: 10px;border: 1px solid red;}.box {flex: 1 0 calc(25% …

rsync全面讲解

rsync 是一个常用的 Linux 应用程序&#xff0c;用于文件同步。 它可以在本地计算机与远程计算机之间&#xff0c;或者两个本地目录之间同步文件&#xff08;但不支持两台远程计算机之间的同步&#xff09;。它也可以当作文件复制工具&#xff0c;替代cp和mv命令。 它名称里面…

逆向使用webpack打包的网站

webpack webpack 是 JavaScript 应用程序的模块打包器,可以把开发中的所有资源&#xff08;图片、js文件、css文件等&#xff09;都看成模块&#xff0c;通过loader&#xff08;加载器&#xff09;和 plugins &#xff08;插件&#xff09;对资源进行处理&#xff0c;打包成符…

JRTP实时音视频传输(2)-使用TCP通信的案例

1.创建自己的demo 先将example1拷贝为myclienttcp.cpp和myservertcp.cpp cp example1.cpp myclienttcp.cpp cp example1.cpp myservertcp.cpp 改写jrtplib/JRTPLIB/examples/CMakeLists.txt&#xff0c;添加myclienttcp和myservertcp编译 重新生成Makefile并编译 sudo cmak…

plc红绿灯程序

引言&#xff1a; PLC&#xff08;Programmable Logic Controller&#xff0c;可编程逻辑控制器&#xff09;是一种用于工业自动化控制的电子设备。西门子的SIMATIC S7-200是这类设备的一个流行系列&#xff0c;广泛应用于小型至中等规模的自动化项目中。它具有以下特点&#…

pytorch学习(一)线性模型

文章目录 线性模型 pytorch是一个基础的python的科学计算库&#xff0c;它有以下特点&#xff1a; 类似于numpy&#xff0c;但是它可以使用GPU可以用它来定义深度学习模型&#xff0c;可以灵活的进行深度学习模型的训练和使用 线性模型 线性模型的基本形式为&#xff1a; f ( x…

推荐一款性价比高的USB 协议分析仪

最近在入门学习USB 协议&#xff0c;USB 协议是出了名的晦涩难懂&#xff0c;调试过程中如果没有合适的工具帮助分析&#xff0c;就像电工没有电表笔一样&#xff0c;难以诊断各种奇难杂症。 于是网上找了一下USB 协议分析仪&#xff0c;一看价格超过3位数的就不考虑了&#x…

Java关键字static和final

一、final关键字是什么&#xff1f; 1、final可以用来修饰的结构&#xff1a;类、方法、变量 2、final用来修饰一个类&#xff1a;此类不能被其它类继承。当我们需要让一个类永远不被继承&#xff0c;此时就可以用final修饰&#xff0c;但要注意&#xff1a;final类中所有的成…

ArcGIS Pro 如何新建布局

你是否已经习惯了在ArcGIS中数据视图和布局视图之间来回切换&#xff0c;到了ArcGIS Pro中却找不到二者之间切换的按钮&#xff0c;即使新建布局后却发现地图怎么却是一片空白。 这一切的一切都是因为ArcGIS Pro的功能框架完全不同&#xff0c;这里为大家介绍一下在ArcGIS Pro…

微信小程序(五)下拉刷新

注释很详细&#xff0c;直接上代码 上一篇 新增内容&#xff1a; 1. 下拉刷新 2. 下拉页面背景颜色 3. 设置是否可滚动 4. 设置导航栏模式 源码&#xff1a;(实际上不能加注释但这里为了方便解释就加上了) index.json {//默认模式&#xff0c;另一种自定义模式是custom//自定义…

课表排课小程序怎么制作?多少钱?

在当今的数字化时代&#xff0c;无论是购物、支付、点餐&#xff0c;还是工作、学习&#xff0c;都离不开各种各样的微信小程序。其中&#xff0c;课表排课小程序就是许多教育机构和学校必不可少的工具。那么课表排课小程序怎么制作呢&#xff1f;又需要多少钱呢&#xff1f; …

RK3399平台入门到精通系列讲解(USB篇)UDC 层 usb_gadget_probe_driver 接口分析

🚀返回总目录 文章目录 一、UDC:usb_gadget_probe_driver函数分析二、usb_gadget_driver 结构详细介绍三、usb_udc 结构详细介绍一、UDC:usb_gadget_probe_driver函数分析 UDC层的一项基本任务是向上层提供usb_gadget_probe_driver()接口函数。 上层调用者为composite.c中…