昇思25天学习打卡营第07天 | 函数式自动微分

昇思25天学习打卡营第07天 | 函数式自动微分

文章目录

  • 昇思25天学习打卡营第07天 | 函数式自动微分
    • 函数与计算图
      • 微分函数与梯度
      • Stop Gradient
      • Auxiliary data
    • 神经网络梯度计算
    • 总结
    • 打卡

神经网络的训练主要使用反向传播算法,首先计算模型预测值(logits)与正确标签(label)之间的loss,然后进行反向传播,通过梯度来更新模型参数从而完成网路的训练。

MindSpore使用函数式自动微分的设计理念,提供更接近于数学语义的自动微分接口gradvalue_and_grad

函数与计算图

计算图是图论语言表示数学函数的一种方式,也是深度学习框架表达神经网络模型的统一方法。
compute-graph
在这个模型中, x x x为输入, z z z为输出, y y y为正确值, w w w b b b是需要优化的参数。

x = ops.ones(5, mindspore.float32)  # input tensor
y = ops.zeros(3, mindspore.float32)  # expected output
w = Parameter(Tensor(np.random.randn(5, 3), mindspore.float32), name='w') # weight
b = Parameter(Tensor(np.random.randn(3,), mindspore.float32), name='b') # bias

def function(x, y, w, b):
	z = ops.matmul(x, w) + b
	loss = ops.binary_cross_entropy_with_logits(z, y, ops.ones_like(z), ops.ones_like(z))
    return loss

通过执行function获得loss值:

loss = function(x,y,w,b)

微分函数与梯度

为了优化参数 w w w b b b,需要求参数对loss的导数 ∂ l o s s ∂ w \frac{\partial loss}{\partial w} wloss ∂ l o s s ∂ b \frac{\partial loss}{\partial b} bloss

可以通过mindspore.grad函数来获得function的微分函数:

grad_fn = mindspore.grad(function, (2, 3))

grads = grad_fn(x, y, w, b)

此处使用了grad的两个入参:

  • fn:待求导的函数;
  • grad_position:指定求导输入位置的索引。

Stop Gradient

通常情况下,求导时会求loss对参数的导数,因此函数只输出loss一项。
如果函数输出多项时,微分函数会求所有输出对参数的导数。

def function_with_logits(x, y, w, b):
    z = ops.matmul(x, w) + b
    loss = ops.binary_cross_entropy_with_logits(z, y, ops.ones_like(z), ops.ones_like(z))
    return loss, z
    
grad_fn = mindspore.grad(function_with_logits, (2, 3))
grads = grad_fn(x, y, w, b)
print(grads)

此处function_with_logits输出的z会影响梯度。

如果想要实现对某个输出项的梯度截断,或消除某个Tensor对梯度的影响,需要用到Stop Gradient操作。

def function_stop_gradient(x, y, w, b):
    z = ops.matmul(x, w) + b
    loss = ops.binary_cross_entropy_with_logits(z, y, ops.ones_like(z), ops.ones_like(z))
    return loss, ops.stop_gradient(z)
    
grad_fn = mindspore.grad(function_stop_gradient, (2, 3))
grads = grad_fn(x, y, w, b)
print(grads)

Auxiliary data

Auxiliary data为辅助数据,是函数除第一个输出项外的其他输出。通常loss值为函数的第一个输出,而其它输出即为辅助数据。

gradvalue_and_grad提供has_aux参数,设置为True时,可以自动实现前文中stop_gradient的功能。

grad_fn = mindspore.grad(function_with_logits, (2, 3), has_aux=True)
grads, (z,) = grad_fn(x, y, w, b)
print(grads, z)

神经网络梯度计算

# Define model
class Network(nn.Cell):
    def __init__(self):
        super().__init__()
        self.w = w
        self.b = b

    def construct(self, x):
        z = ops.matmul(x, self.w) + self.b
        return z

# Instantiate model
model = Network()
# Instantiate loss function
loss_fn = nn.BCEWithLogitsLoss()

实例化网络和损失函数后,将其封装为一个前向计算函数,用于自动微分:

# Define forward function
def forward_fn(x, y):
    z = model(x)
    loss = loss_fn(z, y)
    return loss

由于使用nn.Cell封装网络模型,其参数为Cell的内部属性,因此不需要指定grad_position参数,直接设置为None

