YOLOv8-seg 分割代码详解(二)Train

前言

  本文主要以源码+注释为主,可以了解到从模型的输出到损失计算这个过程每个步骤的具体实现方法。

流程梳理

一、选取有效 anchor
  以 640x640 的输入为例,模型最终有8400个 anchor,每个 anchor 都有其对应的检测输出(4+n)和分割输出(32),而这些 anchor 并不会每个都参与到 loss 的计算。

  满足以下条件可成为有效 anchor,参与 loss 计算:
(1)anchor 坐标在 gt_bbox 范围中;
(2)对于每个 gt_bbox 综合得分前十,此得分由 IoU 和预测的 class_score 融合而得;
(3)若对于多个 gt_bbox 都满足前两个条件,则只保留综合得分最高的。

二、loss 计算
  loss 总共分为4个部分,把交叉熵损失记作 f ( x , y ) f(x,y) f(x,y)

(1) L box L_{\text{box}} Lbox

l i = w i ( 1 − IoU i ) l_i=w_i(1-\text{IoU}_i) li=wi(1IoUi)

L box = ∑ l i / ∑ w i L_{\text{box}}=\sum{l_i} / \sum{w_i} Lbox=li/wi

  这里 w i w_i wi 同样是 IoU 和分类得分融合后的得分,以此作为权重,并做类似求均值的操作得到最终 loss

(2) L seg L_{\text{seg}} Lseg

l i = f ( mask i , gt_mask i ) / area l_i=f(\text{mask}_i, \text{gt\_mask}_i) / \text{area} li=f(maski,gt_maski)/area

L seg = mean ( l ) L_{\text{seg}}=\text{mean}(l) Lseg=mean(l)

  mask 获取方式与 predict 中相同,然后与标签计算交叉熵损失,area 为对应 gt_bbox 的面积。

(3) L cls L_{\text{cls}} Lcls

L cls = sum ( f (cls,gt_cls) ) / ∑ w i L_{\text{cls}}=\text{sum}(f\text{(cls,gt\_cls)}) / \sum{w_i} Lcls=sum(f(cls,gt_cls))/wi

  交叉熵损失,取均值的方式与 L box L_{\text{box}} Lbox 类似

(4) L dfl L_{\text{dfl}} Ldfl

   L dfl L_{\text{dfl}} Ldfl 也是用于收敛检测框的,这里要回溯到 DFL 模块的输出 Tensor(b,4,8400),其对应的坐标是检测框左上角和右下角到 anchor 坐标的距离。把 gt_bbox 转化到同样的形式后,对其计算损失。

l = f ( x , floor ( y ) ) × ( ceil ( y ) − y ) + f ( x , ceil ( y ) ) × ( y − floor ( y ) ) l=f(x, \text{floor}(y))\times(\text{ceil}(y)-y) + f(x, \text{ceil}(y))\times (y - \text{floor}(y)) l=f(x,floor(y))×(ceil(y)y)+f(x,ceil(y))×(yfloor(y))

  上面的公式中的 x x x 对应某个 anchor 的4个坐标值中的一个, y y y 是其对应的 gt 值。这里简单举一个例子更方便理解这个损失,例如 x = x , y = 3.7 x=x,y=3.7 x=x,y=3.7

l = f ( x , 3 ) × 0.3 + f ( x , 4 ) × 0.7 l=f(x,3)\times 0.3+f(x,4)\times 0.7 l=f(x,3)×0.3+f(x,4)×0.7

  对于每个坐标值,模型会输出16个总和为1的概率值,分别与0~15加权求和成为最终的坐标值。这意味这当 y = 3.7 y=3.7 y=3.7 时,理想情况下是 3 的概率为 0.3、4 的概率为 0.7,从而 0.3 ∗ 3 + 0.7 ∗ 4 = 3.7 0.3*3+0.7*4=3.7 0.33+0.74=3.7,而这个损失便是通过权重和两个交叉熵损失,让分类结果不同程度的向 3 和 4 收敛。

代码与细节

0. 模型原始输出

preds: (tuple:3)
	0(feats): (list:3)
		0: (Tensor:(b,64+cls_n,80,80))
		1: (Tensor:(b,64+cls_n,40,40))
		2: (Tensor:(b,64+cls_n,20,20))
	1(pred_masks): (Tensor:(b,32,8400))
	2(proto): (Tensor:(b,32,160,160)) 

cls_n=1

1. 获取原始标签

