用GAN生成奖杯

数据集链接:https://pan.baidu.com/s/19Uxc2ELiMG3acUtLeSTDTA?pwd=wsyw
提取码:wsyw

我设置的图片大小为128*128,如果内存爆炸可以将batch_size调小,epoch我设置的2000,我感觉其实1000也够了。代码如下:

import argparse
from torchvision import datasets, transforms
import torch
import torch.nn as nn
import os
import numpy as np
from torchvision.utils import save_image


def args_parse():
    parser = argparse.ArgumentParser()
    parser.add_argument("--n_epoches", type=int, default=2000, help="number of epochs of training")
    parser.add_argument("--batch_size", type=int, default=256, help="size of the batches")
    parser.add_argument("--lr", type=float, default=0.0002, help="adam: learning rate")
    parser.add_argument("--n_cpu", type=int, default=1, help="number of cpu threads to use during batch generation")
    parser.add_argument("--latent_dim", type=int, default=100, help="dimensionality of the latent space")
    parser.add_argument("--img_size", type=int, default=128, help="size of each image dimension")
    parser.add_argument("--channels", type=int, default=3, help="number of image channels")
    parser.add_argument("--sample_interval", type=int, default=50, help="interval between image sampling")
    parser.add_argument("--b1", type=float, default=0.5, help="adam: decay of first order momentum of gradient")
    parser.add_argument("--b2", type=float, default=0.999, help="adam: decay of first order momentum of gradient")
    parser.add_argument("--type", type=str, default='DCGAN', help="The type of DCGAN")
    return parser.parse_args()

class Generator(nn.Module):
    def __init__(self, latent_dim, img_shape):
        super(Generator, self).__init__()
        self.img_shape = img_shape
        def block(in_feat, out_feat, normalize=True):
            layers = [nn.Linear(in_feat, out_feat)]
            if normalize:
                layers.append(nn.BatchNorm1d(out_feat, 0.8))
            layers.append(nn.LeakyReLU(0.2, inplace=True))
            return layers
        self.model = nn.Sequential(
            *block(latent_dim, 128, normalize=False),
            *block(128, 256),
            *block(256, 512),
            *block(512, 1024),
            nn.Linear(1024, int(np.prod(img_shape))),
            nn.Tanh()
        )
    def forward(self, z):
        img = self.model(z)
        img = img.view(img.size(0), self.img_shape[0], self.img_shape[1], self.img_shape[2])
        return img


class Discriminator(nn.Module):
    def __init__(self, img_shape):
        super(Discriminator, self).__init__()
        self.model = nn.Sequential(
            nn.Linear(int(np.prod(img_shape)), 512),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Linear(512, 256),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Linear(256, 1),
            nn.Sigmoid(),
        )
    def forward(self, img):
        img_flat = img.view(img.size(0), -1)
        validity = self.model(img_flat)
        return validity


class Generator_CNN(nn.Module):
    def __init__(self, latent_dim, img_shape):
        super(Generator_CNN, self).__init__()
        self.init_size = img_shape[1] // 4
        self.l1 = nn.Sequential(nn.Linear(latent_dim, 128 * self.init_size ** 2))  # 100 ——> 128 * 8 * 8 = 8192
        self.conv_blocks = nn.Sequential(
            nn.BatchNorm2d(128),
            nn.Upsample(scale_factor=2),
            nn.Conv2d(128, 128, 3, stride=1, padding=1),
            nn.BatchNorm2d(128, 0.8),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Upsample(scale_factor=2),
            nn.Conv2d(128, 64, 3, stride=1, padding=1),
            nn.BatchNorm2d(64, 0.8),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(64, img_shape[0], 3, stride=1, padding=1),
            nn.Tanh()
        )
    def forward(self, z):
        out = self.l1(z)
        out = out.view(out.shape[0], 128, self.init_size, self.init_size)
        img = self.conv_blocks(out)
        return img


class Discriminator_CNN(nn.Module):
    def __init__(self, img_shape):
        super(Discriminator_CNN, self).__init__()
        def discriminator_block(in_filters, out_filters, bn=True):
            block = [nn.Conv2d(in_filters, out_filters, 3, 2, 1),
                     nn.LeakyReLU(0.2, inplace=True),
                     nn.Dropout2d(0.25)]
            if bn:
                block.append(nn.BatchNorm2d(out_filters, 0.8))
            return block
        self.model = nn.Sequential(
            *discriminator_block(img_shape[0], 16, bn=False),
            *discriminator_block(16, 32),
            *discriminator_block(32, 64),
            *discriminator_block(64, 128),
        )
        ds_size = img_shape[1] // 2 ** 4
        self.adv_layer = nn.Sequential(nn.Linear(128 * ds_size ** 2, 1), nn.Sigmoid())  # 128 * 2 * 2 ——> 1

    def forward(self, img):
        out = self.model(img)
        out = out.view(out.shape[0], -1)
        # print('out:', out.shape)
        validity = self.adv_layer(out)
        return validity


