pytorch一致数据增强

分割任务对 image 做(某些)transform 时,要对 label(segmentation mask)也做对应的 transform,如 Resize、RandomRotation 等。如果对 image、label 分别用 transform 处理一遍,则涉及随机操作的可能不一致,如 RandomRotation 将 image 转了 a 度、却将 label 转了 b 度。

MONAI 有个 ArrayDataset 实现了这功能,思路是每次 transform 前都重置一次 random seed 先。对 monai 订制 transform 的方法不熟,torchvision.transforms 的订制接口比较简单,考虑基于 pytorch 实现。要改两个东西:

  • 扩展 torchvison.transforms.Compose,使之支持多个输入(image、label);
  • 一个 wrapper,扩展 transform,使之支持多输入。

思路也是重置 random seed,参考 [1-4]。

Code

  • to_multi:将处理单幅图的 transform 扩展成可处理多幅;
  • MultiCompose:扩展 torchvision.transforms.Compose,可输入多幅图。内部调用 to_multi 扩展传入的 transforms。
import random, os
import numpy as np
import torch

def seed_everything(seed=42):
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = True

def to_multi(trfm):
    """wrap a transform to extend to multiple input with synchronised random seed
    Input:
        trfm: transformation function/object (custom or from torchvision.transforms)
    Output:
        _multi_transform: function
    """
    # numpy.random.seed range error:
    #   ValueError: Seed must be between 0 and 2**32 - 1
    min_seed = 0 # - 0x8000_0000_0000_0000
    max_seed = min(2**32 - 1, 0xffff_ffff_ffff_ffff)
    def _multi_transform(*images):
        """images: [C, H, W]"""
        if len(images) == 1:
            return trfm(images[0])
        _seed = random.randint(min_seed, max_seed)
        res = []
        for img in images:
            seed_everything(_seed)
            res.append(trfm(img))
        return tuple(res)

    return _multi_transform


class MultiCompose:
    """Extension of torchvision.transforms.Compose that accepts multiple input.
    Usage is the same as torchvision.transforms.Compose. This class will wrap input
    transforms with `to_multi` to support simultaneous multiple transformation.
    This can be useful when simultaneously transforming images & segmentation masks.
    """
    def __init__(self, transforms):
        """transforms should be wrapped by `to_multi`"""
        self.transforms = [to_multi(t) for t in transforms]

    def __call__(self, *images):
        for t in self.transforms:
            images = t(*images)
        return images

test

测试一致性,用到预处理过的 verse’19 数据集、一些工具函数、一个订制 transform:

  • verse’19 数据集及预处理见 iTomxy/data/verse;
  • digit_sort_key:数据文件排序用;
  • get_palettecolor_segblend_seg:可视化用;
  • MyDataset:看其中 __getitem__ 的 transform 用法,即同时传入 image 和 label;
  • ResizeZoomPad:一个订制的 transform;
import os, os.path as osp, random
from glob import glob
import numpy as np
from PIL import Image
import torch
import torchvision.transforms as transforms
import torchvision.transforms.functional as F


def digit_sort_key(s, num_pattern=re.compile('([0-9]+)')):
    """natural sort,数据排序用"""
    return [int(text) for text in num_pattern.split(s) if text.isdigit()]


def get_palette(n_classes, pil_format=True):
    """创建调色盘,可视化用"""
    n = n_classes
    palette = [0] * (n * 3)
    for j in range(0, n):
        lab = j
        palette[j * 3 + 0] = 0
        palette[j * 3 + 1] = 0
        palette[j * 3 + 2] = 0
        i = 0
        while lab:
            palette[j * 3 + 0] |= (((lab >> 0) & 1) << (7 - i))
            palette[j * 3 + 1] |= (((lab >> 1) & 1) << (7 - i))
            palette[j * 3 + 2] |= (((lab >> 2) & 1) << (7 - i))
            i += 1
            lab >>= 3

    if pil_format:
        return palette

    res = []
    for i in range(0, len(palette), 3):
        res.append(tuple(palette[i: i+3]))
    return res


def color_seg(label, n_classes=0):
    """segmentation mask 上色,可视化用"""
    if n_classes < 1:
        n_classes = math.ceil(np.max(label)) + 1
    label_rgb = Image.fromarray(label.astype(np.int32)).convert("L")
    label_rgb.putpalette(get_palette(n_classes))
    return label_rgb.convert("RGB")


