生成对抗网络——GAN深度卷积实现(代码+理解)

        本篇博客为 上篇博客的 另一个实现版本,训练流程相同,所以只实现代码,感兴趣可以跳转看一下。

  生成对抗网络—GAN(代码+理解)

http://t.csdnimg.cn/HDfLOicon-default.png?t=N7T8http://t.csdnimg.cn/HDfLO


目录

一、GAN深度卷积实现

1. 模型结构

(1)生成器(Generator)

(2)判别器(Discriminator)

2. 代码实现

3. 运行结果展示

二、学习中产生的疑问,及文心一言回答

1. 模型初始化

2. 模型训练时

3. 优化器定义

4. 训练数据

5. 模型结构

(1)生成器        

(2)判别器


一、GAN深度卷积实现

1. 模型结构

(1)生成器(Generator)

(2)判别器(Discriminator)

2. 代码实现

import torch
import torch.nn as nn
import argparse
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
from torch.utils.data import DataLoader
from torchvision import datasets
import numpy as np


parser = argparse.ArgumentParser()
parser.add_argument("--n_epochs", type=int, default=20, help="number of epochs of training")
parser.add_argument("--batch_size", type=int, default=64, help="size of the batches")
parser.add_argument("--lr", type=float, default=0.0002, help="adam: learning rate")
parser.add_argument("--b1", type=float, default=0.5, help="adam: decay of first order momentum of gradient")
parser.add_argument("--b2", type=float, default=0.999, help="adam: decay of first order momentum of gradient")
parser.add_argument("--latent_dim", type=int, default=100, help="dimensionality of the latent space")
parser.add_argument("--img_size", type=int, default=32, help="size of each image dimension")
parser.add_argument("--channels", type=int, default=1, help="number of image channels")
parser.add_argument("--sample_interval", type=int, default=400, help="interval between image sampling")
opt = parser.parse_args()
print(opt)

# 加载数据
dataloader = torch.utils.data.DataLoader(
    datasets.MNIST(
        "./others/",
        train=False,
        download=False,
        transform=transforms.Compose(
            [transforms.Resize(opt.img_size), transforms.ToTensor(), transforms.Normalize([0.5], [0.5])]
        ),
    ),
    batch_size=opt.batch_size,
    shuffle=True,
)

def weights_init_normal(m):
    classname = m.__class__.__name__
    if classname.find("Conv") != -1:
        torch.nn.init.normal_(m.weight.data, 0.0, 0.02)
    elif classname.find("BatchNorm2d") != -1:
        torch.nn.init.normal_(m.weight.data, 1.0, 0.02) # 给定均值和标准差的正态分布N(mean,std)中生成值
        torch.nn.init.constant_(m.bias.data, 0.0)

class Generator(nn.Module):
    def __init__(self):
        super(Generator, self).__init__()

        self.init_size = opt.img_size // 4  # 原为28*28,现为32*32,两边各多了2
        self.l1 = nn.Sequential(nn.Linear(opt.latent_dim, 128 * self.init_size ** 2))

        self.conv_blocks = nn.Sequential(
            nn.BatchNorm2d(128),    # 调整数据的分布,使其 更适合于 下一层的 激活函数或学习
            nn.Upsample(scale_factor=2),
            nn.Conv2d(128, 128, 3, stride=1, padding=1),
            nn.BatchNorm2d(128, 0.8),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Upsample(scale_factor=2),
            nn.Conv2d(128, 64, 3, stride=1, padding=1),
            nn.BatchNorm2d(64, 0.8),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(64, opt.channels, 3, stride=1, padding=1),
            nn.Tanh(),
        )

    def forward(self, z):
        out = self.l1(z)
        out = out.view(out.shape[0], 128, self.init_size, self.init_size)
        img = self.conv_blocks(out)
        return img

