昇思MindSpore学习总结八——模型保存与加载

        在训练网络模型的过程中,实际上我们希望保存中间和最后的结果,用于微调(fine-tune)和后续的模型推理与部署,接下来将介绍如何保存与加载模型。

1.构建模型

import numpy as np
import mindspore
from mindspore import nn
from mindspore import Tensor

def network():
    model = nn.SequentialCell(
                nn.Flatten(),
                nn.Dense(28*28, 512),
                nn.ReLU(),
                nn.Dense(512, 512),
                nn.ReLU(),
                nn.Dense(512, 10))
    return model

这里是没有经过训练的,可以直接用上一节训练的模型model。

2、保存和加载权重

2.1 保存

保存模型使用save_checkpoint接口,传入网络和指定的保存路径:

mindspore.save_checkpoint(save_objckpt_file_nameintegrated_save=Trueasync_save=Falseappend_dict=Noneenc_key=Noneenc_mode='AES-GCM'choice_func=None**kwargs)

【参数】

  • save_obj (Union[Cell, list, dict]) - 待保存的对象。数据类型可为 mindspore.nn.Cell 、list或dict。若为list,可以是 Cell.trainable_params() 的返回值,或元素为dict的列表(如[{“name”: param_name, “data”: param_data},…],param_name 的类型必须是str,param_data 的类型必须是Parameter或者Tensor);若为dict,可以是 mindspore.load_checkpoint() 的返回值。

  • ckpt_file_name (str) - checkpoint文件名称。如果文件已存在,将会覆盖原有文件。

  • integrated_save (bool) - 在并行场景下是否合并保存拆分的Tensor。默认值: True 。

  • async_save (bool) - 是否异步执行保存checkpoint文件。默认值: False 。

  • append_dict (dict) - 需要保存的其他信息。dict的键必须为str类型,dict的值类型必须是int、float、bool、string、Parameter或Tensor类型。默认值: None 。

  • enc_key (Union[None, bytes]) - 用于加密的字节类型密钥。如果值为 None ,那么不需要加密。默认值: None 。

  • enc_mode (str) - 该参数在 enc_key 不为 None 时有效,指定加密模式,目前仅支持 "AES-GCM" , "AES-CBC" 和 "SM4-CBC" 。默认值: "AES-GCM" 。

  • choice_func (function) - 一个用于自定义控制保存参数的函数。函数的输入值为字符串类型的Parameter名称,并且返回值是一个布尔值。如果返回 True ,则匹配自定义条件的Parameter将被保存。 如果返回 False ,则未匹配自定义条件的Parameter不会被保存。默认值: None 。

  • kwargs (dict) - 配置选项字典。

model = network()
mindspore.save_checkpoint(model, "model.ckpt")

运行之后会在同路径下找到一个文件

2.2 加载

        要加载模型权重,需要先创建相同模型的实例,然后使用load_checkpointload_param_into_net方法加载参数。

2.2.1 load_checkpoint

mindspore.load_checkpoint(ckpt_file_namenet=Nonestrict_load=Falsefilter_prefix=Nonedec_key=Nonedec_mode='AES-GCM'specify_prefix=Nonechoice_func=None)

【参数】

  • ckpt_file_name (str) - checkpoint的文件名称。

  • net (Cell) - 加载checkpoint参数的网络。默认值: None 。

  • strict_load (bool) - 是否将严格加载参数到网络中。如果是 False ,它将根据相同的后缀名将参数字典中的参数加载到网络中,并会在精度不匹配时,进行强制精度转换,比如将 float32 转换为 float16 。默认值: False 。

  • filter_prefix (Union[str, list[str], tuple[str]]) - 废弃(请参考参数 choice_func)。以 filter_prefix 开头的参数将不会被加载。默认值: None 。

  • dec_key (Union[None, bytes]) - 用于解密的字节类型密钥,如果值为 None ,则不需要解密。默认值: None 。

  • dec_mode (str) - 该参数仅当 dec_key 不为 None 时有效。指定解密模式,目前支持 "AES-GCM" , "AES-CBC" 和 "SM4-CBC" 。默认值: "AES-GCM" 。

  • specify_prefix (Union[str, list[str], tuple[str]]) - 废弃(请参考参数 choice_func)。以 specify_prefix 开头的参数将会被加载。默认值: None 。

  • choice_func (Union[None, function]) - 函数的输入值为字符串类型的Parameter名称,并且返回值是一个布尔值。如果返回 True ,则匹配自定义条件的Parameter将被加载。 如果返回 False ,则匹配自定义条件的Parameter将被删除。默认值: None 。

2.2.2 load_param_into_net

