LSTM时序预测 | Python实现LSTM长短期记忆神经网络时间序列预测

本文内容:Python实现LSTM长短期记忆神经网络时间序列预测,使用的数据集为AirPassengers

目录

数据集简介

1.步骤一

2.步骤二

3.步骤三

4.步骤四

数据集简介

AirPassengers 数据集的来源可以追溯到经典的统计和时间序列分析文献。原始数据集由 Box, Jenkins 和 Reinsel 在他们的书籍《Time Series Analysis: Forecasting and Control》中引入,这本书在时间序列分析领域非常著名

1.训练结果

2.步骤一

安装darts库:

pip install darts

#在连接处添加注意力机制
class UNetAttention1(nn.Module):
    def __init__(self, n_channels, n_classes, bilinear=False, attention=False):
        super(UNetAttention1, self).__init__()

        self.model_name = 'UNetAttention1'
        self.n_channels = n_channels
        self.n_classes = n_classes
        self.bilinear = bilinear
        self.attention = attention

        self.inc = (DoubleConv(n_channels, 64))
        self.down1 = (Down(64, 128))
        self.down2 = (Down(128, 256))
        self.down3 = (Down(256, 512))
        factor = 2 if bilinear else 1
        self.down4 = (Down(512, 1024 // factor))
        self.up1 = (Up(1024, 512 // factor, bilinear))
        self.up2 = (Up(512, 256 // factor, bilinear))
        self.up3 = (Up(256, 128 // factor, bilinear))
        self.up4 = (Up(128, 64, bilinear))
        self.outc = (OutConv(64, n_classes))

        if self.attention:
            self.attention1 = CBAM(64)
            self.attention2 = CBAM(128)
            self.attention3 = CBAM(256)
            self.attention4 = CBAM(512)



    def forward(self, x):
        x1 = self.inc(x)
        if self.attention:
            x1 = self.attention1(x1) + x1

        x2 = self.down1(x1)
        if self.attention:
            x2 = self.attention2(x2) + x2

        x3 = self.down2(x2)
        if self.attention:
            x3 = self.attention3(x3) + x3

        x4 = self.down3(x3)
        if self.attention:
            x4 = self.attention4(x4) + x4

        x5 = self.down4(x4)
        x = self.up1(x5, x4)
        x = self.up2(x, x3)
        x = self.up3(x, x2)
        x = self.up4(x, x1)
        logits = self.outc(x)
        return logits

    def use_checkpointing(self):
        self.inc = torch.utils.checkpoint(self.inc)
        self.down1 = torch.utils.checkpoint(self.down1)
        self.down2 = torch.utils.checkpoint(self.down2)
        self.down3 = torch.utils.checkpoint(self.down3)
        self.down4 = torch.utils.checkpoint(self.down4)
        self.up1 = torch.utils.checkpoint(self.up1)
        self.up2 = torch.utils.checkpoint(self.up2)
        self.up3 = torch.utils.checkpoint(self.up3)
        self.up4 = torch.utils.checkpoint(self.up4)
        self.outc = torch.utils.checkpoint(self.outc)

3.步骤二

部分代码如下:


import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
import pandas as pd
import shutil
from sklearn.preprocessing import MinMaxScaler
from tqdm import tqdm_notebook as tqdm
import matplotlib.pyplot as plt

from darts import TimeSeries
from darts.dataprocessing.transformers import Scaler
from darts.models import RNNModel, ExponentialSmoothing, BlockRNNModel
from darts.metrics import mape, mae, mse, rmse
from darts.utils.statistics import check_seasonality, plot_acf
from darts.datasets import AirPassengersDataset, SunspotsDataset
from darts.utils.timeseries_generation import datetime_attribute_timeseries

import warnings

warnings.filterwarnings("ignore")
import logging

logging.disable(logging.CRITICAL)

####################数据准备##########################
# Read data:
series = AirPassengersDataset().load()  #原始数据集由 Box, Jenkins 和 Reinsel 在他们的书籍《Time Series Analysis: Forecasting and Control》中引入

# Create training and validation sets:
train, val = series.split_after(pd.Timestamp("19590101")) ##可以填写具体的日期,也可以填写比例

# Normalize the time series (note: we avoid fitting the transformer on the validation set)
transformer = Scaler()
train_transformed = transformer.fit_transform(train)
val_transformed = transformer.transform(val)
series_transformed = transformer.transform(series)

# create month and year covariate series
year_series = datetime_attribute_timeseries(
    pd.date_range(start=series.start_time(), freq=series.freq_str, periods=1000),
    attribute="year",
    one_hot=False,
)
year_series = Scaler().fit_transform(year_series)
month_series = datetime_attribute_timeseries(
    year_series, attribute="month", one_hot=True
)
covariates = year_series.stack(month_series)
cov_train, cov_val = covariates.split_after(pd.Timestamp("19590101"))

####################构建模型##########################
my_model = RNNModel(
    model="LSTM",
    hidden_dim=20,
    dropout=0,
    batch_size=16,
    n_epochs=300,
    optimizer_kwargs={"lr": 1e-3},
    model_name="Air_RNN",
    log_tensorboard=True,
    random_state=42,
    training_length=20,
    input_chunk_length=14,
    force_reset=True,
    save_checkpoints=True,
)


my_model.fit(
    train_transformed,
    future_covariates=covariates,
    val_series=val_transformed,
    val_future_covariates=covariates,
    verbose=True,
)

完整代码下载地址:下载地址

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/888170.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

增强分析:新时代的数据洞察工具

随着数据科学和人工智能的迅猛发展,分析数据的方式也发生了显著的变化。增强分析(Augmented Analytics)是近年来涌现出的新概念,它将人工智能(AI)、机器学习(ML)和自然语言处理&…

HarmonyOS NEXT - 表单录入组件封装(TextInput)

demo 地址: https://github.com/iotjin/JhHarmonyDemo 组件对应代码实现地址 代码不定时更新,请前往github查看最新代码 HarmonyOS NEXT - 表单录入组件封装(TextInput) 序JhFormInputCellJhFormSelectCellJhLoginTextField 序 鸿蒙next中有两…

java 的三种IO模型(BIO、NIO、AIO)

java 的三种IO模型(BIO、NIO、AIO) 一、BIO 阻塞式 IO(Blocking IO)1.1、BIO 工作机制1.2、BIO 实现单发单收1.3、BIO 实现多发多收1.4、BIO 实现客户端服务端多对一1.5、BIO 模式下的端口转发思想 二、NIO 同步非阻塞式 IO&#…

Android15车载音频之Virtualbox中QACT实时调试(八十八)

简介: CSDN博客专家、《Android系统多媒体进阶实战》一书作者 新书发布:《Android系统多媒体进阶实战》🚀 优质专栏: Audio工程师进阶系列【原创干货持续更新中……】🚀 优质专栏: 多媒体系统工程师系列【原创干货持续更新中……】🚀 优质视频课程:AAOS车载系统+…

Pikachu- Over Permission-垂直越权

以admin 账号登陆,添加一个用户; 把添加用户的这个请求发送到 repeater; 退出admin,使用普通用户pikachu登陆; 只有查看权限; 使用pikachu 用户的认证信息,替换repeater处管理员创建用户请求的…

六、索引的数据结构

文章目录 1. 为什么使用索引2. 索引及其优缺点2.1 索引概述2.2 优点2.3 缺点3. InnoDB中索引的推演3.1 索引之前的查找3.1.1 在一个页中的查找3.1.2 在很多页中查找3.2 设计索引3.2.1 一个简单的索引设计方案3.2.2 InnoDB中的索引方案3.3 常见索引概念3.3.1 聚簇索引3.3.2 二级…

UDP协议【网络】

文章目录 UDP协议格式 UDP协议格式 16位源端口号:表示数据从哪里来。16位目的端口号:表示数据要到哪里去。16位UDP长度:表示整个数据报(UDP首部UDP数据)的长度。16位UDP检验和:如果UDP报文的检验和出错&…

centos一些常用命令

文章目录 查看磁盘信息使用 df 命令使用 du 命令 查看磁盘信息 使用 df 命令 df(disk free)命令用于显示文件系统的磁盘空间占用情况。 查看所有挂载点的磁盘使用情况: df -h选项说明: -h 参数表示以人类可读的格式&#xff0…

Windows下Jenkins控制台中文乱码

问题描述 问题情况如下图: 环境信息 Windows 11 家庭中文版java 21.0.4 2024-07-16 LTSJenkins 2.452.3 解决方法 增加系统JAVA_TOOL_OPTIONS,并设置值为-Dfile.encodingGBK。 打开设置方法:桌面上右键点击“此电脑”图标,选…

【黑马点评】使用RabbitMQ实现消息队列——3.使用Jmeter压力测试,导入批量token,测试异步秒杀下单

3 批量获取用户token,使用jmeter压力测试 3 批量获取用户token,使用jmeter压力测试3.1 需求3.2 实现3.2.1 环境配置3.2.2 修改登录接口UserController和实现类3.2.3 测试类 3.3 使用jmeter进行测试3.4 测试结果3.5 将用户登录逻辑修改回去 3 批量获取用户…

力扣16~20题

题16&#xff08;中等&#xff09;&#xff1a; 思路&#xff1a; 双指针法&#xff0c;和15题差不多&#xff0c;就是要排除了&#xff0c;如果total<target则排除了更小的&#xff08;left右移&#xff09;&#xff0c;如果total>target则排除了更大的&#xff08;rig…

三绕组单相电容电动机的瞬态分析(2)定子三角形接法时起动过程及稳态性能分析

1. 引言 2. 定子接三绕组单相电容电动机的数学模型 3.最佳移相电容计算 4. 仿真分析实例 5. 总结 6. 参考文献 1. 引言 目前,三相供电系统在全世界范围内已是非常普遍了。但是,由于架设输电线成本高,受输、配电系统的限制,城市居民用电和大多数农村及边远地区的用电仍…

JavaWeb——Vue路由(概述、介绍、使用、解决bug)

目录 概述 介绍 使用 解决bug 概述 员工管理的页面已经制作完成。其他页面的制作方式一致。 项目中准备了部门管理的页面组件 DeptView &#xff0c;这样就有了员工管理和部门管理两个页面组件。 在系统中&#xff0c;要实现点击部门管理菜单展示部门管理组件&#xff0c…

线性代数入门指南

在数学的广袤领域中&#xff0c;线性代数犹如一座神秘而又充满魅力的殿堂&#xff0c;等待着初学者去探索。当你踏入线性代数的大门&#xff0c;便开启了一段充满挑战与惊喜的知识之旅。 线性代数是什么呢&#xff1f;简单来说&#xff0c;它是一门研究线性方程组、向量空间、线…

Android Automotive(一)

目录 什么是Android Automotive Android Automotive & Android Android Automotive 与 Android Auto 什么是Android Automotive Android Automotive 是一个基础的 Android 平台&#xff0c;它能够运行预装的车载信息娱乐系统&#xff08;IVI&#xff09;应用程序&#…

分治算法(3)_快速选择_数组中的第K个最大元素

个人主页&#xff1a;C忠实粉丝 欢迎 点赞&#x1f44d; 收藏✨ 留言✉ 加关注&#x1f493;本文由 C忠实粉丝 原创 分治算法(3)_快速排序_数组中的第K个最大元素 收录于专栏【经典算法练习】 本专栏旨在分享学习算法的一点学习笔记&#xff0c;欢迎大家在评论区交流讨论&#…

51单片机的自动制冷系统【proteus仿真+程序+报告+原理图+演示视频】

1、主要功能 该系统由AT89C51/STC89C52单片机LCD1602显示模块温度传感器继电器LED、按键和蜂鸣器等模块构成。适用于车载便携式自动制冷系统、冰箱制冷、温度控制等相似项目。 可实现功能: 1、LCD1602实时显示当前温度 2、温度传感器DS18B20采集温度 3、按键可设置温度的阈…

双向数据库迁移工具:轻松实现 MySQL 与 SQLite 数据互导

项目概述与作用 该项目的核心是实现 MySQL 和 SQLite 两种数据库之间的数据迁移工具。它能够轻松地将 MySQL 数据库中的数据导出为 SQLite 数据库文件&#xff0c;反过来也可以将 SQLite 数据库中的数据上传到 MySQL 数据库中。这个双向迁移工具非常适用于&#xff1a; 数据库备…

日期类的实现(C++)

个人主页&#xff1a;Jason_from_China-CSDN博客 所属栏目&#xff1a;C系统性学习_Jason_from_China的博客-CSDN博客 所属栏目&#xff1a;C知识点的补充_Jason_from_China的博客-CSDN博客 前言 日期类是六个成员函数学习的总结和拓展&#xff0c;是实践的体现 创建文件 构造函…

[C#]使用onnxruntime部署yolov11-onnx实例分割模型

【官方框架地址】 https://github.com/ultralytics/ultralytics.git 【算法介绍】 在C#中使用ONNX Runtime部署YOLOv11-ONNX实例分割模型&#xff0c;涉及到模型的加载、数据预处理、模型推理和后处理几个关键步骤。 首先&#xff0c;需要确保已经安装了ONNX Runtime的NuGe…