目录
1. 介绍
2. 主函数代码
2. utils 模块代码
2.1 划分数据集
2.2 可视化数据集
3. dataset 数据处理
4. collate_fn
5. other
1. 介绍
图像分类一般来说不需要自定义的dataSet,因为pytorch自定义好的ImageFolder可以解决大部分的需求,更多的dataSet是在图像分割里面实现的
这里 霹雳吧啦Wz 博主提供了一个好的代码,可以进行数据集划分(不需要保存划分后的数据集),然后重新实现了dataSet,并且对dataloader的 collate_fn 方法进行了实现
下面的代码只会对重点的部分做笔记
2. 主函数代码
这里的root传入的是数据集的路径
import os
import torch
from torchvision import transforms
from torch.utils.data import DataLoader
from my_dataset import MyDataSet
from utils import read_split_data, plot_data_loader_image
# 数据集所在根目录,不需要划分trainSet+valSet,这里是完整数据集
root = './data/flower'
def main():
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("using {} device.".format(device)) # 打印使用的设备
# 划分训练集 + 验证集
train_images_path, train_images_label, val_images_path, val_images_label = read_split_data(root, val_rate=0.1,flag=False)
# 预处理
data_transform = {
"train": transforms.Compose([transforms.RandomResizedCrop(224),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])]),
"val": transforms.Compose([transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])}
# 数据处理
train_data_set = MyDataSet(images_path=train_images_path,
images_class=train_images_label,
transform=data_transform["train"])
batch_size = 8
nw = min([os.cpu_count(), batch_size if batch_size > 1 else 0, 8]) # number of workers
print('Using {} dataloader workers'.format(nw))
# 获取数据,测试的时候num_workers 设定为0
train_loader = DataLoader(train_data_set,batch_size=batch_size,shuffle=True,num_workers=nw,
collate_fn=train_data_set.collate_fn)
# 可视化数据
plot_data_loader_image(train_loader)
if __name__ == '__main__':
main()
这里的代码很常规,为了测试,只加载了训练集数据
2. utils 模块代码
这里实现了两个功能,划分数据集 + 可视化数据集
2.1 划分数据集
代码都做了注释,这块的内容慢慢调试也很容易理解,之前实现过相似的代码,只不过当时将划分好的数据集保存到不同的目录中,然后用ImageFolder调用的
def read_split_data(root: str, val_rate: float = 0.2, flag: bool = False):
random.seed(0) # 保证随机结果可复现
assert os.path.exists(root), "dataset root: {} does not exist.".format(root) # 断言数据集目录是否存在
# 遍历文件夹,一个文件夹对应一个类别,flower_class = ['daisy', 'dandelion', 'roses', 'sunflowers', 'tulips']
flower_class = [cla for cla in os.listdir(root) if os.path.isdir(os.path.join(root, cla))]
# 排序,保证顺序一致
flower_class.sort()
# 生成类别名称以及对应的数字索引 class_indices={'daisy': 0, 'dandelion': 1, 'roses': 2, 'sunflowers': 3, 'tulips': 4}
class_indices = dict((k, v) for v, k in enumerate(flower_class))
json_str = json.dumps(dict((val, key) for key, val in class_indices.items()), indent=4)
with open('class_indices.json', 'w') as json_file:
json_file.write(json_str)
'''生成json文件
{
"0": "daisy",
"1": "dandelion",
"2": "roses",
"3": "sunflowers",
"4": "tulips"
}
'''
train_images_path = [] # 存储训练集的所有图片路径
train_images_label = [] # 存储训练集图片对应label
val_images_path = [] # 存储验证集的所有图片路径
val_images_label = [] # 存储验证集图片对应label
every_class_num = [] # 存储每个类别的样本总数
supported = [".jpg", ".JPG", ".png", ".PNG"] # 支持的文件后缀类型
# 遍历每个文件夹下的文件
for cla in flower_class:
cla_path = os.path.join(root, cla) # 每个文件夹的路径
# 遍历获取supported支持的所有文件路径
images = [os.path.join(root, cla, i) for i in os.listdir(cla_path)
if os.path.splitext(i)[-1] in supported] # splitext 分离文件名和后缀名
# 获取该类别对应的索引
image_class = class_indices[cla]
# 记录该类别的样本数量
every_class_num.append(len(images))
# 按比例随机采样验证样本
val_path = random.sample(images, k=int(len(images) * val_rate))
for img_path in images:
if img_path in val_path: # 如果该路径在采样的验证集样本中则存入验证集
val_images_path.append(img_path)
val_images_label.append(image_class) # 0 1 2 3 4
else: # 否则存入训练集
train_images_path.append(img_path)
train_images_label.append(image_class)
print("{} images were found in the dataset.".format(sum(every_class_num))) # 总样本个数
print("{} images for training.".format(len(train_images_path))) # 训练集个数
print("{} images for validation.".format(len(val_images_path))) # 验证集个数
plot_image = flag # 是否绘制图表,默认为 False
if plot_image:
# 绘制每种类别个数柱状图
plt.bar(range(len(flower_class)), every_class_num, align='center')
# 将横坐标0,1,2,3,4替换为相应的类别名称
plt.xticks(range(len(flower_class)), flower_class)
# 在柱状图上添加数值标签
for i, v in enumerate(every_class_num):
plt.text(x=i, y=v + 5, s=str(v), ha='center')
# 设置x坐标
plt.xlabel('image class')
# 设置y坐标
plt.ylabel('number of images')
# 设置柱状图的标题
plt.title('flower class distribution')
plt.show()
return train_images_path, train_images_label, val_images_path, val_images_label
2.2 可视化数据集
代码如下,
# 可视化
def plot_data_loader_image(data_loader):
batch_size = data_loader.batch_size
# 载入名称的json文件
json_path = './class_indices.json'
assert os.path.exists(json_path), json_path + " does not exist."
json_file = open(json_path, 'r')
class_indices = json.load(json_file)
for data in data_loader:
images, labels = data
for i in range(batch_size):
# [C, H, W] -> [H, W, C]
img = images[i].numpy().transpose(1, 2, 0)
# 反Normalize操作
img = (img * [0.229, 0.224, 0.225] + [0.485, 0.456, 0.406]) * 255
label = labels[i].item()
plt.subplot(2, batch_size//2+1, i+1)
plt.xlabel(class_indices[str(label)])
plt.xticks([]) # 去掉x轴的刻度
plt.yticks([]) # 去掉y轴的刻度
plt.imshow(img.astype('uint8'))
plt.show()
3. dataset 数据处理
代码如下:
from PIL import Image
import torch
from torch.utils.data import Dataset
# 自定义数据集处理
class MyDataSet(Dataset):
def __init__(self, images_path: list, images_class: list, transform=None):
self.images_path = images_path
self.images_class = images_class
self.transform = transform
def __len__(self): # 返回数据集的个数
return len(self.images_path)
def __getitem__(self, item):
img = Image.open(self.images_path[item]) # 返回路径下的PIL图像
# RGB为彩色图片,L为灰度图片
if img.mode != 'RGB': # 判断是否为 RGB 图像
raise ValueError("image: {} isn't RGB mode.".format(self.images_path[item]))
label = self.images_class[item]
if self.transform is not None: # transform 对 PIL 读取的图片处理
img = self.transform(img)
return img, label
@staticmethod
def collate_fn(batch):
# 官方实现的default_collate可以参考
# https://github.com/pytorch/pytorch/blob/67b7e751e6b5931a9f45274653f4f653a4e6cdf6/torch/utils/data/_utils/collate.py
images, labels = tuple(zip(*batch))
images = torch.stack(images, dim=0)
labels = torch.as_tensor(labels)
return images, labels
对这里调试的话,可以看到很多信息
4. collate_fn
这里的实现如下:
@staticmethod
def collate_fn(batch):
# 官方实现的default_collate可以参考
# https://github.com/pytorch/pytorch/blob/67b7e751e6b5931a9f45274653f4f653a4e6cdf6/torch/utils/data/_utils/collate.py
images, labels = tuple(zip(*batch))
images = torch.stack(images, dim=0)
labels = torch.as_tensor(labels)
return images, labels
下面是之前 blog 里面写的
对下面进行调试,发现dataloader其实是加载batch_size 个数的list,其中没有元素是一个tuple,里面存放了图像和label
运行发现:将batch_size 个图像放到一个tuple里面,label也是
最后,
所以最后可视化的结果:
5. other
这里博主提供了一个调试的方法,将这里勾上