GAN对抗生成网络(二)——算法及Python实现

1 算法步骤

上一篇提到的GAN的最优化问题是G^{*}=\arg\min\limits_{G}\max\limits_{D}V(G,D),本文记录如何求解这一问题。

首先为了表示方便,记\max\limits_{D}V(G,D)=L(G),这里让V(G,D)最大的D=D^{*}可视作常量。

第一步,给定初始的G_{0},使用梯度上升找到 D_0^{*},最大化L(G_0)。关于梯度下降,可以参考笔者另一篇文章《BP神经网络原理-CSDN博客》误差反向传播的部分。

第二步,使用梯度下降法,找到G最佳的参数\theta_{G}.其中\eta为学习率。

\theta_{G}\leftarrow \theta_{G}-\eta\frac{\partial{L(G)}}{\theta_{G}}得到G_{1}

 之后这两步交替进行。

这里的L(G)是有max运算的,可以被微分吗?答案是可以的

引用李宏毅老师的例子,f(x)是有max运算的,相当于分段函数,在求微分的时候,根据当前x落在哪个区域决定微分的形式如何。

2 算法与JS散度的关系

上述算法第一步训练D时本质是增大JS散度,第二步训练G时看起来是减小JS散度,但实际上不完全等同。

如下图所示,左侧表示算法第一步根据G_{0}得到了最优的D_0^{*}。当进行到算法第二步,需要根据D_0^{*}找到一个更小的JS散度,如右图所示,G选择了G_{1}从而使得V(G_1, D)<V(G_0, D)。虽然此时JS散度更小,但是由于G_{0}更换成了G_{1}D_0^{*}将更新参数变成D_1^{*},此时JS散度更大了。只能说不能让G更新得太多,否则不能达到减小JS散度的目标。回到文物造假的例子,造假者收到鉴宝者的反馈后,应该微调技艺,而不是彻底更换技艺,否则只能从头来过。

从快速收敛的角度来说,G应该不能更新太过,但是如果太小也忽略了G更好的形式,可能陷入局部最优。

3 实际训练过程

实际训练时,我们是无法计算出真实数据或生成数据实际的期望的,只能通过抽样近似得到期望。因此实际的做法如下:

3.1 第一步,初始化

初始化生成器G和判别器D

3.2 第二步,固定G,训练D

从分布函数(如高斯分布)中随机抽样出m个样本\left \{ z^1,z^2,...z^m \right \}输入给G,输出m个样本\left \{ \tilde{x}^1,\tilde{x}^2,..\tilde{x}^m \right \}G本质上概率分布转化器——将高斯分布的噪声转变成样本的分布。从真实数据中随机抽样出m个样本\left \{ x^1,x^2,...x^m \right \},将二者输入给D。训练D的参数使其接收x_{1}时打出0分,接收x_{2}时打出1分,即最大化\tilde{V}=\frac{1}{m}\sum_{i=1}^m{log{D(x^i)}}+\frac{1}{m}\sum_{i=1}^m{log{(1-D(\tilde{x}^i))}}

建模成分类或回归问题均可。

使用梯度上升法,\theta_d\leftarrow \theta_d+\eta\nabla\tilde{V}(\theta_d)

实际中需要更新多次,使得V值最大。这一步实际上只找到了一个\max\limits_{D}V(G,D)的下限(lower bound),原因是:(1)训练次数不会非常大,没法训练到收敛;(2)即使能收敛,也可能只是一个局部最优解;(3)推导时假设了D可以是任意的函数,即针对不同的x都给出最高的值,但实际中这个假设不成立。

3.3 第三步,固定D,训练G

从分布函数(如高斯分布)中随机抽样出另外m个样本\left \{ z^1,z^2,...z^m \right \}

更新G的参数\theta_G使得下式最小:

\tilde{V}=\frac{1}{m}\sum_{i=1}^m{log{D(x^i)}}+\frac{1}{m}\sum_{i=1}^m{log{(1-D(G(z^i)))}}

其中第一项与G无关,因此只需要看第二项。

根据上文的讨论,这里一般只训练一次,避免G改变过多,无法收敛。

实践中是将GD合在一起作为一个大的神经网络,前几层是G,后几层是D,中间有一个隐含层是G的输出,就是GAN希望得到的输出。第二步和第三步可分别固定神经网络中的某几层参数不动,训练其它层参数。

4 Python实现

关于GAN的代码,参考了https://github.com/junqiangchen/GAN。项目可以产生数字图片和人脸图片,其中人脸图片的生成使用了GAN的变种——WGAN,之后会专门讨论,本文讨论最原始的GAN模型。

