【信号处理】基于变分自编码器(VAE)的图片典型增强方法实现

关于

深度学习中,经常面临图片数据量较小的问题,此时,对数据进行增强,显得比较重要。传统的图片增强方法包括剪切,增加噪声,改变对比度等等方法,但是,对于后端任务的性能提升有限。所以,变分自编码器被用来实现深度数据增强。

变分自编码器的主要缺点在于生成图像过于平滑和模糊,图像细节重建不足。

常见的图像增强方法:https://www.tensorflow.org/tutorials/images/data_augmentation

工具

数据集下载地址: CIFAR-10 and CIFAR-100 datasets

方法实现

加载数据和必要的库函数
import tensorflow.compat.v1.keras.backend as K
import tensorflow as tf
tf.compat.v1.disable_eager_execution()
import matplotlib.pyplot as plt
import numpy as np
from numpy import random
import tensorflow_datasets as tfds
import keras
from keras.models import Model
from keras.layers import Conv2D, Conv2DTranspose, Input, Flatten, Dense, Lambda, Reshape


xtrain , ytrain = tfds.as_numpy(tfds.load('cifar10',split='train',batch_size=-1,as_supervised=True,))
xtest , ytest = tfds.as_numpy(tfds.load('cifar10',split='test',batch_size=-1,as_supervised=True,))
xtrain = (xtrain.astype('float32'))/255
xtest = (xtest.astype('float32'))/255

height=32
width=32
channels=3
print(f"Train Shape: {xtrain.shape},Test Shape: {xtest.shape}")
plt.imshow(xtrain[0])

编码器模型搭建
input_shape=(height,width,channels)
latent_dims=3072

input_img= Input(shape=input_shape, name='encoder_input')
x=Conv2D(128, 4, padding='same', activation='relu',strides=2)(input_img)
x=Conv2D(256, 4, padding='same', activation='relu',strides=2)(x)
x=Conv2D(512, 4, padding='same', activation='relu',strides=2)(x)
x=Conv2D(1024, 4, padding='same', activation='relu',strides=2)(x)
conv_shape = K.int_shape(x)
x=Flatten()(x)
x=Dense(3072, activation='relu')(x)
z_mean=Dense(latent_dims, name='latent_mean')(x)
z_sigma=Dense(latent_dims, name='latent_sigma')(x)

def sampler(args):
  z_mean, z_sigma = args
  eps = K.random_normal(shape=(K.shape(z_mean)[0], K.int_shape(z_mean)[1]))
  return z_mean + K.exp(z_sigma / 2) * eps


z = Lambda(sampler, output_shape=(latent_dims, ), name='z')([z_mean, z_sigma])

encoder = Model(input_img, [z_mean, z_sigma, z], name='encoder')
print(encoder.summary())

 解码器模型构建
decoder_input = Input(shape=(latent_dims, ), name='decoder_input')
x = Dense(conv_shape[1]*conv_shape[2]*conv_shape[3], activation='relu')(decoder_input)
x = Reshape((conv_shape[1], conv_shape[2], conv_shape[3]))(x)
x = Conv2DTranspose(256, 3, padding='same', activation='relu',strides=(2, 2))(x)
x = Conv2DTranspose(128, 3, padding='same', activation='relu',strides=(2, 2))(x)
x = Conv2DTranspose(64, 3, padding='same', activation='relu',strides=(2, 2))(x)
x = Conv2DTranspose(3, 3, padding='same', activation='relu',strides=(2, 2))(x)
x = Conv2DTranspose(channels, 3, padding='same', activation='sigmoid', name='decoder_output')(x)
decoder = Model(decoder_input, x, name='decoder')
decoder.summary()
z_decoded = decoder(z)

