一、dataset和dataloader要点说明
在我们搭建自己的网络时,往往需要定义自己的dataset
和dataloader
,将图像和标签数据送入模型。
(1)在我们定义dataset
时,需要继承torch.utils.data.dataset
,再重写三个方法:
init
方法,主要用来定义数据的预处理getitem
方法,数据增强;返回数据的item和labellen
方法,返回数据数量
(2)在我们定义dataloader
时,需要考虑下面几个参数:
dataset
:使用哪个数据集batch_size
:将数据集拆成一组多少个进行训练shuffle
:是否需要打乱数据num_workers
:几个mini_batch并行计算,一般<=你的电脑cpu数目collect_fn
:数据打包方式
(3)通过迭代的方式,按批次,获取dataloader
中的数据
(4)关系图
二、核心代码框架
import os
import cv2
from torchvision import transforms
from torch.utils.data.dataset import Dataset
from torch.utils.data import DataLoader
# -------------------------------------------------------------#
# 自定义dataset需要继承torch.utils.data.dataset,
# 再重写def __init__,def __len__,def __getitem__三个方法
# -------------------------------------------------------------#
class YourDataset(Dataset):
def __init__(self, root_path):
super(YourDataset, self).__init__()
self.root_path = root_path
#-------------------------------------------------------------------------#
# 获取样本名,以jpg原始图片为参考,修改后缀名为json,png,获取json,png标签文件路径
#-------------------------------------------------------------------------#
self.sample_names = []
jpg_path = os.path.join(os.path.join(self.root_path, "images"),)
for file in os.listdir(jpg_path):
if file.endswith(".jpg"):
self.sample_names.append(os.path.splitext(file)[0]) # 去掉.json
def __len__(self):
#----------------------#
# 返回数据数量
#----------------------#
return len(self.sample_names)
def __getitem__(self, index):
name = self.sample_names[index]
# ----------------------#
# 读取图像
# ----------------------#
img_path = os.path.join(os.path.join(self.root_path, "images"), name + '.jpg')
image = cv2.imread(img_path)
# ----------------------#
# 读取标签
# ----------------------#
label_path = os.path.join(os.path.join(self.root_path, "jsons"), name + '.json')
with open(label_path) as label_file:
points = self.get_data_from_json(label_file)
#----------------------#
# 图像数据增强
#----------------------#
image = self.random_color(image)
#----------------------#
# 标签归一化
#----------------------#
labels = self.convert_labels(points)
return image, labels
# -------------------------------------#
# 图片和标签格式转换后,按批次(batch)打包
# -------------------------------------#
def dataloader_collate_fn(batch):
images = []
labels = []
for img, label in batch:
images.append(transforms.ToTensor()(img))
labels.append(label)
return images, labels
if __name__ == '__main__':
# -------------------------------------#
# 构建dataset
# -------------------------------------#
path = './data/train'
train_dataset = YourDataset(path)
# -------------------------------------#
# 构建Dataloader
# -------------------------------------#
dataset = train_dataset
batch_size = 32
shuffle = True
num_workers = 0
collate_fn = dataloader_collate_fn
sampler = None
train_gen = DataLoader(dataset=dataset, shuffle=shuffle, batch_size=batch_size, num_workers=num_workers, pin_memory=True,drop_last=True, collate_fn=collate_fn, sampler=sampler)
# ---------------------------------------------#
# 通过迭代的方式,一批一批读取训练集中的图像和标签数据
# ---------------------------------------------#
for iter, batch in enumerate(train_gen):
images, labels = batch