深度学习第5天:GAN生成对抗网络

Image Description

☁️主页 Nowl

🔥专栏 《深度学习》

📑君子坐而论道,少年起而行之

​​

在这里插入图片描述

文章目录

  • 一、GAN
    • 1.基本思想
    • 2.用途
    • 3.模型架构
  • 二、具体任务与代码
    • 1.任务介绍
    • 2.导入库函数
    • 3.生成器与判别器
    • 4.预处理
    • 5.模型训练
    • 6.图片生成
    • 7.不同训练轮次的结果对比

一、GAN

1.基本思想

想象一下,市面上有许多仿制的画作,人们为了辨别这些伪造的画,就会提高自己的鉴别技能,然后仿制者为了躲过鉴别又会提高自己的伪造技能,这样反反复复,两个群体的技能不断得到提高,这就是GAN的基本思想

2.用途

我们知道GAN的全名是生成对抗网络,那么它就是以生成为主要任务,所以可以用在这些方面

  • 生成虚拟数据集,当数据集数量不够时,我们可以用这种方法生成数据
  • 图像清晰化,可以将模糊图片清晰化
  • 文本到图像的生成,可以训练文生图模型

GAN的用途还有很多,可以在学习过程中慢慢发现

3.模型架构

GAN的主要结构包含一个生成器和一个判别器,我们先输入一堆杂乱数据(被称为噪声)给生成器,接着让判别器将生成器生成的数据与真实的数据作对比,看是否能判别出来,以此往复训练

在这里插入图片描述

二、具体任务与代码

1.任务介绍

相信很多人都对手写数字数据集不陌生了,那我们就训练一个生成手写数字的GAN,注意:本示例代码需要的运行时间较长,请在高配置设备上运行或者减少训练回合数

在这里插入图片描述

2.导入库函数

先导入必要的库函数,包括torch用来处理神经网络方面的任务,numpy用来处理数据

import torch
import torch.nn as nn
import torch.optim as optim
from torch.autograd.variable import Variable
from torchvision import transforms, datasets
import numpy as np

3.生成器与判别器

使用torch定义生成器与判别器的基本结构,这里由于任务比较简单,只用定义线性层就行,再给线性层添加相应的激活函数就行了

# 定义生成器(Generator)和判别器(Discriminator)的简单网络结构
class Generator(nn.Module):
    def __init__(self):
        super(Generator, self).__init__()
        self.model = nn.Sequential(
            nn.Linear(100, 256),
            nn.ReLU(),
            nn.Linear(256, 784),
            nn.Tanh()
        )

    def forward(self, noise):
        return self.model(noise)

class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()
        self.model = nn.Sequential(
            nn.Linear(784, 256),
            nn.LeakyReLU(0.2),
            nn.Linear(256, 1),
            nn.Sigmoid()
        )

    def forward(self, image):
        return self.model(image)

4.预处理

这一部分定义了模型参数,加载了数据集,定义了损失函数与优化器,这些是神经网络训练时的一些基本参数

# 定义一些参数
batch_size = 100
learning_rate = 0.0002
epochs = 500

# 加载MNIST数据集
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,))
])

mnist_data = datasets.MNIST(root='./data', train=True, transform=transform, download=True)
data_loader = torch.utils.data.DataLoader(dataset=mnist_data, batch_size=batch_size, shuffle=True)

# 初始化生成器、判别器和优化器
generator = Generator()
discriminator = Discriminator()
optimizer_G = optim.Adam(generator.parameters(), lr=learning_rate)
optimizer_D = optim.Adam(discriminator.parameters(), lr=learning_rate)

# 损失函数
criterion = nn.BCELoss()

5.模型训练

这一部分开始训练模型,通过反向传播逐步调整模型的参数,注意模型训练的过程,观察生成器和判别器分别是怎么在训练中互相作用不断提高的

# 训练 GAN
for epoch in range(epochs):
    for data, _ in data_loader:
        data = data.view(data.size(0), -1)
        real_data = Variable(data)
        target_real = Variable(torch.ones(data.size(0), 1))
        target_fake = Variable(torch.zeros(data.size(0), 1))

        # 训练判别器
        optimizer_D.zero_grad()
        output_real = discriminator(real_data)
        loss_real = criterion(output_real, target_real)
        loss_real.backward()

        noise = Variable(torch.randn(data.size(0), 100))
        fake_data = generator(noise)
        output_fake = discriminator(fake_data.detach())
        loss_fake = criterion(output_fake, target_fake)
        loss_fake.backward()

        optimizer_D.step()

        # 训练生成器
        optimizer_G.zero_grad()
        output = discriminator(fake_data)
        loss_G = criterion(output, target_real)
        loss_G.backward()
        optimizer_G.step()

    print(f'Epoch [{epoch+1}/{epochs}], Loss D: {loss_real.item()+loss_fake.item()}, Loss G: {loss_G.item()}')

