图像分类的划分数据,dataset和dataloader的实现

目录

1. 介绍

2. 主函数代码

2. utils 模块代码

2.1 划分数据集

2.2 可视化数据集

3. dataset 数据处理

4. collate_fn

5. other 


1. 介绍

图像分类一般来说不需要自定义的dataSet,因为pytorch自定义好的ImageFolder可以解决大部分的需求,更多的dataSet是在图像分割里面实现的

这里 霹雳吧啦Wz 博主提供了一个好的代码,可以进行数据集划分(不需要保存划分后的数据集),然后重新实现了dataSet,并且对dataloader的 collate_fn 方法进行了实现

下面的代码只会对重点的部分做笔记

2. 主函数代码

这里的root传入的是数据集的路径

import os

import torch
from torchvision import transforms
from torch.utils.data import DataLoader
from my_dataset import MyDataSet
from utils import read_split_data, plot_data_loader_image


# 数据集所在根目录,不需要划分trainSet+valSet,这里是完整数据集
root = './data/flower'


def main():
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print("using {} device.".format(device))    # 打印使用的设备

    # 划分训练集 + 验证集
    train_images_path, train_images_label, val_images_path, val_images_label = read_split_data(root, val_rate=0.1,flag=False)

    # 预处理
    data_transform = {
        "train": transforms.Compose([transforms.RandomResizedCrop(224),
                                     transforms.RandomHorizontalFlip(),
                                     transforms.ToTensor(),
                                     transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])]),
        "val": transforms.Compose([transforms.Resize(256),
                                   transforms.CenterCrop(224),
                                   transforms.ToTensor(),
                                   transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])}

    # 数据处理
    train_data_set = MyDataSet(images_path=train_images_path,
                               images_class=train_images_label,
                               transform=data_transform["train"])

    batch_size = 8
    nw = min([os.cpu_count(), batch_size if batch_size > 1 else 0, 8])  # number of workers
    print('Using {} dataloader workers'.format(nw))

    # 获取数据,测试的时候num_workers 设定为0
    train_loader = DataLoader(train_data_set,batch_size=batch_size,shuffle=True,num_workers=nw,
                              collate_fn=train_data_set.collate_fn)

    # 可视化数据
    plot_data_loader_image(train_loader)


if __name__ == '__main__':
    main()

这里的代码很常规,为了测试,只加载了训练集数据

2. utils 模块代码

 

这里实现了两个功能,划分数据集 + 可视化数据集

2.1 划分数据集

代码都做了注释,这块的内容慢慢调试也很容易理解,之前实现过相似的代码,只不过当时将划分好的数据集保存到不同的目录中,然后用ImageFolder调用的

