Gan论文阅读笔记

GAN论文阅读笔记

2014年老论文了,主要记录一些重要的东西。论文链接如下:

Generative Adversarial Nets (neurips.cc)

文章目录

  • GAN论文阅读笔记
    • 出发点
    • 创新点
    • 设计
    • 训练代码
    • 网络结构代码
    • 测试代码

出发点

Deep generative models have had less of an impact, due to the difficulty of approximating many intractable probabilistic computations that arise in maximum likelihood estimation and related strategies, and due to difficulty of leveraging the benefits of piecewise linear units in the generative context.

​ 当时的生成模型效果不佳在于近似许多棘手的概率计算十分困难,如最大似然估计等。除此之外,把利用分段线性单元运用到生成场景中也有困难。于是作者提出新的生成模型:GAN。

​ 我的理解是,当时的生成模型都是去学习模型生成数据的分布,比如确定方差,确定均值之类的参数,然而这种方法十分难以学习,而且计算量大而复杂,作者考虑到这一点,对生成模型采用端到端的学习策略,不去学习生成数据的分布,而是直接学习模型,只要这个模型的生成结果能够逼近Ground-Truth,那么就可以直接用这个模型代替分布去生成数据。这是典型的黑箱思想。

创新点

adiscriminative model that learns to determine whether a sample is from the model distribution or the data distribution. The generative model can be thought of as analogous to a team of counterfeiters, trying to produce fake currency and use it without detection, while the discriminative model is analogous to the police, trying to detect the counterfeit currency. Competition in this game drives both teams to improve their methods until the counterfeits are indistiguishable from the genuine articles.

创新点1:提出对抗学习策略:提出两个model之间相互对抗,相互抑制的策略。一个model名为生成器Generator,一个model名为判别器Discriminator,生成器尽可能生成接近真实的数据,判别器尽可能识别出生成器数据是Fake。

In this article, we explore the special case when the generative model generates samples by passing random noise through a multilayer perceptron, and the discriminative model is also a multilayer perceptron.

创新点2:当两个model都使用神经网络时,可以运用反向传播和Dropout等算法进行学习,这样就可以避免使用马尔科夫链。

设计

To learn the generator’s distribution pgover data x, we define a prior on input noise variables pz(z), then represent a mapping to data space as G(z; θg), where G is a differentiable function represented by a multilayer perceptron with parameters θg. We also define a second multilayer perceptron D(x; θd) that outputs a single scalar. D(x) represents the probability that x came from the data rather than pg.

1.输入:为了让生成器G生成的数据分布pg与真实数据分布x接近,策略是给G输入一个噪音变量z,然后学习参数θg,这个θg是G网络权重。因此,G可以被写作:G(z;θg)。
m i n G m a x D V ( D , G ) = E x ∼ p d a t a ( x ) [ l o g D ( x ) ] + E z ∼ p z ( z ) [ l o g ( 1 − D ( G ( z ) ) ) ] \underset{G}{min}\underset{D}{max}V(D, G) =\mathbb{E}_{x \sim p_{data}(x)}\left[ logD(x)\right] + \mathbb{E}_{z \sim p_z(z)}\left[log(1 - D(G(z)))\right] GminDmaxV(D,G)=Expdata(x)[logD(x)]+Ezpz(z)[log(1D(G(z)))]
2.对抗性损失函数:从代码可知,对抗性损失是两个BCELoss的和,V尽可能使D(x)更大,在此基础上尽可能使G(z)更小。这是有先后顺序的,在后面会做说明。

在代码中可知,先人为生成两个标签,第一个标签是用torch.ones生成的全为1的矩阵,形状为(batch,1)。其中batch是输入噪声的batch,第二维度只是一个数字——1,这个标签用于判别器D的BCELoss中,代入BCELoss即可得到上面对抗性损失中左侧的期望。第二个标签是用torch.zeors生成的全为0的矩阵,形状同理为(batch,1),运用于生成器G的BCELoss中,代入即可得到对抗性损失的右侧期望。

we alternate between k steps of optimizing D and one step of optimizing G.

This results in D being maintained near its optimal solution, so long as G changes slowly enough.

