AIGC实战——条件生成对抗网络(Conditional Generative Adversarial Net, CGAN)

AIGC实战——条件生成对抗网络

    • 0. 前言
    • 1. CGAN架构
    • 2. 模型训练
    • 3. CGAN 分析
    • 小结
    • 系列链接

0. 前言

我们已经学习了如何构建生成对抗网络 (Generative Adversarial Net, GAN) 以从给定的训练集中生成逼真图像。但是,我们无法控制想要生成的图像类型,例如控制模型生成男性或女性的面部图像;我们可以从潜空间中随机采样一个点,但是不能预知给定潜变量能够生成什么样的图像。在本节中,我们将构建一个能够控制输出的 GAN,即条件生成对抗网络 (Conditional Generative Adversarial Net, GAN)。该模型最早由 MirzaOsindero2014 年提出,是对 GAN 架构的简单改进。

1. CGAN架构

在节中,我们将使用面部数据集中的头发颜色属性来设置 CGAN 的条件。也就是说,我们将能够明确指定是否要生成带有金发的图像。头发颜色标签作为 CelebA 数据集的一部分已在数据集中提供,CGAN 的架构如下图所示。

CGAN 架构

标准 GANCGAN 之间的关键区别在于:在 CGAN 中,我们需要向生成器和判别器传递与标签相关的额外信息。在生成器中,标签信息转化为独热编码 (one-hot) 向量后附加在潜空间样本之后。在判别器中,通过重复独热编码向量填充得到与输入图像相同形状的通道,将标签信息添加为 RGB 图像的额外通道。
CGAN 之所以能够生成指定类型的图像,是因为其判别器可以获得关于图像内容的额外信息,因此生成器必须确保其输出与提供的标签一致,以继续欺骗判别器。如果生成器生成了与图像标签不一致的图像,即使图像非常逼真,判别器会将它们判定为伪造图像,因为图像和标签并不匹配。
在本节所构建的 CGAN 中,因为有两个类别(金发和非金发),独热编码标签的长度是 2。但是,我们也可以根据需要拥有使用多个标签。例如,在 Fashion-MNIST 数据集上训练 CGAN 时,为了输出 10 种不同类型的 Fashion-MNIST 图像,可以通过将长度为 10 的独热编码标签向量并入生成器的输入,并将 10 个额外的独热编码标签通道并入判别器的输入。
综上,我们需要对标准 GAN 架构所进行的修改是,将标签信息与生成器和判别器的现有输入连接起来:

# 图像通道和标签通道分别传递给判别器,并进行连接
critic_input = layers.Input(shape=(IMAGE_SIZE, IMAGE_SIZE, CHANNELS))
label_input = layers.Input(shape=(IMAGE_SIZE, IMAGE_SIZE, CLASSES))
x = layers.Concatenate(axis=-1)([critic_input, label_input])
x = layers.Conv2D(64, kernel_size=4, strides=2, padding="same")(x)
x = layers.LeakyReLU(0.2)(x)
x = layers.Conv2D(128, kernel_size=4, strides=2, padding="same")(x)
x = layers.LeakyReLU()(x)
x = layers.Dropout(0.3)(x)
x = layers.Conv2D(128, kernel_size=4, strides=2, padding="same")(x)
x = layers.LeakyReLU(0.2)(x)
x = layers.Dropout(0.3)(x)
x = layers.Conv2D(128, kernel_size=4, strides=2, padding="same")(x)
x = layers.LeakyReLU(0.2)(x)
x = layers.Dropout(0.3)(x)
x = layers.Conv2D(1, kernel_size=4, strides=1, padding="valid")(x)
critic_output = layers.Flatten()(x)

critic = models.Model([critic_input, label_input], critic_output)
print(critic.summary())
# 潜向量和标签类别分别传递给生成器,并在调整形状之前进行连接
generator_input = layers.Input(shape=(Z_DIM,))
label_input = layers.Input(shape=(CLASSES,))
x = layers.Concatenate(axis=-1)([generator_input, label_input])
x = layers.Reshape((1, 1, Z_DIM + CLASSES))(x)
x = layers.Conv2DTranspose(
    128, kernel_size=4, strides=1, padding="valid", use_bias=False
)(x)
x = layers.BatchNormalization(momentum=0.9)(x)
x = layers.LeakyReLU(0.2)(x)
x = layers.Conv2DTranspose(
    128, kernel_size=4, strides=2, padding="same", use_bias=False
)(x)
x = layers.BatchNormalization(momentum=0.9)(x)
x = layers.LeakyReLU(0.2)(x)
x = layers.Conv2DTranspose(
    128, kernel_size=4, strides=2, padding="same", use_bias=False
)(x)
x = layers.BatchNormalization(momentum=0.9)(x)
x = layers.LeakyReLU(0.2)(x)
x = layers.Conv2DTranspose(
    64, kernel_size=4, strides=2, padding="same", use_bias=False
)(x)
x = layers.BatchNormalization(momentum=0.9)(x)
x = layers.LeakyReLU(0.2)(x)
generator_output = layers.Conv2DTranspose(
    CHANNELS, kernel_size=4, strides=2, padding="same", activation="tanh"
)(x)
generator = models.Model([generator_input, label_input], generator_output)
print(generator.summary())

