【深度学习】gan网络原理生成对抗网络

【深度学习】gan网络原理生成对抗网络

GAN的基本思想源自博弈论你的二人零和博弈,由一个生成器和一个判别器构成,通过对抗学习的方式训练,目的是估测数据样本的潜在分布并生成新的数据样本。
在这里插入图片描述

1.下载数据并对数据进行规范

transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(0.5 , 0.5)
])
train_ds = torchvision.datasets.MNIST('data', train=True, transform=transform, download=True)
dataloader = torch.utils.data.DataLoader(train_ds, batch_size=64, shuffle=True)

下载MNIST数据集,并对数据进行规范化。transforms.Compose 是用于定义一系列数据变换的类,ToTensor() 将图像转换为PyTorch张量,Normalize(0.5, 0.5) 对张量进行归一化。然后,创建一个 DataLoader,它将数据集划分成小批次,使得在训练时更容易处理。

2.生成器的代码

class Generator(nn.Module):
    def __init__(self):
        super(Generator, self).__init__()
        self.main = nn.Sequential(
               nn.Linear(100, 256),
               nn.ReLU(),
               nn.Linear(256, 512),
               nn.ReLU(),
               nn.Linear(512, 28*28),
               nn.Tanh()
           )

    def forward(self, x):
        img = self.main(x)
        img = img.reshape(-1, 28, 28)
        return img

这一部分定义了生成器的神经网络模型。生成器的输入是一个大小为100的随机向量,通过多个线性层和激活函数(ReLU),最后通过 nn.Tanh() 激活函数生成大小为28x28的图像。forward 方法定义了前向传播的过程。

3.判别器的代码

class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()
        self.main = nn.Sequential(
                    nn.Linear(28*28, 512),
                    nn.LeakyReLU(),
                    nn.Linear(512, 256),
                    nn.LeakyReLU(),
                    nn.Linear(256, 1),
                    nn.Sigmoid()
        )

    def forward(self, x):
        x = x.view(-1, 28*28)
        x = self.main(x)
        return x

这一部分定义了判别器的神经网络模型。判别器的输入是28x28大小的图像,通过多个线性层和激活函数(LeakyReLU),最后通过 nn.Sigmoid() 激活函数输出一个0到1之间的值,表示输入图像是真实图像的概率。

4. 定义损失函数和优化函数

device = 'cuda' if torch.cuda.is_available() else 'cpu'
gen = Generator().to(device)
dis = Discriminator().to(device)
gen_opt = optim.Adam(gen.parameters(), lr=0.0001)
dis_opt = optim.Adam(dis.parameters(), lr=0.0001)
loss_fn = torch.nn.BCELoss()

这一部分设置了设备(GPU或CPU)、初始化了生成器和判别器的实例,并定义了优化器(Adam优化器)和损失函数(二分类交叉熵损失)。将生成器和判别器移动到设备上进行加速计算。

5.定义绘图函数

def gen_img_plot(model,test_input):
    prediction = np.squeeze(model(test_input).detach().cpu().numpy())
    fig = plt.figure(figsize=(4, 4))
    for i in range(16):
        plt.subplot(4, 4, i+1)
        plt.imshow((prediction[i]+1)/2)
        plt.axis('off')
    plt.show()

6. 开始训练,并显示出生成器所产生的图像

test_input = torch.randn(16, 100, device=device)
D_loss = []
G_loss = []
for epoch in range(30):
    d_epoch_loss = 0
    g_epoch_loss = 0
    count = len(dataloader)
    for step, (img, _) in enumerate(dataloader):
        img = img.to(device)               # 获得用于训练的mnist图像
        size = img.size(0)                 # 获得1批次数据量大小
        
    # 随机生成size个100维的向量样本值,也即是噪声,用于输入生成器 生成 和mnist一样的图像数据
    random_noise = torch.randn(size, 100, device=device)

    ########################### 先训练判别器 #############################
    dis_opt.zero_grad()
    real_output = dis(img)
    d_real_loss = loss_fn(real_output, torch.ones_like(real_output))  # 真实值的loss,也即是真图片与1标签的损失
    d_real_loss.backward()

    gen_img = gen(random_noise)
    fake_output = dis(gen_img.detach())
    d_fake_loss = loss_fn(fake_output, torch.zeros_like(fake_output)) # 假的值的loss,也即是生成的图像与0标签的损失
    d_fake_loss.backward()

    d_loss = d_real_loss + d_fake_loss
    dis_opt.step()

    ########################### 下面再训练生成器 #############################
    gen_opt.zero_grad()
    fake_output = dis(gen_img)
    g_loss = loss_fn(fake_output, torch.ones_like(fake_output))
    g_loss.backward()
    gen_opt.step()
    #########################################################################
   with torch.no_grad():
        d_epoch_loss += d_loss
        g_epoch_loss += g_loss