class CustomLayer(keras.layers.Layer):

    def vae_loss(self, x, z_decoded):
        x = K.flatten(x)
        z_decoded = K.flatten(z_decoded)
        
        # Reconstruction loss (as we used sigmoid activation we can use binarycrossentropy)
        recon_loss = keras.metrics.binary_crossentropy(x, z_decoded)
        
        # KL divergence
        kl_loss = -5e-4 * K.mean(1 + z_sigma - K.square(z_mean) - K.exp(z_sigma), axis=-1)
        return K.mean(recon_loss + kl_loss)

    # add custom loss to the class
    def call(self, inputs):
        x = inputs[0]
        z_decoded = inputs[1]
        loss = self.vae_loss(x, z_decoded)
        self.add_loss(loss, inputs=inputs)
        return x

 

整体模型构建
y = CustomLayer()([input_img, z_decoded])

vae = Model(input_img, y, name='vae')
vae.compile(optimizer='adam', loss=None)
vae.summary()

 

模型训练

history=vae.fit(xtrain, verbose=2, epochs = 100, batch_size = 64, validation_split = 0.2)
 训练可视化
f = plt.figure(figsize=(10,7))
f.add_subplot()
#Adding Subplot
plt.plot(history.epoch, history.history['loss'], label = "loss") # Loss curve for training set
plt.plot(history.epoch, history.history['val_loss'], label = "val_loss") # Loss curve for validation set

plt.title("Loss Curve",fontsize=18)
plt.xlabel("Epochs",fontsize=15)
plt.ylabel("Loss",fontsize=15)
plt.grid(alpha=0.3)
plt.legend()
plt.savefig("VAE_Loss_Trial5.png")
plt.show()

 中间编码特征可视化
mu, _, _ = encoder.predict(xtest)
#Plot dim1 and dim2 for mu
plt.figure(figsize=(10, 10))
plt.scatter(mu[:, 0], mu[:, 1], c=ytest, cmap='brg')
plt.xlabel('dim 1')
plt.ylabel('dim 2')
plt.colorbar()
plt.show()
plt.savefig("VAE_Colourbar_Trial5.png")

 

数据增强生成
#RANDOM GENERATION
def generate():
    n=20
    figure = np.zeros((width *2 , height * 10, channels))

#Create a Grid of latent variables, to be provided as inputs to decoder.predict
#Creating vectors within range -5 to 5 as that seems to be the range in latent space

    for k in range(2):
        for l in range(10):
            z_sample =random.rand(3072)
            z_out=np.array([z_sample])
            x_decoded = decoder.predict(z_out)
            digit = x_decoded[0].reshape(width, height, channels)
            figure[k * width: (k + 1) * width,
                    l * height: (l + 1) * height] = digit

    plt.figure(figsize=(10, 10))
#Reshape for visualization
    fig_shape = np.shape(figure)
    figure = figure.reshape((fig_shape[0], fig_shape[1],3))

    plt.imshow(figure, cmap='gnuplot2')
    plt.show()  
 
    plt.savefig("VAE_imagesgen_Trial5.png")

解码器图像重建
#IMAGE RECONSTRUCT USING TEST SET IMGS
def reconstruct():
    num_imgs = 6
    rand = np.random.randint(1, xtest.shape[0]-6) 

    xtestsample = xtest[rand:rand+num_imgs]
    x_encoded = np.array(encoder.predict(xtestsample))
    latent_xtest=x_encoded[2]
    x_decoded = decoder.predict(latent_xtest)

    rows = 2 # defining no. of rows in figure
    cols = 3 # defining no. of colums in figure
    cell_size = 1.5
    f = plt.figure(figsize=(cell_size*cols,cell_size*rows*2)) # defining a figure 
    f.tight_layout()
    for i in range(rows):
        for j in range(cols): 
            f.add_subplot(rows*2,cols, (2*i*cols)+(j+1)) # adding sub plot to figure on each iteration
            plt.imshow(xtestsample[i*cols + j]) 
            plt.axis("off")
        
        for j in range(cols): 
            f.add_subplot(rows*2,cols,((2*i+1)*cols)+(j+1)) # adding sub plot to figure on each iteration
            plt.imshow(x_decoded[i*cols + j]) 
            plt.axis("off")

    f.suptitle("Autoencoder Results - Cifar10",fontsize=18)
    plt.savefig("VAE_imagesrecons_Trial5.png")

    plt.show()

 

