【计算机视觉基础CV】03-深度学习图像分类实战:鲜花数据集加载与预处理详解

本文将深入介绍鲜花分类数据集的加载与处理方式,同时详细解释代码的每一步骤并给出更丰富的实践建议和拓展思路。以实用为导向,为读者提供从数据组织、预处理、加载到可视化展示的完整过程,并为后续模型训练打下基础。


前言

在计算机视觉的深度学习实践中,数据加载和预处理是至关重要的一步。无论你是初学者,还是有一定经验的从业者,都需要深刻理解如何将原始数据转化为神经网络可接受的输入。PyTorch中的torchvision.datasetstorchvision.transforms为我们提供了极大的便利,使图像数据的加载和处理更加高效与简洁。

本文将以“鲜花分类数据集”(一个包含5种不同花卉类别的图像数据集)为例,详细讲述如何使用ImageFolder类进行数据加载,并通过transforms对图像进行预处理和数据增强。我们还会深入讨论数据集结构、训练/验证集划分、代码注释和实践建议,并给出详细说明。


数据集简介与结构

本例使用的鲜花分类数据集共包含5种花:雏菊(daisy)、蒲公英(dandelion)、玫瑰(roses)、向日葵(sunflowers)和郁金香(tulips)。数据量约为:

  • 训练集(train):3306张图像

  • 验证集(val):364张图像

数据已按类别分好目录,每个类别对应一个文件夹,文件夹中存放若干图片文件。结构示意如下:

dataset/flower_datas/
  ├─ train/
  │   ├─ daisy/       # 雏菊类图像若干张
  │   ├─ dandelion/   # 蒲公英类图像若干张
  │   ├─ roses/       # 玫瑰类图像若干张
  │   ├─ sunflowers/   # 向日葵类图像若干张
  │   └─ tulips/       # 郁金香类图像若干张
  └─ val/
      ├─ daisy/
      ├─ dandelion/
      ├─ roses/
      ├─ sunflowers/
      └─ tulips/

这种目录结构非常适合ImageFolder数据集类,它会根据子文件夹的名称自动分配类别标签,从0开始编号。例如:

  • daisy -> 0

  • dandelion -> 1

  • roses -> 2

  • sunflowers -> 3

  • tulips -> 4

这样无需手动编码类别映射,简化了流程。


ImageFolder和transform

ImageFolder简介

ImageFoldertorchvision.datasets中的一个实用数据类,它假设数据按如下规则组织:

  • root/class_x/xxx.png

  • root/class_x/xxy.png

  • root/class_y/xxz.png

  • ...

其中class_xclass_y是类名(字符串),ImageFolder会根据这些类名自动生成类别索引。加载后,每个样本是一个(image, label)二元组,image通常会通过transform转换为Tensorlabel为整数索引。


transforms的数据预处理功能

torchvision.transforms提供多种图像处理方法,用来改变图像格式、尺寸、颜色空间和进行数据增强。例如:

  • ToTensor():将PIL图像或Numpy数组转换为(C,H,W)格式的张量,并将像素值归一化到[0,1]之间。

  • Resize((224,224)):将图像缩放到224x224大小,这通常是预训练模型如ResNet、VGG的标准输入尺寸。

  • RandomHorizontalFlip():随机水平翻转图像,用于数据增强,提高模型对翻转不敏感。

  • Normalize(mean, std):对图像的每个通道进行归一化,使训练更稳定。

你可以根据需求灵活组合多个变换操作,使用transforms.Compose将其串联成流水线。


加载鲜花分类数据集的示例代码

下面的代码示例中,我将详细注释每个步骤,为读者提供清晰的思路。该示例以最基本的ToTensor和Resize为主,读者可按需添加更多transform。

import torch
import torchvision
from torchvision.datasets import ImageFolder
from torchvision import transforms
import matplotlib.pyplot as plt

# 数据集存放路径,根据实际情况修改
flowers_train_path = '../01.图像分类/dataset/flower_datas/train/'
flowers_val_path = '../01.图像分类/dataset/flower_datas/val/'

