LSTM in JAX & Flax (Complete example with code and notebook)

Derrick Mwiti
Derrick Mwiti

Table of Contents

LSTMs are a class of neural networks used to solve sequence problems such as time series and natural language processing. The LSTMs maintain some internal state that is useful in solving these problems. LSTMs apply for loops to iterate over each time step. We can use functions from JAX and Flax instead of writing these for loops from scratch. In this article, we will build a natural language processing model using LSTMs in Flax.

If this is your first interaction with JAX, then I recommend that you first go through our JAX introductory tutorials:

Let's get started.

Dataset download

We'll use the movie review dataset from Kaggle. We download the dataset using Kaggle's Python package.

import os
#Obtain from
import kaggle
kaggle datasets download lakshmi25npathi/imdb-dataset-of-50k-movie-reviews

Next, extract the dataset.

import zipfile
with zipfile.ZipFile('', 'r') as zip_ref:

Load the dataset using Pandas and display a sample of the reviews.

df = pd.read_csv("imdb-dataset-of-50k-movie-reviews/IMDB Dataset.csv")

Data processing with NLTK

The dataset contains unnecessary characters for predicting whether a movie review is negative or positive. For instance, punctuation marks and special characters. We, therefore, remove these from the reviews.

We also need to convert the sentiment column into a numerical representation. This is achieved using LabelEncoder from Scikit-learn. Let's import that together with other packages we'll use throughout this article.

import numpy as np 
import pandas as pd 
from numpy import array
import tensorflow as tf
from sklearn.model_selection import train_test_split 
from sklearn.preprocessing import LabelEncoder
import matplotlib.pyplot as plt

The reviews also contain words that are not useful in the sentiment prediction. These are common words in English, such as the, at, and, etc. These words are known as stopwords. We remove them with the help of the nltk library. Let's start by defining a function to remove all the English stopwords.

# pip install nltk
import nltk
from nltk.corpus import stopwords'stopwords')
def remove_stop_words(review):
    review_minus_sw = []
    stop_words = stopwords.words('english')
    review = review.split()
    cleaned_review = [review_minus_sw.append(word) for word in review if word not in stop_words]            
    cleaned_review = ' '.join(review_minus_sw)
    return cleaned_review       

Apply the function to the sentiment column.

df['review'] = df['review'].apply(remove_stop_words)

Let's also convert the sentiment column to numerical representation.

labelencoder = LabelEncoder()
df = df.assign(sentiment = labelencoder.fit_transform(df["sentiment"]))

Compare the reviews with the review with and without the stop words.

Looking at the third review, we notice that the words this, was and a have been dropped from the sentence. However, we can still see some special characters, such as <br> in the review. Let's resolve that next.

Text vectorization with Keras

The review data is still in text form. However, we need to convert it to a numeric representation like the sentiment column. Before we do that, let's split the dataset into a training and testing set.

from sklearn.model_selection import train_test_split
df = df.drop_duplicates()
docs = df['review']
labels = array(df['sentiment'])
X_train, X_test , y_train, y_test = train_test_split(docs, labels , test_size = 0.20, random_state=0)

We use the Keras text vectorization layer to convert the reviews to integer form. This function lets us filter out all punctuation marks and convert the reviews to lowercase. We pass the following parameters:

  • standardize as lower_and_strip_punctuation to convert to lowercase and remove punctuation marks.
  • output_mode to int to get the result as integers. tf_idf would apply the TF-IDF algorithm.
  • output_sequence_length as 50 to get sentences of that length. Change this number to see how it affects the model's performance. I found 50 to five some good results. Sentences longer than the specified length will be truncated, while shorter ones will be padded with zeros.
  • max_tokens as 10,000 to have a vocabulary size of that number. Tweak this number and check how the model's performance changes.

After defining the vectorization layer, we apply it to the training data. This is done by calling the adapt function. The function computes the vocabulary from the provided dataset. The vocabulary will be truncated tomax_tokens, if that is provided.

import tensorflow as tf
max_features = 10000  # Maximum vocab size.
batch_size = 128
max_len = 50 # Sequence length to pad the outputs to.
vectorize_layer = tf.keras.layers.TextVectorization(standardize='lower_and_strip_punctuation',max_tokens=max_features,output_mode='int',output_sequence_length=max_len)

To view the generated vocabulary, call the get_vocabulary function.


Convert the training and test data to numerical form using the trained vectorization layer.

X_train_padded =  vectorize_layer(X_train)
X_test_padded =  vectorize_layer(X_test)

Create dataset

Let's generate and prefetch batches from the training and test set to make loading them to the LSTM model more efficient. We start by creating a

