darts 时序预测入门

darts是一个强大而易用的Python时间序列建模工具包。在github上目前拥有超过7k颗stars。

它主要支持以下任务:

  • 时间序列预测 (包含 ARIMA, LightGBM模型, TCN, N-BEATS, TFT, DLinear, TiDE等等)

  • 时序异常检测 (包括 分位数检测 等等)

  • 时间序列滤波 (包括 卡尔曼滤波,高斯过程滤波)

本文演示使用darts构建N-BEATS模型对 牛奶月销量数据进行预测~

公众号算法美食屋后台回复关键词:源码,获取本文notebook源码和数据集~

!pip install darts

一,  准备数据

首先,你需要准备时间序列数据。如果数据有缺失,需要进行数据填充。

这里示范的是一个每月牛奶销量数据集。

 
 
import numpy as np 
import pandas as pd 
from matplotlib import pyplot as plt


import darts
from darts import TimeSeries
from darts.dataprocessing.transformers import MissingValuesFiller
from darts.dataprocessing.transformers import Scaler 




# 1,读取数据集
df = pd.read_csv('month_milk.csv')
df.columns = ['ds','y']
df = df.sort_values(by='ds')
#df['y'] = df['y'].interpolate(method='linear')




# 2,填充数据集
ts_raw = TimeSeries.from_dataframe(df,time_col='ds',value_cols=['y'])
#df = ts_raw.pd_dataframe() #timeseries转成dataframe
fig, ax1 = plt.subplots(1, 1, figsize=(10,5))
ts_raw.plot(color='brown',ax = ax1)
ax1.set_title('before fill')


#from darts.utils.missing_values import fill_missing_values
#ts = fill_missing_values(ts_raw)


filler = MissingValuesFiller()
ts = filler.transform(ts_raw)


fig, ax2 = plt.subplots(1, 1, figsize=(10,5))
ts.plot(color='cyan', ax = ax2)
ax2.set_title('after fill')


# 3,分割数据集
ts_train, ts_test= ts.split_after(pd.Timestamp("1973-01-01")) 
fig, ax3 = plt.subplots(1,1,figsize=(10,5))
ts_train.plot(color='blue',label='train',ax=ax3)
ts_test.plot(color='red',label='test',ax=ax3)
ax3.set_title('after split');




# 4,缩放数据集
scaler = Scaler()
ts_train_scaled = scaler.fit_transform(ts_train)
ts_test_scaled = scaler.transform(ts_test)
ts_scaled = scaler.transform(ts)

2724e31a0807564598b405d1427fbf05.png

0a0854fc32b002bd5207824d4847c037.png

f98d248ece22c91ccd55587a1c1c9941.png

二, 定义模型

接下来,定义一个时间序列模型。Darts支持多种模型。

包括统计模型(ARIMA, FFT, ExponentialSmoothing等)

机器学习模型(Prophet,LightGBM等)

神经网络模型(RNN类型,Transformer类型,CNN类型,MLP类型)。

此处我们使用 NBeats模型。

 
 
import warnings
warnings.filterwarnings("ignore")
import logging
logging.disable(logging.CRITICAL)
import wandb
wandb.login()
 
 
from darts.models import NBEATSModel 
import pytorch_lightning as pl 
from darts.utils.callbacks import TFMProgressBar
from pytorch_lightning.loggers import WandbLogger


wandb_logger = WandbLogger(project='MILK')
progress_bar = TFMProgressBar(enable_train_bar_only=True)
early_stopper = pl.callbacks.EarlyStopping(monitor="val_loss",patience=10,
        min_delta=1e-5,mode="min")
pl_trainer_kwargs = dict(max_epochs=100, 
     accelerator='cpu',
     callbacks = [progress_bar,early_stopper],
     logger = wandb_logger
) 


settings = dict(
    input_chunk_length=10,
    output_chunk_length=3,
    generic_architecture=True,
    num_stacks=10,
    num_blocks=1,
    num_layers=4,
    layer_widths=512,
    save_checkpoints=True,
    
    batch_size=800,
    random_state=42,
    model_name='nbeats',
    force_reset=True
)


settings['pl_trainer_kwargs'] = pl_trainer_kwargs


model = NBEATSModel(**settings)

三,训练模型

model.fit(series = ts_train_scaled, val_series= ts_train_scaled)
model = model.load_from_checkpoint(model_name=model.model_name,best=True)

四,评估模型

 
 
