yolov8添加注意力机制模块-CBAM

修改

  1. 在tasks.py(路径:ultralytics-main/ultralytics-main - attention/ultralytics/nn/tasks.py)文件中,引入CBAM模块。因为yolov8源码中已经包含CBAM模块,在conv.py文件中(路径:ultralytics-main/ultralytics-main - attention/ultralytics/nn/modules/conv.py),这里就就用自己写了。
  2. 修改tasks.py文件,搜索parse_model。在指定位置添加代码。
            elif m is CBAM:  # todo 源码修改 (增加了elif)
                """
                ch[f]:上一层的
                args[0]:第0个参数
                c1:输入通道数
                c2:输出通道数
                """
                c1, c2 = ch[f], args[0]
                # print("ch[f]:",ch[f])
                # print("args[0]:",args[0])
                # print("args:",args)
                # print("c1:",c1)
                # print("c2:",c2)
                if c2 != nc:  # if c2 not equal to number of classes (i.e. for Classify() output)
                    c2 = make_divisible(c2 * width, 8)
                args = [c1, *args[1:]]

    3.修改yolov8.yaml文件位置(ultralytics-main/ultralytics-main - attention/ultralytics/cfg/models/v8/yolov8.yaml)。修改head模块,修改的内容如下图。

        4.测试打印网络。已经添加成功。

分析

一般来说,注意力机制通常被分为以下基本四大类:

通道注意力 Channel Attention

空间注意力机制 Spatial Attention

时间注意力机制 Temporal Attention

分支注意力机制 Branch Attention

CBAM:通道注意力和空间注意力的集成者

源码解读

这段代码是对通道的注意力。首先经过自适应平均池化层,它会对每个输入通道的空间维度进行全局平均池化,并输出一个具有空间大小为 1x1 的特征图。然后是一个卷积操作,这相当于是对每个通道进行独立的全连接层变换,因为卷积核大小为1。

最后经过Sigmoid函数,将卷积层的输出转换为权重因子,范围在(0, 1)最后,这些权重因子与原始输入x逐元素相乘,以得到加权后的特征图,这一操作实现了注意力机制,允许模型专注于更有信息量的通道。

class ChannelAttention(nn.Module):
    """Channel-attention module https://github.com/open-mmlab/mmdetection/tree/v3.0.0rc1/configs/rtmdet."""

    def __init__(self, channels: int) -> None:
        """Initializes the class and sets the basic configurations and instance variables required."""
        super().__init__()
        self.pool = nn.AdaptiveAvgPool2d(1)
        self.fc = nn.Conv2d(channels, channels, 1, 1, 0, bias=True)
        self.act = nn.Sigmoid()

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """Applies forward pass using activation on convolutions of the input, optionally using batch normalization."""
        return x * self.act(self.fc(self.pool(x)))

下面是一个空间注意力模块,旨在通过对输入特征图加权来强调或抑制某些空间区域。空间注意力通常用于强调图像的重要部分并抑制不重要的部分。

self.cv1 是一个卷积层,有两个输入通道,一个输出通道,和可选的 kernel_size 与 padding。由于 bias=False,这个卷积层不会有偏置参数。两个输入通道对应于输入特征图的均值和最大值。

forward中

  1. torch.mean(x, 1, keepdim=True) 计算输入张量 x 每个样本的通道维度的均值,keepdim=True 表示保持输出张量的维度不变。

  2. torch.max(x, 1, keepdim=True)[0] 计算输入张量 x 每个样本的通道维度的最大值,[0] 是因为 torch.max 返回一个元组,包含最大值和相应的索引。

  3. torch.cat([avg_out, max_out], 1) 将均值和最大值沿通道维度拼接起来,这样每个空间位置都有两个通道:其均值和最大值。

  4. self.cv1(x_cat) 对拼接的结果应用 1x2 卷积,生成一个单通道的特征图,该特征图对应于每个空间位置的注意力权重。

  5. self.act(...) 应用 Sigmoid 激活函数将注意力权重映射到 (0, 1) 范围内。

  6. x * scale 将原始输入 x 与计算得到的空间注意力权重相乘,这样每个空间位置的特征值都会根据其重要性加权,实现了特征重标定。

最终,forward 方法返回的是加权后的输入特征图(对特征图的每个元素值×权值),它突出了输入中更重要的空间区域。

