# 深入理解RNN(一):循环神经网络的核心计算机制

深入理解RNN:循环神经网络的核心计算机制

RNN示意图
在这里插入图片描述

引言

在自然语言处理、时间序列预测、语音识别等涉及序列数据的领域,循环神经网络(RNN)一直扮演着核心角色。尽管近年来Transformer等架构逐渐成为主流,RNN的基本原理和思想依然对于理解深度学习处理序列数据的方式至关重要。本文将深入剖析RNN的核心计算机制,通过公式、代码和直观解释,帮助读者真正掌握这一经典算法的内在逻辑。

RNN的基本思想

传统前馈神经网络的主要局限在于:它无法处理序列数据中的时序依赖关系。每个输入被视为独立的个体,网络无法"记住"之前看到的信息。循环神经网络正是为了解决这一问题而设计的。

RNN的核心思想:通过在网络中引入循环连接,使当前时刻的输出不仅依赖于当前的输入,还依赖于之前时刻的"记忆"。这种设计使得RNN能够保持"状态",从而有效处理序列数据。

RNN的核心计算公式

RNN的计算过程可以用两个关键公式表示:

h t = t a n h ( W h ⋅ h t − 1 + W x ⋅ x t + b h ) h_t = tanh(W_h · h_{t-1} + W_x · x_t + b_h) ht=tanh(Whht1+Wxxt+bh)

先不要被吓到了,脑海中先想到CNN的y=kx+b , CNN里的线性变化。
然后想到y = tan ( kx+b ),引入非线性。
然后就是要引入上一步的信息,所以有了 W_h · h_{t-1} ,所以计算的这里,只是多了上一步的状态信息而已。就想着RNN相比CNN,其实就是多了一个缓冲区,会把上一步的隐藏层的值,加入到这里去计算,

y t = W y ⋅ h t + b y y_t = W_y · h_t + b_y yt=Wyht+by

这是最后输出当前值的步骤,y_t不参与后续的计算,是上一步的隐藏层的信息参与了计算。所以奥秘都在隐藏层里,最后这步的作用你可以理解为和CNN最后的FC层是一个意思

其中:

  • h t h_t ht 是当前时刻t的隐藏状态(即"记忆")
  • h t − 1 h_{t-1} ht1 是前一时刻的隐藏状态
  • x t x_t xt 是当前时刻的输入
  • y t y_t yt 是当前时刻的输出
  • W h W_h Wh, W x W_x Wx, W y W_y Wy 是权重矩阵
  • b h b_h bh, b y b_y by 是偏置项
  • t a n h tanh tanh 是激活函数(也可以使用其他函数如ReLU)

这两个公式理解了还是很简单,基本涵盖了RNN的全部精髓。让我们简要看看每个组成部分的意义。

公式详解:记忆与学习的数学表达

隐藏状态更新(第一个公式)

h t = t a n h ( W h ⋅ h t − 1 + W x ⋅ x t + b h ) h_t = tanh(W_h · h_{t-1} + W_x · x_t + b_h) ht=tanh(Whht1+Wxxt+bh)

这个公式描述了RNN如何更新其"记忆"。我们可以将其拆解为几个关键部分:

  1. 历史信息: W h ⋅ h t − 1 W_h · h_{t-1} Whht1
    • 这部分将前一时刻的隐藏状态 h t − 1 h_{t-1} ht1与权重矩阵 W h W_h Wh相乘
    • W h W_h Wh决定了保留多少历史信息,以及如何将这些信息与当前输入融合
    • 这正是RNN区别于传统神经网络的关键所在

这里可能不好理解,我们仍然可以把 h t − 1 h_{t-1} ht1看成一个变量x,哎,然后这个上一步的隐藏层的信息,我们是不是也要考虑下它如何影响下一步啊,因为每个数/词对下一个数/词的影响肯定是不同的,所以我们也给上一步的信息搞个 k x + b kx+b kx+b,也就是 W h ⋅ h t − 1 + b h ^ W_h · h_{t-1}+b_{\hat{h}} Whht1+bh^,然后放到公式里

