RTDETR论文快速理解和代码快速实现(训练与预测)

文章目录

  • 前言
  • 一、摘要
  • 二、论文目的
  • 三、论文贡献
  • 四、模型结构
    • 1、模型整体结构
    • 2、backbone结构
    • 3、neck结构
    • 4、混合编码器(neck)
  • 五、RTDERT模型训练(data-->train)
    • 1、环境安装
    • 2、训练
      • 1、数据准备
      • 2、数据yaml文件
      • 3、训练代码
      • 4、训练运行结果
    • 3、推理
      • 1、推理代码
      • 2、推理运行结果
  • 总结


前言

最近,我们想比较基于DETR的transformer模型与基于CNN的yolo模型效果,而百度RT-DETR模型声称“在实时目标检测领域打败YOLO”。从数据的角度来看,RT-DETR似乎确实在某些方面超越了YOLO。我选择RT-DETR模型与YOLO模型比较。本篇文章将介绍RT-DETR模型原理–>环境安装–>数据准备–>训练实现–>预测实现。


一、摘要

近期,端到端基于transformer检测器DETRs已有显著性能。然而,DETR的计算成本限制其实际应用,也阻止其无后处理的优势(如:NMS)。在这篇论文,我们首次分析NMS对目标检测的速度与精确率影响,并构建了端到端的speed基准。为了解决这些问题,我们提出RT-DETR模型,据我们所知,这是第一个实时端到端检测模型。特别的,我们设计一个高效混合编码器加工多尺度特征与特征交互和融合,并提出IOU感知查询,通过像解码器提供更高初始目标来提示性能。除此之外,我们提出的检测模型,可使用解码层without retraining灵活调整推理速度,这样可适应多样的实时场景。我们模型RT-DETR-L在coco2017实现了53%的AP和114FPS on T4 gpu,而RT-DETR-X实现54.8%AP和74FPS,超过同规模模型的yolo。此外,我们的 RT-DETR-R50 实现了53.1%的AP和108FPS的速度,准确性优于 DINO-Deformable-DETR-R50 约 2.2% AP,帧率约为其21倍。
在这里插入图片描述

二、论文目的

实时目标检测是一个重要的研究领域,而DETR的高计算成本问题尚未得到有效解决,这限制了DETR的实际应用,并导致无法充分利用其优势(后处理)。换句话说,RTDETR解决问题是

为了实现上述目标,我们重新思考了DETR,并对其关键组件进行了详细分析和实验,以减少不必要的计算冗余。具体而言,我们发现虽然引入多尺度特征有助于加快训练收敛和提高性能[43],但它也导致输入编码器的序列长度显著增加。因此,由于高计算成本,Transformer编码器成为模型的计算瓶颈。为了实现实时目标检测,我们设计了一个高效的混合编码器来替代原始的Transformer编码器。通过解耦多尺度特征的内尺度交互和跨尺度融合,编码器能够高效处理具有不同尺度的特征。此外,先前的研究[35, 20]表明,解码器的对象查询初始化方案对于检测性能至关重要。为了进一步提高性能,我们提出了基于IoU的查询选择方法,通过在训练过程中提供IoU约束,为解码器提供更高质量的初始对象查询。此外,我们提出的检测器支持通过使用不同的解码器层对推理速度进行灵活调节,无需重新训练,这得益于DETR架构中解码器的设计,并有助于实时检测器的实际应用。

三、论文贡献

本论文的主要贡献总结如下:

1、我们提出了第一个实时端到端目标检测器,不仅在准确性和速度方面优于当前最先进的实时检测器,而且不需要后处理,因此推理速度不会延迟并保持稳定;

2、我们详细分析了NMS对实时检测器的影响,并从后处理的角度得出了关于基于CNN的实时检测器的结论;

3、我们提出的IoU-aware查询选择在模型中展现出卓越的性能改进,为改进目标查询的初始化方案提供了新的思路;

