解析生成对抗网络(GAN):原理与应用

目录

一、引言

二、生成对抗网络原理

(一)基本架构

(二)训练过程

三、生成对抗网络的应用

(一)图像生成

无条件图像生成:

(二)数据增强

(三)风格迁移

四、生成对抗网络训练中的挑战与解决策略

(一)模式崩溃

(二)梯度消失


一、引言

生成对抗网络(GAN)自 2014 年被 Goodfellow 等人提出以来,在深度学习领域引起了广泛的关注和研究热潮。它创新性地引入了一种对抗训练的思想,通过生成器和判别器的相互博弈,使得生成器能够学习到数据的潜在分布,从而生成逼真的样本数据。这种独特的机制使得 GAN 在图像生成、文本生成、音频生成等多个领域展现出了巨大的潜力,为人工智能技术的发展带来了新的突破和方向。

二、生成对抗网络原理

(一)基本架构

GAN 主要由两个核心组件构成:生成器(Generator)和判别器(Discriminator)。

  1. 生成器
    • 生成器的任务是接收一个随机噪声向量 (通常从一个简单的分布,如标准正态分布 N(0,1)采样得到),并通过一系列的神经网络层将其映射为与真实数据相似的生成数据G(z)
    • 例如,在图像生成任务中,生成器的输出将是一张与训练数据集中图像具有相似特征的合成图像。
    • 生成器通常采用多层的反卷积神经网络(Deconvolutional Neural Network)或转置卷积神经网络(Transposed Convolutional Neural Network)结构。以生成64*64其网络结构如下:
      import torch
      import torch.nn as nn
      
      class Generator(nn.Module):
          def __init__(self):
              super(Generator, self).__init__()
              # 输入为 100 维的噪声向量
              self.fc = nn.Linear(100, 4 * 4 * 1024)
              self.deconv1 = nn.ConvTranspose2d(1024, 512, kernel_size=4, stride=2, padding=1)
              self.bn1 = nn.BatchNorm2d(512)
              self.deconv2 = nn.ConvTranspose2d(512, 256, kernel_size=4, stride=2, padding=1)
              self.bn2 = nn.BatchNorm2d(256)
              self.deconv3 = nn.ConvTranspose2d(256, 128, kernel_size=4, stride=2, padding=1)
              self.bn3 = nn.BatchNorm2d(128)
              self.deconv4 = nn.ConvTranspose2d(128, 3, kernel_size=4, stride=2, padding=1)
      
          def forward(self, x):
              x = self.fc(x)
              x = x.view(-1, 1024, 4, 4)
              x = torch.relu(self.bn1(self.deconv1(x)))
              x = torch.relu(self.bn2(self.deconv2(x)))
              x = torch.relu(self.bn3(self.deconv3(x)))
              x = torch.tanh(self.deconv4(x))
              return x

  2. 判别器
  • 判别器的作用是区分输入的数据是来自真实数据分布还是由生成器生成的数据。它接收真实数据 x 或生成数据 G(z),并输出一个表示数据真实性的概率值  D(x)或D(G(z)) ,取值范围在 0 到  1之间,接近  表示数据更可能是真实的,接近  表示数据更可能是生成的。

判别器通常采用卷积神经网络(Convolutional Neural Network)结构。例如,对于判断  彩色图像的判别器网络结构如下:

class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()
        self.conv1 = nn.Conv2d(3, 128, kernel_size=4, stride=2, padding=1)
        self.bn1 = nn.BatchNorm2d(128)
        self.conv2 = nn.Conv2d(128, 256, kernel_size=4, stride=2, padding=1)
        self.bn2 = nn.BatchNorm2d(256)
        self.conv3 = nn.Conv2d(256, 512, kernel_size=4, stride=2, padding=1)
        self.bn3 = nn.BatchNorm2d(512)
        self.conv4 = nn.Conv2d(512, 1, kernel_size=4, stride=2, padding=0)

    def forward(self, x):
        x = torch.relu(self.bn1(self.conv1(x)))
        x = torch.relu(self.bn2(self.conv2(x)))
        x = torch.relu(self.bn3(self.conv3(x)))
        x = torch.sigmoid(self.conv4(x))
        return x.view(-1)

