PyTorch深度学习实战(25)——自编码器

PyTorch深度学习实战(25)——自编码器

    • 0. 前言
    • 1. 自编码器
    • 2. 使用 PyTorch 实现自编码器
    • 小结
    • 系列链接

0. 前言

自编码器 (Autoencoder) 是一种无监督学习的神经网络模型,用于数据的特征提取和降维,它由一个编码器 (Encoder) 和一个解码器 (Decoder) 组成,通过将输入数据压缩到低维表示,然后再重构出原始数据。在本节中,我们将学习如何使用自编码器,以在低维空间表示图像,学习以较少的维度表示图像有助于修改图像,可以利用低维表示来生成新图像。

1. 自编码器

我们已经学习了通过输入图像及其相应标签训练模型来对图像进行分类,进行分类的前提是是拥有带有类别标签的数据集。假设数据集中没有图像对应的标签,如果需要根据图像的相似性对图像进行聚类,在这种情况下,自编码器可以方便地识别和分组相似的图像。
自动编码器将图像作为输入,将其存储在低维空间中,并尝试通过解码过程输出相同图像,而不使用其他标签,因此 AutoEncoder 中的 Auto 表示能够再现输入。但是,如果我们只需要简单的在输出中重现输入,就不需要神经网络了,只需要将输入简单地原样输出即可。自编码器的作用在于它能够以较低维度对图像信息进行编码,因此称为编码器(将图像信息编码至较低维空间中),因此,相似的图像具有相似的编码。此外,解码器致力于根据编码矢量重建原始图像,以尽可能重现输入图像:

自编码器

假设模型输入图像是 MNIST 手写数字图像,模型输出图像与输入图像相同。最中间的网络层是编码层,也称瓶颈层 (bottleneck layer),输入和瓶颈层之间发生的操作表示编码器,瓶颈层和输出之间的操作表示解码器。
通过瓶颈层,我们可以在低维空间中表示图像,也可以重建原始图像,换句话说,利用自编码器中的瓶颈层能够解决识别相似图像以及生成新图像的问题,具体而言:

  • 具有相似瓶颈层值(编码表示,也称潜编码)的图像可能彼此相似
  • 通过改变瓶颈层的节点值,可以改变输出图像。

2. 使用 PyTorch 实现自编码器

本节中,使用 PyTorch 构建自编码器,我们使用 MNIST 数据集训练此网络,MNIST 数据集中是一个手写数字的图像数据集,包含了 6 万个 28x28 像素的训练样本和 1 万个测试样本。

(1) 导入相关库并定义设备:

from torchvision.datasets import MNIST
from torchvision import transforms
from torch.utils.data import DataLoader, Dataset
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms
from torchvision.utils import make_grid
import numpy as np
from matplotlib import pyplot as plt
device = 'cuda' if torch.cuda.is_available() else 'cpu'

(2) 指定图像转换方法:

img_transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize([0.5], [0.5]),
    transforms.Lambda(lambda x: x.to(device))
])

通过以上代码,将图像转换为张量,对其进行归一化,然后将其传递到设备中。

(3) 创建训练和验证数据集:

trn_ds = MNIST('MNIST/', transform=img_transform, train=True, download=True)
val_ds = MNIST('MNIST/', transform=img_transform, train=False, download=True)

(4) 定义数据加载器:

batch_size = 256
trn_dl = DataLoader(trn_ds, batch_size=batch_size, shuffle=True)
val_dl = DataLoader(val_ds, batch_size=batch_size, shuffle=False)

(5) 定义网络架构,在 __init__ 方法中定义了使用编码器-解码器架构的 AutoEncoder 类,以及瓶颈层的维度,latent_dimforward 方法,并打印模型摘要信息。

定义 AutoEncoder 类和包含编码器、解码器以及瓶颈层维度的 __init__ 方法:

