How to load datasets in JAX with TensorFlow

JAX doesn't ship with data loading utilities. This keeps JAX focused on providing a fast tool for building and training machine learning models. Loading data in JAX is done using either TensorFlow or PyTorch. In the Image classification with JAX & Flax tutorial, we saw how to load image data with PyTorch. This article will focus on how to load datasets in JAX using TensorFlow.
Let's dive in!
How to load text data in JAX
Let's use the IMDB dataset from Kaggle to illustrate how to load text datasets with JAX. We'll use the Kaggle Python library to download the data. That requires your Kaggle username and key. Head over https://www.kaggle.com/your_username/
account to obtain the API key.
The library downloads the data as a zip file. We'll therefore extract it afterward.
Next, import the standard data science packages and view a sample of the data.

Clean the text data
Let's do some processing of the data before we proceed to load it using TensorFlow. Standard processing in text problems is to remove stop words. Stop words are common words such as a
, the
that don't help the model in identifying the polarity of a sentence. NLTK provides the stops words. We can, therefore, write a function to remove them from the IMDB dataset.
Label encode the sentiment column
Convert the sentiment column to numerical representation using Scikit-learn's label encoder. This is important because neural networks expect numerical data.

Text preprocessing with TensorFlow
We have converted the sentiment column to a numerical representation. However, the reviews are still in text form. We need to convert them to numbers as well.
We start by splitting the dataset into a training and testing set.
Next, we use TensorFlow's TextVectorization function to convert the text data to integer representations. The function expects:
standardize
used to specify how the text data is processed. For example, thelower_and_strip_punctuation
option will lowercase the data and remove punctuations.max_tokens
dictates the maximum size of the vocabulary.output_mode
determines the output of the vectorization layer. Settingint
outputs integers.output_sequence_length
indicates the maximum length of the output sequence. This ensures that all sequences have the same length.
import tensorflow as tf
max_features = 5000 # Maximum vocab size.
batch_size = 32
max_len = 512 # Sequence length to pad the outputs to.
vectorize_layer = tf.keras.layers.TextVectorization(standardize='lower_and_strip_punctuation',max_tokens=max_features,output_mode='int',output_sequence_length=max_len)
vectorize_layer.adapt(X_train,batch_size=None)
Next, apply this layer to the training and testing data.

Convert the data to a TensorFlow dataset and create a function to fetch the data in batches. We also convert the data to NumPy arrays because JAX expects NumPy or JAX arrays.
The data is now in the right format and to be passed to a Flax network. We have already seen how to build a network in Flax.
Let's quickly walk through the rest of the steps required to train neural networks in Flax using this data. Read the image classification with JAX and Flax article for more details about training models in Flax.

First, create a simple neural network in Flax.
Define a function to compute the loss.
Next, define the function to compute the network metrics.
The training state is used to track the network training. It tracks the optimizer and model parameters and can be modified to track other things such as dropout and batch normalization statistics.
In the training step, we Apply
the model to obtain the loss. This is then used to compute the gradients that update the model parameters.
The evaluation step applies the model to the testing data to compute the test metrics.
The evaluation function runs the above evaluation step to obtain the evaluation metrics.
We use the get_train_batches
function in the train_epoch
method. We loop through the batches as we apply the train_step
method. We obtain the train metrics and return them.
The final step is to train the network on the training set and evaluate it on the test set. A training state is required before training the model. This is because JAX expects pure functions.


How to load image data in JAX
Let's now see how we can load image data with TensorFlow. We'll use the popular cats and dogs images from Kaggle. We start by downloading the data.
Next, create a Pandas DataFrame containing the labels and paths to the images.
The next step is to define an ImageDataGenerator for scaling the images and performing simple augmentation.
Load the images using the flow_from_dataframe
of these generators. This will match the image paths in the DataFrame to the images we downloaded.
Loop through the training set to confirm that a batch of images are being generated.
The next step is to define a network and pass the data. The steps are similar to what we did for the text data above. Check the Image classification with JAX & Flax article to see how to train convolutional neural networks with JAX and Flax.
How to load CSV data in JAX
You can use Pandas to load CSV data as we did for the text data at the beginning of the article. Convert the data to NumPy or JAX arrays once preprocessing is done. Passing Torch tensors or TensorFlow tensors to JAX neural networks will result in an error.
Final thoughts
This article shows how you can use TensorFlow to load datasets in JAX and Flax applications. We have walked through an example of loading text data with TensorFlow. After that, we discussed loading image and CSV data in JAX.
Resources
- What is JAX?
- Elegy(High-level API for deep learning in JAX & Flax)
- Flax vs. TensorFlow
- JAX loss functions
- Optimizers in JAX and Flax
- Distributed training with JAX & Flax
- How to load datasets in JAX with TensorFlow
- How to use TensorBoard in Flax
- Building convolutional neural networks with JAX and Flax
- LSTM in JAX & Flax
Follow us on LinkedIn, Twitter, GitHub, and subscribe to our blog, so you don't miss a new issue.