def train():
    opt = args_parse()
    transform = transforms.Compose(
        [
            transforms.Resize((128, 128)),
            transforms.ToTensor(),
            transforms.Normalize([0.5], [0.5])
        ])
    data = datasets.ImageFolder('./dataset', transform=transform)
    train_loader = torch.utils.data.DataLoader(
        data,
        batch_size=opt.batch_size,
        shuffle=True)
    img_shape = (opt.channels, opt.img_size, opt.img_size)
    # Construct generator and discriminator
    if opt.type == 'DCGAN':
        generator = Generator_CNN(opt.latent_dim, img_shape)
        discriminator = Discriminator_CNN(img_shape)
    else:
        generator = Generator(opt.latent_dim, img_shape)
        discriminator = Discriminator(img_shape)
    adversarial_loss = torch.nn.BCELoss()
    cuda = True if torch.cuda.is_available() else False
    if cuda:
        generator.cuda()
        discriminator.cuda()
        adversarial_loss.cuda()
    # Optimizers
    optimizer_G = torch.optim.Adam(generator.parameters(), lr=(opt.lr * 8 / 9), betas=(opt.b1, opt.b2))
    optimizer_D = torch.optim.Adam(discriminator.parameters(), lr=opt.lr, betas=(opt.b1, opt.b2))

    Tensor = torch.cuda.FloatTensor if cuda else torch.FloatTensor
    for epoch in range(opt.n_epoches):
        for i, (imgs, _) in enumerate(train_loader):
            # adversarial ground truths
            valid = torch.ones(imgs.shape[0], 1).type(Tensor)
            fake = torch.zeros(imgs.shape[0], 1).type(Tensor)
            real_imgs = imgs.type(Tensor)
            #############    Train Generator    ################
            optimizer_G.zero_grad()
            # sample noise as generator input
            z = torch.tensor(np.random.normal(0, 1, (imgs.shape[0], opt.latent_dim))).type(Tensor)
            # Generate a batch of images
            gen_imgs = generator(z)
            # G-Loss
            g_loss = adversarial_loss(discriminator(gen_imgs), valid)
            g_loss.backward()
            optimizer_G.step()
            #############  Train Discriminator ################
            optimizer_D.zero_grad()
            # D-Loss
            real_loss = adversarial_loss(discriminator(real_imgs), valid)
            fake_loss = adversarial_loss(discriminator(gen_imgs.detach()), fake)
            d_loss = (real_loss + fake_loss) / 2
            d_loss.backward()
            optimizer_D.step()
            print(
                "[Epoch %d/%d] [Batch %d/%d] [D loss: %f] [G Loss: %f]"
                % (epoch, opt.n_epoches, i, len(train_loader), d_loss.item(), g_loss.item())
            )
            batches_done = epoch * len(train_loader) + i
            os.makedirs("images_3_2", exist_ok=True)
            if batches_done % opt.sample_interval == 0:
                save_image(gen_imgs.data[:25], "images_3_2/%d.png" % (batches_done), nrow=5, normalize=True)


if __name__ == '__main__':
    train()

实验效果如下:
在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mfbz.cn/a/754550.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

信创认证 | Smartbi Insight V11成功适配申威3231处理器

在信息技术飞速发展的浪潮中,软硬件的深度融合与协同发展已成为推动行业创新的关键因素。 近日,思迈特商业智能与数据分析软件[简称:Smartbi Insight]V11在统信服务器操作系统V20和中电科申泰信息科技有限公司产品申威3231处理器环境下完成适…

CAN和CANFD数据写入.asc文件的dll

