R4-LSTM学习笔记

  • 🍨 本文为🔗365天深度学习训练营 中的学习记录博客
  • 🍖 原作者:K同学啊

LSTM-火灾温度预测

  • 导入数据
  • 数据可视化
  • 设置X、y
  • 构建模型
  • 调用模型
  • 个人总结
    • LSTM 的基本结构
    • 细胞状态(Cell State)
    • LSTM 的优点

导入数据


import tensorflow as tf
import pandas     as pd
import numpy      as np

gpus = tf.config.list_physical_devices("GPU")
if gpus:
    tf.config.experimental.set_memory_growth(gpus[0], True)  #设置GPU显存用量按需使用
    tf.config.set_visible_devices([gpus[0]],"GPU")
print(gpus)

df_1 = pd.read_csv(r"C:\Users\11054\Desktop\kLearning\R4\woodpine2.csv")
[]

数据可视化


import matplotlib.pyplot as plt
import seaborn as sns

plt.rcParams['savefig.dpi'] = 500 #图片像素
plt.rcParams['figure.dpi']  = 500 #分辨率

fig, ax =plt.subplots(1,3,constrained_layout=True, figsize=(14, 3))

sns.lineplot(data=df_1["Tem1"], ax=ax[0])
sns.lineplot(data=df_1["CO 1"], ax=ax[1])
sns.lineplot(data=df_1["Soot 1"], ax=ax[2])
plt.show()

在这里插入图片描述

#构建数据集


dataFrame = df_1.iloc[:,1:]
dataFrame
Tem1CO 1Soot 1
025.00.0000000.000000
125.00.0000000.000000
225.00.0000000.000000
325.00.0000000.000000
425.00.0000000.000000
............
5943295.00.0000770.000496
5944294.00.0000770.000494
5945292.00.0000770.000491
5946291.00.0000760.000489
5947290.00.0000760.000487

5948 rows × 3 columns

设置X、y


# 取前8个时间段的Tem1、CO 1、Soot 1为X,第9个时间段的Tem1为y。
width_X = 8
width_y = 1
X = []
y = []

in_start = 0

for _, _ in df_1.iterrows():
    in_end  = in_start + width_X
    out_end = in_end   + width_y

    if out_end < len(dataFrame):
        X_ = np.array(dataFrame.iloc[in_start:in_end , ])
        X_ = X_.reshape((len(X_)*3))
        y_ = np.array(dataFrame.iloc[in_end  :out_end, 0])

        X.append(X_)
        y.append(y_)

    in_start += 1

X = np.array(X)
y = np.array(y)

X.shape, y.shape
((5939, 24), (5939, 1))
# 归一化
from sklearn.preprocessing import MinMaxScaler

#将数据归一化,范围是0到1
sc       = MinMaxScaler(feature_range=(0, 1))
X_scaled = sc.fit_transform(X)
X_scaled.shape
(5939, 24)
X_scaled = X_scaled.reshape(len(X_scaled),width_X,3)
X_scaled.shape
(5939, 8, 3)
# 划分数据集
X_train = np.array(X_scaled[:5000]).astype('float64')
y_train = np.array(y[:5000]).astype('float64')

X_test  = np.array(X_scaled[5000:]).astype('float64')
y_test  = np.array(y[5000:]).astype('float64')
X_train.shape
(5000, 8, 3)

构建模型


# 多层 LSTM
import tensorflow.python.keras as keras
from tensorflow.python.keras.layers.core import Activation, Dropout, Dense
from tensorflow.python.keras.layers import Flatten, LSTM
model_lstm = keras.Sequential()
model_lstm.add(LSTM(units=64, activation='relu', return_sequences=True,
               input_shape=(X_train.shape[1], 3)))
model_lstm.add(LSTM(units=64, activation='relu'))

model_lstm.add(Dense(width_y))
# 只观测loss数值,不观测准确率,所以删去metrics选项
from tensorflow.python.keras.optimizers import adam_v2
optimizer = adam_v2.Adam(1e-3)
model_lstm.compile(optimizer=optimizer,
                   loss='mean_squared_error')  # 损失函数用均方误差
X_train.shape, y_train.shape
((5000, 8, 3), (5000, 1))
# 注意 此处函数弃用 需要修改源码
#‘tensorflow.python.distribute.input_lib‘ has no attribute 
# DistributedDataset 实际是判断是否为分布式数据,返回值直接改为False
history_lstm = model_lstm.fit(X_train, y_train,
                         batch_size=64,
                         epochs=40,
                         validation_data=(X_test, y_test),
                         validation_freq=1)
           
