VGAN实现视网膜图像血管分割(基于pytorch)

背景介绍

VGAN(Retinal Vessel Segmentation in Fundoscopic Images with Generative Adversarial Networks)出自2018年的一篇论文,尝试使用生成性对抗网络实现视网膜血管分割的任务,原论文地址:https://arxiv.org/abs/1706.09318
在github上有相应的源码仓库,不过由于版本的原因也会出现一些bug,本篇博客在复现项目的过程中也对源码进行了相应的修改,源码地址: https://github.com/guyuchao/Vessel-wgan-pytorch?tab=readme-ov-file

另一方面,刚好这个项目作为我2023年的最后一个项目,就斗胆当作是2023年编程之旅的回顾,博主是在茫茫知识海洋漂泊的一叶小舟,还有许多的知识尚未学习,希望可以和大家互相交流学习!

2024年,冲鸭!!!!!


前言

生成对抗网络(Generative Adversarial Networks)在提出的时候是为了实现模型创造性的能力,如今在AI图像生成领域已经有非常广阔的应用,例如知名的Midjourney网站,就是通过用户输入的prompt提示,利用GAN的框架生成对应用户想要生成的图片;我自己对于这个模型的名声也是早有耳闻,刚好前一段时间看到了《Retinal Vessel Segmentation in Fundoscopic Images with Generative Adversarial Networks》这篇文章,内容里探究了把GAN模型应用到视网膜血管分割的领域,刚好可以与我本学期的生物医学创新实践联系在一起。

生成对抗网络介绍

生成对抗网络(Generative Adversarial Network)简称GAN,是深度学习领域的一种重要模型,由Ian Goodfellow在2014年提出。

GAN模型包括两部分:一个是生成器(Generator),另一个是判别器(Discriminator)。这两部分模型相互博弈,共同训练,赋予网络生成特定分布的数据的能力。

1. 生成器(Generator):该部分的目标是生成尽可能真实的数据。例如,如果我们想让网络生成一张风景图片,生成器的目标就是生成一张看上去就像是某个摄影师拍摄的风景照片。

2. 判别器(Discriminator):该部分的目标是尽可能好地区分出真实的和生成的数据。在风景图片的例子中,判别器需要区分出哪些图片是真实的风景照片,哪些是生成器生成的假照片。

两者相互博弈的过程中,判别器会不断提高对真假数据的判断能力,生成器也会不断提高生成数据的逼真度,理想状态下,生成器生成的数据将和真实数据无法区分,判别器对生成器的生成结果的判断是50%,即做出了随机猜测。这样,就完成了GAN的训练过程。

视网膜分割的GAN模型(VGAN)

从上图中可以看出,模型中的GAN的generator是一个U-net形状的网络模型,每一层上采样层都与对称的下采样的输出进行连接,能够很好的处理图像的边缘及其他的细节特征;discriminator是一个多层的下采样的网络模型,最后是输出是实现一个二分类的效果,即{(0,1)}^N,接近0表示判断机器生成的(generator),接近1表示判断为真实的血管分割标签,每层的generator和discriminator都是由基本的block卷积神经网络组成,block的代码构建为:

class block(nn.Module):
    def __init__(self,in_filters,n_filters):
        super(block,self).__init__()
        self.deconv1 = nn.Sequential(
            nn.Conv2d(in_filters, n_filters, 3, stride=1, padding=1),
            nn.BatchNorm2d(n_filters),
            nn.ReLU())
    def forward(self, x):
        x=self.deconv1(x)
        return x

 是用nn.Sequential()连接的包含卷积,标准化和池化的经典卷积层,卷积核为3\times3,步长为1,边缘补充为1,处理的视网膜图片是三通道彩色图盘,第一层的intput_channel为3;下面用pytorch的tensorboard工具对搭建的generator和discriminator进行网络结构的可视化

这里需要注意的是,在原论文中,Discriminator的输入并不是Generator直接生成的图片或者原数据集中的label,而是需要在C通道上与进行分割的原视网膜图片进行合并再进行输入。

损失函数

在训练Generator和训练Discriminator时,使用不同的损失函数,我们最后是使用Generator进行mask的生成,也就是使用Generator输入需要进行视网膜血管分割的图片,输出分割的结果,所以我们更加注重Generator损失函数的设计。

GAN整体的损失函数可以定义为

对于D(Discriminator),在代码中不设计具体的损失函数,从任务设计中可以得出,当输入到D的是真实的标签图像时,我们期望D输出越接近1越好;当输入到D的是Generator生成的图像时,我们期望D输出越接近0越好,基于这种关系,我们直接把D的输出作为损失函数,同时,为了避免GAN模型中常见的断层问题,引入了经典的WGAN方法(Gradient Penalty),即获得一个1-Lipschitz函数,保证GAN模型的训练曲线是足够平滑从而生成稳定的图片,在梯度计算中引入作为梯度惩罚,则D的损失函数可以表示为:

对于G,为适应分割的任务不需要使用隐含空间的向量而是直接获取img的输入,具体地,在代码中使用二分类交叉熵损失函数获取与对应标签的loss值,同时加上来自D的反馈,具体的损失函数为:

数据集预处理

本次项目使用的数据集是经典的视网膜血管分割数据集,含有20张训练集和20张测试集,每张视网膜眼底的图片是584\times565\times3像素的三通道彩色tif格式图片,对应的标签是584\times565像素单通道灰度tif图片

对img的预处理包括随机改变图片的亮度、对比度和色相,图片像素标准化、转换成tensor数据的经典图片训练格式[B,C,H,W];

对label的预处理包括图片像素标准化、转换成tensor数据的经典图片训练格式[B,C,H,W];

同时对img和label的预处理包括随机裁剪图片高和宽为512\times512像素大小,随机水平翻转和垂直翻转;

下面是对训练数据中的进行预处理后的结果可视化

训练过程

一开始打算把本项目放到colab上跑或者服务器上跑,在此之前抱着试一试的态度先用本地的显卡1080ti加4G显存跑了50个epoch,结果竟然能跑得动!于是就先跑了300个epoch,显卡没崩,结果保存在./pth中。

对于GAN的结果,除了使用传统的评估方法外,也会对训练过程的结果进行输出可视化看看结果有没有生成奇奇怪怪的图像从而停止训练重新调整,不过对于本次项目,Generator的输入不包含隐含空间z,而且加入了WGAN进行约束,所以生成的图像基本上是比较完整的;作为视觉上直观地对GAN的效果进行评估,我们每跑50个epoch对应输出测试集生成的分割图像,保存在对应的文件夹名称里面,跑完全部epoch后,对应把D和G训练好的checkpoint也保存在./pth路径下。

300个epoch跑了一个多小时,显卡的散热吹风机感觉可以起飞了,不过结果其实可圈可点,为了不浪费训练好的资料,后面又写了一个re_train脚本,读取前一次训练好的模型权重再进行训练,又花了一个小时跑了300个epoch,结果保存在./pth2目录下,结构与./pth中文件相同,所以综合起来一共训练了600个epoch。

结果分析

 先对比每50个epoch生成的图像,选取测试集中的第一张图片

对于测试集的预处理也包括随机裁剪像素512\times512像素大小与随机水平翻转与垂直翻转,所以生成的图像包含有不同的方向和裁剪风格。

从第50个epoch到第600个epoch结果,可以明显的看出图片质量提升的效果,说明Generator的学习是非常有效果的,没有出现GAN中经常出现的图片断层效果。

对比一下对应的眼底原图和血管分割标签图:

从视觉上来看,GAN生成图像与原图像的分割标签是比较接近的。

接下来我们使用训练好的Generator模型进行量化的对比:

绘制对应的PR曲线与ROC曲线:

从两个曲线的效果可以看出,训练出来的模型在测试集上也具有比较好的效果,本次项目使用的VGAN在处理视网膜血管分割的任务上体现出比较好的性能。

参考文献

  1. Son, J., Park, S.J., & Jung, K. (2017). Retinal Vessel Segmentation in Fundoscopic Images with Generative Adversarial Networks. ArXiv, abs/1706.09318.
  2. Arjovsky, M., Chintala, S., & Bottou, L. (2017). Wasserstein GAN. ArXiv, abs/1701.07875.

彩蛋

我们都知道GAN以创作能力而闻名,那我们试一下用上面训练好的模型接受随机初始化满足正态分布的z隐含空间的数据会输出怎么样的图像

嗯...看来想要GAN生成像样的图片,还是需要再训练机制里面下手

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/316199.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

用通俗易懂的方式讲解:十分钟读懂 Stable Diffusion 运行原理

AIGC 热潮正猛烈地席卷开来,可以说 Stable Diffusion 开源发布把 AI 图像生成提高了全新高度,特别是 ControlNet 和 T2I-Adapter 控制模块的提出进一步提高生成可控性,也在逐渐改变一部分行业的生产模式。惊艳其出色表现,也不禁好…

大语言模型下载,huggingface和modelscope加速

huggingface 下载模型 如果服务器翻墙了,不用租机器 如果服务器没翻墙,可以建议使用下面的方式 可以租一台**autodl**不用显卡的机器,一小时只有1毛钱,启动学术加速,然后下载,下载完之后,用scp…

Java重修第五天—面向对象2

通过学习本篇文章可以掌握如下知识 static;设计单例;继承。 之前文章我们已经对面向对象进行了入门学习,这篇文章我们就开始深入了解面向对象设计。 static 我们定义了一个 Student类,增加姓名属性:name &#xff1…

Spring Task 任务调度工具

大家好我是苏麟 , 今天聊聊Spring Task 任务调度工具 Spring Task Spring Task 是Spring框架提供的任务调度工具,可以按照约定的时间自动执行某个代码逻辑。 定位:定时任务框架 作用:定时自动执行某段Java代码 什么是定时任务 ? 通过时…

009-Zynq基操之如何去玩转PL向PS的中断(对新手友好,走过路过千万不要错过)

文章目录 前言一、PL-PS的中断是啥?二、PL-PS端中断详细步骤1.ZYNQ核配置2.PS端中断函数配置3.需要拓展多个中断函数 总结 前言 本设计跟我的ZYNQ实战合集专栏中的脉冲触发电路有关系,也正好趁这个机会讲述一下PL-PS的中断系统,如何去触发中…

为什么不直接public,多此一举用get、set,一文给你说明白

文章目录 1. 封装性(Encapsulation)2. 验证与逻辑处理3. 计算属性(Computed Properties)4. **跟踪变化(Change Tracking)5. 懒加载与延迟初始化(Lazy Initialization)6. 兼容性与未来…

【Leetcode】2182. 构造限制重复的字符串

文章目录 题目思路代码 题目 2182. 构造限制重复的字符串 问题:给你一个字符串 s 和一个整数 repeatLimit ,用 s 中的字符构造一个新字符串 repeatLimitedString ,使任何字母 连续 出现的次数都不超过 repeatLimit 次。你不必使用 s 中的全…

高效便捷的远程管理利器——Royal TSX for Mac软件介绍

Royal TSX for Mac是一款功能强大、操作便捷的远程管理软件。无论是远程桌面、SSH、VNC、Telnet还是FTP,用户都可以通过Royal TSX轻松地远程连接和管理各种服务器、计算机和网络设备。 Royal TSX for Mac提供了直观的界面和丰富的功能,让用户能够快速便…

新版云进销存ERP销售库存仓库员工管理系统源码

新版云进销存ERP销售库存仓库员工管理系统源码 系统介绍:2022版本,带合同报价单打印,修复子账号不显示新加客户的BUG,还有其他方面的优化。 简单方便。 功能强大,系统采用phpMYSQL开发,B/S架构,方便随地使用…

红队打靶练习:HOLYNIX: V1