with torch.no_grad():
    d_epoch_loss /= count
    g_epoch_loss /= count
    D_loss.append(d_epoch_loss)
    G_loss.append(g_epoch_loss)
    print('epoch:', epoch)
    gen_img_plot(gen, test_input)

1.设置 test_input 作为模型的输入,并初始化用于存储判别器(D)和生成器(G)的损失值的列表。

2.开始 30 轮次的训练循环。在每一轮中:

3.对数据集进行遍历。每次迭代,加载一批图像数据 (img)。

4.将图像数据移动到设备(device)上,并获取批次大小。

5.生成随机噪声,作为输入给生成器。

6.训练判别器(D):

  • 对真实图像计算判别器的损失 (d_real_loss),并反向传播计算梯度。
  • 生成生成器产生的图像,并计算判别器的对这些生成图像的损失 (d_fake_loss),再反向传播计算梯度。
  • 计算总的判别器损失 d_loss,并更新判别器的参数。

7.训练生成器(G):

  • 生成器生成图像,并将其输入到判别器中,计算生成器的损失 (g_loss),并反向传播计算梯度。
  • 更新生成器的参数。

这个过程是 GAN 中交替训练生成器和判别器的典型过程,目的是让生成器生成逼真的图像,同时让判别器能够准确区分真假图像。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/207226.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

用Python进行gRPC接口测试(一)

前言 gRPC 是一个高性能、通用的开源RPC框架,其由 Google 主要面向移动应用开发并基于HTTP/2 协议标准而设计,基于 ProtoBuf(Protocol Buffers) 序列化协议开发,且支持众多开发语言。 自gRPC推出以来,已经广泛应用于各种服务之中…

UI自动化测试工具有哪些优势?

UI自动化测试工具通过提高测试效率、覆盖率,减少测试时间和成本,以及支持持续集成等方式,为软件开发团队提供了一系列重要的优势,有助于提升软件质量和开发效率。 自动化执行:UI自动化测试工具可以模拟用户与应用程序的…

ubuntu22下使用nvidia 2080T显卡部署pytorch

1.直接到NVIDA官网下载相应的驱动,然后安装官方驱动 | NVIDIA 2.下载相应版本cuda,并安装,安装时不安装驱动 3.conda install pytorch2.1.0 torchvision0.16.0 torchaudio2.1.0 pytorch-cuda12.1 -c pytorch -c nvidia 安装pytorch。 安装…

Qt应用开发--国产工业开发板全志T113-i的部署教程

Qt在工业上的使用场景包括工业自动化、嵌入式系统、汽车行业、航空航天、医疗设备、制造业和物联网应用。Qt被用来开发工业设备的用户界面、控制系统、嵌入式应用和其他工业应用,因其跨平台性和丰富的功能而备受青睐。 Qt能够为工业领域带来什么好处: -…

scratch《贪吃蛇》改编版——设计方案

一、设计思路 设计想法来自《贪吃蛇》游戏改编。《贪吃蛇》游戏的背景源自古老的瑞典神话,讲述一条巨蛇在世间蔓延,吞噬一切的传说。游戏的玩法很简单,玩家通过上下左右键控制蛇的方向,使其在地图上移动并吞噬食物,随…

Android 12 及以上授权精确位置和模糊位置

请求位置信息权限 为了保护用户隐私,使用位置信息服务的应用必须请求位置权限。 请求位置权限时,请遵循与请求任何其他运行时权限相同的最佳做法。请求位置权限时的一个重要区别在于,系统中包含与位置相关的多项权限。具体请求哪项权限以及…

创新药集采中选后是否还需进行学术营销推广模式?

药品集采中选给药企带来了大幅降价和以价换量的趋势,同时也成为药企争夺的关键领域。对于是否继续进行推广活动,存在不同的观点和认知。 ▼需要还是不需要? 一些人认为集采后不再需要推广,因为药企无法承担高昂的销售费用&#x…

11.29 知识回顾(视图层、模板层)

