经典神经网络(9)VAE模型原理及其在MNIST数据集上的应用

经典神经网络(9)VAE模型原理及其在MNIST数据集上的应用

  • 图片生成领域来说,有四大主流生成模型:生成对抗模型(GAN)、变分自动编码器(VAE)、流模型(Flow based Model)、扩散模型(Diffusion Model)。

  • VAE 的encoder是学习一个概率分布,所以VAE也可以随机采样生成图片,但VAE图片还原效果很弱,生成的图像模糊,效果不如diffusion model。但VAE可以减少训练和推理时间,降低GPU硬件要求。

  • 从2022年开始,主要爆火的图片生成模型是Diffusion Model(扩散模型)为主。在Stable Diffusion中利用VAE将原图512x512x3压缩到64x64x4的潜空间(Latent Space),通过在隐式表征(而不是完整图像)上进行扩散,可以在使用更少的内存的同时,减少UNet层数并加速图片的生成。与此同时,我们仍能把结果输入VAE的解码器,从而解码得到高分辨率图像,隐式表征极大降低了训练和推理成本。

    在这里插入图片描述

  • VAE损失函数的推导过程中,同样用到了KL散度的概念,可以参考:信息量、熵、KL散度、交叉熵概念理解

1 自编码器(Auto-encoder,AE)

1.1 自编码器概述

  • 自编码器是一种无监督的神经网络模型,可以用于数据的降维、特征提取和数据重建等任务。

  • 它由编码器和解码器两部分组成(如下图所示):

    • 编码器将输入数据压缩成低维特征向量,即编码Code;
    • 解码器则将低维特征向量还原成原始数据。
    • 在自编码器整个训练过程中,目标是最小化输入数据和重建数据之间的差异,以学习到更加有效的特征表示。
  • 最简单的自动编码器是由线性层构成的,叫做线性自编码器(如下图所示)。

    • 输出层的神经元数量往往与输入层的神经元数量一致;
    • 网络架构往往呈对称性,且中间结构简单、两边结构复杂。

在这里插入图片描述

1.2 自编码器存在的问题

  • 如下图所示,假设有两张训练图片,一张是全月图,一张是半月图,经过训练我们的自编码器模型已经能无损地还原这两张图片。
  • 接下来,我们在code空间上,两张图片的编码点中间处取一点,然后将这一点交给解码器,我们希望新的生成图片是一张清晰的图片(类似3/4全月的样子)。
  • 但是,实际的结果是,生成图片是模糊且无法辨认的乱码图。
  • 原因是:基本自编码器给定一张图片生成原始图片,从输入到输出都是确定的,没有任何随机的成分,为了使模型表现很好,在不断的迭代训练中,编码器的输出也就是解码器的输入会趋于确定,这样才能让解码器能生成与输入数据更接近的数据,以使损失变得更小。但是这就与生成器的初衷有悖了:
    • 生成器的初衷实际上是为了生成更多全新的数据,而不是为了生成与输入数据更像的数据。

在这里插入图片描述

1.3 AE模型在MNIST数据集上的应用

  • 代码如下所示,ae_original_image.png是原始数据

在这里插入图片描述

  • ae_image_encoder.png是经过编码器,解码器后得到的图片

在这里插入图片描述

  • ae_image.png是将三个编码器得到的编码值进行平均得到的图片,可以看到是模糊且无法辨认的乱码图,这就是自编码器存在的问题。

在这里插入图片描述

import torch
import torch.nn as nn
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
from torchvision import datasets
from torchvision.utils import save_image


class AE(nn.Module):
    def __init__(self):
        super().__init__()
        # 编码器
        self.encoder = nn.Sequential(nn.Linear(784, 256),
                                     nn.ReLU(),
                                     nn.Linear(256, 128),
                                     nn.ReLU(),
                                     nn.Linear(128, 10)
                                     )
        # 解码器
        self.decoder = nn.Sequential(nn.Linear(10, 128),
                                     nn.ReLU(),
                                     nn.Linear(128, 256),
                                     nn.ReLU(),
                                     nn.Linear(256, 784),
                                     )

    def forward(self, x):
        x = x.view(-1, 28 * 28)  # [b,1,28,28] ———> [b,784]
        x = self.encoder(x)      # [b,784] ———> [b,10]
        x = self.decoder(x)      # [b,10] ———> [b,784]
        return x

def train(model, loss_fn, opt, epoch=200):
    for epoch in range(epoch):
        model.train()
        total_loss = 0.0
        for x, _ in trian_dl:
            x = x.to(device)
            y_pre = model(x)
            loss = loss_fn(y_pre, x.reshape(-1, 784))
            opt.zero_grad()
            loss.backward()
            opt.step()
            total_loss += loss
            break
        print(f'epoch = {epoch + 1}, train loss = {total_loss / len(trian_dl): .4f}')

    torch.save(model.state_dict(), 'ae_model.pth')


