昇思25天学习打卡营第23天 | Pix2Pix实现图像转换

内容介绍:

Pix2Pix是基于条件生成对抗网络(cGAN, Condition Generative Adversarial Networks )实现的一种深度学习图像转换模型,该模型是由Phillip Isola等作者在2017年CVPR上提出的,可以实现语义/标签到真实图片、灰度图到彩色图、航空图到地图、白天到黑夜、线稿图到实物图的转换。Pix2Pix是将cGAN应用于有监督的图像到图像翻译的经典之作,其包括两个模型:生成器判别器

传统上,尽管此类任务的目标都是相同的从像素预测像素,但每项都是用单独的专用机器来处理的。而Pix2Pix使用的网络作为一个通用框架,使用相同的架构和目标,只在不同的数据上进行训练,即可得到令人满意的结果,鉴于此许多人已经使用此网络发布了他们自己的艺术作品。

cGAN的生成器与传统GAN的生成器在原理上有一些区别,cGAN的生成器是将输入图片作为指导信息,由输入图像不断尝试生成用于迷惑判别器的“假”图像,由输入图像转换输出为相应“假”图像的本质是从像素到另一个像素的映射,而传统GAN的生成器是基于一个给定的随机噪声生成图像,输出图像通过其他约束条件控制生成,这是cGAN和GAN的在图像翻译任务中的差异。Pix2Pix中判别器的任务是判断从生成器输出的图像是真实的训练图像还是生成的“假”图像。在生成器与判别器的不断博弈过程中,模型会达到一个平衡点,生成器输出的图像与真实训练数据使得判别器刚好具有50%的概率判断正确。

具体内容:

1. 导包:

from download import download
from mindspore import dataset as ds
import matplotlib.pyplot as plt
import mindspore
import mindspore.nn as nn
import mindspore.ops as ops
import mindspore.nn as nn
import mindspore.nn as nn
from mindspore.common import initializer as init
import numpy as np
import os
import datetime
from mindspore import value_and_grad, Tensor
from mindspore import load_checkpoint, load_param_into_net

2. 下载数据集

url = "https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/notebook/models/application/dataset_pix2pix.tar"

download(url, "./dataset", kind="tar", replace=True)

3. 数据显示

dataset = ds.MindDataset("./dataset/dataset_pix2pix/train.mindrecord", columns_list=["input_images", "target_images"], shuffle=True)
data_iter = next(dataset.create_dict_iterator(output_numpy=True))
# 可视化部分训练数据
plt.figure(figsize=(10, 3), dpi=140)
for i, image in enumerate(data_iter['input_images'][:10], 1):
    plt.subplot(3, 10, i)
    plt.axis("off")
    plt.imshow((image.transpose(1, 2, 0) + 1) / 2)
plt.show()

4. 网络构建

class UNetSkipConnectionBlock(nn.Cell):
    def __init__(self, outer_nc, inner_nc, in_planes=None, dropout=False,
                 submodule=None, outermost=False, innermost=False, alpha=0.2, norm_mode='batch'):
        super(UNetSkipConnectionBlock, self).__init__()
        down_norm = nn.BatchNorm2d(inner_nc)
        up_norm = nn.BatchNorm2d(outer_nc)
        use_bias = False
        if norm_mode == 'instance':
            down_norm = nn.BatchNorm2d(inner_nc, affine=False)
            up_norm = nn.BatchNorm2d(outer_nc, affine=False)
            use_bias = True
        if in_planes is None:
            in_planes = outer_nc
        down_conv = nn.Conv2d(in_planes, inner_nc, kernel_size=4,
                              stride=2, padding=1, has_bias=use_bias, pad_mode='pad')
        down_relu = nn.LeakyReLU(alpha)
        up_relu = nn.ReLU()
        if outermost:
            up_conv = nn.Conv2dTranspose(inner_nc * 2, outer_nc,
                                         kernel_size=4, stride=2,
                                         padding=1, pad_mode='pad')
            down = [down_conv]
            up = [up_relu, up_conv, nn.Tanh()]
            model = down + [submodule] + up
        elif innermost:
            up_conv = nn.Conv2dTranspose(inner_nc, outer_nc,
                                         kernel_size=4, stride=2,
                                         padding=1, has_bias=use_bias, pad_mode='pad')
            down = [down_relu, down_conv]
            up = [up_relu, up_conv, up_norm]
            model = down + up
        else:
            up_conv = nn.Conv2dTranspose(inner_nc * 2, outer_nc,
                                         kernel_size=4, stride=2,
                                         padding=1, has_bias=use_bias, pad_mode='pad')
            down = [down_relu, down_conv, down_norm]
            up = [up_relu, up_conv, up_norm]

            model = down + [submodule] + up
            if dropout:
                model.append(nn.Dropout(p=0.5))
        self.model = nn.SequentialCell(model)
        self.skip_connections = not outermost

    def construct(self, x):
        out = self.model(x)
        if self.skip_connections:
            out = ops.concat((out, x), axis=1)
        return out