class SpatialAttention(nn.Module):
    """Spatial-attention module."""

    def __init__(self, kernel_size=7):
        """Initialize Spatial-attention module with kernel size argument."""
        super().__init__()
        assert kernel_size in (3, 7), 'kernel size must be 3 or 7'
        padding = 3 if kernel_size == 7 else 1
        self.cv1 = nn.Conv2d(2, 1, kernel_size, padding=padding, bias=False)
        self.act = nn.Sigmoid()

    def forward(self, x):
        """Apply channel and spatial attention on input for feature recalibration."""
        return x * self.act(self.cv1(torch.cat([torch.mean(x, 1, keepdim=True), torch.max(x, 1, keepdim=True)[0]], 1)))

下面就是CBAM,是上面两个模块的组合,通道注意力和空间注意力。通道注意力专注于哪些通道更重要,而空间注意力则集中在输入特征图中的哪些空间位置更重要。

  • 输入 x 首先通过 self.channel_attention,这个步骤会重新调整每个通道的重要性。
  • 然后,调整通道重要性后的特征图 x 通过 self.spatial_attention,这个步骤会重新调整特征图中每个位置的重要性。
  • 最终,这两个注意力机制的结果被串联起来,形成了最终的输出。

这种结构可以提高网络对于输入特征的逐通道和逐空间位置的重要性评估能力,进而可能提高模型的性能。

class CBAM(nn.Module):
    """Convolutional Block Attention Module."""

    def __init__(self, c1, kernel_size=7):
        """Initialize CBAM with given input channel (c1) and kernel size."""
        super().__init__()
        self.channel_attention = ChannelAttention(c1)
        self.spatial_attention = SpatialAttention(kernel_size)

    def forward(self, x):
        """Applies the forward pass through C1 module."""
        return self.spatial_attention(self.channel_attention(x))

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/409037.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

【README 小技巧】在项目README.md 中展示github点赞数量

