【技术追踪】SegGuidedDiff:基于分割引导扩散模型实现解剖学可控的医学图像生成(MICCAI-2024)

  它来了它来了,它带着 mask 做生成了~

  SegGuidedDiff:提出一种用于解剖学可控医学图像生成的扩散模型,在每个采样步骤都遵循多类解剖分割掩码并结合了随机掩码消融训练算法,可助力乳房 MRI 和 腹部/颈部到骨盆 CT 等任务涨点。


论文:Anatomically-Controllable Medical Image Generation with Segmentation-Guided Diffusion Models
代码:https://github.com/mazurowski-lab/segmentation-guided-diffusion


0、摘要

  扩散模型能够实现高质量的医学图像生成,但在生成的图像中实现解剖约束具有挑战性。
  为此,本文提出了一种基于扩散模型的方法,通过支持解剖可控的医学图像生成,在每个采样步骤中遵循多类解剖分割 mask。
  此外,还引入了一种随机 mask 消融训练算法,以实现对选定的解剖约束组合的调节,同时允许其他解剖区域的灵活性。
  本文将所提出的方法 SegGuidedDiff 与乳腺MRI和腹部/颈部到骨盆CT数据集的现有方法进行了比较,这些数据集具有广泛的解剖目标。
  结果表明,本文的方法在生成图像的忠实性方面达到了SOTA效果,并与一般的解剖现实相符。
  该模型还有额外的好处,即可通过潜在空间的插值来调整生成图像与真实图像的解剖相似性。
  SegGuidedDiff 有许多应用,包括跨模态转译,以及成对数据或反事实数据的生成。(好像很厉害~


1、引言

1.1、DDPM的不足

  DDPM这样的标准生成模型仍然无法创建解剖学上合理的组织(图1),并且这种解剖结构无法精确定制。
  本文提出的解决方案是将不同类型的组织、器官的分割 mask,作为解剖信息先验,以此来生成图像,为网络提供更直接的学习信号,实现解剖真实感。

标准的扩散模型即使生成高质量图像,但无法创建真实的组织:
在这里插入图片描述

1.2、从mask生成图像:图像转译任务难点

  (1)现有模型没有直接实现精确的像素级解剖约束;
  (2)LDM已用于对自然图像的 mask 引导,但其在医学图像上的适应性并不好;
  故,本文实现的是图像空间的扩散模型,转换到潜在空间可能会丢失精确的空间引导;


2、方法

2.1、扩散模型简要概述

  原文略,可参考:【Diffusion综述】医学图像分析中的扩散模型(一)中2.2节;

2.2、在扩散模型中添加分割引导

  主要思想是以分割mask为引导条件生成更符合真实解剖的图像,故本文不直接从非条件分布 p ( x 0 ) {p(x_0)} p(x0) 中采样,而是从 p ( x 0 ∣ m ) {p(x_0|m)} p(x0m) 中采样,其中 x 0 ∈ R c × h × w {x_0 \in \mathbb{R}^{c×h×w}} x0Rc×h×w m ∈ { 0 , . . . , C − 1 } h × w {m \in \{ 0,...,C-1 \}}^{h×w} m{0,...,C1}h×w,C为多类别标注 mask 的类别数,包括背景。

  这样添加引导条件不会改变前向过程 q ( x t ∣ x t − 1 ) {q(x_t|x_{t-1})} q(xtxt1) ,但会修改反向过程 p θ ( x t − 1 ∣ x t , m ) {p_{\theta}(x_{t-1}|x_t,m)} pθ(xt1xt,m) 和噪声预测网络 ϵ θ {{\epsilon}_{\theta}} ϵθ ,损失函数如下:
在这里插入图片描述
  每一个训练 x 0 {x_0} x0 都有一些配对的 mask m {m} m ,在网络中如何实现呢,在去噪过程中,将 m {m} m 个 mask 直接 concat 到 Unet 的输入 x t {x_t} xt 就可以了。

  其原代码实现如下:

def convert_segbatch_to_multiclass(imgs_shape, segmentations_batch, config, device):
    # NOTE: this generic function assumes that segs don't overlap
    # put all segs on same channel
    segs = torch.zeros(imgs_shape).to(device)
    for k, seg in segmentations_batch.items():
        if k.startswith("seg_"):
            seg = seg.to(device)
            segs[segs == 0] = seg[segs == 0]

    if config.use_ablated_segmentations:
        # randomly remove class labels from segs with some probability
        segs = ablate_masks(segs, config)

    return segs
    
def add_segmentations_to_noise(noisy_images, segmentations_batch, config, device):
    """
    concat segmentations to noisy image
    """
    if config.segmentation_channel_mode == "single":
        segs = convert_segbatch_to_multiclass(noisy_images.shape, segmentations_batch, config, device) 
        # concat segs to noise
        noisy_images = torch.cat((noisy_images, segs), dim=1)  # 这里,cat在一起!
        
    elif config.segmentation_channel_mode == "multi":
        raise NotImplementedError

    return noisy_images

2.3、Mask-Ablated训练和采样

  本文作者认为,用于引导生成的 mask 质量非常重要,若 mask 标注不全,可能会误导生成图像,因此,作者希望模型可以简单地填充或推断未提供的目标。
  那咋整呢,提出了一种 mask-ablated 训练(MAT)策略,该策略提供了具有各种数量和类别组合的 mask 示例,供模型在训练过程中学习生成图像。这可以被认为是解剖对象表征的一种自监督学习形式。
  算法中,采用伯努利分布,随机将一些类的 mask 置 0 ,构成各种类别 mask 的组合。

算法流程:
在这里插入图片描述

  其原代码实现如下:

def ablate_masks(segs, config, method="equal_weighted"):
    # randomly remove class label(s) from segs with some probability 
    if method == "equal_weighted":
        """
        # give equal probability to each possible combination of removing non-background classes
        # NOTE: requires that each class has a value in ({0, 1, 2, ...} / 255)
        # which is by default if the mask file was saved as {0, 1, 2 ,...} and then normalized by default to [0, 1] by transforms.ToTensor()
        # num_segmentation_classes
        """
        # 随机将某一类mask置为False,删除
        class_removals = (torch.rand(config.num_segmentation_classes - 1) < 0.5).int().bool().tolist()
        for class_idx, remove_class in enumerate(class_removals):
            if remove_class:
                segs[(255 * segs).int() == class_idx + 1] = 0

    elif method == "by_class":
        class_ablation_prob = 0.3
        for seg_value in segs.unique():
            if seg_value != 0:
                # remove seg with some probability
                if torch.rand(1).item() < class_ablation_prob:
                    segs[segs == seg_value] = 0
    
    else:
        raise NotImplementedError
    return segs

3、实验与结果

3.1、数据集

(1)杜克大学乳腺癌MRI数据集:
  ①100例,T1图像,70例训练,15例测试,保留15例训练集做其他实验;
  ②所有数据有乳腺、血管(BV)、纤维腺/致密组织(FGT)的分割标注,FGT和BV在形状、大小和其他形态特征上具有非常高的变异性,这为生成模型的真实特征捕获提出了挑战;

(2)CT器官:
  ①40例,腹部CT扫描,包括肝脏、膀胱、肺、肾和骨的分割标注;
  ②24例训练,8例测试,保留8例训练集;

  所有生成模型都是在训练集上进行训练的,辅助分割网络是在保留训练集上进行训练的;

3.2、实施细节

  (1)图像大小256×256,归一化到 [0,255];
  (2)正向过程: β t {\beta_t} βt 线性从 0.0001 到 0.02;
  (3)AdamW优化器,余弦调整学习率,初始0.0001,500 linear warm-up steps;
  (4)epoch:400;
  (5)batch size:64;
  (6)显卡:4块 48 GB NVIDIA A6000;

3.3、与现有图像生成模型的比较

  STD为标准模型,MAT则采用了mask-ablated 训练策略:

在这里插入图片描述

3.4、评估生成的图像对输入掩码的忠实度

  使用在真实训练集上训练的辅助分割网络(MONAI UNet),预测从测试集生成的图像的分割mask: m g e n p r e d {m_{gen}^{pred}} mgenpred,计算其与 m {m} m m g e n p r e d {m_{gen}^{pred}} mgenpred 的 Dice 值:

在这里插入图片描述

3.5、评估生成图像质量

  作者认为,FID 这样的基于 CNN 特征的指标无法捕捉到解剖学真实性的全局特征,而这些特征在这些模型生成的图像中可能会有所不同;
  作者利用辅助分割网络在合成的图像上训练,将测试集分为两部分,分别验证在真实图像和合成图像上训练的模型,证明了在合成图像训练的模型表现比真实图像训练的模型差不多(≤ 0.04 Dice):

在这里插入图片描述

3.6、MAT的优势

  MAT的好处是它能够从缺少类的引导 mask 中生成图像:

在这里插入图片描述

3.7、生成的图像与真实图像的可调解剖相似性

  通过在模型的潜在空间中对合成图像和真实图像进行插值来调整由 m {m} m 生成的图像与 x 0 {x_0} x0 的解剖相似性;

  在反向过程中, t = T {t=T} t=T,在 t = t ~ {t=\tilde{t} } t=t~ (本文使用 t ~ = 240 {\tilde{t} = 240} t~=240)时获得一个潜在表示 x t ~ ′ {x_{\tilde{t}}^{\prime}} xt~;使用正向过程,从 x 0 {x_0} x0 获得 t = t ~ {t=\tilde{t} } t=t~ 时的图像 x t ~ {x_{\tilde{t}}} xt~ ,使用 x t ~ λ = ( 1 − λ ) x t ~ + λ x t ~ ′ {x_{\tilde{t}}^{\lambda} = (1-\lambda) x_{\tilde{t}} + \lambda x_{\tilde{t}}^{\prime}} xt~λ=(1λ)xt~+λxt~ 融合这两幅图的特征, λ ∈ ( 0 , 1 ] {\lambda \in (0,1]} λ(0,1] 控制混合特征与真实图像的相似性; x t ~ λ {x_{\tilde{t}}^{\lambda}} xt~λ 接着去噪,获得 x 0 λ {x_{0}^{\lambda}} x0λ

  乳腺MRI中只有FGT+BV对受限,而CT器官中只有骨骼受限:
在这里插入图片描述


  又是羡慕别人diffusion 的一天,怎么拥有一个好用的 diffusion,在线等,挺着急的 (;′⌒`)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/763889.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

python中的包和模块

目录 一、包与模块 二、第三方包的安装 2.1 pip install 2.2使用 curl 管道 2.3其他安装方法 三、导入单元的构成 3.1pip的使用 四、模块的缓存 一、包与模块 Python 中除了函数库以外&#xff0c;还有非常多且优秀的第三方库、包、模块。 模块Module&#xff1a;以…

LangChain 开发智能Agent,你学会了吗?

Prompt Enginnering 是打开LLM宝库的一把金钥匙&#xff0c;如果prompt得法&#xff0c;并能将其技巧与某项工作深度结合&#xff0c;那必将大大增效。今天我们来聊聊如何优化Prompt设计、Prompt Template管理等技术和体力活&#xff0c;并赋能老喻干货店的营销活动。 LLM Pro…

基于机器学习的零售商品销售数据预测系统

1 项目介绍 1.1 研究目的和意义 在电子商务日益繁荣的今天&#xff0c;精准预测商品销售数据成为商家提升运营效率、优化库存管理以及制定营销策略的关键。为此&#xff0c;开发了一个基于深度学习的商品销售数据预测系统&#xff0c;该系统利用Python编程语言与Django框架&a…

Java服务器代码远程调试(IDEA版)

Java服务器代码远程调试 配置启动脚本参数配置IDEA远程调试工具操作步骤 注意&#xff1a;远程调试的代码需要与本地代码一致&#xff0c;远程调试目的是解决本地环境无法支持调试的情况下&#xff0c;解决线上&#xff08;测试&#xff09;环境调试问题。 配置启动脚本参数 n…

昇思25天学习打卡营第10天|linchenfengxue

基于MobileNetv2的垃圾分类 通过读取本地图像数据作为输入&#xff0c;对图像中的垃圾物体进行检测&#xff0c;并且将检测结果图片保存到文件中。 MobileNetv2模型原理介绍 MobileNet网络是由Google团队于2017年提出的专注于移动端、嵌入式或IoT设备的轻量级CNN网络&#x…

TikTok直播限流与网络的关系及解决方法

TikTok作为一款热门的社交平台&#xff0c;其直播功能吸引了大量用户。然而&#xff0c;一些用户可能会遇到TikTok直播限流的问题&#xff0c;例如直播过程中出现播放量低、直播画面质量差等情况。那么&#xff0c;TikTok直播限流与所使用的网络线路是否有关系&#xff1f;是否…

TypeScript Project References npm 包构建小实践

npm 包输出 es/cjs 产物 在开发一个 npm 包时&#xff0c;通常需要同时输出 ES 模块和 CommonJS 模块的产物供不同的构建进行使用。在只使用tsc进行产物编译的情况下&#xff0c;我们通常可以通过配置两个独立的 tsconfig.json 配置文件&#xff0c;并在一个 npm script 中 执…

typescript学习回顾(五)

今天来分享一下ts的泛型&#xff0c;最后来做一个练习 泛型 有时候&#xff0c;我们在书写某些函数的时候&#xff0c;会丢失一些类型信息&#xff0c;比如我下面有一个例子&#xff0c;我想提取一个数组的某个索引之前的所有数据 function getArraySomeData(newArr, n:numb…

Mouse Prealbumin ELISA Kit小鼠前白蛋白ELISA试剂盒

前白蛋白&#xff08;PRE&#xff09;是一种由4条相同的多肽链组成的四聚体蛋白。电泳时&#xff0c;它比血清白蛋白的迁移速度更快&#xff0c;PRE可以作为多种疾病患者营养评价的标志物。ICL的Mouse Prealbumin ELISA Kit应用双抗体夹心法测定小鼠样本中前白蛋白水平&#xf…

CentOS7源码安装nginx并编写服务脚本

华子目录 准备下载nginx源码包关闭防火墙关闭selinux安装依赖环境 解压编译安装测试编写服务脚本&#xff0c;通过systemctl实现服务启动与关闭测试 准备 下载nginx源码包 在源码安装前&#xff0c;我们得先下载nginx源码包https://nginx.org/download/这里我下载的是nginx-1…

《梦醒蝶飞:释放Excel函数与公式的力量》8.2 COUNTA函数

8.2 COUNTA函数 COUNTA函数是Excel中用于统计指定区域内所有非空单元格数量的函数。它能够统计数值、文本、错误值以及公式返回的结果&#xff0c;是数据分析中常用的统计工具。 8.2.1 函数简介 COUNTA函数用于统计指定区域中所有非空单元格的数量。它与COUNT函数不同&#…

transformer——多变量预测PyTorch搭建Transformer实现多变量多步长时间序列预测(负荷预测)——transformer多变量预测

写在最前&#xff1a; 在系统地学习了Transformer结构后&#xff0c;尝试使用Transformer模型对DNA序列数据实现二分类&#xff0c;好久前就完成了这个实验&#xff0c;一直拖着没有整理&#xff0c;今天系统的记录一下&#xff0c;顺便记录一下自己踩过的坑 &#xff08;需要…

OpenHarmony开发实战:GPIO控制器接口

功能简介 GPIO&#xff08;General-purpose input/output&#xff09;即通用型输入输出。通常&#xff0c;GPIO控制器通过分组的方式管理所有GPIO管脚&#xff0c;每组GPIO有一个或多个寄存器与之关联&#xff0c;通过读写寄存器完成对GPIO管脚的操作。 GPIO接口定义了操作GP…

Echarts 问题集锦

最近公司集中做统计图表&#xff0c;新手小白&#xff0c;真被Echarts折腾地不轻&#xff0c;怕自己年老记忆衰退&#xff0c;特地做一些记录。以备后面查阅。 1、X轴的 数据显示不全&#xff0c;间或不显示 很奇葩&#xff0c;我发现数据里有一个值为0.0&#xff0c;当这条记…

液压件工厂的MES解决方案:智能生产,高效未来

一、引言 虽然我国液压件行业发展迅速&#xff0c;但是大多数液压件生产企业规模小、自主创新能力不足&#xff0c;大部分液压产品处于价值链中低端。且由于技术、工艺、设备及管理等多方面的限制&#xff0c;高端液压件产品研发生产水平不足&#xff0c;无法形成有效的供给&a…

【linux】虚拟机安装 BCLinux-R8-U4-Server-x86_64

目录 一、概述 1.1移动云Linux系统订阅服务 CLS 1.2 大云天元操作系统BC-Linux 二、安装 一、概述 1.1移动云Linux系统订阅服务 CLS 移动云Linux系统订阅服务 CLS &#xff08;Cloud Linux Service&#xff09;为使用BC-Linux操作系统的用户提供标准维保服务以及高级技术支…

生物墨水的重要特性

生物打印技术正以前所未有的速度发展&#xff0c;为组织工程和再生医学领域带来了革命性的变革。然而&#xff0c;成功打印出功能性的三维结构&#xff0c;并将其应用于人体&#xff0c;离不开生物墨水这一关键材料。主要特性包括&#xff1a; 物理性质 表面张力: 表面张力影…

基于java+springboot+vue实现的社团管理系统(文末源码+Lw)270

摘 要 互联网发展至今&#xff0c;无论是其理论还是技术都已经成熟&#xff0c;而且它广泛参与在社会中的方方面面。它让信息都可以通过网络传播&#xff0c;搭配信息管理工具可以很好地为人们提供服务。针对信息管理混乱&#xff0c;出错率高&#xff0c;信息安全性差&#…

Linux4(Docker)

目录 一、Docker介绍 二、Docker结构 三、Docker安装 四、Docker 镜像 五、Docker 容器 六、Docker 安装nginx 七、Docker 中的MySQL部署 一、Docker介绍 Docker&#xff1a;是给予Go语言实现的开源项目。 Docker的主要目标是“Build,Ship and Run Any App,Anywhere” 也…

LangChain入门学习笔记(七)—— 使用检索提高生成内容质量

大模型训练使用的数据是开放的、广泛的&#xff0c;因此它显得更加的通用。然而在有些应用场景下&#xff0c;用户需要使用自己的数据使得大模型生成的内容更加贴切&#xff0c;也有时候用户的数据是敏感的&#xff0c;无法提供出来给大模型进行通用性的训练。RAG技术就是一种解…