【深度学习】 UNet详解

UNet 是一种经典的卷积神经网络(Convolutional Neural Network, CNN)架构,专为生物医学图像分割任务设计。该模型于 2015 年由 Olaf Ronneberger 等人在论文《U-Net: Convolutional Networks for Biomedical Image Segmentation》中首次提出,因其卓越的性能和简单的结构,迅速成为图像分割领域的重要模型


1. 环境搭建

1.1 安装 Python 和相关工具

  1. 安装 Python 3.8 及以上版本
    如果尚未安装 Python,可以从 Python官网 下载并安装。确保安装时勾选“Add Python to PATH”选项。

  2. 安装虚拟环境管理工具
    虚拟环境是管理 Python 项目依赖的好方法。可以使用 venvconda 来创建虚拟环境。我们这里使用 venv,步骤如下:

    # 创建虚拟环境
    python -m venv unet_env
    
    # 激活虚拟环境
    source unet_env/bin/activate  # Linux/Mac
    unet_env\Scripts\activate     # Windows
    

1.2 安装依赖库

  1. 安装 PyTorch
    根据你的硬件选择正确的 PyTorch 版本。如果你的电脑支持 CUDA(GPU 加速),可以使用带 CUDA 的版本,否则使用 CPU 版本:

    # 安装支持 CUDA 11.8 版本的 PyTorch
    pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu118
    
    # 如果不支持 CUDA,则使用以下命令:
    pip install torch torchvision torchaudio
    
  2. 安装其他依赖
    你还需要一些其他的辅助库:

    pip install numpy opencv-python matplotlib tqdm scikit-learn pillow
    

2. 下载或实现 UNet 模型

2.1 UNet 模型结构详解

UNet 是经典的图像分割网络,其主要特点是由编码器(下采样部分)和解码器(上采样部分)组成。通过跳跃连接,编码器的每一层都将特征图传递到解码器对应层,以保持细节信息。

以下是 UNet 的详细实现,包含编码器、解码器、跳跃连接以及卷积操作:

import torch
import torch.nn as nn
import torch.nn.functional as F

class UNet(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(UNet, self).__init__()
        
        # 编码器部分
        self.encoder1 = self.conv_block(in_channels, 64)
        self.encoder2 = self.conv_block(64, 128)
        self.encoder3 = self.conv_block(128, 256)
        self.encoder4 = self.conv_block(256, 512)
        
        # 底部瓶颈部分
        self.bottleneck = self.conv_block(512, 1024)
        
        # 解码器部分
        self.upconv4 = self.upconv(1024, 512)
        self.decoder4 = self.conv_block(1024, 512)
        
        self.upconv3 = self.upconv(512, 256)
        self.decoder3 = self.conv_block(512, 256)
        
        self.upconv2 = self.upconv(256, 128)
        self.decoder2 = self.conv_block(256, 128)
        
        self.upconv1 = self.upconv(128, 64)
        self.decoder1 = self.conv_block(128, 64)
        
        # 输出层
        self.output = nn.Conv2d(64, out_channels, kernel_size=1)

    def conv_block(self, in_channels, out_channels):
        """标准的卷积模块,包含两个卷积层和ReLU激活函数"""
        return nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1),
            nn.ReLU(inplace=True)
        )

    def upconv(self, in_channels, out_channels):
        """上采样操作,采用转置卷积(反卷积)"""
        return nn.ConvTranspose2d(in_channels, out_channels, kernel_size=2, stride=2)

    def forward(self, x):
        """前向传播"""
        # 编码器:下采样
        enc1 = self.encoder1(x)
        enc2 = self.encoder2(F.max_pool2d(enc1, 2))
        enc3 = self.encoder3(F.max_pool2d(enc2, 2))
        enc4 = self.encoder4(F.max_pool2d(enc3, 2))
        
        # 底部瓶颈
        bottleneck = self.bottleneck(F.max_pool2d(enc4, 2))
        
        # 解码器:上采样 + 跳跃连接
        dec4 = self.upconv4(bottleneck)
        dec4 = torch.cat((dec4, enc4), dim=1)  # 跳跃连接
        dec4 = self.decoder4(dec4)
        
        dec3 = self.upconv3(dec4)
        dec3 = torch.cat((dec3, enc3), dim=1)
        dec3 = self.decoder3(dec3)
        
        dec2 = self.upconv2(dec3)
        dec2 = torch.cat((dec2, enc2), dim=1)
        dec2 = self.decoder2(dec2)
        
        dec1 = self.upconv1(dec2)
        dec1 = torch.cat((dec1, enc1), dim=1)
        dec1 = self.decoder1(dec1)

        return self.output(dec1)

