什么是循环神经网络?

一、概念

        循环神经网络(Recurrent Neural Network, RNN)是一类用于处理序列数据的神经网络。与传统的前馈神经网络不同,RNN具有循环连接,可以利用序列数据的时间依赖性。正因如此,RNN在自然语言处理、时间序列预测、语音识别等领域有广泛的应用。

二、核心算法

        RNN的核心思想是通过循环连接来记忆和处理序列数据中的上下文信息。RNN在每个时间步(time step)上接收当前输入和前一个时间步的隐藏状态(hidden state),并输出当前时间步的隐藏状态和输出。RNN的基本结构包括输入层、隐藏层和输出层。隐藏层的状态在每个时间步都会更新,并传递到下一个时间步。具体来说,RNN在时间步 t 的计算过程如下:

1、输入

        获取当前时间步的输入x_{t}

2、隐藏状态更新

        根据当前的输入x_{t}和前一个时间步的隐藏状态h_{t-1},计算当前时间步的隐藏状态h_{t}

h_{t} = \sigma(W_{xh}x_{t}+W_{hh}h_{t-1}+b_{h})

        其中,W_{xh}是输入x到隐藏层的权重矩阵,W_{hh}是隐藏层到隐藏层的权重矩阵,b_{h}是隐藏层的偏置向量,\sigma是激活函数。

3、输出

        根据当前时间步的隐藏状态h_{t}计算当前时间步的输出y_{t}

y_{t} = \phi (W_{hy}h_{t}+b_{y})

        其中,W_{hy}是隐藏层到输出层的权重矩阵,b_{y}是输出层的偏置向量,\phi是输出层的激活函数。

三、python实现

        这里,我们使用一个简单的正弦波数据来演示RNN在时间序列预测中的应用。

import torch
import torch.nn as nn
import numpy as np
import matplotlib.pyplot as plt

# 设置随机种子
torch.manual_seed(0)
np.random.seed(0)

# 生成正弦波数据
timesteps = 100
sin_wave = np.array([np.sin(2 * np.pi * i / timesteps) for i in range(timesteps)])

# 创建数据集
def create_dataset(data, time_step=1):
    dataX, dataY = [], []
    for i in range(len(data) - time_step - 1):
        a = data[i:(i + time_step)]
        dataX.append(a)
        dataY.append(data[i + time_step])
    return np.array(dataX), np.array(dataY)

time_step = 10
X, y = create_dataset(sin_wave, time_step)

# 数据预处理
X = X.reshape(X.shape[0], time_step, 1)
y = y.reshape(-1, 1)

# 转换为Tensor
X = torch.tensor(X, dtype=torch.float32)
y = torch.tensor(y, dtype=torch.float32)

# 划分训练集和测试集
train_size = int(len(X) * 0.7)
test_size = len(X) - train_size
trainX, testX = X[:train_size], X[train_size:]
trainY, testY = y[:train_size], y[train_size:]

# 定义RNN模型
class RNNModel(nn.Module):
    def __init__(self, input_size, hidden_size, output_size):
        super(RNNModel, self).__init__()
        self.hidden_size = hidden_size
        self.rnn = nn.RNN(input_size, hidden_size, batch_first=True)
        self.fc = nn.Linear(hidden_size, output_size)

    def forward(self, x):
        h0 = torch.zeros(1, x.size(0), self.hidden_size)
        out, _ = self.rnn(x, h0)
        out = self.fc(out[:, -1, :])
        return out

input_size = 1
hidden_size = 50
output_size = 1
model = RNNModel(input_size, hidden_size, output_size)

# 定义损失函数和优化器
criterion = nn.MSELoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)

# 训练模型
num_epochs = 200
for epoch in range(num_epochs):
    model.train()
    optimizer.zero_grad()
    outputs = model(trainX)
    loss = criterion(outputs, trainY)
    loss.backward()
    optimizer.step()

    if (epoch + 1) % 10 == 0:
        print(f'Epoch [{epoch + 1}/{num_epochs}], Loss: {loss.item():.4f}')

