PyTorch DataLoader整理函数详解【collate_fn】

DataLoader 是 PyTorch 中最常用的类之一。 而且,它是你首先学习的内容之一。 该类有很多参数,但最有可能的是,你将使用其中的大约三个参数(dataset、shuffle 和 batch_size)。 今天我想解释一下 collate_fn 的含义—根据我的经验,我发现它让初学者感到困惑。 我们将简要探讨 PyTorch 如何创建批数据,并了解如何根据需要修改默认行为。

在线工具推荐: Three.js AI纹理开发包 - YOLO合成数据生成器 - GLTF/GLB在线编辑 - 3D模型格式在线转换 - 可编程3D场景编辑器 

1、批创建流程

每个深度学习课程中最重要的信息之一是我们批量执行训练/推理。 大多数时候,一个批次只是一些堆叠的数据样本。 但在某些情况下,我们想修改它的创建方式。

首先,让我们研究一下默认情况下会发生什么。 假设我们有以下玩具数据集。 它包含四个示例,每个示例三个功能。

import torch
from torch.utils.data import DataLoader
import numpy as np

data = np.array([
    [0.1, 7.4, 0],
    [-0.2, 5.3, 0],
    [0.2, 8.2, 1],
    [0.2, 7.7, 1]])
print(data)

如果我们向加载程序请求一个批次,我们将看到以下内容(请注意,我设置了 shuffle=False 以消除随机性):

loader = DataLoader(data, batch_size=2, shuffle=False)
batch = next(iter(loader))
print(batch)

# tensor([[ 0.1000,  7.4000,  0.0000],
#         [-0.2000,  5.3000,  0.0000]], dtype=torch.float64)

结果毫不奇怪,但让我们正式描述一下已经做了什么:

  • 加载器从数据集中选择了 2 个样本。
  • 这些样本被转换为张量(2 个大小为 3 的样本)。
  • 创建并返回一个新的张量 (2x3)。

默认设置还允许我们使用字典。 让我们看一个例子:

from pprint import pprint
# now dataset is a list of dicts
dict_data = [
    {'x1': 0.1, 'x2': 7.4, 'y': 0},
    {'x1': -0.2, 'x2': 5.3, 'y': 0},
    {'x1': 0.2, 'x2': 8.2, 'y': 1},
    {'x1': 0.2, 'x2': 7.7, 'y': 10},
]
pprint(dict_data)
# [{'x1': 0.1, 'x2': 7.4, 'y': 0},
# {'x1': -0.2, 'x2': 5.3, 'y': 0},
# {'x1': 0.2, 'x2': 8.2, 'y': 1},
# {'x1': 0.2, 'x2': 7.7, 'y': 10}]

loader = DataLoader(dict_data, batch_size=2, shuffle=False)
batch = next(iter(loader))
pprint(batch)
# {'x1': tensor([ 0.1000, -0.2000], dtype=torch.float64),
#  'x2': tensor([7.4000, 5.3000], dtype=torch.float64),
#  'y': tensor([0, 0])}

加载器足够聪明,可以正确地从字典列表中重新打包数据。 当你的数据采用 JSONL 格式(我个人更喜欢这种格式而不是 CSV)时,此功能非常方便。

2、自定义collate函数

如果默认规则如此智能,为什么我们需要创建自定义collate规则呢? 默认设置有一个很大的限制——批数据必须处于同一维度。 假设我们有一个 NLP 任务,并且数据是分词后的文本。

# values are token indices but it does not matter - it can be any kind of variable-size data
nlp_data = [
    {'tokenized_input': [1, 4, 5, 9, 3, 2],
     'label':0},
    {'tokenized_input': [1, 7, 3, 14, 48, 7, 23, 154, 2],
     'label':0},
    {'tokenized_input': [1, 30, 67, 117, 21, 15, 2],
     'label':1},
    {'tokenized_input': [1, 17, 2],
     'label':0},
]
loader = DataLoader(nlp_data, batch_size=2, shuffle=False)
batch = next(iter(loader))

上面的代码不会工作并引发错误:

/usr/local/lib/python3.7/dist-packages/torch/utils/data/_utils/collate.py in default_collate(batch)
     80         elem_size = len(next(it))
     81         if not all(len(elem) == elem_size for elem in it):
---> 82             raise RuntimeError('each element in list of batch should be of equal size')
     83         transposed = zip(*batch)
     84         return [default_collate(samples) for samples in transposed]

RuntimeError: each element in list of batch should be of equal size

错误消息表明不可能创建非矩形张量。 顺便说一句,可以看到触发错误的是 default_collate函数。

我们可以做什么? 有两种解决方案:

  • 将整个数据集填充到最长的样本。
  • 在批创建期间动态填充。

第一个解决方案可能看起来更简单—只需将所有样本扩展到最长的样本即可。 但有一个问题—我们会浪费内存和计算能力(它们在 GPU 上很昂贵!)来处理 padding,这并不影响结果。 如果我们的数据中有一些长序列,而且大多数序列都相对较短,那就尤其痛苦。 在这种情况下,我们主要是处理填充而不是数据!

