什么是门控循环单元?

一、概念

        门控循环单元(Gated Recurrent Unit,GRU)是一种改进的循环神经网络(RNN),由Cho等人在2014年提出。GRU是LSTM的简化版本,通过减少门的数量和简化结构,保留了LSTM的长时间依赖捕捉能力,同时提高了计算效率。GRU通过引入两个门(重置门和更新门)来控制信息的流动。与LSTM不同,GRU没有单独的细胞状态,而是将隐藏状态直接作为信息传递的载体,因此结构更简单,计算效率更高。

二、核心算法

        令x_{t}为时间步 t 的输入向量,h_{t-1}为前一个时间步的隐藏状态向量,h_{t}为当前时间步的隐藏状态向量,r_{t}为当前时间步的重置门向量,z_{t}为当前时间步的更新门向量,\bar{h_{t}}为当前时间步的候选隐藏状态向量,W_{r},W_{z},W_{h}分别为各门的权重矩阵,b_{r},b_{z},b_{h}为偏置向量,\sigma为sigmoid激活函数,tanh为tanh激活函数,*为元素级乘法。

1、重置门

        重置门控制前一个时间步的隐藏状态对当前时间步的影响。通过sigmoid激活函数,重置门的输出在0到1之间,表示前一个隐藏状态元素被保留的比例。

r_{t} = \sigma(W_{r} \cdot \left [ h_{t-1}, x_{t} \right ] + b_{r})

2、更新门

        更新门控制前一个时间步的隐藏状态和当前时间步的候选隐藏状态的混合比例。通过sigmoid激活函数,更新门的输出在0到1之间,表示前一个隐藏状态元素被保留的比例。

z_{t} = \sigma(W_{z} \cdot \left [ h_{t-1}, x_{t} \right ] + b_{z})

3、候选隐藏状态

        候选隐藏状态结合当前输入和前一个时间步的隐藏状态生成。重置门的输出与前一个隐藏状态相乘,表示保留的旧信息。然后与当前输入一起通过tanh激活函数生成候选隐藏状态。

\bar{h_{t}} = tanh(W_{h} \cdot \left [ r_{t} \ast h_{t-1}, x_{t} \right ] + b_{h})

4、隐藏状态更新

        隐藏状态结合更新门的结果进行更新。更新门的输出与前一个隐藏状态相乘,表示保留的旧信息。更新门的补数与候选隐藏状态相乘,表示写入的新信息。两者相加得到当前时间步的隐藏状态。

h_{t} = (1-z_{t}) \ast h_{t-1} + z_{t} \ast \bar{h_{t}}

三、python实现

import torch
import torch.nn as nn
import numpy as np
import matplotlib.pyplot as plt
 
# 设置随机种子
torch.manual_seed(0)
np.random.seed(0)
 
# 生成正弦波数据
timesteps = 1000
sin_wave = np.array([np.sin(2 * np.pi * i / timesteps) for i in range(timesteps)])
 
# 创建数据集
def create_dataset(data, time_step=1):
    dataX, dataY = [], []
    for i in range(len(data) - time_step - 1):
        a = data[i:(i + time_step)]
        dataX.append(a)
        dataY.append(data[i + time_step])
    return np.array(dataX), np.array(dataY)
 
time_step = 10
X, y = create_dataset(sin_wave, time_step)
 
# 数据预处理
X = X.reshape(X.shape[0], time_step, 1)
y = y.reshape(-1, 1)
 
# 转换为Tensor
X = torch.tensor(X, dtype=torch.float32)
y = torch.tensor(y, dtype=torch.float32)
 
# 划分训练集和测试集
train_size = int(len(X) * 0.7)
test_size = len(X) - train_size
trainX, testX = X[:train_size], X[train_size:]
trainY, testY = y[:train_size], y[train_size:]
 
# 定义RNN模型
class GRUModel(nn.Module):
    def __init__(self, input_size, hidden_size, output_size):
        super(GRUModel, self).__init__()
        self.hidden_size = hidden_size
        self.gru = nn.GRU(input_size, hidden_size, batch_first=True)
        self.fc = nn.Linear(hidden_size, output_size)
 
    def forward(self, x):
        h0 = torch.zeros(1, x.size(0), self.hidden_size)
        out, _ = self.gru(x, h0)
        out = self.fc(out[:, -1, :])
        return out
 
input_size = 1
hidden_size = 50
output_size = 1
model = GRUModel(input_size, hidden_size, output_size)
 
