Mindspore框架CycleGAN模型实现图像风格迁移|(三)损失函数计算

Mindspore框架:CycleGAN模型实现图像风格迁移算法

Mindspore框架CycleGAN模型实现图像风格迁移|(一)CycleGAN神经网络模型构建
Mindspore框架CycleGAN模型实现图像风格迁移|(二)实例数据集(苹果2橘子)
Mindspore框架CycleGAN模型实现图像风格迁移|(三)损失函数计算
Mindspore框架CycleGAN模型实现图像风格迁移|(四)CycleGAN模型训练
Mindspore框架CycleGAN模型实现图像风格迁移|(五)CycleGAN模型推理与资源下载

1. 损失函数计算

CycleGAN 网络本质上是由两个镜像对称的 GAN 网络组成。

在这里插入图片描述
运算流程:
在这里插入图片描述

CycleGAN网络运转流程:图中苹果图片 𝑥 经过生成器 𝐺得到伪橘子 𝑌̂ ,然后将伪橘子 𝑌̂ 结果送进生成器 𝐹又产生苹果风格的结果 𝑥̂ ,最后将生成的苹果风格结果 𝑥̂ 与原苹果图片 𝑥一起计算出循环一致损失。

对生成器 𝐺 及其判别器 𝐷𝑌:
x-> 𝐺(𝑥)
目标损失函数定义为:
在这里插入图片描述
其中 𝐺试图生成看起来与 𝑌 中的图像相似的图像 𝐺(𝑥),而 𝐷𝑌的目标是区分翻译样本 𝐺(𝑥) 和真实样本 𝑦,生成器的目标是最小化这个损失函数以此来对抗判别器。即 在这里插入图片描述

对生成器G到F
x-> 𝐺(𝑥) ->F( 𝐺(𝑥))
在这里插入图片描述

这种循环损失计算,会捕捉这样的直觉,即如果我们从一个域转换到另一个域,然后再转换回来,我们应该到达我们开始的地方。

2.损失函数实现

# GAN网络损失函数,这里最后一层不使用sigmoid函数
loss_fn = nn.MSELoss(reduction='mean')
l1_loss = nn.L1Loss("mean")

def gan_loss(predict, target):
    target = ops.ones_like(predict) * target
    loss = loss_fn(predict, target)
    return loss

生成器网络和判别器网络的优化器:

# 构建生成器优化器
optimizer_rg_a = nn.Adam(net_rg_a.trainable_params(), learning_rate=0.0002, beta1=0.5)
optimizer_rg_b = nn.Adam(net_rg_b.trainable_params(), learning_rate=0.0002, beta1=0.5)
# 构建判别器优化器
optimizer_d_a = nn.Adam(net_d_a.trainable_params(), learning_rate=0.0002, beta1=0.5)
optimizer_d_b = nn.Adam(net_d_b.trainable_params(), learning_rate=0.0002, beta1=0.5)

3. 模型前向计算损失的过程

import mindspore as ms

# 前向计算

def generator(img_a, img_b):
    fake_a = net_rg_b(img_b)
    fake_b = net_rg_a(img_a)
    rec_a = net_rg_b(fake_b)
    rec_b = net_rg_a(fake_a)
    identity_a = net_rg_b(img_a)
    identity_b = net_rg_a(img_b)
    return fake_a, fake_b, rec_a, rec_b, identity_a, identity_b

lambda_a = 10.0
lambda_b = 10.0
lambda_idt = 0.5

# 生成器
def generator_forward(img_a, img_b):
    true = Tensor(True, dtype.bool_)
    fake_a, fake_b, rec_a, rec_b, identity_a, identity_b = generator(img_a, img_b)
    loss_g_a = gan_loss(net_d_b(fake_b), true)
    loss_g_b = gan_loss(net_d_a(fake_a), true)
    loss_c_a = l1_loss(rec_a, img_a) * lambda_a
    loss_c_b = l1_loss(rec_b, img_b) * lambda_b
    loss_idt_a = l1_loss(identity_a, img_a) * lambda_a * lambda_idt
    loss_idt_b = l1_loss(identity_b, img_b) * lambda_b * lambda_idt
    loss_g = loss_g_a + loss_g_b + loss_c_a + loss_c_b + loss_idt_a + loss_idt_b
    return fake_a, fake_b, loss_g, loss_g_a, loss_g_b, loss_c_a, loss_c_b, loss_idt_a, loss_idt_b

