时间序列预测大模型-TimeGPT

时间序列预测领域正在经历一个非常激动人心的时期。仅在过去的三年里,我们就看到了许多重要的贡献,例如N-BEATS、N-HiTS、PatchTST和TimesNet。

与此同时,大型语言模型 (LLM)最近在 ChatGPT 等应用程序中广受欢迎,因为它们无需进一步训练即可适应各种任务。

这就引出了一个问题:时间序列的基础模型是否可以像自然语言处理一样存在?在大量时间序列数据上预先训练的大型模型是否有可能对未见过的数据产生准确的预测?

由 Azul Garza 和 Max Mergenthaler-Canseco 提出,作者将大模型背后的技术和架构应用于预测领域,成功构建了第一个能够进行零样本推理的时间序列基础模型。

在本文中,我们首先探讨 TimeGPT 背后的架构以及模型的训练方式。然后,我们将其应用于预测项目,以根据其他最先进的方法(如 N-BEATS、N-HiTS 和 PatchTST)评估其性能。

欲了解更多详细信息,请务必阅读原始论文

探索 TimeGPT

  如前所述,TimeGPT 是创建时间序列预测基础模型的首次尝试。

说明如何训练 TimeGPT 以对看不见的数据进行推理。图片由 Azul Garza 和 Max Mergenthaler-Canseco 拍摄,来自TimeGPT-1

从上图我们可以看出,TimeGPT 背后的总体思想是在来自不同领域的大量数据上训练模型,然后对未见过的数据进行零样本推理。

当然,这种方法依赖于迁移学习,这是模型利用训练期间获得的知识解决新任务的能力。

现在,只有当模型足够大并且经过大量数据训练时,这才有效。

训练TimeGPT

为此,作者使用超过 1000 亿个数据点对 TimeGPT 进行了训练,这些数据点全部来自开源时间序列数据。该数据集涵盖广泛的领域,从金融、经济和天气,到网络流量、能源和销售。

请注意,作者没有透露用于管理 1000 亿个数据点的公共数据来源。

这种多样性对于基础模型的成功至关重要,因为它可以学习不同的时间模式,从而更好地泛化。

例如,我们可以预期天气数据具有每日季节性(白天比晚上更热)和每年季节性,而汽车交通数据可以具有每日季节性(白天道路上的汽车多于夜间)和每周季节性。季节性(一周内道路上的汽车多于周末)。

为了确保模型的稳健性和泛化能力,预处理被保持在最低限度。事实上,只填充了缺失的值,其余的保持原始形式。虽然作者没有指定数据插补的方法,但我怀疑使用了某种插值技术,例如线性插值、样条插值或移动平均插值。

然后对模型进行多天的训练,在此期间优化超参数和学习率。虽然作者没有透露训练需要多少天和 GPU,但我们确实知道该模型是在 PyTorch 中实现的,并且它使用 Adam 优化器和学习率衰减策略。

TimeGPT的架构

TimeGPT 利用 Transformer 架构和基于 Google 和多伦多大学 2017 年开创性工作的自注意力机制。

TimeGPT 的架构。输入序列与外生变量一起被馈送到 Transfomer 的编码器,然后解码器生成预测。图片由 Azul Garza 和 Max Mergenthaler-Canseco 拍摄,来自TimeGPT-1。

从上图我们可以看到TimeGPT采用了完整的编码器-解码器Transformer架构。

输入可以包含历史数据窗口以及外源数据,例如准时事件或其他系列。

输入被馈送到模型的编码器部分。然后,编码器内部的注意力机制从输入中学习不同的属性。然后将其馈送到解码器,解码器使用学到的信息来生成预测。当然,当预测序列达到用户设置的预测范围的长度时,预测序列就会结束。

值得注意的是,作者在 TimeGPT 中实现了保形预测,允许模型根据历史误差估计预测区间。

TimeGPT 的功能

考虑到 TimeGPT 是构建时间序列基础模型的首次尝试,因此它具有广泛的功能。

首先,TimeGPT 是一个预先训练的模型,这意味着我们可以生成预测,而无需专门针对我们的数据进行训练。尽管如此,仍然可以根据我们的数据对模型进行微调。

其次,该模型支持外生变量来预测我们的目标,并且它可以处理多元预测任务。