# 历史数据逐段回测,使用真实历史数据作为特征,不做滚动预测
ts_test_forecast = model.historical_forecasts(
    series=ts_scaled,
    start=ts_test_scaled.start_time(),
    forecast_horizon=3,
    stride=3,
    last_points_only=False,
    retrain=False,
    verbose=True,
)
 
 
from darts import concatenate
ts_test_forecast  = scaler.inverse_transform(concatenate(ts_test_forecast))
ts_test_true = ts[ts_test_forecast.time_index]
test_score = r2_score(ts_test_true, ts_test_preds)
 
 
import matplotlib.pyplot  as plt 
plt.figure(figsize=(8, 5))
ts_test_true.plot(label="y",color='blue')
ts_test_forecast.plot(label='yhat-forecast',color='cyan')
plt.title(
    "yhat-forecast R2-score: {}".format(test_score)
)
plt.legend()

9a18ae60157acae54c79a44dda143c4a.png

五,使用模型

#滚动预测,使用预测的数据作为后面预测步骤的特征 
# (注意:当预测步数 n 小于等于模型的output_chunk_length,无需滚动)
ts_preds = model.predict(n = len(ts_test_scaled),series = ts_train_scaled, num_samples=1)
ts_test_preds  = scaler.inverse_transform(ts_preds)
 
 
from darts.metrics import r2_score 
test_score = r2_score(ts_test_true, ts_test_preds)
 
 
import matplotlib.pyplot  as plt 
plt.figure(figsize=(8, 5))
ts_test_true.plot(label="y",color='blue')
ts_test_forecast.plot(label="yhat-forecast",color='cyan')
ts_test_preds.plot(label='yhat',color='red')
plt.title(
    "yhat R2-score: {}".format(test_score)
)
plt.legend()

0dd218379c73417198932b8ab071c98c.png

六,保存模型

 
 
model.save('nbeats')
model_loaded = NBEATSModel.load('nbeats') #重新加载

4872ae24e820f45e89741844e10f1021.png

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mfbz.cn/a/695544.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

【CS.OS】操作系统如何使用分页和分段技术管理内存

1000.5.CS.OS.1.3-基础-内存管理-操作系统如何使用分页和分段技术管理内存-Created: 2024-06-09.Sunday10:24 操作系统的内存管理是一个复杂而关键的功能,它确保了程序可以高效、安全地运行。虚拟内存管理是其中一个重要的概念,它通过分页和分段技术来实…

2024-6-9

今日安排: 学校的课程作业windows SEH 机制简单入门windows 用户态 pwn / 内核态入门 计网实验报告 && 网安实验报告继续审计 nf_tables 源码,主要看 active 相关逻辑。复现 CVE-2022-32250 这个漏洞【 && iptables 相关学习】♥♥♥♥…

文章解读与仿真程序复现思路——电力自动化设备EI\CSCD\北大核心《计及电力不平衡风险的配电网分区协同规划》

本专栏栏目提供文章与程序复现思路,具体已有的论文与论文源程序可翻阅本博主免费的专栏栏目《论文与完整程序》 论文与完整源程序_电网论文源程序的博客-CSDN博客https://blog.csdn.net/liang674027206/category_12531414.html 电网论文源程序-CSDN博客电网论文源…

缓存更新策略中级总结

背景 看到好些人在写更新缓存数据代码时,先删除缓存,然后再更新数据库,而后续的操作会把数据再装载的缓存中。然而,这个是逻辑是错误的。试想,两个并发操作,一个是更新操作,另一个是查询操作…

说说Lambda架构

Lambda架构由Storm的作者Nathan Marz提出,其设计目的在于提供一个能满足大数据库系统关键特性的架构,包括高容错、低延迟、可扩展等。其整合离线批处理和实时流处理,融合不可变形、读写分离和复杂隔离性等原则,集成Hadoop、Kafka、…

【C#线程设计】2:backgroundWorker

实现: (1).控件:group Box,text Box,check Box,label,botton,richtextbox 控件拉取见:https://blog.csdn.net/m0_74749240/article/details/139409510?spm1…

html+CSS+js部分基础运用19

1. 应用动态props传递数据,输出影片的图片、名称和描述等信息【要求使用props】,效果图如下: 2.在页面中定义一个按钮和一行文本,通过单击按钮实现放大文本的功能。【要求使用$emit()】 代码可以截图或者复制黏贴放置在“实验…

红黑树/红黑树迭代器封装(C++)

