PyTorch深度学习框架60天进阶学习计划第16天:循环神经网络进阶!

PyTorch深度学习框架60天进阶学习计划 - 第16天:生成对抗网络原理

学习目标

今天我们将深入探讨生成对抗网络(GAN)的基本原理和数学基础,重点解析GAN的minimax博弈公式,推导生成器与判别器的损失函数,分析Wasserstein GAN的改进方案以及DCGAN的架构设计规范。

1. GAN的基本原理

生成对抗网络(Generative Adversarial Networks, GAN)是由Ian Goodfellow在2014年提出的一种生成模型框架。GAN由两个网络组成:

  • 生成器(Generator, G):学习生成逼真的数据样本
  • 判别器(Discriminator, D):学习区分真实数据和生成器生成的数据

这两个网络通过对抗训练相互提升。

1.1 GAN的博弈过程

GAN的训练过程可以看作是一个两人零和博弈:

网络角色目标策略
生成器(G)生成逼真的假样本欺骗判别器最小化判别器正确分类的概率
判别器(D)准确区分真实样本和生成样本最大化判别器正确分类的概率

2. GAN的数学表达:Minimax博弈公式

2.1 经典GAN的价值函数

GAN的核心可以用以下数学公式表示:

min ⁡ G max ⁡ D V ( D , G ) = E x ∼ p d a t a ( x ) [ log ⁡ D ( x ) ] + E z ∼ p z ( z ) [ log ⁡ ( 1 − D ( G ( z ) ) ) ] \min_G \max_D V(D, G) = \mathbb{E}_{x \sim p_{data}(x)}[\log D(x)] + \mathbb{E}_{z \sim p_z(z)}[\log(1 - D(G(z)))] GminDmaxV(D,G)=Expdata(x)[logD(x)]+Ezpz(z)[log(1D(G(z)))]

这个公式表达了什么?让我们逐步分解:

  • V ( D , G ) V(D, G) V(D,G) 是价值函数,判别器D试图最大化它,而生成器G试图最小化它
  • E x ∼ p d a t a ( x ) [ log ⁡ D ( x ) ] \mathbb{E}_{x \sim p_{data}(x)}[\log D(x)] Expdata(x)[logD(x)] 表示判别器对真实数据的平均预测概率的对数
  • E z ∼ p z ( z ) [ log ⁡ ( 1 − D ( G ( z ) ) ) ] \mathbb{E}_{z \sim p_z(z)}[\log(1 - D(G(z)))] Ezpz(z)[log(1D(G(z)))] 表示判别器对生成数据的平均预测概率的对数的负值

2.2 损失函数的推导

从minimax公式中,我们可以分别推导出判别器和生成器的损失函数:

判别器的损失函数

判别器的目标是最大化 V ( D , G ) V(D, G) V(D,G),即最小化以下损失函数:

L D = − E x ∼ p d a t a ( x ) [ log ⁡ D ( x ) ] − E z ∼ p z ( z ) [ log ⁡ ( 1 − D ( G ( z ) ) ) ] L_D = -\mathbb{E}_{x \sim p_{data}(x)}[\log D(x)] - \mathbb{E}_{z \sim p_z(z)}[\log(1 - D(G(z)))] LD=Expdata(x)[logD(x)]Ezpz(z)[log(1D(G(z)))]

在实际实现中,我们通常使用二元交叉熵损失:

L D = − 1 m ∑ i = 1 m [ log ⁡ D ( x ( i ) ) + log ⁡ ( 1 − D ( G ( z ( i ) ) ) ) ] L_D = -\frac{1}{m}\sum_{i=1}^{m}[\log D(x^{(i)}) + \log(1 - D(G(z^{(i)})))] LD=m1i=1m[logD(x(i))+log(1D(G(z(i))))]

生成器的损失函数

生成器的目标是最小化 V ( D , G ) V(D, G) V(D,G),即最小化以下损失函数:

L G = E z ∼ p z ( z ) [ log ⁡ ( 1 − D ( G ( z ) ) ) ] L_G = \mathbb{E}_{z \sim p_z(z)}[\log(1 - D(G(z)))] LG=Ezpz(z)[log(1D(G(z)))]

然而,在训练初期,当生成器的输出质量较差时,判别器可以轻松区分真假样本,这会导致梯度消失问题。因此,实践中通常使用一个非饱和的损失函数:

L G = − E z ∼ p z ( z ) [ log ⁡ D ( G ( z ) ) ] L_G = -\mathbb{E}_{z \sim p_z(z)}[\log D(G(z))] LG=Ezpz(z)[logD(G(z))]

在代码实现中:

L G = − 1 m ∑ i = 1 m [ log ⁡ D ( G ( z ( i ) ) ) ] L_G = -\frac{1}{m}\sum_{i=1}^{m}[\log D(G(z^{(i)}))] LG=m1i=1m[logD(G(z(i)))]

3. Wasserstein GAN (WGAN) 改进方案

3.1 传统GAN的问题

传统GAN存在以下问题:

  • 训练不稳定
  • 模式崩溃(Mode Collapse)
  • 梯度消失或爆炸

3.2 Wasserstein距离的引入

Wasserstein GAN引入了Wasserstein距离(也称为Earth Mover’s Distance, EMD)来衡量两个概率分布之间的差异。其数学表达式为:

W ( p r , p g ) = inf ⁡ γ ∈ Π ( p r , p g ) E ( x , y ) ∼ γ [ ∥ x − y ∥ ] W(p_r, p_g) = \inf_{\gamma \in \Pi(p_r, p_g)} \mathbb{E}_{(x,y) \sim \gamma}[\|x-y\|] W(pr,pg)=γΠ(pr,pg)infE(x,y)γ[xy]

