昇思25天学习打卡营第4天|网络构建|函数式自动微分

学AI还能赢奖品?每天30分钟,25天打通AI任督二脉 (qq.com)

网络构建

神经网络模型是由神经网络层和Tensor操作构成的,mindspore.nn提供了常见神经网络层的实现,在MindSpore中,Cell类是构建所有网络的基类,也是网络的基本单元。一个神经网络模型表示为一个Cell,它由不同的子Cell构成。使用这样的嵌套结构,可以简单地使用面向对象编程的思维,对神经网络结构进行构建和管理。

下面我们将构建一个用于Mnist数据集分类的神经网络模型。

%%capture captured_output
# 实验环境已经预装了mindspore==2.2.14,如需更换mindspore版本,可更改下面mindspore的版本号
!pip uninstall mindspore -y
!pip install -i https://pypi.mirrors.ustc.edu.cn/simple mindspore==2.2.14
import mindspore
from mindspore import nn, ops

定义模型类

当我们定义神经网络时,可以继承nn.Cell类,在__init__方法中进行子Cell的实例化和状态管理,在construct方法中实现Tensor操作。

construct意为神经网络(计算图)构建,相关内容详见使用静态图加速。

class Network(nn.Cell):
    def __init__(self):
        super().__init__()
        self.flatten = nn.Flatten()
        self.dense_relu_sequential = nn.SequentialCell(
            nn.Dense(28*28, 512, weight_init="normal", bias_init="zeros"),
            nn.ReLU(),
            nn.Dense(512, 512, weight_init="normal", bias_init="zeros"),
            nn.ReLU(),
            nn.Dense(512, 10, weight_init="normal", bias_init="zeros")
        )

    def construct(self, x):
        x = self.flatten(x)
        logits = self.dense_relu_sequential(x)
        return logits

构建完成后,实例化Network对象,并查看其结构。

model = Network()
print(model)

我们构造一个输入数据,直接调用模型,可以获得一个十维的Tensor输出,其包含每个类别的原始预测值。

model.construct()方法不可直接调用。

X = ops.ones((1, 28, 28), mindspore.float32)
logits = model(X)
# print logits
logits

在此基础上,我们通过一个nn.Softmax层实例来获得预测概率。

pred_probab = nn.Softmax(axis=1)(logits)
y_pred = pred_probab.argmax(1)
print(f"Predicted class: {y_pred}")

使用nn.Cell作为基类来定义自己的神经网络模型Network。在__init__方法中初始化所需的神经网络层,并在construct方法中定义前向传播过程。

实例化自定义的网络类Network,调用模型实例model处理输入数据X,得到预测输出logits。通过nn.Softmax转换为各分类的概率。

模型层

本节中我们分解上节构造的神经网络模型中的每一层。首先我们构造一个shape为(3, 28, 28)的随机数据(3个28x28的图像),依次通过每一个神经网络层来观察其效果。

input_image = ops.ones((3, 28, 28), mindspore.float32)
print(input_image.shape)

nn.Flatten

实例化nn.Flatten层,将28x28的2D张量转换为784大小的连续数组。

flatten = nn.Flatten()
flat_image = flatten(input_image)
print(flat_image.shape)

nn.Dense

nn.Dense为全连接层,其使用权重和偏差对输入进行线性变换。

layer1 = nn.Dense(in_channels=28*28, out_channels=20)
hidden1 = layer1(flat_image)
print(hidden1.shape)

nn.ReLU¶

nn.ReLU层给网络中加入非线性的激活函数,帮助神经网络学习各种复杂的特征。

print(f"Before ReLU: {hidden1}\n\n")
hidden1 = nn.ReLU()(hidden1)
print(f"After ReLU: {hidden1}")

nn.SequentialCell¶

nn.SequentialCell是一个有序的Cell容器。输入Tensor将按照定义的顺序通过所有Cell。我们可以使用SequentialCell来快速组合构造一个神经网络模型。

seq_modules = nn.SequentialCell(
    flatten,
    layer1,
    nn.ReLU(),
    nn.Dense(20, 10)
)

logits = seq_modules(input_image)
print(logits.shape)

