GAN原理 代码解读

模型架构

在这里插入图片描述

代码

数据准备

import os
import time
import matplotlib.pyplot as plt
import numpy as np
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
from torchvision import datasets
import torch.nn as nn
import torch

# 创建文件夹存放图片
os.makedirs("data", exist_ok=True)
transform = transforms.Compose([
    transforms.ToTensor(), #它会进行0-1归一化,h方向/h,w方向/w。 然后将图片格式转换为 (channel,h,w)
    transforms.Normalize(0.5,0.5),#把数据归一化为均值为0.5,方差为0.5,图像的数值范围变成-1到1
])
# 下载训练数据后对图片进行transform里的toTensor和用均值方差归一化
train_dataset = datasets.MNIST('data',
                               train=True,
                               transform=transform,
                               download=True)
dataloader = torch.utils.data.DataLoader(train_dataset,batch_size=64,shuffle=True)

定义生成器

'''
    输入:正态分布随机数噪声(长度为100)
    输出:生成的图片,(1,28,28)
    中间过程:
        linear1: 100 -> 256
        linear2: 256 -> 512
        linear3: 512 -> 28*28
        reshape: 28x28 -> (1,28,28)
'''
class Generator(nn.Module):
    def __init__(self):
        super(Generator,self).__init__() # super().__init__() 是调用父类的__init__函数
        self.model = nn.Sequential(nn.Linear(100,256),nn.ReLU(),
                                   nn.Linear(256,512),nn.ReLU(),
                                    # 最后一层用tanh激活,将数据压缩到-1到1
                                   nn.Linear(512,28*28),nn.Tanh())
    def forward(self,x):
        img = self.model(x)
        img = img.view(-1,28,28,1) # 得到的是28*28=784,把它reshape为 (批量,h,w,channel)
        return img

定义判别器

'''
    判别器
    输入:(1,28,28)的图片
    输出:二分类的概率值 用sigmoid压缩到0-1之间
    内容:
    判别器 推荐使用LeakyRelu,因为生成器难以训练,Relu的负值直接变成0没有梯度了
'''
class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator,self).__init__()
        self.model = nn.Sequential(
            nn.Linear(28*28,512),nn.LeakyReLU(),
            nn.Linear(512,256),nn.LeakyReLU(),
            nn.Linear(256,1),nn.Sigmoid(),
        )
    def forward(self,x):
        x = x.view(-1,28*28)
        x = self.model(x)
        return x

初始化模型,优化器及损失计算函数

device = 'cuda' if torch.cuda.is_available() else 'cpu'
gen = Generator().to(device) # 初始化并放到了相应的设备上
dis = Discriminator().to(device)
dis_optim = torch.optim.Adam(dis.parameters(),lr=0.0001)
gen_optim = torch.optim.Adam(gen.parameters(),lr=0.0001)
bce_loss = torch.nn.BCELoss()

绘图函数

def gen_img_plot(model,epoch,test_input):
    prediction = model(test_input).detach().cpu().numpy() # 放在内存上 并转换为Numpy
    prediction = np.squeeze(prediction) # np.squeeze是一个numpy函数,删除数组中形状为1的维度
    fig = plt.figure(figsize=(4,4))
    for i in range(16): # 迭代这n张图片
        plt.subplot(4,4,i+1)
        plt.imshow((prediction[i] + 1) / 2) # 生成器生成的图片是-1到1之间的,无法绘图。通过 (原+1)/2把[-1,1]压缩到[0,1]
        plt.axis('off')
    plt.show()

定义训练函数


def train(num_epoch,test_input):
    D_loss = []
    G_loss = []
    # 训练循环
    for epoch in range(num_epoch):
        d_epoch_loss = 0
        g_epoch_loss = 0
        count = len(dataloader) # 返回批次数
        for step,(img,_) in enumerate(dataloader): # _是标签数据,img是(批次,h,w)
            img = img.to(device)
            size = img.size(0) # 得到一个批次的图片
            random_noise = torch.randn(size,100,device=device) # 生成器的输入

            '''一. 训练判别器'''
            '''用真实图片训练判别器'''
            dis_optim.zero_grad()
            real_output = dis(img) # 对判别取输入真实的图片,输出对真实图片的预测结果
            # 判别器在真实图像上的损失
            d_real_loss = bce_loss(real_output,
                                   # torch.ones_like(real_output) 创建一个根real_loss一样形状的全1数组,作为标签。
                                   torch.ones_like(real_output))
            d_real_loss.backward()

            '''用生成的图片训练判别器'''
            gen_img = gen(random_noise)
            # 因为此时是为了训练判别器,所以不能让生成器的梯度参与进来。所以用detach()取出无梯度的tensor
            fake_output = dis(gen_img.detach())
            d_fake_loss = bce_loss(fake_output,
                                   torch.zeros_like(fake_output))
            d_fake_loss.backward()
            d_loss = d_real_loss+d_fake_loss
            dis_optim.step() # 对参数进行优化

            '''二.训练生成器'''
            gen_optim.zero_grad()
            # 刚才是去掉生成器生成的图片的梯度,来训练判别器。此处不需要去掉梯度。让判别器进行判别
            fake_output = dis(gen_img)
            # 思想:目的是生成越来越逼真的图片瞒过判别器,让判别器判定生成的图片是真实的图片。
            # 实现方法:把判别器的结果输入到bce_loss,用1作为标签,看判别器把生成的图片判别为真的损失。
            g_loss = bce_loss(fake_output,
                              torch.ones_like(fake_output))
            g_loss.backward()
            gen_optim.step()

            # 计算一个epoch的损失
            with torch.no_grad(): #  禁止梯度计算和参数更新
                d_epoch_loss +=d_loss
                g_epoch_loss +=g_loss
        # 计算整体loss每个epoch的平均Loss
        with torch.no_grad(): #  禁止梯度计算和参数更新
            d_epoch_loss /= count
            g_epoch_loss /= count
            D_loss.append(d_epoch_loss)
            G_loss.append(g_epoch_loss)
            print('Epoch:', epoch)
            print(f'd_epoch_loss={d_epoch_loss}')
            print(f'g_epoch_loss={g_epoch_loss}')
            # 将16个长度为100的噪音输入到生成器并画图
            gen_img_plot(gen,epoch,test_input)

