Pytorch从零开始实战17

Pytorch从零开始实战——生成对抗网络入门

本系列来源于365天深度学习训练营

原作者K同学

文章目录

  • Pytorch从零开始实战——生成对抗网络入门
    • 环境准备
    • 模型定义
    • 开始训练
    • 总结

环境准备

本文基于Jupyter notebook,使用Python3.8,Pytorch1.8+cpu,本次实验目的是了解生成对抗网络。

生成对抗网络(GAN)是一种深度学习模型。GAN由两个主要组成部分组成:生成器和判别器。这两个部分通过对抗的方式共同学习,使得生成器能够生成逼真的数据,而判别器能够区分真实数据和生成的数据。

生成器的任务是生成与真实数据相似的样本。它接收一个随机噪声向量,然后通过深度神经网络将这个随机噪声转换为具体的数据样本。在图像生成的场景中,生成器通常将随机噪声映射为图像。生成器的目标是欺骗判别器,使其无法区分生成的样本和真实的样本。生成器的训练目标是最小化生成的样本与真实样本之间的差异。

判别器的任务是对给定的样本进行分类,判断它是来自真实数据集还是由生成器生成的。它接收真实样本和生成样本,然后通过深度神经网络输出一个概率,表示输入样本是真实样本的概率。判别器的目标是准确地分类样本,使其能够正确地区分真实数据和生成的数据。判别器的训练目标是最大化正确分类的概率。

导入相关包。

import torch
import torch.nn as nn
import argparse
import os
import numpy as np
import torchvision.transforms as transforms
from torchvision.utils import save_image
from torch.utils.data import DataLoader
from torchvision import datasets
from torch.autograd import Variable

创建文件夹,分别保存训练过程中的图像、模型参数和数据集。

os.makedirs("./images/", exist_ok=True) # 训练过程中图片效果
os.makedirs("./save/", exist_ok=True) # 训练完成时模型保存位置
os.makedirs("./datasets/", exist_ok=True) # 数据集位置

设置超参数。
b1、b2为Adam优化算法的参数,其中b1是梯度的一阶矩估计的衰减系数,b2是梯度的二阶矩估计的衰减系数。
latent_dim表示生成器输入的随机噪声向量的维度。这个噪声向量用于生成器产生新样本。
sample_interval表示在训练过程中每隔多少个batch保存一次生成器生成的样本图像,以便观察生成效果。

epochs = 20
batch_size = 64
lr = 0.0002
b1 = 0.5
b2 = 0.999
latent_dim=100
img_size=28
channels=1
sample_interval=500

设定图像尺寸并检查cuda,本次使用的设备没有cuda。

img_shape = (channels, img_size, img_size) # (1, 28, 28)
img_area = np.prod(img_shape) # 784
 
## 设置cuda
cuda = True if torch.cuda.is_available() else False
print(cuda) # False

本次使用GAN来生成手写数字,首先下载mnist数据集。

mnist = datasets.MNIST(root='./datasets/', 
                       train=True, 
                       download=True,
                       transform=transforms.Compose([transforms.Resize(img_size),
                                                     transforms.ToTensor(), 
                                                     transforms.Normalize([0.5], [0.5])]))

使用dataloader划分批次与打乱。

dataloader = DataLoader(
    mnist,
    batch_size=batch_size,
    shuffle=True,
)

len(dataloader) # 938

模型定义

首先定义鉴别器模型,代码中LeakyReLU是ReLU激活函数的变体,它引入了一个小的负斜率,在负输入值范围内,而不是将它们直接置零。这个斜率通常是一个小的正数,例如0.01。

class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()
        self.model = nn.Sequential(
            nn.Linear(img_area, 512),        
            nn.LeakyReLU(0.2, inplace=True),  
            nn.Linear(512, 256),             
            nn.LeakyReLU(0.2, inplace=True), 
            nn.Linear(256, 1),              
            nn.Sigmoid(),                    
        )
 
    def forward(self, img):
        img_flat = img.view(img.size(0), -1) 
        validity = self.model(img_flat)     
        return validity         # 返回的是一个[0, 1]间的概率

定义生成器模型,用于输出图像。