最后,通过使用保形预测,TimeGPT 可以估计预测间隔。这反过来又允许模型执行异常检测。基本上,如果数据点超出 99% 置信区间,则模型会将其标记为异常。

请记住,所有这些任务都可以通过零样本推理或一些微调来实现,这是时间序列预测领域范式的根本转变。

现在我们对 TimeGPT、它的工作原理以及训练方式有了更深入的了解,让我们看看该模型的实际应用。

使用 TimeGPT 进行预测

现在让我们将 TimeGPT 应用于预测任务,并将其性能与其他模型进行比较。

请注意,在撰写本文时,TimeGPT 只能通过 API 访问,并且处于封闭测试阶段。我提交了请求并获得了两周免费访问该模型的许可。要获取令牌并访问模型,您必须访问他们的网站。

如前所述,该模型是根据来自公开数据的 1000 亿个数据点进行训练的。由于作者没有指定实际使用的数据集,我认为在已知的基准数据集(例如ETT或Weather)上测试模型是不合理的,因为模型可能在训练期间看到了这些数据。

因此,我为本文编译并开源了自己的数据集。

具体来说,我策划了从 2020 年 1 月 1 日到 2023 年 10 月 12 日期间博客上的每日浏览量。我还添加了两个外生变量:一个表示发布新文章的日期,另一个表示发布新文章的日期。我在美国度假,因为我的大多数观众都住在那里。

该数据集现已在GitHub上公开提供,最重要的是,我们确信 TimeGPT 并未使用该数据进行训练

导入库并读取数据

自然的第一步是导入此实验的库。

import pandas as pd
import numpy as np
import datetime
import matplotlib.pyplot as plt

from neuralforecast.core import NeuralForecast
from neuralforecast.models import NHITS, NBEATS, PatchTST

from neuralforecast.losses.numpy import mae, mse

from nixtlats import TimeGPT

%matplotlib inline

然后,为了访问 TimeGPT 模型,我们从文件中读取 API 密钥。请注意,我没有将 API 密钥分配给环境变量,因为访问仅限两周。

with open("data/timegpt_api_key.txt", 'r') as file:
        API_KEY = file.read()

然后,我们就可以读取数据了。

df = pd.read_csv('data/medium_views_published_holidays.csv')
df['ds'] = pd.to_datetime(df['ds'])

df.head()

我们数据集的前五行。

从上图可以看出,数据集的格式与我们使用 Nixtla 的其他开源库时的格式相同。

我们有一个unique_id列来标记不同的时间序列,但在我们的例子中,我们只有一个序列。

y列代表我博客上的每日浏览量,published是一个简单的标志,用于标记发布新文章 (1) 或未发布文章 (0) 的日期。直观上我们知道,当新内容发布时,浏览量通常会在一段时间内增加。

最后, is_holiday列表示美国是否有假期。直觉是,在假期,访问我博客的人会更少。

现在,让我们可视化我们的数据并寻找可辨别的模式。

published_dates = df[df['published'] == 1]

fig, ax = plt.subplots(figsize=(12,8))

ax.plot(df['ds'], df['y'])
ax.scatter(published_dates['ds'], published_dates['y'], marker='o', color='red', label='New article')
ax.set_xlabel('Day')
ax.set_ylabel('Total views')
ax.legend(loc='best')

fig.autofmt_xdate()


plt.tight_layout()

我的博客的每日浏览量。

从上图中,我们已经可以看到一些有趣的行为。首先,请注意,红点表示一篇新发表的文章,它们几乎立即出现访问高峰。

我们还注意到 2021 年的活动减少,这反映在我博客的每日浏览量减少上。最后,在 2023 年,我们注意到文章发表后访问量出现了一些异常峰值。

放大数据,我们还发现了明显的每周季节性。

我的博客的每日浏览量。在这里,我们看到明显的每周季节性,周末参观的人较少。

从上图中,我们现在可以看到周末访问博客的访问者比工作日少。

考虑到所有这些,让我们看看如何使用 TimeGPT 来进行预测。

使用 TimeGPT 进行预测

首先,我们将数据集分为训练集和测试集。在这里,我将为测试集保留 168 个时间步,对应于 24 周的每日数据。

train = df[:-168]
test = df[-168:]

然后,我们的预测范围为 7 天,因为我有兴趣预测一整周的每日观看次数。

现在,该 API 不附带交叉验证的实现。因此,我们创建自己的循环来一次生成七个预测,直到我们对整个测试集进行预测。

future_exog = test[['unique_id', 'ds', 'published', 'is_holiday']]

timegpt = TimeGPT(token=API_KEY)

timegpt_preds = []

for i in range(0, 162, 7):

    timegpt_preds_df = timegpt.forecast(
        df=df.iloc[:1213+i],
        X_df = future_exog[i:i+7],
        h=7,
        finetune_steps=10,
        id_col='unique_id',
        time_col='ds',
        target_col='y'
    )
    
    preds = timegpt_preds_df['TimeGPT']
    
    timegpt_preds.extend(preds)

在上面的代码块中,请注意,我们必须传递外生变量的未来值。这很好,因为它们是静态变量。我们知道未来的假期日期,博客作者也知道他计划何时发表文章。

另请注意,我们使用finetune_steps参数根据数据微调 TimeGPT 。

循环完成后,我们可以将预测添加到测试集中。同样,TimeGPT 一次生成 7 个预测,直到获得 168 个预测,以便我们可以评估其预测下周每日观看次数的能力。

test['TimeGPT'] = timegpt_preds

test.head()

来自 TimeGPT 的预测。

使用 N-BEATS、N-HiTS 和 PatchTST 进行预测

现在,让我们应用其他方法来看看专门在我们的数据集上训练这些模型是否可以产生更好的预测。

对于这个实验,如前所述,我们使用 N-BEATS、N-HiTS 和 PatchTST。

horizon = 7

models = [NHITS(h=horizon,
               input_size=5*horizon,
               max_steps=50),
         NBEATS(h=horizon,
               input_size=5*horizon,
               max_steps=50),
         PatchTST(h=horizon,
                 input_size=5*horizon,
                 max_steps=50)]

然后,我们初始化NeuralForecast对象并指定数据的频率,在本例中为每天。

nf = NeuralForecast(models=models, freq='D')

然后,我们在 7 个时间步长的 24 个窗口上运行交叉验证,以获得与 TimeGPT 使用的测试集一致的预测。

preds_df = nf.cross_validation(
    df=df, 
    static_df=future_exog , 
    step_size=7, 
    n_windows=24
)

然后,我们可以简单地将 TimeGPT 的预测添加到这个新的preds_df DataFrame 中,以获得包含所有模型预测的单个 DataFrame。

preds_df['TimeGPT'] = test['TimeGPT']

包含所有模型预测的数据框。

伟大的!我们现在准备评估每个模型的性能。

评估

在测量性能指标之前,让我们可视化测试集上每个模型的预测。

可视化每个模型的预测。

首先,我们看到每个模型之间有很多重叠。然而,我们确实注意到 N-HiTS 预测了两个在现实生活中未实现的峰值。此外,PatchTST 似乎经常被低估。然而,TimeGPT 似乎通常与实际数据重叠得很好。

当然,评估每个模型性能的唯一方法是测量性能指标。在这里,我们使用平均绝对误差(MAE)和均方误差(MSE)。此外,我们将预测四舍五入为整数,因为小数对于博客的日常访问者来说没有意义。

preds_df = preds_df.round({
    'NHITS': 0,
    'NBEATS': 0,
    'PatchTST': 0,
    'TimeGPT': 0
})

data = {'N-HiTS': [mae(preds_df['NHITS'], preds_df['y']), mse(preds_df['NHITS'], preds_df['y'])],
       'N-BEATS': [mae(preds_df['NBEATS'], preds_df['y']), mse(preds_df['NBEATS'], preds_df['y'])],
       'PatchTST': [mae(preds_df['PatchTST'], preds_df['y']), mse(preds_df['PatchTST'], preds_df['y'])],
       'TimeGPT': [mae(preds_df['TimeGPT'], preds_df['y']), mse(preds_df['TimeGPT'], preds_df['y'])]}

metrics_df = pd.DataFrame(data=data)
metrics_df.index = ['mae', 'mse']

metrics_df.style.highlight_min(color='lightgreen', axis=1)

每个模型的性能指标。在这里,TimeGPT 是冠军模型,因为它实现了最低的 MAE 和 MSE。

从上图可以看出,TimeGPT 是冠军模型,因为它实现了最低的 MAE 和 MSE,其次是 N-BEATS、PatchTST 和 N-HiTS。