def blend_seg(image, label, n_classes=0, alpha=0.7, rescale=False, transparent_bg=True, save_file=""):
    """融合 image 和其 segmentation mask,可视化用"""
    if rescale:
        denom = image.max() - image.min()
        if 0 != denom:
            image = (image - image.min()) / denom * 255
        image = np.clip(image, 0, 255).astype(np.uint8)
    img_pil = Image.fromarray(image).convert("RGB")
    lab_pil = color_seg(label, n_classes)
    blended_image = Image.blend(img_pil, lab_pil, alpha)
    if transparent_bg:
        blended_image = Image.fromarray(np.where(
            (0 == label)[:, :, np.newaxis],
            np.asarray(img_pil),
            np.asarray(blended_image)
        ))
    if save_file:
        blended_image.save(save_file)
    return blended_image


class MyDataset(torch.utils.data.Dataset):
    """订制 dataset,看 __getitem__ 处 transform 的调法"""
    def __init__(self, image_list, label_list, transform=None):
        assert len(image_list) == len(label_list)
        self.image_list = image_list
        self.label_list = label_list
        self.transform = transform
    def __len__(self):
        return len(self.image_list)
    def __getitem__(self, index):
        img = np.load(self.image_list[index]) # [h, w]
        lab = np.load(self.label_list[index])
        img = torch.from_numpy(img).unsqueeze(0).float() # -> [c=1, h, w]
        lab = torch.from_numpy(lab).unsqueeze(0).int()
        if self.transform is not None:
            img, lab = self.transform(img, lab) # 同时传入 image、label
        return img, lab


class ResizeZoomPad:
    """订制 resize"""
    def __init__(self, size, interpolation="bilinear"):
        if isinstance(size, int):
            assert size > 0
            self.size = [size, size]
        elif isinstance(size, (tuple, list)):
            assert len(size) == 2 and size[0] > 0 and size[1] > 0
            self.size = size

        if isinstance(interpolation, str):
            assert interpolation.lower() in {"nearest", "bilinear", "bicubic", "box", "hamming", "lanczos"}
            interpolation = {
                "nearest": F.InterpolationMode.NEAREST,
                "bilinear": F.InterpolationMode.BILINEAR,
                "bicubic": F.InterpolationMode.BICUBIC,
                "box": F.InterpolationMode.BOX,
                "hamming": F.InterpolationMode.HAMMING,
                "lanczos": F.InterpolationMode.LANCZOS
            }[interpolation.lower()]
        self.interpolation = interpolation

    def __call__(self, image):
        """image: [C, H, W]"""
        scale_h, scale_w = float(self.size[0]) / image.size(1), float(self.size[1]) / image.size(2)
        scale = min(scale_h, scale_w)
        tmp_size = [ # clipping to ensure size
            min(int(image.size(1) * scale), self.size[0]),
            min(int(image.size(2) * scale), self.size[1])
        ]
        image = F.resize(image, tmp_size, self.interpolation)
        assert image.size(1) <= self.size[0] and image.size(2) <= self.size[1]
        pad_h, pad_w = self.size[0] - image.size(1), self.size[1] - image.size(2)
        if pad_h > 0 or pad_w > 0:
            pad_left, pad_right = pad_w // 2, (pad_w + 1) // 2
            pad_top, pad_bottom = pad_h // 2, (pad_h + 1) // 2
            image = F.pad(image, (pad_left, pad_top, pad_right, pad_bottom))
        return image


# 读数据文件
data_path = os.path.expanduser("~/data/verse/processed-verse19-npy-horizontal")
train_images, train_labels, val_images, val_labels = [], [], [], []
for d in os.listdir(osp.join(data_path, "training")):
    if d.endswith("_ct"):
        img_p = osp.join(data_path, "training", d)
        lab_p = osp.join(data_path, "training", d[:-3]+"_seg-vert_msk")
        assert osp.isdir(lab_p)
        train_labels.extend(glob(os.path.join(lab_p, "*.npy")))
        train_images.extend(glob(os.path.join(img_p, "*.npy")))
for d in os.listdir(osp.join(data_path, "validation")):
    if d.endswith("_ct"):
        img_p = osp.join(data_path, "validation", d)
        lab_p = osp.join(data_path, "validation", d[:-3]+"_seg-vert_msk")
        assert osp.isdir(lab_p)
        val_labels.extend(glob(os.path.join(lab_p, "*.npy")))
        val_images.extend(glob(os.path.join(img_p, "*.npy")))

# 数据文件名排序
train_images = sorted(train_images, key=lambda f: digit_sort_key(os.path.basename(f)))
train_labels = sorted(train_labels, key=lambda f: digit_sort_key(os.path.basename(f)))
val_images = sorted(val_images, key=lambda f: digit_sort_key(os.path.basename(f)))
val_labels = sorted(val_labels, key=lambda f: digit_sort_key(os.path.basename(f)))

# transform
# 用 MultiCompose,其内部调用 to_multi 将 transforms wrap 成支持多输入的
train_trans = MultiCompose([
    ResizeZoomPad((224, 256)),
    transforms.RandomRotation(30),
])