WGAN的价值函数变为:

min ⁡ G max ⁡ D ∈ D E x ∼ p d a t a ( x ) [ D ( x ) ] − E z ∼ p z ( z ) [ D ( G ( z ) ) ] \min_G \max_{D \in \mathcal{D}} \mathbb{E}_{x \sim p_{data}(x)}[D(x)] - \mathbb{E}_{z \sim p_z(z)}[D(G(z))] GminDDmaxExpdata(x)[D(x)]Ezpz(z)[D(G(z))]

其中 D \mathcal{D} D 是所有1-Lipschitz函数的集合。

3.3 梯度惩罚(Gradient Penalty)实现原理

为了满足Lipschitz约束,WGAN-GP提出了梯度惩罚的方法:

L D = E z ∼ p z ( z ) [ D ( G ( z ) ) ] − E x ∼ p d a t a ( x ) [ D ( x ) ] + λ E x ^ ∼ p x ^ [ ( ∥ ∇ x ^ D ( x ^ ) ∥ 2 − 1 ) 2 ] L_D = \mathbb{E}_{z \sim p_z(z)}[D(G(z))] - \mathbb{E}_{x \sim p_{data}(x)}[D(x)] + \lambda \mathbb{E}_{\hat{x} \sim p_{\hat{x}}}[(\|\nabla_{\hat{x}} D(\hat{x})\|_2 - 1)^2] LD=Ezpz(z)[D(G(z))]Expdata(x)[D(x)]+λEx^px^[(x^D(x^)21)2]

其中 x ^ \hat{x} x^ 是真实样本和生成样本之间的随机插值:

x ^ = ϵ x + ( 1 − ϵ ) G ( z ) \hat{x} = \epsilon x + (1 - \epsilon) G(z) x^=ϵx+(1ϵ)G(z)

ϵ \epsilon ϵ 是从均匀分布 U [ 0 , 1 ] U[0,1] U[0,1] 采样的随机数。

4. DCGAN架构设计规范

Deep Convolutional GAN (DCGAN) 是GAN在计算机视觉领域的一个重要应用。它提出了一系列架构设计规范:

4.1 DCGAN主要设计规范

规范生成器(G)判别器(D)
池化层使用转置卷积进行上采样,不使用池化层使用带步长的卷积替代池化层进行下采样
批量归一化在除输出层外的所有层使用在除输入层外的所有层使用
激活函数隐藏层使用ReLU激活函数,输出层使用Tanh所有层使用LeakyReLU
全连接层最后一层可以使用全连接层最后一层可以使用全连接层

4.2 批量归一化在生成器中的特殊应用

在生成器中,批量归一化具有以下特殊应用:

  1. 促进网络收敛:通过归一化特征,加速训练过程
  2. 防止模式崩溃:帮助不同的生成样本保持多样性
  3. 减轻内部协变量偏移:保持各层的输入分布相对稳定
  4. 特殊位置:通常不在生成器的输出层应用批量归一化,以保留生成数据的原始分布特性

5. PyTorch实现标准GAN

下面是一个使用PyTorch实现标准GAN的代码示例:

import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
import numpy as np

# 设置随机种子
torch.manual_seed(42)

# 设备配置
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# 超参数
batch_size = 64
z_dim = 100
lr = 0.0002
beta1 = 0.5
epochs = 30
image_size = 64

# 数据集预处理
transform = transforms.Compose([
    transforms.Resize(image_size),
    transforms.CenterCrop(image_size),
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
])

# 加载MNIST数据集
dataset = torchvision.datasets.MNIST(root='./data', train=True, transform=transform, download=True)
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=2)

# 判别器网络
class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()
        
        self.model = nn.Sequential(
            # 输入: 1 x 64 x 64
            nn.Conv2d(1, 64, 4, 2, 1, bias=False),
            nn.LeakyReLU(0.2, inplace=True),
            # 32 x 32
            nn.Conv2d(64, 128, 4, 2, 1, bias=False),
            nn.BatchNorm2d(128),
            nn.LeakyReLU(0.2, inplace=True),
            # 16 x 16
            nn.Conv2d(128, 256, 4, 2, 1, bias=False),
            nn.BatchNorm2d(256),
            nn.LeakyReLU(0.2, inplace=True),
            # 8 x 8
            nn.Conv2d(256, 512, 4, 2, 1, bias=False),
            nn.BatchNorm2d(512),
            nn.LeakyReLU(0.2, inplace=True),
            # 4 x 4
            nn.Conv2d(512, 1, 4, 1, 0, bias=False),
            nn.Sigmoid()
        )
    
    def forward(self, x):
        return self.model(x).view(-1, 1).squeeze(1)

# 生成器网络
class Generator(nn.Module):
    def __init__(self):
        super(Generator, self).__init__()
        
        self.model = nn.Sequential(
            # 输入: z_dim x 1 x 1
            nn.ConvTranspose2d(z_dim, 512, 4, 1, 0, bias=False),
            nn.BatchNorm2d(512),
            nn.ReLU(True),
            # 4 x 4
            nn.ConvTranspose2d(512, 256, 4, 2, 1, bias=False),
            nn.BatchNorm2d(256),
            nn.ReLU(True),
            # 8 x 8
            nn.ConvTranspose2d(256, 128, 4, 2, 1, bias=False),
            nn.BatchNorm2d(128),
            nn.ReLU(True),
            # 16 x 16
            nn.ConvTranspose2d(128, 64, 4, 2, 1, bias=False),
            nn.BatchNorm2d(64),
            nn.ReLU(True),
            # 32 x 32
            nn.ConvTranspose2d(64, 1, 4, 2, 1, bias=False),
            nn.Tanh()
            # 输出: 1 x 64 x 64
        )
    
    def forward(self, z):
        return self.model(z)

