神经网络反向传播的数学原理

如果能二秒内在脑袋里解出下面的问题,本文便结束了。

已知:J=(Xw-y)^T(Xw-y)=\left | \right | Xw-y\left | \right | ^2,其中X\in R^{m\times n},w\in R^{n\times1},y\in R^{m\times1}

求:\frac{\partial J}{\partial X},\frac{\partial J}{\partial w},\frac{\partial J}{\partial y}

到这里,请耐心看完下面的公式推导,无需长久心里建设。

首先,反向传播的数学原理是“求导的链式法则” :

设f和g为x的可导函数,则(f\circ g)'(x)=f'(g(x))g'(x)

接下来介绍

  • 矩阵、向量求导的维数相容原则

  • 利用维数相容原则快速推导反向传播

  • 编程实现前向传播、反向传播

  • 卷积神经网络的反向传播

快速矩阵、向量求导

这一节展示如何使用链式法则、转置、组合等技巧来快速完成对矩阵、向量的求导

一个原则维数相容,实质是多元微分基本知识,没有在课本中找到下列内容,维数相容原则是我个人总结:

维数相容原则:通过前后换序、转置 使求导结果满足矩阵乘法且结果维数满足下式:

如果x\in R^{m \times n},f(x) \in R^1 ,那么\frac{\partial f(x)}{\partial x} \in R^{m \times n}

利用维数相容原则解上例:

step1:把所有参数当做实数来求导,J=(Xw-y)^2

依据链式法则有\frac{\partial J}{\partial X}=2(Xw-y)w, \frac{\partial J}{\partial w}=2(Xw-y)X, \frac{\partial J}{\partial y}=-2(Xw-y)

可以看出除了\frac{\partial J}{\partial y}=-2(Xw-y)\frac{\partial J}{\partial X}\frac{\partial J}{\partial w}的求导结果在维数上连矩阵乘法都不能满足。

step2:根据step1的求导结果,依据维数相容原则做调整:前后换序、转置

依据维数相容原则\frac{\partial J}{\partial X} \in R^{m \times n},但\frac{\partial J}{\partial X} \in R^{m \times n}=2(Xw-y)w(Xw-y)\in R^{m \times 1}X \in R^{m \times n},自然得调整为\frac{\partial J}{\partial X}=2(Xw-y)w^T

同理:\frac{\partial J}{\partial w} \in R^{n \times 1},但\frac{\partial J}{\partial w} \in R^{n \times 1}=2(Xw-y)X(Xw-y) \in R^{m \times 1}X \in R^{m \times n},那么通过换序、转置我们可以得到维数相容的结果2X^T(Xw-y)

对于矩阵、向量求导:

  • “当做一维实数使用链式法则求导,然后做维数相容调整,使之符合矩阵乘法原则且维数相容”是快速准确的策略;

  • “对单个元素求导、再整理成矩阵形式”这种方式整理是困难的、过程是缓慢的,结果是易出错的(不信你试试)。

如何证明经过维数相容原则调整后的结果是正确的呢?直觉!简单就是美...

快速反向传播

神经网络的反向传播求得“各层”参数W和b的导数,使用梯度下降(一阶GD、SGD,二阶LBFGS、共轭梯度等)优化目标函数。

接下来,展示不使用下标的记法(W_{ij},b_iorb_j)直接对W和b求导反向传播链式法则维数相容原则的完美体现,对每一层参数的求导利用上一层的中间结果完成。

这里的标号,参考UFLDL教程 - Ufldl

前向传播:

z^{(l+1)}=W^{(l)}a^{(l)}+b^{(l)}    (公式1)

a^{(l+1)}=f(z^{(l+1)})             (公式2)

z^{(l)}为第 l 层的中间结果,a^{(l)} 为第 l 层的激活值,其中第 l +1层包含元素:输入a^{(l)},参数W^{(l)}b^{(l)},激活函数f(),中间结果z^{(l+1)},输出a^{(l+1)}

