Elegy(High-level API for deep learning in JAX & Flax)

Derrick Mwiti
Derrick Mwiti

Table of Contents

Training deep learning networks in Flax is done in a couple of steps. It involves creating the following functions:

  • Model definition.
  • Compute metrics.
  • Training state.
  • Training step.
  • Training and evaluation function.

Flax and JAX give more control in defining and training deep learning networks. However, this comes with more verbosity. Enter Elegy. Elegy is a high-level API for creating deep learning networks in JAX. Elegy's API is like the one in Keras.

Let's look at how to use Elegy to define and train deep learning networks in Flax.

Data pre-processing

To make this illustration concrete, we'll use the movie review data from Kaggle to create an LSTM network in Flax.

Read more: LSTM in JAX & Flax (Complete example with code and notebook)

The first step is to download and extract the data.

import os
import kaggle
# Obtain from https://www.kaggle.com/username/account
!kaggle datasets download lakshmi25npathi/imdb-dataset-of-50k-movie-reviews
import zipfile
with zipfile.ZipFile('imdb-dataset-of-50k-movie-reviews.zip', 'r') as zip_ref:

Next, we define the following processing steps:

  • Split the data into a training and testing set.
  • Remove stopwords from the data.
  • Clean the data by removing punctuations and other special characters.
  • Convert the data to a TensorFlow dataset.
  • Conver the data to numerical representation using the Keras vectorization layer.
import numpy as np 
import pandas as pd 
from numpy import array
import tensorflow_datasets as tfds
import tensorflow as tf
from sklearn.model_selection import train_test_split 
from sklearn.preprocessing import LabelEncoder
import matplotlib.pyplot as plt
from sklearn.model_selection import train_test_split
import tensorflow as tf

df = pd.read_csv("imdb-dataset-of-50k-movie-reviews/IMDB Dataset.csv")
import nltk
from nltk.corpus import stopwords
def remove_stop_words(review):
    review_minus_sw = []
    stop_words = stopwords.words('english')
    review = review.split()
    cleaned_review = [review_minus_sw.append(word) for word in review if word not in stop_words]            
    cleaned_review = ' '.join(review_minus_sw)
    return cleaned_review       
df['review'] = df['review'].apply(remove_stop_words)
labelencoder = LabelEncoder()
df = df.assign(sentiment = labelencoder.fit_transform(df["sentiment"]))
df = df.drop_duplicates()
docs = df['review']
labels = array(df['sentiment'])
X_train, X_test , y_train, y_test = train_test_split(docs, labels , test_size = 0.20, random_state=0)
max_features = 10000  # Maximum vocab size.
batch_size = 128
max_len = 50 # Sequence length to pad the outputs to.
vectorize_layer = tf.keras.layers.TextVectorization(standardize='lower_and_strip_punctuation',max_tokens=max_features,output_mode='int',output_sequence_length=max_len)
X_train_padded =  vectorize_layer(X_train)
X_test_padded =  vectorize_layer(X_test)
training_data = tf.data.Dataset.from_tensor_slices((X_train_padded, y_train))
validation_data = tf.data.Dataset.from_tensor_slices((X_test_padded, y_test))
training_data = training_data.batch(batch_size)
validation_data = validation_data.batch(batch_size)
def get_train_batches():
  ds = training_data.prefetch(1)
  ds = ds.repeat(3)
  ds = ds.shuffle(3, reshuffle_each_iteration=True)
  # tfds.dataset_as_numpy converts the tf.data.Dataset into an iterable of NumPy arrays
  return tfds.as_numpy(ds)

Model definition in Elegy

Start by installing Elegy, Flax, and JAX.

pip install -U elegy flax jax jaxlib

Next, define the LSTM model.

import jax
import jax.numpy as jnp
import elegy as eg
from flax import linen as nn

