ICML2024
paper
code
Intro
基于状态表征的model-based强化学习方法一般需要学习状态转移模型以及奖励模型。现有方法都是将二者联合训练但普遍缺乏对如何平衡二者之间的比重进行研究。本文提出的HarmonyDream便是通过自动调整损失系数来维持任务间的和谐,即在世界模型学习过程中保持观测状态建模和奖励建模之间的动态平衡。
Method
算法基于DreamV2的世界模型架构:
- Representation model: z t ∼ q θ ( z t ∣ z t − 1 , a t − 1 , o t ) z_{t}\sim q_{\theta }( z_{t}\mid z_{t- 1}, a_{t- 1}, o_{t}) zt∼qθ(zt∣zt−1,at−1,ot)
- Observation model: o ^ t ∼ p θ ( o ^ t ∣ z t ) \hat{o}_t\sim p_\theta(\hat{o}_t\mid z_t) o^t∼pθ(o^t∣zt)
- Transition model: z ^ t ∼ p θ ( z ^ t ∣ z t − 1 , a t − 1 ) \hat{z} _t\sim p_\theta ( \hat{z} _t\mid z_{t- 1}, a_{t- 1}) z^t∼pθ(z^t∣zt−1,at−1)
- Reward model: r ^ t ∼ p θ ( r ^ t ∣ z t ) . \hat{r}_t\sim p_\theta\left(\hat{r}_t\mid z_t\right). r^t∼pθ(r^t∣zt).
所有参数通过联合训练以下三个损失函数
- Observation loss: L o ( θ ) = − log p θ ( o t ∣ z t ) \mathcal{L}_o(\theta)=-\log p_\theta(o_t\mid z_t) Lo(θ)=−logpθ(ot∣zt)
- Reward loss: L r ( θ ) = − log p θ ( r t ∣ z t ) \mathcal{L}_{r}(\theta)=-\log p_{\theta}(r_{t}\mid z_{t}) Lr(θ)=−logpθ(rt∣zt)
- Dynamics loss: L d ( θ ) = K L [ q θ ( z t ∣ z t − 1 , a t − 1 , o t ) ∥ p θ ( z ^ t ∣ z t − 1 , a t − 1 ) ] \mathcal{L}_{d}(\theta)=KL[q_{\theta}(z_{t}\mid z_{t-1},a_{t-1},o_{t})\parallel p_{\theta}(\hat{z}_{t}\mid z_{t-1},a_{t-1})] Ld(θ)=KL[qθ(zt∣zt−1,at−1,ot)∥pθ(z^t∣zt−1,at−1)],
对三种目标加权后便是最终优化目标
L
(
θ
)
=
w
o
L
o
(
θ
)
+
w
r
L
r
(
θ
)
+
w
d
L
d
(
θ
)
.
\mathcal{L}(\theta)=w_o\mathcal{L}_o(\theta)+w_r\mathcal{L}_r(\theta)+w_d\mathcal{L}_d(\theta).
L(θ)=woLo(θ)+wrLr(θ)+wdLd(θ).
HarmonyDream提出动态加权方法,
L
(
θ
,
σ
o
,
σ
r
,
σ
d
)
=
∑
i
∈
{
o
,
r
,
d
}
H
^
(
L
i
(
θ
)
,
σ
i
)
(5)
=
∑
i
∈
{
o
,
r
,
d
}
1
σ
i
L
i
(
θ
)
+
log
(
1
+
σ
i
)
.
\begin{aligned} \mathcal{L}(\theta,\sigma_{o},\sigma_{r},\sigma_{d})& =\sum_{i\in\{o,r,d\}}\hat{\mathcal{H}}(\mathcal{L}_{i}(\theta),\sigma_{i}) \\ &&\text{(5)} \\ &=\sum_{i\in\{o,r,d\}}\frac{1}{\sigma_{i}}\mathcal{L}_{i}(\theta)+\log{(1+\sigma_{i})}. \end{aligned}
L(θ,σo,σr,σd)=i∈{o,r,d}∑H^(Li(θ),σi)=i∈{o,r,d}∑σi1Li(θ)+log(1+σi).(5)
其中
σ
i
\sigma_i
σi由
σ
i
=
exp
(
s
i
)
>
0
\sigma_i=\exp(s_i)>0
σi=exp(si)>0表示,源码中
s
s
s为一个可梯度回传的参数且初始化为0。
self.harmony_s1 = torch.nn.Parameter(-torch.log(torch.tensor(1.0))) #reward
self.harmony_s2 = torch.nn.Parameter(-torch.log(torch.tensor(1.0))) # image
self.harmony_s3 = torch.nn.Parameter(-torch.log(torch.tensor(1.0))) # kl
结果
相较于DreamerV2提升明显。结合DreamerV3的效果也很好。