昇思MindSpore学习入门-函数式自动微分

函数式自动微分

神经网络的训练主要使用反向传播算法,模型预测值(logits)与正确标签(label)送入损失函数(loss function)获得loss,然后进行反向传播计算,求得梯度(gradients),最终更新至模型参数(parameters)。自动微分能够计算可导函数在某点处的导数值,是反向传播算法的一般化。自动微分主要解决的问题是将一个复杂的数学运算分解为一系列简单的基本运算,该功能对用户屏蔽了大量的求导细节和过程,大大降低了框架的使用门槛。

MindSpore使用函数式自动微分的设计理念,提供更接近于数学语义的自动微分接口grad和value_and_grad。下面我们使用一个简单的单层线性变换模型进行介绍。

函数与计算图

计算图是用图论语言表示数学函数的一种方式,也是深度学习框架表达神经网络模型的统一方法。我们将根据下面的计算图构造计算函数和神经网络。

在这个模型中,𝑥为输入,𝑦为正确值,𝑤和𝑏是我们需要优化的参数。

我们根据计算图描述的计算过程,构造计算函数。 其中,binary_cross_entropy_with_logits 是一个损失函数,计算预测值和目标值之间的二值交叉熵损失。

执行计算函数,可以获得计算的loss值。

微分函数与梯度计算

为了优化模型参数,需要求参数对loss的导数:

,此时我们调用mindspore.grad函数,来获得function的微分函数。

这里使用了grad函数的两个入参,分别为:

  • fn:待求导的函数。
  • grad_position:指定求导输入位置的索引。

由于我们对𝑤和𝑏求导,因此配置其在function入参对应的位置(2, 3)。

执行微分函数,即可获得𝑤、𝑏对应的梯度。

Stop Gradient

通常情况下,求导时会求loss对参数的导数,因此函数的输出只有loss一项。当我们希望函数输出多项时,微分函数会求所有输出项对参数的导数。此时如果想实现对某个输出项的梯度截断,或消除某个Tensor对梯度的影响,需要用到Stop Gradient操作。

这里我们将function改为同时输出loss和z的function_with_logits,获得微分函数并执行。

可以看到求得𝑤、𝑏对应的梯度值发生了变化。此时如果想要屏蔽掉z对梯度的影响,即仍只求参数对loss的导数,可以使用ops.stop_gradient接口,将梯度在此处截断。我们将function实现加入stop_gradient,并执行。

可以看到,求得𝑤、𝑏对应的梯度值与初始function求得的梯度值一致。

Auxiliary data

Auxiliary data意为辅助数据,是函数除第一个输出项外的其他输出。通常我们会将函数的loss设置为函数的第一个输出,其他的输出即为辅助数据。

grad和value_and_grad提供has_aux参数,当其设置为True时,可以自动实现前文手动添加stop_gradient的功能,满足返回辅助数据的同时不影响梯度计算的效果。

下面仍使用function_with_logits,配置has_aux=True,并执行。

可以看到,求得𝑤、𝑏对应的梯度值与初始function求得的梯度值一致,同时z能够作为微分函数的输出返回。

神经网络梯度计算

前述章节主要根据计算图对应的函数介绍了MindSpore的函数式自动微分,但我们的神经网络构造是继承自面向对象编程范式的nn.Cell。接下来我们通过Cell构造同样的神经网络,利用函数式自动微分来实现反向传播。

首先我们继承nn.Cell构造单层线性变换神经网络。这里我们直接使用前文的𝑤、𝑏作为模型参数,使用mindspore.Parameter进行包装后,作为内部属性,并在construct内实现相同的Tensor操作。

接下来我们实例化模型和损失函数。

由于需要使用函数式自动微分,需要将神经网络和损失函数的调用封装为一个前向计算函数。

完成后,我们使用value_and_grad接口获得微分函数,用于计算梯度。

由于使用Cell封装神经网络模型,模型参数为Cell的内部属性,此时我们不需要使用grad_position指定对函数输入求导,因此将其配置为None。对模型参数求导时,我们使用weights参数,使用model.trainable_params()方法从Cell中取出可以求导的参数。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/758640.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

论文解读:【CVPR2024】DUSt3R: Geometric 3D Vision Made Easy

论文“”https://openaccess.thecvf.com/content/CVPR2024/papers/Wang_DUSt3R_Geometric_3D_Vision_Made_Easy_CVPR_2024_paper.pdf 代码:GitHub - naver/dust3r: DUSt3R: Geometric 3D Vision Made Easy DUSt3R是一种旨在简化几何3D视觉任务的新框架。作者着重于…

Java高级重点知识点-17-异常

文章目录 异常异常处理自定义异常 异常 指的是程序在执行过程中,出现的非正常的情况,最终会导致JVM的非正常停止。Java处 理异常的方式是中断处理。 异常体系 异常的根类是 java.lang.Throwable,,其下有两个子类:ja…

实验4 图像空间滤波

1. 实验目的 ①掌握图像空间滤波的主要原理与方法; ②掌握图像边缘提取的主要原理和方法; ③了解空间滤波在图像处理和机器学习中的应用。 2. 实验内容 ①调用 Matlab / Python OpenCV中的函数,实现均值滤波、高斯滤波、中值滤波等。 ②调…

java基于ssm+jsp 多用户博客个人网站

1管理员功能模块 管理员登录,管理员通过输入用户名、密码等信息进行系统登录,如图1所示。 图1管理员登录界面图 管理员登录进入个人网站可以查看;个人中心、博文类型管理、学生博客管理、学生管理、论坛信息、管理员管理、我的收藏管理、留…

Linux多进程和多线程(一)-进程的概念和创建

进程 进程的概念进程的特点如下进程和程序的区别LINUX进程管理 getpid()getppid() 进程的地址空间虚拟地址和物理地址进程状态管理进程相关命令 ps toppstreekill 进程的创建 并发和并行fork() 父子进程执行不同的任务创建多个进程 进程的退出 exit()和_exit() exit()函数让当…

微短剧市场还能火多久?短剧小程序是否有必要搭建?,现在入场到底晚不晚?

我公司在2019年开始都是做软件开发的,从2022到现在(2024)特别深有体会,在2022年的时候我公司还是在全部做外包项目,一年大概遇到了10多个咨询短剧领域的软件定制,但是当时我只是以为是一个影视播放的程序&a…

7.优化算法之分治-快排归并

0.分治 分而治之 1.颜色分类 75. 颜色分类 - 力扣(LeetCode) 给定一个包含红色、白色和蓝色、共 n 个元素的数组 nums ,原地对它们进行排序,使得相同颜色的元素相邻,并按照红色、白色、蓝色顺序排列。 我们使用整数…

推动多模态智能模型发展:大型视觉语言模型综合多模态评测基准

随着人工智能技术的飞速发展,大型视觉语言模型(LVLMs)在多模态应用领域取得了显著进展。然而,现有的多模态评估基准测试在跟踪LVLMs发展方面存在不足。为了填补这一空白,本文介绍了MMT-Bench,这是一个全面的…

Django 模版继承

1&#xff0c;设计母版页 Test/templates/6/base.html <!DOCTYPE html> <html lang"en"> <head><meta charset"UTF-8"><!-- 修正了模板标签的全角字符问题 -->{% block title %}<title>这个是母版页</title>{…

leetCode.93. 复原 IP 地址

leetCode.93. 复原 IP 地址 题目思路&#xff1a; 代码 // 前导零的判断方法&#xff1a;如果第一个数是0&#xff0c;且第二个数还有数据&#xff0c;那就是前导0&#xff0c;要排除的 // 注意跟单个 0 区分开 class Solution { public:vector<string> res;vector<…

Opencv+python模板匹配

我们经常玩匹配图像或者找相似&#xff0c;opencv可以很好实现这个简单的小功能。 模板是被查找目标的图像&#xff0c;查找模板在原始图像中的哪个位置的过程就叫模板匹配。OpenCV提供的matchTemplate()方法就是模板匹配方法&#xff0c;其语法如下&#xff1a; result cv2.…

