YOLOv9改进策略 | SPPF篇 | 利用RT-DETR的AIFI模块替换SPPFELAN助力小目标检测涨点

 一、本文介绍

本文给大家带来是用最新的RT-DETR模型中的AIFI模块来替换YOLOv9中的SPPFELAN。RT-DETR号称是打败YOLO的检测模型,其作为一种基于Transformer的检测方法,相较于传统的基于卷积的检测方法,提供了更为全面和深入的特征理解,将RT-DETR中的一些先进模块融入到YOLOv9往往能够达到一些特殊的效果。同时欢迎大家订阅本专栏,本专栏每周更新3-5篇最新机制,更有包含我所有改进的文件和交流群提供给大家。同时本专栏目前改进基于yolov9.yaml文件,后期如果官方放出轻量化版本,专栏内所有改进也会同步更新,请大家放心,本文提供三种使用方式,下面图片为yaml1对应的结构图。

专栏地址:YOLOv9有效涨点专栏-持续复现各种顶会内容-有效涨点-全网改进最全的专栏  

目录

 一、本文介绍

二、RT-DETR的AIFI框架原理

2.1 AIFI的基本原理

三、AIFI的完整代码

四、手把手教你添加AIFI模块

4.1 细节修改教程

4.1.1 修改一

​4.1.2 修改二

4.1.3 修改三 

4.1.4 修改四

4.2 AIFI的yaml文件

4.3 AIFI运行成功截图

五、本文总结 


二、RT-DETR的AIFI框架原理

​​​​

论文地址:RT-DETR论文地址

代码地址:RT-DETR官方下载地址

​​​​


2.1 AIFI的基本原理

RT-DETR模型中的AIFI(基于注意力的内部尺度特征交互)模块是一个关键组件,它与CNN基于的跨尺度特征融合模块(CCFM)一起构成了模型的编码器部分。AIFI的主要思想如下->

  1. 基于注意力的特征处理:AIFI模块利用自我注意力机制来处理图像中的高级特征。自我注意力是一种机制,它允许模型在处理特定部分的数据时,同时考虑到数据的其他相关部分。这种方法特别适用于处理具有丰富语义信息的高级图像特征。

  2. 选择性特征交互:AIFI模块专注于在S5级别(即高级特征层)上进行内部尺度交互。这是基于认识到高级特征层包含更丰富的语义概念,能够更有效地捕捉图像中的概念实体间的联系。与此同时,避免在低级特征层进行相同的交互,因为低级特征缺乏必要的语义深度,且可能导致数据处理上的重复和混淆。

总结:AIFI模块的主要思想其实就是通过自我注意力机制专注于处理高级图像特征,从而提高模型在对象检测和识别方面的性能,同时减少不必要的计算消耗。

​​

AIFI模块的主要作用和特点如下: 

1. 减少计算冗余:AIFI模块进一步减少了基于变体D的计算冗余,这个变体仅在S5级别上执行内部尺度交互。

2. 高级特征的自我注意力操作:AIFI模块通过对具有丰富语义概念的高级特征应用自我注意力操作,捕捉图像中概念实体之间的联系。这种处理有助于随后的模块更有效地检测和识别图像中的对象。

3. 避免低级特征的内部尺度交互:由于低级特征缺乏语义概念,以及存在与高级特征交互时的重复和混淆风险,AIFI模块不对低级特征进行内部尺度交互。

4. 专注于S5级别:为了验证上述观点,AIFI模块仅在S5级别上进行内部尺度交互,这表明模块主要关注于处理高级特征。

没啥好讲的这个AIFI具体的内容大家可以看我的另一篇博客->

RT-DETR回顾:RT-DETR论文阅读笔记(包括YOLO版本训练和官方版本训练)


三、AIFI的完整代码

我们将在“ultralytics/nn/modules”目录下面创建一个文件将其复制进去,使用方法在后面第四章会讲。

import torch
import torch.nn as nn

__all__ = ['AIFI']