Epoch 1/40
79/79 [==============================] - 5s 18ms/step - loss: 11953.3770 - val_loss: 4300.0811
Epoch 2/40
79/79 [==============================] - 1s 14ms/step - loss: 154.9082 - val_loss: 996.6027
Epoch 3/40
79/79 [==============================] - 1s 13ms/step - loss: 15.5335 - val_loss: 341.1582
Epoch 4/40
79/79 [==============================] - 1s 13ms/step - loss: 7.5403 - val_loss: 290.5845
Epoch 5/40
79/79 [==============================] - 1s 12ms/step - loss: 7.3599 - val_loss: 265.5993
Epoch 6/40
79/79 [==============================] - 1s 13ms/step - loss: 7.0518 - val_loss: 257.2565
Epoch 7/40
79/79 [==============================] - 1s 14ms/step - loss: 7.2212 - val_loss: 291.9721
Epoch 8/40
79/79 [==============================] - 1s 14ms/step - loss: 6.5968 - val_loss: 214.1582
Epoch 9/40
79/79 [==============================] - 1s 12ms/step - loss: 7.5128 - val_loss: 243.2180
Epoch 10/40
79/79 [==============================] - 1s 14ms/step - loss: 6.6279 - val_loss: 218.3901
Epoch 11/40
79/79 [==============================] - 1s 14ms/step - loss: 6.3576 - val_loss: 220.4561
Epoch 12/40
79/79 [==============================] - 1s 13ms/step - loss: 6.5356 - val_loss: 243.2316
Epoch 13/40
79/79 [==============================] - 1s 12ms/step - loss: 6.8603 - val_loss: 225.0939
Epoch 14/40
79/79 [==============================] - 1s 13ms/step - loss: 7.3385 - val_loss: 154.1982
Epoch 15/40
79/79 [==============================] - 1s 14ms/step - loss: 7.1614 - val_loss: 225.3159
Epoch 16/40
79/79 [==============================] - 1s 13ms/step - loss: 6.5654 - val_loss: 199.5660
Epoch 17/40
79/79 [==============================] - 1s 12ms/step - loss: 6.4847 - val_loss: 176.0666
Epoch 18/40
79/79 [==============================] - 1s 11ms/step - loss: 6.3168 - val_loss: 217.2618
Epoch 19/40
79/79 [==============================] - 1s 12ms/step - loss: 6.6176 - val_loss: 244.9601
Epoch 20/40
79/79 [==============================] - 1s 13ms/step - loss: 6.9601 - val_loss: 229.4581
Epoch 21/40
79/79 [==============================] - 1s 15ms/step - loss: 5.8222 - val_loss: 351.2197
Epoch 22/40
79/79 [==============================] - 1s 13ms/step - loss: 7.2396 - val_loss: 149.9563
Epoch 23/40
79/79 [==============================] - 1s 13ms/step - loss: 7.0510 - val_loss: 273.4801
Epoch 24/40
79/79 [==============================] - 1s 14ms/step - loss: 6.3674 - val_loss: 254.2635
Epoch 25/40
79/79 [==============================] - 1s 13ms/step - loss: 6.6236 - val_loss: 139.3550
Epoch 26/40
79/79 [==============================] - 1s 14ms/step - loss: 5.7532 - val_loss: 296.6612
Epoch 27/40
79/79 [==============================] - 1s 14ms/step - loss: 6.8470 - val_loss: 305.5312
Epoch 28/40
79/79 [==============================] - 1s 13ms/step - loss: 7.1153 - val_loss: 160.2791
Epoch 29/40
79/79 [==============================] - 1s 12ms/step - loss: 5.9563 - val_loss: 235.7691
Epoch 30/40
79/79 [==============================] - 1s 13ms/step - loss: 7.2391 - val_loss: 168.0048
Epoch 31/40
79/79 [==============================] - 1s 13ms/step - loss: 5.8283 - val_loss: 197.9875
Epoch 32/40
79/79 [==============================] - 1s 14ms/step - loss: 5.3628 - val_loss: 279.9405
Epoch 33/40
79/79 [==============================] - 1s 14ms/step - loss: 7.0928 - val_loss: 134.6513
Epoch 34/40
79/79 [==============================] - 1s 14ms/step - loss: 6.9159 - val_loss: 213.4102
Epoch 35/40
79/79 [==============================] - 1s 14ms/step - loss: 5.9036 - val_loss: 190.3418
Epoch 36/40
79/79 [==============================] - 1s 14ms/step - loss: 7.2137 - val_loss: 136.0768
Epoch 37/40
79/79 [==============================] - 1s 14ms/step - loss: 5.9089 - val_loss: 176.8896
Epoch 38/40
79/79 [==============================] - 1s 13ms/step - loss: 6.0480 - val_loss: 153.8418
Epoch 39/40
79/79 [==============================] - 1s 12ms/step - loss: 5.7332 - val_loss: 255.8892
Epoch 40/40
79/79 [==============================] - 1s 14ms/step - loss: 6.2338 - val_loss: 163.2283
# 支持中文
plt.rcParams['font.sans-serif'] = ['SimHei']  # 用来正常显示中文标签
plt.rcParams['axes.unicode_minus'] = False  # 用来正常显示负号

