生成对抗网络DCGAN学习实践

在AI内容生成领域,有三种常见的AI模型技术:GAN、VAE、Diffusion。其中,Diffusion是较新的技术,相关资料较为稀缺。VAE通常更多用于压缩任务,而GAN由于其问世较早,相关的开源项目和科普文章也更加全面,适合入门学习。

博主从入门和学习角度用Tensorflow跑通了DCGAN,本文对其进行记录以及分享。

1.简介

GAN(Generative Adversarial Network)是一种用于生成模型的机器学习框架。其原理基于两个主要组件:生成器(Generator)和判别器(Discriminator),二者通过对抗学习的方式相互竞争和提升。

从2014年左右发展至今,GAN目前有很多分支:

  • GAN 朴素GAN,最原始版本
  • DCGAN 卷积神经网络GAN
  • CGAN 条件GAN,训练时传入额外条件,例如通过不同的mask区域生成不同内容,可控制的生成
  • SeqGAN 使用GAN生成某些风格的句子,但不能进行对答
  • Cycle GAN 可实现图像风格迁移,其实现略复杂
  • 省略

2.原理介绍

先来看图

梯度
判别
G
LeakyReLU
tanh
InputNoise
FullConnectLayer123
OutputImage
D
LeakyReLU
Sigmoid
InputImage
FullConnectLayer12
OutputOneValue

生成器(Generator)和判别器(Discriminator)是GAN的两个主要模型,生成器在上图中用缩写G表示,判别器用缩写D表示。
生成器G输入[N]的一维噪声,即InputNoise。输出[W * H * RGB](大致类似)的张量
判别器D输入一张图像,输出[1]的张量,即一个浮点数,通过0-1的值得到图像是真还是假

判别器需要尽可能的认出造假图片,生成器需要尽可能的骗过判别器,两者会在这2个目标上不断的通过反向传播进行学习,从而达到生成器和判别器的纳什均衡,最终输出质量很高的生成图像。

2.2 重点1

在训练中,判别器返回一个0-1区间的浮点数(如[0]=0.63,[0]=0.21)作为判断结果,值越高也越认为是真实图片。由于判别器也是一个神经网络模型,因此可以将输出层的梯度一直传递回输入层,然后将输入层的梯度作为生成器的梯度继续反向传播,从而完成一次训练。

然而,很多文章并没有提到这一点。如果没有接触过这种多模型梯度传递训练方法,可能会认为使用一个数学方法或者计算机视觉方法来构建判别器也可以让整个模型正常运行。但事实上,这种方法是不可行的(通常情况下)。

2.3 重点2

使用更多的层可以增强模型的推理能力。例如,在训练过程中,如果模型生成出眉毛 A 的特征,则有鼻子 B、C 和 D 相关的备选项;而如果生成出眉毛 E 的特征,则有鼻子 F 和 G 相关的备选项。

这也是为什么生成器需要使用三个隐层的原因(博主的观点)。通过增加隐层的数量,模型可以捕捉到更多的特征和抽象概念,从而提高生成器的表现能力和推理能力。更深层次的网络结构能够帮助模型学习更复杂的模式和关联,使其在生成结果时更加准确和多样化。

上图生成器部分的激活函数用的是LeakyReLU,实际上就单隐层神经网络来说,ReLU要比Sigmoid能多解决很多类型问题,Sigmoid更适合分类问题,遇到一些奇怪的问题不容易收敛,而LeakyReLU激活函数即和ReLU逻辑一样也可以返回负数信息,这是博主觉得采用这个激活函数的原因。
而至于tanH和Sigmoid的比较,它们在某种程度上相似。一般来说,网上普遍认为tanH比Sigmoid更好,主要原因是它具有较窄的数值边界范围。

2.4 重点3

对于2套样本比较损失这类问题,一般使用二分类交叉熵,这不同于分类问题。
而二分类交叉熵又是在只有2种结果(r和1-r),的情况下对公式进行的简化:
https://blog.csdn.net/grayrail/article/details/131619144

2.5 模式崩溃

训练时还会出现一种情况,即生成器始终卡在一个生成结果上,比如生成0-9数字,结果训练几轮后始终在生成数字3。
这种情况称为模式崩溃,一般增加训练样本数量并调节参数,没有比较好的办法。

3.实践准备

python库下载使用国内镜像源:
https://zhuanlan.zhihu.com/p/477179822

使用方式:

pip install -i https://pypi.tuna.tsinghua.edu.cn/simple pyspider

github库下载耽误时间,可以缓存到gitee:
在这里插入图片描述

