昇思25天学习打卡营第8天 | 保存与加载 使用静态图加速

保存与加载

在训练网络模型的过程中,实际上我们希望保存中间和最后的结果,用于微调(fine-tune)和后续的模型推理与部署,下面是介绍如何保存与加载模型。

先定义一个模型用:

import numpy as np
import mindspore
from mindspore import nn
from mindspore import Tensor

def network():
    model = nn.SequentialCell(
                nn.Flatten(),
                nn.Dense(28*28, 512),
                nn.ReLU(),
                nn.Dense(512, 512),
                nn.ReLU(),
                nn.Dense(512, 10))
    return model

保存和加载模型权重

保存模型使用save_checkpoint接口,传入网络和指定的保存路径:

要加载模型权重,需要先创建相同模型的实例,然后使用load_checkpointload_param_into_net方法加载参数。

param_not_load是未被加载的参数列表,为空时代表所有参数均加载成功。

model = network()
mindspore.save_checkpoint(model, "model.ckpt")

model = network()
param_dict = mindspore.load_checkpoint("model.ckpt")
param_not_load, _ = mindspore.load_param_into_net(model, param_dict)
print(param_not_load)

保存和加载MindIR

除Checkpoint外,MindSpore提供了云侧(训练)和端侧(推理)统一的中间表示(Intermediate Representation,IR)。可使用export接口直接将模型保存为MindIR。

MindIR同时保存了Checkpoint和模型结构,因此需要定义输入Tensor来获取输入shape。

已有的MindIR模型可以方便地通过load接口加载,传入nn.GraphCell即可进行推理。

nn.GraphCell仅支持图模式。

model = network()
inputs = Tensor(np.ones([1, 1, 28, 28]).astype(np.float32))
mindspore.export(model, inputs, file_name="model", file_format="MINDIR")

mindspore.set_context(mode=mindspore.GRAPH_MODE)

graph = mindspore.load("model.mindir")
model = nn.GraphCell(graph)
outputs = model(inputs)
print(outputs.shape)

代码实现:

使用静态图加速

背景介绍

AI编译框架分为两种运行模式,分别是动态图模式以及静态图模式。MindSpore默认情况下是以动态图模式运行,但也支持手工切换为静态图模式。两种运行模式的详细介绍如下:

动态图模式

动态图的特点是计算图的构建和计算同时发生(Define by run),其符合Python的解释执行方式,在计算图中定义一个Tensor时,其值就已经被计算且确定,因此在调试模型时较为方便,能够实时得到中间结果的值,但由于所有节点都需要被保存,导致难以对整个计算图进行优化。

在MindSpore中,动态图模式又被称为PyNative模式。由于动态图的解释执行特性,在脚本开发和网络流程调试过程中,推荐使用动态图模式进行调试。 如需要手动控制框架采用PyNative模式,可以通过以下代码进行网络构建:

import numpy as np
import mindspore as ms
from mindspore import nn, Tensor
ms.set_context(mode=ms.PYNATIVE_MODE)  # 使用set_context进行动态图模式的配置

class Network(nn.Cell):
    def __init__(self):
        super().__init__()
        self.flatten = nn.Flatten()
        self.dense_relu_sequential = nn.SequentialCell(
            nn.Dense(28*28, 512),
            nn.ReLU(),
            nn.Dense(512, 512),
            nn.ReLU(),
            nn.Dense(512, 10)
        )

    def construct(self, x):
        x = self.flatten(x)
        logits = self.dense_relu_sequential(x)
        return logits

model = Network()
input = Tensor(np.ones([64, 1, 28, 28]).astype(np.float32))
output = model(input)
print(output)

静态图模式

相较于动态图而言,静态图的特点是将计算图的构建和实际计算分开(Define and run)。有关静态图模式的运行原理,可以参考静态图语法支持。

在MindSpore中,静态图模式又被称为Graph模式,在Graph模式下,基于图优化、计算图整图下沉等技术,编译器可以针对图进行全局的优化,获得较好的性能,因此比较适合网络固定且需要高性能的场景。

如需要手动控制框架采用静态图模式,可以通过以下代码进行网络构建:

import numpy as np
import mindspore as ms
from mindspore import nn, Tensor
ms.set_context(mode=ms.GRAPH_MODE)  # 使用set_context进行运行静态图模式的配置

class Network(nn.Cell):
    def __init__(self):
        super().__init__()
        self.flatten = nn.Flatten()
        self.dense_relu_sequential = nn.SequentialCell(
            nn.Dense(28*28, 512),
            nn.ReLU(),
            nn.Dense(512, 512),
            nn.ReLU(),
            nn.Dense(512, 10)
        )

    def construct(self, x):
        x = self.flatten(x)
        logits = self.dense_relu_sequential(x)
        return logits