2. 模型训练

调整 CGANtrain_step 方法,以令生成器和判别器适应新的输入格式:

    def train_step(self, data):
        # 从数据集中提取图像和标签
        real_images, one_hot_labels = data
        # 将独热编码向量扩展为具有与输入图像相同空间尺寸 (64×64) 的独热编码图像
        image_one_hot_labels = one_hot_labels[:, None, None, :]
        image_one_hot_labels = tf.repeat(image_one_hot_labels, repeats=IMAGE_SIZE, axis=1)
        image_one_hot_labels = tf.repeat(image_one_hot_labels, repeats=IMAGE_SIZE, axis=2)

        batch_size = tf.shape(real_images)[0]

        for i in range(self.critic_steps):
            random_latent_vectors = tf.random.normal( shape=(batch_size, self.latent_dim))

            with tf.GradientTape() as tape:
                # 生成器接受包含两个输入的列表——随机潜向量和独热编码的标签向量
                fake_images = self.generator([random_latent_vectors, one_hot_labels], training=True)
                # 判别器接受包含两个输入的列表——真实/生成图像和独热编码的标签通道
                fake_predictions = self.critic([fake_images, image_one_hot_labels], training=True)
                real_predictions = self.critic([real_images, image_one_hot_labels], training=True)

                c_wass_loss = tf.reduce_mean(fake_predictions) - tf.reduce_mean(real_predictions)
                c_gp = self.gradient_penalty(batch_size, real_images, fake_images, image_one_hot_labels)
                # 梯度惩罚函数还需要通过独热编码的标签通道传递(由于其流经判别器)
                c_loss = c_wass_loss + c_gp * self.gp_weight

            c_gradient = tape.gradient(c_loss, self.critic.trainable_variables)
            self.c_optimizer.apply_gradients(zip(c_gradient, self.critic.trainable_variables))

        random_latent_vectors = tf.random.normal(
            shape=(batch_size, self.latent_dim)
        )

        with tf.GradientTape() as tape:
            # 生成器训练过程的修改与判别器训练步骤的修改相同
            fake_images = self.generator([random_latent_vectors, one_hot_labels], training=True)
            fake_predictions = self.critic([fake_images, image_one_hot_labels], training=True)
            g_loss = -tf.reduce_mean(fake_predictions)

        gen_gradient = tape.gradient(g_loss, self.generator.trainable_variables)
        self.g_optimizer.apply_gradients(zip(gen_gradient, self.generator.trainable_variables))

        self.c_loss_metric.update_state(c_loss)
        self.c_wass_loss_metric.update_state(c_wass_loss)
        self.c_gp_metric.update_state(c_gp)
        self.g_loss_metric.update_state(g_loss)
        return {m.name: m.result() for m in self.metrics}

3. CGAN 分析

我们可以通过将特定的独热编码标签传递到生成器的输入中来控制 CGAN 的输出。例如,要生成一张非金发的人脸图像,我们传入向量 [1, 0];要生成一张金发的人脸图像,我们传入向量 [0, 1]
CGAN 的输出如下图所示。可以看到,在保持随机潜向量不变的情况下,只改变条件标签向量,显然 CGAN 已经学会使用标签向量来控制图像的头发颜色属性,且图像的其余部分几乎没有改变。这证明了 GAN 能够以这种方式组织潜空间中的点,使得各个特征可以相互解耦。

生成结果

如果数据集中有标签可用,即使不一定需要将生成的输出与标签相关联,将它们作为 GAN 的输入通常也可以提高生成图像的质量,我们可以把标签看作是像素输入的信息扩展。

小结

