PyTorch|Dataset与DataLoader使用、构建自定义数据集

文章目录

  • 一、Dataset与DataLoader
  • 二、自定义Dataset类
    • (一)\_\_init\_\_函数
    • (二)\_\_len\_\_函数
    • (三)\_\_getitem\_\函数
    • (四)全部代码
  • 三、将单个样本组成minibatch(DataLoader)
    • (一)PyTorch的DataLoader源码
      • 1、DataLoader的参数
      • 2、init函数
      • 3、iter函数
    • (二)使用DataLoader遍历


一、Dataset与DataLoader

PyTorch提供的两个常用数据API:

  • torch.utils.data.Dataset:用于处理单个训练样本,读取数据特征、size、标签等,并且包括数据转换等;
  • torch.utils.data.DataLoader:DataLoader在Dataset周围重载一个可迭代对象,以便轻松访问样本。

官方案例: Fashion-MNIST数据集
torchvision:torch的一个视觉库,将torchvision中的datasets导入进来,就能获得其中的各种数据集

FashionMNIST图像存储在目录img_dir中,标签存储在CSV文件annotations_file中

import torch
from torch.utils.data import Dataset
from torchvision import datasets
from torchvision.transforms import ToTensor
import matplotlib.pyplot as plt


training_data = datasets.FashionMNIST(
    root="data",
    train=True,
    download=True,
    transform=ToTensor()
)

test_data = datasets.FashionMNIST(
    root="data",
    train=False,
    download=True,
    transform=ToTensor()
)

对上述数据集进行可视化:

labels_map = {
    0: "T-Shirt",
    1: "Trouser",
    2: "Pullover",
    3: "Dress",
    4: "Coat",
    5: "Sandal",
    6: "Shirt",
    7: "Sneaker",
    8: "Bag",
    9: "Ankle Boot",
}
figure = plt.figure(figsize=(8, 8))
cols, rows = 3, 3
for i in range(1, cols * rows + 1):
    sample_idx = torch.randint(len(training_data), size=(1,)).item()
    img, label = training_data[sample_idx]
    figure.add_subplot(rows, cols, i)
    plt.title(labels_map[label])
    plt.axis("off")
    plt.imshow(img.squeeze(), cmap="gray")
plt.show()

二、自定义Dataset类

  • 构建自定义的Dataset类,需要继承TensorFlow的官方dataset类
  • 自定义Dataset类必须实现三个函数:__init__,__len__和__getitem__

pytorch中的dataset类是在pytorch的torch下的utils之下的data文件夹里有一个dataset.py
在这里插入图片描述

(一)__init__函数

包含图像、注释文件和两个转换:

  • annotations_file:标注文件
  • img_dir:图像目录
def __init__(self, annotations_file, img_dir, transform=None, target_transform=None):
    self.img_labels = pd.read_csv(annotations_file) #标签存储在CSV文件annotations_file中
    self.img_dir = img_dir #FashionMNIST图像存储在目录img_dir中
    self.transform = transform #图像转换
    self.target_transform = target_transform

(二)__len__函数

返回数据集的样本数(就是img_labels的长度)

def __len__(self):
    return len(self.img_labels)

(三)__getitem_\函数

输入索引index,getitem函数从数据集中加载并返回对应index的一个样本:

def __getitem__(self, idx):
		#img_labels的第index行第0列标注了对应的照片文件名称
    img_path = os.path.join(self.img_dir, self.img_labels.iloc[idx, 0])
    image = read_image(img_path) #使用read_image将图像转换为张量
    label = self.img_labels.iloc[idx, 1] #从self中的csv数据中检索相应的标签
    #调用转换函数
    if self.transform: 
        image = self.transform(image)
    if self.target_transform:
        label = self.target_transform(label)
    return image, label #返回张量图像和相应的标签

(四)全部代码

import os
import pandas as pd
from torchvision.io import read_image