(二)训练过程

GAN 的训练过程是一个对抗性的迭代过程。

三、生成对抗网络的应用

(一)图像生成

1.无条件图像生成

GAN 可以用于生成各种类型的图像,如人脸图像、风景图像等。例如,在人脸图像生成任务中,通过在大规模人脸数据集上训练 GAN,生成器能够学习到人脸的各种特征,如五官的形状、肤色、表情等,从而生成全新的、逼真的人脸图像。

代码示例:

# 假设已经定义好生成器 G 和判别器 D,以及相关的优化器和损失函数
# 训练循环
num_epochs = 100
for epoch in range(num_epochs):
    for i, (real_images, _) in enumerate(dataloader):
        # 训练判别器
        # 采样噪声
        z = torch.randn(real_images.shape[0], 100).to(device)
        # 生成假图像
        fake_images = G(z)
        # 计算判别器损失
        real_loss = criterion(D(real_images), torch.ones(real_images.shape[0]).to(device))
        fake_loss = criterion(D(fake_images.detach()), torch.zeros(fake_images.shape[0]).to(device))
        d_loss = (real_loss + fake_loss) / 2
        # 更新判别器参数
        d_optimizer.zero_grad()
        d_loss.backward()
        d_optimizer.step()

        # 训练生成器
        # 再次采样噪声
        z = torch.randn(real_images.shape[0], 100).to(device)
        # 生成假图像
        fake_images = G(z)
        # 计算生成器损失
        g_loss = criterion(D(fake_images), torch.ones(fake_images.shape[0]).to(device))
        # 更新生成器参数
        g_optimizer.zero_grad()
        g_loss.backward()
        g_optimizer.step()

2.条件图像生成

可以通过在生成器和判别器的输入中加入条件信息,实现条件图像生成。例如,根据给定的文本描述生成相应的图像,或者根据特定的类别标签生成属于该类别的图像。

以根据类别标签生成图像为例,在生成器的输入中除了噪声向量 ,还加入类别标签的编码向量 ,生成器的网络结构需要进行相应修改,如:

class ConditionalGenerator(nn.Module):
    def __init__(self, num_classes):
        super(ConditionalGenerator, self).__init__()
        # 输入为 100 维噪声向量和类别编码向量
        self.fc = nn.Linear(100 + num_classes, 4 * 4 * 1024)
        # 后续的反卷积层与无条件生成器类似
        self.deconv1 = nn.ConvTranspose2d(1024, 512, kernel_size=4, stride=2, padding=1)
        self.bn1 = nn.BatchNorm2d(512)
        #...

    def forward(self, x, y):
        # 拼接噪声向量和类别编码向量
        x = torch.cat([x, y], dim=1)
        x = self.fc(x)
        x = x.view(-1, 1024, 4, 4)
        x = torch.relu(self.bn1(self.deconv1(x)))
        #...
        return x

(二)数据增强

  • 图像数据增强
    • 在图像分类、目标检测等任务中,数据量不足可能导致模型过拟合。GAN 可以用于生成额外的图像数据来扩充数据集。通过在原始图像数据集上训练 GAN,生成与原始图像相似但又有一定变化的图像,如不同角度、光照条件下的图像,从而增加数据的多样性,提高模型的泛化能力。
  • 其他数据类型的数据增强
    • 除了图像数据,GAN 也可以应用于其他数据类型的数据增强,如文本数据。例如,通过生成与原始文本相似的新文本,扩充文本数据集,有助于训练更强大的文本处理模型,如文本分类、机器翻译等模型。

(三)风格迁移

  • 图像风格迁移原理
    • GAN 可以实现图像风格迁移,即将一幅图像的内容与另一幅图像的风格进行融合。其原理是通过定义内容损失和风格损失,利用生成器生成具有目标风格的图像,同时判别器用于判断生成图像的质量和风格一致性。
    • 例如,使用预训练的 VGG 网络来计算内容损失和风格损失。内容损失衡量生成图像与原始内容图像在特征表示上的差异,风格损失衡量生成图像与目标风格图像在风格特征(如纹理、颜色分布等)上的差异。

