第G7周:Semi-Supervised GAN 理论与实战

 🍨 本文为🔗365天深度学习训练营 中的学习记录博客

  🍦 参考文章:365天深度学习训练营-第G7周:Semi-Supervised GAN 理论与实战(训练营内部成员可读)

  🍖 原作者:K同学啊|接辅导、项目定制

 🏡 运行环境:
电脑系统:Windows 10
语言环境:python 3.10
编译器:Pycharm 2022.1.1
深度学习环境:Pytorch  


目录

一、理论知识讲解

二、代码实现

1、配置代码 

 2、初始化权重

3、定义算法模型

4、配置模型

 5、训练模型


一、理论知识讲解

该算法将产生式对抗网络(GAN) 拓展到半监督学习,通过强制判别器D来输出类别标签。我们
在一个数据集上训练一个生成器G以及一个判别器D,输入是N类当中的一个。在训练的时候,判别器D被用于预测输入是属于N+1类中的哪一个,这个N+1是对应了生成器G的输出,这里的判别器
D同时也充当起了分类器C的效果。这种方法可以用于训练效果更好的判别器D,并且可以比普通的GAN产性更加高质量的样本。Semi-Supervised GAN有如下优点:
(1)作者对GANs做了一个新的扩展,允许它同时学习一个生成模型和一个分类器。我们把这个 扩展叫做半监督GAN或SGAN
(2)论文实验结果表明,SGAN在有限数据集比没有生成部分的基准分类器提升了分类性能
(3)论文实验结果表明,SGAN可以显著地提升生成样本的质量并降低生成器的训练时间。 

二、代码实现

1、配置代码 
import argparse
import os
import numpy as np
import math

import torchvision.transforms as transforms
from torchvision.utils import save_image

from torch.utils.data import DataLoader
from torchvision import datasets
from torch.autograd import Variable

import torch.nn as nn
import torch.nn.functional as F
import torch

os.makedirs("images", exist_ok=True)

parser = argparse.ArgumentParser()
parser.add_argument("--n_epochs", type=int, default=2, help="number of epochs of training")
parser.add_argument("--batch_size", type=int, default=64, help="size of the batches")
parser.add_argument("--lr", type=float, default=0.0002, help="adam: learning rate")
parser.add_argument("--b1", type=float, default=0.5, help="adam: decay of first order momentum of gradient")
parser.add_argument("--b2", type=float, default=0.999, help="adam: decay of first order momentum of gradient")
parser.add_argument("--n_cpu", type=int, default=2, help="number of cpu threads to use during batch generation")
parser.add_argument("--latent_dim", type=int, default=100, help="dimensionality of the latent space")
parser.add_argument("--num_classes", type=int, default=10, help="number of classes for dataset")
parser.add_argument("--img_size", type=int, default=32, help="size of each image dimension")
parser.add_argument("--channels", type=int, default=1, help="number of image channels")
parser.add_argument("--sample_interval", type=int, default=400, help="interval between image sampling")
opt = parser.parse_args(args=[])
print(opt)

cuda = True if torch.cuda.is_available() else False
Namespace(n_epochs=2, batch_size=64, lr=0.0002, b1=0.5, b2=0.999, n_cpu=2, latent_dim=100, num_classes=10, img_size=32, channels=1, sample_interval=400)
 2、初始化权重
def weights_init_normal(m):
    classname = m.__class__.__name__
    if classname.find("Conv") != -1:
        torch.nn.init.normal_(m.weight.data, 0.0, 0.02)
    elif classname.find("BatchNorm") != -1:
        torch.nn.init.normal_(m.weight.data, 1.0, 0.02)
        torch.nn.init.constant_(m.bias.data, 0.0)
3、定义算法模型
class Generator(nn.Module):
    def __init__(self):
        super(Generator, self).__init__()

        self.label_emb = nn.Embedding(opt.num_classes, opt.latent_dim)

        self.init_size = opt.img_size // 4  # Initial size before upsampling
        self.l1 = nn.Sequential(nn.Linear(opt.latent_dim, 128 * self.init_size ** 2))

        self.conv_blocks = nn.Sequential(
            nn.BatchNorm2d(128),
            nn.Upsample(scale_factor=2),
            nn.Conv2d(128, 128, 3, stride=1, padding=1),
            nn.BatchNorm2d(128, 0.8),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Upsample(scale_factor=2),
            nn.Conv2d(128, 64, 3, stride=1, padding=1),
            nn.BatchNorm2d(64, 0.8),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(64, opt.channels, 3, stride=1, padding=1),
            nn.Tanh(),
        )

    def forward(self, noise):
        out = self.l1(noise)
        out = out.view(out.shape[0], 128, self.init_size, self.init_size)
        img = self.conv_blocks(out)
        return img


