GRU (门控循环单元 - 基于RNN - 简化LSTM又快又好 - 体现注意力的思想) + 代码实现 —— 笔记3.5《动手学深度学习》

目录

0. 前言

1. 门控隐状态

1.1 重置门和更新门

1.2 候选隐状态

1.3 隐状态

2. 从零开始实现

2.1 初始化模型参数

2.2 定义模型

2.3 训练与预测

3 简洁实现

4. 小结


0. 前言

  • 课程全部代码(pytorch版)已上传到附件
  • 看懂上一篇RNN的所有细节,这里的GRU、后面的LSTM甚至注意力机制都能容易理解,因为RNN是一切序列模型的基础
  • 本章节为原书第9章(现代循环网络),共分为8节,本篇是第1节:门控循环单元 GRU
  • 本节的代码位置为:chapter_recurrent-modern/gru.ipynb
  • 本节的视频链接:56 门控循环单元(GRU)【动手学深度学习v2】_哔哩哔哩_bilibili

在一个序列中:

在 :numref:sec_bptt中, 我们讨论了如何在循环神经网络中计算梯度, 以及矩阵连续乘积可以导致梯度消失或梯度爆炸的问题。 下面我们简单思考一下这种梯度异常在实践中的意义:

  • 我们可能会遇到这样的情况:早期观测值对预测所有未来观测值具有非常重要的意义。 考虑一个极端情况,其中第一个观测值包含一个校验和, 目标是在序列的末尾辨别校验和是否正确。 在这种情况下,第一个词元的影响至关重要。 我们希望有某些机制能够在一个记忆元里存储重要的早期信息。 如果没有这样的机制,我们将不得不给这个观测值指定一个非常大的梯度, 因为它会影响所有后续的观测值。
  • 我们可能会遇到这样的情况:一些词元没有相关的观测值。 例如,在对网页内容进行情感分析时, 可能有一些辅助HTML代码与网页传达的情绪无关。 我们希望有一些机制来跳过隐状态表示中的此类词元。
  • 我们可能会遇到这样的情况:序列的各个部分之间存在逻辑中断。 例如,书的章节之间可能会有过渡存在, 或者证券的熊市和牛市之间可能会有过渡存在。 在这种情况下,最好有一种方法来重置我们的内部状态表示。

在学术界已经提出了许多方法来解决这类问题。 其中最早的方法是"长短期记忆"(long-short-term memory,LSTM) :cite:Hochreiter.Schmidhuber.1997, 我们将在 :numref:sec_lstm中讨论。 门控循环单元(gated recurrent unit,GRU) :cite:Cho.Van-Merrienboer.Bahdanau.ea.2014 是一个稍微简化的变体,通常能够提供同等的效果, 并且计算 :cite:Chung.Gulcehre.Cho.ea.2014的速度明显更快。 由于门控循环单元更简单,我们从它开始解读。

1. 门控隐状态

门控循环单元与普通的循环神经网络之间的关键区别在于: 前者支持隐状态的门控。 这意味着模型有专门的机制来确定应该何时更新隐状态, 以及应该何时重置隐状态。 这些机制是可学习的,并且能够解决了上面列出的问题。 例如,如果第一个词元非常重要, 模型将学会在第一次观测之后不更新隐状态。 同样,模型也可以学会跳过不相关的临时观测。 最后,模型还将学会在需要的时候重置隐状态。 下面我们将详细讨论各类门控。

1.1 重置门和更新门

我们首先介绍重置门(reset gate)和更新门(update gate)。 我们把它们设计成(0,1)区间中的向量, 这样我们就可以进行凸组合(权重向量的每个元素非负且和为1)。 重置门允许我们控制“可能还想记住”的过去状态的数量; 更新门将允许我们控制新状态中有多少个是旧状态的副本。

我们从构造这些门控开始。 :numref:fig_gru_1 描述了门控循环单元中的重置门和更新门的输入, 输入是由当前时间步的输入和前一时间步的隐状态给出。 两个门的输出是由使用sigmoid激活函数的两个全连接层给出。