代码示例实现风格迁移

import torchvision.models as models
import torch.nn.functional as F

# 加载预训练的 VGG 网络
vgg = models.vgg19(pretrained=True).features.eval().to(device)

# 定义内容损失函数
def content_loss(content_features, generated_features):
    return F.mse_loss(content_features, generated_features)

# 定义风格损失函数
def style_loss(style_features, generated_features):
    style_loss = 0
    for s_feat, g_feat in zip(style_features, generated_features):
        # 计算 Gram 矩阵
        s_gram = gram_matrix(s_feat)
        g_gram = gram_matrix(g_feat)
        style_loss += F.mse_loss(s_gram, g_gram)
    return style_loss

# Gram 矩阵计算函数
def gram_matrix(x):
    b, c, h, w = x.size()
    features = x.view(b * c, h * w)
    gram = torch.mm(features, features.t())
    return gram.div(b * c * h * w)

四、生成对抗网络训练中的挑战与解决策略

(一)模式崩溃

问题描述

模式崩溃是 GAN 训练中常见的问题之一,表现为生成器生成的样本多样性不足,往往集中在少数几种模式上。例如,在生成人脸图像时,可能生成的人脸都具有相似的特征,而不能涵盖人脸的多种可能形态。

解决策略

Wasserstein GAN(WGAN):WGAN 对 GAN 的损失函数进行了改进,采用 Wasserstein 距离来衡量真实数据分布和生成数据分布之间的差异,而不是传统的 JS 散度。这使得训练过程更加稳定,减少了模式崩溃的发生。其关键代码修改如下:

# 判别器的最后一层不再使用 Sigmoid 激活函数
self.conv4 = nn.Conv2d(512, 1, kernel_size=4, stride=2, padding=0)
# 定义 WGAN 的损失函数
def wgan_loss(real_pred, fake_pred):
    return -torch.mean(real_pred) + torch.mean(fake_pred)

模式正则化:通过在生成器的损失函数中加入正则化项,鼓励生成器生成更多样化的样本。例如,在生成器的损失函数中加入对生成样本的熵约束,使得生成样本的分布更加均匀。

(二)梯度消失

  • 问题描述
    • 在 GAN 训练初期,当判别器的性能非常好时,生成器的梯度可能会变得非常小,导致生成器难以更新参数,无法有效地学习到数据的分布。这是因为判别器能够很容易地区分真实数据和生成数据,使得生成器的损失函数接近饱和,梯度趋近于 。
  • 解决策略
    • 梯度惩罚(Gradient Penalty):在判别器的损失函数中加入梯度惩罚项,限制判别器的梯度大小,使得判别器不会过于强大,从而缓解生成器的梯度消失问题。例如,在 WGAN-GP(Wasserstein GAN with Gradient Penalty)中,梯度惩罚项的计算如下:
      def gradient_penalty(critic, real, fake, device):
          BATCH_SIZE, C, H, W = real.shape
          # 随机采样插值系数
          alpha = torch.rand((BATCH_SIZE, 1, 1, 1)).repeat(1, C, H, W).to(device)
          # 计算插值数据
          interpolated_images = real * alpha + fake * (1 - alpha)
          # 计算判别器对插值数据的输出
          mixed_scores = critic(interpolated_images)
          # 计算梯度
          gradient = torch.autograd.grad(
              inputs=interpolated_images,
              outputs=mixed_scores,
              grad_outputs=torch.ones_like(mixed_scores),
              create_graph=True,
              retain_graph=True,
          )[0]
          # 计算梯度惩罚项
          gradient = gradient.view(gradient.shape[0], -1)
          gradient_norm = gradient.norm(2, dim=1)
          gradient_penalty = torch.mean((gradient_norm - 1) ** 2)
          return gradient_penalty

    • 使用 Leaky ReLU 激活函数:在判别器和生成器的网络中使用 Leaky ReLU 激活函数替代传统的 ReLU 激活函数。Leaky ReLU 允许在负半轴有一个较小的斜率,从而避免了在某些情况下神经元完全不激活导致的梯度消失问题。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/927387.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

