第12章 PyTorch图像分割代码框架-1

从本章开始,本书将会进行深度学习图像分割的实战阶段。PyTorch作为目前最为流行的一款深度学习计算框架,在计算机视觉和图像分割任务中已经广泛使用。本章将介绍基于PyTorch的深度学习图像分割代码框架,在总体框架的基础上,基于PASCAL VOC 2012数据集,分别介绍预处理模块、数据导入模块、模型模块、工具函数模块、配置模块、主函数模块、推理模块和部署模块等。每个模块都会在基本的代码结构基础上配以简单的代码示例,帮助读者快速掌握深度学习图像分割的代码范式。

图像分割代码总体框架

深度学习项目的代码框架和范式整体较为固定,一般按照机器学习的pipeline都会有:数据预处理、数据导入、模型搭建、训练、验证、测试和部署等流程。图像分割作为经典的深度学习应用方向之一,也不例外地遵循上述范式。根据深度学习图像分割项目的特点,一个相对完整的图像分割代码框架应包含:预处理模块、数据导入模块、模型模块、工具函数模块、配置模块、主函数模块、推理模块和部署模块。

PyTorch是目前最为流行的深度学习计算框架,无论在学术界还是工业界,都被广大用户作为首选用来进行项目搭建的框架。基于PyTorch的深度学习图像分割代码框架结构如图1所示。

181d4d6bd81ca2d3862db5926c3fbcbc.png

1是一个典型的深度学习图像分割代码框架,也是一个分割项目的工程示例。左侧为代码目录结构,与虚线对应的右侧部分则是具体的模块和功能文档名称。除了前述的8个核心模块之外,图中还补充了基于Jupyter Notebookexperiment.ipynb实验文档,方便开发者在将代码写成工程文件前进行一些尝试性的实验和探索。README.md文档则是关于该项目的说明文档,具体可包括项目的基本情况介绍,训练、验证和测试方法等。需要注意的是,图11-1仅是一个典型的参考代码框架,具体的目录结构可能会因开发者的习惯和项目的特殊需要而有所不同。但不论什么样的深度学习图像分割项目,数据、模型和包含训练验证的主函数这三个部分,一定是不可或缺的。

本章将基于经典的语义分割数据集PASCAL VOC 2012,针对上述8个代码模块,在给出基本代码模板的基础上,进而给出完整的代码实现方式。关于PASCAL VOC 2012数据集相关信息,读者可直接参考本书第10章内容。

预处理模块

数据预处理模块主要用于定义一些图像预处理的功能函数,包括像数据标准化方法、图像转换方法、图像重采样方法等。预处理模块不是必须的,具体结合每个分割项目的图像数据集实际情况来定,对于PASCAL VOC 2012这样成熟的公开数据集,可能一般不需要太多预处理的流程,但对于个人收集的原始数据、一些医学影像或者遥感影像等特定领域的图像数据,可能需要一定的预处理过程。预处理模块需要用户结合个人数据集实际情况酌情进行定义。在图像分割项目中,如果我们的图像数据集需要预处理,我们可以在preprocess.py文件中定义预处理过程。

数据导入模块

PyTorch提供了数据导入的标准化代码模板,可以直接借助于torch.utils.data模块下的数据对象类Dataset来完成数据的定义和读取,然后再借助于数据导入类DataLoaders将数据按批次导入到模型中。基于Dataset的数据读取框架如代码11-1所示。

# 导入数据对象类Dataset
from torch.utils.data import Dataset
# 基于Dataset的数据定义和读取框架
class CustomDataset(Dataset):
    def __init__(self, ...):
        # stuff
        
    def __getitem__(self, index):
        # stuff
        return (img, label)
        
    def __len__(self):
        # return examples size
        return len(self.images)

其中CustomDataset为自定义的数据对象类,该类继承于Dataset类,下面包括三个方法:__init__为类初始化方法,包含了数据的存放路径、输入图像、输出图像以及数据变换等属性的定义;__getitem__定义了图像读取和变换方法,一般可通过pillow库的Image.Open先完成读取,再用torchvision库下的transform完成数据变换;__len__方法则通过调用内置函数len返回数据的长度。按照该数据读取框架,我们可以定义PASCAL VOC 2012数据集的读取方法,定义voc.py文件如下。

# 导入相关模块
import numpy as np
from PIL import Image
from pathlib import Path
from torch.utils.data import Dataset


