【动手学深度学习-pytorch】9.2长短期记忆网络(LSTM)

长期以来,隐变量模型存在着长期信息保存和短期输入缺失的问题。 解决这一问题的最早方法之一是长短期存储器(long short-term memory,LSTM) (Hochreiter and Schmidhuber, 1997)。 它有许多与门控循环单元( 9.1节)一样的属性。 有趣的是,长短期记忆网络的设计比门控循环单元稍微复杂一些, 却比门控循环单元早诞生了近20年.

门控记忆元 cell

  • 长短期记忆网络引入了记忆元(memory cell),或简称为单元(cell)
  • 为了控制记忆元,我们需要许多门。输入门 输出门 遗忘门
  • 其中一个门用来从单元中输出条目,我们将其称为输出门(output gate)。 另外一个门用来决定何时将数据读入单元,我们将其称为输入门(input gate)。 我们还需要一种机制来重置单元的内容,由遗忘门(forget gate)来管理, 这种设计的动机与门控循环单元相同, 能够通过专用机制决定什么时候记忆或忽略隐状态中的输入。 让我们看看这在实践中是如何运作的。

输入门、忘记门和输出门

就如在门控循环单元中一样, 当前时间步的输入和前一个时间步的隐状态 作为数据送入长短期记忆网络的门中, 如 图9.2.1所示。 它们由三个具有sigmoid激活函数的全连接层处理, 以计算输入门、遗忘门和输出门的值。 因此,这三个门的值都在
的范围内。
在这里插入图片描述
在这里插入图片描述

候选记忆元

在这里插入图片描述

记忆元

在这里插入图片描述

隐状态

在这里插入图片描述

只有隐状态会传递到输出层,而记忆元完全属于内部信息

从零开始实现

现在,我们从零开始实现长短期记忆网络。 与 8.5节中的实验相同, 我们首先加载时光机器数据集。

import torch
from torch import nn
from d2l import torch as d2l

batch_size, num_steps = 32, 35
train_iter, vocab = d2l.load_data_time_machine(batch_size, num_steps)

初始化模型参数

如前所述,超参数num_hiddens定义隐藏单元的数量。 我们按照标准差
的高斯分布初始化权重,并将偏置项设为0.

def get_lstm_params(vocab_size, num_hiddens, device):
    num_inputs = num_outputs = vocab_size

    def normal(shape):
        return torch.randn(size=shape, device=device)*0.01

    def three():
        return (normal((num_inputs, num_hiddens)),
                normal((num_hiddens, num_hiddens)),
                torch.zeros(num_hiddens, device=device))

    W_xi, W_hi, b_i = three()  # 输入门参数
    W_xf, W_hf, b_f = three()  # 遗忘门参数
    W_xo, W_ho, b_o = three()  # 输出门参数
    W_xc, W_hc, b_c = three()  # 候选记忆元参数
    # 输出层参数
    W_hq = normal((num_hiddens, num_outputs))
    b_q = torch.zeros(num_outputs, device=device)
    # 附加梯度
    params = [W_xi, W_hi, b_i, W_xf, W_hf, b_f, W_xo, W_ho, b_o, W_xc, W_hc,
              b_c, W_hq, b_q]
    for param in params:
        param.requires_grad_(True)
    return params

定义模型

在初始化函数中, 长短期记忆网络的隐状态需要返回一个额外的记忆元, 单元的值为0,形状为(批量大小,隐藏单元数)。 因此,我们得到以下的状态初始化。

def init_lstm_state(batch_size, num_hiddens, device):
    return (torch.zeros((batch_size, num_hiddens), device=device),
            torch.zeros((batch_size, num_hiddens), device=device))

实际模型的定义与我们前面讨论的一样: 提供三个门和一个额外的记忆元。 请注意,只有隐状态才会传递到输出层, 而记忆元不直接参与输出计算。

def lstm(inputs, state, params):
    [W_xi, W_hi, b_i, W_xf, W_hf, b_f, W_xo, W_ho, b_o, W_xc, W_hc, b_c,
     W_hq, b_q] = params
    (H, C) = state
    outputs = []
    for X in inputs:
        I = torch.sigmoid((X @ W_xi) + (H @ W_hi) + b_i)
        F = torch.sigmoid((X @ W_xf) + (H @ W_hf) + b_f)
        O = torch.sigmoid((X @ W_xo) + (H @ W_ho) + b_o)
        C_tilda = torch.tanh((X @ W_xc) + (H @ W_hc) + b_c)
        C = F * C + I * C_tilda
        H = O * torch.tanh(C)
        Y = (H @ W_hq) + b_q
        outputs.append(Y)
    return torch.cat(outputs, dim=0), (H, C)

