基于ResNet-attention的负荷预测

一、attention机制

    注意力模型最近几年在深度学习各个领域被广泛使用,无论是图像处理、语音识别还是自然语言处理的各种不同类型的任务中,都很容易遇到注意力模型的身影。从注意力模型的命名方式看,很明显其借鉴了人类的注意力机制。我们来看下面的一张图片。

    图中形象化展示了人类在看到一副图像时是如何高效分配有限的注意力资源的,其中红色区域表明视觉系统更关注的目标,很明显对于图1所示的场景,人们会把注意力更多投入到人的脸部,文本的标题以及文章首句等位置。

   视觉注意力机制是人类视觉所特有的大脑信号处理机制。人类视觉通过快速扫描全局图像,获得需要重点关注的目标区域,也就是一般所说的注意力焦点,而后对这一区域投入更多注意力资源,以获取更多所需要关注目标的细节信息,而抑制其他无用信息。深度学习中的注意力机制的核心就是让网络关注其更需要更重要的地方,注意力机制就是实现网络自适应的一个方式。

    注意力机制的本质就是定位到感兴趣的信息,抑制无用信息,结果通常都是以概率图或者概率特征向量的形式展示,从原理上来说,主要分为空间注意力模型,通道注意力模型,空间和通道混合注意力模型三种。那么今天我们主要介绍通道注意力机制。

1、通道注意力机制

    通道注意力机制最经典的应用就是SENet(Sequeeze and Excitation Net),它通过建模各个特征通道的重要程度,然后针对不同的任务增强或者抑制不同的通道,原理图如下。

 

       在正常的卷积操作后分出了一个旁路分支,首先进行Squeeze操作(即图中Fsq(·)),它将空间维度进行特征压缩,即每个二维的特征图变成一个实数,相当于具有全局感受野的池化操作,特征通道数不变。然后是Excitation操作(即图中的Fex(·)),它通过参数w为每个特征通道生成权重,w被学习用来显式地建模特征通道间的相关性。在文章中,使用了一个2层bottleneck结构(先降维再升维)的全连接层+Sigmoid函数来实现。得到了每一个特征通道的权重之后,就将该权重应用于原来的每个特征通道,基于特定的任务,就可以学习到不同通道的重要性。作为一种通用的设计思想,它可以被用于任何现有网络,具有较强的实践意义。

    综上通道注意力计算公式总结为:

    关于通道注意力机制的原理就介绍到这里,想要了解具体原理的,大家可以参考文献:Squeeze-and-Excitation Networks

二、代码实战

