【Week-R2】使用LSTM实现火灾预测(tf版本)

【Week-R2】使用LSTM实现火灾预测(tf版本)

  • 一、 前期准备
    • 1.1 设置GPU
    • 1.2 导入数据
    • 1.3 数据可视化
  • 二、数据预处理(构建数据集)
    • 2.1 设置x、y
    • 2.2 归一化
    • 2.3 划分数据集
  • 三、模型创建、编译、训练、得到训练结果
    • 3.1 构建模型
    • 3.2 编译模型
    • 3.3 训练模型
    • 3.4 模型评估
      • 3.4.1 Loss与Accuracy图
      • 3.4.2 调用模型进行预测
      • 3.4.3 查看误差
  • 四、其他
    • 4.1 模块报错:seaborn模块导入错误
    • 4.2 图片实时显示比例不对
    • 4.3 什么是LSTM

  • 🍨 本文为🔗365天深度学习训练营 中的学习记录博客
  • 🍖 原作者:K同学啊 | 接辅导、项目定制

在这里插入图片描述

一、 前期准备

语言环境:Python3.7.8
编译器选择:VSCode
深度学习环境:TensorFlow
数据集:本地数据集

1.1 设置GPU

本文使用CPU环境

'''
LSTM-实现火灾预测
'''

import tensorflow as tf

gpus = tf.config.list_physical_devices("GPU")

if gpus:
    gpu0 = gpus[0]
    tf.config.experimental.set_memory_growth(gpu0,true)
    tf.config.set_visible_devices([gpu0],"GPU")
    print("GPU: ",gpus)
else:
    print("CPU:")

输出:
在这里插入图片描述

1.2 导入数据

下载数据集文件woodpine2.csv到本地,使用绝对路径进行访问:

# 2.1 导入数据
import  pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns

df = pd.read_csv("D:\\jupyter notebook\\DL-100-days\\RNN\\woodpine2.csv")
print("df:", df)

输出:
在这里插入图片描述

1.3 数据可视化

# 3.数据可视化 
plt.rcParams['savefig.dpi'] = 500
plt.rcParams['figure.dpi'] = 500
 
fig,ax = plt.subplots(1,3,constrained_layout = True , figsize = (14,3))
sns.lineplot(data=df["Tem1"],ax=ax[0])
sns.lineplot(data=df["CO 1"],ax=ax[1])
sns.lineplot(data=df["Soot 1"],ax=ax[2])
plt.savefig("D:\\jupyter notebook\\DL-100-days\\RNN\\3.数据可视化.png")
#plt.show()

输出:
在这里插入图片描述

二、数据预处理(构建数据集)

# 二、数据预处理(构建数据集)
dataFrame = df.iloc[:,1:]
print("dataFrame:", dataFrame)

输出:
在这里插入图片描述

2.1 设置x、y

# 1,设置x、y
# 需要实现:使用1-8时刻段预测9时刻段,则通过下述代码做好长度的确定:
width_X = 8
width_y = 1

X = []
y = []
 
in_start = 0
 
for _,_ in df.iterrows():
    in_end = in_start + width_X
    out_end = in_end + width_y
 
    if out_end < len(dataFrame):
        X_ = np.array(dataFrame.iloc[in_start:in_end,])
        X_ = X_.reshape((len(X_)*3))
        y_ = np.array(dataFrame.iloc[in_end:out_end,0])
 
        X.append(X_)
        y.append(y_)
 
    in_start += 1
 
X = np.array(X)
y = np.array(y)
 
print(X.shape,y.shape)

输出:
在这里插入图片描述

2.2 归一化

# 2,归一化
from sklearn.preprocessing import MinMaxScaler
 
sc = MinMaxScaler(feature_range=(0,1))
X_scaled = sc.fit_transform(X)
print(X_scaled.shape)

X_scaled = X_scaled.reshape(len(X_scaled),width_X,3)
print(X_scaled.shape)

输出:
在这里插入图片描述

2.3 划分数据集

取5000之前的数据作为训练集,5000之后的数据作为验证集:

# 3,划分数据集
# 取5000之前的数据作为训练集,5000之后的数据作为验证集:
X_train = np.array(X_scaled[:5000]).astype('float64')
y_train = np.array(y[:5000]).astype('float64')
 
