YOLOv9有效提点|加入SE、CBAM、ECA、SimAM等几十种注意力机制(一)


专栏介绍:YOLOv9改进系列 | 包含深度学习最新创新,主力高效涨点!!!


一、本文介绍

        本文将以SE注意力机制为例,演示如何在YOLOv9种添加注意力机制!


 《Squeeze-and-Excitation Networks》

        SENet提出了一种基于“挤压和激励”(SE)的注意力模块,用于改进卷积神经网络(CNN)的性能。SE块可以适应地重新校准通道特征响应,通过建模通道之间的相互依赖关系来增强CNN的表示能力。这些块可以堆叠在一起形成SENet架构,使其在多个数据集上具有非常有效的泛化能力。


《CBAM:Convolutional Block Attention Module》

        CBAM模块能够同时关注CNN的通道和空间两个维度,对输入特征图进行自适应细化。这个模块轻量级且通用,可以无缝集成到任何CNN架构中,并可以进行端到端训练。实验表明,使用CBAM可以显著提高各种模型的分类和检测性能。


《ECA-Net: Efficient Channel Attention for Deep Convolutional Neural Networks》

        通道注意力模块ECA,可以提升深度卷积神经网络的性能,同时不增加模型复杂性。通过改进现有的通道注意力模块,作者提出了一种无需降维的局部交互策略,并自适应选择卷积核大小。ECA模块在保持性能的同时更高效,实验表明其在多个任务上具有优势。


《SimAM: A Simple, Parameter-Free Attention Module for Convolutional Neural Networks》

        SimAM一种概念简单且非常有效的注意力模块。不同于现有的通道/空域注意力模块,该模块无需额外参数为特征图推导出3D注意力权值。具体来说,SimAM的作者基于著名的神经科学理论提出优化能量函数以挖掘神经元的重要性。该模块的另一个优势在于:大部分操作均基于所定义的能量函数选择,避免了过多的结构调整

适用检测目标:   YOLOv9模块通用改进


二、改进步骤

        以下以SE注意力机制为例在YOLOv9中加入注意力代码,其他注意力机制同理!

 2.1 复制代码

        将SE的代码辅助到models包下common.py文件中。

 2.2 修改yolo.py文件

        在yolo.py脚本的第700行(可能因YOLOv9版本变化而变化)增加下方代码。

        elif m in (SE,):
            args.insert(0, ch[f])

2.3 创建配置文件

        创建模型配置文件(yaml文件),将我们所作改进加入到配置文件中(这一步的配置文件可以复制models  - > detect 下的yaml修改。)。对YOLO系列yaml文件不熟悉的同学可以看我往期的yaml详解教学!

YOLO系列 “.yaml“文件解读-CSDN博客

# YOLOv9
# Powered bu https://blog.csdn.net/StopAndGoyyy
# parameters
nc: 80  # number of classes
depth_multiple: 1  # model depth multiple
width_multiple: 1  # layer channel multiple
#activation: nn.LeakyReLU(0.1)
#activation: nn.ReLU()

# anchors
anchors: 3

# YOLOv9 backbone
backbone:
  [
   [-1, 1, Silence, []],  
   
   # conv down
   [-1, 1, Conv, [64, 3, 2]],  # 1-P1/2

   # conv down
   [-1, 1, Conv, [128, 3, 2]],  # 2-P2/4

   # elan-1 block
   [-1, 1, RepNCSPELAN4, [256, 128, 64, 1]],  # 3

   # avg-conv down
   [-1, 1, ADown, [256]],  # 4-P3/8

   # elan-2 block
   [-1, 1, RepNCSPELAN4, [512, 256, 128, 1]],  # 5

   # avg-conv down
   [-1, 1, ADown, [512]],  # 6-P4/16

   # elan-2 block
   [-1, 1, RepNCSPELAN4, [512, 512, 256, 1]],  # 7

   # avg-conv down
   [-1, 1, ADown, [512]],  # 8-P5/32

   # elan-2 block
   [-1, 1, RepNCSPELAN4, [512, 512, 256, 1]],  # 9
  ]

