深度学习落地实战:基于UNet实现血管瘤超声图像分割

前言

大家好,我是机长

本专栏将持续收集整理市场上深度学习的相关项目,旨在为准备从事深度学习工作或相关科研活动的伙伴,储备、提升更多的实际开发经验,每个项目实例都可作为实际开发项目写入简历,且都附带完整的代码与数据集。可通过百度云盘进行获取,实现开箱即用

正在跟新中~

项目背景

(基于UNet实现血管瘤超声图像分割)

血管瘤是一种常见的血管异常疾病,尤其在婴幼儿中具有较高的发病率。超声检查作为一种无创的诊断手段,能够为临床提供血管瘤的位置、形状及累及范围等重要信息,对指导医生进一步治疗至关重要。然而,目前血管瘤病灶的分割主要依赖于专家的人工勾画,这不仅耗时耗力,还容易受到临床经验水平的影响,导致分割结果存在人为误差。因此,利用深度学习技术实现血管瘤超声图像的自动精准分割,成为了一个医学应用与研究的热点方向方向。

项目环境

  • 平台:windows 10
  • 语言环境:python 3.8
  • 编辑器:PyCharm
  • PyThorch版本:1.8

1.创建并跳转到虚拟环境

python -m venv myenv

myenv\Scripts\activate.bat

2. 虚拟环境pip命令安装其他工具包

pip install torch torchvision torchaudio

注:此处只示范安装pytorch,其他工具包安装类似,可通过运行代码查看所确实包提示进行安装

3.pycharm 运行环境配置

进入pytcharm =》点击file =》点击settings=》点击Project:...=》点击 Python Interpreter,进入如下界面

点击add =》点击Existing environment  =》 点击 ... =》选择第一步1创建虚拟环境目录myenv\Scripts\下的python.exe文件点击ok完成环境配置

数据集介绍

数据集分分训练数据集与标签数据集

                

       训练数据样式                                                               标注数据样式

训练数据获取:

私信博主获取

UNet网络介绍

U-Net网络的以其独特的U型结构著称,这种结构由编码器(Encoder)和解码器(Decoder)两大部分组成,非常适合于医学图像分割等任务。下面我将进一步解释您提到的关键点,并补充一些细节。

Encoder(编码器)
  • 结构:编码器的左半部分主要负责特征提取。它通过多个卷积层(通常是3x3的卷积核)和ReLU激活函数来逐层提取图像的特征。在每个卷积层组合之后,通常会添加一个2x2的最大池化层(Max Pooling)来降低特征图的分辨率,并增加感受野。这种下采样操作有助于捕获图像中的全局信息。
  • 作用:编码器通过逐层提取和抽象化图像特征,为后续的分割任务提供丰富的信息。
Decoder(解码器)
  • 结构:解码器的右半部分则负责将编码器提取的特征图恢复到原始图像的分辨率,以便进行像素级别的分类。它首先通过上采样(通常是转置卷积或双线性插值)来增加特征图的尺寸,然后通过特征拼接(Concatenation)将上采样后的特征图与编码器对应层级的特征图进行融合。之后,再通过卷积层和ReLU激活函数进一步处理融合后的特征图。
  • 特征拼接:与FCN(全卷积网络)的逐点相加不同,U-Net使用特征拼接来融合不同层级的特征图。这种拼接方式能够保留更多的特征信息,形成更“厚”的特征图,但同时也需要更多的显存来存储这些特征。
  • 作用:解码器通过逐步上采样和特征融合,将编码器提取的高级特征与低级特征相结合,恢复出精细的图像分割结果。
优点与挑战
  • 优点:U-Net网络结构简洁,但性能强大,特别适用于医学图像分割等任务。其独特的U型结构和特征拼接方式使得网络能够同时捕获全局和局部特征,从而实现高精度的分割。
  • 挑战:尽管U-Net在性能上表现出色,但其对显存的需求也相对较高。特别是当处理高分辨率图像或进行大规模训练时,显存消耗可能会成为限制因素。此外,网络的设计和训练也需要大量的专业知识和经验。

