什么是长短期记忆网络?

一、概念

        长短期记忆网络(Long Short-Term Memory, LSTM)是一种特殊的循环神经网络(RNN),旨在解决标准RNN在处理长序列时的梯度消失和梯度爆炸问题。LSTM通过引入三个门(输入门、遗忘门和输出门)来控制信息的流动。其中,每个门都是一个神经网络层,用于决定哪些信息应该被保留,哪些信息应该被丢弃。LSTM的核心是细胞状态(cell state),它通过这些门的控制来更新和传递信息。

二、核心算法

        令x_{t}为时间步 t 的输入向量,h_{t-1}为前一个时间步的隐藏状态向量,h_{t}为当前时间步的隐藏状态向量,C_{t-1}为前一个时间步的细胞状态向量,C_{t}为当前时间步的细胞状态变量,f_{t}为当前时间步的遗忘门向量,i_{t}为当前时间步的输入门向量,\bar{C_{t}}为当前时间步的候选细胞状态向量,o_{t}为当前时间步的输出门向量,W_{f},W_{i},W_{C},W_{o}分别为各门的权重矩阵,b_{f},b_{i},b_{C},b_{o}为偏置向量,\sigma为sigmoid激活函数,tanh为tanh激活函数,*为元素级乘法。LSTM的核心内容包括以下几个部分:

1、遗忘门(Forget Gate)

        遗忘门决定细胞状态中哪些信息需要被遗忘。通过sigmoid激活函数,遗忘门的输出在0到1之间,表示每个细胞状态元素被保留的比例。

f_{t} = \sigma(W_{f} \cdot \left [ h_{t-1}, x_{t} \right ] + b_{f})

2、输入门(Input Gate)

        输入门决定哪些新的信息需要被写入细胞状态。通过sigmoid激活函数,输入门的输出在0到1之间,表示每个候选细胞状态元素被写入的比例。候选细胞状态通过tanh激活函数生成,表示新的信息。

i_{t} = \sigma(W_{i} \cdot \left [ h_{t-1}, x_{t} \right ] + b_{i})

\bar{C}_{t} = tanh(W_{C} \cdot \left [ h_{t-1}, x_{t} \right ] + b_{C})

3、细胞状态更新

        细胞状态结合遗忘门和输入门的结果进行更新。遗忘门的输出与前一个时间步的细胞状态相乘,表示保留的旧信息。输入门的输出与候选细胞状态相乘,表示写入的新信息。两者相加得到当前时间步的细胞状态。

C_{t} = f_{t} \ast C_{t-1}+i_{t} \ast \bar{C}_{t}

4、输出门(Output Gate)

        输出门决定细胞状态的哪些部分将作为输出。通过sigmoid激活函数,输出门的输出在0到1之间,表示每个细胞状态元素被输出的比例。细胞状态通过tanh激活函数进行非线性变换,然后与输出门的输出相乘,得到当前时间步的隐藏状态。

o_{t} = \sigma(W_{o} \cdot \left [ h_{t-1}, x_{t} \right ] + b_{o})

h_{t} = o_{t} \ast tanh(C_{t})

三、python实现

import torch
import torch.nn as nn
import numpy as np
import matplotlib.pyplot as plt
from sklearn.model_selection import train_test_split

# 生成正弦波数据
def generate_sine_wave(seq_length, num_samples):
    x = np.linspace(0, num_samples, num_samples)
    y = np.sin(x)
    data = []
    for i in range(len(y) - seq_length):
        data.append(y[i:i+seq_length+1])
    return np.array(data)

# 定义LSTM模型
class LSTMModel(nn.Module):
    def __init__(self, input_size, hidden_size, output_size, num_layers):
        super(LSTMModel, self).__init__()
        self.hidden_size = hidden_size
        self.num_layers = num_layers
        self.lstm = nn.LSTM(input_size, hidden_size, num_layers, batch_first=True)
        self.fc = nn.Linear(hidden_size, output_size)

    def forward(self, x):
        h0 = torch.zeros(self.num_layers, x.size(0), self.hidden_size).to(x.device)
        c0 = torch.zeros(self.num_layers, x.size(0), self.hidden_size).to(x.device)
        out, _ = self.lstm(x, (h0, c0))
        out = self.fc(out[:, -1, :])
        return out

# 超参数设置
seq_length = 50
num_samples = 1000
input_size = 1
hidden_size = 50
output_size = 1
num_layers = 2
batch_size = 64
learning_rate = 0.001
num_epochs = 5
test_size = 0.2  # 测试集占比

# 生成数据
data = generate_sine_wave(seq_length, num_samples)
X = data[:, :-1]
y = data[:, -1]

# 划分训练集和测试集
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=test_size, random_state=42)

