This notebook presents a simple Generative Adversarial Network (GAN) and more advanced Deep Convolutional GAN. Both are applied MNIST dataset.
We will create three sub-graphs for GAN network as follows:
Contents
import numpy as np
import matplotlib.pyplot as plt
Limit TensorFlow GPU memory usage
import tensorflow as tf
config = tf.ConfigProto()
config.gpu_options.allow_growth = True
with tf.Session(config=config):
pass # init sessin with allow_growth
Load MNIST Dataset from Keras API. We only need train images, ignore labels and validation set.
(x_train_raw, _), (_, _) = tf.keras.datasets.mnist.load_data()
Convert to -1..1 range to mach tanh output from generator
x_train = (x_train_raw-127.5) / 127.5
x_train = x_train.reshape([len(x_train), -1])
print('x_train.shape:', x_train.shape)
print('x_train.min():', x_train.min())
print('x_train.max():', x_train.max())
print('x_train:\n', x_train)
Generator
from tensorflow.keras.layers import InputLayer, Dense # InputLayer for pretier names in TensorBoard
generator = tf.keras.Sequential(name='Generator')
generator.add(InputLayer(input_shape=(100,), name='GenInput')) # random noise input
generator.add(Dense(units=128, activation='elu', name='GenHidden')) # one hidden layer
generator.add(Dense(784, activation='tanh', name='GenOutput')) # MNIST-like output
# generator.compile(...) # no need
generator.summary()
Optional: see graph created so far in TensorBoard
# writer = tf.summary.FileWriter(logdir='tf_log', graph=tf.get_default_graph())
# writer.flush()
Discriminator
discriminator = tf.keras.Sequential(name='Discriminator')
discriminator.add(InputLayer(input_shape=(784,), name='DiscInput'))
discriminator.add(Dense(units=128, activation='elu', input_dim=784, name='DiscHidden'))
discriminator.add(Dense(units=1, activation='sigmoid', name='DiscOut'))
discriminator.compile(optimizer='adam', loss='binary_crossentropy')
#discriminator.compile(optimizer=tf.keras.optimizers.Adam(lr=0.002), loss='binary_crossentropy')
discriminator.summary()
Optional: see graph created so far in TensorBoard - this should produce two independent sub-graphs, one for generator and one for discriminator
# writer = tf.summary.FileWriter(logdir='tf_log', graph=tf.get_default_graph())
# writer.flush()
GAN Model
discriminator.trainable = False # must make non-trainable before compiling gan_model
gan_model = tf.keras.Sequential(name='GAN') # (this doesn't affect anything we done earlier)a
gan_model.add(InputLayer(input_shape=(100,), name='GANInput'))
gan_model.add(generator)
gan_model.add(discriminator)
gan_model.compile(optimizer='adam', loss='binary_crossentropy')
#gan_model.compile(optimizer=tf.keras.optimizers.Adam(lr=0.002), loss='binary_crossentropy')
gan_model.summary()
Optional: see graph created so far in TensorBoard - because weights are shared this will create somewhat messy rendering of three sub-graphs
Train GAN
n_batch = 200
n_epochs = 20 # <- set this to ~100 to see more reasonable results
losses = {'gen':[], 'disc':[]}
indices = np.array(range(len(x_train)))
for e in range(n_epochs):
np.random.shuffle(indices)
for i in range(0, len(x_train), n_batch):
# Generate fake images
noise = np.random.normal(size=[n_batch, 100]) # shape (n_batch, n_rand)
x_fake = generator.predict(noise, batch_size=n_batch) # shape (n_batch, n_data)
# Pick next batch of real images
i_batch = indices[i:i+n_batch]
x_real = x_train[i_batch]
# Join real and fake into one batch
x_all = np.concatenate([x_real, x_fake])
y_all = np.concatenate([.9 * np.ones([n_batch,1]), # use .9 instead 1 as discriminator target
np.zeros([n_batch,1])]) # this is called 'smoothing' and improves learning
# Train discriminator
discriminator.trainable = True # get rid of warning messages (doesn't affect training)
dloss = discriminator.train_on_batch(x_all, y_all) # this trains only discriminator, doesn't touch gen.
noise = np.random.normal(size=[n_batch, 100])
y_fake = np.ones([n_batch, 1])
discriminator.trainable = False # get rid of warning messages
gloss = gan_model.train_on_batch(noise, y_fake)
losses['disc'].append(dloss)
losses['gen'].append(gloss)
print(f'epoch: {e:3} gloss: {gloss:4.2f} dloss: {dloss:4.2f}')
Plot losses during training
plt.plot(losses['disc'], label='disc_loss')
plt.plot(losses['gen'], label='gen_loss')
plt.title('Losses')
plt.legend();
Helper to show bunch of MNIST-like images
def show_images(x):
fig, axes = plt.subplots(nrows=1, ncols=len(x), figsize=[20,4])
for i, ax in enumerate(axes):
ax.imshow(x[i].reshape([28,28]), cmap='gray', vmin=-1, vmax=1)
ax.axis('off')
Show some real images
real_imgs = x_train[0:10]
show_images(real_imgs)
Show some fakes
noise = np.random.normal(size=[10, 100])
fake_imgs = generator.predict(x=noise, batch_size=10)
show_images(fake_imgs)
Generate Graphs
This was used to generate graphs in this post
# tf.keras.utils.plot_model(generator, to_file='assets/gan_mlp_generator.png', show_shapes=True)
# tf.keras.utils.plot_model(discriminator, to_file='assets/gan_mlp_discriminator.png', show_shapes=True)
# tf.keras.utils.plot_model(model, to_file='assets/gan_mlp_gan_model.png', show_shapes=True)