class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()

        def discriminator_block(in_filters, out_filters, bn=True):
            block = [nn.Conv2d(in_filters, out_filters, 3, 2, 1),
                     nn.LeakyReLU(0.2, inplace=True),
                     nn.Dropout2d(0.25)]
            if bn:
                block.append(nn.BatchNorm2d(out_filters, 0.8))
            return block

        self.model = nn.Sequential(
            *discriminator_block(opt.channels, 16, bn=False),
            *discriminator_block(16, 32),
            *discriminator_block(32, 64),
            *discriminator_block(64, 128),
        )
        # 下采样(图片进行 4次卷积操作,变为ds_size * ds_size尺寸大小)
        ds_size = opt.img_size // 2 ** 4
        self.adv_layer = nn.Sequential(
            nn.Linear(128 * ds_size ** 2, 1),
            nn.Sigmoid()
        )

    def forward(self, img):
        out = self.model(img)
        out = out.view(out.shape[0], -1)
        validity = self.adv_layer(out)
        return validity

# 实例化
generator = Generator()
discriminator = Discriminator()

# 初始化参数
generator.apply(weights_init_normal)
discriminator.apply(weights_init_normal)

# 优化器
optimizer_G = torch.optim.Adam(generator.parameters(), lr=opt.lr, betas=(opt.b1, opt.b2))
optimizer_D = torch.optim.Adam(discriminator.parameters(), lr=opt.lr, betas=(opt.b1, opt.b2))

# 交叉熵损失函数
adversarial_loss = torch.nn.BCELoss()

def gen_img_plot(model, epoch, text_input):
    prediction = np.squeeze(model(text_input).detach().cpu().numpy()[:16])
    plt.figure(figsize=(4, 4))
    for i in range(16):
        plt.subplot(4, 4, i + 1)
        plt.imshow((prediction[i] + 1) / 2)
        plt.axis('off')
    plt.show()

# ----------
#  Training
# ----------
D_loss_ = []  # 记录训练过程中判别器的损失
G_loss_ = []  # 记录训练过程中生成器的损失
for epoch in range(opt.n_epochs):
    # 初始化损失值
    D_epoch_loss = 0
    G_epoch_loss = 0
    count = len(dataloader)  # 返回批次数
    for i, (imgs, _) in enumerate(dataloader):
        valid = torch.ones(imgs.shape[0], 1)
        fake = torch.zeros(imgs.shape[0], 1)

        # -----------------
        #  Train Generator
        # -----------------
        optimizer_G.zero_grad()
        z = torch.randn(imgs.shape[0], opt.latent_dim)
        gen_imgs = generator(z)
        g_loss = adversarial_loss(discriminator(gen_imgs), valid)
        g_loss.backward()
        optimizer_G.step()

        # ---------------------
        #  Train Discriminator
        # ---------------------
        optimizer_D.zero_grad()
        real_loss = adversarial_loss(discriminator(imgs), valid)
        fake_loss = adversarial_loss(discriminator(gen_imgs.detach()), fake)
        d_loss = (real_loss + fake_loss) / 2
        d_loss.backward()
        optimizer_D.step()

        print(
            "[Epoch %d/%d] [Batch %d/%d] [D loss: %f] [G loss: %f]"
            % (epoch, opt.n_epochs, i, len(dataloader), d_loss.item(), g_loss.item())
        )

        # batches_done = epoch * len(dataloader) + i
        # if batches_done % opt.sample_interval == 0:
        #     save_image(gen_imgs.data[:25], "others/images/%d.png" % batches_done, nrow=5, normalize=True)

        # 累计每一个批次的loss
        with torch.no_grad():
            D_epoch_loss += d_loss
            G_epoch_loss += g_loss

        # 求平均损失
    with torch.no_grad():
        D_epoch_loss /= count
        G_epoch_loss /= count
        D_loss_.append(D_epoch_loss.item())
        G_loss_.append(G_epoch_loss.item())

        text_input = torch.randn(opt.batch_size, opt.latent_dim)
        gen_img_plot(generator, epoch, text_input)


x = [epoch + 1 for epoch in range(opt.n_epochs)]
plt.figure()
plt.plot(x, G_loss_, 'r')
plt.plot(x, D_loss_, 'b')
plt.ylabel('loss')
plt.xlabel('epoch')
plt.legend(['G_loss','D_loss'])
plt.show()

3. 运行结果展示

二、学习中产生的疑问,及文心一言回答

1. 模型初始化

        函数 weights_init_normal 用于初始化 模型参数,为什么要 以 均值和标准差 的正态分布中采样的数 为标准?

