darts是一个强大而易用的Python时间序列建模工具包。在github上目前拥有超过7k颗stars。
它主要支持以下任务:
时间序列预测 (包含 ARIMA, LightGBM模型, TCN, N-BEATS, TFT, DLinear, TiDE等等)
时序异常检测 (包括 分位数检测 等等)
时间序列滤波 (包括 卡尔曼滤波,高斯过程滤波)
本文演示使用darts构建N-BEATS模型对 牛奶月销量数据进行预测~
公众号算法美食屋后台回复关键词:源码,获取本文notebook源码和数据集~
!pip install darts
一, 准备数据
首先,你需要准备时间序列数据。如果数据有缺失,需要进行数据填充。
这里示范的是一个每月牛奶销量数据集。
import numpy as np
import pandas as pd
from matplotlib import pyplot as plt
import darts
from darts import TimeSeries
from darts.dataprocessing.transformers import MissingValuesFiller
from darts.dataprocessing.transformers import Scaler
# 1,读取数据集
df = pd.read_csv('month_milk.csv')
df.columns = ['ds','y']
df = df.sort_values(by='ds')
#df['y'] = df['y'].interpolate(method='linear')
# 2,填充数据集
ts_raw = TimeSeries.from_dataframe(df,time_col='ds',value_cols=['y'])
#df = ts_raw.pd_dataframe() #timeseries转成dataframe
fig, ax1 = plt.subplots(1, 1, figsize=(10,5))
ts_raw.plot(color='brown',ax = ax1)
ax1.set_title('before fill')
#from darts.utils.missing_values import fill_missing_values
#ts = fill_missing_values(ts_raw)
filler = MissingValuesFiller()
ts = filler.transform(ts_raw)
fig, ax2 = plt.subplots(1, 1, figsize=(10,5))
ts.plot(color='cyan', ax = ax2)
ax2.set_title('after fill')
# 3,分割数据集
ts_train, ts_test= ts.split_after(pd.Timestamp("1973-01-01"))
fig, ax3 = plt.subplots(1,1,figsize=(10,5))
ts_train.plot(color='blue',label='train',ax=ax3)
ts_test.plot(color='red',label='test',ax=ax3)
ax3.set_title('after split');
# 4,缩放数据集
scaler = Scaler()
ts_train_scaled = scaler.fit_transform(ts_train)
ts_test_scaled = scaler.transform(ts_test)
ts_scaled = scaler.transform(ts)
二, 定义模型
接下来,定义一个时间序列模型。Darts支持多种模型。
包括统计模型(ARIMA, FFT, ExponentialSmoothing等)
机器学习模型(Prophet,LightGBM等)
神经网络模型(RNN类型,Transformer类型,CNN类型,MLP类型)。
此处我们使用 NBeats模型。
import warnings
warnings.filterwarnings("ignore")
import logging
logging.disable(logging.CRITICAL)
import wandb
wandb.login()
from darts.models import NBEATSModel
import pytorch_lightning as pl
from darts.utils.callbacks import TFMProgressBar
from pytorch_lightning.loggers import WandbLogger
wandb_logger = WandbLogger(project='MILK')
progress_bar = TFMProgressBar(enable_train_bar_only=True)
early_stopper = pl.callbacks.EarlyStopping(monitor="val_loss",patience=10,
min_delta=1e-5,mode="min")
pl_trainer_kwargs = dict(max_epochs=100,
accelerator='cpu',
callbacks = [progress_bar,early_stopper],
logger = wandb_logger
)
settings = dict(
input_chunk_length=10,
output_chunk_length=3,
generic_architecture=True,
num_stacks=10,
num_blocks=1,
num_layers=4,
layer_widths=512,
save_checkpoints=True,
batch_size=800,
random_state=42,
model_name='nbeats',
force_reset=True
)
settings['pl_trainer_kwargs'] = pl_trainer_kwargs
model = NBEATSModel(**settings)
三,训练模型
model.fit(series = ts_train_scaled, val_series= ts_train_scaled)
model = model.load_from_checkpoint(model_name=model.model_name,best=True)
四,评估模型
# 历史数据逐段回测,使用真实历史数据作为特征,不做滚动预测
ts_test_forecast = model.historical_forecasts(
series=ts_scaled,
start=ts_test_scaled.start_time(),
forecast_horizon=3,
stride=3,
last_points_only=False,
retrain=False,
verbose=True,
)
from darts import concatenate
ts_test_forecast = scaler.inverse_transform(concatenate(ts_test_forecast))
ts_test_true = ts[ts_test_forecast.time_index]
test_score = r2_score(ts_test_true, ts_test_preds)
import matplotlib.pyplot as plt
plt.figure(figsize=(8, 5))
ts_test_true.plot(label="y",color='blue')
ts_test_forecast.plot(label='yhat-forecast',color='cyan')
plt.title(
"yhat-forecast R2-score: {}".format(test_score)
)
plt.legend()
五,使用模型
#滚动预测,使用预测的数据作为后面预测步骤的特征
# (注意:当预测步数 n 小于等于模型的output_chunk_length,无需滚动)
ts_preds = model.predict(n = len(ts_test_scaled),series = ts_train_scaled, num_samples=1)
ts_test_preds = scaler.inverse_transform(ts_preds)
from darts.metrics import r2_score
test_score = r2_score(ts_test_true, ts_test_preds)
import matplotlib.pyplot as plt
plt.figure(figsize=(8, 5))
ts_test_true.plot(label="y",color='blue')
ts_test_forecast.plot(label="yhat-forecast",color='cyan')
ts_test_preds.plot(label='yhat',color='red')
plt.title(
"yhat R2-score: {}".format(test_score)
)
plt.legend()
六,保存模型
model.save('nbeats')
model_loaded = NBEATSModel.load('nbeats') #重新加载