class TransformerEncoderLayer(nn.Module):
    """Defines a single layer of the transformer encoder."""

    def __init__(self, c1, cm=2048, num_heads=8, dropout=0.0, act=nn.GELU(), normalize_before=False):
        """Initialize the TransformerEncoderLayer with specified parameters."""
        super().__init__()

        self.ma = nn.MultiheadAttention(c1, num_heads, dropout=dropout, batch_first=True)
        # Implementation of Feedforward model
        self.fc1 = nn.Linear(c1, cm)
        self.fc2 = nn.Linear(cm, c1)

        self.norm1 = nn.LayerNorm(c1)
        self.norm2 = nn.LayerNorm(c1)
        self.dropout = nn.Dropout(dropout)
        self.dropout1 = nn.Dropout(dropout)
        self.dropout2 = nn.Dropout(dropout)

        self.act = act
        self.normalize_before = normalize_before

    @staticmethod
    def with_pos_embed(tensor, pos=None):
        """Add position embeddings to the tensor if provided."""
        return tensor if pos is None else tensor + pos

    def forward_post(self, src, src_mask=None, src_key_padding_mask=None, pos=None):
        """Performs forward pass with post-normalization."""
        q = k = self.with_pos_embed(src, pos)
        src2 = self.ma(q, k, value=src, attn_mask=src_mask, key_padding_mask=src_key_padding_mask)[0]
        src = src + self.dropout1(src2)
        src = self.norm1(src)
        src2 = self.fc2(self.dropout(self.act(self.fc1(src))))
        src = src + self.dropout2(src2)
        return self.norm2(src)

    def forward_pre(self, src, src_mask=None, src_key_padding_mask=None, pos=None):
        """Performs forward pass with pre-normalization."""
        src2 = self.norm1(src)
        q = k = self.with_pos_embed(src2, pos)
        src2 = self.ma(q, k, value=src2, attn_mask=src_mask, key_padding_mask=src_key_padding_mask)[0]
        src = src + self.dropout1(src2)
        src2 = self.norm2(src)
        src2 = self.fc2(self.dropout(self.act(self.fc1(src2))))
        return src + self.dropout2(src2)

    def forward(self, src, src_mask=None, src_key_padding_mask=None, pos=None):
        """Forward propagates the input through the encoder module."""
        if self.normalize_before:
            return self.forward_pre(src, src_mask, src_key_padding_mask, pos)
        return self.forward_post(src, src_mask, src_key_padding_mask, pos)


class AIFI(TransformerEncoderLayer):
    """Defines the AIFI transformer layer."""

    def __init__(self, c1, cm=2048, num_heads=8, dropout=0, act=nn.GELU(), normalize_before=False):
        """Initialize the AIFI instance with specified parameters."""
        super().__init__(c1, cm, num_heads, dropout, act, normalize_before)

    def forward(self, x):
        """Forward pass for the AIFI transformer layer."""
        c, h, w = x.shape[1:]
        pos_embed = self.build_2d_sincos_position_embedding(w, h, c)
        # Flatten [B, C, H, W] to [B, HxW, C]
        x = super().forward(x.flatten(2).permute(0, 2, 1), pos=pos_embed.to(device=x.device, dtype=x.dtype))
        return x.permute(0, 2, 1).view([-1, c, h, w]).contiguous()

    @staticmethod
    def build_2d_sincos_position_embedding(w, h, embed_dim=256, temperature=10000.0):
        """Builds 2D sine-cosine position embedding."""
        grid_w = torch.arange(int(w), dtype=torch.float32)
        grid_h = torch.arange(int(h), dtype=torch.float32)
        grid_w, grid_h = torch.meshgrid(grid_w, grid_h, indexing="ij")
        assert embed_dim % 4 == 0, "Embed dimension must be divisible by 4 for 2D sin-cos position embedding"
        pos_dim = embed_dim // 4
        omega = torch.arange(pos_dim, dtype=torch.float32) / pos_dim
        omega = 1.0 / (temperature ** omega)

        out_w = grid_w.flatten()[..., None] @ omega[None]
        out_h = grid_h.flatten()[..., None] @ omega[None]

        return torch.cat([torch.sin(out_w), torch.cos(out_w), torch.sin(out_h), torch.cos(out_h)], 1)[None]


四、手把手教你添加AIFI模块

4.1 细节修改教程

4.1.1 修改一

我们找到如下的目录'yolov9-main/models'在这个目录下创建一整个文件目录(注意是目录,因为我这个专栏会出很多的更新,这里用一种一劳永逸的方法)文件目录起名modules,然后在下面新建一个文件,将我们的代码复制粘贴进去。


​4.1.2 修改二

然后新建一个__init__.py文件,然后我们在里面添加一行代码(均用红框标记出来了)。注意标记一个'.'其作用是标记当前目录。

​​

​​


4.1.3 修改三 

然后我们找到如下文件''models/yolo.py''在开头的地方导入我们的模块按照如下修改->

(如果你看了我多个改进机制此处只需要添加一个即可,无需重复添加。)

