【深度学习笔记】6_7 门控循环单元(GRU)

注:本文为《动手学深度学习》开源内容,部分标注了个人理解,仅为个人学习记录,无抄袭搬运意图

6.7 门控循环单元(GRU)

上一节介绍了循环神经网络中的梯度计算方法。我们发现,当时间步数较大或者时间步较小时,循环神经网络的梯度较容易出现衰减或爆炸。虽然裁剪梯度可以应对梯度爆炸,但无法解决梯度衰减的问题。通常由于这个原因,循环神经网络在实际中较难捕捉时间序列中时间步距离较大的依赖关系。

门控循环神经网络(gated recurrent neural network)的提出,正是为了更好地捕捉时间序列中时间步距离较大的依赖关系。它通过可以学习的门来控制信息的流动。其中,门控循环单元(gated recurrent unit,GRU)是一种常用的门控循环神经网络 [1, 2]。另一种常用的门控循环神经网络则将在下一节中介绍。

6.7.1 门控循环单元

下面将介绍门控循环单元的设计。它引入了重置门(reset gate)和更新门(update gate)的概念,从而修改了循环神经网络中隐藏状态的计算方式。

6.7.1.1 重置门和更新门

如图6.4所示,门控循环单元中的重置门和更新门的输入均为当前时间步输入 X t \boldsymbol{X}_t Xt与上一时间步隐藏状态 H t − 1 \boldsymbol{H}_{t-1} Ht1,输出由激活函数为sigmoid函数的全连接层计算得到。

在这里插入图片描述

图6.4 门控循环单元中重置门和更新门的计算

具体来说,假设隐藏单元个数为 h h h,给定时间步 t t t的小批量输入 X t ∈ R n × d \boldsymbol{X}_t \in \mathbb{R}^{n \times d} XtRn×d(样本数为 n n n,输入个数为 d d d)和上一时间步隐藏状态 H t − 1 ∈ R n × h \boldsymbol{H}_{t-1} \in \mathbb{R}^{n \times h} Ht1Rn×h。重置门 R t ∈ R n × h \boldsymbol{R}_t \in \mathbb{R}^{n \times h} RtRn×h和更新门 Z t ∈ R n × h \boldsymbol{Z}_t \in \mathbb{R}^{n \times h} ZtRn×h的计算如下:

R t = σ ( X t W x r + H t − 1 W h r + b r ) , Z t = σ ( X t W x z + H t − 1 W h z + b z ) , \begin{aligned} \boldsymbol{R}_t = \sigma(\boldsymbol{X}_t \boldsymbol{W}_{xr} + \boldsymbol{H}_{t-1} \boldsymbol{W}_{hr} + \boldsymbol{b}_r),\\ \boldsymbol{Z}_t = \sigma(\boldsymbol{X}_t \boldsymbol{W}_{xz} + \boldsymbol{H}_{t-1} \boldsymbol{W}_{hz} + \boldsymbol{b}_z), \end{aligned} Rt=σ(XtWxr+Ht1Whr+br),Zt=σ(XtWxz+Ht1Whz+bz),

其中 W x r , W x z ∈ R d × h \boldsymbol{W}_{xr}, \boldsymbol{W}_{xz} \in \mathbb{R}^{d \times h} Wxr,WxzRd×h W h r , W h z ∈ R h × h \boldsymbol{W}_{hr}, \boldsymbol{W}_{hz} \in \mathbb{R}^{h \times h} Whr,WhzRh×h是权重参数, b r , b z ∈ R 1 × h \boldsymbol{b}_r, \boldsymbol{b}_z \in \mathbb{R}^{1 \times h} br,bzR1×h是偏差参数。3.8节(多层感知机)节中介绍过,sigmoid函数可以将元素的值变换到0和1之间。因此,重置门 R t \boldsymbol{R}_t Rt和更新门 Z t \boldsymbol{Z}_t Zt中每个元素的值域都是 [ 0 , 1 ] [0, 1] [0,1]

6.7.1.2 候选隐藏状态

