VAE中的“变分”什么

写在前面

        VAE(Variational Autoencoder),中文译为变分自编码器。其中AE(Autoencoder)很好理解。那“变分”指的是什么呢?—其实是“变分推断”。变分推断主要用在VAE的损失函数中,那变分推断是什么,VAE的损失函数又是什么呢?下面我就来说一说!

       可以先看一下 这篇文章,介绍了VAE的代码实现。

一、通俗理解损失函数

        这篇文章已经整体介绍了VAE,这里我详细介绍一下VAE的损失函数:

\mathbf{LOSS=-E_{q(z|x)}\left [ \textit{log}p(x|z) \right ]+KL(q(z|x)||p(z))}

        每个变量的说明下面会有介绍,现在我们只关注VAE的损失函数有由两部分组成,第一部分是一个交叉熵,我们称之为“重构项”,其作用是确保训练时输入和输出间的相似性;第二部分是KL散度,叫做“KL散度项”,它其实是一个正则项,主要解决了两个AE模型的痛点,这也是VAE成功并流行的主要原因:

        1.潜在空间的结构化:AE的潜在空间往往是无规则的,这意味着编码器学到的表征可能杂乱无章,不便于后续操作。VAE通过添加KL散度项来惩罚潜在变量分布与预设先验分布(就是p(z),是一个标准高斯分布)之间的偏差,从而迫使潜在空间呈现出一定的结构,使潜在变量的分布更加合理和连贯。说人话就是:VAE可以输入标准高斯分布的采样数据,生成精美的图像。

        2.潜在空间的连续性:KL散度项要求潜在变量 z 的分布  q(z|x) 尽可能接近预设的先验分布  p(z) ,这个先验分布通常选择为标准正态分布。通过这种方式,潜在空间被组织成一个连续、平滑的多维空间,其中每一维上的值都能够自由变动而不产生剧烈变化。这种设计确保了在潜在空间中的小步长移动会导致解码结果的轻微变化,从而实现了连续性。说人话就是:VAE可以通过微调输入的采样数据,一定程度上修改生成图像的属性。这也是造成“抽卡”的原因之一。

        损失函数的这两项可以简单的这么理解,但是它其实是推导出来的,这就说来话长。感兴趣的小伙伴继续往下看。

二、边际似然

1.边际似然的定义

        VAE 是一种生成模型,生成模型的核心任务是计算在给定潜在变量 z 的情况下生成观测数据 x 的概率。我们希望模型能够生成与真实数据分布相似的新数据,这一目标可以通过边际似然 p(x) 来实现。

        其中z就是Latent;x是训练用的图像;p(x)是边际似然,也就是VAE的损失函数

        p(x)可以很好的衡量模型的生成能力。p(x)直接衡量了模型在生成数据方面的整体能力,因为它考虑了所有潜在的隐变量 z 对观测数据 x 的影响。高的p(x)意味着模型可以很好地解释数据,并且在生成新数据时表现出较强的能力。

        具体来说,如果模型的边际似然高,说明模型在所有可能的隐变量 z 下生成观测数据的概率累加起来后非常高,这意味着模型学到了数据的真实分布。

        边际似然 p(x) 表示给定模型情况下生成观测数据 x 的概率,定义为:

p(x)=\int p(x|z)p(z)dz  (1)

        其中,条件概率 p(x∣z):给定潜在变量 z 的情况下,生成观测数据 x 的概率。先验分布 p(z):潜在变量 z 的分布,反映了我们对 z 的先验知识。

2.边际似然的推导

        使用全概率公式,边际似然可以用全概率公式来定义,具体为:

p(x)=\int p(x,z)dz (2)

        这里 p(x,z)是 x 和 z 的联合分布。根据条件概率的定义,联合分布可以表示为:

p(x,z)=p(x|z)p(z) (3)

        因此,我们可以将边际似然表示为:

p(x)=\int p(x|z)p(z)dz (4)

        我们要做的就是最大化p(x),这里多说一句,最大化p(x)的目标是使得模型生成的总体概率分布 p(x) 更接近于真实数据分布。这样,模型生成的新样本就会与训练数据的分布一致。

        直观理解:假设我们在训练一个模型生成手写数字图片。如果真实的数据集中 80% 是“1”,20% 是“2”,那么一个好的生成模型应该能够生成 80% 的“1”和 20% 的“2”。而不是让p(x)趋近于1.

3.边际似然的挑战

        但是计算边际似然通常是一个复杂且困难的任务,原因包括:

        (1)高维积分:在实际的应用中,潜在变量 z 通常是高维的。例如,如果 z 是 100 维的向量,那么积分就需要在 100 维的空间上进行。这种高维积分是非常复杂的,解析解几乎不可能得到。

        (2)分布形式复杂:在生成模型中,条件分布 p(x∣z)和先验分布 p(z) 可能并不是简单的概率分布。例如,p(x∣z) 可能由一个深度神经网络参数化,计算时需要经过非线性激活函数和复杂的网络结构,这会让这个积分无法直接求解。

        (3)数值计算的困难:计算边际似然时,需要对 z 的所有可能值进行积分,也就是计算出在所有潜在表示 z 上,生成数据 x 的所有可能性。现实中,z 的范围非常大,即使是连续的,也可能取值无穷多个,直接求解所有 z 的可能性几乎是不可能的。

        举个例子,假设我们有一个简单的生成模型,其中:p(z) 是标准正态分布N(0,I)。p(x∣z) 是由一个深度神经网络生成的图像。直接计算边际似然意味着我们需要知道所有 z 的取值如何影响 x。如果 z 是 100 维向量,那么在 R^{100} 空间上对 z 进行积分(或采样)需要极大的计算资源。神经网络的非线性使得每个 p(x∣z) 的计算都很复杂,最终让直接计算积分变得不可行。

        为了解决上面的问题,让模型可以正常训练,我们引入变分推断。

三、变分推断

1.变分推断的定义

        变分推断是一种通过引入近似分布来解决无法直接计算复杂积分的问题的方法。在生成模型中,我们的目标是最大化观测数据的边际似然 p(x):

p(x)=\int p(x|z)p(z)dz (5)

        如前所述,这个积分通常很难直接计算,因此我们引入一个 近似后验分布(也叫变分分布,就是训练时模型的输出 q(z∣x),来代替无法直接求解的真实后验 p(z∣x)。变分推断的目标是让 q(z∣x) 尽可能地接近真实的 p(z∣x)。

\mathbf{p(x)=\int p(x|z)p(z)dz=\int p(z|x)\frac{p(x|z)p(z)}{q(z|x)}dz} (6)

        通过这种重写,我们引入了 q(z∣x) 作为一个权重,这样我们可以在期望的形式下进行优化。我们现在有一个可以计算的表达式:

\mathbf{\mathit{log}p(x)=\mathit{log}\int p(z|x)\frac{p(x|z)p(z)}{q(z|x)}dz} (7)

        尽管重写了表达式,计算 p(x)依然困难,因为积分本身依然难解。因此,我们应用 Jensen 不等式(log是凸函数),将对数操作从积分外移到期望内部(这里的期望是由积分转化来的):

\mathbf{\mathit{log}p(x)=\mathit{log}\int p(z|x)\frac{p(x|z)p(z)}{q(z|x)}dz\geq E_{q(z|x)} \left [ log\frac{p(x|z)p(z)}{q(z|x)} \right ] } (8)

        其中,Eq(z∣x)[⋅]表示在 q(z∣x) 分布下对 z 取期望。这一不等式说明,我们得到了一个对数边际似然的下界,即变分下界 (ELBO)。

2.变分下界ELBO

        式子(8)右边的表达式即为变分下界(Evidence Lower Bound,),通常记作 ELBO,至此我们的目标也变成了最大化ELBO,从而间接地最大化边际似然 p(x)。式子(8)可以写成:

\mathbf{ELBO=E_{q(z|x)} \left [ log\frac{p(x|z)p(z)}{q(z|x)} \right ] } (9)

        式子(9)右边可以展开成:

\mathbf{ELBO=E_{q(z|x)}\left [ \textit{log}p(x|z) \right ]+E_{q(z|x)}\left [ \textit{log}p(z) \right ]-E_{q(z|x)}\left [ \textit{log}q(z|x) \right ]} (10)

        因为KL散度公式:

\mathbf{KL(q(z|x)||p(z))=E_{q(z|x)}[log\frac{q(z|x)}{p(z)}]=E_{q(z|x)}[\textit{log}q(z|x)]-E_{q(z|x)}[\textit{log}p(z)]}(11)

        可以看到,式子(10)右边的第二项和第三项可以用KL散度代替:

\mathbf{-KL(q(z|x)||p(z))=E_{q(z|x)}[\textit{log}p(z)]-E_{q(z|x)}[\textit{log}q(z|x)]}(12)

        最终,ELBO 可以写成如下式子,这也是VAE需要优化的损失函数:

\mathbf{ELBO=E_{q(z|x)}\left [ \textit{log}p(x|z) \right ]-KL(q(z|x)||p(z))} (13)

        ELBO 公式展示了两个部分:

        重构项:表示模型生成数据的能力。

        KL 散度项:作为正则化项,控制 q(z∣x) 和 p(z) 之间的差异。最小化这个项有助于使近似后验 q(z∣x) 尽量接近先验 p(z),从而促进模型的泛化能力。p(z)一般被设置成标准高斯分布。

最大化 ELBO 的意义:

        优化目标:最大化 ELBO 实际上是希望在重构能力和潜在分布的正则化之间取得平衡。通过调整这两个部分,可以确保模型既能够良好地重构输入数据,又能够学习到有意义的潜在空间。

        间接最大化边际似然:由于 ELBO 是边际似然的下界,最大化 ELBO 也会使得边际似然 p(x) 的值增加。

        ELBO 在 VAE 中扮演着至关重要的角色,它将生成模型的目标与优化过程结合起来,使得模型能够在重构能力和潜在空间的正则化之间找到最佳平衡。通过最大化 ELBO,VAE 能够学习到有效的潜在表示,从而生成新样本。

四、代码实现中的公式

        这篇文章介绍了VAE的代码实现,其中的损失函数是ELBO的具体实现,我们来看一下,具体是怎么实现的。

        我们的目标是最大化ELBO,相当于最小化其负值,因此 VAE 的损失函数可以表示为:

\mathbf{LOSS=-E_{q(z|x)}\left [ \textit{log}p(x|z) \right ]+KL(q(z|x)||p(z))}   (14)

1.重构项

        交叉熵的定义为:

H(p,q)=-E_p[log\textbf{q}]   (15)

        如果我们将 p(x∣z) 视为模型生成 x 的概率分布(对应代码中的recon_x,即模型的输出),而将真实数据的分布视为 q(x)(对应代码中的x,即GT),则ELBO的第一项可以写成:

\mathbf{E_{q(z|x)}[\textit{log}p(x|z)]=-H(x,q(z|x))}  (16)

        最大化 ELBO 的第一项(重构项)实际上是最小化交叉熵损失,代码如下:

BCE = nn.functional.binary_cross_entropy(recon_x, x, reduction='sum')

2.KL 散度项

        对于高斯分布 q(z|x)=N(\mu ,\sigma ^2)和标准正态分布p(x)=N(0,1),我们可以将 KL散度计算分解为以下几个步骤:

        (1)KL散度 的公式为:

\mathbf{KL\left [ q(z|x)||p(z) \right ]=\int q(z|x)log(\frac{q(z|x)}{p(z)})dz}(17)

        解释一下变量的意义:

        q(z∣x):这是给定输入 x时隐变量 z 的后验分布,通常由编码器生成。

        p(z):这是隐变量 z 的先验分布,通常是标准高斯分布 N(0,1)。

        比率 \mathbf{\frac{q(z|x)}{p(z)}}:这个比率表示后验分布与先验分布的相对关系,反映了后验分布相较于先验分布的“信息量”。

        对数项\mathbf{log(\frac{q(z|x)}{p(z)})}:量化了 q(z∣x) 相较于 p(z) 的信息增益。正值表示后验分布相对于先验分布的增加的信息,而负值则表示信息的损失。

        积分:通过对所有可能的 z进行积分,KL散度 计算了整个后验分布与先验分布之间的差异。

        (2)将q(z|x)=N(\mu ,\sigma ^2)p(z)=N(0,1)带入(17

KL\left [ q(z|x)||p(z) \right ]=\int N(\mu ,\sigma ^2)log(\frac{N(\mu ,\sigma ^2)}{N(0,1)})dz(18)

        (3)高斯分布的公式: 高斯分布的概率密度函数为:

\mathbf{N(z;\mu ,\sigma ^2)=\frac{1}{\sqrt{2\pi \sigma ^2}}exp[-\frac{(z-\mu )^2}{2\sigma ^2}]} (19)

        而标准正态分布为:

\mathbf{N(z;0,1)=\frac{1}{\sqrt{2\pi }}exp(-\frac{z^2}{2})}    (20)

        (4)计算 KL散度: 将这些代入 K散度的公式中,最终可以简化得到:

KL(q(z|x)||p(z))=-\frac{1}{2}(1+log(\sigma ^2)-\mu ^2-\sigma ^2)  (21)

        (5)简化: 进一步简化后,得到:

KL(q(z|x)||p(z))=-0.5(log(\sigma ^2)+1-\mu ^2-\sigma ^2)  (22)

        (6)用对数方差表示: 在实现中,通常使用对数方差 log(\sigma ^2) 来计算,这样可以避免数值稳定性问题,最终得到的 KL散度公式是:

KL(q(z|x)||p(z))=-0.5(1+log(\sigma ^2)-\mu ^2-\sigma ^2)(23)

        KL散度代码实现:在代码实现的时候编码器的输出其实是均值mu和对数方差log_var

KLD = -0.5 * torch.sum(1 + log_var - mu.pow(2) - log_var.exp())

        其中log_var 是对数方差,使用对数方差的形式可以保证数值稳定性、避免负值以及计算便利性,这种做法在许多深度学习模型中都得到了广泛应用,尤其是在处理概率分布时。;mu 是均值;\sigma ^2=exp(log\sigma ^2)

五、总结

        1.VAE中的“变分”指的是“变分推断”;

        2.VAE的损失函数值最大化边际似然;

        3.最大化边际似然几乎做不到,所以使用变分推断来简化计算;

        4.使用变分推断后,训练通过最大化ELBO实现;

        5.ELBO有两项:重构项和KL散度项。重构项的作用是确保训练时输入和输出间的相似性,就是传统的损失函数常用的东西;KL散度项是一个正则项,能确保潜在空间的结构化和连续性。

        VAE就介绍到这,关注不迷路(*^__^*) 

  

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/902481.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

第十二部分 Java Stream、File

第十二部分 Java Stream、File 12.1 Java Stream流 12.1.1体验Stream流 案例需求 按照下面的要求完成集合的创建和遍历 创建一个集合,存储多个字符串元素把集合中所有以"张"开头的元素存储到一个新的集合把"张"开头的集合中的长度为3的元素存…

OpenTelemetry 实际应用

介绍 OpenTelemetry“动手”指南适用于想要开始使用 OpenTelemetry 的人。 如果您是 OpenTelemetry 的新手,那么我建议您从OpenTelemetry 启动和运行帖子开始,我在其中详细介绍了 OpenTelemetry。 OpenTelemetry开始改变可观察性格局,它提供…

开挖 Domain - 前奏

WPF App 主机配置 Microsot.Extension.Hosting 一键启动(配置文件、依赖注入,日志) // App.xaml.cs 中定义 IHost private readonly IHost _host Host.CreateDefaultBuilder().ConfigureAppConfiguration(c > {_ c.SetBasePath(Envi…

JVM(HotSpot):GC之G1垃圾回收器

文章目录 一、简介二、工作原理三、Young Collection 跨代引用四、大对象问题 一、简介 1、适用场景 同时注重吞吐量(Throughput)和低延迟(Low latency),默认的暂停目标是 200 ms超大堆内存,会将堆划分为…

CentOS 7 上安装 MySQL 8.0 教程

🌟 你好 欢迎来到我的技术小宇宙!🌌 这里不仅是我记录技术点滴的后花园,也是我分享学习心得和项目经验的乐园。📚 无论你是技术小白还是资深大牛,这里总有一些内容能触动你的好奇心。🔍 &#x…

C#使用log4net结合sqlite数据库记录日志

0 前言 为什么要把日志存到数据库里? 因为结构化的数据库存储的日志信息,可以写专门的软件读取历史日志信息,通过各种条件筛选,可操作性极大增强,有这方面需求的开发人员可以考虑。 为什么选择SQLite? …

node和npm

背景(js) 1、为什么js能操作DOM和BOM? 原因:每个浏览器都内置了DOM、BOM这样的API函数 2、浏览器中的js运行环境? v8引擎:负责解析和执行js代码 内置API:由运行环境提供的特殊接口,只能在所…

Java面向对象编程高阶(一)

Java面向对象编程高阶(一) 一、关键字static1、static修饰属性2、静态变量与实例变量的对比3、static修饰方法4、什么时候将属性声明为静态的?5、什么时候将属性声明为静态的?6、代码演示 一、关键字static static用来修饰的结构…

从0到1学习node.js(npm)

文章目录 一、NPM的生产环境与开发环境二、全局安装三、npm安装指定版本的包四、删除包 五、用npm发布一个包六、修改和删除npm包1、修改2、删除 一、NPM的生产环境与开发环境 类型命令补充生产依赖npm i -S uniq-S 等效于 --save -S是默认选项npm i -save uniq包的信息保存在…

首席数据官和首席数据分析官

根据分析人士的预测,首席数据官(CDO)和首席数据分析官(CDAO)必须更有效地展示他们对企业和AI项目的价值,以保障其在高管层的地位。Gartner的最新报告指出,CDO和CDAO在AI时代需要重新塑造自身定位…

HDFS异常org.apache.hadoop.hdfs.protocol.NSQuotaExceededException

HDFS异常org.apache.hadoop.hdfs.protocol.NSQuotaExceededException 异常信息: Hive:org.apache.hadoop.hdfs.protocol.NSQuotaExceededException: The NameSpace quota (directories and files) of directory /xxxdir is exceeded: quota10000 file count15001N…

【Python】为Pandas加速(适合Pandas中级开发者)

非常好的一篇文章,解决问题的方式和思路层层递进,透彻深刻。 Pandas是个好工具,好工具要用正确高效的方式使用,才能发挥出万钧之力。 英文水平较高可直接阅读原文。Fast, Flexible, Easy and Intuitive: How to Speed Up Your p…

linux创建自定义服务部署项目

1.进入linux单元服务文件夹 cd /etc/systemd/system/ 2.创建一个文件以.service结尾的文件 C#(.Net Core)、 Java、Python等语言,都可以通过linux自定义服务来部署项目,实现守护进程、实现开机自启 2.1例如创建my.service文件 这里以部署python项目为…

新华三H3CNE网络工程师认证—OSPF路由协议

OSPF是典型的链路状态路由协议,是目前业内使用非常广泛的IGP协议之一。本博客将对OSPF路由协议进行总结。 OSPF目前针对IPv4协议使用的是OSPFVersion2(RFC2328); 针对IPv6协议使用OSPFVersion3(RFC2740)。如无特殊说明本章后续所指的OSPF均为OSPF Versi…

HBase2.4.17 修改znode后master初始化失败

正常运行的hbase服务,修改zookeeper.znode.parent后,重启。hbase master服务可以启动成功,但是仅有meta表上线,且hbase:meta表中元数据丢失。仅残留table:state列的值,其他列的值全部丢失。 有大佬知道是怎么回事嘛

(二十四)、在 k8s 中部署自己的 jar 镜像(以 springcloud web 项目为例)

文章目录 1、环境陈述2、前期准备2.1、将一个 SpringCloud 微服务运行 以 jar 方式运行2.2、为 SpringCloud 项目生成 Docker 镜像2.3、推送镜像2.4、从宿主机访问 k8s(minikube) 发布的 redis 服务2.5、k8s(minikube) 部署mysql 3、本期关键3.1、打 jar 包需要修改的地方3.2、…

Anchor DETR:Transformer-Based目标检测的Query设计

写在前面 文中指出之前DETR-like算法存在以下问题: 之前DETR-liked检测算法里,object query是一组可学习的嵌入表示(就是一组256-d的向量),缺乏明确的物理意义,不能解释它们会关注什么地方。每个object q…

禾川SV-X2E A伺服驱动器参数设置——脉冲型

前些天发现了一个巨牛的人工智能学习网站,通俗易懂,风趣幽默,忍不住分享一下给大家!人工智能学习网站 前言: 大家好,我是上位机马工,硕士毕业4年年入40万,目前在一家自动化公司担任…

PHPOK 4.8.338 后台任意文件上传漏洞(CVE-2018-12941)复现

PHPOK企业站(简称PHPOK)建设系统是一套基于PHP和MySQL构建的高效企业网站建设方案之一,全面针对企业网(以展示为中心)进行合理的设计规划。 PHPOK是一套开源免费的建站系统,可以在遵守LGPL协议的基础上免费使用。系统具…