torch.utils.data

整体架构

平时使用 pytorch 加载数据时大概是这样的:

import numpy as np
from torch.utils.data import Dataset, DataLoader

class ExampleDataset(Dataset):
	def __init__(self):
		self.data = [1, 2, 3, 4, 5]

	def __getitem__(self, idx):
		return self.data[idx]

	def __len__(self):
		return len(self.data)

def collate_fn(batch):
	return np.array(batch)

dataset = ExampleDataset()  # create the dataset
dataloader = DataLoader(
	dataset=dataset,
	batch_size=2,
	shuffle=True,
	num_workers=4,
	collate_fn=collate_fn
)
for datapoint in dataloader:
	print(datapoint)
  1. 继承 Dataset 类,定义一个迭代器,包含两个魔法方法:__getitem__(self, idx)__len__(self),分别实现如何获取一条数据和如何设定数据长度;
  2. 定义 collate_fn 函数,设定如何组织一个 batch
  3. 实例化 Dataset,并和 collate_fn 一起传入 DataLoader,参数 batch_size 设置批大小、shuffle 设置是否打乱、num_workers 设置并行加载数据的进程数。

然而,背后到底干了什么,我们不清楚,甚至遇到 DataLoader 的如 samplerbatch_samplerworker_init_fn 的其他参数,就会懵逼。那就看一看官方文档,了解一下 torch.utils.data 是如何工作的。


上图是数据加载的整体框架图,官网说 DataLoader 组合datasetsampler,多个 workers 根据 dataset 提供的数据副本sampler 提供的 keys 并行地加载数据,并通过 collate_fn 组成 batch 供用户迭代。需要注意的有:

  1. 每个 worker 持有数据的一个副本,故占用内存主线程内存 * num_workers”;
  2. 即使用户不提供 sampler 对象 (通常不提供),DataLoader 也会根据 shuffle 参数创建一个默认的 sampler 对象;一旦提供了,其前路的 shuffle 参数不能为 True (不提供就好);
  3. 即使用户不提供 batch_sampler 对象 (通常不提供),DataLoader 也会根据 batch_sampler, drop_last 参数创建一个默认的 batch_sampler 对象;一旦提供了,其前路的 shuffle, drop_last 不能为 Truebatch_size 必须为 1 1 1sampler 必须为 None,因为创建 BatchSampler 时已经有了这些参数;

    本质上是把创建 batch_sampler 的活拉出来由用户在 DataLoader 外自定义地做了。

Dataset

分为两种:map-styleiterable-style。前者的数据可通过 [idx or key] 访问,后者的数据只能通过迭代器 next 一个个访问。所以上面架构中的采样器是对于 map-style 数据集说的iterable-style 的数据集的访问顺序由迭代器决定。

Sampler

torch.utils.data.Sampler 的子类或 Iterable,两个例子:

class AccedingSequenceLengthSampler(tu_data.Sampler[int]):
	def __init__(self, data: List[str]) -> None:
		super().__init__()
		self.data = data

	def __len__(self) -> int:
		return len(self.data)

	def __iter__(self) -> Iterator[int]:
		"""
		:return: 实现了按数据长短顺序访问数据集
		"""
		sizes = torch.tensor([len(x) for x in self.data])
		yield from torch.argsort(sizes).tolist()


class AccedingSequenceLengthBatchSampler(tu_data.Sampler[List[int]]):
	def __init__(self, data: List[str], batch_size: int) -> None:
		super().__init__()
		self.data = data
		self.batch_size = batch_size

	def __len__(self) -> int:
		return (len(self.data) + self.batch_size - 1) // self.batch_size

	def __iter__(self) -> Iterator[List[int]]:
		sizes = torch.tensor([len(x) for x in self.data])
		for batch in torch.chunk(torch.argsort(sizes), len(self)):  # 按块遍历
			yield batch.tolist()

Batch

batch_sampler 提供一批下标,取得一批数据后由 collate_fn 将这批数据整合:

if collate_fn is None:
	if self._auto_collation:
		collate_fn = _utils.collate.default_collate
	else:  # self.batch_sampler is None: (batch_size is None) and (batch_sampler is None)
		collate_fn = _utils.collate.default_convert

分两种情况:

  • automatic batching is disabled:调用 default_convert 函数简单地将 NumPy arrays 转化为 PyTorch Tensor;
  • automatic batching is enabled:调用 default_collate 函数,转化会变得复杂一点:
from torch.utils import data as tu_data
import collections

# %% Example with a batch of `int`s:
tu_data.default_collate([0, 1, 2, 3])
# tensor([0, 1, 2, 3])

# %% Example with a batch of `str`s:
tu_data.default_collate(['a', 'b', 'c'])
# ['a', 'b', 'c']

