MATLAB环境下生成对抗网络系列(11种)

为了构建有效的图像深度学习模型,数据增强是一个非常行之有效的方法。图像的数据增强是一套使用有限数据来提高训练数据集质量和规模的数据空间解决方案。广义的图像数据增强算法包括:几何变换、颜色空间增强、核滤波器、混合图像、随机擦除、特征空间增强、对抗训练、生成对抗网络和风格迁移等内容。增强的数据代表一个分布覆盖性更广、可靠性更高的数据点集,使用增强数据能够有效增加训练样本的多样性,最小化训练集和验证集以及测试集之间的距离。使用数据增强后的数据集训练模型,可以达到提升模型稳定性、泛化能力的效果。

使用生成对抗网络GAN提取原数据集特征,对抗生成新的目标域图像,已成为众多学者在数据增强技术研究中的优选方法。相比于传统的图像数据增强方法,通过基于GAN的生成式建模技术进行数据增强的思路来源于博弈论中的二人零和博弈,由网络中包含的生成器和判别器利用对抗学习的方法来指导网络训练,在两个网络对抗过程中估计原始数据样本的分布并生成与之相似的新数据。

近期的研究在原始生成对抗网络框架的基础上又提出了多种不同的改进方案,通过设计不同的神经网络架构和损失函数等手段不断提升生成对抗网络的性能。生成对抗网络应用在图像数据增强任务上的思想主要是其通过生成新的训练数据来扩充模型的训练数据,通过样本空间的扩充实现图像分类等任务效果的提升。但目前基于GAN的图像数据增强技术普遍存在模型收敛不稳定、生成图像质量低等问题,如何正确引入高频信息,提升图像数据质量是破解这一系列问题的关键。

MATLAB环境配置如下:

  • MATLAB 2021b
  • Deep Learning Toolbox
  • Parallel Computing Toolbox (optional for GPU usage)

目录如下

  • Generative Adversarial Network (GAN) [paper]
  • Least Squares Generative Adversarial Network (LSGAN) [paper]
  • Deep Convolutional Generative Adversarial Network (DCGAN) [paper]
  • Conditional Generative Adversarial Network (CGAN)[paper]
  • Auxiliary Classifier Generative Adversarial Network (ACGAN) [paper]
  • InfoGAN [paper]
  • Adversarial AutoEncoder (AAE)[paper]
  • Pix2Pix[paper]
  • Wasserstein Generative Adversarial Network (WGAN) [paper]
  • Semi-Supervised Generative Adversarial Network (SGAN) [paper]
  • CycleGAN [paper]
  • DiscoGAN [paper]

部分代码如下:

首先,导入相关的mnist手写数字图

load('mnistAll.mat')

然后对训练、测试图像进行预处理

trainX = preprocess(mnist.train_images); 
trainY = mnist.train_labels;%训练标签
testX = preprocess(mnist.test_images); 
testY = mnist.test_labels;%测试标签

preprocess为归一化函数,如下

function x = preprocess(x)
x = double(x)/255;
x = (x-.5)/.5;
x = reshape(x,28*28,[]);
end

然后进行参数设置,包括潜变量空间维度,batch_size大小,学习率,最大迭代次数等等

settings.latent_dim = 10;
settings.batch_size = 32; settings.image_size = [28,28,1]; 
settings.lrD = 0.0002; settings.lrG = 0.0002; settings.beta1 = 0.5;
settings.beta2 = 0.999; settings.maxepochs = 50;

下面进行编码器初始化,代码还是很容易看懂的

paramsEn.FCW1 = dlarray(initializeGaussian([512,...
     prod(settings.image_size)],.02));
paramsEn.FCb1 = dlarray(zeros(512,1,'single'));
paramsEn.FCW2 = dlarray(initializeGaussian([512,512]));
paramsEn.FCb2 = dlarray(zeros(512,1,'single'));
paramsEn.FCW3 = dlarray(initializeGaussian([2*settings.latent_dim,512]));
paramsEn.FCb3 = dlarray(zeros(2*settings.latent_dim,1,'single'));