2. 模型训练时

        这里“d_loss = (real_loss + fake_loss) / 2” 中的 “/ 2” 操作,在 实际训练中 有什么作用?

        由(real_loss + fake_loss) / 2的 得到 的 d_loss 与(real_loss+fake_loss)得到的 d_loss 进行 回溯,两者结果会 有什么不同吗?

3. 优化器定义

        设置 betas=(opt.b1, opt.b2) 有什么 实际的作用?通俗易懂的讲一下

        betas=(opt.b1, opt.b2) 是怎样 更新学习率的?

4. 训练数据

        这里我们用的data为 MNIST,为什么img_size设置为 32,不是 28?

5. 模型结构

(1)生成器        

        解释一下为什么是“Upsample, Conv2d, BatchNorm2d, LeakyReLU ”这种顺序?

(2)判别器

        模型的 基本 运算步骤是什么?其中为什么需要 “Dropout2d( p=0.25, inplace=False)”这一步?

        关于“ds_size” 和 “128 * ds_size ** 2”的实际意义?


                                后续更新 GAN的其他模型结构。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/722776.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

SSM整合-前后端分离(实现增删改查)

实现增删改查 实现功能03-添加家居信息需求分析/图解思路分析代码实现注意事项和细节 实现功能04-显示家居信息需求分析/图解思路分析代码实现 实现功能05-修改家居信息需求分析/图解思路分析代码实现注意事项和细节 实现功能06-删除家居信息需求分析/图解思路分析代码实现课后…

STM32学习笔记(八)--DMA直接存储器存取详解

