YOLOV5添加 ECA CA SE CBAM 等八种注意力机制(小白可用)

目录

CBAM注意力机制原理及代码实现

代码实现

 yaml文件

修改后的结构图

SE注意力机制

SE结构图

完整代码实现

报错


⭐欢迎大家订阅我的专栏一起学习⭐

🚀🚀🚀订阅专栏,更新及时查看不迷路🚀🚀🚀

http://t.csdnimg.cn/Q4aka

http://t.csdnimg.cn/sVHxv

💡魔改网络、复现论文、优化创新💡

原理

尽管Ultralytics 推出了最新版本的 YOLOv8 模型。但YOLOv5作为一个经典的目标检测的算法,仍经常被提及。注意力机制是提高模型性能最热门的方法之一。本次将介绍几种常见的注意力机制,这些注意力机制在大多数的数据集上均能有效的提升目标检测的精度/召回率/准确率。

CBAM注意力机制原理及代码实现
CBAM注意力机制结构图


CBAM(Convolutional Block Attention Module)是一种用于卷积神经网络(CNN)的注意力机制,它能够增强网络对输入特征的关注度,提高网络性能。CBAM 主要包含两个子模块:通道注意力模块(Channel Attention Module)和空间注意力模块(Spatial Attention Module)。

以下是CBAM注意力机制的基本原理:

1. 通道注意力模块(Channel Attention Module):
输入:经过卷积层的特征图。
处理步骤:
对每个通道进行全局平均池化,得到通道的全局平均值。
通过两个全连接层,将全局平均值映射为两个权重向量(一个用于缩放,一个用于偏置)。
将这两个权重向量与原始特征图相乘,以加权调整每个通道的重要性。

2. 空间注意力模块(Spatial Attention Module):**
输入:通道注意力模块的输出。
处理步骤:
     对每个通道的特征图进行分别的最大池化和平均池化,得到两个空间特征图。
     将这两个空间特征图相加,通过一个卷积层产生一个权重图。
     将原始特征图与权重图相乘,以加权调整每个空间位置的重要性。

3. 整合:
   将通道注意力模块和空间注意力模块的输出相乘,得到最终的注意力增强特征图。
   将这个注意力增强的特征图传递给网络的下一层进行进一步处理。

CBAM的关键优势在于它能够同时考虑通道和空间信息,有助于网络更好地理解和利用输入特征。这种注意力机制有助于提高网络在视觉任务上的性能,使其能够更有针对性地关注重要的特征。

代码实现
class ChannelAttention(nn.Module):
    """Channel-attention module https://github.com/open-mmlab/mmdetection/tree/v3.0.0rc1/configs/rtmdet."""
 
    def __init__(self, channels: int) -> None:
        super().__init__()
        self.pool = nn.AdaptiveAvgPool2d(1)
        self.fc = nn.Conv2d(channels, channels, 1, 1, 0, bias=True)
        self.act = nn.Sigmoid()
 
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return x * self.act(self.fc(self.pool(x)))
 
 
class SpatialAttention(nn.Module):
    """Spatial-attention module."""
 
    def __init__(self, kernel_size=7):
        """Initialize Spatial-attention module with kernel size argument."""
        super().__init__()
        assert kernel_size in (3, 7), 'kernel size must be 3 or 7'
        padding = 3 if kernel_size == 7 else 1
        self.cv1 = nn.Conv2d(2, 1, kernel_size, padding=padding, bias=False)
        self.act = nn.Sigmoid()
 
    def forward(self, x):
        """Apply channel and spatial attention on input for feature recalibration."""
        return x * self.act(self.cv1(torch.cat([torch.mean(x, 1, keepdim=True), torch.max(x, 1, keepdim=True)[0]], 1)))
 
 
class CBAM(nn.Module):
    """Convolutional Block Attention Module."""
 
    def __init__(self, c1, kernel_size=7):  # ch_in, kernels
        super().__init__()
        self.channel_attention = ChannelAttention(c1)
        self.spatial_attention = SpatialAttention(kernel_size)
 
    def forward(self, x):
        """Applies the forward pass through C1 module."""
        return self.spatial_attention(self.channel_attention(x))
 yaml文件
# YOLOv5 🚀 by Ultralytics, GPL-3.0 license

# Parameters
nc: 80  # number of classes
depth_multiple: 0.67  # model depth multiple
width_multiple: 0.75  # layer channel multiple
anchors:
  - [10,13, 16,30, 33,23]  # P3/8
  - [30,61, 62,45, 59,119]  # P4/16
  - [116,90, 156,198, 373,326]  # P5/32

