经典神经网络(10)PixelCNN模型、Gated PixelCNN模型及其在MNIST数据集上的应用

经典神经网络(10)PixelCNN模型、Gated PixelCNN模型及其在MNIST数据集上的应用

1 PixelCNN

  • PixelCNN是DeepMind团队在论文Pixel Recurrent Neural Networks (16.01)提出的一种生成模型,实际上这篇论文共提出了两种架构:PixelRNNPixelCNN,两者的主要区别是前者用LSTM来建模,而PixelCNN是基于CNN的,相比RNN,CNN计算更高效,我们这里只讨论PixelCNN。

  • PixelCNN借用了NLP里的方法来生成图像。对于自然图像,每个像素值的取值范围为0~255,共256个离散值。PixelCNN模型会根据前i - 1个像素输出第i个像素的概率分布。

  • 训练时,和多分类任务一样,要根据第i个像素的真值和预测的概率分布求交叉熵损失函数

  • 采样时(图像生成时),会根据前i - 1个像素直接从预测的概率分布(多项分布)里采样出第i个像素。

1.1 单通道PixelCNN

1.1.1 掩码卷积

我们现在知道了PixelCNN的大体思路,就是根据前i - 1个像素输出第i个像素的概率分布。我们现在只考虑单通道图像,每个像素的颜色取值只有256种,那么很容易想到下面的实现方式:

在这里插入图片描述

但是只输出一个像素的概率分布,这样训练效率太低了。

  • 在训练时,我们可以输入一幅图像,同时让模型输出图像每一点像素的概率分布(如下图所示),这样就能通过每个像素的真值和模型预测的概率分布求交叉熵损失函数,进行并行训练。
  • 我们能这么做的原因是:在训练时,整幅训练图像是已知的,因此我们可以在一次前向传播后得到图像每一处的概率分布。
  • 当然,我们需要找到每个像素都忽略后续像素的信息的方法,即论文中提出的掩码卷积机制,我们后面再讲。

在这里插入图片描述

但是在生成图像(采样)时,还是要一个像素一个像素的生成(如下所示)

  • 在采样时,我们会先根据前i - 1个像素输出第i个像素的概率分布。
  • 然后,我们会从第i个像素的概率分布中进行采样(如下面代码所示)
# 假设颜色取值范围为[0, 7],下面为概率分布
prob_dist = torch.tensor([[0.1347, 0.1356, 0.1048, 0.1314, 0.1329, 0.1256, 0.1326, 0.1025]])

# 我们并不是取概率最大的像素,而是从概率分布中采样(例如下面取像素值6)
# torch.multinomial会从input这个概率分布中,取num_samples个值
pixel = torch.multinomial(input=prob_dist, num_samples=1).float() # tensor([[6.]])

在这里插入图片描述

我们现在已经知道了训练及采样的大体过程。但是,我们现在还是有一个疑问,如何保证训练时候,每个像素都忽略后续像素的信息?

PixelCNN论文里提出了一种掩码卷积机制,这种机制可以巧妙地掩盖住每个像素右侧和下侧的信息。

  • 具体来说,PixelCNN使用了两类掩码卷积:
    • 我们把两类掩码卷积分别称为「A类」和「B类」。
    • 二者都是对卷积操作的卷积核做了掩码处理,使得卷积核的右下部分不产生贡献。
    • A类和B类的唯一区别在于:卷积核的中心像素是否产生贡献
    • CNN的第一个的卷积层使用A类掩码卷积,之后每一层的都使用B类掩码卷积

在这里插入图片描述

我们来分析下这样设计的优点:

  • 对于一个7x7的图像,我们先用1次3x3 A类掩码卷积,再用若干次3x3 B类掩码卷积。我们观察图像中心处的像素在每次卷积后的感受野(即输入图像中哪些像素的信息能够传递到中心像素上)
    • 经过了第一个A类掩码卷积后,每个像素就已经看不到自己位置上的输入信息了。
    • 再经过两次B类掩码卷积后,中心像素能够看到左上角大部分像素的信息(如下图所示,我们发现还是会看漏少部分的信息,后面的Gated PixelCNN对此进行了改进)。
    • 这满足PixelCNN的约束。

在这里插入图片描述

  • 如果一直使用A类掩码卷积,每次卷积后中心像素都会看漏一些信息,最终就会导致看漏很多信息

在这里插入图片描述

  • 如果第一层就使用B类卷积,中心像素还是能看到自己位置的输入信息。这打破了PixelCNN的约束。