# 定义数据预处理
# 这里的transforms主要包括:
# 1. ToTensor():将PIL图片或numpy数组转为Tensor,并将像素值归一化到[0,1]区间。
# 2. Resize((224,224)):将所有图片大小统一为224x224,以匹配后续卷积神经网络的输入要求。
# 对于实际训练,更建议加入数据增强手段(如随机裁剪、翻转、归一化等),
# 但本例先展示基本流程。
dataset_transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Resize((224,224))
])

# 使用ImageFolder加载训练集和验证集
# ImageFolder会扫描指定目录下的子文件夹,并以子文件夹名称作为类别。
flowers_train = ImageFolder(root=flowers_train_path, transform=dataset_transform)
flowers_val = ImageFolder(root=flowers_val_path, transform=dataset_transform)

# 打印样本数量
print("训练集样本数:", len(flowers_train))
print("验证集样本数:", len(flowers_val))

# flowers_train.classes属性包含类别名称列表,如['daisy', 'dandelion', 'roses', 'sunflowers', 'tulips']
print("类别名称列表:", flowers_train.classes)

# 获取单个样本进行查看
# __getitem__(index)返回(img, label),img是Tensor,label是int
sample_index = 3000
sample_img, sample_label = flowers_train[sample_index]

print("样本索引:", sample_index)
print("类别标签索引:", sample_label, "类别名称:", flowers_train.classes[sample_label])
print("图像Tensor尺寸:", sample_img.shape)  # 期望为[3,224,224]

# 可视化图像
# Matplotlib的imshow要求图像为(H,W,C),而Tensor是(C,H,W),需要permute调整维度顺序。
plt.imshow(sample_img.permute(1,2,0))
plt.title(flowers_train.classes[sample_label])
plt.show()

代码输出: 


关于训练集、验证集和测试集的说明

本数据集中已提前将数据分为trainval两个目录:

  • train/:训练集,用于模型训练过程中反向传播和参数更新。

  • val/:验证集,用于在训练中间进行性能评估,不参与参数更新,仅用于选择超参数或判断训练是否过拟合。

有些数据集还会提供test/测试集,用于最终评估模型在未知数据上的表现,但本例中未提供,如有需要可自行分割数据或从其他来源获取。


DataLoader的引入

仅有ImageFolder还不够,为了在训练时批量读取数据并进行迭代,我们通常会将数据集对象传入DataLoader中。

DataLoader的作用是:

  • 按指定的batch_size从Dataset中抽取样本构成mini-batch。

  • 可设置shuffle=True来随机打乱样本顺序,防止模型记住样本顺序。

  • 使用num_workers参数并行加速数据加载。

示例(可选代码):

from torch.utils.data import DataLoader

batch_size = 32
# 定义训练集和验证集的DataLoader
train_loader = DataLoader(flowers_train, batch_size=batch_size, shuffle=True, num_workers=2)
val_loader = DataLoader(flowers_val, batch_size=batch_size, shuffle=False, num_workers=2)

# 测试一下加载结果
images, labels = next(iter(train_loader))
print("一个batch的图像尺寸:", images.shape)  # [batch_size, 3, 224, 224]
print("对应的标签:", labels)  # 张量形式,如tensor([0, 1, 3, ...])


有了DataLoader,我们在训练模型时,就可以轻松迭代数据:

for epoch in range(1):
    for batch_images, batch_labels in train_loader:
        # 在这里将batch_images, batch_labels输入模型进行训练
        print("一个batch的图像尺寸:", batch_images.shape)  # [batch_size, 3, 224, 224]
        print("对应的标签:", batch_labels)  # 张量形式,如tensor([0, 1, 3, ...])
        pass
        break


我们可以打印一下第一个batch 和最后一个batch的标签

batch_count = 0
first_batch_images, first_batch_labels = None, None
last_batch_images, last_batch_labels = None, None

for epoch in range(1):
    for batch_images, batch_labels in train_loader:
        batch_count += 1
        
        # 保存第一个batch
        if batch_count == 1:
            first_batch_images, first_batch_labels = batch_images, batch_labels
            print("第一个batch的图像尺寸:", batch_images.shape)
            print("第一个batch的标签:", batch_labels)

        # 每次循环都会更新last_batch
        last_batch_images, last_batch_labels = batch_images, batch_labels
        
    break  # 只进行一次epoch的训练,移除这行会进行多个epoch的训练

# 打印最后一个batch
print("最后一个batch的图像尺寸:", last_batch_images.shape)
print("最后一个batch的标签:", last_batch_labels)

# 打印总共的batch数量
print("总共的batch数量:", batch_count)


数据增强策略的拓展

实际训练中,为提高模型的泛化能力,我们常加入数据增强操作。这些操作对训练集图像进行随机变换,如随机剪裁、翻转、颜色抖动、归一化等。这样模型不会过度记忆特定图像的像素分布,而会学习更有泛化性的特征。

一个常用的transform示例:

# 定义训练集的图像预处理流程
train_transform = transforms.Compose([
    # 随机裁剪并缩放图像到224x224的尺寸,裁剪的区域大小是随机的
    transforms.RandomResizedCrop(224),  
    
    # 随机进行水平翻转,用于数据增强,提升模型的泛化能力
    transforms.RandomHorizontalFlip(),
    
    # 将图像转换为Tensor类型,PyTorch要求输入为Tensor格式
    transforms.ToTensor(),
    
    # 进行图像的标准化处理。根据ImageNet数据集的均值和标准差进行归一化,
    # 使得不同的通道(RGB)具有相同的尺度,便于训练。
    transforms.Normalize(mean=[0.485,0.456,0.406], std=[0.229,0.224,0.225])
])

# 定义验证集的图像预处理流程
val_transform = transforms.Compose([
    # 将图像的最短边缩放到256像素,保持长宽比例不变
    transforms.Resize(256),  
    
    # 从缩放后的图像中进行中心裁剪,裁剪出224x224的区域,这样图像的尺寸就一致了
    transforms.CenterCrop(224),
    
    # 将图像转换为Tensor类型,PyTorch要求输入为Tensor格式
    transforms.ToTensor(),
    
    # 进行图像的标准化处理。根据ImageNet数据集的均值和标准差进行归一化,
    # 使得不同的通道(RGB)具有相同的尺度,便于训练。
    transforms.Normalize(mean=[0.485,0.456,0.406], std=[0.229,0.224,0.225])
])

# 使用定义的transform对训练集和验证集进行图像预处理
# flowers_train_path和flowers_val_path是训练集和验证集图像所在的路径
flowers_train = ImageFolder(flowers_train_path, transform=train_transform)  # 训练集
flowers_val = ImageFolder(flowers_val_path, transform=val_transform)  # 验证集

在此示例中,Normalize的参数是使用ImageNet数据集的均值和标准差,这在使用ImageNet预训练模型时是常规操作。对于自定义数据集,你也可以先统计本数据集的均值和方差,再进行归一化。


我们可以打印一下变化前后的图像区别

import os
import random
import numpy as np  # 需要导入numpy
import matplotlib.pyplot as plt
from PIL import Image
from torchvision import transforms
from torchvision.datasets import ImageFolder

