YOLOv8 | 有效涨点,添加GAM注意力机制,使用Wise-IoU有效提升目标检测效果(附报错解决技巧,全网独家)

 目录

摘要

基本原理

通道注意力机制

空间注意力机制

GAM代码实现 

Wise-IoU 

WIoU代码实现

yaml文件编写

完整代码分享(含多种注意力机制)


摘要

人们已经研究了各种注意力机制来提高各种计算机视觉任务的性能。然而,现有方法忽视了保留通道和空间方面的信息以增强跨维度交互的重要性。因此,我们提出了一种全局注意力机制,通过减少信息减少和放大全局交互表示来提高深度神经网络的性能。引入了具有多层感知器的 3D 排列,用于通道注意以及卷积空间注意子模块。在 CIFAR-100 和 ImageNet-1K 上对所提出的图像分类任务机制的评估表明,我们的方法稳定优于最近使用 ResNet 和轻量级 MobileNet 的几种注意力机制。

基本原理

目标的设计是一种减少信息缩减并放大全局维度交互特征的机制。我们采用 CBAM 的顺序通道空间注意力机制并重新设计子模块。整个过程如图 所示。

GAM结构图
通道注意力机制

通道注意力子模块使用 3D 排列来保留三个维度的信息。然后,它使用两层 MLP(多层感知器)放大跨维度通道空间依赖性。 (MLP是一种编码器-解码器结构,其缩减比为r,与BAM相同。)通道注意子模块如图所示。 

通道注意力子模块
空间注意力机制

在空间注意力子模块中,为了关注空间信息,我们使用两个卷积层进行空间信息融合。我们还使用与 BAM 相同的通道注意子模块的缩减率 r。同时,最大池化会减少信息并产生负面影响。我们删除池化以进一步保留特征图。因此,空间注意力模块有时会显着增加参数的数量。为了防止参数显着增加,我们在 ResNet50 中采用带有通道洗牌的组卷积。没有组卷积的空间注意力子模块如图所示。 

空间注意力子模块
GAM代码实现 
class GAM_Attention(nn.Module):
    def __init__(self, c1, c2, group=True, rate=4):
        super(GAM_Attention, self).__init__()

        self.channel_attention = nn.Sequential(
            nn.Linear(c1, int(c1 / rate)),
            nn.ReLU(inplace=True),
            nn.Linear(int(c1 / rate), c1)
        )

        self.spatial_attention = nn.Sequential(

            nn.Conv2d(c1, c1 // rate, kernel_size=7, padding=3, groups=rate) if group else nn.Conv2d(c1, int(c1 / rate),
                                                                                                     kernel_size=7,
                                                                                                     padding=3),
            nn.BatchNorm2d(int(c1 / rate)),
            nn.ReLU(inplace=True),
            nn.Conv2d(c1 // rate, c2, kernel_size=7, padding=3, groups=rate) if group else nn.Conv2d(int(c1 / rate), c2,
                                                                                                     kernel_size=7,
                                                                                                     padding=3),
            nn.BatchNorm2d(c2)
        )

    def forward(self, x):
        b, c, h, w = x.shape
        x_permute = x.permute(0, 2, 3, 1).view(b, -1, c)
        x_att_permute = self.channel_attention(x_permute).view(b, h, w, c)
        x_channel_att = x_att_permute.permute(0, 3, 1, 2)
        # x_channel_att=channel_shuffle(x_channel_att,4) #last shuffle
        x = x * x_channel_att

        x_spatial_att = self.spatial_attention(x).sigmoid()
        x_spatial_att = channel_shuffle(x_spatial_att, 4)  # last shuffle
        out = x * x_spatial_att
        # out=channel_shuffle(out,4) #last shuffle
        return out

以上代码添加在 ./ultralytics/nn/modules/conv.py 中

Wise-IoU 

Yolov7提出的损失函数是GIoU(Generalized Intersection over Union),能在更广义的层面上计算IoU(Intersection over Union),但是当两个预测框完全重合时,不能反映出实际情况,此时GIoU就要退化为IoU,并且GIoU对每个预测框与真实框均要计算最小外接框,故损失函数计算及收敛速度受到限制。
为了弥补这种遗憾,改进的网络中使用了WIoU(Wise-IoU)作为损失函数。WIoU v3作为边界框回归损失,包含一种动态非单调机制,并设计了一种合理的梯度增益分配,该策略减少了极端样本中出现的大梯度或有害梯度。该损失方法计算更多地关注普通质量的样本,进而提高网络模型的泛化能力和整体性能。

