程序员学长 | 快速学习一个算法,GAN

本文来源公众号“程序员学长”,仅用于学术分享,侵权删,干货满满。

原文链接:快速学习一个算法,GAN

GAN 如何工作?

GAN 由两个部分组成:生成器(Generator)和判别器(Discriminator)。这两个部分通过一种对抗的过程来相互改进和优化。

c03159492fff41b799068bb5cf88bd59.png

生成器(Generator)

生成器的任务是接收一个随机噪声向量,并将其转换为看起来尽可能真实的数据。它通常使用一个深度神经网络来实现,从随机噪声中生成类似于训练数据的样本。

生成器的目标是生成的样本能够骗过判别器,使判别器认为生成的数据是真实的。

判别器(Discriminator)

判别器是一个二分类模型,它的任务是区分真实数据和生成器生成的假数据。

判别器接收一个数据样本,并输出一个概率值,表示该样本是真实数据的概率。

判别器的目标是尽可能准确地将真实数据与生成数据区分开来。

对抗过程

GAN 的训练过程可以被看作是生成器和判别器之间的博弈。具体来说,生成器试图生成逼真的数据以欺骗判别器,而判别器则试图更好地识别生成的假数据

这个过程可以描述为一个最小最大化的优化问题。

4c8b68b8b9184d9092ffda2498ba9173.png

鉴别器 D 想要最大化目标函数,使得 D(x) 接近于 1,D(G(z)) 接近于 0。这意味着鉴别器应该将训练集中的所有图像识别为真实 (1),将所有生成的图像识别为假 (0)。

生成器 (G) 想要最小化目标函数,使得 D(G(z)) 为 1。这意味着生成器试图生成被鉴别器网络分类为 1 的图像。

训练步骤

  1. 初始化生成器和判别器的参数

  2. 判别器训练

    • 从真实数据集中采样一个 mini-batch 的真实数据。

    • 从生成器的噪声分布中采样一个 mini-batch 的噪声,并生成假数据。

    • 更新判别器的参数,使其能够更好地区分真实数据和生成数据。

      f6487b7759c24cf9aa3570065e6f957b.png

  3. 生成器训练

    861dc76c582746c281d4d47c3e1156c0.png
    • 从噪声分布中采样一个 mini-batch 的噪声,并生成假数据。

    • 更新生成器的参数,使其生成的数据更能够欺骗判别器。

  4. 重复上述过程,直到生成器生成的数据足够逼真,判别器无法准确区分真实数据和生成数据

GAN 在图像生成、图像修复、超分辨率、风格迁移、数据增强等领域有广泛应用。例如,通过 GAN,可以生成高分辨率的图像,将低分辨率图像转换为高分辨率图像,或者将某种风格的图像转换为另一种风格的图像。

案例分享

下面是使用 GAN 来生成图像的案例。这里我们以手写数字识别数据集为例进行说明。

1.读取数据集

from __future__ import print_function, division
from keras.datasets import mnist
from keras.layers import Input, Dense, Reshape, Flatten, Dropout
from keras.layers import BatchNormalization, Activation, ZeroPadding2D
from keras.layers import LeakyReLU
from keras.layers import UpSampling2D, Conv2D
from keras.models import Sequential, Model
from keras.optimizers import Adam, SGD
import matplotlib.pyplot as plt
import sys
import numpy as np

num_rows = 28
num_cols = 28
num_channels = 1
input_shape = (num_rows, num_cols, num_channels)
z_size = 100
batch_size = 32
(train_ims, _), (_, _) = mnist.load_data()
train_ims = train_ims / 127.5 - 1.
train_ims = np.expand_dims(train_ims, axis=3)

valid = np.ones((batch_size, 1))
fake = np.zeros((batch_size, 1))

2.定义生成器

生成器 (D) 在 GAN 中扮演着至关重要的角色,因为它负责生成能够欺骗鉴别器的真实图像。

它是 GAN 中图像形成的主要组件。

在本文中,我们为生成器使用了一种特定的架构,该架构包含一个完全连接 (FC) 层并采用 Leaky ReLU 激活。然而,值得注意的是,生成器的最后一层使用 TanH 激活而不是 LeakyReLU。

