【论文复现】LSTM长短记忆网络

LSTM

  • 前言
  • 网络架构
    • 总线
    • 遗忘门
    • 记忆门
    • 记忆细胞
    • 输出门
  • 模型定义
    • 单个LSTM神经元的定义
    • LSTM层内结构的定义
  • 模型训练
  • 模型评估
  • 代码细节
    • LSTM层单元的首尾的处理
    • 配置Tensorflow的GPU版本

前言

LSTM作为经典模型,可以用来做语言模型,实现类似于语言模型的功能,同时还经常用于做时间序列。由于LSTM的原版论文相关版权问题,这里以colah大佬的博客为基础进行讲解。之前写过一篇Tensorflow中的LSTM详解,但是原理部分跟代码部分的联系并不紧密,实践性较强但是如果想要进行更加深入的调试就会出现原理性上面的问题,因此特此作文解决这个问题,想要用LSTM这个有趣的模型做出更加好的机器学习效果😊。

网络架构

LSTM框架图
这张图展示了LSTM在整体结构,下面就开始分部分介绍中间这个东东。

总线

在这里插入图片描述
这条是总线,可以实现神经元结构的保存或者更改,如果就是像上图一样一条总线贯穿不做任何改变,那么就是不改变细胞状态。那么如果想要改变细胞状态怎么办?可以通过来实现,这里的门跟高中生物中学的神经兴奋阈值比较像,用数学来表示就是sigmoid函数或者其他的激活函数,当门的输入达到要求,门就会打开,允许当前门后面的信息“穿过”门改变主线上面传递的信息,如果把每一个神经元看成一个时间节点,那么从上一个时间节点传到下一个时间节点过程中的门的开启与关闭就实现了时间序列数据的信息传递。
在这里插入图片描述

遗忘门

在这里插入图片描述
首先是遗忘门,这个门的作用是决定从上一个神经元传输到当前神经元的数据丢弃的程度,如果经过sigmoid函数以后输出0表示全部丢弃,输出1表示全部保留,这个层的输入是旧的信息和当前的新信息。

σ \sigma σ:sigmoid函数
W f W_f Wf:权重向量
b f b_f bf:偏置项,决定丢弃上一个时间节点的程度,如果是正数,表示更容易遗忘,如果是负数,表示比较容易记忆
h t − 1 h_{t-1} ht1:上一个时刻的输入
x t x_t xt:当前层的输入

记忆门

在这里插入图片描述
接下来是记忆门,这个门决定要记住什么信息,同时决定按照什么程度记住上一个状态的信息。

i t i_t it:在时间步t时刻的输入门激活值,计算方法跟上面的遗忘门是一样的,只是目的不一样,这里是记忆
C ~ t \tilde{C}_{t} C~t:表示上一个时刻的信息和当前时刻的信息的集合,但是是规则化到[-1,1]这个范围内了的

记忆细胞

在这里插入图片描述
有了上面的要记忆的信息和要丢弃的信息,记忆细胞的功能就可以得到实现,用 f t f_t ft这个标量决定上一个状态要遗忘什么,用 i t i_t it这个标量决定上一个状态要记住什么以及当前状态的信息要记住什么。这样就形成了一个记忆闭环了。

输出门

在这里插入图片描述
最后,在有了记忆细胞以后不仅仅不要将当前细胞状态记住,还要将当前的信息向下一层继续传输,实现公式中的状态转移。

o t o_t ot:跟前面的门公式都一样,但是功能是决定输出的程度
h t h_t ht:将输出规范到[-1,1]的区间,这里有两个输出的原因是在构建LSTM网络的时候需要有纵向向上的那个 h t h_t ht,然而在当前层的LSTM的神经元之间还是首尾相接的😍。

模型定义

单个LSTM神经元的定义