4、我们的工作为端到端检测器的实时实现提供了可行的解决方案,所提出的检测器可以通过使用不同的解码器层进行灵活调整模型大小和推理速度,无需重新训练。

四、模型结构

1、模型整体结构

RT-DETR模型由主干网络(backbone)、混合编码器(hybrid encoder)和带有辅助预测头的Transformer解码器组成。模型的整体架构概述如下图所示。具体来说,我们利用主干网络最后三个阶段的输出特征{S3,S4,S5}作为编码器的输入。混合编码器通过内部尺度交互和跨尺度融合,将多尺度特征转换为图像特征序列。随后,采用IoU感知的查询选择机制,从编码器的输出序列中选择固定数量的图像特征作为解码器的初始对象查询。最后,带有辅助预测头的解码器迭代优化对象查询,生成边界框和置信度分数。
在这里插入图片描述
RT-DETR模型架构图显示了主干网络的最后三个阶段{S3,S4,S5}作为编码器的输入。高效的混合编码器通过内部尺度特征交互(AIFI)和跨尺度特征融合模块(CCFM)将多尺度特征转化为图像特征序列。采用IoU感知的查询选择方法,选择固定数量的图像特征作为解码器的初始对象查询。最后,解码器通过辅助预测头迭代优化对象查询,生成边界框和置信度分数
本文最重要是设计AIFI与CCFM结构

2、backbone结构

与YOLO相似,RT-DETR最终会输出三种不同尺寸的特征图,它们相对于输入图像的分辨率下采样倍数分别是 8 倍、16 倍和 32 倍。这与主流的YOLO算法相似。除此之外,在主干结构的其他方面,RT-DETR并没有特别的地方。

3、neck结构

对于颈部网络部分,RT-DETR 采用了一层 Transformer 的 Encoder ,文中这个颈部网络叫做 Efficient Hybrid Encoder,其包括两部分:Attention-based Intra-scale Feature Interaction (AIFI) 和 CNN-based Cross-scale Feature-fusion Module (CCFM),这个AIFI模块有一点值得注意,这个模块只对S5特征图进行处理

对于AIFI模块(如下左图),它首先将二维的 S5 特征拉成向量,然后交给AIFI模块处理,其数学过程就是多头自注意力与 FFN,随后,再将输出Reshape回二维,记作 F5,以便去完成后续的所谓的“跨尺度特征融合”。

对于CCFM模块(如下右图),以YOLO的角度看这个结构的话,这个CCFM模块就是一个FPN/PAN结构。关于CCFM模块中的Fusion文中也给了详细的结构图,是由 2 个1×1 卷积和 N 个 RepBlock 构成的,这里之所以写成 N ,我觉得是因为 RT-DETR 可以进行缩放处理,通过调整 CCFM中RepBlock 的数量和 Encoder 的编码维度分别控制 Hybrid Encoder 的深度和宽度,同时对 backbone 进行相应的调整即可实现检测器的缩放。
在这里插入图片描述

4、混合编码器(neck)

在3已经介绍neck最终结构,而设计neck结构时,作者为了实时性与减少冗余,设计了一些列结构,其原因是注意力机制的改进减少了计算开销,却输入序列的大幅增加仍导致编码器成为计算瓶颈,不太好实时场景中使用。作者分析了多尺度变换器编码器中存在的计算冗余,设计了一系列变种来证明同时进行内部尺度和跨尺度特征交互在计算上效率低下。

在这里插入图片描述
A → B:变体B插入了一个单尺度的Transformer编码器,它使用了一个Transformer块的层。每个尺度的特征共享编码器,进行内部尺度的特征交互,然后将输出的多尺度特征进行连接。
B → C:变体C在B的基础上引入了基于尺度的特征融合,将连接的多尺度特征输入编码器进行特征交互。
C → D:变体D将多尺度特征的内部尺度交互和跨尺度融合解耦。首先,使用单尺度的Transformer编码器进行内部尺度交互,然后利用类似于PANet [21]的结构进行跨尺度融合。
D → E:变体E在D的基础上进一步优化多尺度特征的内部尺度交互和跨尺度融合,采用了我们设计的高效混合编码器。

