人脸图像生成(DCGAN)

一、理论基础

1.什么是深度卷积对抗网络(Deep Convolutional Generative Adversarial Network,)
深度卷积对抗网络(Deep Convolutional Generative Adversarial Network,DCGAN)是一种生成对抗网络(GAN)的变体,它结合了深度卷积神经网络(CNN)的特性和生成对抗网络的架构。

生成对抗网络是由生成器(Generator)和判别器(Discriminator)组成的模型。生成器尝试生成与真实数据相似的样本,而判别器则尝试区分生成的样本和真实样本。两者通过博弈的方式不断优化,使生成器生成更逼真的样本。

DCGAN引入了卷积神经网络的结构,以处理图像生成任务。它的主要特点包括:

卷积层替代全连接层: DCGAN中的生成器和判别器都使用卷积层,这有助于模型学习图像中的空间层次特征,从而更好地捕捉图像的结构信息。

批归一化(Batch Normalization): 在生成器和判别器中广泛使用批归一化,有助于加速训练过程,同时提高模型的稳定性和生成效果。

去除全连接层: DCGAN中移除了全连接层,这有助于减少模型参数数量,降低过拟合的风险。

使用Leaky ReLU激活函数: 生成器和判别器中使用Leaky ReLU激活函数,以避免梯度消失的问题,同时引入一定的负斜率,促使模型更容易学习。

DCGAN的目标是通过训练生成器生成逼真的图像,同时训练判别器以有效地区分真实和生成的图像。这种架构的成功应用包括图像生成、图像编辑、图像超分辨率等领域。

2.DCGAN原理
DCGAN(Deep Convolutional Generative Adversarial Network)由GAN进行改进得到,它由两个子网络组成:生成器和判别器。

生成器网络接受一个随机噪声向量作为输入,并尝试生成看起来像真实数据的输出。具体来说,生成器网络通常由多个卷积层和反卷积层组成,这些层将随机噪声转换为具有现实特征的图像。

判别器网络则接受输入并尝试将其分类为“真实”或“生成”。判别器网络通常由多个卷积层组成,这些层将输入转换为具有现实特征的表示形式,并输出一个二进制数字,表示输入是否是真实数据。

在训练过程中,生成器和判别器交替进行预测和生成,以逐渐提高生成器输出的质量。生成器试图生成看起来像真实数据的输出,而判别器则试图将其与真实数据区分开来。通过不断地调整生成器和判别器,最终生成器可以生成非常逼真的数据。

3.DCGAN与GAN相同点与不同点
GAN(Generative Adversarial Network)和DCGAN(Deep Convolutional Generative Adversarial Network)都是生成对抗网络,它们的基本原理是相同的,即通过两个相互对抗的网络(生成器和判别器)来进行无监督的学习,以生成高质量的新数据。但是,DCGAN相对于GAN有一些改进和不同点,主要表现在以下方面:

相同点:
基本原理相同:DCGAN和GAN的基本原理都是生成对抗网络(Generative Adversarial Networks),其中两个子网络(生成器和判别器)在对抗中优化彼此。
判别器结构相同:在DCGAN和GAN中,判别器结构都是相同的,都是使用卷积神经网络(CNN)进行特征提取和分类。
不同点:
网络结构不同:DCGAN将卷积运算的思想引入到生成式模型当中,生成器和判别器模型都使用了卷积层,而GAN使用了全连接层。
训练方法不同:DCGAN使用正交标注(Orthogonal Annotation)来生成伪标签(Pseudo Labels),以用于半监督学习。而GAN通常使用真实标签(True Labels)进行监督学习。
输入输出不同:DCGAN输入的数据通常是3D体积,而GAN输入的数据通常是2D图像。此外,DCGAN输出的数据也是3D体积,而GAN输出的数据是2D图像。
判别器模型不同:在DCGAN中,判别器模型使用卷积步长取代了空间池化,以更好地提取图像特征。而在GAN中,判别器模型通常使用空间池化。
生成器模型不同:在DCGAN中,生成器模型中使用反卷积操作扩大数据维度,以更好地处理高分辨率的图像。而在GAN中,生成器模型通常使用上采样操作。
模型优化不同:在DCGAN中,整个网络去除了全连接层,直接使用卷积层连接生成器和判别器的输入层以及输出层。此外,DCGAN还使用了一些其他的技术,如Batch Normalization、Leaky ReLU激活函数和Tanh输出层,以帮助控制输出范围和提高生成质量。而在GAN中,这些技术并不常见。
4.训练原理
DCGAN的训练原理是基于半监督学习,利用已有的少量标注数据和大量无标注数据来生成新的数据。以下是DCGAN训练原理的主要步骤:

正交标注:首先,对一个标记的3D体积进行两个正交切片的标注,即正交标注。这样可以减少标注的负担,并且利用了不同方向的切片提供的互补信息。
注册模块:通过注册模块,将正交标注传播到整个体积,生成了伪标签。注册模块利用了正交标注的信息,将其传播到整个体积,从而生成了伪标签,用于半监督学习的训练过程中。
生成器训练:使用已经生成的伪标签和未标注的数据来训练生成器模型。生成器模型的目标是最小化判别器模型的输出,即让判别器无法区分生成器和真实数据之间的差异。
判别器训练:使用已经生成的伪标签和真实数据来训练判别器模型。判别器模型的目标是最小化生成器模型的输出,即让生成器无法生成看起来像真实数据的输出。
迭代训练:不断地迭代训练生成器和判别器模型,直到生成器能够生成看起来像真实数据的输出,并且判别器能够准确地区分生成数据和真实数据。
DCGAN的训练原理是基于半监督学习,通过生成伪标签和使用注册模块来减少标注负担,从而利用未标注数据来生成新的数据。在整个训练过程中,生成器和判别器相互对抗,以提高生成器的性能和生成数据的质量。

二、前期准备 

import os
import random
import argparse
import numpy as np
import torch
import torch.nn as nn
import torch.nn.parallel
import torch.optim as optim
import torch.utils.data
from torch.autograd import Variable
from torch.utils.data import DataLoader
from torchvision import datasets as dset
import torchvision.utils as vutils
from torchvision.utils import save_image
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
import matplotlib.animation as animation
from IPython.display import HTML
 
 
manualSeed = 999  # 随机种子
print("Random Seed: ", manualSeed)
random.seed(manualSeed)
torch.manual_seed(manualSeed)
torch.use_deterministic_algorithms(True) # Needed for reproducible results
 
# 超参数配置
dataroot   = "D:/GAN-Data"  # 数据路径
batch_size = 128                   # 训练过程中的批次大小
n_epochs   = 5                     # 训练的总轮数
img_size   = 64                    # 图像的尺寸(宽度和高度)
nz         = 100                   # z潜在向量的大小(生成器输入的尺寸)
ngf        = 64                    # 生成器中的特征图大小
ndf        = 64                    # 判别器中的特征图大小
beta1      = 0.5                   # Adam优化器的Beta1超参数
beta2      = 0.2                   # Adam优化器的Beta1超参数
lr         = 0.0002                # 学习率
 
# 创建数据集
dataset = dset.ImageFolder(root=dataroot,
                           transform=transforms.Compose([
                               transforms.Resize(img_size),        # 调整图像大小
                               transforms.CenterCrop(img_size),    # 中心裁剪图像
                               transforms.ToTensor(),                # 将图像转换为张量
                               transforms.Normalize((0.5, 0.5, 0.5), # 标准化图像张量
                                                    (0.5, 0.5, 0.5)),]))
# 创建数据加载器
dataloader = torch.utils.data.DataLoader(dataset,
                                         batch_size=batch_size,  # 批量大小
                                         shuffle=True)           # 是否打乱数据集
# 选择要在哪个设备上运行代码
device = torch.device("cuda:0" if (torch.cuda.is_available()) else "cpu")
print("使用的设备是:",device)
# 绘制一些训练图像
real_batch = next(iter(dataloader))
plt.figure(figsize=(8,8))
plt.axis("off")
plt.title("Training Images")
plt.imshow(np.transpose(vutils.make_grid(real_batch[0].to(device)[:24],
           padding=2,
           normalize=True).cpu(),(1,2,0)))
 
 

三、定义模型 

criterion = nn.BCELoss()

fixed_noise = torch.randn(64, nz, 1, 1,device=device)
real_label = 1.
fake_label = 0.

