NeuralForecast 多变量的处理 包括训练和推理

NeuralForecast 多变量的处理 包括训练和推理

flyfish
在这里插入图片描述

在这里插入图片描述

两个excel表格合并后的结果

      unique_id                  ds         y      ex_1      ex_2      ex_3      ex_4
0           HUFL 2016-07-01 00:00:00 -0.041413 -0.500000  0.166667 -0.500000 -0.001370
1           HUFL 2016-07-01 00:15:00 -0.185467 -0.500000  0.166667 -0.500000 -0.001370
2           HUFL 2016-07-01 00:30:00 -0.257495 -0.500000  0.166667 -0.500000 -0.001370
3           HUFL 2016-07-01 00:45:00 -0.577510 -0.500000  0.166667 -0.500000 -0.001370
4           HUFL 2016-07-01 01:00:00 -0.385501 -0.456522  0.166667 -0.500000 -0.001370
...          ...                 ...       ...       ...       ...       ...       ...
403195        OT 2018-02-20 22:45:00 -1.581325  0.456522 -0.333333  0.133333 -0.363014
403196        OT 2018-02-20 23:00:00 -1.581325  0.500000 -0.333333  0.133333 -0.363014
403197        OT 2018-02-20 23:15:00 -1.581325  0.500000 -0.333333  0.133333 -0.363014
403198        OT 2018-02-20 23:30:00 -1.562328  0.500000 -0.333333  0.133333 -0.363014
403199        OT 2018-02-20 23:45:00 -1.562328  0.500000 -0.333333  0.133333 -0.363014
import pandas as pd

from datasetsforecast.long_horizon import LongHorizon
# Change this to your own data to try the model
Y_df, X_df, _ = LongHorizon.load(directory='./', group='ETTm2')
Y_df['ds'] = pd.to_datetime(Y_df['ds'])

# X_df contains the exogenous features, which we add to Y_df
X_df['ds'] = pd.to_datetime(X_df['ds'])
Y_df = Y_df.merge(X_df, on=['unique_id', 'ds'], how='left')

print(Y_df.head)
#exit()

# We make validation and test splits
n_time = len(Y_df.ds.unique())
val_size = int(.2 * n_time)
test_size = int(.2 * n_time)
@dataclass
class LongHorizon:
    """
    This Long-Horizon datasets wrapper class, provides
    with utility to download and wrangle the following datasets:    
    ETT, ECL, Exchange, Traffic, ILI and Weather.
    
    - Each set is normalized with the train data mean and standard deviation.
    - Datasets are partitioned into train, validation and test splits.
    - For all datasets: 70%, 10%, and 20% of observations are train, validation, test, 
      except ETT that uses 20% validation.  
    """
    
    source_url: str = 'https://nhits-experiments.s3.amazonaws.com/datasets.zip'

    @staticmethod
    def load(directory: str,
             group: str,
             cache: bool = True) -> Tuple[pd.DataFrame, 
                                          Optional[pd.DataFrame], 
                                          Optional[pd.DataFrame]]:
        """
        
        Downloads and long-horizon forecasting benchmark datasets.

            Parameters
            ----------
            directory: str
                Directory where data will be downloaded.
            group: str
                Group name.
                Allowed groups: 'ETTh1', 'ETTh2', 
                                'ETTm1', 'ETTm2',
                                'ECL', 'Exchange',
                                'Traffic', 'Weather', 'ILI'.
            cache: bool
                If `True` saves and loads 

            Returns
            ------- 
            y_df: pd.DataFrame
                Target time series with columns ['unique_id', 'ds', 'y'].
            X_df: pd.DataFrame
                Exogenous time series with columns ['unique_id', 'ds', 'y']. 
            S_df: pd.DataFrame
                Static exogenous variables with columns ['unique_id', 'ds']. 
                and static variables. 
        """
        if group not in LongHorizonInfo.groups:
            raise Exception(f'group not found {group}')
            
        path = f'{directory}/longhorizon/datasets'
        file_cache = f'{path}/{group}.p'
        
        if os.path.exists(file_cache) and cache:
            df, X_df, S_df = pd.read_pickle(file_cache)
            
            return df, X_df, S_df
        
        LongHorizon.download(directory)
        path = f'{directory}/longhorizon/datasets'
        
        kind = 'M' if group not in ['ETTh1', 'ETTh2'] else 'S'
        name = LongHorizonInfo[group].name
        y_df = pd.read_csv(f'{path}/{name}/{kind}/df_y.csv')
        y_df = y_df.sort_values(['unique_id', 'ds'], ignore_index=True)
        y_df = y_df[['unique_id', 'ds', 'y']]
        X_df = pd.read_csv(f'{path}/{name}/{kind}/df_x.csv')
        X_df = y_df.drop('y', axis=1).merge(X_df, how='left', on=['ds'])
       
        S_df = None
        if cache:
            pd.to_pickle((y_df, X_df, S_df), file_cache)
            
        return y_df, X_df, S_df

    @staticmethod
    def download(directory: str) -> None:
        """
        Download ETT Dataset.
        
        Parameters
        ----------
        directory: str
            Directory path to download dataset.
        """
        path = f'{directory}/longhorizon/datasets/'
        if not os.path.exists(path):
             download_file(path, LongHorizon.source_url, decompress=True)