RT-DETR认为S5特征相对于较浅的S3和S4特征来说,具有更深、更高级和更丰富的语义特征。这些语义特征对于Transformer模型更加重要,因为它们对于区分不同物体的特征非常有用,而浅层特征由于缺乏良好的语义特征并不是很丰富。

五、RTDERT模型训练(data–>train)

我将在此部分介绍环境安装、数据准备格式、训练相关配置与代码、预测相关内容与代码,我也将数据、官网提供权重放在这里,有需要自行下载。

1、环境安装

使用命令安装,如下:

conda create -n yolov8 python=3.8
conda activate yolov8
git clone https://github.com/ultralytics/ultralytics.git
cd ultralytics
pip install -r requirement.txt
pip install ultralytics

使用上面命令安装可能会报错Could not load library libcudnn_cnn_train.so.8 ,解决方法点击这里,建议先安装较低点的torch版本。

2、训练

我们使用yolov8集成的RTDETR模型,训练与预测文件大致如下图。
在这里插入图片描述

1、数据准备

实际为yolo数据格式,可按照yolov5或v8格式准备即可。

2、数据yaml文件

其数据yaml文件与yolo差不多,但少了nc且将names变成字典的映射,coco8.yaml内容如下:

# Ultralytics YOLO 🚀, AGPL-3.0 license
# COCO8 dataset (first 8 images from COCO train2017) by Ultralytics
# Example usage: yolo train data=coco8.yaml
# parent
# ├── ultralytics
# └── datasets
#     └── coco8  ← downloads here (1 MB)


# Train/val/test sets as 1) dir: path/to/imgs, 2) file: path/to/imgs.txt, or 3) list: [path/to/imgs1, path/to/imgs2, ..]
path: C:/Users/Administrator/Desktop/rtdetr/coco128  # dataset root dir
train: images/train  # train images (relative to 'path') 4 images
val: images/train  # val images (relative to 'path') 4 images
test:  # test images (optional)

# Classes
names:
  0: person
  1: bicycle
  2: car
  3: motorcycle
  4: airplane
  5: bus
  6: train
  7: truck
  8: boat
  9: traffic light
  10: fire hydrant
  11: stop sign
  12: parking meter
  13: bench
  14: bird
  15: cat
  16: dog
  17: horse
  18: sheep
  19: cow
  20: elephant
  21: bear
  22: zebra
  23: giraffe
  24: backpack
  25: umbrella
  26: handbag
  27: tie
  28: suitcase
  29: frisbee
  30: skis
  31: snowboard
  32: sports ball
  33: kite
  34: baseball bat
  35: baseball glove
  36: skateboard
  37: surfboard
  38: tennis racket
  39: bottle
  40: wine glass
  41: cup
  42: fork
  43: knife
  44: spoon
  45: bowl
  46: banana
  47: apple
  48: sandwich
  49: orange
  50: broccoli
  51: carrot
  52: hot dog
  53: pizza
  54: donut
  55: cake
  56: chair
  57: couch
  58: potted plant
  59: bed
  60: dining table
  61: toilet
  62: tv
  63: laptop
  64: mouse
  65: remote
  66: keyboard
  67: cell phone
  68: microwave
  69: oven
  70: toaster
  71: sink
  72: refrigerator
  73: book
  74: clock
  75: vase
  76: scissors
  77: teddy bear
  78: hair drier
  79: toothbrush


# Download script/URL (optional)
download: https://ultralytics.com/assets/coco8.zip

3、训练代码

我们使用命令训练,如下代码:

yolo train model=rtdetr-l.pt data=coco8.yaml epochs=100 imgsz=640 batch=2 amp=False

4、训练运行结果

配置好以上内容即可训练,执行过程如下显示
在这里插入图片描述

3、推理

1、推理代码

这里不在过多介绍推理代码,朋友们可自行查阅。

import cv2
import torch
import numpy as np
from ultralytics.nn.autobackend import AutoBackend

