Paddle 实现DCGAN

传统GAN

传统的GAN可以看我的这篇文章:Paddle 基于ANN(全连接神经网络)的GAN(生成对抗网络)实现-CSDN博客

DCGAN

DCGAN是适用于图像生成的GAN,它的特点是:

  • 只采用卷积层和转置卷积层,而不采用全连接层
  • 在每个卷积层或转置卷积层之间,插入一个批归一化层和ReLU激活函数

转置卷积层

转置卷积层执行的是转置卷积或反卷积的操作,即它是常规卷积层的反向操作。它接收一个低分辨率的输入,然后将其通过转置滤波器升采样到更高的分辨率。

对于一个卷积层,它的输出大小公式是:

o = \frac{i + 2p - k}{s} + 1

其中,o表示输出大小,i表示输入大小,p表示填充(padding),k表示卷积核大小(kernel_size),s表示步长(stride)。也就是说:输出大小 = (输入大小 - 卷积核大小 + 2 × 填充数) ÷ 步长 + 1

而对于一个转置卷积层,它的输出大小公式是:

o = s(i-1)-2p+k+u

 其中,o表示输出大小,i表示输入大小,p表示填充(padding),k表示反卷积核大小(kernel_size),s表示步长(stride),u表示输出填充(output padding)。也就是说:输出大小 = (输入大小 - 1) * 步长 - 2*填充 + 反卷积大小 + 输出填充

在paddle中,转置卷积层可以这么定义:

paddle.nn.Conv2DTranspose(in_channels, out_channels, kernel_size, stride, padding)

像卷积层一样,反卷积层的in_channels表示输入通道数(如形如(3, 32, 32)的图片张量的通道数就是3),out_channels表示输出通道数(如把(64, 32, 32)变成3通道的彩色图像(3, 32, 32))。 

代码实现

这里我们采用NWPU-RESISC45数据集,从中选择“freeway”(高速公路)作为训练数据,让机器生成高速公路的图片。这个训练数据内有700张256x256的图片,但由于我的电脑显存不足,因此将图片大小设置为64x64.

先写dataset.py:

import paddle
import numpy as np
from PIL import Image
import os


def getAllPath(path):
    return [os.path.join(path, f) for f in os.listdir(path)]


class FreewayDataset(paddle.io.Dataset):

    def __init__(self, transform=None):
        super().__init__()
        self.data = []
        for path in getAllPath('./freeway'):
            img = Image.open(path)
            img = img.resize((64, 64))
            img = np.array(img, dtype=np.float32).transpose((2, 1, 0))
            if transform is not None:
                img = transform(img)
            self.data.append(img)
        self.data = np.array(self.data, dtype=np.float32)

    def __getitem__(self, idx):
        return self.data[idx]

    def __len__(self):
        return len(self.data)

然后写训练脚本:

from dataset import FreewayDataset
import paddle
from models import Generator, Discriminator
import numpy as np

dataset = FreewayDataset()
dataloader = paddle.io.DataLoader(dataset, batch_size=32, shuffle=True)

netG = Generator()
netD = Discriminator()

if 1:
    try:
        mydict = paddle.load('generator.params')
        netG.set_dict(mydict)
        mydict = paddle.load('discriminator.params')
        netD.set_dict(mydict)
    except:
        print('fail to load model')

loss = paddle.nn.BCELoss()

optimizerD = paddle.optimizer.Adam(parameters=netD.parameters(), learning_rate=0.0002, beta1=0.5, beta2=0.999)
optimizerG = paddle.optimizer.Adam(parameters=netG.parameters(), learning_rate=0.0002, beta1=0.5, beta2=0.999)

# 最大迭代epoch
max_epoch = 1000