# 初始化网络
netG = Generator().to(device)
netD = Discriminator().to(device)

# 权重初始化
def weights_init(m):
    classname = m.__class__.__name__
    if classname.find('Conv') != -1:
        nn.init.normal_(m.weight.data, 0.0, 0.02)
    elif classname.find('BatchNorm') != -1:
        nn.init.normal_(m.weight.data, 1.0, 0.02)
        nn.init.constant_(m.bias.data, 0)

netG.apply(weights_init)
netD.apply(weights_init)

# 设置损失函数和优化器
criterion = nn.BCELoss()
optimizerD = optim.Adam(netD.parameters(), lr=lr, betas=(beta1, 0.999))
optimizerG = optim.Adam(netG.parameters(), lr=lr, betas=(beta1, 0.999))

# 训练GAN
for epoch in range(epochs):
    for i, data in enumerate(dataloader, 0):
        ############################
        # (1) 更新判别器: 最大化 log(D(x)) + log(1 - D(G(z)))
        ###########################
        # 训练真实样本
        netD.zero_grad()
        real_images = data[0].to(device)
        batch_size = real_images.size(0)
        labels = torch.full((batch_size,), 1, dtype=torch.float, device=device)
        
        output = netD(real_images)
        errD_real = criterion(output, labels)
        errD_real.backward()
        
        # 训练生成样本
        noise = torch.randn(batch_size, z_dim, 1, 1, device=device)
        fake_images = netG(noise)
        labels.fill_(0)
        
        output = netD(fake_images.detach())
        errD_fake = criterion(output, labels)
        errD_fake.backward()
        
        errD = errD_real + errD_fake
        optimizerD.step()
        
        ############################
        # (2) 更新生成器: 最大化 log(D(G(z)))
        ###########################
        netG.zero_grad()
        labels.fill_(1)  # 生成器希望判别器将生成的图像判为真
        
        output = netD(fake_images)
        errG = criterion(output, labels)
        errG.backward()
        
        optimizerG.step()
        
        # 输出训练状态
        if i % 100 == 0:
            print(f'[{epoch}/{epochs}][{i}/{len(dataloader)}] Loss_D: {errD.item():.4f} Loss_G: {errG.item():.4f}')
    
    # 保存一些生成的图像
    with torch.no_grad():
        fixed_noise = torch.randn(64, z_dim, 1, 1, device=device)
        fake = netG(fixed_noise).detach().cpu()
        img_grid = torchvision.utils.make_grid(fake, padding=2, normalize=True)
        plt.figure(figsize=(8, 8))
        plt.imshow(np.transpose(img_grid, (1, 2, 0)))
        plt.axis('off')
        plt.savefig(f'fake_images_epoch_{epoch}.png')
        plt.close()

print("Training finished!")

6. WGAN-GP的PyTorch实现

下面是WGAN-GP的PyTorch实现示例:

import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
import numpy as np
import matplotlib.pyplot as plt

# 设置随机种子
torch.manual_seed(42)

# 设备配置
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# 超参数
batch_size = 64
z_dim = 100
lr = 0.0002
beta1 = 0.5
beta2 = 0.9
epochs = 30
image_size = 64
n_critic = 5  # 判别器训练次数
lambda_gp = 10  # 梯度惩罚系数

# 数据集预处理
transform = transforms.Compose([
    transforms.Resize(image_size),
    transforms.CenterCrop(image_size),
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,)),
])

# 加载MNIST数据集
dataset = torchvision.datasets.MNIST(root='./data', train=True, transform=transform, download=True)
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=2)

# 判别器网络 (Critic)
class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()
        
        self.model = nn.Sequential(
            # 输入: 1 x 64 x 64
            nn.Conv2d(1, 64, 4, 2, 1),
            nn.LeakyReLU(0.2, inplace=True),
            # 32 x 32
            nn.Conv2d(64, 128, 4, 2, 1),
            nn.InstanceNorm2d(128),  # 使用Instance Normalization替代Batch Normalization
            nn.LeakyReLU(0.2, inplace=True),
            # 16 x 16
            nn.Conv2d(128, 256, 4, 2, 1),
            nn.InstanceNorm2d(256),
            nn.LeakyReLU(0.2, inplace=True),
            # 8 x 8
            nn.Conv2d(256, 512, 4, 2, 1),
            nn.InstanceNorm2d(512),
            nn.LeakyReLU(0.2, inplace=True),
            # 4 x 4
            nn.Conv2d(512, 1, 4, 1, 0),
            # 注意: 没有Sigmoid激活函数,因为WGAN直接输出Wasserstein距离
        )
    
    def forward(self, x):
        return self.model(x).view(-1)

# 生成器网络
class Generator(nn.Module):
    def __init__(self):
        super(Generator, self).__init__()
        
        self.model = nn.Sequential(
            # 输入: z_dim x 1 x 1
            nn.ConvTranspose2d(z_dim, 512, 4, 1, 0),
            nn.BatchNorm2d(512),
            nn.ReLU(True),
            # 4 x 4
            nn.ConvTranspose2d(512, 256, 4, 2, 1),
            nn.BatchNorm2d(256),
            nn.ReLU(True),
            # 8 x 8
            nn.ConvTranspose2d(256, 128, 4, 2, 1),
            nn.BatchNorm2d(128),
            nn.ReLU(True),
            # 16 x 16
            nn.ConvTranspose2d(128, 64, 4, 2, 1),
            nn.BatchNorm2d(64),
            nn.ReLU(True),
            # 32 x 32
            nn.ConvTranspose2d(64, 1, 4, 2, 1),
            nn.Tanh()
            # 输出: 1 x 64 x 64
        )
    
    def forward(self, z):
        return self.model(z)

