第G3周:CGAN入门|生成手势图像

  • 🍨 本文为🔗365天深度学习训练营 中的学习记录博客
  • 🍖 原作者:K同学啊 | 接辅导、项目定制

一、前置知识

CGAN(条件生成对抗网络)的原理是在原始GAN的基础上,为生成器和判别器提供 额外的条件信息

CGAN通过将条件信息(如类别标签或其他辅助信息)加入生成器和判别器的输入中,使得生成器能够根据这些条件信息生成特定类型的数据,而判别器则负责区分真实数据和生成数据是否符合这些条件。这种方式让生成器在生成数据时有了明确的方向,从而提高了生成数据的质量与相关性。

CGAN的特点包括有监督学习、联合隐层表征、可控性、使用卷积结构等,其具体内容为:

  1. 有监督学习:CGAN通过额外信息的使用,将原本无监督的GAN转变为一种有监督的学习模式,这使得网络的训练更加目标明确,生成结果更加符合预期。
  2. 联合隐层表征:在生成模型中,噪声输入和条件信息共同构成了联合隐层表征,这有助于生成更多样化且具有特定属性的数据。
  3. 可控性:CGAN的一个关键特点是提高了生成过程的可控性,即可以通过调整条件信息来指导模型生成特定类型的数据。
  4. 使用卷积结构:CGAN可以采用卷积神经网络作为其内部结构,这在图像相关的任务中尤其有效,因为它能够捕捉到局部特征,并提高模型对细节的处理能力。

相比于传统的GAN,CGAN的主要异同点包括条件信息的输入、训练稳定性、损失函数、网络结构等,其具体内容为:

  1. 条件信息的输入:CGAN引入了条件变量,使得生成器和判别器都能接收到更多的信息来指导训练过程,这是传统GAN所不具备的。
  2. 训练稳定性:传统GAN在训练过程中容易产生模式崩溃(mode collapse)的问题,而CGAN由于有了额外的条件信息,可以提高训练的稳定性和生成数据的多样性。
  3. 损失函数:虽然CGAN的损失函数仍然保留了传统GAN的对抗损失函数的形式,但额外添加的条件信息使得损失计算更加复杂且有针对性。
  4. 网络结构:在实现上,CGAN可以采用更深更复杂的网络结构,如卷积神经网络,这有助于处理更为复杂的数据类型,比如高分辨率图像。

综上所述,CGAN的核心在于它通过引入条件信息来增强模型的生成能力和可控性,与传统GAN相比,它
提供了更明确的训练目标和更好的生成效果。

二、准备工作

代码知识点

  1. torch:PyTorch库,用于实现深度学习模型和张量计算。
  2. numpy:NumPy库,用于进行数值计算和处理多维数组。
  3. torch.nn:PyTorch的神经网络模块,包含各种神经网络层和损失函数。
  4. torch.optim:PyTorch的优化器模块,包含各种优化算法,如SGD、Adam等。
  5. torchvision:PyTorch的计算机视觉库,包含数据集、预处理、数据增强等工具。
  6. torch.autograd:PyTorch的自动求导模块,用于自动计算梯度。
  7. torchvision.utils:PyTorch计算机视觉工具包,包含一些实用函数,如保存图像、制 作网格等。
  8. torch.utils.tensorboard:PyTorch的TensorBoard接口,用于可视化训练过程。
  9. torchsummary:PyTorch的模型摘要工具,用于显示模型结构及参数数量。
  10. matplotlib.pyplot:Matplotlib库的绘图模块,用于绘制图表和可视化数据。
  11. datetime:Python的日期时间模块,用于处理日期和时间相关的操作。
import torch
import numpy as np
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from torch.autograd import Variable
from torchvision.utils import save_image
from torchvision.utils import make_grid
from torch.utils.tensorboard import SummaryWriter
from torchsummary import summary
import matplotlib.pyplot as plt
import datetime

设置随机种子,确保每次运行代码时生成的随机数序列是相同的

torch.manual_seed(1)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
batch_size = 128

1.导入数据

train_transform = transforms.Compose([
    transforms.Resize(128),
    transforms.ToTensor(),
    transforms.Normalize([0.5,0.5,0.5], [0.5,0.5,0.5])])

train_dataset = datasets.ImageFolder(root="H:/G3周数据集/rps/rps", transform=train_transform)
train_loader = torch.utils.data.DataLoader(dataset=train_dataset, 
                                           batch_size=batch_size, 
                                           shuffle=True,
                                           num_workers=6)

2.数据可视化

代码知识点
permute()函数的作用是对tensor进行重新排序或转置
使用permute()进行维度调换主要是因为在深度学习中,不同的模型或操作可能要求输入数据的维度顺序不同。

在PyTorch中,permute()是一个用于改变张量(tensor)形状的函数,它通过重新排列张量的维度来实现这一点。具体来说,permute()函数接收一系列维度的索引作为参数,这些索引指定了新的维度顺序。以下是关于permute()函数的详细信息:

  • 维度调换:通过传递一组新的维度顺序,permute()可以改变张量的形状。我们可以使用x.permute(0, 2, 1, 3)来将其重新排列为(a, c, b, d)。这里的第一个参数0表示保持原始张量的第一个维度不变,2表示将原始的第三个维度移动到第二个位置,1表示将原始的第二个维度移动到第三个位置,最后一个参数3表示保持原始张量的最后一个维度不变
  • 灵活性:与transpose()函数相比,permute()可以处理任意数量的维度,而transpose()通常用于二维矩阵的转置。虽然连续使用transpose()可以实现与permute()相同的效果,但permute()提供了一种更为直观和简洁的方式来调整多维数组的维度顺序。
  • 不变性:使用permute()不会改变张量中的数据,只会改变数据的组织方式。这意味着原始数据保持不变,只是按照新的顺序重新排列而已。

