基于U-Net的视网膜血管分割(Pytorch完整版)

基于 U-Net 的视网膜血管分割是一种应用深度学习的方法,特别是 U-Net 结构,用于从眼底图像中分割出视网膜血管。U-Net 是一种全卷积神经网络(FCN),通常用于图像分割任务。以下是基于 U-Net 的视网膜血管分割的内容:
框架结构:
在这里插入图片描述
代码结构:
在这里插入图片描述

U-Net分割代码:

unet_model.py

import torch.nn.functional as F
from .unet_parts import *
class UNet(nn.Module):
    def __init__(self, n_channels, n_classes, bilinear=True):
        super(UNet, self).__init__()
        self.n_channels = n_channels
        self.n_classes = n_classes
        self.bilinear = bilinear

        self.inc = DoubleConv(n_channels, 64)
        self.down1 = Down(64, 128)
        self.down2 = Down(128, 256)
        self.down3 = Down(256, 512)
        self.down4 = Down(512, 512)
        self.up1 = Up(1024, 256, bilinear)
        self.up2 = Up(512, 128, bilinear)
        self.up3 = Up(256, 64, bilinear)
        self.up4 = Up(128, 64, bilinear)
        self.outc = OutConv(64, n_classes)



    def forward(self, x):
        x1 = self.inc(x)

        # 在编码器下采样过程加空间注意力
        # x2 = self.down1(self.sp1(x1))
        # x3 = self.down2(self.sp2(x2))
        # x4 = self.down3(self.sp3(x3))
        # x5 = self.down4(self.sp4(x4))

        x2 = self.down1(x1)
        x3 = self.down2(x2)
        x4 = self.down3(x3)
        x5 = self.down4(x4)

        x = self.up1(x5, x4)
        x = self.up2(x, x3)
        x = self.up3(x, x2)
        x = self.up4(x, x1)

        logits = self.outc(x)
        return logits

if __name__ == '__main__':
    net = UNet(n_channels=3, n_classes=1)
    print(net)

unet_parts.py

import torch
import torch.nn as nn
import torch.nn.functional as F


class DoubleConv(nn.Module):
    """(convolution => [BN] => ReLU) * 2"""

    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.double_conv = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True)
        )

        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(m.weight)
            elif isinstance(m, nn.BatchNorm2d):
                nn.init.constant_(m.weight, 1)
                nn.init.constant_(m.bias, 0)
            elif isinstance(m, nn.Linear):
                nn.init.constant_(m.bias, 0)


    def forward(self, x):
        return self.double_conv(x)


class Down(nn.Module):
    """Downscaling with maxpool then double conv"""

    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.maxpool_conv = nn.Sequential(
            nn.MaxPool2d(2),
            DoubleConv(in_channels, out_channels)
        )

    def forward(self, x):
        return self.maxpool_conv(x)


