AIGC实战——WGAN(Wasserstein GAN)

AIGC实战——WGAN

    • 0. 前言
    • 1. WGAN-GP
      • 1.1 Wasserstein 损失
      • 1.2 Lipschitz 约束
      • 1.3 强制 Lipschitz 约束
      • 1.4 梯度惩罚损失
      • 1.5 训练 WGAN-GP
    • 2. GAN 与 WGAN-GP 的关键区别
    • 3. WGAN-GP 模型分析
    • 小结
    • 系列链接

0. 前言

原始的生成对抗网络 (Generative Adversarial Network, GAN) 在训练过程中面临着模式坍塌和梯度消失等问题,为了解决这些问题,研究人员提出了大量的关键技术以提高GAN模型的整体稳定性,并降低了上述问题出现的可能性。例如 WGAN (Wasserstein GAN) 和 WGAN-GP (Wasserstein GAN-Gradient Penalty) 等,通过对原始生成对抗网络 (Generative Adversarial Network, GAN) 框架进行了细微调整,就能够训练复杂GAN。在本节中,我们将学习 WGANWGAN-GP,两者都对原始 GAN 框架进行了细微调整,以改善图像生成过程的稳定性和质量。

1. WGAN-GP

WGAN (Wasserstein GAN) 是提高 GAN 训练稳定性方面的一次巨大进步,在经过一些简单改动后 GAN 就能够实现以下两个特点:

  • 与生成器的收敛度和生成样本质量相关的损失度量
  • 优化过程的稳定性得到提高

具体来说,WGAN 针对判别器和生成器提出了一种新的损失函数 (Wasserstein Loss),用这种损失函数代替二元交叉熵就可以让 GAN 的收敛更加稳定。
在本节中,我们将构建一个 WGAN-GP (Wasserstein GAN-Gradient Penalty),利用 CelebA 数据集训练模型以生成人脸图像。

1.1 Wasserstein 损失

