分类预测 | MATLAB实现CNN-BiLSTM-Attention多输入分类预测
目录
- 分类预测 | MATLAB实现CNN-BiLSTM-Attention多输入分类预测
- 分类效果
- 基本介绍
- 模型描述
- 程序设计
- 参考资料
分类效果
基本介绍
MATLAB实现CNN-BiLSTM-Attention多输入分类预测,CNN-BiLSTM结合注意力机制多输入分类预测。
模型描述
Matlab实现CNN-BiLSTM-Attention多变量分类预测
1.data为数据集,格式为excel,12个输入特征,输出四个类别;
2.MainCNN_BiLSTM_AttentionNC.m为主程序文件,运行即可;
注意程序和数据放在一个文件夹,运行环境为Matlab2020b及以上。
4.注意力机制模块:
SEBlock(Squeeze-and-Excitation Block)是一种聚焦于通道维度而提出一种新的结构单元,为模型添加了通道注意力机制,该机制通过添加各个特征通道的重要程度的权重,针对不同的任务增强或者抑制对应的通道,以此来提取有用的特征。该模块的内部操作流程如图,总体分为三步:首先是Squeeze 压缩操作,对空间维度的特征进行压缩,保持特征通道数量不变。融合全局信息即全局池化,并将每个二维特征通道转换为实数。实数计算公式如公式所示。该实数由k个通道得到的特征之和除以空间维度的值而得,空间维数为H*W。其次是Excitation激励操作,它由两层全连接层和Sigmoid函数组成。如公式所示,s为激励操作的输出,σ为激活函数sigmoid,W2和W1分别是两个完全连接层的相应参数,δ是激活函数ReLU,对特征先降维再升维。最后是Reweight操作,对之前的输入特征进行逐通道加权,完成原始特征在各通道上的重新分配。
程序设计
- 完整程序和数据获取方式1:同等价值程序兑换;
- 完整程序和数据获取方式2:私信博主获取。
%% CNN模型建立
layers = [
imageInputLayer([size(input,1) 1 1]) %输入层参数设置
convolution2dLayer(3,16,'Padding','same')%卷积层的核大小、数量,填充方式
reluLayer %relu激活函数
fullyConnectedLayer(384) % 384 全连接层神经元
fullyConnectedLayer(384) % 384 全连接层神经元
fullyConnectedLayer(1) % 输出层神经元
regressionLayer]; % 添加回归层,用于计算损失值
%% 模型训练与测试
options = trainingOptions('adam', ...
'MaxEpochs',20, ...
'MiniBatchSize',16, ...
'InitialLearnRate',0.005, ...
'GradientThreshold',1, ...
'Verbose',false,...
'Plots','training-progress',...
'ValidationData',{testD,targetD_test'});
% 训练
net = trainNetwork(trainD,targetD',layers,options);
————————————————
版权声明:本文为CSDN博主「机器学习之心」的原创文章,遵循CC 4.0 BY-SA版权协议,转载请附上原文出处链接及本声明。
tempLayers = multiplicationLayer(2, "Name", "multiplication"); % 点乘的注意力
lgraph = addLayers(lgraph, tempLayers); % 将上述网络结构加入空白结构中
tempLayers = [
sequenceUnfoldingLayer("Name", "sequnfold") % 建立序列反折叠层
flattenLayer("Name", "flatten") % 网络铺平层
bilstmLayer(6, "Name", "bilstm", "OutputMode", "last") % BiLSTM层
fullyConnectedLayer(num_class) % 全连接层
softmaxLayer % 损失函数层
classificationLayer]; % 分类层
lgraph = addLayers(lgraph, tempLayers); % 将上述网络结构加入空白结构中
lgraph = connectLayers(lgraph, "seqfold/out", "conv_1"); % 折叠层输出 连接 卷积层输入;
lgraph = connectLayers(lgraph, "seqfold/miniBatchSize", "sequnfold/miniBatchSize");
参考资料
[1] https://blog.csdn.net/kjm13182345320/article/details/129943065?spm=1001.2014.3001.5501
[2] https://blog.csdn.net/kjm13182345320/article/details/129919734?spm=1001.2014.3001.5501