图像分割——U-Net论文介绍+代码(PyTorch)

0、概要

原理大致介绍了一下,后续会不断精进改的更加详细,然后就是代码可以对自己的数据集进行一个训练,还会不断完善,相应其他代码可以私信我。

一、论文内容总结

摘要:人们普遍认为,深度网络成功需要数千样本,在本文中,提出一种网络和训练方法,它使用大量数据增强来有效使用现存的样本,我们的体系结构由一个捕获上下文的收缩路径和能够实现精确定位的对称扩展路径组成。我们证明出这个网络可以使用少量图像进行端到端训练,并且在ISBI挑战赛上优先于先前的最佳方法(滑动窗口卷积)。并且我们的网络速度很快。

1介绍

        目前卷积神经网络的具体用途是用在分类任务上,其中对图像的输出是一个单一的类标签。然而,在许多视觉任务中,特别是在生物医学图像处理中,所期望的输出应该包括定位(每个像素都应该分配一个类标签),另外,医学图像数目不是很多。因此,ciresan等人,在一个滑动窗口设置中训练一个网络,通过在每个像素周围提供一个局部区域(补丁)来预测每个像素的类标签,这个网络可以本地化,并且在当时效果还可以,但是这个网络的也有缺陷,很慢,每个网络必须在每个补丁单独运行,而且由于重叠的补丁,会有很多多余的预测,并且补丁的大小,也决定了预测的这个像素点所结合的上下文或者说是感受野的大小,而这个补丁不能太大也不能太小。所以这就是这个网络所存在问题。

        而我们的网络,建立了一个更好的网络,所谓的全卷积网络,我们修改和扩展了这种体系结构,使它可以在很少的训练图像下,产生更精确的分割,网络结构如下图所示。

        主要思想如下(1)编码器-解码器架构(Encoder-Decoder Structure):U-Net采用了经典的编码器-解码器设计。编码器部分通过一系列的卷积和池化操作对输入图像进行下采样,目的是提取出越来越抽象的特征表示。解码器部分则通过上采样操作(例如转置卷积)逐步将这些特征映射回原始输入的空间维度,以便进行像素级别的预测。

(2)跳跃连接:它允许将编码器路径中的特征图与相应解码器层的特征进行合并。具体来说,在每个上采样步骤之后,会将对应编码器层的输出与解码器的输出拼接在一起。这样做的目的是保留局部的精细结构信息,有助于恢复分割结果中的细节,因为编码器的早期层包含更多空间信息但语义信息较少。

(3)对称性:U-Net的结构在视觉上呈现为“U”形,体现了其编码器和解码器的对称性。这种对称不仅体现在网络结构上,也反映在处理图像信息的方式中,从特征提取到细节恢复的完整流程。

(4)端到端学习与像素级预测:U-Net能够直接在每个像素上进行类别预测,实现了端到端的学习,这对于图像分割任务尤为重要。网络的输出与输入图像大小相同,每个像素都有一个类别标签,适用于精确的图像分割任务。

(5)轻量级和高效性

2、网络结构

从上图能很清晰的清楚结构,结果十分简单。

3、训练

利用输入图像及其相应的分割图,利用随机梯度下降来训练网络,由于当时还未有填充的卷积,因此输出图像比输入图像小了一个恒定的边界宽度。后边的一些解释大家可以代码过程,这里介绍起来不是很清楚。

4、数据增强

当只有少量的训练样本可用的时候,数据增强对于教会网络所需的不变性和鲁棒性是非常重要的,对于显微镜图像,我们主要需要位移和旋转不变性,以及对变形和灰度值变化的鲁棒性,而训练样本的随机弹性变形时训练一个很少标注图像的分割网络的观念概念。因此我们采用了相应的方法进行了数据增强。

二、代码结构+解释

一、工程文件中有一个文件夹叫model,里面含有两个文件夹,一个是unet_model,另一个是unet_parts,这两个用来定义模型结构。

(1)unet_parts.py  主要包含常用的一些块

""" Parts of the U-Net model """
"""https://github.com/milesial/Pytorch-UNet/blob/master/unet/unet_parts.py"""
import torch
import torch.nn as nn
import torch.nn.functional as F
# 导入torch相关库