零售餐饮收银台源码

收银系统早已成为门店经营的必备软件工具,因为各个连锁品牌有自己的经营模式,自然对收银系统需求各有不同,需要有相应的功能模块来实现其商业模式。 1. 适用行业 收银系统源码适用于零售、餐饮等行业门店,如商超便利店、水果生鲜…

我的第一个创作纪念日 —— 梦开始的地方

前言 时光荏苒,转眼间,我已经在CSDN这片技术沃土上耕耘了365天 今天,我迎来了自己在CSDN的第1个创作纪念日,这个特殊的日子不仅是对我过去努力的肯定,更是对未来持续创作的激励 机缘 回想起初次接触CSDN,那…

mac终端自定义命令打开vscode

1.打开终端配置文件 open -e ~/.bash_profile终端安装了zsh,那么配置文件是.zshrc(打开zsh配置,这里举🌰使用zsh) sudo open -e ~/.zshrc 2.在zshrc配置文件中添加新的脚本(这里的code就是快捷命令可以进…

计算帧率、每秒过多少次

1、c #include <iostream> #include <opencv2/opencv.hpp> #include <string> #include <thread> #include <atomic>using namespace std;const int NUM_THREADS 1; // 线程数量std::atomic<int> frameCounts[NUM_THREADS]; // 每个线程…

【在Linux世界中追寻伟大的One Piece】读者写者问题与读写锁

目录 1 -> 读者写者问题 1.1 -> 什么是读者写者问题 1.2 -> 读者写者与生产消费者的区别 1.3 -> 如何理解读者写者问题 2 -> 读写锁 2.1 -> 读写锁接口 3 -> 读者优先(Reader-Preference) 4 -> 写者优先(Writer-Preference) 1 -> 读者写者…

PS的功能学习(修复、画笔)

混合器画笔工具 就像&#xff0c;电子毛笔 关键功能有两个&#xff0c;自带一个混合器色板 清理画笔是全清&#xff0c;换一支新的毛笔&#xff0c;执行完之后在判断是否载入画笔 载入画笔就是把前景色上的颜色进行叠加处理&#xff0c;重新混入当前的混合色 &#xff08;…

centos 7 离线安装postgis插件

前一段时间记录了下如何在centos7中离线安装postgresql&#xff0c;因为工作需要&#xff0c;我不仅要安装postgresql&#xff0c;还需要安装postgis插件&#xff0c;这篇文章记录下postgis插件的安装过程。 1. 安装前的参考 如下的链接都是官网上的链接&#xff0c;对你安装p…

Vue 90 ,Element 13 ,Vue + Element UI 中 el-switch 使用小细节解析,避免入坑(获取后端的数据类型自动转变)

目录 前言 在开发过程中&#xff0c;我们经常遇到一些看似简单的问题&#xff0c;但有时正是这些细节问题让我们头疼不已。今天&#xff0c;我就来和大家分享一个我在开发过程中遇到的 el-switch 使用的小坑&#xff0c;希望大家在使用时能够避免。 一. 问题背景 二. 问题分…

同时使用Tmini和GS2两个雷达

24.12.02 要求&#xff1a;同时使用两个雷达。 问题在于:两个雷达都是ydlidar&#xff0c;使用同一个包。 因此同时启动GS2.launch和Tmini.launch会调用同一个功能节点&#xff0c;使用同一个cpp文件。 方法&#xff1a;新建一个cpp节点。 但同时保持在同一个坐标系&#xff0…

高等数学函数的性质

牛顿二项公式 ( x y ) n ∑ k 0 n C n k ⋅ x n − k y k (xy)^n\stackrel{n}{\sum\limits_{k0}}C^k_n\sdot x^{n-k}y^k (xy)nk0∑​n​Cnk​⋅xn−kyk. 映射 f : X → Y f:X\rightarrow Y f:X→Y&#xff0c; f f f 为 X X X 到 Y Y Y 的映射。 f f f 是一个对应关系&am…

【MySQL】深度学习数据库开发技术:mysql事务穿透式解析

前言&#xff1a;本节内容开始讲解事务。 博主计划用三节来讲解事务。 本篇为第一节&#xff0c; 主要解释什么是事务&#xff0c; 事务有什么用。 以及事物的基本操作和异常退出回滚情况。 下面不多说&#xff0c;友友们&#xff0c; 开始学习吧&#xff01; ps&#xff1a;本…

Swift解题 | 求平面上同一条直线的最多点数

文章目录 前言摘要问题描述解题思路Swift 实现代码代码分析示例测试与结果时间复杂度空间复杂度总结关于我们 前言 本题由于没有合适答案为以往遗留问题&#xff0c;最近有时间将以往遗留问题一一完善。 149. 直线上最多的点数 不积跬步&#xff0c;无以至千里&#xff1b;不积…

使用Ansible自动化部署Zabbix6监控

1、获取Ansible离线部署包 链接&#xff1a;https://pan.baidu.com/s/1EjI02Ni8m9J4eJeBcJ-ZUQ?pwdzabx 提取码&#xff1a;zabx 2、安装Ansible wget -O /etc/yum.repos.d/epel.repo https://mirrors.aliyun.com/repo/epel-7.repo yum -y install ansible3、修改hosts文件…

lua闭包Upvalue

闭包 lua任何函数都是闭包&#xff0c;闭包至少带1个upValue&#xff1b; CClosure是使用Lua提供的lua_pushcclosure这个C-Api加入到虚拟栈中的C函数&#xff0c;它是对LClosure的一种C模拟 如string.gmatch就是cclosure 定义&#xff1a; #define ClosureHeader \CommonH…

二叉搜索树之遍历

二叉搜索树是一种重要的数据结构&#xff0c;它的每个节点最多有两个子节点&#xff0c;称为左子节点和右子节点。 二叉搜索树的特性是&#xff1a;对于树中的每个节点&#xff0c;其左子树中的所有节点的值都小于该节点的值&#xff0c;而右子树中的所有节点的值都大于该节点…

Java基础访问修饰符全解析

一、Java 访问修饰符概述 Java 中的访问修饰符用于控制类、方法、变量和构造函数的可见性和访问权限&#xff0c;主要有四种&#xff1a;public、protected、default&#xff08;无修饰符&#xff09;和 private。 Java 的访问修饰符在编程中起着至关重要的作用&#xff0c;它…

安心护送转运平台小程序

安心护送转运平台小程序是一款基于FastAdminThinkPHPUniapp开发的非急救救护车租用转运平台小程序系统&#xff0c;可以根据运营者的业务提供类似短途接送救护服务&#xff0c;重症病人转运服务&#xff0c;长途跨省护送服务。

人工智能技术在外骨骼机器人中的应用,发展历程与原理介绍

大家好&#xff0c;我是微学AI&#xff0c;今天给大家介绍一下 人工智能技术在外骨骼机器人中的应用&#xff0c;发展历程与原理介绍 。外骨骼机器人是一种 套在人体外部的可穿戴机器人装置 &#xff0c;旨在增强人类的身体能力和运动功能。其独特之处在于能够与人体紧密配合&a…

类型转换与IO流:C++世界的变形与交互之道

文章目录 前言&#x1f384;一、类型转换&#x1f388;1.1 隐式类型转换&#x1f388;1.2 显式类型转换&#x1f381;1. C 风格强制类型转换&#x1f381;2. C 类型转换操作符 &#x1f388;1.3 C 类型转换操作符详解&#x1f381;1. static_cast&#x1f381;2. dynamic_cast&…

如何手搓一个智能宠物喂食器

背景 最近家里的猫胖了&#xff0c;所以我就想做个逗猫棒。找了一圈市场上的智能逗猫棒&#xff0c;运行轨迹比较单一&#xff0c;互动性不足。 轨迹单一&#xff0c;活动范围有限 而我希望后续可以结合人工智能物联网&#xff0c;通过摄像头来捕捉猫的位置&#xff0c;让小…