文章目录
- 0、遇到的问题
- 1、配置环境 & 导入数据
- 2、定义模型
- 3、训练模型
- 4、什么是DCGAN?
- 🍨 本文为🔗365天深度学习训练营 中的学习记录博客
- 🍖 原作者:K同学啊 | 接辅导、项目定制
本文环境:
系统环境:
语言:Python3.7.8
编译器:VSCode
深度学习框架:torch 1.13.1
0、遇到的问题
ImportError :IPython模块没找到
在运行程序的环境里安装该模块
1、配置环境 & 导入数据
print("*****************1.1 导入第三方库***************")
import torch, random, os
import torch.nn as nn
import torch.nn.parallel
import torch.optim as optim
import torch.utils.data
import torchvision.datasets as dset
import torchvision.transforms as transforms
import torchvision.utils as utils
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.animation as animation
from IPython.display import HTML
randomSeed = 999 # 随机种子
print("Random Seed: ", randomSeed)
random.seed(randomSeed)
torch.manual_seed(randomSeed)
torch.use_deterministic_algorithms(True)
print("\n")
print("*****************1.2 设置超参数***************")
dataroot = "D:/jupyter notebook/DL-100-days/GAN/G2/"
batch_size = 128
image_size = 64
nz = 100
ngf = 64
ndf = 64
num_epochs = 50
lr = 0.0002
beta1 = 0.5
print("\n")
print("*****************1.3 导入数据***************")
# 创建数据集
dataset = dset.ImageFolder(root=dataroot,
transform=transforms.Compose([
transforms.Resize(image_size),
transforms.CenterCrop(image_size),
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5),(0.5, 0.5, 0.5)),
]))
# 创建数据加载器
dataloader = torch.utils.data.DataLoader(dataset,
batch_size=batch_size,
shuffle=True,
num_workers=0) # 这个值我写为5,就报错
# 选择要在那个设备上运行代码
device = torch.device("cuda:0" if (torch.cuda.is_available()) else "cpu")
print("使用的设备是:", device)
# 绘制一些训练图像
real_batch = next(iter(dataloader))
plt.figure(figsize=(8,8))
plt.axis("off")
plt.title("Training Images")
plt.imshow(np.transpose(utils.make_grid(real_batch[0].to(device)[:24],
padding=2,
normalize=True).cpu(), (1,2,0)))
plt.savefig("./GAN/G2/训练图像.png")
print("\n")
输出结果:
2、定义模型
print("*****************2. 定义模型******************")
print("*****************2.1 初始化权重***************")
def weights_init(m):
classname = m.__class__.__name__
if classname.find('Conv') != -1:
nn.init.normal_(m.weight.data, 0.0, 0.02)
elif classname.find('BatchNorm') != -1:
nn.init.normal_(m.weight.data, 1.0, 0.02)
nn.init.constant_(m.bias.data, 0)
print("\n")
print("*****************2.2 定义生成器***************")
class Generator(nn.Module):
def __init__(self):
super(Generator, self).__init__()
self.main = nn.Sequential(
nn.ConvTranspose2d(nz, ngf*8, 4, 1, 0, bias=False),
nn.BatchNorm2d(ngf*8),
nn.ReLU(True),
nn.ConvTranspose2d(ngf*8, ngf*4, 4, 2, 1, bias=False),
nn.BatchNorm2d(ngf*4),
nn.ReLU(True),
nn.ConvTranspose2d(ngf*4, ngf*2, 4, 2, 1, bias=False),
nn.BatchNorm2d(ngf*2),
nn.ReLU(True),
nn.ConvTranspose2d(ngf*2, ngf, 4, 2, 1, bias=False),
nn.BatchNorm2d(ngf),
nn.ReLU(True),
nn.ConvTranspose2d(ngf, 3, 4, 2, 1, bias=False),
nn.Tanh()
)
def forward(self, input):
return self.main(input)
netG = Generator().to(device)
netG.apply(weights_init)
print("netG: ", netG)
print("\n")
print("*****************2.3 定义鉴别器***************")
class Descriminator(nn.Module):
def __init__(self):
super(Descriminator, self).__init__()
self.main = nn.Sequential(
nn.Conv2d(3, ndf, 4, 2, 1, bias=False),
nn.LeakyReLU(0.2, inplace=True),
nn.Conv2d(ndf, ndf*2, 4, 2, 1, bias=False),
nn.BatchNorm2d(ndf*2),
nn.LeakyReLU(0.2, inplace=True),
nn.Conv2d(ndf*2, ndf*4, 4, 2, 1, bias=False),
nn.BatchNorm2d(ndf*4),
nn.LeakyReLU(0.2, inplace=True),
nn.Conv2d(ndf*4, ndf*8, 4, 2, 1, bias=False),
nn.BatchNorm2d(ndf*8),
nn.LeakyReLU(0.2, inplace=True),
nn.Conv2d(ndf*8, 1, 4, 1, 0, bias=False),
nn.Sigmoid()
)
def forward(self, input):
return self.main(input)
netD = Descriminator().to(device)
netD.apply(weights_init)
print("netD: ", netD)
print("\n")
3、训练模型
print("*****************3 训练模型***************")
print("*****************3.1 定义训练参数***************")
criterion = nn.BCELoss()
fixed_noise = torch.randn(64, nz, 1, 1, device=device)
real_label = 1.
fake_label = 0.
optimizerD = optim.Adam(netD.parameters(),
lr=lr,
betas=(beta1, 0.999))
optimizerG = optim.Adam(netG.parameters(),
lr=lr,
betas=(beta1, 0.999))
print("\n")
print("*****************3.2 训练模型***************")
img_list = []
G_loss = []
D_loss = []
iters = 0
print("Startint training loop...")
for epoch in range(num_epochs):
for i,data in enumerate(dataloader, 0):
netD.zero_grad()
real_cpu = data[0].to(device)
b_size = real_cpu.size(0)
label = torch.full((b_size,),
real_label,
dtype=torch.float,
device=device)
output = netD(real_cpu).view(-1)
errD_real = criterion(output, label)
errD_real.backward()
D_x = output.mean().item()
noise = torch.randn(b_size, nz, 1, 1, device=device)
fake = netG(noise)
label.fill_(fake_label)
output = netD(fake.detach()).view(-1)
errD_fake = criterion(output, label)
errD_fake.backward()
D_G_z1 = output.mean().item()
errD = errD_real + errD_fake
optimizerD.step()
netG.zero_grad()
label.fill_(real_label)
output = netD(fake).view(-1)
errG = criterion(output, label)
errG.backward()
D_G_z2 = output.mean().item()
optimizerG.step()
if i%400 == 0:
print("[%d/%d][%d/%d]\tLossD: %.4f\tLossG: %.4f\tD(x): %.4f\tD(G(z)): %.4f/%.4f"
%(epoch, num_epochs, i, len(dataloader),errD.item(), errG.item(), D_x, D_G_z1, D_G_z2))
G_loss.append(errG.item())
D_loss.append(errD.item())
if(iters%500==0) or ((epoch==num_epochs-1) and (i==len(dataloader)-1)):
with torch.no_grad():
fake = netG(fixed_noise).detach().cpu()
img_list.append(utils.make_grid(fake, padding=2, normalize=True))
iters += 1
print("\n")
训练结果打印:
print("*****************3.3 训练过程可视化***************")
plt.figure(figsize=(10,5))
plt.title("Generator & Descriminator Loss During Training")
plt.plot(G_loss, label="G")
plt.plot(D_loss, label="D")
plt.xlabel("iterations")
plt.ylabel("Loss")
plt.legend()
plt.savefig("./GAN/G2/训练过程可视化.png")
plt.show()
fig = plt.figure(figsize=(8,8))
plt.axis("off")
ims = [[plt.imshow(np.transpose(i, (1,2,0)),animated=True)]
for i in img_list]
ani = animation.ArtistAnimation(fig, ims, interval=1000, repeat_delay = 1000, blit=True)
HTML(ani.to_jshtml)
real_batch = next(iter(dataloader))
plt.figure(figsize=(15,15))
plt.subplot(1,2,1)
plt.axis("off")
plt.title("Real Images")
plt.imshow(np.transpose(utils.make_grid(real_batch[0].to(device)[:64],padding=5,normalize=True).cpu(),(1,2,0)))
plt.subplot(1,2,2)
plt.axis("off")
plt.title("Fake Images")
plt.imshow(np.transpose(img_list[-1],(1,2,0)))
plt.savefig("./GAN/G2/Real & Fake images.png")
plt.show()
print("\n")
HTML(ani.to_jshtml)
这里总是报错,不知道怎么回事?
4、什么是DCGAN?
网络结构:
通过查看网络结构中包含的生成器和鉴别器,前半部分是提取输入图像的特征,后半部分根据已知新的特征形成新的特征。后半部分像前半部分的逆过程。
比如本文输入 人脸图片,先通过生成器,提取构成人脸的要素,然后输入到鉴别器,发展新的人脸。
以上是个人理解。
更详细的说明参考:【pytorch】DCGAN实战教程(官方教程翻译)