AI开发:生成式对抗网络入门 模型训练和图像生成 -Python 机器学习

阶段1:GAN是个啥?

生成式对抗网络(Generative Adversarial Networks, GAN),名字听着就有点“对抗”的意思,没错!它其实是两个神经网络互相斗智斗勇的游戏:

  1. 生成器(Generator):负责造假,生成一些以假乱真的数据。
  2. 判别器(Discriminator):负责打假,判断数据是真还是假。

想象一下,生成器是个假币制造商,判别器是个验钞机。假币制造商不断提升造假能力,验钞机也不断升级打假技巧。最终的目标是生成的假币足以以假乱真,让验钞机无法区分。

生成式对抗网络(GAN)是一种由 Ian Goodfellow 和他的团队在2014年提出的深度学习模型。GAN 本质上是一种用于生成与真实数据分布相似的“新数据”的方法,常用于图像生成、风格转换和数据增强等任务。

一、GAN 的基本概念
1. 两个网络:生成器(Generator)和判别器(Discriminator)

GAN 的核心思想是利用两个神经网络相互对抗:

  • 生成器 (G): 学习生成接近真实数据的“假数据”。其目标是“骗过”判别器,使其认为假数据是真的。
  • 判别器 (D): 学习区分真实数据和生成器生成的假数据。其目标是提高“识别假数据的能力”。

两者形成了一种动态博弈:

  • 生成器不断改进以生成更逼真的数据。
  • 判别器不断改进以更准确地区分真假数据。

最终目标:生成器生成的数据和真实数据难以区分,判别器无法给出明确的判断。

2. 训练目标

GAN 的训练目标可以通过以下损失函数来描述:

  • 判别器的损失:最大化真实数据的得分,最小化假数据的得分。
  • 生成器的损失:最小化判别器对假数据的判断分数(即尽量骗过判别器)。

数学公式为:

这里:

  • D(x)D(x) 表示判别器给真实数据 xx 的打分。
  • G(z)G(z) 表示生成器根据随机噪声 zz 生成的假数据。
3. GAN 的对抗过程

训练过程通常分为两步:

  1. 更新判别器: 让判别器学习如何区分真实和假数据。
  2. 更新生成器: 让生成器学习生成更真实的数据,以骗过判别器。

二、直观例子:警察与造假者

你可以将 GAN 的训练过程类比为“警察(判别器)与造假者(生成器)”之间的较量:

  • 一开始,造假者技术拙劣,警察很容易识破假币。
  • 随着时间推移,造假者的造假技术逐渐提高,而警察也在不断升级检测手段。
  • 最终,假币与真币变得极为相似,警察几乎无法分辨。

阶段2:从头写个最简单的GAN

import torch
import torch.nn as nn
import torch.optim as optim

# 1. 生成器(Generator):简单的全连接网络
class Generator(nn.Module):
    def __init__(self):
        super(Generator, self).__init__()
        self.model = nn.Sequential(
            nn.Linear(100, 256),  # 输入 100 维噪声,输出中间隐藏层 256 维
            nn.ReLU(),            # 激活函数 ReLU,增加非线性
            nn.Linear(256, 784),  # 隐藏层输出 784 维数据(28x28 图像展平后)
            nn.Tanh()             # 将输出限制到 [-1, 1],方便后续训练
        )
    
    def forward(self, z):
        return self.model(z)

# 2. 判别器(Discriminator):另一个简单的全连接网络
class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()
        self.model = nn.Sequential(
            nn.Linear(784, 256),       # 输入 784 维数据(展平的图像)
            nn.LeakyReLU(0.2),         # 激活函数,允许小负值(更鲁棒)
            nn.Linear(256, 1),         # 输出一个值(真 or 假)
            nn.Sigmoid()               # 输出概率,范围 [0, 1]
        )
    
    def forward(self, x):
        return self.model(x)


# 初始化网络
G = Generator()
D = Discriminator()

# 优化器
optimizer_G = optim.Adam(G.parameters(), lr=0.0002)
optimizer_D = optim.Adam(D.parameters(), lr=0.0002)

