目录
- 1. 图像分类数据集
- 1.1 读取数据集
- 1.2 读取小批量
- 1.3 整合所有组件
- 1.4 小结
1. 图像分类数据集
这里采用Fashion-MNIST数据集
torchvision
:torch类型的可视化包,一般计算机视觉和数据可视化需要使用from torchvision import transforms
:该组件经常用于图片的修改(一般数据集中的图片都是PIL格式,使用的时候需要转化为tenser,而在加入函数时常需要转化为nadarry(numpy中的ndarray为多维数组))d2l.use_svg_display()
:使用什么模式展示图片
%matplotlib inline
import torch
import torchvision #pytorch用于计算机视觉的一个库
from torch.utils import data
from torchvision import transforms #导入对数据操作的模具
from d2l import torch as d2l
d2l.use_svg_display() #使用svg展示图片
1.1 读取数据集
通过框架中的内置函数将Fashion-MNIST数据集下载并读取到内存中
torchvision.datasets
:一般用于图像数据集的下载和获取
eg:torchvision.datasets.FashionMNIST( root=, train=True, transform=, download=True)
:
- train:是否为训练集
- transform:使用什么格式转换(可以从transforms组件中选择)
- dowload:是否下载对应数据集
- .FashionMNIST可以更换为其他数据源
# 通过ToTensor实例将图像数据从PIL类型变换成32位浮点数格式,
# 并除以255使得所有像素的数值均在0~1之间
trans = transforms.ToTensor() #对图片进行预处理,转换为tensor格式
# 下载训练集和测试集,并保存
mnist_train = torchvision.datasets.FashionMNIST(
root="../data", train=True, transform=trans,download=True)
mnist_train = torchvision.datasets.FashionMNIST(
root="../data", train=False, transform=trans,download=True)
Fashion-MNIST由10个类别的图像组成, 每个类别由训练数据集(train dataset)中的6000张图像 和测试数据集(test dataset)中的1000张图像组成。 因此,训练集和测试集分别包含60000和10000张图像。 测试数据集不会用于训练,只用于评估模型性能。
# 输出训练集和测试集的大小
len(mnist_train), len(mnist_test)
每个输入图像的高度和宽度均为28像素。 数据集由灰度图像组成,其通道数为1(彩色图像通道数为3)。
# 索引到第一张图片
mnist_train[0][0].shape # 输入图像的通道数、高度和宽度
Fashion-MNIST中包含的10个类别,分别为t-shirt(T恤)、trouser(裤子)、pullover(套衫)、dress(连衣裙)、coat(外套)、sandal(凉鞋)、shirt(衬衫)、sneaker(运动鞋)、bag(包)和ankle boot(短靴)。以下函数用于在数字标签索引及其文本名称之间进行转换。
# 获取数据集的标签
def get_fashion_mnist_labels(labels): #@save
"""返回Fashion-MNIST数据集的文本标签"""
text_labels = ['t-shirt', 'trouser', 'pullover', 'dress', 'coat',
'sandal', 'shirt', 'sneaker', 'bag', 'ankle boot']
return [text_lables[int(i)] for i in labels]
创建一个函数来可视化这些样本。
plt.subplots()
是一个返回包含图形和轴对象的元组的函数。因此,在使用时fig, ax = plt.subplots(),将此元组解压缩到变量fig和ax。enumerate()
函数用于将一个可遍历的数据对象(如列表、元组或字符串)组合为一个索引序列,同时列出数据和数据下标,一般用在 for 循环当中,生成可以遍历的每个元素有对应序号(0, 1, 2, 3…)的enumerate对象。zip()
函数用于将多个可迭代对象作为参数,依次将对象中对应的元素打包成一个个元组,然后返回由这些元组组成的对象,里面的每个元素大概为i,(ax,img)的形式。imshow()
可以接收二维,三维甚至多维数组。二维默认为一通道即灰度图像,三维需要在第三个维度指定图像通道数(必须是第三维)
def show_images(imgs, num_rows, num_cols, titles=None, scale=1.5): #@save
"""绘制图像列表"""
figsize = (num_cols * scale, num_rows * scale)
# 第1个参数是个图,一般不用;第2个axer类似于图片的索引矩阵(行,列)
_, axes = d2l.plt.subplots(num_rows, num_cols, figsize=figsize) # axes:轴
axes = axes.flatten()
# 遍历生成形如i, (ax, img)形式的enumerate对象
for i, (ax, img) in enumerate(zip(axes, imgs)):
if torch.is_tensor(img):
# 图片张量
ax.imshow(img.numpy())
else:
# PIL图片
ax.imshow(img)
ax.axes.get_xaxis().set_visible(False) #x轴隐藏
ax.axes.get_yaxis().set_visible(False) #y轴隐藏
if titles:
ax.set_title(title[i]) #显示标题
return axes
以下是训练数据集中前几个样本的图像及其相应的标签。
next()
返回迭代器的下一个项目。next()
函数要和生成迭代器的iter()
函数一起使用。- 我们可以通过
iter()
函数获取这些可迭代对象的迭代器。然后,我们可以对获取到的迭代器不断使⽤next()
函数来获取下⼀条数据。
注:当我们已经迭代完最后⼀个数据之后,再次调⽤next()
函数会抛出 StopIteration的异常 ,来告诉我们所有数据都已迭代完成,不⽤再执⾏next()
函数了。
# 使用next()函数获取批量大小为18的训练集的图像和标签
X, y = next(iter(data.DataLoader(mnist_train, batch_size=18)))
#显示18张图片,宽度为28,长度为28,总共为2行9列
# 绘制两行图片,每一行有9张图片,并获取标签
show_images(X.reshape(18, 28, 28), 2, 9, titles=get_fashion_mnist_labels(y));
1.2 读取小批量
为了使我们在读取训练集和测试集时更容易,我们使用内置的数据迭代器,而不是从零开始创建。 回顾一下,在每次迭代中,数据加载器每次都会读取一小批量数据,大小为batch_size。 通过内置数据迭代器,我们可以随机打乱了所有样本,从而无偏见地读取小批量。
batch_size = 256
def get_dataloader_workers(): #@save
"""使用4个进程来读取数据"""
return 4
# 训练集需要设置shuffle=True打乱顺序
train_iter = data.DataLoader(mnist_train, batch_size, shuffle=True,
num_workers=get_dataloader_workers())
我们看一下读取训练数据所需的时间。
timer = d2l.Timer() #调用Timer函数,测试速度
for X, y in train_iter:
continue
f'{timer.stop():.2f} sec' #输出读取数据所用的秒数,精度为2位小数
1.3 整合所有组件
定义load_data_fashion_mnist函数,用于获取和读取Fashion-MNIST数据集。这个函数返回训练集和验证集的数据迭代器。 此外,这个函数还接受一个可选参数resize,用来将图像大小调整为另一种形状。
torchvision.transforms
是pytorch中的图像预处理包,一般用Compose把多个步骤整合到一起。insert
函数是一种用于列表的内置函数。这个函数的作用是在一个列表中的指定位置,插入一个元素。
transforms中的函数 | 功能 |
---|---|
Resize | 把给定的图片resize到given size |
Normalize | 用均值和标准差归一化张量图像 |
def load_data_fashion_mnist(batch_size, resize=None): #@save
"""下载Fashion-MNIST数据集,然后将其加载到内存中"""
# 转换为tensor
trans = [transforms.ToTensor()]
if resize:
trans.insert(0, transforms.Resize(resize))
# compose整合步骤
trans = transforms.Compose(trans)
# 下载训练集和测试集,将小批量样本返回到train_iter中,用于之后的训练
mnist_train = torchvision.datasets.FashionMNIST(
root="../data", train=True, transform=trans, download=True)
mnist_test = torchvision.datasets.FashionMNIST(
root="../data", train=False, transform=trans, download=True)
return (data.DataLoader(mnist_train, batch_size, shuffle=True,
num_workers=get_dataloader_workers()),
data.DataLoader(mnist_test, batch_size, shuffle=False,
num_workers=get_dataloader_workers()))
下面,我们通过指定resize参数来测试load_data_fashion_mnist函数的图像大小调整功能。
train_iter, test_iter = load_data_fashion_mnist(32, resize=64)
for X, y in train_iter:
print(X.shape, X.dtype, y.shape, y.dtype)
break
1.4 小结
- Fashion-MNIST是一个服装分类数据集,由10个类别的图像组成。我们将在后续章节中使用此数据集来评估各种分类算法。
- 我们将高度h像素,宽度w像素图像的形状记为h×w或(h,w)。
- 数据迭代器是获得更高性能的关键组件。依靠实现良好的数据迭代器,利用高性能计算来避免减慢训练过程。