# 计算梯度惩罚
def compute_gradient_penalty(D, real_samples, fake_samples):
    # 随机权重项: 在真实样本和生成样本之间进行插值
    alpha = torch.rand(real_samples.size(0), 1, 1, 1, device=device)
    # 获取插值样本
    interpolates = (alpha * real_samples + ((1 - alpha) * fake_samples)).requires_grad_(True)
    # 计算判别器对插值样本的输出
    d_interpolates = D(interpolates)
    # 为反向传播创建虚拟标签
    fake = torch.ones(real_samples.size(0), device=device, requires_grad=False)
    # 计算梯度
    gradients = torch.autograd.grad(
        outputs=d_interpolates,
        inputs=interpolates,
        grad_outputs=fake,
        create_graph=True,
        retain_graph=True,
        only_inputs=True,
    )[0]
    # 计算梯度惩罚
    gradients = gradients.view(gradients.size(0), -1)
    gradient_penalty = ((gradients.norm(2, dim=1) - 1) ** 2).mean()
    return gradient_penalty

# 初始化网络
netG = Generator().to(device)
netD = Discriminator().to(device)

# 权重初始化
def weights_init(m):
    classname = m.__class__.__name__
    if classname.find('Conv') != -1:
        nn.init.normal_(m.weight.data, 0.0, 0.02)
    elif classname.find('BatchNorm') != -1 or classname.find('InstanceNorm') != -1:
        nn.init.normal_(m.weight.data, 1.0, 0.02)
        nn.init.constant_(m.bias.data, 0)

netG.apply(weights_init)
netD.apply(weights_init)

# 设置优化器
optimizerD = optim.Adam(netD.parameters(), lr=lr, betas=(beta1, beta2))
optimizerG = optim.Adam(netG.parameters(), lr=lr, betas=(beta1, beta2))

# 记录生成器和判别器的损失
G_losses = []
D_losses = []

# 训练WGAN-GP
for epoch in range(epochs):
    for i, data in enumerate(dataloader, 0):
        ############################
        # (1) 更新判别器
        ###########################
        # 训练判别器多次
        for _ in range(n_critic):
            # 配置网络
            netD.zero_grad()
            
            # 训练真实样本
            real_images = data[0].to(device)
            batch_size = real_images.size(0)
            
            # 生成噪声
            noise = torch.randn(batch_size, z_dim, 1, 1, device=device)
            # 生成假样本
            fake_images = netG(noise)
            
            # 计算损失
            real_validity = netD(real_images)
            fake_validity = netD(fake_images.detach())
            
            # 计算梯度惩罚
            gradient_penalty = compute_gradient_penalty(netD, real_images, fake_images.detach())
            
            # 判别器总损失
            d_loss = -torch.mean(real_validity) + torch.mean(fake_validity) + lambda_gp * gradient_penalty
            
            # 反向传播
            d_loss.backward()
            optimizerD.step()
        
        ############################
        # (2) 更新生成器
        ###########################
        netG.zero_grad()
        
        # 生成新的假样本
        noise = torch.randn(batch_size, z_dim, 1, 1, device=device)
        fake_images = netG(noise)
        
        # 判别器评估假样本
        fake_validity = netD(fake_images)
        
        # 生成器损失
        g_loss = -torch.mean(fake_validity)
        
        # 反向传播
        g_loss.backward()
        optimizerG.step()
        
        # 保存损失
        G_losses.append(g_loss.item())
        D_losses.append(d_loss.item())
        
        # 输出训练状态
        if i % 50 == 0:
            print(f'[{epoch}/{epochs}][{i}/{len(dataloader)}] Loss_D: {d_loss.item():.4f} Loss_G: {g_loss.item():.4f}')
    
    # 每个epoch结束后保存一些生成的图像
    with torch.no_grad():
        fixed_noise = torch.randn(64, z_dim, 1, 1, device=device)
        fake = netG(fixed_noise).detach().cpu()
        img_grid = torchvision.utils.make_grid(fake, padding=2, normalize=True)
        plt.figure(figsize=(8, 8))
        plt.imshow(np.transpose(img_grid, (1, 2, 0)), cmap='gray')
        plt.axis('off')
        plt.title(f'WGAN-GP Generated Images - Epoch {epoch}')
        plt.savefig(f'wgan_gp_images_epoch_{epoch}.png')
        plt.close()

print("Training finished!")

# 绘制损失曲线
plt.figure(figsize=(10, 5))
plt.title("Generator and Discriminator Loss During Training")
plt.plot(G_losses, label="G")
plt.plot(D_losses, label="D")
plt.xlabel("iterations")
plt.ylabel("Loss")
plt.legend()
plt.savefig("wgan_gp_loss_curve.png")
plt.close()

# 展示最终生成的图像
with torch.no_grad():
    noise = torch.randn(16, z_dim, 1, 1, device=device)
    generated_images = netG(noise).detach().cpu()
    
    # 去归一化
    generated_images = (generated_images + 1) / 2
    
    # 创建图像网格
    img_grid = torchvision.utils.make_grid(generated_images, nrow=4, padding=2, normalize=False)
    
    plt.figure(figsize=(10, 10))
    plt.imshow(np.transpose(img_grid, (1, 2, 0)), cmap='gray')
    plt.axis('off')
    plt.title("Final WGAN-GP Generated Images")
    plt.savefig("final_wgan_gp_images.png")
    plt.show()

7. GAN的训练流程图

以下是GAN训练的基本流程图:
在这里插入图片描述
在这里插入图片描述

