基于PyTorch搭建你的生成对抗性网络

cd0651e77ef4cd23df4cfc0475b1bb07.jpeg

前言

你听说过GANs吗?还是你才刚刚开始学?GANs是2014年由蒙特利尔大学的学生 Ian Goodfellow 博士首次提出的。GANs最常见的例子是生成图像。有一个网站包含了不存在的人的面孔,便是一个常见的GANs应用示例。也是我们将要在本文中进行分享的。

生成对抗网络由两个神经网络组成,生成器和判别器相互竞争。我将在后面详细解释每个步骤。希望在本文结束时,你将能够从零开始训练和建立自己的生财之道对抗性网络。所以闲话少说,让我们开始吧。

目录

步骤0: 导入数据集

步骤1: 加载及预处理图像

步骤2: 定义判别器算法

步骤3: 定义生成器算法

步骤4: 编写训练算法

步骤5: 训练模型

步骤6: 测试模型

步骤0: 导入数据集

第一步是下载并将数据加载到内存中。我们将使用 CelebFaces Attributes Dataset (CelebA)来训练你的对抗性网络。主要分以下三个步骤:

1. 下载数据集:

https://s3.amazonaws.com/video.udacity-data.com/topher/2018/November/5be7eb6f_processed-celeba-small/processed-celeba-small.zip;

2. 解压缩数据集;

3. Clone 如下 GitHub地址:

https://github.com/Ahmad-shaikh575/Face-Generation-using-GANS

这样做之后,你可以在 colab 环境中打开它,或者你可以使用你自己的 pc 来训练模型。

导入必要的库

#import the neccessary libraries
import pickle as pkl
import matplotlib.pyplot as plt
import numpy as np
import torch.nn as nn
import torch.nn.functional as F
import torch
from torchvision import datasets
from torchvision import transforms
import torch
import torch.optim as optim

步骤1: 加载及预处理图像

在这一步中,我们将预处理在前一节中下载的图像数据。

将采取以下步骤:

  1. 调整图片大小

  2. 转换成张量

  3. 加载到 PyTorch 数据集中

  4. 加载到 PyTorch DataLoader 中

# Define hyperparameters
batch_size = 32
img_size = 32
data_dir='processed_celeba_small/'

# Apply the transformations
transform = transforms.Compose([transforms.Resize(image_size)
                                    ,transforms.ToTensor()])
# Load the dataset
imagenet_data = datasets.ImageFolder(data_dir,transform= transform)

# Load the image data into dataloader
celeba_train_loader = torch.utils.data.DataLoader(imagenet_data,
                                          batch_size,
                                          shuffle=True)

图像的大小应该足够小,这将有助于更快地训练模型。Tensors 基本上是 NumPy 数组,我们只是将图像转换为在 PyTorch 中所必需的 NumPy 数组。

然后我们加载这个转换成的 PyTorch 数据集。在那之后,我们将把我们的数据分成小批量。这个数据加载器将在每次迭代时向我们的模型训练过程提供图像数据。

随着数据的加载完成。现在,我们可以预处理图像。

图像的预处理

我们将在训练过程中使用 tanh 激活函数。该生成器的输出范围在 -1到1之间。我们还需要对这个范围内的图像进行缩放。代码如下所示:

def scale(img, feature_range=(-1, 1)):
  '''
  Scales the input image into given feature_range
  '''
    min,max = feature_range
    img = img * (max-min) + min
    return img

这个函数将对所有输入图像缩放,我们将在后面的训练中使用这个函数。

现在我们已经完成了无聊的预处理步骤。

接下来是令人兴奋的部分,现在我们需要为我们的生成器和判别器神经网络编写代码。

步骤2: 定义判别器算法

97127e11857e4e2d08f27fe2e2848c52.png

判别器是一个可以区分真假图像的神经网络。真实的图像和由生成器生成的图像都将提供给它。

我们将首先定义一个辅助函数,这个辅助函数在创建卷积网络层时非常方便。

# helper conv function
def conv(in_channels, out_channels, kernel_size, stride=2, padding=1, batch_norm=True):
    layers = []
    conv_layer = nn.Conv2d(in_channels, out_channels, 
                           kernel_size, stride, padding, bias=False)
    
    #Appending the layer
    layers.append(conv_layer)
    #Applying the batch normalization if it's given true
    if batch_norm:
        layers.append(nn.BatchNorm2d(out_channels))
     # returning the sequential container
    return nn.Sequential(*layers)

这个辅助函数接收创建任何卷积层所需的参数,并返回一个序列化的容器。现在我们将使用这个辅助函数来创建我们自己的判别器网络。