虽然几种主流损失函数都采用静态聚焦机制,但WIoU不仅考虑了方位角、质心距离和重叠面积,还引入了动态非单调聚焦机制。 WIoU应用合理的梯度增益分配策略来评估锚框的质量。WIoU有三个版本。 WIoU v1 设计了基于注意力的预测框损失,WIoU v2 和 WIoU v3 添加了聚焦系数。

wiou原理图

最小的包围盒(绿色)和中心点的连接(红色),其中并集的面积为 Su = wh + wgthgt − WiHi .

WIoU代码实现
def WIoU(cls, pred, target, self=None):
        self = self if self else cls(pred, target)
        dist = torch.exp(self.l2_center / self.l2_box.detach())
        return self._scaled_loss(dist * self.iou)

 下面的代码替换loss.py的class BboxLoss

class BboxLoss(nn.Module):

    def __init__(self, reg_max, use_dfl=False):
        """Initialize the BboxLoss module with regularization maximum and DFL settings."""
        super().__init__()
        self.reg_max = reg_max
        self.use_dfl = use_dfl

    def forward(self, pred_dist, pred_bboxes, anchor_points, target_bboxes, target_scores, target_scores_sum, fg_mask):
        """IoU loss."""
        weight = target_scores.sum(-1)[fg_mask].unsqueeze(-1)
        loss,iou = bbox_iou(pred_bboxes[fg_mask], target_bboxes[fg_mask], xywh=False,type_='WIoU')
        loss_iou=loss.sum()/target_scores_sum

        # DFL loss
        if self.use_dfl:
            target_ltrb = bbox2dist(anchor_points, target_bboxes, self.reg_max)
            loss_dfl = self._df_loss(pred_dist[fg_mask].view(-1, self.reg_max + 1), target_ltrb[fg_mask]) * weight
            loss_dfl = loss_dfl.sum() / target_scores_sum
        else:
            loss_dfl = torch.tensor(0.0).to(pred_dist.device)

        return loss_iou, loss_dfl
yaml文件编写
# Ultralytics YOLO 🚀, AGPL-3.0 license
# YOLOv8 object detection model with P3-P5 outputs. For Usage examples see https://docs.ultralytics.com/tasks/detect

# Parameters
nc: 1  # number of classes
scales: # model compound scaling constants, i.e. 'model=yolov8n.yaml' will call yolov8.yaml with scale 'n'
  # [depth, width, max_channels]
  n: [0.33, 0.25, 1024]  # YOLOv8n summary: 225 layers,  3157200 parameters,  3157184 gradients,   8.9 GFLOPs
  s: [0.33, 0.50, 1024]  # YOLOv8s summary: 225 layers, 11166560 parameters, 11166544 gradients,  28.8 GFLOPs
  m: [0.67, 0.75, 768]   # YOLOv8m summary: 295 layers, 25902640 parameters, 25902624 gradients,  79.3 GFLOPs
  l: [1.00, 1.00, 512]   # YOLOv8l summary: 365 layers, 43691520 parameters, 43691504 gradients, 165.7 GFLOPs
  x: [1.00, 1.25, 512]   # YOLOv8x summary: 365 layers, 68229648 parameters, 68229632 gradients, 258.5 GFLOPs

# YOLOv8.0n backbone
backbone:
  # [from, repeats, module, args]
  - [-1, 1, Conv, [64, 3, 2]]  # 0-P1/2
  - [-1, 1, Conv, [128, 3, 2]]  # 1-P2/4
  - [-1, 3, C2f, [128, True]]
  - [-1, 1, Conv, [256, 3, 2]]  # 3-P3/8
  - [-1, 6, C2f, [256, True]]
  - [-1, 1, Conv, [512, 3, 2]]  # 5-P4/16
  - [-1, 6, C2f, [512, True]]
  - [-1, 1, Conv, [1024, 3, 2]]  # 7-P5/32
  - [-1, 3, C2f, [1024, True]]
  - [-1, 3, GAM_Attention, [1024]]
  - [-1, 1, SPPF, [1024, 5]]  # 10

