PyTorch训练深度卷积生成对抗网络DCGAN

文章目录

    • DCGAN介绍
    • 代码
    • 结果
    • 参考

DCGAN介绍

将CNN和GAN结合起来,把监督学习和无监督学习结合起来。具体解释可以参见 深度卷积对抗生成网络(DCGAN)

DCGAN的生成器结构:
在这里插入图片描述
图片来源:https://arxiv.org/abs/1511.06434

代码

model.py

import torch
import torch.nn as nn

class Discriminator(nn.Module):
    def __init__(self, channels_img, features_d):
        super(Discriminator, self).__init__()
        self.disc = nn.Sequential(
            # Input: N x channels_img x 64 x 64
            nn.Conv2d(
                channels_img, features_d, kernel_size=4, stride=2, padding=1
            ), # 32 x 32
            nn.LeakyReLU(0.2),
            self._block(features_d, features_d*2, 4, 2, 1), # 16 x 16
            self._block(features_d*2, features_d*4, 4, 2, 1), # 8 x 8
            self._block(features_d*4, features_d*8, 4, 2, 1), # 4 x 4
            nn.Conv2d(features_d*8, 1, kernel_size=4, stride=2, padding=0), # 1 x 1
            nn.Sigmoid(),
        )

    def _block(self, in_channels, out_channels, kernel_size, stride, padding):
        return nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding, bias=False),
            nn.BatchNorm2d(out_channels),
            nn.LeakyReLU(0.2),
        )

    def forward(self, x):
        return self.disc(x)
    
class Generator(nn.Module):
    def __init__(self, z_dim, channels_img, features_g):
        super(Generator, self).__init__()
        self.gen = nn.Sequential(
            # Input: N x z_dim x 1 x 1
            self._block(z_dim, features_g*16, 4, 1, 0), # N x f_g*16 x 4 x 4
            self._block(features_g*16, features_g*8, 4, 2, 1), # 8x8
            self._block(features_g*8, features_g*4, 4, 2, 1), # 16x16
            self._block(features_g*4, features_g*2, 4, 2, 1), # 32x32
            nn.ConvTranspose2d(
                features_g*2, channels_img, kernel_size=4, stride=2, padding=1,
            ),
            nn.Tanh(),

        )

    def _block(self, in_channels, out_channels, kernel_size, stride, padding):
        return nn.Sequential(
            nn.ConvTranspose2d(
                in_channels,
                out_channels,
                kernel_size,
                stride,
                padding,
                bias=False,
            ),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(),
        )
    
    def forward(self, x):
        return self.gen(x)


def initialize_weights(model):
    for m in model.modules():
        if isinstance(m, (nn.Conv2d, nn.ConvTranspose2d, nn.BatchNorm2d)):
            nn.init.normal_(m.weight.data, 0.0, 0.02)

def test():
    N, in_channels, H, W = 8, 3, 64, 64
    z_dim = 100
    x = torch.randn((N, in_channels, H, W))
    disc = Discriminator(in_channels, 8)
    initialize_weights(disc)
    assert disc(x).shape == (N, 1, 1, 1)

    gen = Generator(z_dim, in_channels, 8)
    initialize_weights(gen)
    z = torch.randn((N, z_dim, 1, 1))
    assert gen(z).shape == (N, in_channels, H, W)
    print("success")
    
if __name__ == "__main__":
    test()

训练使用的数据集:CelebA dataset (Images Only) 总共1.3GB的图片,使用方法,将其解压到当前目录

图片如下图所示:
在这里插入图片描述

train.py

import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.datasets as datasets
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter
from model import Discriminator, Generator, initialize_weights

# Hyperparameters etc.
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
LEARNING_RATE = 2e-4  # could also use two lrs, one for gen and one for disc
BATCH_SIZE = 128
IMAGE_SIZE = 64
CHANNELS_IMG = 3 # 1 if MNIST dataset; 3 if celeb dataset
NOISE_DIM = 100
NUM_EPOCHS = 5
FEATURES_DISC = 64
FEATURES_GEN = 64

transforms = transforms.Compose(
    [
        transforms.Resize((IMAGE_SIZE, IMAGE_SIZE)),
        transforms.ToTensor(),
        transforms.Normalize(
            [0.5 for _ in range(CHANNELS_IMG)], [0.5 for _ in range(CHANNELS_IMG)]
        ),
    ]
)

