基于 CycleGAN 对抗网络的自定义数据集训练

目录

生成对抗网络(GAN)

CycleGAN模型训练

训练数据生成

下载开源项目CycleGAN

配置训练环境

开始训练

模型测试

可视化结果


生成对抗网络(GAN)

        首先介绍一下什么是GAN网络,它是由生成器(Generator)和判别器(Discriminator)组成,二者均是由神经网络构成,通过不断的博弈来提高输出数据质量。

        生成器的目的是学习真实数据的分布,从而能够生成与真实数据相似的新样本。它接收随机噪声作为输入,并通过一系列的神经网络层将其转化为具有特定特征的输出,试图欺骗判别器使其认为生成的数据是真实的。

        判别器则负责区分输入数据是来自真实数据集还是由生成器生成的。它接收数据并输出一个概率值,表示该数据为真实数据的可能性。判别器通过不断学习来提高自己区分真实数据和生成数据的能力

        在训练过程中,生成器和判别器进行对抗性的博弈。生成器努力提高生成数据的质量,以使其能够骗过判别器;而判别器则努力提高自己的鉴别能力,不被生成器欺骗。通过不断地迭代训练,双方的性能逐渐提升,最终达到一种平衡状态,此时生成器能够生成非常逼真的样本,而判别器也具有较高的鉴别能力。

CycleGAN 是由 Jun-Yan Zhu 等人于 2017 年提出的,核心思想是通过两个生成器和两个判别器来实现无监督的图像转换2。它引入了循环一致性损失,确保转换是双向的且在转换前后能够保持图像的一致性。

CycleGAN 论文:https://arxiv.org/abs/1703.10593

上面这个图是该网络实现的风格迁移,感觉这个网络还是挺有意思的,就想着训练一下自己的数据集看下效果,那下面我们直接进入正题吧。

CycleGAN模型训练

注意:目前只尝试过图像对的训练,仅支持包含src和dst的数据集

GitHub项目:CycleGAN-based-train

整体目录架构:

训练数据生成

首先准备自己需要训练的数据集,需要包含源和目标,数据集的格式如下:

其中,O-HAZY NTIRE 2018是根目录,GT是源图像存放路径,hazy是目标图像存放路径

同时请准备好测试样本文件夹test-sample(可自定义),准备的一定要是图像文件夹,暂时不会支持单张图像的测试,格式如下:

数据集准备好后运行main.py文件,需要注意参数设置,具体请查看文件说明

# main.py

import os
import shutil
import numpy as np
from PIL import Image
import matplotlib.pyplot as plt
import matplotlib

matplotlib.use('TkAgg')
from tqdm import tqdm

# ----------------------训练数据路径-----------------------#
#   仅支持包含src和dst的数据集(图像对)
# -------------------------------------------------------#
root = r'O-HAZY NTIRE 2018'
# --------------------------------------------------------#
#       label1:src的路径名  |  label2:dst的路径名
# --------------------------------------------------------#
label1 = 'GT'
label2 = 'hazy'
# -------------------------生成图像可视化-------------------------#
#   !!! 在训练和测试均完成后进行结果检查时仅可设置为True,否则报错  !!!
#   该部分只是对结果的可视化,预测阶段请查看README
# -------------------------------------------------------------#
test = False
# ------------------------测试样本------------------------------#
test_data_path = './test-sample'
# ------------------------测试结果图像保存路径---------------------#
# !!!   里面是已经得到的测试结果和原图     !!!
# -------------------------------------------------------------#
results_path = './results/dehaze_cyclegan/test_latest/images/'


def make_data(src_path, dst_path, label):
    src_path = src_path + f'/{label}/'
    image_files = [f for f in os.listdir(src_path) if f.endswith(('.jpg', '.jpeg', '.png', '.gif', '.bmp'))]
    
    with tqdm(total=len(image_files)) as pbar:
        for filename in image_files:
            file_path = os.path.join(src_path, filename)
            if filename.endswith(('.jpg', '.jpeg', '.png', '.gif', '.bmp')):
                image = Image.open(file_path)
                target_file = os.path.join(dst_path, filename)
                image.save(target_file)
                
            pbar.update(1)



if __name__ == '__main__':
    if not test:
        # -------------------创建CycleGAN的训练数据路径-----------------------#
        if not os.path.exists('dataset'):
            os.makedirs('dataset')
        if not os.path.exists('dataset/trainA'):
            os.makedirs('dataset/trainA')
        if not os.path.exists('dataset/trainB'):
            os.makedirs('dataset/trainB')

        # --------------------------检查图像对数量----------------------------#
        num_images = len(os.listdir(root + f'/{label1}/'))
        idx = np.arange(1, num_images + 1)
        print(f'查找到{num_images}个图像对')

        make_data(root, 'dataset/trainA/', label1)
        make_data(root, 'dataset/trainB/', label2)

    # ----------------------可视化阶段-----------------------------------#
    else:
        for f in os.listdir(test_data_path):
            fake = f.split('.')[0] + '_fake.png'
            real = f.split('.')[0] + '_real.png'

            fig = plt.figure()
            ax = plt.subplot(1, 2, 1)
            img1 = Image.open(results_path + real)
            plt.imshow(img1)

            ax = plt.subplot(1, 2, 2)
            img2 = Image.open(results_path + fake)
            plt.imshow(img2)

            plt.show()

