2-3 图像分类数据集

MNIST数据集是图像分类任务中广泛使用的数据集之一,但作为基准数据集过于简单,我们将使用类似但更复杂的Fashion-MNIST数据集。

%matplotlib inline
import torch
import torchvision  # pytorch模型关于计算机视觉模型实现的一个库
from torch.utils import data  # 方便读取数据
from torchvision import transforms
from d2l import torch as d2l

d2l.use_svg_display() # 用svg格式来显示我们的图片

读取数据集

通过框架中的内置函数将Fashion-MINST数据集下载并读取到内存中

# 通过ToTensor实例将图像数据从PIL类型变换成32位浮点数格式,
# 并除以255使得所有像素的数值均在0~1之间
trans = transforms.ToTensor()
mnist_train = torchvision.datasets.FashionMNIST(
    root="../data", train=True, transform=trans, download=True)
    # train=True 表示我们下载的是训练数据集
    # transform=trans 表示我们拿到数据之后,得到的是pytorch的tensor,而不是一堆图片
    # download=True 表示我们默认从网上下载(如果不想在网上下,可以先在网上下好,存在"../data"目录下,然后不指定download=True)
mnist_test = torchvision.datasets.FashionMNIST(
    root="../data", train=False, transform=trans, download=True)
    # 测试数据集不参与训练,用来验证我们训练的好坏
len(mnist_train), len(mnist_test)

在这里插入图片描述
mnist_train有6万张图片,mnist_test有1万张图片。
Fashion-MNIST由 10 10 10个类别的图像组成, 每个类别由训练数据集(train dataset)中的6000张图像 和测试数据集(test dataset)中的1000张图像组成。 因此,训练集和测试集分别包含60000和10000张图像。 测试数据集不会用于训练,只用于评估模型性能。

mnist_train[0][0].shape # 第一维表示第0个example,第二维表示图片的标号
# 展示第一张图片的形状

在这里插入图片描述
每个输入图像的高度和宽度均为28像素。 数据集由灰度图像组成,其通道数为 1 1 1。 为了简洁起见,本书将高度 h h h像素、宽度 w w w像素图像的形状记为 h × w h \times w h×w ( h , w ) (h,w) (h,w)

Fashion-MNIST中包含的10个类别,分别为t-shirt(T恤)、trouser(裤子)、pullover(套衫)、dress(连衣裙)、coat(外套)、sandal(凉鞋)、shirt(衬衫)、sneaker(运动鞋)、bag(包)和ankle boot(短靴)。 以下函数用于在数字标签索引及其文本名称之间进行转换

def get_fashion_mnist_labels(labels):  #@save
    """返回Fashion-MNIST数据集的文本标签"""
    text_labels = ['t-shirt', 'trouser', 'pullover', 'dress', 'coat',
                   'sandal', 'shirt', 'sneaker', 'bag', 'ankle boot']
    return [text_labels[int(i)] for i in labels]

我们现在可以创建一个函数来可视化这些样本。

def show_images(imgs, num_rows, num_cols, titles=None, scale=1.5):  #@save
    """绘制图像列表"""
    figsize = (num_cols * scale, num_rows * scale)
    _, axes = d2l.plt.subplots(num_rows, num_cols, figsize=figsize)
    axes = axes.flatten()
    for i, (ax, img) in enumerate(zip(axes, imgs)):
        if torch.is_tensor(img):
            # 图片张量
            ax.imshow(img.numpy())
        else:
            # PIL图片
            ax.imshow(img)
        ax.axes.get_xaxis().set_visible(False)
        ax.axes.get_yaxis().set_visible(False)
        if titles:
            ax.set_title(titles[i])
    return axes

以下是训练数据集中前几个样本的图像及其相应的标签。

X, y = next(iter(data.DataLoader(mnist_train, batch_size=18)))
# 我们在构建了一个pytorch数据集之后,可以放进一个dataloader里面,然后指定一个batch_size,就可以拿到一个大小为固定数字一个批量的数据。
# next 就是拿到第一个小批量
show_images(X.reshape(18, 28, 28), 2, 9, titles=get_fashion_mnist_labels(y));
# 画两行,每一行有9张图片,最后每张图片的标号等于从刚才定义的第一个函数里面取得,y是一个数值的标号,我们通过刚才那个函数获得字符串

在这里插入图片描述


读取小批量

