GRU模块:nn.GRU层

摘要: 

       如果需要深入理解GRU的话,内部实现的详细代码和计算公式就比较重要,中间的一些过程及中间变量的意义需要详细关注。只有这样,才能准备把握这个模块的内涵和意义,设计初衷和使用方式等等。所以,仔细研究这个模块的实现还是非常有必要的。以此类推,对于其他的模块同样如此,只有把各个经典的模块内部原理、实现和计算调用都搞清楚了,才能更好的去设计和利用神经网络,建立内在的直觉和能力。

       本文中介绍GRU内部的代码实现与数学表达式一致,在实际使用中,一般是通过调用API来实现,即语句:self.rnn = nn.GRU(embed_size, num_hiddens, num_layers, dropout=dropout),只需要设定相应的参数即可,免除了重新实现的繁琐,并且类似于pytorch框架中的API还做了计算上的优化,使用起来高效方便。

       先从输入输出的角度看,即代码中的这一行:output, state = self.rnn(X) 。在 GRU(Gated Recurrent Unit)中,outputstate 都是由 GRU 层的循环计算产生的,它们之间有直接的关系。state 实际上是 output 中最后一个时间步的隐藏状态。 

代码示例

class Seq2SeqEncoder(d2l.Encoder):
"""⽤于序列到序列学习的循环神经⽹络编码器"""
    def __init__(self, vocab_size, embed_size, num_hiddens, num_layers,
dropout=0, **kwargs):
       super(Seq2SeqEncoder, self).__init__(**kwargs)
        # 嵌⼊层
       self.embedding = nn.Embedding(vocab_size, embed_size)
       self.rnn = nn.GRU(embed_size, num_hiddens, num_layers,
       dropout=dropout)

    def forward(self, X, *args):
    # 输出'X'的形状:(batch_size,num_steps,embed_size)
        X = self.embedding(X)
    # 在循环神经⽹络模型中,第⼀个轴对应于时间步
        X = X.permute(1, 0, 2)
    # 如果未提及状态,则默认为0
        output, state = self.rnn(X)
    # output的形状:(num_steps,batch_size,num_hiddens)
    # state的形状:(num_layers,batch_size,num_hiddens)
        return output, state

output:在完成所有时间步后,最后⼀层的隐状态的输出output是⼀个张量(output由编码器的循环层返回),其形状为(时间步数,批量⼤⼩,隐藏单元数)。

state:最后⼀个时间步的多层隐状态是state的形状是(隐藏层的数量,批量⼤⼩, 隐藏单元的数量)。

GRU模块的框图 

GRU 的基本公式

GRU 的核心计算包括更新门(update gate)和重置门(reset gate),以及候选隐藏状态(candidate hidden state)。数学表达式如下:

  1. 更新门 \( z_t \): \[ z_t = \sigma(W_z \cdot h_{t-1} + U_z \cdot x_t) \]
       其中,\( \sigma \) 是sigmoid 函数,\( W_z \) 和 \( U_z \) 分别是对应于隐藏状态和输入的权重矩阵,\( h_{t-1} \) 是上一个时间步的隐藏状态,\( x_t \) 是当前时间步的输入。

  2. 重置门 \( r_t \):
       \[ r_t = \sigma(W_r \cdot h_{t-1} + U_r \cdot x_t) \]
       \( W_r \) 和 \( U_r \) 是更新门中定义的相似权重矩阵。

  3. 候选隐藏状态 \( \tilde{h}_t \):
       \[ \tilde{h}_t = \tanh(W \cdot r_t \odot h_{t-1} + U \cdot x_t) \]
       这里,\( \tanh \) 是激活函数,\( \odot \) 表示元素乘法(Hadamard product),\( W \) 和 \( U \) 是隐藏状态的权重矩阵。

  4. 最终隐藏状态 \( h_t \):
       \[ h_t = (1 - z_t) \odot h_{t-1} + z_t \odot \tilde{h}_t \]

output 和 state 的关系

  • output:在 GRU 中,output 包含了序列中每个时间步的隐藏状态。具体来说,对于每个时间步 \( t \),output 的第 \( t \) 个元素就是该时间步的隐藏状态 \( h_t \)。

  • state:state 是 GRU 层最后一层的隐藏状态,也就是 output 中最后一个时间步的隐藏状态 \( h_{T-1} \),其中 \( T \) 是序列的长度。

数学表达式

如果我们用 \( O \) 表示 output,\( S \) 表示 state,\( T \) 表示时间步的总数,那么:

\[ O = [h_0, h_1, ..., h_{T-1}] \]
\[ S = h_{T-1} \]

因此,state 实际上是 output 中最后一个元素,即 \( S = O[T-1] \)。

在 PyTorch 中,output 和 state 都是由 GRU 层的 `forward` 方法计算得到的。`output` 是一个三维张量,包含了序列中每个时间步的隐藏状态,而 `state` 是一个二维张量,仅包含最后一个时间步的隐藏状态。

GRU的内部实现