model = Network()
input = Tensor(np.ones([64, 1, 28, 28]).astype(np.float32))
output = model(input)
print(output)

静态图模式的使用场景

MindSpore编译器重点面向Tensor数据的计算以及其微分处理。因此使用MindSpore API以及基于Tensor对象的操作更适合使用静态图编译优化。其他操作虽然可以部分入图编译,但实际优化作用有限。另外,静态图模式先编译后执行的模式导致其存在编译耗时。因此,如果函数无需反复执行,那么使用静态图加速也可能没有价值

有关使用静态图来进行网络编译的示例,请参考网络构建。

静态图模式开启方式

通常情况下,由于动态图的灵活性,我们会选择使用PyNative模式来进行自由的神经网络构建,以实现模型的创新和优化。但是当需要进行性能加速时,我们需要对神经网络部分或整体进行加速。MindSpore提供了两种切换为图模式的方式,分别是基于装饰器的开启方式以及基于全局context的开启方式

基于装饰器的开启方式

MindSpore提供了jit装饰器,可以通过修饰Python函数或者Python类的成员函数使其被编译成计算图,通过图优化等技术提高运行速度。此时我们可以简单的对想要进行性能优化的模块进行图编译加速,而模型其他部分,仍旧使用解释执行方式,不丢失动态图的灵活性。无论全局context是设置成静态图模式还是动态图模式,被jit修饰的部分始终会以静态图模式进行运行。

在需要对Tensor的某些运算进行编译加速时,可以在其定义的函数上使用jit修饰器,在调用该函数时,该模块自动被编译为静态图。需要注意的是,jit装饰器只能用来修饰函数,无法对类进行修饰。jit的使用示例如下:

import numpy as np
import mindspore as ms
from mindspore import nn, Tensor

class Network(nn.Cell):
    def __init__(self):
        super().__init__()
        self.flatten = nn.Flatten()
        self.dense_relu_sequential = nn.SequentialCell(
            nn.Dense(28*28, 512),
            nn.ReLU(),
            nn.Dense(512, 512),
            nn.ReLU(),
            nn.Dense(512, 10)
        )

    def construct(self, x):
        x = self.flatten(x)
        logits = self.dense_relu_sequential(x)
        return logits

input = Tensor(np.ones([64, 1, 28, 28]).astype(np.float32))

@ms.jit  # 使用ms.jit装饰器,使被装饰的函数以静态图模式运行
def run(x):
    model = Network()
    return model(x)

output = run(input)
print(output)

除使用修饰器外,也可使用函数变换方式调用jit方法,示例如下:

import numpy as np
import mindspore as ms
from mindspore import nn, Tensor

class Network(nn.Cell):
    def __init__(self):
        super().__init__()
        self.flatten = nn.Flatten()
        self.dense_relu_sequential = nn.SequentialCell(
            nn.Dense(28*28, 512),
            nn.ReLU(),
            nn.Dense(512, 512),
            nn.ReLU(),
            nn.Dense(512, 10)
        )

    def construct(self, x):
        x = self.flatten(x)
        logits = self.dense_relu_sequential(x)
        return logits

input = Tensor(np.ones([64, 1, 28, 28]).astype(np.float32))

def run(x):
    model = Network()
    return model(x)

run_with_jit = ms.jit(run)  # 通过调用jit将函数转换为以静态图方式执行
output = run(input)
print(output)

当我们需要对神经网络的某部分进行加速时,可以直接在construct方法上使用jit修饰器,在调用实例化对象时,该模块自动被编译为静态图。示例如下:

import numpy as np
import mindspore as ms
from mindspore import nn, Tensor

class Network(nn.Cell):
    def __init__(self):
        super().__init__()
        self.flatten = nn.Flatten()
        self.dense_relu_sequential = nn.SequentialCell(
            nn.Dense(28*28, 512),
            nn.ReLU(),
            nn.Dense(512, 512),
            nn.ReLU(),
            nn.Dense(512, 10)
        )

    @ms.jit  # 使用ms.jit装饰器,使被装饰的函数以静态图模式运行
    def construct(self, x):
        x = self.flatten(x)
        logits = self.dense_relu_sequential(x)
        return logits

input = Tensor(np.ones([64, 1, 28, 28]).astype(np.float32))
model = Network()
output = model(input)
print(output)

基于context的开启方式

context模式是一种全局的设置模式。代码示例如下:

import numpy as np
import mindspore as ms
from mindspore import nn, Tensor
ms.set_context(mode=ms.GRAPH_MODE)  # 使用set_context进行运行静态图模式的配置

class Network(nn.Cell):
    def __init__(self):
        super().__init__()
        self.flatten = nn.Flatten()
        self.dense_relu_sequential = nn.SequentialCell(
            nn.Dense(28*28, 512),
            nn.ReLU(),
            nn.Dense(512, 512),
            nn.ReLU(),
            nn.Dense(512, 10)
        )

    def construct(self, x):
        x = self.flatten(x)
        logits = self.dense_relu_sequential(x)
        return logits

model = Network()
input = Tensor(np.ones([64, 1, 28, 28]).astype(np.float32))
output = model(input)
print(output)

静态图的语法约束

在Graph模式下,Python代码并不是由Python解释器去执行,而是将代码编译成静态计算图,然后执行静态计算图。因此,编译器无法支持全量的Python语法。MindSpore的静态图编译器维护了Python常用语法子集,以支持神经网络的构建及训练。详情可参考静态图语法支持。

JitConfig配置选项

在图模式下,可以通过使用JitConfig配置选项来一定程度的自定义编译流程,目前JitConfig支持的配置参数如下:

  • jit_level: 用于控制优化等级。
  • exec_mode: 用于控制模型执行方式。
  • jit_syntax_level: 设置静态图语法支持级别,详细介绍请见静态图语法支持。

代码实现

目录

保存与加载

保存和加载模型权重

保存和加载MindIR

代码实现:

使用静态图加速

背景介绍

动态图模式

静态图模式

静态图模式的使用场景

静态图模式开启方式

基于装饰器的开启方式

基于context的开启方式

静态图的语法约束

JitConfig配置选项

代码实现


本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/749530.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

grpc学习golang版( 五、多proto文件示例)

系列文章目录 第一章 grpc基本概念与安装 第二章 grpc入门示例 第三章 proto文件数据类型 第四章 多服务示例 第五章 多proto文件示例 第六章 服务器流式传输 文章目录 一、前言二、定义proto文件2.1 公共proto文件2.2 语音唤醒proto文件2.3 人脸唤醒proto文件2.4 生成go代码2.…

最佳Google Chrome扩展和Mozilla Firefox扩展自动解决验证码

在这个信息爆炸的时代,我们每天都要处理大量的在线内容,验证码已成为不可避免的挑战。尽管它们旨在保护网站安全,但也常常成为我们获取信息的障碍。那么,有没有更简单的方法绕过这些验证码呢?答案是肯定的。通过使用一…

恭喜朱雀桥的越南薇妮她牌NFC山竹汁饮料,成为霸王茶姬奶茶主材

朱雀桥NFC山竹汁饮料:荣登霸王茶姬奶茶主材,非遗传承的天然之选 近日,据小编了解到:霸王茶姬欣喜地宣布,成功与朱雀桥达成合作越南薇妮她VINUT牌NFC山竹汁饮料。这款商超产品凭借其卓越的品质与独特的口感&#xff0c…

小项目——MySQL集训(学生成绩录入)