# 为生成器(G)和判别器( D)设置Adam优化器
optimizerD = optim.Adam(netD.parameters(),lr=lr, betas=(beta1, 0.999))
optimizerG = optim.Adam(netG.parameters(),lr=lr, betas=(beta1, 0.999))
img_list = []  # 用于存储生成的图像列表
G_losses = []  # 用于存储生成器的损失列表
D_losses = []  # 用于存储判别器的损失列表
iters = 0  # 迭代次数
print("Starting Training Loop...")
for epoch in range(num_epochs ):
    for i, data in enumerate(dataloader, 0):
        netD.zero_grad()
        real_cpu = data[0].to(device)
        b_size = real_cpu.size(0)
        label = torch.full((b_size,), real_label, dtype=torch.float, device=device)
        output = netD(real_cpu).view(-1)
        errD_real = criterion(output, label)
        errD_real.backward()
        D_x = output.mean().item()
        
        noise = torch.randn(b_size, nz, 1, 1, device=device)
        fake = netG(noise)
        label.fill_(fake_label)
        output = netD(fake.detach( )).view(-1)
        errD_fake = criterion(output, label)
        errD_fake.backward()
        D_G_z1 = output.mean().item()
        errD = errD_real + errD_fake
        optimizerD.step()
        
        netG.zero_grad()
        label.fill_(real_label)
        output = netD(fake).view(-1)
        errG = criterion(output, label)
        errG.backward()
        D_G_z2 = output.mean().item()
        optimizerG.step()
        
        if i % 400 == 0:
            print('[%d/%d][%d/%d]\tLoss_D:%.4f\tLoss_G:%.4f\tD(x):%.4f\tD(G(z)):%.4f / %.4f'
                 % (epoch, num_epochs, i, len(dataloader), errD.item(), errG.item(), D_x, D_G_z1, D_G_z2))
        
        G_losses.append(errG.item())
        D_losses.append(errD.item())
        
        if(iters % 500 == 0) or ((epoch == num_epochs - 1) and(i == len(dataloader) - 1)):
            with torch.no_grad():
                fake = netG(fixed_noise).detach().cpu()
            img_list.append(vutils.make_grid(fake, padding=2, normalize=True))
            
        iters += 1
plt.figure(figsize=(10, 5))
plt.title('Generator and Discriminator Loss During Training')
plt.plot(G_losses, label='G')
plt.plot(D_losses, label='D')
plt.xlabel('Iterations')
plt.ylabel('Loss')
plt.legend()
plt.show()

 

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/610220.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

计算机通信SCI期刊推荐,JCR1区,IF=6+,审稿快,无版面费!

一、期刊名称 Computer Communications 二、期刊简介概况 期刊类型:SCI 学科领域:计算机科学 影响因子:6 中科院分区:3区 出版方式:订阅模式/开放出版 版面费:选择开放出版需支付$2300 三、期刊征稿…

STM32中的ICACHE是什么有什么用?如何使用?

什么是ICACHE? icache是一种用于缓存指令的存储器,其目的是提高CPU执行指令的效率。 在计算机系统中,icache(指令缓存)是处理器核心内部的一个关键组件,它专门用来存储最近使用过的指令。当CPU需要执行一个…

Bean的作用域

Bean的作用域 Bean的作用域是指在Spring整个框架中的某种行为模式,比如singleton单例作用域,就表示Bean在整个spring中只有一份,它是全局共享的,那么当其他人修改了这个值时,那么另一个人读取到的就是被修改的值 Bea…

每日OJ题_记忆化搜索②_力扣62. 不同路径(三种解法)

目录 力扣62. 不同路径 解析代码1_暴搜递归(超时) 解析代码2_记忆化搜索 解析代码3_动态规划 力扣62. 不同路径 62. 不同路径 难度 中等 一个机器人位于一个 m x n 网格的左上角 (起始点在下图中标记为 “Start” )。 机器…

element-ui skeleton 组件源码分享

今日简单分享 skeleton 骨架屏组件源码,主要从以下四个方面来讲解: 1、skeleton 组件的页面结构 2、skeleton 组件的属性 3、skeleton item 组件的属性 4、skeleton 组件的 slot 一、skeleton 组件的页面结构 二、skeleton 组件的属性 2.1 animate…

BS架构 数据权限--字段级权限 设计与实现

一、需求场景 1. 销售发货场景 销售出库单上 有 商品名称、发货数量、单价、总金额 等信息。 销售人员 关注 上述所有信息,但 仓管人员 不需要知道 单价、总金额 信息。 2. 配方、工艺保密 场景 配方研发人员 掌握核心配方, 但 交给车间打样、生产时…

一款免费的PDF转换工具分享

最近在吾爱上发现一款PDF免费转换工具,支持多种格式转换,试了一下,还不错 最重要的是免费,不用开会员转换,也没有限制(文末有工具地址) ps:转换完成后看一下是否符合,可能会有些许…

即插即用模块:Convolutional Triplet注意力模块(论文+代码)

