在kaggle中用GPU使用CGAN生成指定mnist手写数字

文章目录

  • 1项目介绍
  • 2参考文章
  • 3代码的实现过程及对代码的详细解析
    • 独热编码
    • 定义生成器
    • 定义判别器
    • 打印我们的引导信息
    • 模型训练
    • 迭代过程中生成的图片
    • 损失函数的变化
  • 4总结
  • 5 模型相关的文件

1项目介绍

在GAN的基础上进行有条件的引导生成图片cgan

2参考文章

GAN实战之Pytorch 使用CGAN生成指定MNIST手写数字
GANs系列:CGAN(条件GAN)原理简介以及项目代码实现

3代码的实现过程及对代码的详细解析


import numpy as np # linear algebra
import pandas as pd # data processing, CSV file I/O (e.g. pd.read_csv)

import os
for dirname, _, filenames in os.walk('/kaggle/input'):
    for filename in filenames:
        print(os.path.join(dirname, filename))

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import numpy as np
import matplotlib.pyplot as plt
import torchvision
from torchvision import transforms
from torch.utils import data
import os
import glob
from PIL import Image

独热编码


# 输入x代表默认的torchvision返回的类比值,class_count类别值为10
def one_hot(x, class_count=10):
    return torch.eye(class_count)[x, :]  # 切片选取,第一维选取第x个,第二维全要

torch.eye(10)函数的作用是生成一个10*10的对角矩阵
该函数的作用是得到第x个位置为1的独热编码,如果传入为列表,则得到一个矩阵
在这里插入图片描述

 
transform =transforms.Compose([transforms.ToTensor(),
                               transforms.Normalize(0.5, 0.5)])
 #minist数据集中的图片数据的维度是[batch_size, 1, 28, 28],其中batch_size是每个批次的图像数量。这个数据集中的每个图像都是28x28像素的灰度图像,因此它们只有一个通道
dataset = torchvision.datasets.MNIST('data',
                                     train=True,
                                     transform=transform,
                                     target_transform=one_hot,
                                     download=True)
#这里target_transform参数的作用是对标签进行转换。在这个例子中,它的作用是将标签转换为one-hot编码。
dataloader = data.DataLoader(dataset, batch_size=64, shuffle=True)
 

定义生成器


class Generator(nn.Module):
    def __init__(self):
        super(Generator, self).__init__()
        #因此,这个函数的输入张量维度为[batch_size, 10]和[batch_size, 100],输出张量维度为[batch_size, 1, 1, 1]。
        self.linear1 = nn.Linear(10, 128 * 7 * 7)
        self.bn1 = nn.BatchNorm1d(128 * 7 * 7)
        self.linear2 = nn.Linear(100, 128 * 7 * 7)
        self.bn2 = nn.BatchNorm1d(128 * 7 * 7)
        #这个函数的作用是将一个输入张量进行反卷积操作,得到一个输出张量。
        #nn.ConvTranspose2d函数的作用是将一个256通道的输入张量转换为一个128通道的输出张量,使用3x3的卷积核进行卷积操作,并在卷积操作后进行1像素的padding
        self.deconv1 = nn.ConvTranspose2d(256, 128,
                                          kernel_size=(3, 3),
                                          padding=1)
        self.bn3 = nn.BatchNorm2d(128)
        self.deconv2 = nn.ConvTranspose2d(128, 64,
                                          kernel_size=(4, 4),
                                          stride=2,
                                          padding=1)
        self.bn4 = nn.BatchNorm2d(64)
        self.deconv3 = nn.ConvTranspose2d(64, 1,
                                          kernel_size=(4, 4),
                                          stride=2,
                                          padding=1)
 
    def forward(self, x1, x2):
        x1 = F.relu(self.linear1(x1))
        x1 = self.bn1(x1)
        x1 = x1.view(-1, 128, 7, 7)
        x2 = F.relu(self.linear2(x2))
        x2 = self.bn2(x2)
        x2 = x2.view(-1, 128, 7, 7)
        #将两个处理后的结果拼接在一起,得到形状为[64, 256, 7, 7]的张量
        x = torch.cat([x1, x2], axis=1)
        x = F.relu(self.deconv1(x))
        #形状变为为[64, 128, 7, 7]的张量
        x = self.bn3(x)
        x = F.relu(self.deconv2(x))
        #形状变为为[64, 64, 14, 14]的张量
        x = self.bn4(x)
         # 形状变为为[64, 1, 28, 28]的张量
        x = torch.tanh(self.deconv3(x))
      
        return x