【活动感想】筑梦之旅·AI共创工坊 workshop 会议回顾

目录 &#x1f30a;1. 会议详情 &#x1f30a;2. 会议回顾 &#x1f30d;2.1 主持人开场 &#x1f30d;2.2 元甲-小当家 AI 驱动的创意儿童营养早餐料理机&今天吃什么App &#x1f30d;2.3 Steven- A l 心理疗愈认知 &#x1f30d;2.4 伯棠-诸子百家(xExperts)-多智能…

私有部署Twikoo评论系统

原文&#xff1a;https://blog.c12th.cn/archives/12.html 前言 以前用 MongoDB Vercel 搭建 Twikoo 老是有点小问题&#xff0c;所以就放弃了。无意中看到可以用 docker 来搭建&#xff0c;正好有台服务器可以尝试下。 私有部署 Twikoo 版本要求 1.6.0 或以上 &#xff0c; …

AMD Anti-Lag 2抗延迟技术落地 CS2首发、延迟缩短95%

AMD发布了全新重磅驱动程序Adrenalin 24.6.1版本&#xff0c;包括首发落地Anti-Lag 2抗延迟技术、优化支持新游戏、升级支持HYPR-Tune、支持新操作系统、优化AI加速与开发、扩展支持Agility SDK、修复已知Bug&#xff0c;等等。 一、Anti-Lag 2 今年5月份刚宣布&#xff0c;重…

【计算机毕业设计】基于Springboot的智能物流管理系统【源码+lw+部署文档】

包含论文源码的压缩包较大&#xff0c;请私信或者加我的绿色小软件获取 免责声明&#xff1a;资料部分来源于合法的互联网渠道收集和整理&#xff0c;部分自己学习积累成果&#xff0c;供大家学习参考与交流。收取的费用仅用于收集和整理资料耗费时间的酬劳。 本人尊重原创作者…

信号与系统-实验6-离散时间系统的 Z 域分析

一、实验目的 1、掌握 z 变换及其性质&#xff1b;了解常用序列的 z 变换、逆 z 变换&#xff1b; 2、掌握利用 MATLAB 的符号运算实现 z 变换&#xff1b; 3、掌握利用 MATLAB 绘制离散系统零、极点图的方法&#xff1b; 4、掌握利用 MATLAB 分析离散系统零、极点的方法&a…

kicad第三方插件安装问题

在使用KICAD时想安装扩展内容&#xff0c;但是遇到下载失败&#xff0c;因为SSL connect error。 因为是公司网络&#xff0c;我也不是很懂&#xff0c;只能另寻他法。找到如下方法可以曲线救国。 第三方插件包目录 打开存放第三方插件存放目录&#xff0c;用于存放下载插件包…

vue3+vite+nodejs,通过接口的形式请求后端打包(可打包全部或指定打包组件)

项目地址https://gitee.com/sybb011016/test_build 打包通过按钮的形式请求接口&#xff0c;让后端进行打包&#xff0c;后端使用express-generator搭建模版。前端项目就在npm init vuelatest基础上添加了路由 如果只想打包AboutView组件&#xff0c;首先修改后端接口。 //打…

Linux如何安装openjdk1.8

文章目录 Centosyum安装jdk和JRE配置全局环境变量验证ubuntu使用APT(适用于Ubuntu 16.04及以上版本)使用PPA(可选,适用于需要特定版本或旧版Ubuntu)Centos yum安装jdk和JRE yum install java-1.8.0-openjdk-devel.x86_64 安装后的目录 配置全局环境变量 vim /etc/pr…

运营商、银行、国企等单位开发岗24届Offer薪资与福利汇总

本文介绍24届校园招聘中&#xff0c;地理信息科学&#xff08;GIS&#xff09;专业硕士研究生所得Offer的整体薪资情况、福利待遇等。 在2024届秋招与春招中&#xff0c;我累计投递了170余个单位&#xff0c;获得17个Offer&#xff1b;平均每投递10个简历才能获得1个Offer。说句…