接下来,门控循环单元将计算候选隐藏状态来辅助稍后的隐藏状态计算。如图6.5所示,我们将当前时间步重置门的输出与上一时间步隐藏状态做按元素乘法(符号为 ⊙ \odot )。如果重置门中元素值接近0,那么意味着重置对应隐藏状态元素为0,即丢弃上一时间步的隐藏状态。如果元素值接近1,那么表示保留上一时间步的隐藏状态。然后,将按元素乘法的结果与当前时间步的输入连结,再通过含激活函数tanh的全连接层计算出候选隐藏状态,其所有元素的值域为 [ − 1 , 1 ] [-1, 1] [1,1]

在这里插入图片描述

图6.5 门控循环单元中候选隐藏状态的计算

具体来说,时间步 t t t的候选隐藏状态 H ~ t ∈ R n × h \tilde{\boldsymbol{H}}_t \in \mathbb{R}^{n \times h} H~tRn×h的计算为

H ~ t = tanh ( X t W x h + ( R t ⊙ H t − 1 ) W h h + b h ) , \tilde{\boldsymbol{H}}_t = \text{tanh}(\boldsymbol{X}_t \boldsymbol{W}_{xh} + \left(\boldsymbol{R}_t \odot \boldsymbol{H}_{t-1}\right) \boldsymbol{W}_{hh} + \boldsymbol{b}_h), H~t=tanh(XtWxh+(RtHt1)Whh+bh),

其中 W x h ∈ R d × h \boldsymbol{W}_{xh} \in \mathbb{R}^{d \times h} WxhRd×h W h h ∈ R h × h \boldsymbol{W}_{hh} \in \mathbb{R}^{h \times h} WhhRh×h是权重参数, b h ∈ R 1 × h \boldsymbol{b}_h \in \mathbb{R}^{1 \times h} bhR1×h是偏差参数。从上面这个公式可以看出,重置门控制了上一时间步的隐藏状态如何流入当前时间步的候选隐藏状态。而上一时间步的隐藏状态可能包含了时间序列截至上一时间步的全部历史信息。因此,重置门可以用来丢弃与预测无关的历史信息。

6.7.1.3 隐藏状态

最后,时间步 t t t的隐藏状态 H t ∈ R n × h \boldsymbol{H}_t \in \mathbb{R}^{n \times h} HtRn×h的计算使用当前时间步的更新门 Z t \boldsymbol{Z}_t Zt来对上一时间步的隐藏状态 H t − 1 \boldsymbol{H}_{t-1} Ht1和当前时间步的候选隐藏状态 H ~ t \tilde{\boldsymbol{H}}_t H~t做组合:

H t = Z t ⊙ H t − 1 + ( 1 − Z t ) ⊙ H ~ t . \boldsymbol{H}_t = \boldsymbol{Z}_t \odot \boldsymbol{H}_{t-1} + (1 - \boldsymbol{Z}_t) \odot \tilde{\boldsymbol{H}}_t. Ht=ZtHt1+(1Zt)H~t.

在这里插入图片描述

图6.6 门控循环单元中隐藏状态的计算

值得注意的是,更新门可以控制隐藏状态应该如何被包含当前时间步信息的候选隐藏状态所更新,如图6.6所示。假设更新门在时间步 t ′ t' t t t t t ′ < t t' < t t<t)之间一直近似1。那么,在时间步 t ′ t' t t t t之间的输入信息几乎没有流入时间步 t t t的隐藏状态 H t \boldsymbol{H}_t Ht。实际上,这可以看作是较早时刻的隐藏状态 H t ′ − 1 \boldsymbol{H}_{t'-1} Ht1一直通过时间保存并传递至当前时间步 t t t。这个设计可以应对循环神经网络中的梯度衰减问题,并更好地捕捉时间序列中时间步距离较大的依赖关系。

我们对门控循环单元的设计稍作总结:

  • 重置门有助于捕捉时间序列里短期的依赖关系;
  • 更新门有助于捕捉时间序列里长期的依赖关系。

6.7.2 读取数据集

为了实现并展示门控循环单元,下面依然使用周杰伦歌词数据集来训练模型作词。这里除门控循环单元以外的实现已在6.2节(循环神经网络)中介绍过。以下为读取数据集部分。

