揭示 Wasserstein 生成对抗网络的潜力:生成建模的新范式

导 读

Wasserstein 生成对抗网络 (WGAN) 作为一项关键创新而出现,解决了经常困扰传统生成对抗网络 (GAN) 的稳定性和收敛性的基本挑战。

由 Arjovsky 等人于2017 年提出,WGAN 通过利用 Wasserstein 距离彻底改变了生成模型的训练,提供了一个强大的框架,可以提高生成样本的质量和多样性。

本文深入探讨了 WGAN 的概念基础、优势和实际意义,说明了它们在更广泛的生成建模背景下的重要性。

有需要的朋友关注公众号【小Z的科研日常】,获取更多内容

01、WGAN的概念框架

WGAN 与其前辈的区别在于用 Wasserstein 距离代替 Jensen-Shannon 散度作为其损失函数。

瓦瑟斯坦距离,直观地理解为推土机距离,量化了将一种概率分布转换为另一种概率分布所需的最小成本。

该指标赋予 WGAN 在训练过程中更平滑、更可靠的梯度信号,即使在真实数据分布和生成数据分布不重叠的情况下,也有助于生成更高质量的样本。

与传统 GAN 的一个重要区别是取代了判别器。与将输入分类为真实或虚假的判别器不同,WGAN 框架中的批评者评估真实样本和生成样本的分布之间的 Wasserstein 距离。

这种从分类到估计的转变标志着生成模型处理学习过程的方式发生了根本性变化,从而实现了更细致、更有效的训练动态。

02、相比于传统GAN的优势与挑战

WGAN 提供了几个引人注目的优势,可以解决传统 GAN 框架的局限性。

首先,它们表现出改进的训练稳定性,降低了对超参数设置和架构选择的敏感性。这种稳定性源于 Wasserstein 距离的特性,即使真实分布和生成分布之间没有重叠,它也能提供有用的梯度信息——这是一个可能阻碍传统 GAN 训练的常见问题。

此外,WGAN 还缓解了模式崩溃问题,即生成器学习产生有限范围的输出,从而无法捕获真实数据分布的多样性的现象。Wasserstein 距离的连续且更有意义的损失景观鼓励生成器探索更广泛的输出,从而增强生成样本的多样性。

WGAN 中损失度量的可解释性也代表了重大进步。与传统 GAN(判别器的准确性不一定与生成样本的质量相关)不同,WGAN 中的批评者损失提供了更直接的收敛性衡量标准,为训练过程和生成数据的质量提供了有价值的见解。

尽管有其优点,WGAN 也带来了新的挑战,主要与计算效率有关。WGAN 的最初实现需要权重裁剪来强制执行 Lipschitz 约束,这对于 Wasserstein 距离的理论属性至关重要。

然而,权重裁剪可能会导致优化困难和容量利用率不足。为了解决这个问题,引入带有梯度惩罚的 WGAN (WGAN-GP) 提出了一种替代方法来强制实施 Lipschitz 约束,而无需进行权重裁剪,从而提高训练稳定性和模型性能。

03、代码

为 Wasserstein 生成对抗网络 (WGAN) 创建完整的代码示例涉及几个步骤,包括定义生成器和批评者的模型架构、准备合成数据集、训练模型以及通过指标和图评估性能。

此示例将说明使用 TensorFlow 和 Keras 的基本实现,并使用简单的合成数据集以便于理解。

import numpy as np
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
import matplotlib.pyplot as plt

def build_critic():
    model = keras.Sequential([
        keras.Input(shape=(28, 28, 1)),
        layers.Conv2D(64, (5, 5), strides=(2, 2), padding='same'),
        layers.LeakyReLU(alpha=0.2),
        layers.Conv2D(128, (5, 5), strides=(2, 2), padding='same'),
        layers.LeakyReLU(alpha=0.2),
        layers.GlobalMaxPooling2D(),
        layers.Dense(1),
    ])
    return model

def build_generator(latent_dim):
    model = keras.Sequential([
        keras.Input(shape=(latent_dim,)),
        layers.Dense(7 * 7 * 128),
        layers.Reshape((7, 7, 128)),
        layers.Conv2DTranspose(128, (4, 4), strides=(2, 2), padding='same'),
        layers.LeakyReLU(alpha=0.2),
        layers.Conv2DTranspose(128, (4, 4), strides=(2, 2), padding='same'),
        layers.LeakyReLU(alpha=0.2),
        layers.Conv2D(1, (7, 7), padding='same', activation='sigmoid'),
    ])
    return model