nn.Softmax¶

最后使用nn.Softmax将神经网络最后一个全连接层返回的logits的值缩放为[0, 1],表示每个类别的预测概率。axis指定的维度数值和为1。

softmax = nn.Softmax(axis=1)
pred_probab = softmax(logits)

nn.Flatten展平输入数据,nn.Dense全连接层,nn.ReLU非线性激活函数,nn.SequentialCell有序地组合这些层,形成一个完整的网络结构。nn.Softmax将网络的原始输出转换为概率分布,以进行分类预测。

模型参数

网络内部神经网络层具有权重参数和偏置参数(如nn.Dense),这些参数会在训练过程中不断进行优化,可通过 model.parameters_and_names() 来获取参数名及对应的参数详情。

print(f"Model structure: {model}\n\n")

for name, param in model.parameters_and_names():
    print(f"Layer: {name}\nSize: {param.shape}\nValues : {param[:2]} \n")

更多内置神经网络层详见mindspore.nn API。

查看模型的结构和参数详情。

面向对象编程: 利用MindSpore的nn.Cell基类,使用面向对象的编程风格来构建和管理网络结构。

模块化: 将不同的神经网络层封装成模块,灵活地组合和重用这些模块组合成完整的网络。

调试和可视化: 查看每层的输出和参数对于调试和理解模型有帮助。

函数式自动微分

神经网络的训练主要使用反向传播算法,模型预测值(logits)与正确标签(label)送入损失函数(loss function)获得loss,然后进行反向传播计算,求得梯度(gradients),最终更新至模型参数(parameters)。自动微分能够计算可导函数在某点处的导数值,是反向传播算法的一般化。自动微分主要解决的问题是将一个复杂的数学运算分解为一系列简单的基本运算,该功能对用户屏蔽了大量的求导细节和过程,大大降低了框架的使用门槛。

MindSpore使用函数式自动微分的设计理念,提供更接近于数学语义的自动微分接口gradvalue_and_grad。下面我们使用一个简单的单层线性变换模型进行介绍。        

%%capture captured_output
# 实验环境已经预装了mindspore==2.2.14,如需更换mindspore版本,可更改下面mindspore的版本号
!pip uninstall mindspore -y
!pip install -i https://pypi.mirrors.ustc.edu.cn/simple mindspore==2.2.14
import numpy as np
import mindspore
from mindspore import nn
from mindspore import ops
from mindspore import Tensor, Parameter

函数与计算图

计算图是用图论语言表示数学函数的一种方式,也是深度学习框架表达神经网络模型的统一方法。我们将根据下面的计算图构造计算函数和神经网络。

compute-graph

在这个模型中,𝑥为输入,𝑦为正确值,𝑤和𝑏是我们需要优化的参数。

x = ops.ones(5, mindspore.float32)  # input tensor
y = ops.zeros(3, mindspore.float32)  # expected output
w = Parameter(Tensor(np.random.randn(5, 3), mindspore.float32), name='w') # weight
b = Parameter(Tensor(np.random.randn(3,), mindspore.float32), name='b') # bias

我们根据计算图描述的计算过程,构造计算函数。 其中,binary_cross_entropy_with_logits 是一个损失函数,计算预测值和目标值之间的二值交叉熵损失。

def function(x, y, w, b):
    z = ops.matmul(x, w) + b
    loss = ops.binary_cross_entropy_with_logits(z, y, ops.ones_like(z), ops.ones_like(z))
    return loss

执行计算函数,可以获得计算的loss值。

loss = function(x, y, w, b)
print(loss)

根据计算图构造计算函数和神经网络。定义损失函数(二值交叉熵)计算预测值与目标值之间的损失。

微分函数与梯度计算¶

为了优化模型参数,需要求参数对loss的导数:$\frac{\partial \operatorname{loss}}{\partial w}$$\frac{\partial \operatorname{loss}}{\partial b}$,此时我们调用mindspore.grad函数,来获得function的微分函数。

这里使用了grad函数的两个入参,分别为:

  • fn:待求导的函数。
  • grad_position:指定求导输入位置的索引。

