VAE-变分自编码器(Variational Autoencoder,VAE)

变分自编码器(Variational Autoencoder,VAE)是一种生成模型,结合了概率图模型与神经网络技术,广泛应用于数据生成、表示学习和数据压缩等领域。以下是对VAE的详细解释和理解:

基本概念

1. 自编码器(Autoencoder)

自编码器是一种无监督学习模型,通常用于降维和特征提取。它由两个主要部分组成:

  • 编码器(Encoder):将输入数据映射到一个低维隐变量空间。
  • 解码器(Decoder):从低维隐变量空间重建输入数据。
    自编码器的目标是使重建的数据尽可能与原始输入数据相似。

2. 变分自编码器(VAE)

VAE 是自编码器的一种扩展,它通过引入概率分布的概念来对隐变量空间进行建模。VAE 的目标不仅是重建输入数据,还要使隐变量遵循某种已知的概率分布(通常是标准正态分布)。这样可以通过采样隐变量来生成新数据。

VAE的工作原理

  1. 编码器
    在VAE中,编码器不是直接输出一个隐变量,而是输出隐变量的参数(均值 μ 和标准差 σ)。这些参数定义了隐变量的一个概率分布,通常假设为正态分布 N(μ, σ^2)。

  2. 重新参数化技巧(Reparameterization Trick)
    为了使模型能够通过梯度下降进行训练,VAE引入了重新参数化技巧。通过采样一个标准正态分布的变量 ε ~ N(0, 1),然后进行线性变换得到隐变量 z:
    在这里插入图片描述

这样,采样操作变成了一个确定性的操作,允许梯度反向传播。

  1. 解码器
    解码器接受从上述分布中采样的隐变量 z,并尝试重建输入数据。解码器的目标是最大化重建数据的概率。

损失函数

VAE 的损失函数由两部分组成:

  • 重构损失(Reconstruction Loss):衡量重建数据与原始数据的相似度,通常使用均方误差(MSE)或交叉熵损失。 KL

  • 散度(KL Divergence):衡量隐变量分布与标准正态分布的差异。通过最小化KL散度,使隐变量分布接近标准正态分布。

综合起来,VAE的损失函数为:

在这里插入图片描述

VAE的优点

  1. 生成能力:可以从隐变量空间采样生成新数据,具有良好的生成能力。
  2. 隐变量解释性:通过将隐变量空间约束为标准正态分布,隐变量具有一定的解释性和可操作性。
  3. 无监督学习:VAE是一种无监督学习模型,不需要标签数据即可进行训练。

VAE的缺点

  1. **生成质量有限:**生成数据的质量有时不如GAN(生成对抗网络)等其他生成模型。
  2. **训练复杂:**VAE的训练涉及到复杂的概率推断和优化过程。

总结

变分自编码器通过引入概率分布和重新参数化技巧,使得隐变量具有良好的生成能力和解释性。其核心思想是在保持重建数据质量的同时,使隐变量遵循标准正态分布,从而实现数据生成和表示学习。尽管存在一些缺点,但VAE在许多应用场景中仍然表现出色,并为生成模型的研究提供了重要的理论基础。

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms
from torch.autograd import Variable

# 定义VAE模型
class VAE(nn.Module):
    def __init__(self, input_dim, hidden_dim, latent_dim):
        super(VAE, self).__init__()
        self.fc1 = nn.Linear(input_dim, hidden_dim)
        self.fc21 = nn.Linear(hidden_dim, latent_dim)
        self.fc22 = nn.Linear(hidden_dim, latent_dim)
        self.fc3 = nn.Linear(latent_dim, hidden_dim)
        self.fc4 = nn.Linear(hidden_dim, input_dim)

    def encode(self, x):
        h1 = F.relu(self.fc1(x))
        return self.fc21(h1), self.fc22(h1)

    def reparameterize(self, mu, logvar):
        std = torch.exp(0.5*logvar)
        eps = torch.randn_like(std)
        return mu + eps*std

    def decode(self, z):
        h3 = F.relu(self.fc3(z))
        return torch.sigmoid(self.fc4(h3))

    def forward(self, x):
        mu, logvar = self.encode(x.view(-1, 784))
        z = self.reparameterize(mu, logvar)
        return self.decode(z), mu, logvar

# 定义损失函数
def loss_function(recon_x, x, mu, logvar):
    BCE = F.binary_cross_entropy(recon_x, x.view(-1, 784), reduction='sum')
    KLD = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
    return BCE + KLD

# 加载MNIST数据集
train_loader = torch.utils.data.DataLoader(
    datasets.MNIST('../data', train=True, download=True,
                   transform=transforms.ToTensor()),
    batch_size=128, shuffle=True)

# 初始化模型
vae = VAE(input_dim=784, hidden_dim=512, latent_dim=20)
optimizer = optim.Adam(vae.parameters(), lr=1e-3)

# 训练模型
def train(epoch):
    vae.train()
    train_loss = 0
    for batch_idx, (data, _) in enumerate(train_loader):
        optimizer.zero_grad()
        recon_batch, mu, logvar = vae(data)
        loss = loss_function(recon_batch, data, mu, logvar)
        loss.backward()
        train_loss += loss.item()
        optimizer.step()
        if batch_idx % 100 == 0:
            print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
                epoch, batch_idx * len(data), len(train_loader.dataset),
                100. * batch_idx / len(train_loader),
                loss.item() / len(data)))

    print('====> Epoch: {} Average loss: {:.4f}'.format(
          epoch, train_loss / len(train_loader.dataset)))

# 开始训练
for epoch in range(1, 11):
    train(epoch)

代码说明

  • 编码器和解码器:编码器将输入图像编码为潜在空间的均值和对数方差,解码器从潜在变量生成重建的图像。
  • Sampling层:这是实现重参数化技巧的关键部分,将均值和对数方差转换为潜在变量。
  • VAE类:组合编码器和解码器,并实现自定义训练步骤,包括计算重建损失和KL散度损失。
  • 数据准备和训练:加载MNIST数据集,对数据进行预处理,然后训练VAE模型。
    这个示例展示了一个简单的VAE模型。根据具体的应用需求,你可能需要调整网络结构和超参数。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/635137.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

璞公英教学平台同时进驻两大云教育平台,让智慧教育“触手可及”!

近日,璞公英教学平台云上服务版图进一步扩大,在中国电信天翼云甄选商城、宁夏教育资源公共服务平台成功上线,为更多学校更多师生提供精细化服务。借助云平台的强大力量,璞公英教学平台将为用户带来前所未有、超越想象的教学体验。…

Java面试八股之进程和线程的区别

Java进程和线程的区别 定义与作用: 进程:在操作系统中,进程是程序执行的一个实例,是资源分配的最小单位。每个进程都拥有独立的内存空间,包括代码段、数据段、堆空间和栈空间,以及操作系统分配的其他资源…

【HarmonyOS4学习笔记】《HarmonyOS4+NEXT星河版入门到企业级实战教程》课程学习笔记(十一)

课程地址: 黑马程序员HarmonyOS4NEXT星河版入门到企业级实战教程,一套精通鸿蒙应用开发 (本篇笔记对应课程第 18 节) P18《17.ArkUI-状态管理Observed 和 ObjectLink》 第一件事:嵌套对象的类型上加上 Observed 装饰器…

推荐一个娱乐网站poki

今天,我要向您介绍一个充满乐趣的娱乐网站——Poki。这是一个集合了众多在线小游戏的平台,适合所有年龄段的玩家。无论您是想在工作间隙放松一下,还是寻找适合家庭聚会时的娱乐活动,Poki都能满足您的需求。所有游戏都无需下载或安…

leetcode_2024年5月19日10:51:26

238.除自身以外各元素的乘积 给你一个整数数组nums,返回数组answer,其中answer[i]等于nums中除nums[i]之外其余各元素的乘积。 题目数据保证数组nums之中任意元素的全部前缀元素和后缀的乘积都在32位整数范围内。 请不要使用除法,且在o&am…

使用神经实现路径表示的文本到向量生成

摘要 矢量图形在数字艺术中得到广泛应用,并受到设计师的青睐,因为它们具有可缩放性和分层特性。然而,创建和编辑矢量图形需要创造力和设计专业知识,使其成为一项耗时的任务。最近在文本到矢量(T2V)生成方面…

大语言模型的工程技巧(二)——混合精度训练

相关说明 这篇文章的大部分内容参考自我的新书《解构大语言模型:从线性回归到通用人工智能》,欢迎有兴趣的读者多多支持。 混合精度训练的示例请参考如下链接:regression2chatgpt/ch11_llm/gpt2_lora_optimum.ipynb 本文将讨论如何利用混合…

vue.js状态管理和服务端渲染

状态管理 vuejs状态管理的几种方式 组件内管理状态:通过data,computed等属性管理组件内部状态 父子组件通信:通过props和自定义事件实现父子组件状态的通信和传递 事件总线eventBus:通过new Vue()实例,实现跨组件通…

个人博客网站开发笔记2

文章目录 前言p2 hexo安装与使用安装 Nodejs安装 GitGit Bash的使用,代码克隆Clone p3 写作一级标题二级标题三级标题四级标题五级标题六级标题 前言 现在继续看教程 p2 hexo安装与使用 link 啊有点难受,开幕就是需要自己先安装Nodejs和Git&#xff…

git使用介绍

一、为什么做版本控制(git是版本控制工具) 为了保留之前所以的版本,以便回滚和修改 二、点击安装 三、基础操作 1、初步认识 想要让git对一个目录进行版本控制需要以下步骤: 进入要管理的文件夹进行初始化命令 git init管理…

el-table 组件实现 “合并单元格 + N行数据小计” 功能

目录 需求 - 要实现的效果初始代码代码升级(可供多个表格使用)CommonTable.vue 子组件 使用子组件1 - 父组件 - 图1~图3使用效果展示 使用子组件2 - 父组件 - 图4使用效果展示 注意【代码优化 - 解决bug】 需求 - 要实现的效果 父组件中 info 数据示例 …

Redis篇 浅谈分布式系统

分布式系统 一. 单机架构二.分布式系统引入三.引入更多的应用服务器四.读写分离五.引入缓存服务器六. 将数据库服务器拆分七.微服务架构 一. 单机架构 单机架构,就是用一台服务器,完成所有的工作. 这时候就需要我们引入分布式系统了. 分布式系统是什么含义呢?就是由一台主机服…

MySQL实战——主从异步复制搭建(一主一从)

一、搭建前的准备 主库 192.168.1.76 从库 192.168.1.77 二、搭建 1、编辑配置文件 vi /etc/my.cnf 主库 [mysqld] log-binmysql-bin server-id1 从库 [mysqld] server-id2 2、在主库创建复制用户 create user repl192.168.1.77 identified by repl123; grant replic…

9、QT—SQLite使用小记

前言 开发平台:Win10 64位 开发环境:Qt Creator 13.0.0 构建环境:Qt 5.15.2 MSVC2019 64位 sqlite版本:sqlite3 文章目录 一、Sqlite是什么二、sqlite使用步骤2.1 下载2.2 安装2.3 使用 三、Qt集成sqlite33.1 关键问题3.2 封装sql…

C#, PCANBasicd.dll库读写CAN设备数据

PCAN-Basic是一个简单的 PCAN 系统编程接口。 通过 PCAN-Basic Dll,可以将自己的应用程序连接到设备驱动程序和 PCAN 硬件,以与 CAN 总线进行通信。支持C、C++、C#、Delphi、JAVA、VB、Python等语言。 PCAN-Basic库和驱动下载地址 ​ ​https://www.peak-system.com/filead…

【C#】未能加载文件或程序集“CefSharp.Core.Runtime.dll”或它的某一个依赖项。找不到指定的模块。

欢迎来到《小5讲堂》 这是《C#》系列文章,每篇文章将以博主理解的角度展开讲解。 温馨提示:博主能力有限,理解水平有限,若有不对之处望指正! 目录 背景错误提示分析原因解决方法Chromium知识点相关文章 背景 最近在使…

LeetCode 131题详解:高效分割回文串的递归与动态规划方法

❤️❤️❤️ 欢迎来到我的博客。希望您能在这里找到既有价值又有趣的内容,和我一起探索、学习和成长。欢迎评论区畅所欲言、享受知识的乐趣! 推荐:数据分析螺丝钉的首页 格物致知 终身学习 期待您的关注 导航: LeetCode解锁100…

Shell编程之条件判断语句

目录 一、条件判断 1、test命令 2、文件测试 3、整数值比较 4、字符串判断 5、逻辑测试 二、if语句 1、if单分支语句 2、双分支语句 3、多分之语句 4、case 分支语句 一、条件判断 Shell环境根据命令执行后的返回状态值(echo $?)来判断是否执行成…

力扣刷题---1748.唯一元素的和【简单】

题目描述 给你一个整数数组 nums 。数组中唯一元素是那些只出现 恰好一次 的元素。 请你返回 nums 中唯一元素的 和 。 示例 1: 输入:nums [1,2,3,2] 输出:4 解释:唯一元素为 [1,3] ,和为 4 。 示例 2:…

基于BERT的医学影像报告语料库构建

大模型时代,任何行业,任何企业的数据治理未来将会以“语料库”的自动化构建为基石。因此这一系列精选的论文还是围绕在语料库的建设以及自动化的构建。 通读该系列的文章,犹如八仙过海,百花齐放。非结构的提取无外乎关注于非结构…