RT-DETR:端到端的实时Transformer检测模型(目标检测+跟踪)

博主一直一来做的都是基于Transformer的目标检测领域,相较于基于卷积的目标检测方法,如YOLO等,其检测速度一直为人诟病。
终于,RT-DETR横空出世,在取得高精度的同时,检测速度也大幅提升。

那么RT-DETR是如何做到的呢?

在研究RT-DETR的改进前,我们先来了解下DETR类目标检测方法的发展历程吧

  • 首先是DETR,该方法作为Transformer在目标检测领域的开山之作,一经推出,便引发了极大的轰动,该方法巧妙的利用Transformer进行特征提取与解码,同时通过匈牙利匹配方法完成预测框与真实框的匹配,避免了NMS等后处理过程。
  • 随后DAB-DETR引入了动态锚框作为查询向量,从而对DETR中的100个查询向量进行了解释。
  • Deformable-DETR针对Transformer中自注意力计算复杂度高的问题,提出可变形注意力计算,即通过可学习的选取少量向量进行注意力计算,大幅的降低了计算量。
  • DN-DETR认为匈牙利匹配的二义性是导致DETR训练收敛慢的原因,因此提出查询降噪机制,即利用先前DAB-DETR中将查询向量解释为锚框的原理,给查询向量添加一些噪声来辅助模型收敛,最终大幅提升了模型的训练速度。
  • DINO则是在DAB-DETR与DN-DETR的基础上进行进一步的融合与改进。
  • H-DETR为使模型获取更多的正样本特征,从而提升检测精度,因此提出混合匹配方法,在训练阶段,包含原始的匈牙利匹配分支与一个一对多的辅助匹配分支,而在推理阶段,则只有一个匈牙利匹配分支。

然而,上述方法尽管已经大幅提升了检测精度,降低了计算复杂度,但其受Transformer本身高计算复杂度的制约,DETR类目标检测方法的实时性始终令人难以满意,尤其是相较于YOLO等单阶段目标检测方法,其检测速度的确差别巨大。

为了解决这个问题,百度提出了RT-DETR,该方法依旧是在DETR的基础上改进生成的,从论文中给出的实验结果来看,该方法无论在检测速度还是检测精度方法都已经超过了YOLOv8,实现了真正的实时性。

在这里插入图片描述

  • 创新点1:高效混合编码器:RT-DETR使用了一种高效的混合编码器,通过解耦尺度内交互和跨尺度融合来处理多尺度特征。这种独特的基于视觉Transformer的设计降低了计算成本,并允许实时物体检测。
  • 创新点2:IoU感知查询选择:RT-DETR通过利用IoU感知的查询选择改进了目标查询初始化。这使得模型能够聚焦于场景中最相关的目标,从而提高了检测精度。
  • 创新点3:自适应推理速度:RT-DETR支持通过使用不同的解码器层来灵活调整推理速度,而无需重新训练。这种适应性便于在各种实时目标检测场景中的实际应用。

RT-DETR的代码有两个,一个是官方提供的代码,但该代码功能比较单一,只有训练与验证,另一个则是集成在YOLOv8中,该代码的设计就比较全面了

环境部署

conda create -n rtdetr python=3.8
conda activate rtdetr
conda install pytorch==2.0.1 torchvision==0.15.2 torchaudio==2.0.2 pytorch-cuda=11.7 -c pytorch -c nvidia
cd RT-DETR-main/rtdetr_pytorch  //这个路径根据你自己的改
pip install -r requirement.txt

该算法的环境为pytorch=2.0.1,注意,尽量要用pytorch2以上的版本,否则可能会报错:

AttributeError: module 'torchvision' has no attribute 'disable_beta_transforms_warning'

官方模型训练

参数配置

该算法的配置封装较好,我们只需要修改配置即可:train.py,指定要使用的骨干网络。

parser.add_argument('--config', '-c', default="/rtdetr_pytorch\configs/rtdetr/rtdetr_r18vd_6x_coco.yml",type=str, )