"""
gt_labels: (Tensor:(b,8,1))
gt_bboxes: (Tensor:(b,8,4))
mask_gt: (Tensor:(b,8,1))

这里的8是每个图像的最大目标个数(max_num_obj), 设定成统一数量方便后续矩阵运算,
而目标数量不够的会以坐标全为0进行填充, 而 mask_gt 就是记录是否为真的目标的01矩阵
"""
batch_idx = batch['batch_idx'].view(-1, 1)
targets = torch.cat((batch_idx, batch['cls'].view(-1, 1), batch['bboxes']), 1)
targets = self.preprocess(targets.to(self.device), batch_size, scale_tensor=imgsz[[1, 0, 1, 0]])
gt_labels, gt_bboxes = targets.split((1, 4), 2)  # cls, xyxy
mask_gt = gt_bboxes.sum(2, keepdim=True).gt_(0)

2. 模型原始输出处理与转化

"""对 feats, pred_masks 进行合并和维度变换"""
pred_scores: (Tensor:(b,8400,1))
pred_distri: (Tensor:(b,8400,64))
pred_masks: (Tensor:(b,8400,32))

"""把 pred_distri 转换为目标框输出 pred_bboxes: (Tensor:(b,8400,4))"""
pred_bboxes = self.bbox_decode(anchor_points, pred_distri)