在本节中,构建了一个条件生成对抗网络 (Conditional Generative Adversarial Net, CGAN),通过将标签作为输入传递给判别器和生成器,能够生成可控类别的图像,这是由于标签为网络提供了额外的信息,以便使生成的输出与给定的标签相关联。

系列链接

AIGC实战——生成模型简介
AIGC实战——深度学习 (Deep Learning, DL)
AIGC实战——卷积神经网络(Convolutional Neural Network, CNN)
AIGC实战——自编码器(Autoencoder)
AIGC实战——变分自编码器(Variational Autoencoder, VAE)
AIGC实战——使用变分自编码器生成面部图像
AIGC实战——生成对抗网络(Generative Adversarial Network, GAN)
AIGC实战——WGAN(Wasserstein GAN)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/254047.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

Nat Med | Tau靶向反义寡核苷酸

今天给同学们分享一篇实验文章“Tau-targeting antisense oligonucleotide MAPTRx in mild Alzheimers disease: a phase 1b, randomized, placebo-controlled trial”,这篇文章发表在Nat Med期刊上,影响因子为82.9。 结果解读: 患者 从201…

防止反编译,保护你的SpringBoot项目

ClassFinal-maven-plugin插件是一个用于加密Java字节码的工具,它能够保护你的Spring Boot项目中的源代码和配置文件不被非法获取或篡改。下面是如何使用这个插件来加密test.jar包的详细步骤: 安装并设置Maven: 首先确保你已经在你的开发环境中…

关于#c语言#的问题:分析递归调用的过程◇画出调用过程各语句执行过程

关于#c语言#的问题&#xff1a;分析递归调用的过程◇画出调用过程各语句执行过程 当涉及到递归调用的过程时&#xff0c;可以通过绘制函数调用栈来分析和理解递归的执行过程。下面是一个示例的C语言递归函数和相应的调用过程&#xff1a; #include <stdio.h>void recurs…

数据结构 AVL树概念以及实现插入的功能(含Java代码实现)