3.D与G的训练有先后顺序:判别器D先于生成器G训练,而且要求先对D训练k步,再为G训练1步,这就保证G的训练比D足够慢。

如果生成器G足够强大,那么判别器无法再监测生成器,也就没有对抗的必要了。相反,如果判别器D太过于强大,那么生成器也训练地十分缓慢。

在这里插入图片描述

4.算法图如上。

训练代码

import torch
import torch.nn as nn
from torchvision import datasets
from torchvision import transforms
from torchvision.utils import save_image
from torch.utils.data import DataLoader
from Model import generator
from Model import discriminator

import os

if not os.path.exists('gan_train.py'):  # 报错中间结果
    os.mkdir('gan_train.py')


def to_img(x):  # 将结果的-0.5~0.5变为0~1保存图片
    out = 0.5 * (x + 1)
    out = out.clamp(0, 1)
    out = out.view(-1, 1, 28, 28)
    return out


batch_size = 96
num_epoch = 200
z_dimension = 100


# 数据预处理
img_transform = transforms.Compose([
    transforms.ToTensor(),  # 图像数据转换成了张量,并且归一化到了[0,1]。
    transforms.Normalize([0.5], [0.5])  # 这一句的实际结果是将[0,1]的张量归一化到[-1, 1]上。前面的(0.5)均值, 后面(0.5)标准差,
])
# MNIST数据集
mnist = datasets.MNIST(
    root='./data', train=True, transform=img_transform, download=True)
# 数据集加载器
dataloader = torch.utils.data.DataLoader(
    dataset=mnist, batch_size=batch_size, shuffle=True)

D = discriminator()  # 创建生成器
G = generator()  # 创建判别器
if torch.cuda.is_available():  # 放入GPU
    D = D.cuda()
    G = G.cuda()

criterion = nn.BCELoss()  # BCELoss 因为可以当成是一个分类任务,如果后面不加Sigmod就用BCEWithLogitsLoss
d_optimizer = torch.optim.Adam(D.parameters(), lr=0.0003)  # 优化器
g_optimizer = torch.optim.Adam(G.parameters(), lr=0.0003)  # 优化器

# 开始训练
for epoch in range(num_epoch):
    for i, (img, _) in enumerate(dataloader):  # img[96,1,28,28]
        G.train()
        num_img = img.size(0)  # num_img=batchsize
        # =================train discriminator
        img = img.view(num_img, -1)  # 把图片拉平,为了输入判别器 [96,784]
        real_img = img.cuda()  # 装进cuda,真实图片

        real_label = torch.ones(num_img).reshape(num_img, 1).cuda()  # 希望判别器对real_img输出为1 [96,1]
        fake_label = torch.zeros(num_img).reshape(num_img, 1).cuda()  # 希望判别器对fake_img输出为0  [96,1]

        # 先训练鉴别器
        # 计算真实图片的loss
        real_out = D(real_img)  # 将真实图片输入鉴别器 [96,1]
        d_loss_real = criterion(real_out, real_label)  # 希望real_out越接近1越好 [1]
        real_scores = real_out  # 后面print用的

        # 计算生成图片的loss
        z = torch.randn(num_img, z_dimension).cuda()  # 创建一个100维度的随机噪声作为生成器的输入 [96,1]
        #   这个z维度和生成器第一个Linear第一个参数一致
        # 避免计算G的梯度
        fake_img = G(z).detach()  # 生成伪造图片 [96,748]
        fake_out = D(fake_img)  # 给判别器判断生成的好不好 [96,1]

        d_loss_fake = criterion(fake_out, fake_label)  # 希望判别器给fake_out越接近0越好 [1]
        fake_scores = fake_out  # 后面print用的

        d_loss = d_loss_real + d_loss_fake

        d_optimizer.zero_grad()
        d_loss.backward()
        d_optimizer.step()

        # 训练生成器
        # 计算生成图片的loss
        z = torch.randn(num_img, z_dimension).cuda()  # 生成随机噪声 [96,100]

        fake_img = G(z)  # 生成器伪造图像 [96,784]
        output = D(fake_img)  # 将伪造图像给判别器判断真伪 [96,1]
        g_loss = criterion(output, real_label)  # 生成器希望判别器给的值越接近1越好 [1]

        # 更新生成器
        g_optimizer.zero_grad()
        g_loss.backward()
        g_optimizer.step()

        if (i + 1) % 100 == 0:
            print(
                f'Epoch [{epoch}/{num_epoch}], d_loss: {d_loss.cpu().detach():.6f}, g_loss: {g_loss.cpu().detach():.6f}',
                f'D real: {real_scores.cpu().detach().mean():.6f}, D fake: {fake_scores.cpu().detach().mean():.6f}')
    if epoch == 0:  # 保存图片
        real_images = to_img(real_img.detach().cpu())
        save_image(real_images, './img_gan/real_images.png')

    fake_images = to_img(fake_img.detach().cpu())
    save_image(fake_images, f'./img_gan/fake_images-{epoch + 1}.png')

    G.eval()
    with torch.no_grad():
        new_z = torch.randn(batch_size, 100).cuda()
        test_img = G(new_z)
        print(test_img.shape)
        test_img = to_img(test_img.detach().cpu())
        test_path = f'./test_result/the_{epoch}.png'
        save_image(test_img, test_path)

