时间序列预测(九)——门控循环单元网络(GRU)

目录

一、GRU结构

二、GRU核心思想

1、更新门(Update Gate):决定了当前时刻隐藏状态中旧状态和新候选状态的混合比例。

2、重置门(Reset Gate):用于控制前一时刻隐藏状态对当前候选隐藏状态的影响程度。

3、候选隐藏状态(Candidate Hidden State):生成当前隐藏状态的候选值

三、GRU 分步演练

1、输入与初始化:

2、计算重置门:

3、计算候选隐藏状态:

4、计算更新门:

5、计算当前隐藏状态:

四、代码实现

1、任务:

2、做法:

3、主要修改点:

4、具体代码:

5、结果


GRU是一种循环神经网络(RNN)的变体,由Cho等人在2014年提出。相比于传统的RNN,GRU引入了门控机制,可以通过该机制来确定应该何时更新隐状态,以及应该何时重置隐状态,使得网络能够更好地捕捉长期依赖性,同时减少了梯度消失的问题。

一、GRU结构

GRU的结构和基础的RNN相比,并没有特别大的不同,都是一种重复神经网络模块的链式结构,由输入层、隐藏层和输出层组成,其中隐藏层是其核心部分,包含了门控机制相关的计算单元

二、GRU核心思想

与LSTM不同,GRU没有细胞状态,而是直接使用隐藏状态。GRU由两个门控制:更新门(Update Gate)和重置门(Reset Gate)。

1、更新门(Update Gate):决定了当前时刻隐藏状态中旧状态和新候选状态的混合比例。
2、重置门(Reset Gate):用于控制前一时刻隐藏状态对当前候选隐藏状态的影响程度。

补充:

3、候选隐藏状态(Candidate Hidden State):生成当前隐藏状态的候选值

三、GRU 分步演练

1、输入与初始化
  • 假设我们有一个输入序列 X=[x1​,x2​,...,xT​],其中 xt​ 是第 t 个时间步的输入。
  • 初始化隐藏状态 h0​,通常为零向量或随机初始化。
2、计算重置门
  • 重置门 rt​ 决定了前一时间步的隐藏状态 ht−1​ 对当前候选隐藏状态 h~t​ 的影响程度。                                      其中 σ 是sigmoid函数,Wr​ 和 Ur​ 是可训练的权重矩阵。
3、计算候选隐藏状态
  • 使用重置门 rt​ 来控制前一时间步的隐藏状态 ht−1​ 的影响。                       其中 ⊙ 表示元素乘法,tanh 是双曲正切函数,W 和 U 是可训练的权重矩阵。
4、计算更新门
  • 更新门 zt​ 决定了当前隐藏状态 ht​ 应该保留多少前一时间步的隐藏状态 ht−1​ 和多少当前候选隐藏状态 h~t​。            其中 Wz​ 和 Uz​ 是可训练的权重矩阵。
5、计算当前隐藏状态
  • 使用更新门 zt​ 来组合前一时间步的隐藏状态 ht−1​ 和当前候选隐藏状态 h~t​。

四、代码实现

1、任务:

根据一个包含道路曲率(Curvature)、车速(Velocity)、侧向加速度(Ay)和方向盘转角(Steering_Angle)真实的数据集,去预测未来的方向盘转角。

2、做法:

提取前5个历史曲率、速度、方向盘转角作为输入特征,同时添加后5个未来曲率(由于车辆的预瞄距离)。目标输出为未来5个方向盘转角。采用GRU网络训练。

3、主要修改点:
  1. 模型定义:将 LSTM 替换为 GRU,并更新模型类名为 GRUModel
  2. 前向传播forward 方法中相应地使用 GRU 的输出。
4、具体代码:
# GRU 模型
import pandas as pd
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler
from sklearn.metrics import mean_absolute_error as mae, r2_score
import matplotlib.pyplot as plt

# 1. 数据预处理
# 读取数据
data = pd.read_excel('input_data_20241010160240.xlsx')  # 替换为你的数据文件路径  

# 提取特征和标签
curvature = data['Curvature'].values
velocity = data['Velocity'].values
steering = data['Steering_Angle'].values

# 定义历史和未来的窗口大小
history_size = 5
future_size = 5

features = []
labels = []
for i in range(history_size, len(data) - future_size):
    # 提取前5个历史的曲率、速度和方向盘转角
    history_curvature = curvature[i - history_size:i]
    history_velocity = velocity[i - history_size:i]
    history_steering = steering[i - history_size:i]

    # 提取后5个未来的曲率(用于预测)
    future_curvature = curvature[i:i + future_size]

    # 输入特征:历史 + 未来曲率
    feature = np.hstack((history_curvature, history_velocity, history_steering, future_curvature))
    features.append(feature)

    # 输出标签:未来5个方向盘转角
    label = steering[i:i + future_size]
    labels.append(label)

