VAE-pytorch代码

 

 

 

import os
 
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
 
from torchvision import transforms, datasets
from torchvision.utils import save_image
 
from tqdm import tqdm
 
 
class VAE(nn.Module):  # 定义VAE模型
    def __init__(self, img_size, latent_dim):  # 初始化方法
        super(VAE, self).__init__()  # 继承初始化方法
        self.in_channel, self.img_h, self.img_w = img_size  # 由输入图片形状得到图片通道数C、图片高度H、图片宽度W
        self.h = self.img_h // 32  # 经过5次卷积后,最终特征层高度变为原图片高度的1/32
        self.w = self.img_w // 32  # 经过5次卷积后,最终特征层宽度变为原图片高度的1/32
        hw = self.h * self.w  # 最终特征层的尺寸hxw
        self.latent_dim = latent_dim  # 采样变量Z的长度
        self.hidden_dims = [32, 64, 128, 256, 512]  # 特征层通道数列表
        # 开始构建编码器Encoder
        layers = []  # 用于存放模型结构
        for hidden_dim in self.hidden_dims:  # 循环特征层通道数列表
            layers += [nn.Conv2d(self.in_channel, hidden_dim, 3, 2, 1),  # 添加conv
                       nn.BatchNorm2d(hidden_dim),  # 添加bn
                       nn.LeakyReLU()]  # 添加leakyrelu
            self.in_channel = hidden_dim  # 将下次循环的输入通道数设为本次循环的输出通道数
 
        self.encoder = nn.Sequential(*layers)  # 解码器Encoder模型结构
 
        self.fc_mu = nn.Linear(self.hidden_dims[-1] * hw, self.latent_dim)  # linaer,将特征向量转化为分布均值mu
        self.fc_var = nn.Linear(self.hidden_dims[-1] * hw, self.latent_dim)  # linear,将特征向量转化为分布方差的对数log(var)
        # 开始构建解码器Decoder
        layers = []  # 用于存放模型结构
        self.decoder_input = nn.Linear(self.latent_dim, self.hidden_dims[-1] * hw)  # linaer,将采样变量Z转化为特征向量
        self.hidden_dims.reverse()  # 倒序特征层通道数列表
        for i in range(len(self.hidden_dims) - 1):  # 循环特征层通道数列表
            layers += [nn.ConvTranspose2d(self.hidden_dims[i], self.hidden_dims[i + 1], 3, 2, 1, 1),  # 添加transconv
                       nn.BatchNorm2d(self.hidden_dims[i + 1]),  # 添加bn
                       nn.LeakyReLU()]  # 添加leakyrelu
        layers += [nn.ConvTranspose2d(self.hidden_dims[-1], self.hidden_dims[-1], 3, 2, 1, 1),  # 添加transconv
                   nn.BatchNorm2d(self.hidden_dims[-1]),  # 添加bn
                   nn.LeakyReLU(),  # 添加leakyrelu
                   nn.Conv2d(self.hidden_dims[-1], img_size[0], 3, 1, 1),  # 添加conv
                   nn.Tanh()]  # 添加tanh
        self.decoder = nn.Sequential(*layers)  # 编码器Decoder模型结构
 
    def encode(self, x):  # 定义编码过程
        result = self.encoder(x)  # Encoder结构,(n,1,32,32)-->(n,512,1,1)
        result = torch.flatten(result, 1)  # 将特征层转化为特征向量,(n,512,1,1)-->(n,512)
        mu = self.fc_mu(result)  # 计算分布均值mu,(n,512)-->(n,128)
        log_var = self.fc_var(result)  # 计算分布方差的对数log(var),(n,512)-->(n,128)
        return [mu, log_var]  # 返回分布的均值和方差对数
 
    def decode(self, z):  # 定义解码过程
        y = self.decoder_input(z).view(-1, self.hidden_dims[0], self.h,
                                       self.w)  # 将采样变量Z转化为特征向量,再转化为特征层,(n,128)-->(n,512)-->(n,512,1,1)
        y = self.decoder(y)  # decoder结构,(n,512,1,1)-->(n,1,32,32)
        return y  # 返回生成样本Y
 
    def reparameterize(self, mu, log_var):  # 重参数技巧
        std = torch.exp(0.5 * log_var)  # 分布标准差std
        eps = torch.randn_like(std)  # 从标准正态分布中采样,(n,128)
        return mu + eps * std  # 返回对应正态分布中的采样值
 
    def forward(self, x):  # 前传函数
        mu, log_var = self.encode(x)  # 经过编码过程,得到分布的均值mu和方差对数log_var
        z = self.reparameterize(mu, log_var)  # 经过重参数技巧,得到分布采样变量Z
        y = self.decode(z)  # 经过解码过程,得到生成样本Y
        return [y, x, mu, log_var]  # 返回生成样本Y,输入样本X,分布均值mu,分布方差对数log_var
 
    def sample(self, n, cuda):  # 定义生成过程
        z = torch.randn(n, self.latent_dim)  # 从标准正态分布中采样得到n个采样变量Z,长度为latent_dim
        if cuda:  # 如果使用cuda
            z = z.cuda()  # 将采样变量Z加载到GPU
        images = self.decode(z)  # 经过解码过程,得到生成样本Y
        return images  # 返回生成样本Y
 
 
