第12章 PyTorch图像分割代码框架-2

模型模块

本书的第5-9章重点介绍了各种2D3D的语义分割和实例分割网络模型,所以在模型模块中,我们需要做的事情就是将要实验的分割网络写在该目录下。有时候我们可能想尝试不同的分割网络结构,所以在该目录下可以存在多个想要实验的网络模型定义文件。对于PASCAL VOC这样的自然数据集,我们可能想实验Deeplab v3+PSPNetRefineNet等网络的训练效果。代码11-3给出了Deeplab v3+网络封装后的主体部分,完整网络搭建代码可参考本书配套代码对应章节。

代码11-3 Deeplab v3+网络的主体部分

# 定义Deeplab V3+类
class DeepLabHeadV3Plus(nn.Module):
    def __init__(self, in_channels, low_level_channels, num_classes, aspp_dilate=[12, 24, 36]):
        super(DeepLabHeadV3Plus, self).__init__()


        self.project = nn.Sequential(
            nn.Conv2d(low_level_channels, 48, 1, bias=False),
            nn.BatchNorm2d(48),
            nn.ReLU(inplace=True),
        )
    # ASPP
        self.aspp = ASPP(in_channels, aspp_dilate)
    # classifier head
        self.classifier = nn.Sequential(
            nn.Conv2d(304, 256, 3, padding=1, bias=False),
            nn.BatchNorm2d(256),
            nn.ReLU(inplace=True),
            nn.Conv2d(256, num_classes, 1)
        )


        self._init_weight()
  # forward method
    def forward(self, feature):
        # print(feature['low_level'].shape)
        # print(feature['out'].shape)
        low_level_feature = self.project(feature['low_level'])
        output_feature = self.aspp(feature['out'])
        output_feature = F.interpolate(
            output_feature, size=low_level_feature.shape[2:], mode='bilinear', align_corners=False)
        return self.classifier(torch.cat([low_level_feature, output_feature], dim=1))
  # weight initilize
    def _init_weight(self):
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(m.weight)
            elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
                nn.init.constant_(m.weight, 1)
                nn.init.constant_(m.bias, 0)

对于复杂网络搭建,一般都是采用自下而上的搭建方法,先搭建底层组件,再逐步向上封装,对于本例中的Deeplab v3+,可以先分别搭建backbone骨干网络、ASPP和编解码结构,最后再进行封装。

工具函数模块

工具函数是为项目完成各项功能所自定义的辅助函数,可以统一定义在utils文件夹下,根据实际项目的不同,工具函数也各不相同。常用的工具函数包括各种损失函数的定义loss.py、训练可视化函数的定义visualize.py、用于记录训练日志的log.py等。代码11-4给出了一个关于Focal loss损失函数的定义,该损失函数作为工具函数可放在loss.py文件中。

代码11-4 工具函数示例:定义一个Focal loss

# 导入相关库
import torch
import torch.nn as nn
import torch.nn.functional as F
# 定义一个Focal loss类
class FocalLoss(nn.Module):
    def __init__(self, alpha=1, gamma=2):
        super(FocalLoss, self).__init__()
        self.alpha = alpha
        self.gamma = gamma


    def forward(self, inputs, targets):
        # Compute cross-entropy loss
        ce_loss = F.cross_entropy(inputs, targets, reduction='none')


        # Compute the focal loss
        pt = torch.exp(-ce_loss)  
        focal_loss = self.alpha * (1 - pt) ** self.gamma * ce_loss
        return focal_loss.mean()

配置模块

配置模块是为项目模型训练传入各种参数而进行设置的模块,比如训练数据所在目录、训练所需要的各种参数、训练过程是否需要可视化等。一般来说,我们有两种方式来对项目执行参数进行配置管理,一种是直接在主函数main.py中使用argparse库对参数进行配置,然后再命令行中进行传入;另一种则是单独定义一个config.py或者config.yaml文件来对所有参数进行统一配置。基于argparse库的参数配置管理简单示例如代码11-5所示。

