循环神经网络RNN原理与优化

目录

前言

RNN背景

RNN原理

上半部分:RNN结构及按时间线展开图

下半部分:RNN在不同时刻的网络连接和计算过程

LSTM

RNN存在的问题

LSTM的结构与原理

数学表达层面

与RNN对比优势

应用场景拓展

从简易但严谨的代码来看RNN和LSTM

RNN

LSTM


前言

绕循环神经网络(RNN)、注意力机制(Attention)以及相关模型(如 LSTM、Transformer、BERT、GPT 等)在深度学习中的应用展开,介绍了其原理、结构、算法流程和实际应用场景。

RNN背景

RNN 产生的原因深度神经网络(DNN)在处理输入时,每个输入之间相互独立,无法处理序列信息。然而在自然语言处理(NLP)和视频处理等任务中,需要考虑输入元素之间的关联性,因此引入 RNN。以 NLP 中的词性标注任务为例,需处理单词序列才能准确标注词性,仅单独理解每个单词是不够的。

RNN 的结构与公式:RNN 在结构上引入了循环层,其隐藏层状态St不仅取决于当前输入Xt,还与上一时刻的隐藏层状态相关。具体公式为

输出

这样结构使RNN网络能够序列信息进行处理

RNN原理

上半部分:RNN结构及按时间线展开图

RNN结构:

输入层(Input Layer):标记为“x”,接收输入数据。

隐藏层(Hidden Layer):标记为“s”,是RNN的核心部分,包含循环连接。图中显示了权重矩阵“U”(连接输入层和隐藏层)和“W”(隐藏层的循环连接)(输入层的)

输出层(Output Layer):标记为“o”,通过权重矩阵“V”与隐藏层相连(输出层的),产生最终输出。

按时间线展开:

将RNN在时间维度上展开,展示了不同时刻(t-1, t, t+1)的网络状态。每个时刻都有输入x_t、隐藏层状态s_t和输出o_t。权重矩阵“U”、“W”和“V”在不同时刻保持不变,体现了RNN在时间上共享参数的特性。

下半部分:RNN在不同时刻的网络连接和计算过程

t-1时刻

展示了隐藏层状态s的向量形式,s=[s1, s2, ..., sn],其中每个元素代表隐藏层的一个神经元状态。权重矩阵“W”连接了t-1时刻的隐藏层神经元。

t时刻

输入层:输入向量X=[x1, x2, ..., xm],其中m是输入维度。

隐藏层:通过权重矩阵“U”接收输入层的信息,并通过权重矩阵“W”接收t-1时刻的隐藏层状态信息。图中显示了隐藏层的计算过程,即

其中f是激活函数。

输出层:根据隐藏层状态S_t,通过权重矩阵“V”计算输出

其中g是输出层的激活函数。

LSTM

LSTM(Long - Short - Term Memory,长短期记忆网络)是为解决传统循环神经网络(RNN)存在的问题而设计的。

RNN存在的问题

RNN有两个主要问题。一是短期记忆问题,当处理足够长的序列时,它难以将早期时间步的信息传递到后期。比如处理一段文本进行预测时,可能会遗漏开头的重要信息。二是梯度消失问题,在反向传播过程中,梯度随着时间反向传播而缩小。当梯度值变得极小,对神经网络权重更新的贡献就很小,导致早期的层停止学习,这也使得RNN在处理长序列时容易遗忘之前的信息。

LSTM的结构与原理

输入

当前时刻输出保存当前细胞状态(传递给下一个‘细胞’)

LSTM通过引入“细胞状态(cell state)”和“门(gate)”机制来解决上述问题:

细胞状态:就像一条传送带,在整个网络中运行,它可以在序列的不同时间步之间传递信息,使得LSTM能够处理长序列而不容易丢失早期信息。

门:

遗忘门(forget gate):决定从细胞状态中丢弃哪些信息。它读取当前输入和上一时刻隐藏状态,输出一个0 - 1之间的值,1表示“完全保留”,0表示“完全丢弃”。

输入门(input gate):确定要在细胞状态中存储哪些新信息。它包含一个sigmoid层来决定更新哪些值,以及一个tanh层来创建新的候选值向量,这些候选值可能会被添加到细胞状态中。

