flowchart LR A(Encoder) --> B(z) B(z) --> c(Decoder)
Variational Autoencoders
Notes on VAEs.
Imagine an infinite wardrobe organised by “type” of clothing.
Shoes would be close together, but formal shoes might be closer to the suits and trainers closer to the sports gear. Shirts and t-shirts would be close together. Coats might be nearby; the shirt->coat vector applied to t-shirts might lead you to “invent” gilets.
This encapsulates the idea of using a lower dimensional (2D in this case) latent space to encode the representation of more complex objects.
We could sample from some of the empty spaces to invent new hybrids of clothing. This generative step is decoding the latent space.
1. Autoencoders
The idea of autoencoders (read: self-encoders) is that they learn to simplify the input then reconstruct it; the input and target output are the same.
- The encoder learns to compress high-dimensional input data into a lower dimensional representation called the embedding.
- The decoder takes an embedding and recreates a higher-dimensional image. This should be an accurate reconstruction of the input.
This can be used as a generative model because we can the sample and decode new points from the latent space to generate novel outputs. The goal of training an autoencoder is to learn a meaningful embedding \(z\).
This also makes autoencoders useful as denoising models, because the embedding should retain the salient information but “lose” the noise.
2. Building an Autoencoder
We will implement an autoencoder to learn lower-dimensional embeddings for the fashion MNIST data set.
Code
import numpy as np
import matplotlib.pyplot as plt
from scipy.stats import norm
import tensorflow as tf
from tensorflow.keras import (
layers,
models,
datasets,
callbacks,
losses,
optimizers,
metrics,
)
# Parameters
= 32
IMAGE_SIZE = 1
CHANNELS = 100
BATCH_SIZE = 1000
BUFFER_SIZE = 0.2
VALIDATION_SPLIT = 2
EMBEDDING_DIM = 3 EPOCHS
2.1. Load and pre-process the data
Scale the pixel values and reshape the images.
Code
= datasets.fashion_mnist.load_data()
(x_train, y_train), (x_test, y_test)
def preprocess(images):
= images.astype("float32") / 255.0
images = np.pad(images, ((0, 0), (2, 2), (2, 2)), constant_values=0.0)
images = np.expand_dims(images, -1)
images return images
= preprocess(x_train)
x_train = preprocess(x_test) x_test
We can see an example from our training set:
Code
0]) plt.imshow(x_train[
2.2. Build the Encoder
The encoder compresses the dimensionality on the input to a smaller embedding dimension.
Code
# Input
= layers.Input(shape=(IMAGE_SIZE, IMAGE_SIZE, CHANNELS),name="encoder_input")
encoder_input
# Conv layers
= layers.Conv2D(32, (3, 3), strides=2, activation="relu", padding="same")(encoder_input)
x = layers.Conv2D(64, (3, 3), strides=2, activation="relu", padding="same")(x)
x = layers.Conv2D(128, (3, 3), strides=2, activation="relu", padding="same")(x)
x = tf.keras.backend.int_shape(x)[1:] # Used by the decoder later
pre_flatten_shape
# Output
= layers.Flatten()(x)
x = layers.Dense(EMBEDDING_DIM, name="encoder_output")(x)
encoder_output
# Model
= models.Model(encoder_input, encoder_output)
encoder encoder.summary()
Model: "model"
_________________________________________________________________
Layer (type) Output Shape Param #
=================================================================
encoder_input (InputLayer) [(None, 32, 32, 1)] 0
conv2d (Conv2D) (None, 16, 16, 32) 320
conv2d_1 (Conv2D) (None, 8, 8, 64) 18496
conv2d_2 (Conv2D) (None, 4, 4, 128) 73856
flatten (Flatten) (None, 2048) 0
encoder_output (Dense) (None, 2) 4098
=================================================================
Total params: 96770 (378.01 KB)
Trainable params: 96770 (378.01 KB)
Non-trainable params: 0 (0.00 Byte)
_________________________________________________________________
2.3. Build the Decoder
The decoder reconstructs the original image from the embedding.
Convolutional Transpose Layers
In a standard convolutional layer, if we have stride=2
it will half the image size.
In a convolutional transpose layer, we are increasing the image size. The stride
parameter determines the amount of zero padding to add between each pixel. A kernel is then applied to this “internally padded” image to expand the image size.
Code
# Input
= layers.Input(shape=(EMBEDDING_DIM,),name="decoder_input")
decoder_input
# Reshape the input using the pre-flattening shape from the encoder
= layers.Dense(np.prod(pre_flatten_shape))(decoder_input)
x = layers.Reshape(pre_flatten_shape)(x)
x
# Scale up the image back to its original size. These are the reverse of the conv layers applied in the encoder.
= layers.Conv2DTranspose(128, (3, 3), strides=2, activation="relu", padding="same")(x)
x = layers.Conv2DTranspose(64, (3, 3), strides=2, activation="relu", padding="same")(x)
x = layers.Conv2DTranspose(32, (3, 3), strides=2, activation="relu", padding="same")(x)
x
# Output
= layers.Conv2D(
decoder_output
CHANNELS,3, 3),
(=1,
strides='sigmoid',
activation="same",
padding="decoder_output",
name
)(x)
# Model
= models.Model(decoder_input, decoder_output)
decoder decoder.summary()
Model: "model_1"
_________________________________________________________________
Layer (type) Output Shape Param #
=================================================================
decoder_input (InputLayer) [(None, 2)] 0
dense (Dense) (None, 2048) 6144
reshape (Reshape) (None, 4, 4, 128) 0
conv2d_transpose (Conv2DTr (None, 8, 8, 128) 147584
anspose)
conv2d_transpose_1 (Conv2D (None, 16, 16, 64) 73792
Transpose)
conv2d_transpose_2 (Conv2D (None, 32, 32, 32) 18464
Transpose)
decoder_output (Conv2D) (None, 32, 32, 1) 289
=================================================================
Total params: 246273 (962.00 KB)
Trainable params: 246273 (962.00 KB)
Non-trainable params: 0 (0.00 Byte)
_________________________________________________________________
2.4. Build the Autoencoder
Combine the encoder and decoder into a single model.
Code
= models.Model(encoder_input, decoder(encoder_output))
autoencoder autoencoder.summary()
Model: "model_2"
_________________________________________________________________
Layer (type) Output Shape Param #
=================================================================
encoder_input (InputLayer) [(None, 32, 32, 1)] 0
conv2d (Conv2D) (None, 16, 16, 32) 320
conv2d_1 (Conv2D) (None, 8, 8, 64) 18496
conv2d_2 (Conv2D) (None, 4, 4, 128) 73856
flatten (Flatten) (None, 2048) 0
encoder_output (Dense) (None, 2) 4098
model_1 (Functional) (None, 32, 32, 1) 246273
=================================================================
Total params: 343043 (1.31 MB)
Trainable params: 343043 (1.31 MB)
Non-trainable params: 0 (0.00 Byte)
_________________________________________________________________
2.5. Train the Autoencoder
The autoencoder is trained with the source images as both input and target output.
The loss function is usually chosen as either RMSE or binary cross-entropy between pixels of original image vs reconstruction.
Code
compile(optimizer="adam", loss="binary_crossentropy")
autoencoder.
autoencoder.fit(
x_train,
x_train,=EPOCHS,
epochs=BATCH_SIZE,
batch_size=True,
shuffle=(x_test, x_test)
validation_data )
Epoch 1/3
600/600 [==============================] - 35s 58ms/step - loss: 0.2910 - val_loss: 0.2610
Epoch 2/3
600/600 [==============================] - 36s 60ms/step - loss: 0.2569 - val_loss: 0.2561
Epoch 3/3
600/600 [==============================] - 34s 57ms/step - loss: 0.2536 - val_loss: 0.2540
<keras.src.callbacks.History at 0x156e4a310>
3. Analysing the Autoencoder
We can use our trained autoencoder to:
- Reconstruct images
- Analyse embeddings
- Generate new images
3.1. Reconstruct Images Using the Autoencoder
Reconstruct a sample of test images using the autoencoder.
The reconstruction isn’t perfect; some information is lost when reducing down to just 2 dimensions. But it does a surprisingly good job of compressing 32x32 pixel values into just 2 embedding values.
Code
= 5000
NUM_IMAGES_TO_RECONSTRUCT = x_test[:NUM_IMAGES_TO_RECONSTRUCT]
example_images = y_test[:NUM_IMAGES_TO_RECONSTRUCT]
example_labels
= autoencoder.predict(example_images) predictions
7/157 [>.............................] - ETA: 1s 157/157 [==============================] - 1s 8ms/step
Original images:
Code
def plot_sample_images(images, n=10, size=(20, 3), cmap="gray_r"):
=size)
plt.figure(figsizefor i in range(n):
= plt.subplot(1, n, i + 1)
_ "float32"), cmap=cmap)
plt.imshow(images[i].astype("off")
plt.axis(
plt.show()
plot_sample_images(example_images)
Reconstructed images:
Code
plot_sample_images(predictions)
3.2. Analyse the Embeddings
Each of the images above has been encoded as a 2-dimensional embedding.
We can look at these embeddings to gain some insight into how the autoencoder works.
The embedding vectors for our sample images above:
Code
# Encode the example images
= encoder.predict(example_images)
embeddings print(embeddings[:10])
102/157 [==================>...........] - ETA: 0s157/157 [==============================] - 0s 2ms/step
[[ 2.2441912 -2.711683 ]
[ 6.1558456 6.0202003 ]
[-3.787192 7.3368516 ]
[-2.5938551 4.2098355 ]
[ 3.8645594 2.7229536 ]
[-2.0130231 6.0485506 ]
[ 1.2749226 2.1347647 ]
[ 2.8239484 2.898773 ]
[-0.48542604 -1.0869933 ]
[ 0.30643728 -2.6099105 ]]
We can plot the 2D latent space, colouring each point by its label. This shows how similar items are clustered together in latent space.
This is impressive! Remember, we never showed the model the labels when training, so it has learned to cluster images that look alike.
Code
# Colour the embeddings by their label
= y_test[:NUM_IMAGES_TO_RECONSTRUCT]
example_labels
# Plot the latent space
= 8
figsize =(figsize, figsize))
plt.figure(figsize
plt.scatter(0],
embeddings[:, 1],
embeddings[:, ="rainbow",
cmap=example_labels,
c=0.6,
alpha=3,
s
)
plt.colorbar() plt.show()
3.3. Generating New Images
We can sample from the latent space and decode these sampled points to generate new images.
First we sample some random points in the latent space:
Code
# Get the range of existing embedding values so we can sample sensible points within the latent space.
= np.min(embeddings, axis=0)
embedding_min = np.max(embeddings, axis=0)
embedding_max
# Sample some points
= 6
grid_width = 3
grid_height = np.random.uniform(
sample =(grid_width * grid_height, EMBEDDING_DIM)
embedding_min, embedding_max, size
)print(sample)
[[ 1.47862929 9.28394749]
[-3.19389344 -3.04713146]
[-0.57161452 -0.35644389]
[10.97632621 -2.12482484]
[ 4.05160668 9.04420005]
[ 9.50105167 5.71270956]
[ 3.24765456 4.95969011]
[-3.68217634 4.52120851]
[-1.7067196 5.87696959]
[ 5.99883565 -2.11597183]
[ 1.84553131 6.04266323]
[ 0.15552252 1.98655625]
[ 3.55479856 2.35587959]
[-0.32278762 6.07537408]
[ 8.98977414 -1.15893539]
[ 2.1476981 4.97819188]
[-2.0896675 3.9166368 ]
[ 6.49229371 -4.75611412]]
We can then decode these sampled points.
Code
# Decode the sampled points
= decoder.predict(sample) reconstructions
1/1 [==============================] - 0s 59ms/step
Code
= 8
figsize =(figsize, figsize))
plt.figure(figsize
# Plot the latent space and overlay the positions of the sampled points
0], embeddings[:, 1], c="black", alpha=0.5, s=2)
plt.scatter(embeddings[:, 0], sample[:, 1], c="#00B0F0", alpha=1, s=40)
plt.scatter(sample[:,
plt.show()
# Plot a grid of the reconstructed images which decode those sampled points
= plt.figure(figsize=(figsize, grid_height * 2))
fig =0.4, wspace=0.4)
fig.subplots_adjust(hspace
for i in range(grid_width * grid_height):
= fig.add_subplot(grid_height, grid_width, i + 1)
ax "off")
ax.axis(
ax.text(0.5,
-0.35,
str(np.round(sample[i, :], 1)),
=10,
fontsize="center",
ha=ax.transAxes,
transform
)="Greys") ax.imshow(reconstructions[i, :, :], cmap
Now let’s see what happens when we regularly sample the latent space.
Code
# Colour the embeddings by their label (clothing type - see table)
= 12
figsize = 15
grid_size =(figsize, figsize))
plt.figure(figsize
plt.scatter(0],
embeddings[:, 1],
embeddings[:, ="rainbow",
cmap=example_labels,
c=0.8,
alpha=300,
s
)
plt.colorbar()
= np.linspace(min(embeddings[:, 0]), max(embeddings[:, 0]), grid_size)
x = np.linspace(max(embeddings[:, 1]), min(embeddings[:, 1]), grid_size)
y = np.meshgrid(x, y)
xv, yv = xv.flatten()
xv = yv.flatten()
yv = np.array(list(zip(xv, yv)))
grid
= decoder.predict(grid)
reconstructions # plt.scatter(grid[:, 0], grid[:, 1], c="black", alpha=1, s=10)
plt.show()
= plt.figure(figsize=(figsize, figsize))
fig =0.4, wspace=0.4)
fig.subplots_adjust(hspacefor i in range(grid_size**2):
= fig.add_subplot(grid_size, grid_size, i + 1)
ax "off")
ax.axis(="Greys") ax.imshow(reconstructions[i, :, :], cmap
8/8 [==============================] - 0s 9ms/step
3.4. The Limitations of Autoencoders
The latent space exploration above yields some interesting insights into “regular” autoencoders that motivate the use of variational autoencoders to address these shortcomings.
- Different categories occupy varying amounts of area in latent space.
- The latent space distribution is not symmetrical or bounded.
- There are gaps in the latent space.
This makes it difficult for us to sample from this latent space effectively. We could sample a “gap” and get a nonsensical image. If a category (say, trousers) occupies a larger area in latent space, we are more likely to generate images of trousers than of categories which occupy a small area (say, shoes).
4. Variational Autoencoders
If we revisit our wardrobe, rather than assigning each item to a specific location, let’s assign it to a general region of the wardrobe.
And let’s also insist that this region should be as close to the centre of the wardrobe as possible, otherwise we are penalised. This should yield a more uniform latent space.
This is the idea behind variational autoencoders (VAE).
4.1. The Encoder
In a standard autoencoder, each image is mapped directly to one point in the latent space.
In a variational autoencoder, each image is mapped to a multivariate Normal distribution around a point in the latent space. Variational autoencoders assume their is no correlation between latent space dimensions.
So we will typically use isotropic Normal distributions, meaning the covariance matrix is diagonal so the distribution is independent in each dimension. The encoder only needs to map each input to a mean vector and a variance vector; it does not need to worry about covariances.
In practice we choose to map to log variances because this can be any value in the range \((-\infty, \infty)\) which gives a smoother value to learn rather than variances whihc are positive.
In summary, the encoder maps \(image \rightarrow (z_{mean}, z_{log\_var})\)
We can then sample a point \(z\) from this distribution using:
\[ z = z_{mean} + z_{sigma} * epsilon \]
where: \[ z_{sigma} = e^{z_{log\_var} * 0.5} \] \[ epsilon \sim \mathcal{N}(0, I) \]
4.2. The Decoder
This is identical to the standard autoencoder.
4.3. The Variational Autoencoder
Putting these together, we get the overall architecture:
flowchart LR A[Encoder] --> B1(z_mean) A[Encoder] --> B2(z_log_var) B1(z_mean) --> C[sample] B2(z_log_var) --> C[sample] C[sample] --> D(z) D(z) --> E[Decoder]
Why does this change to the encoder help?
In the standard autoencoder, there is no requirement for the latent space to be continuous. So we could sample a point, say, \((1, 2)\) and decode it to a well-formed image. But there is no guarantee that a point next to it \((1.1, 2.1)\) would look similar or even be intelligible.
The “variational” part of the VAE addresses this problem. We now sample from an area around z_mean, so the decoder must ensure that all points in that region produce similar images to keep the reconstruction loss small.
4.4. Reparameterization Trick
Rather than sample directly from a Normal distribution parameterised by z_mean and z_log_var, we can sample epsilon from a standard Normal distribution and manually adjust the sample to correct its mean and variance.
This means gradients can backpropagate freely through the layer. The randomness in the layer is all encapsulated in epsilon, so the partial derivative of the layer output w.r.t. the layer input is deterministic, making backpropagation possible.
4.5 Loss Function
The loss function of the standard autoencoder was the reconstruction loss between original image and its decoded version.
For VAEs, we add an additional term which encourages points to have small mean and variance by penalising z_mean and z_log_var variables that differ significantly from 0.
This is the Kullback-Leibler (KL) divergence. It measures how much one probability distribution differs from another. We use it to measure how much our Normal distribution, with parameters z_mean and z_log_var, differs from a standard Normal distribution.
For this special case of KL divergence between our Normal distribution and a standard Normal, the closed form solution is: \[ D_{KL}[\mathcal{N}(\mu, \sigma) || \mathcal{N}(0, 1)] = -\frac{1}{2} \sum (1 + \log(\sigma ^2) - \mu ^2 - \sigma ^ 2) \]
So using our variables, we can describe this in code as:
kl_loss = -0.5 * sum(1 + z_log_var - z_mean ** 2 - exp(z_log_var))
This loss is minimised when z_mean=0
and z_log_var=0
, i.e. it encourages our distrubution towards a stand Normal distribution, thus using the space around the origin symmetrically and efficently with few gaps.
The original paper simply summed the reconstruction_loss and the kl_loss. A variant of this includes a hyperparameter \(\beta\) to vary the weight of the KL divergence term. This is called a “\(\beta-VAE\)”:
vae_loss = reconstruction_error + beta * kl_loss
5. Building a Variational Autoencoder (VAE)
5.1. Sampling Layer
We need a sampling layer which allows us to sample \(z\) from the distribution defined by \(z_{mean}\) and \(z_{log\_var}\).
Code
class Sampling(layers.Layer):
def call(self, z_mean, z_log_var):
= tf.shape(z_mean)[0]
batch = tf.shape(z_mean)[1]
dim = tf.keras.backend.random_normal(shape=(batch, dim))
epsilon return z_mean + tf.exp(0.5 * z_log_var) * epsilon
5.2. Encoder
The encoder incorporates the Sampling layer as the final step. This is what is passed to the decoder.
Code
# Encoder
= layers.Input(
encoder_input =(IMAGE_SIZE, IMAGE_SIZE, 1), name="encoder_input"
shape
)= layers.Conv2D(32, (3, 3), strides=2, activation="relu", padding="same")(encoder_input)
x = layers.Conv2D(64, (3, 3), strides=2, activation="relu", padding="same")(x)
x = layers.Conv2D(128, (3, 3), strides=2, activation="relu", padding="same")(x)
x = tf.keras.backend.int_shape(x)[1:] # the decoder will need this!
shape_before_flattening
= layers.Flatten()(x)
x = layers.Dense(EMBEDDING_DIM, name="z_mean")(x)
z_mean = layers.Dense(EMBEDDING_DIM, name="z_log_var")(x)
z_log_var = Sampling()(z_mean, z_log_var)
z
= models.Model(encoder_input, [z_mean, z_log_var, z], name="encoder")
encoder encoder.summary()
Model: "encoder"
__________________________________________________________________________________________________
Layer (type) Output Shape Param # Connected to
==================================================================================================
encoder_input (InputLayer) [(None, 32, 32, 1)] 0 []
conv2d_3 (Conv2D) (None, 16, 16, 32) 320 ['encoder_input[0][0]']
conv2d_4 (Conv2D) (None, 8, 8, 64) 18496 ['conv2d_3[0][0]']
conv2d_5 (Conv2D) (None, 4, 4, 128) 73856 ['conv2d_4[0][0]']
flatten_1 (Flatten) (None, 2048) 0 ['conv2d_5[0][0]']
z_mean (Dense) (None, 2) 4098 ['flatten_1[0][0]']
z_log_var (Dense) (None, 2) 4098 ['flatten_1[0][0]']
sampling (Sampling) (None, 2) 0 ['z_mean[0][0]',
'z_log_var[0][0]']
==================================================================================================
Total params: 100868 (394.02 KB)
Trainable params: 100868 (394.02 KB)
Non-trainable params: 0 (0.00 Byte)
__________________________________________________________________________________________________
5.3. Decoder
The decoder is the same as a standard autoencoder.
Code
# Decoder
= layers.Input(shape=(EMBEDDING_DIM,), name="decoder_input")
decoder_input = layers.Dense(np.prod(shape_before_flattening))(decoder_input)
x = layers.Reshape(shape_before_flattening)(x)
x = layers.Conv2DTranspose(128, (3, 3), strides=2, activation="relu", padding="same")(x)
x = layers.Conv2DTranspose(64, (3, 3), strides=2, activation="relu", padding="same")(x)
x = layers.Conv2DTranspose(32, (3, 3), strides=2, activation="relu", padding="same")(x)
x
= layers.Conv2D(1, (3, 3), strides=1, activation="sigmoid", padding="same", name="decoder_output")(x)
decoder_output
= models.Model(decoder_input, decoder_output)
decoder decoder.summary()
Model: "model_3"
_________________________________________________________________
Layer (type) Output Shape Param #
=================================================================
decoder_input (InputLayer) [(None, 2)] 0
dense_1 (Dense) (None, 2048) 6144
reshape_1 (Reshape) (None, 4, 4, 128) 0
conv2d_transpose_3 (Conv2D (None, 8, 8, 128) 147584
Transpose)
conv2d_transpose_4 (Conv2D (None, 16, 16, 64) 73792
Transpose)
conv2d_transpose_5 (Conv2D (None, 32, 32, 32) 18464
Transpose)
decoder_output (Conv2D) (None, 32, 32, 1) 289
=================================================================
Total params: 246273 (962.00 KB)
Trainable params: 246273 (962.00 KB)
Non-trainable params: 0 (0.00 Byte)
_________________________________________________________________
5.4. Autoencoder
Putting the encoder and decoder together.
Code
= 5
EPOCHS = 500 BETA
Code
class VAE(models.Model):
def __init__(self, encoder, decoder, **kwargs):
super(VAE, self).__init__(**kwargs)
self.encoder = encoder
self.decoder = decoder
self.total_loss_tracker = metrics.Mean(name="total_loss")
self.reconstruction_loss_tracker = metrics.Mean(name="reconstruction_loss")
self.kl_loss_tracker = metrics.Mean(name="kl_loss")
@property
def metrics(self):
return [
self.total_loss_tracker,
self.reconstruction_loss_tracker,
self.kl_loss_tracker,
]
def call(self, inputs):
"""Call the model on a particular input."""
= encoder(inputs)
z_mean, z_log_var, z = decoder(z)
reconstruction return z_mean, z_log_var, reconstruction
def train_step(self, data):
"""Step run during training."""
with tf.GradientTape() as tape:
= self(data)
z_mean, z_log_var, reconstruction = tf.reduce_mean(
reconstruction_loss * losses.binary_crossentropy(data, reconstruction, axis=(1, 2, 3))
BETA
)= tf.reduce_mean(
kl_loss -0.5 * (1 + z_log_var - tf.square(z_mean) - tf.exp(z_log_var)), axis=1)
tf.reduce_sum(
)= reconstruction_loss + kl_loss
total_loss
= tape.gradient(total_loss, self.trainable_weights)
grads self.optimizer.apply_gradients(zip(grads, self.trainable_weights))
self.total_loss_tracker.update_state(total_loss)
self.reconstruction_loss_tracker.update_state(reconstruction_loss)
self.kl_loss_tracker.update_state(kl_loss)
return {m.name: m.result() for m in self.metrics}
def test_step(self, data):
"""Step run during validation."""
if isinstance(data, tuple):
= data[0]
data
= self(data)
z_mean, z_log_var, reconstruction = tf.reduce_mean(
reconstruction_loss * losses.binary_crossentropy(data, reconstruction, axis=(1, 2, 3))
BETA
)= tf.reduce_mean(
kl_loss
tf.reduce_sum(-0.5 * (1 + z_log_var - tf.square(z_mean) - tf.exp(z_log_var)),
=1,
axis
)
)= reconstruction_loss + kl_loss
total_loss
return {
"loss": total_loss,
"reconstruction_loss": reconstruction_loss,
"kl_loss": kl_loss,
}
Instantiate the VAE model and compile it.
Code
= VAE(encoder, decoder)
vae
# optimizer = optimizers.Adam(learning_rate=0.0005)
= tf.keras.optimizers.legacy.Adam(learning_rate=0.0005)
optimizer compile(optimizer=optimizer) vae.
5.5. Training
Train the VAE as before.
Code
vae.fit(
x_train,=EPOCHS,
epochs=BATCH_SIZE,
batch_size=True,
shuffle=(x_test, x_test),
validation_data )
Epoch 1/5
600/600 [==============================] - 37s 61ms/step - total_loss: 160.4693 - reconstruction_loss: 155.9913 - kl_loss: 4.4779 - val_loss: 141.2442 - val_reconstruction_loss: 136.1877 - val_kl_loss: 5.0565
Epoch 2/5
600/600 [==============================] - 34s 57ms/step - total_loss: 135.9397 - reconstruction_loss: 130.9409 - kl_loss: 4.9988 - val_loss: 138.5623 - val_reconstruction_loss: 133.5856 - val_kl_loss: 4.9767
Epoch 3/5
600/600 [==============================] - 34s 56ms/step - total_loss: 134.3719 - reconstruction_loss: 129.3381 - kl_loss: 5.0338 - val_loss: 137.1351 - val_reconstruction_loss: 132.1540 - val_kl_loss: 4.9811
Epoch 4/5
600/600 [==============================] - 34s 56ms/step - total_loss: 133.4455 - reconstruction_loss: 128.3819 - kl_loss: 5.0637 - val_loss: 136.5461 - val_reconstruction_loss: 131.4780 - val_kl_loss: 5.0681
Epoch 5/5
600/600 [==============================] - 34s 57ms/step - total_loss: 132.7808 - reconstruction_loss: 127.6688 - kl_loss: 5.1120 - val_loss: 135.8917 - val_reconstruction_loss: 130.7375 - val_kl_loss: 5.1542
<keras.src.callbacks.History at 0x2c59e9f10>
6. Analysing the VAE
6.1. Reconstruct using the VAE
As before, we can eyeball the reconstructions from our model.
Code
# Select a subset of the test set
= 5000
n_to_predict = x_test[:n_to_predict]
example_images = y_test[:n_to_predict]
example_labels
# Create autoencoder predictions and display
= vae.predict(example_images)
z_mean, z_log_var, reconstructions print("Example real clothing items")
plot_sample_images(example_images)print("Reconstructions")
plot_sample_images(reconstructions)
42/157 [=======>......................] - ETA: 1s157/157 [==============================] - 1s 9ms/step
Example real clothing items
Reconstructions
6.2. Generate New Images
We can sample and decode points in the latent space to observe how the model generates images. We note that:
- The latent space has more even coverage and does not stray to far from a standard Normal distribution. If this is not the case, we can vary the \(\beta\) value used to give more weight to the KL loss term.
- We do not see as many poorly formed images as we did when sampling a “gap” in a standard autoencoder.
Code
# Encode the example images
= encoder.predict(example_images)
z_mean, z_var, z
# Sample some points in the latent space, from the standard normal distribution
= (6, 3)
grid_width, grid_height = np.random.normal(size=(grid_width * grid_height, 2))
z_sample # Decode the sampled points
= decoder.predict(z_sample)
reconstructions # Convert original embeddings and sampled embeddings to p-values
= norm.cdf(z)
p = norm.cdf(z_sample)
p_sample # Draw a plot of...
= 8
figsize =(figsize, figsize))
plt.figure(figsize
# ... the original embeddings ...
0], z[:, 1], c="black", alpha=0.5, s=2)
plt.scatter(z[:,
# ... and the newly generated points in the latent space
0], z_sample[:, 1], c="#00B0F0", alpha=1, s=40)
plt.scatter(z_sample[:,
plt.show()
# Add underneath a grid of the decoded images
= plt.figure(figsize=(figsize, grid_height * 2))
fig =0.4, wspace=0.4)
fig.subplots_adjust(hspace
for i in range(grid_width * grid_height):
= fig.add_subplot(grid_height, grid_width, i + 1)
ax "off")
ax.axis(
ax.text(0.5,
-0.35,
str(np.round(z_sample[i, :], 1)),
=10,
fontsize="center",
ha=ax.transAxes,
transform
)="Greys") ax.imshow(reconstructions[i, :, :], cmap
1/1 [==============================] - 0s 14ms/step
6.3. Analyse the Embeddings
The plots below show the latent space coloured by clothing type. The left plot shows this in terms of z-values and the right in terms of p-values.
The latent space is more continuous with fewer gaps, and different categories take similar amounts of space.
Code
# Colour the embeddings by their label (clothing type - see table)
= 8
figsize = plt.figure(figsize=(figsize * 2, figsize))
fig = fig.add_subplot(1, 2, 1)
ax = ax.scatter(
plot_1 0], z[:, 1], cmap="rainbow", c=example_labels, alpha=0.8, s=3
z[:,
)
plt.colorbar(plot_1)= fig.add_subplot(1, 2, 2)
ax = ax.scatter(
plot_2 0], p[:, 1], cmap="rainbow", c=example_labels, alpha=0.8, s=3
p[:,
) plt.show()
Next we see what happens when we sample from the latent space in a regular grid.
Code
# Colour the embeddings by their label (clothing type - see table)
= 12
figsize = 15
grid_size =(figsize, figsize))
plt.figure(figsize
plt.scatter(0], p[:, 1], cmap="rainbow", c=example_labels, alpha=0.8, s=300
p[:,
)
plt.colorbar()
= norm.ppf(np.linspace(0, 1, grid_size))
x = norm.ppf(np.linspace(1, 0, grid_size))
y = np.meshgrid(x, y)
xv, yv = xv.flatten()
xv = yv.flatten()
yv = np.array(list(zip(xv, yv)))
grid
= decoder.predict(grid)
reconstructions # plt.scatter(grid[:, 0], grid[:, 1], c="black", alpha=1, s=10)
plt.show()
= plt.figure(figsize=(figsize, figsize))
fig =0.4, wspace=0.4)
fig.subplots_adjust(hspacefor i in range(grid_size**2):
= fig.add_subplot(grid_size, grid_size, i + 1)
ax "off")
ax.axis(="Greys") ax.imshow(reconstructions[i, :, :], cmap
8/8 [==============================] - 0s 6ms/step
References
- Chapter 3 of Generative Deep Learning by David Foster.