DINO-DETR匈牙利匹配与加噪过程学习记录

今天再来回顾一下DINO中匈牙利匹配与损失函数部分,该部分大致与DETR相似,却又略有不同。
为了查看数据方便,博主将num_query改为20,max_select值也为20。

匈牙利匹配过程

首先是数据送入匈牙利匹配中进行标签匹配过程了。

获取预测的类别,box信息

bs, num_queries = outputs["pred_logits"].shape[:2]
#获取预测值信息
out_prob = outputs["pred_logits"].flatten(0, 1).sigmoid()  
#[batch_size * num_queries, num_classes]  torch.Size([40, 4]) 4为类别数目
out_bbox = outputs["pred_boxes"].flatten(0, 1)  
# [batch_size * num_queries, 4]  torch.Size([40, 4]) 4为xywh数据

获取真实框的类别与box信息

tgt_ids = torch.cat([v["labels"] for v in targets])
tgt_bbox = torch.cat([v["boxes"] for v in targets])

计算Focal_loss:pos_cost_class与neg_cost_class皆为:torch.Size([40, 4]),得到cost_class为:torch.Size([40, 5]),cost_class为每个query与target的损失。

alpha = self.focal_alpha  #0.25
gamma = 2.0   
neg_cost_class = (1 - alpha) * (out_prob ** gamma) * (-(1 - out_prob + 1e-8).log())
pos_cost_class = alpha * ((1 - out_prob) ** gamma) * (-(out_prob + 1e-8).log())
cost_class = pos_cost_class[:, tgt_ids] - neg_cost_class[:, tgt_ids]

计算L1距离得到cost_bbox为:torch.Size([40, 5])

cost_bbox = torch.cdist(out_bbox, tgt_bbox, p=1)

计算giou,得到cost-giou为:torch.Size([40, 5])

cost_giou = -generalized_box_iou(box_cxcywh_to_xyxy(out_bbox), box_cxcywh_to_xyxy(tgt_bbox))

构成cost矩阵

C = self.cost_bbox * cost_bbox + self.cost_class * cost_class + self.cost_giou * cost_giou
C = C.view(bs, num_queries, -1).cpu()

中间步骤C:

在这里插入图片描述

最终形成的C:torch.Size([2, 20, 5])

在这里插入图片描述

获取每个batch中对应的标签个数

sizes = [len(v["boxes"]) for v in targets]

使用匈牙利匹配算法进行计算,得出匹配的标签与预测框。

indices = [linear_sum_assignment(c[i]) for i, c in enumerate(C.split(sizes, -1))]

得出的indices是list形式,内部每个元素为tuple,其内为array对应标签id与预测框id。

在这里插入图片描述

将indices转换为tensor向量形式。

return [(torch.as_tensor(i, dtype=torch.int64), torch.as_tensor(j, dtype=torch.int64)) for i, j in indices]

返回的indices如下:

在这里插入图片描述

CDN构造过程

CDN是DINO的创新点之一,其是如何构造的呢?举个例子:
设置batch_size=2,第一张图片有一个标注框,第二张图片有4个标注框
开始设置参数dn_number=100,即添加噪声的query有100个,同时要设置对照组,也是100
dn_number=200,注意这里是设置dn_query的个数
随后判断设置多少个对照组,根据每个batch中最大的tgt数目设置dn_group。

		known = [(torch.ones_like(t['labels'])).cuda() for t in targets]
		#[tensor([1], device='cuda:0'), tensor([1, 1, 1, 1], device='cuda:0')]
        batch_size = len(known)#2
        known_num = [sum(k) for k in known]
        #[tensor(1, device='cuda:0'), tensor(4, device='cuda:0')]
        if int(max(known_num)) == 0:
            dn_number = 1
        else:
            if dn_number >= 100:
                dn_number = dn_number // (int(max(known_num) * 2))
                #确定dn_number=25

什么意思呢,就是说总共我每个batch中设置25组即可。
然后总共正样本数为(1+4)x25=125,同理负样本数也是如此,两者加起来总共有250个
随后对标签进行加噪。
分别得到编码后的标签类别与box:(input_label_embed等都需经过embed编码)

input_label_embed:torch.Size([250, 256])
input_bbox_embed:torch.Size([250, 4])

由于我们设置了dn_query数目固定为200,生成dn_query:
初始时全为0,

padding_label = torch.zeros(pad_size, hidden_dim).cuda()#torch.Size([200, 256])
padding_bbox = torch.zeros(pad_size, 4).cuda()#torch.Size([200, 4])
随后复制batch维度:
input_query_label = padding_label.repeat(batch_size, 1, 1)#torch.Size([2, 200, 256])
input_query_bbox = padding_bbox.repeat(batch_size, 1, 1)#torch.Size([2, 200, 4])

