使用Pytorch从零开始构建Conditional PixelCNN

条件 PixelCNN

PixelCNN 是 PixelRNN 的卷积版本,它将图像中的像素视为一个序列,并在看到前面的像素后预测每个像素(定义如上和左,尽管这是任意的)。PixelRNN 是图像联合先验分布的自回归模型:
p ( x ) = p ( x 0 ) ∏ p ( x i ∣ x 0 , ⋯   , x i − 1 ) p(x) = p(x_0 ) ∏ p(x_i | x_0, \cdots,x_{i-1} ) p(x)=p(x0)p(xix0,,xi1)
PixelRNN 的训练速度很慢,因为循环无法并行化——即使是小图像也有数百或数千个像素,这对于 RNN 来说是一个相对较长的序列。用掩码卷积替换循环,使卷积滤波器仅看到上方和左侧的像素,从而实现更快的训练(图来自条件 PixelCNN 论文)。
在这里插入图片描述

然而,值得注意的是,最初的 PixelCNN 实现产生的结果比 PixelRNN 更差。在后续论文(使用 PixelCNN 解码器生成条件图像)中推测,结果降级的一个可能原因是 PixelCNN 中的 ReLU 激活与 LSTM 中的门控连接相比相对简单。Conditional PixelCNN 论文随后用门控激活取代了 ReLU:
y = t a n h ( W f ∗ x ) • σ ( W g ∗ x ) y = tanh (W f * x) • σ(W g * x) y=tanh(Wfx)σ(Wgx)
后续论文中提供的另一个可能的原因是,堆叠掩模卷积滤波器会导致盲点,无法捕获预测像素之上的所有像素(论文中的图):
在这里插入图片描述

PixelCNN 与 GAN

PixelCNN 和 GAN 是目前用于生成图像的两种深度学习模型。GAN 最近受到了很多关注,但在很多方面我发现它们的流行是没有根据的。

目前尚不清楚 GAN 实际上试图优化什么目标,因为训练目标的最小值(即愚弄鉴别器)将导致生成器重新创建所有训练图像和/或生成不一定类似于自然图像的对抗性示例。这反映在训练 GAN 的众所周知的困难以及无数的对其进行正则化的技巧上。让两个网络相互对抗以产生训练信号的想法很有趣,并且已经产生了许多好的论文(尤其是 CycleGAN),但我仍然不相信它们除了在社交媒体上发布华丽的帖子之外还有其他用途。

另一方面,PixelCNN 有很好的概率基础。这使得它们不仅可以通过对分布进行采样(从左到右,从上到下,遵循自回归定义)来生成图像,而且还意味着它们可以用于其他任务。例如:作为预筛选网络来检测域外或对抗性示例;用于检测训练集中的异常值;或估计测试中的不确定性。我将在下一篇文章中详细介绍其中一些扩展。

我很想知道是否有人尝试过将 PixelCNN 和 GAN 结合起来。也许 PixelCNN 可以用作解码器的前级或最后阶段(以一些更高级别的学习表示为条件),以避免 GAN 的一些训练困难。

实现

我的实现使用门控块,但为了快速实现,我决定放弃针对盲点问题的双流解决方案(将滤波器分为水平和垂直组件)。有代码可用于解决 Tensorflow 中的盲点问题,并且在 PyTorch 中重写它相当简单。这样,掩蔽就很简单:当前像素下方和右侧的所有内容在滤波器中都被清零,并且在第一层中,当前像素也在滤波器中设置为零。

