GRU模块:nn.GRU层的介绍

       如果需要深入理解GRU的话,那么内部实现的详细代码和计算公式就比较重要,中间的一些过程和变量的意义需要详细关注,只有这样,才能准备把握这个模块的内涵和意义,设计初衷和使用方式等等,所以,仔细研究这个模块的实现还是非常有必要的。对于其他的模块同样如此,只有把各个经典的模块内部原理、实现和计算调用都搞清楚了,才能更好的去设计和利用神经网络,建立内在的直觉和能力。

       本文中介绍GRU内部的代码实现与数学表达式一致,在实际使用中,一般是通过调用API来实现,即self.rnn = nn.GRU(embed_size, num_hiddens, num_layers, dropout=dropout),只需要设定相应的参数即可,免除了重新实现的繁琐,并且类似于pytorch框架中的API还做了计算上的优化,使用起来高效方便。

 

       从输入输出的角度看,在 GRU(Gated Recurrent Unit)中,outputstate 都是由 GRU 层的循环计算产生的,它们之间有直接的关系。state 实际上是 output 中最后一个时间步的隐藏状态。 

GRU 的基本公式

GRU 的核心计算包括更新门(update gate)和重置门(reset gate),以及候选隐藏状态(candidate hidden state)。数学表达式如下:

  1. 更新门 \( z_t \): \[ z_t = \sigma(W_z \cdot h_{t-1} + U_z \cdot x_t) \]
       其中,\( \sigma \) 是sigmoid 函数,\( W_z \) 和 \( U_z \) 分别是对应于隐藏状态和输入的权重矩阵,\( h_{t-1} \) 是上一个时间步的隐藏状态,\( x_t \) 是当前时间步的输入。

  2. 重置门 \( r_t \):
       \[ r_t = \sigma(W_r \cdot h_{t-1} + U_r \cdot x_t) \]
       \( W_r \) 和 \( U_r \) 是更新门中定义的相似权重矩阵。

  3. 候选隐藏状态 \( \tilde{h}_t \):
       \[ \tilde{h}_t = \tanh(W \cdot r_t \odot h_{t-1} + U \cdot x_t) \]
       这里,\( \tanh \) 是激活函数,\( \odot \) 表示元素乘法(Hadamard product),\( W \) 和 \( U \) 是隐藏状态的权重矩阵。

  4. 最终隐藏状态 \( h_t \):
       \[ h_t = (1 - z_t) \odot h_{t-1} + z_t \odot \tilde{h}_t \]

output 和 state 的关系

  • output:在 GRU 中,output 包含了序列中每个时间步的隐藏状态。具体来说,对于每个时间步 \( t \),output 的第 \( t \) 个元素就是该时间步的隐藏状态 \( h_t \)。

  • state:state 是 GRU 层最后一层的隐藏状态,也就是 output 中最后一个时间步的隐藏状态 \( h_{T-1} \),其中 \( T \) 是序列的长度。

数学表达式

如果我们用 \( O \) 表示 output,\( S \) 表示 state,\( T \) 表示时间步的总数,那么:

\[ O = [h_0, h_1, ..., h_{T-1}] \]
\[ S = h_{T-1} \]

因此,state 实际上是 output 中最后一个元素,即 \( S = O[T-1] \)。

在 PyTorch 中,output 和 state 都是由 GRU 层的 `forward` 方法计算得到的。`output` 是一个三维张量,包含了序列中每个时间步的隐藏状态,而 `state` 是一个二维张量,仅包含最后一个时间步的隐藏状态。

代码示例

class Seq2SeqEncoder(d2l.Encoder):
"""⽤于序列到序列学习的循环神经⽹络编码器"""
    def __init__(self, vocab_size, embed_size, num_hiddens, num_layers,
dropout=0, **kwargs):
       super(Seq2SeqEncoder, self).__init__(**kwargs)
        # 嵌⼊层
       self.embedding = nn.Embedding(vocab_size, embed_size)
       self.rnn = nn.GRU(embed_size, num_hiddens, num_layers,
       dropout=dropout)

    def forward(self, X, *args):
    # 输出'X'的形状:(batch_size,num_steps,embed_size)
        X = self.embedding(X)
    # 在循环神经⽹络模型中,第⼀个轴对应于时间步
        X = X.permute(1, 0, 2)
    # 如果未提及状态,则默认为0
        output, state = self.rnn(X)
    # output的形状:(num_steps,batch_size,num_hiddens)
    # state的形状:(num_layers,batch_size,num_hiddens)
        return output, state

output:在完成所有时间步后,最后⼀层的隐状态的输出output是⼀个张量(output由编码器的循环层返回),其形状为(时间步数,批量⼤⼩,隐藏单元数)。

state:最后⼀个时间步的多层隐状态是state的形状是(隐藏层的数量,批量⼤⼩, 隐藏单元的数量)。

GRU的内部实现

上面一节的代码示例,是通过调用API实现的,即self.rnn = nn.GRU(embed_size, num_hiddens, num_layers, dropout=dropout)。那么,GRU内部是如何实现的呢?