def preprocess(image):
    image = cv2.resize(image, (640, 640))
    image = (image[..., ::-1] / 255.0).astype(np.float32) # BGR to RGB, 0 - 255 to 0.0 - 1.0
    image = image.transpose(2, 0, 1)[None]  # BHWC to BCHW (n, 3, h, w)
    image = torch.from_numpy(image)
    return image

def postprocess(pred, oh, ow, conf_thres=0.25):

    # 输入是模型推理的结果,即300个预测框
    # 1,300,84 [cx,cy,w,h,class*80]
    boxes = []
    for item in pred[0]:
        cx, cy, w, h = item[:4]
        label = item[4:].argmax()
        confidence = item[4 + label]
        if confidence < conf_thres:
            continue
        left    = cx - w * 0.5
        top     = cy - h * 0.5
        right   = cx + w * 0.5
        bottom  = cy + h * 0.5
        boxes.append([left, top, right, bottom, confidence, label])

    boxes = np.array(boxes)
    lr = boxes[:,[0, 2]]
    tb = boxes[:,[1, 3]]
    boxes[:,[0,2]] = ow * lr
    boxes[:,[1,3]] = oh * tb

    return boxes

def hsv2bgr(h, s, v):
    h_i = int(h * 6)
    f = h * 6 - h_i
    p = v * (1 - s)
    q = v * (1 - f * s)
    t = v * (1 - (1 - f) * s)
    
    r, g, b = 0, 0, 0

    if h_i == 0:
        r, g, b = v, t, p
    elif h_i == 1:
        r, g, b = q, v, p
    elif h_i == 2:
        r, g, b = p, v, t
    elif h_i == 3:
        r, g, b = p, q, v
    elif h_i == 4:
        r, g, b = t, p, v
    elif h_i == 5:
        r, g, b = v, p, q

    return int(b * 255), int(g * 255), int(r * 255)

def random_color(id):
    h_plane = (((id << 2) ^ 0x937151) % 100) / 100.0
    s_plane = (((id << 3) ^ 0x315793) % 100) / 100.0
    return hsv2bgr(h_plane, s_plane, 1)

if __name__ == "__main__":
    
    img = cv2.imread("1.jpg")
    oh, ow = img.shape[:2]

    img_pre = preprocess(img)

    # postprocess
    # ultralytics/models/rtdetr/predict.py
    model  = AutoBackend(weights="rtdetr-l.pt")
    names  = model.names
    result = model(img_pre)[0]  # 1,300,84

    boxes  = postprocess(result, oh, ow)

    for obj in boxes:
        left, top, right, bottom = int(obj[0]), int(obj[1]), int(obj[2]), int(obj[3])
        confidence = obj[4]
        label = int(obj[5])
        color = random_color(label)
        cv2.rectangle(img, (left, top), (right, bottom), color=color ,thickness=2, lineType=cv2.LINE_AA)
        caption = f"{names[label]} {confidence:.2f}"
        w, h = cv2.getTextSize(caption, 0, 1, 2)[0]
        cv2.rectangle(img, (left - 3, top - 33), (left + w + 10, top), color, -1)
        cv2.putText(img, caption, (left, top - 5), 0, 1, (0, 0, 0), 2, 16)

    cv2.imwrite("infer.jpg", img)
    print("save done")  


注:若下载了文件可直接 python detect.py执行,可得结果

2、推理运行结果

在这里插入图片描述


总结

文章主要是更换backbone(个人觉得不是文章重点),而使用S5在结合作者多个neck模块实验,该neck结构主打消除计算实现实时。
代码可使用百度官网代码,也可使用yolov8自带代码(高效实现)。
后期,我将仿yolov5一键训练与预测,直接使用xml文件格式训练有预测RTDETR文章。

参考博客链接:
https://blog.csdn.net/qq_40672115/article/details/134356250
https://blog.csdn.net/weixin_43694096/article/details/131353118

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/259408.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

Nginx快速入门:安装目录结构详解及核心配置解读(二)

