基础GAN生成式对抗网络(pytorch实验)

(Generative Adversarial Network)

一、理论

https://zhuanlan.zhihu.com/p/307527293?utm_campaign=shareopn&utm_medium=social&utm_psn=1815884330188283904&utm_source=wechat_session
大佬的文章中的“GEN的本质”部分
在这里插入图片描述

二、实验

1、数据集介绍

采用MNIST数据集,如下是训练集中的一张图片
在这里插入图片描述

2、代码

引入包

import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
import torchvision.datasets as datasets
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt

定义生成器

class Generator(nn.Module):
    def __init__(self):
        super(Generator, self).__init__()
        self.net = nn.Sequential(
            nn.Linear(100, 256),
            nn.ReLU(),
            nn.Linear(256, 512),
            nn.ReLU(),
            nn.Linear(512, 1024),
            nn.ReLU(),
            nn.Linear(1024, 784),
            nn.Tanh()  # 输出范围在 -1 到 1 之间
        )

    def forward(self, x):
        return self.net(x)

定义判别器

class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()
        self.net = nn.Sequential(
            nn.Linear(784, 1024),
            nn.LeakyReLU(0.2),
            nn.Linear(1024, 512),
            nn.LeakyReLU(0.2),
            nn.Linear(512, 256),
            nn.LeakyReLU(0.2),
            nn.Linear(256, 1),
            nn.Sigmoid()  # 输出范围在 0 到 1 之间
        )

    def forward(self, x):
        return self.net(x)

训练
先训练判别器,后训练生成器
训练时先训练判别器:将训练集数据(Training Set)打上真标签(1)和生成器(Generator)生成的假图片(Fake image)打上假标签(0)一同组成batch送入判别器(Discriminator),对判别器进行训练。计算loss时使判别器对真数据(Training Set)输入的判别趋近于真(1),对生成器(Generator)生成的假图片(Fake image)的判别趋近于假(0)。此过程中只更新判别器(Discriminator)的参数,不更新生成器(Generator)的参数。

然后再训练生成器:将高斯分布的噪声z(Random noise)送入生成器(Generator),然后将生成器(Generator)生成的假图片(Fake image)打上真标签(1)送入判别器(Discriminator)。计算loss时使判别器对生成器(Generator)生成的假图片(Fake image)的判别趋近于真(1)。此过程中只更新生成器(Generator)的参数,不更新判别器(Discriminator)的参数。

transform = transforms.Compose([
   transforms.ToTensor(),
   transforms.Normalize((0.5,), (0.5,))
])

dataset = datasets.MNIST(root='./data', train=True, transform=transform, download=True)
dataloader = DataLoader(dataset, batch_size=64, shuffle=True)


def train_gan(generator, discriminator, dataloader, num_epochs=25):
   criterion = nn.BCELoss()
   optimizer_g = optim.Adam(generator.parameters(), lr=0.0002, betas=(0.5, 0.999))
   optimizer_d = optim.Adam(discriminator.parameters(), lr=0.0002, betas=(0.5, 0.999))

   for epoch in range(num_epochs):
       for i, (imgs, _) in enumerate(dataloader):
           batch_size = imgs.size(0)

           real_imgs = imgs.view(batch_size, -1)  # 将图像展平成一维

           # 标签
           real_labels = torch.ones(batch_size, 1)
           fake_labels = torch.zeros(batch_size, 1)

           # 训练判别器
           outputs = discriminator(real_imgs)
           d_loss_real = criterion(outputs, real_labels)
           real_score = outputs

           z = torch.randn(batch_size, 100)
           fake_imgs = generator(z)
           outputs = discriminator(fake_imgs.detach())
           d_loss_fake = criterion(outputs, fake_labels)
           fake_score = outputs

           d_loss = d_loss_real + d_loss_fake
           optimizer_d.zero_grad()
           d_loss.backward()
           optimizer_d.step()

           # 训练生成器
           outputs = discriminator(fake_imgs)
           g_loss = criterion(outputs, real_labels)

           optimizer_g.zero_grad()
           g_loss.backward()
           optimizer_g.step()

       print(f'Epoch [{epoch+1}/{num_epochs}], d_loss: {d_loss.item()}, g_loss: {g_loss.item()}')

       if (epoch+1) % 10 == 0:
           with torch.no_grad():
               fake_imgs = generator(torch.randn(64, 100)).view(-1, 1, 28, 28)
               grid = torchvision.utils.make_grid(fake_imgs, nrow=8, normalize=True)
               plt.imshow(grid.permute(1, 2, 0).cpu())
               plt.title(f'Epoch {epoch+1}')
               plt.show()




