基于BiTCN双向时间卷积网络实现电力负荷多元时序预测(PyTorch版)

Bidirectional Temporal Convolutional Network \begin{aligned} &\text{\Large \color{#CDA59E}Bidirectional Temporal Convolutional Network}\\ \end{aligned} Bidirectional Temporal Convolutional Network

Bidirectional Temporal Convolutional Network (BiTCN) is a forecasting architecture based on two temporal convolutional networks (TCNs). The first network (‘forward’) encodes future covariates of the time series, whereas the second network (‘backward’) encodes past observations and covariates. This method allows to preserve the temporal information of sequence data, and is computationally more efficient than common RNN methods (LSTM, GRU, …). As compared to Transformer-based methods, BiTCN has a lower space complexity, i.e. it requires orders of magnitude less parameters.

References
-Olivier Sprangers, Sebastian Schelter, Maarten de Rijke (2023). Parameter-Efficient Deep Probabilistic Forecasting. International Journal of Forecasting 39, no. 1 (1 January 2023): 332–45. URL: https://doi.org/10.1016/j.ijforecast.2021.11.011.
-Shaojie Bai, Zico Kolter, Vladlen Koltun. (2018). An Empirical Evaluation of Generic Convolutional and Recurrent Networks for Sequence Modeling. Computing Research Repository, abs/1803.01271. URL: https://arxiv.org/abs/1803.01271.
-van den Oord, A., Dieleman, S., Zen, H., Simonyan, K., Vinyals, O., Graves, A., Kalchbrenner, N., Senior, A. W., & Kavukcuoglu, K. (2016). Wavenet: A generative model for raw audio. Computing Research Repository, abs/1609.03499. URL: http://arxiv.org/abs/1609.03499. arXiv:1609.03499.

在这里插入图片描述

前言

系列专栏:【深度学习:算法项目实战】✨︎
涉及医疗健康、财经金融、商业零售、食品饮料、运动健身、交通运输、环境科学、社交媒体以及文本和图像处理等诸多领域,讨论了各种复杂的深度神经网络思想,如卷积神经网络、循环神经网络、生成对抗网络、门控循环单元、长短期记忆、自然语言处理、深度强化学习、大型语言模型和迁移学习。

BiTCN,即双向时间卷积网络(Bidirectional Temporal Convolutional Network),作为深度学习领域极具创新性的神经网络架构,其核心设计亮点在于 “双向卷积” 机制。与传统单向卷积网络仅从单一时间流向挖掘信息不同,BiTCN 能够同时从时间序列的正向与反向进行卷积操作。这意味着在处理电力负荷等时序数据时,它不仅能捕捉到随时间递增方向上数据的变化趋势,诸如负荷随时间逐步上升的白天用电高峰特征;还能敏锐感知反向时间流中蕴含的关键信息,像是捕捉夜间用电量逐渐降低过程中隐藏的规律。如此双向并行的信息采集模式,极大地扩充了可获取信息的边界,有效避免因单向视角局限而遗漏重要特征。

在模型内部结构方面,BiTCN 精心构建了多层卷积层与池化层交替排列的布局。通过卷积层,利用不同尺寸的卷积核精细扫描时间序列,精准提取从局部到全局的各类特征。小尺寸卷积核聚焦于数据细微波动,挖掘短周期内的用电模式变化;大尺寸卷积核则负责勾勒宏观趋势,捕捉如季节更迭引发的长期用电负荷起伏。紧随其后的池化层发挥着下采样功能,在降低数据维度的同时保留核心特征,既减少计算量、提升运算效率,又确保关键信息不流失,为后续深层次的网络处理夯实基础。

文章目录

  • 1. 数据集介绍
  • 2. 数据预处理
  • 3. 数据可视化
  • 4. 构建模型
  • 5. 交叉验证
  • 6. 模型预测
  • 7. 回归拟合图
  • 8. 模型评估

1. 数据集介绍

本文用到的数据集是ETTh1.csv,ETTh1数据集是电力变压器数据集(ETDataset)的一部分,旨在用于长序列时间序列预测问题的研究。该数据集收集了中国两个不同县两年的数据,以预测特定地区的电力需求情况。


import pandas as pd
import matplotlib.pyplot as plt

from neuralforecast.core import NeuralForecast
from neuralforecast.models import BiTCN
from neuralforecast.losses.pytorch import MAE
from neuralforecast.losses.numpy import mae, mse, mape, rmse

from datasetsforecast.long_horizon import LongHorizon
# Change this to your own data to try the model
Y_df, X_df, _ = LongHorizon.load(directory='./', group='ETTh1')

2. 数据预处理

Y_df['ds'] = pd.to_datetime(Y_df['ds'])

3. 数据可视化

plt.style.use('ggplot')
plt.plot(Y_df['y'], color='darkorange' ,label='Trend')
plt.show()

在这里插入图片描述

n_time = len(Y_df.ds.unique())
val_size = int(.2 * n_time)
test_size = int(.2 * n_time)

Y_df.groupby('unique_id').head(5)

4. 构建模型

nf = NeuralForecast(
    models = [
        BiTCN(
            h = 1, # Forecasting horizon
            input_size = 24, # Input size
            hidden_size = 64, # Units for the TCN's hidden state size
            dropout = 0.5,
            loss=MAE(),
            valid_loss=MAE(),
            max_steps = 1000, # Number of training iterations
            learning_rate = 1e-3,
            num_lr_decays = -1,
            early_stop_patience_steps = -1,
            val_check_steps = 100, # Compute validation loss every 100 steps
            batch_size = 128,
            random_seed=1234,
        ),
    ],
    freq='H'
)

5. 交叉验证

交叉验证方法 cross_validation 将返回模型在测试集上的预测结果。

Y_hat_df = nf.cross_validation(df=Y_df,
                               val_size=val_size,
                               test_size=test_size,
                               n_windows=None)
   | Name          | Type          | Params | Mode 
---------------------------------------------------------
0  | loss          | MAE           | 0      | train
1  | valid_loss    | MAE           | 0      | train
2  | padder_train  | ConstantPad1d | 0      | train
3  | scaler        | TemporalNorm  | 0      | train
4  | lin_hist      | Linear        | 128    | train
5  | drop_hist     | Dropout       | 0      | train
6  | net_bwd       | Sequential    | 82.9 K | train
7  | drop_temporal | Dropout       | 0      | train
8  | temporal_lin1 | Linear        | 1.6 K  | train
9  | temporal_lin2 | Linear        | 65     | train
10 | output_lin    | Linear        | 65     | train
---------------------------------------------------------
84.7 K    Trainable params
0         Non-trainable params
84.7 K    Total params
0.339     Total estimated model params size (MB)
31        Modules in train mode
0         Modules in eval mode
Y_hat_df.head()

6. 模型预测

Y_plot = Y_hat_df.copy() # OT dataset
cutoffs = Y_hat_df['cutoff'].unique()[::1]
Y_plot = Y_plot[Y_hat_df['cutoff'].isin(cutoffs)]

plt.figure(figsize=(20,5))
plt.plot(Y_plot['ds'], Y_plot['y'], label='True')
plt.plot(Y_plot['ds'], Y_plot['BiTCN'], label='BiTCN')
plt.xlabel('Datestamp')
plt.ylabel('OT')
plt.grid()
plt.legend()
plt.savefig('BiTCN.png')

时序预测

7. 回归拟合图

使用 regplot() 函数绘制数据图,拟合预测值与真实值的线性回归图。

plt.figure(figsize=(5, 5), dpi=100)
sns.regplot(x=Y_plot['y'], y=Y_plot['BiTCN'], scatter=True, marker="*", color='orange',line_kws={'color': 'red'})
plt.show()

回归拟合图

8. 模型评估

以下代码使用了一些常见的评估指标:平均绝对误差(MAE)、平均绝对百分比误差(MAPE)、均方误差(MSE)、均方根误差(RMSE)来衡量模型预测的性能。这里我们将调用 neuralforecast.losses.numpy 模块中的 mae, mse, mape, rmse 函数来对模型的预测效果进行评估。

mae = mae(Y_hat_df['y'], Y_hat_df['BiTCN'])
print(f"MAE: {mae:.4f}")

mape = mape(Y_hat_df['y'], Y_hat_df['BiTCN'])
print(f"MAPE: {mape * 100:.4f}%")

mse = mse(Y_hat_df['y'], Y_hat_df['BiTCN'])
print(f"MSE: {mse:.4f}")

rmse = rmse(Y_hat_df['y'], Y_hat_df['BiTCN'])
print(f"RMSE: {rmse:.4f}")
MAE: 0.1239
MAPE: 8.9629%
MSE: 0.0209
RMSE: 0.1444

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/946176.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

Linux C/C++编程-网络程序架构与套接字类型

【图书推荐】《Linux C与C一线开发实践(第2版)》_linux c与c一线开发实践pdf-CSDN博客《Linux C与C一线开发实践(第2版)(Linux技术丛书)》(朱文伟,李建英)【摘要 书评 试读】- 京东图书 (jd.com…

北京某新能源汽车生产及办公网络综合监控项目

北京某新能源汽车是某世界500强汽车集团旗下的新能源公司,也是国内首个获得新能源汽车生产资质、首家进行混合所有制改造、首批践行国有控股企业员工持股的新能源汽车企业,其主营业务包括纯电动乘用车研发设计、生产制造与销售服务。 项目现状 在企业全…

【LeetCode】2506、统计相似字符串对的数目

【LeetCode】2506、统计相似字符串对的数目 文章目录 一、哈希表位运算1.1 哈希表位运算 二、多语言解法 一、哈希表位运算 1.1 哈希表位运算 每个字符串, 可用一个 int 表示. (每个字符 是 int 的一个位) 哈希表记录各 字符组合 出现的次数 步骤: 遇到一个字符串, 得到 ma…

gitlab 还原合并请求

事情是这样的: 菜鸡从 test 分支切了个名为 pref-art 的分支出来,发布后一机灵,发现错了,于是在本地用 git branch -d pref-art 将该分支删掉了。之后切到了 prod 分支,再切出了一个相同名称的 pref-art 分支出来&…

Uncaught ReferenceError: __VUE_HMR_RUNTIME__ is not defined

Syntax Error: Error: vitejs/plugin-vue requires vue (>3.2.13) or vue/compiler-sfc to be present in the dependency tree. 第一步 npm install vue/compiler-sfc npm run dev 运行成功,本地打开页面是空白,控制台报错 重新下载了vue-loa…

LeetCode--排序算法(堆排序、归并排序、快速排序)

排序算法 归并排序算法思路代码时间复杂度 堆排序什么是堆?如何维护堆?如何建堆?堆排序时间复杂度 快速排序算法思想代码时间复杂度 归并排序 算法思路 归并排序算法有两个基本的操作,一个是分,也就是把原数组划分成…

vim里搜索关键字

vim是linux文本编辑器的命令,再vi的基础上做了功能增强 使用方法如下 1. / 关键字, 回车即可, 按n键查找关键字下一个位置 2.? 关键字, 回车即可, 按n键查找关键字下一个位置 3.示例

自学记录鸿蒙API 13:Calendar Kit日历功能从学习到实践

这次的目标是学习和使用HarmonyOS的Calendar Kit功能,特别是最新的API 13版本。Calendar Kit让我感受到了一种与传统开发完全不同的体验——它提供的不只是简单的日历功能,而是一套集创建、查询、更新、删除等强大能力于一体的日程管理服务。 一开始&…

汽车损坏识别检测数据集,使用yolo,pasical voc xml,coco json格式标注,6696张图片,可识别11种损坏类型,识别率89.7%

汽车损坏识别检测数据集,使用yolo,pasical voc xml,coco json格式标注,6696张图片,可识别11种损坏类型损坏: 前挡风玻璃(damage-front-windscreen ) 损坏的门 (damaged-d…

2025年入职/转行网络安全,该如何规划?网络安全职业规划

网络安全是一个日益增长的行业,对于打算进入或转行进入该领域的人来说,制定一个清晰且系统的职业规划非常重要。2025年,网络安全领域将继续发展并面临新的挑战,包括不断变化的技术、法规要求以及日益复杂的威胁环境。以下是一个关…

如何使用 ChatGPT Prompts 写学术论文?

第 1 部分:学术写作之旅:使用 ChatGPT Prompts 进行学术写作的结构化指南 踏上学术写作过程的结构化旅程,每个 ChatGPT 提示都旨在解决特定方面,确保对您的主题进行全面探索。 制定研究问题: “制定一个关于量子计算的社会影响的研究问题,确保清晰并与您的研究目标保持一…

超大规模分类(一):噪声对比估计(Noise Contrastive Estimation, NCE)

NCE损失对应的论文为《A fast and simple algorithm for training neural probabilistic language models》,发表于2012年的ICML会议。 背景 在2012年,语言模型一般采用n-gram的方法,统计单词/上下文间的共现关系,比神经概率语言…

位置编码--RPE

相对位置编码 (Relative Position Encoding, RPE) 1. 相对位置编码 相对位置编码是 Transformer 中的一种改进位置编码方式,它的主要目的是通过直接建模序列中元素之间的相对位置,而不是绝对位置,从而更好地捕捉序列元素之间的依赖关系&#…

2024年12月31日Github流行趋势

项目名称:free-programming-books 项目地址url:https://github.com/EbookFoundation/free-programming-books项目语言:HTML历史star数:344575今日star数:432项目维护者:vhf, eshellman, davorpa, MHM5000, …

mysql下载安装及配置

基本操作参考:https://www.cnblogs.com/zhangkanghui/p/9613844.html ----------------------------------其余常见问题参考下面: 都需要管理员权限 输入命令查看端口号占用,然后kill掉

RoboMIND:多体现基准 机器人操纵的智能规范数据

我们介绍了 RoboMIND,这是机器人操纵的多体现智能规范数据的基准,包括 4 个实施例、279 个不同任务和 61 个不同对象类别的 55k 真实世界演示轨迹。 工业机器人企业 埃斯顿自动化 | 埃夫特机器人 | 节卡机器人 | 珞石机器人 | 法奥机器人 | 非夕科技 | C…

【Spring MVC 核心机制】核心组件和工作流程解析

在 Web 应用开发中,处理用户请求的逻辑常常会涉及到路径匹配、请求分发、视图渲染等多个环节。Spring MVC 作为一款强大的 Web 框架,将这些复杂的操作高度抽象化,通过组件协作简化了开发者的工作。 无论是处理表单请求、生成动态页面&#x…

郑州时空-TMS运输管理系统 GetDataBase 信息泄露漏洞复现

0x01 产品简介 郑州时空-TMS运输管理系统是一款专为物流运输企业设计的综合性管理软件,旨在提高运输效率、降低运输成本,并实现供应链的协同运作。系统基于现代计算机技术和物流管理方法,结合了郑州时空公司的专业经验和技术优势,为物流运输企业提供了一套高效、智能的运输…

电子应用设计方案81:智能AI冲奶瓶系统设计

智能 AI 冲奶瓶系统设计 一、引言 智能 AI 冲奶瓶系统旨在为父母或照顾者提供便捷、准确和卫生的冲奶服务,特别是在夜间或忙碌时,减轻负担并确保婴儿获得适宜的营养。 二、系统概述 1. 系统目标 - 精确调配奶粉和水的比例,满足不同年龄段婴…

Three.js教程004:坐标辅助器与轨道控制器

文章目录 坐标辅助器与轨道控制器实现效果添加坐标辅助器添加轨道控制器完整代码完整代码下载坐标辅助器与轨道控制器 实现效果 添加坐标辅助器 创建坐标辅助器: const axesHelper = new Three.AxesHelper(5);添加到场景中: scene.