首先我们来回顾一下二元交叉嫡, 在训练 DCGAN 判别器和生成器时采用了这种损失函数:
− 1 n ∑ i = 1 n ( y i l o g ( p i ) + ( 1 − y i ) l o g ( 1 − p i ) ) -\frac 1 n \sum_{i=1}^n(y_ilog(p_i)+(1-y_i)log(1-p_i)) n1i=1n(yilog(pi)+(1yi)log(1pi))
为了训练 GAN 的判别器 D,我们根据以下两者计算损失:真实图像的预测 p i = D ( x i ) p_i=D(x_i) pi=D(xi) 与标签 y i = 1 y_i=1 yi=1 之间的误差,以及生成图像的预测 p i = D ( G ( z i ) ) p_i=D(G(z_i)) pi=D(G(zi))与标签 y i = 0 y_i=0 yi=0 之间的误差。因此,对于 GAN 的判别器来说,损失函数最小化的过程可以表示为:
min ⁡ D − ( E x ∼ p X [ log ⁡ D ( x ) ] + E z ∼ p Z [ log ⁡ ( 1 − D ( G ( z ) ) ) ] ) \mathop {\min} \limits_{D}-(\mathbb E_{x\sim p_X}[\log D(x)]+\mathbb E_{z\sim p_Z}[\log (1-D(G(z)))]) Dmin(ExpX[logD(x)]+EzpZ[log(1D(G(z)))])
为了训练 GAN 的生成器 G,我们根据生成图像的预测 p i = D ( G ( z i ) ) p_i=D(G(z_i)) pi=D(G(zi)) 与标签 y i = 1 y_i=1 yi=1 的误差计算损失。因此,对于 GAN 的生成器来说,将损失函数最小化的过程可以表示为:
min ⁡ G − ( E z ∼ p Z [ log ⁡ D ( G ( z ) ) ] ) \mathop {\min}\limits_{G}-(\mathbb E_{z\sim p_Z}[\log D(G(z))]) Gmin(EzpZ[logD(G(z))])
接下来,我们比较上述损失函数与 Wasserstein 损失函数。
Wasserstein 损失 (Wasserstein Loss) 是用于 Wasserstein GAN (WGAN) 的一种损失函数。与传统的二元交叉熵损失函数不同,Wasserstein 损失引入了标签 1-1,将判别器的输出从概率值转变为分数 (score),因此,WGAN 的判别器通常也被称为评论家 (critic),并要求判别器是 1-Lipschitz 连续函数。
具体来说,Wasserstein 损失使用标签 y i = 1 y_i=1 yi=1 y i = − 1 y_i=-1 yi=1 代替 y i = 1 y_i=1 yi=1 y i = 0 y_i=0 yi=0,同时还需要移除判别器最后一层的 Sigmoid激活函数,如此一来预测结果 p i p_i pi 就不一定在 [ 0 , 1 ] [0,1] [0,1] 范围内了,它可以是 [ − ∞ , ∞ ] [-∞,∞] [,] 范围内的任何值。Wasserstein 损失的定义如下:
− 1 n ∑ i = 1 n ( y i p i ) -\frac 1 n∑_{i=1}^n(y_ip_i) n1i=1n(yipi)
在训练 WGAN 的判别器 D 时,我们将计算以下损失:判别器对真实图像的预测 p i = D ( x i ) p_i=D(x_i) pi=D(xi) 与标签 y i = 1 y_i=1 yi=1 之间的误差,判别器对生成图像的预测 p i = D ( G ( z i ) ) p_i=D(G(z_i)) pi=D(G(zi)) 与标签 y i = − 1 y_i=-1 yi=1 之间的误差。因此,对于 WGAN 判别器,最小化损失函数的过程可以表示为:
min ⁡ D − ( E x ∼ p X [ D ( x ) ] − E z ∼ p Z [ D ( G ( z ) ) ] ) \mathop {\min}\limits_ D - (\mathbb E_{x\sim p_X}[D(x)] - \mathbb E_{z\sim p_Z}[D(G(z))]) Dmin(ExpX[D(x)]EzpZ[D(G(z))])
换句话说,WGAN 判别器试图最大化其对真实图像的预测和生成图像的预测之间的差异,且真实图像的得分更高。
而对于 WGAN 生成器 G 的训练,我们根据判别器对生成图像的预测 p i = D ( G ( z i ) ) p_i=D(G(z_i)) pi=D(G(zi)) 与标签 y i = 1 y_i=1 yi=1 计算损失。因此,对于 WGAN 生成器,最小化损失函数可以表示为:
min ⁡ G − ( E z ∼ p Z [ D ( G ( z ) ) ] ) \mathop {\min}\limits_ G - (\mathbb E_{z\sim p_Z}[D(G(z))]) Gmin(EzpZ[D(G(z))])
换句话说,WGAN 生成器试图生成被判别器以极高分数判定为真实图像的图像(即,令判别器认为它们是真实的)。

1.2 Lipschitz 约束

由于我们允许判别器输出 [ − ∞ , ∞ ] [-∞,∞] [,] 范围内的任意值,而不是按照 Sigmoid 函数那样将输出限制在 [ 0 , 1 ] [0,1] [0,1] 范围内,因此 Wasserstein 损失可能会非常大。因此,为了使 Wasserstein 损失函数正常工作,需要对判别器进行额外约束,即 1-Lipschitz 连续性约束。判别器是一个将图像转换为预测的函数 D,如果对于任意两个输人图像 x 1 x_1 x1 x 2 x_2 x2,判别器函数 D 满足以下不等式,则该函数为 1-Lipschitz 连续:
∣ D ( x 1 ) − D ( x 2 ) ∣ ∣ x 1 − x 2 ∣ ≤ 1 \frac {|D(x_1) - D(x_2)|}{|x_1 - x_2|} ≤ 1 x1x2D(x1)D(x2)1
其中, ∣ x 1 − x 2 ∣ |x_1 - x_2| x1x2 表示两个图像的平均像素之差的绝对值, ∣ D ( x 1 ) − D ( x 2 ) ∣ |D(x_1) - D(x_2)| D(x1)D(x2) 表示判别器预测之间的绝对值。这意味着判别器的预测变化速率在任何情况下都是有界的(即梯度的绝对值不能大于 1)。可以在下图中的 Lipschitz 连续的一维函数中看到,无论将圆锥放在任何位置,曲线都不会进入圆锥内部。换句话说,曲线上任何一点的上升或下降速度都是有限的。

