使用pytorch构建带梯度惩罚的Wasserstein GAN(WGAN-GP)网络模型

本文为此系列的第三篇WGAN-GP,上一篇为DCGAN。文中仍然不会过多详细的讲解之前写过的,只会写WGAN-GP相对于之前版本的改进点,若有不懂的可以重点看第一篇比较详细。

原理

具有梯度惩罚的 Wasserstein GAN (WGAN-GP)可以解决 GAN 的一些稳定性问题。 具体来说,使用W-loss 作为损失函数替代传统的 BCE 等 loss,并使用梯度惩罚来防止 mode collapse。

  • WGAN-GP 使用了 Wasserstein distance(也成为Earth Mover’s distance, EMD)作为训练 GAN 模型的目标函数,Wasserstein distance is a function of amount and distance,体现的是生成的数据的分布移动到真实数据的分布之间所需的距离与量。
    在这里插入图片描述
    随着判别器训练的越来越好,使用 BCE loss 的话会让鉴别器给出接近于 0 或者接近于 1 的极端值,如下为 sigmoid 曲线,极端值的梯度无限接近于 0,这样判别器就没有太多有用的信息反馈给生成器让它学习,导致梯度消失或 model collapse。使用距离的方式可以有效解决,分布距离再远都不再限制。
    在这里插入图片描述
    在这里插入图片描述
  • BCE loss 本质是一个 minimax game, d 即 discriminator 希望尽可能的 minimize,g 即 generator 希望尽可能的 maximize(意味着造出来的假东西对于鉴别器来说看起来很真实),可以进行如下的简化:
    在这里插入图片描述
    基于 Wasserstein distance 的 W-loss 的的式子与其简化版进行对比:
    在这里插入图片描述
    在 Wasserstein GAN 中不再是 discriminator 了,因为输出不再是 0-1 之间来进行分类,既然不分类了就不是 discriminator 了,而是 critic,所以这里使用 c 代表 critic。critic 希望其尽可能的 maximize,因为希望让 real 和 feak 的距离尽可能的大,起到划清界限的目的;generator 希望其尽可能的minimize,减小两者之间的距离,达到以假乱真的目的。
  • mode collapse 即模式崩溃,当生成器学会从单个类生成特征来欺骗鉴别器时,就会发生 mode collapse(陷入一种模式出不来),跟 cnn 的局部最优是一个概念。这会导致输出出现重复,缺乏多样性和细节。

但在使用 W-loss 训练 GAN 时需要对 critic 有一定的条件 —— critic 需要 1-L(1-Lipschitz)连续:
∣ f ( x 1 ) − f ( x 2 ) ∣ ≤ k ∣ x 1 − x 2   ∣ |f(x_1)-f(x_2)|\le k|x_1-x_2\ | f(x1)f(x2)kx1x2 
这里的 k = 1,也就是 critic 的 nn 函数曲线的梯度(斜率)始终在 -1 到 1 之间,即梯度的 L2 范数不超过1:

在这里插入图片描述
如图:
在这里插入图片描述
曲线的每个点的斜率都是在绿色区域内,很显然这个曲线并不符合。像如下这个曲线就是符合的:
在这里插入图片描述
达到 1-L 连续有两种方法:weigh clipping、gradient penalty。

  • weigh clipping 将权重裁剪到固定范围内,从而限制 critic 的学习能力。但是这样有缺点,可能让所有参数走极端,要么取最大值要么取最小值,critic 会非常倾向于学习一个简单的映射函数。
  • gradient penalty 则是添加一个正则项在 loss function 中,相比 weigh clipping 更加柔和对critic参数的限制更加灵活,通常不会导致梯度消失或梯度爆炸问题。
    在这里插入图片描述
    这里的 λ \lambda λ 为超参值,reg 等于 critic 神经网络梯度范数 -1 的平方,即:
    在这里插入图片描述
    当 critic 神经网络梯度范数 >1 时正则化项发挥作用。平方的作用是为了让其偏离越大,惩罚越大。
    这里的 x ^ \hat{x} x^ 为真实数据与生成数据之间随机取样得到的中间数据,随机值 ϵ \epsilon ϵ 作为权重值,假设 ϵ \epsilon ϵ 为0.3,那么 1- ϵ \epsilon ϵ 为0.7。
    在这里插入图片描述

