pytorch基础2-数据集与归一化

专题链接:https://blog.csdn.net/qq_33345365/category_12591348.html

本教程翻译自微软教程:https://learn.microsoft.com/en-us/training/paths/pytorch-fundamentals/

初次编辑:2024/3/2;最后编辑:2024/3/2


本教程第一篇介绍pytorch基础和张量操作

这是本教程的第二篇,介绍以下内容:

  1. 数据集与数据加载器
  2. 归一化

另外本人还有pytorch CV相关的教程,见专题:

https://blog.csdn.net/qq_33345365/category_12578430.html


数据集与数据加载器 Datasets and Dataloaders

处理数据样本的代码可能会变得复杂且难以维护。通常希望数据集代码与模型训练代码解耦,以获得更好的可读性和模块化性。PyTorch提供了两个数据原语:torch.utils.data.DataLoadertorch.utils.data.Dataset,它们使您能够使用预加载的数据集以及您自己的数据。Dataset存储样本及其对应的标签,而DataLoader则在Dataset周围包装了一个可迭代对象,以便轻松访问样本。

PyTorch库提供了许多预加载的样本数据集(例如FashionMNIST),它们是torch.utils.data.Dataset的子类,并实现了特定于特殊数据的功能。用于模型原型设计(prototyping)和基准测试(benchmark)的样本包括:

  1. 图像数据集
  2. 文本数据集
  3. 音频数据集

加载数据集

此处使用TorchVision来加载Fashion-MNIST数据集。Fashion-MNIST是Zalando的服装图像数据集,包含60,000个训练样本和10,000个测试样本。每个样本包括一个28×28的灰度图像和一个来自10个类别中的关联标签。

  • 每个图像高度为28个像素,宽度为28个像素,总共有784个像素。
  • 这10个类别告诉图像是什么类型,例如:T恤/上衣、裤子、套头衫、连衣裙、包、短靴等。
  • 灰度像素的值介于0到255之间,用于衡量黑白图像的强度。强度值从白色到黑色递增。例如:白色的颜色值为0,而黑色的颜色值为255。

我们使用以下参数加载FashionMNIST数据集:

  • root 是存储训练/测试数据的路径。
  • train 指定训练或测试数据集。
  • download=True 如果数据在root中不可用,则从互联网下载数据。
  • transformtarget_transform 指定特征和标签的转换。

加载训练和测试数据集的代码如下所示:

import torch
from torch.utils.data import Dataset
from torchvision import datasets
from torchvision.transforms import ToTensor, Lambda
import matplotlib.pyplot as plt

training_data = datasets.FashionMNIST(
    root="data",
    train=True,
    download=True,
    transform=ToTensor()
)

test_data = datasets.FashionMNIST(
    root="data",
    train=False,
    download=True,
    transform=ToTensor()
)

数据集的迭代与可视化

可以像列表一样手动索引Datasets: training_data[index]。此处使用matplotlib对训练数据中的一些样本进行可视化。代码如下所示:

labels_map = {
    0: "T-Shirt",
    1: "Trouser",
    2: "Pullover",
    3: "Dress",
    4: "Coat",
    5: "Sandal",
    6: "Shirt",
    7: "Sneaker",
    8: "Bag",
    9: "Ankle Boot",
}
figure = plt.figure(figsize=(8, 8))
cols, rows = 3, 3
for i in range(1, cols * rows + 1):
    sample_idx = torch.randint(len(training_data), size=(1,)).item()
    img, label = training_data[sample_idx]
    figure.add_subplot(rows, cols, i)
    plt.title(labels_map[label])
    plt.axis("off")
    plt.imshow(img.squeeze(), cmap="gray")
plt.show()

得到如下图片:

在这里插入图片描述

使用数据加载器(DataLoaders)准备自定义数据

Dataset 逐个样本检索数据集的特征和标签。在训练模型时,通常希望以“minibatch”的形式传递样本,在每个周期重混洗(reshuffle)数据以减少模型过拟合,并使用Python的多进程加速数据检索。

在机器学习中,需要指定数据集中的特征和标签。**特征(Feature)**是输入,**标签(label)**是输出。我们训练特征,然后训练模型以预测标签。

  • 特征是图像像素中的模式。
  • 标签是10个类别类型:T恤,凉鞋,连衣裙等。

DataLoader是一个可迭代对象,用简单的API抽象了这种复杂性。为了使用DataLoader,我们需要设置以下参数:

  • data 用于训练模型的训练数据,以及用于评估模型的测试数据。
  • batch size 每个批次要处理的记录数。
  • shuffle 按索引随机抽取数据样本。

使用DataLoader处理两个数据集的代码如下所示:

from torch.utils.data import DataLoader

train_dataloader = DataLoader(training_data, batch_size=64, shuffle=True)
test_dataloader = DataLoader(test_data, batch_size=64, shuffle=True)

使用DataLoader进行迭代

我们已将数据集加载到 DataLoader 中,现在可以根据需要迭代数据集。 下面的每次迭代都会返回一个批次的 train_featurestrain_labels(分别包含 batch_size=64 个特征和标签)。由于我们指定了 shuffle=True,在遍历完所有批次后,数据将被混洗(shuffle),以更精细地控制数据加载顺序。