为了使我们在读取训练集和测试集时更容易,我们使用内置的数据迭代器,而不是从零开始创建。 回顾一下,在每次迭代中,数据加载器每次都会读取一小批量数据,大小为batch_size。 通过内置数据迭代器,我们可以随机打乱了所有样本,从而无偏见地读取小批量。

batch_size = 256

def get_dataloader_workers():  #@save
    """使用4个进程来读取数据"""
    return 4

train_iter = data.DataLoader(mnist_train, batch_size, shuffle=True,
                             num_workers=get_dataloader_workers())
# shuffle=True 随机打乱顺序
#  num_workers=get_dataloader_workers()) 需要多少进程,这里是4个进程

我们看一下读取训练数据所需的时间。

timer = d2l.Timer()
for X, y in train_iter:  # 一个一个访问所有的batch
    continue
f'{timer.stop():.2f} sec'

在这里插入图片描述
可知,我们读一次完整数据的时间是1.36秒。
一般读取数据的速度要比训练数据的速度要快


整合所有组件

现在我们定义load_data_fashion_mnist函数,用于获取和读取Fashion-MNIST数据集。 这个函数返回训练集和验证集的数据迭代器。 此外,这个函数还接受一个可选参数resize,用来将图像大小调整为另一种形状

def load_data_fashion_mnist(batch_size, resize=None):  #@save
    """下载Fashion-MNIST数据集,然后将其加载到内存中"""
    trans = [transforms.ToTensor()]
    if resize:
        trans.insert(0, transforms.Resize(resize))
    trans = transforms.Compose(trans)
    mnist_train = torchvision.datasets.FashionMNIST(
        root="../data", train=True, transform=trans, download=True)
    mnist_test = torchvision.datasets.FashionMNIST(
        root="../data", train=False, transform=trans, download=True)
    return (data.DataLoader(mnist_train, batch_size, shuffle=True,
                            num_workers=get_dataloader_workers()),
            data.DataLoader(mnist_test, batch_size, shuffle=False,
                            num_workers=get_dataloader_workers()))
    # 返回两个dataloader,一个train,一个test

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/775301.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

windows启动Docker闪退Docker desktop stopped

Windows启动Docker闪退-Docker desktop stopped 电脑上很早就安装有Docker了,但是有一段时间都没有启动了,今天想启动启动不起来了,打开没几秒就闪退,记录一下解决方案。仅供参考 首先,参照其他解决方案,本…

对SRS媒体服务器进行漏洞扫描时,SRS的API模块会出现漏洞,如何修补这些漏洞的简单方法

目录 一、引言 1、srs介绍 2、媒体流介绍 3、应用场景 二、SRS的http_api介绍、及漏洞 1、概述 2、http_api模块的作用 (1)提供HTTP API服务 (2)管理和监控SRS服务器 (3)自定义开发 三、漏洞扫描…

昆虫学(书籍学习资料)

包括昆虫分类(上下册)、昆虫生态大图鉴等书籍资料。

搜索+动态规划

刷题刷题刷题刷题 ​​​​​​​​​​​​​​Forgery - 洛谷 | 计算机科学教育新生态 (luogu.com.cn) 思路: 需要两个数组,一个数组全部初始化为".",另一个数组输入数据,每碰到一个“.”就进行染色操作,将其周围的…

Django学习第五天

启动项目命令 python manage.py runserver 图像验证码生成随机字母或者数字 import random from PIL import Image, ImageDraw, ImageFont, ImageFilterdef check_code(width120, height40, char_length5, font_fileZixunHappyBold.ttf, font_size28):code []img Image.new…

在C#/Net中使用Mqtt

net中MQTT的应用场景 c#常用来开发上位机程序,或者其他一些跟设备打交道比较多的系统,所以会经常作为拥有数据的终端,可以用来采集上传数据,而MQTT也是物联网常用的协议,所以下面介绍在C#开发中使用MQTT。 安装MQTTn…

【ARMv8/v9 GIC 系列 5.1 -- GIC GICD_CTRL Enable 1 of N Wakeup Function】

请阅读【ARM GICv3/v4 实战学习 】 文章目录 GIC Enable 1 of N Wakeup Function基本原理工作机制配置方式应用场景小结GIC Enable 1 of N Wakeup Function 在ARM GICv3(Generic Interrupt Controller第三代)规范中,引入了一个名为"Enable 1 of N Wakeup"的功能。…

图像的对数变换

