昇思25天学习打卡营第18天|Pix2Pix实现图像转换

Pix2Pix概述

Pix2Pix是基于条件生成对抗网络实现的一种深度学习图像转换模型。Pix2Pix是将cGAN应用于有监督的图像到图像翻译,包括生成器和判别器。

基础原理

cGAN的生成器是将输入图片作为指导信息,由输入图像不断尝试生成用于迷惑判别器的“假”图像,由输入图像转换输出为相应“假”图像的本质是从像素到另一个像素的映射,而传统GAN的生成器是基于一个给定的随机噪声生成图像,输出图像通过其他约束条件控制生成。Pix2Pix中判别器的任务是判断从生成器输出的图像是真实的训练图像还是生成的“假”图像。在生成器与判别器的不断博弈过程中,模型会达到一个平衡点,生成器输出的图像与真实训练数据使得判别器刚好具有50%的概率判断正确。

CGAN的目标损失函数为:

L_{cGAN}(G,D)=E_{(x,y)}[log(D(x,y))]+E_{(x,z)}[log(1-D(x,G(x,z)))]

目标函数是使判别器的损失最大化,而生成器的损失最小化。

pix2pix1

数据准备

from download import download

url = "https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/notebook/models/application/dataset_pix2pix.tar"

download(url, "./dataset", kind="tar", replace=True)

from mindspore import dataset as ds
import matplotlib.pyplot as plt

dataset = ds.MindDataset("./dataset/dataset_pix2pix/train.mindrecord", columns_list=["input_images", "target_images"], shuffle=True)
data_iter = next(dataset.create_dict_iterator(output_numpy=True))
# 可视化部分训练数据
plt.figure(figsize=(10, 3), dpi=140)
for i, image in enumerate(data_iter['input_images'][:10], 1):
    plt.subplot(3, 10, i)
    plt.axis("off")
    plt.imshow((image.transpose(1, 2, 0) + 1) / 2)
plt.show()

创建网络

生成器G结构

使用U-Net,它分为两个部分,其中左侧是由卷积和降采样操作组成的压缩路径,右侧是由卷积和上采样组成的扩张路径,扩张的每个网络块的输入由上一层上采样的特征和压缩路径部分的特征拼接而成。

pix2pix2

定义UNet Skip Connection Block

import mindspore
import mindspore.nn as nn
import mindspore.ops as ops

class UNetSkipConnectionBlock(nn.Cell):
    def __init__(self, outer_nc, inner_nc, in_planes=None, dropout=False,
                 submodule=None, outermost=False, innermost=False, alpha=0.2, norm_mode='batch'):
        super(UNetSkipConnectionBlock, self).__init__()
        down_norm = nn.BatchNorm2d(inner_nc)
        up_norm = nn.BatchNorm2d(outer_nc)
        use_bias = False
        if norm_mode == 'instance':
            down_norm = nn.BatchNorm2d(inner_nc, affine=False)
            up_norm = nn.BatchNorm2d(outer_nc, affine=False)
            use_bias = True
        if in_planes is None:
            in_planes = outer_nc
        down_conv = nn.Conv2d(in_planes, inner_nc, kernel_size=4,
                              stride=2, padding=1, has_bias=use_bias, pad_mode='pad')
        down_relu = nn.LeakyReLU(alpha)
        up_relu = nn.ReLU()
        if outermost:
            up_conv = nn.Conv2dTranspose(inner_nc * 2, outer_nc,
                                         kernel_size=4, stride=2,
                                         padding=1, pad_mode='pad')
            down = [down_conv]
            up = [up_relu, up_conv, nn.Tanh()]
            model = down + [submodule] + up
        elif innermost:
            up_conv = nn.Conv2dTranspose(inner_nc, outer_nc,
                                         kernel_size=4, stride=2,
                                         padding=1, has_bias=use_bias, pad_mode='pad')
            down = [down_relu, down_conv]
            up = [up_relu, up_conv, up_norm]
            model = down + up
        else:
            up_conv = nn.Conv2dTranspose(inner_nc * 2, outer_nc,
                                         kernel_size=4, stride=2,
                                         padding=1, has_bias=use_bias, pad_mode='pad')
            down = [down_relu, down_conv, down_norm]
            up = [up_relu, up_conv, up_norm]

            model = down + [submodule] + up
            if dropout:
                model.append(nn.Dropout(p=0.5))
        self.model = nn.SequentialCell(model)
        self.skip_connections = not outermost

    def construct(self, x):
        out = self.model(x)
        if self.skip_connections:
            out = ops.concat((out, x), axis=1)
        return out

