《昇思25天学习打卡营第6天 | 函数式自动微分》

《昇思25天学习打卡营第6天 | 函数式自动微分》

目录

  • 《昇思25天学习打卡营第6天 | 函数式自动微分》
    • 函数式自动微分
    • 简单的单层线性变换模型
    • 函数与计算图
    • 微分函数与梯度计算
    • Stop Gradient

函数式自动微分

神经网络的训练主要使用反向传播算法,模型预测值(logits)与正确标签(label)送入损失函数(loss function)获得loss,然后进行反向传播计算,求得梯度(gradients),最终更新至模型参数(parameters)。自动微分能够计算可导函数在某点处的导数值,是反向传播算法的一般化。自动微分主要解决的问题是将一个复杂的数学运算分解为一系列简单的基本运算,该功能对用户屏蔽了大量的求导细节和过程,大大降低了框架的使用门槛。

MindSpore使用函数式自动微分的设计理念,提供更接近于数学语义的自动微分接口grad和value_and_grad。

简单的单层线性变换模型

我们通过学习使用一个简单的单层线性变换模型来了解

import numpy as np
import mindspore
from mindspore import nn
from mindspore import ops
from mindspore import Tensor, Parameter

函数与计算图

计算图是用图论语言表示数学函数的一种方式,也是深度学习框架表达神经网络模型的统一方法。我们将根据下面的计算图构造计算函数和神经网络。

compute-graph
在这个模型中, 𝑥
为输入, 𝑦
为正确值, 𝑤
和 𝑏
是我们需要优化的参数。

x = ops.ones(5, mindspore.float32)  # input tensor
y = ops.zeros(3, mindspore.float32)  # expected output
w = Parameter(Tensor(np.random.randn(5, 3), mindspore.float32), name='w') # weight
b = Parameter(Tensor(np.random.randn(3,), mindspore.float32), name='b') # bias

我们根据计算图描述的计算过程,构造计算函数。 其中,binary_cross_entropy_with_logits 是一个损失函数,计算预测值和目标值之间的二值交叉熵损失。

def function(x, y, w, b):
    z = ops.matmul(x, w) + b
    loss = ops.binary_cross_entropy_with_logits(z, y, ops.ones_like(z), ops.ones_like(z))
    return loss

执行计算函数,可以获得计算的loss值。

loss = function(x, y, w, b)
print(loss)

Tensor(shape=[], dtype=Float32, value= 0.914285)

微分函数与梯度计算

为了优化模型参数,需要求参数对loss的导数: ∂ loss ⁡ ∂ w \frac{\partial \operatorname{loss}}{\partial w} wloss ∂ loss ⁡ ∂ b \frac{\partial \operatorname{loss}}{\partial b} bloss,此时我们调用mindspore.grad函数,来获得function的微分函数。

这里使用了grad函数的两个入参,分别为:

  • fn:待求导的函数。
  • grad_position:指定求导输入位置的索引。

由于我们对 w w w b b b求导,因此配置其在function入参对应的位置(2, 3)

使用grad获得微分函数是一种函数变换,即输入为函数,输出也为函数。

grad_fn = mindspore.grad(function, (2, 3))

执行微分函数,即可获得 𝑤
、 𝑏
对应的梯度。

grads = grad_fn(x, y, w, b)
print(grads)

(Tensor(shape=[5, 3], dtype=Float32, value=
[[ 6.56869709e-02, 5.37334494e-02, 3.01467031e-01],
[ 6.56869709e-02, 5.37334494e-02, 3.01467031e-01],
[ 6.56869709e-02, 5.37334494e-02, 3.01467031e-01],
[ 6.56869709e-02, 5.37334494e-02, 3.01467031e-01],
[ 6.56869709e-02, 5.37334494e-02, 3.01467031e-01]]),
Tensor(shape=[3], dtype=Float32, value= [ 6.56869709e-02, 5.37334494e-02, 3.01467031e-01]))

Stop Gradient

通常情况下,求导时会求loss对参数的导数,因此函数的输出只有loss一项。当我们希望函数输出多项时,微分函数会求所有输出项对参数的导数。此时如果想实现对某个输出项的梯度截断,或消除某个Tensor对梯度的影响,需要用到Stop Gradient操作。

这里我们将function改为同时输出loss和z的function_with_logits,获得微分函数并执行。

def function_with_logits(x, y, w, b):
    z = ops.matmul(x, w) + b
    loss = ops.binary_cross_entropy_with_logits(z, y, ops.ones_like(z), ops.ones_like(z))
    return loss, z