开始训练

'''开始计时'''
start_time = time.time()

'''开始训练'''
test_input = torch.randn(16,100,device=device) # 生成16个 长度为100的正太分布随机数。放到GPU中 作为输入
num_epoch = 50
train(num_epoch,test_input)

'''计时结束'''
end_time = time.time()
run_time = end_time - start_time
# 将输出的秒数保留两位小数
if int(run_time)<60:
    print(f'{round(run_time,2)}s')
else:
    print(f'{round(run_time/60,2)}minutes')

结果可视化

在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/88825.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

【C语言】动态内存管理(malloc,free,calloc,realloc)-- 详解

一、动态内存分配 定义&#xff1a;动态内存分配 (Dynamic Memory Allocation) 就是指在程序执行的过程中&#xff0c;动态地分配或者回收存储空间的分配内存的方法。动态内存分配不像数组等静态内存分配方法那样&#xff0c;需要预先分配存储空间&#xff0c;而是由系统根据程…

十、pikachu之php反序列化

文章目录 1、php反序列化概述2、实战3、关于Magic function4、__wakeup()和destruct() 1、php反序列化概述 在理解这个漏洞前&#xff0c;首先搞清楚php中serialize()&#xff0c;unserialize()这两个函数。 &#xff08;1&#xff09;序列化serialize()&#xff1a;就是把一个…

基于Jenkins+Git+Ansible 发布PHP 项目-------从小白到大神之路之学习运维第88天

第四阶段提升 时 间&#xff1a;2023年8月25日 参加人&#xff1a;全班人员 内 容&#xff1a; 基于JenkinsGitAnsible 发布PHP 项目 目录 基于JenkinsGitAnsible 发布PHP 项目 一、部署PHP 运行环境 二、主机环境配置 三、Tomcat主机操作&#xff1a; 四、Jenkins主…

【golang】panic函数、recover函数以及defer语句

从panic被引发到程序终止运行的大致过程是什么&#xff1f; 大致过程&#xff1a; 某个函数中的某行代码有意无意地引发了一个panic。这时&#xff0c;初始的panic详情会被建立起来&#xff0c;并且该程序的控制权会立即从从行代码转移至调用其所属函数的那行代码上&#xff…

CentOS系统环境搭建(十七)——elasticsearch设置密码

centos系统环境搭建专栏&#x1f517;点击跳转 elasticsearch设置密码 没有密码是很不安全的一件事&#x1f62d; 文章目录 elasticsearch设置密码1.设置密码2.登录elasticsearch3.登录kibana4.登录elasticsearch-head 1.设置密码 关于Elasticsearch的安装请看CentOS系统环境搭…

SpringBootWeb案例 Part 4

3. 修改员工 需求&#xff1a;修改员工信息 在进行修改员工信息的时候&#xff0c;我们首先先要根据员工的ID查询员工的信息用于页面回显展示&#xff0c;然后用户修改员工数据之后&#xff0c;点击保存按钮&#xff0c;就可以将修改的数据提交到服务端&#xff0c;保存到数据…

【校招VIP】产品思维分析之面试新的功能点设计

考点介绍&#xff1a; 这种题型是面试里出现频度最高&#xff0c;也是难度最大的一种&#xff0c;需要面试者对产品本身的功能、扩展性以及行业都有一定的了解。而且分析时间较短&#xff0c;需要一定的产品能力和回答技巧。 『产品思维分析之面试新的功能点设计』相关题目及解…

java+springboot+vue儿童慈善捐赠管理系统的设计与实现8n9e4

针对用户需求开发与设计&#xff0c;该技术尤其在各行业领域发挥了巨大的作用&#xff0c;有效地促进了“爱相连”儿童慈善管理的发展。然而&#xff0c;由于用户量和需求量的增加&#xff0c;信息过载等问题暴露出来&#xff0c;为改善传统线下管理中的不足&#xff0c;本文将…

Docker搭建LNMP----(超详细)