class Discriminator(nn.Module):

    def __init__(self, conv_dim):
        super(Discriminator, self).__init__()

        self.conv_dim = conv_dim

        #32 x 32
        self.cv1 = conv(3, self.conv_dim, 4, batch_norm=False)
        #16 x 16
        self.cv2 = conv(self.conv_dim, self.conv_dim*2, 4, batch_norm=True)
        #4 x 4
        self.cv3 = conv(self.conv_dim*2, self.conv_dim*4, 4, batch_norm=True)
        #2 x 2
        self.cv4 = conv(self.conv_dim*4, self.conv_dim*8, 4, batch_norm=True)
        #Fully connected Layer
        self.fc1 = nn.Linear(self.conv_dim*8*2*2,1)
        

    def forward(self, x):
        # After passing through each layer
        # Applying leaky relu activation function
        x = F.leaky_relu(self.cv1(x),0.2)
        x = F.leaky_relu(self.cv2(x),0.2)
        x = F.leaky_relu(self.cv3(x),0.2)
        x = F.leaky_relu(self.cv4(x),0.2)
        # To pass throught he fully connected layer
        # We need to flatten the image first
        x = x.view(-1,self.conv_dim*8*2*2)
        # Now passing through fully-connected layer
        x = self.fc1(x)
        return x

步骤3: 定义生成器算法

d30da4cfe3a34a4e28baacfc833e75f8.png

正如你们从图中看到的,我们给网络一个高斯矢量或者噪声矢量,它输出 s 中的值。图上的“ z”表示噪声,右边的 G (z)表示生成的样本。

与判别器一样,我们首先创建一个辅助函数来构建生成器网络,如下所示:

def deconv(in_channels, out_channels, kernel_size, stride=2, padding=1, batch_norm=True):
    layers = []
    convt_layer = nn.ConvTranspose2d(in_channels, out_channels, 
                           kernel_size, stride, padding, bias=False)
    
    # Appending the above conv layer
    layers.append(convt_layer)

    if batch_norm:
        # Applying the batch normalization if True
        layers.append(nn.BatchNorm2d(out_channels))
     
    # Returning the sequential container
    return nn.Sequential(*layers)

现在,是时候构建生成器网络了! !

class Generator(nn.Module):
    
    def __init__(self, z_size, conv_dim):
        super(Generator, self).__init__()

        self.z_size = z_size
        
        self.conv_dim = conv_dim
        
        #fully-connected-layer
        self.fc = nn.Linear(z_size, self.conv_dim*8*2*2)
        #2x2
        self.dcv1 = deconv(self.conv_dim*8, self.conv_dim*4, 4, batch_norm=True)
        #4x4
        self.dcv2 = deconv(self.conv_dim*4, self.conv_dim*2, 4, batch_norm=True)
        #8x8
        self.dcv3 = deconv(self.conv_dim*2, self.conv_dim, 4, batch_norm=True)
        #16x16
        self.dcv4 = deconv(self.conv_dim, 3, 4, batch_norm=False)
        #32 x 32

    def forward(self, x):
        # Passing through fully connected layer
        x = self.fc(x)
        # Changing the dimension
        x = x.view(-1,self.conv_dim*8,2,2)
        # Passing through deconv layers
        # Applying the ReLu activation function
        x = F.relu(self.dcv1(x))
        x= F.relu(self.dcv2(x))
        x= F.relu(self.dcv3(x))
        x= F.tanh(self.dcv4(x))
        #returning the modified image
        return x

为了使模型更快地收敛,我们将初始化线性和卷积层的权重。根据相关研究论文中的描述:所有的权重都是从0中心的正态分布初始化的,标准差为0.02

我们将为此目的定义一个功能如下:

def weights_init_normal(m):
    classname = m.__class__.__name__
    # For the linear layers
    if 'Linear' in classname:
        torch.nn.init.normal_(m.weight,0.0,0.02)
        m.bias.data.fill_(0.01)
    # For the convolutional layers
    if 'Conv' in classname or 'BatchNorm2d' in classname:
        torch.nn.init.normal_(m.weight,0.0,0.02)

现在我们将超参数和两个网络初始化如下:

# Defining the model hyperparamameters
d_conv_dim = 32
g_conv_dim = 32
z_size = 100   #Size of noise vector


D = Discriminator(d_conv_dim)
G = Generator(z_size=z_size, conv_dim=g_conv_dim)
# Applying the weight initialization
D.apply(weights_init_normal)
G.apply(weights_init_normal)