# 转换为 NumPy 数组
features = np.array(features)
labels = np.array(labels)

# 归一化
scaler_x = StandardScaler()
scaler_y = StandardScaler()

features = scaler_x.fit_transform(features)
labels = scaler_y.fit_transform(labels)

# 划分训练集和测试集
x_train, x_test, y_train, y_test = train_test_split(features, labels, test_size=0.05)

# 将特征转换为三维张量,形状为 [样本数, 时间序列长度, 特征数]
input_feature_size = history_size * 3 + future_size  # 历史曲率、速度、方向盘转角 + 未来曲率
x_train_tensor = torch.tensor(x_train, dtype=torch.float32).view(-1, 1, input_feature_size)  # [batch_size, seq_len=1, input_size]
y_train_tensor = torch.tensor(y_train, dtype=torch.float32).view(-1, future_size)  # 输出未来的5个方向盘转角
x_test_tensor = torch.tensor(x_test, dtype=torch.float32).view(-1, 1, input_feature_size)
y_test_tensor = torch.tensor(y_test, dtype=torch.float32).view(-1, future_size)

# 2. 创建GRU模型
class GRUModel(nn.Module):
    def __init__(self, input_size, hidden_size, num_layers, output_size):
        super(GRUModel, self).__init__()
        self.hidden_size = hidden_size
        self.num_layers = num_layers
        self.gru = nn.GRU(input_size, hidden_size, num_layers, batch_first=True)  # 使用GRU
        self.fc = nn.Linear(hidden_size, output_size)  # 输出层

    def forward(self, x):
        # 前向传播
        out, _ = self.gru(x)  # GRU输出
        out = self.fc(out[:, -1, :])  # 只取最后一个时间步的输出
        return out

# 实例化模型
input_size = input_feature_size  # 输入特征数
hidden_size = 64  # 隐藏层大小
num_layers = 2  # GRU层数
output_size = future_size  # 输出5个未来方向盘转角
model = GRUModel(input_size, hidden_size, num_layers, output_size)

# 3. 设置损失函数和优化器
criterion = nn.MSELoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

# 4. 训练模型
num_epochs = 1000
for epoch in range(num_epochs):
    model.train()
    
    # 前向传播
    outputs = model(x_train_tensor)
    loss = criterion(outputs, y_train_tensor)

    # 后向传播和优化
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

    if (epoch + 1) % 100 == 0:
        print(f'Epoch [{epoch + 1}/{num_epochs}], Loss: {loss.item():.4f}')

# 5. 预测
model.eval()
with torch.no_grad():
    y_pred_tensor = model(x_test_tensor)

y_pred = scaler_y.inverse_transform(y_pred_tensor.numpy())  # 将预测值逆归一化
y_test = scaler_y.inverse_transform(y_test_tensor.numpy())  # 逆归一化真实值

# 评估指标
r2 = r2_score(y_test, y_pred, multioutput='uniform_average')  # 多维输出下的R^2
mae_score = mae(y_test, y_pred)
print(f"R^2 score: {r2:.4f}")
print(f"MAE: {mae_score:.4f}")

# 支持中文
plt.rcParams['font.sans-serif'] = ['SimSun']  # 用来正常显示中文标签
plt.rcParams['axes.unicode_minus'] = False  # 用来正常显示负号

# 绘制未来5个方向盘转角的预测和真实值对比
plt.figure(figsize=(10, 6))
for i in range(future_size):
    plt.plot(range(len(y_test)), y_test[:, i], label=f'真实值 {i+1} 步', color='blue')
    plt.plot(range(len(y_pred)), y_pred[:, i], label=f'预测值 {i+1} 步', color='red')
plt.xlabel('样本索引')
plt.ylabel('Steering Angle')
plt.title('未来5个方向盘转角的实际值与预测值对比图')
plt.legend()
plt.grid(True)
plt.show()

# 计算预测和真实方向盘转角的平均值
y_pred_mean = np.mean(y_pred, axis=1)  # 每个样本的5个预测值取平均
y_test_mean = np.mean(y_test, axis=1)  # 每个样本的5个真实值取平均

