CBAM注意力机制详解与实现

前言:

在深度学习领域,注意力机制已成为提升模型性能的重要手段之一。CBAM(Convolutional Block Attention Module)作为一种轻量级且高效的注意力机制,被广泛应用于各种卷积神经网络中。

一、CBAM注意力机制概述

1.1 什么是CBAM

CBAM(Convolutional Block Attention Module)是一种卷积块注意力模块,由通道注意力(Channel Attention)和空间注意力(Spatial Attention)两个子模块组成。CBAM的设计目标是通过显式地建模通道和空间两个维度的注意力,提升卷积神经网络的特征表达能力。

1.2 CBAM的结构

CBAM的结构如图所示:

从图中可以看出,CBAM包含两个主要部分:

  1. 通道注意力模块(Channel Attention Module):用于建模通道之间的依赖关系,生成通道注意力图。
  2. 空间注意力模块(Spatial Attention Module):用于建模空间位置之间的依赖关系,生成空间注意力图。

输入特征首先通过通道注意力模块,生成通道注意力图,然后通过空间注意力模块,生成空间注意力图。最终的输出特征是输入特征与两个注意力图的逐元素相乘结果。

二、通道注意力模块(Channel Attention Module)

2.1 通道注意力的原理

通道注意力模块的主要目标是显式地建模通道之间的依赖关系,生成通道注意力图。具体来说,通道注意力模块通过以下步骤实现:

  1. 全局信息聚合:通过全局平均池化(Global Average Pooling, GAP)和全局最大池化(Global Max Pooling, GMP)操作,将输入特征的空间维度压缩为1,生成两个通道描述符。
  2. 特征变换:将两个通道描述符通过共享的多层感知机(MLP)进行特征变换,生成通道注意力图。
  3. 激活:通过Sigmoid函数将通道注意力图的值归一化到[0, 1]范围内。

2.2 通道注意力的数学公式

设输入特征为 F∈RC×H×WF \in \mathbb{R}^{C \times H \times W},通道注意力图 Mc∈RC×1×1M_c \in \mathbb{R}^{C \times 1 \times 1} 的计算公式如下:

Mc(F)=σ(MLP(AvgPool(F))+MLP(MaxPool(F)))M_c(F) = \sigma(MLP(AvgPool(F)) + MLP(MaxPool(F)))

其中:

  • AvgPool(F)AvgPool(F) 和 MaxPool(F)MaxPool(F) 分别表示全局平均池化和全局最大池化操作。
  • MLPMLP 表示多层感知机,通常由两个全连接层组成,中间通过ReLU激活函数。
  • σ\sigma 表示Sigmoid函数,用于将注意力图的值归一化到[0, 1]范围内。

2.3 通道注意力模块的实现

以下是一个简单的通道注意力模块的实现代码(以PyTorch为例):

import torch
import torch.nn as nn

class ChannelAttentionModule(nn.Module):
    def __init__(self, in_channels, reduction_ratio=16):
        super(ChannelAttentionModule, self).__init__()
        self.avg_pool = nn.AdaptiveAvgPool2d(1)
        self.max_pool = nn.AdaptiveMaxPool2d(1)
        
        self.fc1 = nn.Conv2d(in_channels, in_channels // reduction_ratio, 1, bias=False)
        self.relu = nn.ReLU()
        self.fc2 = nn.Conv2d(in_channels // reduction_ratio, in_channels, 1, bias=False)
        self.sigmoid = nn.Sigmoid()

    def forward(self, x):
        avg_out = self.fc2(self.relu(self.fc1(self.avg_pool(x))))
        max_out = self.fc2(self.relu(self.fc1(self.max_pool(x))))
        out = avg_out + max_out
        return self.sigmoid(out)

三、空间注意力模块(Spatial Attention Module)

3.1 空间注意力的原理

空间注意力模块的主要目标是显式地建模空间位置之间的依赖关系,生成空间注意力图。具体来说,空间注意力模块通过以下步骤实现:

  1. 通道-wise池化:对输入特征进行通道-wise的最大池化和平均池化,生成两个空间描述符。
  2. 特征融合:将两个空间描述符在通道维度上拼接,然后通过一个卷积层生成空间注意力图。
  3. 激活:通过Sigmoid函数将空间注意力图的值归一化到[0, 1]范围内。

3.2 空间注意力的数学公式

设输入特征为 F∈RC×H×WF \in \mathbb{R}^{C \times H \times W},空间注意力图 Ms∈R1×H×WM_s \in \mathbb{R}^{1 \times H \times W} 的计算公式如下:

Ms(F)=σ(f7×7([AvgPool(F);MaxPool(F)]))M_s(F) = \sigma(f^{7 \times 7}([AvgPool(F); MaxPool(F)]))

其中:

  • AvgPool(F)AvgPool(F) 和 MaxPool(F)MaxPool(F) 分别表示通道-wise的平均池化和最大池化操作。
  • f7×7f^{7 \times 7} 表示一个7x7的卷积层。
  • σ\sigma 表示Sigmoid函数,用于将注意力图的值归一化到[0, 1]范围内。

3.3 空间注意力模块的实现

以下是一个简单的空间注意力模块的实现代码(以PyTorch为例):

import torch
import torch.nn as nn

class SpatialAttentionModule(nn.Module):
    def __init__(self, kernel_size=7):
        super(SpatialAttentionModule, self).__init__()
        self.conv1 = nn.Conv2d(2, 1, kernel_size, padding=kernel_size//2, bias=False)
        self.sigmoid = nn.Sigmoid()

    def forward(self, x):
        avg_out = torch.mean(x, dim=1, keepdim=True)
        max_out, _ = torch.max(x, dim=1, keepdim=True)
        out = torch.cat([avg_out, max_out], dim=1)
        out = self.conv1(out)
        return self.sigmoid(out)

四、CBAM模块的集成

4.1 CBAM模块的实现

将通道注意力模块和空间注意力模块结合起来,形成完整的CBAM模块。以下是一个完整的CBAM模块的实现代码(以PyTorch为例):

import torch
import torch.nn as nn

class CBAM(nn.Module):
    def __init__(self, in_channels, reduction_ratio=16, kernel_size=7):
        super(CBAM, self).__init__()
        self.channel_attention = ChannelAttentionModule(in_channels, reduction_ratio)
        self.spatial_attention = SpatialAttentionModule(kernel_size)

    def forward(self, x):
        x = x * self.channel_attention(x)
        x = x * self.spatial_attention(x)
        return x

4.2 CBAM模块的应用

CBAM模块可以插入到任何深度卷积神经网络中,以提升模型的特征表达能力。以下是一个将CBAM模块插入到ResNet中的示例:

import torch
import torch.nn as nn
from torchvision.models import resnet50

class ResNet50WithCBAM(nn.Module):
    def __init__(self, num_classes=1000):
        super(ResNet50WithCBAM, self).__init__()
        self.resnet = resnet50(pretrained=True)
        self.cbam = CBAM(2048, reduction_ratio=16, kernel_size=7)
        self.fc = nn.Linear(2048, num_classes)

    def forward(self, x):
        x = self.resnet.conv1(x)
        x = self.resnet.bn1(x)
        x = self.resnet.relu(x)
        x = self.resnet.maxpool(x)

        x = self.resnet.layer1(x)
        x = self.resnet.layer2(x)
        x = self.resnet.layer3(x)
        x = self.resnet.layer4(x)

        x = self.cbam(x)
        x = self.resnet.avgpool(x)
        x = x.view(x.size(0), -1)
        x = self.fc(x)
        return x

五、CBAM模块的优势

5.1 提升特征表达能力

CBAM模块通过显式地建模通道和空间两个维度的注意力,能够显著提升模型的特征表达能力。通道注意力模块能够关注重要的通道特征,而空间注意力模块能够关注重要的空间位置特征。

5.2 轻量级设计

CBAM模块的设计非常轻量级,不会显著增加模型的计算量和参数量。这使得CBAM模块可以轻松地插入到各种深度卷积神经网络中,而不会对模型的性能产生负面影响。

5.3 即插即用

CBAM模块具有即插即用的特点,可以轻松地插入到任何深度卷积神经网络中。这使得CBAM模块在实际应用中非常方便,无需对模型进行复杂的修改。

六、CBAM模块的应用场景

6.1 图像分类

CBAM模块可以插入到各种图像分类模型中,如ResNet、VGG、DenseNet等,以提升模型的分类性能。

6.2 目标检测

CBAM模块可以插入到各种目标检测模型中,如Faster R-CNN、YOLO、SSD等,以提升模型的检测性能。

6.3 语义分割

CBAM模块可以插入到各种语义分割模型中,如DeepLab、PSPNet、U-Net等,以提升模型的分割性能。

七、总结

CBAM(Convolutional Block Attention Module)作为一种轻量级且高效的注意力机制,通过显式地建模通道和空间两个维度的注意力,显著提升了卷积神经网络的特征表达能力。CBAM模块的设计非常轻量级,具有即插即用的特点,可以轻松地插入到各种深度卷积神经网络中。在实际应用中,CBAM模块被广泛应用于图像分类、目标检测和语义分割等任务中,取得了显著的效果,接下来将插入到其他模型中应用。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/979493.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

GCN从理论到实践——基于PyTorch的图卷积网络层实现

Hi,大家好,我是半亩花海。图卷积网络(Graph Convolutional Network, GCN)是一种处理图结构数据的深度学习模型。它通过聚合邻居节点的信息来更新每个节点的特征表示,广泛应用于社交网络分析、推荐系统和生物信息学等领…

给虚拟机配置IP

虚拟机IP这里一共有三个地方要设置,具体说明如下: (1)配置vm虚拟机网段 如果不进行设置,每次启动机器时都可能是随机的IP,不方便我们后续操作。具体操作是:点击编辑→虚拟网络编辑器 选择VMne…

【免费】YOLO[笑容]目标检测全过程(yolo环境配置+labelimg数据集标注+目标检测训练测试)

一、yolo环境配置 这篇帖子是我试过的,非常全,很详细【cudaanacondapytorchyolo(ultralytics)】 yolo环境配置 二、labelimg数据集标注 可以参考下面的帖子,不过可能会出现闪退的问题,安装我的流程来吧 2.1 labelimg安装 label…

mapbox基础,使用geojson加载heatmap热力图层

👨‍⚕️ 主页: gis分享者 👨‍⚕️ 感谢各位大佬 点赞👍 收藏⭐ 留言📝 加关注✅! 👨‍⚕️ 收录于专栏:mapbox 从入门到精通 文章目录 一、🍀前言1.1 ☘️mapboxgl.Map 地图对象1.2 ☘️mapboxgl.Map style属性1.3 ☘️heatmap热力图层样式二、🍀使用geojs…

Python 课堂点名桌面小程序

一、场景分析 闲来无事,老婆说叫我开发一个课堂点名桌面小程序,给她在课堂随机点名学生问问题。 人生苦短,那就用 Python 给她写一个吧。 二、依赖安装 因为要用到 excel,所以安装两个依赖: pip install openpyxl…

蓝桥杯 路径之谜

路径之谜 题目描述 小明冒充 XX 星球的骑士,进入了一个奇怪的城堡。 城堡里边什么都没有,只有方形石头铺成的地面。 假设城堡地面是 nnnn 个方格。如下图所示。 按习俗,骑士要从西北角走到东南角。可以横向或纵向移动,但不能斜着走…

在鸿蒙HarmonyOS手机上安装hap应用

一、下载工具 安装hap包需要用到小工具 。 二、解压到目录后,进入该文件夹,打开命令行,如下图 三、将下载好的hap包放入刚才解压的文件夹内(假设hap包文件名为app.hap) 四、连接好手机和电脑,手机需要打…

Android APK组成编译打包流程详解

Android APK(Android Package)是 Android 应用的安装包文件,其组成和打包流程涉及多个步骤和文件结构。以下是详细的说明: 一、APK 的组成 APK 是一个 ZIP 格式的压缩包,包含应用运行所需的所有文件。解压后主要包含以…

自然语言处理:词频-逆文档频率

介绍 大家好,博主又来给大家分享知识了。本来博主计划完成稠密向量表示的内容分享后,就开启自然语言处理中文本表示的讲解。可在整理分享资料的时候,博主发现还有个知识点,必须得单独拎出来好好说道说道。 这就是TF-IDF&#xf…

esp8266 rtos sdk开发环境搭建

1. 安装必要的工具 1.1 安装 Git Git 用于从远程仓库克隆代码,你可以从Git 官方网站下载 Windows 版本的安装程序。安装过程中可保持默认设置,安装完成后,在命令提示符(CMD)或 PowerShell 中输入git --version&#…

pytest下放pytest.ini文件就导致报错:ERROR: file or directory not found: #

pytest下放pytest.ini文件就导致报错:ERROR: file or directory not found: # 如下: 项目文件目录如下: pytest.ini文件内容: [pytest] addopts -v -s --alluredir ./allure-results # 自动添加的命令行参数:# -…

Blender调整最佳渲染清晰度

1.渲染采样调高 512 2.根据需要 开启AO ,开启辉光 , 开启 屏幕空间反射 3.调高分辨率 4096x4096 100% 分辨率是清晰度的关键 , 分辨率不高 , 你其他参数调再高都没用 4.世界环境开启体积散射 , 可以增强氛围感 5.三点打光法 放在模型和相机45夹角上 白模 白模带线条 成品

Vllm进行Qwen2-vl部署(包含单卡多卡部署及爬虫请求)

1.简介 阿里云于今年9月宣布开源第二代视觉语言模型Qwen2-VL,包括 2B、7B、72B三个尺寸及其量化版本模型。Qwen2-VL具备完整图像、多语言的理解能力,性能强劲。 相比上代模型,Qwen2-VL 的基础性能全面提升,可以读懂不同分辨率和…

xr-frame 3D Marker识别,扬州古牌坊 3D识别技术稳定调研

目录 识别物体规范 3D Marker 识别目标文件 map 生成 生成任务状态解析 服务耗时: 对传入的视频有如下要求: 对传入的视频建议: 识别物体规范 为提高Marker质量,保证算法识别效果,可参考Marker规范文档 Marker规…

Windows环境下SuperMapGIS 11i 使用达梦数据库

1. 环境介绍: 1.1. 操作系统: windows server 2019 1.2. GIS 软件: 1.2.1. GIS 桌面 supermap-idesktopx-11.3.0-windows-x64-bin 下载链接:SuperMap技术资源中心|为您提供全面的在线技术服务 安装教程:绿色版&…

redis的下载和安装详解

一、下载redis安装包 进入redis官网查看当前稳定版本: https://redis.io/download/发现此时的稳定版本是6.2.4, 此时可以去这个网站下载6.2.4稳定版本的tar包。 暂时不考虑不在windows上使用redis,那样将无法发挥redis的性能 二、上传tar…

Prometheus + Grafana 监控

Prometheus Grafana 监控 官网介绍:Prometheus 是一个开源系统 监控和警报工具包最初由 SoundCloud 构建。自 2012 年成立以来,许多 公司和组织已经采用了 Prometheus,并且该项目具有非常 活跃的开发人员和用户社区。它现在是一个独立的开源…

使用Semantic Kernel:对DeepSeek添加自定义插件

SemanticKernel介绍 Semantic Kernel是一个SDK,它将OpenAI、Azure OpenAI等大型语言模型与C#、Python和Java等传统编程语言集成在一起。Semantic Kernel通过允许您定义插件来实现这一点。 为什么需要添加插件? 大语言模型虽然具有强大的自然语言理解和…

Grok3使用体验与模型版本对比分析

文章目录 Grok的功能DeepSearch思考功能绘画功能Grok 3的独特功能 Grok 3的版本和特点与其他AI模型的比较 最新新闻:Grok3被誉为“地球上最聪明的AI” 最近,xAI公司正式发布了Grok3,并宣称其在多项基准测试中展现了惊艳的表现。据官方消息&am…

QT——c++界面编程库

非界面编程 QT编译的时候,依赖于 .pro 配置文件: SOURCES: 所有需要参与编译的 .cpp 源文件 HEADERS:所有需要参与编译的.h 头文件 QT:所有需要参与编译的 QT函数库 .pro文件一旦修改,注意需要键盘按 ctrls 才能加载最新的配置文…