上面一节的代码示例,是通过调用API实现的,即self.rnn = nn.GRU(embed_size, num_hiddens, num_layers, dropout=dropout)。那么,GRU内部是如何实现的呢?

分为模型、模型参数初始化和隐状态初始化三个部分:

模型定义(模型定义与数学表示式一致,也可以参考上图):

def gru(inputs, state, params):
    W_xz, W_hz, b_z, W_xr, W_hr, b_r, W_xh, W_hh, b_h, W_hq, b_q = params
    H, = state
    outputs = []
    for X in inputs:
        Z = torch.sigmoid((X @ W_xz) + (H @ W_hz) + b_z)
        R = torch.sigmoid((X @ W_xr) + (H @ W_hr) + b_r)
        H_tilda = torch.tanh((X @ W_xh) + ((R * H) @ W_hh) + b_h)
        H = Z * H + (1 - Z) * H_tilda
        Y = H @ W_hq + b_q
        outputs.append(Y)
    return torch.cat(outputs, dim=0), (H,)

  模型参数初始化(权重是从标准差0.01的高斯分布中提取的,超参数num_hiddens定义隐藏单元的数量,偏置项设置为0):

def get_params(vocab_size, num_hiddens, device):
    num_inputs = num_outputs = vocab_size

    def normal(shape):
        return torch.randn(size=shape, device=device)*0.01
    def three():
        return (normal((num_inputs, num_hiddens)),
                normal((num_hiddens, num_hiddens)),
                torch.zeros(num_hiddens, device=device))

    W_xz, W_hz, b_z = three() # 更新⻔参数
    W_xr, W_hr, b_r = three() # 重置⻔参数
    W_xh, W_hh, b_h = three() # 候选隐状态参数
# 输出层参数
    W_hq = normal((num_hiddens, num_outputs))
    b_q = torch.zeros(num_outputs, device=device)
# 附加梯度
    params = [W_xz, W_hz, b_z, W_xr, W_hr, b_r, W_xh, W_hh, b_h, W_hq, b_q]
    for param in params:
        param.requires_grad_(True)
    return params

