昇思25天学习打卡营第22天|GAN图像生成

今天是参加昇思25天学习打卡营的第22天,今天打卡的课程是“GAN图像生成”,这里做一个简单的分享。

1.简介

今天来学习“GAN图像生成”,这是一个基础的生成式模型。

生成式对抗网络(Generative Adversarial Networks,GAN)是一种生成式机器学习模型,是近年来复杂分布上无监督学习最具前景的方法之一。

最初,GAN由Ian J. Goodfellow于2014年发明,并在论文Generative Adversarial Nets中首次进行了描述,其主要由两个不同的模型共同组成——生成器(Generative Model)和判别器(Discriminative Model):

  • 生成器的任务是生成看起来像训练图像的“假”图像;
  • 判别器需要判断从生成器输出的图像是真实的训练图像还是虚假的图像。

GAN通过设计生成模型和判别模型这两个模块,使其互相博弈学习产生了相当好的输出。

2.模型架构

  • 模型原理

GAN模型的核心在于提出了通过对抗过程来估计生成模型这一全新框架。在这个框架中,将会同时训练两个模型——捕捉数据分布的生成模型 𝐺 和估计样本是否来自训练数据的判别模型 𝐷。

在训练过程中,生成器会不断尝试通过生成更好的假图像来骗过判别器,而判别器在这过程中也会逐步提升判别能力。这种博弈的平衡点是,当生成器生成的假图像和训练数据图像的分布完全一致时,判别器拥有50%的真假判断置信度。

用 𝑥 代表图像数据,用 𝐷(𝑥)表示判别器网络给出图像判定为真实图像的概率。在判别过程中,𝐷(𝑥) 需要处理作为二进制文件的大小为 1×28×28的图像数据。当 𝑥 来自训练数据时,𝐷(𝑥) 数值应该趋近于 1 ;而当 𝑥 来自生成器时,𝐷(𝑥)𝐷数值应该趋近于 00 。因此 𝐷(𝑥) 也可以被认为是传统的二分类器。

用 𝑧 代表标准正态分布中提取出的隐码(隐向量),用 𝐺(𝑧):表示将隐码(隐向量) 𝑧 映射到数据空间的生成器函数。函数 𝐺(𝑧) 的目标是将服从高斯分布的随机噪声 𝑧 通过生成网络变换为近似于真实分布 𝑝𝑑𝑎𝑡𝑎(𝑥) 的数据分布,我们希望找到 θ 使得 𝑝𝐺(𝑥;𝜃) 和𝑝𝑑𝑎𝑡𝑎(𝑥) 尽可能的接近,其中𝜃 代表网络参数。

𝐷(𝐺(𝑧))表示生成器 𝐺𝐺生成的假图像被判定为真实图像的概率,如Generative Adversarial Nets中所述,𝐷 和 𝐺 在进行一场博弈,𝐷 想要最大程度的正确分类真图像与假图像,也就是参数 log⁡𝐷(𝑥);而 𝐺 试图欺骗 𝐷 来最小化假图像被识别到的概率,也就是参数log⁡(1−𝐷(𝐺(𝑧)))。因此GAN的损失函数为:
在这里插入图片描述
从理论上讲,此博弈游戏的平衡点是𝑝𝐺(𝑥;𝜃)=𝑝𝑑𝑎𝑡𝑎(𝑥),此时判别器会随机猜测输入是真图像还是假图像。下面我们简要说明生成器和判别器的博弈过程:

  1. 在训练刚开始的时候,生成器和判别器的质量都比较差,生成器会随机生成一个数据分布。
  2. 判别器通过求取梯度和损失函数对网络进行优化,将靠近真实数据分布的数据判定为1,将靠近生成器生成出来数据分布的数据判定为0。
  3. 生成器通过优化,生成出更加贴近真实数据分布的数据。
  4. 生成器所生成的数据和真实数据达到相同的分布,此时判别器的输出为1/2。
  • 核心代码

