人工智能应用-实验8-用生成对抗网络生成数字图像

文章目录

    • 🧡🧡实验内容🧡🧡
    • 🧡🧡代码🧡🧡
    • 🧡🧡分析结果🧡🧡
    • 🧡🧡实验总结🧡🧡

🧡🧡实验内容🧡🧡

以MNIST 数据集为训练数据,用生成对抗网络生成手写数字 5的图像(编程语言不限,如Python 等)。


🧡🧡代码🧡🧡

import torch
from torch import nn
from torch.optim import Adam
import torch.nn.functional as F
from torchvision import transforms, datasets
import matplotlib.pyplot as plt
import time
import pandas

transform = transforms.Compose([
    transforms.ToTensor(),
])

train_set = datasets.MNIST('data', train=True, download=True, transform=transform)
train_loader = torch.utils.data.DataLoader(train_set, batch_size=1, shuffle=False) # 批次为1,不打乱数据

# !nvidia-smi
# 检查GPU是否可用
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)

#@title 模型
#返回size大小的均值为0,均方误差为1的随机数
def generate_random(size):
    random_data = torch.randn(size)
    return random_data

# def generate_random(size): # 均匀分布的随机数,会产生模式崩溃
#     random_data = torch.rand(size)
#     return random_data

#判别器
class Discriminator(nn.Module):
    def __init__(self):
        super().__init__()
        self.model=nn.Sequential(
            nn.Linear(784, 200), # 全连接层 784维特征(像素点) => 200维特征
            nn.LeakyReLU(0.02), # 激活层:f(x)=max(ax,x) a
            nn.LayerNorm(200), # 归一化层
            nn.Linear(200, 1), # 全连接层 200维特征(像素点) => 1维标量
            nn.Sigmoid() # 将1维标量缩放结果到0-1之间,以0.5作为二分类结果
        )

        self.loss_function = nn.BCELoss() # 定义损失函数
        self.optimiser = torch.optim.Adam(self.parameters(), lr=0.0001) # 创建优化器,使用Adam梯度下降
        # 计数器和损失记录
        self.counter = 0
        self.loss_list = []

    def forward(self, inputs):
        return self.model(inputs)

    def train(self, inputs, targets):
        outputs = self.forward(inputs)  # 计算网络前向传播输出
        loss = self.loss_function(outputs, targets) # 计算损失值


        self.counter += 1
        if (self.counter % 10 == 0): # 每训练10次记录损失值
            self.loss_list.append(loss.item())
        if (self.counter % 10000 == 0): # 每训练10000次打印进程
            print("counter = ", self.counter)

        self.optimiser.zero_grad() #在反向传播前先把梯度归零
        loss.backward() #反向传播,计算各参数对于损失loss的梯度
        self.optimiser.step()  #根据反向传播得到的梯度,更新模型权重参数

    def plot_loss_process(self):
        df = pandas.DataFrame(self.loss_list, columns=['Discriminator Loss'])
        ax = df.plot(figsize=(12,6), alpha=0.1,
        marker='.', grid=True, yticks=(0, 0.25, 0.5, 1.0, 5.0))
        ax.set_title("Discriminator Loss")


# 生成器
class Generator(nn.Module):
    def __init__(self):
        super().__init__()
        # 定义神经网络层
        self.model = nn.Sequential(
            nn.Linear(100, 200), # 全连接层 100维噪声 => 200维特征
            nn.LeakyReLU(0.02), # 激活函数
            nn.LayerNorm(200), # 标准化
            nn.Linear(200, 784), # 200维特征 => 784像素特征
            nn.Sigmoid() # 每个像素点缩放到0-1
        )
        # 创建生成器,使用Adam梯度下降
        self.optimiser = torch.optim.Adam(self.parameters(), lr=0.0001)
        # 计数器和损失记录
        self.counter = 0
        self.loss_list = []

    def forward(self, inputs):
        # 运行模型
        return self.model(inputs)

    def train(self, D, inputs, targets):

        g_output = self.forward(inputs) # 计算网络输出
        d_output = D.forward(g_output) # 输入判别器
        loss = D.loss_function(d_output, targets) # 计算损失值

        self.counter += 1
        if (self.counter % 10 == 0):  # 每训练10次记录损失值
            self.loss_list.append(loss.item())

        # 梯度归零,反向传播,并更新权重
        self.optimiser.zero_grad()
        loss.backward()

        #更新由self.optimiser而不是D.optimiser触发。这样一来,只有生成器的链接权重得到更新
        self.optimiser.step()

    def plot_loss_process(self):
        df = pandas.DataFrame(self.loss_list, columns=['Generator Loss'])
        ax = df.plot(figsize=(12,6), alpha=0.1,
        marker='.', grid=True, yticks=(0, 0.25, 0.5, 1.0, 5.0))
        ax.set_title("Generator Loss")

