【笔记】打卡01 | 初学入门

初学入门:01-02

  • 01 基本介绍
  • 02 快速入门
    • 处理数据集
    • ==网络构建==
    • 模型训练
    • 保存模型
    • 加载模型
    • 打卡-时间

01 基本介绍

MindSpore Data(数据处理层)
ModelZoo(模型库)
MindSpore Science(科学计算),包含了业界领先的数据集、基础模型、预置高精度模型和前后处理工具
MindSpore Insight(可视化调试调优工具),能够可视化地查看训练过程、优化模型性能、调试精度问题、解释推理结果

02 快速入门

import mindspore
from mindspore import nn
from mindspore.dataset import vision, transforms
from mindspore.dataset import MnistDataset

处理数据集

下载Mnist数据集

# Download data from open datasets
from download import download

url = "https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/" \
      "notebook/datasets/MNIST_Data.zip"
path = download(url, "./", kind="zip", replace=True)

在这里插入图片描述

训练集、测试集

train_dataset = MnistDataset('MNIST_Data/train')
test_dataset = MnistDataset('MNIST_Data/test')

在这里插入图片描述
列名:图片 和 对应标签(分类)

数据处理流水线(Data Processing Pipeline)

参数:数据集、batch_size

def datapipe(dataset, batch_size):
    image_transforms = [                    
        vision.Rescale(1.0 / 255.0, 0),
        vision.Normalize(mean=(0.1307,), std=(0.3081,)),
        vision.HWC2CHW()
    ]
    label_transform = transforms.TypeCast(mindspore.int32)

    dataset = dataset.map(image_transforms, 'image')
    dataset = dataset.map(label_transform, 'label')
    dataset = dataset.batch(batch_size)
    return dataset

首先,数据变换(Transforms):1、对输入数据(即图片)2、对输出(即标签);
然后,map对图像数据及标签进行变换处理;
最后,将处理好的数据集打包为大小为64的batch

train_dataset = datapipe(train_dataset, 64)
test_dataset = datapipe(test_dataset, 64)

对数据集进行迭代访问

for data in test_dataset.create_dict_iterator():
    print(f"Shape of image [N, C, H, W]: {data['image'].shape} {data['image'].dtype}")
    print(f"Shape of label: {data['label'].shape} {data['label'].dtype}")
    break

网络构建

class Network(nn.Cell):
    def __init__(self):
        super().__init__()
        self.flatten = nn.Flatten()
        self.dense_relu_sequential = nn.SequentialCell(
            nn.Dense(28*28, 512),
            nn.ReLU(),
            nn.Dense(512, 512),
            nn.ReLU(),
            nn.Dense(512, 10)
        )

    def construct(self, x):
        x = self.flatten(x)
        logits = self.dense_relu_sequential(x)
        return logits

model = Network()
print(model)

mindspore.nn类是构建所有网络的基类,也是网络的基本单元。

  • 自定义网络时,可以继承nn.Cell
  • __init__包含所有网络层的定义
  • construct(类似前向传播??)包含数据(Tensor)的变换过程。

模型训练

定义损失函数、优化器

loss_fn = nn.CrossEntropyLoss()
optimizer = nn.SGD(model.trainable_params(), 1e-2)

一个完整的训练过程(step)需要实现以下三步:

1. 正向计算:模型预测结果(logits),并与正确标签(label)求预测损失(loss)。
2. 反向传播:利用自动微分机制,自动求模型参数(parameters)对于loss的梯度(gradients)。
3. 参数优化:将梯度更新到参数上。

定义正向计算函数。

def forward_fn(data, label):
    logits = model(data)
    loss = loss_fn(logits, label)
    return loss, logits

使用value_and_grad通过函数变换获得梯度计算函数。

grad_fn = mindspore.value_and_grad(forward_fn, None, optimizer.parameters, has_aux=True)

one-step training

def train_step(data, label):
    (loss, _), grads = grad_fn(data, label)
    optimizer(grads)
    return loss

定义训练函数,使用set_train设置为训练模式,执行正向计算、反向传播和参数优化。

def train(model, dataset):
    size = dataset.get_dataset_size()
    model.set_train()
    for batch, (data, label) in enumerate(dataset.create_tuple_iterator()):
        loss = train_step(data, label)

        if batch % 100 == 0:
            loss, current = loss.asnumpy(), batch
            print(f"loss: {loss:>7f}  [{current:>3d}/{size:>3d}]")

定义测试函数:用来评估模型的性能。

def test(model, dataset, loss_fn):
    num_batches = dataset.get_dataset_size()
    model.set_train(False)
    total, test_loss, correct = 0, 0, 0
    for data, label in dataset.create_tuple_iterator():
        pred = model(data)
        total += len(data)
        test_loss += loss_fn(pred, label).asnumpy()
        correct += (pred.argmax(1) == label).asnumpy().sum()
    test_loss /= num_batches
    correct /= total
    print(f"Test: \n Accuracy: {(100*correct):>0.1f}%, Avg loss: {test_loss:>8f} \n")