X_test = np.array(X_scaled[5000:]).astype('float64')
y_test = np.array(y[5000:]).astype('float64')
 
print(X_train.shape)

输出:
在这里插入图片描述

三、模型创建、编译、训练、得到训练结果

3.1 构建模型

通过下面代码,构建一个包含两个LSTM层和一个全连接层的LSTM模型。这个模型将接受形状为(X_train.shape[1], 3)的输入,其中X_train.shape[1]是时间步数,3 是每个时间步的特征数。

# 三、构建模型
import keras
from keras.models import Sequential
from keras.layers import Dense,LSTM
 
model_lstm = Sequential()
model_lstm.add(LSTM(units=64,activation='relu',return_sequences=True,input_shape=(X_train.shape[1],3)))
model_lstm.add(LSTM(units=64,activation='relu'))
model_lstm.add(Dense(width_y))
# 通过上述代码,构建了一个包含两个LSTM层和一个全连接层的LSTM模型。这个模型将接受形状为 (X_train.shape[1], 3) 的输入,其中 X_train.shape[1] 是时间步数,3 是每个时间步的特征数。

输出:
在这里插入图片描述

3.2 编译模型

# 四、 编译模型 
model_lstm.compile(loss='mean_squared_error',
              optimizer=tf.keras.optimizers.Adam(1e-3))

3.3 训练模型

history = model_lstm.fit(X_train,y_train,
                    epochs = 40,
                    batch_size = 64,
                    validation_data=(X_test,y_test),
                    validation_freq= 1)

训练输出:
在这里插入图片描述

3.4 模型评估

3.4.1 Loss与Accuracy图

# 六、 模型评估
# 1.Loss与Accuracy图
import matplotlib.pyplot as plt
 
plt.rcParams['font.sans-serif'] = ['SimHei']
plt.rcParams['axes.unicode_minus'] = False
 
plt.figure(figsize=(5, 3),dpi=120)
 
plt.plot(history.history['loss'],label = 'LSTM Training Loss')
plt.plot(history.history['val_loss'],label = 'LSTM Validation Loss')
 
plt.title('Training and Validation Accuracy')
plt.legend()
plt.savefig("D:\\jupyter notebook\\DL-100-days\\RNN\\1.Loss与Accuracy图.png")
plt.show()

输出:
在这里插入图片描述

3.4.2 调用模型进行预测

# 2.调用模型进行预测
predicted_y_lstm = model_lstm.predict(X_test)
 
y_tset_one = [i[0] for i in y_test]
predicted_y_lstm_one = [i[0] for i in predicted_y_lstm]
 
plt.figure(figsize=(5,3),dpi=120)
plt.plot(y_tset_one[:1000],color = 'red', label = '真实值')
plt.plot(predicted_y_lstm_one[:1000],color = 'blue', label = '预测值')
 
plt.title('Title')
plt.xlabel('X')
plt.ylabel('y')
plt.legend()
plt.savefig("D:\\jupyter notebook\\DL-100-days\\RNN\\2.调用模型进行预测图.png")
plt.show()

输出:
在这里插入图片描述

3.4.3 查看误差


# 3. 查看误差
from  sklearn import metrics
 
RMSE_lstm = metrics.mean_squared_error(predicted_y_lstm,y_test)**0.5
R2_lstm = metrics.r2_score(predicted_y_lstm,y_test)
 
print('均方根误差:%.5f' % RMSE_lstm)
print('R2:%.5f' % R2_lstm)

输出:
在这里插入图片描述

四、其他

4.1 模块报错:seaborn模块导入错误

解决方法如下:
在这里插入图片描述
在这里插入图片描述

4.2 图片实时显示比例不对

改为保存到本地查看:【在plt.show()之前保存
line 34:

plt.savefig("D:\\jupyter notebook\\DL-100-days\\RNN\\3.数据可视化.png")
plt.show()

line 125:

plt.savefig("D:\\jupyter notebook\\DL-100-days\\RNN\\1.Loss与Accuracy图.png")
plt.show()

line 143:

plt.savefig("D:\\jupyter notebook\\DL-100-days\\RNN\\2.调用模型进行预测图.png")
plt.show()

输出:
在这里插入图片描述

