迁移学习——CycleGAN

CycleGAN

    • 1.导入需要的包
    • 2.数据加载
      • (1)to_img 函数
      • (2)数据加载
      • (3)图像转换
    • 3.随机读取图像进行预处理
      • (1)函数参数
      • (2)数据路径
      • (3)读取文件列表
      • (4)初始化结果列表
      • (5)随机采样
      • (6)读取和预处理图像
      • (7)返回结果
    • 4.残差网络块
      • (1)构造函数
      • (2)残差块层
      • (3)跳跃连接
    • 5.生成器网络
      • (1)构造函数
      • (2)编码器部分
      • (3)残差块部分
      • (4)解码器部分
      • (5)输出层
      • (6)模型初始化
      • (7)前向传播
    • 6.判别器网络
      • (1)构造函数
      • (2)判别器层
      • (3)全卷积网络部分
      • (4)输出
    • 7.缓存生成器
      • (1)构造函数
      • (2)push_and_pop 方法
    • 8.训练生成对抗网络(GAN)
    • 9.优化器
    • 10.训练循环的迭代次数
    • 11.训练循环
    • 12.训练生成器
    • 13.训练判别器
    • 14.损失打印,存储伪造图片
    • 全部代码

CycleGAN(循环一致性对抗网络),用于实现两个域(例如,风格或主题不同的图像)之间的无监督图像到图像转换。
CycleGAN的核心思想是使用生成器(Generator)和判别器(Discriminator)来学习从源域(source
domain)到目标域(target domain)的映射,同时保持循环一致性,即从目标域映射回源域应该尽可能接近原始源域图像。

1.导入需要的包

from random import randint: 从Python的random模块中导入randint函数,用于生成随机整数。

import numpy as np: 导入Numpy库,并将其重命名为np,以便在代码中使用。
import torch:导入PyTorch库。
torch.set_default_tensor_type(torch.FloatTensor):设置PyTorch的默认Tensor类型为torch.FloatTensor。
import torch.nn as nn:导入PyTorch的神经网络模块,并将其重命名为nn。
import torch.optim as optim:导入PyTorch的优化器模块,并将其重命名为optim。
import torchvision.datasets as datasets: 导入PyTorch的图像数据集模块,并将其重命名为datasets。
import torchvision.transforms as transforms:导入PyTorch的图像变换模块,并将其重命名为transforms。
import os:导入Python的操作系统模块,用于处理文件和目录。
import matplotlib.pyplot as plt:导入matplotlib的Pyplot模块,用于绘图。
import torch.nn.functional as F:导入PyTorch的函数模块,并将其重命名为F。
from torch.autograd import Variable:从PyTorch的自动求导模块中导入Variable类。
from torchvision.utils import save_image: 从PyTorch的图像处理模块中导入save_image函数。
import shutil:导入Python的文件操作模块,用于删除文件和目录。
import cv2: 导入OpenCV库,用于图像处理和计算机视觉。
import random: 导入Python的随机模块。
from PIL import Image:从Pillow库中导入Image类。
import itertools: 导入Python的迭代工具模块。

from random import randint
import numpy as np 
import torch
torch.set_default_tensor_type(torch.FloatTensor)
import torch.nn as nn
import torch.optim as optim
import torchvision.datasets as datasets
import torchvision.transforms as transforms
import os
import matplotlib.pyplot as plt
import torch.nn.functional as F
from torch.autograd import Variable
from torchvision.utils import save_image
import shutil
import cv2
import random
from PIL import Image
import itertools

2.数据加载

(1)to_img 函数

out = 0.5 * (x + 1): 将输入张量 x 的值从 [-1, 1] 范围转换到 [0, 1] 范围。这是因为在训练过程中,图像通常会被归一化到 [-1, 1] 范围,而显示图像时需要将其转换回 [0, 1] 范围。
out = out.clamp(0, 1): 确保所有像素值都在 [0, 1] 范围内。clamp 函数将小于0的值设为0,大于1的值设为1。
out = out.view(-1, 3, 256, 256): 将张量 out 的形状重新调整为批次的形状,其中每个样本是一个 3通道(RGB)的 256x256 图像。-1 表示自动计算批次大小。

def to_img(x):
    out = 0.5 * (x + 1)
    out = out.clamp(0, 1)  
    out = out.view(-1, 3, 256, 256)  
    return out

(2)数据加载

data_path = os.path.abspath('D:\probject\pythonProject1\pytorch\CycleGAN\data'):定义了数据的路径,使用os.path.abspath()将相对路径转换为绝对路径。
image_size = 256:指定图像的大小为256x256。
batch_size = 1:定义了批处理的大小为1。

data_path = os.path.abspath('D:\probject\pythonProject1\pytorch\CycleGAN\data')
image_size = 256
batch_size = 1

(3)图像转换