mindspore.load_param_into_net(netparameter_dictstrict_load=False)将参数加载到网络中,返回网络中没有被加载的参数列表。

【参数】

  • net (Cell) - 将要加载参数的网络。

  • parameter_dict (dict) - 加载checkpoint文件得到的字典。

  • strict_load (bool) - 是否将参数严格加载到网络中。如果是 False , 它将以相同的后缀名将参数字典中的参数加载到网络中,并会在精度不匹配时,进行精度转换,比如将 float32 转换为 float16 。默认值: False 。

model = network()
param_dict = mindspore.load_checkpoint("model.ckpt")
param_not_load, _ = mindspore.load_param_into_net(model, param_dict)
print(param_not_load)

【运行结果】

param_not_load是未被加载的参数列表,为空时代表所有参数均加载成功。 

3、保存和加载MindIR

        除Checkpoint外,MindSpore提供了云侧(训练)和端侧(推理)统一的中间表示(Intermediate Representation,IR)。可使用export接口直接将模型保存为MindIR。

3.1 保存

mindspore.export(net*inputsfile_namefile_format**kwargs)将MindSpore网络模型导出为指定格式的文件。

【参数】

  • net (Union[Cell, function]) - MindSpore网络结构。

  • inputs (Union[Tensor, Dataset, List, Tuple, Number, Bool]) - 网络的输入,如果网络有多个输入,需要一同传入。当传入的类型为 Dataset 时,将会把数据预处理行为同步保存起来。需要手动调整batch的大小,当前仅支持获取 Dataset 的 image 列。

  • file_name (str) - 导出模型的文件名称。

  • file_format (str) - MindSpore目前支持导出”AIR”,”ONNX”和”MINDIR”格式的模型。

    • AIR - Ascend Intermediate Representation。一种Ascend模型的中间表示格式。推荐的输出文件后缀是”.air”。

    • ONNX - Open Neural Network eXchange。一种针对机器学习所设计的开放式的文件格式。推荐的输出文件后缀是”.onnx”。

    • MINDIR - MindSpore Native Intermediate Representation for Anf。一种MindSpore模型的中间表示格式。推荐的输出文件后缀是”.mindir”。

  • kwargs (dict) - 配置选项字典。

    • enc_key (byte) - 用于加密的字节类型密钥,有效长度为16、24或者32。

    • enc_mode (Union[str, function]) - 指定加密模式,当设置 enc_key 时启用。

      • 对于 ‘AIR’和 ‘ONNX’格式的模型,当前仅支持自定义加密导出。

      • 对于 ‘MINDIR’格式的模型,支持的加密选项有: ‘AES-GCM’, ‘AES-CBC’, ‘SM4-CBC’和用户自定义加密算法。默认值: "AES-GCM"

      • 关于使用自定义加密导出的详情,请查看 教程。

    • dataset (Dataset) - 指定数据集的预处理方法,用于将数据集的预处理导入MindIR。

    • obf_config (dict) - 模型混淆配置选项字典。

      • type (str) - 混淆类型,目前支持动态混淆,即 ‘dynamic’ 。

      • obf_ratio (Union[str, float]) - 全模型算子的混淆比例,可取浮点数(0, 1]或者字符串 "small" 、 "medium" 、 "large" 。"small" 、"medium" 、"large" 分别对应于 0.1、0.3、0.6。

      • customized_func (function) - 在自定义函数模式下需要设置的Python函数,用来控制混淆结构中的选择分支走向。它的返回值需要是bool类型,且是恒定的,用户可以参考不透明谓词进行设置(请查看 动态混淆教程 中的 my_func())。如果设置了 customized_func ,那么在使用 load 接口导入模型的时候,需要把这个函数也传入。

      • obf_random_seed (int) - 混淆随机种子,是一个取值范围为(0, 9223372036854775807]的整数,不同的随机种子会使模型混淆后的结构不同。如果用户设置了 obf_random_seed ,那么在部署混淆模型的时候,需要在调用 mindspore.nn.GraphCell 接口中传入 obf_random_seed 。需要注意的是,如果用户同时设置了 customized_func 和 obf_random_seed ,那么后一种模式将会被采用。

    • custom_func (function) - 用户自定义的导出策略的函数。该函数会在网络导出时,对模型使用该函数进行自定义处理。需要注意,当前仅支持对 format 为 MindIR 的文件使用 custom_func ,且自定义函数仅接受一个代表 MindIR 文件 Proto 对象的入参。当使用 custom_func 对模型进行修改时,需要保证修改后模型的正确性,否则可能导致模型加载失败或功能错误。默认值: None 。

