《动手学深度学习(PyTorch版)》笔记4.7

Chapter4 Multilayer Perceptron

4.7 Forward/Backward Propagation and Computational Graphs

本节将通过一些基本的数学和计算图,深入探讨反向传播的细节。首先,我们将重点放在带权重衰减( L 2 L_2 L2正则化)的单隐藏层多层感知机上。

4.7.1 Forward Propagation

前向传播(forward propagation或forward pass)指的是按顺序(从输入层到输出层)计算和存储神经网络中每层的结果。

我们将一步步研究单隐藏层神经网络的机制,为了简单起见,我们假设输入样本是 x ∈ R d \mathbf{x}\in \mathbb{R}^d xRd,并且我们的隐藏层不包括偏置项。这里的中间变量是:

z = W ( 1 ) x , \mathbf{z}= \mathbf{W}^{(1)} \mathbf{x}, z=W(1)x,

其中 W ( 1 ) ∈ R h × d \mathbf{W}^{(1)} \in \mathbb{R}^{h \times d} W(1)Rh×d是隐藏层的权重参数。将中间变量 z ∈ R h \mathbf{z}\in \mathbb{R}^h zRh通过激活函数 ϕ \phi ϕ后,我们得到长度为 h h h的隐藏激活向量:

h = ϕ ( z ) . \mathbf{h}= \phi (\mathbf{z}). h=ϕ(z).

隐藏变量 h \mathbf{h} h也是一个中间变量。假设输出层的参数只有权重 W ( 2 ) ∈ R q × h \mathbf{W}^{(2)} \in \mathbb{R}^{q \times h} W(2)Rq×h,我们可以得到输出层变量,它是一个长度为 q q q的向量:

o = W ( 2 ) h . \mathbf{o}= \mathbf{W}^{(2)} \mathbf{h}. o=W(2)h.

假设损失函数为 l l l,样本标签为 y y y,我们可以计算单个数据样本的损失项,

L = l ( o , y ) . L = l(\mathbf{o}, y). L=l(o,y).

根据 L 2 L_2 L2正则化的定义,给定超参数 λ \lambda λ,正则化项为

s = λ 2 ( ∥ W ( 1 ) ∥ F 2 + ∥ W ( 2 ) ∥ F 2 ) , s = \frac{\lambda}{2} \left(\|\mathbf{W}^{(1)}\|_F^2 + \|\mathbf{W}^{(2)}\|_F^2\right), s=2λ(W(1)F2+W(2)F2),

∥ X ∥ F \|\mathbf{X}\|_F XF表示矩阵的Frobenius范数:
∥ X ∥ F = ∑ i = 1 m ∑ j = 1 n x i j 2 . \|\mathbf{X}\|_F = \sqrt{\sum_{i=1}^m \sum_{j=1}^n x_{ij}^2}. XF=i=1mj=1nxij2 .
最后,模型在给定数据样本上的正则化损失为:

J = L + s . J = L + s. J=L+s.

在下面的讨论中,我们将 J J J称为目标函数(objective function)。

下图是与上述简单网络相对应的计算图,其中正方形表示变量,圆圈表示操作符。

在这里插入图片描述

4.7.2 Backward Propagation

反向传播(backward propagation或backpropagation)指的是计算神经网络参数梯度的方法,该方法根据链式规则,按相反的顺序从输出层到输入层遍历网络。该算法存储了计算某些参数梯度时所需的任何中间变量(偏导数)。
假设我们有函数 Y = f ( X ) \mathsf{Y}=f(\mathsf{X}) Y=f(X) Z = g ( Y ) \mathsf{Z}=g(\mathsf{Y}) Z=g(Y),其中输入和输出 X , Y , Z \mathsf{X}, \mathsf{Y}, \mathsf{Z} X,Y,Z是任意形状的张量。利用链式法则,我们可以计算 Z \mathsf{Z} Z关于 X \mathsf{X} X的导数:

∂ Z ∂ X = prod ( ∂ Z ∂ Y , ∂ Y ∂ X ) . \frac{\partial \mathsf{Z}}{\partial \mathsf{X}} = \text{prod}\left(\frac{\partial \mathsf{Z}}{\partial \mathsf{Y}}, \frac{\partial \mathsf{Y}}{\partial \mathsf{X}}\right). XZ=prod(YZ,XY).

在这里,我们使用 prod \text{prod} prod运算符在执行必要的操(如换位和交换输入位置)后将其参数相乘。对于高维张量,我们使用适当的对应项。

