学习笔记:Pytorch利用MNIST数据集训练生成对抗网络(GAN)

2023.8.27

       在进行深度学习的进阶的时候,我发了生成对抗网络是一个很神奇的东西,为什么它可以“将一堆随机噪声经过生成器变成一张图片”,特此记录一下学习心得。

一、生成对抗网络百科

        2014年,还在蒙特利尔读博士的Ian Goodfellow发表了论 文《Generative Adversarial Networks》(网址: https://arxiv.org/abs/1406.2661),将生成对抗网络引入 深度学习领域。2016年,GAN热潮席卷AI领域顶级会议, 从ICLR到NIPS,大量高质量论文被发表和探讨。Yann LeCun曾评价GAN是“20年来机器学习领域最酷的想法”。

机器学习的模型可大体分为两类,生成模型( Generative Model)和判别模型(Discriminative Model)。判别模型需要输入变量 ,通过某种模型来 预测 。生成模型是给定某种隐含信息,来随机产生观 测数据。

GAN百科:

GAN(生成对抗网络)的系统全面介绍(醍醐灌顶)_打灰人的博客-CSDN博客

二、GAN代码

训练代码:

                epoch=1000时的效果就不错啦

import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
import numpy as np
import matplotlib.pyplot as plt


class Generator(nn.Module):  # 生成器
    def __init__(self, latent_dim):
        super(Generator, self).__init__()
        self.model = nn.Sequential(
            nn.Linear(latent_dim, 256),
            nn.LeakyReLU(0.2),
            nn.Linear(256, 512),
            nn.LeakyReLU(0.2),
            nn.Linear(512, 1024),
            nn.LeakyReLU(0.2),
            nn.Linear(1024, 784),
            nn.Tanh()
        )

    def forward(self, z):
        img = self.model(z)
        img = img.view(img.size(0), 1, 28, 28)
        return img


class Discriminator(nn.Module):  # 判别器
    def __init__(self):
        super(Discriminator, self).__init__()
        self.model = nn.Sequential(
            nn.Linear(784, 512),
            nn.LeakyReLU(0.2),
            nn.Linear(512, 256),
            nn.LeakyReLU(0.2),
            nn.Linear(256, 1),
            nn.Sigmoid()
        )

    def forward(self, img):
        img = img.view(img.size(0), -1)
        validity = self.model(img)
        return validity


def gen_img_plot(model, test_input):
    pred = np.squeeze(model(test_input).detach().cpu().numpy())
    fig = plt.figure(figsize=(4, 4))
    for i in range(16):
        plt.subplot(4, 4, i + 1)
        plt.imshow((pred[i] + 1) / 2)
        plt.axis('off')
    plt.show(block=False)
    plt.pause(3)  # 停留0.5s
    plt.close()


# 调用GPU
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

# 超参数设置
lr = 0.0001
batch_size = 128
latent_dim = 100
epochs = 1000

# 数据集载入和数据变换
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,))
])

# 训练数据
train_dataset = datasets.MNIST(root='./data', train=True, transform=transform, download=False)
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)

# 测试数据 torch.randn()函数的作用是生成一组均值为0,方差为1(即标准正态分布)的随机数
# test_data = torch.randn(batch_size, latent_dim).to(device)
test_data = torch.FloatTensor(batch_size, latent_dim).to(device)

# 实例化生成器和判别器,并定义损失函数和优化器
generator = Generator(latent_dim).to(device)
discriminator = Discriminator().to(device)
adversarial_loss = nn.BCELoss()
optimizer_G = optim.Adam(generator.parameters(), lr=lr)
optimizer_D = optim.Adam(discriminator.parameters(), lr=lr)

# 开始训练模型
for epoch in range(epochs):
    for i, (imgs, _) in enumerate(train_loader):
        batch_size = imgs.shape[0]
        real_imgs = imgs.to(device)

        # 训练判别器
        z = torch.FloatTensor(batch_size, latent_dim).to(device)
        z.data.normal_(0, 1)
        fake_imgs = generator(z)  # 生成器生成假的图片

        real_labels = torch.full((batch_size, 1), 1.0).to(device)
        fake_labels = torch.full((batch_size, 1), 0.0).to(device)

        real_loss = adversarial_loss(discriminator(real_imgs), real_labels)
        fake_loss = adversarial_loss(discriminator(fake_imgs.detach()), fake_labels)
        d_loss = (real_loss + fake_loss) / 2

        optimizer_D.zero_grad()
        d_loss.backward()
        optimizer_D.step()

        # 训练生成器
        z.data.normal_(0, 1)
        fake_imgs = generator(z)

        g_loss = adversarial_loss(discriminator(fake_imgs), real_labels)
        optimizer_G.zero_grad()
        g_loss.backward()
        optimizer_G.step()

        torch.save(generator.state_dict(), "Generator_mnist.pth")

    print(f"Epoch [{epoch}/{epochs}] Loss_D: {d_loss.item():.4f} Loss_G: {g_loss.item():.4f}")