D = Discriminator()
G = Generator()
D = D.to(device)
G = G.to(device)

#@title train
epochs=1
start_time=time.time()
for epoch in range(epochs):
    print(f"=============Epoch={epoch}============")
    for step, (images, labels) in enumerate(train_loader):
        images = images.to(device)
        image_data_tensor=images.view(-1)
        # ==使用真实数据训练判别器, 并标注真实数据为正样本(1)==
        D.train( image_data_tensor, torch.FloatTensor([1.0]).to(device) )

        # ==用生成数据(fake)训练判别器, 并标注生成数据为负样本(0)==
        # 同时使用detach()以避免计算生成器G中的梯度
        D.train( G.forward(generate_random(100).to(device)).detach(), torch.FloatTensor([0.0]).to(device) )

        # ==训练生成器, 让判别器对于生成器的生成数据评分尽可能接近正样本(1)==
        G.train( D, generate_random(100).to(device), torch.FloatTensor([1.0]).to(device) )
print(f"cost all time={(time.time()-start_time)/60} minutes")

# 保存模型
torch.save(D, 'GAN_Digits_D.pt')
torch.save(G, 'GAN_Digits_G.pt')
# 加载模型
D=torch.load('GAN_Digits_D.pt')
G=torch.load('GAN_Digits_G.pt')
G.plot_loss_process()
D.plot_loss_process()
# 生成效果图
f, axarr = plt.subplots(2,3, figsize=(16,8))
for i in range(2):
    for j in range(3):
        output = G.forward(generate_random(100).to(device))
        output = output.cpu()
        img = output.detach().numpy().reshape(28,28)
        axarr[i,j].imshow(img, interpolation='none', cmap='Blues')



🧡🧡分析结果🧡🧡

数据预处理:
加载数据集:
加载torch库中自带的minst数据集
转换数据:
转为tensor变量(相当于直接除255归一化到值域为(0,1))。
此处不同于CNN和BP网络实验,不再对其进行transforms.Normalize()处理,因为对抗网络中,生成器输入的是一个随机噪声向量,不是预处理后的图像;判别器中,输入的是真实图像和生成图像,而不是预处理后的图像,如果对输入数据进行归一化处理,会改变图像的数值范围,可能会影响判别器的判断结果。

构建对抗网络
构造判别器:
在这里插入图片描述

  • nn.Linear():全连接层,转换特征维度。
  • nn.LeakyReLU(0.02):激活层,激活函数如下,0.02即为negative_slope,用于控制负斜率的角度。相比于不具备负值响应(x<0,则y为0)的传统ReLU,LeakyReLU在负数区间表现的更加平滑,增强非线性表达能力,有助于判别器更好地区分真实样本和真实样本。
    在这里插入图片描述
  • nn.LayerNorm(200):对中间层的输出值进行标准化,让它们均值为0,避免较大值引起的梯度消失。200表示要标准化的维度数目。
  • nn.Sigmoid():将1维标量缩放结果到0-1之间,以0.5作为二分类结果。

构造生成器:
在这里插入图片描述

  • nn.Linear():全连接层,转换特征维度。这里设定输入的随机噪声维度为100,最后输出一张784像素图片。
  • nn.LeakyReLU、nn.LayerNorm、nn.Sigmoid作用同上述类似