下载开源项目CycleGAN

这一步如果下载了我上传的GitHub仓库的可以直接跳过,因为我已经将该项目放置在仓库里面,不需要重复下载。当然如果没有下载,请继续往下看

方式一:git clone GitHub - junyanz/pytorch-CycleGAN-and-pix2pix: Image-to-Image Translation in PyTorch

方式二:百度网盘:pytorch-CycleGAN-and-pix2pix

链接:https://pan.baidu.com/s/1WC-kEonwm7bFujO72GZAcQ        提取码:jsw2

配置训练环境

终端打开pytorch-CycleGAN-and-pix2pix,输入以下命令

pip install -r requirements.txt

开始训练

同样的,在终端打开该项目,输入以下指令:

python train.py --dataroot ./dataset --name dehaze_cyclegan --model cycle_gan

其中,只有 --name 是可改参数,可以自己命名模型的名称,但是修改后一定要与测试时的名称一致,请一定注意这一点

此外,如果在训练过程中出现“OSError: [WinError 1455] 页面文件太小,无法完成操作”报错信息,这是由于训练环境所在磁盘虚拟内存不足导致,调整方法如下:

最后一步选择训练环境所在的磁盘进行修改即可

训练过程截图

模型测试

在终端打开该项目,输入以下指令:

cp ./checkpoints/dehaze_cyclegan/latest_net_G_A.pth ./checkpoints/dehaze_cyclegan/latest_net_G.pth
python test.py --dataroot ./test-sample --name dehaze_cyclegan --model test --no_dropout --direction AtoB

这里需要注意的是 --dataroot 是测试样本,可以自己调整路径,同时注意模型名称是否与训练的一致,不一致请修改

生成的结果会保存在results文件夹下,目录结构如下:

其中,fake是生成图像,real是原图像,同时所有图像尺寸均会被调整为256\times 256

可视化结果

运行main.py文件,需要设置3个参数:test、test_data_path、results_path(test=True),详情请查看具体文件

我想要实现图像加雾,但是这个效果看起来一般吧,也有可能是图像数据对和训练轮次太少了。但不管怎么说,终究还是成功了嘛。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/875510.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

【CTF Web】BUUCTF Upload-Labs-Linux Pass-13 Writeup(文件上传+PHP+文件包含漏洞+PNG图片马)

Upload-Labs-Linux 1 点击部署靶机。 简介 upload-labs是一个使用php语言编写的,专门收集渗透测试和CTF中遇到的各种上传漏洞的靶场。旨在帮助大家对上传漏洞有一个全面的了解。目前一共20关,每一关都包含着不同上传方式。 注意 1.每一关没有固定的…

Modbus协议02:存储区简介

视频链接:【2】Modbus协议存储区说明_哔哩哔哩_bilibilihttps://www.bilibili.com/video/BV11G4y1W7pU?p2&vd_sourceb5775c3a4ea16a5306db9c7c1c1486b5 1.为什么需要存储区、存储区类型及代号 2.Modbus存储区范围及地址模型

SLM561A​​系列 60V 10mA到50mA线性恒流LED驱动芯片 为智能家居照明注入新活力

SLM561A系列选型参考: SLM561A10ae-7G SOD123 SLM561A15ae-7G SOD123 SLM561A20ae-7G SOD123 SLM561A25ae-7G SOD123 SLM561A30ae-7G SOD123 SLM561A35ae-7G SOD123 SLM561A40ae-7G SOD123 SLM561A45ae-7G SOD123 SLM561A50ae-7G SOD123 …

在Webmin上默认状态无法正常显示 Mariadb V11.02及以上版本

OS: Armbian OS 24.5.0 Bookworm Mariadb V11.02及以上版本 Webmin:V2.202 非常小众的问题,主要是记录一下。 如题 Webmin 默认无法正常显示 Mariadb V11.02及以上版本 如果对 /etc/webmin/mysql/config 文件作相应调整就可以再现Mariadb管理界面。 路径…

AI prompt(提示词)

# 好用的用于学习的AI提示词 ## 费曼学习法 请使用费曼学习法,用简单的语言解释(量子力学)是什么,并提供一个简单的例子来说明它如何应用 ## 帕累托法则(80/20原则) 将(量子力学)最…

09_Tensorflow2图像处理大赏:让你的图片笑出AI感,惊艳朋友圈!

1. 图像处理案例 1.1 逆时针旋转90度 import tensorflow as tf import matplotlib.pyplot as plt import matplotlib.cm as cm import numpy import osdef show_pic(pic,name,cmapNone):显示图像plt.imshow(pic,cmapcmap) plt.axis(off) # 打开坐标轴为 on # 设置图像标题…