5. UNet生成器

class UNetGenerator(nn.Cell):
    def __init__(self, in_planes, out_planes, ngf=64, n_layers=8, norm_mode='bn', dropout=False):
        super(UNetGenerator, self).__init__()
        unet_block = UNetSkipConnectionBlock(ngf * 8, ngf * 8, in_planes=None, submodule=None,
                                             norm_mode=norm_mode, innermost=True)
        for _ in range(n_layers - 5):
            unet_block = UNetSkipConnectionBlock(ngf * 8, ngf * 8, in_planes=None, submodule=unet_block,
                                                 norm_mode=norm_mode, dropout=dropout)
        unet_block = UNetSkipConnectionBlock(ngf * 4, ngf * 8, in_planes=None, submodule=unet_block,
                                             norm_mode=norm_mode)
        unet_block = UNetSkipConnectionBlock(ngf * 2, ngf * 4, in_planes=None, submodule=unet_block,
                                             norm_mode=norm_mode)
        unet_block = UNetSkipConnectionBlock(ngf, ngf * 2, in_planes=None, submodule=unet_block,
                                             norm_mode=norm_mode)
        self.model = UNetSkipConnectionBlock(out_planes, ngf, in_planes=in_planes, submodule=unet_block,
                                             outermost=True, norm_mode=norm_mode)

    def construct(self, x):
        return self.model(x)

6. PatchGAN判别器

import mindspore.nn as nn

class ConvNormRelu(nn.Cell):
    def __init__(self,
                 in_planes,
                 out_planes,
                 kernel_size=4,
                 stride=2,
                 alpha=0.2,
                 norm_mode='batch',
                 pad_mode='CONSTANT',
                 use_relu=True,
                 padding=None):
        super(ConvNormRelu, self).__init__()
        norm = nn.BatchNorm2d(out_planes)
        if norm_mode == 'instance':
            norm = nn.BatchNorm2d(out_planes, affine=False)
        has_bias = (norm_mode == 'instance')
        if not padding:
            padding = (kernel_size - 1) // 2
        if pad_mode == 'CONSTANT':
            conv = nn.Conv2d(in_planes, out_planes, kernel_size, stride, pad_mode='pad',
                             has_bias=has_bias, padding=padding)
            layers = [conv, norm]
        else:
            paddings = ((0, 0), (0, 0), (padding, padding), (padding, padding))
            pad = nn.Pad(paddings=paddings, mode=pad_mode)
            conv = nn.Conv2d(in_planes, out_planes, kernel_size, stride, pad_mode='pad', has_bias=has_bias)
            layers = [pad, conv, norm]
        if use_relu:
            relu = nn.ReLU()
            if alpha > 0:
                relu = nn.LeakyReLU(alpha)
            layers.append(relu)
        self.features = nn.SequentialCell(layers)

    def construct(self, x):
        output = self.features(x)
        return output

class Discriminator(nn.Cell):
    def __init__(self, in_planes=3, ndf=64, n_layers=3, alpha=0.2, norm_mode='batch'):
        super(Discriminator, self).__init__()
        kernel_size = 4
        layers = [
            nn.Conv2d(in_planes, ndf, kernel_size, 2, pad_mode='pad', padding=1),
            nn.LeakyReLU(alpha)
        ]
        nf_mult = ndf
        for i in range(1, n_layers):
            nf_mult_prev = nf_mult
            nf_mult = min(2 ** i, 8) * ndf
            layers.append(ConvNormRelu(nf_mult_prev, nf_mult, kernel_size, 2, alpha, norm_mode, padding=1))
        nf_mult_prev = nf_mult
        nf_mult = min(2 ** n_layers, 8) * ndf
        layers.append(ConvNormRelu(nf_mult_prev, nf_mult, kernel_size, 1, alpha, norm_mode, padding=1))
        layers.append(nn.Conv2d(nf_mult, 1, kernel_size, 1, pad_mode='pad', padding=1))
        self.features = nn.SequentialCell(layers)

    def construct(self, x, y):
        x_y = ops.concat((x, y), axis=1)
        output = self.features(x_y)
        return output

7. Pix2Pix生成器和判别器初始化

g_in_planes = 3
g_out_planes = 3
g_ngf = 64
g_layers = 8
d_in_planes = 6
d_ndf = 64
d_layers = 3
alpha = 0.2
init_gain = 0.02
init_type = 'normal'


net_generator = UNetGenerator(in_planes=g_in_planes, out_planes=g_out_planes,
                              ngf=g_ngf, n_layers=g_layers)
for _, cell in net_generator.cells_and_names():
    if isinstance(cell, (nn.Conv2d, nn.Conv2dTranspose)):
        if init_type == 'normal':
            cell.weight.set_data(init.initializer(init.Normal(init_gain), cell.weight.shape))
        elif init_type == 'xavier':
            cell.weight.set_data(init.initializer(init.XavierUniform(init_gain), cell.weight.shape))
        elif init_type == 'constant':
            cell.weight.set_data(init.initializer(0.001, cell.weight.shape))
        else:
            raise NotImplementedError('initialization method [%s] is not implemented' % init_type)
    elif isinstance(cell, nn.BatchNorm2d):
        cell.gamma.set_data(init.initializer('ones', cell.gamma.shape))
        cell.beta.set_data(init.initializer('zeros', cell.beta.shape))


net_discriminator = Discriminator(in_planes=d_in_planes, ndf=d_ndf,
                                  alpha=alpha, n_layers=d_layers)
for _, cell in net_discriminator.cells_and_names():
    if isinstance(cell, (nn.Conv2d, nn.Conv2dTranspose)):
        if init_type == 'normal':
            cell.weight.set_data(init.initializer(init.Normal(init_gain), cell.weight.shape))
        elif init_type == 'xavier':
            cell.weight.set_data(init.initializer(init.XavierUniform(init_gain), cell.weight.shape))
        elif init_type == 'constant':
            cell.weight.set_data(init.initializer(0.001, cell.weight.shape))
        else:
            raise NotImplementedError('initialization method [%s] is not implemented' % init_type)
    elif isinstance(cell, nn.BatchNorm2d):
        cell.gamma.set_data(init.initializer('ones', cell.gamma.shape))
        cell.beta.set_data(init.initializer('zeros', cell.beta.shape))

class Pix2Pix(nn.Cell):
    """Pix2Pix模型网络"""
    def __init__(self, discriminator, generator):
        super(Pix2Pix, self).__init__(auto_prefix=True)
        self.net_discriminator = discriminator
        self.net_generator = generator

    def construct(self, reala):
        fakeb = self.net_generator(reala)
        return fakeb

8. 训练

epoch_num = 3
ckpt_dir = "results/ckpt"
dataset_size = 400
val_pic_size = 256
lr = 0.0002
n_epochs = 100
n_epochs_decay = 100

def get_lr():
    lrs = [lr] * dataset_size * n_epochs
    lr_epoch = 0
    for epoch in range(n_epochs_decay):
        lr_epoch = lr * (n_epochs_decay - epoch) / n_epochs_decay
        lrs += [lr_epoch] * dataset_size
    lrs += [lr_epoch] * dataset_size * (epoch_num - n_epochs_decay - n_epochs)
    return Tensor(np.array(lrs).astype(np.float32))

dataset = ds.MindDataset("./dataset/dataset_pix2pix/train.mindrecord", columns_list=["input_images", "target_images"], shuffle=True, num_parallel_workers=1)
steps_per_epoch = dataset.get_dataset_size()
loss_f = nn.BCEWithLogitsLoss()
l1_loss = nn.L1Loss()

def forword_dis(reala, realb):
    lambda_dis = 0.5
    fakeb = net_generator(reala)
    pred0 = net_discriminator(reala, fakeb)
    pred1 = net_discriminator(reala, realb)
    loss_d = loss_f(pred1, ops.ones_like(pred1)) + loss_f(pred0, ops.zeros_like(pred0))
    loss_dis = loss_d * lambda_dis
    return loss_dis

def forword_gan(reala, realb):
    lambda_gan = 0.5
    lambda_l1 = 100
    fakeb = net_generator(reala)
    pred0 = net_discriminator(reala, fakeb)
    loss_1 = loss_f(pred0, ops.ones_like(pred0))
    loss_2 = l1_loss(fakeb, realb)
    loss_gan = loss_1 * lambda_gan + loss_2 * lambda_l1
    return loss_gan

d_opt = nn.Adam(net_discriminator.trainable_params(), learning_rate=get_lr(),
                beta1=0.5, beta2=0.999, loss_scale=1)
g_opt = nn.Adam(net_generator.trainable_params(), learning_rate=get_lr(),
                beta1=0.5, beta2=0.999, loss_scale=1)

grad_d = value_and_grad(forword_dis, None, net_discriminator.trainable_params())
grad_g = value_and_grad(forword_gan, None, net_generator.trainable_params())

def train_step(reala, realb):
    loss_dis, d_grads = grad_d(reala, realb)
    loss_gan, g_grads = grad_g(reala, realb)
    d_opt(d_grads)
    g_opt(g_grads)
    return loss_dis, loss_gan

if not os.path.isdir(ckpt_dir):
    os.makedirs(ckpt_dir)

g_losses = []
d_losses = []
data_loader = dataset.create_dict_iterator(output_numpy=True, num_epochs=epoch_num)

for epoch in range(epoch_num):
    for i, data in enumerate(data_loader):
        start_time = datetime.datetime.now()
        input_image = Tensor(data["input_images"])
        target_image = Tensor(data["target_images"])
        dis_loss, gen_loss = train_step(input_image, target_image)
        end_time = datetime.datetime.now()
        delta = (end_time - start_time).microseconds
        if i % 2 == 0:
            print("ms per step:{:.2f}  epoch:{}/{}  step:{}/{}  Dloss:{:.4f}  Gloss:{:.4f} ".format((delta / 1000), (epoch + 1), (epoch_num), i, steps_per_epoch, float(dis_loss), float(gen_loss)))
        d_losses.append(dis_loss.asnumpy())
        g_losses.append(gen_loss.asnumpy())
    if (epoch + 1) == epoch_num:
        mindspore.save_checkpoint(net_generator, ckpt_dir + "Generator.ckpt")

9. 推理

param_g = load_checkpoint(ckpt_dir + "Generator.ckpt")
load_param_into_net(net_generator, param_g)
dataset = ds.MindDataset("./dataset/dataset_pix2pix/train.mindrecord", columns_list=["input_images", "target_images"], shuffle=True)
data_iter = next(dataset.create_dict_iterator())
predict_show = net_generator(data_iter["input_images"])
plt.figure(figsize=(10, 3), dpi=140)
for i in range(10):
    plt.subplot(2, 10, i + 1)
    plt.imshow((data_iter["input_images"][i].asnumpy().transpose(1, 2, 0) + 1) / 2)
    plt.axis("off")
    plt.subplots_adjust(wspace=0.05, hspace=0.02)
    plt.subplot(2, 10, i + 11)
    plt.imshow((predict_show[i].asnumpy().transpose(1, 2, 0) + 1) / 2)
    plt.axis("off")
    plt.subplots_adjust(wspace=0.05, hspace=0.02)
plt.show()

Pix2Pix的学习过程让我深刻体会到了深度学习在图像处理领域的强大能力。通过训练一对相互竞争的神经网络——生成器与判别器,Pix2Pix能够学习到输入图像与输出图像之间复杂的映射关系。这种端到端的学习方式,无需人工设计复杂的特征提取与转换规则,极大地简化了图像转换的流程,同时也提高了转换结果的质量和多样性。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/791411.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

二分法求函数的零点 信友队

