昇思25天学习打卡营第1天 | 快速入门

打卡截图
今天开始学习Mindspore框架,首先需要引入数据集,以Mnist数据集为例:

处理数据集

# Download data from open datasets
from download import download

url = "https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/" \
      "notebook/datasets/MNIST_Data.zip"
path = download(url, "./", kind="zip", replace=True)

数据集 下载完后通过下面的代码获取训练集和测试集,train_dataset表示训练集,test_dataset表示测试集。

train_dataset = MnistDataset('MNIST_Data/train')
test_dataset = MnistDataset('MNIST_Data/test')

可以通过打印的方式来查看数据集的列名

print(train_dataset.get_col_names())

接下来对数据进行预处理

vision.Rescale(1.0 / 255.0, 0)指的是将图像的像素值缩放在[0,1]范围内,通过将每个像素值除以255实现。
​​​​​​在这里插入图片描述

vision.Normalize(mean=(0.1307,), std=(0.3081,))是对数据进行归一化处理,mean就是均值,std是标准差,每个像素值减去均值并除以标准差。

vision.HWC2CHW()将图像布局从高度-宽度-通道改为通道-高度-宽度。

label_transform = transforms.TypeCast(mindspore.int32) 将数据类型改为mindspore的类型

dataset = dataset.map(image_transforms, ‘image’)

dataset = dataset.map(label_transform, ‘label’)将数据预处理应用到数据集上。

dataset = dataset.batch(batch_size) 将数据集划分每batchsize为一批。

def datapipe(dataset, batch_size):
    image_transforms = [
        vision.Rescale(1.0 / 255.0, 0),
        vision.Normalize(mean=(0.1307,), std=(0.3081,)),
        vision.HWC2CHW()
    ]
    label_transform = transforms.TypeCast(mindspore.int32)

    dataset = dataset.map(image_transforms, 'image')
    dataset = dataset.map(label_transform, 'label')
    dataset = dataset.batch(batch_size)
    return dataset

可以用create_tuple_iterator 或create_dict_iterator对数据进行遍历,create_tuple_iterator是创建一个元组迭代器,create_dict_iterator是创建一个字典迭代器

网络构建
mindspore.nn类是构建所有网络的基类,也是网络的基本单元。当用户需要自定义网络时,可以继承nn.Cell类,并重写__init__方法和construct方法。__init__包含所有网络层的定义,construct中包含数据(Tensor)的变换过程。

输入x首先被flatten层展平,然后通过全连接-relu-sequential得到最后的分数

# Define model
class Network(nn.Cell):
    def __init__(self):
        super().__init__()
        self.flatten = nn.Flatten()
        self.dense_relu_sequential = nn.SequentialCell(
            #将28*28个输入节点映射到512个节点
            nn.Dense(28*28, 512),
            #激活函数
            nn.ReLU(),
            nn.Dense(512, 512),
            nn.ReLU(),
            nn.Dense(512, 10)
        )

    def construct(self, x):
        x = self.flatten(x)
        logits = self.dense_relu_sequential(x)
        return logits

model = Network()
print(model)

模型训练
在模型训练中,一个完整的训练过程(step)需要实现以下三步:

1.正向计算:模型预测结果(logits),并与正确标签(label)求预测损失(loss)。

2.反向传播:利用自动微分机制,自动求模型参数(parameters)对于loss的梯度(gradients)。

3.参数优化:将梯度更新到参数上。

MindSpore使用函数式自动微分机制,因此针对上述步骤需要实现:

1.定义正向计算函数。

2.使用value_and_grad通过函数变换获得梯度计算函数。

3.定义训练函数,使用set_train设置为训练模式,执行正向计算、反向传播和参数优化。

# Instantiate loss function and optimizer
#实例化了一个交叉熵损失函数,这是在多类别分类问题中常用的损失函数,用于衡量模型输出(logits)和真实标签之间的差异。
loss_fn = nn.CrossEntropyLoss()
#实例化了一个随机梯度下降(SGD)优化器,学习率为0.01(1e-2),用于更新模型的权重。
optimizer = nn.SGD(model.trainable_params(), 1e-2)

# 1. Define forward function
#这个函数接受数据和标签作为输入,通过模型计算出logits,然后使用损失函数计算损失值。它返回损失值和logits。
def forward_fn(data, label):
    logits = model(data)
    loss = loss_fn(logits, label)
    return loss, logits

