from Discriminator import Discriminator
from Genration import CVAE
import torch
import torch.optim as optim
import torchvision
import torch.nn.functional as F
from torchvision.utils import save_image
# 生成0-9数字
def sample_images(epoch):
with torch.no_grad(): # 上下文管理器,确保在该上下文中不会进行梯度计算。因为在这里只是生成样本而不需要梯度
number = 10
# 生成标签
sample_labels = torch.arange(10).long().to(device) # 0-9的标签
sample_labels_onehot = F.one_hot(sample_labels, num_classes=10).float()
# 生成随机噪声
sample = torch.randn(number, latent_size).to(device) # 生成一个形状为 (64, latent_size) 的张量,其中包含从标准正态分布中采样的随机数
sample = torch.cat([sample, sample_labels_onehot], dim=1) # 连接图片和标签
sample = cvae_model.decode(sample).cpu() # 将随机样本输入到解码器中,解码器将其映射为图像
save_image(sample.view(number, 1, 28, 28), f'sample{epoch}.png', nrow=int(number / 2)) # 将生成的图像保存为文件
def generator_loss(recon_x, x, mu, log_var, discriminator_output):
mse_loss = F.mse_loss(recon_x, x.view(-1, input_size), reduction='sum') # 计算重构图像 recon_x 和原始图像 x 之间的均方误差
kld_loss = -0.5 * torch.sum(1 + log_var - mu.pow(2) - log_var.exp())
gan_loss = F.binary_cross_entropy(discriminator_output, torch.ones_like(discriminator_output))
return mse_loss + kld_loss + gan_loss
def discriminator_loss_acc(real_output, fake_output):
# 损失
real_loss = F.binary_cross_entropy(real_output, torch.ones_like(real_output))
fake_loss = F.binary_cross_entropy(fake_output, torch.zeros_like(fake_output))
total_loss = real_loss + fake_loss
# 精度
real_pred = torch.round(real_output)
fake_pred = torch.round(fake_output)
real_acc = (real_pred == 1).sum().item() / real_output.numel()
fake_acc = (fake_pred == 0).sum().item() / fake_output.numel()
total_acc = (real_acc + fake_acc) / 2
return total_loss, total_acc
if __name__ == '__main__':
batch_size = 512 # 批次大小
epochs = 50 # 学习周期
sample_interval = 10 # 保存结果的周期
learning_rate = 0.001 # 学习率
input_size = 784 # 输入大小
num_classes = 10 # 标签数量
latent_size = 64 # 噪声大小
# 载入 MNIST 数据集中的图片进行训练
transform = torchvision.transforms.Compose([torchvision.transforms.ToTensor()]) # 将图像转换为张量
train_dataset = torchvision.datasets.MNIST(
root="~/torch_datasets", train=True, transform=transform, download=True
) # 加载 MNIST 数据集的训练集,设置路径、转换和下载为 True
train_loader = torch.utils.data.DataLoader(
train_dataset, batch_size=batch_size, shuffle=True
) # 创建一个数据加载器,用于加载训练数据,设置批处理大小和是否随机打乱数据
# 配置要在哪个设备上运行
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
cvae_model = CVAE(input_size, num_classes).to(device)
dis_model = Discriminator(input_size, num_classes).to(device)
optimizer_cvae = optim.Adam(cvae_model.parameters(), lr=learning_rate)
optimizer_dis = optim.Adam(dis_model.parameters(), lr=learning_rate)
for epoch in range(epochs):
generator_loss_total = 0
discriminator_loss_total = 0
discriminator_acc_total = 0
for batch_idx, (data, labels) in enumerate(train_loader):
data = data.to(device)
data = data.view(-1, input_size)
labels = F.one_hot(labels, num_classes).float().to(device)
# 更新判别器
optimizer_dis.zero_grad()
recon_batch, _, _ = cvae_model(data, labels) # 生成虚假数据
fake_data = torch.cat([recon_batch, labels], dim=1)
real_data = torch.cat([data, labels], dim=1)
fake_output = dis_model(fake_data)
real_output = dis_model(real_data)
d_loss, d_acc = discriminator_loss_acc(real_output, fake_output) # 计算判别器损失和精度
d_loss.backward()
optimizer_dis.step() # 更新模型参数
# 更新生成器
optimizer_cvae.zero_grad()
recon_batch, mu, log_var = cvae_model(data, labels)
fake_data = torch.cat([recon_batch, labels], dim=1)
fake_output = dis_model(fake_data)
g_loss = generator_loss(recon_batch, data, mu, log_var, fake_output)
g_loss.backward()
optimizer_cvae.step()
generator_loss_total += g_loss.item()
discriminator_loss_total += d_loss.item()
discriminator_acc_total += d_acc
generator_loss_avg = generator_loss_total / len(train_loader)
discriminator_loss_avg = discriminator_loss_total / len(train_loader)
discriminator_acc_avg = discriminator_acc_total / len(train_loader)
print('Epoch [{}/{}], Generator Loss: {:.3f}, Discriminator Loss: {:.3f}, Discriminator Acc: {:.2f}%'.format(
epoch + 1, epochs, generator_loss_avg, discriminator_loss_avg, discriminator_acc_avg * 100))
if (epoch + 1) % sample_interval == 0:
sample_images(epoch + 1)