时间序列预测实战(十一)用SCINet实现滚动预测功能(附代码+数据集+原理介绍)

论文地址->SCINet官方论文地址

官方代码地址-> 官方代码下载地址

个人整理的代码地址->免费分享给大家创作不易请大家给文章点点赞

一、本文介绍

这篇文章给大家带来的是关于SCINet实现时间序列滚动预测功能的讲解,SCINet是样本卷积交换网络的缩写(Sample Convolutional Interchange Network),SCINet号称是比现有的卷积模型和基于Transformer的模型准确率都有提升(我实验了几次效果确实不错)。本篇文章讲解的代码是我个人根据官方的代码总结出来的模型结构并且进行改进增加了滚动预测的功能。本篇实战案例中包括->详细的参数讲解、改进方向、数据集介绍、模型框架原理、项目结构、如何训练个人数据集的教程、以及结果分析和结果展示。本篇文章的讲解流程为->

适用对象->适合对精度有比较高要求的学习者

预测类型->单元变量预测、多元变量预测

二、模型框架原理

1.SCINet基本原理

SCINet是一个层次化的降采样-卷积-交互TSF框架,有效地对具有复杂时间动态的时间序列进行建模。通过在多个时间分辨率上迭代提取和交换信息,可以学习到具有增强可预测性的有效表示。此外,SCINet的基础构件,SCI-Block,通过将输入数据/特征降采样为两个子序列,然后使用不同的卷积滤波器提取每个子序列的特征。为了补偿降采样过程中的信息损失,每个SCI-Block内部都加入了两种卷积特征之间的交互学习。

个人总结:SCINet就是在不同的维度上面对数据进行处理进行特征提取工作,从而获得不同层次的特征。这点有点类似于目标检测的YOLO系列模型,一张图片进行不断的缩放和扩大获取不同层次的特征,然后对这些特征进行操作,既节省算力又提高精度。(SCINet引入了一个新的概念时间分辨率大家可以注意一下)

2.SCINet基本组件

 下图为SCINet网络结构图

SCINet采用编码器-解码器架构。编码器是一个分层卷积网络,通过丰富的卷积滤波器捕捉多分辨率下的动态时间依赖性。其基本构件SCI-Block将输入数据或特征降采样为两个子序列,然后用一组卷积滤波器处理每个子序列,从每部分中提取独特但有价值的时间特征。为了补偿降采样中的信息损失,它允许两个子序列之间的交互学习。SCINet通过将多个SCI-Blocks排列成二叉树结构来构建。这种设计的一个显著优势是每个SCI-Block都对整个时间序列有局部和全局视角,从而有助于提取有用的时间特征。经过所有降采样-卷积-交互操作后,将提取的特征重新排列成新的序列表示,并将其加入原始时间序列中,用全连接网络作为解码器进行预测。

个人总结:SCINet就是将多个SCI-Block用二叉树的结构堆叠起来,然后提取不从层次的特征,然后从新排列起来,然后经过一个全连接层进行预测。

改进方案:这里其实有改进的空间,经过我的训练过程我发现这个模型训练时间还是比较长的,就是因为他堆叠多个层的SCI-Block这里我觉得可以结和一些新的结构进行改造的,类似于不进行二叉树的操作,但是将不同层次的特征融合起来,有兴趣的小伙伴可以研究一下,没准能发个论文毕竟文章的简写就是SCI~~。

三、数据集介绍 

这个模型用了两个数据集进行测试,一个是某个公司的话务员接线量一个是油温效果都不错,下面讲解用油温的数据集来进行讲解和结果分析。

数据集的部分截图如下->其具有八列数据‘OT’其中间的关系为化学关系比较固定,为油温度。

四、参数讲解 