model = network()
inputs = Tensor(np.ones([1, 1, 28, 28]).astype(np.float32))
mindspore.export(model, inputs, file_name="model", file_format="MINDIR")

MindIR同时保存了Checkpoint和模型结构,因此需要定义输入Tensor来获取输入shape。

3.2 加载

已有的MindIR模型可以方便地通过load接口加载,传入nn.GraphCell即可进行推理。nn.GraphCell仅支持图模式。

mindspore.nn.GraphCell(graphparams_init=Noneobf_random_seed=None)

运行从MindIR加载的计算图。

此功能仍在开发中。目前 GraphCell 不支持修改图结构,在导出MindIR时只能使用shape和类型与输入相同的数据。

【参数】

  • graph (FuncGraph) - 从MindIR加载的编译图。

  • params_init (dict) - 需要在图中初始化的参数。key为参数名称,类型为字符串,value为 Tensor 或 Parameter。如果参数名在图中已经存在,则更新其值;如果不存在,则忽略。默认值: None 。

  • obf_random_seed (Union[int, None]) - 用于动态混淆保护的混淆随机种子。动态混淆是一种模型保护方法,可以参考 mindspore.obfuscate_model() 。如果导入的 graph 是一个经过混淆的模型,那么须提供 obf_random_seed 。 obf_random_seed 的取值范围是(0, 9223372036854775807]。默认值: None 。

mindspore.set_context(mode=mindspore.GRAPH_MODE)

graph = mindspore.load("model.mindir")
model = nn.GraphCell(graph)
outputs = model(inputs)
print(outputs.shape)

 这里时间改了一下,之前差8个小时。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/767942.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

【RocketMQ】记录一次RocketMQ消费延迟问题排查思路

文章目录 背景问题排查Consumer负载均衡机制订阅关系的一致 背景 业务团队反馈使用我提供的RocketMQ集群,上游生产的消息,部分消息,消费程序需要等1分钟,甚至几分钟后,才能收到。 问题排查 见怪不怪,大部…

TapData 医疗美容行业数字化白皮书上线

随着颜值经济的发展,人们追求“美”的热情也不断高涨,而作为近年来被资本频频下注的赛道之一,医疗美容逐渐走进大众视野。 一方面,社会对外貌的重视程度不断提升,越来越多的人愿意通过医疗美容来提升自身形象&#xf…

后端加前端Echarts画图示例全流程(折线图,饼图,柱状图)

本文将带领读者通过一个完整的Echarts画图示例项目,演示如何结合后端技术(使用Spring Boot框架)和前端技术(使用Vue.js或React框架)来实现数据可视化。我们将实现折线图、饼图和柱状图三种常见的数据展示方式&#xff…

【python】PyQt5事件机制、定时器原理分析和实战演练

✨✨ 欢迎大家来到景天科技苑✨✨ 🎈🎈 养成好习惯,先赞后看哦~🎈🎈 🏆 作者简介:景天科技苑 🏆《头衔》:大厂架构师,华为云开发者社区专家博主,…

哨兵1SAR空间数据包协议数据单元文档(五)

《哨兵1SAR空间数据包协议数据单元》文档对数据包的结构进行了详细描述,并提供了用户数据的格式和解码算法。 原文链接: 哨兵1SAR空间数据包协议数据单元文档英文版 同系列中的其他文章篇链接: 哨兵1SAR空间数据包协议数据单元文档(一) 哨兵1…

亚太万人eVTOL展!2024深圳eVTOL将于9月登陆鹏城

2024年以来,北京、上海等十多个省市,先后发布了鼓励低空经济发展的行动方案,其中,eVTOL(电动垂直起降航空器)成为低空经济最火热的细分赛道。2023年,中国eVTOL产业规模达9.8亿元,同比…

【C++】相机标定源码笔记- RGB 相机与 ToF 深度传感器校准类

类的设计目标是为了实现 RGB 相机与 ToF 深度传感器之间的高精度校准,从而使两种类型的数据能够在同一个坐标框架内被整合使用。这在很多场景下都是非常有用的,比如在3D重建、增强现实、机器人导航等应用中,能够提供更丰富的场景信息。 -----…

64位Office API声明语句第120讲

跟我学VBA,我这里专注VBA, 授人以渔。我98年开始,从源码接触VBA已经20余年了,随着年龄的增长,越来越觉得有必要把这项技能传递给需要这项技术的职场人员。希望职场和数据打交道的朋友,都来学习VBA,利用VBA,起码可以提高…

工厂模式之简单工厂模式

