最近读了maskformer以及maskdino的分割头设计,于是想在RT-DETR上做一个分割的改动,所以选择在ultralytics库中对RTDETR进行改进。
本文内容简介:
1.ultralytics库中RT-DETR模型解析
2. 对ultralytics库中的RT-DETR模型增加分割头做实例分割
1.ultralytics库中RT-DETR模型解析
从yaml文件中可以看出解码过程是由RTDETRDecoder类实现的,先看看该类的代码:
class RTDETRDecoder(nn.Module):
export = False # export mode
def __init__():
super().__init__()
def forward(self, x, batch=None):
"""Runs the forward pass of the module, returning bounding box and classification scores for the input."""
from ultralytics.models.utils.ops import get_cdn_group
# Input projection and embedding
feats, shapes = self._get_encoder_input(x)
# Prepare denoising training
dn_embed, dn_bbox, attn_mask, dn_meta = get_cdn_group(
batch,
self.nc,
self.num_queries,
self.denoising_class_embed.weight,
self.num_denoising,
self.label_noise_ratio,
self.box_noise_scale,
self.training,
)
embed, refer_bbox, enc_bboxes, enc_scores = self._get_decoder_input(feats, shapes, dn_embed, dn_bbox)
# Decoder
dec_bboxes, dec_scores = self.decoder(
embed,
refer_bbox,
feats,
shapes,
self.dec_bbox_head,
self.dec_score_head,
self.query_pos_head,
attn_mask=attn_mask,
)
x = dec_bboxes, dec_scores, enc_bboxes, enc_scores, dn_meta
if self.training:
return x
# (bs, 300, 4+nc)
y = torch.cat((dec_bboxes.squeeze(0), dec_scores.squeeze(0).sigmoid()), -1)
return y if self.export else (y, x)
首先看看其输入和输出:
输入:骨干网络得到的三层特征,以640输入为例,分别为[b,256,80,80],[b,256,40,40],[b,256,20,20]
输出:box信息dec_bboxes,置信度信息dec_scores
再看看__call__执行的具体过程(这里我们先忽略掉CDN的部分):
其中,主要包含了4个函数,_get_encoder_input函数将输入整理成需要的形状,get_cdn_group添加了类似于DN-DETR的去噪分组方法,_get_decoder_input,decoder进行注意力计算。下边详细来看看这4个函数具体过程。
(1) _get_encoder_input
这个函数主要是将feature输入调整成固定的输入,输入为三个特征层的输入,输出为一个合并的特征feats,以及一个包含三个特征层尺寸的列表shaps[[80,80],[40,40],[20,20]]
具体的过程可由下图表示:
(2) get_cdn_group(这里先略过)
(3) _get_decoder_input
这里的输入就是上文提到的特征feats(b,8400,256)以及shapes[[80.80],[40,40],[20,20]],开头就是根据shapes生成锚点的操作。
def _get_decoder_input(self, feats, shapes):
"""
feats: [b,8400,256]
shapes: [[80,80],[40,40],[20,20]]
"""
bs = feats.shape[0] # b
anchors, valid_mask = self._generate_anchors(shapes, dtype=feats.dtype, device=feats.device) # anchor [b,8400,4] valid_mask[b,8400,1]
features = self.enc_output(valid_mask * feats) # [b,8400,1]*[b,8400,256]=[b,8400,256]
enc_outputs_scores = self.enc_score_head(features) # [b,8400,256]->[b,8400,nc]
# Query selection
# (bs, num_queries) DINO中的Mixed Query Selection策略,也就是从最后一个编码器层中选择前K个编码器特征作为先验,以增强解码器查询。
topk_ind = torch.topk(enc_outputs_scores.max(-1).values, self.num_queries, dim=1).indices.view(-1)
# (bs, num_queries)
batch_ind = torch.arange(end=bs, dtype=topk_ind.dtype).unsqueeze(-1).repeat(1, self.num_queries).view(-1)
# (bs, 8400, 256) -> (bs, num_queries, 256)
top_k_features = features[batch_ind, topk_ind].view(bs, self.num_queries, -1)
# (bs, 8400, 4) -> (bs, num_queries, 4)
top_k_anchors = anchors[:, topk_ind].view(bs, self.num_queries, -1)
# Dynamic anchors + static content
# 前300的特征经过3个Linear [N 300 256]—>[N 300 4]再加上top_k_anchors得到refer_bbox
refer_bbox = self.enc_bbox_head(top_k_features) + top_k_anchors
enc_bboxes = refer_bbox.sigmoid()
enc_scores = enc_outputs_scores[batch_ind, topk_ind].view(bs, self.num_queries, -1)
embeddings = self.tgt_embed.weight.unsqueeze(0).repeat(bs, 1, 1) if self.learnt_init_query else top_k_features
if self.training:
refer_bbox = refer_bbox.detach()
if not self.learnt_init_query:
embeddings = embeddings.detach()
return embeddings, refer_bbox, enc_bboxes, enc_scores
这里相信看过我文章的小伙伴已经非常熟悉了,通过[w,h]来生成对应的锚点,只不过这里有一点特殊,这里的锚点坐标是归一化后的,另外,针对锚点归一化后的值小于0.01或者大于0.99都是无效的,所以这里维护了一个valid_mask来得到有效的锚点
def _generate_anchors(self, shapes, grid_size=0.05, dtype=torch.float32, device="cpu", eps=1e-2):
"""Generates anchor bounding boxes for given shapes with specific grid size and validates them."""
anchors = []
for i, (h, w) in enumerate(shapes):
sy = torch.arange(end=h, dtype=dtype, device=device)
sx = torch.arange(end=w, dtype=dtype, device=device)
grid_y, grid_x = torch.meshgrid(sy, sx, indexing="ij") if TORCH_1_10 else torch.meshgrid(sy, sx)
grid_xy = torch.stack([grid_x, grid_y], -1) # (h, w, 2)
valid_WH = torch.tensor([w, h], dtype=dtype, device=device)
grid_xy = (grid_xy.unsqueeze(0) + 0.5) / valid_WH # (1, h, w, 2) 归一化锚点xy
# 三个层的值分别为0.05,0.1,0.2
wh = torch.ones_like(grid_xy, dtype=dtype, device=device) * grid_size * (2.0**i) #(1, h, w, 2)
anchors.append(torch.cat([grid_xy, wh], -1).view(-1, h * w, 4)) # (1, h*w, 4)
anchors = torch.cat(anchors, 1) # (1, h*w*nl, 4) #
# 限制每个anchor内的值都在[0.01-0.99]之间,在这个区间之外的值设为无效,后面通过masked_fill设为'inf'
valid_mask = ((anchors > eps) & (anchors < 1 - eps)).all(-1, keepdim=True) # 1, h*w*nl, 1
anchors = torch.log(anchors / (1 - anchors))
anchors = anchors.masked_fill(~valid_mask, float("inf"))
return anchors, valid_mask
接下来使用了 DINO中的Mixed Query Selection策略,也就是从特征中中选择前K个编码器特征作为先验,以增强解码器查询。