免费阅读篇 | 芒果YOLOv8改进113:注意力机制ShuffleAttention:深度卷积神经网络的随机注意力

💡🚀🚀🚀本博客 改进源代码改进 适用于 YOLOv8 按步骤操作运行改进后的代码即可

该专栏完整目录链接: 芒果YOLOv8深度改进教程

该篇博客为免费阅读内容,YOLOv8+ShuffleAttention改进内容🚀🚀🚀

文章目录

      • 1. ShuffleAttention 论文
      • 2. YOLOv8 核心代码改进部分
      • 2.1 核心新增代码
        • 2.2 修改部分
      • 2.3 YOLOv8-SA 网络配置文件
      • 2.4 运行代码
      • 改进说明


1. ShuffleAttention 论文

请添加图片描述
注意力机制使神经网络能够准确地关注输入的所有相关元素,已成为提高深度神经网络性能的重要组成部分。计算机视觉研究中广泛使用的注意力机制主要有两种:空间注意力和通道注意力,其目的分别是捕获像素级的成对关系和通道依赖性。虽然将它们融合在一起可能会比它们单独的实现获得更好的性能,但它不可避免地会增加计算开销。在本文中,我们提出了一种高效的洗牌注意力(SA)模块来解决这个问题,它采用洗牌单元来有效地结合两种类型的注意机制。具体来说,SA 首先将通道维度分组为多个子特征,然后并行处理它们。然后,对于每个子特征,SA 利用洗牌单元来描述空间和通道维度上的特征依赖性。之后,所有子特征被聚合,并采用“通道洗牌”算子来实现不同子特征之间的信息通信。所提出的 SA 模块高效且有效,例如,SA 针对主干 ResNet50 的参数和计算量分别为 300 vs. 25.56M 和 2.76e-3 GFLOPs vs. 4.12 GFLOPs,并且性能提升超过 1.34% Top-1 准确度方面。对常用基准(包括用于分类的 ImageNet-1k、用于对象检测的 MS COCO 和实例分割)的大量实验结果表明,所提出的 SA 通过实现更高的准确度和更低的模型复杂度,显着优于当前的 SOTA 方法
在这里插入图片描述

具体细节可以去看原论文:https://arxiv.org/pdf/2102.00240.pdf


2. YOLOv8 核心代码改进部分

2.1 核心新增代码

首先在ultralytics/nn/modules文件夹下,创建一个 sa.py文件,新增以下代码

import numpy as np
import torch
from torch import nn
from torch.nn import init
from torch.nn.parameter import Parameter