if __name__ == '__main__':

    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    # 1、读取数据集
    trian_dl = torch.utils.data.DataLoader(
        datasets.MNIST(
            "/root/autodl-fs/data/minist",
            train=True,
            download=False,
            transform=transforms.Compose(
                [
                    transforms.ToTensor(),
                ]
            ),
        ),
        batch_size=256,
        shuffle=True,
        num_workers=8
    )

    # 2、创建AE模型及优化器
    model = AE().to(device)
    opt = torch.optim.Adam(model.parameters(), lr=0.001)
    loss_fn = nn.MSELoss(reduction='mean')

    # 4、模型训练
    train(model, loss_fn, opt, epoch=100)

    # 5、模型推理
    # 加载模型
    model.load_state_dict(torch.load('ae_model.pth', map_location=device))
    bs = 3
    text_dl = torch.utils.data.DataLoader(
        datasets.MNIST(
            "/root/autodl-fs/data/minist",
            train=False,
            download=False,
            transform=transforms.Compose(
                [
                    transforms.ToTensor(),
                ]
            ),
        ),
        batch_size=bs,
        shuffle=False
    )

    # 获取编码
    for x, y in text_dl:
        model.eval()
        save_image(x, "ae_original_image.png")
        x = x.to(device)
        sample_encoder = model.encoder(x.reshape(-1, 784))
        print(sample_encoder)

        image_encoder = model.decoder(sample_encoder).reshape(-1, 1, 28, 28)
        save_image(image_encoder, "ae_image_encoder.png")

        sample_sum = torch.sum(sample_encoder, dim=0) / 2
        sample = sample_sum.tile(bs, 1)
        image = model.decoder(sample).reshape(-1, 1, 28, 28)

        save_image(image, "ae_image.png")
        break

2 变分自编码器(VAE)

2.1 如何解决自编码器的缺点

  • 我们现在已经知道,自编码器生成的图片是模糊且无法辨认的乱码图。如何解决这个问题呢?
    • 如下图所示,现在在给两张图片编码的时候加上一点噪音,使得每张图片的编码点出现在绿色箭头所示范围内。
    • 在训练模型的时候,绿色箭头范围内的点都有可能被采样到,这样解码器在训练时会把绿色范围内的点都尽可能还原成和原图相似的图片。
    • 然后我们可以关注之前那个失真点,现在它处于全月图和半月图编码的交界上,于是解码器希望它既要尽量相似于全月图,又要尽量相似于半月图,于是它的还原结果就是两种图的折中(3/4全月图)。

在这里插入图片描述

  • 由此我们发现,给编码器增添一些噪音,可以有效覆盖失真区域。
    • 不过这还并不充分,因为在上图的距离训练区域很远的黄色点处,它依然不会被覆盖到,仍是个失真点
    • 为了解决这个问题,我们可以试图把噪音无限拉长,使得对于每一个样本,它的编码会覆盖整个编码空间,不过我们得保证,在原编码附近编码的概率最高,离原编码点越远,编码概率越低。
    • 在这种情况下,图像的编码就由原先离散的编码点变成了一条连续的编码分布曲线,如下图所示。

在这里插入图片描述

  • 这种将图像编码由离散变为连续的方法,就是变分自编码的核心思想。

2.2 VAE的整体架构概述

  • 如下图所示,与普通自动编码器一样,VAE有编码器Encoder与解码器Decoder两大部分组成,原始图像从编码器输入,经编码器后形成隐式表示(Latent Representation),之后隐式表示被输入到解码器、再复原回原始输入的结构。
  • 然而,与普通AE不同的是,我们不会直接将Encoder编码后的结果传递给Decoder,而是要使得隐式表示满足既定分布(例如:正态分布)。
    • 变分自动编码器的Encoder在输出时,并不会直接输出原始数据的隐式表示,而是会输出从原始数据提炼出的均值 μ 和标准差 σ 。
    • 之后,我们需要建立均值为 μ 、标准差为 σ 的正态分布,并从该正态分布中抽样出隐式表示z,再将隐式表示z输入到Decoder中进行解码。
    • 对隐式表示z而言,它传递给Decoder的就不是原始数据的信息,而只是与原始数据同均值、同标准差的分布中的信息了。

在这里插入图片描述

2.3 VAE的前向过程