3. 数据处理

3.1 数据集准备

为了训练 UNet,你需要准备一个图像分割数据集。数据集通常由原始图像(RGB 图像)和每个图像对应的标注图像(Mask)组成。

假设我们有一个目录结构:

dataset/
├── train/
│   ├── images/
│   └── masks/
├── val/
│   ├── images/
│   └── masks/

每个 images 文件夹包含训练图像,而 masks 文件夹包含对应的标注图像。

3.2 数据加载器

在 PyTorch 中,我们可以通过 Dataset 类来自定义数据加载器。以下是一个简单的 SegmentationDataset 类:

import os
import cv2
import torch
from torch.utils.data import Dataset

class SegmentationDataset(Dataset):
    def __init__(self, image_dir, mask_dir, transform=None):
        self.image_dir = image_dir
        self.mask_dir = mask_dir
        self.transform = transform
        self.images = os.listdir(image_dir)

    def __len__(self):
        return len(self.images)

    def __getitem__(self, idx):
        img_path = os.path.join(self.image_dir, self.images[idx])
        mask_path = os.path.join(self.mask_dir, self.images[idx])
        
        # 读取图像和标签
        image = cv2.imread(img_path)
        mask = cv2.imread(mask_path, cv2.IMREAD_GRAYSCALE)
        
        # 应用任何数据增强(如有)
        if self.transform:
            augmented = self.transform(image=image, mask=mask)
            image = augmented['image']
            mask = augmented['mask']
        
        # 转换为 Tensor,通道数放到最前面
        return torch.tensor(image, dtype=torch.float32).permute(2, 0, 1), torch.tensor(mask, dtype=torch.long)

4. 训练模型

4.1 定义训练过程

我们将训练一个 UNet 模型,使用交叉熵损失函数和 Adam 优化器。训练时,输入的图像和标签将通过 DataLoader 加载。

from torch.utils.data import DataLoader
from torch.optim import Adam
from torch.nn import CrossEntropyLoss

# 超参数
LEARNING_RATE = 1e-4
BATCH_SIZE = 8
EPOCHS = 20
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"

# 数据加载
train_dataset = SegmentationDataset("dataset/train/images", "dataset/train/masks")
train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)

# 初始化模型与优化器
model = UNet(in_channels=3, out_channels=2).to(DEVICE)  # 假设是二分类
optimizer = Adam(model.parameters(), lr=LEARNING_RATE)
criterion = CrossEntropyLoss()

# 训练过程
for epoch in range(EPOCHS):
    model.train()
    for images, masks in train_loader:
        images, masks = images.to(DEVICE), masks.to(DEVICE)
        
        optimizer.zero_grad()
        
        # 前向传播
        outputs = model(images)
        
        # 计算损失
        loss = criterion(outputs, masks)
        
        # 反向传播
        loss.backward()
        
        # 更新参数
        optimizer.step()
        
    print(f"Epoch [{epoch+1}/{EPOCHS}], Loss: {loss.item():.4f}")

5. 模型推理

5.1 保存与加载模型

# 保存模型
torch.save(model.state_dict(), "unet_model.pth")

# 加载模型
model.load_state_dict(torch.load("unet_model.pth"))
model.eval()

5.2 单张图片推理

def predict(model, image_path):
    image = cv2.imread(image_path)
    image = torch.tensor(image, dtype=torch.float32).permute(2, 0, 1).unsqueeze(0)
    with torch.no_grad():
        output = model(image)
    return output.argmax(dim=1).squeeze(0).numpy()

6. 模型优化与改进

为了提高 UNet 的性能,我们可以从以下几个方面进行优化:


6.1 数据增强

在训练过程中引入数据增强技术可以提高模型的泛化能力。使用 Albumentations 库可以实现多种增强方式,例如旋转、翻转、裁剪等:

import albumentations as A
from albumentations.pytorch import ToTensorV2

transform = A.Compose([
    A.Resize(256, 256),          # 调整尺寸
    A.HorizontalFlip(p=0.5),     # 随机水平翻转
    A.VerticalFlip(p=0.5),       # 随机垂直翻转
    A.RandomRotate90(p=0.5),     # 随机旋转90度
    A.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),  # 标准化
    ToTensorV2()                 # 转为 Tensor
])

# 在数据集初始化时传入 transform
train_dataset = SegmentationDataset("dataset/train/images", "dataset/train/masks", transform=transform)

6.2 学习率调度器

动态调整学习率可以提高收敛速度。可以使用 PyTorch 提供的学习率调度器,例如 StepLRReduceLROnPlateau

from torch.optim.lr_scheduler import ReduceLROnPlateau

scheduler = ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=3)

for epoch in range(EPOCHS):
    model.train()
    epoch_loss = 0
    for images, masks in train_loader:
        images, masks = images.to(DEVICE), masks.to(DEVICE)

        optimizer.zero_grad()
        outputs = model(images)
        loss = criterion(outputs, masks)
        loss.backward()
        optimizer.step()
        
        epoch_loss += loss.item()
    
    # 更新学习率
    scheduler.step(epoch_loss / len(train_loader))
    print(f"Epoch [{epoch+1}/{EPOCHS}], Loss: {epoch_loss / len(train_loader):.4f}")

6.3 混合精度训练

使用混合精度训练可以加速训练并减少显存使用,特别是在 GPU 上。PyTorch 提供了 torch.cuda.amp 模块来实现:

from torch.cuda.amp import GradScaler, autocast

scaler = GradScaler()

for epoch in range(EPOCHS):
    model.train()
    for images, masks in train_loader:
        images, masks = images.to(DEVICE), masks.to(DEVICE)

        optimizer.zero_grad()

        # 自动混合精度
        with autocast():
            outputs = model(images)
            loss = criterion(outputs, masks)

        # 反向传播与优化
        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()

6.4 Dice Loss 或 IoU Loss

交叉熵损失适合分类任务,但在分割任务中,Dice Loss 或 IoU Loss 能更好地处理类别不平衡问题:

class DiceLoss(nn.Module):
    def __init__(self):
        super(DiceLoss, self).__init__()

    def forward(self, preds, targets, smooth=1):
        preds = torch.sigmoid(preds)  # 将输出限制在 [0, 1] 之间
        preds = preds.view(-1)
        targets = targets.view(-1)
        intersection = (preds * targets).sum()
        dice = (2. * intersection + smooth) / (preds.sum() + targets.sum() + smooth)
        return 1 - dice

然后在训练中替换损失函数:

criterion = DiceLoss()

6.5 模型改进:加入注意力机制

可以在 UNet 的跳跃连接中加入注意力机制(如 Squeeze-and-Excitation 或 Attention Gates),以提升模型对目标区域的关注能力。

以下是一个基于 SE 模块的示例:

class SEBlock(nn.Module):
    def __init__(self, in_channels, reduction=16):
        super(SEBlock, self).__init__()
        self.global_pool = nn.AdaptiveAvgPool2d(1)
        self.fc = nn.Sequential(
            nn.Linear(in_channels, in_channels // reduction),
            nn.ReLU(inplace=True),
            nn.Linear(in_channels // reduction, in_channels),
            nn.Sigmoid()
        )

    def forward(self, x):
        batch, channels, _, _ = x.size()
        y = self.global_pool(x).view(batch, channels)
        y = self.fc(y).view(batch, channels, 1, 1)
        return x * y

将 SEBlock 插入 UNet 的编码器和解码器中。


7. 模型评估

为了评估模型性能,通常需要计算一些分割任务的指标,例如:

  • 像素精度 (Pixel Accuracy)
  • IoU (Intersection over Union)
  • Dice 系数

以下是计算 IoU 和 Dice 系数的代码:

def compute_metrics(preds, labels):
    preds = preds > 0.5  # 阈值化
    intersection = (preds & labels).sum()
    union = (preds | labels).sum()
    iou = intersection / union
    dice = (2 * intersection) / (preds.sum() + labels.sum())
    return iou, dice

在验证集上运行:

model.eval()
with torch.no_grad():
    for images, masks in val_loader:
        images, masks = images.to(DEVICE), masks.to(DEVICE)
        outputs = model(images)
        preds = torch.sigmoid(outputs) > 0.5  # 二值化预测
        
        iou, dice = compute_metrics(preds.cpu(), masks.cpu())
        print(f"IoU: {iou:.4f}, Dice: {dice:.4f}")

8. 部署与推理加速

8.1 导出 ONNX

将模型导出为 ONNX 格式以便在推理加速框架中使用:

dummy_input = torch.randn(1, 3, 256, 256).to(DEVICE)
torch.onnx.export(model, dummy_input, "unet_model.onnx", opset_version=11)

8.2 使用 TensorRT 加速推理

可以使用 NVIDIA TensorRT 对 ONNX 模型进行优化并加速推理。具体操作请参考 TensorRT 文档。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/962158.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

【视频+图文详解】HTML基础1-html和css介绍、上网原理

图文详解 html介绍 概念:html是超文本标记语言的缩写,其英文全称为HyperText Markup Language,是用来搭建网站结构的语言,比如网页上的文字,按钮,图片,视频等。html的版本分为1.0、2.0、3.0、…

VT:优化LLM推理过程的记忆与探索

📖标题:LLMs Can Plan Only If We Tell Them 🌐来源:arXiv, 2501.13545 🌟摘要 🔸大型语言模型(LLM)在自然语言处理和推理方面表现出了显著的能力,但它们在自主规划方面…

C++并发编程指南07

文章目录 [TOC]5.1 内存模型5.1.1 对象和内存位置图5.1 分解一个 struct,展示不同对象的内存位置 5.1.2 对象、内存位置和并发5.1.3 修改顺序示例代码 5.2 原子操作和原子类型5.2.1 标准原子类型标准库中的原子类型特殊的原子类型备选名称内存顺序参数 5.2.2 std::a…

日志收集Day007

1.配置ES集群TLS认证: (1)elk101节点生成证书文件 cd /usr/share/elasticsearch ./bin/elasticsearch-certutil cert -out config/elastic-certificates.p12 -pass "" --days 3650 (2)elk101节点为证书文件修改属主和属组 chown elasticsearch:elasticsearch con…

AJAX综合案例——图书管理

黑马程序员视频地址: AJAX-Day02-10.案例_图书管理AJAX-Day02-10.案例_图书管理_总结_V1.0是黑马程序员前端AJAX入门到实战全套教程,包含学前端框架必会的(ajaxnode.jswebpackgit),一套全覆盖的第25集视频&#xff0c…

Linux_线程同步生产者消费者模型

同步的相关概念 同步:在保证数据安全的前提下,让线程能够按照某种特定的顺序访问临界资源,从而有效避免饥饿问题,叫做同步竞态条件:因为时序问题,而导致程序异常,我们称之为竞态条件。 同步的…

Qt u盘自动升级软件

Qt u盘自动升级软件 Chapter1 Qt u盘自动升级软件u盘自动升级软件思路:step1. 获取U盘 判断U盘名字是否正确, 升级文件是否存在。step2. 升级step3. 升级界面 Chapter2 Qt 嵌入式设备应用程序,通过U盘升级的一种思路Chapter3 在开发板上运行的…

拦截器快速入门及详解

拦截器Interceptor 快速入门 什么是拦截器? 是一种动态拦截方法调用的机制,类似于过滤器。 拦截器是Spring框架中提供的,用来动态拦截控制器方法的执行。 拦截器的作用:拦截请求,在指定方法调用前后,根…

信息安全专业优秀毕业设计选题汇总:热点选题

目录 前言 毕设选题 开题指导建议 更多精选选题 选题帮助 最后 前言 大家好,这里是海浪学长毕设专题! 大四是整个大学期间最忙碌的时光,一边要忙着准备考研、考公、考教资或者实习为毕业后面临的升学就业做准备,一边要为毕业设计耗费大量精力。学长给大家整理…

Linux中使用unzip

安装命令 yum install unzip unzip常用选项和参数 选项 说明 -q 隐藏解压过程中的消息输出 -d /path/to/directory 指定解压文件的目标目录 -P password 如果.zip文件被密码保护,使用此选项可以指定打开文件所需的密码 解压命令 unzip 要解压的压缩包unz…

ThreadLocal源码解析

文章目录 一、概述二、get()方法三、set()方法四、可能导致的内存泄漏问题五、remove六、思考:为什么要将ThreadLocalMap的value设置为强引用? 一、概述 ThreadLocal是线程私有的,独立初始化的变量副本。存放在和线程进行绑定的ThreadLocalMa…

批量解密,再也没有任何限制了

有的时候我们在网上下载了PDF文档。发现没有办法进行任何的操作,就连打印权限都没有。今天给大家介绍的这个软件可以一键帮你进行PDF解密,非常方便,完全免费。 PDF智能助手 批量解密PDF文件 这个软件不是很大,只有10MB&#xff…

《LLM大语言模型+RAG实战+Langchain+ChatGLM-4+Transformer》

文章目录 Langchain的定义Langchain的组成三个核心组件实现整个核心组成部分 为什么要使用LangchainLangchain的底层原理Langchain实战操作LangSmithLangChain调用LLM安装openAI库-国内镜像源代码运行结果小结 使用Langchain的提示模板部署Langchain程序安装langserve代码请求格…

车载软件 --- 大一新生入门汽车零部件嵌入式开发

我是穿拖鞋的汉子,魔都中坚持长期主义的汽车电子工程师。 老规矩,分享一段喜欢的文字,避免自己成为高知识低文化的工程师: 简单,单纯,喜欢独处,独来独往,不易合同频过着接地气的生活…

有效运作神经网络

内容来自https://www.bilibili.com/video/BV1FT4y1E74V,仅为本人学习所用。 文章目录 训练集、验证集、测试集偏差、方差正则化正则化参数为什么正则化可以减少过拟合Dropout正则化Inverted Dropout其他的正则化方法数据增广Early stopping 归一化梯度消失与梯度爆…

【深度优先搜索篇】走迷宫的魔法:算法如何破解迷宫的神秘密码

当你在夜晚孤军奋战时,满天星光以为你而闪烁。 欢迎拜访:羑悻的小杀马特.-CSDN博客 本篇主题:轻轻松松拿捏洛谷走迷宫问题 制作日期:2024.12.31 隶属专栏:C/C题海汇总 首先我…

SQL进阶实战技巧:如何分析浏览到下单各步骤转化率及流失用户数?

目录 0 问题描述 1 数据准备 2 问题分析 3 问题拓展 3.1 跳出率计算 3.2 计算从浏览商品到支付订单的不同路径的用户数,并按照用户数降序排列。 往期精彩 0 问题描述 统计从浏览商品到最终下单的各个步骤的用户数和流失用户数,并计算转化率 用户表结构和…

Autosar-Os是怎么运行的?(内存保护)

写在前面: 入行一段时间了,基于个人理解整理一些东西,如有错误,欢迎各位大佬评论区指正!!! 1.功能概述 以TC397芯片为例,英飞凌芯片集成了MPU模块, MPU模块采用了硬件机…

什么是Maxscript?为什么要学习Maxscript?

MAXScript是Autodesk 3ds Max的内置脚本语言,它是一种与3dsMax对话并使3dsMax执行某些操作的编程语言。它是一种脚本语言,这意味着您不需要编译代码即可运行。通过使用一系列基于文本的命令而不是使用UI操作,您可以完成许多使用UI操作无法完成的任务。 Maxscript是一种专有…

(一)QT的简介与环境配置WIN11

目录 一、QT的概述 二、QT的下载 三、简单编程 常用快捷键 一、QT的概述 简介 Qt(发音:[kjuːt],类似“cute”)是一个跨平台的开发库,主要用于开发图形用户界面(GUI)应用程序,…