# https://arxiv.org/pdf/2102.00240.pdf
class ShuffleAttention(nn.Module):

    def __init__(self, channel=512, out_channel=512, reduction=16,G=8):
        super().__init__()
        self.G=G
        self.channel=channel
        self.avg_pool = nn.AdaptiveAvgPool2d(1)
        self.gn = nn.GroupNorm(channel // (2 * G), channel // (2 * G))
        self.cweight = Parameter(torch.zeros(1, channel // (2 * G), 1, 1))
        self.cbias = Parameter(torch.ones(1, channel // (2 * G), 1, 1))
        self.sweight = Parameter(torch.zeros(1, channel // (2 * G), 1, 1))
        self.sbias = Parameter(torch.ones(1, channel // (2 * G), 1, 1))
        self.sigmoid=nn.Sigmoid()


    def init_weights(self):
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                init.kaiming_normal_(m.weight, mode='fan_out')
                if m.bias is not None:
                    init.constant_(m.bias, 0)
            elif isinstance(m, nn.BatchNorm2d):
                init.constant_(m.weight, 1)
                init.constant_(m.bias, 0)
            elif isinstance(m, nn.Linear):
                init.normal_(m.weight, std=0.001)
                if m.bias is not None:
                    init.constant_(m.bias, 0)


    @staticmethod
    def channel_shuffle(x, groups):
        b, c, h, w = x.shape
        x = x.reshape(b, groups, -1, h, w)
        x = x.permute(0, 2, 1, 3, 4)

        # flatten
        x = x.reshape(b, -1, h, w)

        return x

    def forward(self, x):
        b, c, h, w = x.size()
        #group into subfeatures
        x=x.view(b*self.G,-1,h,w) #bs*G,c//G,h,w

        #channel_split
        x_0,x_1=x.chunk(2,dim=1) #bs*G,c//(2*G),h,w

        #channel attention
        x_channel=self.avg_pool(x_0) #bs*G,c//(2*G),1,1
        x_channel=self.cweight*x_channel+self.cbias #bs*G,c//(2*G),1,1
        x_channel=x_0*self.sigmoid(x_channel)

        #spatial attention
        x_spatial=self.gn(x_1) #bs*G,c//(2*G),h,w
        x_spatial=self.sweight*x_spatial+self.sbias #bs*G,c//(2*G),h,w
        x_spatial=x_1*self.sigmoid(x_spatial) #bs*G,c//(2*G),h,w

        # concatenate along channel axis
        out=torch.cat([x_channel,x_spatial],dim=1)  #bs*G,c//G,h,w
        out=out.contiguous().view(b,-1,h,w)

        # channel shuffle
        out = self.channel_shuffle(out, 2)
        return out
2.2 修改部分

在ultralytics/nn/modules/init.py中导入 定义在 sa.py 里面的模块

from .sa import ShuffleAttention

'ShuffleAttention' 加到 __all__ = [...] 里面

第一步:
ultralytics/nn/tasks.py文件中,新增

from ultralytics.nn.modules import ShuffleAttention

然后在 在tasks.py中配置
找到

        elif m is nn.BatchNorm2d:
        args = [ch[f]]

在这句上面加一个

        elif m is ShuffleAttention:
            c1, c2 = ch[f], args[0]
            if c2 != nc:  # if c2 not equal to number of classes (i.e. for Classify() output)
                c2 = make_divisible(min(c2, max_channels) * width, 8)

            args = [c1, c2, *args[1:]]

2.3 YOLOv8-SA 网络配置文件

新增YOLOv8-SA.yaml

# Ultralytics YOLO 🚀, AGPL-3.0 license
# YOLOv8 object detection model with P3-P5 outputs. For Usage examples see https://docs.ultralytics.com/tasks/detect

# Parameters
nc: 80  # number of classes
scales: # model compound scaling constants, i.e. 'model=yolov8n.yaml' will call yolov8.yaml with scale 'n'
  # [depth, width, max_channels]
  n: [0.33, 0.25, 1024]  # YOLOv8n summary: 225 layers,  3157200 parameters,  3157184 gradients,   8.9 GFLOPs
  s: [0.33, 0.50, 1024]  # YOLOv8s summary: 225 layers, 11166560 parameters, 11166544 gradients,  28.8 GFLOPs
  m: [0.67, 0.75, 768]   # YOLOv8m summary: 295 layers, 25902640 parameters, 25902624 gradients,  79.3 GFLOPs
  l: [1.00, 1.00, 512]   # YOLOv8l summary: 365 layers, 43691520 parameters, 43691504 gradients, 165.7 GFLOPs
  x: [1.00, 1.25, 512]   # YOLOv8x summary: 365 layers, 68229648 parameters, 68229632 gradients, 258.5 GFLOPs

# YOLOv8.0n backbone
backbone:
  # [from, repeats, module, args]
  - [-1, 1, Conv, [64, 3, 2]]  # 0-P1/2
  - [-1, 1, Conv, [128, 3, 2]]  # 1-P2/4
  - [-1, 3, C2f, [128, True]]
  - [-1, 1, Conv, [256, 3, 2]]  # 3-P3/8
  - [-1, 6, C2f, [256, True]]
  - [-1, 1, Conv, [512, 3, 2]]  # 5-P4/16
  - [-1, 6, C2f, [512, True]]
  - [-1, 1, Conv, [1024, 3, 2]]  # 7-P5/32
  - [-1, 3, C2f, [1024, True]]
  - [-1, 3, ShuffleAttention, [1024]]
  - [-1, 1, SPPF, [1024, 5]]  # 9

# YOLOv8.0n head
head:
  - [-1, 1, nn.Upsample, [None, 2, 'nearest']]
  - [[-1, 6], 1, Concat, [1]]  # cat backbone P4
  - [-1, 3, C2f, [512]]  # 12

  - [-1, 1, nn.Upsample, [None, 2, 'nearest']]
  - [[-1, 4], 1, Concat, [1]]  # cat backbone P3
  - [-1, 3, C2f, [256]]  # 15 (P3/8-small)

  - [-1, 1, Conv, [256, 3, 2]]
  - [[-1, 13], 1, Concat, [1]]  # cat head P4
  - [-1, 3, C2f, [512]]  # 18 (P4/16-medium)

  - [-1, 1, Conv, [512, 3, 2]]
  - [[-1, 10], 1, Concat, [1]]  # cat head P5
  - [-1, 3, C2f, [1024]]  # 21 (P5/32-large)

  - [[16, 19, 22], 1, Detect, [nc]]  # Detect(P3, P4, P5)

2.4 运行代码

直接替换YOLOv8-SA.yaml 进行训练即可

到这里就完成了这篇的改进。

改进说明

这里改进是放在了主干后面,如果想放在改进其他地方,也是可以的。直接新增,然后调整通道,配齐即可,如果有不懂的,可以添加博主联系方式,如下


🥇🥇🥇
添加博主联系方式:

友好的读者可以添加博主QQ: 2434798737, 有空可以回答一些答疑和问题

🚀🚀🚀


参考

https://github.com/ultralytics/ultralytics

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/463435.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

矩阵中移动的最大次数

文章目录 所属专栏:BFS算法 题目链接 思路如下: 1.首先我们需要从第一列开始遍历,寻找每一个都能够满足条件的位置,将它插入到数组里面 2.第一列遍历完了后我们先判断第一列的数是否都满足条件插入到数组里面,如果数组为空&#…

JAVA13多行文本java14模式变量

文章目录 多行文本模式变量 多行文本 在JAVA13中,终于是支持多行文本字面量了。而且最关键的是为了源代码更加美观,还自动去掉了每行文本前面的空格。如下面的例子: public class MultilineStringDemo {public static void main(String[] ar…

在Linux/Ubuntu/Debian中使用windows应用程序/软件

Wine 是一个兼容层,允许你在类 Unix 操作系统(包括 Ubuntu)上运行 Windows 应用程序。 以下是在 Ubuntu 上安装和使用 Wine 的基本步骤: 在 Ubuntu 上安装 Wine: 更新软件包列表: 打开终端并运行以下命令以…

【Miniconda】一文了解conda虚拟环境的作用

【Miniconda】一文了解conda虚拟环境的作用 🌈 个人主页:高斯小哥 🔥 高质量专栏:Matplotlib之旅:零基础精通数据可视化、Python基础【高质量合集】、PyTorch零基础入门教程👈 希望得到您的订阅和支持~ &am…

macOS 安装 NetLogo 6.4.0

netlogo 下载地址 NetLogo-6.4.0.dmg参考 netlogo 官网

源码编译部署LAMP

编译部署LAMP 配置apache [rootzyq ~]#: wget https://downloads.apache.org/apr/apr-1.7.4.tar.gz --2023-12-11 14:35:57-- https://downloads.apache.org/apr/apr-1.7.4.tar.gz Resolving downloads.apache.org (downloads.apache.org)... 88.99.95.219, 135.181.214.104…

快递批量查询神器:一键解锁物流数据,高效管理轻松实现!

在这个信息爆炸的时代,拥有「多维度快递查询分析工具」,就意味着拥有了洞悉市场脉动的能力。让您的商业决策更加明智,让您的企业在竞争中脱颖而出。现在就加入我们的行列吧!一键掌握物流数据,洞悉市场脉动,…

npm下载慢换国内镜像地址

1 设置淘宝镜像地址 npm config set registry http://registry.npm.taobao.org 2 查看当前下载地址 npm config get registry 3 其它镜像地址列表: 1. 官方镜像:https://registry.npmjs.org/ 2. 淘宝镜像:https://registry.npm.taobao.o…

【渗透测试】

渗透测试 一、kali信息搜集1、端口扫描常见的端口服务工具: 2、目录扫描工具: 3、CMS系统指纹识别工具: 4、资产收集工具 二、kali漏洞利用1、漏洞数据库2、漏洞扫描工具商业漏洞扫描工具免费漏洞扫描工具国产商业漏洞扫描产品 3、msf反弹连接…

Django验证码(二)

一、生成图片 1.1、说明 通过pillow模板库生成图片,步骤如下 安装pillow模板建立 生成验证码内容 方法建立 生成验证码颜色 方法建立 生成验证码 方法1.2、需要安装 Pillow 库 pip install Pillow==9.3.01.3、生成验证码内容 import randomdef random_str(length=4):"…

HTTPS协议原理

文章目录 1. 什么是https2. 什么是加密3. 为什么要加密3. 常见的加密方式3.1 对称加密3.2 非对称加密 4. 数据指纹5. https工作过程5.1 加密方案5.2 引入证书5.3 完整流程 1. 什么是https 网络协议栈当中,传输层和网络层是在操作系统内部,应用层的请求是…

多站合一的音乐搜索下载助手PHP源码l亲测

源码获取方式 回复:031601 搭建教程: 将源码下载上传至宝塔面板,直接运行即可~ 说明: 该源码进行测试,测试成功源码无加密优化相关其他采集问题。

MySQL语法分类 DQL(1)基础查询

//语法 select 字段列表 from 表名列表 where条件列表 group by分组字段 having 分组后的条件 order by排序 limit 分页限定为了更好的学习这里给出基本表数据用于查询操作 create table student (id int, name varchar(20), age int, sex varchar(5),address varchar(100),ma…

白话模电:3.三极管(考研面试与笔试常考问题)

一、三极管的简单判断 1.判断三极 1)给了图 左边是b,有箭头是e,剩下是c 2)给了电位 b:中间值,e:较近值(离中间值),c:较远值(离中间值) 2.判断流向 bc同向(共同流向“|”或共同流离“|”),e与bc反向 3.判断材料 4.判断类型 5.判断能否构…

261:vue+openlayers 使用setRotation旋转地图

第261个 点击查看专栏目录 本示例介绍演示如何在vue+openlayers中使用setRotation旋转地图。setRotation是view的一个方法,旋转的内容是弧度,这里设置的角度需要将其换算为弧度,即 x*Math.PI/180. 直接复制下面的 vue+openlayers源代码,操作2分钟即可运行实现效果 文章目…

分享一篇Oracle RAC实战安装11G

分享一次很久以前的Oracle rac项目实施。 1、拓扑结构 基础环境是2台H3C的服务器2台3PAR的双活存储,操作系统centos7.2。借用下别人家的拓扑先(这是一套典型的RAC架构)。 2、网卡TEAM操作 以eno51和en052组成Team1组为示例: nm…

Tengine编译安装

首先下载源码 可以去官网:The Tengine Web Server 当前最新版地址:https://tengine.taobao.org/download/tengine-3.1.0.tar.gz 安装编译依赖 yum -y install gcc pcre-devel zlib-devel openssl-devel libxml2-devel \libxslt-devel gd-devel GeoIP…

如何在Windows系统搭建Emby影音平台并实现远程访问本地文件【内网穿透】

文章目录 1.前言2. Emby网站搭建2.1. Emby下载和安装2.2 Emby网页测试 3. 本地网页发布3.1 注册并安装cpolar内网穿透3.2 Cpolar云端设置3.3 Cpolar内网穿透本地设置 4.公网访问测试5.结语 1.前言 在现代五花八门的网络应用场景中,观看视频绝对是主力应用场景之一&…

AI - 支持向量机算法

🧨概念 支持向量机(Support Vector Machine, SVM)是一种强大的机器学习算法,主要用于解决二分类问题。 SVM的核心思想是找到一个超平面,这个超平面能够最好地将数据分为两类,即在保证分类准确的情况下&am…

android studio 连接mumu模拟器调试

1、打开mumu模拟器 2、在Android Studio 中 控制台 cd 到 sdk 目录下 platform-tools 文件夹,有一个adb.exe 可运行程序 一般指令: adb connect 127.0.0.1:7555 但是这个执行在window环境下可能会报错 解决方法是在 adb 之前加 ".\", 问题…