# 测试:读数据,可试化 image 和 label
check_ds = MyDataset(train_images, train_labels, train_trans)
check_loader = torch.utils.data.DataLoader(check_ds, batch_size=10, shuffle=True)
for images, labels in check_loader:
    print(images.size(), labels.size())
    for i in range(images.size(0)):
        # print(i, end='\r')
        img = images[i][0].numpy()
        lab = labels[i][0].numpy()
        print(np.unique(lab))
        seg_img = blend_seg(img, lab)
        img = (255 * (img - img.min()) / (img.max() - img.min())).astype(np.uint8)
        img = np.asarray(Image.fromarray(img).convert("RGB"))
        lab = np.asarray(color_seg(lab))
        comb = np.concatenate([img, lab, seg_img], axis=1)
        Image.fromarray(comb).save(f"test-dataset-{i}.png")
    break

效果:
test-dataset-7.png
可见,image 和 label 转了同一个随机角度。

Limits

有些 augmentations 是只对 image 做而不对 label 做的,如 ColorJitter,这里没有考虑怎么处理。

References

  1. How to Set Random Seeds in PyTorch and Tensorflow
  2. ihoromi4/seed_everything.py
  3. Reproducibility
  4. What is the max seed you can set up?

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/234702.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

【概率方法】朗之万动力学 Langevin Dynamics

目前我们了解到采样方法有很多种&#xff0c;按照从朴素到高效的演变顺序大致是 反函数采样蒙特卡洛模拟&#xff08;求统计量&#xff09;接受-拒绝采样MCMC HM 算法Gibbs 采样 接上一篇文章&#xff0c;Gibbs 采样能在有条件分布 p ( x d ′ ∣ x − d ) p(\mathbf{x}_{d…

头歌-Python 基础

第1关&#xff1a;建模与仿真 1、 建模过程&#xff0c;通常也称为数学优化建模(Mathematical Optimization Modeling)&#xff0c;不同之处在于它可以确定特定场景的特定的、最优化或最佳的结果。这被称为诊断一个结果&#xff0c;因此命名为▁▁▁。 填空1答案&#xff1a;决…

【数据挖掘】国科大苏桂平老师数据库新技术课程作业 —— 第四次作业

云数据库研究 云计算与云数据库背景 云计算&#xff08;cloud computing&#xff09;是 IT 技术发展的最新趋势&#xff0c;正受到业界和学术界的广泛关注。云计算是在分布式处理、并行处理和网格计算等技术的基础上发展起来的&#xff0c;是一种新兴的共享基础架构的方法。它…

大数据技术7:基于StarRocks统一OALP实时数仓

前言&#xff1a; 大家对StarRocks 的了解可能不及 ClickHouse或者是远不及 ClickHouse 。但是大家可能听说过 Doris &#xff0c;而 StarRocks 实际上原名叫做 Doris DB &#xff0c;他相当于是一个加强版的也就是一个 Doris ,也就是说 Doris 所有的功能 StarRocks 都是有的&a…

this.$emit(‘update:isVisible‘, false)作用

这个写是不是很新颖&#xff0c;传父组件传值&#xff01;这是什么鬼。。。 假设你有以下逻辑业务。在A页面弹出一个组件B&#xff0c;A组件里面使用B组件&#xff0c;是否展示B组件你使用的是baselineShow变量控制&#xff01; <BaselineData :isVisible.sync"basel…

SQL命令---修改字段的排列位置

介绍 使用sql语句表字段的排列顺序。 命令 alter table 表名 modify 字段名1 数据类型 first|after 字段名2;例子 将a表中的age字段改为表的第一个字段。 alter table a modify age int(12) first;下面是执行命令后的表结构&#xff1a; 将a表中的age字段放到name字段之…

ELK简单介绍二

学习目标 能够部署kibana并连接elasticsearch集群能够通过kibana查看elasticsearch索引信息知道用filebeat收集日志相对于logstash的优点能够安装filebeat能够使用filebeat收集日志并传输给logstash kibana kibana介绍 Kibana是一个开源的可视化平台,可以为ElasticSearch集群…

linux 15day apache apache服务安装 httpd服务器 安装虚拟主机系统 一个主机 多个域名如何绑定

目录 一、apache安装二、访问控制总结修改默认网站发布目录 三、虚拟主机 一、apache安装 [rootqfedu.com ~]# systemctl stop firewalld [rootqfedu.com ~]# systemctl disable firewalld [rootqfedu.com ~]# setenforce 0 [rootqfedu.com ~]# yum install -y httpd [rootqfe…

用23种设计模式打造一个cocos creator的游戏框架----(十二)状态模式

