【深度学习注意力机制系列】—— SKNet注意力机制(附pytorch实现)

SKNet(Selective Kernel Network)是一种用于图像分类和目标检测任务的深度神经网络架构,其核心创新是引入了选择性的多尺度卷积核(Selective Kernel)以及一种新颖的注意力机制,从而在不增加网络复杂性的情况下提升了特征提取的能力。SKNet的设计旨在解决多尺度信息融合的问题,使网络能够适应不同尺度的特征。

1. 核心思想

SKNet的核心思想是**通过选择性地应用不同尺度的卷积核,从而在不同层级上捕捉多尺度特征。**为了实现这一点,SKNet引入了一个选择模块,用于自适应地决定在每个通道上使用哪些尺度的卷积核。这种选择性的多尺度卷积核有助于提升特征表示的能力,使网络更具适应性和泛化能力。

2. 结构

SKNet的结构如下:

在这里插入图片描述

实现机制:

  • split:对特征图进行多分支分离卷积,各分支使用不同的卷积核(感受野不同)进行特征提取。(并未对原始特征图进行拆解分离,只是使用不同的卷积核对原始特征图进行卷积操作)。假设分支为n,则特征图维度变换为 (c, h, w) -> (n, c, h, w),原文中n=2。

  • Fuse:将多个分支的特征图提取结果相加。特征图维度变换为 (n, c, h, w) -> (c, h, w)。再通过全局平均池,特征图维度变换为 (c, h, w) -> (c, 1, 1),然后利用全连接层进行降维(限制了最低维度,通过全连接层生成d×1的向量(图中的z),公式如图中所示(δ表示ReLU激活函数,B表示Batch Noramlization,W是一个d×C的维的)。d的取值是由公式d = max(C/r,L)确定,r是一个缩小的比率(与SENet中相似),L表示d的最小值,原文实验中L的值为32。),再利用两个(或多个,和分支数目相同,原论文中为两个)全连接层进行升维,得到两个(多个)维度同降维前相同的特征图(向量)。在对两个特征向量进行softmax处理。假设分支为n,则特征图维度为 n个(c, 1, 1) ,原文中n=2,即a->(c, 1, 1), b->(c, 1, 1)。

  • select:利用softmax处理后的多个特征向量分别乘以第一步中的多分支提取的特征图结果。特征维度变化为n个(c, 1 ,1) * n 个(c, h ,w) = (n, c, h, w)。最后将n个特征图进行相加。

3. 优势

SKNet的设计在以下几个方面具有优势:

  • 多尺度信息融合

通过选择性地应用不同尺度的卷积核,SKNet能够有效地融合多尺度的特征信息。这有助于网络捕捉不同层次的视觉特征,提高了特征的表征能力。

  • 自适应性

选择模块使网络能够自适应地选择卷积核的尺度,从而适应不同任务和图像的特点。这种自适应性能够使网络在各种场景下都能表现出色。

  • 减少计算成本

尽管引入了多尺度卷积核,但由于选择模块的存在,SKNet只会选择一部分卷积核进行计算,从而减少了计算成本,保持了网络的高效性。

4.代码实现