在上面的计算图中单隐藏层简单网络的参数是 W ( 1 ) \mathbf{W}^{(1)} W(1) W ( 2 ) \mathbf{W}^{(2)} W(2),反向传播的目的是计算梯度 ∂ J / ∂ W ( 1 ) \partial J/\partial \mathbf{W}^{(1)} J/W(1) ∂ J / ∂ W ( 2 ) \partial J/\partial \mathbf{W}^{(2)} J/W(2),计算的顺序与前向传播中执行的顺序相反,具体如下:

∂ J ∂ L = 1    and    ∂ J ∂ s = 1. \frac{\partial J}{\partial L} = 1 \; \text{and} \; \frac{\partial J}{\partial s} = 1. LJ=1andsJ=1.

∂ J ∂ o = prod ( ∂ J ∂ L , ∂ L ∂ o ) = ∂ L ∂ o ∈ R q . \frac{\partial J}{\partial \mathbf{o}} = \text{prod}\left(\frac{\partial J}{\partial L}, \frac{\partial L}{\partial \mathbf{o}}\right) = \frac{\partial L}{\partial \mathbf{o}} \in \mathbb{R}^q. oJ=prod(LJ,oL)=oLRq.

∂ s ∂ W ( 1 ) = λ W ( 1 )    ,    ∂ s ∂ W ( 2 ) = λ W ( 2 ) . \frac{\partial s}{\partial \mathbf{W}^{(1)}} = \lambda \mathbf{W}^{(1)} \; \text{,} \; \frac{\partial s}{\partial \mathbf{W}^{(2)}} = \lambda \mathbf{W}^{(2)}. W(1)s=λW(1),W(2)s=λW(2).

∂ J ∂ W ( 2 ) = prod ( ∂ J ∂ o , ∂ o ∂ W ( 2 ) ) + prod ( ∂ J ∂ s , ∂ s ∂ W ( 2 ) ) = ∂ J ∂ o h ⊤ + λ W ( 2 ) ∈ R q × h . \frac{\partial J}{\partial \mathbf{W}^{(2)}}= \text{prod}\left(\frac{\partial J}{\partial \mathbf{o}}, \frac{\partial \mathbf{o}}{\partial \mathbf{W}^{(2)}}\right) + \text{prod}\left(\frac{\partial J}{\partial s}, \frac{\partial s}{\partial \mathbf{W}^{(2)}}\right)= \frac{\partial J}{\partial \mathbf{o}} \mathbf{h}^\top + \lambda \mathbf{W}^{(2)}\in \mathbb{R}^{q \times h}. W(2)J=prod(oJ,W(2)o)+prod(sJ,W(2)s)=oJh+λW(2)Rq×h.

∂ J ∂ h = prod ( ∂ J ∂ o , ∂ o ∂ h ) = W ( 2 ) ⊤ ∂ J ∂ o ∈ R h . \frac{\partial J}{\partial \mathbf{h}} = \text{prod}\left(\frac{\partial J}{\partial \mathbf{o}}, \frac{\partial \mathbf{o}}{\partial \mathbf{h}}\right) = {\mathbf{W}^{(2)}}^\top \frac{\partial J}{\partial \mathbf{o}}\in \mathbb{R}^h. hJ=prod(oJ,ho)=W(2)oJRh.

由于激活函数 ϕ \phi ϕ是按元素计算的,计算中间变量 z \mathbf{z} z的梯度需要使用按元素乘法运算符,我们用 ⊙ \odot 表示:

∂ J ∂ z = prod ( ∂ J ∂ h , ∂ h ∂ z ) = ∂ J ∂ h ⊙ ϕ ′ ( z ) ∈ R h . \frac{\partial J}{\partial \mathbf{z}} = \text{prod}\left(\frac{\partial J}{\partial \mathbf{h}}, \frac{\partial \mathbf{h}}{\partial \mathbf{z}}\right) = \frac{\partial J}{\partial \mathbf{h}} \odot \phi'\left(\mathbf{z}\right)\in \mathbb{R}^h. zJ=prod(hJ,zh)=hJϕ(z)Rh.

∂ J ∂ W ( 1 ) = prod ( ∂ J ∂ z , ∂ z ∂ W ( 1 ) ) + prod ( ∂ J ∂ s , ∂ s ∂ W ( 1 ) ) = ∂ J ∂ z x ⊤ + λ W ( 1 ) = ∂ J ∂ h ⊙ ϕ ′ ( z ) x ⊤ + λ W ( 1 ) = ( W ( 2 ) ⊤ ∂ J ∂ o ) ⊙ ϕ ′ ( z ) x ⊤ + λ W ( 1 ) . \begin{align*} \frac{\partial J}{\partial \mathbf{W}^{(1)}} &= \text{prod}\left(\frac{\partial J}{\partial \mathbf{z}}, \frac{\partial \mathbf{z}}{\partial \mathbf{W}^{(1)}}\right) + \text{prod}\left(\frac{\partial J}{\partial s}, \frac{\partial s}{\partial \mathbf{W}^{(1)}}\right) \\ &= \frac{\partial J}{\partial \mathbf{z}} \mathbf{x}^\top + \lambda \mathbf{W}^{(1)} \\ &= \frac{\partial J}{\partial \mathbf{h}} \odot \phi'\left(\mathbf{z}\right)\mathbf{x}^\top + \lambda \mathbf{W}^{(1)} \\ &= ({\mathbf{W}^{(2)}}^\top \frac{\partial J}{\partial \mathbf{o}})\odot \phi'\left(\mathbf{z}\right)\mathbf{x}^\top + \lambda \mathbf{W}^{(1)}. \end{align*} W(1)J=prod(zJ,W(1)z)+prod(sJ,W(1)s)=zJx+λW(1)=hJϕ(z)x+λW(1)=(W(2)oJ)ϕ(z)x+λW(1).

