【深度学习笔记】6_6 通过时间反向传播(back-propagation through time)

注:本文为《动手学深度学习》开源内容,部分标注了个人理解,仅为个人学习记录,无抄袭搬运意图

6.6 通过时间反向传播

在前面两节中,如果不裁剪梯度,模型将无法正常训练。为了深刻理解这一现象,本节将介绍循环神经网络中梯度的计算和存储方法,即通过时间反向传播(back-propagation through time)。

我们在3.14节(正向传播、反向传播和计算图)中介绍了神经网络中梯度计算与存储的一般思路,并强调正向传播和反向传播相互依赖。正向传播在循环神经网络中比较直观,而通过时间反向传播其实是反向传播在循环神经网络中的具体应用。我们需要将循环神经网络按时间步展开,从而得到模型变量和参数之间的依赖关系,并依据链式法则应用反向传播计算并存储梯度。

6.6.1 定义模型

简单起见,我们考虑一个无偏差项的循环神经网络,且激活函数为恒等映射( ϕ ( x ) = x \phi(x)=x ϕ(x)=x)。设时间步 t t t 的输入为单样本 x t ∈ R d \boldsymbol{x}_t \in \mathbb{R}^d xtRd,标签为 y t y_t yt,那么隐藏状态 h t ∈ R h \boldsymbol{h}_t \in \mathbb{R}^h htRh的计算表达式为

h t = W h x x t + W h h h t − 1 , \boldsymbol{h}_t = \boldsymbol{W}_{hx} \boldsymbol{x}_t + \boldsymbol{W}_{hh} \boldsymbol{h}_{t-1}, ht=Whxxt+Whhht1,

其中 W h x ∈ R h × d \boldsymbol{W}_{hx} \in \mathbb{R}^{h \times d} WhxRh×d W h h ∈ R h × h \boldsymbol{W}_{hh} \in \mathbb{R}^{h \times h} WhhRh×h是隐藏层权重参数。设输出层权重参数 W q h ∈ R q × h \boldsymbol{W}_{qh} \in \mathbb{R}^{q \times h} WqhRq×h,时间步 t t t的输出层变量 o t ∈ R q \boldsymbol{o}_t \in \mathbb{R}^q otRq计算为

o t = W q h h t . \boldsymbol{o}_t = \boldsymbol{W}_{qh} \boldsymbol{h}_{t}. ot=Wqhht.

设时间步 t t t的损失为 ℓ ( o t , y t ) \ell(\boldsymbol{o}_t, y_t) (ot,yt)。时间步数为 T T T的损失函数 L L L定义为

L = 1 T ∑ t = 1 T ℓ ( o t , y t ) . L = \frac{1}{T} \sum_{t=1}^T \ell (\boldsymbol{o}_t, y_t). L=T1t=1T(ot,yt).

我们将 L L L称为有关给定时间步的数据样本的目标函数,并在本节后续讨论中简称为目标函数。

6.6.2 模型计算图

为了可视化循环神经网络中模型变量和参数在计算中的依赖关系,我们可以绘制模型计算图,如图6.3所示。例如,时间步3的隐藏状态 h 3 \boldsymbol{h}_3 h3的计算依赖模型参数 W h x \boldsymbol{W}_{hx} Whx W h h \boldsymbol{W}_{hh} Whh、上一时间步隐藏状态 h 2 \boldsymbol{h}_2 h2以及当前时间步输入 x 3 \boldsymbol{x}_3 x3

在这里插入图片描述

图6.3 时间步数为3的循环神经网络模型计算中的依赖关系。方框代表变量(无阴影)或参数(有阴影),圆圈代表运算符

6.6.3 方法

刚刚提到,图6.3中的模型的参数是 W h x \boldsymbol{W}_{hx} Whx, W h h \boldsymbol{W}_{hh} Whh W q h \boldsymbol{W}_{qh} Wqh。与3.14节(正向传播、反向传播和计算图)中的类似,训练模型通常需要模型参数的梯度 ∂ L / ∂ W h x \partial L/\partial \boldsymbol{W}_{hx} L/Whx ∂ L / ∂ W h h \partial L/\partial \boldsymbol{W}_{hh} L/Whh ∂ L / ∂ W q h \partial L/\partial \boldsymbol{W}_{qh} L/Wqh
根据图6.3中的依赖关系,我们可以按照其中箭头所指的反方向依次计算并存储梯度。为了表述方便,我们依然采用3.14节中表达链式法则的运算符prod。

首先,目标函数有关各时间步输出层变量的梯度 ∂ L / ∂ o t ∈ R q \partial L/\partial \boldsymbol{o}_t \in \mathbb{R}^q L/otRq很容易计算:

∂ L ∂ o t = ∂ ℓ ( o t , y t ) T ⋅ ∂ o t . \frac{\partial L}{\partial \boldsymbol{o}_t} = \frac{\partial \ell (\boldsymbol{o}_t, y_t)}{T \cdot \partial \boldsymbol{o}_t}. otL=Tot(ot,yt).

下面,我们可以计算目标函数有关模型参数 W q h \boldsymbol{W}_{qh} Wqh的梯度 ∂ L / ∂ W q h ∈ R q × h \partial L/\partial \boldsymbol{W}_{qh} \in \mathbb{R}^{q \times h} L/WqhRq×h。根据图6.3, L L L通过 o 1 , … , o T \boldsymbol{o}_1, \ldots, \boldsymbol{o}_T o1,,oT依赖 W q h \boldsymbol{W}_{qh} Wqh。依据链式法则,

∂ L ∂ W q h = ∑ t = 1 T prod ( ∂ L ∂ o t , ∂ o t ∂ W q h ) = ∑ t = 1 T ∂ L ∂ o t h t ⊤ . \frac{\partial L}{\partial \boldsymbol{W}_{qh}} = \sum_{t=1}^T \text{prod}\left(\frac{\partial L}{\partial \boldsymbol{o}_t}, \frac{\partial \boldsymbol{o}_t}{\partial \boldsymbol{W}_{qh}}\right) = \sum_{t=1}^T \frac{\partial L}{\partial \boldsymbol{o}_t} \boldsymbol{h}_t^\top. WqhL=t=1Tprod(otL,Wqhot)=t=1TotLht.

其次,我们注意到隐藏状态之间也存在依赖关系。
在图6.3中, L L L只通过 o T \boldsymbol{o}_T oT依赖最终时间步 T T T的隐藏状态 h T \boldsymbol{h}_T hT。因此,我们先计算目标函数有关最终时间步隐藏状态的梯度 ∂ L / ∂ h T ∈ R h \partial L/\partial \boldsymbol{h}_T \in \mathbb{R}^h L/hTRh。依据链式法则,我们得到

∂ L ∂ h T = prod ( ∂ L ∂ o T , ∂ o T ∂ h T ) = W q h ⊤ ∂ L ∂ o T . \frac{\partial L}{\partial \boldsymbol{h}_T} = \text{prod}\left(\frac{\partial L}{\partial \boldsymbol{o}_T}, \frac{\partial \boldsymbol{o}_T}{\partial \boldsymbol{h}_T} \right) = \boldsymbol{W}_{qh}^\top \frac{\partial L}{\partial \boldsymbol{o}_T}. hTL=prod(oTL,hToT)=WqhoTL.

接下来对于时间步 t < T t < T t<T, 在图6.3中, L L L通过 h t + 1 \boldsymbol{h}_{t+1} ht+1 o t \boldsymbol{o}_t ot依赖 h t \boldsymbol{h}_t ht。依据链式法则,
目标函数有关时间步 t < T t < T t<T的隐藏状态的梯度 ∂ L / ∂ h t ∈ R h \partial L/\partial \boldsymbol{h}_t \in \mathbb{R}^h L/htRh需要按照时间步从大到小依次计算:
∂ L ∂ h t = prod ( ∂ L ∂ h t + 1 , ∂ h t + 1 ∂ h t ) + prod ( ∂ L ∂ o t , ∂ o t ∂ h t ) = W h h ⊤ ∂ L ∂ h t + 1 + W q h ⊤ ∂ L ∂ o t \frac{\partial L}{\partial \boldsymbol{h}_t} = \text{prod} (\frac{\partial L}{\partial \boldsymbol{h}_{t+1}}, \frac{\partial \boldsymbol{h}_{t+1}}{\partial \boldsymbol{h}_t}) + \text{prod} (\frac{\partial L}{\partial \boldsymbol{o}_t}, \frac{\partial \boldsymbol{o}_t}{\partial \boldsymbol{h}_t} ) = \boldsymbol{W}_{hh}^\top \frac{\partial L}{\partial \boldsymbol{h}_{t+1}} + \boldsymbol{W}_{qh}^\top \frac{\partial L}{\partial \boldsymbol{o}_t} htL=prod(ht+1L,htht+1)+prod(otL,htot)=Whhht+1L+WqhotL

将上面的递归公式展开,对任意时间步 1 ≤ t ≤ T 1 \leq t \leq T 1tT,我们可以得到目标函数有关隐藏状态梯度的通项公式

∂ L ∂ h t = ∑ i = t T ( W h h ⊤ ) T − i W q h ⊤ ∂ L ∂ o T + t − i . \frac{\partial L}{\partial \boldsymbol{h}_t} = \sum_{i=t}^T {\left(\boldsymbol{W}_{hh}^\top\right)}^{T-i} \boldsymbol{W}_{qh}^\top \frac{\partial L}{\partial \boldsymbol{o}_{T+t-i}}. htL=i=tT(Whh)TiWqhoT+tiL.

由上式中的指数项可见,当时间步数 T T T 较大或者时间步 t t t 较小时,目标函数有关隐藏状态的梯度较容易出现衰减和爆炸。这也会影响其他包含 ∂ L / ∂ h t \partial L / \partial \boldsymbol{h}_t L/ht项的梯度,例如隐藏层中模型参数的梯度 ∂ L / ∂ W h x ∈ R h × d \partial L / \partial \boldsymbol{W}_{hx} \in \mathbb{R}^{h \times d} L/WhxRh×d ∂ L / ∂ W h h ∈ R h × h \partial L / \partial \boldsymbol{W}_{hh} \in \mathbb{R}^{h \times h} L/WhhRh×h
在图6.3中, L L L通过 h 1 , … , h T \boldsymbol{h}_1, \ldots, \boldsymbol{h}_T h1,,hT依赖这些模型参数。
依据链式法则,我们有

∂ L ∂ W h x = ∑ t = 1 T prod ( ∂ L ∂ h t , ∂ h t ∂ W h x ) = ∑ t = 1 T ∂ L ∂ h t x t ⊤ , ∂ L ∂ W h h = ∑ t = 1 T prod ( ∂ L ∂ h t , ∂ h t ∂ W h h ) = ∑ t = 1 T ∂ L ∂ h t h t − 1 ⊤ . \begin{aligned} \frac{\partial L}{\partial \boldsymbol{W}_{hx}} &= \sum_{t=1}^T \text{prod}\left(\frac{\partial L}{\partial \boldsymbol{h}_t}, \frac{\partial \boldsymbol{h}_t}{\partial \boldsymbol{W}_{hx}}\right) = \sum_{t=1}^T \frac{\partial L}{\partial \boldsymbol{h}_t} \boldsymbol{x}_t^\top,\\ \frac{\partial L}{\partial \boldsymbol{W}_{hh}} &= \sum_{t=1}^T \text{prod}\left(\frac{\partial L}{\partial \boldsymbol{h}_t}, \frac{\partial \boldsymbol{h}_t}{\partial \boldsymbol{W}_{hh}}\right) = \sum_{t=1}^T \frac{\partial L}{\partial \boldsymbol{h}_t} \boldsymbol{h}_{t-1}^\top. \end{aligned} WhxLWhhL=t=1Tprod(htL,Whxht)=t=1ThtLxt,=t=1Tprod(htL,Whhht)=t=1ThtLht1.

我们已在3.14节里解释过,每次迭代中,我们在依次计算完以上各个梯度后,会将它们存储起来,从而避免重复计算。例如,由于隐藏状态梯度 ∂ L / ∂ h t \partial L/\partial \boldsymbol{h}_t L/ht被计算和存储,之后的模型参数梯度 ∂ L / ∂ W h x \partial L/\partial \boldsymbol{W}_{hx} L/Whx ∂ L / ∂ W h h \partial L/\partial \boldsymbol{W}_{hh} L/Whh的计算可以直接读取 ∂ L / ∂ h t \partial L/\partial \boldsymbol{h}_t L/ht的值,而无须重复计算它们。此外,反向传播中的梯度计算可能会依赖变量的当前值。它们正是通过正向传播计算出来的。
举例来说,参数梯度 ∂ L / ∂ W h h \partial L/\partial \boldsymbol{W}_{hh} L/Whh的计算需要依赖隐藏状态在时间步 t = 0 , … , T − 1 t = 0, \ldots, T-1 t=0,,T1的当前值 h t \boldsymbol{h}_t ht h 0 \boldsymbol{h}_0 h0是初始化得到的)。这些值是通过从输入层到输出层的正向传播计算并存储得到的。

小结

  • 通过时间反向传播是反向传播在循环神经网络中的具体应用。
  • 当总的时间步数较大或者当前时间步较小时,循环神经网络的梯度较容易出现衰减或爆炸。

注:本节与原书基本相同,原书传送门

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/442803.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

C#,排列组合的堆生成法(Heap’s Algorithm for generating permutations)算法与源代码

1 排列组合的堆生成法 堆生成算法用于生成n个对象的所有组合。其思想是通过选择一对要交换的元素&#xff0c;在不干扰其他n-2元素的情况下&#xff0c;从先前的组合生成每个组合。 下面是生成n个给定数的所有组合的示例。 示例&#xff1a; 输入&#xff1a;1 2 3 输出&a…

2024蓝桥杯每日一题(归并排序)

一、第一题&#xff1a;火柴排队 解题思路&#xff1a;归并排序 重点在于想清楚是对哪个数组进行归并排序求逆序对 【Python程序代码】 from math import * n int(input()) a list(map(int,input().split())) b list(map(int,input().split())) na,nb [],[] for …

#onenet网络请求http(GET,POST)

参考博文&#xff1a; POST: https://blog.csdn.net/qq_43350239/article/details/104361153 POST请求&#xff08;用串口助手测试&#xff09;&#xff1a; POST /devices/1105985351/datapoints HTTP/1.1 api-key:AdbrV5kCRsKsRCfjboYOCVcF9FY Host:api.heclouds.com Con…

liteIDE 解决go root报错 go: cannot find GOROOT directory: c:\go

liteIDE环境配置 我使用的liteIDE为 x36 5.9.5版本 。在查看–>选项 中可以看到 LiteEnv&#xff0c;双击LiteEnv &#xff0c;在右侧选择对应系统的env文件&#xff0c;我的是win64系统&#xff0c;所以文件名为win64.env 再双击 win64.env &#xff0c;关闭当前窗口&…

专业的项目管理系统,企智汇!帮助企业提高项目管理效率!

一款专业的项目管理系统&#xff0c;是企智汇项目管理系统&#xff01;企智汇专业做项目管理系统10年&#xff0c;经过10年的打磨&#xff0c;有成熟的项目管理系统功能&#xff0c;它面向各个企业的项目团队&#xff0c;提供数字化、智能化、信息化的项目管理功能&#xff0c;…

【PHP+代码审计】PHP基础——数据类型

&#x1f36c; 博主介绍&#x1f468;‍&#x1f393; 博主介绍&#xff1a;大家好&#xff0c;我是 hacker-routing &#xff0c;很高兴认识大家~ ✨主攻领域&#xff1a;【渗透领域】【应急响应】 【Java、PHP】 【VulnHub靶场复现】【面试分析】 &#x1f389;点赞➕评论➕收…

STM32H750片外QSPI启动配置简要

STM32H750片外QSPI启动配置简要 &#x1f4cd;参考信息源&#xff1a;《STM32H750片外Flash启动(W25Q64JVSIQ)》&#x1f516;本例程基于Keil MDk开发平台。&#x1f341;配置框架&#xff1a; ✨为什么使用要使用QSPI启动方式 不管对于STM32H7系列单片机&#xff0c;还是其他…

VGG:更小的卷积核

“paddingvalid”&#xff0c;表示不对输入数据进行填充(不填充)“paddingsame”&#xff0c;表示对输入数据进行填充&#xff0c;使得输出特征图的大小与输入特征图的大小相同 VGG16 网络架构 Why VGG16 VGG16 卷积层 全连接层层数 1个输出层 更小的卷积核 连续的33卷积…

换手机后日记不见了怎么恢复?换手机日记内容同步方法

曾经&#xff0c;我使用的是一款苹果手机&#xff0c;这部手机陪伴了我整整3年。随着时间的推移&#xff0c;手机内存不够用成为了我面临的一个大问题&#xff0c;因此我决定更换一部新手机——这次我选择了OPPO品牌。在更换手机的过程中&#xff0c;我利用手机搬家软件一键同步…

AI辅助研发对医药、汽车和电子行业的影响

