CGAN笔记总结第二弹~

CGAN原理与源码分析

  • 一、复习GAN
    • 1.1损失函数
    • 1.2判别器源码
    • 1.3 生成器源码
  • 二、什么是CGAN?
    • 2.1 CGAN原理图
    • 2.2条件GAN的损失函数
    • 2.3 生成器源码
    • 2.4 判别器源码
    • 2.5 训练过程
      • 1)这里的训练顺序
      • 2)为什么先训练判别器后训练生成器呢?
    • 2.6 训练过程运行结果
    • 2.7测试结果
      • 1)测试代码

一、复习GAN

生成式对抗网络(Generative Adversarial Networks)是让两个神经网络进行博弈进行学习。基础结构包含生成器和判别器。生成器的目标是生成与真实图片相似的图片,以假乱真,尽可能地让判别器判断生成的图片是真实的。判别器的目标是能够区分真实图片和生成图片。生成器和判别器通过巧妙地设计损失函数,而结合在一起,在相互对抗中不断调整各自的参数,使得判别器难以判断生成器生成的图片是否真实,从而达到欺骗人眼的效果。
在    插入图片描述

1.1损失函数

在这里插入图片描述

在这里插入图片描述

1.2判别器源码

class Discriminator(nn.Module):
	def __init__(self):
		super().__init__()
		self.model = nn.Sequential(
						nn.Linear(784,1024),
						nn.LeakyReLU(0.2),
						nn.Dropout(0.3),
						nn.Linear(1024,512),
						nn.LeakyReLU(0.2),
						nn.Dropout(0.3),
						nn.Linear(512,256),
						nn.LeakyReLU(0.2),
						nn.Dropout(0.3),
						nn.Linear(256,1),
						nn.Sigmoid()
 		)
 	def forward(self, x):
 		return self.model(x)

在这里插入图片描述

1.3 生成器源码

class Generator(nn.Module):
	def __init__(self):
		super().__init__()
		self.model = nn.Sequential(
						nn.Linear(100,256),
						nn.LeakyReLU(0.2),
						
						nn.Linear(256,512),
						nn.LeakyReLU(0.2),
						
						nn.Linear(512,1024),
						nn.LeakyReLU(0.2),
					
						nn.Linear(1024,784),
						nn.Tahn()
 		)
 	def forward(self, x):
 		return self.model(x)

在这里插入图片描述

二、什么是CGAN?

CGAN,全称Conditional Generative Aderversarial Networks.与GAN相比,条件GAN加入了额外信息c,从而能够生成指定的手写数字。

2.1 CGAN原理图

在这里插入图片描述

2.2条件GAN的损失函数

在这里插入图片描述
nn.BCELoss()是一个PyTorch中的损失函数,它被用于二分类问题。BCE代表二元交叉熵(Binary Cross Entropy)
这里用到的是二元交叉熵损失函数
D(x)代表的是判别器判别图片是真的概率;

2.3 生成器源码

class Generator(nn.Module):
    def __init__(self, num_channel=1, nz=100, nc=10, ngf=64):
        super(Generator, self).__init__()
        self.main = nn.Sequential(
            # 输入维度 110 x 1 x 1
            nn.ConvTranspose2d(nz + nc, ngf * 8, 4, 1, 0, bias=False),
            nn.BatchNorm2d(ngf * 8),
            nn.ReLU(True),
            # 特征维度 (ngf*8) x 4 x 4
            nn.ConvTranspose2d(ngf * 8, ngf * 4, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ngf * 4),
            nn.ReLU(True),
            # 特征维度 (ngf*4) x 8 x 8
            nn.ConvTranspose2d(ngf * 4, ngf * 2, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ngf * 2),
            nn.ReLU(True),
            # 特征维度 (ngf*2) x 16 x 16
            nn.ConvTranspose2d(ngf * 2, ngf, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ngf),
            nn.ReLU(True),
            # 特征维度 (ngf) x 32 x 32
            nn.ConvTranspose2d(ngf, num_channel, 4, 2, 1, bias=False),
            nn.Tanh()
            # 特征维度. (num_channel) x 64 x 64
        )
        self.apply(weights_init)

    def forward(self, input_z, onehot_label):
        input_ = torch.cat((input_z, onehot_label), dim=1)
        n, c = input_.size()
        input_ = input_.view(n, c, 1, 1)
        return self.main(input_)