# 预测
model.eval()
train_predict = model(trainX)
test_predict = model(testX)
train_predict = train_predict.detach().numpy()
test_predict = test_predict.detach().numpy()

# 绘制结果
plt.figure(figsize=(10, 6))
plt.plot(sin_wave, label='Original Data')
plt.plot(np.arange(time_step, time_step + len(train_predict)), train_predict, label='Training Predict')
plt.plot(np.arange(time_step + len(train_predict), time_step + len(train_predict) + len(test_predict)), test_predict, label='Test Predict')
plt.legend()
plt.show()

四、总结

        RNN能够处理任意长度的序列数据,这使得它在自然语言处理、时间序列预测、语音识别等领域非常有用。然而,标准RNN在捕捉长时间依赖关系方面表现不佳。为了解决这个问题,大佬们相继提出了长短期记忆网络(LSTM)和门控循环单元(GRU)等RNN的变体,在增加模型的复杂度的同时提高了RNN对长时间依赖关系的捕捉能力。此外,由于RNN的时间步之间存在依赖关系,因此难以进行并行化处理。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/960900.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

Python设计模式 - 组合模式

定义 组合模式(Composite Pattern) 是一种结构型设计模式,主要意图是将对象组织成树形结构以表示"部分-整体"的层次结构。这种模式能够使客户端统一对待单个对象和组合对象,从而简化了客户端代码。 组合模式有透明组合…

19.Word:小马-校园科技文化节❗【36】

目录 题目​ NO1.2.3 NO4.5.6 NO7.8.9 NO10.11.12索引 题目 NO1.2.3 布局→纸张大小→页边距:上下左右插入→封面:镶边→将文档开头的“黑客技术”文本移入到封面的“标题”控件中,删除其他控件 NO4.5.6 标题→原文原文→标题 正文→手…

一文讲解Java中Object类常用的方法

在Java中,经常提到一个词“万物皆对象”,其中的“万物”指的是Java中的所有类,而这些类都是Object类的子类; Object主要提供了11个方法,大致可以分为六类: 对象比较: public native int has…

多项日常使用测试,带你了解如何选择AI工具 Deepseek VS ChatGpt VS Claude

多项日常使用测试,带你了解如何选择AI工具 Deepseek VS ChatGpt VS Claude 注:因为考虑到绝大部分人的使用,我这里所用的模型均为免费模型。官方可访问的。ChatGPT这里用的是4o Ai对话,编程一直以来都是人们所讨论的话题。Ai的出现…

Linux下学【MySQL】表的必备操作( 配实操图和SQL语句)

绪论​ “Patience is key in life (耐心是生活的关键)”。本章是MySQL中非常重要且基础的知识----对表的操作。再数据库中表是存储数据的容器,我们通过将数据填写在表中,从而再从表中拿取出来使用,本章主要讲到表的增…

【Java数据结构】了解排序相关算法

基数排序 基数排序是桶排序的扩展,本质是将整数按位切割成不同的数字,然后按每个位数分别比较最后比一位较下来的顺序就是所有数的大小顺序。 先对数组中每个数的个位比大小排序然后按照队列先进先出的顺序分别拿出数据再将拿出的数据分别对十位百位千位…

【全栈】SprintBoot+vue3迷你商城(9)

【全栈】SprintBootvue3迷你商城(9) 往期的文章都在这里啦,大家有兴趣可以看一下 后端部分: 【全栈】SprintBootvue3迷你商城(1) 【全栈】SprintBootvue3迷你商城(2) 【全栈】Spr…

php-phar打包避坑指南2025

有很多php脚本工具都是打包成phar形式,使用起来就很方便,那么如何自己做一个呢?也找了很多文档,也遇到很多坑,这里就来总结一下 phar安装 现在直接装yum php-cli包就有phar文件,很方便 可通过phar help查看…

【数据结构】_顺序表

目录 1. 概念与结构 1.1 静态顺序表 1.2 动态顺序表 2. 动态顺序表实现 2.1 SeqList.h 2.2 SeqList.c 2.3 Test_SeqList.c 3. 顺序表性能分析 线性表是n个具有相同特性的数据元素的有限序列。 常见的线性表有:顺序表、链表、栈、队列、字符串等&#xff1b…