class Generator(nn.Module):
    def __init__(self):
        super(Generator, self).__init__()
        def block(in_feat, out_feat, normalize=True):      
            layers = [nn.Linear(in_feat, out_feat)]          
            if normalize:
                layers.append(nn.BatchNorm1d(out_feat, 0.8)) 
            layers.append(nn.LeakyReLU(0.2, inplace=True))   
            return layers
        
        self.model = nn.Sequential(
            *block(latent_dim, 128, normalize=False), 
            *block(128, 256),                         
            *block(256, 512),                         
            *block(512, 1024),                       
            nn.Linear(1024, img_area),                
            nn.Tanh()                                
        )
    def forward(self, z):                           
        imgs = self.model(z)                       
        imgs = imgs.view(imgs.size(0), *img_shape)  # reshape成(64, 1, 28, 28)
        return imgs                                 # 输出为64张大小为(1, 28, 28)的图像

开始训练

创建生成器、判别器对象。

generator = Generator()
discriminator = Discriminator()

定义损失函数。这个其实就是二分类的交叉熵损失。

criterion = torch.nn.BCELoss()

定义优化函数。

optimizer_G = torch.optim.Adam(generator.parameters(), lr=lr, betas=(b1, b2))
optimizer_D = torch.optim.Adam(discriminator.parameters(), lr=lr, betas=(b1, b2))

开始训练,实现GAN训练过程,其中生成器和判别器交替训练,通过对抗过程使得生成器生成逼真的图像,而判别器不断提高对真实和生成图像的判别能力。

for epoch in range(epochs):                   # epoch:50
    for i, (imgs, _) in enumerate(dataloader):  # imgs:(64, 1, 28, 28)     _:label(64)
 
        imgs = imgs.view(imgs.size(0), -1)    # 将图片展开为28*28=784  imgs:(64, 784)
        real_img = Variable(imgs)     # 将tensor变成Variable放入计算图中,tensor变成variable之后才能进行反向传播求梯度
        real_label = Variable(torch.ones(imgs.size(0), 1))    ## 定义真实的图片label为1
        fake_label = Variable(torch.zeros(imgs.size(0), 1))    ## 定义假的图片的label为0
 
 
        real_out = discriminator(real_img)            # 将真实图片放入判别器中
        loss_real_D = criterion(real_out, real_label) # 得到真实图片的loss
        real_scores = real_out                        # 得到真实图片的判别值,输出的值越接近1越好
        ## 计算假的图片的损失
        ## detach(): 从当前计算图中分离下来避免梯度传到G,因为G不用更新
        z = Variable(torch.randn(imgs.size(0), latent_dim))     ## 随机生成一些噪声, 大小为(128, 100)
        fake_img    = generator(z).detach()                                    ## 随机噪声放入生成网络中,生成一张假的图片。
        fake_out    = discriminator(fake_img)                                  ## 判别器判断假的图片
        loss_fake_D = criterion(fake_out, fake_label)                       ## 得到假的图片的loss
        fake_scores = fake_out
        ## 损失函数和优化
        loss_D = loss_real_D + loss_fake_D  # 损失包括判真损失和判假损失
        optimizer_D.zero_grad()             # 在反向传播之前,先将梯度归0
        loss_D.backward()                   # 将误差反向传播
        optimizer_D.step()                  # 更新参数
 
 
        z = Variable(torch.randn(imgs.size(0), latent_dim))     ## 得到随机噪声
        fake_img = generator(z)                                             ## 随机噪声输入到生成器中,得到一副假的图片
        output = discriminator(fake_img)                                    ## 经过判别器得到的结果
        ## 损失函数和优化
        loss_G = criterion(output, real_label)                              ## 得到的假的图片与真实的图片的label的loss
        optimizer_G.zero_grad()                                             ## 梯度归0
        loss_G.backward()                                                   ## 进行反向传播
        optimizer_G.step()                                                  ## step()一般用在反向传播后面,用于更新生成网络的参数
 
        ## 打印训练过程中的日志
        ## item():取出单元素张量的元素值并返回该值,保持原元素类型不变
        if (i + 1) % 100 == 0:
            print(
                "[Epoch %d/%d] [Batch %d/%d] [D loss: %f] [G loss: %f] [D real: %f] [D fake: %f]"
                % (epoch, epochs, i, len(dataloader), loss_D.item(), loss_G.item(), real_scores.data.mean(), fake_scores.data.mean())
            )
        ## 保存训练过程中的图像
        batches_done = epoch * len(dataloader) + i
        if batches_done % sample_interval == 0:
            save_image(fake_img.data[:25], "./images/%d.png" % batches_done, nrow=5, normalize=True)
            