分为模型、模型参数初始化和隐状态初始化三个部分:

模型定义(模型定义与数学表示式一致,也可以参考上图):

def gru(inputs, state, params):
    W_xz, W_hz, b_z, W_xr, W_hr, b_r, W_xh, W_hh, b_h, W_hq, b_q = params
    H, = state
    outputs = []
    for X in inputs:
        Z = torch.sigmoid((X @ W_xz) + (H @ W_hz) + b_z)
        R = torch.sigmoid((X @ W_xr) + (H @ W_hr) + b_r)
        H_tilda = torch.tanh((X @ W_xh) + ((R * H) @ W_hh) + b_h)
        H = Z * H + (1 - Z) * H_tilda
        Y = H @ W_hq + b_q
        outputs.append(Y)
    return torch.cat(outputs, dim=0), (H,)

  模型参数初始化(权重是从标准差0.01的高斯分布中提取的,超参数num_hiddens定义隐藏单元的数量,偏置项设置为0):

def get_params(vocab_size, num_hiddens, device):
    num_inputs = num_outputs = vocab_size

    def normal(shape):
        return torch.randn(size=shape, device=device)*0.01
    def three():
        return (normal((num_inputs, num_hiddens)),
                normal((num_hiddens, num_hiddens)),
                torch.zeros(num_hiddens, device=device))

    W_xz, W_hz, b_z = three() # 更新⻔参数
    W_xr, W_hr, b_r = three() # 重置⻔参数
    W_xh, W_hh, b_h = three() # 候选隐状态参数
# 输出层参数
    W_hq = normal((num_hiddens, num_outputs))
    b_q = torch.zeros(num_outputs, device=device)
# 附加梯度
    params = [W_xz, W_hz, b_z, W_xr, W_hr, b_r, W_xh, W_hh, b_h, W_hq, b_q]
    for param in params:
        param.requires_grad_(True)
    return params