def loss_fn(y, x, mu, log_var):  # 定义损失函数
    recons_loss = F.mse_loss(y, x)  # 重建损失,MSE
    kld_loss = torch.mean(0.5 * torch.sum(mu ** 2 + torch.exp(log_var) - log_var - 1, 1), 0)  # 分布损失,正态分布与标准正态分布的KL散度
    return recons_loss + w * kld_loss  # 最终损失由两部分组成,其中分布损失需要乘上一个系数w
 
 
if __name__ == "__main__":
    total_epochs = 100  # epochs
    batch_size = 64  # batch size
    lr = 5e-4  # lr
    w = 0.00025  # kld_loss的系数w
    num_workers = 8  # 数据加载线程数
    image_size = 32  # 图片尺寸
    image_channel = 1  # 图片通道
    latent_dim = 128  # 采样变量Z长度
    sample_images_dir = "sample_images"  # 生成样本示例存放路径
    train_dataset_dir = "../dataset/mnist"  # 训练样本存放路径
 
    os.makedirs(sample_images_dir, exist_ok=True)  # 创建生成样本示例存放路径
    os.makedirs(train_dataset_dir, exist_ok=True)  # 创建训练样本存放路径
    cuda = True if torch.cuda.is_available() else False  # 如果cuda可用,则使用cuda
    img_size = (image_channel, image_size, image_size)  # 输入样本形状(1,32,32)
 
    vae = VAE(img_size, latent_dim)  # 实例化VAE模型,传入输入样本形状与采样变量长度
    if cuda:  # 如果使用cuda
        vae = vae.cuda()  # 将模型加载到GPU
    # dataset and dataloader
    transform = transforms.Compose(  # 图片预处理方法
        [transforms.Resize(image_size),  # 图片resize,(28x28)-->(32,32)
         transforms.ToTensor(),  # 转化为tensor
         transforms.Normalize([0.5], [0.5])]  # 标准化
    )
    dataloader = DataLoader(  # 定义dataloader
        dataset=datasets.MNIST(root=train_dataset_dir,  # 使用mnist数据集,选择数据路径
                               train=True,  # 使用训练集
                               transform=transform,  # 图片预处理
                               download=True),  # 自动下载
        batch_size=batch_size,  # batch size
        num_workers=num_workers,  # 数据加载线程数
        shuffle=True  # 打乱数据
    )
    # optimizer
    optimizer = torch.optim.Adam(vae.parameters(), lr=lr)  # 使用Adam优化器
    # train loop
    for epoch in range(total_epochs):  # 循环epoch
        total_loss = 0  # 记录总损失
        pbar = tqdm(total=len(dataloader), desc=f"Epoch {epoch + 1}/{total_epochs}", postfix=dict,
                    miniters=0.3)  # 设置当前epoch显示进度
        for i, (img, _) in enumerate(dataloader):  # 循环iter
            if cuda:  # 如果使用cuda
                img = img.cuda()  # 将训练数据加载到GPU
            vae.train()  # 模型开始训练
            optimizer.zero_grad()  # 模型清零梯度
            y, x, mu, log_var = vae(img)  # 输入训练样本X,得到生成样本Y,输入样本X,分布均值mu,分布方差对数log_var
            loss = loss_fn(y, x, mu, log_var)  # 计算loss
            loss.backward()  # 反向传播,计算当前梯度
            optimizer.step()  # 根据梯度,更新网络参数
            total_loss += loss.item()  # 累计loss
            pbar.set_postfix(**{"Loss": loss.item()})  # 显示当前iter的loss
            pbar.update(1)  # 步进长度
        pbar.close()  # 关闭当前epoch显示进度
        print("total_loss:%.4f" %
              (total_loss / len(dataloader)))  # 显示当前epoch训练完成后,模型的总损失
        vae.eval()  # 模型开始验证
        sample_images = vae.sample(25, cuda)  # 获得25个生成样本
        save_image(sample_images.data, "%s/ep%d.png" % (sample_images_dir, (epoch + 1)), nrow=5,
                   normalize=True)  # 保存生成样本示例(5x5)

其中计算KLloss的代码的解释如下:

代码的目标是计算变分自编码器(VAE)中近似后验分布q(z∣x) 和标准正态分布 p(z) 之间的KL散度。KL散度公式的具体计算步骤如下:

1. mu ** 2

计算均值的平方项: μ2 这个项是为了衡量均值偏离零的程度。

2. torch.exp(log_var)

对数方差取指数,以获得实际的方差: exp⁡(log⁡(σ2))=σ2 这个项衡量方差的大小。

3. - log_var

减去对数方差: −log⁡(σ2) 这个项衡量分布的扩展程度。

4. - 1

减去 1,是KL散度公式中的常数项,用于归一化。

将这些项加在一起:

μ2+exp⁡(log⁡(σ2))−log⁡(σ2)−1

5. torch.sum(..., 1)

对所有维度求和,计算单个样本的KL散度: ∑(μ2+σ2−log⁡(σ2)−1) 这一步是将每个样本的所有维度的KL散度加起来。

6. 0.5 * ...

乘以 0.5,因KL散度公式中有系数 0.5: 0.5×∑(μ2+σ2−log⁡(σ2)−1)

7. torch.mean(..., 0)

对所有样本取平均,得到最终的KL散度损失: mean(0.5×∑(μ2+σ2−log⁡(σ2)−1))

整个公式的作用是计算出近似后验分布 q(z∣x) 和标准正态分布 p(z) 之间的KL散度,该散度表示了两个分布之间的差异。这种损失通常用于变分自编码器(VAE)训练中,确保生成的潜在变量分布接近标准正态分布。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/750677.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

基于盲信号处理的声音分离-基于改进的信息最大化的ICA算法

基于信息最大化的ICA算法的主要依据是使输入端与输出端的互信息达到最大,且输出各个分量之间的相关性最小化,即输出各个分量之间互信息量最小化,其算法的系统框图如图所示。 基于信息最大化的ICA算法的主要依据是使输入端与输出端的互信息达到…

java基于ssm+jsp 弹幕视频网站

1前台首页功能模块 弹幕视频网站,在弹幕视频网站可以查看首页、视频信息、商品信息、论坛信息、我的、跳转到后台、购物车、客服等内容,如图1所示。 图1前台首页界面图 登录,通过登录填写账号、密码等信息进行登录操作,如图2所示…

Sparse4D v1

Sparse4D: Multi-view 3D Object Detection with Sparse Spatial-Temporal Fusion 单位:地平线 GitHub:https://github.com/HorizonRobotics/Sparse4D 论文:https://arxiv.org/abs/2211.10581 时间:2022-11 找博主项目讨论方…

【MotionCap】conda 链接缺失的cuda库

conda 安装的环境不知道为啥python 环境里的 一些cuda库是空的要自己链接过去。ln 前面是已有的,后面是要新创建的 ln -s <path to the file/folder to be linked> cuda 有安装 libcublas 已经在cuda中 (base) zhangbin@ubuntu-server:~/miniconda3/envs/ai-mocap/lib/…

ARM芯片架构(RTOS)

前言&#xff1a;笔记韦东山老师的rtos教程&#xff0c;连接放在最后 #ARM介绍 arm芯片属于精简指令集risc&#xff0c;所用的指令比较简单&#xff0c;ARM架构是一种精简指令集&#xff08;RISC&#xff09;架构&#xff0c;广泛应用于移动设备、嵌入式系统、物联网等领域。AR…

40.设计HOOK引擎的好处

免责声明&#xff1a;内容仅供学习参考&#xff0c;请合法利用知识&#xff0c;禁止进行违法犯罪活动&#xff01; 上一个内容&#xff1a;39.右键弹出菜单管理游戏列表 以 39.右键弹出菜单管理游戏列表 它的代码为基础进行修改 效果图&#xff1a; 实现步骤&#xff1a; 首…

吴恩达2022机器学习专项课程C2W3:2.27 选修_数据倾斜

目录 处理不平衡数据集1.分类需求描述2.计算精确率和召回率 权衡精确率和召唤率1.手动调整阈值2.F1分数 总结 处理不平衡数据集 1.分类需求描述 如果你在处理一个机器学习应用&#xff0c;其中正例和负例的比例&#xff08;用于解决分类问题&#xff09;非常不平衡&#xff0…

图像大小调整(缩放)

尺寸调整前尺寸调整前 1、背景介绍 在深度学习中&#xff0c;将图像调整到固定尺寸&#xff08;如28x28像素&#xff09;的操作是非常常见的&#xff0c;尤其是在处理诸如图像分类、物体检测和图像分割等任务时。这种操作有几个重要原因&#xff1a; 标准化输入&#xff1a;许…

MYSQL 四、mysql进阶 5(InnoDB数据存储结构)

一、数据库的存储结构&#xff1a;页 索引结构给我们提供了高效的索引方式&#xff0c;不过索引信息以及数据记录都是保存在文件上的&#xff0c;确切说时存储在页结构中&#xff0c;另一方面&#xff0c;索引是在存储引擎中实现的&#xff0c;Mysql服务器上的存储引擎负责对表…

