什么是GAN?

一、基本概念

        生成对抗网络(Generative Adversarial Network,GAN)是一种由两个神经网络共同组成深度学习模型:生成器(Generator)和判别器(Discriminator)。这两个网络通过对抗的方式进行训练,生成器尝试伪造逼真的样本数据,而判别器则负责判断输入的数据是真实数据还是生成器伪造出来的数据。理想情况下,判别器对真实样本和生成样本的判断概率都是1/2,意味着判别器已经无法判断生成器生成的数据真假。

二、模型原理

        GAN的模型原理并不复杂。首先,GAN由以下两个子模型组成:

  • 生成器(Generator)从随机噪声中生成数据,目标是欺骗判别器,使其认为生成的数据是真实的。
  • 判别器(Discriminator):判断输入数据是来自真实数据分布还是生成器,目标是正确区分真实数据和生成数据。

        然后,GAN的损失函数是训练的核心,我们需要构建一个合适的损失函数用于衡量生成器和判别器的表现:

  • 生成器损失(G_loss):通常表示为最大化判别器对其生成样本的错误分类概率,也就是判别器判定所有生成数据均为真。
  • 判别器损失(D_loss):由两部分组成,一部分是真实样本的损失(标签为1),另一部分是生成样本的损失(标签为0)。

        最后,我们通过算法设计来交替训练生成器和判别器,例如生成器每训练5个Epoch,我们就训练一次判别器:

  • 训练判别器:提高其区分真实样本和生成样本的能力。
  • 训练生成器:提高其生成真实样本的能力,目标是最大化判别器将其生成样本识别为真实样本的概率。

三、python实现

1、导库

import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
import matplotlib.pyplot as plt
from sklearn.datasets import load_iris
from sklearn.preprocessing import StandardScaler, OneHotEncoder
from torch.utils.data import DataLoader, TensorDataset
from sklearn.decomposition import PCA

2、数据处理

        这里我们的目标是训练一个生成对抗网络来生成iris数据,使用sklearn的iris数据集训练。这意味着,我们输入给生成器的信息中需要包含类别信息,这样生成器才能生成对应类别的数据样本。当然,这一步不是必要的,在类别不敏感的任务中,只需要生成符合要求的数据即可。

# 加载Iris数据集
iris = load_iris()
data = iris.data
labels = iris.target

# 标准化数据
scaler = StandardScaler()
data = scaler.fit_transform(data)

# One-hot编码标签
encoder = OneHotEncoder(sparse=False)
# torch.Size([100, 3])
labels = encoder.fit_transform(labels.reshape(-1, 1))

# 转换为PyTorch张量
data = torch.FloatTensor(data)
labels = torch.FloatTensor(labels)

# 创建数据加载器
batch_size = 32
dataset = TensorDataset(data, labels)
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)

3、构建生成器

        这里,我们构建一个全连接神经网络。生成器的输入包括随机初始化的x,以及x对应的期望类别,期望类别是可以真实标签,表示生成对应类别下的数据样本。

# 生成器网络
class Generator(nn.Module):
    def __init__(self, input_dim, label_dim, output_dim):
        super(Generator, self).__init__()
        self.model = nn.Sequential(
            nn.Linear(input_dim + label_dim, 128),
            nn.ReLU(),
            nn.Linear(128, 256),
            nn.ReLU(),
            nn.Linear(256, output_dim),
        )

    def forward(self, x, labels):
        x = torch.cat([x, labels], 1)
        return self.model(x)

4、构建判别器

        这里,我们的判别器实际上是一个二分类模型。判别器的输入维度跟生成器一直,都需要考虑类别信息。

# 判别器网络
class Discriminator(nn.Module):
    def __init__(self, input_dim, label_dim):
        super(Discriminator, self).__init__()
        self.model = nn.Sequential(
            nn.Linear(input_dim + label_dim, 256),
            nn.LeakyReLU(0.2),
            nn.Linear(256, 128),
            nn.LeakyReLU(0.2),
            nn.Linear(128, 1),
            nn.Sigmoid()
        )

    def forward(self, x, labels):
        x = torch.cat([x, labels], 1)
        return self.model(x)

5、超参数设置

        值得注意的是,我们分别为生成器和判别器构造一个优化器,从而便于分开训练两个子模型。

# 设置超参数
latent_dim = 100
data_dim = data.shape[1]
label_dim = labels.shape[1]
lr = 0.0002
num_epochs = 200

# 初始化生成器和判别器
generator = Generator(latent_dim, label_dim, data_dim)
discriminator = Discriminator(data_dim, label_dim)