# 绘制平均值的实际值与预测值对比图
plt.figure(figsize=(10, 6))
plt.plot(range(len(y_test_mean)), y_test_mean, label='真实值(平均)', color='blue')
plt.plot(range(len(y_pred_mean)), y_pred_mean, label='预测值(平均)', color='red')
plt.xlabel('样本索引')
plt.ylabel('Steering Angle (平均)')
plt.title('未来5个方向盘转角的平均值对比图')
plt.legend()
plt.grid(True)
plt.show()

# 绘制第1个时间步的实际值与预测值对比图
plt.figure(figsize=(10, 6))
plt.plot(range(len(y_test)), y_test[:, 0], label='真实值 (第1步)', color='blue')
plt.plot(range(len(y_pred)), y_pred[:, 0], label='预测值 (第1步)', color='red')
plt.xlabel('样本索引')
plt.ylabel('Steering Angle')
plt.title('未来第1步方向盘转角的实际值与预测值对比图')
plt.legend()
plt.grid(True)
plt.show()

# 计算每个时间步的平均绝对误差
time_steps = y_test.shape[1]
mae_per_step = [mae(y_test[:, i], y_pred[:, i]) for i in range(time_steps)]

# 绘制每个时间步的平均绝对误差
plt.figure(figsize=(10, 6))
plt.bar(range(1, time_steps + 1), mae_per_step, color='orange')
plt.xlabel('时间步')
plt.ylabel('MAE')
plt.title('不同时间步的平均绝对误差')
plt.grid(True)
plt.show()
5、结果

五、总结

GRU是LSTM的简化版本,减少了门的数量,使得训练和推理速度更快。它在许多序列建模任务中表现良好,适用于时间序列预测、自然语言处理等领域。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/901840.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

STM32实现毫秒级时间同步