4.3 什么是LSTM

LSTM是一种特殊的RNN,能到学习到长期的依赖关系,可以理解为升级版的RNN。
在这里插入图片描述
传统的RNN在处理长序列时存在着“梯度爆炸(/梯度消失)”和“短时记忆”的问题,向RNN中加入遗忘门、输入门及输出门使得困扰RNN的问题得到了一定的解决;
在这里插入图片描述
关于LSTM的实现流程:(1、单输出时间步)单输入单输出、多输入单输出、多输入多输出(2、多输出时间步)单输入单输出、多输入单输出、多输入多输出;
在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/688261.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

移动安全赋能化工能源行业智慧转型

随着我国能源化工企业的不断发展&#xff0c;化工厂中经常存在火灾爆炸的危险&#xff0c;特别是生产场所&#xff0c;约有80%以上生产场所区域存在爆炸性物质。而目前我国化工危险场所移动通信设备的普及率高&#xff0c;但是对移动通信设备的安全防护却有所忽视&#xff0c;包…

国自然基金的检索

&#xff08;1&#xff09;网址 跳转国自然基金网址&#xff1a;https://www.nsfc.gov.cn/ &#xff08;2&#xff09;查询入口 &#xff08;3&#xff09;进行查询

antdv 穿梭框

antd的穿梭框的数据貌似只接收key和title&#xff0c;而且必须是字符串&#xff08;我测试不是字符串的不行&#xff09;&#xff0c; 所以要把后端返回的数据再处理一下得到我们想要的数据 除了实现简单的穿梭框功能&#xff0c;还想要重写搜索事件&#xff0c;想达到的效果是…

【96】write combine机制介绍

前言 这篇文章主要介绍了write combine的机制 一、write combine的试验 1.系统配置 &#xff08;1&#xff09;、CPU&#xff1a;11th Gen Intel(R) Core(TM) i7-11700 2.50GHz &#xff08;2&#xff09;、GPU&#xff1a;XX &#xff08;3&#xff09;、link status&am…

Towards Graph Contrastive Learning: A Survey and Beyond

目录 Towards Graph Contrastive Learning- A Survey and Beyond摘要IntroductionPRELIMINARY符号说明GNN对比学习下游任务 GCL自监督学习增强策略基于规则随机扰动或mask子图采样图扩散 基于学习图结构学习图对抗训练图合理化 对比模式同尺度对比全局上下文局部 跨尺度对比局部…

构建体育直播平台源码:深度解析数据分析模块的核心展示内容

在现代的体育直播平台中&#xff0c;数据分析展示已经成为不可或缺的一部分。如下参考借助“东莞梦幻网络科技”提供的体育直播源码&#xff0c;打造的平台&#xff0c;并通过表格形式为用户列出以下数据分析内容&#xff1a; 1、积分排名&#xff1a;反映了各支队伍在赛季中的…

HIK录像机GB28181对接相机不在线问题随笔

一、问题现象 【设备信息】型号&#xff1a;DS-8664N-I16-V3 V4.63.000 build 230412 【问题现象】HIK录像机使用GB28181对接异常相机无法正常上线&#xff0c;对接HIK相机可以正常上线。 【现场拓扑】现场拓扑如下 NVR侧使用固定公网IP地址。IPC侧使用家用宽带的方式&…

玩转ChatGPT:最全学术论文提示词分享【上】

学境思源&#xff0c;一键生成论文初稿&#xff1a; AcademicIdeas - 学境思源AI论文写作 在当今数字时代&#xff0c;人工智能&#xff08;AI&#xff09;技术正迅速改变各行各业的运作方式。特别是&#xff0c;OpenAI的ChatGPT等语言模型以其强大的文本生成能力&#xff0c;…

从零开始实现自己的串口调试助手(8)-循环发送

循环发送 准备 创建槽函数 设置QSpinBox的最大值 注意&#xff1a; // 我们不能在qt的ui线程中延时&#xff0c;否则将导致页面刷新问题 //QThread::msleep(ui->spinBox->text().toInt());//设置下次发送时间间隔 定时器实现 关联信号与槽: //添加自动换行定…

C++学习/复习13--list概述