题目ID&#xff1a;15713 必做题 100分 时间限制: 1000ms 空间限制: 65536kB 题目描述 有函数&#xff1a;f(x) 已知f(1.5) > 0&#xff0c;f(2.4) < 0 且方程 f(x) 0 在区间 [1.5,2.4] 有且只有一个根&#xff0c;请用二分法求出该根。 输入格式 &#xff08;无…

政安晨:【Keras机器学习示例演绎】(五十三)—— 使用 TensorFlow 决策森林进行分类

目录 简介 设置 准备数据 定义数据集元数据 配置超参数 实施培训和评估程序 实验 1&#xff1a;使用原始特征的决策森林 检查模型 实验 2&#xff1a;目标编码决策森林 创建模型输入 使用目标编码实现特征编码 使用预处理器创建梯度提升树模型 训练和评估模型 实验…

从零开始学习嵌入式----C语言框架梳理与后期规划

目录 一、环境搭建. 二、见解 三、C语言框架梳理 四、嵌入式学习规划流程图&#xff08;学习顺序可能有变&#xff09; 一、环境搭建. C语言是一门编程语言&#xff0c;在学习的时候要准备好环境。我个人比较喜欢用VS,具体怎么安装请百度。学习C语言的时候&#xff0c;切忌…

Spring中如何操作Redis

Spring毕竟是Java中的一个主流框架&#xff0c;如何在这个框架中使用Redis呢&#xff1f; 创建项目并引入相关依赖 然后进行创建。 至此就将Redis的相关依赖引入进来了。 编写Redis配置 将application.properties修改成application.yml 然后编写如下配置&#xff1a; spr…

各向异性含水层中地下水三维流基本微分方程的推导

各向异性含水层中地下水三维流基本微分方程的推导 参考文献&#xff1a; [1] 刘欣怡,付小莉.论连续性方程的推导及几种形式转换的方法[J].力学与实践,2023,45(02):469-474. 文章链接 水均衡的基本思想&#xff1a; ∑ 流 入 − ∑ 流 出 Δ V \sum 流入-\sum 流出\Delta V ∑…

ArduPilot开源飞控之AP_Mount_Siyi

ArduPilot开源飞控之AP_Mount_Siyi 1. 源由2. 框架设计2.1 类和继承2.2 公共方法2.3 保护方法2.4 私有成员和方法2.5 解析状态2.6 重要成员变量 3. 重要方法3.1 AP_Mount_Siyi::init3.2 AP_Mount_Siyi::update3.3 AP_Mount_Siyi::read_incoming_packets3.4 AP_Mount_Siyi::proc…

LabVIEW实现LED显示屏视觉检测

为了满足LED显示屏在生产过程中的严格质量检测需求&#xff0c;引入自动化检测系统是十分必要的。传统人工检测方式存在检测强度高、效率低、准确性差等问题&#xff0c;自动化检测系统则能显著提高检测效率和准确性。视觉检测系统的构建主要包含硬件和软件两个部分。 视觉系统…

深度学习论文: LLaMA: Open and Efficient Foundation Language Models

深度学习论文: LLaMA: Open and Efficient Foundation Language Models LLaMA: Open and Efficient Foundation Language Models PDF:https://arxiv.org/pdf/2302.13971.pdf PyTorch: https://github.com/shanglianlm0525/PyTorch-Networks 1 概述 本文介绍了LLaMA&#xff0…

一二三应用开发平台应用开发示例(7)——文档功能实现示例

概述 在完成文件夹配置工作后&#xff0c;接下来配置文档管理系统最核心的管理对象“文档”。 依旧是使用平台低代码配置工作来配置&#xff0c;配置流程跟文件夹的配置是相同的&#xff0c;以下简要说明&#xff0c;重点是新涉及到的功能或注意点。 创建实体 配置模型属性 …

【Dison夏令营 Day 15】 Python 鸡蛋捕手

在本次课程中&#xff0c;我们使用 Python 创建了经典的 "抓蛋 "游戏。在这个游戏中&#xff0c;每抓到一个鸡蛋就能赢得 10 分&#xff0c;而每掉落一个鸡蛋就会损失一条命。 小时候&#xff0c;我们都玩过 "抓鸡蛋 "游戏。我们使用海龟软件包在 Python …

【linux】阿里云centos配置邮件服务

