YOLOv8改进 | 2023注意力篇 | EMAttention注意力机制(附多个可添加位置)

一、本文介绍

本文给大家带来的改进机制是EMAttention注意力机制,它的核心思想是,重塑部分通道到批次维度,并将通道维度分组为多个子特征,以保留每个通道的信息并减少计算开销。EMA模块通过编码全局信息来重新校准每个并行分支中的通道权重,并通过跨维度交互来捕获像素级别的关系。本文首先给大家提供效果图(由基础版本未作任何修改和修改了本文的改进机制的效果对比图),然后介绍其主要的原理,最后手把手教大家如何添加该注意力机制

推荐指数:⭐⭐⭐⭐⭐

涨点效果:⭐⭐⭐⭐⭐

专栏回顾:YOLOv8改进系列专栏——本专栏持续复习各种顶会内容——科研必备    

训练结果对比图-> 

目录

一、本文介绍

二、EMAttention的框架原理

三、EMAttention的核心代码 

四、手把手教你添加EMAttention

4.1 EMAttention添加步骤

4.1.1 步骤一

4.1.2 步骤二

4.1.3 步骤三

五、EMAttention的yaml文件和运行记录

5.1 EMAttention的yaml版本一(推荐)

5.2 EMAttention的yaml版本二

5.3 推荐EMAttention可添加的位置 

5.4 EMAttention的训练过程截图 

五、本文总结


二、EMAttention的框架原理

官方论文地址: 官方论文地址

官方代码地址: 官方代码地址


主要原理是一个新型的高效多尺度注意力(EMA)这个模块通过重塑部分通道到批次维度,并将通道维度分组为多个子特征,以保留每个通道的信息并减少计算开销。EMA模块通过编码全局信息来重新校准每个并行分支中的通道权重,并通过跨维度交互来捕获像素级别的关系。

提出的创新点主要包括:

1. 高效多尺度注意力(EMA)模:这是一种新型的注意力机制,专为计算机视觉任务设计,旨在同时减少计算开销和保留每个通道的关键信息。

2. 通道和批次维度的重组:EMA通过重新组织通道维度和批次维度,提高了模型处理特征的能力。

3. 跨维度交互:模块利用跨维度的交互来捕捉像素级别的关系,这在传统的注意力模型中较为少见。

4. 全局信息编码和通道权重校准:EMA模块在并行分支中编码全局信息,用于通道权重的重新校准,增强了特征表示的能力。

这张图片是文章中提出的高效多尺度注意力(EMA)模块的示意图。"g"表示输入通道被分成的组数。"X Avg Pool"和"Y Avg Pool"分别代表一维水平和垂直的全局池化操作。在EMA模块中,输入首先被分组,然后通过不同的分支进行处理:一个分支进行一维全局池化,另一个通过3x3的卷积进行特征提取。两个分支的输出特征之后通过sigmoid函数和归一化操作进行调制,最终通过跨维度交互模块合并,以捕捉像素级的成对关系。经过最终的sigmoid调节后,输出特征映射以增强或减弱原始输入特征,从而得到最终输出。


三、EMAttention的核心代码 

使用方法看章节四。

import torch
from torch import nn

