YOLOv11模型改进-模块-引入空间池化模块StripPooling 解决遮挡、小目标

         本篇文章将介绍一个新的改进机制——空间池化模块StripPooling,并阐述如何将其应用于YOLOv11中,显著提升模型性能。首先,我们将解析StripPooling的工作原理,SP模块通过条带池化在水平和垂直方向上捕捉长距离依赖关系,增强全局和局部特征表达。随后,本文将探讨如何将SP模块与YOLOv11相结合,以提升目标检测的性能。

1. 空间池化模块StripPooling (SP) 结构介绍    

        本文介绍了一种新的空间池化策略,称为Strip Pooling,用于场景解析。以下是其结构和作用的简要总结:

        Strip Pooling Module (SPM):包括两个平行路径,分别进行水平和垂直方向的条带池化,然后通过1D卷积层进行特征调整。然后将他们融合在一起经过1x1的卷积和sigmoid,形成权重。

        作用

        1. 捕捉长距离依赖:通过长条形的池化窗口,有效捕捉分布离散的区域之间的关系。

        2. 避免不相关信息干扰:窄条形的池化窗口有助于捕捉局部上下文,防止不相关区域干扰。     

2. YOLOv11与SP的结合

        YOLOv11作为YOLO系列的最新版本,继承了其前辈的高效性和准确性。然而,传统的YOLO模型在处理长距离依赖和复杂背景时仍存在一定的局限性。为了解决这一问题,我们引入了Strip Pooling (SP) 模块,该模块通过长条形的池化窗口,有效地捕捉全局上下文信息。

1.  SP模块集成到YOLOv11的C3K2模块模块中:YOLOv11的C3K2模块模块旨在通过多层卷积和池化操作提取图像中的多尺度特征。我们可以在C3K2模块的某些卷积层后插入SP模块,以优化特征信息的筛选。

2.  SP模块集成到YOLOv11的backbone中: 将SP添加到YOLOv11backbone中的SPPF模块之前,增强SPPF模块的全局上下文信息。

3. 空间池化模块StripPooling (SP) 代码部分

import torch
from torch import nn
import torch.nn.functional as F
# from conv import Conv
# from block import C2f, C3k

from .conv import Conv
from .block import C2f, C3, Bottleneck