0. 引言 上节我们讲解了nginx的应用场景和安装&#xff0c;本节继续针对nginx的各个目录文件进行讲解&#xff0c;让大家更加深入的认识nginx。并通过一个实操案例&#xff0c;带大家来实际认知nginx的核心配置 1. nginx安装目录结构 首先nginx的默认安装目录为&#xff1a;…

原子学习笔记3——使用tslib库

一、tslib介绍 tslib 是专门为触摸屏设备所开发的 Linux 应用层函数库&#xff0c;并且是开源。 tslib 为触摸屏驱动和应用层之间的适配层&#xff0c;它把应用程序中读取触摸屏 struct input_event 类型数据&#xff08;这是输入设备上报给应用层的原始数据&#xff09;并进行…

Elasticsearch常见面试题

文章目录 1.简单介绍下ES&#xff1f;2.简单介绍当前可以下载的ES稳定版本&#xff1f;3.安装ES前需要安装哪种软件&#xff1f;4.请介绍启动ES服务的步骤&#xff1f;5.ES中的倒排索引是什么&#xff1f;6. ES是如何实现master选举的&#xff1f;7. 如何解决ES集群的脑裂问题8…

SpringBoot-XXLJOB提供动态API调度任务

目录 一、项目版本 二、XXL-JOB提供动态API controller层 service层 三、SpringBoot项目 pom model XxlJobUtil-工具类 XXL-JOB是一个分布式任务调度平台&#xff0c;其核心设计目标是开发迅速、学习简单、轻量级、易扩展。现已开放源代码并接入多家公司线上产品线&…

Gateway网关-网关的cors跨域配置

目录 一、跨域 二、解决方案 三、实际测试 3.1 html调用接口 3.2 跨域问题复现 3.3 application文件中配置CORS 3.4 问题解决 一、跨域 跨域问题&#xff1a;浏览器禁止请求的发起者与服务端发生跨域ajax请求&#xff0c;请求被浏览器拦截的问题 跨域&#xff1a;…

第四节TypeScript 声明变量

1、typescript变量声明 变量是一种使用方便的占位符&#xff0c;用于引用计算机内存地址。 我们可以把变量看做存储数据的容器。 typescript变量的命名规则&#xff1a; 变量名称可以包含数字和字母。除了下划线_和美元$符号外&#xff0c;不能包含其它特殊字符&#xff0c…

基于Alpha-Beta剪枝树的井字棋人机博弈系统的实现

这篇文章讨论了算法的基本概念与特性&#xff0c;并介绍了五种常见的算法类型&#xff1a;分治法、动态规划、贪心算法、回溯法和分支限界法。文章以井字棋博弈中的Alpha-Beta剪枝树作为示例&#xff0c;详细解释了该算法的应用和原理。Alpha-Beta剪枝树是一种用于实现游戏AI的…

Python数据加密:保障信息安全的最佳实践

更多资料获取 &#x1f4da; 个人网站&#xff1a;ipengtao.com 随着信息技术的发展&#xff0c;数据安全成为越来越重要的议题。在Python中&#xff0c;有多种方法可以用于数据加密&#xff0c;以确保敏感信息在传输和存储过程中不被泄露或篡改。本文将详细介绍Python中数据加…

智能优化算法应用:基于梯度算法3D无线传感器网络(WSN)覆盖优化 - 附代码

智能优化算法应用&#xff1a;基于梯度算法3D无线传感器网络(WSN)覆盖优化 - 附代码 文章目录 智能优化算法应用&#xff1a;基于梯度算法3D无线传感器网络(WSN)覆盖优化 - 附代码1.无线传感网络节点模型2.覆盖数学模型及分析3.梯度算法4.实验参数设定5.算法结果6.参考文献7.MA…

服务器数据恢复-EMC存储raid5磁盘物理故障离线的数据恢复案例