# YOLOv5 v6.0 backbone
backbone:
  # [from, number, module, args]
  [[-1, 1, Conv, [64, 6, 2, 2]],  # 0-P1/2
   [-1, 1, Conv, [128, 3, 2]],  # 1-P2/4
   [-1, 3, C3, [128]],
   [-1, 1, Conv, [256, 3, 2]],  # 3-P3/8
   [-1, 6, C3, [256]],
   [-1, 1, Conv, [512, 3, 2]],  # 5-P4/16
   [-1, 9, C3, [512]],
   [-1, 1, Conv, [1024, 3, 2]],  # 7-P5/32
   [-1, 3, C3, [1024]],
   [-1, 1, CBAM, [1024]],
   [-1, 1, SPPF, [1024, 5]],  # 9
  ]

# YOLOv5 v6.0 head
head:
  [[-1, 1, Conv, [512, 1, 1]],
   [-1, 1, nn.Upsample, [None, 2, 'nearest']],
   [[-1, 6], 1, Concat, [1]],  # cat backbone P4
   [-1, 3, C3, [512, False]],  # 13

   [-1, 1, Conv, [256, 1, 1]],
   [-1, 1, nn.Upsample, [None, 2, 'nearest']],
   [[-1, 4], 1, Concat, [1]],  # cat backbone P3
   [-1, 3, C3, [256, False]],  # 17 (P3/8-small)

   [-1, 1, Conv, [256, 3, 2]],
   [[-1, 14], 1, Concat, [1]],  # cat head P4
   [-1, 3, C3, [512, False]],  # 20 (P4/16-medium)

   [-1, 1, Conv, [512, 3, 2]],
   [[-1, 10], 1, Concat, [1]],  # cat head P5
   [-1, 3, C3, [1024, False]],  # 23 (P5/32-large)

   [[17, 20, 23], 1, Detect, [nc, anchors]],  # Detect(P3, P4, P5)
  ]
修改后的结构图

改进后完整的YOLOv5结构图
SE注意力机制

卷积神经网络建立在卷积运算的基础上,它通过在局部感受野内将空间和通道信息融合在一起来提取信息特征。为了提高网络的表示能力,最近的几种方法已经显示了增强空间编码的好处。在这项工作中,我们专注于通道关系,并提出了一种新颖的架构单元,我们将其称为“挤压和激励”(SE)块,它通过显式建模通道之间的相互依赖性来自适应地重新校准通道方面的特征响应。我们证明,通过将这些块堆叠在一起,我们可以构建在具有挑战性的数据集上具有极好的泛化能力的 SENet 架构。至关重要的是,我们发现 SE 模块能够以最小的额外计算成本为现有最先进的深度架构带来显着的性能改进。 SENets 构成了我们 ILSVRC 2017 分类提交的基础,该分类提交赢得了第一名,并将 top-5 错误率显着降低至 2.251%,与 2016 年获胜条目相比相对提高了约 25%。

我们通过引入一个新的架构单元(我们将其称为“挤压和激励”(SE)块)来研究架构设计的另一个方面 - 通道关系。我们的目标是通过显式建模网络卷积特征通道之间的相互依赖性来提高网络的表示能力。为了实现这一目标,我们提出了一种允许网络执行特征重新校准的机制,通过该机制它可以学习使用全局信息来选择性地强调信息丰富的特征并抑制不太有用的特征。

SE结构图
模型结构


SE 构建块的基本结构如图所示。对于任何给定的变换 Ftr : X → U, X ∈ RH′×W ′×C′ , U ∈ RH×W ×C ,,我们可以构建相应的 SE 块来执行特征重新校准。特征 U 首先通过挤压操作,该操作聚合跨空间维度 H × W 的特征图以生成通道描述符。该描述符嵌入了通道特征响应的全局分布,使得来自网络的全局感受野的信息能够被其较低层利用。接下来是激励操作,其中通过基于通道依赖性的自门机制为每个通道学习特定于样本的激活,控制每个通道的激励。然后对特征图 U 进行重新加权以生成 SE 块的输出,然后可以将其直接馈送到后续层中。