2.3.1 前向过程(一个样本产生一个正态分布)

  • 在auto-encoder中,编码器是直接产生一个编码的,但是在VAE中,为了给编码添加合适的噪音,编码器会输出两个编码:

    • 下图有三个样本,因此存在三组均值和标准差,一组均值和标准差只能生成一个正态分布,而一个正态分布中只能抽选一个数字,这是变分自编码器抽样的基本规则。下图中,每个样本经过Encoder后只输出了一组均值和标准差,那z自然只能有一列,隐式空间 ( c 1 , c 2 , c 3 ) (c_1,c_2,c_3) (c1,c2,c3)的结构为(3,1)。

    • 一个是原有编码 ( m 1 , m 2 , m 3 ) (m_1,m_2,m_3) (m1,m2,m3)(均值),注意这里是一个样本会输出一个均值

    • 一个是控制噪音干扰程度的编码 ( σ 1 、 σ 2 、 σ 3 ) (\sigma_1、\sigma_2、\sigma_3) (σ1σ2σ3)(注意这里是方差,下面推导过程用的是标准差),第二个编码其实很好理解,就是为随机噪音 ( e 1 , e 2 , e 3 ) (e_1,e_2,e_3) (e1,e2,e3)分配权重。加上exp的目的是为了保证这个分配的权重是个正值,最后将原编码与噪音编码相加,就得到了VAE在code层的输出结果,即 c i = e σ i ∗ e i + m i c_i=e^{\sigma_i}*e_i+m_i ci=eσiei+mi

    • 得到了VAE在code层的输出结果 c c c后,进入解码器Decoder中,最后得到output。

在这里插入图片描述

  • 由上图可知,在损失函数方面,除了必要的重构损失外(让输出和输入相近),VAE还另外增添了一个损失函数(上图下方的L2损失函数),我们后面会详细推导此损失函数。

2.3.2 前向过程(一个样本产生多个正态分布)

  • 在变分自动编码器的流程当中,均值和标准差是通过第一个神经网络Encoder训练出来的。我们不可能知道当前样本服从的真实分布的状态,因此这一推断过程自然可以根据不同的规则(Encoder中不同的权重)得出不同的结果。
  • 如下图所示,我们可以令Encoder的输出层存在3个神经元,这样Encoder就会对每一个样本推断出三对不同的均值和标准差。这个行为相当于对样本数据所属的原始分布进行估计,但给出了三个可能的答案。因此现在,在每个样本下,我们就可以基于三个均值和标准差的组合生成三个不同的正态分布了。
  • 隐式空间越大,隐式表示z所携带的信息自然也会越多,自动编码器的表现就可能变得更好,因此在实际使用变分自动编码器的过程中,一个样本上至少都会生成10~100组均值和标准差,隐式表示z的结构一般也是较高维的矩阵

在这里插入图片描述

在这里插入图片描述

2.4 VAE损失函数的推导

2.4.1 从高斯混合模型到VAE

  • VAE的理论基础是高斯混合模型,即任何一个数据的分布,都可以看作是若干高斯分布的叠加。
  • 如图所示,如果P(X)代表一种分布的话,存在一种拆分方法能让它表示成图中若干浅蓝色曲线对应的高斯分布的叠加。

在这里插入图片描述

  • 如下图所示,我们将编码换成一个连续变量z,为了计算方便,我们规定z服从标准正态分布(实际上并不一定要选用)。正如2.3前向过程所示,对于每一个采样点z,会有两个函数 u u u σ \sigma σ,分别决定z对应到的高斯分布的均值和方差,然后在积分域上所有的高斯分布的累加就成为了原始分布P(X)。
    • 我们使用 p p p代表解码器;
    • p ( x ∣ z ) p(x|z) p(xz)代表给定z时解码器产生 x x x的概率;
    • x x x并非一个具体的值,而可以看作是一类数据,比如 x x x可以代表某种风格的手写体数字, p ( x ∣ z ) p(x|z) p(xz)就是生成这些数字的概率,这里的概率也并非一个具体的值,而是某一风格的每个数字对应了一个概率,其输出的是概率分布。

在这里插入图片描述

  • 那么VAE的优化目标是什么呢?其实就是最大化解码器输出x的概率,即最大化 p ( x ) p(x) p(x)

2.4.2 损失函数的推导

我们现在的优化目标就是:最大化编码器输出x的概率
p ( x ) = ∫ z p ( x ∣ z ) p ( z ) d z p(x)=\int_zp(x|z)p(z)dz p(x)=zp(xz)p(z)dz

  • 注意:这里的 p ( z ) p(z) p(z)可以是任意分布,在VAE中我们常常假设 p ( z ) p(z) p(z)服从标准正态分布

为了最大化 p ( x ) p(x) p(x),我们可以采用极大似然估计的方法来进行:
L = ∑ x l o g p ( x ) L=\sum_xlogp(x) L=xlogp(x)

  • 这里的每个x可以理解为代表了某一个风格的手写体,我们的目标是生成手写体数字,因此我们并不会局限其风格。