本篇将会较为全面的讲解有关红黑树的特点,插入操作,然后使用代码模拟实现红黑树,同时还会封装出红黑树的迭代器。 在 STL 库中的 set 和 map 都是使用红黑树封装的,在前文中我们讲解了 AVL树,对于红黑树和 AVL 树来说&…

手机自动化测试:4.通过appium inspector 获取相关app的信息,以某团为例,点击,搜索,获取数据等。

0.使用inspector时,一定要把不相关的如weditor啥的退出去,否则,净是事。 1.从0开始的数据获取 第一个位置,有时0.0.0.0,不可以的话,你就用这个。 第二个位置,抄上。 直接点击第三个启动。不要…

论文阅读:Indoor Scene Layout Estimation from a Single Image

项目地址:https://github.com/leVirve/lsun-room/tree/master 发表时间:2018 icpr 场景理解,在现实交互的众多方面中,因其在增强现实(AR)等应用中的相关性而得到广泛关注。场景理解可以分为几个子任务&…

Makefile:从零开始入门Makefile

目录 1.前言 2.Makefile的简单介绍 3.Makefile中的指令规则 4.Makefile的执行流程 5.Makefile中的变量类型 6.Makefile中的模式匹配 7.Makefile中的函数 8.Makefile补充知识 前言 在Linux中编译CPP文件,我们能够使用GCC命令进行编译,但当项目文件多且繁杂…

如何利用pandas解析html的表格数据

如何利用pandas解析html的表格数据 我们在编写爬虫的过程中,经常使用的就是parsel、bs4、pyquery等解析库。在博主的工作中经常的需要解析表格形式的html页面,常规的写法是,解析table表格th作为表头,解析td标签作为表格的行数据 …

网站不收录的原因

随着互联网的发展,越来越多的网站被创建和更新,然而,并不是所有的网站都能被搜索引擎收录。有时候,这些网站会因为各种原因而被搜索引擎排除在搜索结果之外。下面我们来探讨一下网站不收录的原因。 首先,网站不收录可能…

贪心算法学习三

例题一 解法(贪⼼): 贪⼼策略: ⽤尽可能多的字符去构造回⽂串: a. 如果字符出现偶数个,那么全部都可以⽤来构造回⽂串; b. 如果字符出现奇数个,减去⼀个之后,剩下的…

12.【Orangepi Zero2】基于orangepi_Zero_2 Linux的智能家居项目

基于orangPi Zero 2的智能家居项目 需求及项目准备 语音接入控制各类家电,如客厅灯、卧室灯、风扇回顾二阶段的Socket编程,实现Sockect发送指令远程控制各类家电烟雾警报监测, 实时检查是否存在煤气泄漏或者火灾警情,当存在警情时…

Robust Tiny Object Detection in Aerial Images amidst Label Noise

文章目录 AbstractIntroductionRelated WorkMethodsClass-aware Label CorrectionUpdateFilteringTrend-guided Learning StrategyTrend-guided Label ReweightingRecurrent Box RegenerationExperimentpaper Abstract 精确检测遥感图像中的小目标非常困难,因为这类目标视觉信…

关于目前ggrcs包的报错解决方案

目前有不少粉丝私信我说使用ggrcs包出现如下错误 我查看了一下,目前报错来源于新版本后的RMS包,主要是预测函数的报错,这个只能等R包作者来修复这个错误。目前需要急用的话,我提供了一个方案,请看下面视频操作 关于目前…

外部排序快速入门详解:基本原理,败者树,置换-选择排序,最佳归并树

文章目录 外部排序1.最基本的外部排序原理2.外部排序的优化2.1 败者树优化方法2.2 置换-选择排序优化方法2.3 最佳归并树 外部排序 为什么要学习外部排序? 答: 在处理数据的过程中,我们需要把磁盘(外存)中存储的数据拿到内存中处理…

通过 Python+Nacos实现微服务,细解微服务架构

shigen坚持更新文章的博客写手,擅长Java、python、vue、shell等编程语言和各种应用程序、脚本的开发。记录成长,分享认知,留住感动。 个人IP:shigen 背景 一直以来的想法比较多,然后就用Python编写各种代码脚本。很多…

在 Ubuntu 中安装 Docker

在 Ubuntu 中安装 Docker 首先,更新你的 Ubuntu 系统。 1、更新 Ubuntu 打开终端,依次运行下列命令: sudo apt update sudo apt upgrade sudo apt full-upgrade 2、添加 Docker 库 首先,安装必要的证书并允许 apt 包管理器…