训练过程需多轮(epoch)训练数据集

epochs = 3
for t in range(epochs):
    print(f"Epoch {t+1}\n-------------------------------")
    train(model, train_dataset)
    test(model, test_dataset, loss_fn)
print("Done!")

在这里插入图片描述

保存模型

模型训练完成后,需要保存其参数。

mindspore.save_checkpoint(model, "model.ckpt")
print("Saved Model to model.ckpt")

加载模型

加载保存的权重

# 1、重新实例化模型对象,构造模型
model = Network()
# 加载模型参数,并将其加载至模型上。
param_dict = mindspore.load_checkpoint("model.ckpt")
param_not_load, _ = mindspore.load_param_into_net(model, param_dict)
print(param_not_load)

param_not_load是未被加载的参数列表,为空时代表所有参数均加载成功。

打卡-时间

from datetime import datetime
import pytz
# 设置时区为北京时区
beijing_tz = pytz.timezone('Asia/shanghai')
# 获取当前时间,并转为北京时间
current_beijing_time = datetime.now(beijing_tz)
# 格式化时间输出
formatted_time = current_beijing_time.strftime('%Y-%m-%d %H:%M:%S')
print("当前北京时间:",formatted_time,'your name')

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/728992.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

Chromium 调试指南2024 Mac篇 - 调试 Chromium(三)

1.引言 在完成了环境准备和成功编译Chromium之后,下一步就是进行调试工作。调试是软件开发过程中必不可少的环节,通过调试可以定位和修复代码中的问题,验证新功能的正确性,并确保整个项目的稳定性和高效性。 由于Chromium项目的…

【html】如何利用hbuilderX 开发一个自己的app并安装在手机上运行

引言: 相信大家都非常想开发一款自己的apk,手机应用程序,今天就教大家,如何用hbuilderX 开发一个自己的app并安装在手机上运行。 步骤讲解: 打开hbuilderX ,选择新建项目 2.选择5app,想一个名字&#x…

每天写java到期末考试(6.21)--集合4--练习--6.20

练习1&#xff1a; 正常写集合 bool类 代码&#xff1a; import QM_Fx.Student;import java.util.ArrayList;public class test {public static void main(String[] args) {ArrayList<Student> listnew ArrayList<>();//2.创建学生对象Student s1new Student(&quo…

从媒体网站的频道划分看媒体邀约的分类?

传媒如春雨&#xff0c;润物细无声&#xff0c;大家好&#xff0c;我是51媒体网胡老师。 媒体宣传加速季&#xff0c;100万补贴享不停&#xff0c;一手媒体资源&#xff0c;全国100城线下落地执行。详情请联系胡老师。 在我们举行活动的时候&#xff0c;通常会邀请媒体到现场来…

基于Python爬虫的城市天气数据可视化分析

基于Python爬虫的城市天气数据可视化分析 一、项目简介二、项目背景三、Python语言简介四、网络爬虫简介五、数据可视化简介六、天气数据爬取与存储6.1 获取目标网页6.2 发送请求6.3 提取数据6.4 保存数据七、天气数据可视化7.1 天气现象轮播图7.2 历史温度分布图7.3 历史风向分…

2134名女性,0感染!艾滋病预防药传出大消息,只需半年注射一次,药厂股价应声暴涨

内容提要 美国生物制药公司吉利德科学公布了Lenacapavir预防艾滋病毒的实验结果&#xff0c;显示出100%有效性。或将为艾滋病预防带来新选择。 文章正文 当地时间周四&#xff08;6月20日&#xff09;&#xff0c;美国生物制药公司吉利德科学在其官网公布一则重磅实验结果&am…

使用mysql的binlog进行数据恢复

1.mysql安装环境 在你本地电脑windows上建一个和生产环境一样的mysql版本 我的是 mysql5.7.43 安装教程可以自行上网搜&#xff08;这里不做介绍&#xff09; 可参考&#xff1a; 1.1安装路径 我的mysql安装路径&#xff1a; D:\mysql\mysql-5.7.43-winx64\bin * 1.2my.in…

官方解决 maven 上传jar包到中央仓库401

目录 目录背景&#xff1a;报错现象一句话解决一句话原因问题描述前车之鉴&#xff1a;尝试无果的记录尝试1&#xff1a;询问 GPT 未果尝试2&#xff1a;搜索引擎未果 正确原因分析&#xff1a;佐证 解决方案&#xff08;三步完成&#xff09;&#xff1a;1. 打开 sonatype2. 生…

MacOS设备远程登录配置结合内网穿透实现异地ssh远程连接

