PyTorch训练简单的生成对抗网络GAN

文章目录

    • 原理
    • 代码
    • 结果
    • 参考

原理

同时训练两个网络:辨别器Discriminator 和 生成器Generator
Generator是 造假者,用来生成假数据。
Discriminator 是警察,尽可能的分辨出来哪些是造假的,哪些是真实的数据。

目的:使得判别模型尽量犯错,无法判断数据是来自真实数据还是生成出来的数据。

GAN的梯度下降训练过程:

在这里插入图片描述
上图来源:https://arxiv.org/abs/1406.2661

Train 辨别器: m a x max max l o g ( D ( x ) ) + l o g ( 1 − D ( G ( z ) ) ) log(D(x)) + log(1 - D(G(z))) log(D(x))+log(1D(G(z)))

Train 生成器: m i n min min l o g ( 1 − D ( G ( z ) ) ) log(1-D(G(z))) log(1D(G(z)))

我们可以使用BCEloss来计算上述两个损失函数

BCEloss的表达式: m i n − [ y l n x + ( 1 − y ) l n ( 1 − x ) ] min -[ylnx + (1-y)ln(1-x)] min[ylnx+(1y)ln(1x)]
具体过程参加代码中注释

代码

import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.datasets as datasets
from torch.utils.data import DataLoader
import torchvision.transforms as transforms
from torch.utils.tensorboard import SummaryWriter  # to print to tensorboard

class Discriminator(nn.Module):
    def __init__(self, img_dim):
        super(Discriminator, self).__init__()
        self.disc = nn.Sequential(
            nn.Linear(img_dim, 128),
            nn.LeakyReLU(0.1),
            nn.Linear(128, 1),
            nn.Sigmoid(),
        )

    def forward(self, x):
        return self.disc(x)
    
class Generator(nn.Module):
    def __init__(self, z_dim, img_dim): # z_dim 噪声的维度
        super(Generator, self).__init__()
        self.gen = nn.Sequential(
            nn.Linear(z_dim, 256),
            nn.LeakyReLU(0.1),
            nn.Linear(256, img_dim), # 28x28 -> 784
            nn.Tanh(),
        )
    
    def forward(self, x):
        return self.gen(x)
    
# Hyperparameters
device = 'cuda' if torch.cuda.is_available() else 'cpu'
lr = 3e-4 # 3e-4是Adam最好的学习率
z_dim = 64 # 噪声维度
img_dim = 784 # 28x28x1
batch_size = 32
num_epochs = 50

disc = Discriminator(img_dim).to(device)
gen = Generator(z_dim, img_dim).to(device)

fixed_noise = torch.randn((batch_size, z_dim)).to(device)
transforms = transforms.Compose( # MNIST标准化系数:(0.1307,), (0.3081,)
    [transforms.ToTensor(), transforms.Normalize((0.1307, ), (0.3081,))] # 不同数据集就有不同的标准化系数
)

dataset = datasets.MNIST(root='dataset/', transform=transforms, download=True)
loader = DataLoader(dataset, batch_size=batch_size, shuffle=True)

opt_disc = optim.Adam(disc.parameters(), lr=lr)
opt_gen = optim.Adam(gen.parameters(), lr=lr)
# BCE 损失
criterion = nn.BCELoss()

# 打开tensorboard:在该目录下,使用 tensorboard --logdir=runs
writer_fake = SummaryWriter(f"runs/GAN_MNIST/fake")
writer_real = SummaryWriter(f"runs/GAN_MNIST/real")
step = 0