​​​​


4.1.4 修改四

然后我们找到parse_model方法,按照如下修改->

        elif m in {AIFI}:
            c2 = ch[f]
            args = [c2, *args]

到此就修改完成了,复制下面的ymal文件即可运行。


4.2 AIFI的yaml文件

# YOLOv9

# parameters
nc: 80  # number of classes
depth_multiple: 1  # model depth multiple
width_multiple: 1  # layer channel multiple
#activation: nn.LeakyReLU(0.1)
#activation: nn.ReLU()

# anchors
anchors: 3

# YOLOv9 backbone
backbone:
  [
   [-1, 1, Silence, []],
   # conv down
   [-1, 1, Conv, [64, 3, 2]],  # 1-P1/2
   # conv down
   [-1, 1, Conv, [128, 3, 2]],  # 2-P2/4
   # elan-1 block
   [-1, 1, RepNCSPELAN4, [256, 128, 64, 1]],  # 3
   # conv down
   [-1, 1, Conv, [256, 3, 2]],  # 4-P3/8
   # elan-2 block
   [-1, 1, RepNCSPELAN4, [512, 256, 128, 1]],  # 5
   # conv down
   [-1, 1, Conv, [512, 3, 2]],  # 6-P4/16
   # elan-2 block
   [-1, 1, RepNCSPELAN4, [512, 512, 256, 1]],  # 7
   # conv down
   [-1, 1, Conv, [512, 3, 2]],  # 8-P5/32
   # elan-2 block
   [-1, 1, RepNCSPELAN4, [512, 512, 256, 1]],  # 9
  ]

# YOLOv9 head
head:
  [
   # elan-spp block
   [-1, 1, AIFI, []],  # 10

   # up-concat merge
   [-1, 1, nn.Upsample, [None, 2, 'nearest']],
   [[-1, 7], 1, Concat, [1]],  # cat backbone P4

   # elan-2 block
   [-1, 1, RepNCSPELAN4, [512, 512, 256, 1]],  # 13

   # up-concat merge
   [-1, 1, nn.Upsample, [None, 2, 'nearest']],
   [[-1, 5], 1, Concat, [1]],  # cat backbone P3

   # elan-2 block
   [-1, 1, RepNCSPELAN4, [256, 256, 128, 1]],  # 16 (P3/8-small)

   # conv-down merge
   [-1, 1, Conv, [256, 3, 2]],
   [[-1, 13], 1, Concat, [1]],  # cat head P4

   # elan-2 block
   [-1, 1, RepNCSPELAN4, [512, 512, 256, 1]],  # 19 (P4/16-medium)

   # conv-down merge
   [-1, 1, Conv, [512, 3, 2]],
   [[-1, 10], 1, Concat, [1]],  # cat head P5

   # elan-2 block
   [-1, 1, RepNCSPELAN4, [512, 512, 256, 1]],  # 22 (P5/32-large)
   
   # routing
   [5, 1, CBLinear, [[256]]], # 23
   [7, 1, CBLinear, [[256, 512]]], # 24
   [9, 1, CBLinear, [[256, 512, 512]]], # 25
   
   # conv down
   [0, 1, Conv, [64, 3, 2]],  # 26-P1/2

   # conv down
   [-1, 1, Conv, [128, 3, 2]],  # 27-P2/4

   # elan-1 block
   [-1, 1, RepNCSPELAN4, [256, 128, 64, 1]],  # 28

   # conv down fuse
   [-1, 1, Conv, [256, 3, 2]],  # 29-P3/8
   [[23, 24, 25, -1], 1, CBFuse, [[0, 0, 0]]], # 30  

   # elan-2 block
   [-1, 1, RepNCSPELAN4, [512, 256, 128, 1]],  # 31

   # conv down fuse
   [-1, 1, Conv, [512, 3, 2]],  # 32-P4/16
   [[24, 25, -1], 1, CBFuse, [[1, 1]]], # 33 

   # elan-2 block
   [-1, 1, RepNCSPELAN4, [512, 512, 256, 1]],  # 34

   # conv down fuse
   [-1, 1, Conv, [512, 3, 2]],  # 35-P5/32
   [[25, -1], 1, CBFuse, [[2]]], # 36

   # elan-2 block
   [-1, 1, RepNCSPELAN4, [512, 512, 256, 1]],  # 37

   # detect
   [[31, 34, 37, 16, 19, 22], 1, DualDDetect, [nc]],  # DualDDetect(A3, A4, A5, P3, P4, P5)
  ]

