【深度学习实战(11)】搭建自己的dataset和dataloader

一、dataset和dataloader要点说明

在我们搭建自己的网络时,往往需要定义自己的datasetdataloader,将图像和标签数据送入模型。
(1)在我们定义dataset时,需要继承torch.utils.data.dataset,再重写三个方法:

  • init方法,主要用来定义数据的预处理
  • getitem方法,数据增强;返回数据的item和label
  • len方法,返回数据数量

(2)在我们定义dataloader时,需要考虑下面几个参数:

  • dataset :使用哪个数据集
  • batch_size:将数据集拆成一组多少个进行训练
  • shuffle:是否需要打乱数据
  • num_workers:几个mini_batch并行计算,一般<=你的电脑cpu数目
  • collect_fn:数据打包方式

(3)通过迭代的方式,按批次,获取dataloader中的数据

(4)关系图

在这里插入图片描述

二、核心代码框架

import os
import cv2
from torchvision import transforms
from torch.utils.data.dataset import Dataset
from torch.utils.data import DataLoader


# -------------------------------------------------------------#
#   自定义dataset需要继承torch.utils.data.dataset,
#   再重写def __init__,def __len__,def __getitem__三个方法
# -------------------------------------------------------------#
class YourDataset(Dataset):
    def __init__(self,  root_path):
        super(YourDataset, self).__init__()
        self.root_path = root_path
        #-------------------------------------------------------------------------#
        #   获取样本名,以jpg原始图片为参考,修改后缀名为json,png,获取json,png标签文件路径
        #-------------------------------------------------------------------------#
        self.sample_names = []
        jpg_path = os.path.join(os.path.join(self.root_path, "images"),)
        for file in os.listdir(jpg_path):
            if file.endswith(".jpg"):
                self.sample_names.append(os.path.splitext(file)[0]) # 去掉.json

    def __len__(self):
        #----------------------#
        #   返回数据数量
        #----------------------#
        return len(self.sample_names)

    def __getitem__(self, index):
        name = self.sample_names[index]

        # ----------------------#
        #   读取图像
        # ----------------------#
        img_path = os.path.join(os.path.join(self.root_path, "images"), name + '.jpg')
        image = cv2.imread(img_path)
        # ----------------------#
        #   读取标签
        # ----------------------#
        label_path = os.path.join(os.path.join(self.root_path, "jsons"), name + '.json')
        with open(label_path) as label_file:
            points = self.get_data_from_json(label_file)
        #----------------------#
        #   图像数据增强
        #----------------------#
        image = self.random_color(image)
        #----------------------#
        #   标签归一化
        #----------------------#
        labels = self.convert_labels(points)
        return image,  labels

# -------------------------------------#
#   图片和标签格式转换后,按批次(batch)打包
# -------------------------------------#
def dataloader_collate_fn(batch):
    images = []
    labels = []
    for img, label in batch:
        images.append(transforms.ToTensor()(img))
        labels.append(label)
    return images, labels


if __name__ == '__main__':
    # -------------------------------------#
    #   构建dataset
    # -------------------------------------#
    path = './data/train'
    train_dataset = YourDataset(path)

    # -------------------------------------#
    #   构建Dataloader
    # -------------------------------------#
    dataset = train_dataset
    batch_size = 32
    shuffle = True
    num_workers = 0
    collate_fn = dataloader_collate_fn
    sampler = None
    train_gen = DataLoader(dataset=dataset, shuffle=shuffle, batch_size=batch_size, num_workers=num_workers, pin_memory=True,drop_last=True, collate_fn=collate_fn, sampler=sampler)
    # ---------------------------------------------#
    #   通过迭代的方式,一批一批读取训练集中的图像和标签数据
    # ---------------------------------------------#
    for iter, batch in enumerate(train_gen):
        images,  labels = batch

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/562644.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

计算机体系结构

体系结构 CPU&#xff1a;运算器和控制器 运算器&#xff1a;进行算术和逻辑运算控制器&#xff1a; 输入设备&#xff1a;鼠标、键盘、显示器、磁盘、网卡等输出设备&#xff1a;显卡&#xff0c;磁盘、网卡、打印机等存储器&#xff1a;内存&#xff0c;掉电易失总线&#xf…

