分割任务对 image 做(某些)transform 时,要对 label(segmentation mask)也做对应的 transform,如 Resize、RandomRotation 等。如果对 image、label 分别用 transform 处理一遍,则涉及随机操作的可能不一致,如 RandomRotation 将 image 转了 a 度、却将 label 转了 b 度。
MONAI 有个 ArrayDataset 实现了这功能,思路是每次 transform 前都重置一次 random seed 先。对 monai 订制 transform 的方法不熟,torchvision.transforms 的订制接口比较简单,考虑基于 pytorch 实现。要改两个东西:
- 扩展 torchvison.transforms.Compose,使之支持多个输入(image、label);
- 一个 wrapper,扩展 transform,使之支持多输入。
思路也是重置 random seed,参考 [1-4]。
Code
to_multi
:将处理单幅图的 transform 扩展成可处理多幅;MultiCompose
:扩展 torchvision.transforms.Compose,可输入多幅图。内部调用to_multi
扩展传入的 transforms。
import random, os
import numpy as np
import torch
def seed_everything(seed=42):
random.seed(seed)
os.environ['PYTHONHASHSEED'] = str(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = True
def to_multi(trfm):
"""wrap a transform to extend to multiple input with synchronised random seed
Input:
trfm: transformation function/object (custom or from torchvision.transforms)
Output:
_multi_transform: function
"""
# numpy.random.seed range error:
# ValueError: Seed must be between 0 and 2**32 - 1
min_seed = 0 # - 0x8000_0000_0000_0000
max_seed = min(2**32 - 1, 0xffff_ffff_ffff_ffff)
def _multi_transform(*images):
"""images: [C, H, W]"""
if len(images) == 1:
return trfm(images[0])
_seed = random.randint(min_seed, max_seed)
res = []
for img in images:
seed_everything(_seed)
res.append(trfm(img))
return tuple(res)
return _multi_transform
class MultiCompose:
"""Extension of torchvision.transforms.Compose that accepts multiple input.
Usage is the same as torchvision.transforms.Compose. This class will wrap input
transforms with `to_multi` to support simultaneous multiple transformation.
This can be useful when simultaneously transforming images & segmentation masks.
"""
def __init__(self, transforms):
"""transforms should be wrapped by `to_multi`"""
self.transforms = [to_multi(t) for t in transforms]
def __call__(self, *images):
for t in self.transforms:
images = t(*images)
return images
test
测试一致性,用到预处理过的 verse’19 数据集、一些工具函数、一个订制 transform:
- verse’19 数据集及预处理见 iTomxy/data/verse;
digit_sort_key
:数据文件排序用;get_palette
、color_seg
、blend_seg
:可视化用;MyDataset
:看其中__getitem__
的 transform 用法,即同时传入 image 和 label;ResizeZoomPad
:一个订制的 transform;
import os, os.path as osp, random
from glob import glob
import numpy as np
from PIL import Image
import torch
import torchvision.transforms as transforms
import torchvision.transforms.functional as F
def digit_sort_key(s, num_pattern=re.compile('([0-9]+)')):
"""natural sort,数据排序用"""
return [int(text) for text in num_pattern.split(s) if text.isdigit()]
def get_palette(n_classes, pil_format=True):
"""创建调色盘,可视化用"""
n = n_classes
palette = [0] * (n * 3)
for j in range(0, n):
lab = j
palette[j * 3 + 0] = 0
palette[j * 3 + 1] = 0
palette[j * 3 + 2] = 0
i = 0
while lab:
palette[j * 3 + 0] |= (((lab >> 0) & 1) << (7 - i))
palette[j * 3 + 1] |= (((lab >> 1) & 1) << (7 - i))
palette[j * 3 + 2] |= (((lab >> 2) & 1) << (7 - i))
i += 1
lab >>= 3
if pil_format:
return palette
res = []
for i in range(0, len(palette), 3):
res.append(tuple(palette[i: i+3]))
return res
def color_seg(label, n_classes=0):
"""segmentation mask 上色,可视化用"""
if n_classes < 1:
n_classes = math.ceil(np.max(label)) + 1
label_rgb = Image.fromarray(label.astype(np.int32)).convert("L")
label_rgb.putpalette(get_palette(n_classes))
return label_rgb.convert("RGB")
def blend_seg(image, label, n_classes=0, alpha=0.7, rescale=False, transparent_bg=True, save_file=""):
"""融合 image 和其 segmentation mask,可视化用"""
if rescale:
denom = image.max() - image.min()
if 0 != denom:
image = (image - image.min()) / denom * 255
image = np.clip(image, 0, 255).astype(np.uint8)
img_pil = Image.fromarray(image).convert("RGB")
lab_pil = color_seg(label, n_classes)
blended_image = Image.blend(img_pil, lab_pil, alpha)
if transparent_bg:
blended_image = Image.fromarray(np.where(
(0 == label)[:, :, np.newaxis],
np.asarray(img_pil),
np.asarray(blended_image)
))
if save_file:
blended_image.save(save_file)
return blended_image
class MyDataset(torch.utils.data.Dataset):
"""订制 dataset,看 __getitem__ 处 transform 的调法"""
def __init__(self, image_list, label_list, transform=None):
assert len(image_list) == len(label_list)
self.image_list = image_list
self.label_list = label_list
self.transform = transform
def __len__(self):
return len(self.image_list)
def __getitem__(self, index):
img = np.load(self.image_list[index]) # [h, w]
lab = np.load(self.label_list[index])
img = torch.from_numpy(img).unsqueeze(0).float() # -> [c=1, h, w]
lab = torch.from_numpy(lab).unsqueeze(0).int()
if self.transform is not None:
img, lab = self.transform(img, lab) # 同时传入 image、label
return img, lab
class ResizeZoomPad:
"""订制 resize"""
def __init__(self, size, interpolation="bilinear"):
if isinstance(size, int):
assert size > 0
self.size = [size, size]
elif isinstance(size, (tuple, list)):
assert len(size) == 2 and size[0] > 0 and size[1] > 0
self.size = size
if isinstance(interpolation, str):
assert interpolation.lower() in {"nearest", "bilinear", "bicubic", "box", "hamming", "lanczos"}
interpolation = {
"nearest": F.InterpolationMode.NEAREST,
"bilinear": F.InterpolationMode.BILINEAR,
"bicubic": F.InterpolationMode.BICUBIC,
"box": F.InterpolationMode.BOX,
"hamming": F.InterpolationMode.HAMMING,
"lanczos": F.InterpolationMode.LANCZOS
}[interpolation.lower()]
self.interpolation = interpolation
def __call__(self, image):
"""image: [C, H, W]"""
scale_h, scale_w = float(self.size[0]) / image.size(1), float(self.size[1]) / image.size(2)
scale = min(scale_h, scale_w)
tmp_size = [ # clipping to ensure size
min(int(image.size(1) * scale), self.size[0]),
min(int(image.size(2) * scale), self.size[1])
]
image = F.resize(image, tmp_size, self.interpolation)
assert image.size(1) <= self.size[0] and image.size(2) <= self.size[1]
pad_h, pad_w = self.size[0] - image.size(1), self.size[1] - image.size(2)
if pad_h > 0 or pad_w > 0:
pad_left, pad_right = pad_w // 2, (pad_w + 1) // 2
pad_top, pad_bottom = pad_h // 2, (pad_h + 1) // 2
image = F.pad(image, (pad_left, pad_top, pad_right, pad_bottom))
return image
# 读数据文件
data_path = os.path.expanduser("~/data/verse/processed-verse19-npy-horizontal")
train_images, train_labels, val_images, val_labels = [], [], [], []
for d in os.listdir(osp.join(data_path, "training")):
if d.endswith("_ct"):
img_p = osp.join(data_path, "training", d)
lab_p = osp.join(data_path, "training", d[:-3]+"_seg-vert_msk")
assert osp.isdir(lab_p)
train_labels.extend(glob(os.path.join(lab_p, "*.npy")))
train_images.extend(glob(os.path.join(img_p, "*.npy")))
for d in os.listdir(osp.join(data_path, "validation")):
if d.endswith("_ct"):
img_p = osp.join(data_path, "validation", d)
lab_p = osp.join(data_path, "validation", d[:-3]+"_seg-vert_msk")
assert osp.isdir(lab_p)
val_labels.extend(glob(os.path.join(lab_p, "*.npy")))
val_images.extend(glob(os.path.join(img_p, "*.npy")))
# 数据文件名排序
train_images = sorted(train_images, key=lambda f: digit_sort_key(os.path.basename(f)))
train_labels = sorted(train_labels, key=lambda f: digit_sort_key(os.path.basename(f)))
val_images = sorted(val_images, key=lambda f: digit_sort_key(os.path.basename(f)))
val_labels = sorted(val_labels, key=lambda f: digit_sort_key(os.path.basename(f)))
# transform
# 用 MultiCompose,其内部调用 to_multi 将 transforms wrap 成支持多输入的
train_trans = MultiCompose([
ResizeZoomPad((224, 256)),
transforms.RandomRotation(30),
])
# 测试:读数据,可试化 image 和 label
check_ds = MyDataset(train_images, train_labels, train_trans)
check_loader = torch.utils.data.DataLoader(check_ds, batch_size=10, shuffle=True)
for images, labels in check_loader:
print(images.size(), labels.size())
for i in range(images.size(0)):
# print(i, end='\r')
img = images[i][0].numpy()
lab = labels[i][0].numpy()
print(np.unique(lab))
seg_img = blend_seg(img, lab)
img = (255 * (img - img.min()) / (img.max() - img.min())).astype(np.uint8)
img = np.asarray(Image.fromarray(img).convert("RGB"))
lab = np.asarray(color_seg(lab))
comb = np.concatenate([img, lab, seg_img], axis=1)
Image.fromarray(comb).save(f"test-dataset-{i}.png")
break
效果:
可见,image 和 label 转了同一个随机角度。
Limits
有些 augmentations 是只对 image 做而不对 label 做的,如 ColorJitter,这里没有考虑怎么处理。
References
- How to Set Random Seeds in PyTorch and Tensorflow
- ihoromi4/seed_everything.py
- Reproducibility
- What is the max seed you can set up?