4.3 AIFI运行成功截图

附上我的运行记录确保我的教程是可用的。 


五、本文总结 

到此本文的正式分享内容就结束了,在这里给大家推荐我的YOLOv9改进有效涨点专栏,本专栏目前为新开的平均质量分98分,后期我会根据各种最新的前沿顶会进行论文复现,也会对一些老的改进机制进行补充,目前本专栏免费阅读(暂时,大家尽早关注不迷路~),如果大家觉得本文帮助到你了,订阅本专栏,关注后续更多的更新~

专栏地址:YOLOv9有效涨点专栏-持续复现各种顶会内容-有效涨点-全网改进最全的专栏 

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/562620.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

如何30天快速掌握键盘盲打

失业后在家备考公务员,发现了自己不正确的打字方式,决定每天抽出一点时间练习打字。在抖音上看到一些高手的飞速盲打键盘后,觉得使用正确的指法打字是很必要的。 练习打字,掌握正确的键盘指法十分关键。 练习打字的第一步是找到…

基本的SELECT语句及DESC显示表结构

1. SELECT ... 例 : 2. SELECT ... FROM ... (1). SELECT ... : 标识选择哪些列. (2). FROM ... : 标识从哪个表中选取. (3). *通配符 : 选择表中全部列. 例 : 3.列的别名 (1). 空一格. (2). 在列和别名间加入关键字AS. (3). 别名可以使用双引号,以便于在…

【Datawhale LLM学习笔记】一、什么是大型语言模型(LLM)

