centerpoint论文和代码解读

 

目录

一、序论

二、论文结构

三、代码


论文地址: https://arxiv.org/pdf/2006.11275.pdf

 代码地址:tianweiy/CenterPoint (github.com)

一、序论

centorpoint是一种anchor-free的方法,直接预测物体的中心点,然后直接回归其whl,省去了anchor与GT匹配过程(传统的anchor-base方法需要计算GT和anchor的iou进行分配),同时基于点的预测方便下游跟踪等任务的进行。论文最后的实验表明,该方法对于物体的旋转角度的学习更强一点。因为初始化只有一个点,强迫模型去学习更多的旋转角度信息。反之,anchor-base的方法因为有anchor的先验,所以模型更容易收敛。

二、论文结构

 

整体的网络架构和pointpillar很像,主要的改动地方在于head部分是anchor-free的。所以我们主要分析的也就是head部分。 

前面的部分,点云经过VFE处理,scatter投影到BEV上,使用FPN的neck对其进行处理得到[B,C,H,W],然后通过一个conv对通道数进行调整,分别经过五个头(其实就是一堆卷积+一个卷积把channel降到需要的维度),得到reg [B,2,W,H] heigh[B,1,w,h] dim [B,3,W,H] rot [B,2,H,W] hm [B,8,H,W]。预测的reg是在一个像素内的偏移,主要是为了

推理时:将dim求指数,根据rot的正余弦值得到角度,将reg与meshgrid生成的坐标相加得到特征图上的绝对坐标。将他们拼接成[B,H*W,7]的box形式,同时对hm求sigmoid,送入后处理,首先对heatmap在channel维度求max,得到其分数和label,根据类别阈值对hm求mask,看哪些能够保留,然后进行NMS过滤掉多余的框,这里我们就说一阶段的,论文里用的两阶段,还有一个box修正阶段。注意:centorpoint使用了NMS

训练时:首先要得到GT的hm和box,所以先0初始化hm [B,8,h,w]  anno_box [B,500,8] ind [B,500] msk [B,500] cat [B,500] 因为每个样本的GT数量不可能一样,所以有的多有的少,统一为500最多,用mask来表示是不是GT,遍历GT个数,根据类别生成相应的hm,高斯半径是根据wh的框的最小iou重叠度确定的,具体见说点Cornernet/Centernet代码里面GT heatmap里面如何应用高斯散射核 - 知乎 (zhihu.com)(分三种,内切,外切,交叉),这里作者限定了高斯半径的最小值。然后看中心点落在哪个pillar里,求个整型做差得到偏移量。对whl求log,对角度求sincos组成anno_box,ind表示该物体中心点在H*W中的下标,cat表示该物体的类别。这样就得到了example。如何画高斯就是用指数的负dist次表示权重,这样离中心点越近,越接近1.

这时有了GT的hm [B,8,h,w]  anno_box [B,500,8] ind [B,500] msk [B,500] cat [B,500]

模型预测的reg [B,2,W,H] heigh[B,1,w,h] dim [B,3,W,H] rot [B,2,H,W] hm [B,8,H,W]

对模型预测的hm进行sigmoid,并组成pred_box[B,8,H*W]这时要把pred_box根据ind用gather转换为[B,8,500],用L1loss计算。而hm则直接用Fastfocalloss计算。

三、代码

import logging
from collections import defaultdict
from torch import double, nn
import copy 


import torch
import numpy as np
import torch.nn.functional as F

from ...ops.iou3d_nms import iou3d_nms_cuda
from ..model_utils import model_nms_utils


class Sequential(torch.nn.Module):
    r"""A sequential container.
    Modules will be added to it in the order they are passed in the constructor.
    Alternatively, an ordered dict of modules can also be passed in.

    To make it easier to understand, given is a small example::

        # Example of using Sequential
        model = Sequential(
                  nn.Conv2d(1,20,5),
                  nn.ReLU(),
                  nn.Conv2d(20,64,5),
                  nn.ReLU()
                )

        # Example of using Sequential with OrderedDict
        model = Sequential(OrderedDict([
                  ('conv1', nn.Conv2d(1,20,5)),
                  ('relu1', nn.ReLU()),
                  ('conv2', nn.Conv2d(20,64,5)),
                  ('relu2', nn.ReLU())
                ]))

        # Example of using Sequential with kwargs(python 3.6+)
        model = Sequential(
                  conv1=nn.Conv2d(1,20,5),
                  relu1=nn.ReLU(),
                  conv2=nn.Conv2d(20,64,5),
                  relu2=nn.ReLU()
                )
    """

    def __init__(self, *args, **kwargs):
        super(Sequential, self).__init__()
        if len(args) == 1 and isinstance(args[0], OrderedDict):
            for key, module in args[0].items():
                self.add_module(key, module)
        else:
            for idx, module in enumerate(args):
                self.add_module(str(idx), module)
        for name, module in kwargs.items():
            if sys.version_info < (3, 6):
                raise ValueError("kwargs only supported in py36+")
            if name in self._modules:
                raise ValueError("name exists.")
            self.add_module(name, module)

    def __getitem__(self, idx):
        if not (-len(self) <= idx < len(self)):
            raise IndexError("index {} is out of range".format(idx))
        if idx < 0:
            idx += len(self)
        it = iter(self._modules.values())
        for i in range(idx):
            next(it)
        return next(it)

    def __len__(self):
        return len(self._modules)

    def add(self, module, name=None):
        if name is None:
            name = str(len(self._modules))
            if name in self._modules:
                raise KeyError("name exists")
        self.add_module(name, module)

    def forward(self, input):
        # i = 0
        for module in self._modules.values():
            # print(i)
            input = module(input)
            # i += 1
        return input