综上所述,permute()函数是一个非常有用的工具,特别是在处理多维数据时,如图像、声音信号等,它可以帮助我们根据需要重新组织数据,以便于进行后续的处理和分析。

def show_images(images):
    fig, ax = plt.subplots(figsize=(20, 20))
    ax.set_xticks([]); ax.set_yticks([])
    ax.imshow(make_grid(images.detach(), nrow=22).permute(1, 2, 0))

def show_batch(dl):
    for images, _ in dl:
        show_images(images)
        break
show_batch(train_loader)

输出
在这里插入图片描述
代码知识点
np.prod函数的作用是计算数组中所有元素的乘积
关于np.prod函数的参数,以下是详细的解释:

  • a: 输入的数组,可以是标量、一维向量、二维矩阵或者多维数组。
  • axis: 指定要沿其计算乘积的轴。如果为None(默认值),则计算整个数组的乘积。
  • dtype: 指定结果的数据类型。如果未指定(默认为None),则使用输入数组的数据类型。
  • out: 指定输出结果的数组。如果未指定(默认为None),则创建一个新的数组来存储结果。
  • keepdims: 如果设置为True,则被缩减的轴将作为单例维度保留在结果中。如果为False,则被缩减的轴将完全删除。
  • initial: 指定乘积的初始值。默认值为1。
  • where: 可选参数,用于指定一个布尔数组,结果只有在该数组对应位置为True时才会被设置。

latent_dim代表了潜空间(latent space)的维度,是深度学习中特别是生成模型的一个核心概念

首先,我们来了解为什么需要设置latent_dim。在生成对抗网络(GANs)和其他生成模型中,latent_dim定义了隐空间的大小,这个空间是编码器将输入数据映射到的地方。一个合适的latent_dim对于模型的表达能力至关重要,它决定了模型能够捕捉和表示数据的复杂性的程度。如果latent_dim设置得太低,模型可能会丢失数据的重要特征,导致生成的数据缺乏多样性;而如果设置得太高,则可能会导致模型学习效率低下,增加计算成本。

接下来,我们讨论如何设置latent_dim的数值。理论上,latent_dim的设置应该与数据的内在维数(intrinsic dimension)相匹配,这是流形学习中的一个概念,指的是数据分布的本质复杂性。在实际应用中,确定latent_dim的数值通常需要依赖实验和经验。研究人员会通过实验不同的维度值,观察模型性能的变化,以此来找到最佳的维度设置。例如,在处理简单模式的数据时,可能只需要较低的维度;而对于高度复杂的数据,如高分辨率图像,可能需要更高的维度来捕获细节。

image_shape = (3, 128, 128)
image_dim = int(np.prod(image_shape))
latent_dim = 100

代码知识点
embedding_dim是一个超参数,它定义了嵌入层中每个离散输入值映射到的连续向量的维度大小

embedding_dim的主要作用是控制嵌入向量的维度,这有助于模型捕捉输入数据中的语义关系和相似性。例如,在自然语言处理(NLP)任务中,单词可以被映射到一个多维空间,其中语义相近的单词在空间中的位置也相近。这样,模型就能够更好地理解和处理语言数据。

设置embedding_dim的方法通常取决于具体任务的复杂性和数据集的特性。一方面,如果embedding_dim设置得太低,可能无法充分捕捉数据中的细微差别;另一方面,如果设置得过高,则可能会导致不必要的计算负担和过拟合的风险。因此,选择合适的embedding_dim需要权衡模型性能和计算效率。

在实际应用中,通常会根据经验或实验来确定embedding_dim的最佳值。例如,对于小型数据集或简单的任务,可能只需要较低的维度;而对于大型数据集或复杂的任务,可能需要更高的维度以获得更好的性能。此外,还可以通过交叉验证等方法来系统地评估不同维度值对模型性能的影响,从而找到最优的设置。

n_classes = 3
embedding_dim = 100

三、构建模型

代码知识点
这段代码是一个权重初始化函数,用于初始化神经网络中的卷积层和批量归一化层的权重。

  1. def weights_init(m): 定义一个名为weights_init的函数,输入参数为m,表示网络中的一个模块。
  2. classname = m.__class__.__name__ 获取模块m的类名,并将其赋值给变量classname
  3. if classname.find('Conv') != -1: 如果classname中包含字符串’Conv’,说明该模块是卷积层。
  4. torch.nn.init.normal_(m.weight, 0.0, 0.02) 使用正态分布初始化卷积层的权重,均值为0,标准差为0.02。
  5. elif classname.find('BatchNorm') != -1: 如果classname中包含字符串’BatchNorm’,说明该模块是批量归一化层。
  6. torch.nn.init.normal_(m.weight, 1.0, 0.02) 使用正态分布初始化批量归一化层的权重,均值为1,标准差为0.02。
  7. torch.nn.init.zeros_(m.bias) 将批量归一化层的偏置项初始化为全零。
