【深度学习】Pytorch:加载自定义数据集

本教程将使用 flower_photos 数据集演示如何在 PyTorch 中加载和导入自定义数据集。该数据集包含不同花种的图像,每种花的图像存储在以花名命名的子文件夹中。我们将深入讲解每个函数和对象的使用方法,使读者能够推广应用到其他数据集任务中。

flower_photos/
├── daisy/
│   ├── image1.jpg
│   ├── image2.jpg
└── rose/
     ├── image1.jpg
     ├── image2.jpg
...

环境配置

所需工具和库

pip install torch torchvision matplotlib

导入必要的库

import os
import torch
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
from PIL import Image
import pathlib

数据集导入方法

定义数据转换

图像转换在计算机视觉任务中至关重要。通过 transforms 对象,我们可以实现图像大小调整、归一化、随机变换等预处理操作。

# 定义图像转换  
transform = transforms.Compose([  
    transforms.Resize((150, 150)),  # 调整图像大小为 150x150  
    transforms.ToTensor(),  # 将图像转换为 PyTorch 张量  
    transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])  # 归一化图像数据  
])  

# 数据路径  
data_dir = r"E:\CodeSpace\Deep\data\flower_photos"  

# 使用 ImageFolder 加载数据  
full_dataset = datasets.ImageFolder(root=data_dir, transform=transform)  

# 计算训练集和测试集的样本数量(80%和20%的划分)  
train_size = int(0.8 * len(full_dataset))  
test_size = len(full_dataset) - train_size  

# 随机划分数据集  
train_dataset, test_dataset = random_split(full_dataset, [train_size, test_size])  

# 创建数据加载器  
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)  
test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False)  

# 获取类别名  
class_names = full_dataset.classes  
print("类别名:", class_names)

显示部分样本图像

可视化样本数据有助于理解数据集结构和数据质量。

# 定义函数以绘制样本图像
def plot_images(images, labels, class_names):
    plt.figure(figsize=(10, 10))
    for i in range(9):  # 绘制前 9 张图像
        plt.subplot(3, 3, i + 1)
        img = images[i].permute(1, 2, 0)  # 将张量维度从 (C, H, W) 转为 (H, W, C)
        plt.imshow(img * 0.5 + 0.5)  # 反归一化处理,恢复到原始像素范围 [0, 1]
        plt.title(class_names[labels[i]])  # 显示类别标签
        plt.axis('off')  # 去掉坐标轴

# 获取部分样本数据用于展示
sample_images, sample_labels = next(iter(train_loader))
plot_images(sample_images, sample_labels, class_names)

自定义数据加载方法

当数据结构复杂或需要额外处理时,可以通过继承 torch.utils.data.Dataset 创建自定义数据加载类。

Dataset 类详解

Dataset 是 PyTorch 中的一个抽象类,用户需要实现以下核心方法:

  1. __init__():初始化方法
    • 传入数据路径和转换方法。
    • 加载所有图像路径并生成类别标签。
  2. __len__():返回数据集大小
    • 指定数据集中样本数量。
  3. __getitem__():根据索引获取样本数据
    • 加载指定位置的图像和标签,并进行必要的转换。

代码实现

class CustomFlowerDataset(torch.utils.data.Dataset):
    def __init__(self, data_dir, transform=None):
        # 初始化数据集路径和图像转换方法
        self.data_dir = pathlib.Path(data_dir)
        self.transform = transform
        self.image_paths = list(self.data_dir.glob('*/*.jpg'))  # 获取所有图像路径
        self.label_names = sorted(item.name for item in self.data_dir.glob('*/') if item.is_dir())
        self.label_to_index = {name: idx for idx, name in enumerate(self.label_names)}  # 将类别名映射为索引

    def __len__(self):
        # 返回数据集大小
        return len(self.image_paths)

    def __getitem__(self, idx):
        # 根据索引获取图像及其标签
        img_path = self.image_paths[idx]
        label = self.label_to_index[img_path.parent.name]  # 通过父文件夹名获取标签
        image = Image.open(img_path).convert("RGB")  # 确保图像是 RGB 模式
        if self.transform:
            image = self.transform(image)  # 进行图像预处理
        return image, label

# 使用自定义数据集
custom_dataset = CustomFlowerDataset(data_dir, transform=transform)
custom_loader = DataLoader(custom_dataset, batch_size=32, shuffle=True)

随机划分数据集

如果你还希望在这个自定义数据集上随机划分训练集和测试集,可以使用 torch.utils.data.random_split。以下是示例代码:

from torch.utils.data import random_split  

# 获取数据集长度  
full_dataset = CustomFlowerDataset(data_dir, transform=transform)  

# 计算训练集和测试集的样本数量(80%和20%的划分)  
train_size = int(0.8 * len(full_dataset))  
test_size = len(full_dataset) - train_size  

# 随机划分数据集  
train_dataset, test_dataset = random_split(full_dataset, [train_size, test_size])  

# 创建数据加载器  
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)  
test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False)  

print(f"训练集大小: {len(train_dataset)}, 测试集大小: {len(test_dataset)}")  

