AE——重构数字(Pytorch+mnist)

1、简介

  • AE(自编码器)由编码器和解码器组成,编码器将输入数据映射到潜在空间,解码器将潜在表示映射回原始输入空间。
  • AE的训练目标通常是最小化重构误差,即尽可能地重构输入数据,使得解码器输出与原始输入尽可能接近。
  • AE通常用于数据压缩、去噪、特征提取等任务。
  • 本文利用AE,输入数字图像。训练后,输入测试数字图像,重构生成新的数字图像。
    • 【注】本文案例需要输入才能生成输出,目标是重构,而不是生成。
  • 可以看出,重构图片和原始图片差别不大。 
  • 【注】输出的10张数字图像是输入的测试图像的第一批次。

2、代码

  • import matplotlib.pyplot as plt
    import torch
    import torch.nn as nn
    import torch.optim as optim
    import torchvision
    
    
    # 在一个类中编写编码器和解码器层。为编码器和解码器层的组件都定义了全连接层
    class AE(nn.Module):
        def __init__(self, **kwargs):
            super().__init__()
            self.encoder_hidden_layer = nn.Linear(
                in_features=kwargs["input_shape"], out_features=128
            )  # 编码器隐藏层
            self.encoder_output_layer = nn.Linear(
                in_features=128, out_features=128
            )  # 编码器输出层
            self.decoder_hidden_layer = nn.Linear(
                in_features=128, out_features=128
            )  # 解码器隐藏层
            self.decoder_output_layer = nn.Linear(
                in_features=128, out_features=kwargs["input_shape"]
            )  # 解码器输出层
    
        # 定义了模型的前向传播过程,包括激活函数的应用和重构图像的生成
        def forward(self, features):
            activation = self.encoder_hidden_layer(features)
            activation = torch.relu(activation)  # ReLU 激活函数,得到编码器的激活值
            code = self.encoder_output_layer(activation)
            code = torch.sigmoid(code)  # Sigmoid 激活函数,以确保编码后的表示在 [0, 1] 范围内
            activation = self.decoder_hidden_layer(code)
            activation = torch.relu(activation)
            activation = self.decoder_output_layer(activation)
            reconstructed = torch.sigmoid(activation)
            return reconstructed
    
    
    if __name__ == '__main__':
        # 设置批大小、学习周期和学习率
        batch_size = 512
        epochs = 30
        learning_rate = 1e-3
    
        # 载入 MNIST 数据集中的图片进行训练
        transform = torchvision.transforms.Compose([torchvision.transforms.ToTensor()])  # 将图像转换为张量
    
        train_dataset = torchvision.datasets.MNIST(
            root="~/torch_datasets", train=True, transform=transform, download=True
        )  # 加载 MNIST 数据集的训练集,设置路径、转换和下载为 True
    
        train_loader = torch.utils.data.DataLoader(
            train_dataset, batch_size=batch_size, shuffle=True
        )  # 创建一个数据加载器,用于加载训练数据,设置批处理大小和是否随机打乱数据
    
        # 在使用定义的 AE 类之前,有以下事情要做:
        # 配置要在哪个设备上运行
        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    
        # 建立 AE 模型并载入到 CPU 设备
        model = AE(input_shape=784).to(device)
    
        # Adam 优化器,学习率 10e-3
        optimizer = optim.Adam(model.parameters(), lr=learning_rate)
    
        # 使用均方误差(MSE)损失函数
        criterion = nn.MSELoss()
    
        # 在GPU设备上运行,实例化一个输入大小为784的AE自编码器,并用Adam作为训练优化器用MSELoss作为损失函数
        # 训练:
        for epoch in range(epochs):
            loss = 0
            for batch_features, _ in train_loader:
                # 将小批数据变形为 [N, 784] 矩阵,并加载到 CPU 设备
                batch_features = batch_features.view(-1, 784).to(device)
    
                # 梯度设置为 0,因为 torch 会累加梯度
                optimizer.zero_grad()
    
                # 计算重构
                outputs = model(batch_features)
    
                # 计算训练重建损失
                train_loss = criterion(outputs, batch_features)
    
                # 计算累积梯度
                train_loss.backward()
    
                # 根据当前梯度更新参数
                optimizer.step()
    
                # 将小批量训练损失加到周期损失中
                loss += train_loss.item()
    
            # 计算每个周期的训练损失
            loss = loss / len(train_loader)
    
            # 显示每个周期的训练损失
            print("epoch : {}/{}, recon loss = {:.8f}".format(epoch + 1, epochs, loss))
    
        # 用训练过的自编码器提取一些测试用例来重构
        test_dataset = torchvision.datasets.MNIST(
            root="~/torch_datasets", train=False, transform=transform, download=True
        )  # 加载 MNIST 测试数据集
    
        test_loader = torch.utils.data.DataLoader(
            test_dataset, batch_size=10, shuffle=False
        )  # 创建一个测试数据加载器
    
        test_examples = None
    
        # 通过循环遍历测试数据加载器,获取一个批次的图像数据
        with torch.no_grad():  # 使用 torch.no_grad() 上下文管理器,确保在该上下文中不会进行梯度计算
            for batch_features in test_loader:  # 历测试数据加载器中的每个批次的图像数据
                batch_features = batch_features[0]  # 获取当前批次的图像数据
                test_examples = batch_features.view(-1, 784).to(
                    device)  # 将当前批次的图像数据转换为大小为 (批大小, 784) 的张量,并加载到指定的设备(CPU 或 GPU)上
                reconstruction = model(test_examples)  # 使用训练好的自编码器模型对测试数据进行重构,即生成重构的图像
                break
    
        # 试着用训练过的自编码器重建一些测试图像
        with torch.no_grad():
            number = 10  # 设置要显示的图像数量
            plt.figure(figsize=(20, 4))  # 创建一个新的 Matplotlib 图形,设置图形大小为 (20, 4)
            for index in range(number):  # 遍历要显示的图像数量
                # 显示原始图
                ax = plt.subplot(2, number, index + 1)
                plt.imshow(test_examples[index].cpu().numpy().reshape(28, 28))
                plt.gray()
                ax.get_xaxis().set_visible(False)
                ax.get_yaxis().set_visible(False)
    
                # 显示重构图
                ax = plt.subplot(2, number, index + 1 + number)
                plt.imshow(reconstruction[index].cpu().numpy().reshape(28, 28))
                plt.gray()
                ax.get_xaxis().set_visible(False)
                ax.get_yaxis().set_visible(False)
            plt.savefig('reconstruction_results.png')  # 保存图像
            plt.show()

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/510163.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