# gen_img_plot(Generator, test_data)
gen_img_plot(generator, test_data)

测试代码:

import torch
import torch.nn as nn
import matplotlib.pyplot as plt
import numpy as np
import random

device = torch.device("cuda:0" if torch.cuda.is_available() else 'cpu')


class Generator(nn.Module):  # 生成器
    def __init__(self, latent_dim):
        super(Generator, self).__init__()
        self.model = nn.Sequential(
            nn.Linear(latent_dim, 256),
            nn.LeakyReLU(0.2),
            nn.Linear(256, 512),
            nn.LeakyReLU(0.2),
            nn.Linear(512, 1024),
            nn.LeakyReLU(0.2),
            nn.Linear(1024, 784),
            nn.Tanh()
        )

    def forward(self, z):
        img = self.model(z)
        img = img.view(img.size(0), 1, 28, 28)
        return img


# test_data = torch.FloatTensor(128, 100).to(device)
test_data = torch.randn(128, 100).to(device)  # 随机噪声

model = Generator(100).to(device)
model.load_state_dict(torch.load('Generator_mnist.pth'))
model.eval()

pred = np.squeeze(model(test_data).detach().cpu().numpy())

for i in range(64):
    plt.subplot(8, 8, i + 1)
    plt.imshow((pred[i] + 1) / 2)
    plt.axis('off')
plt.savefig(fname='image.png', figsize=[5, 5])
plt.show()

三、结果

       在超参数设置 epoch=1000,batch_size=128,lr=0.0001,latent_dim = 100 时,gan生成的权重测的结果如图所示

四,GAN的损失函数曲线

                一开始训练时,我的gan的损失函数的曲线是类似这样的,就是知乎这文章里一样,生成器损失函数的曲线一直发散。首先,这个loss的曲线一看就是网络崩了,一般正常的情况,d_loss的值会一直下降然后收敛,而g_loss的曲线会先增大后减少,最后同样也会收敛。其次,网络拿到手以后先不要训练太多次,容易出现过拟合的情况。

生成对抗网络的损失函数图像如下合理吗? - 知乎

这是训练了10轮的生成器和鉴别器的损失函数值变化吧:

效果如图所示: 

 

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/96307.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

Java虚拟机内部组成