class AutoEncoder(nn.Module):
    def __init__(self, latent_dim):
        super().__init__()
        self.latend_dim = latent_dim
        self.encoder = nn.Sequential(
            nn.Linear(28 * 28, 128), nn.ReLU(True),
            nn.Linear(128, 64), nn.ReLU(True), 
            #nn.Linear(64, 12),  nn.ReLU(True), 
            nn.Linear(64, latent_dim))
        self.decoder = nn.Sequential(
            #nn.Linear(latent_dim, 12), nn.ReLU(True),
            nn.Linear(latent_dim, 64), nn.ReLU(True),
            nn.Linear(64, 128), nn.ReLU(True), 
            nn.Linear(128, 28 * 28), nn.Tanh())

定义前向计算方法 forward

    def forward(self, x):
        x = x.view(len(x), -1)
        x = self.encoder(x)
        x = self.decoder(x)
        x = x.view(len(x), 1, 28, 28)
        return x

打印模型摘要信息:

from torchsummary import summary
model = AutoEncoder(3).to(device)
print(summary(model, (1,28,28)))

模型架构信息输出如下:

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
================================================================
            Linear-1                  [-1, 128]         100,480
              ReLU-2                  [-1, 128]               0
            Linear-3                   [-1, 64]           8,256
              ReLU-4                   [-1, 64]               0
            Linear-5                    [-1, 3]             195
            Linear-6                   [-1, 64]             256
              ReLU-7                   [-1, 64]               0
            Linear-8                  [-1, 128]           8,320
              ReLU-9                  [-1, 128]               0
           Linear-10                  [-1, 784]         101,136
             Tanh-11                  [-1, 784]               0
================================================================
Total params: 218,643
Trainable params: 218,643
Non-trainable params: 0
----------------------------------------------------------------
Input size (MB): 0.00
Forward/backward pass size (MB): 0.02
Params size (MB): 0.83
Estimated Total Size (MB): 0.85
----------------------------------------------------------------

从前面的输出中,可以看到 Linear: 2-5 层是瓶颈层,将每张图像都表示为一个 3 维向量;此外,解码器使用瓶颈层中的 3 维向量重建原始图像。

(6) 定义函数在批数据上训练模型 train_batch()

def train_batch(input, model, criterion, optimizer):
    model.train()
    optimizer.zero_grad()
    output = model(input)
    loss = criterion(output, input)
    loss.backward()
    optimizer.step()
    return loss

(7) 定义在批数据上进行模型验证的函数 validate_batch()

@torch.no_grad()
def validate_batch(input, model, criterion):
    model.eval()
    output = model(input)
    loss = criterion(output, input)
    return loss

(8) 定义模型、损失函数和优化器:

model = AutoEncoder(3).to(device)
criterion = nn.MSELoss()
optimizer = torch.optim.AdamW(model.parameters(), lr=0.001, weight_decay=1e-5)

(9) 训练模型:

num_epochs = 20
train_loss_epochs = []
val_loss_epochs = []
for epoch in range(num_epochs):
    N = len(trn_dl)
    trn_loss = []
    val_loss = []
    for ix, (data, _) in enumerate(trn_dl):
        loss = train_batch(data, model, criterion, optimizer)
        pos = (epoch + (ix+1)/N)
        trn_loss.append(loss.item())
    train_loss_epochs.append(np.average(trn_loss))

    N = len(val_dl)
    for ix, (data, _) in enumerate(val_dl):
        loss = validate_batch(data, model, criterion)
        pos = epoch + (1+ix)/N
        val_loss.append(loss.item())
    val_loss_epochs.append(np.average(val_loss))

(10) 可视化训练期间模型的训练和验证损失随时间的变化情况:

epochs = np.arange(num_epochs)+1
plt.plot(epochs, train_loss_epochs, 'bo', label='Training loss')
plt.plot(epochs, val_loss_epochs, 'r-', label='Test loss')
plt.title('Training and Test loss over increasing epochs')
plt.xlabel('Epochs')
plt.ylabel('Loss')
plt.legend()
plt.grid('off')
plt.show()

模型监测
(11) 使用测试数据集 val_ds 验证模型:

for _ in range(5):
    ix = np.random.randint(len(val_ds))
    im, _ = val_ds[ix]
    _im = model(im[None])[0]
    plt.subplot(121)
    # fig, ax = plt.subplots(1,2,figsize=(3,3)) 
    plt.imshow(im[0].detach().cpu(), cmap='gray')
    plt.title('input')
    plt.subplot(122)
    plt.imshow(_im[0].detach().cpu(), cmap='gray')
    plt.title('prediction')
plt.show()

重建图像
我们可以看到,即使瓶颈层只有三个维度,网络也可以非常准确地重现输入,但是图像并不像预期的那样清晰,主要是因为瓶颈层中的节点数量过少。具有不同瓶颈层大小 (2351050) 的网络训练后,可视化重建的图像如下所示:

def train_aec(latent_dim):
    model = AutoEncoder(latent_dim).to(device)
    criterion = nn.MSELoss()
    optimizer = torch.optim.AdamW(model.parameters(), lr=0.001, weight_decay=1e-5)

    num_epochs = 20
    train_loss_epochs = []
    val_loss_epochs = []

    for epoch in range(num_epochs):
        N = len(trn_dl)
        trn_loss = []
        val_loss = []
        for ix, (data, _) in enumerate(trn_dl):
            loss = train_batch(data, model, criterion, optimizer)
            pos = (epoch + (ix+1)/N)
            trn_loss.append(loss.item())
        train_loss_epochs.append(np.average(trn_loss))

        N = len(val_dl)
        trn_loss = []
        val_loss = []
        for ix, (data, _) in enumerate(val_dl):
            loss = validate_batch(data, model, criterion)
            pos = epoch + (1+ix)/N
            val_loss.append(loss.item())
        val_loss_epochs.append(np.average(val_loss))
    epochs = np.arange(num_epochs)+1
    plt.plot(epochs, train_loss_epochs, 'bo', label='Training loss')
    plt.plot(epochs, val_loss_epochs, 'r-', label='Test loss')
    plt.title('Training and Test loss over increasing epochs')
    plt.xlabel('Epochs')
    plt.ylabel('Loss')
    plt.legend()
    plt.grid('off')
    plt.show()
    return model

aecs = [train_aec(dim) for dim in [50, 2, 3, 5, 10]]

for _ in range(10):
    ix = np.random.randint(len(val_ds))
    im, _ = val_ds[ix]
    plt.subplot(1, len(aecs)+1, 1)
    plt.imshow(im[0].detach().cpu(), cmap='gray')
    plt.title('input')
    idx = 2
    for model in aecs:
        _im = model(im[None])[0]
        plt.subplot(1, len(aecs)+1, idx)
        plt.imshow(_im[0].detach().cpu(), cmap='gray')
        plt.title(f'prediction\nlatent-dim:{model.latend_dim}')
        idx += 1
plt.show()

请添加图片描述

随着瓶颈层中向量维度的增加,重建图像的清晰度逐渐提高。

小结

自编码器是一种无监督学习的神经网络模型,用于数据的特征提取和降维。它由编码器和解码器组成,通过将输入数据压缩到低维表示,并尝试重构出原始数据来实现特征提取和数据的降维。自编码器的训练过程中,目标是最小化输入数据与重构数据之间的重建误差,以使编码器捕捉到数据的关键特征。自编码器在无监督学习和深度学习中扮演着重要的角色,能够从数据中学习有用的特征,并为后续的机器学习任务提供支持。

系列链接

