昇思25天学习打卡营第2天|快速入门

快速入门

  • 操作步骤
    • 1.引入依赖包
    • 2.下载Mnist数据集
    • 3.划分训练集和测试集
    • 4.数据预处理
    • 5.网络构建
    • 6.模型训练
    • 7.保存模型
    • 8.加载模型
    • 9.模型预测

今天通过昇思大模型平台AI实验室提供的在线Jupyter工具,快速入门MindSpore。
目标:通过MindSpore的API快速实现一个简单的深度学习模型。

操作步骤

1.引入依赖包

import mindspore
from mindspore import nn
from mindspore.dataset import vision, transforms
from mindspore.dataset import MnistDataset

2.下载Mnist数据集

from download import download
url = "https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/" \
      "notebook/datasets/MNIST_Data.zip"
path = download(url, "./", kind="zip", replace=True)

Mnist数据集目录结构如下:
MNIST_Data
└── train
├── train-images-idx3-ubyte (60000个训练图片)
├── train-labels-idx1-ubyte (60000个训练标签)
└── test
├── t10k-images-idx3-ubyte (10000个测试图片)
├── t10k-labels-idx1-ubyte (10000个测试标签)

3.划分训练集和测试集

train_dataset = MnistDataset('MNIST_Data/train')
test_dataset = MnistDataset('MNIST_Data/test')

4.数据预处理

使用dataset模块的map操作对图像数据及标签进行变换处理,然后将处理好的数据集打包为大小为64的batch。

def datapipe(dataset, batch_size):
    image_transforms = [
        vision.Rescale(1.0 / 255.0, 0),
        vision.Normalize(mean=(0.1307,), std=(0.3081,)),
        vision.HWC2CHW()
    ]
    label_transform = transforms.TypeCast(mindspore.int32)

    dataset = dataset.map(image_transforms, 'image')
    dataset = dataset.map(label_transform, 'label')
    dataset = dataset.batch(batch_size)
    return dataset

# Map vision transforms and batch dataset
train_dataset = datapipe(train_dataset, 64)
test_dataset = datapipe(test_dataset, 64)

5.网络构建

继承nn.Cell类,并重写__init__方法和construct方法。__init__包含所有网络层的定义,construct中包含数据(Tensor)的变换过程。

# Define model
class Network(nn.Cell):
    def __init__(self):
        super().__init__()
        self.flatten = nn.Flatten()
        self.dense_relu_sequential = nn.SequentialCell(
            nn.Dense(28*28, 512),
            nn.ReLU(),
            nn.Dense(512, 512),
            nn.ReLU(),
            nn.Dense(512, 10)
        )

    def construct(self, x):
        x = self.flatten(x)
        logits = self.dense_relu_sequential(x)
        return logits

model = Network()
print(model)

6.模型训练

一个完整的训练过程(step)需要实现以下三步:

  1. 正向计算:模型预测结果(logits),并与正确标签(label)求预测损失(loss)。
  2. 反向传播:利用自动微分机制,自动求模型参数(parameters)对于loss的梯度(gradients)。
  3. 参数优化:将梯度更新到参数上。
# Instantiate loss function and optimizer
loss_fn = nn.CrossEntropyLoss()
optimizer = nn.SGD(model.trainable_params(), 1e-2)

# 定义正向计算函数
def forward_fn(data, label):
    logits = model(data)
    loss = loss_fn(logits, label)
    return loss, logits

# 使用value_and_grad通过函数变换获得梯度计算函数
grad_fn = mindspore.value_and_grad(forward_fn, None, optimizer.parameters, has_aux=True)

# 定义训练函数,使用set_train设置为训练模式,执行正向计算、反向传播和参数优化
def train_step(data, label):
    (loss, _), grads = grad_fn(data, label)
    optimizer(grads)
    return loss

# 定义测试函数,用来评估模型的性能
def train(model, dataset):
    size = dataset.get_dataset_size()
    model.set_train()
    for batch, (data, label) in enumerate(dataset.create_tuple_iterator()):
        loss = train_step(data, label)

        if batch % 100 == 0:
            loss, current = loss.asnumpy(), batch
            print(f"loss: {loss:>7f}  [{current:>3d}/{size:>3d}]")