解码器初始化

paramsDe.FCW1 = dlarray(initializeGaussian([512,settings.latent_dim],.02));
paramsDe.FCb1 = dlarray(zeros(512,1,'single'));
paramsDe.FCW2 = dlarray(initializeGaussian([512,512]));
paramsDe.FCb2 = dlarray(zeros(512,1,'single'));
paramsDe.FCW3 = dlarray(initializeGaussian([prod(settings.image_size),512]));
paramsDe.FCb3 = dlarray(zeros(prod(settings.image_size),1,'single'));

判别器初始化

paramsDis.FCW1 = dlarray(initializeGaussian([512,settings.latent_dim],.02));
paramsDis.FCb1 = dlarray(zeros(512,1,'single'));
paramsDis.FCW2 = dlarray(initializeGaussian([256,512]));
paramsDis.FCb2 = dlarray(zeros(256,1,'single'));
paramsDis.FCW3 = dlarray(initializeGaussian([1,256]));
paramsDis.FCb3 = dlarray(zeros(1,1,'single'));

%平均梯度和平均梯度平方数组
avgG.Dis = []; avgGS.Dis = []; avgG.En = []; avgGS.En = [];
avgG.De = []; avgGS.De = [];

开始训练

dlx = gpdl(trainX(:,1),'CB');
dly = Encoder(dlx,paramsEn);
numIterations = floor(size(trainX,2)/settings.batch_size);
out = false; epoch = 0; global_iter = 0;
while ~out
    tic; 
    shuffleid = randperm(size(trainX,2));
    trainXshuffle = trainX(:,shuffleid);
    fprintf('Epoch %d\n',epoch) 
    for i=1:numIterations
        global_iter = global_iter+1;
        idx = (i-1)*settings.batch_size+1:i*settings.batch_size;
        XBatch=gpdl(single(trainXshuffle(:,idx)),'CB');

        [GradEn,GradDe,GradDis] = ...
                dlfeval(@modelGradients,XBatch,...
                paramsEn,paramsDe,paramsDis,settings);

        % 更新判别器网络参数
        [paramsDis,avgG.Dis,avgGS.Dis] = ...
            adamupdate(paramsDis, GradDis, ...
            avgG.Dis, avgGS.Dis, global_iter, ...
            settings.lrD, settings.beta1, settings.beta2);

        % 更新编码器网络参数
        [paramsEn,avgG.En,avgGS.En] = ...
            adamupdate(paramsEn, GradEn, ...
            avgG.En, avgGS.En, global_iter, ...
            settings.lrG, settings.beta1, settings.beta2);
        
        % 更新解码器网络参数
        [paramsDe,avgG.De,avgGS.De] = ...
            adamupdate(paramsDe, GradDe, ...
            avgG.De, avgGS.De, global_iter, ...
            settings.lrG, settings.beta1, settings.beta2);
        
        if i==1 || rem(i,20)==0
            progressplot(paramsDe,settings);
            if i==1 
                h = gcf;
                % 捕获图像
                frame = getframe(h); 
                im = frame2im(frame); 
                [imind,cm] = rgb2ind(im,256); 
                % 写入 GIF 文件
                if epoch == 0
                  imwrite(imind,cm,'AAEmnist.gif','gif', 'Loopcount',inf); 
                else 
                  imwrite(imind,cm,'AAEmnist.gif','gif','WriteMode','append'); 
                end 
            end
        end
        
    end

    elapsedTime = toc;
    disp("Epoch "+epoch+". Time taken for epoch = "+elapsedTime + "s")
    epoch = epoch+1;
    if epoch == settings.maxepochs
        out = true;
    end    
end

下面是完整的辅助函数

模型的梯度计算函数

