Pytorch从零开始实战18

Pytorch从零开始实战——人脸图像生成

本系列来源于365天深度学习训练营

原作者K同学

文章目录

  • Pytorch从零开始实战——人脸图像生成
    • 环境准备
    • 模型定义
    • 开始训练
    • 可视化
    • 总结

环境准备

本文基于Jupyter notebook,使用Python3.8,Pytorch2.0.1+cu118,torchvision0.15.2,需读者自行配置好环境且有一些深度学习理论基础。本次实验的目的是了解并使用DCGAN模型,完成人脸图生成。
第一步,导入常用包

import torch, random, random, os
import torch.nn as nn
import torch.nn.parallel
import torch.optim as optim
import torch.utils.data
import torchvision
import torchvision.datasets as datasets
import torchvision.transforms as transforms
import torchvision.utils as vutils
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.animation as animation
from IPython.display import HTML
os.environ['KMP_DUPLICATE_LIB_OK']='True'  # 用于避免jupyter环境突然关闭
torch.backends.cudnn.benchmark=True  # 用于加速GPU运算的代码

设置随机数种子

torch.manual_seed(428)
torch.cuda.manual_seed(428)
torch.cuda.manual_seed_all(428)
random.seed(428)
np.random.seed(428)

检查设备对象

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device, torch.cuda.device_count()

设置超参数,其中数据集源于K同学

dataroot = "./data/face"  # 数据路径
batch_size = 128  # 训练过程中的批次大小
image_size = 64   # 图像的尺寸(宽度和高度)
nz  = 100         # z潜在向量的大小(生成器输入的尺寸)
ngf = 64          # 生成器中的特征图大小
ndf = 64          # 判别器中的特征图大小
num_epochs = 50   # 训练的总轮数,如果你显卡不太行,可调小,但是生成效果会随之降低
lr    = 0.0002    # 学习率
beta1 = 0.5       # Adam优化器的Beta1超参数

使用datasets读取数据

dataset = datasets.ImageFolder(root=dataroot,
                           transform=transforms.Compose([
                               transforms.Resize(image_size),        # 调整图像大小
                               transforms.CenterCrop(image_size),    # 中心裁剪图像
                               transforms.ToTensor(),                # 将图像转换为张量
                               transforms.Normalize((0.5, 0.5, 0.5), # 标准化图像张量
                                                    (0.5, 0.5, 0.5)),
                           ]))

随机查看五张数据

def plotsample(data):
    fig, axs = plt.subplots(1, 5, figsize=(10, 10)) #建立子图
    for i in range(5):
        num = random.randint(0, len(data) - 1) #首先选取随机数,随机选取五次
        #抽取数据中对应的图像对象,make_grid函数可将任意格式的图像的通道数升为3,而不改变图像原始的数据
        #而展示图像用的imshow函数最常见的输入格式也是3通道
        npimg = torchvision.utils.make_grid(data[num][0]).numpy()
        nplabel = data[num][1] #提取标签 
        #将图像由(3, weight, height)转化为(weight, height, 3),并放入imshow函数中读取
        axs[i].imshow(np.transpose(npimg, (1, 2, 0))) 
        axs[i].set_title(nplabel) #给每个子图加上标签
        axs[i].axis("off") #消除每个子图的坐标轴

plotsample(dataset)

在这里插入图片描述
使用dataloader进行批量划分和打乱

dataloader = torch.utils.data.DataLoader(dataset, 
                                         batch_size=batch_size,  # 批量大小
                                         shuffle=True,           # 是否打乱数据集
                                         num_workers=5 # 使用多个线程加载数据的工作进程数
                                        )

模型定义

深度卷积对抗网络(Deep Convolutional Generative Adversarial Networks, DCGAN)是生成对抗网络的一种模型改进,其将卷积运算的思想引入到生成式模型当中来做无监督的训练,利用卷积网络强大的特征提取能力来提高生成网络的学习效果。

判别器网络和生成器网络: DCGAN包括两个主要部分,即判别器(Discriminator)和生成器(Generator)。判别器负责评估输入图像是真实图像还是生成图像,而生成器则试图生成逼真的图像以欺骗判别器。

