Swin-UMamba
- 期刊分析
- 摘要
- 贡献
- 方法
- Swin-UMamba整体框架
- 1. 基于 Mamba 的 VSS 块
- 2. 集成基于 ImageNet 的预训练
- 3. Swin-UMamba解码器
- 4. Swin-UMamba†:带有基于 Mamba 的解码器的 Swin-UMamba
- 实验
- 代码实践
- 可借鉴参考
期刊分析
Swin-UMamba只是名字中含有swin,和swin没有一点关系,代码中也无swin相关内容
发布
arXiv 2024.2.5
代码:https://github.com/JiarunLiu/Swin-UMamba
摘要
准确的医学图像分割需要集成多尺度信息,从局部特征到全局依赖性。然而,现有方法对远程全局信息进行建模具有挑战性,
其中卷积神经网络(CNN)受到其局部感受野的限制,而视觉变换器(ViT)则受到其注意力机制的高二次复杂度的影响。
最近,基于 Mamba 的模型因其在长序列建模方面令人印象深刻的能力而受到极大关注。多项研究表明,这些模型在各种任务中都可以优于流行的视觉模型,提供更高的准确性、更低的内存消耗和更少的计算负担。然而,现有的基于 Mamba 的模型大多是从头开始训练的,并没有探索预训练的力量,而预训练已被证明对于数据高效的医学图像分析非常有效。本文介绍了一种基于 Mamba 的新型模型 Swin-UMamba,该模型专为医学图像分割任务而设计,利用了基于 ImageNet 的预训练的优势。我们的实验结果揭示了基于 ImageNet 的训练在增强基于 Mamba 的模型性能方面的重要作用。与 CNN、ViT 和最新的基于 Mamba 的模型相比,Swin-UMamba 表现出了巨大的优越性能。值得注意的是,在腹部 MRI、肠镜检查和显微镜检查数据集上,Swin-UMamba 的平均得分比最接近的 U-Mamba 好 3.58%。
贡献
- 据我们所知,我们是首次尝试发现预训练的基于 Mamba 的网络在医学图像分割任务中的影响。我们的实验验证了基于 ImageNet 的预训练在基于 Mamba 的网络的医学图像分割中发挥着重要作用,有时至关重要。
- 我们提出了一种基于 Mamba 的新网络,名为 Swin-UMamba,用于医学图像分割,该网络是专门为统一预训练模型的功能而设计的。此外,我们提出了一种变体结构
SwinUMamba†
,具有更少的网络参数和更低的 FLOPs,同时保持有竞争力的性能。 - 我们的结果表明,Swin-UMamba 和 Swin-UMamba† 都可以优于以前的分割模型,包括 CNN、ViT 和最新的基于 Mamba 的模型,并且具有显着的优势,凸显了基于 ImageNet 的预训练和所提出的架构在医学图像分割任务中的有效性。
方法
Swin-UMamba整体框架
其中蓝色框中就是加载的预训练权重!
1. 基于 Mamba 的 VSS 块
VMamba中解释图:
VM-UNet中解释图:
1.
VSS块在VM-UNet中有着详细的解释,不同与Vit的扫描方式,其按照四种方向进行拆分并合并。
2.
本文提出的Swin-UMamba
是在编码器部分采用了四层VSS;而提出的SwinUMamba†
是在解码器部分也采用了VSS
2. 集成基于 ImageNet 的预训练
上面先介绍了VSS基础模块,接着介绍Swin-UMamba中如何使用VSS的,以及如何将预训练模型插入到网络中
1.
Swin-UMamba的编码器可以分为5个阶段。第一阶段是茎阶段**(stem stage)**。它包含一个用于 2×下采样的卷积层,内核为 7×7,填充大小为 3,步幅大小为 2。在卷积层之后采用了 实例归一化。 Swin-UMamba 的第一阶段与 VMamba 不同,因为我们更喜欢渐进的下采样过程,其中每个阶段进行 2× 下采样。第二阶段使用补丁大小为 2 × 2 的补丁嵌入层,将特征分辨率保持在原始图像的 1/4×,这与 VMamba 中的嵌入特征相同。后续阶段遵循 VMamba-Tiny 的设计
,其中每个阶段由用于 2×下采样的补丁合并层和用于高级特征提取的多个 VSS 块组成。与 ViT 不同,由于 VSS 块的因果性质,我们没有采用 Swin-UMamba 中的位置嵌入[24]
。
2.
stage-2到stage-5的VSS块的数量分别为{2,2,9,2}。每个阶段之后的特征尺寸都会以二次方的方式增加。阶段,结果为 D = {48, 96, 192, 384, 768}。我们使用 VMamba-Tiny 中的 ImageNet 预训练权重来初始化 VSS 块和补丁合并层,如图 1 所示
。值得注意的是,由于补丁大小和输入通道的差异,补丁嵌入块没有使用预训练权重进行初始化。
因为是按层进行预训练权重加载的,这和平时那种直接加载整个网络的权重不同,大大增加了理解代码的负担🐕
3. Swin-UMamba解码器
上面的1介绍了VSS块,2介绍了Swin-UMamba的编码器和预训练权重使用,这里的3介绍了解码器部分。
1.
带有残差连接的额外卷积块来处理跳过连接特征,这里作者是使用了monai框架中的UnetResBlock实现的,暂时没太搞懂。
2.
每个尺度的额外分割头用于深度监督。
4. Swin-UMamba†:带有基于 Mamba 的解码器的 Swin-UMamba
上面的1介绍了VSS块,2介绍了Swin-UMamba的编码器和预训练权重使用,3介绍了解码器部分,这里的4介绍的是Swin-Umamba变体
1.
我们使用 4×4 补丁嵌入层,直接将输入图像从 H ×W ×C 投影到形状为 H/4 ×H/4 × 96 的特征图中,遵循 VMamba [24]。值得注意的是,Swin-UMamba† 中的最后一个补丁扩展块是 4×上采样操作,镜像 4× 补丁嵌入层。剩余的patch扩展层进行2×上采样操作。源自输入图像的跳过连接和 Swin-UMamba 中的 2× 下采样特征被删除,因为它们没有相应的解码块。此外,深度监督应用于 1×、1/4 ×、1/8 × 和 1/16 × 的分辨率,并为每个尺度结合了额外的分割头(即,将高维特征映射到 K 的 1 × 1 卷积)。结合所有这些修改,网络参数的数量从 40M 减少到 27M,并且 AbdomenMRI 数据集上的 FLOPs 从 58.4G 减少到 15.0G。
实验
我们评估了 Swin-UMamba 在三个不同的医学图像分割数据集上的性能和可扩展性,包括
器官分割
、仪器分割
和细胞分割
。这些数据集是在各种分辨率和图像模式中选择的,可以深入了解模型在不同医学成像场景中的功效和适应性。
文中指出:我们的主要目标是评估预训练模型对医学图像分割的影响,而不是仅仅追求最先进的 (SOTA) 性能。
1.
实验结果如表1、表2和表3所示,在AbdomenMRI(腹部MRI器官分割)、Endoscopy(内窥镜手术器械分割)和Microscopy(细胞分割)三个数据集上,Swin-UMamba和Swin-UMamba†相比于baseline方法均有显著的提升。
2.
作者发现ImageNet预训练
对Swin-UMamba和Swin-UMamba†都起到了非常重要的作用,包括更高的分割精度、更稳定的收敛、减轻过拟合问题、数据效率和较低的计算开销
。
例如,相较于不使用ImageNet预训练
,使用ImageNet预训练能够为Swin-UMamba在Endoscopy上的DSC带来13.08%的显著提升
,这可能是因为Endoscopy数据集较小,导致模型更容易过拟合。
3.
Swin-UMamba仅需baseline算法1/10的训练迭代次数就能够在AbdomenMRI数据上收敛,这有利于节省训练期间的计算开销。
4.
ImageNet预训练有时对模型收敛稳定性起到至关重要的作用,**Swin-UMamba†**在没有使用预训练模型时很难在AbdomenMRI数据上正常收敛,而当使用预训练模型后,**Swin-UMamba†**又能够以较少的参数量和FLOPs超过所有baseline算法,这也进一步佐证了预训练对Mamba-based模型的重要性。
代码实践
- 环境安装
conda create -n your_env_name python=3.10.13
conda activate your_env_name
conda install cudatoolkit==11.8 -c nvidia
pip install torch==2.1.1 torchvision==0.16.1 torchaudio==2.1.1 --index-url https://download.pytorch.org/whl/cu118
conda install -c "nvidia/label/cuda-11.8.0" cuda-nvcc
conda install packaging
pip install causal-conv1d==1.1.1
pip install mamba-ssm
- 代码实践
import re
import time
import math
import numpy as np
from functools import partial
from typing import Optional, Union, Type, List, Tuple, Callable, Dict
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.checkpoint as checkpoint
from einops import rearrange, repeat
from timm.models.layers import DropPath, to_2tuple, trunc_normal_
from mamba_ssm.ops.selective_scan_interface import selective_scan_fn, selective_scan_ref
DropPath.__repr__ = lambda self: f"timm.DropPath({self.drop_prob})"
from monai.networks.blocks.dynunet_block import UnetOutBlock
from monai.networks.blocks.unetr_block import UnetrBasicBlock, UnetrUpBlock
class PatchEmbed2D(nn.Module):
r""" Image to Patch Embedding
Args:
patch_size (int): Patch token size. Default: 4.
in_chans (int): Number of input image channels. Default: 3.
embed_dim (int): Number of linear projection output channels. Default: 96.
norm_layer (nn.Module, optional): Normalization layer. Default: None
"""
def __init__(self, patch_size=4, in_chans=3, embed_dim=96, norm_layer=None, **kwargs):
super().__init__()
if isinstance(patch_size, int):
patch_size = (patch_size, patch_size)
self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
if norm_layer is not None:
self.norm = norm_layer(embed_dim)
else:
self.norm = None
def forward(self, x):
x = self.proj(x).permute(0, 2, 3, 1)
if self.norm is not None:
x = self.norm(x)
return x
class PatchMerging2D(nn.Module):
r""" Patch Merging Layer.
Args:
input_resolution (tuple[int]): Resolution of input feature.
dim (int): Number of input channels.
norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
"""
def __init__(self, dim, norm_layer=nn.LayerNorm):
super().__init__()
self.dim = dim
self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False)
self.norm = norm_layer(4 * dim)
def forward(self, x):
B, H, W, C = x.shape
SHAPE_FIX = [-1, -1]
if (W % 2 != 0) or (H % 2 != 0):
print(f"Warning, x.shape {x.shape} is not match even ===========", flush=True)
SHAPE_FIX[0] = H // 2
SHAPE_FIX[1] = W // 2
x0 = x[:, 0::2, 0::2, :] # B H/2 W/2 C
x1 = x[:, 1::2, 0::2, :] # B H/2 W/2 C
x2 = x[:, 0::2, 1::2, :] # B H/2 W/2 C
x3 = x[:, 1::2, 1::2, :] # B H/2 W/2 C
if SHAPE_FIX[0] > 0:
x0 = x0[:, :SHAPE_FIX[0], :SHAPE_FIX[1], :]
x1 = x1[:, :SHAPE_FIX[0], :SHAPE_FIX[1], :]
x2 = x2[:, :SHAPE_FIX[0], :SHAPE_FIX[1], :]
x3 = x3[:, :SHAPE_FIX[0], :SHAPE_FIX[1], :]
x = torch.cat([x0, x1, x2, x3], -1) # B H/2 W/2 4*C
x = x.view(B, H//2, W//2, 4 * C) # B H/2*W/2 4*C
x = self.norm(x)
x = self.reduction(x)
return x
class SS2D(nn.Module):
def __init__(
self,
d_model,
d_state=16,
d_conv=3,
expand=2,
dt_rank="auto",
dt_min=0.001,
dt_max=0.1,
dt_init="random",
dt_scale=1.0,
dt_init_floor=1e-4,
dropout=0.,
conv_bias=True,
bias=False,
device=None,
dtype=None,
**kwargs,
):
factory_kwargs = {"device": device, "dtype": dtype}
super().__init__()
self.d_model = d_model
self.d_state = d_state
self.d_conv = d_conv
self.expand = expand
self.d_inner = int(self.expand * self.d_model)
self.dt_rank = math.ceil(self.d_model / 16) if dt_rank == "auto" else dt_rank
self.in_proj = nn.Linear(self.d_model, self.d_inner * 2, bias=bias, **factory_kwargs)
self.conv2d = nn.Conv2d(
in_channels=self.d_inner,
out_channels=self.d_inner,
groups=self.d_inner,
bias=conv_bias,
kernel_size=d_conv,
padding=(d_conv - 1) // 2,
**factory_kwargs,
)
self.act = nn.SiLU()
self.x_proj = (
nn.Linear(self.d_inner, (self.dt_rank + self.d_state * 2), bias=False, **factory_kwargs),
nn.Linear(self.d_inner, (self.dt_rank + self.d_state * 2), bias=False, **factory_kwargs),
nn.Linear(self.d_inner, (self.dt_rank + self.d_state * 2), bias=False, **factory_kwargs),
nn.Linear(self.d_inner, (self.dt_rank + self.d_state * 2), bias=False, **factory_kwargs),
)
self.x_proj_weight = nn.Parameter(torch.stack([t.weight for t in self.x_proj], dim=0)) # (K=4, N, inner)
del self.x_proj
self.dt_projs = (
self.dt_init(self.dt_rank, self.d_inner, dt_scale, dt_init, dt_min, dt_max, dt_init_floor, **factory_kwargs),
self.dt_init(self.dt_rank, self.d_inner, dt_scale, dt_init, dt_min, dt_max, dt_init_floor, **factory_kwargs),
self.dt_init(self.dt_rank, self.d_inner, dt_scale, dt_init, dt_min, dt_max, dt_init_floor, **factory_kwargs),
self.dt_init(self.dt_rank, self.d_inner, dt_scale, dt_init, dt_min, dt_max, dt_init_floor, **factory_kwargs),
)
self.dt_projs_weight = nn.Parameter(torch.stack([t.weight for t in self.dt_projs], dim=0)) # (K=4, inner, rank)
self.dt_projs_bias = nn.Parameter(torch.stack([t.bias for t in self.dt_projs], dim=0)) # (K=4, inner)
del self.dt_projs
self.A_logs = self.A_log_init(self.d_state, self.d_inner, copies=4, merge=True) # (K=4, D, N)
self.Ds = self.D_init(self.d_inner, copies=4, merge=True) # (K=4, D, N)
self.selective_scan = selective_scan_fn
self.out_norm = nn.LayerNorm(self.d_inner)
self.out_proj = nn.Linear(self.d_inner, self.d_model, bias=bias, **factory_kwargs)
self.dropout = nn.Dropout(dropout) if dropout > 0. else None
@staticmethod
def dt_init(dt_rank, d_inner, dt_scale=1.0, dt_init="random", dt_min=0.001, dt_max=0.1, dt_init_floor=1e-4, **factory_kwargs):
dt_proj = nn.Linear(dt_rank, d_inner, bias=True, **factory_kwargs)
# Initialize special dt projection to preserve variance at initialization
dt_init_std = dt_rank**-0.5 * dt_scale
if dt_init == "constant":
nn.init.constant_(dt_proj.weight, dt_init_std)
elif dt_init == "random":
nn.init.uniform_(dt_proj.weight, -dt_init_std, dt_init_std)
else:
raise NotImplementedError
# Initialize dt bias so that F.softplus(dt_bias) is between dt_min and dt_max
dt = torch.exp(
torch.rand(d_inner, **factory_kwargs) * (math.log(dt_max) - math.log(dt_min))
+ math.log(dt_min)
).clamp(min=dt_init_floor)
# Inverse of softplus: https://github.com/pytorch/pytorch/issues/72759
inv_dt = dt + torch.log(-torch.expm1(-dt))
with torch.no_grad():
dt_proj.bias.copy_(inv_dt)
# Our initialization would set all Linear.bias to zero, need to mark this one as _no_reinit
dt_proj.bias._no_reinit = True
return dt_proj
@staticmethod
def A_log_init(d_state, d_inner, copies=1, device=None, merge=True):
# S4D real initialization
A = repeat(
torch.arange(1, d_state + 1, dtype=torch.float32, device=device),
"n -> d n",
d=d_inner,
).contiguous()
A_log = torch.log(A) # Keep A_log in fp32
if copies > 1:
A_log = repeat(A_log, "d n -> r d n", r=copies)
if merge:
A_log = A_log.flatten(0, 1)
A_log = nn.Parameter(A_log)
A_log._no_weight_decay = True
return A_log
@staticmethod
def D_init(d_inner, copies=1, device=None, merge=True):
# D "skip" parameter
D = torch.ones(d_inner, device=device)
if copies > 1:
D = repeat(D, "n1 -> r n1", r=copies)
if merge:
D = D.flatten(0, 1)
D = nn.Parameter(D) # Keep in fp32
D._no_weight_decay = True
return D
def forward_core(self, x: torch.Tensor):
B, C, H, W = x.shape
L = H * W
K = 4
x_hwwh = torch.stack([x.view(B, -1, L), torch.transpose(x, dim0=2, dim1=3).contiguous().view(B, -1, L)], dim=1).view(B, 2, -1, L)
xs = torch.cat([x_hwwh, torch.flip(x_hwwh, dims=[-1])], dim=1) # (b, k, d, l)
x_dbl = torch.einsum("b k d l, k c d -> b k c l", xs.view(B, K, -1, L), self.x_proj_weight)
dts, Bs, Cs = torch.split(x_dbl, [self.dt_rank, self.d_state, self.d_state], dim=2)
dts = torch.einsum("b k r l, k d r -> b k d l", dts.view(B, K, -1, L), self.dt_projs_weight)
xs = xs.float().view(B, -1, L) # (b, k * d, l)
dts = dts.contiguous().float().view(B, -1, L) # (b, k * d, l)
Bs = Bs.float().view(B, K, -1, L) # (b, k, d_state, l)
Cs = Cs.float().view(B, K, -1, L) # (b, k, d_state, l)
Ds = self.Ds.float().view(-1) # (k * d)
As = -torch.exp(self.A_logs.float()).view(-1, self.d_state) # (k * d, d_state)
dt_projs_bias = self.dt_projs_bias.float().view(-1) # (k * d)
out_y = self.selective_scan(
xs, dts,
As, Bs, Cs, Ds, z=None,
delta_bias=dt_projs_bias,
delta_softplus=True,
return_last_state=False,
).view(B, K, -1, L)
assert out_y.dtype == torch.float
inv_y = torch.flip(out_y[:, 2:4], dims=[-1]).view(B, 2, -1, L)
wh_y = torch.transpose(out_y[:, 1].view(B, -1, W, H), dim0=2, dim1=3).contiguous().view(B, -1, L)
invwh_y = torch.transpose(inv_y[:, 1].view(B, -1, W, H), dim0=2, dim1=3).contiguous().view(B, -1, L)
return out_y[:, 0], inv_y[:, 0], wh_y, invwh_y
def forward(self, x: torch.Tensor, **kwargs):
B, H, W, C = x.shape
xz = self.in_proj(x)
x, z = xz.chunk(2, dim=-1) # (b, h, w, d)
x = x.permute(0, 3, 1, 2).contiguous()
x = self.act(self.conv2d(x)) # (b, d, h, w)
y1, y2, y3, y4 = self.forward_core(x)
assert y1.dtype == torch.float32
y = y1 + y2 + y3 + y4
y = torch.transpose(y, dim0=1, dim1=2).contiguous().view(B, H, W, -1)
y = self.out_norm(y)
y = y * F.silu(z)
out = self.out_proj(y)
if self.dropout is not None:
out = self.dropout(out)
return out
class VSSBlock(nn.Module):
def __init__(
self,
hidden_dim: int = 0,
drop_path: float = 0,
norm_layer: Callable[..., torch.nn.Module] = partial(nn.LayerNorm, eps=1e-6),
attn_drop_rate: float = 0,
d_state: int = 16,
**kwargs,
):
super().__init__()
self.ln_1 = norm_layer(hidden_dim)
self.self_attention = SS2D(d_model=hidden_dim, dropout=attn_drop_rate, d_state=d_state, **kwargs)
self.drop_path = DropPath(drop_path)
def forward(self, input: torch.Tensor):
x = input + self.drop_path(self.self_attention(self.ln_1(input)))
return x
class VSSLayer(nn.Module):
""" A basic layer for one stage.
Args:
dim (int): Number of input channels.
depth (int): Number of blocks.
drop (float, optional): Dropout rate. Default: 0.0
attn_drop (float, optional): Attention dropout rate. Default: 0.0
drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0
norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None
use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False.
"""
def __init__(
self,
dim,
depth,
attn_drop=0.,
drop_path=0.,
norm_layer=nn.LayerNorm,
downsample=None,
use_checkpoint=False,
d_state=16,
**kwargs,
):
super().__init__()
self.dim = dim
self.use_checkpoint = use_checkpoint
self.blocks = nn.ModuleList([
VSSBlock(
hidden_dim=dim,
drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path,
norm_layer=norm_layer,
attn_drop_rate=attn_drop,
d_state=d_state,
)
for i in range(depth)])
if True: # is this really applied? Yes, but been overriden later in VSSM!
def _init_weights(module: nn.Module):
for name, p in module.named_parameters():
if name in ["out_proj.weight"]:
p = p.clone().detach_() # fake init, just to keep the seed ....
nn.init.kaiming_uniform_(p, a=math.sqrt(5))
self.apply(_init_weights)
if downsample is not None:
self.downsample = downsample(dim=dim, norm_layer=norm_layer)
else:
self.downsample = None
def forward(self, x):
for blk in self.blocks:
if self.use_checkpoint:
x = checkpoint.checkpoint(blk, x)
else:
x = blk(x)
if self.downsample is not None:
x = self.downsample(x)
return x
class VSSMEncoder(nn.Module):
def __init__(self, patch_size=4, in_chans=3, depths=[2, 2, 9, 2],
dims=[96, 192, 384, 768], d_state=16, drop_rate=0., attn_drop_rate=0., drop_path_rate=0.2,
norm_layer=nn.LayerNorm, patch_norm=True,
use_checkpoint=False, **kwargs):
super().__init__()
self.num_layers = len(depths)
if isinstance(dims, int):
dims = [int(dims * 2 ** i_layer) for i_layer in range(self.num_layers)]
self.embed_dim = dims[0]
self.num_features = dims[-1]
self.dims = dims
self.patch_embed = PatchEmbed2D(patch_size=patch_size, in_chans=in_chans, embed_dim=self.embed_dim,
norm_layer=norm_layer if patch_norm else None)
# WASTED absolute position embedding ======================
self.ape = False
if self.ape:
self.patches_resolution = self.patch_embed.patches_resolution
self.absolute_pos_embed = nn.Parameter(torch.zeros(1, *self.patches_resolution, self.embed_dim))
trunc_normal_(self.absolute_pos_embed, std=.02)
self.pos_drop = nn.Dropout(p=drop_rate)
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] # stochastic depth decay rule
self.layers = nn.ModuleList()
self.downsamples = nn.ModuleList()
for i_layer in range(self.num_layers):
layer = VSSLayer(
dim=dims[i_layer],
depth=depths[i_layer],
d_state=math.ceil(dims[0] / 6) if d_state is None else d_state, # 20240109
drop=drop_rate,
attn_drop=attn_drop_rate,
drop_path=dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])],
norm_layer=norm_layer,
downsample=None,
use_checkpoint=use_checkpoint,
)
self.layers.append(layer)
if i_layer < self.num_layers - 1:
self.downsamples.append(PatchMerging2D(dim=dims[i_layer], norm_layer=norm_layer))
self.apply(self._init_weights)
def _init_weights(self, m: nn.Module):
"""
out_proj.weight which is previously initilized in VSSBlock, would be cleared in nn.Linear
no fc.weight found in the any of the model parameters
no nn.Embedding found in the any of the model parameters
so the thing is, VSSBlock initialization is useless
Conv2D is not intialized !!!
"""
if isinstance(m, nn.Linear):
trunc_normal_(m.weight, std=.02)
if isinstance(m, nn.Linear) and m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.LayerNorm):
nn.init.constant_(m.bias, 0)
nn.init.constant_(m.weight, 1.0)
@torch.jit.ignore
def no_weight_decay(self):
return {'absolute_pos_embed'}
@torch.jit.ignore
def no_weight_decay_keywords(self):
return {'relative_position_bias_table'}
def forward(self, x):
x_ret = []
x_ret.append(x)
x = self.patch_embed(x)
if self.ape:
x = x + self.absolute_pos_embed
x = self.pos_drop(x)
for s, layer in enumerate(self.layers):
x = layer(x)
x_ret.append(x.permute(0, 3, 1, 2))
if s < len(self.downsamples):
x = self.downsamples[s](x)
return x_ret
class SwinUMamba(nn.Module):
def __init__(
self,
in_chans=3,
out_chans=1,
feat_size=[48, 96, 192, 384],
drop_path_rate=0,
layer_scale_init_value=1e-6,
hidden_size: int = 768,
norm_name = "instance",
res_block: bool = True,
spatial_dims=2,
deep_supervision: bool = False,
) -> None:
super().__init__()
self.hidden_size = hidden_size
self.in_chans = in_chans
self.out_chans = out_chans
self.drop_path_rate = drop_path_rate
self.feat_size = feat_size
self.layer_scale_init_value = layer_scale_init_value
self.stem = nn.Sequential(
nn.Conv2d(in_chans, feat_size[0], kernel_size=7, stride=2, padding=3),
nn.InstanceNorm2d(feat_size[0], eps=1e-5, affine=True),
)
self.spatial_dims = spatial_dims
self.vssm_encoder = VSSMEncoder(patch_size=2, in_chans=feat_size[0])
self.encoder1 = UnetrBasicBlock(
spatial_dims=spatial_dims,
in_channels=self.in_chans,
out_channels=self.feat_size[0],
kernel_size=3,
stride=1,
norm_name=norm_name,
res_block=res_block,
)
self.encoder2 = UnetrBasicBlock(
spatial_dims=spatial_dims,
in_channels=self.feat_size[0],
out_channels=self.feat_size[1],
kernel_size=3,
stride=1,
norm_name=norm_name,
res_block=res_block,
)
self.encoder3 = UnetrBasicBlock(
spatial_dims=spatial_dims,
in_channels=self.feat_size[1],
out_channels=self.feat_size[2],
kernel_size=3,
stride=1,
norm_name=norm_name,
res_block=res_block,
)
self.encoder4 = UnetrBasicBlock(
spatial_dims=spatial_dims,
in_channels=self.feat_size[2],
out_channels=self.feat_size[3],
kernel_size=3,
stride=1,
norm_name=norm_name,
res_block=res_block,
)
self.encoder5 = UnetrBasicBlock(
spatial_dims=spatial_dims,
in_channels=self.feat_size[3],
out_channels=self.hidden_size,
kernel_size=3,
stride=1,
norm_name=norm_name,
res_block=res_block,
)
self.decoder5 = UnetrUpBlock(
spatial_dims=spatial_dims,
in_channels=self.hidden_size,
out_channels=self.feat_size[3],
kernel_size=3,
upsample_kernel_size=2,
norm_name=norm_name,
res_block=res_block,
)
self.decoder4 = UnetrUpBlock(
spatial_dims=spatial_dims,
in_channels=self.feat_size[3],
out_channels=self.feat_size[2],
kernel_size=3,
upsample_kernel_size=2,
norm_name=norm_name,
res_block=res_block,
)
self.decoder3 = UnetrUpBlock(
spatial_dims=spatial_dims,
in_channels=self.feat_size[2],
out_channels=self.feat_size[1],
kernel_size=3,
upsample_kernel_size=2,
norm_name=norm_name,
res_block=res_block,
)
self.decoder2 = UnetrUpBlock(
spatial_dims=spatial_dims,
in_channels=self.feat_size[1],
out_channels=self.feat_size[0],
kernel_size=3,
upsample_kernel_size=2,
norm_name=norm_name,
res_block=res_block,
)
self.decoder1 = UnetrBasicBlock(
spatial_dims=spatial_dims,
in_channels=self.feat_size[0],
out_channels=self.feat_size[0],
kernel_size=3,
stride=1,
norm_name=norm_name,
res_block=res_block,
)
# deep supervision support
self.deep_supervision = deep_supervision
self.out_layers = nn.ModuleList()
for i in range(4):
self.out_layers.append(UnetOutBlock(
spatial_dims=spatial_dims,
in_channels=self.feat_size[i],
out_channels=self.out_chans
))
def forward(self, x_in):
x1 = self.stem(x_in)
vss_outs = self.vssm_encoder(x1)
enc1 = self.encoder1(x_in)
enc2 = self.encoder2(vss_outs[0])
enc3 = self.encoder3(vss_outs[1])
enc4 = self.encoder4(vss_outs[2])
enc_hidden = self.encoder5(vss_outs[3])
dec3 = self.decoder5(enc_hidden, enc4)
dec2 = self.decoder4(dec3, enc3)
dec1 = self.decoder3(dec2, enc2)
dec0 = self.decoder2(dec1, enc1)
dec_out = self.decoder1(dec0)
if self.deep_supervision:
feat_out = [dec_out, dec1, dec2, dec3]
out = []
for i in range(4):
pred = self.out_layers[i](feat_out[i])
out.append(pred)
else:
out = self.out_layers[0](dec_out)
return out
@torch.no_grad()
def freeze_encoder(self):
for name, param in self.vssm_encoder.named_parameters():
if "patch_embed" not in name:
param.requires_grad = False
@torch.no_grad()
def unfreeze_encoder(self):
for param in self.vssm_encoder.parameters():
param.requires_grad = True
# def load_pretrained_ckpt(
# model,
# ckpt_path="./data/pretrained/vmamba/vmamba_tiny_e292.pth"
# ):
# print(f"Loading weights from: {ckpt_path}")
# skip_params = ["norm.weight", "norm.bias", "head.weight", "head.bias",
# "patch_embed.proj.weight", "patch_embed.proj.bias",
# "patch_embed.norm.weight", "patch_embed.norm.weight"]
# ckpt = torch.load(ckpt_path, map_location='cpu')
# model_dict = model.state_dict()
# for k, v in ckpt['model'].items():
# if k in skip_params:
# print(f"Skipping weights: {k}")
# continue
# kr = f"vssm_encoder.{k}"
# if "downsample" in kr:
# i_ds = int(re.findall(r"layers\.(\d+)\.downsample", kr)[0])
# kr = kr.replace(f"layers.{i_ds}.downsample", f"downsamples.{i_ds}")
# assert kr in model_dict.keys()
# if kr in model_dict.keys():
# assert v.shape == model_dict[kr].shape, f"Shape mismatch: {v.shape} vs {model_dict[kr].shape}"
# model_dict[kr] = v
# else:
# print(f"Passing weights: {k}")
# model.load_state_dict(model_dict)
# return model
# def get_swin_umamba_from_plans(
# plans_manager: int,
# dataset_json: dict,
# configuration_manager: None,
# num_input_channels: int,
# deep_supervision: bool = True,
# use_pretrain: bool = True
# ):
# label_manager = plans_manager.get_label_manager(dataset_json)
# model = SwinUMamba(
# in_chans=num_input_channels,
# out_chans=3,
# feat_size=[48, 96, 192, 384],
# deep_supervision=deep_supervision,
# hidden_size=768,
# )
# if use_pretrain:
# model = load_pretrained_ckpt(model)
# return model
x = torch.randn(12, 3, 256, 256)
net = SwinUMamba(3, 1)
print(net(x).shape)
可借鉴参考
- 阅读U-Mamba
Jun Ma, Feifei Li, and Bo Wang. U-mamba: Enhancing long-range dependency for biomedical image segmentation. arXiv preprint arXiv:2401.04722, 2024.
- 阅读Vmamba
Yue Liu, Yunjie Tian, Yuzhong Zhao, Hongtian Yu, Lingxi Xie, Yaowei Wang, Qixiang Ye, and Yunfan Liu. Vmamba: Visual state space model. arXiv preprint arXiv:2401.10166, 2024.
- 阅读Segmamba
Zhaohu Xing, Tian Ye, Yijun Yang, Guang Liu, and Lei Zhu. Segmamba: Longrange sequential modeling mamba for 3d medical image segmentation. arXiv preprint arXiv:2401.13560, 2024.
- 阅读Visionmamba
Lianghui Zhu, Bencheng Liao, Qian Zhang, Xinlong Wang, Wenyu Liu, and Xinggang Wang. Vision mamba: Efficient visual representation learning with bidirectional state space model. arXiv preprint arXiv:2401.09417, 2024.