神经网络11-TFT模型的简单示例

Temporal Fusion Transformer (TFT) 是一种用于时间序列预测的深度学习模型,它结合了Transformer架构的优点和专门为时间序列设计的一些优化技术。TFT尤其擅长处理多变量时间序列数据,并且能够捕捉到长期依赖关系,同时通过自注意力机制有效地处理时序特征。TFT的工作原理主要由以下几个部分组成:

1. 输入数据处理

  • 输入特征:TFT的输入是一个多变量时间序列,每个样本包含多个特征(如10个特征,每个特征有240个时间步)。每个时间步的特征值可以是连续的(如温度、股价等),也可以是分类的(如星期几、节假日等)。
  • 静态和时间序列特征:TFT区分了静态特征(例如个体ID、地点)和动态特征(例如时间步上的温度)。静态特征在模型中用于增强个体的预测性能,而动态特征则帮助模型捕捉到随时间变化的模式。

2. Encoder-Decoder架构

TFT采用了编码器-解码器(Encoder-Decoder)的架构,这个架构原本用于序列到序列的任务(如机器翻译),但是在TFT中做了调整:

  • 编码器(Encoder):输入的时间序列通过编码器进行处理,编码器包括一个由自注意力机制和GRU(门控循环单元)组成的结构。自注意力机制能够帮助模型捕捉不同时间步之间的依赖关系,而GRU有助于捕捉短期的时间依赖性。
  • 解码器(Decoder):解码器根据编码器输出的特征以及未来的已知信息来生成预测。解码器可以直接预测下一个时间步的值。

3. 自注意力机制(Self-Attention)

自注意力机制在TFT中用于捕捉时间序列中各个时间步之间的长短期依赖关系。它通过计算每个时间步和其他时间步之间的相关性,自动地给出不同时间步的权重。这样,模型可以根据时间序列中的重要性自适应地调整权重。

4. 门控机制(Gating Mechanisms)

TFT采用了多个门控机制(例如:GRN(Gated Residual Network)和变量选择网络)来控制信息流,避免不必要的计算,并且使得模型更加灵活:

  • 变量选择网络:自动选择哪些输入特征对于当前时间步的预测更为重要,从而提升了模型的性能和可解释性。
  • 门控残差单元(GRN):通过加权处理动态特征和静态特征来提供更丰富的信息。

5. 多尺度时间步(Multi-Scale Temporal Fusion)

TFT不仅利用了全局时间步的特征,还通过多尺度处理能够同时捕捉长期和短期的模式。例如,通过多个不同的时间尺度对历史信息进行融合,从而提升了模型的预测精度。

6. 预测头(Forecasting Head)

在解码器的顶部,TFT有一个预测头,它根据模型输出的时间序列特征来进行实际的预测。它生成一个未来时间步的预测值,通常用于回归任务或二分类任务。

7. 自解释性(Interpretability)

TFT具有一定的可解释性,特别是通过注意力机制变量选择网络,可以分析哪些特征对于模型的预测最重要。这对于需要模型透明度和决策依据的场景非常有用。

8.简单例子:预测股票价格

假设我们有一个简单的时间序列数据集,包含时间步(例如每天的股票价格),并且我们希望预测未来几天的股票价格。假设我们有以下结构的时间序列数据:

  • 每天的股票价格(连续特征)。
  • 每天的交易量(连续特征)。
  • 每天的节假日信息(分类特征,例如是否为节假日)。

我们的目标是基于过去的几个时间步的数据预测未来的股票价格。

模型步骤

  1. 数据预处理:我们将数据准备为适合TFT的格式。TFT需要有时间步特征静态特征目标变量
  2. 模型构建:使用TFT模型进行训练和预测。
  3. 预测:用训练好的模型预测未来的股票价格。

示例代码

假设我们使用pytorch-forecasting这个库来实现TFT模型。这个库为时间序列任务提供了简化的API。

安装必要的库

首先,安装相关库:

pip install pytorch-forecasting pytorch-lightning

构建简单的TFT模型