plt.figure(figsize=(5, 3),dpi=120)

plt.plot(history_lstm.history['loss']    , label='LSTM Training Loss')
plt.plot(history_lstm.history['val_loss'], label='LSTM Validation Loss')

plt.title('Training and Validation Loss')
plt.legend()
plt.show()

在这里插入图片描述

调用模型


predicted_y_lstm = model_lstm.predict(X_test)                        # 测试集输入模型进行预测

y_test_one = [i[0] for i in y_test]
predicted_y_lstm_one = [i[0] for i in predicted_y_lstm]

plt.figure(figsize=(5, 3),dpi=120)
# 画出真实数据和预测数据的对比曲线
plt.plot(y_test_one[:1000], color='red', label='真实值')
plt.plot(predicted_y_lstm_one[:1000], color='blue', label='预测值')

plt.title('Title')
plt.xlabel('X')
plt.ylabel('Y')
plt.legend()
plt.show()

在这里插入图片描述

from sklearn import metrics
"""
RMSE :均方根误差  ----->  对均方误差开方
R2   :决定系数,可以简单理解为反映模型拟合优度的重要的统计量
"""
RMSE_lstm  = metrics.mean_squared_error(predicted_y_lstm, y_test)**0.5
R2_lstm    = metrics.r2_score(predicted_y_lstm, y_test)

print('均方根误差: %.5f' % RMSE_lstm)
print('R2: %.5f' % R2_lstm)
均方根误差: 12.77608
R2: 0.70679

个人总结

LSTM 的基本结构

