Transformer - 特征预处理
flyfish
原始数据
train_data.values
[[ 5.827 2.009 1.599 0.462 4.203 1.34 30.531]
[ 5.76 2.076 1.492 0.426 4.264 1.401 30.46 ]
[ 5.76 1.942 1.492 0.391 4.234 1.31 30.038]
[ 5.76 1.942 1.492 0.426 4.234 1.31 27.013]
[ 5.693 2.076 1.492 0.426 4.142 1.371 27.787]
[ 5.492 1.942 1.457 0.391 4.112 1.279 27.717]
[ 5.358 1.875 1.35 0.355 3.929 1.34 27.646]
[ 5.157 1.808 1.35 0.32 3.807 1.279 27.084]
[ 5.157 1.741 1.279 0.355 3.777 1.218 27.787]
[ 5.157 1.808 1.35 0.426 3.777 1.188 27.506]
[ 5.157 1.808 1.315 0.391 3.777 1.249 27.857]
[ 5.157 1.942 1.35 0.426 3.807 1.279 27.013]
[ 5.09 1.942 1.279 0.391 3.807 1.279 25.044]
[ 5.224 2.009 1.457 0.533 3.807 1.249 24.551]
[ 5.291 1.808 1.457 0.426 3.777 1.218 23.566]
[ 5.358 1.942 1.492 0.462 3.807 1.31 21.526]
[ 5.358 1.942 1.492 0.462 3.868 1.279 21.948]
[ 5.492 2.009 1.492 0.462 3.929 1.34 21.456]
[ 5.492 1.942 1.492 0.426 3.929 1.34 22.792]
[ 5.492 2.076 1.492 0.497 3.99 1.31 21.034]
[ 5.626 2.143 1.528 0.533 4.051 1.371 21.174]
[ 5.961 2.344 1.67 0.604 4.234 1.492 20.823]
[ 6.162 2.411 1.777 0.604 4.325 1.523 21.174]
[ 6.631 2.478 1.99 0.746 4.66 1.675 21.174]
[ 7.167 2.947 2.132 0.782 5.026 1.858 22.792]
[ 7.502 3.215 2.239 0.888 5.33 1.98 23.848]
[ 7.703 3.349 2.487 1.031 5.269 1.919 24.34 ]
......
通过sklearn的fit和transform将数据规范化
train_data = df_data[border1s[0]:border2s[0]]
self.scaler.fit(train_data.values)
data = self.scaler.transform(df_data.values)
规划化后的数据就是将要训练的数据
transform_data [[ 0.6156 -1.3896 -0.991 ... 1.1402 -0.9535 2.907 ]
[ 0.5294 -1.2429 -1.2435 ... 1.2172 -0.6099 2.8853]
[ 0.5294 -1.5362 -1.2435 ... 1.1794 -1.1224 2.7561]
...
[ 5.6959 5.6479 11.0826 ... -0.82 0.933 1.5291]
[ 7.1602 6.0879 13.2628 ... -0.897 1.1076 1.5077]
[ 6.8156 5.3546 14.3529 ... -0.82 1.2765 1.5508]]
可以通过inverse_transform将数据还原
def inverse_transform(self, data):
return self.scaler.inverse_transform(data)
inverse_transform_data: [[ 5.827 2.009 1.599 ... 4.203 1.34 30.531]
[ 5.76 2.076 1.492 ... 4.264 1.401 30.46 ]
[ 5.76 1.942 1.492 ... 4.234 1.31 30.038]
...
[ 9.779 5.224 6.716 ... 2.65 1.675 26.028]
[10.918 5.425 7.64 ... 2.589 1.706 25.958]
[10.65 5.09 8.102 ... 2.65 1.736 26.099]]
......
配置
seq_len:24
label_len:12
pred_len:24
set_type:0
features:M
target:OT
scale:True
timeenc:1
freq:h
root_path:./dataset/ETT-small/
data_path:ETTm1.csv
scaler:StandardScaler()
data_x是训练数据
data_x:[[ 6.1557e-01 -1.3896e+00 -9.9100e-01 -1.3248e+00 1.1402e+00 -9.5346e-01
2.9070e+00]
[ 5.2944e-01 -1.2429e+00 -1.2435e+00 -1.4268e+00 1.2172e+00 -6.0995e-01
2.8853e+00]
[ 5.2944e-01 -1.5362e+00 -1.2435e+00 -1.5260e+00 1.1794e+00 -1.1224e+00
2.7561e+00]
[ 5.2944e-01 -1.5362e+00 -1.2435e+00 -1.4268e+00 1.1794e+00 -1.1224e+00
1.8305e+00]
[ 4.4331e-01 -1.2429e+00 -1.2435e+00 -1.4268e+00 1.0633e+00 -7.7889e-01
2.0673e+00]
[ 1.8492e-01 -1.5362e+00 -1.3260e+00 -1.5260e+00 1.0254e+00 -1.2970e+00
2.0459e+00]
[ 1.2660e-02 -1.6829e+00 -1.5785e+00 -1.6280e+00 7.9439e-01 -9.5346e-01
2.0242e+00]
时间数据的编码
具体看这里
原值
df_stamp['date'].values: [
'2016-07-01T00:00:00.000000000' '2016-07-01T00:15:00.000000000'
'2016-07-01T00:30:00.000000000' '2016-07-01T00:45:00.000000000'
'2016-07-01T01:00:00.000000000' '2016-07-01T01:15:00.000000000'
'2016-07-01T01:30:00.000000000' '2016-07-01T01:45:00.000000000'
'2016-07-01T02:00:00.000000000' '2016-07-01T02:15:00.000000000'
'2016-07-01T02:30:00.000000000' '2016-07-01T02:45:00.000000000'
'2016-07-01T03:00:00.000000000' '2016-07-01T03:15:00.000000000'
'2016-07-01T03:30:00.000000000' '2016-07-01T03:45:00.000000000'
'2016-07-01T04:00:00.000000000' '2016-07-01T04:15:00.000000000'
'2016-07-01T04:30:00.000000000' '2016-07-01T04:45:00.000000000'
'2016-07-01T05:00:00.000000000' '2016-07-01T05:15:00.000000000'
......
编码之后
data_stamp: [[-0.5 0.1667 -0.5 -0.0014]
[-0.5 0.1667 -0.5 -0.0014]
[-0.5 0.1667 -0.5 -0.0014]
[-0.5 0.1667 -0.5 -0.0014]
[-0.4565 0.1667 -0.5 -0.0014]
[-0.4565 0.1667 -0.5 -0.0014]
[-0.4565 0.1667 -0.5 -0.0014]
[-0.4565 0.1667 -0.5 -0.0014]
[-0.413 0.1667 -0.5 -0.0014]
[-0.413 0.1667 -0.5 -0.0014]
[-0.413 0.1667 -0.5 -0.0014]
[-0.413 0.1667 -0.5 -0.0014]
[-0.3696 0.1667 -0.5 -0.0014]
[-0.3696 0.1667 -0.5 -0.0014]
[-0.3696 0.1667 -0.5 -0.0014]
[-0.3696 0.1667 -0.5 -0.0014]
[-0.3261 0.1667 -0.5 -0.0014]
[-0.3261 0.1667 -0.5 -0.0014]
[-0.3261 0.1667 -0.5 -0.0014]
[-0.3261 0.1667 -0.5 -0.0014]
[-0.2826 0.1667 -0.5 -0.0014]
[-0.2826 0.1667 -0.5 -0.0014]
s_begin: 0
s_end: 24
r_begin: 12
r_end: 48
s_begin: 1
s_end: 25
r_begin: 13
r_end: 49
......
seq_x: [[ 6.1557e-01 -1.3896e+00 -9.9100e-01 -1.3248e+00 1.1402e+00 -9.5346e-01
2.9070e+00]
[ 5.2944e-01 -1.2429e+00 -1.2435e+00 -1.4268e+00 1.2172e+00 -6.0995e-01
2.8853e+00]
[ 5.2944e-01 -1.5362e+00 -1.2435e+00 -1.5260e+00 1.1794e+00 -1.1224e+00
2.7561e+00]
[ 5.2944e-01 -1.5362e+00 -1.2435e+00 -1.4268e+00 1.1794e+00 -1.1224e+00
1.8305e+00]
[ 4.4331e-01 -1.2429e+00 -1.2435e+00 -1.4268e+00 1.0633e+00 -7.7889e-01
2.0673e+00]
[ 1.8492e-01 -1.5362e+00 -1.3260e+00 -1.5260e+00 1.0254e+00 -1.2970e+00
2.0459e+00]
[ 1.2660e-02 -1.6829e+00 -1.5785e+00 -1.6280e+00 7.9439e-01 -9.5346e-01
2.0242e+00]
[-2.4573e-01 -1.8295e+00 -1.5785e+00 -1.7271e+00 6.4040e-01 -1.2970e+00
1.8522e+00]
[-2.4573e-01 -1.9762e+00 -1.7460e+00 -1.6280e+00 6.0253e-01 -1.6405e+00
2.0673e+00]
[-2.4573e-01 -1.8295e+00 -1.5785e+00 -1.4268e+00 6.0253e-01 -1.8094e+00
1.9814e+00]
[-2.4573e-01 -1.8295e+00 -1.6611e+00 -1.5260e+00 6.0253e-01 -1.4659e+00
2.0888e+00]
[-2.4573e-01 -1.5362e+00 -1.5785e+00 -1.4268e+00 6.4040e-01 -1.2970e+00
1.8305e+00]
[-3.3186e-01 -1.5362e+00 -1.7460e+00 -1.5260e+00 6.4040e-01 -1.2970e+00
1.2280e+00]
[-1.5960e-01 -1.3896e+00 -1.3260e+00 -1.1237e+00 6.4040e-01 -1.4659e+00
1.0771e+00]
[-7.3470e-02 -1.8295e+00 -1.3260e+00 -1.4268e+00 6.0253e-01 -1.6405e+00
7.7573e-01]
[ 1.2660e-02 -1.5362e+00 -1.2435e+00 -1.3248e+00 6.4040e-01 -1.1224e+00
1.5150e-01]
[ 1.2660e-02 -1.5362e+00 -1.2435e+00 -1.3248e+00 7.1740e-01 -1.2970e+00
2.8063e-01]
[ 1.8492e-01 -1.3896e+00 -1.2435e+00 -1.3248e+00 7.9439e-01 -9.5346e-01
1.3008e-01]
[ 1.8492e-01 -1.5362e+00 -1.2435e+00 -1.4268e+00 7.9439e-01 -9.5346e-01
5.3889e-01]
[ 1.8492e-01 -1.2429e+00 -1.2435e+00 -1.2257e+00 8.7139e-01 -1.1224e+00
9.5210e-04]
[ 3.5718e-01 -1.0962e+00 -1.1585e+00 -1.1237e+00 9.4839e-01 -7.7889e-01
4.3791e-02]
[ 7.8783e-01 -6.5627e-01 -8.2347e-01 -9.2254e-01 1.1794e+00 -9.7496e-02
-6.3613e-02]
[ 1.0462e+00 -5.0961e-01 -5.7100e-01 -9.2254e-01 1.2942e+00 7.7075e-02
4.3791e-02]
[ 1.6491e+00 -3.6295e-01 -6.8426e-02 -5.2023e-01 1.7171e+00 9.3304e-01
4.3791e-02]]
seq_y: [[-3.3186e-01 -1.5362e+00 -1.7460e+00 -1.5260e+00 6.4040e-01 -1.2970e+00
1.2280e+00]
[-1.5960e-01 -1.3896e+00 -1.3260e+00 -1.1237e+00 6.4040e-01 -1.4659e+00
1.0771e+00]
[-7.3470e-02 -1.8295e+00 -1.3260e+00 -1.4268e+00 6.0253e-01 -1.6405e+00
7.7573e-01]
[ 1.2660e-02 -1.5362e+00 -1.2435e+00 -1.3248e+00 6.4040e-01 -1.1224e+00
1.5150e-01]
[ 1.2660e-02 -1.5362e+00 -1.2435e+00 -1.3248e+00 7.1740e-01 -1.2970e+00
2.8063e-01]
[ 1.8492e-01 -1.3896e+00 -1.2435e+00 -1.3248e+00 7.9439e-01 -9.5346e-01
1.3008e-01]
[ 1.8492e-01 -1.5362e+00 -1.2435e+00 -1.4268e+00 7.9439e-01 -9.5346e-01
5.3889e-01]
[ 1.8492e-01 -1.2429e+00 -1.2435e+00 -1.2257e+00 8.7139e-01 -1.1224e+00
9.5210e-04]
[ 3.5718e-01 -1.0962e+00 -1.1585e+00 -1.1237e+00 9.4839e-01 -7.7889e-01
4.3791e-02]
[ 7.8783e-01 -6.5627e-01 -8.2347e-01 -9.2254e-01 1.1794e+00 -9.7496e-02
-6.3613e-02]
[ 1.0462e+00 -5.0961e-01 -5.7100e-01 -9.2254e-01 1.2942e+00 7.7075e-02
4.3791e-02]
[ 1.6491e+00 -3.6295e-01 -6.8426e-02 -5.2023e-01 1.7171e+00 9.3304e-01
4.3791e-02]
[ 2.3382e+00 6.6367e-01 2.6662e-01 -4.1824e-01 2.1791e+00 1.9636e+00
5.3889e-01]
[ 2.7688e+00 1.2503e+00 5.1909e-01 -1.1793e-01 2.5628e+00 2.6506e+00
8.6202e-01]
[ 3.0272e+00 1.5436e+00 1.1043e+00 2.8720e-01 2.4858e+00 2.3071e+00
1.0126e+00]
[ 2.6827e+00 1.1037e+00 6.8662e-01 8.3219e-02 2.2169e+00 2.1325e+00
6.4660e-01]
[ 2.6827e+00 1.3970e+00 6.8662e-01 2.8720e-01 2.2561e+00 4.0246e+00
6.4660e-01]
[ 2.9411e+00 1.3970e+00 6.0168e-01 3.8636e-01 2.5628e+00 3.1630e+00
8.4030e-01]
[ 2.9411e+00 1.6903e+00 6.8662e-01 4.8835e-01 2.4858e+00 2.3071e+00
8.6202e-01]
[ 2.8549e+00 1.3970e+00 6.8662e-01 3.8636e-01 2.4479e+00 2.3071e+00
9.0486e-01]
[ 2.7105e-01 8.1033e-01 1.0217e+00 6.8951e-01 -4.3503e-01 -4.3537e-01
1.9465e-01]
[-1.5960e-01 5.1701e-01 8.5414e-01 6.8951e-01 -6.2815e-01 -6.0995e-01
3.2378e-01]
[-2.4573e-01 3.7035e-01 6.8662e-01 6.8951e-01 -5.8902e-01 -4.3537e-01
4.0976e-01]
[-3.3186e-01 -5.0961e-01 -4.8842e-01 -1.1793e-01 -7.0514e-01 -2.6644e-01
-1.4198e+00]
[-1.0196e+00 -2.1629e-01 -2.3595e-01 -3.1908e-01 -7.8214e-01 -7.7889e-01
-1.0970e+00]
[-1.0196e+00 -6.5627e-01 -5.7100e-01 -4.1824e-01 -7.4301e-01 -6.0995e-01
-1.0110e+00]
[-8.4735e-01 -2.1629e-01 -3.2089e-01 -2.1709e-01 -6.2815e-01 -2.6644e-01
-7.7413e-01]
[-8.4735e-01 -5.0961e-01 -2.3595e-01 -1.1793e-01 -6.2815e-01 -7.7889e-01
-5.3729e-01]
[-5.0283e-01 -2.1629e-01 -6.8426e-02 -2.1709e-01 -4.3503e-01 -9.7496e-02
-3.2187e-01]
[ 9.8790e-02 7.7033e-02 4.3415e-01 -1.1793e-01 -2.0530e-01 2.4601e-01
-7.7413e-01]
[ 1.8492e-01 7.7033e-02 3.5157e-01 -2.1709e-01 -5.1305e-02 2.4601e-01
-7.5241e-01]
[ 7.0170e-01 2.2369e-01 8.5414e-01 8.3219e-02 1.4055e-01 2.4601e-01
-4.9415e-01]
[ 5.2944e-01 -2.1629e-01 4.3415e-01 -2.1709e-01 1.7968e-01 -9.7496e-02
-2.7903e-01]
[ 3.5718e-01 2.2369e-01 3.5157e-01 -2.1709e-01 6.3558e-02 5.8952e-01
-3.8644e-01]
[-1.3641e+00 -1.0962e+00 -2.0811e+00 -1.4268e+00 -1.6617e-01 7.6410e-01
-3.8644e-01]
[-1.3641e+00 -5.0961e-01 -1.0736e+00 -1.4268e+00 -2.4317e-01 4.2059e-01
-2.1447e-01]]
seq_x_mark: [[-0.5 0.1667 -0.5 -0.0014]
[-0.5 0.1667 -0.5 -0.0014]
[-0.5 0.1667 -0.5 -0.0014]
[-0.5 0.1667 -0.5 -0.0014]
[-0.4565 0.1667 -0.5 -0.0014]
[-0.4565 0.1667 -0.5 -0.0014]
[-0.4565 0.1667 -0.5 -0.0014]
[-0.4565 0.1667 -0.5 -0.0014]
[-0.413 0.1667 -0.5 -0.0014]
[-0.413 0.1667 -0.5 -0.0014]
[-0.413 0.1667 -0.5 -0.0014]
[-0.413 0.1667 -0.5 -0.0014]
[-0.3696 0.1667 -0.5 -0.0014]
[-0.3696 0.1667 -0.5 -0.0014]
[-0.3696 0.1667 -0.5 -0.0014]
[-0.3696 0.1667 -0.5 -0.0014]
[-0.3261 0.1667 -0.5 -0.0014]
[-0.3261 0.1667 -0.5 -0.0014]
[-0.3261 0.1667 -0.5 -0.0014]
[-0.3261 0.1667 -0.5 -0.0014]
[-0.2826 0.1667 -0.5 -0.0014]
[-0.2826 0.1667 -0.5 -0.0014]
[-0.2826 0.1667 -0.5 -0.0014]
[-0.2826 0.1667 -0.5 -0.0014]]
seq_y_mark: [[-0.3696 0.1667 -0.5 -0.0014]
[-0.3696 0.1667 -0.5 -0.0014]
[-0.3696 0.1667 -0.5 -0.0014]
[-0.3696 0.1667 -0.5 -0.0014]
[-0.3261 0.1667 -0.5 -0.0014]
[-0.3261 0.1667 -0.5 -0.0014]
[-0.3261 0.1667 -0.5 -0.0014]
[-0.3261 0.1667 -0.5 -0.0014]
[-0.2826 0.1667 -0.5 -0.0014]
[-0.2826 0.1667 -0.5 -0.0014]
[-0.2826 0.1667 -0.5 -0.0014]
[-0.2826 0.1667 -0.5 -0.0014]
[-0.2391 0.1667 -0.5 -0.0014]
[-0.2391 0.1667 -0.5 -0.0014]
[-0.2391 0.1667 -0.5 -0.0014]
[-0.2391 0.1667 -0.5 -0.0014]
[-0.1957 0.1667 -0.5 -0.0014]
[-0.1957 0.1667 -0.5 -0.0014]
[-0.1957 0.1667 -0.5 -0.0014]
[-0.1957 0.1667 -0.5 -0.0014]
[-0.1522 0.1667 -0.5 -0.0014]
[-0.1522 0.1667 -0.5 -0.0014]
[-0.1522 0.1667 -0.5 -0.0014]
[-0.1522 0.1667 -0.5 -0.0014]
[-0.1087 0.1667 -0.5 -0.0014]
[-0.1087 0.1667 -0.5 -0.0014]
[-0.1087 0.1667 -0.5 -0.0014]
[-0.1087 0.1667 -0.5 -0.0014]
[-0.0652 0.1667 -0.5 -0.0014]
[-0.0652 0.1667 -0.5 -0.0014]
[-0.0652 0.1667 -0.5 -0.0014]
[-0.0652 0.1667 -0.5 -0.0014]
[-0.0217 0.1667 -0.5 -0.0014]
[-0.0217 0.1667 -0.5 -0.0014]
[-0.0217 0.1667 -0.5 -0.0014]
[-0.0217 0.1667 -0.5 -0.0014]]
代码
class Dataset_Custom(Dataset):
def __init__(self, root_path, flag='train', size=None,
features='S', data_path='ETTh1.csv',
target='OT', scale=True, timeenc=0, freq='h'):
# size [seq_len, label_len, pred_len]
# info
if size == None:
self.seq_len = 24 * 4 * 4
self.label_len = 24 * 4
self.pred_len = 24 * 4
else:
self.seq_len = size[0]
self.label_len = size[1]
self.pred_len = size[2]
# init
assert flag in ['train', 'test', 'val']
type_map = {'train': 0, 'val': 1, 'test': 2}
self.set_type = type_map[flag]
self.features = features
self.target = target
self.scale = scale
self.timeenc = timeenc
self.freq = freq
self.root_path = root_path
self.data_path = data_path
self.__read_data__()
def __read_data__(self):
self.scaler = StandardScaler()
df_raw = pd.read_csv(os.path.join(self.root_path,
self.data_path))
'''
df_raw.columns: ['date', ...(other features), target feature]
'''
cols = list(df_raw.columns)
cols.remove(self.target)
cols.remove('date')
df_raw = df_raw[['date'] + cols + [self.target]]
# print(cols)
# num_train = int(len(df_raw) * 0.7)
# print("num_train:",num_train)
# num_test = int(len(df_raw) * 0.2)
num_train = int(len(df_raw) * 0.5)
print("num_train:",num_train)
num_test = int(len(df_raw) * 0.2)
num_vali = len(df_raw) - num_train - num_test
border1s = [0, num_train - self.seq_len, len(df_raw) - num_test - self.seq_len]
border2s = [num_train, num_train + num_vali, len(df_raw)]
border1 = border1s[self.set_type]
border2 = border2s[self.set_type]
if self.features == 'M' or self.features == 'MS':
cols_data = df_raw.columns[1:]
df_data = df_raw[cols_data]
elif self.features == 'S':
df_data = df_raw[[self.target]]
if self.scale:
train_data = df_data[border1s[0]:border2s[0]]
self.scaler.fit(train_data.values)
data = self.scaler.transform(df_data.values)
#--------------------------------------------------------------------
print("train_data.values",train_data.values)
print("transform_data",data)
inverse_transform_data=self.inverse_transform(data)
print("inverse_transform_data:",inverse_transform_data)
#--------------------------------------------------------------------
else:
data = df_data.values
df_stamp = df_raw[['date']][border1:border2]
df_stamp['date'] = pd.to_datetime(df_stamp.date)
if self.timeenc == 0:
df_stamp['month'] = df_stamp.date.apply(lambda row: row.month, 1)
df_stamp['day'] = df_stamp.date.apply(lambda row: row.day, 1)
df_stamp['weekday'] = df_stamp.date.apply(lambda row: row.weekday(), 1)
df_stamp['hour'] = df_stamp.date.apply(lambda row: row.hour, 1)
data_stamp = df_stamp.drop(['date'], 1).values
elif self.timeenc == 1:
print("df_stamp['date'].values:",df_stamp['date'].values)
data_stamp = time_features(pd.to_datetime(df_stamp['date'].values), freq=self.freq)
data_stamp = data_stamp.transpose(1, 0)
self.data_x = data[border1:border2]
self.data_y = data[border1:border2]
self.data_stamp = data_stamp
print("data_stamp:",data_stamp)
print('\n'.join(['%s:%s' % item for item in self.__dict__.items()]) )
def __getitem__(self, index):
s_begin = index
s_end = s_begin + self.seq_len
r_begin = s_end - self.label_len
r_end = r_begin + self.label_len + self.pred_len
print("s_begin:",s_begin)
print("s_end:",s_end)
print("r_begin:",r_begin)
print("r_end:",r_end)
seq_x = self.data_x[s_begin:s_end]
seq_y = self.data_y[r_begin:r_end]
seq_x_mark = self.data_stamp[s_begin:s_end]
seq_y_mark = self.data_stamp[r_begin:r_end]
print("seq_x.shape:",seq_x.shape)
print("seq_y.shape:",seq_y.shape)
print("seq_x_mark.shape:",seq_x_mark.shape)
print("seq_y_mark.shape:",seq_y_mark.shape)
print("seq_x:",seq_x)
print("seq_y:",seq_y)
print("seq_x_mark:",seq_x_mark)
print("seq_y_mark:",seq_y_mark)
return seq_x, seq_y, seq_x_mark, seq_y_mark
def __len__(self):
return len(self.data_x) - self.seq_len - self.pred_len + 1
def inverse_transform(self, data):
return self.scaler.inverse_transform(data)
对比下形状
for i, (batch_x, , , ): torch.Size([1, 24, 7])
for i, (, batch_y, , ): torch.Size([1, 36, 7])
for i, (, , batch_x_mark, ): torch.Size([1, 24, 4])
for i, (, , , batch_y_mark): torch.Size([1, 36, 4])
seq_x.shape: (24, 7)
seq_y.shape: (36, 7)
seq_x_mark.shape: (24, 4)
seq_y_mark.shape: (36, 4)
训练数据用
先用batch_x, batch_y, batch_x_mark, batch_y_mark
作为参数
outputs, batch_y = self._predict(batch_x, batch_y, batch_x_mark, batch_y_mark)
def _predict(self, batch_x, batch_y, batch_x_mark, batch_y_mark):
# decoder input
dec_inp = torch.zeros_like(batch_y[:, -self.args.pred_len:, :]).float()
dec_inp = torch.cat([batch_y[:, :self.args.label_len, :], dec_inp], dim=1).float().to(self.device)
# encoder - decoder
def _run_model():
outputs = self.model(batch_x, batch_x_mark, dec_inp, batch_y_mark)
if self.args.output_attention:
outputs = outputs[0]
return outputs
batch_y变dec_inp
假如是Vanilla Transformer模型
输入的对应关系如下
x_enc = batch_x
x_mark_enc = batch_x_mark
x_dec = dec_inp
x_mark_dec = batch_y_mark