设神经网络的损失函数为J(W,b) \in R^1(这里不给出具体公式,可以是交叉熵、MSE等),根据链式法则有:

图片

这里记 \frac{\partial J(W,b)}{\partial {z^{(l+1)}}}=\delta ^{(l+1)},其中\frac{\partial z^{(l+1)}}{\partial W^{(l)}}=a^{(l)}\frac{\partial z^{(l+1)}}{\partial b^{(l)}}=1可由 公式1 得出,a^{(l)}加转置符号(a^{(l)})^T是根据维数相容原则作出的调整。

如何求 \delta ^{(l)}=\frac{\partial J(W,b)}{\partial z^{(l)}} ?可使用如下递推(需根据维数相容原则作出调整):

图片

其中

图片

那么我们可以从最顶层逐层往下,便可以递推求得每一层的\delta ^{(l)}=\frac{\partial J(W,b)}{\partial z^{(l)}}

注意:\frac{\partial a^{(l)}}{\partial z^{(l)}}=f'(z^{(l)})是逐维求导,在公式中是点乘的形式。

反向传播整个流程如下:

1) 进行前向传播计算,利用前向传播公式,得到隐藏层和输出层 的激活值。

2) 对输出层(第 l 层),计算残差:\delta ^{(l)}=\frac{\partial J(W,b)}{\partial z^{(l)}}(不同损失函数,结果不同,这里不给出具体形式)

3) 对于l-1,l-2,\cdot \cdot \cdot ,2的隐藏层,计算:

图片

4) 计算各层参数W^{(l)},b^{(l)}偏导数:

图片

编程实现

大部分开源library(如:caffe,Kaldi/src/{nnet1,nnet2})的实现通常把W^{(l)},b^{(l)}作为一个layer,激活函数f()作为一个layer(如:sigmoid、relu、softplus、softmax)。

反向传播时分清楚该层的输入、输出即能正确编程实现,如:

图片

(1)式AffineTransform/FullConnected层,以下是伪代码:

图片

注: out_diff = \frac{\partial J}{\partial z^{(l+1)}}是上一层(Softmax 或 Sigmoid/ReLU的 in_diff)已经求得:

图片

(2)式激活函数层(以Sigmoid为例)

图片

注:out_diff = \frac{\partial J}{\partial a^{(l+1)}}是上一层AffineTransform的in_diff,已经求得,

图片