代码11-5 argparser参数配置管理

# 导入argparse库
import argparse
# 创建参数管理器
parser = argparse.ArgumentParser()
# 涉及数据相关的参数管理
parser.add_argument("--data_root", type=str, default='./dataset',
                     help="path to Dataset")
parser.add_argument("--save_root", type=str, default='./',
                     help="path to save result")
parser.add_argument("--dataset", type=str, default='voc',
                     choices=['voc', 'cityscapes', 'ade'], help='Name of dataset')
parser.add_argument("--num_classes", type=int, default=None,
                     help="num classes (default: None)")

在上述代码中,我们基于argparse给出了一小部分参数配置管理代码,涉及训练数据相关的部分参数,包括数据读取路径、存放路径、训练所用数据集、分割类别数量等。

主函数模块

主函数模块main.py是项目的启动模块,该模块将定义好的数据和模型模块进行组装,并结合损失函数、优化器、评估方法和可视化等组件,将config.py中配置好的项目参数传入,根据训练-验证的模式,执行图像分割项目模型训练和验证。代码11-6VOC数据集训练验证部分代码。

代码11-6 主函数模块中的训练迭代部分

# 初始化区间损失
interval_loss = 0
while True:  
  # 执行训练
  model.train()
  cur_epochs += 1
  for (images, labels) in train_loader:
    cur_itrs += 1
    images = images.to(device, dtype=torch.float32)
    labels = labels.to(device, dtype=torch.long)
    optimizer.zero_grad()
    outputs = model(images)
    loss = criterion(outputs, labels)
    loss.backward()
    optimizer.step()


    np_loss = loss.detach().cpu().numpy()
    interval_loss += np_loss


    if vis is not None:
      vis.vis_scalar('Loss', cur_itrs, np_loss)
    # 打印训练信息
    if (cur_itrs) % opts.print_interval == 0:
      pass
    # 保存模型
    if (cur_itrs) % opts.val_interval == 0:
      pass
      # 日志记录
      logger.info("Save the latest model to %s" % save_path_checkpoints)
      # 模型验证
      print("validation...")
      model.eval()
      val_score, ret_samples = validate(
        opts=opts, model=model, loader=val_loader, device=device, metrics=metrics,
        ret_samples_ids=vis_sample_id)
      logger.info("Validation performance: %s", val_score)
      
      # 保存最优模型
      if val_score['mean_dice'] > best_score:  
        best_score = val_score['mean_dice']
        save_ckpt(os.path.join(save_path_checkpoints, 'best_%s_%s_os%d.pth' %
                     (opts.model, opts.dataset, opts.output_stride)))
        logger.info("Save best-performance model so far to %s" % save_path_checkpoints)


      # 训练过程可视化
      if vis is not None:  
        vis.vis_scalar("[Val] Overall Acc", cur_itrs, val_score['Overall Acc'])
        vis.vis_scalar("[Val] Mean IoU", cur_itrs, val_score['Mean IoU'])
        vis.vis_table("[Val] Class IoU", val_score['Class IoU'])


        for k, (img, target, lbl) in enumerate(ret_samples):
          img = (denorm(img) * 255).astype(np.uint8)
          target = train_dst.decode_target(target).transpose(2, 0, 1).astype(np.uint8)
          lbl = train_dst.decode_target(lbl).transpose(2, 0, 1).astype(np.uint8)
          concat_img = np.concatenate((img, target, lbl), axis=2)  
          vis.vis_image('Sample %d' % k, concat_img)
          
    scheduler.step()

在代码11-6中,我们展示了一个图像分割项目主函数模块中最核心的训练和验证部分。在训练时,按照指定迭代次数保存模型和对训练过程进行可视化展示。图11-2为训练打印的部分信息。

29f5835d31863a3c0e12336eba35dd8d.png

11-2 VOC训练过程信息

11-3为基于visdom的训练过程可视化展示,包括当前训练配置参数信息,训练损失函数变化曲线、验证集全局准确率、mIoU和类别IoU等指标变化曲线图。

