MNIST内置手写数字数据集的实现

torchvision库

torchivision库是PyTorch中用来处理图像和视频的一个辅助库,接下来我们就会使用torchvision库加载内置的数据集进行分类模型的演示

为了统一数据加载和处理代码,PyTorch提供了两个类用于处理数据加载,他们分别是torch.utils.data.Dataset类和torch.utils.data.DataLoader类,通过这两个类可使数据集加载和预处理代码与模型训练代码脱钩,从而获得更好的代码模块化和代码可读性。torchvision加载的内置图片数据集均继承自torch.utils.data.Dataset类,因此可直接使用加载的内置数据集创建DataLoader.

加载内置图片数据集

PyTorch的内置图片数据集均在torchvision.datasets模块下,包含Caltech、CelebA、CIFAR、Cityscapes、COCO、Fashion-MNIST、ImageNet、MNIST等很多著名的数据集,其中MNIAT数据集是手写数字数据集,这是一个很适合入门者学习使用的小型计算机视觉数据集,它包含0到9的手写数字图片和每一张图片对应的标签。接下来我们就以此数据集为例子进行学习。

import torchvision  # 导入torchvision库
from torchvision.transforms import ToTensor  #做好准备工作,导入所需要的包
import torch
import matplotlib.pyplot as plt
import numpy as np

首先就是对我们所需要的库进行导入。

我对上述的代码进行一下解读,首先导入了torchvision库,从torchvision.transforms模块下导入ToTensor类。torchvision.transforms模块包含了转换函数,使用它可以很方便的对加载的图形进行各种变换,这里用到的ToTensor类,该类的主要作用有以下3点。

  1. 将输入转换为张量
  2. 将读取图片的格式规范为(channel,height,width),这里和我们经常遇到的图片格式有可能会有一些去呗,PyTorch中的图片格式一般是通道数(channel)在前,然后是高度(height)和宽度(width)
  3. 将图片像素的取值范围归一化,规范为0到1的范围内
train_ds=torchvision.datasets.MNIST('data/',train=True,transform=ToTensor(),download=True)
test_ds=torchvision.datasets.MNIST('data/',train=False,transform=ToTensor(),download=True)

通过torchvision.datasets.MNIST方法加载MNIST数据集,方法中的第一个参数为data/表示下载数据集存放的位置,参数train表示是否是训练数据,若为True,则加载训练数据集,若为False,则加载测试数据集;
使用参数transform表示对加载数据的预处理,参数值为ToTensor();
最后一个参数download=True表示将下载此数据集,一旦下载完成后,下一次执行此代码是,将优先从本地文件夹直接加载,如果咱们的计算机不能连接互联网,也可以直接将文件复制到data文件夹中,这样就能从本地直接加载数据了。

现在我们得到了两个数据集,分别是训练数据集和测试数据集,PyTorch还提供了torch.utils.data.DataLoader类用以对数据集做进一步的处理,DataLoader接收数据集,并执行复杂的操作,如小批次处理、多线程、随机打乱等,以便从数据集中获取数据。它接收来自用户的Dataset实例,并使用采样器策略将数据采样为小批次。DataLoader的目的如下

1.使用shuffle参数对数据集做乱序的操作,一般情况下,需要对训练数据集进行乱序的操作,因为原始的数据在样本均衡的情况下肯呢个是按照某种顺序进行排列的,经过顺序打乱之后,数据的排列就会拥有一定的随机性,这样做可以避免出现模型反复依次序学习数据的特征或者学习到的只是数据的次数特征的情况。

2.将数据采样为小批次,可用batch_size参数指定批次大小。首先单个样本训练有一个很大的缺点,就是损失和梯度会受到单个样本的影响,如果样本分布不均匀,或者有错误标注样本,则会引起梯度的巨大震荡,从而导致模型训练效果很差。为了解决这个问题,我们可以考虑使用批量数据训练(也叫做批量梯度下降算法),通过遍历全部数据集算一次损失函数,然后计算损失对各个参数的梯度,并更新参数。这种训练方式没更新一次,参数都要把数据集里所有样本都看一遍,不仅计算开销大,而且计算速度慢。为了克服上述方法的缺点,一般采用的是一种折中手段进行损失函数计算:即把数据分为若干个小的批次,按批次来更新参数,这样,一个批次中的一组数据共同决定了本次梯度的方向,大大降低了参数更新时的梯度方差,下降起来更加稳定,减少了随机性,与单样本训练相比,小批次训练可利用矩阵操作进行有效的梯度计算,计算量也不是很大,对计算机内存的要求也不高。

3.可以充分利用多个子进程加速数据预处理。num_workers参数可以指定子进程的数量

4.可通过collate_fn参数传递批次数据的处理函数,实现在DataLoader中对批次数据做转换处理

