用deepseek学大模型08-长短时记忆网络 (LSTM)

deepseek.com 从入门到精通长短时记忆网络(LSTM),着重介绍的目标函数,损失函数,梯度下降 标量和矩阵形式的数学推导,pytorch真实能跑的代码案例以及模型,数据, 模型应用场景和优缺点,及如何改进解决及改进方法数据推导。

从入门到精通长短时记忆网络 (LSTM)

参考:长短时记忆网络(LSTM)在序列数据处理中的优缺点分析
LSTM


1. LSTM 核心机制

LSTM 通过门控机制(遗忘门、输入门、输出门)和细胞状态(Cell State)解决 RNN 的梯度消失问题。

核心公式(时间步 t t t):

  1. 遗忘门(Forget Gate):
    f t = σ ( W f [ h t − 1 , x t ] + b f ) \mathbf{f}_t = \sigma\left( \mathbf{W}_f [\mathbf{h}_{t-1}, \mathbf{x}_t] + \mathbf{b}_f \right) ft=σ(Wf[ht1,xt]+bf)
  2. 输入门(Input Gate):
    i t = σ ( W i [ h t − 1 , x t ] + b i ) \mathbf{i}_t = \sigma\left( \mathbf{W}_i [\mathbf{h}_{t-1}, \mathbf{x}_t] + \mathbf{b}_i \right) it=σ(Wi[ht1,xt]+bi)
    C ~ t = tanh ⁡ ( W C [ h t − 1 , x t ] + b C ) \tilde{\mathbf{C}}_t = \tanh\left( \mathbf{W}_C [\mathbf{h}_{t-1}, \mathbf{x}_t] + \mathbf{b}_C \right) C~t=tanh(WC[ht1,xt]+bC)
  3. 细胞状态更新
    C t = f t ⊙ C t − 1 + i t ⊙ C ~ t \mathbf{C}_t = \mathbf{f}_t \odot \mathbf{C}_{t-1} + \mathbf{i}_t \odot \tilde{\mathbf{C}}_t Ct=ftCt1+itC~t
  4. 输出门(Output Gate):
    o t = σ ( W o [ h t − 1 , x t ] + b o ) \mathbf{o}_t = \sigma\left( \mathbf{W}_o [\mathbf{h}_{t-1}, \mathbf{x}_t] + \mathbf{b}_o \right) ot=σ(Wo[ht1,xt]+bo)
    h t = o t ⊙ tanh ⁡ ( C t ) \mathbf{h}_t = \mathbf{o}_t \odot \tanh(\mathbf{C}_t) ht=ottanh(Ct)

2. 目标函数与损失函数
  • 目标函数:最小化预测与真实值的差异(监督学习)。
  • 损失函数(以分类任务交叉熵为例):
    L = − 1 T ∑ t = 1 T ∑ c = 1 C y ^ t , c log ⁡ ( y t , c ) L = -\frac{1}{T} \sum_{t=1}^T \sum_{c=1}^C \mathbf{\hat{y}}_{t,c} \log(\mathbf{y}_{t,c}) L=T1t=1Tc=1Cy^t,clog(yt,c)
    其中 C C C为类别数, y ^ \mathbf{\hat{y}} y^为真实标签的 one-hot 编码。

3. 梯度下降与数学推导

LSTM 的梯度反向传播通过细胞状态 C t \mathbf{C}_t Ct和门控机制稳定梯度流动。

标量形式推导(以遗忘门 f t \mathbf{f}_t ft为例):
∂ L ∂ f t = ∂ L ∂ h t ⋅ ∂ h t ∂ C t ⋅ ∂ C t ∂ f t \frac{\partial L}{\partial \mathbf{f}_t} = \frac{\partial L}{\partial \mathbf{h}_t} \cdot \frac{\partial \mathbf{h}_t}{\partial \mathbf{C}_t} \cdot \frac{\partial \mathbf{C}_t}{\partial \mathbf{f}_t} ftL=htLCthtftCt
其中:
∂ C t ∂ f t = C t − 1 ⊙ f t ⊙ ( 1 − f t ) \frac{\partial \mathbf{C}_t}{\partial \mathbf{f}_t} = \mathbf{C}_{t-1} \odot \mathbf{f}_t \odot (1 - \mathbf{f}_t) ftCt=Ct1ft(1ft)

矩阵形式推导(链式法则):
∂ L ∂ W f = ∑ t = 1 T ( δ f , t ⋅ [ h t − 1 , x t ] T ) \frac{\partial L}{\partial \mathbf{W}_f} = \sum_{t=1}^T \left( \delta_{f,t} \cdot [\mathbf{h}_{t-1}, \mathbf{x}_t]^T \right) WfL=t=1T(δf,t[ht1,xt]T)
其中 δ f , t \delta_{f,t} δf,t为遗忘门的梯度误差:
δ f , t = ∂ L ∂ f t ⊙ σ ′ ( ⋅ ) \delta_{f,t} = \frac{\partial L}{\partial \mathbf{f}_t} \odot \sigma'(\cdot) δf,t=ftLσ()


4. PyTorch 代码案例
import torch
import torch.nn as nn
import matplotlib.pyplot as plt

# 数据生成:正弦波 + 噪声
time = torch.arange(0, 100, 0.1)
data = torch.sin(time) + 0.1 * torch.randn(len(time))

# 转换为序列数据(窗口长度=20)
def create_sequences(data, seq_length=20):
    X, y = [], []
    for i in range(len(data)-seq_length):
        X.append(data[i:i+seq_length])
        y.append(data[i+seq_length])
    return torch.stack(X).unsqueeze(-1), torch.stack(y).unsqueeze(-1)

X, y = create_sequences(data)
X_train, y_train = X[:800], y[:800]  # 划分训练集和测试集
X_test, y_test = X[800:], y[800:]

# 定义 LSTM 模型
class LSTMModel(nn.Module):
    def __init__(self, input_size=1, hidden_size=64, output_size=1):
        super().__init__()
        self.lstm = nn.LSTM(input_size, hidden_size, batch_first=True)
        self.fc = nn.Linear(hidden_size, output_size)
    
    def forward(self, x):
        out, (h_n, c_n) = self.lstm(x)  # out: (batch, seq_len, hidden_size)
        out = self.fc(out[:, -1, :])    # 取最后一个时间步
        return out

model = LSTMModel()
criterion = nn.MSELoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

# 训练
epochs = 100
train_loss = []
for epoch in range(epochs):
    optimizer.zero_grad()
    outputs = model(X_train)
    loss = criterion(outputs, y_train)
    loss.backward()
    nn.utils.clip_grad_norm_(model.parameters(), 0.5)  # 梯度裁剪
    optimizer.step()
    train_loss.append(loss.item())

# 可视化训练损失
plt.plot(train_loss)
plt.title("Training Loss")
plt.show()

# 预测
model.eval()
with torch.no_grad():
    train_pred = model(X_train)
    test_pred = model(X_test)

# 绘制结果
plt.figure(figsize=(12, 5))
plt.plot(data.numpy(), label="True Data")
plt.plot(range(20, 820), train_pred.numpy(), label="Train Predictions")
plt.plot(range(820, len(data)), test_pred.numpy(), label="Test Predictions")
plt.legend()
plt.show()

5. 应用场景与优缺点
  • 应用场景
    • 时间序列预测(股票价格、天气)
    • 自然语言处理(文本生成、机器翻译)
    • 语音识别
  • 优点
    • 解决长程依赖问题
    • 通过门控机制稳定梯度流动
    • 可处理变长序列
  • 缺点
    • 计算复杂度高(参数多)
    • 对短序列可能过拟合
    • 训练时间较长

