YOLOv5修改注意力机制CBAM

直接上干货

CBAM注意力机制是由通道注意力机制(channel)和空间注意力机制(spatial)组成。

传统基于卷积神经网络的注意力机制更多的是关注对通道域的分析,局限于考虑特征图通道之间的作用关系。CBAM从 channel 和 spatial 两个作用域出发,引入空间注意力和通道注意力两个分析维度,实现从通道到空间的顺序注意力结构。空间注意力可使神经网络更加关注图像中对分类起决定作用的像素区域而忽略无关紧要的区域,通道注意力则用于处理特征图通道的分配关系,同时对两个维度进行注意力分配增强了注意力机制对模型性能的提升效果。
 

CBAM中的通道注意力机制模块流程图如下。先将输入特征图分别进行全局最大池化和全局平均池化,对特征映射基于两个维度压缩,获得两张不同维度的特征描述。池化后的特征图共用一个多层感知器网络,先通过一个全连接层下降通道数,再通过另一个全连接恢复通道数。将两张特征图在通道维度堆叠,经过 sigmoid 激活函数将特征图的每个通道的权重归一化到0-1之间。将归一化后的权重和输入特征图相乘。

yaml 配置文件如下

# YOLOv5 🚀 by Ultralytics, GPL-3.0 license

# Parameters
nc: 6  # number of classes
depth_multiple: 0.33  # model depth multiple
width_multiple: 0.25  # layer channel multiple
anchors:
  - [10,13, 16,30, 33,23]  # P3/8
  - [30,61, 62,45, 59,119]  # P4/16
  - [116,90, 156,198, 373,326]  # P5/32

# YOLOv5 v6.0 backbone
backbone:
  # [from, number, module, args]
  [[-1, 1, Conv, [64, 6, 2, 2]],  # 0-P1/2
   [-1, 1, Conv, [128, 3, 2]],  # 1-P2/4
   [-1, 3, C3, [128]],
   [-1, 1, Conv, [256, 3, 2]],  # 3-P3/8
   [-1, 6, C3, [256]],
   [-1, 1, Conv, [512, 3, 2]],  # 5-P4/16
   [-1, 9, C3, [512]],
   [-1, 1, Conv, [1024, 3, 2]],  # 7-P5/32
   [-1, 3, C3, [1024]],
   [-1, 1, CBAM, [1024]],
   [-1, 1, SPPF, [1024, 5]],  # 10
  ]

# YOLOv5 v6.0 head
head:
  [[-1, 1, Conv, [512, 1, 1]],
   [-1, 1, nn.Upsample, [None, 2, 'nearest']],
   [[-1, 6], 1, Concat, [1]],  # cat backbone P4
   [-1, 3, C3, [512, False]],  # 14

   [-1, 1, Conv, [256, 1, 1]],
   [-1, 1, nn.Upsample, [None, 2, 'nearest']],
   [[-1, 4], 1, Concat, [1]],  # cat backbone P3
   [-1, 3, C3, [256, False]],  # 18 (P3/8-small)

   [-1, 1, Conv, [256, 3, 2]],
   [[-1, 15], 1, Concat, [1]],  # cat head P4
   [-1, 3, C3, [512, False]],  # 21 (P4/16-medium)

   [-1, 1, Conv, [512, 3, 2]],
   [[-1, 11], 1, Concat, [1]],  # cat head P5
   [-1, 3, C3, [1024, False]],  # 24 (P5/32-large)

   [[18, 21, 24], 1, Detect, [nc, anchors]],  # Detect(P3, P4, P5)
  ]

common加入以下代码

