机器学习入门--门控循环单元(GRU)原理与实践

GRU模型

随着深度学习领域的快速发展,循环神经网络(RNN)已成为自然语言处理(NLP)等领域中常用的模型之一。但是,在RNN中,如果时间步数较大,会导致梯度消失或爆炸的问题,这影响了模型的训练效果。为了解决这个问题,研究人员提出了新的模型,其中GRU是其中的一种。

本文将介绍GRU的数学原理、代码实现,并通过pytorch和sklearn的数据集进行试验,最后对该模型进行总结。

数学原理

GRU是一种门控循环单元(Gated Recurrent Unit)模型。与传统的RNN相比,它具有更强的建模能力和更好的性能。

重置门和更新门

在GRU中,每个时间步有两个状态:隐藏状态 h t h_t ht和更新门 r t r_t rt。。更新门控制如何从先前的状态中获得信息,而隐藏状态捕捉序列中的长期依赖关系。

GRU的核心思想是使用“门”来控制信息的流动。这些门是由sigmoid激活函数控制的,它们决定了哪些信息被保留和传递。
在每个时间步 t t t,GRU模型执行以下操作:

1.计算重置门
r t = σ ( W r [ x t , h t − 1 ] ) r_t = \sigma(W_r[x_t, h_{t-1}]) rt=σ(Wr[xt,ht1])
其中, W r W_r Wr是权重矩阵, σ \sigma σ表示sigmoid函数。重置门 r t r_t rt告诉模型是否要忽略先前的隐藏状态 h t − 1 h_{t-1} ht1,并只依赖于当前输入
x t x_t xt

2.计算更新门
z t = σ ( W z [ x t , h t − 1 ] ) z_t = \sigma(W_z[x_t, h_{t-1}]) zt=σ(Wz[xt,ht1])
其中,更新门 z t z_t zt告诉模型新的隐藏状态 h t h_t ht在多大程度上应该使用先前的状态 h t − 1 h_{t-1} ht1

候选隐藏状态和隐藏状态

在计算完重置门和更新门之后,我们可以计算候选隐藏状态 h ~ t \tilde{h}_{t} h~t和隐藏状态 h t h_t ht

1.计算候选隐藏状态
h ~ t = tanh ⁡ ( W [ x t , r t ∗ h t − 1 ] ) \tilde{h}_{t} = \tanh(W[x_t, r_t * h_{t-1}]) h~t=tanh(W[xt,rtht1])
其中, W W W是权重矩阵。候选隐藏状态 h ~ t \tilde{h}_{t} h~t利用当前输入 x t x_t xt和重置门 r t r_t rt来估计下一个可能的隐藏状态。

2.计算隐藏状态
h t = ( 1 − z t ) ∗ h t − 1 + z t ∗ h ~ t h_{t} = (1 - z_t) * h_{t-1} + z_t * \tilde{h}_{t} ht=(1zt)ht1+zth~t
这是GRU的最终隐藏状态公式。它在候选隐藏状态 h ~ t \tilde{h}_{t} h~t和先前的隐藏状态 h t h_t ht之间进行加权,其中权重由更新门 z t z_t zt控制。

代码实现

下面是使用pytorch和sklearn的房价数据集实现GRU的示例代码:

import torch
import torch.nn as nn
import torch.optim as optim
from sklearn.datasets import load_boston
from sklearn.preprocessing import StandardScaler
import matplotlib.pyplot as plt

# 加载数据集并进行标准化
data = load_boston()
X = data.data
y = data.target
scaler = StandardScaler()
X = scaler.fit_transform(X)
y = y.reshape(-1, 1)

# 转换为张量
X = torch.tensor(X, dtype=torch.float32).unsqueeze(1)
y = torch.tensor(y, dtype=torch.float32)

# 定义GRU模型
class GRUNet(nn.Module):
    def __init__(self, input_size, hidden_size, output_size):
        super(GRUNet, self).__init__()
        self.hidden_size = hidden_size
        self.gru = nn.GRU(input_size, hidden_size, batch_first=True)
        self.fc = nn.Linear(hidden_size, output_size)

    def forward(self, x):
        out, _ = self.gru(x)
        out = self.fc(out[:, -1, :])
        return out

input_size = X.shape[2]
hidden_size = 32
output_size = 1
model = GRUNet(input_size, hidden_size, output_size)

# 定义损失函数和优化器
criterion = nn.MSELoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

# 训练模型
num_epochs = 10000
loss_list = []
for epoch in range(num_epochs):
    optimizer.zero_grad()
    outputs = model(X)
    loss = criterion(outputs, y)
    loss.backward()
    optimizer.step()

    if (epoch+1) % 100 == 0:
        loss_list.append(loss.item())
        print(f'Epoch [{epoch+1}/{num_epochs}], Loss: {loss.item()}')

# 可视化损失曲线
plt.plot(range(100), loss_list)
plt.xlabel('num_epochs')
plt.ylabel('loss of GRU Training')
plt.show()

# 预测新数据
new_data_point = X[0].reshape(1, 1, -1)
prediction = model(new_data_point)
print(f'Predicted value: {prediction.item()}')

上述代码首先加载并标准化房价数据集,然后定义了一个包含GRU层和全连接层的GRUNet模型,并使用均方误差作为损失函数和Adam优化器进行训练。训练完成后,使用matplotlib库绘制损失曲线(如下图所示),并使用训练好的模型对新的数据点进行预测。
GRU 损失曲线

总结

GRU是一种门控循环单元模型,它通过更新门和重置门,有效地解决了梯度消失或爆炸的问题。在本文中,我们介绍了GRU的数学原理、代码实现和代码解释,并通过pytorch和sklearn的房价数据集进行了试验。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/398106.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

公寓报修|公寓报修管理系统|基于springboot公寓报修管理系统设计与实现(源码+数据库+文档)

公寓报修管理系统目录 目录 基于springboot公寓报修管理系统设计与实现 一、前言 二、系统功能设计 三、系统实现 1、住户管理 2、房间管理 3、维修人员管理 4、维修分类管理 5、物品信息管理 6、维修申请管理管理 四、数据库设计 1、实体ER图 五、核心代码 六、…

Vue3快速上手(九)计算属性computed

一、应用场景 vue3里强调尽量让<template>,也就是模版&#xff0c;变得更加简单。所以涉及到转换、计算等操作的&#xff0c;还是建议在<script>标签里进行。如此我们可以用到computed。 二、实际用法 2.1 示例1 一个简单的加法计算 <template><div …

6.2 数据库

本节介绍Android的数据库存储方式--SQLite的使用方法&#xff0c;包括&#xff1a;SQLite用到了哪些SQL语法&#xff0c;如何使用数据库管理操纵SQLitem&#xff0c;如何使用数据库帮助器简化数据库操作&#xff0c;以及如何利用SQLite改进登录页面的记住密码功能。 6.2.1 SQ…

深度学习——概念引入

深度学习 深度学习简介深度学习分类根据网络结构划分&#xff1a;循环神经网络卷积神经网络 根据学习方式划分&#xff1a;监督学习无监督学习半监督学习 根据应用领域划分&#xff1a;计算机视觉自然语言处理语音识别生物信息学 深度学习简介 深度学习&#xff08;Deep Learni…

将Windows电脑右下角的“中”字或“英”字输入法状态隐藏的方法

本文介绍在Windows 11操作系统中&#xff0c;将任务栏右下角的语言栏的“中”、“英”标识加以隐藏、消除的一种或许可行的方法。 最近换了新电脑&#xff0c;终于用上了Windows 11操作系统。但是&#xff0c;默认状态下&#xff0c;在任务栏最右侧&#xff0c;也就是屏幕右下角…

2024最新版Redis安装使用指南

2024最新版Redis安装使用指南 Installation and Usage Guide to the Latest Redis in 2024 By JacksonML 1. 什么是Redis? The open-source, in-memory data store used by millions of developers as a cache, vector database, document database, streaming engine, an…

MSS与cwnd的关系,rwnd又是什么?

慢启动算法是指数递增的 这种指数增长的方式是慢启动算法的一个核心特点&#xff0c;它确保了TCP连接在开始传输数据时能够快速地探测网络的带宽容量&#xff0c;而又不至于过于激进导致网络拥塞。具体来说&#xff1a; 初始阶段&#xff1a;当TCP连接刚建立时&#xff0c;拥…

Prometheus 教程

目录 一、简介二、下载安装1、安装 prometheus2、安装 alertmanager3、安装 grafana4、安装 node_exporter5、安装 mysqld_exporter 一、简介 Prometheus 是一个开源的系统监控和警报工具。它最初由 SoundCloud 开发&#xff0c;并于 2012 年发布为开源项目。Prometheus 专注于…

【leetcode刷题】 93.复原IP地址单层逻辑特殊点总结

这个跟131分割回文串比较类似&#xff0c;但是这里的回溯过程需要注意两个事项&#xff0c;一个是横向深入时要考虑到原字符串中加入“.”所以计数的idx从i2开始。纵向回退时要把用来控制结束时机的pointnum减掉1&#xff0c;再把这时已经加入了“.”的字符串去掉“.”。 判断合…

关于node与node-sass那些事

