pytorch实现变分自编码器

 人工智能例子汇总:AI常见的算法和例子-CSDN博客 

变分自编码器(Variational Autoencoder, VAE)是一种生成模型,属于深度学习中的无监督学习方法。它通过学习输入数据的潜在分布(Latent Distribution),生成与输入数据相似的新样本。VAE 可以用于数据生成、降维、异常检测等任务。

VAE 的关键思想是在传统的自编码器(Autoencoder)的基础上,引入了变分推断(Variational Inference)和概率模型,使得网络能够学习到数据的潜在分布,而不仅仅是数据的映射。

VAE 的结构:

  1. 编码器(Encoder):将输入数据映射到潜在空间的分布。不同于传统的自编码器直接将数据映射到一个固定的潜在向量,VAE 通过输出潜在变量的均值和方差来描述一个概率分布,这样潜在空间中的每个点都有一个概率分布。
  2. 潜在空间(Latent Space):表示数据的潜在特征。在 VAE 中,潜在空间的表示是一个分布而不是固定的值。通常,采用正态分布来作为潜在空间的先验分布。
  3. 解码器(Decoder):从潜在空间的样本中重构输入数据。解码器通过将潜在空间的点映射回数据空间来生成样本。

VAE 的目标函数:

VAE 的目标是最大化变分下界(Variational Lower Bound,简称 ELBO),即通过优化以下两部分的加权和:

  • 重构误差(Reconstruction Loss):衡量生成的数据和输入数据之间的差异,通常使用均方误差(MSE)或交叉熵(Cross-Entropy)。
  • KL 散度(KL Divergence):衡量潜在空间的分布与先验分布(通常是标准正态分布)之间的差异。

其最终的目标是使生成的数据尽可能接近真实数据,同时使潜在空间的分布接近先验分布。

优点:

  • VAE 能够生成具有多样性的样本,尤其适用于图像、音频等数据的生成。
  • 潜在空间通常具有良好的结构,可以进行插值、样本生成等操作。

应用:

  • 生成任务:如图像生成、文本生成等。
  • 数据重构:如去噪、自编码等。
  • 半监督学习:VAE 可以结合有标签和无标签的数据进行训练,提升模型的泛化能力。
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import matplotlib.pyplot as plt


# 生成圆形图像的函数(使用PyTorch)
def generate_circle_image(size=64):
    image = torch.zeros((1, size, size))  # 使用 PyTorch 创建空白图像
    center = size // 2
    radius = size // 4
    for y in range(size):
        for x in range(size):
            if (x - center) ** 2 + (y - center) ** 2 <= radius ** 2:
                image[0, y, x] = 1  # 在圆内的点设置为白色
    return image


# 生成方形图像的函数(使用PyTorch)
def generate_square_image(size=64):
    image = torch.zeros((1, size, size))  # 使用 PyTorch 创建空白图像
    padding = size // 4
    image[0, padding:size - padding, padding:size - padding] = 1  # 设置方形区域为白色
    return image


