MATLAB代码解析:利用DCGAN实现图像数据的生成

摘要

经典代码:利用DCGAN生成花朵

MATLAB官方其实给出了DCGAN生成花朵的示范代码,原文地址:训练生成对抗网络 (GAN) - MATLAB & Simulink - MathWorks 中国

先看看训练效果

训练1周期

训练11周期

训练56个周期

脚本文件 

为了能让各位更好的复现,该代码已打包,下载后解压运行用MATLAB运行"gan.mlx"即可
链接: https://pan.baidu.com/s/1hNYLw1xku2AdKf5CanoFzA?pwd=fb7n 提取码: fb7n 
 

代码详解:

首先是脚本gan:

数据获取
clear all
clc
imageFolder = fullfile("flower_photos");
imds = imageDatastore(imageFolder,IncludeSubfolders=true);
augmenter = imageDataAugmenter(RandXReflection=true);
augimds = augmentedImageDatastore([64 64],imds,DataAugmentation=augmenter);
生成器
filterSize = 5;
numFilters = 64;
numLatentInputs = 100;

projectionSize = [4 4 512];%

layersGenerator = [
    featureInputLayer(numLatentInputs)
    projectAndReshapeLayer(projectionSize)
    transposedConv2dLayer(filterSize,4*numFilters)
    batchNormalizationLayer
    reluLayer
    transposedConv2dLayer(filterSize,2*numFilters,Stride=2,Cropping="same")
    batchNormalizationLayer
    reluLayer
    transposedConv2dLayer(filterSize,numFilters,Stride=2,Cropping="same")
    batchNormalizationLayer
    reluLayer
    transposedConv2dLayer(filterSize,3,Stride=2,Cropping="same")
    tanhLayer];
netG = dlnetwork(layersGenerator);
判别器
dropoutProb = 0.5;
numFilters = 64;
scale = 0.2;

inputSize = [64 64 3];
filterSize = 5;

layersDiscriminator = [
    imageInputLayer(inputSize,Normalization="none")
    dropoutLayer(dropoutProb)
    convolution2dLayer(filterSize,numFilters,Stride=2,Padding="same")
    leakyReluLayer(scale)
    convolution2dLayer(filterSize,2*numFilters,Stride=2,Padding="same")
    batchNormalizationLayer
    leakyReluLayer(scale)
    convolution2dLayer(filterSize,4*numFilters,Stride=2,Padding="same")
    batchNormalizationLayer
    leakyReluLayer(scale)
    convolution2dLayer(filterSize,8*numFilters,Stride=2,Padding="same")
    batchNormalizationLayer
    leakyReluLayer(scale)
    convolution2dLayer(4,1)
    sigmoidLayer];
netD = dlnetwork(layersDiscriminator);
指定训练选项
numEpochs = 500;
miniBatchSize = 128;
learnRate = 0.00008;
gradientDecayFactor = 0.5;
squaredGradientDecayFactor = 0.999;
flipProb = 0.35;
validationFrequency = 100;
训练模型
augimds.MiniBatchSize = miniBatchSize;

mbq = minibatchqueue(augimds, ...
    MiniBatchSize=miniBatchSize, ...
    PartialMiniBatch="discard", ...
    MiniBatchFcn=@preprocessMiniBatch, ...
    MiniBatchFormat="SSCB");
trailingAvgG = [];
trailingAvgSqG = [];
trailingAvg = [];
trailingAvgSqD = [];
numValidationImages = 25;
ZValidation = randn(numLatentInputs,numValidationImages,"single");
ZValidation = dlarray(ZValidation,"CB");
if canUseGPU
    ZValidation = gpuArray(ZValidation);
end

f = figure;
f.Position(3) = 2*f.Position(3);

imageAxes = subplot(1,2,1);
scoreAxes = subplot(1,2,2);

C = colororder;
lineScoreG = animatedline(scoreAxes,Color=C(1,:));
lineScoreD = animatedline(scoreAxes,Color=C(2,:));
legend("Generator","Discriminator");
ylim([0 1])
xlabel("Iteration")
ylabel("Score")
grid on

iteration = 0;
start = tic;

% Loop over epochs.
for epoch = 1:numEpochs

    % Reset and shuffle datastore.
    shuffle(mbq);

    % Loop over mini-batches.
    while hasdata(mbq)
        iteration = iteration + 1;

        % Read mini-batch of data.
        X = next(mbq);

        % Generate latent inputs for the generator network. Convert to
        % dlarray and specify the format "CB" (channel, batch). If a GPU is
        % available, then convert latent inputs to gpuArray.
        Z = randn(numLatentInputs,miniBatchSize,"single");
        Z = dlarray(Z,"CB");

        if canUseGPU
            Z = gpuArray(Z);
        end

        % Evaluate the gradients of the loss with respect to the learnable
        % parameters, the generator state, and the network scores using
        % dlfeval and the modelLoss function.
        [L,~,gradientsG,gradientsD,stateG,scoreG,scoreD] = ...
            dlfeval(@modelLoss,netG,netD,X,Z,flipProb);
        netG.State = stateG;

        %%show data
        %"epoch"
        %epoch
        %"scoreG-D"
        %[scoreG,scoreD]

        % Update the discriminator network parameters.
        [netD,trailingAvg,trailingAvgSqD] = adamupdate(netD, gradientsD, ...
            trailingAvg, trailingAvgSqD, iteration, ...
            learnRate, gradientDecayFactor, squaredGradientDecayFactor);

        % Update the generator network parameters.
        [netG,trailingAvgG,trailingAvgSqG] = adamupdate(netG, gradientsG, ...
            trailingAvgG, trailingAvgSqG, iteration, ...
            learnRate, gradientDecayFactor, squaredGradientDecayFactor);
        % Every validationFrequency iterations, display batch of generated
        % images using the held-out generator input.
        if mod(iteration,validationFrequency) == 0 || iteration == 1
            % Generate images using the held-out generator input.
            XGeneratedValidation = predict(netG,ZValidation);

            % Tile and rescale the images in the range [0 1].
            I = imtile(extractdata(XGeneratedValidation));
            I = rescale(I);

            % Display the images.
            subplot(1,2,1);
            image(imageAxes,I)
            xticklabels([]);
            yticklabels([]);
            title("Generated Images");
        end

        % Update the scores plot.
        subplot(1,2,2)
        scoreG = double(extractdata(scoreG));
        addpoints(lineScoreG,iteration,scoreG);

        scoreD = double(extractdata(scoreD));
        addpoints(lineScoreD,iteration,scoreD);

        % Update the title with training progress information.
        D = duration(0,0,toc(start),Format="hh:mm:ss");
        title(...
            "Epoch: " + epoch + ", " + ...
            "Iteration: " + iteration + ", " + ...
            "Elapsed: " + string(D))

        drawnow
    end
end
生成新图像  
numObservations = 4;
ZNew = randn(numLatentInputs,numObservations,"single");
ZNew = dlarray(ZNew,"CB");
if canUseGPU
    ZNew = gpuArray(ZNew);
end

XGeneratedNew = predict(netG,ZNew);

I = imtile(extractdata(XGeneratedNew));
I = rescale(I);
figure
image(I)
axis off
title("Generated Images")

生成器与判别器的设计

脚本gan中以及包含了生成器Generator和判别器Discriminator的结构设计,生成器利用装置卷积对特征进行上采样,最终生成了64*64*3的图像,而判别器则用卷积进行下采样,将输入提取至1*1的格式大小,利用sigmoid作为激活函数,判断输入图像的真假

如何自定义生成对抗网络?很简单,把握上采样和下采样的规模就行,利用MATLAB的DLtool(deep network designer)可以很好的观察到这一点,以刚刚的生成器为例,我们可以观察到,转置卷积后(步幅为2),输出的空间(S)长宽都翻倍,深度对应我们给定的filters数量,因此,我们想要生成特定大小的数据时,修改转置卷积的步幅、卷积核数量以及转置卷积层的数量就行,同时记得在添加的转置卷积层后连接新的BN层和ReLU激活函数。

