flowchart LR A([Random noise]) --> B[Generator] --> C([Generated image]) D([Image]) --> E[Discriminator] --> F([Prediction of realness probability])
GANs
Notes on Generative Adversarial Networks (GANs).
Imagine a forger trying to forge £20 notes and the police trying to stop them.
The police learn to spot the fakes. But then the forger learns to improve their forging skills to make better fakes.
This goes back and forth. With each iteration, the forger keeps getting better but then the police learn to spot these more sophisticated fakes.
The results in a forger (generator) learning to create convincing fakes and the popo (discriminator) learning to spot fakes.
1. GANs
The idea of GANs is that we can train two competing models:
- The generator tries to convert random noise into convincing observations.
- The discriminator tries to predict whether an observation came from the original training dataset or is a “fake”.
We initialise both as random models; the generator outputs noise and the discriminator predicts randomly. We then alternate the training of the two networks so that the generator gets incrementally better at fooling the discriminator, then the discriminator gets incrementally better at spotting fakes.
2. Building a Deep Convolutional GAN (DCGAN)
We will implement a GAN to generate pictures of bricks.
2.1. Load and pre-process the data
Load image data of lego bricks. We will train a model that can generate novel lego brick images.
tanh vs sigmoid activation
The original data is scaled from [0, 255].
Often we will rescale this to [0, 1] so that we can use sigmoid activation functions.
In this case we will scale to [-1, 1] so that we can use tanh activation functions, which tend to give stronger gradients than sigmoid.
from pathlib import Path
import numpy as np
import matplotlib.pyplot as plt
import tensorflow as tf
from tensorflow.keras import (
layers,
models,
callbacks,
losses,
utils,
metrics,
optimizers,
)
def sample_batch(dataset):
= dataset.take(1).get_single_element()
batch if isinstance(batch, tuple):
= batch[0]
batch return batch.numpy()
def display_images(
=10, size=(20, 3), cmap="gray_r", as_type="float32", save_to=None
images, n
):"""Displays n random images from each one of the supplied arrays."""
if images.max() > 1.0:
= images / 255.0
images elif images.min() < 0.0:
= (images + 1.0) / 2.0
images
=size)
plt.figure(figsizefor i in range(n):
= plt.subplot(1, n, i + 1)
_ =cmap)
plt.imshow(images[i].astype(as_type), cmap"off")
plt.axis(
if save_to:
plt.savefig(save_to)print(f"\nSaved to {save_to}")
plt.show()
Model and data parameters:
= Path("/Users/gurpreetjohl/workspace/datasets/lego-brick-images")
DATA_DIR
= 64
IMAGE_SIZE = 1
CHANNELS = 128
BATCH_SIZE = 100
Z_DIM = 100
EPOCHS = False
LOAD_MODEL = 0.5
ADAM_BETA_1 = 0.999
ADAM_BETA_2 = 0.0002
LEARNING_RATE = 0.1 NOISE_PARAM
Load and pre-process the training data:
def preprocess(img):
"""Normalize and reshape the images."""
= (tf.cast(img, "float32") - 127.5) / 127.5
img return img
= utils.image_dataset_from_directory(
training_data / "dataset",
DATA_DIR =None,
labels="grayscale",
color_mode=(IMAGE_SIZE, IMAGE_SIZE),
image_size=BATCH_SIZE,
batch_size=True,
shuffle=42,
seed="bilinear",
interpolation
)= training_data.map(lambda x: preprocess(x)) train
Found 40000 files belonging to 1 classes.
Some sample input images:
display_images(sample_batch(train))
2.2. Build the Discriminator
The goal of the discriminator is to predict whether an image is real or fake.
This is a supervised binary classification problem, so we can use CNN architecture with a single output node. We stack Conv2D
layers with BatchNormalization
, LeakyReLU
and Dropout
layers sandwiched between.
= layers.Input(shape=(IMAGE_SIZE, IMAGE_SIZE, CHANNELS))
discriminator_input
= layers.Conv2D(64, kernel_size=4, strides=2, padding="same", use_bias=False)(discriminator_input)
x = layers.LeakyReLU(0.2)(x)
x = layers.Dropout(0.3)(x)
x
= layers.Conv2D(128, kernel_size=4, strides=2, padding="same", use_bias=False)(x)
x = layers.BatchNormalization(momentum=0.9)(x)
x = layers.LeakyReLU(0.2)(x)
x = layers.Dropout(0.3)(x)
x
= layers.Conv2D(256, kernel_size=4, strides=2, padding="same", use_bias=False)(x)
x = layers.BatchNormalization(momentum=0.9)(x)
x = layers.LeakyReLU(0.2)(x)
x = layers.Dropout(0.3)(x)
x
= layers.Conv2D(512, kernel_size=4, strides=2, padding="same", use_bias=False)(x)
x = layers.BatchNormalization(momentum=0.9)(x)
x = layers.LeakyReLU(0.2)(x)
x = layers.Dropout(0.3)(x)
x
= layers.Conv2D(1, kernel_size=4, strides=1, padding="valid", use_bias=False, activation="sigmoid")(x)
x = layers.Flatten()(x) # The shape is already 1x1 so no need for a Dense layer after this
discriminator_output
= models.Model(discriminator_input, discriminator_output)
discriminator discriminator.summary()
Model: "model_1"
_________________________________________________________________
Layer (type) Output Shape Param #
=================================================================
input_2 (InputLayer) [(None, 64, 64, 1)] 0
conv2d_5 (Conv2D) (None, 32, 32, 64) 1024
leaky_re_lu_4 (LeakyReLU) (None, 32, 32, 64) 0
dropout_4 (Dropout) (None, 32, 32, 64) 0
conv2d_6 (Conv2D) (None, 16, 16, 128) 131072
batch_normalization_3 (Bat (None, 16, 16, 128) 512
chNormalization)
leaky_re_lu_5 (LeakyReLU) (None, 16, 16, 128) 0
dropout_5 (Dropout) (None, 16, 16, 128) 0
conv2d_7 (Conv2D) (None, 8, 8, 256) 524288
batch_normalization_4 (Bat (None, 8, 8, 256) 1024
chNormalization)
leaky_re_lu_6 (LeakyReLU) (None, 8, 8, 256) 0
dropout_6 (Dropout) (None, 8, 8, 256) 0
conv2d_8 (Conv2D) (None, 4, 4, 512) 2097152
batch_normalization_5 (Bat (None, 4, 4, 512) 2048
chNormalization)
leaky_re_lu_7 (LeakyReLU) (None, 4, 4, 512) 0
dropout_7 (Dropout) (None, 4, 4, 512) 0
conv2d_9 (Conv2D) (None, 1, 1, 1) 8192
flatten_1 (Flatten) (None, 1) 0
=================================================================
Total params: 2765312 (10.55 MB)
Trainable params: 2763520 (10.54 MB)
Non-trainable params: 1792 (7.00 KB)
_________________________________________________________________
2.3. Build the Generator
The purpose of the generator is to turn random noise into convincing images.
The input is a vector sampled from a multivariate Normal distribution, and the output is an image of the same size as the training data.
The discriminator-generator relationship in a GAN is similar to that of the encoder-decoder relations in a VAE.
The architecture of the discriminator is similar to the discriminator but in reverse (like a decoder). We pass stack Conv2DTranspose
layers with BatchNormalization
and LeakyReLU
layers sandwiched in between.
Conv2DTranspose
We use Conv2DTranspose
layers to scale the image size up.
An alternative would be to use stacks of Upsampling2D
and Conv2D
layers, i.e. the following serves the same purpose as a Conv2DTranspose
layer:
= layers.Upsampling2D(size=2)(x)
x = layers.Conv2D(256, kernel_size=4, strides=1, padding="same")(x) x
The Upsampling2D
layer simply repeats each row and column to double its size, then Conv2D applies a convolution.
The idea is similar with Conv2DTranspose
, but the extra rows and columns are filled with zeros rather than repeated existing values.
Conv2DTranspose
layers can result in checkerboard pattern artifacts. Both options are used in practice, so it is often helpful to experiment and see which gives better results.
= layers.Input(shape=(Z_DIM,))
generator_input
= layers.Reshape((1, 1, Z_DIM))(generator_input) # Reshape the input vector so we can apply conv transpose operations to it
x
= layers.Conv2DTranspose(512, kernel_size=4, strides=1, padding="valid", use_bias=False)(x)
x = layers.BatchNormalization(momentum=0.9)(x)
x = layers.LeakyReLU(0.2)(x)
x
= layers.Conv2DTranspose(256, kernel_size=4, strides=2, padding="same", use_bias=False)(x)
x = layers.BatchNormalization(momentum=0.9)(x)
x = layers.LeakyReLU(0.2)(x)
x
= layers.Conv2DTranspose(128, kernel_size=4, strides=2, padding="same", use_bias=False)(x)
x = layers.BatchNormalization(momentum=0.9)(x)
x = layers.LeakyReLU(0.2)(x)
x
= layers.Conv2DTranspose(64, kernel_size=4, strides=2, padding="same", use_bias=False)(x)
x = layers.BatchNormalization(momentum=0.9)(x)
x = layers.LeakyReLU(0.2)(x)
x
= layers.Conv2DTranspose(
generator_output
CHANNELS,=4,
kernel_size=2,
strides="same",
padding=False,
use_bias="tanh",
activation
)(x)
= models.Model(generator_input, generator_output)
generator generator.summary()
Model: "model_2"
_________________________________________________________________
Layer (type) Output Shape Param #
=================================================================
input_3 (InputLayer) [(None, 100)] 0
reshape (Reshape) (None, 1, 1, 100) 0
conv2d_transpose (Conv2DTr (None, 4, 4, 512) 819200
anspose)
batch_normalization_6 (Bat (None, 4, 4, 512) 2048
chNormalization)
leaky_re_lu_8 (LeakyReLU) (None, 4, 4, 512) 0
conv2d_transpose_1 (Conv2D (None, 8, 8, 256) 2097152
Transpose)
batch_normalization_7 (Bat (None, 8, 8, 256) 1024
chNormalization)
leaky_re_lu_9 (LeakyReLU) (None, 8, 8, 256) 0
conv2d_transpose_2 (Conv2D (None, 16, 16, 128) 524288
Transpose)
batch_normalization_8 (Bat (None, 16, 16, 128) 512
chNormalization)
leaky_re_lu_10 (LeakyReLU) (None, 16, 16, 128) 0
conv2d_transpose_3 (Conv2D (None, 32, 32, 64) 131072
Transpose)
batch_normalization_9 (Bat (None, 32, 32, 64) 256
chNormalization)
leaky_re_lu_11 (LeakyReLU) (None, 32, 32, 64) 0
conv2d_transpose_4 (Conv2D (None, 64, 64, 1) 1024
Transpose)
=================================================================
Total params: 3576576 (13.64 MB)
Trainable params: 3574656 (13.64 MB)
Non-trainable params: 1920 (7.50 KB)
_________________________________________________________________
2.4. Train the GAN
We alternate between training the discriminator and generator. They are not trained simultaneously. We want the generated images to be predicted close to 1 because the generator is good, not because the discriminator is weak.
For the discriminator, we create a training set where some images are real images from the training data and some are outputs from the generator. This is then a supervised binary classification problem.
For the generator, we want a way of scoring each generated image on its realness so that we can optimise this. The discriminator provides exactly this. We pass the generated images through the discriminator to get probabilities. The generator wants to fool the discriminator, so ideally this would be a vector of 1s. So the loss function is the binary crossentropy between these probabilities and a vector of 1s.
class DCGAN(models.Model):
def __init__(self, discriminator, generator, latent_dim):
super(DCGAN, self).__init__()
self.discriminator = discriminator
self.generator = generator
self.latent_dim = latent_dim
def compile(self, d_optimizer, g_optimizer):
super(DCGAN, self).compile()
self.loss_fn = losses.BinaryCrossentropy()
self.d_optimizer = d_optimizer
self.g_optimizer = g_optimizer
self.d_loss_metric = metrics.Mean(name="d_loss")
self.d_real_acc_metric = metrics.BinaryAccuracy(name="d_real_acc")
self.d_fake_acc_metric = metrics.BinaryAccuracy(name="d_fake_acc")
self.d_acc_metric = metrics.BinaryAccuracy(name="d_acc")
self.g_loss_metric = metrics.Mean(name="g_loss")
self.g_acc_metric = metrics.BinaryAccuracy(name="g_acc")
@property
def metrics(self):
return [
self.d_loss_metric,
self.d_real_acc_metric,
self.d_fake_acc_metric,
self.d_acc_metric,
self.g_loss_metric,
self.g_acc_metric,
]
def train_step(self, real_images):
# Sample random points in the latent space
= tf.shape(real_images)[0]
batch_size = tf.random.normal(
random_latent_vectors =(batch_size, self.latent_dim)
shape
)
# Train the discriminator on fake images
with tf.GradientTape() as gen_tape, tf.GradientTape() as disc_tape:
= self.generator(
generated_images =True
random_latent_vectors, training
)
# Evaluate the discriminator on the real and fake images
= self.discriminator(real_images, training=True)
real_predictions = self.discriminator(generated_images, training=True)
fake_predictions
= tf.ones_like(real_predictions)
real_labels = real_labels + NOISE_PARAM * tf.random.uniform(
real_noisy_labels
tf.shape(real_predictions)
)= tf.zeros_like(fake_predictions)
fake_labels = fake_labels - NOISE_PARAM * tf.random.uniform(
fake_noisy_labels
tf.shape(fake_predictions)
)
# Calculate the losses
= self.loss_fn(real_noisy_labels, real_predictions)
d_real_loss = self.loss_fn(fake_noisy_labels, fake_predictions)
d_fake_loss = (d_real_loss + d_fake_loss) / 2.0
d_loss
= self.loss_fn(real_labels, fake_predictions)
g_loss
# Update gradients
= disc_tape.gradient(
gradients_of_discriminator self.discriminator.trainable_variables
d_loss,
)= gen_tape.gradient(
gradients_of_generator self.generator.trainable_variables
g_loss,
)
self.d_optimizer.apply_gradients(
zip(gradients_of_discriminator, discriminator.trainable_variables)
)self.g_optimizer.apply_gradients(
zip(gradients_of_generator, generator.trainable_variables)
)
# Update metrics
self.d_loss_metric.update_state(d_loss)
self.d_real_acc_metric.update_state(real_labels, real_predictions)
self.d_fake_acc_metric.update_state(fake_labels, fake_predictions)
self.d_acc_metric.update_state(
[real_labels, fake_labels], [real_predictions, fake_predictions]
)self.g_loss_metric.update_state(g_loss)
self.g_acc_metric.update_state(real_labels, fake_predictions)
return {m.name: m.result() for m in self.metrics}
# Create a DCGAN
= DCGAN(
dcgan =discriminator, generator=generator, latent_dim=Z_DIM
discriminator )
compile(
dcgan.=optimizers.legacy.Adam(
d_optimizer=LEARNING_RATE, beta_1=ADAM_BETA_1, beta_2=ADAM_BETA_2
learning_rate
),=optimizers.legacy.Adam(
g_optimizer=LEARNING_RATE, beta_1=ADAM_BETA_1, beta_2=ADAM_BETA_2
learning_rate
),
)=EPOCHS) dcgan.fit(train, epochs
Epoch 1/100
313/313 [==============================] - 459s 1s/step - d_loss: 0.0252 - d_real_acc: 0.9011 - d_fake_acc: 0.9013 - d_acc: 0.9012 - g_loss: 5.3464 - g_acc: 0.0987
Epoch 2/100
313/313 [==============================] - 459s 1s/step - d_loss: 0.0454 - d_real_acc: 0.8986 - d_fake_acc: 0.8997 - d_acc: 0.8992 - g_loss: 5.2642 - g_acc: 0.1002
Epoch 3/100
313/313 [==============================] - 458s 1s/step - d_loss: 0.0556 - d_real_acc: 0.8958 - d_fake_acc: 0.8975 - d_acc: 0.8967 - g_loss: 4.9679 - g_acc: 0.1025
Epoch 4/100
313/313 [==============================] - 467s 1s/step - d_loss: 0.0246 - d_real_acc: 0.9065 - d_fake_acc: 0.9091 - d_acc: 0.9078 - g_loss: 5.1611 - g_acc: 0.0909
Epoch 5/100
313/313 [==============================] - 470s 1s/step - d_loss: 0.0178 - d_real_acc: 0.9067 - d_fake_acc: 0.9088 - d_acc: 0.9078 - g_loss: 5.1731 - g_acc: 0.0912
Epoch 6/100
313/313 [==============================] - 463s 1s/step - d_loss: 0.0314 - d_real_acc: 0.9116 - d_fake_acc: 0.9105 - d_acc: 0.9110 - g_loss: 5.2774 - g_acc: 0.0895
Epoch 7/100
313/313 [==============================] - 559s 2s/step - d_loss: 0.0229 - d_real_acc: 0.9085 - d_fake_acc: 0.9079 - d_acc: 0.9082 - g_loss: 5.3445 - g_acc: 0.0921
Epoch 8/100
313/313 [==============================] - 467s 1s/step - d_loss: -0.0155 - d_real_acc: 0.9161 - d_fake_acc: 0.9161 - d_acc: 0.9161 - g_loss: 5.7091 - g_acc: 0.0839
Epoch 9/100
313/313 [==============================] - 438s 1s/step - d_loss: -0.0077 - d_real_acc: 0.9220 - d_fake_acc: 0.9224 - d_acc: 0.9222 - g_loss: 5.8731 - g_acc: 0.0776
Epoch 10/100
313/313 [==============================] - 468s 1s/step - d_loss: -0.0472 - d_real_acc: 0.9228 - d_fake_acc: 0.9241 - d_acc: 0.9234 - g_loss: 5.9693 - g_acc: 0.0759
Epoch 11/100
313/313 [==============================] - 430s 1s/step - d_loss: -0.0839 - d_real_acc: 0.9404 - d_fake_acc: 0.9424 - d_acc: 0.9414 - g_loss: 6.1212 - g_acc: 0.0576
Epoch 12/100
313/313 [==============================] - 457s 1s/step - d_loss: 0.0431 - d_real_acc: 0.9053 - d_fake_acc: 0.9046 - d_acc: 0.9049 - g_loss: 6.0708 - g_acc: 0.0954
Epoch 13/100
313/313 [==============================] - 448s 1s/step - d_loss: -0.0154 - d_real_acc: 0.9236 - d_fake_acc: 0.9244 - d_acc: 0.9240 - g_loss: 6.3106 - g_acc: 0.0756
Epoch 14/100
313/313 [==============================] - 432s 1s/step - d_loss: -0.0720 - d_real_acc: 0.9320 - d_fake_acc: 0.9342 - d_acc: 0.9331 - g_loss: 6.6509 - g_acc: 0.0658
Epoch 15/100
313/313 [==============================] - 443s 1s/step - d_loss: 0.0057 - d_real_acc: 0.9072 - d_fake_acc: 0.9097 - d_acc: 0.9085 - g_loss: 6.1399 - g_acc: 0.0903
Epoch 16/100
313/313 [==============================] - 427s 1s/step - d_loss: 0.0069 - d_real_acc: 0.9185 - d_fake_acc: 0.9167 - d_acc: 0.9176 - g_loss: 6.3255 - g_acc: 0.0833
Epoch 17/100
313/313 [==============================] - 439s 1s/step - d_loss: -0.0709 - d_real_acc: 0.9220 - d_fake_acc: 0.9239 - d_acc: 0.9230 - g_loss: 6.8108 - g_acc: 0.0761
Epoch 18/100
313/313 [==============================] - 437s 1s/step - d_loss: 0.0373 - d_real_acc: 0.9288 - d_fake_acc: 0.9512 - d_acc: 0.9400 - g_loss: 8.1066 - g_acc: 0.0488
Epoch 19/100
313/313 [==============================] - 1029s 3s/step - d_loss: -0.1154 - d_real_acc: 0.9408 - d_fake_acc: 0.9420 - d_acc: 0.9414 - g_loss: 7.6274 - g_acc: 0.0580
Epoch 20/100
313/313 [==============================] - 5781s 19s/step - d_loss: -0.0431 - d_real_acc: 0.9222 - d_fake_acc: 0.9231 - d_acc: 0.9227 - g_loss: 7.1953 - g_acc: 0.0769
Epoch 21/100
313/313 [==============================] - 2696s 9s/step - d_loss: -0.0542 - d_real_acc: 0.9176 - d_fake_acc: 0.9205 - d_acc: 0.9191 - g_loss: 7.1675 - g_acc: 0.0794
Epoch 22/100
313/313 [==============================] - 1481s 5s/step - d_loss: -0.1424 - d_real_acc: 0.9398 - d_fake_acc: 0.9400 - d_acc: 0.9399 - g_loss: 7.7399 - g_acc: 0.0600
Epoch 23/100
313/313 [==============================] - 439s 1s/step - d_loss: -0.0263 - d_real_acc: 0.9154 - d_fake_acc: 0.9148 - d_acc: 0.9151 - g_loss: 7.3322 - g_acc: 0.0852
Epoch 24/100
313/313 [==============================] - 443s 1s/step - d_loss: -0.1420 - d_real_acc: 0.9406 - d_fake_acc: 0.9430 - d_acc: 0.9418 - g_loss: 8.1502 - g_acc: 0.0570
Epoch 25/100
313/313 [==============================] - 475s 2s/step - d_loss: -0.1650 - d_real_acc: 0.9393 - d_fake_acc: 0.9395 - d_acc: 0.9394 - g_loss: 7.9602 - g_acc: 0.0605
Epoch 26/100
313/313 [==============================] - 447s 1s/step - d_loss: -0.1247 - d_real_acc: 0.9419 - d_fake_acc: 0.9437 - d_acc: 0.9428 - g_loss: 7.7939 - g_acc: 0.0563
Epoch 27/100
313/313 [==============================] - 1356s 4s/step - d_loss: -0.0182 - d_real_acc: 0.9160 - d_fake_acc: 0.9212 - d_acc: 0.9186 - g_loss: 6.9180 - g_acc: 0.0787
Epoch 28/100
313/313 [==============================] - 2219s 7s/step - d_loss: -0.2227 - d_real_acc: 0.9511 - d_fake_acc: 0.9519 - d_acc: 0.9515 - g_loss: 8.1970 - g_acc: 0.0481
Epoch 29/100
313/313 [==============================] - 5807s 19s/step - d_loss: -0.1091 - d_real_acc: 0.9318 - d_fake_acc: 0.9320 - d_acc: 0.9319 - g_loss: 7.5829 - g_acc: 0.0680
Epoch 30/100
313/313 [==============================] - 2511s 8s/step - d_loss: -0.3131 - d_real_acc: 0.9571 - d_fake_acc: 0.9604 - d_acc: 0.9588 - g_loss: 9.8839 - g_acc: 0.0395
Epoch 31/100
313/313 [==============================] - 2768s 9s/step - d_loss: -0.0996 - d_real_acc: 0.9269 - d_fake_acc: 0.9286 - d_acc: 0.9277 - g_loss: 8.3337 - g_acc: 0.0714
Epoch 32/100
313/313 [==============================] - 3046s 10s/step - d_loss: -0.1619 - d_real_acc: 0.9423 - d_fake_acc: 0.9482 - d_acc: 0.9453 - g_loss: 8.2435 - g_acc: 0.0518
Epoch 33/100
313/313 [==============================] - 3478s 11s/step - d_loss: -0.1182 - d_real_acc: 0.9284 - d_fake_acc: 0.9304 - d_acc: 0.9294 - g_loss: 8.1681 - g_acc: 0.0696
Epoch 34/100
313/313 [==============================] - 2776s 9s/step - d_loss: -0.2214 - d_real_acc: 0.9459 - d_fake_acc: 0.9582 - d_acc: 0.9520 - g_loss: 9.9168 - g_acc: 0.0417
Epoch 35/100
313/313 [==============================] - 2724s 9s/step - d_loss: -0.3101 - d_real_acc: 0.9421 - d_fake_acc: 0.9293 - d_acc: 0.9357 - g_loss: 13.2857 - g_acc: 0.0707
Epoch 36/100
313/313 [==============================] - 2648s 8s/step - d_loss: -0.0441 - d_real_acc: 0.8963 - d_fake_acc: 0.8961 - d_acc: 0.8962 - g_loss: 7.5664 - g_acc: 0.1038
Epoch 37/100
313/313 [==============================] - 3262s 10s/step - d_loss: -0.0859 - d_real_acc: 0.9314 - d_fake_acc: 0.9402 - d_acc: 0.9358 - g_loss: 8.3591 - g_acc: 0.0598
Epoch 38/100
313/313 [==============================] - 2612s 8s/step - d_loss: -0.2979 - d_real_acc: 0.9554 - d_fake_acc: 0.9577 - d_acc: 0.9566 - g_loss: 9.2534 - g_acc: 0.0423
Epoch 39/100
313/313 [==============================] - 2235s 7s/step - d_loss: -0.3387 - d_real_acc: 0.9607 - d_fake_acc: 0.9622 - d_acc: 0.9615 - g_loss: 9.9397 - g_acc: 0.0378
Epoch 40/100
313/313 [==============================] - 3453s 11s/step - d_loss: -0.1056 - d_real_acc: 0.9279 - d_fake_acc: 0.9310 - d_acc: 0.9294 - g_loss: 8.9394 - g_acc: 0.0690
Epoch 41/100
313/313 [==============================] - 2316s 7s/step - d_loss: -0.2147 - d_real_acc: 0.9318 - d_fake_acc: 0.9327 - d_acc: 0.9323 - g_loss: 9.0337 - g_acc: 0.0673
Epoch 42/100
313/313 [==============================] - 3134s 10s/step - d_loss: -0.2554 - d_real_acc: 0.9511 - d_fake_acc: 0.9540 - d_acc: 0.9526 - g_loss: 9.5571 - g_acc: 0.0460
Epoch 43/100
313/313 [==============================] - 3933s 13s/step - d_loss: -0.2871 - d_real_acc: 0.9490 - d_fake_acc: 0.9526 - d_acc: 0.9508 - g_loss: 10.3316 - g_acc: 0.0474
Epoch 44/100
313/313 [==============================] - 3248s 10s/step - d_loss: -0.3456 - d_real_acc: 0.9635 - d_fake_acc: 0.9635 - d_acc: 0.9635 - g_loss: 9.8675 - g_acc: 0.0364
Epoch 45/100
313/313 [==============================] - 3043s 10s/step - d_loss: -0.3274 - d_real_acc: 0.9603 - d_fake_acc: 0.9618 - d_acc: 0.9610 - g_loss: 10.4185 - g_acc: 0.0382
Epoch 46/100
313/313 [==============================] - 2706s 9s/step - d_loss: -0.6160 - d_real_acc: 0.9902 - d_fake_acc: 0.9908 - d_acc: 0.9905 - g_loss: 13.0574 - g_acc: 0.0092
Epoch 47/100
313/313 [==============================] - 2453s 8s/step - d_loss: 0.3413 - d_real_acc: 0.8073 - d_fake_acc: 0.8054 - d_acc: 0.8064 - g_loss: 6.5391 - g_acc: 0.1946
Epoch 48/100
313/313 [==============================] - 2898s 9s/step - d_loss: -0.4416 - d_real_acc: 0.9764 - d_fake_acc: 0.9784 - d_acc: 0.9774 - g_loss: 10.8318 - g_acc: 0.0216
Epoch 49/100
313/313 [==============================] - 3358s 11s/step - d_loss: 6.8776 - d_real_acc: 0.1058 - d_fake_acc: 0.9910 - d_acc: 0.5484 - g_loss: 14.8921 - g_acc: 0.0090
Epoch 50/100
313/313 [==============================] - 2940s 9s/step - d_loss: 7.7113 - d_real_acc: 0.0000e+00 - d_fake_acc: 1.0000 - d_acc: 0.5000 - g_loss: 15.4250 - g_acc: 0.0000e+00
Epoch 51/100
313/313 [==============================] - 2983s 10s/step - d_loss: 7.7121 - d_real_acc: 0.0000e+00 - d_fake_acc: 1.0000 - d_acc: 0.5000 - g_loss: 15.4250 - g_acc: 0.0000e+00
Epoch 52/100
313/313 [==============================] - 3458s 11s/step - d_loss: 7.7149 - d_real_acc: 0.0000e+00 - d_fake_acc: 1.0000 - d_acc: 0.5000 - g_loss: 15.4250 - g_acc: 0.0000e+00
Epoch 53/100
313/313 [==============================] - 2404s 8s/step - d_loss: 4.4950 - d_real_acc: 0.3543 - d_fake_acc: 0.8865 - d_acc: 0.6204 - g_loss: 10.6772 - g_acc: 0.1135
Epoch 54/100
313/313 [==============================] - 3297s 11s/step - d_loss: -0.0132 - d_real_acc: 0.9068 - d_fake_acc: 0.9010 - d_acc: 0.9039 - g_loss: 7.9660 - g_acc: 0.0990
Epoch 55/100
313/313 [==============================] - 2486s 8s/step - d_loss: -0.3508 - d_real_acc: 0.9615 - d_fake_acc: 0.9612 - d_acc: 0.9614 - g_loss: 10.2242 - g_acc: 0.0388
Epoch 56/100
313/313 [==============================] - 2995s 10s/step - d_loss: -0.3125 - d_real_acc: 0.9525 - d_fake_acc: 0.9533 - d_acc: 0.9529 - g_loss: 10.4182 - g_acc: 0.0467
Epoch 57/100
313/313 [==============================] - 1791s 6s/step - d_loss: -0.3201 - d_real_acc: 0.9532 - d_fake_acc: 0.9560 - d_acc: 0.9546 - g_loss: 10.4752 - g_acc: 0.0441
Epoch 58/100
313/313 [==============================] - 2792s 9s/step - d_loss: -0.2649 - d_real_acc: 0.9509 - d_fake_acc: 0.9532 - d_acc: 0.9520 - g_loss: 9.3587 - g_acc: 0.0468
Epoch 59/100
313/313 [==============================] - 3665s 12s/step - d_loss: -0.1747 - d_real_acc: 0.9413 - d_fake_acc: 0.9584 - d_acc: 0.9499 - g_loss: 10.1369 - g_acc: 0.0416
Epoch 60/100
313/313 [==============================] - 2493s 8s/step - d_loss: -0.2692 - d_real_acc: 0.9499 - d_fake_acc: 0.9534 - d_acc: 0.9517 - g_loss: 9.7124 - g_acc: 0.0466
Epoch 61/100
313/313 [==============================] - 2293s 7s/step - d_loss: -0.2869 - d_real_acc: 0.9520 - d_fake_acc: 0.9556 - d_acc: 0.9538 - g_loss: 10.2684 - g_acc: 0.0444
Epoch 62/100
313/313 [==============================] - 2865s 9s/step - d_loss: -0.6188 - d_real_acc: 0.9900 - d_fake_acc: 0.9901 - d_acc: 0.9901 - g_loss: 12.9584 - g_acc: 0.0099
Epoch 63/100
313/313 [==============================] - 2301s 7s/step - d_loss: -0.7197 - d_real_acc: 0.9984 - d_fake_acc: 0.9985 - d_acc: 0.9985 - g_loss: 14.5762 - g_acc: 0.0015
Epoch 64/100
313/313 [==============================] - 2404s 8s/step - d_loss: -0.4320 - d_real_acc: 0.9665 - d_fake_acc: 0.9702 - d_acc: 0.9683 - g_loss: 12.0177 - g_acc: 0.0298
Epoch 65/100
313/313 [==============================] - 4723s 15s/step - d_loss: 6.7591 - d_real_acc: 0.9940 - d_fake_acc: 0.1077 - d_acc: 0.5509 - g_loss: 1.4402 - g_acc: 0.8923
Epoch 66/100
313/313 [==============================] - 2726s 9s/step - d_loss: 7.6259 - d_real_acc: 1.0000 - d_fake_acc: 0.0000e+00 - d_acc: 0.5000 - g_loss: 0.0000e+00 - g_acc: 1.0000
Epoch 67/100
313/313 [==============================] - 3325s 11s/step - d_loss: 7.6250 - d_real_acc: 1.0000 - d_fake_acc: 0.0000e+00 - d_acc: 0.5000 - g_loss: 0.0000e+00 - g_acc: 1.0000
Epoch 68/100
313/313 [==============================] - 2513s 8s/step - d_loss: 7.6251 - d_real_acc: 1.0000 - d_fake_acc: 0.0000e+00 - d_acc: 0.5000 - g_loss: 7.4387e-12 - g_acc: 1.0000
Epoch 69/100
313/313 [==============================] - 3304s 11s/step - d_loss: 3.4076 - d_real_acc: 0.8552 - d_fake_acc: 0.5472 - d_acc: 0.7012 - g_loss: 5.2072 - g_acc: 0.4528
Epoch 70/100
313/313 [==============================] - 2276s 7s/step - d_loss: -0.2424 - d_real_acc: 0.9448 - d_fake_acc: 0.9581 - d_acc: 0.9514 - g_loss: 10.5206 - g_acc: 0.0419
Epoch 71/100
313/313 [==============================] - 3315s 11s/step - d_loss: -0.3093 - d_real_acc: 0.9563 - d_fake_acc: 0.9640 - d_acc: 0.9602 - g_loss: 10.6076 - g_acc: 0.0360
Epoch 72/100
313/313 [==============================] - 2306s 7s/step - d_loss: -0.2440 - d_real_acc: 0.9466 - d_fake_acc: 0.9581 - d_acc: 0.9523 - g_loss: 10.1996 - g_acc: 0.0419
Epoch 73/100
313/313 [==============================] - 3218s 10s/step - d_loss: -0.7206 - d_real_acc: 0.9985 - d_fake_acc: 0.9983 - d_acc: 0.9984 - g_loss: 14.6350 - g_acc: 0.0017
Epoch 74/100
313/313 [==============================] - 3258s 10s/step - d_loss: -0.6281 - d_real_acc: 0.9828 - d_fake_acc: 0.9841 - d_acc: 0.9834 - g_loss: 14.5219 - g_acc: 0.0159
Epoch 75/100
313/313 [==============================] - 2874s 9s/step - d_loss: -0.0555 - d_real_acc: 0.9254 - d_fake_acc: 0.9371 - d_acc: 0.9312 - g_loss: 9.3443 - g_acc: 0.0629
Epoch 76/100
313/313 [==============================] - 2559s 8s/step - d_loss: -0.2825 - d_real_acc: 0.9515 - d_fake_acc: 0.9611 - d_acc: 0.9563 - g_loss: 10.7583 - g_acc: 0.0388
Epoch 77/100
313/313 [==============================] - 3663s 12s/step - d_loss: -0.3945 - d_real_acc: 0.9667 - d_fake_acc: 0.9691 - d_acc: 0.9679 - g_loss: 10.8566 - g_acc: 0.0309
Epoch 78/100
313/313 [==============================] - 2314s 7s/step - d_loss: -0.3953 - d_real_acc: 0.9508 - d_fake_acc: 0.9529 - d_acc: 0.9519 - g_loss: 12.0037 - g_acc: 0.0471
Epoch 79/100
313/313 [==============================] - 2816s 9s/step - d_loss: -0.6059 - d_real_acc: 0.9841 - d_fake_acc: 0.9838 - d_acc: 0.9840 - g_loss: 13.3516 - g_acc: 0.0162
Epoch 80/100
313/313 [==============================] - 3232s 10s/step - d_loss: -0.3555 - d_real_acc: 0.9587 - d_fake_acc: 0.9649 - d_acc: 0.9618 - g_loss: 11.2453 - g_acc: 0.0351
Epoch 81/100
313/313 [==============================] - 3705s 12s/step - d_loss: -0.4501 - d_real_acc: 0.9731 - d_fake_acc: 0.9743 - d_acc: 0.9737 - g_loss: 11.4553 - g_acc: 0.0258
Epoch 82/100
313/313 [==============================] - 2364s 8s/step - d_loss: -0.3827 - d_real_acc: 0.9588 - d_fake_acc: 0.9639 - d_acc: 0.9613 - g_loss: 11.6275 - g_acc: 0.0361
Epoch 83/100
313/313 [==============================] - 1122s 4s/step - d_loss: -0.4355 - d_real_acc: 0.9642 - d_fake_acc: 0.9674 - d_acc: 0.9658 - g_loss: 12.1025 - g_acc: 0.0326
Epoch 84/100
313/313 [==============================] - 4065s 13s/step - d_loss: -0.4456 - d_real_acc: 0.9695 - d_fake_acc: 0.9714 - d_acc: 0.9704 - g_loss: 11.6065 - g_acc: 0.0287
Epoch 85/100
313/313 [==============================] - 4461s 14s/step - d_loss: -0.6405 - d_real_acc: 0.9901 - d_fake_acc: 0.9899 - d_acc: 0.9900 - g_loss: 13.4694 - g_acc: 0.0101
Epoch 86/100
313/313 [==============================] - 2630s 8s/step - d_loss: -0.6431 - d_real_acc: 0.9857 - d_fake_acc: 0.9856 - d_acc: 0.9857 - g_loss: 14.3623 - g_acc: 0.0144
Epoch 87/100
313/313 [==============================] - 2567s 8s/step - d_loss: -0.3870 - d_real_acc: 0.9534 - d_fake_acc: 0.9578 - d_acc: 0.9556 - g_loss: 12.0201 - g_acc: 0.0422
Epoch 88/100
313/313 [==============================] - 2597s 8s/step - d_loss: -0.7624 - d_real_acc: 0.9999 - d_fake_acc: 0.9998 - d_acc: 0.9998 - g_loss: 15.3547 - g_acc: 2.5000e-04
Epoch 89/100
313/313 [==============================] - 1477s 5s/step - d_loss: -0.5787 - d_real_acc: 0.9764 - d_fake_acc: 0.9759 - d_acc: 0.9762 - g_loss: 13.8500 - g_acc: 0.0241
Epoch 90/100
313/313 [==============================] - 522s 2s/step - d_loss: -0.6747 - d_real_acc: 0.9885 - d_fake_acc: 0.9897 - d_acc: 0.9891 - g_loss: 14.5329 - g_acc: 0.0104
Epoch 91/100
313/313 [==============================] - 512s 2s/step - d_loss: 6.4703 - d_real_acc: 0.9892 - d_fake_acc: 0.1438 - d_acc: 0.5665 - g_loss: 2.1184 - g_acc: 0.8562
Epoch 92/100
313/313 [==============================] - 514s 2s/step - d_loss: 7.6245 - d_real_acc: 1.0000 - d_fake_acc: 0.0000e+00 - d_acc: 0.5000 - g_loss: 0.0000e+00 - g_acc: 1.0000
Epoch 93/100
313/313 [==============================] - 533s 2s/step - d_loss: 7.6249 - d_real_acc: 1.0000 - d_fake_acc: 0.0000e+00 - d_acc: 0.5000 - g_loss: 0.0000e+00 - g_acc: 1.0000
Epoch 94/100
313/313 [==============================] - 499s 2s/step - d_loss: 7.6236 - d_real_acc: 1.0000 - d_fake_acc: 0.0000e+00 - d_acc: 0.5000 - g_loss: 0.0000e+00 - g_acc: 1.0000
Epoch 95/100
313/313 [==============================] - 483s 2s/step - d_loss: 7.6240 - d_real_acc: 1.0000 - d_fake_acc: 0.0000e+00 - d_acc: 0.5000 - g_loss: 0.0000e+00 - g_acc: 1.0000
Epoch 96/100
313/313 [==============================] - 481s 2s/step - d_loss: 7.6248 - d_real_acc: 1.0000 - d_fake_acc: 0.0000e+00 - d_acc: 0.5000 - g_loss: 0.0000e+00 - g_acc: 1.0000
Epoch 97/100
313/313 [==============================] - 488s 2s/step - d_loss: 7.6247 - d_real_acc: 1.0000 - d_fake_acc: 0.0000e+00 - d_acc: 0.5000 - g_loss: 0.0000e+00 - g_acc: 1.0000
Epoch 98/100
313/313 [==============================] - 481s 2s/step - d_loss: 7.6263 - d_real_acc: 1.0000 - d_fake_acc: 0.0000e+00 - d_acc: 0.5000 - g_loss: 0.0000e+00 - g_acc: 1.0000
Epoch 99/100
313/313 [==============================] - 459s 1s/step - d_loss: 7.6235 - d_real_acc: 1.0000 - d_fake_acc: 0.0000e+00 - d_acc: 0.5000 - g_loss: 0.0000e+00 - g_acc: 1.0000
Epoch 100/100
313/313 [==============================] - 4669s 15s/step - d_loss: 7.6231 - d_real_acc: 1.0000 - d_fake_acc: 0.0000e+00 - d_acc: 0.5000 - g_loss: 0.0000e+00 - g_acc: 1.0000
<keras.src.callbacks.History at 0x10f686690>
# Save the final models
"./models/generator")
generator.save("./models/discriminator") discriminator.save(
WARNING:tensorflow:Compiled the loaded model, but the compiled metrics have yet to be built. `model.compile_metrics` will be empty until you train or evaluate the model.
INFO:tensorflow:Assets written to: ./models/generator/assets
WARNING:tensorflow:Compiled the loaded model, but the compiled metrics have yet to be built. `model.compile_metrics` will be empty until you train or evaluate the model.
INFO:tensorflow:Assets written to: ./models/discriminator/assets
WARNING:tensorflow:Compiled the loaded model, but the compiled metrics have yet to be built. `model.compile_metrics` will be empty until you train or evaluate the model.
INFO:tensorflow:Assets written to: ./models/generator/assets
WARNING:tensorflow:Compiled the loaded model, but the compiled metrics have yet to be built. `model.compile_metrics` will be empty until you train or evaluate the model.
INFO:tensorflow:Assets written to: ./models/discriminator/assets
3. Analysing the GAN
We can see some examples of images produced by the GAN.
(I don’t have a GPU so training is slow and I only trained 100 epochs… they’re a bit crap)
# Sample some points in the latent space, from the standard normal distribution
= (10, 3)
grid_width, grid_height = np.random.normal(size=(grid_width * grid_height, Z_DIM))
z_sample
# Decode the sampled points
= generator.predict(z_sample)
reconstructions
# Draw a plot of decoded images
= plt.figure(figsize=(18, 5))
fig =0.4, wspace=0.4)
fig.subplots_adjust(hspace
# Output the grid of faces
for i in range(grid_width * grid_height):
= fig.add_subplot(grid_height, grid_width, i + 1)
ax "off")
ax.axis(="Greys") ax.imshow(reconstructions[i, :, :], cmap
We also want to make sure a generative model doesn’t simply recreate images that are already in the training set.
As a sanity check, we plot some generated images and the closest training images (using the L1 distance). This confirms that the generator is able to understand high-level features, even though we didn’t provide anything other than raw pixels, and it can generate examples distinct from those encountered before.
def compare_images(img1, img2):
return np.mean(np.abs(img1 - img2))
= []
all_data for i in train.as_numpy_iterator():
all_data.extend(i)= np.array(all_data)
all_data
# Plot the images
= 3, 5
r, c = plt.subplots(r, c, figsize=(10, 6))
fig, axs "Generated images", fontsize=20)
fig.suptitle(
= np.random.normal(size=(r * c, Z_DIM))
noise = generator.predict(noise)
gen_imgs
= 0
cnt for i in range(r):
for j in range(c):
="gray_r")
axs[i, j].imshow(gen_imgs[cnt], cmap"off")
axs[i, j].axis(+= 1
cnt
plt.show()
1/1 [==============================] - 0s 80ms/step
= plt.subplots(r, c, figsize=(10, 6))
fig, axs "Closest images in the training set", fontsize=20)
fig.suptitle(
= 0
cnt for i in range(r):
for j in range(c):
= 99999
c_diff = None
c_img for k_idx, k in enumerate(all_data):
= compare_images(gen_imgs[cnt], k)
diff if diff < c_diff:
= np.copy(k)
c_img = diff
c_diff ="gray_r")
axs[i, j].imshow(c_img, cmap"off")
axs[i, j].axis(+= 1
cnt
plt.show()
4. GAN Training Tips
GANs are notoriously difficult to train because there is a balancing act between the generator and discriminator; neither should grow so strong that it overpowers the other.
4.1. Discriminator Overpowers Generator
The discriminator can always spot the fakes, so the signal from the loss function becomes too weak to cause any meaningful improvements in the generator.
In the extreme case, the discriminator distinguishes fakes perfectly, so gradients vanish and no training takes place.
We need to weaken the discriminator in this case. Some possible options are:
- More dropout
- Lower learning rate
- Simplify the discriminator architecture - use fewer layers
- Add noise to the labels when training the discriminator
- Add intentional labelling errors - randomly flip the labels of some images when training the discriminator
4.2. Generator Overpowers Discriminator
If the discriminator is too weak, the generator will learn that it can trick the discriminator using a small sample of nearly identical images. This is known as mode collapse. The generator would map every point in the latent space to this image, so the gradients of the loss function would vanish and it would not be able to recover.
Strengthening the discriminator would not help because the generator would just learn to find a different mode that fools the discriminator with no diversity; it is numb to its input.
Some possible options are:
- Strengthen the discriminator - do the opposite of the previous section
- Reduce the learning rate of both generator and discriminator
- Increase the batch size
4.3. Uninformative Loss
The value of the loss is not meaningful as a measure of the generator’s strength when training.
The loss function is relative to the discriminator, and since the discriminator is also being trained, the goalposts are constantly shifting. Also, we don’t want the loss function to reach 0 or else we may reach mode collapse as described above.
This makes GAN training difficult to monitor.
4.4 Hyperparameters
There are a lot of hyperparameters involved with GANs because we are now training two networks.
The performance is highly sensitive to these hyperparameter choices, and involves a lot of trial and error.
5. Wasserstein GAN with Gradient Penalty (WGAN-GP)
The Wasserstein GAN replaces the binary crossentropy loss function with the Wassserstein lss function in both the discriminator and generator.
This results in two desirable properties:
- A meaningful loss metric that correlates with generator convergence. This allows for better monitoring of training.
- More stable optimisation process.
5.1. Wassertein Loss
Revisiting Binary Cross-Entropy Loss
First, recall the binary cross-entropy loss, which is defined as: \[ -\frac{1}{n} \sum_{i=1}^{n}{y_i log(p_i) + (1 - y_i)log(1-p_i)} \]
In a regular GAN, the discriminator compares the predictions for real images \(p_i = D(x_i)\) to the response \(y_i = 1\), and it compares the predictions for generated images \(p_i = D(G(z_i))\) to the response \(y_i = 0\).
So the discriminator loss minimisation can be written as: \[ \min_{D} -(E_{x \sim p_X}[log D(x)] + E_{z \sim p_Z}[log (1 - D(G(z))] ) \]
The generator aims to trick the discriminator into believing the images are real, so it compares the discriminator’s predicted response to the desired response of \(y_i=1\). So the generator loss minimisation can be written as: \[ \min_{G} -(E_{z \sim p_Z}[log D(G(z)] ) \]
Introducing Wasserstein Loss
In contrast , the Wasserstein loss function is defined as: \[ -\frac{1}{n} \sum_{i=1}^{n}{y_i p_i} \]
It requires that we use \(y_i=1\) and \(y_i=-1\) as labels rather than 1 and 0. We also remove the final sigmoid activation layer, which constrained the output to the range \([0, 1]\), meaning the output can now be any real number in the range \((-\infty, \infty)\).
Because of these changes, rather than referring to a discriminator which outputs a probability, for WGANs we refer to a critic which outputs a score.
With these changes of labels to -1 and 1, the critic loss minimisation becomes: \[ \min_{D} -(E_{x \sim p_X}[D(x)] - E_{z \sim p_Z}[D(G(z)] ) \]
i.e. it aims to maximise the difference in predictions between real and generated images.
The generator is still trying to trick the critic as in a regular GAN, so it wants the critic to score its generated images as highly as possible. To this end, it still compares to the desired critic response of \(y_i=1\) corresponding to the critic believing the generated image is real. So the generator loss minimisation remains similar, just without the log: \[ \min_{G} -(E_{z \sim p_Z}[D(G(z)] ) \]
5.2. Lipschitz Constraint
Allowing the critic to use the range \((-\infty, \infty)\) initially seems a bit counterintuitive - usually we want to avoid large numbers in neural networks otherwise gradients explode!
There is an additional constraint placed on the critic requiring it to be a 1-Lipschitz continuous function. This means that for any two input images, \(x_1\) and \(x_2\), it satisfied the following inequality: \[ \frac{| D(x_1) - D(x_2) |}{| x_1 - x_2 |} \le 1 \]
Recall that the critic is a function \(D\) which converts an image into a scalar prediction. The numerator is the change in predictions, and the denominator is the average pixelwise difference. So this is essentially a limit on how sharply the critic predictions are allowed to change for a given image perturbation.
An in-depth exploration of why the Wasserstein loss requires this constraint is given here.
Enforcing the Lipschitz Constraint
One crude way of enforcing the constraint suggested by the original authors (and described as “a terrible way to enforce a Lipschitz constraint) is to clip the weights to \([-0.01, 0.01]\) after each training batch. However, this diminishes the critic’s ability to learn.
An improved approach is to introduce a gradient penalty term to penalise the gradient norm when it deviates from 1.
5.3. Gradient Penalty Loss
The gradient penalty (GP) loss term encourages gradients towards 1 to conform to the 1-Lipschitz continuous constraint.
The GP loss measures the squared difference between the norm of the gradient of predictions w.r.t. input images and 1, i.e. \[ loss_{GP} = (||\frac{\delta predictions}{\delta input}|| - 1)^2 \]
Calculating this everywhere during the training process would be computationally intensive. Instead we evaluate it at a sample of points. To encourage a balanced mix, we use an interpolated image which is a pairwise weighted average of a real image and a generated image.
The combined loss function is then a weighted sum of the Wasserstein loss and the GP loss: \[ loss = loss_{Wasserstein} + \lambda_{GP} * loss_{GP} \]
5.4. Training the WGAN-GP
When training WGANs, we train the critic to convergence and then the train the generator. This ensures the gradients used for the generator update are accurate.
This is a major benefit over traditional GANs, where we must balance the alternating training of generator and discriminator.
We train the critic several times between each generator training step. A typical ratio is 3-5 critic updates for each generator update.
Note that batch normalisation shouldn’t be used for WGAN-GP because it creates correlation between images of a given batch, which makes the GP term less effective.
5.5. Analysing the WGAN-GP
Images produced by VAEs tend to produce softer images that blur colour boundaries, whereas GANs produce sharper, more well-defined.
However, GANs typically take longer to train and can be more sensitive to hyperparameter choices to reach a satisfactory result.
6. Conditional GAN (CGAN)
Conditional GANs allow us to control the type of image that is generated. For example, should we generate a large or small brick? A male or female face?
6.1. CGAN Architecture
The key difference of a CGAN is we pass a one-hot encoded label.
For the generator, we append the one-hot encoded label vector to the random noise input to the generator. For the critic, we append the one-hot encoded label vector to the image. If it has multiple channels, as in an RGB image, the label vector is repeated on each of the channels to fit the required shape.
The critic can now see the label, which means the generator needs to create images that match the generated labels. If the generated image was inconsistent with the generated label, then the critic could use this to easily distinguish the fakes.
The only change required to the architecture is to concatenate the label to the inputs of the discriminator (actual training labels) and the generator (initially randomised vectors of the correct shape).
The images and labels are unpacked during the training step.
6.2. Using the CGAN
We can control the type of image generated by passing a particular one-hot encoded label into the input of the generator.
If labels are available for your training data, it is generally a good idea to include them in the inputs even if you don’t intend to create a conditional GAN from it. They tend to improve the quality of the output generated, since the labels act as a highly informative extension of the pixel inputs.
References
- Chapter 4 of Generative Deep Learning by David Foster.
- Wasserstein loss and Lipschitz constraint