由于最大化L相当于最大化 l o g p ( x ) logp(x) logp(x),因此后续目标调整为最大化 l o g p ( x ) logp(x) logp(x)。我们假设q代表了编码器, q ( z ∣ x ) q(z|x) q(zx)就代表了给定x时编码器产生z的概率
给定任意 x ,其产生不同 z 的概率之和为 1 ,因此: ∫ z q ( z ∣ x ) = 1 而 p ( x ) 和 z 无关,那么 : l o g p ( x ) = ∫ z q ( z ∣ x ) l o g p ( x ) d z ( 公式一 ) 依据联合概率公式 p ( x ) = p ( x , z ) p ( z ∣ x ) = p ( x , z ) q ( z ∣ x ) q ( z ∣ x ) p ( z ∣ x ) 代入 ( 公式一 ) : l o g p ( x ) = ∫ z q ( z ∣ x ) l o g ( p ( x , z ) q ( z ∣ x ) q ( z ∣ x ) p ( z ∣ x ) ) d z 我们将 l o g 里的乘积拆开,变为两项之和 = ∫ z q ( z ∣ x ) l o g ( p ( x , z ) q ( z ∣ x ) ) d z + ∫ z q ( z ∣ x ) l o g ( q ( z ∣ x ) p ( z ∣ x ) ) d z 给定任意x,其产生不同z的概率之和为1,因此:\\ \int_zq(z|x)=1\\ 而p(x)和z无关,那么:\\ logp(x)=\int_zq(z|x)logp(x)dz(公式一) \\ 依据联合概率公式p(x)=\frac{p(x,z)}{p(z|x)}=\frac{p(x,z)}{q(z|x)}\frac{q(z|x)}{p(z|x)}\\ 代入(公式一):logp(x)=\int_zq(z|x)log(\frac{p(x,z)}{q(z|x)}\frac{q(z|x)}{p(z|x)})dz\\ 我们将log里的乘积拆开,变为两项之和\\ =\int_zq(z|x)log(\frac{p(x,z)}{q(z|x)})dz+\int_zq(z|x)log(\frac{q(z|x)}{p(z|x)})dz 给定任意x,其产生不同z的概率之和为1,因此:zq(zx)=1p(x)z无关,那么:logp(x)=zq(zx)logp(x)dz(公式一)依据联合概率公式p(x)=p(zx)p(x,z)=q(zx)p(x,z)p(zx)q(zx)代入(公式一)logp(x)=zq(zx)log(q(zx)p(x,z)p(zx)q(zx))dz我们将log里的乘积拆开,变为两项之和=zq(zx)log(q(zx)p(x,z))dz+zq(zx)log(p(zx)q(zx))dz
结合KL散度公式,我们可以看出第二项其实就是 K L ( q ( z ∣ x ) ∣ ∣ p ( z ∣ x ) ) KL(q(z|x) || p(z|x)) KL(q(zx)∣∣p(zx))。因为该值为非负项,所以 l o g p ( x ) logp(x) logp(x)不可能小于第一项,我们使用 L b Lb Lb来代表第一项。
l o g p ( x ) = L b + ∫ z q ( z ∣ x ) l o g ( q ( z ∣ x ) p ( z ∣ x ) ) d z = L b + K L ( q ( z ∣ x ) ∣ ∣ p ( z ∣ x ) ) 结合式式子 : p ( x ) = ∫ z p ( x ∣ z ) p ( z ) d z 分析如下: 当 p ( x ∣ z ) 不变时, p ( x ) 也不变,从而 l o g p ( x ) 也不变,那么 L b + K L ( q ( z ∣ x ) ∣ ∣ p ( z ∣ x ) ) 的值就不会变。 这时如果我们利用 q ( z ∣ x ) 来最大化 L b ,那么 L b 就会增大,而 K L ( q ( z ∣ x ) ∣ ∣ p ( z ∣ x ) ) 的值就会减小。 当我们调节到 q ( z ∣ x ) 与 p ( z ∣ x ) 完全相同时, K L 散度就为 0 , L b 和 l o g p ( x ) 完全一致 logp(x)=Lb+\int_zq(z|x)log(\frac{q(z|x)}{p(z|x)})dz\\ =Lb+KL(q(z|x) || p(z|x))\\ 结合式式子:p(x)=\int_zp(x|z)p(z)dz\\ 分析如下:\\ 当p(x|z)不变时,p(x)也不变,从而log p(x)也不变,那么Lb+KL(q(z|x) || p(z|x))的值就不会变。\\ 这时如果我们利用q(z|x)来最大化Lb,那么Lb就会增大,而KL(q(z|x) || p(z|x))的值就会减小。\\ 当我们调节到q(z|x)与p(z|x)完全相同时,KL散度就为0,Lb和logp(x)完全一致\\ logp(x)=Lb+zq(zx)log(p(zx)q(zx))dz=Lb+KL(q(zx)∣∣p(zx))结合式式子:p(x)=zp(xz)p(z)dz分析如下:p(xz)不变时,p(x)也不变,从而logp(x)也不变,那么Lb+KL(q(zx)∣∣p(zx))的值就不会变。这时如果我们利用q(zx)来最大化Lb,那么Lb就会增大,而KL(q(zx)∣∣p(zx))的值就会减小。当我们调节到q(zx)p(zx)完全相同时,KL散度就为0Lblogp(x)完全一致
在这里插入图片描述