class CustomImageDataset(Dataset):
    def __init__(self, annotations_file, img_dir, transform=None, target_transform=None):
        self.img_labels = pd.read_csv(annotations_file)
        self.img_dir = img_dir
        self.transform = transform
        self.target_transform = target_transform

    def __len__(self):
        return len(self.img_labels)

    def __getitem__(self, idx):
        img_path = os.path.join(self.img_dir, self.img_labels.iloc[idx, 0])
        image = read_image(img_path)
        label = self.img_labels.iloc[idx, 1]
        if self.transform:
            image = self.transform(image)
        if self.target_transform:
            label = self.target_transform(label)
        return image, label

三、将单个样本组成minibatch(DataLoader)

(一)PyTorch的DataLoader源码

1、DataLoader的参数

DataLoader通常是在torch.utils.data下
在这里插入图片描述
常用的参数有:

  • dataset(数据集):需要提取数据的数据集,Dataset对象
  • batch_size(批大小):每一次装载样本的个数,int型
  • shuffle:是否打乱数据顺序
  • sampler(Sampler, optional): 自定义从数据集中取样本的策略,如果指定这个参数,那么shuffle必须为False
  • num_workers:进行数据加载时使用单个进程还是多进程进行加载,多进程意为加载速度更快,一般默认为0,表示使用主进程进行加载
  • collate_fn (callable, optional): 将一个list的sample组成一个mini-batch的函数,一般用于对于一个batch进行后处理
  • pin_memory (bool, optional): 如果设置为True,那么data loader将会在返回它们之前将tensors拷贝到CUDA中的固定内存(CUDA pinned memory)中
  • drop_last:当样本数不能被batchsize整除时, 是否舍弃最后一批数据
from torch.utils.data import DataLoader

train_dataloader = DataLoader(training_data, batch_size=64, shuffle=True)
test_dataloader = DataLoader(test_data, batch_size=64, shuffle=True)

2、init函数

主要做了三件事:构建sampler、构建batch_sampler、构建collate_fn

定义属性:
在这里插入图片描述
如果设置了自定义的sampler然后又设置了shuffle=true,这种情况是没有意义的:
(shuffle是官方自定义的一个随机sampler)
在这里插入图片描述
设置了batch_sampler的情况下,就不需要设置batch_size、shuffle、sampler和drop_last了:
在这里插入图片描述
如果没有设置sampler,则先判断数据集类型,如果使用的是map-style(else逻辑),就根据是否设置shuffle来选择pytorch内置的sampler:
在这里插入图片描述
设置了batch_size但是没有设置batch_sampler时,会使用内置的BatchSampler:
在这里插入图片描述
如果没有设置collate_fn,就判断auto_collation是否设置(auto_collation是根据batch_sampler是否是None来设置的,如果batch_sampler不是none,auto_collation就是true),default_collate是将batch作为输入,batch输出,并没有对数据做额外处理:
在这里插入图片描述

3、iter函数

iter函数返回的是get_iterator的值:
在这里插入图片描述
get_iterator根据num_workers的设置选择对应的内置DataLoaderIter:
在这里插入图片描述

所以可知,iter函数最终返回的是一个dataloaderiter对象,以SingleProcessDataLoaderIter为例,类里有next_data函数:
在这里插入图片描述
SingleProcessDataLoaderIter类是继承了BaseDataLoaderIter类,BaseDataLoaderIter类中的next函数就是使用了子类中的next_data:
在这里插入图片描述

(二)使用DataLoader遍历

根据上述源码分析,就可以对dataloader去迭代iter之后调用next函数来获得每一批次的数据:

  • 通过DataLoader实现对于数据集的遍历,每次遍历会得到一个batch的数据,这里设置batch_size为64:
from torch.utils.data import DataLoader