def read_split_data(root: str, val_rate: float = 0.2, flag: bool = False):
    random.seed(0)  # 保证随机结果可复现
    assert os.path.exists(root), "dataset root: {} does not exist.".format(root)    # 断言数据集目录是否存在

    # 遍历文件夹,一个文件夹对应一个类别,flower_class = ['daisy', 'dandelion', 'roses', 'sunflowers', 'tulips']
    flower_class = [cla for cla in os.listdir(root) if os.path.isdir(os.path.join(root, cla))]
    # 排序,保证顺序一致
    flower_class.sort()

    # 生成类别名称以及对应的数字索引 class_indices={'daisy': 0, 'dandelion': 1, 'roses': 2, 'sunflowers': 3, 'tulips': 4}
    class_indices = dict((k, v) for v, k in enumerate(flower_class))

    json_str = json.dumps(dict((val, key) for key, val in class_indices.items()), indent=4)
    with open('class_indices.json', 'w') as json_file:
        json_file.write(json_str)
    '''生成json文件
    {
    "0": "daisy",
    "1": "dandelion",
    "2": "roses",
    "3": "sunflowers",
    "4": "tulips"
    }
    '''

    train_images_path = []   # 存储训练集的所有图片路径
    train_images_label = []  # 存储训练集图片对应label
    val_images_path = []     # 存储验证集的所有图片路径
    val_images_label = []    # 存储验证集图片对应label

    every_class_num = []  # 存储每个类别的样本总数
    supported = [".jpg", ".JPG", ".png", ".PNG"]  # 支持的文件后缀类型

    # 遍历每个文件夹下的文件
    for cla in flower_class:
        cla_path = os.path.join(root, cla)  #  每个文件夹的路径
        # 遍历获取supported支持的所有文件路径
        images = [os.path.join(root, cla, i) for i in os.listdir(cla_path)
                  if os.path.splitext(i)[-1] in supported]  # splitext 分离文件名和后缀名

        # 获取该类别对应的索引
        image_class = class_indices[cla]
        # 记录该类别的样本数量
        every_class_num.append(len(images))
        # 按比例随机采样验证样本
        val_path = random.sample(images, k=int(len(images) * val_rate))

        for img_path in images:
            if img_path in val_path:  # 如果该路径在采样的验证集样本中则存入验证集
                val_images_path.append(img_path)
                val_images_label.append(image_class)    # 0 1 2 3 4
            else:                     # 否则存入训练集
                train_images_path.append(img_path)
                train_images_label.append(image_class)

    print("{} images were found in the dataset.".format(sum(every_class_num)))  # 总样本个数
    print("{} images for training.".format(len(train_images_path)))             # 训练集个数
    print("{} images for validation.".format(len(val_images_path)))             # 验证集个数

    plot_image = flag   # 是否绘制图表,默认为 False
    if plot_image:
        # 绘制每种类别个数柱状图
        plt.bar(range(len(flower_class)), every_class_num, align='center')
        # 将横坐标0,1,2,3,4替换为相应的类别名称
        plt.xticks(range(len(flower_class)), flower_class)
        # 在柱状图上添加数值标签
        for i, v in enumerate(every_class_num):
            plt.text(x=i, y=v + 5, s=str(v), ha='center')
        # 设置x坐标
        plt.xlabel('image class')
        # 设置y坐标
        plt.ylabel('number of images')
        # 设置柱状图的标题
        plt.title('flower class distribution')
        plt.show()

    return train_images_path, train_images_label, val_images_path, val_images_label

2.2 可视化数据集

代码如下,