4.1 使用新版tensorflow需要修改的地方

原始的代码直接运行是不通的,需要做一些调整;原始代码采用的是旧版Tensorflow(V1),如果安装了新版TensorFlow(V2)也需要做调整;有些包如果安装的新版同样不支持部分API,需要替换。具体如下表所示

问题调整方法备注
部分文件路径不对

调整路径,例如

from GAN.face_model import WGAN_GPModel

调整为

from GAN.genface.face_model import WGAN_GPModel
其他几处不再赘述
imresize报错

 例如

from scipy.misc import imresize

调整为

from skimage.transform import resize
最新版本scipy不支持此函数,将
imresize(test_image, (init_width * scale_factor, init_height * scale_factor))

替换为

resize(
Image.fromarray(test_image).resize(init_width * scale_factor, init_height * scale_factor))
imsave
报错

例如

scipy.misc.imsave(path, merge_img)

调整为

import cv2

cv2.imwrite(path, merge_img * 255)

最新版本scipy不支持此函数,替换成cv2。个人认为最后应该乘255,因为原始数据是0~1的数据,直接存会存成几乎黑白的图片,需要还原
使用新版tensorflow的问题
import tensorflow as tf

替换为

import tensorflow.compat.v1 as tf


tf.compat.v1.disable_eager_execution()
新版tensorflow提供了向下兼容的compat.v1的使用方式,统一替换即可。同时要取消eager_execution模式,新版默认是“即时计算”模式,如果兼容旧版则应取消该模式。

4.2 GAN的代码解析

代码位置:gan/GAN/genmnist/mnist_model.py, class名为GANModel

4.2.1 Generator

定义在_GAN_generator函数中,总结为以下要点:

(1)含有五层网络,除最后一层,其他层在进入下一层之前都用batch_normalization归一化+relu激活函数

g4 = tf.contrib.layers.batch_norm(g4, epsilon=1e-5, is_training=self.phase, scope='bn4')
g4 = tf.nn.relu(g4)

(2)每一层都定义w和b,使用truncated_normal,即截断异常值的正态分布

tf.truncated_normal_initializer

(3)第1~2层使用全连接层,即使用w乘输入,并加上b偏置

tf.matmul(g1, g_w2) + g_b2

(4)第3~4层使用反卷积运算。是卷积运算的逆过程,关于反卷积的介绍笔者正在整理

tf.nn.conv2d_transpose(x, W, output_shape, strides=[1, strides, strides, 1], padding='SAME')

(5)第5层使用卷积运算,并使用tanh激活函数

g5 = convolution_2d(g4, g_w5)

4.2.2 Discriminator

与Generator类似,简述如下:

(1)共4层,其中1、2层使用卷积,3、4层使用全连接

(2)卷积后使用平均池化

d1 = average_pool_2x2(d1)

(3)最后一层使用sigmoid将输出控制在0~1之间

out = tf.nn.sigmoid(out_logit)

4.2.3 损失函数

Generator的损失函数为

-tf.reduce_mean(tf.log(self.D_fake))

对应前文提到的\frac{1}{m}\sum_{i=1}^m{log{(1-D(G(z^i)))}}。注意这里是用-\frac{1}{m}\sum_{i=1}^m{log{D(G(z^i))}},方向是一样的,之后笔者会讨论他们的区别。

Discriminator的损失函数为

-tf.reduce_mean(tf.log(self.D_real) + tf.log(1 - self.D_fake))

对应前文提到的\tilde{V}=\frac{1}{m}\sum_{i=1}^m{log{D(x^i)}}+\frac{1}{m}\sum_{i=1}^m{log{(1-D(\tilde{x}^i))}}

4.2.5 训练

定义D和G的训练函数,使得各自损失函数最小化。

trainD_op = tf.train.AdamOptimizer(learning_rate, beta1).minimize(self.d_loss, var_list=D_vars)
trainG_op = tf.train.AdamOptimizer(learning_rate, beta1).minimize(self.g_loss, var_list=G_vars)

先让D预训练30次,然后D和G交替训练。为什么先让D预训练30次?笔者认为D本质上就是个图片分类器,可以不依赖于G,比较好训练,预训练可以加快收敛速度。

训练时使用feed“喂数据”

feed_dict={self.X: batch_xs, self.Z: z_batch, self.phase: 1}

其中self.X表示真实的图片,self.Z表示噪声,self.phase表示batchnorm训练阶段还是测试阶段。