# YOLOv9 head
head:
  [
   # elan-spp block
   [-1, 1, SPPELAN, [512, 256]],  # 10

   # up-concat merge
   [-1, 1, nn.Upsample, [None, 2, 'nearest']],
   [[-1, 7], 1, Concat, [1]],  # cat backbone P4

   # elan-2 block
   [-1, 1, RepNCSPELAN4, [512, 512, 256, 1]],  # 13

   # up-concat merge
   [-1, 1, nn.Upsample, [None, 2, 'nearest']],
   [[-1, 5], 1, Concat, [1]],  # cat backbone P3

   # elan-2 block
   [-1, 1, RepNCSPELAN4, [256, 256, 128, 1]],  # 16 (P3/8-small)

   # avg-conv-down merge
   [-1, 1, ADown, [256]],
   [[-1, 13], 1, Concat, [1]],  # cat head P4

   # elan-2 block
   [-1, 1, RepNCSPELAN4, [512, 512, 256, 1]],  # 19 (P4/16-medium)

   # avg-conv-down merge
   [-1, 1, ADown, [512]],
   [[-1, 10], 1, Concat, [1]],  # cat head P5

   # elan-2 block
   [-1, 1, RepNCSPELAN4, [512, 512, 256, 1]],  # 22 (P5/32-large)
   
   
   # multi-level reversible auxiliary branch
   
   # routing
   [5, 1, CBLinear, [[256]]], # 23
   [7, 1, CBLinear, [[256, 512]]], # 24
   [9, 1, CBLinear, [[256, 512, 512]]], # 25
   
   # conv down
   [0, 1, Conv, [64, 3, 2]],  # 26-P1/2

   # conv down
   [-1, 1, Conv, [128, 3, 2]],  # 27-P2/4

   # elan-1 block
   [-1, 1, RepNCSPELAN4, [256, 128, 64, 1]],  # 28

   # avg-conv down fuse
   [-1, 1, ADown, [256]],  # 29-P3/8
   [[23, 24, 25, -1], 1, CBFuse, [[0, 0, 0]]], # 30  

   # elan-2 block
   [-1, 1, RepNCSPELAN4, [512, 256, 128, 1]],  # 31

   # avg-conv down fuse
   [-1, 1, ADown, [512]],  # 32-P4/16
   [[24, 25, -1], 1, CBFuse, [[1, 1]]], # 33 

   # elan-2 block
   [-1, 1, RepNCSPELAN4, [512, 512, 256, 1]],  # 34

   # avg-conv down fuse
   [-1, 1, ADown, [512]],  # 35-P5/32
   [[25, -1], 1, CBFuse, [[2]]], # 36

   # elan-2 block
   [-1, 1, RepNCSPELAN4, [512, 512, 256, 1]],  # 37
   [-1, 1, SE, [16]],  # 38

   
   
   # detection head

   # detect
   [[31, 34, 38, 16, 19, 22], 1, DualDDetect, [nc]],  # DualDDetect(A3, A4, A5, P3, P4, P5)
  ]

3.1 训练过程

        最后,复制我们创建的模型配置,填入训练脚本(train_dual)中(不会训练的同学可以参考我之前的文章。),运行即可。

YOLOv9 最简训练教学!-CSDN博客

​​

​​


SE代码