# 定义损失函数和优化器
criterion = nn.MSELoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
 
# 训练模型
num_epochs = 50
for epoch in range(num_epochs):
    model.train()
    optimizer.zero_grad()
    outputs = model(trainX)
    loss = criterion(outputs, trainY)
    loss.backward()
    optimizer.step()
 
    if (epoch + 1) % 10 == 0:
        print(f'Epoch [{epoch + 1}/{num_epochs}], Loss: {loss.item():.4f}')
 
# 预测
model.eval()
train_predict = model(trainX)
test_predict = model(testX)
train_predict = train_predict.detach().numpy()
test_predict = test_predict.detach().numpy()
 
# 绘制结果
plt.figure(figsize=(10, 6))
plt.plot(sin_wave, label='Original Data')
plt.plot(np.arange(time_step, time_step + len(train_predict)), train_predict, label='Training Predict')
plt.plot(np.arange(time_step + len(train_predict), time_step + len(train_predict) + len(test_predict)), test_predict, label='Test Predict')
plt.legend()
plt.show()

四、总结

        GRU的结构比LSTM更简单,只有两个门(重置门和更新门),没有单独的细胞状态。这使得GRU的计算复杂度较低,训练和推理速度更快。通过引入重置门和更新门,GRU也有效地解决了标准RNN在处理长序列时的梯度消失和梯度爆炸问题。然而,在需要更精细的门控制和信息流动的任务中,LSTM的性能可能优于GRU。因此在我们实际的建模过程中,可以根据数据特点选择合适的RNN系列模型,并没有哪个模型能在所有任务中都具有优势。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/963515.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

基于深度学习的输电线路缺陷检测算法研究(论文+源码)

输电线路关键部件的缺陷检测对于电网安全运行至关重要,传统方法存在效率低、准确性不高等问题。本研究探讨了利用深度学习技术进行输电线路关键组件的缺陷检测,目的是提升检测的效率与准确度。选用了YOLOv8模型作为基础,并通过加入CA注意力机…

【LLM-agent】(task6)构建教程编写智能体

note 构建教程编写智能体 文章目录 note一、功能需求二、相关代码(1)定义生成教程的目录 Action 类(2)定义生成教程内容的 Action 类(3)定义教程编写智能体(4)交互式操作调用教程编…

C++游戏开发实战:从引擎架构到物理碰撞

📝个人主页🌹:一ge科研小菜鸡-CSDN博客 🌹🌹期待您的关注 🌹🌹 1. 引言 C 是游戏开发中最受欢迎的编程语言之一,因其高性能、低延迟和强大的底层控制能力,被广泛用于游戏…

Time Constant | RC、RL 和 RLC 电路中的时间常数

注:本文为 “Time Constant” 相关文章合辑。 机翻,未校。 How To Find The Time Constant in RC and RL Circuits June 8, 2024 💡 Key learnings: 关键学习点: Time Constant Definition: The time constant (τ) is define…

DeepSeek Janus-Pro:多模态AI模型的突破与创新