一、list概念 1.带头双向链表 2.构造函数 3.迭代器&#xff08;其迭代器需尤其注意&#xff09; 4、size 5.front/back 6.插入删除 删除时的迭代器失效 由于list的节点特殊&#xff0c;既有数据又有指针&#xff0c;其实现需要节点/迭代器/list各成一类再组合

UI框架与MVC模式详解(1)——逻辑与数据分离

【效率最高的耦合方式】 以实际的例子来说明&#xff0c;更容易理解些。 这里从上到下&#xff0c;从左到右共有8个显示项&#xff0c;如果只需要显示这8个&#xff0c;不会做任何改变&#xff0c;数据固定&#xff0c;那么我们只需要最常规的思路去写就好&#xff0c;这是最…

【小海实习日记】Git使用规范

1.Git使用流程 1.1 从master分支拉一个分支&#xff0c;命名要符合规范且清晰。 1.2 commit到本地&#xff0c;push 到远端。 1.3 在Gitlab创建MR&#xff0c;选择develp分支。 1.4 如果要修改的话&#xff0c;先把Gitlab上的MR修改为Draft(修改态)&#xff0c;然后在本地修改代…

SwiftUI获取用户的位置信息(CLLocationManager,CLLocationManagerDelegate)

本篇文章将会介绍一下在SwiftUI中如何通过CorLocation框架获取用户的位置信息&#xff0c;因为获取位置信息属于用户的隐私信息&#xff0c;所以需要在Info.plist文件里面加上访问位置权限的说明。 关于位置信息&#xff0c;可以请求两种级别的许可&#xff1a;always和when i…

图的创建和BFS,DFS遍历

图的创建 图是一种用于表示对象及其关系的抽象数据结构&#xff0c;由节点&#xff08;也称为顶点&#xff09;和连接节点的边组成。图可以分为有向图&#xff08;Directed Graph&#xff09;和无向图&#xff08;Undirected Graph&#xff09;&#xff0c;以及加权图&…

Redis的事务与关系型数据库事务有何不同?

引言&#xff1a;关于 Redis 的事务很多人可能都是一知半解&#xff0c;大多数人只了解数据库的事务&#xff0c;并且是单体事务&#xff0c;对于 Redis 事务和常见关系型数据库的事务的区别还没有去了解过&#xff0c;本文就来详细进行介绍。 题目 Redis的事务与关系型数据库…

git【工具软件】分布式版本控制工具软件

一、Git 的介绍 git软件的作用&#xff1a;管理软件开发项目中的源代码文件。 常用功能&#xff1a; 仓库管理、文件管理、分支管理、标签管理、远程操作 功能指令&#xff1a; add&#xff0c;commit&#xff0c;log&#xff0c;branch&#xff0c;tag&#xff0c;remote…

UE5 Mod Support 思路——纯蓝图

原创作者&#xff1a;Chatouille 核心功能 “Get Blueprint Assets”节点&#xff0c;用于加载未来的mod。用基础类BP_Base扩展即可。打包成补丁&#xff0c;放到Content\Paks目录下&#xff0c;即可让游戏访问到内容。 与文中所写不同的地方 5.1或者5.2开始&#xff0c;打…

mysql optimizer_switch : 查询优化器优化策略深入解析

码到三十五 &#xff1a; 个人主页 在 MySQL 数据库中&#xff0c;查询优化器是一个至关重要的组件&#xff0c;它负责确定执行 SQL 查询的最有效方法。为了提供DBA和开发者更多的灵活性和控制权&#xff0c;MySQL 引入了 optimizer_switch 系统变量。这个强大的工具允许用户开…

自动驾驶仿真(高速道路)LaneKeeping

前言 A high-level decision agent trained by deep reinforcement learning (DRL) performs quantitative interpretation of behavioral planning performed in an autonomous driving (AD) highway simulation. The framework relies on the calculation of SHAP values an…

Go微服务: 基于使用场景理解分布式之二阶段提交

概述 二阶段提交&#xff08;Two-Phase Commit&#xff0c;2PC&#xff09;是一种分布式事务协议&#xff0c;用于在分布式系统中确保多个参与者的操作具有原子性即所有参与者要么全部提交事务&#xff0c;要么全部回滚事务&#xff0c;以维持数据的一致性它分为两个阶段进行&…