由于我们对𝑤和𝑏求导,因此配置其在function入参对应的位置(2, 3)

使用grad获得微分函数是一种函数变换,即输入为函数,输出也为函数。

grad_fn = mindspore.grad(function, (2, 3))

执行微分函数,即可获得𝑤、𝑏对应的梯度。

grads = grad_fn(x, y, w, b)
print(grads)

使用MindSpore的grad函数,获得指定参数位置grad_position的梯度。

Stop Gradient¶

通常情况下,求导时会求loss对参数的导数,因此函数的输出只有loss一项。当我们希望函数输出多项时,微分函数会求所有输出项对参数的导数。此时如果想实现对某个输出项的梯度截断,或消除某个Tensor对梯度的影响,需要用到Stop Gradient操作。

这里我们将function改为同时输出loss和z的function_with_logits,获得微分函数并执行。

def function_with_logits(x, y, w, b):
    z = ops.matmul(x, w) + b
    loss = ops.binary_cross_entropy_with_logits(z, y, ops.ones_like(z), ops.ones_like(z))
    return loss, z
grad_fn = mindspore.grad(function_with_logits, (2, 3))
grads = grad_fn(x, y, w, b)
print(grads)

可以看到求得𝑤、𝑏对应的梯度值发生了变化。此时如果想要屏蔽掉z对梯度的影响,即仍只求参数对loss的导数,可以使用ops.stop_gradient接口,将梯度在此处截断。我们将function实现加入stop_gradient,并执行。

def function_stop_gradient(x, y, w, b):
    z = ops.matmul(x, w) + b
    loss = ops.binary_cross_entropy_with_logits(z, y, ops.ones_like(z), ops.ones_like(z))
    return loss, ops.stop_gradient(z)
grad_fn = mindspore.grad(function_stop_gradient, (2, 3))
grads = grad_fn(x, y, w, b)
print(grads)

可以看到,求得𝑤、𝑏对应的梯度值与初始function求得的梯度值一致。

阻止某个Tensor对梯度的影响,可以使用ops.stop_gradient接口来实现梯度的截断。

Auxiliary data

Auxiliary data意为辅助数据,是函数除第一个输出项外的其他输出。通常我们会将函数的loss设置为函数的第一个输出,其他的输出即为辅助数据。

gradvalue_and_grad提供has_aux参数,当其设置为True时,可以自动实现前文手动添加stop_gradient的功能,满足返回辅助数据的同时不影响梯度计算的效果。

下面仍使用function_with_logits,配置has_aux=True,并执行。

grad_fn = mindspore.grad(function_with_logits, (2, 3), has_aux=True)
grads, (z,) = grad_fn(x, y, w, b)
print(grads, z)

可以看到,求得𝑤、𝑏对应的梯度值与初始function求得的梯度值一致,同时z能够作为微分函数的输出返回。

在微分函数grad中,除了主要的输出(如loss)之外,还可能有其他的辅助输出。使用has_aux参数可以满足返回辅助数据的同时不影响梯度计算。

神经网络梯度计算

前述章节主要根据计算图对应的函数介绍了MindSpore的函数式自动微分,但我们的神经网络构造是继承自面向对象编程范式的nn.Cell。接下来我们通过Cell构造同样的神经网络,利用函数式自动微分来实现反向传播。

首先我们继承nn.Cell构造单层线性变换神经网络。这里我们直接使用前文的𝑤𝑤、𝑏𝑏作为模型参数,使用mindspore.Parameter进行包装后,作为内部属性,并在construct内实现相同的Tensor操作。

# Define model
class Network(nn.Cell):
    def __init__(self):
        super().__init__()
        self.w = w
        self.b = b

    def construct(self, x):
        z = ops.matmul(x, self.w) + self.b
        return z

接下来我们实例化模型和损失函数。

# Instantiate model
model = Network()
# Instantiate loss function
loss_fn = nn.BCEWithLogitsLoss()

完成后,由于需要使用函数式自动微分,需要将神经网络和损失函数的调用封装为一个前向计算函数。

# Define forward function
def forward_fn(x, y):
    z = model(x)
    loss = loss_fn(z, y)
    return loss