# If you train on MNIST, remember to set channels_img to 1
# dataset = datasets.MNIST(
#     root="dataset/", train=True, transform=transforms, download=True
# )

# comment mnist above and uncomment below if train on CelebA

# If you train on celeb dataset, remember to set channels_img to 3
dataset = datasets.ImageFolder(root="celeb_dataset", transform=transforms)
dataloader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True)
gen = Generator(NOISE_DIM, CHANNELS_IMG, FEATURES_GEN).to(device)
disc = Discriminator(CHANNELS_IMG, FEATURES_DISC).to(device)
initialize_weights(gen)
initialize_weights(disc)

opt_gen = optim.Adam(gen.parameters(), lr=LEARNING_RATE, betas=(0.5, 0.999))
opt_disc = optim.Adam(disc.parameters(), lr=LEARNING_RATE, betas=(0.5, 0.999))
criterion = nn.BCELoss()

fixed_noise = torch.randn(32, NOISE_DIM, 1, 1).to(device)
writer_real = SummaryWriter(f"logs/real")
writer_fake = SummaryWriter(f"logs/fake")
step = 0

gen.train()
disc.train()

for epoch in range(NUM_EPOCHS):
    # Target labels not needed! <3 unsupervised
    for batch_idx, (real, _) in enumerate(dataloader):
        real = real.to(device)
        noise = torch.randn(BATCH_SIZE, NOISE_DIM, 1, 1).to(device)
        fake = gen(noise)

        ### Train Discriminator: max log(D(x)) + log(1 - D(G(z)))
        disc_real = disc(real).reshape(-1)
        loss_disc_real = criterion(disc_real, torch.ones_like(disc_real))
        disc_fake = disc(fake.detach()).reshape(-1)
        loss_disc_fake = criterion(disc_fake, torch.zeros_like(disc_fake))
        loss_disc = (loss_disc_real + loss_disc_fake) / 2
        disc.zero_grad()
        loss_disc.backward()
        opt_disc.step()

        ### Train Generator: min log(1 - D(G(z))) <-> max log(D(G(z))
        output = disc(fake).reshape(-1)
        loss_gen = criterion(output, torch.ones_like(output))
        gen.zero_grad()
        loss_gen.backward()
        opt_gen.step()

        # Print losses occasionally and print to tensorboard
        if batch_idx % 100 == 0:
            print(
                f"Epoch [{epoch}/{NUM_EPOCHS}] Batch {batch_idx}/{len(dataloader)} \
                  Loss D: {loss_disc:.4f}, loss G: {loss_gen:.4f}"
            )

            with torch.no_grad():
                fake = gen(fixed_noise)
                # take out (up to) 32 examples
                img_grid_real = torchvision.utils.make_grid(real[:32], normalize=True)
                img_grid_fake = torchvision.utils.make_grid(fake[:32], normalize=True)

                writer_real.add_image("Real", img_grid_real, global_step=step)
                writer_fake.add_image("Fake", img_grid_fake, global_step=step)

            step += 1

结果

训练5个epoch,部分结果如下:

Epoch [3/5] Batch 1500/1583                   Loss D: 0.4996, loss G: 1.1738
Epoch [4/5] Batch 0/1583                   Loss D: 0.4268, loss G: 1.6633
Epoch [4/5] Batch 100/1583                   Loss D: 0.4841, loss G: 1.7475
Epoch [4/5] Batch 200/1583                   Loss D: 0.5094, loss G: 1.2376
Epoch [4/5] Batch 300/1583                   Loss D: 0.4376, loss G: 2.1271
Epoch [4/5] Batch 400/1583                   Loss D: 0.4173, loss G: 1.4380
Epoch [4/5] Batch 500/1583                   Loss D: 0.5213, loss G: 2.1665
Epoch [4/5] Batch 600/1583                   Loss D: 0.5036, loss G: 2.1079
Epoch [4/5] Batch 700/1583                   Loss D: 0.5158, loss G: 1.0579
Epoch [4/5] Batch 800/1583                   Loss D: 0.5426, loss G: 1.9427
Epoch [4/5] Batch 900/1583                   Loss D: 0.4721, loss G: 1.2659
Epoch [4/5] Batch 1000/1583                   Loss D: 0.5662, loss G: 2.4537
Epoch [4/5] Batch 1100/1583                   Loss D: 0.5604, loss G: 0.8978
Epoch [4/5] Batch 1200/1583                   Loss D: 0.4085, loss G: 2.0747
Epoch [4/5] Batch 1300/1583                   Loss D: 1.1894, loss G: 0.1825
Epoch [4/5] Batch 1400/1583                   Loss D: 0.4518, loss G: 2.1509
Epoch [4/5] Batch 1500/1583                   Loss D: 0.3814, loss G: 1.9391