for epoch in range(num_epochs):
    for batch_idx, (real, _) in enumerate(loader):
        real = real.view(-1, 784).to(device) # view相当于reshape
        batch_size = real.shape[0]

        ### Train Discriminator: max log(D(real)) + log(1 - D(G(z)))
        noise = torch.randn(batch_size, z_dim).to(device)
        fake = gen(noise) # G(z)
        disc_real = disc(real).view(-1) # flatten
        # BCEloss的表达式:min -[ylnx + (1-y)ln(1-x)]

        # max log(D(real)) 相当于 min -log(D(real))
        # ones_like:1填充得到y=1, 即可忽略  min -[ylnx + (1-y)ln(1-x)]中的后一项
        # 得到 min -lnx,这里的x就是我们的real图片
        lossD_real = criterion(disc_real, torch.ones_like(disc_real))

        disc_fake = disc(fake).view(-1)
        # max log(1 - D(G(z))) 相当于 min -log(1 - D(G(z)))
        # zeros_like用0填充,得到y=0,即可忽略  min -[ylnx + (1-y)ln(1-x)]中的前一项
        # 得到 min -ln(1-x),这里的x就是我们的fake噪声
        lossD_fake = criterion(disc_fake, torch.zeros_like(disc_fake))
        lossD = (lossD_real + lossD_fake) / 2

        disc.zero_grad()
        lossD.backward(retain_graph=True)
        opt_disc.step()

        ### Train Generator: min log(1-D(G(z))) <--> max log(D(G(z))) <--> min - log(D(G(z)))
        # 依然可使用BCEloss来做
        output = disc(fake).view(-1)
        lossG = criterion(output, torch.ones_like(output))
        gen.zero_grad()
        lossG.backward()
        opt_gen.step()

        if batch_idx == 0:
            print(
                f"Epoch [{epoch}/{num_epochs}] \ "
                f"Loss D: {lossD:.4f}, Loss G: {lossG:.4f}"
            )

            with torch.no_grad():
                fake = gen(fixed_noise).reshape(-1, 1, 28, 28)
                data = real.reshape(-1, 1, 28, 28)
                img_grid_fake = torchvision.utils.make_grid(fake, normalize=True)
                img_grid_real = torchvision.utils.make_grid(data, normalize=True)

                writer_fake.add_image(
                    "Mnist Fake Images", img_grid_fake, global_step=step
                )

                writer_real.add_image(
                    "Mnist Real Images", img_grid_real, global_step=step
                )

                step += 1

结果

训练50轮的的损失

Epoch [0/50] \ Loss D: 0.7366, Loss G: 0.7051
Epoch [1/50] \ Loss D: 0.2483, Loss G: 1.6877
Epoch [2/50] \ Loss D: 0.1049, Loss G: 2.4980
Epoch [3/50] \ Loss D: 0.1159, Loss G: 3.4923
Epoch [4/50] \ Loss D: 0.0400, Loss G: 3.8776
Epoch [5/50] \ Loss D: 0.0450, Loss G: 4.1703
...
Epoch [43/50] \ Loss D: 0.0022, Loss G: 7.7446
Epoch [44/50] \ Loss D: 0.0007, Loss G: 9.1281
Epoch [45/50] \ Loss D: 0.0138, Loss G: 6.2177
Epoch [46/50] \ Loss D: 0.0008, Loss G: 9.1188
Epoch [47/50] \ Loss D: 0.0025, Loss G: 8.9419
Epoch [48/50] \ Loss D: 0.0010, Loss G: 8.3315
Epoch [49/50] \ Loss D: 0.0007, Loss G: 7.8302

使用

tensorboard --logdir=runs

打开tensorboard:

在这里插入图片描述
可以看到效果并不好,这是由于我们只是采用了简单的线性网络来做辨别器和生成器。后面的博文我们会使用更复杂的网络来训练GAN。

参考

[1] Building our first simple GAN
[2] https://arxiv.org/abs/1406.2661

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/77570.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

C++中List的实现

前言 数据结构中&#xff0c;我们了解到了链表&#xff0c;但是我们使用时需要自己去实现链表才能用&#xff0c;但是C出现了list将这一切皆变为现。list可以看作是一个带头双向循环的链表结构&#xff0c;并且可以在任意的正确范围内进行增删查改数据的容器。list容器一样也是…