刷题DAY59 | LeetCode 503-下一个更大元素II 42-接雨水

503 下一个更大元素II&#xff08;medium&#xff09; 给定一个循环数组 nums &#xff08; nums[nums.length - 1] 的下一个元素是 nums[0] &#xff09;&#xff0c;返回 nums 中每个元素的 下一个更大元素 。 数字 x 的 下一个更大的元素 是按数组遍历顺序&#xff0c;这个…

基于SpringBoot + Vue实现的奖学金管理系统设计与实现+毕业论文+答辩PPT

介绍 角色:管理员、学院负责人、学校负责人、学生 管理员:管理员登录进入高校奖助学金系统的实现可以查看系统首页、个人中心、学生管理、学院负责人管理、学校负责人管理、奖学金类型管理、奖学金申请管理、申请提交管理、系统管理等信息 学院负责人:学院负责人登录系统后&am…

14年电赛题--风洞实验--基于STM32与串口屏

前言&#xff1a; 经过三天两夜的比赛&#xff0c;最终我们还是取得了不错的成绩&#xff0c;只有第4问出了一点点问题&#xff0c;球没吹到最顶端。当时我们以为这个是最简单的问题&#xff0c;只要目标值给大点就没问题。但最终还是败在了这一问上&#xff0c;电压不够没吹到…

[已解决]react打包部署

react打包部署 问题 npm install 命令无反应 思路 换成 yarn install 安装完hadoop的环境后&#xff0c;使用node的yarn会报错&#xff1a; 我们在cmd使用where yarn&#xff0c;如下&#xff1a; 看你想保留哪一个&#xff0c;我平时node用的多&#xff0c;就把hadoop的y…

JavaEE 初阶篇-深入了解 I/O 流(FileInputStream 与 FileOutputStream 、Reader 与 Writer)

&#x1f525;博客主页&#xff1a; 【小扳_-CSDN博客】 ❤感谢大家点赞&#x1f44d;收藏⭐评论✍ 文章目录 1.0 I/O 流概述 2.0 文件字节输入流(FileInputStream) 2.1 创建 FileInputStream 对象 2.2 读取数据 2.3 关闭流 3.0 文件字节输出流(FileOutputStream) 3.1 创建 Fi…

代码随想录第42天|416. 分割等和子集

416. 分割等和子集 416. 分割等和子集 - 力扣&#xff08;LeetCode&#xff09; 代码随想录 (programmercarl.com) 动态规划之背包问题&#xff0c;这个包能装满吗&#xff1f;| LeetCode&#xff1a;416.分割等和子集_哔哩哔哩_bilibili 给你一个 只包含正整数 的 非空 数组…

知识加油站:数字阅览室全天候满足师生阅读需求

在知识经济时代&#xff0c;阅读已成为获取信息、提升素养、拓宽视野的重要途径。但在传统知识的海洋里&#xff0c;每一本书都是一座孤岛&#xff0c;每一个思想是一股潮流。传统的纸质阅读已经无法完全满足现代人快速、便捷、多样化的学习和阅读需求。因此&#xff0c;数字阅…

C++ 右值引用

1.左值引用和右值引用的概念 什么是左值&#xff1f;什么是左值引用&#xff1f; 左值是一个表示数据的表达式(如变量名或解引用的指针)&#xff0c;我们可以获取它的地址可以对它赋值&#xff0c;左值可以出现赋值符号的左边&#xff0c;右值不能出现在赋值符号左边。定义时co…

查看上一次错误的方法$err,hr,到底是什么意思

在《windows核心编程》或者《windows via C/C》一书中&#xff0c;提到过查看函数错误的方法&#xff0c;可以在watch窗口中输入"$err,hr"&#xff0c;来显示。比如下面一个程序 #include <Windows.h> int APIENTRY wWinMain(_In_ HINSTANCE hInstance,_In_op…

FPGA - ZYNQ Cache一致性问题

