GAN的原理分析与实例

为了便于理解,可以先玩一玩这个网站:GAN Lab: Play with Generative Adversarial Networks in Your Browser!

GAN的本质:枯叶蝶和鸟。生成器的目标:让枯叶蝶进化,变得像枯叶,不被鸟准确识别。判别器的目标:准确判别是枯叶还是鸟

伪代码: 

案例:

原始数据:

案例结果: 

 案例完整代码:

# import os
import torch
import torch.nn as nn
import torchvision as tv
from torch.autograd import Variable
import tqdm
import matplotlib.pyplot as plt
plt.rcParams['font.sans-serif'] = ['SimHei']  # 显示中文标签
plt.rcParams['axes.unicode_minus'] = False

# dir = '... your path/faces/'
dir = './data/train_data'
# path = []
#
# for fileName in os.listdir(dir):
#     path.append(fileName)       # len(path)=51223


noiseSize = 100     # 噪声维度
n_generator_feature = 64        # 生成器feature map数
n_discriminator_feature = 64        # 判别器feature map数
batch_size = 50
d_every = 1     # 每一个batch训练一次discriminator
g_every = 5     # 每五个batch训练一次generator


class NetGenerator(nn.Module):
    def __init__(self):
        super(NetGenerator,self).__init__()
        self.main = nn.Sequential(      # 神经网络模块将按照在传入构造器的顺序依次被添加到计算图中执行
            nn.ConvTranspose2d(noiseSize, n_generator_feature * 8, kernel_size=4, stride=1, padding=0, bias=False),
            #转置卷积层:输入特征映射的尺寸会放大,通道数可能会减小,普通卷积层:输入特征映射的尺寸会缩小,但通道数可能会增加
            nn.BatchNorm2d(n_generator_feature * 8),
            nn.ReLU(True),       # (n_generator_feature * 8) × 4 × 4        (1-1)*1+1*(4-1)+0+1 = 4
            nn.ConvTranspose2d(n_generator_feature * 8, n_generator_feature * 4, kernel_size=4, stride=2, padding=1, bias=False),
            nn.BatchNorm2d(n_generator_feature * 4),
            nn.ReLU(True),      # (n_generator_feature * 4) × 8 × 8     (4-1)*2-2*1+1*(4-1)+0+1 = 8
            nn.ConvTranspose2d(n_generator_feature * 4, n_generator_feature * 2, kernel_size=4, stride=2, padding=1, bias=False),
            nn.BatchNorm2d(n_generator_feature * 2),
            nn.ReLU(True),  # (n_generator_feature * 2) × 16 × 16
            nn.ConvTranspose2d(n_generator_feature * 2, n_generator_feature, kernel_size=4, stride=2, padding=1, bias=False),
            nn.BatchNorm2d(n_generator_feature),
            nn.ReLU(True),      # (n_generator_feature) × 32 × 32
            nn.ConvTranspose2d(n_generator_feature, 3, kernel_size=5, stride=3, padding=1, bias=False),
            nn.Tanh()       # 3 * 96 * 96
        )

    def forward(self, input):
        return self.main(input)


class NetDiscriminator(nn.Module):
    def __init__(self):
        super(NetDiscriminator,self).__init__()
        self.main = nn.Sequential(
            nn.Conv2d(3, n_discriminator_feature, kernel_size=5, stride=3, padding=1, bias=False),
            nn.LeakyReLU(0.2, inplace=True),        # n_discriminator_feature * 32 * 32
            nn.Conv2d(n_discriminator_feature, n_discriminator_feature * 2, kernel_size=4, stride=2, padding=1, bias=False),
            nn.BatchNorm2d(n_discriminator_feature * 2),
            nn.LeakyReLU(0.2, inplace=True),         # (n_discriminator_feature*2) * 16 * 16
            nn.Conv2d(n_discriminator_feature * 2, n_discriminator_feature * 4, kernel_size=4, stride=2, padding=1, bias=False),
            nn.BatchNorm2d(n_discriminator_feature * 4),
            nn.LeakyReLU(0.2, inplace=True),  # (n_discriminator_feature*4) * 8 * 8
            nn.Conv2d(n_discriminator_feature * 4, n_discriminator_feature * 8, kernel_size=4, stride=2, padding=1, bias=False),
            nn.BatchNorm2d(n_discriminator_feature * 8),
            nn.LeakyReLU(0.2, inplace=True),  # (n_discriminator_feature*8) * 4 * 4
            nn.Conv2d(n_discriminator_feature * 8, 1, kernel_size=4, stride=1, padding=0, bias=False),
            nn.Sigmoid()        # 输出一个概率
        )

    def forward(self, input):
        return self.main(input).view(-1)


def train():
    for i, (image,_) in tqdm.tqdm(enumerate(dataloader)):       # type((image,_)) = <class 'list'>, len((image,_)) = 2 * 256 * 3 * 96 * 96
        real_image = Variable(image)
        #real_image = real_image.cuda()

        if (i + 1) % d_every == 0:  #d_every = 1,每一个batch训练一次discriminator
            optimizer_d.zero_grad()
            output = Discriminator(real_image)      # 尽可能把真图片判为True
            error_d_real = criterion(output, true_labels)
            error_d_real.backward()

            noises.data.copy_(torch.randn(batch_size, noiseSize, 1, 1))
            fake_img = Generator(noises).detach()       # 根据噪声生成假图
            fake_output = Discriminator(fake_img)       # 尽可能把假图片判为False
            error_d_fake = criterion(fake_output, fake_labels)
            error_d_fake.backward()
            optimizer_d.step()

        if (i + 1) % g_every == 0:
            optimizer_g.zero_grad()
            noises.data.copy_(torch.randn(batch_size, noiseSize, 1, 1))
            fake_img = Generator(noises)        # 这里没有detach
            fake_output = Discriminator(fake_img)       # 尽可能让Discriminator把假图片判为True
            error_g = criterion(fake_output, true_labels)
            error_g.backward()
            optimizer_g.step()


def show(num):
    fix_fake_imags = Generator(fix_noises)
    fix_fake_imags = fix_fake_imags.data.cpu()[:64] * 0.5 + 0.5

    # x = torch.rand(64, 3, 96, 96)
    fig = plt.figure(1)

    i = 1
    for image in fix_fake_imags:
        ax = fig.add_subplot(8, 8, eval('%d' % i)) #将Figure划分为8行8列的子图网格,并将当前的子图添加到第i个位置。
        # plt.xticks([]), plt.yticks([])  # 去除坐标轴
        plt.axis('off')
        plt.imshow(image.permute(1, 2, 0)) #permute()函数可以对维度进行重排,Matplotlib期望的图像格式是(H, W, C),即高度、宽度、通道
        i += 1
    plt.subplots_adjust(left=None,  # the left side of the subplots of the figure
                        right=None,  # the right side of the subplots of the figure
                        bottom=None,  # the bottom of the subplots of the figure
                        top=None,  # the top of the subplots of the figure
                        wspace=0.05,  # the amount of width reserved for blank space between subplots
                        hspace=0.05)  # the amount of height reserved for white space between subplots)
    plt.suptitle('第%d迭代结果' % num, y=0.91, fontsize=15)
    plt.savefig("images/%dcgan.png" % num)


if __name__ == '__main__':
    transform = tv.transforms.Compose([
        tv.transforms.Resize(96),     # 图片尺寸, transforms.Scale transform is deprecated
        tv.transforms.CenterCrop(96),
        tv.transforms.ToTensor(),
        tv.transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))       # 变成[-1,1]的数
    ])

    dataset = tv.datasets.ImageFolder(dir, transform=transform)

    dataloader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=4, drop_last=True)   # module 'torch.utils.data' has no attribute 'DataLoder'

    print('数据加载完毕!')
    Generator = NetGenerator()
    Discriminator = NetDiscriminator()

    optimizer_g = torch.optim.Adam(Generator.parameters(), lr=2e-4, betas=(0.5, 0.999))
    optimizer_d = torch.optim.Adam(Discriminator.parameters(), lr=2e-4, betas=(0.5, 0.999))
    criterion = torch.nn.BCELoss()

    true_labels = Variable(torch.ones(batch_size))     # batch_size
    fake_labels = Variable(torch.zeros(batch_size))
    fix_noises = Variable(torch.randn(batch_size, noiseSize, 1, 1))
    noises = Variable(torch.randn(batch_size, noiseSize, 1, 1))     # 均值为0,方差为1的正态分布

    # if torch.cuda.is_available() == True:
    #     print('Cuda is available!')
    #     Generator.cuda()
    #     Discriminator.cuda()
    #     criterion.cuda()
    #     true_labels, fake_labels = true_labels.cuda(), fake_labels.cuda()
    #     fix_noises, noises = fix_noises.cuda(), noises.cuda()


    #plot_epoch = [1,5,10,50,100,200,500,800,1000,1500,2000,2500,3000]
    plot_epoch = [1,5,10,50,100,200,500,800,1000,1200,1500]

    for i in range(1500):        # 最大迭代次数
        train()
        print('迭代次数:{}'.format(i))
        if i in plot_epoch:
            show(i)


http://t.csdnimg.cn/FTSriicon-default.png?t=N7T8http://t.csdnimg.cn/FTSri

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/245970.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

ShenYu网关Http服务探活解析

文章目录 网关端服务探活admin端服务探活 Shenyu HTTP服务探活是一种用于检测HTTP服务是否正常运行的机制。它通过建立Socket连接来判断服务是否可用。当服务不可用时&#xff0c;将服务从可用列表中移除。 网关端服务探活 以divide插件为例&#xff0c;看下divide插件是如何获…

谷歌宣布向云计算客户开放 Gemini Pro,开发者可用其构建应用

12 月 14 日消息&#xff0c;美国时间周三&#xff0c;谷歌宣布了一系列升级的人工智能&#xff08;AI&#xff09;功能&#xff0c;旨在为其云计算客户提供更好的服务。这家科技巨头正试图赶上竞争对手&#xff0c;比如微软和 OpenAI&#xff0c;它们都在积极利用人工智能的热…

大数据技术11:Hadoop 原理与运行机制

前言&#xff1a;HDFS &#xff08;Hadoop Distributed File System&#xff09;是 Hadoop 下的分布式文件系统&#xff0c;具有高容错、高吞吐量等特性&#xff0c;可以部署在低成本的硬件上。 一、Hadoop简介 1.1、Hadoop定义 Hadoop 作为一个开源分布式系统基础框架&#x…

Yum仓库架构解析与搭建实践

1.Yum仓库搭建 1.1本地Yum仓库图解 1.2Linux本地仓库搭建 配置本地光盘镜像仓库 1&#xff09;挂载 [roothadoop101 ~]# mount -t iso996 /dev/cdrom/mnt 2&#xff09;查看 [rooothadoop101 ~] # df -h | |grep -i mnt /dev/sr0 4.6G 4.4G 3&#xf…

C语言之文件操作(下)

C语言之文件操作&#xff08;下&#xff09; 文章目录 C语言之文件操作&#xff08;下&#xff09;1. 文件的顺序读写1.1 文件的顺序读写函数1.1.1 字符输入/输出函数&#xff08;fgetc/fputc&#xff09;1.1.2 ⽂本⾏输⼊/输出函数&#xff08;fgets/fputs&#xff09;1.1.3 格…

MYSQL练题笔记-高级字符串函数 / 正则表达式 / 子句-简单3题

这个系列先写了三题&#xff0c;比较简单写在一起。 1.修复表中的名字相关的表和题目如下 看题目就知道是有关字符串函数的&#xff0c;于是在书里查询相关的函数&#xff0c;如下图&#xff0c;但是没有完全对口的函数&#xff0c;所以我还是去百度了。 然后发现结合上面的4个…

多层感知机

目录 一、感知机 1、相关概念介绍 2、&#xff08;单层&#xff09;感知机存在的问题 3、总结 二、多层感知机&#xff08;MLP&#xff09; 1、多层感知机思路 2、激活函数 3、常见的激活函数 4、多类分类 4、总结 三、多层感知机从零开始实现 1、读取数据集 2、初…

[Kubernetes]2. k8s集群中部署基于nodejs golang的项目以及Pod、Deployment详解

一. 创建k8s部署的镜像 1.部署nodejs项目 (1).上传nodejs项目到节点node1 (2).压缩nodejs项目 (3).构建nodejsDockerfile 1).创建nodejsDockerfile 具体可参考:[Docker]十.Docker Swarm讲解,在/root下创建nodejsDockerfile,具体代码如下: FROM node #把压缩文件COPY到镜像的…

掌握iText:轻松处理PDF文档-高级篇-添加页眉和页脚

推荐语 本文介绍了如何使用iText编程库为PDF文档添加自定义的页眉和页脚。通过指定位置、大小、字体和颜色等属性&#xff0c;你可以将文本、图像或其他元素添加到每一页的固定位置&#xff0c;实现专业、可读的自定义页眉和页脚效果。这对于需要批量处理大量PDF文档或需要更精…

modelbox线程爆满宕机bug

序 该bug的解决需要特别感谢张同学。有了大佬的帮助&#xff0c;这个bug才得以解决。 问题现象 modelbox可以进行模型推理&#xff0c;但压测一段时间后&#xff0c;modelbox会宕机&#xff0c;并发生段错误。 “libgomp: Thread creation failed: Resource temporarily una…

TCP/IP详解——ICMP协议,Ping程序,Traceroute程序,IP源站选路选项

文章目录 一、ICMP 协议1. ICMP 概念2. ICMP 重定向3. ICMP 差错检测4. ICMP 错误报告/差错报文5. ICMP 差错报文的结构6. ICMP 源站抑制差错7. ICMP 数据包格式8. ICMP 消息类型和编码类型9. ICMP 应用-Ping10. ICMP 应用-Tracert11. BSD 对 ICMP 报文的处理12. 总结 PING 程序…

数据结构:队列

数据结构&#xff1a;队列 文章目录 数据结构&#xff1a;队列1.队列常用操作&#xff1a;2.队列的实现3.队列典型应用 ***「队列 queue」是一种遵循先入先出规则的线性数据结构。***队列模拟了排队现象&#xff0c;即新来的人不断加入队列尾部&#xff0c;而位于队列头部的人逐…

Visual studio+Qt开发环境搭建以及注意事项和打开qt的.pro项目

下载qt-然后安装5.14.2_msvc2017 不知道安装那个就全选5.14.2的父级按钮 https://download.qt.io/archive/qt/5.14/5.14.2/ 安装Visual studio,下载直接下一步就行 配置Visual studio的qt环境 在线安装-重启Visual studio会自动安装 离线安装-关闭Visual studio点击安装 关闭…

a16z:加密行业2024趋势“无缝用户体验”

近日&#xff0c;知名加密投资机构a16z发布了“Big ideas 2024”&#xff0c;列出了加密行业在 2024 年几个具备趋势的“大想法”&#xff0c;其中 Seamless UX&#xff08;无缝用户体验&#xff09;赫然在列。 从最为直观的理解上&#xff0c;Seamless UX 是在强调用户在使用产…

路由器原理

目录 一.路由器 1.路由器的转发原理 2.路由器的工作原理 二.路由表 1.路由表的形成 2.路由表表头含义 直连&#xff1a; 非直连&#xff1a; 静态 静态路由的配置 负载均衡&#xff08;浮动路由&#xff09; 默认路由 动态 三.交换与路由对比 一.路由器 1.路由器…

独立完成软件的功能的测试(4)

独立完成软件的功能的测试&#xff08;4&#xff09; &#xff08;12.14&#xff09;&#xff08;功能测试>头条项目实战&#xff09; 项目总体概述 项目背景和定位&#xff1a;一款汇聚科技咨询&#xff0c;技术文章和问答交流的用户移动终端产品&#xff0c;用户可以通过…

STM32在CTF中的应用和快速解题

题目给的是bin文件&#xff0c;基本上就是需要我们手动修复的固件逆向。 如果给的是hex文件&#xff0c;我们可能需要使用MKD进行动态调试 主要还是以做题为目的 详细的可以去看文档&#xff1a;https://pdf1.alldatasheet.com/datasheet-pdf/view/201596/STMICROELECTRONIC…

微服务学习:Gateway服务网关

一&#xff0c;Gateway服务网关的作用&#xff1a; 路由请求&#xff1a;Gateway服务网关可以根据请求的URL或其他标识符将请求路由到特定的微服务。 负载均衡&#xff1a;Gateway服务网关可以通过负载均衡算法分配请求到多个实例中&#xff0c;从而平衡各个微服务的负载压力。…

一入二出热电阻温度信号隔离变送器

一入二出热电阻温度信号隔离变送器 用于测量铂热电阻Pt10,Pt100,Pt1000,Cu50,Cu100的热电阻传感器的小型仪器设备。广泛应用于工业测量温度系统&#xff0c;是降低成本且有效的测量方式。 型号&#xff1a;JSD TARZ-1002系列 我们来看下有什么特点&#xff1a; ◆小体积&#x…

天猫数据分析平台-天猫销售数据查询软件-11月天猫平台冲锋衣市场销售运营数据分析

随着气温逐渐下降&#xff0c;保暖服饰迎来热销&#xff0c;冲锋衣的需求大增。如今冲锋衣已经不仅仅是户外运动的装备&#xff0c;还成为很多年轻人的日常穿搭和时尚的追求。 新的穿搭趋势也带来了巨大的市场机会。据公开数据显示&#xff0c;中国有冲锋衣生产及经营企业超过8…