6. 改进方法及数学推导
  1. GRU(门控循环单元)
    简化 LSTM,合并遗忘门和输入门:
    z t = σ ( W z [ h t − 1 , x t ] ) \mathbf{z}_t = \sigma(\mathbf{W}_z [\mathbf{h}_{t-1}, \mathbf{x}_t]) zt=σ(Wz[ht1,xt])
    r t = σ ( W r [ h t − 1 , x t ] ) \mathbf{r}_t = \sigma(\mathbf{W}_r [\mathbf{h}_{t-1}, \mathbf{x}_t]) rt=σ(Wr[ht1,xt])
    h ~ t = tanh ⁡ ( W [ r t ⊙ h t − 1 , x t ] ) \tilde{\mathbf{h}}_t = \tanh(\mathbf{W} [\mathbf{r}_t \odot \mathbf{h}_{t-1}, \mathbf{x}_t]) h~t=tanh(W[rtht1,xt])
    h t = ( 1 − z t ) ⊙ h t − 1 + z t ⊙ h ~ t \mathbf{h}_t = (1 - \mathbf{z}_t) \odot \mathbf{h}_{t-1} + \mathbf{z}_t \odot \tilde{\mathbf{h}}_t ht=(1zt)ht1+zth~t

  2. 双向 LSTM(Bi-LSTM)
    同时捕捉前向和后向依赖:
    h t → = LSTM ( x t , h t − 1 → ) \overrightarrow{\mathbf{h}_t} = \text{LSTM}(\mathbf{x}_t, \overrightarrow{\mathbf{h}_{t-1}}) ht =LSTM(xt,ht1 )
    h t ← = LSTM ( x t , h t + 1 ← ) \overleftarrow{\mathbf{h}_t} = \text{LSTM}(\mathbf{x}_t, \overleftarrow{\mathbf{h}_{t+1}}) ht =LSTM(xt,ht+1 )
    h t = [ h t → , h t ← ] \mathbf{h}_t = [\overrightarrow{\mathbf{h}_t}, \overleftarrow{\mathbf{h}_t}] ht=[ht ,ht ]

  3. 注意力机制
    增强对关键时间步的关注:
    α t = softmax ( v T tanh ⁡ ( W h h t + W s s ) ) \alpha_t = \text{softmax}(\mathbf{v}^T \tanh(\mathbf{W}_h \mathbf{h}_t + \mathbf{W}_s \mathbf{s})) αt=softmax(vTtanh(Whht+Wss))
    c = ∑ t = 1 T α t h t \mathbf{c} = \sum_{t=1}^T \alpha_t \mathbf{h}_t c=t=1Tαtht


7. 关键改进的数学验证(以 GRU 为例)
  • 梯度稳定性
    GRU 的更新门 z t \mathbf{z}_t zt控制历史信息的保留比例,梯度可沿两条路径传播:
    ∂ h t ∂ h t − 1 = ( 1 − z t ) + z t ⊙ ∂ h ~ t ∂ h t − 1 \frac{\partial \mathbf{h}_t}{\partial \mathbf{h}_{t-1}} = (1 - \mathbf{z}_t) + \mathbf{z}_t \odot \frac{\partial \tilde{\mathbf{h}}_t}{\partial \mathbf{h}_{t-1}} ht1ht=(1zt)+ztht1h~t
    避免传统 RNN 的连乘梯度。

通过上述内容,您可全面掌握 LSTM 的理论基础、实际实现及优化方法。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/972208.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

力扣 买卖股票的最佳时机

贪心算法典型例题。 题目 做过股票交易的都知道,想获取最大利润,就得从最低点买入,最高点卖出。这题刚好可以用暴力,一个数组中找到最大的数跟最小的数,然后注意一下最小的数在最大的数前面即可。从一个数组中选两个数…

mysql的rpm包安装