def rotate_nms_pcdet(boxes, scores, thresh, pre_maxsize=None, post_max_size=None):
    """
    :param boxes: (N, 7) [x, y, z, l, w, h, theta]
    :param scores: (N)
    :param thresh:
    :return:
    """
    # transform back to pcdet's coordinate
    #将角度转换为openpcdet的坐标
    boxes = boxes[:, [0, 1, 2, 4, 3, 5, -1]]
    boxes[:, -1] = -boxes[:, -1] - np.pi /2

    order = scores.sort(0, descending=True)[1] #将这n个box根据分数从大到小排
    if pre_maxsize is not None:  #如果盒子大于阈值,取前max个
        order = order[:pre_maxsize]

    boxes = boxes[order].contiguous()

    keep = torch.LongTensor(boxes.size(0))

    if len(boxes) == 0:
        num_out =0
    else:
        num_out = iou3d_nms_cuda.nms_gpu(boxes, keep, thresh)

    selected = order[keep[:num_out].cuda()].contiguous()

    if post_max_size is not None:
        selected = selected[:post_max_size]

    return selected 


def kaiming_init(
    module, a=0, mode="fan_out", nonlinearity="relu", bias=0, distribution="normal"
):
    assert distribution in ["uniform", "normal"]
    if distribution == "uniform":
        nn.init.kaiming_uniform_(
            module.weight, a=a, mode=mode, nonlinearity=nonlinearity
        )
    else:
        nn.init.kaiming_normal_(
            module.weight, a=a, mode=mode, nonlinearity=nonlinearity
        )
    if hasattr(module, "bias") and module.bias is not None:
        nn.init.constant_(module.bias, bias)

def gaussian_radius(det_size, min_overlap=0.5):
    """
    compute gaussian radius by min_overlap, you can get principle in <<CenterNet :Objects as Points>> paper
    """
    height, width = det_size  #得到高宽

    a1  = 1
    b1  = (height + width)
    c1  = width * height * (1 - min_overlap) / (1 + min_overlap)
    sq1 = np.sqrt(b1 ** 2 - 4 * a1 * c1)
    r1  = (b1 + sq1) / 2

    a2  = 4
    b2  = 2 * (height + width)
    c2  = (1 - min_overlap) * width * height
    sq2 = np.sqrt(b2 ** 2 - 4 * a2 * c2)
    r2  = (b2 + sq2) / 2

    a3  = 4 * min_overlap
    b3  = -2 * min_overlap * (height + width)
    c3  = (min_overlap - 1) * width * height
    sq3 = np.sqrt(b3 ** 2 - 4 * a3 * c3)
    r3  = (b3 + sq3) / 2
    return min(r1, r2, r3)

def gaussian2D(shape, sigma=1):
    """
    compute gaussian
    """
    m, n = [(ss - 1.) / 2. for ss in shape]
    y, x = np.ogrid[-m:m+1,-n:n+1]  #y[7,1]  x [1,7]

    h = np.exp(-(x * x + y * y) / (2 * sigma * sigma)) # [7,7],离原点越近越大
    h[h < np.finfo(h.dtype).eps * h.max()] = 0  #np.finfo(h.dtype).eps是指非负的最小值
    return h


def draw_umich_gaussian(heatmap, center, radius, k=1):
    """
    draw gaussian in heatmap
    """
    diameter = 2 * radius + 1 #radius
    # compute gaussian value
    gaussian = gaussian2D((diameter, diameter), sigma=diameter / 6) #是一个7*7的矩阵

    x, y = int(center[0]), int(center[1]) #获得整形的中点坐标

    height, width = heatmap.shape[0:2]

    # get gaussian map pos
    left, right = min(x, radius), min(width - x, radius + 1)  #如果xy落在heatmap的边上,离边的距离小于r,就要限制一下防止越界
    top, bottom = min(y, radius), min(height - y, radius + 1)

    # get masked heatmap pos 
    masked_heatmap  = heatmap[y - top:y + bottom, x - left:x + right] # 得到我们要替换heatmap的位置
    masked_gaussian = gaussian[radius - top:radius + bottom, radius - left:radius + right] #得到可用高斯的范围

    # this is used for debug, actuly no use
    if min(masked_gaussian.shape) > 0 and min(masked_heatmap.shape) > 0: # TODO debug
        np.maximum(masked_heatmap, masked_gaussian * k, out=masked_heatmap) #取两者中较大的部分
    return heatmap

def _gather_feat(feat, ind, mask=None):
    dim  = feat.size(2) # 8
    ind  = ind.unsqueeze(2).expand(ind.size(0), ind.size(1), dim) #ind[B,500]--[B,500,1]--[B,500,8] 其表示物体在特征图上的索引
    feat = feat.gather(1, ind)  #根据ind在第一维度H*W找索引ind
    if mask is not None:
        mask = mask.unsqueeze(2).expand_as(feat)
        feat = feat[mask]
        feat = feat.view(-1, dim)
    return feat

def _transpose_and_gather_feat(feat, ind):
    feat = feat.permute(0, 2, 3, 1).contiguous()  # [B,200,380,8]
    feat = feat.view(feat.size(0), -1, feat.size(3)) # [B,H*W,8]
    feat = _gather_feat(feat, ind)
    return feat