目录 1.安装mailx服务 2./etc/mail.rc 配置增加 3.QQ邮箱开启smtp服务&#xff0c;获取授权码 4.端口设置&#xff1a;Linux 防火墙开放端口-CSDN博客 5.测试 1.安装mailx服务 yum -y install mailx 2./etc/mail.rc 配置增加 #邮件发送人 set from924066173qq.com #阿里…

windows USB 设备驱动开发-USB电源管理(一)

符合通用串行总线 (USB) 规范的 USB 设备的电源管理功能具有一组丰富而复杂的电源管理功能。 请务必了解这些功能如何与 Windows 驱动程序模型 (WDM) 交互&#xff0c;特别是 Microsoft Windows 如何调整标准 USB 功能以支持系统唤醒体系结构。 基于内核模式驱动程序框架的 US…

android13 rom 开发总纲说明

1. 这里是文章总纲&#xff0c;可以在这里快速找到需要的文章。 2. 文章一般是基于标准的android13&#xff0c;有一些文章可能会涉及到具体平台&#xff0c;例如全志&#xff0c;瑞芯微等一些平台。 3.系统应用 3.1系统应用Launcher3桌面相关&#xff1a; 3.2系统应用设置S…

NLP入门——词袋语言模型的搭建、训练与预测

卷积语言模型实际上是取了句子最后ctx_len个词作为上下文输入模型来预测之后的分词。但更好的选择是我们做一个词袋&#xff0c;将所有分词装在词袋中作为上下文&#xff0c;这样预测的分词不只根据最后ctx_len个分词&#xff0c;而是整个词袋中的所有分词。 例如我们的序列是&…

绝对值不等式运用(C++)

货仓选址 用数学公式表达题意&#xff0c;假设有位置a1~an,假设选址在x位置处&#xff0c;则有&#xff1a; 如何让这个最小&#xff0c;我们把两个式子整合一下&#xff0c;利用绝对值不等式&#xff1a; 我们知道&#xff1a; 如下图所示&#xff1a;到A&#xff0c;B两点&…

【深度学习入门篇 ②】Pytorch完成线性回归!

&#x1f34a;嗨&#xff0c;大家好&#xff0c;我是小森( &#xfe61;ˆoˆ&#xfe61; )&#xff01; 易编橙终身成长社群创始团队嘉宾&#xff0c;橙似锦计划领衔成员、阿里云专家博主、腾讯云内容共创官、CSDN人工智能领域优质创作者 。 易编橙&#xff1a;一个帮助编程小…

简单状压dp(以力扣464为例)

目录 1.状态压缩dp是啥&#xff1f; 2.题目分析 3.解题思路 4.算法分析 5.代码分析 6.代码一览 7.结语 1.状态压缩dp是啥&#xff1f; 顾名思义&#xff0c;状态压缩dp就是将原本会超出内存限制的存储改用更加有效的存储方式。简而言之&#xff0c;就是压缩dp的空间。 …

昇思MindSpore学习笔记6-05计算机视觉--SSD目标检测

摘要&#xff1a; 记录MindSpore AI框架使用SSD目标检测算法对图像内容识别的过程、步骤和方法。包括环境准备、下载数据集、数据采样、数据集加载和预处理、构建模型、损失函数、模型训练、模型评估等。 一、概念 1.模型简介 SSD目标检测算法 Single Shot MultiBox Detecto…

WebRTC群发消息API接口选型指南!怎么用?

WebRTC群发消息API接口安全性如何&#xff1f;API接口怎么优化&#xff1f; WebRTC技术在现代实时通信中占据了重要地位。对于需要实现群发消息功能的应用程序来说&#xff0c;选择合适的WebRTC群发消息API接口是至关重要的。AokSend将详细介绍WebRTC群发消息API接口的选型指南…

大数据开发者如何快速熟悉新公司业务

作为一名大数据开发工程师,进入一家新公司后快速熟悉业务是至关重要的。 目录 1. 了解产品形态故事1:电商平台的数据分析故事2:金融科技的风控系统故事3:社交媒体的推荐算法 2. 了解业务流程故事1:物流配送系统的优化故事2:医疗保险的理赔流程故事3:银行的贷款审批流程 3. 走…