- 🍨 本文为🔗365天深度学习训练营 中的学习记录博客
- 🍦 参考文章:TensorFlow入门实战|第3周:天气识别
- 🍖 原作者:K同学啊|接辅导、项目定制
CGAN(条件生成对抗网络)的原理是在原始GAN的基础上,为生成器和判别器提供 额外的条件信息。
CGAN通过将条件信息(如类别标签或其他辅助信息)加入生成器和判别器的输入中,使得生成器能够根据这些条件信息生成特定类型的数据,而判别器则负责区分真实数据和生成数据是否符合这些条件。这种方式让生成器在生成数据时有了明确的方向,从而提高了生成数据的质量与相关性。
CGAN的特点包括有监督学习、联合隐层表征、可控性、使用卷积结构等,其具体内容为:
有监督学习:CGAN通过额外信息的使用,将原本无监督的GAN转变为一种有监督的学习模式,这使得网络的训练更加目标明确,生成结果更加符合预期。
联合隐层表征:在生成模型中,噪声输入和条件信息共同构成了联合隐层表征,这有助于生成更多样化且具有特定属性的数据。
可控性:CGAN的一个关键特点是提高了生成过程的可控性,即可以通过调整条件信息来指导模型生成特定类型的数据。
使用卷积结构:CGAN可以采用卷积神经网络作为其内部结构,这在图像相关的任务中尤其有效,因为它能够捕捉到局部特征,并提高模型对细节的处理能力。
一、前期工作
import torch
import numpy as np
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from torch.autograd import Variable
from torchvision.utils import save_image, make_grid
from torchsummary import summary
import matplotlib.pyplot as plt
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
batch_size = 128
train_transform = transforms.Compose([
transforms.Resize(128),
transforms.ToTensor(),
transforms.Normalize([0.5,0.5,0.5], [0.5,0.5,0.5])])
train_dataset = datasets.ImageFolder(root="H:/G3/rps/rps", transform=train_transform)
train_loader = torch.utils.data.DataLoader(dataset=train_dataset,
batch_size=batch_size,
shuffle=True,
num_workers=6)
def show_images(dl):
for images, _ in dl:
fig, ax = plt.subplots(figsize=(10, 10))
ax.set_xticks([]); ax.set_yticks([])
ax.imshow(make_grid(images.detach(), nrow=16).permute(1, 2, 0))
break
show_images(train_loader)
二、构建模型
latent_dim = 100
n_classes = 3
embedding_dim = 100
def weights_init(m):
classname = m.__class__.__name__
if classname.find('Conv') != -1:
torch.nn.init.normal_(m.weight, 0.0, 0.02)
elif classname.find('BatchNorm') != -1:
torch.nn.init.normal_(m.weight, 1.0, 0.02)
torch.nn.init.zeros_(m.bias)
class Generator(nn.Module):
def __init__(self):
super(Generator, self).__init__()
self.label_conditioned_generator = nn.Sequential(
nn.Embedding(n_classes, embedding_dim),
nn.Linear(embedding_dim, 16)
)
self.latent = nn.Sequential(
nn.Linear(latent_dim, 4*4*512),
nn.LeakyReLU(0.2, inplace=True)
)
self.model = nn.Sequential(
nn.ConvTranspose2d(513, 64*8, 4, 2, 1, bias=False),
nn.BatchNorm2d(64*8, momentum=0.1, eps=0.8),
nn.ReLU(True),
nn.ConvTranspose2d(64*8, 64*4, 4, 2, 1, bias=False),
nn.BatchNorm2d(64*4, momentum=0.1, eps=0.8),
nn.ReLU(True),
nn.ConvTranspose2d(64*4, 64*2, 4, 2, 1, bias=False),
nn.BatchNorm2d(64*2, momentum=0.1, eps=0.8),
nn.ReLU(True),
nn.ConvTranspose2d(64*2, 64*1, 4, 2, 1, bias=False),
nn.BatchNorm2d(64*1, momentum=0.1, eps=0.8),
nn.ReLU(True),
nn.ConvTranspose2d(64*1, 3, 4, 2, 1, bias=False),
nn.Tanh()
)
def forward(self, inputs):
noise_vector, label = inputs
label_output = self.label_conditioned_generator(label)
label_output = label_output.view(-1, 1, 4, 4)
latent_output = self.latent(noise_vector)
latent_output = latent_output.view(-1, 512, 4, 4)
concat = torch.cat((latent_output, label_output), dim=1)
image = self.model(concat)
return image
generator = Generator().to(device)
generator.apply(weights_init)
a = torch.ones(100)
b = torch.ones(1)
b = b.long()
a = a.to(device)
b = b.to(device)
import torch
import torch.nn as nn
class Discriminator(nn.Module):
def __init__(self):
super(Discriminator, self).__init__()
self.label_condition_disc = nn.Sequential(
nn.Embedding(n_classes, embedding_dim),
nn.Linear(embedding_dim, 3*128*128)
)
self.model = nn.Sequential(
nn.Conv2d(6, 64, 4, 2, 1, bias=False),
nn.LeakyReLU(0.2, inplace=True),
nn.Conv2d(64, 64*2, 4, 3, 2, bias=False),
nn.BatchNorm2d(64*2, momentum=0.1, eps=0.8),
nn.LeakyReLU(0.2, inplace=True),
nn.Conv2d(64*2, 64*4, 4, 3, 2, bias=False),
nn.BatchNorm2d(64*4, momentum=0.1, eps=0.8),
nn.LeakyReLU(0.2, inplace=True),
nn.Conv2d(64*4, 64*8, 4, 3, 2, bias=False),
nn.BatchNorm2d(64*8, momentum=0.1, eps=0.8),
nn.LeakyReLU(0.2, inplace=True),
nn.Flatten(),
nn.Dropout(0.4),
nn.Linear(4608, 1),
nn.Sigmoid()
)
def forward(self, inputs):
img, label = inputs
label_output = self.label_condition_disc(label)
label_output = label_output.view(-1, 3, 128, 128)
concat = torch.cat((img, label_output), dim=1)
output = self.model(concat)
return output
a = torch.ones(2,3,128,128)
b = torch.ones(2,1)
b = b.long()
a = a.to(device)
b = b.to(device)
c = discriminator((a,b))
三、训练模型及可视化
这一部分主要定义初始化权重,构建鉴别器和生成器。
# 定义损失函数
adversarial_loss = nn.BCELoss()
def generator_loss(fake_output, label):
gen_loss = adversarial_loss(fake_output, label)
return gen_loss
def discriminator_loss(output, label):
disc_loss = adversarial_loss(output, label)
return disc_loss
learning_rate = 0.0002
G_optimizer = optim.Adam(generator.parameters(), lr = learning_rate, betas=(0.5, 0.999))
D_optimizer = optim.Adam(discriminator.parameters(), lr = learning_rate, betas=(0.5, 0.999))
# 设置训练的总轮数
num_epochs = 100
# 初始化用于存储每轮训练中判别器和生成器损失的列表
D_loss_plot, G_loss_plot = [], []
# 循环进行训练
for epoch in range(1, num_epochs + 1):
# 初始化每轮训练中判别器和生成器损失的临时列表
D_loss_list, G_loss_list = [], []
# 遍历训练数据加载器中的数据
for index, (real_images, labels) in enumerate(train_loader):
# 清空判别器的梯度缓存
D_optimizer.zero_grad()
# 将真实图像数据和标签转移到GPU(如果可用)
real_images = real_images.to(device)
labels = labels.to(device)
# 将标签的形状从一维向量转换为二维张量(用于后续计算)
labels = labels.unsqueeze(1).long()
# 创建真实目标和虚假目标的张量(用于判别器损失函数)
real_target = Variable(torch.ones(real_images.size(0), 1).to(device))
fake_target = Variable(torch.zeros(real_images.size(0), 1).to(device))
# 计算判别器对真实图像的损失
D_real_loss = discriminator_loss(discriminator((real_images, labels)), real_target)
# 从噪声向量中生成假图像(生成器的输入)
noise_vector = torch.randn(real_images.size(0), latent_dim, device=device)
noise_vector = noise_vector.to(device)
generated_image = generator((noise_vector, labels))
# 计算判别器对假图像的损失(注意detach()函数用于分离生成器梯度计算图)
output = discriminator((generated_image.detach(), labels))
D_fake_loss = discriminator_loss(output, fake_target)
# 计算判别器总体损失(真实图像损失和假图像损失的平均值)
D_total_loss = (D_real_loss + D_fake_loss) / 2
D_loss_list.append(D_total_loss)
# 反向传播更新判别器的参数
D_total_loss.backward()
D_optimizer.step()
# 清空生成器的梯度缓存
G_optimizer.zero_grad()
# 计算生成器的损失
G_loss = generator_loss(discriminator((generated_image, labels)), real_target)
G_loss_list.append(G_loss)
# 反向传播更新生成器的参数
G_loss.backward()
G_optimizer.step()
# 打印当前轮次的判别器和生成器的平均损失
print('Epoch: [%d/%d]: D_loss: %.3f, G_loss: %.3f' % (
(epoch), num_epochs, torch.mean(torch.FloatTensor(D_loss_list)),
torch.mean(torch.FloatTensor(G_loss_list))))
# 将当前轮次的判别器和生成器的平均损失保存到列表中
D_loss_plot.append(torch.mean(torch.FloatTensor(D_loss_list)))
G_loss_plot.append(torch.mean(torch.FloatTensor(G_loss_list)))
if epoch%10 == 0:
# 将生成的假图像保存为图片文件
save_image(generated_image.data[:50], './sample_%d' % epoch + '.png', nrow=5, normalize=True)
# 将当前轮次的生成器和判别器的权重保存到文件
torch.save(generator.state_dict(), './generator_epoch_%d.pth' % (epoch))
torch.save(discriminator.state_dict(), './discriminator_epoch_%d.pth' % (epoch))
from numpy.random import randint, randn
from numpy import linspace
from matplotlib import pyplot as plt, gridspec
import numpy as np
# Assuming 'generator' and 'device' are defined earlier in your code
generator.load_state_dict(torch.load('./generator_epoch_100.pth'), strict=False)
generator.eval()
interpolated = randn(100)
interpolated = torch.tensor(interpolated).to(device).type(torch.float32)
label = 0
labels = torch.ones(1) * label
labels = labels.to(device).unsqueeze(1).long()
predictions = generator((interpolated, labels))
predictions = predictions.permute(0, 2, 3, 1).detach().cpu()
import warnings
warnings.filterwarnings("ignore")
plt.rcParams['font.sans-serif'] = ['SimHei']
plt.rcParams['axes.unicode_minus'] = False
plt.rcParams['figure.dpi'] = 100
plt.figure(figsize=(8, 3))
pred = (predictions[0, :, :, :] + 1) * 127.5
pred = np.array(pred)
plt.imshow(pred.astype(np.uint8))
plt.show()
代码中的操作将预测结果的值加1(这样所有的值都变为非负数),然后乘以127.5,最终得到的值就在0到255之间。