UNet 是一种经典的卷积神经网络(Convolutional Neural Network, CNN)架构,专为生物医学图像分割任务设计。该模型于 2015 年由 Olaf Ronneberger 等人在论文《U-Net: Convolutional Networks for Biomedical Image Segmentation》中首次提出,因其卓越的性能和简单的结构,迅速成为图像分割领域的重要模型
1. 环境搭建
1.1 安装 Python 和相关工具
-
安装 Python 3.8 及以上版本
如果尚未安装 Python,可以从 Python官网 下载并安装。确保安装时勾选“Add Python to PATH”选项。 -
安装虚拟环境管理工具
虚拟环境是管理 Python 项目依赖的好方法。可以使用venv
或conda
来创建虚拟环境。我们这里使用venv
,步骤如下:# 创建虚拟环境 python -m venv unet_env # 激活虚拟环境 source unet_env/bin/activate # Linux/Mac unet_env\Scripts\activate # Windows
1.2 安装依赖库
-
安装 PyTorch
根据你的硬件选择正确的 PyTorch 版本。如果你的电脑支持 CUDA(GPU 加速),可以使用带 CUDA 的版本,否则使用 CPU 版本:# 安装支持 CUDA 11.8 版本的 PyTorch pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu118 # 如果不支持 CUDA,则使用以下命令: pip install torch torchvision torchaudio
-
安装其他依赖
你还需要一些其他的辅助库:pip install numpy opencv-python matplotlib tqdm scikit-learn pillow
2. 下载或实现 UNet 模型
2.1 UNet 模型结构详解
UNet 是经典的图像分割网络,其主要特点是由编码器(下采样部分)和解码器(上采样部分)组成。通过跳跃连接,编码器的每一层都将特征图传递到解码器对应层,以保持细节信息。
以下是 UNet 的详细实现,包含编码器、解码器、跳跃连接以及卷积操作:
import torch
import torch.nn as nn
import torch.nn.functional as F
class UNet(nn.Module):
def __init__(self, in_channels, out_channels):
super(UNet, self).__init__()
# 编码器部分
self.encoder1 = self.conv_block(in_channels, 64)
self.encoder2 = self.conv_block(64, 128)
self.encoder3 = self.conv_block(128, 256)
self.encoder4 = self.conv_block(256, 512)
# 底部瓶颈部分
self.bottleneck = self.conv_block(512, 1024)
# 解码器部分
self.upconv4 = self.upconv(1024, 512)
self.decoder4 = self.conv_block(1024, 512)
self.upconv3 = self.upconv(512, 256)
self.decoder3 = self.conv_block(512, 256)
self.upconv2 = self.upconv(256, 128)
self.decoder2 = self.conv_block(256, 128)
self.upconv1 = self.upconv(128, 64)
self.decoder1 = self.conv_block(128, 64)
# 输出层
self.output = nn.Conv2d(64, out_channels, kernel_size=1)
def conv_block(self, in_channels, out_channels):
"""标准的卷积模块,包含两个卷积层和ReLU激活函数"""
return nn.Sequential(
nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
nn.ReLU(inplace=True),
nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1),
nn.ReLU(inplace=True)
)
def upconv(self, in_channels, out_channels):
"""上采样操作,采用转置卷积(反卷积)"""
return nn.ConvTranspose2d(in_channels, out_channels, kernel_size=2, stride=2)
def forward(self, x):
"""前向传播"""
# 编码器:下采样
enc1 = self.encoder1(x)
enc2 = self.encoder2(F.max_pool2d(enc1, 2))
enc3 = self.encoder3(F.max_pool2d(enc2, 2))
enc4 = self.encoder4(F.max_pool2d(enc3, 2))
# 底部瓶颈
bottleneck = self.bottleneck(F.max_pool2d(enc4, 2))
# 解码器:上采样 + 跳跃连接
dec4 = self.upconv4(bottleneck)
dec4 = torch.cat((dec4, enc4), dim=1) # 跳跃连接
dec4 = self.decoder4(dec4)
dec3 = self.upconv3(dec4)
dec3 = torch.cat((dec3, enc3), dim=1)
dec3 = self.decoder3(dec3)
dec2 = self.upconv2(dec3)
dec2 = torch.cat((dec2, enc2), dim=1)
dec2 = self.decoder2(dec2)
dec1 = self.upconv1(dec2)
dec1 = torch.cat((dec1, enc1), dim=1)
dec1 = self.decoder1(dec1)
return self.output(dec1)
3. 数据处理
3.1 数据集准备
为了训练 UNet,你需要准备一个图像分割数据集。数据集通常由原始图像(RGB 图像)和每个图像对应的标注图像(Mask)组成。
假设我们有一个目录结构:
dataset/
├── train/
│ ├── images/
│ └── masks/
├── val/
│ ├── images/
│ └── masks/
每个 images
文件夹包含训练图像,而 masks
文件夹包含对应的标注图像。
3.2 数据加载器
在 PyTorch 中,我们可以通过 Dataset
类来自定义数据加载器。以下是一个简单的 SegmentationDataset
类:
import os
import cv2
import torch
from torch.utils.data import Dataset
class SegmentationDataset(Dataset):
def __init__(self, image_dir, mask_dir, transform=None):
self.image_dir = image_dir
self.mask_dir = mask_dir
self.transform = transform
self.images = os.listdir(image_dir)
def __len__(self):
return len(self.images)
def __getitem__(self, idx):
img_path = os.path.join(self.image_dir, self.images[idx])
mask_path = os.path.join(self.mask_dir, self.images[idx])
# 读取图像和标签
image = cv2.imread(img_path)
mask = cv2.imread(mask_path, cv2.IMREAD_GRAYSCALE)
# 应用任何数据增强(如有)
if self.transform:
augmented = self.transform(image=image, mask=mask)
image = augmented['image']
mask = augmented['mask']
# 转换为 Tensor,通道数放到最前面
return torch.tensor(image, dtype=torch.float32).permute(2, 0, 1), torch.tensor(mask, dtype=torch.long)
4. 训练模型
4.1 定义训练过程
我们将训练一个 UNet 模型,使用交叉熵损失函数和 Adam 优化器。训练时,输入的图像和标签将通过 DataLoader 加载。
from torch.utils.data import DataLoader
from torch.optim import Adam
from torch.nn import CrossEntropyLoss
# 超参数
LEARNING_RATE = 1e-4
BATCH_SIZE = 8
EPOCHS = 20
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
# 数据加载
train_dataset = SegmentationDataset("dataset/train/images", "dataset/train/masks")
train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)
# 初始化模型与优化器
model = UNet(in_channels=3, out_channels=2).to(DEVICE) # 假设是二分类
optimizer = Adam(model.parameters(), lr=LEARNING_RATE)
criterion = CrossEntropyLoss()
# 训练过程
for epoch in range(EPOCHS):
model.train()
for images, masks in train_loader:
images, masks = images.to(DEVICE), masks.to(DEVICE)
optimizer.zero_grad()
# 前向传播
outputs = model(images)
# 计算损失
loss = criterion(outputs, masks)
# 反向传播
loss.backward()
# 更新参数
optimizer.step()
print(f"Epoch [{epoch+1}/{EPOCHS}], Loss: {loss.item():.4f}")
5. 模型推理
5.1 保存与加载模型
# 保存模型
torch.save(model.state_dict(), "unet_model.pth")
# 加载模型
model.load_state_dict(torch.load("unet_model.pth"))
model.eval()
5.2 单张图片推理
def predict(model, image_path):
image = cv2.imread(image_path)
image = torch.tensor(image, dtype=torch.float32).permute(2, 0, 1).unsqueeze(0)
with torch.no_grad():
output = model(image)
return output.argmax(dim=1).squeeze(0).numpy()
6. 模型优化与改进
为了提高 UNet 的性能,我们可以从以下几个方面进行优化:
6.1 数据增强
在训练过程中引入数据增强技术可以提高模型的泛化能力。使用 Albumentations
库可以实现多种增强方式,例如旋转、翻转、裁剪等:
import albumentations as A
from albumentations.pytorch import ToTensorV2
transform = A.Compose([
A.Resize(256, 256), # 调整尺寸
A.HorizontalFlip(p=0.5), # 随机水平翻转
A.VerticalFlip(p=0.5), # 随机垂直翻转
A.RandomRotate90(p=0.5), # 随机旋转90度
A.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)), # 标准化
ToTensorV2() # 转为 Tensor
])
# 在数据集初始化时传入 transform
train_dataset = SegmentationDataset("dataset/train/images", "dataset/train/masks", transform=transform)
6.2 学习率调度器
动态调整学习率可以提高收敛速度。可以使用 PyTorch 提供的学习率调度器,例如 StepLR
或 ReduceLROnPlateau
:
from torch.optim.lr_scheduler import ReduceLROnPlateau
scheduler = ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=3)
for epoch in range(EPOCHS):
model.train()
epoch_loss = 0
for images, masks in train_loader:
images, masks = images.to(DEVICE), masks.to(DEVICE)
optimizer.zero_grad()
outputs = model(images)
loss = criterion(outputs, masks)
loss.backward()
optimizer.step()
epoch_loss += loss.item()
# 更新学习率
scheduler.step(epoch_loss / len(train_loader))
print(f"Epoch [{epoch+1}/{EPOCHS}], Loss: {epoch_loss / len(train_loader):.4f}")
6.3 混合精度训练
使用混合精度训练可以加速训练并减少显存使用,特别是在 GPU 上。PyTorch 提供了 torch.cuda.amp
模块来实现:
from torch.cuda.amp import GradScaler, autocast
scaler = GradScaler()
for epoch in range(EPOCHS):
model.train()
for images, masks in train_loader:
images, masks = images.to(DEVICE), masks.to(DEVICE)
optimizer.zero_grad()
# 自动混合精度
with autocast():
outputs = model(images)
loss = criterion(outputs, masks)
# 反向传播与优化
scaler.scale(loss).backward()
scaler.step(optimizer)
scaler.update()
6.4 Dice Loss 或 IoU Loss
交叉熵损失适合分类任务,但在分割任务中,Dice Loss 或 IoU Loss 能更好地处理类别不平衡问题:
class DiceLoss(nn.Module):
def __init__(self):
super(DiceLoss, self).__init__()
def forward(self, preds, targets, smooth=1):
preds = torch.sigmoid(preds) # 将输出限制在 [0, 1] 之间
preds = preds.view(-1)
targets = targets.view(-1)
intersection = (preds * targets).sum()
dice = (2. * intersection + smooth) / (preds.sum() + targets.sum() + smooth)
return 1 - dice
然后在训练中替换损失函数:
criterion = DiceLoss()
6.5 模型改进:加入注意力机制
可以在 UNet 的跳跃连接中加入注意力机制(如 Squeeze-and-Excitation 或 Attention Gates),以提升模型对目标区域的关注能力。
以下是一个基于 SE 模块的示例:
class SEBlock(nn.Module):
def __init__(self, in_channels, reduction=16):
super(SEBlock, self).__init__()
self.global_pool = nn.AdaptiveAvgPool2d(1)
self.fc = nn.Sequential(
nn.Linear(in_channels, in_channels // reduction),
nn.ReLU(inplace=True),
nn.Linear(in_channels // reduction, in_channels),
nn.Sigmoid()
)
def forward(self, x):
batch, channels, _, _ = x.size()
y = self.global_pool(x).view(batch, channels)
y = self.fc(y).view(batch, channels, 1, 1)
return x * y
将 SEBlock 插入 UNet 的编码器和解码器中。
7. 模型评估
为了评估模型性能,通常需要计算一些分割任务的指标,例如:
- 像素精度 (Pixel Accuracy)
- IoU (Intersection over Union)
- Dice 系数
以下是计算 IoU 和 Dice 系数的代码:
def compute_metrics(preds, labels):
preds = preds > 0.5 # 阈值化
intersection = (preds & labels).sum()
union = (preds | labels).sum()
iou = intersection / union
dice = (2 * intersection) / (preds.sum() + labels.sum())
return iou, dice
在验证集上运行:
model.eval()
with torch.no_grad():
for images, masks in val_loader:
images, masks = images.to(DEVICE), masks.to(DEVICE)
outputs = model(images)
preds = torch.sigmoid(outputs) > 0.5 # 二值化预测
iou, dice = compute_metrics(preds.cpu(), masks.cpu())
print(f"IoU: {iou:.4f}, Dice: {dice:.4f}")
8. 部署与推理加速
8.1 导出 ONNX
将模型导出为 ONNX 格式以便在推理加速框架中使用:
dummy_input = torch.randn(1, 3, 256, 256).to(DEVICE)
torch.onnx.export(model, dummy_input, "unet_model.onnx", opset_version=11)
8.2 使用 TensorRT 加速推理
可以使用 NVIDIA TensorRT 对 ONNX 模型进行优化并加速推理。具体操作请参考 TensorRT 文档。