train_dataloader = DataLoader(training_data, batch_size=64, shuffle=True)
test_dataloader = DataLoader(test_data, batch_size=64, shuffle=True)
  • iter函数将train_dataloader变成一个迭代器,使用next函数可以以此从迭代器中生成一个一个的批次:
# Display image and label.
train_features, train_labels = next(iter(train_dataloader))
print(f"Feature batch shape: {train_features.size()}")
print(f"Labels batch shape: {train_labels.size()}")
img = train_features[0].squeeze()
label = train_labels[0]
plt.imshow(img, cmap="gray")
plt.show()
print(f"Label: {label}")

在这里插入图片描述由于batch_size=64,因此最终返回的Feature batch shape以及Labels batch shape均为64。


参考:
PyTorch官方文档:Datasets & DataLoaders
5、深入剖析PyTorch DataLoader源码

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/520885.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

信息论基础:串联信道

串联信道 大学时候看过一期湖南卫视《快乐大本营》,那时候的主持人是何炅和李湘。节目的一个环节是邀请五名观众上台做猜谜游戏。五人带上耳机,坐在一排椅子上,两两中间隔着挡板,好像并排在一起上厕所。李湘把一部电影的名字写在…

Redis集群三种模式

一、Redis集群的三种模式 Redis有三种模式,分别是主从复制、哨兵模式、cluster 主从复制:主从复制是高可用Redis的基础,哨兵和集群都是在主从复制基础上实现高可用的。主从复制主要实现了数据的多机备份,以及对于读操作的负载均衡和简单的故障…

国家开放大学电大《钢结构》形考任务答案

电大搜题 多的用不完的题库,支持文字、图片搜题,包含国家开放大学、广东开放大学、超星等等多个平台题库,考试作业必备神器。 公众号 答案:更多答案,请关注【电大搜题】微信公众号 答案:更多答案&#x…

【windows】--- nginx 超详细安装并配置教程

目录 一、下载 nginx二、安装三、查看是否安装成功四、配置五、关闭 nginx六 负载均衡七 配置静态资源1. 根目录下的子目录(root)2.完全匹配(alias) 刷新配置(不必重启nginx)八、后端鉴权 一、下载 nginx 打开 nginx 的官网:nginx.org/ &…

K8S基于containerd做容器从harbor拉取镜

实现创建pod时,通过指定harbor仓库里的镜像来运行pod 检查:K8S是不是用containerd做容器运行时,以及containerd的版本是不是小于1.6.22 kubectl get nodes -owide1、如果containerd小于 1.6.22,需要先升级containerd 先卸载旧的…

力扣Lc28---- 557. 反转字符串中的单词 III(java版)-2024年4月06日

1.题目描述 2.知识点 1)用StringBuilder的方法 实现可变字符串结果 最后返回的时候用.toString的方法 2)在Java中使用StringBuilder的toString()方法时,它会返回StringBuilder对象当前包含的所有字符序列的字符串表示。 在我们的例子中,sb是一个Stri…

初心护蕾 珍视青春

(通讯员:赵灿飞 图:杨美、孙红浪) 为进一步加强未成年人合法权益保护工作,提高未成年人的自我安全防范意识和能力,培养未成年人正确的性观念和自我保护意识,促进健康的人际关系&#xff0c…

Debian安装宝塔教程

宝塔面板是一款非常受欢迎的服务器管理软件,它以其强大的功能、简洁的操作界面和丰富的应用生态而闻名。宝塔面板不仅能够帮助用户轻松管理服务器,还能够提供网站、数据库、FTP、备份等多种服务,是服务器管理的得力助手。 宝塔面板的特色 1.…

【Spring】之AOP详解

AOP 什么是AOP? AOP:Aspect Oriented Programming,面向切面编程。 切面指的是某一类特定问题,因此面向切面编程也可以理解为面向特定方法编程。例如,在任何一个系统中,总有一些页面不是用户可以随便访问…

设置你的第一个React应用