文章目录 前言1. MacOS打开远程登录2. 局域网内测试ssh远程3. 公网ssh远程连接MacOS3.1 MacOS安装配置cpolar3.2 获取ssh隧道公网地址3.3 测试公网ssh远程连接MacOS 4. 配置公网固定TCP地址4.1 保留一个固定TCP端口地址4.2 配置固定TCP端口地址 5. 使用固定TCP端口地址ssh远程 …

【C++】类和对象(四)拷贝构造、赋值运算符重载

文章目录 四、拷贝构造函数干嘛的&#xff1f;写拷贝构造函数的注意事项正确写法 不显示定义拷贝构造函数的情况浅拷贝:one:示例&#xff1a;内置类型:two:示例&#xff1a;自定义类型一个提问 深拷贝 五、赋值运算符重载运算符重载函数原型注意调用时的两种书写方式完整实现代…

Ant design Vue 表格中显示不同的状态(多条件显示)

比如&#xff1a;后端一个字段有多种状态&#xff1a; 那么后端接口会返回&#xff1a;0 或者 1 或者 2 其中一个&#xff0c;前端需要展示的是对应的文字&#xff0c;像简单的只有两个状态的可以直接在列里面操作&#xff1a; {title: 状态,dataIndex: usable,customRender: …

国产测温速度快且功耗低的温度传感芯片MY18E20可Pin-Pin替换DS18B20

MY18E20是一款国产高精度可编程的数字模拟混合信号温度传感芯片&#xff1b;感温原理基于CMOS半导体PN节温度与带隙电压的特性关系&#xff0c;经过小信号放大、模数转换、数字校准补偿后&#xff0c;数字总线输出&#xff0c;具有精度高、一致性好、测温快、功耗低、可编程配置…

剖析影响大米码垛机性能的关键因素

在现代化粮食加工产业链中&#xff0c;大米码垛机以其高效、精准的自动化操作&#xff0c;成为提升生产效率、降低劳动强度的得力助手。然而&#xff0c;要想充分发挥大米码垛机的性能优势&#xff0c;我们必须深入了解影响其性能的关键因素。星派将深入剖析这些关键因素&#…

1.XSS-反射型(get)

反射型XSS或不持久型XSS 含义&#xff1a;交互的数据一般不会被存在在数据库里面,只是简单的把用户输入的数据反射给浏览器&#xff0c;一次性&#xff0c;所见即所得。 进行测试一下页面是否存在xss漏洞&#xff0c;进行测试一下 "<>6666输入一些字符或者符号&am…

YOLOv10改进 | Neck | 添加双向特征金字塔BiFPN【含二次独家创新】

&#x1f4a1;&#x1f4a1;&#x1f4a1;本专栏所有程序均经过测试&#xff0c;可成功执行&#x1f4a1;&#x1f4a1;&#x1f4a1; 专栏目录&#xff1a;《YOLOv8改进有效涨点》专栏介绍 & 专栏目录 | 目前已有40篇内容&#xff0c;内含各种Head检测头、损失函数Loss、B…

【面试干货】Java的基础类型和字节大小

【面试干货】Java的基础类型和字节大小 &#x1f496;The Begin&#x1f496;点点关注&#xff0c;收藏不迷路&#x1f496; 在Java编程语言中&#xff0c;有八种基本数据类型&#xff0c;它们分别是&#xff1a;布尔型&#xff08;boolean&#xff09;、字节型&#xff08;byt…

前端架构(含演进历程、设计内容、AI辅助设计、架构演进历程)

前端架构的演进历程 前端架构的设计内容 技术选型(库、工具、标准规范、性能、安全、扩展性 )设计模式及代码组织(模块化、分层架构、数据结构)构建与部署(性能优化、自动化、集成、测试) 项目体量 小型项目&#xff1a;对于小型项目&#xff0c;前端架构需要关注模块化、组件…

可视化大屏开发涉及到的8大技术栈,收藏起来不迷路。

可视化大屏开发设计涉及到的技术栈包括&#xff1a; 前端开发技术&#xff1a; HTML、CSS、JavaScript、TypeScript等&#xff0c;用于构建用户界面和实现交互效果。 数据可视化库&#xff1a; 如D3.js、ECharts、Highcharts等&#xff0c;用于将数据以图表、地图、仪表盘等…

Hibernate 框架进行对象关系映射(ORM)

Hibernate是一个广泛使用的Java对象关系映射&#xff08;ORM&#xff09;框架&#xff0c;它通过将Java类与数据库表关联起来&#xff0c;使得开发人员可以使用面向对象的编程方式进行数据库操作。Hibernate的主要目标是消除冗长的JDBC代码和手动处理SQL的需求&#xff0c;从而…

在C#中对 JSON进行序列化和反序列化处理

概述&#xff1a;在现代软件开发领域&#xff0c;不同系统和平台之间的数据交换是不可或缺的方面。JSON&#xff08;JavaScript 对象表示法&#xff09;因其轻量级、人类可读和易于解析的特性而成为一种无处不在的数据格式。使用 C# &#x1f680;编程的 JSON 序列化和反序列化…