class StripPooling(nn.Module):
    """
    Reference:(20, 12), nn.BatchNorm2d, {'mode': 'bilinear', 'align_corners': True}
    """
    def __init__(self, in_channels, pool_size= (20, 12), norm_layer=nn.BatchNorm2d, up_kwargs={'mode': 'bilinear', 'align_corners': True}):
        super(StripPooling, self).__init__()
        self.pool1 = nn.AdaptiveAvgPool2d(pool_size[0])
        self.pool2 = nn.AdaptiveAvgPool2d(pool_size[1])
        self.pool3 = nn.AdaptiveAvgPool2d((1, None))
        self.pool4 = nn.AdaptiveAvgPool2d((None, 1))

        inter_channels = int(in_channels/4)
        self.conv1_1 = nn.Sequential(nn.Conv2d(in_channels, inter_channels, 1, bias=False),
                                norm_layer(inter_channels),
                                nn.ReLU(True))
        self.conv1_2 = nn.Sequential(nn.Conv2d(in_channels, inter_channels, 1, bias=False),
                                norm_layer(inter_channels),
                                nn.ReLU(True))
        self.conv2_0 = nn.Sequential(nn.Conv2d(inter_channels, inter_channels, 3, 1, 1, bias=False),
                                norm_layer(inter_channels))
        self.conv2_1 = nn.Sequential(nn.Conv2d(inter_channels, inter_channels, 3, 1, 1, bias=False),
                                norm_layer(inter_channels))
        self.conv2_2 = nn.Sequential(nn.Conv2d(inter_channels, inter_channels, 3, 1, 1, bias=False),
                                norm_layer(inter_channels))
        self.conv2_3 = nn.Sequential(nn.Conv2d(inter_channels, inter_channels, (1, 3), 1, (0, 1), bias=False),
                                norm_layer(inter_channels))
        self.conv2_4 = nn.Sequential(nn.Conv2d(inter_channels, inter_channels, (3, 1), 1, (1, 0), bias=False),
                                norm_layer(inter_channels))
        self.conv2_5 = nn.Sequential(nn.Conv2d(inter_channels, inter_channels, 3, 1, 1, bias=False),
                                norm_layer(inter_channels),
                                nn.ReLU(True))
        self.conv2_6 = nn.Sequential(nn.Conv2d(inter_channels, inter_channels, 3, 1, 1, bias=False),
                                norm_layer(inter_channels),
                                nn.ReLU(True))
        self.conv3 = nn.Sequential(nn.Conv2d(inter_channels*2, in_channels, 1, bias=False),
                                norm_layer(in_channels))
        # bilinear interpolate options
        self._up_kwargs = up_kwargs

    def forward(self, x):
        _, _, h, w = x.size()
        x1 = self.conv1_1(x)
        x2 = self.conv1_2(x)
        x2_1 = self.conv2_0(x1)
        x2_2 = F.interpolate(self.conv2_1(self.pool1(x1)), (h, w), **self._up_kwargs)
        x2_3 = F.interpolate(self.conv2_2(self.pool2(x1)), (h, w), **self._up_kwargs)
        x2_4 = F.interpolate(self.conv2_3(self.pool3(x2)), (h, w), **self._up_kwargs)
        x2_5 = F.interpolate(self.conv2_4(self.pool4(x2)), (h, w), **self._up_kwargs)
        x1 = self.conv2_5(F.relu_(x2_1 + x2_2 + x2_3))
        x2 = self.conv2_6(F.relu_(x2_5 + x2_4))
        out = self.conv3(torch.cat([x1, x2], dim=1))
        return F.relu_(x + out)


class Bottleneck_SP(nn.Module):
    """Standard bottleneck."""

    def __init__(self, c1, c2, shortcut=True, g=1, k=(3, 3), e=0.5):
        """Initializes a standard bottleneck module with optional shortcut connection and configurable parameters."""
        super().__init__()
        c_ = int(c2 * e)  # hidden channels
        self.cv1 = Conv(c1, c_, k[0], 1)
        self.cv2 = StripPooling(c_, (20, 12), nn.BatchNorm2d, {'mode': 'bilinear', 'align_corners': True})
        self.add = shortcut and c1 == c2

    def forward(self, x):
        """Applies the YOLO FPN to input data."""
        return x + self.cv2(self.cv1(x)) if self.add else self.cv2(self.cv1(x))

class C3k(C3):
    """C3k is a CSP bottleneck module with customizable kernel sizes for feature extraction in neural networks."""

    def __init__(self, c1, c2, n=1, shortcut=True, g=1, e=0.5, k=3):
        """Initializes the C3k module with specified channels, number of layers, and configurations."""
        super().__init__(c1, c2, n, shortcut, g, e)
        c_ = int(c2 * e)  # hidden channels
        # self.m = nn.Sequential(*(RepBottleneck(c_, c_, shortcut, g, k=(k, k), e=1.0) for _ in range(n)))
        self.m = nn.Sequential(*(Bottleneck_SP(c_, c_, shortcut, g, k=(k, k), e=1.0) for _ in range(n)))

# 在c3k=True时,使用Bottleneck_SP特征融合,为false的时候我们使用普通的Bottleneck提取特征
class C3k2_SP(C2f):
    """Faster Implementation of CSP Bottleneck with 2 convolutions."""

    def __init__(self, c1, c2, n=1, c3k=False, e=0.5, g=1, shortcut=True):
        """Initializes the C3k2 module, a faster CSP Bottleneck with 2 convolutions and optional C3k blocks."""
        super().__init__(c1, c2, n, shortcut, g, e)
        self.m = nn.ModuleList(
            C3k(self.c, self.c, 2, shortcut, g) if c3k else Bottleneck(self.c, self.c, shortcut, g) for _ in
            range(n)
        )


