CGAN——生成0-9数字图像(Tensorflow+mnist)

1、简介

  • 传统的GAN或者其他的GAN都是通过一堆的训练数据,最后训练出了生成网络,随机输入噪声最后产生的数据是这些训练数据类别中之一,无法提前预测生成的是哪个类别。
  • 如果需要定向指定生成某些数据,比如想生成飞机,数字9等指定类别的图片,就需要利用CGAN——条件生成对抗网络
  • 本文利用CGAN,输入带有标签的数字图像,训练后,再生成对应标签的图像。
  • 数据集:mnist
  • 框架:tensorflow

2、代码

  • import numpy as np
    from keras.models import Sequential, Model
    from keras.layers import Dense, LeakyReLU, BatchNormalization, Reshape
    from keras.layers import Input, Embedding, Flatten, multiply, Dropout
    from keras.datasets import mnist
    from keras.optimizers import Adam
    import matplotlib.pyplot as plt
    import matplotlib
    
    
    # 条件对抗生成网络
    class CGAN():
        def __init__(self):
            # 写入输入维度
            self.img_rows = 28  # 行
            self.img_cols = 28  # 列
            self.img_channel = 1  # 通道数
            self.img_shape = (self.img_rows, self.img_cols, self.img_channel)  # 尺寸
    
            self.num_classes = 10  # 类别数
            self.latent_dim = 100  # 噪声大小
    
            optimizer = Adam(0.0002, 0.5)  # 优化器,学习率0.0002
    
            self.generator = self.build_generator()  # 构建生成器
            self.discriminator = self.build_discriminator()  # 构建判别器
            # 判别器训练的配置
            self.discriminator.compile(loss=['binary_crossentropy'],  # 二进制交叉熵损失函数
                                       optimizer=optimizer,
                                       metrics=['accuracy'])
    
            # 联合训练,固定判别器
            self.discriminator.trainable = False
            noise = Input(shape=(100,))
            label = Input(shape=(1,))
            img = self.generator([noise, label])  # 生成的图像
            valid = self.discriminator([img, label])  # 判别生成的图像
            self.combined = Model([noise, label], valid)
            self.combined.compile(loss=['binary_crossentropy'],  # 二进制交叉熵损失函数
                                  optimizer=optimizer,
                                  metrics=['accuracy'])
    
        # 生成器
        def build_generator(self):
            model = Sequential()  # 定义网络层
    
            # 第一层
            model.add(Dense(256, input_dim=self.latent_dim))  # 全连接层,256个神经元,输入维度100
            model.add(LeakyReLU(alpha=0.2))  # 激活层
            model.add(BatchNormalization(momentum=0.8))  # BN层,动量0.8
    
            # 第二层
            model.add(Dense(512))  # 全连接层
            model.add(LeakyReLU(alpha=0.2))  # 激活层
            model.add(BatchNormalization(momentum=0.8))  # BN层,动量0.8
    
            # 第三层
            model.add(Dense(1024))  # 全连接层
            model.add(LeakyReLU(alpha=0.2))  # 激活层
            model.add(BatchNormalization(momentum=0.8))  # BN层,动量0.8
    
            # 输出层
            model.add(Dense(np.prod(self.img_shape), activation='tanh'))  # 计算图像尺寸,激活函数tanh
            model.add(Reshape(self.img_shape))  # Reshape层,输入的是噪声,需要的是图像,转换为图像
    
            model.summary()  # 记录参数情况
    
            # 定义输入
            noise = Input(shape=(self.latent_dim,))  # 生成器的输入维度
            label = Input(shape=(1,), dtype='int32')  # 生成器的标签维度,1维,类型int
    
            # 使输入Y和X的维度一致。将10个种类的label映射到latent_dim维度
            label_embedding = Flatten()(Embedding(self.num_classes, self.latent_dim)(label))  # 输入维度,输出维度,转换的变量label
            # Flatten() 将100维转化为(None, 100),这里None会随着batch而改变
    
            # 合并噪声和类别
            model_input = multiply([noise, label_embedding])  # 合并方法:对应位置相乘
    
            # 预测模型输出
            img = model(model_input)  # 生成图片
    
            return Model([noise, label], img)  # [输入],输出。输入按noise和label,合并由内部完成
    
        # 判别器
        def build_discriminator(self):
            model = Sequential()  # 定义网络层
    
            # 第一层
            model.add(Dense(512, input_dim=np.prod(self.img_shape)))  # 全连接层,512个神经元,输入维度784
            model.add(LeakyReLU(alpha=0.2))  # 激活层
    
            # 第二层
            model.add(Dense(512))  # 全连接层
            model.add(LeakyReLU(alpha=0.2))  # 激活层
            model.add(Dropout(0.4))  # Dropout层,防止过拟合,提高泛化性
    
            # 第三层
            model.add(Dense(512))  # 全连接层
            model.add(LeakyReLU(alpha=0.2))  # 激活层
            model.add(Dropout(0.4))  # Dropout层,防止过拟合,提高泛化性
    
            # 输出层
            model.add(Dense(1, activation='sigmoid'))
    
            model.summary()  # 记录参数情况
    
            # 定义输入
            img = Input(shape=self.img_shape)  # 输入图片
            label = Input(shape=(1,), dtype='int32')  # 输入标签
    
            # 使输入Y和X的维度一致。Flatten() 将100维转化为(None, 784),这里None会随着batch而改变
            label_embedding = Flatten()(Embedding(self.num_classes, np.prod(self.img_shape))(label))  # 输入维度,输出维度,转换的变量label
            flat_img = Flatten()(img)
    
            # 将图片和类别合并
            model_input = multiply([flat_img, label_embedding])  # 合并方法:对应位置相乘
    
            # 模型输出结果
            validity = model(model_input)  # 获取输出概率结果
    
            return Model([img, label], validity)  # [输入],输出
    
        # 训练
        def train(self, epochs, batch_size=128, sample_interval=50):
            # 获取数据集
            (X_train, Y_train,), (_, _) = mnist.load_data()  # 下载数据集,空的表示不要测试集
    
            # 将获取的图像转化为-1到1
            X_train = (X_train.astype(np.float32) - 127.5) / 127.5
            X_train = np.expand_dims(X_train, axis=3)  # 扩展维度,在第三维扩展。将60000*28*28的图片扩展为60000*28*28*1
    
            # 将标签大小变为60000*1
            Y_train = Y_train.reshape(-1, 1)  # -1自动计算第0维的维度空间数
    
            # 写入 真实输出 与 虚假输出
            valid = np.ones((batch_size, 1))  # 每行为一张图片
            fake = np.zeros((batch_size, 1))  # 每行为一张图片
            # imgs shape(batch_size, 28, 281)
            # labels shape(32, 1)
    
            for epoch in range(epochs):
                # 训练判别器
                # 从0~60000随机获取batch_size个索引数
                idx = np.random.randint(0, X_train.shape[0], batch_size)
                imgs, labels = X_train[idx], Y_train[idx]  # 获取图像和对应标签
    
                noise = np.random.normal(0, 1, (batch_size, self.latent_dim))  # 产生随机噪声
    
                gen_imgs = self.generator.predict([noise, labels])  # 生成虚假图片
    
                # 损失
                d_loss_real = self.discriminator.train_on_batch([imgs, labels], valid)
                d_loss_fake = self.discriminator.train_on_batch([gen_imgs, labels], fake)
                d_loss = 0.5 * np.add(d_loss_real, d_loss_fake)
    
                # 训练生成器
                sample_label = np.random.randint(0, 10, batch_size).reshape(-1, 1)  # 随机生成样本标签
    
                # 固定判别器,训练生成器——在联合模型中
                g_loss = self.combined.train_on_batch([noise, sample_label], valid)  # 生成器随机生成的图像和随机产生的标签,骗过判别器
    
                # 绘制进度图
                print("%d [D loss: %f, acc: %.2f%%] [G loss: %f]" % (epoch, d_loss[0], d_loss[1] * 100, g_loss[0]))
    
                # 每50次保存图像
                if (epoch + 1) % sample_interval == 0:
                    self.sample_images(epoch)
    
                # 每训练5000次保存模型
                if (epoch + 1) % 5000 == 0:
                    self.save_models(epoch)
    
        def sample_images(self, epoch):
            r, c = 2, 5  # 输出 2行5列的10张指定图像
            noise = np.random.normal(0, 1, (r * c, 100))
            sampled_labels = np.arange(0, 10).reshape(-1, 1)
    
            gen_imgs = self.generator.predict([noise, sampled_labels])
    
            # Rescale images -1
            gen_imgs = 0.5 * gen_imgs + 0.5
            fig, axs = plt.subplots(r, c)
            cnt = 0
            for i in range(r):
                for j in range(c):
                    axs[i, j].imshow(gen_imgs[cnt, :, :, 0], cmap='gray')
                    axs[i, j].set_title("Digit: %d" % sampled_labels[cnt])
                    axs[i, j].axis('off')
                    cnt += 1
            fig.savefig(f"images/sd{epoch+1}.png")  # 文件路径和代码文件同目录
            plt.close()
    
        def save_models(self, epoch):
            self.generator.save(f"models/generator_epoch_{epoch+1}.h5")
            self.discriminator.save(f"models/discriminator_epoch_{epoch+1}.h5")
            self.combined.save(f"models/combined_epoch_{epoch+1}.h5")
    
    
    if __name__ == '__main__':
        matplotlib.use('TkAgg')  # 设置后端为TkAgg
        cgan = CGAN()
        # 训练轮数20000,一次处理32张图片,每200保存一次生成的已知标签的生成图像
        cgan.train(epochs=20000, batch_size=32, sample_interval=200)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/460095.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