昨晚找了之前的一个项目想要复习下&#xff0c;结果npm i报错&#xff0c;大致意思就是noda-sass的版本和node的对不上&#xff0c;那怎么办呢&#xff1a; 1.换node版本&#xff0c;那好吧&#xff0c;首先要明白&#xff0c;对应的版本关系 2.然后我开始用nvm换node版本&am…

金晨颜值逆袭。每年一个样,美丽爆表。

♥ 为方便您进行讨论和分享&#xff0c;同时也为能带给您不一样的参与感。请您在阅读本文之前&#xff0c;点击一下“关注”&#xff0c;非常感谢您的支持&#xff01; 文 |猴哥聊娱乐 编 辑|徐 婷 校 对|侯欢庭 微博热议金晨颜值蜕变&#xff01;从20岁清纯到31岁明艳&…

RHEL8提示需要注册才可以yum解决办法

关闭注册以及修改更新远&#xff08;已注册的RHEL8忽略本步骤&#xff09; 原因&#xff1a;因为没注册的红帽子是无法连接到官方的Yum源的 箭头所指的改成0 vi /etc/yum/pluginconf.d/subscription-manager.conf 箭头所指的改成0 cd /etc/yum.repos.d/ wget https://mirro…

docker之安装mongo创建运行环境

目录 一、docker pull 最新资源 二、启动mongo镜像 启动命令查看日志拉取低版本镜像成功启动 三、进入mongo容器 进入容器进入mongo环境查询当前所在库切换库至admin随意切换库 并 创建用户登录用户新增文档数据等 五、总结 版本兼容可备份操作 一、docker pull 最新资源…

CentOS已安装宝塔的情况下手动安装phpMyAdmin

CentOS 7.9.2009&#xff0c;宝塔7.9.4。 服务器中已有MySQL&#xff0c;可能不是通过宝塔安装的&#xff0c;而是手动安装的。用命令行可以正常进入MySQL查看和管理数据&#xff0c;说明已有的MySQL是正常的。在宝塔中点击数据库提示“未安装本地数据库&#xff0c;已隐藏无法…

基于函数计算AIGC生成图应用

目录 基于函数计算部署AIGC应用的主要步骤 创建Stable Diffusion模型的应用 访问应用实现文字生图 函数的查看与管理 基于函数计算部署AIGC应用的主要步骤 用函数计算实现AIGC只要简单的三步&#xff0c;分别是创建应用、运行应用及查看管理。 创建Stable Diffusion模型的应…

【大厂AI课学习笔记】【2.2机器学习开发任务实例】(7)特征构造

特征分析之后&#xff0c;就是特征构造。 特征构造第一步 特征构造往往要进行数据的归一化。 在本案例中&#xff0c;我们将所有的数据&#xff0c;将所有特征区间调整为0~1之间。 如上图。 那么&#xff0c;为什么要进行归一化&#xff0c;又如何将数据&#xff0c;调整为…

《隐私计算简易速速上手小册》第2章:关键技术介绍(2024 最新版)

文章目录 2.1 同态加密2.1.1 基础知识2.1.2 主要案例:云计算数据分析2.1.3 拓展案例 1:医疗数据分析2.1.4 拓展案例 2:金融风险评估2.2 安全多方计算(SMC)2.2.1 基础知识2.2.2 主要案例:跨机构金融数据共享2.2.3 拓展案例 1:医疗研究合作2.2.4 拓展案例 2:跨国界数据交…

拼夕夕 拼多多关键词恢复供应,欢迎骚扰

API接口&#xff08;Application Programming Interface&#xff09;是一种定义了软件组件之间交互的规范。它允许不同的软件系统之间进行通信和数据交换&#xff0c;使得开发者可以利用已有的功能和服务来构建自己的应用程序。 API接口可以分为不同的类型&#xff0c;包括Web…

加固手持平板电脑在铁轨维修上的应用|亿道三防onerugged

随着铁路交通的发展&#xff0c;铁轨维修成为保障行车安全和运营效率的重要环节。在这个关键的维修过程中&#xff0c;专业的设备和工具至关重要。亿道三防onerugged系列产品中的加固手持平板电脑是一款在铁轨维修上应用的理想选择。 首先&#xff0c;这款加固手持平板电脑满足…

全面总结!加速大模型推理的超全指南来了!

2023 年&#xff0c;大型语言模型&#xff08;LLM&#xff09;以其强大的生成、理解、推理等能力而持续受到高度关注。然而&#xff0c;训练和部署 LLM 非常昂贵&#xff0c;需要大量的计算资源和内存&#xff0c;因此研究人员开发了许多用于加速 LLM 预训练、微调和推理的方法…