LSTM 的核心思想是通过门控机制来控制信息的流动。一个标准的 LSTM 单元包含以下三个门结构:

  1. 遗忘门(Forget Gate)

    • 决定哪些信息需要从细胞状态中丢弃。
    • 计算公式:KaTeX parse error: Can't use function '\(' in math mode at position 1: \̲(̲ f_t = \sigma(W…
    • 其中 KaTeX parse error: Can't use function '\(' in math mode at position 1: \̲(̲ \sigma \) 是 Sigmoid 激活函数,KaTeX parse error: Can't use function '\(' in math mode at position 1: \̲(̲ W_f \) 是权重矩阵,KaTeX parse error: Can't use function '\(' in math mode at position 1: \̲(̲ b_f \) 是偏置项,KaTeX parse error: Can't use function '\(' in math mode at position 1: \̲(̲ h_{t-1} \) 是前一时间步的隐藏状态,KaTeX parse error: Can't use function '\(' in math mode at position 1: \̲(̲ x_t \) 是当前时间步的输入。
  2. 输入门(Input Gate)

    • 决定哪些新信息需要添加到细胞状态中。
    • 计算公式:KaTeX parse error: Can't use function '\(' in math mode at position 1: \̲(̲ i_t = \sigma(W…
    • 以及新候选值:KaTeX parse error: Can't use function '\(' in math mode at position 1: \̲(̲ \tilde{C_t} = …
  3. 输出门(Output Gate)

    • 决定当前细胞状态的哪些部分将输出。
    • 计算公式:KaTeX parse error: Can't use function '\(' in math mode at position 1: \̲(̲ o_t = \sigma(W…

细胞状态(Cell State)

细胞状态 KaTeX parse error: Can't use function '\(' in math mode at position 1: \̲(̲ C_t \)是 LSTM 的核心,它通过门控机制来传递长期依赖信息。

  • 更新细胞状态:

    • KaTeX parse error: Can't use function '\(' in math mode at position 1: \̲(̲ C_t = f_t \odo…
    • 其中 KaTeX parse error: Can't use function '\(' in math mode at position 1: \̲(̲ \odot \) 表示逐元素乘法。
  • 计算当前时间步的隐藏状态:

    • KaTeX parse error: Can't use function '\(' in math mode at position 1: \̲(̲ h_t = o_t \odo…

在这里插入图片描述
在这里插入图片描述

LSTM 的优点

  1. 处理长期依赖问题:通过门控机制,LSTM 能够有效处理序列数据中的长期依赖问题,避免了传统 RNN 的梯度消失和梯度爆炸问题。

  2. 灵活性:LSTM 可以用于各种序列学习任务,如语言模型、语音识别、时间序列预测等。

  3. 性能:在许多序列学习任务中,LSTM 表现出色,尤其是在需要记忆长期信息的场景中。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/953122.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

uniapp实现H5页面内容居中与两边留白,打造类似微信公众号阅读体验

在 UniApp 中&#xff0c;由于需要兼容多端应用&#xff0c;我们通常使用 rpx 作为尺寸单位。然而&#xff0c;在某些情况下&#xff0c;如需要实现内容居中且两边留白时&#xff0c;直接使用 rpx 可能会带来一些限制。这时&#xff0c;我们可以考虑使用 px 或 rem 等单位&…

网工_网络体系结构

2024.01.09&#xff1a;网络工程学习笔记&#xff08;网工老姜&#xff09; 第1节 网络体系结构 1.1 计算机一切皆011.2 网络协议1.3 协议的分层模型1.4 主机1向主机2发送数据过程1.5 本章小结 1.1 计算机一切皆01 在计算机内部&#xff0c;所有的数据最终都是以01的方式存在的…

CI/CD 流水线

CI/CD 流水线 CI 与 CD 的边界CI 持续集成CD&#xff08;持续交付/持续部署&#xff09;自动化流程示例&#xff1a; Jenkins 引入到 CI/CD 流程在本地或服务器上安装 Jenkins。配置 Jenkins 环境流程设计CI 阶段&#xff1a;Jenkins 流水线实现CD 阶段&#xff1a;Jenkins 流水…

编程题-二分查找

题目&#xff1a; 给定一个 n 个元素有序的&#xff08;升序&#xff09;整型数组 nums 和一个目标值 target &#xff0c;写一个函数搜索 nums 中的 target&#xff0c;如果目标值存在返回下标&#xff0c;否则返回 -1 解法一&#xff08;循环遍历查找&#xff09;&#xff…

OOM排查思路

K8S 容器的云原生生态&#xff0c;改变了服务的交付方式&#xff0c;自愈能力和自动扩缩等功能简直不要太好用。 有好的地方咱要夸&#xff0c;不好的地方咱也要说&#xff0c;真正的业务是部署于容器内部&#xff0c;而容器之外&#xff0c;又有一逻辑层 Pod 。 对于容器和…

Github Copilot学习笔记

&#xff08;一&#xff09;Prompt Engineering 利用AI工具生成prompt设计好的prompt结构使用MarkDown语法&#xff0c;按Role, Skills, Constrains, Background, Requirements和Demo这几个维度描述需求。然后收输入提示词&#xff1a;作为 [Role], 拥有 [Skills], 严格遵守 […

在 Rider 中使用 C# 创建 Windows 窗体应用 Winforms

1&#xff0c;创建项目 new solution 创建一个解决方案 2&#xff0c;打开设计器 在 Form1.cs 上右键打开设计器 认识一下 Rider 的界面 参考微软官方的例子&#xff0c;添加如下属性&#xff1a;注&#xff1a;这里 Listbox 的大小设置成 120, 94 失败&#xff0c;默认的是 12…

R数据分析:多分类问题预测模型的ROC做法及解释

有同学做了个多分类的预测模型,结局有三个类别,做的模型包括多分类逻辑回归、随机森林和决策树,多分类逻辑回归是用ROC曲线并报告AUC作为模型评估的,后面两种模型报告了混淆矩阵,审稿人就提出要统一模型评估指标。那么肯定是统一成ROC了,刚好借这个机会给大家讲讲ROC在多…

#Java-集合进阶-Map

1.Map 声明1 1.1 双列集合的特点 单列集合一次只能添加一个元素&#xff0c;双列集合一次可以添加一对元素 例&#xff1a; 小米手机2000华为手机5000苹果手机9000 这三对元素&#xff0c;左边的我们称之为键&#xff0c;右边的称为值。他们是一一对应的关系 所以双列集合中…

IntelliJ IDEA和MAVEN基本操作:项目和缓存存储到非C盘

为了将 IntelliJ IDEA 的所有项目和缓存存储到 C 盘以外的地方&#xff0c;以下是你需要调整的设置和步骤&#xff1a; 1. 更改项目默认存储位置 打开 IntelliJ IDEA。点击顶部菜单的 File > Settings &#xff08;Windows&#xff09;或 IntelliJ IDEA > Preferences &…

【Linux系列】`find / -name cacert.pem` 文件搜索

&#x1f49d;&#x1f49d;&#x1f49d;欢迎来到我的博客&#xff0c;很高兴能够在这里和您见面&#xff01;希望您在这里可以感受到一份轻松愉快的氛围&#xff0c;不仅可以获得有趣的内容和知识&#xff0c;也可以畅所欲言、分享您的想法和见解。 推荐:kwan 的首页,持续学…

RabbitMQ基础(简单易懂)

RabbitMQ高级篇请看&#xff1a; RabbitMQ高级篇-CSDN博客 目录 什么是RabbitMQ&#xff1f; MQ 的核心概念 1. RabbitMQ 的核心组件 2. Exchange 的类型 3. 数据流向说明 如何安装RabbitQueue&#xff1f; WorkQueue&#xff08;工作队列&#xff09;&#xff1a; Fa…

《Spring Framework实战》5:Spring Framework 概述

欢迎观看《Spring Framework实战》视频教程 Spring 使创建 Java 企业应用程序变得容易。它为您提供一切 需要在企业环境中采用 Java 语言&#xff0c;并支持 Groovy 和 Kotlin 作为 JVM 上的替代语言&#xff0c;并且可以灵活地创建许多 类型的架构。从 Spring Framework 6.0 开…

有限元分析学习——Anasys Workbanch第一阶段笔记(10)桌子载荷案例分析_实际载荷与均布载荷的对比

目录 0 序言 1 桌子案例 2 模型简化 3 方案A 前处理 1&#xff09;分析类型选择 2&#xff09;材料加载 3&#xff09;约束、载荷及接触 4&#xff09;控制网格(网格大小需要根据结果不断调整) 初始计算结果 加密后计算结果 4 方案B、C 前处理 1&#xff09;分析…

Git 基础——《Pro Git》

⭐获取 Git 仓库 获取 Git 仓库有两种方式&#xff1a; 将未进行版本控制的本地目录转换为 Git 仓库。从其他服务器克隆一个已存在的 Git 仓库。 在已存在目录中初始化 Git 仓库 进入目标目录 在 Linux 上&#xff1a;$ cd /home/user/my_project在 macOS 上&#xff1a;$ c…

Java 将RTF文档转换为Word、PDF、HTML、图片

RTF文档因其跨平台兼容性而广泛使用&#xff0c;但有时在不同的应用场景可能需要特定的文档格式。例如&#xff0c;Word文档适合编辑和协作&#xff0c;PDF文档适合打印和分发&#xff0c;HTML文档适合在线展示&#xff0c;图片格式则适合社交媒体分享。因此我们可能会需要将RT…

R语言在森林生态研究中的魔法:结构、功能与稳定性分析——发现数据背后的生态故事!

森林生态系统结构、功能与稳定性分析与可视化研究具有多方面的重要意义&#xff0c;具体如下&#xff1a; 一、理论意义 ●深化生态学理论 通过研究森林生态系统的结构、功能与稳定性&#xff0c;可以深化对生态系统基本理论的理解。例如&#xff0c;生物多样性与生态系统稳定性…

Delphi+SQL Server实现的(GUI)户籍管理系统

1.项目简介 本项目是一个户籍管理系统&#xff0c;用于记录住户身份信息&#xff0c;提供新户登记&#xff08;增加&#xff09;、户籍变更&#xff08;修改&#xff09;、户籍注销&#xff08;删除&#xff09;、户籍查询、曾用名查询、迁户记录查询以及创建备份、删除备份共8…

第2课 “Hello World” 与 print

1 Hello World 2 print 函数解析 2.1 基本用法 2.2 输出多个对象 2.3 使用sep参数 2.4 使用flush参数 2.5 输出到文件 3 格式化输出 3.1 格式化输出整数 3.2 格式化输出16进制整数 3.3 格式化输出浮点数(float) 3.4 格式化输出字符串(string) 3.5 输出列表与字典 …

计算机网络(四)网络层

4.1、网络层概述 简介 网络层的主要任务是实现网络互连&#xff0c;进而实现数据包在各网络之间的传输 这些异构型网络N1~N7如果只是需要各自内部通信&#xff0c;他们只要实现各自的物理层和数据链路层即可 但是如果要将这些异构型网络互连起来&#xff0c;形成一个更大的互…