网上一些描述LSTM文章看的云里雾里,只是介绍LSTM 的结构,并没有说明原理。我这里用通俗易懂的话来描述一下。
我们先来复习一些RNN的核心公式:
h
t
=
t
a
n
h
(
W
h
h
t
−
1
+
W
x
x
t
+
b
h
)
h_t = tanh(W_h h_{t-1} + W_x x_t + b_h)
ht=tanh(Whht−1+Wxxt+bh)
y
t
=
W
y
h
t
+
b
y
y_t = W_y h_t + b_y
yt=Wyht+by
我们可以注意到 最终输出依赖于x(当前输入) 和 之前的状态 H t H_t Ht, H t H_t Ht是上一个轮隐藏状态,包含了之前输入计算得到的信息, 作为记忆力给下一轮输出使用,这就造成了一个问题, 随着序列的增长, H t H_t Ht 包含的更早之前的信息就会越来越少,所以可以称 H t H_t Ht 为短期记忆。 LSTM 就是在rnn的基础上,结合短期记忆增加了长期记忆,所以叫长短期记忆。
我们来看LSTM 的结构图
这里我们依然可以看到
H
t
H_t
Ht 结构,这里依然是短期记忆。$$C_t 则是长期记忆。
另外值得注意的是σ代表sigmod 函数,sigmod将结果映射到0-1 空间。用于遗忘一些不重要的输入(也就是权重比较低的x),我们来看一下长期记忆和短期记忆是如何形成的,
-
遗忘门(Forget Gate)
- 作用:决定遗忘多少过去的信息
- 公式: f t = σ ( W f ⋅ [ h t − 1 , x t ] + b f ) f_t = \sigma(W_f \cdot [h_{t-1}, x_t] + b_f) ft=σ(Wf⋅[ht−1,xt]+bf)
- 其中
σ
(sigmoid)被用于将输出压缩到[0,1]
,从而控制信息的遗忘程度。
-
输入门(Input Gate)
- 作用:决定当前时刻的新信息有多少被加入到细胞状态(Cell State)
- 公式:
- i t = σ ( W i ⋅ [ h t − 1 , x t ] + b i ) i_t = \sigma(W_i \cdot [h_{t-1}, x_t] + b_i) it=σ(Wi⋅[ht−1,xt]+bi) (输入门权重)
- C ~ t = tanh ( W C ⋅ [ h t − 1 , x t ] + b C ) \tilde{C}_t = \tanh(W_C \cdot [h_{t-1}, x_t] + b_C) C~t=tanh(WC⋅[ht−1,xt]+bC) (候选信息)
- C t = f t ∗ C t − 1 + i t ∗ C ~ t C_t = f_t * C_{t-1} + i_t * \tilde{C}_t Ct=ft∗Ct−1+it∗C~t (更新长期记忆状态)
-
输出门(Output Gate)
- 作用:决定 LSTM 该输出多少信息
- 公式:
- o t = σ ( W o ⋅ [ h t − 1 , x t ] + b o ) o_t = \sigma(W_o \cdot [h_{t-1}, x_t] + b_o) ot=σ(Wo⋅[ht−1,xt]+bo)
- h t = o t ∗ tanh ( C t ) h_t = o_t * \tanh(C_t) ht=ot∗tanh(Ct) (更新短期记忆)
这里我们可以清晰的看到,长短期记忆模型的核心在于使用sigmod来实现遗忘一些信息,使得 C t C_t Ct能保存更长序列的核心内容,也就是长期记忆。 由于 H t H_t Ht 直接受输出门(Output Gate)控制,它的值可能会随着时间快速变化,因此它更偏向短期信息的存储。
总结:
- H t H_t Ht 是 LSTM 每个时间步的输出,它会被传递到下一个时间步,也可以用于最终的预测。
- 由于 H t H_t Ht 直接受输出门(Output Gate)控制,它的值可能会随着时间快速变化,因此它更偏向短期信息的存储。
- 在某些情况下,( h_t ) 可能会丢失远程依赖信息,类似于 RNN 里的信息传递方式。
🔹 为什么 C t C_t Ct 代表长期记忆?
- C t C_t Ct 是 细胞状态(Cell State),它通过遗忘门(Forget Gate) 和 输入门(Input Gate) 来更新信息。
- 遗忘门可以选择保留一部分过去的信息,让 C t C_t Ct 可以跨多个时间步存储重要信息,而不会像 H t H_t Ht 那样频繁变化。
- 这样,LSTM 解决了普通 RNN 梯度消失 的问题,使得模型可以记住更长时间的依赖关系。