# 转换为Tensor
X_train = torch.tensor(X_train.reshape(-1, seq_length, input_size), dtype=torch.float32)
y_train = torch.tensor(y_train.reshape(-1, output_size), dtype=torch.float32)
X_test = torch.tensor(X_test.reshape(-1, seq_length, input_size), dtype=torch.float32)
y_test = torch.tensor(y_test.reshape(-1, output_size), dtype=torch.float32)

# 创建数据加载器
train_dataset = torch.utils.data.TensorDataset(X_train, y_train)
train_loader = torch.utils.data.DataLoader(dataset=train_dataset, batch_size=batch_size, shuffle=True)
test_dataset = torch.utils.data.TensorDataset(X_test, y_test)
test_loader = torch.utils.data.DataLoader(dataset=test_dataset, batch_size=batch_size, shuffle=False)

# 初始化模型、损失函数和优化器
model = LSTMModel(input_size, hidden_size, output_size, num_layers)
criterion = nn.MSELoss()
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)

# 训练模型
for epoch in range(num_epochs):
    model.train()
    for i, (inputs, labels) in enumerate(train_loader):
        outputs = model(inputs)
        loss = criterion(outputs, labels)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
    print(f'Epoch [{epoch+1}/{num_epochs}], Loss: {loss.item():.4f}')

# 测试模型
model.eval()
with torch.no_grad():
    predicted = []
    actual = []
    for inputs, labels in test_loader:
        outputs = model(inputs)
        predicted.extend(outputs.numpy())
        actual.extend(labels.numpy())

# 绘制结果
plt.plot(actual, label='Actual data')
plt.plot(predicted, label='Predicted data')
plt.legend()
plt.show()

四、总结

        LSTM能够捕捉长时间依赖关系,使得模型在处理长序列数据时表现得比标准的RNN更好。但由于LSTM的计算依赖于前一个时间步的输出,这使得这样的网络结构难以并行化,在处理大规模数据时的效率较低。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/961705.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

Unity游戏(Assault空对地打击)开发(1) 创建项目和选择插件

目录 前言 创建项目 插件导入 地形插件 前言 这是游戏开发第一篇,进行开发准备。 创作不易,欢迎支持。 我的编辑器布局是【Tall】,建议调整为该布局,如下。 创建项目 首先创建一个项目,过程略,名字请勿…

996引擎 - NPC-动态创建NPC

996引擎 - NPC-动态创建NPC 创建脚本服务端脚本客户端脚本添加自定义音效添加音效文件修改配置参考资料有个小问题,创建NPC时没有控制朝向的参数。所以。。。自己考虑怎么找补吧。 多重影分身 创建脚本 服务端脚本 Mir200\Envir\Market_Def\test\test001-3.lua -- NPC八门名…

如何看待 OpenAI 的12天“shipmas”发布计划?

openAI的“Shipmas”并非单纯的营销活动,而是在用户增长、技术创新和市场竞争中的综合布局和战略体现。 史上最寒酸的发布会?继十月马斯克在好莱坞电影城高调发布特斯拉三款最新产品(无人出租车、无人巴士、人形机器人)后,十二月,OpenAI CEO 奥特曼宣布 OpenAI 将连续12…

蓝桥杯模拟算法:蛇形方阵

P5731 【深基5.习6】蛇形方阵 - 洛谷 | 计算机科学教育新生态 我们只要定义两个方向向量数组,这种问题就可以迎刃而解了 比如我们是4的话,我们从左向右开始存,1,2,3,4 到5的时候y就大于4了就是越界了&…

第31篇:Python开发进阶:数据可视化与前端集成

第31篇:数据可视化与前端集成 目录 数据可视化概述 什么是数据可视化数据可视化的重要性 Python中的数据可视化库 MatplotlibSeabornPlotlyBokehAltair 数据可视化的基本概念 图表类型设计原则交互性与动态性 与前端框架的集成 前端框架概述Flask与Django集成数据…

240. 搜索二维矩阵||

参考题解:https://leetcode.cn/problems/search-a-2d-matrix-ii/solutions/2361487/240-sou-suo-er-wei-ju-zhen-iitan-xin-qin-7mtf 将矩阵旋转45度,可以看作一个二叉搜索树。 假设以左下角元素为根结点, 当target比root大的时候&#xff…

maven的打包插件如何使用

默认的情况下,当直接执行maven项目的编译命令时,对于结果来说是不打第三方包的,只有一个单独的代码jar,想要打一个包含其他资源的完整包就需要用到maven编译插件,使用时分以下几种情况 第一种:当只是想单纯…

联想拯救者R720笔记本外接显示屏方法,显示屏是2K屏27英寸