def weights_init(m):
    classname = m.__class__.__name__

    if classname.find('Conv') != -1:
        torch.nn.init.normal_(m.weight, 0.0, 0.02)
    
    elif classname.find('BatchNorm') != -1:
        torch.nn.init.normal_(m.weight, 1.0, 0.02)
        torch.nn.init.zeros_(m.bias)

1.构建生成器

class Generator(nn.Module):
    def __init__(self):
        super(Generator, self).__init__()
        self.label_conditioned_generator = nn.Sequential(
            nn.Embedding(n_classes, embedding_dim), 
            nn.Linear(embedding_dim, 16)            
        )
        self.latent = nn.Sequential(
            nn.Linear(latent_dim, 4*4*512),  
            nn.LeakyReLU(0.2, inplace=True)  
        )

        self.model = nn.Sequential( 
            nn.ConvTranspose2d(513, 64*8, 4, 2, 1, bias=False),
            nn.BatchNorm2d(64*8, momentum=0.1, eps=0.8),  
            nn.ReLU(True),            
            nn.ConvTranspose2d(64*8, 64*4, 4, 2, 1, bias=False),
            nn.BatchNorm2d(64*4, momentum=0.1, eps=0.8),
            nn.ReLU(True),     
            nn.ConvTranspose2d(64*4, 64*2, 4, 2, 1, bias=False),
            nn.BatchNorm2d(64*2, momentum=0.1, eps=0.8),
            nn.ReLU(True),       
            nn.ConvTranspose2d(64*2, 64*1, 4, 2, 1, bias=False),
            nn.BatchNorm2d(64*1, momentum=0.1, eps=0.8),
            nn.ReLU(True),       
            nn.ConvTranspose2d(64*1, 3, 4, 2, 1, bias=False),
            nn.Tanh()  
        )

    def forward(self, inputs):
        noise_vector, label = inputs  
        label_output = self.label_conditioned_generator(label)     
        label_output = label_output.view(-1, 1, 4, 4)        
        latent_output = self.latent(noise_vector)     
        latent_output = latent_output.view(-1, 512, 4, 4) 
        concat = torch.cat((latent_output, label_output), dim=1)
        image = self.model(concat)
        return image
generator = Generator().to(device)
generator.apply(weights_init)
print(generator)

输出
在这里插入图片描述

代码知识点
torchinfo是一个用于PyTorch模型的库,它提供了一种简单的方式来获取模型的各种信息,如参数数量、每层的输出形状等。summary函数是torchinfo库中的一个函数,它可以生成一个模型的摘要,包括模型的每一层的信息。

from torchinfo import summary
summary(generator)

输出
在这里插入图片描述
代码知识点
这段代码的含义是将两个张量(tensor)a和b分别转换为设备上的数据类型,并将它们移动到指定的设备上。

  1. a = torch.ones(100):创建一个形状为(100,)的全1张量a。
  2. b = torch.ones(1):创建一个形状为(1,)的全1张量b。
  3. b = b.long():将张量b的数据类型转换为长整型(long)。
  4. a = a.to(device):将张量a移动到指定的设备上。这里的device是一个变量,表示要移动到的设备,例如GPU或CPU。
  5. b = b.to(device):将张量b也移动到指定的设备上。
a = torch.ones(100)
b = torch.ones(1)
b = b.long()
a = a.to(device)
b = b.to(device)

2.构建鉴别器

import torch
import torch.nn as nn

class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()

        self.label_condition_disc = nn.Sequential(
            nn.Embedding(n_classes, embedding_dim),     
            nn.Linear(embedding_dim, 3*128*128)         
        )
        
        self.model = nn.Sequential(
            nn.Conv2d(6, 64, 4, 2, 1, bias=False),      
            nn.LeakyReLU(0.2, inplace=True),             
            nn.Conv2d(64, 64*2, 4, 3, 2, bias=False),    
            nn.BatchNorm2d(64*2, momentum=0.1, eps=0.8),  
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(64*2, 64*4, 4, 3, 2, bias=False),  
            nn.BatchNorm2d(64*4, momentum=0.1, eps=0.8),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(64*4, 64*8, 4, 3, 2, bias=False),  
            nn.BatchNorm2d(64*8, momentum=0.1, eps=0.8),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Flatten(),                               
            nn.Dropout(0.4),                            
            nn.Linear(4608, 1),                         
            nn.Sigmoid()                                
        )

    def forward(self, inputs):
        img, label = inputs
        
        label_output = self.label_condition_disc(label)
        label_output = label_output.view(-1, 3, 128, 128)
        
        concat = torch.cat((img, label_output), dim=1)
        
        output = self.model(concat)
        return output
discriminator = Discriminator().to(device)
discriminator.apply(weights_init)
print(discriminator)

输出
在这里插入图片描述

summary(discriminator)

输出
在这里插入图片描述

a = torch.ones(2,3,128,128)
b = torch.ones(2,1)
b = b.long()
a = a.to(device)
b = b.to(device)
c = discriminator((a,b))
c.size()

输出
torch.Size([2, 1])

三、训练模型

1.定义损失函数