# 预测,并输出每一轮的loss值和预测准确率(Accuracy)
epochs = 3
for t in range(epochs):
    print(f"Epoch {t+1}\n-------------------------------")
    train(model, train_dataset)
    test(model, test_dataset, loss_fn)
print("Done!")

训练结果:
训练结果

7.保存模型

# Save checkpoint
mindspore.save_checkpoint(model, "model.ckpt")
print("Saved Model to model.ckpt")

8.加载模型

# Instantiate a random initialized model
model = Network()
# Load checkpoint and load parameter to model
param_dict = mindspore.load_checkpoint("model.ckpt")
param_not_load, _ = mindspore.load_param_into_net(model, param_dict)
print(param_not_load)

9.模型预测

model.set_train(False)
for data, label in test_dataset:
    pred = model(data)
    predicted = pred.argmax(1)
    print(f'Predicted: "{predicted[:10]}", Actual: "{label[:10]}"')
    break

打印结果:Predicted: “[7 1 9 8 3 8 7 7 7 9]”, Actual: “[7 1 9 8 3 8 7 7 7 9]”

截图时间
截图时间

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/743814.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

互联网应用主流框架整合之Spring Boot运维体系

先准备个简单的系统,配置和代码如下 # 服务器配置 server:# 服务器端口port: 8001# Spring Boot 配置 spring:# MVC 配置mvc:# Servlet 配置servlet:# Servlet 的访问路径path: /sbd# 应用程序配置application:# 应用程序名称name: SpringBootDeployment# 配置数据…

PointCloudLib NDT3D算法实现点云配准 C++版本

0.实现效果 效果不咋好 ,参数不好调整 1.算法原理 3D NDT(Normal Distributions Transform)算法是一种用于同时定位和地图生成(SLAM)的机器人导航算法,特别适用于三维点云数据的配准。以下是关于3D NDT算法的详细解释: 算法原理 点云划分与分布计算:3D NDT算法首先将…

ElementPlus组件与图标按需自动引入

按需自动引入组件 1. 安装ElementPlus和自动导入ElementPlus组件的插件 pnpm install element-plus pnpm install -D unplugin-vue-components unplugin-auto-import 2. vite.config.ts进行修改 import { defineConfig } from vite import vue from vitejs/plugin-vue // …

MySQL索引优化解决方案--索引失效(3)

索引失效情况 最佳左前缀法则:如果索引了多列,要遵循最左前缀法则,指的是查询从索引的最左前列开始并且不跳过索引中的列。不在索引列上做任何计算、函数操作,会导致索引失效而转向全表扫描存储引擎不能使用索引中范围条件右边的…

Windows 根据github上的环境需求,安装一个虚拟环境,安装cuda和torch

比如我们在github上看到一个关于运行环境的需求 Installation xxx系统Python 3.xxx CUDA 9.2PyTorch 1.9.0xxxxxx 最主要的就是cuda和torch,这两个会卡很多环境的安装。 我们重新走一遍环境安装。 首先创建一个虚拟环境 conda create -n 环境名字 python3.xxx…

Tomcat 下载部署到 idea

一、下载Tomcat Tomcat 是Apache 软件基金会(Apache Software Foundation)下的一个核心项目,免费开源、并支持Servlet 和JSP 规范。属于轻量级应用服务器,在中小型系统和并发访问用户不是很多的场合下被普遍使用,是开发…

【Text2SQL 论文】MAGIC:为 Text2SQL 任务自动生成 self-correction guideline

论文:MAGIC: Generating Self-Correction Guideline for In-Context Text-to-SQL ⭐⭐⭐ 莱顿大学 & Microsoft, arXiv:2406.12692 一、论文速读 DIN-SQL 模型中使用了一个 self-correction 模块,他把 LLM 直接生成的 SQL 带上一些 guidelines 的 p…

计算机组成原理课程设计报告

有关计算机组成原理课程设计报告如下题材,其包含报告和代码 16位ALU设计 动态LED动态显示屏设计 海明码编码设计编码流水传输 单周期MIPS控制器设计 阵列乘法器设计 Cache映射机制与逻辑实现 我们以6x6位阵列乘法器设计的报告为例为大家讲述一下我们这次的课程设计 …

