Transformer - 特征预处理

Transformer - 特征预处理

flyfish
在这里插入图片描述
原始数据

train_data.values 
[[ 5.827  2.009  1.599  0.462  4.203  1.34  30.531]
 [ 5.76   2.076  1.492  0.426  4.264  1.401 30.46 ]
 [ 5.76   1.942  1.492  0.391  4.234  1.31  30.038]
 [ 5.76   1.942  1.492  0.426  4.234  1.31  27.013]
 [ 5.693  2.076  1.492  0.426  4.142  1.371 27.787]
 [ 5.492  1.942  1.457  0.391  4.112  1.279 27.717]
 [ 5.358  1.875  1.35   0.355  3.929  1.34  27.646]
 [ 5.157  1.808  1.35   0.32   3.807  1.279 27.084]
 [ 5.157  1.741  1.279  0.355  3.777  1.218 27.787]
 [ 5.157  1.808  1.35   0.426  3.777  1.188 27.506]
 [ 5.157  1.808  1.315  0.391  3.777  1.249 27.857]
 [ 5.157  1.942  1.35   0.426  3.807  1.279 27.013]
 [ 5.09   1.942  1.279  0.391  3.807  1.279 25.044]
 [ 5.224  2.009  1.457  0.533  3.807  1.249 24.551]
 [ 5.291  1.808  1.457  0.426  3.777  1.218 23.566]
 [ 5.358  1.942  1.492  0.462  3.807  1.31  21.526]
 [ 5.358  1.942  1.492  0.462  3.868  1.279 21.948]
 [ 5.492  2.009  1.492  0.462  3.929  1.34  21.456]
 [ 5.492  1.942  1.492  0.426  3.929  1.34  22.792]
 [ 5.492  2.076  1.492  0.497  3.99   1.31  21.034]
 [ 5.626  2.143  1.528  0.533  4.051  1.371 21.174]
 [ 5.961  2.344  1.67   0.604  4.234  1.492 20.823]
 [ 6.162  2.411  1.777  0.604  4.325  1.523 21.174]
 [ 6.631  2.478  1.99   0.746  4.66   1.675 21.174]
 [ 7.167  2.947  2.132  0.782  5.026  1.858 22.792]
 [ 7.502  3.215  2.239  0.888  5.33   1.98  23.848]
 [ 7.703  3.349  2.487  1.031  5.269  1.919 24.34 ]
 ......

通过sklearn的fit和transform将数据规范化

train_data = df_data[border1s[0]:border2s[0]]
self.scaler.fit(train_data.values)
data = self.scaler.transform(df_data.values)

规划化后的数据就是将要训练的数据

transform_data [[ 0.6156 -1.3896 -0.991  ...  1.1402 -0.9535  2.907 ]
 [ 0.5294 -1.2429 -1.2435 ...  1.2172 -0.6099  2.8853]
 [ 0.5294 -1.5362 -1.2435 ...  1.1794 -1.1224  2.7561]
 ...
 [ 5.6959  5.6479 11.0826 ... -0.82    0.933   1.5291]
 [ 7.1602  6.0879 13.2628 ... -0.897   1.1076  1.5077]
 [ 6.8156  5.3546 14.3529 ... -0.82    1.2765  1.5508]]

可以通过inverse_transform将数据还原

def inverse_transform(self, data):
    return self.scaler.inverse_transform(data)
inverse_transform_data: [[ 5.827  2.009  1.599 ...  4.203  1.34  30.531]
 [ 5.76   2.076  1.492 ...  4.264  1.401 30.46 ]
 [ 5.76   1.942  1.492 ...  4.234  1.31  30.038]
 ...
 [ 9.779  5.224  6.716 ...  2.65   1.675 26.028]
 [10.918  5.425  7.64  ...  2.589  1.706 25.958]
 [10.65   5.09   8.102 ...  2.65   1.736 26.099]]
 ......

配置

seq_len:24
label_len:12
pred_len:24
set_type:0
features:M
target:OT
scale:True
timeenc:1
freq:h
root_path:./dataset/ETT-small/
data_path:ETTm1.csv
scaler:StandardScaler()

data_x是训练数据

data_x:[[ 6.1557e-01 -1.3896e+00 -9.9100e-01 -1.3248e+00  1.1402e+00 -9.5346e-01
   2.9070e+00]
 [ 5.2944e-01 -1.2429e+00 -1.2435e+00 -1.4268e+00  1.2172e+00 -6.0995e-01
   2.8853e+00]
 [ 5.2944e-01 -1.5362e+00 -1.2435e+00 -1.5260e+00  1.1794e+00 -1.1224e+00
   2.7561e+00]
 [ 5.2944e-01 -1.5362e+00 -1.2435e+00 -1.4268e+00  1.1794e+00 -1.1224e+00
   1.8305e+00]
 [ 4.4331e-01 -1.2429e+00 -1.2435e+00 -1.4268e+00  1.0633e+00 -7.7889e-01
   2.0673e+00]
 [ 1.8492e-01 -1.5362e+00 -1.3260e+00 -1.5260e+00  1.0254e+00 -1.2970e+00
   2.0459e+00]
 [ 1.2660e-02 -1.6829e+00 -1.5785e+00 -1.6280e+00  7.9439e-01 -9.5346e-01
   2.0242e+00]

时间数据的编码
具体看这里
原值

df_stamp['date'].values: [
 '2016-07-01T00:00:00.000000000' '2016-07-01T00:15:00.000000000'
 '2016-07-01T00:30:00.000000000' '2016-07-01T00:45:00.000000000'
 '2016-07-01T01:00:00.000000000' '2016-07-01T01:15:00.000000000'
 '2016-07-01T01:30:00.000000000' '2016-07-01T01:45:00.000000000'
 '2016-07-01T02:00:00.000000000' '2016-07-01T02:15:00.000000000'
 '2016-07-01T02:30:00.000000000' '2016-07-01T02:45:00.000000000'
 '2016-07-01T03:00:00.000000000' '2016-07-01T03:15:00.000000000'
 '2016-07-01T03:30:00.000000000' '2016-07-01T03:45:00.000000000'
 '2016-07-01T04:00:00.000000000' '2016-07-01T04:15:00.000000000'
 '2016-07-01T04:30:00.000000000' '2016-07-01T04:45:00.000000000'
 '2016-07-01T05:00:00.000000000' '2016-07-01T05:15:00.000000000'
 ......

编码之后

data_stamp: [[-0.5     0.1667 -0.5    -0.0014]
 [-0.5     0.1667 -0.5    -0.0014]
 [-0.5     0.1667 -0.5    -0.0014]
 [-0.5     0.1667 -0.5    -0.0014]
 [-0.4565  0.1667 -0.5    -0.0014]
 [-0.4565  0.1667 -0.5    -0.0014]
 [-0.4565  0.1667 -0.5    -0.0014]
 [-0.4565  0.1667 -0.5    -0.0014]
 [-0.413   0.1667 -0.5    -0.0014]
 [-0.413   0.1667 -0.5    -0.0014]
 [-0.413   0.1667 -0.5    -0.0014]
 [-0.413   0.1667 -0.5    -0.0014]
 [-0.3696  0.1667 -0.5    -0.0014]
 [-0.3696  0.1667 -0.5    -0.0014]
 [-0.3696  0.1667 -0.5    -0.0014]
 [-0.3696  0.1667 -0.5    -0.0014]
 [-0.3261  0.1667 -0.5    -0.0014]
 [-0.3261  0.1667 -0.5    -0.0014]
 [-0.3261  0.1667 -0.5    -0.0014]
 [-0.3261  0.1667 -0.5    -0.0014]
 [-0.2826  0.1667 -0.5    -0.0014]
 [-0.2826  0.1667 -0.5    -0.0014]
s_begin: 0
s_end: 24
r_begin: 12
r_end: 48

s_begin: 1
s_end: 25
r_begin: 13
r_end: 49

......
seq_x: [[ 6.1557e-01 -1.3896e+00 -9.9100e-01 -1.3248e+00  1.1402e+00 -9.5346e-01
   2.9070e+00]
 [ 5.2944e-01 -1.2429e+00 -1.2435e+00 -1.4268e+00  1.2172e+00 -6.0995e-01
   2.8853e+00]
 [ 5.2944e-01 -1.5362e+00 -1.2435e+00 -1.5260e+00  1.1794e+00 -1.1224e+00
   2.7561e+00]
 [ 5.2944e-01 -1.5362e+00 -1.2435e+00 -1.4268e+00  1.1794e+00 -1.1224e+00
   1.8305e+00]
 [ 4.4331e-01 -1.2429e+00 -1.2435e+00 -1.4268e+00  1.0633e+00 -7.7889e-01
   2.0673e+00]
 [ 1.8492e-01 -1.5362e+00 -1.3260e+00 -1.5260e+00  1.0254e+00 -1.2970e+00
   2.0459e+00]
 [ 1.2660e-02 -1.6829e+00 -1.5785e+00 -1.6280e+00  7.9439e-01 -9.5346e-01
   2.0242e+00]
 [-2.4573e-01 -1.8295e+00 -1.5785e+00 -1.7271e+00  6.4040e-01 -1.2970e+00
   1.8522e+00]
 [-2.4573e-01 -1.9762e+00 -1.7460e+00 -1.6280e+00  6.0253e-01 -1.6405e+00
   2.0673e+00]
 [-2.4573e-01 -1.8295e+00 -1.5785e+00 -1.4268e+00  6.0253e-01 -1.8094e+00
   1.9814e+00]
 [-2.4573e-01 -1.8295e+00 -1.6611e+00 -1.5260e+00  6.0253e-01 -1.4659e+00
   2.0888e+00]
 [-2.4573e-01 -1.5362e+00 -1.5785e+00 -1.4268e+00  6.4040e-01 -1.2970e+00
   1.8305e+00]
 [-3.3186e-01 -1.5362e+00 -1.7460e+00 -1.5260e+00  6.4040e-01 -1.2970e+00
   1.2280e+00]
 [-1.5960e-01 -1.3896e+00 -1.3260e+00 -1.1237e+00  6.4040e-01 -1.4659e+00
   1.0771e+00]
 [-7.3470e-02 -1.8295e+00 -1.3260e+00 -1.4268e+00  6.0253e-01 -1.6405e+00
   7.7573e-01]
 [ 1.2660e-02 -1.5362e+00 -1.2435e+00 -1.3248e+00  6.4040e-01 -1.1224e+00
   1.5150e-01]
 [ 1.2660e-02 -1.5362e+00 -1.2435e+00 -1.3248e+00  7.1740e-01 -1.2970e+00
   2.8063e-01]
 [ 1.8492e-01 -1.3896e+00 -1.2435e+00 -1.3248e+00  7.9439e-01 -9.5346e-01
   1.3008e-01]
 [ 1.8492e-01 -1.5362e+00 -1.2435e+00 -1.4268e+00  7.9439e-01 -9.5346e-01
   5.3889e-01]
 [ 1.8492e-01 -1.2429e+00 -1.2435e+00 -1.2257e+00  8.7139e-01 -1.1224e+00
   9.5210e-04]
 [ 3.5718e-01 -1.0962e+00 -1.1585e+00 -1.1237e+00  9.4839e-01 -7.7889e-01
   4.3791e-02]
 [ 7.8783e-01 -6.5627e-01 -8.2347e-01 -9.2254e-01  1.1794e+00 -9.7496e-02
  -6.3613e-02]
 [ 1.0462e+00 -5.0961e-01 -5.7100e-01 -9.2254e-01  1.2942e+00  7.7075e-02
   4.3791e-02]
 [ 1.6491e+00 -3.6295e-01 -6.8426e-02 -5.2023e-01  1.7171e+00  9.3304e-01
   4.3791e-02]]
seq_y: [[-3.3186e-01 -1.5362e+00 -1.7460e+00 -1.5260e+00  6.4040e-01 -1.2970e+00
  1.2280e+00]
[-1.5960e-01 -1.3896e+00 -1.3260e+00 -1.1237e+00  6.4040e-01 -1.4659e+00
  1.0771e+00]
[-7.3470e-02 -1.8295e+00 -1.3260e+00 -1.4268e+00  6.0253e-01 -1.6405e+00
  7.7573e-01]
[ 1.2660e-02 -1.5362e+00 -1.2435e+00 -1.3248e+00  6.4040e-01 -1.1224e+00
  1.5150e-01]
[ 1.2660e-02 -1.5362e+00 -1.2435e+00 -1.3248e+00  7.1740e-01 -1.2970e+00
  2.8063e-01]
[ 1.8492e-01 -1.3896e+00 -1.2435e+00 -1.3248e+00  7.9439e-01 -9.5346e-01
  1.3008e-01]
[ 1.8492e-01 -1.5362e+00 -1.2435e+00 -1.4268e+00  7.9439e-01 -9.5346e-01
  5.3889e-01]
[ 1.8492e-01 -1.2429e+00 -1.2435e+00 -1.2257e+00  8.7139e-01 -1.1224e+00
  9.5210e-04]
[ 3.5718e-01 -1.0962e+00 -1.1585e+00 -1.1237e+00  9.4839e-01 -7.7889e-01
  4.3791e-02]
[ 7.8783e-01 -6.5627e-01 -8.2347e-01 -9.2254e-01  1.1794e+00 -9.7496e-02
 -6.3613e-02]
[ 1.0462e+00 -5.0961e-01 -5.7100e-01 -9.2254e-01  1.2942e+00  7.7075e-02
  4.3791e-02]
[ 1.6491e+00 -3.6295e-01 -6.8426e-02 -5.2023e-01  1.7171e+00  9.3304e-01
  4.3791e-02]
[ 2.3382e+00  6.6367e-01  2.6662e-01 -4.1824e-01  2.1791e+00  1.9636e+00
  5.3889e-01]
[ 2.7688e+00  1.2503e+00  5.1909e-01 -1.1793e-01  2.5628e+00  2.6506e+00
  8.6202e-01]
[ 3.0272e+00  1.5436e+00  1.1043e+00  2.8720e-01  2.4858e+00  2.3071e+00
  1.0126e+00]
[ 2.6827e+00  1.1037e+00  6.8662e-01  8.3219e-02  2.2169e+00  2.1325e+00
  6.4660e-01]
[ 2.6827e+00  1.3970e+00  6.8662e-01  2.8720e-01  2.2561e+00  4.0246e+00
  6.4660e-01]
[ 2.9411e+00  1.3970e+00  6.0168e-01  3.8636e-01  2.5628e+00  3.1630e+00
  8.4030e-01]
[ 2.9411e+00  1.6903e+00  6.8662e-01  4.8835e-01  2.4858e+00  2.3071e+00
  8.6202e-01]
[ 2.8549e+00  1.3970e+00  6.8662e-01  3.8636e-01  2.4479e+00  2.3071e+00
  9.0486e-01]
[ 2.7105e-01  8.1033e-01  1.0217e+00  6.8951e-01 -4.3503e-01 -4.3537e-01
  1.9465e-01]
[-1.5960e-01  5.1701e-01  8.5414e-01  6.8951e-01 -6.2815e-01 -6.0995e-01
  3.2378e-01]
[-2.4573e-01  3.7035e-01  6.8662e-01  6.8951e-01 -5.8902e-01 -4.3537e-01
  4.0976e-01]
[-3.3186e-01 -5.0961e-01 -4.8842e-01 -1.1793e-01 -7.0514e-01 -2.6644e-01
 -1.4198e+00]
[-1.0196e+00 -2.1629e-01 -2.3595e-01 -3.1908e-01 -7.8214e-01 -7.7889e-01
 -1.0970e+00]
[-1.0196e+00 -6.5627e-01 -5.7100e-01 -4.1824e-01 -7.4301e-01 -6.0995e-01
 -1.0110e+00]
[-8.4735e-01 -2.1629e-01 -3.2089e-01 -2.1709e-01 -6.2815e-01 -2.6644e-01
 -7.7413e-01]
[-8.4735e-01 -5.0961e-01 -2.3595e-01 -1.1793e-01 -6.2815e-01 -7.7889e-01
 -5.3729e-01]
[-5.0283e-01 -2.1629e-01 -6.8426e-02 -2.1709e-01 -4.3503e-01 -9.7496e-02
 -3.2187e-01]
[ 9.8790e-02  7.7033e-02  4.3415e-01 -1.1793e-01 -2.0530e-01  2.4601e-01
 -7.7413e-01]
[ 1.8492e-01  7.7033e-02  3.5157e-01 -2.1709e-01 -5.1305e-02  2.4601e-01
 -7.5241e-01]
[ 7.0170e-01  2.2369e-01  8.5414e-01  8.3219e-02  1.4055e-01  2.4601e-01
 -4.9415e-01]
[ 5.2944e-01 -2.1629e-01  4.3415e-01 -2.1709e-01  1.7968e-01 -9.7496e-02
 -2.7903e-01]
[ 3.5718e-01  2.2369e-01  3.5157e-01 -2.1709e-01  6.3558e-02  5.8952e-01
 -3.8644e-01]
[-1.3641e+00 -1.0962e+00 -2.0811e+00 -1.4268e+00 -1.6617e-01  7.6410e-01
 -3.8644e-01]
[-1.3641e+00 -5.0961e-01 -1.0736e+00 -1.4268e+00 -2.4317e-01  4.2059e-01
 -2.1447e-01]]
seq_x_mark: [[-0.5     0.1667 -0.5    -0.0014]
 [-0.5     0.1667 -0.5    -0.0014]
 [-0.5     0.1667 -0.5    -0.0014]
 [-0.5     0.1667 -0.5    -0.0014]
 [-0.4565  0.1667 -0.5    -0.0014]
 [-0.4565  0.1667 -0.5    -0.0014]
 [-0.4565  0.1667 -0.5    -0.0014]
 [-0.4565  0.1667 -0.5    -0.0014]
 [-0.413   0.1667 -0.5    -0.0014]
 [-0.413   0.1667 -0.5    -0.0014]
 [-0.413   0.1667 -0.5    -0.0014]
 [-0.413   0.1667 -0.5    -0.0014]
 [-0.3696  0.1667 -0.5    -0.0014]
 [-0.3696  0.1667 -0.5    -0.0014]
 [-0.3696  0.1667 -0.5    -0.0014]
 [-0.3696  0.1667 -0.5    -0.0014]
 [-0.3261  0.1667 -0.5    -0.0014]
 [-0.3261  0.1667 -0.5    -0.0014]
 [-0.3261  0.1667 -0.5    -0.0014]
 [-0.3261  0.1667 -0.5    -0.0014]
 [-0.2826  0.1667 -0.5    -0.0014]
 [-0.2826  0.1667 -0.5    -0.0014]
 [-0.2826  0.1667 -0.5    -0.0014]
 [-0.2826  0.1667 -0.5    -0.0014]]
seq_y_mark: [[-0.3696  0.1667 -0.5    -0.0014]
 [-0.3696  0.1667 -0.5    -0.0014]
 [-0.3696  0.1667 -0.5    -0.0014]
 [-0.3696  0.1667 -0.5    -0.0014]
 [-0.3261  0.1667 -0.5    -0.0014]
 [-0.3261  0.1667 -0.5    -0.0014]
 [-0.3261  0.1667 -0.5    -0.0014]
 [-0.3261  0.1667 -0.5    -0.0014]
 [-0.2826  0.1667 -0.5    -0.0014]
 [-0.2826  0.1667 -0.5    -0.0014]
 [-0.2826  0.1667 -0.5    -0.0014]
 [-0.2826  0.1667 -0.5    -0.0014]
 [-0.2391  0.1667 -0.5    -0.0014]
 [-0.2391  0.1667 -0.5    -0.0014]
 [-0.2391  0.1667 -0.5    -0.0014]
 [-0.2391  0.1667 -0.5    -0.0014]
 [-0.1957  0.1667 -0.5    -0.0014]
 [-0.1957  0.1667 -0.5    -0.0014]
 [-0.1957  0.1667 -0.5    -0.0014]
 [-0.1957  0.1667 -0.5    -0.0014]
 [-0.1522  0.1667 -0.5    -0.0014]
 [-0.1522  0.1667 -0.5    -0.0014]
 [-0.1522  0.1667 -0.5    -0.0014]
 [-0.1522  0.1667 -0.5    -0.0014]
 [-0.1087  0.1667 -0.5    -0.0014]
 [-0.1087  0.1667 -0.5    -0.0014]
 [-0.1087  0.1667 -0.5    -0.0014]
 [-0.1087  0.1667 -0.5    -0.0014]
 [-0.0652  0.1667 -0.5    -0.0014]
 [-0.0652  0.1667 -0.5    -0.0014]
 [-0.0652  0.1667 -0.5    -0.0014]
 [-0.0652  0.1667 -0.5    -0.0014]
 [-0.0217  0.1667 -0.5    -0.0014]
 [-0.0217  0.1667 -0.5    -0.0014]
 [-0.0217  0.1667 -0.5    -0.0014]
 [-0.0217  0.1667 -0.5    -0.0014]]

代码

class Dataset_Custom(Dataset):
    def __init__(self, root_path, flag='train', size=None,
                 features='S', data_path='ETTh1.csv',
                 target='OT', scale=True, timeenc=0, freq='h'):
        # size [seq_len, label_len, pred_len]
        # info
        if size == None:
            self.seq_len = 24 * 4 * 4
            self.label_len = 24 * 4
            self.pred_len = 24 * 4
        else:
            self.seq_len = size[0]
            self.label_len = size[1]
            self.pred_len = size[2]
        # init
        assert flag in ['train', 'test', 'val']
        type_map = {'train': 0, 'val': 1, 'test': 2}
        self.set_type = type_map[flag]

        self.features = features
        self.target = target
        self.scale = scale
        self.timeenc = timeenc
        self.freq = freq

        self.root_path = root_path
        self.data_path = data_path
        self.__read_data__()

    def __read_data__(self):
        self.scaler = StandardScaler()
        df_raw = pd.read_csv(os.path.join(self.root_path,
                                          self.data_path))

        '''
        df_raw.columns: ['date', ...(other features), target feature]
        '''
        cols = list(df_raw.columns)
        cols.remove(self.target)
        cols.remove('date')
        df_raw = df_raw[['date'] + cols + [self.target]]
        # print(cols)
        # num_train = int(len(df_raw) * 0.7)
        # print("num_train:",num_train)
        # num_test = int(len(df_raw) * 0.2)
        
        num_train = int(len(df_raw) * 0.5)
        print("num_train:",num_train)
        num_test = int(len(df_raw) * 0.2)
        
        
        num_vali = len(df_raw) - num_train - num_test
        border1s = [0, num_train - self.seq_len, len(df_raw) - num_test - self.seq_len]
        border2s = [num_train, num_train + num_vali, len(df_raw)]
        border1 = border1s[self.set_type]
        border2 = border2s[self.set_type]

        if self.features == 'M' or self.features == 'MS':
            cols_data = df_raw.columns[1:]
            df_data = df_raw[cols_data]
        elif self.features == 'S':
            df_data = df_raw[[self.target]]

        if self.scale:
            train_data = df_data[border1s[0]:border2s[0]]
            self.scaler.fit(train_data.values)
            data = self.scaler.transform(df_data.values)
            
            #--------------------------------------------------------------------
            print("train_data.values",train_data.values)
            print("transform_data",data)
            inverse_transform_data=self.inverse_transform(data)
            print("inverse_transform_data:",inverse_transform_data)
             #--------------------------------------------------------------------
            
            
        else:
            data = df_data.values

        df_stamp = df_raw[['date']][border1:border2]
        df_stamp['date'] = pd.to_datetime(df_stamp.date)
        
  
        if self.timeenc == 0:
            df_stamp['month'] = df_stamp.date.apply(lambda row: row.month, 1)
            df_stamp['day'] = df_stamp.date.apply(lambda row: row.day, 1)
            df_stamp['weekday'] = df_stamp.date.apply(lambda row: row.weekday(), 1)
            df_stamp['hour'] = df_stamp.date.apply(lambda row: row.hour, 1)
            data_stamp = df_stamp.drop(['date'], 1).values
        elif self.timeenc == 1:
            print("df_stamp['date'].values:",df_stamp['date'].values)
            data_stamp = time_features(pd.to_datetime(df_stamp['date'].values), freq=self.freq)
            data_stamp = data_stamp.transpose(1, 0)

        self.data_x = data[border1:border2]
        self.data_y = data[border1:border2]
        self.data_stamp = data_stamp
        print("data_stamp:",data_stamp)
        

        print('\n'.join(['%s:%s' % item for item in self.__dict__.items()]) )


    def __getitem__(self, index):
        s_begin = index
        s_end = s_begin + self.seq_len
        r_begin = s_end - self.label_len
        r_end = r_begin + self.label_len + self.pred_len
        
        print("s_begin:",s_begin)
        print("s_end:",s_end)
        print("r_begin:",r_begin)
        print("r_end:",r_end)

        seq_x = self.data_x[s_begin:s_end]
        seq_y = self.data_y[r_begin:r_end]
        seq_x_mark = self.data_stamp[s_begin:s_end]
        seq_y_mark = self.data_stamp[r_begin:r_end]
        
        print("seq_x.shape:",seq_x.shape)
        print("seq_y.shape:",seq_y.shape)
        print("seq_x_mark.shape:",seq_x_mark.shape)
        print("seq_y_mark.shape:",seq_y_mark.shape)
        print("seq_x:",seq_x)
        print("seq_y:",seq_y)
        print("seq_x_mark:",seq_x_mark)
        print("seq_y_mark:",seq_y_mark)

        

        return seq_x, seq_y, seq_x_mark, seq_y_mark

    def __len__(self):
        return len(self.data_x) - self.seq_len - self.pred_len + 1

    def inverse_transform(self, data):
        return self.scaler.inverse_transform(data)

对比下形状

 for i, (batch_x, , , ): torch.Size([1, 24, 7])
 for i, (, batch_y, , ): torch.Size([1, 36, 7])
 for i, (, , batch_x_mark, ): torch.Size([1, 24, 4])
 for i, (, , , batch_y_mark): torch.Size([1, 36, 4])
seq_x.shape: (24, 7)
seq_y.shape: (36, 7)
seq_x_mark.shape: (24, 4)
seq_y_mark.shape: (36, 4)

训练数据用

先用batch_x, batch_y, batch_x_mark, batch_y_mark作为参数

 outputs, batch_y = self._predict(batch_x, batch_y, batch_x_mark, batch_y_mark)
def _predict(self, batch_x, batch_y, batch_x_mark, batch_y_mark):
    # decoder input
    dec_inp = torch.zeros_like(batch_y[:, -self.args.pred_len:, :]).float()
    dec_inp = torch.cat([batch_y[:, :self.args.label_len, :], dec_inp], dim=1).float().to(self.device)
    # encoder - decoder

    def _run_model():

        outputs = self.model(batch_x, batch_x_mark, dec_inp, batch_y_mark)
        if self.args.output_attention:
            outputs = outputs[0]
        return outputs

batch_y变dec_inp
假如是Vanilla Transformer模型
输入的对应关系如下

 x_enc			= batch_x
 x_mark_enc 	= batch_x_mark
 x_dec 			= dec_inp
 x_mark_dec 	= batch_y_mark

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/570815.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

AndroidStudio中虚拟机(AVD)无法启动,出现unable to locate adb错误

1.检查Android SDK Platform-Tools是否安装(个人是通过这个方法解决的) 首先通过File-Project Structure-Project SDK检查SDK有没有被选中 步骤:打开file -> settings ,搜索SDK 之后点击"-",在点击Apply进行安装 2.可能是驱动的问题 电脑…

牛客NC179 长度为 K 的重复字符子串【simple 哈希,滑动窗口 C++、Java、Go、PHP】

题目 题目链接: https://www.nowcoder.com/practice/eced9a8a4b6c42b79c95ae5625e1d5fd 思路 哈希统计每个字符出现的次数。没在窗口内的字符要删除参考答案C class Solution {public:/*** 代码中的类名、方法名、参数名已经指定,请勿修改&#xff0c…

python(一)

一、字面量 字面量:在代码中,被写下来的固定的值,称之为字面量。 Python中常用的有6种值(数据)的类型: 二、注释 注释的分类: 单行注释:以#开头,#右边的所有文字当作说明,而不是真正要执行的程序&#…

2024新算法爱情进化算法(LEA)和经典灰狼优化器(GWO)进行无人机三维路径规划设计实验

简介: 2024新算法爱情进化算法(LEA)和经典灰狼优化器(GWO)进行无人机三维路径规划设计实验。 无人机三维路径规划的重要意义在于确保飞行安全、优化飞行路线以节省时间和能源消耗,并使无人机能够适应复杂环…

多模态模型

转换器成功作为构建语言模型的一种方法,促使 AI 研究人员考虑同样的方法是否对图像数据也有效。 研究结果是开发多模态模型,其中模型使用大量带有描述文字的图像进行训练,没有固定的标签。 图像编码器基于像素值从图像中提取特征,…

C++笔记:类和对象(一)->封装

类和对象 认识类和对象 先来回忆一下C语言中的类型和变量,类型就像是定义了数据的规则,而变量则是根据这些规则来实际存储数据的容器。类是我们自己定义的一种数据类型,而对象则是这种数据类型的一个具体实例。类就可以理解为类型&#xff0c…

vue2知识点————(父子通信)

vue2的知识点,更多前端知识在主页,还有其他知识会持续更新 vue组件 在Vue.js 2.x中,父子组件之间的通信是非常常见的情况,Vue提供了多种方法来实现这种通信。 Props 父向子通信 Props 是父组件向子组件传递数据的一种方式。通过…

Java的八大基本数据类型和 println 的介绍

前言 如果你有C语言的基础,这部分内容就会很简单,但是会有所不同~~ 这是我将要提到的八大基本数据类型: 注意,Java的数据类型是有符号的!!!和C语言不同,Java不存在无符号的数据。 整…

【电控笔记5.8】数字滤波器设计流程频域特性

数字滤波器设计流程&频域特性 2HZ : w=2pi2=12.56 wc=2*pi*5; Ts=0.001; tf_lpf =

【行为型模式】解释器模式

一、解释器模式概述 解释器模式定义:给分析对象定义一个语言,并定义该语言的文法表示,再设计一个解析器来解释语言中的句子。也就是说,用编译语言的方式来分析应用中的实例。这种模式实现了文法表达式处理的接口,该接口…

设计模式——状态模式19

状态模式是一种行为设计模式, 允许一个对象在其内部状态改变时改变它的行为,对象看起来好像修改了它的类。状态模式的核心是状态与行为绑定,不同的状态对应不同的行为。 设计模式,一定要敲代码理解 状态行为抽象 //在某种状态下&…

数据库——实 验 8 SQL 编程

1.T-SQL 语言简介 SQL Server 使用的语言称作 Transact-SQL, 它不仅包括基本 SQL 操作的内容,如 SQL 的数据查询功能和数据操作功能等,还有一般程序设计的能力。 2. 局部变量和全局变量的概念 1)局部变量 局部变量是一个能够拥有特定数据类型的对…

多目标应用:基于非支配排序粒子群优化算法NSPSO求解无人机三维路径规划(MATLAB代码)

一、无人机多目标优化模型 无人机三维路径规划是无人机在执行任务过程中的非常关键的环节,无人机三维路径规划的主要目的是在满足任务需求和自主飞行约束的基础上,计算出发点和目标点之间的最佳航路。 1.1路径成本 无人机三维路径规划的首要目标是寻找…

html网页在展示时,监听网络是否断网,如果断网页面暂停点击响应

序言: 集合百家之所长,方著此篇文章,废话少说,直接上代码,找好你的测试网页,进行配置,然后复制粘贴代码,就可以了。 1.css文件内容 #newbody{display: none;width: 100%;height: 9…

【用户投稿】Apache SeaTunnel 2.3.3+Web 1.0.0版本安装部署

项目概要 Apache SeaTunnel 是一个分布式、高性能、易扩展的数据集成平台,用于实时和离线数据处理,支持多种数据源之间的数据迁移和转换。 其中,Apache-seatunnel-web-1.0.0-bin.tar.gz和apache-seatunnel-2.3.3-bin.tar.gz代表了 Apache SeaTunnel Web…

python语言实现语音合成(文字转语音)

python语言实现语音合成(文字转语音) 在Python中实现文本到语音——语音朗读功能,可以使用pyttsx3库。pyttsx3库的安装和使用也相对简单,但在控制语音的暂停、继续和停止功能方面可能存在一定的困难。 首先,您需要安装…

北航计算机软件技术基础课程作业笔记【4】

题目&#xff08;好像以前没加&#xff09; 二叉树与哈希表 作业 1.二叉树前序遍历结果 二叉树结构为 代码实现中序后序推理前序表达式 #include <iostream> #include <stack> #include <string> #include <vector> #include <deque> ​ // …

H800算力低至5.99元/卡时!抢鲜体验LLaMA3最佳实践就在潞晨云

由Meta发布的LLaMA3 8B和LLaMA3 70B的&#xff0c;将开源AI大模型推向新的高度。在多个基准测试上的表现均大幅超过已有竞品&#xff0c;成为AI应用的最新优选。 潞晨云现已上架 LLaMA3 8B和LLaMA3 70B从推理到微调和预训练的实践教程。 提供免费测试代金券&#xff0c;限时特…

yolov8 区域多类别计数

yolov8 区域多类别计数 1. 基础2. 计数功能2.1 计数模块2.2 判断模块 3. 初始代码4. 实验结果5. 完整代码6. 源码 1. 基础 本项目是在 WindowsYOLOV8环境配置 的基础上实现的&#xff0c;测距原理可见上边文章 2. 计数功能 2.1 计数模块 在指定区域内计数模块 region_point…