Glow模型【图解版加代码】

论文:Glow: Generative Flow with Invertible 1x1 Convolutions

代码:pytorch版本:rosinality/glow-pytorch: PyTorch implementation of Glow (github.com)

正版是TensorFlow版本 openai的

参考csdn文章:Glow-pytorch复现github项目_pytorch glow-CSDN博客
(pytorch进阶之路)NormalizingFlow标准流_normalizing flow-CSDN博客
需要先看一下b站的Flow的讲解Flow-based Generative Model_哔哩哔哩_bilibili P59

本csdn文的目标:跑通代码+理解原理(不包含论文结果部分解读)

目录

1 引言

2 背景:

3 Generative Flow

Glow模块的整体代码:

Block模块:

Flow模块: 

 3.1 Actnorm: scale and bias layer with data dependent initialization

3.2 Invertible 1 1 convolution 可逆1*1卷积

3.3 Affine Coupling Layers 仿射耦合层

train部分


1 引言

基于flow模型改进,提出Glow

2 背景:

之前是基于flow的生成模型,我们的目标是从z(一个普通的分布)拟合到x(真实的分布),理解为从图A变为图B,而且要求这个过程是可逆的。

模型为G(x),目标最大化极大似然(最大似然理解为当参数为变量时,X=x的概率最大化):


也就是最后的这个。即最小化:

其中,flow的意思就是多个G连起来:

最终最大化下面这个,即:


其中,z的分布的选取一般为正态分布,均值为0函数G为双摄可逆函数,,可逆回去。

在计算方面,最后可以等于雅可比行列式的对角线。

3 Generative Flow

flow的每一步都由actnorm(3.1)、一个可逆的1x1卷积(3.2)和一个耦合层(3.3)组成。flow的深度为K,层数为L,下图。

Glow模块的整体代码:

class Glow(nn.Module):
    def __init__(self, in_channel, n_flow, n_block, affine=True, conv_lu=True): #n_flow为K,n_block为L
        super().__init__()
        self.blocks = nn.ModuleList() #blocks层为图b的堆叠
        n_channel = in_channel
        for i in range(n_block - 1):
            self.blocks.append(Block(n_channel, n_flow, affine=affine, conv_lu=conv_lu))
            n_channel *= 2 #最后一个Block通道*2
        self.blocks.append(Block(n_channel, n_flow, split=False, affine=affine))

    def forward(self, input):
        log_p_sum = 0
        logdet = 0
        out = input
        z_outs = [] #中间z
        for block in self.blocks:
            out, det, log_p, z_new = block(out) #循环 out
            z_outs.append(z_new)
            logdet = logdet + det #logdet求和
            if log_p is not None:
                log_p_sum = log_p_sum + log_p #log_p求和
        return log_p_sum, logdet, z_outs # 输出log_p和logdet,以及最后的z序列

    def reverse(self, z_list, reconstruct=False):
        for i, block in enumerate(self.blocks[::-1]):#最后一个block去掉
            if i == 0:
                input = block.reverse(z_list[-1], z_list[-1], reconstruct=reconstruct)
            else:
                input = block.reverse(input, z_list[-(i + 1)], reconstruct=reconstruct)
        return input

Block模块:

class Block(nn.Module):
    def __init__(self, in_channel, n_flow, split=True, affine=True, conv_lu=True):
        super().__init__()
        squeeze_dim = in_channel * 4 #扩大4倍
        self.flows = nn.ModuleList()
        for i in range(n_flow): #内部Flow块,一共n_flow块
            self.flows.append(Flow(squeeze_dim, affine=affine, conv_lu=conv_lu))
        self.split = split
        if split:
            self.prior = ZeroConv2d(in_channel * 2, in_channel * 4)
        else:
            self.prior = ZeroConv2d(in_channel * 4, in_channel * 8)

    def forward(self, input):
        b_size, n_channel, height, width = input.shape
        squeezed = input.view(b_size, n_channel, height // 2, 2, width // 2, 2) #尺寸变小
        squeezed = squeezed.permute(0, 1, 3, 5, 2, 4) #[b,c,h,2,w,2]变成[b,c,2,2,h,w]
        out = squeezed.contiguous().view(b_size, n_channel * 4, height // 2, width // 2) #深拷贝重新创建out
        logdet = 0
        for flow in self.flows:
            out, det = flow(out)
            logdet = logdet + det
        if self.split:
            out, z_new = out.chunk(2, 1) #分块,dim=1分2块
            mean, log_sd = self.prior(out).chunk(2, 1)
            log_p = gaussian_log_p(z_new, mean, log_sd)
            log_p = log_p.view(b_size, -1).sum(1)
        else:
            zero = torch.zeros_like(out)
            mean, log_sd = self.prior(zero).chunk(2, 1)
            log_p = gaussian_log_p(out, mean, log_sd)
            log_p = log_p.view(b_size, -1).sum(1)
            z_new = out
        return out, logdet, log_p, z_new

    def reverse(self, output, eps=None, reconstruct=False):
        input = output
        if reconstruct:
            if self.split:
                input = torch.cat([output, eps], 1)
            else:
                input = eps
        else:
            if self.split:
                mean, log_sd = self.prior(input).chunk(2, 1)
                z = gaussian_sample(eps, mean, log_sd)
                input = torch.cat([output, z], 1)
            else:
                zero = torch.zeros_like(input)
                # zero = F.pad(zero, [1, 1, 1, 1], value=1)
                mean, log_sd = self.prior(zero).chunk(2, 1)
                z = gaussian_sample(eps, mean, log_sd)
                input = z
        for flow in self.flows[::-1]:
            input = flow.reverse(input)

        b_size, n_channel, height, width = input.shape
        unsqueezed = input.view(b_size, n_channel // 4, 2, 2, height, width)
        unsqueezed = unsqueezed.permute(0, 1, 4, 2, 5, 3)
        unsqueezed = unsqueezed.contiguous().view(b_size, n_channel // 4, height * 2, width * 2)

        return unsqueezed

Flow模块: 

class Flow(nn.Module):
    def __init__(self, in_channel, affine=True, conv_lu=True):
        super().__init__()
        self.actnorm = ActNorm(in_channel)
        if conv_lu:
            self.invconv = InvConv2dLU(in_channel)
        else:
            self.invconv = InvConv2d(in_channel)
        self.coupling = AffineCoupling(in_channel, affine=affine)

    def forward(self, input):
        out, logdet = self.actnorm(input)
        out, det1 = self.invconv(out)
        out, det2 = self.coupling(out)

        logdet = logdet + det1
        if det2 is not None:
            logdet = logdet + det2
            
        return out, logdet

    def reverse(self, output):
        input = self.coupling.reverse(output)
        input = self.invconv.reverse(input)
        input = self.actnorm.reverse(input)

        return input

 3.1 Actnorm: scale and bias layer with data dependent initialization

之前提出批归一化来缓解训练深度模型时遇到的问题。然而,由于批处理归一化(batch normalization)所增加的激活噪声的方差与每个GPU或其他处理单元(PU)的小批(minibatch)大小成反比,因此已知每个PU的小批大小会降低性能。因此,minibatch=1. 我们提出了一个actnorm层(用于激活归一化),它使用每个通道的尺度和偏置参数执行激活的仿射变换,类似于批量归一化。这些参数被初始化,使得每个通道的事后激活具有零均值和给定初始小批量数据的单位方差。这是数据依赖初始化的一种形式(Salimans and Kingma 2016)。初始化后,尺度和偏差被视为独立于数据的常规可训练参数。(没怎么懂,看代码吧)

在Flow模块中的第一层就是ActNorm。这一步其实就是一个标准化,对于input(经过squeezed)【batch,12,32,32 】进行每个通道的标准化,用每个通道,例如3通道计算batch*h*w的均值,【1,12,1,1】,标准差也同样,然后进行(x-均值)/(标准差+1e-6) 标准化。因为要可逆,需要计算det,为系数的log求和,其实就是1/(标准差+1e-6)的log求和。

class ActNorm(nn.Module):
    def __init__(self, in_channel, logdet=True):
        super().__init__()
        self.loc = nn.Parameter(torch.zeros(1, in_channel, 1, 1)) #每个通道有一个值 初始为全0
        self.scale = nn.Parameter(torch.ones(1, in_channel, 1, 1)) #初始scale为全1

        self.register_buffer("initialized", torch.tensor(0, dtype=torch.uint8)) #不被更新的参数
        self.logdet = logdet #是否计算logdet

    def initialize(self, input): #改变scale
        with torch.no_grad():
            flatten = input.permute(1, 0, 2, 3).contiguous().view(input.shape[1], -1)#深度拷贝,[12, 64*32*32]
            mean = (
                flatten.mean(1)
                .unsqueeze(1)
                .unsqueeze(2)
                .unsqueeze(3)
                .permute(1, 0, 2, 3)
            )#上面把input分为12通道,每个通道包含64张的图像的一个通道数据,求均值,并转化为[1,12,1,1]
            std = (
                flatten.std(1)
                .unsqueeze(1)
                .unsqueeze(2)
                .unsqueeze(3)
                .permute(1, 0, 2, 3)
            )#类似的,求std标准差,并转化为[1,12,1,1]

            self.loc.data.copy_(-mean)# loc为负的平均值
            self.scale.data.copy_(1 / (std + 1e-6)) #scale为1 / (std + 1e-6)

    def forward(self, input):#64, 12, 32, 32
        _, _, height, width = input.shape

        if self.initialized.item() == 0: #没操作,为0
            self.initialize(input) #initialized为一个操作,根据input,对loc和scale的赋值
            self.initialized.fill_(1) #操作完了,为1;哈哈哈哈要是我写的话,就是直接创建一个哨兵

        log_abs = logabs(self.scale)#均值的绝对值的log
        logdet = height * width * torch.sum(log_abs)#均值的logabs求和后乘以h*w det为系数log求和,一共h*w个点
        if self.logdet:
            return self.scale * (input + self.loc), logdet #对input每个点使用通道标准化 det为系数log求和
        else:
            return self.scale * (input + self.loc)

    def reverse(self, output):
        return output / self.scale - self.loc

3.2 Invertible 1 1 convolution 可逆1*1卷积

在Flow模块中的第二层,根据是否LU,选择是否带LU操作的1*1可逆卷积:

        if conv_lu:
            self.invconv = InvConv2dLU(in_channel)
        else:
            self.invconv = InvConv2d(in_channel)
class InvConv2dLU(nn.Module):
    def __init__(self, in_channel):
        super().__init__()
        weight = np.random.randn(in_channel, in_channel)#[12,12]
        q, _ = la.qr(weight) #qr分解,q为正交矩阵,r为上三角矩阵
        w_p, w_l, w_u = la.lu(q.astype(np.float32))#对于正交矩阵q进行LU分解,p为置换矩阵,l为下三角,u为上三角,PA=LU,P就是把最大元素放在第一行
        w_s = np.diag(w_u)#对角线
        w_u = np.triu(w_u, 1) #去掉对角线,只保留上三角
        u_mask = np.triu(np.ones_like(w_u), 1) #上三角单位阵,不包含对角线
        l_mask = u_mask.T #下三角 不包含对角线

        w_p = torch.from_numpy(w_p) #q置换矩阵p
        w_l = torch.from_numpy(w_l) #q的下三角l
        w_s = torch.from_numpy(w_s.copy()) #q的上三角u的对角线
        w_u = torch.from_numpy(w_u) #q的上三角u的上三角

        self.register_buffer("w_p", w_p)#p不更新
        self.register_buffer("u_mask", torch.from_numpy(u_mask))
        self.register_buffer("l_mask", torch.from_numpy(l_mask))
        self.register_buffer("s_sign", torch.sign(w_s))
        self.register_buffer("l_eye", torch.eye(l_mask.shape[0])) #对角线全1,其余全0
        self.w_l = nn.Parameter(w_l) #更新的
        self.w_s = nn.Parameter(logabs(w_s))
        self.w_u = nn.Parameter(w_u)

    def forward(self, input):
        _, _, height, width = input.shape
        weight = self.calc_weight()#[12,12,1,1] 这里就是1*1卷积了,12种12通道 对应下面的卷积操作
        out = F.conv2d(input, weight) #输出通道数为卷积种类为12
        logdet = height * width * torch.sum(self.w_s)
        return out, logdet

    def calc_weight(self):
        weight = (
            self.w_p
            @ (self.w_l * self.l_mask + self.l_eye) #@为矩阵乘法
            @ ((self.w_u * self.u_mask) + torch.diag(self.s_sign * torch.exp(self.w_s)))
        )
        return weight.unsqueeze(2).unsqueeze(3)

    def reverse(self, output):
        weight = self.calc_weight()#weight跟上面的weight是同一个 需要先训练上面的那个weight
        return F.conv2d(output, weight.squeeze().inverse().unsqueeze(2).unsqueeze(3))

自己定义的权重W,(cxc),与输入的tensor h (h x w x c)之间进行卷积计算,因此,log_det的计算为:

但是,detW的计算复杂,为了简化计算复杂度,提出使用LU分解,把W参数化:

P为置换矩阵(不参与更新),L为下三角矩阵(对角线为0),U为上三角矩阵(对角线为0),diag(s)为分解时候的上三角矩阵plu的u的对角线,U仅仅只是u的对角线变为0,这样才符合plu分解,即,W=p*l*u。这样,log_det可以简化为:

对于较大的通道数c,可以大大节省。并且,除了P不参与更新外,L、U、s都参与更新。

也提供了不进行PLU分解的版本:

class InvConv2d(nn.Module):
    def __init__(self, in_channel):
        super().__init__()

        weight = torch.randn(in_channel, in_channel)
        q, _ = torch.qr(weight)
        weight = q.unsqueeze(2).unsqueeze(3)
        self.weight = nn.Parameter(weight)

    def forward(self, input):
        _, _, height, width = input.shape

        out = F.conv2d(input, self.weight)
        logdet = (
            height * width * torch.slogdet(self.weight.squeeze().double())[1].float()
        )

        return out, logdet

    def reverse(self, output):
        return F.conv2d(
            output, self.weight.squeeze().inverse().unsqueeze(2).unsqueeze(3)
        )

3.3 Affine Coupling Layers 仿射耦合层

这一层在flow模块中的第三层

仿射耦合层是一种强大的可逆变换,其中正向函数、逆函数和对数行列式的计算效率很高。加性耦合层是s=1和log_det=0的特殊情况。

还是看代码吧。

Zero initialization:零初始化最后一个卷积。这样每个仿射耦合层最初执行一个恒等函数,这有助于训练非常深的网络。也就是说,网络一开始输入等于输出,因为F为0和H接近。

class ZeroConv2d(nn.Module):
    def __init__(self, in_channel, out_channel, padding=1):
        super().__init__()
        self.conv = nn.Conv2d(in_channel, out_channel, 3, padding=0)
        self.conv.weight.data.zero_()
        self.conv.bias.data.zero_()
        self.scale = nn.Parameter(torch.zeros(1, out_channel, 1, 1)) # scale变成可以训练的 [1,12,1,1]

    def forward(self, input):
        out = F.pad(input, [1, 1, 1, 1], value=1) # 填充数值为1 从[64,512,32,32]变为[64,512,34,34]
        out = self.conv(out) #通道数变回 从512变回12 初始输出全为0,因为权重为0
        out = out * torch.exp(self.scale * 3) #0乘以1还是0
        return out

class AffineCoupling(nn.Module):
    def __init__(self, in_channel, filter_size=512, affine=True):
        super().__init__()
        self.affine = affine
        self.net = nn.Sequential(
            nn.Conv2d(in_channel // 2, filter_size, 3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(filter_size, filter_size, 1),
            nn.ReLU(inplace=True),
            ZeroConv2d(filter_size, in_channel if self.affine else in_channel // 2),#如果仿射,输出通道数为12,否则为6
        )
        self.net[0].weight.data.normal_(0, 0.05)#初始化权重,对于第一个Conv2d
        self.net[0].bias.data.zero_()

        self.net[2].weight.data.normal_(0, 0.05)#初始化权重,对于第二个Conv2d
        self.net[2].bias.data.zero_()

    def forward(self, input):
        in_a, in_b = input.chunk(2, 1)#分块,对于dim=1,通道分为2块,这应该就是上下两块 [6,6]

        if self.affine:
            log_s, t = self.net(in_a).chunk(2, 1)#6 6 输出初始都为0
            # s = torch.exp(log_s)
            s = torch.sigmoid(log_s + 2) #图中的F
            # out_a = s * in_a + t
            out_b = (in_b + t) * s #生成下面那个 t为图中的H,有所不同的是计算顺序

            logdet = torch.sum(torch.log(s).view(input.shape[0], -1), 1)

        else: #不生成F
            net_out = self.net(in_a) #图中的H 通道数为6
            out_b = in_b + net_out #直接生成下面 通道数相同
            logdet = None

        return torch.cat([in_a, out_b], 1), logdet #上面的那块不变,

    def reverse(self, output):
        out_a, out_b = output.chunk(2, 1) #上面的out拆分,第一个其实没有变

        if self.affine:
            log_s, t = self.net(out_a).chunk(2, 1) #由于第一个没有变,生成的这两个块与上面是一样的
            # s = torch.exp(log_s)
            s = torch.sigmoid(log_s + 2)#生成的F与上面也是一样的
            # in_a = (out_a - t) / s
            in_b = out_b / s - t #先除以F后减t

        else:
            net_out = self.net(out_a) #由于第一个没有变 生成的F没有变
            in_b = out_b - net_out #直接减掉就好

        return torch.cat([out_a, in_b], 1)

代码实现部分,关于s的生成注释掉的部分与视频中讲解的一致,属于标准形式,后面用sigmoid生成openai代码中也是如此。

至此,Flow模块已经完成。论文方法部分也结束了。

在Flow模块外部还有squeezed操作,把图像切分为4块后,拼起来,通道变为12后再送入Flow块。后面还有一个split操作。这形成一个Block块。如果需要split操作,则输出的一半作为z,另一半作为out送到下游。看代码

首先,高斯分布的概率密度函数:

对于此概率密度函数取对数log,以e为底:注意下面的输入,log_sd是对标准差取对数,其中mean和log_sd都是可以训练的。 

def gaussian_log_p(x, mean, log_sd):
    return -0.5 * log(2 * pi) - log_sd - 0.5 * (x - mean) ** 2 / torch.exp(2 * log_sd)
class Block(nn.Module):
    def __init__(self, in_channel, n_flow, split=True, affine=True, conv_lu=True):
        super().__init__()
        squeeze_dim = in_channel * 4 #扩大4倍
        self.flows = nn.ModuleList()
        for i in range(n_flow): #内部Flow块,一共n_flow块
            self.flows.append(Flow(squeeze_dim, affine=affine, conv_lu=conv_lu))
        self.split = split
        if split:#对于split,输入,输出的通道数不同
            self.prior = ZeroConv2d(in_channel * 2, in_channel * 4)
        else:
            self.prior = ZeroConv2d(in_channel * 4, in_channel * 8)

    def forward(self, input):
        b_size, n_channel, height, width = input.shape
        squeezed = input.view(b_size, n_channel, height // 2, 2, width // 2, 2) #尺寸变小
        squeezed = squeezed.permute(0, 1, 3, 5, 2, 4) #[b,c,h,2,w,2]变成[b,c,2,2,h,w]
        out = squeezed.contiguous().view(b_size, n_channel * 4, height // 2, width // 2) #深拷贝重新创建out [b, c*4, h//2, w//2]
        logdet = 0
        for flow in self.flows:
            out, det = flow(out)
            logdet = logdet + det
        if self.split:#如果split的话,flow的out一半是z,另一半用来生成log_p指标,
            out, z_new = out.chunk(2, 1) #通道分块,dim=1分2块 6,6
            mean, log_sd = self.prior(out).chunk(2, 1) #6,6 mean, log_sd都是可学习的
            log_p = gaussian_log_p(z_new, mean, log_sd) #z_new的分布为高斯分布的概率的log 这就是z是高斯分布的关键
            log_p = log_p.view(b_size, -1).sum(1)#求和
        else:
            zero = torch.zeros_like(out)
            mean, log_sd = self.prior(zero).chunk(2, 1)
            log_p = gaussian_log_p(out, mean, log_sd)#out的分布为高斯分布的概率的log
            log_p = log_p.view(b_size, -1).sum(1)
            z_new = out
        return out, logdet, log_p, z_new

    def reverse(self, output, eps=None, reconstruct=False): #reverse的输入,如果是最后一层,output和eps都是z_list,其他层的话output为out,eps为z
        input = output
        if reconstruct: #是否重建
            if self.split:
                input = torch.cat([output, eps], 1) #如果split了,【out,z】
            else:
                input = eps #z
        else: #如果不需要重建
            if self.split:
                mean, log_sd = self.prior(input).chunk(2, 1)
                z = gaussian_sample(eps, mean, log_sd)
                input = torch.cat([output, z], 1)
            else:
                zero = torch.zeros_like(input)
                # zero = F.pad(zero, [1, 1, 1, 1], value=1)
                mean, log_sd = self.prior(zero).chunk(2, 1)
                z = gaussian_sample(eps, mean, log_sd)
                input = z
        for flow in self.flows[::-1]:
            input = flow.reverse(input)

        b_size, n_channel, height, width = input.shape
        unsqueezed = input.view(b_size, n_channel // 4, 2, 2, height, width)
        unsqueezed = unsqueezed.permute(0, 1, 4, 2, 5, 3)
        unsqueezed = unsqueezed.contiguous().view(b_size, n_channel // 4, height * 2, width * 2)
        return unsqueezed

最后Glow模型:

class Glow(nn.Module):
    def __init__(self, in_channel, n_flow, n_block, affine=True, conv_lu=True): #n_flow为K,n_block为L
        super().__init__()
        self.blocks = nn.ModuleList() #blocks层为图b的堆叠
        n_channel = in_channel
        for i in range(n_block - 1):
            self.blocks.append(Block(n_channel, n_flow, affine=affine, conv_lu=conv_lu))
            n_channel *= 2 #最后一个Block通道*2
        self.blocks.append(Block(n_channel, n_flow, split=False, affine=affine))

    def forward(self, input):
        log_p_sum = 0
        logdet = 0
        out = input
        z_outs = [] #中间z
        for block in self.blocks:
            out, det, log_p, z_new = block(out) #循环 out
            z_outs.append(z_new)
            logdet = logdet + det #logdet求和
            if log_p is not None:
                log_p_sum = log_p_sum + log_p #log_p求和
        return log_p_sum, logdet, z_outs # 输出log_p和logdet,以及最后的z序列

    def reverse(self, z_list, reconstruct=False):
        for i, block in enumerate(self.blocks[::-1]):#最后一个block去掉
            if i == 0:
                input = block.reverse(z_list[-1], z_list[-1], reconstruct=reconstruct)
            else:
                input = block.reverse(input, z_list[-(i + 1)], reconstruct=reconstruct)
        return input

train部分

from tqdm import tqdm
import numpy as np
from PIL import Image
from math import log, sqrt, pi

import argparse

import torch
from torch import nn, optim
from torch.autograd import Variable, grad
from torch.utils.data import DataLoader
import torch.utils.data
from torchvision.datasets import CIFAR10
from torchvision import datasets, transforms, utils

from model import Glow

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

parser = argparse.ArgumentParser(description="Glow trainer")
parser.add_argument("--iter1", default=200000, type=int, help="maximum iterations") # 迭代周期
parser.add_argument("--n_flow", default=32, type=int, help="number of flows in each block")
parser.add_argument("--n_block", default=4, type=int, help="number of blocks")
parser.add_argument("--no_lu", action="store_true", help="use plain convolution instead of LU decomposed version")
parser.add_argument("--affine", action="store_true", help="use affine coupling instead of additive")
parser.add_argument("--n_bits", default=5, type=int, help="number of bits")
parser.add_argument("--lr", default=1e-4, type=float, help="learning rate")
parser.add_argument("--temp", default=0.7, type=float, help="temperature of sampling")
parser.add_argument("--n_sample", default=20, type=int, help="number of samples")


def data_tr_1(x):
    x = x.resize((64, 64))
    x = np.array(x, dtype='float32') / 255
    x = (x - 0.5) / 0.5
    x = x.transpose((2, 0, 1))
    x = torch.from_numpy(x)
    return x


def sample_data():
    transform = transforms.Compose(
        [
            transforms.Resize(64),
            transforms.CenterCrop(64),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
        ]
    )
    dataset = CIFAR10('./data', train=True, transform=transform, download=True)
    loader = torch.utils.data.DataLoader(dataset, batch_size=64, shuffle=True)
    #test_set = CIFAR10('./data', train=False, transform=transform, download=True)
    #test_data = torch.utils.data.DataLoader(test_set, batch_size=128, shuffle=False)
    #dataset = datasets.ImageFolder(path, transform=transform)
    #loader = DataLoader(dataset, shuffle=True, batch_size=batch_size, num_workers=4)
    loader = iter(loader)

    while True:
        try:
            yield next(loader)

        except StopIteration:
            loader = DataLoader(
                dataset, shuffle=True, batch_size=64, num_workers=4
            )
            loader = iter(loader)
            yield next(loader)


def calc_z_shapes(n_channel, input_size, n_block):
    '''
    每一个block之后输出的z_shape
    input:(3,64,64)
    [(6, 32, 32), (12, 16, 16), (48, 8, 8)]
    '''
    z_shapes = []
    for i in range(n_block - 1):
        input_size //= 2 #size 两倍变小
        n_channel *= 2 # 通道两倍变大
        z_shapes.append((n_channel, input_size, input_size))

    input_size //= 2
    z_shapes.append((n_channel * 4, input_size, input_size))

    return z_shapes


def calc_loss(log_p, logdet, image_size, n_bins):
    # log_p = calc_log_p([z_list])
    n_pixel = image_size * image_size * 3

    loss = -log(n_bins) * n_pixel
    loss = loss + logdet + log_p

    return (
        (-loss / (log(2) * n_pixel)).mean(),
        (log_p / (log(2) * n_pixel)).mean(),
        (logdet / (log(2) * n_pixel)).mean(),
    )


def train(args, model, optimizer):
    dataset = iter(sample_data())
    n_bins = 2.0 ** args.n_bits # 10bit

    z_sample = [] #中间初始值z?
    z_shapes = calc_z_shapes(3, image_size, n_block)
    for z in z_shapes:
        z_new = torch.randn(n_sample, *z) * temp # n_sample为batch
        z_sample.append(z_new.to(device)) #[-2, 3]左右

    with tqdm(range(iter1)) as pbar:
        for i in pbar:
            image, _ = next(dataset)
            image = image.to(device)

            image = image * 255 # [0, 255]
            if args.n_bits < 8: #5
                image = torch.floor(image / 2 ** (8 - args.n_bits)) #[0,31]
            image = image / n_bins - 0.5 #[-0.5, 2.6]

            if i == 0:
                with torch.no_grad():
                    log_p, logdet, _ = model.module(image + torch.rand_like(image) / n_bins)
                    continue
            else:
                log_p, logdet, _ = model(image + torch.rand_like(image) / n_bins) #加噪声
            logdet = logdet.mean()

            loss, log_p, log_det = calc_loss(log_p, logdet, image_size, n_bins)
            model.zero_grad()
            loss.backward()
            # warmup_lr = args.lr * min(1, i * batch_size / (50000 * 10))
            warmup_lr = args.lr
            optimizer.param_groups[0]["lr"] = warmup_lr
            optimizer.step()

            pbar.set_description(
                f"Loss: {loss.item():.5f}; logP: {log_p.item():.5f}; logdet: {log_det.item():.5f}; lr: {warmup_lr:.7f}"
            )

            if i % 100 == 0:
                with torch.no_grad():
                    utils.save_image(
                        model_single.reverse(z_sample).cpu().data,
                        f"sample/{str(i + 1).zfill(6)}.png",
                        normalize=True,
                        nrow=10,
                        range=(-0.5, 0.5),
                    )

            if i % 10000 == 0:
                torch.save(
                    model.state_dict(), f"checkpoint/model_{str(i + 1).zfill(6)}.pt"
                )
                torch.save(
                    optimizer.state_dict(), f"checkpoint/optim_{str(i + 1).zfill(6)}.pt"
                )


if __name__ == "__main__":
    args = parser.parse_args()
    print(args)

    image_size = 64
    n_flow = args.n_flow
    n_block = args.n_block
    n_sample = args.n_sample
    temp = args.temp
    iter1 = args.iter1

    model_single = Glow(3, n_flow, n_block, affine=args.affine, conv_lu=not args.no_lu)
    model = nn.DataParallel(model_single)
    # model = model_single
    model = model.to(device)

    optimizer = optim.Adam(model.parameters(), lr=1e-4)

    train(args, model, optimizer)

数据集我选用的是cifar10.batch size设置为64,其余都是原本的默认值。loss为log_p与logdet相加后取负,也就是目标为最大化log_p, 使输出逐渐为高斯分布,logdet使得可逆后回去Image。

最后生成的结果

由于我的数据并没有分类存放,导致学习到的特征比较混乱,而且我也只是跑通代码理解原理而已。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/671901.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

spoon工具的常用基础操作

一些常用转换工具 1、emp表输入->excel表输出 emp表输入&#xff0c;可以进行预览查看数据有没有过来excel表输出 成功执行后&#xff0c;可以到保存的excel位置进行查看。 2、excel输入->表输出 运行转换后可以在oracle进行查看是否有成功创建这个表 3、对部门最高…

十_信号11 - 函数sigsetjmp() 和 siglongjmp()

也就是说&#xff0c;正常情况下&#xff0c;当捕捉到一个信号&#xff0c;并调用该信号的信号处理程序时&#xff0c;被捕捉的信号会被加入到当前进程的信号屏蔽字中&#xff0c;以防止在本次信号处理程序还没有完成的时候&#xff0c;再次触发该信号&#xff0c; 发生重入。 …

罕见!史诗级“大堵船”

新加坡港口的停泊延误时间已延长至7天&#xff0c;积压的集装箱数量达到惊人的450000标准箱&#xff0c;远超新冠疫情暴发时期的数轮高点。业内认为&#xff0c;近期东南亚恶劣的天气情况加剧了该区域港口拥堵。 5月31日&#xff0c;上海航运交易所&#xff08;下称“航交所”…

针对硅基氮化镓高电子迁移率晶体管(GaN-HEMT)的准物理等效电路模型,包含基板中射频漏电流的温度依赖性

来源&#xff1a;Quasi-Physical Equivalent Circuit Model of RF Leakage Current in Substrate Including Temperature Dependence for GaN-HEMT on Si&#xff08;TMTT 23年&#xff09; 摘要 该文章提出了一种针对硅基氮化镓高电子迁移率晶体管&#xff08;GaN-HEMT&…

【算法】理解堆排序

堆排序&#xff0c;无疑与堆这种数据结构有关。在了解堆排序之前&#xff0c;我们需要先了解堆的建立与维护方法。 堆 堆&#xff08;二插堆&#xff09;可以用一种近似的完全二叉树来表示&#xff0c;该二叉树除了叶子结点之外&#xff0c;其余节点均具有两个子女&#xff0c…

HCIP--RIP协议的实验 + RIP笔记

RIP实验&#xff1a; 实验思路&#xff1a; 1.规划IP&#xff0c;配置环回&#xff0c;接口IP 2.在3个路由器上跑通rip; 2.在边界路由器上用rip协议 设置缺省路由&#xff1b; [r3]rip [r3-rip-1]default-route originate 3.在r1、r2的主干接口上设置路由汇总 RIPV2手工汇…

MySQL数据库的约束

MySQL对于数据库存储的数据, 做出一些限制性要求, 就叫做数据库的"约束". 在每一列的 列名, 类型 后面加上"约束". 一. not null (非空) 指定某列不能存储null值. 二. unique (唯一) 保证这一列的每行必须有唯一值. 我们可以看到, 给 table 的 sn 列插…

Ubuntu系统配置DDNS-GO【笔记】

DDNS-GO 是一个基于 Go 语言的动态 DNS (DDNS) 客户端&#xff0c;用于自动更新你的 IP 地址到 DNS 记录上。这对于经常变更 IP 地址的用户&#xff08;如使用动态 IP 的家庭用户或者小型服务器&#xff09;非常有用。 此文档实验环境为&#xff1a;ubuntu20.04.6。 在Ubuntu…

基于Django的博客系统之登录增加忘记密码(八)

需求 描述&#xff1a; 用户忘记密码时&#xff0c;提供一种重置密码的方法&#xff0c;以便重新获得账户访问权限。规划&#xff1a; 创建一个包含邮箱输入字段的表单&#xff0c;用于接收用户的重置密码请求。用户输入注册时使用的邮箱地址&#xff0c;系统发送包含重置密码…

量产导入 | 芯片测试介绍可靠性测试

作者:桃芯科技链接:https://picture.iczhiku.com/weixin/message1583129221975.html半导体芯片的defects、Faults 芯片在制造过程中,会出现很多种不同类型的defects,比如栅氧层针孔、扩散工艺造成的各种桥接、各种预期外的高阻态、寄生电容电阻造成的延迟等等,如下面图(1)…

Spring高手之路19——Spring AOP注解指南

文章目录 1. 背景2. 基于AspectJ注解来实现AOP3. XML实现和注解实现AOP的代码对比4. AOP通知讲解5. AOP时序图 1. 背景 在现代软件开发中&#xff0c;面向切面编程&#xff08;AOP&#xff09;是一种强大的编程范式&#xff0c;允许开发者跨越应用程序的多个部分定义横切关注点…

数据隐私重塑:Web3时代的隐私保护创新

随着数字化时代的不断深入&#xff0c;数据隐私保护已经成为了人们越来越关注的焦点之一。而在这个数字化时代的新篇章中&#xff0c;Web3技术作为下一代互联网的代表&#xff0c;正在为数据隐私保护带来全新的创新和可能性。本文将深入探讨数据隐私的重要性&#xff0c;Web3时…

解锁数据宝藏:高效查找算法揭秘

代码下载链接&#xff1a;https://gitee.com/flying-wolf-loves-learning/data-structure.git 目录 一、查找的原理 1.1 查找概念 1.2 查找方法 1.3平均查找长度 1.4顺序表的查找 1.5 顺序表的查找算法及分析 1.6 折半查找算法及分析 1.7 分块查找算法及分析 1.8 总结…

很多人讲不明白HTTPS,但是我能

很多人讲不明白HTTPS&#xff0c;但是我能 今天我们用问答的形式&#xff0c;来彻底弄明白HTTPS的过程 下面的问题都是 小明和小丽两个人通信为例 可以把小明想象成服务端&#xff0c;小丽想象成客户端 1. https是做什么用的&#xff1f; 答&#xff1a;数据安全传输用的。…

数学建模 —— 聚类分析(3)

目录 一、聚类分析概述 1.1 常用聚类要素的数据处理 1.1.1 总和标准化 1.1.2 标准差标准化 1.1.3 极大值标准化 1.1.4 极差的标准化 1.2 分类 1.2.1 快速聚类法&#xff08;K-均值聚类&#xff09; 1.2.2 系统聚类法&#xff08;分层聚类法&#xff09; 二、分类统计…

Ubuntu18.04安装pwntools报错解决方案

报错1&#xff1a;ModuleNotFoundError: No module named ‘setuptools_rust’ 报错信息显示ModuleNotFoundError: No module named setuptools_rust&#xff0c;如下图所示 解决方案&#xff1a;pip install setuptools_rust 报错2&#xff1a;pip版本低 解决方案&#xff…

【数据结构(邓俊辉)学习笔记】图02——搜索

文章目录 0. 概述1. 广度优先搜索1.1 策略1.2 实现1.3 可能情况1.4 实例1.5 多联通1.6 复杂度1.7 最短路径 2. 深度优先搜索2.1 算法2.2 框架2.3 细节2.4 无向边2.5 有向边2.6 多可达域2.7 嵌套引理 3 遍历算法的应用 0. 概述 此前已经介绍过图的基本概念以及它在计算机中的表…

设计模式(十四)行为型模式---访问者模式(visitor)

文章目录 访问者模式简介分派的分类什么是双分派&#xff1f;结构UML图具体实现UML图代码实现 优缺点 访问者模式简介 访问者模式&#xff08;visitor pattern&#xff09;是封装一些作用于某种数据结构中的元素的操作&#xff0c;它可以在不改变这个数据结构&#xff08;实现…

Visual Studio Installer 点击闪退

Visual Studio Installer 点击闪退问题 1. 问题描述2. 错误类型3. 解决方法4. 结果5. 说明6. 参考 1. 问题描述 重装了系统后&#xff08;系统版本&#xff1a;如下图所示&#xff09;&#xff0c;我从官方网站&#xff08;https://visualstudio.microsoft.com/ ) 下载了安装程…

Three.js-实现加载图片并旋转

1.实现效果 2. 实现步骤 2.1创建场景 const scene new THREE.Scene(); 2.2添加相机 说明&#xff1a; fov&#xff08;视场角&#xff09;&#xff1a;视场角决定了相机的视野范围&#xff0c;即相机可以看到的角度范围。较大的视场角表示更广阔的视野&#xff0c;但可能…