官方论文地址->官方论文地址
官方代码地址->官方代码地址
个人修改代码->个人修改的代码已经上传CSDN免费下载
一、本文介绍
本文给大家带来是DLinear模型,DLinear是一种用于时间序列预测(TSF)的简单架构,DLinear的核心思想是将时间序列分解为趋势和剩余序列,并分别使用两个单层线性网络对这两个序列进行建模以进行预测(值得一提的是DLinear的出现是为了挑战Transformer在实现序列预测中有效性)。本文的讲解内容包括:模型原理、数据集介绍、参数讲解、模型训练和预测、结果可视化、训练个人数据集,讲解顺序如下->
预测类型->单元预测、多元预测
适用对象->如果你的配置不是很好这个模型应该很适合你因为参数量很小训练速度很快
二、模型原理
DLinear模型出现是为了调整Transformer的有效性从而存在,Transformer的设计都十分的复杂和需要大量的参数,所以作者提出了一种简单的结构DLinear(参数量我实验过程中确实非常小)。
DLinear的核心思想是将时间序列分解为趋势和剩余序列,并分别使用两个单层线性网络对这两个序列进行建模以进行预测。
具体地,DLinear如何工作的关键点如下:
-
时间序列分解:DLinear将输入的时间序列分解为两部分——趋势部分和剩余部分。这种分解有助于分别处理时间序列中的长期趋势和短期波动。
-
单层线性网络:对于趋势和剩余序列,DLinear分别使用两个单层的线性网络进行建模。这种简单的架构使得DLinear在处理时间序列时既高效又有效。
-
预测任务:在进行预测时,DLinear结合这两个网络的输出来生成最终的时间序列预测。
总结->可以看出DLinear的核心结构真的十分简单就包括一个分解和两个线性网络进行建模最后经过一个简单的相加就输出了结果。
模型的网络结构图如下所示->
图片分析->可以看到和我们上面讲的一样,数据从输入进来经过两个分支,一个为趋势性一个为剩余序列,然后分别经过一个线性层处理(这里的提到的线性层就是普通的全连接层),然后将结果进行简单的拼接就完成了结果的输出(这就这样的简单模型结果比过程十分复杂的Transformer模型效果要好->我自己实验效果确实要好,我拿2020年的bestpaper和普通的Transformer都进行了对比效果确实要有提升)。
下面的图片是一个简单的线性层(普通的全连接层)提取数据的过程图->
这里把模型的代码结构放出来方便大家根据讲解和代码进行对比。
class moving_avg(nn.Module):
"""
Moving average block to highlight the trend of time series
"""
def __init__(self, kernel_size, stride):
super(moving_avg, self).__init__()
self.kernel_size = kernel_size
self.avg = nn.AvgPool1d(kernel_size=kernel_size, stride=stride, padding=0)
def forward(self, x):
# padding on the both ends of time series
front = x[:, 0:1, :].repeat(1, (self.kernel_size - 1) // 2, 1)
end = x[:, -1:, :].repeat(1, (self.kernel_size - 1) // 2, 1)
x = torch.cat([front, x, end], dim=1)
x = self.avg(x.permute(0, 2, 1))
x = x.permute(0, 2, 1)
return x
class series_decomp(nn.Module):
"""
Series decomposition block
"""
def __init__(self, kernel_size):
super(series_decomp, self).__init__()
self.moving_avg = moving_avg(kernel_size, stride=1)
def forward(self, x):
moving_mean = self.moving_avg(x)
res = x - moving_mean
return res, moving_mean
class Model(nn.Module):
"""
Decomposition-Linear
"""
def __init__(self, configs):
super(Model, self).__init__()
self.seq_len = configs.seq_len
self.pred_len = configs.pred_len
# Decompsition Kernel Size
kernel_size = 25
self.decompsition = series_decomp(kernel_size)
self.individual = configs.individual
self.channels = configs.enc_in
if self.individual:
self.Linear_Seasonal = nn.ModuleList()
self.Linear_Trend = nn.ModuleList()
for i in range(self.channels):
self.Linear_Seasonal.append(nn.Linear(self.seq_len,self.pred_len))
self.Linear_Trend.append(nn.Linear(self.seq_len,self.pred_len))
# Use this two lines if you want to visualize the weights
# self.Linear_Seasonal[i].weight = nn.Parameter((1/self.seq_len)*torch.ones([self.pred_len,self.seq_len]))
# self.Linear_Trend[i].weight = nn.Parameter((1/self.seq_len)*torch.ones([self.pred_len,self.seq_len]))
else:
self.Linear_Seasonal = nn.Linear(self.seq_len,self.pred_len)
self.Linear_Trend = nn.Linear(self.seq_len,self.pred_len)
# Use this two lines if you want to visualize the weights
# self.Linear_Seasonal.weight = nn.Parameter((1/self.seq_len)*torch.ones([self.pred_len,self.seq_len]))
# self.Linear_Trend.weight = nn.Parameter((1/self.seq_len)*torch.ones([self.pred_len,self.seq_len]))
def forward(self, x):
# x: [Batch, Input length, Channel]
seasonal_init, trend_init = self.decompsition(x)
seasonal_init, trend_init = seasonal_init.permute(0,2,1), trend_init.permute(0,2,1)
if self.individual:
seasonal_output = torch.zeros([seasonal_init.size(0),seasonal_init.size(1),self.pred_len],dtype=seasonal_init.dtype).to(seasonal_init.device)
trend_output = torch.zeros([trend_init.size(0),trend_init.size(1),self.pred_len],dtype=trend_init.dtype).to(trend_init.device)
for i in range(self.channels):
seasonal_output[:,i,:] = self.Linear_Seasonal[i](seasonal_init[:,i,:])
trend_output[:,i,:] = self.Linear_Trend[i](trend_init[:,i,:])
else:
seasonal_output = self.Linear_Seasonal(seasonal_init)
trend_output = self.Linear_Trend(trend_init)
x = seasonal_output + trend_output
return x.permute(0,2,1) # to [Batch, Output length, Channel]
我看论文的内容大比分都是对比实验,因为DLinear的产生就是为了质疑Transformer所以他和各种Transformer的模型进行对比试验,因为本篇文章就是DLinear的实战案例,对比的部分我就不讲了,大家有兴趣可以看看论文内容在最上面我已经提供了链接。
三、数据集介绍
所用到的数据集为某公司的业务水平评估和其它参数具体的内容我就介绍了估计大家都是想用自己的数据进行训练模型,这里展示部分图片给大家提供参考。
四、参数讲解
模型的参数如下(大部分都是一些公共参数并不涉及模型)->
parser = argparse.ArgumentParser(description='DLinearNet Multivariate Time Series Forecasting')
# basic config
parser.add_argument('--train', type=bool, default=True, help='Whether to conduct training')
parser.add_argument('--rollingforecast', type=bool, default=True, help='rolling forecast True or False')
parser.add_argument('--rolling_data_path', type=str, default='ETTh1-Test.csv', help='rolling data file')
parser.add_argument('--show_results', type=bool, default=True, help='Whether show forecast and real results graph')
parser.add_argument('--model', type=str, default='SCINet',help='Model name')
# data loader
parser.add_argument('--root_path', type=str, default='./data/', help='root path of the data file')
parser.add_argument('--data_path', type=str, default='ETTh1.csv', help='data file')
parser.add_argument('--features', type=str, default='MS',
help='forecasting task, options:[M, S, MS]; M:multivariate predict multivariate, S:univariate predict univariate, MS:multivariate predict univariate')
parser.add_argument('--target', type=str, default='OT', help='target feature in S or MS task')
parser.add_argument('--freq', type=str, default='h',
help='freq for time features encoding, options:[s:secondly, t:minutely, h:hourly, d:daily, b:business days, w:weekly, m:monthly], you can also use more detailed freq like 15min or 3h')
parser.add_argument('--checkpoints', type=str, default='./models/', help='location of model models')
# forecasting task
parser.add_argument('--seq_len', type=int, default=126, help='input sequence length')
parser.add_argument('--label_len', type=int, default=64, help='start token length')
parser.add_argument('--pred_len', type=int, default=4, help='prediction sequence length')
# model
parser.add_argument('--individual', action='store_true', default=False,
help='DLinear: a linear layer for each variate(channel) individually')
parser.add_argument('--enc_in', type=int, default=7, help='encoder input size')
parser.add_argument('--dec_in', type=int, default=7, help='decoder input size')
parser.add_argument('--c_out', type=int, default=1, help='output size')
parser.add_argument('--dropout', type=float, default=0.05, help='dropout')
parser.add_argument('--embed', type=str, default='timeF',
help='time features encoding, options:[timeF, fixed, learned]')
parser.add_argument('--activation', type=str, default='gelu', help='activation')
# optimization
parser.add_argument('--num_workers', type=int, default=0, help='data loader num workers')
parser.add_argument('--train_epochs', type=int, default=10, help='train epochs')
parser.add_argument('--batch_size', type=int, default=16, help='batch size of train input data')
parser.add_argument('--learning_rate', type=float, default=0.001, help='optimizer learning rate')
parser.add_argument('--loss', type=str, default='mse', help='loss function')
parser.add_argument('--lradj', type=str, default='type1', help='adjust learning rate')
# GPU
parser.add_argument('--use_gpu', type=bool, default=True, help='use gpu')
parser.add_argument('--device', type=int, default=0, help='gpu')
模型的详细参数讲解如下(如果你想训练你自己的数据集可以仔细看看)->
参数名称 | 参数类型 | 参数讲解 | |
---|---|---|---|
0 | train | bool | 是否进行训练,如果你单纯只想进行预测设置为False即可, |
1 | rollingforecast | bool | 是否进行滚动预测,如果是则设置为True,如果不进行滚动预测则进行正常的预测 |
2 | rolling-data-path | str | 如果进行滚动预测则需要添加新的和训练文件相同格式的数据 |
3 | show_results | bool | 是否保存预测值和真实值的对比 |
4 | model | str | 定义的模型名称 |
5 | root_path | str | 这个才是你文件的路径,不要到具体的文件,到目录级别即可。 |
6 | data_path | str | 这个填写你文件的具体名称。 |
7 | features | str | 这个是特征有三个选项M,MS,S。分别是多元预测多元,多元预测单元,单元预测单元。 |
8 | target | str | 这个是你数据集中你想要预测那一列数据,假设我预测的是油温OT列就输入OT即可。 |
9 | freq | str | 时间的间隔,你数据集每一条数据之间的时间间隔。 |
10 | checkpoints | str | 训练出来的模型保存路径 |
11 | seq_len | int | 用过去的多少条数据来预测未来的数据 |
12 | label_len | int | 可以理解为更高的权重占比的部分要小于seq_len |
13 | pred_len | int | 预测未来多少个时间点的数据 |
14 | enc_in | int | 你数据有多少列,要减去时间那一列,这里我是输入8列数据但是有一列是时间所以就填写7 |
15 | dec_in | int | 同上 |
16 | individual | bool | 这个就是我们上面提到的两个线性层,如果为True我们则对每一个通道用单独的线性层处理,False则为所有的通道用一个线性层 |
17 | c_out | int | 这里有一些不同如果你的features填写的是M那么和上面就一样,如果填写的MS那么这里要输入1因为你的输出只有一列数据。 |
18 | dropout | float | 这个应该都理解不说了,丢弃的概率,防止过拟合的。 |
19 | embed | str | 时间特征的编码方式,默认为"timeF" |
20 | activation | str | 激活函数 |
21 | num_workers | int | 线程windows大家最好设置成0否则会报线程错误,linux系统随便设置。 |
22 | train_epochs | int | 训练的次数 |
23 | batch_size | int | 一次往模型力输入多少条数据 |
24 | learning_rate | float | 学习率。 |
25 | loss | str | 损失函数,默认为"mse" |
26 | lradj | str | 学习率的调整方式,默认为"type1" |
27 | use_gpu | bool | 是否使用GPU训练,根据自身来选择 |
28 | gpu | int | GPU的编号 |
五、模型训练和预测
1.项目目录结构
项目的目录构造如下->
其中data为训练用的数据放的地方,layers为模型结构存放的地方,models为训练保存的训练模型,results为可视化结果保存的图片和滚动预测的结果,util为一些工具。
2.模型训练
当我们经过上面的参数讲解之后,我们可以开始训练模型了,控制台输出如下->
3.滚动预测
这里进行滚动预测的控制台输出->
4.结果展示
运行结果后,结果保存到同级目录下(下图为预测值和真实值的对比)->
5.结果分析
可以看到预测值和真实值之间的差距还可以,但是这个模型的参数量少得可怜,不得不得质疑Transformer模型的有效性~
六、训练你个人数据集
这个模型我在写的过程中为了节省大家训练自己数据集,我基本上把大部分的参数都写好了,需要大家注意的就是如果要进行滚动预测下面的参数要设置为True。
parser.add_argument('--rollingforecast', type=bool, default=True, help='rolling forecast True or False')
如果上面的参数设置为True那么下面就要提供一个进行滚动预测的数据集,该数据集的格式要和你训练模型的数据集格式完全一致(重要!!!),如果没有可以考虑在自己数据的尾部剪切一部分,不要粘贴否则数据模型已经训练过了的话预测就没有效果了。
parser.add_argument('--rolling_data_path', type=str, default='ETTh1-Test.csv', help='rolling data file')
其它的没什么可以讲的了大部分的修改操作在参数讲解的部分我都详细讲过了,这里的滚动预测可能是大家想看的所以摘出来详细讲讲。
总结
到此本文已经全部讲解完成了,希望能够帮助到大家,在这里也给大家推荐一些我其它的博客的时间序列实战案例讲解,其中有数据分析的讲解就是我前面提到的如何设置参数的分析博客,最后希望大家订阅我的专栏,本专栏均分文章均分98,并且免费阅读。
概念理解
15种时间序列预测方法总结(包含多种方法代码实现)
数据分析
时间序列预测中的数据分析->周期性、相关性、滞后性、趋势性、离群值等特性的分析方法
机器学习——难度等级(⭐⭐)
时间序列预测实战(四)(Xgboost)(Python)(机器学习)图解机制原理实现时间序列预测和分类(附一键运行代码资源下载和代码讲解)
深度学习——难度等级(⭐⭐⭐⭐)
时间序列预测实战(五)基于Bi-LSTM横向搭配LSTM进行回归问题解决
时间序列预测实战(七)(TPA-LSTM)结合TPA注意力机制的LSTM实现多元预测
时间序列预测实战(三)(LSTM)(Python)(深度学习)时间序列预测(包括运行代码以及代码讲解)
时间序列预测实战(十一)用SCINet实现滚动预测功能(附代码+数据集+原理介绍)
Transformer——难度等级(⭐⭐⭐⭐)
时间序列预测模型实战案例(八)(Informer)个人数据集、详细参数、代码实战讲解
时间序列预测模型实战案例(一)深度学习华为MTS-Mixers模型
个人创新模型——难度等级(⭐⭐⭐⭐⭐)
时间序列预测实战(十)(CNN-GRU-LSTM)通过堆叠CNN、GRU、LSTM实现多元预测和单元预测
传统的时间序列预测模型(⭐⭐)
时间序列预测实战(二)(Holt-Winter)(Python)结合K-折交叉验证进行时间序列预测实现企业级预测精度(包括运行代码以及代码讲解)
时间序列预测实战(六)深入理解ARIMA包括差分和相关性分析
融合模型——难度等级(⭐⭐⭐)
时间序列预测实战(九)PyTorch实现融合移动平均和LSTM-ARIMA进行长期预测