ultralytics库RT-DETR代码解析

最近读了maskformer以及maskdino的分割头设计,于是想在RT-DETR上做一个分割的改动,所以选择在ultralytics库中对RTDETR进行改进。

本文内容简介:

        1.ultralytics库中RT-DETR模型解析

        2. 对ultralytics库中的RT-DETR模型增加分割头做实例分割


1.ultralytics库中RT-DETR模型解析

从yaml文件中可以看出解码过程是由RTDETRDecoder类实现的,先看看该类的代码:

class RTDETRDecoder(nn.Module):
    export = False  # export mode
    def __init__():
        super().__init__()

    def forward(self, x, batch=None):
        """Runs the forward pass of the module, returning bounding box and classification scores for the input."""
        from ultralytics.models.utils.ops import get_cdn_group

        # Input projection and embedding
        feats, shapes = self._get_encoder_input(x)

        # Prepare denoising training
        dn_embed, dn_bbox, attn_mask, dn_meta = get_cdn_group(
            batch,
            self.nc,
            self.num_queries,
            self.denoising_class_embed.weight,
            self.num_denoising,
            self.label_noise_ratio,
            self.box_noise_scale,
            self.training,
        )

        embed, refer_bbox, enc_bboxes, enc_scores = self._get_decoder_input(feats, shapes, dn_embed, dn_bbox)

        # Decoder
        dec_bboxes, dec_scores = self.decoder(
            embed,
            refer_bbox,
            feats,
            shapes,
            self.dec_bbox_head,
            self.dec_score_head,
            self.query_pos_head,
            attn_mask=attn_mask,
        )
        x = dec_bboxes, dec_scores, enc_bboxes, enc_scores, dn_meta
        if self.training:
            return x
        # (bs, 300, 4+nc)
        y = torch.cat((dec_bboxes.squeeze(0), dec_scores.squeeze(0).sigmoid()), -1)
        return y if self.export else (y, x)

   

首先看看其输入和输出:

        输入:骨干网络得到的三层特征,以640输入为例,分别为[b,256,80,80],[b,256,40,40],[b,256,20,20]

        输出:box信息dec_bboxes,置信度信息dec_scores

再看看__call__执行的具体过程(这里我们先忽略掉CDN的部分):

        其中,主要包含了4个函数,_get_encoder_input函数将输入整理成需要的形状,get_cdn_group添加了类似于DN-DETR的去噪分组方法,_get_decoder_input,decoder进行注意力计算。下边详细来看看这4个函数具体过程。

       (1) _get_encoder_input

        这个函数主要是将feature输入调整成固定的输入,输入为三个特征层的输入,输出为一个合并的特征feats,以及一个包含三个特征层尺寸的列表shaps[[80,80],[40,40],[20,20]]

        具体的过程可由下图表示:

    (2) get_cdn_group(这里先略过)

    (3) _get_decoder_input

         这里的输入就是上文提到的特征feats(b,8400,256)以及shapes[[80.80],[40,40],[20,20]],开头就是根据shapes生成锚点的操作。

	def _get_decoder_input(self, feats, shapes):
		"""
			feats: [b,8400,256]
			shapes: [[80,80],[40,40],[20,20]]
		"""
        bs = feats.shape[0]  # b
        anchors, valid_mask = self._generate_anchors(shapes, dtype=feats.dtype, device=feats.device)  # anchor [b,8400,4] valid_mask[b,8400,1]
        features = self.enc_output(valid_mask * feats)  # [b,8400,1]*[b,8400,256]=[b,8400,256]

        enc_outputs_scores = self.enc_score_head(features)  # [b,8400,256]->[b,8400,nc]
        # Query selection
        # (bs, num_queries) DINO中的Mixed Query Selection策略,也就是从最后一个编码器层中选择前K个编码器特征作为先验,以增强解码器查询。
        topk_ind = torch.topk(enc_outputs_scores.max(-1).values, self.num_queries, dim=1).indices.view(-1)
        # (bs, num_queries)
        batch_ind = torch.arange(end=bs, dtype=topk_ind.dtype).unsqueeze(-1).repeat(1, self.num_queries).view(-1)

        # (bs, 8400, 256) -> (bs, num_queries, 256)
        top_k_features = features[batch_ind, topk_ind].view(bs, self.num_queries, -1)
        # (bs, 8400, 4)   -> (bs, num_queries, 4)
        top_k_anchors = anchors[:, topk_ind].view(bs, self.num_queries, -1)

        # Dynamic anchors + static content
		# 前300的特征经过3个Linear [N 300 256]—>[N 300 4]再加上top_k_anchors得到refer_bbox
        refer_bbox = self.enc_bbox_head(top_k_features) + top_k_anchors

        enc_bboxes = refer_bbox.sigmoid()
        enc_scores = enc_outputs_scores[batch_ind, topk_ind].view(bs, self.num_queries, -1)

        embeddings = self.tgt_embed.weight.unsqueeze(0).repeat(bs, 1, 1) if self.learnt_init_query else top_k_features
        if self.training:
            refer_bbox = refer_bbox.detach()
            if not self.learnt_init_query:
                embeddings = embeddings.detach()

        return embeddings, refer_bbox, enc_bboxes, enc_scores

             这里相信看过我文章的小伙伴已经非常熟悉了,通过[w,h]来生成对应的锚点,只不过这里有一点特殊,这里的锚点坐标是归一化后的,另外,针对锚点归一化后的值小于0.01或者大于0.99都是无效的,所以这里维护了一个valid_mask来得到有效的锚点