print(D)
print()
print(G)

输出结果大致如下:

b0dbbadcdafb3a9f8bf28da76d19bd8f.png

判别器损失:

根据 DCGAN Research Paper 论文中描述:

        判别器总损失 = 真图像损失 + 假图像损失,即:d_loss = d_real_loss + d_fake_loss。

       不过,我们希望鉴别器输出1表示真正的图像和0表示假图像,所以我们需要设置的损失来反映这一点。

我们将定义双损失函数。一个是真正的损失,另一个是假的损失,如下:

def real_loss(D_out,smooth=False):
    
    batch_size = D_out.size(0)
    if smooth:
        labels = torch.ones(batch_size)*0.9
    else:
        labels = torch.ones(batch_size)
    
    labels = labels.to(device)
    criterion = nn.BCEWithLogitsLoss()
    loss = criterion(D_out.squeeze(), labels)
    return loss

def fake_loss(D_out):

    batch_size = D_out.size(0)
    labels = torch.zeros(batch_size)
    labels = labels.to(device)
    criterion = nn.BCEWithLogitsLoss()
    loss = criterion(D_out.squeeze(), labels)
    return loss

生成器损失:

根据 DCGAN Research Paper 论文中描述:

        生成器的目标是让判别器认为它生成的图像是真实的。

现在,是时候为我们的网络设置优化器了:

lr = 0.0005
beta1 = 0.3
beta2 = 0.999 # default value
# Optimizers
d_optimizer = optim.Adam(D.parameters(), lr, betas=(beta1, beta2))
g_optimizer = optim.Adam(G.parameters(), lr, betas=(beta1, beta2))

我将为我们的训练使用 Adam 优化器。因为它目前被认为是对GANs最有效的。根据上述介绍论文中的研究成果,确定了超参数的取值范围。他们已经尝试了它,这些被证明是最好的!超参数设置如下:

步骤4: 编写训练算法

我们必须为我们的两个神经网络编写训练算法。首先,我们需要初始化噪声向量,并在整个训练过程中保持一致。

# Initializing arrays to store losses and samples
samples = []
losses = []

# We need to initilialize fixed data for sampling
# This would help us to evaluate model's performance
sample_size=16
fixed_z = np.random.uniform(-1, 1, size=(sample_size, z_size))
fixed_z = torch.from_numpy(fixed_z).float()

对于判别器:

我们首先将真实的图像输入判别器网络,然后计算它的实际损失。然后生成伪造图像并输入判别器网络以计算虚假损失。

在计算了真实和虚假损失之后,我们对其进行求和,并采取优化步骤进行训练。

# setting optimizer parameters to zero
# to remove previous training data residue
d_optimizer.zero_grad()

# move real images to gpu memory
real_images = real_images.to(device)

# Pass through discriminator network
dreal = D(real_images)

# Calculate the real loss
dreal_loss = real_loss(dreal)

# For fake images

# Generating the fake images
z = np.random.uniform(-1, 1, size=(batch_size, z_size))
z = torch.from_numpy(z).float()

# move z to the GPU memory
z = z.to(device)

# Generating fake images by passing it to generator
fake_images = G(z)

# Passing fake images from the disc network        
dfake = D(fake_images)
# Calculating the fake loss
dfake_loss = fake_loss(dfake)

#Adding both lossess
d_loss = dreal_loss + dfake_loss
# Taking the backpropogation step
d_loss.backward()
d_optimizer.step()

对于生成器:

对于生成器网络的训练,我们也会这样做。刚才在通过判别器网络输入假图像之后,我们将计算它的真实损失。然后优化我们的生成器网络。

## Training the generator for adversarial loss
#setting gradients to zero
g_optimizer.zero_grad()

# Generate fake images
z = np.random.uniform(-1, 1, size=(batch_size, z_size))
z = torch.from_numpy(z).float()
# moving to GPU's memory
z = z.to(device)

# Generating Fake images
fake_images = G(z)

# Calculating the generator loss on fake images
# Just flipping the labels for our real loss function
D_fake = D(fake_images)
g_loss = real_loss(D_fake, True)

# Taking the backpropogation step
g_loss.backward()
g_optimizer.step()

步骤5: 训练模型

现在我们将开始100个epoch的训练: D

经过训练,损失的图表看起来大概是这样的:

297583a673349a29e8b1e16942b38a73.png

我们可以看到,判别器 Loss 是相当平滑的,甚至在100个epoch之后收敛到某个特定值。而生成器的Loss则飙升。