class DoubleConv(nn.Module): # 继承pytorch中的nn.Moudle类,该类用于构建神经网络中的双卷积块,利用两次连续的卷积操作增强特征表示能力,
    """(convolution => [BN] => ReLU) * 2"""
    def __init__(self, in_channels, out_channels):  # 初始化参数,设置输入特征图参数和输出而整体参数
        super().__init__() # 调用父类的初始化方法,继承父类必要步骤
        self.double_conv = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),  # padding=1来保持输出尺寸和输入相同
            nn.BatchNorm2d(out_channels),  # 批量归一化层(BN),加速训练过程,提高模型的稳定性和泛化能力。这里针对的是 out_channels 个通道。
            nn.ReLU(inplace=True), # 应用ReLU激活函数,非线性地增加网络的表达能力
            nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True)
        )

    def forward(self, x):
        return self.double_conv(x)  # 返回处理后的输出章

#  旨在通过连续的卷积和非线性提取更高级别的特征表示
class Down(nn.Module):  # 下采样模块
    """Downscaling with maxpool then double conv"""

    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.maxpool_conv = nn.Sequential(
            nn.MaxPool2d(2),  # 最大池化层
            DoubleConv(in_channels, out_channels)  # 卷积或者是下采样
        )

    def forward(self, x):
        return self.maxpool_conv(x)