代码知识点
nn.BCELoss() 函数用于计算二分类问题中的二元交叉熵损失

nn.BCELoss()
是PyTorch中提供的一个损失函数,主要用于二元分类问题。它计算的是真实标签与模型预测概率之间的二元交叉熵损失(Binary Cross
Entropy
Loss)。这个损失函数衡量的是模型输出概率与实际标签的一致性,其目的是在训练过程中最小化这个损失值,从而提高模型的预测准确性。

在二元分类问题中,模型的输出通常是一个介于0和1之间的概率值,表示某个样本属于正类的概率。nn.BCELoss()
函数在计算损失时,会将模型预测的概率与真实标签结合起来,通过优化这个过程使得模型能够更好地预测样本的类别。在使用nn.BCELoss()时,通常需要在模型的输出层之前添加一个Sigmoid激活函数,以将模型的输出转换为概率值。

具体来说,二元交叉熵损失函数的计算公式为:

  • BCELoss = - (y * log§ + (1 - y) * log(1 - p))

其中,(y) 是真实标签(通常取值为0或1),§
是模型预测样本为正类的概率。这个公式衡量了预测概率与真实标签之间的差异,当预测准确时损失值较低,预测不准确时损失值较高,通过梯度下降等优化算法可以逐步降低这个损失值,从而提升模型的性能。

总的来说,nn.BCELoss() 是深度学习中常用的损失函数之一,特别适用于二元分类问题,帮助模型学习到更好的预测结果。

这段代码是用于生成对抗网络(GAN)的损失函数计算。下面是对每行代码的解释:
adversarial_loss函数在生成对抗网络(GAN)中用于衡量生成器生成的样本与真实样本之间的差异。

  1. adversarial_loss = nn.BCELoss():定义了一个二进制交叉熵损失函数(Binary Cross Entropy Loss),用于衡量生成器和判别器在生成假样本时与真实标签之间的差异。

  2. def generator_loss(fake_output, label)::定义了一个名为generator_loss的函数,该函数接受两个参数:fake_output表示生成器生成的假样本的输出,label表示真实的标签。

  3. gen_loss = adversarial_loss(fake_output, label):调用前面定义的二进制交叉熵损失函数,将生成器生成的假样本输出和真实标签作为输入,计算生成器的损失值,并将结果赋值给变量gen_loss

  4. return gen_loss:返回生成器的损失值。

  5. def discriminator_loss(output, label)::定义了一个名为discriminator_loss的函数,该函数接受两个参数:output表示判别器对样本的输出,label表示真实的标签。

  6. disc_loss = adversarial_loss(output, label):调用前面定义的二进制交叉熵损失函数,将判别器对样本的输出和真实标签作为输入,计算判别器的损失值,并将结果赋值给变量disc_loss

  7. return disc_loss:返回判别器的损失值。

adversarial_loss = nn.BCELoss() 

def generator_loss(fake_output, label):
    gen_loss = adversarial_loss(fake_output, label)
    return gen_loss

def discriminator_loss(output, label):
    disc_loss = adversarial_loss(output, label)
    return disc_loss

2.定义优化器

代码知识点
在Adam优化器中,betas是一个包含两个数值的元组(beta1, beta2),分别代表
一阶矩估计(梯度的指数加权平均)和二阶矩估计(梯度平方的指数加权平均)的衰减率。

以下是关于Adam优化器中betas参数的一些详细解释:

  • 一阶矩估计(beta1):它反映了历史梯度信息的衰减率。默认情况下,beta1通常设置为0.9,这意味着最近的梯度信息将更加重要,而较旧的梯度信息则会逐渐衰减。
  • 二阶矩估计(beta2):它代表了历史梯度平方信息的衰减率。默认情况下,beta2通常设置为0.999,这有助于更精确地调整学习率。

综上所述,通过适当地调整这些参数,可以更好地控制模型训练过程中的学习动态,进而提高模型的性能。

learning_rate = 0.0002

G_optimizer = optim.Adam(generator.parameters(),     lr = learning_rate, betas=(0.5, 0.999))
D_optimizer = optim.Adam(discriminator.parameters(), lr = learning_rate, betas=(0.5, 0.999))

3.训练模型