修改数据集配置文件:RT-DETR-main\rtdetr_pytorch\configs\dataset\coco_detection.yml
修改训练集与测试集路径,同时修改类别数。

在这里插入图片描述

随后便可以开启训练:该文件中指定 epochs

RT-DETR-main\rtdetr_pytorch\configs\rtdetr\include\optimizer.yml

首次训练,需要下载骨干网络的预训练模型

在这里插入图片描述

在这里,博主使用ResNet18作为骨干特征提取网络

训练结果

开始运行,查看GPU使用情况,此时的batch-size=8,可以看到显存占用4.5G左右,相较于博主先前提出的方法或者DINO,其显存占用少了许多,DINObatch-size=2时的显存占用将近16G.

在这里插入图片描述

训练了24轮的结果。

在这里插入图片描述

训练的结果会保存在output文件夹内:

在这里插入图片描述

官方模型推理

在进行模型推理前,需要先导出模型,在官方代码的tools文件夹下有个export_onnx.py文件,只需要指定配置文件与训练好的模型文件:

parser.add_argument('--config', '-c',  default="/rtdetr_pytorch\configs/rtdetr/rtdetr_r18vd_6x_coco.yml",type=str, )
parser.add_argument('--resume', '-r', default="rtdetr_pytorch/tools\output/rtdetr_r18vd_6x_coco\checkpoint0024.pth",type=str, )

导出的文件是onnx格式

ONNX(Open Neural Network Exchange)是一种开放式的文件格式,用于存储和交换训练好的机器学习模型。它使得不同的人工智能框架(如PyTorch、TensorFlow)可以共享模型,促进了模型在不同平台之间的迁移和复用。ONNX文件采用Protobuf序列化技术进行存储,具有高效、紧凑的特点。

在这里插入图片描述
随后开始推理,代码如下:

import torch
import onnxruntime as ort
from PIL import Image, ImageDraw
from torchvision.transforms import ToTensor
if __name__ == "__main__":
    ##################
    classes = ['car','truck',"bus"]
    ##################
    # print(onnx.helper.printable_graph(mm.graph))
    #############
    img_path = "1.jpg"
    #############
    im = Image.open(img_path).convert('RGB')
    im = im.resize((640, 640))
    im_data = ToTensor()(im)[None]
    print(im_data.shape)
    size = torch.tensor([[640, 640]])
    sess = ort.InferenceSession("model.onnx")
    import time
    start = time.time()
    output = sess.run(
        # output_names=['labels', 'boxes', 'scores'],
        output_names=None,
        input_feed={'images': im_data.data.numpy(), "orig_target_sizes": size.data.numpy()}
    )
    end = time.time()
    fps = 1.0 / (end - start)
    print(fps)
    
    labels, boxes, scores = output
    draw = ImageDraw.Draw(im)
    thrh = 0.6
    for i in range(im_data.shape[0]):

        scr = scores[i]
        lab = labels[i][scr > thrh]
        box = boxes[i][scr > thrh]

        print(i, sum(scr > thrh))
        #print(lab)
        print(f'box:{box}')
        for l, b in zip(lab, box):
            draw.rectangle(list(b), outline='red',)
            print(l.item())

            draw.text((b[0], b[1] - 10), text=str(classes[l.item()]), fill='blue', )
    #############
    im.save('2.jpg')
    #############

YOLOv8集成RT-DETR训练

在YOLOv8中,给出了YOLO先前的诸多版本,此外还包含RT-DETR
其运行环境与官方的相同,这里就不再赘述了,另外,如果想要了解YOLO及其集成算法的更多功能,可以查看:

https://docs.ultralytics.com/

ultralytics集成了多种算法,已有将YOLO目标检测算法大一统的趋势,涵盖语义分割、目标检测、姿势估计、分类、跟踪等多个任务。

数据集配置

YOLO版本的RT-DETR的数据集支持的数据集格式有多种,这里博主选用的是YOLO格式的

coco
    images
    	train2017
    	val2017
    lables
    	train2017
    	val2017

开始训练

随后在根目录下新建一个run.py文件,文件中写入如下代码:

from ultralytics.models import RTDETR
if __name__ == '__main__':
    model = RTDETR(model='ultralytics/cfg/models/rt-detr/rtdetr-l.yaml')
    #model.load('rtdetr-l.pt') # 不使用预训练权重可注释掉此行
    model.train(pretrained=True, data='ultralytics\cfg\datasets\cocomine.yaml', epochs=200, batch=16, device=0, imgsz=320, workers=2,cache=False,)

运行报错:

OMP: Error #15: Initializing libiomp5md.dll, but found libiomp5md.dll already initialized.

解决方法,这是由于Anconda的torch中的某个文件与环境中的某个文件冲突导致的,找到环境中的文件:
环境路径:

D:\softwares\Anconda\envs\detr\Library\bin

将下面的文件给重命名即可。

在这里插入图片描述

随后便开始训练了,如下:

在这里插入图片描述

至此,RT-DETR的训练过程便完成了。博主设置训练200个epoch,但考虑到接下来的任务,因此训练到一半就停止了,生成的文件存放在run文件中,如下:

在这里插入图片描述

YOLOv8集成RT-DETR推理

在YOLOv8集成的RT-DETR中,其设计就非常完备了,我们只需要新建一个predict.py,里面的内容如下:
这里的images即为一个文件夹,里面可以放入多张图像,save代表保存

model=RTDETR("runs\detect/train\weights/best.pt")
model.predict(source="images",save=True)

推理结果、保存路径与推理速度都会显示在下面

在这里插入图片描述

当然我们还可以指定conf参数,即置信度,可以帮我们筛选一下结果:设置置信度为0.6,此时原本的汽车就不再框选了。

在这里插入图片描述

在这里插入图片描述

视频推理

视频推理也很简单,只需要将原来的图像换为视频即可

model=RTDETR("runs\detect/train\weights/best.pt")
model.predict(source="images/1.mp4",save=True,conf=0.6)

在这里插入图片描述

目标跟踪

在先前的目标跟踪中,都是通过先检测,后跟踪的方式,如采用YOLOv5+DeepSort的方式进行目标跟踪,而在YOLOv8中,他将该功能集成到里面,我们可以直接采用执行跟踪任务的方式完成目标跟踪。

from ultralytics.models import RTDETR
model=RTDETR("runs\detect/train\weights/best.pt")
results = model.track(source="images/1.mp4", conf=0.3, iou=0.5,save=True)

RT-DETR目标跟踪视频

轨迹绘制

from collections import defaultdict

import cv2
import numpy as np
from ultralytics import RTDETR

# Load the YOLOv8 model
model=RTDETR("D:\graduate\programs\yolo8/ultralytics-main/runs\detect/train\weights/best.pt")

# Open the video file
video_path = "images/1.mp4"
cap = cv2.VideoCapture(video_path)

# Store the track history
track_history = defaultdict(lambda: [])

# Loop through the video frames
while cap.isOpened():
    # Read a frame from the video
    success, frame = cap.read()

    if success:
        # Run YOLOv8 tracking on the frame, persisting tracks between frames
        results = model.track(frame, persist=True)

        # Get the boxes and track IDs
        boxes = results[0].boxes.xywh.cpu()
        track_ids = results[0].boxes.id.int().cpu().tolist()

        # Visualize the results on the frame
        annotated_frame = results[0].plot()

        # Plot the tracks
        for box, track_id in zip(boxes, track_ids):
            x, y, w, h = box
            track = track_history[track_id]
            track.append((float(x), float(y)))  # x, y center point
            if len(track) > 30:  # retain 90 tracks for 90 frames
                track.pop(0)

            # Draw the tracking lines
            points = np.hstack(track).astype(np.int32).reshape((-1, 1, 2))
            cv2.polylines(annotated_frame, [points], isClosed=False, color=(230, 230, 230), thickness=10)

        # Display the annotated frame
        cv2.imshow("YOLOv8 Tracking", annotated_frame)

        # Break the loop if 'q' is pressed
        if cv2.waitKey(1) & 0xFF == ord("q"):
            break
    else:
        # Break the loop if the end of the video is reached
        break