c829adc88b9de5b266170dd6b3b86385.png

11-3 Deeplab v3+训练过程可视化

11-4展示了两组训练过程中验证集的输入图像、标签图像和模型预测图像的对比图。可以看到,基于Deeplab v3+的分割模型在PASCAL VOC 2012上表现还不错。

ea3738bd115f8bb89add3ea4ab2e67b6.png

11-4 验证集模型效果图

后续全书内容和代码将在github上开源,请关注仓库:

https://github.com/luwill/Deep-Learning-Image-Segmentation

(未完待续)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/123265.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

11 # 手写 reduce 方法

reduce 使用 reduce() 方法对数组中的每个元素按序执行一个提供的 reducer 函数,每一次运行 reducer 会将先前元素的计算结果作为参数传入,最后将其结果汇总为单个返回值。 第一次执行回调函数时,不存在“上一次的计算结果”。如果需要回调…

运行obotframework-ride控制台报错module ‘urllib‘ has no attribute ‘Request‘

背景:Python3.8robotframework-ride1.7.3.1,运行报错module urllib has no attribute Request 原因: 解决:升级robotframework-ride到2.0以上。或者降级python到3.7。

基于SSM的演唱会购票系统的设计与实现

末尾获取源码 开发语言:Java Java开发工具:JDK1.8 后端框架:SSM 前端:Vue、HTML 数据库:MySQL5.7和Navicat管理工具结合 服务器:Tomcat8.5 开发软件:IDEA / Eclipse 是否Maven项目:是…

大数据学习之一文学会Spark【Spark知识点总结】

文章目录 什么是SparkSpark的特点Spark vs HadoopSparkHadoopSpark集群安装部署Spark集群安装部署StandaloneON YARN Spark的工作原理什么是RDDRDD的特点Spark架构相关进程Spark架构原理 Spark实战:单词统计Scala代码开发java代码开发任务提交 Transformation与Acti…

自动控制原理--面试问答题

以下文中的,例如 s_1 为 s下角标1。面试加油! 控制系统的三要素:稳准快。稳,系统最后不能震荡、发散,一定要收敛于某一个值;快,能够迅速达到系统的预设值;准,最后稳态值…

清凉油市场现状及未来发展趋势

清凉油市场一直以其庞大的规模和快速增长的势头受到人们的关注。无论是消费者对健康生活方式的追求,还是中国作为全球最大市场的地位,都为清凉油市场的持续发展注入了强大的动力。随着人们对健康意识的提升和对保健产品需求的增加,清凉油市场…

算法?认识一下啦

一、什么是算法? 算法 ,是对特定问题求解方法和步骤的一种描述。它是有限指令的有限序列,其中每个指令表示一个或多个操作。 算法和程序的关系 算法​是解决问题的一种方法或一个过程,考虑如何将输入转换成输出,一个…

功能更新|Leangoo领歌免费敏捷工具支持SAFe大规模敏捷框架

Leangoo领歌是一款永久免费的专业的敏捷开发管理工具,提供端到端敏捷研发管理解决方案,涵盖敏捷需求管理、任务协同、进展跟踪、统计度量等。 
 Leangoo可以支持敏捷研发管理全流程,包括小型团队敏捷开发,规模化敏捷SAFe&#xf…

SpringBoot测试类启动web环境-下篇

一、响应状态 1.MockMvcResultMatchers 说明:模拟结果匹配。 package com.forever;import org.junit.jupiter.api.Test; import org.springframework.beans.factory.annotation.Autowired; import org.springframework.boot.test.autoconfigure.web.servlet.AutoC…

软件测试|MySQL LIKE:深入了解模糊查询

简介 在数据库查询中,模糊查询是一种强大的技术,可以用来搜索与指定模式匹配的数据。MySQL数据库提供了一个灵活而强大的LIKE操作符,使得模糊查询变得简单和高效。本文将详细介绍MySQL中的LIKE操作符以及它的用法,并通过示例演示…

