Contrastive Learning for Unpaired Image-to-Image Translation
- 1. 摘要
- 2. 介绍
- 3. 相关工作
- 3.1 图像转换、循环一致性
- 3.2 关系保持
- 3.3 深度网络嵌入中的感知相似性
- 3.4 对比表示学习
- 4. 方法
原文及代码链接 https://github.com/taesungp/contrastive-unpaired-translation
1. 摘要
- 图像转换任务中,输入-输出对应patch内容应该保持一致;
- 使用基于patch的对比学习方法实现单向图像转换;
- 训练数据不成对;
- 该方法促使输入-输出中对应patch映射到特征空间中的一个相似点,输入图像中其他部分为负样本;
- 使用一个多层、基于patch的方法,而不是在整幅图像上进行操作;
- 负样本从输入图像的其他部分进行提取,而不是数据集中的其他部分;
- 该方法可以提高转换质量且缩短训练时间;
2. 介绍
- 图像转换任务本质工作是,在源域与目标域的映射过程中,将源域图像的结构、内容部分与外观部分分离,在转换过程中,内容不变,而外观则使用目标域图像的外观替换;
- CycleGAN中使用对抗损失保持外观部分,使用循环一致损失保持内容部分,但是循环一致损失对源域与目标域的限制较大,要求两域之间的映射关系必须为双射,也就是要求G(x)只会存在一个y与之对应,反之亦然;
- 本文方法通过最大化对应的输入和输出patch之间的互信息仅保持内容部分对应;
- 本文通过对比损失InfoNCE Loss实现对比学习,该函数目的是学习一个嵌入或一个编码器,它将相应的补丁相互关联,同时将它们与其他补丁分离;
- 编码器会关注两个域之间的共性,例如物体的部分和形状,同时对不同之处如动物纹理保持不变:网络目标:输入域图像内容+目标域图像风格;
- 使用多层、基于patch的对比学习方法更有效,此过程中从输入的其他部分提取负样本可以强迫patch更好的保存输入的内容;
- 因为对比学习是在图像内部进行,所以该网络可以在单个图像上进行训练;
3. 相关工作
3.1 图像转换、循环一致性
- 成对图像转换任务中,常使用对抗损失或结合重构损失学习源域到目标域之间的映射关系;
- 不成对图像转换任务中,多使用循环一致损失强制要求源域和目标域之间内容尽可能一致,但是该损失有一个很大的限制:两域之间映射关系为双射。实际情况中类似CycleGAN的实现有一个缺点:如果两域之间图像内容上信息不对应则会很难实现重构;
3.2 关系保持
- 此方法促使输入图像中的关系在输出中有类似反映;
- 可使用关系保持的方法替换循环一致性,但此种方法存在两大缺点:1)依赖整幅图像之间的关系:输入图像中相似patch中存在的关系在输出中依然保持;2)依赖预先定义的距离函数:用于计算patch之间的相似度;
- 本文方法不依赖预定义距离,通过最大化输入-输出对应patch之间的互信息(公共信息)学习一个跨域相似性函数,以此替换循环一致性;
3.3 深度网络嵌入中的感知相似性
- 图像转换任务中,大多工作使用每像素重构度量定义感知距离函数,比如使用l1损失定义,此种方法并不能很好地反映视觉效果,可能会得到模糊结果;
- 成对数据图像转换任务中,近期有一些工作通过在ImageNet数据集上预训练的VGG分类网络定义感知损失,这种方式存在一定缺点:1)预训练网络权重固定,可能并不适合当前提供的数据集;2)可能不适用于不成对数据训练中;
- 本文通过图像之间的互信息提出约束,通过利用数据中的负样本,允许跨域相似函数适应特定的输入和输出域;
3.4 对比表示学习
- 最大化互信息:利用噪声对比估计学习一种嵌入关系,将相关信号聚集在一起;
- 优点:不用预先定义损失函数度量预测性能;
- 使用InfoNCE loss进行对比学习;
- InfoNCE loss 用于计算图像之间的相似度,是一种基于互信息和噪声对比估计的无监督学习方法,可用于自监督学习和基于对比的学习。通过对比图像之间的特征信息来计算相似度,从而使得学习到的特征更加具有判别性和鲁棒性;
4. 方法
- 网络整体为生成对抗网络,且生成器G分为两部分:Genc和Gdec;
对抗损失: - 使用对抗损失保证输入输出的风格相似,视觉上看起来类似;
互信息最大化: - 使用噪声对比估计去最大化输入-输出之间的互信息,以此保持输入-输出之间内容部分对应;
对比学习的目的是将两个信号联系起来:查询信号和对应的正样本,数据集中其他样本为负样本;
查询信号和对应的正样本、N个负样本均被映射为一个K维向量;
将向量归一化到一个单位球上,防止空间崩溃或者膨胀;
对比学习的过程可以理解为一个N+1路分类问题,在N+1个样本中确定1个正样本,其余均为负样本;
通过下列公式计算交叉熵损失函数,表示正样本被选取的概率;
在本文中,查询信号为输出图像中的某个patch,对应的正样本为输入图像中的对应patch,负样本为输入图像中的其他部分;
多层、基于patch的对比学习:
- 无监督学习中,对比学习被用于图像级别或者patch级别;
- 作者注意到输入-输出不仅要在图像级别共享内容,其对应patch也应该共享内容;
- 本文应用多层、基于patch的对比学习;
- 生成器中编码器用于捕捉内容信息,解码器用于合成风格信息;
- 网络层提取特征图中的一个像素点可以看做输入图像中的某个patch,这得益于卷积操作的局部性;
- 编码器结合两层MLP计算交叉熵损失,
总训练目标:
- 1)对抗损失;2)基于patch的交叉熵损失函数;
- PatchNCE(G,H,Y)损失可理解为一致性损失,即G(Y)≈Y: