Dataset类实践
蚂蚁蜜蜂分类数据集和下载链接https://download.pytorch.org/tutorial/hymenoptera_data.zip
Dataset:提供一种方式去获取数据及其lable
-
Q:如何获取每个数据及其lable
重写构造方法和获取标签方法
-
Q:告诉我们总共有多少数据
重写len方法
代码示例:
from torch.utils.data import Dataset
from PIL import Image
import os
class MyData(Dataset):
def __init__(self, root_dir, label_dir): # 获取所有图片的地址
self.root_dir = root_dir # self设置成全局变量
self.label_dir = label_dir
self.path = os.path.join(self.root_dir, self.label_dir)
self.img_path = os.listdir(self.path)
def __getitem__(self, idx): # 获取图片及其标签
img_name = self.img_path[idx] # 获取文件名称
img_item_path = os.path.join(self.root_dir, self.label_dir, img_name) # 文件绝对路径
img = Image.open(img_item_path) # 获取文件详细信息
label = self.label_dir # 获取标签
return img, label # 返回图片和标签
def __len__(self):
return len(self.img_path)
# 创建示例测试
root_dir = "hymenoptera_data/train"
ants_label_dir = "ants"
bees_label_dir = "bees"
ants_dataset = MyData(root_dir, ants_label_dir) # 获取蚂蚁的数据集
bees_dataset = MyData(root_dir, ants_label_dir) # 获取蜜蜂的数据集
#获取整个train数据集
train_dataset = ants_dataset + bees_dataset
在控制台中进行测试
- 对数据集中图片的相关操作
获取蚂蚁数据集
- 查找当前数据集中第一个图片名称
- 图片名称拼接(进行路径和标签的拼接)
- 读取图片相应信息
实例化对象
- 返回该对象 image和label
- 结果
- 改变ants_dataset[],展示第二张图
获取蜜蜂的数据集
获取整个train数据集(蚂蚁+蜜蜂)
train_dataset = ants_dataset + bees_dataset
# 进行长度测试
len(train_dataset)
248
len(ants_dataset)
124
len(bees_dataset)
124