注:本文为《动手学深度学习》开源内容,部分标注了个人理解,仅为个人学习记录,无抄袭搬运意图
6.6 通过时间反向传播
在前面两节中,如果不裁剪梯度,模型将无法正常训练。为了深刻理解这一现象,本节将介绍循环神经网络中梯度的计算和存储方法,即通过时间反向传播(back-propagation through time)。
我们在3.14节(正向传播、反向传播和计算图)中介绍了神经网络中梯度计算与存储的一般思路,并强调正向传播和反向传播相互依赖。正向传播在循环神经网络中比较直观,而通过时间反向传播其实是反向传播在循环神经网络中的具体应用。我们需要将循环神经网络按时间步展开,从而得到模型变量和参数之间的依赖关系,并依据链式法则应用反向传播计算并存储梯度。
6.6.1 定义模型
简单起见,我们考虑一个无偏差项的循环神经网络,且激活函数为恒等映射( ϕ ( x ) = x \phi(x)=x ϕ(x)=x)。设时间步 t t t 的输入为单样本 x t ∈ R d \boldsymbol{x}_t \in \mathbb{R}^d xt∈Rd,标签为 y t y_t yt,那么隐藏状态 h t ∈ R h \boldsymbol{h}_t \in \mathbb{R}^h ht∈Rh的计算表达式为
h t = W h x x t + W h h h t − 1 , \boldsymbol{h}_t = \boldsymbol{W}_{hx} \boldsymbol{x}_t + \boldsymbol{W}_{hh} \boldsymbol{h}_{t-1}, ht=Whxxt+Whhht−1,
其中 W h x ∈ R h × d \boldsymbol{W}_{hx} \in \mathbb{R}^{h \times d} Whx∈Rh×d和 W h h ∈ R h × h \boldsymbol{W}_{hh} \in \mathbb{R}^{h \times h} Whh∈Rh×h是隐藏层权重参数。设输出层权重参数 W q h ∈ R q × h \boldsymbol{W}_{qh} \in \mathbb{R}^{q \times h} Wqh∈Rq×h,时间步 t t t的输出层变量 o t ∈ R q \boldsymbol{o}_t \in \mathbb{R}^q ot∈Rq计算为
o t = W q h h t . \boldsymbol{o}_t = \boldsymbol{W}_{qh} \boldsymbol{h}_{t}. ot=Wqhht.
设时间步 t t t的损失为 ℓ ( o t , y t ) \ell(\boldsymbol{o}_t, y_t) ℓ(ot,yt)。时间步数为 T T T的损失函数 L L L定义为
L = 1 T ∑ t = 1 T ℓ ( o t , y t ) . L = \frac{1}{T} \sum_{t=1}^T \ell (\boldsymbol{o}_t, y_t). L=T1t=1∑Tℓ(ot,yt).
我们将 L L L称为有关给定时间步的数据样本的目标函数,并在本节后续讨论中简称为目标函数。
6.6.2 模型计算图
为了可视化循环神经网络中模型变量和参数在计算中的依赖关系,我们可以绘制模型计算图,如图6.3所示。例如,时间步3的隐藏状态 h 3 \boldsymbol{h}_3 h3的计算依赖模型参数 W h x \boldsymbol{W}_{hx} Whx、 W h h \boldsymbol{W}_{hh} Whh、上一时间步隐藏状态 h 2 \boldsymbol{h}_2 h2以及当前时间步输入 x 3 \boldsymbol{x}_3 x3。
6.6.3 方法
刚刚提到,图6.3中的模型的参数是
W
h
x
\boldsymbol{W}_{hx}
Whx,
W
h
h
\boldsymbol{W}_{hh}
Whh 和
W
q
h
\boldsymbol{W}_{qh}
Wqh。与3.14节(正向传播、反向传播和计算图)中的类似,训练模型通常需要模型参数的梯度
∂
L
/
∂
W
h
x
\partial L/\partial \boldsymbol{W}_{hx}
∂L/∂Whx、
∂
L
/
∂
W
h
h
\partial L/\partial \boldsymbol{W}_{hh}
∂L/∂Whh和
∂
L
/
∂
W
q
h
\partial L/\partial \boldsymbol{W}_{qh}
∂L/∂Wqh。
根据图6.3中的依赖关系,我们可以按照其中箭头所指的反方向依次计算并存储梯度。为了表述方便,我们依然采用3.14节中表达链式法则的运算符prod。
首先,目标函数有关各时间步输出层变量的梯度 ∂ L / ∂ o t ∈ R q \partial L/\partial \boldsymbol{o}_t \in \mathbb{R}^q ∂L/∂ot∈Rq很容易计算:
∂ L ∂ o t = ∂ ℓ ( o t , y t ) T ⋅ ∂ o t . \frac{\partial L}{\partial \boldsymbol{o}_t} = \frac{\partial \ell (\boldsymbol{o}_t, y_t)}{T \cdot \partial \boldsymbol{o}_t}. ∂ot∂L=T⋅∂ot∂ℓ(ot,yt).
下面,我们可以计算目标函数有关模型参数 W q h \boldsymbol{W}_{qh} Wqh的梯度 ∂ L / ∂ W q h ∈ R q × h \partial L/\partial \boldsymbol{W}_{qh} \in \mathbb{R}^{q \times h} ∂L/∂Wqh∈Rq×h。根据图6.3, L L L通过 o 1 , … , o T \boldsymbol{o}_1, \ldots, \boldsymbol{o}_T o1,…,oT依赖 W q h \boldsymbol{W}_{qh} Wqh。依据链式法则,
∂ L ∂ W q h = ∑ t = 1 T prod ( ∂ L ∂ o t , ∂ o t ∂ W q h ) = ∑ t = 1 T ∂ L ∂ o t h t ⊤ . \frac{\partial L}{\partial \boldsymbol{W}_{qh}} = \sum_{t=1}^T \text{prod}\left(\frac{\partial L}{\partial \boldsymbol{o}_t}, \frac{\partial \boldsymbol{o}_t}{\partial \boldsymbol{W}_{qh}}\right) = \sum_{t=1}^T \frac{\partial L}{\partial \boldsymbol{o}_t} \boldsymbol{h}_t^\top. ∂Wqh∂L=t=1∑Tprod(∂ot∂L,∂Wqh∂ot)=t=1∑T∂ot∂Lht⊤.
其次,我们注意到隐藏状态之间也存在依赖关系。
在图6.3中,
L
L
L只通过
o
T
\boldsymbol{o}_T
oT依赖最终时间步
T
T
T的隐藏状态
h
T
\boldsymbol{h}_T
hT。因此,我们先计算目标函数有关最终时间步隐藏状态的梯度
∂
L
/
∂
h
T
∈
R
h
\partial L/\partial \boldsymbol{h}_T \in \mathbb{R}^h
∂L/∂hT∈Rh。依据链式法则,我们得到
∂ L ∂ h T = prod ( ∂ L ∂ o T , ∂ o T ∂ h T ) = W q h ⊤ ∂ L ∂ o T . \frac{\partial L}{\partial \boldsymbol{h}_T} = \text{prod}\left(\frac{\partial L}{\partial \boldsymbol{o}_T}, \frac{\partial \boldsymbol{o}_T}{\partial \boldsymbol{h}_T} \right) = \boldsymbol{W}_{qh}^\top \frac{\partial L}{\partial \boldsymbol{o}_T}. ∂hT∂L=prod(∂oT∂L,∂hT∂oT)=Wqh⊤∂oT∂L.
接下来对于时间步
t
<
T
t < T
t<T, 在图6.3中,
L
L
L通过
h
t
+
1
\boldsymbol{h}_{t+1}
ht+1和
o
t
\boldsymbol{o}_t
ot依赖
h
t
\boldsymbol{h}_t
ht。依据链式法则,
目标函数有关时间步
t
<
T
t < T
t<T的隐藏状态的梯度
∂
L
/
∂
h
t
∈
R
h
\partial L/\partial \boldsymbol{h}_t \in \mathbb{R}^h
∂L/∂ht∈Rh需要按照时间步从大到小依次计算:
∂
L
∂
h
t
=
prod
(
∂
L
∂
h
t
+
1
,
∂
h
t
+
1
∂
h
t
)
+
prod
(
∂
L
∂
o
t
,
∂
o
t
∂
h
t
)
=
W
h
h
⊤
∂
L
∂
h
t
+
1
+
W
q
h
⊤
∂
L
∂
o
t
\frac{\partial L}{\partial \boldsymbol{h}_t} = \text{prod} (\frac{\partial L}{\partial \boldsymbol{h}_{t+1}}, \frac{\partial \boldsymbol{h}_{t+1}}{\partial \boldsymbol{h}_t}) + \text{prod} (\frac{\partial L}{\partial \boldsymbol{o}_t}, \frac{\partial \boldsymbol{o}_t}{\partial \boldsymbol{h}_t} ) = \boldsymbol{W}_{hh}^\top \frac{\partial L}{\partial \boldsymbol{h}_{t+1}} + \boldsymbol{W}_{qh}^\top \frac{\partial L}{\partial \boldsymbol{o}_t}
∂ht∂L=prod(∂ht+1∂L,∂ht∂ht+1)+prod(∂ot∂L,∂ht∂ot)=Whh⊤∂ht+1∂L+Wqh⊤∂ot∂L
将上面的递归公式展开,对任意时间步 1 ≤ t ≤ T 1 \leq t \leq T 1≤t≤T,我们可以得到目标函数有关隐藏状态梯度的通项公式
∂ L ∂ h t = ∑ i = t T ( W h h ⊤ ) T − i W q h ⊤ ∂ L ∂ o T + t − i . \frac{\partial L}{\partial \boldsymbol{h}_t} = \sum_{i=t}^T {\left(\boldsymbol{W}_{hh}^\top\right)}^{T-i} \boldsymbol{W}_{qh}^\top \frac{\partial L}{\partial \boldsymbol{o}_{T+t-i}}. ∂ht∂L=i=t∑T(Whh⊤)T−iWqh⊤∂oT+t−i∂L.
由上式中的指数项可见,当时间步数
T
T
T 较大或者时间步
t
t
t 较小时,目标函数有关隐藏状态的梯度较容易出现衰减和爆炸。这也会影响其他包含
∂
L
/
∂
h
t
\partial L / \partial \boldsymbol{h}_t
∂L/∂ht项的梯度,例如隐藏层中模型参数的梯度
∂
L
/
∂
W
h
x
∈
R
h
×
d
\partial L / \partial \boldsymbol{W}_{hx} \in \mathbb{R}^{h \times d}
∂L/∂Whx∈Rh×d和
∂
L
/
∂
W
h
h
∈
R
h
×
h
\partial L / \partial \boldsymbol{W}_{hh} \in \mathbb{R}^{h \times h}
∂L/∂Whh∈Rh×h。
在图6.3中,
L
L
L通过
h
1
,
…
,
h
T
\boldsymbol{h}_1, \ldots, \boldsymbol{h}_T
h1,…,hT依赖这些模型参数。
依据链式法则,我们有
∂ L ∂ W h x = ∑ t = 1 T prod ( ∂ L ∂ h t , ∂ h t ∂ W h x ) = ∑ t = 1 T ∂ L ∂ h t x t ⊤ , ∂ L ∂ W h h = ∑ t = 1 T prod ( ∂ L ∂ h t , ∂ h t ∂ W h h ) = ∑ t = 1 T ∂ L ∂ h t h t − 1 ⊤ . \begin{aligned} \frac{\partial L}{\partial \boldsymbol{W}_{hx}} &= \sum_{t=1}^T \text{prod}\left(\frac{\partial L}{\partial \boldsymbol{h}_t}, \frac{\partial \boldsymbol{h}_t}{\partial \boldsymbol{W}_{hx}}\right) = \sum_{t=1}^T \frac{\partial L}{\partial \boldsymbol{h}_t} \boldsymbol{x}_t^\top,\\ \frac{\partial L}{\partial \boldsymbol{W}_{hh}} &= \sum_{t=1}^T \text{prod}\left(\frac{\partial L}{\partial \boldsymbol{h}_t}, \frac{\partial \boldsymbol{h}_t}{\partial \boldsymbol{W}_{hh}}\right) = \sum_{t=1}^T \frac{\partial L}{\partial \boldsymbol{h}_t} \boldsymbol{h}_{t-1}^\top. \end{aligned} ∂Whx∂L∂Whh∂L=t=1∑Tprod(∂ht∂L,∂Whx∂ht)=t=1∑T∂ht∂Lxt⊤,=t=1∑Tprod(∂ht∂L,∂Whh∂ht)=t=1∑T∂ht∂Lht−1⊤.
我们已在3.14节里解释过,每次迭代中,我们在依次计算完以上各个梯度后,会将它们存储起来,从而避免重复计算。例如,由于隐藏状态梯度
∂
L
/
∂
h
t
\partial L/\partial \boldsymbol{h}_t
∂L/∂ht被计算和存储,之后的模型参数梯度
∂
L
/
∂
W
h
x
\partial L/\partial \boldsymbol{W}_{hx}
∂L/∂Whx和
∂
L
/
∂
W
h
h
\partial L/\partial \boldsymbol{W}_{hh}
∂L/∂Whh的计算可以直接读取
∂
L
/
∂
h
t
\partial L/\partial \boldsymbol{h}_t
∂L/∂ht的值,而无须重复计算它们。此外,反向传播中的梯度计算可能会依赖变量的当前值。它们正是通过正向传播计算出来的。
举例来说,参数梯度
∂
L
/
∂
W
h
h
\partial L/\partial \boldsymbol{W}_{hh}
∂L/∂Whh的计算需要依赖隐藏状态在时间步
t
=
0
,
…
,
T
−
1
t = 0, \ldots, T-1
t=0,…,T−1的当前值
h
t
\boldsymbol{h}_t
ht(
h
0
\boldsymbol{h}_0
h0是初始化得到的)。这些值是通过从输入层到输出层的正向传播计算并存储得到的。
小结
- 通过时间反向传播是反向传播在循环神经网络中的具体应用。
- 当总的时间步数较大或者当前时间步较小时,循环神经网络的梯度较容易出现衰减或爆炸。
注:本节与原书基本相同,原书传送门