基于UNet的生成器

class UNetGenerator(nn.Cell):
    def __init__(self, in_planes, out_planes, ngf=64, n_layers=8, norm_mode='bn', dropout=False):
        super(UNetGenerator, self).__init__()
        unet_block = UNetSkipConnectionBlock(ngf * 8, ngf * 8, in_planes=None, submodule=None,
                                             norm_mode=norm_mode, innermost=True)
        for _ in range(n_layers - 5):
            unet_block = UNetSkipConnectionBlock(ngf * 8, ngf * 8, in_planes=None, submodule=unet_block,
                                                 norm_mode=norm_mode, dropout=dropout)
        unet_block = UNetSkipConnectionBlock(ngf * 4, ngf * 8, in_planes=None, submodule=unet_block,
                                             norm_mode=norm_mode)
        unet_block = UNetSkipConnectionBlock(ngf * 2, ngf * 4, in_planes=None, submodule=unet_block,
                                             norm_mode=norm_mode)
        unet_block = UNetSkipConnectionBlock(ngf, ngf * 2, in_planes=None, submodule=unet_block,
                                             norm_mode=norm_mode)
        self.model = UNetSkipConnectionBlock(out_planes, ngf, in_planes=in_planes, submodule=unet_block,
                                             outermost=True, norm_mode=norm_mode)

    def construct(self, x):
        return self.model(x)

基于PatchGAN的判别器

生成的矩阵中的每个点代表原图的一小块区域(patch)。通过矩阵中的各个值来判断原图中对应每个Patch的真假。

import mindspore.nn as nn

class ConvNormRelu(nn.Cell):
    def __init__(self,
                 in_planes,
                 out_planes,
                 kernel_size=4,
                 stride=2,
                 alpha=0.2,
                 norm_mode='batch',
                 pad_mode='CONSTANT',
                 use_relu=True,
                 padding=None):
        super(ConvNormRelu, self).__init__()
        norm = nn.BatchNorm2d(out_planes)
        if norm_mode == 'instance':
            norm = nn.BatchNorm2d(out_planes, affine=False)
        has_bias = (norm_mode == 'instance')
        if not padding:
            padding = (kernel_size - 1) // 2
        if pad_mode == 'CONSTANT':
            conv = nn.Conv2d(in_planes, out_planes, kernel_size, stride, pad_mode='pad',
                             has_bias=has_bias, padding=padding)
            layers = [conv, norm]
        else:
            paddings = ((0, 0), (0, 0), (padding, padding), (padding, padding))
            pad = nn.Pad(paddings=paddings, mode=pad_mode)
            conv = nn.Conv2d(in_planes, out_planes, kernel_size, stride, pad_mode='pad', has_bias=has_bias)
            layers = [pad, conv, norm]
        if use_relu:
            relu = nn.ReLU()
            if alpha > 0:
                relu = nn.LeakyReLU(alpha)
            layers.append(relu)
        self.features = nn.SequentialCell(layers)

    def construct(self, x):
        output = self.features(x)
        return output

