【深度学习|Pytorch】torchvision.datasets.ImageFolder详解

ImageFolder详解

  • 1、数据准备
  • 2、ImageFolder类的定义
    • transforms.ToTensor()解析
  • 3、ImageFolder返回对象

1、数据准备

创建一个文件夹,比如叫dataset,将cat和dog文件夹都放在dataset文件夹路径下:
在这里插入图片描述

2、ImageFolder类的定义

class ImageFolder(DatasetFolder):
	def __init__(
	        self,
	        root: str,
	        transform: Optional[Callable] = None,
	        target_transform: Optional[Callable] = None,
	        loader: Callable[[str], Any] = default_loader,
	        is_valid_file: Optional[Callable[[str], bool]] = None,
	    ):

可以看到,ImageFolder类有这几个参数:
root:图片存储的根目录,即存放不同类别图片文件夹的前一个路径。
transform:即对加载的这些图片进行的前处理的方式,这里可以传入一个实例化的torchvision.Compose()对象,里面包含了各种预处理的操作。
target_transform:对图片类别进行预处理,通常来说不会用到这一步,因此可以直接不传入参数,默认图像标签没有变换,如果需要进行标签的处理,同样可以传入一个实例化的torchvision.Compose()对象。
loader:表示图像数据加载的方式,通常采用默认的加载方式,ImageFolder加载图像的方式为调用PIL库,因此图像的通道顺序是RGB而非opencv的BGR
is_valid_file:获取图像文件路径的函数,并且可以检查是否有损坏的文件。
示例代码:

ROOT_TEST = 'dataset' #dataset/cat, dataset/dog
normalize = transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])
val_transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    normalize
])

# 加载训练数据集
val_dataset = ImageFolder(ROOT_TEST, transform=val_transform)

transforms.ToTensor()解析

这里需要特别说一下ToTensor()这个函数的作用,刚接触深度学习的我那时以为只是单纯的将图像的ndarray和PIL格式转成Tensor格式,后来查看了一下源码之后发现,事情并没有这么简答!

   """Convert a PIL Image or ndarray to tensor and scale the values accordingly.

    This transform does not support torchscript.

    Converts a PIL Image or numpy.ndarray (H x W x C) in the range
    [0, 255] to a torch.FloatTensor of shape (C x H x W) in the range [0.0, 1.0]
    if the PIL Image belongs to one of the modes (L, LA, P, I, F, RGB, YCbCr, RGBA, CMYK, 1)
    or if the numpy.ndarray has dtype = np.uint8

    In the other cases, tensors are returned without scaling.

    .. note::
        Because the input image is scaled to [0.0, 1.0], this transformation should not be used when
        transforming target image masks. See the `references`_ for implementing the transforms for image masks.

    .. _references: https://github.com/pytorch/vision/tree/main/references/segmentation
    """

这是关于ToTensor()函数的注解,这里明确指出了ToTensor()可以将PIL和ndarray格式的图像数据转成Tensor并缩放它们的值,这里的缩放他们的值的意思在下面也指出了,即将[0, 255]的像素值域归一化[0, 1.0],并且图像转换成Tensor格式之后,维度的顺序也会发生一点变化,从一开始的HWC变成了CHW的排列方式。

3、ImageFolder返回对象

以第一部分为例,我们用一个val_dataset接收了ImageFolder的返回值,那么这个Val_dataset对象里面包含了什么呢:
val_dataset.classes:存放着根目录下的子文件夹的名称(类别名称)的列表。
val_dataset.class_to_idx:存放着类别名称和各自的索引,字典类型。
val_dataset.extensions:存放着ImageFolder可以读取的图像格式名称,元组类型。
val_dataset.targets:存放着根目录下每一张图的类别索引。
val_dataset.transform:我们提供的transform的方式。
val_dataset.imgs:存放着根目录下每一张图的路径和类别索引。元组列表类型。
以上是关于这个ImageFolder返回的对象的属性的解析。

此外,我们可以通过一个for循环来遍历整个val_dataset的所有图像数据,其中val_dataset[i]是一个元组类型的数据,val_dataset[i][0]代表了前处理后的图像数据,类型为tensor,以AlexNet为例,此时的tensor应该是3 * 224 * 224的维度。val_dataset[i][1]代表了图像的类别索引。
完整示例代码:

import torch
from AlexNet import AlexNet
from torch.autograd import Variable
from torchvision import transforms
from torchvision.transforms import ToPILImage
from torchvision.datasets import ImageFolder
from torch.utils.data import DataLoader

# ROOT_TRAIN = 'D:/pycharm/AlexNet/data/train'
ROOT_TEST = 'dataset'

# 将图像的像素值归一化到[-1,1]之间
normalize = transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])

val_transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    normalize
])