import numpy as np
import torch
from torch import nn, optim
import torch.nn.functional as F

import sys
sys.path.append("..") 
import d2lzh_pytorch as d2l
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

(corpus_indices, char_to_idx, idx_to_char, vocab_size) = d2l.load_data_jay_lyrics()

6.7.3 从零开始实现

我们先介绍如何从零开始实现门控循环单元。

6.7.3.1 初始化模型参数

下面的代码对模型参数进行初始化。超参数num_hiddens定义了隐藏单元的个数。

num_inputs, num_hiddens, num_outputs = vocab_size, 256, vocab_size
print('will use', device)

def get_params():
    def _one(shape):
        ts = torch.tensor(np.random.normal(0, 0.01, size=shape), device=device, dtype=torch.float32)
        return torch.nn.Parameter(ts, requires_grad=True)
    def _three():
        return (_one((num_inputs, num_hiddens)),
                _one((num_hiddens, num_hiddens)),
                torch.nn.Parameter(torch.zeros(num_hiddens, device=device, dtype=torch.float32), requires_grad=True))
    
    W_xz, W_hz, b_z = _three()  # 更新门参数
    W_xr, W_hr, b_r = _three()  # 重置门参数
    W_xh, W_hh, b_h = _three()  # 候选隐藏状态参数
    
    # 输出层参数
    W_hq = _one((num_hiddens, num_outputs))
    b_q = torch.nn.Parameter(torch.zeros(num_outputs, device=device, dtype=torch.float32), requires_grad=True)
    return nn.ParameterList([W_xz, W_hz, b_z, W_xr, W_hr, b_r, W_xh, W_hh, b_h, W_hq, b_q])

6.7.3.2 定义模型

下面的代码定义隐藏状态初始化函数init_gru_state。同6.4节(循环神经网络的从零开始实现)中定义的init_rnn_state函数一样,它返回由一个形状为(批量大小, 隐藏单元个数)的值为0的Tensor组成的元组。

def init_gru_state(batch_size, num_hiddens, device):
    return (torch.zeros((batch_size, num_hiddens), device=device), )

下面根据门控循环单元的计算表达式定义模型。

def gru(inputs, state, params):
    W_xz, W_hz, b_z, W_xr, W_hr, b_r, W_xh, W_hh, b_h, W_hq, b_q = params
    H, = state
    outputs = []
    for X in inputs:
        Z = torch.sigmoid(torch.matmul(X, W_xz) + torch.matmul(H, W_hz) + b_z)
        R = torch.sigmoid(torch.matmul(X, W_xr) + torch.matmul(H, W_hr) + b_r)
        H_tilda = torch.tanh(torch.matmul(X, W_xh) + torch.matmul(R * H, W_hh) + b_h)
        H = Z * H + (1 - Z) * H_tilda
        Y = torch.matmul(H, W_hq) + b_q
        outputs.append(Y)
    return outputs, (H,)

6.7.3.3 训练模型并创作歌词

我们在训练模型时只使用相邻采样。设置好超参数后,我们将训练模型并根据前缀“分开”和“不分开”分别创作长度为50个字符的一段歌词。

num_epochs, num_steps, batch_size, lr, clipping_theta = 160, 35, 32, 1e2, 1e-2
pred_period, pred_len, prefixes = 40, 50, ['分开', '不分开']

我们每过40个迭代周期便根据当前训练的模型创作一段歌词。

d2l.train_and_predict_rnn(gru, get_params, init_gru_state, num_hiddens,
                          vocab_size, device, corpus_indices, idx_to_char,
                          char_to_idx, False, num_epochs, num_steps, lr,
                          clipping_theta, batch_size, pred_period, pred_len,
                          prefixes)

输出:

epoch 40, perplexity 149.477598, time 1.08 sec
 - 分开 我不不你 我想你你的爱我 你不你的让我 你不你的让我 你不你的让我 你不你的让我 你不你的让我 你
 - 不分开 我想你你的让我 你不你的让我 你不你的让我 你不你的让我 你不你的让我 你不你的让我 你不你的让我
