在Pytorch中自定义dataset读取数据

这里使用的是经典的花分类数据集

下载地址:https://storage.googleapis.com/download.tensorflow.org/example_images/flower_photos.tgz

下载结束后进行解压,可以得到五种不同种类花的图片,如上图所示

主函数 main


def main():
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print("using {} device.".format(device))

    train_images_path, train_images_label, val_images_path, val_images_label = read_split_data(root)

    data_transform = {
        "train": transforms.Compose([transforms.RandomResizedCrop(224),
                                     transforms.RandomHorizontalFlip(),
                                     transforms.ToTensor(),
                                     transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])]),
        "val": transforms.Compose([transforms.Resize(256),
                                   transforms.CenterCrop(224),
                                   transforms.ToTensor(),
                                   transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])}

    train_data_set = MyDataSet(images_path=train_images_path,
                               images_class=train_images_label,
                               transform=data_transform["train"])

    batch_size = 8
    nw = min([os.cpu_count(), batch_size if batch_size > 1 else 0, 8])  # number of workers
    print('Using {} dataloader workers'.format(nw))
    train_loader = torch.utils.data.DataLoader(train_data_set,
                                               batch_size=batch_size,
                                               shuffle=True,
                                               num_workers=nw,
                                               collate_fn=train_data_set.collate_fn)

    # plot_data_loader_image(train_loader)

    for step, data in enumerate(train_loader):
        images, labels = data

其中,

train_images_path, train_images_label, val_images_path, val_images_label  = read_split_data ( root )

传入参数 root(就是该数据集所在的路径),没有传入参数val_rate就取其默认值0.2( 即验证集占整个数据集的 20% ), 调用函数 read_split_data

def read_split_data(root: str, val_rate: float = 0.2):
    random.seed(0)  # 保证随机结果可复现
    assert os.path.exists(root), "dataset root: {} does not exist.".format(root)

    # 遍历文件夹,一个文件夹对应一个类别
    flower_class = [cla for cla in os.listdir(root) if os.path.isdir(os.path.join(root, cla))]
    # 排序,保证顺序一致
    flower_class.sort()
    # 生成类别名称以及对应的数字索引
    class_indices = dict((k, v) for v, k in enumerate(flower_class))
    json_str = json.dumps(dict((val, key) for key, val in class_indices.items()), indent=4)
    with open('class_indices.json', 'w') as json_file:
        json_file.write(json_str)

    train_images_path = []  # 存储训练集的所有图片路径
    train_images_label = []  # 存储训练集图片对应索引信息
    val_images_path = []  # 存储验证集的所有图片路径
    val_images_label = []  # 存储验证集图片对应索引信息
    every_class_num = []  # 存储每个类别的样本总数
    supported = [".jpg", ".JPG", ".png", ".PNG"]  # 支持的文件后缀类型
    # 遍历每个文件夹下的文件
    for cla in flower_class:
        cla_path = os.path.join(root, cla)
        # 遍历获取supported支持的所有文件路径
        images = [os.path.join(root, cla, i) for i in os.listdir(cla_path)
                  if os.path.splitext(i)[-1] in supported]
        # 获取该类别对应的索引
        image_class = class_indices[cla]
        # 记录该类别的样本数量
        every_class_num.append(len(images))
        # 按比例随机采样验证样本
        val_path = random.sample(images, k=int(len(images) * val_rate))

        for img_path in images:
            if img_path in val_path:  # 如果该路径在采样的验证集样本中则存入验证集
                val_images_path.append(img_path)
                val_images_label.append(image_class)
            else:  # 否则存入训练集
                train_images_path.append(img_path)
                train_images_label.append(image_class)

    print("{} images were found in the dataset.".format(sum(every_class_num)))
    print("{} images for training.".format(len(train_images_path)))
    print("{} images for validation.".format(len(val_images_path)))

    plot_image = False
    if plot_image:
        # 绘制每种类别个数柱状图
        plt.bar(range(len(flower_class)), every_class_num, align='center')
        # 将横坐标0,1,2,3,4替换为相应的类别名称
        plt.xticks(range(len(flower_class)), flower_class)
        # 在柱状图上添加数值标签
        for i, v in enumerate(every_class_num):
            plt.text(x=i, y=v + 5, s=str(v), ha='center')
        # 设置x坐标
        plt.xlabel('image class')
        # 设置y坐标
        plt.ylabel('number of images')
        # 设置柱状图的标题
        plt.title('flower class distribution')
        plt.show()

    return train_images_path, train_images_label, val_images_path, val_images_label

