Pytorch:torch.utils.data.DataLoader()

如果读者正在从事深度学习的项目,通常大部分时间都花在了处理数据上,而不是神经网络上。因为数据就像是网络的燃料:它越合适,结果就越快、越准确!神经网络表现不佳的主要原因之一可能是由于数据不佳或理解不足。因此,以更直观的方式理解、预处理数据并将其加载到网络中非常重要。
参考:https://zhuanlan.zhihu.com/p/596730297

DataLoader加载和迭代数据集

Dataloader本质是一个迭代器对象,也就是可以通过for batch_idx,batch_dict in dataloader 来提取数据集,提取的数量由batch_size 参数决定,得到这一batch的数据后,就可以喂入网络开始训练或者推理了。
在迭代的过程中,dataloader会自动调用dataset中的__getitem__ 函数,以获取一帧数据(item)

from torch.utils.data import DataLoader

DataLoader(
    dataset,
    batch_size=1,
    shuffle=False,
    num_workers=0,
    collate_fn=None,
    pin_memory=False,
 )

以U-Net中的代码为例:
具体详见:U-Net代码复现

loader_args = dict(batch_size=batch_size, num_workers=os.cpu_count(), pin_memory=True)
train_loader = DataLoader(train_set, shuffle=True, **loader_args)
val_loader = DataLoader(val_set, shuffle=False, drop_last=True, **loader_args)

1. 数据集

**dataset (Dataset) ** – dataset from which to load the data.
即自定义的数据集,非常重要,因为dataloader会调用dataset的一些重载函数(e.g. getitem && len )

2. 对数据进行批处理

batch_size (int, optional)how many samples per batch to load(default: 1).

3. 在 CUDA 张量上加载数据

pin_memory(bool, optional)If True, the data loader will copy Tensors into device/CUDA pinned memory before returning them. If your data elementsare a custom type, or your collate_fn returns a batch that is a custom type,see the example below.

pin_memory参数直接将数据集加载为 CUDA 张量。它是一个可选参数,接受一个布尔值;如果设置为True,会在返回张量之前张量复制到 CUDA 固定内存中。这样在GPU训练过程中,数据从内存到GPU的复制可以使用异步的方式进行,从而提高数据读取的效率。

通常情况下,当使用GPU训练模型时,数据读取会成为整个训练过程的瓶颈之一。使用pin_memory可以将数据在CPU和GPU之间进行传输时的复制时间减少,从而提高数据加载的速度,加速训练过程。

需要注意的是,使用pin_memory会占用更多的内存空间,因此在内存资源紧张的情况下,需要谨慎使用。同时,在某些情况下(例如数据集比较小的情况下),使用pin_memory并不会带来明显的加速效果。

4.允许多进程

num_workers (int, optional)how many subprocesses to use for dataloading. 0 means that the data will be loaded in the main process.(default: 0)
这也是一个很有意思的参数,按照官方的说法, num_workers 用于设置数据加载过程中使用的子进程数。其默认值为0,即在主进程中进行数据加载,而不使用额外的子进程。

以下是我看到的一个解释,原文链接:https://blog.csdn.net/vonct/article/details/130263743
下面说一下个人的理解,在初始化 dataloader对象时,会根据num_workers创建子线程用于加载数据(主线程数+子线程=num_workers)。每个worker或者说线程都有自己负责的dataset范围(下面统称worker)

每当迭代 dataloader 对象时,工人们(workers)就开始干活了:将数据从数据源(如硬盘)加载到内存(数据加载),当一个worker读取(调用__getitem__)到足够的数据(看你在dataset中怎么定义一个item了)后,会将这些数据封装成一个(即一帧),并将其放到该worker独有的内存队列中。 要注意的是,每次迭代时,worker会尽可能地读数据,直到自己的队列被填满。

当所有workers的队列都被填满时,一个名为sampler的线程将会被创建,它的作用就是收集各workers队列中队首的 ,把他们放到一个各线程共享内存的缓冲队列中,并调用 collate_fn 函数来将 batch_size 个 整合,最后返回给迭代的输出。

这时候大家肯定会有点疑惑,那当迭代到后期时,需要读取的样本都已经在队列中了,是不是意味着这时候工人们已经在休息了?根据chatgpt的回答:是的!下面以一张图来帮助大家理解

在这里插入图片描述

5.合并数据集

collate_fn (Callable, optional)merges a list of samples to form a mini-batch of Tensor(s). Used when using batched loading from a map-style dataset.

整合多个样本到一个batch时需要调用的函数,当 getitem 返回的不是tensor而是字典之类时,需要进行 collate_fn的重载,同时可以进行数据的进一步处理以满足pytorch的输入要求。
以U-Net为例:

def __getitem__(self, idx):
        name = self.ids[idx]
        mask_file = list(self.mask_dir.glob(name + self.mask_suffix + '.*'))
        img_file = list(self.images_dir.glob(name + '.*'))

        assert len(img_file) == 1, f'Either no image or multiple images found for the ID {name}: {img_file}'
        assert len(mask_file) == 1, f'Either no mask or multiple masks found for the ID {name}: {mask_file}'
        mask = load_image(mask_file[0])
        img = load_image(img_file[0])

        assert img.size == mask.size, \
            f'Image and mask {name} should be the same size, but are {img.size} and {mask.size}'

        img = self.preprocess(self.mask_values, img, self.scale, is_mask=False)
        mask = self.preprocess(self.mask_values, mask, self.scale, is_mask=True)

        return {
            'image': torch.as_tensor(img.copy()).float().contiguous(),
            'mask': torch.as_tensor(mask.copy()).long().contiguous()
        }

getitem 返回的是一个包含image和mask的 data_dict 字典,这时候就需要调用自定义的collate_fn来进行打包(待补充。。。)

6.数据采样

sampler (Sampler or Iterable, optional) – defines the strategy to draw samples from the dataset. Can be any Iterable with len implemented. If specified, shufflemust not be specified.

sampler的主要作用是控制样本的采样顺序,并提供样本的索引。在默认情况下,dataloader使用的是SequentialSampler,它按照数据集的顺序依次提取样本,但在某些情况下,我们可能需要自定义采样顺序。比如说想从队尾提取数据。

比如,当我们处理非常大的数据集时,为了提高训练效率,可能需要对数据进行分布式采样,这时候就需要使用DistributedSampler。DistributedSampler会将数据集划分成多个子集,每个子集分配给不同的进程进行采样。在这种情况下,如果使用默认的SequentialSampler,可能会导致各个进程采样到相同的数据,从而降低训练效率。

此外,还有一些自定义的sampler,比如随机采样器(RandomSampler)和加权采样器(WeightedRandomSampler),它们可以按照不同的采样策略对数据集进行采样,从而满足不同的训练需求。

因此,根据不同的训练需求,我们可能需要自定义sampler来控制数据的采样顺序。

原文链接:https://blog.csdn.net/vonct/article/details/130263743

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/199689.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

ArcGIS Pro、Python、USLE、INVEST模型等多技术融合的生态系统服务构建生态安全格局

近年来,由于社会经济的快速发展和人口增长,社会活动对环境的压力不断增大,人地矛盾加剧。虽然全球各国在生态环境的建设和保护上已取得不少成果,但还是未从根本上转变生态环境的恶化趋势;生态破坏、环境退化、生物多样…

rss服务搭建记录

layout: post title: RSS subtitle: vps搭建RSS服务 date: 2023-11-27 author: Sprint#51264 header-img: img/post-bg-universe.jpg catalog: true tags: - 折腾 文章目录 引言RSShub-dockerRSS-radarFreshrssFluent reader获取fever api配置Fluent Reader同步 结语 引言 一个…

【代码】数据驱动的多离散场景电热综合能源系统分布鲁棒优化算法matlab/yalmip+cplex/gurobi

程序名称:数据驱动的多离散场景电热综合能源系统分布鲁棒优化算法 实现平台:matlab-yalmip-cplex/gurobi 代码简介:数据驱动的分布鲁棒优化算法。考虑四个离散场景,模型采用列与约束生成(CCG)算法进行迭代求解,场景分…

C#键盘钩子(Hook)拦截器的使用

引言 键盘钩子(Hook)是一种机制,允许程序捕获和处理操作系统中的键盘输入。在C#中,我们可以使用键盘钩子来创建一个拦截器,用于拦截特定的键盘事件并执行自定义操作。本文将介绍如何使用C#开发一个键盘钩子拦截器,并给出一些示例代…

苹果手机如何格式化?五个步骤快速掌握!

如果手机出现异常情况,例如运行缓慢、频繁崩溃,又或者想将手机出售、转让给他人,那么将手机格式化可以有助于解决问题。苹果手机如何格式化?本文将为您介绍解决方法,只需要五个步骤就能搞定,帮助您快速掌握…

一维数组,逆序存放并输出【样例输入】20 30 10 50 40 90 80 70【样例输出】70 80 90 40 50 10 30 20

一维数组&#xff0c;逆序存放并输出 【样例输入】 20 30 10 50 40 90 80 70 【样例输出】 70 80 90 40 50 10 30 20 以下是使用C语言编写的将一维数组逆序存放并输出的示例代码&#xff1a; #include <stdio.h>void reverseArray(int arr[], int size) {int start…

Xilinx Zynq-7000系列FPGA多路视频处理:图像缩放+视频拼接显示,提供工程源码和技术支持

目录 1、前言免责声明 2、相关方案推荐FPGA图像处理方案FPGA图像缩放方案FPGA视频拼接叠加融合方案推荐 3、设计思路详解HLS 图像缩放介绍Video Mixer介绍 4、vivado工程介绍PL 端 FPGA 逻辑设计PS 端 SDK 软件设计 5、工程移植说明vivado版本不一致处理FPGA型号不一致处理其他…

全网日志智能聚合及问题根因分析

