import os
import time
import matplotlib. pyplot as plt
import numpy as np
import torchvision. transforms as transforms
from torch. utils. data import DataLoader
from torchvision import datasets
import torch. nn as nn
import torch
os. makedirs( "data" , exist_ok= True )
transform = transforms. Compose( [
transforms. ToTensor( ) ,
transforms. Normalize( 0.5 , 0.5 ) ,
] )
train_dataset = datasets. MNIST( 'data' ,
train= True ,
transform= transform,
download= True )
dataloader = torch. utils. data. DataLoader( train_dataset, batch_size= 64 , shuffle= True )
linear1: 100 -> 256
linear2: 256 -> 512
linear3: 512 -> 28*28
reshape: 28x28 -> (1,28,28)
class Generator ( nn. Module) :
def __init__ ( self) :
super ( Generator, self) . __init__( )
self. model = nn. Sequential( nn. Linear( 100 , 256 ) , nn. ReLU( ) ,
nn. Linear( 256 , 512 ) , nn. ReLU( ) ,
nn. Linear( 512 , 28 * 28 ) , nn. Tanh( ) )
def forward ( self, x) :
img = self. model( x)
img = img. view( - 1 , 28 , 28 , 1 )
return img
输出:二分类的概率值 用sigmoid压缩到0-1之间
判别器 推荐使用LeakyRelu,因为生成器难以训练,Relu的负值直接变成0没有梯度了
class Discriminator ( nn. Module) :
def __init__ ( self) :
super ( Discriminator, self) . __init__( )
self. model = nn. Sequential(
nn. Linear( 28 * 28 , 512 ) , nn. LeakyReLU( ) ,
nn. Linear( 512 , 256 ) , nn. LeakyReLU( ) ,
nn. Linear( 256 , 1 ) , nn. Sigmoid( ) ,
def forward ( self, x) :
x = x. view( - 1 , 28 * 28 )
x = self. model( x)
return x
device = 'cuda' if torch. cuda. is_available( ) else 'cpu'
gen = Generator( ) . to( device)
dis = Discriminator( ) . to( device)
dis_optim = torch. optim. Adam( dis. parameters( ) , lr= 0.0001 )
gen_optim = torch. optim. Adam( gen. parameters( ) , lr= 0.0001 )
bce_loss = torch. nn. BCELoss( )
def gen_img_plot ( model, epoch, test_input) :
prediction = model( test_input) . detach( ) . cpu( ) . numpy( )
prediction = np. squeeze( prediction)
fig = plt. figure( figsize= ( 4 , 4 ) )
for i in range ( 16 ) :
plt. subplot( 4 , 4 , i+ 1 )
plt. imshow( ( prediction[ i] + 1 ) / 2 )
plt. axis( 'off' )
plt. show( )
def train ( num_epoch, test_input) :
D_loss = [ ]
G_loss = [ ]
for epoch in range ( num_epoch) :
d_epoch_loss = 0
g_epoch_loss = 0
count = len ( dataloader)
for step, ( img, _) in enumerate ( dataloader) :
img = img. to( device)
size = img. size( 0 )
random_noise = torch. randn( size, 100 , device= device)
'''一. 训练判别器'''
dis_optim. zero_grad( )
real_output = dis( img)
d_real_loss = bce_loss( real_output,
torch. ones_like( real_output) )
d_real_loss. backward( )
gen_img = gen( random_noise)
fake_output = dis( gen_img. detach( ) )
d_fake_loss = bce_loss( fake_output,
torch. zeros_like( fake_output) )
d_fake_loss. backward( )
d_loss = d_real_loss+ d_fake_loss
dis_optim. step( )
gen_optim. zero_grad( )
fake_output = dis( gen_img)
g_loss = bce_loss( fake_output,
torch. ones_like( fake_output) )
g_loss. backward( )
gen_optim. step( )
with torch. no_grad( ) :
d_epoch_loss += d_loss
g_epoch_loss += g_loss
with torch. no_grad( ) :
d_epoch_loss /= count
g_epoch_loss /= count
D_loss. append( d_epoch_loss)
G_loss. append( g_epoch_loss)
print ( 'Epoch:' , epoch)
print ( f'd_epoch_loss= { d_epoch_loss} ' )
print ( f'g_epoch_loss= { g_epoch_loss} ' )
gen_img_plot( gen, epoch, test_input)
start_time = time. time( )
test_input = torch. randn( 16 , 100 , device= device)
num_epoch = 50
train( num_epoch, test_input)
end_time = time. time( )
run_time = end_time - start_time
if int ( run_time) < 60 :
print ( f' { round ( run_time, 2 ) } s' )
else :
print ( f' { round ( run_time/ 60 , 2 ) } minutes' )