if __name__ =='__main__':

    SP = StripPooling(256, (20, 12), nn.BatchNorm2d, {'mode': 'bilinear', 'align_corners': True})
    #创建一个输入张量
    batch_size = 8
    input_tensor=torch.randn(batch_size, 256, 64, 64 )
    #运行模型并打印输入和输出的形状
    output_tensor =SP(input_tensor)
    print("Input shape:",input_tensor.shape)
    print("0utput shape:",output_tensor.shape)

 4. 将空间池化模块StripPooling 引入到YOLOv11中

第一: 将下面的核心代码复制到D:\bilibili\model\YOLO11\ultralytics-main\ultralytics\nn路径下,如下图所示。

第二:在task.py中导入StripPooling 包

第三:在task.py中的模型配置部分下面代码

 1. 第一个改进,直接在backbone中添加FRFN

elif m is StripPooling:
    args = [ch[f]]

2. 第二个改进,与C3k2模块结合生成C3k2_SP

第四:将模型配置文件复制到YOLOV11.YAMY文件中

# Ultralytics YOLO 🚀, AGPL-3.0 license
# YOLO11 object detection model with P3-P5 outputs. For Usage examples see https://docs.ultralytics.com/tasks/detect

# Parameters
nc: 80 # number of classes
scales: # model compound scaling constants, i.e. 'model=yolo11n.yaml' will call yolo11.yaml with scale 'n'
  # [depth, width, max_channels]
  n: [0.50, 0.25, 1024] # summary: 319 layers, 2624080 parameters, 2624064 gradients, 6.6 GFLOPs
  s: [0.50, 0.50, 1024] # summary: 319 layers, 9458752 parameters, 9458736 gradients, 21.7 GFLOPs
  m: [0.50, 1.00, 512] # summary: 409 layers, 20114688 parameters, 20114672 gradients, 68.5 GFLOPs
  l: [1.00, 1.00, 512] # summary: 631 layers, 25372160 parameters, 25372144 gradients, 87.6 GFLOPs
  x: [1.00, 1.50, 512] # summary: 631 layers, 56966176 parameters, 56966160 gradients, 196.0 GFLOPs

# YOLO11n backbone
backbone:
  # [from, repeats, module, args]
  - [-1, 1, Conv, [64, 3, 2]] # 0-P1/2
  - [-1, 1, Conv, [128, 3, 2]] # 1-P2/4
  - [-1, 2, C3k2, [256, False, 0.25]]
  - [-1, 1, Conv, [256, 3, 2]] # 3-P3/8
  - [-1, 2, C3k2, [512, False, 0.25]]
  - [-1, 1, Conv, [512, 3, 2]] # 5-P4/16
  - [-1, 2, C3k2, [512, True]]
  - [-1, 1, Conv, [1024, 3, 2]] # 7-P5/32
  - [-1, 2, C3k2, [1024, True]]
  - [-1, 1, SPPF, [1024, 5]] # 9
  - [-1, 2, C2PSA, [1024]] # 10