代码知识点
这段代码是一个训练生成对抗网络(GAN)的循环。下面是对每一行代码的解释:

  1. num_epochs = 100:设置训练的总轮数为100。
  2. D_loss_plot, G_loss_plot = [], []:初始化两个列表,用于记录判别器和生成器的损失值。
  3. for epoch in range(1, num_epochs + 1)::开始一个循环,循环次数为总轮数。
  4. D_loss_list, G_loss_list = [], []:在每轮开始时,初始化两个列表,用于记录当前轮次中判别器和生成器的损失值。
  5. for index, (real_images, labels) in enumerate(train_loader)::遍历训练数据加载器中的每个批次。
  6. D_optimizer.zero_grad():将判别器的梯度清零。
  7. real_images = real_images.to(device):将真实图像转移到设备上(例如GPU)。
  8. labels = labels.to(device):将标签转移到设备上。
  9. labels = labels.unsqueeze(1).long():将标签的形状从(batch_size,)变为(batch_size,
    1),并将其转换为长整型。
  10. real_target = Variable(torch.ones(real_images.size(0), 1).to(device)):创建一个与真实图像大小相同的全1张量,表示真实的目标。
  11. fake_target = Variable(torch.zeros(real_images.size(0), 1).to(device)):创建一个与真实图像大小相同的全0张量,表示假的目标。
  12. D_real_loss = discriminator_loss(discriminator((real_images, labels)), real_target):计算判别器对真实图像的损失。
  13. noise_vector = torch.randn(real_images.size(0), latent_dim, device=device):生成随机噪声向量。
  14. noise_vector = noise_vector.to(device):将噪声向量转移到设备上。
  15. generated_image = generator((noise_vector, labels)):使用生成器生成图像。
  16. output = discriminator((generated_image.detach(), labels)):将生成的图像输入判别器。
  17. D_fake_loss = discriminator_loss(output, fake_target):计算判别器对生成图像的损失。
  18. D_total_loss = (D_real_loss + D_fake_loss) / 2:计算判别器的总体损失。
  19. D_loss_list.append(D_total_loss):将判别器的损失值添加到列表中。
  20. D_total_loss.backward():计算判别器的梯度。
  21. D_optimizer.step():更新判别器的参数。
  22. G_optimizer.zero_grad():将生成器的梯度清零。
  23. G_loss = generator_loss(discriminator((generated_image, labels)), real_target):计算生成器的损失。
  24. G_loss_list.append(G_loss):将生成器的损失值添加到列表中。
  25. G_loss.backward():计算生成器的梯度。
  26. G_optimizer.step():更新生成器的参数。
  27. print('Epoch: [%d/%d]: D_loss: %.3f, G_loss: %.3f' % ((epoch), num_epochs, torch.mean(torch.FloatTensor(D_loss_list)), torch.mean(torch.FloatTensor(G_loss_list)))):打印当前轮次的判别器和生成器的平均损失值。
  28. D_loss_plot.append(torch.mean(torch.FloatTensor(D_loss_list))):将判别器的平均损失值添加到损失值列表中。
  29. G_loss_plot.append(torch.mean(torch.FloatTensor(G_loss_list))):将生成器的平均损失值添加到损失值列表中。
  30. if epoch%10 == 0::如果当前轮次是10的倍数,则执行以下操作。
  31. save_image(generated_image.data[:50], './images/sample_%d' % epoch + '.png', nrow=5, normalize=True):保存生成的前50个图像。
  32. torch.save(generator.state_dict(), './training_weights/generator_epoch_%d.pth' % (epoch)):保存生成器的权重。
  33. torch.save(discriminator.state_dict(), './training_weights/discriminator_epoch_%d.pth' % (epoch)):保存判别器的权重。
num_epochs = 100

D_loss_plot, G_loss_plot = [], []

for epoch in range(1, num_epochs + 1):
    
    D_loss_list, G_loss_list = [], []
    
    for index, (real_images, labels) in enumerate(train_loader):
        
        D_optimizer.zero_grad()
        
        real_images = real_images.to(device)
        labels      = labels.to(device)
        
        
        labels = labels.unsqueeze(1).long()
        
        real_target = Variable(torch.ones(real_images.size(0), 1).to(device))
        fake_target = Variable(torch.zeros(real_images.size(0), 1).to(device))

        D_real_loss = discriminator_loss(discriminator((real_images, labels)), real_target)
        
        noise_vector = torch.randn(real_images.size(0), latent_dim, device=device)
        noise_vector = noise_vector.to(device)
        generated_image = generator((noise_vector, labels))
        
        output = discriminator((generated_image.detach(), labels))
        D_fake_loss = discriminator_loss(output, fake_target)

        D_total_loss = (D_real_loss + D_fake_loss) / 2
        D_loss_list.append(D_total_loss)

        D_total_loss.backward()
        D_optimizer.step()

        G_optimizer.zero_grad()
        G_loss = generator_loss(discriminator((generated_image, labels)), real_target)
        G_loss_list.append(G_loss)

        G_loss.backward()
        G_optimizer.step()

    print('Epoch: [%d/%d]: D_loss: %.3f, G_loss: %.3f' % (
            (epoch), num_epochs, torch.mean(torch.FloatTensor(D_loss_list)), 
            torch.mean(torch.FloatTensor(G_loss_list))))
    
    D_loss_plot.append(torch.mean(torch.FloatTensor(D_loss_list)))
    G_loss_plot.append(torch.mean(torch.FloatTensor(G_loss_list)))

    if epoch%10 == 0:
        save_image(generated_image.data[:50], './images/sample_%d' % epoch + '.png', nrow=5, normalize=True)
        torch.save(generator.state_dict(), './training_weights/generator_epoch_%d.pth' % (epoch))
        torch.save(discriminator.state_dict(), './training_weights/discriminator_epoch_%d.pth' % (epoch))

梯度清零的作用是重置梯度信息,确保每个batch的梯度计算是独立的。

在训练神经网络时,梯度清零(通常通过调用优化器的.zero_grad()方法实现)是一个重要的步骤。由于PyTorch等框架中,梯度是累加的,即如果不清零,每次计算的梯度会叠加到之前的梯度上。这会导致不同batch的数据对模型参数更新的贡献混在一起,从而影响训练的准确性。通过梯度清零,我们确保了每个batch的梯度计算是独立于其他batches的,这样每个batch的梯度就能准确地反映出该batch数据对模型的影响。