那么如果 q ( z ∣ x ) 不变呢?此时当我们增大 p ( x ∣ z ) 时, L b 会增大且 p ( x ) 会增大,即 l o g p ( x ) 也会增大。 由此我们可以得出结论,只要我们最大化 L b 就能使 l o g p ( x ) 最大化。 那么如果q(z|x)不变呢?此时当我们增大p(x|z)时,Lb会增大且p(x)会增大,即log p(x)也会增大。\\ 由此我们可以得出结论,只要我们最大化Lb就能使log p(x)最大化。 那么如果q(zx)不变呢?此时当我们增大p(xz)时,Lb会增大且p(x)会增大,即logp(x)也会增大。由此我们可以得出结论,只要我们最大化Lb就能使logp(x)最大化。
此时我们的优化目标就变成了最大化 L b Lb Lb
L b = ∫ z q ( z ∣ x ) l o g ( p ( x , z ) q ( z ∣ x ) ) d z = ∫ z q ( z ∣ x ) l o g ( p ( x ∣ z ) p ( z ) q ( z ∣ x ) ) d z = ∫ z q ( z ∣ x ) l o g p ( z ) q ( z ∣ x ) d z + ∫ z q ( z ∣ x ) l o g p ( x ∣ z ) d z = − K L ( q ( z ∣ x ) ∣ ∣ p ( z ) ) + ∫ z q ( z ∣ x ) l o g p ( x ∣ z ) d z = − K L ( q ( z ∣ x ) ∣ ∣ p ( z ) ) + E q ( z ∣ x ) l o g p ( x ∣ z ) 我们加负号,转化为最小化问题: L o s s = K L ( q ( z ∣ x ) ∣ ∣ p ( z ) ) − E q ( z ∣ x ) l o g p ( x ∣ z ) Lb=\int_zq(z|x)log(\frac{p(x,z)}{q(z|x)})dz\\ =\int_zq(z|x)log(\frac{p(x|z)p(z)}{q(z|x)})dz\\ =\int_zq(z|x)log\frac{p(z)}{q(z|x)}dz+\int_zq(z|x)logp(x|z)dz\\ =-KL(q(z|x)||p(z)) + \int_zq(z|x)logp(x|z)dz\\ =-KL(q(z|x)||p(z)) + E_{q(z|x)}logp(x|z)\\ 我们加负号,转化为最小化问题:\\ Loss=KL(q(z|x)||p(z)) - E_{q(z|x)}logp(x|z) Lb=zq(zx)log(q(zx)p(x,z))dz=zq(zx)log(q(zx)p(xz)p(z))dz=zq(zx)logq(zx)p(z)dz+zq(zx)logp(xz)dz=KL(q(zx)∣∣p(z))+zq(zx)logp(xz)dz=KL(q(zx)∣∣p(z))+Eq(zx)logp(xz)我们加负号,转化为最小化问题:Loss=KL(q(zx)∣∣p(z))Eq(zx)logp(xz)

  • 此时VAE的最终目标就一目了然了,VAE的训练目标有两个:
    • 第一,最大化在 q ( z ∣ x ) q(z|x) q(zx)这个分布下 l o g p ( x ∣ z ) logp(x|z) logp(xz)的期望(L1损失),其中 q ( z ∣ x ) q(z|x) q(zx)为编码器输入 x x x时产生 z z z的概率。假设解码器利用 z z z生成出了 x ’ x’ x,我们就需要使 x ’ x’ x尽可能向 x x x靠近,以最大化 l o g p ( x ∣ z ) logp(x|z) logp(xz)
    • 第二,最小化 K L ( q ( z ∣ x ) ∣ ∣ p ( z ) ) KL(q(z|x)||p(z)) KL(q(zx)∣∣p(z))(L2损失),使 q ( z ∣ x ) q(z|x) q(zx)的分布尽量向 p ( z ) p(z) p(z)靠近;