# YOLO11n head
head:
  - [-1, 1, nn.Upsample, [None, 2, "nearest"]]
  - [[-1, 6], 1, Concat, [1]] # cat backbone P4
  - [-1, 2, C3k2, [512, False]] # 13

  - [-1, 1, nn.Upsample, [None, 2, "nearest"]]
  - [[-1, 4], 1, Concat, [1]] # cat backbone P3
  - [-1, 2, C3k2, [256, False]] # 16 (P3/8-small)
  - [-1, 2, StripPooling, []] # 16 (P3/8-small)

  - [-1, 1, Conv, [256, 3, 2]]
  - [[-1, 13], 1, Concat, [1]] # cat head P4
  - [-1, 2, C3k2, [512, False]] # 19 (P4/16-medium)
  - [-1, 2, StripPooling, []] # 19 (P4/16-medium)

  - [-1, 1, Conv, [512, 3, 2]]
  - [[-1, 10], 1, Concat, [1]] # cat head P5
  - [-1, 2, C3k2, [1024, True]] # 22 (P5/32-large)
  - [-1, 2, StripPooling, [] ]# 19 (P4/16-medium)

  - [[17, 21, 25], 1, Detect, [nc]] # Detect(P3, P4, P5)



# Ultralytics YOLO 🚀, AGPL-3.0 license
# YOLO11 object detection model with P3-P5 outputs. For Usage examples see https://docs.ultralytics.com/tasks/detect

# Parameters
nc: 80 # number of classes
scales: # model compound scaling constants, i.e. 'model=yolo11n.yaml' will call yolo11.yaml with scale 'n'
  # [depth, width, max_channels]
  n: [0.50, 0.25, 1024] # summary: 319 layers, 2624080 parameters, 2624064 gradients, 6.6 GFLOPs
  s: [0.50, 0.50, 1024] # summary: 319 layers, 9458752 parameters, 9458736 gradients, 21.7 GFLOPs
  m: [0.50, 1.00, 512] # summary: 409 layers, 20114688 parameters, 20114672 gradients, 68.5 GFLOPs
  l: [1.00, 1.00, 512] # summary: 631 layers, 25372160 parameters, 25372144 gradients, 87.6 GFLOPs
  x: [1.00, 1.50, 512] # summary: 631 layers, 56966176 parameters, 56966160 gradients, 196.0 GFLOPs

# YOLO11n backbone
backbone:
  # [from, repeats, module, args]
  - [-1, 1, Conv, [64, 3, 2]] # 0-P1/2
  - [-1, 1, Conv, [128, 3, 2]] # 1-P2/4
  - [-1, 2, C3k2_SP, [256, False, 0.25]]
  - [-1, 1, Conv, [256, 3, 2]] # 3-P3/8
  - [-1, 2, C3k2_SP, [512, False, 0.25]]
  - [-1, 1, Conv, [512, 3, 2]] # 5-P4/16
  - [-1, 2, C3k2_SP, [512, True]]
  - [-1, 1, Conv, [1024, 3, 2]] # 7-P5/32
  - [-1, 2, C3k2_SP, [1024, True]]
  - [-1, 1, SPPF, [1024, 5]] # 9
  - [-1, 2, C2PSA, [1024]] # 10

# YOLO11n head
head:
  - [-1, 1, nn.Upsample, [None, 2, "nearest"]]
  - [[-1, 6], 1, Concat, [1]] # cat backbone P4
  - [-1, 2, C3k2_SP, [512, False]] # 13

  - [-1, 1, nn.Upsample, [None, 2, "nearest"]]
  - [[-1, 4], 1, Concat, [1]] # cat backbone P3
  - [-1, 2, C3k2_SP, [256, False]] # 16 (P3/8-small)

  - [-1, 1, Conv, [256, 3, 2]]
  - [[-1, 13], 1, Concat, [1]] # cat head P4
  - [-1, 2, C3k2_SP, [512, False]] # 19 (P4/16-medium)

  - [-1, 1, Conv, [512, 3, 2]]
  - [[-1, 10], 1, Concat, [1]] # cat head P5
  - [-1, 2, C3k2_SP, [1024, True]] # 22 (P5/32-large)

  - [[16, 19, 22], 1, Detect, [nc]] # Detect(P3, P4, P5)

第五:运行成功


from ultralytics.models import NAS, RTDETR, SAM, YOLO, FastSAM, YOLOWorld

