前作 [1] 介绍了一种用 pytorch 模仿 MONAI 实现多幅图(如:image 与 label)同用 random seed 保证一致变换的写法,核心是 MultiCompose
类和 to_multi
包装函数。不过 [1] 没考虑各图用不同 augmentation 的情况,如:
- ColorJitter 只对 image 做,而不对 label 做;
- image 的 resize interpolation 可任选,但 label 只能用
nearest
。
本篇更新写法,支持各图同用、独用 augmentation。
Code
- 对比 [1],主要改变是改写
MultiCompose
类,并将to_multi
吸收入内。 MultiCompose
的用法还是和torchvision.transforms.Compose
几乎一致,不过支持独用 augmentation:只要为各图指定各自的 augmentation 类/函数即可。见下一节例程。
def to_multi():
"""不用单独的 to_multi 打包了,已并入 MultiCompose"""
raise NotImplementedError
class MultiCompose:
"""扩展 torchvision.transforms.Compose:支持输入多图,
且保证各 augmentation 中所有输入都用同一随机状态(如旋转同一随机角度),
分割任务有用。
"""
# numpy.random.seed range error:
# ValueError: Seed must be between 0 and 2**32 - 1
MIN_SEED = 0 # - 0x8000_0000_0000_0000
MAX_SEED = min(2**32 - 1, 0xffff_ffff_ffff_ffff)
def __init__(self, transforms):
"""输入:一个 list/tuple,
其中每个元素可以是一个 augmentation 对象(transform)/函数,各输入同用;
或一个嵌套的 list/tuple,为每个输入指定独用的 augmentation。
"""
# self.transforms = [to_multi(t) for t in transforms]
no_op = lambda x: x # i.e. identity function
self.transforms = []
for t in transforms:
if isinstance(t, (tuple, list)):
# convert `None` to `no_op` for convenience
self.transforms.append([no_op if _t is None else _t for _t in t])
else:
self.transforms.append(t)
def __call__(self, *images):
for t in self.transforms:
if isinstance(t, (tuple, list)): # 独用
assert len(images) <= len(t) # allow redundant transform
else: # 同用
t = [t] * len(images)
_aug_images = []
_seed = random.randint(self.MIN_SEED, self.MAX_SEED)
for _im, _t in zip(images, t):
seed_everything(_seed)
_aug_images.append(_t(_im))
images = _aug_images
if len(images) == 1:
images = images[0]
return images
Usage & Test
例程沿用 [1],但改一下 augmentation:
train_trans = MultiCompose([
# image 用 bilinear,label 用 nearest
(ResizeZoomPad((224, 256), "bilinear"), ResizeZoomPad((224, 256), "nearest")), # 独用
transforms.RandomAffine(30, (0.1, 0.1)), # 同用,传一个就行
transforms.RandomHorizontalFlip(), # 同用
# ColorJitter 只对 image 做,label 不做(None)
[transforms.ColorJitter(0.1, 0.2, 0.3, 0.4), None], # 独用
])
- 效果:
References
- pytorch一致数据增强