云计算 3月14号 (TCP三次握手和四次挥手)

1.TCP三次握手和四次挥手 1.TCP的传输过程: Seq 序列号 保障传输过程可靠。 ACK (确认消息) SYN (在建立TCP连接的时候使用) FIN (在关闭TCP连接的时候使用) 3.TCP建立连接的过程&…

ES解析word内容为空的问题和直接使用Tika解析文档的方案

导言 在上一篇文章最后,我们虽然跑通了ES文件搜索的全部流程,但是仍然出现了1个大的问题:ES7.3实测无法索引docx和doc文档,content有值但是无法解析到附件成为可读的可搜索的内容,附件内容为空(附件中根本…

Microsoft Remote Desktop Mac

Microsoft Remote Desktop是一款功能强大的远程连接工具,允许用户从远程位置连接到另一台计算机,实现跨设备的无缝协作。无论是在不同的设备之间共享文件、应用程序和其他资源,还是远程访问工作站和服务器,Microsoft Remote Deskt…

Unity开发一个FPS游戏之二

在之前的文章中,我介绍了如何开发一个FPS游戏,添加一个第一人称的主角,并设置武器。现在我将继续完善这个游戏,打算添加敌人,实现其智能寻找玩家并进行对抗。完成的效果如下: fps_enemy_demo 下载资源 首先是设计敌人,我们可以在网上找到一些好的免费素材,例如在Unity…

人机交互三原则,网络7层和对应的设备、公钥私钥

人机交互三原则 heo Mandel提出了人机交互的三个黄金原则,它们强调了相似的设计目标,分别是: 简单总结为:控负持面–>空腹吃面 1,用户控制 2,减轻负担 3,保持界面一致 置用户于控制之下&a…

DHCP在企业网的部署及安全防范

学习目标: 1. DHCP能够解决什么问题? 2. DHCP服务器如何部署? 3. 私接设备会带来什么问题以及如何防范? 给DHCP服务器配置地址: 地址池: DHCP有2种分配模式:全局分配和接口分配 DHCP enable

Upload-labs靶场

文件漏洞上传进行复现 环境搭建--->搭建好环境如下: 打开第一关,尝试文件上传漏洞 根据界面提示,选择一个文件(.php文件)进行上传,发现无法上传 根据提示是指使用js对不合法文件进行了检查,…

传输层的UDP协议

1. UDP协议报文格式 1.1 16位端口号 UDP协议报文中,端口号占2个字节,包括 源端口号 和 目的端口号。 1.2 16位UDP长度 UDP报文长度为2个字节 ,即UDP数据报长度为0~65535,也就是64kb。 1.3 16位UDP检验和 数据在网络传输的…

Python图像处理指南:PIL与OpenCV的比较【第135篇—PIL】

Python图像处理指南:PIL与OpenCV的比较 图像处理在计算机视觉和图像识别等领域中扮演着至关重要的角色。Python作为一种功能强大且易于学习的编程语言,提供了多种库供图像处理使用。在本文中,我们将比较两个最流行的Python图像处理库&#x…

基于正点原子潘多拉STM32L496开发板的简易示波器

一、前言 由于需要对ADC采样性能的评估,重点在于对原波形的拟合性能。 考虑到数据的直观性,本来计划采集后使用串口导出,并用图形做数据拟合,但是这样做的效率低下,不符合实时观察的需要,于是将开发板的屏幕…

云计算2主从数据库

设置主从数据库的目的是将数据库1和数据库2分别建在两个虚拟机上,并实现数据互通访问 首先准备两个虚拟机,这里示例ip分别为: 192.168.200.10;192.168.200.20 修改主机名,一个是mysql1,一个是mysql2&#x…

【每日一问】手机如何开启USB调试?

一、背景 当电脑跟手机之间需要进行交互的时候,可以考虑使用usb进行连接。那么手机如何开启USB调试呢? 二、操作步骤: 思路: 步骤1:手机开启开发者模式 步骤2:在开发者模式中,开启“USB调试”…

【AAAI 2024】解锁深度表格学习(Deep Tabular Learning)的关键:算术特征交互

近日,阿里云人工智能平台PAI与浙江大学吴健、应豪超老师团队合作论文《Arithmetic Feature Interaction is Necessary for Deep Tabular Learning》正式在国际人工智能顶会AAAI-2024上发表。本项工作聚焦于深度表格学习中的一个核心问题:在处理结构化表格…

html5的css使用display: flex进行div居中的坑!

最近做项目的时候,有个需求,一个高度宽度不确定的Div在另一个Div内上下左右居中。 然后以前上下居中用的都是很繁琐的,就打算去百度搜索一个更优秀的方法。 百度AI自己给我一个例子: /* div在容器里居中显示,设置外容…

单片机学到什么程度才可以去工作?

单片机学到什么程度才可以去工作? 如果没有名校或学位的加持,你还得再努力一把,才能从激烈的竞争中胜出。以下这些技能可以给你加分,你看情况学,不同行业对这些组件会有取舍: . Cortex-M内核:理解MCU内核各部件的工作机制&#…

windows的maven 低版本如何切换到高版本

要升级到 Maven 3.9.x 版本,可以按照以下步骤操作: 下载 Maven 3.9.x: 访问 Maven 的官方网站(https://maven.apache.org/download.cgi)并下载 Maven 3.9.x 版本的压缩包。选择与你的操作系统兼容的版本。 2. 解压缩 Maven 3.9.x…

一、MySQL基础学习

目录 1、MySQL启动2、MySQL客户端连接3、SQL3.1、SQL语句分类3.2、DDL(数据库定义语言)3.2.1、操作数据库3.2.2、操作数据表 3.3、DML(数据库操作语言)3.3.1、增加 insert into3.3.2、删除 delete3.3.3、修改 update 3.4、DQL&…

移动云COCA架构实现算力跃升,探索人工智能新未来

近期,随着OpenAI正式发布首款文生视频模型Sora,标志着人工智能大模型在视频生成领域有了重大飞跃。Sora模型不仅能够生成逼真的视频内容,还能够模拟物理世界中的物体运动与交互,其核心在于其能够处理和生成具有复杂动态与空间关系…

逆序对的数量 刷题笔记

思路 使用归并排序 在每次返回时 更新增加答案数 因为归并排序的两个特点 第一 使用双指针算法 第二 层层返回 从局部有序合并到整体有序 例如 {4 ,1 ,2 ,3} 划分到底层是四个数组 {4},{1},{3}, {…

【算法杂货铺】二分算法

目录 🌈前言🌈 📁 朴素二分查找 📂 朴素二分模板 📁 查找区间端点处 细节(重要) 📂 区间左端点处模板 📂 区间右端点处模板 📁 习题 1. 35. 搜索插入位…