1 Introduction
自回归模型随着gpt的出现取得很大的成功,还是有很多工程上的问题并不是很适合使用自回归模型:
1)自回归需要的算力太大,满足不了实时性要求:例如在自动驾驶的轨迹预测任务中,如果要用纯自回归的世界模型,耗时太大;
2)要求数据天然有时序性:很多图像任务并没有严格的序列生成的要求;
这个部分开始用隐变量的方式来进行建模。
2 特征提取和线形回归
自动驾驶和机器人中的很多任务,是通过感知的环境输入, 然后进行特征提取,最后用线形回归来预测和生成指令。
但是这种方式因为采用了非常简单的单高斯分布来估计指令,这个时候有几种提高的方式:
1)提高特征的表达能力:
1.1)如果特征提取的模型(一般是transformer)记忆力足够强大,哪怕后面接了单峰高斯估计也能有一个比较好的拟合效果;直觉来说,就是把所有的情况都记住了。高质量的特征能够在一定程度上“预处理”复杂性。
1.2)采用anchor based的query来生成不同的feature,降低拟合难度。
2)提高概率分布的表达能力:
2.1)采用混合高斯叠加的概率分布,来生成复杂的概率分布。
因为提高特征表达能力往往是多模态相关的工作,这里我门进行跳过,更加关注通过提高概率分布的表达能力这个方面。
3 vae 模型
z是隐变量,需要用模型构建z存在的情况下
p
(
x
∣
z
,
θ
)
p(x|z,\theta)
p(x∣z,θ)的概率。
按照note2的内容,loss设计的时候,满足极大似然就可以
l
o
g
P
θ
(
x
)
logP_{\theta}(x)
logPθ(x)。
现在的问题是,每种z都有一定的概率能生成x。可以采用普查的方式,或者采用抽样的方式
l o g P θ ( x ) = 1 D ( z ) ∑ z ∈ D P ( x , z ; θ ) \begin{aligned} logP_{\theta}(x)=\frac{1}{D(z)}\sum_{z \in D}P(x,z;\theta) \end{aligned} logPθ(x)=D(z)1z∈D∑P(x,z;θ)
因为z本身是连续分布,如果采用普查的方式也不太现实,不过我们可以随便搞一个抽样方式q(z)。
l o g P θ ( x ) = l o g ∑ j = 1 k q ( z ( j ) ) q ( z ( j ) ) P ( x , z ; θ ) = l o g E x − q ( z ) P ( x , z ; θ ) q ( z ( j ) ) \begin{aligned} logP_{\theta}(x) & = log\sum_{j=1}^k \frac{q(z^{(j)})}{q(z^{(j)})} P(x,z;\theta) \\ & = logE_{x-q(z)}\frac{P(x,z;\theta)}{q(z^{(j)})} \end{aligned} logPθ(x)=logj=1∑kq(z(j))q(z(j))P(x,z;θ)=logEx−q(z)q(z(j))P(x,z;θ)
对于log这种凸函数,满足
l
o
g
E
[
x
]
>
E
[
l
o
g
(
x
)
]
logE[x]>E[log(x)]
logE[x]>E[log(x)],可以对上面这个式子进行变换
l
o
g
P
θ
(
x
)
=
l
o
g
E
x
−
q
(
z
)
P
(
x
,
z
;
θ
)
q
(
z
(
j
)
)
≥
E
x
−
q
(
z
)
l
o
g
P
(
x
,
z
;
θ
)
q
(
z
(
j
)
)
=
∑
j
=
1
k
q
(
z
(
j
)
)
l
o
g
P
(
x
,
z
;
θ
)
q
(
z
(
j
)
)
=
∑
j
=
1
k
(
q
(
z
(
j
)
)
l
o
g
P
(
x
,
z
;
θ
)
−
q
(
z
(
j
)
)
l
o
g
q
(
z
(
j
)
)
)
=
E
L
B
O
\begin{aligned} logP_{\theta}(x) & = logE_{x-q(z)}\frac{P(x,z;\theta)}{q(z^{(j)})} \\ & \ge E_{x-q(z)}log\frac{P(x,z;\theta)}{q(z^{(j)})} \\ & = \sum_{j=1}^kq(z^{(j)})log\frac{P(x,z;\theta)}{q(z^{(j)})} \\ & = \sum_{j=1}^k(q(z^{(j)})logP(x,z;\theta)-q(z^{(j)})logq(z^{(j)}))=ELBO \end{aligned}
logPθ(x)=logEx−q(z)q(z(j))P(x,z;θ)≥Ex−q(z)logq(z(j))P(x,z;θ)=j=1∑kq(z(j))logq(z(j))P(x,z;θ)=j=1∑k(q(z(j))logP(x,z;θ)−q(z(j))logq(z(j)))=ELBO
我们可以来研究一下随机的抽样概率分布q(z|x)和p(z|x;\theta)的概率分布的差异。
D
K
L
(
q
(
z
)
∣
∣
p
(
z
∣
x
;
θ
)
)
=
∑
z
q
(
z
∣
x
)
l
o
g
q
(
z
∣
x
)
p
(
z
∣
x
;
θ
)
=
∑
z
q
(
z
∣
x
)
l
o
g
q
(
z
∣
x
)
p
(
z
,
x
;
θ
)
/
p
(
x
;
θ
)
=
∑
z
(
q
(
z
∣
x
)
l
o
g
q
(
z
∣
x
)
+
q
(
z
∣
x
)
l
o
g
p
(
x
;
θ
)
−
q
(
z
∣
x
)
l
o
g
p
(
x
,
z
;
θ
)
)
=
∑
z
q
(
z
∣
x
)
l
o
g
p
(
x
;
θ
)
−
∑
z
(
q
(
z
∣
x
)
l
o
g
q
(
z
∣
x
)
−
q
(
z
∣
x
)
l
o
g
p
(
x
,
z
;
θ
)
)
=
l
o
g
p
(
x
;
θ
)
−
∑
j
=
1
k
(
q
(
z
(
j
)
)
l
o
g
P
(
x
,
z
;
θ
)
−
q
(
z
(
j
)
)
l
o
g
q
(
z
(
j
)
)
)
\begin{aligned} D_{KL}(q(z)||p(z|x;\theta))&=\sum_z q(z|x)log\frac{q(z|x)}{p(z|x;\theta)} \\ & = \sum_z q(z|x)log\frac{q(z|x)}{p(z,x;\theta)/p(x;\theta)} \\ & = \sum_z (q(z|x)logq(z|x)+q(z|x)logp(x;\theta)-q(z|x)logp(x,z;\theta)) \\ & = \sum_z q(z|x)logp(x;\theta) - \sum_z (q(z|x)logq(z|x)-q(z|x)logp(x,z;\theta)) \\ & = logp(x;\theta)- \sum_{j=1}^k(q(z^{(j)})logP(x,z;\theta)-q(z^{(j)})logq(z^{(j)})) \end{aligned}
DKL(q(z)∣∣p(z∣x;θ))=z∑q(z∣x)logp(z∣x;θ)q(z∣x)=z∑q(z∣x)logp(z,x;θ)/p(x;θ)q(z∣x)=z∑(q(z∣x)logq(z∣x)+q(z∣x)logp(x;θ)−q(z∣x)logp(x,z;θ))=z∑q(z∣x)logp(x;θ)−z∑(q(z∣x)logq(z∣x)−q(z∣x)logp(x,z;θ))=logp(x;θ)−j=1∑k(q(z(j))logP(x,z;θ)−q(z(j))logq(z(j)))
这个式子说明了
l
o
g
p
(
x
;
θ
)
−
D
K
L
(
p
(
z
∣
x
;
θ
)
∣
∣
q
(
z
)
)
=
E
L
B
O
logp(x;\theta)-D_{KL}(p(z|x;\theta)||q(z))=ELBO
logp(x;θ)−DKL(p(z∣x;θ)∣∣q(z))=ELBO
现在我们来更新一下极大似然
l
o
g
P
θ
(
x
)
=
∑
z
(
q
(
z
∣
x
)
l
o
g
P
(
x
,
z
;
θ
)
−
q
(
z
∣
x
)
l
o
g
q
(
z
∣
x
)
)
−
D
K
L
(
q
(
z
∣
x
)
∣
∣
p
(
z
∣
x
;
θ
)
)
\begin{aligned} logP_{\theta}(x) & = \sum_{z}(q(z|x)logP(x,z;\theta)-q(z|x)logq(z|x)) - D_{KL}(q(z|x)||p(z|x;\theta))\\ \end{aligned}
logPθ(x)=z∑(q(z∣x)logP(x,z;θ)−q(z∣x)logq(z∣x))−DKL(q(z∣x)∣∣p(z∣x;θ))
现在核心的问题来了,
p
(
z
∣
x
;
θ
)
p(z|x;\theta)
p(z∣x;θ)的真值是什么,没有人知道!!!这意味着,我们无法教模型去学这个分布。
这里涉及了一些先验和后验的基本概念。
我们的先验知识是q(z|x)应该是标准正态分布
N
(
0
,
1
)
\mathcal{N}(0,1)
N(0,1),但是实际的q(z|x)没人知道。
q(z|x)既然我们知道是标准正态分布,那么我们就可以用模型来学习,可以表示成
q
(
z
∣
x
,
ϕ
)
q(z|x,\phi)
q(z∣x,ϕ)。
那么
D
K
L
(
q
(
z
∣
x
)
∣
∣
p
(
z
∣
x
;
θ
)
)
D_{KL}(q(z|x)||p(z|x;\theta))
DKL(q(z∣x)∣∣p(z∣x;θ))的结果,我们祈祷数据的后验分布是真的符合高斯分布。也就是一般vae计算极大似然的时候只需要考虑ELBO.
对于ELBO,这里只能采用mento carlo的方式进行采样计算
z
(
k
)
=
μ
ϕ
(
x
)
+
σ
ϕ
(
x
)
ϵ
,
ϵ
∼
N
(
0
,
1
)
z^{(k)}=\mu_{\phi}(x)+\sigma_{\phi}(x)ϵ, ϵ \sim \mathcal{N}(0, 1)
z(k)=μϕ(x)+σϕ(x)ϵ,ϵ∼N(0,1)
那么极大似然可以更新成,现在还有个问题,vae的decoder只能给出p(x|z)的概率,所以我们还需要继续修改。
E
L
B
O
=
1
K
∑
k
l
o
g
P
(
x
,
z
(
k
)
;
θ
)
−
l
o
g
q
(
z
(
k
)
∣
x
;
ϕ
)
)
=
1
K
∑
k
l
o
g
P
(
x
∣
z
(
k
)
,
θ
)
+
l
o
g
P
(
z
(
k
)
)
−
l
o
g
q
(
z
(
k
)
∣
x
,
ϕ
)
=
1
K
∑
k
l
o
g
P
(
x
∣
z
(
k
)
,
θ
)
−
D
K
L
(
q
(
z
∣
x
,
ϕ
)
∣
∣
P
(
z
)
)
\begin{aligned} ELBO & =\frac{1}{K} \sum_{k}logP(x,z^{(k)};\theta)-logq(z^{(k)}|x;\phi))\\ & = \frac{1}{K}\sum_k logP(x|z^{(k)},\theta)+logP(z^{(k)})-logq(z^{(k)}|x,\phi) \\ & = \frac{1}{K}\sum_k logP(x|z^{(k)},\theta)-D_{KL}(q(z|x,\phi)||P(z)) \end{aligned}
ELBO=K1k∑logP(x,z(k);θ)−logq(z(k)∣x;ϕ))=K1k∑logP(x∣z(k),θ)+logP(z(k))−logq(z(k)∣x,ϕ)=K1k∑logP(x∣z(k),θ)−DKL(q(z∣x,ϕ)∣∣P(z))
P(z)是标准正态分布,KL散度可以积分直接得到解析解,这里直接给出公式的结果
D
K
L
(
q
ϕ
(
z
∣
x
)
∣
∣
P
(
z
)
)
=
D
K
L
(
N
(
μ
,
σ
)
∣
∣
N
(
0
,
1
)
)
=
1
2
∑
i
(
σ
i
2
+
μ
i
2
−
1
−
l
n
σ
i
2
)
\begin{aligned} D_{KL}(q_{\phi}(z|x)||P(z)) & = D_{KL}(\mathcal{N}(\mu, \sigma)||\mathcal{N}(0, 1)) \\ & = \frac{1}{2}\sum_i(\sigma_i^2+\mu_i^2-1-ln\sigma_i^2) \end{aligned}
DKL(qϕ(z∣x)∣∣P(z))=DKL(N(μ,σ)∣∣N(0,1))=21i∑(σi2+μi2−1−lnσi2)
4 cvae实例分析
aloha作为机器人模仿学习的重要的一项工作[1],在他们的工作中使用了cvae,让我们来看一下它是如何设计的。
4.1 encoder
可以看到 q ( z ∣ x , ϕ ) q(z|x,\phi) q(z∣x,ϕ) 用的是transformer的cls query, 再经过一个mlp,得到 μ , 和 σ \mu,和\sigma μ,和σ
action_embed = self.encoder_action_proj(actions) # (bs, seq, hidden_dim)
qpos_embed = self.encoder_joint_proj(qpos) # (bs, hidden_dim)
qpos_embed = torch.unsqueeze(qpos_embed, axis=1) # (bs, 1, hidden_dim)
cls_embed = self.cls_embed.weight # (1, hidden_dim)
cls_embed = torch.unsqueeze(cls_embed, axis=0).repeat(bs, 1, 1) # (bs, 1, hidden_dim)
encoder_input = torch.cat([cls_embed, qpos_embed, action_embed], axis=1) # (bs, seq+1, hidden_dim)
encoder_input = encoder_input.permute(1, 0, 2) # (seq+1, bs, hidden_dim)
# do not mask cls token
cls_joint_is_pad = torch.full((bs, 2), False).to(qpos.device) # False: not a padding
is_pad = torch.cat([cls_joint_is_pad, is_pad], axis=1) # (bs, seq+1)
# obtain position embedding
pos_embed = self.pos_table.clone().detach()
pos_embed = pos_embed.permute(1, 0, 2) # (seq+1, 1, hidden_dim)
# query model
encoder_output = self.encoder(encoder_input, pos=pos_embed, src_key_padding_mask=is_pad)
encoder_output = encoder_output[0] # take cls output only
latent_info = self.latent_proj(encoder_output)
mu = latent_info[:, :self.latent_dim]
logvar = latent_info[:, self.latent_dim:]
latent_sample = reparametrize(mu, logvar)
latent_input = self.latent_out_proj(latent_sample)
4.2 decoder
decoder 这里采用的是标准的detr,先将所有的conditional的信息融合latent_feature先经过transformer encoder,然后用一个query 过一个transformer decoder,得到feature, 然后直接出action,这里输出没有用高斯分布。
if self.backbones is not None:
# Image observation features and position embeddings
all_cam_features = []
all_cam_pos = []
for cam_id, cam_name in enumerate(self.camera_names):
features, pos = self.backbones[0](image[:, cam_id]) # HARDCODED
features = features[0] # take the last layer feature
pos = pos[0]
all_cam_features.append(self.input_proj(features))
all_cam_pos.append(pos)
# proprioception features
proprio_input = self.input_proj_robot_state(qpos)
# fold camera dimension into width dimension
src = torch.cat(all_cam_features, axis=3)
pos = torch.cat(all_cam_pos, axis=3)
hs = self.transformer(src, None, self.query_embed.weight, pos, latent_input, proprio_input, self.additional_pos_embed.weight)[0]
else:
qpos = self.input_proj_robot_state(qpos)
env_state = self.input_proj_env_state(env_state)
transformer_input = torch.cat([qpos, env_state], axis=1) # seq length = 2
hs = self.transformer(transformer_input, None, self.query_embed.weight, self.pos.weight)[0]
a_hat = self.action_head(hs)
is_pad_hat = self.is_pad_head(hs)
4.3 loss
loss和极大似然是相反的,所以在
l
o
s
s
=
1
K
∑
k
−
l
o
g
P
(
x
∣
z
(
k
)
,
θ
)
+
D
K
L
(
q
(
z
∣
x
,
ϕ
)
∣
∣
P
(
z
)
)
loss=\frac{1}{K}\sum_k -logP(x|z^{(k)},\theta)+D_{KL}(q(z|x,\phi)||P(z))
loss=K1k∑−logP(x∣z(k),θ)+DKL(q(z∣x,ϕ)∣∣P(z))
也就是说在loss中直接计算
D
K
L
(
q
(
z
∣
x
,
ϕ
)
∣
∣
P
(
z
)
)
D_{KL}(q(z|x,\phi)||P(z))
DKL(q(z∣x,ϕ)∣∣P(z))
对应代码
def kl_divergence(mu, logvar):
batch_size = mu.size(0)
assert batch_size != 0
if mu.data.ndimension() == 4:
mu = mu.view(mu.size(0), mu.size(1))
if logvar.data.ndimension() == 4:
logvar = logvar.view(logvar.size(0), logvar.size(1))
klds = -0.5 * (1 + logvar - mu.pow(2) - logvar.exp())
total_kld = klds.sum(1).mean(0, True)
dimension_wise_kld = klds.mean(0)
mean_kld = klds.mean(1).mean(0, True)
return total_kld, dimension_wise_kld, mean_kld
正常来说,
∑
k
−
l
o
g
P
(
x
∣
z
(
k
)
,
θ
)
=
∑
k
−
l
o
g
1
2
π
σ
e
−
(
x
−
μ
)
2
/
(
2
σ
2
)
=
∑
k
(
x
−
μ
)
2
/
(
2
σ
2
)
\begin{aligned} \sum_k-logP(x|z^{(k)},\theta) &=\sum_k-log\frac{1}{\sqrt{2\pi}\sigma}e^{-(x-\mu)^2/(2\sigma^2)}\\ &=\sum_k(x-\mu)^2/(2\sigma^2) \end{aligned}
k∑−logP(x∣z(k),θ)=k∑−log2πσ1e−(x−μ)2/(2σ2)=k∑(x−μ)2/(2σ2)
这里act policy因为vae decoder没有输出
σ
\sigma
σ,默认就是1了。
a_hat, is_pad_hat, (mu, logvar) = self.model(qpos, image, env_state, actions, is_pad)
total_kld, dim_wise_kld, mean_kld = kl_divergence(mu, logvar)
loss_dict = dict()
all_l1 = F.l1_loss(actions, a_hat, reduction='none')
l1 = (all_l1 * ~is_pad.unsqueeze(-1)).mean()
loss_dict['l1'] = l1
loss_dict['kl'] = total_kld[0]
loss_dict['loss'] = loss_dict['l1'] + loss_dict['kl'] * self.kl_weight
4.4 vae decoder 的重参数化
act policy对于z只采样了一次,没有多重采样。
latent_sample = reparametrize(mu, logvar)
latent_input = self.latent_out_proj(latent_sample)
def reparametrize(mu, logvar):
std = logvar.div(2).exp()
eps = Variable(std.data.new(std.size()).normal_())
return mu + std * eps
References
[1] https://tonyzhaozh.github.io/aloha/