4.7.3 Training Neural Networks

在训练神经网络时,前向传播和反向传播相互依赖。以上述简单网络为例:一方面,在前向传播期间计算正则项取决于模型参数 W ( 1 ) \mathbf{W}^{(1)} W(1) W ( 2 ) \mathbf{W}^{(2)} W(2)的当前值。它们是由优化算法根据最近迭代的反向传播给出的。另一方面,反向传播期间参数的梯度计算,取决于由前向传播给出的隐藏变量 h \mathbf{h} h的当前值。

因此,在训练神经网络时,我们交替使用前向传播和反向传播,利用反向传播给出的梯度来更新模型参数。注意,反向传播重复利用前向传播中存储的中间值,以避免重复计算。这带来的影响之一是我们需要保留中间值,直到反向传播完成,这也是训练比单纯的预测需要更多的内存(显存)的原因之一。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/353581.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

SQL注入:盲注

SQL注入系列文章: 初识SQL注入-CSDN博客 SQL注入:联合查询的三个绕过技巧-CSDN博客 SQL注入:报错注入-CSDN博客 目录 什么是盲注? 布尔盲注 手工注入 使用python脚本 使用sqlmap 时间盲注 手工注入 使用python脚本 使…

深兰科技入选亿欧《“制”敬不凡先锋榜·智能机器人Top10》榜单

日前,由亿欧协办的2023工博会工业智能化发展高峰论坛于上海成功举办,会上发布了《2023智能制造:“制”敬不凡先锋者》系列名单。深兰科技凭借在智能机器人开发中的技术创新和模式应用,入选《“制”敬不凡先锋榜——智能机器人Top1…

第3章-python深度学习——(波斯美女)

第3章 神经网络入门 本章包括以下内容: 神经网络的核心组件 Keras 简介 建立深度学习工作站 使用神经网络解决基本的分类问题与回归问题 本章的目的是让你开始用神经网络来解决实际问题。你将进一步巩固在第 2 章第一个示例中学到的知识,还会将学到的…

Linux——软件安装

1、软件包的分类 Linux下的软件包众多,而且几乎都是经GPL授权的,也就是说这些软件都免费,振奋人心吧?而且更棒的是,这些软件几乎都提供源代码(开源的),只要你愿意,就可以…

【GitHub项目推荐--开源PDF 工具】【转载】

12 年历史的 PDF 工具开源了 最近在整理 PDF 的时候,有一些需求普通的 PDF 编辑器没办法满足,比如 PDF 批量合并、编辑等。 于是,我就去 GitHub 上看一看有没有现成的轮子,发现了这个 PDF 神器「PDF 补丁丁」,让人惊…

C++类和对象——构造函数与解析函数介绍

目录 1.构造函数和析构函数 1.构造函数,进行初始化 2.析构函数,进行清理 2.构造函数的分类及调用 1.括号法 注意: 2.显示法 3.隐式转化法 匿名对象 3.拷贝构造函数调用时机 4.构造函数调用规则 1.定义有参构造函数,不…

2023年全球软件开发大会(QCon广州站2023):核心内容与学习收获(附大会核心PPT下载)

在全球化的科技浪潮中,软件开发行业日新月异,持续推动着社会经济的飞速发展。本次峰会以“引领未来,探索无限可能”为主题,聚焦软件开发领域的最新技术、最佳实践和创新思想。来自世界各地的顶级专家、企业领袖和开发者齐聚一堂&a…

【极数系列】Linux环境搭建Flink1.18版本 (03)

文章目录 引言01 Linux部署JDK11版本1.下载Linux版本的JDK112.创建目录3.上传并解压4.配置环境变量5.刷新环境变量6.检查jdk安装是否成功 02 Linux部署Flink1.18.0版本1.下载Flink1.18.0版本包2.上传压缩包到服务器3.修改flink-config.yaml配置4.启动服务5.浏览器访问6.停止服务…