目录 一、React入门 1.1 你好React 1.2 创建React 1.3 应用结构 二、总结 2.1 定义组件 2.2 组件源码 三、组件详解 注意事项 3.1 组件三部曲 3.2 组件通信 —— props 3.3 对象数组迭代 —— map() 3.4 事件处理 3.5 钩子函数 —— useState() 初次学习最终效果…

Cortex-M7 内存映射模型

1 前言 如图1所示, Cortex-M7最大支持4GB的内存寻址,并对内存映射(memory map)做了初步的规定,将整个内存空间划分为了多个内存区域(region)。每个内存区域有着既定的内存类型(memory type)和内存属性(memory attribute),这两者决…

AI - ComfyUI过程图(3)

ComfyUI 比 Stable Diffusion WebUI更灵活,而且可以看到处理过程,能增加节点进行后续处理,因而更强大。 看看下面一张图的变化,一开始惨不忍睹。 使用 Ultimate SD Upscale 提升分辨率 超精后脸部有改善: 脸部比较…

递归实现指数型枚举(acwing)

题目描述: 从 1∼n 这 n 个整数中随机选取任意多个,输出所有可能的选择方案。 输入格式: 输入一个整数 n。 输出格式: 每行输出一种方案。 同一行内的数必须升序排列,相邻两个数用恰好 1 个空格隔开。 对于没有…

一周年纪念

文章目录 机缘:命运之门收获---知识之心日常---灵魂之窗成就 — 自我之光憧憬 — 未来之路 机缘:命运之门 “人生是由一连串的选择组成,而真正的成长,往往始于最具挑战性的决定。” —— 这句话恰如其分地概括了我选择跨考计算机的…

自动驾驶执行层 - 线控底盘基础原理(非常详细)

自动驾驶执行层 - 线控底盘基础原理(非常详细) 附赠自动驾驶学习资料和量产经验:链接 1. 前言 1.1 线控的对象 在自动驾驶行业所谓的“感知-定位-决策-执行”的过程中,在末端的执行层,车辆需要自主执行决策层所给出的指令,具体…

2024最全ChatGPT支持GPTs使用教程+Prompt应用预设词教程

使用指南 直接复制使用 可以前往已经添加好Prompt预设的AI系统测试使用(可自定义添加使用) https://ai.sparkaigf.com 现已支持GPTs 雅思写作考官 我希望你假定自己是雅思写作考官,根据雅思评判标准,按我给你的雅思考题和对应…

【多模态融合】MetaBEV 解决传感器故障 3D检测、BEV分割任务

前言 本文介绍多模态融合中,如何解决传感器故障问题;基于激光雷达和相机,融合为BEV特征,实现3D检测和BEV分割,提高系统容错性和稳定性。 会讲解论文整体思路、模型框架、论文核心点、损失函数、实验与测试效果等。 …

Python 基于列表实现的通讯录管理系统(有完整源码)

目录 通讯录管理系统 PersonInformation类 ContactList类 menu函数 main函数 程序的运行流程 完整代码 运行示例 通讯录管理系统 这是一个基于文本的界面程序,用户可以通过命令行与之交互,它使用了CSV文件来存储和读取联系人信息,这…

浅谈Redis和一些指令

浅浅谈一谈Redis的客户端 Redis客户端 Redis也是一个客户端/服务端结构的程序。 MySQL也是一个客户端/服务端结构的程序。 Redis的客户端也有多种形态 1.自带命令行客户端 redis-cli 2.图形化界面的客户端(桌面程序,web程序) 像这样的图形…

3d代理模型怎么转换成标准模型---模大狮模型网

在当今的虚拟世界中,3D建模技术被广泛运用于游戏开发、电影制作、工业设计等领域。在3D建模过程中,有时会遇到需要将代理模型转换成标准模型的情况。模大狮将从理论和实践两方面,介绍如何将3D代理模型转换成标准模型,以帮助读者更…