输出门(output gate):确定LSTM的输出。它首先通过sigmoid层决定细胞状态的哪些部分将被输出,然后将细胞状态通过tanh层(将值映射到 - 1到1之间),并将其与sigmoid层的输出相乘,得到最终的输出。

通过这些机制,LSTM能够更好地处理长序列数据,有选择性地记忆和遗忘信息,有效克服了RNN的短期记忆和梯度消失问题,这也是LSTM在后续的一些自然语言处理、语音识别等领域得到广泛应用的主要原因。

数学表达层面

遗忘门计算:

,其中W_f是权重矩阵,[h_{t - 1},x_t]是上一时刻隐藏状态和当前输入的拼接(‘细胞’传递),b_f是偏置项(截距),sigma是sigmoid激活函数,输出值在0 - 1之间,决定从细胞状态中遗忘的信息比例。

输入门计算:

确定更新值比例,

生成候选值向量,二者后续用于更新细胞状态。

细胞状态更新:

是逐元素相乘,即结合遗忘门输出、上一时刻细胞状态、输入门输出和候选值向量来更新细胞状态。

输出门计算:

决定输出比例,

得到最终隐藏状态输出。

与RNN对比优势

长期依赖处理:RNN受限于梯度消失难以保持长期依赖,LSTM通过门控机制控制细胞状态信息流,能有效保存和传递长距离信息,比如在处理长篇小说文本时,可记住开头人物关系等信息用于后续情节理解和生成。

学习效率:RNN因梯度问题早期层学习困难,LSTM通过门控灵活控制信息流动,更高效学习,在训练时间和收敛速度上表现更好,在语音识别任务中,可更快学习到语音序列中的特征模式。

应用场景拓展

自然语言处理:除常见的文本生成、机器翻译、情感分析,在文本摘要提取中,能抓住长文本关键信息;在命名实体识别中,准确识别不同类型实体。

时间序列预测:在金融领域,预测股票价格、汇率等波动;在能源领域,预测电力负荷、能源消耗等,利用其对时间序列中长短期信息的捕捉能力提高预测准确性。

视频处理:分析视频帧序列,用于动作识别、视频内容理解与生成,如判断视频中人物动作类别,生成符合逻辑的视频字幕等。

从简易但严谨的代码来看RNN和LSTM

通过pytorch框架定义只有一个’细胞RNNLSTM,进一步理解这两个网络架构应用

RNN

import torch
import torch.nn as nn

# 定义RNN模型
class SimpleRNN(nn.Module):
    def __init__(self, input_size, hidden_size, num_layers, output_size):
        super(SimpleRNN, self).__init__()
        self.hidden_size = hidden_size
        self.num_layers = num_layers
        self.rnn = nn.RNN(input_size, hidden_size, num_layers, batch_first=True)
        self.fc = nn.Linear(hidden_size, output_size)

    def forward(self, x):
        h0 = torch.zeros(self.num_layers, x.size(0), self.hidden_size).to(x.device)
        out, _ = self.rnn(x, h0)
        out = self.fc(out[:, -1, :])
        return out
        
# 示例参数
input_size = 10
hidden_size = 20
num_layers = 1
output_size = 5
batch_size = 3
seq_length = 8

# 创建输入数据
x = torch.randn(batch_size, seq_length, input_size)

# 实例化RNN模型
model = SimpleRNN(input_size, hidden_size, num_layers, output_size)

# 前向传播
output = model(x)
print(output.shape)

说明

定义了一个简单的SimpleRNN类继承自nn.Module。在构造函数中,初始化了 RNN 层和全连接层。nn.RNN指定了输入维度input_size、隐藏层维度hidden_size、层数num_layers,并设置batch_first=True表示输入数据的形状为(batch_size, seq_length, input_size)。

forward方法中,首先初始化隐藏状态h0,然后将输入数据x和初始隐藏状态传入 RNN 层,获取输出out。最后将 RNN 最后一个时间步的输出传入全连接层得到最终输出。

LSTM

import torch
import torch.nn as nn