# YOLOv8.0n head
head:
  - [-1, 1, nn.Upsample, [None, 2, 'nearest']]
  - [[-1, 6], 1, Concat, [1]]  # cat backbone P4
  - [-1, 3, C2f, [512]]  # 13
  #- [-1, 1, GAM_Attention, [512,512]]

  - [-1, 1, nn.Upsample, [None, 2, 'nearest']]
  - [[-1, 4], 1, Concat, [1]]  # cat backbone P3
  - [-1, 3, C2f, [256]]  # 16 (P3/8-small)
  #- [-1, 1, GAM_Attention, [256,256]]

  - [-1, 1, Conv, [256, 3, 2]]
  - [[-1, 13], 1, Concat, [1]]  # cat head P4
  - [-1, 3, C2f, [512]]  # 19 (P4/16-medium)
  #- [-1, 1, GAM_Attention, [512,512]]

  - [-1, 1, Conv, [512, 3, 2]]
  - [[-1, 10], 1, Concat, [1]]  # cat head P5
  - [-1, 3, C2f, [1024]]  # 22 (P5/32-large)
  #- [-1, 1, GAM_Attention, [1024,1024]]

  - [[16, 19, 22], 1, Detect, [nc]]  # Detect(P3, P4, P5)
完整代码分享(含多种注意力机制)

内涵SA,CBAM,GAM,ECA等多种注意力机制

链接: https://pan.baidu.com/s/1T9bVifTPCRMv2t7eREsuEw?pwd=nbrt 提取码: nbrt 

报错解决办法

YOLOv8 | 添加注意力机制报错KeyError:已解决,详细步骤-CSDN博客

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/460041.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

Paimon新版本核心特性和生产实践解读

最近Apche Paimon发布了最新版本0.7.0,在这个版本中,Paimon对一些新特性进行了增强。 Paimon在数据湖领域发展迅速,未来会在整个数据开发领域占有很重要的地位,今天我们来盘点一下当前能力的特点以及在生产环境中的使用情况。 Loo…

【C++】手撕AVL树

> 作者简介:დ旧言~,目前大二,现在学习Java,c,c,Python等 > 座右铭:松树千年终是朽,槿花一日自为荣。 > 目标:能直接手撕AVL树。 > 毒鸡汤:放弃自…

react-native使用FireBase实现google登陆

一、前置操作 首先下载这个包 yarn add react-native-google-signin/google-signin 二、Google cloud配置 Google Cloud 去google控制台新建一个android项目,这时候需要用到你自己创建的keystore的sha1值,然后会让你下载一个JSON文件,先保…

【Linux进阶之路】HTTPS = HTTP + S

文章目录 一、概念铺垫1.Session ID2.明文与密文3.公钥与私钥4.HTTPS结构 二、加密方式1. 对称加密2.非对称加密3.CA证书 总结尾序 一、概念铺垫 1.Session ID Session ID,即会话ID,用于标识客户端与服务端的唯一特定会话的标识符。会话,即客…

某鱼弹幕逆向

声明: 本文章中所有内容仅供学习交流使用,不用于其他任何目的,不提供完整代码,抓包内容、敏感网址、数据接口等均已做脱敏处理,严禁用于商业用途和非法用途,否则由此产生的一切后果均与作者无关!wx a15018…

Delft3D建模、水动力模拟方法及在地表水环境影响评价中的技术应用

​任博士,长期从事地表水数值模拟研究与实践工作,具有资深的技术底蕴和专业背景。 1、掌握Delft3D的建模流程,包括基础数据的准备、计算网格的制作、模型的调试与率定、计算结果的处理等,熟悉软件的基本操作。 2、熟悉Delft3D网…

java---网络初始

一.局域网和广域网 随着时代的发展,越来越需要计算机之间互相通信,共享软件和数据,即以多个计算机协同工作来完成业务,就有了网络互连。 网络互连:将多台计算机连接在一起,完成数据共享。数据共享本质是网…

开发反应式API

开发反应式API 开发反应式API1 使用SpringWebFlux1.1 Spring WebFlux 简介1.2 编写反应式控制器 2 定义函数式请求处理器3 测试反应式控制器3.1 测试 GET 请求3.2 测试 POST 请求3.3 使用实时服务器进行测试 4 反应式消费RESTAPI4.1 获取资源4.2 发送资源4.3 删除资源4.4 处理错…

基于springboot+vue实现养老服务管理系统项目【项目源码+论文说明】计算机毕业设计

基于springbootvue实现养老服务管理系统演示 摘要 医疗水平和生活水平的不断提高造就了我们现在稳定、发展的社会,带来受益的同时也加重了人口老龄化程度。随着人口老龄化程度的不断加深,越来越多的社会资源在对养老方面注入。那么面对如此快速发展的养…

