torchvision中的数据集使用
- 1. torchvision中的数据集使用
- 官网文档
- 注意点1 totensor实例化不要忘记加括号
- 注意点2 download可以一直保持为True
- 代码
- 执行结果
- 2. DataLoader的使用
1. torchvision中的数据集使用
官网文档
注意左上角的版本
https://pytorch.org/vision/0.9/
注意点1 totensor实例化不要忘记加括号
totensor实例化不要忘记加括号,否则后面用数据集序列号的时候会报错
注意点2 download可以一直保持为True
download可以一直保持为True,下载一次后指定目录下有下载好的数据集,代码不会重复下载,也可以自己把下载好的数据集压缩包放到指定目录,代码会自动解压缩
代码
from torch.utils.tensorboard import SummaryWriter
from torchvision import datasets, transforms
# 用法1
# 数据下载很慢的话 可以使用迅雷下载,属性里面可以看到迅雷是从多方下载的,速度比较快 https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz
train_set = datasets.CIFAR10(root='./dataset', train=True, download=True)
test_set = datasets.CIFAR10(root='./dataset', train=False, download=True)
# 下载的数据集是图片类型,可以debug查看数据
print(test_set[0]) # __getitem__ return img, target
print(type(test_set[0]))
img, target = test_set[0]
print(target)
print(test_set.classes[target])
print(img)
# PIL 图片可以直接show函数展示
img.show()
# 用法2
# 将数据集批量调用transforms,使用tensor数据类型
# trans_compose = transforms.Compose([transforms.ToTensor]) # 错误写法 会导致后面报错
trans_compose = transforms.Compose([transforms.ToTensor()])
train_set2 = datasets.CIFAR10(root='./dataset', train=True, transform=trans_compose, download=True)
test_set2 = datasets.CIFAR10(root='./dataset', train=False, transform=trans_compose, download=True)
print(type(test_set2[2]))
img, target = test_set2[0]
print(target)
print(test_set2.classes[target])
print(type(img))
writer = SummaryWriter("logs")
for i in range(10):
img_tensor, target = test_set2[i]
writer.add_image('tensor dataset', img_tensor, i)
writer.close()
执行结果
> p11_torchvision_dataset.py
Files already downloaded and verified
Files already downloaded and verified
(<PIL.Image.Image image mode=RGB size=32x32 at 0x1CF47DA9E20>, 3)
<class 'tuple'>
3
cat
<PIL.Image.Image image mode=RGB size=32x32 at 0x1CF47DA9E20>
Files already downloaded and verified
Files already downloaded and verified
<class 'tuple'>
3
cat
<class 'torch.Tensor'>
Process finished with exit code 0