# CBAM
class ChannelAttention(nn.Module):
    def __init__(self, in_planes, ratio=16):
        super(ChannelAttention, self).__init__()
        self.avg_pool = nn.AdaptiveAvgPool2d(1)
        self.max_pool = nn.AdaptiveMaxPool2d(1)
        self.f1 = nn.Conv2d(in_planes, in_planes // ratio, 1, bias=False)
        self.relu = nn.ReLU()
        self.f2 = nn.Conv2d(in_planes // ratio, in_planes, 1, bias=False)
        self.sigmoid = nn.Sigmoid()

    def forward(self, x):
        avg_out = self.f2(self.relu(self.f1(self.avg_pool(x))))
        max_out = self.f2(self.relu(self.f1(self.max_pool(x))))
        out = self.sigmoid(avg_out + max_out)
        return out


class SpatialAttention(nn.Module):
    def __init__(self, kernel_size=7):
        super(SpatialAttention, self).__init__()
        assert kernel_size in (3, 7), 'kernel size must be 3 or 7'
        padding = 3 if kernel_size == 7 else 1
        # (特征图的大小-算子的size+2*padding)/步长+1
        self.conv = nn.Conv2d(2, 1, kernel_size, padding=padding, bias=False)
        self.sigmoid = nn.Sigmoid()

    def forward(self, x):
        # 1*h*w
        avg_out = torch.mean(x, dim=1, keepdim=True)
        max_out, _ = torch.max(x, dim=1, keepdim=True)
        x = torch.cat([avg_out, max_out], dim=1)
        # 2*h*w
        x = self.conv(x)
        # 1*h*w
        return self.sigmoid(x)


class CBAM(nn.Module):
    # CSP Bottleneck with 3 convolutions
    def __init__(self, c1, c2, ratio=16, kernel_size=7):  # ch_in, ch_out, number, shortcut, groups, expansion
        super(CBAM, self).__init__()
        self.channel_attention = ChannelAttention(c1, ratio)
        self.spatial_attention = SpatialAttention(kernel_size)

    def forward(self, x):
        out = self.channel_attention(x) * x
        # c*h*w
        # c*h*w * 1*h*w
        out = self.spatial_attention(out) * out
        return out

YOLO 的

parse_model 注册

到此完成

后续会给大家讲解YOLOv8怎么修改

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/68170.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

Kafka与Zookeeper版本对应关系

文章目录 了解版本对应Kafka安装包Kafka源码包 了解 比如: kafka_2.11-1.1.1.jar包 其中2.11表示的是Scala的版本,因为Kafka服务器端代码完全由Scala语音编写。”-“后面的1.1.1表示的kafka的版本信息。遵循一个基本原则,Kafka客户端版本和服…

MySQL5.7数据库、Navicat Premium1.6可视化工具安装教程【详细教程】

文章目录 一、MySQL、Navicat、注册机地址二、安装(一)、MySQL安装(二)、Navicat Premium安装(三)、集活Navicat Premium 三、遇到的问题1、Are you sure your navicat has not beenpatched/modified befor…

Spring 事务管理

目录 1. 事务管理 1.1. Spring框架的事务支持模型的优势 1.1.1. 全局事务 1.1.2. 本地事务 1.1.3. Spring框架的一致化编程模型 1.2. 了解Spring框架的事务抽象(Transaction Abstraction) 1.2.1. Hibernate 事务设置 1.3. 用事务同步资源 1.3.1…

Malloc动态内存分配

在C语言中我们会使用malloc来动态地分配内存,这样做的一个主要理由是有些数据结构的大小只有在运行时才能确定。例如,如果你正在编写一个程序,需要用户输入一些数据,但你不知道用户会输入多少数据,那么你就需要使用动态…

vue3使用pinia和pinia-plugin-persist做持久化存储

1、安装依赖 pnpm i pinia // 安装 pinia pnpm i pinia-plugin-persist // 安装持久化存储插件2、main.js引入 import App from ./App.vue const app createApp(App)//pinia import { createPinia } from pinia import piniaPersist from pinia-plugin-persist //持久化插件 …

Python中enumerate用法详解

目录 1.简介 2.语法 3.参数 4.返回值 5.详解 6.实例 7.补充 1.简介 enumerate() 函数用于将一个可遍历的数据对象(如列表、元组或字符串)组合为一个索引序列,同时列出数据和数据下标,一般用在 for 循环当中。 2.语法 以下是 enumerate() 方法的语…

Qt应用开发(基础篇)——堆栈窗口 QStackedWidget

一、前言 QStackedWidget继承于QFrame,QFrame继承于QWidget,是Qt常用的堆栈窗口部件。 框架类QFrame介绍 QStackedWidget堆栈窗口,根据下标切换,一次显示一个小部件,常用于应用界面切换、图片轮询播放等场景。 二、QSt…

通义千问开源模型部署使用

首先可以参考modelScope社区给出的使用文档,已经足够全面 通义千问-7B-Chat 但在按照文档中步骤部署时,还是有些错误问题发生,可以搜索参考的解决方式不多,所以记录下来 个人电脑部署 这里不太建议使用自己的笔记本部署通义千…

exchange partition global index

EXCHANGE PARTITION with a Table having a UNIQUE INDEX and PK Constraint. (Doc ID 1620636.1)​编辑To Bottom In this Document Symptoms Changes Cause Solution References APPLIES TO: Oracle Database - Enterprise Edition - Version 11.2.0.3 and later Oracle Da…

Spring整合MyBatis(详细步骤)

Spring与Mybatis的整合&#xff0c;大体需要做两件事&#xff0c; 第一件事是:Spring要管理MyBatis中的SqlSessionFactory 第二件事是:Spring要管理Mapper接口的扫描 具体的步骤为: 步骤1:项目中导入整合需要的jar包 <dependency><!--Spring操作数据库需要该jar包…

gazebo 导入从blender导出的dae等文件

背景&#xff1a; gazebo 模型库里的模型在我需要完成的任务中不够用&#xff0c;还是得从 solidworks、3DMax, blender这种建模软件里面在手动画一些&#xff0c;或者去他们的库里面在挖一挖。 目录 1 blender 1-1 blender 相关links 1-2 install 2 gazebo导入模型 2-1 g…

湘大 XTU OJ 1308 比赛 题解:循环结束的临界点+朴素模拟

一、链接 比赛 二、题目 题目描述 有n个人要进行比赛&#xff0c;比赛规则如下&#xff1a; 假设每轮比赛的人是m&#xff0c;取最大的k&#xff0c;k2^t且k≤m。这k个人每2人举行一场比赛&#xff0c;胜利者进入一下轮&#xff0c;失败者被淘汰。余下的m-k个人&#xff0…

从Spring源码看创建对象的过程

从Spring源码看创建对象的过程 Spring对于程序员set注入的属性叫做属性的填充、对于set注入之后的处理&#xff08;包括BeanPostProcessor的处理、初始化方法的处理&#xff09;叫做初始化。 研读AbstractBeanFactory类中的doGetBean()方法 doGetBean()方法首先完成的工作是…

mysql基础之触发器的简单使用

1.建立学生信息表 -- 触发器 -- 建立学生信息表 create table s1(id int unsigned auto_increment,name varchar(30),score tinyint unsigned,dept varchar(50),primary key(id) );2.建立学生补考信息表 -- 建立学生补考信息表 create table s2 like s1;3.建立触发器&#xf…

Grafana技术文档-概念-《十分钟扫盲》

Grafana官网链接 Grafana: The open observability platform | Grafana Labs 基本概念 Grafana是一个开源的度量分析和可视化套件&#xff0c;常用于对大量数据进行实时分析和可视化。以下是Grafana的基本概念&#xff1a; 数据源&#xff08;Data Source&#xff09;&#…

【大数据】Flink 详解(一):基础篇

Flink 详解&#xff08;一&#xff09;&#xff1a;基础篇 1、什么是 Flink &#xff1f; Flink 是一个以 流 为核心的高可用、高性能的分布式计算引擎。具备 流批一体&#xff0c;高吞吐、低延迟&#xff0c;容错能力&#xff0c;大规模复杂计算等特点&#xff0c;在数据流上提…

模板的进阶

目录 1.非类型模板参数 2.模板特化 2.1概念 2.2函数模板特化 2.3类模板特化 2.3.1全特化 2.3.2偏特化 3.模板分离编译 3.1什么是分离编译 3.2 模板的分离编译 3.3解决方法 4. 模板总结 1.非类型模板参数 模板参数分类类型形参与非类型形参。 类型形参即&#xff1a…

Python(七十五--总结)列表、字典、元组、集合总结

❤️ 专栏简介&#xff1a;本专栏记录了我个人从零开始学习Python编程的过程。在这个专栏中&#xff0c;我将分享我在学习Python的过程中的学习笔记、学习路线以及各个知识点。 ☀️ 专栏适用人群 &#xff1a;本专栏适用于希望学习Python编程的初学者和有一定编程基础的人。无…

关于Object 0 = new Object() 的追魂九连问

文章目录 对象的创建过程对象的组成解析普通对象**结果分析&#xff1a;**给对象添加属性注意事项 补充jvm压缩指针栗子&#xff1a; 对象头包含什么对象怎么定位&#xff1f;**句柄方式和直接引用的优缺点&#xff1a;** 对象怎么分配&#xff1f;为什么hotspot不使用c对象来代…

QT的信号槽的四种写法和五种链接方式

目录 四种信号槽写法&#xff1a; 五种连接方式&#xff1a; 实例&#xff1a; 常见错误及改正&#xff1a; 错误1: 未连接信号与槽 错误2: 信号和槽参数不匹配 错误3: 未使用Q_OBJECT宏 错误4: 跨线程连接未处理 在Qt中&#xff0c;信号&#xff08;Signal&#xff09…