class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()

        def discriminator_block(in_filters, out_filters, bn=True):
            """Returns layers of each discriminator block"""
            block = [nn.Conv2d(in_filters, out_filters, 3, 2, 1), nn.LeakyReLU(0.2, inplace=True), nn.Dropout2d(0.25)]
            if bn:
                block.append(nn.BatchNorm2d(out_filters, 0.8))
            return block

        self.conv_blocks = nn.Sequential(
            *discriminator_block(opt.channels, 16, bn=False),
            *discriminator_block(16, 32),
            *discriminator_block(32, 64),
            *discriminator_block(64, 128),
        )

        # The height and width of downsampled image
        ds_size = opt.img_size // 2 ** 4

        # Output layers
        self.adv_layer = nn.Sequential(nn.Linear(128 * ds_size ** 2, 1), nn.Sigmoid())
        self.aux_layer = nn.Sequential(nn.Linear(128 * ds_size ** 2, opt.num_classes + 1), nn.Softmax())

    def forward(self, img):
        out = self.conv_blocks(img)
        out = out.view(out.shape[0], -1)
        validity = self.adv_layer(out)
        label = self.aux_layer(out)

        return validity, label
4、配置模型
# Loss functions
adversarial_loss = torch.nn.BCELoss()
auxiliary_loss = torch.nn.CrossEntropyLoss()

# Initialize generator and discriminator
generator = Generator()
discriminator = Discriminator()

if cuda:
    generator.cuda()
    discriminator.cuda()
    adversarial_loss.cuda()
    auxiliary_loss.cuda()

# Initialize weights
generator.apply(weights_init_normal)
discriminator.apply(weights_init_normal)

# Configure data loader
os.makedirs("../../data/mnist", exist_ok=True)
dataloader = torch.utils.data.DataLoader(
    datasets.MNIST(
        "../../data/mnist",
        train=True,
        download=True,
        transform=transforms.Compose(
            [transforms.Resize(opt.img_size), transforms.ToTensor(), transforms.Normalize([0.5], [0.5])]
        ),
    ),
    batch_size=opt.batch_size,
    shuffle=True,
)

# Optimizers
optimizer_G = torch.optim.Adam(generator.parameters(), lr=opt.lr, betas=(opt.b1, opt.b2))
optimizer_D = torch.optim.Adam(discriminator.parameters(), lr=opt.lr, betas=(opt.b1, opt.b2))

FloatTensor = torch.cuda.FloatTensor if cuda else torch.FloatTensor
LongTensor = torch.cuda.LongTensor if cuda else torch.LongTensor
Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz to ../../data/mnist\MNIST\raw\train-images-idx3-ubyte.gz
Extracting ../../data/mnist\MNIST\raw\train-images-idx3-ubyte.gz to ../../data/mnist\MNIST\raw

Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz to ../../data/mnist\MNIST\raw\train-labels-idx1-ubyte.gz
Extracting ../../data/mnist\MNIST\raw\train-labels-idx1-ubyte.gz to ../../data/mnist\MNIST\raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz to ../../data/mnist\MNIST\raw\t10k-images-idx3-ubyte.gz
Extracting ../../data/mnist\MNIST\raw\t10k-images-idx3-ubyte.gz to ../../data/mnist\MNIST\raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz to ../../data/mnist\MNIST\raw\t10k-labels-idx1-ubyte.gz
Extracting ../../data/mnist\MNIST\raw\t10k-labels-idx1-ubyte.gz to ../../data/mnist\MNIST\raw
 5、训练模型
# ----------
#  Training
# ----------