【C++】认识C++(前言)

🦄个人主页:小米里的大麦-CSDN博客 🎏所属专栏:C_小米里的大麦的博客-CSDN博客 🎁代码托管:C: 探索C编程精髓,打造高效代码仓库 (gitee.com) ⚙️操作环境:Visual Studio 2022 目录 一、本节概述 二、什么是C 三、C发展史 四…

苏茵茵:以时尚之名,诠释品质生活

在女性追求个性化与自我表达的今天,时尚早已超越了简单的穿着打扮,它成为女性展现自我风格、彰显独特魅力的重要方式。从广泛的兴趣爱好到精心雕琢的个人风格,每一处细节都闪耀着女性对个性独特与自我表达的深切渴望。正是这股不可阻挡的潮流…

Unity6 + UE5.4 PSO缓存实践记录

题图(取自COD冷战的着色器编译提示) PSO(管线状态对象 Pipeline State Object)是伴随现代图形API(DirectX12、Vulkan、Metal)而出现的概念,它本质上是单次绘制时渲染管线所处的状态信息的集合&…

Blender渲染太慢怎么办?blender云渲染已开启

动画行业蓬勃发展,动画制作软件亦持续推陈出新,当制作平台日益丰富,创作难度降低,创作效率提升,如何高效完成复杂动画的渲染就成了从业者更关心的问题。 云渲染技术的出现,无疑为动画制作者提供了前所未有…

kafka原理剖析及实战演练

一、消息系统概述 一)消息系统按消息发送模型分类 1、peer-to-peer(单播) 特点: 一般基于pull或polling接收消息发送对队列中的消息被一个而且仅仅一个接收者所接收,即使有多个接收者在同一队列中侦听同一消息即支持异…

利用熵权法进行数值评分计算——算法过程

1、概述 在软件系统中,研发人员常常遇上需要对系统内的某种行为/模型进行评分的情况。例如根据系统的各种漏洞情况对系统安全性进行评分、根据业务员最近操作系统的情况对业务员工作状态进行打分等等。显然研发人员了解一种或者几种标准评分算法是非常有利于开展研…

中控室控制台处在自动状态什么意思

在现代工业和智能控制系统中,中控室控制台作为集中控制和管理各种设备、系统和流程的核心,扮演着至关重要的角色。当提到中控室控制台处在自动状态时,这通常意味着控制台已经切换到一种高度智能化的工作模式,能够自动调整和管理各…

【SQL】百题计划:SQL判断条件OR的使用。

【SQL】百题计划-20240912 Select name, population, area from World where area>3000000 or population > 25000000;

品读 Java 经典巨著《Effective Java》90条编程法则,第4条:通过私有构造器强化不可实例化的能力

文章目录 【前言】欢迎订阅【品读《Effective Java》】系列专栏java.lang.Math 类的设计经验总结 【前言】欢迎订阅【品读《Effective Java》】系列专栏 《Effective Java》是 Java 开发领域的经典著作,作者 Joshua Bloch 以丰富的经验和深入的知识,全面…

网络运输层之(1)TCP协议基础

网络运输层之(1)TCP协议基础 Author: Once Day Date: 2024年9月12日 一位热衷于Linux学习和开发的菜鸟,试图谱写一场冒险之旅,也许终点只是一场白日梦… 漫漫长路,有人对你微笑过嘛… 全系列文章可参考专栏: 通信网络技术_Once-Day的博客-…

cv2.bitwise_or 提取ROI区域

原图如下所示,想提取圆形ROI区域,红色框 img np.ones(ori_img.shape, dtype"uint8") img img * 255 cv2.circle(img, (50,50), 50, 0, -1) self.bitwiseOr cv2.bitwise_or(ori_img, circle)使用一个和原图尺寸一致的图像做mask,图白圆黑 以…

通信工程学习:什么是PC永久连接、SPC软永久连接

一、PC永久连接 PC(Permanent Connection)永久连接是一种由网管系统通过网管协议建立的长期稳定的连接方式。在ASON(自动交换光网络)中,PC永久连接沿袭了传统光网络的连接建立形式,其特点主要包括&#xff…

视频监控平台是如何运作的?EasyCVR视频汇聚平台的高效策略与实践

随着科技的飞速发展,视频监控平台在社会安全、企业管理、智慧城市构建等领域发挥着越来越重要的作用。一个高效的视频监控平台,不仅依赖于先进的硬件设备,更离不开强大的视频处理技术作为支撑。这些平台集成了多种先进的视频技术,…

微波无源器件 OMT 2 倍频程带宽紧凑十字转门OMT

摘要: 一个64%瞬态带宽的可变比例十字转门OMT用于在所谓的延伸C频带卫星链接被提出。所体术的结构通过在四个输出矩形波导结处添加一个拓宽的单阶梯来克服现在的实际带宽限制。这个明智的(judicious)调整,和减高度波导和E面弯头的和功率合成器的使用,保证…