卷积层和批量归一化: DCGAN使用卷积神经网络(CNN)作为判别器和生成器的主要组件,以有效地捕捉图像的空间结构。此外,引入批量归一化来稳定训练过程,加速收敛。

生成器输入和输出: 生成器接收一个随机噪声向量作为输入,通过反卷积(或称为转置卷积)操作生成图像。这使得生成器能够从随机噪声中学到数据分布的特征。

判别器的激活函数和损失函数: 判别器使用Leaky ReLU(带有泄漏的修正线性单元)作为激活函数,以防止梯度消失问题。损失函数采用二元交叉熵,用于判断生成图像和真实图像的相似度。

避免全连接层: DCGAN的设计避免使用全连接层,而是主要使用卷积和反卷积层,以减少模型参数,降低过拟合风险。

其中,反卷积核心思想是通过在输入特征图之间插入一些新的值(通常用零填充),使得输出的尺寸比输入更大。

DCGAN模型主要包括了一个生成网络 G 和一个判别网络 D,生成网络 G 负责生成图像,它接受一个随机的噪声z,通过该噪声生成图像,将生成的图像记为G(z),判别网络D负责判别一张图像是否为真实的,它的输入是x,代表一张图像,输出D(x)表示x为真实图像的概率。

实际上判别网络D是对数据的来源进行一个判别:究竟这个数据是来自真实的数据分布Pd(x)判别为“1”),还是来自于一个生成网络G所产生的一个数据分布Pg(z)(判别为“0”)。所以在整个训练过程中,生成网络G的目标是生成可以以假乱真的图像G(z),当判别网络D无法区分,即D(G(z))=0.5时,便得到了一个生成网络G用来生成图像扩充数据集。
在这里插入图片描述
初始化权重

# 自定义权重初始化函数,作用于netG和netD
def weights_init(m):
    # 获取当前层的类名
    classname = m.__class__.__name__
    # 如果类名中包含'Conv',即当前层是卷积层
    if classname.find('Conv') != -1:
        # 使用正态分布初始化权重数据,均值为0,标准差为0.02
        nn.init.normal_(m.weight.data, 0.0, 0.02)
    # 如果类名中包含'BatchNorm',即当前层是批归一化层
    elif classname.find('BatchNorm') != -1:
        # 使用正态分布初始化权重数据,均值为1,标准差为0.02
        nn.init.normal_(m.weight.data, 1.0, 0.02)
        # 使用常数初始化偏置项数据,值为0
        nn.init.constant_(m.bias.data, 0)

定义生成器网络

class Generator(nn.Module):
    def __init__(self):
        super(Generator, self).__init__()
        self.main = nn.Sequential(
            # 输入为Z,经过一个转置卷积层
            nn.ConvTranspose2d(nz, ngf * 8, 4, 1, 0, bias=False),
            nn.BatchNorm2d(ngf * 8),  # 批归一化层,用于加速收敛和稳定训练过程
            nn.ReLU(True),  # ReLU激活函数
            # 输出尺寸:(ngf*8) x 4 x 4
            nn.ConvTranspose2d(ngf * 8, ngf * 4, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ngf * 4),
            nn.ReLU(True),
            # 输出尺寸:(ngf*4) x 8 x 8
            nn.ConvTranspose2d(ngf * 4, ngf * 2, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ngf * 2),
            nn.ReLU(True),
            # 输出尺寸:(ngf*2) x 16 x 16
            nn.ConvTranspose2d(ngf * 2, ngf, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ngf),
            nn.ReLU(True),
            # 输出尺寸:(ngf) x 32 x 32
            nn.ConvTranspose2d(ngf, 3, 4, 2, 1, bias=False),
            nn.Tanh()  # Tanh激活函数
            # 输出尺寸:3 x 64 x 64
        )
        
    def forward(self, input):
        return self.main(input)

netG = Generator().to(device)
netG.apply(weights_init) # 使用 "weights_init" 函数来随机初始化所有权重
print(netG)