综上所述,U-Net网络通过其独特的编码器-解码器结构和特征拼接方式,在医学图像分割等领域取得了显著成效。然而,在实际应用中仍需注意显存消耗等挑战,并不断探索更高效的优化方法。

pytorch实现UNet网络

class double_conv2d_bn(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size=3, strides=1, padding=1):
        super(double_conv2d_bn, self).__init__()
        self.conv1 = nn.Conv2d(in_channels, out_channels,
                               kernel_size=kernel_size,
                               stride=strides, padding=padding, bias=True)
        self.conv2 = nn.Conv2d(out_channels, out_channels,
                               kernel_size=kernel_size,
                               stride=strides, padding=padding, bias=True)
        self.bn1 = nn.BatchNorm2d(out_channels)
        self.bn2 = nn.BatchNorm2d(out_channels)

    def forward(self, x):
        out = F.relu(self.bn1(self.conv1(x)))
        out = F.relu(self.bn2(self.conv2(out)))
        return out


class deconv2d_bn(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size=2, strides=2):
        super(deconv2d_bn, self).__init__()
        self.conv1 = nn.ConvTranspose2d(in_channels, out_channels,
                                        kernel_size=kernel_size,
                                        stride=strides, bias=True)
        self.bn1 = nn.BatchNorm2d(out_channels)

    def forward(self, x):
        out = F.relu(self.bn1(self.conv1(x)))
        return out


class Unet(nn.Module):
    def __init__(self):
        super(Unet, self).__init__()
        self.layer1_conv = double_conv2d_bn(1, 8)
        self.layer2_conv = double_conv2d_bn(8, 16)
        self.layer3_conv = double_conv2d_bn(16, 32)
        self.layer4_conv = double_conv2d_bn(32, 64)
        self.layer5_conv = double_conv2d_bn(64, 128)
        self.layer6_conv = double_conv2d_bn(128, 64)
        self.layer7_conv = double_conv2d_bn(64, 32)
        self.layer8_conv = double_conv2d_bn(32, 16)
        self.layer9_conv = double_conv2d_bn(16, 8)
        self.layer10_conv = nn.Conv2d(8, 1, kernel_size=3,
                                      stride=1, padding=1, bias=True)

        self.deconv1 = deconv2d_bn(128, 64)
        self.deconv2 = deconv2d_bn(64, 32)
        self.deconv3 = deconv2d_bn(32, 16)
        self.deconv4 = deconv2d_bn(16, 8)

        self.sigmoid = nn.Sigmoid()

    def forward(self, x):
        conv1 = self.layer1_conv(x)
        pool1 = F.max_pool2d(conv1, 2)

        conv2 = self.layer2_conv(pool1)
        pool2 = F.max_pool2d(conv2, 2)

        conv3 = self.layer3_conv(pool2)
        pool3 = F.max_pool2d(conv3, 2)

        conv4 = self.layer4_conv(pool3)
        pool4 = F.max_pool2d(conv4, 2)

        conv5 = self.layer5_conv(pool4)

        convt1 = self.deconv1(conv5)
        concat1 = torch.cat([convt1, conv4], dim=1)
        conv6 = self.layer6_conv(concat1)

        convt2 = self.deconv2(conv6)
        concat2 = torch.cat([convt2, conv3], dim=1)
        conv7 = self.layer7_conv(concat2)

        convt3 = self.deconv3(conv7)
        concat3 = torch.cat([convt3, conv2], dim=1)
        conv8 = self.layer8_conv(concat3)

        convt4 = self.deconv4(conv8)
        concat4 = torch.cat([convt4, conv1], dim=1)
        conv9 = self.layer9_conv(concat4)
        outp = self.layer10_conv(conv9)
        outp = self.sigmoid(outp)
        return outp

自定义加载数据集