:label:fig_gru_1

1.2 候选隐状态

:numref:fig_gru_2说明了应用重置门之后的计算流程。

:label:fig_gru_2

1.3 隐状态

些设计可以帮助我们处理循环神经网络中的梯度消失问题, 并更好地捕获时间步距离很长的序列的依赖关系。 例如,如果整个子序列的所有时间步的更新门都接近于1, 则无论序列的长度如何,在序列起始时间步的旧隐状态都将很容易保留并传递到序列结束。

:numref:fig_gru_3说明了更新门起作用后的计算流。

:label:fig_gru_3

总之,门控循环单元具有以下两个显著特征:

  • 重置门有助于捕获序列中的短期依赖关系;
  • 更新门有助于捕获序列中的长期依赖关系。

2. 从零开始实现

为了更好地理解门控循环单元模型,我们从零开始实现它。 首先,我们读取 :numref:sec_rnn_scratch中使用的时间机器数据集:

In [1]:

import torch
from torch import nn
from d2l import torch as d2l
​
batch_size, num_steps = 32, 35
train_iter, vocab = d2l.load_data_time_machine(batch_size, num_steps)
 

2.1 初始化模型参数

下一步是初始化模型参数。 我们从标准差为0.01的高斯分布中提取权重, 并将偏置项设为0,超参数num_hiddens定义隐藏单元的数量, 实例化与更新门、重置门、候选隐状态和输出层相关的所有权重和偏置。

In [2]:

def get_params(vocab_size, num_hiddens, device):  # 不去复习RNN从零实现,后面不太好看懂,函数都在那里定义
    num_inputs = num_outputs = vocab_size  # 和RNN类似,用了独热编码后,输入和输出都是词表大小,多分类问题
​
    def normal(shape):
        return torch.randn(size=shape, device=device)*0.01  # 和RNN类似,均值0方差1初始化后,缩放为0.01
​
    def three():  # 由于三个公式结构差不多,对应位置参数的形状都一样;封装成函数减少重复代码,更优雅
        return (normal((num_inputs, num_hiddens)),  # W_x_学习参数的形状(词表大小,隐层数)
                normal((num_hiddens, num_hiddens)),  # W_h_形状(隐层数,隐层数)
                torch.zeros(num_hiddens, device=device))  # b__的形状(隐层数,)

    # 回看上面的公式
    W_xz, W_hz, b_z = three()  # 更新门参数
    W_xr, W_hr, b_r = three()  # 重置门参数
    W_xh, W_hh, b_h = three()  # 候选隐状态参数

    # 输出层参数
    W_hq = normal((num_hiddens, num_outputs))
    b_q = torch.zeros(num_outputs, device=device)

    # 附加梯度  # GRU多了前6个参数,作为对比,RNN只有“W_xh, W_hh, b_h, W_hq, b_q”后面这5个参数
    params = [W_xz, W_hz, b_z, W_xr, W_hr, b_r, W_xh, W_hh, b_h, W_hq, b_q]
    for param in params:
        param.requires_grad_(True)  # 把params中每个可学习参数的"是否计算grad"的属性设为True
    return params
 

2.2 定义模型

现在我们将[定义隐状态的初始化函数]init_gru_state。 与 :numref:sec_rnn_scratch中定义的init_rnn_state函数一样, 此函数返回一个形状为(批量大小,隐藏单元个数)的张量,张量的值全部为零。

In [3]:

def init_gru_state(batch_size, num_hiddens, device):  # 隐藏状态初始化,和RNN没区别
    return (torch.zeros((batch_size, num_hiddens), device=device), )

现在我们准备[定义门控循环单元模型], 模型的架构与基本的循环神经网络单元是相同的, 只是权重更新公式更为复杂。

In [4]:

def gru(inputs, state, params):  # 每个batch的前向传播,类似于rnn()
    W_xz, W_hz, b_z, W_xr, W_hr, b_r, W_xh, W_hh, b_h, W_hq, b_q = params
    H, = state  # 传入的隐状态
    outputs = []
    # inputs的形状:(时间步数量,批量大小,词表大小); 是3d的tensor,到这里时input已经做了转置,时间步数维提前
    for X in inputs:  # 每次拿一个时间步出来,X形状:(批量大小,词表大小)
        # “X @ W_xz”约等于“torch.mm(X, W_xz)”做矩阵乘法,但torch.mm只能做2d的,@可以基于PyTorch广播到高维
        # CNN用的是tanh()激活函数,GRU用的是sigmoid()激活,都是对输出的隐藏表示进行归一化
        Z = torch.sigmoid((X @ W_xz) + (H @ W_hz) + b_z)  # “+ b_z”用到广播机制
        R = torch.sigmoid((X @ W_xr) + (H @ W_hr) + b_r)  #  W_x_,W_h_和b__的形状在three()函数里有定义
        H_tilda = torch.tanh((X @ W_xh) + ((R * H) @ W_hh) + b_h)  # 候选隐状态参数; *是元素相乘(Hadamard积)
        H = Z * H + (1 - Z) * H_tilda  # 最终的隐状态更新
        Y = H @ W_hq + b_q  # 输出函数
        outputs.append(Y)  # 拼接每个时间步的输出成list
    return torch.cat(outputs, dim=0), (H,)  # 隐藏状态做成tuple,为的是后面LSTM有两个隐藏状态H
 

2.3 训练与预测

训练和预测的工作方式与 :numref:sec_rnn_scratch完全相同。 训练结束后,我们分别打印输出训练集的困惑度, 以及前缀“time traveler”和“traveler”的预测序列上的困惑度。

In [5]:

vocab_size, num_hiddens, device = len(vocab), 256, d2l.try_gpu()
num_epochs, lr = 500, 1
model = d2l.RNNModelScratch(len(vocab), num_hiddens, device, get_params,
                            init_gru_state, gru)  # RNNModelScratch()类很泛用,改一后三个输入就是GRU啦
d2l.train_ch8(model, train_iter, vocab, lr, num_epochs, device)
perplexity 1.1, 31198.0 tokens/sec on cuda:0
time traveller for so it will be convenient to speak of himwas e
travelleryou can show black is white by argument said filby

3 简洁实现

高级API包含了前文介绍的所有配置细节, 所以我们可以直接实例化门控循环单元模型。 这段代码的运行速度要快得多, 因为它使用的是编译好的运算符而不是Python来处理之前阐述的许多细节。

In [6]:

num_inputs = vocab_size
gru_layer = nn.GRU(num_inputs, num_hiddens)  # pytorch框架里直接掉包,使用GRU类
model = d2l.RNNModel(gru_layer, len(vocab))  # 从零实现的代码,也能兼容pytorch框架(简洁实现)
model = model.to(device)
d2l.train_ch8(model, train_iter, vocab, lr, num_epochs, device)
# 看结果,用框架比从零实现快了9倍,因为这GRU从零实现的小矩阵乘法更多
perplexity 1.0, 271470.1 tokens/sec on cuda:0
time travelleryou can show black is white by argument said filby
travelleryou can show black is white by argument said filby

4. 小结

  • 门控循环神经网络可以更好地捕获时间步距离很长的序列上的依赖关系。
  • 重置门有助于捕获序列中的短期依赖关系。
  • 更新门有助于捕获序列中的长期依赖关系。
  • 重置门打开时,门控循环单元包含基本循环神经网络;更新门打开时,门控循环单元可以跳过子序列。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/921194.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

Java 基于SpringBoot+vue框架的老年医疗保健网站

大家好,我是Java徐师兄,今天为大家带来的是Java Java 基于SpringBootvue框架的老年医疗保健网站。该系统采用 Java 语言开发,SpringBoot 框架,MySql 作为数据库,系统功能完善 ,实用性强 ,可供大…

JavaWeb-表单-07

