G2 基于生成对抗网络(GAN)人脸图像生成

  • 🍨 本文为🔗365天深度学习训练营 中的学习记录博客
  • 🍖 原作者:K同学啊

基于生成对抗网络(GAN)人脸图像生成

这周将构建并训练一个生成对抗网络(GAN)来生成人脸图像。

GAN 原理概述

生成对抗网络通过两个神经网络的对抗性结构来实现目标:

  • 生成器(G):输入随机噪声,通过学习数据的分布模式生成类似真实图像的输出。
  • 判别器(D):用来判断输入的图像是真实的还是生成器生成的。

训练过程中,生成器尝试欺骗判别器,生成逼真的图像,而判别器则不断优化,以区分真实图像与生成图像。这种对抗过程最终使生成器的生成能力逐渐逼近真实图像。

环境准备

首先导入相关库并设置随机种子以确保结果的可复现性。

import random
import os
import torch
import torch.nn as nn
import torch.optim as optim
import torch.utils.data
import torchvision.datasets as dset
import torchvision.transforms as transforms
import torchvision.utils as vutils
import matplotlib.pyplot as plt
import numpy as np

超参数设置

在训练GAN之前,首先定义一些关键的超参数:

  • batch_size:每个批次的样本数。
  • image_size:图像的大小,用于调整输入数据的尺寸。
  • nz:潜在向量大小,即生成器的输入维度。
  • ngfndf:分别控制生成器和判别器中的特征图数量。
  • num_epochs:训练的总轮数。
  • lr:学习率。
batch_size = 128
image_size = 64
nz = 100
ngf = 64
ndf = 64
num_epochs = 50
lr = 0.0002
beta1 = 0.5

数据加载

通过torchvision.datasets.ImageFolder加载数据,并使用 torch.utils.data.DataLoader 进行批量处理。数据加载时,通过转换函数调整图像大小,并对其进行归一化处理。

dataroot = "data/GANdata"
dataset = dset.ImageFolder(root=dataroot,
                           transform=transforms.Compose([
                               transforms.Resize(image_size),
                               transforms.CenterCrop(image_size),
                               transforms.ToTensor(),
                               transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
                           ]))
dataloader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=True)

网络结构定义

1. 生成器

生成器将随机噪声(潜在向量)通过一系列转置卷积层转换为图像。每层使用ReLU激活函数,最后一层用Tanh激活函数,将输出限制在 [-1, 1]

class Generator(nn.Module):
    def __init__(self):
        super(Generator, self).__init__()
        self.main = nn.Sequential(
            nn.ConvTranspose2d(nz, ngf * 8, 4, 1, 0, bias=False),
            nn.BatchNorm2d(ngf * 8),
            nn.ReLU(True),
            nn.ConvTranspose2d(ngf * 8, ngf * 4, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ngf * 4),
            nn.ReLU(True),
            nn.ConvTranspose2d(ngf * 4, ngf * 2, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ngf * 2),
            nn.ReLU(True),
            nn.ConvTranspose2d(ngf * 2, ngf, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ngf),
            nn.ReLU(True),
            nn.ConvTranspose2d(ngf, 3, 4, 2, 1, bias=False),
            nn.Tanh()
        )

    def forward(self, input):
        return self.main(input)

2. 判别器

判别器为卷积网络,通过一系列卷积层提取图像特征。每层使用LeakyReLU激活函数,最终输出一个值(真实为1,生成为0)。

class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()
        self.main = nn.Sequential(
            nn.Conv2d(3, ndf, 4, 2, 1, bias=False),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(ndf, ndf * 2, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ndf * 2),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(ndf * 2, ndf * 4, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ndf * 4),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(ndf * 4, ndf * 8, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ndf * 8),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(ndf * 8, 1, 4, 1, 0, bias=False),
            nn.Sigmoid()
        )

    def forward(self, input):
        return self.main(input)

训练过程

训练分为两个部分:判别器和生成器的更新。

1. 判别器的训练

判别器首先接收真实图像样本,计算输出与真实标签的误差。然后判别器接收生成器生成的假图像,再计算输出与假标签的误差。最终判别器的损失是两者的总和。

output = netD(real_cpu).view(-1)
errD_real = criterion(output, label)
errD_real.backward()

fake = netG(noise)
output = netD(fake.detach()).view(-1)
errD_fake = criterion(output, label.fill_(fake_label))
errD_fake.backward()

2. 生成器的训练

生成器的目标是欺骗判别器,因此其损失函数基于判别器将生成图像误识为真实的概率值。