在生成器,
随机向量z是100维的,
额外信息c是10维的,(因为手写数字包含0-9,一共10类)
在这里,采用直接拼接的方式,最终形成了110维的输入

2.4 判别器源码

class Discriminator(nn.Module):
    def __init__(self, num_channel=1, nc=10, ndf=64):
        super(Discriminator, self).__init__()
        self.main = nn.Sequential(
            # 输入维度 (num_c3
            # channel+nc) x 64 x 64  1*64*64的图像和10维的类别   10维类别先转换成10*64*64    然后合并就是11*64*64
            # 输入通道  输出通道   卷积核的大小   步长  填充
            #原始输入张量:b 11 64  64
            nn.Conv2d(num_channel + nc, ndf, 4, 2, 1, bias=False),   #b  64  32  32
            nn.LeakyReLU(0.2, inplace=True),
            # 特征维度 (ndf) x 32 x 32
            nn.Conv2d(ndf, ndf * 2, 4, 2, 1, bias=False),  #b   64*2   16  16
            nn.BatchNorm2d(ndf * 2),
            nn.LeakyReLU(0.2, inplace=True),
            # 特征维度 (ndf*2) x 16 x 16
            nn.Conv2d(ndf * 2, ndf * 4, 4, 2, 1, bias=False),    #b   64*4   8    8
            nn.BatchNorm2d(ndf * 4),
            nn.LeakyReLU(0.2, inplace=True),
            # 特征维度 (ndf*4) x 8 x 8
            nn.Conv2d(ndf * 4, ndf * 8, 4, 2, 1, bias=False),    #b   64*8    4    4
            nn.BatchNorm2d(ndf * 8),
            nn.LeakyReLU(0.2, inplace=True),
            # 特征维度 (ndf*8) x 4 x 4
            nn.Conv2d(ndf * 8, 1, 4, 1, 0, bias=False),        #b   1    1    1      其实就是一个数值,区间在正无穷到负无穷之间
            nn.Sigmoid()
        )
        self.apply(weights_init)

    def forward(self, images, onehot_label):
        device = 'cuda' if torch.cuda.is_available() else 'cpu'
        h, w = images.shape[2:]
        n, nc = onehot_label.shape[:2]
        label = onehot_label.view(n, nc, 1, 1) * torch.ones([n, nc, h, w]).to(device)
        input_ = torch.cat([images, label], 1)
        return self.main(input_)

在判别器中,输入的数据有
图片x,(可能是来自真实数据集的样本,也可能是来自生成器生成的虚假样本) 维度是1HW
额外信息c,维度是10维,变换到10 * 1 * 1,将后两维进行复制 变换为10 * H * W的张量;
最终拼接在一起,构成11 * H * W的输入。

2.5 训练过程


MODEL_G_PATH = "./"
LOG_G_PATH = "Log_G.txt"
LOG_D_PATH = "Log_D.txt"
IMAGE_SIZE = 64
BATCH_SIZE = 128
WORKER = 1
LR = 0.0002
NZ = 100
NUM_CLASS = 10
EPOCH = 50

data_loader = loadMNIST(img_size=IMAGE_SIZE, batch_size=BATCH_SIZE)  #原始图片宽高是28*28的,给改变成64*64
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
netG = Generator().to(device)
netD = Discriminator().to(device)
criterion = nn.BCELoss()
real_label = 1.
fake_label = 0.
optimizerD = optim.Adam(netD.parameters(), lr=LR, betas=(0.5, 0.999))
optimizerG = optim.Adam(netG.parameters(), lr=LR, betas=(0.5, 0.999))

