【Python时序预测系列】基于LSTM实现单变量时间序列预测(源码)

一、引言

前文回顾:

【Python时序预测系列】基于Holt-Winters方法实现单变量时间序列预测(源码)

【Python时序预测系列】基于ARIMA法实现单变量时间序列预测(源码)

【Python时序预测系列】基于SARIMA实现单变量时间序列预测(源码)

        LSTM(Long Short-Term Memory,长短期记忆)是一种常用的循环神经网络(Recurrent Neural Network,RNN)架构,用于处理和建模时间序列数据。相比于传统的RNN,LSTM具有更强的记忆能力和长期依赖建模能力,能够有效地处理长序列和解决梯度消失/爆炸的问题。

        LSTM的核心思想是引入了称为"门"的结构,以控制信息的流动和记忆的保留。LSTM单元由一个输入门(input gate)、一个遗忘门(forget gate)和一个输出门(output gate)组成,每个门都有一个可学习的权重参数。这些门决定了输入数据的处理方式,以及过去记忆的保留和遗忘。

        在LSTM中,输入门控制新输入数据的加入,遗忘门控制过去记忆的遗忘,输出门控制输出的生成。通过这些门的调节,LSTM可以选择性地记住或忘记过去的信息,并将当前输入和过去的记忆相结合,产生新的输出。这种机制使得LSTM能够有效地处理长期依赖关系,从而在许多任务中取得了很好的效果,如语言建模、机器翻译、语音识别等。

        本文以"国际航空乘客"数据集为例,使用LSTM进行单变量单步预测。

二、实现过程

导入相关的库

import warnings
warnings.filterwarnings('ignore')
import pandas as pd
from statsmodels.tsa.holtwinters import ExponentialSmoothing
import matplotlib.pyplot as plt

2.1 读取数据集

# 读取数据集
data = pd.read_csv('international-airline-passengers.csv')
# 将日期列转换为日期时间类型
data['Month'] = pd.to_datetime(data['Month'])
# 将日期列设置为索引
data.set_index('Month', inplace=True)

data:

图片

2.2 划分数据集

# 拆分数据集为训练集和测试集
train_size = int(len(data) * 0.8)
train_data = data[:train_size]
test_data = data[train_size:]

# 绘制训练集和测试集的折线图
plt.figure(figsize=(10, 6))
plt.plot(train_data, label='Training Data')
plt.plot(test_data, label='Testing Data')
plt.xlabel('Year')
plt.ylabel('Passenger Count')
plt.title('International Airline Passengers - Training and Testing Data')
plt.legend()
plt.show()

共144条数据,8:2划分:训练集115,测试集29。

训练集和测试集:

图片

2.3 归一化

# 将数据归一化到 0~1 范围
scaler = MinMaxScaler()
train_data_scaler = scaler.fit_transform(train_data.values.reshape(-1, 1))
test_data_scaler = scaler.transform(test_data.values.reshape(-1, 1))

2.4 构造数据集

# 定义滑动窗口函数
def create_sliding_windows(data, window_size):
    X, Y = [], []
    for i in range(len(data) - window_size):
        X.append(data[i:i+window_size])
        Y.append(data[i+window_size])
    return np.array(X), np.array(Y)

# 定义滑动窗口大小
window_size = 12

# 创建滑动窗口数据集
X_train, Y_train = create_sliding_windows(train_data_scaler, window_size)
X_test, Y_test = create_sliding_windows(test_data_scaler, window_size)