选取损失函数:
对于分类问题,损失函数使用二元交叉熵BCELoss()往往比均方误差MSELoss()效果更好。因为它能对正确分类进行奖励,而对错误分类进行惩罚。
由于生成器无需定义损失函数,所以我们只需要修改鉴别器的损失函数即可:

训练和评估
在这里插入图片描述
在这里插入图片描述
在这里插入图片描述
在这里插入图片描述
在这里插入图片描述
在这里插入图片描述
每10张图记录1次loss,1次epoch训练60000张图,则1次epoch记录6000次loss,6次epoch记录36000次loss。而1次epoch训练1次生成器,训练2次判别器(1次正样本判别、1次负样本判别),所以生成器loss迭代变化横坐标为36000次,判别器loss迭代变化横坐标为72000次。
在这里插入图片描述
loss迭代变化如下图。
在这里插入图片描述
在这里插入图片描述
从图中整体来看,一开始生成器loss较高,判别器接近0,后面生成器和判别器loss逐渐分布均匀(方差减少,数值大小越来越集中)。

分析生成对抗网络中生成器和判别器的关系
实验中,判别器的loss定义为:区分真实图像和假图像的能力,即loss越小,区分能力越强
而生成器虽然没有直接定义loss,但是利用了判别器的loss,使得判别器对生成器生成的假图像的评分尽可能接近正样本,也即loss越小,生成器生成的假数据越来越接近真实图像。
上述loss的记录迭代次数太多,可能不够直观观察判别器和生成器的相对变化,计算每次epoch的平均loss如下图:
在这里插入图片描述
可以看到,刚开始生成器与判别器的博弈中处于下风,随着训练进行,生成器的loss大幅减少,说明生成器生成的图像越来越逼真,反观判别器loss增大,说明判别器开始处于下风。最后,可以看到两者的loss都趋于平稳,说明此时渐渐达到了博弈平衡,从直观的图像清晰度也能看到,对比训练初期,图像5相比最开始变得比较清晰,但当迭代一定训练次数后,清晰度似乎不再变化了。


🧡🧡实验总结🧡🧡

理论理解:
GAN的核心思想:生成器G和判别器D的一代代博弈

  • 生成器:
    生成网络,通过输入生成图像
  • 判别器:
    二分类网络,将生成器生成图像作为负样本,真实图像作为正样本
  • 优化 判别器D:
    给定G,通过G生成图像产生负样本,并结合真实图像作为正样本来训练D
  • 优化 生成器G:
    给定D,以使得D对G生成图像的评分尽可能接近正样本作为目标来训练G

G和D的训练过程交替进行,这个对抗的过程使得G生成的图像越来越逼真,D辨别的能力也越来越强。

代码实操:

  • 模式崩溃:
    在生成器生成随机数时,若生成的方法不对,可能会导致模式崩溃问题,它指的是生成器倾向于生成相似或重复的样本,而不是多样化的输出(如下图)。
    在这里插入图片描述
    在python中,torch.rand()产生的是0-1之间均匀分布的随机数,很容易导致模式崩溃,因为均匀分布的随机数无法提供足够的多样性,从而使得生成器可能会生成类似的样本。为了解决这个问题,使用torch,randn()函数从高斯分布中抽取随机数,从而增大生成器的多样性。
  • 判断对抗网络模型的收敛情况
    一方面生成器和判别器的损失函数值来监控两者的优化过程,它们的相对变化可以一定程度反映它们的博弈情况,当它们的loss的变化都慢慢趋于平稳时,可以认为模型达到收敛。当然,另一方面,通过观察图像清晰度也是比较直观的方法。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/639613.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

spark的简单学习一

一 RDD 1.1 RDD的概述 1.RDD&#xff08;Resilient Distributed Dataset&#xff0c;弹性分布式数据集&#xff09;是Apache Spark中的一个核心概念。它是Spark中用于表示不可变、可分区、里面的元素可并行计算的集合。RDD提供了一种高度受限的共享内存模型&#xff0c;即RD…