6.图片生成

这一部分再一次随机生成了一些噪声,并把他们传入生成器生成图片,其中包含一些格式转化过程,再通过matplotlib绘图库显示结果

# 生成一些图片
num_samples = 16
noise = Variable(torch.randn(num_samples, 100))
generated_samples = generator(noise)
generated_samples = generated_samples.view(num_samples, 1, 28, 28).detach()

import matplotlib.pyplot as plt
import torchvision.utils as vutils

plt.figure(figsize=(8, 8))
plt.axis("off")
plt.title("Generated Images")
plt.imshow(
    np.transpose(
        vutils.make_grid(generated_samples, nrow=4, padding=2, normalize=True).cpu(), (1, 2, 0)
    )
)
plt.show()

7.不同训练轮次的结果对比

在这里插入图片描述

在这里插入图片描述

感谢阅读,觉得有用的话就订阅下《深度学习》专栏吧,有错误也欢迎指出

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/246758.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

KylinV10 将项目上传至 Github

KylinV10 将项目上传至 Github 银河麒麟操作系统 V10 是在 Ubuntu 的基础上开发的,所以适用于 Ubuntu 的也适用于 KylinV10 一般上传至 GitHub,有两种方式,一种是 HTTPS,一种是 SSH,但是在 KylinV10 操作系统 HTTPS 的…

大数据----31.hbase安装启动

二.Hbase安装 先前安装: Zookeeper 正常部署 首先保证 Zookeeper 集群的正常部署,并启动之。 三台机器都执行:zkServer.sh startHadoop 正常部署 Hadoop 集群的正常部署并启动。 主节点上进行 :start-all.sh 1.HBase 的获取 一定…

【Python网络爬虫入门教程1】成为“Spider Man”的第一课:HTML、Request库、Beautiful Soup库

Python 网络爬虫入门:Spider man的第一课 写在最前面背景知识介绍蛛丝发射器——Request库智能眼镜——Beautiful Soup库 第一课总结 写在最前面 有位粉丝希望学习网络爬虫的实战技巧,想尝试搭建自己的爬虫环境,从网上抓取数据。 前面有写一…

c#读取XML文件实现晶圆wafermapping显示demo计算电机坐标以便控制电机移动

c#读取XML文件实现晶圆wafermapping显示 功能: 1.读取XML文件,显示mapping图 2.在mapping视图图标移动,实时查看bincode,x,y索引与计算的电机坐标 3.通过设置wafer放在平台的位置x,y轴电机编码值,相机在wafer的中心位置&#…

IDEA删除最近打开的文件记录

IDEA删除最近打开的文件记录 遇见问题:如何删除IDEA中最近打开的文件记录 解决方法 先关闭IDEA 找到 recentProjects.xml 文件 windows 位置:(AppData是隐藏文件夹) 1.C:\Users\电脑用户名\AppData\Roaming\JetBrains\IntelliJIde…

jpa 修改信息拦截

实现目标springbootJPA 哪个人,修改了哪个表的哪个字段,从什么值修改成什么值 Component // 必须加 Slf4j Configurable(autowire Autowire.BY_TYPE) public class AuditingEntityListener {// 线程变量,保存修改前的 objectprivate Thre…

SpringBoot运维中的高级配置

🙈作者简介:练习时长两年半的Java up主 🙉个人主页:程序员老茶 🙊 ps:点赞👍是免费的,却可以让写博客的作者开心好久好久😎 📚系列专栏:Java全栈,…

Navicat16 无限试用 亲测有效

Navicat16 无限试用 亲测有效 亲测有效!!! 吐槽下,有的用不了,有的是图片,更甚者还有收费的,6的一批 粘贴下面的代码,保存到桌面,命名为 trial-navicat16.bat echo off…

Web安全-SQL注入常用函数(二)