训练和预测

让我们通过实例化 8.5节中 引入的RNNModelScratch类来训练一个长短期记忆网络, 就如我们在 9.1节中所做的一样。

vocab_size, num_hiddens, device = len(vocab), 256, d2l.try_gpu()
num_epochs, lr = 500, 1
model = d2l.RNNModelScratch(len(vocab), num_hiddens, device, get_lstm_params,
                            init_lstm_state, lstm)
d2l.train_ch8(model, train_iter, vocab, lr, num_epochs, device)

简洁实现

使用高级API,我们可以直接实例化LSTM模型。 高级API封装了前文介绍的所有配置细节。 这段代码的运行速度要快得多, 因为它使用的是编译好的运算符而不是Python来处理之前阐述的许多细节。

num_inputs = vocab_size
lstm_layer = nn.LSTM(num_inputs, num_hiddens)
model = d2l.RNNModel(lstm_layer, len(vocab))
model = model.to(device)
d2l.train_ch8(model, train_iter, vocab, lr, num_epochs, device)

总结

  • 长短期记忆网络,包含三个门:输入门、忘记门和遗忘门。其中遗忘门用于重置单元的内容,通过专用的机制决定什么时候记忆或者忽略状态中的输入。

  • 长短期记忆网络的隐藏层输出包括“隐状态”和“记忆元”。只有隐状态会传递到输出层,而记忆元完全属于内部信息。

  • 长短期记忆网络可以缓解梯度消失和梯度爆炸。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/500187.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

婚恋交友APP小程序H5源码交付-支持二开!实名制交友,可服务器审核,亦可后台自己审核!同城交友,多人语音!

一、需求分析 在征婚交友网站开发初期,需求分析是至关重要的环节。这需要深入了解目标用户的需求和期望,包括他们的年龄、职业、兴趣爱好、交友条件等方面。通过收集和分析这些信息,开发团队可以明确网站的目标用户,并为他们提供…

东特科技现已加入2024第13届国际生物发酵产品与技术装备展

参展企业介绍 温州东特科技有限公司是一家集设计、生产、销售及服务为一体的卫生级流体设备企业。专业从事各种乳食品、制药、化工、啤酒设备、不锈钢卫生级阀门,管件,视镜,及非标配件定制等产品的销售与服务。先进的设计理念专业的技术优势一…

深度思考:雪花算法snowflake分布式id生成原理详解

雪花算法snowflake是一种优秀的分布式ID生成方案,其优点突出:它能生成全局唯一且递增的ID,确保了数据的一致性和准确性;同时,该算法灵活性强,可自定义各部分bit位,满足不同业务场景的需求&#…

使用html实现图片相册展示设计

<!DOCTYPE html> <html lang"en"> <head><meta charset"UTF-8"><title>图片&#xff08;相册&#xff09;展示设计</title><link rel"stylesheet" href"./style.css"> </head> <b…

SQL--报错注入(join无列名注入)

SQL报错注入 平时在做SQL题时&#xff0c;如果发生语法的错误时&#xff0c;就会产生报错&#xff0c;报错的信息就会显示在前端 报错注入大多是利用函数会报错的特性&#xff0c;将需要的信息通过报错信息回显出来 报错注入函数&#xff08;后面主要的还有一个floor函数暂时…

虚拟机 centos 安装后与主机 ip保持一致

1、安装时 网络模式 悬着自动检测 &#xff08;桥接&#xff09; 2、打开网络 这里如果没有打开 就去 编辑 该文件。把ONBOOTno 改为yes 后 vim /etc/sysconfig/network-scripts/ifcfg-ens160 刷新配置 systemctl restart network 再查看addr 与本机 192.168.31.xx 在同…

Python Flask Web框架初步入门

前言 flask基础 搭建flask服务器 定义html 使用templates模板定义页面的html html页面编写 render_template传参变量 定义图片 创建static目录&#xff0c;存入图片 html编写 flask入门 网站多域名 网站之间超链接跳转 入门案例 将centos的rpm包下载链接集成到自…

13-API风格(下):RPCAPI介绍

RPC在Go项目开发中用得也非常多&#xff0c;需要我们认真掌握。 RPC介绍 根据维基百科的定义&#xff0c;RPC&#xff08;Remote Procedure Call&#xff09;&#xff0c;即远程过程调用&#xff0c;是一个计算机通信协议。该协议允许运行于一台计算机的程序调用另一台计算机…

发光二极管限流电阻对电路性能有哪些影响?

目录 1.控制电流 2.稳定电压&#xff08;亮度控制&#xff09; 3.功耗控制&#xff08;保护元件&#xff09; 4.节能控制 发光二极管&#xff08;LED&#xff09;限流电阻在电路中对电路性能主要有以下几个影响&#xff1a; 1.控制电流 限流电阻通过限制电流&#xff0c;使LED工…