对模型参数求导时,使用weights参数,指定为通过model.trainable_params()方法从Cell中取出的可以求导的参数:

grad_fn = mindspore.value_and_grad(forward_fn, None, weights=model.trainable_params())

loss, grads = grad_fn(x, y)
print(grads)

总结

这一节从一个简单的线性函数 w x + b wx+b wx+b出发,介绍了网络模型中数学函数的统一表示方法(即计算图),与loss的计算过程。使用gradvalue_and_grad方法可以通过自动微分获取目标函数的微分函数,从而得到参数对loss的梯度,进而优化参数。对于需要输出辅助数据的函数来说,可以通过ops.stop_gradient进行梯度截断,或设置has_aux=True来自动完成。
在通过Cell封装的网络模型中,需要将模型和loss的调用封装为一个前向计算函数,从而进行自动微分。

打卡

在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mfbz.cn/a/774432.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

【IC】mismatch model

由于工艺和制造偏差的存在,相同设计参数的器件会存在参数间额差异,称为mismatch,通常用Monte Carlo去仿真多个mismatch叠加对设计的总影响。 器件偏差mismatch是工艺和制造偏差导致的,在Lot to Lot、Wafer to Wafer、Die to Die 以及in die的Device to Deview之间可见。 …

OZON怎么查看竞品数据,OZON怎么找竞品数据

在跨境电商的激烈竞争中,了解和分析竞品数据是每一位卖家优化销售策略、提升市场竞争力的关键步骤。OZON作为俄罗斯领先的电商平台,为卖家提供了丰富的数据分析工具,而萌啦ozon数据作为第三方数据分析平台,更是为卖家提供了更为全…

四、centos7安装nginx

来源网站:山海同行 来源地址:https://shanhaigo.cn 网站简介:一站式编程学习、资源、导航网站 本篇资源:以整理分类并关联本篇地址 本篇地址:https://shanhaigo.cn/courseDetail/1805875642621952000 安装系统centos7 …

CASS7.0按方向和距离绘制图形

1、绘制工具 2、按方向和距离绘制 (1)切换方向 (2)距离输入

【算法:贪心】:贪心算法介绍+基础题(四个步骤);柠檬水找零(交换论证法)

🎁个人主页:我们的五年 🔍系列专栏:C课程学习 🎉欢迎大家点赞👍评论📝收藏⭐文章 前言: 暑假马上就要留校学习算法了,现在先学习一下基本的算法打打基础。本篇要讲的是…

数据库数据修改和删除操作详解

目录 🎃摘要 1. 数据库数据修改概述 2. 数据更新操作 2.1MySQL数据更新示例 3. 数据删除概述 4.使用DELETE进行数据删除 4.1 DELETE的基本语法 4.2 DELETE的使用场景 4.3 DELETE示例 5. 使用TRUNCATE进行数据删除 5.1 TRUNCATE的基本语法 5.2 TRUNCATE的…

Vue3:全局播放背景音乐

说明&#xff1a;一个全局播放的背景音乐&#xff0c;首页无音乐无音乐图标&#xff0c;在首页互动跳转页面并开始播放音乐&#xff0c;切换页面不需暂停音乐也不会重置音乐&#xff0c;可以通过音乐图标控制暂停或播放。 MusicPlay.vue&#xff08;音乐组件&#xff09; <…

新手高效指南:电子元器件BOM表创建/制作及配单全教程

在科技日新月异的今天&#xff0c;电子产品设计与制造不仅是创新精神的展现&#xff0c;更是对精确度与效率的不懈追求。在这个过程中&#xff0c;一份精细且全面的BOM&#xff08;物料清单&#xff09;犹如一座桥梁&#xff0c;连接着创意与现实世界。BOM不仅细致记录了产品所…

Facebook群发消息API接口的申请流程详解!

Facebook 群发消息api接口如何集成&#xff1f;怎么使用API接口&#xff1f; 在现代社交媒体营销中&#xff0c;群发消息是与客户保持互动的重要工具。Facebook群发消息API接口提供了一种有效的方法来实现这一目标。本文将详细介绍如何申请Facebook群发消息API接口的具体步骤和…

【qt】如何获取本机的IP地址?