def build_generator():
    gen_model = Sequential()
    gen_model.add(Dense(256, input_dim=z_size))
    gen_model.add(LeakyReLU(alpha=0.2))
    gen_model.add(BatchNormalization(momentum=0.8))
    gen_model.add(Dense(512))
    gen_model.add(LeakyReLU(alpha=0.2))
    gen_model.add(BatchNormalization(momentum=0.8))
    gen_model.add(Dense(1024))
    gen_model.add(LeakyReLU(alpha=0.2))
    gen_model.add(BatchNormalization(momentum=0.8))
    gen_model.add(Dense(np.prod(input_shape), activation='tanh'))
    gen_model.add(Reshape(input_shape))

    gen_noise = Input(shape=(z_size,))
    gen_img = gen_model(gen_noise)
    return Model(gen_noise, gen_img)

3.定义鉴别器

在生成对抗网络 (GAN) 中,鉴别器 (D) 通过评估真实性和可能性来执行区分真实图像和生成图像的关键任务。

此组件可以看作是一个二元分类问题。

为了解决此任务,我们可以采用一个简化的网络架构,该架构由全连接层 (FC)、Leaky ReLU 激活和 Dropout 层组成。值得一提的是,鉴别器的最后一层包括 FC 层,后跟 Sigmoid 激活。Sigmoid 激活函数产生所需的分类概率。

def build_discriminator():
    disc_model = Sequential()
    disc_model.add(Flatten(input_shape=input_shape))
    disc_model.add(Dense(512))
    disc_model.add(LeakyReLU(alpha=0.2))
    disc_model.add(Dense(256))
    disc_model.add(LeakyReLU(alpha=0.2))
    disc_model.add(Dense(1, activation='sigmoid'))

    disc_img = Input(shape=input_shape)
    validity = disc_model(disc_img)
    return Model(disc_img, validity)

4.计算损失函数

我们可以使用二元交叉熵损失来实现生成器和鉴别器。

# discriminator
disc= build_discriminator()
disc.compile(loss='binary_crossentropy',
    optimizer='sgd',
    metrics=['accuracy'])

z = Input(shape=(z_size,))

# generator
img = generator(z)

disc.trainable = False

validity = disc(img)

# combined model
combined = Model(z, validity)
combined.compile(loss='binary_crossentropy', optimizer='sgd')

5.优化损失

def intialize_model():
    disc= build_discriminator()
    disc.compile(loss='binary_crossentropy',
        optimizer='sgd',
        metrics=['accuracy'])

    generator = build_generator()

    z = Input(shape=(z_size,))
    img = generator(z)

    disc.trainable = False

    validity = disc(img)

    combined = Model(z, validity)
    combined.compile(loss='binary_crossentropy', optimizer='sgd')
    return disc, generator, combined

6. 模型训练

def train(epochs, batch_size=128, sample_interval=50):
    # load images
    (train_ims, _), (_, _) = mnist.load_data()
    # preprocess
    train_ims = train_ims / 127.5 - 1.
    train_ims = np.expand_dims(train_ims, axis=3)

    valid = np.ones((batch_size, 1))
    fake = np.zeros((batch_size, 1))
    # training loop
    for epoch in range(epochs):

        batch_index = np.random.randint(0, train_ims.shape[0], batch_size)
        imgs = train_ims[batch_index]
    # create noise
        noise = np.random.normal(0, 1, (batch_size, z_size))
    # predict using a Generator
        gen_imgs = gen.predict(noise)
    # calculate loss functions
        real_disc_loss = disc.train_on_batch(imgs, valid)
        fake_disc_loss = disc.train_on_batch(gen_imgs, fake)
        disc_loss_total = 0.5 * np.add(real_disc_loss, fake_disc_loss)

        noise = np.random.normal(0, 1, (batch_size, z_size))

        g_loss = full_model.train_on_batch(noise, valid)
   
    # save outputs every few epochs
        if epoch % sample_interval == 0:
            one_batch(epoch)

7.生成手写数字

使用 MNIST 数据集,我们可以创建一个实用函数,使生成器为一组图像生成预测。

此函数生成随机声音,将其提供给生成器,运行它以显示生成的图像并将其保存在特殊文件夹中。建议定期运行此实用函数,例如每 200 个周期运行一次,以监控网络进度。

def one_batch(epoch):
    r, c = 5, 5
    noise_model = np.random.normal(0, 1, (r * c, z_size))
    gen_images = gen.predict(noise_model)

    # Rescale images 0 - 1
    gen_images = gen_images*(0.5) + 0.5

    fig, axs = plt.subplots(r, c)
    cnt = 0
    for i in range(r):
        for j in range(c):
            axs[i,j].imshow(gen_images[cnt, :,:,0], cmap='gray')
            axs[i,j].axis('off')
            cnt += 1
    fig.savefig("images/%d.png" % epoch)
    plt.close()