output = netD(fake).view(-1)
errG = criterion(output, label.fill_(real_label))
errG.backward()

训练监控与可视化

在这里插入图片描述

训练时,我们记录生成器和判别器的损失,并生成一些样本图像来查看生成器的效果。

plt.figure(figsize=(10, 5))
plt.title("Generator and Discriminator Loss During Training")
plt.plot(G_losses, label="G")
plt.plot(D_losses, label="D")
plt.xlabel("iterations")
plt.ylabel("Loss")
plt.legend()
plt.savefig('Generator and Discriminator Loss During Training.png')

在这里插入图片描述

结果可视化

训练结束后,我们将真实图像与生成图像对比,以检验生成器的效果。

plt.figure(figsize=(15, 15))
plt.subplot(1, 2, 1)
plt.axis("off")
plt.title("Real Images")
plt.imshow(np.transpose(vutils.make_grid(real_batch[0].to(device)[:64], padding=5, normalize=True).cpu(), (1, 2, 0)))

plt.subplot(1, 2, 2)
plt.axis("off")
plt.title("Fake Images")
plt.imshow(np.transpose(img_list[-1], (1, 2, 0)))
plt.savefig('Fake Images.png')
plt.show()

在这里插入图片描述

总结

这周学习构建了一个深度卷积生成对抗网络(DCGAN),用于生成逼真的人脸图像,通过这周学习对对抗网路的构建有了更深的了解与运用

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/908991.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

N-155基于springboot,vue宿舍管理系统

开发工具:IDEA 服务器:Tomcat9.0, jdk1.8 项目构建:maven 数据库:mysql5.7 项目采用前后端分离 前端技术:vue3element-plus 服务端技术:springbootmybatis-plus 本项目分为学生、宿舍管理…

友思特应用 | FantoVision边缘计算:多模态传感+AI算法=新型非接触式医疗设备

导读 基于多模态传感技术和先进人工智能技术可有效提升乳腺癌检测的精准性、性价比和效率。友思特 FantoVision 边缘计算机 则为其生物组织数据的高效传输和实时分析提供了坚实基础。 乳腺癌的新型医疗检测方式 乳腺癌是女性面临的最令人担忧的健康问题之一,早期发…

【热门主题】000029 ECMAScript:现代编程的基石

前言:哈喽,大家好,今天给大家分享一篇文章!并提供具体代码帮助大家深入理解,彻底掌握!创作不易,如果能帮助到大家或者给大家一些灵感和启发,欢迎收藏关注哦 💕 目录 【热…

5G时代已来:我们该如何迎接超高速网络?

内容概要 随着5G技术的普及,我们的生活似乎变得更加“科幻”了。想象一下,未来的智能家居将不仅仅是能够听你说“开灯”;它们可能会主动询问你今天心情如何,甚至会推荐你一杯“维他命C芒果榨汁”,帮助你抵御夏天的炎热…

Navigating Net 算法简介

0. Inro \textbf{0. Inro} 0. Inro 1️⃣一些要用到的符号 ( U , dist ⁡ ) (U, \operatorname{dist}) (U,dist)为基础度量空间, S ⊆ U S \subseteq U S⊆U为包含 n ≥ 2 n \geq 2 n≥2个对象的 Input \text{Input} Input​ h ⌈ log ⁡ 2 diam ⁡ ( S ) ⌉ h\lef…

Java项目实战II基于Java+Spring Boot+MySQL的网上摄影工作室(源码+数据库+文档)

目录 一、前言 二、技术介绍 三、系统实现 四、文档参考 五、核心代码 六、源码获取 全栈码农以及毕业设计实战开发,CSDN平台Java领域新星创作者,专注于大学生项目实战开发、讲解和毕业答疑辅导。获取源码联系方式请查看文末 一、前言 随着互联网…

【Android 系统中使用CallStack类来追踪获取和操作调用栈信息】

Android系统CallStack类的使用 定义使用方法使用场景注意事项应用举例 定义 在 Android 系统中,CallStack 类是一个用于获取和操作调用栈信息的工具类。这个类通常用于调试和日志记录,以帮助开发者了解函数调用的顺序和位置。以下是您提供的代码片段的解…

IBM服务器修改IMM的IP方法

服务器设备:IBM x3550 M4 Server IMM默认IP地址:192.168.70.125 用户名:USERID 密码:PASSW0RD(注意是零0) 1.服务器开机按F1进入BIOS界面 2.进入System Settings 3.进入Integrated Management Module 4.…

【数据分享】1901-2023年我国省市县镇四级的逐年最高气温数据(免费获取/Shp/Excel格式)

