SEnet注意力机制(逐行代码注释讲解)

目录

⒈结构图

⒉机制流程讲解

⒊源码(pytorch框架实现)及逐行解释

⒋测试结果


⒈结构图

左边是我自绘的,右下角是官方论文的。


⒉机制流程讲解

通道注意力机制的思想是,对于输入进来的特征层,我们在每一个通道学习不同的权重,这些权重与不同通道的特征相关,决定了每个通道在任务中的重要性。

对于SENet而言,它会对输入特征层进行这些操作:

①首先对输入特征层做了global average pooling,也就是全局平均池化,全局平均池化将对当前特征层取平均值,显然,高、宽分别为H、W的特征层经过平均池化操作后会得到一个实数,这个实数就是所有输入特征层的平均值;另外,平均池化并不影响通道数,因此,输入为C*H*W的特征经过平均池化后,H和W两个维度被压缩,就将得到只剩下C(也就是通道数)这一个维度的特征层。

②然后,对于平均池化输出的矩阵,进行两次全连接,第一次全连接和第二次是不完全相同的,区别在于:第一次全连接的通道数不完整,而是取原通道数的1/r,也就是这边的C/r,第二次则是用正常的通道数进行全连接。

这样做的目的是——能够减少通道个数从而降低计算量,并在一定程度上防止网络模型过拟合。(我在学习SEnet的结构时,看到第一次全连接减少通道数这个操作时,就有联想到神经网络的另一个trick,叫做dropout,dropout是一种正则化技巧,通过随机让神经网络中的部分神经元暂时失活,从而减少模型的过拟合风险,当时我以为SEnet的第一个全连接层就是运用了这个trick,但后来查阅资料时发现不是这样,dropout是随机减少全连接层中的部分神经元,而SEnet在这里是固定减少特征图的通道数,只能说有些异曲同工之妙吧),刚刚是在分享我学习过程遇到的小问题,现在说回正题,全连接1只取原通道数的1/r以此来减少计算量与防止过拟合,但是全连接2又用回原通道数——这样做是为了输出与原特征层相同的通道数,以便后续的最重要的reweight操作,也就是通过乘法逐通道加权到原先的输入特征层上。

值得注意的是,两个全连接层不是简单的直接相连,而是在全连接1后面经过一个relu激活函数,这是全连接层中很常规的操作,用来对一个全连接层的输出结果进行非线性变换,如果不这样做,所有的全连接层都只是普通的线性组合,这样训练出来的模型无法理解复杂的非线性数据和特征,可想而知这样的模型的检测效果肯定是很差的。

relu激活函数的公式其实很简单:f(x) = max(0, x),在x大于等于零时是线性函数,但当输入为负数时,输出为零,在负数部分截断了线性部分,将其映射到了一个确定的点上,从而实现了非线性变换。

自绘烂图,将就看。

③再然后,需要对全连接2的输出结果映射到sigmoid函数中,sigmoid是很经典的激活函数,它的值域是0到1,画一下函数图像(显然x=0时函数值等于0.5)……然后,它的定义域是整个实数集,值域是0到1,也就是说,全连接2的输出结果映射到sigmoid函数中后,就将得到一组0到1之间的值(因此称此操作为归一化),也就是所谓的不同通道的权重。

公式:

自绘烂图,我真的尽力画了/(ㄒoㄒ)/~~

最后最后,将这组通道权重与原输入2特征层通过乘法逐通道加权,就实现了“增强重要的通道,抑制不重要的通道”,也就是所谓的通道注意力机制

⒊源码(pytorch框架实现)及逐行解释