我们可以从下面步骤6中的结果看出,60个时代之后生成的图像是扭曲的。由此可以得出结论,60个epoch是一个最佳的训练节点。

步骤6: 测试模型

10个epoch之后:

96d12f9703a12ac0e77a0c2f87c3f146.png

20个epoch之后:

dc196f60f97d770c6d0c79c8f4d85ae7.png

30个epoch之后:

33365e822eda268d5e048d296cbea7be.png

40个epoch之后:

0198316b876a53571d5550f9f97556d5.png

50个epoch之后:

bdcf6ad6c3a1977c55cdd7f6fd9a4816.png

60个epoch之后:

4152c319441182ae8907a6b354fb6a44.png

70个epoch之后:

b48254e7675b521a66d059da034db13b.png

80个epoch之后:

e567311618525ee9d2e67e0292fe067d.png

90个epoch之后:

c933353f759c001a2d1cd8d9ec81f3ba.png

100个epoch之后:

0f83fe27e164d062d3afad75e7960e75.png

总结

我们可以看到,训练一个生成对抗性网络并不意味着它一定会产生好的图像。

从结果中我们可以看出,训练40-60个 epoch 的生成器生成的图像相对比其他更好。

您可以尝试更改优化器、学习速率和其他超参数,以使其生成更好的图像!

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/149594.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

Apache Pulsar 技术系列 - 基于 Pulsar 的海量 DB 数据采集和分拣

导语 Apache Pulsar 是一个多租户、高性能的服务间消息传输解决方案,支持多租户、低延时、读写分离、跨地域复制、快速扩容、灵活容错等特性。本文是 Pulsar 技术系列中的一篇,主要介绍 Pulsar 在海量DB Binlog 增量数据采集、分拣场景下的应用。 前言…

阿里影业S1财报解读:优质内容叠加整合效益,转动增长飞轮

从《消失的她》到《长安三万里》再到《孤注一掷》,市场对阿里影业半年报好成绩已有所预期。 11月13日,阿里影业发布2023/24半年度业绩。根据财报,报告期内(4月1日至9月30日),阿里影业实现收入人民币26.16亿…

深入理解SqueezeSegV3点云分割

文章:Squeezesegv3: Spatially-adaptive convolution for efficient point-cloud segmentation 代码:https://github.com/chenfengxu714/SqueezeSegV3 一、摘要 激光雷达点云分割是许多应用中的一个重要问题。对于大规模点云分割,一般是投…

【算法每日一练]-图论(保姆级教程 篇2(topo排序,并查集,逆元))#topo排序 #最大食物链 #游走 #村村通

今天讲topo排序 目录 题目:topo排序 思路: 题目:最大食物链 解法一: 解法二: 记忆化 题目:村村通 思路: 前言:topo排序专门处理DAG(有向无环图) 题目…

15个顶级元宇宙游戏

元宇宙游戏是可让数百万玩家在一个虚拟世界中相互互动,允许你按照自己的节奏玩游戏,并根据自己的条件推广自己的品牌。 而且,这些游戏中的大多数都涉及虚拟 NFT,它们是完全独特的和虚拟的。在 Facebook 将品牌重新命名为“Meta”…

Spring 国际化:i18n 如何使用

1、i18n概述 国际化也称作i18n,其来源是英文单词 internationalization的首末字符i和n,18为中间的字符数。由于软件发行可能面向多个国家,对于不同国家的用户,软件显示不同语言的过程就是国际化。通常来讲,软件中的国…

11月第2周榜单丨飞瓜数据B站UP主排行榜榜单(B站平台)发布!

飞瓜轻数发布2023年11月6日-11月12日飞瓜数据UP主排行榜(B站平台),通过充电数、涨粉数、成长指数、带货数据等维度来体现UP主账号成长的情况,为用户提供B站号综合价值的数据参考,根据UP主成长情况用户能够快速找到运营…

【JUC】六、辅助类

文章目录 1、CountDownLatch减少计数2、CyclicBarrier循环栅栏3、Semaphore信号灯 本篇整理JUC的几个同步辅助类: 减少计数:CountDownLatch循环栅栏:CyclicBarrier信号灯:Semaphore 1、CountDownLatch减少计数 案例:6…

基于opencv+tensorflow+神经网络的智能银行卡卡号识别系统——深度学习算法应用(含python、模型源码)+数据集(二)

目录 前言总体设计系统整体结构图系统流程图 运行环境模块实现1. 训练集图片处理1)数据加载2)图像处理 2. 测试图片处理1)图像读取2)图像处理 相关其它博客工程源代码下载其它资料下载 前言 本项目基于从网络获取的多种银行卡数据…