grad_fn = mindspore.grad(function_with_logits, (2, 3))
grads = grad_fn(x, y, w, b)
print(grads)

(Tensor(shape=[5, 3], dtype=Float32, value=
[[ 1.06568694e+00, 1.05373347e+00, 1.30146706e+00],
[ 1.06568694e+00, 1.05373347e+00, 1.30146706e+00],
[ 1.06568694e+00, 1.05373347e+00, 1.30146706e+00],
[ 1.06568694e+00, 1.05373347e+00, 1.30146706e+00],
[ 1.06568694e+00, 1.05373347e+00, 1.30146706e+00]]),
Tensor(shape=[3], dtype=Float32, value= [ 1.06568694e+00, 1.05373347e+00, 1.30146706e+00]))

可以看到求得 w w w b b b对应的梯度值发生了变化。此时如果想要屏蔽掉z对梯度的影响,即仍只求参数对loss的导数,可以使用ops.stop_gradient接口,将梯度在此处截断。我们将function实现加入stop_gradient,并执行。

def function_stop_gradient(x, y, w, b):
    z = ops.matmul(x, w) + b
    loss = ops.binary_cross_entropy_with_logits(z, y, ops.ones_like(z), ops.ones_like(z))
    return loss, ops.stop_gradient(z)
grad_fn = mindspore.grad(function_stop_gradient, (2, 3))
grads = grad_fn(x, y, w, b)
print(grads)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/760502.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

基于ssm口红商城管理的设计与实现

一、👨‍🎓网站题目 口红商城项目可以提供更加便捷和高效的购物方式。消费者可以在家中使用电脑或手机随时随地购物,避免了传统购物方式中需要花费时间和精力去实体店铺购物的麻烦。此外,口红商城项目还提供了更多的选择和更低的…

JavaSE (Java基础):面向对象(上)

8 面向对象 面向对象编程的本质就是:以类的方法组织代码,以对象的组织(封装)数据。 8.1 方法的回顾 package com.oop.demo01;// Demo01 类 public class Demo01 {// main方法public static void main(String[] args) {int c 10…

2023年的Facebook营销:超级完整指南

Facebook营销不是可选的,是必须的。Facebook是世界上使用最多的社交平台,每天吸引22.9亿活跃用户。 它也不全是度假照片和虚张声势。对于53.2% 的 16-24 岁互联网用户,社交媒体是他们进行品牌研究的主要来源。而且,66% 的 Facebo…

GoSync+华为智能穿戴使用指导

GoSync官方简介: GoSync 是一款免费应用程序,主要用于将您的可穿戴设备中的步行、跑步、骑自行车和游泳等活动数据同步到您的 Google Fit 和其他健身平台。在开始同步数据之前,您需要将您的可穿戴设备账户与您的健身平台账户连接起来。在创建…

WEB攻防【5】——JS项目/Node.js框架安全/识别审计/验证绕过

1、test.php和test.html对比 #知识点: 1、原生JS&开发框架-安全条件 2、常见安全问题-前端验证&未授权 #详细点: 1、什么是Js渗透测试? 在Javascript中也存在变量和函数,当存在可控变量及函数调用即可参数漏洞 JS开发的WEB应用和PHP、java.NET…

WebRtc实现1V1音视频通话

简介 WebRTC,名称源自网页实时通信(Web Real-Time Communication)的缩写,是一个支持网页浏览器进行实时语音通话或视频聊天的技术,是谷歌 2010 年以 6820 万美元收购 Global IP Solutions 公司而获得的一项技术。 WebR…

STM32之四:TIM定时器(1-基本定时器)