generator = Generator()
discriminator = Discriminator()

train_gan(generator, discriminator, dataloader)

3、结果

输入一个随机噪声图像,由生成器能得到如下的图片(训练1step的结果)
在这里插入图片描述

输入一个随机噪声图像,由生成器能得到如下的图片(训练10step的结果)
在这里插入图片描述

拓展

可以看看VAE、CGAN等模型

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/876261.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

Java | Leetcode Java题解之第403题青蛙过河

题目&#xff1a; 题解&#xff1a; class Solution {public boolean canCross(int[] stones) {int n stones.length;boolean[][] dp new boolean[n][n];dp[0][0] true;for (int i 1; i < n; i) {if (stones[i] - stones[i - 1] > i) {return false;}}for (int i 1…

Oracle 11gR2打PSU补丁详细教程

1 说明 Oracle的PSU&#xff08;Patch Set Update&#xff09;补丁是Oracle公司为了其数据库产品定期发布的更新包&#xff0c;通常每季度发布一次。PSU包含了该季度内收集的一系列安全更新&#xff08;CPU&#xff1a;Critical Patch Update&#xff09;以及一些重要的错误修…

效率神器来了:AI工具手把手教你快速提升工作效能

随着科技的进步&#xff0c;AI工具已经成为提升工作效率的关键手段。本文将介绍一些实用的AI工具和方法&#xff0c;帮助你自动化繁琐的重复性任务、优化数据管理、促进团队协作与沟通&#xff0c;并提升决策质量。 背景&#xff1a;OOP AI-免费问答学习交流-GPT 自动化重复性任…

【Linux】【Vim】Vim 基础

Vim/Gvim 基础 文本编辑基础编辑操作符命令和位移改变文本重复改动Visual 模式移动文本(复制、粘贴)文本对象替换模式 光标移动以 word 为单位移动行首和行尾行内指定单字符移动到匹配的括号光标移动到指定行滚屏简单查找 /string标记 分屏vimdiff 文本编辑 基础编辑 Normal 模…

【网络安全的神秘世界】渗透测试基础

&#x1f31d;博客主页&#xff1a;泥菩萨 &#x1f496;专栏&#xff1a;Linux探索之旅 | 网络安全的神秘世界 | 专接本 | 每天学会一个渗透测试工具 渗透测试基础 基于功能去进行漏洞挖掘 1、编辑器漏洞 1.1 编辑器漏洞介绍 一般企业搭建网站可能采用了通用模板&#xff…

app抓包 chrome://inspect/#devices

一、前言&#xff1a; 1.首先不支持flutter框架&#xff0c;可支持ionic、taro 2.初次需要翻墙 3.app为debug包&#xff0c;非release 二、具体步骤 1.谷歌浏览器地址&#xff1a;chrome://inspect/#devices qq浏览器地址&#xff1a;qqbrowser://inspect/#devi…

Lombok:Java开发者的代码简化神器【后端 17】

Lombok&#xff1a;Java开发者的代码简化神器 在Java开发中&#xff0c;我们经常需要编写大量的样板代码&#xff0c;如getter、setter、equals、hashCode、toString等方法。这些代码虽然基础且必要&#xff0c;但往往占据了大量开发时间&#xff0c;且容易在属性变更时引发错误…

华为OD机试 - 计算误码率(Python/JS/C/C++ 2024 E卷 100分)

华为OD机试 2024E卷题库疯狂收录中&#xff0c;刷题点这里 专栏导读 本专栏收录于《华为OD机试真题&#xff08;Python/JS/C/C&#xff09;》。 刷的越多&#xff0c;抽中的概率越大&#xff0c;私信哪吒&#xff0c;备注华为OD&#xff0c;加入华为OD刷题交流群&#xff0c;…

怎么将几个pdf合成为一个?把几个PDF合并成为一个的8种方法

怎么将几个pdf合成为一个&#xff1f;将多个PDF文件合并成一个整体可以显著提高信息整合的效率&#xff0c;并简化文件的管理与传递。例如&#xff0c;将不同章节的电子书合成一本完整的书籍&#xff0c;或者将多个部门的报告整合成一个统一的文档&#xff0c;可以使处理流程变…