这是一个令人兴奋的结果,因为 TimeGPT 从未见过这个数据集,并且只进行了几个步骤的微调。虽然这不是一个详尽的实验,但我相信它确实展示了预测领域潜在的基础模型的一瞥。

我对TimeGPT的个人看法

虽然我对 TimeGPT 的简短实验被证明是令人兴奋的,但我必须指出,原始论文在许多重要领域仍然含糊不清。

同样,我们不知道使用哪些数据集来训练和测试模型,因此我们无法真正验证 TimeGPT 的性能结果,如下所示。

Azul Garza 和 Max Mergenthaler-Canseco 的原始论文中报告的 TimeGPT 性能结果

从上表中我们可以看到,TimeGPT 在每月和每周的频率上表现最好,N-HiTS 和 Temporal Fusion Transformer (TFT) 通常排名第二或第三。话又说回来,因为我们不知道使用了哪些数据,所以我们无法验证这些指标。

在如何训练模型以及如何调整模型来处理时间序列数据方面也缺乏透明度。

我相信该模型是用于商业用途,这解释了为什么该论文缺乏重现 TimeGPT 的细节。这并没有什么问题,但论文缺乏可重复性是科学界担心的问题。

尽管如此,我还是希望这能够激发时间序列基础模型的新工作和研究,并且我们最终会看到这些模型的开源版本,就像我们看到法学硕士发生的情况一样。

结论

TimeGPT 是时间序列预测的第一个基础模型。

它利用 Transformer 架构,并在 1000 亿个数据点上进行了预训练,以对新的未见数据进行零样本推理。

结合保形预测技术,该模型无需在特定数据集上进行训练即可生成预测区间并执行异常检测。

我仍然相信每个预测问题都需要独特的方法,因此请务必测试 TimeGPT 以及其他模型。

谢谢阅读!我希望您喜欢它并学到新东西!

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/110244.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

自动化项目实战 [个人博客系统]

自动化博客项目 用户注册登录验证效验个人博客列表页博客数量不为 0 博客系统主页写博客 我的博客列表页效验 刚发布的博客的标题和时间查看 文章详情页删除文章效验第一篇博客 不是 "自动化测试" 注销退出到登录页面,用户名密码为空 用户注册 Order(1)Parameterized…

Spring Cloud之API网关(Zuul)

目录 Zuul 简介 功能 工作流程 搭建 1.引入依赖 2.添加注解 3.路由转发 4.测试 实现原理 EnableZuulProxy注解 ZuulServlet FilterProcessor Zuul内置过滤器 常用配置 Zuul 简介 zuul是SpringCloud子项目的核心组件之一,可以作为微服务架构中的API网…

【C】C语言文件操作

1.为什么使用文件 我们前面学习结构体时,写通讯录的程序,当通讯录运行起来的时候,可以给通讯录中增加、删除数据,此时数据是存放在内存中,当程序退出的时候,通讯录中的数据自然就不存在了,等下…

超全整理,Jmeter性能测试-脚本error报错排查/分布式压测(详全)

目录:导读 前言一、Python编程入门到精通二、接口自动化项目实战三、Web自动化项目实战四、App自动化项目实战五、一线大厂简历六、测试开发DevOps体系七、常用自动化测试工具八、JMeter性能测试九、总结(尾部小惊喜) 前言 性能脚本error报错…

域名系统 DNS

DNS 概述 域名系统 DNS(Domain Name System)是因特网使用的命名系统,用来把便于人们使用的机器名字转换成为 IP 地址。域名系统其实就是名字系统。为什么不叫“名字”而叫“域名”呢?这是因为在这种因特网的命名系统中使用了许多的“域(domain)”&#x…

S5PV210裸机(九):ADC

本文主要探讨210的ADC相关知识。 ADC ADC:模数转换(模拟信号转数字信号) 量程:模拟电压信号范围(210为0~3.3V) 精度:若10二进制位来表示精度(210为10位或12位),量…

线性代数 第三章 向量

一、运算 加法、数乘、内积 施密特正交化 二、线性表出 概念:如果,则称可由线性表出(k不要求不全为0) 判定: 非齐次线性方程组有解无关,相关 如果两个向量组可以互相线性表出,则称这两个…

Xilinx 7 系列 1.8V LVDS 和 2.5V LVDS 信号之间的 LVDS 兼容性