class LiverDataset(data.Dataset):
    def __init__(self, root, transform=None, target_transform=None, mode='train'):
        n = len(os.listdir(root + '/images'))

        imgs = []

        if mode == 'train':
            for i in range(n):
                img = os.path.join(root, 'images',
                                   "%d.png" % (i + 1))
                mask = os.path.join(root, 'mask', "%d.png" % (i + 1))
                imgs.append([img, mask])
        else:
            for i in range(n):
                img = os.path.join(root, 'images',
                                   "%d.png" % (i + 1))
                mask = os.path.join(root, 'mask', "%d.png" % (i + 1))
                imgs.append([img, mask])

        self.imgs = imgs
        self.transform = transform
        self.target_transform = target_transform

    def __getitem__(self, index):
        x_path, y_path = self.imgs[index]
        img_x = Image.open(x_path)
        img_y = Image.open(y_path)
        if self.transform is not None:
            img_x = self.transform(img_x)
        if self.target_transform is not None:
            img_y = self.target_transform(img_y)
        return img_x, img_y

    def __len__(self):
        return len(self.imgs)

完整代码

import os

import PIL.Image as Image
import matplotlib.pyplot as plt
import matplotlib
matplotlib.use('TkAgg')
#module 'backend_interagg' has no attribute 'FigureCanvas',报错执行pip install matplotlib==3.5.0
import numpy as np
#关于numpy 报错时执行 pip install numpy==1.23.5
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.data as data
from torchvision import transforms
from tqdm import tqdm

epochs = 1
batch_size = 32
device = 'cpu'
best_model = None
best_loss = 999
save_path = './checkpoints/best_model.pkl'
directory_path = './checkpoints'

# 检查目录是否存在
if not os.path.exists(directory_path):
    # 如果目录不存在,则创建它
    os.makedirs(directory_path)


class double_conv2d_bn(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size=3, strides=1, padding=1):
        super(double_conv2d_bn, self).__init__()
        self.conv1 = nn.Conv2d(in_channels, out_channels,
                               kernel_size=kernel_size,
                               stride=strides, padding=padding, bias=True)
        self.conv2 = nn.Conv2d(out_channels, out_channels,
                               kernel_size=kernel_size,
                               stride=strides, padding=padding, bias=True)
        self.bn1 = nn.BatchNorm2d(out_channels)
        self.bn2 = nn.BatchNorm2d(out_channels)

    def forward(self, x):
        out = F.relu(self.bn1(self.conv1(x)))
        out = F.relu(self.bn2(self.conv2(out)))
        return out


class deconv2d_bn(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size=2, strides=2):
        super(deconv2d_bn, self).__init__()
        self.conv1 = nn.ConvTranspose2d(in_channels, out_channels,
                                        kernel_size=kernel_size,
                                        stride=strides, bias=True)
        self.bn1 = nn.BatchNorm2d(out_channels)

    def forward(self, x):
        out = F.relu(self.bn1(self.conv1(x)))
        return out


class Unet(nn.Module):
    def __init__(self):
        super(Unet, self).__init__()
        self.layer1_conv = double_conv2d_bn(1, 8)
        self.layer2_conv = double_conv2d_bn(8, 16)
        self.layer3_conv = double_conv2d_bn(16, 32)
        self.layer4_conv = double_conv2d_bn(32, 64)
        self.layer5_conv = double_conv2d_bn(64, 128)
        self.layer6_conv = double_conv2d_bn(128, 64)
        self.layer7_conv = double_conv2d_bn(64, 32)
        self.layer8_conv = double_conv2d_bn(32, 16)
        self.layer9_conv = double_conv2d_bn(16, 8)
        self.layer10_conv = nn.Conv2d(8, 1, kernel_size=3,
                                      stride=1, padding=1, bias=True)

        self.deconv1 = deconv2d_bn(128, 64)
        self.deconv2 = deconv2d_bn(64, 32)
        self.deconv3 = deconv2d_bn(32, 16)
        self.deconv4 = deconv2d_bn(16, 8)

        self.sigmoid = nn.Sigmoid()

    def forward(self, x):
        conv1 = self.layer1_conv(x)
        pool1 = F.max_pool2d(conv1, 2)

        conv2 = self.layer2_conv(pool1)
        pool2 = F.max_pool2d(conv2, 2)

        conv3 = self.layer3_conv(pool2)
        pool3 = F.max_pool2d(conv3, 2)

        conv4 = self.layer4_conv(pool3)
        pool4 = F.max_pool2d(conv4, 2)

        conv5 = self.layer5_conv(pool4)

        convt1 = self.deconv1(conv5)
        concat1 = torch.cat([convt1, conv4], dim=1)
        conv6 = self.layer6_conv(concat1)

        convt2 = self.deconv2(conv6)
        concat2 = torch.cat([convt2, conv3], dim=1)
        conv7 = self.layer7_conv(concat2)

        convt3 = self.deconv3(conv7)
        concat3 = torch.cat([convt3, conv2], dim=1)
        conv8 = self.layer8_conv(concat3)

        convt4 = self.deconv4(conv8)
        concat4 = torch.cat([convt4, conv1], dim=1)
        conv9 = self.layer9_conv(concat4)
        outp = self.layer10_conv(conv9)
        outp = self.sigmoid(outp)
        return outp


