伪装目标检测之注意力CBAM:《Convolutional Block Attention Module》

论文地址:link
代码:link

摘要

我们提出了卷积块注意力模块(CBAM),这是一种简单而有效的用于前馈卷积神经网络的注意力模块。给定一个中间特征图,我们的模块依次推断沿着两个独立维度的注意力图,通道和空间,然后将这些注意力图与输入特征图相乘,进行自适应特征细化。由于CBAM是一个轻量级和通用的模块,它可以无缝地集成到任何CNN架构中,几乎没有额外开销,并且可以与基础CNN一起端到端地进行训练。我们通过在ImageNet-1K、MS COCO检测和VOC 2007检测数据集上进行大量实验来验证我们的CBAM。 我们的实验证明,在各种模型中,分类和检测性能均有一致的提升,展示了CBAM的广泛适用性。

1.介绍

卷积神经网络(CNN)凭借其丰富的表示能力,显著提升了视觉任务的性能[1,2,3]。为了增强CNN的性能,最近的研究主要探讨了网络的三个重要因素:深度、宽度和基数。。除了这些因素,我们研究了架构设计的另一个方面,即注意力。注意力的重要性在前的文献中得到了广泛研究[12,13,14,15,16,17]。注意力不仅告诉我们要关注哪里,还改善了感兴趣的表示。我们的目标是通过使用注意力机制增强表示能力:专注于重要特征并抑制不必要的特征。在本文中,我们提出了一个新的网络模块,名为“卷积块注意力模块”。由于卷积操作通过将跨通道和空间信息混合在一起提取信息特征,我们采用我们的模块来强调这两个主要维度上的有意义特征:通道和空间轴。为了实现这一点,我们依次应用通道和空间注意力模块(如图1所示),以便每个分支可以分别学习在通道和空间轴上关注“什么”和“哪里”。
主要贡献
1.我们提出了一个简单而有效的注意力模块(CBAM),可以广泛应用于提升CNN的表示能力。
2.我们通过广泛的消融研究验证了我们的注意力模块的有效性。
3.我们验证了在多个基准测试(ImageNet-1K、MS COCO和VOC 2007)上,通过插入我们的轻量级模块,各种网络的性能得到了显著提升。

2.卷积块注意模块

在这里插入图片描述
给定一个中间特征图 F ∈ R C × H × W F \in {R^{C \times H \times W}} FRC×H×W作为输入,CBAM 依次推断出一个1D通道注意力图 M c ∈ R C × 1 × 1 {M_c} \in {R^{C \times 1 \times 1}} McRC×1×1和一个 2D 空间注意力图 M s ∈ R 1 × H × W {M_s} \in {R^{1 \times H \times W}} MsR1×H×W,如图1所示。总体注意力过程可以总结为:
在这里插入图片描述
在这里,符号 ⊗ 表示逐元素乘法。在乘法过程中,注意力值相应地进行广播(复制):通道注意力值沿空间维度广播,反之亦然。F ′′ 是最终的精炼输出。图2描述了每个注意力图的计算过程。以下描述了每个注意力模块的细节。

2.1 Channel attention module

利用特征的通道间关系来生成通道注意力图,特征图的每个通道都被视为特征检测器,通道注意力集中在给定输入图像的情况下“什么”是有意义的,为了有效地计算通道注意力,压缩输入特征图的空间维度,使用平均池化和最大池化特征。利用这两个特征可以极大地提高网络的表示能力,而不是单独使用每个特征。
首先利用平均池化和最大池化操作来聚合特征图的空间信息,生成两个不同的空间上下文描述符, F a v g c F_{avg}^c Favgc F m a x c F_{max}^c Fmaxc,分别表示平均池化特征和最大池化特征。然后,这两个描述符都被转发到共享网络以生成我们的通道注意力图 M c ∈ R C × 1 × 1 {M_c} \in {R^{C \times 1 \times 1}} McRC×1×1 。共享网络由具有一个隐藏层的多层感知器(MLP)组成。为了减少参数开销,隐藏激活大小设置为 R C / r × 1 × 1 R ^{C/r×1×1} RC/r×1×1 ,其中 r 是缩减比率。将共享网络应用于每个描述符后,我们使用逐元素求和来合并输出特征向量。简而言之,通道注意力计算如下:
在这里插入图片描述
其中 σ 表示 sigmoid 函数,W0 ∈ R C / r × C R^{C/r×C} RC/r×C ,W1 ∈ R C × C / r R^{C×C/r} RC×C/r 。请注意,两个输入共享 MLP 权重 W 0 W_0 W0 W 1 W_1 W1,并且 ReLU 激活函数后面跟着 W 0 W_0 W0

2.2 Spatial attention module

利用特征的空间关系生成空间注​​意力图。与通道注意力不同,空间注意力关注的是“哪里”,这是信息性的部分,与通道注意力是互补的。为了计算空间注意力,我们首先沿着通道轴应用平均池化和最大池化操作并将它们连接起来以生成有效的特征描述符。沿着通道轴应用池化操作被证明可以有效地突出显示信息区域。在级联特征描述符上,我们应用卷积层来生成空间注​​意力图 M s ( F ) ∈ R H × W {M_s}\left( F \right) \in {R^{H \times W}} Ms(F)RH×W,它对强调或抑制的位置进行编码。下面我们描述详细操作。通过使用两个池化操作来聚合特征图的通道信息,生成两个2D图: F a v g s ∈ R 1 × H × W F_{avg}^s \in {R^{1 \times H \times W}} FavgsR1×H×W F m a x s ∈ R 1 × H × W F_{max}^s \in {R^{1 \times H \times W}} FmaxsR1×H×W。每个表示整个通道的平均池化特征和最大池化特征。然后,它们通过标准卷积层连接和卷积,生成我们的 2D 空间注意力图。简而言之,空间注意力计算如下:
在这里插入图片描述
其中σ表示sigmoid函数,f 7×7表示滤波器尺寸为7×7的卷积运算。

注意力模块的安排。给定输入图像,两个注意力模块(通道和空间)计算互补注意力,分别关注“什么”和“哪里”。考虑到这一点,两个模块可以并行或顺序放置。我们发现顺序排列比并行排列提供更好的结果。对于顺序过程的安排,我们的实验结果表明,通道优先顺序略好于空间优先顺序。
代码实现:

import torch
import math
import torch.nn as nn
import torch.nn.functional as F

class BasicConv(nn.Module):
    def __init__(self, in_planes, out_planes, kernel_size, stride=1, padding=0, dilation=1, groups=1, relu=True, bn=True, bias=False):
        super(BasicConv, self).__init__()
        self.out_channels = out_planes
        self.conv = nn.Conv2d(in_planes, out_planes, kernel_size=kernel_size, stride=stride, padding=padding, dilation=dilation, groups=groups, bias=bias)
        self.bn = nn.BatchNorm2d(out_planes,eps=1e-5, momentum=0.01, affine=True) if bn else None
        self.relu = nn.ReLU() if relu else None

    def forward(self, x):
        x = self.conv(x)
        if self.bn is not None:
            x = self.bn(x)
        if self.relu is not None:
            x = self.relu(x)
        return x

class Flatten(nn.Module):
    def forward(self, x):
        return x.view(x.size(0), -1)

class ChannelGate(nn.Module):
    def __init__(self, gate_channels, reduction_ratio=16, pool_types=['avg', 'max']):
        super(ChannelGate, self).__init__()
        self.gate_channels = gate_channels
        self.mlp = nn.Sequential(
            Flatten(),
            nn.Linear(gate_channels, gate_channels // reduction_ratio),
            nn.ReLU(),
            nn.Linear(gate_channels // reduction_ratio, gate_channels)
            )
        self.pool_types = pool_types
    def forward(self, x):
        channel_att_sum = None
        for pool_type in self.pool_types:
            if pool_type=='avg':
                avg_pool = F.avg_pool2d( x, (x.size(2), x.size(3)), stride=(x.size(2), x.size(3)))
                channel_att_raw = self.mlp( avg_pool )
            elif pool_type=='max':
                max_pool = F.max_pool2d( x, (x.size(2), x.size(3)), stride=(x.size(2), x.size(3)))
                channel_att_raw = self.mlp( max_pool )
            elif pool_type=='lp':
                lp_pool = F.lp_pool2d( x, 2, (x.size(2), x.size(3)), stride=(x.size(2), x.size(3)))
                channel_att_raw = self.mlp( lp_pool )
            elif pool_type=='lse':
                # LSE pool only
                lse_pool = logsumexp_2d(x)
                channel_att_raw = self.mlp( lse_pool )

            if channel_att_sum is None:
                channel_att_sum = channel_att_raw
            else:
                channel_att_sum = channel_att_sum + channel_att_raw

        scale = F.sigmoid( channel_att_sum ).unsqueeze(2).unsqueeze(3).expand_as(x)
        return x * scale

def logsumexp_2d(tensor):
    tensor_flatten = tensor.view(tensor.size(0), tensor.size(1), -1)
    s, _ = torch.max(tensor_flatten, dim=2, keepdim=True)
    outputs = s + (tensor_flatten - s).exp().sum(dim=2, keepdim=True).log()
    return outputs

class ChannelPool(nn.Module):
    def forward(self, x):
        return torch.cat( (torch.max(x,1)[0].unsqueeze(1), torch.mean(x,1).unsqueeze(1)), dim=1 )

class SpatialGate(nn.Module):
    def __init__(self):
        super(SpatialGate, self).__init__()
        kernel_size = 7
        self.compress = ChannelPool()
        self.spatial = BasicConv(2, 1, kernel_size, stride=1, padding=(kernel_size-1) // 2, relu=False)
    def forward(self, x):
        x_compress = self.compress(x)
        x_out = self.spatial(x_compress)
        scale = F.sigmoid(x_out) # broadcasting
        return x * scale

class CBAM(nn.Module):
    def __init__(self, gate_channels, reduction_ratio=16, pool_types=['avg', 'max'], no_spatial=False):
        super(CBAM, self).__init__()
        self.ChannelGate = ChannelGate(gate_channels, reduction_ratio, pool_types)
        self.no_spatial=no_spatial
        if not no_spatial:
            self.SpatialGate = SpatialGate()
    def forward(self, x):
        x_out = self.ChannelGate(x)
        if not self.no_spatial:
            x_out = self.SpatialGate(x_out)
        return x_out

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/488150.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

Qt实现简易的多线程TCP服务器(支持多个客户端连接)附源码

目录 一.UI界面的设计 二.服务器的启动 三.实现自定义的TcpServer类 1.在widget中声明自定义TcpServer类的成员变量 2.在TcpServer的构造函数中对于我们声明的m_widget进行初始化,m_widget我们用于后续的显示消息等,说白了就是主界面的更新显示等 …

为何ChatGPT日耗电超50万度?

看新闻说,ChatGPT每天的耗电量是50万度,国内每个家庭日均的耗电量不到10度,ChatGPT耗电相当于国内5万个家庭用量。 网上流传,英伟达创始人黄仁勋说:“AI的尽头是光伏和储能”,大佬的眼光就是毒辣&#xff…

使用LLaVA模型实现以文搜图和以图搜图

本文将会详细介绍如何使用多模态模型——LLaVA模型来实现以文搜图和以图搜图的功能。本文仅为示例Demo,并不能代表实际的以文搜图和以图搜图的技术实现方案。 1、实现原理 使用多模态模型获取图片的标题和详细描述以文搜图功能:使用ES实现查询匹配&…

深入了解 Linux 中的 MTD 设备:/dev/mtd* 与 /dev/mtdblock*

目录 前言一、什么是MTD子系统?二、 /dev/mtd* 设备文件用途注意事项 三、/dev/mtdblock* 设备文件用途注意事项 三、这两种设备文件的关系四、关norflash的一些小知识 前言 在嵌入式Linux系统的世界里,非易失性存储技术扮演着至关重要的角色。MTD&#…

面试知识汇总——垃圾回收器(分代收集算法)

分代收集算法 根据对象的存活周期,把内存分成多个区域,不同区域使用不同的回收算法回收对象。 对象在创建的时候,会先存放到伊甸园。当伊甸园满了之后,就会触发垃圾回收。 这个回收的过程是:把伊甸园中的对象拷贝到F…

初识redis(一)

前言 引用的是这本书的原话 Redis[1]是一种基于键值对(key-value)的NoSQL数据库,与很多键值对数据库不同的是,Redis中的值可以是由string(字符串)、hash(哈希)、list(列…

Android15功能和 API 概览

Android 15 面向开发者引入了一些出色的新功能和 API。以下部分总结了这些功能,以帮助您开始使用相关 API。 如需查看新增、修改和移除的 API 的详细列表,请参阅 API 差异报告。如需详细了解新的 API,请访问 Android API 参考文档&#xff0…

Selenium 自动化 —— 定位页面元素

更多内容请关注我的 Selenium 自动化 专栏: 入门和 Hello World 实例使用WebDriverManager自动下载驱动Selenium IDE录制、回放、导出Java源码浏览器窗口操作切换浏览器窗口 使用 Selenium 做自动化,我们不仅仅是打开一个网页,这只是万里长…

Stable Diffusion 进阶教程 - 二次开发(制作您的文生图应用)

目录 1. 引言 2. 基于Rest API 开发 2.1 前置条件 2.2 代码实现 2.3 效果演示 2.4 常见错误 3. 总结 1. 引言 Stable Diffusion作为一种强大的文本到图像生成模型,已经在艺术、设计和创意领域引起了广泛的关注和应用。然而,对于许多开发者来说&#xff…

时序预测 | Matlab实现SSA-BP麻雀算法优化BP神经网络时间序列预测

时序预测 | Matlab实现SSA-BP麻雀算法优化BP神经网络时间序列预测 目录 时序预测 | Matlab实现SSA-BP麻雀算法优化BP神经网络时间序列预测预测效果基本介绍程序设计参考资料 预测效果 基本介绍 1.Matlab实现SSA-BP麻雀算法优化BP神经网络时间序列预测(完整源码和数据…

DRC检查及丝印的调整

DRC检查及丝印的调整 综述:本文主要讲述AD软件中DRC检查、丝印的调整以及logo的添加的相关步骤,附加logo添加的脚本链接和大量操作图片,使步骤详细直观。 1. 点击“工具”→“设计规则检查”→“运行DRC”。(一开始可以只开启电…

利用云手机技术,开拓海外社交市场

近年来,随着科技的不断进步,云手机技术逐渐在海外社交营销领域崭露头角。其灵活性、成本效益和全球性特征使其成为海外社交营销的利器。那么,究竟云手机在海外社交营销中扮演了怎样的角色呢? 首先,云手机技术能够消除地…

LLM - 大语言模型的指令微调(Instruction Tuning) 概述

欢迎关注我的CSDN:https://spike.blog.csdn.net/ 本文地址:https://blog.csdn.net/caroline_wendy/article/details/137009993 大语言模型的指令微调(Instruction Tuning)是一种优化技术,通过在特定的数据集上进一步训练大型语言模型(LLMs)&a…

javaWeb个人日记(博客)管理系统

一、简介 在快节奏的生活中,记录生活点滴、感悟和思考是一种重要的方式。基于此,我设计了一个基于JavaWeb的个人日记本系统,旨在帮助用户轻松记录并管理自己的日记。该系统包括登录、首页、日记列表、写日记、日记分类管理和个人中心等功能&…

mysql - 缓存

缓存 InnoDB存储引擎在处理客户端的请求时,当需要访问某个页的数据时,就会把完整的页的数据全部加载到内存中,也就是说即使我们只需要访问一个页的一条记录,那也需要先把整个页的数据加载到内存中。将整个页加载到内存中后就可以…

命令模式(请求与具体实现解耦)

目录 前言 UML plantuml 类图 实战代码 模板 Command Invoker Receiver Client 前言 命令模式解耦了命令请求者(Invoker)和命令执行者(receiver),使得 Invoker 不再直接引用 receiver,而是依赖于…

Java基础--128陷阱

问题引入 Integer a 123; Integer b 123; System.out.println(ab); 结果为true。 但是如果代码如下 Integer a 1230;Integer b 1230;System.out.println(ab); 这个的结果就是false。 问题解决 当Integer a 123时,其实他底层自动转换成了Integer a Inte…

Learn OpenGL 29 延迟着色法

延迟着色法 我们现在一直使用的光照方式叫做正向渲染(Forward Rendering)或者正向着色法(Forward Shading),它是我们渲染物体的一种非常直接的方式,在场景中我们根据所有光源照亮一个物体,之后再渲染下一个物体,以此类推。它非常…

网络安全-文件包含

一、php://input 我们先来看一个简单的代码 <meta charset"utf8"> <?php error_reporting(0); $file $_GET["file"]; if(stristr($file,"php://filter") || stristr($file,"zip://") || stristr($file,"phar://&quo…

Windows如何搭建 ElasticSearch 集群

单机 & 集群 单台 Elasticsearch 服务器提供服务&#xff0c;往往都有最大的负载能力&#xff0c;超过这个阈值&#xff0c;服务器 性能就会大大降低甚至不可用&#xff0c;所以生产环境中&#xff0c;一般都是运行在指定服务器集群中。 除了负载能力&#xff0c;单点服务器…