for epoch in range(opt.n_epochs):
    for i, (imgs, labels) in enumerate(dataloader):

        batch_size = imgs.shape[0]

        # Adversarial ground truths
        valid = Variable(FloatTensor(batch_size, 1).fill_(1.0), requires_grad=False)
        fake = Variable(FloatTensor(batch_size, 1).fill_(0.0), requires_grad=False)
        fake_aux_gt = Variable(LongTensor(batch_size).fill_(opt.num_classes), requires_grad=False)

        # Configure input
        real_imgs = Variable(imgs.type(FloatTensor))
        labels = Variable(labels.type(LongTensor))

        # -----------------
        #  Train Generator
        # -----------------

        optimizer_G.zero_grad()

        # Sample noise and labels as generator input
        z = Variable(FloatTensor(np.random.normal(0, 1, (batch_size, opt.latent_dim))))

        # Generate a batch of images
        gen_imgs = generator(z)

        # Loss measures generator's ability to fool the discriminator
        validity, _ = discriminator(gen_imgs)
        g_loss = adversarial_loss(validity, valid)

        g_loss.backward()
        optimizer_G.step()

        # ---------------------
        #  Train Discriminator
        # ---------------------

        optimizer_D.zero_grad()

        # Loss for real images
        real_pred, real_aux = discriminator(real_imgs)
        d_real_loss = (adversarial_loss(real_pred, valid) + auxiliary_loss(real_aux, labels)) / 2

        # Loss for fake images
        fake_pred, fake_aux = discriminator(gen_imgs.detach())
        d_fake_loss = (adversarial_loss(fake_pred, fake) + auxiliary_loss(fake_aux, fake_aux_gt)) / 2

        # Total discriminator loss
        d_loss = (d_real_loss + d_fake_loss) / 2

        # Calculate discriminator accuracy
        pred = np.concatenate([real_aux.data.cpu().numpy(), fake_aux.data.cpu().numpy()], axis=0)
        gt = np.concatenate([labels.data.cpu().numpy(), fake_aux_gt.data.cpu().numpy()], axis=0)
        d_acc = np.mean(np.argmax(pred, axis=1) == gt)

        d_loss.backward()
        optimizer_D.step()

        batches_done = epoch * len(dataloader) + i
        if batches_done % opt.sample_interval == 0:
            save_image(gen_imgs.data[:25], "images/%d.png" % batches_done, nrow=5, normalize=True)

    print(
        "[Epoch %d/%d] [Batch %d/%d] [D loss: %f, acc: %d%%] [G loss: %f]"
        % (epoch, opt.n_epochs, i, len(dataloader), d_loss.item(), 100 * d_acc, g_loss.item())
    )
[Epoch 0/2] [Batch 937/938] [D loss: 1.358861, acc: 50%] [G loss: 0.671799]
[Epoch 1/2] [Batch 937/938] [D loss: 1.343094, acc: 50%] [G loss: 0.681119]

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/115364.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

【蓝桥杯选拔赛真题44】python小蓝晨跑 青少年组蓝桥杯python 选拔赛STEMA比赛真题解析

目录 python小蓝晨跑 一、题目要求 1、编程实现 2、输入输出 二、算法分析

作为一个初学者,该如何入门大模型?

在生成式 AI 盛行的当下,你是否被这种技术所折服,例如输入一段简简单单的文字,转眼之间,一幅精美的图片,又或者是文笔流畅的文字就展现在你的面前。 相信很多人有这种想法,认为生成式 AI 深不可测&#xf…

系列十五、idea全局配置

一、全局Maven配置 IDEA启动页面>Customize>All settings>Build,Execution,Deployment>Build Tools>Maven 二、全局编码配置 IDEA启动页面>Customize>All settings>Editor>File Encodings 三、全局激活DevTools配置 IDEA启动页面>Customize>A…

ASTM F963-23美国玩具安全新标准发布

新标准发布 2023年10月13日,美国材料与试验协会(ASTM)发布了新版玩具安全标准ASTM F963-23。 主要更新内容 与ASTM F963-17相比,此次更新包括:单独描述了基材重金属元素的豁免情况,更新了邻苯二甲酸酯的管控…

表格冻结第二行

在网上找了一圈也没找到说明白的。 像这样的表格下面会有很多行,往下翻的时候会忘记第二行显示的什么内容。 需求:将第二行进行冻结 实现: 1,选中第一和第二行,下拉,点击冻结 2,会显示冻结至…

MinIO多容器配置NGINX代理实践(docker-compose版本)