而gitee也有自己缓存好的镜像库,可以先去这里查:
https://gitcode.net/mirrors

python库查找:
https://pypi.org/

在pip中查找python库:
先 pip install pip-search 再使用命令 pip_search 搜索

4.实践

全连接神经网络版本的朴素GAN效果相对较差,而DCGAN(Deep Convolutional GAN)是卷积神经网络版本的GAN,下面以DCGAN为例使用Tensorflow进行实现:

import tensorflow as tf
import matplotlib.pyplot as plt
from tensorflow.keras import layers

# 定义生成器模型
def build_generator():
    model = tf.keras.Sequential()
    model.add(layers.Dense(7*7*256, use_bias=False, input_shape=(100,)))
    model.add(layers.BatchNormalization())
    model.add(layers.LeakyReLU())

    model.add(layers.Reshape((7, 7, 256)))
    assert model.output_shape == (None, 7, 7, 256)  # 注意:batch size 没有限制

    model.add(layers.Conv2DTranspose(128, (5, 5), strides=(1, 1), padding='same', use_bias=False))
    assert model.output_shape == (None, 7, 7, 128)
    model.add(layers.BatchNormalization())
    model.add(layers.LeakyReLU())

    model.add(layers.Conv2DTranspose(64, (5, 5), strides=(2, 2), padding='same', use_bias=False))
    assert model.output_shape == (None, 14, 14, 64)
    model.add(layers.BatchNormalization())
    model.add(layers.LeakyReLU())

    model.add(layers.Conv2DTranspose(1, (5, 5), strides=(2, 2), padding='same', use_bias=False, activation='tanh'))
    assert model.output_shape == (None, 28, 28, 1)

    return model

# 定义判别器模型
def build_discriminator():
    model = tf.keras.Sequential()
    model.add(layers.Conv2D(64, (5, 5), strides=(2, 2), padding='same',
                                     input_shape=[28, 28, 1]))
    model.add(layers.LeakyReLU())
    model.add(layers.Dropout(0.3))

    model.add(layers.Conv2D(128, (5, 5), strides=(2, 2), padding='same'))
    model.add(layers.LeakyReLU())
    model.add(layers.Dropout(0.3))

    model.add(layers.Flatten())
    model.add(layers.Dense(1))

    return model

# 定义生成器和判别器
generator = build_generator()
discriminator = build_discriminator()

# 定义损失函数
loss_fn = tf.keras.losses.BinaryCrossentropy(from_logits=True)

# 定义生成器和判别器的优化器
generator_optimizer = tf.keras.optimizers.Adam(1e-4)
discriminator_optimizer = tf.keras.optimizers.Adam(1e-4)

def generator_loss(fake_output):
    return loss_fn(tf.ones_like(fake_output), fake_output)

def discriminator_loss(real_output, fake_output):
    real_loss = loss_fn(tf.ones_like(real_output), real_output)
    fake_loss = loss_fn(tf.zeros_like(fake_output), fake_output)
    total_loss = real_loss + fake_loss
    return total_loss

# 定义训练循环
@tf.function
def train_step(images):
    # 生成噪声向量
    noise = tf.random.normal([BATCH_SIZE, 100])

    with tf.GradientTape() as gen_tape, tf.GradientTape() as disc_tape:
        # 使用生成器生成假图片
        generated_images = generator(noise, training=True)

        # 使用判别器判断真假图片
        real_output = discriminator(images, training=True)
        fake_output = discriminator(generated_images, training=True)

        # 计算损失函数
        gen_loss = generator_loss(fake_output)
        disc_loss = discriminator_loss(real_output, fake_output)

    # 计算梯度并更新生成器和判别器的参数
    gradients_of_generator = gen_tape.gradient(gen_loss, generator.trainable_variables)
    gradients_of_discriminator = disc_tape.gradient(disc_loss, discriminator.trainable_variables)

    generator_optimizer.apply_gradients(zip(gradients_of_generator, generator.trainable_variables))
    discriminator_optimizer.apply_gradients(zip(gradients_of_discriminator, discriminator.trainable_variables))

def generate_and_save_images(model, epoch, test_input):

    predictions = model(test_input, training=False)
    print("predictions.shape:", predictions.shape)
    num_images = predictions.shape[0]
    rows = int(num_images ** 0.5) # 计算行数
    cols = num_images // rows # 计算列数
    
    fig = plt.figure(figsize=(8, 8))
    
    for i in range(num_images):
        plt.subplot(rows, cols, i+1)
        plt.imshow(predictions[i, :, :, 0] * 127.5 + 127.5, cmap='gray')
        plt.axis('off')
    
    plt.savefig('image_at_epoch_{:04d}.png'.format(epoch))
    #plt.show()