网络爬虫Xpath开发工具的使用

开发人员在编写网络爬虫程序时若遇到解析网页数据的问题,则需要花费大量的时间编 写与测试路径表达式,以确认是否可以解析出所需要的数据。为帮助开发人员在网页上直接 测试路径表达式是否正确,我们在这里推荐一款比较好用的 XPath 开发工…

关于关闭防火墙后docker启动不了容器

做项目的时候遇到个怪事,在Java客户端没办法操作redis集群。反复检查了是否运行,端口等一系列细节的操作,结果都不行。 根据提示可能是Linux的防火墙原因。于是去linux关闭了防火墙。 关闭后果不其然 可以操作reids了,可是没想到另…

浏览器断点调试(用图说话)

浏览器断点调试(用图说话) 1、开发者工具2、添加断点3、查看变量值 浏览器断点调试 有时候我们需要在浏览器中查看 html页面的js中的变量值。1、开发者工具 打开浏览器的开发者工具 按F12 ,没反应的话按FnF12 2、添加断点 3、查看变量值

高考填报志愿攻略,5个步骤选专业和院校

在高考完毕出成绩的时候,很多人会陷入迷茫中,好像努力了这么多年,却不知道怎么规划好未来。怎么填报志愿合适?在填报志愿方面有几个内容需要弄清楚,按部就班就能找到方向,一起来了解一下正确的步骤吧。 第…

【C语言】解决C语言报错:Dangling Pointer

文章目录 简介什么是Dangling PointerDangling Pointer的常见原因如何检测和调试Dangling Pointer解决Dangling Pointer的最佳实践详细实例解析示例1:释放内存后未将指针置为NULL示例2:返回指向局部变量的指针示例3:指针悬空后继续使用示例4&…

37岁,被裁员,失业三个月,被面试官嫌弃“太水”:就这也叫10年以上工作经验?

今年部门要招两个自动化测试,这几个月我面试了几十位候选人。发现一个很奇怪的现象,面试中一问到元素定位、框架api、脚本编写之类的,很多候选人都对答如流。但是一问到实际项目,比如“项目中UI自动化和接口自动化如何搭配使用&am…

【研究】国内外大模型公司进展

2022年11月,OpenAI推出基于GPT-3.5的ChatGPT后,引发全球AI大模型技术开发与投资热潮。AI大模型性能持续快速提升。以衡量LLM的常用评测标准MMLU为例,2021年底全球最先进大模型的MMLU 5-shot得分刚达到60%,2022年底超过70%&#xf…

JAVA小知识29:IO流(上)

IO流是指在计算机中进行输入和输出操作的一种方式,用于读取和写入数据。IO流主要用于处理数据传输,可以将数据从一个地方传送到另一个地方,例如从内存到硬盘,从网络到内存等。IO流在编程中非常常见,特别是在文件操作和…

正版软件 | Copywhiz 6:革新您的文件复制、备份与管理体验

在数字化时代,文件管理的效率直接影响到我们的生产力。Copywhiz 6 最新版本,带来了前所未有的文件处理能力,让复制、备份和组织文件变得轻而易举。 智能选择,只复制更新内容 Copywhiz 6 的智能选择功能,让您只需几次点…

10--7层负载均衡集群

前言:动静分离,资源分离都是在7层负载均衡完成的,此处常被与四层负载均衡比较,本章这里使用haproxy与nginx进行负载均衡总结演示。 1、基础概念详解 1.1、负载均衡 4层负载均衡和7层负载均衡是两种常见的负载均衡技术&#xff…

docker 容器设置中文环境

1.容器中安装和设置 1.1.进入容器查看已有语言包 locale -a 默认情况下: 1.2 安装中文语言环境 如果没有zh_CN.utf8就安装。 方式1: #直接安装中文语言包 apt-get install -y language-pack-zh-hans 方式2: #安装中文语言环境 apt-g…

小白学python(第二天)

哈喽,各位小伙伴们我们又见面了,昨天的文章吸收得如何?可有不懂否?如有不懂可以在品论区留言哦,废话不多说,开始今天的内容。 字符及字符串的续讲 字符:英文字母,阿拉伯数字&#x…