深入解析PyTorch的DataLoader:参数探秘与使用指南【建议收藏】

引言

当我们深入探索深度学习的世界时,PyTorch作为一个强大且易用的框架,提供了丰富的功能来帮助我们高效地进行模型训练和数据处理。其中,DataLoader是PyTorch中一个非常核心且实用的组件,它负责在模型训练过程中加载和处理数据。通过灵活配置DataLoader的各种参数,我们可以优化数据加载速度,调整数据批次大小,甚至实现自定义的数据处理和抽样策略。在这篇文章中,小编将详细解析DataLoader的每个参数,通过具体的示例代码展示它们的使用场景和效果,帮助你更深入地理解和使用PyTorch进行深度学习模型的开发。

DataLoader的主要参数

PyTorch的DataLoader是一个非常重要的工具,用于在训练神经网络时批量、打乱和并行加载数据。下面我们将详细介绍其各个参数的具体作用和使用场景,并通过示例代码进行详细注释。

主要参数说明

  1. dataset (必需): 用于加载数据的数据集,通常是torch.utils.data.Dataset的子类实例。
  2. batch_size (可选): 每个批次的数据样本数。默认值为1。
  3. shuffle (可选): 是否在每个周期开始时打乱数据。默认为False
  4. sampler (可选): 定义从数据集中抽取样本的策略。如果指定,则忽略shuffle参数。
  5. batch_sampler (可选): 与sampler类似,但一次返回一个批次的索引。不能与batch_sizeshufflesampler同时使用。
  6. num_workers (可选): 用于数据加载的子进程数量。默认为0,意味着数据将在主进程中加载。
  7. collate_fn (可选): 如何将多个数据样本整合成一个批次。通常不需要指定。
  8. drop_last (可选): 如果数据集大小不能被批次大小整除,是否丢弃最后一个不完整的批次。默认为False

DataLoader的dataset参数(必需)

在实例化PyTorch的DataLoader类时,dataset参数是必需的,它指定了要从【哪个数据集对象】里面加载数据。该对象必须是torch.utils.data.Dataset的子类实例。

示例代码:

from torch.utils.data import Dataset, DataLoader

class MyDataset(Dataset):
    def __init__(self, data):
        self.data = data
        
    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, idx):
        return self.data[idx]

# 创建自定义数据集实例
my_data = [1, 2, 3, 4, 5, 6, 7]
my_dataset = MyDataset(my_data)

# 使用DataLoader加载自定义数据集my_dataset
dataloader = DataLoader(dataset=my_dataset)

DataLoader的batch_size参数 (可选)

batch_size参数指定了每个批次的数据样本数。默认值为1。

示例代码:

# 将批次大小设置为3,这意味着每个批次将包含3个数据样本。
dataloader = DataLoader(dataset=my_dataset, batch_size=3)

for data in dataloader:
	print(data)

运行结果
在这里插入图片描述

DataLoader的shuffle参数 (可选)

shuffle参数指定是否在每个周期开始时打乱数据。默认为False。如果设置为True,则在每个周期开始时,数据将被随机打乱顺序。

示例代码:

# shuffle默认为False
dataloader = DataLoader(dataset=my_dataset, batch_size=3)

print("当shuffle=False时,运行结果如下:")
print("*" * 30)
for data in dataloader:
    print(data)
print("*" * 30)

dataloader = DataLoader(dataset=my_dataset, batch_size=3, shuffle=True)

print("当shuffle=True时,运行结果如下:")
print("*" * 30)
for data in dataloader:
    print(data)
print("*" * 30)

运行结果:
在这里插入图片描述

DataLoader的drop_last参数 (可选)

drop_last参数决定了在数据批次划分时是否丢弃最后一个不完整的批次。当数据集的大小不能被批次大小整除时,最后一个批次的大小可能会小于指定的批次大小。drop_last参数用于控制是否保留这个不完整的批次

使用场景

  1. 当数据集大小不能被批次大小整除时,如果最后一个批次的大小较小,可能会导致模型训练时的不稳定。通过将drop_last设置为True,可以确保每个批次的大小都相同,从而避免这种情况。
  2. 在某些情况下,丢弃最后一个批次可能不会对整体训练效果产生太大影响,但可以减少计算资源的浪费。例如,当数据集非常大时,最后一个不完整的批次可能只包含很少的数据样本,对于整体训练过程的贡献较小。

示例代码:

# drop_last默认为False
dataloader = DataLoader(dataset=my_dataset, batch_size=3)

print("当drop_last=False时,运行结果如下:")
print("*" * 30)
for data in dataloader:
    print(data)
print("*" * 30)

dataloader = DataLoader(dataset=my_dataset, batch_size=3, drop_last=True)

print("当drop_last=True时,运行结果如下:")
print("*" * 30)
for data in dataloader:
    print(data)
print("*" * 30)

运行结果
在这里插入图片描述
可以看到,当drop_last=True时,最后一个批次的数据tensor([7])被舍弃了。

DataLoader的sampler参数 (可选)