篮球竞赛预约平台的设计与实现|Springboot+ Mysql+Java+ B/S结构(可运行源码+数据库+设计文档)

本项目包含可运行源码数据库LW,文末可获取本项目的所有资料。 推荐阅读300套最新项目持续更新中..... 最新ssmjava项目文档视频演示可运行源码分享 最新jspjava项目文档视频演示可运行源码分享 最新Spring Boot项目文档视频演示可运行源码分享 2024年56套包含ja…

道本科技智慧合规助力企业转型升级

在当今这个快速变化的商业世界里,企业合规管理已经从一项基本的监管要求转变为推动企业持续发展的关键动力。合规不仅是避免法律麻烦的盾牌,它还充当着引领企业向更高效、更可靠和更可持续方向发展的催化剂。而在实现这一目标的过程中,智慧合…

1区SCI,1个月左右录用,1周见刊,各项指标优秀,强推!

毕业推荐 SSCI • 社科类,分区稳步上升(最快13天录用) IEEE: • 计算机类,1区(TOP),CCF推荐 SCIE • 计算机工程类,CCF推荐(最快16天录用) 计算机类 ● 好刊解读 …

websocket 局域网 webrtc 一对一 多对多 视频通话 的示例

基本介绍 WebRTC(Web Real-Time Communications)是一项实时通讯技术,它允许网络应用或者站点,在不借助中间媒介的情况下,建立浏览器之间点对点(Peer-to-Peer)的连接,实现视频流和&am…

k8s存储学习 emptyDir 卷

官网描述: 对于定义了emptyDir卷的Pod,在Pod被指派到某节点时此卷会被创建。就像其名称所表示的那样,emptyDir卷最初是空的。尽管Pod中的容器挂载emptyDir卷的路径可能相同也可能不容。但这些容器都可以读写emptyDir卷中相同的文件。当Pod因…

OpenHarmony实战:使用宏、std::bind 巧妙实现进出函数日志打印

背景 我们始终渴望了解模块的调用、时序逻辑,每个人都会轻易地想到在函数的入口打印一条进入 enter 相关的日志,在函数的出口打印一条离开 leave 相关的日志。不能有遗漏,我们会复制这条日志到所有关心的函数中,为了表明是哪个模…

桶排序---

1、算法概念 桶排序:一种非比较的排序算法。桶排序采用了一些分类和分治的思想,把元素的值域分成若干段,每一段对应一个桶。在排序的时候,首先把每一个元素放到其对应的桶中,再对每一个桶中的元素分别排序&#xff0c…

【数据库系统工程师】软考2024年5月报名流程及注意事项

