- 🍨 本文为🔗365天深度学习训练营 中的学习记录博客
- 🍖 原作者:K同学啊
一、CycleGAN原理
(一)CycleGAN的原理结构:
CycleGAN(循环生成对抗网络)是一种生成对抗网络(GAN),它能够在没有成对训练样本的情况下,将一个域(比如照片中的马)转换成另一个域(比如照片中的斑马)。CycleGAN 主要由两部分组成:两个生成器和一个判别器。生成器的作用是在两个域之间进行转换,而判别器则用于判断输入的图像是真实的还是由生成器生成的。
1. 生成器(Generator)
CycleGAN 包含两个生成器,一个用于将图像从域A转换到域B,另一个则用于将图像从域B转换回域A。每个生成器都是一个神经网络,通常采用卷积神经网络(CNN)的结构。生成器的目标是学习如何将输入图像转换成目标域中的图像,同时欺骗判别器,使其认为生成的图像是真实的。
2. 判别器(Discriminator)
与生成器相对应,CycleGAN 也有两个判别器,一个用于判断图像是否属于域A,另一个用于判断图像是否属于域B。判别器同样通常采用卷积神经网络结构,其目标是能够准确地区分真实图像和由生成器生成的图像。
3. 循环一致性损失(Cycle Consistency Loss)
为了确保在没有成对训练样本的情况下,生成器能够学习到有效的映射,CycleGAN 引入了循环一致性损失。这个损失函数确保当图像从域A转换到域B,然后再转换回域A时,得到的图像与原始图像尽可能相似。同样地,从域B转换到域A,再转换回域B也应该保持循环一致性。
4. 对抗损失(Adversarial Loss)
对抗损失是GAN的核心,它确保生成器能够生成足以欺骗判别器的图像。对于CycleGAN,每个生成器都需要最小化对抗损失,使其生成的图像能够在相应的判别器上获得高分。
5. identity损失(Identity Loss)
除了上述两种损失函数,CycleGAN 还引入了identity损失,以确保当输入图像已经属于目标域时,生成器能够返回原始图像。这有助于保持生成器在训练过程中的稳定性,并防止过度拟合。
6.训练过程:
在训练过程中,CycleGAN 通过不断调整生成器和判别器的参数来最小化上述损失函数。生成器尝试生成越来越能欺骗判别器的图像,而判别器则尝试越来越准确地识别真实图像和生成图像。通过这种对抗过程,生成器最终能够学习到如何在两个域之间进行有效的转换。
7.应用:
CycleGAN 在计算机视觉领域有广泛的应用,如风格迁移、季节变换、照片增强等,它为那些没有成对训练数据的图像转换任务提供了一种有效的解决方案。
(二)CycleGAN和传统GAN的对比
CycleGAN与普通GAN相比,有几个特殊之处,这些特性使得CycleGAN适合于图像到图像的转换任务,尤其是在没有成对训练数据的情况下:
- 无配对数据要求:
- 普通GAN通常需要成对的训练数据,即每个输入图像都有一个对应的输出图像。而CycleGAN不需要这样的成对数据,它可以学习一个域(比如照片)到另一个域(比如画作)的转换,即使没有直接的成对映射。
- 循环一致性损失:
- CycleGAN引入了循环一致性损失(Cycle Consistency Loss),这是其核心的创新之一。这个损失函数确保当图像从源域转换到目标域,然后再转换回源域时,能够尽可能地恢复到原始图像。这样的循环保证了即使在没有成对数据的情况下,转换过程也是合理的。
- 两个生成器和两个判别器:
- CycleGAN包含两个生成器,每个生成器负责一个方向的转换(从域A到域B和从域B到域A)。同时,也有两个判别器,分别用于判断图像是否属于域A或域B。这种结构使得CycleGAN能够同时学习两个域之间的映射。
- 对抗性损失和身份损失:
- 除了循环一致性损失,CycleGAN还结合了对抗性损失和身份损失。对抗性损失确保生成器能够生成足以欺骗判别器的图像,而身份损失确保当输入图像已经属于目标域时,生成器能够返回原始图像,这有助于保持生成器的稳定性。
- 多样化的应用场景:
- 由于不需要成对数据,CycleGAN可以应用于多种不同的图像到图像的转换任务,如风格迁移、季节变换、艺术作品风格化等,而这些任务在普通GAN中很难实现。
- 更强的泛化能力:
- 由于循环一致性损失的设计,CycleGAN在训练过程中学习到了更加泛化的特征表示,这使得它在面对未见过的数据时,也能表现出较好的转换效果。
二、CycleGAN代码分析
第一部分
导入的库的作用的分析
argparse
:
- 这是一个Python的标准库,用于解析命令行参数。在训练或测试CycleGAN时,可以通过命令行传入各种参数,如学习率、批量大小、数据集路径等,
argparse
可以帮助程序解析这些参数。itertools
:
- 这也是Python的标准库之一,它提供了多种迭代操作的函数。在处理数据集或进行模型训练时,可能会用到
itertools
来生成迭代器,例如用于循环遍历数据批次。torchvision.utils
:
- 这个模块包含了多个实用函数,用于处理和展示图像。
save_image
函数用于将Tensor保存为图像文件,make_grid
函数用于将多个图像拼接成一个网格图像,这在可视化训练过程中的图像时非常有用。torch.utils.data
:
- 这个模块提供了数据加载和处理的工具,
DataLoader
类是其中的核心,它允许我们以批量方式加载数据,并提供数据并行处理的功能,这对于实现高效的数据加载非常重要。models
:datasets
:utils
:这三个是自带的数据,都全部抓取,导入模型中
import argparse
import itertools
from torchvision.utils import save_image, make_grid
from torch.utils.data import DataLoader
from models import *
from datasets import *
from utils import *
import torch
第二部分
代码分析
parser = argparse.ArgumentParser()
这一行代码创建了一个 ArgumentParser
对象,这个对象将用于解析命令行参数。
argparse 是 Python 的一个库,它提供了一个方便的方式来解析命令行参数。ArgumentParser 类是 argparse
库中最重要的类,它设置了解析器的基本行为,并允许开发者添加参数。创建了解析器对象之后,开发者可以通过调用 add_argument() 方法来指定程序需要接受的命令行参数。每个
add_argument() 调用都为一个特定的参数添加了规则,包括参数的名称、类型、帮助信息等。
接下来一段主要就是设置模型的各个参数,接下来着重关注每个参数的作用
b1
和b2
:
- 这两个参数是Adam优化器的超参数,分别代表一阶矩估计的指数衰减率和二阶矩估计的指数衰减率。它们用于计算梯度的一阶矩估计(mean)和二阶矩估计(uncentered
variance),并对学习率进行自适应调整。batch_size
:
- 批量大小,即每次训练时传递给模型的样本数量。批量大小的大小会影响模型的收敛速度和稳定性。
channels
:
- 图像的通道数。对于彩色图像,通常是3(红、绿、蓝通道)。
checkpoint_interval
:
- 检查点间隔,即在训练过程中,每过多少个周期(epoch)保存一次模型的权重。
dataset_name
:
- 数据集的名称。CycleGAN可以用于不同的图像到图像的转换任务,这个参数指定了当前使用的数据集。
decay_epoch
:
- 学习率衰减开始的周期。通常在训练过程中,学习率会随着训练的进行而逐渐减小,
decay_epoch
指定了何时开始衰减。epoch
:
- 当前周期数。这个参数通常在训练开始时设置为0,并在每个周期结束时递增。
img_height
和img_width
:
- 输入图像的高度和宽度。CycleGAN要求所有输入图像具有相同的大小,这些参数指定了图像的尺寸。
lambda_cyc
:
- 循环一致性损失的权重。在CycleGAN中,循环一致性损失用于确保图像在转换回原始域后尽可能接近原始图像。
lambda_cyc
控制这个损失在总损失中的重要性。lambda_id
:
- 身份损失的权重。身份损失确保当输入图像已经属于目标域时,生成器能够返回原始图像。
lambda_id
控制这个损失在总损失中的重要性。lr
:
- 学习率,即模型参数在每次更新时的调整幅度。学习率的选择对模型的训练至关重要,过大的学习率可能导致模型无法收敛,过小的学习率可能导致训练过程缓慢。
n_cpu
:
- 用于数据加载的CPU核心数。在加载数据时,可以使用多个CPU核心来并行处理,以提高数据加载的效率。
n_epochs
:
- 总的训练周期数。一个周期是指模型对整个训练数据集进行一次完整的遍历。
n_residual_blocks
:
- 残差块的数量。在CycleGAN的生成器中,残差块用于构建网络的主干,增加残差块的数量可以增加模型的容量和表达能力。
sample_interval
:
- 采样间隔,即在训练过程中,每过多少个批次保存一次生成的图像样本。这有助于监控训练过程中模型的表现。 这些参数共同决定了CycleGAN模型的训练过程和表现。通过调整这些参数,可以优化模型的性能,并适应不同的训练环境和任务需求。
parser.parse_args()
是 argparse 库中的一个方法,它用于解析命令行参数。在代码片段中,parser
是一个 ArgumentParser 对象,它已经定义了程序可以接受的参数和它们的属性。当调用 parse_args()
方法时,会发生以下几件事情:
1.解析命令行参数: parse_args() 会检查命令行中提供的参数,并根据 parser
对象中定义的规则来解析它们。如果命令行参数的格式正确,它们将被转换成相应的数据类型(例如,字符串、整数、浮点数等)。
2.填充 args 对象:解析后的参数会被填充到一个名为 args 的命名空间对象中。这个对象包含了所有解析后的参数值,可以通过点号操作符访问这些值,例如args.batch_size。
3.提供默认值: 如果在命令行中没有提供某个参数的值,parse_args() 会使用在 add_argument() 调用中指定的默认值。
4. 错误处理: 如果命令行参数的格式不正确或者提供了未定义的参数,parse_args()
会自动打印出错误信息和一个用法提示,并退出程序。
parser = argparse.ArgumentParser()
parser.add_argument("--epoch", type=int, default=0, help="epoch to start training from")
parser.add_argument("--n_epochs", type=int, default=100, help="number of epochs of training")
parser.add_argument("--dataset_name", type=str, default="monet2photo", help="name of the dataset")
parser.add_argument("--batch_size", type=int, default=64, help="size of the batches")
parser.add_argument("--lr", type=float, default=0.0002, help="adam: learning rate")
parser.add_argument("--b1", type=float, default=0.5, help="adam: decay of first order momentum of gradient")
parser.add_argument("--b2", type=float, default=0.999, help="adam: decay of first order momentum of gradient")
parser.add_argument("--decay_epoch", type=int, default=50, help="epoch from which to start lr decay")
parser.add_argument("--n_cpu", type=int, default=16, help="number of cpu threads to use during batch generation")
parser.add_argument("--img_height", type=int, default=256, help="size of image height")
parser.add_argument("--img_width", type=int, default=256, help="size of image width")
parser.add_argument("--channels", type=int, default=3, help="number of image channels")
parser.add_argument("--sample_interval", type=int, default=100, help="interval between saving generator outputs")
parser.add_argument("--checkpoint_interval", type=int, default=1, help="interval between saving model checkpoints")
parser.add_argument("--n_residual_blocks", type=int, default=9, help="number of residual blocks in generator")
parser.add_argument("--lambda_cyc", type=float, default=10.0, help="cycle loss weight")
parser.add_argument("--lambda_id", type=float, default=7.0, help="identity loss weight")
opt = parser.parse_args()
print(opt)
第三部分
代码知识点
在CycleGAN的实现中,使用了三种损失函数来指导模型的训练。这些损失函数分别对应于不同的训练目标,下面是每个损失函数的
解释:
criterion_GAN
(对抗损失):
- 这是一种用于计算生成对抗网络(GAN)中的对抗性损失的函数。在CycleGAN中,每个生成器都试图生成能够欺骗对应判别器的图像。通常,GAN使用二进制交叉熵损失(BCELoss)作为对抗性损失,但在某些情况下,也可以使用均方误差损失(MSELoss)。
torch.nn.MSELoss()
计算的是预测值和目标值之间的均方误差,这种损失函数在图像生成任务中可以提供平滑的梯度,有助于生成器的学习。criterion_cycle
(循环一致性损失):
- 这个损失函数用于确保图像在经过两个生成器(从域A到域B,再从域B回到域A)的转换后,能够尽可能地恢复到原始图像。循环一致性损失使用L1范数(绝对值误差)来计算,因为L1损失对异常值不那么敏感,能够产生更清晰的图像。
torch.nn.L1Loss()
计算的是预测值和目标值之间的平均绝对误差。criterion_identity
(身份损失):
- 身份损失确保当输入图像已经属于目标域时,生成器能够返回原始图像。这个损失函数也使用L1损失来计算,因为它能够帮助生成器学习到保持图像结构的映射。例如,如果我们将一张风景照片作为输入,生成器应该输出与输入非常相似的风景照片,而不是将其转换为另一幅完全不同的图像。
总的来说,这三种损失函数共同作用于CycleGAN的训练过程中,使得生成器能够学习到在不同域之间进行有效转换的同时,保持图像的循环一致性和身份一致性。通过平衡这些损失函数,CycleGAN能够在没有成对训练样本的情况下,实现高质量的图像到图像的转换。
criterion_GAN = torch.nn.MSELoss()
criterion_cycle = torch.nn.L1Loss()
criterion_identity = torch.nn.L1Loss()
cuda = torch.cuda.is_available()
input_shape = (opt.channels, opt.img_height, opt.img_width)
第四部分
这段代码是CycleGAN模型训练脚本的的一部分,它涉及到创建生成器和判别器模型、将模型移动到GPU上(如果可用)、以及加载或初始化模型的权重。下面是分段解释:
创建生成器和判别器模型:
G_AB = GeneratorResNet(input_shape, opt.n_residual_blocks)
G_BA = GeneratorResNet(input_shape, opt.n_residual_blocks)
D_A = Discriminator(input_shape)
D_B = Discriminator(input_shape)
这四行代码创建了四个模型:两个生成器G_AB
和G_BA
,以及两个判别器D_A
和D_B
。GeneratorResNet
是生成器的类,它接收输入图像的形状和残差块的数量作为参数。Discriminator
是判别器的类,它接收输入图像的形状作为参数。
将模型和损失函数移动到GPU上:
if cuda:
G_AB = G_AB.cuda()
G_BA = G_BA.cuda()
D_A = D_A.cuda()
D_B = D_B.cuda()
criterion_GAN.cuda()
criterion_cycle.cuda()
criterion_identity.cuda()
这段代码检查是否使用了GPU(cuda
变量为True),如果是,则将创建的模型和损失函数移动到GPU上。这样可以利用GPU加速模型的训练和计算。
加载或初始化模型的权重:
if opt.epoch != 0:
# 加载预训练模型
G_AB.load_state_dict(torch.load("saved_models/%s/G_AB_%d.pth" % (opt.dataset_name, opt.epoch)))
G_BA.load_state_dict(torch.load("saved_models/%s/G_BA_%d.pth" % (opt.dataset_name, opt.epoch)))
D_A.load_state_dict(torch.load("saved_models/%s/D_A_%d.pth" % (opt.dataset_name, opt.epoch)))
D_B.load_state_dict(torch.load("saved_models/%s/D_B_%d.pth" % (opt.dataset_name, opt.epoch)))
else:
# 初始化权重
G_AB.apply(weights_init_normal)
G_BA.apply(weights_init_normal)
D_A.apply(weights_init_normal)
D_B.apply(weights_init_normal)
这段代码检查opt.epoch
是否不等于0,如果不等于0,则从磁盘加载对应周期的预训练模型权重。这些权重保存在以数据集名称和周期数命名的文件中。如果opt.epoch
等于0,则表示从头开始训练,这时会使用weights_init_normal
函数来初始化模型的权重。weights_init_normal
函数通常是一个自定义函数,用于将模型的权重初始化为正态分布。
总的来说,这段代码负责创建CycleGAN所需的模型,并将它们配置为在GPU上运行(如果可用),然后根据是否有预训练的权重来加载或初始化这些模型的权重。
第五部分
这段代码涉及到为CycleGAN的生成器和判别器创建优化器,设置学习率更新调度器,以及定义Tensor类型和重播缓冲区。下面是分段解释:
创建优化器:
optimizer_G = torch.optim.Adam(
itertools.chain(G_AB.parameters(), G_BA.parameters()), lr=opt.lr, betas=(opt.b1, opt.b2)
)
optimizer_D_A = torch.optim.Adam(D_A.parameters(), lr=opt.lr, betas=(opt.b1, opt.b2))
optimizer_D_B = torch.optim.Adam(D_B.parameters(), lr=opt.lr, betas=(opt.b1, opt.b2))
这五行代码创建了三个优化器:一个用于两个生成器G_AB
和G_BA
的联合参数(optimizer_G
),以及两个分别用于判别器D_A
和D_B
的优化器(optimizer_D_A
和optimizer_D_B
)。所有优化器都是使用Adam算法,它是一种适用于大规模数据和高维空间的优化算法。lr
参数设置学习率,betas
参数设置Adam算法中的两个超参数,分别是第一和第二矩估计的指数衰减率。
设置学习率更新调度器:
lr_scheduler_G = torch.optim.lr_scheduler.LambdaLR(
optimizer_G, lr_lambda=LambdaLR(opt.n_epochs, opt.epoch, opt.decay_epoch).step
)
lr_scheduler_D_A = torch.optim.lr_scheduler.LambdaLR(
optimizer_D_A, lr_lambda=LambdaLR(opt.n_epochs, opt.epoch, opt.decay_epoch).step
)
lr_scheduler_D_B = torch.optim.lr_scheduler.LambdaLR(
optimizer_D_B, lr_lambda=LambdaLR(opt.n_epochs, opt.epoch, opt.decay_epoch).step
)
这四行代码为每个优化器创建了一个学习率更新调度器。LambdaLR
是一个基于自定义函数的学习率调度器,它允许用户根据迭代次数来调整学习率。LambdaLR
的lr_lambda
参数是一个函数,它根据当前周期数、总周期数和开始衰减的周期数来计算学习率的乘数。LambdaLR(opt.n_epochs, opt.epoch, opt.decay_epoch).step
是一个函数,它返回学习率乘数。
定义Tensor类型:
Tensor = torch.cuda.FloatTensor if cuda else torch.Tensor
这行代码定义了一个Tensor类型,如果cuda
为True(表示使用GPU),则使用torch.cuda.FloatTensor
,否则使用torch.Tensor
。这个Tensor类型将在后续代码中用于创建张量。
创建重播缓冲区:
fake_A_buffer = ReplayBuffer()
fake_B_buffer = ReplayBuffer()
这两行代码创建了两个重播缓冲区,用于存储生成的样本。这些缓冲区在训练判别器时用于提供历史的生成样本,以提高训练的稳定性。ReplayBuffer
是一个自定义类,它提供了一个固定大小的存储,用于存储最近生成的样本。
第六部分
这段代码主要涉及到创建数据加载器、定义一个用于保存样本图像的函数,并且在训练过程中定期保存生成的样本。下面是分段解释:
1-7.定义图像转换操作:
transforms_ = [
transforms.Resize(int(opt.img_height * 1.12), Image.BICUBIC),
transforms.RandomCrop((opt.img_height, opt.img_width)),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
]
这行代码定义了一个列表transforms_
,其中包含了多个图像转换操作。这些操作包括:
Resize
:将图像尺寸放大到原始高度的1.12倍,使用双三次插值方法。RandomCrop
:从放大的图像中随机裁剪出原始尺寸的图像。RandomHorizontalFlip
:以一定的概率水平翻转图像。ToTensor
:将图像转换为PyTorch张量。Normalize
:对图像进行归一化处理,将像素值范围从[0, 1]转换为[-1, 1]。
8-12. 创建训练数据加载器:
dataloader = DataLoader(
ImageDataset("./data/%s/" % opt.dataset_name, transforms_=transforms_, unaligned=True),
batch_size=opt.batch_size,
shuffle=True,
num_workers=opt.n_cpu,
)
这五行代码创建了一个训练数据加载器dataloader
。ImageDataset
是一个自定义类,它从指定的数据集路径加载图像,并应用转换操作。DataLoader
是一个迭代器,它允许我们按批次加载数据,并提供数据混洗、多进程数据加载等功能。
13-18.创建测试数据加载器:
val_dataloader = DataLoader(
ImageDataset("./data/%s/" % opt.dataset_name, transforms_=transforms_, unaligned=True, mode="test"),
batch_size=5,
shuffle=True,
num_workers=1,
)
这五行代码创建了一个测试数据加载器val_dataloader
。它与训练数据加载器类似,但是batch_size
设置为5,num_workers
设置为1,并且ImageDataset
的mode
参数设置为"test",表示加载的是测试集。
19-33. 定义保存样本图像的函数:
def sample_images(batches_done):
# ... (代码逻辑见下文)
这行代码定义了一个名为sample_images
的函数,它接受一个参数batches_done
,表示已经完成的批次数。这个函数的作用是在训练过程中定期保存生成的样本图像。
20-24. 加载并处理测试集的图像:
imgs = next(iter(val_dataloader))
G_AB.eval()
G_BA.eval()
real_A = Variable(imgs["A"].type(Tensor))
fake_B = G_AB(real_A)
real_B = Variable(imgs["B"].type(Tensor))
fake_A = G_BA(real_B)
这些代码行从测试数据加载器中获取下一批图像,并将生成器G_AB
和G_BA
设置为评估模式。然后,它将真实图像real_A
和real_B
转换为PyTorch张量,并通过生成器生成伪造的图像fake_B
和fake_A
。
26-30. 创建图像网格并保存:
real_A = make_grid(real_A, nrow=5, normalize=True)
real_B = make_grid(real_B, nrow=5, normalize=True)
fake_A = make_grid(fake_A, nrow=5, normalize=True)
fake_B = make_grid(fake_B, nrow=5, normalize=True)
image_grid = torch.cat((real_A, fake_B, real_B, fake_A), 1)
save_image(image_grid, "images/%s/%s.png" % (opt.dataset_name, batches_done), normalize=False)
这些代码行将真实图像和伪造图像排列成网格,并将它们拼接成一个大的图像网格。然后,使用save_image
函数将图像网格保存到指定的文件路径。make_grid
函数将多个图像排列成一个网格,save_image
函数将张量保存为图像文件。
第七部分
这段代码是CycleGAN训练脚本的一部分,涉及生成器的训练过程。下面是逐行解释:
1-3. 检查是否为主程序入口:定义一个变量prev_time
来记录开始训练的时间。
if __name__ == '__main__':
prev_time = time.time()
这行代码首先检查当前脚本是否作为主程序入口运行。如果是,则继续执行。
6-7. 循环遍历所有周期和批次:
for epoch in range(opt.epoch, opt.n_epochs):
for i, batch in enumerate(dataloader):
这两个循环遍历了所有的训练周期和每个周期内的批次。opt.epoch
是当前的周期数,opt.n_epochs
是总的周期数。
8-12. 设置模型输入和对抗性地面真值:
# Set model input
real_A = Variable(batch["A"].type(Tensor))
real_B = Variable(batch["B"].type(Tensor))
# Adversarial ground truths
valid = Variable(Tensor(np.ones((real_A.size(0), *D_A.output_shape))), requires_grad=False)
fake = Variable(Tensor(np.zeros((real_A.size(0), *D_A.output_shape))), requires_grad=False)
这些代码行设置模型的输入,包括真实图像real_A
和real_B
,以及用于判别器的对抗性地面真值。valid
和fake
分别表示真实图像和生成图像的标签。
13-22. 训练生成器:
# ------------------
# Train Generators
# ------------------
G_AB.train()
G_BA.train()
optimizer_G.zero_grad()
# Identity loss
loss_id_A = criterion_identity(G_BA(real_A), real_A)
loss_id_B = criterion_identity(G_AB(real_B), real_B)
loss_identity = (loss_id_A + loss_id_B) / 2
# GAN loss
fake_B = G_AB(real_A)
loss_GAN_AB = criterion_GAN(D_B(fake_B), valid)
fake_A = G_BA(real_B)
loss_GAN_BA = criterion_GAN(D_A(fake_A), valid)
loss_GAN = (loss_GAN_AB + loss_GAN_BA) / 2
# Cycle loss
recov_A = G_BA(fake_B)
loss_cycle_A = criterion_cycle(recov_A, real_A)
recov_B = G_AB(fake_A)
loss_cycle_B = criterion_cycle(recov_B, real_B)
loss_cycle = (loss_cycle_A + loss_cycle_B) / 2
# Total损失
loss_G = loss_GAN + opt.lambda_cyc * loss_cycle + opt.lambda_id * loss_identity
loss_G.backward()
optimizer_G.step()
这些代码行定义了生成器的损失函数,包括身份损失、对抗性损失、循环一致性损失和总损失。然后,它们计算这些损失,反向传播并更新生成器的权重。
这段代码的主要目的是通过训练生成器来学习从源域(A)到目标域(B)的映射,并从目标域(B)到源域(A)的映射。通过这种方式,生成器能够学习到两个域之间的映射关系,从而实现图像到图像的转换。
第八部分
这段代码是CycleGAN训练脚本的一部分,涉及判别器的训练过程。下面是逐行解释:
1-4. 重置判别器A的优化器梯度:
optimizer_D_A.zero_grad()
这行代码将判别器A的优化器中的梯度清零,以便在反向传播时不会累加到之前的梯度上。
5-11. 计算判别器A的损失:
# Real loss
loss_real = criterion_GAN(D_A(real_A), valid)
# Fake loss (on batch of previously generated samples)
fake_A_ = fake_A_buffer.push_and_pop(fake_A)
loss_fake = criterion_GAN(D_A(fake_A_.detach()), fake)
# Total loss
loss_D_A = (loss_real + loss_fake) / 2
这些代码行计算判别器A的损失,包括真实图像的损失(loss_real
)和生成图像的损失(loss_fake
)。fake_A_
是从fake_A_buffer
中弹出的先前生成的图像批次的副本,以确保判别器不会反向传播梯度回生成器。
12-16. 计算判别器A的总损失,并进行反向传播和权重更新:
loss_D_A.backward()
optimizer_D_A.step()
这行代码计算判别器A的总损失,并执行反向传播,将梯度传播回判别器的权重。然后,使用优化器optimizer_D_A
来更新判别器A的权重。
17-22. 重置判别器B的优化器梯度:
optimizer_D_B.zero_grad()
这行代码将判别器B的优化器中的梯度清零,以便在反向传播时不会累加到之前的梯度上。
23-30. 计算判别器B的损失:
# Real loss
loss_real = criterion_GAN(D_B(real_B), valid)
# Fake loss (on batch of previously generated samples)
fake_B_ = fake_B_buffer.push_and_pop(fake_B)
loss_fake = criterion_GAN(D_B(fake_B_.detach()), fake)
# Total loss
loss_D_B = (loss_real + loss_fake) / 2
这些代码行计算判别器B的损失,包括真实图像的损失(loss_real
)和生成图像的损失(loss_fake
)。fake_B_
是从fake_B_buffer
中弹出的先前生成的图像批次的副本。
31-35. 计算判别器B的总损失,并进行反向传播和权重更新:
loss_D_B.backward()
optimizer_D_B.step()
这行代码计算判别器B的总损失,并执行反向传播,将梯度传播回判别器的权重。然后,使用优化器optimizer_D_B
来更新判别器B的权重。
36. 计算判别器A和B的总损失,以平衡两个判别器的训练:
loss_D = (loss_D_A + loss_D_B) / 2
这行代码计算判别器A和B的总损失的平均值,以平衡两个判别器的训练过程。
第九部分
这段代码是CycleGAN训练脚本的循环外部分,涉及打印日志、保存图像样本、更新学习率以及保存模型检查点。下面是逐行解释:
1-8. 计算剩余批次数和剩余时间:
batches_done = epoch * len(dataloader) + i
batches_left = opt.n_epochs * len(dataloader) - batches_done
time_left = datetime.timedelta(seconds=batches_left * (time.time() - prev_time))
prev_time = time.time()
这些代码行计算了到目前为止已经完成的批次数(batches_done
),剩余的批次数(batches_left
),以及剩余的训练时间(time_left
)。prev_time
用于计算时间差,以便估算剩余时间。
9-15. 打印日志:
# Print log
sys.stdout.write(
"\r[Epoch %d/%d] [Batch %d/%d] [D loss: %f] [G loss: %f, adv: %f, cycle: %f, identity: %f] ETA: %s"
% (
epoch,
opt.n_epochs,
i,
len(dataloader),
loss_D.item(),
loss_G.item(),
loss_GAN.item(),
loss_cycle.item(),
loss_identity.item(),
time_left,
)
)
这行代码使用sys.stdout.write
来打印训练日志,包括当前周期数、批次数、判别器损失、生成器损失以及剩余训练时间。\r
是一个回车符,用于在同一行中打印信息。
16-21. 保存图像样本:
# If at sample interval save image
if batches_done % opt.sample_interval == 0:
sample_images(batches_done)
这行代码检查是否达到了预定的样本间隔(opt.sample_interval
),如果是,则调用sample_images
函数来保存生成的图像样本。
22-25. 更新生成器的学习率:
# Update learning rates
lr_scheduler_G.step()
这行代码更新生成器G_AB
和G_BA
的学习率。lr_scheduler_G
是生成器的学习率调度器,它根据当前周期数和预设的参数来调整学习率。
26-29. 更新判别器的 learning rate:
lr_scheduler_D_A.step()
lr_scheduler_D_B.step()
这行代码更新判别器D_A
和D_B
的学习率。lr_scheduler_D_A
和lr_scheduler_D_B
是判别器的学习率调度器,它们根据当前周期数和预设的参数来调整学习率。
30-34. 保存模型检查点:
if opt.checkpoint_interval != -1 and epoch % opt.checkpoint_interval == 0:
# Save model checkpoints
torch.save(G_AB.state_dict(), "saved_models/%s/G_AB_%d.pth" % (opt.dataset_name, epoch))
torch.save(G_BA.state_dict(), "saved_models/%s/G_BA_%d.pth" % (opt.dataset_name, epoch))
torch.save(D_A.state_dict(), "saved_models/%s/D_A_%d.pth" % (opt.dataset_name, epoch))
torch.save(D_B.state_dict(), "saved_models/%s/D_B_%d.pth" % (opt.dataset_name, epoch))
这行代码检查是否达到了预定的检查点间隔(opt.checkpoint_interval
),如果是,则保存生成器G_AB
、G_BA
和判别器D_A
、D_B
的权重状态。权重被保存为.pth
文件,以便在需要时可以加载
三、学习过程中遇到的问题及其解决方案
1. 如果想要在jupyter notebook中运行.py文件的同学,可以使用
%load 文件名.py
将这个输入在cell中并运行就可以将.py文件转化成.ipynb文件
2. 如果遇到ipykernel_launcher.py: error: unrecognized arguments: -f /home/
这个报错问题的同学
这是因为调用parser.parse_args()会读取系统参数:sys.argv[],命令行调用时是正确参数,而在jupyter notebook中调用时,sys.argv的值为ipykrnel_launcher.py:
解决方法是:在代码中加入这段
import sys
sys.argv = ['run.py']
3. 一定要把文件按照要求的目录结构放,不然会影响读取
4. 这个训练超级费算力,我天真的以为我能跑满200轮,放了一天才跑了四十轮,结果还把那个结搞丢了。后来又重新跑了十几轮
四、训练结果
这点已经花了我六个小时来跑了,其实已经能有很不错的结果了,话不多说,上结果!