VAE——生成数字(Pytorch+mnist)

1、简介

  • VAE(变分自编码器)同样由编码器和解码器组成,但与AE不同的是,VAE通过引入隐变量并利用概率分布来学习潜在表示。
  • VAE的编码器学习将输入数据映射到潜在空间的概率分布的参数,而不是直接映射到确定性的潜在表示。
  • VAE的解码器则通过从编码器学得的概率分布中采样,从而生成样本。
  • VAE的训练目标既包括最小化重构误差,也包括最大化编码器输出的潜在空间与单位高斯分布之间的KL散度,以促使学得的潜在表示更接近于标准正态分布。
  • VAE可以生成更连续、更具表现力的样本,并且具有更强的概率建模能力。
  • 本文利用VAE,输入数字图像。训练后,生成新的数字图像。
    • (100个epochs的结果)
  • 【注】本文案例输出的是随机的64个数字。

2、代码

  • import torch
    import torch.nn as nn
    import torch.optim as optim
    import torchvision
    import torch.nn.functional as F
    from torchvision.utils import save_image
    
    
    # 变分自编码器
    class VAE(nn.Module):
        def __init__(self):
            super(VAE, self).__init__()
    
            # 编码器层
            self.fc1 = nn.Linear(input_size, 512)  # 编码器输入层
            self.fc2 = nn.Linear(512, latent_size)
            self.fc3 = nn.Linear(512, latent_size)
    
            # 解码器层
            self.fc4 = nn.Linear(latent_size, 512)  # 解码器输入层
            self.fc5 = nn.Linear(512, input_size)  # 解码器输出层
    
        # 编码器部分
        def encode(self, x):
            x = F.relu(self.fc1(x))  # 编码器的隐藏表示
            mu = self.fc2(x)  # 潜在空间均值
            log_var = self.fc3(x)  # 潜在空间对数方差
            return mu, log_var
    
        # 重参数化技巧
        def reparameterize(self, mu, log_var):  # 从编码器输出的均值和对数方差中采样得到潜在变量z
            std = torch.exp(0.5 * log_var)  # 计算标准差
            eps = torch.randn_like(std)  # 从标准正态分布中采样得到随机噪声
            return mu + eps * std  # 根据重参数化公式计算潜在变量z
    
        # 解码器部分
        def decode(self, z):
            z = F.relu(self.fc4(z))  # 将潜在变量 z 解码为重构图像
            return torch.sigmoid(self.fc5(z))  # 将隐藏表示映射回输入图像大小,并应用 sigmoid 激活函数,以产生重构图像
    
        # 前向传播
        def forward(self, x):  # 输入图像 x 通过编码器和解码器,得到重构图像和潜在变量的均值和对数方差
            mu, log_var = self.encode(x.view(-1, input_size))
            z = self.reparameterize(mu, log_var)
            return self.decode(z), mu, log_var
    
    
    # 使用重构损失和 KL 散度作为损失函数
    def loss_function(recon_x, x, mu, log_var):  # 参数:重构的图像、原始图像、潜在变量的均值、潜在变量的对数方差
        MSE = F.mse_loss(recon_x, x.view(-1, input_size), reduction='sum')  # 计算重构图像 recon_x 和原始图像 x 之间的均方误差
        KLD = -0.5 * torch.sum(1 + log_var - mu.pow(2) - log_var.exp())  # 计算潜在变量的KL散度
        return MSE + KLD  # 返回二进制交叉熵损失和 KLD 损失的总和作为最终的损失值
    
    
    def sample_images(epoch):
        with torch.no_grad():  # 上下文管理器,确保在该上下文中不会进行梯度计算。因为在这里只是生成样本而不需要梯度
            number = 64
            sample = torch.randn(number, latent_size).to(device)  # 生成一个形状为 (64, latent_size) 的张量,其中包含从标准正态分布中采样的随机数
            sample = model.decode(sample).cpu()  # 将随机样本输入到解码器中,解码器将其映射为图像
            save_image(sample.view(number, 1, 28, 28), f'sample{epoch}.png')  # 将生成的图像保存为文件
    
    
    if __name__ == '__main__':
        batch_size = 512  # 批次大小
        epochs = 100  # 学习周期
        sample_interval = 10  # 保存结果的周期
        learning_rate = 0.001  # 学习率
        input_size = 784  # 输入大小
        latent_size = 64  # 噪声大小
    
        # 载入 MNIST 数据集中的图片进行训练
        transform = torchvision.transforms.Compose([torchvision.transforms.ToTensor()])  # 将图像转换为张量
    
        train_dataset = torchvision.datasets.MNIST(
            root="~/torch_datasets", train=True, transform=transform, download=True
        )  # 加载 MNIST 数据集的训练集,设置路径、转换和下载为 True
    
        train_loader = torch.utils.data.DataLoader(
            train_dataset, batch_size=batch_size, shuffle=True
        )  # 创建一个数据加载器,用于加载训练数据,设置批处理大小和是否随机打乱数据
    
        # 在使用定义的 AE 类之前,有以下事情要做:
        # 配置要在哪个设备上运行
        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    
        # 建立 VAE 模型并载入到 CPU 设备
        model = VAE().to(device)
    
        # Adam 优化器,学习率
        optimizer = optim.Adam(model.parameters(), lr=learning_rate)
    
        # 训练
        for epoch in range(epochs):
            train_loss = 0
            for batch_idx, (data, _) in enumerate(train_loader):
                data = data.to(device)  # 将输入数据移动到设备(GPU 或 CPU)上
    
                optimizer.zero_grad()  # 进行反向传播之前,需要将优化器中的梯度清零,以避免梯度的累积
    
                # 重构图像 recon_batch、潜在变量的均值 mu 和对数方差 log_var
                recon_batch, mu, log_var = model(data)
    
                loss = loss_function(recon_batch, data, mu, log_var)  # 计算损失
                loss.backward()  # 计算损失相对于模型参数的梯度
                train_loss += loss.item()
    
                optimizer.step()  # 更新模型参数
    
            train_loss = train_loss / len(train_loader)  # # 计算每个周期的训练损失
            print('Epoch [{}/{}], Loss: {:.3f}'.format(epoch + 1, epochs, train_loss))
    
            # 每10次保存图像
            if (epoch + 1) % sample_interval == 0:
                sample_images(epoch + 1)
    
            # 每训练10次保存模型
            if (epoch + 1) % sample_interval == 0:
                torch.save(model.state_dict(), f'vae{epoch + 1}.pth')

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/502885.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