def _circle_nms(boxes, min_radius, post_max_size=83):
    """
    NMS according to center distance, no use now
    """
    keep = np.array(circle_nms(boxes.cpu().numpy(), thresh=min_radius))[:post_max_size]

    keep = torch.from_numpy(keep).long().to(boxes.device)

    return keep 


class RegLoss(nn.Module):
  '''Regression loss for an output tensor
    Arguments:
      output (batch x dim x h x w)
      mask (batch x max_objects)
      ind (batch x max_objects)
      target (batch x max_objects x dim)
  '''
  def __init__(self):
    super(RegLoss, self).__init__()
  
  def forward(self, output, mask, ind, target):
    # output[B,8,200,380]  pred[B,500,8]
    # compute mask by ind as not all box number is same and not all grid in use
    pred = _transpose_and_gather_feat(output, ind)
    mask = mask.float().unsqueeze(2) 

    # use L1 loss 两者都是[B,500,8]乘上mask计算loss,然后在B和500维度求和,出来八维的loss
    loss = F.l1_loss(pred*mask, target*mask, reduction='none')
    loss = loss / (mask.sum() + 1e-4)
    loss = loss.transpose(2 ,0).sum(dim=2).sum(dim=1)
    return loss

class FastFocalLoss(nn.Module):
  '''
  Reimplemented focal loss, exactly the same as the CornerNet version.
  Faster and costs much less memory.
  '''
  def __init__(self):
    super(FastFocalLoss, self).__init__()

  def forward(self, out, target, ind, mask, cat):
    '''
    Arguments:
      out, target: B x C x H x W
      ind, mask: B x M
      cat (category id for peaks): B x M
    '''
    mask = mask.float()
    gt = torch.pow(1 - target, 4)
    # compute negtive loss in heatmap
    neg_loss = torch.log(1 - out) * torch.pow(out, 2) * gt
    neg_loss = neg_loss.sum()

    pos_pred_pix = _transpose_and_gather_feat(out, ind) # B x M x C
    pos_pred = pos_pred_pix.gather(2, cat.unsqueeze(2)) # B x M
    num_pos = mask.sum()

    # compute positive loss in heatmap
    pos_loss = torch.log(pos_pred) * torch.pow(1 - pos_pred, 2) * \
               mask.unsqueeze(2)
    pos_loss = pos_loss.sum()
    if num_pos == 0:
      return - neg_loss
    return - (pos_loss + neg_loss) / num_pos



def neg_loss_cornernet(pred, gt, mask=None):
    """
    Refer to https://github.com/tianweiy/CenterPoint.
    Modified focal loss. Exactly the same as CornerNet. Runs faster and costs a little bit more memory
    Args:
        pred: (B x 8 x h x w)
        gt: (B x 8 x h x w)
        mask: (B x h x w)
    Returns:
    """
    pos_inds = gt.eq(1).float() #有物体中心点的地方才为1
    neg_inds = gt.lt(1).float() #不是物体中心的为1

    neg_weights = torch.pow(1 - gt, 4) #[B,8,H,W]  #把负样本的权重设置的很小

    loss = 0

    pos_loss = torch.log(pred) * torch.pow(1 - pred, 2) * pos_inds
    neg_loss = torch.log(1 - pred) * torch.pow(pred, 2) * neg_weights * neg_inds #这样负样本loss会很低

    if mask is not None:
        mask = mask[:, None, :, :].float()
        pos_loss = pos_loss * mask
        neg_loss = neg_loss * mask
        num_pos = (pos_inds.float() * mask).sum()
    else:
        num_pos = pos_inds.float().sum()

    pos_loss = pos_loss.sum()
    neg_loss = neg_loss.sum()

    if num_pos == 0:
        loss = loss - neg_loss
    else:
        loss = loss - (pos_loss + neg_loss) / num_pos  #求完的loss之和除以正样本的个数
    return loss


class FocalLossCenterNet(nn.Module):
    """
    Refer to https://github.com/tianweiy/CenterPoint
    """
    def __init__(self):
        super(FocalLossCenterNet, self).__init__()
        self.neg_loss = neg_loss_cornernet

    def forward(self, out, target, mask=None):
        return self.neg_loss(out, target, mask=mask)



