pytorch基础实践-数据与预处理

文章目录

  • 数据集
    • Fashion-MNIST 数据集
  • 数据预处理
    • 包的导入
    • 在Pytorch中进行 ETL
      • 利用torchvison包获取和处理数据集(E+T)
    • 访问数据集
      • 访问和查看 train_set 中的单个数据
      • 利用 DataLoader 成批访问数据

数据集

Fashion-MNIST 数据集

  • MNIST
    MNIST,Modified National Institute of Standards and Technology database,前面加了“modified ”是因为这个数据集已经是在原始的 NIST 数据集上修改过的版本。

    简单来说 MNIST 就是一个包含了 0-9 十个数字(十个类别)的手写图片数据集,都是灰度图片,每张图片 28x28 像素,每个类别 7000 张图片,一共 70000 张。并且划分了 60000 张图片作为训练集,10000 张图片作为测试集。
    在这里插入图片描述
    MNIST 在图像分类领域非常流行,主要有两个原因:一是这个数据集特别简单,适合新手上手;二是学术圈为了比较各自的算法优劣,会在相同的数据集上训练算法,就是 MNIST。

    MNIST 也有它的问题就是太简单了(图像分类领域的“hello world”),所以有一帮人就开发了 Fashion-MNIST 想要来取代 MNIST。

  • Fashion-MNIST
    Fashion-MNIST 是一个德国的时装公司 Zalando 下面的研究院 Zalando Research 开发的,它用10类服装的图片取代了十类手写数字图片。十个类别分别是:
    在这里插入图片描述
    Fashion-MNIST 的设计理念就是作为 MNIST 的直接取代(direct dropin replacement),就是说以前使用 MNIST 的模型,除了数据集的链接(URL),其他什么都不用改。但替换之后的图像分类有了更高的难度。所以 Fashion-MNIST 和 MNIST 一样,都是灰度图片,28x28 像素,每类 7000 张,一共 70000 张,其中训练集 60000 张,测试集 10000 张。数据集链接
    Fashion-MNIST 是直接从 Zalando 网站上的商品图片提取出来制作的,包括以下7步:① 转换为PNG;② 裁剪;③ 长边缩放为28像素;④ 锐化;⑤ 补足空白;⑥ 取负片;⑦ 取灰度。在这里插入图片描述

数据预处理

通过 PyTorch 的 torchvision 包获取 Fashion-MNIST 数据集。

一般而言,对一个数据集的预处理流程为 ETL,即包含 extract、transform、load 三个步骤。

1.Extract - Get the Fashion-MNIST image data from the data source.Transform - 2.Transform image data into a desirable PyTorch tensor format.
3.Load - Put data into a suitable structure to make it easily accessible.

完成 ETL 流程之后,就可以开始构建和训练深度学习模型。

包的导入

需要把所有需要的PyTorch包导入:

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F

import torchvision
import torchvision.transforms as transforms

对各个包的描述如下:

  • torch - The top-level PyTorch package and tensor library.
  • torch.nn - A subpackage that contains modules and extensible classes for building neural networks.
  • torch.optim - A subpackage that contains standard optimization operations like SGD and Adam.
  • torch.nn.functional - A functional interface that contains typical operations used for building neural networks like loss functions and convolutions.
  • torchvision - A package that provides access to popular datasets, model architectures, and image transformations for computer vision.
  • torchvision.transforms - An interface that contains common transforms for image processing.

在Pytorch中进行 ETL

对于 ETL 流程,PyTorch 提供了两个类(class):
在这里插入图片描述
使用 PyTorch 创建自定义的数据集,我们通过创建子类并继承 Dataset 中的函数,来实现 Dataset 的扩展,然后就可以传递给 DataLoader 对象。
在这里插入图片描述
len() 和 getitm() 是其中两个必要的函数,前者的功能是计算数据集的长度,后者的功能是在数据集中按指定的索引编号将数据取出。

利用torchvison包获取和处理数据集(E+T)

利用 torchvision 获取并创建 Fashion-MNIST 数据集的一份实例(instance),这个过程中同时完成的数据集的获取(E)和转化(T),代码如下:

train_set = torchvision.datasets.FashionMNIST(
    root='./data'
    ,train=True
    ,download=True
    ,transform=transforms.Compose([
        transforms.ToTensor()
    ])
)

