生成模型 -- GAN

文章目录

  • 1. 生成模型与判别模型
    • 1.1 生成模型
  • 2. VAE
  • 3. GAN
    • 3.1 GAN-生成对抗网络
    • 3.2 GAN-生成对抗网络的训练
      • 3.2.1 判别模型的训练:
      • 3.2.2 生成网络的训练:
  • 4. LeakyReLU
  • 5. GAN代码实例

1. 生成模型与判别模型

生成模型与判别模型
我们前面几章主要介绍了机器学习中的判别式模型,这种模型的形式主要是根据原始图像推测图像具备的一些性质,例如根据数字图像推测数字的名称,根据自然场景图像推测物体的边界;

而生成模型恰恰相反,通常给出的输入是图像具备的性质,而输出是性质对应的图像。这种生成模型相当于构建了图像的分布,因此利用这类模型,我们可以完成图像自动生成(采样)、图像信息补全等工作。

在深度学习之前已经有很多生成模型,但苦于生成模型难以描述难以建模,科研人员遇到了很多挑战,而深度学习的出现帮助他们解决了不少问题。

基于深度学习思想的生成模型——GAN和VAE,以及GAN的变种模型。

1.1 生成模型

  • 生成图片
  • 人脸生成
  • 照片生成
  • 生成卡通人物
  • 图像转换
  • 文本到图片的转换
  • 语义图片到照片的转换
  • 正脸图片生成
  • 生成新的人体姿势
  • 照片到表情的转换
  • 照片编辑
  • 图片混合
  • 超分辨率
  • 图片修复
  • 衣服转换
  • 视频预测
  • 3D 物体生成

在这里插入图片描述

在这里插入图片描述

在这里插入图片描述

在这里插入图片描述

2. VAE

VAE-Variational Autoencoder
变分自动编码器
想象这样一个网络,输入是一组全部为1的向量,目标是一张猫脸,经过好多好多轮的训练。 我们只要输入这个全部为1的向量就可以得到这张猫的脸。

其实这是因为在训练的过程中,我们通过不断地训练,网络已经将这张猫的图片的参数保存起来了。

在这里插入图片描述

这个工作其实已经可以看出他的意义所在了,通过一个网络,将一个高维空间的脸映射为低维空间的一个向量。

那么如果,我们尝试使用更多的图片。这次我们用one-hot向量而不是全1向量。我们用[1, 0, 0, 0]代表猫,用[0, 1, 0, 0]代表狗。虽然这也没什么问题,但是我们最多只能储存4张图片。

于是,我们可以增加向量的长度和网络的参数,那么我们可以获得更多的图片。

例如,将这个向量定义为四维,采用one-hot的表达方式表达四张不同的脸,那么这个网络就可以表达四个脸。输入不同的数据,他就会输出不同的脸来。

在这里插入图片描述

但是,这样的向量很稀疏。为了解决这个问题,我们想使用实数值向量而不是0,1向量。我们可认为这种实数值向量是原图片的一种编码,这也就引出了编码/解码的概念。

举个例子,[3.3, 4.5, 2.1, 9.8]代表猫,[3.4, 2.1, 6.7, 4.2] 代表狗。

这个已知的初始向量可以作为我们的潜在变量。

如果像我上面一样,随机初始化一些向量去代表图片的编码,这不是一个很好的办法,我们更希望计算机能帮我们自动编码。在auto encoder模型中,我们加入一个编码器,它能帮我们把图片编码成向量。然后解码器能够把这些向量恢复成图片。

在这里插入图片描述

在下面这个图中,我们通过六个因素来描述最终的人脸形状,而这些因素不同的值则代表了不同的特性。

在这里插入图片描述

3. GAN

3.1 GAN-生成对抗网络

什么是生成对抗网络,GAN–Generative Adversarial Network,

  1. 对抗网络有一个生成器(Generator),还有一个判别器 (Discriminator);
  2. 生成器从随机噪声中生成图片,由于这些图片都是生成器臆想出来的,所以我 们称之为 Fake Image;
  3. 生成器生成的照片Fake Image和训练集里的Real Image都会传入判别器,判别器判断他们是 Real 还是 Fake。

那么我们如何训练网络呢?要达到什么样的目的?

  1. 我们希望生成器生成的图片足够真实,可以骗过判别器;
  2. 我们也希望判别器足够“精明”,可以很好的分别出真图还是生成图;
  3. 最后在训练中,生成器和判别器达到一种“对抗”中的平衡,结束训练。
  4. 这时,我们分离出生成器,它便可以帮助我们“生成”想要的图片。