想学接口测试,不知道那个工具适合?

引言&#xff1a; 接口测试在软件开发中扮演着至关重要的角色&#xff0c;它可以帮助我们验证系统的功能、性能和安全性。而选择适合的工具是进行接口测试的重要一步。本文将从零开始&#xff0c;为你详细介绍如何选择合适的工具&#xff0c;并提供规范的指导。 一、了解接口…

【大数据】MapReduce实战

文章目录 [toc]Word CountMapperReducerrun.sh本地调试 基于白名单的Word CountMapperReducerrun.sh本地调试 文件分发-fileMapperReducerrun.sh -cacheFileMapperReducerrun.sh -cacheArchiveMapperReducerrun.sh 杀死MapReduce Job排序压缩文件mr_ip_lib_python本地调试 个人…

PE文件(六)新增节-添加代码作业

一.手动新增节添加代码 1.当预备条件都满足&#xff0c;节表结尾没有相关数据时&#xff1a; 现在我们将ipmsg.exe用winhex打开&#xff0c;在节的最后新增一个节用于存放我们要增加的数据 注意&#xff1a;飞鸽的文件对齐和内存对齐是一致的 先判断节表末尾到第一个节之间…

《书生·浦语大模型实战营》第一课 学习笔记:书生·浦语大模型全链路开源体系

文章大纲 1. 简介与背景智能聊天机器人与大语言模型目前的开源智能聊天机器人与云上运行模式 2. InternLM2 大模型 简介3. 视频笔记&#xff1a;书生浦语大模型全链路开源体系内容要点从模型到应用典型流程全链路开源体系 4. 论文笔记:InternLM2 Technical Report简介软硬件基础…

Flat Ads获广东电视台报道!CEO林啸:助力更多企业实现业务全球化增长

近日,在广州举行的第四届全球产品与增长展会(PAGC2024)上,Flat Ads凭借其卓越的一站式全球化营销和创新的变现方案大放异彩,不仅吸引了众多业界目光,同时也在展会上斩获了备受瞩目的“金帆奖”,展现了其在全球化营销推广领域的卓越实力和专业服务。 在大会现场,Flat Ads的CEO林…

fyne网格包裹布局

fyne网格包裹布局 与之前的网格布局一样&#xff0c;网格环绕布局以网格模式创建元素排列。但是&#xff0c;此网格没有固定数量的列&#xff0c;而是为每个单元格使用固定大小&#xff0c;然后将内容流到显示项目所需的行数。 layout.NewGridWrapLayout(size) 您可以使用其中…

如何官方查询论文分区,中科院及JCR

中科院分区 有一个小程序&#xff1a;中科院文献情报中心分区表 点2023升级版&#xff0c;输入期刊名 大类1区 JCR分区 进入官方网站 Journal Citation Reports 输入要查询的期刊名&#xff0c;点开 拼命往下拉 这就是根据影响因子的排名&#xff0c;在computer science&am…

Dijkstra算法求最短路径 c++

目录 【问题背景】 【相关知识】 【算法思想】 【算法实现】 【伪代码】 【输入输出】 【代码】 【问题背景】 出门旅游&#xff0c;有些城市之间有公路&#xff0c;有些城市之间则没有&#xff0c;如下图。为了节省经费以及方便计划旅程&#xff0c;希望在出发之前知道…

【iceberg数据一致性】iceberg如何保证高并发数据一致性

在使用iceberg写数据时&#xff0c;一直弄不清楚为什么iceberg写入快&#xff0c;并且能够保证数据的一致性。今天决定搞清楚这个问题&#xff0c;经过查询和理解&#xff0c;写下来。 文件格式 iceberg元数据的文件目前有三个&#xff1a;metadata.json&#xff0c;snap.avro…

MyBatis实用方案,如何使项目兼容多种数据库

系列文章目录 MyBatis缓存原理 Mybatis plugin 的使用及原理 MyBatisSpringboot 启动到SQL执行全流程 数据库操作不再困难&#xff0c;MyBatis动态Sql标签解析 Mybatis的CachingExecutor与二级缓存 使用MybatisPlus还是MyBaits &#xff0c;开发者应该如何选择&#xff1f; 巧…