进行梯度清零的原因是为了避免梯度之间的混合和累积,保持模型更新的准确性

在深度学习训练中,通常我们希望每个batch的梯度更新只基于该batch的数据。如果不进行梯度清零,前一个batch计算出的梯度会影响后一个batch的梯度计算,导致梯度不断累积,最终影响模型参数的更新。此外,梯度清零也与“梯度累加”策略相关,该策略允许我们在硬件资源有限的情况下,通过多次累加小batch的梯度来模拟大batch的训练效果,这样可以在一定程度上缓解显存限制问题。

将真实图像转移到设备上(例如GPU)的原因是为了加快计算速度和提高计算效率

在深度学习训练中,尤其是在处理大规模数据集时,计算量通常非常大。CPU(中央处理单元)负责通用计算任务,而GPU(图形处理单元)则专为并行计算设计,能够同时处理大量数据。因此,当涉及到大量的矩阵运算、卷积运算等操作时,使用GPU可以显著加速这些计算过程。

具体来说,将真实图像转移到设备上的原因包括:

  1. 并行计算能力:GPU拥有成百上千个计算核心,能够同时执行多个计算任务。这意味着在处理图像或其他数据时,GPU可以在相同的时间内完成更多的计算工作。

  2. 提高内存带宽:GPU的内存带宽远高于CPU,这使得数据在GPU内存与计算核心之间的传输更加迅速,减少了数据传输的延迟。

  3. 专门的优化:深度学习框架(如PyTorch、TensorFlow等)通常对GPU进行专门的优化,以充分利用其计算能力。这些优化包括自动分配显存、优化张量操作等。

  4. 加快模型训练:通过将数据和模型参数存储在GPU上,可以加快模型的前向传播和反向传播过程,从而缩短训练时间。

  5. 支持更大规模的模型和数据集:GPU的显存通常比CPU内存大得多,这使得我们可以在有限的硬件资源下训练更大规模的模型和处理更大的数据集。

总之,将真实图像转移到设备上(尤其是GPU)是为了更好地利用GPU的高性能计算能力,加快深度学习模型的训练速度,提高整体的训练效率。这对于处理复杂的神经网络和大规模的数据集至关重要。

输出
在这里插入图片描述
在这里插入图片描述

四、模型分析

1.加载模型

输出
generator.load_state_dict() 是一个PyTorch函数,用于加载预训练模型的权重。

参数解释:

  • torch.load('./training_weights/generator_epoch_100.pth'):这部分代码是使用 PyTorch 的 torch.load()
    函数从指定的文件路径(‘./training_weights/generator_epoch_100.pth’)中加载预训练模型的权重。
  • strict=False:这是一个可选参数,默认值为 True。当设置为 False 时,如果预训练模型的架构与当前模型的架构不完全匹配,那么 PyTorch
    将只加载匹配的部分,并忽略不匹配的部分。这在迁移学习中非常有用,因为我们可以加载一个预训练模型的一部分权重,而不需要完全匹配整个模型的架构。
generator.load_state_dict(torch.load('./training_weights/generator_epoch_100.pth'), strict=False)
generator.eval()               

输出
在这里插入图片描述

代码知识点
这段代码的作用是生成一些随机的点,然后对这些点进行插值处理,最后使用一个生成器模型(generator)对插值后的点进行预测。

  1. 导入所需的库和函数:numpy、matplotlib、torch等。
  2. 定义一个函数generate_latent_points,用于生成随机的点。输入参数为潜在空间维度(latent_dim)、样本数量(n_samples)和类别数量(n_classes,默认为3)。函数内部首先生成一个随机数矩阵,然后将其重塑为指定的形状,最后返回这个矩阵。
  3. 定义一个函数interpolate_points,用于在两个点之间进行插值。输入参数为两个点(p1和p2)以及插值步数(n_steps,默认为10)。函数内部首先计算插值比例,然后根据这些比例计算插值后的点,并将它们添加到一个列表中。最后将这个列表转换为numpy数组并返回。
  4. 调用generate_latent_points函数生成两个随机点,并将它们存储在变量pts中。
  5. 调用interpolate_points函数对这两个点进行插值处理,并将结果存储在变量interpolated中。然后将这个张量转换为PyTorch张量,并将其移动到指定的设备(device)上,并将其类型转换为float32。
  6. 初始化一个名为output的变量,用于存储生成器的预测结果。
  7. 对于每个类别(共3个),创建一个全为该类别标签的张量(长度为10),并将其移动到指定的设备上。然后将这个张量的维度扩展为(batch_size,
    1),并将其类型转换为long。接着打印这个张量的大小。
  8. 使用生成器模型对插值后的点进行预测,并将结果存储在变量predictions中。然后将这个张量的维度重新排列,并将其从GPU上移回到CPU上。如果output为空,则将pred赋值给output;否则,将predoutput进行拼接。
from numpy import asarray
from numpy.random import randn
from numpy.random import randint
from numpy import linspace
from matplotlib import pyplot
from matplotlib import gridspec

def generate_latent_points(latent_dim, n_samples, n_classes=3):
    x_input = randn(latent_dim * n_samples)
    z_input = x_input.reshape(n_samples, latent_dim)
    return z_input