# 2. Get gradient function
#这个函数为前向传播函数创建了一个梯度函数。value_and_grad函数会返回一个新函数,该函数在计算损失值的同时也会计算梯度。None表示不关心输入数据的梯度,optimizer.parameters指定了需要计算梯度的参数,has_aux=True表示forward_fn会返回一个包含主输出(损失值)和一个辅助输出(logits)的元组。
grad_fn = mindspore.value_and_grad(forward_fn, None, optimizer.parameters, has_aux=True)

# 3. Define function of one-step training
#这个函数执行一个训练步骤,包括计算损失和梯度,以及应用梯度更新模型的权重。grad_fn返回损失值和梯度,然后使用优化器更新模型参数。
def train_step(data, label):
    (loss, _), grads = grad_fn(data, label)
    optimizer(grads)
    return loss
#这个函数执行整个训练过程。首先,它获取数据集的大小,然后将模型设置为训练模式。接着,它遍历数据集的批次,并调用train_step函数进行训练。每100个批次,它会打印当前的损失值和批次编号。
def train(model, dataset):
    size = dataset.get_dataset_size()
    model.set_train()
    #enumerate用于遍历序列(如列表、元组或字符串)并返回序列中的元素以及它们的索引。enumerate函数通常用于在循环中同时获取元素和它们的索引。
    for batch, (data, label) in enumerate(dataset.create_tuple_iterator()):
        loss = train_step(data, label)

        if batch % 100 == 0:
        #这行代码将损失值从MindSpore的Tensor格式转换为NumPy数组格式,以便可以打印出来。同时,它将当前的批次编号赋值给变量current。
            loss, current = loss.asnumpy(), batch
            #这行代码使用Python的格式化字符串功能(f-string)来打印损失值和批次信息。loss:>7f表示将损失值打印为至少7位小数的浮点数,并在必要时进行右对齐。current:>3d和size:>3d分别表示将当前批次编号和数据集大小打印为至少3位整数的数字,并在必要时进行右对齐。
            print(f"loss: {loss:>7f}  [{current:>3d}/{size:>3d}]")

训练完就该测试了,下面是测试函数

def test(model, dataset, loss_fn):
#获取数据集的批次数,这是通过调用数据集的 get_dataset_size 方法得到的。
    num_batches = dataset.get_dataset_size()
#将模型设置为评估模式。在评估模式下,模型中的某些层(如批量归一化层和dropout层)的行为会与训练模式不同。
    model.set_train(False)
#初始化三个变量,total 用于存储总样本数,test_loss 用于存储总损失,correct 用于存储预测正确的样本数。
    total, test_loss, correct = 0, 0, 0
    for data, label in dataset.create_tuple_iterator():
#通过模型传递输入数据,得到预测结果 pred(logits)
        pred = model(data)
        #更新总样本数。
        total += len(data)
#计算当前批次的损失,并累加到 test_loss 中。asnumpy() 将损失从MindSpore的Tensor格式转换为NumPy数组格式。
        test_loss += loss_fn(pred, label).asnumpy()
#计算当前批次中预测正确的样本数,并累加到 correct 中。pred.argmax(1) 获取每个样本的最可能类别(即概率最高的类别),然后与真实标签 label 进行比较,得到一个布尔数组,其中 True 表示预测正确,False 表示预测错误。.asnumpy().sum() 将布尔数组转换为NumPy数组并计算其中 True 的数量,即预测正确的样本数。
        correct += (pred.argmax(1) == label).asnumpy().sum()
#计算平均损失,通过将总损失除以批次数得到。
    test_loss /= num_batches
#计算准确率,通过将预测正确的样本数除以总样本数得到。
    correct /= total
#打印测试结果,包括准确率和平均损失。{(100*correct):>0.1f}% 将准确率转换为百分比形式,并保留一位小数;{test_loss:>8f} 将平均损失打印为至少8位小数的浮点数。
    print(f"Test: \n Accuracy: {(100*correct):>0.1f}%, Avg loss: {test_loss:>8f} \n")
    

训练过程需多次迭代数据集,一次完整的迭代称为一轮(epoch)。在每一轮,遍历训练集进行训练,结束后使用测试集进行预测。打印每一轮的损失(loss)值和预测准确率(Accuracy),可以看到loss在不断下降,Accuracy在不断提高。

epochs = 3
for t in range(epochs):
    print(f"Epoch {t+1}\n-------------------------------")
    train(model, train_dataset)
    test(model, test_dataset, loss_fn)