# 定义LSTM模型
class SimpleLSTM(nn.Module):
    def __init__(self, input_size, hidden_size, num_layers, output_size):
        super(SimpleLSTM, self).__init__()
        self.hidden_size = hidden_size
        self.num_layers = num_layers
        self.lstm = nn.LSTM(input_size, hidden_size, num_layers, batch_first=True)
        self.fc = nn.Linear(hidden_size, output_size)

    def forward(self, x):
        h0 = torch.zeros(self.num_layers, x.size(0), self.hidden_size).to(x.device)
        c0 = torch.zeros(self.num_layers, x.size(0), self.hidden_size).to(x.device)
        out, _ = self.lstm(x, (h0, c0))
        out = self.fc(out[:, -1, :])
        return out
        
# 示例参数
input_size = 10
hidden_size = 20
num_layers = 1
output_size = 5
batch_size = 3
seq_length = 8

# 创建输入数据
x = torch.randn(batch_size, seq_length, input_size)

# 实例化LSTM模型
model = SimpleLSTM(input_size, hidden_size, num_layers, output_size)

# 前向传播
output = model(x)
print(output.shape)

说明

1.定义了SimpleLSTM类,同样继承自nn.Module。构造函数中初始化了 LSTM 层和全连接层,nn.LSTM的参数设置与 RNN 类似。

2.forward方法里,除了初始化隐藏状态h0,还初始化了细胞状态c0,然后将输入x、h0和c0传入 LSTM 层,获取输出out,最后经全连接层得到最终结果。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/973209.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

Mac arm架构使用 Yarn 全局安装 Vue CLI

dgqdgqdeMacBook-Pro spid-admin % vue --version zsh: command not found: vue要使用 Yarn 安装 Vue CLI,你可以执行以下命令: yarn global add vue/cli这个命令会全局安装 Vue CLI,让你可以使用 vue 命令创建、管理 Vue.js 项目。以下是一…

TensorFlow深度学习实战(8)——卷积神经网络

TensorFlow深度学习实战(8)——卷积神经网络 0. 前言1. 全连接网络的缺陷2. 卷积神经网络2.1 卷积神经网络的基本概念2.2 TensorFlow 中的卷积层2.3 TensorFlow 中的池化层2.4 卷积神经网络总结 3. 构建卷积神经网络3.1 LeNet3.2 使用 TensorFlow 实现 L…

.NET + Vue3 的前后端项目在IIS的发布

目录 一、发布准备 1、安装 IIS 2、安装 Windows Hosting Bundle(.NET Core 托管捆绑包) 3、安装 IIS URL Rewrite 二、项目发布 1、后端项目发布 2、前端项目发布 3、将项目部署到 IIS中 三、网站配置 1、IP配置 2、防火墙配置 3、跨域配置…

电脑想安装 Windows 11 需要开启 TPM 2.0 怎么办?

尽管 TPM 2.0 已经内置在许多新电脑中,但很多人并不知道如何激活这一功能,甚至完全忽略了它的存在。其实,只需简单的几步操作,你就能开启这项强大的安全特性,为你的数字生活增添一层坚固的防护屏障。无论你是普通用户还…

嵌入式开发岗位认识

目录 1.核心定义2.岗位方向3.行业方向4.技术方向5.工作职责6.核心技能7.等级标准8.优势与劣势9.市场薪资10. 发展路径11. 市场趋势12. 技术趋势 1.核心定义 嵌入式系统: 以应用为中心,以计算机技术为基础,软硬件可裁剪的专用计算机系统 特点…

爱普生 SG-8101CE 可编程晶振在笔记本电脑的应用

在笔记本电脑的精密架构中,每一个微小的元件都如同精密仪器中的齿轮,虽小却对整体性能起着关键作用。如今的笔记本电脑早已不再局限于简单的办公用途,其功能愈发丰富多样。从日常轻松的文字处理、网页浏览,到专业领域中对图形处理…

Python VsCode DeepSeek接入

Python VsCode DeepSeek接入 创建API key 首先进入DeepSeek官网,https://www.deepseek.com/ 点击左侧“API Keys”,创建API key,输出名称为“AI” 点击“创建",将API key保存,复制在其它地方。 在VsCode中下载…

基于eBPF的全栈可观测性系统:重新定义云原生环境诊断范式

引言:突破传统APM的性能桎梏 某头部电商平台采用eBPF重构可观测体系后,生产环境指标采集性能提升327倍:百万QPS场景下传统代理模式CPU占用达63%,而eBPF直采方案仅消耗0.9%内核资源。核心业务的全链路追踪时延从900μs降至18μs&a…