展示图片和标签的代码如下所示:

# Display image and label.
train_features, train_labels = next(iter(train_dataloader))
print(f"Feature batch shape: {train_features.size()}")
print(f"Labels batch shape: {train_labels.size()}")
img = train_features[0].squeeze()
label = train_labels[0]
plt.imshow(img, cmap="gray")
plt.show()
label_name = list(labels_map.values())[label]
print(f"Label: {label_name}")

得到图片:

在这里插入图片描述

归一化

标准化是一种常见的数据预处理技术,用于对数据进行缩放或转换,以确保每个特征都有相等的学习贡献。例如,灰度图像中的每个像素的值都介于0和255之间,这些是特征。如果一个像素值为17,另一个像素为197,则像素重要性的分布将不均匀,因为较高的像素值会偏离学习。标准化改变了数据的范围,而不会扭曲其特征之间的区别。这种预处理是为了避免:

  • 减少预测精度
  • 模型学习困难
  • 特征数据范围的不利分布

Transforms

数据并不总是以最终处理过的形式呈现,这种形式适合训练机器学习算法。可以使用transforms来操作数据,使其适合训练。

所有 TorchVision 数据集都有两个参数(transform用于修改特征,target_transform用于修改标签),这些参数接受包含转换逻辑的可调用对象。torchvision.transforms 模块提供了几种常用的转换。

FashionMNIST 的特征是 PIL 图像格式,标签是整数。 对于训练,需要将特征转换为归一化的张量,将标签转换为 one-hot 编码的张量。 为了进行这些转换,将使用 ToTensorLambda

from torchvision import datasets
from torchvision.transforms import ToTensor, Lambda

ds = datasets.FashionMNIST(
    root="data",
    train=True,
    download=True,
    transform=ToTensor(),
    target_transform=Lambda(lambda y: torch.zeros(10, dtype=torch.float).scatter_(0, torch.tensor(y), value=1))
)

解释上述代码:

ToTensor()

ToTensor 将PIL图像或NumPy ndarray转换为FloatTensor,并将图像的像素强度值缩放到范围 [0., 1.]内。

Lambda transforms

Lambda transforms 应用任何用户定义的lambda函数。在这里,我们定义了一个函数来将整数转换为一个one-hot编码的张量。它首先创建一个大小为10的零张量(数据集中标签的数量),然后调用scatter,根据标签y的索引为其分配值为1。也可以使用torch.nn.functional.one_hot的方式来实现这一点。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/424269.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

2024-03-01(金融AI行业与大数据生态圈)

1.金融这一块的算法,不像推荐系统,图像等领域,金融领域的算法都比较成熟了。现在来说门槛低,属于初期阶段,上升期。 2.反欺诈的数据标签比较少,有一种“标签染色”的方法来做反欺诈模型的标签。 3.常用反…

什么是Vue指令?请列举一些常见的Vue指令以及它们的用法

Vue.js 是一款流行的前端框架,它的指令(Directives)是 Vue.js 提供的一种特殊属性,用于在模板中对 DOM 元素进行直接操作。指令通常是以 v- 开头的特殊属性,用于响应式地将数据绑定到 DOM 元素上。 在 Vue 中&#xf…

【NTN 卫星通信】卫星和无人机配合的应用场景

1 场景概述 卫星接入网是一种有潜力的技术,可以为地面覆盖差地区的用户提供无处不在的网络服务。然而,卫星覆盖范围对于位于考古或采矿地点内部/被茂密森林覆盖的村庄/山谷/靠近山丘或大型建筑物的用户可能很稀疏。因此,涉及卫星接入和无人驾…

131. 分割回文串(力扣LeetCode)

文章目录 131. 分割回文串题目描述回溯代码 131. 分割回文串 题目描述 给你一个字符串 s,请你将 s 分割成一些子串,使每个子串都是 回文串 。返回 s 所有可能的分割方案。 回文串 是正着读和反着读都一样的字符串。 示例 1: 输入&#xf…

C#,基于密度的噪声应用空间聚类算法(DBSCAN Algorithm)源代码

1 聚类算法 聚类分析或简单聚类基本上是一种无监督的学习方法,它将数据点划分为若干特定的批次或组,使得相同组中的数据点具有相似的属性,而不同组中的数据点在某种意义上具有不同的属性。它包括许多基于差分进化的不同方法。 E、 g.K-均值…

kafka文件存储机制和消费者

1.broker文件存储机制 去查看真正的存储文件: 在/opt/module/kafka/datas/ 路径下 kafka-run-class.sh kafka.tools.DumpLogSegments --files ./00000000000000000000.index 如果是6415那么这个会存储在563的log文件之中,因为介于6410和10090之间。 2.…

STM32使用FlyMcu串口下载程序与STLink Utility下载程序

文章目录 前言软件链接一、FlyMcu串口下载程序原理优化手动修改跳线帽选项字节其他功能 二、STLink Utility下载程序下载程序选项字节固件更新 前言 本文主要讲解使用FlyMcu配合USART串口为STM32下载程序、使用STLink Utility配合STLink为STM32下载程序,以及这两个…