function [GradEn,GradDe,GradDis]=modelGradients(x,paramsEn,paramsDe,paramsDis,settings)
dly = Encoder(x,paramsEn);
latent_fake = dly(1:settings.latent_dim,:)+...
    dly(settings.latent_dim+1:2*settings.latent_dim)*...
    randn(settings.latent_dim,settings.batch_size);
latent_real = gpdl(randn(settings.latent_dim,settings.batch_size),'CB');

%训练判别器
d_output_fake = Discriminator(latent_fake,paramsDis);
d_output_real = Discriminator(latent_real,paramsDis);
d_loss = -.5*mean(log(d_output_real+eps)+log(1-d_output_fake+eps));

%训练编码器和解码器
x_ = Decoder(latent_fake,paramsDe);
g_loss = .999*mean(mean(.5*(x_-x).^2,1))-.001*mean(log(d_output_fake+eps));

%对于每个网络,计算关于损失函数的梯度
[GradEn,GradDe] = dlgradient(g_loss,paramsEn,paramsDe,'RetainData',true);
GradDis = dlgradient(d_loss,paramsDis);
end

提取数据函数

function x = gatext(x)
x = gather(extractdata(x));
end

GPU深度学习数组wrapper函数

function dlx = gpdl(x,labels)
dlx = gpuArray(dlarray(x,labels));
end

权重初始化函数

function parameter = initializeGaussian(parameterSize,sigma)
if nargin < 2
    sigma = 0.05;
end
parameter = randn(parameterSize, 'single') .* sigma;
end

dropout函数

function dly = dropout(dlx,p)
if nargin < 2
    p = .3;
end
[n,d] = rat(p);
mask = randi([1,d],size(dlx));
mask(mask<=n)=0;
mask(mask>n)=1;
dly = dlx.*mask;
end

编码器函数

function dly = Encoder(dlx,params)
dly = fullyconnect(dlx,params.FCW1,params.FCb1);
dly = leakyrelu(dly,.2);
dly = fullyconnect(dly,params.FCW2,params.FCb2);
dly = leakyrelu(dly,.2);
dly = fullyconnect(dly,params.FCW3,params.FCb3);
dly = leakyrelu(dly,.2);
end

解码器函数

function dly = Decoder(dlx,params)
dly = fullyconnect(dlx,params.FCW1,params.FCb1);
dly = leakyrelu(dly,.2);
dly = fullyconnect(dly,params.FCW2,params.FCb2);
dly = leakyrelu(dly,.2);
dly = fullyconnect(dly,params.FCW3,params.FCb3);
dly = leakyrelu(dly,.2);
dly = tanh(dly);
end

判别器函数

function dly = Discriminator(dlx,params)
dly = fullyconnect(dlx,params.FCW1,params.FCb1);
dly = leakyrelu(dly,.2);
dly = fullyconnect(dly,params.FCW2,params.FCb2);
dly = leakyrelu(dly,.2);
dly = fullyconnect(dly,params.FCW3,params.FCb3);
dly = sigmoid(dly);
end

工学博士,担任《Mechanical System and Signal Processing》审稿专家,担任
《中国电机工程学报》优秀审稿专家,《控制与决策》,《系统工程与电子技术》,《电力系统保护与控制》,《宇航学报》等EI期刊审稿专家。

擅长领域:现代信号处理,机器学习,深度学习,数字孪生,时间序列分析,设备缺陷检测、设备异常检测、设备智能故障诊断与健康管理PHM等。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/384433.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

人生就是一场大考

感觉人生就是一场大考&#xff0c;最重要的高考&#xff0c;而这个考试要到49岁左右才有得分。如&#xff0c;当初事业的选择&#xff0c;爱情伴侣的选择&#xff0c;出去在外发展的选择……这都决定了你以后人生高考的走向。现在的你过得怎么样&#xff0c;也是你在人生这场高…

Java实现河南软件客服系统 JAVA+Vue+SpringBoot+MySQL