def _generate_anchors(self, shapes, grid_size=0.05, dtype=torch.float32, device="cpu", eps=1e-2):
        """Generates anchor bounding boxes for given shapes with specific grid size and validates them."""
        anchors = []
        for i, (h, w) in enumerate(shapes):
            sy = torch.arange(end=h, dtype=dtype, device=device)
            sx = torch.arange(end=w, dtype=dtype, device=device)
            grid_y, grid_x = torch.meshgrid(sy, sx, indexing="ij") if TORCH_1_10 else torch.meshgrid(sy, sx)
            grid_xy = torch.stack([grid_x, grid_y], -1)  # (h, w, 2)

            valid_WH = torch.tensor([w, h], dtype=dtype, device=device)
            grid_xy = (grid_xy.unsqueeze(0) + 0.5) / valid_WH  # (1, h, w, 2) 归一化锚点xy
            # 三个层的值分别为0.05,0.1,0.2
            wh = torch.ones_like(grid_xy, dtype=dtype, device=device) * grid_size * (2.0**i) #(1, h, w, 2)
            anchors.append(torch.cat([grid_xy, wh], -1).view(-1, h * w, 4))  # (1, h*w, 4)

        anchors = torch.cat(anchors, 1)  # (1, h*w*nl, 4)   # 
        # 限制每个anchor内的值都在[0.01-0.99]之间,在这个区间之外的值设为无效,后面通过masked_fill设为'inf'
        valid_mask = ((anchors > eps) & (anchors < 1 - eps)).all(-1, keepdim=True)  # 1, h*w*nl, 1
        anchors = torch.log(anchors / (1 - anchors))
        anchors = anchors.masked_fill(~valid_mask, float("inf"))
        return anchors, valid_mask

        接下来使用了 DINO中的Mixed Query Selection策略,也就是从特征中中选择前K个编码器特征作为先验,以增强解码器查询。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/945940.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

25. C++继承 1 (继承的概念与基础使用, 继承的复制兼容规则,继承的作用域)

⭐上篇模板文章&#xff1a;24. C模板 2 (非类型模板参数&#xff0c;模板的特化与模板的分离编译)-CSDN博客 ⭐本篇代码&#xff1a;c学习 橘子真甜/c-learning-of-yzc - 码云 - 开源中国 (gitee.com) ⭐标⭐是比较重要的部分 目录 一. 继承的基础使用 1.1 继承的格式 1.2 …