模型的全部参数如下->

    parser.add_argument('--train', type=bool, default=True, help='Whether to conduct training')
    parser.add_argument('--rollingforecast', type=bool, default=True, help='rolling forecast True or False')
    parser.add_argument('--rolling_data_path', type=str, default='ETTh1-Test.csv', help='rolling data file')
    parser.add_argument('--model', type=str, default='Transformer',
                        help='model name, options: [Transformer, Linear, NLinear, DLinear, SCINet, ConvFC, MTSMixer, MTSMatrix, FNet]')

    # data loader
    parser.add_argument('--root_path', type=str, default='./', help='root path of the data file')
    parser.add_argument('--data_path', type=str, default='ETTh1.csv', help='data file')
    parser.add_argument('--features', type=str, default='MS',
                        help='forecasting task, options:[M, S, MS]; M:multivariate predict multivariate, S:univariate predict univariate, MS:multivariate predict univariate')
    parser.add_argument('--target', type=str, default='OT', help='target feature in S or MS task')
    parser.add_argument('--freq', type=str, default='h',
                        help='freq for time features encoding, options:[s:secondly, t:minutely, h:hourly, d:daily, b:business days, w:weekly, m:monthly], you can also use more detailed freq like 15min or 3h')
    parser.add_argument('--checkpoints', type=str, default='./models/', help='location of model models')

    # forecasting task
    parser.add_argument('--seq_len', type=int, default=32, help='input sequence length')
    parser.add_argument('--label_len', type=int, default=8, help='start token length')
    parser.add_argument('--pred_len', type=int, default=4, help='prediction sequence length')

    # model
    parser.add_argument('--rev', action='store_true', default=False, help='whether to apply RevIN')
    parser.add_argument('--enc_in', type=int, default=7, help='encoder input size')
    parser.add_argument('--dec_in', type=int, default=7, help='decoder input size')
    parser.add_argument('--c_out', type=int, default=1, help='output size')
    parser.add_argument('--d_model', type=int, default=512, help='dimension of model')

    parser.add_argument('--dropout', type=float, default=0.05, help='dropout')
    parser.add_argument('--embed', type=str, default='timeF',
                        help='time features encoding, options:[timeF, fixed, learned]')
    parser.add_argument('--activation', type=str, default='gelu', help='activation')

    # optimization
    parser.add_argument('--num_workers', type=int, default=0, help='data loader num workers')
    parser.add_argument('--train_epochs', type=int, default=10, help='train epochs')
    parser.add_argument('--batch_size', type=int, default=16, help='batch size of train input data')
    parser.add_argument('--learning_rate', type=float, default=0.001, help='optimizer learning rate')
    parser.add_argument('--loss', type=str, default='mse', help='loss function')
    parser.add_argument('--lradj', type=str, default='type1', help='adjust learning rate')

    # GPU
    parser.add_argument('--use_gpu', type=bool, default=True, help='use gpu')
    parser.add_argument('--device', type=int, default=0, help='gpu')

模型的详细参数讲解如下-> 

参数名称参数类型参数讲解
0trainbool是否进行训练,如果你单纯只想进行预测设置为False即可,
1rollingforecastbool是否进行滚动预测,如果是则设置为True,如果不进行滚动预测则进行正常的预测
2rolling-data-pathstr如果进行滚动预测则需要添加新的和训练文件相同格式的数据
3modelstr定义的模型名称
4root_pathstr这个才是你文件的路径,不要到具体的文件,到目录级别即可。
5data_pathstr这个填写你文件的具体名称。
6featuresstr这个是特征有三个选项M,MS,S。分别是多元预测多元,多元预测单元,单元预测单元。
7targetstr这个是你数据集中你想要预测那一列数据,假设我预测的是油温OT列就输入OT即可。
8freqstr时间的间隔,你数据集每一条数据之间的时间间隔。
9checkpointsstr训练出来的模型保存路径
10seq_lenint用过去的多少条数据来预测未来的数据
11label_lenint可以理解为更高的权重占比的部分要小于seq_len
12pred_lenint预测未来多少个时间点的数据
13enc_inint你数据有多少列,要减去时间那一列,这里我是输入8列数据但是有一列是时间所以就填写7
14dec_inint同上
15c_outint这里有一些不同如果你的features填写的是M那么和上面就一样,如果填写的MS那么这里要输入1因为你的输出只有一列数据。
16d_modelint用于设置模型的维度,默认值为512。可以根据需要调整该参数的数值来改变模型的维度
15n_headsint用于设置模型中的注意力头数。默认值为8,表示模型会使用8个注意力头,我建议和的输入数据的总体保持一致,列如我输入的是8列数据不用刨去时间的那一列就输入8即可。
17e_layersint用于设置编码器的层数
18d_layersint用于设置解码器的层数
19s_layersstr用于设置堆叠编码器的层数
20dropoutfloat这个应该都理解不说了,丢弃的概率,防止过拟合的。
21embedstr时间特征的编码方式,默认为"timeF"
22activationstr激活函数
23num_workersint线程windows大家最好设置成0否则会报线程错误,linux系统随便设置。
24train_epochsint训练的次数
25batch_sizeint一次往模型力输入多少条数据
26learning_ratefloat学习率。
27lossstr     损失函数,默认为"mse"
28lradjstr     学习率的调整方式,默认为"type1"
29use_gpubool是否使用GPU训练,根据自身来选择
30gpuintGPU的编号