运行上述代码, 得到 class_indices.json 文件,该文件存储了类别名称以及每个类别对应的索引

设置变量 plot_image 为True,可以将每个类别的样本数以柱状图的形式可视化出来

函数 read_split_data 执行结束后,返回四个列表  : train_images_path 、train_images_label 、val_images_path 和 val_images_label,分别表示训练集的图像和标签路径以及验证集的图像和标签路径,对数据集完成了训练集和验证集的划分!

然后对训练集和验证集中的数据进行数据预处理,比如裁剪、翻转、归一化等等操作

接下来,重点来了!

train_data_set = MyDataSet(images_path=train_images_path,
                           images_class=train_images_label,
                           transform=data_transform["train"])

传入训练集图像的路径列表、标签列表以及数据预处理的方法,对类 MyDataSet 进行初始化,得到类 MyDataSet 的实例对象 train_data_set

MyDataSet 是一个自定义的数据类,代码如下:

from PIL import Image
import torch
from torch.utils.data import Dataset


class MyDataSet(Dataset):
    """自定义数据集"""

    def __init__(self, images_path: list, images_class: list, transform=None):
        self.images_path = images_path
        self.images_class = images_class
        self.transform = transform

    def __len__(self):
        return len(self.images_path)

    def __getitem__(self, item):
        img = Image.open(self.images_path[item])
        # RGB为彩色图片,L为灰度图片
        if img.mode != 'RGB':
            raise ValueError("image: {} isn't RGB mode.".format(self.images_path[item]))
        label = self.images_class[item]

        if self.transform is not None:
            img = self.transform(img)

        return img, label

    @staticmethod
    def collate_fn(batch):
        # 官方实现的default_collate可以参考
        # https://github.com/pytorch/pytorch/blob/67b7e751e6b5931a9f45274653f4f653a4e6cdf6/torch/utils/data/_utils/collate.py
        images, labels = tuple(zip(*batch))

        images = torch.stack(images, dim=0)
        labels = torch.as_tensor(labels)
        return images, labels

该类继承类Dataset,主要实现初始化函数__init__( )、计算数据集中样本数量的函数__len__( )、根据索引返回相应的图片和标签的函数__getitem__( ) 以及 collate_fn( ) 函数

我想要重点阐述一下关于函数 collate_fn( ) 函数的作用

collate_fn( ) 函数决定了如何将数据进行打包处理

@staticmethod
def collate_fn(batch):
    # 官方实现的default_collate可以参考
    # https://github.com/pytorch/pytorch/blob/67b7e751e6b5931a9f45274653f4f653a4e6cdf6/torch/utils/data/_utils/collate.py
    images, labels = tuple(zip(*batch))
    images = torch.stack(images, dim=0)
    labels = torch.as_tensor(labels)
    return images, labels

传入函数的参数 batch 是由 (images,labels) 组成的一个个的元组

如果在此处设置batch_size的值为8,那么这个函数就从数据集中获取8张图片以及这8张图片所对应的标签

可以设置断点来看一下:

因为 batch_size 取 8,所以可以看到 batch 是一个长度为8的列表,列表是由8个元组元素组成的,每个元组是由图像和其所对应的标签组成的

最后,通过 DataLoader 从实例化对象 train_data_set 中加载数据,打包成一个一个 batch 送入网络中进行训练

train_loader = torch.utils.data.DataLoader(train_data_set,
                                           batch_size=batch_size,
                                           shuffle=True,
                                           num_workers=nw,
                                           collate_fn=train_data_set.collate_fn)

这样就可以得到用于加载训练数据的数据加载器 train_loader

可以将 数据加载器 train_loader 传给函数,通过调用函数 plot_data_loader_image 后

plot_data_loader_image(train_loader)

这样就能可视化出数据加载器  train_loader 中的内容,如图所示(此处需要将 num_workers 设置为0)

 

 

 

                                              ……

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/256956.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