class SELayer(nn.Module):
    def __init__(self, channel, reduction=16):
        super(SELayer, self).__init__()
        self.avg_pool = nn.AdaptiveAvgPool2d(1)
        self.fc = nn.Sequential(
            nn.Linear(channel, channel // reduction, bias=False),
            nn.ReLU(inplace=True),
            nn.Linear(channel // reduction, channel, bias=False),
            nn.Sigmoid()
        )
 
    def forward(self, x):
        b, c, _, _ = x.size()
        y = self.avg_pool(x).view(b, c)
        y = self.fc(y).view(b, c, 1, 1)
        return x * y.expand_as(x)
class SEConv(nn.Module):
    # channel attentive convolution
    def __init__(self, c1, c2, k=1, s=1, p=None, g=1, act=True, reduction=16):  # ch_in, ch_out, kernel, stride, padding, groups
        super().__init__()
        self.conv = nn.Conv2d(c1, c2, k, s, autopad(k, p), groups=g, bias=False)
        self.bn = nn.BatchNorm2d(c2)
        self.act = nn.SiLU() if act is True else (act if isinstance(act, nn.Module) else nn.Identity())
        self.se = SELayer(c2, reduction)
 
    def forward(self, x):
        residual = x
        out = self.bn(self.conv(x))
        out = self.se(out)
        out += residual
        return self.act(out)
class SEBottleneck(nn.Module):
    # Standard bottleneck
    def __init__(self, c1, c2, shortcut=True, g=1, e=0.5):  # ch_in, ch_out, shortcut, groups, expansion
        super().__init__()
        c_ = int(c2 * e)  # hidden channels
        self.cv1 = SEConv(c1, c_, 1, 1)
        self.cv2 = SEConv(c_, c2, 3, 1, g=g)
        self.add = shortcut and c1 == c2
 
    def forward(self, x):
        return x + self.cv2(self.cv1(x)) if self.add else self.cv2(self.cv1(x))

其他注意力机制未完待续...

完整代码实现

链接: https://pan.baidu.com/s/1LYdIwI71ZLCo4GqM6L5PBg?pwd=4yr5 提取码: 4yr5 

yaml文件记得选择对应的注意力机制

报错

如果报错,查看

解决Yolov5的RuntimeError: result type Float can‘t be cast to the desired output type long int 问题_yolov5 runtimeerror: result type float can't be ca-CSDN博客

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/453253.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

【考研数学】打基础 - 武忠祥《基础篇》vs 李永乐《复习全书》

这个我太有同感了,但是我现在能分清了,绝对给大家讲明白 大家疑惑的是,武忠祥老师的基础阶段为什么有人推荐高等数学基础篇,有人推荐复习全书。所以到底该用哪一个。 其实这是因为武忠祥老师在两个机构干过,在有道的…

前端之用HTML弄一个古诗词

将进酒 <!DOCTYPE html> <html lang"en"> <head><meta charset"UTF-8"><title>将进酒</title><h1><big>将进酒</big> 君不见黄河之水天上来</h1><table><tr><td ><img…

C#无法给PLC写入数据原因分析

一、背景 1.1 概述 C#中无法给PLC写入数据的原因有很多&#xff0c;这里分享网络端口号被占用导致无法写入的确认方法 1.2 环境 ①使用三菱PLC ②C#通过网口与PLC进行通讯 二、现象 1.1 代码 通过HslCommunication连接PLC时&#xff0c;连接返回成功&#xff0c;写入返回失败 …

了解O2OA(翱途)开发平台中的VIP应用

使用O2OA(翱途)开发平台可以非常方便地进行项目的业务需求开发与实施&#xff0c;O2OA(翱途)开发平台并不限制实现的系统类型&#xff0c;所以能实现的系统很多&#xff0c;最终呈现的项目成果也是多样性的&#xff0c;可能是OA系统&#xff0c;可能是人力资源管理系统&#xf…

AI写作,一键生成原创文章真高效

AI写作&#xff0c;一键生成原创文章真高效!在当今信息爆炸的时代&#xff0c;写作已成为一项重要的技能。然而&#xff0c;对于许多人来说&#xff0c;写作并不容易。有时我们会面临写作灵感枯竭、时间不足或者缺乏写作能力的困扰。但是&#xff0c;随着人工智能技术的不断进步…

3、设计模式之工厂模式2(Factory)

一、什么是工厂模式 工厂模式属于创建型设计模式&#xff0c;它用于解耦对象的创建和使用。通常情况下&#xff0c;我们创建对象时需要使用new操作符&#xff0c;但是使用new操作符创建对象会使代码具有耦合性。工厂模式通过提供一个公共的接口&#xff0c;使得我们可以在不暴露…

TS271IDT运算放大器芯片中文资料PDF数据手册引脚图图片参数价格功能

产品描述&#xff1a; TS271 是一款低成本、低功耗的单通道运算放大器&#xff0c;设计用于采用单电源或双电源供电。该运算放大器采用意法半导体硅栅CMOS工艺&#xff0c;具有出色的消耗-速度比。该放大器非常适合低功耗应用。 电源可通过引脚 8 和 4 之间连接的电阻器进行外…

Linux:深入文件系统

一、Inode 我们使用ls -l的时候看到的除了看到文件名&#xff0c;还看到了文件元数据。 [rootlocalhost linux]# ls -l 总用量 12 -rwxr-xr-x. 1 root root 7438 "9月 13 14:56" a.out -rw-r--r--. 1 root root 654 "9月 13 14:56" test.c 每行包含7列&…

Win11系统启动VMware上虚拟机蓝屏解决办法

背景 最近有在做一个项目的过程中需要使用虚拟机&#xff0c;用原来装好的的Vmware14打开虚拟机&#xff0c;直接蓝屏了&#xff0c;尝试了如下几种方法来解决&#xff0c;最好用的就是第二种&#xff0c;直接下载最新版本(在软件管家中直接下载)。 虚拟机 目前常用的虚拟机软…

idea+maven+tomcat+spring 创建一个jsp项目

概述&#xff1a;我真服了&#xff0c;这个垃圾学校还在教jsp&#xff0c;这种技术我虽然早会了&#xff0c;但是之前搞的大多都是springboot web类型的&#xff0c;这里我就复习一下&#xff0c;避免以后忘记这种垃圾技术 第一步&#xff1a;创建maven项目 第二步&#xff1a…

Java设计模式:桥接模式

❤ 作者主页&#xff1a;欢迎来到我的技术博客&#x1f60e; ❀ 个人介绍&#xff1a;大家好&#xff0c;本人热衷于Java后端开发&#xff0c;欢迎来交流学习哦&#xff01;(&#xffe3;▽&#xffe3;)~* &#x1f34a; 如果文章对您有帮助&#xff0c;记得关注、点赞、收藏、…

【STL】deque双端开口容器

1.关于deque容器说明 deque容器与vector容器差不多&#xff0c;但deque是双端开口容器&#xff0c;可以在两端插入和删除元素 push_front( )//在头部插入 push_back( )//在尾部插入 pop_front( )//在头部删除 pop_back( ) //在尾部删除 其他相应函数与vector差不多&#…

Pytorch入门-TensorBoard

文章目录 WhatHOWSummaryWriter What TensorBoard是TensorFlow自带的一个强大的可视化工具&#xff0c;也是一个Web应用程序套件。 torch.utils.tensorboard 是 PyTorch 提供的一个用于将标量、图像、直方图和其他信息记录到 TensorBoard 中的实用程序包。TensorBoard 是 Tenso…

Pycharm的Project Structure (项目结构)

文章目录 一、Sources二、Tests三、Exeluded四、Namespace packages五、Templates六、Resources 一、Sources 源代码根目录&#xff1a;包含项目的主要源代码&#xff0c;它会在这个目录下搜索代码&#xff0c;然后自动补全和只能提示都通过这里的代码提供。若项目运行自定义代…

rabbitmq-spring-boot-start配置使用手册

rabbitmq-spring-boot-start配置使用手册 文章目录 1.yaml配置如下2.引入pom依赖如下2.1 引入项目resources下libs中的jar包依赖如下2.2引入maven私服依赖如下 3.启动类配置如下4.项目中测试发送消息如下5.项目中消费消息代码示例6.mq管理后台交换机队列创建及路由绑定关系如下…

Java算法总结之冒泡排序(详解)

程序代码园发文地址&#xff1a;Java算法总结之冒泡排序&#xff08;详解&#xff09;-程序代码园小说,Java,HTML,Java小工具,程序代码园,http://www.byqws.com/ ,Java算法总结之冒泡排序&#xff08;详解&#xff09;http://www.byqws.com/blog/3145.html?sourcecsdn 冒泡排序…

(day 7)JavaScript学习笔记(函数1)

概述 这是我的学习笔记&#xff0c;记录了JavaScript的学习过程&#xff0c;我是有一些Python基础的&#xff0c;因此在学习的过程中不自觉的把JavaScript的代码跟Python代码做对比&#xff0c;以便加深印象。在写博客的时候我会尽量详尽的记录每个知识点。如果你完全没接触过J…

社媒营销会遇到的封号问题

最近看到很多的直播都在教伙伴们怎么用linked in 以及facebook开发客户&#xff0c;比如如何添加联系人&#xff0c;如何用添加好友的话术&#xff0c;如何加群&#xff0c;如何分析客户的背景等等。 有的主播只是讲一些表面的东西&#xff0c;有的主播可能是真的肚子里有货&a…

【DataWhale学习】用免费GPU线上跑StableDiffusion项目实践

用免费GPU线上跑SD项目实践 ​ DataWhale组织了一个线上白嫖GPU跑chatGLM与SD的项目活动&#xff0c;我很感兴趣就参加啦。之前就对chatGLM有所耳闻&#xff0c;是去年清华联合发布的开源大语言模型&#xff0c;可以用来打造个人知识库什么的&#xff0c;一直没有尝试。而SD我…

鸿蒙Harmony应用开发—ArkTS声明式开发(基础手势:StepperItem)

用作Stepper组件的页面子组件。 说明&#xff1a; 该组件从API Version 8开始支持。后续版本如有新增内容&#xff0c;则采用上角标单独标记该内容的起始版本。 子组件 支持单个子组件。 接口 StepperItem() 属性 参数名参数类型参数描述prevLabelstring设置左侧文本按钮内…