import torch
from torch import nn
from torchsummary import summary
 
 
class SEAttention(nn.Module):
    def __init__(self, inputs, ratio=4):
        super(SEAttention, self).__init__()  # 调用父类构造方法
        _, c, _, _ = inputs.size()# NCHW
        self.avgpool = nn.AdaptiveAvgPool2d(1)
        self.linear1 = nn.Linear(c, c // ratio, bias=False)
        self.relu = nn.ReLU(inplace=True)
        self.linear2 = nn.Linear(c // ratio, c, bias=False)
        self.sigmoid = nn.Sigmoid()
 
    def forward(self, inputs):
        n, c, _, _ = inputs.size()
        x = self.avgpool(inputs).view(n, c)#nchw,池化加reshape压缩维度
        x = self.linear1(x)
        x = self.relu(x)
        x = self.linear2(x)
        x = self.sigmoid(x)
        x = x.view(n, c, 1, 1) #reshape还原维度
        return inputs * x
 
 
#这边是测试代码,用summary类总结网络模型层
inputs = torch.randn(32, 512, 26, 26)  # NCHW
my_model = SEAttention(inputs)
outputs = my_model(inputs)
summary(my_model.cuda(), input_size=(512, 26, 26))

 解释:

①依赖包为torch,以及torch里的nn模块(导入这个纯粹是省得还要用torch.nn去调用nn的类或方法),summary类是用来测试的,需要提前下载,命令为->pip install torchsummary

②从整体来看,我们运用封装思想将整个模块封装为类,且这个类继承于nn.Moudule这个类,这个类共两部分,

__init__函数用来对实例化对象进行初始化,在python中这个函数属于类的魔术方法。

#代码逐行解释:
def __init__(self, inputs, ratio=4):#self必须写,inputs接收输入张量,ratio是通道衰减因子
        super(SEAttention, self).__init__()  # super关键字调用父类(即nn.Moudule类)的构造方法
        _, c, _, _ = inputs.size()#获取张量的形状(即NCHW),该模块只关注参数C,其余用占位符忽略
        self.avgpool = nn.AdaptiveAvgPool2d(1)#nn模块的自适应二维平均池化,参数1等同于全局平均池化
        self.linear1 = nn.Linear(c, c // ratio, bias=False)#nn模块的全连接,这里输入c,输出c//ratio,bias是偏置参数,网络层是否有偏置,默认存在,若bias=False,则该网络层无偏置,图层不会学习附加偏差
        self.relu = nn.ReLU(inplace=True)#nn模块的ReLU激活函数,inplace=True表示要用引用传递(即地址传递),估计可以减少张量的内存占用(因为值传递要拷贝一份)
        self.linear2 = nn.Linear(c // ratio, c, bias=False)#同全连接1,但输入输出相反
        self.sigmoid = nn.Sigmoid()#nn模块的Sigmoid函数

forward函数进行前向传播,用初始化好的网络模型对输入特征层进行一系列加工。

#代码逐行解释:
def forward(self, inputs):#self必须写,inputs接收输入特征张量
        n, c, _, _ = inputs.size()#获取张量形状(即NCHW),HW被忽略
        x = self.avgpool(inputs).view(n, c)#nchw,池化加view方法重塑(reshape)张量形状,因为全连接层之间的张量必须是二维的(一个输入维度一个输出维度),view的参数是(n,c)表示只保留这两个维度
        x = self.linear1(x)
        x = self.relu(x)
        x = self.linear2(x)
        x = self.sigmoid(x)#上面这四行直接调用初始化好的网络层即可
        x = x.view(n, c, 1, 1) #reshape还原维度,因为要和原输入特征相乘,不重塑形状不同无法相乘
        return inputs * x#和原输入特征层相乘

⒋测试结果

感觉summary类没有很好使。。。有些关键网络层的变换没有体现出来,这里是少了最后reshape的一层,但无伤大雅罢!

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/163111.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

基于STM32的多组外部中断(EXTI)的优化策略与应用

在某些嵌入式应用中,可能需要同时处理多个外部中断事件。STM32系列微控制器提供了多组外部中断线(EXTI Line),可以同时配置和使用多个GPIO引脚作为外部中断触发器。为了有效管理和处理多组外部中断,我们可以采取一些优…

Linux非阻塞等待示例

Linux非阻塞等待实例 非阻塞等待的意义:简单的多进程编程示例代码解释 非阻塞等待的意义: 非阻塞等待在多进程编程中的意义主要体现在提高系统的响应性、实现异步任务执行、动态任务管理和多任务协同工作等方面。它允许父进程在等待子进程退出的同时&…

优化|优化求解器自动调参

原文信息:MindOpt Tuner: Boost the Performance of Numerical Software by Automatic Parameter Tuning 作者:王孟昌 (达摩院决策智能实验室MindOpt团队成员) 一个算法开发者,可能会幻想进入这样的境界:算…

C++题目练习第二十天__有效的完全平方数

题目链接: 力扣(LeetCode)官网 - 全球极客挚爱的技术成长平台 题目: 给你一个正整数 num 。如果 num 是一个完全平方数,则返回 true ,否则返回 false 。 完全平方数 是一个可以写成某个整数的平方的整数。…

Android 弹出自定义对话框

Android在任意Activity界面弹出一个自定义的对话框,效果如下图所示: 准备一张小图片,右上角的小X图标64*64,close_icon.png,随便找个小图片代替; 第一步:样式添加,注意:默认在value…

2023年中职“网络安全“—Web 渗透测试②

2023年中职“网络安全“—Web 渗透测试② Web 渗透测试任务环境说明:1.访问http://靶机IP/web1/,获取flag值,Flag格式为flag{xxx};2.访问http://靶机IP/web2/,获取flag值,Flag格式为flag{xxx};3.访问http://靶机IP/web…

git拉取普通idea Java项目module没有build的问题

在不断完成一个项目的时候,会有不断新加的module,我们用git拉取时会发生没有识别新module的情况。 解决方法是右键项目名称,然后点击Open Module Settings 接下来,点击Module,加号,新建Module的名字就是在g…

深度学习乳腺癌分类 计算机竞赛

文章目录 1 前言2 前言3 数据集3.1 良性样本3.2 病变样本 4 开发环境5 代码实现5.1 实现流程5.2 部分代码实现5.2.1 导入库5.2.2 图像加载5.2.3 标记5.2.4 分组5.2.5 构建模型训练 6 分析指标6.1 精度,召回率和F1度量6.2 混淆矩阵 7 结果和结论8 最后 1 前言 &…

sqli-labs关卡18(基于http头部报错盲注)通关思路

文章目录 前言一、靶场通关需要了解的知识点1、什么是http请求头2、为什么http头部可以进行注入 二、靶场第十八关通关思路1、判断注入点2、爆数据库名3、爆数据库表4、爆数据库列5、爆数据库关键信息 总结 前言 此文章只用于学习和反思巩固sql注入知识,禁止用于做…

若依框架数据源切换为pg库

一 切换数据源 在ruoyi-admin项目里引入pg数据库驱动 <dependency><groupId>org.postgresql</groupId><artifactId>postgresql</artifactId><version>42.2.18</version> </dependency>修改配置文件里的数据源为pg spring:d…

测不准原理

测不准原理 算符的对易关系 commutation relation 测不准原理的矢量推导 Schwarz inequality: 设对易关系&#xff1a; 设一个新态&#xff1a; 投影&#xff1a; 那么有&#xff1a; 代回Schwarz inequality 即可证明&#xff1a;

2023OceanBase年度发布会后,有感

很荣幸收到了OceanBase邀请&#xff0c;于本周四&#xff08;11月16日&#xff09;参加了OceanBase年度发布会并参加了DBA老友会&#xff0c;按照理论应该我昨天&#xff08;星期五&#xff09;就回到成都了&#xff0c;最迟今天白天就该把文章写出来了&#xff0c;奈何媳妇儿买…

zsh和ohmyzsh安装指南+插件推荐

文章目录 1. 安装指南2. 插件配置指南3. 参考信息 1. 安装指南 1. 安装 zsh sudo apt install zsh2. 安装 Oh My Zsh 国内访问GitHub sh -c "$(curl -fsSL https://raw.githubusercontent.com/ohmyzsh/ohmyzsh/master/tools/install.sh)"这将安装 Oh My Zsh 和所…

75基于matlab的模拟退火算法优化TSP(SA-TSP),最优路径动态寻优,输出最优路径值、路径曲线、迭代曲线。

基于matlab的模拟退火算法优化TSP(SA-TSP)&#xff0c;最优路径动态寻优&#xff0c;输出最优路径值、路径曲线、迭代曲线。数据可更换自己的&#xff0c;程序已调通&#xff0c;可直接运行。 75matlab模拟退火算法TSP问题 (xiaohongshu.com)

联想系列台式机Win11系统改Win7系统BIOS设置步骤

联想最新一代的台式机默认操作系统Win11&#xff0c;采用UEFIGPT启动模式&#xff0c;并且开启了安全启动功能&#xff0c;一般用户不能直接将Win11改成Win7&#xff0c;如果需要更改操作系统&#xff0c;是需要再BIOS菜单中关闭安全启动功能的&#xff0c;并且把启动模式设置成…

Linux CentOS7 添加网卡

一台主机中安装多块网卡&#xff0c;有许多优势。可以实现多项功能。 为了学习网卡参数的设置&#xff0c;可以为主机添加多块网卡。与添加磁盘一样&#xff0c;要在VMware中设置。利用图形化方式或命令行查看或设置网卡。本文仅作一初步讨论。有关网络参数的设置不在讨论之列…

【Web】Ctfshow SSRF刷题记录1

核心代码解读 <?php $url$_POST[url]; $chcurl_init($url); curl_setopt($ch, CURLOPT_HEADER, 0); curl_setopt($ch, CURLOPT_RETURNTRANSFER, 1); $resultcurl_exec($ch); curl_close($ch); ?> curl_init()&#xff1a;初始curl会话 curl_setopt()&#xff1a;会…

C语言进阶第十课 --------文件的操作

作者前言 &#x1f382; ✨✨✨✨✨✨&#x1f367;&#x1f367;&#x1f367;&#x1f367;&#x1f367;&#x1f367;&#x1f367;&#x1f382; ​&#x1f382; 作者介绍&#xff1a; &#x1f382;&#x1f382; &#x1f382; &#x1f389;&#x1f389;&#x1f389…

HarmonyOS开发Java与ArkTS如何抉择

在“鸿蒙系统实战短视频App 从0到1掌握HarmonyOS”视频课程中&#xff0c;很多学员来问我&#xff0c;在HarmonyOS开发过程中&#xff0c;面对Java与ArkTS&#xff0c;应该选哪样&#xff1f; 本文详细分析Java与ArkTS在HarmonyOS开发过程的区别&#xff0c;力求解答学员的一些…

【我和Python算法的初相遇】——体验递归的可视化篇

&#x1f308;个人主页: Aileen_0v0 &#x1f525;系列专栏:PYTHON数据结构与算法学习系列专栏&#x1f4ab;"没有罗马,那就自己创造罗马~" 目录 递归的起源 什么是递归? 利用递归解决列表求和问题 递归三定律 递归应用-整数转换为任意进制数 递归可视化 画…