(Generative Adversarial Network)
一、理论
https://zhuanlan.zhihu.com/p/307527293?utm_campaign=shareopn&utm_medium=social&utm_psn=1815884330188283904&utm_source=wechat_session
大佬的文章中的“GEN的本质”部分
二、实验
1、数据集介绍
采用MNIST数据集,如下是训练集中的一张图片
2、代码
引入包
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
import torchvision.datasets as datasets
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
定义生成器
class Generator(nn.Module):
def __init__(self):
super(Generator, self).__init__()
self.net = nn.Sequential(
nn.Linear(100, 256),
nn.ReLU(),
nn.Linear(256, 512),
nn.ReLU(),
nn.Linear(512, 1024),
nn.ReLU(),
nn.Linear(1024, 784),
nn.Tanh() # 输出范围在 -1 到 1 之间
)
def forward(self, x):
return self.net(x)
定义判别器
class Discriminator(nn.Module):
def __init__(self):
super(Discriminator, self).__init__()
self.net = nn.Sequential(
nn.Linear(784, 1024),
nn.LeakyReLU(0.2),
nn.Linear(1024, 512),
nn.LeakyReLU(0.2),
nn.Linear(512, 256),
nn.LeakyReLU(0.2),
nn.Linear(256, 1),
nn.Sigmoid() # 输出范围在 0 到 1 之间
)
def forward(self, x):
return self.net(x)
训练
先训练判别器,后训练生成器
训练时先训练判别器:将训练集数据(Training Set)打上真标签(1)和生成器(Generator)生成的假图片(Fake image)打上假标签(0)一同组成batch送入判别器(Discriminator),对判别器进行训练。计算loss时使判别器对真数据(Training Set)输入的判别趋近于真(1),对生成器(Generator)生成的假图片(Fake image)的判别趋近于假(0)。此过程中只更新判别器(Discriminator)的参数,不更新生成器(Generator)的参数。
然后再训练生成器:将高斯分布的噪声z(Random noise)送入生成器(Generator),然后将生成器(Generator)生成的假图片(Fake image)打上真标签(1)送入判别器(Discriminator)。计算loss时使判别器对生成器(Generator)生成的假图片(Fake image)的判别趋近于真(1)。此过程中只更新生成器(Generator)的参数,不更新判别器(Discriminator)的参数。
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.5,), (0.5,))
])
dataset = datasets.MNIST(root='./data', train=True, transform=transform, download=True)
dataloader = DataLoader(dataset, batch_size=64, shuffle=True)
def train_gan(generator, discriminator, dataloader, num_epochs=25):
criterion = nn.BCELoss()
optimizer_g = optim.Adam(generator.parameters(), lr=0.0002, betas=(0.5, 0.999))
optimizer_d = optim.Adam(discriminator.parameters(), lr=0.0002, betas=(0.5, 0.999))
for epoch in range(num_epochs):
for i, (imgs, _) in enumerate(dataloader):
batch_size = imgs.size(0)
real_imgs = imgs.view(batch_size, -1) # 将图像展平成一维
# 标签
real_labels = torch.ones(batch_size, 1)
fake_labels = torch.zeros(batch_size, 1)
# 训练判别器
outputs = discriminator(real_imgs)
d_loss_real = criterion(outputs, real_labels)
real_score = outputs
z = torch.randn(batch_size, 100)
fake_imgs = generator(z)
outputs = discriminator(fake_imgs.detach())
d_loss_fake = criterion(outputs, fake_labels)
fake_score = outputs
d_loss = d_loss_real + d_loss_fake
optimizer_d.zero_grad()
d_loss.backward()
optimizer_d.step()
# 训练生成器
outputs = discriminator(fake_imgs)
g_loss = criterion(outputs, real_labels)
optimizer_g.zero_grad()
g_loss.backward()
optimizer_g.step()
print(f'Epoch [{epoch+1}/{num_epochs}], d_loss: {d_loss.item()}, g_loss: {g_loss.item()}')
if (epoch+1) % 10 == 0:
with torch.no_grad():
fake_imgs = generator(torch.randn(64, 100)).view(-1, 1, 28, 28)
grid = torchvision.utils.make_grid(fake_imgs, nrow=8, normalize=True)
plt.imshow(grid.permute(1, 2, 0).cpu())
plt.title(f'Epoch {epoch+1}')
plt.show()
generator = Generator()
discriminator = Discriminator()
train_gan(generator, discriminator, dataloader)
3、结果
输入一个随机噪声图像,由生成器能得到如下的图片(训练1step的结果)
输入一个随机噪声图像,由生成器能得到如下的图片(训练10step的结果)
拓展
可以看看VAE、CGAN等模型