# 优化器
optimizer_G = optim.Adam(generator.parameters(), lr=lr)
optimizer_D = optim.Adam(discriminator.parameters(), lr=lr)

# 损失函数
criterion = nn.BCELoss()

6、模型训练

        这里,我们选择了分开训练生成器和判别器,在一个epoch中,先训练3次生成器,再训练一次判别器。这样的目的是增加生成器的学习时间,从而使得生成的样本更为真实。

# 训练GAN
for epoch in range(num_epochs):
    for i, (real_data, real_labels) in enumerate(dataloader):
        batch_size = real_data.size(0)
        
        # 当前仅训练生成器
        generator.train()
        discriminator.eval()

        # 迭代训练生成器,这里是每个epoch训练3次
        for _ in range(3):
            z = torch.randn(batch_size, latent_dim)
            # 直接使用真实标签即可,这里的标签代表的是样本类别,目的是让模型学习到类别差异
            # 生成器生成的是各对应类别的数据
            fake_data = generator(z, real_labels)
            # 使用判别器对生成的假数据进行分类
            outputs = discriminator(fake_data, real_labels)
            # 基于判别器的结果计算生成器的损失,目标是让判别器认为生成的数据是真实的(标签为1)
            # 如果这里使用的是torch.zeros则生成器的结果将会非常差,几乎无法生成真实数据
            # 这是由于我们的目标是让outputs逼近全1向量,也就是让判别器认为所有生成的数据都是真实的,这样才能让生成样本越来越真实
            g_loss = criterion(outputs, torch.ones(batch_size, 1))

            # 反向传播生成器的梯度
            optimizer_G.zero_grad()
            g_loss.backward()
            optimizer_G.step()

        # 当前仅训练判别器
        generator.eval()
        discriminator.train()

        # 训练判别器,真实样本标签为1,生成样本标签为0
        real_targets = torch.ones(batch_size, 1)
        fake_targets = torch.zeros(batch_size, 1)

        # 真实数据损失
        outputs = discriminator(real_data, real_labels)
        d_loss_real = criterion(outputs, real_targets)
        real_score = outputs

        # 生成假数据,计算损失
        z = torch.randn(batch_size, latent_dim)
        fake_data = generator(z, real_labels)
        outputs = discriminator(fake_data.detach(), real_labels)
        # 这里的目标与上面生成器部分相反,我们是要让outputs逼近全0向量,也就是全部预测出假数据
        # 所以fake_targets是一个全0向量
        d_loss_fake = criterion(outputs, fake_targets)
        fake_score = outputs

        # 总的判别器损失
        d_loss = d_loss_real + d_loss_fake

        # 反向传播判别器的梯度
        optimizer_D.zero_grad()
        d_loss.backward()
        optimizer_D.step()

    if epoch%10==0:
        # 打印损失
        print(f'Epoch [{epoch+1}/{num_epochs}], d_loss: {d_loss.item():.4f}, g_loss: {g_loss.item():.4f}, '
              f'D(x): {real_score.mean().item():.4f}, D(G(z)): {fake_score.mean().item():.4f}')

7、生成新数据

        最后,我们使用训练好的GAN中的生成器来生成一批新数据。可以看到,效果不错。