在这里插入图片描述

我们要明白在使用GAN的时候的2个问题

  1. 我们有什么?
    比如上图,我们有的只是真实采集而来的人脸样本数据集,仅此而已,而且很关键的一点是我们连人脸数据集的类标签都没有,也就是我们不知道那个人脸对应的是谁。
  2. 我们要得到什么?
    至于要得到什么,不同的任务得到的东西不一样,我们只说最原始的GAN目的,那就是我们想通过输入一个噪声,模拟得到一个人脸图像,这个图像可以非常逼真以至于以假乱真。

首先判别模型,就是图中右半部分的网络,直观来看就是一个简单的神经网络结构,输入就是一副图像,输出就是一个概率值,用于判断真假使用(概率值大于0.5那就是真,小于0.5那就是假),真假也不过是人们定义的概率而已。

其次是生成模型,同样也可以看成是一个神经网络模型,输入是一组随机数Z,输出是一个图像,不再是一个数值。

从图中可以看到,会存在两个数据集,一个是真实数据集,另一个是假的数据集.

GAN的目标:

  1. 判别网络的目的:就是能判别出来输入的一张图它是来自真实样本集还是假样本集。假如输入的是真样本,网络输出就接近1,输入的是假样本,网络输出接近0,达到了很好的判别的目的。
  2. 生成网络的目的:生成网络是造样本的,它的目的就是使得自己造样本的能力尽可能强,尽可能的使判别网络没法判断是真样本还是假样本。

生成网络与判别网络的目的正好是相反的,一个说我能判别的好,一个说我让你判别不好。

所以叫做对抗,叫做博弈。

那么最后的结果到底是谁赢呢?

这就要归结到设计者,也就是我们希望谁赢了。

作为设计者的我们,我们的目的是要得到以假乱真的样本,那么很自然的我们希望生成样本赢了,也就是希望生成样本很真,判别网络的能力不足以区分真假样本为止。

3.2 GAN-生成对抗网络的训练

单独交替迭代训练
在这里插入图片描述

3.2.1 判别模型的训练:

假设现在生成网络模型已经有了(当然可能不是最好的生成网络),那么给一堆随机数组,就会得到一堆假的样本集(因为不是最终的生成模型,那么现在生成网络可能就处于劣势,导致生成的样本就不咋地,可能很容易就被判别网络判别出来了说这货是假冒的)。

假设我们现在有了这样的假样本集,而真样本集一直都有,现在我们人为地定义真假样本集的标签,因为我们希望真样本集的输出尽可能为1,假样本集为0,很明显这里我们就已经默认真样本集所有的类标签都为1,而假样本集的所有类标签都为0.。

所以,我们现在有了真样本集以及它们的label(都是1)、假样本集以及它们的label(都是0)

这样单就判别网络来说,此时问题就变成了一个再简单不过的有监督的二分类问题了,直接送到神经网络模型中训练就可以了。

3.2.2 生成网络的训练:

想想我们的目的,是生成尽可能逼真的样本。
那么原始的生成网络生成的样本,怎么知道它真不真呢?
就是送到判别网络中,所以在训练生成网络的时候,我们需要联合判别网络一起才能达到训练的目的。
把刚才的判别网络串接在生成网络的后面,这样我们就知道真假了,也就有了误差了。

所以对于生成网络的训练其实是对生成-判别网络串接的训练。

对于样本,我们要把生成的假样本的标签都设置为1,也就是认为这些假样本在生成网络训练的时候是真样本。

那么为什么要这样呢?我们想想,是不是这样才能起到迷惑判别器的目的,也才能使得生成的假样本逐渐逼近为真样本。

现在对于生成网络的训练,我们有了样本集(只有假样本集,没有真样本集),有了对应的label(全为1)。

注意,在训练这个串接的网络的时候,一个很重要的操作就是不要更新判别网络的参数,只是把误差一直传, 传到生成网络后更新生成网络的参数。

在完成生成网络训练后,我们就可以根据目前新的生成网络再对先前的那些噪声Z生成新的假样本了。

并且训练后的假样本应该是更真了才对。

所有这样我们又有了新的真假样本集,这样又可以重复上述过程了。

我们把这个过程称作为单独交替训练

4. LeakyReLU

Relu的输入值为负的时候,输出始终为0,其一阶导数也始终为0,这样会导致神经元不能更新参数,也就是神经元不学习了,这种现象叫做“Dead Neuron”。

为了解决Relu函数这个缺点,在Relu函数的负半区间引入一个泄露(Leaky)值,所以称为Leaky Relu函数。即ReLU在取值小于零部分没有梯度,LeakyReLU在取值小于0部分给一个很小的梯度。
在这里插入图片描述

5. GAN代码实例

from __future__ import print_function, division

from keras.datasets import mnist
from keras.layers import Input, Dense, Reshape, Flatten, Dropout
from keras.layers import BatchNormalization, Activation, ZeroPadding2D
from keras.layers.advanced_activations import LeakyReLU
from keras.layers.convolutional import UpSampling2D, Conv2D
from keras.models import Sequential, Model
from keras.optimizers import Adam

import os
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'
import tensorflow as tf
gpu_options = tf.GPUOptions(per_process_gpu_memory_fraction=0.5)
sess = tf.Session(config=tf.ConfigProto(gpu_options=gpu_options))

import matplotlib.pyplot as plt

import sys

import numpy as np

class GAN():
    def __init__(self):
        self.img_rows = 28
        self.img_cols = 28
        self.channels = 1
        self.img_shape = (self.img_rows, self.img_cols, self.channels)
        self.latent_dim = 100

        optimizer = Adam(0.0002, 0.5)

        # Build and compile the discriminator
        self.discriminator = self.build_discriminator()
        self.discriminator.compile(loss='binary_crossentropy',
            optimizer=optimizer,
            metrics=['accuracy'])

        # Build the generator
        self.generator = self.build_generator()

        # The generator takes noise as input and generates imgs
        z = Input(shape=(self.latent_dim,))
        img = self.generator(z)

        # For the combined model we will only train the generator
        self.discriminator.trainable = False

        # The discriminator takes generated images as input and determines validity
        validity = self.discriminator(img)

        # The combined model  (stacked generator and discriminator)
        # Trains the generator to fool the discriminator
        self.combined = Model(z, validity)
        self.combined.compile(loss='binary_crossentropy', optimizer=optimizer)


    def build_generator(self):

        model = Sequential()

        model.add(Dense(256, input_dim=self.latent_dim))
        model.add(LeakyReLU(alpha=0.2))
        model.add(BatchNormalization(momentum=0.8))
        model.add(Dense(512))
        model.add(LeakyReLU(alpha=0.2))
        model.add(BatchNormalization(momentum=0.8))
        model.add(Dense(1024))
        model.add(LeakyReLU(alpha=0.2))
        model.add(BatchNormalization(momentum=0.8))
        model.add(Dense(np.prod(self.img_shape), activation='tanh'))
        model.add(Reshape(self.img_shape))

        model.summary()

        noise = Input(shape=(self.latent_dim,))
        img = model(noise)

        return Model(noise, img)

    def build_discriminator(self):

        model = Sequential()

        model.add(Flatten(input_shape=self.img_shape))
        model.add(Dense(512))
        model.add(LeakyReLU(alpha=0.2))
        model.add(Dense(256))
        model.add(LeakyReLU(alpha=0.2))
        model.add(Dense(1, activation='sigmoid'))
        model.summary()

        img = Input(shape=self.img_shape)
        validity = model(img)

        return Model(img, validity)

    def train(self, epochs, batch_size=128, sample_interval=50):

        # Load the dataset
        (X_train, _), (_, _) = mnist.load_data()

        # Rescale -1 to 1
        X_train = X_train / 127.5 - 1.
        X_train = np.expand_dims(X_train, axis=3)

        # Adversarial ground truths
        valid = np.ones((batch_size, 1))
        fake = np.zeros((batch_size, 1))

        for epoch in range(epochs):

            # ---------------------
            #  Train Discriminator
            # ---------------------

            # Select a random batch of images
            idx = np.random.randint(0, X_train.shape[0], batch_size)
            imgs = X_train[idx]

            noise = np.random.normal(0, 1, (batch_size, self.latent_dim))

            # Generate a batch of new images
            gen_imgs = self.generator.predict(noise)

            # Train the discriminator
            d_loss_real = self.discriminator.train_on_batch(imgs, valid)
            d_loss_fake = self.discriminator.train_on_batch(gen_imgs, fake)
            d_loss = 0.5 * np.add(d_loss_real, d_loss_fake)

            # ---------------------
            #  Train Generator
            # ---------------------

            noise = np.random.normal(0, 1, (batch_size, self.latent_dim))

            # Train the generator (to have the discriminator label samples as valid)
            g_loss = self.combined.train_on_batch(noise, valid)

            # Plot the progress
            print ("%d [D loss: %f, acc.: %.2f%%] [G loss: %f]" % (epoch, d_loss[0], 100*d_loss[1], g_loss))

            # If at save interval => save generated image samples
            if epoch % sample_interval == 0:
                self.sample_images(epoch)

    def sample_images(self, epoch):
        r, c = 5, 5
        noise = np.random.normal(0, 1, (r * c, self.latent_dim))
        gen_imgs = self.generator.predict(noise)

        # Rescale images 0 - 1
        gen_imgs = 0.5 * gen_imgs + 0.5

        fig, axs = plt.subplots(r, c)
        cnt = 0
        for i in range(r):
            for j in range(c):
                axs[i,j].imshow(gen_imgs[cnt, :,:,0], cmap='gray')
                axs[i,j].axis('off')
                cnt += 1
        fig.savefig("./images/mnist_%d.png" % epoch)
        plt.close()


if __name__ == '__main__':
    gan = GAN()
    gan.train(epochs=2000, batch_size=32, sample_interval=200)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/85535.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

UE4 材质学习笔记

CheapContrast与CheapContrast_RGB都是提升对比度的,一个是一维输入,一个是三维输入,让亮的地方更亮,暗的地方更暗,不像power虽然也是提升对比度,但是使用过后的结果都是变暗或者最多不变(值为1…

Mybatis简单入门

星光下的赶路人star的个人主页 夏天就是吹拂着不可预期的风 文章目录 1、Mybatis介绍1.1 JDBC痛点1.2 程序员的诉求1.3 Mybatis简介 2、数据准备2.1 数据准备2.2 建工程2.3 Employee类2.4 Mybatis的全局配置2.5 编写要执行的SQL2.6 编写java程序2.7 稍微总结一下流程 3、解决属…

什么是安全测试报告,怎么获得软件安全检测报告?

安全测试报告 软件安全测试报告:是指测试人员对软件产品的安全缺陷和非法入侵防范能力进行检查和验证的过程,并对软件安全质量进行整体评估,发现软件的缺陷与 bug,为开发人员修复漏洞、提高软件质量奠定坚实的基础。 怎么获得靠谱…

Hadoop学习一(初识大数据)

目录 一 什么是大数据? 二 大数据特征 三 分布式计算 四 Hadoop是什么? 五 Hadoop发展及版本 六 为什么要使用Hadoop 七 Hadoop vs. RDBMS 八 Hadoop生态圈 九 Hadoop架构 一 什么是大数据? 大数据是指无法在一定时间内用常规软件工具对其内…

昌硕科技、世硕电子同步上线法大大电子合同

近日,世界500强企业和硕联合旗下上海昌硕科技有限公司(以下简称“昌硕科技”)、世硕电子(昆山)有限公司(以下简称“世硕电子”)的电子签项目正式上线。上线仪式在上海浦东和硕集团科研大楼举行&…

渗透测试方法论

文章目录 渗透测试方法论1. 渗透测试种类黑盒测试白盒测试脆弱性评估 2. 安全测试方法论2.1 OWASP TOP 102.3 CWE2.4 CVE 3. 渗透测试流程3.1 通用渗透测试框架3.1.1 范围界定3.1.2 信息搜集3.1.3 目标识别3.1.4 服务枚举3.1.5 漏洞映射3.1.6 社会工程学3.1.7 漏洞利用3.1.8 权…

Java课题笔记~ SpringBoot基础配置

二、基础配置 1. 配置文件格式 问题导入 框架常见的配置文件有哪几种形式? 1.1 修改服务器端口 http://localhost:8080/books/1 >>> http://localhost/books/1 SpringBoot提供了多种属性配置方式 application.properties server.port80 applicati…

jmeter HTTP请求默认值

首先,打开JMeter并创建一个新的测试计划。 右键单击测试计划,选择"添加" > “配置元件” > “HTTP请求默认值”。 在HTTP请求默认值中,您可以设置全局的HTTP请求属性,例如: 服务器地址&#xff1a…

神经网络简单理解:机场登机

目录 神经网络简单理解:机场登机 ​编辑 激活函数:转为非线性问题 ​编辑 激活函数ReLU 通过神经元升维(神经元数量):提升线性转化能力 通过增加隐藏层:增加非线性转化能力​编辑 模型越大,…

uniapp日期选择组件优化

<uni-forms-item label="出生年月" name="birthDate"><view style="display: flex;flex-direction: row;align-items: center;height: 100%;"><view class="" v-

【图论】最短路的传送问题

一.分层图问题&#xff08;单源传送&#xff09; &#xff08;1&#xff09;题目 P4568 [JLOI2011] 飞行路线 - 洛谷 | 计算机科学教育新生态 (luogu.com.cn) &#xff08;2&#xff09;思路 可知背景就是求最短路问题&#xff0c;但难点是可以使一条路距离缩短至0&#xf…

一、数学建模之线性规划篇

1.定义 2.例题 3.使用软件及解题 一、定义 1.线性规划&#xff08;Linear Programming&#xff0c;简称LP&#xff09;是一种数学优化技术&#xff0c;线性规划作为运筹学的一个重要分支&#xff0c;专门研究在给定一组线性约束条件下&#xff0c;如何找到一个最优的决策&…

绿盾客户端字体库文件被加密了,预览不了

环境: 绿盾客户端7.0 Win10 专业版 问题描述: 绿盾客户端字体库文件被加密了,预览不了 预览不了 解决方案 1.打开控制台 2.进入规则中心 3.找到对应的操作员类型 4.选择自定义程序 5.右键新建程序,填最开始获得的程序名,可执行程序选择本地程序,我本地没有这个…

pytest之parametrize参数化

前言 我们都知道pytest和unittest是兼容的&#xff0c;但是它也有不兼容的地方&#xff0c;比如ddt数据驱动&#xff0c;测试夹具fixtures&#xff08;即setup、teardown&#xff09;这些功能在pytest中都不能使用了&#xff0c;因为pytest已经不再继承unittest了。 不使用dd…

PHP8中自定义函数-PHP8知识详解

1、什么是函数&#xff1f; 函数&#xff0c;在英文中的单词是function&#xff0c;这个词语有功能的意思&#xff0c;也就是说&#xff0c;使用函数就是在编程的过程中&#xff0c;实现一定的功能。即函数就是实现一定功能的一段特定代码。 在前面的教学中&#xff0c;我们已…

如何进行在线pdf转ppt?在线pdf转ppt的方法

在当今数字化时代&#xff0c;PDF文件的广泛应用为我们的工作和学习带来了巨大的便利。然而&#xff0c;有时候我们可能需要将PDF转换为PPT文件&#xff0c;以便更好地展示和分享内容。在线PDF转PPT工具因其操作简便、高效而备受欢迎。如何进行在线pdf转ppt呢?接下来&#xff…

kafka--技术文档-基本概念-《快速了解kafka》

学习一种新的消息中间键&#xff0c;卡夫卡&#xff01;&#xff01;&#xff01; 官网网址 Apache Kafka 基本概念 Kafka是一种开源的分布式流处理平台&#xff0c;由Apache软件基金会开发&#xff0c;用Scala和Java编写。它是一个高吞吐量的分布式发布订阅消息系统&#xf…

c++ qt--信号与槽(二) (第四部分)

c qt–信号与槽(二) &#xff08;第四部分&#xff09; 一.信号与槽的关系 1.一对一 2.一对多 3.多对一 4.多对多 还可以进行传递 信号->信号->槽 一个信号控制多个槽的例子&#xff08;通过水平滑块控制两个组件&#xff09; 1.应用的组件 注意这里最下面的组件…

(五)Docker 安装 redis镜像+启动redis容器(超详细)

输入&#xff1a;su root命令&#xff0c;切换到root 1、启动Docker 启动&#xff1a;sudo systemctl start docker 停止&#xff1a;systemctl stop docker 重启&#xff1a;systemctl restart docker 查看docker运行状态&#xff08;显示绿色代表正常启动&#xff09;&#x…

WPF 项目中 MVVM模式 的简单例子说明

一、概述 MVVM 是 Model view viewModel 的简写。MVVM模式有助于将应用程序的业务和表示逻辑与用户界面清晰分离。 几个概念的说明&#xff1a; model :数据&#xff0c;界面中需要的数据&#xff0c;最好不要加逻辑代码view : 视图就是用户看到的UI结构 xaml 文件viewModel …