比如我想生成128*128*3的图片,我只需要将刚刚示例中的其中一个转置卷积核的大小提高至7*7,同时步幅修改成4。或者,我直接添加一层步幅为2的转置卷积层。对于一些数据尺寸为非2倍数问题,如311*171*3,我们可以先生成312*172*3再resize一下,或者你提前将数据预处理成312*172.

注意:定义完网络结构后,要用dlnetwork()函数将layer参数转变成可训练的dlnetwork。

最近比较忙,先在这里停笔了,后面再慢慢补充-24-10-14

数据预处理

自定义模型训练

损失函数与梯度下降

优化器与参数更新

总结

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/891105.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

centos7 Oracle 11g rac 静默安装(NFS配置共享存储)

1.环境信息准备 注意: 在配置网络时,Oracle RAC的每个节点必须具有至少两个以上的网卡,一张网卡对外提供网络服务,另一张网卡用于各个节点间的通信和心跳检测等。在配置RAC集群的网卡时,如果节点1的公共接口是eth0&…

下一代安全:融合网络和物理策略以实现最佳保护

在当今快速发展的技术环境中,网络和物理安全融合变得比以往任何时候都更加重要。随着物联网 (IoT) 和工业物联网 (IIoT) 的兴起,组织在保护数字和物理资产方面面临着独特的挑战。 本文探讨了安全融合的概念、说明其重要性的实际事件以及整合网络和物理安…

本地装了个pytorch cuda

安装命令选择 pip install torch1.13.1cu116 torchvision0.14.1cu116 torchaudio0.13.1 --extra-index-url https://download.pytorch.org/whl/cu116 torch版本查看 python import torch print(torch.__version__) 查看pytorch能否使用cuda import torch# 检查CUDA是否可用…

鸿蒙NEXT开发-动画(基于最新api12稳定版)

注意:博主有个鸿蒙专栏,里面从上到下有关于鸿蒙next的教学文档,大家感兴趣可以学习下 如果大家觉得博主文章写的好的话,可以点下关注,博主会一直更新鸿蒙next相关知识 专栏地址: https://blog.csdn.net/qq_56760790/…

241014-绿联UGOSPro-通过虚拟机访问主机的用户目录及文件夹

如图所示,两种方式; 方式1: 通过Files中的Other Locations 添加主机ip,随后输入主机的用户名及密码即可系统及文件加载可能需要一段时间,有点卡,加载完应该就可以点击访问了 方式2: 通过命令行直接ssh/sftp userna…

【C++网络编程】(一)Linux平台下TCP客户/服务端程序

文章目录 Linux平台下TCP客户/服务端程序服务端客户端相关头文件介绍 Linux平台下TCP客户/服务端程序 图片来源:https://subingwen.cn/linux/socket/ 下面实现一个Linux平台下TCP客户/服务端程序:客户端向服务器发送:“你好,服务…

网络资源模板--Android Studio 实现简易新闻App

目录 一、项目演示 二、项目测试环境 三、项目详情 四、完整的项目源码 一、项目演示 网络资源模板--基于Android studio 实现的简易新闻App 二、项目测试环境 三、项目详情 登录页 用户输入: 提供账号和密码输入框,用户可以输入登录信息。支持“记…

[ComfyUI]Flux:国漫经典!斗破苍穹古熏儿之绮梦流光模型来袭

在数字艺术和创意领域,FLUX以其独特的虚实结合技术,已经成为艺术家和设计师们手中的利器。今天,我们激动地宣布,FLUX推出了一款全新的ComfyUI版本——Flux,它将国漫经典《斗破苍穹》中的古熏儿之绮梦流光模型完美融合&…

第十四章 RabbitMQ延迟消息之延迟队列

目录 一、引言 二、死信队列 三、核心代码实现 四、运行效果 五、总结 一、引言 什么是延迟消息? 发送者发送消息时指定一个时间,消费者不会立刻收到消息,而是在指定时间后收到消息。 什么是延迟任务? 设置在一定时间之后才…