OPencv3.4.1安装及配置教程

来到GitHub上opencv的项目地址 https://github.com/opencv/opencv/releases/tag/3.4.1 以上资源包都是 OpenCV 3.4.1 版本相关资源,它们的区别如下: (1). opencv-3.4.1-android-sdk.zip:适用于 Android 平台的软件开发工具包(SDK…

世上本没有路,只有“场”et“Bravo”

楔子:电气本科“工程电磁场”电气研究生课程“高等电磁场分析”和“电磁兼容”自学”天线“、“通信原理”、“射频电路”、“微波理论”等课程 文章目录 前言零、学习历程一、Maxwells equations1.James Clerk Maxwell2.自由空间中传播的电磁波3.边界条件和有限时域…

ZYNQ-IP-AXI-GPIO

AXI GPIO 可以将 PS 端的一个 AXI 4-Lite 接口转化为 GPIO 接口,并且可以被配置为单端口或双端口,每个通道的位宽可以独立配置。 通过使能三态门可以将端口动态地配置为输入或输出。 AXIGPIO 是 ZYNQ PL 端的一个 IP 核,可以将 AXI-Lite Mas…

20.Word:小谢-病毒知识的科普文章❗【38】

目录 题目​ NO1.2.3文档格式 NO4.5 NO6.7目录/图表目录/书目 NO8.9.10 NO11索引 NO12.13.14 每一步操作完,确定之后记得保存最后所有操作完记得再次删除空行 题目 NO1.2.3文档格式 样式的应用 选中应用段落段落→开始→选择→→检查→应用一个一个应用ctr…

为什么应用程序是特定于操作系统的?[计算机原理]

你把WINDOWS程序复制到MAC上使用,会发现无法运行。你可能会说,MAC是arm处理器,而WINDWOS是X86 处理器。但是在2019年,那时候MAC电脑还全是Intel处理器,在同样的X86芯片上,运行MAC和WINDOWS 程序还是无法互相…

LigerUI在MVC模式下的响应原则

LigerUI是基于jQuery的UI框架,故他也是遵守jQuery的开发模式,但是也具有其特色的侦听函数,那么当LigerUI作为View层的时候,他所发送后端的必然是表单的数据,在此我们以俩个div为例: {Layout "~/View…

BurpSuite--暴力破解

一.弱口令 1. 基本概念 介绍:弱口令(weak password)是指那些容易被他人猜测或通过工具破解的密码。虽然弱口令没有严格的定义,但通常它指的是由简单的数字、字母、常用词语或规律性组合构成的密码。 特点: 密码容易被…

深入探讨防抖函数中的 this 上下文

深入剖析防抖函数中的 this 上下文 最近我在研究防抖函数实现的时候,发现一个耗费脑子的问题,出现了令我困惑的问题。接下来,我将通过代码示例,深入探究这些现象背后的原理。 示例代码 function debounce(fn, delay) {let time…

【PostgreSQL内核学习 —— (WindowAgg(一))】

WindowAgg 窗口函数介绍WindowAgg理论层面源码层面WindowObjectData 结构体WindowStatePerFuncData 结构体WindowStatePerAggData 结构体eval_windowaggregates 函数update_frameheadpos 函数 声明:本文的部分内容参考了他人的文章。在编写过程中,我们尊…

RocketMQ消息是如何存储的?

大家好,我是锋哥。今天分享关于【RocketMQ消息是如何存储的?】面试题。希望对大家有帮助; RocketMQ消息是如何存储的? 1000道 互联网大厂Java工程师 精选面试题-Java资源分享网 RocketMQ 使用了一个高性能、分布式的消息存储架构…

MongoDB平替数据库对比

背景 项目一直是与实时在线监测相关,特点数据量大,读写操作大,所以选用的是MongoDB。但按趋势来讲,需要有一款国产数据库可替代,实现信创要求。选型对比如下 1. IoTDB 这款是由清华大学主导的开源时序数据库&#x…