def interpolate_points(p1, p2, n_steps=10):
    ratios = linspace(0, 1, num=n_steps)
    vectors = list()
    for ratio in ratios:
        v = (1.0 - ratio) * p1 + ratio * p2
        vectors.append(v)
    return asarray(vectors)

pts = generate_latent_points(100, 2)
interpolated = interpolate_points(pts[0], pts[1])
interpolated = torch.tensor(interpolated).to(device).type(torch.float32)

output = None
for label in range(3):
    labels = torch.ones(10) * label
    labels = labels.to(device)
    labels = labels.unsqueeze(1).long()
    print(labels.size())
    predictions = generator((interpolated, labels))
    predictions = predictions.permute(0,2,3,1)
    pred = predictions.detach().cpu()
    if output is None:
        output = pred
    else:
        output = np.concatenate((output,pred))

输出
torch.Size([10, 1])
torch.Size([10, 1])
torch.Size([10, 1])

output.shape

输出
(30, 128, 128, 3)
代码知识点
这段代码的作用是在一个图形窗口中显示多个子图,每个子图显示一个预测结果。具体解释如下:

  1. nrow = 3ncol = 10 分别表示子图的行数和列数。
  2. fig = plt.figure(figsize=(15,4)) 创建一个图形窗口,设置窗口的大小为宽15英寸,高4英寸。
  3. gs = gridspec.GridSpec(nrow, ncol) 使用gridspec模块创建一个网格布局,用于指定子图的位置。
  4. k = 0 初始化一个计数器变量k,用于遍历输出数组output
  5. for i in range(nrow): 循环遍历行数。
  6. for j in range(ncol): 循环遍历列数。
  7. pred = (output[k, :, :, :] + 1 ) * 127.5 对输出数组output的第k个元素进行预处理,将其值范围从[-1, 1]映射到[0, 255]。
  8. pred = np.array(pred) 将预处理后的结果转换为NumPy数组。
  9. ax= plt.subplot(gs[i,j]) 在网格布局中创建子图,并将其赋值给变量ax
  10. ax.imshow(pred.astype(np.uint8)) 在子图中显示预处理后的图像,数据类型为无符号整数(np.uint8)。
  11. ax.set_xticklabels([]) 移除子图的x轴刻度标签。
  12. ax.set_yticklabels([]) 移除子图的y轴刻度标签。
  13. ax.axis('off') 关闭子图的坐标轴。
  14. k += 1 更新计数器变量k的值。
  15. plt.show() 显示图形窗口。
nrow = 3
ncol = 10

fig = plt.figure(figsize=(15,4))
gs = gridspec.GridSpec(nrow, ncol) 

k = 0
for i in range(nrow):
    for j in range(ncol):
        pred = (output[k, :, :, :] + 1 ) * 127.5
        pred = np.array(pred)  
        ax= plt.subplot(gs[i,j])
        ax.imshow(pred.astype(np.uint8))
        ax.set_xticklabels([])
        ax.set_yticklabels([])
        ax.axis('off')
        k += 1   

plt.show()

输出
在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/446654.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

vue3 ref获取子组件显示 __v_skip : true 获取不到组件的方法 怎么回事怎么解决

看代码 问题出现了 当我想要获取这个组件上的方法时 为什么获取不到这个组件上的方法呢 原來: __v_skip: true 是 Vue 3 中的一个特殊属性,用于跳过某些组件的渲染。当一个组件被标记为 __v_skip: true 时,Vue 将不会对该组件进行渲染&am…

ABAP接口-RFC连接(ABAP TO ABAP)

目录 ABAP接口-RFC连接(ABAP TO ABAP)创建ABAP连接RFC函数的调用 ABAP接口-RFC连接(ABAP TO ABAP) 创建ABAP连接 事务代码:SM59 点击创建,填写目标名称,选择连接类型: 填写主机名…

打卡--MySQL8.0 一(单机部署)

一路走来,所有遇到的人,帮助过我的、伤害过我的都是朋友,没有一个是敌人。如有侵权,请留言,我及时删除! MySQL 8.0 简介 MySQL 8.0与5.7的区别主要体现在:1、性能提升;2、新的默认…

02-app端文章查看,静态化freemarker,分布式文件系统minIO

app端文章查看,静态化freemarker,分布式文件系统minIO 1)文章列表加载 1.1)需求分析 文章布局展示 1.2)表结构分析 ap_article 文章基本信息表 ap_article_config 文章配置表 ap_article_content 文章内容表 三张表关系分析 1.3)导入文章数据库 1.3.1)导入数据…

【vue.js】文档解读【day 2】 | 响应式基础

如果阅读有疑问的话,欢迎评论或私信!! 本人会很热心的阐述自己的想法!谢谢!!! 文章目录 响应式基础声明响应式状态(属性)响应式代理 vs 原始值声明方法深层响应性DOM 更新时机有状态方法 响应式…

电脑数据安全新防线:文件备份的终极指南

一、数据守护者的使命:文件备份的重要性 在数字化日益普及的今天,电脑已成为我们日常生活和工作的必备工具,文件作为我们储存、交流和处理信息的主要载体,其重要性不言而喻。然而,无论是由于硬件故障、软件崩溃&#…

Autosar教程-Mcal教程-GPT配置教程