需要用到这个类QHostInfo和pro里面添加network模块 用这个类的静态函数forName()来获取该主机名的信息 返回的就是这个类 这个QHostInfo类就包括主机的IP地址信息 用静态函数addresses()来获取 返回的是一个QHostAddress的容器 QList<QHostAddress>addrList hostIn…

深入分析 Android BroadcastReceiver (九)

文章目录 深入分析 Android BroadcastReceiver (九)1. Android 广播机制的扩展应用与高级优化1.1 广播机制的扩展应用1.1.1 示例&#xff1a;有序广播1.1.2 示例&#xff1a;粘性广播1.1.3 示例&#xff1a;局部广播 1.2 广播机制的高级优化1.2.1 示例&#xff1a;使用 Pending…

Gemini for China 大更新,现已上架 Android APP!

官网&#xff1a;https://gemini.fostmar.online/ Android APP&#xff1a;https://gemini.fostmar.online/gemini_1.0.apk 一、Android APP 如果是 Android 设备&#xff0c;则会直接识别到并给下载链接。PC 直接对话即可。 二、聊天记录 现在 Gemini for China&#xff…

mysql8 导入导出工具类,支持windows 和linux

概述 1&#xff09;导入导出工具类 支持windows 和linux&#xff0c;详见第3部分 2&#xff09;导入、导出参数在 dbeaver 中应用&#xff0c;详见第4部分 整理原因: 1&#xff09;中文乱码 --default-character-setutf8 2&#xff09;BLOB 导出后&#xff0c;导入失败 --he…

DatawhaleAI夏令营2024 Task2

#AI夏令营 #Datawhale #夏令营 赛题解析一、Baseline详解1.1 环境配置1.2 数据处理任务理解2.3 prompt设计2.4 数据抽取 二、完整代码总结 赛题解析 赛事背景 在数字化时代&#xff0c;企业积累了大量对话数据&#xff0c;这些数据不仅是交流记录&#xff0c;还隐藏着宝贵的信…

8.13 矢量图层面要素反转面要素渲染(Inverted polygons Renderer)

前言 本章介绍矢量图层面要素反转面要素(Inverted polygons Renderer)的使用说明&#xff1a;文章中的示例代码均来自开源项目qgis_cpp_api_apps 反转面要素(Inverted polygons Renderer) 反转面要素渲染常用于掩膜数据。 反转面要素(Inverted polygons Renderer)是一种渲染方…

软件测试之接口自动化测试实战(完整版)

&#x1f345; 视频学习&#xff1a;文末有免费的配套视频可观看 &#x1f345; 点击文末小卡片&#xff0c;免费获取软件测试全套资料&#xff0c;资料在手&#xff0c;涨薪更快 自从看到阿里云性能测试 PTS 接口测试开启免费公测&#xff0c;就想着跟大家分享交流一下如何实现…

使用笔记之-E语言微信支付支付宝支付源代码

首先下载E语言微信支付&支付宝支付源代码 http://www.htsoft.com.cn/download/E_WeiXin_ZhiFuBao_ZhiFu.rar

编译开源车载Linux操作系统AGL

随着汽车行业的智能化和互联化趋势日益明显&#xff0c;车载系统作为汽车的重要组成部分&#xff0c;其性能和功能也受到了越来越多的关注。Linux作为一款开源的操作系统&#xff0c;具有稳定性高、安全性强、可定制性好等优点&#xff0c;因此成为了车载系统领域的热门选择。 …

黄小米-从田间到餐桌的美味之旅

甘肃黄小米颗粒饱满&#xff0c;色泽金黄&#xff0c;富含多种营养成分&#xff0c;如蛋白质、膳食纤维、维生素和矿物质等。其口感香糯&#xff0c;煮粥时香气扑鼻&#xff0c;米油丰富&#xff0c;味道醇厚。由于甘肃地区独特的地理和气候条件&#xff0c;包括充足的日照、较…

SQL 与 NoSQL 数据库:一场关于灵活性与结构的对话

文章目录 引言SQL 数据库&#xff1a;传统之光定义特征优势缺点 NoSQL 数据库&#xff1a;新时代的弹性定义特征优势缺点 何时选择 NoSQL&#xff1f;场景1&#xff1a;海量数据与高并发场景2&#xff1a;灵活性需求场景3&#xff1a;实时数据分析场景4&#xff1a;分布式系统 …