生成器代码:

from mindspore import nn
import mindspore.ops as ops

img_size = 28  # 训练图像长(宽)

class Generator(nn.Cell):
    def __init__(self, latent_size, auto_prefix=True):
        super(Generator, self).__init__(auto_prefix=auto_prefix)
        self.model = nn.SequentialCell()
        # [N, 100] -> [N, 128]
        # 输入一个100维的0~1之间的高斯分布,然后通过第一层线性变换将其映射到256维
        self.model.append(nn.Dense(latent_size, 128))
        self.model.append(nn.ReLU())
        # [N, 128] -> [N, 256]
        self.model.append(nn.Dense(128, 256))
        self.model.append(nn.BatchNorm1d(256))
        self.model.append(nn.ReLU())
        # [N, 256] -> [N, 512]
        self.model.append(nn.Dense(256, 512))
        self.model.append(nn.BatchNorm1d(512))
        self.model.append(nn.ReLU())
        # [N, 512] -> [N, 1024]
        self.model.append(nn.Dense(512, 1024))
        self.model.append(nn.BatchNorm1d(1024))
        self.model.append(nn.ReLU())
        # [N, 1024] -> [N, 784]
        # 经过线性变换将其变成784维
        self.model.append(nn.Dense(1024, img_size * img_size))
        # 经过Tanh激活函数是希望生成的假的图片数据分布能够在-1~1之间
        self.model.append(nn.Tanh())

    def construct(self, x):
        img = self.model(x)
        return ops.reshape(img, (-1, 1, 28, 28))

net_g = Generator(latent_size)
net_g.update_parameters_name('generator')

判别器代码:

# 判别器
class Discriminator(nn.Cell):
    def __init__(self, auto_prefix=True):
        super().__init__(auto_prefix=auto_prefix)
        self.model = nn.SequentialCell()
        # [N, 784] -> [N, 512]
        self.model.append(nn.Dense(img_size * img_size, 512))  # 输入特征数为784,输出为512
        self.model.append(nn.LeakyReLU())  # 默认斜率为0.2的非线性映射激活函数
        # [N, 512] -> [N, 256]
        self.model.append(nn.Dense(512, 256))  # 进行一个线性映射
        self.model.append(nn.LeakyReLU())
        # [N, 256] -> [N, 1]
        self.model.append(nn.Dense(256, 1))
        self.model.append(nn.Sigmoid())  # 二分类激活函数,将实数映射到[0,1]

    def construct(self, x):
        x_flat = ops.reshape(x, (-1, img_size * img_size))
        return self.model(x_flat)

net_d = Discriminator()
net_d.update_parameters_name('discriminator')
  • 损失函数和优化器
lr = 0.0002  # 学习率

# 损失函数
adversarial_loss = nn.BCELoss(reduction='mean')

# 优化器
optimizer_d = nn.Adam(net_d.trainable_params(), learning_rate=lr, beta1=0.5, beta2=0.999)
optimizer_g = nn.Adam(net_g.trainable_params(), learning_rate=lr, beta1=0.5, beta2=0.999)
optimizer_g.update_parameters_name('optim_g')
optimizer_d.update_parameters_name('optim_d')

3.小结

今天学了GAN用于图像生成的基本理论和编码方法。GAN模型由生成器(Generative Model)和判别器(Discriminative Model)构成两个相互对抗的模型。生成器负责生成图像进行,判别器用于判定图像真假,通过对抗的模式使得真假判定的结果接近1:1,进而完成训练。这样训练好的生成器即可用于图形生成。这里面着重要掌握对抗网络损失函数的意义,这是的对抗网络能够输出最正确结果的要点。

以上是第22天的学习内容,附上今日打卡记录:
在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/798876.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

【Django+Vue3 线上教育平台项目实战】构建课程详情页与集成视频播放功能

