torchvision的datasets
使用torchvision提供的数据集API,比较方便, 如果在pycharm中下载很慢,可以URL链接到迅雷中进行下载(有些URL链接在源码里) 代码如下:
import torchvision
train_set = torchvision.datasets.CIFAR10( "./Dataset" , train = True, download = True)
test_set = torchvision.datasets.CIFAR10( "./Dataset" , train = False, download = True)
CIFAR10数据集的每个样本会输出一个元组,第一个元素是PIL格式的图片,第二个元素是该样本的标签,即class,代码如下:
import torchvision
train_set = torchvision.datasets.CIFAR10( "./Dataset" , train = True, download = True)
test_set = torchvision.datasets.CIFAR10( "./Dataset" , train = False, download = True)
print( train_set[ 0 ] )
print( train_set.classes)
img, target = train_set[ 0 ]
print( img)
print( target)
print( train_set.classes[ target] )
对数据集进行transforms变换
注意,只需要在调用数据集API时,填入变换对象即可,由于dataset_transforms是Compose类实例化后的对象,所以直接传入即可,代码如下:
import torchvision
from torch.utils.tensorboard import SummaryWriter
dataset_transforms = torchvision.transforms.Compose( [
torchvision.transforms.ToTensor( ) ,
] )
train_set = torchvision.datasets.CIFAR10( "./Dataset" , train = True, transform = dataset_transforms, download = True)
test_set = torchvision.datasets.CIFAR10( "./Dataset" , train = False, transform = dataset_transforms, download = True)
writer = SummaryWriter( "logs" )
for i in range( 10 ) :
img, target = train_set[ i]
writer.add_image( "train_set_img" , img, i)
writer.close( )
tensorboard的展示结果如下:
torchvision中的dataloader