在这里插入图片描述
定义判别器

class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()

        # 定义判别器的主要结构,使用Sequential容器将多个层按顺序组合在一起
        self.main = nn.Sequential(
            # 输入大小为3 x 64 x 64
            nn.Conv2d(3, ndf, 4, 2, 1, bias=False),
            nn.LeakyReLU(0.2, inplace=True),
            # 输出大小为(ndf) x 32 x 32
            nn.Conv2d(ndf, ndf * 2, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ndf * 2),
            nn.LeakyReLU(0.2, inplace=True),
            # 输出大小为(ndf*2) x 16 x 16
            nn.Conv2d(ndf * 2, ndf * 4, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ndf * 4),
            nn.LeakyReLU(0.2, inplace=True),
            # 输出大小为(ndf*4) x 8 x 8
            nn.Conv2d(ndf * 4, ndf * 8, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ndf * 8),
            nn.LeakyReLU(0.2, inplace=True),
            # 输出大小为(ndf*8) x 4 x 4
            nn.Conv2d(ndf * 8, 1, 4, 1, 0, bias=False),
            nn.Sigmoid()
        )

    def forward(self, input):
        # 将输入通过判别器的主要结构进行前向传播
        return self.main(input)

# 创建判别器模型
netD = Discriminator().to(device)
netD.apply(weights_init) # 使用 "weights_init" 函数来随机初始化所有权重
print(netD)

在这里插入图片描述

开始训练

定义损失函数和优化算法

# 损失函数
criterion = nn.BCELoss()

# 创建用于可视化生成器进程的潜在向量批次
fixed_noise = torch.randn(64, nz, 1, 1, device=device)

real_label = 1.
fake_label = 0.

# 为生成器(G)和判别器(D)设置Adam优化器
optimizerD = optim.Adam(netD.parameters(), lr=lr, betas=(beta1, 0.999))
optimizerG = optim.Adam(netG.parameters(), lr=lr, betas=(beta1, 0.999))

对于每个dataloader中的atch,会进行以下步骤:
1.更新判别器网络: 通过最大化判别器对真实图像和生成图像的损失来训练判别器。这包括计算对真实图像的损失和对生成图像的损失,然后通过梯度反向传播来更新判别器的参数。
2.更新生成器网络: 通过最大化生成器在生成图像上的损失来训练生成器。生成器的目标是欺骗判别器,使其无法区分生成的图像和真实图像。同样,通过梯度反向传播来更新生成器的参数。
3.记录损失值,输出训练统计信息,并定期保存生成器在固定噪声上的输出图像。

img_list = []  # 用于存储生成的图像列表
G_losses = []  # 用于存储生成器的损失列表
D_losses = []  # 用于存储判别器的损失列表
iters = 0  # 迭代次数

print("Starting Training Loop...")  # 输出训练开始的提示信息
# 对于每个epoch(训练周期)
for epoch in range(num_epochs):
    # 对于dataloader中的每个batch
    for i, data in enumerate(dataloader, 0):
        
        ############################
        # (1) 更新判别器网络:最大化 log(D(x)) + log(1 - D(G(z)))
        ###########################
        ## 使用真实图像样本训练
        netD.zero_grad()  # 清除判别器网络的梯度
        # 准备真实图像的数据
        real_cpu = data[0].to(device)
        b_size = real_cpu.size(0)
        label = torch.full((b_size,), real_label, dtype=torch.float, device=device)  # 创建一个全是真实标签的张量
        # 将真实图像样本输入判别器,进行前向传播
        output = netD(real_cpu).view(-1)
        # 计算真实图像样本的损失
        errD_real = criterion(output, label)
        # 通过反向传播计算判别器的梯度
        errD_real.backward()
        D_x = output.mean().item()  # 计算判别器对真实图像样本的输出的平均值

        ## 使用生成图像样本训练
        # 生成一批潜在向量
        noise = torch.randn(b_size, nz, 1, 1, device=device)
        # 使用生成器生成一批假图像样本
        fake = netG(noise)
        label.fill_(fake_label)  # 创建一个全是假标签的张量
        # 将所有生成的图像样本输入判别器,进行前向传播
        output = netD(fake.detach()).view(-1)
        # 计算判别器对生成图像样本的损失
        errD_fake = criterion(output, label)
        # 通过反向传播计算判别器的梯度
        errD_fake.backward()
        D_G_z1 = output.mean().item()  # 计算判别器对生成图像样本的输出的平均值
        # 计算判别器的总损失,包括真实图像样本和生成图像样本的损失之和
        errD = errD_real + errD_fake
        # 更新判别器的参数
        optimizerD.step()

        ############################
        # (2) 更新生成器网络:最大化 log(D(G(z)))
        ###########################
        netG.zero_grad()  # 清除生成器网络的梯度
        label.fill_(real_label)  # 对于生成器成本而言,将假标签视为真实标签
        # 由于刚刚更新了判别器,再次将所有生成的图像样本输入判别器,进行前向传播
        output = netD(fake).view(-1)
        # 根据判别器的输出计算生成器的损失
        errG = criterion(output, label)
        # 通过反向传播计算生成器的梯度
        errG.backward()
        D_G_z2 = output.mean().item()  # 计算判别器对生成器输出的平均值
        # 更新生成器的参数
        optimizerG.step()
        
        # 输出训练统计信息
        if i % 400 == 0:
            print('[%d/%d][%d/%d]\tLoss_D: %.4f\tLoss_G: %.4f\tD(x): %.4f\tD(G(z)): %.4f / %.4f'
                  % (epoch, num_epochs, i, len(dataloader),
                     errD.item(), errG.item(), D_x, D_G_z1, D_G_z2))
        
        # 保存损失值以便后续绘图
        G_losses.append(errG.item())
        D_losses.append(errD.item())
        
        # 通过保存生成器在固定噪声上的输出来检查生成器的性能
        if (iters % 500 == 0) or ((epoch == num_epochs-1) and (i == len(dataloader)-1)):
            with torch.no_grad():
                fake = netG(fixed_noise).detach().cpu()
            img_list.append(vutils.make_grid(fake, padding=2, normalize=True))
            
        iters += 1

在这里插入图片描述

可视化

查看训练过程

plt.figure(figsize=(10,5))
plt.title("Generator and Discriminator Loss During Training")
plt.plot(G_losses,label="G")
plt.plot(D_losses,label="D")
plt.xlabel("iterations")
plt.ylabel("Loss")
plt.legend()
plt.show()

在这里插入图片描述

查看生成的图像

# 创建一个大小为8x8的图形对象
fig = plt.figure(figsize=(8, 8))
# 不显示坐标轴
plt.axis("off")
# 将图像列表img_list中的图像转置并创建一个包含每个图像的单个列表ims
ims = [[plt.imshow(np.transpose(i, (1, 2, 0)), animated=True)] for i in img_list]
# 使用图形对象、图像列表ims以及其他参数创建一个动画对象ani
ani = animation.ArtistAnimation(fig, ims, interval=1000, repeat_delay=1000, blit=True)
# 将动画以HTML形式呈现
HTML(ani.to_jshtml())

在这里插入图片描述
对比一下真实图像和生成的图像

# 从数据加载器中获取一批真实图像
real_batch = next(iter(dataloader))

# 绘制真实图像
plt.figure(figsize=(15,15))
plt.subplot(1,2,1)
plt.axis("off")
plt.title("Real Images")
plt.imshow(np.transpose(vutils.make_grid(real_batch[0].to(device)[:64], padding=5, normalize=True).cpu(),(1,2,0)))

# 绘制上一个时期生成的假图像
plt.subplot(1,2,2)
plt.axis("off")
plt.title("Fake Images")
plt.imshow(np.transpose(img_list[-1],(1,2,0)))
plt.show()

在这里插入图片描述

总结

DCGAN是生成对抗网络的一种应用,包含生成器和判别器,通过对抗训练的方式,使得生成器能够生成逼真的数据,而判别器则学会区分真实数据和生成数据。总之,DCGAN的设计使得生成对抗网络在图像生成领域取得了显著的进展,促进了后续对GAN的研究和发展。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/364638.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