生成器对数据的处理过程:
这个函数对于输入张量[64, 1, 28, 28]的维度变化过程如下:
输入张量维度为[64, 1, 28, 28]
经过线性变换和ReLU激活函数处理后,得到两个形状为[64, 128 * 7 * 7]的张量
将两个张量分别通过BatchNorm1d进行归一化处理
将两个处理后的结果reshape成形状为[64, 128, 7, 7]的张量
将两个处理后的结果拼接在一起,得到形状为[64, 256, 7, 7]的张量
经过反卷积操作得到输出张量,维度为[64, 1, 28, 28]

定义判别器


# input:1,28,28的图片以及长度为10的condition
class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()
        self.linear = nn.Linear(10, 1*28*28)
        self.conv1 = nn.Conv2d(2, 64, kernel_size=3, stride=2)
        self.conv2 = nn.Conv2d(64, 128, kernel_size=3, stride=2)
        self.bn = nn.BatchNorm2d(128)
        self.fc = nn.Linear(128*6*6, 1) # 输出一个概率值
 
    def forward(self, x1, x2):
    #leak_relu激活函数:它在输入小于0时返回一个小的斜率,而在输入大于等于0时返回输入本身
        x1 =F.leaky_relu(self.linear(x1))
        x1 = x1.view(-1, 1, 28, 28)
        #torch.cat([x1, x2], axis=1)函数将张量x1和张量x2沿着第二个维度(即列)拼接起来
        x = torch.cat([x1, x2], axis=1)
        #处理过后变为(64,2,28,28)
        x = F.dropout2d(F.leaky_relu(self.conv1(x)))
        #维度变为(64,64,13,13)
        x = F.dropout2d(F.leaky_relu(self.conv2(x)))
        #维度变为(64,128,6,6)
        x = self.bn(x)
        x = x.view(-1, 128*6*6)
        #最后键位了64*1(同时把值映射到0~1之间)
        x = torch.sigmoid(self.fc(x))
        return x


# 初始化模型
device = 'cuda' if torch.cuda.is_available() else 'cpu'
gen = Generator().to(device)
dis = Discriminator().to(device)
 
# 损失计算函数
loss_function = torch.nn.BCELoss()
 
# 定义优化器
d_optim = torch.optim.Adam(dis.parameters(), lr=1e-5)
g_optim = torch.optim.Adam(gen.parameters(), lr=1e-4)
 
# 定义可视化函数
def generate_and_save_images(model, epoch, label_input, noise_input):
	#生成器生成取片,label_input为输入的引导信息,noise_input为随机的噪声点
    predictions = np.squeeze(model(label_input, noise_input).cpu().numpy())
    #numpy.squeeze()函数的作用是去掉矩阵里维度为1的维度。
    fig = plt.figure(figsize=(4, 4))
    for i in range(predictions.shape[0]):
        plt.subplot(4, 4, i + 1)
        plt.imshow((predictions[i] + 1) / 2, cmap='gray')
        plt.axis("off")
    from IPython.display import FileLink
    plt.savefig('data/img/image_at_epoch_{:04d}.png'.format(epoch))
    plt.show()


import os 
os.makedirs("data/img")

打印我们的引导信息

noise_seed = torch.randn(16, 100, device=device)
 
label_seed = torch.randint(0, 10, size=(16,))
label_seed_onehot = one_hot(label_seed).to(device)

print(label_seed)
tensor([1, 3, 5, 4, 9, 3, 0, 0, 1, 3, 4, 5, 9, 2, 3, 7])

模型训练

D_loss = []
G_loss = []


# 训练循环
for epoch in range(150):
    d_epoch_loss = 0
    g_epoch_loss = 0
    count = len(dataloader.dataset)
    # 对全部的数据集做一次迭代
    #dataloader中的图像是四维的。在for循环中,每次迭代会返回一个batch_size大小的数据
    #其中每个数据都是一个四维张量,形状为[batch_size, channels, height, width]
    for step, (img, label) in enumerate(dataloader):
        img = img.to(device)
        label = label.to(device)
        size = img.shape[0]
        random_noise = torch.randn(size, 100, device=device)
 
        d_optim.zero_grad()
 
        real_output = dis(label, img)
        d_real_loss = loss_function(real_output,
                                    torch.ones_like(real_output, device=device)
                                    )
        #torch.ones_like(real_output, device=device)函数的作用是生成一个与real_output形状相同的张量,其中所有元素都为1。                         
        d_real_loss.backward() #求解梯度
 
        # 得到判别器在生成图像上的损失
        gen_img = gen(label,random_noise)
        fake_output = dis(label, gen_img.detach())  # 判别器输入生成的图片,f_o是对生成图片的预测结果
        d_fake_loss = loss_function(fake_output,
                                    torch.zeros_like(fake_output, device=device))
        d_fake_loss.backward()
 
        d_loss = d_real_loss + d_fake_loss
        d_optim.step()  # 优化
 
        # 得到生成器的损失
        g_optim.zero_grad()
        fake_output = dis(label, gen_img)
        g_loss = loss_function(fake_output,
                               torch.ones_like(fake_output, device=device))
        g_loss.backward()
        g_optim.step()
 
        with torch.no_grad():
            d_epoch_loss += d_loss.item()
            g_epoch_loss += g_loss.item()
    with torch.no_grad():
        d_epoch_loss /= count
        g_epoch_loss /= count
        D_loss.append(d_epoch_loss)
        G_loss.append(g_epoch_loss)
        if epoch % 10 == 0:
            print('Epoch:', epoch)
            generate_and_save_images(gen, epoch, label_seed_onehot, noise_seed)
    print("epoch:{}/150".format(epoch))
 