如果通过LVDS进行接口,可以按照以程图中的步骤操作,以确保满足正确使用LVDS的所有要求。 40191 - 7 系列 - 1.8V LVDS 和 2.5V LVDS 信号之间的 LVDS 兼容性 与LVDS兼容驱动器和接收器连接时,7系列LVDS和LVDS_25输入和输出应该不存在兼容性问…

案例分析真题-系统建模

案例分析真题-系统建模 2009年真题 【问题1】 【问题2】 【问题3】 2012年真题 【问题1】 【问题2】 【问题3】 2014年真题 【问题1】 【问题2】 骚戴理解:这个题目以前经常考,不知道今年会不会考,判断的话就是看加工有没有缺少输入和输出&a…

C++面试——多线程详解

C11提供了语言层面上的多线程&#xff0c;包含在头文件<thread>中。它解决了跨平台的问题&#xff0c;提供了管理线程、保护共享数据、线程间同步操作、原子操作等类。C11 新标准中引入了5个头文件来支持多线程编程&#xff0c;如下图所示&#xff1a; 多进程与多线程 多…

SSH 无密登录设置

1 &#xff09; 配置 ssh &#xff08;1&#xff09;基本语法 ssh 另一台电脑的 IP 地址&#xff08;2&#xff09;ssh 连接时出现 Host key verification failed 的解决方法 [libaihadoop102 ~]$ ssh hadoop103 ➢ 如果出现如下内容 Are you sure you want to continue c…

MAC缓解WebUI提示词反推

当前环境信息&#xff1a; 在mac上安装好stable diffusion后&#xff0c;能做图片生成了之后&#xff0c;遇到一些图片需要做提示词反推&#xff0c;这个时候需要下载一个插件&#xff0c;参考&#xff1a; https://gitcode.net/ranting8323/stable-diffusion-webui-wd14-tagg…

mac 安装homebrew ,golang

mac 安装homebrew ,golang 安装homebrew安装golang选择 apple arm 版本安装配置环境变量 安装homebrew /bin/zsh -c "$(curl -fsSL https://gitee.com/cunkai/HomebrewCN/raw/master/Homebrew.sh)"回车执行指令后&#xff0c;根据提示操作。具体包括以下提示操作&am…

ios 代码上下文截屏之后导致的图片异常问题

业务场景&#xff0c;之前是直接将当前的collectionview截长屏操作&#xff0c;第一次截图会出现黑色部分原因是视图未完全布局&#xff0c;原因是第一次使用了Masonry约束然后再截图的时候进行了frame赋值&#xff0c;可以查看下Masonry约束和frame的冲突&#xff0c;全部修改…

除自身以外数组的乘积

给你一个整数数组 nums&#xff0c;返回 数组 answer &#xff0c;其中 answer[i] 等于 nums 中除 nums[i] 之外其余各元素的乘积 。 题目数据 保证 数组 nums之中任意元素的全部前缀元素和后缀的乘积都在 32 位 整数范围内。 请 不要使用除法&#xff0c;且在 O(n) 时间复杂…

Qt 重写QSlider简单实现滑动解锁控件(指定百分比回弹效果)

组件效果图&#xff1a; 应用场景&#xff1a; 用于滑动解锁相关场景&#xff0c;Qt的控件鼠标监听机制对于嵌入式设备GUI可触摸屏依旧可用。 实现方式&#xff1a; 主要是通过继承QSlider以及搭配使用QStyleOptionSlider来实现效果。 注意细则&#xff1a; QStyleOption…

【IDEA】每个方法之间如何设置分隔线

修改后效果&#xff1a; 各个方法之间出现了分隔线

洞察运营机会的数据分析利器

这套分析方法包括5个分析工具&#xff1a; 用“描述性统计”来快速了解数据的整体特点。用“变化分析”来寻找数据的问题和突破口。用“指标体系”来深度洞察变化背后的原因。用“相关性分析”来精确判断原因的影响程度。用“趋势预测”来科学预测未来数据的走势&#xff0c;

浅谈js代码的封装方法(2023.10.30)

常见的js代码封装方法 2023.10.30 需求1、js代码封装的优缺点2、js代码封装方式2.1 方式一&#xff1a;function function declarations2.1.1 示例 2.2 方式二&#xff1a;class2.2.1 class declarations2.2.2 Class expressions 2.3 变量函数2.4 变量闭包匿名函数2.5 闭包函数…