总结如下:

  • 逐像素预测只依赖于前面的像素,因此在选择卷积核时要进行掩码操作避免看到未来的值,因此,在第一层预测时可采用掩码卷积A
  • 由于CNN的逐像素预测是多层卷积,所以当第一层结束后,图像缺失部分已经有了预测值,因此在进行下一次/层卷积操作时可以利用当前像素的预测值,因此采用下列掩码卷积B
  • 需要注意的是,这里只考虑了单通道,如果扩展到RGB三个通道时,该如何进行mask呢?

1.1.2 PixelCNN的网络架构

  • 利用两类掩码卷积,PixelCNN满足了每个像素只能接受之前像素的信息这一约束。
  • 我们可以用任意一种CNN架构来实现PixelCNN。
  • 下图红色框所示部分是PixelCNN的网络结构,其中,第一个7x7卷积层用了A类掩码卷积,之后所有3x3卷积都是B类掩码卷积。

在这里插入图片描述

1.1.3 PixelCNN在MNIST数据集上的应用

1.1.3.1 模型

实现PixelCNN,最重要的是实现掩码卷积。

  • 掩码卷积的实现思路就是在卷积核组上设置一个mask。在前向传播的时候,先让卷积核组乘mask,再做普通的卷积。
  • 由于输入输出都是单通道图像,我们只需要在卷积核的h, w两个维度设置掩码。
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
import torchvision
from torchvision.transforms import ToTensor
import time
import einops
import cv2
import numpy as np
import os