4.2.6 预测

生成随机噪声Z之后,喂给G,即可生成图片

outimage = self.Gen.eval(feed_dict={self.Z: z_batch, self.phase: 1}, session=sess)

不过笔者对这里的phase有些疑问,是否应该设置为0?恕笔者对Tensorflow不熟,代码解析有些走马观花,没有深究细节以及为什么这么写,等功力提高再回过头来优化。

至此,原始GAN的算法以及Python实现已介绍完毕,下一篇笔者将拓展讨论一些细节并介绍GAN的变种。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/947228.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

JAVA(二)【未完】

数据类型与变量 数据类型&#xff1a;基本数据类型&#xff1a;整型&#xff1a;byte short int long 浮点型&#xff1a;float double char 布尔型&#xff1a;boolean 引用数据类型&#xff1a;数组 类 接口 枚举类型 long b 10l;System.out.println(b);System.out.printl…

C语言day5:shell脚本

一、练习题1 定义一个find函数&#xff0c;查找ubuntu和root的gid并使用变量接收结果 二、练习题2 定义一个数组&#xff0c;写一个函数完成对数组的冒泡排序 三、练习题3 使用break求1-100中的质数&#xff08;质数&#xff1a;只能被1和它本身整除&#xff0c;如&#xff1a;…

R语言6种将字符转成数字的方法,写在新年来临之际

咱们临床研究中&#xff0c;拿到数据后首先要对数据进行清洗&#xff0c;把数据变成咱们想要的格式&#xff0c;才能进行下一步分析&#xff0c;其中数据中的字符转成数字是个重要的内容&#xff0c;因为字符中常含有特殊符号&#xff0c;不利于分析&#xff0c;转成数字后才能…

C语言面的向对象编程(OOP)

如果使用过C、C#、Java语言&#xff0c;一定知道面向对象编程&#xff0c;这些语言对面向对象编程的支持是语言级别的。C语言在语言级别不支持面向对象&#xff0c;那可以实现面向对象吗&#xff1f;其实面向对象是一种思想&#xff0c;而不是一种语言&#xff0c;很多初学者很…

C++ 基础思维导图(一)

目录 1、C基础 IO流 namespace 引用、const inline、函数参数 重载 2、类和对象 类举例 3、 内存管理 new/delete 对象内存分布 内存泄漏 4、继承 继承权限 继承中的构造与析构 菱形继承 1、C基础 IO流 #include <iostream> #include <iomanip> //…

聊聊前端框架中的process.env,env的来源及优先级(next.js、vue-cli、vite)

在平时开发中&#xff0c;常常使用vue、react相关脚手架创建项目&#xff0c;在项目根目录可以创建.env、.env.[mode]&#xff08;mode为development、production、test)、.env.local等文件&#xff0c;然后在项目中就可以通过process.env来访问相关的环境变量了。 下面针对如下…

基于云架构Web端的工业MES系统:赋能制造业数字化变革

基于云架构Web端的工业MES系统:赋能制造业数字化变革 在当今数字化浪潮席卷全球的背景下,制造业作为国家经济发展的重要支柱产业,正面临着前所未有的机遇与挑战。市场需求的快速变化、客户个性化定制要求的日益提高以及全球竞争的愈发激烈,都促使制造企业必须寻求更加高效、智…

LeetCode算法题——螺旋矩阵ll

题目描述 给你一个正整数n&#xff0c;生成一个包含1到n2所有元素&#xff0c;且元素按顺时针顺序螺旋排列的n x n正方形矩阵matrix 。 示例 输入&#xff1a;n 3 输出&#xff1a;[[1,2,3],[8,9,4],[7,6,5]]题解 思路&#xff1a; 将整个过程分解为逐圈填充的过程&#xf…

MySQL 01 02 章——数据库概述与MySQL安装篇

一、数据库概述 &#xff08;1&#xff09;为什么要使用数据库 数据库可以实现持久化&#xff0c;什么是持久化&#xff1a;数据持久化意味着将内存中的数据保存到硬盘上加以“固化”持久化的主要作用是&#xff1a;将内存中的数据存储在关系型数据库中&#xff0c;当然也可以…

GPU 进阶笔记(四):NVIDIA GH200 芯片、服务器及集群组网