表单标签 介绍 code: <!DOCTYPE html> <html lang"en"> <head><meta charset"UTF-8"> <meta name"viewport" content"widthdevice-width, initial-scale1.0"><title>HTML-表单</title> &…

计算机网络socket编程(4)_TCP socket API 详解

个人主页&#xff1a;C忠实粉丝 欢迎 点赞&#x1f44d; 收藏✨ 留言✉ 加关注&#x1f493;本文由 C忠实粉丝 原创 计算机网络socket编程(4)_TCP socket API 详解 收录于专栏【计算机网络】 本专栏旨在分享学习计算机网络的一点学习笔记&#xff0c;欢迎大家在评论区交流讨论&…

向量数据库FAISS之五:原理(LSH、PQ、HNSW、IVF)

1.Locality Sensitive Hashing (LSH) 使用 Shingling MinHashing 进行查找 左侧是字典&#xff0c;右侧是 LSH。目的是把足够相似的索引放在同一个桶内。 LSH 有很多的版本&#xff0c;很灵活&#xff0c;这里先介绍第一个版本&#xff0c;也是原始版本 Shingling one-hot …

Django启用国际化支持(2)—实现界面内切换语言:activate()

文章目录 ⭐注意⭐1. 配置项目全局设置&#xff1a;启用国际化2. 编写视图函数3. 配置路由4. 界面演示5、扩展自动识别并切换到当前语言设置语言并保存到Session设置语言并保存到 Cookie ⭐注意⭐ 以下操作依赖于 Django 项目的国际化支持。如果你不清楚如何启用国际化功能&am…

【初阶数据结构与算法】线性表之栈的定义与实现(含源码和有效的括号练习)

文章目录 一、栈的概念与结构1.栈的概念与操作2.栈的底层结构选型 二、栈的实现1.栈结构的定义2. 栈的初始化和销毁栈的初始化栈的销毁 3.栈的扩容与入栈栈的扩容入栈 4.判断栈是否为空和出栈判断栈是否为空出栈 5.取栈顶元素和获取栈中有效元素个数取栈顶元素获取栈中有效元素…

基于Spring Boot+Unipp的博物馆预约小程序(协同过滤算法、二维码识别)【原创】

&#x1f388;系统亮点&#xff1a;协同过滤算法、二维码识别&#xff1b; 一.系统开发工具与环境搭建 1.系统设计开发工具 后端使用Java编程语言的Spring boot框架 项目架构&#xff1a;B/S架构 运行环境&#xff1a;win10/win11、jdk17 前端&#xff1a; 技术&#xff1a;框…

Python 快速入门(上篇)❖ Python基础知识

Python 基础知识 Python安装**运行第一个程序:基本数据类型算术运算符变量赋值操作符转义符获取用户输入综合案例:简单计算器实现Python安装** Linux安装: yum install python36 -y或者编译安装指定版本:https://www.python.org/downloads/source/ wget https://www.pyt…

【MyBatisPlus·最新教程】包含多个改造案例,常用注解、条件构造器、代码生成、静态工具、类型处理器、分页插件、自动填充字段

文章目录 一、MyBatis-Plus简介二、快速入门1、环境准备2、将mybatis项目改造成mybatis-plus项目&#xff08;1&#xff09;引入MybatisPlus依赖&#xff0c;代替MyBatis依赖&#xff08;2&#xff09;配置Mapper包扫描路径&#xff08;3&#xff09;定义Mapper接口并继承BaseM…

【人工智能】PyTorch、TensorFlow 和 Keras 全面解析与对比:深度学习框架的终极指南

文章目录 PyTorch 全面解析2.1 PyTorch 的发展历程2.2 PyTorch 的核心特点2.3 PyTorch 的应用场景 TensorFlow 全面解析3.1 TensorFlow 的发展历程3.2 TensorFlow 的核心特点3.3 TensorFlow 的应用场景 Keras 全面解析4.1 Keras 的发展历程4.2 Keras 的核心特点4.3 Keras 的应用…

Chrome 浏览器 131 版本新特性

