How to Generate Images with Variational Autoencoders(VAE) (Create VAE from scratch using Keras and TensorFlow)

Derrick Mwiti
Derrick Mwiti
8 min read

Table of Contents

An autoencoder takes an input image and creates a low-dimensional representation, i.e., a latent vector. This vector is then used to reconstruct the original image. Regular autoencoders get an image as input and output the same image. However, Variational AutoEncoders (VAE) generate new images with the same distribution as the training images.

VAEs work as follows:

  • Map an input into a distribution over the latent space
  • Pick a point from the distribution in the latent space
  • Decode the sampled point and compute the reconstruction and KL Divergence errors.

The reconstruction error is the same as the one used in standard autoencoders. The KL Divergence error measures the distance between the distribution of the generated and original image.

Variational AutoEncoders are constrained in the normal distribution during training. You can, therefore, pick a point in the normal distribution, and the network will create a new image based on the training data.

VAE architecture. Image by author

Prepare data for the VAE model

Let's illustrate how to build a VAE model in Keras using the Fruits and Vegetables Image Recognition Dataset.

First, let's get the usual imports out of the way.

from keras.models import Model
from keras import backend as K
from keras import metrics
from keras.losses import mse
import numpy as np
from tensorflow.keras.layers import Input, Dense, Lambda, Conv2D, Flatten, Reshape, Conv2DTranspose,BatchNormalization,LeakyReLU,Dropout
import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt

Next, load the training and validation set.

base_dir = '/kaggle/input/fruit-and-vegetable-image-recognition/train'
batch_size = 32
img_size = 128

training_set = tf.keras.utils.image_dataset_from_directory(
  image_size=(img_size, img_size),
  validation_set = tf.keras.utils.image_dataset_from_directory(
  image_size=(img_size, img_size),

Since we don't need the target variables, we create training and validation data without them.

x_train = np.array([])
for x, y in training_set:
  x_train = np.concatenate([x])
x_test = np.array([])
for x, y in validation_set:
  x_test = np.concatenate([x])

Here's what the dataset looks like.

class_names = training_set.class_names
plt.figure(figsize=(10, 10))
for images, labels in training_set.take(1):
  for i in range(9):
    ax = plt.subplot(3, 3, i + 1)

Finally, let's normalize the data as required when training deep learning models.  

# Normalize pixel values between 0 and 1
x_train = x_train.astype('float32') / 255.
x_test = x_test.astype('float32') / 255.

The next step is to create the building blocks for the VAE model.

Create VAE encoder in Keras

The VAE encoder outputs a mean and variance. As you can see below, it's a normal encoder defined using the Keras Functional API.

latent_dim dictates the number of dimensions in the latent space. You can tweak this number to see how it affects the model's performance.

shape_before_flattening  gets the shape of the tensor x, which will be used later in the decoder network to reshape the flattened tensor back to the original shape of the feature maps.

The output of z_mean represents the mean of the normal distribution that generates the latent representation z. The output z_log_var represents the log variance of the normal distribution that produces the latent representation z.

# Define input shape and latent dimension
latent_dim = 2
input_shape = (img_size, img_size, 3)
# Encoder network
inputs = Input(shape=input_shape)
x = Conv2D(16, (3, 3), activation='relu', padding='same')(inputs)
x = Conv2D(32, (3, 3), activation='relu', strides=(2, 2), padding='same')(x)
x = Conv2D(128, (3, 3), activation='relu', padding='same')(x)
x = Conv2D(256, (3, 3), activation='relu', padding='same')(x)
shape_before_flattening = K.int_shape(x)
x = Flatten()(x)
z_mean = Dense(latent_dim)(x)
z_log_var = Dense(latent_dim)(x)

Create sampling function

As noted earlier, we need a way to sample from the normal distribution. This is the purpose of the sampling function.

# Sampling function
def sampling(args):
    z_mean, z_log_var = args
    epsilon = K.random_normal(shape=(K.shape(z_mean)[0], latent_dim))
    return z_mean + K.exp(z_log_var / 2) * epsilon

# Reparameterization trick
z = Lambda(sampling)([z_mean, z_log_var])

Epsilon is the standard normal distribution which we randomly sample from. Since it's random, it's not trained. Therefore, the learned parameters will be the mean and standard deviation.

The reparameterization trick enables the computation of gradients because it is impossible to compute gradients over a stochastic process. Reparameterization makes the process deterministic. The Lambda layer enables the calculation of the sampling function.

Create VAE decoder in Keras

The decoder creates an image from the sampled latent vector. It performs upsampling of the low dimensional latent vector.

# Decoder network
decoder_input = Input(K.int_shape(z)[1:])
x = Dense([1:]), activation='relu')(decoder_input)
x = Reshape(shape_before_flattening[1:])(x)
x = Conv2DTranspose(128, (2, 2), activation='relu', padding='same', )(x)
x = Conv2DTranspose(64, (2, 2), activation='relu', padding='same', strides=(2, 2))(x)
x = Conv2DTranspose(32, (2, 2), activation='relu', padding='same', )(x)
x = Conv2DTranspose(16, (2, 2), activation='relu', padding='same', )(x)
x = Conv2D(3, (3, 3), activation='sigmoid', padding='same')(x)

The decoder's input shape is the shape of the z tensor. The input is then passed to a dense layer. The input to this dense layer is a product. Let's examine what it means.

shape_before_flattening is the shape of the output tensor from the last convolutional layer in the encoder before flattening. The output tensor has shape (batch_size, height, width, channels).  shape_before_flattening[1:] corresponds to the dimensions (height, width, channels). The Dense layer in the decoder network takes as input a tensor of shape (batch_size, num_features), where num_features is the product of the dimensions (height, width, channels) of the last convolutional layer output tensor in the encoder. Therefore,[1:]) computes the value of num_features, which is the number of features that the Dense layer should output.

Let's take an intuitive example where the output tensor shape of the last convolutional layer in the encoder network is (None, 8, 8, 64) where None is the batch size, and 8,8,64 are the width, height, and number of channels, respectively. The number of features in this tensor can be computed using[1:]). This represents the product of all elements in shape_before_flattening except the batch size, which is None.[1:]) is, therefore, the same as:

num_features = shape_before_flattening[1] * shape_before_flattening[2] * shape_before_flattening[3]

num_features = 8 * 8 * 64 = 4096

In this case, 4096 becomes the number of units in the dense layer of the decoder.

The tensor is then reshaped into the same shape as the output of the final convolutional layer in the encoder network by Reshape(shape_before_flattening[1:])(x).   Using the same example above, we can see that after flattening in the encoder, we will get the shape (None, 4096). The goal of Reshape(shape_before_flattening[1:])(x) is to get back the 3D image before flattening. In this case, (8, 8, 64). The output of the Reshape layer will be (None, 8, 8, 64). Hence the objective of this Reshape layer is to unflatten the image enabling us to get back the 3D image from the 1D representation.  

The original full-resolution image is obtained through a sequence of  Conv2DTranspose layers that perform convolution and upsampling at the same time. The aim is to get a final output tensor of the shape (None, img_size, img_size, 3).

Define the VAE model in Keras

With all the building blocks in place, the next step is to define the Keras VAE model. Passing an input image to the encoder produces the mean, standard deviation, and a sample from the latent space. The sample is passed to the decoder to obtain an image.

# Define the VAE model
encoder = Model(inputs, [z_mean, z_log_var, z], name='encoder')
decoder = Model(decoder_input, x, name='decoder')
outputs = decoder(encoder(inputs)[2])
vae = Model(inputs, outputs, name='vae')

encoder(inputs) produces output from the encoder. This output is z_mean, z_log_var, and z. The decoder expects z , which is the latent representation of the input image. encoder(inputs)[2] gives z because it's the value at index 2. z is then passed to the decoder  producing outputs, an approximation of the original input tensor.  

Plot the VAE model

To visualize the VAE, you can use:



tf.keras.utils.plot_model(vae,"model.png", show_shapes = True,)

The summary of the encoder:

The summary of the decoder:

Define the VAE loss function

The VAE loss function combines the reconstruction loss and the KL Divergence loss.

Let's define the two loss functions and add them to the VAE model.

# Define the VAE loss function
reconstruction_loss = mse(K.flatten(inputs), K.flatten(outputs))
reconstruction_loss *= input_shape[0] * input_shape[1] * input_shape[2]
kl_loss = -0.5 * K.sum(1 + z_log_var - K.square(z_mean) - K.exp(z_log_var), axis=1)
B = 1000   
vae_loss = K.mean(B * reconstruction_loss + kl_loss)

We can add the metrics in the same way.

vae.add_metric(kl_loss, name="kl_loss")
vae.add_metric(reconstruction_loss, name="reconstruction_loss")

Train the VAE model

The final step is to compile and train the VAE model.

vae.compile(optimizer='adam'), epochs=500, batch_size=batch_size, validation_data=(x_test, None))

Run predictions on the VAE model

Next, run some predictions using the test images.

import matplotlib.pyplot as plt
# Convert the predictions into images
decoded_imgs = vae.predict(x_test)
# Display the original and reconstructed images
n = 10 # number of images to display
plt.figure(figsize=(20, 4))
for i in range(n):
    # Display the original image
    ax = plt.subplot(2, n, i + 1)
    plt.imshow(x_test[i].reshape(img_size, img_size,3))
    # plt.gray()

    # Display the reconstructed image
    ax = plt.subplot(2, n, i + 1 + n)
    plt.imshow(decoded_imgs[i].reshape(img_size, img_size,3))
    # plt.gray()

The network can generate new images from the test images.  

You can also try and make some predictions using pure noise from the normal distribution to see if the network can generate images from that.  

import matplotlib.pyplot as plt
# Convert the predictions into images
num_samples = 10
random_latent_vectors  = np.random.random((num_samples, img_size, img_size, 3))

decoded_imgs = vae.predict(random_latent_vectors)
# Display the original and reconstructed images
n = 10 # number of images to display
plt.figure(figsize=(20, 4))
for i in range(n):
    # Display the original image
    ax = plt.subplot(2, n, i + 1)
    plt.imshow(random_latent_vectors[i].reshape(img_size, img_size,3))
    # plt.gray()

    # Display the reconstructed image
    ax = plt.subplot(2, n, i + 1 + n)
    plt.imshow(decoded_imgs[i].reshape(img_size, img_size,3))
    # plt.gray()

Try tweaking the network parameters to see if it can generate different images from pure noise.

Final thoughts

In this article, you have learned how to create a Variational AutoEncoder in Keras and generate images from pure noise. Check out the Kaggle notebook to play with the code and the references to dive deeper into the topic.  


Kaggle Notebook

Auto-Encoding Variational Bayes

An Introduction to Variational Autoencoders

Whenever you're ready, there is 2 ways I can help you:

If you're looking for a way to build a career while writing about data science and machine learning, I'd recommend starting with an affordable ebook:

Writing for Data Scientists: The exact path I followed to get technical work that pays between $250-$500 from machine learning companies such as Comet, Neptune, cnvrg, Paperspace, Layer, Neural Magic, Determined, Activeloop, and many more. Get your copy.

Data Science and Machine Learning Ebook: I offer numerous free and paid data science and machine learning ebooks to help you in your data science career. Check them out.


Derrick Mwiti Twitter

Google Developer Expert - Machine Learning


Community guidelines