五、项目结构 

项目的构造目录如下->

其中data用于方法训练数据,layers用于存放模型,models用于存放训练的保存结果,results用于存放模型的预测结果为CSV的输出格式文件,util用于存放一些工具。 

六、训练和预测

1.训练模型

经过参数的讲解我们已经定义好了所有的参数,可以开始训练了,我的完整main.py文件调好参数的内容如下->

if __name__ == '__main__':
    parser = argparse.ArgumentParser(description='SCINet Multivariate Time Series Forecasting')
    # basic config
    parser.add_argument('--train', type=bool, default=True, help='Whether to conduct training')
    parser.add_argument('--rollingforecast', type=bool, default=True, help='rolling forecast True or False')
    parser.add_argument('--rolling_data_path', type=str, default='ETTh1-Test.csv', help='rolling data file')
    parser.add_argument('--model', type=str, default='SCINet',help='Model name')

    # data loader
    parser.add_argument('--root_path', type=str, default='./data/', help='root path of the data file')
    parser.add_argument('--data_path', type=str, default='ETTh1.csv', help='data file')
    parser.add_argument('--features', type=str, default='MS',
                        help='forecasting task, options:[M, S, MS]; M:multivariate predict multivariate, S:univariate predict univariate, MS:multivariate predict univariate')
    parser.add_argument('--target', type=str, default='OT', help='target feature in S or MS task')
    parser.add_argument('--freq', type=str, default='h',
                        help='freq for time features encoding, options:[s:secondly, t:minutely, h:hourly, d:daily, b:business days, w:weekly, m:monthly], you can also use more detailed freq like 15min or 3h')
    parser.add_argument('--checkpoints', type=str, default='./models/', help='location of model models')

    # forecasting task
    parser.add_argument('--seq_len', type=int, default=32, help='input sequence length')
    parser.add_argument('--label_len', type=int, default=8, help='start token length')
    parser.add_argument('--pred_len', type=int, default=4, help='prediction sequence length')

    # model
    parser.add_argument('--rev', action='store_true', default=False, help='whether to apply RevIN')
    parser.add_argument('--enc_in', type=int, default=7, help='encoder input size')
    parser.add_argument('--dec_in', type=int, default=7, help='decoder input size')
    parser.add_argument('--c_out', type=int, default=1, help='output size')
    parser.add_argument('--d_model', type=int, default=512, help='dimension of model')

    parser.add_argument('--dropout', type=float, default=0.05, help='dropout')
    parser.add_argument('--embed', type=str, default='timeF',
                        help='time features encoding, options:[timeF, fixed, learned]')
    parser.add_argument('--activation', type=str, default='gelu', help='activation')

    # optimization
    parser.add_argument('--num_workers', type=int, default=0, help='data loader num workers')
    parser.add_argument('--train_epochs', type=int, default=10, help='train epochs')
    parser.add_argument('--batch_size', type=int, default=16, help='batch size of train input data')
    parser.add_argument('--learning_rate', type=float, default=0.001, help='optimizer learning rate')
    parser.add_argument('--loss', type=str, default='mse', help='loss function')
    parser.add_argument('--lradj', type=str, default='type1', help='adjust learning rate')

    # GPU
    parser.add_argument('--use_gpu', type=bool, default=True, help='use gpu')
    parser.add_argument('--device', type=int, default=0, help='gpu')
    args = parser.parse_args()
    Exp = SCINetinitialization
    # setting record of experiments
    setting = 'predict-{}-data-{}'.format(args.model, args.data_path[:-4])

    SCI = SCINetinitialization(args)  # 实例化模型
    if args.train:
        print('>>>>>>>start training : {}>>>>>>>>>>>>>>>>>>>>>>>>>>'.format(args.model))
        SCI.train(setting)
    print('>>>>>>>predicting : {}<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<'.format(args.model))
    SCI.predict(setting, True)

