动手学深度学习——图像分类数据集(代码详解)

目录

  • 1. 图像分类数据集
    • 1.1 读取数据集
    • 1.2 读取小批量
    • 1.3 整合所有组件
    • 1.4 小结

1. 图像分类数据集

这里采用Fashion-MNIST数据集

  • torchvision:torch类型的可视化包,一般计算机视觉和数据可视化需要使用
  • from torchvision import transforms:该组件经常用于图片的修改(一般数据集中的图片都是PIL格式,使用的时候需要转化为tenser,而在加入函数时常需要转化为nadarry(numpy中的ndarray为多维数组))
  • d2l.use_svg_display():使用什么模式展示图片
%matplotlib inline
import torch
import torchvision #pytorch用于计算机视觉的一个库
from torch.utils import data
from torchvision import transforms #导入对数据操作的模具
from d2l import torch as d2l

d2l.use_svg_display() #使用svg展示图片

1.1 读取数据集

通过框架中的内置函数将Fashion-MNIST数据集下载并读取到内存中

  • torchvision.datasets:一般用于图像数据集的下载和获取
    eg:
  • torchvision.datasets.FashionMNIST( root=, train=True, transform=, download=True)
    • train:是否为训练集
    • transform:使用什么格式转换(可以从transforms组件中选择)
    • dowload:是否下载对应数据集
    • .FashionMNIST可以更换为其他数据源
# 通过ToTensor实例将图像数据从PIL类型变换成32位浮点数格式,
# 并除以255使得所有像素的数值均在0~1之间
trans = transforms.ToTensor() #对图片进行预处理,转换为tensor格式

# 下载训练集和测试集,并保存
mnist_train = torchvision.datasets.FashionMNIST(
	root="../data", train=True, transform=trans,download=True)
mnist_train = torchvision.datasets.FashionMNIST(
	root="../data", train=False, transform=trans,download=True)

Fashion-MNIST由10个类别的图像组成, 每个类别由训练数据集(train dataset)中的6000张图像 和测试数据集(test dataset)中的1000张图像组成。 因此,训练集和测试集分别包含60000和10000张图像。 测试数据集不会用于训练,只用于评估模型性能。

# 输出训练集和测试集的大小
len(mnist_train), len(mnist_test)

在这里插入图片描述
每个输入图像的高度和宽度均为28像素。 数据集由灰度图像组成,其通道数为1(彩色图像通道数为3)。

# 索引到第一张图片
mnist_train[0][0].shape # 输入图像的通道数、高度和宽度

在这里插入图片描述
Fashion-MNIST中包含的10个类别,分别为t-shirt(T恤)、trouser(裤子)、pullover(套衫)、dress(连衣裙)、coat(外套)、sandal(凉鞋)、shirt(衬衫)、sneaker(运动鞋)、bag(包)和ankle boot(短靴)。以下函数用于在数字标签索引及其文本名称之间进行转换。

# 获取数据集的标签
def get_fashion_mnist_labels(labels): #@save
	"""返回Fashion-MNIST数据集的文本标签"""
	text_labels = ['t-shirt', 'trouser', 'pullover', 'dress', 'coat',
                   'sandal', 'shirt', 'sneaker', 'bag', 'ankle boot']
    return [text_lables[int(i)] for i in labels]

创建一个函数来可视化这些样本。

  • plt.subplots()是一个返回包含图形和轴对象的元组的函数。因此,在使用时fig, ax = plt.subplots(),将此元组解压缩到变量fig和ax。
  • enumerate()函数用于将一个可遍历的数据对象(如列表、元组或字符串)组合为一个索引序列,同时列出数据和数据下标,一般用在 for 循环当中,生成可以遍历的每个元素有对应序号(0, 1, 2, 3…)的enumerate对象。
  • zip()函数用于将多个可迭代对象作为参数,依次将对象中对应的元素打包成一个个元组,然后返回由这些元组组成的对象,里面的每个元素大概为i,(ax,img)的形式。
  • imshow()可以接收二维,三维甚至多维数组。二维默认为一通道即灰度图像,三维需要在第三个维度指定图像通道数(必须是第三维)