def bbox_decode(self, anchor_points, pred_dist):
    """Decode predicted object bounding box coordinates from anchor points and distribution."""
    if self.use_dfl:
        b, a, c = pred_dist.shape  # batch, anchors, channels
        """
        self.proj = torch.arange(m.reg_max, dtype=torch.float, device=device), 即0~15的向量
        这意味着 pred_dist 中的数值在 0~15 之间
        根据后续的 dist2bbox 可以看出在 20x20 和 40x40 的输出上都有检测覆盖全图的大目标的能力
        在这里计算的坐标都还是在特征图分辨率的坐标系上, 并未根据步长统一到 640x640 坐标系上
        """
        pred_dist = pred_dist.view(b, a, 4, c // 4).softmax(3).matmul(self.proj.type(pred_dist.dtype))
    return dist2bbox(pred_dist, anchor_points, xywh=False)

def dist2bbox(distance, anchor_points, xywh=True, dim=-1):
    """Transform distance(ltrb) to box(xywh or xyxy)."""
    lt, rb = distance.chunk(2, dim)
    x1y1 = anchor_points - lt
    x2y2 = anchor_points + rb
    if xywh:
        c_xy = (x1y1 + x2y2) / 2
        wh = x2y2 - x1y1
        return torch.cat((c_xy, wh), dim)  # xywh bbox
    return torch.cat((x1y1, x2y2), dim)  # xyxy bbox

3. 获取用于计算loss的标签与输出

_, target_bboxes, target_scores, fg_mask, target_gt_idx = self.assigner(
	pred_scores.detach().sigmoid(), 
	(pred_bboxes.detach() * stride_tensor).type(gt_bboxes.dtype),
	anchor_points * stride_tensor, 
	gt_labels, 
	gt_bboxes, 
	mask_gt
)
@torch.no_grad()
def forward(self, pd_scores, pd_bboxes, anc_points, gt_labels, gt_bboxes, mask_gt):
	"""
	Args:
	    pd_scores (Tensor): shape(bs, num_total_anchors, num_classes), sigmoid(pred_scores)
	    pd_bboxes (Tensor): shape(bs, num_total_anchors, 4), 坐标统一到输入640x640
	    anc_points (Tensor): shape(num_total_anchors, 2), 坐标统一到输入640x640
	    gt_labels (Tensor): shape(bs, n_max_boxes, 1)
	    gt_bboxes (Tensor): shape(bs, n_max_boxes, 4)
	    mask_gt (Tensor): shape(bs, n_max_boxes, 1)
	
	Returns:
	    target_labels (Tensor): shape(bs, num_total_anchors)
	    target_bboxes (Tensor): shape(bs, num_total_anchors, 4)
	    target_scores (Tensor): shape(bs, num_total_anchors, num_classes)
	    fg_mask (Tensor): shape(bs, num_total_anchors)
	    target_gt_idx (Tensor): shape(bs, num_total_anchors)	
	"""

	"""
	mask_pos: (Tensor:(b,8,8400)), 01矩阵, 每个目标得分top10的anchor取1
	align_metric: (Tensor:(b,8,8400)), iou与分类得分融合指标
	overlaps: (Tensor:(b,8,8400)), iou
	"""
	mask_pos, align_metric, overlaps = self.get_pos_mask(pd_scores, pd_bboxes, gt_labels, gt_bboxes, anc_points, mask_gt)

	"""
	mask_pos: (Tensor:(b,8,8400)), 去重后的结果
	fg_mask: (Tensor:(b,8400)), fg_mask=mask_pos.sum(-2), anchor是否与gt匹配
	target_gt_idx: (Tensor:(b,8400)), target_gt_idx=mask_pos.argmax(-2), anchor匹配目标的索引
	"""
	target_gt_idx, fg_mask, mask_pos = select_highest_overlaps(mask_pos, overlaps, self.n_max_boxes)
	
	# Assigned target
	"""
	target_labels: (Tensor:(b,8400))
	target_bboxes: (Tensor:(b,8400,4))
	target_scores: (Tensor:(b,8400,cls_n)), one-hot label
	"""
	target_labels, target_bboxes, target_scores = self.get_targets(gt_labels, gt_bboxes, target_gt_idx, fg_mask)
	
	# Normalize
	align_metric *= mask_pos
	pos_align_metrics = align_metric.amax(axis=-1, keepdim=True)  # b, max_num_obj
	pos_overlaps = (overlaps * mask_pos).amax(axis=-1, keepdim=True)  # b, max_num_obj
	norm_align_metric = (align_metric * pos_overlaps / (pos_align_metrics + self.eps)).amax(-2).unsqueeze(-1)  # b, a, 1
	target_scores = target_scores * norm_align_metric
	
	return target_labels, target_bboxes, target_scores, fg_mask.bool(), target_gt_idx

3.1 get_pos_mask

def get_pos_mask(self, pd_scores, pd_bboxes, gt_labels, gt_bboxes, anc_points, mask_gt):
	"""
	mask_in_gts: (Tensor:(b,8,8400))
	01矩阵, 若anchor坐标在某个gt_bboxes内部则为1
	这里把(mask_in_gts*mask_gt)称作候选anchor
	"""
	mask_in_gts = select_candidates_in_gts(anc_points, gt_bboxes)
	
	"""
	align_metric: (Tensor:(b,8,8400)), iou与分类得分融合指标
	overlaps: (Tensor:(b,8,8400)), iou
	仅候选anchor部分计算指标, 其余位置为0
	"""
	align_metric, overlaps = self.get_box_metrics(pd_scores, pd_bboxes, gt_labels, gt_bboxes, mask_in_gts * mask_gt)
	
	"""
	mask_pos: (b,8,8400)
	01矩阵, 候选anchor中得分前10取1
	"""
	mask_topk = self.select_topk_candidates(align_metric, topk_mask=mask_gt.expand(-1, -1, self.topk).bool())
	mask_pos = mask_topk * mask_in_gts * mask_gt

   	return mask_pos, align_metric, overlaps

(1)挑出在 gt_bbox 内部的 anchor

def select_candidates_in_gts(xy_centers, gt_bboxes, eps=1e-9):
	n_anchors = xy_centers.shape[0]
    bs, n_boxes, _ = gt_bboxes.shape
    lt, rb = gt_bboxes.view(-1, 1, 4).chunk(2, 2)  # left-top, right-bottom
    bbox_deltas = torch.cat((xy_centers[None] - lt, rb - xy_centers[None]), dim=2).view(bs, n_boxes, n_anchors, -1)
    return bbox_deltas.amin(3).gt_(eps)

(2)计算目标框指标

def get_box_metrics(self, pd_scores, pd_bboxes, gt_labels, gt_bboxes, mask_gt):
	"""Compute alignment metric given predicted and ground truth bounding boxes."""
	na = pd_bboxes.shape[-2]
	mask_gt = mask_gt.bool()  # b, max_num_obj, h*w
	overlaps = torch.zeros([self.bs, self.n_max_boxes, na], dtype=pd_bboxes.dtype, device=pd_bboxes.device)
	bbox_scores = torch.zeros([self.bs, self.n_max_boxes, na], dtype=pd_scores.dtype, device=pd_scores.device)
	
	ind = torch.zeros([2, self.bs, self.n_max_boxes], dtype=torch.long)  # 2, b, max_num_obj
	ind[0] = torch.arange(end=self.bs).view(-1, 1).expand(-1, self.n_max_boxes)  # b, max_num_obj
	ind[1] = gt_labels.squeeze(-1)  # b, max_num_obj
	# Get the scores of each grid for each gt cls
	"""
	bbox_scores: (Tensor:(b,8,8400))
	把候选anchor对应的正确类别的 cls_score 记录下来
	"""
	bbox_scores[mask_gt] = pd_scores[ind[0], :, ind[1]][mask_gt]  # b, max_num_obj, h*w
	
	# (b, max_num_obj, 1, 4), (b, 1, h*w, 4)
	pd_boxes = pd_bboxes.unsqueeze(1).expand(-1, self.n_max_boxes, -1, -1)[mask_gt]
	gt_boxes = gt_bboxes.unsqueeze(2).expand(-1, -1, na, -1)[mask_gt]
	overlaps[mask_gt] = bbox_iou(gt_boxes, pd_boxes, xywh=False, CIoU=True).squeeze(-1).clamp_(0)
	
	"""alpha=0.5, beta=6"""
	align_metric = bbox_scores.pow(self.alpha) * overlaps.pow(self.beta)
	return align_metric, overlaps

(3)选取得分 Top-10 的 anchor

def select_topk_candidates(self, metrics, largest=True, topk_mask=None):
    """
    Select the top-k candidates based on the given metrics.

    Args:
        metrics (Tensor): A tensor of shape (b, max_num_obj, h*w), where b is the batch size,
                          max_num_obj is the maximum number of objects, and h*w represents the
                          total number of anchor points.
        largest (bool): If True, select the largest values; otherwise, select the smallest values.
        topk_mask (Tensor): An optional boolean tensor of shape (b, max_num_obj, topk), where
                            topk is the number of top candidates to consider. If not provided,
                            the top-k values are automatically computed based on the given metrics.

    Returns:
        (Tensor): A tensor of shape (b, max_num_obj, h*w) containing the selected top-k candidates.
    """

    # (b, max_num_obj, topk)
    topk_metrics, topk_idxs = torch.topk(metrics, self.topk, dim=-1, largest=largest)
    if topk_mask is None:
        topk_mask = (topk_metrics.max(-1, keepdim=True)[0] > self.eps).expand_as(topk_idxs)
    # (b, max_num_obj, topk)
    topk_idxs.masked_fill_(~topk_mask, 0)

    # (b, max_num_obj, topk, h*w) -> (b, max_num_obj, h*w)
    """count_tensor: (b,8,8400)"""
    count_tensor = torch.zeros(metrics.shape, dtype=torch.int8, device=topk_idxs.device)
    """ones: (b,8,1)"""
    ones = torch.ones_like(topk_idxs[:, :, :1], dtype=torch.int8, device=topk_idxs.device)
    for k in range(self.topk):
        # Expand topk_idxs for each value of k and add 1 at the specified positions
        count_tensor.scatter_add_(-1, topk_idxs[:, :, k:k + 1], ones)
    # filter invalid bboxes
    """这里去除的无效框其实就是与mask_gt对应的假目标"""
    count_tensor.masked_fill_(count_tensor > 1, 0)
    
    return count_tensor.to(metrics.dtype)

3.2 select_highest_overlaps

当某个anchor与多个目标适配时,选取得分最高的目标保留为1,其他置零。

def select_highest_overlaps(mask_pos, overlaps, n_max_boxes):
    """
    If an anchor box is assigned to multiple gts, the one with the highest IoI will be selected.

    Args:
        mask_pos (Tensor): shape(b, n_max_boxes, h*w)
        overlaps (Tensor): shape(b, n_max_boxes, h*w)

    Returns:
        target_gt_idx (Tensor): shape(b, h*w)
        fg_mask (Tensor): shape(b, h*w)
        mask_pos (Tensor): shape(b, n_max_boxes, h*w)
    """
    # (b, n_max_boxes, h*w) -> (b, h*w)
    fg_mask = mask_pos.sum(-2)
    if fg_mask.max() > 1:  # one anchor is assigned to multiple gt_bboxes
        mask_multi_gts = (fg_mask.unsqueeze(1) > 1).expand(-1, n_max_boxes, -1)  # (b, n_max_boxes, h*w)
        max_overlaps_idx = overlaps.argmax(1)  # (b, h*w)

        is_max_overlaps = torch.zeros(mask_pos.shape, dtype=mask_pos.dtype, device=mask_pos.device)
        is_max_overlaps.scatter_(1, max_overlaps_idx.unsqueeze(1), 1)

        mask_pos = torch.where(mask_multi_gts, is_max_overlaps, mask_pos).float()  # (b, n_max_boxes, h*w)
        fg_mask = mask_pos.sum(-2)
    # Find each grid serve which gt(index)
    target_gt_idx = mask_pos.argmax(-2)  # (b, h*w)
    return target_gt_idx, fg_mask, mask_pos

3.3 get_targets

def get_targets(self, gt_labels, gt_bboxes, target_gt_idx, fg_mask):
    """
    Compute target labels, target bounding boxes, and target scores for the positive anchor points.

    Args:
        gt_labels (Tensor): Ground truth labels of shape (b, max_num_obj, 1), where b is the
                            batch size and max_num_obj is the maximum number of objects.
        gt_bboxes (Tensor): Ground truth bounding boxes of shape (b, max_num_obj, 4).
        target_gt_idx (Tensor): Indices of the assigned ground truth objects for positive
                                anchor points, with shape (b, h*w), where h*w is the total
                                number of anchor points.
        fg_mask (Tensor): A boolean tensor of shape (b, h*w) indicating the positive
                          (foreground) anchor points.

    Returns:
        (Tuple[Tensor, Tensor, Tensor]): A tuple containing the following tensors:
            - target_labels (Tensor): Shape (b, h*w), containing the target labels for
                                      positive anchor points.
            - target_bboxes (Tensor): Shape (b, h*w, 4), containing the target bounding boxes
                                      for positive anchor points.
            - target_scores (Tensor): Shape (b, h*w, num_classes), containing the target scores
                                      for positive anchor points, where num_classes is the number
                                      of object classes.
    """

    # Assigned target labels, (b, 1)
    batch_ind = torch.arange(end=self.bs, dtype=torch.int64, device=gt_labels.device)[..., None]
    target_gt_idx = target_gt_idx + batch_ind * self.n_max_boxes  # (b, h*w)
    target_labels = gt_labels.long().flatten()[target_gt_idx]  # (b, h*w)

    # Assigned target boxes, (b, max_num_obj, 4) -> (b, h*w)
    target_bboxes = gt_bboxes.view(-1, 4)[target_gt_idx]

    # Assigned target scores
    target_labels.clamp_(0)

    # 10x faster than F.one_hot()
    target_scores = torch.zeros((target_labels.shape[0], target_labels.shape[1], self.num_classes),
                                dtype=torch.int64,
                                device=target_labels.device)  # (b, h*w, 80)
    target_scores.scatter_(2, target_labels.unsqueeze(-1), 1)

    fg_scores_mask = fg_mask[:, :, None].repeat(1, 1, self.num_classes)  # (b, h*w, 80)
    target_scores = torch.where(fg_scores_mask > 0, target_scores, 0)

    return target_labels, target_bboxes, target_scores

4. class loss

target_scores_sum = max(target_scores.sum(), 1)
# cls loss
loss[2] = self.bce(pred_scores, target_scores.to(dtype)).sum() / target_scores_sum  # BCE

self.bce = nn.BCEWithLogitsLoss(reduction='none')

5. bbox loss

target_bboxes /= stride_tensor
loss[0], loss[3] = self.bbox_loss(pred_distri, pred_bboxes, anchor_points, target_bboxes / stride_tensor, target_scores, target_scores_sum, fg_mask)

def forward(self, pred_dist, pred_bboxes, anchor_points, target_bboxes, target_scores, target_scores_sum, fg_mask):
	"""IoU loss."""
	weight = target_scores.sum(-1)[fg_mask].unsqueeze(-1)
	iou = bbox_iou(pred_bboxes[fg_mask], target_bboxes[fg_mask], xywh=False, CIoU=True)
	loss_iou = ((1.0 - iou) * weight).sum() / target_scores_sum
	
	# DFL loss
	if self.use_dfl:
		target_ltrb = bbox2dist(anchor_points, target_bboxes, self.reg_max)
		loss_dfl = self._df_loss(pred_dist[fg_mask].view(-1, self.reg_max + 1), target_ltrb[fg_mask]) * weight
		loss_dfl = loss_dfl.sum() / target_scores_sum
	else:
		loss_dfl = torch.tensor(0.0).to(pred_dist.device)
	
	return loss_iou, loss_dfl

@staticmethod
def _df_loss(pred_dist, target):
    """Return sum of left and right DFL losses."""
    # Distribution Focal Loss (DFL) proposed in Generalized Focal Loss https://ieeexplore.ieee.org/document/9792391
    tl = target.long()  # target left
    tr = tl + 1  		# target right
    wl = tr - target  	# weight left
    wr = 1 - wl  		# weight right
    return (F.cross_entropy(pred_dist, tl.view(-1), reduction='none').view(tl.shape) * wl +
            F.cross_entropy(pred_dist, tr.view(-1), reduction='none').view(tl.shape) * wr).mean(-1, keepdim=True)

6. mask loss

"""下采样到160x160"""
masks = batch['masks'].to(self.device).float()
if tuple(masks.shape[-2:]) != (mask_h, mask_w):  # downsample
    masks = F.interpolate(masks[None], (mask_h, mask_w), mode='nearest')[0]

for i in range(batch_size):
	if fg_mask[i].sum():
	     mask_idx = target_gt_idx[i][fg_mask[i]]
	     if self.overlap:
	         """得到每个有效anchor对应的gt_mask (n,160,160)"""
	         gt_mask = torch.where(masks[[i]] == (mask_idx + 1).view(-1, 1, 1), 1.0, 0.0)
	     else:
	         gt_mask = masks[batch_idx.view(-1) == i][mask_idx]
	     xyxyn = target_bboxes[i][fg_mask[i]] / imgsz[[1, 0, 1, 0]]
	     marea = xyxy2xywh(xyxyn)[:, 2:].prod(1)
	     mxyxy = xyxyn * torch.tensor([mask_w, mask_h, mask_w, mask_h], device=self.device)
	     loss[1] += self.single_mask_loss(gt_mask, pred_masks[i][fg_mask[i]], proto[i], mxyxy, marea)  # seg
	else:
	    loss[1] += (proto * 0).sum() + (pred_masks * 0).sum()  # inf sums may lead to nan loss


def single_mask_loss(self, gt_mask, pred, proto, xyxy, area):
    """Mask loss for one image."""
    pred_mask = (pred @ proto.view(self.nm, -1)).view(-1, *proto.shape[1:])  # (n, 32) @ (32,80,80) -> (n,80,80)
    """loss:(n,160,160)"""
    loss = F.binary_cross_entropy_with_logits(pred_mask, gt_mask, reduction='none')
    """每个anchor的损失求均值后除以对应box的面积再求均值"""
    return (crop_mask(loss, xyxy).mean(dim=(1, 2)) / area).mean()

7. loss 融合

"""
box=7.5, cls=0.5, dfl=1.5
"""
loss[0] *= self.hyp.box  # box gain
loss[1] *= self.hyp.box / batch_size  # seg gain
loss[2] *= self.hyp.cls  # cls gain
loss[3] *= self.hyp.dfl  # dfl gain

return loss.sum() * batch_size, loss.detach()

8. IoU 细节

def bbox_iou(box1, box2, xywh=True, GIoU=False, DIoU=False, CIoU=False, eps=1e-7):
    """
    Calculate Intersection over Union (IoU) of box1(1, 4) to box2(n, 4).

    Args:
        box1 (torch.Tensor): A tensor representing a single bounding box with shape (1, 4).
        box2 (torch.Tensor): A tensor representing n bounding boxes with shape (n, 4).
        xywh (bool, optional): If True, input boxes are in (x, y, w, h) format. If False, input boxes are in
                               (x1, y1, x2, y2) format. Defaults to True.
        GIoU (bool, optional): If True, calculate Generalized IoU. Defaults to False.
        DIoU (bool, optional): If True, calculate Distance IoU. Defaults to False.
        CIoU (bool, optional): If True, calculate Complete IoU. Defaults to False.
        eps (float, optional): A small value to avoid division by zero. Defaults to 1e-7.

    Returns:
        (torch.Tensor): IoU, GIoU, DIoU, or CIoU values depending on the specified flags.
    """

    # Get the coordinates of bounding boxes
    if xywh:  # transform from xywh to xyxy
        (x1, y1, w1, h1), (x2, y2, w2, h2) = box1.chunk(4, -1), box2.chunk(4, -1)
        w1_, h1_, w2_, h2_ = w1 / 2, h1 / 2, w2 / 2, h2 / 2
        b1_x1, b1_x2, b1_y1, b1_y2 = x1 - w1_, x1 + w1_, y1 - h1_, y1 + h1_
        b2_x1, b2_x2, b2_y1, b2_y2 = x2 - w2_, x2 + w2_, y2 - h2_, y2 + h2_
    else:  # x1, y1, x2, y2 = box1
        b1_x1, b1_y1, b1_x2, b1_y2 = box1.chunk(4, -1)
        b2_x1, b2_y1, b2_x2, b2_y2 = box2.chunk(4, -1)
        w1, h1 = b1_x2 - b1_x1, b1_y2 - b1_y1 + eps
        w2, h2 = b2_x2 - b2_x1, b2_y2 - b2_y1 + eps

    # Intersection area
    inter = (b1_x2.minimum(b2_x2) - b1_x1.maximum(b2_x1)).clamp_(0) * \
            (b1_y2.minimum(b2_y2) - b1_y1.maximum(b2_y1)).clamp_(0)

    # Union Area
    union = w1 * h1 + w2 * h2 - inter + eps

    # IoU
    iou = inter / union
    if CIoU or DIoU or GIoU:
        cw = b1_x2.maximum(b2_x2) - b1_x1.minimum(b2_x1)  # convex (smallest enclosing box) width
        ch = b1_y2.maximum(b2_y2) - b1_y1.minimum(b2_y1)  # convex height
        if CIoU or DIoU:  # Distance or Complete IoU https://arxiv.org/abs/1911.08287v1
            c2 = cw ** 2 + ch ** 2 + eps  # convex diagonal squared
            rho2 = ((b2_x1 + b2_x2 - b1_x1 - b1_x2) ** 2 + (b2_y1 + b2_y2 - b1_y1 - b1_y2) ** 2) / 4  # center dist ** 2
            if CIoU:  # https://github.com/Zzh-tju/DIoU-SSD-pytorch/blob/master/utils/box/box_utils.py#L47
                v = (4 / math.pi ** 2) * (torch.atan(w2 / h2) - torch.atan(w1 / h1)).pow(2)
                with torch.no_grad():
                    alpha = v / (v - iou + (1 + eps))
                return iou - (rho2 / c2 + v * alpha)  # CIoU
            return iou - rho2 / c2  # DIoU
        c_area = cw * ch + eps  # convex area
        return iou - (c_area - union) / c_area  # GIoU https://arxiv.org/pdf/1902.09630.pdf
    return iou  # IoU

虽然注释中说这个函数是计算一个框 box1(1,4) 与多个框 box2(n,4) 的 IoU,但实际也能计算多个框 box1(n,4) 与多个框 box2(n,4) 的 IoU(n相同)。

在这里插入图片描述

(1)IoU

IoU = S I S 1 + S 2 − S I \text{IoU}=\frac{S_I}{S_1+S_2-S_I} IoU=S1+S2SISI

(2)GIoU

GIoU = IoU − S C − S I S C \text{GIoU}=\text{IoU}-\frac{S_C-S_I}{S_C} GIoU=IoUSCSCSI

当2个box无交集时,GIoU 可以额外衡量两个 box 的距离,距离越近,GIoU 越大

在这里插入图片描述

(3)DIoU

DIoU = IoU − d 2 c 2 \text{DIoU}=\text{IoU}-\frac{d^2}{c^2} DIoU=IoUc2d2

请添加图片描述

在上图这些情况下 GIoU 降级成了 IoU,但是 DIoU 仍可以进行区分。绿色框为目标框,红色框为预测框。

在这里插入图片描述

第一行为 GIoU,第二行为 DIoU。黑色框为 Anchor,绿色框为目标框,蓝色和红色框为预测框。

GIoU 通常会增大预测框使其与目标框重叠,而 DIoU 会直接最小化中心点距离,收敛速度更快。

(4)CIoU

CIoU = DIoU − v 2 1 − IoU + v \text{CIoU}=\text{DIoU}-\frac{v^2}{1-\text{IoU}+v} CIoU=DIoU1IoU+vv2

v = ( arctan ( w 2 / h 2 ) π / 2 − arctan ( w 1 / h 1 ) π / 2 ) 2 v=(\frac{\text{arctan}(w_2/h_2)}{\pi/2}-\frac{\text{arctan}(w_1/h_1)}{\pi/2})^2 v=(π/2arctan(w2/h2)π/2arctan(w1/h1))2

当两个面积相同但长宽比不同的 box1 在 box2 内部且中心点距离相同时,DIoU 无法区分,而 CIoU 能进一步优化预测框的长宽比。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/143185.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

管理类联考——数学——汇总篇——知识点突破——代数——函数、方程——记忆——一元二次方程

——一元二次方程——【核心为“根”:求根,根的多少/判别式,根与系数,根的正负,根的范围/区间】 一元二次方程:只含一个未知数,且未知数的最高次数是2的方程,“元”是指方程中所含未…

出入库管理系统vue2前端开发服务器地址配置

【精选】vue.config.js 的完整配置(超详细)_vue.config.js配置_web学生网页设计的博客-CSDN博客 本项目需要修改两处: 1、vue开发服务器地址:config\index.js use strict // Template version: 1.3.1 // see http://vuejs-templa…

2012年08月16日 Go生态洞察:优雅的代码组织之道

🌷🍁 博主猫头虎(🐅🐾)带您 Go to New World✨🍁 🦄 博客首页——🐅🐾猫头虎的博客🎐 🐳 《面试题大全专栏》 🦕 文章图文…

开源会议通知H5页面邀请函制作源码系统+自动翻页 带完整的搭建教程

现如今,线上活动越来越频繁,而会议邀请函也成为了活动组织者不可或缺的工具。然而,传统的邮件、短信等方式发送邀请函已经无法满足现代人的需求。因此,开发一款现代化的、功能丰富的会议邀请函系统势在必行。下面源码小编将来给大…

全网最全synchronized锁升级过程

一、前言 在面试题中经常会有这么一道面试题,谈一下synchronized锁升级过程? 之前背了一些,很多文章也说了,到底怎么什么条件才会触发升级,一直不太明白。 实践是检验真理的唯一标准,今天就和大家一起实…

kafka+ubuntu20.04+docker配置

记录一次配置过程 安装docker 参加下面链接的第一部分 Ubuntu20.04使用docker安装kafka服务-CSDN博客 安装zookeeper docker run -d --name zookeeper -p 2181:2181 -v /etc/localtime:/etc/localtime wurstmeister/zookeeper安装kafka服务 docker run -d --name kafka …

react路由安装配置react-router-dom/‘Switch‘ is not defined报错解决

1.安装 npm install --save react-router-dom安装完成 新建两个页面并导出 app.js import Nav from ./components/Nav import Home from ./components/Home import { Link, Route, Switch } from react-router-domfunction App() {return (<div><div><p>&…

【2021集创赛】Arm杯一等奖作品—基于 Cortex-M3 内核 SOC 的动目标检测与跟踪系统

本作品介绍参与极术社区的有奖征集|秀出你的集创赛作品风采,免费电子产品等你拿~ 团队介绍 参赛单位&#xff1a;北京理工大学 队伍名称&#xff1a;飞虎队 指导老师&#xff1a;李彬 参赛杯赛&#xff1a;Arm杯 参赛人员&#xff1a;余裕鑫 胡涵谦 刘鹏昀 获奖情况&#xff1…

使用责任链模式实现登录风险控制

责任链模式 责任链模式是是设计模式中的一种行为型模式。该模式下&#xff0c;多个对象通过next属性进行关系关联&#xff0c;从而形成一个对象执行链表。当发起执行请求时&#xff0c;会从首个节点对象开始向后依次执行&#xff0c;如果一个对象不能处理该请求或者完成了请求…

侧击雷如何检测预防

侧击雷是一种雷击的形式&#xff0c;指的是雷电从建筑物的侧面打来的直接雷击。侧击雷对高层建筑物的防雷保护提出了更高的要求&#xff0c;因为一般的避雷带或避雷针不能完全保护住建筑物的侧面。侧击雷可能会对建筑物的结构、设备和人员造成严重的损害&#xff0c;甚至引发火…

酷开科技丨酷开系统,带你进入惊喜不断的影视世界!

随着科技的迅速发展&#xff0c;智能电视已经成为家庭娱乐的重要组成部分。而要说到智能电视&#xff0c;就不得不提到酷开系统&#xff0c;作为一款智能电视操作系统&#xff0c;酷开系统以其独特的功能和出色的使用体验&#xff0c;让观众们看到了到惊喜不断的影视世界。 如…

CRM系统:助力数据服务企业,打造核心竞争力

近年来&#xff0c;数据服务企业开始走入大众视野。作为企业管理应用热门选手——CRM客户管理系统&#xff0c;可以助力企业实时数据应用先行者&#xff0c;提升业务转化与协同效率&#xff0c;进一步打造核心竞争力。下面我们说说&#xff0c;CRM系统对数据服务企业的作用。 …

Stable Diffusion 是否使用 GPU?

在线工具推荐&#xff1a; Three.js AI纹理开发包 - YOLO合成数据生成器 - GLTF/GLB在线编辑 - 3D模型格式在线转换 - 3D数字孪生场景编辑器 Stable Diffusion 已迅速成为最流行的生成式 AI 工具之一&#xff0c;用于通过文本到图像扩散模型创建图像。但是&#xff0c;它需…

使用 Stable Diffusion Img2Img 生成、放大、模糊和增强

在线工具推荐&#xff1a; Three.js AI纹理开发包 - YOLO合成数据生成器 - GLTF/GLB在线编辑 - 3D模型格式在线转换 - 3D数字孪生场景编辑器 Stable Diffusion 2022.1 Img5Img 于 2 年发布&#xff0c;是一款革命性的深度学习模型&#xff0c;正在重新定义和推动照片级真实…

云原生Kubernetes系列 | 通过容器互联搭建wordpress博客系统

云原生Kubernetes系列 | 通过容器互联搭建wordpress博客系统 通过容器互联搭建一个wordpress博客系统。wordpress系统是需要连接到数据库上的&#xff0c;所以wordpress和mysql的镜像都是需要的。wordpress在创建过程中需要指定一些参数。创建mysql容器时需要把mysql的数据保存…

linux系统下文件操作常用的命令

一、是什么 Linux 是一个开源的操作系统&#xff08;OS&#xff09;&#xff0c;是一系列Linux内核基础上开发的操作系统的总称&#xff08;常见的有Ubuntu、centos&#xff09; 系统通常会包含以下4个主要部分 内核shell文件系统应用程序 文件系统是一个目录树的结构&…

PyQt中QFrame窗口中的组件不显示的原因

文章目录 问题代码&#xff08;例&#xff09;原因和解决方法 问题代码&#xff08;例&#xff09; from PyQt5.QtWidgets import * from PyQt5.QtGui import QFont, QIcon, QCursor, QPixmap import sysclass FrameToplevel(QFrame):def __init__(self, parentNone):super().…

【Python基础篇】变量

博主&#xff1a;&#x1f44d;不许代码码上红 欢迎&#xff1a;&#x1f40b;点赞、收藏、关注、评论。 格言&#xff1a; 大鹏一日同风起&#xff0c;扶摇直上九万里。 文章目录 一 Python中变量的定义二 Python中变量的使用三 Python中变量的类型四 Python中变量的删除五 …

[数据结构大作业]HBU 河北大学校园导航

校园导航实验报告 问题描述&#xff1a; 以我校为例&#xff0c;设计一个校园导航系统&#xff0c;主要为来访的客人提供信息查询。系统有两类登陆账号&#xff0c;一类是游客&#xff0c;使用该系统方便校内路线查询&#xff1b;一类是管理员&#xff0c;可以使用该系统查询…

mysql常用命令-03

今天讲解下mysql中创建表的语法 CREATE TABLE tb_name( 列名 数据类型 [PRIMARY KEY] [AUTO_INCREMENT], 列名 数据类型 [NULL | NOT NULL], ....., 列名 数据类型 ); 1.创建班级表classes,结构如下&#xff1a; 列名数据类型允许空约束其它说明cid INT主键班级编号cname…