class SKNet(nn.Module):
    def __init__(self, in_channels, out_channels, stride=1, M=2, r=16, L=32):
        """
        :param in_channels:  输入通道维度
        :param out_channels: 输出通道维度   原论文中 输入输出通道维度相同
        :param stride:  步长,默认为1
        :param M:  分支数
        :param r: 特征Z的长度,计算其维度d 时所需的比率(论文中 特征S->Z 是降维,故需要规定 降维的下界)
        :param L:  论文中规定特征Z的下界,默认为32
        采用分组卷积: groups = 32,所以输入channel的数值必须是group的整数倍
        """
        super(SKNet, self).__init__()
        d = max(in_channels // r, L)  
        self.M = M
        self.out_channels = out_channels
        self.conv = nn.ModuleList() 
        for i in range(M):
            self.conv.append(nn.Sequential(
                nn.Conv2d(in_channels, out_channels, 3, stride, padding=1 + i, dilation=1 + i, groups=32, bias=False),
                nn.BatchNorm2d(out_channels),
                nn.ReLU(inplace=True)))
        self.global_pool = nn.AdaptiveAvgPool2d(output_size=1) 
        self.fc1 = nn.Sequential(nn.Conv2d(out_channels, d, 1, bias=False),
                                 nn.BatchNorm2d(d),
                                 nn.ReLU(inplace=True))  # 降维
        self.fc2 = nn.Conv2d(d, out_channels * M, 1, 1, bias=False)  
        self.softmax = nn.Softmax(dim=1) 
    def forward(self, input):
        batch_size = input.size(0)
        output = []
        for i, conv in enumerate(self.conv):
            output.append(conv(input))
        U = reduce(lambda x, y: x + y, output)  
        s = self.global_pool(U)  
        z = self.fc1(s)
        a_b = self.fc2(z) 
        a_b = a_b.reshape(batch_size, self.M, self.out_channels, -1) 
        a_b = self.softmax(a_b) 
        a_b = list(a_b.chunk(self.M, dim=1))  
        a_b = list(map(lambda x: x.reshape(batch_size, self.out_channels, 1, 1),
                       a_b))  
        V = list(map(lambda x, y: x * y, output,
                     a_b))  
        V = reduce(lambda x, y: x + y,
                   V)  
        return V

总结

SKNet是一种创新的深度神经网络架构,通过引入选择性的多尺度卷积核和注意力机制,提升了特征提取的能力。其核心结构包括选择模块和SK卷积层,能够有效地融合多尺度信息、自适应地调整卷积核的尺度,并减少计算成本。这使得SKNet在图像分类和目标检测等任务中取得了优越的性能。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/67338.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

opencv基础53-图像轮廓06-判断像素点与轮廓的关系(轮廓内,轮廓上,轮廓外)cv2.pointPolygonTest()

点到轮廓的距离 在 OpenCV 中,函数 cv2.pointPolygonTest()被用来计算点到多边形(轮廓)的最短距离(也 就是垂线距离),这个计算过程又称点和多边形的关系测试。该函数的语法格式为: retval cv2…

Linux —— 基础I/O

一,背景介绍 狭义的文件存放在磁盘上,广义上在Linux下一切皆文件;磁盘上的文件一般为永久存储的外设,本质上对文件的操作,即为对外设的输入和输出(简称I/O);空文件并不是不占磁盘文件…

uC-OS2 V2.93 STM32L476 移植:串口打印篇

前言 前几篇已经 通过 STM32CubeMX 搭建了 NUCLEO-L476RG 的 STM32L476RG 的 裸机工程,下载了 uC-OS2 V2.93 的源码,并把 uC-OS2 的源文件加入 Keil MDK5 工程,通过适配 Systick 系统定时器与 PendSV 实现任务调度,初步让 uC-OS2 …

Facebook商城号最全解析,Facebook Marketplace运营技巧

相信许多做跨境的小伙伴都听说过Facebook商城号。其实所谓的商城并不是Facebook Shop,而是指Facebook Marketplace,它不像前者需要各类入驻条件,只要拥有facebook账号而且所在地区有能够使用marketplace就可以在上面进行商品售卖。这种低成本…

validator入门

validator中文文档地址和英文地址 https://docs.jboss.org/hibernate/validator/4.2/reference/zh-CN/html/validator-gettingstarted.html https://docs.jboss.org/hibernate/validator/6.0/reference/en-US/html_single/#preface自定义hibernate-validator校验 工具类Valid…

设计模式(4)装饰模式

一、介绍: 1、应用场景:把所需的功能按正确的顺序串联起来进行控制。动态地给一个对象添加一些额外的职责,就增加功能来说,装饰模式比生成子类更加灵活。 当需要给一个现有类添加附加职责,而又不能采用生成子类的方法…

谷歌关闭跨域限制.(生成一个开发浏览器),Chrome关闭跨域

(一)、首先找到浏览器在电脑磁盘中的位置,并复制 (二)、复制一个浏览器的快捷方式到桌面(不影响正常浏览器) (三)、chrom鼠标右键属性,修改快捷方式的目标 (四)chrome.exe 后面添加 --disable-web-security --user-data-dir 复制的Chrome浏览…

力扣hot100刷题记录

二刷hot100&#xff0c;坚持每天打卡&#xff01;&#xff01;&#xff01;Today&#xff1a;2023-8-10 1. 两数之和 // 先求差&#xff0c;再查哈希表 public int[] twoSum(int[] nums, int target) {Map<Integer,Integer> map new HashMap<>();for(int i 0;i&…

窥探系列之Mybatis-plus XML分页查询

mybatisPlus分页查总数 Page类在mybatisPlus中用于分页查询&#xff0c;继承Pagination类&#xff0c;Pagination类的searchCount字段控制是否查询总记录数 顺着看哪里用到了searchCount&#xff1a; com.baomidou.mybatisplus.plugins.PaginationInterceptor 是mybatisPlus…

安防监控视频汇聚EasyCVR平台的FLV视频流在VLC中无法播放的原因排查

众所周知&#xff0c;TSINGSEE青犀视频汇聚平台EasyCVR可支持多协议方式接入&#xff0c;包括主流标准协议国标GB28181、RTSP/Onvif、RTMP等&#xff0c;以及厂家私有协议与SDK接入&#xff0c;包括海康Ehome、海大宇等设备的SDK等。在视频流的处理与分发上&#xff0c;视频监控…

PyTorch 微调终极指南:第 2 部分 — 提高模型准确性

一、说明 如今&#xff0c;在训练深度学习模型时&#xff0c;通过在自己的数据上微调预训练模型来迁移学习已成为首选方法。通过微调这些模型&#xff0c;我们可以利用他们的专业知识并使其适应我们的特定任务&#xff0c;从而节省宝贵的时间和计算资源。本文分为四个部分&…

如何使用 AT+WEBSERVER 指令实现自定义的 Webserver html 网页配网

开启 AT 固件中的 Webserver 指令和 FS 指令支持 乐鑫官网发布的默认通用 AT 固件不支持 webserver 配网功能&#xff0c; 需要用户自己搭建 esp-at 环境&#xff0c;并在 sdkconfig 中开启 webserver AT 指令 和 FS 指令的支持&#xff0c; 如下图所示&#xff1a; 测试 AT 固…

Elasticsearch 与 OpenSearch:揭开性能差距

作者&#xff1a;George Kobar, Ugo Sangiorgi 对于任何依赖快速、准确搜索数据的组织来说&#xff0c;强大、快速且高效的搜索引擎是至关重要的元素。 对于开发人员和架构师来说&#xff0c;选择正确的搜索平台可以极大地影响你的组织提供快速且相关结果的能力。 在我们全面的…

oracle积累增量和差异增量

积累增量和差异增量&#xff1a; 对于 RMAN 来说&#xff0c;积累增量备份和差异增量备份都是增量备份的一种形式&#xff0c;它们之间的区别在于备份的范围和备份集的方式。 积累增量备份&#xff1a;在进行积累增量备份时&#xff0c;RMAN 会备份自最后一次完全备份或增量备…

FPGA外部触发信号毛刺产生及滤波

1、背景 最近在某个项目中&#xff0c;遇到输入给FPGA管脚的外部触发信号因为有毛刺产生&#xff0c;导致FPGA接收到的外部触发信号数量多于实际值。比如&#xff1a;用某个信号源产生1000个外部触发信号&#xff08;上升沿触发方式&#xff09;给到FPGA输入IO&#xff0c;实际…

PCI 简易通讯控制器有黄色感叹号

一、问题描述 设备管理器中&#xff0c;其他设备中显示 “PCI 简易通讯控制器”驱动未安装&#xff0c;显示黄色感叹号&#xff1a; 二、原因分析 右键该驱动&#xff0c;查看属性ID&#xff0c;显示为&#xff1a; PCI \ VEN_8086&#xff06;DEV_1C3A&#xff06;SUBSYS…

14.3.4 【Linux】使用 LVM thin Volume 让 LVM 动态自动调整磁盘使用率

想像一个情况&#xff0c;你有个目录未来会使用到大约 5T 的容量&#xff0c;但是目前你的磁盘仅有 3T&#xff0c;问题是&#xff0c;接下来的两个月你的系统都还不会超过 3T 的容量&#xff0c; 不过你想要让用户知道&#xff0c;就是他最多有 5T 可以使用就是了&#xff01;…

ECharts 折线图使用相关

一、折线图堆叠设置为不堆叠的方法 官网是这样的&#xff0c;但是不需要这种堆叠形式的如下图&#xff1a; 即&#xff1a;第2条数据值 第1条数据值 第2条数据值 第3条数据值 第2条数据值 第3条数据值 需要改成实际值展示&#xff0c;如下图&#xff1a; 只需要修改stack的…

微信小程序 地图map(电子围栏圆形和多边形)

正常情况下是没有手机上画电子围栏的&#xff0c;公共平台上我也没找到&#xff0c;所以走了一个歪点子&#xff0c;就是给地图添加点击事件&#xff0c;记录点的位置&#xff0c;在画到电子围栏上就是添加电子围栏了&#xff0c;如果只是显示电子围栏就简单了 一、多边形电子…

mybatis-flex探索

mybatis古今未来 最近无意之中发现了一个非常棒的持久层框架mybatis-flex&#xff0c;迫不及待研究了一下 发现简直就是我的梦中情框&#xff0c;之前写ibatis&#xff0c;后来写mybatis&#xff0c;接着写mybatis-plus&#xff0c;接着研究mybatis-flex ibatis ibatis是apa…