在这里插入图片描述
保存模型。

torch.save(generator.state_dict(), './save/generator.pth')
torch.save(discriminator.state_dict(), './save/discriminator.pth')

查看最初的噪声图像。
在这里插入图片描述
查看后面生成的图像。
在这里插入图片描述

总结

对于GAN,生成器的任务是从随机噪声生成逼真的数据样本,判别器的任务是对给定的数据样本进行分类,判断其是真实数据还是由生成器生成的。生成器和判别器通过对抗的方式进行训练。在每个训练迭代中,生成器试图生成逼真的样本以欺骗判别器,而判别器努力提高自己的能力,以正确地区分真实和生成的样本。这种对抗训练的动态平衡最终导致生成器生成高质量、逼真的样本。

总之,GAN实现了在无监督情况下学习数据分布的能力,被广泛用于生成逼真图像、视频等数据。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/335546.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

CSS 浮动 定位

文章目录 网页布局的本质浮动如何设置浮动测试浮动 定位相对定位绝对定位测试定位 网页布局的本质 用 CSS 来摆放盒子,把盒子摆放到相应位置。 CSS 提供了三种传统布局方式(简单说就是盒子如何进行排列)。 普通流(标准流&#…

后面的输入框与前面的联动,输入框只能输入正数(不用正则)

概要 提示:这里可以描述概要 前面的输入框是发票金额,后面的输入框是累计发票金额(含本次)--含本次就代表后倾请求的接口的数据(不是保存后返显的-因为保存后返显的是含本次)是不含本次的所以在输入发票金…

从数据角度分析年龄与NBA球员赛场表现的关系【数据分析项目分享】

好久不见朋友们,今天给大家分享一个我自己很感兴趣的话题分析——NBA球员表现跟年龄关系到底大不大?数据来源于Kaggle,感兴趣的朋友可以点赞评论留言,我会将数据同代码一起发送给你。 目录 NBA球员表现的探索性数据分析导入Python…

会话跟踪技术(cookiesession)

文章目录 1、什么是会话跟踪技术2、Cookie2.1、Cookie基本使用2.2、Cookie原理2.3、Cookie使用细节 3、Session3.1、Session基本使用3.2、Session原理3.3、Session使用细节 4、Cookie和Session的对比 1、什么是会话跟踪技术 会话 ​ 用户打开浏览器,访问web服务器的…

在行情一般的情况下,就说说23级应届生如何找java工作

Java应届生找工作,不能单靠背面试题,更不能在简历中堆砌和找工作关系不大的校园实践经历,而是更要在面试中能证明自己的java相关商业项目经验。其实不少应届生Java求职者不是说没真实Java项目经验,而是不知道怎么挖掘,…

DB107S-ASEMI智能LED灯具专用DB107S

编辑:ll DB107S-ASEMI智能LED灯具专用DB107S 型号:DB107S 品牌:ASEMI 封装:DBS-4 最大重复峰值反向电压:1000V 最大正向平均整流电流(Vdss):1A 功率(Pd):50W 芯片个数:4 引…

浅析智能家居企业面临的组网问题及解决方案

在这个快速发展的时代,组网对于企业的发展来说是一个至关重要的环节。 案例背景: 案例企业是一家智能家居制造企业,在不同城市分布有分公司、店铺、工厂,这些点原本都是各自采购网络,与总部进行日常沟通、访问。 现在…

Linux用户与文件的关系和文件掩码(umask)的作用

文章目录 1 前言2 Linux用户与文件的关系3 文件掩码(umask)4 总结 1 前言 阅读本篇文章,你将了解Linux的目录结构,用户与文件的关系,以及文件掩码的作用。为了方便大家理解,本文将通过实例进行演示&#xf…

外卖系统创新:智能推荐与用户个性化体验

外卖系统的日益普及使得用户对于更智能、个性化的体验有着不断增长的期望。在这篇文章中,我们将探讨如何通过智能推荐技术,为用户提供更贴心、更符合口味的外卖选择。我们将使用 Python 和基于协同过滤的推荐算法作为示例,让您更深入地了解智…

【分布式技术】Elastic Stack部署,实操logstash的过滤模块常用四大插件

目录 一、Elastic Stack,之前被称为ELK Stack 完成ELK与Filebeat对接 步骤一:安装nginx做测试 步骤二:完成filebeat二进制部署 步骤三:准备logstash的测试文件filebeat.conf 步骤四:完成实验测试 二、logstash拥有…

解决 java.lang.NoClassDefFoundError: org/apache/poi/POIXMLTypeLoader 报错

在使用POI导出Excel表格的时候&#xff0c;本地运行导出没问题&#xff0c;但是发布到服务器后提示 “java.lang.NoClassDefFoundError: org/apache/poi/POIXMLTypeLoader” 下面是pom.xml中的配置 <dependency><groupId>org.apache.poi</groupId><art…

多场景建模:阿里STAR

多场景建模&#xff1a;阿里STAR 阿里提出了Partitioned Normalization、Star Topology FCN、Auxiliary Network应用到多场景建模&#xff0c;在各个场景上面取得不错的效果。 两个场景&#xff1a; 淘宝主页的banner&#xff0c;展示一个商品或者一个店铺或者一个品牌猜你喜欢…

css 3D立体动画效果怎么转这个骰子才能看到5

css 3D立体动画效果怎么转这个骰子才能看到5 <!DOCTYPE html> <html lang"zh-CN"><head><meta charset"UTF-8"><meta name"viewport" content"widthdevice-width, initial-scale1.0"><meta http-equ…

Java和SpringBoot学习路线图

看了一下油管博主Amigoscode的相关视频&#xff0c;提到了Java和SpringBoot的学习路线&#xff0c;相关视频地址为&#xff1a; How To Master Java - Java for Beginners RoadmapSpring Boot Roadmap - How To Master Spring Boot 如下图所示&#xff1a; 当然关于Java和Spr…

SpringBoot 服务注册IP选择问题

问题 有时候我们明明A\B服务都注册成功了&#xff0c;但是相互之间就是访问不了&#xff0c;这大概率是因为注册时选择IP时网卡选错了&#xff0c;当我们本地电脑有多个网卡时&#xff0c;程序会随机选择一个有IPV4的网卡&#xff0c;然后读取IPv4的地址 比如我的电脑有3个网…

铸铁平台使用米字型布局的特点——河北北重

铸铁平台使用米字型布局的特点主要有以下几点&#xff1a; 结构稳定&#xff1a;米字型布局能够使得铸铁平台的结构更加稳定。因为米字型布局将平台的重力均匀分散到四个支撑角上&#xff0c;减小了平台的变形和挠曲程度&#xff0c;使得平台能够承受更大的荷载。 节省空间&am…

伊恩·斯图尔特《改变世界的17个方程》傅里叶变换笔记

主要是课堂的补充&#xff08;yysy&#xff0c;我觉得课堂的教育模式真有够无聊的&#xff0c;PPT、写作业、考试&#xff0c;感受不到知识的魅力。 它告诉我们什么&#xff1f; 空间和时间中的任何模式都可以被看作不同频率的正弦模式的叠加。 为什么重要&#xff1f; 频率分量…

【JavaEE进阶】 SpringBoot配置⽂件

文章目录 &#x1f340;配置⽂件的作⽤&#x1f334;SpringBoot配置⽂件&#x1f38b;配置⽂件的格式&#x1f384;properties配置⽂件&#x1f6a9;properties基本语法&#x1f6a9;读取配置⽂件&#x1f6a9;properties的缺点 &#x1f333;yml配置⽂件yml基本语法&#x1f6…

文件服务FTP

文章目录 一、FTP协议二、VSFTPD服务介绍基础配置匿名用户访问&#xff08;默认开启&#xff09;本地用户访问虚拟用户访问 一、FTP协议 FTP协议&#xff1a;文件传输协议&#xff08;File Transfer Protocol&#xff09; 协议定义了一个在远程计算机系统和本地计算机系统之间…

【Linux的基本指令】

提示&#xff1a;文章写完后&#xff0c;目录可以自动生成&#xff0c;如何生成可参考右边的帮助文档 文章目录 前言 1、ls 指令 2、 pwd命令 3、cd 指令 4、touch指令 5、mkdir指令&#xff08;重要&#xff09; 6、rmdir指令 && rm 指令&#xff08;重要&#xff09;…