目录 一、摘要 二、创新点总结 三、代码详解 论文:https://arxiv.org/pdf/2010.03045v2 代码:https://github. com/LandskapeAI/triplet-attention 一、摘要 由于注意机制具有在通道或空间位置之间建立相互依赖关系的能力,近年来在各种计…

Web功能测试之表单、搜索测试

初入职场接触功能测试老是碰到以下情况不知道怎么写测试用例: 一个界面很多搜索条件怎么写用例? 下拉框测试如何考虑测试点? 上传要考虑哪些验证点?...... 所以这篇主要是整理关于web测试之表单、搜索测试的相关要点。 一、表…

小程序(三)

十三、自定义组件 (二)数据方法声明位置 在js文件中 A、数据声明位置:data中 B、方法声明位置methods中,这点和普通页面不同! Component({/*** 组件的属性列表*/properties: {},/*** 组件的初始数据*/data: {isCh…

离线使用evaluate

一、目录 步骤demorouge-n 含义 二、实现 步骤 离线使用evaluate: 1. 下载evaluate 文件:https://github.com/huggingface/evaluate/tree/main2. 离线使用 路径/evaluate-main/metrics/rougedemo import evaluate离线使用evaluate: 1. 下载evaluate 文件&…

# 从浅入深 学习 SpringCloud 微服务架构(十五)

从浅入深 学习 SpringCloud 微服务架构(十五) 一、SpringCloudStream 的概述 在实际的企业开发中,消息中间件是至关重要的组件之一。消息中间件主要解决应用解耦,异步消息,流量削锋等问题,实现高性能&…

Java中使用RediSearch进行高效数据检索

RediSearch是一款构建在Redis上的搜索引擎,它为Redis数据库提供了全文搜索、排序、过滤和聚合等高级查询功能。通过RediSearch,开发者能够在Redis中实现复杂的数据搜索需求,而无需依赖外部搜索引擎。本文将介绍如何在Java应用中集成并使用Red…

3. 多层感知机算法和异或门的 Python 实现

前面介绍过感知机算法和一些简单的 Python 实践,这些都是单层实现,感知机还可以通过叠加层来构建多层感知机。 2. 感知机算法和简单 Python 实现-CSDN博客 1. 多层感知机介绍 单层感知机只能表示线性空间,多层感知机就可以表示非线性空间。…

切割链表 问题的讲解和实现(带哨兵位)

一:题目 二:思路讲解 先 将小于x的放进一个新的链表,将≥x的也放进另一个新链表。 然后 第一个新链表的尾节点的next链接到第二个新链表的哨兵节点的next,因为本身不存在哨兵位。 最后 链接完成后的链表的最后一个节点的next一…

python数据分析——数据的选择和运算

数据的选择和运算 前言一、数据选择NumPy的数据选择一维数组元素提取示例 多维数组行列选择、区域选择示例 花式索引与布尔值索引布尔索引示例一示例二 花式索引示例一示例二 Pandas数据选择Series数据获取DataFrame数据获取列索引取值示例一示例二 取行方式示例loc() 方法示例…

天地图2024版正式启用!首次开放多时相影像专题,可直接查看历史影像

近日,国家基础地理信息中心正式发布了天地图2024版。 新版本更新2米分辨率遥感影像794万平方公里、优于1米分辨率遥感影像655万平方公里,在线服务2米分辨率遥感影像覆盖全部陆地国土、优于1米遥感影像覆盖陆地国土达到98%,同时新增多时相遥感…

航空科技:探索飞机引擎可视化技术的新视界

随着航空技术的飞速发展,飞机引擎作为航空器最为关键的部件之一,其性能直接影响到飞机的安全性、经济性和环保性。因此,飞机引擎可视化技术的应用日益成为航空行业研究和发展的热点。 通过图扑将复杂的飞机引擎结构和工作原理以直观、生动的…

Linux中的日志系统简介

在的Linux系统上使用的日志系统一般为rsyslogd。rsyslogd守护进程既能接收用户进程输出的日志,又能接收内核日志。用户进程是通过调用syslog函数生成系统日志的。该函数将日志输出到一个UNIX本地域socket类型(AF_UNIX)的文件/dev/log中&#…

动联再掀创新风潮!P92 Max智能POS机惊艳发布

当下,智能支付与零售行业正经历着深刻变革,移动支付、无人支付等新型支付方式在我国广泛应用,显著优化了消费者的支付体验,同时也为零售行业带来新的发展契机。动联,凭借其在身份认证领域的深厚技术底蕴与创新精神&…