本文内容:Python实现LSTM长短期记忆神经网络时间序列预测,使用的数据集为AirPassengers
目录
数据集简介
1.步骤一
2.步骤二
3.步骤三
4.步骤四
数据集简介
AirPassengers
数据集的来源可以追溯到经典的统计和时间序列分析文献。原始数据集由 Box, Jenkins 和 Reinsel 在他们的书籍《Time Series Analysis: Forecasting and Control》中引入,这本书在时间序列分析领域非常著名
1.训练结果
2.步骤一
安装darts库:
pip install darts
#在连接处添加注意力机制
class UNetAttention1(nn.Module):
def __init__(self, n_channels, n_classes, bilinear=False, attention=False):
super(UNetAttention1, self).__init__()
self.model_name = 'UNetAttention1'
self.n_channels = n_channels
self.n_classes = n_classes
self.bilinear = bilinear
self.attention = attention
self.inc = (DoubleConv(n_channels, 64))
self.down1 = (Down(64, 128))
self.down2 = (Down(128, 256))
self.down3 = (Down(256, 512))
factor = 2 if bilinear else 1
self.down4 = (Down(512, 1024 // factor))
self.up1 = (Up(1024, 512 // factor, bilinear))
self.up2 = (Up(512, 256 // factor, bilinear))
self.up3 = (Up(256, 128 // factor, bilinear))
self.up4 = (Up(128, 64, bilinear))
self.outc = (OutConv(64, n_classes))
if self.attention:
self.attention1 = CBAM(64)
self.attention2 = CBAM(128)
self.attention3 = CBAM(256)
self.attention4 = CBAM(512)
def forward(self, x):
x1 = self.inc(x)
if self.attention:
x1 = self.attention1(x1) + x1
x2 = self.down1(x1)
if self.attention:
x2 = self.attention2(x2) + x2
x3 = self.down2(x2)
if self.attention:
x3 = self.attention3(x3) + x3
x4 = self.down3(x3)
if self.attention:
x4 = self.attention4(x4) + x4
x5 = self.down4(x4)
x = self.up1(x5, x4)
x = self.up2(x, x3)
x = self.up3(x, x2)
x = self.up4(x, x1)
logits = self.outc(x)
return logits
def use_checkpointing(self):
self.inc = torch.utils.checkpoint(self.inc)
self.down1 = torch.utils.checkpoint(self.down1)
self.down2 = torch.utils.checkpoint(self.down2)
self.down3 = torch.utils.checkpoint(self.down3)
self.down4 = torch.utils.checkpoint(self.down4)
self.up1 = torch.utils.checkpoint(self.up1)
self.up2 = torch.utils.checkpoint(self.up2)
self.up3 = torch.utils.checkpoint(self.up3)
self.up4 = torch.utils.checkpoint(self.up4)
self.outc = torch.utils.checkpoint(self.outc)
3.步骤二
部分代码如下:
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
import pandas as pd
import shutil
from sklearn.preprocessing import MinMaxScaler
from tqdm import tqdm_notebook as tqdm
import matplotlib.pyplot as plt
from darts import TimeSeries
from darts.dataprocessing.transformers import Scaler
from darts.models import RNNModel, ExponentialSmoothing, BlockRNNModel
from darts.metrics import mape, mae, mse, rmse
from darts.utils.statistics import check_seasonality, plot_acf
from darts.datasets import AirPassengersDataset, SunspotsDataset
from darts.utils.timeseries_generation import datetime_attribute_timeseries
import warnings
warnings.filterwarnings("ignore")
import logging
logging.disable(logging.CRITICAL)
####################数据准备##########################
# Read data:
series = AirPassengersDataset().load() #原始数据集由 Box, Jenkins 和 Reinsel 在他们的书籍《Time Series Analysis: Forecasting and Control》中引入
# Create training and validation sets:
train, val = series.split_after(pd.Timestamp("19590101")) ##可以填写具体的日期,也可以填写比例
# Normalize the time series (note: we avoid fitting the transformer on the validation set)
transformer = Scaler()
train_transformed = transformer.fit_transform(train)
val_transformed = transformer.transform(val)
series_transformed = transformer.transform(series)
# create month and year covariate series
year_series = datetime_attribute_timeseries(
pd.date_range(start=series.start_time(), freq=series.freq_str, periods=1000),
attribute="year",
one_hot=False,
)
year_series = Scaler().fit_transform(year_series)
month_series = datetime_attribute_timeseries(
year_series, attribute="month", one_hot=True
)
covariates = year_series.stack(month_series)
cov_train, cov_val = covariates.split_after(pd.Timestamp("19590101"))
####################构建模型##########################
my_model = RNNModel(
model="LSTM",
hidden_dim=20,
dropout=0,
batch_size=16,
n_epochs=300,
optimizer_kwargs={"lr": 1e-3},
model_name="Air_RNN",
log_tensorboard=True,
random_state=42,
training_length=20,
input_chunk_length=14,
force_reset=True,
save_checkpoints=True,
)
my_model.fit(
train_transformed,
future_covariates=covariates,
val_series=val_transformed,
val_future_covariates=covariates,
verbose=True,
)
完整代码下载地址:下载地址