# 定义VOC数据集类
class VOCSegmentation(Dataset):
    """
    A PyTorch Dataset class for the VOC Segmentation dataset.
    """
    def __init__(self, root, image_set='train', transform=None):
        """
        Initialize the dataset.
        Args:
            root (str): Path to the dataset root directory.
            image_set (str): The image set to use ('train' or 'val').
            transform (callable, optional): A function/transform to apply to the images and masks.
        """
        self.root = Path(root)
        self.transform = transform
        self.image_set = image_set
        base_dir = 'VOCdevkit/VOC2012'
        voc_root = self.root / base_dir
        image_dir = voc_root / 'JPEGImages'
        if not voc_root.is_dir():
            raise RuntimeError('Dataset not found.')


        mask_dir = voc_root / 'SegmentationClass'
        splits_dir = voc_root / 'ImageSets/Segmentation'
        split_f = splits_dir / f"{image_set.rstrip()}.txt"


        with open(split_f, "r") as f:
            file_names = [x.strip() for x in f.readlines()]


        self.images = [image_dir / f"{x}.jpg" for x in file_names]
        self.masks = [mask_dir / f"{x}.png" for x in file_names]
        assert (len(self.images) == len(self.masks))


    def __getitem__(self, index):
        """
        Get an item from the dataset.
        Args:
            index (int): Index of the item to get.   
        Returns:
            tuple: (image, target) where target is the image segmentation.
        """
        img = Image.open(self.images[index]).convert('RGB')
        target = Image.open(self.masks[index])
        if self.transform is not None:
            img, target = self.transform(img, target)
        return img, target


    def __len__(self):
        return len(self.images)

如代码11-2所示,我们通过pathlib库来定义数据路径,通过pillowImage.open函数来读取图像并进行同步转换。除此之外,VOC数据集掩码还需要单独进行颜色编码的接码,所以实际操作时要单独定义voc_map函数以及在VOCSegmentation类中补充一个掩码图像的解码方法decode_target。完整代码可参考本书配套代码对应章节。

需要特别说明的是,torchvision库中提供的transform模块提供了各种图像变换方法,也就是我们通常所说的在线数据增强(Online Data Augmentation)。在线数据增强是指在训练过程中,每次读取一个样本时都会进行数据增强操作。也就是说,数据增强是在每个小批量(batch)的数据上实时进行的。在线数据增强可以通过数据转换的方式,在每个训练迭代中生成多个不同的数据样本,以增加训练集的多样性,但不实际增加训练数据的数量。常见的在线数据增强操作包括随机裁剪(random cropping)、翻转(flipping)、旋转(rotation)、缩放(scaling)等。与在线数据增强对应的是离线数据增强(Offline Data Augmentation)。离线数据增强是指在训练开始之前,将原始数据集进行增强,并将增强后的数据保存为新的训练集。然后,在训练过程中,使用增强后的训练集进行模型训练。离线数据增强的好处是可以节省训练时间,因为数据增强只需在训练开始之前完成一次。常见的离线数据增强操作包括扩充数据集,例如通过旋转、平移、缩放等方式生成新的图像。常用的离线数据增强库包括imgaugalbumentationsAugmentor等。在代码11-2中,我们在初始化方法里面提供数据transform方式,同步接收imgtarget作为输入进行在线数据增强,以增强训练样本的多样性。当然了,这需要我们对torchvision中的transform方法稍微进行改动。

后续全书内容和代码将在github上开源,请关注仓库:

https://github.com/luwill/Deep-Learning-Image-Segmentation

(未完待续)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/105265.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

Fabric.js 讲解官方demo:Stickman

本文简介 戴尬猴,我是德育处主任 Fabric.js 官网有很多有趣的Demo,不仅可以帮助我们了解其功能,还可以为我们提供创意灵感。其中,Stickman是一个非常有趣的例子。 先看看效果图 从上图可以看出,在拖拽圆形时&#xf…

yyds,Elasticsearch Template自动化管理新索引创建

一、什么是Elasticsearch Template? Elasticsearch Template是一种将预定义模板应用于新索引的功能。在索引创建时,它可以自动为新索引应用已定义的模板。Template功能可用于定义索引的映射、设置和别名等。它是一种自动化管理索引创建的方式&#xff0…

CSDN学院 < 华为战略方法论进阶课 > 正式上线!

目录 你将收获 适用人群 课程内容 内容目录 CSDN学院 作者简介 你将收获 提升职场技能提升战略规划的能力实现多元化发展综合能力进阶 适用人群 主要适合公司中高层、创业者、产品经理、咨询顾问,以及致力于改变现状的学员。 课程内容 本期课程主要介绍华为…

非侵入式负荷检测与分解:电力数据挖掘新视角

电力数据挖掘 概述案例背景分析目标分析过程数据准备数据探索缺失值处理 属性构造设备数据周波数据模型训练 性能度量推荐阅读 主页传送门:📀 传送 概述 摘要:本案例将根据已收集到的电力数据,深度挖掘各电力设备的电流、电压和功…

03.MySQL事务及存储引擎笔记

事务 查看/设置事务 select autocommit; --查看当前数据库的事务状态,1表示开启,0表示关闭 set autocommit 0; --关闭自动事务提交采用关闭自动事务提交我们就可以手动进行事务提交,但是这种设置方式是对整个数据库起作用,一些可…

MongoDB 的集群架构与设计