plt.plot(D_loss, label='D_loss')
plt.plot(G_loss, label='G_loss')
plt.legend()
plt.show()

迭代过程中生成的图片

迭代1次
在这里插入图片描述
迭代10次
在这里插入图片描述
迭代20次
在这里插入图片描述

迭代30次
在这里插入图片描述

迭代40次
在这里插入图片描述

迭代150次
在这里插入图片描述

损失函数的变化

在这里插入图片描述

4总结

cGAN相比于GAN而言,将label的信息通过一系列的卷积操作和图像的信息融合在一起,然后放进模型进行训练,让我们的模型能和label相匹配的图像,从而在我们给出制定的数字label时能够生成对应的数字图片,实现了引导的过程。

5 模型相关的文件

模型的相关文件:提取码(ujki)

本模型是放在kaggle中运行的,kaggle的部署流程请参考:在kaggle中用GPU训练模型

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/93833.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

UbuntuDDE 23.04发布,体验DeepinV23的一个新选择

UbuntuDDE 23.04发布,体验DeepinV23的一个新选择 昨晚网上搜索了一圈,无意看到邮箱一条新闻,UbuntuDDE 23.04发布了 因为前几天刚用虚拟机安装过,所以麻溜的从网站下载了ISO文件,安装上看看。本来没多想,…

什么是字体图标(Icon Font)?如何在网页中使用字体图标?

聚沙成塔每天进步一点点 ⭐ 专栏简介⭐ 字体图标(Icon Font)⭐ 如何在网页中使用字体图标⭐ 写在最后 ⭐ 专栏简介 前端入门之旅:探索Web开发的奇妙世界 记得点击上方或者右侧链接订阅本专栏哦 几何带你启航前端之旅 欢迎来到前端入门之旅&a…

9.4 集成功率放大电路

OTL、OCL 和 BTL 电路均有各种不同输出功率和不同电压增益的集成电路。应当注意,在使用 OTL 电路时,需外接输出电容。为了改善频率特性,减小非线性失真,很多电路内部还引入深度负反馈。这里以低频功放为例。 一、集成功率放大电路…

【C++】详解声明和定义

2023年8月28日,周一下午 研究了一个下午才彻底弄明白... 写到晚上才写完这篇博客。 目录 声明和定义的根本区别结构体的声明和定义声明结构体 定义结构体类的声明和定义函数的定义和声明声明函数 定义函数变量声明和定义声明变量定义变量 声明和定义的根本区别 …

全景图像生成算法

摘要 全景图像生成是计算机视觉领域的一个重要研究方向。本文对五种经典的全景图像生成算法进行综述,包括基于相机运动估计的算法、基于特征匹配的算法、基于图像切割的算法、基于多项式拟合的算法和基于深度学习的算法。通过对这些算法的原理、优缺点、适用场景等…

14-redis

一 Redis概述 1 为什么要用NoSQL 单机Mysql的美好年代 在90年代,一个网站的访问量一般都不大,用单个数据库完全可以 轻松应付。在那个时候,更多的都是静态网页,动态交互类型的网站不多。 遇到问题: 随着用户数的增长…

【建议收藏】Kubernetes 网络策略入门:概念、示例和最佳实践,附云原生资料

目录 摘要 一、Kubernetes 网络策略组件 二、实施网络策略 示例 1:在命名空间中限制流量 示例 2:允许特定 Pod 的流量 示例 3:在单个策略中组合入站和出站规则 示例 4:阻止对特定 IP 范围的出站流量 三、Kubernetes 网络策…