git最常用的命令与快捷操作说明

git最常用的命令与快捷操作说明 最常用的git三条命令1、git add .2、git commit -m "推送注释"3、git push origin 远程分支名:本地分支名 其他常用命令本地创建仓库分支删除本地指定分支切换本地分支合并本地分支拉取远程仓库指定分支代码过来合并推送代码到远程分支…

c语言:vs2022写一个一元二次方程(包含虚根)

求一元二次方程 的根&#xff0c;通过键盘输入a、b、c&#xff0c;根据△的值输出对应x1和x2的值(保留一位小数)(用if语句完成)。 //一元二次方程的实现 #include <stdio.h> #include <math.h> #include <stdlib.h> int main() {double a, b, c, delta, x1…

day3 wsl下启动第一个nest项目(java转ts全栈/3R教室)

背景&#xff1a;准备先找个nestjs模板项目&#xff08;kuizou大佬的nest-vben-admin&#xff09;看看大体情况&#xff0c;但发现win下还是问题还真挺多&#xff0c;受不了了今天一定要把wsl环境安装好。。。 比如如下明显就是win环境导致的错误&#xff0c;估计wsl下应该没问…

给大家推荐一个系统运维管理神器jeecat

一个可以当堡垒机又可以当成系统运维管理软件的神器&#xff0c;不仅支持堡垒机的全部功能&#xff0c;还实现以系统为维度的全方位授权管控&#xff0c;有效避免信息泄露、删库跑路等危险操作&#xff0c;作公司的安全运维管控神器。 无须任何插件&#xff0c;只须浏览器的web…

蓝桥杯骗分小技巧

写在前面 由于本人第一次参加的是cpp组&#xff0c;第二次参加的python组&#xff0c;所以一些技巧都是关于cpp和python的 先上圣经 贪心骗样例&#xff0c;暴力出奇迹。 暴搜挂着机&#xff0c;打表出省一。 数学先打表&#xff0c;DP看运气。穷举TLE&#xff0c;打表UKE。 模…

系统需求分析报告(原件获取)

第1章 序言 第2章 引言 2.1 项目概述 2.2 编写目的 2.3 文档约定 2.4 预期读者及阅读建议 第3章 技术要求 3.1 软件开发要求 第4章 项目建设内容 第5章 系统安全需求 5.1 物理设计安全 5.2 系统安全设计 5.3 网络安全设计 5.4 应用安全设计 5.5 对用户安全管理 …

【Qt】QMainWindow

目录 一、概念 二、菜单栏 2.1 创建菜单栏 2.2 在菜单栏中添加菜单 2.3 创建菜单项 2.4 在菜单项之间添加分割线 三、工具栏 3.1 创建工具栏 3.2 设置停靠位置 3.3 设置浮动属性 3.4 设置移动属性 四、状态栏 4.1 状态栏的创建 4.2 显示实时消息 4.3 显示永久消…

vue3+vite模版框架 tabs右键刷新时丢失路由参数

问题&#xff1a; 标题栏的tabs的右键&#xff1a;刷新时&#xff0c;没有保存上一个页面传递过来的参数 分析&#xff1a; TagView.vue刷新事件 function refreshSelectedTag(view: TagView) {console.log(|--执行刷新, view)tagsViewStore.delCachedView(view);const {full…

Cookie/Session