class EMA(nn.Module):
    def __init__(self, channels, factor=32):
        super(EMA, self).__init__()
        self.groups = factor
        assert channels // self.groups > 0
        self.softmax = nn.Softmax(-1)
        self.agp = nn.AdaptiveAvgPool2d((1, 1))
        self.pool_h = nn.AdaptiveAvgPool2d((None, 1))
        self.pool_w = nn.AdaptiveAvgPool2d((1, None))
        self.gn = nn.GroupNorm(channels // self.groups, channels // self.groups)
        self.conv1x1 = nn.Conv2d(channels // self.groups, channels // self.groups, kernel_size=1, stride=1, padding=0)
        self.conv3x3 = nn.Conv2d(channels // self.groups, channels // self.groups, kernel_size=3, stride=1, padding=1)

    def forward(self, x):
        b, c, h, w = x.size()
        group_x = x.reshape(b * self.groups, -1, h, w)  # b*g,c//g,h,w
        x_h = self.pool_h(group_x)
        x_w = self.pool_w(group_x).permute(0, 1, 3, 2)
        hw = self.conv1x1(torch.cat([x_h, x_w], dim=2))
        x_h, x_w = torch.split(hw, [h, w], dim=2)
        x1 = self.gn(group_x * x_h.sigmoid() * x_w.permute(0, 1, 3, 2).sigmoid())
        x2 = self.conv3x3(group_x)
        x11 = self.softmax(self.agp(x1).reshape(b * self.groups, -1, 1).permute(0, 2, 1))
        x12 = x2.reshape(b * self.groups, c // self.groups, -1)  # b*g, c//g, hw
        x21 = self.softmax(self.agp(x2).reshape(b * self.groups, -1, 1).permute(0, 2, 1))
        x22 = x1.reshape(b * self.groups, c // self.groups, -1)  # b*g, c//g, hw
        weights = (torch.matmul(x11, x12) + torch.matmul(x21, x22)).reshape(b * self.groups, 1, h, w)
        return (group_x * weights.sigmoid()).reshape(b, c, h, w)


四、手把手教你添加EMAttention

4.1 EMAttention添加步骤

4.1.1 步骤一

首先我们找到如下的目录'ultralytics/nn/modules',然后在这个目录下创建一个py文件,名字可以根据你自己的习惯起,然后将EMAttention的核心代码复制进去。

4.1.2 步骤二

之后我们找到'ultralytics/nn/tasks.py'文件,在其中注册我们的EMAttention模块。

首先我们需要在文件的开头导入我们的EMAttention模块,如下图所示->

4.1.3 步骤三

我们找到parse_model这个方法,可以用搜索也可以自己手动找,大概在六百多行吧。 我们找到如下的地方,然后将EMA添加进去即可,模仿我添加即可。

到此我们就注册成功了,可以修改yaml文件中输入EMAttention使用这个模块了。


五、EMAttention的yaml文件和运行记录

下面推荐几个版本的yaml文件给大家,大家可以复制进行训练,但是组合用很多具体那种最有效果都不一定,针对不同的数据集效果也不一样,我不可每一种都做实验,所以我下面推荐了几种我自己认为可能有效果的配合方式,你也可以自己进行组合。


5.1 EMAttention的yaml版本一(推荐)

下面的添加EMAttention是我实验结果的版本。

# Ultralytics YOLO 🚀, AGPL-3.0 license
# YOLOv8 object detection model with P3-P5 outputs. For Usage examples see https://docs.ultralytics.com/tasks/detect

# Parameters
nc: 80  # number of classes
scales: # model compound scaling constants, i.e. 'model=yolov8n.yaml' will call yolov8.yaml with scale 'n'
  # [depth, width, max_channels]
  n: [0.33, 0.25, 1024]  # YOLOv8n summary: 225 layers,  3157200 parameters,  3157184 gradients,   8.9 GFLOPs
  s: [0.33, 0.50, 1024]  # YOLOv8s summary: 225 layers, 11166560 parameters, 11166544 gradients,  28.8 GFLOPs
  m: [0.67, 0.75, 768]   # YOLOv8m summary: 295 layers, 25902640 parameters, 25902624 gradients,  79.3 GFLOPs
  l: [1.00, 1.00, 512]   # YOLOv8l summary: 365 layers, 43691520 parameters, 43691504 gradients, 165.7 GFLOPs
  x: [1.00, 1.25, 512]   # YOLOv8x summary: 365 layers, 68229648 parameters, 68229632 gradients, 258.5 GFLOP

# YOLOv8.0n backbone
backbone:
  # [from, repeats, module, args]
  - [-1, 1, Conv, [64, 3, 2]]  # 0-P1/2
  - [-1, 1, Conv, [128, 3, 2]]  # 1-P2/4
  - [-1, 3, C2f, [128, True]]
  - [-1, 1, Conv, [256, 3, 2]]  # 3-P3/8
  - [-1, 6, C2f, [256, True]]
  - [-1, 1, Conv, [512, 3, 2]]  # 5-P4/16
  - [-1, 6, C2f, [512, True]]
  - [-1, 1, Conv, [1024, 3, 2]]  # 7-P5/32
  - [-1, 3, C2f, [1024, True]]
  - [-1, 1, SPPF, [1024, 5]]  # 9

# YOLOv8.0n head
head:
  - [-1, 1, nn.Upsample, [None, 2, 'nearest']]
  - [[-1, 6], 1, Concat, [1]]  # cat backbone P4
  - [-1, 3, C2f, [512]]  # 12

  - [-1, 1, nn.Upsample, [None, 2, 'nearest']]
  - [[-1, 4], 1, Concat, [1]]  # cat backbone P3
  - [-1, 3, C2f, [256]]  # 15 (P3/8-small)
  - [-1, 1, EMA, []]  # 16

  - [-1, 1, Conv, [256, 3, 2]]
  - [[-1, 12], 1, Concat, [1]]  # cat head P4
  - [-1, 3, C2f, [512]]  # 19 (P4/16-medium)
  - [-1, 1, EMA, []]  # 20

  - [-1, 1, Conv, [512, 3, 2]]
  - [[-1, 9], 1, Concat, [1]]  # cat head P5
  - [-1, 3, C2f, [1024]]  # 23 (P5/32-large)
  - [-1, 1, EMA, []]  # 24

  - [[16, 20, 24], 1, Detect, [nc]]  # Detect(P3, P4, P5)


5.2 EMAttention的yaml版本二

添加的版本二具体那种适合你需要大家自己多做实验来尝试。

# Ultralytics YOLO 🚀, AGPL-3.0 license
# YOLOv8 object detection model with P3-P5 outputs. For Usage examples see https://docs.ultralytics.com/tasks/detect

# Parameters
nc: 80  # number of classes
scales: # model compound scaling constants, i.e. 'model=yolov8n.yaml' will call yolov8.yaml with scale 'n'
  # [depth, width, max_channels]
  n: [0.33, 0.25, 1024]  # YOLOv8n summary: 225 layers,  3157200 parameters,  3157184 gradients,   8.9 GFLOPs
  s: [0.33, 0.50, 1024]  # YOLOv8s summary: 225 layers, 11166560 parameters, 11166544 gradients,  28.8 GFLOPs
  m: [0.67, 0.75, 768]   # YOLOv8m summary: 295 layers, 25902640 parameters, 25902624 gradients,  79.3 GFLOPs
  l: [1.00, 1.00, 512]   # YOLOv8l summary: 365 layers, 43691520 parameters, 43691504 gradients, 165.7 GFLOPs
  x: [1.00, 1.25, 512]   # YOLOv8x summary: 365 layers, 68229648 parameters, 68229632 gradients, 258.5 GFLOP

# YOLOv8.0n backbone
backbone:
  # [from, repeats, module, args]
  - [-1, 1, Conv, [64, 3, 2]]  # 0-P1/2
  - [-1, 1, Conv, [128, 3, 2]]  # 1-P2/4
  - [-1, 3, C2f, [128, True]]
  - [-1, 1, Conv, [256, 3, 2]]  # 3-P3/8
  - [-1, 6, C2f, [256, True]]
  - [-1, 1, Conv, [512, 3, 2]]  # 5-P4/16
  - [-1, 6, C2f, [512, True]]
  - [-1, 1, Conv, [1024, 3, 2]]  # 7-P5/32
  - [-1, 3, C2f, [1024, True]]
  - [-1, 1, SPPF, [1024, 5]]  # 9

# YOLOv8.0n head
head:
  - [-1, 1, nn.Upsample, [None, 2, 'nearest']]
  - [[-1, 6], 1, Concat, [1]]  # cat backbone P4
  - [-1, 3, C2f, [512]]  # 12

  - [-1, 1, nn.Upsample, [None, 2, 'nearest']]
  - [[-1, 4], 1, Concat, [1]]  # cat backbone P3
  - [-1, 3, C2f, [256]]  # 15 (P3/8-small)


  - [-1, 1, Conv, [256, 3, 2]]
  - [[-1, 12], 1, Concat, [1]]  # cat head P4
  - [-1, 3, C2f, [512]]  # 18 (P4/16-medium)


  - [-1, 1, Conv, [512, 3, 2]]
  - [[-1, 9], 1, Concat, [1]]  # cat head P5
  - [-1, 3, C2f, [1024]]  # 21 (P5/32-large)
  - [-1, 1, EMA, []]  # 22

  - [[15, 18, 22], 1, Detect, [nc]]  # Detect(P3, P4, P5)

5.3 推荐EMAttention可添加的位置 

EMAttention是一种即插即用的可替换卷积的模块,其可以添加的位置有很多,添加的位置不同效果也不同,所以我下面推荐几个添加的位,置大家可以进行参考,当然不一定要按照我推荐的地方添加。

  1. 残差连接中:在残差网络的残差连接中加入EMAttention

  2. Neck部分:YOLOv8的Neck部分负责特征融合,这里添加EMAttention可以帮助模型更有效地融合不同层次的特征(yaml文件一和二)

  3. Backbone:可以替换中干网络中的卷积部分

  4. 能添加的位置很多,一篇文章很难全部介绍到,后期我会发文件里面集成上百种的改进机制,然后还有许多融合模块,给大家。


5.4 EMAttention的训练过程截图 

下面是添加了EMAttention的训练截图。

大家可以看下面的运行结果和添加的位置所以不存在我发的代码不全或者运行不了的问题大家有问题也可以在评论区评论我看到都会为大家解答(我知道的)。


五、本文总结

到此本文的正式分享内容就结束了,在这里给大家推荐我的YOLOv8改进有效涨点专栏,本专栏目前为新开的平均质量分98分,后期我会根据各种最新的前沿顶会进行论文复现,也会对一些老的改进机制进行补充,目前本专栏免费阅读(暂时,大家尽早关注不迷路~),如果大家觉得本文帮助到你了,订阅本专栏,关注后续更多的更新~

专栏回顾:YOLOv8改进系列专栏——本专栏持续复习各种顶会内容——科研必备

​​​​

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/280877.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

竞赛保研 基于机器学习与大数据的糖尿病预测

文章目录 1 前言1 课题背景2 数据导入处理3 数据可视化分析4 特征选择4.1 通过相关性进行筛选4.2 多重共线性4.3 RFE(递归特征消除法)4.4 正则化 5 机器学习模型建立与评价5.1 评价方式的选择5.2 模型的建立与评价5.3 模型参数调优5.4 将调参过后的模型重…

Peter算法小课堂—浮点数危机

大家先想想下面这个代码运行结果&#xff1a; #include <bits/stdc.h> using namespace std; int main(){double x5.2;double y4.11.1;cout<<(x<y)<<endl;cout<<x-y<<endl;return 0; } 最终发现&#xff0c; &#xff1f;&#xff1f;&…

【数据结构】八、查找

一、基本概念 静态查找&#xff1a;只查找&#xff0c;不改变集合内数据元素 动态查找&#xff1a;有则输出元素&#xff0c;无则添加元素 二、静态查找表 2.1顺序查找 在线性表、链表、树中依次查找 2.2折半查找&#xff08;二分查找&#xff09; 在有序的线性表中&…

条件编译处理多端差异

条件编译https://uniapp.dcloud.net.cn/tutorial/platform.html#%E4%B8%BA%E4%BB%80%E4%B9%88%E9%80%89%E6%8B%A9%E6%9D%A1%E4%BB%B6%E7%BC%96%E8%AF%91%E5%A4%84%E7%90%86%E8%B7%A8%E7%AB%AF%E5%85%BC%E5%AE%B9 <template><view class"container"><…

分类模型评估方法

1.数据集划分 1.1 为什么要划分数据集? 思考&#xff1a;我们有以下场景&#xff1a; 将所有的数据都作为训练数据&#xff0c;训练出一个模型直接上线预测 每当得到一个新的数据&#xff0c;则计算新数据到训练数据的距离&#xff0c;预测得到新数据的类别 存在问题&…

【滑动窗口】C++算法:可见点的最大数目

作者推荐 动态规划 多源路径 字典树 LeetCode2977:转换字符串的最小成本 本题涉及知识点 滑动窗口 LeetCode 1610可见点的最大数目 给你一个点数组 points 和一个表示角度的整数 angle &#xff0c;你的位置是 location &#xff0c;其中 location [posx, posy] 且 point…

【MySQL】事务Transaction

1. 事务的概念 事务是什么 在业务逻辑中使用sql&#xff0c;面对一些较复杂的场景&#xff0c;是需要多个sql语句组合起来实现的。如&#xff1a;银行的转账业务&#xff0c;若客户A要转账100元给客户B&#xff0c;就要两条sql&#xff1a;A余额减100&#xff0c;B余额加100&a…

react-router-dom5升级到6

前言 升级前版本为5.1.2 下载与运行 下载 npm install react-router-dom6运行 运行发现报错: 将node_modules删除&#xff0c;重新执行npm i即可 运行发现如下报错 这是因为之前有引用react-router-dom.min&#xff0c;v6中取消了该文件&#xff0c;所以未找到文件导致报错。…

浅谈数字孪生的应用与发展

1、数字孪生概念 ”数字孪生是充分利用物理模型、传感器更新、运行历史等数据,集成多学科、多物理量、多尺度、多概率的仿真过程,在虚拟空间中完成映射,从而反映相对应的实体装备的全生命周期过程。数字孪生是一种超越现实的概念,可以被视为一个或多个重要的、彼此依赖的装…

Kubernetes集群部署Rook Ceph实现文件存储,对象存储,块存储

Kubernetes集群部署Rook Ceph部署Ceph集群 1. Rook Ceph介绍 Rook Ceph是Rook项目中的一个存储方案&#xff0c;专门针对Ceph存储系统进行了优化和封装。Ceph是一个高度可扩展的分布式存储系统&#xff0c;提供了对象存储、块存储和文件系统的功能&#xff0c;广泛应用于提供…

Spring Data Redis对象缓存序列化问题

相信在项目中&#xff0c;你一定是经常使用 Redis &#xff0c;那么&#xff0c;你是怎么使用的呢&#xff1f;在使用时&#xff0c;有没有遇到同我一样&#xff0c;对象缓存序列化问题的呢&#xff1f;那么&#xff0c;你又是如何解决的呢&#xff1f; Redis 使用示例 添加依…

Stable Diffusion WebUI制作光影文字效果

在huggingface上下载control_v1p_sd15_brightness模型。 将模型放在stable-diffusion-webui\extensions\sd-webui-controlnet\models目录下。 SD参数配置 正向提示词&#xff1a; city,Building,tall building,Neon Light, gentle light shines through, anime style, paint…

AI模型训练【偏差/方差】与【欠拟合/过拟合】

在我们拿到一个数据集&#xff0c;高高兴兴准备训练一个模型时&#xff0c;会遇到欠拟合或过拟合的问题&#xff0c;业内也喜欢用偏差和方差这两指标去定义它们&#xff0c;那这些词什么意思呢&#xff1f;有什么方法能避免/解决 欠拟合和过拟合呢&#xff1f; 这其实是非常非常…

【测试基础】构造测试数据之 MySQL 篇

构造测试数据之 MySQL 篇 作为一名测试工程师&#xff0c;我们经常会构造测试数据进行一些功能验证。为了暴露更多的问题&#xff0c;在测试数据的构造上&#xff0c;我们应该尽可能的构造不同类型的字段数据&#xff0c;且一张表的字段最好不低于 10 10 10 个。 对于 MySQL …

UDP信号多个电脑的信息传输测试、配置指南

最近要做一个东西&#xff0c;关于一个软件上得到的信号&#xff0c;如何通过连接的局域网&#xff0c;将数据传输出去。我没做过相关的东西&#xff0c;但是我想应该和软件连接数据库的过程大致是差不多的&#xff0c;就一个ip和一个端口号啥的。 一.问题思路 多个设备同时连…

自动化测试系列 之 Python单元测试框架unittest

一、概述 什么是单元测试 单元测试是一种软件测试方法&#xff0c;是测试最小的可测试单元&#xff0c;通常是一个函数或一个方法。 在软件开发过程中&#xff0c;单元测试作为一项重要的测试方法被广泛应用。 为什么需要单元测试 单元测试是软件开发中重要的一环&#xf…

微服务系列之分布式事务理论

概述 事务是由一组操作构成的可靠的独立的工作单元&#xff0c;事务具备ACID的特性&#xff0c;即原子性、一致性、隔离性和持久性。 分类 大多数情况下&#xff0c;分类是没有意义的一件事。但是分类可以一定程度上&#xff0c;加深理解。 实现 从实现角度来看&#xff0…

c语言函数篇——递归函数

递归函数的工作原理 递归函数的工作原理基于两个主要部分&#xff1a;基本情况和递归情况。基本情况是函数不再调用自身的条件&#xff0c;当达到基本情况时&#xff0c;递归停止并返回结果。递归情况是函数调用自身的部分&#xff0c;它将问题分解为更小的、相似的子问题。 …

【Matlab】基于遗传算法优化BP神经网络 (GA-BP)的数据时序预测

资源下载&#xff1a; https://download.csdn.net/download/vvoennvv/88682033 一&#xff0c;概述 基于遗传算法优化BP神经网络 (GA-BP) 的数据时序预测是一种常用的机器学习方法&#xff0c;用于预测时间序列数据的趋势和未来值。 在使用这种方法之前&#xff0c;需要将时间序…

微信小程序开发系列-07组件

微信小程序开发系列目录 《微信小程序开发系列-01创建一个最小的小程序项目》《微信小程序开发系列-02注册小程序》《微信小程序开发系列-03全局配置中的“window”和“tabBar”》《微信小程序开发系列-04获取用户图像和昵称》《微信小程序开发系列-05登录小程序》《微信小程序…