8. WGAN-GP与标准GAN的对比

特性标准GANWGAN-GP
损失函数二元交叉熵Wasserstein距离 + 梯度惩罚项
判别器/评论家输出概率值(0~1)实数值(Wasserstein距离)
最后一层激活函数Sigmoid无(线性输出)
参数裁剪不需要通过梯度惩罚实现Lipschitz约束
优化器推荐Adam(β1=0.5, β2=0.999)Adam(β1=0.5, β2=0.9)
训练稳定性容易不稳定更加稳定
模式崩溃常见问题大幅减轻
梯度消失容易发生基本解决
判别器训练次数通常1:1通常5:1(判别器:生成器)
归一化层BatchNorm推荐使用LayerNorm或InstanceNorm
训练速度相对较快相对较慢(需要多次训练判别器)
超参数敏感度较高较低
理论基础JS散度Wasserstein距离

9. 批量归一化在GAN中的特殊应用分析

批量归一化(Batch Normalization)在GAN中具有重要作用,尤其在生成器中有特殊的应用方式:

9.1 批量归一化的基本原理

批量归一化通过以下公式对每个批次的数据进行标准化:

x ^ i = x i − μ B σ B 2 + ϵ \hat{x}_i = \frac{x_i - \mu_B}{\sqrt{\sigma_B^2 + \epsilon}} x^i=σB2+ϵ xiμB

其中, μ B \mu_B μB σ B 2 \sigma_B^2 σB2 分别是批次的均值和方差, ϵ \epsilon ϵ 是防止除零的小常数。

然后,通过可学习的参数 γ \gamma γ β \beta β 调整标准化后的分布:

y i = γ x ^ i + β y_i = \gamma \hat{x}_i + \beta yi=γx^i+β

9.2 生成器中批量归一化的特殊应用

位置选择

在生成器网络中,批量归一化层的位置有以下特殊考虑:

  1. 输入层后不使用:生成器接收随机噪声作为输入,这些噪声通常已经是标准正态分布,不需要额外的归一化。

  2. 输出层前不使用:输出层通常使用Tanh激活函数将值映射到[-1,1]区间,输出前不应用批量归一化以保持生成数据的自然分布。

  3. 中间层广泛使用:在中间层广泛使用BatchNorm可以稳定训练过程,防止梯度消失和爆炸。
    在这里插入图片描述

训练模式与评估模式的区别

在GAN中,批量归一化层在训练和评估模式下的行为有重要区别:

  1. 训练模式(train()):使用当前批次的统计值进行归一化
  2. 评估模式(eval()):使用整个训练过程累积的统计值进行归一化

在GAN训练中,正确切换这两种模式至关重要。在生成样本时,必须使用评估模式,确保输出的一致性。

应对小批量大小的策略

当批量大小较小时,BatchNorm可能导致统计不稳定。在GAN中,尤其是高分辨率图像生成时,可以考虑以下替代方案:

  1. 实例归一化(Instance Normalization):对每个样本独立归一化
  2. 层归一化(Layer Normalization):对每个特征通道独立归一化
  3. 组归一化(Group Normalization):将通道分组后归一化
批量归一化与条件GAN

在条件GAN(Conditional GAN)中,批量归一化可以结合条件信息:

class ConditionalBatchNorm2d(nn.Module):
    def __init__(self, num_features, num_classes):
        super().__init__()
        self.num_features = num_features
        self.bn = nn.BatchNorm2d(num_features, affine=False)
        self.embed = nn.Embedding(num_classes, num_features * 2)
        self.embed.weight.data[:, :num_features].normal_(1, 0.02)  # 初始化为1
        self.embed.weight.data[:, num_features:].zero_()  # 初始化为0

    def forward(self, x, y):
        out = self.bn(x)
        gamma, beta = self.embed(y).chunk(2, 1)
        out = gamma.view(-1, self.num_features, 1, 1) * out + beta.view(-1, self.num_features, 1, 1)
        return out

这种条件批量归一化允许生成器根据标签条件调整其特征统计,从而生成特定类别的样本。

10. 实现DCGAN的最佳实践

下面是实现DCGAN时的一些最佳实践:

import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
import numpy as np

# 设置随机种子,确保结果可复现
torch.manual_seed(42)
np.random.seed(42)

# 设备配置
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

# 超参数
batch_size = 128
z_dim = 100
lr = 0.0002
beta1 = 0.5
epochs = 25
image_size = 64
nc = 3  # 彩色图像的通道数

# 数据集预处理
transform = transforms.Compose([
    transforms.Resize(image_size),
    transforms.CenterCrop(image_size),
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
])

# 加载CIFAR10数据集
dataset = torchvision.datasets.CIFAR10(root='./data', train=True, transform=transform, download=True)
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=4, pin_memory=True)

# DCGAN判别器网络
class Discriminator(nn.Module):
    def __init__(self, nc=3, ndf=64):
        super(Discriminator, self).__init__()
        self.main = nn.Sequential(
            # 输入: nc x 64 x 64
            nn.Conv2d(nc, ndf, 4, 2, 1, bias=False),
            nn.LeakyReLU(0.2, inplace=True),
            # 状态尺寸: ndf x 32 x 32
            nn.Conv2d(ndf, ndf * 2, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ndf * 2),
            nn.LeakyReLU(0.2, inplace=True),
            # 状态尺寸: (ndf*2) x 16 x 16
            nn.Conv2d(ndf * 2, ndf * 4, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ndf * 4),
            nn.LeakyReLU(0.2, inplace=True),
            # 状态尺寸: (ndf*4) x 8 x 8
            nn.Conv2d(ndf * 4, ndf * 8, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ndf * 8),
            nn.LeakyReLU(0.2, inplace=True),
            # 状态尺寸: (ndf*8) x 4 x 4
            nn.Conv2d(ndf * 8, 1, 4, 1, 0, bias=False),
            nn.Sigmoid()
        )

    def forward(self, input):
        return self.main(input).view(-1, 1).squeeze(1)