Lipschitz 连续

1.3 强制 Lipschitz 约束

在原始的 WGAN 论文中,作者通过在每个训练结束后将判别器的权重裁剪到一个较小范围内 [ − 0.01 , 0.01 ] [-0.01, 0.01] [0.01,0.01] 来强制执行 Lipschitz 约束。
由于我们裁剪了判别器的权重,判别器的学习能力大大降低,因此,事实上,权重裁剪并不是一种理想的强制 Lipschitz 约束的方式。一个强大的判别器对于 WGAN 的成功至关重要,因为如果没有准确的梯度,生成器无法学习如何调整其权重以产生更好的样本。
因此,研究人员提出了许多其他方法来强制执行 Lipschitz 约束,并提高 WGAN 学习复杂特征的能力。其中一种方法是带有梯度惩罚 (Gradient Penalty) 的 Wasserstein GAN
通过在判别器的损失函数中包含一个梯度惩罚项来直接强制执行 Lipschitz 约束,如果梯度范数偏离 1 时,该项会惩罚模型,从而使训练过程更加稳定。
接下来,将这个额外的梯度惩罚项加入到判别器损失函数中。

1.4 梯度惩罚损失

下图展示了 WGAN-GP 判别器的训练过程,与原始判别器的训练过程进行比较,我们可以看到关键的改进是将梯度惩罚损失作为整体损失函数的一部分,并与来自真实图像和生成图像的 Wasserstein 损失一起使用。

WGAN-GP

梯度惩罚损失衡量了预测关于输入图像的梯度范数与 1 之间的平方差。模型倾向于找到能够使梯度惩罚项最小化的权重,从而鼓励模型符合 Lipschitz 约束。
在训练过程中,每一处的计算梯度是非常困难的,因此WGAN-GP 只在少数几个点处评估梯度。为了确保平衡的,我们使用一组插值图像,在真实图像与伪造图像之间的随机位置逐像素进行插值 (Interpolation) 以生成一些图像。

插值图像

使用 Keras 计算梯度惩罚项:

    def gradient_penalty(self, batch_size, real_images, fake_images):
        # 批数据中的每个图像都会得到一个 0~1 之间的随机数字,存储到向量 alpha 中
        alpha = tf.random.normal([batch_size, 1, 1, 1], 0.0, 1.0)
        # 计算一组插值图像
        diff = fake_images - real_images
        interpolated = real_images + alpha * diff
        with tf.GradientTape() as gp_tape:
            gp_tape.watch(interpolated)
            # 使用判别器对每个插值图像进行评分
            pred = self.critic(interpolated, training=True)
        # 计算插值图像 (y_pred) 的预测对于输入 interpolated_samples) 的梯度
        grads = gp_tape.gradient(pred, [interpolated])[0]
        # 计算这个向量的 L2 范数(即欧几里得长度)
        norm = tf.sqrt(tf.reduce_sum(tf.square(grads), axis=[1, 2, 3]))
        # 函数返回 L2 范数与 1 之差的平方的均值
        gp = tf.reduce_mean((norm - 1.0) ** 2)
        return gp

1.5 训练 WGAN-GP