根据上述的两个训练目标,VAE的损失函数也被设计为两个:

  • L1损失目的是输出的 x ’ x’ x尽可能向原始 x x x靠近(即重构损失),我们可以最小化x’和x之间的MSE Loss或者BCE Loss。

  • L2用于最小化 K L ( q ( z ∣ x ) ∣ ∣ p ( z ) ) KL(q(z|x)||p(z)) KL(q(zx)∣∣p(z)),VAE假设 q ( z ∣ x ) q(z|x) q(zx)的分布为正态分布,而 p ( z ) p(z) p(z)为标准正态分布。计算两个正态分布之间的KL散度的公式如下(这里就不推导了,直接给出):

    在这里插入图片描述

    由于此处 p ( z ) p(z) p(z)为标准正态分布,因此其 μ 2 μ_2 μ2为0, σ 2 σ_2 σ2为1,那么我们带入后可得
    K L ( q ( z ∣ x ) ∣ ∣ p ( z ) ) = − 1 2 ( l o g σ 2 − σ 2 − u 2 + 1 ) 其中 σ 为 q ( z ∣ x ) 的标准差, μ 为 q ( z ∣ x ) 的均值 KL(q(z|x)||p(z))=-\frac{1}{2}(log\sigma^2-\sigma^2-u^2+1) \\ 其中σ为q(z|x)的标准差,μ为q(z|x)的均值 KL(q(zx)∣∣p(z))=21(logσ2σ2u2+1)其中σq(zx)的标准差,μq(zx)的均值

2.4.3 VAE的重参数化

蓝色节点为采样节点,左侧由于采样的存在无法对 μ μ μ σ σ σ求导,反向传播无法进行,而右侧由于采用的重参数化技巧,灰色节点全部打通,使得网络能够正常进行反向传播。

在这里插入图片描述

2.4.4 VAE和GAN的区别、VAE的本质

可以参考:变分自编码器(一):原来是这么一回事 - 科学空间

2.5 VAE在MNIST数据集上的应用

  • 解码器输出的不是方差 σ 2 σ^2 σ2 ,而是对数方差 l o g σ 2 logσ^2 logσ2,详见下面代码的encoder函数,这么做的原因就是,神经网络的输出是 [−∞,+∞] 的任意数值,但是方差不可能为负数,所以对方差取对数以满足神经网络输出值域的要求。
  • 重参数化详见reparameter函数。
  • 训练时候不要忘了损失函数的KL散度部分,利用上面推导的公式计算。这里由于模型的输出是对数方差 l o g σ 2 logσ^2 logσ2而不是方差,所以原始的计算公式需要做一个转换,如下:

K L ( q ( z ∣ x ) ∣ ∣ p ( z ) ) = − 1 2 ( l o g σ 2 − σ 2 − u 2 + 1 ) 【其中 σ 为 q ( z ∣ x ) 的标准差, μ 为 q ( z ∣ x ) 的均值】 = − 1 2 ( l o g _ v a r − e l o g _ v a r − m u 2 + 1 ) KL(q(z|x)||p(z))\\ =-\frac{1}{2}(log\sigma^2-\sigma^2-u^2+1)【其中σ为q(z|x)的标准差,μ为q(z|x)的均值】 \\ =-\frac{1}{2}(log\_var-e^{log\_var}-mu^2+1) KL(q(zx)∣∣p(z))=21(logσ2σ2u2+1)【其中σq(zx)的标准差,μq(zx)的均值】=21(log_varelog_varmu2+1)

import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
from torchvision import datasets
from torchvision.utils import save_image

class VAE(nn.Module):
    def __init__(self):
        super().__init__()

        # encoder
        self.encoder_layer = nn.Sequential(nn.Linear(784, 256),
                                           nn.ReLU(),
                                           nn.Linear(256, 128),
                                           )
        self.fc1 = nn.Linear(128, 10)  # 均值
        self.fc2 = nn.Linear(128, 10)  # log方差

        # decoder
        self.decoder = nn.Sequential(nn.Linear(10, 128),
                                     nn.ReLU(),
                                     nn.Linear(128, 256),
                                     nn.ReLU(),
                                     nn.Linear(256, 784),
                                     )

    def encoder(self, x):
        x = F.relu(self.encoder_layer(x))  # [b,784] ———> [b,128]
        mu = self.fc1(x)                   # [b,128] ———> [b,10]
        log_var = self.fc2(x)              # [b,128] ———> [b,10]
        return mu, log_var

    def reparameter(self, mu, log_var):
        device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        # 重参数化技巧 z = mu(均值) + eps(0-1正态分布) *  sigma(方差)
        std = torch.sqrt(torch.exp(log_var))
        eps = torch.rand_like(std)
        z = mu + eps * std     # [b, 10]
        return z

    def forward(self, x):
        x = x.view(-1, 28 * 28)                # [b,1,28,28] ———> [b,784]
        mu, log_var = self.encoder(x)          # [b,784] ———> [b,10]  mu  sigma
        z = self.reparameter(mu, log_var)      # [b,10] ———> [b,10]
        x = self.decoder(z)                    # [b,10] ———> [b,784]
        x_hat = x.reshape(-1, 1, 28, 28)
        return x_hat, mu, log_var