# 损失函数:二分类交叉熵
criterion = nn.BCELoss()

 代码释疑:

这段代码实现了生成式对抗网络(GAN)的生成器(Generator)和判别器(Discriminator),并为它们设置了优化器和损失函数。以下是对相关内容的详细解释,帮助你理解各个部分的功能。


1. Generator 类:生成器

生成器的作用是生成假数据,用来骗过判别器。

class Generator(nn.Module):
    def __init__(self):
        super(Generator, self).__init__()
        self.model = nn.Sequential(
            nn.Linear(100, 256),  # 输入 100 维噪声,输出中间隐藏层 256 维
            nn.ReLU(),            # 激活函数 ReLU,增加非线性
            nn.Linear(256, 784),  # 隐藏层输出 784 维数据(28x28 图像展平后)
            nn.Tanh()             # 将输出限制到 [-1, 1],方便后续训练
        )
    
    def forward(self, z):
        return self.model(z)
关键点:
  • 输入:

    • 生成器的输入是一个随机噪声 z,形状为 [batch_size, 100]
    • 噪声是生成器的起点,让它从随机性中学习目标数据分布。
  • 输出:

    • 输出 784 个值,对应一张 28x28 的图像展平(如 MNIST 数据)。
    • 使用 Tanh 将输出限制在 [-1, 1] 区间,通常是为了和真实数据的归一化范围一致。

2. Discriminator 类:判别器

判别器的作用是判断输入数据是真实的还是生成的。

class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()
        self.model = nn.Sequential(
            nn.Linear(784, 256),       # 输入 784 维数据(展平的图像)
            nn.LeakyReLU(0.2),         # 激活函数,允许小负值(更鲁棒)
            nn.Linear(256, 1),         # 输出一个值(真 or 假)
            nn.Sigmoid()               # 输出概率,范围 [0, 1]
        )
    
    def forward(self, x):
        return self.model(x)
关键点:
  • 输入:

    • 输入是展平的 28x28 图像(784 维),可以是真实数据或生成器的假数据。
  • 输出:

    • 输出是一个概率值,0 表示假,1 表示真。
    • 使用 Sigmoid 将值映射到 [0, 1] 区间。
  • LeakyReLU:

    • 激活函数 LeakyReLU(0.2) 在输入为负值时保留一定斜率(0.2),解决 ReLU 的“死区”问题,使训练更稳定。

3. 优化器

优化器用于更新模型的参数,使损失函数逐渐减小。

optimizer_G = optim.Adam(G.parameters(), lr=0.0002)
optimizer_D = optim.Adam(D.parameters(), lr=0.0002)
  • Adam 优化器:

    • 一种改进的梯度下降算法,适用于深度学习模型,尤其是 GAN。
    • 自动调整学习率,提高收敛速度。
  • 学习率 (lr=0.0002):

    • 学习率设置为 0.0002,是 GAN 训练中一个常见的经验值。
  • 目标:

    • optimizer_G 优化生成器的参数,使其生成更逼真的数据。
    • optimizer_D 优化判别器的参数,使其更好地区分真假数据。

4. 损失函数:BCELoss

BCELoss 是二分类交叉熵损失函数,用于计算判别器和生成器的损失。

criterion = nn.BCELoss()
什么是交叉熵?

交叉熵是一种用来衡量两个概率分布相似度的损失函数,公式如下:

  • yiy_i:真实标签(1 表示真,0 表示假)。
  • pip_i:模型预测的概率值(判别器的输出)。
在 GAN 中的作用:
  1. 判别器的损失:

    • 判别器的目标是区分真实数据和生成器生成的假数据。
    • 对于真实数据,y = 1;对于假数据,y = 0
    • 损失函数让判别器尽量输出接近真实标签的概率。
  2. 生成器的损失:

    • 生成器的目标是让判别器认为假数据是真实的。
    • 生成器通过 GAN 的损失函数间接影响判别器的输出,目标是让判别器输出 y = 1

5. 上述代码小结

  • 生成器 (G): 学习生成逼真的假数据。
  • 判别器 (D): 学习区分真实数据和假数据。
  • 损失函数 (BCELoss): 衡量模型输出概率和目标标签之间的差异。
  • 优化器 (Adam): 调整模型参数,使损失函数最小化。