什么是Cache&#xff1f; Cache是一种用来提高计算机运行速度的一种技术。它是一种小而快的存储设备&#xff0c;位于CPU与内存之间&#xff0c;用于平衡高速设备与低速设备之间的速度差异。Cache可以存储常用的数据或指令&#xff0c;以便CPU更快地获取&#xff0c;从而减少对…

基于肿瘤相关成纤维细胞的前列腺癌患者分层研究(多组学)

Integrating single-cell and bulk RNA sequencing data unveils antigen presentation and process-related CAFS and establishes a predictive signature in prostate cancer https://pubmed.ncbi.nlm.nih.gov/38221616/#full-view-affiliation-3 文章思路学习&#xff1a…

YOLOv9改进策略 | SPPF篇 | 利用RT-DETR的AIFI模块替换SPPFELAN助力小目标检测涨点

一、本文介绍 本文给大家带来是用最新的RT-DETR模型中的AIFI模块来替换YOLOv9中的SPPFELAN。RT-DETR号称是打败YOLO的检测模型&#xff0c;其作为一种基于Transformer的检测方法&#xff0c;相较于传统的基于卷积的检测方法&#xff0c;提供了更为全面和深入的特征理解&#x…

如何30天快速掌握键盘盲打

失业后在家备考公务员&#xff0c;发现了自己不正确的打字方式&#xff0c;决定每天抽出一点时间练习打字。在抖音上看到一些高手的飞速盲打键盘后&#xff0c;觉得使用正确的指法打字是很必要的。 练习打字&#xff0c;掌握正确的键盘指法十分关键。 练习打字的第一步是找到…

基本的SELECT语句及DESC显示表结构

1. SELECT ... 例 : 2. SELECT ... FROM ... (1). SELECT ... : 标识选择哪些列. (2). FROM ... : 标识从哪个表中选取. (3). *通配符 : 选择表中全部列. 例 : 3.列的别名 (1). 空一格. (2). 在列和别名间加入关键字AS. (3). 别名可以使用双引号&#xff0c;以便于在…

【Datawhale LLM学习笔记】一、什么是大型语言模型(LLM)

文章目录 1. 什么是大模型2. 检索增强生成 RAG一、什么是 RAG二、RAG 的工作流程 3. langChain介绍一、什么是 LangChain二、LangChain 的核心组件 4. 开发 LLM 应用的整体流程一、何为大模型开发二、大模型开发的一般流程三、搭建 LLM 项目的流程简析&#xff08;以知识库助手…

从迷宫问题理解dfs

文章目录 迷宫问题打印路径1思路定义一个结构体要保存所走的路径&#xff0c;就需要使用到栈遍历所有的可能性核心代码 部分函数递归图源代码 迷宫问题返回最短路径这里的思想同上面类似。源代码 迷宫问题打印路径1 定义一个二维数组 N*M &#xff0c;如 5 5 数组下所示&…

掌握Node Version Manager(nvm):跨平台Node.js版本管理

&#x1f31f; 前言 欢迎来到我的技术小宇宙&#xff01;&#x1f30c; 这里不仅是我记录技术点滴的后花园&#xff0c;也是我分享学习心得和项目经验的乐园。&#x1f4da; 无论你是技术小白还是资深大牛&#xff0c;这里总有一些内容能触动你的好奇心。&#x1f50d; &#x…

整合阿里云短信服务

1. 申请服务 如图&#xff1a; 申请签名管理和模板管理 2. 进入快速学习和调试 2.1 进入快速学习 2.2 获取依赖和代码实现 3. 具体实现案例 3.1 添加依赖 <dependency><groupId>com.aliyun</groupId><artifactId>dysmsapi20170525</artifact…

9.MMD 基础内容总结及制作成品流程

前期准备 1. 导入场景和模型 在左上角菜单栏&#xff0c;显示里将编辑模型时保持相机和光照勾选上&#xff0c;有助于后期调色 将抗锯齿和各向异性过滤勾掉&#xff0c;可以节省资源&#xff0c;避免bug 在分辨率设定窗口&#xff0c;可以调整分辨率 3840x2160 4k分辨率 1…