# 加载训练数据集
val_dataset = ImageFolder(ROOT_TEST, transform=val_transform)

# 如果有NVIDA显卡,转到GPU训练,否则用CPU
device = 'cuda' if torch.cuda.is_available() else 'cpu'

# 模型实例化,将模型转到device
model = AlexNet().to(device)

# 加载train.py里训练好的模型
model.load_state_dict(torch.load(r'save_model/model_best.pth'))

# 结果类型
classes = [
    "cat",
    "dog"
]

# 把Tensor转化为图片,方便可视化
show = ToPILImage()

# 进入验证阶段
model.eval()
for i in range(10):
    x, y = val_dataset[i][0], val_dataset[i][1]
    # show():显示图片
    # show(x).show()
    # torch.unsqueeze(input, dim),input(Tensor):输入张量,dim (int):插入维度的索引,最终扩展张量维度为4维
    x = Variable(torch.unsqueeze(x, dim=0).float(), requires_grad=False).to(device)
    with torch.no_grad():
        pred = model(x)
        # argmax(input):返回指定维度最大值的序号
        # 得到预测类别中最高的那一类,再把最高的这一类对应classes中的那一类
        predicted, actual = classes[torch.argmax(pred[0])], classes[y]
        # 输出预测值与真实值
        print(f'predicted:"{predicted}", actual:"{actual}"')

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/515080.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

C# WPF编程-元素绑定

C# WPF编程-元素绑定 将元素绑定到一起绑定表达式绑定错误绑定模式代码创建绑定移除绑定使用代码检索绑定多绑定绑定更新绑定延时 绑定到非元素对象Source属性RelativeSource属性DataContent属性 数据绑定是一种关系,该关系告诉WPF从源对象提取一下信息,…

296个地级市GDP相关数据集(2000-2023年)

01、数据简介 GDP,即国内生产总值(Gross Domestic Product),是指一个国家或地区所有常住单位在一定时期内生产活动的最终成果。 名义GDP,也称货币GDP,是指以生产物品和劳务的当年销售价格计算的全部最终产…

OpenHarmony实战:CMake方式组织编译的库移植

以double-conversion库为例,其移植过程如下文所示。 源码获取 从仓库获取double-conversion源码,其目录结构如下表: 表1 源码目录结构 名称描述double-conversion/cmake/CMake组织编译使用到的模板double-conversion/double-conversion/源…

南京大学提出用于大模型生成的动态温度采样法,简单有效!

在自然语言处理(NLP)的领域,大语言模型(LLMs)已经在各种下游语言任务中展现出了卓越的性能。这些任务包括但不限于问答、摘要、机器翻译等。LLMs的强大能力在于其生成的文本质量和多样性。为了控制生成过程&#xff0c…

力扣由浅至深 每日一题.22 移除链表元素