ddl语句 -- 创建学生信息表 CREATE TABLE students (student_id INT AUTO_INCREMENT PRIMARY KEY COMMENT 学生ID,name VARCHAR(50) NOT NULL COMMENT 学生姓名,gender ENUM(男, 女) NOT NULL COMMENT 性别,class VARCHAR(50) NOT NULL COMMENT 班级,registration_date DATE CO…

【Termius】详细说明MacOS中的SSH的客户端利器Termius

希望文章能给到你启发和灵感~ 如果觉得有帮助的话,点赞+关注+收藏支持一下博主哦~ 阅读指南 开篇说明一、基础环境说明1.1 硬件环境1.2 软件环境二、软件的安装2.1 Termius界面介绍2.1.1 Hosts 主机列表2.1.2 SFTP 文件传输2.1.3 Port ForWarding 端口转发2.1.4 Snippets 片…

想要打造高效活跃的私域社群,这些技巧要知道

对一些企业来说“做社群等于做私域”。 在腾讯提到的私域转化场景中,社群与小程序、官方导购三者并列。 社群连接着品牌和群内用户。品牌通过圈住更多用户,来持续免费触达用户实现变现,用户则是从品牌方手中直接获取更多服务和优惠。那么&a…

LabVIEW中卡尔曼滤波的作用与意义

卡尔曼滤波(Kalman Filter)是一种在控制系统和信号处理领域广泛应用的递推滤波算法,能够在噪声环境下对动态系统的状态进行最优估计。其广泛应用于导航、目标跟踪、图像处理、经济预测等多个领域。本文将详细介绍卡尔曼滤波在LabVIEW中的作用…

手机越用越慢?试试这4个秘籍,让手机流畅如新

智能手机作为日常生活的得力助手,最初总是以惊人的速度和流畅性给我们留下深刻印象。 但你有没有发现,随着时间的推移,手机似乎开始变得不那么敏捷,甚至出现了反应迟缓和卡顿的情况? 别让这个问题困扰你,下面是四个关…

基于springboot、vue影院管理系统

设计技术: 开发语言:Java数据库:MySQL技术:SpringbootMybatisvue 工具:IDEA、Maven、Navicat 主要功能: 影城管理系统的主要使用者分为管理员和用户, 实现功能包括管理员: 首页…

从一道算法题开始,爱上Python编程

Python是一门简单易学、高效强大的编程语言,许多人因为它的便捷性和广泛应用而爱上编程。今天,我将通过一道有趣的算法题,带领大家一步步写出Python代码,并最终解决问题。希望通过这篇文章,能激发大家对Python编程的兴…

vue2+webpack 和 vite+vue3 配置获取环境变量(补充)

相关涉及知识点可看小编该文章: nginx: 部署前端项目的详细步骤(vue项目build打包nginx部署)_前端工程打包部署到nginx-CSDN博客 1.vue2webpack 我们通常会在项目中看到这么两个文件(没有则自己创建,文件名:.env.***) …

WITS核心价值观【创新】篇|从财务中来,到业务中去

「客尊」、「诚信」、「创新」 与「卓越」 是纬创软件的核心价值观。我们秉持诚信态度,致力于成为客户长期且值得信赖的合作伙伴。持续提升服务厚度,透过数字创新实践多市场的跨境交付,助客户保持市场领先地位。以追求卓越的不懈精神&#xf…

C++项目实践学习笔记---DLL

linux守护进程 守护进程或精灵进程(Daemon):以后台服务方式运行的进程,它们不占用终端(Shell),因此不会受终端输入或其他信号(如中断信号)的干扰守护进程有如下特点。 &…

如何做好新闻软文宣发媒体资源筛选?

传媒如春雨,润物细无声,大家好,我是51媒体网胡老师。 媒体宣传加速季,100万补贴享不停,一手媒体资源,全国100城线下落地执行。详情请联系胡老师。 新闻软文宣发是指企业通过创造或利用新闻事件&#xff0c…

【JAVA多线程】JMM,成体系聊一下JAVA线程安全问题

目录 1.什么是JMM 2.现代计算机架构 3.线程安全理论 3.1.造成线程安全问题的原因 3.1.1.内存不可见 3.1.2.重排序 3.2.解决思路 3.2.1.as-if-serial 3.2.2.happen-before 4.线程安全实现 4.1.Synchronized 4.2.volatile 1.什么是JMM JMM,全称为Java Me…

胶质瘤的发病原因及诊断方式有哪些?

胶质瘤,这个听起来有些陌生的名词,实际上是一种起源于神经胶质细胞的常见脑肿瘤。它的发病原因复杂,涉及遗传、环境、年龄及感染等多种因素。 首先,遗传因素在胶质瘤的发病中占据一席之地。某些遗传性疾病,如结节性硬化…

【代码阅读】SSC:Semantic Scan Context for Large-Scale Place Recognition

一、主函数 官方开源的代码提供了四个主函数,其中eval_pair.cpp和eval_top1.cpp是一组,分别用于计算两帧的相似度分数以及一帧点云在所有的51帧点云中相似度最高的25帧的相似度分数。eval_seq.cpp是在eval_top1.cpp的基础上,给了一堆序列&am…

【Java Web】会话管理

目录 一、为什么需要会话管理? 二、会话管理机制 三、Cookie概述 四、HttpSession概述 4.1 HttpSession时效性 一、为什么需要会话管理? HTTP协议在设计之初就是无状态的,所谓无状态就是在浏览器和服务器之间的通信过程中,服务器并…

简过网:上万元的学费,考公到底要不要报个培训班?

考公报不报班一直是很多朋友比较纠结一件事,报班了学费太贵,不报班又怕考不上,如果你也有这种困扰,那么,不妨看看这篇文章! 首先,对于报班VS自学这个问题,小编的建议是:…

LICEcap-开源GIF 屏幕录制工具

LICEcap-开源GIF 屏幕录制工具 开源GIF 屏幕录制工具 下载可以访问:https://www.cockos.com/licecap/ 点击Record,开始录制 点击Stop,停止录制 点击Record,进入该页面 display in animation(在动画中显示) …