sampler参数定义从数据集中抽取样本的策略。如果指定了sampler,则忽略shuffle参数。它可以是任何实现了__iter__()方法的对象,通常会使用torch.utils.data.Sampler的子类。

示例代码:

from torch.utils.data import SubsetRandomSampler

# 创建一个随机抽样器,只选择索引为偶数的样本 【索引从0开始~】
sampler = SubsetRandomSampler(indices=[i for i in range(0, len(my_dataset), 2)])
dataloader = DataLoader(dataset=my_dataset, sampler=sampler)

for data in dataloader:
    print(data)

运行结果:
在这里插入图片描述

DataLoader的batch_sampler参数 (可选)

batch_sampler参数与sampler类似,但它返回的是一批次的索引,而不是单个样本的索引。不能与batch_sizeshufflesampler同时使用。

示例代码:

from torch.utils.data import BatchSampler
from torch.utils.data import SubsetRandomSampler

# 创建一个随机抽样器,只选择索引为偶数的样本 【索引从0开始~】
sampler = SubsetRandomSampler(indices=[i for i in range(0, len(my_dataset), 2)])

# 创建一个批量抽样器,每个批次包含2个样本
batch_sampler = BatchSampler(sampler, batch_size=2, drop_last=True)
dataloader = DataLoader(dataset=my_dataset, batch_sampler=batch_sampler)

for data in dataloader:
    print(data)

运行结果:
在这里插入图片描述

DataLoader的num_workers参数 (可选)

num_workers参数指定用于数据加载的子进程数量。默认为0,表示数据将在主进程中加载。增加num_workers的值可以加快数据的加载速度,但也会增加内存消耗。

示例代码:

dataloader = DataLoader(dataset=my_dataset, num_workers=4)

代码解释: 在这个示例中,我们将子进程数量设置为4,这意味着将使用4个子进程并行加载数据,以加快数据加载速度。

DataLoader的collate_fn参数 (可选)

collate_fn参数指定如何将多个数据样本整合成一个批次,通常不需要指定。如果需要自定义批次数据的整合方式,可以提供一个可调用的函数。该函数接受一个样本【列表】作为输入,返回一个批次的数据。

示例代码:

import torch
class MyDataset(Dataset):
    def __init__(self, data):
        self.data = data

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        return self.data[idx]


# 创建自定义数据集实例
my_data = [1, 2, 3, 4, 5, 6, 7]
my_dataset = MyDataset(my_data)

def my_collate_fn(batch):
    print(type(batch))
    # 将batch中的每个样本转换为pytorch的tensor并都加上10
    return [torch.tensor(data) + 10 for data in batch]

dataloader = DataLoader(dataset=my_dataset, batch_size=2, collate_fn=my_collate_fn)

for data in dataloader:
    print(data)

运行结果:
在这里插入图片描述

结束语

如果本博文对你有所帮助/启发,可以点个赞/收藏支持一下,如果能够持续关注,小编感激不尽~
如果有相关需求/问题需要小编帮助,欢迎私信~
小编会坚持创作,持续优化博文质量,给读者带来更好de阅读体验~

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/232962.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

如何利用Axure制作移动端产品原型

Axure是一款专业的快速原型设计工具,作为专业的原型设计工具,Axure 能够快速、高效地创建原型,同时支持多人协作设计和版本控制管理。它已经得到了许多大公司的采用,如IBM、微软、思科、eBay等,这些公司都利用Axure 进…

【Linux】地址空间

本片博客将重点回答三个问题 什么是地址空间? 地址空间是如何设计的? 为什么要有地址空间? 程序地址空间排布图 在32位下,一个进程的地址空间,取值范围是0x0000 0000~ 0xFFFF FFFF 回答三个问题之前我们先来证明地址空…

react中使用react-konva实现画板框选内容

文章目录 一、前言1.1、API文档1.2、Github仓库 二、图形2.1、拖拽draggable2.2、图片Image2.3、变形Transformer 三、实现3.1、依赖3.2、源码3.2.1、KonvaContainer组件3.2.2、use-key-press文件 3.3、效果图 四、最后 一、前言 本文用到的react-konva是基于react封装的图形绘…

Scrum