# 可视化
def plot_data_loader_image(data_loader):
    batch_size = data_loader.batch_size

    # 载入名称的json文件
    json_path = './class_indices.json'
    assert os.path.exists(json_path), json_path + " does not exist."
    json_file = open(json_path, 'r')
    class_indices = json.load(json_file)

    for data in data_loader:
        images, labels = data
        for i in range(batch_size):
            # [C, H, W] -> [H, W, C]
            img = images[i].numpy().transpose(1, 2, 0)
            # 反Normalize操作
            img = (img * [0.229, 0.224, 0.225] + [0.485, 0.456, 0.406]) * 255
            label = labels[i].item()
            plt.subplot(2, batch_size//2+1, i+1)
            plt.xlabel(class_indices[str(label)])
            plt.xticks([])  # 去掉x轴的刻度
            plt.yticks([])  # 去掉y轴的刻度
            plt.imshow(img.astype('uint8'))
        plt.show()

3. dataset 数据处理

代码如下:

from PIL import Image
import torch
from torch.utils.data import Dataset


# 自定义数据集处理
class MyDataSet(Dataset):
    def __init__(self, images_path: list, images_class: list, transform=None):
        self.images_path = images_path
        self.images_class = images_class
        self.transform = transform

    def __len__(self):  # 返回数据集的个数
        return len(self.images_path)

    def __getitem__(self, item):
        img = Image.open(self.images_path[item])    # 返回路径下的PIL图像
        # RGB为彩色图片,L为灰度图片
        if img.mode != 'RGB':   # 判断是否为 RGB 图像
            raise ValueError("image: {} isn't RGB mode.".format(self.images_path[item]))
        label = self.images_class[item]

        if self.transform is not None:  # transform 对 PIL 读取的图片处理
            img = self.transform(img)

        return img, label

    @staticmethod
    def collate_fn(batch):
        # 官方实现的default_collate可以参考
        # https://github.com/pytorch/pytorch/blob/67b7e751e6b5931a9f45274653f4f653a4e6cdf6/torch/utils/data/_utils/collate.py
        images, labels = tuple(zip(*batch))

        images = torch.stack(images, dim=0)
        labels = torch.as_tensor(labels)
        return images, labels

对这里调试的话,可以看到很多信息

 

4. collate_fn

这里的实现如下:

    @staticmethod
    def collate_fn(batch):
        # 官方实现的default_collate可以参考
        # https://github.com/pytorch/pytorch/blob/67b7e751e6b5931a9f45274653f4f653a4e6cdf6/torch/utils/data/_utils/collate.py
        images, labels = tuple(zip(*batch))

        images = torch.stack(images, dim=0)
        labels = torch.as_tensor(labels)
        return images, labels

下面是之前 blog 里面写的

 

对下面进行调试,发现dataloader其实是加载batch_size 个数的list,其中没有元素是一个tuple,里面存放了图像和label

 

运行发现:将batch_size 个图像放到一个tuple里面,label也是

 

最后,

 

所以最后可视化的结果:

 

5. other 

这里博主提供了一个调试的方法,将这里勾上

 

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/5810.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

springBoot --- mybatisPlus自动生成代码

mybatisPlus自动生成代码mybatisPlus自动生成代码pom.xmlapplication.yml自动生成代码测试主启动类生成目录结果使用插件 --- 版本要求:3.4.0 版本以上pom.xml更新mybatisplus插件版本mp报错‘AutoGenerator()‘ has private access in ‘com.baomidou.mybatisplus.…

Linux系统中使任务后台挂起不停止的命令

在使用远程SSH连接工具时,退出工具时任务也停止,相当于远程连接工具在系统开启了一个Terminal终端,服务也会随着终端的中断而停止。Linux系统也提供了服务基于后台运行的命令,是独立于终端的进程。 nohup Linux nohup Linux no…

【Python】仅7行代码实现自动化天气报时

文章目录前言一、实现步骤二、请求天气接口1.引入库2.读入数据3.钉钉通知天气预报总结前言 早上出门上班前,我总是忘记查看天气预报,以至于通勤路上下雨来了个措手不及。 回想起来,大部分人早上出门前的行为模式是固定的,那么有…

一个基于stream的EPICS IOC应用程序

本文将介绍如何开发一个基于stream的EPICS IOC应用程序,其将作为一个简单的基于消息的设备(用于EPICS stream练习的设备模拟程序_yuyuyuliang00的博客-CSDN博客中最后一个python程序模拟的设备)的IOC控制程序。 1) 按如下步骤建立这个IOC程序…

vb+access大气污染模型系统

数据模型就是按专业的要求,用数字方式描述自然界的事物或现象以及他们的关系。 我们通过对地区的具体数值和情况的观察,对大气质量状况做出分析,建立一个符合当地情况的大气污染模型,用来测量大气污染浓度,并根据污染…

在公司兢兢业业5年,被新来的自动化测试倒挂了薪资…

去年年中朋友左思右想从工作了 5 年的企业离职,离职原因很简单,待疲了,薪资也没咋涨过,新来的自动化测试钱比 Ta 高一倍。但离职 Ta 还是很忐忑的,在这个公司待得久了,自己会的东西一直是那些,业…

Python3爬虫图片抓取

在上一章中,我们已经学会了如何使用Python3爬虫抓取文字,那么在本章教程中,将通过实例来教大家如何使用Python3爬虫批量抓取图片。注:该网站目前已经更换了图片的请求方式,以下爬虫方法只能作为思路参考,已…

【Linux】进程理解与学习-程序替换

环境:centos7.6,腾讯云服务器Linux文章都放在了专栏:【Linux】欢迎支持订阅 相关文章推荐: 【Linux】冯.诺依曼体系结构与操作系统 【Linux】进程理解与学习Ⅰ-进程概念 【Linux】进程理解与学习Ⅱ-进程状态 【Linux】进程理解与学…

想拿到10k-40k的offer,这些技能必不可少!作为程序员的你了解吗

总结了一份Java架构师的技能树,希望对Java编程的同学有点帮助 Java编程的技术点: ​ 计算机基础 ​ Java高级特性 设计模式 ​ 数据库 分布式系统 ​ 注意:下文主要是我个人的总结方法经验(面试学习和刷题笔记) 01…

