基本介绍
今日要实践的模型是GAN,用于图像生成。使用的MNIST手写数字数据集,共有70000张手写数字图片,包含60000张训练样本和10000张测试样本,数字图片为二进制文件,图片大小为28*28,单通道,图片已经预先进行了尺寸归一化和中心化处理。本文会先简单介绍GAN图像生成的原理,然后展示自己的运行结果,不作代码展示,最后进行总结。
GAN基本原理
该部分内容来自官方文档,非原创
生成式对抗网络(Generative Adversarial Networks,GAN)是一种生成式机器学习模型。最初,GAN由Ian J. Goodfellow于2014年发明,并在论文Generative Adversarial Nets中首次进行了描述,其主要由两个不同的模型共同组成——生成器(Generative Model)和判别器(Discriminative Model):
- 生成器的任务是生成看起来像训练图像的“假”图像;
- 判别器需要判断从生成器输出的图像是真实的训练图像还是虚假的图像。
GAN通过设计生成模型和判别模型这两个模块,使其互相博弈学习产生了相当好的输出。GAN模型的核心在于提出了通过对抗过程来估计生成模型这一全新框架。在这个框架中,将会同时训练两个模型——捕捉数据分布的生成模型𝐺和估计样本是否来自训练数据的判别模型𝐷。在训练过程中,生成器会不断尝试通过生成更好的假图像来骗过判别器,而判别器在这过程中也会逐步提升判别能力。这种博弈的平衡点是,当生成器生成的假图像和训练数据图像的分布完全一致时,判别器拥有50%的真假判断置信度。
用𝑥代表图像数据,用𝐷(𝑥)表示判别器网络给出图像判定为真实图像的概率。在判别过程中,𝐷(𝑥)需要处理作为二进制文件的大小为1×28×28的图像数据。当来自训练数据时,𝐷(𝑥)数值应该趋近于 ;而当𝑥来自生成器时,𝐷(𝑥)数值应该趋近于0。因此𝐷(𝑥)也可以被认为是传统的二分类器。用𝑧代表标准正态分布中提取出的隐码(隐向量),用𝐺(𝑧):表示将隐码(隐向量)𝑧映射到数据空间的生成器函数。函数𝐺(𝑧)的目标是将服从高斯分布的随机噪声𝑧通过生成网络变换为近似于真实分布 𝑝𝑑𝑎𝑡𝑎(𝑥)的数据分布,我们希望找到 θθ 使得 𝑝𝐺(𝑥;𝜃)和 𝑝𝑑𝑎𝑡𝑎(𝑥)尽可能的接近,其中𝜃代表网络参数。
𝐷(𝐺(𝑧))表示生成器 𝐺 生成的假图像被判定为真实图像的概率,如Generative Adversarial Nets中所述,𝐷和𝐺在进行一场博弈,𝐷想要最大程度的正确分类真图像与假图像,也就是参数 log𝐷(𝑥);而𝐺试图欺骗𝐷来最小化假图像被识别到的概率,也就是参数log(1−𝐷(𝐺(𝑧)))。因此GAN的损失函数为:
从理论上讲,此博弈游戏的平衡点是𝑝𝐺(𝑥;𝜃)=𝑝𝑑𝑎𝑡𝑎(𝑥),此时判别器会随机猜测输入是真图像还是假图像。下面我们简要说明生成器和判别器的博弈过程:在训练刚开始的时候,生成器和判别器的质量都比较差,生成器会随机生成一个数据分布。
- 判别器通过求取梯度和损失函数对网络进行优化,将靠近真实数据分布的数据判定为1,将靠近生成器生成出来数据分布的数据判定为0。
- 生成器通过优化,生成出更加贴近真实数据分布的数据。
- 生成器所生成的数据和真实数据达到相同的分布,此时判别器的输出为1/2
GAN代码实践
官方给的代码实践是经典的深度学习流程。即数据集预处理,模型搭建,模型训练,模型评估,模型推理。这个流程中重点是模型搭建中的生成器和判别器,这二者是GAN的核心,最好结合代码和原理进行学习理解。详细的可直接参考官方的代码实践,这里给出我自己的运行结果和部分代码
- 数据集部分可视化结果
- 模型训练结果:由于训练时长,这里只训练了12轮,每4轮可视化一次生成的图像
如果将训练过程可视化,训练过程生成的图像的gif图如下:可以看出随着训练次数的增多,图像质量也越来越好。如果增大训练周期数,当 epoch
达到100(该数据来自官方文档)以上时,生成的手写数字图片与数据集中的较为相似
- 描绘
D
和G
损失与训练迭代的关系图
- 模型推理结果:通过加载生成器网络模型参数文件来生成图像,可以看出,由于训练轮次的不足,从而导致生成的图像质量非常差,有的甚至看不出是什么数字,像是经历了上百年的风化一样,有时间多训练个几十轮。
总结
GAN的基本思想不难,核心是生成器和判别器,其数学原理是概率统计,公式比较复杂,不同的GAN模型的生成器和判别器是不同的。今天初步了解了GAN的原理,并结合代码对其生成器和判别器有了大概的了解,顺利完成今天的实践,希望后面的不会太难。