class LSTMModel(nn.Module):
    def setup(self):
        self.embedding = nn.Embed(max_features, max_len)
        lstm_layer = nn.scan(nn.OptimizedLSTMCell,
                               split_rngs={"params": False},
        self.lstm1 = lstm_layer()
        self.dense1 = nn.Dense(256)
        self.lstm2 = lstm_layer()
        self.dense2 = nn.Dense(128)
        self.lstm3 = lstm_layer()
        self.dense3 = nn.Dense(64)
        self.dense4 = nn.Dense(2)
    def __call__(self, x_batch):
        x = self.embedding(x_batch)
        carry, hidden = nn.OptimizedLSTMCell.initialize_carry(jax.random.PRNGKey(0), batch_dims=(len(x_batch),), size=128)
        (carry, hidden), x = self.lstm1((carry, hidden), x)
        x = self.dense1(x)
        x = nn.relu(x)
        carry, hidden = nn.OptimizedLSTMCell.initialize_carry(jax.random.PRNGKey(0), batch_dims=(len(x_batch),), size=64)
        (carry, hidden), x = self.lstm2((carry, hidden), x)
        x = self.dense2(x)
        x = nn.relu(x)
        carry, hidden = nn.OptimizedLSTMCell.initialize_carry(jax.random.PRNGKey(0), batch_dims=(len(x_batch),), size=32)
        (carry, hidden), x = self.lstm3((carry, hidden), x)
        x = self.dense3(x)
        x = nn.relu(x)
        x = self.dense4(x[:, -1])
        return nn.log_softmax(x)

Let's now create an Elegy model using the above network. As you can see, the loss and metrics are defined like in Keras. The model compilation is done in the constructor, so you don't have to do this manually.

import optax

model = eg.Model(

Elegy model summary

Like in Keras, we can print the model's summary.


Distributed training in Elegy

To train models in a distributed manner in Flax, we define parallel versions of our model training functions.

Read more: Distributed training with JAX & Flax

However, with Elegy, we call the distributed method.

model = model.distributed()

Keras-like callbacks in Flax

Elegy supports callbacks similar to Keras callbacks. In this case, we train the model with the following callbacks:

callbacks = [ eg.callbacks.TensorBoard("summaries"),
             eg.callbacks.ModelCheckpoint("models/high-level", save_best_only=True),
             eg.callbacks.EarlyStopping(monitor = 'val_loss',patience=10)

Read more: TensorBoard tutorial (Deep dive with examples and notebook)

Train Elegy models

Elegy provides the fit method for training models. The method supports the following data sources:

  • Tensorflow Dataset.
  • Pytorch DataLoader
  • Elegy DataLoader, and
  • Python Generators.
history = model.fit(

Evaluate Elegy models

To evaluate Elegy models, use the evaluate function.


Visualize Elegy model with TensorBoard

Since we applied the TensorBoard callback, we can view the performance of the model in TensorBoard.

%load_ext tensorboard
%tensorboard --logdir summaries

Read more: How to use TensorBoard in Flax

Plot model performance with Matplotlib

We can also plot the performance of the model using Matplotlib.

import matplotlib.pyplot as plt

def plot_history(history):
    n_plots = len(history.history.keys()) // 2
    plt.figure(figsize=(14, 24))
    for i, key in enumerate(list(history.history.keys())[:n_plots]):
        metric = history.history[key]
        val_metric = history.history[f"val_{key}"]

        plt.subplot(n_plots, 1, i + 1)
        plt.plot(metric, label=f"Training {key}")
        plt.plot(val_metric, label=f"Validation {key}")
        plt.legend(loc="lower right")
        plt.title(f"Training and Validation {key}")

Making predictions with Elegy models

Like Keras, Elegy provides the predict method for making predictions.

(text, test_labels) = next(iter(validation_data))

y_pred = model.predict(jnp.array(text))

Saving and loading Elegy models

Elegy models can also be saved like Keras models and used to make predictions immediately.

# You can use can use `save` but `ModelCheckpoint already serialized the model
# model.save("model")

# current model reference
print("current model id:", id(model))

# load model from disk
model = eg.load("models/high-level")

# new model reference
print("new model id:    ", id(model))

# check that it works!

Final thoughts

This article has been a quick dive into Elegy– a JAX high-level API that you can use to build and train Flax networks. You have seen that Elegy is very similar to Keras and has a simple API for Flax. It also contains similar functions to Keras, like:

  • Model training.
  • Making predictions.
  • Creating callbacks.
  • Defining model loss and metrics.

JAX and Flax resources

Open On GitHub

Follow us on LinkedIn, Twitter, GitHub, and subscribe to our blog, so you don't miss a new issue.


Derrick Mwiti Twitter

Google Developer Expert - Machine Learning