# DCGAN生成器网络
class Generator(nn.Module):
    def __init__(self, nc=3, ngf=64, nz=100):
        super(Generator, self).__init__()
        self.main = nn.Sequential(
            # 输入是一个nz维度的噪声向量,进入转置卷积
            nn.ConvTranspose2d(nz, ngf * 8, 4, 1, 0, bias=False),
            nn.BatchNorm2d(ngf * 8),
            nn.ReLU(True),
            # 状态尺寸: (ngf*8) x 4 x 4
            nn.ConvTranspose2d(ngf * 8, ngf * 4, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ngf * 4),
            nn.ReLU(True),
            # 状态尺寸: (ngf*4) x 8 x 8
            nn.ConvTranspose2d(ngf * 4, ngf * 2, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ngf * 2),
            nn.ReLU(True),
            # 状态尺寸: (ngf*2) x 16 x 16
            nn.ConvTranspose2d(ngf * 2, ngf, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ngf),
            nn.ReLU(True),
            # 状态尺寸: (ngf) x 32 x 32
            nn.ConvTranspose2d(ngf, nc, 4, 2, 1, bias=False),
            nn.Tanh()
            # 状态尺寸: nc x 64 x 64
        )

    def forward(self, input):
        return self.main(input)

# 权重初始化
def weights_init(m):
    classname = m.__class__.__name__
    if classname.find('Conv') != -1:
        nn.init.normal_(m.weight.data, 0.0, 0.02)
    elif classname.find('BatchNorm') != -1:
        nn.init.normal_(m.weight.data, 1.0, 0.02)
        nn.init.constant_(m.bias.data, 0)

# 初始化网络
netG = Generator(nc, 64, z_dim).to(device)
netD = Discriminator(nc, 64).to(device)

# 初始化权重
netG.apply(weights_init)
netD.apply(weights_init)

# 打印模型结构
print(netG)
print(netD)

# 设置损失函数和优化器
criterion = nn.BCELoss()

# 使用Adam优化器,按照DCGAN论文中的建议设置参数
optimizerD = optim.Adam(netD.parameters(), lr=lr, betas=(beta1, 0.999))
optimizerG = optim.Adam(netG.parameters(), lr=lr, betas=(beta1, 0.999))

# 创建固定的噪声向量,用于可视化生成器的进展
fixed_noise = torch.randn(64, z_dim, 1, 1, device=device)

# 创建真和假的标签
real_label = 1.
fake_label = 0.

# 用于保存训练过程中的损失值
G_losses = []
D_losses = []
img_list = []

print("开始训练...")

# 训练循环
for epoch in range(epochs):
    for i, data in enumerate(dataloader, 0):
        ############################
        # (1) 更新判别器: 最大化 log(D(x)) + log(1 - D(G(z)))
        ###########################
        ## 用真实图像训练判别器
        netD.zero_grad()
        real_cpu = data[0].to(device)
        batch_size = real_cpu.size(0)
        label = torch.full((batch_size,), real_label, dtype=torch.float, device=device)
        
        output = netD(real_cpu)
        errD_real = criterion(output, label)
        errD_real.backward()
        D_x = output.mean().item()

        ## 用生成的假图像训练判别器
        noise = torch.randn(batch_size, z_dim, 1, 1, device=device)
        fake = netG(noise)
        label.fill_(fake_label)
        
        output = netD(fake.detach())
        errD_fake = criterion(output, label)
        errD_fake.backward()
        D_G_z1 = output.mean().item()
        
        errD = errD_real + errD_fake
        optimizerD.step()

        ############################
        # (2) 更新生成器: 最大化 log(D(G(z)))
        ###########################
        netG.zero_grad()
        label.fill_(real_label)  # 生成器希望判别器将假图像判为真
        
        output = netD(fake)
        errG = criterion(output, label)
        errG.backward()
        D_G_z2 = output.mean().item()
        
        optimizerG.step()

        # 输出训练状态
        if i % 50 == 0:
            print(f'[{epoch}/{epochs}][{i}/{len(dataloader)}] '
                  f'Loss_D: {errD.item():.4f} Loss_G: {errG.item():.4f} '
                  f'D(x): {D_x:.4f} D(G(z)): {D_G_z1:.4f}/{D_G_z2:.4f}')
        
        # 保存损失,用于以后绘图
        G_losses.append(errG.item())
        D_losses.append(errD.item())
        
        # 检查生成器如何处理固定噪声向量
        if (i % 500 == 0) or ((epoch == epochs-1) and (i == len(dataloader)-1)):
            with torch.no_grad():
                fake = netG(fixed_noise).detach().cpu()
            img_list.append(torchvision.utils.make_grid(fake, padding=2, normalize=True))

    # 在每个epoch结束后保存模型
    torch.save(netG.state_dict(), f'dcgan_generator_epoch_{epoch}.pth')
    torch.save(netD.state_dict(), f'dcgan_discriminator_epoch_{epoch}.pth')

    # 在每个epoch结束后显示生成的图像
    with torch.no_grad():
        fake = netG(fixed_noise).detach().cpu()
        img_grid = torchvision.utils.make_grid(fake, padding=2, normalize=True)
        plt.figure(figsize=(8, 8))
        plt.imshow(np.transpose(img_grid, (1, 2, 0)))
        plt.axis('off')
        plt.title(f'DCGAN Generated Images - Epoch {epoch}')
        plt.savefig(f'dcgan_images_epoch_{epoch}.png')
        plt.close()