服务器数据恢复环境&故障&#xff1a; 一台emc某型号存储服务器&#xff0c;存储服务器上组建了一组raid5磁盘阵列&#xff0c;阵列中有两块磁盘作为热备盘使用。存储服务器在运行过程中有两块磁盘出现故障离线&#xff0c;但是只有一块热备盘激活&#xff0c;最终导致该ra…

文件操作入门指南

目录 一、为什么使用文件 二、什么是文件 2.1 程序文件 2.2 数据文件 2.3 文件名 三、文件的打开和关闭 3.1 文件指针 3.2 文件的打开和关闭 四、文件的顺序读写 ​编辑 &#x1f33b;深入理解 “流”&#xff1a; &#x1f342;文件的顺序读写函数介绍&#xff1a; …

系列十四(面试)、谈谈你对StackOverflowError的理解?

一、StackOverflowError 1.1、概述 StackOverflowError是栈内存溢出的意思。栈中主要存储的是8种基本数据类型 引用类型 实例方法&#xff0c;栈的空间也是有限的&#xff0c;当存储进栈中的容量大于栈的最大容量时&#xff0c;就会报StackOverflowError的错误。 1.2、案例 …

如何入门 GPT 并快速跟上当前的大语言模型 LLM 进展?

入门GPT 首先说第一个问题&#xff1a;如何入门GPT模型&#xff1f; 最直接的方式当然是去阅读官方的论文。GPT模型从2018年的GPT-1到现在的GPT-4已经迭代了好几个版本&#xff0c;通过官方团队发表的论文是最能准确理清其发展脉络的途径&#xff0c;其中包括GPT模型本身和一…

最详细手把手教你安装 Vivado2017.4

软件下载 官网可下载各个版本 百度网盘链接 Vivado2017.4 License 软件安装 解压缩安装包&#xff0c;双击运行安装程序 xsetup.exe&#xff1a; 忽略软件更新&#xff0c;点击 Continue&#xff1a; 点击 Next&#xff1a; 全部勾选 I Agree&#xff0c;点击 Next&#x…

从0到1打造一款WebStyle串口调试工具

Tip&#xff1a;No Ego Some programmers have a huge problem: their own ego. But there is no time for developing an ego. There is no time for being a rockstar. Who is it who decides about your quality as programmer? You? No. The others? Probably. But can …

Python (十二) NumPy操作

程序员的公众号&#xff1a;源1024&#xff0c;获取更多资料&#xff0c;无加密无套路&#xff01; 最近整理了一波电子书籍资料&#xff0c;包含《Effective Java中文版 第2版》《深入JAVA虚拟机》&#xff0c;《重构改善既有代码设计》&#xff0c;《MySQL高性能-第3版》&…

程序员的20大Git面试问题及答案

文章目录 1.什么是Git&#xff1f;2.Git 工作流程3.在 Git 中提交的命令是什么&#xff1f;4.什么是 Git 中的“裸存储库”&#xff1f;5.Git 是用什么语言编写的&#xff1f;6.在Git中&#xff0c;你如何还原已经 push 并公开的提交&#xff1f;7.git pull 和 git fetch 有什么…

计算机网络(3):数据链路层

数据链路层属于计算机网络的低层。 数据链路层使用的信道主要有以下两种类型&#xff1a; (1)点对点信道。这种信道使用一对一的点对点通信方式。 (2)广播信道。这种信道使用一对多的广播通信方式。广播信道上连接的主机很多&#xff0c;因此必须使用专用的共享信道协议来协调这…

制作PPT找了一个校徽是方形的,如何裁剪为圆形的。

问题描述&#xff1a;制作PPT找了一个校徽是方形的&#xff0c;如何裁剪为圆形的。 问题解决&#xff1a;使用一个在线圆形裁剪软件即可。 网址为&#xff1a; https://crop-circle.imageonline.co/cn/#google_vignette

css实现边框彩虹跑马灯效果

效果展示 代码实战 <!DOCTYPE html> <html lang"en"><head><meta charset"UTF-8"><meta http-equiv"X-UA-Compatible" content"IEedge"><meta name"viewport" content"widthdevice-…