print("Done!")

保存模型

模型训练完成后,需要将其参数进行保存。

# Save checkpoint
mindspore.save_checkpoint(model, "model.ckpt")
print("Saved Model to model.ckpt")

加载保存的权重分为两步:

1.重新实例化模型对象,构造模型。

2.加载模型参数,并将其加载至模型上。

# Instantiate a random initialized model
model = Network()
# Load checkpoint and load parameter to model
param_dict = mindspore.load_checkpoint("model.ckpt")
#使用mindspore.load_param_into_net函数将加载的参数字典param_dict中的参数值加载到之前实例化的模型model中。这个函数返回两个值,param_not_load是一个列表,包含了未能加载的参数名称,而第二个返回值是一个布尔值,指示是否所有参数都成功加载。
param_not_load, _ = mindspore.load_param_into_net(model, param_dict)
print(param_not_load)

加载后的模型可以直接用于预测推理。

model.set_train(False)
for data, label in test_dataset:
    pred = model(data)
# 对于每个样本,argmax函数会返回得分最高的类别的索引,即预测的标签。argmax(1)表示沿着第二个维度(索引为1的维度,即类别维度)取最大值的索引。
    predicted = pred.argmax(1)
    print(f'Predicted: "{predicted[:10]}", Actual: "{label[:10]}"')
    break

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/739819.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

qt.qpa.xcb: could not connect to display问题解决

1、问题描述 以服务器pi5作为远程解释器,本地win11使用vscode远程调试视觉时报错如下: qt.qpa.xcb: could not connect to display qt.qpa.plugin: Could not load the Qt platform plugin "xcb" in "xxxxx" even though it was …

EasyExcel 无法读取图片?用poi写了一个工具类

在平时的开发中,经常要开发 Excel 的导入导出功能。一般使用 poi 或者 EasyExcel 开发,使用 poi 做 excel 比较复杂,大部分开发都会使用 EasyExcel 因为一行代码就能实现导入和导出的功能。但是 EasyExcel 不支持图片的读的操作,本…

CRMEB 多商户Java版v1.6公测版发布,付费会员上线,立即体验

新版本来袭!CRMEB 多商户Java版v1.6正式发布! 在v1.6新版本中,我们带来了付费会员体系,这将让商业模式更加灵活多元,新增加的移动端商家管理,也让运营触手可及,更加便捷,还有商家端员…

(2011-2022年) 全国各省快递业务量与快递业务收入面板数据

中国快递业近年来随着电子商务的蓬勃发展而迅速壮大,成为现代生活中不可或缺的一部分。快递业务量与收入的面板数据为我们提供了一个观察中国快递市场繁荣与多元化的窗口。 数据来源 中国统计年鉴 参考文献 胡润哲, 魏君英, 陈银娥. 数字经济发展对农村居民服务…

sheng的学习笔记-聚类(Clustering)

ai目录 sheng的学习笔记-AI目录-CSDN博客 基础知识 什么是聚类 在“无监督学习”(unsupervised learning)中,训练样本的标记信息是未知的,目标是通过对无标记训练样本的学习来揭示数据的内在性质及规律,为进一步的数据分析提供基础。此类学…

智慧安防/边缘计算EasyCVR视频汇聚网关:EasySearch无法探测到服务器如何处理?

安防监控EasyCVR智能边缘网关/视频汇聚网关/视频网关属于软硬一体的边缘计算硬件,可提供多协议(RTSP/RTMP/国标GB28181/GAT1400/海康Ehome/大华/海康/宇视等SDK)的设备接入、音视频采集、视频转码、处理、分发等服务,系统具备实时…

都说HCIE“烂大街”了,说难考都是假的?

在网络技术领域,华为认证互联网专家(HCIE)长期以来被视为一项高端认证,代表着专业技能和知识水平。 然而,近几年来,考证的重视度直线上升,考HCIE的人越来越多了,考过的人好像也越来越…

桌面编辑器ONLYOFFICE 功能多样性快来试试吧!

目录 ONLYOFFICE 桌面编辑器 8.1 ONLYOFFICE介绍 主要功能和特点 使用场景 1.PDF编辑器 2.幻灯片版式 3.编辑,审阅和查看模式 4.隐藏连接到云版块 5.RTL语言支持和本地化选项 6.媒体播放器 7、其他新功能 8.下载 总结 ONLYOFFICE 桌面编辑器 8.1 官网地…