【CSS】CSS 布局——常规流布局

<h1>基础文档流</h1><p>我是一个基本的块级元素。我的相邻块级元素在我的下方另起一行。</p><p>默认情况下&#xff0c;我们会占据父元素 100%的宽度&#xff0c;并且我们的高度与我们的子元素内容一样高。我们的总宽度和高度是我们的内容 内边距…

如何发布自己的小程序

小程序的基础内容组件 text&#xff1a; 文本支持长按选中的效果 <text selectable>151535313511</text> rich-text: 把HTML字符串渲染为对应的UI <rich-text nodes"<h1 stylecolor:red;>123</h1>"></rich-text> 小程序的…

2023牛客暑期多校训练营8-C Clamped Sequence II

2023牛客暑期多校训练营8-C Clamped Sequence II https://ac.nowcoder.com/acm/contest/57362/C 文章目录 2023牛客暑期多校训练营8-C Clamped Sequence II题意解题思路代码 题意 解题思路 先考虑不加紧密度的情况&#xff0c;要支持单点修改&#xff0c;整体查询&#xff0…

AUTOSAR NvM Block的三种类型

Native NVRAM block Native block是最基础的NvM Block&#xff0c;可以用来存储一个数据&#xff0c;可以配置长度、CRC等。 Redundant NVRAM block Redundant block就是在Native block的基础上再加一个冗余块&#xff0c;当Native block失效&#xff08;读取失败或CRC校验失…

时序预测 | MATLAB实现基于BiLSTM双向长短期记忆神经网络的时间序列预测-递归预测未来(多指标评价)

时序预测 | MATLAB实现基于BiLSTM双向长短期记忆神经网络的时间序列预测-递归预测未来(多指标评价) 目录 时序预测 | MATLAB实现基于BiLSTM双向长短期记忆神经网络的时间序列预测-递归预测未来(多指标评价)预测结果基本介绍程序设计参考资料 预测结果 基本介绍 Matlab实现BiLST…

2022年09月 C/C++(二级)真题解析#中国电子学会#全国青少年软件编程等级考试

第1题&#xff1a;统计误差范围内的数 统计一个整数序列中与指定数字m误差范围小于等于X的数的个数。 时间限制&#xff1a;5000 内存限制&#xff1a;65536 输入 输入包含三行&#xff1a; 第一行为N&#xff0c;表示整数序列的长度(N < 100); 第二行为N个整数&#xff0c;…

把握数据要素,做数字化时代的弄潮儿

截至2022年6月&#xff0c;我国网民规模已经达到了10.51亿&#xff0c;人均上网时间达到了每周29.5个小时&#xff0c;并且这部分人群使用手机上网的比例为99.6%。如果把工作、睡眠以及其他的必要的时间算上的话&#xff0c;可以发现通过手机上网已经成为了人们日常中的一部分。…

浅谈人工智能技术与物联网结合带来的好处

物联网是指通过互联网和各种技术将设备进行连接&#xff0c;实时采集数据、交互信息的网络&#xff0c;对设备实现智能化自动化感知、识别和控制&#xff0c;给人们带来便利。 人工智能是计算机科学的一个分支&#xff0c;旨在研究和开发能够模拟人类智能的技术和方法。人工智能…

SpringBoot的配置文件以及日志设置

在使用SpringBoot开发的过程中我们通常会用到配置文件来设置配置信息 以及使用日志来进行记录我们的操作&#xff0c;方便我们对错误的定位 配置文件的作用在于&#xff1a;设置端口&#xff0c;设置数据库连接信息&#xff0c;设置日志等等 在SpringBoot中&#xff0c;配置…

【LangChain概念】了解语言链️:第2部分