def show_images(imgs, num_rows, num_cols, titles=None, scale=1.5): #@save
	"""绘制图像列表"""
	figsize = (num_cols * scale, num_rows * scale)
	
	# 第1个参数是个图,一般不用;第2个axer类似于图片的索引矩阵(行,列)
	_, axes = d2l.plt.subplots(num_rows, num_cols, figsize=figsize) # axes:轴
	axes = axes.flatten()

	# 遍历生成形如i, (ax, img)形式的enumerate对象
	for i, (ax, img) in enumerate(zip(axes, imgs)):
		if torch.is_tensor(img):
			# 图片张量
			ax.imshow(img.numpy())
			
		else:
			# PIL图片
			ax.imshow(img)
		ax.axes.get_xaxis().set_visible(False) #x轴隐藏
		ax.axes.get_yaxis().set_visible(False) #y轴隐藏
		if titles:
			ax.set_title(title[i]) #显示标题
	return axes

以下是训练数据集中前几个样本的图像及其相应的标签。

  • next() 返回迭代器的下一个项目。
  • next() 函数要和生成迭代器的iter() 函数一起使用。
  • 我们可以通过iter()函数获取这些可迭代对象的迭代器。然后,我们可以对获取到的迭代器不断使⽤next()函数来获取下⼀条数据。
    注:当我们已经迭代完最后⼀个数据之后,再次调⽤next()函数会抛出 StopIteration的异常 ,来告诉我们所有数据都已迭代完成,不⽤再执⾏ next()函数了。
# 使用next()函数获取批量大小为18的训练集的图像和标签
X, y = next(iter(data.DataLoader(mnist_train, batch_size=18)))

#显示18张图片,宽度为28,长度为28,总共为2行9列
# 绘制两行图片,每一行有9张图片,并获取标签
show_images(X.reshape(18, 28, 28), 2, 9, titles=get_fashion_mnist_labels(y)); 

在这里插入图片描述

1.2 读取小批量

为了使我们在读取训练集和测试集时更容易,我们使用内置的数据迭代器,而不是从零开始创建。 回顾一下,在每次迭代中,数据加载器每次都会读取一小批量数据,大小为batch_size。 通过内置数据迭代器,我们可以随机打乱了所有样本,从而无偏见地读取小批量。

batch_size = 256

def get_dataloader_workers(): #@save
	"""使用4个进程来读取数据"""
	return 4

# 训练集需要设置shuffle=True打乱顺序	
train_iter = data.DataLoader(mnist_train, batch_size, shuffle=True,
							 num_workers=get_dataloader_workers())

我们看一下读取训练数据所需的时间。

timer = d2l.Timer() #调用Timer函数,测试速度
for X, y in train_iter:
	continue
f'{timer.stop():.2f} sec' #输出读取数据所用的秒数,精度为2位小数

在这里插入图片描述

1.3 整合所有组件

定义load_data_fashion_mnist函数,用于获取和读取Fashion-MNIST数据集。这个函数返回训练集和验证集的数据迭代器。 此外,这个函数还接受一个可选参数resize,用来将图像大小调整为另一种形状。

  • torchvision.transforms是pytorch中的图像预处理包,一般用Compose把多个步骤整合到一起。
  • insert函数是一种用于列表的内置函数。这个函数的作用是在一个列表中的指定位置,插入一个元素。
transforms中的函数功能
Resize把给定的图片resize到given size
Normalize用均值和标准差归一化张量图像
def load_data_fashion_mnist(batch_size, resize=None):  #@save
	"""下载Fashion-MNIST数据集,然后将其加载到内存中"""
	# 转换为tensor
	trans = [transforms.ToTensor()]

	
	if resize:
		trans.insert(0, transforms.Resize(resize))
	# compose整合步骤
	trans = transforms.Compose(trans)

	# 下载训练集和测试集,将小批量样本返回到train_iter中,用于之后的训练
	mnist_train = torchvision.datasets.FashionMNIST(
		root="../data", train=True, transform=trans, download=True)
	mnist_test = torchvision.datasets.FashionMNIST(
		root="../data", train=False, transform=trans, download=True)
	return (data.DataLoader(mnist_train, batch_size, shuffle=True,
							num_workers=get_dataloader_workers()),
			data.DataLoader(mnist_test, batch_size, shuffle=False,
							num_workers=get_dataloader_workers()))