1、栈区 public class Math {public int compute(){//一个方法对应一块栈帧内存区域int a l;int b 2;int c (a b)*10;return c; } public static void main(String[] args){Math math new, Math() ;math.compute() ;System.out.println("test");}} 栈是先进后出…

centos安装MySQL 解压版完整教程(按步骤傻瓜式安装

一、卸载系统自带的 Mariadb 查看: rpm -qa|grep mariadb 卸载: rpm -e --nodeps mariadb-libs-5.5.68-1.el7.x86_64 二、卸载 etc 目录下的 my.cnf 文件 rm -rf /etc/my.cnf 三、检查MySQL是否存在 有则先删除 #卸载mysql服务以及删除所有mysql目录 #没…

使用CSS的@media screen 规则为不同的屏幕尺寸设置不同的样式(响应式图片布局)

当你想要在不同的屏幕尺寸或设备上应用不同的CSS样式时,可以使用 media 规则,特别是 media screen 规则。这允许你根据不同的屏幕特性,如宽度、高度、方向等,为不同的屏幕尺寸设置不同的样式。 具体来说,media screen…

常见前端面试之VUE面试题汇总十一

31. Vuex 有哪几种属性? 有五种,分别是 State、 Getter、Mutation 、Action、 Module state > 基本数据(数据源存放地) getters > 从基本数据派生出来的数据 mutations > 提交更改数据的方法,同步 actions > 像一个装饰器&a…

微信小程序请求接口返回的二维码(图片),本地工具和真机测试都能显示,上线之后不显示问题

请求后端接口返回的图片&#xff1a; 页面展示&#xff1a; 代码实现&#xff1a; :show-menu-by-longpress"true" 是长按保存图片 base64Code 是转为base64的地址 <image class"code" :src"base64Code" alt"" :show-menu-by-long…

C语言练习4(巩固提升)

C语言练习4 选择题 前言 面对复杂变化的世界&#xff0c;人类社会向何处去&#xff1f;亚洲前途在哪里&#xff1f;我认为&#xff0c;回答这些时代之问&#xff0c;我们要不畏浮云遮望眼&#xff0c;善于拨云见日&#xff0c;把握历史规律&#xff0c;认清世界大势。 选择题 …

考研数据结构:第八章 排序

文章目录 一、排序的基本概念二、插入排序2.1插入排序2.1.1算法思想2.1.2算法实现2.1.3算法效率分析2.1.4算法优化——折半插入排序 2.2希尔排序2.2.1算法思想2.2.2代码实现2.2.3算法性能分析 三、交换排序3.1冒泡排序3.1.1算法思想3.1.2代码实现3.1.3算法性能分析 3.2快速排序…

基于Spring Boot 的 Ext JS 应用框架之coworkee

Ext JS 官方提供了一个人员管理的完整应用框架 - coworkee。该框架的显示如下: 该框架的布局特点如下: 布局方式: 左右布局, 左侧导航栏默认收合特点:左侧导航区占用空间小, 工作区较大, 适合没有二级导航栏,工作区需要显示的内容较多的系统。如果导航栏是横向底部,就…

DataTable扩展 列转行方法(2*2矩阵转换)

源数据 如图所示 // <summary>/// DataTable扩展 列转行方法&#xff08;2*2矩阵转换&#xff09;/// </summary>/// <param name"dtSource">数据源</param>/// <param name"columnFilter">逗号分隔 如SDateTime,PM25,PM10…

民族传统文化分享系统uniapp 微信小程序

管理员、用户可通过Android系统手机打开系统&#xff0c;注册登录后可进行管理员后端&#xff1b;首页、个人中心、用户管理、知识分类管理、知识资源管理、用户分享管理、意见反馈、系统管理&#xff0c;用户前端&#xff1b;首页、知识资源、用户分享、我的等。 本系统的使用…

vue create -p dcloudio/uni-preset-vue my-project创建文件报错443

因为使用vue3viteuniappvant4报错&#xff0c;uniapp暂不支持vant4&#xff0c;所以所用vue2uniappvant2 下载uni-preset-vue-master 放到E:\Auniapp\uni-preset-vue-master 在终端命令行创建uniapp vue create -p E:\Auniapp\uni-preset-vue-master my-project

【广州华锐互动】VR党建多媒体互动展厅:随时随地开展党史教育

随着科技的不断发展&#xff0c;虚拟现实(VR)技术已经逐渐渗透到各个领域&#xff0c;其中党建教育尤为受益。为了更好地传承红色基因&#xff0c;弘扬党的优良传统&#xff0c;广州华锐互动推出了VR党建多媒体互动展厅&#xff0c;让广大党员干部和人民群众通过现代科技手段&a…

[ 云计算 | AWS ] Java 应用中使用 Amazon S3 进行存储桶和对象操作完全指南

文章目录 一、前言二、所需 Maven 依赖三、先决必要的几个条件信息四、创建客户端连接五、Amazon S3 存储桶操作5.1. 创建桶5.2. 列出桶 六、Amazon S3 对象操作6.1. 上传对象6.2. 列出对象6.3. 下载对象6.4. 复制、重命名和移动对象6.5. 删除对象6.6. 删除多个对象 七、文末总…

【LeetCode-中等题】114. 二叉树展开为链表

文章目录 题目方法一&#xff1a;前序遍历&#xff08;构造集合&#xff09; 集合&#xff08;构造新树&#xff09;方法二&#xff1a;原地构建方法三&#xff1a;前序遍历--迭代&#xff08;构造集合&#xff09; 集合&#xff08;构造新树&#xff09; 题目 方法一&#x…

Axure设计之日期选择器(年月选择)

在系统中&#xff0c;日期选择器经常会用到&#xff0c;包括日历日期的选择、日期时间的选择和日期范围的选择&#xff0c;一般是下拉列表的形式进行选择。Axure没有自带的日期选择器&#xff0c;下面教大家如何在Axure中制作真实日期选择&#xff08;年月选择&#xff09;效果…

2024毕业设计选题指南【附选题大全】

title: 毕业设计选题指南 - 如何选择合适的毕业设计题目 date: 2023-08-29 categories: 毕业设计 tags: 选题指南, 毕业设计, 毕业论文, 毕业项目 - 如何选择合适的毕业设计题目 当我们站在大学生活的十字路口&#xff0c;毕业设计便成了我们面临的一项重要使命。这不仅是对我们…

数组(个人学习笔记黑马学习)

一维数组 1、定义方式 #include <iostream> using namespace std;int main() {//三种定义方式//1.int arr[5];arr[0] 10;arr[1] 20;arr[2] 30;arr[3] 40;arr[4] 50;//访问数据元素/*cout << arr[0] << endl;cout << arr[1] << endl;cout &l…

GitHub Copilot三连更:能在代码行里直接提问,上下文范围扩展到终端

量子位 | 公众号 QbitAI 就在昨晚&#xff0c;GitHub Copilot迎来了一波不小的更新。 包括&#xff1a; 全新交互体验——代码行中直接召唤聊天功能&#xff0c;不用切界面&#xff0c;主打一个专注&#xff1b; 改善斜杠命令&#xff0c;一键删除&#xff0c;主打快捷操作、…

2023年高教社杯数学建模思路 - 案例:最短时间生产计划安排

文章目录 0 赛题思路1 模型描述2 实例2.1 问题描述2.2 数学模型2.2.1 模型流程2.2.2 符号约定2.2.3 求解模型 2.3 相关代码2.4 模型求解结果 建模资料 0 赛题思路 &#xff08;赛题出来以后第一时间在CSDN分享&#xff09; https://blog.csdn.net/dc_sinor?typeblog 最短时…