# 定义单个LSTM单元
# 定义单个LSTM单元
class My_LSTM(nn.Module):
    def __init__(self, input_size, hidden_size, output_size):
        super(My_LSTM, self).__init__()
        self.input_size = input_size
        self.hidden_size = hidden_size
        self.output_size = output_size

        # 初始化门的权重和偏置,由于每一个神经元都有自己的偏置,所以在定义单元内部定义
        self.Wf = nn.Parameter(torch.Tensor(input_size + hidden_size, hidden_size))
        self.bf = nn.Parameter(torch.Tensor(hidden_size))
        self.Wi = nn.Parameter(torch.Tensor(input_size + hidden_size, hidden_size))
        self.bi = nn.Parameter(torch.Tensor(hidden_size))
        self.Wo = nn.Parameter(torch.Tensor(input_size + hidden_size, hidden_size))
        self.bo = nn.Parameter(torch.Tensor(hidden_size))
        self.Wg = nn.Parameter(torch.Tensor(input_size + hidden_size, hidden_size))
        self.bg = nn.Parameter(torch.Tensor(hidden_size))
        # 初始化输出层的权重和偏置
        self.W = nn.Parameter(torch.Tensor(hidden_size, output_size))
        self.b = nn.Parameter(torch.Tensor(output_size))
        
    # 用于计算每一种权重的函数
    def cal_weight(self, input, weight, bias):
        return F.linear(input, weight, bias)
    # x是输入的数据,数据的格式是(batch, seq_len, input_size),包含的是batch个序列,每个序列有seq_len个时间步,每个时间步有input_size个特征
    def forward(self, x):
        # 初始化隐藏层和细胞状态
        h = torch.zeros(1, 1, self.hidden_size).to(x.device)
        c = torch.zeros(1, 1, self.hidden_size).to(x.device)
        # 遍历每一个时间步
        for i in range(x.size(1)):
            input = x[:, i, :].view(1, 1, -1) # 取出每一个时间步的数据
            # 计算每一个门的权重
            f = torch.sigmoid(self.cal_weight(input, self.Wf, self.bf)) # 遗忘门
            i = torch.sigmoid(self.cal_weight(input, self.Wi, self.bi)) # 输入门
            o = torch.sigmoid(self.cal_weight(input, self.Wo, self.bo)) # 输出门
            C_ = torch.tanh(self.cal_weight(input, self.Wg, self.bg)) # 候选值
            # 更新细胞状态
            c = f * c + i * C_
            # 更新隐藏层
            h = o * torch.tanh(c) # 将输出标准化到-1到1之间
        output = self.cal_weight(h, self.W, self.b) # 计算输出
        return output

LSTM层内结构的定义

class My_LSTMNetwork(nn.Module):
    def __init__(self, input_size, hidden_size, output_size):
        super(My_LSTMNetwork, self).__init__()
        self.hidden_size = hidden_size
        self.lstm = My_LSTM(input_size, hidden_size)  # 使用自定义的LSTM单元
        self.fc = nn.Linear(hidden_size, output_size)  # 定义全连接层

    def forward(self, x):
        h0 = torch.zeros(1, x.size(0), self.hidden_size).to(x.device)
        c0 = torch.zeros(1, x.size(0), self.hidden_size).to(x.device)
        out, _ = self.lstm(x, (h0, c0))  # LSTM层的前向传播
        out = self.fc(out[:, -1, :])  # 全连接层的前向传播
        return out


模型训练

history = model.fit(trainX, trainY, batch_size=64, epochs=50, 
                    validation_split=0.1, verbose=2)
print('compilation time:', time.time()-start)

模型评估

为了更加直观展示,这里用画图的方法进行结果展示。

fig3 = plt.figure(figsize=(20, 15))
plt.plot(np.arange(train_size+1, len(dataset)+1, 1), scaler.inverse_transform(dataset)[train_size:], label='dataset')
plt.plot(testPredictPlot, 'g', label='test')
plt.ylabel('price')
plt.xlabel('date')
plt.legend()
plt.show()

代码细节

LSTM层单元的首尾的处理

  • 首部:由于第一个节点不用接受来自上一个节点的输入,不需要有输入,当然也有一些是添加标识。

  • 尾部:由于已经进行到当前层的最后一个节点,因此输出只需要向下一层进行传递而不用向下一个节点传递,添加标识也是可以的。

配置Tensorflow的GPU版本