Linux下gcc的使用与程序的翻译

gcc和程序的翻译过程 gcc介绍程序的翻译过程预编译编译汇编链接 命令行式宏定义 gcc介绍 gcc是一款编译C语言编译器,可以把我们用vim写的代码编译成可执行程序。编译C用g进行编译,C的文件后缀是test.cc或test.cpp或test.cxx 如果要安装g就执行以下命令 …

一文详解docker swarm

文章目录 1、简介1.1、涉及到哪些概念?1.2、需要注意什么? 2、集群管理2.1、创建集群2.2、将节点加入集群2.3、查看集群状态。2.4、将节点从集群中移除2.5、更新集群2.6、锁定/解锁集群 3、节点管理4、服务部署4.1、准备4.2、服务管理4.2.1、常用命令4.2…

TCP 连接掉线自动重连

文章目录 TCP 连接掉线自动重连定义使用连接效果 TCP 接收数据时防止掉线。TCP 连接掉线自动重连。多线程环境下TCP掉线自动重连。 欢迎讨论更好的方法! TCP 连接掉线自动重连 定义 定义一个类,以编写TCP连接函数Connect(),并且&#xff1a…

分发糖果[困难]

优质博文:IT-BLOG-CN 一、题目 n个孩子站成一排。给你一个整数数组ratings表示每个孩子的评分。你需要按照以下要求,给这些孩子分发糖果: 【1】每个孩子至少分配到1个糖果。 【2】相邻两个孩子评分更高的孩子会获得更多的糖果。 请你给每个孩…

JavaScript基础五对象 内置对象 Math.random()

内置对象-生成任意范围随机数 Math.random() 随机数函数, 返回一个0 - 1之间,并且包括0不包括1的随机小数 [0, 1) 如何生成0-10的随机数呢? Math.floor(Math.random() * (10 1)) 放大11倍再向下取整 如何生成5-10的随机数&…

【智能算法】11种混沌映射算法+2种智能算法示范【鲸鱼WOA、灰狼GWO算法】

1 主要内容 混沌映射算法是我们在智能算法改进中常用到的方法,本程序充分考虑改进算法应用的便捷性,集成了11种混合映射算法,包括Singer、tent、Logistic、Cubic、chebyshev、Piecewise、sinusoidal、Sine、ICMIC、Circle、Bernoulli&#xf…

css实现按钮边框旋转

先上效果图 本质&#xff1a;一个矩形在两个矩形互相重叠遮盖形成的缝隙中旋转形成&#xff0c;注意css属性z-index层级关系 直接上代码 <div class"bg"><div class"button">按钮</div></div><style>.bg {width: 100%;heigh…

数字图像处理(实践篇)四十一 OpenCV-Python 使用sift算法检测图像上的特征点实践

目录 一 涉及的函数 二 实践 2004年,D.Lowe在论文Distinctive Image Features from Scale-Invariant Keypoints中提出了一种新算法,即尺度不变特征变换 (SIFT),该算法提取关键点并计算其描述符。SIFT提取图像的局部特征,在尺度空间寻找极值点,并提取出其位置尺度和方向…

绝地求生:“龙腾”通行证和新空投任务内容一览:二十级依然有图纸!

大家好&#xff0c;27.2版本终于更新完了&#xff0c;先为大家带来这次龙腾通行证的详细内容&#xff0c;显放上详细的兑换点数大家可以慢慢看。 省流: 通行证分支3仍然可解锁图纸和500G-COIN奖励&#xff0c;空投任务也可以通过做很简单的游戏任务70代币兑换获得1张图纸。 这次…

main函数中参数argc和argv用法解析

1 基础 argc 是 argument count 的缩写&#xff0c;表示传入main函数的参数个数&#xff1b; argv 是 argument vector 的缩写&#xff0c;表示传入main函数的参数序列或指针&#xff0c;并且第一个参数argv[0]一定是程序的名称&#xff0c;并且包含了程序所在的完整路径&…

