生成对抗网络DCGAN学习

在AI内容生成领域,有三种常见的AI模型技术:GAN、VAE、Diffusion。其中,Diffusion是较新的技术,相关资料较为稀缺。VAE通常更多用于压缩任务,而GAN由于其问世较早,相关的开源项目和科普文章也更加全面,适合入门学习。

博主从入门和学习角度用Tensorflow跑通了DCGAN,本文对其进行记录以及分享。

1.简介

GAN(Generative Adversarial Network)是一种用于生成模型的机器学习框架。其原理基于两个主要组件:生成器(Generator)和判别器(Discriminator),二者通过对抗学习的方式相互竞争和提升。

从2014年左右发展至今,GAN目前有很多分支:

  • GAN 朴素GAN,最原始版本
  • DCGAN 卷积神经网络GAN
  • CGAN 条件GAN,训练时传入额外条件,例如通过不同的mask区域生成不同内容,可控制的生成
  • SeqGAN 使用GAN生成某些风格的句子,但不能进行对答
  • Cycle GAN 可实现图像风格迁移,其实现略复杂
  • 省略

2.原理介绍

先来看图

梯度
判别
G
LeakyReLU
tanh
InputNoise
FullConnectLayer123
OutputImage
D
LeakyReLU
Sigmoid
InputImage
FullConnectLayer12
OutputOneValue

生成器(Generator)和判别器(Discriminator)是GAN的两个主要模型,生成器在上图中用缩写G表示,判别器用缩写D表示。
生成器G输入[N]的一维噪声,即InputNoise。输出[W * H * RGB](大致类似)的张量
判别器D输入一张图像,输出[1]的张量,即一个浮点数,通过0-1的值得到图像是真还是假

判别器需要尽可能的认出造假图片,生成器需要尽可能的骗过判别器,两者会在这2个目标上不断的通过反向传播进行学习,从而达到生成器和判别器的纳什均衡,最终输出质量很高的生成图像。

2.2 重点1

在训练中,判别器返回一个0-1区间的浮点数(如[0]=0.63,[0]=0.21)作为判断结果,值越高也越认为是真实图片。由于判别器也是一个神经网络模型,因此可以将输出层的梯度一直传递回输入层,然后将输入层的梯度作为生成器的梯度继续反向传播,从而完成一次训练。

然而,很多文章并没有提到这一点。如果没有接触过这种多模型梯度传递训练方法,可能会认为使用一个数学方法或者计算机视觉方法来构建判别器也可以让整个模型正常运行。但事实上,这种方法是不可行的(通常情况下)。

2.3 重点2

使用更多的层可以增强模型的推理能力。例如,在训练过程中,如果模型生成出眉毛 A 的特征,则有鼻子 B、C 和 D 相关的备选项;而如果生成出眉毛 E 的特征,则有鼻子 F 和 G 相关的备选项。

这也是为什么生成器需要使用三个隐层的原因(博主的观点)。通过增加隐层的数量,模型可以捕捉到更多的特征和抽象概念,从而提高生成器的表现能力和推理能力。更深层次的网络结构能够帮助模型学习更复杂的模式和关联,使其在生成结果时更加准确和多样化。

上图生成器部分的激活函数用的是LeakyReLU,实际上就单隐层神经网络来说,ReLU要比Sigmoid能多解决很多类型问题,Sigmoid更适合分类问题,遇到一些奇怪的问题不容易收敛,而LeakyReLU激活函数即和ReLU逻辑一样也可以返回负数信息,这是博主觉得采用这个激活函数的原因。
而至于tanH和Sigmoid的比较,它们在某种程度上相似。一般来说,网上普遍认为tanH比Sigmoid更好,主要原因是它具有较窄的数值边界范围。

2.4 重点3

对于2套样本比较损失这类问题,一般使用二分类交叉熵,这不同于分类问题。
而二分类交叉熵又是在只有2种结果(r和1-r),的情况下对公式进行的简化:
https://blog.csdn.net/grayrail/article/details/131619144

2.5 模式崩溃

训练时还会出现一种情况,即生成器始终卡在一个生成结果上,比如生成0-9数字,结果训练几轮后始终在生成数字3。
这种情况称为模式崩溃,一般增加训练样本数量并调节参数,没有比较好的办法。

3.实践准备

python库下载使用国内镜像源:
https://zhuanlan.zhihu.com/p/477179822

使用方式:

pip install -i https://pypi.tuna.tsinghua.edu.cn/simple pyspider

github库下载耽误时间,可以缓存到gitee:
在这里插入图片描述