代码

model.py

from torch import nn

class Generator(nn.Module):
    def __init__(self, z_dim=10, im_chan=1, hidden_dim=64):
        super(Generator, self).__init__()
        self.z_dim = z_dim
        # Build the neural network
        self.gen = nn.Sequential(
            self.make_gen_block(z_dim, hidden_dim * 4),
            self.make_gen_block(hidden_dim * 4, hidden_dim * 2, kernel_size=4, stride=1),
            self.make_gen_block(hidden_dim * 2, hidden_dim),
            self.make_gen_block(hidden_dim, im_chan, kernel_size=4, final_layer=True),
        )

    def make_gen_block(self, input_channels, output_channels, kernel_size=3, stride=2, final_layer=False):
        if not final_layer:
            return nn.Sequential(
                nn.ConvTranspose2d(input_channels, output_channels, kernel_size, stride),
                nn.BatchNorm2d(output_channels),
                nn.ReLU(inplace=True),
            )
        else:
            return nn.Sequential(
                nn.ConvTranspose2d(input_channels, output_channels, kernel_size, stride),
                nn.Tanh(),
            )

    def forward(self, noise):
        x = noise.view(len(noise), self.z_dim, 1, 1)
        return self.gen(x)

class Critic(nn.Module):
    def __init__(self, im_chan=1, hidden_dim=64):
        super(Critic, self).__init__()
        self.crit = nn.Sequential(
            self.make_crit_block(im_chan, hidden_dim),
            self.make_crit_block(hidden_dim, hidden_dim * 2),
            self.make_crit_block(hidden_dim * 2, 1, final_layer=True),
        )

    def make_crit_block(self, input_channels, output_channels, kernel_size=4, stride=2, final_layer=False):
        if not final_layer:
            return nn.Sequential(
                nn.Conv2d(input_channels, output_channels, kernel_size, stride),
                nn.BatchNorm2d(output_channels),
                nn.LeakyReLU(0.2, inplace=True),
            )
        else:
            return nn.Sequential(
                nn.Conv2d(input_channels, output_channels, kernel_size, stride),
            )

    def forward(self, image):
        crit_pred = self.crit(image)
        return crit_pred.view(len(crit_pred), -1)

train.py

import torch
from torch import nn
from tqdm.auto import tqdm
from torchvision import transforms
from torchvision.datasets import MNIST
from torchvision.utils import make_grid
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
from model import *
torch.manual_seed(0) # Set for testing purposes, please do not change!

def show_tensor_images(image_tensor, num_images=25, size=(1, 28, 28)):
    image_tensor = (image_tensor + 1) / 2
    image_unflat = image_tensor.detach().cpu()
    image_grid = make_grid(image_unflat[:num_images], nrow=5)
    plt.imshow(image_grid.permute(1, 2, 0).squeeze())
    plt.show()

def get_noise(n_samples, z_dim, device='cpu'):
    return torch.randn(n_samples, z_dim, device=device)

n_epochs = 100
z_dim = 64
display_step = 50
batch_size = 128
lr = 0.0002
beta_1 = 0.5
beta_2 = 0.999
c_lambda = 10
crit_repeats = 5
device = 'cuda'

transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,)),
])

dataloader = DataLoader(
    MNIST('.', download=False, transform=transform),
    batch_size=batch_size,
    shuffle=True)

gen = Generator(z_dim).to(device)
gen_opt = torch.optim.Adam(gen.parameters(), lr=lr, betas=(beta_1, beta_2))
crit = Critic().to(device)
crit_opt = torch.optim.Adam(crit.parameters(), lr=lr, betas=(beta_1, beta_2))