下面,我们通过指定resize参数来测试load_data_fashion_mnist函数的图像大小调整功能。

train_iter, test_iter = load_data_fashion_mnist(32, resize=64)
for X, y in train_iter:
	print(X.shape, X.dtype, y.shape, y.dtype)
	break

在这里插入图片描述

1.4 小结

  • Fashion-MNIST是一个服装分类数据集,由10个类别的图像组成。我们将在后续章节中使用此数据集来评估各种分类算法。
  • 我们将高度h像素,宽度w像素图像的形状记为h×w或(h,w)。
  • 数据迭代器是获得更高性能的关键组件。依靠实现良好的数据迭代器,利用高性能计算来避免减慢训练过程。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/40427.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

基于单片机的智能台灯 灯光控制系统人体感应楼梯灯系统的设计与实现

功能介绍 以STM32单片机作为主控系统;主通过光敏采集当前光线强度;通过PMW灯光调节电路,我们可以根据不同的光线亮度,进行3挡调节;通过人体红外检测当前是否有人;通过不同光线情况下使用PWM脉冲电路进行调节…

13matlab数据分析多项式的求值(matlab程序)

1.简述 统计分析常用函数 求最大值 max 和 sum 积 prod 平均值:mean 累加和:cumsum 标准差:std 方差:var 相关系数:corrcoef 排序:sort 四则运算 1.多项式的加减运算就是所对应的系数向量的加减运算&#…

使用jQuery的ajax提交图片信息

1 设置图片id(html) 首先,定义上传图片的id,根据上传文件的id获取图片信息: 注:图片的id应该设置在input标签里面 2 发送ajax请求(js) var formData new FormData(); formData.ap…

day40-Mybatis(resultMap拓展)

