文章目录
- 一、前置知识
- 如何查看torchvision的数据集
- 二、代码(附注释)及运行结果
一、前置知识
如何查看torchvision的数据集
(1)打开官网 https://pytorch.org/
pytorch官网
(2)打开torchvision
在Docs下拉后选择torchvision
(3)左侧点击Datasets
本次用的数据集是CIFAR10:
可以看到,要输入的参数有:
root(字符串):数据集的根目录,其中存在 cifar-10-batches-py 目录,如果设置 download 为 True,则数据集将保存在此目录中。
train(bool,可选):如果为 True,则从训练集创建数据集,否则从测试集创建数据集。
transform(callable,可选):接受 PIL 图像并返回转换后版本的函数/转换。例如,transforms.RandomCrop。
target_transform(callable,可选):接受目标并对其进行转换的函数/转换。
download(bool,可选):如果为 True,则从互联网下载数据集并将其放在根目录中。如果数据集已经下载,则不会重新下载。
二、代码(附注释)及运行结果
import torchvision
from torch.utils.tensorboard import SummaryWriter
# 定义导入数据时进行的变换
data_transform = torchvision.transforms.Compose([
torchvision.transforms.ToTensor()
])
# 创建训练集和测试集
train_set = torchvision.datasets.CIFAR10("./dataset1", train=True, transform=data_transform, download=True)
test_set = torchvision.datasets.CIFAR10("./dataset1", train=False, transform=data_transform, download=True)
# 打印test_set第一个数据
# 结果为:(<PIL.Image.Image image mode=RGB size=32x32 at 0x10C7177C190>, 3)
print(test_set[0])
# 打印test_set数据的类别
# 结果为:['airplane', 'automobile', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck']
print(test_set.classes)
# 将test_set的第一个数据拆分为img和target
img, target = test_set[0]
# 打印test_set第一个数据的img
# 结果为<PIL.Image.Image image mode=RGB size=32x32 at 0x10C7177C190>
print(img)
# 打印test_set第一个数据的target,结果为3
print(target)
# 打印test_set第target个类别
print(test_set.classes[target])
# 创建一个 TensorBoard 的 SummaryWriter 对象,用于记录测试集中的图像
writer = SummaryWriter("logs")
for i in range(10):
img, target = test_set[i]
writer.add_image("test_set", img, i)
运行结果: