Paddle 基于ANN(全连接神经网络)的GAN(生成对抗网络)实现

什么是GAN

GAN是生成对抗网络,将会根据一个随机向量,实现数据的生成(如生成手写数字、生成文本等)。

GAN的训练过程中,需要有一个生成器G和一个鉴别器D.

生成器用于生成数据,鉴定器用于鉴定数据的准确性,其实就是在鉴别数据是人生成的还是机器生成的,因为生成器需要以假乱真。

鉴别器将会与生成器一起训练。鉴别器将会先训练,这样才有适当的能力去鉴定生成器生成数据的准确性。

鉴别器的训练过程中,需要先给它准确的数据,和通过随机向量传入生成器产生的数据(一律视为负样本),并通过损失函数对其进行训练;生成器训练过程中,会先给它一个随机向量进行前向传播,然后让鉴别器判断其正确性,并通过损失函数(不正确的数据意味着有损失)进行训练:

生成器训练过程中,需要先通过随机向量获取其结果,然后让鉴别器进行鉴别,在通过鉴别器的鉴别结果计算损失(如果鉴别器认为这是生成器生成的,则产生损失),最后更新梯度和参数:

训练过程直到生成器拟合训练集(收敛),判别器的输出总是0.5(均方误差损失函数应为0.25)为止.

形象的GAN的例子

想象一场由一位“名画伪造者”和一位“艺术鉴定家”参与的猫捉老鼠游戏。

在这个场景中,名画伪造者(即GAN中的生成器)的目标是创造出一幅足以欺骗艺术鉴定家(即GAN中的判别器)的假画。开始时,伪造者的技艺并不精湛,他制作的假画充满了破绽,很容易被鉴定家一眼识破。

然而,随着伪造者不断尝试和失败,他逐渐从每一次的失败中学习,逐渐提升了自己的技艺。他开始注意到真画的每一个细节,从笔触、色彩到构图,都尽量模仿得惟妙惟肖。每一次的失败都让他更接近成功,他制作的假画也越来越难以辨别真伪。

而艺术鉴定家也不甘示弱。他开始时能够轻易地识别出伪造者的假画,但随着伪造者技艺的提升,他也需要不断提升自己的鉴定能力。他开始深入研究真画的每一个特点,以便更准确地识别出伪造者的假画。

这个过程就像GAN中的训练过程一样。生成器不断尝试生成新的数据(在这里是假画),而判别器则不断尝试区分这些数据是真实的还是生成的。两者在相互竞争的过程中不断提升自己的能力,最终达到了一个平衡状态。

在这个例子中,名画伪造者就是GAN中的生成器,他负责生成新的数据;而艺术鉴定家则是GAN中的判别器,他负责区分数据的真伪。两者在相互竞争的过程中共同进步,使得生成的数据越来越接近真实的数据。

代码实现

本文将以基于MNIST(手写数据集)为数据集,实现一个生成手写数字的GAN模型:

首先创建models.py,用于定义判别器和生成器:

import paddle


# Generator Code
class Generator(paddle.nn.Layer):
    def __init__(self, ):
        super(Generator, self).__init__()
        self.gen = paddle.nn.Sequential(
            paddle.nn.Linear(in_features=100, out_features=256),
            paddle.nn.ReLU(True),
            paddle.nn.Linear(in_features=256, out_features=512),
            paddle.nn.ReLU(True),
            paddle.nn.Linear(in_features=512, out_features=1024),
            paddle.nn.Tanh(),
        )

    def forward(self, x):
        x = self.gen(x)
        out = paddle.reshape(x,[-1,1,32,32])
        return out


# Discriminator Code
class Discriminator(paddle.nn.Layer):
    def __init__(self, ):
        super(Discriminator, self).__init__()
        self.dis = paddle.nn.Sequential(
            paddle.nn.Linear(in_features=1024, out_features=512),
            paddle.nn.LeakyReLU(0.2),
            paddle.nn.Linear(in_features=512, out_features=256),
            paddle.nn.LeakyReLU(0.2),
            paddle.nn.Linear(in_features=256, out_features=1),
            paddle.nn.Sigmoid()
        )

    def forward(self, x):
        x = paddle.reshape(x, [-1, 1024])
        out = self.dis(x)
        return out

其中,生成器将接收一个长度为100的张量(随机向量),输出一个长度为1024的张量(生成的图片);鉴别器将接收一个长度为1024的张量(图片) ,输出长度为1的张量(鉴别结果)

然后创建main.py,用于训练:

import paddle
import matplotlib.pyplot as plt
from models import Generator, Discriminator
import numpy as np

dataset = paddle.vision.datasets.MNIST(mode='train',
                                       transform=paddle.vision.transforms.Compose([
                                           paddle.vision.transforms.Resize((32, 32)),
                                           paddle.vision.transforms.Normalize([0], [255])
                                       ]))

dataloader = paddle.io.DataLoader(dataset, batch_size=32, shuffle=True)

netG = Generator()
netD = Discriminator()

if 1:
    try:
        mydict = paddle.load('generator.params')
        netG.set_dict(mydict)
        mydict = paddle.load('discriminator.params')
        netD.set_dict(mydict)
    except:
        print('fail to load model')

optimizerD = paddle.optimizer.Adam(parameters=netD.parameters(), learning_rate=0.0002, beta1=0.5, beta2=0.999)
optimizerG = paddle.optimizer.Adam(parameters=netG.parameters(), learning_rate=0.0002, beta1=0.5, beta2=0.999)

# 最大迭代epoch
max_epoch = 10

for epoch in range(max_epoch):
    now_step = 0
    for step, (data, label) in enumerate(dataloader):
        ############################
        # (1) 更新鉴别器
        ###########################

        # 清除D的梯度
        optimizerD.clear_grad()

        # 传入正样本,并更新梯度
        pos_img = data
        label = paddle.full([pos_img.shape[0], 1], 1, dtype='float32')
        pre = netD(pos_img)
        loss_D_1 = paddle.nn.functional.mse_loss(pre, label)
        loss_D_1.backward()

        # 通过randn构造随机数,制造负样本,并传入D,更新梯度
        noise = paddle.randn([pos_img.shape[0], 100], 'float32')
        neg_img = netG(noise)
        label = paddle.full([pos_img.shape[0], 1], 0, dtype='float32')
        pre = netD(neg_img.detach())  # 通过detach阻断网络梯度传播,不影响G的梯度计算
        loss_D_2 = paddle.nn.functional.mse_loss(pre, label)
        loss_D_2.backward()

        # 更新D网络参数
        optimizerD.step()
        optimizerD.clear_grad()

        loss_D = loss_D_1 + loss_D_2

        ############################
        # (2) 更新生成器
        ###########################

        # 清除D的梯度
        optimizerG.clear_grad()

        noise = paddle.randn([pos_img.shape[0], 100], 'float32')
        fake = netG(noise)
        label = paddle.full((pos_img.shape[0], 1), 1, dtype=np.float32, )
        output = netD(fake)
        # 这个写法没有问题,因为这个mse_loss既会影响到netG(output=netD(netG(noise)))的梯度,也会影响到netD的梯度,但是之后的代码并没有更新netD的参数,而循环开头就清除了netD的梯度
        loss_G = paddle.nn.functional.mse_loss(output, label)
        loss_G.backward()

        # 更新G网络参数
        optimizerG.step()
        optimizerG.clear_grad()

        now_step += 1

        ###########################
        # 输出日志
        ###########################
        if now_step % 100 == 0:
            print(f'Epoch ID={epoch} Batch ID={now_step} \n\n D-Loss={float(loss_D)} G-Loss={float(loss_G)}')

paddle.save(netG.state_dict(), "generator.params")
paddle.save(netD.state_dict(), "discriminator.params")

如果是第一次训练或不使用原有训练参数,可以将if 1改成if 0.

接下来创建use.py,用于生成图片:

import paddle
from models import Generator
import matplotlib.pyplot as plt

import paddle
from models import Generator
import matplotlib.pyplot as plt
import numpy as np

# 加载模型
netG = Generator()
mydict = paddle.load('generator.params')
netG.set_dict(mydict)

# 设置matplotlib的显示环境
fig, axs = plt.subplots(nrows=2, ncols=5, figsize=(15, 6))  # 创建一个2x5的子图网格

# 生成10个噪声向量
for i, ax in enumerate(axs.flatten()):
    noise = paddle.randn([1, 100], 'float32')
    img = netG(noise)
    img = img.numpy()[0][0]  # img.numpy():张量转np数组
    img[img < 0] = 0  # 将img中所有小于0的元素赋值为0
    img = np.clip(img, 0, 1)  # 将img中所有小于0的元素设为0,大于1的设为1(如果需要)

    # 显示图片
    ax.imshow(img)
    ax.axis('off')  # 不显示坐标轴

# 显示图像
plt.show()

进行多轮训练后,生成结果:

 

可以看到,它很好的生成了我们想要的图片。

GANs

但是,我们这个模型只能随机产生数字,还不能生成指定的数字(如让机器生成一个1).为了解决这个问题,我们可以针对每一个数字生成一个对应的GAN,所有这样的GAN组合起来,就是GANs. 这里不展开讲解。

参考

MNIST数据集下用Paddle框架的动态图模式玩耍经典对抗生成网络(GAN)-使用文档-PaddlePaddle深度学习平台

【飞桨PaddlePaddle】四天搞懂生成对抗网络(一)——通俗理解经典GAN_四天搞懂生成对抗网络(一)-CSDN博客

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/607879.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

车载测试___面试题和答案归纳

车载面试题 一、实车还在设计开发阶段&#xff0c;大部分测试通过什么测试&#xff1f; 答案&#xff1a;通过台架和仿真来完成的 二、测试部分划分&#xff1f; 测试部门是分为自研&#xff0c;系统&#xff0c;验收&#xff0c;自研部门是开发阶段测试&#xff0c;系统部门…

95、动态规划-编辑距离

递归暴力解法 递归方法的基本思想是考虑最后一个字符的操作&#xff0c;然后根据这些操作递归处理子问题。 递归函数定义&#xff1a;定义一个递归函数 minDistance(i, j)&#xff0c;表示将 word1 的前 i 个字符转换成 word2 的前 j 个字符所需的最小操作数。 递归终止条件…

llama.cpp制作GGUF文件及使用

llama.cpp的介绍 llama.cpp是一个开源项目&#xff0c;由Georgi Gerganov开发&#xff0c;旨在提供一个高性能的推理工具&#xff0c;专为在各种硬件平台上运行大型语言模型&#xff08;LLMs&#xff09;而设计。这个项目的重点在于优化推理过程中的性能问题&#xff0c;特别是…

(七)JSP教程——session对象

浏览器和Web服务器之间的交互通过HTTP协议来完成&#xff0c;HTTP协议是一种无状态的协议&#xff0c;服务器端无法保留浏览器每次与服务器的连接信息&#xff0c;无法判断每次连接的是否为同一客户端。为了让服务器端记住客户端的连接信息&#xff0c;可以使用session对象来记…

Java毕设之基于springboot的医护人员排班系统

运行环境 开发语言:java 框架:springboot&#xff0c;vue JDK版本:JDK1.8 数据库:mysql5.7(推荐5.7&#xff0c;8.0也可以) 数据库工具:Navicat11 开发软件:idea/eclipse(推荐idea) 系统详细实现 医护类型管理 医护人员排班系统的系统管理员可以对医护类型添加修改删除以及…

Error: error:0308010C:digital envelope routines::unsupported 问题如何解决

Error: error:0308010C:digital envelope routines::unsupported 通常与 Node.js 的加密库中对某些加密算法的支持有关。这个错误可能是因为 Node.js 的版本与某些依赖库不兼容导致的。特别是在 Node.js 17 版本中&#xff0c;默认使用 OpenSSL 3&#xff0c;而一些旧的加密方式…

电商核心技术揭秘53:社群营销的策略与实施

相关系列文章 电商技术揭秘相关系列文章合集&#xff08;1&#xff09; 电商技术揭秘相关系列文章合集&#xff08;2&#xff09; 电商技术揭秘相关系列文章合集&#xff08;3&#xff09; 电商技术揭秘四十一&#xff1a;电商平台的营销系统浅析 电商技术揭秘四十二&#…

Python轻量级Web框架Flask(13)—— Flask个人博客项目

0、前言: ★这部分内容是基于之前Flask学习内容的一个实战项目梳理内容,没有可以直接抄下来跑的代码,是学习了之前Flask基础知识之后,再来看这部分内容,就会对Flask项目开发流程有更清楚的认知,对一些开发细节可以进一步的学习。项目功能,通过Flask制作个人博客。项目架…

YOLOv9中模块总结补充|RepNCSPELAN4详图

专栏地址&#xff1a;目前售价售价69.9&#xff0c;改进点70 专栏介绍&#xff1a;YOLOv9改进系列 | 包含深度学习最新创新&#xff0c;助力高效涨点&#xff01;&#xff01;&#xff01; 1. RepNCSPELAN4详图 RepNCSPELAN4是YOLOv9中的特征提取-融合模块&#xff0c;类似前几…

制造业的智慧进化:机器学习与人工智能的全方位渗透

