GAN 是一种常用的优秀的图像生成模型。我们使用了支持条件生成的 cGAN。下面介绍简单 cGAN 模型的构建以及训练过程。
2.1 在 model 文件夹中新建 nets.py 文件
import torch
import torch.nn as nn
# 生成器类
class Generator(nn.Module):
def __init__(self, nz=100, nc=3, ngf=128, num_classes=4):
super(Generator, self).__init__()
self.label_emb = nn.Embedding(num_classes, nz)
self.main = nn.Sequential(
nn.ConvTranspose2d(nz + nz, ngf * 8, 4, 1, 0, bias=False),
nn.BatchNorm2d(ngf * 8),
nn.ReLU(True),
nn.ConvTranspose2d(ngf * 8, ngf * 4, 4, 2, 1, bias=False),
nn.BatchNorm2d(ngf * 4),
nn.ReLU(True),
nn.ConvTranspose2d(ngf * 4, ngf * 2, 4, 2, 1, bias=False),
nn.BatchNorm2d(ngf * 2),
nn.ReLU(True),
nn.ConvTranspose2d(ngf * 2, ngf, 4, 2, 1, bias=False),
nn.BatchNorm2d(ngf),
nn.ReLU(True),
nn.ConvTranspose2d(ngf, nc, 4, 2, 1, bias=False),
nn.Tanh()
)
def forward(self, z, labels):
c = self.label_emb(labels).unsqueeze(2).unsqueeze(3)
x = torch.cat([z, c], 1)
return self.main(x)
# 判别器类
class Discriminator(nn.Module):
def __init__(self, nc=3, ndf=64, num_classes=4):
super(Discriminator, self).__init__()
self.label_emb = nn.Embedding(num_classes, nc * 64 * 64)
self.main = nn.Sequential(
nn.Conv2d(nc + 1, ndf, 4, 2, 1, bias=False),
nn.LeakyReLU(0.2, inplace=True),
nn.Conv2d(ndf, ndf * 2, 4, 2, 1, bias=False),
nn.BatchNorm2d(ndf * 2),
nn.LeakyReLU(0.2, inplace=True),
nn.Conv2d(ndf * 2, ndf * 4, 4, 2, 1, bias=False),
nn.BatchNorm2d(ndf * 4),
nn.LeakyReLU(0.2, inplace=True),
nn.Conv2d(ndf * 4, 1, 4, 1, 0, bias=False),
nn.Sigmoid()
)
def forward(self, img, labels):
c = self.label_emb(labels).view(labels.size(0), 1, 64, 64)
x = torch.cat([img, c], 1)
return self.main(x)
2.2新建cGAN_net.py
import torch
import torch.nn as nn
from torch.optim import Adam
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
from torch.optim.lr_scheduler import StepLR
# ===========================
# Conditional DCGAN 实现
# ===========================
class cDCGAN:
def __init__(self, data_root, batch_size, device, latent_dim=100, num_classes=4):
self.device = device
self.batch_size = batch_size
self.latent_dim = latent_dim
self.num_classes = num_classes
# 数据加载器
self.train_loader = self.get_dataloader(data_root)
# 初始化生成器和判别器
self.generator = self.build_generator().to(device)
self.discriminator = self.build_discriminator().to(device)
# 初始化权重
self.generator.apply(self.weights_init)
self.discriminator.apply(self.weights_init)
# 损失函数和优化器
self.criterion = nn.BCELoss()
self.optimizer_G = Adam(self.generator.parameters(), lr=0.0001, betas=(0.5, 0.999))
self.optimizer_D = Adam(self.discriminator.parameters(), lr=0.0001, betas=(0.5, 0.999))
# 学习率调度器
self.scheduler_G = StepLR(self.optimizer_G, step_size=10, gamma=0.5) # 每10个epoch学习率减半
self.scheduler_D = StepLR(self.optimizer_D, step_size=10, gamma=0.5)
def get_dataloader(self, data_root):
transform = transforms.Compose([
transforms.Resize(128),
transforms.ToTensor(),
transforms.Normalize((0.5,), (0.5,))
])
dataset = datasets.ImageFolder(root=data_root, transform=transform)
return DataLoader(dataset, batch_size=self.batch_size, shuffle=True,
num_workers=8, pin_memory=True, persistent_workers=True)
@staticmethod
def weights_init(model):
"""权重初始化"""
if isinstance(model, (nn.Conv2d, nn.Linear)):
nn.init.normal_(model.weight.data, 0.0, 0.02)
if model.bias is not None:
nn.init.constant_(model.bias.data, 0)
def train_step(self, epoch, step, num_epochs):
"""单次训练步骤"""
self.generator.train()
self.discriminator.train()
G_losses, D_losses = [], []
for i, (real_img, labels) in enumerate(self.train_loader):
# 确保 real_img 和 labels 在同一设备
real_img = real_img.to(self.device)
labels = labels.to(self.device)
batch_size = real_img.size(0)
# # 标签 11.19 15:11:12修改
# valid = torch.ones((batch_size, 1), device=self.device)
# fake = torch.zeros((batch_size, 1), device=self.device)
# 标签平滑
# smooth_valid = torch.full((batch_size, 1), 1, device=self.device) # 平滑真实标签
# smooth_fake = torch.full((batch_size, 1), 0, device=self.device) # 平滑伪造标签
# smooth_valid = torch.full((batch_size, 1), torch.rand(1).item() * 0.1 + 0.9, device=self.device)
# smooth_fake = torch.full((batch_size, 1), torch.rand(1).item() * 0.1, device=self.device)
# smooth_valid = torch.full((batch_size, 1), max(0.7, 1 - epoch * 0.001), device=self.device)
# smooth_fake = torch.full((batch_size, 1), min(0.3, epoch * 0.001), device=self.device)
# 动态调整标签范围
smooth_valid = torch.full((batch_size, 1), max(0.9, 1 - 0.0001 * epoch), device=self.device)
smooth_fake = torch.full((batch_size, 1), min(0.1, 0.0001 * epoch), device=self.device)
# 替换以下两处代码
valid = smooth_valid
fake = smooth_fake
# ========== 训练判别器 ==========
real_pred = self.discriminator(real_img, labels)
# d_real_loss = self.criterion(real_pred, valid)
d_real_loss = self.criterion(real_pred, valid - 0.1 * torch.rand_like(valid))
noise = torch.randn(batch_size, self.latent_dim, device=self.device)
# gen_labels = torch.randint(0, self.num_classes, (batch_size,), device=self.device)
gen_labels = torch.randint(0, self.num_classes, (batch_size,), device=self.device) + torch.randint(-1, 2, (
batch_size,), device=self.device)
gen_labels = torch.clamp(gen_labels, 0, self.num_classes - 1) # 确保标签在范围内
gen_img = self.generator(noise, gen_labels)
fake_pred = self.discriminator(gen_img.detach(), gen_labels)
# d_fake_loss = self.criterion(fake_pred, fake)
d_fake_loss = self.criterion(fake_pred, fake + 0.1 * torch.rand_like(fake))
d_loss = (d_real_loss + d_fake_loss) / 2
self.optimizer_D.zero_grad()
d_loss.backward()
self.optimizer_D.step()
D_losses.append(d_loss.item())
# ========== 训练生成器 ==========
gen_pred = self.discriminator(gen_img, gen_labels)
g_loss = self.criterion(gen_pred, valid)
self.optimizer_G.zero_grad()
g_loss.backward()
self.optimizer_G.step()
G_losses.append(g_loss.item())
print(f'第 {epoch}/{num_epochs} 轮, Batch {i + 1}/{len(self.train_loader)}, '
f'D Loss: {d_loss:.4f}, G Loss: {g_loss:.4f}')
step += 1
return G_losses, D_losses, step
def build_generator(self):
"""生成器"""
return Generator(latent_dim=self.latent_dim, num_classes=self.num_classes)
def build_discriminator(self):
"""判别器"""
return Discriminator(num_classes=self.num_classes)
def load_model(self, model_path):
"""加载模型权重"""
checkpoint = torch.load(model_path, map_location=self.device)
self.generator.load_state_dict(checkpoint['generator_state_dict'])
self.optimizer_G.load_state_dict(checkpoint['optimizer_G_state_dict'])
self.discriminator.load_state_dict(checkpoint['discriminator_state_dict'])
self.optimizer_D.load_state_dict(checkpoint['optimizer_D_state_dict'])
epoch = checkpoint['epoch']
print(f"加载了模型权重,起始训练轮次为 {epoch}")
return epoch
def save_model(self, epoch, save_path):
"""保存模型"""
torch.save({
'epoch': epoch,
'scheduler_G_state_dict': self.scheduler_G.state_dict(),
'scheduler_D_state_dict': self.scheduler_D.state_dict(),
'generator_state_dict': self.generator.state_dict(),
'optimizer_G_state_dict': self.optimizer_G.state_dict(),
'discriminator_state_dict': self.discriminator.state_dict(),
'optimizer_D_state_dict': self.optimizer_D.state_dict(),
}, save_path)
print(f"模型已保存至 {save_path}")
# ===========================
# 生成器
# ===========================
class Generator(nn.Module):
def __init__(self, latent_dim=100, num_classes=4, img_channels=3):
super(Generator, self).__init__()
self.latent_dim = latent_dim
self.label_emb = nn.Embedding(num_classes, num_classes)
self.init_size = 8
self.l1 = nn.Linear(latent_dim + num_classes, 256 * self.init_size * self.init_size)
self.conv_blocks = nn.Sequential(
nn.BatchNorm2d(256),
nn.Upsample(scale_factor=2),
nn.Conv2d(256, 128, 3, padding=1),
nn.BatchNorm2d(128),
nn.LeakyReLU(0.2, inplace=True),
nn.Upsample(scale_factor=2),
nn.Conv2d(128, 64, 3, padding=1),
nn.BatchNorm2d(64),
nn.LeakyReLU(0.2, inplace=True),
nn.Upsample(scale_factor=2),
nn.Conv2d(64, 32, 3, padding=1),
nn.BatchNorm2d(32),
nn.LeakyReLU(0.2, inplace=True),
nn.Upsample(scale_factor=2),
nn.Conv2d(32, img_channels, 3, padding=1),
nn.Tanh()
)
def forward(self, noise, labels):
labels = labels.to(self.label_emb.weight.device)
label_embedding = self.label_emb(labels)
x = torch.cat((noise, label_embedding), dim=1)
x = self.l1(x).view(x.size(0), 256, self.init_size, self.init_size)
return self.conv_blocks(x)
# ===========================
# 判别器
# ===========================
class Discriminator(nn.Module):
def __init__(self, img_channels=3, num_classes=4):
super(Discriminator, self).__init__()
self.label_embedding = nn.Embedding(num_classes, img_channels)
self.model = nn.Sequential(
nn.Conv2d(img_channels * 2, 64, 4, stride=2, padding=1),
nn.LeakyReLU(0.2, inplace=True),
nn.Conv2d(64, 128, 4, stride=2, padding=1),
nn.BatchNorm2d(128),
nn.LeakyReLU(0.2, inplace=True),
nn.Conv2d(128, 256, 4, stride=2, padding=1),
nn.BatchNorm2d(256),
nn.LeakyReLU(0.2, inplace=True),
nn.Conv2d(256, 512, 4, stride=2, padding=1),
nn.BatchNorm2d(512),
nn.LeakyReLU(0.2, inplace=True)
)
self.output_layer = nn.Sequential(
nn.Linear(512 * 8 * 8, 1),
nn.Sigmoid()
)
def forward(self, img, labels):
labels = labels.to(self.label_embedding.weight.device)
label_embedding = self.label_embedding(labels).unsqueeze(2).unsqueeze(3)
label_embedding = label_embedding.expand(-1, -1, img.size(2), img.size(3))
x = torch.cat((img, label_embedding), dim=1)
x = self.model(x).view(x.size(0), -1)
return self.output_layer(x)
2.3新建cGAN_trainer.py
import os
import torch
import argparse
from cGAN_net import cDCGAN
from utils import plot_loss, plot_result
import time
os.environ['OMP_NUM_THREADS'] = '1'
def main(args):
# 初始化设备和训练参数
device = torch.device(args.device if torch.cuda.is_available() else 'cpu')
model = cDCGAN(data_root=args.data_root, batch_size=args.batch_size, device=device, latent_dim=args.latent_dim)
# 添加学习率调度器
scheduler_G = torch.optim.lr_scheduler.StepLR(model.optimizer_G, step_size=10, gamma=0.5)
scheduler_D = torch.optim.lr_scheduler.StepLR(model.optimizer_D, step_size=10, gamma=0.5)
start_epoch = 0
# 如果有保存的模型,加载
if args.load_model and os.path.exists(args.load_model):
start_epoch = model.load_model(args.load_model) + 1
# 恢复调度器状态
scheduler_G_path = f"{args.load_model}_scheduler_G.pt"
scheduler_D_path = f"{args.load_model}_scheduler_D.pt"
if os.path.exists(scheduler_G_path) and os.path.exists(scheduler_D_path):
scheduler_G.load_state_dict(torch.load(scheduler_G_path))
scheduler_D.load_state_dict(torch.load(scheduler_D_path))
print(f"成功恢复调度器状态:{scheduler_G_path}, {scheduler_D_path}")
else:
print("未找到调度器状态文件,使用默认调度器设置")
print(f"从第 {start_epoch} 轮继续训练...")
print(f"开始训练,从第 {start_epoch + 1} 轮开始...")
# 创建保存路径
os.makedirs(args.save_dir, exist_ok=True)
os.makedirs(os.path.join(args.save_dir, 'log'), exist_ok=True)
# 训练循环
D_avg_losses, G_avg_losses = [], []
for epoch in range(start_epoch, args.epochs):
G_losses, D_losses, step = model.train_step(epoch, step=0, num_epochs=args.epochs)
# 计算平均损失
D_avg_loss = sum(D_losses) / len(D_losses) if D_losses else 0.0
G_avg_loss = sum(G_losses) / len(G_losses) if G_losses else 0.0
D_avg_losses.append(D_avg_loss)
G_avg_losses.append(G_avg_loss)
# 保存损失曲线图
plot_loss(start_epoch, args.epochs, D_avg_losses, G_avg_losses, epoch + 1, save=True,
save_dir=os.path.join(args.save_dir, "log"))
# 生成并保存图片
labels = torch.tensor([0, 1, 2, 3]).to(device)
if (epoch + 1) % args.save_freq == 0: # 每隔一定轮次保存生成结果
z = torch.randn(len(labels), args.latent_dim, device=device) # 随机生成噪声
plot_result(model.generator, z, labels, epoch + 1, save_dir=os.path.join(args.save_dir, 'log'))
# 每10个epoch保存模型
if (epoch + 1) % args.save_interval == 0:
timestamp = int(time.time())
save_path = os.path.join(args.save_dir, f"cgan_epoch_{epoch + 1}_{timestamp}.pth")
model.save_model(epoch + 1, save_path)
print(f"第 {epoch + 1} 轮的模型已保存,保存路径为 {save_path}")
# 更新学习率调度器
scheduler_G.step()
scheduler_D.step()
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument('--data_root', type=str, default='data/crop128', help="数据集根目录")
parser.add_argument('--save_dir', type=str, default='./chkpt/cgan_model', help="保存模型的目录")
parser.add_argument('--load_model', type=str, default=None, help="要加载的模型路径(可选)")
parser.add_argument('--epochs', type=int, default=1000, help="训练的轮数")
parser.add_argument('--save_interval', type=int, default=10, help="保存模型检查点的间隔(按轮数)")
parser.add_argument('--batch_size', type=int, default=64, help="训练的批次大小")
parser.add_argument('--device', type=str, default='cuda', help="使用的设备(如 cuda 或 cpu)")
parser.add_argument('--latent_dim', type=int, default=100, help="生成器的潜在空间维度")
parser.add_argument('--save_freq', type=int, default=1, help="每隔多少轮保存一次生成结果(默认: 1)")
args = parser.parse_args()
main(args)
结果分析:
2.4中间结果可视化处理
新建utils.py,编写绘制中间结果和中间损失线图的函数,代码如下:
import matplotlib.pyplot as plt
import numpy as np
import os
import torch
def denorm(x):
out = (x + 1) / 2
return out.clamp(0, 1)
def plot_loss(start_epoch, num_epochs, d_losses, g_losses, num_epoch, save=False, save_dir='celebA_cDCGAN_results/', show=False):
"""
绘制损失函数曲线,从 start_epoch 到 num_epochs。
Args:
start_epoch: 起始轮次
num_epochs: 总轮次
d_losses: 判别器损失列表
g_losses: 生成器损失列表
num_epoch: 当前训练轮次
save: 是否保存绘图
save_dir: 保存路径
show: 是否显示绘图
"""
fig, ax = plt.subplots()
ax.set_xlim(start_epoch, num_epochs)
ax.set_ylim(0, max(np.max(g_losses), np.max(d_losses)) * 1.1)
plt.xlabel(f'Epoch {num_epoch + 1}')
plt.ylabel('Loss values')
plt.plot(d_losses, label='Discriminator')
plt.plot(g_losses, label='Generator')
plt.legend()
if save:
if not os.path.exists(save_dir):
os.makedirs(save_dir)
save_fn = os.path.join(save_dir, f'cDCGAN_losses_epoch.png')
plt.savefig(save_fn)
if show:
plt.show()
else:
plt.close()
def plot_result(generator, z, labels, epoch, save_dir=None, show=False):
"""
生成并保存或显示生成的图片结果。
Args:
generator: 生成器模型
z: 随机噪声张量
labels: 标签张量
epoch: 当前训练轮数
save_dir: 保存图片的路径(可选)
show: 是否显示生成的图片(可选)
"""
# 调用生成器,生成图像
generator.eval() # 设置为评估模式
with torch.no_grad():
gen_images = generator(z, labels) # 同时传入 z 和 labels
generator.train() # 恢复训练模式
# 图像反归一化
gen_images = denorm(gen_images)
# 绘制图片
fig, ax = plt.subplots(1, len(gen_images), figsize=(15, 15))
for i in range(len(gen_images)):
ax[i].imshow(gen_images[i].permute(1, 2, 0).cpu().numpy()) # 转换为可显示格式
ax[i].axis('off')
# 保存或显示图片
if save_dir:
os.makedirs(save_dir, exist_ok=True)
save_path = os.path.join(save_dir, f'epoch_{epoch}.png')
plt.savefig(save_path)
if show:
plt.show()
plt.close(fig)
执行 cGAN_trainer.py 文件,完成模型训练。