class Up(nn.Module):
    """Upscaling then double conv"""

    def __init__(self, in_channels, out_channels, bilinear=True):
        super().__init__()

        # if bilinear, use the normal convolutions to reduce the number of channels
        if bilinear:
            self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
        else:
            self.up = nn.ConvTranspose2d(in_channels // 2, in_channels // 2, kernel_size=2, stride=2)

        self.conv = DoubleConv(in_channels, out_channels)

    def forward(self, x1, x2):
        x1 = self.up(x1)
        # input is CHW
        diffY = torch.tensor([x2.size()[2] - x1.size()[2]])
        diffX = torch.tensor([x2.size()[3] - x1.size()[3]])

        x1 = F.pad(x1, [diffX // 2, diffX - diffX // 2,
                        diffY // 2, diffY - diffY // 2])

        x = torch.cat([x2, x1], dim=1)
        return self.conv(x)


class OutConv(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(OutConv, self).__init__()
        self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=1)

        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(m.weight)
            elif isinstance(m, nn.BatchNorm2d):
                nn.init.constant_(m.weight, 1)
                nn.init.constant_(m.bias, 0)
            elif isinstance(m, nn.Linear):
                nn.init.constant_(m.bias, 0)

    def forward(self, x):
        return self.conv(x)

trainval.py

from model.unet_model import UNet
from utils.dataset import FundusSeg_Loader

from torch import optim
import torch.nn as nn
import torch
import sys
import matplotlib.pyplot as plt
from tqdm import tqdm
import time

train_data_path = "DRIVE/drive_train/"
valid_data_path = "DRIVE/drive_test/"
# hyperparameter-settings
N_epochs = 500
Init_lr = 0.00001

def train_net(net, device, epochs=N_epochs, batch_size=1, lr=Init_lr):
    # 加载训练集
    train_dataset = FundusSeg_Loader(train_data_path, 1)
    valid_dataset = FundusSeg_Loader(valid_data_path, 0)
    train_loader = torch.utils.data.DataLoader(dataset=train_dataset, batch_size=batch_size, shuffle=True)
    valid_loader = torch.utils.data.DataLoader(dataset=valid_dataset, batch_size=batch_size, shuffle=False)
    print('Traing images: %s' % len(train_loader))
    print('Valid  images: %s' % len(valid_loader))

    # 定义RMSprop算法
    optimizer = optim.RMSprop(net.parameters(), lr=lr, weight_decay=1e-8, momentum=0.9)
    # 定义Loss算法
    # BCEWithLogitsLoss会对predict进行sigmoid处理
    # criterion 常被用来定义损失函数,方便调换损失函数
    criterion = nn.BCEWithLogitsLoss()
    # 训练epochs次
    # 求最小值,所以初始化为正无穷
    best_loss = float('inf')
    train_loss_list = []
    val_loss_list = []
    for epoch in range(epochs):
        # 训练模式
        net.train()
        train_loss = 0
        print(f'Epoch {epoch + 1}/{epochs}')
        # SGD
        # train_loss_list = []
        # val_loss_list = []
        with tqdm(total=train_loader.__len__()) as pbar:
            for i, (image, label, filename) in enumerate(train_loader):
                optimizer.zero_grad()
                # 将数据拷贝到device中
                image = image.to(device=device, dtype=torch.float32)
                label = label.to(device=device, dtype=torch.float32)
                # 使用网络参数,输出预测结果
                pred = net(image)
                # print(pred)
                # 计算loss
                loss = criterion(pred, label)
                # print(loss)
                train_loss = train_loss + loss.item()

                loss.backward()
                optimizer.step()
                pbar.set_postfix(loss=float(loss.cpu()), epoch=epoch)
                pbar.update(1)

        train_loss_list.append(train_loss / i)
        print('Loss/train', train_loss / i)

        # Validation
        net.eval()
        val_loss = 0
        for i, (image, label, filename) in tqdm(enumerate(valid_loader), total=len(valid_loader)):
            image = image.to(device=device, dtype=torch.float32)
            label = label.to(device=device, dtype=torch.float32)
            pred = net(image)
            loss = criterion(pred, label)
            val_loss = val_loss + loss.item()
            # net.state_dict()就是用来保存模型参数的
            if val_loss < best_loss:
                best_loss = val_loss
                torch.save(net.state_dict(), 'best_model.pth')
                print('saving model............................................')

        val_loss_list.append(val_loss / i)
        print('Loss/valid', val_loss / i)
        sys.stdout.flush()
    return val_loss_list, train_loss_list


if __name__ == "__main__":
    # 选择设备cuda
    device = torch.device('cuda')
    # 加载网络,图片单通道1,分类为1。
    net = UNet(n_channels=3, n_classes=1)
    # 将网络拷贝到deivce中
    net.to(device=device)
    # 开始训练
    val_loss_list, train_loss_list = train_net(net, device)
    # 保存loss值到txt文件
    fileObject1 = open('train_loss.txt', 'w')
    for train_loss in train_loss_list:
        fileObject1.write(str(train_loss))
        fileObject1.write('\n')
    fileObject1.close()
    fileObject2 = open('val_loss.txt', 'w')
    for val_loss in val_loss_list:
        fileObject2.write(str(val_loss))
        fileObject2.write('\n')
    fileObject2.close()
    # 我这里迭代了5次,所以x的取值范围为(0,5),然后再将每次相对应的5损失率附在x上
    x = range(0, N_epochs)
    y1 = val_loss_list
    y2 = train_loss_list
    # 两行一列第一个
    plt.subplot(1, 1, 1)
    plt.plot(x, y1, 'r.-', label=u'val_loss')
    plt.plot(x, y2, 'g.-', label =u'train_loss')
    plt.title('loss')
    plt.xlabel('epochs')
    plt.ylabel('loss')
    plt.savefig("accuracy_loss.jpg")
    plt.show()

predict.py

import numpy as np
import torch
import cv2
from model.unet_model import UNet

from utils.dataset import FundusSeg_Loader
import copy
from sklearn.metrics import roc_auc_score

model_path='./best_model.pth'
test_data_path = "DRIVE/drive_test/"
save_path='./results/'

if __name__ == "__main__":
    test_dataset = FundusSeg_Loader(test_data_path,0)
    test_loader = torch.utils.data.DataLoader(dataset=test_dataset, batch_size=1, shuffle=False)
    print('Testing images: %s' %len(test_loader))
    # 选择设备CUDA
    device = torch.device('cuda')
    # 加载网络,图片单通道,分类为1。
    net = UNet(n_channels=3, n_classes=1)
    # 将网络拷贝到deivce中
    net.to(device=device)
    # 加载模型参数
    print(f'Loading model {model_path}')
    net.load_state_dict(torch.load(model_path, map_location=device))

    # 测试模式
    net.eval()
    tp = 0
    tn = 0
    fp = 0
    fn = 0
    pred_list = []
    label_list = []
    for image, label, filename in test_loader:
        image = image.to(device=device, dtype=torch.float32)
        pred = net(image)
        # Normalize to [0, 1]
        pred = torch.sigmoid(pred)
        pred = np.array(pred.data.cpu()[0])[0]
        pred_list.append(pred)
        # ConfusionMAtrix
        pred_bin = copy.deepcopy(pred)
        label = np.array(label.data.cpu()[0])[0]
        label_list.append(label)
        pred_bin[pred_bin >= 0.5] = 1
        pred_bin[pred_bin < 0.5] = 0
        tp += ((pred_bin == 1) & (label == 1)).sum()
        tn += ((pred_bin == 0) & (label == 0)).sum()
        fn += ((pred_bin == 0) & (label == 1)).sum()
        fp += ((pred_bin == 1) & (label == 0)).sum()
        # 保存图片
        pred = pred * 255
        save_filename = save_path + filename[0] + '.png'
        cv2.imwrite(save_filename, pred)
        print(f'{save_filename} done!')
    # Evaluaiton Indicators
    precision = tp / (tp + fp)   # 预测为真并且正确/预测正确样本总和
    sen = tp / (tp + fn)    # 预测为真并且正确/正样本总和
    spe = tn / (tn + fp)
    acc = (tp + tn) / (tp + tn + fp + fn)
    f1score = 2 * precision * sen / (precision + sen)
    # auc computing
    pred_auc = np.stack(pred_list, axis=0)
    label_auc = np.stack(label_list, axis=0)
    auc = roc_auc_score(label_auc.reshape(-1), pred_auc.reshape(-1))
    print(f'Precision: {precision} Sen: {sen} Spe:{spe} F1-score: {f1score} Acc: {acc} AUC: {auc}')

dataset.py

import torch
import cv2
import os
import glob
from torch.utils.data import Dataset

# import random
# from PIL import Image
# import numpy as np


class FundusSeg_Loader(Dataset):
    def __init__(self, data_path, is_train):
        # 初始化函数,读取所有data_path下的图片
        self.data_path = data_path
        self.imgs_path = glob.glob(os.path.join(data_path, 'image/*.tif'))
        self.labels_path = glob.glob(os.path.join(data_path, 'label/*.tif'))
        self.is_train = is_train
        print(self.imgs_path)
        print(self.labels_path)
    def __getitem__(self, index):
        # 根据index读取图片
        image_path = self.imgs_path[index]
        if self.is_train == 1:
            label_path = image_path.replace('image', 'label')
            label_path = label_path.replace('training', 'manual1')
        else:
            label_path = image_path.replace('image', 'label')
            label_path = label_path.replace('test.tif', 'manual1.tif')
            
        
        # 读取训练图片和标签图片
        image = cv2.imread(image_path)
        label = cv2.imread(label_path)

        # image = np.array(image)
        # label = np.array(label)
        # label = cv2.imread(label_path)
        # image = cv2.resize(image, (600,400))
        # label = cv2.resize(label, (600,400))
        # 转为单通道的图片
        # image = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
        label = cv2.cvtColor(label, cv2.COLOR_BGR2GRAY)
        # label = Image.fromarray(label)
        # label = label.convert("1")
        # reshape()函数可以改变数组的形状,并且原始数据不发生变化。
        image = image.transpose(2, 0, 1)
        # image = image.reshape(1, label.shape[0], label.shape[1])
        label = label.reshape(1, label.shape[0], label.shape[1])
        # 处理标签,将像素值为255的改为1
        if label.max() > 1:
            label[label > 1] = 1

        return image, label, image_path[len(image_path)-12:len(image_path)-4]

    def __len__(self):
        # 返回训练集大小
        return len(self.imgs_path)

visual.py

import numpy as np
import matplotlib.pyplot as plt
import pylab as pl

from mpl_toolkits.axes_grid1.inset_locator import inset_axes
data1_loss =np.loadtxt("E:\\code\\UNet_lr00001\\train_loss.txt",dtype=str )
data2_loss = np.loadtxt("E:\\code\\UNet_lr00001\\val_loss.txt",dtype=str)
x = range(0,10)
y = data1_loss[:, 0]
x1 = range(0,10)
y1 = data2_loss[:, 0]
fig = plt.figure(figsize = (7,5))    #figsize是图片的大小`
ax1 = fig.add_subplot(1, 1, 1) # ax1是子图的名字`
pl.plot(x,y,'g-',label=u'Dense_Unet(block layer=5)')
# ‘'g‘'代表“green”,表示画出的曲线是绿色,“-”代表画的曲线是实线,可自行选择,label代表的是图例的名称,一般要在名称前面加一个u,如果名称是中文,会显示不出来,目前还不知道怎么解决。
p2 = pl.plot(x, y,'r-', label = u'train_loss')
pl.legend()
#显示图例
p3 = pl.plot(x1,y1, 'b-', label = u'val_loss')
pl.legend()
pl.xlabel(u'epoch')
pl.ylabel(u'loss')
plt.title('Compare loss for different models in training')

在这里插入图片描述
在这里插入图片描述

这种基于 U-Net 的方法已在医学图像分割领域取得了一些成功,特别是在视网膜图像处理中。通过深度学习的方法,这种技术能够更准确地提取视网膜血管,为眼科医生提供辅助诊断和治疗的信息。
如有疑问,请评论。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/190746.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

什么是分布式锁?Redis实现分布式锁详解

目录 前言&#xff1a; 分布式系统买票示例 引入redis做分布式锁 引入过期时间 引入校验id 引入lua脚本 过期时间续约问题 redlock算法 小结&#xff1a; 前言&#xff1a; 在分布式系统中&#xff0c;涉及多个主机访问同一块资源&#xff0c;此时就需要锁来做互斥控制…

[论文阅读]CBAM——代码实现和讲解

CBAM 论文网址&#xff1a;CBAM 论文代码&#xff1a;CBAM 本文提出了一种卷积块注意力模块&#xff08;CBAM&#xff09;&#xff0c;它是卷积神经网络&#xff08;CNN&#xff09;的一种轻量级、高效的注意力模块。该模块沿着通道和空间两个独立维度依次推导注意力图&#x…

ehr人力资源管理系统(实际项目源码)

eHR人力资源管理系统&#xff1a;功能强大的人力资源管理工具 随着企业规模的不断扩大和业务需求的多样化&#xff0c;传统的人力资源管理模式已无法满足现代企业的需求。eHR人力资源管理系统作为一种先进的管理工具&#xff0c;能够为企业提供高效、准确、实时的人力资源管理…

GPT实战系列-GPT训练的Pretraining,SFT,Reward Modeling,RLHF

GPT实战系列-GPT训练的Pretraining&#xff0c;SFT&#xff0c;Reward Modeling&#xff0c;RLHF 文章目录 GPT实战系列-GPT训练的Pretraining&#xff0c;SFT&#xff0c;Reward Modeling&#xff0c;RLHFPretraining 预训练阶段Supervised FineTuning &#xff08;SFT&#x…

基于51单片机交通灯夜间模式+紧急模式_易懂版_(仿真+代码_报告_讲解)

J029 51单片机交通灯_易懂版__夜间紧急(仿真代码_报告_讲解&#xff09; 51单片机交通灯_易懂版_ 1 **讲解视频&#xff1a;**2 **功能要求**3 **仿真图&#xff1a;**4 **程序设计&#xff1a;**5 **设计报告**6 **资料清单&&下载链接&#xff1a;****资料下载链接&am…

超好玩C++控制台打飞机小游戏,附源码

我终于决定还是把这个放出来。 视频在这&#xff1a;https://v.youku.com/v_show/id_XNDQxMTQwNDA3Mg.html 具体信息主界面上都有写。 按空格暂停&#xff0c;建议暂停后再升级属性。 记录最高分的文件进行了加密。 有boss&#xff08;上面视频2分47秒&#xff09;。 挺好…

2023年【N1叉车司机】新版试题及N1叉车司机作业考试题库

题库来源&#xff1a;安全生产模拟考试一点通公众号小程序 N1叉车司机新版试题参考答案及N1叉车司机考试试题解析是安全生产模拟考试一点通题库老师及N1叉车司机操作证已考过的学员汇总&#xff0c;相对有效帮助N1叉车司机作业考试题库学员顺利通过考试。 1、【多选题】《中华…

Linux系统编程 day05 进程控制

Linux系统编程 day05 进程控制 1. 进程相关概念2. 创建进程3. exec函数族4. 进程回收 1. 进程相关概念 程序就是编译好的二进制文件&#xff0c;在磁盘上&#xff0c;占用磁盘空间。程序是一个静态的概念。进程即使启动了的程序&#xff0c;进程会占用系统资源&#xff0c;如内…

【服务器能干什么】二十分钟搭建一个属于自己的 RSS 服务

如果大家不想自己捣鼓,只是想尝尝鲜,可以在下面留言,我后台帮大家开几个账号玩一玩。 哔哩哔哩【高清版本可以点击去吐槽到 B 站观看】:【VPS服务器到底能干啥】信息爆炸的年代,如何甄别出优质的内容?你可能需要自建一个RSS服务!_哔哩哔哩_bilibili 前言 RSS 服务 市…

MySQL表连接

文章目录 MySQL内外连接1.内连接2.外连接&#xff08;1&#xff09;左外连接&#xff08;2)右外连接 3.简单案例 MySQL内外连接 1.内连接 内连接的SQL如下&#xff1a; SELECT ... FROM t1 INNER JOIN t2 ON 连接条件 [INNER JOIN t3 ON 连接条件] ... AND 其他条件;说明一下…

2.5 逆矩阵

一、逆矩阵的注释 假设 A A A 是一个方阵&#xff0c;其逆矩阵 A − 1 A^{-1} A−1 与它的大小相同&#xff0c; A − 1 A I A^{-1}AI A−1AI。 A A A 与 A − 1 A^{-1} A−1 会做相反的事情。它们的乘积是单位矩阵 —— 对向量无影响&#xff0c;所以 A − 1 A x x A^{…

QT基础开发笔记

用VS 写QT &#xff0c;设置exe图标的方法&#xff1a; 选定工程--》右键--》添加---》资源--》 QString 字符串用法总结说明 Qt QString 增、删、改、查、格式化等常用方法总结_qstring 格式化-CSDN博客 总结来说&#xff1a; QString 的 remove有两种用法&#xff0c;&am…

【C++】类和对象(下篇)

这里是目录 构造函数&#xff08;续&#xff09;构造函数体赋值初始化列表 explicit关键字隐式类型转换 static成员友元友元函数友元类 内部类匿名对象匿名对象的作用const引用匿名对象 构造函数&#xff08;续&#xff09; 构造函数体赋值 在创建对象时&#xff0c;编译器通…

02、Tensorflow实现手写数字识别(数字0-9)

02、Tensorflow实现手写数字识别&#xff08;数字0-9&#xff09; 开始学习机器学习啦&#xff0c;已经把吴恩达的课全部刷完了&#xff0c;现在开始熟悉一下复现代码。对这个手写数字实部比较感兴趣&#xff0c;作为入门的素材非常合适。 基于Tensorflow 2.10.0与pycharm 1…

SASS的导入文件详细教程

文章目录 前言导入SASS文件使用SASS部分文件默认变量值嵌套导入原生的CSS导入后言 前言 hello world欢迎来到前端的新世界 &#x1f61c;当前文章系列专栏&#xff1a;Sass和Less &#x1f431;‍&#x1f453;博主在前端领域还有很多知识和技术需要掌握&#xff0c;正在不断努…

电子学会C/C++编程等级考试2022年12月(二级)真题解析

C/C++等级考试(1~8级)全部真题・点这里 第1题:数组逆序重放 将一个数组中的值按逆序重新存放。例如,原来的顺序为8,6,5,4,1。要求改为1,4,5,6,8。输入 输入为两行:第一行数组中元素的个数n(1输出 输出为一行:输出逆序后数组的整数,每两个整数之间用空格分隔。样例输入 …

Linux:docker基础操作(3)

docker的介绍 Linux&#xff1a;Docker的介绍&#xff08;1&#xff09;-CSDN博客https://blog.csdn.net/w14768855/article/details/134146721?spm1001.2014.3001.5502 通过yum安装docker Linux&#xff1a;Docker-yum安装&#xff08;2&#xff09;-CSDN博客https://blog.…

Mac 最佳使用指南

官方 Mac 使用手册如何在macOS系统安装根证书mac Terminal config proxy 【mac 终端配置代理】iPhone 安装 iOS 17公测版&#xff08;Public Beta)macOS 最佳命令行客户端&#xff1a;iTermMac 配置与 Linux 互信Mac mini 外接移动硬盘无法写入或者无法显示的解决方法如何在 ma…

2016年五一杯数学建模B题能源总量控制下的城市工业企业协调发展问题解题全过程文档及程序

2016年五一杯数学建模 B题 能源总量控制下的城市工业企业协调发展问题 原题再现 能源是国民经济的重要物质基础,是工业企业发展的动力&#xff0c;但是过度的能源消耗&#xff0c;会破坏资源和环境&#xff0c;不利于经济的可持续发展。目前我国正处于经济转型的关键时期&…

牛客网刷题笔记四 链表节点k个一组翻转

NC50 链表中的节点每k个一组翻转 题目&#xff1a; 思路&#xff1a; 这种题目比较习惯现在草稿本涂涂画画链表处理过程。整体思路是赋值新的链表&#xff0c;用游离指针遍历原始链表进行翻转操作&#xff0c;当游离个数等于k时&#xff0c;就将翻转后的链表接到新的链表后&am…