在这里插入图片描述
参数解释如下:
在这里插入图片描述
因为希望将图片数据集转换为张量,所以在 transform 中使用了 transforms.ToTensor();将此数据集命名为train_set,是因为我们希望将其作为训练数据;另外数据集仅会被下载一次,程序下载之前会检查本地有没有。

之后将获取的 train_set 打包给 DataLoader,使得数据集可以通过 DataLoader 方便的访问和加载(L):

train_loader = torch.utils.data.DataLoader(train_set)

至此已经完成了数据集的 Extract(利用url从网页下载)和 Transform(上面的transforms.ToTensor()),并且已经打包给了 DataLoader,可以通过 DataLoader 来实现 Load,比如设置 batch_size 和 shuffle:

train_loader = torch.utils.data.DataLoader(train_set
    ,batch_size=1000
    ,shuffle=True
)

访问数据集

首先可以查看数据集中有多少个图片,使用 Python 的 len() 函数:

 >len(train_set)
 60000

查看所有图片的标签,只需要访问 train_set.targets 属性:

> train_set.targets
tensor([9, 0, 0, ..., 3, 0, 5])

如果希望查看数据集中每一个类别有多少个标签(即多少个图片,适用于图片全部有标记的情况),可以用 PyTorch 的 bincount() 函数:

>train_set.targets.bincount()
tensor([6000, 6000, 6000, 6000, 6000, 6000, 6000, 6000, 6000, 6000])

Fashion-MNIST 数据集中每一类都有 6000 个图片和标签对,这种每一类的样本数量相等的数据集称作 balanced dataset,反之类别之间样本数量不一致的数据集称为 unbalanced dataset。

访问和查看 train_set 中的单个数据

一次只查看单张图片,首先将 train_set 这个对象传递给 Python 内建的 iter() 函数,它会返回一个可以在其上迭代的代表数据流(stream of data)的对象,使我们可以沿数据流访问数据。

接下来再使用 Python 的内建函数 next() 来获取数据流中的下一个数据元素,如此就可以获取数据集中的一个单独数据(因此下面命名变量都是单数形式):

> sample = next(iter(train_set))

> len(sample)
2

> type(sample)
tuple

获取的一个单独数据长度为 2,这是因为数据集是由图片-标签对的形式组成的,每一个 data element 中都包含两个东西,一个是存储图片数据的张量,另一个是其对应的标签。

sample 的数据类型是 tuple,tuple 是Python中的一种 sequence types,是一个可以迭代的顺序不可变的数据序列。
可以用 sequence unpacking 来将其中的图像和标签分别提取出来:

> image, label = sample

和下面这种写法是等效的:

> image = sample[0]
> label = sample[1]

查看数据类型和shape:

> type(image)
torch.Tensor

> type(label)
int

> image.shape
torch.Size([1, 28, 28]) 

> torch.tensor(label).shape
torch.Size([])

Fashion-MNIST 数据集是单通道的灰度图,所一张图片的 tensor shape 就是 1x28x28。把没有用的颜色通道 squeeze 掉:

> image.squeeze().shape
torch.Size([28, 28])

显示出图片和标签:

> plt.imshow(image.squeeze(), cmap="gray")
> torch.tensor(label)
tensor(9)

在这里插入图片描述
标签是“9”,代表靴子,与图片是相符的。

利用 DataLoader 成批访问数据

> batch = next(iter(train_loader))

> len(batch)
2

> type(batch)
list

list 也是一种 Python sequence types,与 tuple 的不同在于 list 是可变序列。

一次访问 10 张图片,则需要给 DataLoader 指定 batch_size:

> display_loader = torch.utils.data.DataLoader(
    train_set, batch_size=10
)

关于 DataLoader 中的“shuffle=True”:如果“shuffle=True”,则每次调用 next() 返回的 batch 都会不同,训练集中的第一组样本将在第一次调用 next() 时返回,这个功能默认是 False。

可以像上面一样对 display_loader 使用 iter() 和 next() 来每次查看 10 张图片:

> batch = next(iter(display_loader))
> print('len:', len(batch))
len: 2

进行 sequence unpacking:

> images, labels = batch

> print('types:', type(images), type(labels))
> print('shapes:', images.shape, labels.shape)
types: <class 'torch.Tensor'> <class 'torch.Tensor'>
shapes: torch.Size([10, 1, 28, 28]) torch.Size([10])