print("训练完成!")

# 绘制生成器和判别器的损失曲线
plt.figure(figsize=(10, 5))
plt.title("Generator and Discriminator Loss During Training")
plt.plot(G_losses, label="G")
plt.plot(D_losses, label="D")
plt.xlabel("iterations")
plt.ylabel("Loss")
plt.legend()
plt.savefig("dcgan_loss_plot.png")
plt.close()

# 显示训练进程中生成的图像
fig = plt.figure(figsize=(8, 8))
plt.axis("off")
ims = [[plt.imshow(np.transpose(i, (1, 2, 0)), animated=True)] for i in img_list]
ani = animation.ArtistAnimation(fig, ims, interval=1000, repeat_delay=1000, blit=True)
plt.savefig("dcgan_generation_animation.png")
plt.close()

# 展示最终生成的图像
plt.figure(figsize=(8, 8))
plt.axis("off")
plt.title("Final DCGAN Generated Images")
plt.imshow(np.transpose(img_list[-1], (1, 2, 0)))
plt.savefig("final_dcgan_generated_images.png")
plt.close()

print("所有结果已保存!")

11. GAN训练中的常见问题及解决方案

以下是GAN训练中常见的问题及其解决方案:

问题原因解决方案
模式崩溃生成器只学习生成有限种类的样本1. 使用WGAN或WGAN-GP
2. 小批量判别(Minibatch Discrimination)
3. 特征匹配(Feature Matching)
4. 在判别器中添加噪声
判别器过强判别器学习速度太快,导致生成器没有有效梯度1. 降低判别器的学习率
2. 减少判别器的更新频率
3. 添加标签平滑(Label Smoothing)
4. 在生成器损失中添加辅助任务
训练不稳定损失函数波动大或不收敛1. 使用WGAN或WGAN-GP
2. 梯度裁剪或梯度惩罚
3. 调整学习率
4. 使用适当的架构设计
梯度消失在训练初期,判别器可以轻易区分真假样本1. 使用WGAN
2. 使用非饱和生成器损失
3. 标签翻转(Label Flipping)
梯度爆炸网络权重更新过大1. 梯度裁剪
2. 权重归一化
3. 调整批量大小
4. 使用适当的初始化

12. GAN中的损失函数比较

下表比较了不同GAN变体中的损失函数:

在这里插入图片描述

13. 结论与实践建议

GAN是深度学习领域的一个重要创新,它为生成模型带来了革命性的变化。通过本节的学习,我们深入理解了GAN的数学基础、损失函数推导、改进方案以及实现技巧。

在实践中,我建议遵循以下原则:

  1. 从简单开始:先实现标准GAN,了解其基本原理和训练行为
  2. 选择合适的架构:根据任务选择适当的网络架构,DCGAN是一个良好的起点
  3. 使用改进的损失函数:考虑使用WGAN-GP等改进的损失函数提高训练稳定性
  4. 批量归一化应用:在生成器中恰当使用批量归一化,但注意输出层前不要使用
  5. 监控训练过程:定期生成样本并检查,及时调整超参数


清华大学全三版的《DeepSeek教程》完整的文档需要的朋友,关注我私信:deepseek 即可获得。

怎么样今天的内容还满意吗?再次感谢朋友们的观看,关注GZH:凡人的AI工具箱,回复666,送您价值199的AI大礼包。最后,祝您早日实现财务自由,还请给个赞,谢谢!

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/983472.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

系统架构设计师—系统架构设计篇—微服务架构

文章目录 概述优势挑战 概述 微服务是一种架构风格,将单体应用划分成一组小的服务,服务之间相互协作,实现业务功能,每个服务运营在独立的进程中,服务间采用轻量级的通信机制协作(通常是HTTP/JSON&#xff0…

静态时序分析STA——2. 数字单元库-(2)

参考文献 [1]Static Timing Analysis for Nanometer Designs A Practical Approach [2]静态时序分析圣经翻译计划 三.组合逻辑单元的时序模型 对于一个两输入与门的时序弧,两个时序弧均为正单边类型(positive_unate)。这意味着对于 NLDM 模型…

Mysql的卸载安装配置以及简单使用

MySQL其它问题已经更新在:MySQL完善配置---可视化-CSDN博客 一、卸载 ①控制面板卸载 ②C盘隐藏项目>ProgramData>mysql相关文件夹,还有Program file下的MySQL文件夹 ③开始菜单栏搜索>服务,找到MySQL相关服务删除,如果再…

第五课:Express框架与RESTful API设计:技术实践与探索

在使用Node.js进行企业应用开发,常用的开发框架Express,其中的中间件、路由配置与参数解析、RESTful API核心技术尤为重要,本文将深入探讨它们在应用开发中的具体使用方法,最后通过Postman来对开发的接口进行测试。 一、Express中…

基于Django的协同过滤算法养老新闻推荐系统的设计与实现

基于Django的协同过滤算法养老新闻推荐系统(可改成普通新闻推荐系统使用) 开发工具和实现技术 Pycharm,Python,Django框架,mysql8,navicat数据库管理工具,vue,spider爬虫&#xff0…

Facebook 的隐私保护数据存储方案研究

Facebook 的隐私保护数据存储方案研究 在这个信息爆炸的时代,数据隐私保护已成为公众关注的热点。Facebook,作为全球最大的社交媒体平台之一,承载着海量用户数据,其隐私保护措施和数据存储方案对于维护用户隐私至关重要。本文将深…

World of Warcraft [CLASSIC] BigFoot BiaoGe

World of Warcraft [CLASSIC] BigFoot BiaoGe 金团表格插件 设置60秒拍卖装备时间 ALT 鼠标左键,点击装备,弹出对话框,填写 1)拍卖时间默认60秒,起拍价, 2)点击【开始拍卖】 团队所有安装了…