epoch 80, perplexity 31.689210, time 1.10 sec
 - 分开 我想要你 我不要再想 我不要再想 我不要再想 我不要再想 我不要再想 我不要再想 我不要再想 我不
 - 不分开 我想要你 我不要再想 我不要再想 我不要再想 我不要再想 我不要再想 我不要再想 我不要再想 我不
epoch 120, perplexity 4.866115, time 1.08 sec
 - 分开 我想要这样牵着你的手不放开 爱过 让我来的肩膀 一起好酒 你来了这节秋 后知后觉 我该好好生活 我
 - 不分开 你已经不了我不要 我不要再想你 我不要再想你 我不要再想你 不知不觉 我跟了这节奏 后知后觉 又过
epoch 160, perplexity 1.442282, time 1.51 sec
 - 分开 我一定好生忧 唱着歌 一直走 我想就这样牵着你的手不放开 爱可不可以简简单单没有伤害 你 靠着我的
 - 不分开 你已经离开我 不知不觉 我跟了这节奏 后知后觉 又过了一个秋 后知后觉 我该好好生活 我该好好生活

6.7.4 简洁实现

在PyTorch中我们直接调用nn模块中的GRU类即可。

lr = 1e-2 # 注意调整学习率
gru_layer = nn.GRU(input_size=vocab_size, hidden_size=num_hiddens)
model = d2l.RNNModel(gru_layer, vocab_size).to(device)
d2l.train_and_predict_rnn_pytorch(model, num_hiddens, vocab_size, device,
                                corpus_indices, idx_to_char, char_to_idx,
                                num_epochs, num_steps, lr, clipping_theta,
                                batch_size, pred_period, pred_len, prefixes)

输出:

epoch 40, perplexity 1.022157, time 1.02 sec
 - 分开手牵手 一步两步三步四步望著天 看星星 一颗两颗三颗四颗 连成线背著背默默许下心愿 看远方的星是否听
 - 不分开暴风圈来不及逃 我不能再想 我不能再想 我不 我不 我不能 爱情走的太快就像龙卷风 不能承受我已无处
epoch 80, perplexity 1.014535, time 1.04 sec
 - 分开始想像 爸和妈当年的模样 说著一口吴侬软语的姑娘缓缓走过外滩 消失的 旧时光 一九四三 在回忆 的路
 - 不分开始爱像  不知不觉 你已经离开我 不知不觉 我跟了这节奏 后知后觉 又过了一个秋 后知后觉 我该好好
epoch 120, perplexity 1.147843, time 1.04 sec
 - 分开都靠我 你拿着球不投 又不会掩护我 选你这种队友 瞎透了我 说你说 分数怎么停留 所有回忆对着我进攻
 - 不分开球我有多烦恼多 牧草有没有危险 一场梦 我面对我 甩开球我满腔的怒火 我想揍你已经很久 别想躲 说你
epoch 160, perplexity 1.018370, time 1.05 sec
 - 分开爱上你 那场悲剧 是你完美演出的一场戏 宁愿心碎哭泣 再狠狠忘记 你爱过我的证据 让晶莹的泪滴 闪烁
 - 不分开始 担心今天的你过得好不好 整个画面是你 想你想的睡不著 嘴嘟嘟那可爱的模样 还有在你身上香香的味道

小结

  • 门控循环神经网络可以更好地捕捉时间序列中时间步距离较大的依赖关系。
  • 门控循环单元引入了门的概念,从而修改了循环神经网络中隐藏状态的计算方式。它包括重置门、更新门、候选隐藏状态和隐藏状态。
  • 重置门有助于捕捉时间序列里短期的依赖关系。
  • 更新门有助于捕捉时间序列里长期的依赖关系。

参考文献

[1] Cho, K., Van Merriënboer, B., Bahdanau, D., & Bengio, Y. (2014). On the properties of neural machine translation: Encoder-decoder approaches. arXiv preprint arXiv:1409.1259.

[2] Chung, J., Gulcehre, C., Cho, K., & Bengio, Y. (2014). Empirical evaluation of gated recurrent neural networks on sequence modeling. arXiv preprint arXiv:1412.3555.


注:除代码外本节与原书此节基本相同,原书传送门

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/443093.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

vue 下载的插件从哪里上传?npm发布插件详细记录