目录 一、摘要1.1 项目介绍1.2 项目录屏 二、功能模块2.1 系统管理人员2.2 业务操作人员 三、系统展示四、核心代码4.1 查询客户4.2 新增客户跟进情况4.3 查询客户历史4.4 新增服务派单4.5 新增客户服务费 五、免责说明 一、摘要 1.1 项目介绍 基于JAVAVueSpringBootMySQL的河…

机器学习系列5-特征组合、简化正则化

1.特征组合 1.1特征组合&#xff1a;编码非线性规律 我们做出如下假设&#xff1a;蓝点代表生病的树。橙色的点代表健康的树。 您可以绘制一条直线将生病的树与健康的树清晰地分开吗&#xff1f;不可以。这是一个非线性问题。您绘制的任何线条都无法很好地预测树的健康状况…

1【算法】——最大子数组问题(maximum subarray)

一.问题描述 假如我们有一个数组&#xff0c;数组中的元素有正数和负数&#xff0c;如何在数组中找到一段连续的子数组&#xff0c;使得子数组各个元素之和最大。 二.问题分析 分治法求解&#xff1a; 初始状态&#xff1a; low0&#xff1b;highA.length-1&#xff1b;mid&am…

GO 的 Web 开发系列(五)—— 使用 Swagger 生成一份好看的接口文档

经过前面的文章&#xff0c;已经完成了 Web 系统基础功能的搭建&#xff0c;也实现了 API 接口、HTML 模板渲染等功能。接下来要做的就是使用 Swagger 工具&#xff0c;为这些 Api 接口生成一份好看的接口文档。 一、写注释 注释是 Swagger 的灵魂&#xff0c;Swagger 是通过…

leaflet 显示自己geoserver发布的中国地图

安装vscode 安装 通义灵码 问题&#xff1a; 用leaflet显示一个wms地图 修改下代码&#xff0c;结果如下&#xff1a; 例子代码 <!DOCTYPE html> <html lang"en"> <head><meta charset"UTF-8"><meta name"viewport&q…

【HarmonyOS 4.0 应用开发实战】ArkTS 快速入门之常用属性

个人名片&#xff1a; &#x1f43c;作者简介&#xff1a;一名大三在校生&#xff0c;喜欢AI编程&#x1f38b; &#x1f43b;‍❄️个人主页&#x1f947;&#xff1a;落798. &#x1f43c;个人WeChat&#xff1a;hmmwx53 &#x1f54a;️系列专栏&#xff1a;&#x1f5bc;️…

《统计学简易速速上手小册》第5章:回归分析(2024 最新版)

文章目录 5.1 线性回归基础5.1.1 基础知识5.1.2 主要案例&#xff1a;员工薪资预测5.1.3 拓展案例 1&#xff1a;广告支出与销售额关系5.1.4 拓展案例 2&#xff1a;房价与多个因素的关系 5.2 多元回归分析5.2.1 基础知识5.2.2 主要案例&#xff1a;企业收益与多因素关系分析5.…

【小沐学GIS】基于WebGL绘制三维数字地球Earth(OpenGL)

&#x1f37a;三维数字地球系列相关文章如下&#x1f37a;&#xff1a;1【小沐学GIS】基于C绘制三维数字地球Earth&#xff08;OpenGL、glfw、glut&#xff09;第一期2【小沐学GIS】基于C绘制三维数字地球Earth&#xff08;OpenGL、glfw、glut&#xff09;第二期3【小沐学GIS】…

最短路径算法

1. Dijkstra算法 在正数权重的有向图中求解某个源点到其余各个顶点的最短路径一般可以采用迪杰斯特拉算法&#xff08;Dijkstra算法&#xff09;。 1.1 适用场景 单源最短路径权重都为正 1.2 伪代码 1.3 示例 问题描述&#xff1a; 计算节点0到节点4的最短路径&#xff0c…

arduino uno R3驱动直流减速电机(蓝牙控制)