宽带、光猫、路由器、WiFi、光纤之间的关系

1、宽带&#xff08;Broadband&#xff09; 1.1 宽带的定义宽带指的是一种高速互联网接入技术&#xff0c;通常包括ADSL、光纤、4G/5G等不同类型的接入方式。宽带的关键特点是能够提供较高的数据传输速率&#xff0c;使得用户可以享受到稳定的上网体验。 1.2 宽带的作用宽带是…

【AndroidAPP】权限被拒绝:[android.permission.READ_EXTERNAL_STORAGE],USB设备访问权限系统报错

一、问题原因 1.安卓安全性变更 Android 12 的安全性变更&#xff0c;Google 引入了更严格的 PendingIntent 安全管理&#xff0c;强制要求开发者明确指定 PendingIntent 的可变性&#xff08;Mutable&#xff09;或不可变性&#xff08;Immutable&#xff09;。 但是&#xf…

Inno Setup生成exe安装包

Inno Setup生成exe安装包 第一步&#xff1a;创建一个带向导的脚本文件 第二步&#xff1a;直接 Next&#xff0c;不要创建空的脚本文件 第三步&#xff1a;填写相关的应用程序信息 第四步&#xff1a;指定应用程序的安装目录相关的信息 第五步&#xff1a;选择可执行程序和相…

数据库MHA

MHA 什么是MHA -------- MASTER HIGH AVAILABILITY 建立在主从复制基础之上的故障切换到软件系统 主从复制的单点问题&#xff1a; 当主从复制当中&#xff0c;主服务器发生故障&#xff0c;会自动切换到一台从服务器&#xff0c;然后把从服务器升格为主&#xff0c;继续主…

vue2 - Day04 - 插槽、路由

插槽、路由 一、插槽&#xff08;solt&#xff09;1.1 概念1.2 基本用法1.3 分类1.3.1 默认插槽&#xff08;Default Slot&#xff09;例子&#xff1a; 1.3.2 具名插槽&#xff08;Named Slots&#xff09;语法&#xff1a; 1.3.3 作用域插槽&#xff08;Scoped Slots&#xf…

微信小程序:定义页面标题,动态设置页面标题,json

1、常规设置页面标题 正常微信小程序中&#xff0c;设置页面标题再json页面中进行设置&#xff0c;例如 {"usingComponents": {},"navigationBarTitleText": "标题","navigationBarBackgroundColor": "#78b7f7","navi…

【数据可视化-10】国防科技大学录取分数线可视化分析

&#x1f9d1; 博主简介&#xff1a;曾任某智慧城市类企业算法总监&#xff0c;目前在美国市场的物流公司从事高级算法工程师一职&#xff0c;深耕人工智能领域&#xff0c;精通python数据挖掘、可视化、机器学习等&#xff0c;发表过AI相关的专利并多次在AI类比赛中获奖。CSDN…

Spring Boot教程之四十一:在 Spring Boot 中调用或使用外部 API

如何在 Spring Boot 中调用或使用外部 API&#xff1f; Spring Boot 建立在 Spring 之上&#xff0c;包含 Spring 的所有功能。它现在越来越受到开发人员的青睐&#xff0c;因为它是一个快速的生产就绪环境&#xff0c;使开发人员能够直接专注于逻辑&#xff0c;而不必费力配置…

L25.【LeetCode笔记】 三步问题的四种解法(含矩阵精彩解法!)

目录 1.题目 2.三种常规解法 方法1:递归做 ​编辑 方法2:改用循环做 初写的代码 提交结果 分析 修改后的代码 提交结果 for循环的其他写法 提交结果 方法3:循环数组 提交结果 3.方法4:矩阵 算法 代码实践 1.先计算矩阵n次方 2.后将矩阵n次方嵌入递推式中 提…

面试题解,JVM的运行时数据区