Docker和DockerCompose基础教程及安装教程

Docker的应用场景 Web 应用的自动化打包和发布。自动化测试和持续集成、发布。在服务型环境中部署和调整数据库或其他的后台应用。从头编译或者扩展现有的 OpenShift 或 Cloud Foundry 平台来搭建自己的 PaaS 环境。 CentOS Docker 安装 使用官方安装脚本自动安装 安装命令…

题解:洛谷 AT_dp_c Vacation

题目https://www.luogu.com.cn/problem/AT_dp_c设 表示对于前 天&#xff0c;以 项目结尾能获得的最大价值。 则&#xff1a; 答案为&#xff1a;。 实现 #include<bits/stdc.h> using namespace std; #define int long long int n,dp[100005][3]; signed main(){i…

通义万相2.1:开启视频生成新时代

文章摘要&#xff1a;通义万相 2.1 是一款在人工智能视频生成领域具有里程碑意义的工具&#xff0c;它通过核心技术的升级和创新&#xff0c;为创作者提供了更强大、更智能的创作能力。本文详细介绍了通义万相 2.1 的背景、核心技术、功能特性、性能评测、用户反馈以及应用场景…

ubuntu 20.04 C++ 源码编译 cuda版本 opencv4.5.0

前提条件是安装好了cuda和cudnn 点击下载&#xff1a; opencv_contrib4.5.0 opencv 4.5.0 解压重命名后 进入opencv目录&#xff0c;创建build目录 “CUDA_ARCH_BIN ?” 这里要根据显卡查询一下,我的cuda是11&#xff0c;显卡1650&#xff0c;所以是7.5 查询方法1&#xff1…

【人工智能】Open WebUI+ollama+deepSeek-r1 本地部署大模型与知识库

目录 一 、命令行下载安装 二、运行 三、添加开机自启服务 ollama serve 四、重新加载配置、重启ollama server 五、查看模型文件信息 六、 添加open-webui 七、 配置open webui 八、创建自己知识库 九、网络加密优化 十、大工告成&#xff0c;大家如果有问题可以私信…

DeepSeek R1-7B 医疗大模型微调实战全流程分析(全码版)

DeepSeek R1-7B 医疗大模型微调实战全流程指南 目录 环境配置与硬件优化医疗数据工程微调策略详解训练监控与评估模型部署与安全持续优化与迭代多模态扩展伦理与合规体系故障排除与调试行业应用案例进阶调优技巧版本管理与迭代法律风险规避成本控制方案文档与知识传承1. 环境配…

Android Studio右上角Gradle 的Task展示不全

Android Studio 版本如下&#xff1a;Android Studio lguana|2023.21, 发现Gradle 的Tasks阉割严重&#xff0c;如下图&#xff0c;只显示一个other 解决方法如下&#xff1a;**Setting>Experimental>勾选Configure all gradle tasks during Gradle Sync(this can make…

[HTTP协议]应用层协议HTTP从入门到深刻理解并落地部署自己的云服务(2)

标题&#xff1a;[HTTP协议]应用层协议HTTP从入门到深刻理解并落地部署自己的云服务(2) 水墨不写bug 文章目录 一、无法拷贝类(class uncopyable)的设计解释&#xff1a;重要思想&#xff1a;使用示例 二、锁的RAII设计解释重要考虑使用示例 三、基于RAII模式和互斥锁的的日志…

YOLOv8改进SPFF-LSKA大核可分离核注意力机制

YOLOv8改进------------SPFF-LSKA 1、LSAK.py代码2、添加YAML文件yolov8_SPPF_LSKA.yaml3、添加SPPF_LSKA代码4、ultralytics/nn/modules/__init__.py注册模块5、ultralytics/nn/tasks.py注册模块6、导入yaml文件训练 1、LSAK.py代码 论文 代码 LSKA.py添加到ultralytics/nn/…

系统架构设计师—系统架构设计篇—特定领域软件体系结构

文章目录 概述领域分类垂直域水平域 系统模型基本活动参与角色 概述 特定领域软件架构&#xff08;Domain Specific Software Architecture&#xff0c;DSSA&#xff09;是在一个特定应用领域中&#xff0c;为一组应用提供组织结构参考的标准团建体系结构。 领域分类 垂直域…

第四次CCF-CSP认证(含C++源码)

第四次CCF-CSP认证 第一道&#xff08;easy&#xff09;思路及AC代码 第二道&#xff08;easy&#xff09;思路及AC代码遇到的问题 第三道&#xff08;mid&#xff09;思路及AC代码 第一道&#xff08;easy&#xff09; 题目链接 思路及AC代码 这题就是将这个矩阵旋转之后输出…

软考中级-数据库-3.3 数据结构-树

定义:树是n(n>=0)个结点的有限集合。当n=0时称为空树。在任一非空树中,有且仅有一个称为根的结点:其余结点可分为m(m>=0)个互不相交的有限集T1,T2,T3...,Tm…,其中每个集合又都是一棵树,并且称为根结点的子树。 树的相关概念 1、双亲、孩子和兄弟: 2、结点的度:一个结…

Django小白级开发入门

1、Django概述 Django是一个开放源代码的Web应用框架&#xff0c;由Python写成。采用了MTV的框架模式&#xff0c;即模型M&#xff0c;视图V和模版T。 Django 框架的核心组件有&#xff1a; 用于创建模型的对象关系映射为最终用户设计较好的管理界面URL 设计设计者友好的模板…