3.3GPT配置、生成 3.3.1 GPT配置所需要的元素 GPT实际上就是硬件定时器,需要配置的元素有: 1)定时器时钟:定时器要工作需要使能它的时钟源 2)定时器分步:时钟源进到定时器后可以通过分频后再给到定时器 定时器模块选择:MCU有多个定时器模块,需要决定使用哪个定时器模块作…

【动态规划】代码随想录算法训练营第三十九天 |62.不同路径,63.不同路径II(待补充)

62.不同路径 1、题目链接:. - 力扣(LeetCode) 2、文章讲解:代码随想录 3、题目: 一个机器人位于一个 m x n 网格的左上角 (起始点在下图中标记为 “Start” )。 机器人每次只能向下或者向右…

JavaScript高级Ⅱ(全面版)

接上文 JavaScript高级Ⅰ JavaScript高级Ⅰ(自认为很全面版)-CSDN博客 目录 第2章 DOM编程 2.1 DOM编程概述 2.1.4 案例演示(商品全选) 2.1.5 dom操作内容 代码演示: 运行效果: 2.1.6 dom操作属性 代码演示: 运行效果: 2…

H264/265编码参数2: Profile Level Tier

profile和level profile和level是视频编码中两个很重要的概率,中文一般叫做档次和级别。 在MPEG2标准里边,按不同的压缩比分成五个档次,按视频清晰度分为四个级别,如下图所示: 档次和级别共有 20 种组合,…

2024年【化工自动化控制仪表】考试总结及化工自动化控制仪表作业考试题库

题库来源:安全生产模拟考试一点通公众号小程序 2024年【化工自动化控制仪表】考试总结及化工自动化控制仪表作业考试题库,包含化工自动化控制仪表考试总结答案和解析及化工自动化控制仪表作业考试题库练习。安全生产模拟考试一点通结合国家化工自动化控…

day-18 长度最小的子数组

运用队列的思维&#xff0c;求出每种满足题意的子数组长度&#xff0c;最小的即为答案&#xff0c;否则返回0 code class Solution {public int minSubArrayLen(int target, int[] nums) {int l0,r0;int ansInteger.MAX_VALUE;int total0;while(r<nums.length){totalnums[r…

C++:类和对象(三)——拷贝构造函数和运算符重载

目录 一、拷贝构造函数 1.概念 2.特性 二、赋值运算符重载 1.运算符重载 2.赋值运算符重载 &#xff08;1&#xff09;注意的点&#xff1a; &#xff08;2&#xff09;赋值运算符不允许被重载为全局函数&#xff0c;只能重载为类的成员函数 &#xff08;3&#xff09;…

STM32单片机示例:ETH_LAN8742_DHCP_NonOS_Poll_H743

文章目录 目的基础说明关键配置关键代码示例链接总结 目的 以太网是比较常用到的功能&#xff0c;STM32系列单片机使用CubeMX配置使用以太网功能比非常方便。不过对于H7系列来说需要使能 DCache 才能启用LwIP&#xff0c;启用Cache后又会带来一些需要特别注意的事情。这篇文章…

HarBor私有镜像仓库安装部署

环境准备 #>>> redis $ yum -y install redis $ systemctl enable --now redis $ vim /etc/redis.conf modify: bind <ipaddress> $ systemctl restart redis#>>> nfs $ yum -y install nfs-utils $ mkdir -p /data/harbor $ vi /etc/exports /data/h…

简介:CMMI软件能力成熟度集成模型

前言 CMMI是英文Capability Maturity Model Integration的缩写。 CMMI认证简称软件能力成熟度集成模型&#xff0c;是鉴定企业在开发流程化和质量管理上的国际通行标准&#xff0c;全球软件生产标准大都以此为基点&#xff0c;并都努力争取成为CMMI认证队伍中的一分子。 对一个…

动静态库

inode inode用于管理文件属性和内容 一个文件只能有一个inode&#xff0c;一个inode可以对应多个文件名 Linux进程中&#xff0c;打开的每一个文件都有对应的文件inode属性和文件页缓冲区&#xff08;内存和磁盘的缓冲区&#xff09; 软硬链接 硬链接 多个文件指向同一个i…

Gradle模块化最佳实践

一&#xff0c;模块化的原因及意义 模块化是一种将大型的软件系统拆分成相互独立的模块的方法。具有以下优势&#xff1a; 代码复用&#xff1a;不同的模块可以共享相同的代码。这样可以避免重复编写相同的代码&#xff0c;提高开发效率。模块独立性&#xff1a;每个模块都可…

【安装教程】windows下安装Faiss-GPU

提示&#xff1a;文章写完后&#xff0c;目录可以自动生成&#xff0c;如何生成可参考右边的帮助文档 【安装教程】windows下安装Faiss-GPU 查看安装指令 查看安装指令 登录网站&#xff1a;https://anaconda.org/ &#xff0c; 然后搜索faiss-gpu会进入如下界面&#xff0c;或…

Vue 3中的reactive:响应式状态的全面管理

&#x1f90d; 前端开发工程师、技术日更博主、已过CET6 &#x1f368; 阿珊和她的猫_CSDN博客专家、23年度博客之星前端领域TOP1 &#x1f560; 牛客高级专题作者、打造专栏《前端面试必备》 、《2024面试高频手撕题》 &#x1f35a; 蓝桥云课签约作者、上架课程《Vue.js 和 E…