可以看到,此时其全部为0,那么如何将我们加噪后的query放进去呢?

if len(known_num):
   map_known_indice = torch.cat([torch.tensor(range(num)) for num in known_num])  # [1,2, 1,2,3]
   map_known_indice = torch.cat([map_known_indice + single_pad * i for i in range(2 * dn_number)]).long()
if len(known_bid):
   input_query_label[(known_bid.long(), map_known_indice)] = input_label_embed
   input_query_bbox[(known_bid.long(), map_known_indice)] = input_bbox_embed

第一个判断获取了标识index,第二个判断结合batch_id与indice来进行填充:

举个例子:

input_query_box[0,0]填充input_bbox_embed[0]
input_query_box[1,0]填充input_bbox_embed[1]
input_query_box[1,1]填充input_bbox_embed[2]
input_query_box[1,2]填充input_bbox_embed[3]
input_query_box[1,3]填充input_bbox_embed[4]
input_query_box[0,4]填充input_bbox_embed[5]
input_query_box[1,4]填充input_bbox_embed[6],以此类推

至此便构造出dn_query了,值得一提的是只有最大tgt数目的图像中的query是全部有非零值的,在本次例子中,第一个batch中2x100个query中只有2x25个非零值,而在第二个batch中,全部都是被填充的。

known_bid值如下所示:五个一组

在这里插入图片描述

indices值如下所示,五个一组,配合batch_id可以将加噪后的值填入到query中,共有250个,但其值最大到199,刚好与200对应。

在这里插入图片描述

计算Label Loss

首先看传入的label_loss的参数:

def loss_labels(self, outputs, targets, indices, num_boxes, log=True):

outputs为预测结果:labels:torch.Size([2, 200, 4]) box:torch.Size([2, 200, 4])
targets为真实值
indices为匈牙利匹配结果:

在这里插入图片描述
num_boxes为box的个数,此时为125,需要注意的是在第一次跳入loss_labels时,实际上计算的是DN的损失。

使用dn_query计算loss不易查看(有200个),我们使用匈牙利匹配的结果来查看:
target共有5个,其中第一个batch有一个,第二个batch有4个。

在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/30366.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

【Android自动化测试】Ui Automator技术(以对QQ软件自动发说说为例)

文章目录 一、引言二、了解(Android官方文档)1、UiDevice 类2、UI Automator API3、UI Automator 查看器 三、使用1、依赖2、代码 一、引言 描述:UI Automator 是一个界面测试框架,适用于整个系统上以及多个已安装应用间的跨应用…

react---react router 5 基本使用

目录 1.路由介绍 2.路由使用 3.路由组件和一般组件 4.Switch 单一匹配 5.解决二级路由样式丢失的问题 6.路由精准匹配和模糊匹配 7.Redirect路由重定向 1.路由介绍 路由是根据不同的 URL 地址展示不同的内容或页面,在 SPA 应用中,大部分页面结果…

理解Web3公链共识算法的原理与机制

Web3时代带来了去中心化、透明和安全的数字经济发展,而公链的共识算法是实现这一目标的关键。共识算法确保了公链网络中的节点对交易和状态的一致性达成共识,同时防止了恶意行为和双重支付等问题。本文将深入探讨Web3公链共识算法的核心原理与机制。 1.共…

【Uniapp】小程序携带Token请求接口+无感知登录方案2.0