政府指导89元保330万 “聊惠保”2024年度正式上线!

11月15日,“聊惠保”2024年度启动仪式在聊城市融媒体中心举行。市政府领导,省直、市直相关部门单位和共保体成员单位负责同志参加仪式。“聊惠保”2024年度正式上线!“聊惠保”项目组为聊城市医疗救助困难群体捐赠“聊惠保”2024年度团体保险…

python基础练习题库实验八

文章目录 前言题目1代码 题目2代码 题目3代码 总结 前言 🎈关于python小题库的这模块我已经两年半左右没有更新了,主要是在实习跟考研,目前已经上岸武汉某211计算机,目前重新学习这门课程,也做了一些新的题目 &#x…

LeetCode34-34. 在排序数组中查找元素的第一个和最后一个位置

&#x1f517;:代码随想录:二分查找的算法讲解:有关left<right和left<right的区别 class Solution {public int[] searchRange(int[] nums, int target) {int nnums.length;int l0,hn-1;if(numsnull){return null; }if(n0){return new int[]{-1,-1}; }if(target&l…

阿里云99元ECS云服务器老用户也能买,续费同价!

阿里云近日宣布了2023年的服务器优惠活动&#xff0c;令用户们振奋不已。最引人瞩目的消息是&#xff0c;阿里云放开了老用户的购买资格&#xff0c;99元服务器也可以供老用户购买&#xff0c;并且享受续费的99元优惠。此外&#xff0c;阿里云还推出了ECS经济型e实例&#xff0…

8年经验的软件工程师建议

我希望在职业生涯早期就开始做的事情和我希望以不同的方式做的事情。 大家好&#xff0c;我已经做了八年半的软件工程师。这篇文章来源于我最近对自己在职业生涯中希望早点开始做的事情以及希望以不同方式做的事情的自我反思。 我在这里分享的对任何希望提高和进步到高级甚至…

Java远程操作Linux服务器命令

Java可以通过SSH协议远程连接Linux服务器&#xff0c;然后使用JSch库或者Apache Commons Net库来执行远程Linux命令。以下是一个使用JSch库的示例代码&#xff1a; import com.jcraft.jsch.*;public class RemoteCommandExecutor {private String host;private String user;pr…

2023年亚太杯数学建模思路 - 复盘:人力资源安排的最优化模型

文章目录 0 赛题思路1 描述2 问题概括3 建模过程3.1 边界说明3.2 符号约定3.3 分析3.4 模型建立3.5 模型求解 4 模型评价与推广5 实现代码 建模资料 0 赛题思路 &#xff08;赛题出来以后第一时间在CSDN分享&#xff09; https://blog.csdn.net/dc_sinor?typeblog 1 描述 …

ubuntu 20通过docker安装onlyoffice,并配置https访问

目录 一、安装docker &#xff08;一&#xff09;更新包列表和安装依赖项 &#xff08;二&#xff09;添加Docker的官方GPG密钥 &#xff08;三&#xff09;添加Docker存储库 &#xff08;四&#xff09;安装Docker &#xff08;五&#xff09;启动Docker服务并设置它随系…

2024年上半年:加密领域迎来无限机遇与重大突破!

2024年上半年将成为加密行业发展的关键时期&#xff0c;一系列重大事件和计划将为这一领域带来深远的影响。这些举措不仅有望吸引更多机构投资者和资金流入加密市场&#xff0c;还将进一步提升比特币的认可度和流动性&#xff0c;推动整个行业迈向新的阶段。 SEC批准比特币现货…

消息通讯——MQTT WebHookSpringBoot案例

目录 EMQX WebHook介绍EMQX WebHook是什么EMQX WebHook配置项如何使用EMQX WebHook配置WebHook配置事件推送参数详解 SpringBoot集成Webhook实现客户端断连监控1. 实现前提2. 代码实现接口3. 监听结果 总结 EMQX WebHook介绍 EMQX WebHook是什么 EMQX WebHook 是由 emqx_web_…

大语言模型概述|亚马逊这些互联网公司为什么花巨资训练自己的模型?

2023年可谓是大语言模型元年&#xff0c;OpenAI、亚马逊、谷歌等互联网公司争先恐后推出了自己的大语言模型&#xff1a;GPT-4、Titan、PaLM 2&#xff0c;还有亚马逊即将推出的第二个大语言模型Olympus等等。这一革命性技术如今已经在全球范围内引发了广泛的讨论和关注&#x…