class Up(nn.Module):  # 上采样模块
    """Upscaling then double conv"""

    def __init__(self, in_channels, out_channels, bilinear=True):  # 初始化参数,选择上采样方式(双线性插值或转置卷积)、定义内部组件。这里选用的是双线性插值来进行上采样
        super().__init__()

        # if bilinear, use the normal convolutions to reduce the number of channels
        if bilinear:
            self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)# 缩放因子为2,模式是对齐角落的选项
        else:
            self.up = nn.ConvTranspose2d(in_channels // 2, in_channels // 2, kernel_size=2, stride=2)

        self.conv = DoubleConv(in_channels, out_channels) # 用于进一步处理上采样后的特征

    def forward(self, x1, x2):  # 上采样、尺寸调整以及特征融合
        x1 = self.up(x1)  # 上采样特征图
        # input is CHW
        # 计算x1和x2在高度和宽度上的插值
        diffY = torch.tensor([x2.size()[2] - x1.size()[2]])
        diffX = torch.tensor([x2.size()[3] - x1.size()[3]])

        x1 = F.pad(x1, [diffX // 2, diffX - diffX // 2,
                        diffY // 2, diffY - diffY // 2])    # 对x1进行填充


        x = torch.cat([x2, x1], dim=1) # 沿着维度进行拼接,实现特征融合
        return self.conv(x)


class OutConv(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(OutConv, self).__init__()
        self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=1)  # 1*1卷积核降维

    def forward(self, x):
        return self.conv(x)

(2)unet的网络结构,这里相对于原版的有一些更改的地方,代码也很简单

""" Full assembly of the parts to form the complete network """
"""Refer https://github.com/milesial/Pytorch-UNet/blob/master/unet/unet_model.py"""
import torch.nn as nn
import torch.nn.functional as F
from unet_parts import *
# 导入相关库
# 定义了一个完整的U-Net模型
class UNet(nn.Module):
    def __init__(self, n_channels, n_classes, bilinear=True):
        super(UNet, self).__init__()
        self.n_channels = n_channels
        self.n_classes = n_classes
        self.bilinear = bilinear

        self.inc = DoubleConv(n_channels, 64)
        self.down1 = Down(64, 128)
        self.down2 = Down(128, 256)
        self.down3 = Down(256, 512)
        self.down4 = Down(512, 512)
        self.up1 = Up(1024, 256, bilinear)
        self.up2 = Up(512, 128, bilinear)
        self.up3 = Up(256, 64, bilinear)
        self.up4 = Up(128, 64, bilinear)
        self.outc = OutConv(64, n_classes)

    def forward(self, x):
        x1 = self.inc(x)
        x2 = self.down1(x1)
        x3 = self.down2(x2)
        x4 = self.down3(x3)
        x5 = self.down4(x4)
        x = self.up1(x5, x4)
        x = self.up2(x, x3)
        x = self.up3(x, x2)
        x = self.up4(x, x1)
        logits = self.outc(x)
        return logits

if __name__ == '__main__':
    net = UNet(n_channels=3, n_classes=1)
    print(net)

二、数据集设定以及图像增强代码

第一个文件夹主要是对数据集进行处理的一个脚本,第二个就是数据集的一个样式或者说规则,在训练过程中,主要相关的代码就是utils中的dataset.py这个脚本,主要作用是根据data路径,然后对数据集进行预处理,翻转这些操作,代码如下所示

import torch
import cv2
import os
import glob
from torch.utils.data import Dataset
import random

class ISBI_Loader(Dataset):
    def __init__(self, data_path):
        # 初始化函数,读取所有data_path下的图片
        self.data_path = data_path
        self.imgs_path = glob.glob(os.path.join(data_path, 'Training_Images/*.jpg')) # 查找指定路径下的所有JPEG图片文件
        # 表示在data_path路径下的Training_Images文件夹中寻找扩展名为.jpg的所有文件,glob.glob函数会遍历这个路径并且返回一个包含所有匹配文件路径的列表

    def augment(self, image, flipCode):
        # 使用cv2.flip进行数据增强,filpCode为1水平翻转,0垂直翻转,-1水平+垂直翻转
        flip = cv2.flip(image, flipCode)
        return flip
        
    def __getitem__(self, index):
        # 根据index读取图片
        image_path = self.imgs_path[index]
        # 根据image_path生成label_path
        label_path = image_path.replace('Training_Images', 'Training_Labels')
        label_path = label_path.replace('.jpg', '.png') # todo 更新标签文件的逻辑
        # 生成对应标签图像的路径

        # 读取训练图片和标签图片
        # print(image_path)
        # print(label_path)
        image = cv2.imread(image_path)  # 读进来后就是numpy数组了
        label = cv2.imread(label_path)
        image = cv2.resize(image, (512, 512))
        label = cv2.resize(label, (512, 512), interpolation=cv2.INTER_NEAREST)
        # 对于label的图像处理时候,明确采用最近邻插值方法来处理尺寸变化,确保标签图像在缩放过程中类别标签不发生模糊,保持其原有的清晰界限。
        # 将数据转为单通道的图片
        image = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)  # BGR转成二值图
        label = cv2.cvtColor(label, cv2.COLOR_BGR2GRAY)

        # 处理标签,将像素值为255的改为1
        if label.max() > 1:
            label = label / 255
        # 随机进行数据增强,为2时不做处理,即
        flipCode = random.choice([-1, 0, 1, 2])
        if flipCode != 2:
            image = self.augment(image, flipCode)
            label = self.augment(label, flipCode)
        image = image.reshape(1, image.shape[0], image.shape[1])
        label = label.reshape(1, label.shape[0], label.shape[1])
        return image, label

    def __len__(self):
        # 返回训练集大小
        return len(self.imgs_path)


if __name__ == "__main__":
    isbi_dataset = ISBI_Loader("data/train/")
    print("数据个数:", len(isbi_dataset))
    train_loader = torch.utils.data.DataLoader(dataset=isbi_dataset,
                                               batch_size=2,
                                               shuffle=True)
    for image, label in train_loader:
        print(image.shape)

三、训练代码

这部分就是训练的一整个过程。大

from model.unet_model import UNet
from utils.dataset import ISBI_Loader
from torch import optim
import torch.nn as nn
import torch
from tqdm import tqdm


def train_net(net, device, data_path, epochs=40, batch_size=1, lr=0.00001):
    # 加载训练集
    isbi_dataset = ISBI_Loader(data_path)
    per_epoch_num = len(isbi_dataset) / batch_size
    train_loader = torch.utils.data.DataLoader(dataset=isbi_dataset,
                                               batch_size=batch_size,
                                               shuffle=True)
    # 定义RMSprop算法
    optimizer = optim.RMSprop(net.parameters(), lr=lr, weight_decay=1e-8, momentum=0.9)
    # 定义Loss算法
    criterion = nn.BCEWithLogitsLoss()
    # best_loss统计,初始化为正无穷
    best_loss = float('inf')
    # 训练epochs次
    with tqdm(total=epochs*per_epoch_num) as pbar:
        for epoch in range(epochs):
            # 训练模式
            net.train()
            # 按照batch_size开始训练
            for image, label in train_loader:
                optimizer.zero_grad()
                # 将数据拷贝到device中
                image = image.to(device=device, dtype=torch.float32)
                label = label.to(device=device, dtype=torch.float32)
                # 使用网络参数,输出预测结果
                pred = net(image)
                # 计算loss
                loss = criterion(pred, label)
                # print('{}/{}:Loss/train'.format(epoch + 1, epochs), loss.item())
                # 保存loss值最小的网络参数
                if loss < best_loss:
                    best_loss = loss
                    torch.save(net.state_dict(), 'best_model.pth')
                # 更新参数
                loss.backward()
                optimizer.step()
                pbar.update(1)


if __name__ == "__main__":
    # 选择设备,有cuda用cuda,没有就用cpu
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    # 加载网络,图片单通道1,分类为1。
    net = UNet(n_channels=1, n_classes=1)  # todo edit input_channels n_classes
    # 将网络拷贝到deivce中
    net.to(device=device)
    # 指定训练集地址,开始训练
    data_path = r"D:\新建文件夹 (3)\VOCdevkit3000\VOCdevkit\VOC2007\data" # todo 修改为你本地的数据集位置
    train_net(net, device, data_path, epochs=50, batch_size=4)

四、总结

大致可能有些粗劣的介绍了U-Net的相关原理,以及代码,给出的代码可以训练,如果有需要完整工程文件的可以私信我。有错误的地方希望批评指正,感谢感谢

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/718352.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

全面了解三大 AI 绘画:Midjourney、Stable Diffusion、DALL·E 的区别和特点

大家好&#xff0c;我是设计师阿威 在当前&#xff0c;比较流行的 AI 绘画软件主要有三个&#xff0c;分别是&#xff1a;StabilityAI 公司的 Stable Diffusion&#xff0c;OpenAI 公司的 DALLE2&#xff0c;以及更为大众所熟知的&#xff0c;Leap Motion公司创始人 David Hol…

大前端 业务架构 插件库 设计模式 属性 线程

大前端 业务架构 插件库 适配模式之(多态)协议1对多 抽象工厂模式 观察者模式 外观模式 装饰模式之参考catagory 策略模式 属性

单片机建立自己的库文件(4)

文章目录 前言一、新建自己的外设文件夹1.新建外设文件夹&#xff0c;做项目好项目文件管理2.将之前写的.c .h 文件添加到文件夹中 二、在软件中添加项目 .c文件2.1 编译工程保证没问题2. 修改项目列表下的名称 三、在软件项目中添加 .h文件路径四、实际使用测试总结 前言 提示…

性能测试、负载测试、压力测试、稳定性测试简单区分【超详细】

&#x1f345; 视频学习&#xff1a;文末有免费的配套视频可观看 &#x1f345; 点击文末小卡片 &#xff0c;免费获取软件测试全套资料&#xff0c;资料在手&#xff0c;涨薪更快 性能测试是一个总称&#xff0c;可细分为性能测试、负载测试、压力测试、稳定性测试。 性能测试…

大量用户中招,远控木马已经潜伏各类在线会议平台

从 2023 年 12 月开始&#xff0c;研究人员发现有攻击者创建虚假 Skype、Google Meet 和 Zoom 网站来进行恶意软件传播。攻击者为安卓用户投递 SpyNote 远控木马&#xff0c;为 Windows 用户投递 NjRAT 和 DCRAT 远控木马。 攻击行动概述 攻击者在单个 IP 地址上部署了所有的虚…

LabVIEW电表改装与校准仿真系统

LabVIEW开发的电表改装与校准仿真实验平台不仅简化了传统的物理实验流程&#xff0c;而且通过虚拟仿真提高了实验的效率和安全性。该平台通过模拟电表改装与校准的各个步骤&#xff0c;允许学生在没有实际硬件的情况下完成实验&#xff0c;有效地结合了理论学习和实践操作。 项…

RAG未来的出路

总有人喊RAG已死,至少看目前不现实。 持这个观点的人,大多是Long context派,老实说,这派人绝大多数不甚理解长上下文的技术实现点,就觉得反正context越长,越牛B,有点饭圈化 ,当然我并不否认长上下文对提升理解力的一些帮助,就是没大家想的那么牛B而已(说个数据,达到…

Hazelcast 分布式缓存 在Seatunnel中的使用

1、背景 最近在调研seatunnel的时候&#xff0c;发现新版的seatunnel提供了一个web服务&#xff0c;可以用于图形化的创建数据同步任务&#xff0c;然后管理任务。这里面有个日志模块&#xff0c;可以查看任务的执行状态。其中有个取读数据条数和同步数据条数。很好奇这个数据…

Playwright鼠标悬浮元素定位方法

优点&#xff1a;你把鼠标点烂&#xff0c;把它从20楼丢下去&#xff0c;元素定位就在那&#xff0c;他不动&#xff0c;我说的偶像&#xff01; F12打开浏览器的调试页面 点击源代码Sources 右侧找到事件监听器断点&#xff08;Event Listener breakpoints&#xff09;&#…

Excel 常用技巧(六)

Microsoft Excel 是微软为 Windows、macOS、Android 和 iOS 开发的电子表格软件&#xff0c;可以用来制作电子表格、完成许多复杂的数据运算&#xff0c;进行数据的分析和预测&#xff0c;并且具有强大的制作图表的功能。由于 Excel 具有十分友好的人机界面和强大的计算功能&am…

分享:大数据信用报告查询哪家好?

在现代社会&#xff0c;个人信用报告对于个人信用评估、贷款申请以及金融服务的获取至关重要。随着大数据技术的发展&#xff0c;越来越多的平台提供了便捷的大数据信用报告查询服务。那么&#xff0c;到底应该选择哪家平台来查询大数据信用报告呢?以下是一些选择标准和推荐。…

标准立项 | 给水中试基地建设导则

结合近几年在已设计、建设和运维的不同规模的给水中试基地&#xff0c;凝练建设实践中所获得的实际经验和关键指标及参数&#xff0c;编制《给水中试基地建设导则》&#xff0c;以填补标准空白&#xff0c;统一建设标准。

LabVIEW共享变量

共享变量简介 LabVIEW​为​创建​分布​式​应用使用​共享​变量​可以简化​此类​应用的编程。​ 借助​共享​变量&#xff0c;​您​可以​在​同​一个​程序​框​图​的​不同​循环​之间​或者网络上​的​不同VI之间​共享​数据。与LabVIEW中的许多​其他数据​共…

GPT-4o的视觉识别能力,将绕过所有登陆的图形验证码

知识星球&#x1f517;除了包含技术干货&#xff1a;《Java代码审计》《Web安全》《应急响应》《护网资料库》《网安面试指南》还包含了安全中常见的售前护网案例、售前方案、ppt等&#xff0c;同时也有面向学生的网络安全面试、护网面试等。 我们来看一下市面上常见的图形验证…

在Qt编写的exe或者dll中设置版本号

1.背景 在别人编写的exe或者动态库中&#xff0c;通过右键–》属性–》详细信息中&#xff0c;通常都有版本信息&#xff1a; 那我们自己编译出来的Qt程序&#xff0c;如何设置这些版本号呢&#xff1f; 2.解决方案 参考【.pro文件中设置版本等信息】&#xff0c;只要在工…

50etf期权交易规则杠杆怎么计算?

今天带你了解50etf期权交易规则杠杆怎么计算&#xff1f;近年来&#xff0c;期权交易在股票市场中变得愈发流行&#xff0c;其中50ETF期权备受关注。作为一种金融衍生品&#xff0c;50ETF期权为投资者提供了更灵活的投资方式和更多的策略选择。 50etf期权交易规则杠杆怎么计算&…

介绍并改造一个作用于Anki笔记浏览器的插件

在Anki的笔记浏览器窗口中&#xff0c;作为主体部分的表格在对获取到的笔记进行排序时&#xff0c;最多只能有一个排序字段&#xff0c;在设定笔记的排序字段后&#xff0c;没法将表格中的笔记按其他字段进行排序。要满足这个需求&#xff0c;可以使用Advanced Browser插件&…

spring框架(SSM)

Spring Framework系统架构 Spring框架是一个开源的企业级Java应用程序框架&#xff0c;它为开发Java应用程序提供了一个全方位的解决方案。Spring的核心优势在于它的分层架构&#xff0c;这使得开发者可以灵活选择使用哪些模块而无需引入不需要的依赖。下面是Spring框架的一些关…

Linux 下VS Code 弹出 快速修复,导致 BackSpace 无法删除

最近在Linux下使用VSCode&#xff0c;发现有错误的代码选中了无法删除 这个时候&#xff0c;你按BackSpace 是无法删除的&#xff0c;很恼火&#xff01; 把这些禁用了之后&#xff0c;就不会弹出这个框&#xff0c;这样可以顺利选中删除&#xff01; 感觉这个是不是vs code 插…

刷题笔记2:用位运算找“只出现一次的一个数”

1. & 和 | 的基本操作 137. 只出现一次的数字 II - 力扣&#xff08;LeetCode&#xff09; 先对位运算的操作进行复习&#xff1a; 1、>> 右移操作符 移位规则&#xff1a;⾸先右移运算分两种&#xff1a; 1. 逻辑右移&#xff1a;左边⽤0填充&#xff0c;右边丢…