SVN创建项目分支

目录 背景调整目录结构常规目录结构当前现状目标 调整SVN目录调整目录结构创建项目分支 效果展示 背景 当前自己本地做项目的时候发现对SVN创建项目不规范&#xff0c;没有什么目录结构&#xff0c;趁着创建目录分支的契机&#xff0c;顺便调整下SVN服务器上的目录结构 调整目…

Day36 代码随想录打卡|二叉树篇---翻转二叉树

题目&#xff08;leecode T226&#xff09;&#xff1a; 给你一棵二叉树的根节点 root &#xff0c;翻转这棵二叉树&#xff0c;并返回其根节点。 方法&#xff1a; 迭代法 翻转二叉树&#xff0c;即从根节点开始&#xff0c;一一交换每个节点的左右孩子节点&#xff0c;然后…

【Arthas】阿里的线上jvm监控诊断工具的基本使用

关于对运行中的项目做java监测的需求下&#xff0c;Arthas则是一个很好的解决方案。 我们可以用来 1.监控cpu 现成、内存、堆栈 2.排查cpu飚高 造成原因 3.接口没反应 是否死锁 4.接口慢优化 5.代码未按预期执行 是分支不对 还是没提交&#xff1f; 6.线上低级错误 能不能不重启…

伦敦金交易商压箱底的交易技法 居然是……

很多伦敦金交易商&#xff0c;也就是我们常说的伦敦金交易平台&#xff0c;或者伦敦金交易服务提供商&#xff0c;他们会和一些资深的市场分析师合作。另外&#xff0c;一般在这些伦敦金交易商内部&#xff0c;也会有一批高手&#xff0c;他们一边在交易&#xff0c;一边在平台…

【设计模式深度剖析】【3】【创建型】【抽象工厂模式】| 要和【工厂方法模式】对比加深理解

&#x1f448;️上一篇:工厂方法模式 | 下一篇:建造者模式&#x1f449;️ 目录 抽象工厂模式前言概览定义英文原话直译什么意思呢&#xff1f;&#xff08;以运动型车族工厂&#xff0c;生产汽车、摩托产品为例&#xff09; 类图4个角色抽象工厂&#xff08;Abstract Fac…

起底震网病毒的来龙去脉

2010年&#xff0c;震网病毒被发现&#xff0c;引起世界哗然&#xff0c;在后续的10年间&#xff0c;陆陆续续有更多关于该病毒的背景和细节曝光。今年&#xff0c;《以色列时报》和《荷兰日报》又披露了关于此事件的更多信息&#xff0c;基于这些信息&#xff0c;我们重新梳理…

使用 Docker 部署 Jenkins 并设置初始管理员密码

使用 Docker 部署 Jenkins 并设置初始管理员密码 每一次开始&#xff0c;我都特别的认真与胆怯&#xff0c;是因为我期待结局&#xff0c;也能够不会那么粗糙&#xff0c;不会让我失望&#xff0c;所以&#xff0c;就多了些思考&#xff0c;多了些拘束&#xff0c;所以&#xf…

软件测试:功能测试-接口测试-自动化测试-性能测试-验收测试

软件测试的主要流程 一、测试主要的四个阶段 1.测试计划设计阶段&#xff1a;产品立项之后&#xff0c;进行需求分析&#xff0c;需求评审&#xff0c;业务需求评级&#xff0c;绘制业务流程图。确定测试负责人&#xff0c;开始制定测试计划&#xff1b; 2.测试准备阶段&…

不小心丢失mfc140u.dll文件怎么办?mfc140u.dll丢失的解决办法

当您发现mfc140u.dll文件不见了或者受损&#xff0c;别担心&#xff0c;我们可以一起解决这个问题&#xff01;首先&#xff0c;您可能会注意到一个小提示&#xff0c;当您尝试打开某些程序时&#xff0c;屏幕上会跳出一个消息说“找不到mfc140u.dll”或者“mfc140u.dll文件缺失…