如果我们将整个数据集填充到最长的序列,会浪费大量空间!

另一种方法是动态填充数据。 当选择该批的样本时,我们只将它们填充到最长的样本。 如果我们另外按长度对数据进行排序,则填充将是最小的。 如果有一些非常长的序列,它们只会影响它们的批次,而不是整个数据集。

好吧,但是如何实现呢? 只需创建一个自定义 collate_fn , 这很简单:

from torch.nn.utils.rnn import pad_sequence #(1)

def custom_collate(data): #(2)
    inputs = [torch.tensor(d['tokenized_input']) for d in data] #(3)
    labels = [d['label'] for d in data]

    inputs = pad_sequence(inputs, batch_first=True) #(4)
    labels = torch.tensor(labels) #(5)

    return { #(6)
        'tokenized_input': inputs,
        'label': labels
    }

loader = DataLoader(
  	nlp_data, 
    batch_size=2, 
    shuffle=False, 
    collate_fn=custom_collate
) #(7)

iter_loader = iter(loader)
batch1 = next(iter_loader)
pprint(batch1)
batch2 = next(iter_loader)
pprint(batch2)

# {'label': tensor([0, 0]),
#  'tokenized_input': tensor([
#   [  1,   4,   5,   9,   3,   2,   0,   0,   0],
#   [  1,   7,   3,  14,  48,   7,  23, 154,   2]
# ])}

# {'label': tensor([1, 0]),
#  'tokenized_input': tensor([
#   [  1,  30,  67, 117,  21,  15,   2],
#   [  1,  17,   2,   0,   0,   0,   0]])}

代码说明如下:

  • 我们使用 pad_sequence进行填充
  • Collate 函数要传入单个参数 - 样本列表。 在这种情况下,它将是一个字典列表,但它也可以是一个元组列表等——具体取决于数据集。
  • 当数据出现时,如果格式为“字典列表”,我们需要遍历它并为所有输入和标签创建一个单独的列表。 与此同时, tokenized_input 被转换为一维张量(它是一个整数列表)。
  • 执行填充。
  • 由于标签是整数列表,我们将其转换为张量。
  • 返回格式化的批次。
  • 在加载器中设置我们的自定义整理函数。

正如我们所看到的,批的格式与字典的默认排序规则相同。 我们清楚地看到填充量很小。

3、结束语

创建自定义整理函数可能不是最常见的任务,但你绝对需要知道如何去做。


原文链接:PyTorch collate_fn详解 - BimAnt

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/160800.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

【开源】基于JAVA的校园失物招领管理系统

项目编号: S 006 ,文末获取源码。 \color{red}{项目编号:S006,文末获取源码。} 项目编号:S006,文末获取源码。 目录 一、摘要1.1 项目介绍1.2 项目录屏 二、研究内容2.1 招领管理模块2.2 寻物管理模块2.3 系…

Linux shell编程学习笔记26:stty(set tty)

之前我们探讨了Linux中的tty,tty命令的主要功能是显示当前使用的终端名称。 如果我们想进一步对tty进行设置,就要用到stty。 stty的功能:显示和修改终端特性(Print or change terminal characteristics)。 1 stty -…

设置 SSH 主机 ***: (details) 连接到 VS Code Server - 重试 2

VS Code Server 一直重试输入密码 问题描述解决方法1、打开命令面板Ctrl shift p2、在输入框中输入Kill3、在弹出框中选择一直重复输入密码的服务器主机号,输入密码即可成功。 问题描述 VSCode 在使用插件 Remote - SSH 连接远程服务器时总是会遇到各种问题&#…

电脑软件:推荐一款非常实用的固态硬盘优化工具

目录 一、软件简介 二、工作原理 三、功能介绍 3.1、优化SSD设置 3.2、查看驱动器信息 3.3、查看SMART数据 3.4、停用Windows事件日志记录 3.5、禁用Windows碎片整理 3.6、时间戳停用 3.7、禁用引导文件的碎片整理 3.8、关闭短名称 四、使用教程 4.1 安装说明 4.…

猜数字优化版(带进度条)

其实就是加了个动态进度条显示加载游戏的流程&#xff0c;这样看上去是不是更有big了hhhh #include<windows.h> #include<iostream> #include<ctime> using namespace std; void menu() {printf("1.开始游戏\n");printf("0.退出游戏\n")…

CTFhub-RCE-综合过滤练习

%0a、%0d、%0D%0A burp 抓包 修改请求为 POST /?127.0.0.1%0als 列出当前目录 返回包 http://challenge-135e46015a30567b.sandbox.ctfhub.com:10800/?ip127.0.0.1%0acd%09*here%0ac%27a%27t%09* _311632412323588.php

基于和声算法优化概率神经网络PNN的分类预测 - 附代码

基于和声算法优化概率神经网络PNN的分类预测 - 附代码 文章目录 基于和声算法优化概率神经网络PNN的分类预测 - 附代码1.PNN网络概述2.变压器故障诊街系统相关背景2.1 模型建立 3.基于和声优化的PNN网络5.测试结果6.参考文献7.Matlab代码 摘要&#xff1a;针对PNN神经网络的光滑…