zookeeper:启动后占用8080端口问题解决

ZooKeeper是一个分布式的,开放源码的分布式应用程序协调服务。它为分布式应用提供一致性服务的软件,提供的功能包括:配置维护、域名服务、分布式同步、组服务等。 我们经常在运行zookeeper服务时,不需要配置服务端口,…

二叉树的最大深度(LeetCode 104)

文章目录 1.问题描述2.难度等级3.热门指数4.解题思路方法一:深度优先搜索GolangC 方法二:广度优先搜索GolangC 参考文献 1.问题描述 给定一个二叉树 root ,返回其最大深度。 叉树的「最大深度」是指从根节点到最远叶子节点的最长路径上的节…

Apollo开放平台9.0让自动驾驶开发者轻松上手

文章目录 平台架构:基础环境:开始使用:体验心得: 在自动驾驶技术飞速发展的今天,成为这个领域的一名开发者是一次挑战、一次冒险,更是一次心灵之旅。作为这个领域的先锋之一,Apollo开放平台9.0于12月19日发…

TSINGSEE青犀边缘AI计算基于车辆结构化数据的车辆监控方案

随着人工智能技术的不断发展,边缘AI技术逐渐成为智能交通领域的研究热点。其中,基于边缘AI的车辆结构化数据技术与车辆监控系统是实现智能交通系统的重要手段之一。为了满足市场需求,TSINGSEE青犀边缘AI智能分析网关/视频智能分析平台推出了一…

【百度PARL】强化学习笔记

文章目录 强化学习基本知识一些框架Value-based的方法Q表格举个例子 强化的概念TD更新 Sarsa算法SampleSarsa Agent类 On_policy vs off_policy函数逼近与神经网络DQN算法DQN创新点DQN代码实现model.pyalgorithm.pyagent.py总结:举个例子 实战 视频:世界…

【SQL】根据年月,查询月份中每一天的数据量

传入YYYY-MM-01&#xff0c;查询这个月中每一天的数据量&#xff0c;没有数据的天数用0表示 WITH RECURSIVE DateRange AS (SELECT :startDate AS DateUNION ALLSELECT DATE_ADD(Date, INTERVAL 1 DAY) FROM DateRangeWHERE Date < LAST_DAY(:startDate) ) SELECTdr.Date,CO…

docker中如何使用Arthas

docker中如何使用Arthas 一、操作步骤1、首先拷贝arthas包下来&#xff1a;2、其次选中你需要查看的容器ID&#xff1a;3、拷贝arthas程序包到容器目录下&#xff1a;4、进入到容器目录5、进入到第3步映射到容器的路径&#xff0c;并使用ll查看是否存在 arthas-boot.jar6、使用…

全球移动通信(2G/3G/4G/5G)频谱分布情况

一、概述 随着通信技术的不断发展&#xff0c;全球各国都在积极推进2G、3G、4G、5G网络的建设和应用。根据FCC统计&#xff0c;目前全球移动通信频谱分布如下&#xff1a; 二、分布 &#xff08;一&#xff09;俄罗斯 2G&#xff1a;主要使用900MHz和1800MHz两个频段。其中&…

Postman接口测试之Postman常用的快捷键

作为一名IT程序猿&#xff0c;不懂一些工具的快捷方式&#xff0c;应该会被鄙视的吧。收集了一些Postman的快捷方式&#xff0c;大家一起动手操作~ 简单操作 xc 请求 操作MAC系统windows系统请求网址 ⌘L Ctrl L 保存请求 ⌘S Ctrl S 保存请求为 ⇧⌘S Ctrl Shift S发送…

云原生之深入解析Kubernetes集群发生网络异常时如何排查

一、Pod 网络异常 网络不可达&#xff0c;主要现象为 ping 不通&#xff0c;其可能原因为&#xff1a; 源端和目的端防火墙&#xff08;iptables, selinux&#xff09;限制&#xff1b; 网络路由配置不正确&#xff1b; 源端和目的端的系统负载过高&#xff0c;网络连接数满…

如何搭建一个买衣服的微信小程序商城