aws codedeploy 在ec2实例和autoscaling组上进行蓝绿部署

参考资料 https://docs.amazonaws.cn/codedeploy/latest/userguide/reference-appspec-file-structure-hooks.htmlhttps://docs.amazonaws.cn/zh_cn/codedeploy/latest/userguide/applications.html为 EC2/本地蓝/绿部署创建部署组(控制台) 部署ec2比较…

面试角度看问题:消息队列详解(万字长文,绝对值得一看)

面试角度看问题:消息队列详解前言一、消息队列是什么?二、为什么要使用消息队列?1.解耦2.异步3.削峰三、消息队列有什么缺点?1.系统可用性降低2.系统复杂度提高3.一致性问题四、如何保证消息队列的高可用?1.RabbitMQ 的…

zookeeper

目录 1.软件架构的发展 2.了解zookeeper 2.1概述 2.2zookeeper的应用场景 2.3安装zookeeper 2.4zookeeper客户端命令 3.zookeeper简单操作 3.1zookeeper的数据结构 3.2节点的分类 3.3java代码操作zookeeper节点 3.4zookeeper的watch机制 3.4.1介绍 3.4.2NodeCache…

ERD Online 4.0.11 在线数据库建模、元数据协作平台(免费、私有部署)

ERD Online 是全球第一个开源、免费在线数据建模、元数据管理平台。提供简单易用的元数据设计、关系图设计、SQL查询等功能,辅以版本、导入、导出、数据源、SQL解析、审计、团队协作等功能、方便我们快速、安全的管理数据库中的元数据。 4.0.11 ❝ :memo: fix(erd):…

5亿融资与重磅新品双发布,杉数以智能决策技术变革中国产业运营模式

2023年3月30日,由杉数科技举办的“智能决策重塑增长”2023杉数科技智能决策前沿峰会在北京举行。会上发布了杉数新一轮融资消息,同时,面向零售快消的决策优化产品计划宇宙(Planiverse)与面向工业制造的决策优化产品数弈…

Flink (四) --------- Flink 运行时架构

目录一、系统架构1. 整体构成2. 作业管理器(JobManager)3. 任务管理器(TaskManager)二、作业提交流程1. 高层级抽象视角2. 独立模式(Standalone)3. YARN 集群三、 一些重要概念1. 数据流图(Data…

C的实用笔记36——几种常用的字符串处理API(一)

0、const关键字 1、知识点:const是与存储相关的关键字,用作常量声明,修饰普通变量和指针变量,表示只读。const修饰普通变量:,修饰后变量从可修改的左值变成不可修改的左值 const修饰指针变量:分…

redis源码解析(四)——ziplist

版本:redis - 5.0.4 参考资料:redis设计与实现 文件:src下的ziplist.c ziplist.h 一、基础知识1、压缩列表的各个组成部分及详细说明2、列表节点3、encoding二、连锁更新三、ziplist.hquickList一、基础知识 压缩列表是Redis为了节约内存而开…

陌生人社交软件如何破冰?

据艾媒咨询的数据显示,2020年中国移动社交用户规模已达9.24亿人,预计2022年中国移动社交用户整体突破10亿人。而早在2020年,我国陌生人社交用户规模已经达到了6.49亿人,虽然增速有所放缓,但整体规模还是较为庞大。 艾媒…

操作系统笔记——进程管理

操作系统笔记——进程管理2. 进程管理2.1 进程与线程2.1.1 进程的引入前趋图程序的顺序执行程序的并发执行2.1.2 进程的定义及描述进程的定义进程的特征进程和程序的关系进程与作业的区别进程的组成2.1.3 进程的状态与转换进程的5种基本状态进程的状态的相互转换2.1.4 进程的控…

java常见锁策略分享(包括cas和synchronized的优化)

前言 锁策略学习思维导图: 1.常见锁策略 ① 乐观锁和悲观锁 ● 它们是根据锁冲突的预测,如果预测锁冲突比较小,那就是乐观锁,反之,就是悲观锁. ● 举个例子:高考前夕,我总觉得高考题会很难,然后拼命做各种科目的题,全副武装的去应对高考,而我妈则觉得高考只是人生的一个阶段而…