本次改进原文《【Uniapp】小程序携带Token请求接口无感知登录方案》,在实际使用过程中我发现以下bug: 若token恰好在用户访问接口时到期,就会直接查询为空,不反映token过期问题(例如:弹窗显示订单查询记录…

【数据库数据恢复】SQL Server数据表结构损坏的数据恢复案例

数据库故障&分析: SQL server数据库数据无法读取。 经过初检,发现SQL server数据库文件无法被读取的原因是因为底层File Record被截断为0,无法找到文件开头,数据表结构损坏。镜像文件的前面几十M空间和中间一部分空间被覆盖掉…

饶派杯XCTF车联网安全挑战赛Reverse GotYourKey

文章目录 一.程序逻辑分析二.线程2的operate方法解析三.找出真flag 一.程序逻辑分析 onCreate方法中判断SDK版本是否>27 然后创建两个线程 第一个线程是接受输入的字符串并发送出去 第二个线程用于接受数据 线程1,就是将字符串转为字节数组发送出去 线程2,作为服务端接受…

springboot动态加载json文件

resources下面的配置文件,application文件修改启动会实时加载新的内容 其他的文件属于静态文件,打包后会把文件打入jar里面,修改静态文件启动不会加载新的内容 Resource areacode nre FileSystemResource("config" File.separa…

STM32——07-STM32定时器Timer

定时器介绍 软件定时 缺点:不精确、占用 CPU 资源 void Delay500ms () //11.0592MHz { unsigned char i , j , k ; _nop_ (); i 4 ; j 129 ; k 119 ; do { do { while ( -- k ); } while ( -- j ); } while ( -- i ); } 定时器工…

Springboot--关于自定义stater的yml无法提示

1.前言 在以前在搭建架构的时候就碰到了类似的情况,在使用EnableConfigurationProperties注解的时候,不管怎样,在项目中引入了该starter的情况下依然不发自动的提示properties里面的属性。 Data ConfigurationProperties(prefix "pro…

vite vs babel+webpack | 创建一个简单的vite项目打包运行

有babel、webpack这些优秀的框架,为什么使用vite? 因为vite编译快,启动快,使用简单,还自带一个热更新重启的服务器,vite能够自动的帮我打包所用到的依赖,有些依赖只有用到才会导入,不用到不会…

开放式耳机和封闭式耳机的区别?开放式耳机到底有哪些优缺点?

开放式耳机从字面意思可以理解为:开放耳朵,不需要入耳就可以听见声音的耳机,所以它和封闭式耳机的最大区别就是不入耳。这种耳机最大的优点就是不压迫不封闭耳道,而且在听耳机音的同时能够及时注意到周围环境的声音,从…

【图神经网络】5分钟快速了解Open Graph Benchmark

10分钟快速了解Open Graph Benchmark Open Graph Benchmark (OGB)安装OGB简单使用节点分类任务数据集链路预测任务数据集图属性预测任务数据集Large-Scale Graph ML Datasets 内容来源 Open Graph Benchmark (OGB) Open Graph Benchmark(OGB)是用于图机…

从一个线上 Android Bug 回看 Fragment 的基础知识

作者:Kotlin上海用户组 公司的项目在最近遇到了一个与 Fragment 有关的线上 crash,导致这个问题的根本原因比较复杂,导致修复方案的可选项非常有限,不过这个问题的背景、crash 点,以及修复过程都非常有趣,值…

【RabbitMQ教程】第四章 —— RabbitMQ - 交换机

💧 【 R a b b i t M Q 教 程 】 第 四 章 — — R a b b i t M Q − 交 换 机 \color{#FF1493}{【RabbitMQ教程】第四章 —— RabbitMQ - 交换机} 【RabbitMQ教程】第四章——RabbitMQ−交换机💧 🌷 仰望天空,妳我亦是…

共创开源生态 | 小米肖翔荣获“2023中国开源优秀人物”奖

6月15-16日,以“开源创新 数字化转型 智能化重构”为主题的“第十八届开源中国・开源世界高峰论坛”在北京成功召开。小米工程师肖翔凭借其在 Apache 基金会的开源贡献及在操作系统领域内的技术突破,荣获“2023中国开源优秀人物”奖。 Xiaomi …

使用VitePress创建个人网站并部署到GitHub

网站在线预览 参考文档: VitePress 创建 GitHub 远程仓库 克隆远程仓库到本地 git clone gitgithub.com:themusecatcher/front-end-notes.git进入 front-end-notes/ 目录,添加 README.md 并建立分支跟踪 echo "# front-end-notes" >>…

nand flash 介绍

flash名称由来 Flash的擦除操作是以block块为单位的,与此相对应的是其他很多存储设备,是以bit位为最小读取/写入的单位,Flash是一次性地擦除整个块:在发送一个擦除命令后,一次性地将一个block,常见的块的大…

FAQ页面在SaaS产品中的应用

随着云计算和软件即服务(SaaS)的快速发展,越来越多的企业选择将业务迁移到云端,以更好地管理和运营他们的业务。在这种背景下,SaaS产品的出现成为了企业管理和运营的新趋势。SaaS产品通过云端的方式,为企业…

Godot 4 源码分析 - 命令行参数

粗看Godot 4的源码&#xff0c;稍微调试一下&#xff0c;发现一大堆的命令行参数。在widechar_main中 Error err Main::setup(argv_utf8[0], argc - 1, &argv_utf8[1]); Main::setup中&#xff0c;各命令行参数加入到List<Stirng> args中&#xff0c;并通过OS::get…

腾讯云服务器地域有什么区别?怎么选比较好

腾讯云服务器地域有什么区别&#xff1f;云服务器地域怎么选择&#xff1f;地域是指云服务器所在机房的地理位置&#xff0c;用户距离地域越近网络延迟越低&#xff0c;速度越快&#xff0c;所以地域就近选择即可。广州上海北京等地域网站域名需要备案&#xff0c;中国香港或其…