因为工作需要,需要做一些硬件不是CANoe的上位机(比如说周立功CAN,NI-CAN),上位机需要有记录数据的功能,所以用Qt制作了一个记录数据的dll,方便重复使用(因为有的客户指定了编程软件,…

51循迹小车(蓝牙+循迹+超声波+舵机+避障L298N)

基本驱动 L298N电机驱动模块负责供电和控制电机驱动 将电池12V供电接到12V供电上,作为输入。单片机及其他器件供电可以使用5V供电,这里的GND都接到一起。 输出A和输出B接到电机上,负责给电机供电和控制电机。 通道A使能和通道B使能以及逻…

【Windows下使用vckpg下载protoc之后环境变量问题】

使用vcpkg进行下载的protoc: vcpkg install protobuf protobuf:x64-windows 检查protoc版本时出现问题: “protoc”不是内部或外部命令,也不是可运行程序或批处理文件 尝试添加系统环境变量后没有反应。 这个时候找到vckpg下的packages目录…

如何利用ChatGPT寻找科研创新点?分享5个有效实践技巧

欢迎关注:智写AI,为大家带来最酷最有效的智能AI学术科研写作攻略。关于使用ChatGPT等AI学术科研的相关问题可以和作者七哥交流:yida985 地表功能最强大的高级学术专业版已经开放,拥有全球领先的GPT学术科研应用,有兴趣…

鸿蒙开发设备管理:【@ohos.brightness (屏幕亮度)】

屏幕亮度 该模块提供屏幕亮度的设置接口。 说明: 本模块首批接口从API version 7开始支持。后续版本的新增接口,采用上角标单独标记接口的起始版本。 导入模块 import brightness from ohos.brightness;brightness.setValue setValue(value: number):…

【Linux】网络编程套接字

一、预备知识 1.1 理解源IP地址和目的IP地址 在IP数据报的头部中,有两个IP地址,分别叫做源IP地址和目的IP地址。 源IP地址和目的IP地址是网络通信中常用的两个概念,他们代表了通信中的两个节点。 源IP地址是指发起通信的节点的IP地址&#…

在WSL Ubuntu中启用root用户的SSH服务

在 Ubuntu 中,默认情况下 root 用户是禁用 SSH 登录的,这是为了增加系统安全性。 一、修改配置 找到 PermitRootLogin 行:在文件中找到 PermitRootLogin 配置项。默认情况下,它通常被设置为 PermitRootLogin prohibit-password 或…

老生常谈问题之什么是缓存穿透、缓存击穿、缓存雪崩?举个例子你就彻底懂了!!

老生常谈问题之什么是缓存穿透、缓存击穿、缓存雪崩?举个例子你就彻底懂了!! 缓存穿透发生场景解决方案 缓存击穿解决方案 缓存雪崩发生场景解决方案 总结三者区分三者原因三者解决方案 想象一下,你开了一家便利店,店里…

FastAPI教程I

本文参考FastAPI教程https://fastapi.tiangolo.com/zh/tutorial 第一步 import uvicorn from fastapi import FastAPIapp FastAPI()app.get("/") async def root():return {"message": "Hello World"}if __name__ __main__:uvicorn.run(&quo…

从我邮毕业啦!!!

引言 时间过的好快,转眼间就要从北邮毕业了,距离上一次月度总结又过去了两个月,故作本次总结。 PS: https://github.com/WeiXiao-Hyy/blog整理了后端开发的知识网络,欢迎Star! 毕业🎓 6月1号完成了自己的…

Windows server 2016.2019 .NET Framework 3.5安装包、安装步骤

windows server2019 操作系统 安装 sqlserver2008时提示缺少 .NET Frameword 3.5, 在功能里选择 .NET Frameword 3.5安装报错, 下载安装包,下载地址 https://download.csdn.net/download/qq445829096/89450429这里指定备份源路径 安装包解…

多供应商食品零售商城系统的会员营销设计和实现

在多供应商食品零售商城系统中,会员营销是提升用户粘性和增加销售的重要手段。一个有效的会员营销系统能够帮助平台更好地了解用户需求,提供个性化服务,进而提高用户满意度和忠诚度。本文将详细探讨多供应商食品零售商城系统的会员营销设计与…

2毛钱不到的2A同步降压DCDC电压6V频率1.5MHz电感2.2uH封装SOT23-5芯片MT3520B

前言 2A,2.3V-6V输入,1.5MHz 同步降压转换器,批量价格约0.18元 MT3520B 封装SOT23-5 丝印AS20B5 特征 高效率:高达 96% 1.5MHz恒定频率操作 2A 输出电流 无需肖特基二极管 2.3V至6V输入电压范围 输出电压低至 0.6V PFM 模式可在…

MySQL进阶-索引-使用规则-索引失效情况一(索引列运算,字符串不加引号,头部模糊匹配)

文章目录 1、索引列运算1.1、查询表tb_user1.2、查看tb_user的索引1.3、查询 phone177999900151.4、执行计划 phone177999900151.5、查询 substring(phone,10,2) 151.6、执行计划 substring(phone,10,2) 15 2、字符串不加引号2.1、查询 phone177999900152.2、执行计划 phone177…

JAVA-矩阵置零

给定一个 m x n 的矩阵,如果一个元素为 0 ,则将其所在行和列的所有元素都设为 0 。请使用 原地 算法。 思路: 找到0的位置,把0出现的数组的其他值夜置为0 需要额外空间方法: 1、定义两个布尔数组标记二维数组中行和列…

axios之CancelToken取消请求

从 v0.22.0 开始,Axios 支持以 fetch API 方式—— AbortController 取消请求 此 API 从 v0.22.0 开始已被弃用,不应在新项目中使用 官网链接 1. 背景 最近项目中遇到一个场景,当连续触发一个请求时,如果是同一个接口&#xf…

【仿真建模-anylogic】开发规范

Author:赵志乾 Date:2024-06-28 Declaration:All Right Reserved!!! 0. 说明 实际模型开发过程中,对遇到的问题进行总结归纳出以下开发规范,仅供参考! 1. 强制性规范 1…

加密与安全_Java 加密体系 (JCA) 和 常用的开源密码库

文章目录 Java Cryptography Architecture (JCA)开源国密库国密算法对称加密(DES/AES⇒SM4)非对称加密(RSA/ECC⇒SM2)散列(摘要/哈希)算法(MD5/SHA⇒SM3) 在线生成公钥私钥对,RSA公私钥生成参考…

单目操作符

目录 ! --- 逻辑反操作 & --- 取地址操作符 * --- 间接访问操作符(解引用操作符) sizeof --- 操作数的类型长度(单位为字节) ~ --- 对一个数的补码二进制按位取反 前置和前置-- 后置和后置-- (类型) --- 强制类型转换…