一、FashionMNIST数据集简介
FashionMNIST
数据集,是一款作为经典的MNIST
数据集的现代替代品的数据集,用来做衣物分类问题,由Zalando(一家德国的在线时尚零售商)发布。
- 该数据集含有
10种类别,共70000个灰度图像
。包含60000个训练集样本, 和10000个测试集样本。 - 每张图像以
28x28像素
的分辨率提供。
二、数据大小
由于每张图像都是单通道灰度图,因此每个像素可以用一个字节(0-255)来表示。所以,计算单个图像的大小:28像素 x 28像素 = 784字节。
算出整个数据集的大小:
- 训练集大小:60,000张 x 784字节 = 47,040,000字节(约44.77MB)
- 测试集大小:10,000张 x 784字节 = 7,840,000字节(约7.48MB)
总的来说,FashionMNIST数据集大约占用52.25MB的磁盘空间。这个值实际上可能略有不同,取决于存储文件的格式(例如压缩,pytorch下载的压缩包仅26M)。
三、十分类
FashionMNIST包含以下10个类别:
- T-shirt/top (T恤/上衣)
- Trouser (裤子)
- Pullover (套头衫)
- Dress (连衣裙)
- Coat (外套)
- Sandal (凉鞋)
- Shirt (衬衫)
- Sneaker (运动鞋)
- Bag (包包)
- Ankle boot (踝靴)
四、获取方式
Python中最简单的获取方式是通过使用torchvision
或tensorflow.keras.datasets
等库,如下是用torchvision
获取FashionMNIST的一个示例代码:
from torchvision.datasets import FashionMNIST
from torchvision.transforms import ToTensor
# 加载FashionMNIST数据集
train_data = FashionMNIST(root='./dataset/', train=True, transform=transforms.ToTensor(), download=True)
test_data = FashionMNIST(root='./dataset/', train=False, transform=transforms.ToTensor(), download=True)
- 注意: 这里设置的是默认会从"./dataset/"目录加载FashionMNIST数据集,如果没有则会自动下载。第一次运行时download要设成True,如果已下载过,后续可以设置成False。
五、观察数据
# 显示一下第一张图片
plt.imshow(train_data[0][0].squeeze(),cmap=plt.cm.binary)
# 显示一下前100张图片
plt.figure(figsize=(10,10))
for i in range(10*10):
## 在当前图下生成子图 5*5个图
plt.subplot(10,10,i+1)
plt.xticks([])
plt.yticks([])
plt.imshow(train_data[i][0].squeeze(), cmap=plt.cm.binary)
了解完数据后,我们开始盘他。