# 加载MNIST数据集
(train_images, train_labels), (_, _) = tf.keras.datasets.mnist.load_data()

# 标准化数据
train_images = train_images.reshape(train_images.shape[0], 28, 28, 1).astype('float32')
train_images = (train_images - 127.5) / 127.5

# 批量大小与训练次数
BATCH_SIZE = 256
EPOCHS = 50

# 数据集切分为批次并进行训练
dataset = tf.data.Dataset.from_tensor_slices(train_images).shuffle(60000).batch(BATCH_SIZE)

for epoch in range(EPOCHS):
    for i,image_batch in enumerate(dataset):
        print("sub i",i)
        train_step(image_batch)

    print("------------------------------------------------------epoch:", epoch)

    # 每个 epoch 结束后生成并保存一组图像
    if (epoch + 1) % 5 == 0:
        seed = tf.random.normal([BATCH_SIZE, 100])
        generate_and_save_images(generator, epoch + 1, seed)

跑一阵子MNIST数据集后,结果如下:
在这里插入图片描述


参考:

论文精读: https://www.bilibili.com/video/BV1rb4y187vD

同济子豪兄精读版本: https://www.bilibili.com/video/BV1oi4y1m7np

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/53711.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

Stable Diffusion 开源模型 SDXL 1.0 发布

关于 SDXL 模型,之前写过两篇: Stable Diffusion即将发布全新版本Stable Diffusion XL 带来哪些新东西? 一晃四个月的时间过去了,Stability AI 团队终于发布了 SDXL 1.0。当然在这中间发布过几个中间版本,分别是 SDXL …

xshell连接liunx服务器身份验证不能选择password

ssh用户身份验证不能选择password 只能用public key的解决办法 问题现象 使用密码通过Workbench或SSH方式(例如PuTTY、Xshell、SecureCRT等)远程登录ECS实例时,遇到服务器禁用了密码登录方式错误. 可能原因 该问题是由于SSH服务对应配置文件/etc/ssh/sshd_config中…

4年测试“我“该何去何从?测试还是测试开发?

目录:导读 前言一、Python编程入门到精通二、接口自动化项目实战三、Web自动化项目实战四、App自动化项目实战五、一线大厂简历六、测试开发DevOps体系七、常用自动化测试工具八、JMeter性能测试九、总结(尾部小惊喜) 前言 4年测试&#xff…

【IDEA】idea不自动生成target

文章目录 1. 不生成target2. 仅部分文件不生成target2.1. 一般原因就是资源没有设置2.2. 配置编译src/main/java文件夹下的资源文件2.3. 清理缓存(王炸) 3. 参考资料 本文描述idea不生成target的几种情况以及处理方法 1. 不生成target 像下图这样根本就…

Leetcode-每日一题【剑指 Offer 56 - I. 数组中数字出现的次数】

题目 一个整型数组 nums 里除两个数字之外,其他数字都出现了两次。请写程序找出这两个只出现一次的数字。要求时间复杂度是O(n),空间复杂度是O(1)。 示例 1: 输入:nums [4,1,4,6]输出:[1,6] 或 [6,1] 示例 2&#x…

NODEJS笔记

全局对象 global/window console.log/info/warn/error/time/timeEnd process.arch/platform/version/env/kill/pid/nextTick Buffer.alloc(5,abcde) String/toString setTimeout/clearTimeout setInterval/clearInterval setImmediate/clearImmediate process.nextTi…

python包的介绍使用

python包的介绍使用 简单来说python的模块相当于文件,包就相当于文件夹 python包创建后会自动生成 init.py 的文件 然后可以在不同的包下面创建不同的模块 下面是引入模块里面的内容的三种方式 第一种就是引入模块,记住引入包是会报错的 import只能引…

【我们一起60天准备考研算法面试(大全)-第三十天 30/60】【矩阵翻转】【矩阵相乘】

专注 效率 记忆 预习 笔记 复习 做题 欢迎观看我的博客,如有问题交流,欢迎评论区留言,一定尽快回复!(大家可以去看我的专栏,是所有文章的目录)   文章字体风格: 红色文字表示&#…

Eureka 学习笔记4:EurekaClient

版本 awsVersion ‘1.11.277’ EurekaClient 接口实现了 LookupService 接口&#xff0c;拥有唯一的实现类 DiscoveryClient 类。 LookupService 接口提供以下功能&#xff1a; 获取注册表根据应用名称获取应用根据实例 id 获取实例信息 public interface LookupService<…