h t = t a n h ( W h ⋅ h t − 1 + b h ^ + W x ⋅ x t + b h ) h_t = tanh(W_h · h_{t-1}+b_{\hat{h}}+ W_x · x_t + b_h) ht=tanh(Whht1+bh^+Wxxt+bh)

你一手常数项合并,咔,公式就出来了

h t = t a n h ( W h ⋅ h t − 1 + W x ⋅ x t + b h ) h_t = tanh(W_h · h_{t-1} + W_x · x_t + b_h) ht=tanh(Whht1+Wxxt+bh)

  1. 当前输入: W x ⋅ x t W_x · x_t Wxxt

    • 当前时刻的输入 x t x_t xt与权重矩阵 W x W_x Wx相乘
    • W x W_x Wx决定了网络如何解释当前输入的重要性
  2. 非线性变换: t a n h ( . . . ) tanh(...) tanh(...)

    • 将线性组合通过 t a n h tanh tanh激活函数进行非线性变换
    • t a n h tanh tanh将值压缩到[-1,1]范围,帮助稳定网络动态
    • 这种非线性是神经网络表达复杂函数的关键

输出层计算(第二个公式)

y t = W y ⋅ h t + b y y_t = W_y · h_t + b_y yt=Wyht+by

这个公式描述了RNN如何基于当前隐藏状态生成输出:

  1. 隐藏状态 h t h_t ht包含了直到当前时刻的所有相关信息的"摘要"
  2. 权重矩阵 W y W_y Wy将这个隐藏状态映射到所需的输出维度
  3. 输出 y t y_t yt可以是多种形式,取决于任务类型(如分类概率、预测值等)

没错,另一个 k x + b kx+b kx+b,不是吗?

RNN的维度分析

不用看,实践会告诉你答案,你会在你以后的代码实践中对维度有更深刻的理解

为了更好地理解RNN的计算过程,我们需要明确各个参数的维度:

假设:

  • 输入维度: x t ∈ R d i n x_t \in \mathbb{R}^{d_{in}} xtRdin
  • 隐藏状态维度: h t ∈ R d h h_t \in \mathbb{R}^{d_h} htRdh
  • 输出维度: y t ∈ R d o u t y_t \in \mathbb{R}^{d_{out}} ytRdout

则各权重矩阵的维度为:

  • W x ∈ R d h × d i n W_x \in \mathbb{R}^{d_h \times d_{in}} WxRdh×din
  • W h ∈ R d h × d h W_h \in \mathbb{R}^{d_h \times d_h} WhRdh×dh
  • W y ∈ R d o u t × d h W_y \in \mathbb{R}^{d_{out} \times d_h} WyRdout×dh
  • b h ∈ R d h b_h \in \mathbb{R}^{d_h} bhRdh
  • b y ∈ R d o u t b_y \in \mathbb{R}^{d_{out}} byRdout

这种维度设计确保了矩阵乘法的兼容性,同时也反映了数据在网络中的流动方式。

RNN的直观解释

抛开前面的数学公式,我们可以用更直觉的方式理解RNN的工作原理:

  1. 记忆机制:想象RNN有一个"记事本"(隐藏状态),它会在每个时间步更新这个记事本
  2. 选择性记忆:不是所有信息都同等重要,权重矩阵决定记住什么、忘记什么
  3. 信息混合:RNN将之前的记忆与新的观察结合起来,产生更新的理解
  4. 输出决策:基于当前的"记忆状态",RNN做出当前时刻的判断或预测

Python实现:手写一个简单RNN

让我们通过Python代码实现一个简单的RNN,以更好地理解其计算过程:

import numpy as np