完整的训练保存模型文件

import pandas as pd

from datasetsforecast.long_horizon import LongHorizon
# Change this to your own data to try the model
Y_df, X_df, _ = LongHorizon.load(directory='./', group='ETTm2')
Y_df['ds'] = pd.to_datetime(Y_df['ds'])

# X_df contains the exogenous features, which we add to Y_df
X_df['ds'] = pd.to_datetime(X_df['ds'])
Y_df = Y_df.merge(X_df, on=['unique_id', 'ds'], how='left')

print(Y_df.head)
#exit()

# We make validation and test splits
n_time = len(Y_df.ds.unique())
val_size = int(.2 * n_time)
test_size = int(.2 * n_time)

from neuralforecast.core import NeuralForecast
from neuralforecast.models import TSMixer, TSMixerx, NHITS, MLPMultivariate,VanillaTransformer
from neuralforecast.losses.pytorch import MSE, MAE
horizon = 12
input_size = 24
models = [
          VanillaTransformer(h=horizon,
                input_size=input_size,
                max_steps=1,
                val_check_steps=1,
                early_stop_patience_steps=1,
                scaler_type='identity',
                valid_loss=MAE(),
                random_seed=12345678,
                ),  
                                                                      
         ]
nf = NeuralForecast(
    models=models,
    freq='15min')

Y_hat_df = nf.cross_validation(df=Y_df,
                               val_size=val_size,
                               test_size=test_size,
                               n_windows=None
                               )                                 
Y_hat_df = Y_hat_df.reset_index()
nf.save(path='./checkpoints/test_run/',
        model_index=None, 
        overwrite=True,
        save_dataset=True)

完整的推理代码

import pandas as pd
from neuralforecast.core import NeuralForecast
from neuralforecast.models import VanillaTransformer
from neuralforecast.losses.pytorch import MAE

# 示例数据
data = {
    'unique_id': ['HUFL'] * 5,
    'ds': [
        '2016-07-01 00:00:00', '2016-07-01 00:15:00', '2016-07-01 00:30:00', '2016-07-01 00:45:00', '2016-07-01 01:00:00'
    ],
    'y': [-0.041413, -0.185467, -0.257495, -0.577510, -0.385501],
    'ex_1': [-0.5, -0.5, -0.5, -0.5, -0.456522],
    'ex_2': [0.166667, 0.166667, 0.166667, 0.166667, 0.166667],
    'ex_3': [-0.5, -0.5, -0.5, -0.5, -0.5],
    'ex_4': [-0.001370, -0.001370, -0.001370, -0.001370, -0.001370]
}

# 创建 DataFrame
df = pd.DataFrame(data)
df['ds'] = pd.to_datetime(df['ds'])

# 使用 NeuralForecast 库进行预测
horizon = 12
input_size = 24