在我们的实验中,我们使用 32 的批次大小对 GAN 进行了大约 10,000 个时期的训练。

为了跟踪训练的进度,我们每 200 个时期保存一次生成的图像,并将它们存储在名为 “images” 的指定文件夹中。

disc, gen, full_model = intialize_model()
train(epochs=10000, batch_size=32, sample_interval=200)

现在,我们来检查一下不同阶段的 GAN 模拟结果。

初始化、5000 个 epoch 以及 10000 个 epoch 的最终结果。

最初,我们以随机噪声作为生成器的输入。

2e954b9c4c274f3e98d6ff8c9ba2be87.png

经过 5000 个时期的训练后,我们可以观察到生成的图形开始类似于 MNIST 数据集。

1a0dc57d0de14beebed11deee9879f68.png

经过 10,000 个时期的训练后,我们获得以下输出。

7c9ae539548c463db3d7b6a41fe4db86.png

可以看到,这些生成的图像与手写数字数据已经非常相似了。

THE END !

文章结束,感谢阅读。您的点赞,收藏,评论是我继续更新的动力。大家有推荐的公众号可以评论区留言,共同学习,一起进步。

 

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/793194.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

Windows电脑安装Python结合内网穿透轻松搭建可公网访问私有网盘

文章目录 前言1.本地文件服务器搭建1.1.Python的安装和设置1.2.cpolar的安装和注册 2.本地文件服务器的发布2.1.Cpolar云端设置2.2.Cpolar本地设置 3.公网访问测试4.结语 前言 本文主要介绍如何在Windows系统电脑上使用python这样的简单程序语言,在自己的电脑上搭建…

TF卡病毒是什么?如何防范和应对?