def weights_init(m):
    if isinstance(m, nn.Conv2d) or isinstance(m, nn.ConvTranspose2d):
        torch.nn.init.normal_(m.weight, 0.0, 0.02)
    if isinstance(m, nn.BatchNorm2d):
        torch.nn.init.normal_(m.weight, 0.0, 0.02)
        torch.nn.init.constant_(m.bias, 0)
gen = gen.apply(weights_init)
crit = crit.apply(weights_init)

def get_gradient(crit, real, fake, epsilon):
    # Mix the images together
    mixed_images = real * epsilon + fake * (1 - epsilon)

    # Calculate the critic's scores on the mixed images
    mixed_scores = crit(mixed_images)

    # Take the gradient of the scores with respect to the images
    gradient = torch.autograd.grad(
        inputs=mixed_images,
        outputs=mixed_scores,
        # These other parameters have to do with the pytorch autograd engine works
        grad_outputs=torch.ones_like(mixed_scores),
        create_graph=True,
        retain_graph=True,
    )[0]
    return gradient

def gradient_penalty(gradient):
    # Flatten the gradients so that each row captures one image
    gradient = gradient.view(len(gradient), -1)

    # Calculate the magnitude of every row
    gradient_norm = gradient.norm(2, dim=1)

    # Penalize the mean squared distance of the gradient norms from 1
    penalty = torch.mean((gradient_norm - 1) ** 2)
    return penalty

def get_gen_loss(crit_fake_pred):
    gen_loss = -1. * torch.mean(crit_fake_pred)
    return gen_loss

def get_crit_loss(crit_fake_pred, crit_real_pred, gp, c_lambda):
    crit_loss = torch.mean(crit_fake_pred) - torch.mean(crit_real_pred) + c_lambda * gp
    return crit_loss