0目录 Mybatis-resultMap拓展 1.2.3 1.数据库字段和javabean实体类属性不一致时 解决方案1:将sql语句中给予别名(别名同javabean中实体类保持一致) 解决方案2:使用resultMap 2.两表关联(用户表和角色表关联查询&…

【算法基础:数据结构】2.2 字典树/前缀树 Trie

文章目录 知识点cpp结构体模板 模板例题835. Trie字符串统计❤️❤️❤️❤️❤️(重要!模板!)143. 最大异或对😭😭😭😭😭(Trie树的应用) 相关题目…

【UE4 塔防游戏系列】06-炮塔发射子弹攻击敌人

效果 步骤 1. 新建一个Actor蓝图类,命名为“TotalBulletsCategory”,用来表示子弹蓝图总类,后面会有很多不同类型的子弹会继承该类 打开“TotalBulletsCategory”,添加粒子系统组件、盒体碰撞组件和发射物移动组件 调整发射物重力…

电压放大器在超声波焊接中的作用以及应用

电压放大器是一种运用于电子设备中的信号放大器,主要作用是将小信号放大为更高幅度的信号。在超声波焊接中,电压放大器起到了重要的作用,它可以将从传感器采集到的微小信号放大为能够被检测和处理的合适大小的信号。 超声波焊接是现代工业生产…

使用shell监控应用运行状态通过企业微信接收监控通知

目的:编写shell脚本来监控应用服务运行状态,若是应用异常则自动重启应用通过企业微信接收监控告警通知 知识要点: 使用shell脚本监控应用服务使用shell脚本自动恢复异常服务通过企业微信通知接收监控结果shell脚本使用数组知识,…

二次元少女-InsCode Stable Diffusion 美图活动一期

一、 Stable Diffusion 模型在线使用地址: https://inscode.csdn.net/inscode/Stable-Diffusion 二、模型相关版本和参数配置: 模型版本:chilloutmix_NiPrunedFp32Fix.safetensors 采样方法(Sampler)Sampling method:DPM SDE …

xpath下载安装——Python爬虫xpath插件下载安装(2023.7亲测可用!!)

目录 1.免费下载插件链接(若失效评论区留言发送最新链接)(2023.7亲测可用) 2.安装插件 (1)打开chrome浏览器页面,点击:右上角三个点 > 扩展程序 > 管理拓展程序 &#xff…

2023-7-19-第二十式迭代器模式

🍿*★,*:.☆( ̄▽ ̄)/$:*.★* 🍿 💥💥💥欢迎来到🤞汤姆🤞的csdn博文💥💥💥 💟💟喜欢的朋友可以关注一下&#xf…

TortoiseGit 入门指南12:创建标签

前面的文章不止一次的提到过 标签 (Tag),我们在《TortoiseGit 入门指南08:浏览引用以及在引用间切换》一文中知道,标签 是一种 引用;还知道每个提交都对应着一个 SHA-1 值,而引用就是 SHA-1 的一…

SuperGlue学习记录之最优传输

在进行最优传输相关理论的学习过程中,找到SuperGlue这篇论文,该篇论文通过最优传输来完成特征点的匹配过程。 SuperGlue结构 先来看一下其结构: 首先将两张图片送入特征提取网络,通过卷积网络提取出特征,主要有四个值…

Generative Adversarial Network

Goodfellow,2014年 文献阅读笔记--GAN--Generative Adversarial NetworkGAN的原始论文-组会讲解_gan英文论文_Flying Warrior的博客-CSDN博客 启发:如何看两个数据是否来自同一个分布? 在统计中,two sample test。训练一个二分类的分类器,如果能分开这两个数据,说明来自…

网络安全—信息安全—黑客技术(学习笔记)

一、什么是网络安全? 网络安全可以基于攻击和防御视角来分类,我们经常听到的 “红队”、“渗透测试” 等就是研究攻击技术,而“蓝队”、“安全运营”、“安全运维”则研究防御技术。 无论网络、Web、移动、桌面、云等哪个领域,都…

[深度学习入门]什么是神经网络?[神经网络的架构、工作、激活函数]

目录 一、前言二、神经网络的架构——以手写数字识别三、神经网络的工作1、单输入单输出感知器函数2、二维输入参数3、三维输入参数 四、激活函数1、激活函数2、ReLU激活函数3、非线性激活函数(1)二输入二输出的神经网络的架构(2)…

计算机网络 day8 动态路由 - NAT - SNAT实验 - VMware的网卡的3种模式

目录 动态路由:IGP 和 EGP 参考网课:4.6.1 路由选择协议概述_哔哩哔哩_bilibili ​编辑 IGP(Interior Gateway Protocol)内部网关协议: EGP(Interior Gateway Protocol)外部网关协议&#x…

使用模板创建【vite+vue3+ts】项目出现 “找不到模块‘vue‘或其相应的类型声明” 的解决方案

问题描述 项目前台需要使用Vue3Ts来写一个H5应用,然后我用模板创建 npm create vitelatest vue3-vant-mobile -- --template vue-ts创建完后进入HelloWorld.vue,两眼一黑 解决办法一 npm i --save-dev types/node然后在tsconfig.json的"compi…

软件测试银行项目面试过程

今天参加了一场比较正式的面试,汇丰银行的视频面试。在这里把面试的流程记录一下,结果还不确定,但是面试也是自我学习和成长的过程,所以记录下来大家也可以互相探讨一下。 请你做一下自我介绍?(汇丰要求英…

Kind | Kubernetes in Docker 把k8s装进docker!

有点像杰克船长的黑珍珠 目录 零、说明 一、安装 安装 Docker 安装 kubectl 安装 kind 二、创建/切换/删除集群 创建 切换 删除 将镜像加载到 kind 群集中 零、说明 官网:kind Kind: Kubernetes in Docker 的简称。kind 是一个使用 Docker 容…