数据加载性能优化

  • num_workers 参数:设置并行数据加载线程数。对于多核 CPU,可以显著提高数据加载效率。
  • prefetch_factor 参数:控制每个工作线程预取的批次数量。
custom_loader = DataLoader(custom_dataset, batch_size=32, shuffle=True, num_workers=4, prefetch_factor=2)

Dataset 类扩展建议

  1. 支持多格式数据读取:通过扩展 __getitem__() 来支持其他格式如 PNG、BMP。
  2. 数据过滤:在 __init__() 中根据文件名或元数据筛选特定样本。
  3. 标签增强:为每个样本生成附加信息,例如图像的元数据或分布特征。

数据集的使用方法

遍历数据集

模型训练前需要遍历数据集以加载图像和标签:

for images, labels in custom_loader:
    # images 是图像张量,labels 是对应的类别标签
    print(f"图像张量大小: {images.shape}, 标签: {labels}")

模型输入

数据集加载完成后可直接用于模型训练:

import torch.nn as nn
import torch.optim as optim

# 定义一个简单的神经网络模型
model = nn.Sequential(
    nn.Flatten(),  # 将输入张量展平成一维
    nn.Linear(150*150*3, 128),  # 输入层到隐藏层的全连接层
    nn.ReLU(),  # 激活函数
    nn.Linear(128, len(class_names))  # 输出层,类别数量等于花的种类数
)

# 定义损失函数和优化器
criterion = nn.CrossEntropyLoss()  # 交叉熵损失适用于多分类问题
optimizer = optim.Adam(model.parameters(), lr=0.001)  # Adam 优化器

# 示例训练过程
for epoch in range(2):  # 简单训练两轮
    for images, labels in custom_loader:
        outputs = model(images)  # 前向传播计算输出
        loss = criterion(outputs, labels)  # 计算损失

        optimizer.zero_grad()  # 梯度清零
        loss.backward()  # 反向传播计算梯度
        optimizer.step()  # 更新模型参数

    print(f"Epoch {epoch+1}, Loss: {loss.item():.4f}")

模型评估

加载后的数据集也可用于验证模型性能:

correct = 0
total = 0
model.eval()  # 设置模型为评估模式
with torch.no_grad():
    for images, labels in test_loader:
        outputs = model(images)
        _, predicted = torch.max(outputs, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

accuracy = 100 * correct / total
print(f"模型准确率: {accuracy:.2f}%")

方法对比与扩展

ImageFolder vs 自定义 Dataset

  • ImageFolder:适合简单目录结构,快速加载标准图像数据。
  • 自定义 Dataset:更适合复杂数据结构及自定义逻辑,例如多模态数据处理。

提高模型泛化能力

  • 数据增强:通过 transforms.RandomHorizontalFlip()transforms.ColorJitter() 等方法增加数据多样性。
  • 归一化技巧:根据数据集的特性调整 meanstd 参数。

总结

本教程详细讲解了如何在 PyTorch 中加载和导入 flower_photos 数据集,结合不同方法的讲解使你能根据项目需求灵活选择适合的数据加载方案。同时,我们探讨了优化和扩展方法,希望这些内容能为你的深度学习项目提供有力支持。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/952863.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

导出文件,能够导出但是文件打不开

背景: 在项目开发中,对于列表的查询,而后会有导出功能,这里导出的是一个excell表格。实现了两种,1.导出的文件,命名是前端传输过去的;2.导出的文件,命名是根据后端返回的文件名获取的…

马斯克的Grok-2 Beta APP在苹果应用商店上限了,Grok-2安装尝鲜使用教程

马斯克的Grok-2 Beta APP 已经上线苹果商城了,移动端的Grok挺好用的!无需登录即可使用! (文末有安装教程) 实测之后,Grok-2 绘画方面个人感觉比GPT-4的绘画还要强一些。而且速度还挺快,可以多次…

《机器学习》——sklearn库中CountVectorizer方法(词频矩阵)

CountVectorizer方法介绍 CountVectorizer 是 scikit-learn 库中的一个工具,它主要用于将文本数据转换为词频矩阵,而不是传统意义上的词向量转换,但可以作为词向量转换的一种基础形式。用于将文本数据转换为词频矩阵,它是文本特征…

CV 图像处理基础笔记大全(超全版哦~)!!!

一、图像的数字化表示 像素 数字图像由众多像素组成,是图像的基本构成单位。在灰度图像中,一个像素用一个数值表示其亮度,通常 8 位存储,取值范围 0 - 255,0 为纯黑,255 为纯白。例如,一幅简单的…

支持向量回归(SVR:Support Vector Regression)用于A股数据分析、预测

简单说明 支持向量回归是一种用来做预测的数学方法,属于「机器学习」的一种。 它的目标是找到一条「最合适的线」,能够大致描述数据点的趋势,并允许数据点离这条线有一定的误差(不要求所有点都完全落在这条线上)。 可以把它想象成:找到一条「宽带」或「隧道」,大部分…

ollama教程(window系统)

前言 在《本地大模型工具哪家强?对比Ollama、LocalLLM、LM Studio》一文中对比了三个常用的大模型聚合工具优缺点,本文将详细介绍在window操作系统下ollama的安装和使用。要在 Windows 上安装并使用 Ollama,需要依赖 NVIDIA 显卡&#xff0c…

Flink系统知识讲解之:容错与State状态管理

Flink系统知识之:容错与State状态管理 状态在Flink中叫作State,用来保存中间计算结果或者缓存数据。根据是否需要保存中间结果,分为无状态计算和有状态计算。对于流计算而言,事件持续不断地产生,如果每次计算都是相互…

DolphinScheduler自身容错导致的服务器持续崩溃重大问题的排查与解决

01 问题复现 在DolphinScheduler中有如下一个Shell任务: current_timestamp() { date "%Y-%m-%d %H:%M:%S" }TIMESTAMP$(current_timestamp) echo $TIMESTAMP sleep 60 在DolphinScheduler将工作流执行策略设置为并行: 定时周期调度设置…

ASP.NET Core 实现微服务 - Elastic APM

这次要给大家介绍的是Elastic APM ,一款应用程序性能监控组件。APM 监控围绕对应用、服务、容器的健康监控,对接口的调用链、性能进行监控。在我们实施微服务后,由于复杂的业务逻辑,服务之间的调用会像蜘蛛网一样复杂。有了调用链…

25/1/12 嵌入式笔记 学习esp32

了解了一下位选线和段选线的知识: 位选线: 作用:用于选择数码管的某一位,例如4位数码管的第1位,第2位) 通过控制位选线的电平(高低电平),决定当前哪一位数码管处于激活状…