class MaskConv2d(nn.Module):
    """
        掩码卷积的实现思路:
            在卷积核组上设置一个mask,在前向传播的时候,先让卷积核组乘mask,再做普通的卷积
    """
    def __init__(self, conv_type, *args, **kwags):
        super().__init__()
        assert conv_type in ('A', 'B')
        self.conv = nn.Conv2d(*args, **kwags)
        H, W = self.conv.weight.shape[-2:]
        # 由于输入输出都是单通道图像,我们只需要在卷积核的h, w两个维度设置掩码
        mask = torch.zeros((H, W), dtype=torch.float32)
        mask[0:H // 2] = 1
        mask[H // 2, 0:W // 2] = 1
        if conv_type == 'B':
            mask[H // 2, W // 2] = 1
        # 为了保证掩码能正确广播到4维的卷积核组上,我们做一个reshape操作
        mask = mask.reshape((1, 1, H, W))
        # register_buffer可以把一个变量加入成员变量的同时,记录到PyTorch的Module中
        # 每当执行model.to(device)把模型中所有参数转到某个设备上时,被注册的变量会跟着转。
        # 第三个参数表示被注册的变量是否要加入state_dict中以保存下来
        self.register_buffer(name='mask', tensor=mask, persistent=False)

    def forward(self, x):
        self.conv.weight.data *= self.mask
        conv_res = self.conv(x)
        return conv_res

有了最核心的掩码卷积,我们来根据论文中的模型结构图把模型搭起来

在这里插入图片描述

  • 我们先实现残差块上图右部分的ResidualBlock,这里添加归一化
class ResidualBlock(nn.Module):
    """
        残差块ResidualBlock
    """
    def __init__(self, h, bn=True):
        super().__init__()
        self.relu = nn.ReLU()
        self.conv1 = nn.Conv2d(2 * h, h, 1)
        self.bn1 = nn.BatchNorm2d(h) if bn else nn.Identity()
        self.conv2 = MaskConv2d('B', h, h, 3, 1, 1)
        self.bn2 = nn.BatchNorm2d(h) if bn else nn.Identity()
        self.conv3 = nn.Conv2d(h, 2 * h, 1)
        self.bn3 = nn.BatchNorm2d(2 * h) if bn else nn.Identity()

    def forward(self, x):
        # 1、ReLU + 1×1 Conv + bn
        y = self.relu(x)
        y = self.conv1(y)
        y = self.bn1(y)
        # 2、ReLU + 3×3 Conv(mask B) + bn
        y = self.relu(y)
        y = self.conv2(y)
        y = self.bn2(y)
        # 3、ReLU + 1×1 Conv + bn
        y = self.relu(y)
        y = self.conv3(y)
        y = self.bn3(y)
        # 4、残差连接
        y = y + x
        return y
  • 有了所有这些基础模块后,我们就可以拼出最终的PixelCNN了。
  • 注意,我们可以自己决定颜色有几个亮度级别。要修改亮度级别的数量,只需要修改softmax输出的通道数color_level。
class PixelCNN(nn.Module):
    def __init__(self, n_blocks, h, linear_dim, bn=True, color_level=256):
        super().__init__()
        self.conv1 = MaskConv2d('A', 1, 2 * h, 7, 1, 3)
        self.bn1 = nn.BatchNorm2d(2 * h) if bn else nn.Identity()
        self.residual_blocks = nn.ModuleList()
        for _ in range(n_blocks):
            self.residual_blocks.append(ResidualBlock(h, bn))
        self.relu = nn.ReLU()
        self.linear1 = nn.Conv2d(2 * h, linear_dim, 1)
        self.linear2 = nn.Conv2d(linear_dim, linear_dim, 1)
        self.out = nn.Conv2d(linear_dim, color_level, 1)

    def forward(self, x):
        # 1、7 × 7 conv(mask A)
        x = self.conv1(x)
        x = self.bn1(x)
        # 2、Multiple residual blocks
        for block in self.residual_blocks:
            x = block(x)
        x = self.relu(x)
        # 3、1 × 1 conv
        x = self.linear1(x)
        x = self.relu(x)
        x = self.linear2(x)
        x = self.out(x)
        return x
1.1.3.2 数据集及训练

准备好了模型代码,我们可以编写训练脚本了:

  • PixelCNN有15个残差块,中间特征的通道数为128,输出前线性层的通道数为32
def get_dataloader(batch_size: int):
    dataset = torchvision.datasets.MNIST(root='/root/autodl-fs/data/minist',
                                         train=True,
                                         transform=ToTensor()
                                         )
    return DataLoader(dataset, batch_size=batch_size, shuffle=True)


def train(model, device, model_path, batch_size=128, color_level=8, n_epochs=40):
    """训练过程"""
    dataloader = get_dataloader(batch_size)
    model = model.to(device)
    optimizer = torch.optim.Adam(model.parameters(), 1e-3)
    loss_fn = nn.CrossEntropyLoss()

    tic = time.time()
    for e in range(n_epochs):
        total_loss = 0
        for x, _ in dataloader:
            current_batch_size = x.shape[0]
            x = x.to(device)
            # 把训练集的浮点颜色值转换成[0, color_level-1]之间的整型标签
            y = torch.ceil(x * (color_level - 1)).long()
            y = y.squeeze(1)
            predict_y = model(x)
            loss = loss_fn(predict_y, y)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            total_loss += loss.item() * current_batch_size
        total_loss /= len(dataloader.dataset)
        toc = time.time()
        torch.save(model.state_dict(), model_path)
        print(f'epoch {e} loss: {total_loss} elapsed {(toc - tic):.2f}s')

if __name__ == '__main__':
    os.makedirs('work_dirs', exist_ok=True)
    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    # 需要注意的是:MNIST数据集的大部分像素都是0和255
    color_level = 8  # or 256
    # 1、创建PixelCNN模型
    model = PixelCNN(n_blocks=15, h=128, linear_dim=32, bn=True, color_level=color_level)
    # 2、模型训练
    model_path = f'work_dirs/model_pixelcnn_{color_level}.pth'
    train(model, device, model_path)
    # 3、采样
    sample(model, device, model_path, f'work_dirs/pixelcnn_{color_level}.jpg')        
1.1.3.3 采样
  • 在采样时,我们把x初始化成一个0张量。
  • 之后,循环遍历每一个像素,输入x,把预测出的下一个像素填入x.
def sample(model, device, model_path, output_path, n_sample=1):
    """
        把x初始化成一个0张量。
        循环遍历每一个像素,输入x,把预测出的下一个像素填入x
    """
    model.eval()
    model.load_state_dict(torch.load(model_path))
    model = model.to(device)
    C, H, W = get_img_shape()  # (1, 28, 28)
    x = torch.zeros((n_sample, C, H, W)).to(device)
    with torch.no_grad():
        for i in range(H):
            for j in range(W):
                # 我们先获取模型的输出,再用softmax转换成概率分布
                output = model(x)
                prob_dist = F.softmax(output[:, :, i, j], -1)
                # 再用torch.multinomial从概率分布里采样出【1】个[0, color_level-1]的离散颜色值
                # 再除以(color_level - 1)把离散颜色转换成浮点[0, 1]
                pixel = torch.multinomial(input=prob_dist, num_samples=1).float() / (color_level - 1)
                # 最后把新像素填入到生成图像中
                x[:, :, i, j] = pixel
    # 乘255变成一个用8位字节表示的图像
    imgs = x * 255
    imgs = imgs.clamp(0, 255)
    imgs = einops.rearrange(imgs, '(b1 b2) c h w -> (b1 h) (b2 w) c', b1=int(n_sample**0.5))

    imgs = imgs.detach().cpu().numpy().astype(np.uint8)
    cv2.imwrite(output_path, imgs)

1.2 多通道PixelCNN

如下图所示,作者假设RGB三个通道之间存在相互影响

  • 其中红色预测不受蓝色和绿色通道的影响,只受上下文影响
  • 绿色红色通道和上下文影响,但不受蓝色通道影响;
  • 蓝色通道受上下文、红色通道、绿色通道影响

在这里插入图片描述

更具体地,我们规定一个子像素只由它之前的子像素决定,生成图像时,我们一个子像素一个子像素地生成

  • 如下图所示,对于RGB图像,R子像素由它之前所有像素决定
  • G子像素由它的R子像素和之前所有像素决定,
  • B子像素由它的R、G子像素和它之前所有像素决定。

在这里插入图片描述

如下图所示,由于现在要预测三个颜色通道,网络的输出应该是一个[256x3, H, W]形状的张量

  • 即每个像素输出三个概率分布,分别表示R、G、B取某种颜色的概率。
  • 同时,本质上来讲,网络是在并行地为每个像素计算3组结果。因此,为了达到同样的性能,网络所有的特征图的通道数也要乘3。

在这里插入图片描述

图像变为多通道后,A类卷积和B类卷积的定义也需要做出一些调整。我们不仅要考虑像素在空间上的约束,还要考虑一个像素内子像素间的约束。为此,我们要用不同的策略实现约束。为了方便描述,我们设卷积核组的形状为[o, i, h, w],其中o为输出通道数,i为输入通道数,h, w为卷积核的高和宽。

  • 对于通道间的约束,我们要在o, i两个维度上设置掩码,如下图左边所示。
    • 设输出通道可以被拆成三组o1, o2, o3,输入通道可以被拆成三组i1, i2, i3
      • o1 = 0:o/3, o2 = o/3:o*2/3, o3 = o*2/3:o
      • i1 = 0:i/3, i2 = i/3:i*2/3, i3 = i*2/3:i
      • 序号1, 2, 3分别表示这组通道是在维护R, G, B的计算。
    • 我们对输入通道组和输出通道组之间进行约束。
    • 对于A类卷积,我们令o1看不到i1, i2, i3o2看不到i2, i3o3看不到i3
    • 对于B类卷积,我们取消每个通道看不到自己的限制,即在A类卷积的基础上令o1看到i1o2看到i2o3看到i3
  • 如下图右边所示,对于空间上的约束,我们还是和之前一样,在h, w两个维度上设置掩码。由于「是否看到自己」的处理已经在o, i两个维度里做好了,我们直接在空间上用原来的B类卷积就行。

在这里插入图片描述

  • 下面给出三维掩码示意图方便理解:

在这里插入图片描述

2 Gated PixelCNN

2.1 Gated PixelCNN简述

  • 可以参考大神讲解:Gated PixelCNN (sergeiturukin.com)

  • PixelCNN的掩码卷积其实有一个重大漏洞:像素存在视野盲区。如下图所示,中心像素看不到右上角三个本应该能看到的像素。

在这里插入图片描述

  • 为此,PixelCNN论文的作者又发表了Conditional Image Generation with PixelCNN Decoders(16.06)。这篇论文提出了一种叫做Gated PixelCNN的改进架构。Gated PixelCNN使用了一种更好的掩码卷积机制,消除了原PixelCNN里的视野盲区。

在这里插入图片描述

  • 如下图所示,Gated PixelCNN使用了两种卷积,即垂直卷积和水平卷积,来分别维护一个像素上侧的信息和左侧的信息
    • 垂直卷积的结果只是一些临时量
    • 而水平卷积的结果最终会被网络输出
    • 使用这种新的掩码卷积机制后,每个像素能正确地收到之前所有像素的信息了。

在这里插入图片描述

  • Gated PixelCNN用下图的模块代替了原PixelCNN的普通残差模块。
  • 模块的输入输出都是两个量,左边的量是垂直卷积中间结果,右边的量是最后用来计算输出的量。
  • 垂直卷积的结果会经过偏移和一个1x1卷积,再加到水平卷积的结果上。
  • 两条计算路线在输出前都会经过门激活单元。所谓门激活单元,就是输入两个形状相同的量,一个做tanh,一个做sigmoid,两个结果相乘再输出。
  • 此外,模块右侧还有一个残差连接。

在这里插入图片描述

2.2 Gated PixelCNN在MNIST数据集上的应用

2.2.1 创建模型

  • 首先,实现垂直卷积和水平卷积
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
import torchvision
from torchvision.transforms import ToTensor
import time
import einops
import cv2
import numpy as np
import os


class VerticalMaskConv2d(nn.Module):
    """
        垂直卷积
    """
    def __init__(self, *args, **kwags):
        super().__init__()
        self.conv = nn.Conv2d(*args, **kwags)
        H, W = self.conv.weight.shape[-2:]
        mask = torch.zeros((H, W), dtype=torch.float32)
        mask[0:H // 2 + 1] = 1
        mask = mask.reshape((1, 1, H, W))
        self.register_buffer('mask', mask, False)

    def forward(self, x):
        self.conv.weight.data *= self.mask
        conv_res = self.conv(x)
        return conv_res


class HorizontalMaskConv2d(nn.Module):
    """
        水平卷积
    """
    def __init__(self, conv_type, *args, **kwags):
        super().__init__()
        assert conv_type in ('A', 'B')
        self.conv = nn.Conv2d(*args, **kwags)
        H, W = self.conv.weight.shape[-2:]
        mask = torch.zeros((H, W), dtype=torch.float32)
        mask[H // 2, 0:W // 2] = 1
        if conv_type == 'B':
            mask[H // 2, W // 2] = 1
        mask = mask.reshape((1, 1, H, W))
        self.register_buffer('mask', mask, False)

    def forward(self, x):
        self.conv.weight.data *= self.mask
        conv_res = self.conv(x)
        return conv_res
# 垂直卷积
tensor([[[[1., 1., 1.],
          [1., 1., 1.],
          [0., 0., 0.]]]])
# A类水平卷积
tensor([[[[0., 0., 0.],
          [1., 0., 0.],
          [0., 0., 0.]]]])
# B类水平卷积
tensor([[[[0., 0., 0.],
          [1., 1., 0.],
          [0., 0., 0.]]]])
  • 我们现在搭建Gated Block模块,这也是最难理解的一部分。
  • 可以参考的解释:https://segmentfault.com/a/1190000041189859?utm_source=sf-similar-article

在这里插入图片描述

  • # 这里比较难理解,通过对图像进行零填充并裁剪图像底部,可以确保垂直和水平堆栈之间的因果关系
    v_to_h = v[:, :, 0:-1]
    v_to_h = F.pad(v_to_h, (0, 0, 1, 0))
    # 注意到,v和i相加的位置只差了一个单位。
    # 为了把相加的位置对齐,我们要把v往下移一个单位,把原来在i-1处的信息移到i上。
    # 这样,移动过后的v_to_h就能和h直接用向量加法并行地加到一起了。
    

在这里插入图片描述

  • 维护两个v, h两个变量,分别表示垂直卷积部分的结果和水平卷积部分的结果。
    • v会经过一个垂直掩码卷积和一个门激活函数。
    • h会经过一个类似于残差块的结构,只不过第一个卷积是水平掩码卷积、激活函数是门激活函数、进入激活函数之前会和垂直卷积的信息融合。
class GatedBlock(nn.Module):

    def __init__(self, conv_type, in_channels, p, bn=True):
        super().__init__()
        self.conv_type = conv_type
        self.p = p
        self.v_conv = VerticalMaskConv2d(in_channels, 2 * p, 3, 1, 1)
        self.bn1 = nn.BatchNorm2d(2 * p) if bn else nn.Identity()
        self.v_to_h_conv = nn.Conv2d(2 * p, 2 * p, kernel_size=1)
        self.bn2 = nn.BatchNorm2d(2 * p) if bn else nn.Identity()
        self.h_conv = HorizontalMaskConv2d(conv_type, in_channels, 2 * p, 3, 1,
                                           1)
        self.bn3 = nn.BatchNorm2d(2 * p) if bn else nn.Identity()
        self.h_output_conv = nn.Conv2d(p, p, 1)
        self.bn4 = nn.BatchNorm2d(p) if bn else nn.Identity()

    def forward(self, v_input, h_input):
        # v代表垂直卷积部分的结果
        v = self.v_conv(v_input)
        v = self.bn1(v)
        # Note: 重点代码
        # 为了把v的信息贴到h上,我们并不是像前面的示意图所写的令v上移一个单位
        # 而是用下面的代码令v下移了一个单位(下移即去掉最下面一行,往最上面一行填0)
        v_to_h = v[:, :, 0:-1]
        v_to_h = F.pad(v_to_h, (0, 0, 1, 0))
        # 和h相加前,先经过 1×1 conv
        v_to_h = self.v_to_h_conv(v_to_h)
        v_to_h = self.bn2(v_to_h)
        # 分为两份,经过tanh 和 sigmoid
        v1, v2 = v[:, :self.p], v[:, self.p:]
        v1 = torch.tanh(v1)
        v2 = torch.sigmoid(v2)
        v = v1 * v2

        # h代表水平卷积部分的结果
        h = self.h_conv(h_input)
        h = self.bn3(h)
        h = h + v_to_h
        # 分为两份,经过tanh 和 sigmoid
        h1, h2 = h[:, :self.p], h[:, self.p:]
        h1 = torch.tanh(h1)
        h2 = torch.sigmoid(h2)
        h = h1 * h2
        h = self.h_output_conv(h)
        h = self.bn4(h)
        # 在网络的第一层,每个数据是不能看到自己的。
        # 所以,当GatedBlock发现卷积类型为A类时,不应该对h做残差连接。
        if self.conv_type == 'B':
            h = h + h_input
        return v, h
  • 最后,我们来用GatedBlock搭出Gated PixelCNN
  • Gated PixelCNN和PixelCNN的结构非常相似,只是把ResidualBlock替换成了GatedBlock而已。
class GatedPixelCNN(nn.Module):

    def __init__(self, n_blocks, p, linear_dim, bn=True, color_level=256):
        super().__init__()
        self.block1 = GatedBlock('A', 1, p, bn)
        self.blocks = nn.ModuleList()
        for _ in range(n_blocks):
            self.blocks.append(GatedBlock('B', p, p, bn))
        self.relu = nn.ReLU()
        self.linear1 = nn.Conv2d(p, linear_dim, 1)
        self.linear2 = nn.Conv2d(linear_dim, linear_dim, 1)
        self.out = nn.Conv2d(linear_dim, color_level, 1)

    def forward(self, x):
        v, h = self.block1(x, x)
        for block in self.blocks:
            v, h = block(v, h)
        x = self.relu(h)
        x = self.linear1(x)
        x = self.relu(x)
        x = self.linear2(x)
        x = self.out(x)
        return x

2.2.2 数据集、训练及采样

  • 数据集、训练及采样和PixelCNN一模一样,不再赘述。
def get_dataloader(batch_size: int):
    dataset = torchvision.datasets.MNIST(root='/root/autodl-fs/data/minist',
                                         train=True,
                                         transform=ToTensor()
                                         )
    return DataLoader(dataset, batch_size=batch_size, shuffle=True)


def train(model, device, model_path, batch_size=128, color_level=8, n_epochs=40):
    """训练过程"""
    dataloader = get_dataloader(batch_size)
    model = model.to(device)
    optimizer = torch.optim.Adam(model.parameters(), 1e-3)
    loss_fn = nn.CrossEntropyLoss()

    tic = time.time()
    for e in range(n_epochs):
        total_loss = 0
        for x, _ in dataloader:
            current_batch_size = x.shape[0]
            x = x.to(device)
            # 把训练集的浮点颜色值转换成0~color_level-1之间的整型标签的
            y = torch.ceil(x * (color_level - 1)).long()
            y = y.squeeze(1)
            predict_y = model(x)
            loss = loss_fn(predict_y, y)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            total_loss += loss.item() * current_batch_size
        total_loss /= len(dataloader.dataset)
        toc = time.time()
        torch.save(model.state_dict(), model_path)
        print(f'epoch {e} loss: {total_loss} elapsed {(toc - tic):.2f}s')


def get_img_shape():
    return (1, 28, 28)


def sample(model, device, model_path, output_path, n_sample=1):
    """
        把x初始化成一个0张量。
        循环遍历每一个像素,输入x,把预测出的下一个像素填入x
    """
    model.eval()
    model.load_state_dict(torch.load(model_path))
    model = model.to(device)
    C, H, W = get_img_shape()  # (1, 28, 28)
    x = torch.zeros((n_sample, C, H, W)).to(device)
    with torch.no_grad():
        for i in range(H):
            for j in range(W):
                # 我们先获取模型的输出,再用softmax转换成概率分布
                output = model(x)
                prob_dist = F.softmax(output[:, :, i, j], -1)
                # 再用torch.multinomial从概率分布里采样出【1个】0~(color_level-1)的离散颜色值
                # 再除以(color_level - 1)把离散颜色转换成浮点颜色(因为网络是输入是浮点颜色)
                pixel = torch.multinomial(input=prob_dist, num_samples=1).float() / (color_level - 1)
                # 最后把新像素填入生成图像
                x[:, :, i, j] = pixel

    imgs = x * 255
    imgs = imgs.clamp(0, 255)
    imgs = einops.rearrange(imgs, '(b1 b2) c h w -> (b1 h) (b2 w) c', b1=int(n_sample**0.5))

    imgs = imgs.detach().cpu().numpy().astype(np.uint8)
    cv2.imwrite(output_path, imgs)


if __name__ == '__main__':
    os.makedirs('work_dirs', exist_ok=True)
    device = 'cuda' if torch.cuda.is_available() else 'cpu'

    color_level = 8  # or 256
    # 1、创建GatedPixelCNN模型
    model = GatedPixelCNN(n_blocks=15, p=128, linear_dim=32, bn=True, color_level=color_level)
    # 2、模型训练
    model_path = f'work_dirs/model_gatedpixelcnn_{color_level}.pth'
    train(model, device, model_path, batch_size=1)
    # 3、采样
    sample(model, device, model_path, f'work_dirs/gatedpixelcnn_{color_level}.jpg')

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/687900.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

【LeetCode算法】第110题:平衡二叉树

目录 一、题目描述 二、初次解答 三、官方解法 四、总结 一、题目描述 二、初次解答 1. 思路:从上而下访问二叉树的节点,递归判定当前节点的左子树和右子树的高度差是否为0、-1或1,从而判定其是否是平衡二叉树。 2. 代码: int…

【Web API DOM11】节点操作

一:DOM节点 1 什么是DOM节点 DOM树里每一个内容都称为节点 2 DOM节点分类 元素节点 属性节点:a标签的href、img标签的src等 文本节点:标签中的文字 上图为整个DOM树,每个标签、以及标签属性、文本内容构成了DOM树 二&#…

代码随想录算法训练营day43

题目:1049. 最后一块石头的重量 II 、494. 目标和、474.一和零 参考链接:代码随想录 1049. 最后一块石头的重量 II 思路:本题石头是相互粉碎,粉碎后剩下的重量就是两块石头之差,我们可以想到,把石头分成…

从零开始手把手Vue3+TypeScript+ElementPlus管理后台项目实战四(引入Axios,并调用第一个接口)

RealWorld接口综述 本项目调用的是RealWorld项目的开放接口。 接口文档如下: https://main--realworld-docs.netlify.app/docs/specs/backend-specs/endpoints https://main--realworld-docs.netlify.app/docs/specs/frontend-specs/swagger RealWorld 是一个适…

Day45 代码随想录打卡|二叉树篇---路径总和

题目(leecode T112): 给你二叉树的根节点 root 和一个表示目标和的整数 targetSum 。判断该树中是否存在 根节点到叶子节点 的路径,这条路径上所有节点值相加等于目标和 targetSum 。如果存在,返回 true ;…

LeetCode刷题之HOT100之最小路径和

2024/6/7 今天天气转晴,将栀子花移动到二楼阳台,愿它好!昨天准备做完这题再回去,太晚了感觉很疲惫,做不下去,今天早上来把它做了。 1、题目描述 2、逻辑分析 昨天上午做过一个跳格子的题目,也…

设计软件有哪些?效果工具篇(2),渲染100邀请码1a12

这次我们继续介绍一些渲染效果和后期处理的工具。 1、Krakatoa Krakatoa是由Thinkbox Software开发的强大的粒子渲染器,可用于Autodesk 3ds Max等软件。它专注于处理大规模粒子数据,提供了高效的渲染解决方案,适用于各种特效、粒子系统和模…

配音方面目前可以用AIGC替代吗?( 计育韬老师高校公益巡讲答疑实录2024)

这是计育韬老师第 8 次开展面向全国高校的新媒体技术公益巡讲活动了。而在每场讲座尾声,互动答疑环节往往反映了高校师生当前最普遍的运营困境,特此计老师在现场即兴答疑之外,会尽量选择有较高价值的提问进行文字答疑梳理。 *本轮巡讲主题除了…

李飞飞解读创业方向:「空间智能」

在AI领域,李飞飞教授一直是一个举足轻重的存在。她的研究和见解不仅推动了计算机视觉的发展,更对人工智能的未来方向产生了深远的影响。在最近的一次演讲中,李飞飞详细解读了她对于「空间智能」的见解。本文将对她的演讲内容进行详细解读&…

第一周:计算机网络概述(上)

一、计算机网络基本概念 1、计算机网络通信技术计算机技术 计算机网络就是一种特殊的通信网络,其特殊之处就在于它的信源和信宿就是计算机。 2、什么是计算机网络 在计算机网络中,我们把这些计算机统称为“主机”(上图中所有相连的电脑和服…

大学信息资源管理试题及答案,分享几个实用搜题和学习工具 #职场发展#微信

人工智能技术的发展正逐渐改变着我们的生活,学习如何运用这些技术将成为大学生的必备素养。 1.彩虹搜题 这是个微信公众号 算法持续优化,提升搜题效果。每一次搜索都更精准,答案更有价值。 下方附上一些测试的试题及答案 1、在SpringMVC配…

衰老过程中肠道菌群变化及其对老年抑郁和认知下降的影响

谷禾健康 编辑在老龄化过程中,生理功能逐渐衰退,伴随着多种疾病的发生,对老年人的身心健康构成重大威胁。 衰老是一个渐进、持续的过程,受到多种因素的影响,包括遗传、饮食、运动、生活方式等生理因素,也有…

【Linux】进程(8):Linux真正是如何调度的

大家好,我是苏貝,本篇博客带大家了解Linux进程(8):Linux真正是如何调度的,如果你觉得我写的还不错的话,可以给我一个赞👍吗,感谢❤️ 目录 之前我们讲过,在大…

IP地址相同,是否意味着身处同一地点?深入探究IP地址的奥秘

IP地址一样是不是证明在一个地方?在数字化时代的今天,网络已经深入到我们生活的每个角落。IP地址作为网络连接的标识符,其重要性不言而喻。然而,当我们遇到两个或多个设备拥有相同IP地址的情况时,很多人会自然地认为这…

智慧视觉怎么识别视频?智慧机器视觉是通过什么步骤识别视频的?

智慧视觉功能怎么识别视频?智慧视觉是搭载在智能设备比如手机、AI盒子、机器视觉系统上的一个应用程序或特性,采用计算机视觉和人工智能的技术来识别图像或视频中的内容。如果想了解视频识别,就要明白智慧视觉功能会涉及的以下几个关键步骤和…

Linux性能优化实战

Linux性能优化实战 33 | 关于 Linux 网络,你必须知道这些(上)如何提高系统并发?(8条)如何理解分布式?如何理解云计算?如何理解微服务?TCP/IP 网络栈如何分层?…

ctfshow-web入门-命令执行(web37-web40)

目录 1、web37 2、web38 3、web39 4、web40 命令执行&#xff0c;需要严格的过滤 1、web37 使用 php 伪协议&#xff1a; ?cphp://input post 写入我们希望执行的 php 代码&#xff1a; <?php system(tac f*);?> 拿到 flag&#xff1a;ctfshow{5c555d9a-6f55…

FactoryTalk View Site Edition的VBA基本应用

第一节 在VBA中标签的读取和写入 本例要达到的目标是通过FactoryTalk View Site Edition&#xff08;以下简称SE&#xff09;的VBA来访问PLC中的下位标签&#xff0c;并实现标签的读写。 1.准备工作 打开SE&#xff0c;选择应用程序类型&#xff08;本例是Site Edition Netwo…

NSSCTF-Web题目7

目录 [SWPUCTF 2022 新生赛]ez_rce 1、题目 2、知识点 3、思路 ​编辑 [MoeCTF 2022]baby_file 1、题目 2、知识点 3、思路 [SWPUCTF 2022 新生赛]ez_rce 1、题目 2、知识点 ThinkPHP V5 框架漏洞的利用&#xff0c;命令执行 由于ThinkPHP5在处理控制器传参时&#xff…

SpringBoot项目启动后访问网页显示“Please sign in“

SpringBoot启动类代码如下 SpringBoot项目启动后访问网页显示"Please sign in"&#xff0c;如图 这是一个安全拦截页面&#xff0c;即SpringSecurity认证授权页面&#xff0c;因为SecurityAutoConfiguration是Spring Boot提供的安全自动配置类&#xff0c;也就是说它…