vue 发布自己的npm组件

1、在项目任意位置创建index.ts文件 2、导入要到处的组件&#xff0c;使用vue提供的install 功能全局挂在&#xff1b; import GWButton from "/views/GWButton.vue"; import GWAbout from "/views/AboutView.vue";const components {GWButton,GWAbout, …

canvas路径剪裁clip(图文示例)

查看专栏目录 canvas实例应用100专栏&#xff0c;提供canvas的基础知识&#xff0c;高级动画&#xff0c;相关应用扩展等信息。canvas作为html的一部分&#xff0c;是图像图标地图可视化的一个重要的基础&#xff0c;学好了canvas&#xff0c;在其他的一些应用上将会起到非常重…

python执行linux系统命令的三种方式

前言 这是我在这个网站整理的笔记,有错误的地方请指出&#xff0c;关注我&#xff0c;接下来还会持续更新。 作者&#xff1a;神的孩子都在歌唱 1. 使用os.system 无法获取命令执行后的返回信息 import osos.system(ls)2. 使用os.popen 能够获取命令执行后的返回信息 impor…

什么是SDN-软件定义网络

知识改变命运&#xff0c;技术就是要分享&#xff0c;有问题随时联系&#xff0c;免费答疑&#xff0c;欢迎联系&#xff01; 厦门微思网络​​​​​​ https://www.xmws.cn 华为认证\华为HCIA-Datacom\华为HCIP-Datacom\华为HCIE-Datacom Linux\RHCE\RHCE 9.0\RHCA\ Oracle O…

力扣之2648.生成 斐波那契数列(yield)

/*** return {Generator<number>}*/ var fibGenerator function*() {let a 0,b 1;yield 0; // 返回 0&#xff0c;并暂停执行yield 1; // 返回 1&#xff0c;并暂停执行while(true) {yield a b; // 返回 a b&#xff0c;并暂停执行[a, b] [b, a b]; // 更新 a 和 …

大力说视频号第五课:千粉才能直播带货?如何做到

腾讯对直播带货设置了条件&#xff1a;要么你的账号得有1000个粉丝&#xff0c;要么你得开通视频号小店。 对于那些还没有开通视频号小店的用户&#xff0c;他们要想申请直播带货&#xff0c;就必须完成微信实名认证、拥有1000个以上的有效关注者&#xff0c;并缴纳一笔保证金。…

船员投保的数学模型(MATLAB求解)

1.问题描述 劳动工伤事故&#xff0c;即我们平时所说的“工伤事故”&#xff0c;也称职业伤害&#xff0c;是指劳动者在生产岗位上&#xff0c;从事与生产劳动有关的工作中发生的人身伤害事故、急性中毒事故或职业病。船员劳动工伤事故是指船员在船舶生产岗位上&#xff0c;从…

《基于“源启+”的应用重构白皮书》

当前&#xff0c;行业数字化转型驶入“深水区”&#xff0c;全新的市场竞争格局对行业发展提出更高的要求&#xff0c;企业高质量发展需要借助新架构新应用重新定义数字生产力&#xff0c;重塑商业模式与市场核心竞争力。 在中国电子主办&#xff0c;中电金信承办的“数字原生向…

SpringBoot实战(二十六)集成SFTP

目录 一、SFTP简介二、SpringBoot 集成2.1 Maven 依赖2.2 application.yml 配置2.3 DemoController.java 接口2.4 SftpService.java2.5 DemoServiceImpl.java 实现类2.6 SftpUtils.java 工具类2.7 执行结果1&#xff09;上传文件2&#xff09;下载文件3&#xff09;重命名文件&…

go并发编程-介绍与Goroutine使用

1. 并发介绍 进程和线程 A. 进程是程序在操作系统中的一次执行过程&#xff0c;系统进行资源分配和调度的一个独立单位。B. 线程是进程的一个执行实体,是CPU调度和分派的基本单位,它是比进程更小的能独立运行的基本单位。C.一个进程可以创建和撤销多个线程;同一个进程中的多个…