在项目README.md 中展示github点赞数量 [![Star History Chart](https://api.star-history.com/svg?reposwujiawei1207537021/wu-lazy-cloud-network&typeDate)](https://star-history.com/#wujiawei1207537021/wu-lazy-cloud-network&Date)效果

【微服务】mybatis typehandler使用详解

目录 一、前言 二、TypeHandler简介 2.1 什么是TypeHandler 2.1.1 TypeHandler特点 2.2 TypeHandler原理 2.3 mybatis自带的TypeHandler 三、环境准备 3.1 准备一张数据表 3.2 搭建一个springboot工程 3.2.1 基础依赖如下 3.2.2 核心配置文件 3.2.3 测试接口 四、T…

java面向对象高级

一、静态 static读作静态,可以用来修饰成员变量,也能修饰成员方法。我们先来学习static修饰成员变量。 1.1 static修饰成员变量 Java中的成员变量按照有无static修饰分为两种:类变量、实例变量。它们的区别如下图所示: 由于静态…

通过底层原理理解Java是值传递还是引用传递?

本文学习目标或者巩固的知识点 参数传递方式 值传递引用传递指针传递 彻底理解Java的值传递和引用传递 从底层的角度分析值传递会发生复制行为 Java的参数传递例子 快手的一面面试曾经问到过此类题目,所以记下此篇加深印象。 问:求下面main方法中的输…

常用状态码

状态码 用于响应中的,表示响应的结果如何 1、200 OK 运行成功 2、404 Not Found 访问的资源没有找到(url的路径) 3、403 Forbidden 请求资源没有权限访问 4、405 Method Not Allowed 你的服务器只支持GET请求,但是你发了个PO…

基于springboot+vue的校园社团信息管理系统(前后端分离)

博主主页:猫头鹰源码 博主简介:Java领域优质创作者、CSDN博客专家、阿里云专家博主、公司架构师、全网粉丝5万、专注Java技术领域和毕业设计项目实战,欢迎高校老师\讲师\同行交流合作 ​主要内容:毕业设计(Javaweb项目|小程序|Pyt…

sizeof和strlen的对比及练习题(超详细)

创造不易,可以点点赞吗~ 如有错误,欢迎指出~ sizeof和strlen的对比 sizeof举例说明 sizeof计算int类型数据 (有关于数组)sizeof计算 strlen举例说明 strlen是C语言库函数,求字符串长度 函数原型: size_t s…

前端项目打包体积分析与优化

一、安装依赖分析工具 npm install webpack-bundle-analyz 二、修改webpack.config.js文件 1、导入上面下载的包 2、在plugins里创建实例 三、启动打包命令 npm run build 会弹出如下界面: 四、优化 1、通过CDN导入react-dom文件 修改webpack.config.js文件里…

掌握3个Mock工具,轻松玩转单元测试

公司要求提升单元测试的质量,提高代码的分支覆盖率和行覆盖率,安排我研究单元测试,指定方案分享并在开发部普及开。 单元测试中的Mock的目的 Mock的主要目的是让单元测试Write Once, Run Everywhere. 即编写一次后,可以在任意时…

LeetCode--代码详解 236. 二叉树的最近公共祖先

236. 二叉树的最近公共祖先 题目 给定一个二叉树, 找到该树中两个指定节点的最近公共祖先。 百度百科中最近公共祖先的定义为:“对于有根树 T 的两个节点 p、q,最近公共祖先表示为一个节点 x,满足 x 是 p、q 的祖先且 x 的深度尽可能大&am…

中国农业无人机行业市场现状分析与投资前景预测研究报告

全版价格:壹捌零零 报告版本:下单后会更新至最新版本 交货时间:1-2天 第一章农业无人机行业发展综述 第一节农业无人机行业定义及分类 一、农业无人机行业的定义 农业无人机是一种无人驾驶的飞行器来帮助优化农业经营,增加作…

找游戏 - 华为OD统一考试(C卷)

OD统一考试(C卷) 分值: 200分 题解: Java / Python / C 题目描述 小扇和小船今天又玩起来了数字游戏, 小船给小扇一个正整数 n(1 ≤ n ≤ 1e9),小扇需要找到一个比 n 大的数字 m&a…

C++ //练习 8.8 修改上一题的程序,将结果追加到给定的文件末尾。对同一个输出文件,运行程序至少两次,检验数据是否得以保留。

C Primer(第5版) 练习 8.8 练习 8.8 修改上一题的程序,将结果追加到给定的文件末尾。对同一个输出文件,运行程序至少两次,检验数据是否得以保留。 环境:Linux Ubuntu(云服务器) 工…

dolphinscheduler伪集群部署教程

文章目录 前言一、配置免密登录1. 配置root用户免密登录2. 创建用户2.1 创建dolphinscheduler用户2.2 配置dolphinscheduler用户免密登录2.3 退出dolphinscheduler用户 二、安装准备1. 安装条件2. 安装jdk3. 安装MySQL4. 安装zookeeper4.1 zookeeper单机部署4.2 启动zookeeper4…

14-ATF中对多核的支持

讨论一个系统、一个软件或ATF对多核的支持,其实就是看这个软件,在启动阶段如何区分主核、从核的? 在runtime阶段,是否能把不同核的CPU Data加以区分?是否能区分出cpuid? runtime阶段:主核和从核的区分 在启动阶段,会读取平台函数plat_is_my_cpu_primary来判单,当前是…

java 面向对象-上:类的结构之二

类的设计中,两个重要结构之二:方法 方法 描述类应该具的功能。 比如:Math类:sqrt()\random() \... Scanner类:nextXxx() ... Arrays类:sort() \ binarySearch() \ toString() \ equals() \ ... 1.举例 p…

哈希-数组中两数之和

文章目录 题目解题思路具体步骤 题目 在编程和算法学习中,"两数之和"问题是一个非常经典的问题,它不仅测试了基本的编程能力,还涉及到数据结构的有效使用,特别是哈希表。问题的描述是这样的: 题目&#xff…

1110. 删点成林

1110. 删点成林 关键要点 通过O(1)时间复杂度确认节点是否需要删除 Set to_deleteSet new HashSet<>(); Arrays.stream(to_delete).forEach(to_deleteSet::add); 使用深度优先搜索&#xff08;DFS&#xff09;遍历树 node.left dfs(node.left, s, ans); node.right …

Orange3数据预处理(列选择组件)数据角色及类型描述

在Orange3的文件组件中&#xff0c;datetime、categorical、numeric以及text代表不同种类的数据类型&#xff0c;具体如下&#xff1a; datetime&#xff1a;代表日期和时间类型的数据。通常用于时间序列分析、生存分析和其他需要考虑时间因素的机器学习任务中。例如&#xff0…

英语连读技巧15

1. first one – 第一个 连读听起来就像是&#xff1a;【佛斯湾】 连读的音标为&#xff1a; 例句&#xff1a;I don’t want to be the first one there agin. 发音指导&#xff1a;在“first one”的连读中&#xff0c;"t"和"o"之间的连接几乎消失&a…