# 生成新数据
num_samples = 100
z = torch.randn(num_samples, latent_dim)
labels = np.array([0, 1, 2] * (num_samples // 3) + [0] * (num_samples % 3))
labels = encoder.transform(labels.reshape(-1, 1))
labels = torch.FloatTensor(labels)
generated_data = generator(z, labels).detach().numpy()

# 降维
pca = PCA(n_components=2)
data_2d = pca.fit_transform(data)
generated_data_2d = pca.transform(generated_data)

# 可视化生成的数据
plt.figure(figsize=(10, 5))
for i in range(3):
    real_class_data = data_2d[iris.target == i]
    generated_class_data = generated_data_2d[np.argmax(labels.numpy(), axis=1) == i]
    plt.scatter(real_class_data[:, 0], real_class_data[:, 1], label=f'Real Class {i}')
    plt.scatter(generated_class_data[:, 0], generated_class_data[:, 1], label=f'Generated Class {i}')
plt.legend()
plt.show()

7aab63ce61704b35afa41a14652ac1e7.png

四、完整代码

import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
import matplotlib.pyplot as plt
from sklearn.datasets import load_iris
from sklearn.preprocessing import StandardScaler, OneHotEncoder
from torch.utils.data import DataLoader, TensorDataset
from sklearn.decomposition import PCA


# 加载Iris数据集
iris = load_iris()
data = iris.data
labels = iris.target

# 标准化数据
scaler = StandardScaler()
data = scaler.fit_transform(data)

# One-hot编码标签
encoder = OneHotEncoder(sparse=False)
labels = encoder.fit_transform(labels.reshape(-1, 1))

# 转换为PyTorch张量
data = torch.FloatTensor(data)
labels = torch.FloatTensor(labels)

# 创建数据加载器
batch_size = 32
dataset = TensorDataset(data, labels)
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)

# 生成器网络
class Generator(nn.Module):
    def __init__(self, input_dim, label_dim, output_dim):
        super(Generator, self).__init__()
        self.model = nn.Sequential(
            nn.Linear(input_dim + label_dim, 128),
            nn.ReLU(),
            nn.Linear(128, 256),
            nn.ReLU(),
            nn.Linear(256, output_dim),
        )

    def forward(self, x, labels):
        x = torch.cat([x, labels], 1)
        return self.model(x)

# 判别器网络
class Discriminator(nn.Module):
    def __init__(self, input_dim, label_dim):
        super(Discriminator, self).__init__()
        self.model = nn.Sequential(
            nn.Linear(input_dim + label_dim, 256),
            nn.LeakyReLU(0.2),
            nn.Linear(256, 128),
            nn.LeakyReLU(0.2),
            nn.Linear(128, 1),
            nn.Sigmoid()
        )

    def forward(self, x, labels):
        x = torch.cat([x, labels], 1)
        return self.model(x)

# 设置超参数
latent_dim = 100
data_dim = data.shape[1]
label_dim = labels.shape[1]
lr = 0.0002
num_epochs = 200

# 初始化生成器和判别器
generator = Generator(latent_dim, label_dim, data_dim)
discriminator = Discriminator(data_dim, label_dim)

# 优化器
optimizer_G = optim.Adam(generator.parameters(), lr=lr)
optimizer_D = optim.Adam(discriminator.parameters(), lr=lr)

# 损失函数
criterion = nn.BCELoss()

# 训练GAN
for epoch in range(num_epochs):
    for i, (real_data, real_labels) in enumerate(dataloader):
        batch_size = real_data.size(0)
        
        generator.train()
        discriminator.eval()
        # 迭代训练生成器,这里是每个epoch训练3次
        for _ in range(3):
            z = torch.randn(batch_size, latent_dim)
            # 直接使用真实标签即可,这里的标签代表的是样本类别,目的是让模型学习到类别差异
            # 生成器生成的是各对应类别的数据
            fake_data = generator(z, real_labels)
            # 使用判别器对生成的假数据进行分类
            outputs = discriminator(fake_data, real_labels)
            # 基于判别器的结果计算生成器的损失,目标是让判别器认为生成的数据是真实的(标签为1)
            # 如果这里使用的是torch.zeros则生成器的结果将会非常差,几乎无法生成真实数据
            # 这是由于我们的目标是让outputs逼近全1向量,也就是让判别器认为所有生成的数据都是真实的,这样才能让生成样本越来越真实
            g_loss = criterion(outputs, torch.ones(batch_size, 1))

            # 反向传播生成器的梯度
            optimizer_G.zero_grad()
            g_loss.backward()
            optimizer_G.step()

        generator.eval()
        discriminator.train()
        # 训练判别器,真实样本标签为1,生成样本标签为0
        real_targets = torch.ones(batch_size, 1)
        fake_targets = torch.zeros(batch_size, 1)

        # 真实数据损失
        outputs = discriminator(real_data, real_labels)
        d_loss_real = criterion(outputs, real_targets)
        real_score = outputs

        # 生成假数据,计算损失
        z = torch.randn(batch_size, latent_dim)
        fake_data = generator(z, real_labels)
        outputs = discriminator(fake_data.detach(), real_labels)
        # 这里的目标与上面生成器部分相反,我们是要让outputs逼近全0向量,也就是全部预测出假数据
        # 所以fake_targets是一个全0向量
        d_loss_fake = criterion(outputs, fake_targets)
        fake_score = outputs

        # 总的判别器损失
        d_loss = d_loss_real + d_loss_fake

        # 反向传播判别器的梯度
        optimizer_D.zero_grad()
        d_loss.backward()
        optimizer_D.step()

    if epoch%10==0:
        # 打印损失
        print(f'Epoch [{epoch+1}/{num_epochs}], d_loss: {d_loss.item():.4f}, g_loss: {g_loss.item():.4f}, '
              f'D(x): {real_score.mean().item():.4f}, D(G(z)): {fake_score.mean().item():.4f}')

# 生成新数据
num_samples = 100
z = torch.randn(num_samples, latent_dim)
labels = np.array([0, 1, 2] * (num_samples // 3) + [0] * (num_samples % 3))
labels = encoder.transform(labels.reshape(-1, 1))
labels = torch.FloatTensor(labels)
generated_data = generator(z, labels).detach().numpy()

# 降维
pca = PCA(n_components=2)
data_2d = pca.fit_transform(data)
generated_data_2d = pca.transform(generated_data)

# 可视化生成的数据
plt.figure(figsize=(10, 5))
for i in range(3):
    real_class_data = data_2d[iris.target == i]
    generated_class_data = generated_data_2d[np.argmax(labels.numpy(), axis=1) == i]
    plt.scatter(real_class_data[:, 0], real_class_data[:, 1], label=f'Real Class {i}')
    plt.scatter(generated_class_data[:, 0], generated_class_data[:, 1], label=f'Generated Class {i}')
plt.legend()
plt.show()

五、总结

        生成对抗网络是一个很经典的深度学习模型,它在诸多领域中发挥着重要作用。除了超参数调整之外,训练GAN的另一个关键步骤是构造一个合适的训练策略。例如,可以同时训练生成器和判别器,也可以交替训练二者,或者先训练生成器再训练判别器等等。但是,这两个网络是相互博弈的,由于生成器参数是随机初始化的,一开始生成的数据质量往往较差。我们的策略一般是先让生成器变强(通过构造更复杂的网络结构或者更多的训练次数),让生成的数据质量先提升。这样随着训练的迭代,生成的样本越来越逼真,判别器也不得不为了最小化D_loss而提升自身的能力。

 

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/927763.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

RNN And CNN通识

CNN And RNN RNN And CNN通识一、卷积神经网络(Convolutional Neural Networks,CNN)1. 诞生背景2. 核心思想和原理(1)基本结构:(2)核心公式:(3)关…

【数据事务】.NET开源 ORM 框架 SqlSugar 系列

.NET开源 ORM 框架 SqlSugar 系列 【开篇】.NET开源 ORM 框架 SqlSugar 系列【入门必看】.NET开源 ORM 框架 SqlSugar 系列【实体配置】.NET开源 ORM 框架 SqlSugar 系列【Db First】.NET开源 ORM 框架 SqlSugar 系列【Code First】.NET开源 ORM 框架 SqlSugar 系列【数据事务…

南昌大学(NCU)羽毛球场地预约脚本

在冬天进行羽毛球运动是一个很好的选择,它能帮助你保持身体活力,增强心肺功能,并促进血液循环。但是室友和师弟师妹反应,学校的羽毛球场地有限,手速慢的根本预约不到场地。 中午12:00准时开放预约&#xff…

三种方式(oss、本地、minio)图片的上传下载

一、OSS 1、前期准备 1.1 注册阿里云账号,开启对象存储oss功能,创建一个bucket(百度教程多的是,跟着创建一个就行,创建时注意存储类型是标准存储,读写权限是公共读) 有的在创建桶时读写属性是…

关于Nginx前后端分离部署spring boot和vue工程以及反向代理的配置说明

最近项目中用到关于Nginx前后端分离部署spring boot和vue工程以及反向代理的配置,总结了一下说明: 1、后端是spring boot工程,端口8000,通过 jar命令启动 nohup java -jar xxx-jsonflow-biz.jar > /usr/local/nohup.out 2>…

debian 11 虚拟机环境搭建过坑记录

目录 安装过程系统配置修改 sudoers 文件网络配置换源安装桌面mount nfs 挂载安装复制功能tab 无法补全其他安装 软件配置eclipse 配置git 配置老虚拟机硬盘挂载 参考 原来去 debian 官网下载了一个最新的 debian 12,安装后出现包依赖问题,搞了半天&…

Android:生成Excel表格并保存到本地

提醒 本文实例是使用Kotlin进行开发演示的。 一、技术方案 org.apache.poi:poiorg.apache.poi:poi-ooxml 二、添加依赖 [versions]poi "5.2.3" log4j "2.24.2"[libraries]#https://mvnrepository.com/artifact/org.apache.poi/poi apache-poi { module…

RK3576技术笔记之一 RK3576单板介绍

第二篇嘛,亮亮我们做出来的板子,3576这个片子的基本功能接口单板都做了,接口数量肯定是比不上3588(PS:这个我们也在做,后续都完成后会发文章),但是比起3568来说还是升级了&#xff0…

SQL进阶技巧:如何寻找同一批用户 | 断点分组应用【最新面试题】

目录 0 问题描述 1 数据准备 2 问题分析 ​编辑 3 小结 0 问题描述 用户登录时间不超过10分钟的视为同一批用户,找出以下用户哪些属于同一批用户(SQL实现) 例如: user_name time a 2024-10-01 09:55 b 2024-10-01 09:57 c 2024-10-01…

数字图像处理(11):RGB转YUV

(1)RGB颜色空间 RGB颜色空间,是一种基于红色、绿色、蓝色三种基本颜色进行混合的颜色空间,通过这三种颜色的叠加,可以产生丰富而广泛的颜色。RGB颜色空间在计算机图像处理、显示器显示、摄影和影视制作等领域具有广泛应…

利用Ubuntu批量下载modis图像(New)

由于最近modis原来批量下载的代码不再直接给出,因此,再次梳理如何利用Ubuntu下载modis数据。 之前的下载代码为十分长,现在只给出一部分,需要自己再补充另一部分。之前的为: 感谢郭师兄的指导(https://blo…

HTTP 长连接(HTTP Persistent Connection)简介

HTTP长连接怎么看? HTTP 长连接(HTTP Persistent Connection)简介 HTTP 长连接(Persistent Connection)是 HTTP/1.1 的一个重要特性,它允许在一个 TCP 连接上发送多个 HTTP 请求和响应,而无需为…

淘宝商品详情主图SKU图价格|品牌监控|电商API接口

淘宝/天猫获得淘宝商品详情 API 返回值说明 item_get-获得淘宝商品详情 taobao.item_get 公共参数 名称类型必须描述keyString是调用key(必须以GET方式拼接在URL中)secretString是调用密钥api_nameString是API接口名称(包括在请求地址中&a…

单片机学习笔记 17. 串口通信-发送汉字

更多单片机学习笔记:单片机学习笔记 1. 点亮一个LED灯单片机学习笔记 2. LED灯闪烁单片机学习笔记 3. LED灯流水灯单片机学习笔记 4. 蜂鸣器滴~滴~滴~单片机学习笔记 5. 数码管静态显示单片机学习笔记 6. 数码管动态显示单片机学习笔记 7. 独立键盘单片机学习笔记 8…

五层网络协议(封装和分用)

目录 七层网络协议五层网络协议封装1.应用层2.传输层3.网络层4.数据链路层5.物理层 分用1. 物理层2.数据链路层3.网络层 IP 协议4.传输层 UDP 协议5.应用层 七层网络协议 网络通信过程中,需要涉及到的细节,其实是非常非常多的,如果要有一个协…

阿里云人工智能平台(PAI)免费使用教程

文章目录 注册新建实例交互式建模(DSW)注册 注册阿里云账号进行支付宝验证 新建实例 选择资源信息和环境信息,填写实例名称 资源类型需要选择公共资源,才能使用资源包进行抵扣。目前每月送250计算时。1 * NVIDIA A10 8 vCPU 30 GiB 1 * 24 GiB1 * NVIDIA V100 8 vCPU 32 Gi…

【实战】Oracle基础之控制文件内容的5种查询方法

关于Jady: ★工作经验:近20年IT技术服务经验,熟悉业务又深耕技术,为业务加持左能进行IT技术规划,右能处理综合性故障与疑难杂症; ★成长历程:网络运维、主机/存储运维、程序/数据库开发、大数…

蓝桥杯第 23 场 小白入门赛

一、前言 好久没打蓝桥杯官网上的比赛了,回来感受一下,这难度区分度还是挺大的 二、题目总览 三、具体题目 3.1 1. 三体时间【算法赛】 思路 额...签到题 我的代码 // Problem: 1. 三体时间【算法赛】 // Contest: Lanqiao - 第 23 场 小白入门赛 …

使用 Pytorch 构建 Vanilla GAN

文章目录 一、说明二、什么是 GAN?三、使用 PyTorch 的简单 GAN(完整解释的代码示例)3.1 配置变量3.2 、PyTorch 加速3.3 构建生成器3.4 构建鉴别器 四、准备数据集五、初始化函数六、前向和后向传递七、执行训练步骤八、结果 一、说明 使用…

【Windows 11专业版】使用问题集合

博文将不断学习补充 I、设置WIN R打开应用默认使用管理员启动 1、WIN R输入 secpol.msc 进入“本地安全策略”。 2、按照如下路径,找到条目: “安全设置”—“本地策略”—“安全选项”—“用户账户控制:以管理员批准模式运行所有管理员” …