class LiverDataset(data.Dataset):
    def __init__(self, root, transform=None, target_transform=None, mode='train'):
        n = len(os.listdir(root + '/images'))

        imgs = []

        if mode == 'train':
            for i in range(n):
                img = os.path.join(root, 'images',
                                   "%d.png" % (i + 1))
                mask = os.path.join(root, 'mask', "%d.png" % (i + 1))
                imgs.append([img, mask])
        else:
            for i in range(n):
                img = os.path.join(root, 'images',
                                   "%d.png" % (i + 1))
                mask = os.path.join(root, 'mask', "%d.png" % (i + 1))
                imgs.append([img, mask])

        self.imgs = imgs
        self.transform = transform
        self.target_transform = target_transform

    def __getitem__(self, index):
        x_path, y_path = self.imgs[index]
        img_x = Image.open(x_path)
        img_y = Image.open(y_path)
        if self.transform is not None:
            img_x = self.transform(img_x)
        if self.target_transform is not None:
            img_y = self.target_transform(img_y)
        return img_x, img_y

    def __len__(self):
        return len(self.imgs)


x_transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.RandomResizedCrop(224),
    transforms.Grayscale(num_output_channels=1)
])

y_transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.RandomResizedCrop(224)
])

liver_dataset = LiverDataset("./data/train", transform=x_transform, target_transform=y_transform)
dataloader = torch.utils.data.DataLoader(liver_dataset, batch_size=batch_size, shuffle=True)

model = Unet()
criterion = torch.nn.BCELoss()
#criterion = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters())

for epoch in range(epochs):
    model.train()
    epoch_loss = 0
    train_bar = tqdm(dataloader)
    for x, y in train_bar:
        optimizer.zero_grad()
        inputs = x.to(device)
        labels = y.to(device)
        outputs = model(inputs)
        labels=labels.bool().float()
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        epoch_loss += loss.item()

    print("【EPOCH: 】%s" % str(epoch + 1))
    print("训练损失为%s" % str(epoch_loss))

    if epoch_loss < best_loss:
        best_loss = epoch_loss
        best_model = model.state_dict()

    # 在训练结束保存最优的模型参数
    if epoch == epochs - 1:
        # 保存模型
        torch.save(best_model, save_path)

print('Finished Training')

plt.figure('测试一张图片')
pil_img = Image.open('./data/train/images/1.png')
np_img = np.array(pil_img)
plt.imshow(np_img)
plt.show()

plt.figure('测试一张蒙版图片')
pil_img = Image.open('./data/train/mask/1.png')
np_img = np.array(pil_img)
plt.imshow(np_img)
plt.show()

开启训练

训练结果展示

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/802374.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

cpp 强制转换

一、static_cast static_cast 是 C 中的一个类型转换操作符&#xff0c;用于在类的层次结构中进行安全的向上转换&#xff08;从派生类到基类&#xff09;或进行不需要运行时类型检查的转换。它主要用于基本数据类型之间的转换、对象指针或引用的向上转换&#xff08;即从派生…