import pandas as pd
import numpy as np
import torch
from pytorch_forecasting import TemporalFusionTransformer
from pytorch_forecasting.data import TimeSeriesDataSet
from pytorch_forecasting.data import GroupNormalizer
from pytorch_forecasting import Baseline

# 生成模拟数据
np.random.seed(42)

# 创建一个简单的时间序列数据集
n_samples = 1000
time_idx = np.tile(np.arange(1, 101), n_samples // 100)  # 100天的周期重复n_samples次
stock_price = np.sin(time_idx * 0.1) + np.random.normal(0, 0.1, len(time_idx))  # 模拟的股票价格
volume = np.random.normal(1000, 100, len(time_idx))  # 模拟的交易量
is_holiday = (time_idx % 7 == 0).astype(int)  # 假设每7天是一个节假日

# 创建DataFrame
data = pd.DataFrame({
    "time_idx": time_idx,
    "stock_price": stock_price,
    "volume": volume,
    "is_holiday": is_holiday
})

# 创建训练和验证集
max_encoder_length = 60  # 用60个时间步预测未来
max_prediction_length = 10  # 预测未来10个时间步

# 将数据转换为TimeSeriesDataSet
training = TimeSeriesDataSet(
    data[lambda x: x.time_idx <= 80],  # 训练集使用前80天的数据
    time_idx="time_idx",
    target="stock_price",
    group_ids=["is_holiday"],  # 对应的静态特征是节假日
    min_encoder_length=max_encoder_length,
    max_encoder_length=max_encoder_length,
    min_prediction_length=max_prediction_length,
    max_prediction_length=max_prediction_length,
    static_categoricals=["is_holiday"],
    time_varying_known_reals=["stock_price", "volume"],  # 动态已知特征
    time_varying_unknown_reals=["stock_price"],  # 动态目标变量
    target_normalizer=GroupNormalizer(groups=["is_holiday"], transformation="softplus"),  # 归一化
)

# 创建数据加载器
train_dataloader = torch.utils.data.DataLoader(training, batch_size=64, shuffle=True)

# 初始化TFT模型
tft = TemporalFusionTransformer.from_dataset(
    training,
    learning_rate=0.001,
    hidden_size=16,  # 隐藏层单元数
    attention_head_size=4,  # 自注意力头数
    dropout=0.1,  # Dropout概率
    hidden_continuous_size=8,  # 连续特征的隐藏层大小
    output_size=1,  # 输出预测的维度
    loss= torch.nn.MSELoss(),  # 使用均方误差损失
)

# 训练模型
import pytorch_lightning as pl
trainer = pl.Trainer(max_epochs=10, gpus=0)  # 设置epochs
trainer.fit(tft, train_dataloader)

# 进行预测
test_data = data[lambda x: x.time_idx > 80]  # 测试集使用剩余的数据
test_dataset = TimeSeriesDataSet.from_dataset(training, test_data)
test_dataloader = torch.utils.data.DataLoader(test_dataset, batch_size=64, shuffle=False)

predictions = tft.predict(test_dataloader, mode="raw")

代码解析

1.数据生成

我们使用sin函数加上随机噪声来模拟股票价格,并且生成随机的交易量和节假日标记。最终生成的股票数据如下所示:

 2.TimeSeriesDataSet

这是pytorch-forecasting库中的一个数据类,它帮助我们将数据集转化为模型能够理解的格式。我们指定了哪些是时间步特征、静态特征和目标变量,并且定义了时间步长度(60步历史数据用于预测未来10步)。

详细解析:

  1. data[lambda x: x.time_idx <= 80]

    • data 是一个包含时间序列数据的 DataFrame。
    • lambda x: x.time_idx <= 80 是一个条件筛选器,用于选择时间索引 (time_idx) 小于或等于 80 的数据。这样,只会使用前 80 天的数据来训练模型。
  2. time_idx="time_idx"

    • time_idx 指定了表示时间步的列名。在这个例子中,time_idx 表示时间索引(通常是从 1 到 n 的整数,表示每个时间步)。
  3. target="stock_price"

    • target 参数指定了模型要预测的目标变量。在这个例子中,目标变量是 stock_price(股票价格)。
  4. group_ids=["is_holiday"]

    • group_ids 用于指定一个或多个分组特征,这些特征将用于区分不同的时间序列或分组数据。在这里,is_holiday 表示节假日(通常是一个静态变量,指示每个时间步是否为假期)。它会告诉模型如何对待不同的节假日数据。
  5. min_encoder_length=max_encoder_lengthmax_encoder_length=max_encoder_length

    • min_encoder_lengthmax_encoder_length 指定了输入序列(编码器输入)的最小和最大长度。它们用于告诉模型每个输入序列的时间步数。
    • 这两个参数的值相等,表示使用固定长度的历史数据,假设为 max_encoder_length。例如,max_encoder_length=10 表示使用过去的 10 天数据进行预测。
    • max_encoder_lengthmin_encoder_length 应该是整数,表示时间序列的长度。
  6. min_prediction_length=max_prediction_lengthmax_prediction_length=max_prediction_length

    • min_prediction_lengthmax_prediction_length 是预测的时间步长,指定模型需要预测多少步。
    • 这两个参数的值也相等,表示预测的步长是固定的。例如,max_prediction_length=5 表示模型需要预测接下来 5 天的股票价格。
  7. static_categoricals=["is_holiday"]

    • static_categoricals 表示静态类别特征(即在整个时间序列中不变化的特征)。这里 is_holiday 是一个静态类别特征,指示每个时间步是否是节假日。
  8. time_varying_known_reals=["stock_price", "volume"]

    • time_varying_known_reals 指定了动态已知特征,这些特征在时间序列中会随时间变化,并且在训练时已知。在这个例子中,stock_price(股票价格)和 volume(交易量)是动态已知特征。
  9. time_varying_unknown_reals=["stock_price"]

    • time_varying_unknown_reals 列表指定了模型需要预测的动态目标变量(时间步变化的未知特征)。在本例中,stock_price 是需要预测的目标变量,它会随时间变化。
  10. target_normalizer=GroupNormalizer(groups=["is_holiday"], transformation="softplus")

  • target_normalizer 用于对目标变量进行归一化。在这个例子中,使用了 GroupNormalizer,它会基于 is_holiday 这一组特征对目标变量进行归一化。
  • transformation="softplus" 表示使用 Softplus 函数(log(1 + exp(x)))来平滑目标变量的分布,这有助于减小异常值的影响。
3.TFT模型

我们通过TemporalFusionTransformer.from_dataset构建了TFT模型,设置了学习率、隐藏层大小等超参数。使用均方误差(MSE)作为损失函数来训练模型。 

详细解析:

1. TemporalFusionTransformer.from_dataset(training, ...):

  • TemporalFusionTransformer.from_dataset() 是一个类方法,它从给定的数据集(training)中自动配置模型的各个部分,如输入特征、目标变量、编码器长度等。
  • trainingTimeSeriesDataSet,包含了时间序列数据和相关的特征。这个方法会根据该数据集自动处理输入特征和目标,设置模型结构。

2. learning_rate=0.001:

  • learning_rate 是模型优化器的学习率,决定了模型参数在训练时的更新步幅。较小的学习率(例如 0.001)通常能帮助优化过程更稳定,但训练速度可能会变慢。

3. hidden_size=16:

  • hidden_size 是模型中每个隐藏层的单元数。这里设置为 16,表示每个隐藏层的神经元数量。较大的隐藏层大小有助于模型捕捉更多的复杂模式,但也可能增加计算复杂性和过拟合的风险。

4. attention_head_size=4:

  • attention_head_size 指定了自注意力机制中多头注意力(Multi-Head Attention)机制的头数。这里设置为 4,意味着模型将通过 4 个不同的“头”来计算注意力权重,从而捕捉不同方面的信息。这是 Transformer 模型的关键特性,可以帮助模型同时关注输入的多个部分。

5. dropout=0.1:

  • dropout 是防止过拟合的一种技术,表示在训练过程中会随机丢弃 10% 的神经元,以减少模型对特定神经元的依赖。这里设置为 0.1,表示在训练时有 10% 的概率会忽略某些神经元的输出。

6. hidden_continuous_size=8:

  • hidden_continuous_size 是连续特征的隐藏层大小。时间序列中的一些特征(例如股价、交易量等)可能是连续变量,这个参数指定了这些连续特征的隐藏表示大小。设置为 8,表示模型将使用 8 个单位来表示这些连续特征。

7. output_size=1:

  • output_size 指定了模型的输出维度。在时间序列预测任务中,通常输出一个预测值(如股票价格),所以这里设置为 1,表示模型输出一个数值。

8. loss= torch.nn.MSELoss():

  • loss 是指定优化目标的损失函数。在回归任务(如股票价格预测)中,常使用均方误差(MSE)作为损失函数。torch.nn.MSELoss() 会计算模型预测值与实际值之间的均方误差。
4.训练

使用pytorch_lightning中的Trainer进行模型训练,设置训练的epoch数为10。 

详细解析

  1. import pytorch_lightning as pl

    • 导入 PyTorch Lightning 库,PyTorch Lightning 是一个用于简化 PyTorch 模型训练过程的高级框架。它封装了很多 PyTorch 中繁琐的训练步骤,使得训练过程更清晰、更易于管理。
  2. trainer = pl.Trainer(max_epochs=10, gpus=0)

    • 创建一个 Trainer 对象,这是 PyTorch Lightning 中用于训练模型的主要接口。这里使用了以下参数:
      • max_epochs=10:指定训练的最大轮数(epochs)。训练将在 10 个 epochs 后停止。
      • gpus=0:指定使用的 GPU 数量,0 表示不使用 GPU(即在 CPU 上训练)。如果你的机器上有可用的 GPU,可以将其设置为 gpus=1 或更多。
  3. trainer.fit(tft, train_dataloader)

    • fit() 方法用于训练模型。在此处,tft 是创建的 TemporalFusionTransformer 模型,train_dataloader 是训练数据的 DataLoader,包含了批次化的训练数据。
    • fit() 方法会根据训练数据自动进行模型的前向传播、反向传播和参数更新。
5. 预测

通过predict函数,模型基于测试集进行预测。

详细解析

  1. test_data = data[lambda x: x.time_idx > 80]

    • test_data 选择 datatime_idx 大于 80 的数据,作为测试集。这表示测试集包含从第 81 天开始的数据。
  2. test_dataset = TimeSeriesDataSet.from_dataset(training, test_data)

    • TimeSeriesDataSet.from_dataset() 方法根据训练集 training 和测试集 test_data 创建一个新的测试数据集。这个方法会从训练集的配置中继承必要的参数(如时间索引、特征列等),然后将其应用于测试数据集。
  3. test_dataloader = torch.utils.data.DataLoader(test_dataset, batch_size=64, shuffle=False)

    • 使用 DataLoader 来批次化测试集。与训练集不同,测试集通常不需要进行打乱,因此 shuffle=False
    • batch_size=64 表示每个批次加载 64 个样本。
  4. predictions = tft.predict(test_dataloader, mode="raw")

    • predict() 方法用于对测试集进行预测。它将返回模型对测试集的预测结果。
    • mode="raw" 指定了返回原始的预测结果(而不是经过后处理或归一化的结果)。你可以选择不同的模式来控制返回的预测格式,例如可以选择返回预测的概率、标签等。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/918609.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

学习threejs,使用TWEEN插件实现动画

&#x1f468;‍⚕️ 主页&#xff1a; gis分享者 &#x1f468;‍⚕️ 感谢各位大佬 点赞&#x1f44d; 收藏⭐ 留言&#x1f4dd; 加关注✅! &#x1f468;‍⚕️ 收录于专栏&#xff1a;threejs gis工程师 文章目录 一、&#x1f340;前言1.1 ☘️THREE.PLYLoader PLY模型加…

世界坐标系、相机坐标系、图像物理坐标系、像素平面坐标系

坐标系及其转换在计算机视觉领域占据核心地位。理解如何从一个坐标系转换到另一个坐标系&#xff0c;不仅是理论上的需要&#xff0c;也是实际应用中不可或缺的技能。 一、世界坐标系的定义 世界坐标系是一个全局的坐标系统&#xff0c;用于定义场景中物体的位置。在这个坐标…

03-axios常用的请求方法、axios错误处理

欢迎来到“雪碧聊技术”CSDN博客&#xff01; 在这里&#xff0c;您将踏入一个专注于Java开发技术的知识殿堂。无论您是Java编程的初学者&#xff0c;还是具有一定经验的开发者&#xff0c;相信我的博客都能为您提供宝贵的学习资源和实用技巧。作为您的技术向导&#xff0c;我将…

Redis/Codis性能瓶颈揭秘:网卡软中断的影响与优化

目录 现象回顾 问题剖析 现场分析 解决方案 总结与反思 1.调整中断亲和性&#xff08;IRQ Affinity&#xff09;&#xff1a; 2.RPS&#xff08;Receive Packet Steering&#xff09;和 RFS&#xff08;Receive Flow Steering&#xff09;&#xff1a; 近期&#xff0c;…

openwebui使用

文章目录 1、feature2、安装使用2.1 安装过程2.2 安装好后 1、feature 可以加载多个大模型 同时回复 模型问答: 使用vLLM框架部署模型&#xff0c;再使用Open WebUI直接进行模型问答 多模型支持: 多模型回复比对&#xff08;Qwen2-72B-Instruct, llama3-70b-8192, mixtral-8x7…

汽车资讯新引擎:Spring Boot技术领航

3系统分析 3.1可行性分析 通过对本汽车资讯网站实行的目的初步调查和分析&#xff0c;提出可行性方案并对其一一进行论证。我们在这里主要从技术可行性、经济可行性、操作可行性等方面进行分析。 3.1.1技术可行性 本汽车资讯网站采用SSM框架&#xff0c;JAVA作为开发语言&#…

应用系统开发(12) Zync中实现数字相敏检波

在 Xilinx Zynq 系列(如 Zynq-7000 或 Zynq UltraScale+)中实现数字相敏检波(DSP,Digital Synchronous Detection)可以通过硬件(PL部分,FPGA逻辑)和软件(PS部分,ARM Cortex-A 处理器)的协同工作来实现。以下是一个详细的设计方法,包括基本原理和 Zynq 的实现步骤。…

《译文》2024年11月数维杯国际大学生数学建模挑战赛题目

# 赛题正式发布 2024年第十届数维杯国际大学生数学建模挑战赛顺利开赛&#xff0c;竞赛开始时间为北京时间2024年11月15日09:00至北京时间2024年11月19日09:00&#xff0c;共计4天&#xff0c;竞赛题目正式发布&#xff0c;快来一起围观&#xff0c;你认为今年的哪个题目更具有…

apk反编译修改教程系列-----apk应用反编译中AndroidManifest.xml详细代码释义解析 包含各种权限 代码含义【二】

💝💝💝💝在上期博文中解析了一个常规apk中 AndroidManifest.xml的权限以及代码。应粉丝需求。这次解析一个权限较高的apk。这款apk是一个家长管控的应用。需求的各种权限较高。而且通过管控端可以设置控制端的app隐藏与否。 通过博文了解💝💝💝💝 1💝💝…

【UGUI】背包的交互01(道具信息跟随鼠标+道具信息面板显示)

详细程序逻辑过程 初始化物品栏&#xff1a; 在 Awake 方法中&#xff0c;通过标签找到提示框和信息面板。 循环生成10个背包格子&#xff0c;并为每个格子设置图标和名称。 为每个格子添加 UInterMaager232 脚本&#xff0c;以便处理交互事件。 关闭提示框和信息面板&#…

Docker: ubuntu系统下Docker的安装

安装依赖 操作系统版本 Ubuntu Kinetic 22.10Ubuntu Jammy 24.04 (LTS)Ubuntu Jammy 22.04 (LTS)Ubuntu Focal 20.04 (LTS)Ubuntu Bionic 18.04 (LTS) CPU架构支持 ARMx86_64 查看我们的系统版本信息 uname -a通过该命令查得cpu架构是x86_64的&#xff1b; cat /etc/*re…

Nacos 配置中心变更利器:自定义标签灰度

作者&#xff1a;柳遵飞 配置中心被广泛使用 配置中心是 Nacos 的核心功能之一&#xff0c;接入配置中心&#xff0c;可以实现不重启线上应用的情况下动态改变程序的运行期行为&#xff0c;在整个软件生命周期中&#xff0c;可以极大降低了软件构建及部署的成本&#xff0c;提…

基于RK3568J多网口电力可信物联网关解决方案

前言 随着工业物联网的普及和功能越来越强大&#xff0c;边缘计算网关应运而生。 边缘计算有效降低了云端服务器的负载、大大降低了带宽的占用&#xff0c;同时也为本地化的区域自治提供了便利条件。 边缘计算网关&#xff0c;完美地发挥了“边”与“端” 结合优势&#xff0c…

无人机飞手入门指南

无人机飞手入门指南旨在为初学者提供一份全面的学习路径和实践建议&#xff0c;帮助新手快速掌握无人机飞行技能并了解相关法规知识。以下是一份详细的入门指南&#xff1a; 一、了解无人机基础知识 1. 无人机构造&#xff1a;了解无人机的组成部分&#xff0c;如机身、螺旋桨…

网络传输:网卡、IP、网关、子网掩码、MAC、ARP、路由器、NAT、交换机

目录 网卡IP网络地址主机地址子网子网掩码网关默认网关 MACARPARP抓包分析 路由器NATNAPT 交换机 网卡 网卡(Network Interface Card&#xff0c;简称NIC)&#xff0c;也称网络适配器。 OSI模型&#xff1a; 1、网卡工作在OSI模型的最后两层&#xff0c;物理层和数据链路层。物…

多账号登录管理器(淘宝、京东、拼多多等)

目录 下载安装与运行 解决什么问题 功能说明 目前支持的平台 功能演示 登录后能保持多久 下载安装与运行 下载、安装与运行 语雀 解决什么问题 多个账号的快捷登录与切换 功能说明 支持多个电商平台支持多个账号的登录保持支持快捷切换支持导入导出支持批量删除支持…

你可以通过以下步骤找到并打开 **Visual Studio 开发者命令提示符**:

你可以通过以下步骤找到并打开 Visual Studio 开发者命令提示符&#xff1a; 1. 通过开始菜单查找 打开 开始菜单&#xff08;点击屏幕左下角的 Windows 图标&#xff09;。在搜索框中输入 Developer Command Prompt。你应该看到以下几种选项&#xff08;具体取决于你的 Visu…

版本控制【Git Bash】【Gitee】

目录 一、什么是版本控制&#xff1f; 二、版本控制的种类&#xff1a; 1、本地版本控制 2、集中版本控制 3、分布式版本控制 三、下载Git Bash 四、Git Bash 配置 五、Git Bash使用 1、切换目录&#xff1a;cd 2.查看当前文件路径&#xff1a;pwd 3.列出当前目录下文件…

大数据-226 离线数仓 - Flume 优化配置 自定义拦截器 拦截原理 拦截器实现 Java

点一下关注吧&#xff01;&#xff01;&#xff01;非常感谢&#xff01;&#xff01;持续更新&#xff01;&#xff01;&#xff01; Java篇开始了&#xff01; 目前开始更新 MyBatis&#xff0c;一起深入浅出&#xff01; 目前已经更新到了&#xff1a; Hadoop&#xff0…

【3D Slicer】的小白入门使用指南九

定量医学影像临床研究与实践 任务 定量成像教程 定量成像是从医学影像中提取定量测量的过程。 本教程基于两个定量成像的例子构建: - 形态学:缓慢生长肿瘤中的小体积变化 - 功能:鳞状细胞癌中的代谢活动 第1部分:使用变化跟踪模块测量脑膜瘤的小体积变化第2部分:使用PET标…