我们进行执行控制台开始输出训练结果->

 训练完成后,模型保存到该目录下->

2.开始预测 

训练完成之后,开始进行预测,控制台会进行结果的输出,最后会生成结果文件。

结果文件输出在如下目录->

3.1结果展示 

下面的图片是预测值和真实值的对比图,大家可以看出预测结果还是非常不错的。 

下面的图片为MAE的损失图->

3.2结果分析 

可以看出虽然预测结果还可以接受,但是其中存在明显的数据滞后性,这个问题其实是时间序列预测的通病,目前想要解决两种方法:

  • 一种方法是通过损失精度然后进行数据的预处理操作
  • 另一种就是结合其它能够处理数据滞后性的模型进行模型融合的操作。

后期我也会进行模型融合尝试如果大家需要可以在评论区留言想要看和其它什么模型结合。

七、训练你个人数据集 

这个模型我在写的过程中为了节省大家训练自己数据集,我基本上把大部分的参数都写好了,需要大家注意的就是如果要进行滚动预测下面的参数要设置为True。

    parser.add_argument('--rollingforecast', type=bool, default=True, help='rolling forecast True or False')

如果上面的参数设置为True那么下面就要提供一个进行滚动预测的数据集该数据集的格式要和你训练模型的数据集格式完全一致(重要!!!),如果没有可以考虑在自己数据的尾部剪切一部分,不要粘贴否则数据模型已经训练过了的话预测就没有效果了。 

    parser.add_argument('--rolling_data_path', type=str, default='ETTh1-Test.csv', help='rolling data file')

其它的没什么可以讲的了大部分的修改操作在参数讲解的部分我都详细讲过了,这里的滚动预测可能是大家想看的所以摘出来详细讲讲。 

总结  

到此本文已经全部讲解完成了,希望能够帮助到大家,如果你用我的代码可能会存在一些Bug但是肯定不影响运行,如果大家有发现任何的bug可以和我私信沟通,或者评论区留言我可以,大家可以进行探讨。在这里也给大家推荐一些我其它的博客的时间序列实战案例讲解,其中有数据分析的讲解就是我前面提到的如何设置参数的分析博客,最后希望大家订阅我的专栏,本专栏均分文章均分98,并且免费阅读。

时间序列预测:深度学习、机器学习、融合模型、创新模型实战案例(附代码+数据集+原理介绍)

时间序列预测模型实战案例(十)(个人创新模型)通过堆叠CNN、GRU、LSTM实现多元预测和单元预测

时间序列预测中的数据分析->周期性、相关性、滞后性、趋势性、离群值等特性的分析方法

时间序列预测模型实战案例(八)(Informer)个人数据集、详细参数、代码实战讲解

时间序列预测模型实战案例(七)(TPA-LSTM)结合TPA注意力机制的LSTM实现多元预测

时间序列预测模型实战案例(六)深入理解机器学习ARIMA包括差分和相关性分析

时间序列预测模型实战案例(五)基于双向LSTM横向搭配单向LSTM进行回归问题解决

时间序列预测模型实战案例(四)(Xgboost)(Python)(机器学习)图解机制原理实现时间序列预测和分类(附一键运行代码资源下载和代码讲解)

时间序列预测模型实战案例(三)(LSTM)(Python)(深度学习)时间序列预测(包括运行代码以及代码讲解)

【全网首发】(MTS-Mixers)(Python)(Pytorch)最新由华为发布的时间序列预测模型实战案例(一)(包括代码讲解)实现企业级预测精度包括官方代码BUG修复Transform模型

时间序列预测模型实战案例(二)(Holt-Winter)(Python)结合K-折交叉验证进行时间序列预测实现企业级预测精度(包括运行代码以及代码讲解)