【C++】反向迭代器的模拟实现通用(可运用于vector,string,list等模拟容器)

文章目录 前言一、反向迭代器封装&#xff08;reverseiterator&#xff09;1.构造函数1解引用操作.3.->运算符重载4.前置&#xff0c;后置5.前置--&#xff0c;后置--6.不等号运算符重载7.完整代码 二、rbegin&#xff08;&#xff09;以及rend&#xff08;&#xff09;1.rb…

ADSelfService Plus:保护密码安全的最佳解决方案

密码安全是当今数字时代中至关重要的话题。随着互联网和信息技术的迅速发展&#xff0c;我们的生活变得越来越数字化&#xff0c;密码已成为我们生活中不可或缺的一部分。然而&#xff0c;随着各种网络威胁和黑客攻击不断增加&#xff0c;保护我们的密码变得越来越重要。 密码安…

linux基础学习

1.day1 1、修改虚拟机的网络&#xff1b; sudo vim /etc/netplan/*.yaml sudo netplan apply 2.day2 1、VIM配置&#xff1b; 2、安装SSH&#xff0c;调用putty接入终端&#xff1b; 3、shell命令&#xff1b; *&#xff1a;匹配任意长度的字符 &#xff1f;&#xff1a;匹…

为什么VPS是中小型企业的理想选择?

对于中小型企业来说&#xff0c;选择适合自身业务需求的托管方案至关重要。在如今数字化时代&#xff0c;VPS作为一种灵活、高性能的托管解决方案&#xff0c;成为中小型企业的理想选择。作为动态VPS代理产品供应商&#xff0c;我们深知一个高质量、高性能的VPS托管服务&#x…

使用IDEA打jar包的详细图文教程

1. 点击intellij idea左上角的“File”菜单 -> Project Structure 2. 点击"Artifacts" -> 绿色的"" -> “JAR” -> Empty 3. Name栏填入自定义的名字&#xff0c;Output ditectory 选择 jar 包目标目录&#xff0c;Available Elements 里右击…

信息安全:网络安全体系 与 网络安全模型.

信息安全&#xff1a;网络安全体系 与 网络安全模型. 网络安全保障是一项复杂的系统工程&#xff0c;是安全策略、多种技术、管理方法和人员安全素质的综合。一般而言&#xff0c;网络安全体系是网络安全保障系统的最高层概念抽象&#xff0c;是由各种网络安全单元按照一定的规…

Failed to load local font resource:微信小程序加载第三方字体

加载本地字体.ttf 将ttf转换为base64格式&#xff1a;https://transfonter.org/ 步骤如下 将下载后的stylesheet.css 里的font-family属性名字改一下&#xff0c;然后引进页面里就行了&#xff0c;全局样式就放app.scss&#xff0c;单页面就引入单页面 注&#xff1a; .title…

12-3_Qt 5.9 C++开发指南_创建和使用静态链接库

第12章中的静态链接库和动态链接库介绍&#xff0c;都是以UI操作的方式进行&#xff0c;真正在实践中&#xff0c;可以参考UI操作产生的代码来实现同样的功能。 文章目录 1. 创建静态链接库1.1 创建静态链接库过程1.2 静态链接库代码1.2.1 静态链接库可视化UI设计框架1.2.2 qw…

关于element ui 安装失败的问题解决方法、查看是否安装成功及如何引入

Vue2引入 执行npm i element-ui -S报错 原因&#xff1a;npm版本太高 报错信息&#xff1a; 解决办法&#xff1a; 使用命令&#xff1a; npm install --legacy-peer-deps element-ui --save 引入&#xff1a; 在main.js文件中引入 //引入Vue import Vue from vue; //引入…

二阶阻尼弹簧系统的simulink仿真(s函数)

文章目录 前言一.非线性反步法1.原系统对应的s函数脚本文件&#xff08;仅修改模板的初始化函数、导数函数和输出函数三个部分&#xff09;2.控制器对应的s函数脚本文件&#xff08;仅修改模板的初始化函数和输出函数两个部分&#xff09;3.其他参数脚本文件4.输入5.输出&#…

成功拿下25K测试岗,原来字节真的很好进

朋友&#xff0c;作为一个曾经从机械转行到IT的行业的过来人&#xff0c;已在IT行业工作4年&#xff0c;分享一下我的经验&#xff0c;供你参考。 讲真&#xff0c;现在想通过培训班培训几个月就进入IT行业&#xff0c;越来越来难了&#xff1b;如果是在2018年以前&#xff0c;…