迄今为止的生命里 —— 24.4.4 移除链表元素 给你一个链表的头节点 head 和一个整数 val ,请你删除链表中所有满足 Node.val val 的节点,并返回 新的头节点 。 示例 1: 输入:head [1,2,6,3,4,5,6], val 6 输出:[1,2…

【子集回溯】Leetcode 78. 子集 90. 子集 II

【子集回溯】Leetcode 78. 子集 90. 子集 II 78. 子集90. 子集 II ---------------&#x1f388;&#x1f388;78. 子集 题目链接&#x1f388;&#x1f388;------------------- 78. 子集 class Solution {List<List<Integer>> result new ArrayList<>()…

基于约束求解器对“火影忍者Online”进行智能布阵

文章目录 1. 游戏背景2. 确定决策边界3. 布阵数据3.1 追击状态3.2 角色信息3.3 个性化要求 4. 智能布阵模型4.1 主要的决策变量4.2 约束条件&#xff08;含辅助决策变量&#xff09;4.3 目标函数及求解 1. 游戏背景 今天将以“火影忍者Online”为案例&#xff0c;写一个智能布…

STM32工程 如何设置堆栈大小(Heap和Stack)

方法1&#xff1a;通过CubeMX、CubeIDE 配置 方法2&#xff1a;直接在启动文件中修改 &#xff08;适合所有Keil工程&#xff09; Heap、Stack的值大小&#xff0c;不管使用哪种开发环境&#xff0c;它俩都肯定在启动文件中。 可以通过CtrlF&#xff0c;搜索: Heap&#xff0…

【Linux】从零认识文件操作

送给大家一句话&#xff1a; 要相信&#xff0c;所有的不美好都是为了迎接美好&#xff0c;所有的困难都会为努力让道。 —— 简蔓《巧克力色微凉青春》 开始理解基础 IO 吧&#xff01; 1 前言2 知识回顾3 理解文件3.1 进程和文件的关系3.2 文件的系统调用openwrite文件 fd 值…

STL常用容器(2)---vector容器

1.1 vector基本概念 功能&#xff1a; vector数据结构和数组非常相似&#xff0c;也称为单端数组 vector与普通数组区别&#xff1a; 不同之处在于数组是静态空间&#xff0c;而vector可以动态扩展 动态扩展&#xff1a; 并不是在原空间之后的续接的新空间&#xff0c;而…

如何从 Android 和 iPhone 中的 SIM 卡恢复已删除的联系人 [新]

在手机上&#xff0c;我们经常添加联系人&#xff0c;而很少关心联系人是存储在SIM卡中还是手机中。当我们错误删除SIM卡联系人&#xff0c;或者不当取出插入的SIM卡插入新手机时&#xff0c;那些因业务需要而添加的联系人就会消失。这可能会令人沮丧和困惑。因此&#xff0c;您…

UniApp 应用发布到苹果商店指南

&#x1f680; 想要让你的 UniApp 应用在苹果商店亮相吗&#xff1f;别着急&#xff0c;让我来带你一步步完成这个重要的任务吧&#xff01;在这篇博客中&#xff0c;我将详细介绍如何将 UniApp 应用顺利发布到苹果商店&#xff0c;让你的应用跻身于苹果生态之中。 引言 &…

Python向带有SSL/TSL认证服务器发送网络请求小实践(附并发http请求实现asyncio+aiohttp)

1. 写在前面 最近工作中遇到这样的一个场景&#xff1a;给客户发送文件的时候&#xff0c;为保证整个过程中&#xff0c;文件不会被篡改&#xff0c;需要在发送文件之间&#xff0c; 对发送的文件进行签名&#xff0c; 而整个签名系统是另外一个团队做的&#xff0c; 提供了一…

银行数字化转型导师坚鹏:银行数字化转型必知的3大客户分析维度

银行数字化转型需要进行客户分析&#xff0c;如何进行客户分析呢&#xff1f;银行数字化转型导师坚鹏认为至少从客户需求分析、客户画像分析、客户购买行为分析3个维度进行客户分析。 1.客户需求分析 银行数字化转型需要了解客户需求&#xff0c;不同年龄段的客户有不同的需求…

游戏APP如何提高广告变现收益的同时,保证用户留存率?

APP广告变现对接第三方聚合广告平台主要通过SDK文档对接&#xff0c;一些媒体APP不具备专业运营广告变现的对接能力和资源沉淀&#xff0c;导致APP被封控&#xff0c;设置列入黑名单&#xff0c;借助第三方聚合广告平台进行商业化变现是最佳选择。#APP广告变现# 接入第三方平台…

VGG网络模型

VGG网络模型 VGG的网络架构VGG16VGG19 特点总结时间关系AlexNet和VGG相似之处AlexNet和VGG不同之处启发与影响总结 VGG&#xff08;Visual Geometry Group&#xff09;是由牛津大学的 Visual Geometry Group 提出的一个深度卷积神经网络模型&#xff0c;它在2014年的ImageNet大…

哲♂学家带你深♂入了解动态顺序表

前言&#xff1a; 最近本哲♂学家学习了顺序表&#xff0c;下面我给大家分享一下关于顺序表的知识。 一、什么是顺序表 顺序表是用一段物理地址连续的存储单元依次存储数据元素的线性结构&#xff0c;一般情况下采用数组存储。在数组 上完成数据的增删查改。 顺序表&#xff…

动态规划刷题(算法竞赛、蓝桥杯)--乌龟棋(线性DP)

1、题目链接&#xff1a;[NOIP2010 提高组] 乌龟棋 - 洛谷 #include <bits/stdc.h> using namespace std; const int M41; int f[M][M][M][M],num[351],g[5],n,m,x; //f[a][b][c][d]表示放a个1b个2c个3d个4的总得分 int main(){scanf("%d %d",&n,&m)…

创新指南|贝恩的产品经理RAPID框架:解决问题的分步指南,使决策过程既高效又民主

您是否曾发现自己陷入项目的阵痛之中&#xff0c;决策混乱、角色不明确、团队成员之间的冲突不断升级&#xff1f;作为产品经理&#xff0c;驾驭这艘船穿过如此汹涌的水域可能是令人畏惧的。应对这些挑战的关键在于采用清晰、结构化的决策方法。输入贝恩的 RAPID 框架&#xff…

软件测试用例(2)

具体的设计方法 -- 黑盒测试 因果图 因果图是一种简化的逻辑图, 能直观地表明程序的输入条件(原因)和输出动作(结果)之间的相互关系. 因果图法是借助图形来设计测试用例的一种系统方法, 特别适用于被测试程序具有多种输入条件, 程序的输出又依赖于输入条件的各种情况. 因果图…