而gitee也有自己缓存好的镜像库,可以先去这里查:
https://gitcode.net/mirrors

python库查找:
https://pypi.org/

在pip中查找python库:
先 pip install pip-search 再使用命令 pip_search 搜索

4.实践

全连接神经网络版本的朴素GAN效果相对较差,而DCGAN(Deep Convolutional GAN)是卷积神经网络版本的GAN,下面以DCGAN为例使用Tensorflow进行实现:

import tensorflow as tf
import matplotlib.pyplot as plt
from tensorflow.keras import layers

# 定义生成器模型
def build_generator():
    model = tf.keras.Sequential()
    model.add(layers.Dense(7*7*256, use_bias=False, input_shape=(100,)))
    model.add(layers.BatchNormalization())
    model.add(layers.LeakyReLU())

    model.add(layers.Reshape((7, 7, 256)))
    assert model.output_shape == (None, 7, 7, 256)  # 注意:batch size 没有限制

    model.add(layers.Conv2DTranspose(128, (5, 5), strides=(1, 1), padding='same', use_bias=False))
    assert model.output_shape == (None, 7, 7, 128)
    model.add(layers.BatchNormalization())
    model.add(layers.LeakyReLU())

    model.add(layers.Conv2DTranspose(64, (5, 5), strides=(2, 2), padding='same', use_bias=False))
    assert model.output_shape == (None, 14, 14, 64)
    model.add(layers.BatchNormalization())
    model.add(layers.LeakyReLU())

    model.add(layers.Conv2DTranspose(1, (5, 5), strides=(2, 2), padding='same', use_bias=False, activation='tanh'))
    assert model.output_shape == (None, 28, 28, 1)

    return model

# 定义判别器模型
def build_discriminator():
    model = tf.keras.Sequential()
    model.add(layers.Conv2D(64, (5, 5), strides=(2, 2), padding='same',
                                     input_shape=[28, 28, 1]))
    model.add(layers.LeakyReLU())
    model.add(layers.Dropout(0.3))

    model.add(layers.Conv2D(128, (5, 5), strides=(2, 2), padding='same'))
    model.add(layers.LeakyReLU())
    model.add(layers.Dropout(0.3))

    model.add(layers.Flatten())
    model.add(layers.Dense(1))

    return model

# 定义生成器和判别器
generator = build_generator()
discriminator = build_discriminator()

# 定义损失函数
loss_fn = tf.keras.losses.BinaryCrossentropy(from_logits=True)

# 定义生成器和判别器的优化器
generator_optimizer = tf.keras.optimizers.Adam(1e-4)
discriminator_optimizer = tf.keras.optimizers.Adam(1e-4)

def generator_loss(fake_output):
    return loss_fn(tf.ones_like(fake_output), fake_output)

def discriminator_loss(real_output, fake_output):
    real_loss = loss_fn(tf.ones_like(real_output), real_output)
    fake_loss = loss_fn(tf.zeros_like(fake_output), fake_output)
    total_loss = real_loss + fake_loss
    return total_loss

# 定义训练循环
@tf.function  #这个是tensorflow的装饰器,标记后可提升性能,不加此标记也可
def train_step(images):
    # 生成噪声向量
    noise = tf.random.normal([BATCH_SIZE, 100])

    with tf.GradientTape() as gen_tape, tf.GradientTape() as disc_tape:
        # 使用生成器生成假图片
        generated_images = generator(noise, training=True)

        # 使用判别器判断真假图片
        real_output = discriminator(images, training=True)
        fake_output = discriminator(generated_images, training=True)

        # 计算损失函数
        gen_loss = generator_loss(fake_output)
        disc_loss = discriminator_loss(real_output, fake_output)

    # 计算梯度并更新生成器和判别器的参数
    gradients_of_generator = gen_tape.gradient(gen_loss, generator.trainable_variables)
    gradients_of_discriminator = disc_tape.gradient(disc_loss, discriminator.trainable_variables)

    generator_optimizer.apply_gradients(zip(gradients_of_generator, generator.trainable_variables))
    discriminator_optimizer.apply_gradients(zip(gradients_of_discriminator, discriminator.trainable_variables))

def generate_and_save_images(model, epoch, test_input):

    predictions = model(test_input, training=False)
    print("predictions.shape:", predictions.shape)
    num_images = predictions.shape[0]
    rows = int(num_images ** 0.5) # 计算行数
    cols = num_images // rows # 计算列数
    
    fig = plt.figure(figsize=(8, 8))
    
    for i in range(num_images):
        plt.subplot(rows, cols, i+1)
        plt.imshow(predictions[i, :, :, 0] * 127.5 + 127.5, cmap='gray')
        plt.axis('off')
    
    plt.savefig('image_at_epoch_{:04d}.png'.format(epoch))
    #plt.show()