一、说明 在LangChain的帮助下创建LLM应用程序可以帮助我们轻松地链接所有内容。LangChain 是一个创新的框架&#xff0c;它正在彻底改变我们开发由语言模型驱动的应用程序的方式。通过结合先进的原则&#xff0c;LangChain正在重新定义通过传统API可以实现的极限。 在上一篇博…

SpringBoot携带Jre绿色部署项目

文章目录 SpringBoot携带Jre绿色部署运行项目1. 实现步骤2. 自测项目文件目录及bat文件内容&#xff0c;截图如下&#xff1a;2-1 项目文件夹列表&#xff1a;2-2. bat内容 3. 扩展&#xff1a; 1.6-1.8版本的jdk下载 SpringBoot携带Jre绿色部署运行项目 说明&#xff1a; 实…

【Python】Web学习笔记_flask(5)——会话cookie对象

HTTP是无状态协议&#xff0c;一次请求响应结束后&#xff0c;服务器不会留下对方信息&#xff0c;对于大部分web程序来说&#xff0c;是不方便的&#xff0c;所以有了cookie技术&#xff0c;通过在请求和响应保温中添加cookie数据来保存客户端的状态。 html代码&#xff1a; …

css鼠标样式 cursor: pointer

cursor: none; cursor:not-allowed; 禁止选择 user-select: none; pointer-events:none;禁止触发事件, 该样式会阻止默认事件的发生&#xff0c;但鼠标样式会变成箭头

什么文件传输协议才能保障跨国文件传输安全又稳定

在当今的全球化时代&#xff0c;跨国文件传输是一种常见而又重要的需求&#xff0c;无论是个人还是企业&#xff0c;都需要通过网络来分享和交换各种类型和大小的文件。但是&#xff0c;跨国文件传输也面临着许多挑战和风险&#xff0c;如何选择一个合适的文件传输协议&#xf…

CUDA计算超时(TDR)和阻塞界面问题的处理参考方法

本文提供一种解决单个英伟达独立显卡(终端用户常见的情形)上计算密集导致程序崩溃和电脑界面卡死的问题参考方法,采取降低效率和花费更多时间的思路来解决崩溃和卡顿的问题,即让CPU占有率不是一直100%,也不会因为被TDR机制打断。 如上图,在GPU-Z软件中看到“GPU Load”没…

W5500-EVB-PICO 做UDP Server进行数据回环测试(七)

前言 前面我们用W5500-EVB-PICO 开发板在TCP Client和TCP Server模式下&#xff0c;分别进行数据回环测试&#xff0c;本章我们将用开发板在UDP Server模式下进行数据回环测试。 UDP是什么&#xff1f;什么是UDP Server&#xff1f;能干什么&#xff1f; UDP (User Dataqram P…

自动切换HTTP爬虫ip助力Python数据采集

在Python的爬虫世界里&#xff0c;你是否也被网站的IP封锁问题困扰过&#xff1f;别担心&#xff0c;我来教你一个终极方案&#xff0c;让你的爬虫自动切换爬虫ip&#xff0c;轻松应对各种封锁和限制&#xff01;快来跟我学&#xff0c;让你的Python爬虫如虎添翼&#xff01; 首…

解锁暑假云端生活:铁威马NAS助你打造个性化体验

暑假转眼过半&#xff0c;大家一定度过一段非常美好的时光吧。朋友圈被去各地旅游的、看各种演唱会的、各种各样的观影读后感刷屏...生活很精彩&#xff0c;但如何高效地管理、享受和分享自己的文件、照片和影音内容成为困扰我们的难题。在这方面&#xff0c;铁威马NAS成为了越…

使用python对图像加噪声

加上雨点噪声 import cv2 import numpy as npdef get_noise(img, value10):#生成噪声图像>>> 输入&#xff1a; img图像value 大小控制雨滴的多少 >>> 返回图像大小的模糊噪声图像noise np.random.uniform(0, 256, img.shape[0:2])# 控制噪声水平&#xff…