for epoch in range(max_epoch):
    now_step = 0
    for step, data in enumerate(dataloader):
        ############################
        # (1) 更新鉴别器
        ###########################

        # 清除D的梯度
        optimizerD.clear_grad()

        # 传入正样本,并更新梯度
        pos_img = data
        label = paddle.full([pos_img.shape[0], 1, 1, 1], 1, dtype='float32')
        pre = netD(pos_img)
        loss_D_1 = loss(pre, label)
        loss_D_1.backward()

        # 通过randn构造随机数,制造负样本,并传入D,更新梯度
        noise = paddle.randn([pos_img.shape[0], 100, 1, 1], 'float32')
        neg_img = netG(noise)
        label = paddle.full([pos_img.shape[0], 1, 1, 1], 0, dtype='float32')
        pre = netD(neg_img.detach())  # 通过detach阻断网络梯度传播,不影响G的梯度计算
        loss_D_2 = loss(pre, label)
        loss_D_2.backward()

        # 更新D网络参数
        optimizerD.step()
        optimizerD.clear_grad()

        loss_D = loss_D_1 + loss_D_2

        ############################
        # (2) 更新生成器
        ###########################

        # 清除D的梯度
        optimizerG.clear_grad()

        noise = paddle.randn([pos_img.shape[0], 100, 1, 1], 'float32')
        fake = netG(noise)
        label = paddle.full((pos_img.shape[0], 1, 1, 1), 1, dtype=np.float32, )
        output = netD(fake)
        # 这个写法没有问题,因为这个loss既会影响到netG(output=netD(netG(noise)))的梯度,也会影响到netD的梯度,但是之后的代码并没有更新netD的参数,而循环开头就清除了netD的梯度
        loss_G = loss(output, label)
        loss_G.backward()

        # 更新G网络参数
        optimizerG.step()
        optimizerG.clear_grad()

        now_step += 1

        ###########################
        # 输出日志
        ###########################
        if now_step % 10 == 0:
            print(f'Epoch ID={epoch} Batch ID={now_step} \n\n D-Loss={float(loss_D)} G-Loss={float(loss_G)}')

paddle.save(netG.state_dict(), "generator.params")
paddle.save(netD.state_dict(), "discriminator.params")

 最后编写图片生成脚本:

import paddle
from models import Generator
import matplotlib.pyplot as plt

# 加载模型
netG = Generator()
mydict = paddle.load('generator.params')
netG.set_dict(mydict)

# 设置matplotlib的显示环境
fig, axs = plt.subplots(nrows=2, ncols=5, figsize=(15, 6))  # 创建一个2x5的子图网格

# 生成10个噪声向量
for i, ax in enumerate(axs.flatten()):
    noise = paddle.randn([1, 100, 1, 1], 'float32')
    img = netG(noise)
    img = img.numpy()[0].transpose((2, 1, 0))  # img.numpy():张量转np数组
    img[img < 0] = 0  # 将img中所有小于0的元素赋值为0

    # 显示图片
    ax.imshow(img)
    ax.axis('off')  # 不显示坐标轴

# 显示图像
plt.show()

经过数次训练,最终的效果如下:

这样看来,至少有点高速公路的感觉了。 

参考

通过DCGAN实现人脸图像生成-使用文档-PaddlePaddle深度学习平台

卷积层和反卷积层输出特征图大小计算_输出特征图大小的计算方法-CSDN博客 

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/611585.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

如何在 CentOS 上安装并配置 Redis

如何在 CentOS 上安装并配置 Redis 但是太阳&#xff0c;他每时每刻都是夕阳也都是旭日。当他熄灭着走下山去收尽苍凉残照之际&#xff0c;正是他在另一面燃烧着爬上山巅散烈烈朝晖之时。 ——史铁生 环境准备 本教程将在 CentOS 7 或 CentOS 8 上进行。确保你的系统已更新到最…

自托管站点监控工具 Uptime Kuma 搭建与使用