# %% Example with `Map` inside the batch:
tu_data.default_collate([
	{'A': 0, 'B': 1},
	{'A': 100, 'B': 100}
])
# {'A': tensor([0, 100]), 'B': tensor([1, 100])}, 同 key 的合并了

# %% Example with `NamedTuple` inside the batch:
Point = collections.namedtuple('Point', ['x', 'y'])
tu_data.default_collate([Point(0, 0), Point(1, 1)])
# Point(x=tensor([0, 1]), y=tensor([0, 1])), 同 name 的合并了, 大概和 dict 一样吧

# %% Example with `Tuple` inside the batch:
tu_data.default_collate([(0, 1), (2, 3)])
# [tensor([0, 2]), tensor([1, 3])], 对 list 内部执行 collate

# %% Example with `List` inside the batch:
tu_data.default_collate([[0, 1], [2, 3]])  # [tensor([0, 2]), tensor([1, 3])], 对 list 内部执行 collate, 并没有变成二维 tensor

Multi-process Data Loading

dataset, collate_fn, and worker_init_fn are passed to each worker,大概能说明 batch 是在子进程内部合成的。

有一个需要注意的地方是内存增长问题,当 __get_item__(self, key) 访问数据时,由于 Python 对象的 refcount 机制,数据会不断地复制,从而内存爆炸。但这里说解决 number of workers * size of parent process 问题,就不追究了,反正尽量用 numpy 或 pytorch tensor 吧。
iterable-style datasets 的随机性

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/391850.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

Linux: GDB 调试工具

目录 概念: Linux 下 debug 和 release 的区别: GDB 的使用 : 激活和进入工作模式: 查看文件的内容: 运行调试的文件: 打断点: 查看断点: 删除断点: 禁用断点…

17.3.1.3 灰度