晚上23点10分前下单,第二天上午显示屏送到,检查外包装没拆封过。这个屏幕左下方有几个按键,按一按就开屏幕、按一按就关闭屏幕,按一按方便节省时间,也支持阅读等模式。 显示屏是 :AOC 27英寸 2K高清 100Hz…

python:求解偏微分方程(PDEs)

1.偏微分方程基本知识 微分方程是指含有未知函数及其导数的关系式,偏微分方程是包含未知函数的偏导数(偏微分)的微分方程。 偏微分方程可以描述各种自然和工程现象,是构建科学、工程学和其他领域的数学模型主要手段。科学和工程中…

Deepseek技术浅析(二):大语言模型

DeepSeek 作为一家致力于人工智能技术研发的公司,其大语言模型(LLM)在架构创新、参数规模扩展以及训练方法优化等方面都达到了行业领先水平。 一、基于 Transformer 架构的创新 1.1 基础架构:Transformer 的回顾 Transformer 架…

13JavaWeb——SpringBootWeb之事务AOP

1. 事务管理 1.1 事务回顾 在数据库阶段我们已学习过事务了,我们讲到: 事务是一组操作的集合,它是一个不可分割的工作单位。事务会把所有的操作作为一个整体,一起向数据库提交或者是撤销操作请求。所以这组操作要么同时成功&am…

Hive:struct数据类型,内置函数(日期,字符串,类型转换,数学)

struct STRUCT(结构体)是一种复合数据类型,它允许你将多个字段组合成一个单一的值, 常用于处理嵌套数据,例如当你需要在一个表中存储有关另一个实体的信息时。你可以使用 STRUCT 函数来创建一个结构体。STRUCT 函数接受多个参数&…

【Redis】List 类型的介绍和常用命令

1. 介绍 Redis 中的 list 相当于顺序表,并且内部更接近于“双端队列”,所以也支持头插和尾插的操作,可以当做队列或者栈来使用,同时也存在下标的概念,不过和 Java 中的下标不同,Redis 支持负数下标&#x…

鸢尾花书01---基本介绍和Jupyterlab的上手

文章目录 1.致谢和推荐2.py和.ipynb区别3.Jupyterlab的上手3.1入口3.2页面展示3.3相关键介绍3.4代码的运行3.5重命名3.6latex和markdown说明 1.致谢和推荐 这个系列是关于一套书籍,结合了python和数学,机器学习等等相关的理论,总结的7本书籍…

【unity游戏开发之InputSystem——06】PlayerInputManager组件实现本地多屏的游戏(基于unity6开发介绍)

文章目录 PlayerInputManager 简介1、PlayerInputManager 的作用2、主要功能一、PlayerInputManager组件参数1、Notification Behavior 通知行为2、Join Behavior:玩家加入的行为3、Player Prefab 玩家预制件4、Joining Enabled By Default 默认启用加入5、Limit Number Of Pl…

[C语言日寄] 源码、补码、反码介绍

【作者主页】siy2333 【专栏介绍】⌈c语言日寄⌋:这是一个专注于C语言刷题的专栏,精选题目,搭配详细题解、拓展算法。从基础语法到复杂算法,题目涉及的知识点全面覆盖,助力你系统提升。无论你是初学者,还是…

记录 | 基于Docker Desktop的MaxKB安装

目录 前言一、MaxKBStep 1Step2 二、运行MaxKB更新时间 前言 参考文章:如何利用智谱全模态免费模型,生成大家都喜欢的图、文、视并茂的文章! MaxKB的Github下载地址 参考视频:【2025最新MaxKB教程】10分钟学会一键部署本地私人专属…

元宇宙下的Facebook:虚拟现实与社交的结合

随着科技的不断进步,虚拟现实(VR)技术逐渐从科幻走入现实,成为人们探索未来社交方式的重要工具。在这一浪潮中,Facebook(现为Meta)作为全球领先的社交平台,正在积极布局虚拟现实和元…

DeepSeek-R1 本地部署模型流程

DeepSeek-R1 本地部署模型流程 ***************************************************** 环境准备 操作系统:Windows11 内存:32GB RAM 存储:预留 300GB 可用空间 显存: 16G 网络: 100M带宽 ********************************************…

实验三---基于MATLAB的二阶系统动态性能分析---自动控制原理实验课

一 实验目的 1、观察学习二阶控制系统的单位阶跃响应、脉冲响应 2、记录单位阶跃响应曲线、脉冲响应曲线 3、掌握时间响应分析的一般方法 4、掌握系统阶跃响应曲线与传递函数参数的对应关系 二 实验仪器 计算机 MATLAB软件 三 实验内容及步骤 1、作以下二阶系统的单位阶跃响应…