1、模式标准 模式名称&#xff1a;状态模式 模式分类&#xff1a;行为型 模式意图&#xff1a;允许一个对象在其内部状态改变时改变它的行为。对象看起来似乎修改了它的类。 结构图&#xff1a; 适用于&#xff1a; 1、一个对象的行为决定于它的状态&#xff0c;并且它必须…

记录一次云原生线上服务数据迁移全过程

文章目录 背景迁移方案调研迁移过程服务监控脚本定时任务暂停本地副本服务启动&#xff0c;在线服务下线MySQL 数据迁移Mongo 数据迁移切换新数据库 ip 本地服务启动数据库连接验证服务打包部署服务重启前端恢复正常监控脚本定时任务启动旧服务器器容器关闭 迁移总结 背景 校园…

代码随想录二刷 |二叉树 |101. 对称二叉树

代码随想录二刷 &#xff5c;二叉树 &#xff5c;101. 对称二叉树 题目描述解题思路 & 代码实现递归法迭代法使用队列使用栈 题目描述 101.对称二叉树 给你一个二叉树的根节点 root &#xff0c; 检查它是否轴对称。 示例 1&#xff1a; 输入&#xff1a;root [1,2,2,…

adb unauthorized 踩坑记录

给Realme X7 Pro 安装Root后&#xff0c;发现adb连接设备呈现unauthorized 状态&#xff1a; 在Google以后&#xff0c;尝试了很多方案&#xff0c;均无效&#xff0c;尝试的方案如下&#xff1a; 重启手机&#xff0c;电脑。不行撤销调试授权&#xff0c;开关usb调试&#xf…

持续集成交付CICD:Jenkins配置Nexus制品发布

目录 一、实验 1.Jenkins配置Nexus制品发布 一、实验 1.Jenkins配置Nexus制品发布 &#xff08;1&#xff09;策略 发布其实就是下载制品&#xff0c;然后将制品发送到目标主机&#xff0c;最后通过脚本或者指令启动程序。 &#xff08;2&#xff09;安装Maven Artifact …

基于JavaWeb+SSM+Vue马拉松报名系统微信小程序的设计和实现

基于JavaWebSSMVue马拉松报名系统微信小程序的设计和实现 源码获取入口Lun文目录前言主要技术系统设计功能截图订阅经典源码专栏Java项目精品实战案例《500套》 源码获取 源码获取入口 Lun文目录 1系统概述 1 1.1 研究背景 1 1.2研究目的 1 1.3系统设计思想 1 2相关技术 2 2.…

SQL命令---添加新字段

介绍 使用sql语句为表添加新字段。 命令 alter table 表名 add 新字段名 数据类型;例子 向a表中添加name字段&#xff0c;类型为varchar(255)。 alter table a add name varchar(255);下面是执行添加有的表结构&#xff1a;

react之项目打包,本地预览,路由懒加载,打包体积分析以及如何配置CDN

react之项目打包,本地预览,路由懒加载,打包体积分析以及如何配置CDN 一、项目打包二、项目本地预览三、路由懒加载四、打包体积分析五、配置CDN 一、项目打包 执行命令 npm run build根目录下生成的build文件夹 及时打包后的文件 二、项目本地预览 1.全局安装本地服务包 npm…

内存分配器

实现分配器需要考虑的问题 空闲块的组织方式&#xff1a;如何记录现有的空闲块空闲块的选择&#xff1a;如何选择一个合适的空闲块空闲块的分割&#xff1a;选择了一个合适的空闲块后如何处理空闲块内部的剩余部分空闲块的合并&#xff1a;如何处理一个刚刚被释放的块&#xf…

Python sqlalchemy使用

基本结构 #!/usr/bin/python3 # -*- coding:utf-8 -*- """ author: JHC file: base_db.py time: 2023/6/19 21:34 desc: """ from sqlalchemy import create_engine,text from sqlalchemy.orm import sessionmaker,scoped_session from contex…

计算机服务器中了Mallox勒索病毒怎么解密,Mallox勒索病毒解密步骤

计算机网络技术的不断发展与应用&#xff0c;为企业的生产运营提供了坚实的基础&#xff0c;大大提高了企业的生产与工作效率&#xff0c;但随之而来的网络安全威胁也在不断增加。在本月&#xff0c;云天数据恢复中心接到了很多企业的求助&#xff0c;企业的计算机服务器遭到了…

Nacos源码解读12——Nacos中长连接的实现

短连接 VS 长连接 什么是短连接 客户端和服务器每进行一次HTTP操作&#xff0c;就建立一次连接&#xff0c;任务结束就中断连接。 长连接 客户端和服务器之间用于传输HTTP数据的TCP连接不会关闭&#xff0c;客户端再次访问这个服务器时&#xff0c;会继续使用这一条已经建立…