cur_step = 0
generator_losses = []
critic_losses = []
for epoch in range(n_epochs):
    # Dataloader returns the batches
    for real, _ in tqdm(dataloader):
        cur_batch_size = len(real)
        real = real.to(device)

        mean_iteration_critic_loss = 0
        for _ in range(crit_repeats):
            ### Update critic ###
            crit_opt.zero_grad()
            fake_noise = get_noise(cur_batch_size, z_dim, device=device)
            fake = gen(fake_noise)
            crit_fake_pred = crit(fake.detach())
            crit_real_pred = crit(real)

            epsilon = torch.rand(len(real), 1, 1, 1, device=device, requires_grad=True)
            gradient = get_gradient(crit, real, fake.detach(), epsilon)
            gp = gradient_penalty(gradient)
            crit_loss = get_crit_loss(crit_fake_pred, crit_real_pred, gp, c_lambda)

            # Keep track of the average critic loss in this batch
            mean_iteration_critic_loss += crit_loss.item() / crit_repeats
            # Update gradients
            crit_loss.backward(retain_graph=True)
            # Update optimizer
            crit_opt.step()
        critic_losses += [mean_iteration_critic_loss]

        ### Update generator ###
        gen_opt.zero_grad()
        fake_noise_2 = get_noise(cur_batch_size, z_dim, device=device)
        fake_2 = gen(fake_noise_2)
        crit_fake_pred = crit(fake_2)

        gen_loss = get_gen_loss(crit_fake_pred)
        gen_loss.backward()

        # Update the weights
        gen_opt.step()

        # Keep track of the average generator loss
        generator_losses += [gen_loss.item()]

        ### Visualization code ###
        if cur_step % display_step == 0 and cur_step > 0:
            gen_mean = sum(generator_losses[-display_step:]) / display_step
            crit_mean = sum(critic_losses[-display_step:]) / display_step
            print(f"Step {cur_step}: Generator loss: {gen_mean}, critic loss: {crit_mean}")
            show_tensor_images(fake)
            show_tensor_images(real)
            step_bins = 20
            num_examples = (len(generator_losses) // step_bins) * step_bins
            plt.plot(
                range(num_examples // step_bins),
                torch.Tensor(generator_losses[:num_examples]).view(-1, step_bins).mean(1),
                label="Generator Loss"
            )
            plt.plot(
                range(num_examples // step_bins),
                torch.Tensor(critic_losses[:num_examples]).view(-1, step_bins).mean(1),
                label="Critic Loss"
            )
            plt.legend()
            plt.show()

        cur_step += 1

在这里插入图片描述

代码讲解

网络模型与上一篇的DCGAN没有变动。
在这里插入图片描述
这个模块进行梯度计算,即上文原理中正则项公式里面的梯度L2范数里的梯度。首先计算真实数据与生成数据之间随机取样的混合数据,然后输入 critic,最后计算出其梯度。
在这里插入图片描述
梯度惩罚模块,即上文原理中的整个正则项公式,梯度范数 -1 的平方。
在这里插入图片描述
critic 的 loss function 公式如下,generator 因为和真实数据无关,且与正则项也无关,所以只有中间一项。
在这里插入图片描述————————————————————————————————————————————

总之,WGAN-GP 不一定要提高 GAN 的整体性能,但会很好的提高稳定性并避免模式崩溃。

下一篇条件生成GAN。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/512051.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

【WEEK6】 【DAY2】DQL查询数据-第二部分【中文版】

2024.4.2 Tuesday 接上文【WEEK6】 【DAY1】DQL查询数据-第一部分【中文版】 目录 4.4.连接查询4.4.1.JOIN 对比4.4.2.七种JOIN4.4.3.例4.4.3.1.本例中INNER JOIN和RIGHT JOIN结果相同4.4.3.2.LEFT JOIN4.4.3.3.查询缺考的同学4.4.3.4.思考题:查询参加了考试的同学信…

Visual Studio安装下载进度为零已解决

因为在安装pytorch3d0.3.0时遇到问题,提示没有cl.exe,VS的C编译组件,可以添加组件也可以重装VS。查了下2019版比2022问题少,选择了安装2019版,下面是下载安装时遇到的问题记录,关于下载进度为零网上有三类解…

redis的哈希Hash

哈希是一个字符类型的字段和值的映射表,简单来说就是一个键值对的集合。 查看里面的name或者age在不在里面,0说明已经删了的 用来获取person里的键

[C#]使用OpencvSharp去除面积较小的连通域

【C介绍】 关于opencv实现有比较好的算法,可以参考这个博客OpenCV去除面积较小的连通域_c#opencv 筛选小面积区域-CSDN博客 但是没有对应opencvsharp实现同类算法,为了照顾懂C#编程同学们,因此将 去除面积较小的连通域算法转成C#代码。 方…

Java获取IP地址以及MAC地址(附Demo)

目录 前言1. IP及MAC2. 特定适配器 前言 需要获取客户端的IP地址以及MAC地址 import java.io.BufferedReader; import java.io.IOException; import java.io.InputStreamReader;public class test {public static void main(String[] args) {try {// 执行命令Process process…

Nginx在Kubernetes集群中的进阶应用

简介 在现代DevOps环境中,Nginx作为负载均衡器与Kubernetes的Ingress资源的结合,为应用程序提供了强大的路由和安全解决方案。本文将深入探讨如何利用Nginx的灵活性和功能,实现高效、安全的外部访问控制,以及如何配置Ingress以优…

智能小车测速(3.26)

模块介绍: 接线: VCC -- 3.3V 不能接5V,否则遮挡一次会触发3次中断 OUT -- PB14 测速原理: cubeMX设置: PB14设置为gpio中断 打开定时器2,时钟来源设置为内部时钟,设置溢出时间1s&#xff0c…

上位机图像处理和嵌入式模块部署(qmacvisual图像清晰度)

【 声明:版权所有,欢迎转载,请勿用于商业用途。 联系信箱:feixiaoxing 163.com】 做过isp的同学都知道,图像处理里面有一个3a,即自动曝光、自动白平衡和自动对焦。其中自动对焦这个,就需要用输入…

qt通过setProperty设置样式表笔记

在一个pushbutton里面嵌套两个label即可,左侧放置图片label,右侧放置文字label,就如上图所示; 但是这时的hover,press的伪状态是没有办法“传递”给里面的控件的,对btn的伪状态样式表的设置,是不…

IP SSL的应用与安装

IP SSL,即互联网协议安全套接字层,它是一种为网络通信提供安全及数据完整性的安全协议。在网络传输过程中,IP SSL可以对数据进行加密,这样即便数据在传输途中被截取,没有相应的解密密钥也无法解读内容。这一过程如同将…

防抖节流面试

1、防抖 1.1、条件 1、高频 2、耗时(比如console不算) 3、以最后一次调用为准 刷到个神评论,回城是防抖,技能cd是节流 1.2、手写 传参版本 function debounce(fn,delay){let timerreturn function(...args){//返回函数必须是普…

动态规划详解(Dynamic Programming)

目录 引入什么是动态规划?动态规划的特点解题办法解题套路框架举例说明斐波那契数列题目描述解题思路方式一:暴力求解思考 方式二:带备忘录的递归解法方式三:动态规划 推荐练手题目 引入 动态规划问题(Dynamic Progra…

QT子窗口关闭时自动释放及注意事项

先说方法,很简单,有如下API函数可用: testDialog->setAttribute( Qt::WA_DeleteOnClose, true ); 他的官方解释如下: 最后,说一个注意事项: 最近写python程序比较多,回过头来&a…

OPPO VPC 实践探索

01 概述 一年前(20年6月),OPPO云网络技术底座开始支持VPC方案,解决了用户担心的云上安全和虚拟实例的性能问题。我们称这个版本为VPC1.0,其采用了先进的智能网卡加速和VXLAN隧道隔离技术,实现了VPC从无到有的突破。 然而由于业务快…

爬虫部署平台crawlab使用说明

Crawlab 是一个基于 Go 语言的分布式网络爬虫管理平台,它支持 Python、Node.js、Jar、EXE 等多种类型的爬虫。 Crawlab 提供了一个可视化的界面,并且可以通过简单的配置来管理和监控爬虫程序。 以下是 Crawlab 的一些主要优点: 集中管理&am…

绿联 安装Mysql数据库

绿联 安装Mysql数据库 1、镜像 mysql:5.7 数据库5.7.x系列。 mysql:8 数据库8.x.x系列,安装方式相同。 2、安装 2.1、拉取镜像 拉取5.7.x版本的镜像。 2.2、基础设置 重启策略:第三或第四项均可。 2.3、网络 桥接即可。 2.4、命令 在原有的“mys…

概率论基础——拉格朗日乘数法

概率论基础——拉格朗日乘数法 概率论是机器学习和优化领域的重要基础之一,而拉格朗日乘数法与KKT条件是解决优化问题中约束条件的重要工具。本文将简单介绍拉格朗日乘数法的基本概念、应用以及如何用Python实现算法。 1. 基本概念 拉格朗日乘数法是一种用来求解…

EPSON机器人仿真实战攻略:从设置通信到运行调试一网打尽!

EPSON机器人 仿真测试深度教程 机器人还没到,怎么提前验证写好得机器人程序? 强大的仿真功能来了!本文详细深入的介绍了仿真的功能,一步步教会你如何仿真! 请先关注公众号收藏,防止走丢! 需要先设置电脑与控制器通信的虚拟连接,设置-电脑与控制器通信-增加-选择连接…

第27篇:T触发器实现4位计数器

Q:本篇我们用T触发器实现时序逻辑电路--计数器。 A:T触发器(Toggle Flip-Flop)只有一个信号输入端,在时钟有效边沿到来时,输入有效信号则触发器翻转,否则触发器保持不变,因此T触发器…

C++之结构体初始化10种写法总结(二百六十六)

简介: CSDN博客专家,专注Android/Linux系统,分享多mic语音方案、音视频、编解码等技术,与大家一起成长! 优质专栏:Audio工程师进阶系列【原创干货持续更新中……】🚀 优质专栏:多媒…