使用 Wasserstein 损失函数的一个优点是,不再需要担心平衡判别器和生成器的训练。事实上,在使用 Wasserstein 损失时,必须在更新生成器之前将判别器训练到收敛,以确保生成器更新的梯度准确无误。这与标准 GAN 相反,标准 GAN 中重要的是不要让判别器变得过强。
因此,使用 Wasserstein GAN,我们可以简单地在生成器更新之间多次训练判别器,以确保它接近收敛。通常每次生成器更新一次,判别器更新三到五次。
了解了 WGAN-GP 的两个关键概念 (Wasserstein 损失和梯度惩罚项)后,使用 Keras 实现 WGAN-GP

    def train_step(self, real_images):
        batch_size = tf.shape(real_images)[0]
        # 对判别器进行三次更新
        for i in range(self.critic_steps):
            random_latent_vectors = tf.random.normal(
                shape=(batch_size, self.latent_dim)
            )

            with tf.GradientTape() as tape:
                fake_images = self.generator(
                    random_latent_vectors, training=True
                )
                fake_predictions = self.critic(fake_images, training=True)
                real_predictions = self.critic(real_images, training=True)
                # 计算判别器的 Wasserstein 损失
                c_wass_loss = tf.reduce_mean(fake_predictions) - tf.reduce_mean(real_predictions)
                # 计算梯度惩罚项
                c_gp = self.gradient_penalty(batch_size, real_images, fake_images)
                # 判别器损失函数是 Wasserstein 损失和梯度惩罚的加权和
                c_loss = c_wass_loss + c_gp * self.gp_weight
            c_gradient = tape.gradient(c_loss, self.critic.trainable_variables)
            # 更新判别器的权重
            self.c_optimizer.apply_gradients(
                zip(c_gradient, self.critic.trainable_variables)
            )
        random_latent_vectors = tf.random.normal(shape=(batch_size, self.latent_dim))
        with tf.GradientTape() as tape:
            fake_images = self.generator(random_latent_vectors, training=True)
            fake_predictions = self.critic(fake_images, training=True)
            # 计算生成器的 Wasserstein 损失
            g_loss = -tf.reduce_mean(fake_predictions)

        gen_gradient = tape.gradient(g_loss, self.generator.trainable_variables)
        # 更新生成器的权重
        self.g_optimizer.apply_gradients(
            zip(gen_gradient, self.generator.trainable_variables)
        )

        self.c_loss_metric.update_state(c_loss)
        self.c_wass_loss_metric.update_state(c_wass_loss)
        self.c_gp_metric.update_state(c_gp)
        self.g_loss_metric.update_state(g_loss)
        return {m.name: m.result() for m in self.metrics}

在训练 WGAN-GP 之前,需要注意的最后一点是判别器不应该使用批量归一化。这是因为批归一化会在同一批图像之间创建相关性,从而使梯度惩罚损失的效果降低。实验证明,即使在判别器中没有批归一化, WGAN-GP 仍然可以输出出色的结果。

2. GAN 与 WGAN-GP 的关键区别

总而言之,标准 GANWGAN-GP 之间存在以下:

  • WGAN-GP 使用 Wasserstein 损失
  • WGAN-GP 使用 1 表示真实图像标签,使用 -1 表示伪造图像的标签
  • 判别器的最后一层没有使用 sigmoid 激活
  • 在判别器的损失函数中包含梯度惩罚项
  • 每训练一次生成器更新权重,需要多次训练判别器
  • 判别器中没有批归一化层

3. WGAN-GP 模型分析

训练 25epoch 后,WGAN-GP 模型的生成器能够生成合理图像:

面部生成结果

该模型已经学习到了面部的重要高级特征,且没有出现模式坍塌的迹象。
如果我们将 WGAN-GP 的输出与变分自编码器 (Variational Autoencoder, VAE) 的输出进行比较,可以看到 WGAN-GP 生成的图像通常更清晰。总的来说,VAE 倾向于产生颜色边界模糊的图像,而 GAN 产生的图像更加清晰合理。GAN 通常比 VAE 更难训练,需要更长的时间才能获得满意的数据质量。

小结

在本节中,我们学习了如何使用 Wasserstein 损失函数以解决经典 GAN 训练过程中的模式坍塌和梯度消失等问题,使得 GAN 的训练更加可预测和可靠。WGAN-GP 通过在损失函数中添加一个令梯度范数指向 1 的项,为训练过程施加 1-Lipschitz 约束。