PyTorch深度学习实战(1)——神经网络与模型训练过程详解
PyTorch深度学习实战(2)——PyTorch基础
PyTorch深度学习实战(3)——使用PyTorch构建神经网络
PyTorch深度学习实战(4)——常用激活函数和损失函数详解
PyTorch深度学习实战(5)——计算机视觉基础
PyTorch深度学习实战(6)——神经网络性能优化技术
PyTorch深度学习实战(7)——批大小对神经网络训练的影响
PyTorch深度学习实战(8)——批归一化
PyTorch深度学习实战(9)——学习率优化
PyTorch深度学习实战(10)——过拟合及其解决方法
PyTorch深度学习实战(11)——卷积神经网络
PyTorch深度学习实战(12)——数据增强
PyTorch深度学习实战(13)——可视化神经网络中间层输出
PyTorch深度学习实战(14)——类激活图
PyTorch深度学习实战(15)——迁移学习
PyTorch深度学习实战(16)——面部关键点检测
PyTorch深度学习实战(17)——多任务学习
PyTorch深度学习实战(18)——目标检测基础
PyTorch深度学习实战(19)——从零开始实现R-CNN目标检测
PyTorch深度学习实战(20)——从零开始实现Fast R-CNN目标检测
PyTorch深度学习实战(21)——从零开始实现Faster R-CNN目标检测
PyTorch深度学习实战(22)——从零开始实现YOLO目标检测
PyTorch深度学习实战(23)——使用U-Net架构进行图像分割
PyTorch深度学习实战(24)——从零开始实现Mask R-CNN实例分割

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/240604.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

招不到人?用C语言采集系统批量采集简历

虽说现在大环境不太好,很多人面临着失业再就业风险,包括企业则面临着招人人,找对口专业难得问题。想要找到适合自己公司的人员,还要得通过爬虫获取筛选简历才能从茫茫人海中找到公司得力干将。废话不多说,直接开整。 1…

Github仓库远程操作——简单版

Github远程操作 github仓库简单的远程操作,更多复杂的功能请参考github官方文档 标题 Github远程操作添加公钥到githubGithub仓库远程操作 远程操作之前,先添加本地的公钥到github 添加公钥到github 创建本地ssh公私钥:使用powershell或者gi…

(1)(1.7) HOTT telemetry

文章目录 前言 1 布线和设置 2 参数说明 前言 Plane-4.0.0(及更高版本)、Copter-4.0.4(及更高版本)和 Rover-4.1.0(及更高版本)支持 Graupner HOTT 遥测技术。 1 布线和设置 与自动驾驶仪的连接可通过…

Jenkins项目部署CICD

目录 什么是CI/CD 常用 CI/CD 工具 主要步骤 1、点击新建任务 2、构建自由风格项目 3、填写内容 ①、General 1)描述 2)丢弃旧的构建 ②、源码管理 1)Repository URL 2)Credentials 3)Branches to build…

破局:国内市场确实存在“消费升级”和“消费降级”,3.0全新新零售商业模式

国内市场确实存在“消费升级”和“消费降级”两个趋势,这是由于不同消费者群体的需求和购买力存在差异。消费升级主要发生在高端市场,消费者愿意为高品质、高价值、高价格的商品和服务付出更多。而消费降级则主要发生在中低端市场,消费者更加…

【教程】Ipa Guard为iOS应用提供免费加密混淆方案

概述:使用ios加固工具对ios代码保护,保护ios项目中的核心代码, #ipagurd年终大促百厂联动暖冬特惠,超多软控件立享惊喜优惠>> ​ 简介 iOS加固保护是直接针对ios ipa二进制文件的保护技术,可以对iOS APP中的可…

git根据commit id强制推送,撤销远程仓库代码

由于将把不用发版的需求合并上去了,现在想撤回,可以根据以下操作进行 注意撤回、强制推送有风险,记得强制撤回前,备份好代码 确保本地仓库中包含你想要推送的 commit: 这里你要经常使用命令进行操作的话,就…

maui下sqlite演示增删改查