隐状态初始化函数(此函数返回一个形状为(批量大小,隐藏单元个数)的张量,张量的值都为0

def init_gru_state(batch_size, num_hiddens, device):
    return (torch.zeros((batch_size, num_hiddens), device=device), )

最后由一个函数统一起来,实现模型:

model = d2l.RNNModelScratch(len(vocab), num_hiddens, device, get_params, init_gru_state, gru)

       总体上说,内部的代码实现与数学表达式一致,在实际使用中,一般是通过调用API来实现,即self.rnn = nn.GRU(embed_size, num_hiddens, num_layers, dropout=dropout),只需要设定相应的参数即可,免除了重新实现的繁琐,并且类似于pytorch框架中的API还做了计算上的优化,使用起来高效方便。但是,如果需要深入理解GRU的话,那么内部实现的详细代码和计算公式就比较重要,中间的一些过程和变量的意义需要详细关注,只有这样,才能准备把握这个模块的内涵和意义,设计初衷和使用方式等等,所以,仔细研究这个模块的实现还是非常有必要的。对于其他的模块同样如此,只有把各个经典的模块内部原理、实现和计算调用都搞清楚了,才能更好的去设计和利用神经网络,建立内在的直觉和能力。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/602861.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

值得推荐的多款iPaaS工具

当今企业面临着日益复杂的数据和系统集成挑战,为了提高业务效率和灵活性,许多企业转向了iPaaS工具(Integration Platform as a Service,即集成平台即服务)。iPaaS工具可以帮助企业轻松地连接和集成各种应用程序、数据和…

如何切换PHP版本

如果服务器上安装了多个php,可能会导致默认的php版本错误,无法启动swoole等服务, 查看命令行的php版本方法:https://q.crmeb.com/thread/9921 解决方法如下,选一个即可: 一、切换命令行php版本&#xff…

servlet-会话(cookie与session)

servlet会话技术 会话技术cookie创建Cookieindex.jspCookieServlet 获取Cookieindex.jspshowCookie session创建sessionindex.jsplogin.jspLoginServlet 获取sessionRedurectServket 清除会话login.jspClearItmeServlet 会话技术 两种会话:cookie,sessi…

在 Linux 中创建文件

目录 ⛳️推荐 前言 使用 touch 命令创建一个新的空文件 使用 echo 命令创建一个新文件 使用 cat 命令创建新文件 测试你的知识 ⛳️推荐 前些天发现了一个巨牛的人工智能学习网站,通俗易懂,风趣幽默,忍不住分享一下给大家。点击跳转到…

利用Python简单操作MySQL数据库,轻松实现数据读写

PyMySQL是Python编程语言中的一个第三方模块,它可以让Python程序连接到MySQL数据库并进行数据操作。它的使用非常简单,只需要安装PyMySQL模块,然后按照一定的步骤连接到MySQL数据库即 可。本文将介绍PyMySQL的安装、连接MySQL数据库、创建表、…

嗨动PDF编辑器V1.60版本发布,有哪些亮点值得注意!

嗨动PDF编辑器V1.60版发布,有哪些亮点值得注意呢? 在数字信息爆炸的时代,PDF文档以其跨平台、易于阅读和保持格式统一的特性,成为了工作、学习和生活中的常客。但很多时候,我们收到的PDF文档只是“只读”的&#xff0…

什么是香草看涨期权?香草看涨期权有哪些特点?

什么是香草看涨期权?香草看涨期权有哪些特点? 香草看涨期权,通常也称为香草期权,是金融市场上的一种金融衍生品,由券商或金融机构推出。它允许投资者以较小的费用获取相应股票市值的收益权,主要用于风险管…

6款好用的数据恢复软件推荐【不收费】+【收费】

日常办公和学习中,总有一些小粗心鬼会不小心误删了自己的重要文件,或者是由于设备故障导致数据丢失。如果需要进行数据恢复,那么可以试试数据恢复工具,只需要自己再电脑中操作,就可以帮助找回数据文件,下面…

基于随机森林与支持向量机的高光谱图像分类(含python代码)

目录 一、背景 二、代码实现 三、项目代码 一、背景 基于深度学习的教程(卷积神经网络)详见:基于卷积神经网络的高光谱图像分类详细教程(含python代码)-CSDN博客 在高光谱图像分类领域,随机森林&#…

CAT Game Builder

CAT游戏生成器是在Unity中制作游戏的更简单、更快的方法。使用CAT的模块化条件、动作、触发器和集成游戏系统,您可以创建原型、演示或完整游戏,而无需额外编程。 CAT Game Builder是一个完全可扩展的框架,专为专业团队设计,但对初学者来说足够容易掌握。CAT Game Builder使…

2024蓝桥杯CTF writeUP--Theorem

密码方向的签到题,根据题目已知n、e和c,并且p和q是相邻的素数,可以考虑分解。 通过prevprime函数分解n,然后 RSA解密即可: from Crypto.Util.number import long_to_bytes import gmpy2 import libnumfrom sympy im…

大语言模型LLM入门篇

大模型席卷全球,彷佛得模型者得天下。对于IT行业来说,以后可能没有各种软件了,只有各种各样的智体(Agent)调用各种各样的API。在这种大势下,笔者也阅读了很多大模型相关的资料,和很多新手一样&a…

1.数据结构---顺序表

ArrayList 在new的时候并没有进行内存的分配 此时才进行内存分配 两个结论: 第一次Add的时候分配大小为10的内存 扩容是1.5倍扩容

如何修复显示器或笔记本电脑屏幕的黄色色调?这里提供几种方法

序言 如果你的笔记本电脑屏幕呈淡黄色,则可以启用夜灯功能。该问题也可能源于连接松散的显示电缆、损坏的显卡驱动程序或错误配置的显示器设置。以下是一些故障排除步骤,你可以尝试解决此问题。 禁用夜间模式 夜间模式功能旨在减少显示器的蓝色色调,使屏幕看起来更温暖,…

光伏设备数据交互模硬件接口要求

模组的弱电接口采用26(间距2.54mm)双排插针作为连接件,模组与电能表的硬件接口示意图如 图1所示(模组正视图方向),接口定义说明见表3。模组外接插座和插头采用凤凰端子结构,接口示意 图应符合附…

网贷大数据查询要怎么保证准确性?

相信现在不少人都听说过什么是网贷大数据,但还有很多人都会将它跟征信混为一谈,其实两者有本质上的区别,那网贷大数据查询要怎么保证准确性呢?本文将为大家总结几点,感兴趣的朋友不妨去看看。 想要保证网贷大数据查询的准确度&am…

差动绕组电流互感器过电压保护器ACTB

安科瑞薛瑶瑶18701709087/17343930412 电流互感器在运行中如果二次绕组开路或一次绕组流过异常电流,都会在二次侧产生数千伏甚至上万伏的过电压。这不仅会使CT和二次设备损坏,也严重威胁运行人员的生命安全,并造成重大经济损失。采用电流互感…

SpringBoot多数据源配置

🍓 简介:java系列技术分享(👉持续更新中…🔥) 🍓 初衷:一起学习、一起进步、坚持不懈 🍓 如果文章内容有误与您的想法不一致,欢迎大家在评论区指正🙏 🍓 希望这篇文章对你有所帮助,欢…

Git知识点总结

目录 1、版本控制 1.1什么是版本控制 1.2常见的版本控制工具 1.3版本控制分类 2、集中版本控制 SVN 3、分布式版本控制 Git 2、Git与SVN的主要区别 3、软件下载 安装:无脑下一步即可!安装完毕就可以使用了! 4、启动Git 4.1常用的Li…

CentOS 7 :虚拟机网络环境配置+ 安装gcc(新手进)

虚拟机安装完centos的系统却发现无法正常联网,咋破! 几个简单的步骤: 一、检查和设置虚拟机网络适配器 这里笔者使用的桥接模式,朋友们可以有不同的选项设置 二、查看宿主机的网络 以笔者的为例,宿主机采用wlan上网模…