# 定义训练集的图像预处理流程
train_transform = transforms.Compose([
    transforms.RandomResizedCrop(224),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

# 定义图像数据集路径
train_image_folder = '/Users/coyi/PycharmProjects/coyi_pythonProject/01.图像分类/dataset/flower_datas/train/'

# 使用ImageFolder加载数据集
dataset = ImageFolder(train_image_folder, transform=None)

# 随机选取一张图片
random_idx = random.randint(0, len(dataset) - 1)
image, label = dataset[random_idx]

# 显示原始图像
plt.figure(figsize=(5,5))
plt.title("Original Image")
plt.imshow(image)
plt.axis('off')  # 不显示坐标轴
plt.show()

# 应用train_transform变换
transformed_image = train_transform(image)

# 反标准化(Undo normalization)以恢复图片的原始视觉效果,因为训练的时候需要标准化
inv_normalize = transforms.Normalize(mean=[-0.485, -0.456, -0.406], std=[1/0.229, 1/0.224, 1/0.225])
unnormalized_image = inv_normalize(transformed_image)

# 将Tensor转回PIL图像进行显示
unnormalized_image = unnormalized_image.permute(1, 2, 0).numpy()  # 转换为HWC格式
unnormalized_image = np.clip(unnormalized_image, 0, 1)  # 限制值在[0, 1]之间,以符合视觉输出

# 显示变换后的图像
plt.figure(figsize=(5,5))
plt.title("Transformed Image")
plt.imshow(unnormalized_image)
plt.axis('off')  # 不显示坐标轴
plt.show()

输出: 

备注: 为了显示图片,我对处理后的图片进行了反标准化,实际上训练的时候是不需要反标准化的


为什么要反标准化?

标准化是一个常见的预处理步骤,目的是让模型训练时更稳定,通常是将像素值转换到均值为0、标准差为1的范围。这可以帮助模型更好地收敛,并且消除不同通道(例如RGB)的尺度差异。

然而,标准化后的图像不适合直接用于可视化,因为它们的像素值已经不在[0, 1]的范围内,可能会变成负数或大于1。反标准化的目的是恢复图像的原始视觉效果,让它们的像素值回到原始的视觉范围。

不反标准化可以吗?

在可视化时不反标准化是可以的,但你会看到经过标准化后的图像没有直观的可视化效果,因为图像的像素值会偏离 [0, 1] 的可视化范围。这会导致显示的图像看起来可能是“失真”的,例如图像会变得非常暗、非常亮,或者有一些不自然的颜色。

简而言之:

反标准化是为了恢复图像的原始视觉效果,使得图像显示更符合人类的感知。

• **np.clip()**是为了确保图像的像素值在[0, 1]范围内,符合图像显示的要求。

示例:

假设标准化之后,你得到了一个像素值为 -0.5 或 1.5 的图像像素。这时,如果不进行 np.clip(),直接用 matplotlib 显示,可能会看到图像出现异常的颜色或显示不出来。而通过 np.clip(),将这些像素值限制在[0, 1]的范围内,可以确保图像能正确显示。


类别分布与标签可解释性

flowers_train.classesflowers_val.classes可以查看类名列表。例如:

这意味着模型预测结果中的label=0代表daisy,label=1代表dandelion,以此类推。当我们预测模型输出为label=3时,就可以将其解释为sunflowers。这种可读性非常有助于后期分析和调试。

如果想查看具体每类样本数量,可手动统计,例如:

 

通过查看类别分布,我们可了解数据是否偏斜(某些类样本过多或过少),从而采取相应措施(如类均衡采样、权重平衡等)。


实战建议和下一步计划

  1. 数据准备完成后做什么? 通常下一步就是定义和加载模型(如预训练的ResNet18),然后编写训练循环对模型进行微调或从头训练。在训练循环中,train_loader提供批数据,val_loader则用于评估模型在验证集上的表现。

  2. 调试DataLoader是否正确工作: 在正式训练前,尝试可视化几个batch的数据样本,确保图像大小、颜色正确,标签映射无误。如果出现图像显示不正确或标签偏移,及时检查目录结构和transform流程。

  3. 善用数据增强: 当验证集精度停滞不前或出现过拟合时,尝试加入更多数据增强手段(如RandomRotationColorJitterRandomGrayscale等)提升泛化性能。

  4. 硬件加速: 在加载大规模数据时,合理增加num_workers可以提高数据读取速度(依赖操作系统和硬件条件)。同时,如果是分布式训练,也需考虑分布式Sampler和合适的数据划分策略。

  5. 定制Dataset: 如果你的数据不遵循ImageFolder的结构,也可以自行定义Dataset类,通过实现__len____getitem__方法来自定义数据加载流程。但对像本例这样已按类分文件夹的数据集,ImageFolder无疑是最简单高效的方案。


小结

在本文中,我们从零出发,详细介绍了如何使用PyTorch的ImageFoldertransforms加载和预处理鲜花分类数据集。主要点包括:

  • 数据集组织结构:子文件夹命名为类名,便于ImageFolder自动识别类别。

  • 使用transforms对图像进行ToTensor和Resize等变换,以满足神经网络输入要求。

  • 通过可视化样本和打印类别信息确认数据加载的正确性。

  • 引入DataLoader批量采样和迭代数据,为后续训练循环奠定基础。

  • 展望数据增强、Normalize以及预训练模型迁移学习等实战技巧。

数据加载与预处理是深度学习项目不可或缺的步骤。掌握这些技能,能够让你在模型开发和实验中更加得心应手。未来你可以尝试更多高级技巧,如自定义transforms、对数据集进行统计分析、探索更复杂的增强策略和分布式数据加载方法。

达成这些基础后,你就可以开始定义模型(如使用torchvision.models.resnet18(pretrained=True)加载预训练模型)、设置损失函数(如CrossEntropyLoss)、选择优化器(如Adam或SGD),并在训练循环中快速迭代提升模型性能。

希望本文介绍,能为你对CV数据加载与预处理的理解添砖加瓦,帮助你在图像分类任务中迈出稳健的一步。


如果你遇到了什么问题,或者想了解某些方面的知识,欢迎在评论区留言

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/939885.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

Unity-Editor扩展GUI基本实现一个可拖拉放的格子列表

短短几百行代码,好吧,又是“参考”了国外的月亮 操作,还真地挺自然的。。。。。。国外的实现有点小牛 拖拉,增加+ 一个Element 鼠标左键长按,可以出提示 鼠标右键,清除Element, 有点小bug,不是很自然地完全清除, using System.Collections; using System.Collecti…

修改vscode中emmet中jsx和tsx语法中className的扩展符号从单引号到双引号 - HTML代码补全 - 单引号双引号

效果图 实现步骤 文件 > 首选项 > 设置搜索“”在settings.json中修改,增加 "emmet.syntaxProfiles": {"html": {"attr_quotes": "single"},"jsx": {"attr_quotes": "double","…

首批|云轴科技ZStack成为开放智算产业联盟首批会员单位

近日 ,在Linux基金会AI & Data及中国开源软件推进联盟的指导之下,开放智算产业联盟成立大会在北京成功召开。在大会上,联盟首次公布了组织架构并颁发了首批会员单位证书。凭借ZStack AIOS平台智塔和在智算领域的技术创新,云轴…

HTN 78A3 6V~140V输入,3A实地异步降压变换器

1、特征 3A降压,内置250mΩ高侧功率管 输入电压范围:6V~140V 脉冲跳跃模式使得轻载下高效率 最高1MHZ可编程开关频率 COT纹波电压控制架构 欠压保护、过流保护和过热关断保护 无铅封装,ESOP8 2、应用 二轮电瓶车 太阳能系统 高压电池组 …

以太网帧、IP数据报图解

注:本文为 “以太网帧、IP数据报”图解相关文章合辑。 未整理去重。 以太网帧、IP数据报的图解格式(包含相关例题讲解) Rebecca.Yan已于 2023-05-27 14:13:19 修改 一、基础知识 UDP 段、IP 数据包,以太网帧图示 通信过程中&…

汽车IVI中控开发入门及进阶(三十九):ADAS的车道线检测

概述: 自动驾驶汽车中确保驾驶员和乘客安全环境的重要系统之一是高级驾驶员辅助系统(ADAS)。自适应巡航控制、自动制动/转向、车道保持系统、盲点辅助、车道偏离警告系统和车道检测都是ADAS的示例。车道检测向车辆的智能系统显示特定于车道线结构几何特征的信息,以显示车道…

Liveweb视频汇聚平台支持WebRTC协议赋能H.265视频流畅传输

随着科技的飞速发展和网络技术的不断革新,视频监控已经广泛应用于社会各个领域,成为现代安全管理的重要组成部分。在视频监控领域,视频编码技术的选择尤为重要,它不仅关系到视频的质量,还直接影响到视频的传输效率和兼…

搭建一个简单的Web服务器(Apache2.4)

一、下载安装包 Apache服务器的官方下载地址:http://httpd.apache.org windows系统的安装软件下载地址:https://www.apachelounge.com/download/ 下载的是免安装的版本:httpd-2.4.62-240904-win64-VS17.zip 解压后将 Apache24 拷贝到 D 盘&…

NTFS 文件搜索库

NTFS 文件搜索库 中文 | English 一个快速搜索NTFS卷文件的库 在这里插入图片描述 特性 快速扫描 NTFS 格式驱动器上的所有文件实时快速同步文件变更(创建, 更名, 删除)支持通配符查询文件名或文件路径重启自动更新文件变动, 无需重新进行全盘扫描 API描述 初始化并指定…

Ubuntu下C语言操作kafka示例

目录 安装kafka: 安装librdkafka consumer Producer 测试运行 安装kafka: Ubuntu下Kafka安装及使用_ubuntu安装kafka-CSDN博客 安装librdkafka github地址:GitHub - confluentinc/librdkafka: The Apache Kafka C/C library $ apt in…

JavaScript基础学习

1.一元运算符 正负号,自增运算符 2.比较运算符 判断相等用 字符串比较,比较的是ASC码 尽量不要比较小数,小数有精度 不同类型比较会有隐式转换 3.运算符优先级 4.分支语句 除了0,其余数字都是真。除了空字符串&#xff0c…

SQL进阶技巧:如何计算商品需求与到货队列表进出计划?

目录 0 需求描述 1 数据准备 2 问题分析 3 小结 累计到货数量计算 出货数量计算 剩余数量计算 0 需求描述 假设现有多种商品的订单需求表 DEMO_REQUIREMENT,以及商品的到货队列表 DEMO_ARR_QUEUE,要求按照业务需要,设计一个报表&#…

“调用函数”多出一个None

Python中函数缺省返回空值,试图打印没有return的函数都将是None。 (笔记模板由python脚本于2024年12月20日 19:59:03创建,本篇笔记适合会自定义python函数的coder翻阅) 【学习的细节是欢悦的历程】 Python 官网:https://www.python.org/ Fre…

Flask内存马学习

文章目录 参考文章环境搭建before_request方法构造内存马after_request方法构造内存马errorhandler方式构造内存马add_url_rule方式构造内存马 参考文章 https://www.mewo.cc/archives/10/ https://www.cnblogs.com/gxngxngxn/p/18181936 前人栽树, 后人乘凉 大佬们太nb了, …

在Win11系统上安装Android Studio

诸神缄默不语-个人CSDN博文目录 下载地址:https://developer.android.google.cn/studio?hlzh-cn 官方安装教程:https://developer.android.google.cn/studio/install?hlzh-cn 点击Next,默认会同时安装Android Studio和Android虚拟机&#…

Python读取Excel批量写入到PPT生成词卡

一、问题的提出 有网友想把Excel表中的三列数据,分别是:单词、音标和释义分别写入到PPT当中,每一张PPT写一个单词的内容。这种批量操作是python的强项,尤其是在办公领域,它能较好地解放双手,读取Excel表后…

【CC2530开发基础篇】继电器模块使用

一、前言 1.1 开发背景 本实验通过使用CC2530单片机控制继电器的吸合与断开,深入了解单片机GPIO的配置与应用。继电器作为一种常见的电气控制元件,广泛用于自动化系统中,用于控制大功率负载的开关操作。在本实验中,将通过GPIO口…

ChatGPT生成接口测试用例(二)

5.1.4 自动生成测试数据 测试数据的生成通常是接口测试的一个烦琐任务。ChatGPT可以帮助测试团队生成测试数据,包括各种输入和它们的组合。测试人员可以描述他们需要的数据类型和范围,ChatGPT可以生成符合要求的测试数据,从而减轻测试人员的负…

空天地遥感数据识别与计算--数据分析如何助力农林牧渔、城市发展、地质灾害监测等行业革新

在科技飞速发展的时代,遥感数据的精准分析已经成为推动各行业智能决策的关键工具。从无人机监测农田到卫星数据支持气候研究,空天地遥感数据正以前所未有的方式为科研和商业带来深刻变革。然而,对于许多专业人士而言,如何高效地处…

使用ZLMediaKit 开源项目搭建RTSP 服务器

ZLMediaKit 是啥? ZLMediaKit是国人开发的开源C流媒体服务器,同SRS一样是主流的流媒体服务器。 ZLToolKit是基于C11的高性能服务器框架,和ZLMediaKit是同一个作者,ZLMediaKit正是使用该框架开发的。 官网 ZLMediaKit开源地址&…