随着移动互联网的普及&#xff0c;微信小程序商城已经成为众多商家开展线上业务的重要平台。本文将介绍如何搭建一个卖衣服的微信小程序商城&#xff0c;帮助您实现线上业务的拓展。 第一步&#xff1a;登录乔拓云平台进入商城后台管理页面 在浏览器中搜索乔拓云平台并登录&a…

广汽本田售后服务技术技能竞赛总决赛

从2001年开始&#xff0c;广汽本田售后服务技术技能大赛年年举办&#xff0c;层层筛选、择优选拔&#xff0c;以赛促练全面提升服务领域一线人员的业务能力。 广汽本田售后服务技术技能大赛在赛程设置、考核内容和形式等方面进行了全面升级与强化&#xff0c;创新突破服务精英的…

科士达新能源荣获CTC国检集团“产品碳足迹证书”

2023年12月14-15日&#xff0c;“2023光伏行业年度大会”在江苏省宿迁市召开&#xff0c;行业主管部门、行业组织、知名专家和光伏企业等代表莅临现场。科士达新能源受邀出席&#xff0c;并在同期举办的”光伏产品碳足迹与碳中和研讨会”上&#xff0c;荣获CTC国检集团“产品碳…

WebMvcConfigurer接口详解及使用方式(Spring-WebMvc)

简介 如下图所示WebMvcConfigurer是spring-webmvc jar包下的一个接口&#xff0c;spring-webmvc jar包又来源于spring-boot-starter-web&#xff0c;所以要使用WebMvcConfigurer要引入spring-boot-starter-web依赖。WebMvcConfigurer接口提供了常用的web应用拦截方法。通过实现…

Elasticsearch 索引生命周期和翻滚 (rollover) 策略

Elasticsearch 是搜索引擎中的摇滚明星&#xff0c;它的蓬勃发展在于使你的数据井井有条且速度快如闪电。 但当你的数据成为一场摇滚音乐会时&#xff0c;管理其生命周期就变得至关重要。 正确使用索引生命周期管理 (ILM) 和 rollover 策略&#xff0c;你的后台工作人员可确保顺…

有损编码——Wyner-Ziv理论

有损编码是一种在信息传输和存储中常见的编码技术&#xff0c;其主要目标是通过牺牲一定的信息质量&#xff0c;以换取更高的压缩效率。相比于无损编码&#xff0c;有损编码可以在保证一定程度的信息还原的前提下&#xff0c;使用更少的比特数来表示信息。Wyner-Ziv理论是一种重…

台湾虾皮卖什么比较好

虾皮&#xff08;Shopee&#xff09;平台在台湾地区广受欢迎&#xff0c;吸引了大量的消费者和卖家。该平台上有许多热销产品类别&#xff0c;这些产品在台湾市场上具有巨大的销售潜力。在本文中&#xff0c;我们将介绍虾皮平台上一些热门的产品类别&#xff0c;并提供一些建议…

记录一次云服务器被攻击事件

今天去登录华为云平台的时候&#xff0c;发现服务器的cpu涨到了百分之九十九&#xff0c;这个也太不正常了&#xff0c;我自己就只部署了一个页面&#xff0c;怎么会飚这么高呢&#xff1f; 然后&#xff0c;我就去找原因&#xff0c;使用top命令&#xff0c;去查看到底是谁占用…

基于Java SSM框架实现宠物医院信息管理系统项目【项目源码】

基于java的SSM框架实现宠物医院信息管理系统演示 java简介 Java语言是在二十世纪末由Sun公司发布的&#xff0c;而且公开源代码&#xff0c;这一优点吸引了许多世界各地优秀的编程爱好者&#xff0c;也使得他们开发出当时一款又一款经典好玩的小游戏。Java语言是纯面向对象语言…

高精度红蜡3D打印加工服务珠宝首饰3D打印微型医疗器械3D打印-CASAIM

随着科技的飞速发展&#xff0c;3D打印技术已经逐渐渗透到各个领域&#xff0c;成为现代制造业的重要组成部分。而在众多的3D打印材料中&#xff0c;高精度红蜡作为一种具有优异性能的材料&#xff0c;适合对精度要求高的小尺寸模型&#xff0c;用于快速铸造&#xff0c;如珠宝…