g_writer = LossWriter(save_path=LOG_G_PATH)
d_writer = LossWriter(save_path=LOG_D_PATH)

fix_noise = torch.randn(BATCH_SIZE, NZ, device=device)
fix_input_c = (torch.rand(BATCH_SIZE, 1) * NUM_CLASS).type(torch.LongTensor).squeeze().to(device)
fix_input_c = onehot(fix_input_c, NUM_CLASS)

img_list = []
G_losses = []
D_losses = []
iters = 0

print("开始训练>>>")
for epoch in range(EPOCH):

    print("正在保存网络并评估...")
    save_network(MODEL_G_PATH, netG, epoch)
    with torch.no_grad():
        fake_imgs = netG(fix_noise, fix_input_c).detach().cpu()
        images = recover_image(fake_imgs)
        full_image = np.full((5 * 64, 5 * 64, 3), 0, dtype="uint8")
        for i in range(25):
            row = i // 5
            col = i % 5
            full_image[row * 64:(row + 1) * 64, col * 64:(col + 1) * 64, :] = images[i]
            # !!!!!!!!!!!!!!
            #每一轮次结束后,这里只展示了一批图片的前25张。
        plt.imshow(full_image)
        #plt.show()
        plt.imsave("{}.png".format(epoch), full_image)

    for data in data_loader:
       
        netD.zero_grad()
        real_imgs, input_c = data   #这里的input_c其实就是数据集每一批中的每个图片对应的标签
        input_c = input_c.to(device)
        input_c = onehot(input_c, NUM_CLASS).to(device)

        # 1.1 来自数据集的样本
        real_imgs = real_imgs.to(device)
        b_size = real_imgs.size(0)
        label = torch.full((b_size,), real_label, dtype=torch.float, device=device)
        #上面的torch.full是生成一维的 b_size这么多的,填充值为1.的张量
        # real_label = 1.
        # fake_label = 0.

        # 使用判别器对真实数据集样本做判断
        #!!!!!!!!!!!!!
        #output应该是判别器判别一批真图片真实的概率
        output = netD(real_imgs, input_c).view(-1)   
        errD_real = criterion(output, label)
        #!!!!!!
        #errD_real是判别器识别真图片的误差,为了训练判别器判别真图片为真
        errD_real.backward()
        D_x = output.mean().item()   
        #!!!!!!!
        #D_x就是判别器判别一批真图片为真的概率的平均值

        
        # 1.2 生成随机向量   这一步想要训练判别器是否能够识别出是虚假图片
        noise = torch.randn(b_size, NZ, device=device)
        # 生成随机标签
        input_c = (torch.rand(b_size, 1) * NUM_CLASS).type(torch.LongTensor).squeeze().to(device)
        input_c = onehot(input_c, NUM_CLASS)

        # 来自生成器生成的样本
        fake = netG(noise, input_c)
        label.fill_(fake_label)

        # real_label = 1.
        # fake_label = 0.
        # 使用判别器对生成器生成样本做判断
        #!!!!!!!!!!!
        #output应该是判别器判别一批假图片真实的概率
        output = netD(fake.detach(), input_c).view(-1)  
        errD_fake = criterion(output, label)
        # 对判别器进行梯度回传
        errD_fake.backward()
        #!!!!!!
        #errD_fake是判别器识别假图片的误差,为了训练判别器判别假图片为假
        D_G_z1 = output.mean().item()
        #!!!!!!!!!!!!
        #D_G_z1就是判别器判别一批假图片为真的概率的平均值
        errD = errD_real + errD_fake
        #!!!!!!
        #errD是判别器识别真实图片和假图片的误差和
        # 更新判别器
        optimizerD.step()


        
       
        netG.zero_grad()
        # 对于生成器训练,令生成器生成的样本为真,
        label.fill_(real_label)

        # real_label = 1.
        # fake_label = 0.
        #!!!!!!!!!!!
        #output应该是判别器判别一批假图片真实的概率
        output = netD(fake, input_c).view(-1)
        # 对生成器计算损失
        errG = criterion(output, label)
         #!!!!!!
        #errG是判别器识别假图片的误差,但是是为了训练生成器生成假图片,以假乱真
        # 因为这里判别器的角度label真实应该是0,但是站在生成器的角度,label真实应该是1,即生成器希望生成的虚假图片让判别器识别的时候,会误以为1才比较好,即误以为是真实的图片
        # 所以生成器交叉熵也是越小越好
        # 对生成器进行梯度回传
        errG.backward()
        D_G_z2 = output.mean().item()
        #!!!!!!!!!!!!
        #D_G_z2就是判别器判别一批假图片为真的概率的平均值
        # 更新生成器
        optimizerG.step()

        # 输出损失状态
        if iters % 5 == 0:
            print('[%d/%d][%d/%d]\tLoss_D: %.4f\tLoss_G: %.4f\tD(x): %.4f\tD(G(z)): %.4f / %.4f'
                  % (epoch, EPOCH, iters % len(data_loader), len(data_loader),
                     errD.item(), errG.item(), D_x, D_G_z1, D_G_z2))
            d_writer.add(loss=errD.item(), i=iters)
            g_writer.add(loss=errG.item(), i=iters)

        # 保存损失记录
        G_losses.append(errG.item())
        D_losses.append(errD.item())

        iters += 1