class MaskedConv(nn.Conv2d):
    def __init__(self,mask_type,in_channels,out_channels,kernel_size,stride=1):
        """
        mask_type: 'A' for first layer of network, 'B' for all others
        """
        super(MaskedConv,self).__init__(in_channels,out_channels,kernel_size,
                                        stride,padding=kernel_size//2)
        assert mask_type in ('A','B')
        mask = torch.ones(1,1,kernel_size,kernel_size)
        mask[:,:,kernel_size//2,kernel_size//2+(mask_type=='B'):] = 0
        mask[:,:,kernel_size//2+1:] = 0
        self.register_buffer('mask',mask)

    def forward(self,x):
        self.weight.data *= self.mask
        return super(MaskedConv,self).forward(x)

门控 ResNet 块的实现稍微复杂一些:PixelCNN 在网络的两半之间有快捷连接,就像 U-Net 一样;PyTorch 允许模块的前向方法仅在输入是变量时才接受多个输入;由于网络前半部分的特征图不是变量,因此它们必须与其他输入(前一层的特征)连接起来。使用条件向量可以避免这种情况,因为它是一个变量(在本例中为类标签)。

class GatedRes(nn.Module):
    def __init__(self,in_channels,out_channels,n_classes,kernel_size=3,stride=1,
                 aux_channels=0):
        super(GatedRes,self).__init__()
        self.conv = MaskedConv('B',in_channels,2*out_channels,kernel_size,
                               stride)
        self.y_embed = nn.Linear(n_classes,2*out_channels)
        self.out_channels = out_channels
        if aux_channels!=2*out_channels and aux_channels!=0:
            self.aux_shortcut = nn.Sequential(
                nn.Conv2d(aux_channels,2*out_channels,1),
                nn.BatchNorm2d(2*out_channels,momentum=0.1))
        if in_channels!=out_channels:
            self.shortcut = nn.Sequential(
                nn.Conv2d(in_channels,out_channels,1),
                nn.BatchNorm2d(out_channels,momentum=0.1))
        self.batchnorm = nn.BatchNorm2d(out_channels,momentum=0.1)

    def forward(self,x,y):
        # check for aux input from first half of net stacked into x
        if x.dim()==5:
            x,aux = torch.split(x,1,dim=0)
            x = torch.squeeze(x,0)
            aux = torch.squeeze(x,0)
        else:
            aux = None
        x1 = self.conv(x)
        y = torch.unsqueeze(torch.unsqueeze(self.y_embed(y),-1),-1)
        if aux is not None:
            if hasattr(self,'aux_shortcut'):
                aux = self.aux_shortcut(aux)
            x1 = (x1+aux)/2
        # split for gate (note: pytorch dims are [n,c,h,w])
        xf,xg = torch.split(x1,self.out_channels,dim=1)
        yf,yg = torch.split(y,self.out_channels,dim=1)
        f = torch.tanh(xf+yf)
        g = torch.sigmoid(xg+yg)
        if hasattr(self,'shortcut'):
            x = self.shortcut(x)
        return x+self.batchnorm(g*f)

我不确定在阅读原始论文时将批量归一化放在哪里,所以我将它放在我认为有意义的地方:在添加剩余连接之前。

实现这两个类后,整个网络就相对容易了。PyTorch 方案将所有内容定义为 的子类nn.Module,初始化所有层/操作/等。在构造函数中,然后在forward方法中将它们连接在一起可能会很混乱。如果您有大量快捷连接并且想要使用任意深度的循环对模型进行编码,则尤其如此。

注意:为了能够保存/恢复模型,您必须将图层存储在一个ModuleList而不是常规列表中。不过,附加和索引此列表在其他方面是相同的。

class PixelCNN(nn.Module):
    def __init__(self,in_channels,n_classes,n_features,n_layers,n_bins,
                 dropout=0.5):
        super(PixelCNN,self).__init__()

        self.layers = nn.ModuleList()
        self.n_layers = n_layers

        # Up pass
        self.input_batchnorm = nn.BatchNorm2d(in_channels,momentum=0.1)
        for l in range(n_layers):
            if l==0:  # start with normal conv
                block = nn.Sequential(
                    MaskedConv('A',in_channels+1,n_features,kernel_size=7),
                    nn.BatchNorm2d(n_features,momentum=0.1),
                    nn.ReLU())
            else:
                block = GatedRes(n_features, n_features, n_classes)
            self.layers.append(block)

        # Down pass
        for _ in range(n_layers):
            block = GatedRes(n_features, n_features,n_classes,
                             aux_channels=n_features)
            self.layers.append(block)

        # Last layer: project to n_bins (output is [-1, n_bins, h, w])
        self.layers.append(
            nn.Sequential(nn.Dropout2d(dropout),
                          nn.Conv2d(n_features,n_bins,1),
                          nn.LogSoftmax(dim=1)))

    def forward(self,x,y):
        # Add channel of ones so network can tell where padding is
        x = nn.functional.pad(x,(0,0,0,0,0,1,0,0),mode='constant',value=1)

        # Up pass
        features = []
        i = -1
        for _ in range(self.n_layers):
            i += 1
            if i>0:
                x = self.layers[i](x,y)
            else:
                x = self.layers[i](x)
            features.append(x)

        # Down pass
        for _ in range(self.n_layers):
            i += 1
            x = self.layers[i](torch.stack((x,features.pop())),y)

        # Last layer
        i += 1
        x = self.layers[i](x)
        assert i==len(self.layers)-1
        assert len(features)==0
        return x

MNIST 实际上是黑白的,因此我将标签离散为仅 4 个灰度级,以便计算交叉熵损失。在自然图像上,输出级别的数量显然需要更高。网络中的所有层都有 200 个特征。对于数据增强,我使用了 +/-5 度的随机旋转和最近邻采样。对于训练,我使用 Adam,学习率为 10 -4,dropout 率为 0.9。

更高的特征数量(比 MNIST 所需的特征数量更多)和更高的 dropout 是训练时间与正则化之间的权衡。这是一个在论文中很少提及的技巧,但有助于避免过度拟合——我只在一篇关于视频中动作识别训练的论文中看到过它,其中由于高维度与当前数据集大小,过度拟合是一个问题可用的。

我有一个 GTX1070 GPU,所以我没有运行任何类型的超参数优化:猜测合理的超参数并使模型工作的能力很大程度上说明了 Adam + 批量归一化 + dropout 的稳健性。学习率肯定可以更高,但这会产生更有趣的 GIF。

结果

在这里插入图片描述

上面的 gif 显示了整个训练过程中每个epochs后生成的一批 50 张图像(每类 5 个示例),从看似随机的涂鸦到类似于实际数字的东西。这是最佳epochs的结果:
在这里插入图片描述
这项工作的动机是看看条件 PixelCNN 是否也可以在类之间生成合理的示例。这是通过调节软标签而不是单热编码标签来完成的。

让我们尝试一下我所期望的容易混淆的数字对:(1,7), (3,8), (4,9), (5,6)
在这里插入图片描述
生成的类间示例并不像正常示例那样真实。模型可能需要一些额外的训练信号(例如来自分类器网络的教师强制)才能沿着图像流形进行插值。这有点令人失望,因为我曾希望生成类间示例可能允许使用学习的混合形式(而不是平均图像)。显然,进一步测试这个想法将需要更多的 GPU 来生成批量输入,所以无论如何,它目前超出了我的范围。

本文的完整代码可在Github代码库中查看。

本博文译自 jrbtaylor 的博客。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/186685.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

命令执行总结

之前做了一大堆的题目 都没有进行总结 现在来总结一下命令执行 我遇到的内容 这里我打算按照过滤进行总结 依据我做过的题目 过滤system 下面是一些常见的命令执行内容 system() passthru() exec() shell_exec() popen() proc_open() pcntl_exec() 反引号 同shell_exec() …

【从删库到跑路 | MySQL总结篇】数据库基础(增删改查的基本操作)

个人主页:兜里有颗棉花糖 欢迎 点赞👍 收藏✨ 留言✉ 加关注💓本文由 兜里有颗棉花糖 原创 收录于专栏【MySQL学习专栏】🎈 本专栏旨在分享学习MySQL的一点学习心得,欢迎大家在评论区讨论💌 重点放前面&am…

通过ros系统中websocket中发送sensor_msgs::Image数据给web端显示(二)

通过ros系统中websocket中发送sensor_msgs::Image数据给web端显示(二) mp4媒体流数据 #include <ros/ros.h> #include <signal.h> #include <sensor_msgs/Image.h> #include <message_filters/subscriber.h> #include <message_filters/synchroniz…

计算机组成原理。3-408

1.动态存储和静态存储 2.双端口RAM 注意&#xff1a;cpu通过地址线和数据线读写数据时&#xff0c;不能同时写&#xff0c;但可以同时读&#xff0c;也不能一边读一边写。 3.多体并行存储器 分为高位存储和低位存储 小结 4.磁盘存储器的组成 5.磁盘的性能指标 磁盘读写寻道…

ddns-go部署在linux虚拟机

ddns-go部署ubuntu1804 1.二进制部署 1.虚拟机部署 1.下载linux的x86二进制包 wget https://github.com/jeessy2/ddns-go/releases/download/v5.6.3/ddns-go_5.6.3_linux_x86_64.tar.gz2.解压 tar -xzf ddns-go_5.6.3_linux_x86_64.tar.gz3.拷贝执行文件到PATH下&#xff0c…

【Kotlin精简】第9章 Kotlin Flow

1 前言 上一章节我们学习了Kotlin的协程【Kotlin精简】第8章 协程&#xff0c;我们知道 协程实质是对线程切换的封装&#xff0c;能更加安全实现异步代码同步化&#xff0c;本质上协程、线程都是服务于并发场景下&#xff0c;其中协程是协作式任务&#xff0c;线程是抢占式任务…

【opencv】计算机视觉:实时目标追踪

目录 前言 解析 深入探究 前言 目标追踪技术对于民生、社会的发展以及国家军事能力的壮大都具有重要的意义。它不仅仅可以应用到体育赛事当中目标的捕捉&#xff0c;还可以应用到交通上&#xff0c;比如实时监测车辆是否超速等&#xff01;对于国家的军事也具有一定的意义&a…

Python武器库开发-前端篇之Html基础语法(二十九)

前端篇之Html基础语法(二十九) HTML 元素 HTML元素指的是HTML文档中的标签和内容。标签用于定义元素的类型&#xff0c;而内容则是元素所包含的内容。HTML元素由开始标签和结束标签组成&#xff0c;也可以是自闭合标签。 例如&#xff0c;下面是一个叫做<p>的HTML元素…

laravel8安装多应用多模块(笔记三)

先安装laravel8 Laravel 安装&#xff08;笔记一&#xff09;-CSDN博客 一、进入项目根目录安装 laravel-modules composer require nwidart/laravel-modules 二、 大于laravel5需配置provider&#xff0c;自动生成配置文件 php artisan vendor:publish --provider"Nwid…

Apollo接入配置中心 -- 源码分析之如何获取配置

全文参考&#xff1a;https://mp.weixin.qq.com/s/G5BV5BIdOtB3LlxNsr4ZDQ https://blog.csdn.net/crystonesc/article/details/106630412 https://www.cnblogs.com/deepSleeping/p/14565774.html 背景&#xff1a;近期在接入行内配置中心&#xff0c;因此对配置的加载接入有了…

p12 63.删除无头结点无头指针的循环链表中所有值为x的结点 桂林电子科技大学2015年 (c语言代码实现)注释详解

本题代码如下 void delete(linklist* L, int x) {lnode* p *L, * q *L;while (p->next ! q)// 从第一个结点开始遍历链表&#xff0c;直到尾结点的前一个结点{if (p->next->data x)//判断是否等于x{lnode* r p->next;//将r指向x的位置p->next r->next;…

JoyT的科研之旅第一周——科研工具学习及论文阅读收获

CiteSpace概述 CiteSpace 是一个用于可视化和分析科学文献的工具&#xff0c;它专门针对研究者进行文献回顾和趋势分析。CiteSpace 的核心功能是创建文献引用网络&#xff0c;这些网络揭示了研究领域内各个文献之间的相互关系。使用 CiteSpace 可以为论文研究做出贡献的几种方…

Linux常用命令——bind命令

在线Linux命令查询工具 bind 显示或设置键盘按键与其相关的功能 补充说明 bind命令用于显示和设置命令行的键盘序列绑定功能。通过这一命令&#xff0c;可以提高命令行中操作效率。您可以利用bind命令了解有哪些按键组合与其功能&#xff0c;也可以自行指定要用哪些按键组合…

【Proteus仿真】【STM32单片机】智能垃圾桶设计

文章目录 一、功能简介二、软件设计三、实验现象联系作者 一、功能简介 本项目使用Proteus8仿真STM32单片机控制器&#xff0c;使用报警模块、LCD1602液晶模块、按键模块、人体红外传感器、HCSR04超声波、有害气体传感器、SG90舵机等。 主要功能&#xff1a; 系统运行后&…

python基础-numpy

numpy中shape (1,X) 和 &#xff08;X&#xff0c;&#xff09;的区别 参考 首先放结论&#xff1a;shape(x,)是一维数组&#xff0c;ndim1,[1,2,3,…x] ;shape(1,x)是二维&#xff1f;数组&#xff0c;ndim2,[[1,2,3,…n]] >>> import numpy as np >>> a…

常见树种(贵州省):016杜鹃、含笑、桃金娘、金丝桃、珍珠花、观光木

摘要&#xff1a;本专栏树种介绍图片来源于PPBC中国植物图像库&#xff08;下附网址&#xff09;&#xff0c;本文整理仅做交流学习使用&#xff0c;同时便于查找&#xff0c;如有侵权请联系删除。 图片网址&#xff1a;PPBC中国植物图像库——最大的植物分类图片库 一、杜鹃 …

记录华为云服务器(Linux 可视化 宝塔面板)-- 安全组篇

文章目录 前言安全组说明安全组的特性安全组的应用场景 进入安全组添加基本规则添加自定义规则如有启发&#xff0c;可点赞收藏哟~ 前言 和windows防火墙类似&#xff0c;安全组是一种虚拟防火墙&#xff0c;具备状态检测和数据包过滤功能&#xff0c;可以对进出云服务器的流量…

5种主流API网关技术选型,yyds!

API网关是微服务项目的重要组成部分&#xff0c;今天来聊聊API网关的技术选型&#xff0c;有理论&#xff0c;有实战。 不 BB&#xff0c;上文章目录&#xff1a; 1 API网关基础 1.1 什么是API网关 API网关是一个服务器&#xff0c;是系统的唯一入口。 从面向对象设计的角度…

Linux加强篇002-部署Linux系统

目录 前言 1. shell语言 2. 执行命令的必备知识 3. 常用系统工作命令 4. 系统状态检测命令 5. 查找定位文件命令 6. 文本文件编辑命令 7. 文件目录管理命令 前言 悟已往之不谏&#xff0c;知来者之可追。实迷途其未远&#xff0c;觉今是而昨非。舟遥遥以轻飏&#xff…

测试用例的缝缝补补

&#x1f4d1;打牌 &#xff1a; da pai ge的个人主页 &#x1f324;️个人专栏 &#xff1a; da pai ge的博客专栏 ☁️山水速疾来去易&#xff0c;襄樊镇固永难开 ☁️定位页面的元素 参数:抽象类By里…