从零开始 - 在Python中构建和训练生成对抗网络(GAN)模型

生成对抗网络(GANs)是一种强大的生成模型,可以合成新的逼真图像。通过完整的实现过程,读者将对GANs在幕后的工作原理有深刻的理解。本教程首先导入必要的库并加载将用于训练GAN的Fashion-MNIST数据集。然后,提供了构建GAN核心组件(生成器和判别器模型)的代码示例。接下来的部分解释了如何构建一个组合模型,该模型训练生成器以欺骗判别器,以及如何设计一个训练函数来优化对抗过程。

目录:

1. 导入库和下载数据集

2. 构建生成器模型

3. 构建判别器模型

4. 构建组合模型

5. 构建训练函数

6. 训练和观察结果

  1. 导入库和下载数据集

让我们首先导入本文中将使用的重要库:

from __future__ import print_function, division
from keras.datasets import fashion_mnist
from keras.layers import Input, Dense, Reshape, Flatten, Dropout
from keras.layers import BatchNormalization, Activation, ZeroPadding2D
from keras.layers.advanced_activations import LeakyReLU
from keras.layers.convolutional import UpSampling2D, Conv2D
from keras.models import Sequential, Model
from keras.optimizers import Adam
import numpy as np
import matplotlib.pyplot as plt

在本文中,您将在Fashion-MNIST数据集上训练DCGAN。Fashion-MNIST包含60,000个用于训练的灰度图像和一个包含10,000个图像的测试集。每个28×28的灰度图像与10个类别中的一个标签相关联。Fashion-MNIST旨在作为原始MNIST数据集的直接替代品,用于对比机器学习算法的性能。与三通道的彩色图像相比,灰度图像在一通道上训练卷积网络时需要更少的计算能力,这使您更容易在没有GPU的个人计算机上进行训练。

a43e74d2137f4a31ce4d40fe66ab7a52.jpeg

数据集分为10个时尚类别。类别标签如下:

760b0174d7592e71606bec49bf3407a5.jpeg

您可以使用以下代码加载数据集:

(training_data, _), (_, _) = fashion_mnist.load_data()
X_train = training_data / 127.5 - 1.
X_train = np.expand_dims(X_train, axis=3)

要可视化数据集中的图像,可以使用以下代码:

def visualize_input(img, ax):
    ax.imshow(img, cmap='gray')
    width, height = img.shape
    thresh = img.max()/2.5
    for x in range(width):
        for y in range(height):
            ax.annotate(str(round(img[x][y],2)), xy=(y,x),
                        horizontalalignment='center',
                        verticalalignment='center',
                        color='white' if img[x][y]<thresh else="" 'black')=""  =""  
fig = plt.figure(figsize = (12,12))
ax = fig.add_subplot(111)
visualize_input(training_data[3343], ax)We also use batch normalization and a ReLU activation.
For each of these layers, the general scheme is convolution ⇒ batch normalization
⇒ ReLU. We keep stacking up layers like this until we get the final transposed
convolution layer with shape 28 × 28 × 1:

b001bcb6986483ef65aa3f19ef9b657e.jpeg

2. 构建生成器模型

正如我们在前面的文章中所探讨的,GANs由两个主要组件组成,即生成器和判别器。在这一部分中,我们将构建生成器模型,其输入将是一个噪声向量(z)。生成器的架构如下图所示。

第一层是一个全连接层,然后被重新塑造成深而窄的层,在原始的DCGAN论文中,作者将输入重新塑造为4×4×1024。在这里,我们将使用7×7×128。然后,我们使用上采样层将特征映射的维度从7×7加倍到14×14,然后再次加倍到28×28。在这个网络中,我们使用了三个卷积层。我们还将使用批归一化和ReLU激活。

对于每个层,通用方案是卷积 ⇒ 批归一化 ⇒ ReLU。我们不断地堆叠这样的层,直到得到最终的转置卷积层,形状为28×28×1。

4fabaa16f62175b0c474ff334293c279.jpeg

以下是构建上述生成器模型的Keras代码:

def build_generator():


  generator = Sequential()


  generator.add(Dense(6272, activation="relu", input_dim=100)) # Add dense layer
  generator.add(Reshape((7, 7, 128)))  # reshape the image
  generator.add(UpSampling2D()) # Upsampling layer to double the size of the image
  generator.add(Conv2D(128, kernel_size=3, padding="same", activation="relu"))
  generator.add(BatchNormalization(momentum=0.8))
  generator.add(UpSampling2D())


  # convolutional + batch normalization layers
  generator.add(Conv2D(64, kernel_size=3, padding="same", activation="relu"))
  generator.add(BatchNormalization(momentum=0.8))


  # convolutional layer with filters = 1
  generator.add(Conv2D(1, kernel_size=3, padding="same", activation="relu"))
  generator.summary() # prints the model summary


  """
  We don't add upsampling here because the image size of 28 × 28 is 
  equal to the image size in the MNIST dataset. 
  You can adjust this for your own problem.
  """
  noise = Input(shape=(100,))
  fake_image = generator(noise)


  # Returns a model that takes the noise vector as an input and outputs the fake image
  return Model(inputs=noise, outputs=fake_image)

3. 构建判别器模型

GANs的第二个主要组件是判别器。判别器只是一个传统的卷积分类器。判别器的输入是28×28×1的图像。我们希望有一些卷积层,然后是输出的全连接层。

与之前一样,我们希望得到一个Sigmoid输出,并且我们需要返回logits。对于卷积层的深度,我们可以从第一层开始使用32或64个过滤器,然后在添加层时将深度加倍。在这个实现中,我们将从64层开始,然后是128,然后是256。对于降采样,我们不使用池化层。相反,我们只使用步幅卷积层进行降采样,类似于Radford等人的实现。

我们还使用批归一化和dropout来优化训练。对于四个卷积层的每一层,通用方案是卷积 ⇒ 批归一化 ⇒ 泄漏的ReLU。

c99ea77aec1203923646688e02c6e1d6.jpeg

现在,让我们构建build_discriminator函数:

def build_discriminator():
  discriminator = Sequential()
  discriminator.add(Conv2D(32, kernel_size=3, strides=2, input_shape=(28,28,1), padding="same"))
  discriminator.add(LeakyReLU(alpha=0.2))
  discriminator.add(Dropout(0.25))
  
  discriminator.add(Conv2D(64, kernel_size=3, strides=2,padding="same"))
  discriminator.add(ZeroPadding2D(padding=((0,1),(0,1))))
  discriminator.add(BatchNormalization(momentum=0.8))
  
  discriminator.add(LeakyReLU(alpha=0.2))
  discriminator.add(Dropout(0.25))
  
  discriminator.add(Conv2D(128, kernel_size=3, strides=2, padding="same"))
  discriminator.add(BatchNormalization(momentum=0.8))
  discriminator.add(LeakyReLU(alpha=0.2))
  discriminator.add(Dropout(0.25))
  
  discriminator.add(Conv2D(256, kernel_size=3, strides=1, padding="same"))
  discriminator.add(BatchNormalization(momentum=0.8))
  discriminator.add(LeakyReLU(alpha=0.2))
  discriminator.add(Dropout(0.25))
  
  discriminator.add(Flatten())
  discriminator.add(Dense(1, activation='sigmoid'))
  
  img = Input(shape=(28,28,1))
  probability = discriminator(img)
 
  return Model(inputs=img, outputs=probability)

4. 构建组合模型

正如本系列的第二篇文章中所解释的,为了训练生成器,我们需要构建一个包含生成器和判别器的组合网络。组合模型以噪声信号(z)作为输入,并将判别器的预测输出作为虚假或真实输出。

e90e9c2335ae20998fab73b192b20485.jpeg

重要的是要记住,我们希望在组合模型中禁用判别器的训练,正如本系列的第二篇文章中所解释的那样。在训练生成器时,我们不希望判别器更新权重,但我们仍然希望将判别器模型包含在生成器训练中。因此,我们创建一个包含两个模型的组合网络,但在组合网络中冻结判别器模型的权重:

optimizer = Adam(learning_rate=0.0002, beta_1=0.5)
discriminator = build_discriminator()
discriminator.compile(loss='binary_crossentropy', optimizer=optimizer, metrics=['accuracy'])
discriminator.trainable = False


# Build the generator
generator = build_generator()
z = Input(shape=(100,))
img = generator(z)
valid = discriminator(img)
combined = Model(inputs=z, outputs=valid)
combined.compile(loss='binary_crossentropy', optimizer=optimizer)

5. 构建训练函数

为了训练GAN模型,我们训练两个网络:判别器和我们在前面部分创建的组合网络。让我们构建train函数,该函数接受以下参数:

  • epoch

  • batch size 大小

  • save_interval,以指定多久保存一次结果

def train(epochs, batch_size=128, save_interval=50):
    
    valid = np.ones((batch_size, 1))
    fake = np.zeros((batch_size, 1))
    
    for epoch in range(epochs):  # Train Discriminator network
        idx = np.random.randint(0, X_train.shape[0], batch_size)
        imgs = X_train[idx]
        
        noise = np.random.normal(0, 1, (batch_size, 100))
        gen_imgs = generator.predict(noise)
        
        d_loss_real = discriminator.train_on_batch(imgs, valid)
        d_loss_fake = discriminator.train_on_batch(gen_imgs, fake)
        d_loss = 0.5 * np.add(d_loss_real, d_loss_fake)
        
        g_loss = combined.train_on_batch(noise, valid)
        
        # printing progress
        print("%d [D loss: %f, acc.: %.2f%%] [G loss: %f]" %(epoch, d_loss[0], 100*d_loss[1], g_loss))
        
        if epoch % save_interval == 0:
            plot_generated_images(epoch, generator)

我们还将创建另一个函数`plot_generated_images()` 来绘制生成的图像。