# 将数据集转换为 LSTM 模型所需的形状(样本数,时间步长,特征数)
X_train = np.reshape(X_train, (X_train.shape[0], window_size, 1))
X_test = np.reshape(X_test, (X_test.shape[0], window_size, 1)

滑动窗口12

训练集:

【1-12】【13】

【2-13】【14】

...

【102-113】【114】

【103-114】【115】

X_train:(103,12,1)

Y_train:(103,1)

经过滑动窗口构造的数据集,新的训练集数据数量(103)比原始训练集(115)少一个滑动窗口数量(12)

因此,实际训练值只有103条,是训练的13-115的部分。

测试集:

【116-127】【128】

【117-128】【129】

...

【131-142】【143】

【132-143】【144】

X_test:(17,12,1)

Y_test:(17,1)

经过滑动窗口构造的数据集,新的测试集数据数量(17)比原始训测试集(29)少一个滑动窗口数量(12)

因此,实际预测值只有17个,是预测的128-144的部分,如果想预测116-128的部分,可以取训练集的最后12个数进行预测。

2.5 建立模拟合模型进行预测

# 构建 LSTM 模型
model = Sequential()
model.add(LSTM(50, activation='relu', input_shape=(window_size, 1)))
model.add(Dense(1))
model.compile(optimizer='adam', loss='mse')

# 训练 LSTM 模型
model.fit(X_train, Y_train, epochs=100, batch_size=32)

# 使用 LSTM 模型进行预测
train_predictions = model.predict(X_train)
test_predictions = model.predict(X_test)

# 反归一化预测结果
train_predictions = scaler.inverse_transform(train_predictions)
test_predictions = scaler.inverse_transform(test_predictions)

predictions:

图片

2.6 预测效果展示

# 绘制测试集预测结果的折线图
plt.figure(figsize=(10, 6))
plt.plot(test_data, label='Actual')
plt.plot(list(test_data.index)[-17:], test_predictions, label='Predicted')
plt.xlabel('Month')
plt.ylabel('Passengers')
plt.title('Actual vs Predicted')
plt.legend()
plt.show()

测试集真实值与预测值:

图片

# 绘制原始数据、训练集预测结果和测试集预测结果的折线图
plt.figure(figsize=(10, 6))
plt.plot(data, label='Actual')
plt.plot(list(train_data.index)[window_size:train_size], train_predictions, label='Training Predictions')
plt.plot(list(test_data.index)[-(len(test_data)-window_size):], test_predictions, label='Testing Predictions')
plt.xlabel('Year')
plt.ylabel('Passenger Count')
plt.title('International Airline Passengers - Actual vs Predicted')
plt.legend()
plt.show()

原始数据、训练集预测结果和测试集预测结果:

图片

作者简介:

读研期间发表6篇SCI数据挖掘相关论文,现在某研究院从事数据算法相关科研工作,结合自身科研实践经历不定期分享关于Python、机器学习、深度学习、人工智能系列基础知识与应用案例。致力于只做原创,以最简单的方式理解和学习,关注我一起交流成长。需要数据集和源码的小伙伴可以关注底部公众号添加作者微信。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/355752.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

超声波清洗机买哪款比较好?四款公认好用超声波清洗机

超声波清洗机好用吗?好多人都说是普通的清洁工具买回家就是浪费钱,真心不建议购买,但其实,手动清洗眼镜的话会比较容易损坏镜片,一副眼镜比较普通的也要上几百了,而且眼镜是我们日常生活中经常会使用的&…

介绍TCP/IP

TCP/IP(传输控制协议/互联网协议)是一种用于数据通信的基本通信协议,它是互联网的基础。TCP/IP指的是一组规则和过程,它规定了如何在网络上发送和接收数据。这个协议族由两个主要部分组成:传输控制协议(TCP…

C#实现多种图片格式转换(例如转换成图标图像ICO)

1,目的: 实现多种图片格式的相互转换,图片大小可自定义等。 2,知识点: 转换成图标图像(ico)时,需要获取图像句柄,然后根据句柄生成Ico图像,否则生成的图像不能作为应用的图标使用。 IntPtr hwd bitmap.GetHicon();…

MongoDB之概述、命令

基础知识 是什么 概念 分布式文件存储数据库,提供高可用、可扩展、易部署的数据存储解决方案。 结构 BSON存储类型 类似JSON的一种二进制存储格式。相比于JSON,提供更丰富的类型支持。 优点是灵活,缺点是空间利用率不佳。 类型说明解释…

python爬虫demo——爬取历史平均房价

简单爬取历史房价 需求 爬取的网站汇聚数据的城市房价 https://fangjia.gotohui.com/ 功能 选择城市 https://fangjia.gotohui.com/fjdata-3 需要爬取年份的数据,等等 https://fangjia.gotohui.com/years/3/2018/ 使用bs4模块 使用bs4模块快速定义需要爬取的…

基于springboot+微信小程序+vue实现的校园二手商城项目源码

介绍 校园二手商城,架构:springboot微信小程序vue 软件架构 软件架构说明 系统截图 技术选型 技术版本说明Spring Boot2.1.6MVC核心框架Spring Security oauth22.1.5认证和授权框架MyBatis3.5.0ORM框架MyBatisPlus3.1.0基于mybatis,使用…

生成对抗网络

目录 1.GAN的网络组成 2.损失函数解释说明 2.1 BCEloss 2.2整体代码 1.GAN的网络组成 2.损失函数解释说明 2.1 BCEloss 损失函数 import torch from torch import autogradinput autograd.Variable(torch.tensor([[1.9072,1.1079,1.4906],[-0.6584,-0.0512,0.7608],[-0.0…

【嵌入式移植】5、U-Boot源码分析2—make nanopi_neo2_defconfig

U-Boot源码分析2—make nanopi_neo2_defconfig 1 概述2 nanopi_neo2_defconfig3 编译过程分析3.1 编译目标3.2 scripts_basic3.2.1 prefix src定义3.2.2 PHONY3.2.3 __build3.2.4 fixdep3.3 objscripts/kconfig 1 概述 上一章中,对Makefile相关源码进行了初步分析&…

Vue-cli脚手架将组件挂载到全局

局部引用组件,直接将组件引入,注册组件即可,这篇文章讲组件挂载到全局的方法! main.js文件 将组件引入main.js文件中,并且注册 使用方法 在需要的地方使用组件即可 BaoGit.Vue代码 <template><div><a href"https://gitee.com/ah-ah-bao"><img …

【机器学习】正则化

正则化是防止模型过拟合的方法&#xff0c;它通过对模型的权重进行约束来控制模型的复杂度。 正则化在损失函数中引入模型复杂度指标&#xff0c;利用给W加权值&#xff0c;弱化了数据的噪声&#xff0c;一般不正则化b。 loss(y^,y)&#xff1a;模型中所有参数的损失函数&…

PID校正

一、Introduction to PID Control PID控制是一种应用非常广泛的控制算法。小到控制一个元件的温度&#xff0c;大到控制无人机的飞行姿态和飞行速度等等&#xff0c;都可以使用PID控制。PID(proportion integration differentiation)其实就是指比例&#xff0c;积分&#xff0…

Python tkinter (10) ——Combobox控件

本文主要是Python tkinter Combobox下拉控件介绍及使用示例。 tkinter系列文章 python tkinter窗口简单实现 Python tkinter (1) —— Label标签 Python tkinter (2) —— Button标签 Python tkinter (3) —— Entry标签 Python tkinter (4) —— Text控件 Python tkinte…

Springboot入门教程详解

Springboot入门教程详解 博客主页&#xff1a;划水的阿瞒的博客主页 欢迎关注&#x1f5b1;点赞&#x1f380;收藏⭐留言✒ 系列专栏&#xff1a;Springboot入门教程详解首发时间&#xff1a;&#x1f39e;2024年1月29日&#x1f3a0; 如果觉得博主的文章还不错的话&#xff0c…

【DeepLearning-10】yolo.py文件关键代码parse_model(d, ch)函数

这段代码功能是根据提供的配置字典&#xff08;d&#xff09;和输入通道列表&#xff08;ch&#xff09;来解析并构建一个YOLOv5模型。函数的核心工作是遍历模型的每一层&#xff0c;并根据配置创建相应的神经网络层。 我们可以在函数中为新增模块配置构造参数设置。 函数中 f…

MyBatis 如何整合 Druid 连接池?

Mybatis 如何整合 Druid 数据连接池呢&#xff1f;首先打开创建的 Maven 工程&#xff0c;找到 pom.xml 文件&#xff0c;添加 Druid 依赖。 <!--druid连接池--> <dependency><groupId>com.alibaba</groupId><artifactId>druid</artifactId&…

MySQL-窗口函数 简单易懂

窗口函数 考查知识点&#xff1a; • 如何用窗口函数解决排名问题、Top N问题、前百分之N问题、累计问题、每组内比较问题、连续问题。 什么是窗口函数 窗口函数也叫作OLAP&#xff08;Online Analytical Processing&#xff0c;联机分析处理&#xff09;函数&#xff0c;可…

HCIP复习课(mpls实验)

1、IP配置&#xff1a; R1&#xff1a; R2&#xff1a; R3&#xff1a; R4&#xff1a; R5&#xff1a; R6&#xff1a; R7&#xff1a; R8&#xff1a; 2、rip&#xff0c;ospf配置&#xff1a; R2&#xff1a; R3&#xff1a; R4&#xff1a; R5&#xff1a; R6&#xff1a…

MySQL知识点总结(一)——一条SQL的执行过程、索引底层数据结构、一级索引和二级索引、索引失效、索引覆盖、索引下推

MySQL知识点总结&#xff08;一&#xff09;——一条SQL的执行过程、索引底层数据结构、一级索引和二级索引、索引失效、索引覆盖、索引下推 一条SQL的执行过程索引底层数据结构为什么不使用二叉树&#xff1f;为什么不使用红黑树?为什么不使用hash表&#xff1f;为什么不使用…

屏蔽系统热键/关机/注入 Winlogon(中)

1 前言 在新的内容开始前&#xff0c;我想整理一些旧文&#xff0c;这一框题展示了在以前的系统上实现在用户关机/重启/注销时弹出对话框的功能。为什么需要先讲这个部分&#xff1f;因为这一部分需要拦截的函数是截至 Win 8 系统&#xff0c;微软所采用的关机/重启等途径上的…

海外推广是企业必须面临和重视的问

随着中国半导体国际化进程的加快&#xff0c;越来越多的企业开始走向海外市场&#xff0c;对于企业出海来说&#xff0c;想要最大限度的提高曝光度&#xff0c;提升企业核心竞争力&#xff0c;做好海外推广是企业必须面临和重视的问题。萨科微(www.slkoric.com)半导体积极布局海…