1、简介
- PyTorch中如何读取数据主要涉及到两个类,分别为Dataset和Dataloader。
- Dataset:创建可被Pytorch使用的数据集
- Dataloader:向模型传递数据
- 本文主要讲解Dataset的使用方法。
2、Dataset
2.1、查看使用方法
- 打开Anaconda Prompt,进入pytorch虚拟环境(conda activate pytorch),输入下面命令,打开Jupyter。(使用Jupyter输出的结果更加清晰)
- 新建一个文件(可自行选择创建位置)。
- 输入下面指令,按Shift+回车运行。
- 也可以输入下列指令。
- 有下列描述可知,Dataset是一个抽象类,所有的数据集都需要继承这个类。并且所有子类都需要重写__getitem__方法来获取每一个数据的标签。
2.2、应用
- 使用PyCharm打开pytorch项目。如果没有,请参考:PyTorch入门教学——使用PyCharm创建一个PyTorch项目-CSDN博客,创建一个。
- 新建一个python文件。
- 数据集下载:https://download.pytorch.org/tutorial/hymenoptera_data.zip,将下好的数据集放入pytorch项目中。
- 该数据集分为训练数据集和验证数据集。
- 两个数据集中包含了蚂蚁和蜜蜂的图片,可以用来做二分分类,识别图片为蚂蚁还是蜜蜂。
- 打开read_data.py,写入下列代码。
-
from torch.utils.data import Dataset from PIL import Image # 获取图片 import os # 提供一些方法 class MyData(Dataset): # 继承Dataset # 初始化,为整个类提供全局变量 def __init__(self, root_dir, label_dir): self.root_dir = root_dir # 将这些变量赋给类,供后续函数使用 self.label_dir = label_dir self.path = os.path.join(self.root_dir, self.label_dir) # 将两个路径拼接 self.img_path = os.listdir(self.path) # 将路径中的所有文件名转为列表 # 获取图片和标签 def __getitem__(self, item): img_name = self.img_path[item] # 获取每一个图片名称 img_item_path = os.path.join(self.root_dir, self.label_dir, img_name) # 获取每一张图片的地址 img = Image.open(img_item_path) # 打开图片 label = self.label_dir return img, label # 数量 def __len__(self): return len(self.img_path) root_dir = "Dataset/ReadData/train" # 图片相对路径 ants_label_dir = "ants" bees_label_dir = "bees" ants_dataset = MyData(root_dir, ants_label_dir) bees_dataset = MyData(root_dir, bees_label_dir) train_dataset = ants_dataset + bees_dataset # 整个训练数据集 # 图片展示 img1, label1 = train_dataset[123] img1.show() # 展示蚂蚁图片 img2, label1 = train_dataset[124] img2.show() # 展示蜜蜂图片
- 分别将蚂蚁和蜜蜂的图片提取并展示出来。
-