目标检测算法之RT-DETR

RT-DETR算法理解

  • Background
  • Model Architecture
    • Efficient Hybrid Encoder
    • Uncertainty-minimal Query Selection
  • 总结

Background

Real-time Detection Transformer(RT-DETR)是一个基于tranformer的实时推理目标检测模型。RT-DETR是2023年百度发布的一个新目标检测模型,它兼顾了速度和精度俩个特性,在速度上超越yolo,同时仍保持不低于yolo模型的精度。其分别从encoder部分、query选择俩个方面进行改进,保持了模型的精度,同时提高了模型的推理速度。
在这里插入图片描述
论文地址:https://arxiv.org/pdf/2304.08069
代码地址:https://github.com/lyuwenyu/RT-DETR

Model Architecture

在这里插入图片描述
模型的结构如上图所示,输出图片经过Backbone进行特征提取,获取三个特征图 S 3 、 S 4 、 S 5 S_3、S_4、S_5 S3S4S5。然后将它们输入Efficient Hybrid Encoder层。Efficient Hybrid Encoder层对特征图 S 5 S_5 S5做AIFI获得特征图 F 5 F_5 F5,然后通过CCFF结合 S 3 、 S 4 、 F 5 S_3、S_4、F_5 S3S4F5输出。然后用Uncertainty-minimal Query Selection选取query,再和Encoder的输出一起输入decoder中,最后输出检测结果。

Efficient Hybrid Encoder

作者分析了特征图自交互的情况,认为低级特征具备丰富的图像语义,交互的需求不大。同时通过实验验证了这一观点。这里的出发点是从缩短输入的AIFI的长度出发,由于计算复杂度与长度的平方成正比,由于高级特征的长度较小,所以计算量较少,同时能够验证低级特征交互是不必要,那么就可以较少这一部分的计算。
整个Efficient Hybrid Encoder模块可以用公式表达出来,即 Q = K = V = F l a t t e n ( C 5 ) F 5 = R e s h a p e ( A I F I ( Q , K , V ) ) O = C C F F ( { S 3 , S 4 , F 5 } ) \begin{align*}Q =& K=V = Flatten(C_5)\\F_5 = &Reshape(AIFI(Q,K,V))\\O=&CCFF(\{S_3,S_4,F_5\})\end{align*} Q=F5=O=K=V=Flatten(C5)Reshape(AIFI(Q,K,V))CCFF({S3,S4,F5})这里就是将 C 5 C_5 C5打平,然后输入AIFI中,AIFI是一个普通的transformer encoder模块,然后复原获得特征图 F 5 F_5 F5。然后将三个特征图输入CCFF模块中。官方的CCFF图看起来有些许不明显,所以这里重新画了一下这块,可能让读者更好地了解CCFF,具体见下图。
在这里插入图片描述
CCFF模块其实就是类似于yolo neck中的FPN+PAN,用于融合不同尺度的特征图。这里主要了解一下Fusion的结构,论文中给出了fusion的结构图,具体如下
在这里插入图片描述
Fusion的结构采用了CSP的方法,将输入的特征concat后用1x1的卷积分成了俩份,然后一边经过RepBlock,另一边直接与RepBlock输出直接concat,然后经过flatten层输出。
接下来结合一下源码分析一下CCFF的结构,下面的代码来自hybrid_encoder.py

        inner_outs = [proj_feats[-1]] #获取特征图F5
        for idx in range(len(self.in_channels) - 1, 0, -1): #总共俩层,即idx为2,1
            feat_high = inner_outs[0] #第一次遍历为F5
            feat_low = proj_feats[idx - 1] #第一次遍历为S4
            feat_high = self.lateral_convs[len(self.in_channels) - 1 - idx](feat_high)#这一部分就是图中的黄色模块,由1x1的卷积+BN层+SiLU组成,第一次遍历时处理F5
            inner_outs[0] = feat_high
            upsample_feat = F.interpolate(feat_high, scale_factor=2., mode='nearest') #第一次遍历对经过lateral_conv的F5做上采样
            inner_out = self.fpn_blocks[len(self.in_channels)-1-idx](torch.concat([upsample_feat, feat_low], dim=1)) #这里就是论文中的fusion模块
            inner_outs.insert(0, inner_out)   #相信集合图形可以很好地理解,第二次的遍历对着图就可以了

        outs = [inner_outs[0]]
        for idx in range(len(self.in_channels) - 1): #这里也是遍历俩次
            feat_low = outs[-1] #获得FPN的最后一层输出
            feat_high = inner_outs[idx + 1] #第二次lateral_conv的输出 
            downsample_feat = self.downsample_convs[idx](feat_low) #上采样
            out = self.pan_blocks[idx](torch.concat([downsample_feat, feat_high], dim=1)) #经过fusion模块
            outs.append(out) #这里也是分析了第一次遍历,第二次也是类似的