class Discriminator(nn.Cell):
    def __init__(self, in_planes=3, ndf=64, n_layers=3, alpha=0.2, norm_mode='batch'):
        super(Discriminator, self).__init__()
        kernel_size = 4
        layers = [
            nn.Conv2d(in_planes, ndf, kernel_size, 2, pad_mode='pad', padding=1),
            nn.LeakyReLU(alpha)
        ]
        nf_mult = ndf
        for i in range(1, n_layers):
            nf_mult_prev = nf_mult
            nf_mult = min(2 ** i, 8) * ndf
            layers.append(ConvNormRelu(nf_mult_prev, nf_mult, kernel_size, 2, alpha, norm_mode, padding=1))
        nf_mult_prev = nf_mult
        nf_mult = min(2 ** n_layers, 8) * ndf
        layers.append(ConvNormRelu(nf_mult_prev, nf_mult, kernel_size, 1, alpha, norm_mode, padding=1))
        layers.append(nn.Conv2d(nf_mult, 1, kernel_size, 1, pad_mode='pad', padding=1))
        self.features = nn.SequentialCell(layers)

    def construct(self, x, y):
        x_y = ops.concat((x, y), axis=1)
        output = self.features(x_y)
        return output

Pix2Pix的生成器和判别器初始化

实例化Pix2Pix生成器和判别器

import mindspore.nn as nn
from mindspore.common import initializer as init

g_in_planes = 3
g_out_planes = 3
g_ngf = 64
g_layers = 8
d_in_planes = 6
d_ndf = 64
d_layers = 3
alpha = 0.2
init_gain = 0.02
init_type = 'normal'


net_generator = UNetGenerator(in_planes=g_in_planes, out_planes=g_out_planes,
                              ngf=g_ngf, n_layers=g_layers)
for _, cell in net_generator.cells_and_names():
    if isinstance(cell, (nn.Conv2d, nn.Conv2dTranspose)):
        if init_type == 'normal':
            cell.weight.set_data(init.initializer(init.Normal(init_gain), cell.weight.shape))
        elif init_type == 'xavier':
            cell.weight.set_data(init.initializer(init.XavierUniform(init_gain), cell.weight.shape))
        elif init_type == 'constant':
            cell.weight.set_data(init.initializer(0.001, cell.weight.shape))
        else:
            raise NotImplementedError('initialization method [%s] is not implemented' % init_type)
    elif isinstance(cell, nn.BatchNorm2d):
        cell.gamma.set_data(init.initializer('ones', cell.gamma.shape))
        cell.beta.set_data(init.initializer('zeros', cell.beta.shape))


net_discriminator = Discriminator(in_planes=d_in_planes, ndf=d_ndf,
                                  alpha=alpha, n_layers=d_layers)
for _, cell in net_discriminator.cells_and_names():
    if isinstance(cell, (nn.Conv2d, nn.Conv2dTranspose)):
        if init_type == 'normal':
            cell.weight.set_data(init.initializer(init.Normal(init_gain), cell.weight.shape))
        elif init_type == 'xavier':
            cell.weight.set_data(init.initializer(init.XavierUniform(init_gain), cell.weight.shape))
        elif init_type == 'constant':
            cell.weight.set_data(init.initializer(0.001, cell.weight.shape))
        else:
            raise NotImplementedError('initialization method [%s] is not implemented' % init_type)
    elif isinstance(cell, nn.BatchNorm2d):
        cell.gamma.set_data(init.initializer('ones', cell.gamma.shape))
        cell.beta.set_data(init.initializer('zeros', cell.beta.shape))

class Pix2Pix(nn.Cell):
    """Pix2Pix模型网络"""
    def __init__(self, discriminator, generator):
        super(Pix2Pix, self).__init__(auto_prefix=True)
        self.net_discriminator = discriminator
        self.net_generator = generator

    def construct(self, reala):
        fakeb = self.net_generator(reala)
        return fakeb

训练

包括训练判别器和生成器。训练判别器的目的是最大程度地提高判别图像真伪的概率。训练生成器是希望能产生更好的虚假图像。

代码实现:

import numpy as np
import os
import datetime
from mindspore import value_and_grad, Tensor

epoch_num = 3
ckpt_dir = "results/ckpt"
dataset_size = 400
val_pic_size = 256
lr = 0.0002
n_epochs = 100
n_epochs_decay = 100