此时返回的图像张量是 [10, 1, 28, 28] 的四阶张量,标签是一个长度为 10 的一阶张量。可以单独查看其中每一个图片和标签:

> images[0].shape
torch.Size([1, 28, 28])

> labels[0]
tensor(9)

一次绘制一批图像,可以使用 torchvision.utils.make_grid() 函数创建一个可以按网格绘制图片的 grid:

> grid = torchvision.utils.make_grid(images, nrow=10)    # nrow指定每行多少列图片

> plt.figure(figsize=(15,15))        # 缩放图像显示大小?
> plt.imshow(grid.permute(1,2,0))    # 这一步让grid符合imshow的要求,不清楚细节

> print('labels:', labels)
labels: tensor([9, 0, 0, 3, 0, 2, 7, 2, 5, 5])\

在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/83469.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

政务中心站至政务中心东站右线盾构本月始发

本报记者 赵鹏 实习记者 池阳 通讯员 董浩程 立秋已过&#xff0c;平谷线“瓜熟蒂落”的日子指日可待。在左线隧道刚刚顺利贯通后&#xff0c;平谷线政务中心站至政务中心东站区间右线隧道已展开盾构组装施工&#xff0c;右线盾构即将于本月内始发&#xff0c;被誉为“地下蛟龙…

如何查看Linux内核版本

如何查看Linux内核版本 uname -r用centos7.0&#xff0c;内核版本就是3.10

Cesium加载ArcGIS Server4490且orgin -400 400的切片服务

Cesium在使用加载Cesium.ArcGisMapServerImageryProvider加载切片服务时&#xff0c;默认只支持wgs84的4326坐标系&#xff0c;不支持CGCS2000的4490坐标系。 如果是ArcGIS发布的4490坐标系的切片服务&#xff0c;如果原点在orgin X: -180.0Y: 90.0的情况下&#xff0c;我们可…

设计模式之组合模式(Composite)的C++实现

1、组合模式的提出 在软件开发过程中&#xff0c;使用者Client过多依赖所操作对象内部的实现结构&#xff0c;如果对象内部的实现结构频繁发生变化&#xff0c;则使用者的代码结构将要频繁地修改&#xff0c;不利于代码地维护和扩展性&#xff1b;组合模式可以解决此类问题。组…

云养猪平台如何开发

随着数字化和智能化的发展&#xff0c;农业行业也逐渐开始融入互联网技术&#xff0c;其中云养猪平台作为新兴的农业数字化解决方案之一&#xff0c;备受关注。本文将探讨如何开发一款具备专业、思考深度和逻辑性的云养猪平台。 一、前期准备阶段&#xff1a; 1.明确目…

SecureCRT 备份Button Bar中所有Button

一、前言 Button Bar功能可以保存一些常用命令避免重复输入&#xff0c;但是有时候secureCRT的button bar经常莫名其妙消失&#xff0c;重装系统或软件后&#xff0c;也都需要重新一个个添加Button&#xff0c;如果能备份就能减少这些费时间的操作 二、备份步骤 在面板Optio…

数字孪生助力智慧水务:科技创新赋能水资源保护

智慧水务中&#xff0c;数字孪生有着深远的作用&#xff0c;正引领着水资源管理和环境保护的创新变革。随着城市化和工业化的不断推进&#xff0c;水资源的可持续利用和管理愈发显得重要&#xff0c;而数字孪生技术为解决这一挑战提供了独特的解决方案。 数字孪生技术&#xf…

typescript 声明文件

作用 1、为已存在js库提供类型信息&#xff0c;这样在ts项目中使用这些库时候&#xff0c;就像用ts一样&#xff0c;会有代码提示、类型保护等机制 2、项目内共享类型&#xff1a;如果多个.ts文件中都用到同一个类型&#xff0c;此时可以创建.d.ts文件提供该类型&#xff0c;…

常见的软件测试用例设计方法有哪些?

常见的软件测试用例设计方法&#xff0c;个人认为主要是下面这6种&#xff1a; 1)流程图法&#xff08;也叫场景法&#xff09; 2)等价类划分法 3)边界值分析 4)判定表 5)正交法 6)错误推测法 这6种常见方法中&#xff0c;我分别按照定义、应用场景、使用步骤、案例讲解这4个部…