clc
clear
​
close all
load Train.mat
% load Test.mat
Train.weekend = dummyvar(Train.weekend);
Train.month = dummyvar(Train.month);
Train = movevars(Train,{'weekend','month'},'After','demandLag');
Train.ts = [];
​
​
Train(1,:) =[];
y = Train.demand;
x = Train{:,2:5};
[xnorm,xopt] = mapminmax(x',0,1);
[ynorm,yopt] = mapminmax(y',0,1);
​
xnorm = xnorm(:,1:1000);
ynorm = ynorm(1:1000);
​
k = 24;           % 滞后长度
​
% 转换成2-D image
for i = 1:length(ynorm)-k
​
    Train_xNorm{:,i} = xnorm(:,i:i+k-1);
    Train_yNorm(i) = ynorm(i+k-1);
    Train_y{i} = y(i+k-1);
end
Train_x = Train_xNorm';
​
ytest = Train.demand(1001:1170);
xtest = Train{1001:1170,2:5};
[xtestnorm] = mapminmax('apply', xtest',xopt);
[ytestnorm] = mapminmax('apply',ytest',yopt);
% xtestnorm = [xtestnorm; Train.weekend(1001:1170,:)'; Train.month(1001:1170,:)'];
xtest = xtest';
for i = 1:length(ytestnorm)-k
    Test_xNorm{:,i} = xtestnorm(:,i:i+k-1);
    Test_yNorm(i) = ytestnorm(i+k-1);
    Test_y(i) = ytest(i+k-1);
end
Test_x = Test_xNorm';
x_train = table(Train_x,Train_y');
x_test = table(Test_x);
%% 训练集和验证集划分
% TrainSampleLength = length(Train_yNorm);
% validatasize = floor(TrainSampleLength * 0.1);
% Validata_xNorm = Train_xNorm(:,end - validatasize:end,:);
% Validata_yNorm = Train_yNorm(:,TrainSampleLength-validatasize:end);
% Validata_y = Train_y(TrainSampleLength-validatasize:end);
% 
% Train_xNorm = Train_xNorm(:,1:end-validatasize,:);
% Train_yNorm = Train_yNorm(:,1:end-validatasize);
% Train_y = Train_y(1:end-validatasize);
%% 构建残差神经网络
lgraph = layerGraph();
tempLayers = [
    imageInputLayer([4 24 1],"Name","imageinput")
    convolution2dLayer([3 3],32,"Name","conv","Padding","same")];
lgraph = addLayers(lgraph,tempLayers);
​
tempLayers = [
    batchNormalizationLayer("Name","batchnorm")
    reluLayer("Name","relu")];
lgraph = addLayers(lgraph,tempLayers);
​
tempLayers = [
    additionLayer(2,"Name","addition")
    convolution2dLayer([3 3],32,"Name","conv_1","Padding","same")];
lgraph = addLayers(lgraph,tempLayers);
​
tempLayers = [
    batchNormalizationLayer("Name","batchnorm_1")
    reluLayer("Name","relu_1")];
lgraph = addLayers(lgraph,tempLayers);
​
tempLayers = [
    additionLayer(2,"Name","addition_1")
    convolution2dLayer([3 3],32,"Name","conv_2","Padding","same")];
lgraph = addLayers(lgraph,tempLayers);
​
tempLayers = [
    batchNormalizationLayer("Name","batchnorm_2")
    reluLayer("Name","relu_2")];
lgraph = addLayers(lgraph,tempLayers);
​
tempLayers = [
    additionLayer(2,"Name","addition_2")
    convolution2dLayer([3 3],32,"Name","conv_3","Padding","same")];
lgraph = addLayers(lgraph,tempLayers);
​
tempLayers = [
    batchNormalizationLayer("Name","batchnorm_3")
    reluLayer("Name","relu_3")];
lgraph = addLayers(lgraph,tempLayers);
tempLayers = [
    additionLayer(2,"Name","addition_4")
    sigmoidLayer("Name","sigmoid")];
lgraph = addLayers(lgraph,tempLayers);
​
tempLayers = multiplicationLayer(2,"Name","multiplication");
lgraph = addLayers(lgraph,tempLayers);
​
tempLayers = [
    additionLayer(3,"Name","addition_3")
    fullyConnectedLayer(32,"Name","fc1")
    fullyConnectedLayer(16,"Name","fc2")
    fullyConnectedLayer(1,"Name","fc3")
    regressionLayer("Name","regressionoutput")];
lgraph = addLayers(lgraph,tempLayers);
​
% 清理辅助变量
clear tempLayers;
plot(lgraph);
analyzeNetwork(lgraph);
%% 设置网络参数
maxEpochs = 100;
miniBatchSize = 32;
options = trainingOptions('adam', ...
 'MaxEpochs',maxEpochs, ...
 'MiniBatchSize',miniBatchSize, ...
 'InitialLearnRate',0.005, ...
 'GradientThreshold',1, ...
 'Shuffle','never', ...
 'Plots','training-progress',...
 'Verbose',0);
​
net = trainNetwork(x_train,lgraph ,options);
​
Predict_yNorm = predict(net,x_test);
Predict_y = double(Predict_yNorm);
plot(Test_y)
hold on 
plot(Predict_y)
legend('真实值','预测值')
​

 训练迭代图:

试集预测曲线图

完整代码

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/14969.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

融云 CTO 岑裕:出海技术前沿探索和排「坑」实践

在本文中,你将看到以下内容: 全球通信网络在接入点、链路加速、服务商、协议等层面的动态演进; 进入到具体市场,禁运国、跨国拦截、区域一致性差等细节“坑点”如何应对; 融云如何从技术侧帮助开发者应对本地化用户体…

Hive与HBase的区别及应用场景

目录: 零、前言一、定义二、区别三、应用场景 零、前言 在学大数据分析的过程中,Hive和HBase是两个非常重要的内容,对于初学者而言容易混淆。所以比较两者区别,能够帮助我们对这两个组件有一个清晰的认识和定位。那么,…

一篇文章看懂MySQL的多表连接(包含左/右/全外连接)

MySQL的多表查询 这是第二次学习多表查询,关于左右连接还是不是很熟悉,因此重新看一下。小目标:一篇文章看懂多表查询!! 这篇博客是跟着宋红康老师学习的,点击此处查看视频,关于数据库我放在了…

大神们分享STM32的学习方法

单片机用处这么广,尤其是STM32生态这么火!如何快速上手学习呢? 第一:你要考虑的是,要用STM32实现什么 为什么使用STM32而不是8051? 是因为51的频率太低,无法满足计算需求?是51的管脚太少,无法…

云HIS(二级医院,乡镇医院,民营医院,标准化HIS医院信息管理系统源码)

传统 HIS(基于医院信息系统) 和云 HIS(基于云计算的医院信息系统)各有优缺点,选择哪种系统需要根据具体情况进行权衡。 传统 HIS 系统通常由医院自行开发和维护,适用于医院内部信息化程度较高、数据安全性…

【软件测试】第1章 软件测试概述

系列文章目录 文章目录 系列文章目录前言第1章 软件测试概述1.1 软件、软件危机和软件工程1.1.1 基本概念1.1.2 软件工程的目标及其一般开发过程1.1.3 软件过程模型 1.2 软件缺陷与软件故障1.2.1 基本概念1.2.2 典型案例 1.3 软件测试的概念1.3.1 软件测试的定义1.3.2 软件测试…

计算机程序安装及使用须知_kaic

安装及使用须知 1 数据库建模程序的使用 本文件夹中的“PowerDesigner建模”目录下包含三个可运行文件TMS1.cdm,TMS.cdm,TMS.pdm分别为TMS系统的实体关系简图、实体关系图和数据库模型,使用PowerDesigner集成开发环境打开任意一个文件即可运…

Linux系统与shell编程第一节课

目录 1.1 Linux发展历史 1.2 什么是linux? 1.3 Linux的发行版 Host-Only(仅主机模式) windows开发 linux服务 区块链, 特点:稳定,安全,可移植性,低资源消耗,开源软…

2023年第十二届数据技术嘉年华(DTC)资料分享

第十二届数据技术嘉年华(DTC 2023)已于4月8日在北京圆满落幕,大会围绕“开源融合数智化——引领数据技术发展,释放数据要素价值”这一主题,共设置有1场主论坛,12场专题论坛,68场主题演讲&#x…

【基础】Kafka -- 日志存储

Kafka -- 日志存储 日志文件目录日志索引偏移量索引时间戳索引 日志清理日志删除基于时间基于日志大小基于日志起始偏移量 日志压缩 日志文件目录 Kafka 中的消息以主题为单位进行基本归类,而每个主题又可以划分为一个或者多个分区。在不考虑多副本的情况下&#x…

【MySQL】插入文件路径,反斜杠消失

系列文章 C#底层库–MySQL脚本自动构建类(insert、update语句生成) 本文链接:https://blog.csdn.net/youcheng_ge/article/details/129179216 C#底层库–MySQL数据库访问操作辅助类(推荐阅读) 本文链接:h…

如何优雅的写个try catch的方式!

软件开发过程中,不可避免的是需要处理各种异常,就我自己来说,至少有一半以上的时间都是在处理各种异常情况,所以代码中就会出现大量的try {...} catch {...} finally {...} 代码块,不仅有大量的冗余代码,而…

07 【Sass语法介绍-控制指令】

1.前言 Sass 为我们提供了很多控制指令,使得我们可以更高效的来控制样式的输出,或者在函数中进行逻辑控制。本节内容我们就来讲解什么是 Sass 控制指令?它能用来做什么?它将使你更方便的编写 Sass 。 2.什么是 Sass 控制指令 控…

Dockere-Compose迁移Gitea部署

Dockere-Compose迁移Gitea部署 ps: 江湖不是打打杀杀,江湖是人情事故。 解释: Gitea:类似于Git的代码版本管理工具。Docker:Docker-Compose: Docker命令: 查看镜像:docker images 删除镜像…

2023年江苏专转本成绩查询步骤

2023年江苏专转本成绩查询时间 2023年江苏专转本成绩查询时间预计在5月初,参加考试的考生,可以关注考试院发布的消息。江苏专转本考生可在规定时间内在省教育考试院网,在查询中心页面中输入准考证号和身份证号进行查询,或者拨…

【u盘提示:驱动器未格式化】如何解决?

u盘虽然使用很方便,可随时拷贝资料到任何有电脑的地方,但它的问题也是比较多的,其中u盘提示驱动器未格式化故障最让人心虚,因为已经无法打开u盘了,里面的资料怎么办,很重要的怎么办,所以今天就教…

LSSANet:一种用于肺结节检测的长、短切片感知网络

文章目录 LSSANet: A Long Short Slice-Aware Network for Pulmonary Nodule Detection摘要方法Long Short Slice GroupingLong Short Slice-Aware Network 实验结果 LSSANet: A Long Short Slice-Aware Network for Pulmonary Nodule Detection 摘要 提出了一个长短片感知网…

《Spring MVC》 第六章 MVC类型转换器、格式化器

前言 介绍MVC类型转换器、格式化器 1、使用场景 <form th:action"{/user/register}" method"post">用户名&#xff1a;<input type"text" name"userName"/><br/>密码&#xff1a;<input type"password&q…

【Access】win 10 / win 11:Access 下载、安装、使用教程(「管理信息系统」实践专用软件)

目录 一、前言 二、卸载 Office 三、下载 Office Tool Plus 四、安装 Office&#xff08;内含 Access&#xff09; &#xff08;1&#xff09;启动 Office Tool Plus &#xff08;2&#xff09;部署 &#xff08;3&#xff09;安装 Office&#xff08;内含 Access&#…

【Arduino SD卡和数据记录教程】

【Arduino SD卡和数据记录教程】 1. 前言2. 工作原理3. Arduino SD 卡模块代码4. Arduino SD卡数据记录1. 前言 在本Arduino教程中,我们将学习如何将SD卡模块与Arduino板一起使用。此外,结合DS3231实时时钟模块,我们将制作一个数据记录示例,将温度传感器的数据存储到SD卡中…