class SimpleRNN:
    def __init__(self, input_size, hidden_size, output_size):
        """初始化RNN参数"""
        # 初始化权重矩阵(使用随机值)
        self.Wx = np.random.randn(hidden_size, input_size) * 0.01  # 输入到隐藏
        self.Wh = np.random.randn(hidden_size, hidden_size) * 0.01  # 隐藏到隐藏
        self.Wy = np.random.randn(output_size, hidden_size) * 0.01  # 隐藏到输出
        
        # 初始化偏置项
        self.bh = np.zeros((hidden_size, 1))  # 隐藏层偏置
        self.by = np.zeros((output_size, 1))  # 输出层偏置
        
        # 保存尺寸信息
        self.input_size = input_size
        self.hidden_size = hidden_size
        self.output_size = output_size
    
    def forward(self, x_sequence, h0=None):
        """前向传播过程"""
        # x_sequence形状: (seq_length, input_size, 1)
        seq_length = len(x_sequence)
        
        # 如果没有提供初始隐藏状态,则初始化为零
        if h0 is None:
            h0 = np.zeros((self.hidden_size, 1))
        
        # 保存所有时间步的隐藏状态和输出(用于反向传播)
        h = np.zeros((seq_length+1, self.hidden_size, 1))
        y = np.zeros((seq_length, self.output_size, 1))
        
        h[0] = h0  # 设置初始隐藏状态
        
        # 按时间步前向传播
        for t in range(seq_length):
            # 更新隐藏状态: h_t = tanh(W_h·h_{t-1} + W_x·x_t + b_h)
            h[t+1] = np.tanh(
                np.dot(self.Wh, h[t]) + 
                np.dot(self.Wx, x_sequence[t]) + 
                self.bh
            )
            
            # 计算输出: y_t = W_y·h_t + b_y
            y[t] = np.dot(self.Wy, h[t+1]) + self.by
        
        return y, h[1:]  # 返回所有输出和隐藏状态
    
    def predict(self, x_sequence):
        """使用模型进行预测"""
        y, _ = self.forward(x_sequence)
        return y

# 示例:如何使用这个RNN
if __name__ == "__main__":
    # 创建一个输入维度为3,隐藏层大小为5,输出维度为2的RNN
    rnn = SimpleRNN(input_size=3, hidden_size=5, output_size=2)
    
    # 创建一个序列数据:3个时间步,每步是一个3维向量
    seq_data = [
        np.array([[0.1], [0.2], [0.3]]),  # x_1
        np.array([[0.2], [0.3], [0.4]]),  # x_2
        np.array([[0.3], [0.4], [0.5]])   # x_3
    ]
    
    # 前向传播
    outputs, hidden_states = rnn.forward(seq_data)
    
    print("输出序列形状:", len(outputs), "x", outputs[0].shape)
    print("第一个时间步的输出:\n", outputs[0])
    print("最后一个时间步的隐藏状态:\n", hidden_states[-1])

RNN的缺点与改进版本

尽管RNN的设计非常优雅,但它存在一些严重的局限性:

  1. 梯度消失/爆炸问题:在长序列上,梯度要么趋近于零(无法学习),要么爆炸(不稳定)
  2. 长期依赖问题:基本RNN难以捕捉长距离的依赖关系
  3. 信息覆盖:新信息可能完全覆盖旧信息,导致"遗忘"重要的历史信息

为了解决这些问题,研究者提出了多种RNN的变体:

  1. LSTM (长短期记忆网络):引入了"门"机制,可以选择性地记住或忘记信息
  2. GRU (门控循环单元):LSTM的简化版本,性能相近但计算更高效
  3. 双向RNN:同时考虑过去和未来的信息,适用于有完整序列的场景

这些改进版本的核心计算公式更为复杂,后面有机会我们都摸一下,但基本思想与原始RNN相同:通过更新隐藏状态来保持对序列的"记忆"。

RNN在实际项目中的应用

RNN及其变体广泛应用于各种序列处理任务,至今RNN都在时序任务上仍有一席之地,但是那是另一个故事了。

总结:RNN的核心要点

  1. RNN的本质是一种带有循环连接的神经网络,使其能够处理序列数据
  2. 核心计算公式体现了RNN如何结合历史信息和当前输入
  3. 隐藏状态是RNN的"记忆",它随着序列处理不断更新
  4. 权重共享是RNN的关键特性,使其能够处理任意长度的序列
  5. 梯度问题是基本RNN的主要缺陷,导致了LSTM等改进版本的出现

尽管Transformer等新型架构在许多任务上已经超越了RNN,理解RNN的核心计算机制仍然是掌握序列模型的重要基础。RNN简洁的设计和直观的计算过程,体现了序列学习的基本原理,这些原理在更复杂的模型中依然适用。

哎,我上来就是一手 k x + b kx+b kx+b