目录 1、医药行业&#xff1a;加速药物发现与开发 2、汽车行业&#xff1a;创新设计与生产 3、电子行业&#xff1a;优化芯片设计 4、展望未来 AI辅助研发正逐渐改变我们的世界&#xff0c;特别是在医药、汽车和电子等关键行业中的应用&#xff0c;已经展现出其巨大潜力。本…

支付宝开放平台,证书验签出错,请确认charset参数放在了URL查询字符串中且各参数值使用charset参数指示的字符集编码

#小李子9479# 支付宝证书接入方式的时候出现如下错误 验签出错&#xff0c;请确认charset参数放在了URL查询字符串中且各参数值使用charset参数指示的字符集编码 产生的原因分析&#xff1a; 1。验证签名里面有中文 2。在json_encode的时候使用了JSON_UNESCAPED_UNICODE参数…

【字符串】【贪心】【 树状数组】2193. 得到回文串的最少操作次数

作者推荐 视频算法专题 本文涉及知识点 字符串 贪心 树状数组 分类讨论 LeetCode2193. 得到回文串的最少操作次数 给你一个只包含小写英文字母的字符串 s 。 每一次 操作 &#xff0c;你可以选择 s 中两个 相邻 的字符&#xff0c;并将它们交换。 请你返回将 s 变成回文串…

mysql-DBA(2)-日志-数据库复制

1.mysqlbinlog 查看日志-精确查找-增量备份 1.1查看日志 mysqlbinlog binlog.000003 -vv --base64auto | less //两种都有 base64加密和看的懂的明文 mysqlbinlog binlog.000003 -vv --base64never | less //不显示 mysqlbinlog binlog.000003 -vv --base64decode-rows …

Flask python开发篇: 写一个简单的接口

第一步&#xff1a;新建flask项目 参考使用pycharm新建一个项目 打开pycharm&#xff0c;根据下面图中箭头顺序&#xff0c;新建一个flask的项目&#xff1b; 第二步&#xff1a;运行项目&#xff0c; 安装成功以后&#xff0c;会有个app.py文件&#xff0c;打开以后&#…

STM32CubeIDE基础学习-STM32CubeIDE软件程序下载方法

STM32CubeIDE基础学习-STM32CubeIDE软件代码下载方法 文章目录 STM32CubeIDE基础学习-STM32CubeIDE软件代码下载方法前言第1章 代码下载第2章 下载器固件更新总结 前言 编写完代码&#xff0c;一般都会选择在线下载程序的方式进行验证该程序是否正确&#xff0c;如果发现结果和…

回收站选址(CCF 201912-2)解题思路

分析 把x,y坐标拼接成一个字符串&#xff08;x,y&#xff09;作为Set的key&#xff0c;保存到Set中&#xff0c;遍历Set&#xff0c;取出坐标&#xff0c;然后判断上下左右四个点是否在Set中&#xff0c;如果在&#xff0c;进而判断&#xff0c;四个角是否在Set中&#xff0c;…

搜索引擎推广6种策略助你站在市场顶峰的有效方法-华媒舍

如何有效推广自己的产品或服务变得至关重要。6种引擎霸屏推广策略是帮助你站在市场顶峰的有效方法&#xff0c;下面将逐一介绍这些策略。 1.搜索引擎优化&#xff08;SEO&#xff09; 搜索引擎优化是提升网站在搜索引擎结果页&#xff08;SERP&#xff09;中排名的策略。通过…

股票价格预测项目

项目介绍 背景 股票价格预测一直是金融领域的热点问题。准确的预测可以帮助投资者作出更明智的决策。本项目旨在使用机器学习技术&#xff0c;特别是长短期记忆网络&#xff08;LSTM&#xff09;&#xff0c;来预测股票价格。 目标 开发一个基于LSTM的股票价格预测模型。使…

Mamba-minimal Mamba的最小限度实现 (二)

文章目录 链接导入所需包class ModelArgsclass Mambadef __ init __def forward class ResidualBlockclass RNSNorm文本生成demo manba的简单最小限度实现&#xff0c;和原始论文实现 state-spaces/mamba (github.com)相比&#xff0c;为了可读性对参数没有很好的初始化&#…

c++ 常用的STL

前言 写这篇博客目的是为了记录在刷算法题中使用过的STL&#xff0c;因为有些不太常用的会遗忘。这篇博客只是作为笔记&#xff0c;不是详细的STL&#xff0c;因此只会对常用方法说明&#xff0c;不会详细介绍。此外在后面用到新的STL内容时会再补充。 列队 基础列队 基本列…