文章目录 工厂模式工厂模式分为工厂模式的角色简单工厂模式案例代码定义一个父类,三个子类定义简单工厂客户端使用输出结果 工厂模式 工厂模式属于创造型的模式,用于创建对象。 工厂模式分为 简单工厂模式:定义一个简单工厂类,根…

Vue3的模板语法插值表达式用法

在template中输入“5 3” &#xff0c;是没有运算能力的&#xff0c;只会把字符直接显示出来&#xff0c;代码如下&#xff1a; <template><view>这是demo</view><view>5 3</view><navigator open-type"navigateBack"><vi…

永劫无间国服延迟高、报错、卡顿的处理措施一览

永劫无间国服延迟高、报错、卡顿怎么办&#xff1f;快速解决办法分享 第一个办法&#xff1a;改善延迟 如果是一直遇到永劫无间国服延迟高、报错、卡顿的问题&#xff0c;重启游戏也不管用的话&#xff0c;那应该就是网络问题&#xff0c;玩家可以启动雷神&#xff0c;让其快速…

服了!DELETE 同一行记录也会造成死锁!!

1 问题背景 “哥们&#xff0c;又双叒叕写了个死锁&#xff0c;秀啊&#xff01;&#x1f60f;” 就算是经常写死锁的同学看到估计都会有点懵&#xff0c;两条一模一样的 DELETE 语句怎么会产生死锁呢&#xff1f; 2 MySQL 锁回顾 看到这里的靓仔肯定对 MySQL 的锁非常了解&…

红酒与建筑:品味历史与艺术的交汇

在时间的长河中&#xff0c;红酒与建筑都是人类智慧的结晶&#xff0c;它们各自承载着历史的厚重与艺术的韵味。当这两者交汇时&#xff0c;仿佛是一场穿越时空的对话&#xff0c;将我们带入一个既古老又现代、既深沉又温柔的世界。今天&#xff0c;就让我们一起走进这个奇妙的…

企业消费采购成本和员工体验如何实现“鱼和熊掌“的兼得?

有企业说企业消费采购成本和员工体验的关系好比是“鱼和熊掌”&#xff0c;无法兼得&#xff1f; 要想控制好成本就一定要加强管控&#xff0c;但是加强管控以后&#xff0c;就会很难让员工获得满意的体验度。如果不加以管控&#xff0c;员工自由度增加了&#xff0c;往往就很难…

为什么要在成像应用中使用图像采集卡?

达到最大产量是工业和工厂自动化的关键标准之一。提高传感器分辨率和帧速率有助于实现这一目标&#xff0c;但也使带宽达到极限&#xff0c;并提出了新的传输问题。当前高带宽接口(如10GigE、相机直接与PC连接和嵌入式系统)的实现促使成像应用的许多用户询问如何以最佳配置最优…

Day63 代码随想录打卡|回溯算法篇---电话号码的字母组合

题目&#xff08;leecode T17&#xff09;&#xff1a; 给定一个仅包含数字 2-9 的字符串&#xff0c;返回所有它能表示的字母组合。答案可以按 任意顺序 返回。 给出数字到字母的映射如下&#xff08;与电话按键相同&#xff09;。注意 1 不对应任何字母。 方法&#xff1a;…

CCS方形低角度光源实现更均匀的照射

机器视觉系统中&#xff0c;光源设计作为图像成像效果的关键&#xff0c;今天的光源系列分享——FPQ3系列。 特点&#xff1a; ・从4个方向以低角度照射均匀扩射光的方型低角度光源。 ・实现比上一代产品2倍的高输出。白色和蓝色的亮度提高至2倍&#xff0c;红色的亮度提高至…

app单页下载页源码带管理后台

新版带后台管理APP应用下载页,自动识别安卓苹果下载页&#xff0c;带管理后台&#xff0c;内置带3套App下载模板带中文模板/英文模板随时切换。 app单页下载页源码带管理后台

保存到redis中的token乱码了

示图&#xff1a; 原因是缓存保存到redis需要序列化操作&#xff0c;没有序列化会出现这样的问题 序列化redis第一步&#xff1a; package com.abliner.test.configure.redis;import com.fasterxml.jackson.annotation.JsonAutoDetect; import com.fasterxml.jackson.annota…

【雷丰阳-谷粒商城 】【分布式高级篇-微服务架构篇】【17】认证服务01

持续学习&持续更新中… 守破离 【雷丰阳-谷粒商城 】【分布式高级篇-微服务架构篇】【17】认证服务01 环境搭建验证码倒计时短信服务邮件服务验证码短信形式&#xff1a;邮件形式&#xff1a; 异常机制MD5参考 环境搭建 C:\Windows\System32\drivers\etc\hosts 192.168.…