在存储芯片及存储卡领域,TF卡病毒是一个备受关注的话题。在本文中,拓优星辰将详细解释TF卡病毒的含义、来源以及如何防范和应对这一问题,帮助客户更好地了解和处理TF卡病毒的风险。 1. TF卡病毒的含义 TF卡病毒是指针对TF存储卡(T…

【案例】python集成OCR识别工具调研

目录 一、前言二、Tesseract_OCR2.1、安装过程2.2、python代码使用三、PaddleOCR3.1、安装过程3.2、python代码使用四、EasyOCR五、ddddOCR六、CnOCR七、总结一、前言 因项目需要OCR识别能力,且要支持私有化部署。本文将对比市场一些开源的OCR识别工具,从中选择适合项目需要…

逻辑回归(纯理论)

1.什么是逻辑回归? 逻辑回归是一种常用的统计学习方法,主要用于解决分类问题。尽管名字中包含"回归",但它实际上是一种分类算法 2.为什么机器学习需要使用逻辑回归 1.二元分类 这是逻辑回归最基本和常见的用途。它可以预测某个事…

【备战秋招】——算法题目训练和总结day3

【备战秋招】——算法题目训练和总结day3😎 前言🙌BC149简写单词题解思路分析代码分享: dd爱框框题解思路分析代码分享: 除2!题解思路分析代码分享: 总结撒花💞 😎博客昵称&#xff…

多周期路径的约束与设置原则

本节将回顾工具检查建立保持时间的原则,接下来介绍设置多周期后的检查原则。多周期命令是设计约束中常用的一个命令,用来修改默认的建立or保持时间的关系。基本语法如下 默认的建立时间与保持时间的检查方式 DC工具计算默认的建立保持时间关系是基于时钟…

EXSI 实用指南 2024 -编译环境 Mac OS 安装篇(一)

1. 引言 在现代虚拟化技术的快速发展中,VMware ESXi 作为领先的虚拟化平台,凭借其高性能、稳定性和丰富的功能,广泛应用于企业和个人用户。ESXi 能有效地提高硬件资源利用率,并简化 IT 基础设施的管理。然而,如何在 V…

RK3568平台(显示篇)主屏副屏配置

一.主屏副屏配置 目前在RK3568平台上有两路HDMIOUT输出,分别输出到两个屏幕上,一路配置为主屏,一路配置为副屏。 硬件原理图: &hdmi0_in_vp2 {status "okay"; };&hdmi1_in_vp0 {status "okay"; }…

idea修改全局配置、idea中用aliyun的脚手架,解决配置文件中文乱码

idea修改全局配置 idea中用aliyun的脚手架,创建springBoot项目 解决配置文件中文乱码

基于springboot+mybatis学生管理系统

基于springbootmybatis学生管理系统 简介: 题目虽然是学生管理系统,但功能包含(学生,教师,管理员),项目基于springboot2.1.x实现的管理系统。 编译环境 : jdk 1.8 mysql 5.5 tomcat 7 框架 : springboot…

p15 p16 c语言实现三子棋

具体的实现代码 game.c #include "game.h"void InitBoard(char board[ROW][COL], int row, int col) {int i 0;int j 0;for (i 0; i < row; i) {for (j 0; j < col; j) {board[i][j] ;}} }void DisplayBoard(char board[ROW][COL], int row, int col) …

springboot系列九: 接收参数相关注解

文章目录 基本介绍接收参数相关注解应用实例PathVariableRequestHeaderRequestParamCookieValueRequestBodyRequestAttributeSessionAttribute 复杂参数基本介绍应用实例 自定义对象参数-自动封装基本介绍应用实例 基本介绍 1.SpringBoot 接收客户端提交数据 / 参数会使用到相…

二进制二维数组与装箱问题

装箱问题&#xff08;Bin Packing Problem&#xff09;是一类经典的优化问题&#xff0c;其目标是将一系列项目&#xff08;通常具有不同的体积或重量&#xff09;分配到尽量少的箱子中&#xff0c;使得每个箱子的容量不被超出。这种问题在物流、资源分配、内存管理等领域有广泛…

LinkedList----源码分析

源码介绍 public class LinkedList<E>extends AbstractSequentialList<E>implements List<E>, Deque<E>, Cloneable, java.io.Serializable{} 添加过程中的操作&#xff1a; 当创建LinkedList类时&#xff0c;会调用其空参构造方法&#xff0c;将其参…

第一关:Linux基础知识

Linux基础知识目录 前言LinuxInternStudio 关卡1. InternStudio开发机介绍2. SSH及端口映射2.1 什么是SSH&#xff1f;2.2 如何使用SSH远程连接开发机&#xff1f;2.2.1 使用密码进行SSH远程连接2.2.2 配置SSH密钥进行SSH远程连接2.2.3 使用VScode进行SSH远程连接 2.3. 端口映射…

6-5,web3浏览器链接区块链(react+区块链实战)

6-5&#xff0c;web3浏览器链接区块链&#xff08;react区块链实战&#xff09; 6-5 web3浏览器链接区块链&#xff08;调用读写合约与metamask联动&#xff09; 6-5 web3浏览器链接区块链&#xff08;调用读写合约与metamask联动&#xff09; 这里就是浏览器端和智能合约的交…

论文阅读【时空+大模型】ST-LLM(MDM2024)

论文阅读【时空大模型】ST-LLM&#xff08;MDM2024&#xff09; 论文链接&#xff1a;Spatial-Temporal Large Language Model for Traffic Prediction 代码仓库&#xff1a;https://github.com/ChenxiLiu-HNU/ST-LLM 发表于MDM2024&#xff08;Mobile Data Management&#xf…

回归损失和分类损失

回归损失和分类损失是机器学习模型训练过程中常用的两类损失函数&#xff0c;分别适用于回归任务和分类任务。 回归损失函数 回归任务的目标是预测一个连续值&#xff0c;因此回归损失函数衡量预测值与真实值之间的差异。常见的回归损失函数有&#xff1a; 均方误差&#xff…

srs直播内网拉流带宽飙升问题记录

问题背景 srs部署在云服务器上&#xff0c;32核cpu&#xff0c;64G内存&#xff0c;带宽300M. 客户端从srs拉流&#xff0c;发现外网客户端拉流&#xff0c;cpu和带宽都正常。然而内网客户端拉流&#xff0c;拉流人数超过5人以上&#xff0c;带宽就会迅速飙升。 排查 用srs…

Pandas数学函数大揭秘:让数据处理变得如此简单高效,轻松玩转数据分析新纪元!

1.导包 # 导包 import numpy as np import pandas as pd2.聚合函数 df pd.DataFrame(datanp.random.randint(0,100,size(5,3))) df01203550281552376231419335895434679917 # 列非空元素的数量 df.count()0 5 1 5 2 5 dtype: int64# 行非空元素的数量 df.count(ax…