1.Cookie HTTP 协议自身是属于 "无状态" 协议. "无状态" 的含义指的是: 默认情况下 HTTP 协议的客户端和服务器之间的这次通信, 和下次通信之间没有直接的联系. 但是实际开发中, 我们很多时候是需要知道请求之间的关联关系的. 例如登陆网站成功后, 第二…

004 高并发内存池_ThreadCache设计

​&#x1f308;个人主页&#xff1a;Fan_558 &#x1f525; 系列专栏&#xff1a;高并发内存池 &#x1f339;关注我&#x1f4aa;&#x1f3fb;带你学更多知识 文章目录 前言文章重点一、设计FreeList自由链表结构二、定制对齐映射规则三、完成申请Allocate与释放Deallocate…

数据结构:链表的双指针技巧

文章目录 一、链表相交问题二、单链表判环问题三、回文链表四、重排链表结点 初学双指针的同学&#xff0c;请先弄懂删除链表的倒数第 N 个结点。 并且在学习这一节时&#xff0c;不要将思维固化&#xff0c;认为只能这样做&#xff0c;这里的做法只是技巧。 一、链表相交问题 …

【免费获取】【下片神器】IDM非主流网站视频免费下载神器IDM+m3u8并解决idm下载失败问题 idm下载器超长免费试用

当你浏览一个网站&#xff0c;看到一个喜欢的视频&#xff0c;不知道如何下载的时候&#xff0c;本教程或许可以帮你点小忙。大部分的主流网站都有专门的下载工具&#xff0c;本篇教程主要针对的是一些非主流的小网站。 我们的下载方法就是利用IDM&#xff08;Internet Downlo…

npm卸载不掉的解决方案

不管怎么重装重启都报错 真服了&#xff0c;npm卸载不掉绝对是有缓存存在&#xff0c;用where npm查到d盘 实际上根本不在这个地方&#xff0c;这个是我安装的6.14.12版本的npm的地方&#xff0c;我说我怎么怎么重装怎么导包都不行呢&#xff0c;偷偷隐藏在这个目录里面&#…

Unity 学习日记 13.地形系统

下载源码 UnityPackage 1.地形对象Terrain 目录 1.地形对象Terrain 2.设置地形纹理 3.拔高地形地貌 4. 绘制树和草 5.为地形加入水 6.加入角色并跑步 7.加入水声 右键创建3D地形&#xff1a; 依次对应下面的按钮 || 2.设置地形纹理 下载资源包 下载资源包后&#x…

【Ubuntu】文件和目录的增、删、改、查操作

这里写目录标题 (一)文件和目录类命令的使用1、目录与文件的增加&#xff08;1&#xff09;目录的增加 :&#xff08;2&#xff09;文件的增加 2、目录与文件的删除&#xff08;1&#xff09;目录和文件的删除 3、目录与文件的修改&#xff08;1&#xff09;mv命令 4、目录与文…

【跟着CHATGPT学习硬件外设 | 01】SPI

文章目录 &#x1f680; 概念揭秘关键精华&#x1f31f; 秒懂案例生活类比实战演练 &#x1f50d; 原理与工作流程探秘步骤1&#xff1a;初始化SPI接口步骤2&#xff1a;主设备启动通信步骤3&#xff1a;主设备发送数据步骤4&#xff1a;从设备接收数据步骤5&#xff1a;从设备…

Zookeeper(九)客户端的启动流程

目录 一 ZooKeeper会话的创建与连接1.1 会话的创建1.1.1 ClientWatchManager1.1.2 ConnectStringParser1.1.3 HostProvider1.1.4 ClientCnxn 1.2 会话的连接1.2.1 SendThread1.2.2 eventThread 二 ZooKeeper会话的响应2.1 接受服务端响应 三 ClientCnxn 详解3.1 Packet3.2 队列…

一文彻底搞懂并发容器

文章目录 1. 什么是并发容器2. 并发容器的分类 1. 什么是并发容器 并发容器是一种用于多线程环境的数据结构&#xff0c;它们能够有效地处理并发访问和修改的问题。在多线程应用程序中&#xff0c;多个线程可能会同时访问和修改共享的数据结构&#xff0c;这可能会导致数据不一…

一文教你如何轻松领取阿里云优惠券

随着云计算技术的飞速发展&#xff0c;越来越多的企业和个人选择使用阿里云作为他们的云服务提供商。为了吸引更多的用户上云&#xff0c;阿里云推出了各种优惠券和促销活动。本文将教大家如何轻松领取阿里云优惠券&#xff0c;以便在购买阿里云产品和服务时享受更多优惠。 一、…

激发数据潜力:企业数据中台的策略性构建与优化_光点科技

在信息时代&#xff0c;数据是企业价值链中不可或缺的一环。构建一个策略性的企业数据中台不仅能够整合分散的数据资源&#xff0c;还能提高决策效率和业务敏捷性。本文聚焦于如何策略性地构建和优化数据中台&#xff0c;以便企业能够最大化地利用数据资源&#xff0c;推动企业…