class WGAN(keras.Model):
    def __init__(self, critic, generator, latent_dim):
        super(WGAN, self).__init__()
        self.critic = critic
        self.generator = generator
        self.latent_dim = latent_dim
        self.d_optimizer = keras.optimizers.Adam(learning_rate=0.0002, beta_1=0.5)
        self.g_optimizer = keras.optimizers.Adam(learning_rate=0.0002, beta_1=0.5)
        self.critic_loss_tracker = keras.metrics.Mean(name="critic_loss")
        self.generator_loss_tracker = keras.metrics.Mean(name="generator_loss")

    @property
    def metrics(self):
        return [self.critic_loss_tracker, self.generator_loss_tracker]

    def compile(self, d_optimizer, g_optimizer, d_loss_fn, g_loss_fn):
        super(WGAN, self).compile()
        self.d_optimizer = d_optimizer
        self.g_optimizer = g_optimizer
        self.d_loss_fn = d_loss_fn
        self.g_loss_fn = g_loss_fn

    def train_step(self, real_images):
        if isinstance(real_images, tuple):
            real_images = real_images[0]

       # 在潜在空间中随机取样
        batch_size = tf.shape(real_images)[0]
        random_latent_vectors = tf.random.normal(shape=(batch_size, self.latent_dim))

        # 将它们解码为假图像
        generated_images = self.generator(random_latent_vectors)

        # 将它们与真实图像相结合
        combined_images = tf.concat([generated_images, real_images], axis=0)

        # 组合标签,辨别真假图像
        labels = tf.concat(
            [tf.ones((batch_size, 1)), -tf.ones((batch_size, 1))], axis=0
        )
        # 在标签中添加随机噪音--重要技巧!
        labels += 0.05 * tf.random.uniform(tf.shape(labels))

        # 训练批评家
        with tf.GradientTape() as tape:
            predictions = self.critic(combined_images)
            d_loss = self.d_loss_fn(labels, predictions)
        grads = tape.gradient(d_loss, self.critic.trainable_variables)
        self.d_optimizer.apply_gradients(
            zip(grads, self.critic.trainable_variables)
        )

       # 在潜在空间中随机取样
        random_latent_vectors = tf.random.normal(shape=(batch_size, self.latent_dim))

        # 组装 "所有真实图像 "的标签
        misleading_labels = -tf.ones((batch_size, 1))

        # 训练生成器(通过评论家模型)
        with tf.GradientTape() as tape:
            predictions = self.critic(self.generator(random_latent_vectors))
            g_loss = self.g_loss_fn(misleading_labels, predictions)
        grads = tape.gradient(g_loss, self.generator.trainable_variables)
        self.g_optimizer.apply_gradients(
            zip(grads, self.generator.trainable_variables)
        )

       # 更新指标
        self.critic_loss_tracker.update_state(d_loss)
        self.generator_loss_tracker.update_state(g_loss)

        return {
            "critic_loss": self.critic_loss_tracker.result(),
            "generator_loss": self.generator_loss_tracker.result(),
        }

latent_dim = 128

# 准备数据集
(x_train, _), (_, _) = keras.datasets.mnist.load_data()
x_train = x_train.astype("float32") / 255.0
x_train = np.expand_dims(x_train, axis=-1)

# 实例化批评者和生成器模型
critic = build_critic()
generator = build_generator(latent_dim)

# 实例化 WGAN 模型
wgan = WGAN(critic=critic, generator=generator, latent_dim=latent_dim)

# 编译 WGAN 模型
wgan.compile(
    d_optimizer=keras.optimizers.Adam(learning_rate=0.0002, beta_1=0.5),
    g_optimizer=keras.optimizers.Adam(learning_rate=0.0002, beta_1=0.5),
    d_loss_fn=keras.losses.MeanSquaredError(),
    g_loss_fn=keras.losses.MeanSquaredError(),
)

wgan.fit(x_train, batch_size=32, epochs=100)

def generate_and_save_images(model, epoch, test_input):
    predictions = model.generator(test_input, training=False)

    fig = plt.figure(figsize=(4, 4))

    for i in range(predictions.shape[0]):
        plt.subplot(4, 4, i + 1)
        plt.imshow(predictions[i, :, :, 0] * 127.5 + 127.5, cmap='gray')
        plt.axis('off')

    plt.savefig('image_at_epoch_{:04d}.png'.format(epoch))
    plt.show()