使用

tensorboard --logdir=logs

打开tensorboard

在这里插入图片描述

参考

[1] DCGAN implementation from scratch
[2] https://arxiv.org/abs/1511.06434

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/80232.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

Docker容器:docker的资源控制及docker数据管理

文章目录 一.docker的资源控制1.CPU 资源控制1.1 资源控制工具1.2 cgroups有四大功能1.3 设置CPU使用率上限1.4 进行CPU压力测试1.5 设置50%的比例分配CPU使用时间上限1.6 设置CPU资源占用比&#xff08;设置多个容器时才有效&#xff09;1.6.1 两个容器测试cpu1.6.2 设置容器绑…

马上七夕到了,用各种编程语言实现10种浪漫表白方式

目录 1. 直接表白&#xff1a;2. 七夕节表白&#xff1a;3. 猜心游戏&#xff1a;4. 浪漫诗句&#xff1a;5. 爱的方程式&#xff1a;6. 爱心Python&#xff1a;7. 心形图案JavaScript 代码&#xff1a;8. 心形并显示表白信息HTML 页面&#xff1a;9. Java七夕快乐&#xff1a;…

【学习笔记之vue】These dependencies were not found:

These dependencies were not found:方案一 全部安装一遍 我们先浅试一个axios >> npm install axios 安装完报错就没有axios了&#xff0c;验证咱们的想法没有问题&#xff0c;实行&#xff01; ok

city walk结合VR全景,打造新时代下的智慧城市

近期爆火的city walk是什么梗&#xff1f;它其实是近年来备受追捧的城市漫步方式&#xff0c;一种全新的城市探索方式&#xff0c;与传统的旅游观光不同&#xff0c;城市漫步更注重与城市的亲密接触&#xff0c;一步步地感受城市的脉动。其实也是一种自由、休闲的方式&#xff…

【Windows系统编程】03.远线程注入ShellCode

shellcode&#xff1a;本质上也是一段普通的代码&#xff0c;只不过特殊的编程手法&#xff0c;可以在任意环境下&#xff0c;不依赖于原有的依赖库执行。 远程线程 #include <iostream> #include <windows.h> #include <TlHelp32.h>int main(){HANDLE hPr…

stable diffusion基础

整合包下载&#xff1a;秋叶大佬 【AI绘画8月最新】Stable Diffusion整合包v4.2发布&#xff01; 参照&#xff1a;基础04】目前全网最贴心的Lora基础知识教程&#xff01; VAE 作用&#xff1a;滤镜微调 VAE下载地址&#xff1a;C站&#xff08;https://civitai.com/models…

小程序的数据绑定和事件绑定

小程序的数据绑定 1.需要渲染的数据放在index.js中的data里 Page({data: {info:HELLO WORLD,imgSrc:/images/1.jpg,randomNum:Math.random()*10,randomNum1:Math.random().toFixed(2)}, }) 2.在WXML中通过{{}}获取数据 <view>{{info}}</view><image src"{{…

爬虫逆向实战(二)--某某观察城市排行榜

一、数据接口分析 主页地址&#xff1a;某某观察 1、抓包 通过抓包可以发现数据接口是multi 2、判断是否有加密参数 请求参数是否加密&#xff1f; 无请求头是否加密&#xff1f; 无cookie是否加密&#xff1f; 无响应数据是否加密&#xff1f; 通过查看“响应”板块可以…

windows10 安装WSL2, Ubuntu,docker

AI- 通过docker开发调试部署ChatLLM 阅读时长&#xff1a;10分钟 本文内容&#xff1a; window上安装ubuntu虚拟机&#xff0c;并在虚拟机中安装docker&#xff0c;通过docker部署数字人模型&#xff0c;通过vscode链接到虚拟机进行开发调试.调试完成后&#xff0c;直接部署在云…

如何搭建游戏服务器?有哪些操作步骤

​  选择游戏服务器提供商 为确保游戏服务器的稳定运行和及时响应问题&#xff0c;选择一个正规、靠谱的游戏服务器提供商非常重要。 选择服务器操作系统 根据不同游戏的需求&#xff0c;选择适合的操作系统&#xff0c;通常可选择Linux或WindowsServer操作系统。 上传、安装…