Uncertainty-minimal Query Selection

作者分析认为,以往选择query时未同时考虑分类和回归的结果,所以导致模型的预测结果中,并不是分类和回归都是最优。所以它为了降低这种不确定性,在query的选择中加入整个因素,即衡量不确定性定义为 U ( x ^ ) U(\hat{x}) U(x^),其中 U ( x ^ ) = ∣ ∣ P ( x ^ ) − C ( x ^ ) ∣ ∣ U(\hat{x}) = ||P(\hat{x})-C(\hat{x})|| U(x^)=∣∣P(x^)C(x^)∣∣其中 x ^ \hat{x} x^为encoder的输出, P P P位置预测, C C C指分类预测。
然后在最后的损失中加上 U U U,即 L ( x ^ , y ^ , y ) = L b o x ( b ^ , b ) + L c l s ( U ( x ^ ) , c ^ , c ) \mathcal{L}(\hat{x},\hat{y},y) = \mathcal{L} _{box}(\hat{b},b)+ \mathcal{L} _{cls}(U(\hat{x}),\hat{c},c) L(x^,y^,y)=Lbox(b^,b)+Lcls(U(x^),c^,c)这里的思想其实就是做了一个分类和回归的对齐,核心上就是分类分数高回归结果也要准。在源码的具体实现中,采用了VFL的方法,VFL公式具体如下 V F L ( p , q ) = { − q ( q log ⁡ ( p ) + ( 1 − q ) log ⁡ ( 1 − p ) ) q > 0 − α p γ log ⁡ ( 1 − p ) q = 0 VFL(p,q)=\left\{\begin{matrix}-q(q\log(p)+(1-q)\log(1-p))&q>0\\ -\alpha p^{\gamma}\log(1-p) &q=0\end{matrix}\right. VFL(p,q)={q(qlog(p)+(1q)log(1p))αpγlog(1p)q>0q=0其中 q q q为预测框的iou, p p p则为分类概率。
源码中的实现如下

    def loss_labels_vfl(self, outputs, targets, indices, num_boxes, log=True):
        assert 'pred_boxes' in outputs
        idx = self._get_src_permutation_idx(indices)

        src_boxes = outputs['pred_boxes'][idx]
        target_boxes = torch.cat([t['boxes'][i] for t, (_, i) in zip(targets, indices)], dim=0)
        ious, _ = box_iou(box_cxcywh_to_xyxy(src_boxes), box_cxcywh_to_xyxy(target_boxes))
        ious = torch.diag(ious).detach()

        src_logits = outputs['pred_logits']
        target_classes_o = torch.cat([t["labels"][J] for t, (_, J) in zip(targets, indices)])
        target_classes = torch.full(src_logits.shape[:2], self.num_classes,
                                    dtype=torch.int64, device=src_logits.device)
        target_classes[idx] = target_classes_o
        target = F.one_hot(target_classes, num_classes=self.num_classes + 1)[..., :-1]

        target_score_o = torch.zeros_like(target_classes, dtype=src_logits.dtype)
        target_score_o[idx] = ious.to(target_score_o.dtype)
        target_score = target_score_o.unsqueeze(-1) * target

        pred_score = F.sigmoid(src_logits).detach()
        weight = self.alpha * pred_score.pow(self.gamma) * (1 - target) + target_score
        
        loss = F.binary_cross_entropy_with_logits(src_logits, target_score, weight=weight, reduction='none')
        loss = loss.mean(1).sum() * src_logits.shape[1] / num_boxes
        return {'loss_vfl': loss}

总结

对RT-DETR的encoder部分,整体看下来像是yolo的backbone+neck。RT-DETR的核心还是在增速上,所以这里它的优化思想是值得借鉴的,但是yolo结构跟DETR结构之间的界限越来越模糊了。对query的优化上,只是做了对齐,使其选择的query更加精确。整体而言模型的创新不大。虽然DETR提倡的是NMS-Free,但是对于某些对精装度要求较高的任务中,如果阈值设置过低,导致最后得出的框过多,仍然需要借助NMS的方法去改进。设置过高则存在丢框的问题。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/752627.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

【MTK平台】如何学习Bluedroid A2DP Code

一 Bluedroid A2DP架构图 备注: vendor/mediatek/proprietary/packages/modules/Bluetooth/system/audio_a2dp_hw/src 目录下编译生成audio.a2dp.default.so,主要实现a2dp做为设备的功能 二 A2DP File Hierarchy ModuleFileDescriptionAudio HAL (hardware/libhardware/…

小程序发布必须进行软件测试吗?测试内容有哪些?

在如今移动互联网时代,小程序已成为许多企业广泛采用的一种营销手段,然而,发布小程序之前进行充分的软件测试是至关重要的,因为它不仅可以确保小程序的质量,还可以避免潜在的风险和损失。 在进行小程序发布前进行软件…

【大模型】大模型微调方法总结(四)

1. P-Tuning v1 1.背景 大模型的Prompt构造方式严重影响下游任务的效果。比如:GPT-3采用人工构造的模版来做上下文学习(in context learning),但人工设计的模版的变化特别敏感,加一个词或者少一个词,或者变…

DIY:在您的 PC 上本地使用 Stable Diffusion AI 模型生成图像

前言 随着DALL-E-2和Midjourney的发布,您可能听说过最近 AI 生成艺术的繁荣。这些人工智能模型如何在几秒钟内创造性地生成逼真的图像,这绝对是令人兴奋的。您可以在这里查看其中的一些:DALL-E-2 gallery和Midjourney gallery 但是这些模型…

DAY16-力扣刷题

1.不同的二叉搜索树2 95. 不同的二叉搜索树 II - 力扣(LeetCode) 给你一个整数 n ,请你生成并返回所有由 n 个节点组成且节点值从 1 到 n 互不相同的不同 二叉搜索树 。可以按 任意顺序 返回答案。 方法一:回溯 class Solutio…

聚观早报 | iPhone 16核心硬件曝光;三星Galaxy全球新品发布会

聚观早报每日整理最值得关注的行业重点事件,帮助大家及时了解最新行业动态,每日读报,就读聚观365资讯简报。 整理丨Cutie 6月28日消息 iPhone 16核心硬件曝光 三星Galaxy全球新品发布会 苹果正多方下注布局AI商店 黄仁勋2024年薪酬3400…

Kotlin设计模式:深入理解桥接模式

Kotlin设计模式:深入理解桥接模式 在软件开发中,随着系统需求的不断增长和变化,类的职责可能会变得越来越复杂,导致代码难以维护和扩展。桥接模式(Bridge Pattern)是一种结构型设计模式,它通过…

Nest 的 IoC 机制

后端系统中,会有很多对象: Controller 对象:接收 http 请求,调用 Service,返回响应 Service 对象:实现业务逻辑 Repository 对象:实现对数据库的增删改查 此外,还有数据库链接对…

【吊打面试官系列-MyBatis面试题】MyBatis 框架的缺点?

大家好,我是锋哥。今天分享关于 【MyBatis 框架的缺点?】面试题,希望对大家有帮助; MyBatis 框架的缺点? 1、SQL 语句的编写工作量较大,尤其当字段多、关联表多时,对开发人员编写 SQL 语句的功底…

工作备忘录哪个好用 好用的工作备忘录

在繁忙的工作环境中,备忘录就像是我手中的一把利剑,助我斩断杂乱的思绪,让工作变得井井有条。每当任务堆积如山,或是灵感与琐事交织时,我总会依赖我的备忘录来帮我理清头绪。 想象一下,你正忙于一个大型项…

小区物业管理收费系统源码小程序

便捷、透明、智能化的新体验 一款基于FastAdminUniApp开发的一款物业收费管理小程序。包含房产管理、收费标准、家属管理、抄表管理、在线缴费、业主公告、统计报表、业主投票、可视化大屏等功能。为物业量身打造的小区收费管理系统,贴合物业工作场景,轻…

RabbitMQ实践——搭建单人聊天服务

大纲 创建Core交换器用户登录发起聊天邀请接受邀请聊天实验过程总结代码工程 经过之前的若干节的学习,我们基本掌握了Rabbitmq各个组件和功能。本文我们将使用之前的知识搭建一个简单的单人聊天服务。 基本结构如下。为了避免Server有太多连线导致杂乱,下…

【MySQL基础篇】概述及SQL指令:DDL及DML

数据库是一个按照数据结构来组织、存储和管理数据的仓库。以下是对数据库概念的详细解释:定义与基本概念: 数据库是长期存储在计算机内的、有组织的、可共享的、统一管理的大量数据的集合。 数据库不仅仅是数据的简单堆积,而是遵循一定的规则…

可用的搜索引擎

presearchhttps://presearch.com/yandexhttps://yandex.com/ 以上,目前均不需科学上网。

GEOS学习笔记(一)

下载编译GEOS 从Download and Build | GEOS (libgeos.org)下载geos-3.10.6.tar.bz2 使用cmake-3.14.0版本配置VS2015编译 按默认配置生成VS工程文件 编译后生成geos.dll,geos_c.dll 后面学习使用C接口进行编程

PCB在工业领域的应用以及人工智能的影响。

什么是pcb呢? PCB,全称Printed Circuit Board,中文名称为印制电路板,也被称为印刷线路板或印制板1。这是一种重要的电子部件,主要由绝缘基板、连接导线和装配焊接电子元器件的焊盘组成。PCB的主要作用是作为电子元器件的支撑体和电气连接的载体,它能够简化电子产品的装配…

三分钟快速搭建基于FastAPI的AI Agent应用!

点击下方“JavaEdge”,选择“设为星标” 第一时间关注技术干货! 免责声明~ 任何文章不要过度深思! 万事万物都经不起审视,因为世上没有同样的成长环境,也没有同样的认知水平,更「没有适用于所有人的解决方案…

【鸿蒙学习笔记】页面和自定义组件生命周期

官方文档:页面和自定义组件生命周期 目录标题 [Q&A] 都谁有生命周期? [Q&A] 什么是组件生命周期? [Q&A] 什么是组件?组件生命周期 [Q&A] 什么是页面生命周期? [Q&A] 什么是页面?页面生…

代码随想录算法训练营第五十二天| [KC]100. 岛屿的最大面积、101. 孤岛的总面积、102. 沉没孤岛、103. 水流问题

[KamaCoder] 100. 岛屿的最大面积 [KamaCoder] 100. 岛屿的最大面积 文章解释 题目描述 给定一个由 1(陆地)和 0(水)组成的矩阵,计算岛屿的最大面积。岛屿面积的计算方式为组成岛屿的陆地的总数。岛屿由水平方向或垂直…

开放式耳机哪个牌子好?2024热门红榜开放式耳机测评真实篇!

当你跟朋友们聊天时,他们经常抱怨说长时间戴耳机会令耳朵感到不适,后台也有很多人来滴滴我,作为一位致力于开放式耳机的测评博主,在对比了多款开放式耳机之后,你开放式耳机在保护听力方面确实有用。开放式的设计有助于减轻耳道内的…