if __name__=="__main__":


    # 使用自己的YOLOv11.yamy文件搭建模型并加载预训练权重训练模型
    model = YOLO(r"D:\bilibili\model\YOLO11\ultralytics-main\ultralytics\cfg\models\11\yolo11_strippooling.yaml")\
        .load(r'D:\bilibili\model\YOLO11\ultralytics-main\yolo11n.pt')  # build from YAML and transfer weights

    results = model.train(data=r'D:\bilibili\model\ultralytics-main\ultralytics\cfg\datasets\VOC_my.yaml',
                          epochs=100, imgsz=640, batch=8)



 

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/896009.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

如何在线查看近8年的建筑覆盖变化

我们在《谷歌发布建筑数据,高度误差达惊人的1.5米》一文中介绍了谷歌2.5D建筑数据用途、制作方法以及数据下载方式。 现在我们演示下如何在线查看近8年的建筑物覆盖、建筑物质心和建筑物高度的变化。 历史建筑覆盖在线查看 2.5D建筑演变数据集包含2016年至2023年…

碰一碰支付怎么推广的?小白也能看懂的教程来了!

作为支付宝和微信全面力推的一个项目,碰一碰支付的热度可谓是节节攀升,连带着与之相关的话题,如碰一碰支付是什么和碰一碰支付怎么推广的等也因此成为了人们关注的焦点。那么,本期,我们就来详细地解答一下这两个问题。…

2024台州赛CTFwp

备注: 解题过程中,关键步骤不可省略,不可含糊其辞、一笔带过。解题过程中如是自己编写的脚本,不可省略,不可截图(代码字体可以调小;而如果代码太长,则贴关键代码函数)。…

I.MX6U 字符设备驱动开发指南

目录 一、引言 二、字符设备驱动的基本概念 1.字符设备的定义 2.字符设备驱动的作用 三、I.MX6U 字符设备驱动开发步骤 1.确定设备信息 2.编写设备驱动代码 3.编译和加载驱动 4.测试设备驱动 四、实例分析 1.确定设备信息 2.编写设备驱动代码 3.编译和加载驱动 4.…

iOS 回到主线程刷新UI

在iOS 里面,项目打开就会运行一个主线程,所有的UI都在主线程里进行.其他网络请求或者耗时操作理论上也可以在主线程运行,但是如果太耗时,那么就会影响主线程其他UI.所以需要开字线程来进行耗时操作,子线程进行完耗时操作之后,如果项目需求有需要刷新UI,或者改变UI,一定得回到主…

音频/视频提取器:Python和moviepy实现

在这篇博客中,我们将深入探讨一个使用Python和wxPython构建的音频/视频提取器应用程序。这个应用程序允许用户从视频文件中提取音频,或者从音频文件中截取特定时间段。让我们逐步分析这个程序的功能和实现。 C:\pythoncode\new\MP3towav.py 全部代码 import wx import os imp…

「iOS」——YYModel学习

iOS学习 前言优势使用方法简单的Model与JSON互转多样化的数据类型交换容器类数据交换 model中包含其他model白名单与黑名单 总结 前言 YYModel是YYKit的高效组件之一,在实际场景中的非常实用,在项目中使用MVC架构时,可以简化数据处理。在性能…

OpenShift 4 - 云原生备份容灾 - Velero 和 OADP 基础篇

《OpenShift 4.x HOL教程汇总》 说明: 本文主要说明能够云原生备份容灾的开源项目 Velero 及其红帽扩展项目 OADP 的概念和架构篇。操作篇见《OpenShift 4 - 使用 OADP 对容器应用进行备份和恢复(附视频) 》 Velero 和 OADP 包含的功能和模…

028.爬虫浏览器-抓取shadowRoot下的内容

一、什么是Shadow DOM Shadow DOM是一种在web开发中用于封装HTML标记、样式和行为的技术,以避免组件间的样式和脚本冲突。它允许开发者将网页的一部分隐藏在一个独立的作用域内,从而实现更加模块化和可维护的代码结构 二、js操作Shadow DOM // 获取宿…