目录 ​编辑 一、项目环境 1.1 所有安装包下载&#xff1a; 1.3 服务器环境 1.4任务需求 二、Ngin 2.1、建立工作目录 2.2 编写 Dockerfile 脚本 2.3准备 nginx.conf 配置文件 2.4生成镜像 2.5创建自定义网络 2.6启动镜像容器 2.7验证 nginx、 三、Mysql 3.1建立…

传智教育广州校区又又又举行校内招聘会,多名学员被广东民生在线教育招入麾下

数字经济的高速发展以及经济形势的逐渐回暖&#xff0c;带动了企业对数字人才的用人需求增加&#xff0c;近日&#xff0c;传智教育旗下高端IT教育品牌黑马程序员多个校区接到了企业上门招聘的需求&#xff0c;各分校区通过举行校内招聘会&#xff0c;为用人企业和学员搭建了人…

一文速学-让神经网络不再神秘,一天速学神经网络基础-激活函数(二)

前言 思索了很久到底要不要出深度学习内容&#xff0c;毕竟在数学建模专栏里边的机器学习内容还有一大半算法没有更新&#xff0c;很多坑都没有填满&#xff0c;而且现在深度学习的文章和学习课程都十分的多&#xff0c;我考虑了很久决定还是得出神经网络系列文章&#xff0c;…

实验二 tftp 服务器环境搭建

tftp 服务器环境搭建 tftp&#xff08;Trivial File Transfer Protocol&#xff09;即简单文件传输协议是TCP/IP协议族中的一个用来在客户机与服务器之间进行简单文件传输的协议&#xff0c;提供不复杂、开销不大的文件传输服务。端口号为69 【实验目的】 掌握 tftp 环境搭…

【目标检测】“复制-粘贴 copy-paste” 数据增强实现

文章目录 前言1. 效果展示代码说明3. 参考文档4. 不合适点 前言 本文来源论文《Simple Copy-Paste is a Strong Data Augmentation Method for Instance Segmentation》&#xff08;CVPR2020&#xff09;&#xff0c;对其数据增强方式进行实现。 论文地址&#xff1a;https:/…

MediaPlayer音频与视频的播放介绍

作者&#xff1a;向阳逐梦 Android多媒体中的——MediaPlayer&#xff0c;我们可以通过这个API来播放音频和视频该类是Androd多媒体框架中的一个重要组件&#xff0c;通过该类&#xff0c;我们可以以最小的步骤来获取&#xff0c;解码和播放音视频。 它支持三种不同的媒体来源…

Talk | 上海交通大学官同坤:识别任意文本,隐式注意力与字符间蒸馏在文本识别中的应用

本期为TechBeat人工智能社区第525期线上Talk&#xff01; 北京时间8月23日(周三)20:00&#xff0c;上海交通大学博士生—官同坤的Talk已准时在TechBeat人工智能社区开播&#xff01; 他与大家分享的主题是: “隐式注意力与字符间蒸馏在文本识别中的应用”&#xff0c;分享了识别…

使用docker-maven-plugin插件构建镜像并推送至私服Harbor

前言 如下所示&#xff0c;建议使用 Dockerfile Maven 插件&#xff0c;但该插件也停止维护更新了。因此先暂时使用docker-maven-plugin插件。 一、开启Docker服务器的远程访问 1.1 开启2375远程访问 默认的dokcer是不支持远程访问的&#xff0c;需要加点配置&#xff0c;开…

bh002- Blazor hybrid / Maui 保存设置快速教程

1. 建立工程 bh002_ORM 源码 2. 添加 nuget 包 <PackageReference Include"BootstrapBlazor.WebAPI" Version"7.*" /> <PackageReference Include"FreeSql" Version"*" /> <PackageReference Include"FreeSql.…

MyBatis分页插件PageHelper的使用及特殊字符的处理

目录 一、PageHelper简介 1.什么是分页 2.PageHelper是什么 3.使用PageHelper的优点 二、PageHelper插件的使用 原生limit查询 1. 导入pom依赖 2. Mybatis.cfg.xml 配置拦截器 3. 使用PageHelper进行分页 三、特殊字符的处理 1.SQL注入&#xff1a; 2.XML转义&#…

【Linux】【驱动】第一个相对完整的驱动编写

【Linux】【驱动】第一个相对完整的驱动编写 续1.驱动部分的代码2 app 代码3 操作相关的代码 续 这个章节会讲述去直接控制一个GPIO&#xff0c;高低电平。 因为linux不允许直接去操作寄存器&#xff0c;所以在操作寄存器的时候就需要使用到函数&#xff1a;ioremap 和iounma…

线性代数的学习和整理10:各种特殊类型的矩阵(草稿-----未完成 建设ing)

目录 1 图形化分类 1.1对称矩阵 1.2 梯形矩阵 1.3 三角矩阵 1.3.1 上三角矩阵 1.4 对角线矩阵 2 按各自功能分 2.1 等价矩阵 2.2 增广矩阵 2.3 伴随矩阵 2.4 正交矩阵 2.5 正交矩阵 2.6 相似矩阵 1 图形化分类 1.1对称矩阵 1.2 梯形矩阵 1.3 三角矩阵 1.3.1 上…