最近在处理VOT数据集时,遇到了一个奇怪的问题,特此记录。
源代码如下:
def ltr_collate_stack1(batch):
"""Puts each data field into a tensor. The tensors are stacked at dim=1 to form the batch"""
error_msg = "batch must contain tensors, numbers, dicts or lists; found {}"
elem_type = type(batch[0])
if isinstance(batch[0], torch.Tensor):
out = None
if _check_use_shared_memory():
# If we're in a background process, concatenate directly into a
# shared memory tensor to avoid an extra copy
numel = sum([x.numel() for x in batch])
storage = batch[0].storage()._new_shared(numel)
out = batch[0].new(storage)
# print(batch.shape)
# print(out.shape)
return torch.stack(batch, 1, out=out)
按道理来说代码执行结束,out.shape和storage.shape一致,但是在pytorch2.1.1版本中,这两个却不一致,将代码修改如下即可正确运行:
def ltr_collate_stack1(batch):
"""Puts each data field into a tensor. The tensors are stacked at dim=1 to form the batch"""
error_msg = "batch must contain tensors, numbers, dicts or lists; found {}"
elem_type = type(batch[0])
if isinstance(batch[0], torch.Tensor):
out = None
if _check_use_shared_memory():
# If we're in a background process, concatenate directly into a
# shared memory tensor to avoid an extra copy
numel = sum([x.numel() for x in batch])
storage = batch[0].storage()._new_shared(numel)
out = batch[0].new(storage).view(-1)
# print(batch.shape)
# print(out.shape)
return torch.stack(batch, 1, out=out)