版权声明:本文为博主原创文章,转载请在显著位置标明本文出处以及作者网名,未经作者允许不得用于商业目的。 灰度的算法主要有以下三种: 1、最大值法: 原图像:颜色值color(R,G,B&a…

wayland(xdg_wm_base) client 使用 dmabuf 最简实例

文章目录 前言一、zwp_linux_dmabuf_v1 协议二、wayland client 使用 zwp_linux_dmabuf_v1 协议传递dma-buf代码实例1. wayland_dmabuf.c 代码实例2. xdg-shell-protocol.c 和 xdg-shell-client-protocol.h3. linux-dmabuf-unstable-v1-client-protocol.h 和 linux-dmabuf-unst…

如何在JavaScript中使用大于和小于运算符

在你的 JavaScript 程序中,你经常需要比较两个值,以确定一个是否大于另一个或小于另一个。这就是大于和小于运算符派上用场的地方。 在本文中,我们将通过代码示例更详细地介绍如何使用这些运算符。 (本文内容参考:ja…

Stable Diffusion 模型下载:Beautiful Realistic Asians(美丽真实的亚洲人)

本文收录于《AI绘画从入门到精通》专栏,专栏总目录:点这里。 文章目录 模型介绍生成案例案例一案例二案例三案例四案例五案例六案例七案例八案例九案例十 下载地址 模型介绍 Beautiful Realistic Asians(BRA)模型是由作者自己训练…

阿里云服务器租用价格 2024年新版活动报价及租用收费标准

2024年最新阿里云服务器租用费用优惠价格表,轻量2核2G3M带宽轻量服务器一年61元,折合5元1个月,新老用户同享99元一年服务器,2核4G5M服务器ECS优惠价199元一年,2核4G4M轻量服务器165元一年,2核4G服务器30元3…

NumPyML 源码解析(四)

numpy-ml\numpy_ml\neural_nets\utils\__init__.py """ 神经网络特定的常见辅助函数。neural_nets.utils 模块包含神经网络特定的辅助函数,主要用于处理 CNNs。 """# 从当前目录下的 utils 模块中导入所有内容 from .utils import *…

天锐绿盾|防泄密系统|计算机文件数据\资料安全管理软件

“天锐绿盾”似乎是一款专注于防泄密和计算机文件数据/资料安全管理的软件。在信息安全日益受到重视的今天,这样的软件对于保护企业的核心数据资产和防止敏感信息泄露至关重要。 通用地址:www.drhchina.com 防泄密系统的主要功能通常包括: 文…

组合数的计算

1.由定义式直接算&#xff1a;n!/m!*(n-m)! #include <iostream> using namespace std; long long combine(long long m,long long n ){long long result1;for(int i1;i<n1;i){//n!result*i;}for(int i1;i<m1;i){//n!/m!result/i;}for(int i1;i<n-m1;i){//n!/(…

红蓝对抗:网络安全领域的模拟实战演练

引言&#xff1a; 随着信息技术的快速发展&#xff0c;网络安全问题日益突出。为了应对这一挑战&#xff0c;企业和组织需要不断提升自身的安全防护能力。红蓝对抗作为一种模拟实战演练方法&#xff0c;在网络安全领域得到了广泛应用。本文将介绍红蓝对抗的概念、目的、过程和…

问卷设计初探:题目类型概览与注意事项梳理

问卷法常被人们应用于社会调查中&#xff0c;它能反馈出最真实的社会信息。所以&#xff0c;很多企业为了最大程度地了解市场&#xff0c;也经常使用问卷调查法进行研究。不过&#xff0c;想要发挥出问卷法的最大用处&#xff0c;前提是要将问卷设计规范并且可量化。 想要设计…

【漏洞复现】企语iFair协同管理系统任意文件读取漏洞

Nx01 产品简介 企语iFair协同管理系统是一款专业的协同办公软件&#xff0c;该管理系统兼容性强&#xff0c;适合多种企业类型。 Nx02 漏洞描述 企语iFair协同管理系统存在任意文件读取漏洞&#xff0c;未经身份认证的攻击者可以通过此漏洞获取服务器敏感信息。 Nx03 产品主页…

租房招聘|在线租房和招聘平台|基于Springboot的在线租房和招聘平台设计与实现(源码+数据库+文档)

在线租房和招聘平台目录 目录 基于Springboot的在线租房和招聘平台设计与实现 一、前言 二、系统功能设计 三、系统实现 1、房屋管理 2、招聘管理 3、平台资讯管理 4、平台资讯类型管理 四、数据库设计 1、实体ER图 六、论文参考 七、最新计算机毕设选题推荐 八、源…

小白必看,总结前端所有主流的构建工具,webpack / vite / roollup / esbuild,包含源码,建议关注+收藏

前言 本篇文章旨在总结前端常见的构建工具&#xff0c;构建工具是前端工程化中的重要的组成部分。 在实际项目中&#xff0c;我们初始化项目&#xff0c;一般是使用脚手架命令一键生成的&#xff0c;比如说使用 create-vue 初始化 vue 项目的时候&#xff0c;就会默认使用 vi…

力扣 123. 买卖股票的最佳时机 III

题目来源&#xff1a;https://leetcode.cn/problems/best-time-to-buy-and-sell-stock-iii/description/ C题解&#xff1a;动态规划。至多买卖两次&#xff0c;这意味着可以买卖一次&#xff0c;可以买卖两次&#xff0c;也可以不买卖。 一天一共就有四个状态&#xff1a; 第…

报文鉴别、实体鉴别

目录 鉴别 1 报文鉴别 1.1 用数字签名进行鉴别&#xff08;原理&#xff09; 可保证机密性的数字签名 1.2 密码散列函数 MD5 算法 MD5 算法计算步骤 安全散列算法 SHA-1 1.3 用报文鉴别码实现报文鉴别 用报文鉴别码 MAC 鉴别报文 使用已签名的报文鉴别码 MAC 对报…

论文阅读:MotionNet基于鸟瞰图的自动驾驶联合感知和运动预测

MotionNet: Joint Perception and Motion Prediction for Autonomous Driving Based on Bird’s Eye View Maps MotionNet&#xff1a;基于鸟瞰图的自动驾驶联合感知和运动预测 论文地址&#xff1a;MotionNet: Joint Perception and Motion Prediction for Autonomous Drivi…

【AI视野·今日CV 计算机视觉论文速览 第294】Mon, 22 Jan 2024

AI视野今日CS.CV 计算机视觉论文速览 Mon, 22 Jan 2024 Totally 64 papers &#x1f449;上期速览✈更多精彩请移步主页 Daily Computer Vision Papers Depth Anything: Unleashing the Power of Large-Scale Unlabeled Data Authors Lihe Yang, Bingyi Kang, Zilong Huang, X…

vue3-组合式 API

什么是组合式 API&#xff1f; 组合式 API (Composition API) 是一系列 API 的集合&#xff0c;使我们可以使用函数而不是声明选项的方式书写 Vue 组件。它是一个概括性的术语&#xff0c;涵盖了以下方面的 API&#xff1a; 响应式 API&#xff1a;例如 ref() 和 reactive()&a…

如何在UI自动化测试中加入REST API的操作

1、问题 当我们描述一个“好的自动化测试用例”时&#xff0c;经常出现标准是&#xff1a; 精确 自动化测试用例应该测试一件事&#xff0c;只有一件事。与测试用例无关的应用程序的某个部分中的错误不应导致测试用例失败。 独立 自动化测试用例不应该受测试套件中任何其他…