为啥要有avl树 avl树是在二叉搜索树下的一种进阶形式,是为了防止二叉搜索树在极端情况下产生的链表化的场景,从而在二叉搜索树的基础上,加上了某些条件来阻止这种极端情况的产生,但不是保证完全平衡,而是放开了一定的条件,使得这种情况不那么难以满足.(条件:左右子树的高度差的…

【【UART 传输数据实验】】

UART 传输数据实验 通信方式在日常的应用中一般分为串行通信&#xff08;serial communication&#xff09;和并行通信&#xff08;parallel communication&#xff09;。 我们再来了解下串行通信的特点。串行通信是指数据在一条数据线上&#xff0c;一比特接一比特地按顺序传…

Python 爬虫开发完整环境部署,爬虫核心框架安装

Python 爬虫开发完整环境部署 前言&#xff1a; ​ 关于本篇笔记&#xff0c;参考书籍为 《Python 爬虫开发实战3 》 笔记做出来的一方原因是为了自己对 Python 爬虫加深认知&#xff0c;一方面也想为大家解决在爬虫技术区的一些问题&#xff0c;本篇文章所使用的环境为&#x…

Esxi中的AlmaLinux硬盘扩容

Esxi中的AlmaLinux硬盘扩容 通过本文能学习到 虚拟机中的AlmaLinux硬盘扩容 本文主要包括3部分内容&#xff1a; 1. 需要进行扩容的原因 2. 写这篇文章的目的 3. 扩容实操需要进行扩容的原因 近日&#xff0c;使用Jenkins部署时&#xff0c;出现镜像向Nexus私服推送镜像时…

展示一段比较简单的人工智能自动做模型的程序

人工智能是一种模拟或模仿人类智能的技术。它通过使计算机系统具有一定的认知能力和学习能力&#xff0c;使其能够自动完成一系列复杂的任务。人工智能可以在各个领域应用&#xff0c;包括图像识别、语音识别、自然语言处理、机器学习等。人工智能还可以用于解决各种问题&#…

互联网加竞赛 python 机器视觉 车牌识别 - opencv 深度学习 机器学习

1 前言 &#x1f525; 优质竞赛项目系列&#xff0c;今天要分享的是 &#x1f6a9; 基于python 机器视觉 的车牌识别系统 &#x1f947;学长这里给一个题目综合评分(每项满分5分) 难度系数&#xff1a;3分工作量&#xff1a;3分创新点&#xff1a;3分 &#x1f9ff; 更多资…

高通平台开发系列讲解(AI篇)SNPE工作流程介绍

文章目录 一、转换网络模型二、量化2.1、选择量化或非量化模型2.2、使用离线TensorFlow或Caffe模型2.3、使用非量化DLC初始化SNPE2.4、使用量化DLC初始化SNPE三、准备输入数据四、运行加载网络沉淀、分享、成长,让自己和他人都能有所收获!😄 📢本篇章主要介绍SNPE模型工作…

Spring01

一、Spring概述 自 2004年 4 月&#xff0c;Spring 1.0 版本正式发布以来&#xff0c;Spring 已经步入到了第 5 个大版本&#xff0c;也就是我们常说的 Spring 5。 Spring的基础是Spring Framework&#xff0c;其功能有&#xff1a; 1、IoC (控制反转)&#xff0c;Spring 两大…

鸿蒙 Ark ui 实战登录界面请求网络实现教程

团队介绍 作者&#xff1a;徐庆 团队&#xff1a;坚果派 公众号&#xff1a;“大前端之旅” 润开鸿生态技术专家&#xff0c;华为HDE&#xff0c;CSDN博客专家&#xff0c;CSDN超级个体&#xff0c;CSDN特邀嘉宾&#xff0c;InfoQ签约作者&#xff0c;OpenHarmony布道师&…

18个非技术面试题

请你自我介绍一下你自己&#xff1f; 这道面试题是大家在以后面试过程中会常被问到的&#xff0c;那么我们被问到之后&#xff0c;该如果回答呢&#xff1f;是说姓名&#xff1f;年龄&#xff1f;还是其他什么&#xff1f; 最佳回答提示&#xff1a; 一般人回答这个问题往往会…

接口测试--参数实现MD5加密签名规则

最近有个测试接口需求&#xff0c;接口有签名检查&#xff0c;签名规范为将所有请求参数按照key字典排序并连接起来进行md5加密&#xff0c;格式是&#xff1a;md5(bar2&baz3&foo1),得到签名&#xff0c;将签名追加到参数末尾。由于需要对参数进行动态加密并且做压力测…

基于Python实现的一个书法字体风格识别器源码,通过输入图片,识别出图片中的书法字体风格,采用Tkinter实现GUI界面

项目描述 本项目是一个书法字体风格识别器&#xff0c;通过输入图片&#xff0c;识别出图片中的书法字体风格。项目包含以下文件&#xff1a; 0_setting.yaml&#xff1a;配置文件&#xff0c;包含书法字体风格列表、图片调整大小的目标尺寸等设置。1_Xy.py&#xff1a;预处理…

vue3 插槽slot

插槽是子组件中的提供给父组件使用的一个占位符&#xff0c;用 <slot> 表示&#xff0c;父组件可以在这个占位符中填充任何模板代码&#xff0c;如 HTML、组件等&#xff0c;填充的内容会替换子组件的<slot> 元素。<slot> 元素是一个插槽出口 (slot outlet)&…

Web前端-HTML(初识)

文章目录 1.认识WEB1.1 认识网页&#xff0c;网站1.2 思考 2. 浏览器&#xff08;了解&#xff09;2.1 五大浏览器2.2 查看浏览器占有的市场份额 3. Web标准&#xff08;重点&#xff09;3.1 Web 标准构成结构表现行为 1.认识WEB 1.1 认识网页&#xff0c;网站 网页主要由文字…

ADB:获取坐标

命令&#xff1a; adb shell getevent | grep -e "0035" -e "0036" adb shell getevent -l | grep -e "0035" -e "0036" 这一条正确&#xff0c;但是&#xff0c;grep给过滤了&#xff0c;导致没有输出 getevent -c 10 //输出10条信息…

Linux---远程登录、远程拷贝命令

1. 远程登录、远程拷贝命令的介绍 命令说明ssh远程登录scp远程拷贝 2. ssh命令的使用 ssh是专门为远程登录提供的一个安全性协议&#xff0c;常用于远程登录&#xff0c;想要使用ssh服务&#xff0c;需要安装相应的服务端和客户端软件&#xff0c;当软件安装成功以后就可以使…

雷士、书客、松下护眼台灯怎么样?多方位测评对比爆料!

灯光对于我们而言&#xff0c;重要性不言而喻&#xff0c;不管是办公、休闲&#xff0c;还是学习阅读&#xff0c;都离不开它的存在&#xff0c;尤其的在夜晚的时候。所以一般来说都会准备一盏桌面照明的台灯&#xff0c;目前而言最受欢迎的台灯就是护眼台灯了&#xff0c;因为…