1)这里的训练顺序

这里训练的顺序是
先拿真实图片训练判别器,
再拿假图片训练判别器,
最后,拿假图片让判别器判断,来训练生成器。

2)为什么先训练判别器后训练生成器呢?

试想,假如先训练生成器,但是刚开始判别器还没有判别能力,所以达不到训练生成器,帮助生成器能越来越生成逼真的假图片。
所以,需要先训练判别器,让判别器先具有初步的判别能力,才能训练生成器,帮助生成器能够生成逼真的假图片。

2.6 训练过程运行结果

在这里插入图片描述
在这里插入图片描述

在这里插入图片描述
#errD是判别器识别真实图片和假图片的误差和,是为了训练判别器能够判别真假图片
#errG是判别器识别假图片的误差,但是是为了训练生成器生成假图片,以假乱真
#D_x就是判别器判别一批真图片为真的概率的平均值,训练判别器识别真图片
#D_G_z1就是判别器判别一批假图片为真的概率的平均值,训练判别器识别假图片
#D_G_z2就是判别器判别一批假图片为真的概率的平均值,训练生成器生成逼真的假图片

在这里插入图片描述
在这里插入图片描述

在这里插入图片描述
在这里插入图片描述

在这里插入图片描述
在这里插入图片描述

在这里插入图片描述
在这里插入图片描述
在这里插入图片描述

在这里插入图片描述

2.7测试结果

在这里插入图片描述

1)测试代码


NZ = 100
NUM_CLASS = 10
BATCH_SIZE = 10
DEVICE = "cpu"

netG = Generator()
netG = restore_network("./", "49", netG)
fix_noise = torch.randn(BATCH_SIZE, NZ, device=DEVICE)
fix_input_c = torch.tensor([0, 0, 0, 0, 0, 0, 0, 0, 0, 0])
device = "cuda" if torch.cuda.is_available() else "cpu"
fix_input_c = onehot(fix_input_c, NUM_CLASS)
fix_input_c = fix_input_c.to(device)
fix_noise = fix_noise.to(device)
netG = netG.to(device)
#fake_imgs = netG(fix_noise, fix_input_c).detach().cpu()