邮件营销文案设计:打造个性化内容的步骤?

邮件营销文案写作技巧与方法?外贸邮件营销怎么撰写? 一个成功的邮件营销文案设计不仅能吸引客户的注意力,还能有效提升转化率。MailBing将详细探讨如何通过一系列步骤,打造出既个性化又高效的邮件营销文案设计。 邮件营销文案设…

监控易监测对象及指标之:Microsoft Message Queue(MSMQ)监控

监控易是一款强大的监控工具,能够实时监控各类IT设施和应用程序的性能指标。对于Microsoft Message Queue(简称MSMQ)的监控,监控易提供了详尽的指标,以确保企业能够准确掌握消息队列的运行状况。 在MSMQ的监控中&#…

24、darkhole_2

难度 高(个人感觉属于中) 目标 root权限2个flag 基于VMware启动(我这里启动只能选择我wifi的那个网卡才能获取到ip地址,nat的没获取到不知道为什么。) kali 192.168.1.122 靶机 192.168.1.170 信息收集 端口扫描 只开…

【python】OpenCV—Fun Mirrors

文章目录 1、准备工作2、原理介绍3、代码实现4、效果展示5、参考 1、准备工作 pip install vacm2、原理介绍 在OpenCV中,VCAM 库是一个用于简化创建三维曲面、定义虚拟摄像机、设置参数以及进行投影任务的工具。它特别适用于实现如哈哈镜等图像变形效果。 一、VC…

用你的手机/电脑运行文生图方案

随着ChatGPT和Stable Diffusion的发布,最近一两年,生成式AI已经火爆全球,已然成为移动互联网后一个重要的“风口”。就图片/视频生成领域来说,Stable Diffusion模型发挥着极其重要的作用。由于Stable Diffusion模型参数量是10亿参…

读者写者问题与读写锁

读者写者问题 读者写者 vs 生产消费 重点是有什么区别 读者写者问题如何理解 重点理解读者和写者如何完成同步 下面是一段伪代码:公共部分 uint32_t reader_count 0; lock_t count_lock; lock_t writer_lock; Reader // 加锁 lock(count_lock); if(reader_c…

Java | Leetcode Java题解之第492题构造矩形

题目: 题解: class Solution {public int[] constructRectangle(int area) {int w (int) Math.sqrt(area);while (area % w ! 0) {--w;}return new int[]{area / w, w};} }

5G物联网主机引领企业数字化转型

在当今这个信息化高度发展的时代,企业的竞争力很大程度上取决于其能否快速适应市场变化并高效地进行内部管理。郑州龙兴物联科技有限公司凭借其先进的5G物联网技术,推出了为企业量身定制的5G物联网主机,该设备充分利用其多协议、多接口的特点…

ESP32-C3 入门笔记04:gpio_key 按键 (ESP-IDF + VSCode)

1.GPIO简介 ESP32-C3是QFN32封装,GPIO引脚一共有22个,从GPIO0到GPIO21。 理论上,所有的IO都可以复用为任何外设功能,但有些引脚用作连接芯片内部FLASH或者外部FLASH功能时,官方不建议用作其它用途。 通过开发板的原…

【Vue】Vue3.0 (十二)、watchEffect 和watch的区别及使用

上篇文章: 【Vue】Vue3.0 (十二)、watch对ref定义的基本类型、对象类型;reactive定义的对象类型的监视使用 🏡作者主页:点击! 🤖Vue专栏:点击! ⏰️创作时间&…

数据仓库基础概念

数据仓库 概念 数据仓库(Data Warehouse, DW)是一个面向主题的、集成的、相对稳定的、反映历史变化的数据集合。它是为满足企业决策分析需求而设计的。 面向主题:数据仓库围绕特定的主题组织数据,例如“销售”或“人力资源”&am…