def generator_forward_grad(img_a, img_b):
    _, _, loss_g, _, _, _, _, _, _ = generator_forward(img_a, img_b)
    return loss_g

# 判别器
def discriminator_forward(img_a, img_b, fake_a, fake_b):
    false = Tensor(False, dtype.bool_)
    true = Tensor(True, dtype.bool_)
    d_fake_a = net_d_a(fake_a)
    d_img_a = net_d_a(img_a)
    d_fake_b = net_d_b(fake_b)
    d_img_b = net_d_b(img_b)
    loss_d_a = gan_loss(d_fake_a, false) + gan_loss(d_img_a, true)
    loss_d_b = gan_loss(d_fake_b, false) + gan_loss(d_img_b, true)
    loss_d = (loss_d_a + loss_d_b) * 0.5
    return loss_d

def discriminator_forward_a(img_a, fake_a):
    false = Tensor(False, dtype.bool_)
    true = Tensor(True, dtype.bool_)
    d_fake_a = net_d_a(fake_a)
    d_img_a = net_d_a(img_a)
    loss_d_a = gan_loss(d_fake_a, false) + gan_loss(d_img_a, true)
    return loss_d_a

def discriminator_forward_b(img_b, fake_b):
    false = Tensor(False, dtype.bool_)
    true = Tensor(True, dtype.bool_)
    d_fake_b = net_d_b(fake_b)
    d_img_b = net_d_b(img_b)
    loss_d_b = gan_loss(d_fake_b, false) + gan_loss(d_img_b, true)
    return loss_d_b

# 保留了一个图像缓冲区,用来存储之前创建的50个图像
pool_size = 50
def image_pool(images):
    num_imgs = 0
    image1 = []
    if isinstance(images, Tensor):
        images = images.asnumpy()
    return_images = []
    for image in images:
        if num_imgs < pool_size:
            num_imgs = num_imgs + 1
            image1.append(image)
            return_images.append(image)
        else:
            if random.uniform(0, 1) > 0.5:
                random_id = random.randint(0, pool_size - 1)

                tmp = image1[random_id].copy()
                image1[random_id] = image
                return_images.append(tmp)

            else:
                return_images.append(image)
    output = Tensor(return_images, ms.float32)
    if output.ndim != 4:
        raise ValueError("img should be 4d, but get shape {}".format(output.shape))
    return output

4.计算梯度和反向传播

from mindspore import value_and_grad

# 实例化求梯度的方法
grad_g_a = value_and_grad(generator_forward_grad, None, net_rg_a.trainable_params())
grad_g_b = value_and_grad(generator_forward_grad, None, net_rg_b.trainable_params())

grad_d_a = value_and_grad(discriminator_forward_a, None, net_d_a.trainable_params())
grad_d_b = value_and_grad(discriminator_forward_b, None, net_d_b.trainable_params())

# 计算生成器的梯度,反向传播更新参数
def train_step_g(img_a, img_b):
    net_d_a.set_grad(False)
    net_d_b.set_grad(False)

    fake_a, fake_b, lg, lga, lgb, lca, lcb, lia, lib = generator_forward(img_a, img_b)

    _, grads_g_a = grad_g_a(img_a, img_b)
    _, grads_g_b = grad_g_b(img_a, img_b)
    optimizer_rg_a(grads_g_a)
    optimizer_rg_b(grads_g_b)

    return fake_a, fake_b, lg, lga, lgb, lca, lcb, lia, lib

# 计算判别器的梯度,反向传播更新参数
def train_step_d(img_a, img_b, fake_a, fake_b):
    net_d_a.set_grad(True)
    net_d_b.set_grad(True)

    loss_d_a, grads_d_a = grad_d_a(img_a, fake_a)
    loss_d_b, grads_d_b = grad_d_b(img_b, fake_b)

    loss_d = (loss_d_a + loss_d_b) * 0.5

    optimizer_d_a(grads_d_a)
    optimizer_d_b(grads_d_b)

    return loss_d

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/802148.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