系列链接

AIGC实战——生成模型简介
AIGC实战——深度学习 (Deep Learning, DL)
AIGC实战——卷积神经网络(Convolutional Neural Network, CNN)
AIGC实战——自编码器(Autoencoder)
AIGC实战——变分自编码器(Variational Autoencoder, VAE)
AIGC实战——使用变分自编码器生成面部图像
AIGC实战——生成对抗网络(Generative Adversarial Network, GAN)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/235419.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

PyTorch深度学习实战——人群计数

PyTorch深度学习实战——人群计数 0. 前言1. 人群计数1.1 基本概念1.2 CRSNet 架构 2. 使用 CSRNet 实现人群计数2.1 模型分析2.2 数据集分析2.3 模型构建与训练 相关链接 0. 前言 人群计数是指通过图像或视频分析技术,对给定场景中的人群数量进行估计和统计的过程…

【MATLAB】REMD信号分解+FFT+HHT组合算法

有意向获取代码,请转文末观看代码获取方式~也可转原文链接获取~ 1 基本定义 TVFEMDFFTHHT组合算法是一种结合了总体变分模态分解(TVFEMD)、傅里叶变换(FFT)和希尔伯特-黄变换(HHT)的信号分解方…

多维时序 | MATLAB实现RIME-CNN-LSTM-Multihead-Attention多头注意力机制多变量时间序列预测

多维时序 | MATLAB实现RIME-CNN-LSTM-Multihead-Attention多头注意力机制多变量时间序列预测 目录 多维时序 | MATLAB实现RIME-CNN-LSTM-Multihead-Attention多头注意力机制多变量时间序列预测预测效果基本介绍模型描述程序设计参考资料 预测效果 基本介绍 MATLAB实现RIME-CNN-…

微信小程序:用map()将对象数组中的某一项组合成新数组

