Elegy(High-level API for deep learning in JAX & Flax)

Derrick Mwiti
Derrick Mwiti
6 min read

Training deep learning networks in Flax is done in a couple of steps. It involves creating the following functions:

  • Model definition.
  • Compute metrics.
  • Training state.
  • Training step.
  • Training and evaluation function.

Flax and JAX give more control in defining and training deep learning networks. However, this comes with more verbosity. Enter Elegy. Elegy is a high-level API for creating deep learning networks in JAX. Elegy's API is like the one in Keras.

Let's look at how to use Elegy to define and train deep learning networks in Flax.

Data pre-processing

To make this illustration concrete, we'll use the movie review data from Kaggle to create an LSTM network in Flax.

Read more: LSTM in JAX & Flax (Complete example with code and notebook)

The first step is to download and extract the data.

import os
import kaggle
# Obtain from https://www.kaggle.com/username/account
!kaggle datasets download lakshmi25npathi/imdb-dataset-of-50k-movie-reviews
import zipfile
with zipfile.ZipFile('imdb-dataset-of-50k-movie-reviews.zip', 'r') as zip_ref:

Next, we define the following processing steps:

  • Split the data into a training and testing set.
  • Remove stopwords from the data.
  • Clean the data by removing punctuations and other special characters.  
  • Convert the data to a TensorFlow dataset.
  • Conver the data to numerical representation using the Keras vectorization layer.
import numpy as np 
import pandas as pd 
from numpy import array
import tensorflow_datasets as tfds
import tensorflow as tf
from sklearn.model_selection import train_test_split 
from sklearn.preprocessing import LabelEncoder
import matplotlib.pyplot as plt
from sklearn.model_selection import train_test_split
import tensorflow as tf

df = pd.read_csv("imdb-dataset-of-50k-movie-reviews/IMDB Dataset.csv")
import nltk
from nltk.corpus import stopwords
def remove_stop_words(review):
    review_minus_sw = []
    stop_words = stopwords.words('english')
    review = review.split()
    cleaned_review = [review_minus_sw.append(word) for word in review if word not in stop_words]            
    cleaned_review = ' '.join(review_minus_sw)
    return cleaned_review       
df['review'] = df['review'].apply(remove_stop_words)
labelencoder = LabelEncoder()
df = df.assign(sentiment = labelencoder.fit_transform(df["sentiment"]))
df = df.drop_duplicates()
docs = df['review']
labels = array(df['sentiment'])
X_train, X_test , y_train, y_test = train_test_split(docs, labels , test_size = 0.20, random_state=0)
max_features = 10000  # Maximum vocab size.
batch_size = 128
max_len = 50 # Sequence length to pad the outputs to.
vectorize_layer = tf.keras.layers.TextVectorization(standardize='lower_and_strip_punctuation',max_tokens=max_features,output_mode='int',output_sequence_length=max_len)
X_train_padded =  vectorize_layer(X_train)
X_test_padded =  vectorize_layer(X_test)
training_data = tf.data.Dataset.from_tensor_slices((X_train_padded, y_train))
validation_data = tf.data.Dataset.from_tensor_slices((X_test_padded, y_test))
training_data = training_data.batch(batch_size)
validation_data = validation_data.batch(batch_size)
def get_train_batches():
  ds = training_data.prefetch(1)
  ds = ds.repeat(3)
  ds = ds.shuffle(3, reshuffle_each_iteration=True)
  # tfds.dataset_as_numpy converts the tf.data.Dataset into an iterable of NumPy arrays
  return tfds.as_numpy(ds)

Model definition in Elegy

Start by installing Elegy, Flax, and JAX.

pip install -U elegy flax jax jaxlib

Next, define the LSTM model.

import jax
import jax.numpy as jnp
import elegy as eg
from flax import linen as nn

class LSTMModel(nn.Module):
    def setup(self):
        self.embedding = nn.Embed(max_features, max_len)
        lstm_layer = nn.scan(nn.OptimizedLSTMCell,
                               split_rngs={"params": False},
        self.lstm1 = lstm_layer()
        self.dense1 = nn.Dense(256)
        self.lstm2 = lstm_layer()
        self.dense2 = nn.Dense(128)
        self.lstm3 = lstm_layer()
        self.dense3 = nn.Dense(64)
        self.dense4 = nn.Dense(2)
    def __call__(self, x_batch):
        x = self.embedding(x_batch)
        carry, hidden = nn.OptimizedLSTMCell.initialize_carry(jax.random.PRNGKey(0), batch_dims=(len(x_batch),), size=128)
        (carry, hidden), x = self.lstm1((carry, hidden), x)
        x = self.dense1(x)
        x = nn.relu(x)
        carry, hidden = nn.OptimizedLSTMCell.initialize_carry(jax.random.PRNGKey(0), batch_dims=(len(x_batch),), size=64)
        (carry, hidden), x = self.lstm2((carry, hidden), x)
        x = self.dense2(x)
        x = nn.relu(x)
        carry, hidden = nn.OptimizedLSTMCell.initialize_carry(jax.random.PRNGKey(0), batch_dims=(len(x_batch),), size=32)
        (carry, hidden), x = self.lstm3((carry, hidden), x)
        x = self.dense3(x)
        x = nn.relu(x)
        x = self.dense4(x[:, -1])
        return nn.log_softmax(x)

Let's now create an Elegy model using the above network. As you can see, the loss and metrics are defined like in Keras. The model compilation is done in the constructor, so you don't have to do this manually.

import optax

model = eg.Model(

Elegy model summary

Like in Keras, we can print the model's summary.


Distributed training in Elegy

To train models in a distributed manner in Flax, we define parallel versions of our model training functions.

Read more: Distributed training with JAX & Flax

However, with Elegy, we call the distributed method.  

model = model.distributed()

Keras-like callbacks in Flax

Elegy supports callbacks similar to Keras callbacks. In this case, we train the model with the following callbacks:

callbacks = [ eg.callbacks.TensorBoard("summaries"),
             eg.callbacks.ModelCheckpoint("models/high-level", save_best_only=True),
             eg.callbacks.EarlyStopping(monitor = 'val_loss',patience=10)

Read more: TensorBoard tutorial (Deep dive with examples and notebook)

Train Elegy models

Elegy provides the fit method for training models. The method supports the following data sources:

  • Tensorflow Dataset.
  • Pytorch DataLoader
  • Elegy DataLoader, and
  • Python Generators.

This post is for subscribers only


Already have an account? Log in