How to load datasets in JAX with TensorFlow
How to load datasets in JAX with TensorFlow. 

How to load datasets in JAX with TensorFlow

Derrick Mwiti
Derrick Mwiti
5 min read

JAX doesn't ship with data loading utilities. This keeps JAX focused on providing a fast tool for building and training machine learning models. Loading data in JAX is done using either TensorFlow or PyTorch. In the Image classification with JAX & Flax tutorial, we saw how to load image data with PyTorch. This article will focus on how to load datasets in JAX using TensorFlow.

Let's dive in!

How to load text data in JAX

Let's use the IMDB dataset from Kaggle to illustrate how to load text datasets with JAX. We'll use the Kaggle Python library to download the data. That requires your Kaggle username and key. Head over account to obtain the API key.

The library downloads the data as a zip file. We'll therefore extract it afterward.  

Next, import the standard data science packages and view a sample of the data.

Clean the text data

Let's do some processing of the data before we proceed to load it using TensorFlow. Standard processing in text problems is to remove stop words. Stop words are common words such as a , the that don't help the model in identifying the polarity of a sentence. NLTK provides the stops words. We can, therefore, write a function to remove them from the IMDB dataset.

Label encode the sentiment column

Convert the sentiment column to numerical representation using Scikit-learn's label encoder. This is important because neural networks expect numerical data.  

Text preprocessing with TensorFlow

We have converted the sentiment column to a numerical representation. However, the reviews are still in text form. We need to convert them to numbers as well.

We start by splitting the dataset into a training and testing set.  

Next, we use TensorFlow's TextVectorization function to convert the text data to integer representations. The function expects:

  • standardize used to specify how the text data is processed. For example, the lower_and_strip_punctuation option will lowercase the data and remove punctuations.
  • max_tokens dictates the maximum size of the vocabulary.
  • output_mode determines the output of the vectorization layer. Setting int outputs integers.
  • output_sequence_length indicates the maximum length of the output sequence. This ensures that all sequences have the same length.  
import tensorflow as tf
max_features = 5000  # Maximum vocab size.
batch_size = 32
max_len = 512 # Sequence length to pad the outputs to.
vectorize_layer = tf.keras.layers.TextVectorization(standardize='lower_and_strip_punctuation',max_tokens=max_features,output_mode='int',output_sequence_length=max_len)

Next, apply this layer to the training and testing data.

Convert the data to a TensorFlow dataset and create a function to fetch the data in batches. We also convert the data to NumPy arrays because JAX expects NumPy or JAX arrays.  

The data is now in the right format and to be passed to a Flax network. We have already seen how to build a network in Flax.

Let's quickly walk through the rest of the steps required to train neural networks in Flax using this data. Read the image classification with JAX and Flax article for more details about training models in Flax.

Image classification with JAX & Flax
Learn how to build convolutional neural networks with JAX and Flax. Flax is a neural network library for JAX. JAX is a Python library that provides high-performance computing in machine learning research. JAX provides an API similar to NumPy making it easy to adopt.

First, create a simple neural network in Flax.

Define a function to compute the loss.

Next, define the function to compute the network metrics.

The training state is used to track the network training. It tracks the optimizer and model parameters and can be modified to track other things such as dropout and batch normalization statistics.    

In the training step, we Apply the model to obtain the loss. This is then used to compute the gradients that update the model parameters.

The evaluation step applies the model to the testing data to compute the test metrics.

The evaluation function runs the above evaluation step to obtain the evaluation metrics.  

We use the get_train_batches function in the train_epoch method. We loop through the batches as we apply the train_step method. We obtain the train metrics and return them.

The final step is to train the network on the training set and evaluate it on the test set. A training state is required before training the model. This is because JAX expects pure functions.

What is JAX?
JAX is a Python library offering high performance in machine learning with XLA and Just In Time (JIT) compilation. Its API is similar to NumPy’s with a few differences. JAX ships with some functionalities that aim to improve and increase speed in machine learning research. These functionalities incl…

How to load image data in JAX

Let's now see how we can load image data with TensorFlow. We'll use the popular cats and dogs images from Kaggle. We start by downloading the data.

This post is for subscribers only


Already have an account? Log in