def train(model, epoch=200):
    for epoch in range(epoch):
        model.train()
        total_loss = 0.0
        total_bce_loss = 0.0
        total_kl_loss = 0.0
        for x, y in trian_dl:
            x, y = x.to(device), y.to(device)
            x_hat, mu, log_var = model(x)
            # 3-1、bce_loss
            bce_loss = F.binary_cross_entropy(torch.sigmoid(x_hat.view(-1, 784)), x.view(-1, 784), reduction='sum')
            # 3-2、kl_loss
            kl_loss = torch.sum(-0.5 * (log_var - torch.exp(log_var) - mu ** 2 + 1))
            loss = bce_loss + kl_loss
            opt.zero_grad()
            loss.backward()
            opt.step()

            total_loss += loss
            total_bce_loss += bce_loss
            total_kl_loss += kl_loss

        print(f'epoch = {epoch + 1},bce_loss = {(total_bce_loss / len(trian_dl)):.4f}, kl_loss = {(total_kl_loss / len(trian_dl)):.4f}, train loss = {total_loss / len(trian_dl):.4f}')

    torch.save(model.state_dict(), 'vae_model.pth')

if __name__ == '__main__':

    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    # 1、读取数据集
    trian_dl = torch.utils.data.DataLoader(
        datasets.MNIST(
            "/root/autodl-fs/data/minist",
            train=True,
            download=False,
            transform=transforms.Compose(
                [
                    transforms.ToTensor(),
                ]
            ),
        ),
        batch_size=256,
        shuffle=True,
        num_workers=12
    )


    # 2、创建VAE模型及优化器
    model = VAE().to(device)
    opt = torch.optim.Adam(model.parameters(), lr=0.001)

    # 4、模型训练
    train(model, epoch=300)

    # 5、模型推理
    # 加载模型
    model.load_state_dict(torch.load('vae_model.pth', map_location=device))


    model.eval()
    z = torch.randn(3, 10)  # 生成一个形状为 (3, 10) 的随机数张量
    image = model.decoder(z).reshape(-1, 1, 28, 28)

    save_image(image, "vae_image_random.png")

在这里插入图片描述

参考:
图片来自李宏毅老师的教程视频:https://www.bilibili.com/video/av15889450/?p=33
https://kexue.fm/archives/5253
https://blog.csdn.net/weixin_42491648/article/details/132384913
http://www.gwylab.com/note-vae.html

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/652647.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

【最优化方法】实验一 熟悉MATLAB基本功能

实验一  熟悉MATLAB基本功能 实验的目的和要求:在本次实验中,通过亲临使用MATLAB,对该软件做一全面了解并掌握重点内容。 实验内容: 1、全面了解MATLAB系统 2、实验常用工具的具体操作和功能 学习建…

【基础篇-Day8:JAVA字符串的学习】

目录 1、常用API2、String类2.1 String类的特点2.2 String类的常见构造方法2.3 String类的常见面试题:2.3.1 面试题一:2.3.2 面试题二:2.3.3 面试题三:2.3.4 面试题四: 2.4 String类字符串用于比较的方法2.5 String类字…

基坑气膜:建筑工地环保新利器—轻空间

随着城市化进程的加快,建筑行业的飞速发展带来了严重的环境问题,如噪音和粉尘污染,给人们的生活带来诸多不便。为了解决这些问题,建筑行业一直在探索更为环保和高效的施工方式。近年来,基坑气膜技术逐渐崭露头角&#…

【国信华源:以专业服务,协助水利厅抵御强暴雨】

5月18日-19日,广西出现入汛以来最强暴雨天气过程,钦州、防城港、北海、南宁等地出现特大暴雨,多地打破降雨量极值。国信华源技术团队积极行动驻守一线,为打好山洪灾害防御的提前战、主动战提供了技术支撑。 5月17日18时&#xff0…

SOAR-Top 10安全剧本最佳实践-百度网盘下载

概述: SOAR(Security Orchestration,Automation and Response安全编排自动化响应),Gartner 对 SOAR 的最新描述性定义(摘自 Gartner 报告《Hype Cycle on Threat-Facing Technologies, 2018》) 是:SOAR 是一系列技术的…