数据操作类 有分页 todoitemDatabase.cs: using SQLite; using TodoSQLite.Models;namespace TodoSQLite.Data {public class TodoItemDatabase{SQLiteAsyncConnection Database;public TodoItemDatabase(){}// 初始化数据库连接和表async Task Init(){if (Databa…

Java:TCP 通信方法(基本发送 + 接收)并 实现文件传输且反馈

TCP 通信编程 TCP:是一种可靠的网络协议,再通信两端都建立一个Socket对象。 通信之前要保证连接已经建立。 通过Socket产生IO流进行通信。 创建对象时,会连接服务器,连接不上,会报错。 所以,先运行服务端,再…

Triton算法服务部署:初识与试用【Hello world】

0. 写在前面 Triton Inference Server 是一款开源推理服务软件,可简化 AI 推理。其可以部署来自多个深度学习和机器学习框架的任何 AI 模型,包括 TensorRT、TensorFlow、PyTorch、ONNX、OpenVINO、Python、RAPIDS FIL 等。Triton 支持在 NVIDIA GPU、x8…

【C++】哈希表

文章目录 哈希概念哈希冲突哈希函数哈希表闭散列开散列 开散列与闭散列比较 正文开始前给大家推荐个网站,前些天发现了一个巨牛的 人工智能学习网站, 通俗易懂,风趣幽默,忍不住分享一下给大家。 点击跳转到网站。 哈希概念 顺…

微服务项目部署

启动rabbitmq \RabbitMQ\rabbitmq_server-3.8.2\sbin 找到你的安装路径 找到\sbin路径下执行这些命令即可 rabbitmqctl status //查看当前状态 rabbitmq-plugins enable rabbitmq_management //开启Web插件 rabbitmq-server start //启动服务 rabbitmq-server stop //停止服务…

不需要联网的ocr项目

地址 GitHub - plantree/ocr-pwa: A simple PWA for OCR, based on Tesseract. 协议 mit 界面 推荐理由 可以离线使用,隐私安全

python自动化测试实战 —— 自动化测试框架的实例

软件测试专栏 感兴趣可看:软件测试专栏 自动化测试学习部分源码 python自动化测试相关知识: 【如何学习Python自动化测试】—— 自动化测试环境搭建 【如何学习python自动化测试】—— 浏览器驱动的安装 以及 如何更…

【XR806开发板试用】基于FreeRTOS的SoftAp配网实现

1.环境搭建 由于电脑上之前就有开发其他设备用的ubuntu18.06虚拟机环境,就在此环境基础上进行开发。基本环境搭建参考官方文档进行: 全志XR806开发板开发环境搭建 2.功能实现 2.1设计思路 从官方下载的SDK开发包project/example目录下有基本功能实现…

扫盲运动—字节序

1 大端、小端字节序 术语“大端”和“小端”表示多个字节值的哪一端(小端或大端)存储在该值的起始地址。 大端:将高序字节存储在起始地址,这称为大端(big-endian)字节序小端:将低序字节存储在…

03-详解Nacos注册中心的配置步骤和功能

Nacos注册中心 服务注册到Nacos Nacos是SpringCloudAlibaba的组件也遵循SpringCloud中定义的服务注册和服务发现规范,因此使用Nacos与使用Eureka对于微服务来说并没有太大区别 主要差异就是依赖不同,服务地址不同 第一步: 在父工程cloud-demo模块的pom.xml文件中引入Spring…

现代信号处理实验:MATLAB实现LD算法进行AR估计

MATLAB实现LD算法进行AR估计 利用给定的一组样本数据估计一个平稳随机信号的功率谱密度称为功率谱估计,又称谱估计。谱估计的方法可以分成经典谱估计和现代谱估计。 经典谱估计又称为非参数化的谱估计,分为直接法和间接法。直接法是指直接计算样本数据…

C# WPF上位机开发(增强版绘图软件)

【 声明:版权所有,欢迎转载,请勿用于商业用途。 联系信箱:feixiaoxing 163.com】 前面我们写过一个绘图软件,不过那个比较简单,主要就是用鼠标模拟pen进行绘图。实际应用中,另外一种使用比较多的…

MySQL笔记-第18章_MySQL8其它新特性

视频链接:【MySQL数据库入门到大牛,mysql安装到优化,百科全书级,全网天花板】 文章目录 第18章_MySQL8其它新特性1. MySQL8新特性概述1.1 MySQL8.0 新增特性1.2 MySQL8.0移除的旧特性 2. 新特性1:窗口函数2.1 使用窗口…