CCS811二氧化碳传感器详解(STM32)

目录 一、介绍 二、传感器原理 1.原理图 2.引脚描述 3.工作原理介绍 三、程序设计 main.c文件 ccs811.h文件 ccs811.c文件 四、实验效果 五、资料获取 项目分享 一、介绍 CCS811模块是一种气体传感器&#xff0c;可以测量环境中TVOC(总挥发性有机物质)浓度和eCO2…

6.接口测试加密接口(Jmeter/工具/函数助手对话框、Beanshell脚本)

一、接口测试加密接口&#xff0c;签名接口 1.加密算法&#xff1a; 可以解密的&#xff1a; 对称式加密&#xff08;私钥加密&#xff09;&#xff1a;AES&#xff0c;DES&#xff0c;Base64 https://www.bejson.com 非对称加密&#xff08;双…

Fisco Bcos 2.11.0通过网络和本地二进制文件搭建单机节点联盟链网络(搭建你的第一个区块链网络)

Fisco Bcos 2.11.0通过网络和本地二进制文件搭建单机节点联盟链网络(搭建你的第一个区块链网络) 文章目录 Fisco Bcos 2.11.0通过网络和本地二进制文件搭建单机节点联盟链网络(搭建你的第一个区块链网络)前言一、Ubuntu依赖安装二、创建操作目录, 下载build_chain.sh脚本2.1 先…

Linux-Swap分区使用与扩容

一、背景 在Linux系统中&#xff0c;swap空间&#xff08;通常称为swap分区&#xff09;是一个用于补充内存资源的重要组件。当系统的物理RAM不足时&#xff0c;Linux会将一部分不经常使用的内存页面移动到硬盘上的swap空间中&#xff0c;这个过程被称为分页&#xff08;paging…

【JavaEE初阶】多线程(4)

欢迎关注个人主页&#xff1a;逸狼 创造不易&#xff0c;可以点点赞吗~ 如有错误&#xff0c;欢迎指出~ 目录 线程安全的 第四个原因 代码举例: 分析原因 解决方法 方法1 方法2 wait(等待)和notify(通知) wait和sleep区别 线程安全的 第四个原因 内存可见性,引起的线程安全问…

CloudXR 套件扩展 XR 工作流

NVIDIA为开发者提供了一个先进的平台&#xff0c;开发者可以在该平台上使用全新NVIDIA CloudXR 套件来创建可扩展、品牌化的定制扩展现实&#xff08;XR&#xff09;产品。 NVIDIA CloudXR 套件基于全新架构而打造&#xff0c;是扩展XR生态的重要工具。它为开发者、专业人士和…

彻底理解浅拷贝和深拷贝

目录 浅拷贝实现 深拷贝实现自己手写 浅拷贝 浅拷贝是指创建一个新对象&#xff0c;这个对象具有原对象属性的精确副本 基本数据类型&#xff08;如字符串、数字等&#xff09;&#xff0c;在浅拷贝过程中它们是通过值传递的&#xff0c;而不是引用传递&#xff0c;修改值并不…

Git项目管理工具

分布式版本控制系统

数据集 wider person 户外密集行人检测 >> DataBall

数据集 wider person 用于野外密集行人检测的多样化数据集 行人检测 目标检测 户外密集行人检测的多样化数据集 WiderPerson: A Diverse Dataset for Dense Pedestrian Detection in the Wild article{zhang2019widerperson, Author {Zhang, Shifeng and Xie, Yiliang and Wa…

常用环境部署(二十)——docker部署OpenProject

一、安装Docker及Docker-compose https://blog.csdn.net/wd520521/article/details/112609796 二、docker拉取OpenProject镜像 1、拉取镜像 docker pull openproject/openproject:14 注意&#xff1a; 拉取镜像的时候会有超时的现象出现&#xff0c;大家重新拉取几次就行…

链式二叉树的基本操作(C语言版)

目录 1.二叉树的定义 2.创建二叉树 3.递归遍历二叉树 1&#xff09;前序遍历 2&#xff09;中序遍历 3&#xff09;后序遍历 4.层序遍历 5.计算节点个数 6.计算叶子节点个数 7.计算第K层节点个数 8.计算树的最大深度 9.查找值为x的节点 10.二叉树的销毁 从二叉树…