文章参考&#xff1a; 参考文章一&#xff1a; 封装vue插件并发布到npm详细步骤_vue-cli 封装插件-CSDN博客 参考文章二&#xff1a; npm发布vue插件步骤、组件、package、adduser、publish、getElementsByClassName、important、export、default、target、dest_export default…

HTML静态网页成品作业(HTML+CSS+JS)——和平精英介绍设计制作(4个页面)

&#x1f389;不定期分享源码&#xff0c;关注不丢失哦 文章目录 一、作品介绍二、作品演示三、代码目录四、网站代码HTML部分代码 五、源码获取 一、作品介绍 &#x1f3f7;️本套采用HTMLCSS&#xff0c;使用Javacsript代码实现图片轮播&#xff0c;共有4个页面。 二、作品…

Pytorch学习 day08(最大池化层、非线性激活层)

最大池化层 最大池化&#xff0c;也叫上采样&#xff0c;是池化核在输入图像上不断移动&#xff0c;并取对应区域中的最大值&#xff0c;目的是&#xff1a;在保留输入特征的同时&#xff0c;减小输入数据量&#xff0c;加快训练。参数设置如下&#xff1a; kernel_size&#…

微信加好友频繁会被封号吗?

微信加好友频繁会被封号吗&#xff1f; 微信规定,每个人每天最多可以加20个好友&#xff0c;但一天之内如果频繁加好友&#xff0c;微信可能会出现异常提示&#xff0c;需要暂停好友添加操作。 面对微信上突如其来的大量好友申请&#xff0c;一定要谨慎处理&#xff0c;以免被…

Golang搭建grpc环境

简介 OS : Windows 11 Golang 版本: go1.22.0 grpc : 1.2 protobuffer: 1.28代理 没有代理国内环境下载不了库七牛CDN &#xff08;试过可用&#xff09; go env -w GOPROXYhttps://goproxy.cn,direct阿里云代理(运行grpc时下载包出现报错 ): go env -w GOPROXYhttps://mirr…

CCProxy代理服务器地址的设置步骤

目录 前言 一、下载和安装CCProxy 二、启动CCProxy并设置代理服务器地址 三、验证代理服务器设置是否生效 四、使用CCProxy进行代理设置的代码示例 总结 前言 CCProxy是一款常用的代理服务器软件&#xff0c;可以帮助用户实现网络共享和上网代理。本文将详细介绍CCProxy…

IntelliJ IDEA 2020.2.4试用方法

打开idea&#xff0c;准备好ide-eval-resetter压缩包。 将准备好的压缩包拖入idea中 选中弹窗中的自动重置选项&#xff0c;并点击重置 查看免费试用时长

[数据集][目标检测]变电站缺陷检测数据集VOC+YOLO格式8307张17类别

数据集格式&#xff1a;Pascal VOC格式YOLO格式(不包含分割路径的txt文件&#xff0c;仅仅包含jpg图片以及对应的VOC格式xml文件和yolo格式txt文件) 图片数量(jpg文件个数)&#xff1a;8307 标注数量(xml文件个数)&#xff1a;8307 标注数量(txt文件个数)&#xff1a;8307 标注…

汽车大灯汽车尾灯破裂裂纹破损破洞掉角崩角等问题能修复吗?修复后需要注意什么?

汽车灯罩破损修复后&#xff0c;车主需要注意以下几点&#xff1a; 检查修复效果&#xff1a;修复完成后&#xff0c;车主应该仔细检查灯罩的修复效果&#xff0c;确保破损部分已经被填补并恢复原有的透明度和光泽。如果修复效果不理想&#xff0c;需要及时联系维修店进行处理…

问题:前端获取long型数值精度丢失,后面几位都为0

文章目录 问题分析解决 问题 通过接口获取到的数据和 Postman 获取到的数据不一样&#xff0c;仔细看 data 的第17位之后 分析 该字段类型是long类型问题&#xff1a;前端接收到数据后&#xff0c;发现精度丢失&#xff0c;当返回的结果超过17位的时候&#xff0c;后面的全…

什么是工业级物联网智能网关?如何远程控制PLC?