新火种AI|OpenAI CTO表示:未来将会有越来越多的人被AI所取代...

作者:小岩 编辑:彩云 对于“AI是否能最终取代人类进行工作”这件事儿,很多学者持有否定态度。大家普遍认为,即便如今诞生了ChatGPT,Claude等强大的AI工具,它们也只能够解决一些格式化,重复化的…

【力扣C++】爬楼梯

假设你正在爬楼梯。需要 n 阶你才能到达楼顶。 每次你可以爬 1 或 2 个台阶。你有多少种不同的方法可以爬到楼顶呢? 示例 1: 输入:n 2 输出:2 解释:有两种方法可以爬到楼顶。 1. 1 阶 1 阶 2. 2 阶 示例 2&#x…

鸿蒙开发系统基础能力:【@ohos.hidebug (Debug调试)】

Debug调试 说明: 本模块首批接口从API version 8开始支持。后续版本的新增接口,采用上角标单独标记接口的起始版本。 使用hidebug,可以获取应用内存的使用情况,包括应用进程的静态堆内存(native heap)信息…

想要将视频做二维码,试试这个方法吧

视频内容做成二维码用于分享内容的一种常用方式,而且通过二维码来分享视频内容与传统方式相比也更加的方便,用户只需要扫描二维码就可以观看视频内容,在很多的使用场景中的都有应用。那么如何操作能够快速制作一个视频二维码呢? …

算法设计与分析:动态规划法求扔鸡蛋问题 C++

目录 一、实验目的 二、问题描述 三、实验要求 四、算法思想和实验结果 1、动态规划法原理: 2、解决方法: 2.1 方法一:常规动态规划 2.1.1 算法思想: 2.1.2 时间复杂度分析 2.1.3 时间效率分析 2.2 方法二:动态规划加…

光伏发电项目是如何提高开发效率的?

随着全球对可再生能源需求的持续增长,光伏发电项目的高效开发成为关键。本文将深入探讨如何在实际操作中提高光伏发电项目的开发效率。 一、优化选址流程 1、数据收集与分析:利用卫星地图和遥感技术,收集目标区域的光照资源、地形地貌、阴影…

Win11 docker build拉取镜像失败(无法访问镜像仓库)

目录 遇到的问题: 修改docker配置 写了一个dockerfile(基于python的镜像)文件,在生成时,一直报错,换了好几个仓库,都是不行(包括阿里、南大、官网、网易、Azure中国镜像等都不行) 遇到的问题: 连接超时…

视频云存储平台LntonCVS国标视频平台功能和应用场景详细介绍

LntonCVS国标视频融合云平台基于先进的端-边-云一体化架构设计,以轻便的部署和灵活多样的功能为特点。该平台不仅支持多种通信协议如GB28181、RTSP、Onvif、海康SDK、Ehome、大华SDK、RTMP推流等,还能兼容各类设备,包括IPC、NVR和监控平台。在…

Inception_V2_V3_pytorch

Inception_V2_V3_pytorch 在上一节我们已经精度了Inception_V2_V3这篇论文,本篇我们将用pyorch复现论文中的网络结构! 从论文中我们可以知道InceptionV3的主要改进为: 5 * 5卷积分解为2个3 * 3卷积核分解为不对称卷积滤波器组 我们可将GoogL…

【专利】一种光伏产品缺陷检测AI深度学习算法

申请号CN202410053849.9公开号(公开)CN118037635A申请日2024.01.12申请人(公开)超音速人工智能科技股份有限公司发明人(公开)张俊峰(总); 叶长春(总); 廖绍伟 摘要 本发明公开一种光伏产品缺陷检测AI深度…

区块链实验室(37) - 交叉编译百度xuperchain for arm64

纠结了很久,终于成功编译xuperchain for arm64。踩到1个坑,说明如下。 1、官方文档是这么说的:go语言版本推荐1.5-1.8 2、但是同一个页面,又是这么说的:不推荐使用1.11之前的版本。 3、问题来了:用什么版本…

ONLYOFFICE 编辑器8.1,一个功能全面的编辑器

目录 官网地址:ONLYOFFICE - 企业在线办公应用软件 | ONLYOFFICE 一、PDF编辑 二、PPT播放 1. 多样化的幻灯片样式与布局 2. 强大的文本编辑与格式化功能 3. 丰富的图形与图表插入功能 4. 灵活的过渡效果与动画设置 5. 舒适的呈现与演讲辅助功能 6. 便捷的团…