# 生成潜在点
random_latent_vectors = tf.random.normal(shape=(16, latent_dim))
generate_and_save_images(wgan, 0, random_latent_vectors)
1875/1875 [==============================] - 28s 15ms/step - critic_loss: 0.5405 - generator_loss: 2.4530
Epoch 99/100
1875/1875 [==============================] - 28s 15ms/step - critic_loss: 0.5408 - generator_loss: 2.4463
Epoch 100/100
1875/1875 [==============================] - 28s 15ms/step - critic_loss: 0.5384 - generator_loss: 2.4411

此代码提供了使用简单数据集通过 TensorFlow 和 Keras 实现 WGAN 的基础框架。对于实际应用程序,您可能需要调整数据集、架构和训练参数以满足您的特定需求。

04、结论

Wasserstein 生成对抗网络代表了生成建模领域的重大飞跃。通过将 Wasserstein 距离集成到 GAN 框架中,WGAN 为训练生成模型提供了更稳定、可靠和可解释的方法。

尽管存在与计算需求和 Lipschitz 约束的执行相关的挑战,但 WGAN 及其后续迭代(如 WGAN-GP)所带来的进步继续影响着生成模型的发展。

随着该领域研究的进展,WGAN 有望进一步释放生成模型在从图像合成到自然语言生成等众多应用中的潜力,预示着人工智能驱动的创造力和创新的新时代。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/422154.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

如何在群晖Docker运行本地聊天机器人并结合内网穿透发布到公网访问

文章目录 1. 拉取相关的Docker镜像2. 运行Ollama 镜像3. 运行Chatbot Ollama镜像4. 本地访问5. 群晖安装Cpolar6. 配置公网地址7. 公网访问8. 固定公网地址 随着ChatGPT 和open Sora 的热度剧增,大语言模型时代,开启了AI新篇章,大语言模型的应用非常广泛,包括聊天机…

Tokenize Anything via Prompting论文解读

文章目录 前言一、摘要二、引言三、模型结构图解读四、相关研究1、Vision Foundation Models2、Open-Vocabulary Segmentation3、Zero-shot Region Understanding 五、模型方法解读1、Promptable TokenizationPre-processingPromptable segmentationConcept predictionZero-sho…

STM32标准库开发—实时时钟(BKP+RTC)

BKP配置结构 注意事项 BKP基本操作 时钟初始化 RCC_APB1PeriphClockCmd(RCC_APB1Periph_PWR, ENABLE);RCC_APB1PeriphClockCmd(RCC_APB1Periph_BKP, ENABLE);PWR_BackupAccessCmd(ENABLE);//设置PWR_CR的DBP,使能对PWR以及BKP的访问读写寄存器操作 uint16_t ArrayW…

LeetCode--72

72. 编辑距离 给你两个单词 word1 和 word2, 请返回将 word1 转换成 word2 所使用的最少操作数 。 你可以对一个单词进行如下三种操作: 插入一个字符删除一个字符替换一个字符 示例 1: 输入:word1 "horse", word2 …

Mysql与StarRocks语法上的不同

🐓 序言 StarRocks 是新一代极速全场景 MPP (Massively Parallel Processing) 数据库。StarRocks 的愿景是能够让用户的数据分析变得更加简单和敏捷。用户无需经过复杂的预处理,可以用StarRocks 来支持多种数据分析场景的极速分析。 🐓 语法…

STL容器之string类

文章目录 STL容器之string类1、 什么是STL2、STL的六大组件3、string类3.1、string类介绍3.2、string类的常用接口说明3.2.1、string类对象的常见构造3.2.2、string类对象的容量操作3.2.3、string类对象的访问及遍历操作3.2.4、 string类对象的修改操作3.2.5、 string类非成员函…

springBoot整合Redis(二、RedisTemplate操作Redis)

Spring-data-redis是spring大家族的一部分,提供了在srping应用中通过简单的配置访问redis服务,对reids底层开发包(Jedis, JRedis, and RJC)进行了高度封装,RedisTemplate提供了redis各种操作、异常处理及序列化,支持发布订阅&…

支持向量机算法(带你了解原理 实践)

引言 在机器学习和数据科学中,分类问题是一种常见的任务。支持向量机(Support Vector Machine, SVM)是一种广泛使用的分类算法,因其出色的性能和高效的计算效率而受到广泛关注。本文将深入探讨支持向量机算法的原理、特点、应用&…

Unity(第二十一部)动画的基础了解(感觉不了解其实也行)

1、动画组件老的是Animations 动画视频Play Automatically 是否自动播放Animate Physics 驱动方式,勾选后是物理驱动Culling Type 剔除方式 默认总是动画化就会一直执行下去,第二个是基于渲染播放(离开镜头后不执行), …