train_dl=torch.utils.data.DataLoader(train_ds,batch_size=64,shuffle=True)
test_dl=torch.utils.data.DataLoader(test_ds,batch_size=46)

上面代码中分别创建了训练数据和测试数据的DataLoader,并设置他们的批次大小为64,对训练数据设置了shuffle为True;对测试数据,由于仅仅作为测试,没必要做乱序。

DataLoader是可迭代对象,我们观察它返回的数据集的类型,给大家对对DataLoader和MNIST数据集有一个直观的印象

imgs,labels=next(iter(train_dl))#创建生成器,并用next方法返回一个批次的数据
print(imgs.shape)
print(labels.shape)

我们使用iter方法将DataLoader对象创建为生成器,并使用next方法反悔了一个批次的图像(imgs)和对应的一个批次的标签(labels),image.shape为torch.Size([64,1,28,28]),这里的64是批次,我们可以认为这代表64张形状为(1,28,28)的图片,其中1为通道数,28和28分别为高和宽;既然这里有64张图片,那么就对应着有64个标签,也就是labels.shape所显示的torch.Size([64])

结果绘制

# 我们使用Matplotlib来绘制一下前10张的图片
plt.figure(figsize=(20,2))  # 创建一个(10,1)大小的画布
for i,img in enumerate(imgs[:20]):
    npimg=img.numpy()  # 将张量转换为ndarray
    npimg=np.squeeze(npimg)  # 图片形状由(1,28,28)转换为(28,28)
    plt.subplot(1,20,i+1)  # 初始化子图,3个参数表示1行10列的第i+1个子图
    plt.imshow(npimg)  #在子图中绘制单张图片
    plt.axis('off')  # 关闭显示子图坐标

plt.imshow() 是一个用于显示图像的函数,通常用于在 Python 中使用 Matplotlib 库绘制图像。它可以接受一个数组或图像数据,并将其显示为图像。这个函数通常用于可视化图像数据,比如热图、灰度图、彩色图等。plt.imshow() 可以接受一些参数,比如 cmap(颜色映射)、interpolation(插值方法)等,用来控制图像的显示效果。

接下来,我们打印对应的标签

print(labels[:20])

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/252833.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

进程通信知识基础【Linux】——下篇

目录 前文 一,命名管道 创建命名管道 1. getline——c库 2. unlink——系统接口 实践代码 common.hpp client.cpp server.cpp Log.cpp 二,共享内存(system V接口) 1. 创建共享内存 shmget接口 2. 删除共享内存 常见…

PMP项目管理 - 相关方管理

系列文章目录 PMP项目管理 - 质量管理 PMP项目管理 - 采购管理 PMP项目管理 - 资源管理 PMP项目管理 - 风险管理 PMP项目管理 - 沟通管理 现在的一切都是为将来的梦想编织翅膀,让梦想在现实中展翅高飞。 Now everything is for the future of dream weaving wing…

【一种用opencv实现高斯曲线拟合的方法】

背景: 项目中需要实现数据的高斯拟合,进而提取数据中标准差,手头只有opencv库,经过资料查找验证,总结该方法。 基础知识: 1、opencv中solve可以实现对矩阵参数的求解; 2、线的拟合就是对多项…

【深度强化学习】确定性策略梯度算法 DDPG

前面讲到如 REINFORCE,Actor-Critic,TRPO,PPO 等算法,它们都是随机性策略梯度算法(Stochastic policy),在广泛的任务上表现良好,因为这类方法鼓励了算法探索,给出的策略是…

探索 Vim:一个强大的文本编辑器

引言: Vim(Vi IMproved)是一款备受推崇的文本编辑器,拥有强大的功能和高度可定制性,提供丰富的编辑和编程体验。本文将探讨 Vim 的基本概念、使用技巧以及为用户带来的独特优势。 简介和发展 1. Vim 的简介和历史 V…

【二分查找】自写二分函数的总结

作者推荐 【动态规划】【广度优先搜索】LeetCode:2617 网格图中最少访问的格子数 本文涉及的基础知识点 二分查找算法合集 自写二分函数 的封装 我暂时只发现两种: 一,在左闭右开的区间寻找最后一个符合条件的元素,我封装成FindEnd函数。…

力扣刷题-二叉树-平衡二叉树

