YOLOv8改进算法之添加CA注意力机制

1. CA注意力机制

CA(Coordinate Attention)注意力机制是一种用于加强深度学习模型对输入数据的空间结构理解的注意力机制。CA 注意力机制的核心思想是引入坐标信息,以便模型可以更好地理解不同位置之间的关系。如下图:

4cf2f4e3554f46929a6bbb5db3e79108.png

1. 输入特征: CA 注意力机制的输入通常是一个特征图,它通常是卷积神经网络(CNN)中的某一层的输出,具有以下形状:[C, H, W],其中:

  • C 是通道数,表示特征图中的不同特征通道。
  • H 是高度,表示特征图的垂直维度。
  • W 是宽度,表示特征图的水平维度。

2. 全局平均池化: CA 注意力机制首先对输入特征图进行两次全局平均池化,一次在宽度方向上,一次在高度方向上。这两次操作分别得到两个特征映射:

  • 在宽度方向上的平均池化得到特征映射 [C, H, 1]
  • 在高度方向上的平均池化得到特征映射 [C, 1, W]

这两个特征映射分别捕捉了在宽度和高度方向上的全局特征。

3. 合并宽高特征: 将上述两个特征映射合并,通常通过简单的堆叠操作,得到一个新的特征层,形状为 [C, 1, H + W],其中 H + W 表示在宽度和高度两个方向上的维度合并在一起。

4. 卷积+标准化+激活函数: 对合并后的特征层进行卷积操作,通常是 1x1 卷积,以捕捉宽度和高度维度之间的关系。然后,通常会应用标准化(如批量标准化)和激活函数(如ReLU)来进一步处理特征,得到一个更加丰富的表示。

5. 再次分开: 分别从上述特征层中分离出宽度和高度方向的特征:

  • 一个分支得到特征层 [C, 1, H]
  • 另一个分支得到特征层 [C, 1, W]

6. 转置: 对分开的两个特征层进行转置操作,以恢复宽度和高度的维度,得到两个特征层分别为 [C, H, 1][C, 1, W]

7. 通道调整和 Sigmoid: 对两个分开的特征层分别应用 1x1 卷积,以调整通道数,使其适应注意力计算。然后,应用 Sigmoid 激活函数,得到在宽度和高度维度上的注意力分数。这些分数用于指示不同位置的重要性。

8. 应用注意力: 将原始输入特征图与宽度和高度方向上的注意力分数相乘,得到 CA 注意力机制的输出。

 

2. YOLOv8添加CA注意力机制

加入注意力机制,在ultralytics包中的nn包的modules里添加CA注意力模块,我这里选择在conv.py文件中添加CA注意力机制。

4334c4e3f5ae4109b59163920b5157a9.png

CA注意力机制代码如下:

import torch
import torch.nn as nn
import torch.nn.functional as F


class h_sigmoid(nn.Module):
    def __init__(self, inplace=True):
        super(h_sigmoid, self).__init__()
        self.relu = nn.ReLU6(inplace=inplace)

    def forward(self, x):
        return self.relu(x + 3) / 6


class h_swish(nn.Module):
    def __init__(self, inplace=True):
        super(h_swish, self).__init__()
        self.sigmoid = h_sigmoid(inplace=inplace)

    def forward(self, x):
        return x * self.sigmoid(x)