Chrome 浏览器 131 版本新特性 一、Chrome 浏览器 131 版本更新 1. 在 iOS 上使用 Google Lens 搜索 自 Chrome 126 版本以来&#xff0c;用户可以通过 Google Lens 搜索屏幕上看到的任何图片或文字。 要使用此功能&#xff0c;请访问网站&#xff0c;并点击聚焦时出现在地…

2.不同语音ai任务dataset类写法

主流语音任务 语音数据读取基本原则 直接保存语音会将该对象保存在内存中&#xff08;Dataset类调用__getitem__方法&#xff09; 所以一般保存这些数据的存储路径文档&#xff08;表单&#xff09;而不是数据的直接copy&#xff08;不然占用内存太大了&#xff09; 通常用nump…

K8S + Jenkins 做CICD

前言 这里会做整体CICD的思路和流程的介绍&#xff0c;会给出核心的Jenkins pipeline脚本&#xff0c;最后会演示一下 实验/实操 结果 由于整体内容较多&#xff0c;所以不打算在这里做每一步的详细演示 - 本文仅作自己的实操记录和日后回顾用 要看保姆式教学的可以划走了&…

力扣 LeetCode 701. 二叉搜索树中的插入操作(Day10:二叉树)

解题思路&#xff1a; 全部插入到叶子节点即可 class Solution {public TreeNode insertIntoBST(TreeNode root, int val) {if (root null) {TreeNode node new TreeNode(val);return node;}if (root.val < val) {root.right insertIntoBST(root.right, val);}if (root…

2024年11月22日Github流行趋势

项目名称&#xff1a;twenty 项目维护者&#xff1a;charlesBochet, lucasbordeau, Weiko, FelixMalfait, bosiraphael 项目介绍&#xff1a;正在构建一个由社区驱动的现代Salesforce替代方案。 项目star数&#xff1a;22,938 项目fork数&#xff1a;2,413 项目名称&#xff1…

Qt之QMainWidget相关

QMainWindow 继承于QWidget的子类 自带一个菜单栏,一个工具栏,可以设置状态栏与铆钉部件 菜单栏:QMenuBar 注意:一个窗口最多一个菜单栏 API: 创建 QMenuBar(parent) 获取QMainWindow自带的菜单栏 QMenuBar* menuBar() 添加菜单:QMenu addMenu(QMenu *menu); 菜单添加活动:QAct…

【深度学习之一】2024最新pytorch+cuda+cudnn下载安装搭建开发环境

兵马未动&#xff0c;粮草先行。作为深度学习的初学者&#xff0c;快速搭建一个属于自己的开发环境就是头等大事&#xff0c;可以让我们节省许多的时间。这一期我们主要讲一讲2024年最新pytorchcudacudnn下载安装搭建开发环境&#xff0c;以及安装过程中可能遇到的一些问题以及…

SQL 复杂查询

目录 复杂查询 一、目的和要求 二、实验内容 &#xff08;1&#xff09;查询出所有水果产品的类别及详情。 查询出编号为“00000001”的消费者用户的姓名及其所下订单。&#xff08;分别采用子查询和连接方式实现&#xff09; 查询出每个订单的消费者姓名及联系方式。 在…

如何在 UniApp 中实现 iOS 版本更新检测

随着移动应用的不断发展&#xff0c;保持应用程序的更新是必不可少的&#xff0c;这样用户才能获得更好的体验。本文将帮助你在 UniApp 中实现 iOS 版的版本更新检测和提示&#xff0c;适合刚入行的小白。我们将分步骤进行说明&#xff0c;每一步所需的代码及其解释都会一一列出…

ssm面向品牌会员的在线商城小程序

摘要 随着Internet的发展&#xff0c;人们的日常生活已经离不开网络。未来人们的生活与工作将变得越来越数字化&#xff0c;网络化和电子化。它将是直接管理面向品牌会员的在线商城小程序的最新形式。本小程序是以面向品牌会员的在线商城管理为目标&#xff0c;使用 java技术制…