Qt入门教程:创建我的第一个小程序

本章教程,主要介绍如何编写一个简单的QT小程序。主要是介绍创建项目的过程。 一、打开QT软件编辑器 这里使用的是QT5.14.2版本的,安装教程参考以往教程:https://blog.csdn.net/qq_19309473/article/details/142907096 二、创建项目 到这里&am…

使用Docker部署nextjs应用

最近使用nextjs网站开发,希望使用docker进行生产环境的部署,减少环境的依赖可重复部署操作。我采用的是Dockerfile编写应用镜像方式 docker-compose实现容器部署的功能。 Docker Docker 可以让开发者打包他们的应用以及依赖包到一个轻量级、可移植的容器…

【大模型问答测试】大模型问答测试脚本实现(第一版)

背景 公司已经做了一段时间的大模型,每次测试或者回归的时候都需要针对问答进行测试回归,耗费大量的时间与精力,因此结合产品特点,开发自动化脚本替代人工的操作,提升测试回归效率 设计 使用pythonrequestExcel进行…

Android笔记(二十四)基于Compose组件的MVVM模式和MVI模式的实现

仔细研究了一下MVI(Model-View-Intent)模式,发现它和MVVM模式非常的相识。在采用Android JetPack Compose组件下,MVI模式的实现和MVVM模式的实现非常的类似,都需要借助ViewModel实现业务逻辑和视图数据和状态的传递。在这篇文章中&#xff0c…

易我数据恢复软件,一键找回你的重要资料!

我们生活在数字时代,数据对我们来说超级重要。工作文件、学习资料,还有照片视频,这些东西要是没了或者不小心删了,那得多烦人啊。幸好现在科技发达,有了数据恢复软件,就像给我们数据上了一把安全锁。市面上…

一篇闪击常用放大器电路(学习笔记)

文章目录 声明概念名词经典电路分析反向放大器同向放大器加法器减法器积分电路微分电路差分放大电路电流->电压转换电路电压->电流转换电路 虚短与虚断一、虚短二、虚断 一些碎碎念 声明 ​ 本文是主要基于以下两篇博客所做的笔记: 模电四:基本放…

IT招聘乱象的全面分析

近年来,IT行业的招聘要求似乎越来越苛刻,甚至有些不切实际。许多企业在招聘时,不仅要求前端工程师具备UI设计能力,还希望后端工程师精通K8S服务器运维,更有甚至希望研发经理掌握所有前后端框架和最新开发技术。这种招聘…

MySQL基本语法、高级语法知识总结以及常用语法案例

MySQL基本语法总结 MySQL是一种广泛使用的关系型数据库管理系统,其基本语法涵盖了数据库和数据表的创建、查询、修改和删除等操作。 一、数据库操作 创建数据库(CREATE DATABASE) 语法:CREATE DATABASE [IF NOT EXISTS] databa…

最新PHP礼品卡回收商城 点卡回收系统源码_附教程

最新PHP礼品卡回收商城 点卡回收系统源码_附教程 各大电商平台优惠券秒杀拼团限时折扣回收商城带余额宝 1、余额宝理财 2、回收、提现、充值、新订单语音消息提醒功能 3、带在线客服 4、优惠券回收功能 源码下载:https://download.csdn.net/download/m0_66047…

Android实现App内直接预览本地PDF文件

在App内实现直接预览pdf文件,而不是通过调用第三方软件,如WPS office等打开pdf。 主要思路:通过PhotoView将pdf读取为图片流进行展示。 一、首先,获取对本地文件读取的权限 在AndrooidManifest.xml中声明权限,以及页…

Windows,MySQL主从复制搭建

前提:windows环境,同一个服务器安装多个相同版本的mysql数据库 多个MySQL服务搭建完成后,下面我们进行主从复制的相关配置 1.主数据库 执行指令 #创建用户 CREATE USER slavelocalhost IDENTIFIED BY 123456;#授权 GRANT REPLICATION SLA…