def plot_generated_images(epoch, generator, examples=100, dim=(10, 10),figsize=(10, 10)):
    noise = np.random.normal(0, 1, size=[examples, latent_dim])
    generated_images = generator.predict(noise)
    generated_images = generated_images.reshape(examples, 28, 28)
    
    plt.figure(figsize=figsize)
    for i in range(generated_images.shape[0]):
        plt.subplot(dim[0], dim[1], i+1)
        plt.imshow(generated_images[i], interpolation='nearest', cmap='gray_r')
        plt.axis('off')
    
    plt.tight_layout()
    plt.savefig('gan_generated_image_epoch_%d.png' % epoch

最后,让我们为训练GAN模型定义重要的变量和参数:

# Input shape
img_shape = (28,28,1)
channels = 1
latent_dim = 100
optimizer = Adam(0.0002, 0.5)


# Build and compile the discriminator
discriminator = build_discriminator()
discriminator.compile(loss='binary_crossentropy', optimizer=optimizer, metrics=['accuracy'])
# Build the generator
generator = build_generator()
# The generator takes noise as input and generates imgs
z = Input(shape=(latent_dim,))
img = generator(z)
# For the combined model we will only train the generator
discriminator.trainable = False
# The discriminator takes generated images as input and determines validity
valid = discriminator(img)
# The combined model  (stacked generator and discriminator)
# Trains the generator to fool the discriminator
combined = Model(z, valid)
combined.compile(loss='binary_crossentropy', optimizer=optimizer)

6. 训练和观察结果

此时,代码实现已经完成,我们准备开始DCGAN的训练。要训练模型,请运行以下代码行:

train(epochs=1000, batch_size=32, save_interval=50)

这将在1,000个epochs上运行训练,并每50个epochs保存一次图像。当运行`train()` 函数时,训练进度将如下所示:

86d990d67af3b9ee259b9424b3e1e521.jpeg

如下图所示,在epoch = 0时,图像只是随机噪声,没有明确的模式或有意义的数据。到了第50个epoch,图案已经开始形成。

80fb00ada0dc22c60488b9d4fda559aa.jpeg

在训练过程的后期,到了第1,000个epoch,您可以看到清晰的形状,可能能够猜测输入到GAN模型的训练数据的类型。

49de38a46bd9065cb03bb8125b1a990e.jpeg

再快进到第10,000个epoch,您会发现生成器已经非常擅长重新创建训练数据集中不存在的新图像。

de6db2898ea32036dd85c216a275c842.jpeg

·  END  ·

HAPPY LIFE

aeccadbe0b4d2dc12a3db6eea9e70b49.png

本文仅供学习交流使用,如有侵权请联系作者删除

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/286556.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

2024年【通信安全员ABC证】复审考试及通信安全员ABC证操作证考试

题库来源&#xff1a;安全生产模拟考试一点通公众号小程序 通信安全员ABC证复审考试根据新通信安全员ABC证考试大纲要求&#xff0c;安全生产模拟考试一点通将通信安全员ABC证模拟考试试题进行汇编&#xff0c;组成一套通信安全员ABC证全真模拟考试试题&#xff0c;学员可通过…

优化模型:matlab二次规划

1.二次规划 1.1 二次规划的定义 若某非线性规划的目标函数为自变量 x x x的二次函数&#xff0c;且约束条件全是线性的&#xff0c;则称这种规划模型为二次规划。 1.2 二次规划的数学模型 min ⁡ 1 2 x T H x f T x \min \frac{1}{2}\boldsymbol{x}^{\boldsymbol{T}}\bolds…

【Matlab】RF随机森林时序预测算法(附代码)

资源下载&#xff1a; https://download.csdn.net/download/vvoennvv/88692249 一&#xff0c;概述 随机森林的基本思想是利用多个决策树对时序数据进行预测&#xff0c;其中每个决策树都使用不同的随机抽样方式选择训练数据&#xff0c;以减小过拟合的风险。 随机森林时序预测…

【计算机毕业设计】SSM实验室设备管理

项目介绍 本项目为后台管理系统&#xff0c;分为管理员、老师、学生三种角色&#xff1b; 管理员角色包含以下功能&#xff1a; 信息管理&#xff1a;用户管理&#xff1b; 基础管理&#xff1a;实验室管理,实验室申请记录,设备管理,设备记录管理,耗材管理,耗材记录管理等功能…

【Java进阶篇】字符串常量、字符串常量池详解

字符串常量、字符串常量池详解 ✔️字符串常量池是如何实现的?✔️字符串常量从哪来的? ✔️字符串常量是什么时候进入到字符串常量池的? ✔️字符串常量池是如何实现的? 字符串常量池 (String Constant Pool) 是Java中一块特殊的内存区域&#xff0c;用于存储字符串常量。…

C语言——操作符

一、算数操作符 1、(加操作符) 用于将两个数相加&#xff0c;例&#xff1a;3 3结果为6 2、-(减操作符) 用于将两个数相减&#xff0c;例&#xff1a;3 - 3结果为0 3、*(乘操作符) 用于将两个数相乘&#xff0c;例&#xff1a;3 * 3结果为9 4、/(除操作符) 用于将两个…

HashMap使用-LeetCode做题总结 454. 四数相加 II

454. 四数相加 II 最初思路优化思路Java语法增强for的使用场景 最初思路 枚举&#xff0c;因为是要计算有多少个元组&#xff0c;所以每个元素肯定都要遍历到&#xff0c;所以干脆算出所有元组的和。 我想用四个for循环加&#xff0c;但是失败。 优化思路 参考力扣 四数相加为…

创建VLAN及VLAN间通信

任务1、任务2、任务3实验背景&#xff1a; 在一家微型企业中&#xff0c;企业的办公区域分为两个房间&#xff0c;一个小房间为老板办公室&#xff0c;一个大房间为开放办公室&#xff0c;财务部和销售部的员工共同使用这个办公空间。我们需要通过VLAN的划分&#xff0c;使老板…

聊聊我使用亚马逊鲲鹏系统注册买家号的心得

想和大家聊一下我最近用了个挺好用的工具&#xff0c;就是亚马逊鲲鹏系统。以前我总是烦恼要一个一个手动注册亚马逊账号&#xff0c;真是麻烦。但有了这个系统&#xff0c;简直是方便到不行&#xff01; 首先&#xff0c;它有个全自动批量注册账号的功能&#xff0c;你只需要提…

Python爬取今日头条热门文章

前言 今日头条文章收益是没有任何门槛&#xff0c;只要是你发布文章&#xff0c;每篇文章的阅读量超过1000就能有收益&#xff0c;阅读量越多收益越高。于是乎我就有了个大胆的想法。何不利用Python爬虫&#xff0c;爬取热门文章&#xff0c;然后完成自动化发布文章呢&#xf…

77 Python开发-批量FofaSRC提取POC验证

目录 本课知识点:学习目的:演示案例:Python开发-某漏洞POC验证批量脚本Python开发-Fofa搜索结果提取采集脚本Python开发-教育SRC报告平台信息提取脚本 涉及资源: 本课知识点: Request爬虫技术&#xff0c;lxml数据提取&#xff08;把一些可以用的或者有价值的数据进行提取和保…

十二星座、社交做人守信用程度指数。

双子座&#xff08;95&#xff05; &#xff09;&#xff1b;天蝎座&#xff08;92&#xff05; &#xff09;&#xff1b;处女座&#xff08;90&#xff05; &#xff09; 金牛座&#xff08;85&#xff05; &#xff09;&#xff1b;狮子座&#xff08;85&#xff05; &#…

07. HTTP接口请求重试怎么处理?

目录 1、前言 2、实现方式 2.1、循环重试 2.2、递归重试 2.3、Spring Retry 2.4、Resilience4j 2.5、http请求网络工具内置重试方式 2.6、自定义重试工具 2.7、并发框架异步重试 2.8、消息队列 3、小结 1、前言 HTTP接口请求重试是指在请求失败时&#xff0c;再次发…

[python]matplotlib

整体图示 .ipynb 转换md时候图片不能通知携带&#xff0c;所有图片失效&#xff0c;不过直接运行代码可以执行 figure figure,axes与axis import matplotlib.pyplot as plt figplt.figure() fig2plt.subplots() fig3,axsplt.subplots(2,2) plt.show()<Figure size 640x480 …

C++模板进阶操作 ---非类型模板参数、模板的特化以及模板的分离编译

本专栏内容为&#xff1a;C学习专栏&#xff0c;分为初阶和进阶两部分。 通过本专栏的深入学习&#xff0c;你可以了解并掌握C。 &#x1f493;博主csdn个人主页&#xff1a;小小unicorn ⏩专栏分类&#xff1a;C &#x1f69a;代码仓库&#xff1a;小小unicorn的代码仓库&…

计算机网络复习1

概论 文章目录 概论计算机网络的组成功能分类性能指标&#xff08;搞清楚每个时延的具体定义&#xff09;分层结构协议、接口和服务服务的分类ISO/OSITCP/IP两者的不同 计算机网络的组成 组成部分&#xff1a;硬件&#xff0c;软件和协议&#xff08;协议&#xff1a;传输数据…

C++ stack使用、模拟实现、OJ题

目录 一、介绍 二、常用函数 三、模拟实现 四、OJ练习题 1、最小栈 2、栈的压入、弹出序列 3、逆波兰表达式(后缀转中缀) 4、中缀转后缀思路 5、用栈实现队列 一、介绍 stack是一种容器适配器&#xff0c;专门用在具有后进先出操作的上下文环境中&#xff0c;其删除…

二叉树的前序遍历 、二叉树的最大深度、平衡二叉树、二叉树遍历【LeetCode刷题日志】

目录 一、二叉树的前序遍历 方法一&#xff1a;全局变量记录节点个数 方法二&#xff1a;传址调用记录节点个数 二、二叉树的最大深度 三、平衡二叉树 四、二叉树遍历 一、二叉树的前序遍历 方法一&#xff1a;全局变量记录节点个数 计算树的节点数: 函数TreeSize用于…

[情商-5]:用IT直男擅长的流程图阐述高情商聊天过程与直男聊天过程

目录 一、目标与主要思想的差别 二、高情商聊天与直男聊天的流程图 1. 发起谈话主题Topic 2. 分析谈话的主题和内容 3. 确定谈话目的&#xff1a;解决问题还是情绪交流 4. 倾听&#xff1a;站在自己的角度倾听、捕获、理解对方的情绪状态与情绪诉求 5. 同理心&#xff1…

探索 CodeWave低代码技术的魅力与应用

目录 前言1 低代码平台2 CodeWave简介3 CodeWave 的独特之处3.1 高保真还原交互视觉需求3.2 擅长复杂应用开发3.3 支持应用导出&独立部署3.4 金融级安全要求3.5 可集成性高3.6 可拓展性强 4 平台架构和核心功能4.1 数据模型设计4.2 页面设计4.3 逻辑设计4.4 流程设计4.5 接…