class CoordAtt(nn.Module):
    def __init__(self, inp, reduction=32):
        super(CoordAtt, self).__init__()
        self.pool_h = nn.AdaptiveAvgPool2d((None, 1))
        self.pool_w = nn.AdaptiveAvgPool2d((1, None))

        mip = max(8, inp // reduction)

        self.conv1 = nn.Conv2d(inp, mip, kernel_size=1, stride=1, padding=0)
        self.bn1 = nn.BatchNorm2d(mip)
        self.act = h_swish()

        self.conv_h = nn.Conv2d(mip, inp, kernel_size=1, stride=1, padding=0)
        self.conv_w = nn.Conv2d(mip, inp, kernel_size=1, stride=1, padding=0)

    def forward(self, x):
        identity = x

        n, c, h, w = x.size()
        x_h = self.pool_h(x)
        x_w = self.pool_w(x).permute(0, 1, 3, 2)

        y = torch.cat([x_h, x_w], dim=2)
        y = self.conv1(y)
        y = self.bn1(y)
        y = self.act(y)

        x_h, x_w = torch.split(y, [h, w], dim=2)
        x_w = x_w.permute(0, 1, 3, 2)

        a_h = self.conv_h(x_h).sigmoid()
        a_w = self.conv_w(x_w).sigmoid()

        out = identity * a_w * a_h

        return out

CA注意力机制的注册和引用如下:

 ultralytics/nn/modules/_init_.py文件中:

b5965c8994e84095842122c96556e16b.jpeg

  ultralytics/nn/tasks.py文件夹中:

823386aa86634340aff6aa1701b3507e.png

 在tasks.py中的parse_model中添加如下代码:

        elif m in {CoordAtt}:
            args=[ch[f],*args]

新建相应的yolov8s-CA.yaml文件,代码如下:

# Ultralytics YOLO 🚀, AGPL-3.0 license
# YOLOv8 object detection model with P3-P5 outputs. For Usage examples see https://docs.ultralytics.com/tasks/detect

# Parameters
nc: 80  # number of classes
scales: # model compound scaling constants, i.e. 'model=yolov8n.yaml' will call yolov8.yaml with scale 'n'
  # [depth, width, max_channels]
  n: [0.33, 0.25, 1024]  # YOLOv8n summary: 225 layers,  3157200 parameters,  3157184 gradients,   8.9 GFLOPs
  s: [0.33, 0.50, 1024]  # YOLOv8s summary: 225 layers, 11166560 parameters, 11166544 gradients,  28.8 GFLOPs
  m: [0.67, 0.75, 768]   # YOLOv8m summary: 295 layers, 25902640 parameters, 25902624 gradients,  79.3 GFLOPs
  l: [1.00, 1.00, 512]   # YOLOv8l summary: 365 layers, 43691520 parameters, 43691504 gradients, 165.7 GFLOPs
  x: [1.00, 1.25, 512]   # YOLOv8x summary: 365 layers, 68229648 parameters, 68229632 gradients, 258.5 GFLOPs

# YOLOv8.0n backbone
backbone:
  # [from, repeats, module, args]
  - [-1, 1, Conv, [64, 3, 2]]  # 0-P1/2
  - [-1, 1, Conv, [128, 3, 2]]  # 1-P2/4
  - [-1, 3, C2f, [128, True]]
  - [-1, 1, Conv, [256, 3, 2]]  # 3-P3/8
  - [-1, 6, C2f, [256, True]]
  - [-1,1,CoordAtt,[]]
  - [-1, 1, Conv, [512, 3, 2]]  # 5-P4/16
  - [-1, 6, C2f, [512, True]]
  - [-1,1,CoordAtt,[]]
  - [-1, 1, Conv, [1024, 3, 2]]  # 7-P5/32
  - [-1, 3, C2f, [1024, True]]
  - [-1,1,CoordAtt,[]]
  - [-1, 1, SPPF, [1024, 5]]  # 9

# YOLOv8.0n head
head:
  - [-1, 1, nn.Upsample, [None, 2, 'nearest']]
  - [[-1, 8], 1, Concat, [1]]  # cat backbone P4
  - [-1, 3, C2f, [512]]  # 12

  - [-1, 1, nn.Upsample, [None, 2, 'nearest']]
  - [[-1, 5], 1, Concat, [1]]  # cat backbone P3
  - [-1, 3, C2f, [256]]  # 15 (P3/8-small)

  - [-1, 1, Conv, [256, 3, 2]]
  - [[-1, 15], 1, Concat, [1]]  # cat head P4
  - [-1, 3, C2f, [512]]  # 18 (P4/16-medium)

  - [-1, 1, Conv, [512, 3, 2]]
  - [[-1, 12], 1, Concat, [1]]  # cat head P5
  - [-1, 3, C2f, [1024]]  # 21 (P5/32-large)

  - [[18, 21, 24], 1, Detect, [nc]]  # Detect(P3, P4, P5)

在main.py文件中进行训练:

if __name__ == '__main__':

    # 使用yaml配置文件来创建模型,并导入预训练权重.
    model = YOLO('ultralytics/cfg/models/v8/yolov8s-CA.yaml')
    # model.load('yolov8n.pt')
    model.train(**{'cfg': 'ultralytics/cfg/default.yaml', 'data': 'dataset/data.yaml'})

 

 

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/467464.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

敏捷开发——elementUI/Vue使用/服务器部署

1. 创建vue项目 2. 安装element-ui组件库 npm i -S element-ui或 npm install element-ui3. 在main.js中导入element-ui组件 import ElementUI from element-ui import element-ui/lib/theme-chalk/index.css Vue.use(ElementUI)4. 运行 npm run serve后可以使用 ctrc终止进…

一个 Java8 的坑坑了我 2 小时试错...

背景 趁着失业的间隙想要重温一下Flink相关的学习,当前一切就绪之后,想要用我的 mac运行一个 flink版 helloworld 来验证整体环境是否 OK的时候出现了如下问题,这个问题我未曾遇到过,如下: Failed to write core dump…

【iOS】Blocks

文章目录 前言一、什么是Blocks二、Blocks模式1.Block语法2.Block类型变量3.截获自动变量值4.__block说明符5.截获的自动变量 三、Blocks的实现1.Block的实质__main_block_impl_0Block对象的实现结构体初始化 2.截获自动变量值3.__block说明符4.Block存储域5.__block变量存储域…

LM studio使用gemmar聊天小试

通过LM studio可以方便的使用各种模型,使用LM提供的chat界面或者是使用python代码。 试试代码 在windows下使用python简单一试,例子直接复制LM界面上的代码: 用pip安装 openai包在LM界面 Start Server 需要安装 openai包。 本地电脑是I7…

ArcGIS巧思制作3D景观地图

John Nelson 又制作了一个制图教程视频,我原以为只是一个简单的局部场景DEM夸张实现的3D地图。 不过细看以后…… 还就是比较简单的3D场景地图,操作不难,但是 John Nelson 就是天才。 为什么? 他使用 ArcGIS Pro,在普通的3D地图中,不仅仅是图层混合制作地形效果,还巧妙的…

GPT实战系列-LangChain的Prompt提示模版构建

GPT实战系列-LangChain的Prompt提示模版构建 LangChain GPT实战系列-LangChain如何构建基通义千问的多工具链 GPT实战系列-构建多参数的自定义LangChain工具 GPT实战系列-通过Basetool构建自定义LangChain工具方法 GPT实战系列-一种构建LangChain自定义Tool工具的简单方法…

LLM 构建Data Multi-Agents 赋能数据分析平台的实践之②:数据治理之一

概述 数据治理不仅是产业数字化转型的基石,更是推动产业向更高层次、更精细化、更智能的方向发展的重要引擎。通过科学有效的数据治理实践,产业能够在数字化进程中实现数据驱动的决策与行动,最终达到转型升级的战略目标。 一、数据治理在产业…

800万像素车载摄像头的一些思考

1. 800万像素摄像头与算力、算法以及数据的关系 随着800万像素摄像头在2021款理想One上首次量产应用,800万像素摄像头的议论热潮 再次兴起。有一个话题大家普遍很关注,那就是800万像素摄像头与算力、算法以及数据之 间的关系, 例如&#xf…

深入理解快速排序

一、快速排序 快速排序是冒泡排序的一种改进算法,相比于冒泡排序效率更优。 算法过程分析: 通过采用分治策略,围绕一个 x 将原始数组划分为两个子数组,使得前一个子数组的元素≤ x ≤ 后一个子数组元素,对两个子数组进…

体验OceanBase OBD V2.5.0 组件内扩容和组件变更

背景 OBD 是OceanBase的命令行部署工具,在 obd V2.5.0 版本之前,其主要功能主要是部署各类组件,例如 oceanbase-ce,obproxy-ce,obagent 等。然而,它并不支持组件的变更操作以及组件内部的扩缩容调整。具体来说: 1、若…

#每天一道面试题# 什么是MySQL的回表查询

MySQL中的索引按照物理存储的方式分为聚集索引和非聚集索引; 聚集索引索引和数据存储在一起,B树的叶子节点就是表数据,如果通过聚集索引查询数据,直接就可以查询出我们想要的数据;非聚集索引B树的叶子节点存储的是主键…

流畅的python--小技巧总结

对于python菜鸟来说,只看基本教程后的结果就是看是看过了,但依然不会用,遇事先百度; 此文整理了一些python区别于js的一些小技巧(鄙人前端学py),可以快速高效实现功能,当个笔记&…

【嵌入式学习收徒,高薪offer等你来!!!】

有粉丝问了一个问题,说他今年要毕业了,投了好多简历都石沉大海,感觉好多公司都不招人了,想问一下现在究竟是不是如此,不清楚我当年毕业的时候是怎么样的。 我先不直接回答这个问题,先来看一组数据&#xf…

Day 1.数据结构----单向链表(无头单向链表)

数据结构 如何组织存储数据 程序 数据结构 算法 MVC:软件设计结构 M:数据的管理(数据结构) V:视图,数据的反映及人机交互 C:逻辑控制 单向链表 有头链表:第一个链表结点中…

山景BP1048 升级狗烧写

1.打开MVAssistant_BP10xx工具,在芯片型号栏中选择B1X系列。 2.模式选择 选 M2.仅升级Flash SH(可选) 3 .Code数据选择SDK编译好的bin文件 4.const数据选择编译好的提示音bin文件。 5.点击升级狗下载。 6. 如下图所示,出现提示为正在给升级狗正在下载程…

Machine Learning ---- Feature Scaling

目录 一、What is feature scaling:: 二、Why do we need to perform feature scaling? 三、How to perform feature scaling: 1、Normalization: 2、Mean normalization: 3、Standardization (data needs to follow a normal distribution): 一、What is featur…

salesforce生产环境如何删除触发器

由于生产环境不能直接删除触发器,所以需要在sandbox中先让触发器inactive再部署到生产环境,就可以让触发器失效了。

人物百度百科如何创建?人物类词条编辑指南

创建人物百度百科是一项既具有挑战性的工作。下面,伯乐网络传媒就来给大家详细介绍如何创建人物百度百科,包括准备工作、创建步骤以及常见问题解答。 一、创建人物百度百科的准备工作 1. 人物百科词条创建要求 百度百科对创建人物词条有一定的要求&…

谷歌google adsense广告申请提示:网站已下线或无法访问

自己在运营网站时,想在网站上挂google adsense广告,但是申请很多次,收到的邮件都是您需要先纠正一些问,登陆google adsense后台显示,网站已下线或无法访问。 重新申请多次问题依旧,我在想为什么国外无法访…

Python命名空间和作用域,让你的代码逻辑更清晰!

关于Python,我们前面的基础部分,基本也说完了,包括我们也讲了高阶特性,面向对象编程。现在我来补充一个知识:命名空间和作用域。 这是Python两个重要的概念,它们决定了变量的可见性和访问范围。理解命名空…