文章目录 前言一、课程列表页面a.后端代码b.前端代码 二、课程详情页面a. 视频播放功能的集成1.获取上传视频的链接地址2.集成在前端页面中1>使用vue-alipayer视频播放组件2>使用video标签 b. 页面主要内容展示1.后端代码1>分析表2>核心逻辑 2.前端代码3.效果图 前…

Java中的Filter流:理解与应用

Java中的Filter流:理解与应用 1、字节Filter流1.1 FilterInputStream1.2 FilterOutputStream 2、字符Filter流2.1 FilterReader2.2 FilterWriter 3、使用Filter流的好处 💖The Begin💖点点关注,收藏不迷路💖 在Java的…

算法思想总结:字符串

一、最长公共前缀 . - 力扣&#xff08;LeetCode&#xff09; 思路1&#xff1a;两两比较 时间复杂度mn 实现findcomon返回两两比较后的公共前缀 class Solution { public:string longestCommonPrefix(vector<string>& strs) {//两两比较 string retstrs[0];size…

信通院全景图发布 比瓴科技领跑软件供应链安全,多领域覆盖数字安全服务

近日&#xff0c;中国信息通信研究院在2024全球数字经济大会—数字安全生态建设专题论坛正式发布首期《数字安全护航技术能力全景图》&#xff08;以下简称全景图&#xff09;。 比瓴科技入选软件供应链安全赛道“开发流程安全管控、交互式安全测试、静态安全测试、软件成分分…

【编程概念】生命周期

在进行系统编程时&#xff0c;经常遇到对象的生命周期这一概念。我理解的对象生命周期周期&#xff0c;就是一个对象从创建到销毁的所有状态&#xff0c;对象在不同的状态下会有不同的行为。 生命周期的概念&#xff0c;通常给出现在需要长时间运行的软件中&#xff0c;脚本工…

MySQL里的累计求和

在MySQL中&#xff0c;你可以使用SUM()函数来进行累计求和。如果你想要对一个列进行累计求和&#xff0c;可以使用OVER()子句与ORDER BY子句结合&#xff0c;进行窗口函数的操作。 以下是一个简单的例子&#xff0c;假设我们有一个名为sales的表&#xff0c;它有两个列&#x…

synwit其它应用

一、关于变量数组数量定义 在应用中&#xff0c;定义数组变量或其它变量时&#xff0c;需要注意不要超出MCU的RAM大小。如&#xff1a; 有客户应用SWM320系列是&#xff0c;定义了“uint32_t TEST_Buffer1[65536];”临时变量&#xff0c;编译过程中不会出错&#xff0c;但实际应…

敏捷的两种方式:Kanban和 Scrum

敏捷方法通过提供灵活、迭代的项目管理方法&#xff0c;改变了软件开发。敏捷方法中最著名的框架是 Kanban 和 Scrum。虽然这两种方法都旨在提高生产力和效率&#xff0c;但它们的运作原则和实践却截然不同。 在本文中&#xff0c;我们将深入探讨 Kanban 和 Scrum 的起源、主要…