在训练过程中:

  • 生成器试图最小化生成器的损失。
  • 判别器试图最大化判别器的准确率。

这段代码是 GAN 的基础骨架,你可以在此基础上进行实验,比如用它来生成 MNIST 图像!

题外话:

PyTorch(简称 torch)是一个流行的开源深度学习框架,它提供了许多用于构建和训练神经网络的功能。它特别以易用性、灵活性和性能而著称,是机器学习和深度学习领域的常用工具之一。下面,我们来了解一下 PyTorch 的作用,以及在这段 GAN 代码中它是如何发挥作用的。

1. PyTorch 的基本功能

PyTorch 提供了以下几个关键功能:

  • 张量(Tensor):

    • PyTorch 中的核心数据结构是张量(torch.Tensor),类似于 NumPy 的数组,但是张量支持 GPU 加速。
    • 张量是神经网络中的数据载体,存储输入数据、权重、偏置等。
  • 自动求导(Autograd):

    • PyTorch 提供自动求导功能,能够计算神经网络中每一层的梯度,简化了反向传播算法的实现。
    • 当你定义模型并传入数据后,PyTorch 会自动计算损失函数的梯度,并更新模型的参数。
  • 构建和训练神经网络:

    • 使用 torch.nn 提供的模块,可以方便地构建神经网络的各层(如全连接层、卷积层、激活函数等)。
    • torch.optim 提供了优化算法(如 SGD、Adam)来训练模型。
  • GPU 加速:

    • PyTorch 可以利用 GPU(如 CUDA)来加速计算。你可以将张量和模型移动到 GPU 上,这样就能提高训练速度。

阶段3:它们怎么斗起来?

核心是两步:

  1. 训练判别器:真图片标为1,假图片标为0,看看它能不能区分真伪。
  2. 训练生成器:假图片骗过判别器,努力让判别器给它打1分。
from torchvision import datasets, transforms
from torch.utils.data import DataLoader

# 数据加载(MNIST 数据集)
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize([0.5], [0.5])])
mnist = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
dataloader = DataLoader(mnist, batch_size=64, shuffle=True)

# 训练循环
epochs = 10
for epoch in range(epochs):
    for i, (real_imgs, _) in enumerate(dataloader):
        # ========== 1. 训练判别器 ==========
        # 真数据
        real_imgs = real_imgs.view(real_imgs.size(0), -1)  # 展平图片
        real_labels = torch.ones(real_imgs.size(0), 1)  # 真图片标签为1
        
        # 假数据
        z = torch.randn(real_imgs.size(0), 100)  # 随机噪声
        fake_imgs = G(z)
        fake_labels = torch.zeros(real_imgs.size(0), 1)  # 假图片标签为0
        
        # 判别器的预测和损失
        real_preds = D(real_imgs)
        fake_preds = D(fake_imgs.detach())  # 假图片不更新生成器
        loss_real = criterion(real_preds, real_labels)
        loss_fake = criterion(fake_preds, fake_labels)
        loss_D = loss_real + loss_fake
        
        # 优化判别器
        optimizer_D.zero_grad()
        loss_D.backward()
        optimizer_D.step()
        
        # ========== 2. 训练生成器 ==========
        z = torch.randn(real_imgs.size(0), 100)
        fake_imgs = G(z)
        fake_preds = D(fake_imgs)
        loss_G = criterion(fake_preds, real_labels)  # 欺骗判别器的损失
        
        # 优化生成器
        optimizer_G.zero_grad()
        loss_G.backward()
        optimizer_G.step()

        # 打印进度
        if i % 200 == 0:
            print(f"Epoch [{epoch+1}/{epochs}], Step [{i}/{len(dataloader)}], "
                  f"D Loss: {loss_D.item():.4f}, G Loss: {loss_G.item():.4f}")

这段代码实现了一个基本的 生成式对抗网络(GAN) 训练过程,使用 MNIST 数据集 生成与真实手写数字类似的图像。执行这段代码会产生以下几个结果:

1. 数据加载(MNIST 数据集)

