GPT与GAN结合生成图像——VQGAN原理解析

1、前言

这篇文章,我们讲VQ_GAN,这是一个将特征向量离散化的模型,其效果相当不错,搭配Transformer(GPT)或者CLIP使用,达到的效果在当时可谓是令人拍案叫绝!

原论文:Taming Transformers for High-Resolution Image Synthesis (arxiv.org)

参考代码:dome272/VQGAN-pytorch: Pytorch implementation of VQGAN

视频:[GPT与GAN结合生成图像——VQGAN原理解析-哔哩哔哩]

效果演示:

图像生成
在这里插入图片描述

其他任务

在这里插入图片描述

2、VQVAE

VQGAN其实是VQVAE修改过来的,是VQVAE先对VAE中的编码向量离散化。而后,VQGAN就是在VQVAE的基础上进行了一些修改,以提高其生成效果

由于这篇文章讲的是VQGAN,所以不会涉及VQVAE里面的公式推导,我们就直观的理解就行了,后续我看看是否需要补一个VQVAE

3、VQGAN

论文里面提到,VQGAN的出现的动机是针对transformer,由于transformer在NLP(自然语言处理)取得了令人惊讶的效果。所以,就有很多人尝试,是否可以将transformer应用在图像处理领域

在这篇论文之前,已经有人进行尝试,transformer可以应用在图像领域,并且取得了相当不错的效果。然而,相对于NLP,图像处理的难度却比较大,在transformer中,一句话的长度往往不会很长,里面的自注意力机制的计算量仍然可以接收;可图像处理领域的每个像素如果都要做自注意力的话,在低像素的或许还可以接收,但是一旦到了高像素,其计算量往往令人望而生畏。

受VQVAE的启发,作者先把图像通过编码器,编码成维度较低的向量,从而减少自注意力机制的计算量。并且,会把编码后的向量离散化。作者认为,在自然界中,图像本身应该是由一个个离散的量组合而成的,就好比东一个西一个,就拼成了车。

4、VQGAN流程

在这里插入图片描述

首先,从左下角开始,有一张狗的照片(红框),把这张图送给一个卷积编码器( E E E),输出向量 z ^ \hat z z^

接着,初始化一个码本(Codebook Z ∈ R ( n u m , d i m ) Z\in R^{(num,dim)} ZR(num,dim),num是码本有多少行,dim是每行多少维度),把向量 z ^ \hat z z^在像素层面上,都在码本中找到与它最像的一个向量(使用最近邻搜索)。得到 z q z_q zq(图中像素上面的数字代表码本对应位置向量)

把得到的 z q z_q zq,送给解码器G,恢复图像,然后把这张还原的图像和生成的图像,送给卷积判别器D,判断真伪。

这就是整个流程。

我们看图中的码本,码本中对应的向量,分别表示图中那只狗某一块的特征,这种就是特征的离散化,能够让特征充分解耦。

5、VQVAE的损失

VQGAN的目标,就是学习到一个足够好的码本,编码器和解码器。

在讲VQGAN之前,我们先来看VQVAE。

5.1、VQVAE重构损失

这是VQVAE的模型图(与VQGAN相比,少了判别网络D)

在这里插入图片描述

如果你知道VAE或者AE,就应该知道,我们要让编码后再解码得到的图像和原始图像很像,那就说明这两个编码和解码器足够好。所以,我们要让重构的损失最小。即
L r e c = ∣ ∣ x − x ^ ∣ ∣ 2 = ∣ ∣ x − G ( z q ) ∣ ∣ 2 L_{rec} = ||x-\hat x||^2=||x-G(z_q)||^2 Lrec=∣∣xx^2=∣∣xG(zq)2
x ^ \hat x x^表示重构出来的图像, G G G是解码器。