[NSSRound#4 SWPU]1zweb

非预期解&#xff1a; 输入/flag&#xff0c;点击查看 预期解&#xff1a; upload.php <?php if ($_FILES["file"]["error"] > 0){echo "上传异常"; } else{$allowedExts array("gif", "jpeg", "jpg"…

STFT:解决音频-视频零样本学习 (ZSL) 中的挑战

传统的监督学习方法需要大量的标记训练实例来进行训练,视听零样本学习的任务是利用音频和视频模态对对象或场景进行分类&#xff0c;即使在没有可用标记数据的情况下。为了解决传统监督方法的限制&#xff0c;提出了广义零样本学习&#xff08;Generalized Zero-Shot Learning,…

掌握微信自动化操作,从此高效办公,效率直线上升!

你是不是每次回复客户消息&#xff0c;都要复制话术再粘贴发给不同的客户&#xff1f;每次统计微信数据都要手动统计很费时间&#xff1f; 试试这个多微管理神器&#xff0c;让你可以实现微信自动化操作&#xff0c;效率直线上升&#xff01; 1、自动通过好友并打招呼 系统可…

【Java】Idea运行JDK1.8,Build时中文内容GBK UTF-8编码报错一堆方块码

问题描述 在Windows系统本地运行一个JDK1.8的项目时&#xff0c;包管理用的Gradle&#xff0c;一就编码报错&#xff08;所有的中文内容&#xff0c;包括中文注释、中文的String字面量&#xff09;&#xff0c;但程序还是正常运行。具体如下&#xff1a; 解决 1. Idea更改编…

Java学习 - Spring 讲解

前言 为了解决我们开发者在 J2EE 开发时所遇到的众多问题&#xff0c;Rob Johnson 等人发起了 Spring 框架项目。Spring 是一个开源的 J2EE 应用程序框架&#xff0c;是针对 Bean 的生命周期进行管理的轻量级容器。它既可以单独用于构建程序&#xff0c;也能和当前众多的 Web …

《昇思25天学习打卡营第18天|基于MobileNetv2的垃圾分类》

MobileNetV2是一种轻量级的深度神经网络&#xff0c;设计用于移动和嵌入式设备。它的核心思想是通过深度可分离卷积&#xff08;Depthwise Separable Convolutions&#xff09;和倒残差结构&#xff08;Inverted Residuals&#xff09;来减少计算复杂度和模型参数量。其主要特点…

什么叫图像的双边滤波,并附利用OpenCV和MATLB实现双边滤波的代码

双边滤波&#xff08;Bilateral Filtering&#xff09;是一种在图像处理中常用的非线性滤波技术&#xff0c;主要用于去噪和保边。它在空间域和像素值域上同时进行加权&#xff0c;既考虑了像素之间的空间距离&#xff0c;也考虑了像素值之间的相似度&#xff0c;从而能够有效地…

WPF学习(4) -- 数据模板

一、DataTemplate 在WPF&#xff08;Windows Presentation Foundation&#xff09;中&#xff0c;DataTemplate 用于定义数据的可视化呈现方式。它允许你自定义如何展示数据对象&#xff0c;从而实现更灵活和丰富的用户界面。DataTemplate 通常用于控件&#xff08;如ListBox、…

[GXYCTF2019]BabySQli

原题目描述&#xff1a;刚学完sqli&#xff0c;我才知道万能口令这么危险&#xff0c;还好我进行了防护&#xff0c;还用md5哈希了密码&#xff01; 我看到是个黑盒先想着搞一份源码 我dirsearch明明扫到了.git&#xff0c;算了直接注入试试看 随便输入了两个东西&#xff0c…

赛氪网荣获2024年中国高校计算机教育大会合作伙伴荣誉

2024年7月13日&#xff0c;在黑龙江哈尔滨召开的“2024年中国高校计算机教育大会&#xff08;CCEC2024&#xff09;”&#xff0c;环球赛乐&#xff08;北京&#xff09;科技有限公司(以下简称”赛氪网“)凭借其在高等教育与科技创新领域的卓越贡献&#xff0c;荣幸地获得了本次…

安卓onNewIntent 什么时候执行

一.详细介绍 onNewIntent 方法 onNewIntent 是 Android 中 Activity 生命周期的一部分。它在特定情况下被调用&#xff0c;主要用于处理新的 Intent&#xff0c;而不是创建新的 Activity 实例。详细介绍如下&#xff1a; 使用场景 singleTop 启动模式&#xff1a; 如果一个 Ac…

6.S081的Lab学习——Lab11: Network

文章目录 前言Network提示&#xff1a;实现e1000_transmit的一些提示&#xff1a;实现e1000_recv的一些提示&#xff1a; 解析 总结 前言 一个本硕双非的小菜鸡&#xff0c;备战24年秋招。打算尝试6.S081&#xff0c;将它的Lab逐一实现&#xff0c;并记录期间心酸历程。 代码下…