什么是DataLoader
个人理解是,如果Dataset
的所有数据相当于一副扑克牌,DataLoader
就相当于从扑克牌中抽取几张,我们可以规定一次抽取的张数,或者以什么规则进行抽取
DataLoader的使用
查阅官网的文档,主要有这几个参数比较常用
其中dataset
可以用上一篇文章来进行创建
具体的实现方法为
import torchvision
from torch.utils.data import DataLoader
test_dataset = torchvision.datasets.CIFAR10(root='./dataset', train=False, transform=torchvision.transforms.ToTensor(), download=True)
# 这里采用测试集是因为测试集较小,运行较快
test_dataLoader = DataLoader(test_dataset, batch_size=64, shuffle=False, num_workers=0, drop_last=True)
for data in test_dataLoader: #从test_dataLoader中取出data
imgs, labels = data
print(labels)
然后我们可以加入tensorBoard
可视化处理
import torchvision
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter
test_dataset = torchvision.datasets.CIFAR10(root='./dataset', train=False, transform=torchvision.transforms.ToTensor(),
download=True)
# 这里采用测试集是因为测试集较小,运行较快
test_dataLoader = DataLoader(test_dataset, batch_size=64, shuffle=False, num_workers=0, drop_last=False)
writer = SummaryWriter(log_dir='./logs')
i = 0
for data in test_dataLoader: # 从test_dataLoader中取出data
imgs, labels = data
print(imgs.shape)
writer.add_images('test_loader1', imgs, i) # 注意这是add_images
i = i + 1
writer.close()
注意这里的writer.add_images()
需要加s
否则不能运行