Scrum是一个用于开发和维持复杂产品的框架,是一个增量的、迭代的开发过程。在这个框架中,整个开发过程由若干个短的迭代周期组成,一个短的迭代周期称为一个Sprint,每个Sprint的建议长度是2到4周(互联网产品研发可以使用1周的Sprin…

序列的Z变换(信号的频域分析)

1. 关于Z变换 2. 等比级数求和 3. 特殊序列的Z变换 4. 因果序列/系统收敛域的特点 5. 例题

力扣 4. 寻找两个正序数组的中位数

题目 给定两个大小分别为 m 和 n 的正序(从小到大)数组 nums1 和 nums2。请你找出并返回这两个正序数组的 中位数 。 算法的时间复杂度应该为 O(log (mn)) 。 My class Solution {public double findMedianSortedArrays(int[] nums1, int[] nums2) {i…

LLM之Agent(五)| AgentTuning:清华大学与智谱AI提出AgentTuning提高大语言模型Agent能力

​论文地址:https://arxiv.org/pdf/2310.12823.pdf Github地址:https://github.com/THUDM/AgentTuning 在ChatGPT带来了大模型的蓬勃发展,开源LLM层出不穷,虽然这些开源的LLM在各自任务中表现出色,但是在真实环境下作…

按天批量创建间隔分区表(DM8:达梦数据库)

DM8:达梦数据库-按天批量创建间隔分区表 环境介绍1 生成按天批量创建间隔分区表的日志2 整合后的日志信息3 创建成功4 达梦数据库学习使用列表 环境介绍 由于未知原因限制,按天批量创建间隔分区表最大是103行记录,需要反复执行几次,提取日志,整合后最终创建成功; 1 生成按天批…

AGILE-SCRUM

一个复杂的汽车ECU开发。当时开发队伍遍布全球7个国家,10多个地区,需要同时为多款车型定制不同的软件,头疼的地方是: 涉及到多方人员协调,多模块集成和管理不同软件团队使用的设计工具、验证工具,数据、工…

python安装与工具PyCharm

摘要: 周末闲来无事学习一下python!不是你菜鸡,只不过是对手太强了!所以你要不断努力,去追求更高的未来!下面先了解python与环境的安装与工具的配置! python安装: 官网 进入官网下载…

【Linux】输出缓冲区和fflush刷新缓冲区

目录 一、输出缓冲区 1.1 输出缓冲区的使用 1.2 缓冲区的刷新 1.3 输出缓冲区的作用 二、回车换行 一、输出缓冲区 C/C语言,当调用输出函数(如printf()、puts()、fwrite()等)时,会给我们提供默认的缓冲区。这些数据先存…

Python绘制多分类ROC曲线

目录 1 数据集介绍 1.1 数据集简介 1.2 数据预处理 2随机森林分类 2.1 数据加载 2.2 参数寻优 2.3 模型训练与评估 3 绘制十分类ROC曲线 第一步,计算每个分类的预测结果概率 第二步,画图数据准备 第三步,绘制十分类ROC曲线 1 数据集…

C++学习笔记之五(String类)

C 前言getlinelength, sizec_strappend, inserterasefindsubstrisspace, isdigit 前言 C是兼容C语言的,所以C的字符串自然继承C语言的一切字符串,但它也衍生出属于自己的字符串类,即String类。String更像是一个容器,但它与容器还…

uniapp如何制作一个收缩通讯录(布局篇)

html&#xff1a; <view class"search"><view class"search_padding"><u-search change"search" placeholder"请输入成员名称" v-model"keyword"></u-search></view></view> <view…

【面试经典150 | 二叉树】从中序与后序遍历序列构造二叉树

文章目录 写在前面Tag题目来源题目解读解题思路方法一&#xff1a;递归 写在最后 写在前面 本专栏专注于分析与讲解【面试经典150】算法&#xff0c;两到三天更新一篇文章&#xff0c;欢迎催更…… 专栏内容以分析题目为主&#xff0c;并附带一些对于本题涉及到的数据结构等内容…

2020年第九届数学建模国际赛小美赛A题自由泳解题全过程文档及程序

2020年第九届数学建模国际赛小美赛 A题 自由泳 原题再现&#xff1a; 在所有常见的游泳泳姿中&#xff0c;哪一种最快&#xff1f;哪个冲程推力最大&#xff1f;在自由泳项目中&#xff0c;游泳者可以选择他们的泳姿&#xff0c;他们通常选择前面的爬行。然而&#xff0c;游泳…

中文分词演进(查词典,hmm标注,无监督统计)新词发现

查词典和字标注 目前中文分词主要有两种思路&#xff1a;查词典和字标注。 首先&#xff0c;查词典的方法有&#xff1a;机械的最大匹配法、最少词数法&#xff0c;以及基于有向无环图的最大概率组合&#xff0c;还有基于语言模型的最大概率组合&#xff0c;等等。 查词典的方法…

esxi全称“VMware ESXi

esxi全称“VMware ESXi”&#xff0c;是可直接安装在物理服务器上的强大的裸机管理系统&#xff0c;是一款虚拟软件&#xff1b;ESXi本身可以看做一个操作系统&#xff0c;采用Linux内核&#xff0c;安装方式为裸金属方式&#xff0c;可直接安装在物理服务器上&#xff0c;不需…

TCP传输数据的确认机制

实际的TCP收发数据的过程是双向的。 TCP采用这样的方式确认对方是否收到了数据&#xff0c;在得到对方确认之前&#xff0c;发送过的包都会保存在发送缓冲区中。如果对方没有返回某些包对应的ACK号&#xff0c;那么就重新发送这些包。 这一机制非常强大。通过这一机制&#xf…

Redis如何做内存优化?

Redis如何做内存优化&#xff1f; 1、缩短键值的长度 缩短值的长度才是关键&#xff0c;如果值是一个大的业务对象&#xff0c;可以将对象序列化成二进制数组&#xff1b; 首先应该在业务上进行精简&#xff0c;去掉不必要的属性&#xff0c;避免存储一些没用的数据&#xff1…