【C#】条码管理操作手册

前言&#xff1a;本文档为条码管理系统操作指南&#xff0c;介绍功能使用、参数配置、资源链接&#xff0c;以及异常的解决等。思维导图如下&#xff1a; 一、思维导图 二、功能操作–条码打印&#xff08;客户端&#xff09; 2.1 参数设置 功能介绍&#xff1a;二维码图片样…

Flink-----Yarn应用模式作业提交流程

Yarn应用模式作业提交流程 在Yarn当中又分为Session&#xff0c;PerJob&#xff0c;Application&#xff0c;建议和推荐使用独立集群的&#xff0c;其中就包含PerJob 和Application&#xff0c;但是1.17版本的Flink已将PerJob标记为过时&#xff0c;并且Application可以解决Pe…

Linux0.11内核源码解析-truncate.c

truncate文件只要实现释放指定i节点在设备上占用的所有逻辑块&#xff0c;包括直接块、一次间接块、二次间接块。从而将文件节点对应的文件长度截为0&#xff0c;并释放占用的设备空间。 索引节点的逻辑块连接方式 释放一次间接块 static void free_ind(int dev,int block) {…

基于阿里云 Flink+Hologres 搭建实时数仓

摘要&#xff1a;本文作者阿里云 Hologres 高级研发工程师张高迪&阿里云 Flink 技术内容工程师张英男&#xff0c;本篇内容将为您介绍如何通过实时计算 Flink 版和实时数仓 Hologres 搭建实时数仓。 Tips&#xff1a;点击「阅读原文」免费领取 5000CU*小时 Flink 云资源 背…

CSS基础 知识点总结

一.CSS简介 1.1 CSS简介 ① CSS指的是层叠样式表&#xff0c;用来控制网页外观的一门技术 ② CSS发展至今&#xff0c;经历过CSS1.0 CSS2.0 CSS2.1 CSS3.0这几个版本&#xff0c;CSS3.0是CSS最新版本 1.2 CSS引入方式 ① 在一个页面引入CSS&#xff0c;共有三种方式 外部…

【C语言基础】宏定义的用法详解

&#x1f4e2;&#xff1a;如果你也对机器人、人工智能感兴趣&#xff0c;看来我们志同道合✨ &#x1f4e2;&#xff1a;不妨浏览一下我的博客主页【https://blog.csdn.net/weixin_51244852】 &#x1f4e2;&#xff1a;文章若有幸对你有帮助&#xff0c;可点赞 &#x1f44d;…

Arduino 入门学习笔记10 使用I2C的OLED屏幕

Arduino 入门学习笔记10 使用I2C的OLED屏幕 一、准备工具二、JMD0.96C-1介绍1. 显示屏参数2. SSD1306驱动芯片介绍&#xff1a; 三、使用Arduino开发步骤1. 安装库&#xff08;1&#xff09;Adafruit_GFX_Library 库&#xff08;2&#xff09;Adafruit_SSD1306 驱动库&#xff…

七夕特辑——3D爱心(可监听鼠标移动)

前言 「作者主页」&#xff1a;雪碧有白泡泡 「个人网站」&#xff1a;雪碧的个人网站 「推荐专栏」&#xff1a; ★java一站式服务 ★ ★ React从入门到精通★ ★前端炫酷代码分享 ★ ★ 从0到英雄&#xff0c;vue成神之路★ ★ uniapp-从构建到提升★ ★ 从0到英雄&#xff…

Lua 语言笔记(一)

1. 变量命名规范 弱类型语言(动态类型语言)&#xff0c;定义变量的时候&#xff0c;不需要类型修饰 而且&#xff0c;变量类型可以随时改变每行代码结束的时候&#xff0c;要不要分号都可以变量名 由数字&#xff0c;字母下划线组成&#xff0c;不能以数字开头&#xff0c;也不…

【数据结构】二叉树

&#x1f407; &#x1f525;博客主页&#xff1a; 云曦 &#x1f4cb;系列专栏&#xff1a;数据结构 &#x1f4a8;吾生也有涯&#xff0c;而知也无涯 &#x1f49b; 感谢大家&#x1f44d;点赞 &#x1f60b;关注&#x1f4dd;评论 文章目录 前言一、树的概念及结构&#x…