基于SpringBoot+Vue在线动漫信息平台设计和实现(源码+LW+部署讲解)

🌹作者主页:青花锁 🌹简介:Java领域优质创作者🏆、Java微服务架构公号作者😄 🌹简历模板、学习资料、面试题库、技术互助 🌹文末获取联系方式 📝 🌹推荐一个人…

使用nexus搭建的nodejs私库,定期清理无用的npm组件,彻底释放磁盘空间

一、背景 昨天我们整理了一篇关于docker私库,如何定期清理以释放磁盘空间的文章。 虽然也提及了npm前端应用的组件该如何定期清理的,本文是对它作一个补充说明。 前文也看到了,npm组件占用的blob空间为180多GB,急需清理。 二、…

K8s证书过期处理

问题描述 本地有一个1master2worker的k8s集群,今天启动VMware虚拟机之后发现api-server没有起来,docker一直退出,这个集群是使用kubeadm安装的。 于是kubectl logs查看了日志,发现证书过期了 解决方案: 查看证书 #…

vue3 部署后修改配置文件

前端项目部署之后,运维可以自行修改配置文件里的接口IP,达到无需再次打包就可以使用的效果 vue2如何修改请看vue 部署后修改配置文件(接口IP)_vue部署后修改配置文件-CSDN博客 使用前提: vite搭建的vue3项目 使用setu…

IND-ID-CPA 和 IND-ANON-ID-CPA Game

Src: https://eprint.iacr.org/2017/967.pdf

WGCLOUD部署好后,怎么登录WGCLOUD界面

WGCLOUD的server启动完成后,我们在浏览器里输入URL,如下 http://[server主机IP]:9999 注意默认端口就是9999,如果修改过,那么把端口改成自己的实际端口 这样就可以看到登录页面了,默认账号密码是:admin/…

2951. 找出峰值

找出数组中的峰值 给你一个下标从 0 开始的数组 mountain 。你的任务是找出数组 mountain 中的所有 峰值。 以数组形式返回给定数组中 峰值 的下标,顺序不限 。 注意 峰值 是指一个严格大于其相邻元素的元素。数组的第一个和最后一个元素 不 是峰值。 示例 1 …

VSCODE常用插件记录

重点提名: back & ForthBookmarksC/ChighlightSSH FS //SSH插件

《精通Stable Diffusion AI绘画:基础技巧、实战案例与海量资源一站式学习》

随着人工智能技术的迅猛发展,AI绘画已经成为了一个炙手可热的话题。特别是在设计、艺术和创意领域,AI绘画工具的出现无疑为创作者们带来了更多的可能性和便利。《Stable Diffusion AI绘画从提示词到模型出图》这本书,就是一本深入解析Stable …

【IDEA】Redis可视化神器

在开发过程中,为了方便地管理 Redis 数据库,我们可能会使用一些数据库可视化插件。这些插件通常可以帮助你在 IDE 中直观地查看和管理 Redis 数据库,包括查看键值对、执行命令、监视数据库活动等。 IDEA作为IDE界的Jenkins,本身自…

SAP 根据报错消息号快速定位问题

通常用户在业务的操作过程中,经常会遇到报错信息,有些报错是系统控制抛出的信息,但是有些报错的信息是根据不同地点业务场景对填写的数据进行判断校验,然后给出的报错信息,正常情况报错信息一般是有文本,或…

PyTorch 错误 RuntimeError: CUDA error: device-side assert triggered

训练数据的时候出现 RuntimeError:CUDA error:device-side assert triggeredCUDA kernel errors might be asynchronously reported at some other API call, so the stacktrace below might be incorrect.For debugging consider passing CUDA_LAUNCH_BLOCKING1.Compile with …

2024年统计、数据分析与大数据技术国际会议(SDBT 2024)

2024年统计、数据分析与大数据技术国际会议(SDBT 2024) 2024 International Conference on Statistics, Data Analysis, and Big Data Technology 【重要信息】 大会地点:广州 大会时间:2024年7月22日 大会官网:http…

【EI会议】2024年雷达、电子与通信工程国际会议(ICREC 2024)

2024年雷达、电子与通信工程国际会议 2024 International Conference on Radar, Electronics and Communication Engineering 【1】会议简介 2024年雷达、电子与通信工程国际会议即将在深圳隆重召开。深圳,这座充满活力的现代化都市,以其卓越的科技创新…

【Git】使用tortoiseGit

参考视频 【TortoiseGit常用的基本使用教程】 https://www.bilibili.com/video/BV193411h7FP/?share_sourcecopy_web&vd_source77e36f24add8dc77c362748ffb980148 拉取远程代码 创建分支 拉取远端dev分支的代码: 先创建本地的dev分支: 拉取&…