完成后,我们使用value_and_grad接口获得微分函数,用于计算梯度。

由于使用Cell封装神经网络模型,模型参数为Cell的内部属性,此时我们不需要使用grad_position指定对函数输入求导,因此将其配置为None。对模型参数求导时,我们使用weights参数,使用model.trainable_params()方法从Cell中取出可以求导的参数。

grad_fn = mindspore.value_and_grad(forward_fn, None, weights=model.trainable_params())
loss, grads = grad_fn(x, y)
print(grads)

执行微分函数,可以看到梯度值和前文function求得的梯度值一致。

基于nn.Cell的神经网络模型,可以使用value_and_grad接口结合模型的trainable_params()方法来计算梯度。

自动微分:自动微分简化了梯度计算过程,让开发者可以专注于模型设计而非复杂的数学推导。MindSpore通过value_and_gradgrad支持面向对象的模型定义(继承nn.Cell)和函数式自动微分。

计算图:计算图以图形方式表示了函数的运算流程,使得自动微分能够按图进行反向传播。

梯度计算:MindSpore允许对特定的函数输入位置进行梯度计算(通过grad_position参数),并且提供了ops.stop_gradient来控制哪些部分参与梯度计算。通过has_aux=True,可以在计算梯度的同时返回辅助数据。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/731796.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

借助AI营销类API,实现自动化的营销流程

借助AI营销类API,企业可以实现自动化的营销流程,提高效率和效果,并节省大量的时间和资源。这些API利用人工智能和机器学习的技术,能够自动化地执行各种营销任务和流程。首先,AI营销类API可以帮助企业实现自动化的市场调…

【鸿蒙】创建第⼀个鸿蒙项⽬

点击 Create Project 配置项目 开发工具界面 工程介绍

探索AI前沿:本地部署GPT-4o,打造专属智能助手!

目录 1、获取API_key 2、开始调用 3、openai连接异常 4、解决方法: 5、调用GPT-4o 1、获取API_key 这里就不多赘述了,大家可以参考下面这篇博客 怎么获取OpenAI的api-key【人工智能】https://blog.csdn.net/qq_51625007/article/details/13763274…

大数据与java哪个好找工作?这篇文章帮你做选择!

大数据与java哪个好找工作?这篇文章帮你做选择! 还在为选择Java开发还是Java大数据而头疼吗?别担心,本文将从就业前景、学习方向、学习内容以及薪资待遇四个方面,为你揭开Java和Java大数据的神秘面纱,帮你做…

ZW3D二次开发_删除草图中的实体

1.目前草图中的实体不能直接通过id删除,而是通过entityPath实体路径,所以需要将id转化为实体路径。 2.以下示例代码的主要功能为获取草图中的所有实体并删除: int Count;int *idEnts;ZF_CALL(cvxSkInqGeom(&Count, &idEnts));//获取…

《窄门》情不知所起,而一往情深

《窄门》情不知所起,而一往情深 安德烈纪德(1869-1951),法国作家。纪德一生著有小说、剧本、论文、散文、日记、书信多种,主要作品有小说《背德者》《窄门》《田园交响曲》《伪币制造者》等,戏剧《康多尔王…

x64汇编fastcall调用约定

x64汇编环境:只需要在x86基础上对项目属性进行设置,将平台设置为所有平台; 以及在将debug改为x64模式即可: 后续写完代码直接生成项目再使用本地调试器进行运行即可。 fastcall调用约定 在x64架构下,fastcall调用约定…

android倒计时封装(活动进入后台,倒计时依然能正常计时)