R包开发-2.2:在RStudio中使用Rcpp制作R-Package(更新于2023.8.23)

目录 4-添加C函数 5-编辑元数据 6-启用Roxygen,执行文档化。 7-单元测试 8-在自己的计算机上安装R包: 9-程序发布 参考: 为什么要写这篇文章的更新日期?因为R语言发展很快,很多函数或者方式,现在可以使…

stm32串口通信(PC--stm32;中断接收方式;附proteus电路图;开发方式:cubeMX)

单片机型号STM32F103R6: 最后实现的效果是,开机后PC内要求输入1或0,输入1则打开灯泡,输入0则关闭灯泡,输入其他内容则显示错误,值得注意的是这个模拟的东西只能输入英文 之所以用2个LED灯是因为LED电阻粗略一算就是1…

Linux(实操篇一)

Linux实操篇 Linux(实操篇一)1. 常用基本命令1.1 帮助命令1.1.1 man获得帮助信息1.1.2 help获得shell内置命令的帮助信息1.1.3 常用快捷键 1.2 文件目录类1.2.1 pwd显示当前 工作目录的绝对路径1.2.2 ls列出目录的内容1.2.3 cd切换目录1.2.4 mkdir创建一个新的目录1.2.5 rmdir删…

Bigemap在路桥行业是怎么应用的?

选择Bigemap的原因: 奥维下架了,后来了解到的bigemap,于是测试了这款软件 使用场景: 下载影像、矢量路网做前期策划,下载完数据后导出cad ,做一些标注,最终出图下载等高线,作为前期选址依据 …

java八股文面试[多线程]——线程的生命周期

笔试题:画出线程的生命周期,各个状态的转换。 5.等待队列(本是Object里的方法,但影响了线程) 调用obj的wait(), notify()方法前,必须获得obj锁,也就是必须写在synchronized(obj) 代码段内。与等待队列相关的步骤和图 …

股票预测和使用LSTM(长期-短期-记忆)的预测

一、说明 准确预测股市走势长期以来一直是投资者和交易员难以实现的目标。虽然多年来出现了无数的策略和模型,但有一种方法最近因其能够捕获历史数据中的复杂模式和依赖关系而获得了显着的关注:长短期记忆(LSTM)。利用深度学习的力…

[MyBatis系列⑥]注解开发

🍃作者简介:准大三本科网络工程专业在读,持续学习Java,努力输出优质文章 ⭐MyBatis系列①:增删改查 ⭐MyBatis系列②:两种Dao开发方式 ⭐MyBatis系列③:动态SQL ⭐MyBatis系列④:核心…

Postman API测试之道:不止于点击,更在于策略

引言:API测试的重要性 在当今的软件开发中,API已经成为了一个不可或缺的部分。它们是软件组件之间交互的桥梁,确保数据的流动和功能的实现。因此,对API的测试显得尤为重要,它不仅关乎功能的正确性,还涉及到…

android framework之Applicataion启动流程分析

Application启动流程分析 启动方式一:通过Launcher启动app 启动方式二:在某一个app里启动第二个app的Activity. 以上两种方式均可触发app进程的启动。但无论哪种方式,最终通过通过调用AMS的startActivity()来启动application的。 根据上图…

论文解读 | ScanNet:室内场景的丰富注释3D重建

原创 | 文 BFT机器人 大型的、有标记的数据集的可用性是为了利用做有监督的深度学习方法的一个关键要求。但是在RGB-D场景理解的背景下,可用的数据非常少,通常是当前的数据集覆盖了一小范围的场景视图,并且具有有限的语义注释。 为了解决这个问题&#…

数据仓库一分钟

简介 数据仓库(Data Warehouse)简称DW或DWH,是数据库的一种概念上的升级,可以说是为满足新需求设计的一种新数据库,而这个数据库是需容纳更多的数据,更加庞大的数据集,从逻辑上讲数据仓库和数据…

Midjourney API 的对接和使用

“ 阅读本文大概需要 4 分钟。 ” 在人工智能绘图领域,想必大家听说过 Midjourney 的大名吧。 Midjourney 以其出色的绘图能力在业界独树一帜。无需过多复杂的操作,只要简单输入绘图指令,这个神奇的工具就能在瞬间为我们呈现出对应的图像。无…

Git企业开发控制理论和实操-从入门到深入(七)|企业级开发模型

前言 那么这里博主先安利一些干货满满的专栏了! 首先是博主的高质量博客的汇总,这个专栏里面的博客,都是博主最最用心写的一部分,干货满满,希望对大家有帮助。 高质量博客汇总 然后就是博主最近最花时间的一个专栏…