# 加载MNIST数据集
(train_images, train_labels), (_, _) = tf.keras.datasets.mnist.load_data()

# 标准化数据
train_images = train_images.reshape(train_images.shape[0], 28, 28, 1).astype('float32')
train_images = (train_images - 127.5) / 127.5

# 批量大小与训练次数
BATCH_SIZE = 256
EPOCHS = 50

# 数据集切分为批次并进行训练
dataset = tf.data.Dataset.from_tensor_slices(train_images).shuffle(60000).batch(BATCH_SIZE)

for epoch in range(EPOCHS):
    for i,image_batch in enumerate(dataset):
        print("sub i",i)
        train_step(image_batch)

    print("------------------------------------------------------epoch:", epoch)

    # 每个 epoch 结束后生成并保存一组图像
    if (epoch + 1) % 5 == 0:
        seed = tf.random.normal([BATCH_SIZE, 100])
        generate_and_save_images(generator, epoch + 1, seed)

跑一阵子MNIST数据集后,结果如下:
在这里插入图片描述


参考:

论文精读: https://www.bilibili.com/video/BV1rb4y187vD

同济子豪兄精读版本: https://www.bilibili.com/video/BV1oi4y1m7np

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/57768.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

STM32入门——外部中断

中断系统概述 中断:在主程序运行过程中,出现了特定的中断触发条件(中断源),使得CPU暂停当前正在运行的程序,转而去处理中断程序,处理完成后又返回原来被暂停的位置继续运行中断优先级&#xff…

vue 图片回显标签

第一种 <el-form-item label"打款银行回单"><image-preview :src"form.bankreceiptUrl" :width"120" :height"120"/></el-form-item>// 值为 https://t11.baidu.com/it/app106&fJPEG&fm30&fmtauto&…

Kafka-Broker工作流程

kafka集群在启动时&#xff0c;会将每个broker节点注册到zookeeper中&#xff0c;每个broker节点都有一个controller&#xff0c;哪个controller先在zookeeper中注册&#xff0c;哪个controller就负责监听brokers节点变化&#xff0c;当有分区的leader挂掉时&#xff0c;contro…

Python基本数据类型之散列类型详解

前言&#xff1a; python的基本数据类型可以分为三类&#xff1a;数值类型、序列类型、散列类型&#xff0c;本文主要介绍散列类型。 一、散列类型 散列类型&#xff1a;内部元素无序&#xff0c;不能通过下标取值 1&#xff09;字典&#xff08;dict&#xff09;&#xff…

20230803激活手机realme GT Neo3

20230803激活手机realme GT Neo3 缘起&#xff1a; 新买的手机&#xff1a;realme GT Neo3 需要确认&#xff1a; 1、4K录像&#xff0c;时间不限制。 【以前的很多手机都是限制8/10/12/16分钟】 2、通话自动录音 3、定时开关机。 4、GPS记录轨迹不要拉直线&#xff1a;户外助…

1345:香甜的黄油(Dijkstra)---信息学奥赛一本通

【题目描述】 农夫John发现做出全威斯康辛州最甜的黄油的方法&#xff1a;糖。把糖放在一片牧场上&#xff0c;他知道N&#xff08;1≤N≤500&#xff09;只奶牛会过来舔它&#xff0c;这样就能做出能卖好价钱的超甜黄油。当然&#xff0c;他将付出额外的费用在奶牛上。 农夫Jo…

【秋招】算法岗的八股文之机器学习

目录 机器学习特征工程常见的计算模型总览线性回归模型与逻辑回归模型线性回归模型逻辑回归模型区别 朴素贝叶斯分类器模型 (Naive Bayes)决策树模型随机森林模型支持向量机模型 (Support Vector Machine)K近邻模型神经网络模型卷积神经网络&#xff08;CNN&#xff09;循环神经…

【css】css实现一个简单的按钮

四种链接状态分别是&#xff1a; a:link - 正常的&#xff0c;未访问的链接a:visited - 用户访问过的链接a:hover - 用户将鼠标悬停在链接上时a:active - 链接被点击时 <style> a:link, a:visited {//未访问、访问过background-color: #07c160;//设置背景颜色color: wh…

MFC、Qt、WPF?该用哪个?