&#x1f9d1; 作者简介&#xff1a;阿里巴巴嵌入式技术专家&#xff0c;深耕嵌入式人工智能领域&#xff0c;具备多年的嵌入式硬件产品研发管理经验。 &#x1f4d2; 博客介绍&#xff1a;分享嵌入式开发领域的相关知识、经验、思考和感悟&#xff0c;欢迎关注。提供嵌入式方向…

libcity 笔记:添加自定义dataset

假设我们把libcity/data/dataset/trajectory_dataset.py复制一份到libcity/data/dataset/dataset_subclass/GeolifeDM_dataset.py&#xff0c;里面内容不变&#xff0c;只是把class的名字换了 那其他需要修改哪些内容&#xff0c;使得这个dataset生效呢 libcity/data/dataset/d…

国内唯一!阿里云荣膺MongoDB“2024年度DBaaS认证合作伙伴奖”

近日&#xff0c;在MongoDB用户大会纽约站上&#xff0c;阿里云荣膺MongoDB“2024年度DBaaS认证合作伙伴奖”。这是阿里云连续第五年斩获MongoDB合作伙伴奖项&#xff0c;也是唯一获此殊荣的中国云厂商。 MongoDB是当今全球最受欢迎的非关系型数据库之一。凭借灵活的模式和丰富…

Web3探索加密世界:如何避免限制并增加空投成功的几率

今天分享空投如何避免限制以提高效率&#xff0c;增加成功几率&#xff0c;首先我们来了解什么是空投加密&#xff0c;有哪些空投类型。 一、什么是空投加密&#xff1f; 加密货币空投是一种营销策略&#xff0c;包括向用户的钱包地址发送免费的硬币或代币。 加密货币项目使用…

【Linux】深浅睡眠状态超详解!!!

1.浅度睡眠状态【S】&#xff08;挂起&#xff09; ——S (sleeping)可中断睡眠状态 进程因等待某个条件&#xff08;如 I/O 完成、互斥锁释放或某个事件发生&#xff09;而无法继续执行。在这种情况下&#xff0c;进程会进入阻塞状态&#xff0c;在阻塞状态下&#xff0c;进程…

Python批量修改图片文件名中的指定名称

批量处理图像时&#xff0c;图片名有时需要统一&#xff0c;本教程仅针对图片中名如&#xff1a;0001x4.png&#xff0c;批量将图片名中的x4去除&#xff0c;只留下0001.png的情况。 如果想要按照原图片顺序批量修改图片名&#xff0c;参考其它博文&#xff1a;按照原顺序批量…

Joplin:自由、安全、多功能的笔记应用

什么是 Joplin&#xff1f; Joplin是一款免费、开源的笔记和待办事项应用程序&#xff0c;可以处理整理到笔记本中的大量笔记。这些笔记是可搜索的&#xff0c;可以直接从应用程序或从您自己的文本编辑器中复制、标记和修改。笔记采用Markdown 格式 功能亮点 功能丰富&#x…

科技云报道:从亚运到奥运,大型国际赛事共赴“云端”

科技云报道原创。 “广播电视转播技术拯救了奥运会”前奥委会主席萨马兰奇这句话广为流传。 奥运会、世界杯、亚运会这样的全球大型体育赛事不仅是体育竞技的盛宴&#xff0c;也是商业盛宴&#xff0c;还是技术与人文的融合秀。随着科技的进步&#xff0c;技术在体育赛事中扮…

google地图js,添加标记,以及infowindow信息弹窗

&#xff08;谷歌地图版本V3&#xff09; var contentString "<div classdevinfo><P>设备ID: BJ-20240507</p> <P>设备状态: 正常</p> <P>通讯信号: 89% </p> <P>设备位置: 中国</p> <P>剂量率: 988</p&…

MATLAB 自定义实现点云随机抽稀方法(66)

MATLAB 自定义实现点云随机抽稀方法(66) 一、算法介绍二、算法实现1.代码2.结果三、数据链接一、算法介绍 MATLAB虽然提供了点云随机抽稀的内置函数,但是我们也可以自己实现这个功能,有助于理解,下面是具体的实现效果和代码(直接复制粘贴即可使用): 使用提供的数据直接…

Matlab如何导出高质量论文插图?科研效率UpUp第8期

当你用Matlab绘制了一张论文插图&#xff1a; 想要所见即所得&#xff0c;原封不动地将其保存下来&#xff0c;该怎么操作呢&#xff1f; 虽说以前总结过7种方法&#xff08;Matlab导出论文插图的7种方法&#xff09;&#xff0c;但要说哪一种可以满足上面的要求&#xff0c;想…