class AssignLabel(object):
    def __init__(self, **kwargs):
        """Return CenterNet training labels like heatmap, height, offset"""

        self.tasks = kwargs["tasks"] #assigner_cfg.target_assigner.tasks

        assigner_cfg = kwargs["cfg"]

        self.out_size_factor = assigner_cfg.out_size_factor # 2
        self.gaussian_overlap = assigner_cfg.gaussian_overlap # 0.1
        self._max_objs = assigner_cfg.max_objs  # 500
        self._min_radius = assigner_cfg.min_radius # 2
        # tasks
        self.class_names = self.tasks["class_names"] # 列表里是八个名字
        self.num_classes = self.tasks["num_class"]  # 8

    def __call__(self, res,  grid_size , voxel_size , pc_range):
        max_objs = self._max_objs   # 500

        feature_map_size = grid_size[:2] // self.out_size_factor  # 得到特征图的长宽
        
        draw_gaussian = draw_umich_gaussian
        # 分别是xyzhwl,yaw,类别
        gt_boxes = res['gt_boxes'].cpu().numpy() # 得到data_dict里的GT  [B,N,8]
        batch_size = res['batch_size']

        # hm is heatmap
        hms, anno_boxs, inds, masks, cats = [], [], [], [], []

        #jinmu: batch one by one compute now
        for batch_idx in range(batch_size):
            batch_box = gt_boxes[batch_idx,...]  #[n,8]
            batch_box_mask = batch_box[...,-1] != 0 # 因为n表示batch里一个样本最多的物体数,有些没有这么多
            #上面这句是指遍历n个物体,最后一维不为0表示有物体
            if np.all(batch_box_mask == False):
                batch_box_valid_num = 0
            else:  # batch_box_mask=[1,1,1,1,0,0,0,0,0]一维的话,np.where只返回列数
                batch_box_valid_num = np.where(batch_box_mask)[0].squeeze().max() + 1 #得到有几个物体

            # c, h, w  [8, 200,380]
            hm = np.zeros((len(self.class_names), feature_map_size[1], feature_map_size[0]),
                            dtype=np.float32)
            # [500, 8]
            anno_box = np.zeros((max_objs, 8), dtype=np.float32)
            # [500]
            ind = np.zeros((max_objs), dtype=np.int64)
            mask = np.zeros((max_objs), dtype=np.uint8) # [500]
            cat = np.zeros((max_objs), dtype=np.int64)  # [500]

            # should keep box number same in different frame to
            # compute in one time, but actualy different frame not 
            # has same box number, so should keep mask
            num_objs = min(batch_box_valid_num, max_objs)  #得到当前帧的物体个数

            for k in range(num_objs):
                cls_id = batch_box[k][-1] - 1  #cls的id
                l, w, h = batch_box[k][3], batch_box[k][4], batch_box[k][5]
                # 得到在特征图上的wl
                w, l = w / voxel_size[1] / self.out_size_factor, l / voxel_size[0] / self.out_size_factor
                if w > 0 and l > 0:  #根据长宽得到高斯半径,根据两个框的最小重叠区,建立r的方程求根,内切外切,一个内一个外
                    radius = gaussian_radius((l, w), min_overlap=self.gaussian_overlap) #wl是浮点数,超参为0.1,得到高斯半径
                    radius = max(self._min_radius, int(radius)) #确保最小的高斯半径为2

                    # 得到中心点在特征图上的坐标
                    x, y, z = batch_box[k][0], batch_box[k][1], batch_box[k][2]
                    coor_x, coor_y = (x - pc_range[0]) / voxel_size[0] / self.out_size_factor, \
                                        (y - pc_range[1]) / voxel_size[1] / self.out_size_factor
                    
                    ct = np.array([coor_x, coor_y], dtype=np.float32)  
                    ct_int = ct.astype(np.int32)  #变为整型

                    # throw out not in range objects to avoid out of array area when creating the heatmap
                    # if beyond range, then continue
                    if not (0 <= ct_int[0] < feature_map_size[0] and 0 <= ct_int[1] < feature_map_size[1]):
                        continue 

                    # draw gaussian in heatmap gt
                    draw_gaussian(hm[int(cls_id)], ct, radius) #画到相应类的heatmap上

                    new_idx = k #表示第k个物体
                    x, y = ct_int[0], ct_int[1]

                    cat[new_idx] = cls_id # 得到相应物体的类别
                    ind[new_idx] = y * feature_map_size[0] + x  # 得到该物体在特征图上的索引
                    mask[new_idx] = 1  #把相应位置的mask赋值为1
                    rot = batch_box[k][6]
                    # fill regression target, ct - (x,y) is x_offset and y_offset
                    # rot is yaw angle
                    anno_box[new_idx] = np.concatenate(
                        (ct - (x, y), z, np.log(batch_box[k][3:6]),
                        np.sin(rot), np.cos(rot)), axis=None)  #得到当前heatmap的xy偏移,whl,sincos,

            hms.append(hm)
            anno_boxs.append(anno_box)
            masks.append(mask)
            inds.append(ind)
            cats.append(cat)

        hms = torch.from_numpy(np.stack(hms)).cuda() #将数组沿着第0维堆叠
        anno_boxs = torch.from_numpy(np.stack(anno_boxs)).cuda()
        inds = torch.from_numpy(np.stack(inds)).cuda()
        cats = torch.from_numpy(np.stack(cats)).cuda()
        masks = torch.from_numpy(np.stack(masks)).cuda()
        # [B,8,h,w]   [B,500,8]  [B,500,1] [B,500,1] [B,500,1]
        example = {'hm': hms, 'anno_box': anno_boxs, 'ind': inds, 'mask': masks, 'cat': cats}

        return example