# 保存模型
torch.save(G.state_dict(), './generator.pth')
torch.save(D.state_dict(), './discriminator.pth')

网络结构代码

import torch
from torch import nn


# 判别器 判别图片是不是来自MNIST数据集
class discriminator(nn.Module):
    def __init__(self):
        super(discriminator, self).__init__()
        self.dis = nn.Sequential(
            nn.Linear(784, 256),  # 784=28*28
            nn.LeakyReLU(0.2),
            nn.Linear(256, 256),
            nn.LeakyReLU(0.2),
            nn.Linear(256, 1),
            nn.Sigmoid()
            #   sigmoid输出这个生成器是或不是原图片,是二分类
        )

    def forward(self, x):
        x = self.dis(x)
        return x


# 生成器 生成伪造的MNIST数据集
class generator(nn.Module):
    def __init__(self):
        super(generator, self).__init__()
        self.gen = nn.Sequential(
            nn.Linear(100, 256),  # 输入为100维的随机噪声
            nn.ReLU(),
            nn.Linear(256, 256),
            nn.ReLU(),
            nn.Linear(256, 784),
            #   生成器输出的特征维和正常图片一样,这是一个可参考的点
            nn.Tanh()
        )

    def forward(self, x):
        x = self.gen(x)
        return x


class FinetuneModel(nn.Module):
    def __init__(self, weights):
        super(FinetuneModel, self).__init__()
        self.G = generator()
        base_weights = torch.load(weights)

        model_parameters = dict(self.G.named_parameters())
        #   不是对model进行named_parameters,而是对model里面的具体网络进行named_parameters取出参数,否则取出的是model冗余的参数去测试
        pretrained_weights = {k: v for k, v in base_weights.items() if k in model_parameters}

        new_state_dict = {k: pretrained_weights[k] for k in model_parameters.keys()}
        self.G.load_state_dict(new_state_dict)

    def forward(self, input):
        output = self.G(input)
        return output

测试代码

import os
import sys
import numpy as np
import torch
import argparse
import torch.utils.data
from PIL import Image
from Model import FinetuneModel
from Model import generator
from torchvision.utils import save_image

parser = argparse.ArgumentParser("GAN")
parser.add_argument('--save_path', type=str, default='./test_result')
parser.add_argument('--gpu', type=int, default=0)
parser.add_argument('--seed', type=int, default=2)
parser.add_argument('--model', type=str, default='generator.pth')

args = parser.parse_args()
save_path = args.save_path
os.makedirs(save_path, exist_ok=True)


def to_img(x):  # 将结果的-0.5~0.5变为0~1保存图片
    out = 0.5 * (x + 1)
    out = out.clamp(0, 1)
    out = out.view(-1, 1, 28, 28)
    return out