(如果之前下载过mariadb,使用yum remove mariadb卸载,因为mariadb与rpm包安装的mysql有很多相似的组件和文件,会发生冲突,而源码包安装的mysql不会,所以不用删除源码包安装myqsl,只删除mariadb就可以&#…

vue3 子组件属性响应性丢失分析总结(四)

一、先看例子&#xff1a; <script setup lang"ts"> import { onMounted, reactive, ref, watch } from vue; import Test from /components/Test.vue;let a {a:"a"};const aRef ref(a);var aReactive reactive(a);let bObj "B";cons…

Jenkins同一个项目不同分支指定不同JAVA环境

背景 一些系统应用,会为了适配不同的平台,导致不同的分支下用的是不同的gradle,导致需要不同的JAVA环境来编译,比如a分支需要使用JAVA11, b分支使用JAVA17。 但是jenkins上,一般都是Global Tool Configuration 全局所有环境公用一个JAVA_HOME。 尝试过用 Build 的Execut…

实现可拖拽的 Ant Design Modal 并保持下层 HTML 可操作性

前言 在开发复杂的前端界面时&#xff0c;我们常常需要一个可拖拽的弹窗&#xff08;Modal&#xff09;&#xff0c;同时又希望用户能够在弹窗打开的情况下操作下层的内容。Ant Design 的 Modal 组件提供了强大的功能&#xff0c;但默认情况下&#xff0c;弹窗会覆盖整个页面&…

网络安全三件套

一、在线安全的四个误解     Internet实际上是个有来有往的世界&#xff0c;你可以很轻松地连接到你喜爱的站点&#xff0c;而其他人&#xff0c;例如黑客也很方便地连接到你的机器。实际上&#xff0c;很多机器都因为自己很糟糕的在线安全设置无意间在机器和系统中留下了“…

Jetson Agx Orin平台JP6.0-r36.3版本修复了vi模式下的原始图像损坏(线条伪影)

1.问题描述 这是JP-6.0 GA/ l4t-r36.3.0的一个已知问题 通过vi模式捕获的图像会导致异常线条 参考下面的快照来演示这些线伪影 这个问题只能通过VI模式进行修复,不应该通过LibArgus看到。 此外,这是由于内存问题。 由于upstream已经将属性名称更改为“dma-noncoherent”…

封装neo4j的持久层和服务层

目录 持久层 mp 模仿&#xff1a; 1.抽取出通用的接口类 2.创建自定义的repository接口 服务层 mp 模仿&#xff1a; 1.抽取出一个IService通用服务类 2.创建ServiceImpl类实现IService接口 3.自定义的服务接口 4.创建自定义的服务类 工厂模式 为什么可以使用工厂…

Spring Boot (maven)分页2.0版本

前言&#xff1a; 通过实践而发现真理&#xff0c;又通过实践而证实真理和发展真理。从感性认识而能动地发展到理性认识&#xff0c;又从理性认识而能动地指导革命实践&#xff0c;改造主观世界和客观世界。实践、认识、再实践、再认识&#xff0c;这种形式&#xff0c;循环往…

Docker安装Minio对象存储

介绍 MinIO 是一种对象存储解决方案&#xff0c;提供与Amazon Web Services S3兼容的API并支持所有核心S3功能。MinIO可部署在任何地方&#xff1a;公共云或私有云、裸机基础设施、编排环境和边缘基础设施。 详情参见官方文档&#xff1a;MinIO Object Storage for Container…

BERT 大模型

BERT 大模型 EmbeddingTransformer预微调模块预训练任务 BERT 特点 : 优点 : 在语言理解相关任务中表现很好缺点 : 更适合 NLU 任务&#xff0c;不适合 NLG 任务 BERT 架构&#xff1a;双向编码模型 : Embedding 模块Transformer 模块预微调模块 Embedding Embedding 组成 …

cmake:定位Qt的ui文件

如题。在工程中&#xff0c;将h&#xff0c;cpp&#xff0c;ui文件放置到不同文件夹下&#xff0c;会存在cmake找不到ui文件&#xff0c;导致编译报错情况。 cmake通过指定文件路径&#xff0c;确保工程找到ui文件。 标识1&#xff1a;ui文件保存路径。 标识2&#xff1a;添加…

DFS算法篇:理解递归,熟悉递归,成为递归

1.DFS原理 那么dfs就是大家熟知的一个深度优先搜索&#xff0c;那么听起来很高大尚的一个名字&#xff0c;但是实际上dfs的本质就是一个递归&#xff0c;而且是一个带路径的递归&#xff0c;那么递归大家一定很熟悉了&#xff0c;大学c语言课程里面就介绍过递归&#xff0c;我…

H5自适应响应式代理记账与财政咨询服务类PbootCMS网站模板 – HTML5财务会计类网站源码下载

(H5自适应)响应式代理记账财政咨询服务类pbootcms网站模板 html5财务会计类网站源码下载 为了提升系统安全&#xff0c;请将后台文件admin.php的文件名修改一下。修改之后&#xff0c;后台登录地址就是&#xff1a;您的域名/您修改的文件名.php 模板特点&#xff1a; 1&#x…

嵌入式音视频开发(二)ffmpeg音视频同步

系列文章目录 嵌入式音视频开发&#xff08;零&#xff09;移植ffmpeg及推流测试 嵌入式音视频开发&#xff08;一&#xff09;ffmpeg框架及内核解析 嵌入式音视频开发&#xff08;二&#xff09;ffmpeg音视频同步 嵌入式音视频开发&#xff08;三&#xff09;直播协议及编码器…

工业自动化丨工业控制系统五层架构以及PLC、SCADA系统、DCS系统,从零基础到精通,收藏这篇就够了!

工业控制系统通常是几种类型控制系统的总称&#xff0c;包括监控和数据采集&#xff08;SCADA&#xff09;系统、分布式控制系统&#xff08;DCS&#xff09;和可编程逻辑控制器&#xff08;PLC&#xff09;以及其它控制系统。 【右下角**点赞、**转发、在看&#xff0c;为企业…

✨1.HTML、CSS 和 JavaScript 是什么?

✨✨ HTML、CSS 和 JavaScript 是构建网页的三大核心技术&#xff0c;它们相互协作&#xff0c;让网页呈现出丰富的内容、精美的样式和交互功能。以下为你详细介绍&#xff1a; &#x1f98b;1. HTML&#xff08;超文本标记语言&#xff09; 定义&#xff1a;HTML 是一种用于描…

MySQL基本操作——包含增删查改(环境为Ubuntu20.04,MySQL5.7.42)

1.库的操作 1.1 创建数据库 语法&#xff1a; 说明&#xff1a; 大写的表示关键字 [] 是可选项 CHARACTER SET: 指定数据库采用的字符集 COLLATE: 指定数据库字符集的校验规则 1.2 创建案例 创建一个使用utf8字符集的db1数据库 create database db1 charsetutf8; …

【项目】基于STM32F103C8T6的四足爬行机器人设计与实现(源码工程)

&#x1f449;博__主&#x1f448;&#xff1a;米码收割机 &#x1f449;技__能&#x1f448;&#xff1a;C/Python语言 &#x1f449;专__注&#x1f448;&#xff1a;专注主流机器人、人工智能等相关领域的开发、测试技术。 【项目】基于STM32F103C8T6的四足爬行机器人设计与…

IIS asp.net权限不足

检查应用程序池的权限 IIS 应用程序池默认使用一个低权限账户&#xff08;如 IIS_IUSRS&#xff09;&#xff0c;这可能导致无法删除某些文件或目录。可以通过以下方式提升权限&#xff1a; 方法 1&#xff1a;修改应用程序池的标识 打开 IIS 管理器。 在左侧导航树中&#x…