RT-DETR融合[ECCV2024]自调制特征聚合SMFA模块及相关改进思路


RT-DETR使用教程: RT-DETR使用教程

RT-DETR改进汇总贴:RT-DETR更新汇总贴


《SMFANet: A Lightweight Self-Modulation Feature Aggregation Network for Efficient Image Super-Resolution》

一、 模块介绍

        论文链接:https://link.springer.com/chapter/10.1007/978-3-031-72973-7_21

        代码链接:https://github.com/Zheng-MJ/SMFANet?tab=readme-ov-file

论文速览:

        基于 Transformer 的修复方法取得了显着的性能,因为 Transformer 的自注意力 (SA) 可以探索非局部信息以获得更好的高分辨率图像重建。然而,关键的点积 SA 需要大量的计算资源。此外,SA 机制的低通特性限制了其捕获局部细节的能力,从而导致平滑的重建结果。为了解决这些问题,作者提出了一个自调制特征聚合 (SMFA) 模块,以协同利用局部和非局部特征交互来实现更准确的重建。具体来说,SMFA 模块采用高效的自我注意近似 (EASA) 分支来对非局部信息进行建模,并使用局部细节估计 (LDE) 分支来捕获局部细节。此外,作者进一步引入了基于部分卷积的前馈网络 (PCFN) 来改进从 SMFA 派生的代表性特征。大量实验表明,所提出的 SMFANet 系列在公共基准数据集上实现了更好的重建性能和计算效率之间的权衡。特别是,与×4 SwinIR-light,SMFANet+ 在五个公共测试集中平均实现了 0.14 dB 的性能提升,并且×运行速度提高 10 倍,模型复杂度仅为 43% 左右(例如 FLOPs)。

总结:一种基于自调制特征聚合模块(SMFA)的高分辨率图像重建方法,实测与其他模块融合有提升。


二、 加入到RT-DETR中

2.1 创建脚本文件

        首先在ultralytics->nn路径下创建blocks.py脚本,用于存放模块代码。

2.2 复制代码        

        复制代码粘到刚刚创建的blocks.py脚本中,如下图所示:

import torch
import torch.nn as nn
import torch.nn.functional as F
 
 
class DMlp(nn.Module):
    def __init__(self, dim, growth_rate=2.0):
        super().__init__()
        hidden_dim = int(dim * growth_rate)
        self.conv_0 = nn.Sequential(
            nn.Conv2d(dim,hidden_dim,3,1,1,groups=dim),
            nn.Conv2d(hidden_dim,hidden_dim,1,1,0)
        )
        self.act =nn.GELU()
        self.conv_1 = nn.Conv2d(hidden_dim, dim, 1, 1, 0)
 
    def forward(self, x):
        x = self.conv_0(x)
        x = self.act(x)
        x = self.conv_1(x)
        return x
 
 
class SMFA(nn.Module):
    def __init__(self, dim=36):
        super(SMFA, self).__init__()
        self.linear_0 = nn.Conv2d(dim,dim*2,1,1,0)
        self.linear_1 = nn.Conv2d(dim,dim,1,1,0)
        self.linear_2 = nn.Conv2d(dim,dim,1,1,0)
 
        self.lde = DMlp(dim,2)
 
        self.dw_conv = nn.Conv2d(dim,dim,3,1,1,groups=dim)
 
        self.gelu = nn.GELU()
        self.down_scale = 8
 
        self.alpha = nn.Parameter(torch.ones((1,dim,1,1)))
        self.belt = nn.Parameter(torch.zeros((1,dim,1,1)))
 
    def forward(self, f):
        _,_,h,w = f.shape
        y, x = self.linear_0(f).chunk(2, dim=1)
        x_s = self.dw_conv(F.adaptive_max_pool2d(x, (h // self.down_scale, w // self.down_scale)))
        x_v = torch.var(x, dim=(-2,-1), keepdim=True)
        x_l = x * F.interpolate(self.gelu(self.linear_1(x_s * self.alpha + x_v * self.belt)), size=(h,w), mode='nearest')
        y_d = self.lde(y)
        return self.linear_2(x_l + y_d)

2.3 更改task.py文件 

       打开ultralytics->nn->modules->task.py,在脚本空白处导入函数。

from ultralytics.nn.blocks import *

        之后找到模型解析函数parse_model(约在tasks.py脚本中940行左右位置,可能因代码版本不同变动),在该函数的最后一个else分支上面增加相关解析代码。

        elif m is SMFA:
            c2 = ch[f]
            args = [ch[f]]

2.4 更改yaml文件 

yam文件解读:YOLO系列 “.yaml“文件解读_yolo yaml文件-CSDN博客

       打开更改ultralytics/cfg/models/rt-detr路径下的rtdetr-l.yaml文件,替换原有模块。(放在该位置仅能插入该模块,具体效果未知。博主精力有限,仅完成与其他模块二次创新融合的测试,结构图见文末,代码见群文件更新。)

# Ultralytics YOLO 🚀, AGPL-3.0 license
# RT-DETR-l object detection model with P3-P5 outputs. For details see https://docs.ultralytics.com/models/rtdetr

# Parameters
nc: 80 # number of classes
scales: # model compound scaling constants, i.e. 'model=yolov8n-cls.yaml' will call yolov8-cls.yaml with scale 'n'
  # [depth, width, max_channels]
  l: [1.00, 1.00, 1024]

backbone:
  # [from, repeats, module, args]
  - [-1, 1, HGStem, [32, 48]] # 0-P2/4
  - [-1, 6, HGBlock, [48, 128, 3]] # stage 1

  - [-1, 1, DWConv, [128, 3, 2, 1, False]] # 2-P3/8
  - [-1, 6, HGBlock, [96, 512, 3]] # stage 2

  - [-1, 1, DWConv, [512, 3, 2, 1, False]] # 4-P3/16
  - [-1, 2, SMFA, []] # cm, c2, k, light, shortcut
  - [-1, 6, HGBlock, [192, 1024, 5, True, True]]
  - [-1, 6, HGBlock, [192, 1024, 5, True, True]] # stage 3

  - [-1, 1, DWConv, [1024, 3, 2, 1, False]] # 8-P4/32
  - [-1, 6, HGBlock, [384, 2048, 5, True, False]] # stage 4

head:
  - [-1, 1, Conv, [256, 1, 1, None, 1, 1, False]] # 10 input_proj.2
  - [-1, 1, AIFI, [1024, 8]]
  - [-1, 1, Conv, [256, 1, 1]] # 12, Y5, lateral_convs.0

  - [-1, 1, nn.Upsample, [None, 2, "nearest"]]
  - [7, 1, Conv, [256, 1, 1, None, 1, 1, False]] # 14 input_proj.1
  - [[-2, -1], 1, Concat, [1]]
  - [-1, 3, RepC3, [256]] # 16, fpn_blocks.0
  - [-1, 1, Conv, [256, 1, 1]] # 17, Y4, lateral_convs.1

  - [-1, 1, nn.Upsample, [None, 2, "nearest"]]
  - [3, 1, Conv, [256, 1, 1, None, 1, 1, False]] # 19 input_proj.0
  - [[-2, -1], 1, Concat, [1]] # cat backbone P4
  - [-1, 3, RepC3, [256]] # X3 (21), fpn_blocks.1

  - [-1, 1, Conv, [256, 3, 2]] # 22, downsample_convs.0
  - [[-1, 17], 1, Concat, [1]] # cat Y4
  - [-1, 3, RepC3, [256]] # F4 (24), pan_blocks.0

  - [-1, 1, Conv, [256, 3, 2]] # 25, downsample_convs.1
  - [[-1, 12], 1, Concat, [1]] # cat Y5
  - [-1, 3, RepC3, [256]] # F5 (27), pan_blocks.1

  - [[21, 24, 27], 1, RTDETRDecoder, [nc]] # Detect(P3, P4, P5)


 2.5 修改train.py文件

       创建Train_RT脚本用于训练。

from ultralytics.models import RTDETR
import os
os.environ['KMP_DUPLICATE_LIB_OK'] = 'True'

if __name__ == '__main__':
    model = RTDETR(model='ultralytics/cfg/models/rt-detr/rtdetr-l.yaml')
    # model.load('yolov8n.pt')
    model.train(data='./data.yaml', epochs=2, batch=1, device='0', imgsz=640, workers=2, cache=False,
                amp=True, mosaic=False, project='runs/train', name='exp')

         在train.py脚本中填入修改好的yaml路径,运行即可训。

三、相关改进思路(2024/11/16日群文件)

        根据SMFA模块特性,可如图加入到HGBlock、RepNCSPELAN4、RepC3等模块中,代码见群文件,结构如图。自研模块与该模块融合代码及yaml文件见群文件。

 ⭐另外,融合上百种深度学习改进模块的YOLO项目仅79.9(含百种改进的v9),RTDETR79.9,含高性能自研模型,更易发论文,代码每周更新,欢迎点击下方小卡片加我了解。⭐

⭐⭐平均每个文章对应4-6个二创及自研融合模块⭐⭐


本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/915012.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

postman变量和脚本功能介绍

1、基本概念——global、collection、environment 在postman中,为了更好的管理各类变量、测试环境以及脚本等,创建了一些概念,包括:globals、collection、environment。其实在postman中,最上层还有一个Workspaces的概…

计算机网络常见面试题(一):TCP/IP五层模型、TCP三次握手、四次挥手,TCP传输可靠性保障、ARQ协议

文章目录 一、TCP/IP五层模型(重要)二、应用层常见的协议三、TCP与UDP3.1 TCP、UDP的区别(重要)3.2 运行于TCP、UDP上的协议3.3 TCP的三次握手、四次挥手3.3.1 TCP的三次握手3.3.2 TCP的四次挥手3.3.3 随机生成序列号的原因 四、T…

约束(MYSQL)

not null(非空) unique(唯一) default(默认约束,规定值) 主键约束primary key(非空且唯一) auto_increment(自增类型) 复合主键 check&#xff08…

Cent OS-7的Apache服务配置

WWW是什么? WWW(World Wide Web,万维网)是一个全球性的信息空间,其中的文档和其他资源通过URL标识,并通过HTTP或其他协议访问。万维网是互联网的一个重要组成部分,但它并不是互联网的全部。互联…

【C++】类与对象的基础概念

目录: 一、inline 二、类与对象基础 (一)类的定义 (二)访问限定符 (三)类域 (四)实例化概念 正文 一、inline 在C语言的学习过程中,大家肯定了解过宏这个概…

matlab实现主成分分析方法图像压缩和传输重建

原创 风一样的航哥 航哥小站 2024年11月12日 15:23 江苏 为了研究图像的渐进式传输技术,前文提到过小波变换,但是发现小波变换非常适合传输缩略图,实现渐进式传输每次传输的数据量不一样,这是因为每次变换之后低频成分大约是上一…

python成长技能之网络编程

文章目录 一、初识Socket1.1 什么是 Socket?1.2 socket的基本操作1.3 socket常用函数 二、基于UDP实现客户端与服务端通信三、基于TCP实现客户端与服务端通信四、使用requests模块发送http请求 一、初识Socket 1.1 什么是 Socket? Socket又称"套接字",…

ROM修改进阶教程------安卓14 安卓15去除app签名验证的几种操作步骤 详细图文解析

在安卓14 安卓15的固件中。如果修改了系统级别的app。那么就会触发安卓14 15的应用签名验证。要么会导致修改的固件会进不去系统,或者进入系统有bug。博文将从几方面来解析去除安卓14 15应用签名验证的几种方法。 💝💝💝通过博文了解: 1💝💝💝-----安卓14去除…

[Docker#6] 镜像 | 常用命令 | 迁移镜像 | 压缩与共享

目录 Docker 镜像是什么 生活案例 为什么需要镜像 镜像命令详解 实验 1.一些操作 1. 遍历查看镜像 2. 查看镜像仓库在本地的存储信息 进入镜像存储目录 查看 repositories.json 文件 3. 镜像过滤 4. 下载镜像时的分层 实战一:离线迁移镜像 实战二&…

「QT」几何数据类 之 QVector3d 三维向量类

✨博客主页何曾参静谧的博客📌文章专栏「QT」QT5程序设计📚全部专栏「VS」Visual Studio「C/C」C/C程序设计「UG/NX」BlockUI集合「Win」Windows程序设计「DSA」数据结构与算法「UG/NX」NX二次开发「QT」QT5程序设计「File」数据文件格式「PK」Parasolid…

人工智能(AI)对于电商行业的变革和意义

![在这里插入图片描述](https://img-blog.csdnimg.cn/direct/402a907e12694df5a34f8f266385f3d2.png#pic_center> 🎓作者简介:全栈领域优质创作者 🌐个人主页:百锦再新空间代码工作室 📞工作室:新空间代…

物联网设备研究——分配推理负载的联合学习方法

概述 物联网(IoT)的最新发展导致人工智能模型被嵌入到传感器和智能手机等终端设备中。这些模型是根据每个设备的存储容量和计算能力定制的,但重点是在终端侧进行本地推理,以降低通信成本和延迟。 然而,与部署在边缘服…

CentOS Stream 9设置静态IP

CentOS Stream 9设置静态IP CentOS Stream 9作为CentOS Stream发行版的下一个主要版本,已经发布有一段时间,但与目前广泛使用的CentOS7有较大区别。安装试用Stream 9的过程中,就发现设置静态IP的方式和CentOS7/8差别较大,在此记录…

【嵌入式】ESP32开发(一)ESP-IDF概述

文章目录 1 前言2 IDF环境配置3 在VS Code中使用IDF3.1 使用ESP-IDF例程3.2 底部按钮的作用【重要!】3.3 高级用法4 ESP-IDF框架分析5 从零开始创建一个项目5.1 组件(component)6 主要参考资料7 遇到的一些问题与解决办法8 对于ESP-IDF开发的一些感受1 前言 对于ESP32的开发…

基于Multisim水箱水位控制系统仿真电路(含仿真和报告)

【全套资料.zip】水箱水位控制系统仿真电路Multisim仿真设计数字电子技术 文章目录 功能一、Multisim仿真源文件二、原理文档报告资料下载【Multisim仿真报告讲解视频.zip】 功能 1.在水箱内的不同高度安装3根金属棒,以感知水位变化情况, 液位分1&…

解读Nature:Larger and more instructable language models become less reliable

目录 Larger and more instructable language models become less reliable 核心描述 核心原理 创新点 举例说明 大模型训练,微调建议 Larger and more instructable language models become less reliable 这篇论文的核心在于对大型语言模型(LLMs)的可靠性进行了深入…

zabbix监控端界面时间与服务器时间不对应

1. 修改系统时间 # tzselect Please select a continent, ocean, "coord", or "TZ".1) Africa2) Americas3) Antarctica4) Asia5) Atlantic Ocean6) Australia7) Europe8) Indian Ocean9) Pacific Ocean 10) coord - I want to use geographical coordina…

ubuntu20.04安装FLIR灰点相机BFS-PGE-16S2C-CS的ROS驱动

一、Spinnaker 安装 1.1Spinnaker 下载 下载地址为: https://www.teledynevisionsolutions.com/support/support-center/software-firmware-downloads/iis/spinnaker-sdk-download/spinnaker-sdk–download-files/?pnSpinnakerSDK&vnSpinnakerSDK 在上述地址中…

Windows配置JDK

1、解压 下载以后解压,放在一个没有中文路径和没有空格的目录,如下图: 2、配置Java环境 1)、点击左下角windows图标,输入huanjing(或者path),打开环境变量配置 如图: …

Unity教程(十八)战斗系统 攻击逻辑

Unity开发2D类银河恶魔城游戏学习笔记 Unity教程(零)Unity和VS的使用相关内容 Unity教程(一)开始学习状态机 Unity教程(二)角色移动的实现 Unity教程(三)角色跳跃的实现 Unity教程&…