def get_lr():
    lrs = [lr] * dataset_size * n_epochs
    lr_epoch = 0
    for epoch in range(n_epochs_decay):
        lr_epoch = lr * (n_epochs_decay - epoch) / n_epochs_decay
        lrs += [lr_epoch] * dataset_size
    lrs += [lr_epoch] * dataset_size * (epoch_num - n_epochs_decay - n_epochs)
    return Tensor(np.array(lrs).astype(np.float32))

dataset = ds.MindDataset("./dataset/dataset_pix2pix/train.mindrecord", columns_list=["input_images", "target_images"], shuffle=True, num_parallel_workers=1)
steps_per_epoch = dataset.get_dataset_size()
loss_f = nn.BCEWithLogitsLoss()
l1_loss = nn.L1Loss()

def forword_dis(reala, realb):
    lambda_dis = 0.5
    fakeb = net_generator(reala)
    pred0 = net_discriminator(reala, fakeb)
    pred1 = net_discriminator(reala, realb)
    loss_d = loss_f(pred1, ops.ones_like(pred1)) + loss_f(pred0, ops.zeros_like(pred0))
    loss_dis = loss_d * lambda_dis
    return loss_dis

def forword_gan(reala, realb):
    lambda_gan = 0.5
    lambda_l1 = 100
    fakeb = net_generator(reala)
    pred0 = net_discriminator(reala, fakeb)
    loss_1 = loss_f(pred0, ops.ones_like(pred0))
    loss_2 = l1_loss(fakeb, realb)
    loss_gan = loss_1 * lambda_gan + loss_2 * lambda_l1
    return loss_gan

d_opt = nn.Adam(net_discriminator.trainable_params(), learning_rate=get_lr(),
                beta1=0.5, beta2=0.999, loss_scale=1)
g_opt = nn.Adam(net_generator.trainable_params(), learning_rate=get_lr(),
                beta1=0.5, beta2=0.999, loss_scale=1)

grad_d = value_and_grad(forword_dis, None, net_discriminator.trainable_params())
grad_g = value_and_grad(forword_gan, None, net_generator.trainable_params())

def train_step(reala, realb):
    loss_dis, d_grads = grad_d(reala, realb)
    loss_gan, g_grads = grad_g(reala, realb)
    d_opt(d_grads)
    g_opt(g_grads)
    return loss_dis, loss_gan

if not os.path.isdir(ckpt_dir):
    os.makedirs(ckpt_dir)

g_losses = []
d_losses = []
data_loader = dataset.create_dict_iterator(output_numpy=True, num_epochs=epoch_num)

for epoch in range(epoch_num):
    for i, data in enumerate(data_loader):
        start_time = datetime.datetime.now()
        input_image = Tensor(data["input_images"])
        target_image = Tensor(data["target_images"])
        dis_loss, gen_loss = train_step(input_image, target_image)
        end_time = datetime.datetime.now()
        delta = (end_time - start_time).microseconds
        if i % 2 == 0:
            print("ms per step:{:.2f}  epoch:{}/{}  step:{}/{}  Dloss:{:.4f}  Gloss:{:.4f} ".format((delta / 1000), (epoch + 1), (epoch_num), i, steps_per_epoch, float(dis_loss), float(gen_loss)))
        d_losses.append(dis_loss.asnumpy())
        g_losses.append(gen_loss.asnumpy())
    if (epoch + 1) == epoch_num:
        mindspore.save_checkpoint(net_generator, ckpt_dir + "Generator.ckpt")

推理

from mindspore import load_checkpoint, load_param_into_net

param_g = load_checkpoint(ckpt_dir + "Generator.ckpt")
load_param_into_net(net_generator, param_g)
dataset = ds.MindDataset("./dataset/dataset_pix2pix/train.mindrecord", columns_list=["input_images", "target_images"], shuffle=True)
data_iter = next(dataset.create_dict_iterator())
predict_show = net_generator(data_iter["input_images"])
plt.figure(figsize=(10, 3), dpi=140)
for i in range(10):
    plt.subplot(2, 10, i + 1)
    plt.imshow((data_iter["input_images"][i].asnumpy().transpose(1, 2, 0) + 1) / 2)
    plt.axis("off")
    plt.subplots_adjust(wspace=0.05, hspace=0.02)
    plt.subplot(2, 10, i + 11)
    plt.imshow((predict_show[i].asnumpy().transpose(1, 2, 0) + 1) / 2)
    plt.axis("off")
    plt.subplots_adjust(wspace=0.05, hspace=0.02)
