目录
生成对抗网络(GAN)
CycleGAN模型训练
训练数据生成
下载开源项目CycleGAN
配置训练环境
开始训练
模型测试
可视化结果
生成对抗网络(GAN)
首先介绍一下什么是GAN网络,它是由生成器(Generator)和判别器(Discriminator)组成,二者均是由神经网络构成,通过不断的博弈来提高输出数据质量。
生成器的目的是学习真实数据的分布,从而能够生成与真实数据相似的新样本。它接收随机噪声作为输入,并通过一系列的神经网络层将其转化为具有特定特征的输出,试图欺骗判别器使其认为生成的数据是真实的。
判别器则负责区分输入数据是来自真实数据集还是由生成器生成的。它接收数据并输出一个概率值,表示该数据为真实数据的可能性。判别器通过不断学习来提高自己区分真实数据和生成数据的能力。
在训练过程中,生成器和判别器进行对抗性的博弈。生成器努力提高生成数据的质量,以使其能够骗过判别器;而判别器则努力提高自己的鉴别能力,不被生成器欺骗。通过不断地迭代训练,双方的性能逐渐提升,最终达到一种平衡状态,此时生成器能够生成非常逼真的样本,而判别器也具有较高的鉴别能力。
CycleGAN 是由 Jun-Yan Zhu 等人于 2017 年提出的,核心思想是通过两个生成器和两个判别器来实现无监督的图像转换2。它引入了循环一致性损失,确保转换是双向的且在转换前后能够保持图像的一致性。
CycleGAN 论文:https://arxiv.org/abs/1703.10593
上面这个图是该网络实现的风格迁移,感觉这个网络还是挺有意思的,就想着训练一下自己的数据集看下效果,那下面我们直接进入正题吧。
CycleGAN模型训练
注意:目前只尝试过图像对的训练,仅支持包含src和dst的数据集
GitHub项目:CycleGAN-based-train
整体目录架构:
训练数据生成
首先准备自己需要训练的数据集,需要包含源和目标,数据集的格式如下:
其中,O-HAZY NTIRE 2018是根目录,GT是源图像存放路径,hazy是目标图像存放路径
同时请准备好测试样本文件夹test-sample(可自定义),准备的一定要是图像文件夹,暂时不会支持单张图像的测试,格式如下:
数据集准备好后运行main.py文件,需要注意参数设置,具体请查看文件说明
# main.py
import os
import shutil
import numpy as np
from PIL import Image
import matplotlib.pyplot as plt
import matplotlib
matplotlib.use('TkAgg')
from tqdm import tqdm
# ----------------------训练数据路径-----------------------#
# 仅支持包含src和dst的数据集(图像对)
# -------------------------------------------------------#
root = r'O-HAZY NTIRE 2018'
# --------------------------------------------------------#
# label1:src的路径名 | label2:dst的路径名
# --------------------------------------------------------#
label1 = 'GT'
label2 = 'hazy'
# -------------------------生成图像可视化-------------------------#
# !!! 在训练和测试均完成后进行结果检查时仅可设置为True,否则报错 !!!
# 该部分只是对结果的可视化,预测阶段请查看README
# -------------------------------------------------------------#
test = False
# ------------------------测试样本------------------------------#
test_data_path = './test-sample'
# ------------------------测试结果图像保存路径---------------------#
# !!! 里面是已经得到的测试结果和原图 !!!
# -------------------------------------------------------------#
results_path = './results/dehaze_cyclegan/test_latest/images/'
def make_data(src_path, dst_path, label):
src_path = src_path + f'/{label}/'
image_files = [f for f in os.listdir(src_path) if f.endswith(('.jpg', '.jpeg', '.png', '.gif', '.bmp'))]
with tqdm(total=len(image_files)) as pbar:
for filename in image_files:
file_path = os.path.join(src_path, filename)
if filename.endswith(('.jpg', '.jpeg', '.png', '.gif', '.bmp')):
image = Image.open(file_path)
target_file = os.path.join(dst_path, filename)
image.save(target_file)
pbar.update(1)
if __name__ == '__main__':
if not test:
# -------------------创建CycleGAN的训练数据路径-----------------------#
if not os.path.exists('dataset'):
os.makedirs('dataset')
if not os.path.exists('dataset/trainA'):
os.makedirs('dataset/trainA')
if not os.path.exists('dataset/trainB'):
os.makedirs('dataset/trainB')
# --------------------------检查图像对数量----------------------------#
num_images = len(os.listdir(root + f'/{label1}/'))
idx = np.arange(1, num_images + 1)
print(f'查找到{num_images}个图像对')
make_data(root, 'dataset/trainA/', label1)
make_data(root, 'dataset/trainB/', label2)
# ----------------------可视化阶段-----------------------------------#
else:
for f in os.listdir(test_data_path):
fake = f.split('.')[0] + '_fake.png'
real = f.split('.')[0] + '_real.png'
fig = plt.figure()
ax = plt.subplot(1, 2, 1)
img1 = Image.open(results_path + real)
plt.imshow(img1)
ax = plt.subplot(1, 2, 2)
img2 = Image.open(results_path + fake)
plt.imshow(img2)
plt.show()
下载开源项目CycleGAN
这一步如果下载了我上传的GitHub仓库的可以直接跳过,因为我已经将该项目放置在仓库里面,不需要重复下载。当然如果没有下载,请继续往下看
方式一:git clone GitHub - junyanz/pytorch-CycleGAN-and-pix2pix: Image-to-Image Translation in PyTorch
方式二:百度网盘:pytorch-CycleGAN-and-pix2pix
链接:https://pan.baidu.com/s/1WC-kEonwm7bFujO72GZAcQ 提取码:jsw2
配置训练环境
在终端打开pytorch-CycleGAN-and-pix2pix,输入以下命令
pip install -r requirements.txt
开始训练
同样的,在终端打开该项目,输入以下指令:
python train.py --dataroot ./dataset --name dehaze_cyclegan --model cycle_gan
其中,只有 --name 是可改参数,可以自己命名模型的名称,但是修改后一定要与测试时的名称一致,请一定注意这一点
此外,如果在训练过程中出现“OSError: [WinError 1455] 页面文件太小,无法完成操作”报错信息,这是由于训练环境所在磁盘的虚拟内存不足导致,调整方法如下:
最后一步选择训练环境所在的磁盘进行修改即可
训练过程截图
模型测试
在终端打开该项目,输入以下指令:
cp ./checkpoints/dehaze_cyclegan/latest_net_G_A.pth ./checkpoints/dehaze_cyclegan/latest_net_G.pth
python test.py --dataroot ./test-sample --name dehaze_cyclegan --model test --no_dropout --direction AtoB
这里需要注意的是 --dataroot 是测试样本,可以自己调整路径,同时注意模型名称是否与训练的一致,不一致请修改
生成的结果会保存在results文件夹下,目录结构如下:
其中,fake是生成图像,real是原图像,同时所有图像尺寸均会被调整为
可视化结果
运行main.py文件,需要设置3个参数:test、test_data_path、results_path(test=True),详情请查看具体文件
我想要实现图像加雾,但是这个效果看起来一般吧,也有可能是图像数据对和训练轮次太少了。但不管怎么说,终究还是成功了嘛。