目录 1. STM32有哪些定时器 2. 基本定时器 2.1 基本定时器主要功能 2.2 基本定时器的框图 2.2.1 时钟输入 2.2.2 至DAC 2.2.3 至时基单元(重点) 2.2.4 影子寄存器 2.2.5 基本定时器寄存器说明 2.2.5.1 控制寄存器1(TIMx_CR1&#x…

压缩包怎么解压,解压压缩包不损坏文件

常见格式: ZIP:最常见的压缩文件格式之一,支持跨平台。RAR:另一种常见的压缩文件格式,通常压缩率比ZIP高,但不如ZIP普及。7Z:来自7-Zip的压缩格式,支持更高的压缩率和一些高级特性。…

设计NOR Flash(FMSC接口)的Flashloader(MCU: stm32f4)

目录 概述 1 硬件 1.1 MCU和S29GL128P10TFI01控制电路 1.1.1 S29GL128P10TFI01 1.1.2 MCU与NOR Flash接口 1.2 STM32F4的FSMC接口 1.2.1 时序信号 1.2.2 外部存储器接口信号 2 Flash leader功能实现 2.1 框架结构介绍 2.2 S29GL128P10TFI01的Flash leader框架 2.3 N…

人工智能期末复习笔记(更新中)

分类问题 分类:根据已知样本的某些特征,判断一个新的样本属于哪种已知的样本类 垃圾分类、图像分类 怎么解决分类问题 分类和回归的区别 1. 逻辑回归分类 用于解决分类问题的一种模型。根据数据特征或属性,计算其归属于某一类别 的概率P,…

66、基于长短期记忆 (LSTM) 网络对序列数据进行分类

1、基于长短期记忆 (LSTM) 网络对序列数据进行分类的原理及流程 基于长短期记忆(LSTM)网络对序列数据进行分类是一种常见的深度学习任务,适用于处理具有时间或序列关系的数据。下面是在Matlab中使用LSTM网络对序列数据进行分类的基本原理和流…

XJTUSE-数据结构-homework1

任务 1 题目: 排序算法设计: 需要写Selection、Shell、Quicksort 和 Mergesort四种排序算法,书上讲述比较全面而且不需要进行额外的优化,下面我简要地按照自己的理解讲述。 Selection(选择排序)&#xff…

HarmonyOS Next开发学习手册——单选框 (Radio)

Radio是单选框组件,通常用于提供相应的用户交互选择项,同一组的Radio中只有一个可以被选中。具体用法请参考 Radio 。 创建单选框 Radio通过调用接口来创建,接口调用形式如下: Radio(options: {value: string, group: string})…

Linux常用工具使用方式

目录 常用工具: 安装包管理工具: 查找含有关键字的软件包 安装软件 安装文件传输工具 安装编辑器 C语言编译器 C编译器 安装调试器 安装项目版本管理工具 cmake 卸载软件 安装jsoncpp 安装boost库 安装mariadb 安装tree(让目录…

Python28-3 朴素贝叶斯分类算法

朴素贝叶斯算法简介 朴素贝叶斯(Naive Bayes)算法是一种基于贝叶斯定理的分类算法。它广泛应用于文本分类、垃圾邮件检测和情感分析等领域。该算法假设特征之间是独立的,这个假设在实际情况中可能并不完全成立,但Naive Bayes在许…

java笔记(30)——反射的 API 及其 使用

文章目录 反射1. 什么是反射2. 获取class字段(字节码文件对象)方式1方式2方式3应用 3. 获取构造方法和权限修饰符前期准备获取所有的公共构造方法获取所有的构造方法获取无参构造方法获取一个参数的构造方法获取一个参数的构造方法获取两个参数的构造方法…

Java面试题--JVM大厂篇之G1 GC的分区管理方式如何减少应用线程的影响

目录 引言: 正文: 1. 区域划分(Region) 2. 并行和并发回收 3. 区域优先回收(Garbage First) 4. 可预测的停顿时间 5. 分阶段回收 6. 复制和压缩 实际效果: 场景举例 1. 减少单次GC的影响 2. 支持高并发环境 3. 优…

数学建模(1):期末大乱炖

1 概述!! 1.1 原型和模型 原型:客观存在的研究对象称为原型,也称为“系统”、“过程”。 机械系统、电力系统、化学反应过程、生产销售过程等都是原型; 研究原型的结构和原理, 从而进行优化、预测、评价…

一区算法MPA|海洋捕食者算法原理及其代码实现(Matlab/Python))

Matlab/Python: 本文KAU将介绍一个2020年发表在1区期刊ESWA上的优化算法——海洋捕食者算法 (Marine Predators Algorithm,MPA)[1] 该算法由Faramarzi等于2020年提出,其灵感来源于海洋捕食者之间不同的觅食策略、最佳相遇概率策略、海洋记…

【MySQL】Linux下MySQL的目录结构、用户、权限与角色

一、Linux下MySQL的目录结构 1、MySQL相关目录 数据库文件存放路径:/var/lib/mysql数据库命令存放路径:/user/bin和/user/sbin配置文件目录:/usr/share/mysql-8.0/、/usr/share/mysql/和/etc/my.cnf 2、假设我们创建了一个数据库dbtest1&a…