Stable Diffusion 模型分享:AAM XL (Anime Mix)(动漫截屏风格 XL)

本文收录于《AI绘画从入门到精通》专栏,专栏总目录:点这里。 文章目录 模型介绍生成案例案例一案例二案例三案例四案例五案例六案例七案例八 下载地址 模型介绍 AAM XL (Anime Mix) 是一个动漫截屏风格的模型,是 AAM - AnyLoRA Anime Mix 模…

【办公类-25-01】20240302 UIBOT上传 ”班级主页-育儿知识(家园小报)“

作品展示: 一、背景需求: 本学期制作了 “育儿知识(家园小报)”合并A4内容 【办公类-22-08】周计划系列(4)“育儿知识(家园小报)“ (2024年调整版本)-CSDN博…

【Qt学习笔记】(四)Qt窗口

Qt窗口 1 菜单栏1.1 创建菜单栏1.2 在菜单栏中添加菜单1.3 创建菜单项1.4 在菜单项之间添加分割线1.5 给菜单项添加槽函数1.6 给菜单项添加快捷键 2 工具栏2.1 创建工具栏2.2 设置停靠位置2.3 设置浮动属性2.4 设置移动属性2.5 添加 Action 3 状态栏3.1 状态栏的创建3.2 在状态…

【高级数据结构】Trie树

原理 介绍 高效地存储和查询字符串的数据结构。所以其重点在于:存储、查询两个操作。 存储操作 示例和图片来自:https://blog.csdn.net/qq_42024195/article/details/88364485 假设有这么几个字符串:b,abc,abd&…

基于springboot+vue的高校教师科研管理系统

博主主页:猫头鹰源码 博主简介:Java领域优质创作者、CSDN博客专家、阿里云专家博主、公司架构师、全网粉丝5万、专注Java技术领域和毕业设计项目实战,欢迎高校老师\讲师\同行交流合作 ​主要内容:毕业设计(Javaweb项目|小程序|Pyt…

Redis高级特性和应用(发布、订阅、Stream、慢查询、Pipeline、事务、Lua)

Redis高级特性和应用 发布和订阅 Redis提供了基于“发布/订阅”模式的消息机制,此种模式下,消息发布者和订阅者不进行直接通信,发布者客户端向指定的频道( channel)发布消息,订阅该频道的每个客户端都可以收到该消息。 操作命令 Redis主要…

手写 Attention 迷你LLaMa2——LLM实战

https://github.com/Yuezhengrong/Implement-Attention-TinyLLaMa-from-scratch 1. Attention 1.1 Attention 灵魂10问 你怎么理解Attention? Scaled Dot-Product Attention中的Scaled: 1 d k \frac{1}{\sqrt{d_k}} dk​ ​1​ 的目的是调节内积&…

OpenAI划时代大模型——文本生成视频模型Sora作品欣赏(十三)

Sora介绍 Sora是一个能以文本描述生成视频的人工智能模型,由美国人工智能研究机构OpenAI开发。 Sora这一名称源于日文“空”(そら sora),即天空之意,以示其无限的创造潜力。其背后的技术是在OpenAI的文本到图像生成模…

BUGKU bp

打开环境,他提示了弱密码top1000,随便输入密码123抓包爆破 发现长度都一样,看一下响应发现一段js代码,若r值为{code: bugku10000},则会返回错误,通过这一句“window.location.href success.php?coder.cod…

【软考】设计模式之访问者模式

目录 1. 说明2. 应用场景3. 结构图4. 构成5. java示例5.1 喂动物5.1.1 抽象访问者5.1.2 具体访问者5.1.3 抽象元素5.1.4 具体元素5.1.5 对象结构5.1.6 客户端类5.1.7 结果示例 5.2 超市销售系统5.2.1 业务场景5.2.2 业务需求类图5.2.3 抽象访问者5.2.4 具体访问者5.2.5 抽象元素…

实战:Oracle Weblogic 11g配置无密码启动,启动关闭脚本,修改节点内存

导读 上篇博文介绍了Oracle Weblogic 11g的安装部署,本文介绍Weblogic安装后的基本配置 包括:设置weblogic启动关闭的无密码验证,启动关闭脚本,修改默认的节点内存。 1、配置无密码启动 [weblogicw1 base_domain]$ cd servers/ […

C#,无监督的K-Medoid聚类算法(K-Medoid Algorithm)与源代码

1 K-Medoid算法 K-Medoid(也称为围绕Medoid的划分)算法是由Kaufman和Rousseeuw于1987年提出的。中间点可以定义为簇中的点,其与簇中所有其他点的相似度最小。 K-medoids聚类是一种无监督的聚类算法,它对未标记数据中的对象进行聚…

计算机网络|Socket

文章目录 Socket并发socket Socket Socket是一种工作在TCP/IP协议栈上的API。 端口用于区分不同应用,IP地址用于区分不同主机。 以下是某一个服务器的socket代码。 其中with是python中的一个语法糖,代表当代码块离开with时,自动对s进行销毁…