蓝桥杯倒计时 43天 - 前缀和,单调栈

最大数组和 算法思路&#xff1a;利用前缀和化简 for 循环将 n^2 简化成 nn&#xff0c;以空间换时间。枚举每个 m&#xff0c;m是删除最小两个数&#xff0c;那k-m就是删除最大数&#xff0c;m<k&#xff0c;求和最大的值。暴力就是枚举 m-O(n)&#xff0c;计算前 n-(k-m)的…

Revit-二开之创建TextNote-(1)

Revit二开之创建TextNote TextNode在Revit注释模块中&#xff0c;具体位置如图所示 图中是Revit2018版本 【Revit中的使用】 Revit 中的操作是点击上图中的按钮在平面视图中点击任意放置放置就行&#xff0c; 在属性中可以修改文字 代码实现 创建TextNode ExternalComm…

有趣的CSS - 故障字体效果

大家好&#xff0c;我是 Just&#xff0c;这里是「设计师工作日常」&#xff0c;今天分享的是用 css 实现一个404故障字体效果。 《有趣的css》系列最新实例通过公众号「设计师工作日常」发布。 目录 整体效果核心代码html 代码css 部分代码 完整代码如下html 页面css 样式页面…

2024年全国乙卷高考理科数学备考:十年选择题真题和解析

今天距离2024年高考还有三个多月的时间&#xff0c;今天我们来看一下2014~2023年全国乙卷高考理科数学的选择题&#xff0c;从过去十年的真题中随机抽取5道题&#xff0c;并且提供解析。后附六分成长独家制作的在线练习集&#xff0c;科学、高效地反复刷这些真题&#xff0c;吃…

Linux上搭建并使用ffmpeg(Java)

关于MacOs和Windows系统上使用ffmpeg就不多说了&#xff0c;有很多相关文章&#xff0c;今天给大家分享一个在Linux环境下使用Java语言来使用ffmpeg 一、首先去官网下载一个Linux对应的ffmpeg包 1、进入ffmpeg官网&#xff1a;官网 2、点击左侧导航栏Download 3、选择Linux对…

什么是人才储备?如何做人才储备?

很多小伙伴都会有企业面试被拒的情况&#xff0c;然后HR会告诉你&#xff0c;虽然没有录用你&#xff0c;但是你进入了他们的人才储备库&#xff0c;那么这个储备库有什么作用和特点呢&#xff1f;我们如何应用人才测评系统完善人才储备库呢&#xff1f; 人才储备一般有以下三…

软考重点题解析-基础知识

1.加密技术&#xff1a;分为对称加密技术&#xff1a;文件的加密和解密使用相同的密钥 和 非对称加密技术&#xff1a;加密和解密不同的密钥&#xff0c;分别是公开密钥和私有密钥。 例题&#xff1a;若A,B两人分别在认证机构&#xff08;CA&#xff09;M,N处获得证书&…

liunx安装jdk、redis、nginx

jdk安装 下载jdk,解压。 sudo tar -zxvf /usr/local/jdk-8u321-linux-x64.tar.gz -C /usr/local/ 在/etc/profile文件中的&#xff0c;我们只需要编辑一下&#xff0c;在文件的最后加上java变量的有关配置&#xff08;其他内容不要动&#xff09;。 export JAVA_HOME/usr/l…

云轴科技ZStack与华东师范大学共建产教融合基地

近日&#xff0c;上海云轴信息科技有限公司&#xff08;云轴科技ZStack&#xff09;与华东师范大学上海国际首席技术官学院宣布&#xff0c;共同打造产教融合基地&#xff0c;以促进人才培养与产业需求的全方位融合。这一举措旨在深化教育与产业的合作关系&#xff0c;培养更多…

Maven编译报processing instruction can not have PITarget with reserveld xml name

在java项目中&#xff0c;平时我们会执行mvn clean package命令来编译我们的java项目&#xff0c;可是博主今天执行编译时突然报了 processing instruction can not have PITarget with reserveld xml name 这个错&#xff0c;网上也说法不一&#xff0c;但是绝大绝大部分是因…

Yii2中如何使用scenario场景,使rules按不同运用进行字段验证

Yii2中如何使用scenario场景&#xff0c;使rules按不同运用进行字段验证 当创建news新闻form表单时&#xff1a; 添加新闻的时候执行create动作。 必填字段&#xff1a;title-标题&#xff0c;picture-图片&#xff0c;description-描述。 这时候在model里News.php下rules规则…