目录 信息收集 1、arp 2、netdiscover 3、nmap 4、nikto whatweb 目录探测 1、gobuster 2、dirsearch 3、dirb 4、feroxbuster WEB sqlmap 1、爆库 2、爆表 3、爆列 4、爆字段 后台登录 1、文件上传 2、文件包含 3、越权漏洞 反弹shell 提权 总结 信息…

SpringSecurity入门demo(二)表单认证

上一篇博客集成 Spring Security,使用其默认生效的 HTTP 基本认证保护 URL 资源,下面使用表单认证来保护 URL 资源。 一、默认表单认证: 代码改动:自定义WebSecurityConfig配置类 package com.security.demo.config; import or…

最新AI绘画Midjourney绘画提示词Prompt大全

一、Midjourney绘画工具 SparkAi创作系统是基于ChatGPT进行开发的Ai智能问答系统和Midjourney绘画系统,支持OpenAI-GPT全模型国内AI全模型。本期针对源码系统整体测试下来非常完美,可以说SparkAi是目前国内一款的ChatGPT对接OpenAI软件系统。那么如何搭…

国内开源环境漫谈

我国开源软件产业相较于欧美发达国家而言起步相对较晚,开源项目很少超过五年,开发者较年轻。国外很多开源项目都是10年以上的规划与投入。在开源社区发展初期、发展期、协作期、结晶期与流行期的五个阶段中,中国的开源社区平台大多处于前三个…

imgaug库指南(14):从入门到精通的【图像增强】之旅

引言 在深度学习和计算机视觉的世界里,数据是模型训练的基石,其质量与数量直接影响着模型的性能。然而,获取大量高质量的标注数据往往需要耗费大量的时间和资源。正因如此,数据增强技术应运而生,成为了解决这一问题的…

【服务器数据恢复】服务器硬盘磁头损坏的数据恢复案例

服务器硬盘故障: 一台服务器上raid阵列上有两块硬盘出现故障,用户方已经将故障硬盘送到其他机构检测过,其中一块硬盘已经开盘,检测结果是盘片损伤严重;另一块硬盘尚未开盘,初步判断也存在硬件故障&#xff…

代码随想录 Leetcode160. 相交链表

题目: 代码(首刷看解析 2024年1月13日): class Solution { public:ListNode *getIntersectionNode(ListNode *headA, ListNode *headB) {ListNode *A headA, *B headB;while (A ! B) {A A ! nullptr ? A->next : headB;B B ! nullpt…

Legion R7000 2021(82JW)原装出厂Win10/WIN11系统预装OEM系统镜像

LENOVO联想拯救者R7000 2021款(82JW)笔记本电脑原厂Windows10/11系统 链接:https://pan.baidu.com/s/1m_Ql5qu6tnw62PbpvXB0hQ?pwd6ek4 提取码:6ek4 原装出厂系统自带所有驱动、出厂主题壁纸、系统属性专属联机支持标志、系统属性专属联想的LOGO标…

金蝶云星空和吉客云单据接口对接

金蝶云星空和吉客云单据接口对接 对接系统:吉客云 吉客云是基于“网店管家”十五年电商ERP行业和技术积累基础上顺应产业发展需求,重新定位、全新设计推出的换代产品,从业务数字化和组织数字化两个方向出发,以构建流程的闭环为依归…

阿里云服务部署docker容器

1.1 为什么要用docker 问题 开发、测试、生产环境不统一,造成项目测试、部署时产生问题 解决方案 使用容器化技术,将环境和项目一起发送给测试、部署人员,测试人数和运维人员直接使用发过 来的环境和项目进行操作,避免环境不统一…

详解Skywalking 服务Overview页面的参数含义(适合小白)

本文针对刚刚接触skywalking的同学,重点讲解服务Overview页面中各个参数的含义,为大家快速上手skywalking会起到帮助作用! 最重要的三个指标 Service Apdex(数字):当前服务的评分 Successful Rate(数字&a…