def main():
    if not torch.cuda.is_available():
        print("no gpu device available")
        sys.exit(1)

    model = FinetuneModel(args.model)
    model = model.to(device=args.gpu)
    model.eval()

    z_dimension = 100

    with torch.no_grad():
        for i in range(100):
            z = torch.randn(96, z_dimension).cuda()  # 创建一个100维度的随机噪声作为生成器的输入 [96,100]
            output = model(z)
            print(output.shape)
            u_name = f'the_{i}.png'
            print(f'processing {u_name}')
            u_path = save_path + '/' + u_name
            output = to_img(output.cpu().detach())
            save_image(output, u_path)


if __name__ == '__main__':
    main()

本文毕

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/233211.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

C/C++ 判断str1能不能由str2里面的字符构成,如果可以,返回true;否则,返回false

题目: 给两个字符串:str1和str2,判断str1能不能由str2里面的字符构成。 如果可以,返回true; 否则,返回false。 限制: str2 中的每个字符只能在str1中使用一次。 示例 1: 输入:str1 "a&q…

CSS3技巧36:让内容垂直居中的三种方式

让内容垂直居中,是一个很重要的应用情景,在很多场合都会需要。这也是面试的时候,一些考官喜欢拿来初面的小题目。 这里,小结下让内容垂直居中的三种方式。 当然,读者如果有更好的方法,也可以提出来。 基本…

使用Java实现汉诺塔问题

文章目录 汉诺塔问题 今天和大家来看看汉诺塔问题,这也是一个经典的算法 汉诺塔问题 分治算法经典问题:汉诺塔问题 汉诺塔的传说 汉诺塔:汉诺塔(又称河内塔)问题是源于印度一个古老传说的益智玩具。大梵天创造世界的…

面试必考精华版Leetcode875. 爱吃香蕉的珂珂

题目&#xff1a; 代码(首刷看解析&#xff09;&#xff1a; class Solution { public:int minEatingSpeed(vector<int>& piles, int h) {int low 1;int high 0;for(int pile:piles){highmax(high,pile);}int k high;while(low<high){int speed (high-low)/2l…

『 MySQL数据库 』聚合统计

文章目录 前言 &#x1f951;&#x1f95d; 聚合函数&#x1f353; COUNT( ) 查询数据数量&#x1f353; SUM( ) 查询数据总和&#x1f353; AVG( ) 查询数据平均值&#x1f353; MAX( ) 查询数据最大值&#x1f353; MIN( ) 查询数据最小值 &#x1f95d; 数据分组GROUP BY子句…

期待一下elasticsearch还未发布的8.12版本,由lucene底层带来的大幅度提升

现在是北京时间23年12月10日。当前es最新版本还是es8.11版本。我们可以期待一下不久的将来&#xff0c;es的8.12版本看到大幅度的检索性能提升。受益于 Lucene 9.9版本&#xff0c;内核带来的大幅提升&#xff01; 此次向量检索利用底层指令fma会性能提升5%。并且还提供了向量点…

零一万物模型折腾笔记:官方 Yi-34B 模型基础使用

当争议和流量都消失后&#xff0c;或许现在是个合适的时间点&#xff0c;来抛开情绪、客观的聊聊这个 34B 模型本身&#xff0c;尤其是实践应用相关的一些细节。来近距离看看这个模型在各种实际使用场景中的真实表现和对硬件的性能要求。 或许&#xff0c;这会对也想在本地私有…

NLP项目实战01--电影评论分类

介绍&#xff1a; 欢迎来到本篇文章&#xff01;在这里&#xff0c;我们将探讨一个常见而重要的自然语言处理任务——文本分类。具体而言&#xff0c;我们将关注情感分析任务&#xff0c;即通过分析电影评论的情感来判断评论是正面的、负面的。 展示&#xff1a; 训练展示如下…

Android笔记(十七):PendingIntent简介

PendingIntent翻译成中文为“待定意图”&#xff0c;这个翻译很好地表示了它的涵义。PendingIntent描述了封装Intent意图以及该意图要执行的目标操作。PendingIntent封装Intent的目标行为的执行是必须满足一定条件&#xff0c;只有条件满足&#xff0c;才会触发意图的目标操作。…

HCIP —— BGP 基础 (上)

BGP --- 边界网关协议 &#xff08;路径矢量协议&#xff09; IGP --- 内部网关协议 --- OSPF RIP ISIS EGP --- 外部网关协议 --- EGP BGP AS --- 自治系统 由单一的组织或者机构独立维护的网络设备以及网络资源的集合。 因 网络范围太大 需 自治 。 为区分不同的AS&#…