以下nginx配置 分别将本机的9001端口代理到minio1,minio2,minio3,minio4主机的9001端口。用于minio后台 分别将本机的9000端口代理到minio1,minio2,minio3,minio4主机的9000端口。用于minioApi events {worker_connections 1024; }http {upstream minio_console {server min…

kubernetes集群编排——k8s存储

configmap 字面值创建 kubectl create configmap my-config --from-literalkey1config1 --from-literalkey2config2kubectl get cmkubectl describe cm my-config 通过文件创建 kubectl create configmap my-config-2 --from-file/etc/resolv.confkubectl describe cm my-confi…

研发效能DevOps: Git安装

目录 一、理论 1.Git 2.Git 工具 二、实验 1.Git安装 2.配置Git 3. VS Code加载Git 一、理论 1.Git (1)简介 Git 是一个分布式版本控制及源代码管理工具;Git 可以为你的项目保存若干快照,以此来对整个项目进行版本管理。 Git 是一个…

如何做好网页配色,分享一些配色方案和方法

很多网页设计师在选择网页配色方案时,会纠结于用什么网页UI配色方案来吸引客户的注意力,传达信息。选择正确的颜色是网页设计不可或缺的一部分。本指南将从色彩理论和色彩心理学入手,分享三个网页UI配色的简单步骤。 网页UI配色方法有很多&a…

会声会影2024出来了吗?有哪些新功能?剪辑后音乐剪辑教程

会声会影 2024视频编辑软件,既加入光影、动态特效的滤镜效果,也提供了与色彩调整相关的LUT配置文件滤镜,可选择性大,运用起来更显灵活。会声会影在用户的陪伴下走过20余载,经过上百个版本的优化迭代,已将操…

使用 Python 进行自然语言处理第 4 部分:文本表示

一、说明 本文是在 2023 年 3 月为 WomenWhoCode 数据科学跟踪活动发表的系列文章中。早期的文章位于:第 1 部分(涵盖 NLP 简介)、第 2 部分(涵盖 NLTK 和 SpaCy 库)、第 2 部分(涵盖NLTK和SpaCy库&#xf…

open mp笔记

Open mp在cpu上并行计算, 统一内存访问(OPEN MP pthreads),同一块内存共享多个CPU 非统一内存访问(MPI),每个CPU都有自己对应的内存,通过blus interconnect链接起来,cpu不能直接访问他们的内存,…

el-table 列分页

<template><div><el-table:data"tableData":key"tampTime"style"width: 100%"><el-table-columnprop"name"label"姓名"width"180"></el-table-column><el-table-columnprop&quo…

Contec SolarView Compact < 6.00 远程命令执行漏洞 (CVE-2023-23333)

Contec SolarView Compact < 6.00 远程命令执行漏洞 &#xff08;CVE-2023-23333&#xff09; 免责声明漏洞描述漏洞影响漏洞危害网络测绘Fofa: body"SolarView Compact" 漏洞复现1. 构造poc2. 执行命令id命令pwd命令 免责声明 仅用于技术交流,目的是向相关安全人…

在 Python 中创建奇数列表

我们将在本文中介绍在 Python 中创建奇数列表的不同方法。 Python 中的奇数 定义奇数有两种方法&#xff0c;第一种是整数不能被 2 整除时的情况。另一种是整数除以 2 时余数为 1 的情况。 例如&#xff0c;1、5、9、11、45等都是奇数。 从列表中获取奇数的方法有很多&#x…

【【嵌入式开发 Linux 常用命令系列 10 -- Linux 修改终端下 ls 各种类型文件的显示颜色】

文章目录 Linux 修改终端下各种类型文件的显示颜色LS_COLORS 详细介绍 Linux 修改终端下各种类型文件的显示颜色 在 ~/.bashrc 文件最下面添加如下内容&#xff0c;就可以配置目录、文件、sh类型文件的颜色了。 export LS_COLORSdi1:fi0:*.sh33:$LS_COLORS这句话的意思就是在…

【黑马程序员】SSM框架——SSM整合

文章目录 前言一、SSM 整合1. SSM 整合思路1.1 Spring 整合 MyBatis配置模型数据层标准开发业务层标准开发测试接口事务处理 1.2 Spring 整合 SpringMVCweb 配置类SpringMVC 配置类基于 Restful 的 Controller 开发 2. SSM 整合具体实现2.1 创建工程2.2 SSM 整合SpringMyBatisS…

初识Vue 解决vue在启动时生成的提示

让我为大家简单介绍一下吧&#xff01; Vue是一套用于构建用户界面的渐进式javaScript框架 当我们引入vue.js后 <script src"../js/vue.js"></script>我们发现&#xff0c;当我们打开网页时&#xff0c;控制台会出现以下内容 那我们该怎么解决呢&…

tolua中table.remove怎么删除表中符合条件的数据

tolua中table.remove怎么删除表中符合条件的数据 介绍问题&#xff08;错误方式删除数据&#xff09;正确删除方案从后向前删除递归方式删除插入新表方式 拓展一下总结 介绍 在lua中删除表中符合条件的数据其实很简单&#xff0c;但是有一个顺序问题&#xff0c;因为lua的表中…

机器视觉行业最大的污点是什么?99%机器视觉公司存在测量项目数据造假,很遗憾,本人不没有恪守技术的本分

机器视觉行业最大的污点是什么&#xff1f;99%机器视觉公司存在测量项目数据造假&#xff0c;很遗憾&#xff0c;本人没有恪守技术的本分。 1%是没做过机器视觉测量项目&#xff0c;我们应该具体分析和具体判断&#xff0c;更应该提高自己的认知能力和技术能力。 那我们​在现场…