欢迎关注我的CSDN:https://spike.blog.csdn.net/
本文地址:https://spike.blog.csdn.net/article/details/135936848
VQ-VAE,即Vector Quantized Variational AutoEncoder,向量量化变分自编码器。VQ-VAE 的创新之处是引入了一个向量量化(VQ)层,将连续的编码器输出映射到离散的潜在空间。VQ层由一组可学习的向量组成,称为代码本(Codebook)。每个编码向量都会被替换为代码本中与最近的向量,从而实现量化。这样,VQ-VAE 可以把图片编码成离散向量。
VQ-VAE 的优点是可以生成高质量的数据,并且在数据表示上引入离散性。离散性有利于捕捉一些自然界的模态,如语言、推理、规划等。而且,离散向量也更容易被其他模型处理。VQ-VAE 的训练过程包括三个部分:编码器、解码器和代码本。编码器和解码器的训练目标是最小化重建误差,即让原始图片和重建图片尽可能相似。代码本的训练目标是最小化代码本损失,即让代码本中的向量向各自最近的编码向量靠近。此外,还有一个承诺损失(Commitment Loss),用来训练编码器,防止编码向量频繁在各个代码本向量之间跳动。
论文:Neural Discrete Representation Learning,神经离散表示学习,NIPS 2017
- Paper - Neural Discrete Representation Learning (VQ-VAE) 论文简读
判断 GPU 或 CPU 环境:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"[Info] device: {device}")
CIFAR10 数据集:
training_data = datasets.CIFAR10(root="data", train=True, download=True,
transform=transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.5,0.5,0.5), (1.0,1.0,1.0))
]))
validation_data = datasets.CIFAR10(root="data", train=False, download=True,
transform=transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.5,0.5,0.5), (1.0,1.0,1.0))
]))
训练数据的方差var:
data_variance = np.var(training_data.data / 255.0)
print(f"[Info] data_variance: {data_variance}")
# recon_error = F.mse_loss(data_recon, data) / data_variance, 用于图像重构loss的缩放
Loss:
The total loss is actually composed of three components
- reconstruction loss: which optimizes the decoder and encoder
- codebook loss: due to the fact that gradients bypass the embedding, we use a dictionary learning algorithm which uses an l 2 l_{2} l2 error to move the embedding vectors e i e_{i} ei towards the encoder output
- commitment loss: since the volume of the embedding space is dimensionless, it can grow arbirtarily if the embeddings e i e_{i} ei do not train as fast as the encoder parameters, and thus we add a commitment loss to make sure that the encoder commits to an embedding
Loss:
sg = stop gradient
Codebook Loss 初始化是均匀分布 (Uniform):
self._embedding.weight.data.uniform_(-1/self._num_embeddings, 1/self._num_embeddings)
距离计算欧式距离:
# Flatten input, 转换成二维张量
flat_input = inputs.view(-1, self._embedding_dim)
# Calculate distances, 欧式距离,或者是,余弦距离
distances = (torch.sum(flat_input**2, dim=1, keepdim=True)
+ torch.sum(self._embedding.weight**2, dim=1)
- 2 * torch.matmul(flat_input, self._embedding.weight.t()))
转换成 One-Hot 格式:
encoding_indices = torch.argmin(distances, dim=1).unsqueeze(1) # 最小索引
encodings = torch.zeros(encoding_indices.shape[0], self._num_embeddings, device=inputs.device)
encodings.scatter_(1, encoding_indices, 1) # 转换成 one-hot 形式
Codebook(self._embedding.weight
) 与 One-Hot(encodings
) 相乘,获得量化特征:
quantized = torch.matmul(encodings, self._embedding.weight).view(input_shape)
Latent特征,e_latent_loss
更新 inputs,即 encoder 网络,q_latent_loss
更新 Codebook (self._embedding
),即
- detach() 即 stop gradient 操作
# Loss
e_latent_loss = F.mse_loss(quantized.detach(), inputs) # commitment loss
q_latent_loss = F.mse_loss(quantized, inputs.detach()) # detach = stop gradient
loss = q_latent_loss + self._commitment_cost * e_latent_loss
梯度复制:
# 梯度复制的技巧, quantized的梯度与inputs的梯度连接
quantized = inputs + (quantized - inputs).detach() # trick 通过常数, 让编码器和解码器连续, 可导
困惑度计算:
# 困惑度, 即信息熵, 检测指标, 困惑度越大, 信息熵越高, 表明训练越充分
avg_probs = torch.mean(encodings, dim=0)
perplexity = torch.exp(-torch.sum(avg_probs * torch.log(avg_probs + 1e-10)))
Encoder 的输出维度是 num_hiddens
,128维,调用 _pre_vq_conv
(Conv2d),升维操作,至 embedding_dim
,64维,即:
def forward(self, x):
# print(f"[Info] x: {x.shape}")
z = self._encoder(x)
# print(f"[Info] z: {z.shape}")
z = self._pre_vq_conv(z) # 需改维度
# print(f"[Info] z: {z.shape}")
loss, quantized, perplexity, _ = self._vq_vae(z)
# print(f"[Info] quantized: {quantized.shape}")
x_recon = self._decoder(quantized)
# print(f"[Info] x_recon: {x_recon.shape}")
return loss, x_recon, perplexity
特征维度,z是inputs
,需要与quantized
的维度相等,即:
[Info] x: torch.Size([256, 3, 32, 32])
[Info] z: torch.Size([256, 128, 8, 8])
[Info] z: torch.Size([256, 64, 8, 8])
[Info] quantized: torch.Size([256, 64, 8, 8])
[Info] x_recon: torch.Size([256, 3, 32, 32])
EMA 更新技巧:
if self.training:
self._ema_cluster_size = self._ema_cluster_size * self._decay + \
(1 - self._decay) * torch.sum(encodings, 0)
# Laplace smoothing of the cluster size
n = torch.sum(self._ema_cluster_size.data)
self._ema_cluster_size = (
(self._ema_cluster_size + self._epsilon)
/ (n + self._num_embeddings * self._epsilon) * n)
dw = torch.matmul(encodings.t(), flat_input)
self._ema_w = nn.Parameter(self._ema_w * self._decay + (1 - self._decay) * dw)
self._embedding.weight = nn.Parameter(self._ema_w / self._ema_cluster_size.unsqueeze(1))
随着训练的提升,Perplexity逐渐提升,从1提升至423.421,即15000 iterations
Loss平滑:
train_res_recon_error_smooth = savgol_filter(train_res_recon_error, 201, 7)
train_res_perplexity_smooth = savgol_filter(train_res_perplexity, 201, 7)
使用 UMap 观察 Codebook 的分布:
import umap.umap_ as umap
proj = umap.UMAP(n_neighbors=3,
min_dist=0.1,
metric='cosine').fit_transform(model._vq_vae._embedding.weight.data.cpu())
plt.scatter(proj[:,0], proj[:,1], alpha=0.3)
完整源码:
#!/usr/bin/env python
# -- coding: utf-8 --
"""
Copyright (c) 2024. All rights reserved.
Created by C. L. Wang on 2024/1/30
"""
from __future__ import print_function
import os
import matplotlib.pyplot as plt
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torchvision.datasets as datasets
import torchvision.transforms as transforms
import umap.umap_ as umap
from scipy.signal import savgol_filter
from six.moves import xrange
from torch.utils.data import DataLoader
from torchvision.utils import make_grid
# --------------- 数据部分 --------------- #
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"[Info] device: {device}")
training_data = datasets.CIFAR10(root="data", train=True, download=True,
transform=transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.5,0.5,0.5), (1.0,1.0,1.0))
]))
validation_data = datasets.CIFAR10(root="data", train=False, download=True,
transform=transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.5,0.5,0.5), (1.0,1.0,1.0))
]))
data_variance = np.var(training_data.data / 255.0)
print(f"[Info] data_variance: {data_variance}")
# --------------- 数据部分 --------------- #
class VectorQuantizer(nn.Module):
def __init__(self, num_embeddings, embedding_dim, commitment_cost):
super(VectorQuantizer, self).__init__()
self._embedding_dim = embedding_dim
self._num_embeddings = num_embeddings
self._embedding = nn.Embedding(self._num_embeddings, self._embedding_dim)
# 均匀分布
self._embedding.weight.data.uniform_(-1 / self._num_embeddings, 1 / self._num_embeddings)
self._commitment_cost = commitment_cost
def forward(self, inputs):
# convert inputs from BCHW -> BHWC
# 通道维度转换为最后1维
inputs = inputs.permute(0, 2, 3, 1).contiguous()
input_shape = inputs.shape
# Flatten input, 转换成二维张量
flat_input = inputs.view(-1, self._embedding_dim)
# Calculate distances, 欧式距离,或者是,余弦距离
distances = (torch.sum(flat_input ** 2, dim=1, keepdim=True)
+ torch.sum(self._embedding.weight ** 2, dim=1)
- 2 * torch.matmul(flat_input, self._embedding.weight.t()))
# Encoding
encoding_indices = torch.argmin(distances, dim=1).unsqueeze(1) # 最近邻索引
encodings = torch.zeros(encoding_indices.shape[0], self._num_embeddings, device=inputs.device)
encodings.scatter_(1, encoding_indices, 1) # 转换成 one-hot 形式
# Quantize and unflatten, 矩阵相乘,获得 embedding vector
quantized = torch.matmul(encodings, self._embedding.weight).view(input_shape)
# Loss
e_latent_loss = F.mse_loss(quantized.detach(), inputs) # commitment loss
q_latent_loss = F.mse_loss(quantized, inputs.detach()) # detach = stop gradient
loss = q_latent_loss + self._commitment_cost * e_latent_loss
# 梯度复制的技巧
quantized = inputs + (quantized - inputs).detach() # trick 通过常数, 让编码器和解码器连续, 可导
# 困惑度, 即信息熵, 检测指标, 困惑度越大, 信息熵越高, 表明训练越充分
avg_probs = torch.mean(encodings, dim=0)
perplexity = torch.exp(-torch.sum(avg_probs * torch.log(avg_probs + 1e-10)))
# convert quantized from BHWC -> BCHW
return loss, quantized.permute(0, 3, 1, 2).contiguous(), perplexity, encodings
class Residual(nn.Module):
def __init__(self, in_channels, num_hiddens, num_residual_hiddens):
super(Residual, self).__init__()
self._block = nn.Sequential(
nn.ReLU(True),
nn.Conv2d(in_channels=in_channels,
out_channels=num_residual_hiddens,
kernel_size=3, stride=1, padding=1, bias=False),
nn.ReLU(True),
nn.Conv2d(in_channels=num_residual_hiddens,
out_channels=num_hiddens,
kernel_size=1, stride=1, bias=False)
)
def forward(self, x):
return x + self._block(x)
class ResidualStack(nn.Module):
def __init__(self, in_channels, num_hiddens, num_residual_layers, num_residual_hiddens):
super(ResidualStack, self).__init__()
self._num_residual_layers = num_residual_layers
self._layers = nn.ModuleList([Residual(in_channels, num_hiddens, num_residual_hiddens)
for _ in range(self._num_residual_layers)])
def forward(self, x):
for i in range(self._num_residual_layers):
x = self._layers[i](x)
return F.relu(x)
class Encoder(nn.Module):
def __init__(self, in_channels, num_hiddens, num_residual_layers, num_residual_hiddens):
super(Encoder, self).__init__()
self._conv_1 = nn.Conv2d(in_channels=in_channels,
out_channels=num_hiddens // 2,
kernel_size=4,
stride=2, padding=1)
self._conv_2 = nn.Conv2d(in_channels=num_hiddens // 2,
out_channels=num_hiddens,
kernel_size=4,
stride=2, padding=1)
self._conv_3 = nn.Conv2d(in_channels=num_hiddens,
out_channels=num_hiddens,
kernel_size=3,
stride=1, padding=1)
self._residual_stack = ResidualStack(in_channels=num_hiddens,
num_hiddens=num_hiddens,
num_residual_layers=num_residual_layers,
num_residual_hiddens=num_residual_hiddens)
def forward(self, inputs):
x = self._conv_1(inputs)
x = F.relu(x)
x = self._conv_2(x)
x = F.relu(x)
x = self._conv_3(x)
return self._residual_stack(x)
class Decoder(nn.Module):
def __init__(self, in_channels, num_hiddens, num_residual_layers, num_residual_hiddens):
super(Decoder, self).__init__()
# 预处理层
self._conv_1 = nn.Conv2d(in_channels=in_channels,
out_channels=num_hiddens,
kernel_size=3,
stride=1, padding=1)
# 信息提取
self._residual_stack = ResidualStack(in_channels=num_hiddens,
num_hiddens=num_hiddens,
num_residual_layers=num_residual_layers,
num_residual_hiddens=num_residual_hiddens)
# 反卷积
self._conv_trans_1 = nn.ConvTranspose2d(in_channels=num_hiddens,
out_channels=num_hiddens // 2,
kernel_size=4,
stride=2, padding=1)
# 反卷积
self._conv_trans_2 = nn.ConvTranspose2d(in_channels=num_hiddens // 2,
out_channels=3,
kernel_size=4,
stride=2, padding=1)
def forward(self, inputs):
x = self._conv_1(inputs)
x = self._residual_stack(x)
x = self._conv_trans_1(x)
x = F.relu(x)
return self._conv_trans_2(x)
class VectorQuantizerEMA(nn.Module):
def __init__(self, num_embeddings, embedding_dim, commitment_cost, decay, epsilon=1e-5):
super(VectorQuantizerEMA, self).__init__()
self._embedding_dim = embedding_dim
self._num_embeddings = num_embeddings
self._embedding = nn.Embedding(self._num_embeddings, self._embedding_dim)
self._embedding.weight.data.normal_()
self._commitment_cost = commitment_cost
self.register_buffer('_ema_cluster_size', torch.zeros(num_embeddings))
self._ema_w = nn.Parameter(torch.Tensor(num_embeddings, self._embedding_dim))
self._ema_w.data.normal_()
self._decay = decay
self._epsilon = epsilon
def forward(self, inputs):
# convert inputs from BCHW -> BHWC
inputs = inputs.permute(0, 2, 3, 1).contiguous()
input_shape = inputs.shape
# Flatten input
flat_input = inputs.view(-1, self._embedding_dim)
# Calculate distances
distances = (torch.sum(flat_input ** 2, dim=1, keepdim=True)
+ torch.sum(self._embedding.weight ** 2, dim=1)
- 2 * torch.matmul(flat_input, self._embedding.weight.t()))
# Encoding
encoding_indices = torch.argmin(distances, dim=1).unsqueeze(1)
encodings = torch.zeros(encoding_indices.shape[0], self._num_embeddings, device=inputs.device)
encodings.scatter_(1, encoding_indices, 1)
# Quantize and unflatten
quantized = torch.matmul(encodings, self._embedding.weight).view(input_shape)
# Use EMA to update the embedding vectors
if self.training:
self._ema_cluster_size = self._ema_cluster_size * self._decay + \
(1 - self._decay) * torch.sum(encodings, 0)
# Laplace smoothing of the cluster size
n = torch.sum(self._ema_cluster_size.data)
self._ema_cluster_size = (
(self._ema_cluster_size + self._epsilon)
/ (n + self._num_embeddings * self._epsilon) * n)
dw = torch.matmul(encodings.t(), flat_input)
self._ema_w = nn.Parameter(self._ema_w * self._decay + (1 - self._decay) * dw)
self._embedding.weight = nn.Parameter(self._ema_w / self._ema_cluster_size.unsqueeze(1))
# Loss
e_latent_loss = F.mse_loss(quantized.detach(), inputs)
loss = self._commitment_cost * e_latent_loss
# Straight Through Estimator
quantized = inputs + (quantized - inputs).detach()
avg_probs = torch.mean(encodings, dim=0)
perplexity = torch.exp(-torch.sum(avg_probs * torch.log(avg_probs + 1e-10)))
# convert quantized from BHWC -> BCHW
return loss, quantized.permute(0, 3, 1, 2).contiguous(), perplexity, encodings
# --------------- 参数部分 --------------- #
batch_size = 256
num_training_updates = 15000
num_hiddens = 128
num_residual_hiddens = 32
num_residual_layers = 2
embedding_dim = 64
num_embeddings = 512
commitment_cost = 0.25
decay = 0.99
learning_rate = 1e-3
# --------------- 参数部分 --------------- #
training_loader = DataLoader(training_data,
batch_size=batch_size,
shuffle=True,
pin_memory=True)
validation_loader = DataLoader(validation_data,
batch_size=32,
shuffle=True,
pin_memory=True)
# --------------- 模型部分 --------------- #
class Model(nn.Module):
def __init__(self, num_hiddens, num_residual_layers, num_residual_hiddens,
num_embeddings, embedding_dim, commitment_cost, decay=0.0):
super(Model, self).__init__()
self._encoder = Encoder(3, num_hiddens,
num_residual_layers,
num_residual_hiddens)
self._pre_vq_conv = nn.Conv2d(in_channels=num_hiddens,
out_channels=embedding_dim,
kernel_size=1,
stride=1)
if decay > 0.0:
self._vq_vae = VectorQuantizerEMA(num_embeddings, embedding_dim,
commitment_cost, decay)
else:
self._vq_vae = VectorQuantizer(num_embeddings, embedding_dim,
commitment_cost)
self._decoder = Decoder(embedding_dim,
num_hiddens,
num_residual_layers,
num_residual_hiddens)
def forward(self, x):
# print(f"[Info] x: {x.shape}")
z = self._encoder(x)
# print(f"[Info] z: {z.shape}")
z = self._pre_vq_conv(z) # 需改维度
# print(f"[Info] z: {z.shape}")
loss, quantized, perplexity, _ = self._vq_vae(z)
# print(f"[Info] quantized: {quantized.shape}")
x_recon = self._decoder(quantized)
# print(f"[Info] x_recon: {x_recon.shape}")
return loss, x_recon, perplexity
model = Model(num_hiddens, num_residual_layers, num_residual_hiddens,
num_embeddings, embedding_dim,
commitment_cost, decay).to(device)
# --------------- 模型部分 --------------- #
# --------------- 模型训练 --------------- #
model_path = "model.ckpt"
if not os.path.exists(model_path):
model.train()
train_res_recon_error = []
train_res_perplexity = []
optimizer = optim.Adam(model.parameters(), lr=learning_rate, amsgrad=False)
for i in xrange(num_training_updates):
(data, _) = next(iter(training_loader))
data = data.to(device)
optimizer.zero_grad()
vq_loss, data_recon, perplexity = model(data)
recon_error = F.mse_loss(data_recon, data) / data_variance # 重构 Loss 需要除以方差
loss = recon_error + vq_loss
loss.backward()
optimizer.step()
train_res_recon_error.append(recon_error.item())
train_res_perplexity.append(perplexity.item())
if (i + 1) % 100 == 0:
print('%d iterations' % (i + 1))
print('recon_error: %.3f' % np.mean(train_res_recon_error[-100:]))
print('perplexity: %.3f' % np.mean(train_res_perplexity[-100:]))
print()
torch.save(model.state_dict(), model_path)
train_res_recon_error_smooth = savgol_filter(train_res_recon_error, 201, 7)
train_res_perplexity_smooth = savgol_filter(train_res_perplexity, 201, 7)
f = plt.figure(figsize=(16, 8))
ax = f.add_subplot(1, 2, 1)
ax.plot(train_res_recon_error_smooth)
ax.set_yscale('log')
ax.set_title('Smoothed NMSE.')
ax.set_xlabel('iteration')
ax = f.add_subplot(1, 2, 2)
ax.plot(train_res_perplexity_smooth)
ax.set_title('Smoothed Average codebook usage (perplexity).')
ax.set_xlabel('iteration')
else:
model.load_state_dict(torch.load(model_path))
# --------------- 模型训练 --------------- #
# --------------- 重构效果 --------------- #
model.eval()
(valid_originals, _) = next(iter(validation_loader))
valid_originals = valid_originals.to(device)
vq_output_eval = model._pre_vq_conv(model._encoder(valid_originals))
_, valid_quantize, _, _ = model._vq_vae(vq_output_eval)
valid_reconstructions = model._decoder(valid_quantize)
(train_originals, _) = next(iter(training_loader))
train_originals = train_originals.to(device)
_, train_reconstructions, _, _ = model._vq_vae(train_originals)
def show(img):
npimg = img.numpy()
fig = plt.imshow(np.transpose(npimg, (1,2,0)), interpolation='nearest')
fig.axes.get_xaxis().set_visible(False)
fig.axes.get_yaxis().set_visible(False)
show(make_grid(valid_reconstructions.cpu().data)+0.5, )
show(make_grid(valid_originals.cpu()+0.5))
# --------------- 重构效果 --------------- #
# --------------- Codebook 效果 --------------- #
proj = umap.UMAP(n_neighbors=3,
min_dist=0.1,
metric='cosine').fit_transform(model._vq_vae._embedding.weight.data.cpu())
plt.scatter(proj[:,0], proj[:,1], alpha=0.3)
# --------------- Codebook 效果 --------------- #
参考:
- 源码:vq-vae
- CSDN - 详解VQVAE:Neural Discrete Representation Learning
- CSDN - scipy.signal.savgol_filter
- PyTorch - saving_loading_models
- StackOverflow - module ‘umap’ has no attribute ‘UMAP’