一、请简述JVM运行时数据区的组成结构及各部分作用 总览 从线程持有的权限来看 线程私有区 虚拟机栈 虚拟机栈是一个栈结构&#xff0c;由许多个栈帧组成&#xff0c;一个方法分配一个栈帧&#xff0c;线程每执行一个方法时都会有一个栈帧入栈&#xff0c;方法执行结束后栈帧…

代码随想录算法【Day7】

DAY7 454.四数相加II 特点&#xff1a; 1.只用返回元组的个数&#xff0c;而不用返回具体的元组 2.可以不用去重 暴力思路&#xff1a;遍历&#xff0c;这样时间复杂度会达到O(n^4) 标准思路&#xff1a;用哈希法&#xff08;场景&#xff1a;在一个集合里面判断一个元素…

网络渗透测试实验四:CTF实践

1.实验目的和要求 实验目的:通过对目标靶机的渗透过程,了解CTF竞赛模式,理解CTF涵盖的知识范围,如MISC、PPC、WEB等,通过实践,加强团队协作能力,掌握初步CTF实战能力及信息收集能力。熟悉网络扫描、探测HTTP web服务、目录枚举、提权、图像信息提取、密码破解等相关工具…

[羊城杯 2024]不一样的数据库_2

题目描述&#xff1a; 压缩包6 (1).zip需要解压密码&#xff1a; 尝试用ARCHPR工具爆破一下&#xff1a; &#xff08;字典可自行在github上查找&#xff09; 解压密码为&#xff1a;753951 解压得到13.png和Kee.kdbx文件&#xff1a; 二维码图片看上去只缺了正常的三个角&…

JSON结构快捷转XML结构API集成指南

JSON结构快捷转XML结构API集成指南 引言 在当今的软件开发世界中&#xff0c;数据交换格式的选择对于系统的互操作性和效率至关重要。JSON&#xff08;JavaScript Object Notation&#xff09;和XML&#xff08;eXtensible Markup Language&#xff09;是两种广泛使用的数据表…

小程序租赁系统构建指南与市场机会分析

内容概要 在当今竞争激烈的市场环境中&#xff0c;小程序租赁系统正崭露头角&#xff0c;成为企业转型与创新的重要工具。通过这个系统&#xff0c;商户能够快速推出自己的小程序&#xff0c;无需从头开发&#xff0c;节省了大量时间和资金。让我们来看看这个系统的核心功能吧…

单词统计详解---pyhton

有一个.txt的文本文件&#xff0c;对齐单词进行统计&#xff0c;并显示单词重复做多的10个单词 思路&#xff1a; 1将文本文件进行逐行处理&#xff0c;并进行空格分割处理 2新建一个字典&#xff0c;使用get方法将单词一次添加到字典中&#xff0c;并用sorted方法进行排序。…

如何逐步操作vCenter修改DNS服务器?

在vSphere 7中有一个新功能&#xff0c;它允许管理员更改vCenter Server Appliance的FQDN和IP。因此本文将介绍如何轻松让vCenter修改DNS服务器。 vCenter修改DNS以及修改vCenter IP地址 与在部署 vCenter Server Appliance 后&#xff0c;您可以根据需要修改其 DNS 设置和 IP…

[创业之路-225]:《华为闭环战略管理》-4-华为的商业智慧:在价值链中探索取舍之道与企业边界

目录 一、在价值链中探索取舍之道与企业边界 价值链的深刻洞察 取舍之道&#xff1a;有所为&#xff0c;有所不为 垂直整合与横向整合的平衡 企业边界与活动边界的界定 采购与外包的智慧运用 结语 二、企业外部价值流&#xff1a;上游、中游、下游、终端 上游&#xf…

鸿蒙1.2:第一个应用

1、create Project&#xff0c;选择Empty Activity 2、配置项目 project name 为项目名称&#xff0c;建议使用驼峰型命名 Bundle name 为项目包名 Save location 为保存位置 Module name 为模块名称&#xff0c;即运行时需要选择的模块名称&#xff0c;见下图 查看模块名称&…