本文首发于只抄博客&#xff0c;欢迎点击原文链接了解更多内容。 前言 Uptime Kuma 是一个类似 Uptime Robot 的站点监控工具&#xff0c;它可以自托管在自己的 Nas 或者 VPS 上&#xff0c;用来监控各类站点、数据库等 监控类型&#xff1a;支持监控 HTTP(s) / TCP / HTTP(s…

Day 43 1049. 最后一块石头的重量 II 494. 目标和 474.一和零

最后一块石头重量Ⅱ 有一堆石头&#xff0c;每块石头的重量都是正整数。 每一回合&#xff0c;从中选出任意两块石头&#xff0c;然后将它们一起粉碎。假设石头的重量分别为 x 和 y&#xff0c;且 x < y。那么粉碎的可能结果如下&#xff1a; 如果 x y&#xff0c;那么两…

【LLM 论文】Step-Back Prompting:先解决更高层次的问题来提高 LLM 推理能力

论文&#xff1a;Take a Step Back: Evoking Reasoning via Abstraction in Large Language Models ⭐⭐⭐⭐ Google DeepMind, ICLR 2024, arXiv:2310.06117 论文速读 该论文受到的启发是&#xff1a;人类再解决一个包含很多细节的具体问题时&#xff0c;先站在更高的层次上解…

【Git】Github创建远程仓库并与本地互联

创建仓库 点击生成新的仓库 创建成功后会生成一个这样的文件 拉取到本地 首先先确保本地安装了git 可以通过终端使用 git --version来查看是否安装好了git 如果显示了版本信息&#xff0c;说明已经安装好了git&#xff0c;这时候我们就可以进入我们想要clone到问目标文件夹 …

计算机系列之算法分析与设计

21、算法分析与设计 算法是对特定问题求解步骤的一种描述。它是指令的有限序列&#xff0c;其中每一条指令标识一个或多个操作。 它具有有穷性、确定性&#xff08;含义确定、输入输出确定&#xff0c;相同输入相同输出&#xff1b;执行路径唯一&#xff09;、可行性、输入&a…

【SAP ME 38】SAP ME发布WebService配置及应用

更多WebService介绍请参照 【SAP ME 28】SAP ME创建开发组件&#xff08;DC&#xff09;webService 致此一个WebService应用发布成功&#xff0c;把wsdl文件提供到第三方系统调用接口&#xff01; 注意&#xff1a; 在SAP ME官方开发中默认对外开放的接口是WebService接口&am…

01、vue+openlayers6实现自定义测量功能(提供源码)

首先先封装一些openlayers的工具函数&#xff0c;如下所示&#xff1a; import VectorSource from ol/source/Vector; import VectorLayer from ol/layer/Vector; import Style from ol/style/Style; import Fill from ol/style/Fill; import Stroke from ol/style/Stroke; im…

Android GPU渲染SurfaceFlinger合成RenderThread的dequeueBuffer/queueBuffer与fence机制(2)

Android GPU渲染SurfaceFlinger合成RenderThread的dequeueBuffer/queueBuffer与fence机制&#xff08;2&#xff09; 计算fps帧率 用 adb shell dumpsys SurfaceFlinger --list 查询当前的SurfaceView&#xff0c;然后有好多行&#xff0c;再把要查询的行内容完整的传给 ad…

题目----力扣--移除链表元素

题目 给你一个链表的头节点 head 和一个整数 val &#xff0c;请你删除链表中所有满足 Node.val val 的节点&#xff0c;并返回 新的头节点 。 示例 1&#xff1a; 输入&#xff1a;head [1,2,6,3,4,5,6], val 6 输出&#xff1a;[1,2,3,4,5]示例 2&#xff1a; 输入&…

智慧公厕:让厕所管理变得更智慧、高效、舒适!

公共厕所是城市的重要组成部分&#xff0c;但常常被忽视。它们的管理和养护往往面临着许多问题&#xff0c;例如卫生状况不佳、环境畏畏缩缩、设施老旧等。为了解决这些问题&#xff0c;智慧公厕应运而生。智慧公厕是一种全方位的应用解决方案&#xff0c;将科技与公共厕所管理…