代码获取

已经附在文章底部,自行拿取。

项目开发,相关问题咨询,欢迎交流沟通。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/515752.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

运算符规则

console.log(null undefined) null和undefined都是原始类型,然后把这两个转换为数字。是0NaN.看规则有一个NaN的话就得到NaN. console.log({} []); 把{}和[]转换为原始类型分别为和[Object Object]。然后特殊情况有字符串,那就拼接字符串返回[Object…

Redis数据库——群集(主从、哨兵)

目录 前言 一、主从复制 1.基本原理 2.作用 3.流程 4.搭建主动复制 4.1环境准备 4.2修改主服务器配置 4.3从服务器配置(Slave1和Slave2) 4.4查看主从复制信息 4.5验证主从复制 二、哨兵模式——Sentinel 1.定义 2.原理 3.作用 4.组成 5.…

【Java多线程(3)】线程安全问题和解决方案

目录 一、线程安全问题 1. 线程不安全的示例 2. 线程安全的概念 3. 线程不安全的原因 二、线程不安全的解决方案 1. synchronized 关键字 1.1 synchronized 的互斥特性 1.2 synchronized 的可重入特性 1.3 死锁的进一步讨论 1.4 死锁的四个必要条件(重点&…

Golang 内存管理和垃圾回收底层原理(一)

一、这篇文章我们来聊聊Golang内存管理和垃圾回收,主要注重基本底层原理讲解,进一步实战待后续文章 1、这篇我们来讨论一下Golang的内存管理 先上结构图 从图我们来讲Golang的基本内存结构,内存结构可以分为:协程缓存、中央缓存…

vue3+eachrts饼图轮流切换显示高亮数据

<template><div class"charts-box"><div class"charts-instance" ref"chartRef"></div>// 自定义legend 样式<div class"charts-note"><span v-for"(items, index) in data.dataList" cla…

jdbc连SQL server,显示1433端口连接失败解决方法

Exception in thread "main" com.microsoft.sqlserver.jdbc.SQLServerException: 通过端口 1433 连接到主机 localhost 的 TCP/IP 连接失败。错误:“connect timed out。请验证连接属性。确保 SQL Server 的实例正在主机上运行&#xff0c;且在此端口接受 TCP/IP 连接…

移动WEB开发之rem适配布局

一、rem 基础 rem 单位 rem (root em)是一个相对单位&#xff0c;类似于em&#xff0c;em是父元素字体大小。不同的是rem的基准是相对于html元素的字体大小。比如&#xff0c;根元素&#xff08;html&#xff09;设置font-size12px; 非根元素设置width:2rem; 则换成px表示就是2…

页面自适应

后续整理下自适应的集中方法 地址

【数据库】MySQL InnoDB存储引擎详解 - 读书笔记

MySQL InnoDB存储引擎详解 - 读书笔记 InnoDB 存储引擎概述InnoDB 存储引擎的版本InnoDB 体系架构内存缓冲池LRU List、Free List 和 Flush List重做日志缓冲&#xff08;redo log buffer&#xff09;额外的内存池 存储结构表空间系统表空间独立表空间通用表空间undo表空间临时…

如何彻底删除node和npm

如何彻底删除node和npm 前言&#xff1a; 最近做个项目把本地的node更新了&#xff0c;之前是v10.14.2更新至v16.14.0 &#xff0c;想着把之前的项目起来下&#xff0c;执行npm install 结果启动不了&#xff0c;一直报npm版本不匹配需要更新本地库异常… 找了几天发现是npm 和…

基于JAVA的汽车售票网站论文

摘 要 互联网发展至今&#xff0c;无论是其理论还是技术都已经成熟&#xff0c;而且它广泛参与在社会中的方方面面。它让信息都可以通过网络传播&#xff0c;搭配信息管理工具可以很好地为人们提供服务。针对汽车售票信息管理混乱&#xff0c;出错率高&#xff0c;信息安全性差…

从零到百万富翁:ChatGPT + Pinterest

原文&#xff1a;Zero to Millionaire Online: ChatGPT Pinterest 译者&#xff1a;飞龙 协议&#xff1a;CC BY-NC-SA 4.0 在社交媒体上赚取百万美元 - 逐步指南&#xff0c;如何在线赚钱版权 献给&#xff1a; 我将这本书&#xff0c;“从零到百万富翁在线&#xff1a;Chat…

Netty经典32连问

文章目录 1、Netty是什么&#xff0c;它的主要特点是什么&#xff1f;2、Netty 应用场景了解么&#xff1f;3、Netty 核心组件有哪些&#xff1f;分别有什么作用&#xff1f;4、Netty的线程模型是怎样的&#xff1f;如何优化性能&#xff1f;5、EventloopGroup了解么?和 Event…

【QT入门】 无边框窗口设计之实现圆角窗口

往期回顾&#xff1a; 【QT入门】对无边框窗口自定义标题栏并实现拖动和拉伸效果-CSDN博客 【QT入门】 自定义标题栏界面qss美化按钮功能实现-CSDN博客 【QT入门】 无边框窗口设计之实现窗口阴影-CSDN博客 【QT入门】 无边框窗口设计之实现圆角窗口 我们实际用到的很多窗口&am…

装饰工程管理系统|基于Springboot的装饰工程管理系统设计与实现(源码+数据库+文档)

装饰工程管理系统-项目立项子系统目录 目录 基于Springboot的装饰工程管理系统设计与实现 一、前言 二、系统功能设计 三、系统实现 1、管理员功能实现 &#xff08;2&#xff09;合同报价管理 &#xff08;3&#xff09;装饰材料总计划管理 &#xff08;4&#xff0…

基本线段树以及相关例题

1.线段树的概念 线段树是一种二叉树&#xff0c;也就是对于一个线段&#xff0c;我们会用一个二叉树来表示。 这个其实就是一个线段树&#xff0c;我们会将其每次从中间分开&#xff0c;其左孩子就是左边的集合的和&#xff0c;其右孩子就是右边集合的和&#xff1b; 我们可以…

前端返回 List<Map<String, Object>>中的vaue值里面包含一个Bigdecimal类型,序列化时小数点丢失,如何解决?

&#x1f3c6;本文收录于「Bug调优」专栏&#xff0c;主要记录项目实战过程中的Bug之前因后果及提供真实有效的解决方案&#xff0c;希望能够助你一臂之力&#xff0c;帮你早日登顶实现财富自由&#x1f680;&#xff1b;同时&#xff0c;欢迎大家关注&&收藏&&…

Windows 11安装kb5035853补丁时,提示错误0x800f0922,并且弹出“某些操作未按计划进行,不必担心,正在撤消更改。请不要关机”

Windows 11安装kb5035853补丁时&#xff0c;提示错误0x800f0922&#xff0c;并且还在重启后弹出“某些操作未按计划进行&#xff0c;不必担心&#xff0c;正在撤消更改。请不要关机”&#xff0c;按微软官方的作法是&#xff1a;https://learn.microsoft.com/zh-cn/windows/rel…

基于SpringBoot和Vue的音乐在线交流网站的设计和实现【附源码】

1、系统演示视频&#xff08;演示视频&#xff09; 2、需要交流和学习请联系

【PLC一体机】GX Works2编程控制步进电机正反转

今天博主和大家分享一下在GX Works2 中对PLC一体机编程&#xff0c;实现步进电机自动正反转的程序。 程序如下&#xff1a; 该程序有几个重要的地方和大家分享一下&#xff1a; 1、程序以中间寄存器M4开头。 博主想通过PLC一体机上的触摸屏控制程序是否运行&#xff0c;因此…