transform = transforms.Compose([: 创建一个由多个图像转换操作组成的管道。
transforms.Resize(int(image_size * 1.12), Image.BICUBIC): 将图像大小调整为原始大小的 1.12 倍。这样做是为了在后续的随机裁剪中提供更多的裁剪选择。
transforms.RandomCrop(image_size): 从调整大小后的图像中随机裁剪出 256x256 像素大小的区域。
transforms.RandomHorizontalFlip(): 以 50% 的概率随机水平翻转图像。
transforms.ToTensor(): 将 PIL 图像转换为 PyTorch 张量。
transforms.Normalize((0.5,0.5,0.5), (0.5,0.5,0.5)):对图像进行归一化处理,将每个通道的像素值从 [0, 1] 范围转换为 [-1, 1] 范围。

transform = transforms.Compose([transforms.Resize(int(image_size * 1.12), 
                                                  Image.BICUBIC), 
            transforms.RandomCrop(image_size), 
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            transforms.Normalize((0.5,0.5,0.5), (0.5,0.5,0.5))])

3.随机读取图像进行预处理

(1)函数参数

batch_size: 一个整数,表示每个批次中图像的数量。默认值为1。

def _get_train_data(batch_size=1):

(2)数据路径

train_a_filepath: 训练集A的文件路径。
train_b_filepath: 训练集B的文件路径。

	train_a_filepath = data_path + '\\trainA\\'
    train_b_filepath = data_path + '\\trainB\\'

(3)读取文件列表

train_a_list: 读取训练集A目录中的所有文件名。
train_b_list: 读取训练集B目录中的所有文件名。

   train_a_list = os.listdir(train_a_filepath)
   train_b_list = os.listdir(train_b_filepath)

(4)初始化结果列表

train_a_result: 存储处理后的训练集A图像。
train_b_result: 存储处理后的训练集B图像。

    train_a_result = []
    train_b_result = [] 

(5)随机采样

numlist: 从0到训练集A长度之间的范围中随机采样 batch_size 个索引。

numlist = random.sample(range(0, len(train_a_list)), batch_size)

(6)读取和预处理图像

对于 numlist 中的每个索引 i: 读取训练集A和B中对应的文件名。 使用 PIL.Image.open
打开图像文件,并将其转换为RGB格式。 应用之前定义的 transform 方法对图像进行预处理(包括调整大小、裁剪、翻转和归一化)。
将预处理后的图像添加到 train_a_result 和 train_b_result 列表中。

	for i in numlist:
        a_filename = train_a_list[i]
        a_img = Image.open(train_a_filepath + a_filename).convert('RGB')
        res_a_img = transform(a_img)
        train_a_result.append(torch.unsqueeze(res_a_img, 0))
        
        b_filename = train_b_list[i]
        b_img = Image.open(train_b_filepath + b_filename).convert('RGB')
        res_b_img = transform(b_img)
        train_b_result.append(torch.unsqueeze(res_b_img, 0))
        

(7)返回结果

使用 torch.cattrain_a_resulttrain_b_result
列表中的图像堆叠成一个批次,并返回这两个批次的图像。

4.残差网络块

残差块是一种常用的构建块,用于深度卷积神经网络,特别是在
ResNet(残差网络)架构中。它允许网络在学习过程中保留和利用之前层的信息,通过跳跃连接(shortcut
connections)来解决深层网络训练过程中的梯度消失问题。

(1)构造函数

def __init__(self, in_features): 构造函数接收一个参数 in_features,表示输入特征图的通道数。
super(ResidualBlock, self).__init__(): 调用父类 nn.Module 的构造函数。
self.block_layer: 定义一个顺序模型 nn.Sequential,包含残差块的所有层。

class ResidualBlock(nn.Module):
    def __init__(self, in_features):
        super(ResidualBlock, self).__init__()
        self.block_layer = nn.Sequential

(2)残差块层

nn.ReflectionPad2d(1):使用反射填充(padding)来扩展输入张量的边界。这种填充方式在边缘反射输入数据,以保持边缘信息的连续性。
nn.Conv2d(in_features, in_features, 3): 使用 3x3的卷积核进行卷积操作,输入和输出通道数相同。
nn.InstanceNorm2d(in_features):应用实例归一化(Instance Normalization)来对每个样本的特征图进行归一化处理。这与批量归一化(Batch Normalization)不同,它不对整个批次的数据进行归一化,而是对单个样本的特征图进行归一化。
nn.ReLU(inplace=True): 应用 ReLU 激活函数,并设置 inplace=True以便直接修改输入张量,减少内存使用。

(
            nn.ReflectionPad2d(1),
            nn.Conv2d(in_features, in_features, 3),
            nn.InstanceNorm2d(in_features),
            nn.ReLU(inplace=True),
            nn.ReflectionPad2d(1),
            nn.Conv2d(in_features, in_features, 3),
            nn.InstanceNorm2d(in_features))

(3)跳跃连接

return x + self.block_layer(x): 这是残差块的核心,它将输入张量 x 与
self.block_layer(x) 的输出相加,形成跳跃连接。这样,即使 self.block_layer
的输出为零(即网络未能学习到任何东西),输入 x 仍然可以通过跳跃连接直接传递到下一层,从而保持了信息的流通。

	def forward(self, x):
        return x + self.block_layer(x)

5.生成器网络

生成器的目的是将输入图像从一个域转换到另一个域。

(1)构造函数

super(Generator, self).__init__(): 调用父类 nn.Module 的构造函数。
model: 初始化一个列表,用于存储生成器网络中的层。

class Generator(nn.Module):
    def __init__(self):
        super(Generator, self).__init__()

(2)编码器部分

nn.ReflectionPad2d(3): 使用反射填充(padding)来扩展输入张量的边界。
nn.Conv2d(3, 64, 7): 使用 7x7 的卷积核将输入图像(3 通道)转换为 64 通道的特征图。
nn.InstanceNorm2d(64):应用实例归一化。
nn.ReLU(inplace=True): 应用 ReLU 激活函数。
for _ in range(2):重复以下层两次,以逐渐减少特征图的尺寸。
nn.Conv2d(in_features, out_features, 3,stride=2, padding=1): 使用 3x3 的卷积核,步长为 2,进行降采样。
nn.InstanceNorm2d(out_features): 应用实例归一化。
nn.ReLU(inplace=True):应用 ReLU 激活函数。

model = [nn.ReflectionPad2d(3), 
                 nn.Conv2d(3, 64, 7), 
                 nn.InstanceNorm2d(64), 
                 nn.ReLU(inplace=True)]

        in_features = 64
        out_features = in_features * 2
        for _ in range(2):
            model += [nn.Conv2d(in_features, out_features, 
                                3, stride=2, padding=1), 
            nn.InstanceNorm2d(out_features), 
            nn.ReLU(inplace=True)]
            in_features = out_features
            out_features = in_features*2

(3)残差块部分

for _ in range(9): 重复添加 9 个残差块,这些块是 CycleGAN 生成器的核心,用于学习域之间的映射。

 	for _ in range(9):
            model += [ResidualBlock(in_features)]

(4)解码器部分

out_features = in_features // 2: 准备进行上采样,将特征图的尺寸加倍。
for _ in range(2): 重复以下层两次,以逐渐增加特征图的尺寸。
nn.ConvTranspose2d(in_features, out_features, 3, stride=2, padding=1, output_padding=1): 使用 3x3 的转置卷积核,步长为 2,进行上采样。
nn.InstanceNorm2d(out_features): 应用实例归一化。
nn.ReLU(inplace=True): 应用 ReLU 激活函数。

    out_features = in_features // 2
        for _ in range(2):
            model += [nn.ConvTranspose2d(
                    in_features, out_features, 
                    3, stride=2, padding=1, output_padding=1), 
                nn.InstanceNorm2d(out_features), 
                nn.ReLU(inplace=True)]
            in_features = out_features
            out_features = in_features // 2

(5)输出层

nn.ReflectionPad2d(3): 使用反射填充。
nn.Conv2d(64, 3, 7): 使用 7x7的卷积核将特征图转换回 3 通道的图像。
nn.Tanh(): 应用 Tanh 激活函数,将输出值范围映射到 [-1, 1]。

	model += [nn.ReflectionPad2d(3), 
                  nn.Conv2d(64, 3, 7), 
                  nn.Tanh()]

(6)模型初始化

self.gen = nn.Sequential( * model): 将所有层组合成一个顺序模型。

self.gen = nn.Sequential( * model)

(7)前向传播

def forward(self, x): 定义前向传播函数。
x = self.gen(x): 通过生成器网络传递输入 x。
return x: 返回生成器的输出。

	def forward(self, x):
        x = self.gen(x)
        return x 

6.判别器网络

(1)构造函数

super(Discriminator, self).__init__(): 调用父类 nn.Module 的构造函数。
self.dis: 定义一个顺序模型 nn.Sequential,包含判别器网络的所有层。

class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()
        self.dis = nn.Sequential

(2)判别器层

nn.Conv2d(3, 64, 4, 2, 1, bias=False): 使用 4x4 的卷积核,步长为2,进行降采样,输入通道数为 3(RGB),输出通道数为 64。
nn.LeakyReLU(0.2, inplace=True): 应用Leaky ReLU 激活函数,设置斜率为 0.2。
for _ in range(3): 重复以下层三次,以逐渐减少特征图的尺寸。
nn.Conv2d(in_features, out_features, 4, 2, 1, bias=False): 使用 4x4 的卷积核,步长为 2,进行降采样。
nn.InstanceNorm2d(out_features): 应用实例归一化。
nn.LeakyReLU(0.2, inplace=True): 应用 Leaky ReLU 激活函数。

(
            nn.Conv2d(3, 64, 4, 2, 1, bias=False),
            nn.LeakyReLU(0.2, inplace=True),

            nn.Conv2d(64, 128, 4, 2, 1, bias=False),
            nn.InstanceNorm2d(128),
            nn.LeakyReLU(0.2, inplace=True),

            nn.Conv2d(128, 256, 4, 2, 1, bias=False),
            nn.InstanceNorm2d(256),
            nn.LeakyReLU(0.2, inplace=True),

(3)全卷积网络部分

nn.Conv2d(256, 512, 4, padding=1): 使用 4x4 的卷积核,不进行降采样,输入通道数为256,输出通道数为 512。
nn.InstanceNorm2d(512): 应用实例归一化。
nn.LeakyReLU(0.2, inplace=True): 应用 Leaky ReLU 激活函数。
nn.Conv2d(512, 1, 4, padding=1):使用 4x4 的卷积核,不进行降采样,输入通道数为 512,输出通道数为 1。

			nn.Conv2d(256, 512, 4, padding=1),
            nn.InstanceNorm2d(512),
            nn.LeakyReLU(0.2, inplace=True),
            
            nn.Conv2d(512, 1, 4, padding=1))  

(4)输出

return F.avg_pool2d(x, x.size()[2:]).view(x.size()[0], -1):对判别器输出的特征图进行平均池化操作,然后将其展平为一维向量。这个一维向量将作为最终的判别结果,其长度为 1,表示输入图像的真实性(接近 1表示真实,接近 0 表示假)。

	def forward(self, x):
        x = self.dis(x)
        return F.avg_pool2d(x, x.size()[2:]).view(x.size()[0], -1)

7.缓存生成器

(1)构造函数

def __init__(self, max_size=50): 定义了一个构造函数 init,用于在创建ReplayBuffer 对象时初始化其属性。
self.max_size = max_size: 初始化缓冲区的大小。
self.data = []: 初始化一个空列表 self.data,用于存储缓存的数据。

class ReplayBuffer():
#     """
#     缓存队列,若不足则新增,否则随机替换
#     """
    def __init__(self, max_size=50):
        self.max_size = max_size
        self.data = []

(2)push_and_pop 方法

def push_and_pop(self, data): 定义了一个方法,用于将新数据推入缓冲区,并在需要时弹出旧数据。
to_return = []: 初始化一个空列表 to_return,用于存储从缓冲区中弹出的数据。
for element in data.data:: 遍历传入的数据 data.data 中的每个元素。
element = torch.unsqueeze(element, 0):将每个元素展平为一维张量。这通常是为了确保张量的形状与预期的形状匹配,以便后续的操作可以正确执行。
if len(self.data) < self.max_size:: 如果缓冲区中还没有达到最大容量,则将新元素添加到缓冲区。
self.data.append(element): 将新元素添加到缓冲区。
to_return.append(element): 将新元素添加到 to_return 列表中。
else:: 如果缓冲区已满,则随机替换缓冲区中的一个元素。
if random.uniform(0,1) > 0.5:: 如果随机数大于 0.5,则从缓冲区中随机选择一个元素替换。
i = random.randint(0, self.max_size-1): 随机选择一个索引。
to_return.append(self.data[i].clone()): 将缓冲区中的元素复制并添加到 to_return列表中。
self.data[i] = element: 用新元素替换缓冲区中的元素。
else:: 如果随机数小于或等于 0.5,则直接添加新元素到 to_return 列表中。
to_return.append(element): 将新元素添加到 to_return 列表中。
return Variable(torch.cat(to_return)): 返回 to_return 列表中所有元素的拼接张量。Variable 是一个 PyTorch 类,用于表示可变的张量。torch.cat 函数用于将多个张量拼接在一起。

def push_and_pop(self, data):
        to_return = []
        for element in data.data:
            element = torch.unsqueeze(element, 0)
            if len(self.data) < self.max_size:
                self.data.append(element)
                to_return.append(element)
            else:
                if random.uniform(0,1) > 0.5:
                    i = random.randint(0, self.max_size-1)
                    to_return.append(self.data[i].clone())
                    self.data[i] = element
                else:
                    to_return.append(element)
        return Variable(torch.cat(to_return))

8.训练生成对抗网络(GAN)

fake_A_buffer = ReplayBuffer(): 创建了一个名为 fake_A_buffer 的 ReplayBuffer实例。ReplayBuffer是一个用于缓存和随机替换数据的结构,在训练循环中用于缓存生成器生成的假图像,以便在后续的训练步骤中用于训练判别器。
fake_B_buffer = ReplayBuffer(): 创建了一个名为 fake_B_buffer 的 ReplayBuffer实例。这个缓冲区的作用与 fake_A_buffer 类似,用于缓存从生成器 netG_B2A 生成的假图像。

fake_A_buffer = ReplayBuffer()
fake_B_buffer = ReplayBuffer()

netG_A2B = Generator(): 创建了一个名为 netG_A2B 的 Generator 实例。Generator是一个用于生成新图像的神经网络,在这里,它将从域 A 生成域 B 的图像。
netG_B2A = Generator(): 创建了一个名为 netG_B2A 的 Generator 实例。这个生成器将从域 B生成域 A 的图像。
netD_A = Discriminator(): 创建了一个名为 netD_A 的 Discriminator实例。Discriminator 是一个用于判断图像是否真实的神经网络,在这里,它用于判断 A 类图像是否真实。
netD_B = Discriminator(): 创建了一个名为 netD_B 的 Discriminator实例。这个判别器用于判断 B 类图像是否真实。

netG_A2B = Generator()
netG_B2A = Generator()
netD_A = Discriminator()
netD_B = Discriminator()

criterion_GAN = torch.nn.MSELoss(): 定义了一个名为 criterion_GAN 的 MSELoss
损失函数。这个损失函数用于计算 GAN 损失,即判别器对真实图像和假图像的预测之间的差异。
criterion_cycle = torch.nn.L1Loss(): 定义了一个名为 criterion_cycle 的 L1Loss损失函数。这个损失函数用于计算循环一致性损失,即生成器生成的图像与其输入图像之间的差异。
criterion_identity = torch.nn.L1Loss(): 定义了一个名为 criterion_identity 的 L1Loss损失函数。这个损失函数用于计算身份损失,即生成器生成的图像与其输入图像之间的差异。

criterion_GAN = torch.nn.MSELoss()
criterion_cycle = torch.nn.L1Loss()
criterion_identity = torch.nn.L1Loss()

d_learning_rate = 3e-4 : 定义了判别器的学习率。
g_learning_rate = 3e-4:定义了生成器的 learning rate。
optim_betas = (0.5, 0.999): 定义了优化器的超参数betas,这是用于计算梯度下降的动量项的值。

d_learning_rate = 3e-4  
g_learning_rate = 3e-4
optim_betas = (0.5, 0.999)

9.优化器

g_optimizer = optim.Adam(itertools.chain(netG_A2B.parameters(), netG_B2A.parameters()), lr=d_learning_rate): 创建了一个名为 g_optimizer 的Adam 优化器实例。Adam 是一种常用的优化算法,用于调整神经网络的权重。这里,itertools.chain函数用于将两个生成器的参数合并为一个单一的迭代器,以便于一起优化。lr 参数指定了学习率,它用于控制权重更新的速度。

da_optimizer = optim.Adam(netD_A.parameters(), lr=d_learning_rate):创建了一个名为 da_optimizer 的 Adam 优化器实例,用于训练判别器 netD_A。

db_optimizer = optim.Adam(netD_B.parameters(), lr=d_learning_rate):创建了一个名为 db_optimizer 的 Adam 优化器实例,用于训练判别器 netD_B。

g_optimizer = optim.Adam(itertools.chain(netG_A2B.parameters(), 
                                         netG_B2A.parameters()), 
            lr=d_learning_rate)
da_optimizer = optim.Adam(netD_A.parameters(), lr=d_learning_rate)
db_optimizer = optim.Adam(netD_B.parameters(), lr=d_learning_rate)

10.训练循环的迭代次数

num_epochs = 100: 定义了训练循环的迭代次数。epoch是一个训练周期,在这个周期内,所有数据都会被遍历一次。在这里,训练循环将执行 100 个周期。

num_epochs = 100

11.训练循环

for epoch in range(num_epochs):: 开始一个循环,该循环将执行指定的次数(由 num_epochs定义)。
real_a, real_b = _get_train_data(batch_size): 从数据集中获取一批真实图像real_a 和 real_b。
target_real = torch.full((batch_size,), 1).float():创建一个全为 1 的张量 target_real,用于指示真实图像。
target_fake =torch.full((batch_size,), 0).float(): 创建一个全为 0 的张量target_fake,用于指示假图像。
g_optimizer.zero_grad():清除生成器的梯度,以便于下一次前向传播和反向传播时不会累积梯度。

for epoch in range(num_epochs): 

    real_a, real_b = _get_train_data(batch_size)
    target_real = torch.full((batch_size,), 1).float()
    target_fake = torch.full((batch_size,), 0).float()
    
	g_optimizer.zero_grad()

12.训练生成器

same_B = netG_A2B(real_b).float(): 使用生成器 netG_A2B 从真实图像 real_b生成相似的图像 same_B。

loss_identity_B = criterion_identity(same_B, real_b) * 5.0: 计算same_B 和 real_b 之间的身份损失,并乘以 5.0 以增加其权重。

same_A = netG_B2A(real_a).float(): 使用生成器 netG_B2A 从真实图像 real_a生成相似的图像 same_A。

loss_identity_A = criterion_identity(same_A, real_a) * 5.0: 计算same_A 和 real_a 之间的身份损失,并乘以 5.0 以增加其权重。

fake_B = netG_A2B(real_a).float(): 使用生成器 netG_A2B 从真实图像 real_a 生成假图像fake_B。

pred_fake = netD_B(fake_B).float(): 使用判别器 netD_B 判断 fake_B 是否为假图像。

loss_GAN_A2B = criterion_GAN(pred_fake, target_real): 计算判别器对 fake_B的预测和真实图像的损失,即 GAN 损失。

fake_A = netG_B2A(real_b).float(): 使用生成器 netG_B2A 从真实图像 real_b 生成假图像fake_A。

pred_fake = netD_A(fake_A).float(): 使用判别器 netD_A 判断 fake_A 是否为假图像。

loss_GAN_B2A = criterion_GAN(pred_fake, target_real): 计算判别器对 fake_A的预测和真实图像的损失,即 GAN 损失。

recovered_A = netG_B2A(fake_B).float(): 使用生成器 netG_B2A 从假图像 fake_B生成恢复的图像 recovered_A。

loss_cycle_ABA = criterion_cycle(recovered_A, real_a) * 10.0: 计算recovered_A 和 real_a 之间的循环一致性损失,并乘以 10.0 以增加其权重。

recovered_B = netG_A2B(fake_A).float(): 使用生成器 netG_A2B 从假图像 fake_A生成恢复的图像 recovered_B。

loss_cycle_BAB = criterion_cycle(recovered_B, real_b) * 10.0: 计算recovered_B 和 real_b 之间的循环一致性损失,并乘以 10.0 以增加其权重。

loss_G = (loss_identity_A + loss_identity_B + loss_GAN_A2B + loss_GAN_B2A + loss_cycle_ABA + loss_cycle_BAB): 将所有损失加在一起,得到生成器的总损失。

loss_G.backward(): 对总损失进行反向传播,计算每个参数的梯度。

g_optimizer.step():会对生成器的所有参数进行梯度更新,以最小化生成器损失函数。

# 第一步:训练生成器
    same_B = netG_A2B(real_b).float()
    loss_identity_B = criterion_identity(same_B, real_b) * 5.0   
    same_A = netG_B2A(real_a).float()
    loss_identity_A = criterion_identity(same_A, real_a) * 5.0
    
    fake_B = netG_A2B(real_a).float()
    pred_fake = netD_B(fake_B).float()
    loss_GAN_A2B = criterion_GAN(pred_fake, target_real)
    fake_A = netG_B2A(real_b).float()
    pred_fake = netD_A(fake_A).float()
    loss_GAN_B2A = criterion_GAN(pred_fake, target_real)
    recovered_A = netG_B2A(fake_B).float()
    loss_cycle_ABA = criterion_cycle(recovered_A, real_a) * 10.0
    recovered_B = netG_A2B(fake_A).float()
    loss_cycle_BAB = criterion_cycle(recovered_B, real_b) * 10.0  
    loss_G = (loss_identity_A + loss_identity_B + loss_GAN_A2B + 
              loss_GAN_B2A + loss_cycle_ABA + loss_cycle_BAB)
    loss_G.backward()    
    g_optimizer.step()

13.训练判别器

da_optimizer.zero_grad(): 清除判别器 A 的梯度,以便于下一次前向传播和反向传播时不会累积梯度。

pred_real = netD_A(real_a).float(): 使用判别器 A 来判断真实图像 real_a 是否为真实图像。

loss_D_real = criterion_GAN(pred_real, target_real): 计算判别器 A对真实图像的预测和真实图像的损失,即 GAN 损失。

fake_A = fake_A_buffer.push_and_pop(fake_A): 从 fake_A_buffer 中获取一批fake_A 图像,这些图像是从生成器 A 生成的假图像。

pred_fake = netD_A(fake_A.detach()).float(): 使用判别器 A 来判断 fake_A是否为假图像。由于 fake_A 是从 fake_A_buffer 中获取的,它已经与生成器的梯度解耦,因此不需要梯度信息。

loss_D_fake = criterion_GAN(pred_fake, target_fake): 计算判别器 A 对fake_A 的预测和假图像的损失,即 GAN 损失。

loss_D_A = (loss_D_real + loss_D_fake) * 0.5: 将判别器 A的真实图像损失和假图像损失加在一起,得到判别器 A 的总损失。

loss_D_A.backward(): 对判别器 A 的总损失进行反向传播,计算每个参数的梯度。
da_optimizer.step(): 使用之前计算的梯度来更新判别器 A 的参数。

   # 第二步:训练判别器
    # 训练判别器A
    da_optimizer.zero_grad()
    pred_real = netD_A(real_a).float()
    loss_D_real = criterion_GAN(pred_real, target_real)
    fake_A = fake_A_buffer.push_and_pop(fake_A)
    pred_fake = netD_A(fake_A.detach()).float()
    loss_D_fake = criterion_GAN(pred_fake, target_fake)
    loss_D_A = (loss_D_real + loss_D_fake) * 0.5
    loss_D_A.backward()
    da_optimizer.step()
    # 训练判别器B
    db_optimizer.zero_grad()
    pred_real = netD_B(real_b)
    loss_D_real = criterion_GAN(pred_real, target_real)
    fake_B = fake_B_buffer.push_and_pop(fake_B)
    pred_fake = netD_B(fake_B.detach())
    loss_D_fake = criterion_GAN(pred_fake, target_fake)
    loss_D_B = (loss_D_real + loss_D_fake) * 0.5
    loss_D_B.backward()
	db_optimizer.step()

14.损失打印,存储伪造图片

print('Epoch[{}],loss_G:{:.6f} ,loss_D_A:{:.6f},loss_D_B:{:.6f}' .format(epoch, loss_G.data.item(), loss_D_A.data.item(), loss_D_B.data.item())):打印当前训练周期(epoch)的损失,包括生成器损失(loss_G)和两个判别器损失(loss_D_A 和 loss_D_B)。
if (epoch + 1) % 20 == 0 or epoch == 0:: 检查当前训练周期是否是 20的倍数,或者是否是第一个周期。如果是,则执行以下操作。
b_fake = to_img(fake_B.data): 将判别器 B的输入(fake_B)转换回图像格式。
a_fake = to_img(fake_A.data): 将判别器 A的输入(fake_A)转换回图像格式。
a_real = to_img(real_a.data): 将真实图像 A 转换回图像格式。
b_real = to_img(real_b.data): 将真实图像 B 转换回图像格式。
save_image(a_fake,'../tmp/a_fake.png'): 将 a_fake 图像保存到文件 …/tmp/a_fake.png。
save_image(b_fake, '../tmp/b_fake.png'): 将 b_fake 图像保存到文件…/tmp/b_fake.png。
save_image(a_real, '../tmp/a_real.png'): 将 a_real图像保存到文件 …/tmp/a_real.png。
save_image(b_real, '../tmp/b_real.png'):将 b_real 图像保存到文件 …/tmp/b_real.png。

 #损失打印,存储伪造图片
    print('Epoch[{}],loss_G:{:.6f} ,loss_D_A:{:.6f},loss_D_B:{:.6f}'
      .format(epoch, loss_G.data.item(), loss_D_A.data.item(), 
              loss_D_B.data.item()))
    if (epoch + 1) % 20 == 0 or epoch == 0:  
        b_fake = to_img(fake_B.data)
        a_fake = to_img(fake_A.data)
        a_real = to_img(real_a.data)
        b_real = to_img(real_b.data)
        save_image(a_fake, '../tmp/a_fake.png') 
        save_image(b_fake, '../tmp/b_fake.png') 
        save_image(a_real, '../tmp/a_real.png') 
        save_image(b_real, '../tmp/b_real.png') 

在这里插入图片描述
在这里插入图片描述

全部代码

from random import randint
import numpy as np 
import torch
torch.set_default_tensor_type(torch.FloatTensor)
import torch.nn as nn
import torch.optim as optim
import torchvision.datasets as datasets
import torchvision.transforms as transforms
import os
import matplotlib.pyplot as plt
import torch.nn.functional as F
from torch.autograd import Variable
from torchvision.utils import save_image
import shutil
import cv2
import random
from PIL import Image
import itertools   
 def to_img(x):
    out = 0.5 * (x + 1)
    out = out.clamp(0, 1)  
    out = out.view(-1, 3, 256, 256)  
    return out

# 数据加载 
data_path = os.path.abspath('D:\probject\pythonProject1\pytorch\CycleGAN\data')
image_size = 256
batch_size = 1

transform = transforms.Compose([transforms.Resize(int(image_size * 1.12), 
                                                  Image.BICUBIC), 
            transforms.RandomCrop(image_size), 
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            transforms.Normalize((0.5,0.5,0.5), (0.5,0.5,0.5))])
def _get_train_data(batch_size=1):
    
    train_a_filepath = data_path + '\\trainA\\'
    train_b_filepath = data_path + '\\trainB\\'
    
    train_a_list = os.listdir(train_a_filepath)
    train_b_list = os.listdir(train_b_filepath)
    
    train_a_result = []
    train_b_result = [] 
    
    numlist = random.sample(range(0, len(train_a_list)), batch_size)
    
    for i in numlist:
        a_filename = train_a_list[i]
        a_img = Image.open(train_a_filepath + a_filename).convert('RGB')
        res_a_img = transform(a_img)
        train_a_result.append(torch.unsqueeze(res_a_img, 0))
        
        b_filename = train_b_list[i]
        b_img = Image.open(train_b_filepath + b_filename).convert('RGB')
        res_b_img = transform(b_img)
        train_b_result.append(torch.unsqueeze(res_b_img, 0))
        
    return torch.cat(train_a_result, dim=0), torch.cat(train_b_result, dim=0)

# """
# 残差网络block
class ResidualBlock(nn.Module):
    def __init__(self, in_features):
        super(ResidualBlock, self).__init__()
        self.block_layer = nn.Sequential(
            nn.ReflectionPad2d(1),
            nn.Conv2d(in_features, in_features, 3),
            nn.InstanceNorm2d(in_features),
            nn.ReLU(inplace=True),
            nn.ReflectionPad2d(1),
            nn.Conv2d(in_features, in_features, 3),
            nn.InstanceNorm2d(in_features))
        
    def forward(self, x):
        return x + self.block_layer(x)
# 生成器
class Generator(nn.Module):
    def __init__(self):
        super(Generator, self).__init__()
             
        model = [nn.ReflectionPad2d(3), 
                 nn.Conv2d(3, 64, 7), 
                 nn.InstanceNorm2d(64), 
                 nn.ReLU(inplace=True)]

        in_features = 64
        out_features = in_features * 2
        for _ in range(2):
            model += [nn.Conv2d(in_features, out_features, 
                                3, stride=2, padding=1), 
            nn.InstanceNorm2d(out_features), 
            nn.ReLU(inplace=True)]
            in_features = out_features
            out_features = in_features*2

        for _ in range(9):
            model += [ResidualBlock(in_features)]

        out_features = in_features // 2
        for _ in range(2):
            model += [nn.ConvTranspose2d(
                    in_features, out_features, 
                    3, stride=2, padding=1, output_padding=1), 
                nn.InstanceNorm2d(out_features), 
                nn.ReLU(inplace=True)]
            in_features = out_features
            out_features = in_features // 2

        model += [nn.ReflectionPad2d(3), 
                  nn.Conv2d(64, 3, 7), 
                  nn.Tanh()]

        self.gen = nn.Sequential( * model)
        
    def forward(self, x):
        x = self.gen(x)
        return x 
# 判别器 

class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()
        self.dis = nn.Sequential(
            nn.Conv2d(3, 64, 4, 2, 1, bias=False),
            nn.LeakyReLU(0.2, inplace=True),

            nn.Conv2d(64, 128, 4, 2, 1, bias=False),
            nn.InstanceNorm2d(128),
            nn.LeakyReLU(0.2, inplace=True),

            nn.Conv2d(128, 256, 4, 2, 1, bias=False),
            nn.InstanceNorm2d(256),
            nn.LeakyReLU(0.2, inplace=True),

            nn.Conv2d(256, 512, 4, padding=1),
            nn.InstanceNorm2d(512),
            nn.LeakyReLU(0.2, inplace=True),
            
            nn.Conv2d(512, 1, 4, padding=1))        
        
    def forward(self, x):
        x = self.dis(x)
        return F.avg_pool2d(x, x.size()[2:]).view(x.size()[0], -1)


class ReplayBuffer():
#     """
#     缓存队列,若不足则新增,否则随机替换
#     """
    def __init__(self, max_size=50):
        self.max_size = max_size
        self.data = []
        
    def push_and_pop(self, data):
        to_return = []
        for element in data.data:
            element = torch.unsqueeze(element, 0)
            if len(self.data) < self.max_size:
                self.data.append(element)
                to_return.append(element)
            else:
                if random.uniform(0,1) > 0.5:
                    i = random.randint(0, self.max_size-1)
                    to_return.append(self.data[i].clone())
                    self.data[i] = element
                else:
                    to_return.append(element)
        return Variable(torch.cat(to_return))
    
fake_A_buffer = ReplayBuffer()
fake_B_buffer = ReplayBuffer()

netG_A2B = Generator()
netG_B2A = Generator()
netD_A = Discriminator()
netD_B = Discriminator()

criterion_GAN = torch.nn.MSELoss()
criterion_cycle = torch.nn.L1Loss()
criterion_identity = torch.nn.L1Loss()

d_learning_rate = 3e-4  # 3e-4
g_learning_rate = 3e-4
optim_betas = (0.5, 0.999)

g_optimizer = optim.Adam(itertools.chain(netG_A2B.parameters(), 
                                         netG_B2A.parameters()), 
            lr=d_learning_rate)
da_optimizer = optim.Adam(netD_A.parameters(), lr=d_learning_rate)
db_optimizer = optim.Adam(netD_B.parameters(), lr=d_learning_rate)

num_epochs = 100
for epoch in range(num_epochs): 

    real_a, real_b = _get_train_data(batch_size)
    target_real = torch.full((batch_size,), 1).float()
    target_fake = torch.full((batch_size,), 0).float()
    
    g_optimizer.zero_grad()
    
    # 第一步:训练生成器
    same_B = netG_A2B(real_b).float()
    loss_identity_B = criterion_identity(same_B, real_b) * 5.0   
    same_A = netG_B2A(real_a).float()
    loss_identity_A = criterion_identity(same_A, real_a) * 5.0
    
    fake_B = netG_A2B(real_a).float()
    pred_fake = netD_B(fake_B).float()
    loss_GAN_A2B = criterion_GAN(pred_fake, target_real)
    fake_A = netG_B2A(real_b).float()
    pred_fake = netD_A(fake_A).float()
    loss_GAN_B2A = criterion_GAN(pred_fake, target_real)
    recovered_A = netG_B2A(fake_B).float()
    loss_cycle_ABA = criterion_cycle(recovered_A, real_a) * 10.0
    recovered_B = netG_A2B(fake_A).float()
    loss_cycle_BAB = criterion_cycle(recovered_B, real_b) * 10.0  
    loss_G = (loss_identity_A + loss_identity_B + loss_GAN_A2B + 
              loss_GAN_B2A + loss_cycle_ABA + loss_cycle_BAB)
    loss_G.backward()    
    g_optimizer.step()
    
    
    # 第二步:训练判别器
    # 训练判别器A
    da_optimizer.zero_grad()
    pred_real = netD_A(real_a).float()
    loss_D_real = criterion_GAN(pred_real, target_real)
    fake_A = fake_A_buffer.push_and_pop(fake_A)
    pred_fake = netD_A(fake_A.detach()).float()
    loss_D_fake = criterion_GAN(pred_fake, target_fake)
    loss_D_A = (loss_D_real + loss_D_fake) * 0.5
    loss_D_A.backward()
    da_optimizer.step()
    # 训练判别器B
    db_optimizer.zero_grad()
    pred_real = netD_B(real_b)
    loss_D_real = criterion_GAN(pred_real, target_real)
    fake_B = fake_B_buffer.push_and_pop(fake_B)
    pred_fake = netD_B(fake_B.detach())
    loss_D_fake = criterion_GAN(pred_fake, target_fake)
    loss_D_B = (loss_D_real + loss_D_fake) * 0.5
    loss_D_B.backward()
    db_optimizer.step()
    
    
    #损失打印,存储伪造图片
    print('Epoch[{}],loss_G:{:.6f} ,loss_D_A:{:.6f},loss_D_B:{:.6f}'
      .format(epoch, loss_G.data.item(), loss_D_A.data.item(), 
              loss_D_B.data.item()))
    if (epoch + 1) % 20 == 0 or epoch == 0:  
        b_fake = to_img(fake_B.data)
        a_fake = to_img(fake_A.data)
        a_real = to_img(real_a.data)
        b_real = to_img(real_b.data)
        save_image(a_fake, '../tmp/a_fake.png') 
        save_image(b_fake, '../tmp/b_fake.png') 
        save_image(a_real, '../tmp/a_real.png') 
        save_image(b_real, '../tmp/b_real.png') 
    

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/746476.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

基于redisson实现tomcat集群session共享

目录 1、环境 2、修改server.xml 3、修改context.xml 4、新增redisson配置文件 5、下载并复制2个Jar包到Tomcat Lib目录中 6、 安装redis 7、配置nginx负载均衡 8、配置测试页面 9、session共享测试验证 前言&#xff1a; 上篇中&#xff0c;Tomcat session复制及ses…

观测云 VS 开源自建

观测云是一款面向全技术栈的监控观测一体化产品方案&#xff0c;具备强大而丰富的功能&#xff0c;目标是帮助最终用户提升监控观测的能力&#xff0c;化繁为简&#xff0c;轻松的构建起完整的监控观测体系。同时能够帮助整个企业的开发技术团队从统一的观测能力上获得完整的收…

ONLYOFFICE 文档开发者版 8.1:API 更新

随着版本 8.1 新功能的发布&#xff0c;我们更新了编辑器、文档生成器和插件的 API&#xff0c;并添加了 Office API 板块。阅读下文了解详情。 ​ ONLYOFFICE 文档是什么 ONLYOFFICE 文档是一个功能强大的文档编辑器&#xff0c;支持处理文本文档、电子表格、演示文稿、可填写…

探索ChatGPT在程序员日常工作的多种应用

引言 在现代科技迅猛发展的今天&#xff0c;人工智能的应用已经深入到我们生活和工作的各个方面。作为程序员&#xff0c;我们时常面临大量繁杂的任务&#xff0c;从代码编写、错误调试到项目管理和团队协作&#xff0c;每一项都需要花费大量的时间和精力。近年来&#xff0c;…

算法与数据结构——时间复杂度详解与示例(C#,C++)

文章目录 1. 算法与数据结构概述2. 时间复杂度基本概念3. 时间复杂度分析方法4. 不同数据结构的时间复杂度示例5. 如何通过算法优化来提高时间复杂度6. C#中的时间复杂度示例7. 总结 算法与数据结构是计算机科学的核心&#xff0c;它们共同决定了程序的性能和效率。在实际开发中…

大模型产品的“命名经济学”:名字越简单,产品越火爆?

文 | 智能相对论 作者 | 陈泊丞 古人云&#xff1a;赐子千金&#xff0c;不如教子一艺&#xff1b;教子一艺&#xff0c;不如赐子一名。 命名之妙&#xff0c;玄之又玄。 早两年&#xff0c;大模型爆火&#xff0c;本土厂商在大模型产品命名上可谓下足了功夫&#xff0c;引…

C#+uni-app医院HIS预约挂号系统源码 看病挂号快人一步

​​​​​​​ 提到去大型医院机构就诊时&#xff0c;许多人都感到恐惧。有些人一旦走进医院的门诊大厅&#xff0c;就感到迷茫&#xff0c;既无法理解导医台医生的建议&#xff0c;也找不到应该去哪个科室进行检查。实际上&#xff0c;就医也是一门学问&#xff0c;如何优化…

【CS.DS】数据结构 —— 图:深入了解三种表示方法之邻接表(Adjacency List)

文章目录 1 概念2 无向图的邻接表2.1 示例2.2 Mermaid 图示例2.3 C实现2.3.1 简单实现2.3.2 优化封装 2.4 总结 3 有向图的邻接表3.1 示例3.2 C实现3.3 总结 4 邻接图的遍历5 拓展补充References 数据结构 1 概念 优点&#xff1a;空间效率高&#xff0c;适合稀疏图。动态性强…

Win10,Win11电脑重装系统怎么操作,简单一步搞定【保姆级教程】

电脑重装系统怎么操作&#xff1f;电脑使用时间长了&#xff0c;就会出现系统崩溃、病毒感染或者是系统文件损坏等问题。这个时候我们就可以对电脑进行系统重装&#xff0c;也就是恢复电脑出厂设置。现在市面上有很多系统重装工具可以帮助我们解决难题&#xff0c;如果您是电脑…

自定义 Django 管理界面中的多对多内联模型

1. 问题背景 在 Django 管理界面中&#xff0c;用户可以使用内联模型来管理一对多的关系。但是&#xff0c;当一对多关系是多对多时&#xff0c;Django 提供的默认内联模型可能并不适合。例如&#xff0c;如果存在一个产品模型和一个发票模型&#xff0c;并且产品和发票之间是…

Java文件操作小项目-带GUI界面统计文件夹内文件类型及大小

引言 在Java编程中&#xff0c;文件操作是一项基本且常见的任务。我们经常需要处理文件和文件夹&#xff0c;例如读取、写入、删除文件&#xff0c;或者遍历文件夹中的文件等。本文将介绍如何使用Java的File类和相关API来统计一个文件夹中不同类型文件的数量和大小。 准备工作…

数据分析python基础实战分析

数据分析python基础实战分析 安装python&#xff0c;建议安装Anaconda 【Anaconda下载链接】https://repo.anaconda.com/archive/ 记得勾选上这个框框 安装完后&#xff0c;然后把这两个框框给取消掉再点完成 在电脑搜索框输入"Jupyter"&#xff0c;牛马启动&am…

Vitis Accelerated Libraries 学习笔记--OpenCV 安装指南

目录 1. 简介 2. 安装过程 2.1 安装准备 2.2 编译并安装 XRT 2.2.1 下载 XRT 源码 2.2.2 安装依赖项 2.2.3 构建 XRT 2.2.4 打包 DEB 2.2.5 安装 XRT 2.3 编译并安装 OpenCV 2.3.1 下载 OpenCV 源码 2.3.2 创建目录 2.3.3 设置环境变量 2.3.4 构建 opencv 3. 总…

【STM32】看门狗

1.看门狗简介 看门狗起始就是一个定时器&#xff0c;从功能上说它可以让微控制器在程序发生意外&#xff08;程序进入死循环或跑飞&#xff09;的时候&#xff0c;能重新恢复到系统刚上电状态&#xff0c;以保障系统出问题的时候可以重启一次。说的简单一点&#xff0c;看门狗…

加速业务布局,30年老将加盟ATFX,掌舵运营新篇章

全球领先的差价合约经纪商ATFX日前宣布了一项重大人事任命&#xff0c;聘请业界资深人士约翰博格(John Bogue)为机构业务运营总监。约翰博格是一名行业老将&#xff0c;曾在差价合约界深耕三十余载。伴随其加入ATFX&#xff0c;相信他的深厚专业知识和从业经验将为ATFX机构业务…

HarmonyOS NEXT Developer Beta1配套相关说明

一、版本概述 2024华为开发者大会&#xff0c;HarmonyOS NEXT终于在万千开发者的期待下从幕后走向台前。 HarmonyOS NEXT采用全新升级的系统架构&#xff0c;贯穿HarmonyOS全场景体验的底层优化&#xff0c;系统更流畅&#xff0c;隐私安全能力更强大&#xff0c;将给您带来更高…

数据集的未来:如何利用亮数据浏览器提升数据采集效率

目录 一、跨境电商的瓶颈1、技术门槛2、语言与文化差异3、网络稳定性4、验证码处理和自动识别5、数据安全6、法规和合规 二、跨境电商现在是一个合适的商机吗&#xff1f;三、数据集与亮数据浏览器1、市场分析2、价格监控3、产品开发4、供应链优化5、客户分析 四、亮数据浏览器…

Jenkins流水线发布,一篇就解决你的所有疑惑

这次搭建的项目比较常规,前端是react写的,后端是springboot,并且由于是全栈开发,所以是在同一个项目中。接下来我演示下怎么用jenkins进行自动化发布。 1.jenkins必装插件 这里用到的是jenkinsFile主要是基于Groovy这个沙盒,有些前置插件。这里使用maven进行打包,所以需…

如何提高项目风险的处理效率?5个重点

提高项目风险的处理效率&#xff0c;有助于迅速识别和应对风险&#xff0c;减少风险导致的延误&#xff0c;降低成本&#xff0c;提升项目质量&#xff0c;确保项目按时交付。如果项目风险处理效率较低&#xff0c;未能及时发现和处理风险&#xff0c;导致问题累积&#xff0c;…

浏览器扩展V3开发系列之 chrome.runtime 的用法和案例

【作者主页】&#xff1a;小鱼神1024 【擅长领域】&#xff1a;JS逆向、小程序逆向、AST还原、验证码突防、Python开发、浏览器插件开发、React前端开发、NestJS后端开发等等 chrome.runtime API 提供了一系列的方法和事件&#xff0c;可以通过它来管理和维护 Chrome 扩展的生命…