大家读完觉得有意义记得关注和点赞&#xff01;&#xff01;&#xff01; 1 传统原厂 GPU 服务器&#xff1a;Intel/AMD x86 CPU NVIDIA GPU2 新一代原厂 GPU 服务器&#xff1a;NVIDIA CPU NVIDIA GPU 2.1 CPU 芯片&#xff1a;Grace (ARM)2.2 GPU 芯片&#xff1a;Hopper/B…

vite6+vue3+ts+prettier+eslint9配置前端项目(后台管理系统、移动端H5项目通用配置)

很多小伙伴苦于无法搭建一个规范的前端项目&#xff0c;导致后续开发不规范&#xff0c;今天给大家带来一个基于Vite6TypeScriptVue3ESlint9Prettier的搭建教程。 目录 一、基础配置1、初始化项目2、代码质量风格的统一2.1、配置prettier2.2、配置eslint2.3、配置typescript 3、…

ESLint+Prettier的配置

ESLintPrettier的配置 安装插件 ​​​​​​ 在settings.json中写下配置 {// tab自动转换标签"emmet.triggerExpansionOnTab": true,"workbench.colorTheme": "Default Dark","editor.tabSize": 2,"editor.fontSize": …

Cyber Security 101-Web Hacking-JavaScript Essentials(JavaScript 基础)

任务1&#xff1a;介绍 JavaScript &#xff08;JS&#xff09; 是一种流行的脚本语言&#xff0c;它允许 Web 开发人员向包含 HTML 和 CSS&#xff08;样式&#xff09;的网站添加交互式功能。创建 HTML 元素后&#xff0c;您可以通过 JS 添加交互性&#xff0c;例如验证、on…

《机器学习》从入门到实战——逻辑回归

目录 一、简介 二、逻辑回归的原理 1、线性回归部分 2、逻辑函数&#xff08;Sigmoid函数&#xff09; 3、分类决策 4、转换为概率的形式使用似然函数求解 5、对数似然函数 ​编辑 6、转换为梯度下降任务 三、逻辑回归拓展知识 1、数据标准化 &#xff08;1&#xf…

JDK8源码分析Jdk动态代理底层原理

本文侧重分析JDK8中jdk动态代理的源码&#xff0c;若是想看JDK17源码分析可以看我的这一篇文章 JDK17源码分析Jdk动态代理底层原理-CSDN博客 两者之间有着略微的差别&#xff0c;JDK17在JDK8上改进了不少 目录 源码分析 过程 生成的代理类大致结构 本文侧重分析JDK8中jdk…

ZYNQ初识6(zynq_7010)clock时钟IP核

基于板子的PL端无时钟晶振&#xff0c;需要从PS端借用clock1&#xff08;50M&#xff09;晶振 接下去是自定义clock的IP核封装&#xff0c;为后续的simulation可以正常仿真波形&#xff0c;需要注意顶层文件的设置&#xff0c;需要将自定义的IP核对应的.v文件设置为顶层文件&a…

深度学习模型格式转换:pytorch2onnx(包含自定义操作符)

将PyTorch模型转换为ONNX&#xff08;Open Neural Network Exchange&#xff09;格式是实现模型跨平台部署和优化推理性能的一种常见方法。PyTorch 提供了多种方式来完成这一转换&#xff0c;以下是几种主要的方法&#xff1a; 一、静态模型转换 使用 torch.onnx.export() t…

GPU 进阶笔记(一):高性能 GPU 服务器硬件拓扑与集群组网

记录一些平时接触到的 GPU 知识。由于是笔记而非教程&#xff0c;因此内容不求连贯&#xff0c;有基础的同学可作查漏补缺之用 1 术语与基础 1.1 PCIe 交换芯片1.2 NVLink 定义演进&#xff1a;1/2/3/4 代监控1.3 NVSwitch1.4 NVLink Switch1.5 HBM (High Bandwidth Memory) 由…

在Unity中用Ab包加载资源(简单好抄)

第一步创建一个Editor文件夹 第二步编写BuildAb&#xff08;这个脚本一点要放在Editor中因为这是一个编辑器脚本&#xff0c;放在其他地方可能会报错&#xff09; using System.IO; using UnityEditor; using UnityEngine;public class BuildAb : MonoBehaviour {// 在Unity编…

【贪心算法】贪心算法七

贪心算法七 1.整数替换2.俄罗斯套娃信封问题3.可被三整除的最大和4.距离相等的条形码5.重构字符串 点赞&#x1f44d;&#x1f44d;收藏&#x1f31f;&#x1f31f;关注&#x1f496;&#x1f496; 你的支持是对我最大的鼓励&#xff0c;我们一起努力吧!&#x1f603;&#x1f…