1 日志关联分析的挑战 随着各行各业数字化转型的不断深入&#xff0c;网络承载了人们日常生活所需的政务、金融、娱乐等多方面的业务系统&#xff0c;已经成为影响社会稳定运行、关系国计民生的重要基础设施资源。哪怕网络发生及其微小的故障&#xff0c;也可能带来难以估量的…

【预测爆款不用愁,有服饰RFID小助手】

时尚服饰行业库存成本高&#xff0c;数据不精准&#xff0c;爆款服饰一直抓不住&#xff0c;增加库存滞销风险难脱逃&#xff0c;给服饰零售企业带来极大困扰。 帮您提前预测爆款服饰小塔服饰RFID系统 小塔RFID系统作为服饰新零售小助手&#xff0c;通过RFID系统与硬件结合&a…

在Springboot中操作Redis——五大数据类型

在Java中操作Redis Redis的Java客户端 前面我们讲解了Redis的常用命令&#xff0c;这些命令是我们操作Redis的基础&#xff0c;那么我们在java程序中应该如何操作Redis呢&#xff1f;这就需要使用Redis的Java客户端&#xff0c;就如同我们使用JDBC操作MySQL数据库一样。 Red…

重温 re:Invent,分享十年成长:我和 re:Invent的故事

文章目录 前言背景我和re:Invent的交际历届峰会主题2012 突破技术垄断2013 革新数据服务2014 更好用的云服务2015 打通最后一-公里2016 迈向云上数据湖时代2017 重构云计算基础2018 云能力的再进化2019 赋能企业云架构服务2020 推动行业数据库服务的演进2021 无可比拟的云架构2…

pdf文件编辑,[增删改查]

pdf文件是投标文件中必不可少的格式&#xff0c;传统的方式先编辑word格式&#xff0c;最后生成pdf&#xff0c;但是有时候需要直接编辑pdf文件&#xff0c;编辑pdf的工具无疑 “adobe acrobat dc”是最好用的之一了 1.把图片文件添加到pdf指定位置&#xff0c;例如把一张图片添…

API网关

API网关的作用 下图显示了详细信息。 步骤 1 - 客户端向 API 网关发送 HTTP 请求。 步骤 2 - API 网关解析并验证 HTTP 请求中的属性。 步骤 3 - API 网关执行允许列表/拒绝列表检查。 步骤 4 - API 网关与身份提供商对话以进行身份​​验证和授权。 步骤 5 - 将速率限制规…

亚马逊云科技Aurora MySQL在复制性能提升上的不断优化和尝试

前言 Amazon Aurora是亚马逊云科技自研的云原生关系数据库&#xff0c;它在提供和开源数据库MySQL、PostgreSQL的完好兼容性同时&#xff0c;也能够提供和商业数据库媲美的性能和可用性。 Aurora的性能提升不仅包含应用读写吞吐量的提升&#xff0c;也包含复制延迟的降低。一个…

echart 柱状图-bar

业务场景一 效果 业务组件调用代码 <template> <barCom :domId"1" :title"barComProps.title" :xAxisData"barComProps.xAxisData" :yAxisProps"barComProps.yAxisProps" :seriseData"barComProps.serise…

在数据库中进行表内容的修改(MYSQL)

根据表中内容&#xff0c;用命令语句创建数据库&#xff0c;表格&#xff0c;以及插入&#xff0c;修改&#xff0c;删除表格中的内容。 创建数据库&#xff1a;zrzy mysql> create database zrzy; 引用zrzy数据库&#xff1a; mysql> use zrzy; 创建student_info表&…

【EasyExcel实践】导出多个sheet到多个excel文件,并压缩到一个zip文件

文章目录 前言正文一、项目依赖二、封装表格实体和Sheet实体2.1 表格实体2.2 Sheet实体 三、核心实现3.1 核心实现之导出为输出流3.2 web导出3.3 导出为字节数组 四、调试4.1 构建调试用的实体类4.2 控制器调用4.3 测试结果 五、注册大数转换器&#xff0c;长度大于15时&#x…

XML Schema中的attributeFormDefault

XML Schema中的attributeFormDefault属性&#xff0c;用以指定元素的属性默认是否必须带有命名空间前缀。 attributeFormDefault属性可以取值qualified或unqualified&#xff0c;默认值是unqualified。 当取值为qualified时&#xff0c;表示属性必须用命名空间作为前缀&#x…

线性可分SVM摘记

线性可分SVM摘记 0. 线性可分1. 训练样本到分类面的距离2. 函数间隔和几何间隔、(硬)间隔最大化3. 支持向量 \qquad 线性可分的支持向量机是一种二分类模型&#xff0c;支持向量机通过核技巧可以成为非线性分类器。本文主要分析了线性可分的支持向量机模型&#xff0c;主要取自…

命令模式 rust和java实现

文章目录 命令模式介绍javarustrust仓库 命令模式 命令模式&#xff08;Command Pattern&#xff09;是一种数据驱动的设计模式。请求以命令的形式包裹在对象中&#xff0c;并传给调用对象。调用对象寻找可以处理该命令的合适的对象&#xff0c;并把该命令传给相应的对象&…