training_data =, y_train))
validation_data =, y_test))
training_data = training_data.batch(batch_size)
validation_data = validation_data.batch(batch_size)

Read: Loading datasets in JAX using TensorFlow.

Next, we prefetch one batch, shuffle the data and return it as a NumPy array.

# pip install tensorflow_datasets
import tensorflow_datasets as tfds
def get_train_batches():
  ds = training_data.prefetch(1)
  ds = ds.shuffle(3, reshuffle_each_iteration=True)
  # tfds.dataset_as_numpy converts the into an iterable of NumPy arrays
  return tfds.as_numpy(ds)

Define LSTM model in Flax

We are now ready to define the LSTM model in Flax. To design LSTMs in Flax, we use the LSTMCell or the OptimizedLSTMCell. The OptimizedLSTMCell is the efficient LSTMCell. The LSTMCell.initialize_carry function is used to initialize the hidden state of the LSTM cell. It expects:

  • A random number.
  • The batch dimensions.
  • The number of units.

Let's use the setup method to define the LSTM model. The LSTM contains the following layers:

  • An Embedding layer with the same number of features and length as defined in the vectorization layer.
  • LSTM layers that pass data in one direction as specified by the reverse argument.
  • A couple of Dense layers.
  • Final dense output layer.
from flax import linen as nn