【Redis】集群

文章目录 一、集群是什么&#xff1f;二、 Redis集群分布式存储为什么redis集群的最大槽数是16384&#xff08;不太懂&#xff09;redis的集群主节点数量基本不可能超过1000个 三、 配置集群&#xff08;三主三从&#xff09;3.1 配置config文件3.2 启动六台redis3.2 通过redis…

铜管和铝管、铝管和铝管焊接操作介绍

一、部分品牌冰箱、空调采用铜铝管或铝铝管之间的连接方式&#xff0c;连接方式有以下两种&#xff1a; 1、洛克环&#xff1a;是方便简单的方式&#xff0c;但其需从德国采购&#xff0c;成本过于高昂而且采购周期长&#xff1b; 2、铜铝异种材料钎焊技术&#xff1a;国内可…

怎样在 PostgreSQL 中优化对大表的索引创建和维护的性能开销?

&#x1f345;关注博主&#x1f397;️ 带你畅游技术世界&#xff0c;不错过每一次成长机会&#xff01;&#x1f4da;领书&#xff1a;PostgreSQL 入门到精通.pdf 文章目录 怎样在 PostgreSQL 中优化对大表的索引创建和维护的性能开销&#xff1f;一、理解大表和索引的概念&am…

[C++]——同步异步日志系统(7)

同步异步日志系统 一、日志器管理模块&#xff08;单例模式&#xff09;1.1 对日志器管理器进行设计1.2 实现日志器管理类的各个功能1.3. 设计一个全局的日志器建造者1.4 测试日志器管理器的接口和全局建造者类 二、宏函数和全局接口设计2.1 新建一个.h,文件,文件里面放我们写的…

小欧吃苹果-OPPO 2024届校招正式批笔试题-数据开发(C卷)

在处理这个问题前&#xff0c;先看一个经典的贪心算法题目。信息学奥赛一本通&#xff08;C版&#xff09;在线评测系统http://ybt.ssoier.cn:8088/problem_show.php?pid1320 注意移动纸牌的贪心策略并不是题目中给出的移动次序&#xff1a;第1堆纸牌9<10&#xff0c;因为是…

几何相关计算

目录 一、 判断两个矩形是否相交 二、判断两条线段是否相交 三、判断点是否在多边形内 四、垂足计算 五、贝塞尔曲线 六、坐标系 一、 判断两个矩形是否相交 当矩形1的最大值比矩形2的最小值都小&#xff0c;那矩形1和矩形2一定不相交&#xff0c;其他同理。 struct Po…

【STM32】按键控制LED光敏传感器控制蜂鸣器(江科大)

一、按键控制LED LED.c #include "stm32f10x.h" // Device header/*** 函 数&#xff1a;LED初始化* 参 数&#xff1a;无* 返 回 值&#xff1a;无*/ void LED_Init(void) {/*开启时钟*/RCC_APB2PeriphClockCmd(RCC_APB2Periph_GPIOA, ENAB…

醇香之旅:探索红酒的无穷魅力

在浩渺的饮品世界里&#xff0c;红酒如同一颗璀璨的星辰&#xff0c;闪烁着诱人的光芒。它以其不同的醇香和深邃的韵味&#xff0c;吸引着无数人的目光。今天&#xff0c;就让我们一起踏上这场醇香之旅&#xff0c;探索雷盛红酒所带来的无穷魅力。 一、初识红酒的醇香 当我们…

去除重复字母

题目链接 去除重复字母 题目描述 注意点 s 由小写英文字母组成1 < s.length < 10^4需保证 返回结果的字典序最小&#xff08;要求不能打乱其他字符的相对位置&#xff09; 解答思路 本题与移掉 K 位数字类似&#xff0c;需要注意的是&#xff0c;并不是每个字母都能…

张量分解(4)——SVD奇异值分解