此篇博客用于记录使用arduino驱动直流减速电机的过程&#xff0c;仅实现简单的功能&#xff1a;PID调速、蓝牙控制 1、直流减速电机简介2、DRV8833电机驱动模块简介3、HC-05蓝牙模块简介电机转动测试4、PID控制5、蓝牙控制电机 1、直流减速电机简介 我在淘宝购买的电机&#x…

C语言整型数据类型..

1.整型数据类型 在C语言中 根据数据范围从小到大依次为char、short、int、long、long long 但是对于整型来说 为什么有这么多类型呢 我们得先说字节的本质&#xff1a; 计算机是通过晶体管的开关状态来记录数据 晶体管通常8个为一组 称为一个字节 而晶体管由两种状态 分别是开…

交叉熵损失函数(Cross-Entropy Loss)的基本概念与程序代码

交叉熵损失函数&#xff08;Cross-Entropy Loss&#xff09;是机器学习和深度学习中常用的损失函数之一&#xff0c;用于分类问题。其基本概念如下&#xff1a; 1. 基本解释&#xff1a; 交叉熵损失函数衡量了模型预测的概率分布与真实概率分布之间的差异。在分类问题中&…

春节假期:思考新一年的发展思路

春节假期是人们放松身心、享受家庭团聚的时刻&#xff0c;但除了走亲戚、玩、吃之外&#xff0c;我们确实也需要思考新的一年的发展思路。以下是一些建议&#xff0c;帮助您在春节假期中为新的一年做好准备&#xff1a; 回顾过去&#xff0c;总结经验&#xff1a;在春节期间&a…

大华智慧园区综合管理平台/emap/devicePoint RCE漏洞

免责声明&#xff1a;文章来源互联网收集整理&#xff0c;请勿利用文章内的相关技术从事非法测试&#xff0c;由于传播、利用此文所提供的信息或者工具而造成的任何直接或者间接的后果及损失&#xff0c;均由使用者本人负责&#xff0c;所产生的一切不良后果与文章作者无关。该…

【十六】【C++】stack的常见用法和练习

stack的常见用法 C标准库中的stack是一种容器适配器&#xff0c;它提供了后进先出&#xff08;Last In First Out, LIFO&#xff09;的数据结构。stack使用一个底层容器进行封装&#xff0c;如deque、vector或list&#xff0c;但只允许从一端&#xff08;顶部&#xff09;进行…

一周学会Django5 Python Web开发-Django5操作命令

锋哥原创的Python Web开发 Django5视频教程&#xff1a; 2024版 Django5 Python web开发 视频教程(无废话版) 玩命更新中~_哔哩哔哩_bilibili2024版 Django5 Python web开发 视频教程(无废话版) 玩命更新中~共计11条视频&#xff0c;包括&#xff1a;2024版 Django5 Python we…

第8讲个人中心页面搭建实现

个人中心页面搭建实现 <template><view class"user_center"><!-- 用户信息开始 --><view class"user_info_wrap"><!--获取头像--><button class"user_image"></button> <view class"user_n…

14.盔甲?装甲?装饰者模式!

人类的军工发展史就是一场矛与盾的追逐&#xff0c;矛利则盾坚&#xff0c;盾愈坚则矛愈利。在传统的冶金工艺下&#xff0c;更坚固的盾牌和盔甲往往意味着更迟缓笨重的运动能力和更高昂的移动成本。从战国末期的魏武卒、秦锐士&#xff0c;到两宋之交的铁浮图、重步兵&#xf…

Roop的安装教程

roop插件的安装&#xff0c;并不容易 并且最好就是在电脑本地完成&#xff0c;因为涉及到C、visual studio软件&#xff0c;并且还需要在电脑本地放置一些模型&#xff0c;用autoDL其实也有镜像&#xff0c;但是需要数据扩容至少100G&#xff0c;烧钱。。。 电脑本地&#xff0…