C#,图算法——以邻接节点表示的图最短路径的迪杰斯特拉(Dijkstra)算法C#程序

1 文本格式 using System; using System.Text; using System.Linq; using System.Collections; using System.Collections.Generic; namespace Legalsoft.Truffer.Algorithm { public class Node // : IComparable<Node> { private int vertex, weigh…

CNN发展史脉络 概述图整理

CNN发展史脉络概述图整理&#xff0c;学习心得&#xff0c;供参考&#xff0c;错误请批评指正。 相关论文&#xff1a; LeNet&#xff1a;Handwritten Digit Recognition with a Back-Propagation Network&#xff1b; Gradient-Based Learning Applied to Document Recogniti…

【工具使用-JFlash】如何使用Jflash擦除和读取MCU内部指定扇区的数据

一&#xff0c;简介 在调试的过程中&#xff0c;特别是在调试向MCU内部flash写数据的时候&#xff0c;我们常常要擦除数据区的内容&#xff0c;而不想擦除程序取。那这种情况就需要擦除指定的扇区数据即可。本文介绍一种方法&#xff0c;可以擦除MCU内部Flash中指定扇区的数据…

【小沐学Python】Python实现TTS文本转语音(speech、pyttsx3、百度AI)

文章目录 1、简介2、Windows语音2.1 简介2.2 安装2.3 代码 3、pyttsx33.1 简介3.2 安装3.3 代码 4、ggts4.1 简介4.2 安装4.3 代码 5、pywin326、百度AI7、百度飞桨结语 1、简介 TTS(Text To Speech) 译为从文本到语音&#xff0c;TTS是人工智能AI的一个模组&#xff0c;是人机…

【linux】yum安装时: Couldn‘t resolve host name for XXXXX

yum 安装 sysstat 报错了&#xff1a; Kylin Linux Advanced Server 10 - Os 0.0 B/s | 0 B 00:00 Errors during downloading metadata for repository ks10-adv-os:- Curl error (6): Couldnt resolve host nam…

【摸鱼向】利用Arduino实现自动化切屏

曾几何时&#xff0c;每次背着老妈打游戏的时候都要紧张兮兮地听着爸妈是不是会破门而入&#xff0c;这严重影响了游戏体验&#xff0c;因此&#xff0c;最近想到了用Arduino加上红外传感器来实现自动监测的功能&#xff0c;当有人靠近门口的时候&#xff0c;电脑可以自动执行预…

【文件上传系列】No.2 秒传(原生前端 + Node 后端)

上一篇文章 【文件上传系列】No.1 大文件分片、进度图展示&#xff08;原生前端 Node 后端 & Koa&#xff09; 秒传效果展示 秒传思路 整理的思路是&#xff1a;根据文件的二进制内容生成 Hash 值&#xff0c;然后去服务器里找&#xff0c;如果找到了&#xff0c;说明已经…

pytorch:YOLOV1的pytorch实现

pytorch&#xff1a;YOLOV1的pytorch实现 注&#xff1a;本篇仅为学习记录、学习笔记&#xff0c;请谨慎参考&#xff0c;如果有错误请评论指出。 参考&#xff1a; 动手学习深度学习pytorch版——从零开始实现YOLOv1 目标检测模型YOLO-V1损失函数详解 3.1 YOLO系列理论合集(Y…

【IDEA】解决mac版IDEA,进行单元测试时控制台不能输入问题

我的IDEA版本 编辑VM配置 //增加如下配置&#xff0c;重启IDEA -Deditable.java.test.consoletrue测试效果

风力发电对讲 IP语音对讲终端IP安防一键呼叫对讲 医院对讲终端SV-6005网络音频终端

风力发电对讲 IP语音对讲终端IP安防一键呼叫对讲 医院对讲终端SV-6005网络音频终端 目 录 1、产品规格 2、接口使用 2.1、侧面接口功能 2.2、背面接口功能 2.3、面板接口功能 3、功能使用 1、产品规格 输入电源&#xff1a; 12V&#xff5e;24V的直流电源 网络接口&am…