&#x1f345; 写在前面 &#x1f468;‍&#x1f393; 博主介绍&#xff1a;大家好&#xff0c;这里是hyk写算法了吗&#xff0c;一枚致力于学习算法和人工智能领域的小菜鸟。 &#x1f50e;个人主页&#xff1a;主页链接&#xff08;欢迎各位大佬光临指导&#xff09; ⭐️近…

01 机器学习概述

目录 1. 基本概念 2. 机器学习三要素 3. 参数估计的四个方法 3.1 经验风险最小化 3.2 结构风险最小化 3.3 最大似然估计 3.4 最大后验估计 4. 偏差-方差分解 5. 机器学习算法的类型 6. 数据的特征表示 7. 评价指标 1. 基本概念 机器学习&#xff08;Machine Le…

【python】PyQt5的窗口界面的各种交互逻辑实现,轻松掌控图形化界面程序

✨✨ 欢迎大家来到景天科技苑✨✨ &#x1f388;&#x1f388; 养成好习惯&#xff0c;先赞后看哦~&#x1f388;&#x1f388; &#x1f3c6; 作者简介&#xff1a;景天科技苑 &#x1f3c6;《头衔》&#xff1a;大厂架构师&#xff0c;华为云开发者社区专家博主&#xff0c;…

C# modbus 图表

控件&#xff1a;chart1(图表)&#xff0c;cartesianChart1(第三方添加图表)&#xff0c;timer(时间) 添加第三方&#xff1a; 效果&#xff1a;图标会根据连接的温度&#xff0c;湿度用timer时间进行改变 Chart1控件样式&#xff1a;Series添加线条&#xff0c;颜色&#xf…

【算法】LRU缓存

难度&#xff1a;中等 题目&#xff1a; 请你设计并实现一个满足 LRU (最近最少使用) 缓存 约束的数据结构。 实现 LRUCache 类&#xff1a; LRUCache(int capacity) 以 正整数 作为容量 capacity 初始化 LRU 缓存int get(int key) 如果关键字 key 存在于缓存中&#xff0c;…

2024牛客暑期多校训练营1 A题 解题思路

前言&#xff1a; 今年和队友报了牛客暑期多校比赛&#xff0c;写了一下午结果除了签到题之外只写出了一道题&#xff08;A&#xff09;&#xff0c;签到题没什么好说的&#xff0c;其他题我也没什么好说的&#xff08;太菜了&#xff0c;根本写不出来&#xff09;&#xff0c;…

SAP ABAP性能优化

1.前言 ABAP作为SAP的专用的开发语言&#xff0c;衡量其性能的指标主要有以下两个方面&#xff1a; 响应时间&#xff1a;对于某项特定的业务请求&#xff0c;系统在收到请求后需要多久返回结果 吞吐量&#xff1a;在给定的时间能&#xff0c;系统能够处理的数据量 2. ABAP语…

FFMPEG录屏入门指南【转载】

文章非原创&#xff0c;为防失联而转载&#xff1a;【原创】FFMPEG录屏入门指南 - 博客园 (cnblogs.com) 【原创】FFMPEG录屏入门指南 最近部门内部在做技术分享交流&#xff0c;需要将内容录制成视频存档。很自然的想到了去网上找一些录屏的软件&#xff0c;试过了几款诸如屏幕…

昇思25天学习打卡营第13天|CycleGAN 图像风格迁移互换全流程解析

目录 数据集下载和加载 可视化 构建生成器 构建判别器 优化器和损失函数 前向计算 计算梯度和反向传播 模型训练 模型推理 数据集下载和加载 使用 download 接口下载数据集&#xff0c;并将下载后的数据集自动解压到当前目录下。数据下载之前需要使用 pip install dow…

LabVIEW设备检修信息管理系统

开发了基于LabVIEW设计平台开发的设备检修信息管理系统。该系统应用于各种设备的检修基地&#xff0c;通过与基地管理信息系统的连接和数据交换&#xff0c;实现了本地检修工位数据的远程自动化管理&#xff0c;提高了设备的检修效率和安全性。 项目背景 现代设备运维过程中信…