论文:Glow: Generative Flow with Invertible 1x1 Convolutions
代码:pytorch版本:rosinality/glow-pytorch: PyTorch implementation of Glow (github.com)
正版是TensorFlow版本 openai的
参考csdn文章:Glow-pytorch复现github项目_pytorch glow-CSDN博客
(pytorch进阶之路)NormalizingFlow标准流_normalizing flow-CSDN博客
需要先看一下b站的Flow的讲解Flow-based Generative Model_哔哩哔哩_bilibili P59
本csdn文的目标:跑通代码+理解原理(不包含论文结果部分解读)
目录
1 引言
2 背景:
3 Generative Flow
Glow模块的整体代码:
Block模块:
Flow模块:
3.1 Actnorm: scale and bias layer with data dependent initialization
3.2 Invertible 1 1 convolution 可逆1*1卷积
3.3 Affine Coupling Layers 仿射耦合层
train部分
1 引言
基于flow模型改进,提出Glow
2 背景:
之前是基于flow的生成模型,我们的目标是从z(一个普通的分布)拟合到x(真实的分布),理解为从图A变为图B,而且要求这个过程是可逆的。
模型为G(x),目标最大化极大似然(最大似然理解为当参数为变量时,X=x的概率最大化):
也就是最后的这个。即最小化:
其中,flow的意思就是多个G连起来:
最终最大化下面这个,即:
其中,z的分布的选取一般为正态分布,均值为0函数G为双摄可逆函数,,可逆回去。
在计算方面,最后可以等于雅可比行列式的对角线。
3 Generative Flow
flow的每一步都由actnorm(3.1)、一个可逆的1x1卷积(3.2)和一个耦合层(3.3)组成。flow的深度为K,层数为L,下图。
Glow模块的整体代码:
class Glow(nn.Module):
def __init__(self, in_channel, n_flow, n_block, affine=True, conv_lu=True): #n_flow为K,n_block为L
super().__init__()
self.blocks = nn.ModuleList() #blocks层为图b的堆叠
n_channel = in_channel
for i in range(n_block - 1):
self.blocks.append(Block(n_channel, n_flow, affine=affine, conv_lu=conv_lu))
n_channel *= 2 #最后一个Block通道*2
self.blocks.append(Block(n_channel, n_flow, split=False, affine=affine))
def forward(self, input):
log_p_sum = 0
logdet = 0
out = input
z_outs = [] #中间z
for block in self.blocks:
out, det, log_p, z_new = block(out) #循环 out
z_outs.append(z_new)
logdet = logdet + det #logdet求和
if log_p is not None:
log_p_sum = log_p_sum + log_p #log_p求和
return log_p_sum, logdet, z_outs # 输出log_p和logdet,以及最后的z序列
def reverse(self, z_list, reconstruct=False):
for i, block in enumerate(self.blocks[::-1]):#最后一个block去掉
if i == 0:
input = block.reverse(z_list[-1], z_list[-1], reconstruct=reconstruct)
else:
input = block.reverse(input, z_list[-(i + 1)], reconstruct=reconstruct)
return input
Block模块:
class Block(nn.Module):
def __init__(self, in_channel, n_flow, split=True, affine=True, conv_lu=True):
super().__init__()
squeeze_dim = in_channel * 4 #扩大4倍
self.flows = nn.ModuleList()
for i in range(n_flow): #内部Flow块,一共n_flow块
self.flows.append(Flow(squeeze_dim, affine=affine, conv_lu=conv_lu))
self.split = split
if split:
self.prior = ZeroConv2d(in_channel * 2, in_channel * 4)
else:
self.prior = ZeroConv2d(in_channel * 4, in_channel * 8)
def forward(self, input):
b_size, n_channel, height, width = input.shape
squeezed = input.view(b_size, n_channel, height // 2, 2, width // 2, 2) #尺寸变小
squeezed = squeezed.permute(0, 1, 3, 5, 2, 4) #[b,c,h,2,w,2]变成[b,c,2,2,h,w]
out = squeezed.contiguous().view(b_size, n_channel * 4, height // 2, width // 2) #深拷贝重新创建out
logdet = 0
for flow in self.flows:
out, det = flow(out)
logdet = logdet + det
if self.split:
out, z_new = out.chunk(2, 1) #分块,dim=1分2块
mean, log_sd = self.prior(out).chunk(2, 1)
log_p = gaussian_log_p(z_new, mean, log_sd)
log_p = log_p.view(b_size, -1).sum(1)
else:
zero = torch.zeros_like(out)
mean, log_sd = self.prior(zero).chunk(2, 1)
log_p = gaussian_log_p(out, mean, log_sd)
log_p = log_p.view(b_size, -1).sum(1)
z_new = out
return out, logdet, log_p, z_new
def reverse(self, output, eps=None, reconstruct=False):
input = output
if reconstruct:
if self.split:
input = torch.cat([output, eps], 1)
else:
input = eps
else:
if self.split:
mean, log_sd = self.prior(input).chunk(2, 1)
z = gaussian_sample(eps, mean, log_sd)
input = torch.cat([output, z], 1)
else:
zero = torch.zeros_like(input)
# zero = F.pad(zero, [1, 1, 1, 1], value=1)
mean, log_sd = self.prior(zero).chunk(2, 1)
z = gaussian_sample(eps, mean, log_sd)
input = z
for flow in self.flows[::-1]:
input = flow.reverse(input)
b_size, n_channel, height, width = input.shape
unsqueezed = input.view(b_size, n_channel // 4, 2, 2, height, width)
unsqueezed = unsqueezed.permute(0, 1, 4, 2, 5, 3)
unsqueezed = unsqueezed.contiguous().view(b_size, n_channel // 4, height * 2, width * 2)
return unsqueezed
Flow模块:
class Flow(nn.Module):
def __init__(self, in_channel, affine=True, conv_lu=True):
super().__init__()
self.actnorm = ActNorm(in_channel)
if conv_lu:
self.invconv = InvConv2dLU(in_channel)
else:
self.invconv = InvConv2d(in_channel)
self.coupling = AffineCoupling(in_channel, affine=affine)
def forward(self, input):
out, logdet = self.actnorm(input)
out, det1 = self.invconv(out)
out, det2 = self.coupling(out)
logdet = logdet + det1
if det2 is not None:
logdet = logdet + det2
return out, logdet
def reverse(self, output):
input = self.coupling.reverse(output)
input = self.invconv.reverse(input)
input = self.actnorm.reverse(input)
return input
3.1 Actnorm: scale and bias layer with data dependent initialization
之前提出批归一化来缓解训练深度模型时遇到的问题。然而,由于批处理归一化(batch normalization)所增加的激活噪声的方差与每个GPU或其他处理单元(PU)的小批(minibatch)大小成反比,因此已知每个PU的小批大小会降低性能。因此,minibatch=1. 我们提出了一个actnorm层(用于激活归一化),它使用每个通道的尺度和偏置参数执行激活的仿射变换,类似于批量归一化。这些参数被初始化,使得每个通道的事后激活具有零均值和给定初始小批量数据的单位方差。这是数据依赖初始化的一种形式(Salimans and Kingma 2016)。初始化后,尺度和偏差被视为独立于数据的常规可训练参数。(没怎么懂,看代码吧)
在Flow模块中的第一层就是ActNorm。这一步其实就是一个标准化,对于input(经过squeezed)【batch,12,32,32 】进行每个通道的标准化,用每个通道,例如3通道计算batch*h*w的均值,【1,12,1,1】,标准差也同样,然后进行(x-均值)/(标准差+1e-6) 标准化。因为要可逆,需要计算det,为系数的log求和,其实就是1/(标准差+1e-6)的log求和。
class ActNorm(nn.Module):
def __init__(self, in_channel, logdet=True):
super().__init__()
self.loc = nn.Parameter(torch.zeros(1, in_channel, 1, 1)) #每个通道有一个值 初始为全0
self.scale = nn.Parameter(torch.ones(1, in_channel, 1, 1)) #初始scale为全1
self.register_buffer("initialized", torch.tensor(0, dtype=torch.uint8)) #不被更新的参数
self.logdet = logdet #是否计算logdet
def initialize(self, input): #改变scale
with torch.no_grad():
flatten = input.permute(1, 0, 2, 3).contiguous().view(input.shape[1], -1)#深度拷贝,[12, 64*32*32]
mean = (
flatten.mean(1)
.unsqueeze(1)
.unsqueeze(2)
.unsqueeze(3)
.permute(1, 0, 2, 3)
)#上面把input分为12通道,每个通道包含64张的图像的一个通道数据,求均值,并转化为[1,12,1,1]
std = (
flatten.std(1)
.unsqueeze(1)
.unsqueeze(2)
.unsqueeze(3)
.permute(1, 0, 2, 3)
)#类似的,求std标准差,并转化为[1,12,1,1]
self.loc.data.copy_(-mean)# loc为负的平均值
self.scale.data.copy_(1 / (std + 1e-6)) #scale为1 / (std + 1e-6)
def forward(self, input):#64, 12, 32, 32
_, _, height, width = input.shape
if self.initialized.item() == 0: #没操作,为0
self.initialize(input) #initialized为一个操作,根据input,对loc和scale的赋值
self.initialized.fill_(1) #操作完了,为1;哈哈哈哈要是我写的话,就是直接创建一个哨兵
log_abs = logabs(self.scale)#均值的绝对值的log
logdet = height * width * torch.sum(log_abs)#均值的logabs求和后乘以h*w det为系数log求和,一共h*w个点
if self.logdet:
return self.scale * (input + self.loc), logdet #对input每个点使用通道标准化 det为系数log求和
else:
return self.scale * (input + self.loc)
def reverse(self, output):
return output / self.scale - self.loc
3.2 Invertible 1 1 convolution 可逆1*1卷积
在Flow模块中的第二层,根据是否LU,选择是否带LU操作的1*1可逆卷积:
if conv_lu:
self.invconv = InvConv2dLU(in_channel)
else:
self.invconv = InvConv2d(in_channel)
class InvConv2dLU(nn.Module):
def __init__(self, in_channel):
super().__init__()
weight = np.random.randn(in_channel, in_channel)#[12,12]
q, _ = la.qr(weight) #qr分解,q为正交矩阵,r为上三角矩阵
w_p, w_l, w_u = la.lu(q.astype(np.float32))#对于正交矩阵q进行LU分解,p为置换矩阵,l为下三角,u为上三角,PA=LU,P就是把最大元素放在第一行
w_s = np.diag(w_u)#对角线
w_u = np.triu(w_u, 1) #去掉对角线,只保留上三角
u_mask = np.triu(np.ones_like(w_u), 1) #上三角单位阵,不包含对角线
l_mask = u_mask.T #下三角 不包含对角线
w_p = torch.from_numpy(w_p) #q置换矩阵p
w_l = torch.from_numpy(w_l) #q的下三角l
w_s = torch.from_numpy(w_s.copy()) #q的上三角u的对角线
w_u = torch.from_numpy(w_u) #q的上三角u的上三角
self.register_buffer("w_p", w_p)#p不更新
self.register_buffer("u_mask", torch.from_numpy(u_mask))
self.register_buffer("l_mask", torch.from_numpy(l_mask))
self.register_buffer("s_sign", torch.sign(w_s))
self.register_buffer("l_eye", torch.eye(l_mask.shape[0])) #对角线全1,其余全0
self.w_l = nn.Parameter(w_l) #更新的
self.w_s = nn.Parameter(logabs(w_s))
self.w_u = nn.Parameter(w_u)
def forward(self, input):
_, _, height, width = input.shape
weight = self.calc_weight()#[12,12,1,1] 这里就是1*1卷积了,12种12通道 对应下面的卷积操作
out = F.conv2d(input, weight) #输出通道数为卷积种类为12
logdet = height * width * torch.sum(self.w_s)
return out, logdet
def calc_weight(self):
weight = (
self.w_p
@ (self.w_l * self.l_mask + self.l_eye) #@为矩阵乘法
@ ((self.w_u * self.u_mask) + torch.diag(self.s_sign * torch.exp(self.w_s)))
)
return weight.unsqueeze(2).unsqueeze(3)
def reverse(self, output):
weight = self.calc_weight()#weight跟上面的weight是同一个 需要先训练上面的那个weight
return F.conv2d(output, weight.squeeze().inverse().unsqueeze(2).unsqueeze(3))
自己定义的权重W,(cxc),与输入的tensor h (h x w x c)之间进行卷积计算,因此,log_det的计算为:
但是,detW的计算复杂,为了简化计算复杂度,提出使用LU分解,把W参数化:
P为置换矩阵(不参与更新),L为下三角矩阵(对角线为0),U为上三角矩阵(对角线为0),diag(s)为分解时候的上三角矩阵plu的u的对角线,U仅仅只是u的对角线变为0,这样才符合plu分解,即,W=p*l*u。这样,log_det可以简化为:
对于较大的通道数c,可以大大节省。并且,除了P不参与更新外,L、U、s都参与更新。
也提供了不进行PLU分解的版本:
class InvConv2d(nn.Module):
def __init__(self, in_channel):
super().__init__()
weight = torch.randn(in_channel, in_channel)
q, _ = torch.qr(weight)
weight = q.unsqueeze(2).unsqueeze(3)
self.weight = nn.Parameter(weight)
def forward(self, input):
_, _, height, width = input.shape
out = F.conv2d(input, self.weight)
logdet = (
height * width * torch.slogdet(self.weight.squeeze().double())[1].float()
)
return out, logdet
def reverse(self, output):
return F.conv2d(
output, self.weight.squeeze().inverse().unsqueeze(2).unsqueeze(3)
)
3.3 Affine Coupling Layers 仿射耦合层
这一层在flow模块中的第三层
仿射耦合层是一种强大的可逆变换,其中正向函数、逆函数和对数行列式的计算效率很高。加性耦合层是s=1和log_det=0的特殊情况。
还是看代码吧。
Zero initialization:零初始化最后一个卷积。这样每个仿射耦合层最初执行一个恒等函数,这有助于训练非常深的网络。也就是说,网络一开始输入等于输出,因为F为0和H接近。
class ZeroConv2d(nn.Module):
def __init__(self, in_channel, out_channel, padding=1):
super().__init__()
self.conv = nn.Conv2d(in_channel, out_channel, 3, padding=0)
self.conv.weight.data.zero_()
self.conv.bias.data.zero_()
self.scale = nn.Parameter(torch.zeros(1, out_channel, 1, 1)) # scale变成可以训练的 [1,12,1,1]
def forward(self, input):
out = F.pad(input, [1, 1, 1, 1], value=1) # 填充数值为1 从[64,512,32,32]变为[64,512,34,34]
out = self.conv(out) #通道数变回 从512变回12 初始输出全为0,因为权重为0
out = out * torch.exp(self.scale * 3) #0乘以1还是0
return out
class AffineCoupling(nn.Module):
def __init__(self, in_channel, filter_size=512, affine=True):
super().__init__()
self.affine = affine
self.net = nn.Sequential(
nn.Conv2d(in_channel // 2, filter_size, 3, padding=1),
nn.ReLU(inplace=True),
nn.Conv2d(filter_size, filter_size, 1),
nn.ReLU(inplace=True),
ZeroConv2d(filter_size, in_channel if self.affine else in_channel // 2),#如果仿射,输出通道数为12,否则为6
)
self.net[0].weight.data.normal_(0, 0.05)#初始化权重,对于第一个Conv2d
self.net[0].bias.data.zero_()
self.net[2].weight.data.normal_(0, 0.05)#初始化权重,对于第二个Conv2d
self.net[2].bias.data.zero_()
def forward(self, input):
in_a, in_b = input.chunk(2, 1)#分块,对于dim=1,通道分为2块,这应该就是上下两块 [6,6]
if self.affine:
log_s, t = self.net(in_a).chunk(2, 1)#6 6 输出初始都为0
# s = torch.exp(log_s)
s = torch.sigmoid(log_s + 2) #图中的F
# out_a = s * in_a + t
out_b = (in_b + t) * s #生成下面那个 t为图中的H,有所不同的是计算顺序
logdet = torch.sum(torch.log(s).view(input.shape[0], -1), 1)
else: #不生成F
net_out = self.net(in_a) #图中的H 通道数为6
out_b = in_b + net_out #直接生成下面 通道数相同
logdet = None
return torch.cat([in_a, out_b], 1), logdet #上面的那块不变,
def reverse(self, output):
out_a, out_b = output.chunk(2, 1) #上面的out拆分,第一个其实没有变
if self.affine:
log_s, t = self.net(out_a).chunk(2, 1) #由于第一个没有变,生成的这两个块与上面是一样的
# s = torch.exp(log_s)
s = torch.sigmoid(log_s + 2)#生成的F与上面也是一样的
# in_a = (out_a - t) / s
in_b = out_b / s - t #先除以F后减t
else:
net_out = self.net(out_a) #由于第一个没有变 生成的F没有变
in_b = out_b - net_out #直接减掉就好
return torch.cat([out_a, in_b], 1)
代码实现部分,关于s的生成注释掉的部分与视频中讲解的一致,属于标准形式,后面用sigmoid生成openai代码中也是如此。
至此,Flow模块已经完成。论文方法部分也结束了。
在Flow模块外部还有squeezed操作,把图像切分为4块后,拼起来,通道变为12后再送入Flow块。后面还有一个split操作。这形成一个Block块。如果需要split操作,则输出的一半作为z,另一半作为out送到下游。看代码
首先,高斯分布的概率密度函数:
对于此概率密度函数取对数log,以e为底:注意下面的输入,log_sd是对标准差取对数,其中mean和log_sd都是可以训练的。
def gaussian_log_p(x, mean, log_sd):
return -0.5 * log(2 * pi) - log_sd - 0.5 * (x - mean) ** 2 / torch.exp(2 * log_sd)
class Block(nn.Module):
def __init__(self, in_channel, n_flow, split=True, affine=True, conv_lu=True):
super().__init__()
squeeze_dim = in_channel * 4 #扩大4倍
self.flows = nn.ModuleList()
for i in range(n_flow): #内部Flow块,一共n_flow块
self.flows.append(Flow(squeeze_dim, affine=affine, conv_lu=conv_lu))
self.split = split
if split:#对于split,输入,输出的通道数不同
self.prior = ZeroConv2d(in_channel * 2, in_channel * 4)
else:
self.prior = ZeroConv2d(in_channel * 4, in_channel * 8)
def forward(self, input):
b_size, n_channel, height, width = input.shape
squeezed = input.view(b_size, n_channel, height // 2, 2, width // 2, 2) #尺寸变小
squeezed = squeezed.permute(0, 1, 3, 5, 2, 4) #[b,c,h,2,w,2]变成[b,c,2,2,h,w]
out = squeezed.contiguous().view(b_size, n_channel * 4, height // 2, width // 2) #深拷贝重新创建out [b, c*4, h//2, w//2]
logdet = 0
for flow in self.flows:
out, det = flow(out)
logdet = logdet + det
if self.split:#如果split的话,flow的out一半是z,另一半用来生成log_p指标,
out, z_new = out.chunk(2, 1) #通道分块,dim=1分2块 6,6
mean, log_sd = self.prior(out).chunk(2, 1) #6,6 mean, log_sd都是可学习的
log_p = gaussian_log_p(z_new, mean, log_sd) #z_new的分布为高斯分布的概率的log 这就是z是高斯分布的关键
log_p = log_p.view(b_size, -1).sum(1)#求和
else:
zero = torch.zeros_like(out)
mean, log_sd = self.prior(zero).chunk(2, 1)
log_p = gaussian_log_p(out, mean, log_sd)#out的分布为高斯分布的概率的log
log_p = log_p.view(b_size, -1).sum(1)
z_new = out
return out, logdet, log_p, z_new
def reverse(self, output, eps=None, reconstruct=False): #reverse的输入,如果是最后一层,output和eps都是z_list,其他层的话output为out,eps为z
input = output
if reconstruct: #是否重建
if self.split:
input = torch.cat([output, eps], 1) #如果split了,【out,z】
else:
input = eps #z
else: #如果不需要重建
if self.split:
mean, log_sd = self.prior(input).chunk(2, 1)
z = gaussian_sample(eps, mean, log_sd)
input = torch.cat([output, z], 1)
else:
zero = torch.zeros_like(input)
# zero = F.pad(zero, [1, 1, 1, 1], value=1)
mean, log_sd = self.prior(zero).chunk(2, 1)
z = gaussian_sample(eps, mean, log_sd)
input = z
for flow in self.flows[::-1]:
input = flow.reverse(input)
b_size, n_channel, height, width = input.shape
unsqueezed = input.view(b_size, n_channel // 4, 2, 2, height, width)
unsqueezed = unsqueezed.permute(0, 1, 4, 2, 5, 3)
unsqueezed = unsqueezed.contiguous().view(b_size, n_channel // 4, height * 2, width * 2)
return unsqueezed
最后Glow模型:
class Glow(nn.Module):
def __init__(self, in_channel, n_flow, n_block, affine=True, conv_lu=True): #n_flow为K,n_block为L
super().__init__()
self.blocks = nn.ModuleList() #blocks层为图b的堆叠
n_channel = in_channel
for i in range(n_block - 1):
self.blocks.append(Block(n_channel, n_flow, affine=affine, conv_lu=conv_lu))
n_channel *= 2 #最后一个Block通道*2
self.blocks.append(Block(n_channel, n_flow, split=False, affine=affine))
def forward(self, input):
log_p_sum = 0
logdet = 0
out = input
z_outs = [] #中间z
for block in self.blocks:
out, det, log_p, z_new = block(out) #循环 out
z_outs.append(z_new)
logdet = logdet + det #logdet求和
if log_p is not None:
log_p_sum = log_p_sum + log_p #log_p求和
return log_p_sum, logdet, z_outs # 输出log_p和logdet,以及最后的z序列
def reverse(self, z_list, reconstruct=False):
for i, block in enumerate(self.blocks[::-1]):#最后一个block去掉
if i == 0:
input = block.reverse(z_list[-1], z_list[-1], reconstruct=reconstruct)
else:
input = block.reverse(input, z_list[-(i + 1)], reconstruct=reconstruct)
return input
train部分
from tqdm import tqdm
import numpy as np
from PIL import Image
from math import log, sqrt, pi
import argparse
import torch
from torch import nn, optim
from torch.autograd import Variable, grad
from torch.utils.data import DataLoader
import torch.utils.data
from torchvision.datasets import CIFAR10
from torchvision import datasets, transforms, utils
from model import Glow
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
parser = argparse.ArgumentParser(description="Glow trainer")
parser.add_argument("--iter1", default=200000, type=int, help="maximum iterations") # 迭代周期
parser.add_argument("--n_flow", default=32, type=int, help="number of flows in each block")
parser.add_argument("--n_block", default=4, type=int, help="number of blocks")
parser.add_argument("--no_lu", action="store_true", help="use plain convolution instead of LU decomposed version")
parser.add_argument("--affine", action="store_true", help="use affine coupling instead of additive")
parser.add_argument("--n_bits", default=5, type=int, help="number of bits")
parser.add_argument("--lr", default=1e-4, type=float, help="learning rate")
parser.add_argument("--temp", default=0.7, type=float, help="temperature of sampling")
parser.add_argument("--n_sample", default=20, type=int, help="number of samples")
def data_tr_1(x):
x = x.resize((64, 64))
x = np.array(x, dtype='float32') / 255
x = (x - 0.5) / 0.5
x = x.transpose((2, 0, 1))
x = torch.from_numpy(x)
return x
def sample_data():
transform = transforms.Compose(
[
transforms.Resize(64),
transforms.CenterCrop(64),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
]
)
dataset = CIFAR10('./data', train=True, transform=transform, download=True)
loader = torch.utils.data.DataLoader(dataset, batch_size=64, shuffle=True)
#test_set = CIFAR10('./data', train=False, transform=transform, download=True)
#test_data = torch.utils.data.DataLoader(test_set, batch_size=128, shuffle=False)
#dataset = datasets.ImageFolder(path, transform=transform)
#loader = DataLoader(dataset, shuffle=True, batch_size=batch_size, num_workers=4)
loader = iter(loader)
while True:
try:
yield next(loader)
except StopIteration:
loader = DataLoader(
dataset, shuffle=True, batch_size=64, num_workers=4
)
loader = iter(loader)
yield next(loader)
def calc_z_shapes(n_channel, input_size, n_block):
'''
每一个block之后输出的z_shape
input:(3,64,64)
[(6, 32, 32), (12, 16, 16), (48, 8, 8)]
'''
z_shapes = []
for i in range(n_block - 1):
input_size //= 2 #size 两倍变小
n_channel *= 2 # 通道两倍变大
z_shapes.append((n_channel, input_size, input_size))
input_size //= 2
z_shapes.append((n_channel * 4, input_size, input_size))
return z_shapes
def calc_loss(log_p, logdet, image_size, n_bins):
# log_p = calc_log_p([z_list])
n_pixel = image_size * image_size * 3
loss = -log(n_bins) * n_pixel
loss = loss + logdet + log_p
return (
(-loss / (log(2) * n_pixel)).mean(),
(log_p / (log(2) * n_pixel)).mean(),
(logdet / (log(2) * n_pixel)).mean(),
)
def train(args, model, optimizer):
dataset = iter(sample_data())
n_bins = 2.0 ** args.n_bits # 10bit
z_sample = [] #中间初始值z?
z_shapes = calc_z_shapes(3, image_size, n_block)
for z in z_shapes:
z_new = torch.randn(n_sample, *z) * temp # n_sample为batch
z_sample.append(z_new.to(device)) #[-2, 3]左右
with tqdm(range(iter1)) as pbar:
for i in pbar:
image, _ = next(dataset)
image = image.to(device)
image = image * 255 # [0, 255]
if args.n_bits < 8: #5
image = torch.floor(image / 2 ** (8 - args.n_bits)) #[0,31]
image = image / n_bins - 0.5 #[-0.5, 2.6]
if i == 0:
with torch.no_grad():
log_p, logdet, _ = model.module(image + torch.rand_like(image) / n_bins)
continue
else:
log_p, logdet, _ = model(image + torch.rand_like(image) / n_bins) #加噪声
logdet = logdet.mean()
loss, log_p, log_det = calc_loss(log_p, logdet, image_size, n_bins)
model.zero_grad()
loss.backward()
# warmup_lr = args.lr * min(1, i * batch_size / (50000 * 10))
warmup_lr = args.lr
optimizer.param_groups[0]["lr"] = warmup_lr
optimizer.step()
pbar.set_description(
f"Loss: {loss.item():.5f}; logP: {log_p.item():.5f}; logdet: {log_det.item():.5f}; lr: {warmup_lr:.7f}"
)
if i % 100 == 0:
with torch.no_grad():
utils.save_image(
model_single.reverse(z_sample).cpu().data,
f"sample/{str(i + 1).zfill(6)}.png",
normalize=True,
nrow=10,
range=(-0.5, 0.5),
)
if i % 10000 == 0:
torch.save(
model.state_dict(), f"checkpoint/model_{str(i + 1).zfill(6)}.pt"
)
torch.save(
optimizer.state_dict(), f"checkpoint/optim_{str(i + 1).zfill(6)}.pt"
)
if __name__ == "__main__":
args = parser.parse_args()
print(args)
image_size = 64
n_flow = args.n_flow
n_block = args.n_block
n_sample = args.n_sample
temp = args.temp
iter1 = args.iter1
model_single = Glow(3, n_flow, n_block, affine=args.affine, conv_lu=not args.no_lu)
model = nn.DataParallel(model_single)
# model = model_single
model = model.to(device)
optimizer = optim.Adam(model.parameters(), lr=1e-4)
train(args, model, optimizer)
数据集我选用的是cifar10.batch size设置为64,其余都是原本的默认值。loss为log_p与logdet相加后取负,也就是目标为最大化log_p, 使输出逐渐为高斯分布,logdet使得可逆后回去Image。
最后生成的结果
由于我的数据并没有分类存放,导致学习到的特征比较混乱,而且我也只是跑通代码理解原理而已。