【VAE-base】VAE最简单代码实现(纯全连接层实现变分自编码机)

VAE (Variational Autoencoder)

代码:https://github.com/AntixK/PyTorch-VAE/blob/master/models/vanilla_vae.py
论文:Auto-Encoding Variational Bayes
核心参考1 https://github.com/lyeoni/pytorch-mnist-VAE/blob/master/pytorch-mnist-VAE.ipynb
工程化不错,但是vae的loss可能不对:https://blog.csdn.net/lsb2002/article/details/134837076

文章目录

  • VAE (`V`ariational `A`uto`e`ncoder)
  • 一、VAE直观理解 (全连接层实现,仅用作实验说明)
    • 1.1 直接推理生成图片 (加载训练好vae模型,随机生成`只使用解码器`)
      • 1.1.1 由随机正态分布直接生成
        • 可视化生成的6张手写字符
      • 1.1.2 完整生成图片代码
    • 1.2 `训练`中模型推理的逻辑 (编码器产生中间变量输入到解码器)
      • 1.2.0 正态分布的`性质`* (补充知识,可跳过)
      • 1.2.1 loss计算函数
        • 为什么2个损失函数,如何推导的? (见后文,先及结论通读)
      • 1.2.2 对应VAE的前向代码
      • 1.2.3 sampling函数 (重参数技巧)
        • 1.2.3.1 sampling函数为什么这样写?推导
      • 1.2.4 vae完整训练代码
        • 输出 (为演示只训练3个,可以多训练几个)
  • 二、VAE原理探究和简易推导
    • 2.1 vae模型训练架构图 (训练目的:重构出训练样本+解码器输入趋近于正态分布)
      • Minimize部分:
    • 2.2 VAE 中的 KL 散度(KL Divergence)从零推导(最大化对数似然函数)
    • 2.3 KL 散度公式的进一步化简 (得到loss函数)!
      • 对应实际实现中的 KL 散度
    • VAE重构损失

一、VAE直观理解 (全连接层实现,仅用作实验说明)

基于手写字符生成案例说明
0. 纯线性层,nn.liner实现(每个神经元连接到上一层的所有神经元)
2. 如何直接使用别人训练好的模型:预训练模型+模型
3. 如何训练vae?