在实际编程实现时,in、out可能是矩阵(通常以一行存储一个输入向量,矩阵的行数就是batch_size),那么上面的C++代码就要做出变化(改变前后顺序、转置,把函数参数的Vector换成Matrix,此时Matrix out_diff 每一行就要存储对应一个Vector的diff,在update的时候要做这个batch的加和,这个加和可以通过矩阵相乘out_diff*input(适当的转置)得到。

如果熟悉SVD分解的过程,通过SVD逆过程就可以轻松理解这种通过乘积来做加和的技巧。

丢掉那些下标记法吧!

卷积层求导

卷积怎么求导呢?实际上卷积可以通过矩阵乘法来实现(是否旋转无所谓的,对称处理,caffe里面是不是有image2col),当然也可以使用FFT在频率域做加法。

那么既然通过矩阵乘法,维数相容原则仍然可以运用,CNN求导比DNN复杂一些,要做些累加的操作。具体怎么做还要看编程时选择怎样的策略、数据结构。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/162962.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

利用 Pandoc + ChatGPT 优雅地润色论文,并保持 Word 公式格式:Pandoc将Word和LaTeX文件互相转化

论文润色完美解决方案:Pandoc 与 ChatGPT 的强强联合 写在最前面其他说明 一、通过 Pandoc 将 Word 转换为 LaTeX 的完整指南步骤 1: 安装 PandocWindows:macOS:Linux: 步骤 2: 准备 Word 文档步骤 3: 转换文档步骤 4: 检查并调整输出步骤 5: 编译 LaTeX 文档总结 二…

分类预测 | Matlab实现基于SDAE堆叠去噪自编码器的数据分类预测

分类预测 | Matlab实现基于SDAE堆叠去噪自编码器的数据分类预测 目录 分类预测 | Matlab实现基于SDAE堆叠去噪自编码器的数据分类预测分类效果基本描述程序设计参考资料 分类效果 基本描述 1.Matlab实现基于SDAE堆叠去噪自编码器的数据分类预测(完整源码和数据) 2.多…

机器学习第9天:决策树分类

文章目录 机器学习专栏 介绍 基本思想 使用代码 深度探索 优点 估计概率 训练算法 CART成本函数 实例数与不纯度 正则化 在鸢尾花数据集上训练决策树 机器学习专栏 机器学习_Nowl的博客-CSDN博客 介绍 作用:分类 原理:构建一个二叉树&#…

【信息安全】浅谈IDOR越权漏洞的原理、危害和防范:直接对象引用导致的越权行为

前言 ┌──────────────────────────────────┐ │ 正在播放《越权访问》 - Hanser │ ●━━━━━━─────── 00:00 / 03:05 │ ↻ ◁ ❚❚ ▷ ⇆ └───────────────────────────────…

Java - 位运算的基本原理和用途

Java - 位运算的基本原理和用途 前言一. Java 位运算基本操作1.1 按位与 &1.2 按位或 |1.3 按位异或 ^1.4 按位取反 ~1.5 位移运算1.5.1 左移运算符 <<1.5.2 右移运算符 >>1.5.3 无符号右移运算符 >>> 二. 位运算实际运用2.1 判断奇偶性&#xff08;&…

Vue3 插槽 v-slot

插槽 视频链接&#xff1a;尚硅谷vue-插槽章节 不使用插槽的情况下 结果&#xff1a; 1 默认插槽 在子组件中只能有一个 结果&#xff1a; 2 具名插槽 #b是v-slot:b 的缩写 顾名思义就是指着名字去插入 结果&#xff1a; 3 作用域插槽 可以传递数据的插槽&#…

【SQL server】数据库、数据表的创建

创建数据库 --如果存在就删除 --所有的数据库都存在sys.databases当中 if exists(select * from sys.databases where name DBTEST)drop database DBTEST--创建数据库 else create database DBTEST on --数据文件 (nameDBTEST,--逻辑名称 字符串用单引号filenameD:\DATA\DBT…

【c++】——类和对象(中)——实现完整的日期类(优化)万字详细解疑答惑

作者:chlorine 专栏:c专栏 赋值运算符重载()()():实现完整的日期类(上) 我走的很慢&#xff0c;但我从不后退。 【学习目标】 日期(- - --)天数重载运算符 日期-日期 返回天数 对日期类函数进行优化(不符合常理的日期&#xff0c;负数&#xff0c;const成员)c中重载输入cin和输…

C/C++数据结构之中缀表达式转换为后缀表达式,删除堆栈元素

在这篇博客中&#xff0c;我们将深入分析一个使用C编写的栈和表达式计算程序。该程序不仅实现了基本的栈操作&#xff0c;还提供了中缀表达式转后缀表达式和删除堆栈中的元素等实用功能。通过逐一讲解每个函数的功能&#xff0c;我们将更全面地理解这个程序的实现。 资源获取&a…

python算法例15 合并数字

1. 问题描述 给出n个数&#xff0c;将这n个数合并成一个数&#xff0c;每次只能选择两个数a、b合并&#xff0c;合并需要消耗的能量为ab&#xff0c;输出将n个数合并成一个数后消耗的最小能量。 2. 问题示例 给出[1&#xff0c;2&#xff0c;3&#xff0c;4]&#xff0c;返回…

【漏洞复现】通达oa 前台sql注入

漏洞描述 通达OA(Office Automation)是一款企业级协同办公软件,旨在为企业提供高效、便捷、安全、可控的办公环境。它涵盖了企业日常办公所需的各项功能,包括人事管理、财务管理、采购管理、销售管理、库存管理、生产管理、办公自动化等。通达OA支持PC端和移动端使用,可以…

比亚迪刀片电池与特斯拉4680电池比较

1 电池材料 比亚迪刀片电池采用的磷酸铁锂LFP&#xff08;LiFePO4&#xff09;&#xff0c;特斯拉的4680电池采用的三元锂。 磷酸铁锂&#xff1a;循环寿命长&#xff0c;安全性能好&#xff0c;价格低廉&#xff0c;但是能量密度低&#xff0c;导电性能差&#xff0c;低温表现…

【机器学习算法】机器学习:支持向量机(SVM)

转载自&#xff1a; 【精选】机器学习&#xff1a;支持向量机&#xff08;SVM&#xff09;-CSDN博客 1.概述 1.1&#xff0c;概念 支持向量机&#xff08;SVM&#xff09;是一类按监督学习方式对数据进行二元分类的广义线性分类器&#xff0c;其决策边界是对学习样本求解的最…

window系统vscode 编译wvp前端代码

下载代码 wvp-GB28181-pro: WEB VIDEO PLATFORM是一个基于GB28181-2016标准实现的网络视频平台&#xff0c;负责实现核心信令与设备管理后台部分&#xff0c;支持NAT穿透&#xff0c;支持海康、大华、宇视等品牌的IPC、NVR、DVR接入。支持国标级联&#xff0c;支持rtsp/rtmp等…

应用场景丨迭代市政综合管廊监测系统建设

市政综合管廊是指在城市地下建造的隧道空间&#xff0c;将市政、电力、通讯、燃气、给排水等各种管线集于一体&#xff0c;实施统一规划、设计、建设和管理。综合管廊有利于解决反复开挖路面、架空线网密集、管线事故频发等问题&#xff0c;是保障城市运行的重要基础设施和“生…

CAS源码工程搭建记录

CAS源码工程搭建 1.下载2.gradle下载源改为阿里云&#xff0c;解决下载慢的问题3.解决保存 1.下载 git clone -b 5.3.x https://gitee.com/mirrors/CAS.git如果下载的是压缩包&#xff0c;导入工程会保存&#xff0c;因为builder.gradle的第20行开始有取git信息&#xff0c;如…

【Web】PHP反序列化的一些trick

目录 ①__wakeup绕过 ②加号绕过正则匹配 ③引用绕过相等 ④16进制绕过关键词过滤 ⑤Exception绕过 ⑥字符串逃逸 要中期考试乐(悲) ①__wakeup绕过 反序列化字符串中表示属性数量的值 大于 大括号内实际属性的数量时&#xff0c;wakeup方法会被绕过 &#xff08;php5-p…

汽车级低压差稳压器LDO LM317BD2TR4G原理、参数及应用

LM317BD2TR4G主要功能特性分析 &#xff1a; LM317BD2TR4G 低漏 (LDO) 线性电压稳压器是一款可调 3 端子正向 LDO 电压器&#xff0c;能够在 1.2 V 至 37 V 的输出电压范围内提供 1.5 A 以上的电流。此电压稳压器使用非常简便&#xff0c;仅需两个外部电阻即可设置输出电压。另…

sentinel 网关

网关简介 大家都都知道在微服务架构中&#xff0c;一个系统会被拆分为很多个微服务。那么作为客户端要如何去调用这么多的微服务呢&#xff1f;如果没有网关的存在&#xff0c;我们只能在客户端记录每个微服务的地址&#xff0c;然后分别去调用。 这样的架构&#xff0c;会存在…

组合模式 rust和java的实现

文章目录 组合模式介绍实现javarsut 组合模式 组合模式&#xff08;Composite Pattern&#xff09;&#xff0c;又叫部分整体模式&#xff0c;是用于把一组相似的对象当作一个单一的对象。组合模式依据树形结构来组合对象&#xff0c;用来表示部分以及整体层次。这种类型的设计…