首先,代码通过 torchvision 中的 datasets.MNIST 加载了 MNIST 数据集。这个数据集包含了 60,000 张手写数字的训练图像和 10,000 张测试图像(这里只使用了训练集)。数据被转换为 PyTorch 张量并做了标准化处理,使每个像素值在 [-1, 1] 之间。然后,DataLoader 将数据划分为批次(batch),每次加载 64 张图像。

2. 训练循环

接下来,代码进入训练循环,在每个 epoch 中,它会进行以下操作:

(1)训练判别器(Discriminator)
  • 真数据:

    • 从 MNIST 数据集中提取实际的手写数字图像,将图像展平为 784 维(28x28 的像素展平)。
    • 创建真实标签,所有真实图像的标签为 1
  • 假数据:

    • 从随机噪声 z(100 维的向量)中生成假图像。
    • 创建假的标签,所有生成的假图像标签为 0
  • 判别器损失:

    • 判别器会分别计算它对真实数据和假数据的预测,使用二元交叉熵损失 BCELoss 计算真实数据和假数据的损失。
    • loss_real 是判别器对真实图像的损失,loss_fake 是对假图像的损失,最终判别器的总损失是两者之和 loss_D
  • 优化判别器:

    • 使用 optimizer_D.zero_grad() 清除先前的梯度,进行反向传播并更新判别器的参数。
(2)训练生成器(Generator)
  • 生成假图像:
    • 使用随机噪声 z 通过生成器生成一批假图像。
  • 生成器损失:
    • 生成器的目标是欺骗判别器,让它认为生成的假图像是真实的。因此,生成器的损失是判别器对这些假图像的预测(希望是 1)的损失,即 loss_G
  • 优化生成器:
    • 使用 optimizer_G.zero_grad() 清除先前的梯度,进行反向传播并更新生成器的参数。

3. 打印进度

每训练 200 个批次,代码会打印出当前 epoch 和 step 的进度,并显示判别器和生成器的损失:

Epoch [1/10], Step [0/938], D Loss: 0.6881, G Loss: 0.7014
Epoch [1/10], Step [200/938], D Loss: 0.6834, G Loss: 0.7102
...

实际运行效果:

执行结果

  1. 训练输出:

    • 在训练过程中,随着生成器和判别器的不断优化,你会看到输出的 D Loss(判别器损失)和 G Loss(生成器损失)。初始时,这两个损失通常较大,因为模型还没有学会如何生成和判断图像。
    • 随着训练的进行,损失会逐渐减小,表示生成器和判别器在相互博弈中逐渐变得更强。
  2. 图像生成:

    • 由于 GAN 的训练是一个对抗过程,因此每个 epoch 训练后,生成器的输出图像会逐渐接近真实图像的分布。
    • 生成器在训练中会变得越来越善于生成逼真的手写数字图像,直到它能够生成看起来很像 MNIST 数据集中的真实数字。

小结

  • 判别器:学习区分真实和假图像,给出图像是“真”还是“假”的概率。
  • 生成器:学习生成越来越像真实手写数字的图像,目的是“欺骗”判别器,使判别器认为生成的假图像是真实的。

执行完这段代码后,生成器(G)会经过 10 个 epoch 的训练,逐步学会生成类似 MNIST 手写数字的图像。你可以根据损失值的变化和生成的图像的质量,观察训练过程的进展。


阶段4:GAN生成的图像是啥样?

每训练一段时间,我们让生成器画个画,看看它有没有长进:

import matplotlib.pyplot as plt

def show_images(generator, num_images=16):
    z = torch.randn(num_images, 100)  # 随机噪声
    fake_imgs = generator(z).view(num_images, 1, 28, 28)  # 恢复图片形状
    fake_imgs = (fake_imgs + 1) / 2.0  # 把值范围从 [-1, 1] 变到 [0, 1]
    grid = torch.cat([fake_imgs[i] for i in range(num_images)], dim=2).squeeze(0)
    plt.imshow(grid.detach().numpy(), cmap='gray')
    plt.axis('off')  # 不显示坐标轴
    plt.savefig("generated_images.png", bbox_inches='tight')  # 保存图像到文件
    plt.close()  # 关闭图形窗口

 这是最终生成地图像:

局部放大

是不是可以联想到:生成式对抗网络的应用场景相当广泛,比如半导体晶圆缺陷检测领域,医学影像疾病识别领域等等。


阶段5:GAN训练的问题

GAN不是一帆风顺的,训练GAN像哄熊孩子:生成器和判别器常常互相欺负对方导致训练不稳定。
怎么办?我们可以尝试改进:

  1. 改网络结构:比如用更强大的卷积网络。
  2. 改损失函数:比如使用Wasserstein GAN。
  3. 调参:改动学习率、优化器等等。

这就是生成式对抗网络的基础啦,希望它的斗智斗勇能让你觉得有趣!你也可以试试用它生成其他类型的数据,比如音乐、画作或者文字!

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/928444.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

HarmonyOS开发中,如何高效定位并分析内存泄露相关问题

HarmonyOS开发中,如何高效定位并分析内存泄露相关问题 (1)Allocation的应用调试方式Memory泳道Native Allocation泳道 (2)Snapshot(3)ASan的应用使用约束配置参数使能ASan方式一方式二 启用ASanASan检测异常码 (4)HWASan的应用功能介绍约束条件使能HWASan方式一方式…

【Python】Selenium模拟在输入框里,一个字一个字地输入文字

我们平常在使用Selenium模拟键盘输入内容,常用的是用send_keys来在输入框上输入字: 基本的输入方式: input_element driver.find_element(By.ID, searchBox) input_element.send_keys("我也爱你") #给骚骚的自己发个骚话不过这种…

泷羽sec学习打卡-shell命令6

声明 学习视频来自B站UP主 泷羽sec,如涉及侵权马上删除文章 笔记的只是方便各位师傅学习知识,以下网站只涉及学习内容,其他的都 与本人无关,切莫逾越法律红线,否则后果自负 关于shell的那些事儿-shell6 if条件判断for循环-1for循环-2实践是检验真理的唯一标准 if条件判断 创建…

【ArkTS】使用AVRecorder录制音频 --内附录音机开发详细代码

系列文章目录 【ArkTS】关于ForEach的第三个参数键值 【ArkTS】“一篇带你读懂ForEach和LazyForEach” 【小白拓展】 【ArkTS】“一篇带你掌握TaskPool与Worker两种多线程并发方案” 【ArkTS】 一篇带你掌握“语音转文字技术” --内附详细代码 【ArkTS】技能提高–“用户授权”…

数据分析案例-笔记本电脑价格数据可视化分析

🤵‍♂️ 个人主页:艾派森的个人主页 ✍🏻作者简介:Python学习者 🐋 希望大家多多支持,我们一起进步!😄 如果文章对你有帮助的话, 欢迎评论 💬点赞&#x1f4…

系统监控——分布式链路追踪系统

摘要 本文深入探讨了分布式链路追踪系统的必要性与实施细节。随着软件架构的复杂化,传统的日志分析方法已不足以应对问题定位的需求。文章首先解释了链路追踪的基本概念,如Trace和Span,并讨论了其基本原理。接着,文章介绍了SkyWa…

游戏引擎学习第25天

Git: https://gitee.com/mrxiao_com/2d_game 今天的计划 总结和复述: 这段时间的工作已经接近尾声,虽然每次编程的时间只有一个小时,但每一天的进展都带来不少收获。尽管看起来似乎花费了很多时间,实际上这些日积月累的时间并未…

GaussDB TPOPS 搭建流程记录