在这个信息爆炸的时代&#xff0c;物联网技术已经逐渐渗透到我们生活的方方面面&#xff0c;而工业级物联网智能网关作为连接工业设备和云端的重要桥梁&#xff0c;更是引领着工业4.0时代的浪潮。那么&#xff0c;究竟什么是工业级物联网智能网关呢&#xff1f;今天&#xff0c…

git删除comimit提交的记录

文章目录 本地的删除远程同步修改上次提交更多详情阅读 本地的删除 例如我的提交历史如下 commit 58211e7a5da5e74171e90d8b90b2f00881a48d3a Author: test <test36nu.com> Date: Fri Sep 22 20:55:38 2017 0800add d.txtcommit 0fb295fe0e0276f0c81df61c4fd853b7a00…

详解DNS服务

华子目录 概述产生原因作用连接方式 因特网的域名结构拓扑分类域名服务器类型划分 DNS域名解析过程分类解析图图过程分析注意 搭建DNS域名解析服务器概述安装软件bind服务中的三个关键文件 配置文件分析主配置文件共4部分组成区域配置文件作用区域配置文件示例分析正向解析反向…

STM32代码调试时遇到的一些error和warning

持续更新 ERROR WARNING 1.Note: object file renamed from “xxx.o“ to “xxx_1.o“ 出现下面这些warning可能的原因&#xff1a; &#xff08;1&#xff09;没有将头文件加入到main.c中&#xff0c;检查一下在编译。 &#xff08;2&#xff09;修改源文件路径的时候忘记…

python学习the sixth day

python函数进阶 一、函数多返回值 二、函数的多种参数使用 1.位置参数 2.关键字参数 3.缺省参数 设置默认值&#xff0c;必须放在最后面 4.不定长参数 4.总结 三、匿名函数 1.函数作为参数传递 这是计算逻辑的传递&#xff0c;而非数据的传递 2.lambda匿名函数 python文件操…

【电路笔记】-PNP晶体管

PNP晶体管 文章目录 PNP晶体管1、概述2、PNP晶体管电路示例3、PNP晶体管识别1、概述 PNP 晶体管与我们在上一篇教程中看到的 NPN 晶体管器件完全相反。 在这种类型的 PNP 晶体管结构中,两个互连的二极管相对于之前的 NPN 晶体管是相反的。 这会产生正-负-正类型的配置,箭头…

JAVA实战开源项目:智能停车场管理系统(Vue+SpringBoot)

目录 一、摘要1.1 项目介绍1.2 项目录屏 二、研究内容A. 车主端功能B. 停车工作人员功能C. 系统管理员功能1. 停车位模块2. 车辆模块3. 停车记录模块4. IC卡模块5. IC卡挂失模块 三、界面展示3.1 登录注册3.2 车辆模块3.3 停车位模块3.4 停车数据模块3.5 IC卡档案模块3.6 IC卡挂…

Android Studio下载gradle超时问题解决

方法一 1. 配置根目录的setting.gradle.kts文件 pluginManagement {repositories {maven { urluri ("https://www.jitpack.io")}maven { urluri ("https://maven.aliyun.com/repository/releases")}maven { urluri ("https://maven.aliyun.com/repos…

基于springboot的家庭装修报价系统设计与实现

目 录 摘 要 I Abstract II 引 言 1 1 相关技术 3 1.1 SpringBoot框架 3 1.2 ECharts 3 1.3 Vue框架 3 1.4 Bootstrap框架 3 1.5 JQuery技术 4 1.6 Ajax技术 4 1.7 本章小结 4 2 系统分析 5 2.1 需求分析 5 2.2 非功能需求 7 2.3 本章小结 8 3 系统设计 9 3.1 系统总体设计 9 …

Python学习日记之学习turtle库(上 篇)

一、初步认识turtle库 turtle 库是 Python 语言中一个很流行的绘制图像的函数库&#xff0c;想象一个小乌龟&#xff0c;在一个横 轴为 x、纵轴为 y 的坐标系原点&#xff0c;(0,0)位置开始&#xff0c;它根据一组函数指令的控制&#xff0c;在这个平面 坐标系中移动&#xff0…