plt.show()

总结

Pix2Pix作为GAN的一种变体,再生成图像和扩充数据方面有着重要作用。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mfbz.cn/a/776666.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

c++ 附赠课程的知识点记录

(1) 静态变量的赋值 再一个例子: (2) 一般在定义类的赋值运算符函数时, operator ( const A& a ) 函数,应避免自赋值的情况,就是把对象 a 又赋值给 对象a 如同 a a 这样的情况…

类和对象深入理解

目录 static成员概念静态成员变量面试题补充代码1代码2代码3如何访问private中的成员变量 静态成员函数静态成员函数没有this指针 特性 友元友元函数友元类 内部类特性1特性2 匿名对象拷贝对象时的一些编译器优化 感谢各位大佬对我的支持,如果我的文章对你有用,欢迎点击以下链接…

C++ | Leetcode C++题解之第217题存在重复元素

题目&#xff1a; 题解&#xff1a; class Solution { public:bool containsDuplicate(vector<int>& nums) {unordered_set<int> s;for (int x: nums) {if (s.find(x) ! s.end()) {return true;}s.insert(x);}return false;} };

【PB案例学习笔记】-27制作一个控制任务栏显示与隐藏的小程序

写在前面 这是PB案例学习笔记系列文章的第27篇&#xff0c;该系列文章适合具有一定PB基础的读者。 通过一个个由浅入深的编程实战案例学习&#xff0c;提高编程技巧&#xff0c;以保证小伙伴们能应付公司的各种开发需求。 文章中设计到的源码&#xff0c;小凡都上传到了gite…

视频参考帧和重构帧复用

1、 视频编码中的参考帧和重构帧 从下图的编码框架可以看出&#xff0c;每编码一帧需要先使用当前帧CU(n)减去当前帧的参考帧CU&#xff08;n&#xff09;得到残差。同时&#xff0c;需要将当前帧的重构帧CU*&#xff08;n&#xff09;输出&#xff0c;然后再读取重构帧进行预测…

Pandas数据可视化详解:大案例解析(第27天)

系列文章目录 Pandas数据可视化解决不显示中文和负号问题matplotlib数据可视化seaborn数据可视化pyecharts数据可视化优衣库数据分析案例 文章目录 系列文章目录前言1. Pandas数据可视化1.1 案例解析&#xff1a;代码实现 2. 解决不显示中文和负号问题3. matplotlib数据可视化…

HTTP代理服务器:深度解析与应用

“随着互联网的飞速发展&#xff0c;HTTP代理服务器在网络通信中扮演着越来越重要的角色。它们作为客户端和服务器之间的中介&#xff0c;不仅优化了网络性能&#xff0c;还提供了强大的安全性和隐私保护功能。” 一、HTTP代理服务器的概念与作用 HTTP代理服务器是一种能够接…

Qt扫盲-QRect矩形描述类

QRect矩形描述总结 一、概述二、常用函数1. 移动类2. 属性函数3. 判断4. 比较计算 三、渲染三、坐标 一、概述 QRect类使用整数精度在平面中定义一个矩形。在绘图的时候经常使用&#xff0c;作为一个二维的参数描述类。 一个矩形主要有两个重要属性&#xff0c;一个是坐标&am…

前端面试题16(跨域问题)

跨域问题源于浏览器的同源策略&#xff08;Same-origin policy&#xff09;&#xff0c;这一策略限制了来自不同源的“写”操作&#xff08;比如更新、删除数据等&#xff09;&#xff0c;同时也限制了读操作。当一个网页尝试请求与自身来源不同的资源时&#xff0c;浏览器会阻…

设计模式探索:代理模式

