动手学深度学习(Pytorch版)代码实践 -循环神经网络-57长短期记忆网络(LSTM)

57长短期记忆网络(LSTM

1.LSTM原理

LSTM是专为解决标准RNN的长时依赖问题而设计的。标准RNN在训练过程中,随着时间步的增加,梯度可能会消失或爆炸,导致模型难以学习和记忆长时间间隔的信息。LSTM通过引入一组称为门的机制来解决这个问题:

  1. 输入门(Input Gate):控制有多少新的信息可以传递到记忆单元中。
  2. 遗忘门(Forget Gate):控制当前记忆单元中有多少信息会被保留。
  3. 输出门(Output Gate):控制记忆单元的输出有多少被传递到下一步。

LSTM还引入了一个称为记忆单元(Cell State)的概念,用于携带长期信息。这些门的组合使得LSTM能够选择性地记住或遗忘信息,从而解决了长时依赖问题。
在这里插入图片描述
在这里插入图片描述

2.优点
  1. 解决梯度消失问题:通过门控机制,LSTM能够有效地传递梯度,避免了梯度消失和爆炸的问题。
  2. 捕捉长时依赖LSTM能够记住和利用长时间间隔的信息,这是标准RNN难以做到的。
  3. 灵活性LSTM适用于各种序列数据处理任务,如时间序列预测、语言建模和序列到序列的翻译等。
3.LSTMGRU的区别

GRU(门控循环单元)是另一种解决长时依赖问题的RNN变体。GRULSTM都引入了门控机制,但它们的具体实现有所不同。

  1. 结构简化GRU的结构比LSTM更简单,参数更少,计算效率更高。
  2. 性能对比:在一些任务上,GRULSTM的性能相当,但在某些情况下,GRU可能表现更好,特别是在较小的数据集或较短的序列上。
  3. 门的数量LSTM有三个门(输入门、遗忘门和输出门),而GRU只有两个门(更新门和重置门)。
4.LSTM代码实践
import torch
from torch import nn
from d2l import torch as d2l
import matplotlib.pyplot as plt

# 设置批量大小和序列步数
batch_size, num_steps = 32, 35
# 加载时间机器数据集
train_iter, vocab = d2l.load_data_time_machine(batch_size, num_steps)

# 初始化LSTM模型参数
def get_lstm_params(vocab_size, num_hiddens, device):
    # 输入输出的维度大小
    num_inputs = num_outputs = vocab_size

    # 正态分布初始化权重
    def normal(shape):
        return torch.randn(size=shape, device=device) * 0.01

    # 三个权重参数(用于输入门、遗忘门、输出门和候选记忆元)
    def three():
        return (normal((num_inputs, num_hiddens)),  # 输入到隐藏状态的权重
                normal((num_hiddens, num_hiddens)),  # 隐藏状态到隐藏状态的权重
                torch.zeros(num_hiddens, device=device))  # 偏置

    W_xi, W_hi, b_i = three()  # 输入门参数
    W_xf, W_hf, b_f = three()  # 遗忘门参数
    W_xo, W_ho, b_o = three()  # 输出门参数
    W_xc, W_hc, b_c = three()  # 候选记忆元参数
    # 输出层参数
    W_hq = normal((num_hiddens, num_outputs))  # 隐藏状态到输出的权重
    b_q = torch.zeros(num_outputs, device=device)  # 输出偏置
    # 将所有参数附加到参数列表中
    params = [W_xi, W_hi, b_i, W_xf, W_hf, b_f, W_xo, W_ho, b_o, W_xc, W_hc,
              b_c, W_hq, b_q]
    for param in params:
        param.requires_grad_(True)  # 设置参数需要梯度
    return params

# 初始化LSTM的隐藏状态
def init_lstm_state(batch_size, num_hiddens, device):
    return (torch.zeros((batch_size, num_hiddens), device=device),  # 隐藏状态
            torch.zeros((batch_size, num_hiddens), device=device))  # 记忆元

# LSTM前向传播
def lstm(inputs, state, params):
    [W_xi, W_hi, b_i, W_xf, W_hf, b_f, W_xo, W_ho, b_o, W_xc, W_hc, b_c,
     W_hq, b_q] = params
    (H, C) = state  # 隐藏状态和记忆元
    outputs = []
    for X in inputs:
        # 输入门
        I = torch.sigmoid((X @ W_xi) + (H @ W_hi) + b_i)
        # 遗忘门
        F = torch.sigmoid((X @ W_xf) + (H @ W_hf) + b_f)
        # 输出门
        O = torch.sigmoid((X @ W_xo) + (H @ W_ho) + b_o)
        # 候选记忆元
        C_tilda = torch.tanh((X @ W_xc) + (H @ W_hc) + b_c)
        # 更新记忆元
        C = F * C + I * C_tilda
        # 更新隐藏状态
        H = O * torch.tanh(C)
        # 计算输出
        Y = (H @ W_hq) + b_q
        outputs.append(Y)
    return torch.cat(outputs, dim=0), (H, C)  # 返回输出和状态

# 训练和预测模型
vocab_size, num_hiddens, device = len(vocab), 256, d2l.try_gpu()
num_epochs, lr = 500, 1
# 创建自定义的LSTM模型
model = d2l.RNNModelScratch(len(vocab), num_hiddens, device, get_lstm_params,
                            init_lstm_state, lstm)
d2l.train_ch8(model, train_iter, vocab, lr, num_epochs, device)
plt.show()
# perplexity 1.3, 34433.0 tokens/sec on cuda:0
# 预测结果示例:time traveller conellace there wardeal that are almost us we hou

# 使用PyTorch的简洁实现
num_inputs = vocab_size
lstm_layer = nn.LSTM(num_inputs, num_hiddens)  # 创建LSTM层
model = d2l.RNNModel(lstm_layer, len(vocab))  # 创建模型
model = model.to(device)  # 将模型移动到GPU
d2l.train_ch8(model, train_iter, vocab, lr, num_epochs, device)
plt.show()
# perplexity 1.0, 317323.7 tokens/sec on cuda:0
# 预测结果示例:time travelleryou can show black is white by argument said filby

自定义的LSTM模型:

在这里插入图片描述
简洁实现:

在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/791958.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

55070-001J 同轴连接器

型号简介 55070-001J是Southwest Microwave的连接器。这款连接器外壳和中心接触件采用 BeCu 合金制成,这是一种具有良好导电性和机械性能的铜合金。绝缘珠则使用了 PEEK HT 材料制成,这是一种耐高温、耐化学腐蚀的工程塑料。为了确保连接的可靠性和稳定性…

旷野之间11 - 开源 SORA 已问世!训练您自己的 SORA 模型!

​​​​​ 目前最接近 SORA 的开源模型是 Latte,它采用了与 SORA 相同的 Vision Transformer 架构。Vision Transformer 究竟有何独特之处?它与之前的方法有何不同? Latte 尚未开源其文本转视频训练代码。我们复制了论文中的文本转视频训练代码,并将其提供给任何人使用,…

袋鼠云产品支持全栈信创适配,更加安全可靠、自主可控

随着国产替换的深化,企业对信创产品的需求逐渐融合更丰富的业务诉求以及未来数智规划,正从“同类替换”转向“迭代升级”。 当前,袋鼠云的产品与芯片、服务器、数据库、操作系统、中间件、云平台等主流信创厂商全面兼容适配,为企…

3.相机标定原理及代码实现(opencv)

1.相机标定原理 相机参数的确定过程就叫做相机标定。 1.1 四大坐标系及关系 (1)像素坐标系(单位:像素(pixel)) 像素坐标系是指相机拍到的图片的坐标系,以图片的左上角为坐标原点&a…

工厂人员定位系统介绍及解决方案

着现代工业不断发展,工厂人员定位系统已经广泛的用于各个大型工厂,这主要是因为UWB人员定位系统功能性非常强大,发挥着更加全面的功能优势,满足了不同厂区对工作环境的管理。而它的出现,将人员定位系统达到了更加精准的…

CISCN2024 RE 后两道 wp 复现

5. gdb_debug 其实逻辑还是挺简单的,当时没认真做 伪代码还算清晰 几个循环的加密之后判断密文 难点是前面有随机数参与加密,不过可以猜测随机数是不变的。 第一段加密 flag异或一组随机数,这里可以在异或的位置下条件断点,用…

windows信息收集和提权

目录 手动收集 工具收集 windows本地内核提权 本地提权 根据windows去找需要的exp进行利用 提权后结合mimikatz使用 msf提权 简单提权 生成后门 上线 BypassUAC绕过UAC提权 msf带的bypassuac模块可以尝试提权 Bypassuac提权命令操作 提权成功 ​local_exploi…

【python】随机森林预测汽车销售

目录 引言 1. 数据收集与预处理 2. 划分数据集 3. 构建随机森林模型 4. 模型训练 5. 模型评估 6. 模型调优 数据集 代码及结果 独热编码 随机森林模型训练 特征重要性图 混淆矩阵 ROC曲线 引言 随机森林(Random Forest)是一种集成学习方法…

全网最炸裂的5款SD涩涩模型!身体真的是越来越不好了!建议收藏,晚上自己偷偷打开看!

很多人说,**自从学了SD后,身体一天不如一天了。**今天就再接再厉,给大家推荐5个涩涩的模型。 【身材调节器】 顾名思义,这个lora可以帮你在出图时凹出想要的S型曲线。出图效果的大小由lora的权重来设定,权重越小越贫穷…

本地开发微信小程序,使用巴比达内网穿透

在微信小程序开发的热潮中,开发者常面临的一个挑战是如何在复杂的网络环境下测试和调试内网环境中的服务。巴比达正为这一难题提供了一条解决方案,极大简化了微信小程序与内网服务器之间通信的流程,加速了开发迭代周期。 以往,开…

3D渲染模型是白色的?问题出在以下6点,教你快速解决!

你的3D模型渲染出来是不是黑色、白色、粉色或者扭曲的? 出现这种情况很有可能是你的贴图纹理丢失或损坏! 幸运的是,有一些常见的方法可以解决此问题并恢复纹理。在本文中,小编将分享如何排查和解决不同方案下的纹理问题。通常问…

小红书的大模型有点「怂」

大模型发展至今,各个公司都根据自己的业务发布了各自的大模型。而小红书作为种草类产品的代表,自研的大模型一直引而不发,还在内测阶段。 AI以及自研大模型的持续火热,让以原创内容为主导的小红书坐不住了。 近期,据多…

本周六!上海场新能源汽车数据基础设施专场 Meetup 来了

本周六下午 14:30 新能源汽车数据基础设施专场 Meetup 在上海,点击链接报名 🎁 到场有机会获得 Greptime 和 AutoMQ 的精美文创周边哦~ 🔮 会后还有观众问答 & 抽奖环节等你来把神秘礼物带回家~ 🧁 更…

傅里叶级数的3D表示 包括源码

傅里叶级数的3D表示 包括源码 flyfish 傅里叶级数的基本形式 : y ( t ) ∑ n 1 , 3 , 5 , … N 4 A n π sin ⁡ ( n π T t ) y(t) \sum_{n1,3,5,\ldots}^{N} \frac{4A}{n\pi} \sin\left(\frac{n\pi}{T} t\right) y(t)n1,3,5,…∑N​nπ4A​sin(Tnπ​t) 其中&…

WPF/C#:在WPF中如何实现依赖注入

前言 本文通过 WPF Gallery 这个项目学习依赖注入的相关概念与如何在WPF中进行依赖注入。 什么是依赖注入 依赖注入(Dependency Injection,简称DI)是一种设计模式,用于实现控制反转(Inversion of Control&#xff0…

Apache Hadoop完全分布式集群搭建指南

Hadoop发行版本较多,Cloudera版本(Cloudera’s Distribution Including Apache Hadoop,简称CDH)收费版本通常用于生产环境,这里用开源免费的Apache Hadoop原始版本。 下载:Apache Hadoop 版本下载&#x…

华盈生物-PhenoCycler-超多靶标揭示组织空间位置和互作关系

华盈生物获得美国Akoya公司认证的PhenoCycler-Fusion(原CODEX)空间单细胞蛋白组技术服务商,并进入该技术的全球CRO服务提供者网络:CRO Service Providers | Akoya Biosciences为科研工作者提供为更具有精准医学特色的服务&#xf…

Failed to start mysql.service:Unit mysql.service not found(100%成功解决问题)

问题:在ubuntu中安装mysql后,启动mysql报错 你能看到我这篇文章,是非常幸运的! 安装mysql-server: 然后启动mysql报错: 那要怎么处理呢?easy,跟着下面的步骤操作就好了 首先,先卸…

算法训练营day28--134. 加油站 +135. 分发糖果+860.柠檬水找零+406.根据身高重建队列

一、 134. 加油站 题目链接:https://leetcode.cn/problems/gas-station/ 文章讲解:https://programmercarl.com/0134.%E5%8A%A0%E6%B2%B9%E7%AB%99.html 视频讲解:https://www.bilibili.com/video/BV1jA411r7WX 1.1 初见思路 得模拟分析出…

php快速入门

前言 php是一门脚本语言,可以访问服务器,对数据库增删查改(后台/后端语言) 后台语言:php,java,c,c,python等等 注意:php是操作服务器,不能直接在…