最后希望大家工作顺利学业有成!

​​

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/130305.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

devops完整搭建教程(gitlab、jenkins、harbor、docker)

devops完整搭建教程&#xff08;gitlab、jenkins、harbor、docker&#xff09; 文章目录 devops完整搭建教程&#xff08;gitlab、jenkins、harbor、docker&#xff09;1.简介&#xff1a;2.工作流程&#xff1a;3.优缺点4.环境说明5.部署前准备工作5.1.所有主机永久关闭防火墙…

HTTPS的工作流程

. HTTPS是什么&#xff1f; https是应用层中的一个协议&#xff0c;是在http协议的基础上引入的一个加密层。 为什么需要HTTPS 由于http协议内容都是按照文本的方式明文传输的&#xff0c;这就导致传输过程中会出现一些被篡改的情况。运营商劫持事件最开始百度&#xff0c;…

比较PID控制和神经网络控制在机器人臂上的应用

机器人臂是自动化领域中常见的机器人形式&#xff0c;其精确控制对于实现复杂任务具有重要意义。在机器人臂的控制中&#xff0c;PID控制和神经网络控制是两种常用的控制方法。本文将比较PID控制和神经网络控制在机器人臂控制方面的应用&#xff0c;包括控制原理、优缺点以及在…

【广州华锐互动】太空探索VR模拟仿真教学系统

随着科技的不断发展&#xff0c;人类对宇宙的探索欲望愈发强烈。火星作为距离地球最近的行星之一&#xff0c;自然成为了人类关注的焦点。近年来&#xff0c;火星探测取得了一系列重要成果&#xff0c;为人类了解火星提供了宝贵的信息。然而&#xff0c;实地考察火星仍然面临着…

C++——基础

初学C的时候&#xff0c;有没有想过&#xff0c;为什么C支持重载&#xff0c;而C不支持重载呢&#xff1f;&#xff1f; 其实&#xff0c;一个程序运行起来都要经过四步骤 预处理编译汇编链接 预处理阶段会经过去注释&#xff0c;宏替换&#xff0c;头文件展开&#xff0c;条…

Liunx终极环境搭建

华子目录 网络服务准备工作安装RHEL9系统部署RHEL9操作系统虚拟网络编辑器配置RHEL9系统系统中的设置更换yum源修改主机名关闭selinux&#xff0c;firewalld设置静态ip &#xff08;网络配置&#xff09; 网络服务 准备工作 以下为RHEL9镜像资源&#xff0c;有需要的博友们可…

Ubuntu(WSL) mysql8.0.31 源码安装

要在 Ubuntu 上使用调试功能安装 MySQL 8.0 的源码&#xff0c;可以按照以下详细步骤进行操作&#xff1a; 1. 更新系统 首先&#xff0c;确保你的 Ubuntu 系统是最新的。运行以下命令更新系统软件包&#xff1a; sudo apt update sudo apt upgrade 2. 下载 MySQL 源码 访…

ChatGPT Plus的Vision升级是一个改变游戏规则的创举

内容来源&#xff1a;0xluffy_eth ChatGPT Plus的Vision升级是一个改变游戏规则的创举&#xff01; 现在每个用户都可以以每月20美元的价格雇用自己的个人数字助理实习生&#xff0c;具备VISION&#xff01; 以下是10个惊人的例子&#xff08;&#xff09; 1&#xff0c; 我…

Blender--》点线面操作及其面操作的详解

接下来我会在three.js专栏中分享关于3D建模知识的文章&#xff0c;如果学习three朋友并且想了解和学习3D建模&#xff0c;欢迎关注本专栏&#xff0c;关于这款3D建模软件blender的安装&#xff0c;我在前面的文章已经讲解过了&#xff0c;如果不了解的朋友可以去考考古&#xf…

this和super

文章目录 this用法普通的直接引用区分形参与实参 super用法普通的直接引用区分子类与父类同名的属性或方法 this和super 与 构造方法总结 this this引用表示当前对象对象的引用。 用法 普通的直接引用 public class Test {int a ;int b;public Test() {this.b 0;} }调用当…

mysql的sql_mode参数