目录 前言 环境准备 安装前准备 安装TPOPS 总结 前言 由于工作需要,准备将现有Oracle数据切换至GaussDB数据库。在这里记录一下安装GaussDB数据库过程踩的坑。 首先,我装的是线下版本,需要先装一个GaussDB轻量化管理平台(…

Web网页设计作业成品源码分享(持续更新)

🎉Web前端大作业专栏推荐 📚Web前端期末大作业源码分享 ✍️html网页设计、web前后端网站制作、大学生网页设计作业、个人网站制作、jQuery网站设计、uniapp小程序、vue网站设计、node.js网站设计、网页成品模板、期末大作业,各种设计应有尽有…

facebook欧洲户开户条件有哪些又有何优势?

在当今数字营销时代,Facebook广告已成为企业推广产品和服务的重要渠道。而为了更好地利用这一平台,广告主们需要理解不同类型的Facebook广告账户。Facebook广告账户根据其属性可分为多种类型,包括个人广告账户、企业管理(BM&#…

Qt 2D绘图之三:绘制文字、路径、图像、复合模式

参考文章链接: Qt 2D绘图之三:绘制文字、路径、图像、复合模式 绘制文字 除了绘制图形以外,还可以使用QPainter::darwText()函数来绘制文字,也可以使用QPainter::setFont()设置文字所使用的字体,使用QPainter::fontInfo()函数可以获取字体的信息,它返回QFontInfo类对象…

一种多功能调试工具设计方案开源

一种多功能调试工具设计方案开源 设计初衷设计方案具体实现HUB芯片采用沁恒微CH339W。TF卡功能网口功能SPI功能IIC功能JTAG功能下行USB接口 安路FPGA烧录器功能Xilinx FPGA烧录器功能Jlink OB功能串口功能RS232串口RS485和RS422串口自适应接口 CAN功能烧录器功能 目前进度后续计…

【C++】深入优化计算题目分析与实现

博客主页: [小ᶻ☡꙳ᵃⁱᵍᶜ꙳] 本文专栏: C 文章目录 💯前言💯第一题:圆的计算我的代码实现代码分析改进建议改进代码 老师的代码实现代码分析可以改进的地方改进代码 💯第二题:对齐输出我的代码实现…

动手学深度学习10.5. 多头注意力-笔记练习(PyTorch)

本节课程地址:多头注意力代码_哔哩哔哩_bilibili 本节教材地址:10.5. 多头注意力 — 动手学深度学习 2.0.0 documentation 本节开源代码:...>d2l-zh>pytorch>chapter_multilayer-perceptrons>multihead-attention.ipynb 多头注…

大R玩家流失预测在休闲社交游戏中的应用

摘要 预测玩家何时会离开游戏为延长玩家生命周期和增加收入贡献创造了独特的机会。玩家可以被激励留下来,战略性地与公司组合中的其他游戏交叉链接,或者作为最后的手段,通过游戏内广告传递给其他公司。本文重点预测休闲社交游戏中高价值玩家…

软件质量保证——单元测试之白盒技术

笔记内容及图片整理自XJTUSE “软件质量保证” 课程ppt,仅供学习交流使用,谢谢。 程序图 程序图定义 程序图P(V,E),V是节点的集合(节点是程序中的语句或语句片段),E是有向边的集合…

Node.js:开发和生产之间的区别

Node.js 中的开发和生产没有区别,即,你无需应用任何特定设置即可使 Node.js 在生产配置中工作。但是,npm 注册表中的一些库会识别使用 NODE_ENV 变量并将其默认为 development 设置。始终在设置了 NODE_ENVproduction 的情况下运行 Node.js。…

Zookeeper的通知机制是什么?

大家好,我是锋哥。今天分享关于【Zookeeper的通知机制是什么?】面试题。希望对大家有帮助; Zookeeper的通知机制是什么? 1000道 互联网大厂Java工程师 精选面试题-Java资源分享网 Zookeeper的通知机制主要通过Watcher实现,它是Zookeeper客…

Request method ‘POST‘ not supported(500)

前端路径检查 查看前端的请求路径地址、请求类型、方法名是否正确,结果没问题 后端服务检查 查看后端的传参uri、传参类型、方法名,结果没问题 nacos服务名检查 检查注册的服务是否对应(我这里是后端的服务名是‘ydlh-gatway’,服务列表走…

ESP32-S3模组上跑通ES8388(13)

接前一篇文章:ESP32-S3模组上跑通ES8388(12) 二、利用ESP-ADF操作ES8388 2. 详细解析 上一回解析了es8388_init函数中的第6段代码,本回继续往下解析。为了便于理解和回顾,再次贴出es8388_init函数源码,在…