提起“时间同步”这个概念,大家可能很陌生。一时间搞不清楚是什么意思。 我理解“时间同步”可以解决多个传感器采集数据不同时的问题,让多个传感器同时采集数据。 打个比方。两个人走路,都是100毫秒走一步(频率相同是前提&…

2024数学分析【南昌大学】

计算极限 lim ⁡ n → ∞ 2024 n ( 1 − cos ⁡ 1 n 2 ) n 3 1 + n 2 − n \mathop {\lim }\limits_{n \to \infty } \frac{{\sqrt[n]{{2024}}\left( {1 - \cos \frac{1}{{{n^2}}}} \right){n^3}}}{{\sqrt {1 + {n^2}} - n}} n→∞lim​1+n2 ​−nn2024 ​(1−cosn21​)n3​ …

XJ02、消费金融|消费金融业务模式中的主要主体

根据所持有牌照类型的不同,消费金融服务供给方主要分为商业银行、汽车金融公司、消费金融公司和小贷公司,不同类型机构定位不同、提供消费金融服务与产品类型也各不相同。此外,互联网金融平台也成为中国消费金融业务最重要的参与方之一,虽其并非持牌金融机构,但借助其流量…

React 分装webSocket

背景 AI 实时对话 需要流式数据 React Hooks 写法。新建WebSocket.tsx 放在根目录components import { useCallback, useRef, useState } from react;type MessageHandler (message: MessageEvent) > void; type ErrorHandler (event: Event) > void;export functi…

深度学习(一)基础:神经网络、训练过程与激活函数(1/10)

深度学习基础:神经网络、训练过程与激活函数 引言: 深度学习作为机器学习的一个子领域,近年来在人工智能的发展中扮演了举足轻重的角色。它通过模仿人脑的神经网络结构,使得计算机能够从数据中学习复杂的模式和特征,…

List、Set、数据结构、Collections

一、数据结构 1.1 常用的数据结构 栈 栈:stack,又称堆栈,它是运算受限的线性表,其限制是仅允许在标的一端进行插入和删除操作,不允许在其他任何位置进行添加、查找、删除等操作。 简单的说:采用该结构的集合&#…

【网络协议栈】Tcp协议(下)的可靠性和高效性(超时重传、快速重传、拥塞控制、流量控制)

绪论: 承接上文,上文写到Tcp协议的结构以及对tcp协议的性能优化的滑动窗口,本章我们将继续了解Tcp协议的可靠性和高效性的具体展示。后面我将继续完善网络协议栈的网络层协议敬请期待! 话不多说安全带系好,发车啦(建议…

【AI绘画】Midjourney进阶:对角线构图详解

博客主页: [小ᶻZ࿆] 本文专栏: AI绘画 | Midjourney 文章目录 💯前言💯什么是构图为什么Midjourney要使用构图 💯对角线构图特点应用场景提示词书写技巧测试 💯小结 💯前言 【AI绘画】Midjourney进阶&a…

Linux常用命令1

切换目录 cd [rootlocalhost menge]# cd /[rootlocalhost /]# cd: cd [-L|[-P [-e]] [-]] [目录] 查看当前的目录 pwd 浏览目录内容 ls ls浏览后颜色表示 白色:普通文件 蓝色:目录 红色:压缩包文件 黄色:设备文件 绿…

以 6502 为例讲讲怎么阅读 CPU 电路图

开篇 你是否曾对 CPU 的工作原理充满好奇,以及简单的晶体管又是如何组成逻辑门,进而构建出复杂的逻辑电路实现?本文将以知名的 6502 CPU 的电路图为例,介绍如何阅读 CPU 电路图,并向你演示如何从晶体管电路还原出逻辑…

.NET Core WebApi第2讲:前后端分离,Restful

动态页面:数据流动 / Web服务器 / Ajax / 前后端分离 / restful风格源栈课堂一起帮https://17bang.ren/Code/261 一、Ajax:页面可以局部刷新 1、PPT演示:使用Ajax也无法减小带宽耗用 请求第一个页面,用AJAX从服务器端加载了一个页头。 请求第…

Maven进阶——坐标、依赖、仓库

目录 1.pomxml文件 2. 坐标 2.1 坐标的概念 2.2 坐标的意义 2.3 坐标的含义 2.4 自己项目的坐标 2.5 第三方项目坐标 3. 依赖 3.1 依赖的意义 3.2 依赖的使用 3.3 第三方依赖的查找方法 3.4 依赖范围 3.5 依赖传递和可选依赖 3.5.1 依赖传递 3.5.2 依赖范围对传…

ollama 在 Linux 环境的安装

ollama 在 Linux 环境的安装 介绍 他的存在在我看来跟 docker 的很是相似,他把市面上已经存在的大语言模型集合在一个仓库中,然后通过 ollama 的方式来管理这些大语言模型 下载 # 可以直接通过 http 的方式吧对应的 shell 脚本下载下来,然…

【10天速通Navigation2】(三) :Cartographer建图算法配置:从仿真到实车,从原理到实现

前言 往期内容: 第一期:【10天速通Navigation2】(一) 框架总览和概念解释第二期:【10天速通Navigation2】(二) :ROS2gazebo阿克曼小车模型搭建-gazebo_ackermann_drive等插件的配置和说明 本教材将贯穿nav2的全部内容&#xff0c…

【Android】Kotlin教程(4)

文章目录 1.field2.计算属性3.主构造函数4.次构造函数5.默认参数6.初始化块7.初始化顺序7.延迟初始化lateinit8.惰性初始化 1.field field 关键字通常与属性的自定义 getter 和 setter 一起使用。当你需要为一个属性提供自定义的行为时,可以使用 field 来访问或设置…

Visual Studio2022 Profile 工具使用

本篇研究下Visual Studio自带的性能分析工具,针对C代码,基于Visual Studio2022 文章目录 CPU使用率检测并发可视化工具使用率视图线程视图内核视图并发可视化工具SDK 参考资料 CPU使用率 对于CPU密集型程序,我们可以通过分析程序的CPU使用率…

【MySQL】MySQL数据库中密码加密和查询的解决方案

本篇博客是为了记录自己在遇到password函数无法生效时的解决方案。通过使用AES_ENCRYPT(str,key)和AES_DECRYPT(str,key)进行加密和解密。 一、问题 自己想创建一个user表,user表中有一个password属性列,自己想对密码进行加密后再存入数据库&#xff0c…

基于JAVA+SpringBoot+Vue的华府便利店信息管理系统

基于JAVASpringBootVue的华府便利店信息管理系统 前言 ✌全网粉丝20W,csdn特邀作者、博客专家、CSDN[新星计划]导师、java领域优质创作者,博客之星、掘金/华为云/阿里云/InfoQ等平台优质作者、专注于Java技术领域和毕业项目实战✌ 🍅文末附源码下载链接&#x1f3…

GCC 简介

Linux 中的编译器 GCC 的编译原理和使用详解 GCC 简介 GCC(GNU Compiler Collection)是一套由 GNU 开发的编程语言编译器,它支持多种编程语言,包括 C、C、Objective-C、Fortran、Ada 和 Go 等。GCC 是一个开源的工具集&#xff…

STM32F103C8T6 IO 操作

1.开启相关时钟 在 STM32 微控制器中,开启 GPIO 端口的时钟是确保 IO 口可以正常工作的第一步。 查找 RCC 寄存器使能时钟 在 STM32 中,时钟控制的寄存器通常位于 RCC (Reset and Clock Control) 模块中。不同的 STM32 系列(如 STM32F1、STM…