110 平衡二叉树 给定一个二叉树,判断它是否是高度平衡的二叉树。 本题中,一棵高度平衡二叉树定义为:一个二叉树每个节点 的左右两个子树的高度差的绝对值不超过1。 示例 1: 给定二叉树 [3,9,20,null,null,15,7] 返回 true 。 给定二叉树 [1…

AUTOSAR ComM模块配置以及代码

ComM模块配置以及代码执行流程 1、基本的一个通道的配置列表 ComMNmVariant 概念的个人理解: FULL: 完全按照AUTOSAR NM方式进行调用 LIGHT :设置一个超时时间,在请求停止通信的时候开始计时,超时之后才会进入FULLCOM…

运维实践|采集MySQL数据出现many connection errors

文章目录 问题出现问题分析当前环境问题分析 解决方案1 检查调度事件任务是否开启2 开启调度事件任务3 创建一张日志表4 创建函数存储过程5 创建事件定时器6 开启事件调度任务7 检查核实是否创建 总结 问题出现 最近在做OGG结构化数据采集工作,在数据采集过程中&am…

将博客搬至微信公众号了

一、博客搬家通知 各位码友们好,大家是不是基本很少看 CSDN 了呢?平时开发是不都依靠着 chatGPT 来解决工作中的技术问题了,不过我觉得在工作中的使用场景各式各样的,具体问题还得自己具体去梳理逻辑,再考虑用什么样的…

C语言:求和1+1/2-1/3+1/4-1/5+……-1/99+1/100

#include<stdio.h> int main() {int i 0;double sum 0.0;int flag 1;for (i 1;i < 100;i){sum 1.0 / i * flag;flag -flag;}printf("sum%lf\n", sum);return 0; }

SpringIOC之@Primary

博主介绍&#xff1a;✌全网粉丝5W&#xff0c;全栈开发工程师&#xff0c;从事多年软件开发&#xff0c;在大厂呆过。持有软件中级、六级等证书。可提供微服务项目搭建与毕业项目实战&#xff0c;博主也曾写过优秀论文&#xff0c;查重率极低&#xff0c;在这方面有丰富的经验…

力扣刷题-二叉树-找树左下角的值

513 找树左下角的值 给定一个二叉树的 根节点 root&#xff0c;请找出该二叉树的 最底层 最左边 节点的值。 假设二叉树中至少有一个节点。 示例 1&#xff1a; 示例 2&#xff1a; 思路 层序遍历 直接层序遍历&#xff0c;因为题目说了是最底层&#xff0c;最左边的值&a…

【数据结构—队列的实现】

提示&#xff1a;文章写完后&#xff0c;目录可以自动生成&#xff0c;如何生成可参考右边的帮助文档 文章目录 前言 一、队列 1.1队列的概念及结构 二、队列的实现 2.1头文件的实现—Queue.h 2.2源文件的实现—Queue.c 2.3源文件的测试—test.c 三、测试队列实际数据的展示 3.…

mysql使用st_distance_sphere函数报错Incorrect arguments to st_distance_sphere

前言 最近使用空间点位查询数据时函数报错Incorrect arguments to st_distance_sphere报错。 发现问题 因为之前是没有问题的&#xff0c;所以把问题指向了数据&#xff0c;因为是外部数据&#xff0c;不是通过系统打点获取&#xff0c;发现是因为经纬度反了&#xff0c;loc…

VRRP(虚拟路由冗余协议)

一.VRRP简介 1.VRRP是什么 Virtual route Redundancy Protocol&#xff0c;也叫虚拟路由器冗余协议。 利用VRRP&#xff0c;一组路由器协同工作&#xff0c;单只有一个处于Master状态&#xff0c;处于该状态的路由器&#xff08;的接口&#xff09;承担实际的数据流量转发任…

MapReduce序列化实例代码

1 &#xff09;需求&#xff1a;统计每个学号该月的超市消费、食堂消费、总消费 2 &#xff09;输入数据格式 序号 学号 超市消费 食堂消费 18 202200153105 8.78 12 3 &#xff09;期望输出格式 key &#xff08;学号&#xff09; value &#xff08; bean 对象&#xf…

二分查找算法的概念、原理、效率以及使用C语言循环和数组的简单实现

二分查找的概念 二分查找也称折半查找&#xff08;Binary Search&#xff09;&#xff0c;它是一种效率较高的查找方法。但是&#xff0c;折半查找要求线性表必须采用顺序存储结构&#xff0c;而且表中元素按关键字有序排列。 实现原理 首先&#xff0c;假设表中元素是按升序…

深度学习项目实战:垃圾分类系统

简介&#xff1a; 今天开启深度学习另一板块。就是计算机视觉方向&#xff0c;这里主要讨论图像分类任务–垃圾分类系统。其实这个项目早在19年的时候&#xff0c;我就写好了一个版本了。之前使用的是python搭建深度学习网络&#xff0c;然后前后端交互的采用的是java spring …

VLAN间的通讯---三层交换

一.三层交换 1.概念 使用三层交换技术实现VLAN间通信 三层交换二层交换 三层转发 2.基于CEF的MLS CEF是一种基于拓补转发的模型 转发信息库&#xff08;FIB&#xff09;临接关系表 转发信息库&#xff08;FIB&#xff09;可以理解为路由表 邻接关系表可以理解为MAC地址表…