基于EifficientNet的视网膜病变识别

分析一下代码

model.py

①下面这个方法的作用是:将传入的ch(channel)的个数调整到离它最近的8的整数倍,这样做的目的是对硬件更加友好。

def _make_divisible(ch, divisor=8, min_ch=None):
    if min_ch is None:
        min_ch = divisor
    new_ch = max(min_ch, int(ch + divisor / 2) // divisor * divisor)
    if new_ch < 0.9 * ch:
        new_ch += divisor
    return new_ch

②定义卷积BAN和激活函数的模块类,groups是用来控制卷积结构使用普通卷积结构还是使用Depwise卷积结构。

class ConvBNActivation(nn.Sequential):
    def __init__(self,
                 in_planes: int,
                 out_planes: int,
                 kernel_size: int = 3,
                 stride: int = 1,
                 groups: int = 1,
                 norm_layer: Optional[Callable[..., nn.Module]] = None,
                 activation_layer: Optional[Callable[..., nn.Module]] = None):
        padding = (kernel_size - 1) // 2
        if norm_layer is None:
            norm_layer = nn.BatchNorm2d
        if activation_layer is None:
            activation_layer = nn.SiLU

        super(ConvBNActivation, self).__init__(nn.Conv2d(in_channels=in_planes,
                                                         out_channels=out_planes,
                                                         kernel_size=kernel_size,
                                                         stride=stride,
                                                         padding=padding,
                                                         groups=groups,
                                                         bias=False),
                                               norm_layer(out_planes),
                                               activation_layer())

③定义InvertedResidualConfig模块,对应每一个MBConv模块的配置参数。

class InvertedResidualConfig:
    def __init__(self,
                 kernel: int,          # 3 or 5
                 input_c: int,
                 out_c: int,
                 expanded_ratio: int,  # 1 or 6
                 stride: int,          # 1 or 2
                 use_se: bool,         # True
                 drop_rate: float,
                 index: str,           # 1a, 2a, 2b, ...
                 width_coefficient: float):
        self.input_c = self.adjust_channels(input_c, width_coefficient)
        self.kernel = kernel
        self.expanded_c = self.input_c * expanded_ratio
        self.out_c = self.adjust_channels(out_c, width_coefficient)
        self.use_se = use_se
        self.stride = stride
        self.drop_rate = drop_rate
        self.index = index

    @staticmethod
    def adjust_channels(channels: int, width_coefficient: float):
        return _make_divisible(channels * width_coefficient, 8)

④定义InvertedResidual模块,即MBConv模块。

class InvertedResidual(nn.Module):
    def __init__(self,
                 cnf: InvertedResidualConfig,
                 norm_layer: Callable[..., nn.Module]):
        super(InvertedResidual, self).__init__()

        if cnf.stride not in [1, 2]:
            raise ValueError("illegal stride value.")

        self.use_res_connect = (cnf.stride == 1 and cnf.input_c == cnf.out_c)

        layers = OrderedDict()
        activation_layer = nn.SiLU

        if cnf.expanded_c != cnf.input_c:
            layers.update({"expand_conv": ConvBNActivation(cnf.input_c,
                                                           cnf.expanded_c,
                                                           kernel_size=1,
                                                           norm_layer=norm_layer,
                                                           activation_layer=activation_layer)})

        layers.update({"dwconv": ConvBNActivation(cnf.expanded_c,
                                                  cnf.expanded_c,
                                                  kernel_size=cnf.kernel,
                                                  stride=cnf.stride,
                                                  groups=cnf.expanded_c,
                                                  norm_layer=norm_layer,
                                                  activation_layer=activation_layer)})

        if cnf.use_se:
            layers.update({"se": SqueezeExcitation(cnf.input_c,
                                                   cnf.expanded_c)})

        layers.update({"project_conv": ConvBNActivation(cnf.expanded_c,
                                                        cnf.out_c,
                                                        kernel_size=1,
                                                        norm_layer=norm_layer,
                                                        activation_layer=nn.Identity)})

        self.block = nn.Sequential(layers)
        self.out_channels = cnf.out_c
        self.is_strided = cnf.stride > 1

        # 只有在使用shortcut连接时才使用dropout层
        if self.use_res_connect and cnf.drop_rate > 0:
            self.dropout = DropPath(cnf.drop_rate)
        else:
            self.dropout = nn.Identity()

    def forward(self, x: Tensor) -> Tensor:
        result = self.block(x)
        result = self.dropout(result)
        if self.use_res_connect:
            result += x

        return result

⑤接下来看看EfficientNet是如何实现的,网络参数如下:

代码如下:

class EfficientNet(nn.Module):
    def __init__(self,
                 width_coefficient: float,
                 depth_coefficient: float,
                 num_classes: int = 1000,
                 dropout_rate: float = 0.2,
                 drop_connect_rate: float = 0.2,
                 block: Optional[Callable[..., nn.Module]] = None,
                 norm_layer: Optional[Callable[..., nn.Module]] = None
                 ):
        super(EfficientNet, self).__init__()

        default_cnf = [[3, 32, 16, 1, 1, True, drop_connect_rate, 1],
                       [3, 16, 24, 6, 2, True, drop_connect_rate, 2],
                       [5, 24, 40, 6, 2, True, drop_connect_rate, 2],
                       [3, 40, 80, 6, 2, True, drop_connect_rate, 3],
                       [5, 80, 112, 6, 1, True, drop_connect_rate, 3],
                       [5, 112, 192, 6, 2, True, drop_connect_rate, 4],
                       [3, 192, 320, 6, 1, True, drop_connect_rate, 1]]

        def round_repeats(repeats):
            return int(math.ceil(depth_coefficient * repeats))

        if block is None:
            block = InvertedResidual

        if norm_layer is None:
            norm_layer = partial(nn.BatchNorm2d, eps=1e-3, momentum=0.1)

        adjust_channels = partial(InvertedResidualConfig.adjust_channels,
                                  width_coefficient=width_coefficient)

        bneck_conf = partial(InvertedResidualConfig,
                             width_coefficient=width_coefficient)

        b = 0
        num_blocks = float(sum(round_repeats(i[-1]) for i in default_cnf))
        inverted_residual_setting = []
        for stage, args in enumerate(default_cnf):
            cnf = copy.copy(args)
            for i in range(round_repeats(cnf.pop(-1))):
                if i > 0:
                    cnf[-3] = 1
                    cnf[1] = cnf[2]

                cnf[-1] = args[-2] * b / num_blocks
                index = str(stage + 1) + chr(i + 97)
                inverted_residual_setting.append(bneck_conf(*cnf, index))
                b += 1

        layers = OrderedDict()

        layers.update({"stem_conv": ConvBNActivation(in_planes=3,
                                                     out_planes=adjust_channels(32),
                                                     kernel_size=3,
                                                     stride=2,
                                                     norm_layer=norm_layer)})

        for cnf in inverted_residual_setting:
            layers.update({cnf.index: block(cnf, norm_layer)})

        last_conv_input_c = inverted_residual_setting[-1].out_c
        last_conv_output_c = adjust_channels(1280)
        layers.update({"top": ConvBNActivation(in_planes=last_conv_input_c,
                                               out_planes=last_conv_output_c,
                                               kernel_size=1,
                                               norm_layer=norm_layer)})

        self.features = nn.Sequential(layers)
        self.avgpool = nn.AdaptiveAvgPool2d(1)

        classifier = []
        if dropout_rate > 0:
            classifier.append(nn.Dropout(p=dropout_rate, inplace=True))
        classifier.append(nn.Linear(last_conv_output_c, num_classes))
        self.classifier = nn.Sequential(*classifier)

        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(m.weight, mode="fan_out")
                if m.bias is not None:
                    nn.init.zeros_(m.bias)
            elif isinstance(m, nn.BatchNorm2d):
                nn.init.ones_(m.weight)
                nn.init.zeros_(m.bias)
            elif isinstance(m, nn.Linear):
                nn.init.normal_(m.weight, 0, 0.01)
                nn.init.zeros_(m.bias)

    def _forward_impl(self, x: Tensor) -> Tensor:
        x = self.features(x)
        x = self.avgpool(x)
        x = torch.flatten(x, 1)
        x = self.classifier(x)

        return x

    def forward(self, x: Tensor) -> Tensor:
        return self._forward_impl(x)
train.py

①导入模型

from model import efficientnet_b0 as create_model

②不同版本的EfficientNet网络模型,对应不同的输入图像大小。

img_size = {"B0": 224,
                "B1": 240,
                "B2": 260,
                "B3": 300,
                "B4": 380,
                "B5": 456,
                "B6": 528,
                "B7": 600}
    num_model = "B0"

③实例化模型部分,传入类别个数,然后添加设备。

model = create_model(num_classes=args.num_classes).to(device)

④还有一个参数为是否冻结权重,如果为true,只会微调最后一层的1*1的卷积以及FC全连接层结构,如果为false,就会训练全部的网络结构。

    if args.freeze_layers:
        for name, para in model.named_parameters():
            if ("features.top" not in name) and ("classifier" not in name):
                para.requires_grad_(False)
            else:
                print("training {}".format(name))
predict.py

①创建模型,num_classes的大小需要根据自己的数据集类别个数进行改变,不仅这里需要改变,训练时也需要改变。

model = create_model(num_classes=5).to(device)

②导入之前训练的模型权重即可。

model_weight_path = "./weights/model-29.pth"

开始训练


1. 在train.py脚本中将--data-path设置成解压后的视网膜病变数据集文件夹的绝对路径。


2.下载预训练权重,根据自己使用的模型下载对应预训练权重。


3. 在train.py脚本中将--weights参数设成下载好的预训练权重路径。


4. 设置好数据集的路径--data-path以及预训练权重的路径--weights就能使用train.py脚本开始训练了(训练过程中会自动生成class_indices.json文件)。


5. 在predict.py脚本中导入和训练脚本中同样的模型,并将model_weight_path设置成训练好的模型权重路径(默认保存在weights文件夹下)。


6. 在predict.py脚本中将img_path设置成你自己需要预测的图片绝对路径。


7. 设置好权重路径model_weight_path和预测的图片路径img_path就能使用predict.py脚本进行预测了。


8. 数据集必须按照视网膜病变数据集的文件结构进行摆放(即一个类别对应一个文件夹),并且将训练以及预测脚本中的num_classes设置成正确的数据类别数。

基于efficientnetb0.pth预训练权重的训练结果如下:

基于efficientnetb7.pth预训练权重的训练结果如下(粉色部分):

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/635379.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

爬虫学习--12.MySQL数据库的基本操作(下)

MySQL查询数据 MySQL 数据库使用SQL SELECT语句来查询数据。 语法&#xff1a;在MySQL数据库中查询数据通用的 SELECT 语法 SELECT 字段1&#xff0c;字段2&#xff0c;……&#xff0c;字段n FROM table_name [WHERE 条件] [LIMIT N] 查询语句中你可以使用一个或者多个表&…

Ajax用法总结(包括原生Ajax、Jquery、Axois)

HTTP知识 HTTP&#xff08;hypertext transport protocol&#xff09;协议『超文本传输协议』&#xff0c;协议详细规定了浏览器和万维网服务器之间互相通信的规则。 请求报文 请求行: GET、POST /s?ieutf-8...&#xff08;url的一长串参数&#xff09; HTTP/1.1 请求头…

MySQL数据库单表查询中查询条件的写法

1.使用比较运算符作为查询条件 ; !; >; >; <; <; 如上图所示&#xff0c;可以使用命令select 字段&#xff0c;字段 from 表名 where Gender “M”; 即挑选出Gender “M” 的教师&#xff0c; 如上图所示&#xff0c;可以使用命令select 字段&#xff0c;…

如何使用Android NDK将头像变成“遗像”

看完本文的标题&#xff0c;可能有人要打我。你说黑白的老照片不好吗&#xff1f;非要说什么遗像&#xff0c;我现在就把你变成遗像&#xff01;好了&#xff0c;言归正传。我想大部分人都用过美颜相机或者剪映等软件吧&#xff0c;它们的滤镜功能是如何实现的&#xff0c;有人…

TCP Spurious Retransmission

TCP虚假重传是网络中可能存在丢包问题的一个指标&#xff0c;但并不意味着一定存在丢包。虚假重传指的是当TCP发送方错误地认为原始数据包已丢失时&#xff0c;会不必要地进行重传&#xff0c;而实际上原始数据包已成功到达目的地。 虚假重传可能由多种因素引起&#xff0c;例…

【Linux】进程信号及相关函数/系统调用的简单认识与使用

文章目录 前言一、相关函数/系统调用1. signal2. kill3. abort (库函数)4. raise (库函数)5. alarm 前言 现实生活中, 存在着诸多信号, 比如红绿灯, 上下课铃声…我们在接收到信号时, 就会做出相应的动作. 对于进程也是如此的, 进程也会收到来自 OS 发出的信号, 根据信号的不同…

分布式锁2-Zookeeper分布式锁实战

Zookeeper分布式锁实战 使用curator操作Zookeeper进行实战&#xff1b; curator是什么&#xff1a;Apache Curator包含一套高级API框架和工具类&#xff0c;它 是Apache ZooKeeper 的Java 客户端库。 准备 pom文件引入curtor依赖和zookeeper依赖 <!--curator--> <…

麦克纳母轮(全向)移动机器人集群控制的Simulink/Simscape虚拟仿真平台搭建

麦克纳姆轮是一种常见的全向移动机构&#xff0c;可以使机器人在平面内任意方向平移&#xff0c;同时可以利用差速轮车的属性实现自转&#xff0c;能够在狭窄且复杂多变的环境中自由运行&#xff0c;因而被广泛应用于竞赛机器人和特殊工业机器人场景。 Ps:最新的BYD仰望U8也有一…

graspnet+Astra2相机实现部署

graspnetAstra2相机实现部署 &#x1f680; 环境配置 &#x1f680; ubuntu 20.04Astra2相机cuda 11.0.1cudnn v8.9.7python 3.8.19pytorch 1.7.0numpy 1.23.5 1. graspnet的复现 具体的复现流程可以参考这篇文章&#xff1a;Ubuntu20.04下GraspNet复现流程 这里就不再详细…

C++容器之映射(std::map)

目录 1 概述2 使用实例3 接口使用3.1 construct3.2 assigns3.3 iterators3.4 capacity3.5 access3.6 insert3.7 erase3.8 swap3.9 clear3.10 emplace3.11 emplace_hint3.12 key_comp3.13 value_comp3.14 find/count3.15 upper_bound/upper_bound/equal_range3.16 get_allocator…

数据结构和算法|堆排序系列问题(一)|堆、建堆和Top-K问题

在这里不再描述大顶堆和小顶堆的含义&#xff0c;只剖析原理层面。 主要内容来自&#xff1a;Hello算法 文章目录 1.堆的实现1.1 堆的存储与表示过程1.2 访问堆顶元素1.4元素出堆 2.⭐️建堆2.1 方法一&#xff1a;借助入堆操作实现2.2 ⭐️方法二&#xff1a;通过遍历堆化实现…

Java 多线程抢红包

问题需求 一个人在群里发了1个100元的红包&#xff0c;被分成了8个&#xff0c;群里有10个人一起来抢红包&#xff0c;有抢到的金额随机分配。 红包功能需要满足哪些具体规则呢? 1、被分的人数抢到的金额之和要等于红包金额&#xff0c;不能多也不能少。 2、每个人至少抢到1元…

Ubuntu Nerfstudio安装

https://blog.csdn.net/qq_30565883/article/details/133778529 https://blog.csdn.net/weixin_52581013/article/details/137982846 https://zhuanlan.zhihu.com/p/654394767 1. 结论 因为需要安装tiny-cuda-nn&#xff0c;然而 所以我之前的在笔记本上安装就白费了&#xf…

基于python的k-means聚类分析算法,对文本、数据等进行聚类,有轮廓系数和手肘法检验

K-means算法是一种常见的聚类算法&#xff0c;用于将数据点分成不同的组&#xff08;簇&#xff09;&#xff0c;使同一组内的数据点彼此相似&#xff0c;不同组之间的数据点相对较远。以下是K-means算法的基本工作原理和步骤&#xff1a; 工作原理&#xff1a; 初始化&#x…

QT C++ QTableWidget 演示

本文演示了 QTableWidget的初始化以及单元格值改变时响应槽函数&#xff0c;打印单元格。 并且&#xff0c;最后列不一样,是combobox &#xff0c;此列的槽函数用lambda函数。 在QT6.2.4 MSVC2019 调试通过。 1.界面效果 2.头文件 #ifndef MAINWINDOW_H #define MAINWINDOW…

使用API有效率地管理Dynadot域名,进行域名邮箱的默认邮件转发设置

关于Dynadot Dynadot是通过ICANN认证的域名注册商&#xff0c;自2002年成立以来&#xff0c;服务于全球108个国家和地区的客户&#xff0c;为数以万计的客户提供简洁&#xff0c;优惠&#xff0c;安全的域名注册以及管理服务。 Dynadot平台操作教程索引&#xff08;包括域名邮…

【四数之和】python,排序+双指针

四层循环&#xff1f;&#xff08;doge) 和【三数之和】题目很类似 class Solution:def fourSum(self, nums: List[int], target: int) -> List[List[int]]:nums.sort()#a,b,c,d四个数&#xff0c;先固定两个数&#xff0c;那就是双指针问题了&#xff0c;令ba1&#xff…

【数据结构】【C语言】堆~动画超详细解读!

目录 1 什么是堆1.1 堆的逻辑结构和物理结构1.2 堆的访问1.3 堆为什么物理结构上要用数组?1.4 堆数据上的特点 2 堆的实现2.1 堆类型定义2.2 需要实现的接口2.3 初始化堆2.4 销毁堆2.5 堆判空2.6 交换函数2.7 向上调整(小堆)2.8 向下调整(小堆)2.9 堆插入2.10 堆删除2.11 //堆…

若依解决使用https上传文件返回http路径问题

若依通过HTTPS请求进行文件上传时却返回HTTP的文件链接地址&#xff0c;主要原因是使用了 request.getRequestURL 获取链接地址。 解决办法&#xff1a; 在nginx配置文件location处加上&#xff1a;proxy_set_header X-Forwarded-Scheme $scheme; 然后代码通过request.getHea…

【跳坑日记】暴力解决Ubuntu SSH报错: Failed to start OpenBSD Secure Shell server

报错环境说明&#xff1a; 服务器环境&#xff1a;Ubuntu 20.04 错误内容 最近服务器突然报错&#xff0c;提示如下图信息&#xff1a; 搜素了各种问答&#xff0c;国内的回答大多数是用 ssh-keygen -A命令来解决&#xff0c;但最终也无法登录服务器。 最终搜索到ask ubun…