# 自定义数据集:圆形和方形图像
class ShapeDataset(Dataset):
    def __init__(self, num_samples=1000, size=64):
        self.num_samples = num_samples
        self.size = size
        self.data = []
        # 生成数据:一半是圆形图像,一半是方形图像
        for i in range(num_samples // 2):
            self.data.append(generate_circle_image(size))
            self.data.append(generate_square_image(size))

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        return self.data[idx].float()  # 直接返回 PyTorch Tensor 格式的数据


# VAE模型定义
class VAE(nn.Module):
    def __init__(self, latent_dim=2):
        super(VAE, self).__init__()
        self.latent_dim = latent_dim

        # 编码器
        self.fc1 = nn.Linear(64 * 64, 400)
        self.fc21 = nn.Linear(400, latent_dim)  # 均值
        self.fc22 = nn.Linear(400, latent_dim)  # 方差

        # 解码器
        self.fc3 = nn.Linear(latent_dim, 400)
        self.fc4 = nn.Linear(400, 64 * 64)

    def encode(self, x):
        h1 = torch.relu(self.fc1(x.view(-1, 64 * 64)))
        return self.fc21(h1), self.fc22(h1)  # 返回均值和方差

    def reparameterize(self, mu, logvar):
        std = torch.exp(0.5 * logvar)
        eps = torch.randn_like(std)
        return mu + eps * std

    def decode(self, z):
        h3 = torch.relu(self.fc3(z))
        return torch.sigmoid(self.fc4(h3)).view(-1, 1, 64, 64)  # 重构图像

    def forward(self, x):
        mu, logvar = self.encode(x)
        z = self.reparameterize(mu, logvar)
        return self.decode(z), mu, logvar


# 损失函数:重构误差 + KL 散度
def loss_function(recon_x, x, mu, logvar):
    BCE = nn.functional.binary_cross_entropy(recon_x.view(-1, 64 * 64), x.view(-1, 64 * 64), reduction='sum')
    # KL 散度
    return BCE + 0.5 * torch.sum(torch.exp(logvar) + mu ** 2 - 1 - logvar)


# 设置超参数
batch_size = 128
epochs = 10
latent_dim = 2
learning_rate = 1e-3

# 数据加载
train_loader = DataLoader(ShapeDataset(num_samples=2000), batch_size=batch_size, shuffle=True)

# 创建模型和优化器
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = VAE(latent_dim).to(device)
optimizer = optim.Adam(model.parameters(), lr=learning_rate)


# 训练模型
def train(epoch):
    model.train()
    train_loss = 0
    for batch_idx, data in enumerate(train_loader):
        data = data.to(device)
        optimizer.zero_grad()
        recon_batch, mu, logvar = model(data)
        loss = loss_function(recon_batch, data, mu, logvar)
        loss.backward()
        train_loss += loss.item()
        optimizer.step()

        if batch_idx % 100 == 0:
            print(
                f'Train Epoch: {epoch} [{batch_idx * len(data)}/{len(train_loader.dataset)}] Loss: {loss.item() / len(data):.6f}')

    print(f'Train Epoch: {epoch} Average loss: {train_loss / len(train_loader.dataset):.4f}')

# 测试并显示一些真实图像和生成的图像
def test():
    model.eval()
    with torch.no_grad():
        # 获取一批真实的图像(原始图像)
        real_images = next(iter(train_loader))[:64]  # 只取前64个图像
        real_images = real_images.cpu().numpy()

        # 从潜在空间随机生成一些样本
        sample = torch.randn(64, latent_dim).to(device)
        generated_images = model.decode(sample).cpu().numpy()

        # 显示真实图像和生成的图像,分别标明
        fig, axes = plt.subplots(8, 8, figsize=(8, 8))
        axes = axes.flatten()

        for i in range(64):
            if i < 32:  # 前32个显示真实图像
                axes[i].imshow(real_images[i].squeeze(), cmap='gray')
                axes[i].set_title('Real', fontsize=8)
            else:  # 后32个显示生成图像
                axes[i].imshow(generated_images[i - 32].squeeze(), cmap='gray')
                axes[i].set_title('Generated', fontsize=8)

            axes[i].axis('off')

        plt.tight_layout()
        plt.show()

# 训练模型
for epoch in range(1, epochs + 1):
    train(epoch)

# 训练完成后,显示生成的图像
test()

解释:

  1. 真实图像 (real_images):我们通过 next(iter(train_loader)) 获取一批真实图像,并将其转换为 NumPy 数组,以便 matplotlib 显示。
  2. 生成图像 (generated_images):通过模型生成的图像,使用 decode() 方法生成潜在空间的样本。
  3. 图像展示:前 32 张图像展示真实图像,后 32 张图像展示生成的图像。每个图像上方都有 RealGenerated 标注。

结果:

  • 前32个图像:显示真实图像,并标注为 Real
  • 后32个图像:显示通过训练后的 VAE 生成的图像,并标注为 Generated

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/964958.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

java进阶1——JVM

java进阶——JVM 1、JVM概述 作用 Java 虚拟机就是二进制字节码的运行环境&#xff0c;负责装载字节码到其内部&#xff0c;解释/编译为对 应平台上的机器码指令行&#xff0c;每一条 java 指令&#xff0c;java 虚拟机中都有详细定义&#xff0c;如怎么取操 作数&#xff0c…

DeepSeek各版本说明与优缺点分析

DeepSeek各版本说明与优缺点分析 DeepSeek是最近人工智能领域备受瞩目的一个语言模型系列&#xff0c;其在不同版本的发布过程中&#xff0c;逐步加强了对多种任务的处理能力。本文将详细介绍DeepSeek的各版本&#xff0c;从版本的发布时间、特点、优势以及不足之处&#xff0…

视频融合平台EasyCVR无人机场景视频压缩及录像方案

安防监控视频汇聚EasyCVR平台在无人机场景中发挥着重要的作用&#xff0c;通过高效整合视频流接入、处理与分发等功能&#xff0c;为无人机视频数据的实时监控、存储与分析提供了全面支持&#xff0c;广泛应用于安防监控、应急救援、电力巡检、交通管理等领域。 EasyCVR支持GB…

【力扣】240.搜索二维矩阵 II

题目 我的代码 class Solution { public:bool searchMatrix(vector<vector<int>>& matrix, int target) {for(int i0;i<matrix.size();i){for(int j0;j<matrix[0].size();j){if(targetmatrix[i][j]){return true;}else if(target<matrix[i][j]){brea…

数据库备份、主从、集群等配置

数据库备份、主从、集群等配置 1 MySQL1.1 docker安装MySQL1.2 主从复制1.2.1 主节点配置1.2.2 从节点配置1.2.3 创建用于主从同步的用户1.2.4 开启主从同步1.2.4 主从同步验证 1.3 主从切换1.3.1 主节点设置只读&#xff08;在192.168.1.151上操作&#xff09;1.3.2 检查主从数…

intra-mart实现简易登录页面笔记

一、前言 最近在学习intra-mart框架&#xff0c;在此总结下笔记。 intra-mart是一个前后端不分离的框架&#xff0c;开发时主要用的就是xml、html、js这几个文件&#xff1b; xml文件当做配置文件&#xff0c;html当做前端页面文件&#xff0c;js当做后端文件&#xff08;js里…

Beans模块之工厂模块注解模块CustomAutowireConfigurer

博主介绍&#xff1a;✌全网粉丝5W&#xff0c;全栈开发工程师&#xff0c;从事多年软件开发&#xff0c;在大厂呆过。持有软件中级、六级等证书。可提供微服务项目搭建与毕业项目实战&#xff0c;博主也曾写过优秀论文&#xff0c;查重率极低&#xff0c;在这方面有丰富的经验…

javaEE-8.JVM(八股文系列)

目录 一.简介 二.JVM中的内存划分 JVM的内存划分图: 堆区:​编辑 栈区:​编辑 程序计数器&#xff1a;​编辑 元数据区&#xff1a;​编辑 经典笔试题&#xff1a; 三,JVM的类加载机制 1.加载: 2.验证: 3.准备: 4.解析: 5.初始化: 双亲委派模型 概念: JVM的类加…

【多线程】线程池核心数到底如何配置?

&#x1f970;&#x1f970;&#x1f970;来都来了&#xff0c;不妨点个关注叭&#xff01; &#x1f449;博客主页&#xff1a;欢迎各位大佬!&#x1f448; 文章目录 1. 前置回顾2. 动态线程池2.1 JMX 的介绍2.1.1 MBeans 介绍 2.2 使用 JMX jconsole 实现动态修改线程池2.2.…

js-对象-JSON

JavaScript自定义对象 JSON 概念: JavaScript Object Notation&#xff0c;JavaScript对象标记法. JSON 是通过JavaScript 对象标记法书写的文本。 由于其语法简单&#xff0c;层次结构鲜明&#xff0c;现多用于作为数据载体&#xff0c;在网络中进行数据传输. json中属性名(k…

基于 SpringBoot3 的 SpringSecurity6 + OAuth2 自定义框架模板

&#x1f516;Gitee 项目地址&#xff1a; 基于SpringBoot3的 SpringSecurity6 OAuth2 自定义框架https://gitee.com/MIMIDeK/MySpringSecurityhttps://gitee.com/MIMIDeK/MySpringSecurityhttps://gitee.com/MIMIDeK/MySpringSecurity

大模型综述一镜到底(全文八万字) ——《Large Language Models: A Survey》

论文链接&#xff1a;https://arxiv.org/abs/2402.06196 摘要&#xff1a;自2022年11月ChatGPT发布以来&#xff0c;大语言模型&#xff08;LLMs&#xff09;因其在广泛的自然语言任务上的强大性能而备受关注。正如缩放定律所预测的那样&#xff0c;大语言模型通过在大量文本数…

Django视图与URLs路由详解

在Django Web框架中&#xff0c;视图&#xff08;Views&#xff09;和URLs路由&#xff08;URL routing&#xff09;是Web应用开发的核心概念。它们共同负责将用户的请求映射到相应的Python函数&#xff0c;并返回适当的响应。本篇博客将深入探讨Django的视图和URLs路由系统&am…

位置-速度双闭环PID控制详解与C语言实现

目录 概述 1 控制架构解析 1.1 级联控制结构 1.2 性能对比 2 数学模型 2.1 位置环(外环) 2.2 速度环(内环) 3 C语言完整实现 3.1 控制结构体定义 3.2 初始化函数 3.3 双环计算函数 4 参数整定指南 4.1 整定步骤 4.2 典型参数范围 5 关键优化技术 5.1 速度前馈 …

亚博microros小车-原生ubuntu支持系列:22 物体识别追踪

背景知识 跟上一个颜色追踪类似。也是基于opencv的&#xff0c;不过背后的算法有很多 BOOSTING&#xff1a;算法原理类似于Haar cascades (AdaBoost)&#xff0c;是一种很老的算法。这个算法速度慢并且不是很准。MIL&#xff1a;比BOOSTING准一点。KCF&#xff1a;速度比BOOST…

低至3折,百度智能云千帆宣布全面支持DeepSeek-R1/V3调用

DeepSeek-R1和 DeepSeek-V3模型已在百度智能云千帆平台上架 。 出品|产业家 新年伊始&#xff0c;百度智能云又传来新动作 。 2月3日百度智能云宣布&#xff0c; DeepSeek-R1和 DeepSeek-V3模型已在百度智能云千帆平台上架&#xff0c;同步推出超低价格方案&#xff0c;并…

Deepseek技术浅析(四):专家选择与推理机制

DeepSeek 是一种基于**专家混合模型&#xff08;Mixture of Experts, MoE&#xff09;**的先进深度学习架构&#xff0c;旨在通过动态选择和组合多个专家网络&#xff08;Expert Networks&#xff09;来处理复杂的任务。其核心思想是根据输入数据的特征&#xff0c;动态激活最合…

go运算符

内置运算符 算术运算符关系运算符逻辑运算符位运算符赋值运算符 算术运算符 注意&#xff1a; &#xff08;自增&#xff09;和–&#xff08;自减&#xff09;在 Go 语言中是单独的语句&#xff0c;并不是运算符 package mainimport "fmt"func main() {fmt.Printl…

分享2款 .NET 开源且强大的翻译工具

前言 对于程序员而言永远都无法逃避和英文打交道&#xff0c;今天大姚给大家分享2款 .NET 开源、功能强大的翻译工具&#xff0c;希望可以帮助到有需要的同学。 STranslate STranslate是一款由WPF开源的、免费的&#xff08;MIT License&#xff09;、即开即用、即用即走的翻…

技术书籍写作与编辑沟通指南

引言 撰写技术书籍不仅仅是知识的输出过程&#xff0c;更是与编辑团队紧密合作的协同工作。优秀的技术书籍不仅依赖作者深厚的技术背景&#xff0c;还需要精准的表达、流畅的结构以及符合出版要求的编辑润色。因此&#xff0c;如何高效地与编辑沟通&#xff0c;确保书籍质量&a…