使用分析 使用map()方法来遍历 info 数组中的每个元素,并整合每一个对象中的某一项进行新数组的重组 效果展示 这里是查询对象数组中的全部name值 原始数据 提取出name的数组 核心代码 var infos items.map(item > item.name); 完整代码(用微信小程…

基于hadoop下的spark安装

目录 简介 安装准备 spark安装 配置文件配置 简介 Spark主要⽤于⼤数据的并⾏计算,⽽Hadoop在企业主要⽤于⼤数据的存储(⽐如HDFS、Hive和HBase 等),以及资源调度(Yarn)。但是也有很多公司也在使⽤MR2进…

Web server failed to start. Port 8888 was already in use.

端口占用 强制终止占用端口的进程 获取占用端口的进程ID(PID):在终端或命令提示符中运行以下命令以查找占用端口的进程ID: ①在 Unix/Linux/Mac 上:lsof -i :8888 ②在 Windows 上:netstat -ano | findstr …

HTML面试题---专题二

文章目录 一、前言二、解释input标签中占位符属性的用途三、如何在 HTML 中设置复选框或单选按钮的默认选中状态?四、表单输入字段中必填属性的用途是什么?五、如何使用 HTML 创建表格?六、解释a标签中目标属性的用途七、如何创建一个点击后会…

Java飞翔的小鸟

一、项目分析 创建一个窗口和画板,把画板放到窗口上,在画板上绘画图片 (2)让小鸟在画面中动起来,可以上下飞 (3)让地面和管道动起来 (4)碰撞检测 (5&#xf…

Nginx 优化与防盗链

目录 配置Nginx隐藏版本号 Nginx隐藏版本号的方法 修改配置文件法 修改源码法 修改用户与组 设置缓存时间 日志切割 连接超时 更改进程数 配置网页压缩 配置防盗链 fpm参数优化 总结:nginx优化 配置Nginx隐藏版本号 可以使用 Fiddler 工具抓取数据包&…

【Citespace】从Citespace开始的引文可视化分析

CiteSpace 译“引文空间”,是一款着眼于分析科学分析中蕴含的潜在知识,是在科学计量学、数据可视化背景下逐渐发展起来的引文可视化分析软件。由于是通过可视化的手段来呈现科学知识的结构、规律和分布情况,因此也将通过此类方法分析得到的可…

巧用ChatGPT高效搞定Excel数据分析【文末送书-04】

文章目录 一.巧用ChatGPT高效搞定Excel数据分析1. ChatGPT简介2. 安装所需工具2.1 Python2.2 OpenAI GPT库 3. 与ChatGPT交互进行数据分析4. 利用ChatGPT进行筛选和排序5. ChatGPT的局限性和注意事项6. ChatGPT与数据可视化7. ChatGPT与进阶数据分析任务 二. 结论&文末福利…

Windows安装Maven

一、Maven 是什么? Maven 是一个项目管理和整合工具。Maven 为开发者提供了一套完整的构建生命周期框架。开发团队几乎不用花多少时间就能够自动完成工程的基础构建配置,因为 Maven 使用了一个标准的目录结构和一个默认的构建生命周期。 在有多个开发团…

软件开发安全指南

2.1.应用系统架构安全设计要求 2.2.应用系统软件功能安全设计要求 2.3.应用系统存储安全设计要求 2.4.应用系统通讯安全设计要求 2.5.应用系统数据库安全设计要求 2.6.应用系统数据安全设计要求 软件开发全资料获取:点我获取

用Java实现根据数据库中的数量,生成年月份+序号递增

在日常开发中,经常会遇到根据年月日和第几号文件生成对应的编号,今天给大家提供一个简单的工具类 public static final Long CODE1L;/*** param select 数据库中数据总数* return*/public static String SubjectNo(Long select){// 在总数的基础上1&…

c2-C语言--指针

1.用一级指针遍历一维数组 结论 buf[i]<>*(buf i) <> *(p i)<> p[i] #include <stdio.h>int main(){int buf[5] {10,20 ,30 ,40,50}; //buf[0] --- int // buf --&buf[0] ----int *int *p buf;//&buf[0] --- &*(buf0)printf(&quo…

统一存储、全闪阵列、分布式NAS,企业级存储概述

Infortrend普安科技即将迎来公司成立30周年华诞。Infortrend普安科技从无到有&#xff0c;由小做强&#xff0c;为全球用户提供高性能、高可靠、高扩展、环保节能的存储解决方案&#xff0c;在存储领域造就了一段品牌佳话。从1993年成立伊始&#xff0c;Infortrend一直致力于企…

云服务器哪家便宜?亚马逊云科技按需选实例够便宜

随着云计算的迅猛发展&#xff0c;越来越多的企业和个人开始关注云服务器的选择。在众多云服务提供商中&#xff0c;亚马逊云科技&#xff08;Amazon Web Services&#xff0c;AWS&#xff09;凭借其强大的基础设施和丰富的服务&#xff0c;备受业界青睐。本文聚焦一个备受关注…

【lesson3】数据库表的操作

文章目录 创建修改修改表名增加表类型修改表的某一类型的类型修改表某一类型的类型名 删除删除表的某一列删除表 查看查看表信息查看表内容 创建 建表指令&#xff1a; 查看是否建表成功&#xff1a; 查看表的具体信息&#xff1a; 修改 修改表名 法一&#xff1a;修改…

基础宠物商店管理系统(Java)大一程序设计

一.开发环境 Windows 11 -- JDK 21 -- IDEA 2021.3.3 二.需求 三.代码部分 //创建一个宠物类&#xff0c;被另外两类继承public class Pet {private String name;private int age;private String gender;private double cost0;//买进价格private double sellprice0;//卖出价…

hdlbits系列verilog解答(mt2015_q4a)-52

文章目录 一、问题描述二、verilog源码三、仿真结果 一、问题描述 本次我们实现一个简单的组合逻辑输出。 z (x^y) & x 模块声明&#xff1a; module top_module (input x, input y, output z); 二、verilog源码 module top_module (input x, input y, output z);assig…