java项目之风顺农场供销一体系统的设计与实现(源码+文档)

风定落花生,歌声逐流水,大家好我是风歌,混迹在java圈的辛苦码农。今天要和大家聊的是一款基于ssm的风顺农场供销一体系统的设计与实现。项目源码以及部署相关请联系风歌,文末附上联系信息 。 项目简介: 风顺农场供销…

Spring MVC 的核心以及执行流程

Spring MVC的核心 Spring MVC是Spring框架中的一个重要模块,它采用了经典的MVC(Model-View-Controller)设计模式。 MVC是一种软件架构的思想,它将软件按照模型(Model)、视图(View)…

SQLMesh 系列教程6- 详解 Python 模型

本文将介绍 SQLMesh 的 Python 模型,探讨其定义、优势及在企业业务场景中的应用。SQLMesh 不仅支持 SQL 模型,还允许通过 Python 编写数据模型,提供更高的灵活性和可编程性。我们将通过一个电商平台的实例,展示如何使用 Python 模…

苍穹外卖知识点

导入依赖 Component Aspect public class MyselfAspect{Before("excution(* com.services.*.(..))")public myBefore(JointPoint jointPoint){System.out.println("在前面执行");} }只要注意如何使用Before注解就行了,里面存放的是*&#xff…

MySQL系列之身份鉴别(安全)

导览 前言Q:如何保障MySQL数据库身份鉴别的有效性一、有效性检查 1. 用户唯一2. 启用密码验证3. 是否存在空口令用户4. 是否启用口令复杂度校验5. 是否设置口令的有效期6. 是否限制登录失败尝试次数7. 是否设置(超过尝试次数)锁定的最小时长…

OneNote手机/平板“更多笔记本”中有许多已经删掉或改名的,如何删除

问题描述: OneNote 在手机或平板上添加“更多笔记本”中,有许多已经删掉或改名的笔记本!如何删除? OR:如何彻底删除OneNote中的笔记本? 处理做法: 这个列表对应365里面的【最近打开】&#…

区块链共识机制深度揭秘:从PoW到PoS,谁能主宰未来?

区块链的技术背后,最大的挑战之一就是如何让多个分布在全球各地的节点在没有中心化管理者的情况下达成一致,确保数据的一致性和安全性。这一切都依赖于区块链的核心——共识机制。共识机制不仅决定了区块链的安全性、效率和去中心化程度,还对…

观察者模式说明(C语言版本)

观察者模式主要是为了实现一种一对多的依赖关系,让多个观察者对象同时监听某一个主题对象。这个主题对象在状态发生变化时,会通知所有观察者对象,使它们能够自动更新自己。下面使用C语言实现了一个具体的应用示例,有需要的可以参考…

Linux System V - 消息队列与责任链模式

概念 消息队列是一种以消息为单位的进程间通信机制,允许一个或多个进程向队列中发送消息,同时允许一个或多个进程从队列中接收消息。消息队列由内核维护,具有以下特点: 异步通信:发送方和接收方不需要同时运行&#x…

微信小程序客服消息接收不到微信的回调

微信小程序客服消息,可以接收到用户进入会话事件的回调,但是接收不到用户发送消息的回调接口。需要在微信公众平台,把转发消息给客服的开关关闭。需要把这个开关关闭,否则消息会直接发送给设置的客服,并不会走设置的回…

pycharm社区版有个window和arm64版本,到底下载哪一个?还有pycharm官网

首先pycharm官网是这一个。我是在2025年2月16日9:57进入的网站。如果网站还没有更新的话,那么就往下滑一下找到 community Edition,这个就是社区版了免费的。PyCharm:适用于数据科学和 Web 开发的 Python IDE 适用于数据科学和 Web 开发的 Python IDE&am…

风险价值VaR、CVaR与ES

风险价值VaR、CVaR与ES 一、VaR风险价值1. VaR的定义及基本概念2.VaR的主要性质3.风险价值的优缺点 二、CVaR条件风险价值与ES预期损失1.CVaR的基本概念2.性质3.ES预期损失 一、VaR风险价值 1. VaR的定义及基本概念 20年前,JP的大佬要每天下午收盘后的4:15在桌上看…