java+ 如何动态配置业务规则组

思路 1. 实现在页面上的动态配置规则组&#xff08;2张数据表枚举类serviceimplaction&#xff09; 2. 从数据库中表staffmoverules&#xff08;规则明细表&#xff09;或者staffmovetyperule&#xff08;规则组表&#xff09; &#xff0c;根据传入类型&#xff0c;取出规则编…

【开源】基于JAVA的快递管理系统

项目编号&#xff1a; S 007 &#xff0c;文末获取源码。 \color{red}{项目编号&#xff1a;S007&#xff0c;文末获取源码。} 项目编号&#xff1a;S007&#xff0c;文末获取源码。 目录 一、摘要1.1 项目介绍1.2 项目录屏 二、研究内容2.1 数据中心模块2.2 快递类型模块2.3 快…

YOLOv8-Seg改进:渐近特征金字塔网络(AFPN)

🚀🚀🚀本文改进:AFPN通过融合两个相邻的Low-Level特征来启动的,并渐进地将High-Level特征纳入融合过程,提升分割能力。 🚀🚀🚀AFPN小目标分割首选,暴力涨点 🚀🚀🚀YOLOv8-seg创新专栏:http://t.csdnimg.cn/KLSdv 学姐带你学习YOLOv8,从入门到创新,轻…

c++中的String

文章目录 String定义对象的方式成员函数operatorbegin/endsizecapacityreserversizeoperator/append/push_backoperator[]/at String String是一个类模版&#xff0c;可以定义一个字符/字符串对象。 字符顺序表 定义对象的方式 定义方式有很多重要的就这几种 string s1;stri…

配置环境-insightface-torch

1. 创建环境&#xff1a;conda create -n insightface2 python3.8 2.安装pytorch: 我的cuda 是 11.3 然后进入 pytorch 官网查找对应cuda 版本 pytorch 安装 建议使用 pip # CUDA 11.3 conda install pytorch1.12.1 torchvision0.13.1 torchaudio0.12.1 cudatoolkit11.3 -…

UE5 C++报错:is not currently enabled for Live Coding

解决办法&#xff1a; 再次打开项目&#xff0c;以此法打开&#xff1a;

贝叶斯AB测试

AB测试是用来评估变更效果的有效方法&#xff0c;但很多时候会运行大量AB测试&#xff0c;如果能够在测试中复用之前测试的结果&#xff0c;将有效提升AB测试的效率和有效性。原文: Bayesian AB Testing[1] 随机实验&#xff0c;又称AB测试&#xff0c;是行业中评估因果效应的既…

Windos操作系统下的Zookeeper安装图文教程

凯哥已经准备好最新版本3.9.1且已经配置好了。既获取到配置好的。 获取到凯哥准备的安装后&#xff0c;只需要修改一下配置。将解压包解压后&#xff0c;找到conf文件&#xff0c;里面有个zoo.cfg配置文件。如下图&#xff1a; 下载后conf文件夹内容 打开zoo.cfg配置文件后&a…

Vue项目

使用vs打开Vue项目 运行&#xff1a; 终端命令&#xff1a; npm run serve初次打开的是App.vue页面 修改服务器端口 vue.config.js const { defineConfig } require(vue/cli-service) module.exports defineConfig({transpileDependencies: true,devServer: { //增…

计算机网络——WLAN简解

1. WLAN的发展历程 ❓ WLAN和WIFI有什么区别。 &#x1f604; 具体来说&#xff0c;WALN是抽象的概念&#xff0c;代表这无线局域网这一类技术&#xff0c;而WIFI则是具体的具体技术标准&#xff0c;虽然在生活中&#xff0c;二者的表现是强相关的&#xff08;因为是使用的wifi…

C_11微机原理

一、单项选择题&#xff08;本大题共 15 小题&#xff0c;每小题 3分&#xff0c;共45分。在每小题给出的四个备选项中&#xff0c;选出一个正确的答案。&#xff09; .EXE 文件产生在&#xff08;&#xff09;之后。 A.汇编 B. 编辑 C.用软件转换 D.连接 2,十进制-61的8位二进…

调整COSWriter解决X-easypdf / PDFBOX生成大量数据时OOM问题

背景 业务需要生成一个15W数据左右的PDF交易报表。希望我们写在一个文件里&#xff0c;不拆分成多个PDF文件。 使用的技术组件 <dependency><groupId>wiki.xsx</groupId><artifactId>x-easypdf-pdfbox</artifactId><version>2.11.10<…

python中的字典

字典&#xff1a; 1.字典是一种可变容器模型&#xff0c;可以存储任意类型的对象&#xff0c;比如字符串&#xff0c;数字&#xff0c;元组等其他容 器模型 形式&#xff1a; d{key1&#xff1a;value1&#xff0c;key2&#xff1a;value2} 解释&#xff1a; 1.其中key1代表一…