WebGIS开发

1.准备工作 高德开发API注册账号&#xff0c;创建项目拿到key和密钥 高德key 2.通过JS API引入高德API <html><head><meta charset"utf-8" /><meta http-equiv"X-UA-Compatible" content"IEedge" /><metaname&quo…

pdf在浏览器上无法正常加载的问题

一、背景 觉得很有意思给大家分享一下。事情是这样的&#xff0c;开发给我反馈说&#xff0c;线上环境接口请求展示pdf异常&#xff0c;此时碰巧我前不久正好在ingress前加了一层nginx&#xff0c;恰逢此时内心五谷杂陈&#xff0c;思路第一时间便放在了改动项。捣鼓了好久无果…

上位机图像处理和嵌入式模块部署(qmacvisual非opencv算法编写)

【 声明&#xff1a;版权所有&#xff0c;欢迎转载&#xff0c;请勿用于商业用途。 联系信箱&#xff1a;feixiaoxing 163.com】 我们都知道&#xff0c;qmacvisual本身依赖于qtopencv的实现。大部分的界面都是依赖于qt的实现。图像算法部分&#xff0c;也是大部分都依赖于open…

【最新版RabbitMQ3.13】Linux安装基于源码构建的RabbitMQ教程

前言 linux环境 安装方式有三种&#xff0c;我们这里使用源码安装 Linux下rpm、yum和源码三种安装方式简介 个人语雀首发教程&#xff1a;https://www.yuque.com/wzzz/java/kl2zn22b42svsc6b csdn地址: https://blog.csdn.net/u013625306/article/details/137151862 安装版本…

机器学习模型之逻辑回归

逻辑回归是一种常用的分类算法&#xff0c;尤其适用于二分类问题。逻辑回归的核心思想是通过对数几率函数&#xff08;logistic function&#xff09;将线性回归的输出映射到概率空间&#xff0c;从而实现分类。 逻辑回归的原理&#xff1a; 逻辑回归模型使用对数几率函数&am…

wireshark创建显示过滤器实验简述

伯克利包过滤是一种在计算机网络中进行数据包过滤的技术&#xff0c;通过在内核中插入过滤器程序来实现对网络流量的控制和分析。 在数据包细节面板中创建显示过滤器&#xff0c;显示过滤器可以在wireshark捕获数据之后使用。 实验拓扑图&#xff1a; 实验基础配置&#xff1…

Unity-通过AB包使用SpriteAtlas图集(基于unity2018)

项目遇到了一个性能问题&#xff0c;需要优化UI。其中就涉及UI的合批问题&#xff0c;其中自然而然就会关联到图集的概念。旧版图集&#xff0c;Legacy Atlas&#xff0c;还没有太研究。今天主要看一下SpriteAtlas怎么使用的。 因为我们项目资源工程和Runtime是分离的&#xf…

旅游管理系统|基于springBoot旅游管理系统设计与实现(附项目源码+论文)

基于springBoot旅游管理系统设计与实现 一、摘要 现代经济快节奏发展以及不断完善升级的信息化技术&#xff0c;让传统数据信息的管理升级为软件存储&#xff0c;归纳&#xff0c;集中处理数据信息的管理方式。本旅游管理系统就是在这样的大环境下诞生&#xff0c;其可以帮助…

国产暴雨AI服务器X3418开启多元自主可控新篇章

在当前数字化转型的大潮中&#xff0c;算力作为新质生产力的重要动力引擎&#xff0c;对推动经济社会发展起着关键作用。尤其在人工智能领域&#xff0c;随着高性能、安全可控的AI算力需求持续攀升&#xff0c;国产化服务器的研发与应用显得尤为迫切。 作为国内专业的算力基础…

Netty核心原理剖析与RPC实践6-10

Netty核心原理剖析与RPC实践6-10 06-粘包拆包问题&#xff1a;如何获取一个完整的网络包 本节课开始我们将学习 Netty 通信过程中的编解码技术。编解码技术这是实现网络通信的基础&#xff0c;让我们可以定义任何满足业务需求的应用层协议。在网络编程中&#xff0c;我们经常…

Java八股文(高并发,分布式,JUC)

Java八股文の高并发&#xff0c;分布式&#xff0c;JUC 高并发&#xff0c;分布式&#xff0c;JUC 高并发&#xff0c;分布式&#xff0c;JUC 对于高并发的系统&#xff0c;为了提高用户的体验&#xff0c;你是如何做的&#xff1f;或者说高并发的技巧你用过哪些&#xff1f; …