近年来,人工智能领域取得了显著的进展,尤其是在多模态模型(Multimodal Models)方面。多模态模型能够同时处理和理解文本、图像等多种类型的数据,极大地扩展了AI的应用场景。DeepSeek(DeepSeek-V3 深度剖析:…

w188校园商铺管理系统设计与实现

🙊作者简介:多年一线开发工作经验,原创团队,分享技术代码帮助学生学习,独立完成自己的网站项目。 代码可以查看文章末尾⬇️联系方式获取,记得注明来意哦~🌹赠送计算机毕业设计600个选题excel文…

DeepSeek R1本地化部署 Ollama + Chatbox 打造最强 AI 工具

🌈 个人主页:Zfox_ 🔥 系列专栏:Linux 目录 一:🔥 Ollama 🦋 下载 Ollama🦋 选择模型🦋 运行模型🦋 使用 && 测试 二:🔥 Chat…

图漾相机——Sample_V1示例程序

文章目录 1.SDK支持的平台类型1.1 Windows 平台1.2 Linux平台 2.SDK基本知识2.1 SDK目录结构2.2 设备组件简介2.3 设备组件属性2.4 设备的帧数据管理机制2.5 SDK中的坐标系变换 3.Sample_V1示例程序3.1 DeviceStorage3.2 DumpCalibInfo3.3 NetStatistic3.4 SimpleView_SaveLoad…

DeepSeek 遭 DDoS 攻击背后:DDoS 攻击的 “千层套路” 与安全防御 “金钟罩”

当算力博弈升级为网络战争:拆解DDoS攻击背后的技术攻防战——从DeepSeek遇袭看全球网络安全新趋势 在数字化浪潮席卷全球的当下,网络已然成为人类社会运转的关键基础设施,深刻融入经济、生活、政务等各个领域。从金融交易的实时清算&#xf…

小程序项目-购物-首页与准备

前言 这一节讲一个购物项目 1. 项目介绍与项目文档 我们这里可以打开一个网址 https://applet-base-api-t.itheima.net/docs-uni-shop/index.htm 就可以查看对应的文档 2. 配置uni-app的开发环境 可以先打开这个的官网 https://uniapp.dcloud.net.cn/ 使用这个就可以发布到…

深入解析Python机器学习库Scikit-Learn的应用实例

深入解析Python机器学习库Scikit-Learn的应用实例 随着人工智能和数据科学领域的迅速发展,机器学习成为了当下最炙手可热的技术之一。而在机器学习领域,Python作为一种功能强大且易于上手的编程语言,拥有庞大的生态系统和丰富的机器学习库。其…

大模型训练(5):Zero Redundancy Optimizer(ZeRO零冗余优化器)

0 英文缩写 Large Language Model(LLM)大型语言模型Data Parallelism(DP)数据并行Distributed Data Parallelism(DDP)分布式数据并行Zero Redundancy Optimizer(ZeRO)零冗余优化器 …

在亚马逊云科技上用Stable Diffusion 3.5 Large生成赛博朋克风图片(上)

背景介绍 在2024年的亚马逊云科技re:Invent大会上提前预告的Stable Diffusion 3.5 Large,现在已经在Amazon Bedrock上线了!各位开发者们现在可以使用该模型,根据文本提示词文生图生成高质量的图片,并且支持多种图片风格生成&…

java练习(5)

ps:题目来自力扣 给你两个 非空 的链表,表示两个非负的整数。它们每位数字都是按照 逆序 的方式存储的,并且每个节点只能存储 一位 数字。 请你将两个数相加,并以相同形式返回一个表示和的链表。 你可以假设除了数字 0 之外,这…

【力扣】438.找到字符串中所有字母异位词

AC截图 题目 思路 我一开始是打算将窗口内的s子字符串和p字符串都重新排序&#xff0c;然后判断是否相等&#xff0c;再之后进行窗口滑动。不过缺点是会超时。 class Solution { public:vector<int> findAnagrams(string s, string p) {vector<int> vec;if(s.siz…

DeepSeek回答禅宗三重境界重构交易认知

人都是活在各自心境里&#xff0c;有些话通过语言去交流&#xff0c;还是要回归自己心境内在的&#xff0c;而不是靠外在映射到股票和技术方法&#xff1b;比如说明天市场阶段是不修复不接力节点&#xff0c;这就是最高视角看整个市场&#xff0c;还有哪一句话能概括&#xff1…

C++进阶: 红黑树及map与set封装

红黑树总结整理 红黑色概述&#xff1a; 红黑树整理与AVL树类似&#xff0c;但在对树的平衡做控制时&#xff0c;AVL树会比红黑树更严格。 AVL树是通过引入平衡因子的概念进行对树高度控制。 红黑树则是对每个节点标记颜色&#xff0c;对颜色进行控制。 红黑树控制规则&…

BUUCTF_[网鼎杯 2020 朱雀组]phpweb(反序列化绕过命令)

打开靶场&#xff0c;,弹出上面的提示,是一个警告warning,而且页面每隔几秒就会刷新一次,根据warning中的信息以及信息中的时间一直在变,可以猜测是date()函数一直在被调用 查看页面源代码&#xff0c;没有什么有用的信息 Burp抓包一下 调用了date()函数并回显在页面上,参数fu…

pandas(二)读取数据

一、读取数据 示例代码 import pandaspeople pandas.read_excel(../002/People.xlsx) #读取People数据 print(people.shape) # 打印people表的行数、列数 print(people.head(3)) # 默认打印前5行,当前打印前3行 print("") print(people.tail(3)) # 默…

智慧物业管理系统实现社区管理智能化提升居民生活体验与满意度

内容概要 智慧物业管理系统&#xff0c;顾名思义&#xff0c;是一种将智能化技术融入社区管理的系统&#xff0c;它通过高效的手段帮助物业公司和居民更好地互动与沟通。首先&#xff0c;这个系统整合了在线收费、停车管理等功能&#xff0c;让居民能够方便快捷地完成日常支付…