1. 什么是代理模式 定义 代理模式是一种结构型设计模式&#xff0c;通过为其他对象提供一种代理以控制对这个对象的访问。代理对象在客户端和实际对象之间起到中介作用&#xff0c;可以在不改变真实对象的情况下增强或控制对真实对象的访问。 目的 代理模式的主要目的是隐…

着急,为啥AI叫好不叫座啊?

关注卢松松&#xff0c;会经常给你分享一些我的经验和观点。 李彦宏在2024世界人工智能大会上说&#xff1a; 没有应用&#xff0c;光有基础模型&#xff0c;不管是开源还是闭源都一文不值&#xff0c;所以我从去年下半年开始讲&#xff0c;大家不要卷模型了&#xff0c;要去…

MySQL---事务管理

1.关于事务 理解和学习事务&#xff0c;不能只站在程序猿的角度来理解事务&#xff0c;而是要站在使用者&#xff08;用户&#xff09;的角度来理解事务。 比如支付宝转账&#xff0c;A转了B100块前&#xff0c;在程序猿的角度来看&#xff0c;是两条update操作&#xff0c;A …

PCDN技术如何提高内容分发效率?(贰)

PCDN技术通过以下方式提高内容分发效率: 1.利用用户设备作为分发节点:与传统的 CDN技术主要依赖中心化服务器不同&#xff0c; PCDN技术利用用户的设备作为内容分发的节点。当用户下载内容时&#xff0c;他们的设备也会成为内容分发的一部分&#xff0c;将已下载的内容传递给其…

项目部署_持续集成_Jenkins

1 今日内容介绍 1.1 什么是持续集成 持续集成&#xff08; Continuous integration &#xff0c; 简称 CI &#xff09;指的是&#xff0c;频繁地&#xff08;一天多次&#xff09;将代码集成到主干 持续集成的组成要素 一个自动构建过程&#xff0c; 从检出代码、 编译构建…

树状数组实现 查找逆序对

题意&#xff1a; 输入一个整数n。 接下来输入一行n个整数 。 1< < n ,且每个数字只会出现一次 题解&#xff1a; 按每个数字的大小存入树状数组 #include<bits/stdc.h> using namespace std; #define ll long long const int N10000; int arr[N]; ll a[N];…

Java中关于构造代码块和静态代码块的解析

构造代码块 特点&#xff1a;优先于构造方法执行,每new一次,就会执行一次 public class Person {public Person(){System.out.println("我是无参构造方法");}{System.out.println("我是构造代码块"); //构造代码块} }public class Test {public stati…

golang与以太坊交互

文章目录 golang与以太坊交互什么是go-ethereum与节点交互前的准备使用golang与以太坊区块链交互查询账户的余额使用golang生成以太坊账户使用golang生成以太坊钱包使用golang在账户之间转移eth安装使用solc和abigen生成bin和abi文件生成go文件使用golang在测试网上部署智能合约…

GD32MCU如何实现掉电数据保存?

大家在GD32 MCU应用时&#xff0c;是否会碰到以下应用需求&#xff1a;希望在MCU掉电时保存一定的数据或标志&#xff0c;用以记录一些关键的数据。 以GD32E103为例&#xff0c;数据的存储介质可以选择内部Flash或者备份数据寄存器。 如下图所示&#xff0c;片内Flash具有10年…

【综合能源】计及碳捕集电厂低碳特性及需求响应的综合能源系统多时间尺度调度模型

目录 1 主要内容 2 部分程序 3 实现效果 4 下载链接 1 主要内容 本程序是对《计及碳捕集电厂低碳特性的含风电电力系统源-荷多时间尺度调度方法》方法复现&#xff0c;非完全复现&#xff0c;只做了日前日内部分&#xff0c;并在上述基础上改进升级为电热综合电源微网系统&…

力扣习题--找不同

目录 前言 题目和解析 1、找不同 2、 思路和解析 总结 前言 本系列的所有习题均来自于力扣网站LeetBook - 力扣&#xff08;LeetCode&#xff09;全球极客挚爱的技术成长平台 题目和解析 1、找不同 给定两个字符串 s 和 t &#xff0c;它们只包含小写字母。 字符串 t…