软件测试/测试开发丨接口测试Mock实战练习学习笔记

点此获取更多相关资料 本文为霍格沃兹测试开发学社学员学习笔记分享 原文链接:https://ceshiren.com/t/topic/27857 一、Rewrite 1.1、Rewrite 原理 1.2、Rewrite 实战 Tools → Rewrite 勾选 Enable Rewrite 点击下方 Add 按钮新建一个重写的规则 在右侧编辑重…

JVM之jinfo虚拟机配置信息工具

jinfo虚拟机配置信息工具 1、jinfo jinfo(Configuration Info for Java)的作用是实时地查看和调整虚拟机的各项参数。 使用jps -v 可以查看虚拟机启动时显示指定的参数列表,但是如果想知道未被显示指定的参数的系统默认值,除 …

blender动画制作全流程软件

blender官网下载地址 Download — blender.org blender菜单中英文对照表 blender常用快捷键: ~切换视图 z切换着色模式 shiftA新建物体 tab进入编辑模式 在编辑模式下: 1编辑点 2编辑线 3编辑面 shfit空格弹出所有快捷键 游标一般配合标注使用 常用:G移动物体…

1214. 波动数列

题目&#xff1a; 1214. 波动数列 - AcWing题库 思路&#xff1a;dp dp划分递归 转自&#xff1a; AcWing 1214. 波动数列&#xff08;有公式详细推导&#xff09; - AcWing 代码&#xff1a; #include <iostream> #include <cstring> #include <algori…

Stable Diffusion WebUI扩展sd-webui-controlnet之Canny

什么是Canny? 简单来说,Canny是计算机视觉领域的一种边缘检测算法。 关于Canny算法大家可以去看我下面这篇博客,里面详细介绍了Canny算法的原理以及代码演示。 OpenCV竟然可以这样学!成神之路终将不远(十五)_maxminval opencv-CSDN博客文章浏览阅读111次。14 图像梯度…

【Orangepi Zero2 全志H616】驱动串口实现Tik Tok—VUI(语音交互)

一、编程实现语音和开发板通信 wiringpi库源码demo.c 二、基于前面串口的代码修改实现 uartTool.huartTool.cuartTest.c 三、ADB adb控制指令 四、手机接入Linux热拔插相关 a. 把手机接入开发板 b. 安装adb工具&#xff0c;在终端输入adb安装指令&#xff1a; sudo apt-g…

freeswich学习

写在前面 因为所在部分主要负责公司客服业务&#xff0c;需要了解freeswich相关内容&#xff0c;所以这里将学习内容记录下。 1&#xff1a;安装freesswich freeswich是一个实现了软交换协议的开源软件&#xff0c;可以对对接运营上的通话线路&#xff0c;实现拨打电话。 安…

【C++】万字一文全解【继承】及其特性__[剖析底层化繁为简](20)

前言 大家好吖&#xff0c;欢迎来到 YY 滴C系列 &#xff0c;热烈欢迎&#xff01; 本章主要内容面向接触过C的老铁 主要内容含&#xff1a; 欢迎订阅 YY滴C专栏&#xff01;更多干货持续更新&#xff01;以下是传送门&#xff01; 目录 一.继承&复用&组合的区别1&…

ClickHouse开发系列

一、 ClickHouse详解、安装教程_clickhouse源码安装 二、ClickHouse 语法详解_clickhouse讲解 三、ClickHouse SQL 操作语句详解 四、ClickHouse 高级教程—官方原版 五、ClickHouse主键索引最佳实践 六、MySQL与ClickHouse集成 七、ClickHouse 集成MongoDB、Re…

洛谷P4185 离线+并查集

好题&#xff0c;发现没有强制在线&#xff0c;可以离线操作 排序之后带集合点数的并查集就好了 #include<bits/stdc.h> using namespace std; const int N 1e510; int n,m; int p[N],sz[N];int find(int x){if(x!p[x])p[x] find(p[x]);return p[x]; } struct Node{in…