JAVA 异步编程(异步,线程,线程池)一

目录 1.概念 1.1 线程和进程的区别 1.2 线程的五种状态 1.3 单线程,多线程,线程池 1.4 异步与多线程的概念 2. 实现异步的方式 2.1 方式1 裸线程&#xff08;Thread&#xff09; 2.1 方式2 线程池&#xff08;Executor&#xff09; 2.1.1 源码分析 2.1.2 线程池创建…

新的“SCALE”软件允许为 AMD GPU 原生编译 CUDA 应用程序

虽然已经有各种努力&#xff0c;如HIPIFY来帮助将CUDA源代码转换为AMD GPU的可移植C代码&#xff0c;然后是之前AMD资助的ZLUDA&#xff0c;允许CUDA二进制文件通过CUDA库的直接替代品在AMD GPU上运行&#xff0c;但有一个新的竞争者&#xff1a;SCALE。SCALE现在作为GPGPU工具…

超算网络体系架构-资源层-平台层-服务层-应用层

目录 超算网络体系架构 我国超算基础设施 超算互联网相关标准研制方面 技术架构 资源层 基础资源 芯片多样 体系异构 高效存储 高速互连 资源池化 可隔离 可计量 互联网络 高带宽 低时延 高安全 平台层 算力接入 资源管理 算力调度 用户管理 交易管理 模…

基于springboot和mybatis的RealWorld后端项目实战二之实现tag接口