【解决】IntelliJ IDEA 重命名 Shift + F6 失效

IntelliJ IDEA 重命名 Shift F6 失效 问题解决 问题 Idea 重命名 Shift F6 ,一直没反应 解决 调查发现原因是微软新版的输入法冲突了。需要设置【使用以前版本的微软拼音输入法】解决兼容性。 设置 -> 时间和语言 -> 区域 -> 语言选项 -> 键盘选项…

【开源】基于JAVA语言的就医保险管理系统

目录 一、摘要1.1 项目介绍1.2 项目录屏 二、功能模块2.1 科室档案模块2.2 医生档案模块2.3 预约挂号模块2.4 我的挂号模块 三、系统展示四、核心代码4.1 用户查询全部医生4.2 新增医生4.3 查询科室4.4 新增号源4.5 预约号源 五、免责说明 一、摘要 1.1 项目介绍 基于JAVAVue…

SpringBoot 配置类解析

全局流程解析 配置类解析入口 postProcessBeanDefinitionRegistry逻辑 processConfigBeanDefinitions逻辑 执行逻辑解析 执行入口 ConfigurationClassPostProcessor.processConfigBeanDefinitions()方法中的do while循环体中 循环体逻辑 parse方法调用链 doProcessConfigurat…

C#,计算几何,随机点集之三角剖分的德劳内(Delaunay)算法的源代码

一、三角剖分Delaunay算法简介 点集的三角剖分(Triangulation),对数值分析(比如有限元分析)以及图形学来说,都是极为重要的一项预处理技术。尤其是Delaunay三角剖分,由于其独特性,关…

03 Redis之命令(基本命令+Key命令+String型Value命令与应用场景)

Redis 根据命令所操作对象的不同,可以分为三大类:对 Redis 进行基础性操作的命令,对 Key 的操作命令,对 Value 的操作命令。 3.1 Redis 基本命令 一些可选项对大小写敏感, 所以应尽量将redis的所有命令大写输入 首先通过 redis-…

Nodejs前端学习Day3_准备工作

妈的,这几天真tm冷,前天上午还下了一整天的雪,大雪 文章目录 前言一、Node.js简介1.1何为1.2有什么 二、Node.js可以做什么三、学习路线四、下载nodejs4.1小坑记录4.2LTS和Current版本的不同 五、什么是终端六、在nodejs中执行js代码七、powe…

【PyTorch】n卡驱动、CUDA Toolkit、cuDNN全解安装教程

文章目录 GPU、NVIDIA Graphics Drivers、CUDA、CUDA Toolkit和cuDNN的关系使用情形判断仅仅使用PyTorch使用torch的第三方子模块 安装NVIDIA Graphics Drivers(可跳过)前言Linux法一:图形化界面安装(推荐)法二&#x…

用友U8接口-部署和简要说明(1)

概括 本专栏文章目的说明对目前用友U8ERP接口介绍对底层接口二次封装的介绍 说明 过去发布过介绍U8接口文章简介,参考以下链接。 U8接口开发方式 本专栏文章与下面的HTTP接口相辅相成,主要是写给正在使用,或未来使用本套接口的开发人员&am…

PhpStorm调试docker容器中的php项目

背景 已经通过docker容器启动了一个web服务&#xff0c;并在宿主机可以访问http://localhost:8080访问网页。 现在想使用phpstorm打断点调试代码。 方法 1. 容器内安装xdebug 进入容器 docker exec -it <container-name> bash为php安装xdebug拓展 apt install php8…

市场复盘总结 20240122

仅用于记录当天的市场情况&#xff0c;用于统计交易策略的适用情况&#xff0c;以便程序回测 短线核心&#xff1a;不参与任何级别的调整&#xff0c;采用龙空龙模式 昨日主题投资 连板进级率 6/39 15.3% 二进三&#xff1a; 进级率低 0% 最常用的二种方法&#xff1a; 方法…

数据湖技术之发展现状篇

一. 大数据处理架构&#xff1a; 大数据处理架构的发展过程具体可以分为三个主要阶段&#xff1a;批处理架构、混合处理架构&#xff08;Lambda、Kappa架构&#xff09;、湖仓一体。首先是随着Hadoop生态相关技术的大量应用&#xff0c;批处理架构应运而生&#xff0c;借助离线…

用C语言实现贪吃蛇游戏!!!

前言 大家好呀&#xff0c;我是Humble&#xff0c;不知不觉在CSND分享自己学过的C语言知识已经有三个多月了&#xff0c;从开始的C语言常见语法概念说到C语言的数据结构今天用C语言实现贪吃蛇已经有30余篇博客的内容&#xff0c;也希望这些内容可以帮助到各位正在阅读的小伙伴…