我在洛杉矶采访到了亚马逊云全球首席信息官CISO(L11)!

在本次洛杉矶举办的亚马逊云Re:Inforce全球安全大会中&#xff0c;小李哥作为亚马逊大中华区开发者社区和自媒体代表&#xff0c;跟着亚马逊云安全产品团队采访了亚马逊云首席信息安全官(CISO)CJ Moses、亚马逊副总裁Eric Brandwine和亚马逊云首席高级安全工程师Becky Weiss。 …

搜索的未来:OpenAI 的 GPT 如何彻底改变行业

搜索的未来&#xff1a;OpenAI 的 GPT 如何彻底改变行业 概述 搜索引擎格局正处于一场革命的风口浪尖&#xff0c;而 OpenAI 的 GPT 处于这场变革的最前沿。最近出现了一种被称为“im-good-gpt-2-chatbot”的神秘聊天机器人&#xff0c;以及基于 ChatGPT 的搜索引擎的传言&am…

【C++ 内存管理】深拷贝和浅拷贝你了解吗?

文章目录 1.深拷贝2.浅拷贝3.深拷贝和浅拷贝 1.深拷贝 &#x1f34e; 深拷⻉: 是对对象的完全独⽴复制&#xff0c;包括对象内部动态分配的资源。在深拷⻉中&#xff0c;不仅复制对象的值&#xff0c;还会复制对象所指向的堆上的数据。 特点&#xff1a; &#x1f427;① 复制对…

DCDC中MOS半桥的自举电容,自举电阻问题

一个免费的翻译英文文章的网站&#xff0c;可以将英文数据手册翻译为中文&#xff08;需要挂梯子&#xff0c;不收费&#xff0c;无广告&#xff0c;不需要注册&#xff09;&#xff0c;链接如下&#xff1a; Google 翻译 翻译效果&#xff1a; // 104电容是0.1uf&#xff1b…

Spring AOP(2)

目录 Spring AOP详解 PointCut 切面优先级Order 切点表达式 execution表达式 切点表达式示例 annotation 自定义注解MyAspect 切面类 添加自定义注解 Spring AOP详解 PointCut 上面代码存在一个问题, 就是对于excution(* com.example.demo.controller.*.*(..))的大量重…

Tomcat中服务启动失败,如何查看启动失败日志?

1. 查看 localhost.log 这个日志文件通常包含有关特定 web 应用的详细错误信息。运行以下命令查看 localhost.log 中的错误&#xff1a; sudo tail -n 100 /opt/tomcat/latest/logs/localhost.YYYY-MM-DD.log请替换 YYYY-MM-DD 为当前日期&#xff0c;或选择最近的日志文件日…

【notepad++】使用

1 notepad 下载路径 https://notepad-plus.en.softonic.com/download 2 设置护眼模式 . 设置——语言格式设置——前景色——黑色 . 背景色——RGB &#xff1a;199 237 204 . 勾选“使用全局背景色”、“使用全局前景色” . 保存并关闭

YOLOv5改进 | 注意力机制 | 理解全局和局部信息的SE注意力机制

在深度学习目标检测领域&#xff0c;YOLOv5成为了备受关注的模型之一。本文给大家带来的是能够理解全局和局部信息的SE注意力机制。文章在介绍主要的原理后&#xff0c;将手把手教学如何进行模块的代码添加和修改&#xff0c;并将修改后的完整代码放在文章的最后&#xff0c;方…

RAG查询改写方法概述

在RAG系统中&#xff0c;用户的查询是丰富多样的&#xff0c;可能存在措辞不准确和缺乏语义信息的问题。这导致使用原始的查询可能无法有效检索到目标文档。 因此&#xff0c;将用户查询的语义空间与文档的语义空间对齐至关重要&#xff0c;目前主要有查询改写和嵌入转换两种方…