models = [
    VanillaTransformer(h=horizon,
                       input_size=input_size,
                       max_steps=1,
                       val_check_steps=1,
                       early_stop_patience_steps=1,
                       scaler_type='identity',
                       valid_loss=MAE(),
                       random_seed=12345678)
]



# 加载已训练的模型
nf = NeuralForecast.load(path='./checkpoints/test_run/')
# 数据准备
Y_df = df[['unique_id', 'ds', 'y']]
X_df = df[['unique_id', 'ds', 'ex_1', 'ex_2', 'ex_3', 'ex_4']]

# 合并数据集
Y_df = Y_df.merge(X_df, on=['unique_id', 'ds'], how='left')

# 进行预测
predictions = nf.predict(Y_df)

# 打印预测结果
print(predictions)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/678917.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

“滴滴打车,用友入账”,YonSuite商旅费控助力企业“降低成本”更进一步

在当今竞争激烈的商业环境中,企业对于成本控制和效率提升的需求日益迫切。特别是在商旅管理方面,如何有效整合资源、优化流程、降低费用,成为了成长型企业关注的焦点。用友YonSuite商旅费控作为用友集团旗下的重要产品,凭借其卓越…

SolidWorks功能强大的三维设计软件下载安装,SolidWorks最新资源获取!

SolidWorks,它凭借出色的三维建模能力,使得设计师们能够轻松构建出复杂且精细的机械模型,大大提升了设计效率和质量。 在机械设计领域,SolidWorks凭借其丰富的工具和特性,让设计师们能够随心所欲地挥洒创意。无论是零…

Linuxftp服务002本地登入

本期主要讲述的是ftp服务中的本地用户登入。 操作系统 CentOS Stream 9 操作步骤 首先我们先建立一个ftp组的用户,并设置密码。 [rootlocalhost ~]# useradd -g ftp wq [rootlocalhost ~]# echo 1 |passwd --stdin wq 更改用户 wq 的密码 。 passwd&#xff1a…

SpringBoot中的WebMvcConfigurationSupport和WebMvcConfigurer

在SpringBoot中可以通过以下两种方式来完成自定义WebMvc的配置: (1)继承WebMvcConfigurationSupport类 (2)实现WebMvcConfigurer接口 通过这两种方式完成的WebMvc配置存在差异,本文将对此作简单说明与区…

Selenium with Python Behave(BDD)

一、简介 Python语言的行为驱动开发,Behavior-driven development,简称BDD. "Behavior-driven development (or BDD) is an agile software development technique that encourages collaboration between developers, QA and non-technical or bu…

顶顶通呼叫中心中间件-区号号码自动加0(mod_cti基于FreeS WITCH)

顶顶通呼叫中心中间件-区号号码自动加0(mod_cti基于FreeSWITCH) 本地区号。如果配置了本地区号,被叫手机号码归属地和本地区号不同会自动加0 一、导入号码归属地 1、下载ccadmin安装包并且把手机号码归宿地解压出来 1、下载ccadmin安装包 Windows版本下载地址&…

小短片创作-理论知识(五)

1、网格体绘制 1.UE5打开Megascan插件的材质混合器,创建混合材质,最多选择3个材质进行混合, 2.通过模式->网格体绘制,进入网格体绘制模式,通过select选择一个平面进行绘制,然后通过paint进行绘制&am…

opencv笔记(13)—— 停车场车位识别

一、所需数据介绍 car1.h5 是训练后保存的模型 class_directionary 是0,1的分类 二、图像数据预处理 对输入图片进行过滤: def select_rgb_white_yellow(self,image): #过滤掉背景lower np.uint8([120, 120, 120])upper np.uint8([255, 255, 255])#…

09、进程和计划任务管理

9.1 查看和控制进程 程序是保存在外部存储介质(如硬盘)中的可执行机器代码和数据的静态集合,而进程是在 CPU 及内存中处于动态执行状态的计算机程序。在 Linux操作系统中,每个程序启动后可以创建一个或多个进程。例如,提供 Web 服务的 httpd …

计算机网络学习记录 网络层 Day4(下)

