BEVFormerEncoder中的get_reference_points
@staticmethod
def get_reference_points(H, W, Z=8, num_points_in_pillar=4, dim='3d', bs=1, device='cuda', dtype=torch.float):
"""Get the reference points used in SCA and TSA.
Args:
H, W: spatial shape of bev.
Z: hight of pillar.
D: sample D points uniformly from each pillar.
device (obj:`device`): The device where
reference_points should be.
Returns:
Tensor: reference points used in decoder, has \
shape (bs, num_keys, num_levels, 2).
"""
# reference points in 3D space, used in spatial cross-attention (SCA)
if dim == '3d':
zs = torch.linspace(0.5, Z - 0.5, num_points_in_pillar, dtype=dtype,
device=device).view(-1, 1, 1).expand(num_points_in_pillar, H, W) / Z
xs = torch.linspace(0.5, W - 0.5, W, dtype=dtype,
device=device).view(1, 1, W).expand(num_points_in_pillar, H, W) / W
ys = torch.linspace(0.5, H - 0.5, H, dtype=dtype,
device=device).view(1, H, 1).expand(num_points_in_pillar, H, W) / H
ref_3d = torch.stack((xs, ys, zs), -1)
ref_3d = ref_3d.permute(0, 3, 1, 2).flatten(2).permute(0, 2, 1)
ref_3d = ref_3d[None].repeat(bs, 1, 1, 1)
return ref_3d
# reference points on 2D bev plane, used in temporal self-attention (TSA).
elif dim == '2d':
ref_y, ref_x = torch.meshgrid(
torch.linspace(
0.5, H - 0.5, H, dtype=dtype, device=device),
torch.linspace(
0.5, W - 0.5, W, dtype=dtype, device=device)
)
ref_y = ref_y.reshape(-1)[None] / H
ref_x = ref_x.reshape(-1)[None] / W
ref_2d = torch.stack((ref_x, ref_y), -1)
ref_2d = ref_2d.repeat(bs, 1, 1).unsqueeze(2)
return ref_2d
根据上面的代码可以看出来,如果输入的是3d, 则是
按照:
- X方向: 从0.5, 到W-0.5分成W份.
- Y方向: 从0.5, 到H-0.5分成H份.
- Z方向: 从0.5, 到Z-0.5, 分成 num_points_in_pillar份.
其中num_points_in_pillar 默认给的是4.
配置文件里面给的其实也是4.
BEVFormerEncoder中的point_sampling
# This function must use fp32!!!
@force_fp32(apply_to=('reference_points', 'img_metas'))
def point_sampling(self, reference_points, pc_range, img_metas):
lidar2img = []
for img_meta in img_metas:
lidar2img.append(img_meta['lidar2img'])
lidar2img = np.asarray(lidar2img)
lidar2img = reference_points.new_tensor(lidar2img) # (B, N, 4, 4)
reference_points = reference_points.clone()
# 变换到点云的范围内. 这也是为何get_reference_points中会/H, /W, /Z, 先化到[0, 1]变成ratio.
reference_points[..., 0:1] = reference_points[..., 0:1] * \
(pc_range[3] - pc_range[0]) + pc_range[0]
reference_points[..., 1:2] = reference_points[..., 1:2] * \
(pc_range[4] - pc_range[1]) + pc_range[1]
reference_points[..., 2:3] = reference_points[..., 2:3] * \
(pc_range[5] - pc_range[2]) + pc_range[2]
# 由(x, y, z) 变成(x, y, z, 1) 便于与4*4的参数矩阵相乘.
reference_points = torch.cat(
(reference_points, torch.ones_like(reference_points[..., :1])), -1)
# 此时reference_points可以当成是点云的点了.
reference_points = reference_points.permute(1, 0, 2, 3)
# num_query等于H*W*Z. 等于grid_points的数量.
D, B, num_query = reference_points.size()[:3]
num_cam = lidar2img.size(1)
# 要往每个相机上去投影. 因此先申请num_cam份.
# reference_points的shape就变成了, (D, b, num_cam, num_query, 4, 1) 便于和4*4的矩阵做matmul.
reference_points = reference_points.view(
D, B, 1, num_query, 4).repeat(1, 1, num_cam, 1, 1).unsqueeze(-1)
# 相机参数由(b,num_cam, 4, 4) 变成(1, b, num_cam, 1, 4, 4) 再变成(D,b,num_cam,num_query,4,4)
lidar2img = lidar2img.view(
1, B, num_cam, 1, 4, 4).repeat(D, 1, 1, num_query, 1, 1)
reference_points_cam = torch.matmul(lidar2img.to(torch.float32),
reference_points.to(torch.float32)).squeeze(-1)
eps = 1e-5
# 把每个相机后面的点mask掉. 因为相机后面的点投过来之后第三位是负的.
bev_mask = (reference_points_cam[..., 2:3] > eps)
# 再做齐次化. 得到像素坐标.
reference_points_cam = reference_points_cam[..., 0:2] / torch.maximum(
reference_points_cam[..., 2:3], torch.ones_like(reference_points_cam[..., 2:3]) * eps)
# 由像素坐标转成相对于图像的ratio..
# NOTE 这里如果不同相机size不一样的话.要除以对应的相机的size
reference_points_cam[..., 0] /= img_metas[0]['img_shape'][0][1]
reference_points_cam[..., 1] /= img_metas[0]['img_shape'][0][0]
# 再把超出图像fov范围的点给去掉.
bev_mask = (bev_mask & (reference_points_cam[..., 1:2] > 0.0)
& (reference_points_cam[..., 1:2] < 1.0)
& (reference_points_cam[..., 0:1] < 1.0)
& (reference_points_cam[..., 0:1] > 0.0))
if digit_version(TORCH_VERSION) >= digit_version('1.8'):
bev_mask = torch.nan_to_num(bev_mask)
else:
bev_mask = bev_mask.new_tensor(
np.nan_to_num(bev_mask.cpu().numpy()))
# 由(D, b, num_cam, num_query, 2) 变成 (num_cam, b, num_query, D, 2)
reference_points_cam = reference_points_cam.permute(2, 1, 3, 0, 4)
bev_mask = bev_mask.permute(2, 1, 3, 0, 4).squeeze(-1)
# 至此. reference_points_cam代表的就是像素点相对于各个相机的ratio.
# bev_mask就代表哪些点是有效的
return reference_points_cam, bev_mask
SpatialCrossAttention
个人理解SpatialCrossAttention其实就是正常的Deformable Attention, 只不过原始Deformable Attention中的
refer points是由网络产生的,
而现在的refer points 是由 虚拟的grid points往图像上投影得到的. 在相机参数固定的情况下, 此时的refer points是固定的.
下面是 SpatialCrossAttention这个模块的forward函数的部分代码
问题: 给固定的这些refer points 的收益是多大? 文章好像并没有提. 这一块儿感觉不充分.
另外, 显然这样虚拟的grid points 是不合理的, 因为有些地方可能就没有点, 但是还是能够投影到图像上的. 这里用真值的点应该会更好,
比如用lidar的points. 但是BEVFormer paper里面没有对比加入lidar后的效果.