How to load datasets in JAX with TensorFlow
JAX doesn't ship with data loading utilities. This keeps JAX focused on providing a fast tool for building and training machine learning models. Loading data in JAX is done using either TensorFlow or PyTorch. In the Image classification with JAX & Flax tutorial, we saw how to load image data with PyTorch. This article will focus on how to load datasets in JAX using TensorFlow.
Let's dive in!
How to load text data in JAX
Let's use the IMDB dataset from Kaggle to illustrate how to load text datasets with JAX. We'll use the Kaggle Python library to download the data. That requires your Kaggle username and key. Head over
https://www.kaggle.com/your_username/ account to obtain the API key.
The library downloads the data as a zip file. We'll therefore extract it afterward.
Next, import the standard data science packages and view a sample of the data.
Clean the text data
Let's do some processing of the data before we proceed to load it using TensorFlow. Standard processing in text problems is to remove stop words. Stop words are common words such as
the that don't help the model in identifying the polarity of a sentence. NLTK provides the stops words. We can, therefore, write a function to remove them from the IMDB dataset.
Label encode the sentiment column
Convert the sentiment column to numerical representation using Scikit-learn's label encoder. This is important because neural networks expect numerical data.
Text preprocessing with TensorFlow
We have converted the sentiment column to a numerical representation. However, the reviews are still in text form. We need to convert them to numbers as well.
We start by splitting the dataset into a training and testing set.
Next, we use TensorFlow's TextVectorization function to convert the text data to integer representations. The function expects:
standardizeused to specify how the text data is processed. For example, the
lower_and_strip_punctuationoption will lowercase the data and remove punctuations.
max_tokensdictates the maximum size of the vocabulary.
output_modedetermines the output of the vectorization layer. Setting
output_sequence_lengthindicates the maximum length of the output sequence. This ensures that all sequences have the same length.
import tensorflow as tf max_features = 5000 # Maximum vocab size. batch_size = 32 max_len = 512 # Sequence length to pad the outputs to. vectorize_layer = tf.keras.layers.TextVectorization(standardize='lower_and_strip_punctuation',max_tokens=max_features,output_mode='int',output_sequence_length=max_len) vectorize_layer.adapt(X_train,batch_size=None)
Next, apply this layer to the training and testing data.
Convert the data to a TensorFlow dataset and create a function to fetch the data in batches. We also convert the data to NumPy arrays because JAX expects NumPy or JAX arrays.
The data is now in the right format and to be passed to a Flax network. We have already seen how to build a network in Flax.
Let's quickly walk through the rest of the steps required to train neural networks in Flax using this data. Read the image classification with JAX and Flax article for more details about training models in Flax.
First, create a simple neural network in Flax.
Define a function to compute the loss.
Next, define the function to compute the network metrics.
The training state is used to track the network training. It tracks the optimizer and model parameters and can be modified to track other things such as dropout and batch normalization statistics.
In the training step, we
Apply the model to obtain the loss. This is then used to compute the gradients that update the model parameters.
The evaluation step applies the model to the testing data to compute the test metrics.
The evaluation function runs the above evaluation step to obtain the evaluation metrics.
We use the
get_train_batches function in the
train_epoch method. We loop through the batches as we apply the
train_step method. We obtain the train metrics and return them.
The final step is to train the network on the training set and evaluate it on the test set. A training state is required before training the model. This is because JAX expects pure functions.
How to load image data in JAX
Let's now see how we can load image data with TensorFlow. We'll use the popular cats and dogs images from Kaggle. We start by downloading the data.
This post is for subscribers onlySubscribe
Already have an account? Log in