# Release the video capture object and close the display window
cap.release()
cv2.destroyAllWindows()

在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/667360.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

react路由参数path不再支持正则?比较v5和v6写法的差异性

文章目录 前言v5方式:直接在path参数中,写入对应正则(1)代码详细注释如下(2)页面输出如下,会出现undefined的情况 v6方式: 在路由对象中配置,但只可配动态路由,不可用正则…

在phpstorm2024版里如何使用Jetbrains ai assistant 插件 ?

ai assistant激活成功后,如图 ai assistant渠道:https://web.52shizhan.cn/activity/ai-assistant 在去年五月份的 Google I/O 2023 上,Google 为 Android Studio 推出了 Studio Bot 功能,使用了谷歌编码基础模型 Codey,Codey 是…

C#WPF数字大屏项目实战03--数据内容区域

1、内容区域划分 第一行标题,放了几个文本框 第二行数据,划分成3列布局 2、第1列布局使用UniformGrid控件 最外面放UniformGrid,然后里面放3个GroupBox控件,这3个groupbox都是垂直排列 3、GroupBox控件模板 页面上的3个Group…

【测评|白嫖】雨云宁波新区,2C4G200M,公测期间全免费!

雨云香港三区云服务器,高性能的 Xeon Platinum 处理器 企业级 NVME SSD 高性能云服务器。 一键白嫖链接:https://www.rainyun.com 本篇纯测评,无任何广告,请放心食用!! 本次测评服务器配置如下&#xff1…

系统架构设计师【第9章】: 软件可靠性基础知识 (核心总结)

文章目录 9.1 软件可靠性基本概念9.1.1 软件可靠性定义9.1.2 软件可靠性的定量描述9.1.3 可靠性目标9.1.4 可靠性测试的意义9.1.5 广义的可靠性测试与狭义的可靠性测试 9.2 软件可靠性建模9.2.1 影响软件可靠性的因素9.2.2 软件可靠性的建模方法9.2.3 软件的可靠性模…

十四天学会Vue——Vue 组件化编程(理论+实战)(第四天)

二、 Vue组件化编程 2.1 组件化模式与传统方式编写应用做对比: 传统方式编写应用 依赖关系混乱,不好维护:例如:比如需要引入js1,js2,js3,但是js3需要用到js1、2的方法,所以js1、2…

算能BM1684+FPGA+AI+Camera推理边缘计算盒

搭载算丰智算芯片BM1684,是面向AI推理的边缘计算盒。高效适配市场上所有AI算法,实现视频结构化、人脸识别、行为分析、状态监测等应用,为智慧城市、智慧交通、智慧能源、智慧金融、智慧电信、智慧工业等领域进行AI赋能。 产品规格 处理器芯片…

企业微信接入系列-上传临时素材

企业微信接入系列-上传临时素材 文档介绍上传临时素材写在最后 文档介绍 创建企业群发的文档地址:https://developer.work.weixin.qq.com/document/path/92135,在创建企业群发消息或者群发群消息接口中涉及到上传临时素材的操作,具体文档地址…

数据结构与算法02-排序算法

介绍 排序算法是计算机科学中被广泛研究的一个课题。历时多年,它发展出了数十种算法,这些 算法都着眼于一个问题:如何将一个无序的数字数组整理成升序?先来学习一些“简单排序”,它们很好懂,但效率不如其他…

最佳实践:REST API 的 HTTP 请求参数

HTTP 请求中的请求参数解释 当客户端发起 HTTP 请求 时,它们可以在 URL 末尾添加请求参数(也叫查询参数或 URL 参数)来传递数据。这些参数以键值对的形式出现在 URL 中,方便浏览和操作。 请求参数示例 以下是一些带有请求参数的…

springboot基本使用十二(PageHelper分页查询)

引入依赖&#xff1a; <dependency><groupId>com.github.pagehelper</groupId><artifactId>pagehelper-spring-boot-starter</artifactId><version>1.3.0</version> </dependency> 通个PageHelper.startPage(page,pageSize)方…

