部分代码:
# CNN-Transformer
class CNNTransformerEncoder(nn.Module):
def __init__(self, input_features, transformer_encoder_heads,
embedding_features, cnn_kernel_size, dim_feedforward_enc, n_encoder_layer):
super(CNNTransformerEncoder, self).__init__()
# input: [batch_size, input_features, input_seq_len]
# output: [batch_size, embedding_features, output_len(related to kernel_size, padding and stride)]
self.cnn_embedding = nn.Conv1d(input_features, embedding_features, cnn_kernel_size) # CNN部分
self.position_embedding = PositionalEncoder(d_model=embedding_features) # 位置编码
transformer_encoder_layer = nn.TransformerEncoderLayer(d_model=embedding_features,
nhead=transformer_encoder_heads,
dim_feedforward=dim_feedforward_enc,
activation='gelu') # transformer编码器
self.transformer_encoder = nn.TransformerEncoder(transformer_encoder_layer, num_layers=n_encoder_layer)
def forward(self, input_seq):
cnn_embedding_results = self.cnn_embedding(input_seq) # 输入经过CNN
embediing_with_position = self.position_embedding(cnn_embedding_results.permute((0, 2, 1))) # 进行位置编码
encoder_res = self.transformer_encoder(embediing_with_position.permute((1, 0, 2))) # 通过transformer encoder
return encoder_res
项目截图:
数据:
测试集预测对比:
#完整代码
https://mbd.pub/o/works/592982