参考资源

  1. Goodfellow, I., Bengio, Y., & Courville, A. (2016). Deep Learning. MIT Press.
  2. Karpathy, A. The Unreasonable Effectiveness of Recurrent Neural Networks. http://karpathy.github.io/2015/05/21/rnn-effectiveness/
  3. Olah, C. Understanding LSTM Networks. http://colah.github.io/posts/2015-08-Understanding-LSTMs/

关于作者:是个逗比

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/984275.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

深度学习实战车道线检测

深度学习实战车道线检测 这里写目录标题 车道线原理整体架构设计核心原理步骤1. 特征提取(骨干网络)2. 特征融合3. 车道线表示与分类4. 损失函数5. 后处理 速度优势的来源 软件实现安装环境与文件说明实验测试 结束语 车道线原理 Lane - Detection是一种…

【redis】五种数据类型和编码方式

文章目录 五种数据类型编码方式stringhashlistsetzset查询内部编码 五种数据类型 字符串:Java 中的 String哈希:Java 中的 HashMap列表:Java 中的 List集合:Java 中的 Set有序集合:除了存 member 之外,还有…

Next.js Server Action 提交 vs 前端 Fetch 提交:核心区别与优劣分析

在使用 Next.js 开发时,开发者经常会面临一个问题:前端的数据提交应该直接 Fetch 调用 API 还是使用 Next.js 提供的 Server Action 提交? 本文将深度解析: ✅ Server Action 提交数据的工作原理✅ 前端 Fetch 提交数据的优缺点…

DeepSeek开启AI办公新模式,WPS/Office集成DeepSeek-R1本地大模型!

从央视到地方媒体,已有多家媒体机构推出AI主播,最近杭州文化广播电视集团的《杭州新闻联播》节目,使用AI主持人进行新闻播报,且做到了0失误率,可见AI正在逐渐取代部分行业和一些重复性的工作,这一现象引发很…

混合存储HDD+SSD机型磁盘阵列,配上SSD缓存功能,性能提升300%

企业日常运行各种文件无处不在,文档、报告、视频、应用数据......面对成千上万的文件,团队之间需要做到无障碍协作,员工能够即时快速访问、共享处理文件。随着业务增长,数字化办公不仅需要大容量,快速高效的文件访问越…

【AI】什么是Embedding向量模型?我们应该如何选择?

我们之前讲的搭建本地知识库,基本都是使用检索增强生成(RAG)技术来搭建,Embedding模型则是RAG的核心,同时也是大模型落地必不可少的技术。那么今天我们就来聊聊Embedding向量模型: 一、Embedding模型是什么? Embedding模型是一种将离散数据(如文本、图像、用户行为等)…

Java在小米SU7 Ultra汽车中的技术赋能

目录 一、智能驾驶“大脑”与实时数据 场景一:海量数据的分布式计算 场景二:实时决策的毫秒级响应 场景三:弹性扩展与容错机制 技术隐喻: 二、车载信息系统(IVI)的交互 场景一:Android Automo…

【Python 数据结构 8.串】

目录 一、串的基本概念 1.串的概念 2.获取串的长度 3.串的拷贝 4.串的比较 5.串的拼接 6.串的索引 二、Python中串的使用 1.串的定义 2.串的拼接 3.获取串的长度 4.获取子串位置 5.获取字符串的索引 6.字符串的切片 7.字符串反转 8.字符串的比较 9.字符串的赋值 三、实战 1.344…

计算机视觉cv2入门之图像的读取,显示,与保存

在计算机视觉领域,Python的cv2库是一个不可或缺的工具,它提供了丰富的图像处理功能。作为OpenCV的Python接口,cv2使得图像处理的实现变得简单而高效。 示例图片 目录 opencv获取方式 图像基本知识 颜色空间 RGB HSV 图像格式 BMP格式 …

LLM 学习(二 完结 Multi-Head Attention、Encoder、Decoder)

文章目录 LLM 学习(二 完结 Multi-Head Attention、Encoder、Decoder)Self-Attention (自注意力机制)结构多头注意力 EncoderAdd & Norm 层Feed Forward 层 EncoderDecoder的第一个Multi-Head AttentionMasked 操作Teacher Fo…