一、视图层 1.1 响应对象 响应---》本质都是 HttpResponse -HttpResponse---》字符串 -render----》放个模板---》模板渲染是在后端完成 -js代码是在客户端浏览器里执行的 -模板语法是在后端执行的 -redirect----》重定向 -字符串参数不是…

Java实现socket编程案例

以下是一个基本的Java socket编程案例: 服务端代码: import java.net.*; import java.io.*;public class Server {public static void main(String[] args) throws IOException {ServerSocket serverSocket null;try {serverSocket new ServerSocket…

深度学习今年来经典模型优缺点总结,包括卷积、循环卷积、Transformer、LSTM、GANs等

文章目录 1、卷积神经网络(Convolutional Neural Networks,CNN)1.1 优点1.2 缺点1.3 应用场景1.4 网络图 2、循环神经网络(Recurrent Neural Networks,RNNs)2.1 优点2.2 缺点2.3 应用场景2.4 网络图 3、长短…

生产制造中4种导致产品成本、库存核算差错的问题!(化工/化妆品/生物制剂/混凝土等行业ODOO)

在化工/化妆品/生物制剂/混凝土等行业,因为其生产物料及产成品大都以液体(或散颗粒)形态为主,多以重量为计数方式;且液体(或散颗粒)相较于固体的较大区别就是产品计数上变数较大,固体…

什么是企业资金

我从两个方面来诠释企业资金管理: 1、企业资金管理是什么? 2、企业资金管理包括什么? 一、企业资金管理是什么? 众所周知,每个企业都有对应的财务部门,专门负责管理企业的“钱”,和企业的“帐…

企业软件手机app定制开发趋势|小程序网站搭建

企业软件手机app定制开发趋势|小程序网站搭建 随着移动互联网的快速发展和企业数字化转型的加速,企业软件手机App定制开发正成为一个新的趋势。这种趋势主要是由于企业对于手机App的需求增长以及现有的通用应用不能满足企业特定需求的情况下而产生的。 1.企业软件手…

【数据结构】—AVL树(C++实现)

🎬慕斯主页:修仙—别有洞天 💜本文前置知识: 搜索二叉树 ♈️今日夜电波:Letter Song—ヲタみん 1:36━━━━━━️💟──────── 5:35 …

Ubuntu 20.04 for NVIDIA V100 GPU安装手册

安装Ubuntu 20.04.3 LTS版本 image.png 安装Ubuntu 20.04按照安装提示,仔细选择每一项,基本默认即可。 系统中查看GPU信息 系统安装完成之后,进入系统,使用lspci 命令查询一下GPU是否存在、型号信息是什么。 bpangbobpang:\~$…

C语言中一些有关字符串的常见函数的使用及模拟实现(2)

在编程的过程中,我们经常要处理字符和字符串,为了⽅便操作字符和字符串,C语⾔标准库中提供了\n⼀系列库函数,接下来我们就学习⼀下这些函数。 在上一篇博客中已经讲解了strlen,strcpy,strcmp,st…

XXL-Job详解(二):安装部署

目录 前言环境下载项目调度中心部署执行器部署 前言 看该文章之前,最好看一下之前的文章,比较方便我们理解 XXL-Job详解(一):组件架构 环境 Maven3 Jdk1.8 Mysql5.7 下载项目 源码仓库地址链接: https://github.…

el-drawer抽屉组件弹窗遮挡问题解决

更多ruoyi-nbcio功能请看演示系统 gitee源代码地址 前后端代码: https://gitee.com/nbacheng/ruoyi-nbcio 演示地址:RuoYi-Nbcio后台管理系统 1、根据需要,需要在下面窗口里弹出抽屉组件,但出现遮挡问题,如下&…

阿里云刚崩完又崩了?部分地域云数据库控制台访问异常

11月27日,阿里云发布公告:您好!北京时间2023年11月27日 09:16起,阿里云监控发现北京、上海、杭州、深圳、青岛 、香港以及美东、美西地域的数据库产品(RDS、PolarDB、Redis等)的控制台和OpenAPI访问出现异常…

【Openstack Train安装】一、虚拟机创建

Openstack是一个云平台管理的项目,它不是一个软件。这个项目由几个主要的组件组合起来完成一些具体的工作。Openstack是一个旨在为公共及私有云的建设与管理提供软件的开源项目。它的社区拥有超过130家企业及1350位开发者,这些机构与个人将 Openstack作为…