★★实战前置声明★★ 文章中涉及的程序(方法)可能带有攻击性,仅供安全研究与学习之用,读者将其信息做其他用途,由用户承担全部法律及连带责任,文章作者不承担任何法律及连带责任。 1、MySQL数据库构成 初始化安装MySQL数据库后(…

从零开始搭建企业管理系统(七):RBAC 之用户管理

RBAC 之用户管理 创建表(Entity)用户表角色表权限表用户角色表关系注解ManyToMany 角色权限表 接口开发UserControllerUserServiceUserServiceImplUserRepository 问题解决update 更新问题懒加载问题JSON 循环依赖问题 根据上一小结对表的设计&#xff0…

基于ssm高校食堂订餐系统论文

摘 要 现代经济快节奏发展以及不断完善升级的信息化技术,让传统数据信息的管理升级为软件存储,归纳,集中处理数据信息的管理方式。本高校食堂订餐系统就是在这样的大环境下诞生,其可以帮助管理者在短时间内处理完毕庞大的数据信息…

服务器数据恢复-EqualLogic PS存储硬盘坏道导致存储不可用的数据恢复案例

服务器数据恢复环境: 一台DELL EqualLogic PS系列存储,存储中有一组由16块SAS硬盘组成的RAID5。上层是VMFS文件系统,存放虚拟机文件。存储上层分了4个卷。 服务器故障&检测: 存储上有2个硬盘指示灯显示黄色,磁盘出…

51单片机的外部中断的以及相关寄存器的讲解

中断系统 本文主要涉及8051单片机的中断系统的讲解与使用 其中包括中断相关寄存器的介绍与使用以及外部中断初始化的代码分析。 文章目录 中断系统一、 中断的介绍二、 中断结构及相关寄存器2.1 中断源 2.2 中断请求控制器2.2.1 TCON寄存器2.2.2 SCON寄存器2.2.3 中断允许寄存器…

Spark与PySpark(1.概述、框架、模块)

目录 1.Spark 概念 2. Hadoop和Spark的对比 3. Spark特点 3.1 运行速度快 3.2 简单易用 3.3 通用性强 3.4 可以允许运行在很多地方 4. Spark框架模块 4.1 Spark Core 4.2 SparkSQL 4.3 SparkStreaming 4.4 MLlib 4.5 GraphX 5. Spark的运行模式 5.1 本地模式(单机) Local运行模…

智能优化算法应用:基于模拟退火算法3D无线传感器网络(WSN)覆盖优化 - 附代码

智能优化算法应用:基于模拟退火算法3D无线传感器网络(WSN)覆盖优化 - 附代码 文章目录 智能优化算法应用:基于模拟退火算法3D无线传感器网络(WSN)覆盖优化 - 附代码1.无线传感网络节点模型2.覆盖数学模型及分析3.模拟退火算法4.实验参数设定5.算法结果6.…

2036开关门,1109开关门

一:2036开关门 1.1题目 1.2思路 1.每次都是房间号是服务员的倍数的时候做处理,所以外层(i)枚举服务员1~n,内层(j)枚举房间号1~n,当j % i0时,做处理 2.这个处理指的是&…

module ‘tensorflow‘ has no attribute XXX 报错解决

问题描述: 粘了别人的tensorflow项目,运行总是报错module ‘tensorflow’ has no attribute什么什么 问题解决: 导入tensorflow的代码如下 import tensorflow as tf此时,某个某块报错,比如下面这个 那么就直接把tf.…

【【ZYNQ 7020显示 图片 实验 】】

ZYNQ 7020显示 图片 实验 关键配置 BRAM 因为本次 我想显示的 图片是 400*400 所以在 内部 的 ROM 存储单元选择 了160000 ZYNQ7020的内部资源 最多是 大概 200000左右的 大小 大家可以根据 资源选择合适的像素 此处存放 内部的 图片转文字的COE文件 PLL设置 我选用的是按…

言简意赅的 el-table 跨页多选

步骤一 在<el-table>中:row-key"getRowKeys"和selection-change"handleSelectionChange" 在<el-table-column>中type"selection"那列&#xff0c;添加:reserve-selection"true" <el-table:data"tableData"r…

第二证券:京沪楼市松绑,地产板块强势拉升,京能置业等涨停

地产板块15日盘中强势拉升&#xff0c;到发稿&#xff0c;上实开展、京能置业、大龙地产、新黄浦等涨停&#xff0c;京投开展涨逾6%&#xff0c;保利开展、招商蛇口等涨超3%。 消息面上&#xff0c;12月14日&#xff0c;北京发布调整优化一般住所标准和个人住所告贷政策的告诉…