修改pom.xml 新增tag数据表 SET FOREIGN_KEY_CHECKS0;-- ---------------------------- -- Table structure for tags -- ---------------------------- DROP TABLE IF EXISTS tags; CREATE TABLE tags (id bigint(20) NOT NULL AUTO_INCREMENT,name varchar(255) NOT NULL,PR…

VBA学习(21):遍历文件夹(和子文件夹)中的文件

很多时候&#xff0c;我们都想要遍历文件夹中的每个文件&#xff0c;例如在工作表中列出所有文件名、对每个文件进行修改。VBA给我们提供了一些方式&#xff1a;&#xff08;1&#xff09;Dir函数&#xff1b;&#xff08;2&#xff09;File System Object。 使用Dir函数 Dir…

2024年大数据高频面试题(中篇)

文章目录 Kafka为什么要用消息队列为什么选择了kafkakafka的组件与作用(架构)kafka为什么要分区Kafka生产者分区策略kafka的数据可靠性怎么保证ack应答机制(可问:造成数据重复和丢失的相关问题)副本数据同步策略ISRkafka的副本机制kafka的消费分区分配策略Range分区分配策略…

三级域名能申请SSL证书吗?

在当今互联网时代&#xff0c;SSL证书已经成为了保障网站安全的重要工具&#xff0c;企业会为网站部署SSL证书来实现HTTPS加密以保护传输数据安全。然而随着业务的增长以及交易规模的扩大&#xff0c;为了更好的管理业务和内容&#xff0c;企业会在主域名的基础上划分二级域名&…

GitHub 令牌泄漏, Python 核心资源库面临潜在攻击

TheHackerNews网站消息&#xff0c;软件供应链安全公司 JFrog 的网络安全研究人员称&#xff0c;他们发现了一个意外泄露的 GitHub 令牌&#xff0c;可授予 Python 语言 GitHub 存储库、Python 软件包索引&#xff08;PyPI&#xff09;和 Python 软件基金会&#xff08;PSF&…

【RabbitMQ】一文详解消息可靠性

目录&#xff1a; 1.前言 2.生产者 3.数据持久化 4.消费者 5.死信队列 1.前言 RabbitMQ 是一款高性能、高可靠性的消息中间件&#xff0c;广泛应用于分布式系统中。它允许系统中的各个模块进行异步通信&#xff0c;提供了高度的灵活性和可伸缩性。然而&#xff0c;这种通…

网络准入控制设备是什么?有哪些?网络准入设备臻品优选

小李&#xff1a;“小张&#xff0c;最近公司网络频繁遭遇外部攻击&#xff0c;我们得加强一下网络安全了。” 小张&#xff1a;“是啊&#xff0c;我听说实施网络准入控制是个不错的选择。但具体什么是网络准入控制设备&#xff1f;我们有哪些选择呢&#xff1f;” 小李微笑…

2024Datawhale AI夏令营---Inclusion・The Global Multimedia Deepfake Detection--学习笔记

赛题背景&#xff1a; 其实总结起来就是一句话&#xff0c;这个项目是基于目前的深度伪装技术&#xff0c;就是通过大量人脸的原数据集进行模型训练之后&#xff0c;能够生成伪造的人脸视频。这项目就是教我们如何去实现这个DeepFake技术。 Task1:了解Deepfake和跑通baseline …

Python项目部署到Linux生产环境(uwsgi+python+flask+nginx服务器)

1.安装python 我这里是3.9.5版本 安装依赖&#xff1a; yum install zlib-devel bzip2-devel openssl-devel ncurses-devel sqlite-devel readline-devel tk-devel gcc make -y 根据自己的需要下载对应的python版本&#xff1a; cd local wget https://www.python.org/ftp…

开发实战经验分享:互联网医院系统源码与在线问诊APP搭建

作为一名软件开发者&#xff0c;笔者有幸参与了多个互联网医院系统的开发项目&#xff0c;并在此过程中积累了丰富的实战经验。本文将结合我的开发经验&#xff0c;分享互联网医院系统源码的设计与在线问诊APP的搭建过程。 一、需求分析 在开发任何系统之前&#xff0c;首先要…

成像光谱遥感技术中的AI革命:ChatGPT

遥感技术主要通过卫星和飞机从远处观察和测量我们的环境&#xff0c;是理解和监测地球物理、化学和生物系统的基石。ChatGPT是由OpenAI开发的最先进的语言模型&#xff0c;在理解和生成人类语言方面表现出了非凡的能力&#xff0c;ChatGPT在遥感中的应用&#xff0c;人工智能在…

【STM32】RTT-Studio中HAL库开发教程三:IIC通信--AHT20

文章目录 一、I2C总线通信协议二、AHT20传感器介绍三、STM32CubeMX配置硬件IIC四、RTT中初始化配置五、具体实现代码六、实验现象 一、I2C总线通信协议 使用奥松的AHT20温湿度传感器&#xff0c;对环境温湿度进行采集。AHT20采用的是IIC进行通信&#xff0c;可以使用硬件IIC或…

2. KNN分类算法与鸢尾花分类任务

鸢尾花分类任务 1. 鸢尾花分类步骤1.1 分析问题&#xff0c;搞定输入和输出1.2 每个类别各采集50朵花1.3 选择一种算法&#xff0c;完成输入到输出的映射1.4 第四步&#xff1a;部署&#xff0c;集成 2. KNN算法原理2.1 基本概念2.2 核心理念2.3 训练2.4 推理流程 3. 使用 skle…

Word参考文献交叉引用

前言 Word自带交叉引用功能&#xff0c;可在正文位置引用文档内自动编号的段落&#xff0c;同时创建超链接&#xff0c;适用于参考文献的引用。使用此方法对参考文献进行引用后&#xff0c;当参考文献的编号发生变化时&#xff0c;只需要更新域即可与正文中的引用相对应。下文…

vue3+TS从0到1手撸后台管理系统

1.路由配置 1.1路由组件的雏形 src\views\home\index.vue&#xff08;以home组件为例&#xff09; 1.2路由配置 1.2.1路由index文件 src\router\index.ts //通过vue-router插件实现模板路由配置 import { createRouter, createWebHashHistory } from vue-router import …

【15】Android基础知识之Window(一)

概述 这篇文章纠结了很久&#xff0c;在想需要怎么写&#xff1f;因为window有关的篇幅&#xff0c;如果需要讲起来那可太多了。从层级&#xff0c;或是从关联&#xff0c;总之不是很好开口。这次也下定决心&#xff0c;决定从浅入深的讲讲window这个东西。 Window Window是…

鸿蒙特色物联网实训室

一、 引言 在当今这个万物皆可连网的时代&#xff0c;物联网&#xff08;IoT&#xff09;正以前所未有的速度改变着我们的生活和工作方式。它如同一座桥梁&#xff0c;将实体世界与虚拟空间紧密相连&#xff0c;让数据成为驱动决策和创新的关键力量。随着物联网技术的不断成熟…