#fix_noise = torch.randn(BATCH_SIZE, NZ, device=DEVICE)
full_image = np.full((10 * 64, 10 * 64, 3), 0, dtype="uint8")
for num in range(10):
    input_c = torch.tensor(np.ones(10, dtype="int64") * num)
    input_c = onehot(input_c, NUM_CLASS)
    fix_noise = fix_noise.to(device)
    input_c = input_c.to(device)
    fake_imgs = netG(fix_noise, input_c).detach().cpu()
    images = recover_image(fake_imgs)
    for i in range(10):
        row = num
        col = i % 10
        full_image[row * 64:(row + 1) * 64, col * 64:(col + 1) * 64, :] = images[i]

plt.imshow(full_image)
plt.show()
plt.imsave("hah.png", full_image)


本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/235967.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

Python学习笔记-类

1 定义类 类是函数的集合,class来定义类 pass并没有实际含义,只是为了代码能执行通过,不报错而已,相当于在代码种占一个位置,后续完善 类是对象的加工厂 2.创建对象 carCar()即是创建对象的过程 3、类的成员 3.1 实例…

PC 机与单片机通信(RS232 协议)

PC 机与单片机通信(RS232 协议) 目录: 1、单片机串口通信的应用 2、PC控制单片机IO口输出 3、单片机控制实训指导及综合应用实例 4、单片机给计算机发送数据: [实验任务] 单片机串口通信的应用,通过串口,我们的个人电脑和单…

操作系统大会 openEuler Summit 2023即将召开,亮点不容错过

【12月11日,北京】数字化、智能化浪潮正奔涌而来。操作系统作为数字基础设施的底座,已经成为推动产业数字化、智能化发展的核心力量,为数智未来提供无限可能。12月15-16日,以“崛起数字时代 引领数智未来”为主题的操作系统大会 &…

双向链表(数据结构与算法)

✅✅✅✅✅✅✅✅✅✅✅✅✅✅✅✅ ✨✨✨✨✨✨✨✨✨✨✨✨✨✨✨✨ 🌿🌿🌿🌿🌿🌿🌿🌿🌿🌿🌿🌿🌿🌿&#x1…

香港身份(户口)大放水!23年香港优才计划、高才通计划申请数据公开!24年冲!

香港身份(户口)大放水!23年香港优才计划、高才通计划申请数据公开!24年冲! 近期香港入境处公布了各项人才入境计划申请及审批数字,: 截止今年10月31日一共有18.4万宗申请各类人才输入计划,获批人…

IntelliJ IDEA无公网远程连接Windows本地Mysql数据库提高开发效率

🔥博客主页: 小羊失眠啦. 🎥系列专栏:《C语言》 《数据结构》 《Linux》《Cpolar》 ❤️感谢大家点赞👍收藏⭐评论✍️ 前些天发现了一个巨牛的人工智能学习网站,通俗易懂,风趣幽默,…

智能优化算法应用:基于蝴蝶算法3D无线传感器网络(WSN)覆盖优化 - 附代码

智能优化算法应用:基于蝴蝶算法3D无线传感器网络(WSN)覆盖优化 - 附代码 文章目录 智能优化算法应用:基于蝴蝶算法3D无线传感器网络(WSN)覆盖优化 - 附代码1.无线传感网络节点模型2.覆盖数学模型及分析3.蝴蝶算法4.实验参数设定5.算法结果6.参考文献7.MA…

【深度学习】注意力机制(三)

本文介绍一些注意力机制的实现,包括EMHSA/SA/SGE/AFT/Outlook Attention。 【深度学习】注意力机制(一) 【深度学习】注意力机制(二) 目录 一、EMHSA(Efficient Multi-Head Self-Attention)…

logstash插件简单介绍

logstash插件 输入插件(input) Input:输入插件。 Input plugins | Logstash Reference [8.11] | Elastic 所有输入插件都支持的配置选项 SettingInput typeRequiredDefaultDescriptionadd_fieldhashNo{}添加一个字段到一个事件codeccodecNoplain用于输入数据的…