# 59. python的类与对象-更新

[目录] 文章目录 59. python的类与对象-更新1.面向对象编程2.什么是类3.什么是对象4.如何描述对象5.对象的属性和方法6.Python中的类7.type()函数查看数据类型8.类在Python中的应用9.总结 【正文】 59. python的类与对象-更新 1.面向对象编程 本节内容特别抽象&#xff0c;初…

动手学深度学习-pytorch版本(二):线性神经网络

参考引用 动手学深度学习 1. 线性神经网络 神经网络的整个训练过程&#xff0c;包括: 定义简单的神经网络架构、数据处理、指定损失函数和如何训练模型。经典统计学习技术中的线性回归和 softmax 回归可以视为线性神经网络 1.1 线性回归 回归 (regression) 是能为一个或多个…

[软件工具]精灵标注助手目标检测数据集格式转VOC或者yolo

有时候我们拿到一个数据集发现是xml文件格式如下&#xff1a; <?xml version"1.0" ?> <doc><path>C:\Users\Administrator\Desktop\test\000000000074.jpg</path><outputs><object><item><name>dog</name>…

SSL证书如何使用?SSL保障通信安全

由于SSL技术已建立到所有主要的浏览器和WEB服务器程序中&#xff0c;因此&#xff0c;仅需安装数字证书或服务器证书就可以激活功能了。SSL证书主要是服务于HTTPS&#xff0c;部署证书后&#xff0c;网站链接就由HTTP开头变为HTTPS。 SSL安全证书主要用于发送安全电子邮件、访…

Numpy入门(4)— 保存和导入文件

NumPy保存和导入文件 4.1 文件读写 NumPy可以方便的进行文件读写&#xff0c;如下面这种格式的文本文件&#xff1a; # 使用np.fromfile从文本文件a.data读入数据 # 这里要设置参数sep &#xff0c;表示使用空白字符来分隔数据 # 空格或者回车都属于空白字符&#xff0c;读…

【仿写tomcat】五、响应静态资源(访问html页面)、路由支持以及多线程改进

访问html页面 如果我们想访问html页面其实就是将本地的html文件以流的方式响应给前端即可&#xff0c;下面我们对HttpResponseServlet这个类做一些改造 package com.tomcatServer.domain;import com.tomcatServer.utils.ScanUtil;import java.io.IOException; import java.io…

MySQL的Json类型字段IN查询分组和优化方法

前言 MySQL从5.7的版本开始支持Json后&#xff0c;我时常在设计表格时习惯性地添加一个Json类型字段&#xff0c;用做列的冗余。毕竟Json的非结构性&#xff0c;存储数据更灵活&#xff0c;比如接口请求记录用于存储请求参数&#xff0c;因为每个接口入参不一致&#xff0c;也…

【TypeScript】基础类型

安装 Node.js 环境 https://nodejs.org/en 终端中可以查到版本号即安装成功。 然后&#xff0c;终端执行npm i typescript -g安装 TypeScript 。 查到版本号即安装成功。 字符串类型 let str:string "Hello"; console.log(str);终端中先执行tsc --init&#xf…

第二届人工智能与智能信息处理技术国际学术会议(AIIIP 2023)

第二届人工智能与智能信息处理技术国际学术会议&#xff08;AIIIP 2023&#xff09; 2023 2nd International Conference on Artificial Intelligence and Intelligent Information Processing 第二届人工智能与智能信息处理技术国际学术会议&#xff08;AIIIP 2023&#xf…

ATTCK实战系列——红队实战(一)

目录 搭建环境问题 靶场环境 web 渗透 登录 phpmyadmin 应用 探测版本 写日志获得 webshell 写入哥斯拉 webshell 上线到 msf 内网信息收集 主机发现 流量转发 端口扫描 开启 socks 代理 服务探测 getshell 内网主机 浏览器配置 socks 代理 21 ftp 6002/700…

CentOS 8.5修改安装包镜像源

1 备份原配置 cd /etc/yum.repos.d mkdir backup mv *.repo backup/2 下载镜像源 2.1 使用wget下载 wget http://mirrors.aliyun.com/repo/Centos-8.repo2.2 使用curl下载 我是安装的最小版本的系统&#xff0c;默认只有curl curl使用方法&#xff1a;https://www.ruanyife…