SpringBoot的第二大核心AOP系统梳理

目录 1 事务管理 1.1 事务 1.2 Transactional注解 1.2.1 rollbackFor 1.2.2 propagation 2 AOP 基础 2.1 AOP入门 2.2 AOP核心概念 3. AOP进阶 3.1 通知类型 3.2 通知顺序 3.3 切入点表达式 execution切入点表达式 annotion注解 3.4 连接点 1 事务管理 1.1 事务…

20.Redis之缓存

1.什么是缓存&#xff1f; Redis 最主要的用途,三个方面:1.存储数据(内存数据库)2.缓存 【redis 最常用的场景】3.消息队列【很少见】 缓存 (cache) 是计算机中的⼀个经典的概念. 在很多场景中都会涉及到. 核⼼思路就是把⼀些常⽤的数据放到触⼿可及(访问速度更快)的地⽅, ⽅…

一维时间序列信号的改进小波降噪方法(MATLAB R2021B)

目前国内外对于小波分析在降噪方面的方法研究中&#xff0c;主要有小波分解与重构法降噪、小波阈值降噪、小波变换模极大值法降噪等三类方法。 (1)小波分解与重构法降噪 早在1988 年&#xff0c;Mallat提出了多分辨率分析的概念&#xff0c;利用小波分析的多分辨率特性进行分…

Facebook开户|Facebook做落地页的标准和建议

哈喽呀家人们下午好~今天Zoey来跟大家带来Facebook做落地页的标准和建议&#xff01;需要的家人建议点赞收藏啦&#xff01;&#xff01;用户通过点击你的推广链接、搜索引擎搜索结果页面的快照链接、社交媒体中的网页链接、电子邮件中的链接等进入你网站的特定页面&#xff0c…

Microsoft的Copilot现已登陆Telegram

每周跟踪AI热点新闻动向和震撼发展 想要探索生成式人工智能的前沿进展吗&#xff1f;订阅我们的简报&#xff0c;深入解析最新的技术突破、实际应用案例和未来的趋势。与全球数同行一同&#xff0c;从行业内部的深度分析和实用指南中受益。不要错过这个机会&#xff0c;成为AI领…

AdroitFisherman模块测试日志(2024/5/28)

测试内容 测试AdroitFisherman分发包中Base64Util模块。 测试用具 Django5.0.3框架&#xff0c;AdroitFisherman0.0.29 项目结构 路由设置 总路由 from django.contrib import admin from django.urls import path,include from Base64Util import urls urlpatterns [path…

OrangePi AIpro--新手上路

目录 一、SSH登录二、安装VNC Sevice&#xff08;经测试Xrdp远程桌面安装不上&#xff09;2.1安装xface桌面2.2 配置vnc服务2.2.1 设置vnc server6-8位的密码2.2.2 创建vnc文件夹&#xff0c;写入xstartup文件2.2.3 给xstartup文件提高权限2.2.4 在安装产生的vnc文件夹创建xsta…

北京仁爱堂李艳波主任如何预约挂号?

北京仁爱堂擅长治疗神经系统疾病&#xff0c;例如&#xff1a;痉挛性斜颈&#xff0c;特发性震颤&#xff0c;眼球震颤&#xff0c;帕金森&#xff0c;眼球震颤等。 北京仁爱堂国医馆是一所集治疗、 预防、保健、养生于一体的传统中医诊所&#xff0c;具有精湛技术和丰富经验的…

ad18学习笔记20:焊盘设置Solder Mask Expansion(阻焊层延伸)

【AD18新手入门】从零开始制造自己的PCB_ad18教程-CSDN博客 Altium Designer绘制焊盘孔&#xff08;Pad孔&#xff09;封装库的技巧&#xff0c;包括原理图封装和PCB封装_哔哩哔哩_bilibili 默认的焊盘中间是有个过孔的&#xff0c;单层焊盘&#xff08;表贴烛盘&#xff09;…