(1)配置步骤1.配置RCC外设时钟 开启DMA外设2.初始化DMA外设 调用DMA_Init 外设存储器站点的起始地址 数据宽度 地址是否自增 方向 传输计数器 是否需要自动重装 选择触发源 通道优先级3.开启DMA控制 4.开启触发信号输出(如果需要硬件触发&…

Termius for Mac/Win:跨平台多协议远程管理利器

Termius for Mac/Win是一款备受瞩目的跨平台多协议远程管理软件,以其卓越的性能、丰富的功能和便捷的操作体验,赢得了广大用户的青睐。无论是在企业IT管理、系统维护,还是个人远程连接、文件传输等方面,Termius都展现出了出色的实…

Python神经影像数据的处理和分析库之nipy使用详解

概要 神经影像学(Neuroimaging)是神经科学中一个重要的分支,主要研究通过影像技术获取和分析大脑结构和功能的信息。nipy(Neuroimaging in Python)是一个强大的 Python 库,专门用于神经影像数据的处理和分析。nipy 提供了一系列工具和方法,帮助研究人员高效地处理神经影…

C语言---自定义类型:结构体

结构体回顾 结构体 自定义的类型:结构体、联合体、枚举 结构是一些值的集合,这些值成为成员变量,结构的每个成员可以是不同类型的变量 //描述一本书:书名、作者、定价、书号//结构体类型---类似于整型、浮点型 struct Book {c…

照度计仪器校准检测需要注意哪些因素?通常选择什么校准机构?

照度计是计量中光学领域常见的一类计量器具,一般是用于测量光照影响的微量变化,在实验室和机构中,都有广泛运用。常规的照度计在仪器校准检测中,误差主要因素是外界光线干扰,以及温湿度变化和稳压直流电源的电压变化差…

FPGA早鸟课程第二弹 | Vivado 设计静态时序分析和实际约束

在FPGA设计领域,时序约束和静态时序分析是提升系统性能和稳定性的关键。社区推出的「Vivado 设计静态时序分析和实际约束」课程,旨在帮助工程师们掌握先进的设计技术,优化设计流程,提高开发效率。 课程介绍 关于课程 权威认证&…

MyBatis系列四: 动态SQL

动态SQL语句-更复杂的查询业务需求 官方文档基本介绍案例演示if标签应用实例where标签应用实例choose/when/otherwise应用实例foreach标签应用实例trim标签应用实例[使用较少]set标签应用实例[重点]课后练习 上一讲, 我们学习的是 MyBatis系列三: 原生的API与配置文件详解 现在…

据APO Research(阿谱尔)统计,2023年全球乳酸企业产能约119.3万吨

乳酸又称 2-羟基丙酸,一种天然有机酸,分子式是 C3H6O3。是自然界中最为广泛存在的羟基酸,于 1780 年被瑞典科学家 Scheele 首次发现。乳酸是自然界最小的手性分子,以两种立体异构体的形式存在于自然界中,即左旋型 L-乳…

定制化物联网设备:开启智能生活新篇章

随着科技的进步,物联网(IoT)已成为我们日常生活和工作中不可或缺的一部分。从智能家居到工业自动化,物联网设备以其独特的功能和特性,极大地提高了我们的生活质量和工作效率。然而,在众多的物联网设备中&am…

思科配置路由器,四台主机互相ping通

一、如图配置 PC4和PC5用来配置路由器,各ip、接口如图所示。 二、配置各主机ip、子网掩码SNM、默认网关DGW (一)、PC0 (二)、PC1 (三)、PC2 (四)、PC3 三、 配置路由器Router0 (期间报错是打错了字母) Router>en Router#configure terminal Enter configurat…

使用 Vue CLI 脚手架生成 Vue 项目

最近我参与了一个前端Vue2的项目。尽管之前也有过参与Vue2项目的经验,但对一些前端Web技术并不十分熟悉。这次在项目中遇到了很多问题,所以我决定借此机会深入学习Vue相关的技术栈。然而,直接开始深入钻研这些技术可能会显得枯燥,…

[图解]建模相关的基础知识-12

1 00:00:00,650 --> 00:00:06,200 我们看,下面这个,你看f里面定义域是编号 2 00:00:06,410 --> 00:00:09,040 值域是工号,各只有一个元素 3 00:00:11,850 --> 00:00:14,340 所以这些就没有了 4 00:00:14,610 --> 00:00:19,640…

vue+echarts实现tooltip轮播

效果图如下: 实现步骤如下: 定义一个定时器 timer:null, len: 0,页面一加载就清空定时器,此操作是为了防止重复加载时会设置多个定时器在setOption后设置定时器 this.myChart.clear() this.myChart.setOption(option); this.autoShowTool…

vue.js有哪几种甘特图库?Vue.js的5大甘特图库分享!

vue.js有哪几种甘特图库?Vue.js的5大甘特图库分享! 如今,软件市场为任何复杂程度的项目提供了各种现成的计划和调度工具,但这些解决方案可能包含过多的功能或缺乏一些必要的功能。这就是为什么许多公司更愿意投资开发基于网络的定制解决方案…

【C++】拷贝构造函数、拷贝赋值函数与析构函数

C中的拷贝构造函数、拷贝赋值函数与析构函数详解 一、拷贝构造函数(Copy Constructor)二、拷贝赋值函数(Copy Assignment Operator)三、析构函数(Destructor)四、总结 在C中,拷贝构造函数、拷贝…

Docker私有化仓库Harbor安装流程

1.搭建Docker私有仓库主要有以下几种方式 使用Docker官方提供的Registry镜像:Docker官方提供了一个用于构建私有镜像仓库的Registry镜像,只需将镜像下载并运行容器,然后暴露5000端口即可使用。可以通过修改Docker的配置文件daemon.json&#…

具备人工智能标记的书签应用Hoarder

什么是 Hoarder ? Hoarder 是一款可自托管的书签应用程序(链接、笔记和图像),具有基于人工智能的自动标记和全文搜索功能。适合数据囤积者使用。 软件特点: 🔗 为链接添加书签、做简单的笔记并存储图像。⬇…

简单介绍vim

文章目录 前言一、Vim的特点二、安装Vim三、设置Vim配置文件的位置:编辑配置文件:添加配置选项:保存并退出编辑器:快速配置验证设置: 总结 前言 Vim是一款强大的文本编辑器,被广泛用于各种编程和文本编辑任…

大咖专栏 | AI 时代下,我们可以拥有怎样的数据库?

Hi,各位朋友们,我是 KaiwuDB 高级架构师赵衎衎。 KaiwuDB 始于万物互联时代下千万条数据洪流中,我们持续打磨构造了更加灵活兼容的分布式多模架构,实现了海量异构数据高性能、低成本的集中管理… …这些底层特性都在为后续提供更…