2024年5月软考数据库系统工程师报名入口: 中国计算机技术职业资格网(http://www.ruankao.org.cn/) 2024年软考报名时间暂未公布,考试时间上半年为5月25日到28日,下半年考试时间为11月9日到12日。不想错过考试最新消息…

分享一种快速移植OpenHarmony Linux内核的方法

移植概述 本文面向希望将 OpenHarmony 移植到三方芯片平台硬件的开发者,介绍一种借助三方芯片平台自带 Linux 内核的现有能力,快速移植 OpenHarmony 到三方芯片平台的方法。 移植到三方芯片平台的整体思路 内核态层和用户态层 为了更好的解释整个内核…

Unity 使用 IL2CPP 发布项目

一、为什么用 IL2CPP Unity的IL2CPP(Intermediate Language to C)是一个编译技术,它将C#代码转换为C代码,然后再编译成平台相关的二进制代码。IL2CPP提供了几个优点,特别是在性能和跨平台部署方面。以下是IL2CPP的一些…

未来的智能起航:探索AI技术的创业新天地

在科技飞速发展的当今世界,人工智能(AI)已经成为一个热门话题。不再是科幻小说中的概念,AI正逐渐融入我们的生活和工作中,开创了全新的创业市场和机会。人工智能(AI)的飞速发展不仅引领了科技的…

iOS苹果签名共享签名是什么以及如何获取?

哈喽,大家好呀,咕噜淼淼又来和大家见面啦,最近有很多朋友都来向我咨询共享签名iOS苹果IPA共享签名是什么,针对这个问题,淼淼来解答一下大家的疑惑并告诉大家iOS苹果ipa共享签名需要如何获取。 现在苹果签名在市场上的…

boost库搜索引擎

文章目录 0. 前言1. 搜索引擎原理2. 技术栈和项目环境3. 正排索引和倒排索引3.1 正排索引3.2 倒排索引3.3 模拟查找 4. 获取数据源5. 数据清洗5.1 保存路径5.2 解析文件提取标题提取内容构造url 5.4 保存内容 6. 建立索引6.1 建立正排索引6.2 建立倒排索引6.3 构建索引 7. 搜索…

为什么高校都在做数字化转型?智慧校园建设该如何落地?

作为高校的多年合作伙伴,很多时候我们在与各大高校信息处的老师对接和联系的时候,常常听到部分老师在头疼学校的信息化管理,从而提出需求。归纳下来,高校的信息管理主要面对着三大难题: 第一,资源管理效率…

精彩解读:短链接应用全方位探究

title: 精彩解读:短链接应用全方位探究 date: 2024/4/2 17:44:50 updated: 2024/4/2 17:44:50 tags: 短链接定义映射算法原理简洁美化优势工作流程解析安全隐私保护商业营销应用技术趋势发展 1. 短链接的定义和原理 短链接是一种将长网址转换为短网址的服务&#…

SeLinux 的编译逻辑

在Android 11 init进程对Selinux的处理一文中,我们知道,在init进程对Selinux的处理过程中,会将precompiled_sepolicy或者动态编译相关目录下的cil文件得到的compiled_sepolicy写入给内核。那么precompiled_sepolicy文件和cil文件是从哪里来的…

自己动手写数据库:基于哈希的静态索引设计

数据库设计中有一项至关重要的技术难点,那就是给定特定条件进行查询时,我们需要保证速度尽可能快。假设我们有一个 STUDENT 表,表中包含学生名字,年龄,专业等字段,当我们要查询给定年龄数值的记录&#xff…

如何一键展示全平台信息?Python手把手教你搭建自己的自媒体展示平台

前言 灵感源于之前写过的Github中Readme.md中可以插入自己的js图片和动态api解析模块&#xff0c;在展示方面十分的美观&#xff1a; 这方面原理可以简化为&#xff0c;在Markdown中&#xff0c;你可以使用HTML标签来添加图像&#xff0c;就像这样&#xff1a; <tr><…

轻松设置Facebook自动隐藏评论和删除评论功能

Facebook作为海外营销的最大流量平台之一&#xff0c;是很多跨境卖家争夺的市场&#xff0c;希望可以通过Facebook这个全球性的平台来推广自己的产品或服务。身处这个竞争激烈的市场&#xff0c;任何一条负面评论或不当言论出现在你的品牌页面上都可能影响到品牌形象&#xff0…

晶核新手必备攻略,干货满满!

晶核游戏以其独特的玩法和丰富的内容吸引着众多玩家。然而&#xff0c;对于一些追求效率和资源的玩家来说&#xff0c;单开游戏往往难以满足他们的需求。多开游戏成为了一个不错的选择&#xff0c;它能帮助玩家更快地获取资源&#xff0c;提升账号实力。下面将为大家分享一些晶…