Go微服务实战——服务的注册与获取(nacos做服务注册中心)

背景 随着访问量的逐渐增大,单体应用结构渐渐不满足需求,在微服务出现之后引用被拆分为一个个的服务,服务之间可以互相访问。初期服务之间的调用只要知道服务地址和端口即可,而服务会出现增减、故障、升级等变化导致端口和ip也变…

在OpenStack架构中,Controller节点的配置(基础)

虚拟机的安装 新建虚拟机,选择自定义 默认选择即可 操作系统的镜像稍后选择 客户及操作系统选择Linux,注意选择centos 7 64位 给虚拟机命名 处理器的配置建议1:2 内存大小选择建议为:4GB 网络连接选择为:NAT 默认即可…

蓝桥杯2022年第十三届省赛真题-灭鼠先锋

LLLV solution1 必输:只有一个格子 手算可以模拟出来~ solution2 OOOO状态下,谁先下谁必输 》问题转化为谁先下满第一排,谁必赢,可以非常容易的模拟出来

Vue.js+SpringBoot开发天沐瑜伽馆管理系统

目录 一、摘要1.1 项目介绍1.2 项目录屏 二、功能模块2.1 数据中心模块2.2 瑜伽课程模块2.3 课程预约模块2.4 系统公告模块2.5 课程评价模块2.6 瑜伽器械模块 三、系统设计3.1 实体类设计3.1.1 瑜伽课程3.1.2 瑜伽课程预约3.1.3 系统公告3.1.4 瑜伽课程评价 3.2 数据库设计3.2.…

uniapp实现点击标签文本域中显示标签内容

先上一个效果图 实现的效果有: ①.点击标签时,标签改变颜色并处于可删除状态 ②.切换标签,文本域中出现标签的内容 ③.点击标签右上角的删除可删掉标签,同时清除文本域中标签的内容 ④.可输入内容,切换时不影响输入…

考研C语言复习进阶(5)

目录 1. 为什么使用文件 2. 什么是文件 2.1 程序文件 2.2 数据文件 2.3 文件名 3. 文件的打开和关闭 3.1 文件指针 3.2 文件的打开和关闭 4. 文件的顺序读写 ​编辑 ​编辑 4.1 对比一组函数: ​编辑 5. 文件的随机读写 5.1 fseek 5.2 ftell 5.3 rewind…

tomcat中把项目放在任意目录中的步骤

java web 项目由idea开发&#xff0c;路径如下图所示&#xff1a; 1.在tomcat安装目录conf\Catalina\localhost 里面&#xff0c;编写lesson1.xml文件内容如下&#xff1a; <Context path"/lesson1" docBase"C:\Users\信息技术系\Desktop\2024\学校工作\jav…

基于51单片机的微波炉温度控制器设计[proteus仿真]

基于51单片机的微波炉温度控制器设计[proteus仿真] 温度检测系统这个题目算是课程设计和毕业设计中常见的题目了&#xff0c;本期是一个基于51单片机的微波炉温度控制器设计 需要的源文件和程序的小伙伴可以关注公众号【阿目分享嵌入式】&#xff0c;赞赏任意文章 2&#xff…

【矩阵】240. 搜索二维矩阵 II【中等】

搜索二维矩阵 II 编写一个高效的算法来搜索 m x n 矩阵 matrix 中的一个目标值 target 。该矩阵具有以下特性&#xff1a;每行的元素从左到右升序排列。每列的元素从上到下升序排列。 示例 1&#xff1a; 输入&#xff1a;matrix [[1,4,7,11,15],[2,5,8,12,19],[3,6,9,16,22…

C++:2024/3/12

作业1&#xff1a;试编程&#xff0c;封装一个类 要求&#xff1a;自己封装一个矩形类(Rect)&#xff0c;拥有私有属性:宽度(width)、高度(height)&#xff0c; 定义公有成员函数: 初始化函数:void init(int w, int h) 更改宽度的函数:set_w(int w) 更改高度的函数:set_h(…

如何打造知识管理平台,只需了解这几点

随着企业的发展&#xff0c;知识资源日益丰富和复杂&#xff0c;如果不加以有效管理和整合&#xff0c;这些知识很可能会被埋没或丢失。打造知识管理平台可以将这些知识资源进行统一存储和分类&#xff0c;便于员工查找和使用&#xff0c;从而充分发挥知识的价值。有很多工具可…