可学习超图拉普拉斯算子代码

python版本:3.6。sklearn版本:scikit-learn0.19 问题1:ERROR: Could not build wheels for ecos, scs, which is required to install pyproject.toml-based projects| 解决办法:cvxpy安装过程中遇到的坑_ecos 2.0.7.post1 cp37 …

Terraform实战(二)-terraform创建阿里云资源

1 初始化环境 1.1 创建初始文件夹 $ cd /data $ mkdir terraform $ mkdir aliyun terraform作为terraform的配置文件夹,内部的每一个.tf,.tfvars文件都会被加载。 1.2 配置provider 创建providers.tf文件,配置provider依赖。 provider…

LeetCode 每日一题 Day 9 ||简单dp

70. 爬楼梯 假设你正在爬楼梯。需要 n 阶你才能到达楼顶。 每次你可以爬 1 或 2 个台阶。你有多少种不同的方法可以爬到楼顶呢? 示例 1: 输入:n 2 输出:2 解释:有两种方法可以爬到楼顶。 1 阶 1 阶2 阶 示例 2&am…

智能井盖传感器怎么有效监测井盖位移

随着城市化进程的加速推进,城市基础设施的安全与维护问题日益凸显,引发了社会的广泛关注。其中井盖作为城市地下设施的重要一环,其安全问题时刻影响着市民的幸福生活。近年来智能井盖传感器的发展与应用为实时监测井盖位移提供了全新的解决方…

嵌入式开发按怎样的路线学习较好?

嵌入式开发按怎样的路线学习较好? 在开始前我有一些资料,是我根据自己从业十年经验,熬夜搞了几个通宵,精心整理了一份「嵌入式从专业入门到高级教程工具包」,点个关注,全部无偿共享给大家!&…

BigData之Google Hadoop中间件安装

前言 Hadoop / Zookeeper / Hbase 因资源有限 这三个都是安装在同一台Centos7.9的机器上 但通过配置 所以在逻辑上是distributed模式 1 Java安装 1.1 下载java11 tar/opt/java/jdk-11.0.5/ 1.2 环境配置修改 文件/etc/profile export JAVA_HOME/opt/java/jdk-11.0.5/ e…

网络层重点协议——IP协议详解

✏️✏️✏️今天给大家分享的是网络层的重点协议——IP协议。 清风的CSDN博客 🛩️🛩️🛩️希望我的文章能对你有所帮助,有不足的地方还请各位看官多多指教,大家一起学习交流! ✈️✈️✈️动动你们发财的…

解决vue3 动态引入报错问题

之前这样写的,能使用,但是有警告 警告,查了下,是动态引入的问题,看到说要用glob 然后再我的基础上,稍微 改了下,就可以了: 最后打印了下,modules[../../components/flowc…

每日一练【无重复字符的最长子串】

一、题目描述 给定一个字符串 s ,请你找出其中不含有重复字符的 最长子串 的长度。 二、题目解析 算法思想:移动窗口的思想去解决。 那为什么要用这个方法解决呢? 我们首先用暴力的思路去遍历一遍,我们遍历到deabc后&#xff…

外包干了3个月,技术退步明显.......

先说一下自己的情况,大专生,18年通过校招进入武汉某软件公司,干了接近4年的功能测试,今年年初,感觉自己不能够在这样下去了,长时间呆在一个舒适的环境会让一个人堕落! 而我已经在一个企业干了四年的功能测…

图文教程:从0开始安装stable-diffusion

现在AI绘画还是挺火,Midjourney虽然不错,但是对于我来说还是挺贵的。今天我就来安一下开源的AI绘画stable-diffusion,它的缺点就是对电脑的要求比较高,尤其是显卡。 话不多说开搞。 访问sd的github,https://github.com/AUTOMATIC…