经典神经网络(11)VQ-VAE模型及其在MNIST数据集上的应用

经典神经网络(11)VQ-VAE模型及其在MNIST数据集上的应用

  • 我们之前已经了解了PixelCNN模型。

    经典神经网络(10)PixelCNN模型、Gated PixelCNN模型及其在MNIST数据集上的应用

  • 今天,我们了解下DeepMind在2017年提出的一种基于离散隐变量(Discrete Latent variables)的生成模型:VQ-VAE。

  • VQ-VAE采用离散隐变量,而不是像VAE那样采用连续的隐变量。其实VQ-VAE本质上是一种AE,只能很好地完成图像压缩,把图像变成一个短得多的向量,而不支持随机图像生成。

  • 那么,VQ-VAE会被归类到图像生成模型中呢?这是因为VQ-VAE单独训练了一个基于自回归的模型如PixelCNN来学习先验(prior),对VQ-VAE的离散编码空间采样。而不是像VAE那样采用一个固定的先验(标准正态分布)。

  • 此外,VQ-VAE还是一个强大的无监督表征学习模型,它学习的离散编码具有很强的表征能力:

    • OpenAI在2021年发布的文本转图像模型DALL-E就是基于VQ-VAE。
    • 另外,在BEiT中也用VQ-VAE得到的离散编码作为训练目标。
    • 注:推荐下EleutherAI团队的lucidrains(Phil Wang)的github,他开源复现了ViT、AlphaFold 2、DALLE、 DALLE2、imagen等项目
      https://github.com/lucidrains

1 VQ-VAE

1.1 从AE到VQ-VAE

  • AE是一类能够把图片压缩成较短的向量的神经网络模型。

    • AE的编码器编码出来的向量空间是不规整的。也就是说,解码器只认识经编码器编出来的向量,而不认识其他的向量。
    • 如下图,我们在code空间上,两张图片的编码点中间处取一点,然后将这一点交给解码器,我们希望新的生成图片是一张清晰的图片(类似3/4全月的样子)。但是,实际的结果是,生成图片是模糊且无法辨认的乱码图。
      在这里插入图片描述
  • 只要AE的编码空间比较规整,符合某个简单的数学分布(比如最常见的标准正态分布,如下图所示),那我们就可以从这个分布里随机采样向量,再让解码器根据这个向量来完成随机图片生成了。

    • VAE就是这样一种改进版的AE,它用一些巧妙的方法约束了编码向量z,使得z满足标准正态分布。
    • 训练完成后,我们就可以扔掉编码器,用来自标准正态分布的随机向量和解码器来实现随机图像生成了。

    在这里插入图片描述

  • VQ-VAE的作者认为,VAE的生成图片之所以质量不高,是因为图片被编码成了连续向量。而实际上,把图片编码成离散向量会更加自然。

  • 至于离散编码的原因,作者解释如下:https://avdnoord.github.io/homepage/slides/SANE2017.pdf

    在这里插入图片描述

1.2 VQVAE概述

把图像编码成离散向量后,会带来两个问题:

  • 第一个问题是,神经网络会默认输入满足一个连续的分布,而不善于处理离散的输入。

    • 如果你直接输入0, 1, 2这些数字,神经网络会默认1是一个处于0, 2中间的一种状态。为了解决这一问题,我们可以借鉴NLP中对于离散单词的处理方法。
    • 我们可以把嵌入层加到VQ-VAE的解码器前,这个嵌入层就是embedding space(嵌入空间),也称codebook
    • 注意:其实Encoder编码出来的是二维离散编码,下图画的是一维。

    在这里插入图片描述

  • 另一个问题是离散向量不好采样。

    • VAE之所以把图片编码成符合正态分布的连续向量,就是为了能在图像生成时把编码器扔掉,让随机采样出的向量也能通过解码器变成图片。现在,VQ-VAE把图片编码了一个离散向量,这个离散向量构成的空间是不好采样的。
    • VQ-VAE的作者之前设计了一种图像生成网络,叫做PixelCNN。可以用PixelCNN生成离散编码,再利用VQ-VAE的解码器把离散编码变成图像。
  • VQ-VAE的架构图,如下图所示:

    • 训练VQ-VAE的编码器和解码器,使得VQ-VAE能把图像变成latent image(下图zq),也能把latent image(下图zq)变回图像。
    • 训练PixelCNN,让它学习怎么生成latent image(下图zq)
    • 生成(采样)时,先用PixelCNN采样出latent image(下图zq),再用VQ-VAE把latent image(下图zq)翻译成最终的生成图像。