这是一种非常朴素的想法,但是,这里有个问题,那就是里面的 z q z_q zq z ^ \hat z z^在码本中最近邻搜索弄出来,这种最近邻匹配的方法是没有办法把梯度传递会编码器E那边的。于是,作者提出了straight-through estimator,具体做法如下,我们令
z q = z ^ + s g ( z q − z ^ ) (1) z_q = \hat z+ sg(z_q-\hat z)\tag{1} zq=z^+sg(zqz^)(1)
其中,里面的sg就是停止梯度的意思,也就是当反向传播的时候,括号里面那一项梯度不计。

于是,便有
s g = { s g = 1 ; 正向传播 s g = 0 ; 反向传播 sg=\left\{\begin{matrix}sg = 1;正向传播\\sg=0;反向传播\end{matrix}\right. sg={sg=1;正向传播sg=0;反向传播
当正向传播,把 s g = 1 sg=1 sg=1代入式(1),等式成立;反向传播的时候, s g = 0 sg=0 sg=0,会导致直接传梯度到 z ^ \hat z z^

也就是说,当正向传播时,有损失
L r e c = ∣ ∣ x − G ( z ^ + s g ( z q − z ^ ) ) ∣ ∣ 2 = ∣ ∣ x − G ( z q ) ∣ ∣ 2 L_{rec}=||x-G(\hat z+ sg(z_q-\hat z))||^2=||x-G(z_q)||^2 Lrec=∣∣xG(z^+sg(zqz^))2=∣∣xG(zq)2
反向传播时,有
L r e c = ∣ ∣ x − G ( z ^ + s g ( z q − z ^ ) ) ∣ ∣ 2 = ∣ ∣ x − G ( z ^ ) ∣ ∣ 2 L_{rec}=||x-G(\hat z+ sg(z_q-\hat z))||^2=||x-G(\hat z)||^2 Lrec=∣∣xG(z^+sg(zqz^))2=∣∣xG(z^)2
或许你会想,为什么可以这样做,这样做真的可以收敛吗?是可以的!

试想一下,当 z ^ \hat z z^通过与码本中找到最相近的向量替代原来的向量,得到 z q z_q zq,换句话说, z ^ \hat z z^ z q z_q zq是近似的,那么其更新方向也是近似相等的。

5.2、码本损失

我们要构造一个足够好的码本,去表示图像的离散特征。而我们知道 z ^ \hat z z^是编码器编码图像得到的特征,那么理所应当的,我们只需要让
L c o d e = z i ∈ Z ∣ ∣ E ( x ) − z q ∣ ∣ 2 2 L_{code}=_{z_i\in Z}||E(x)-z_q||_2^2 Lcode=ziZ∣∣E(x)zq22
z q z_q zq是像素点,在码本的对应最近邻向量。

作者认为,编码器 E E E和码本向量不应该以一样的速率优化,码本的是要学习把自己的向量与编码器的向量尽量的接近,码本的学习速率必须要快于编码器,否则码本自己优化,而不是向着编码器的方向优化。

所以将其拆分成两项
L c o d e = ∣ ∣ s g ( E ( x ) ) − z q ∣ ∣ 2 2 + β ∣ ∣ E ( x ) − s g ( z q ) ∣ ∣ 2 2 L_{code}=||sg(E(x))-z_q||_2^2+\beta ||E(x)-sg(z_q)||_2^2 Lcode=∣∣sg(E(x))zq22+β∣∣E(x)sg(zq)22
β \beta β是学习速率。取值 0.1 0.1 0.1 2.0 2.0 2.0之间,但是作者经过实验发现, β \beta β的取值对结果的影响很小,几乎没有。在VQVAE中, β = 0.25 \beta=0.25 β=0.25

5.3、总损失

故而,我们得到VQVAE的总损失函数
L V Q = L r e c + L c o d e \mathcal{L}_{VQ}=L_{rec}+L_{code} LVQ=Lrec+Lcode

6、VQGAN损失

在这里插入图片描述

6.1、感知损失

与VQVAE相比,VQGAN的作者首先把里面的重构损失 L r e c L_{rec} Lrec换成感知损失(perceptual loss)

所谓的感知损失,在一般请看下,就是把真实的图像,和解码器复原的图像,一起送给一个神经网络,比如VGG16,把这两张图像经过VGG16,都编码成特征向量,然后计算特征向量的差别,比如
L p e r = ∣ ∣ V G G ( x ) − V G G ( x ^ ) ∣ ∣ 2 (2) L_{per}=||VGG(x)-VGG(\hat x)||_2\tag{2} Lper=∣∣VGG(x)VGG(x^)2(2)
这只是举个例子,在文章中VQGAN的代码中,比这个复杂一点,它是在很多层都进行都去计算式(2)。

另外,值得注意的是,虽然论文里面写的是把重构损失换成感知损失,但是在本文上面的代码中,其实两种损失都用到了。我个人觉得也没什么不妥的,很显然重构损失是在图像层面的差异,而感知损失是特征向量的差异,所以两者加起来应当不会有什么问题。

6.2、判别网络的损失

VQGAN比VQVAE多了一个判别网络,故而加上一个判别网络的损失,以优化参数让解码器G生成的图像更好。公式如下(这是GAN的基本公式,在此不过多赘述)
L G A N ( { E , G , Z } , D ) = [ log ⁡ D ( x ) + log ⁡ ( 1 − D ( x ^ ) ) ] \mathcal{L}_{GAN}(\{E,G,Z\},D)=[\log D(x)+\log(1-D(\hat x))] LGAN({E,G,Z},D)=[logD(x)+log(1D(x^))]
因此,最终的损失函数如下
L = min ⁡ E , G , Z max ⁡ D E x ∼ p ( x ) [ L V Q ( E , G , Z ) + λ L G A N ( { E , G , Z } , D ) ] L=\min\limits_{E,G,Z}\max\limits_{D}\mathbb{E}_{x\sim p(x)}\left[\mathcal{L}_{VQ}(E,G,Z)+\lambda\mathcal{L}_{GAN}(\{E,G,Z\},D)\right] L=E,G,ZminDmaxExp(x)[LVQ(E,G,Z)+λLGAN({E,G,Z},D)]
其中, λ \lambda λ是动态变化的,其公式如下
λ = ∇ G L [ L r e c ] ∇ G L [ L G A N ] + δ \lambda = \frac{\nabla_{G_L}[\mathcal{L_{rec}}]}{\nabla_{G_L}[\mathcal{L}_{GAN}]+\delta} λ=GL[LGAN]+δGL[Lrec]
论文里面, δ = 1 0 − 6 \delta=10^{-6} δ=106 ∇ G L \nabla_{G_L} GL是关于解码器最后一层求梯度。

7、GPT及图像生成

在VQGAN里面,当训练好之后,就会得到一个训练好的编码器,解码器,以及码本。

可是,我们该如何生成图像呢?就是依靠transformer,换句话中,作者在实验的时候,其实用的是GPT2

以下为具体流程(以单张图像为例):

首先,从训练图像中,采样出一张图像。送给编码器,得到编码向量,并按像素,寻找在码本中的最近邻。但是,得到的最近邻我们不要它的向量值,只要对应的索引。

于是,我们得到的就是一行索引。比如indexs=【1,5,9,3,5,1,10,20】。

接着,只需要按照GPT的训练步骤,随机掩掉一部分值,比如indexs_mask=【1,?,?,3,5,?,10,?】

掩掉的这一部分(也就是问号),写入一些随机值,然后把indexs_mask送给GPT,让其预测出index。更准确的说,其实就是让它预测那些被掩码掉的部分,以这种方式,学习到索引之间的关系。

在这个过程中,VQGAN的参数固定不变,只训练GPT,训练完成后,就可以依靠GPT,随机初始化一个开始值,然后一点点的预测出后面的索引,得到了索引后,送给解码器,得到图像。

8、结束

其实VQGAN可以配合CLIP模型使用,达到文生图的效果。

以上,就是VQGAN的全部内容了,如有问题,还望指出。阿里嘎多!

系。

在这个过程中,VQGAN的参数固定不变,只训练GPT,训练完成后,就可以依靠GPT,随机初始化一个开始值,然后一点点的预测出后面的索引,得到了索引后,送给解码器,得到图像。

8、结束

其实VQGAN可以配合CLIP模型使用,达到文生图的效果。

以上,就是VQGAN的全部内容了,如有问题,还望指出。阿里嘎多!

在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/568601.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

LTD271次升级 | 网站/小程序可设访问IP的黑白名单 • 官微中心支持PDF等办公文件预览与并分享 • 订单退款显示更详尽明细

1、新增IP访问限制功能; 2、订单新增交易号显示与退款明细显示; 3、自定义地址增加四级地区; 4、Android版App优化文件功能; 5、已知问题修复与优化; 01 官微中心 1) 新增IP限制访问功能 允许或者禁止某些 IP 或…

uniapp项目中禁止横屏 ,app不要自动旋转 -,保持竖屏,uniapp取消重力感应

uniapp项目中禁止横屏 ,app不要自动旋转 -,保持竖屏,uniapp取消重力感应 1.适用于移动端,安卓和IOS,当即使手机打开了自动旋转的按钮,设置如下的代码后,页面依旧保持竖屏。 步骤一&#xff1a…

【深度学习】yolo-World,数据标注,zeroshot,目标检测

仓库:https://github.com/AILab-CVC/YOLO-World 下载权重: 仓库下载和环境设置 下载仓库:使用以下命令从 GitHub 上克隆仓库: git clone --recursive https://github.com/AILab-CVC/YOLO-World.git创建并激活环境&#xff1a…

程序猿成长之路之数据挖掘篇——朴素贝叶斯

朴素贝叶斯是数据挖掘分类的基础,本篇文章将介绍一下朴素贝叶斯算法 情景再现 以挑选西瓜为例,西瓜的色泽、瓜蒂、敲响声音、触感、脐部等特征都会影响到西瓜的好坏。那么我们怎么样可以挑选出一个好的西瓜呢? 分析过程 既然挑选西瓜有多个…

DaPy:实现数据分析与处理

DaPy:实现数据分析与处理 DaPy是一个用于数据分析和处理的Python库,它提供了一系列强大的工具和功能,使开发者能够高效地进行数据清洗、转换和分析。本文将深入解析DaPy库的特点、功能以及使用示例,帮助读者了解如何利用DaPy库处理…

贪心算法在单位时间任务调度问题中的应用

贪心算法在单位时间任务调度问题中的应用 一、引言二、问题描述与算法设计三、算法证明四、算法实现与效率分析五、C语言实现示例六、结论 一、引言 单位时间任务调度问题是一类经典的优化问题,旨在分配任务到不同的时间槽中,使得某种性能指标达到最优。…

【QT进阶】Qt http编程之实现websocket server服务器端

往期回顾 【QT进阶】Qt http编程之json解析的简单介绍-CSDN博客 【QT进阶】Qt http编程之nlohmann json库使用的简单介绍-CSDN博客 【QT进阶】Qt http编程之websocket的简单介绍-CSDN博客 【QT进阶】Qt http编程之实现websocket server服务器端 一、最终效果 通过ip地址和端口…

万界星空科技电机行业MES+商业电机行业开源MES+项目合作

要得出mes系统解决方案在机电行业的应用范围,我们先来看一下传统机电行业的管理难题: 1、 产品标准化程度较低,制造工艺复杂,生产周期较长,产品质量不稳定; 2、 自动化程度低,大多数工序以手工…

【视频异常检测】Open-Vocabulary Video Anomaly Detection 论文阅读

Open-Vocabulary Video Anomaly Detection 论文阅读 AbstractMethod3.1. Overall Framework3.2. Temporal Adapter Module3.3. Semantic Knowledge Injection Module3.4. Novel Anomaly Synthesis Module3.5. Objective Functions3.5.1 Training stage without pseudo anomaly …

电子信息制造工厂5G智能制造数字孪生可视化平台,推进数字化转型

电子信息制造工厂5G智能制造数字孪生可视化平台,推进数字化转型。5G智能制造数字孪生可视化平台利用5G网络的高速、低延迟特性,结合数字孪生技术和可视化界面,为电子信息制造工厂提供了一种全新的生产管理模式。不仅提升生产效率,…

设计模式(三):抽象工厂模式

设计模式(三):抽象工厂模式 1. 抽象工厂模式的介绍2. 抽象工厂模式的类图3. 抽象工厂模式的实现3.1 创建摩托车的接口3.2 创建摩托车的具体实现3.3 创建汽车的接口3.4 创建汽车的具体产品3.5 创建抽象工厂3.6 创建具体工厂3.7 创建工厂生成器…

Fisher判别示例:鸢尾花(iris)数据(R)

先读取iris数据,再用程序包MASS(记得要在使用MASS前下载好该程序包)中的线性函数lda()作判别分析: data(iris) #读入数据 iris #展示数据 attach(iris) #用变量名绑定对应数据 library(MASS) #加载MASS程序包 ldlda(Species~…

《ElementPlus 与 ElementUI 差异集合》el-select 显示下拉列表在 Cesium 场景中无法监听关闭

前言 仅在 Element UI 时有此问题,Element Plus 由于内部结构差异较大,不存在此问题。详见《el-select 差异点,如:高、宽、body插入等》; 问题 点击空白处,下拉列表可监听并关闭;但在 Cesium…

Python 基于 OpenCV 视觉图像处理实战 之 OpenCV 简单人脸检测/识别实战案例 之五 简单进行车牌检测和识别

Python 基于 OpenCV 视觉图像处理实战 之 OpenCV 简单人脸检测/识别实战案例 之五 简单进行车牌检测和识别 目录 Python 基于 OpenCV 视觉图像处理实战 之 OpenCV 简单人脸检测/识别实战案例 之五 简单进行车牌检测和识别 一、简单介绍 二、简单进行车牌检测和识别实现原理 …

鸿蒙(HarmonyOS)性能优化实战-Swiper高性能开发

背景 在应用开发中,Swiper 组件常用于翻页场景,比如:桌面、图库等应用。Swiper 组件滑动切换页面时,基于按需加载原则通常会在下一个页面将要显示时才对该页面进行加载和布局绘制,这个过程包括: 如果该页面…

解决VSCode中“#include错误,请更新includePath“问题

目录 1、问题原因 2、解决办法 1、问题原因 在编写C程序时,想引用头文件但是出现如下提示: (1)首先检查要引用的头文件是否存在,位于哪里。 (2)如果头文件存在,在编译时提醒VSCo…

【iOS】类与对象底层探索

文章目录 前言一、编译源码二、探索对象本质三、objc_setProperty 源码探索四、类 & 类结构分析isa指针是什么类的分析元类元类的说明 五、著名的isa走位 & 继承关系图六、objc_class & objc_objectobjc_class结构superClassbitsclass_rw_tclass_ro_tro与rw的区别c…

关于Modbus TCP 编码及解码方式分析

一.Modbus TCP 基本概念 1.基本概念 ①Coil和Register   Modbus中定义的两种数据类型。Coil是位(bit)变量;Register是整型(Word,即16-bit)变量。 ②Slave和Master与Server和Client   同一种设备在不同…

BUUCTF——[RoarCTF 2019]Easy Java

BUUCTF——[RoarCTF 2019]Easy Java 1.既然是登录框嘛,不得随便输入个弱口令,进行尝试 2.使用弱口令爆破了一下,直接就是429,无果 3.查看版本信息 4.帮助文档这里测试啦任意文件读取,无果 5.知道服务器的名称是openresty 6.…

jvm知识点总结(一)

JVM的跨平台 java程序一次编写到处运行。java文件编译生成字节码,jvm将字节码翻译成不同平台的机器码。 JVM的语言无关性 JVM只是识别字节码,和语言是解耦的,很多语言只要编译成字节码,符合规范,就能在JVM里运行&am…