隐状态初始化函数(此函数返回一个形状为(批量大小,隐藏单元个数)的张量,张量的值都为0

def init_gru_state(batch_size, num_hiddens, device):
    return (torch.zeros((batch_size, num_hiddens), device=device), )

最后由一个函数统一起来,实现模型:

model = d2l.RNNModelScratch(len(vocab), num_hiddens, device, get_params, init_gru_state, gru)

小结 

       总体上说,内部的代码实现与数学表达式一致,在实际使用中,一般是通过调用API来实现,即self.rnn = nn.GRU(embed_size, num_hiddens, num_layers, dropout=dropout),只需要设定相应的参数即可,免除了重新实现的繁琐,并且类似于pytorch框架中的API还做了计算上的优化,使用起来高效方便。但是,如果需要深入理解GRU的话,那么内部实现的详细代码和计算公式就比较重要,中间的一些过程和变量的意义需要详细关注,只有这样,才能准备把握这个模块的内涵和意义,设计初衷和使用方式等等,所以,仔细研究这个模块的实现还是非常有必要的。对于其他的模块同样如此,只有把各个经典的模块内部原理、实现和计算调用都搞清楚了,才能更好的去设计和利用神经网络,建立内在的直觉和能力。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/613062.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

实现流程化办公,可了解一下可视化报表开源

当前,实现流程化办公早已成为众多中小企业的发展目标和趋势。可以借助什么样的软件平台实现这一目标?低代码技术平台拥有可视化操作界面、够灵活、易维护等优势特点,在助力企业实现流程化办公、数字化转型方面具有重要的应用价值和推动作用。…

Selenium定位方法汇总及举例

天行健,君子以自强不息;地势坤,君子以厚德载物。 每个人都有惰性,但不断学习是好好生活的根本,共勉! 文章均为学习整理笔记,分享记录为主,如有错误请指正,共同学习进步。…

arp icmp 等报文格式

ARP报文格式 ARP是一个独立的三层协议,所以ARP报文在向数据链路层传输时不需要经过IP协议的封装,而是直接生成自己的报文,其中包括ARP报头,到数据链路层后再由对应的数据链路层协议(如以太网协议)进行封装…

2022 年全国职业院校技能大赛高职组云计算赛项试卷(容器云)

#需要资源(软件包及镜像)或有问题的,可私聊博主!!! #需要资源(软件包及镜像)或有问题的,可私聊博主!!! #需要资源(软件包…

C#上位机1ms级高精度定时任务

precisiontimer 安装扩展包 添加引用 完整代码 using PrecisionTiming;using System; using System.Collections.Generic; using System.ComponentModel; using System.Data; using System.Drawing; using System.Linq; using System.Text; using System.Threading.Tasks; us…

Android虚拟机机制

目录 一、Android 虚拟机 dalvik/art(6版本后)二、Android dex、odex、oat、vdex、art区别 一、Android 虚拟机 dalvik/art(6版本后) 每个应用都在其自己的进程中运行,都有自己的虚拟机实例。ART通过执行DEX文件可在设…

C——单链表

一.前言 我们在前面已经了解了链表中的双向链表,而我们在介绍链表分类的时候就说过常用的链表只有两种——双向带头循环链表和单向不带头不循环链表。下来我来介绍另一种常用的链表——单向不带头不循环链表也叫做单链表。不清楚链表分类的以及不了解双向链表的可以…

react18【系列实用教程】JSX (2024最新版)

为什么要用 JSX? JSX 给 HTML 赋予了 JS 的编程能力 JSX 的本质 JSX 是 JavaScript 的语法扩展,浏览器本身不能识别,需要通过解析工具(如babel)解析之后才能在浏览器中运行。 bable 官网可以查看解析过程 JSX 的语法 …

杭州打的样,适合全国推广

房地产 昨天,杭州和西安全面解除房地产限购。 在房价跌跌不休的今天,这两大城市取消限购其实并不意外。 尤其是杭州,土地财政依赖全国第一,绷不住很正常。 近十年,杭州依靠于亚运会、G20 和阿里巴巴,涨得飞…

将机械手与CodeSys中的运动学模型绑定

文章目录 1.背景介绍2.选定运动学模型3.机械手各尺寸的对应4.总结4.1.选择正确的运动学模型4.2.注意各个关节旋转的正方向。4.3.编码器零点与机械零点的偏移修正。 1.背景介绍 最近搞到了一台工业机械手,虽然这个机械手有自己的控制程序,但是我们还是想…

Java入门基础学习笔记1——初识java

1、为什么学习java? 几乎统治了服务端的开发;几乎所有的互联网企业都使用;100%国内大中型企业都用;全球100亿的设备运行java。开发岗位薪资高。 Java的流行度很高,商用占有率很高。 可移植性。 2、Java的背景知识 …

【基础算法总结】二分查找一

二分查找一 1. 二分查找2.在排序数组中查找元素的第一个和最后一个位置3.x 的平方根4.搜索插入位置 点赞👍👍收藏🌟🌟关注💖💖 你的支持是对我最大的鼓励,我们一起努力吧!😃&#x1…

Java入门基础学习笔记12——变量详解

变量详解: 变量里的数据在计算机中的存储原理。 二进制: 只有0和1, 按照逢2进1的方式表示数据。 十进制转二进制的算法: 除二取余法。 6是110 13是1101 计算机中表示数据的最小单元:一个字节(byte&…

【教程向】从零开始创建浏览器插件(三)解决 Chrome 扩展中弹出页面、背景脚本、内容脚本之间通信的问题

第三步:解决 Chrome 扩展中弹出页面、背景脚本、内容脚本之间通信的问题 Chrome 扩展开发中,弹出页面(Popup)、背景脚本(Background Script)、内容脚本(Content Script)各自拥有独立…

word转pdf的java实现(documents4j)

一、多余的话 java实现word转pdf可用的jar包不多,很多都是收费的。最近发现com.documents4j挺好用的,它支持在本机转换,也支持远程服务转换。但它依赖于微软的office。电脑需要安装office才能转换。鉴于没在linux中使用office,本…

hadoop学习---基于Hive的教育平台数据仓库分析案例(二)

衔接第一部分,第一部分请点击:基于Hive的教育平台数据仓库分析案例(一) 后接第三部分,第三部分请点击:基于Hive的教育平台数据仓库分析案例 (三) 意向用户模块(全量分析)&#…

用户体验优化uxo指的是什么?

用户体验优化(User Experience Optimization,简称UXO)是一种专注于改善和提升用户在使用企业产品或服务时的整体感受和体验的过程。简单来说,它旨在通过改进产品或服务的设计和功能,使用户在使用过程中感到更加愉悦、满意和高效。用户体验优化…

桌面怎么分类便签 桌面分类便签设置方法

桌面便签,一直是我工作和学习的好帮手。每当灵感闪现或是有待办事项,我都会随手记录在便签上,它们就像我桌面上的小助手,时刻提醒我不要遗漏任何重要事务。 但便签一多,管理就成了问题。一张张五颜六色的便签贴满了我…

C++ 多态的相关问题

目录 1. 第一题 2. 第二题 3. inline 函数可以是虚函数吗 4. 静态成员函数可以是虚函数吗 5. 构造函数可以是虚函数吗 6. 析构函数可以是虚函数吗 7. 拷贝构造和赋值运算符重载可以是虚函数吗 8. 对象访问普通函数快还是访问虚函数快 9. 虚函数表是什么阶段生成的&…

华为与达梦数据签署全面合作协议

4月26日,武汉达梦数据库股份有限公司(简称“达梦数据”)与华为技术有限公司(简称“华为”)在达梦数据武汉总部签署全面合作协议。 达梦数据总经理皮宇、华为湖北政企业务总经理吕晓龙出席并见证签约;华为湖…