文章目录 1. 什么是大模型2. 检索增强生成 RAG一、什么是 RAG二、RAG 的工作流程 3. langChain介绍一、什么是 LangChain二、LangChain 的核心组件 4. 开发 LLM 应用的整体流程一、何为大模型开发二、大模型开发的一般流程三、搭建 LLM 项目的流程简析(以知识库助手…

从迷宫问题理解dfs

文章目录 迷宫问题打印路径1思路定义一个结构体要保存所走的路径,就需要使用到栈遍历所有的可能性核心代码 部分函数递归图源代码 迷宫问题返回最短路径这里的思想同上面类似。源代码 迷宫问题打印路径1 定义一个二维数组 N*M ,如 5 5 数组下所示&…

掌握Node Version Manager(nvm):跨平台Node.js版本管理

🌟 前言 欢迎来到我的技术小宇宙!🌌 这里不仅是我记录技术点滴的后花园,也是我分享学习心得和项目经验的乐园。📚 无论你是技术小白还是资深大牛,这里总有一些内容能触动你的好奇心。🔍 &#x…

整合阿里云短信服务

1. 申请服务 如图&#xff1a; 申请签名管理和模板管理 2. 进入快速学习和调试 2.1 进入快速学习 2.2 获取依赖和代码实现 3. 具体实现案例 3.1 添加依赖 <dependency><groupId>com.aliyun</groupId><artifactId>dysmsapi20170525</artifact…

9.MMD 基础内容总结及制作成品流程

前期准备 1. 导入场景和模型 在左上角菜单栏&#xff0c;显示里将编辑模型时保持相机和光照勾选上&#xff0c;有助于后期调色 将抗锯齿和各向异性过滤勾掉&#xff0c;可以节省资源&#xff0c;避免bug 在分辨率设定窗口&#xff0c;可以调整分辨率 3840x2160 4k分辨率 1…

Umi.js:登录之后需要手动刷新权限菜单才能渲染

在使用Umi.js开发后台管理页面时&#xff0c;用户登录之后&#xff0c;总是需要手动刷新一次页面&#xff0c;才能够拿到全局状态/权限信息。 问题描述 结合使用umi/plugin-layout和umi/plugin-access&#xff0c;登录进入页面&#xff0c;配置的权限菜单未渲染&#xff0c;需…

【Redis 神秘大陆】005 常见性能优化方式

五、Redis 性能优化 5.1 系统层面的优化 https://github.com/sohutv/cachecloud/blob/main/redis-ecs/script/cachecloud-init.sh initConfig() {# 支持虚拟内存分配sysctl vm.overcommit_memory1# 最大排队连接数设置为 511&#xff0c;一般默认是 128echo 511 >/proc/sy…

openobserve-filebeat配置

优势 rustgolang开发的日志工具组合&#xff0c;自带日志数据存储&#xff0c;简化部署和管理。日志数据可配置保留x天。从日志文件中采集&#xff0c;做到非侵入式日志集中管理。 可从日志内容中提取信息进行报警等二次开发。 下载 openobserve-v0.10.1-windows-amd64 fil…

【题解】NC40链表相加(二)(链表 + 高精度加法)

https://www.nowcoder.com/practice/c56f6c70fb3f4849bc56e33ff2a50b6b?tpId196&tqId37147&ru/exam/oj class Solution {public:// 逆序链表ListNode* reverse(ListNode* head) {// 创建一个新节点作为逆序后链表的头节点ListNode* newHead new ListNode(0);// 当前…

使用51单片机控制T0和T1分别间隔1秒2秒亮灭逻辑

#include <reg51.h>sbit LED1 P1^0; // 设置LED1灯的接口 sbit LED2 P1^1; // 设置LED2灯的接口unsigned int cnt1 0; // 设置LED1灯的定时器溢出次数 unsigned int cnt2 0; // 设置LED2灯的定时器溢出次数// 定时器T0 void Init_Timer0() {TMOD | 0x01;; // 定时器…

代码学习记录49---单调栈

随想录日记part49 t i m e &#xff1a; time&#xff1a; time&#xff1a; 2024.04.20 主要内容&#xff1a;今天开始要学习单调栈的相关知识了&#xff0c;今天的内容主要涉及&#xff1a;柱状图中最大的矩形 84.柱状图中最大的矩形 Topic184.柱状图中最大的矩形 题目&…

Sharding-JDBC笔记1

Sharding-JDBC笔记1 1.分库分表1.1 垂直分库1.2 垂直分表1.3 水平分库1.4 水平分表 2.存在问题2.1 事务一致性2.2 跨节点关联查询2.3 跨节点分页、排序函数2.4 主键避重2.5 公共表 1.分库分表 分库分表就是为了解决由于数据量过大而导致数据库性能降低的问题&#xff0c;将原来…

操作符不存在:sde.st_geometry ^ !sde.st_geometry建议 SQL函 数st_intersects在内联inlining期间

操作符不存在&#xff1a;sde.st_geometry ^ &#xff01;sde.st_geometry建议 SQL函 数st_intersects在内联inlining期间 问题&#xff1a;最近在使用SQL图形处理函数处理图形时&#xff0c;莫名奇妙报如下错误&#xff0c;甚是费解 于是开始四处"寻医问药" 1、nav…

Spark集群的搭建

1.1搭建Spark集群 Spark集群环境可分为单机版环境、单机伪分布式环境和完全分布式环境。本节任务是学习如何搭建不同模式的Spark集群&#xff0c;并查看Spark的服务监控。读者可从官网下载Spark安装包&#xff0c;本文使用的是spark-2.0.0-bin-hadoop2.7.gz。 1.1.1搭建单机版…

“开挂”的WAAP全站防护是云海驰骋的必备

何为攻击&#xff1f; 网络和应用是攻击的两大阵地 网络攻击像僵尸&#xff1a;简单、粗暴、让人猝不及防 显著特征&#xff1a;流量大&#xff0c;并发高 应用攻击像幽灵&#xff1a;复杂、神秘、让人摸不着头脑 显著特征&#xff1a;流量小、隐蔽强 攻击不像“馅饼”&…

OpenHarmony实战开发-组件复用实践。

若开发者的应用中存在以下场景&#xff0c;并成为UI线程的帧率瓶颈&#xff0c;应该考虑使用组件复用机制提升应用性能&#xff1a; 滑动场景下对同一类自定义组件的实例进行频繁的创建与销毁。反复切换条件渲染的控制分支&#xff0c;且控制分支中的组件子树结构比较复杂。 …

SpringBoot3 + Vue3 + Element-Plus + TS 实现动态二级菜单级联选择器

SpringBoot3 Vue3 Element-Plus TS 实现动态二级菜单选择器 1、效果展示1.1 点击效果1.2 选择效果1.3 返回值1.4 模拟后端返回数据 2、前端代码2.1 UnusedList.vue2.2 goodsType.ts2.3 http.ts 3、后端代码3.1 GoodsCategoryController.java3.2 GoodsCategoryService.java3.…

读后感-有效沟通

司内的学习已开展8期&#xff0c;内容主要以如何沟通为主&#xff0c;这里将根据个人的学习体会&#xff0c;对所学内容进行梳理与整合&#xff0c;以期更好地吸收和应用所学知识。 沟通是一门技术&#xff0c;其轨迹可循。自来熟的态度&#xff0c;一上来便滔滔不绝地发表言论…