class SepHead(nn.Module):
    """
    this is seqhead that contains actual head like (heatmap) (lxoffset yoffset) (z) (dim) (cos(theta) sin(theta))
    """
    def __init__(
        self,
        in_channels,
        heads,
        head_conv=64,
        final_kernel=1,
        bn=False,
        init_bias=-2.19,
        **kwargs,
    ):
        super(SepHead, self).__init__(**kwargs)

        self.heads = heads # {'reg': [2, 2], 'height': [1, 2], 'dim': [3, 2], 'rot': [2, 2], 'hm': [8, 2]}
        for head in self.heads:  #遍历的是键
            classes, num_conv = self.heads[head] #根据键得到值,第一个最终的channel数,用来回归的,第二个是几个conv

            fc = Sequential()
            # layers number decided by config
            for i in range(num_conv-1):
                fc.add(nn.Conv2d(in_channels, head_conv,
                    kernel_size=final_kernel, stride=1, 
                    padding=final_kernel // 2, bias=True))  #
                if bn:
                    fc.add(nn.BatchNorm2d(head_conv))
                fc.add(nn.ReLU())

            # output conv
            fc.add(nn.Conv2d(head_conv, classes,
                    kernel_size=final_kernel, stride=1, 
                    padding=final_kernel // 2, bias=True))    
            # hm的偏置是固定的,其余的开明初始化
            if 'hm' in head:
                fc[-1].bias.data.fill_(init_bias)
            else:
                for m in fc.modules():
                    if isinstance(m, nn.Conv2d):
                        kaiming_init(m)
            # 每个头都有两个卷积,再接一个卷积用来得到预测结果channel维度
            # python method, 设置完可以用getattr通过head调用fc
            self.__setattr__(head, fc)
        

    def forward(self, x):
        ret_dict = dict()        
        for head in self.heads:
            ret_dict[head] = self.__getattr__(head)(x)
        #ret_dict是一个字典 reg:[B,2,200,380] height [B,1,200,380] dim [B,3,200,380] rot [B,2,200,380] hm [B,8,200,380]
        return ret_dict


class CenterHead(nn.Module):
    def __init__(
        self,
        model_cfg,
        input_channels=[128,],
        num_class=1,
        class_names=None,
        grid_size=[0.32,0.32,0.16],
        point_cloud_range=None,
        predict_boxes_when_training=False,
        logger=None,
        init_bias=-2.19,
        num_hm_conv=2,
    ):
        super(CenterHead, self).__init__()
        assert(len(class_names) == num_class)
        
        tasks = dict(num_class=num_class, class_names=class_names)
        self.label_assigner = AssignLabel(cfg=model_cfg.TARGET_ASSIGNER_CONFIG, tasks=tasks)
        
        self.out_size_factor = model_cfg.TARGET_ASSIGNER_CONFIG.out_size_factor # 2
        self.model_cfg = model_cfg

        self.class_names = [class_names] #class_name本来是一个列表现在[[a,b,c,,,,]]
        self.num_classes = [num_class]  # [8]

        self.code_weights = model_cfg.code_weights #[5.0, 1.0, 1.0, 1.0, 1.0, 1.0, 5.0, 5.0]
        self.weight = model_cfg.weight # 0.25 
        
        self.in_channels = input_channels # 384

        #self.crit = FastFocalLoss()
        self.crit = FocalLossCenterNet()
        self.crit_reg = RegLoss()

        

        common_heads = model_cfg.common_heads #{'reg': [ 2, 2 ],'height': [ 1, 2 ],'dim': [ 3, 2 ],'rot': [ 2, 2 ]}

        self.box_n_dim = 9 if 'vel' in common_heads else 7  # 7
        self.use_direction_classifier = False 

        if not logger:
            logger = logging.getLogger("CenterHead")
        self.logger = logger

        logger.info(
            f"num_classes: {self.num_classes}"
        )

        # a shared convolution 
        share_conv_channel = 64 if "share_conv_channel" not in model_cfg else model_cfg.share_conv_channel # 64
        self.shared_conv = nn.Sequential(
            nn.Conv2d(self.in_channels, share_conv_channel,
            kernel_size=3, padding=1, bias=True),
            nn.BatchNorm2d(share_conv_channel),
            nn.ReLU(inplace=True)
        )

        self.tasks = nn.ModuleList()
        print("Use HM Bias: ", init_bias)

        for num_cls in self.num_classes:  #[8]相当于就遍历一个8
            heads = copy.deepcopy(common_heads) 
            heads.update(dict(hm=(num_cls, num_hm_conv))) #{'reg': [2, 2], 'height': [1, 2], 'dim': [3, 2], 'rot': [2, 2], 'hm': [8, 2]}
            self.tasks.append(
                SepHead(share_conv_channel, heads, bn=True, init_bias=init_bias, final_kernel=3)
            )

        self.frozen_param = model_cfg.FROZON_PARAM
        self.frozen_parameters()

        logger.info("Finish CenterHead Initialization")

    def forward(self, data_dict, *kwargs):

        x = data_dict['spatial_features_2d'] # [B, 384, 200, 380]
        x = self.shared_conv(x)  #先将channel变为64
        ret_dicts = []

        for task in self.tasks:
            ret_dicts.append(task(x))
        # reg [B,2,W,H] heigh[B,1,w,h] dim [B,3,W,H] rot [B,2,H,W] hm [B,8,H,W] 是一个字典
        data_dict['centerhead_preds'] = ret_dicts

        return data_dict

    def _sigmoid(self, x):
        y = torch.clamp(x.sigmoid_(), min=1e-4, max=1-1e-4)
        return y

    def loss(self, data_dict, **kwargs):
        #是一个字典根据GT生成的 hm[B,8,H,W],anno_box [B,n,8] ind[B,n] mask[B,n] cat[B,n]
        example = self.label_assigner(data_dict, kwargs["grid_size"], kwargs["voxel_size"], kwargs["pc_range"])

        # get centerhead output reg[B,2,200,380] heigh[B,1,200,380] dim [B,3,200,380] rot [B,2,200,380] hm [B,8,200,380]
        preds_dicts = data_dict['centerhead_preds']

        assert(len(preds_dicts) == 1)
        # TODO refactor this
        preds_dict = preds_dicts[0] #本来是一个数组,得到字典
        
        # apply sigmoid for heatmap output
        preds_dict['hm'] = self._sigmoid(preds_dict['hm']) #对heatmap预测加上sigmoid,自定义的sigmoid,防止梯度消失
        # hm_loss = self.crit(
        #     preds_dict['hm'], 
        #     example['hm'], 
        #     example['ind'], 
        #     example['mask'], 
        #     example['cat']
        #     )
        
        hm_loss = self.crit(preds_dict['hm'], example['hm']) #使用focallosscenternet

        target_box = example['anno_box']
        # not care about vel as not vel now
        if 'vel' in preds_dict:
            preds_dict['anno_box'] = torch.cat((preds_dict['reg'], preds_dict['height'], preds_dict['dim'],
                                                preds_dict['vel'], preds_dict['rot']), dim=1)  
        else:
            preds_dict['anno_box'] = torch.cat((preds_dict['reg'], preds_dict['height'], preds_dict['dim'],
                                                preds_dict['rot']), dim=1)   

        # Regression loss for dimension, offset, height, rotation  得到长度为8的loss张量          
        box_loss = self.crit_reg(preds_dict['anno_box'], example['mask'], example['ind'], target_box)
        box_loss = box_loss * box_loss.new_tensor(self.code_weights) #这样可以使后面的张量拥有和前面一样的属性
        
        reg_loss = box_loss[:2]
        height_loss = box_loss[2]
        dim_loss = box_loss[2:5]
        rot_loss = box_loss[5:]
        
        loc_loss = box_loss.sum()
        loc_loss *= self.weight

        # total loss
        loss = hm_loss + loc_loss
        #ret = {'loss': loss, 'hm_loss': hm_loss, 'loc_loss':loc_loss, 'loc_loss_elem': box_loss.detach().cpu(), 'num_positive': example['mask'][0].float().sum()}
        # ret = {'hm_loss': hm_loss, 'loc_loss': loc_loss, 
        #         'reg_loss': reg_loss, 'height_loss': height_loss, 
        #         'dim_loss': dim_loss, 'rot_loss': rot_loss}

        ret = {'hm_loss': hm_loss, 'loc_loss': loc_loss}
        
        return ret
    
    def frozen_parameters(self):
        if self.frozen_param:
            for parameter in self.parameters():
                parameter.requires_grad = False

    @torch.no_grad()
    def predict(self, preds_dicts, test_cfg, **kwargs):
        """decode, nms, then return the detection result.
        """

        voxel_size = kwargs["voxel_size"]
        pc_range = kwargs["pc_range"]

        post_center_range = pc_range
        # reg [B,2,W,H] heigh[B,1,w,h] dim [B,3,W,H] rot [B,2,H,W] hm [B,8,H,W] 是一个字典
        preds_dicts = preds_dicts['centerhead_preds']

        if len(post_center_range) > 0:
            post_center_range = torch.tensor(
                post_center_range,
                dtype=preds_dicts[0]['hm'].dtype,
                device=preds_dicts[0]['hm'].device,
            )

        rets = []
        #jinmu now only support one task
        for task_id, preds_dict in enumerate(preds_dicts):
            # convert B C H W to B H W C 
            for key, val in preds_dict.items():
                preds_dict[key] = val.permute(0, 2, 3, 1).contiguous()

            batch_size = preds_dict['hm'].shape[0]
            batch_hm = torch.sigmoid(preds_dict['hm'])

            # exp for dim output to keep dim > 0
            batch_dim = torch.exp(preds_dict['dim']) #dim is h, w, d

            # cos(theta) and sin(theta)
            batch_rots = preds_dict['rot'][..., 0:1]
            batch_rotc = preds_dict['rot'][..., 1:2]

            # x offset and y offset output
            batch_reg = preds_dict['reg']
            # z output
            batch_hei = preds_dict['height']

            # atan to recover true theta
            batch_rot = torch.atan2(batch_rots, batch_rotc) #根据正余弦得到角度

            batch, H, W, num_cls = batch_hm.size()

            # reshape for compute convenient
            batch_reg = batch_reg.reshape(batch, H*W, 2)
            batch_hei = batch_hei.reshape(batch, H*W, 1)

            batch_rot = batch_rot.reshape(batch, H*W, 1)
            batch_dim = batch_dim.reshape(batch, H*W, 3)
            batch_hm = batch_hm.reshape(batch, H*W, num_cls) #把hw放一块方便计算

            #compute x and y axies for each grid for later to recover lidar axies x y with 
            # x_offset and y_offset
            ys, xs = torch.meshgrid([torch.arange(0, H), torch.arange(0, W)])
            ys = ys.view(1, H, W).repeat(batch, 1, 1).to(batch_hm.device).float()
            xs = xs.view(1, H, W).repeat(batch, 1, 1).to(batch_hm.device).float()

            # x y  + x_offset y_offset to recover continuous x y value
            xs = xs.view(batch, -1, 1) + batch_reg[:, :, 0:1]
            ys = ys.view(batch, -1, 1) + batch_reg[:, :, 1:2]

            xs = xs * self.out_size_factor * voxel_size[0] + pc_range[0]
            ys = ys * self.out_size_factor * voxel_size[1] + pc_range[1]

            # jinmu: not care aboud this as we has not vel output now
            if 'vel' in preds_dict:
                batch_vel = preds_dict['vel']
                batch_vel = batch_vel.reshape(batch, H*W, 2)
                batch_box_preds = torch.cat([xs, ys, batch_hei, batch_dim, batch_vel, batch_rot], dim=2)
            else: 
                batch_box_preds = torch.cat([xs, ys, batch_hei, batch_dim, batch_rot], dim=2)

            if test_cfg.get('per_class_nms', False):
                pass 
            else:
                rets.append(self.post_processing(batch_box_preds, batch_hm, test_cfg, post_center_range)) 

        assert(len(rets) == 1) # only one task

        return rets[0]

    @torch.no_grad()
    def post_processing(self, batch_box_preds, batch_hm, test_cfg, post_center_range):
        batch_size = len(batch_hm)
        # batch_box_preds [B,H*W,7] batch_hm [B,H*W,8]
        prediction_dicts = []
        for i in range(batch_size):  #一个一个batch处理
            box_preds = batch_box_preds[i]
            hm_preds = batch_hm[i]

            # score and label is get as max operation in heatmap #在八个维度里取个max
            scores, labels = torch.max(hm_preds, dim=-1) #得到最大分数和最大分数的下标(也就是类别)形状都为[H*W]

            # score mask is get as > score_thresh
            #score_mask = scores > test_cfg.score_threshold 
            score_threshold = torch.tensor(test_cfg.score_threshold)[labels] #得到H*W对应类别的thresh
            score_mask = scores > score_threshold.cuda() #如果这个分数大于阈值,就判定为正样本

            # distance_mask means that noly keep 3d box center in some range
            # not use this in perception postprocess code
            distance_mask = (box_preds[..., :3] >= post_center_range[:3]).all(1) \
                & (box_preds[..., :3] <= post_center_range[3:]).all(1)

            # mask is intersection of two mask
            mask = distance_mask & score_mask 

            # get masked data
            box_preds = box_preds[mask] #得到H*W个box里符合要求的
            scores = scores[mask]
            labels = labels[mask]

            # get box for nms, each box in [x y z dx dy dz theta] format
            boxes_for_nms = box_preds[:, [0, 1, 2, 3, 4, 5, -1]]

            # bev rotated box nms
            selected = rotate_nms_pcdet(boxes_for_nms, scores, 
                                thresh=test_cfg.nms.nms_iou_threshold,
                                pre_maxsize=test_cfg.nms.nms_pre_max_size,
                                post_max_size=test_cfg.nms.nms_post_max_size)

            # selected is box mask after nms
            selected_boxes = box_preds[selected]
            selected_scores = scores[selected]
            selected_labels = labels[selected]

            # fill result, selected_boxes: n * 7, selected_scores: n * 1,
            # selected_labels: n * 1
            record_dict = {
                'pred_boxes': selected_boxes,
                'pred_scores': selected_scores,
                'pred_labels': selected_labels + 1
            }

            prediction_dicts.append(record_dict)

        return prediction_dicts 

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mfbz.cn/a/19415.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

【C++】unordered_map与unordered_set(系列关联式容器)

文章目录 1.unordered系列关联式容器2. unordered_map3.unordered_set 1.unordered系列关联式容器 在C98中&#xff0c;STL提供了底层为红黑树结构的一系列关联式容器&#xff0c;如map和set&#xff0c;它们在查询时效率可达logN&#xff0c;即最差情况下需要比较红黑树的高度…

将 Segment Anything 扩展到医学图像领域

文章目录 前言技术交流SAM 拆解分析从医学角度理解 SAM 的效用MedSAM实验总结 前言 SAM 是一种在自然图像分割方面取得成功的模型&#xff0c;但在医学图像分割方面表现不佳。MedSAM 首次尝试将 SAM 的成功扩展到医学图像&#xff0c;并成为用于分割各种医学图像的通用工具。为…

一文读懂 DNS 解析

导读 文章为“一文读懂域名与网站系列”第二篇&#xff0c;上篇文章主要介绍了域名的注册、建站和管理&#xff0c;通过本文你可以了解以下几个问题&#xff1a; 域名的结构、常用解析记录的类型 DNS 解析的过程 DNS 解析拓展知识 众所周知&#xff0c;互联网中的地址其实是…

Invicti v23.5 for Windows 发布 - 企业应用安全测试

Invicti v23.5 for Windows - 企业应用安全测试 Invicti Standard 11 May 2023 v23.5.0.40516 请访问原文链接&#xff1a;https://sysin.org/blog/invicti/&#xff0c;查看最新版。原创作品&#xff0c;转载请保留出处。 作者主页&#xff1a;sysin.org Invicti 是一种自动…

ESP32在linux下烧录,提示权限有问题,解决方法

执行idf.py -p /dev/ttyACM0 flash下载时&#xff0c;提示这个错误 serial.serialutil.SerialException: [Errno 13] could not open port /dev/ttyACM0: [Errno 13] Permission denied: /dev/ttyACM0 解决方法&#xff1a; 1检查串行端口 /dev/ttyUSB0 是否已被其他程序占用…

系统分析师之项目管理(十七)

一、范围管理 范围管理&#xff1a;确定项目的边界&#xff0c;即哪些工作是项目应该做的&#xff0c;哪些工作不应该包括在项目中。 二、时间管理 时间管理&#xff1a;也叫进度管理&#xff0c;就是用科学的方法&#xff0c;确定目标进度&#xff0c;编制进度计划和资源供应计…

SpringBoot整合Swagger

Swagger的作用&#xff1a;生成前后的接口文档&#xff1a; 了解Swagger的概念及作用 掌握在项目中集成Swagger自动生成API文档 一、SpringBoot集成Swagger 1.依赖&#xff1a; <!-- https://mvnrepository.com/artifact/io.springfox/springfox-swagger2 --><depe…

【A、B、C、D、E类IP地址划分依据,你都会吗?】

IP 地址的格式&#xff1a;IP 地址 网络地址 主机地址 如果 IP 进行了子网划分&#xff1a; 则IP地址网络地址子网地址主机地址 网络地址是互联网上的节点在网络中具有的逻辑地址。MAC 地址&#xff0c;处于数据链 路层&#xff0c;IP 地址处于网络层&#xff0c;端口号处…

人工智能基础部分15-自然语言处理中的数据处理上采样、下采样、负采样是什么?

大家好&#xff0c;我是微学AI&#xff0c;今天给大家介绍一下人工智能基础部分15-自然语言处理中的数据处理上采样、下采样、负采样是什么&#xff1f;在自然语言处理中&#xff0c;上采样、下采样、负采样都是用于处理数据不平衡问题的技术&#xff0c;目的是为了优化模型的训…

C# 对PdfiumViewer工具栏进行自定义,实现放大缩小,首页, 尾页,上一页等功能。

文章目录 前言PdfiumViewer工具栏扩展1 创建winform工程&#xff0c;UI界面2 打印预览3 放大功能4 缩小功能5 按比例缩放6 全屏7 首页和尾页8 上一页和下一页9 页码输入框10 显示当前预览的页码 小结 前言 关于PdfiumViewer的介绍 C# 使用PdfiumViewer实现对PDF文档打印预览&a…

路径规划算法:基于麻雀优化的路径规划算法- 附代码

路径规划算法&#xff1a;基于麻雀优化的路径规划算法- 附代码 文章目录 路径规划算法&#xff1a;基于麻雀优化的路径规划算法- 附代码1.算法原理1.1 环境设定1.2 约束条件1.3 适应度函数 2.算法结果3.MATLAB代码4.参考文献 摘要&#xff1a;本文主要介绍利用智能优化算法麻雀…

Qt使用星空图作为窗口背景,点击键盘的WASD控制小飞机在上面移动。

事件函数的使用依托于Qt的事件机制&#xff0c;一个来自于外部事件的传递机制模型如下所示 信号槽虽然好用&#xff0c;但是无法包含所有的情况&#xff0c;事件函数可以起到对信号槽无法覆盖的一些时机进行补充&#xff0c;事件函数的使用无需连接。 常用的事件函数如下所示。…

【Mysql实战】使用存储过程和计算同比环比

背景 同环比&#xff0c;是基本的数据分析方法。在各类调研表中屡见不鲜&#xff0c;如果人工向前追溯统计数据&#xff0c;可想而知工作量是非常大的。 标题复制10行&#xff0c;并且每行大于10个字符【源码解析】SpringBoot接口参数【Mysql实战】使用存储过程和计算同比环比…

vite跨域问题,你可能需要看这篇文章

最近在学习项目的时候&#xff0c;使用了vite工具进行构建&#xff0c;然后出现了跨域的问题&#xff0c;中间的曲折不过多叙述&#xff0c;直接进入正题。 前端成功启动后的界面&#xff1a; 然后在后端进行的Controller上使用了如下的配置 然后浏览器就会出现跨域的问题 为什…

【论文复现】基于区块链的分布式光伏就地消纳交易模式研究(Matlab代码实现)

&#x1f4a5;&#x1f4a5;&#x1f49e;&#x1f49e;欢迎来到本博客❤️❤️&#x1f4a5;&#x1f4a5; &#x1f3c6;博主优势&#xff1a;&#x1f31e;&#x1f31e;&#x1f31e;博客内容尽量做到思维缜密&#xff0c;逻辑清晰&#xff0c;为了方便读者。 ⛳️座右铭&a…

UE4及Airsim安装时遇到的问题及解决办法

UE4及Airsim安装时遇到的问题及解决办法 目录 UE4及Airsim安装时遇到的问题及解决办法前言UE4下载慢解决方法 Airsim编译过程中提示&#xff1a;无法打开包括文件: “Eigen/Dense”: No such file or directory [D:\software\Visual_studio2022\2022Community\AirSim\Air解决办…

别不信:这些细节关乎你的物联网设备的命运!

《高并发系统实战派》-- 值得拥有 一、设备接入层网络协议的意义 随着物联网的发展&#xff0c;越来越多的设备需要接入云平台进行远程监控和管理。设备接入层网络协议起到了承担设备接入网络的功能&#xff0c;为物联网平台提供了数据交互的基础。设备接入层网络协议对于物联…

【云原生概念和技术】1.2 云原生技术概括(上)

如果想了解或者学习云原生的友友们&#xff0c;欢迎订阅哦&#xff5e;&#x1f917;&#xff0c;目前一周三更&#xff0c;努力码字中&#x1f9d1;‍&#x1f4bb;…目前第一章是一些介绍和概念性的知识&#xff0c;可以先在脑海里有一个知识的轮廓&#xff0c;从第二章开始就…

AUTOSAR入门

简介 AUTOSAR&#xff08;AUTomotive Open System ARchitecture&#xff09;是一种汽车软件架构标准&#xff0c;由德国大陆、博世、宝马等汽车及零部件制造商共同发起&#xff0c;拥有广泛的行业参与。其目标是为了解决汽车电子和软件系统日益复杂的问题&#xff0c;提高可重…

打工人使用ChatGPT的一天!

众所周知&#xff0c;ChatGPT 自去年OpenAI 推出以来&#xff0c;这款 AI 聊天机器人可以说迅速成为了 AI 界的「当红炸子鸡」 作为一名资深的打工人&#x1f477;&#x1f3fb;‍♂️&#xff0c;我们应该怎样利用ChatGPT提高工作效率呢&#xff1f;今天给大家介绍下打工人使…