MFC、Qt和WPF都是流行的框架和工具&#xff0c;用于开发图形用户界面&#xff08;GUI&#xff09;应用程序。选择哪个框架取决于你的具体需求和偏好。MFC&#xff08;Microsoft Foundation Class&#xff09;是微软提供的框架&#xff0c;使用C编写&#xff0c;主要用于Windows…

牛客网Verilog刷题——VL47

牛客网Verilog刷题——VL47 题目答案 题目 实现4bit位宽的格雷码计数器。 电路的接口如下图所示&#xff1a; 输入输出描述&#xff1a; 信号类型输入/输出位宽描述clkwireIntput1时钟信号rst_nwireIntput1异步复位信号&#xff0c;低电平有效gray_outregOutput4输出格雷码计数…

uni-app:实现表格多选及数据获取

效果&#xff1a; 代码&#xff1a; <template><view><scroll-view scroll-x"true" style"overflow-x: scroll; white-space: nowrap;"><view class"table"><view class"table-tr"><view class&quo…

LeetCode--剑指Offer75(1)

目录 题目描述&#xff1a;剑指 Offer 05. 替换空格&#xff08;简单&#xff09;题目接口解题思路1代码解题思路2代码 PS: 题目描述&#xff1a;剑指 Offer 05. 替换空格&#xff08;简单&#xff09; 请实现一个函数&#xff0c;把字符串 s 中的每个空格替换成"%20&quo…

数字电路(一)

1、例题 1、进行DA数模转换器选型时&#xff0c;一般要选择主要参数有&#xff08; A&#xff09;、转换精度和转换速度。 A、分辨率 B、输出电流 C、输出电阻 D、模拟开关 2、下图所示电路的逻辑功能为&#xff08; B&#xff09; A、与门 B、或门 C、与非门 D、或非门 分析该…

Nodejs中的全局对象

今天我们将探讨Nodejs中的全局对象&#xff0c;这是Nodejs中重要且有趣的知识点。我们将通过生动形象的例子和风趣的风格来深入理解这些概念&#xff0c;并比较Nodejs中的全局对象与前端JavaScript中的全局对象之间的异同点。 全局对象是什么&#xff1f; 在Nodejs环境中&…

IO进程线程day6(2023.8.3)

一、Xmind整理&#xff1a; 进程与线程关系&#xff1a; 二、课上练习&#xff1a; 练习1&#xff1a;pthread_create 功能&#xff1a;创建一个线程 原型&#xff1a; #include <pthread.h> int pthread_create(pthread_t *thread, const pthread_attr_t *attr, vo…

【Unity学习笔记】生命周期

文章目录 脚本的生命周期初始化更新顺序动画更新循环各类事件结束阶段 阶段分析协程返回 总结 官方文档&#xff1a;事件函数的执行顺序 脚本的生命周期 如图&#xff1a; 脚本的生命周期主要经历以下几个阶段&#xff1a; 初始化 初始化阶段&#xff0c;&#xff08;包括初…

JVM之内存结构

1.程序计数器 定义&#xff1a;程序计数器&#xff08;Program Counter Register&#xff09;是JVM中一块较小的内存空间。解释器在解释JVM指令为机器码以供CPU执行时&#xff0c;会去程序计数器当中找到jvm指令的执行地址。 作用&#xff1a;记住下一条jvm指令的执行地址 特…

机器学习笔记 - 使用 YOLOv5、O​​penCV、Python 和 C++ 检测物体

一、YOLO v5简述 YOLO v5虽然已经不是最先进的对象检测器,但是YOLOv5 使用了一个简单的卷积神经网络 CNN架构(相对YOLO v8来讲,不过v8精度是更高了一些),更易理解。这里主要介绍如何轻松使用 YOLO v5来识别图像中的对象。将使用 OpenCV、Python 和 C++ 来加载和调用我们的…

CPU缓存那些事儿

CPU缓存那些事儿 CPU高速缓存集成于CPU的内部&#xff0c;其是CPU可以高效运行的成分之一&#xff0c;本文围绕下面三个话题来讲解CPU缓存的作用&#xff1a; 为什么需要高速缓存&#xff1f;高速缓存的内部结构是怎样的&#xff1f;如何利用好cache&#xff0c;优化代码执行…

Go学习第三天

map的三种声明定义方式 声明map后&#xff0c;一定要make开辟空间&#xff0c;否则会报越界且不能使用 package mainimport "fmt"func main() {// 第一种声明方式// 声明myMap1是一种map类型 key是string value是stringvar myMap1 map[string]string// 判断一下map在…