计算机网络学习记录 网络层 Day4 (下) 你好,我是Qiuner. 为记录自己编程学习过程和帮助别人少走弯路而写博客 这是我的 github https://github.com/Qiuner ⭐️ ​ gitee https://gitee.com/Qiuner 🌹 如果本篇文章帮到了你 不妨点个赞吧~ 我…

期权懂题库免费!期权开户测试难吗?多少分算合格通过?

今天带你了解期权懂题库免费!期权开户测试难吗?多少分算合格通过?期权开户测试通常要求投资者达到一定的合格分数,以确保他们具备足够的理解和知识来参与期权交易。 期权开户测试难吗? 期权开户测试的难度因人而异&am…

PW1558A规格探秘:为何它是电源系统不可或缺的6A双向保护芯片?

描述 PW1558A 是一款先进的 28V 6A 额定双向负载开关, 提供过载、 短路、 输入电压浪涌、 过大冲击电流和过热保护, 为系统供电。 内置的 24mΩ超低 RDS(ON)电源开关有助于减少正常操作期间的功率损耗。 该设备具有两个输入/输出端口 VBUS1 和 VBUS2&…

LSDFi协议赛道4大稳定币项目,以bitget钱包为例

纵览 LSDfi 生态繁荣的基石,LSD 稳定币赛道全解析 近期有许多建立在流动性质押通证的稳定币借贷协议开始出现在大众眼里,今天文章就要带大家来一一了解这些 LSDfi 协议究竟是如何争夺这块诱人的大饼。 LybraFinanceLSD 它透过抵押stETH/ETH 铸造&#…

二叉树系列题

OJ104:二叉树的最大深度 1.题目 2.注意 这里要用left和right接收递归的结果,如果不接收,直接用递归来比较,会出现效率问题。 3.参考代码 /*** Definition for a binary tree node.* struct TreeNode {* int val;* str…

【深入理解计算机系统第3版】补码加法

感觉这部分有点难,所以稍微整理记一下。 抱歉中英混合,来回切换输入法真的很折磨人。 负溢出 正常 正溢出 以4位补码加法为例,理解下表(书中P64) 补码最大值Tmax 2^3 - 1 7, 补码最小值Tmin -2^3 -8 xyz x yz z mod 2^4zU2Tw(z)溢…

[书生·浦语大模型实战营]——Lagent AgentLego 智能体应用搭建实现效果

1.完成 Lagent Web Demo 使用,并在作业中上传截图 使用插件 不使用插件: 2.完成 AgentLego 直接使用部分,并在作业中上传截图 原图 结果 3.完成 AgentLego WebUI 使用,并在作业中上传截图。 4.使用 Lagent 或 AgentLego …

Python数据分析案例46——电力系统异常值监测(自编码器,孤立森林,SVMD)

案例背景 多变量的时间序列的异常值监测一直是方兴未艾的话题,我总能看到不少的同学要做什么时间序列预测,然后做异常值监测,但是很多同学都搞不清楚他们的区别。 这里要简单解释一下,时间序列预测是有监督的模型,而…

使用API有效率地管理Dynadot域名,使用API创建文件夹管理域名

关于Dynadot Dynadot是通过ICANN认证的域名注册商,自2002年成立以来,服务于全球108个国家和地区的客户,为数以万计的客户提供简洁,优惠,安全的域名注册以及管理服务。 Dynadot平台操作教程索引(包括域名邮…

contenteditable实现插入标签的输入框功能(Vue3版)

需求:实现一个简易的函数编辑器 点击参数能够往输入框插入标签点击函数能够往输入框插入文本删除能够把标签整体删除输入的参数能够获取到其携带的信息 插入文本 /*** description 点击函数展示到输入框*/ const getValue ({ item, type }: any) > {// 创建…

计算机网络之crc循环冗余校验、子网划分、rip协议路由转发表、时延计算、香浓定理 奈氏准则、TCP超时重传 RTO

crc循环冗余校验 异或运算 : 相同得0,相异得1 从多项式获取除数 在原数据的末端补0 , 0的个数等于最高次项的阶数 如果最后结果的有效位数较少时,前面应该补0,补到个数与阶位相同 子网划分 子网掩码:用于识别IP地址中的网络号和主机号的…