class LSTMModel(nn.Module):
    def setup(self):
        self.embedding = nn.Embed(max_features, max_len)
        lstm_layer = nn.scan(nn.OptimizedLSTMCell,
                               split_rngs={"params": False},
        self.lstm1 = lstm_layer()
        self.dense1 = nn.Dense(256)
        self.lstm2 = lstm_layer()
        self.dense2 = nn.Dense(128)
        self.lstm3 = lstm_layer()
        self.dense3 = nn.Dense(64)
        self.dense4 = nn.Dense(2)
    def __call__(self, x_batch):
        x = self.embedding(x_batch)
        carry, hidden = nn.OptimizedLSTMCell.initialize_carry(jax.random.PRNGKey(0), batch_dims=(len(x_batch),), size=128)
        (carry, hidden), x = self.lstm1((carry, hidden), x)
        x = self.dense1(x)
        x = nn.relu(x)
        carry, hidden = nn.OptimizedLSTMCell.initialize_carry(jax.random.PRNGKey(0), batch_dims=(len(x_batch),), size=64)
        (carry, hidden), x = self.lstm2((carry, hidden), x)
        x = self.dense2(x)
        x = nn.relu(x)
        carry, hidden = nn.OptimizedLSTMCell.initialize_carry(jax.random.PRNGKey(0), batch_dims=(len(x_batch),), size=32)
        (carry, hidden), x = self.lstm3((carry, hidden), x)
        x = self.dense3(x)
        x = nn.relu(x)
        x = self.dense4(x[:, -1])
        return nn.log_softmax(x)

We apply the scan function to iterate over the data. It expects:

  • scan the items to be looped over. They must be the same size and will be stacked along the scan axis.
  • carry a carried value that is updated at each iteration. The value must be the same shape and dtype throughout the iteration.
  • broadcast a value that is closed over by the loop
  • <axis:int> axis along which to scan.
  • split_rngs to define if to split the random number generator at each step.

The nn.remat call saves memory when using LSTMs to compute long sequences.

Compute metrics in Flax

Next, we define a function to compute the loss and accuracy of the network.

import optax
import jax.numpy as jnp
def compute_metrics(logits, labels):
  loss = jnp.mean(optax.softmax_cross_entropy(logits, jax.nn.one_hot(labels, num_classes=2)))
  accuracy = jnp.mean(jnp.argmax(logits, -1) == labels)
  metrics = {
      'loss': loss,
      'accuracy': accuracy
  return metrics

Read more: JAX loss functions.

Create training state

The training state applies gradients and updates the parameters and optimizer state. Flax provides train_state for this purpose. We define a function that:

  • Creates an instance of the LSTMModel.
  • Initializes the model to obtain the params by passing a sample of the training data.
  • Returns the created state after applying the Adam optimizer.
from import train_state

def create_train_state(rng):
  """Creates initial `TrainState`."""
  model = LSTMModel()
  params = model.init(rng, jnp.array(X_train_padded[0]))['params']
  tx = optax.adam(0.001,0.9,0.999,1e-07)
  return train_state.TrainState.create(
      apply_fn=model.apply, params=params, tx=tx)

Define training step

The training function does the following:

  • Compute the loss and logits from the model with the apply method.
  • Compute the gradients using value_and_grad.
  • Use the gradients to update the model parameters.
  • Compute the metrics using the function defined earlier.
  • Returns the state and metrics.

Applying @jax.jit makes the function run faster.

def train_step(state, text, labels):
  def loss_fn(params):
    logits = LSTMModel().apply({'params': params}, text)
    loss = jnp.mean(optax.softmax_cross_entropy(
        labels=jax.nn.one_hot(labels, num_classes=2)))
    return loss, logits
  grad_fn = jax.value_and_grad(loss_fn, has_aux=True)
  (_, logits), grads = grad_fn(state.params)
  state = state.apply_gradients(grads=grads)
  metrics = compute_metrics(logits, labels)
  return state, metrics

Evaluate the Flax model

The eval_step evaluates the model's performance on the test set using Module.apply. It returns the loss and accuracy on the testing set.

The evaluate_model function applies the eval_step , obtains the metrics from the device and returns them as a jax.tree_map.

def eval_step(state, text, labels):
    logits = LSTMModel().apply({'params': state.params}, text)
    return compute_metrics(logits=logits, labels=labels)
  def evaluate_model(state, text, test_lbls):
    """Evaluate on the validation set."""
    metrics = eval_step(state, text, test_lbls)
    metrics = jax.device_get(metrics) 
    metrics = jax.tree_map(lambda x: x.item(), metrics)  
    return metrics

Create training function

Next, define a function that trains the Flax LSTM model on one epoch. The function applies train_step to each batch in the training data. After each batch, it appends the metrics to a list.

def train_one_epoch(state):
    """Train for 1 epoch on the training set."""
    batch_metrics = []
    for text, labels in get_train_batches():
        state, metrics = train_step(state, text, labels)

    batch_metrics_np = jax.device_get(batch_metrics)  
    epoch_metrics_np = {
        k: np.mean([metrics[k] for metrics in batch_metrics_np])
        for k in batch_metrics_np[0]

    return state, epoch_metrics_np

The function obtains the metrics from the device and computes the mean from all the trained batches. This gives the loss and accuracy for one epoch.

Train LSTM model in Flax

To train the LSTM model, we run the train_one_epoch function for several iterations. Next, apply the evaluate_model to obtain the test metrics for each epoch. Before training starts, we create a create_train_state to hold the training information. The function initializes the model parameters and the optimizer. This information is stored in the training state dataclass.

rng = jax.random.PRNGKey(0)
rng, input_rng, init_rng = jax.random.split(rng,num=3)

seed = 0 

state = create_train_state(init_rng)
del init_rng  # Must not be used anymore.

num_epochs = 30
(text, test_labels) = next(iter(validation_data))
text = jnp.array(text)
test_labels = jnp.array(test_labels)
training_loss = []
training_accuracy = []
testing_loss = []
testing_accuracy = []

def train_model():
    for epoch in range(1, num_epochs + 1):
        train_state, train_metrics = train_one_epoch(state)
        test_metrics = evaluate_model(train_state, text, test_labels)
        print(f"Epoch: {epoch}, train loss: {train_metrics['loss']}, train accuracy: {train_metrics['accuracy'] * 100}, test loss: {test_metrics['loss']}, test accuracy: {test_metrics['accuracy'] * 100}")
    return train_state
trained_model_state = train_model()

After each epoch, we print the metrics and append them to a list.

Read more: TensorFlow Recurrent Neural Networks (Complete guide with examples and code)

Visualize LSTM model performance in Flax

You can then use Matplotlib to visualize the metrics appended to the list. The training is not quite smooth, but you can tweak the architecture of the network, the length of each review, and the vocabulary size to improve performance.

Save LSTM model

To save a Flax model checkpoint, use the save_checkpoint method. It expects:

  • The directory to save the checkpoint files.
  • The Flax object to be saved, that is, target.
  • The prefix of the checkpoint file name.
  • Whether to overwrite previous checkpoints
from import checkpoints

To restore the saved model, use restore_checkpoint method.

loaded_model = checkpoints.restore_checkpoint(

This model can be used to make predictions right away.

Final thoughts

You have learned to solve natural language processing problems with JAX and Flax in this article. In particular, the nuggets you have covered include:

  • How to process text data with NLTK.
  • Text vectorization with Keras.
  • Creating batches of text data with Keras and TensorFlow.
  • How to create LSTM models in JAX and Flax.
  • How to train and evaluate the LSTM model in Flax.
  • Saving and restoring Flax LSTM models.

JAX and Flax resources

Open On GitHub

Follow us on LinkedIn, Twitter, GitHub, and subscribe to our blog, so you don't miss a new issue.


Derrick Mwiti Twitter

Google Developer Expert - Machine Learning