深度强化学习(四)SARSA算法
一.SARSA
假设状态空间 S \mathcal{S} S 和动作空间 A \mathcal{A} A 都是有限集, 即集合中元素数量有限。比如, S \mathcal{S} S 中一共有 3 种状态, A \mathcal{A} A 中一共有 4 种动作。那么动作价值函数 Q π ( s , a ) Q_\pi(s, a) Qπ(s,a) 可以表示为一个 3 × 4 3 \times 4 3×4 的表格。该表格与一个策略函数 π ( a ∣ s ) \pi(a \mid s) π(a∣s) 相关联; 如果 π \pi π 发生变化,表格 Q π Q_\pi Qπ 也会发生变化。
我们用表格 q q q 近似 Q π Q_\pi Qπ 。首先初始化 q q q, 可以让它是全零的表格。然后用表格形式的 SARSA 算法更新 q q q,每次更新表格的一个元素。最终 q q q 收敛到 Q π Q_\pi Qπ 。
SARSA 算法由下面的贝尔曼方程推导出 :
Q
π
(
s
t
,
a
t
)
=
E
S
t
+
1
,
A
t
+
1
[
R
t
+
γ
⋅
Q
π
(
S
t
+
1
,
A
t
+
1
)
∣
S
t
=
s
t
,
A
t
=
a
t
]
Q_\pi\left(s_t, a_t\right)=\mathbb{E}_{S_{t+1}, A_{t+1}}\left[R_t+\gamma \cdot Q_\pi\left(S_{t+1}, A_{t+1}\right) \mid S_t=s_t, A_t=a_t\right]
Qπ(st,at)=ESt+1,At+1[Rt+γ⋅Qπ(St+1,At+1)∣St=st,At=at]
我们对贝尔曼方程左右两边做近似:
- 方程左边的 Q π ( s t , a t ) Q_\pi\left(s_t, a_t\right) Qπ(st,at) 可以近似成 q ( s t , a t ) 。 q ( s t , a t ) q\left(s_t, a_t\right) 。 q\left(s_t, a_t\right) q(st,at)。q(st,at) 是表格在 t t t 时刻对 Q π ( s t , a t ) Q_\pi\left(s_t, a_t\right) Qπ(st,at)做出的估计。
- 方程右边的期望是关于下一时刻状态
S
t
+
1
S_{t+1}
St+1 和动作
A
t
+
1
A_{t+1}
At+1 求的。给定当前状态
s
t
s_t
st, 智能体执行动作
a
t
a_t
at, 环境会给出奖励
r
t
r_t
rt 和新的状态
s
t
+
1
s_{t+1}
st+1 。然后基于
s
t
+
1
s_{t+1}
st+1 做随机抽样,得到新的动作
a ~ t + 1 ∼ π ( ⋅ ∣ s t + 1 ) . \tilde{a}_{t+1} \sim \pi\left(\cdot \mid s_{t+1}\right) . a~t+1∼π(⋅∣st+1).
用观测到的
r
t
、
s
t
+
1
r_t 、 s_{t+1}
rt、st+1 和计算出的
a
~
t
+
1
\tilde{a}_{t+1}
a~t+1 对期望做蒙特卡洛近似, 得到:
r
t
+
γ
⋅
Q
π
(
s
t
+
1
,
a
~
t
+
1
)
.
r_t+\gamma \cdot Q_\pi\left(s_{t+1}, \tilde{a}_{t+1}\right) .
rt+γ⋅Qπ(st+1,a~t+1).
- 进一步把公式 (5.1) 中的
Q
π
Q_\pi
Qπ 近似成
q
q
q, 得到
y ^ t ≜ r t + γ ⋅ q ( s t + 1 , a ~ t + 1 ) . \widehat{y}_t \triangleq r_t+\gamma \cdot q\left(s_{t+1}, \tilde{a}_{t+1}\right) . y t≜rt+γ⋅q(st+1,a~t+1).
把它称作 TD 目标。它是表格在
t
+
1
t+1
t+1 时刻对
Q
π
(
s
t
,
a
t
)
Q_\pi\left(s_t, a_t\right)
Qπ(st,at) 做出的估计。
q
(
s
t
,
a
t
)
q\left(s_t, a_t\right)
q(st,at) 和
y
^
t
\widehat{y}_t
y
t 都是对动作价值
Q
π
(
s
t
,
a
t
)
Q_\pi\left(s_t, a_t\right)
Qπ(st,at) 的估计。由于
y
^
t
\widehat{y}_t
y
t 部分基于真实观测到的奖励
r
t
r_t
rt,我们认为
y
^
t
\widehat{y}_t
y
t 是更可靠的估计, 所以鼓励
q
(
s
t
,
a
t
)
q\left(s_t, a_t\right)
q(st,at) 趋近
y
^
t
\widehat{y}_t
y
t 。更新表格
(
s
t
,
a
t
)
\left(s_t, a_t\right)
(st,at) 位置上的元素:
q
(
s
t
,
a
t
)
←
(
1
−
α
)
⋅
q
(
s
t
,
a
t
)
+
α
⋅
y
^
t
.
q\left(s_t, a_t\right) \leftarrow(1-\alpha) \cdot q\left(s_t, a_t\right)+\alpha \cdot \widehat{y}_t .
q(st,at)←(1−α)⋅q(st,at)+α⋅y
t.
这样可以使得 q ( s t , a t ) q\left(s_t, a_t\right) q(st,at) 更接近 y ^ t \widehat{y}_t y t 。 SARSA 算法用到了这个五元组: ( s t , a t , r t , s t + 1 , a ~ t + 1 ) \left(s_t, a_t, r_t, s_{t+1}, \tilde{a}_{t+1}\right) (st,at,rt,st+1,a~t+1) 。SARSA 算法学到的 q q q 依赖于策略 π \pi π, 这是因为五元组中的 a ~ t + 1 \tilde{a}_{t+1} a~t+1 是根据 π ( ⋅ ∣ s t + 1 ) \pi\left(\cdot \mid s_{t+1}\right) π(⋅∣st+1) 抽样得到的。
训练流程:设当前表格为 q now q_{\text {now }} qnow , 当前策略为 π now \pi_{\text {now }} πnow 每一轮更新表格中的一个元素,把更新之后的表格记作 q new q_{\text {new }} qnew 。
- 观测到当前状态 s t s_t st, 根据当前策略做抽样: a t ∼ π now ( ⋅ ∣ s t ) a_t \sim \pi_{\text {now }}\left(\cdot \mid s_t\right) at∼πnow (⋅∣st) 。
- 把表格
q
now
q_{\text {now }}
qnow 中第
(
s
t
,
a
t
)
\left(s_t, a_t\right)
(st,at) 位置上的元素记作:
q ^ t = q now ( s t , a t ) . \widehat{q}_t=q_{\text {now }}\left(s_t, a_t\right) . q t=qnow (st,at). - 智能体执行动作 a t a_t at 之后, 观测到奖励 r t r_t rt 和新的状态 s t + 1 s_{t+1} st+1 。
- 根据当前策略做抽样: a ~ t + 1 ∼ π now ( ⋅ ∣ s t + 1 ) \tilde{a}_{t+1} \sim \pi_{\text {now }}\left(\cdot \mid s_{t+1}\right) a~t+1∼πnow (⋅∣st+1) 。注意, a ~ t + 1 \tilde{a}_{t+1} a~t+1 只是假想的动作, 智能体不予执行。
- 把表格
q
now
q_{\text {now }}
qnow 中第
(
s
t
+
1
,
a
~
t
+
1
)
\left(s_{t+1}, \tilde{a}_{t+1}\right)
(st+1,a~t+1) 位置上的元素记作:
q ^ t + 1 = q now ( s t + 1 , a ~ t + 1 ) . \widehat{q}_{t+1}=q_{\text {now }}\left(s_{t+1}, \tilde{a}_{t+1}\right) . q t+1=qnow (st+1,a~t+1). - 计算 TD 目标和 TD 误差:
y ^ t = r t + γ ⋅ q ^ t + 1 , δ t = q ^ t − y ^ t . \widehat{y}_t=r_t+\gamma \cdot \widehat{q}_{t+1}, \quad \delta_t=\widehat{q}_t-\widehat{y}_t . y t=rt+γ⋅q t+1,δt=q t−y t. - 更新表格中
(
s
t
,
a
t
)
\left(s_t, a_t\right)
(st,at) 位置上的元素:
q new ( s t , a t ) ← q now ( s t , a t ) − α ⋅ δ t . q_{\text {new }}\left(s_t, a_t\right) \leftarrow q_{\text {now }}\left(s_t, a_t\right)-\alpha \cdot \delta_t . qnew (st,at)←qnow (st,at)−α⋅δt. - 用某种算法更新策略函数。该算法与 SARSA 算法无关。
二.神经网络形式的SARSA
价值网络:如果状态空间
S
\mathcal{S}
S 是无限集, 那么我们无法用一张表格表示
Q
π
Q_\pi
Qπ, 否则表格的行数是无穷。一种可行的方案是用一个神经网络
q
(
s
,
a
;
w
)
q(s, a ; \boldsymbol{w})
q(s,a;w) 来近似
Q
π
(
s
,
a
)
Q_\pi(s, a)
Qπ(s,a); 理想情况下,
q
(
s
,
a
;
w
)
=
Q
π
(
s
,
a
)
,
∀
s
∈
S
,
a
∈
A
q(s, a ; \boldsymbol{w})=Q_\pi(s, a), \quad \forall s \in \mathcal{S}, a \in \mathcal{A}
q(s,a;w)=Qπ(s,a),∀s∈S,a∈A
训练流程 : 设当前价值网络的参数为
w
n
o
w
\boldsymbol{w}_{\mathrm{now}}
wnow, 当前策略为
π
n
o
w
\pi_{\mathrm{now}}
πnow 每一轮训练用五元组
(
s
t
,
a
t
,
r
t
,
s
t
+
1
,
a
~
t
+
1
)
\left(s_t, a_t, r_t, s_{t+1}, \tilde{a}_{t+1}\right)
(st,at,rt,st+1,a~t+1) 对价值网络参数做一次更新。
-
观测到当前状态 s t s_t st, 根据当前策略做抽样: a t ∼ π now ( ⋅ ∣ s t ) a_t \sim \pi_{\text {now }}\left(\cdot \mid s_t\right) at∼πnow (⋅∣st) 。
-
用价值网络计算 ( s t , a t ) \left(s_t, a_t\right) (st,at) 的价值:
q ^ t = q ( s t , a t ; w now ) . \widehat{q}_t=q\left(s_t, a_t ; \boldsymbol{w}_{\text {now }}\right) . q t=q(st,at;wnow ). -
智能体执行动作 a t a_t at 之后, 观测到奖励 r t r_t rt 和新的状态 s t + 1 s_{t+1} st+1 。
-
根据当前策略做抽样: a ~ t + 1 ∼ π n o w ( ⋅ ∣ s t + 1 ) \tilde{a}_{t+1} \sim \pi_{\mathrm{now}}\left(\cdot \mid s_{t+1}\right) a~t+1∼πnow(⋅∣st+1) 。注意, a ~ t + 1 \tilde{a}_{t+1} a~t+1 只是假想的动作, 智能体不予执行。
-
用价值网络计算 ( s t + 1 , a ~ t + 1 ) \left(s_{t+1}, \tilde{a}_{t+1}\right) (st+1,a~t+1) 的价值:
q ^ t + 1 = q ( s t + 1 , a ~ t + 1 ; w now ) . \widehat{q}_{t+1}=q\left(s_{t+1}, \tilde{a}_{t+1} ; \boldsymbol{w}_{\text {now }}\right) . q t+1=q(st+1,a~t+1;wnow ). -
计算 TD 目标和 TD 误差:
y ^ t = r t + γ ⋅ q ^ t + 1 , δ t = q ^ t − y ^ t . \widehat{y}_t=r_t+\gamma \cdot \widehat{q}_{t+1}, \quad \delta_t=\widehat{q}_t-\widehat{y}_t . y t=rt+γ⋅q t+1,δt=q t−y t. -
对价值网络 q q q 做反向传播, 计算 q q q 关于 w \boldsymbol{w} w 的梯度: ∇ w q ( s t , a t ; w now ) \nabla_{\boldsymbol{w}} q\left(s_t, a_t ; \boldsymbol{w}_{\text {now }}\right) ∇wq(st,at;wnow ) 。
-
更新价值网络参数:
w new ← w now − α ⋅ δ t ⋅ ∇ w q ( s t , a t ; w now ) . \boldsymbol{w}_{\text {new }} \leftarrow \boldsymbol{w}_{\text {now }}-\alpha \cdot \delta_t \cdot \nabla_{\boldsymbol{w}} q\left(s_t, a_t ; \boldsymbol{w}_{\text {now }}\right) . wnew ←wnow −α⋅δt⋅∇wq(st,at;wnow ). -
用某种算法更新策略函数。该算法与 SARSA 算法无关。