当中年男人的觉越来越少 他突然半夜买台电脑(30+岁仿真工程师买电脑心得)

仿真工程师的购机分析&#xff0c;游戏本、移动工作站还是台式机&#xff1f; 认清自己的需求。 现状。现在有一个19年买的华为matebook14、i5第八代低压U&#xff0c;8G内存。还好有SSD当虚拟内存&#xff0c;要不开网页估计都得卡住。媳妇还有台i7、16G的matebook&#xff…

MC进样管PFA塑料管NEPTUNE Plus多接收等离子质谱仪配套管子

PFA进样管可适配Neptune plus多接收器等离子质谱仪&#xff08;MC-ICP-MS&#xff09;&#xff0c;广泛应用于地球化学、核保障、环境科学、金属组学领域&#xff0c;在生物、物理、化学、材料等多个学科的交叉方向也有良好的应用前景。 外观半透明&#xff0c;便于观察管内情…

基于LangChain构建RAG应用

前言 Hello&#xff0c;大家好&#xff0c;我是GISer Liu&#x1f601;&#xff0c;一名热爱AI技术的GIS开发者&#xff0c;上一篇文章中我们详细介绍了RAG的核心思想以及搭建向量数据库的完整过程&#xff1b;&#x1f632; 本文将基于上一篇文章的结果进行开发&#xff0c;主…

最长回文串

描述&#xff1a; 最长回文串 思路&#xff1a; 统计每个字母出现次数&#xff0c;如果是偶数&#xff0c;ret x;如果是存在奇数的话&#xff0c;就可以放在中间&#xff0c;ret 1. 代码&#xff1a; class Solution { public:int hash[200];int longestPalindrome(str…

Elasticsearch8.x聚合查询全面指南:从理论到实战

聚合查询的概念 聚合查询&#xff08;Aggregation Queries&#xff09;是Elasticsearch中用于数据汇总和分析的查询类型。它不同于普通的查询&#xff0c;而是用于执行各种聚合操作&#xff0c;如计数、求和、平均值、最小值、最大值、分组等。 聚合查询的分类 分桶聚合&…

绘唐3是免费的吗?

绘唐科技是一家中国电子信息产品制造商和供应商&#xff0c;成立于2005年。公司主要经营智能硬件、智能穿戴设备、智能家居设备和智能交通设备等领域的产品开发和销售。绘唐科技拥有强大的研发团队和制造能力&#xff0c;能够为客户提供定制化的产品解决方案。 绘唐科技的产品种…

【Spring】Spring学习笔记

Spring数据库 Spring JDBC 环境准备 创建Spring项目, 添加以下依赖 H2 Database: 用于充当嵌入式测试数据库JDBC API: 用于连接数据库Lombok: 用于简化pojo的编写 然后添加配置文件: spring.output.ansi.enabledALWAYS spring.datasource.username*********** spring.dataso…

3d怎么把歪的模型摆正?---模大狮模型网

在进行3D建模过程中&#xff0c;有时候会遇到模型出现歪曲或者旋转不正确的情况&#xff0c;这可能会影响到后续的设计和渲染效果。因此&#xff0c;学会将歪曲的模型摆正是一个非常重要的技巧。模大狮将介绍几种常用的方法&#xff0c;帮助您有效地将歪曲的3D模型摆正&#xf…

抖音团购达人实战营,抖音团购达人从0-1教程(11节课)

课程目录&#xff1a; 1-团购达人先导课1.mp4 2-账号措建.mp4 2-账号搭建_1.mp4 3-开通团购达人_1.mp4 4-账号养号涨粉套路_1.mp4 5-团购选品正确姿势_1.mp4 6-短视频之混剪课_1.mp4 7-短视频之图文课_1.mp4 8-短视频之口播课_1.mp4 9-短视频运营策略_1.mp4 10-团购…

纯血鸿蒙Beta版本发布,中国华为,站起来了!

2024年6月21日至23日&#xff0c;华为开发者大会2024&#xff08;HDC 2024&#xff09;于东莞盛大举行。 此次大会不仅在会场设置了包括鸿蒙原生应用、统一生态统一互联等在内的11个展区&#xff0c;以供展示HarmonyOS NEXT的强大实力&#xff0c;还对外宣布了HarmonyOS的最新进…

初探 YOLOv8(训练参数解析)

文章目录 1、前言2、Backbone网络3、YOLOv8模型训练代码3.1、模型大小选择3.2、训练参数设置 4、训练参数说明5、目标检测系列文章 1、前言 YOLO 因为性能强大、消耗算力较少&#xff0c;一直以来都是实时目标检测领域的主要范式。该框架被广泛用于各种实际应用&#xff0c;包…