这一篇写的比较好,我自己的硬件环境如下图所示,需要的可以借鉴一下,当然也可以在我提供的代码链接直接用我给的environment.yml一键构建环境😃。
在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mfbz.cn/a/646709.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

空间注意力机制

第一步是沿着通道维度进行最大池化和平均池化,比如下面3*3的特征图,有3个通道。 第二步新特征图进行拼接并经过卷积调整通道数 第三步经过Sigmoid函数后乘到输入上 代码: class SpatialAttention(layers.Layer):def __init__(self):super(S…

弱类型解析

php中 转化为相同类型后比较 先判断数据类型后比较数值 var_dump("asdf"0);#bool(true) var_dump("asdf"1);#bool(false) var_dump("0asdf"0);#bool(true) var_dump("1asdf"1);#bool(true)1、md5撞库 例&#xff1a; <?php incl…

stm32-USART串口外设

配置流程 初始化配置 1.开启时钟&#xff08;打开USART和GPIO的时钟&#xff09; void RCC_AHBPeriphClockCmd(uint32_t RCC_AHBPeriph, FunctionalState NewState); void RCC_APB2PeriphClockCmd(uint32_t RCC_APB2Periph, FunctionalState NewState); void RCC_APB1Periph…

YOLOv10 论文学习

论文链接&#xff1a;https://arxiv.org/pdf/2405.14458 代码链接&#xff1a;https://github.com/THU-MIG/yolov10 解决了什么问题&#xff1f; 实时目标检测是计算机视觉领域的研究焦点&#xff0c;目的是以较低的延迟准确地预测图像中各物体的类别和坐标。它广泛应用于自动…

Ansible02-Ansible Modules模块详解

目录 写在前面4. Ansible Modules 模块4.1 Ansible常用模块4.1.1 Command模块4.1.2 shell模块4.1.3 scrpit模块4.1.4 file模块4.1.5 copy模块4.1.6 lineinfile模块4.1.7 systemd模块4.1.8 yum模块4.1.9 get_url模块4.1.10 yum_repository模块4.1.11 user模块4.1.12 group模块4.…

【每日刷题】Day49

【每日刷题】Day49 &#x1f955;个人主页&#xff1a;开敲&#x1f349; &#x1f525;所属专栏&#xff1a;每日刷题&#x1f34d; &#x1f33c;文章目录&#x1f33c; 1. 110. 平衡二叉树 - 力扣&#xff08;LeetCode&#xff09; 2. 501. 二叉搜索树中的众数 - 力扣&…

JavaSE 学习记录

1. Java 内存 2. this VS super this和super是两个关键字&#xff0c;用于引用当前对象和其父类对象 this 关键字&#xff1a; this 关键字用于引用当前对象&#xff0c;即调用该关键字的方法所属的对象。 主要用途包括&#xff1a; 在类的实例方法中&#xff0c;通过 this …

H3CNE-7-TCP和UDP协议

TCP和UDP协议 TCP&#xff1a;可靠传输&#xff0c;面向连接 -------- 速度慢&#xff0c;准确性高 UDP&#xff1a;不可靠传输&#xff0c;非面向连接 -------- 速度快&#xff0c;但准确性差 面向连接&#xff1a;如果某应用层协议的四层使用TCP端口&#xff0c;那么正式的…

DragonKnight CTF2024部分wp

DragonKnight CTF2024部分wp 最终成果 又是被带飞的一天&#xff0c;偷偷拷打一下队里的pwn手&#xff0c;只出了一题 这里是我们队的wp web web就出了两个ez题&#xff0c;确实很easy&#xff0c;只是需要一点脑洞(感觉)&#xff0c; ezsgin dirsearch扫一下就发现有ind…

让大模型变得更聪明三个方向

让大模型变得更聪明三个方向 随着人工智能技术的飞速发展&#xff0c;大模型在多个领域展现出了前所未有的能力&#xff0c;但它们仍然面临着理解力、泛化能力和适应性等方面的挑战。那么&#xff0c;如何让大模型变得更聪明呢&#xff1f; 方向一&#xff1a;算法创新 1.1算…

【ML Olympiad】预测地震破坏——根据建筑物位置和施工情况预测地震对建筑物造成的破坏程度

文章目录 Overview 概述Goal 目标Evaluation 评估标准 Dataset Description 数据集说明Dataset Source 数据集来源Dataset Fields 数据集字段 Data Analysis and Visualization 数据分析与可视化Correlation 相关性Hierarchial Clustering 分层聚类Adversarial Validation 对抗…

linux系统部署Oracle11g:netca成功启动后1521端口未能启动问题

一、问题描述 执行netca命令&#xff0c;进入图形化界面&#xff0c;进行Oracle端口监听设置 #终端输入命令 netca 最终提示设置成功&#xff1a; 但是我们进行下一步“创建数据库”的时候会报错&#xff0c;说数据库端口1521未开启&#xff01; 二、问题处理 使用命令查看开…

【Python特征工程系列】一文教你使用PCA进行特征分析与降维(案例+源码)

这是我的第287篇原创文章。 一、引言 主成分分析&#xff08;Principal Component Analysis, PCA&#xff09;是一种常用的降维技术&#xff0c;它通过线性变换将原始特征转换为一组线性不相关的新特征&#xff0c;称为主成分&#xff0c;以便更好地表达数据的方差。 在特征重要…

【kubernetes】陈述式资源管理的kubectl命令合集

目录 前言 一、K8s 资源管理操作方式 1、声明式资源管理方式 2、陈述式资源管理方式 二、陈述式资源管理方式 1、kubectl 命令基本语法 2、查看基本信息 2.1 查看版本信息 2.2 查看资源对象简写 2.3 配置kubectl命令自动补全 2.4 查看node节点日志 2.5 查看集群信息…

Windows下安装配置深度学习环境

Windows下安装配置深度学习环境 1. 准备工作 1.1 环境准备 操作系统&#xff1a;win10 22H2 GPU&#xff1a;Nvidia GeForce RTX 3060 12G 1.2 安装Nvidia驱动、cuda、cuDNN 下载驱动需要注册并登录英伟达账号。我这里将下面用到的安装包放到了百度网盘&#xff0c;可以关注微信…

【Linux杂货铺】进程通信

目录 &#x1f308; 前言&#x1f308; &#x1f4c1; 通信概念 &#x1f4c1; 通信发展阶段 &#x1f4c1; 通信方式 &#x1f4c1; 管道&#xff08;匿名管道&#xff09; &#x1f4c2; 接口 ​编辑&#x1f4c2; 使用fork来共享通道 &#x1f4c2; 管道读写规则 &…

智能家居完结 -- 整体设计

系统框图 前情提要: 智能家居1 -- 实现语音模块-CSDN博客 智能家居2 -- 实现网络控制模块-CSDN博客 智能家居3 - 实现烟雾报警模块-CSDN博客 智能家居4 -- 添加接收消息的初步处理-CSDN博客 智能家居5 - 实现处理线程-CSDN博客 智能家居6 -- 配置 ini文件优化设备添加-CS…

fastadmin 树状菜单展开,合并;简要文件管理系统界面设计与实现

一&#xff0c;菜单合并效果图 源文件参考&#xff1a;fastadmin 子级菜单展开合并、分类父级归纳 - FastAdmin问答社区 php服务端&#xff1a; public function _initialize() {parent::_initialize();$this->model new \app\admin\model\auth\Filetype;$this->admin…

粤嵌—2024/5/21—打家劫舍(✔)

代码实现&#xff1a; int rob(int *nums, int numsSize) {if (numsSize 1) {return nums[0];}if (numsSize 2) {return fmax(nums[0], nums[1]);}int dp[numsSize];dp[0] nums[0];dp[1] fmax(nums[0], nums[1]);for (int i 2; i < numsSize; i) {dp[i] fmax(dp[i - 1…

东方通TongWeb结合Spring-Boot使用

一、概述 信创需要; 原状:原来的服务使用springboot框架,自带的web容器是tomcat,打成jar包启动; 需求:使用东方通tongweb来替换tomcat容器; 二、替换步骤 2.1 准备 获取到TongWeb7.0.E.6_P7嵌入版 这个文件,文件内容有相关对应的依赖包,可以根据需要来安装到本地…