之前我们分享过1901-2023年1km分辨率逐月最高气温栅格数据和Excel和Shp格式的省市县镇四级逐月最高气温数据,原始的逐月最高气温栅格数据来源于彭守璋学者在国家青藏高原科学数据中心平台上分享的数据!基于逐月数据我们采用求年平均值的方法得到逐年最高…

【前端】Vue3实现图片标点

前言 公司的业务要求可以在图片的位置上面进行标点,然后在现场对汽车桌椅可以实现按照标点进行质量检测。 技术栈 Vue3:https://cn.vuejs.org/index.htmlAnt Design Vue4.x:https://www.antdv.com/docs/vue/introduce-cn 图像标点 将画布…

FP7209M太阳能升压恒流一体测试板,带短路保护功能,软启动时间可调,应用于太阳能吸塑灯箱 商场便利店户外门头侧挂招牌广告牌led灯箱

太阳能灯箱用于城市主要街道、停车场、宾馆、旅游区、等夜间人群活动较多的公共场所照明的设备 太阳能广告灯箱凭借独特的设计理念为广告行业开辟一个全新的领域。不仅具有广告原有的宣传作用,还点亮了都市,小区的景观环境。在不需要架电线,电…

JS渗透(安全)

JS逆向 基本了解 作用域: 相关数据值 调用堆栈: 由下到上就是代码的执行顺序 常见分析调试流程: 1、代码全局搜索 2、文件流程断点 3、代码标签断点 4、XHR提交断点 某通js逆向结合burp插件jsEncrypter 申通快递会员中心-登录 查看登录包…

Imperva 数据库与安全解决方案

Imperva是网络安全解决方案的专业提供商,能够在云端和本地对业务关键数据和应用程序提供保护。公司成立于 2002 年,拥有稳定的发展和成功历史并于 2014 年实现产值1.64亿美元,公司的3700多位客户及300个合作伙伴分布于全球各地的90多个国家。…

工业网络监控中的IP保护与软件授权革新

未来的智能工厂离不开稳定而高效的通信网络,这些网络在支撑生产流程的同时,也面临着复杂的管理与安全挑战。PROCENTEC推出了一系列硬件和软件产品,如Atlas、Mercury和Osiris,以提供全面的网络监控和故障排除能力。然而&#xff0c…

基于springboot+vue实现的网上预约挂号管理系统 (源码+L文+ppt)4-104

结合现有六和医院网上预约挂号管理系统的特点,应用新技术,构建了六和医院网上预约挂号管理系统。首先从需求出发,对目前传统的六和医院网上预约挂号管理进行了详细的了解和分析。根据需求分析结果,对系统进行了设计,并…

QT for android 问题总结(QT 5.15.2)

1.配置好的sdk,显示设置失败 Android SDK Command-line Tools run. Android Platform-Tools installed. Command-line Tools (latest) 版本过高导致报错 ,下载一个低版本的latest ,替换掉之前latest中的文件。即可,latest 路径如…

NAS端最强音乐库,多平台服务支持。海康存储部署『Navidrome』

NAS端最强音乐库,多平台服务支持。海康存储部署『Navidrome』 哈喽小伙伴们好,我是Stark-C~ 对于我们NAS用户,我们总是喜欢将自己喜欢的音乐资源通过下载的方式保存在本地,不过海康存储目前对因音乐的支持和管理实在过于薄弱&am…

【论文阅读笔记】Wavelet Convolutions for Large Receptive Fields

1.论文介绍 Wavelet Convolutions for Large Receptive Fields 大感受野的小波卷积 2024 EECV Paper Code 2.摘要 近年来,人们试图通过增加卷积神经网络(ConvolutionalNeuralNets,CNNs)的核尺寸来模拟视觉变换器(V…

2024年最新10款顶级项目管理软件排行

项目管理软件在现代项目管理中扮演着至关重要的角色,它不仅仅是一个工具,更是一种高效、系统化的方法来管理和优化项目流程,帮助项目经理和团队成员快速了解项目状态,加速项目进展。 进度猫 进度猫是一款以甘特图为向导的轻量级…

SAP ABAP开发学习——RFC

目录 RFC接口 定义 调用过程 RFC的通信 RFC通信情况 RFC接口系统 RFC的通信模式 RFC版本 RFC调用方式 Web Service接口 SAP创建Web Service示例 远程目标的维护 创建远程目标 外部系统访问设置 RFC的调用 RFC接口 定义 调用过程 RFC的通信 RFC通信情况 RFC接…