public class TimeUtils { /倒计时时长 单位:秒/ public static int COUNT 20*60; /当前做/ private static int CURR_COUNT 0; /预计结束的时间/ private static long TIME_END 0; /计时器/ private static Timer countdownTimer; /显示倒计时的textVi…

大数据学习-大数据介绍

意义 从海量的数据中分析出海量数据背后的价值 需要分析海量的数据,就需要存储、计算和分析 那就需要分布式多台计算机合适的工具来处理数据 工具 特点 大数据的核心工作:从海量的、高增长的、多类别的、信息密度低的数据中挖掘出高质量的结果 数据存储…

STM32通过SPI软件读写W25Q64

文章目录 1. W25Q64 2. 硬件电路 3. W25Q64框架图 4. 软件/硬件波形对比 5. 代码实现 5.1 MyI2C.c 5.2 MyI2C.h 5.3 W25Q64.c 5.4 W25Q64.h 5.5 W25Q64_Ins.h 5.6 main.c 1. W25Q64 对于SPI通信和W25Q64的详细解析可以看下面这篇文章 STM32单片机SPI通信详解-CSDN博…

开发中遇到的一个bug

遇到的报错信息是这样的: java: Annotation processing is not supported for module cycles. Please ensure that all modules from cycle [hm-api,hm-common,hm-service] are excluded from annotation processing 翻译过来就是存在循环引用的情况,导…

C++进阶之AVL树

个人主页:点我进入主页 专栏分类:C语言初阶 C语言进阶 数据结构初阶 Linux C初阶 C进阶​ ​​​​算法 欢迎大家点赞,评论,收藏。 一起努力,一起奔赴大厂 目录 一.前言 二.插入 三.旋转 3.1右旋 …

postman国内外竞争者及使用详解分析

一、postman简介 Postman 是一款广泛使用的 API 开发和测试工具,适用于开发人员和测试人员。它提供了一个直观的界面,用于发送 HTTP 请求、查看响应、创建和管理 API 测试用例,以及自动化 API 测试工作流程。以下是 Postman 的主要功能和特点…

Docker常用操作和命令

文章目录 1、卸载旧版本 2、yum安装Docker CE(社区版) 3、添加镜像加速器 4、docker --version 查看docker版本 5、docker info 或 docker system info 显示 Docker 系统的详细信息,包括容器、镜像、网络等 6、docker search 搜索镜像 …

JVM类加载器与双亲委派机制

通过上一篇Java的类加载机制相信大家已经搞明白了整个类加载从触发时机,接着我们就来看下类加载器,因为类加载机制是有加载器实现的。 类加载器的分类 启动类加载器 Bootstrap ClassLoader 是 Java 虚拟机(JVM)的一部分&#x…

C#调用OpenCvSharp实现图像的直方图均衡化

本文学习基于OpenCvSharp的直方图均衡化处理方式,并使用SkiaSharp绘制相关图形。直方图均衡化是一种图像处理方法,针对偏亮或偏暗的图像,通过调整图像的像素值来增强图像对比度,详细原理及介绍见参考文献1-4。   直方图均衡化第…

【中学教资科目二】02中学课程

02中学课程 第一节 课程概述1.1 课程的分类 第二节 课程组织2.1 课程内容的文本表现形式2.2 课程评价 第三节 基础教育课程改革3.1 基础教育改革的目标3.2 新课改的课程结构 第一节 课程概述 1.1 课程的分类 学校课程有多种类型,其中最利于学生系统掌握人类所取得的…

多维表格/业务库表格大数据量性能瓶颈

先说最终结论:Angular 组件创建性能损耗是当下主要的性能瓶颈 理由: 基于以往编辑器性能优化的经验,编辑器在动态渲染内容时会创建很多壳子组件(也就是Angular 组件),排查的时候就发现如果略这些壳子组件性…

mysql--安装跳过验证修改密码安全加固

安装mysql 配置mysql的yum源 [rootVM-0-14-rockylinux ~]# tee /etc/yum.repos.d/mysql.repo << EOF > [MYSQL] > namemysql > baseurlhttps://mirrors.tuna.tsinghua.edu.cn/mysql/yum/mysql-5.7-community-el7-x86_64 > gpgcheck0 > EOF yum安装mysq…

海南聚广众达电子商务咨询有限公司抖音电商新标杆

在数字经济的浪潮中&#xff0c;抖音电商正成为一股不可忽视的力量。海南聚广众达电子商务咨询有限公司&#xff0c;作为专注于抖音电商服务的领军企业&#xff0c;凭借其专业的团队和创新的思维&#xff0c;不断助力商家在抖音平台上实现商业价值的最大化。 海南聚广众达电子…