class SE(nn.Module):
    def __init__(self, channel, reduction=16):
        super(SE, self).__init__()
        self.avg_pool = nn.AdaptiveAvgPool2d(1)
        self.fc = nn.Sequential(
            nn.Linear(channel, channel // reduction, bias=False),
            nn.ReLU(),
            nn.Linear(channel // reduction, channel, bias=False),
            nn.Sigmoid()
        )

    def forward(self, x):
        b, c, _, _ = x.size()
        y = self.avg_pool(x).view(b, c)
        y = self.fc(y).view(b, c, 1, 1)
        return x * y.expand_as(x)

CBAM代码

class CBAMBlock(nn.Module):

    def __init__(self, channel=512, reduction=16, kernel_size=7):
        super().__init__()
        self.ca = ChannelAttention(channel=channel, reduction=reduction)
        self.sa = SpatialAttention(kernel_size=kernel_size)

    def init_weights(self):
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                init.kaiming_normal_(m.weight, mode='fan_out')
                if m.bias is not None:
                    init.constant_(m.bias, 0)
            elif isinstance(m, nn.BatchNorm2d):
                init.constant_(m.weight, 1)
                init.constant_(m.bias, 0)
            elif isinstance(m, nn.Linear):
                init.normal_(m.weight, std=0.001)
                if m.bias is not None:
                    init.constant_(m.bias, 0)

    def forward(self, x):
        b, c, _, _ = x.size()
        out = x * self.ca(x)
        out = out * self.sa(out)
        return out

ECA代码

class ECAAttention(nn.Module):

    def __init__(self, kernel_size=3):
        super().__init__()
        self.gap = nn.AdaptiveAvgPool2d(1)
        self.conv = nn.Conv1d(1, 1, kernel_size=kernel_size, padding=(kernel_size - 1) // 2)
        self.sigmoid = nn.Sigmoid()

    def init_weights(self):
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                init.kaiming_normal_(m.weight, mode='fan_out')
                if m.bias is not None:
                    init.constant_(m.bias, 0)
            elif isinstance(m, nn.BatchNorm2d):
                init.constant_(m.weight, 1)
                init.constant_(m.bias, 0)
            elif isinstance(m, nn.Linear):
                init.normal_(m.weight, std=0.001)
                if m.bias is not None:
                    init.constant_(m.bias, 0)

    def forward(self, x):
        y = self.gap(x)  # bs,c,1,1
        y = y.squeeze(-1).permute(0, 2, 1)  # bs,1,c
        y = self.conv(y)  # bs,1,c
        y = self.sigmoid(y)  # bs,1,c
        y = y.permute(0, 2, 1).unsqueeze(-1)  # bs,c,1,1
        return x * y.expand_as(x)

SimAM代码

class SimAM(torch.nn.Module):
    def __init__(self, e_lambda=1e-4):
        super(SimAM, self).__init__()

        self.activaton = nn.Sigmoid()
        self.e_lambda = e_lambda

    def __repr__(self):
        s = self.__class__.__name__ + '('
        s += ('lambda=%f)' % self.e_lambda)
        return s

    @staticmethod
    def get_module_name():
        return "simam"

    def forward(self, x):
        b, c, h, w = x.size()

        n = w * h - 1

        x_minus_mu_square = (x - x.mean(dim=[2, 3], keepdim=True)).pow(2)
        y = x_minus_mu_square / (4 * (x_minus_mu_square.sum(dim=[2, 3], keepdim=True) / n + self.e_lambda)) + 0.5

        return x * self.activaton(y)

如果觉得本文章有用的话给博主点个关注吧!


本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/423885.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

翻硬币 刷题笔记

通过模拟观察 我们发现 按一下会改变相邻两个硬币的状态 将硬币排成一排 从对位置下标为5到下标为7 依次翻其本身和其右边的硬币 对5,6,7操作 操作次数为3 此时我们只改变了硬币5和硬币8的状态 因此 每两处不一样的地方,我们想改变…

Java | vscode如何使用命令行运行Java程序

1.在vscode中新建一个终端 2.在终端中输入命令 javac <源文件>此命令执行后&#xff0c;在文件夹中会生成一个与原java程序同名的.class文件。然后输入如下命令&#xff1a; java <源文件名称>这样java程序就运行成功了。&#x1f607;

递归与递推(蓝桥杯 c++)

目录 题目一&#xff1a; 代码&#xff1a; 题目二: 代码&#xff1a; 题目三&#xff1a; 代码&#xff1a; 题目四&#xff1a; 代码&#xff1a; 题目一&#xff1a; 代码&#xff1a; #include<iostream> #include<cstring> using namespace std; int …

HTML和CSS (前端共三篇)【详解】

目录 一、前端开发介绍 二、HTML入门 三、HTML基础标签 四、CSS样式修饰 五、HTML表格标签 六、HTML表单标签 一、前端开发介绍 web应用有BS和CS架构两种&#xff0c;其中我们主要涉及的是BS架构。而BS架构里&#xff0c;B&#xff08;Browser浏览器&#xff09;是客户端的…

一些可以访问gpt的方式

1、Coze扣子是新一代 AI 大模型智能体开发平台。整合了插件、长短期记忆、工作流、卡片等丰富能力&#xff0c;扣子能帮你低门槛、快速搭建个性化或具备商业价值的智能体&#xff0c;并发布到豆包、飞书等各个平台。https://www.coze.cn/ 2、https://poe.com/ 3、插件阿里…

2023国赛样题路由部分【RIP RIPNG ACLRIP ACLRIPNG ISIS NAT64】

RT1串行链路、RT2串行链路、FW1、AC1之间分别运行RIP和RIPng协议&#xff0c;FW1、RT1、RT2的RIP和RIPng发布loopback2地址路由&#xff0c;AC1 RIP发布loopback2地址路由&#xff0c;AC1 RIPng采用route-map匹配prefix-list重发布loopback2地址路由。RT1配置offset值为3的路由…

一文认识蓝牙(验证基于Aduino IDE的ESP32)

1、简介 蓝牙技术是一种无线通信的方式&#xff0c;利用特定频率的波段&#xff08;2.4GHz-2.485GHz左右&#xff09;&#xff0c;进行电磁波传输&#xff0c;总共有83.5MHz的带宽资源。 1.1、背景 蓝牙&#xff08;Bluetooth&#xff09;一词取自于十世纪丹麦国王哈拉尔Haral…

linux的通信方案(SYSTEM V)

文章目录 共享内存(Share Memory)信号队列&#xff08;Message Queue&#xff09;信号量(semaphore) 进程间通信的核心理念&#xff1a;让不同的进程看见同一块资源 linux下的通信方案&#xff1a; SYSTEM V 共享内存(Share Memory) 特点&#xff1a;1.共享内存是进程见通信最…

使用 Docker 部署 Answer 问答平台

1&#xff09;介绍 GitHub&#xff1a;https://github.com/apache/incubator-answer Answer 问答社区是在线平台&#xff0c;让用户提出问题并获得回答。用户可以发布问题并得到其他用户的详细答案、建议或信息。回答可以投票或评分&#xff0c;有助于确定有用的内容。标签和分…

笨办法学 Python3 第五版(预览)(二)

原文&#xff1a;Learn Python the Hard Way, 5th Edition (Early Release) 译者&#xff1a;飞龙 协议&#xff1a;CC BY-NC-SA 4.0 练习 19&#xff1a;函数和变量 现在你将把函数与你从之前练习中了解到的变量结合起来。如你所知&#xff0c;变量给数据片段一个名称&#x…

数据挖掘入门项目二手交易车价格预测之数据分析

文章目录 1. 相关库的引入2. 数据的加载3. 数据概况3.1 统计值查看3.2 查看数据类型 4. 判断缺失值4.1 统计每一列空值的数量4.2 可视化缺失值数量 5. 判断异常值5.1 异常值检测 6. 了解预测值的分布6.1 统计各预测值的分布6.2 总体分布概况6.2 查看预测值的具体频数6.3 查看sk…

基于ssm旅社客房收费管理系统+vue

目 录 目 录 I 摘 要 III ABSTRACT IV 1 绪论 1 1.1 课题背景 1 1.2 研究现状 1 1.3 研究内容 2 2 系统开发环境 3 2.1 vue技术 3 2.2 JAVA技术 3 2.3 MYSQL数据库 3 2.4 B/S结构 4 2.5 SSM框架技术 4 3 系统分析 5 3.1 可行性分析 5 3.1.1 技术可行性 5 3.1.2 操作可行性 5 3…

Git 如何上传本地的所有分支

Git 如何上传本地的所有分支 比如一个本地 git 仓库里定义了两个远程分支&#xff0c;一个名为 origin&#xff0c; 一个名为 web 现在本地有一些分支是 web 远程仓库没有的分支&#xff0c;如何将本地所有分支都推送到 web 这个远程仓库上呢 git push web --all

【ArcGIS超级工具】基于ArcPy的矢量数据批量自动化入库工具

最近&#xff0c;有很多做规划的朋友私信我&#xff0c;想让我帮忙开发一款ArcGIS自动化脚本工具&#xff0c;实现点、线、面的自动化入库操作&#xff0c;帮他们在平时的内业数据处理工作中减少机械式重复性的工作&#xff0c;提高工作效率。为此&#xff0c;我详细了解了下目…

车辆维护和燃油里程跟踪器LubeLogger

什么是 LubeLogger &#xff1f; LubeLogger 是一个自托管、开源、基于网络的车辆维护和燃油里程跟踪器。 LubeLogger 比较适合用来跟踪管理您的汽车的维修、保养、加油的历史记录&#xff0c;比用 Excel 强多了 官方提供了在线试用&#xff0c;可以使用用户名 test 和密码 123…

Covalent Network(CQT)将链下收入引入链上,在全新阶段开启 Token 回购

Covalent Network&#xff08;CQT&#xff09;&#xff0c;是 Web3 领域跨越 225 个链的领先数据索引服务商&#xff0c;通过统一 API 的方式提供结构化数据可用性服务&#xff0c;并正在成为 AI、DeFi、分析和治理等多样化需求的关键参与者。为了支持去中心化技术的采用&#…

Java快读

java的快读 (1)BufferedReader BufferedReader br new BufferedReader(new InputStreamReader(System.in));//定义对象String[] strings br.readLine().split(" ");//读取一行字符串&#xff0c;以空格为分隔转化为字符串数组int n Integer.parseInt(strings[0])…

NUC980 Linux(4.4.289)内核配置SD卡相关参数,设备启动后插入后SD卡没反应

现象:SD卡插入&#xff0c;设备识别不到 原因:1.内核配置问题&#xff1b;2.硬件没有接地&#xff1b; 解决: 1.内核配置 2.硬件上SD卡接地

Java面试——Redis

优质博文&#xff1a;IT-BLOG-CN 一、Redis 为什么那么快 【1】完全基于内存&#xff0c;绝大部分请求是纯粹的内存操作&#xff0c;非常快速。数据存在内存中。 【2】数据结构简单&#xff0c;对数据操作也简单&#xff0c;Redis中的数据结构是专门进行设计的。 【3】采用单线…

IEEE754标准的c语言阐述,以及几个浮点数常量

很多年前&#xff0c;调研过浮点数与整数之间的双射问题&#xff1a; win7 intel x64 cpu vs2013 c语言浮点数精度失真问题 最近重新学习了一下IEEE754标准&#xff0c;也许实际还有很多深刻问题没有被揭示。 计算机程序设计艺术&#xff0c;据说这本书中也有讨论。 参考&…