一、前言 MongoDB 有三种集群架构模式,分别为主从复制(Master-Slaver)、副本集(Replica Set)和分片(Sharding)模式。 Master-Slaver 是一种主从复制的模式,目前已经不推荐使用。Re…

互联网产品说明书指南,附撰写流程与方法

产品说明书,对于普通产品而言,再常见不过。药物、电器、电子产品等产品在正式出售时,往往都会附带一份产品说明书,以此告诉用户这个产品的功能与特性,并指导用户如何来使用这个产品。 产品说明书 那么,对于…

Python遍历删除列表元素的一个奇怪bug

假定有一个Python列表,比如[CFFEX.IF, CFFEX.TS,SHFE.FU],现在需要将其中带‘CFFEX’前缀的所有元素都删除。在使用列表推导式一行代码搞定之前,用了一种最朴素的遍历删除方法,结果出现了意想不到的的问题。复盘了下,结…

软件测试进阶篇----自动化测试脚本开发

自动化测试脚本开发 一、自动化测试用例开发 1、用例设计需要注意的点 2、设计一条测试用例 二、脚本开发过程中的技术 1、线性脚本开发 2、模块化脚本开发(封装线性代码到方法或者类中。在需要的地方进行调用) 3、关键字驱动开发:selen…

React之如何捕获错误

一、是什么 错误在我们日常编写代码是非常常见的 举个例子,在react项目中去编写组件内JavaScript代码错误会导致 React 的内部状态被破坏,导致整个应用崩溃,这是不应该出现的现象 作为一个框架,react也有自身对于错误的处理的解…

windows安装数据库MySQL

windows安装数据库MySQL 文章目录 windows安装数据库MySQL一、MySQL官网下载压缩包二、在D盘新建文件夹D:\MySQL,将下载的压缩包解压到该文件夹下三、配置环境变量四、通过命令行模式安装、启用、配置SQL服务 一、MySQL官网下载压缩包 下载地址:https:/…

NLP:从头开始的文本矢量化方法

一、说明 NLP 项目使用文本,但机器学习算法不能使用文本,除非将其转换为数字表示。这种表示通常称为向量,它可以应用于文本的任何合理单位:单个标记、n-gram、句子、段落,甚至整个文档。 在整个语料库的统计 NLP 中&am…

解决Windows出现找不到mfcm90u.dll无法打开软件程序的方法

今天,我非常荣幸能够在这里与大家分享关于mfc90u.dll丢失的5种解决方法。在我们日常使用电脑的过程中,可能会遇到一些软件或系统错误,其中之一就是mfc90u.dll丢失。那么,mfc90u.dll究竟是什么文件呢?接下来&#xff0c…

十八、字符串(3)

本章概要 正则表达式 基础创建正则表达式量词CharSequencePattern 和 Matcherfinde()组(Groups)start() 和 end()Pattern 标记split()替换操作reset()正则表达式与 Java I/0 正则表达式 很久之前,_正则表达式_就已经整合到标准 Unix 工具…

算法通关村第三关-青铜挑战数组专题

本期大纲 线性表基础线性表概念数组概念 数组的基本操作数组创建和初始化查找一个元素增加一个元素修改一个元素删除一个元素 小题一道 - - 单调数组问题小题一道 - - 数组合并问题小结 线性表基础 线性表概念 我们先搞清楚几个基本概念,在很多地方会看到线性结构…

cmd命令快速打开MATLAB

文章目录 复制快捷方式添加 -nojvm打开 复制快捷方式 添加 -nojvm 打开 唯一的缺点是无法使用plot,这一点比不上linux系统,不过打开速度还是挺快的。

计算机网络 第四章网络层

文章目录 1 网络层的功能2 数据交换方式:电路交换3 数据交换方式:报文交换4 数据交换方式:分组交换5 数据交换方式:数据报方式6 数据交换方式:虚电路方式及各种方式对比7 路由算法及路由协议8 IP数据报的概念和格式9 I…

疫情集中隔离

系列文章目录 进阶的卡莎C++_睡觉觉觉得的博客-CSDN博客数1的个数_睡觉觉觉得的博客-CSDN博客双精度浮点数的输入输出_睡觉觉觉得的博客-CSDN博客足球联赛积分_睡觉觉觉得的博客-CSDN博客大减价(一级)_睡觉觉觉得的博客-CSDN博客小写字母的判断_睡觉觉觉得的博客-CSDN博客纸币(…

WebSocket协议:5分钟从入门到精通

一、内容概览 WebSocket的出现,使得浏览器具备了实时双向通信的能力。本文由浅入深,介绍了WebSocket如何建立连接、交换数据的细节,以及数据帧的格式。此外,还简要介绍了针对WebSocket的安全攻击,以及协议是如何抵御类…

EasyAR使用

EazyAR后台管理,云定位服务 建模 需要自行拍摄360度视频,后台上传,由EazyAR工作人员完成构建。 标注数据 需要在unity安装EazyAR插件,在unity场景编辑后,上传标注数据。 uinity标注数据 微信小程序中使用&#x…