006-获取硬件序列号

获取硬件序列号 我将从跨平台角度系统讲解如何通过C获取硬件序列号的核心技术&#xff0c;并提供可移植性代码实现。 一、处理器序列号获取 Windows平台 #include <windows.h> #include <intrin.h>std::string GetCPUSerial_Win() {DWORD cpuInfo[2] { 0 };__c…

GDB调试技巧:多线程案例分析(保姆级)

在软件开发的复杂世界里&#xff0c;高效的调试工具是解决问题的关键利器。今天&#xff0c;我们将深入探讨强大的调试工具 ——GDB&#xff08;GNU Debugger&#xff09;。GDB 为开发者提供了一种深入程序内部运行机制、查找错误和优化性能的有效途径。让我们一同开启 GDB 的调…

OSPF的各种LSA类型,多区域及特殊区域

一、OSPF的LSA类型 OSPF&#xff08;开放最短路径优先&#xff09;协议使用多种LSA&#xff08;链路状态通告&#xff09;类型来交换网络拓扑信息。以下是主要LSA类型的详细分类及其作用&#xff1a; 1. Type 1 LSA&#xff08;路由器LSA Router LSA&#xff09; 生成者&…

JavaScript系列06-深入理解 JavaScript 事件系统:从原生事件到 React 合成事件

JavaScript 事件系统是构建交互式 Web 应用的核心。本文从原生 DOM 事件到 React 的合成事件&#xff0c;内容涵盖&#xff1a; JavaScript 事件基础&#xff1a;事件类型、事件注册、事件对象事件传播机制&#xff1a;捕获、目标和冒泡阶段高级事件技术&#xff1a;事件委托、…

字节跳动C++客户端开发实习生内推-抖音基础技术

智能手机爱好者和使用者&#xff0c;追求良好的用户体验&#xff1b; 具有良好的编程习惯&#xff0c;代码结构清晰&#xff0c;命名规范&#xff1b; 熟练掌握数据结构与算法、计算机网络、操作系统、编译原理等课程&#xff1b; 熟练掌握C/C/OC/Swift一种或多种语言&#xff…

MySQL进阶-关联查询优化

采用左外连接 下面开始 EXPLAIN 分析 EXPLAIN SELECT SQL_NO_CACHE * FROM type LEFT JOIN book ON type.card book.card; 结论&#xff1a;type 有All ,代表着全表扫描&#xff0c;效率较差 添加索引优化 ALTER TABLE book ADD INDEX Y ( card); #【被驱动表】&#xff0…

ai之qwq 32B部署在 linux 与拓展使用在web参考

linux部署 Linux 命令行&#xff1a; curl -fsSL https://ollama.com/install.sh | sh2 将Ollama设置为系统启动时自动运行&#xff08;建议&#xff09; 创建系统用户和用户组 sudo useradd -r -s /bin/false -U -m -d /usr/share/ollama ollamasudo usermod -a -G ollama $…

景联文科技:以精准数据标注赋能AI进化,构筑智能时代数据基石

在人工智能技术席卷全球的浪潮中&#xff0c;高质量数据已成为驱动AI模型进化的核心燃料。作为全球领先的AI数据服务解决方案提供商&#xff0c;景联文科技深耕数据标注领域多年&#xff0c;以技术为基、以专业为本&#xff0c;致力于为全球客户提供全场景、高精度、多模态的数…

C语言_数据结构总结4:不带头结点的单链表

纯C语言代码&#xff0c;不涉及C 0. 结点结构 typedef int ElemType; typedef struct LNode { ElemType data; //数据域 struct LNode* next; //指针域 }LNode, * LinkList; 1. 初始化 不带头结点的初始化&#xff0c;即只需将头指针初始化为NULL即可 void Init…

IDEA 基础配置: maven配置 | 服务窗口配置

文章目录 IDEA版本与MAVEN版本对应关系maven配置镜像源插件idea打开服务工具窗口IDEA中的一些常见问题及其解决方案IDEA版本与MAVEN版本对应关系 查找发布时间在IDEA版本之前的dea2021可以使用maven3.8以及以前的版本 比如我是idea2021.2.2 ,需要将 maven 退到 apache-maven-3.…