1.1 直接推理生成图片 (加载训练好vae模型,随机生成只使用解码器

随机生成多元的正态分布z, 输入到vae的解码器 (decoder),直接得到结构

def load_pretain_vae(point_path='./checkpoint_8.pth'):
    # 加载预训练的vae
    z_gauss_dims=2
    # 下面是训练时的参数
    vae = VAE(x_dim=784, h_dim1= 512, h_dim2=256, z_dim=z_gauss_dims)
    vae.cuda()  # 先移到CUDA上
    vae.load_state_dict(torch.load(point_path, map_location=torch.device('cuda')))
    print('load pretrain vae')
    sample_num=6
    # 输入一个多元正态分布(训练时候的维度)
    z = torch.randn(sample_num, z_gauss_dims).cuda()
    print('z.shape and z is ',z.shape,z)
    sample = vae.decoder(z).cuda()
    # pytorchg官方保存  n c  h w
    # nrow 表示每行的显示数量(相当于列数)
    save_image(sample.view(sample_num, 1, 28, 28), f'./load_pretain_vae' + '.png',nrow=2)

1.1.1 由随机正态分布直接生成

生成6张图片,输入是随机的z (shape 为 (6,2)), 然后进入到vae的解码器 (decoder)
输出为 (6,764)
在这里插入图片描述

可视化生成的6张手写字符

在这里插入图片描述

1.1.2 完整生成图片代码

'''
from: https://github.com/lyeoni/pytorch-mnist-VAE/blob/master/pytorch-mnist-VAE.ipynb
edit: zengxy+ gpt4o  2024.05.27
'''
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms
from torch.autograd import Variable
from torchvision.utils import save_image
import os



# z_dim 一般维度为2 
#
class VAE(nn.Module):
    def __init__(self, x_dim, h_dim1, h_dim2, z_dim):
        super(VAE, self).__init__()
        
        # encoder part
        # 神经元中间层的个数与维度数
        self.fc1 = nn.Linear(x_dim, h_dim1)
        self.fc2 = nn.Linear(h_dim1, h_dim2)
        self.fc31 = nn.Linear(h_dim2, z_dim)
        self.fc32 = nn.Linear(h_dim2, z_dim)
        # decoder part
        self.fc4 = nn.Linear(z_dim, h_dim2)
        self.fc5 = nn.Linear(h_dim2, h_dim1)
        self.fc6 = nn.Linear(h_dim1, x_dim)
        
    def encoder(self, x):
        h = F.relu(self.fc1(x))
        h = F.relu(self.fc2(h))
        return self.fc31(h), self.fc32(h) # mu, log_var
    
    def sampling(self, mu, log_var):
        std = torch.exp(0.5*log_var)
        eps = torch.randn_like(std)
        return eps.mul(std).add_(mu) # return z sample
    
    # 训练好后,用作生成还原图像
    def decoder(self, z):
        h = F.relu(self.fc4(z))
        print('decoder  h1.shape is ',h.shape)
        h = F.relu(self.fc5(h))
        print('decoder  h2.shape is ',h.shape)
        return F.sigmoid(self.fc6(h)) 
    
    # 训练时执行逻辑
    def forward(self, x):
        mu, log_var = self.encoder(x.view(-1, 784))
        z = self.sampling(mu, log_var)
        return self.decoder(z), mu, log_var
    
    


def load_pretain_vae(point_path='./checkpoint_8.pth'):
    # 加载预训练的vae
    z_gauss_dims=2
    # 下面是训练时的参数
    vae = VAE(x_dim=784, h_dim1= 512, h_dim2=256, z_dim=z_gauss_dims)
    vae.cuda()  # 先移到CUDA上
    vae.load_state_dict(torch.load(point_path, map_location=torch.device('cuda')))
    print('load pretrain vae')
    sample_num=6
    # 输入一个多元正态分布(训练时候的维度)
    z = torch.randn(sample_num, z_gauss_dims).cuda()
    print('z.shape and z is ',z.shape,z)
    sample = vae.decoder(z).cuda()
    # pytorchg官方保存  n c  h w
    save_image(sample.view(sample_num, 1, 28, 28), f'./load_pretain_vae' + '.png',nrow=2)



if __name__ == '__main__':
    load_pretain_vae()

1.2 训练中模型推理的逻辑 (编码器产生中间变量输入到解码器)

训练图片X经过编码器输出,经过重采样,得到Z,z经过编码器得到X^,
然后根据公式计算loss,反向迭代参数

在这里插入图片描述

图来自:抛开数学,轻松学懂 VAE(附 PyTorch 实现) - 周弈帆的文章 -
知乎:

输入训练数据data,
经过vae的前向推理流程:recon_batch, mu, log_var = self.model(data)
然后计算loss:loss = self.model.loss_function(recon_batch, data, mu, log_var)

sim前面有个负号的原因: 在实际实现中,KL
散度是一个非负值。为了使其在损失函数中发挥正则化作用,我们通常会加上负号,从而在整体损失中减去这个值,确保模型在训练过程中不仅关注重构误差,还关注潜在分布的正则化。

	# 局部代码 仅用作说明
    def train(self, epoch,save_image_name="train_sample"):
        """
        训练过程。
        :param epoch: 当前的训练轮数。
        """
        self.model.train() #将导入的模型输入到训练模式
        self.load_minist()
        train_loss = 0
        for batch_idx, (data, _) in enumerate(self.train_loader):
            data = data.cuda()
            self.optimizer.zero_grad()
            recon_batch, mu, log_var = self.model(data)
            loss = self.model.loss_function(recon_batch, data, mu, log_var)
            loss.backward()
            train_loss += loss.item()
            self.optimizer.step()

1.2.0 正态分布的性质* (补充知识,可跳过)

根据正态分布的线性变换性质,我们可以推导出 ( z ) 的分布。具体来说,正态分布的线性变换性质包括以下几个方面:

  1. 线性组合的性质

    • 如果 X ∼ N ( μ X , σ X 2 ) X \sim \mathcal{N}(\mu_X, \sigma_X^2) XN(μX,σX2)
    • Y ∼ N ( μ Y , σ Y 2 ) Y \sim \mathcal{N}(\mu_Y, \sigma_Y^2) YN(μY,σY2)
      那么 a X + b Y ∼ N ( a μ X + b μ Y , a 2 σ X 2 + b 2 σ Y 2 ) aX + bY \sim \mathcal{N}(a\mu_X + b\mu_Y, a^2\sigma_X^2 + b^2\sigma_Y^2) aX+bYN(aμX+bμY,a2σX2+b2σY2)
  2. 加权和的性质

    • 如果 X ∼ N ( μ , σ 2 ) X \sim \mathcal{N}(\mu, \sigma^2) XN(μ,σ2),a和 b是常数,那么 a X + b ∼ N ( a μ + b , a 2 σ 2 ) aX + b \sim \mathcal{N}(a\mu + b, a^2\sigma^2) aX+bN(aμ+b,a2σ2)

1.2.1 loss计算函数

重构损失函数:二值交叉熵损失(Binary Cross Entropy Loss)
KL 散度(KL Divergence)

    def loss_function(self,recon_x, x, mu, log_var):
        """
        VAE的损失函数,包括重构损失和KL散度(KLD:Kullback-Leibler divergence)。
        :param recon_x: 重构的数据。
        :param x: 原始数据。
        :param mu: 编码的均值。
        :param log_var: 编码的对数方差。
        """
        BCE = F.binary_cross_entropy(recon_x, x.view(-1, 784), reduction='sum')
        KLD = -0.5 * torch.sum(1 + log_var - mu.pow(2) - log_var.exp())
        return BCE + KLD
为什么2个损失函数,如何推导的? (见后文,先及结论通读)

1.2.2 对应VAE的前向代码

输入训练数据x,
编码器的输出 mu, log_var用于计算z
输出最终的结果 self.decoder(z),

这里的log_var 表示 由网络输出的对数方差,等效于 l o g ( σ 2 ) log(\sigma^2) log(σ2)

# class VAE
    def forward(self, x):
        """
        VAE模型的前向传递。
        :param x: 输入的训练数据数据。
        """
        mu, log_var = self.encoder(x.view(-1, 784))  # 将图像展平处理
        z = self.sampling(mu, log_var)
        return self.decoder(z), mu, log_var
	

1.2.3 sampling函数 (重参数技巧)

返回值为Z z ∼ N ( μ , σ 2 ) z \sim \mathcal{N}(\mu, \sigma^2) zN(μ,σ2)
下面的函数实现,等效下面函数实现,又叫重参数技巧

    def sampling(self, mu, log_var):
        std = torch.exp(0.5*log_var)
        # 用于生成与给定张量具有相同形状和类型的随机数张量,其元素值遵循标准正态分布(均值为0,标准差为1)。
        eps = torch.randn_like(std)
        return eps.mul(std).add_(mu) # return z sample

可以用下面的公式表示 sampling 函数的操作:
z = mu + ϵ × exp ⁡ ( log ⁡ ( σ 2 ) 2 ) z = \text{mu} + \epsilon \times \exp\left(\frac{\log(\sigma^2)}{2}\right) z=mu+ϵ×exp(2log(σ2))

这里的log_var 表示 由网络输出的对数方差,等效于 l o g ( σ 2 ) log(\sigma^2) log(σ2)
其中 torch.exp(0.5*log_var) = σ \sigma σ 表示为 std
使用 torch.randn_like(std) 生成一个与标准差 std 形状相同、服从标准正态分布(均值 0,方差 1)的随机噪声向量 ϵ \epsilon ϵ

1.2.3.1 sampling函数为什么这样写?推导
  1. ϵ \epsilon ϵ 是从标准正态分布 N ( 0 , 1 ) \mathcal{N}(0, 1) N(0,1) 中采样的。
  2. σ \sigma σ 是标准差,因此 ϵ × σ \epsilon \times \sigma ϵ×σ是从标准差为 σ \sigma σ 的正态分布中采样的。

根据正态分布的性质,如果 ϵ ∼ N ( 0 , 1 ) \epsilon \sim \mathcal{N}(0, 1) ϵN(0,1),则 ϵ × σ ∼ N ( 0 , σ 2 ) \epsilon \times \sigma \sim \mathcal{N}(0, \sigma^2) ϵ×σN(0,σ2)

  1. 加上均值 μ \mu μ 后,根据正态分布的线性变换性质:

z = μ + ϵ × σ ∼ N ( μ , σ 2 ) z = \mu + \epsilon \times \sigma \sim \mathcal{N}(\mu, \sigma^2) z=μ+ϵ×σN(μ,σ2)

在这里插入图片描述

1.2.4 vae完整训练代码

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms
from torchvision.utils import save_image
import os
'''
from: https://github.com/lyeoni/pytorch-mnist-VAE/blob/master/pytorch-mnist-VAE.ipynb
edit: zengxy+ gpt4o  2024.05.27
'''
class VAE(nn.Module):
    def __init__(self, x_dim, h_dim1, h_dim2, z_dim):
        """
        VAE模型的初始化。
        :param x_dim: 输入数据的维度。
        :param h_dim1: 第一个隐藏层的维度。
        :param h_dim2: 第二个隐藏层的维度。
        :param z_dim: 潜在变量的维度。
        """
        super(VAE, self).__init__()
        # 编码器部分
        self.fc1 = nn.Linear(x_dim, h_dim1)
        self.fc2 = nn.Linear(h_dim1, h_dim2)
        self.fc31 = nn.Linear(h_dim2, z_dim)  # 生成均值
        self.fc32 = nn.Linear(h_dim2, z_dim)  # 生成对数方差

        # 解码器部分
        self.fc4 = nn.Linear(z_dim, h_dim2)
        self.fc5 = nn.Linear(h_dim2, h_dim1)
        self.fc6 = nn.Linear(h_dim1, x_dim)

    def encoder(self, x):
        """
        编码器功能,用于生成均值和对数方差。
        :param x: 输入数据。
        """
        h = F.relu(self.fc1(x))
        h = F.relu(self.fc2(h))
        return self.fc31(h), self.fc32(h)
    
    def sampling(self, mu, log_var):
        """
        通过重新参数化技巧从标准正态分布中采样。
        :param mu: 生成的均值。
        :param log_var: 生成的对数方差。
        """
        std = torch.exp(0.5 * log_var)
        eps = torch.randn_like(std)
        return eps.mul(std).add_(mu)
    
    def decoder(self, z):
        """
        解码器功能,用于从潜在空间重构输入数据。
        :param z: 潜在变量。
        """
        h = F.relu(self.fc4(z))
        h = F.relu(self.fc5(h))
        return torch.sigmoid(self.fc6(h))
    
    def forward(self, x):
        """
        VAE模型的前向传递。
        :param x: 输入的训练数据数据。
        """
        mu, log_var = self.encoder(x.view(-1, 784))  # 将图像展平处理
        z = self.sampling(mu, log_var)
        return self.decoder(z), mu, log_var
    

    # @staticmethod
    # 这意味着该方法不依赖于类的实例,也不修改类的状态。它是类的一部分,但是可以在没有创建类的实例的情况下调用,且不需要self参
    def loss_function(self,recon_x, x, mu, log_var):
        """
        VAE的损失函数,包括重构损失和KL散度(KLD:Kullback-Leibler divergence)。
        :param recon_x: 重构的数据。
        :param x: 原始数据。
        :param mu: 编码的均值。
        :param log_var: 编码的对数方差。
        """
        BCE = F.binary_cross_entropy(recon_x, x.view(-1, 784), reduction='sum')
        KLD = -0.5 * torch.sum(1 + log_var - mu.pow(2) - log_var.exp())
        return BCE + KLD

class VAE_TrainMnist:
    def __init__(self, model, optimizer,local_data_path='./data/'):
        """
        训练和测试VAE模型的类。
        :param model: VAE模型。
        :param optimizer: 优化器。
        : local_data_path : 本地数据路径
        """
        self.model = model
        self.optimizer = optimizer
        self.local_data_path = local_data_path
        self.batch_size = 128


        # 加载数据集
    def load_minist(self):
        transform = transforms.ToTensor()
        train_dataset = datasets.MNIST(root=self.local_data_path, train=True, transform=transform, download=True)
        test_dataset = datasets.MNIST(root=self.local_data_path, train=False, transform=transform)
        print('len(train_dataset)',len(train_dataset))
        print('len(test_dataset)',len(test_dataset))
        self.train_loader = torch.utils.data.DataLoader(dataset=train_dataset, batch_size=self.batch_size, shuffle=True)
        self.test_loader = torch.utils.data.DataLoader(dataset=test_dataset, batch_size=self.batch_size, shuffle=False)
        print("数据集加载完成!")
    
    def train(self, epoch,save_image_name="train_sample"):
        """
        训练过程。
        :param epoch: 当前的训练轮数。
        """
        self.model.train() #将导入的模型输入到训练模式
        self.load_minist()
        train_loss = 0
        for batch_idx, (data, _) in enumerate(self.train_loader):
            data = data.cuda()
            self.optimizer.zero_grad()
            recon_batch, mu, log_var = self.model(data)
            if batch_idx==0:
               
                print("只在第一次训练时打印,其他时候不打印,防止信息冗余")
                print('输入的训练数据:data.tensor.size',data.size())
                print("输出recon_batch.shape",recon_batch.shape)
                print("输出mu.shape",mu.shape)
                print("输出log_var.shape",log_var.shape)
            loss = self.model.loss_function(recon_batch, data, mu, log_var)
            loss.backward()
            train_loss += loss.item()
            self.optimizer.step()
            if batch_idx % 100 == 0:
                print(f'训练轮次: {epoch} [{batch_idx * len(data)}/{len(self.train_loader.dataset)} '
                      f'({100. * batch_idx / len(self.train_loader):.0f}%)]\t损失: {loss.item() / len(data):.6f}')
                
                save_image(recon_batch.view(self.batch_size, 1, 28, 28), f'./{save_image_name}_{epoch}.png')
                torch.save(self.model.state_dict(), f'checkpoint_{epoch}.pth')
        print(f'====> 训练轮次: {epoch} 平均损失: {train_loss / len(self.train_loader.dataset):.4f}')

    def test(self):
        """
        测试过程。
        """
        self.model.eval()
        test_loss = 0
        with torch.no_grad():
            for data, _ in self.test_loader:
                data = data.cuda()
                recon, mu, log_var = self.model(data)
                test_loss += self.model.loss_function(recon, data, mu, log_var).item()
        test_loss /= len(self.test_loader.dataset)
        print(f'====> 测试集损失: {test_loss:.4f}')
    


if __name__ == '__main__':
    vae = VAE(x_dim=784, h_dim1=512, h_dim2=256, z_dim=2)
    vae.cuda()
    optimizer = optim.Adam(vae.parameters(), lr=1e-3)
    trainer = VAE_TrainMnist(vae, optimizer,local_data_path='./data/')
    # trainer.load_minist() # 
    for epoch in range(1, 3):
        trainer.train(epoch,save_image_name="train_sample")
        trainer.test()

输出 (为演示只训练3个,可以多训练几个)

len(train_dataset) 60000
len(test_dataset) 10000
数据集加载完成!
只在第一次训练时打印,其他时候不打印,防止信息冗余
输入的训练数据:data.tensor.size torch.Size([128, 1, 28, 28])
输出recon_batch.shape torch.Size([128, 784])
输出mu.shape torch.Size([128, 2])
输出log_var.shape torch.Size([128, 2])
训练轮次: 1 [0/60000 (0%)] 损失: 546.049500
训练轮次: 1 [12800/60000 (21%)] 损失: 183.680786
训练轮次: 1 [25600/60000 (43%)] 损失: 178.807480
训练轮次: 1 [38400/60000 (64%)] 损失: 168.790909
训练轮次: 1 [51200/60000 (85%)] 损失: 167.150360
====> 训练轮次: 1 平均损失: 182.9290
====> 测试集损失: 164.0699
训练轮次: 2 [0/60000 (0%)] 损失: 159.409363
训练轮次: 2 [12800/60000 (21%)] 损失: 161.341766
训练轮次: 2 [25600/60000 (43%)] 损失: 155.624207
训练轮次: 2 [38400/60000 (64%)] 损失: 156.475525
训练轮次: 2 [51200/60000 (85%)] 损失: 157.767914
====> 训练轮次: 2 平均损失: 159.2507
====> 测试集损失: 156.2279

请添加图片描述

二、VAE原理探究和简易推导

核心参考:台大李宏毅教授ppt:https://speech.ee.ntu.edu.tw/~tlkagk/courses/ML_2016/Lecture/VAE%20(v5).pdf

2.1 vae模型训练架构图 (训练目的:重构出训练样本+解码器输入趋近于正态分布)

在这里插入图片描述
训练过程中,输入训练图片(batch,图片H*W),
编码器输出最终的结果 recon_batch, 编码器输出中间变量 mu, log_var
其中编码器和解码器的神经网络实现结构有区别,但是输出变量的个数,意义差别不大
NN Encoder输出 c i = exp ⁡ ( σ i ) × e i + m i c_i = \exp(\sigma_i) \times e_i + m_i ci=exp(σi)×ei+mi 表示根据编码器输出来计算 c c c,这里代码z等同,实际代码计算时有差异

Minimize部分:

核心loss, 保证了编码器输出值,经过重采样后近似接近标准正态分布,所以我们再生成的时候,可以直接z~(0,1)采样解码得到生成图像。(下图中sim表示两个分布的相似度,也可以用KL散步来表示)

2.2 VAE 中的 KL 散度(KL Divergence)从零推导(最大化对数似然函数)

变分自编码器的目标是最大化观测数据的对数似然函数,即:

P ( x ) = ∫ P ( z ) P ( x ∣ z )   d z P(x) = \int P(z) P(x|z) \, dz P(x)=P(z)P(xz)dz

其中:

  • P ( z ) P(z) P(z) 是潜在变量 ( z ) 的先验分布, 表示在没有任何观测数据 x 的情况下,对潜在变量 z 的先验知识通常假设,这里为标准正态分布 N ( 0 , I ) \mathcal{N}(0, I) N(0,I)
  • P ( x ∣ z ) P(x|z) P(xz) 是给定潜在变量 z z z 时,生成观测数据 x x x的概率分布。

这整个公式表示的是,观测数据 𝑥的概率可以通过对所有可能的潜在变量 𝑧进行求和(积分)

对数似然函数为:

L = ∑ x log ⁡ P ( x ) L = \sum_x \log P(x) L=xlogP(x)

引入变分分布 q ( z ∣ x ) q(z|x) q(zx)

由于直接计算 P(x)是不可行的,我们引入一个变分分布 ( q(z|x) ) 来近似后验分布 ( P(z|x) )。

对数似然可以写成:

log ⁡ P ( x ) = ∫ q ( z ∣ x ) log ⁡ P ( x )   d z \log P(x) = \int q(z|x) \log P(x) \, dz logP(x)=q(zx)logP(x)dz

使用 ( q(z|x) ) 重写对数似然

通过将对数似然展开并使用 ( q(z|x) ) 重写,我们得到:

log ⁡ P ( x ) = ∫ q ( z ∣ x ) log ⁡ P ( z , x ) P ( z ∣ x )   d z \log P(x) = \int q(z|x) \log \frac{P(z, x)}{P(z|x)} \, dz logP(x)=q(zx)logP(zx)P(z,x)dz

其中 ( P(z, x) = P(x|z) P(z) ),所以公式可以写为:

log ⁡ P ( x ) = ∫ q ( z ∣ x ) log ⁡ P ( x ∣ z ) P ( z ) P ( z ∣ x )   d z \log P(x) = \int q(z|x) \log \frac{P(x|z) P(z)}{P(z|x)} \, dz logP(x)=q(zx)logP(zx)P(xz)P(z)dz

分解对数似然

进一步分解对数似然:

log ⁡ P ( x ) = ∫ q ( z ∣ x ) log ⁡ ( P ( x ∣ z ) P ( z ) q ( z ∣ x ) q ( z ∣ x ) P ( z ∣ x ) )   d z \log P(x) = \int q(z|x) \log \left( \frac{P(x|z) P(z)}{q(z|x)} \frac{q(z|x)}{P(z|x)} \right) \, dz logP(x)=q(zx)log(q(zx)P(xz)P(z)P(zx)q(zx))dz

将对数拆分为两部分:

log ⁡ P ( x ) = ∫ q ( z ∣ x ) log ⁡ P ( x ∣ z ) P ( z ) q ( z ∣ x )   d z + ∫ q ( z ∣ x ) log ⁡ q ( z ∣ x ) P ( z ∣ x )   d z \log P(x) = \int q(z|x) \log \frac{P(x|z) P(z)}{q(z|x)} \, dz + \int q(z|x) \log \frac{q(z|x)}{P(z|x)} \, dz logP(x)=q(zx)logq(zx)P(xz)P(z)dz+q(zx)logP(zx)q(zx)dz

第一项:

∫ q ( z ∣ x ) log ⁡ P ( x ∣ z ) P ( z ) q ( z ∣ x )   d z \int q(z|x) \log \frac{P(x|z) P(z)}{q(z|x)} \, dz q(zx)logq(zx)P(xz)P(z)dz

证据下界(ELBO)

第二项:

∫ q ( z ∣ x ) log ⁡ q ( z ∣ x ) P ( z ∣ x )   d z = D KL ( q ( z ∣ x ) ∥ P ( z ∣ x ) ) \int q(z|x) \log \frac{q(z|x)}{P(z|x)} \, dz = D_{\text{KL}} (q(z|x) \| P(z|x)) q(zx)logP(zx)q(zx)dz=DKL(q(zx)P(zx))

KL 散度

证据下界(ELBO)

由于 KL 散度 D KL ( q ( z ∣ x ) ∥ P ( z ∣ x ) ) D_{\text{KL}} (q(z|x) \| P(z|x)) DKL(q(zx)P(zx))总是非负的,因此我们有:

log ⁡ P ( x ) ≥ E q ( z ∣ x ) [ log ⁡ P ( x ∣ z ) ] − D KL ( q ( z ∣ x ) ∥ P ( z ) ) \log P(x) \geq \mathbb{E}_{q(z|x)} [\log P(x|z)] - D_{\text{KL}} (q(z|x) \| P(z)) logP(x)Eq(zx)[logP(xz)]DKL(q(zx)P(z))

这一不等式被称为证据下界(ELBO),记作 ( L_b ):

L b = E q ( z ∣ x ) [ log ⁡ P ( x ∣ z ) ] − D KL ( q ( z ∣ x ) ∥ P ( z ) ) L_b = \mathbb{E}_{q(z|x)} [\log P(x|z)] - D_{\text{KL}} (q(z|x) \| P(z)) Lb=Eq(zx)[logP(xz)]DKL(q(zx)P(z))

2.3 KL 散度公式的进一步化简 (得到loss函数)!

为了推导 KL 散度,我们需要具体定义 ( q(z|x) ) 和 ( p(z) ) 的形式。在 VAE 中,通常选择以下形式:

  • ( q(z|x) = \mathcal{N}(z; \mu(x), \sigma^2(x)) )
  • ( p(z) = \mathcal{N}(z; 0, I) )

KL 散度的计算公式为:

D KL ( N ( μ , σ 2 ) ∥ N ( 0 , I ) ) = 1 2 ∑ i = 1 d ( σ i 2 + μ i 2 − log ⁡ σ i 2 − 1 ) D_{\text{KL}} \left( \mathcal{N}(\mu, \sigma^2) \| \mathcal{N}(0, I) \right) = \frac{1}{2} \sum_{i=1}^{d} \left( \sigma_i^2 + \mu_i^2 - \log \sigma_i^2 - 1 \right) DKL(N(μ,σ2)N(0,I))=21i=1d(σi2+μi2logσi21)

具体推导过程如下:

  1. KL 散度的定义

D KL ( q ∥ p ) = ∫ q ( z ) log ⁡ q ( z ) p ( z )   d z D_{\text{KL}}(q \| p) = \int q(z) \log \frac{q(z)}{p(z)} \, dz DKL(qp)=q(z)logp(z)q(z)dz

  1. 代入具体的高斯分布形式

q ( z ∣ x ) = N ( z ; μ , σ 2 ) , p ( z ) = N ( z ; 0 , I ) q(z|x) = \mathcal{N}(z; \mu, \sigma^2), \quad p(z) = \mathcal{N}(z; 0, I) q(zx)=N(z;μ,σ2),p(z)=N(z;0,I)

  1. 计算对数概率密度函数

log ⁡ q ( z ∣ x ) = − 1 2 ( d log ⁡ ( 2 π ) + log ⁡ ∣ Σ ∣ + ( z − μ ) T Σ − 1 ( z − μ ) ) \log q(z|x) = -\frac{1}{2} \left( d \log(2\pi) + \log|\Sigma| + (z - \mu)^T \Sigma^{-1} (z - \mu) \right) logq(zx)=21(dlog(2π)+log∣Σ∣+(zμ)TΣ1(zμ))

log ⁡ p ( z ) = − 1 2 ( d log ⁡ ( 2 π ) + log ⁡ ∣ I ∣ + z T z ) \log p(z) = -\frac{1}{2} \left( d \log(2\pi) + \log|I| + z^T z \right) logp(z)=21(dlog(2π)+logI+zTz)

  1. 代入 KL 散度公式并化简

D KL ( q ( z ∣ x ) ∥ p ( z ) ) = E q ( z ∣ x ) [ log ⁡ q ( z ∣ x ) − log ⁡ p ( z ) ] D_{\text{KL}} \left( q(z|x) \| p(z) \right) = \mathbb{E}_{q(z|x)} \left[ \log q(z|x) - \log p(z) \right] DKL(q(zx)p(z))=Eq(zx)[logq(zx)logp(z)]

= E q ( z ∣ x ) [ − 1 2 ( log ⁡ ∣ Σ ∣ + ( z − μ ) T Σ − 1 ( z − μ ) − log ⁡ ∣ I ∣ − z T z ) ] = \mathbb{E}_{q(z|x)} \left[ -\frac{1}{2} \left( \log|\Sigma| + (z - \mu)^T \Sigma^{-1} (z - \mu) - \log|I| - z^T z \right) \right] =Eq(zx)[21(log∣Σ∣+(zμ)TΣ1(zμ)logIzTz)]

= 1 2 ( tr ( Σ ) + μ T μ − d − log ⁡ ∣ Σ ∣ ) = \frac{1}{2} \left( \text{tr}(\Sigma) + \mu^T \mu - d - \log |\Sigma| \right) =21(tr(Σ)+μTμdlog∣Σ∣)

对于对角协方差矩阵 ( \Sigma = \text{diag}(\sigma^2) ),上式化简为:

= 1 2 ∑ i = 1 d ( σ i 2 + μ i 2 − log ⁡ σ i 2 − 1 ) = \frac{1}{2} \sum_{i=1}^{d} \left( \sigma_i^2 + \mu_i^2 - \log \sigma_i^2 - 1 \right) =21i=1d(σi2+μi2logσi21)

对应实际实现中的 KL 散度

在实际实现中,编码器输出的是对数方差 ( \log(\sigma^2) ),记为 log_var。KL 散度的代码实现为:

KLD = -0.5 * torch.sum(1 + log_var - mu.pow(2) - log_var.exp())

在这里插入图片描述

VAE重构损失

在VAE中逐像素应用的公式为:
BCE = − ∑ j = 1 M [ x j log ⁡ ( x ^ j ) + ( 1 − x j ) log ⁡ ( 1 − x ^ j ) ] \text{BCE} = -\sum_{j=1}^M \left[ x_j \log(\hat{x}_j) + (1 - x_j) \log(1 - \hat{x}_j) \right] BCE=j=1M[xjlog(x^j)+(1xj)log(1x^j)]

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/674489.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

IPD推行成功的核心要素(八)市场管理与产品规划保证做正确的事情

产品开发管理是“正确地执行项目”,而市场管理及产品规划关注“执行正确的项目”,可以说后者对产品的成功更为关键。要实现产品的持续成功,还得从源头的市场管理抓起。成功的产品开发,必须面向市场需求,由需求牵引创新…

FlyMcu串口下载STLINK Utility

FlyMcu是串口下载 STLINK Utility是STLINK下载 生成hex文件 打开hex文件,点击开始编程 在编程之前,需要配置BOOT引脚,让STM32执行BootLoader,否则点击开始编程,程序会一直卡住。第一步STM32板上有跳线帽&#xf…

SuperSocket 服务器与客户端双向通讯

1、使用AppSession 的Send方法就可以向连接到的客户端发送数据。服务器端代码如下。 using System; using System.Collections.Generic; using System.Linq; using System.Text; using System.Threading.Tasks;//引入命名空间 using SuperSocket.Common; using SuperSocket.So…

【机器学习】逻辑回归:原理、应用与实践

🌈个人主页: 鑫宝Code 🔥热门专栏: 闲话杂谈| 炫酷HTML | JavaScript基础 ​💫个人格言: "如无必要,勿增实体" 文章目录 逻辑回归:原理、应用与实践引言1. 逻辑回归基础1.1 基本概念1.2 Sig…

leetCode-hot100-二分查找专题

二分查找 简介原理分析易错点分析例题33.搜索旋转排序数组34.在排序数组中查找元素的第一个和最后一个位置35.搜索插入位置240.搜索二维矩阵 Ⅱ 简介 二分查找,是指在有序(升序/降序)数组查找符合条件的元素,或者确定某个区间左右…

HTML静态网页成品作业(HTML+CSS)—— 香奈儿香水介绍网页(1个页面)

🎉不定期分享源码,关注不丢失哦 文章目录 一、作品介绍二、作品演示三、代码目录四、网站代码HTML部分代码 五、源码获取 一、作品介绍 🏷️本套采用HTMLCSS,未使用Javacsript代码,共有1个页面。 二、作品演示 三、代…

关于Acrel-2000E配电室综合监控系统的实际应用分析-安科瑞 蒋静

摘要:“三大工程”指的是保障性住房建设、“平急两用”公共基础设施建设、城中村改造,是我国在建设领域作出的重大决策部署,是根据房地产市场新形势推出的重要举措。其中城中村改造是解决群众急难愁盼问题的重大民生工程,该工程中配电房的建设…

新闻发稿:8个新闻媒体推广中最常见的错误-华媒舍

在数字时代,新闻媒体的推广手段已经越来越多样化。许多媒体在推广过程中常常会犯下一些常见错误。本文将会介绍八个新闻媒体在推广中最常见的错误,并希望能够帮助各位更好地规避这些问题。 1. 缺乏明确的目标受众 在进行推广前,新闻媒体需要…

华为OD机试 - 最大坐标值(Java 2024 D卷 100分)

华为OD机试 2024C卷题库疯狂收录中,刷题点这里 专栏导读 本专栏收录于《华为OD机试(JAVA)真题(A卷B卷C卷)》。 刷的越多,抽中的概率越大,每一题都有详细的答题思路、详细的代码注释、样例测试…

将HTML页面中的table表格元素转换为矩形,计算出每个单元格的宽高以及左上角坐标点,输出为json数据

export function huoQuTableElement() {const tableData []; // 存储表格数据的数组let res [];// 获取到包含表格的foreignObject元素const foreignObject document.getElementById(mydctable);if (!foreignObject){return ;}// 获取到表格元素let oldTable foreignObject…

Orange AIpro开箱上手

0.介绍 首先感谢官方给到机会,有幸参加这次活动。 OrangePi AIpro(8T)采用昇腾AI技术路线,具体为4核64位处理器AI处理器,集成图形处理器,支持8TOPS AI算力,拥有8GB/16GB LPDDR4X,可以外接32GB/64GB/128GB/2…

从小众到主流:KOC如何凭借微影响力塑造品牌传播新格局

随着数字化的飞速发展,KOC作为社交媒体上的一股新兴力量,正以其微小的粉丝基数和高度互动性,引发一场微影响力革命。与传统的KOL不同,KOC通常拥有较小的粉丝基数,但却能够凭借高度互动性和真实的消费者体验&#xff0c…

编写一个问卷界面 并用JavaScript来验证表单内容

倘若文章和代码中有任何错误或疑惑&#xff0c;欢迎提出交流哦~ 简单的html和css初始化 今天使用JavaScript来实现对表单输入的验证&#xff0c; 首先写出html代码如下&#xff1a; <!DOCTYPE html> <html lang"en"> <head><meta charset&qu…

FY-SA-20237·8-WhyWeSpin

Translated from the Scientific American, July/August 2023 issue. Why We Spin (我们为什么旋转) Primates may play with reality by twirling around 翻译&#xff1a;灵长类动物有能力通过旋转或旋转运动来操纵或扭曲他们对现实的感知。 解释&#xff1a; “Primates”…

跟着大佬学RE(二)

[ACTF新生赛2020]easyre enc~}|{zyxwvutsrqponmlkjihgfedcba_^]\[ZYXWVUTSRQPONMLKJIHGFEDCBA?><;:9876543210/.-,*)(\0x27&%$# !" v4*F\"N,\"(I? v4list(map(ord,v4)) print(v4) #( v4[i] ! _data_start__[*((char *)v5 i) - 1] ) flaglist(ACTF…

光猫、路由器的路由模式、桥接模式、拨号上网

下面提到的路由器都是家用路由器 一、联网条件 1.每台电脑、路由器、光猫想要上网&#xff0c;都必须有ip地址。 2.电脑获取ip 可以设置静态ip 或 向DHCP服务器(集成在路由器上) 请求ip 电话线上网时期&#xff0c;猫只负责模拟信号和数字信号的转换&#xff0c;电脑需要使…

折半查找二分查找

简介 折半查找也就是二分查找&#xff0c;也可以叫二分法&#xff0c;本质上都是一样的&#xff0c;通过比对中间值与目标值&#xff0c;一次性就能筛掉一半的数字。 举例&#xff1a; 一个猜数字游戏&#xff0c;让你来猜1-100中我选中的数&#xff0c;如果猜中游戏结束&…

EE trade:量化交易需要什么条件才能做

量化交易结合了金融市场知识和计算机科学技术&#xff0c;利用数学和统计模型来进行交易决策。要成功进行量化交易&#xff0c;需要具备以下几个方面的条件&#xff1a; 1. 知识和技能 金融市场知识&#xff1a;需要理解金融市场的基本原理&#xff0c;包括股票、债券、期货、…

学会读书并不简单,如何真正学会读书

一、教程描述 读书是要讲究方法的&#xff0c;否则就会事倍功半&#xff0c;比如&#xff0c;在学习书本上的每一个问题每一章节的时候&#xff0c;首先应当不只看到书面上&#xff0c;而且还要看到书背后的东西&#xff0c;在对书中每一个问题都经过细嚼慢咽&#xff0c;其次…

AI对话聊天软件有哪些?这5款AI软件值得推荐

AI对话聊天软件有哪些&#xff1f;AI对话聊天软件在现代社会中的重要性日益凸显。它们不仅是沟通的工具&#xff0c;更是人们日常生活中的智能助手。通过深度学习和自然语言处理技术&#xff0c;这些软件能够理解我们的意图&#xff0c;提供个性化的建议和服务&#xff0c;让交…