在这里插入图片描述

1.3 VQ-VAE设计细节

1.3.1 关联编码器的输出与解码器的输入

如何关联编码器的输出与解码器的输入呢?

  • 假设嵌入空间codebook已经训练完毕,那么对于编码器的每个输出向量 z e ( x ) ze(x) ze(x),我们需要找出它在嵌入空间里的最近邻 z q ( x ) zq(x) zq(x),把 z e ( x ) ze(x) ze(x)替换成 z q ( x ) zq(x) zq(x)作为解码器的输入。
  • 方式是:求最近邻,即先计算向量与嵌入空间K个向量每个向量的距离,再对距离数组取一个argmin,求出最近的下标(如上图中的shape为[1,7,7]),最后用下标去嵌入空间里取向量,就得到了 z q zq zq(如上图中的shape为[1,32,7,7])。下标构成的多维数组,也正是VQ-VAE的离散编码。

1.3.2 梯度复制

  • 我们现在能把编码器和解码器拼接到一起,但怎么让梯度从解码器的输入 z q ( x ) zq(x) zq(x)传到 z e ( x ) ze(x) ze(x)?从 z e ( x ) ze(x) ze(x) z q ( x ) zq(x) zq(x)的变换是一个从数组里取值,这个操作无法求导。
  • VQ-VAE使用了一种叫做"straight-through estimator"的技术【即前向传播和反向传播的计算可以不对应】来完成梯度复制。VQ-VAE使用了一种叫做sg(stop gradient,停止梯度)的运算:

s g ( x ) = { x , 前向传播 0 , 反向传播 前向传播时, s g 里的值不变;反向传播时, s g 按值为 0 求导,即此次计算无梯度。 sg(x)=\begin{cases} x, & 前向传播\\ 0,& 反向传播 \end{cases}\\ 前向传播时,sg里的值不变;反向传播时,sg按值为0求导,即此次计算无梯度。 sg(x)={x,0,前向传播反向传播前向传播时,sg里的值不变;反向传播时,sg按值为0求导,即此次计算无梯度。

由于VQ-VAE其实是一个AE,误差函数里应该只有原图像和目标图像的重建误差:
L r e c o n s t r u c t = ∣ ∣ x − d e c o d e r ( z q ( x ) ) ∣ ∣ 2 2 L_{reconstruct}=||x-decoder(z_q(x))||_2^2 Lreconstruct=∣∣xdecoder(zq(x))22
我们现在利用sg运算,设计新的重建误差:
L r e c o n s t r u c t = ∣ ∣ x − d e c o d e r ( z e ( x ) + s g [ z q ( x ) − z e ( x ) ] ) ∣ ∣ 2 2 前向传播时,就是拿解码器的输入 z q ( x ) 来算误差: L r e c o n s t r u c t = ∣ ∣ x − d e c o d e r ( z e ( x ) + z q ( x ) − z e ( x ) ) ∣ ∣ 2 2 = ∣ ∣ x − d e c o d e r ( z q ( x ) ) ∣ ∣ 2 2 反向传播时,等价于把解码器的梯度全部传给 z e ( x ) : L r e c o n s t r u c t = ∣ ∣ x − d e c o d e r ( z e ( x ) + s g [ z q ( x ) − z e ( x ) ] ) ∣ ∣ 2 2 = ∣ ∣ x − d e c o d e r ( z e ( x ) ) ∣ ∣ 2 2 L_{reconstruct}=||x-decoder(z_e(x)+sg[z_q(x)-z_e(x)])||_2^2\\ 前向传播时,就是拿解码器的输入z_q(x)来算误差:\\ L_{reconstruct}=||x-decoder(z_e(x)+z_q(x)-z_e(x))||_2^2\\ =||x-decoder(z_q(x))||_2^2\\ 反向传播时,等价于把解码器的梯度全部传给z_e(x):\\ L_{reconstruct}=||x-decoder(z_e(x)+sg[z_q(x)-z_e(x)])||_2^2\\ =||x-decoder(z_e(x))||_2^2 Lreconstruct=∣∣xdecoder(ze(x)+sg[zq(x)ze(x)])22前向传播时,就是拿解码器的输入zq(x)来算误差:Lreconstruct=∣∣xdecoder(ze(x)+zq(x)ze(x))22=∣∣xdecoder(zq(x))22反向传播时,等价于把解码器的梯度全部传给ze(x)Lreconstruct=∣∣xdecoder(ze(x)+sg[zq(x)ze(x)])22=∣∣xdecoder(ze(x))22
在PyTorch里,(x).detach()就是sg(x),它的值在前向传播时取x,反向传播时取0

# stop gradient
decoder_input = ze + (zq - ze).detach()
# decode
x_hat = decoder(decoder_input)
# l_reconstruct
l_reconstruct = mse_loss(x, x_hat)

1.3.3 优化嵌入空间codebook

嵌入空间的优化目标是什么呢?嵌入空间的每一个向量应该能概括一类编码器输出的向量。因此,嵌入空间的向量应该和其对应编码器输出尽可能接近。
L e = ∣ ∣ z e ( x ) − z q ( x ) ∣ ∣ 2 2 z e ( x ) 是编码器的输出向量, z q ( x ) 是其在嵌入空间的最近邻向量 L_e=||z_e(x)-z_q(x)||_2^2\\ z_e(x)是编码器的输出向量,z_q(x)是其在嵌入空间的最近邻向量 Le=∣∣ze(x)zq(x)22ze(x)是编码器的输出向量,zq(x)是其在嵌入空间的最近邻向量
作者认为,编码器和嵌入向量的学习速度应该不一样快。

于是,他们再次使用了停止梯度的技巧,把上面那个误差函数拆成了两部分。其中,β控制了编码器的相对学习速度。作者发现,算法对β的变化不敏感,β取0.1~2.0都差不多。
L e = ∣ ∣ s g [ z e ( x ) ] − z q ( x ) ∣ ∣ 2 2 + β ∣ ∣ z e ( x ) − s g [ z q ( x ) ] ∣ ∣ 2 2 L_e=||sg[z_e(x)]-z_q(x)||_2^2+\beta||z_e(x)-sg[z_q(x)]||_2^2\\ Le=∣∣sg[ze(x)]zq(x)22+β∣∣ze(x)sg[zq(x)]22

# vq loss
l_embedding = mse_loss(ze.detach(), zq)
# commitment loss
l_commitment = mse_loss(ze, zq.detach())

VQ-VAE总体的损失函数可以写成:
L t o t a l = L r e c o n s t r u c t + L e = ∣ ∣ x − d e c o d e r ( z e ( x ) + s g [ z q ( x ) − z e ( x ) ] ) ∣ ∣ 2 2 + α ∣ ∣ s g [ z e ( x ) ] − z q ( x ) ∣ ∣ 2 2 + β ∣ ∣ z e ( x ) − s g [ z q ( x ) ] ∣ ∣ 2 2 L_{total}=L_{reconstruct} + L_e \\ =||x-decoder(z_e(x)+sg[z_q(x)-z_e(x)])||_2^2 +\alpha||sg[z_e(x)]-z_q(x)||_2^2\\+\beta||z_e(x)-sg[z_q(x)]||_2^2 Ltotal=Lreconstruct+Le=∣∣xdecoder(ze(x)+sg[zq(x)ze(x)])22+α∣∣sg[ze(x)]zq(x)22+β∣∣ze(x)sg[zq(x)]22

# reconstruct loss
l_reconstruct = mse_loss(x, x_hat)
# vq loss
l_embedding = mse_loss(ze.detach(), zq)
# commitment loss
l_commitment = mse_loss(ze, zq.detach())

# total loss
loss = l_reconstruct + \
                l_w_embedding * l_embedding + l_w_commitment * l_commitment

1.3.4 先验模型PixelCNN

  • 训练好VQ-VAE后,还需要训练一个先验模型来完成数据生成,论文中采用PixelCNN模型。
  • 这里我们不再是学习生成原始的pixels,而是学习生成离散编码:
    • 首先,我们需要用已经训练好的VQ-VAE模型对训练图像推理,得到每张图像对应的离散编码;
    • 然后用一个PixelCNN来对离散编码进行建模
    • 最后的预测层采用基于softmax的多分类,类别数为embedding空间的大小K。
  • 那么,生成图像的过程就比较简单了,首先用训练好的PixelCNN模型来采样一个离散编码样本(上图中shape为[1, 32, 7, 7]),然后送入VQ-VAE的decoder中,得到生成的图像。
  • 实际上,PixelCNN不是唯一可用的拟合离散分布的模型。我们可以把它换成Transformer,甚至是diffusion模型。

2 VQ-VAE模型在MNIST数据集上的应用

这里使用的模型为Gated PixelCNN模型,具体可参考:

经典神经网络(10)PixelCNN模型、Gated PixelCNN模型及其在MNIST数据集上的应用

网络结构图如下所示:

在这里插入图片描述

2.1 VQ-VAE模型

VQVAE的编码器和解码器的结构很简单,仅由普通的上/下采样层和残差块组成。

  • 编码器先是有两个3x3卷积+2倍下采样卷积的模块,再有两个残差块(ReLU, 3x3卷积, ReLU, 1x1卷积);
  • 解码器则反过来,先有两个残差块,再有两个3x3卷积+2倍上采样反卷积的模块。
# Reference: https://github.com/SingleZombie/DL-Demos/tree/master/dldemos/VQVAE
import os
import time

import cv2
import einops
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F

import torchvision
from torch.utils.data import DataLoader, Dataset
from torch.utils.data.distributed import DistributedSampler
from torchvision import transforms
from GatedPixelCNNDemo import GatedPixelCNN, GatedBlock


class ResidualBlock(nn.Module):

    def __init__(self, dim):
        super().__init__()
        self.relu = nn.ReLU()
        self.conv1 = nn.Conv2d(dim, dim, kernel_size=3, stride=1, padding=1)
        self.conv2 = nn.Conv2d(dim, dim, kernel_size=1)

    def forward(self, x):
        tmp = self.relu(x)
        tmp = self.conv1(tmp)
        tmp = self.relu(tmp)
        tmp = self.conv2(tmp)
        return x + tmp


class VQVAE(nn.Module):

    def __init__(self, input_dim, dim, n_embedding):
        super().__init__()
        # 1、编码器
        self.encoder = nn.Sequential(nn.Conv2d(input_dim, dim, kernel_size=4, stride=2, padding=1),
                                     nn.ReLU(),
                                     nn.Conv2d(dim, dim, kernel_size=4, stride=2, padding=1),
                                     nn.ReLU(),
                                     nn.Conv2d(dim, dim, kernel_size=3, stride=1, padding=1),
                                     ResidualBlock(dim),
                                     ResidualBlock(dim)
                          )

        self.vq_embedding = nn.Embedding(n_embedding, dim)
        # 初始化为均匀分布
        self.vq_embedding.weight.data.uniform_(-1.0 / n_embedding, 1.0 / n_embedding)
        # 2、解码器
        self.decoder = nn.Sequential(
            nn.Conv2d(dim, dim, 3, 1, 1),
            ResidualBlock(dim),
            ResidualBlock(dim),
            nn.ConvTranspose2d(dim, dim, 4, 2, 1),
            nn.ReLU(),
            nn.ConvTranspose2d(dim, input_dim, 4, 2, 1)
        )
        self.n_downsample = 2

    def forward(self, x):
        # encode [N, 1, 28, 28] -> [N, 32, 7, 7]
        ze = self.encoder(x)

        # ze: [N, C, H, W]
        # embedding [K, C]  [32, 32]
        embedding = self.vq_embedding.weight.data
        N, C, H, W = ze.shape
        K, _ = embedding.shape
        # 求解最近邻
        embedding_broadcast = embedding.reshape(1, K, C, 1, 1)
        ze_broadcast = ze.reshape(N, 1, C, H, W)
        distance = torch.sum((embedding_broadcast - ze_broadcast)**2, 2)
        nearest_neighbor = torch.argmin(distance, 1)
        # make C to the second dim
        zq = self.vq_embedding(nearest_neighbor).permute(0, 3, 1, 2)
        # stop gradient
        decoder_input = ze + (zq - ze).detach()

        # decode
        x_hat = self.decoder(decoder_input)
        return x_hat, ze, zq

    @torch.no_grad()
    def encode(self, x):
        ze = self.encoder(x)
        embedding = self.vq_embedding.weight.data

        # ze: [N, C, H, W]
        # embedding [K, C]
        N, C, H, W = ze.shape
        K, _ = embedding.shape
        embedding_broadcast = embedding.reshape(1, K, C, 1, 1)
        ze_broadcast = ze.reshape(N, 1, C, H, W)
        distance = torch.sum((embedding_broadcast - ze_broadcast)**2, 2)
        nearest_neighbor = torch.argmin(distance, 1)
        return nearest_neighbor

    @torch.no_grad()
    def decode(self, discrete_latent):
        zq = self.vq_embedding(discrete_latent).permute(0, 3, 1, 2)
        x_hat = self.decoder(zq)
        return x_hat

    # Shape: [C, H, W]
    def get_latent_HW(self, input_shape):
        C, H, W = input_shape
        return (H // 2**self.n_downsample, W // 2**self.n_downsample)

2.2 先验模型

我们已经有了一个普通的PixelCNN模型GatedPixelCNN

  • 需要在整个模型的最前面套一个嵌入层,嵌入层的嵌入个数等于离散编码的个数(color_level),嵌入长度等于模型的特征长度(p)。
  • 由于嵌入层会直接输出一个长度为p的向量,我们还需要把第一个模块的输入通道数改成p
# 继承自我们之前实现的模型GatedPixelCNN
class PixelCNNWithEmbedding(GatedPixelCNN):

    def __init__(self, n_blocks, p, linear_dim, bn=True, color_level=256):
        super().__init__(n_blocks, p, linear_dim, bn, color_level)
        self.embedding = nn.Embedding(color_level, p)
        self.block1 = GatedBlock('A', p, p, bn)

    def forward(self, x):
        x = self.embedding(x)
        x = x.permute(0, 3, 1, 2).contiguous()
        return super().forward(x)

2.3 两种模型的训练

  • 下面就是常规的训练代码
  • 先训练VQVAE、再训练PixelCNN
def train_vqvae(model: VQVAE,
                img_shape=None,
                device='cuda',
                ckpt_path='./model.pth',
                batch_size=64,
                dataset_type='MNIST',
                lr=1e-3,
                n_epochs=100,
                l_w_embedding=1,
                l_w_commitment=0.25):
    print('batch size:', batch_size)
    dataloader = get_dataloader(dataset_type,
                                batch_size,
                                img_shape=img_shape)
    model.to(device)
    model.train()
    optimizer = torch.optim.Adam(model.parameters(), lr)
    mse_loss = nn.MSELoss()
    tic = time.time()
    for e in range(n_epochs):
        total_loss = 0

        for x in dataloader:
            current_batch_size = x.shape[0]
            x = x.to(device)

            x_hat, ze, zq = model(x)
            # 1、reconstruct loss
            l_reconstruct = mse_loss(x, x_hat)
            # 2、vq loss + commitment loss
            l_embedding = mse_loss(ze.detach(), zq)
            l_commitment = mse_loss(ze, zq.detach())
            # total loss
            loss = l_reconstruct + \
                l_w_embedding * l_embedding + l_w_commitment * l_commitment
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            total_loss += loss.item() * current_batch_size
        total_loss /= len(dataloader.dataset)
        toc = time.time()
        torch.save(model.state_dict(), ckpt_path)
        print(f'epoch {e} loss: {total_loss} elapsed {(toc - tic):.2f}s')
    print('Done')


def train_generative_model(vqvae: VQVAE,
                           model,
                           img_shape=None,
                           device='cuda',
                           ckpt_path='./gen_model.pth',
                           dataset_type='MNIST',
                           batch_size=64,
                           n_epochs=50):
    print('batch size:', batch_size)
    dataloader = get_dataloader(dataset_type,
                                batch_size,
                                img_shape=img_shape)
    vqvae.to(device)
    vqvae.eval()
    model.to(device)
    model.train()
    optimizer = torch.optim.Adam(model.parameters(), 1e-3)
    # 交叉熵损失
    loss_fn = nn.CrossEntropyLoss()
    tic = time.time()
    for e in range(n_epochs):
        total_loss = 0
        for x in dataloader:
            current_batch_size = x.shape[0]
            with torch.no_grad():
                x = x.to(device)
                # 1、训练好的VQ-VAE模型对训练图像推理,得到每张图像对应的离散编码
                x = vqvae.encode(x)
            # 2、用一个PixelCNN来对离散编码进行建模
            predict_x = model(x)
            # 3、预测层采用基于softmax的多分类
            loss = loss_fn(predict_x, x)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            total_loss += loss.item() * current_batch_size
        total_loss /= len(dataloader.dataset)
        toc = time.time()
        torch.save(model.state_dict(), ckpt_path)
        print(f'epoch {e} loss: {total_loss} elapsed {(toc - tic):.2f}s')
    print('Done')


def reconstruct(model, x, device, dataset_type='MNIST'):
    model.to(device)
    model.eval()
    with torch.no_grad():
        x_hat, _, _ = model(x)
    n = x.shape[0]
    n1 = int(n**0.5)
    x_cat = torch.concat((x, x_hat), 3)
    x_cat = einops.rearrange(x_cat, '(n1 n2) c h w -> (n1 h) (n2 w) c', n1=n1)
    x_cat = (x_cat.clip(0, 1) * 255).cpu().numpy().astype(np.uint8)
    cv2.imwrite(f'work_dirs/vqvae_reconstruct_{dataset_type}.jpg', x_cat)
class MNISTImageDataset(Dataset):

    def __init__(self, img_shape=(28, 28)):
        super().__init__()
        self.img_shape = img_shape
        self.mnist = torchvision.datasets.MNIST(root='/root/autodl-fs/data/minist')

    def __len__(self):
        return len(self.mnist)

    def __getitem__(self, index: int):
        img = self.mnist[index][0]
        pipeline = transforms.Compose(
            [transforms.Resize(self.img_shape),
             transforms.ToTensor()])
        return pipeline(img)


def get_dataloader(type,
                   batch_size,
                   img_shape=None,
                   dist_train=False,
                   num_workers=0,
                   **kwargs):
    if type == 'MNIST':
        if img_shape is not None:
            dataset = MNISTImageDataset(img_shape)
        else:
            dataset = MNISTImageDataset()
    if dist_train:
        sampler = DistributedSampler(dataset)
        dataloader = DataLoader(dataset,
                                batch_size=batch_size,
                                sampler=sampler,
                                num_workers=num_workers)
        return dataloader, sampler
    else:
        dataloader = DataLoader(dataset,
                                batch_size=batch_size,
                                shuffle=True,
                                num_workers=num_workers)
        return dataloader



cfg = dict(dataset_type='MNIST',
                  img_shape=(1, 28, 28),
                  dim=32,
                  n_embedding=32,
                  batch_size=32,
                  n_epochs=20,
                  l_w_embedding=1,
                  l_w_commitment=0.25,
                  lr=2e-4,
                  n_epochs_2=50,
                  batch_size_2=32,
                  pixelcnn_n_blocks=15,
                  pixelcnn_dim=128,
                  pixelcnn_linear_dim=32,
                  vqvae_path='./model_mnist.pth',
                  gen_model_path='./gen_model_mnist.pth')


if __name__ == '__main__':
    os.makedirs('work_dirs', exist_ok=True)
    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    img_shape = cfg['img_shape']
    # 初始化模型
    vqvae = VQVAE(img_shape[0], cfg['dim'], cfg['n_embedding'])

    gen_model = PixelCNNWithEmbedding(cfg['pixelcnn_n_blocks'],
                                      cfg['pixelcnn_dim'],
                                      cfg['pixelcnn_linear_dim'], True,
                                      cfg['n_embedding'])
    # 1. Train VQVAE
    train_vqvae(vqvae,
                img_shape=(img_shape[1], img_shape[2]),
                device=device,
                ckpt_path=cfg['vqvae_path'],
                batch_size=cfg['batch_size'],
                dataset_type=cfg['dataset_type'],
                lr=cfg['lr'],
                n_epochs=cfg['n_epochs'],
                l_w_embedding=cfg['l_w_embedding'],
                l_w_commitment=cfg['l_w_commitment'])

    # 2. Test VQVAE by visualizaing reconstruction result
    vqvae.load_state_dict(torch.load(cfg['vqvae_path']))
    dataloader = get_dataloader(cfg['dataset_type'],
                                16,
                                img_shape=(img_shape[1], img_shape[2]))
    img = next(iter(dataloader)).to(device)
    reconstruct(vqvae, img, device, cfg['dataset_type'])

    # 3. Train Generative model (Gated PixelCNN)
    vqvae.load_state_dict(torch.load(cfg['vqvae_path']))

    train_generative_model(vqvae,
                           gen_model,
                           img_shape=(img_shape[1], img_shape[2]),
                           device=device,
                           ckpt_path=cfg['gen_model_path'],
                           dataset_type=cfg['dataset_type'],
                           batch_size=cfg['batch_size_2'],
                           n_epochs=cfg['n_epochs_2'])
    # 4. Sample VQVAE
    vqvae.load_state_dict(torch.load(cfg['vqvae_path']))
    gen_model.load_state_dict(torch.load(cfg['gen_model_path']))
    sample_imgs(vqvae,
                gen_model,
                cfg['img_shape'],
                device=device,
                n_sample=1,
                dataset_type=cfg['dataset_type'])

2.4 图像生成(采样)

def sample_imgs(vqvae: VQVAE,
                gen_model,
                img_shape,
                n_sample=81,
                device='cuda',
                dataset_type='MNIST'):
    vqvae = vqvae.to(device)
    vqvae.eval()
    gen_model = gen_model.to(device)
    gen_model.eval()

    C, H, W = img_shape
    H, W = vqvae.get_latent_HW((C, H, W))
    input_shape = (n_sample, H, W)
    # 初始化为0
    x = torch.zeros(input_shape).to(device).to(torch.long)
    with torch.no_grad():
        # 逐像素预测
        for i in range(H):
            for j in range(W):
                output = gen_model(x)
                prob_dist = F.softmax(output[:, :, i, j], -1)
                # 从概率分布中采样
                pixel = torch.multinomial(prob_dist, 1)
                x[:, i, j] = pixel[:, 0]
    # 解码
    imgs = vqvae.decode(x)

    imgs = imgs * 255
    imgs = imgs.clip(0, 255)
    imgs = einops.rearrange(imgs,
                            '(n1 n2) c h w -> (n1 h) (n2 w) c',
                            n1=int(n_sample**0.5))

    imgs = imgs.detach().cpu().numpy().astype(np.uint8)
    cv2.imwrite(f'work_dirs/vqvae_sample_{dataset_type}.jpg', imgs)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/700726.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

网页五子棋对战项目测试(selenium+Junit5)

目录 网页五子棋对战项目介绍 网页五子棋对战测试的思维导图​ 网页五子棋对战的UI自动化测试 测试一:测试注册界面 测试二:测试登陆界面 测试三:测试游戏大厅界面 测试四:测试游戏房间界面以及观战房间界面 测试五&#…

使用开源的zip.cpp和unzip.cpp实现压缩包的创建与解压(附源码)

目录 1、使用场景 2、压缩包的创建 3、压缩包的解压 4、CloseZipZ和CloseZipU两接口的区别 5、开源zip.cpp和unzip.cpp文件的下载 VC++常用功能开发汇总(专栏文章列表,欢迎订阅,持续更新...)https://blog.csdn.net/chenlycly/article/details/124272585C++软件异常排…

Redis集群(5)

集群原理 节点通信 通信流程 在分布式存储系统中,维护节点元数据(如节点负责的数据、节点的故障状态等)是关键任务。常见的元数据维护方式分为集中式和P2P方式。Redis集群采用P2P的Gossip协议,这种协议的工作原理是节点之间不断…

使用 FormCreate 快速创建仿真页面

在现代前端开发中,快速创建和迭代仿真页面是提高开发效率和用户体验的关键。FormCreate 是一个强大的工具,它通过 JSON 生成具有动态渲染、数据收集、验证和提交功能的表单组件,支持多种 UI 框架。本文将介绍如何使用 FormCreate 快速创建一个…

svg图标封装--基于vue2适配uniapp全端

第一步&#xff1a;新建svg目录 在static目录下新建svg目录,后将所有svg图标都放到此文件夹 第二步&#xff1a;封装注册全局组件 (注意&#xff1a;在根目录下新建components文件夹) 代码实现&#xff1a; <template><!-- svg图标 --><image :style"{ …

【python】OpenCV—Background Estimation(15)

文章目录 中值滤波中值滤波得到图像背景移动侦测 学习来自 OpenCV基础&#xff08;14&#xff09;OpenCV在视频中的简单背景估计 中值滤波 中值滤波是一种非线性平滑技术&#xff0c;主要用于数字信号处理&#xff0c;特别是在图像处理中去除噪声。 一、定义与原理 定义&am…

MATLAB算法实战应用案例精讲-【数模应用】多分类Logit分析(附python、R语言和MATLAB代码实现)

目录 算法原理 成对类别有序logit 簇族数据中的超散布性 条件独立性检验 SPSS-有序多分类Logistic回归 SPSSAU 参照项设置 案例应用 代码实现 R语言 逻辑回归 决策树 随机森林 支持向量机 评价分类的准确性 MATLAB python 算法原理 成对类别有序logit libr…

maven基本操作和配置(idea版基础版)

写在前面&#xff1a;为一位朋友写的一个博客&#xff0c;有需要都可以查看&#xff01; 一、maven是什么&#xff1f; 一句话&#xff1a;管理依赖工具&#xff0c;统一项目结构便于开发&#xff0c;把项目开发和管理的过程抽象成对象模型来管理&#xff08;pom模型&#xf…

Milvus 2.4 向量库安装部署

1、linux 已有docker环境 2、安装fio命令 yum install -y fio 2、mkdir test-data fio --rwwrite --ioenginesync --fdatasync1 --directorytest-data --size2200m --bs2300 --namemytest ctrlc 3、lscpu 4、docker -v 6、安装docker compose组件 yum -y install python3-…

八、C语言:操作符详解

一、移位操作符 1.1左移操作 左边丢弃&#xff0c;右边补0 1.2右移操作 算数右移&#xff1a;右边丢弃&#xff0c;左边补原符号位 逻辑右移&#xff1a;右边丢弃&#xff0c;左边补0 int main() {int a -1;int b a >> 1;printf("b%d\n",b);return 0; } 原码…

贪吃蛇小游戏简单制作-C语言

文章目录 游戏背景介绍实现目标适合人群所需技术浅玩Window API什么是API控制台程序窗口大小,名称设置 Handle(句柄)获取句柄 坐标结构体设置光标位置 光标属性获取光标属性设置光标属性 按键信息获取 贪吃蛇游戏设计游戏前的初始化设置窗口的大小和名称本地化设置 宽字符Waht …

采用PHP开发的一套(项目源码)医疗安全(不良)事件报告系统源码:统计分析,持续整改,完成闭环管理

采用PHP开发的一套&#xff08;项目源码&#xff09;医疗安全&#xff08;不良&#xff09;事件报告系统源码&#xff1a;统计分析&#xff0c;持续整改&#xff0c;完成闭环管理 医疗安全确实是医疗领域中不容忽视的重要问题。医院不良安全事件&#xff0c;即医疗质量安全不良…

宋街宣传活动-循环利用,绿色生活

善于善行回收团队是一支致力于推动环保事业&#xff0c;积极倡导和实践绿色生活的志愿者队伍。我们的宗旨是通过回收再利用&#xff0c;减少资源浪费&#xff0c;降低环境污染&#xff0c;同时提高公众的环保意识&#xff0c;共同构建美丽和谐的家园。 善于善行志愿团队于2024年…

免费、无广告、界面简洁、简单好用的轻量级思维导图软件

一、简介 1、一款免费、无广告、界面简洁、简单好用的轻量级思维导图软件。它目前支持 Windows、MacOS 平台。其中 Windows 版大小在 104MB 左右&#xff08;UWP 应用&#xff09;&#xff0c;Mac 版大小在 167MB 左右。 二、下载 1、下载地址&#xff1a; MindAtom官网&#…

【保姆级讲解下QT6.3】

&#x1f3a5;博主&#xff1a;程序员不想YY啊 &#x1f4ab;CSDN优质创作者&#xff0c;CSDN实力新星&#xff0c;CSDN博客专家 &#x1f917;点赞&#x1f388;收藏⭐再看&#x1f4ab;养成习惯 ✨希望本文对您有所裨益&#xff0c;如有不足之处&#xff0c;欢迎在评论区提出…

用户和权限

Linux的root用户 无论是Windows、MacOS、Linux均采用多用户的管理模式进行权限管理 超级管理员: root用户拥有最大的系统操作权限(不建议长期使用root用户&#xff0c;避免带来系统损坏)普通用户的权限: 一般在其HOME目录内是不受限的,在HOME目录外仅有只读和执行权限&#x…

go-zero整合Excelize并实现Excel导入导出

go-zero整合Excelize并实现Excel导入导出 本教程基于go-zero微服务入门教程&#xff0c;项目工程结构同上一个教程。 本教程主要实现go-zero框架整合Excelize&#xff0c;并暴露接口实现Excel模板下载、Excel导入、Excel导出。 go-zero微服务入门教程&#xff1a;https://blo…

【深度学习】AI换脸,EasyPhoto: Your Personal AI Photo Generator【一】

论文&#xff1a;https://arxiv.org/abs/2310.04672 文章目录 摘要IntroductionTraining Process3 推理过程3.1 面部预处理3.3 第二扩散阶段3.4 多用户ID 4 任意ID5 实验6 结论 下篇文章进行实战。 摘要 稳定扩散Web UI&#xff08;Stable Diffusion Web UI&#xff0c;简称…

MYSQL八、MYSQL的SQL优化

一、SQL优化 sql优化是指&#xff1a;通过对sql语句和数据库结构的调整&#xff0c;来提高数据库查询、插入、更新和删除等操作的性能和效率。 1、插入数据优化 要一次性往数据库表中插入多条记录&#xff1a; insert into tb_test values(1,tom); insert into tb_tes…

CyberDAO:引领Web3时代的DAO社区文化

致力于Web3研究和孵化 CyberDAO自成立以来&#xff0c;致力于推动Web3研究和孵化&#xff0c;吸引了来自技术、资本、商业、应用与流量等领域的上千名热忱成员。我们为社区提供多元的Web3产品和商业机会&#xff0c;触达行业核心&#xff0c;助力成员捕获Web3.0时代的红利。 目…