msql修改了这个参数&#xff0c;首先mysql需要重新才能生效&#xff0c;还有就是java连接的springboot项目也需要重新启动。之前是遇到了下面的这个报错。只需要把sql_mode设置为空&#xff0c;重启mysql和服务就行 报错 In aggregated query without GROUP BY, expression #1…

使用 pubsub-js 进行消息发布订阅

npm 包地址 github 包地址 pubsub-js 是一个轻量级的 JavaScript 基于主题的消息订阅发布库 &#xff0c;压缩后小于1b。它具有使用简单、性能高效、支持多平台等优点&#xff0c;可以很好地满足各种需求。 功能特点&#xff1a; 无依赖同步解耦ES3 兼容。pubsub-js 能够在…

Vatee万腾外汇数字化策略:Vatee科技决策力的未来引领

在外汇市场&#xff0c;Vatee万腾通过其前瞻性的外汇数字化策略&#xff0c;正引领着科技决策的未来。这一数字化策略的崭新愿景为投资者提供了更智慧、更高效的外汇投资体验&#xff0c;成为科技决策领域的翘楚。 Vatee万腾的外汇数字化策略是科技决策力未来引领的典范。通过运…

C# PaddleInference.PP-HumanSeg 人像分割 替换背景色

效果 项目 VS2022.net4.8OpenCvSharp4Sdcb.PaddleInference 包含4个分割模型 modnet-hrnet_w18 modnet-mobilenetv2 ppmatting-hrnet_w18-human_512 ppmattingv2-stdc1-human_512 代码 using OpenCvSharp; using Sdcb.PaddleInference; using System; using System.Col…

Springboot SpringCloudAlibaba Nacos 项目搭建

依赖版本&#xff1a; spring-boot&#xff1a;2.3.12.RELEASE spring-cloud-alibaba&#xff1a;2.2.7.RELEASE spring-cloud&#xff1a;Hoxton.SR12 nacos&#xff1a;2.0.3 1.部署搭建Nacos注册中心 Linux Nacos 快速启动_nacos linux快速启动-CSDN博客 2.构建项目 源码地…

STM32——STM32F4系统架构

文章目录 前言STM32F4XX系统架构 前言 本篇文章为STM32F4系列的系统架构&#xff0c;因为最近在学习F4的板子&#xff0c;暂时先更F4的&#xff0c;有需要F1的后续再更新。 主系统由 32 位多层 AHB 总线矩阵构成&#xff0c;可实现以下部分的互连&#xff1a; STM32F4XX系统架…

19. 深度学习 - 用函数解决问题

文章目录 Hi&#xff0c; 你好。我是茶桁。 上一节课&#xff0c;我们从一个波士顿房价的预测开始写代码&#xff0c;写到了KNN。 之前咱们机器学习课程中有讲到KNN这个算法&#xff0c;分析过其优点和缺点&#xff0c;说起来&#xff0c;KNN这种方法比较低效&#xff0c;在数…

万能在线预约小程序系统源码 适合任何行业在线预约小程序+预约到店模式 带完整的搭建教程

大家好啊&#xff0c;源码小编又来给大家分享啦&#xff01;随着互联网的发展和普及&#xff0c;越来越多的服务行业开始使用在线预约系统以方便客户和服务管理。例如&#xff0c;美发店、健身房、餐厅等都可以通过在线预约系统提高服务效率&#xff0c;减少等待时间&#xff0…

开机自启动笔记本的小键盘

虽然电脑开机次数不多&#xff0c;但每次开机都要摁下小键盘的开关&#xff0c;好烦 终于忍不住了&#xff1a; 将下面文件命名为 XXX.bat echo off rem 禁用批处理文件中的命令回显&#xff0c;以使输出更整洁rem 查询注册表中 "InitialKeyboardIndicators" 的值 r…

赛氪中西部外语翻译大赛入榜2023国内翻译赛事发展评估报告

中西部外语翻译大赛入选中国外文局CATTI项目管理中心和中国外文界平台联合发布《2023国内翻译赛事发展评估报告》 近日&#xff0c;中国外文局CATTI项目管理中心和中国外文界平台联合发布了《2023国内翻译赛事发展评估报告》&#xff0c;报告对国内主流外语翻译赛事进行了问卷调…