对数变换在图像处理中通常有以下作用: 因为对数曲线在像素值较低的区域斜率较大,像素值较高的区域斜率比较低,所以图像经过对数变换之后,在较暗的区域对比度将得到提升,因而能增强图像暗部的细节。图像的傅里叶频谱其…

Ubuntu多显示器设置不同缩放比例

Ubuntu多显示器设置不同缩放比例 设备问题解决方案 设备 笔记本屏幕分辨率为2560 \times 1600,外接显示器的分辨率为3840 \times 2160。 问题 Ubuntu默认的显示器设置中,缩放仅能选择100%,200%,300%,400%。假…

C++中的引用——引用做函数参数

作用:函数传参时,可以利用引用的技术让形参修饰实参 优点:可以简化指针修改实参 示例: 1.值传递 运行结果: 2.地址传递 运行结果: 3.引用传递 运行结果:

ES6模块化学习

1. 回顾:node.js 中如何实现模块化 node.js 遵循了 CommonJS 的模块化规范。其中: 导入其它模块使用 require() 方法 模块对外共享成员使用 module.exports 对象 模块化的好处: 大家都遵守同样的模块化规范写代码&#xff…

一对一服务,定制化小程序:NetFarmer助力企业精准触达用户

在当今这个日新月异的数字化时代,小程序以其独特的魅力和广泛的应用场景,正逐步成为企业出海战略中的璀璨明星。NetFarmer,作为业界领先的数字化出海服务商,不仅深谙HubSpot营销自动化的精髓,更在小程序领域展现了卓越…

【UE5.3】笔记8 添加碰撞,检测碰撞

添加碰撞 打开BP_Food,添加Box Collision组件,与unity类似: 调整Box Collision的大小到刚好包裹物体,通过调整缩放和盒体范围来控制大小,一般先调整缩放找个大概大小,然后调整盒体范围进行微调。 碰撞检测 添加好碰撞…

CTF常用sql注入(二)报错注入(普通以及双查询)

0x05 报错注入 适用于页面无正常回显,但是有报错,那么就可以使用报错注入 基础函数 floor() 向下取整函数 返回小于或等于传入参数的最大整数。换句话说,它将数字向下取整到最接近的整数值。 示例: floor(3.7) 返回 3 floor(-2…

Python脚本:将Word文档转换为Excel文件

引言 在文档处理中,我们经常需要将Word文档中的内容转换成其他格式,如Excel,以便更好地进行数据分析和报告。针对这一需求,我编写了一个Python脚本,能够批量处理指定目录下的Word文档,将其内容结构化并转换…

pandas,dataframe使用笔记

目录 新建一个dataframe不带列名带列名 dataframe添加一行内容查看dataframe某列的数据类型新建dataframe时设置了列名,则数据类型为object dataframe的保存保存为csv文件保存为excel文件 dataframe属于pandas 新建一个dataframe 不带列名 df pd.DataFrame() 带…

【C++】unordered系列容器的封装

你很自由 充满了无限可能 这是很棒的事 我衷心祈祷你可以相信自己 无悔地燃烧自己的人生 -- 东野圭吾 《解忧杂货店》 unordered系列的封装 1 unordered_map 和 unordered_set2 改造哈希桶2.1 模版参数2.2 加入迭代器 3 上层封装3.1 unordered_set3.2 unordered_map 4 面…

C++基础22 字符串与字符数组及其相关操作

这是《C算法宝典》C基础篇的第22节文章啦~ 如果你之前没有太多C基础,请点击👉C基础,如果你C语法基础已经炉火纯青,则可以进阶算法👉专栏:算法知识和数据结构👉专栏:数据结构啦 ​ 目…

c++重定向输出和输出(竞赛讲解)

1.命令行重定向 在命令行中指定输出文件 指令 .\重定向学习.exe > 1.txt 效果 命令行输入和输出 指令 .\重定向学习.exe < 2.txt > 1.txt 效果 代码 #include<bits/stdc++.h> using namespace std; int n; int main(){cin>>n;for(int i=0;i<n;i…

4、SSD主控

简述 主控是个片上系统&#xff0c;由硬件和固件组成一个功能完整的系统&#xff1b;上文所述的FTL就属于主控的固件范畴。主控闪存构成了整个SSD&#xff0c;在闪存确定的情况下&#xff0c;主控就反映了各家SSD的差异。实时上各家SSD的差异也主要反应在主控上&#xff0c;毕…