IMX6U Qt 开发环境

提示:文章写完后,目录可以自动生成,如何生成可参考右边的帮助文档 文章目录 前言 一、交叉编译 1. 安装通用 ARM 交叉编译工具链 2. 安装 Poky 交叉编译工具链 二、编译出厂源码 1. U-boot 2. 内核和模块 3. 编译出厂 Qt GUI 综合 Demo 前言…

【Oracle专栏】2个入参,生成唯一码处理

Oracle相关文档,希望互相学习,共同进步 风123456789~-CSDN博客 1.背景 业务需要:2个参数,如 aidbankid ,两个值是联合主键,需要生成一个固定唯一码,长度有限制32位,为了…

跨界融合:人工智能与区块链如何重新定义数据安全?

引言:数据安全的挑战与现状 在信息化驱动的数字化时代,数据已成为企业和个人最重要的资产之一。然而,随着网络技术的逐步优化和数据量的爆发式增长,数据安全问题也愈变突出。 数据安全现状:– 数据泄露驱动相关事件驱…

给DevOps加点料:融入安全性的DevSecOps

从前,安全防护只是特定团队的责任,在开发的最后阶段才会介入。当开发周期长达数月、甚至数年时,这样做没什么问题;但是现在,这种做法现在已经行不通了。 采用 DevOps 可以有效推进快速频繁的开发周期(有时…

CDP中的Hive3之Hive Metastore(HMS)

CDP中的Hive3之Hive Metastore(HMS) 1、CDP中的HMS2、HMS表的存储(转换)3、HWC授权 1、CDP中的HMS CDP中的Hive Metastore(HMS)是一种服务,用于在后端RDBMS(例如MySQL或PostgreSQL&a…

【算法】判断一个链表是否为回文结构

问: 给定一个单链表的头节点head,请判断该链表是否为回文结构 例: 1 -> 2 -> 1返回true;1 -> 2 -> 2 -> 1返回true;15 -> 6 -> 15返回true 答: 笔试:初始化一个栈用来…

Python双指针

双指针 双指针:在区间操作时,利用两个下标同时遍历,进行高效操作 双指针利用区间性质可以把 O ( n 2 ) O(n^2) O(n2) 时间降低到 O ( n ) O(n) O(n) 反向扫描 反向扫描: l e f t left left 起点,不断往右走&…

VMware虚拟机安装Home Assistant智能家居平台并实现远程访问保姆级教程

目录 前言 1. 安装Home Assistant 前言 本文主要介绍如何在windows 10 上用VMware Workstation 17 Pro搭建 Home Assistant OS Host os version:Windows 10 Pro, 64-bit (Build 19045.5247) 10.0.19045 VMware version:VMware Workstation 17 Pro 1. 安装Home …

【MySQL】SQL菜鸟教程(一)

1.常见命令 1.1 总览 命令作用SELECT从数据库中提取数据UPDATE更新数据库中的数据DELETE从数据库中删除数据INSERT INTO向数据库中插入新数据CREATE DATABASE创建新数据库ALTER DATABASE修改数据库CREATE TABLE创建新表ALTER TABLE变更数据表DROP TABLE删除表CREATE INDEX创建…

【Java回顾】Day5 并发基础|并发关键字|JUC全局观|JUC原子类

JUC全称java.util.concurrent 处理并发的工具包(线程管理、同步、协调) 一.并发基础 多线程要解决什么问题?本质是什么? CPU、内存、I/O的速度是有极大差异的,为了合理利用CPU的高性能,平衡三者的速度差异,解决办法…