Image classification with JAX & Flax

Flax is a neural network library for JAX. JAX is a Python library that provides high-performance computing in machine learning research. JAX provides an API similar to NumPy making it easy to adopt. JAX also includes other functionalities for improving machine learning research. They include:
- Automatic differentiation. JAX supports forward and reverse automatic differential of numerical functions with functions such as jacrev, grad, hessian and jacfwd.
- Vectorization. JAX supports automatic vectorization via the
vmap
function. It also makes it easy to parallelize large-scale data processing via thepmap
function. - JIT compilation. JAX uses XLA for Just In Time (JIT) compilation and execution of code on GPUs and TPUs.
In this article, let's look at how you can use JAX and Flax to build a simple convolutional neural network.
Loading the dataset
We'll use the cats and dogs dataset from Kaggle. Let's start by downloading and extracting it.
Flax doesn't ship with any data loading tools. You can use the data loaders from PyTorch or TensorFlow. In this case, let's load the data using PyTorch. The first step is to define the dataset class.
Next, we create a Pandas DataFrame that will contain the categories.
Define a function that will stack the data and return it as NumPy arrays.
We are now ready to define the training and test data and use that with the PyTorch DataLoader. We also define a PyTorch transformation for resizing the images.
Define Convolution Neural Network with Flax
Install Flax to create a simple neural network.
pip install flax
Networks are created in Flax using the Linen API by subclassing Module. All Flax modules are Python dataclasses. This means that they have __.init__
by default. You should, therefore, override setup()
instead to initialize the network. However, you can use the compact wrapper to make the model definition more concise.
Define loss
The loss can be computed using the Optax package. We one-hot encode the integer labels before passing them to the softmax cross-entropy function. num_classes
is 2 because we are dealing with two classes.
Compute metrics
Next, we define a function that will use the above loss function to compute and return the loss. We also compute the accuracy in the same function.
Create training state
A training state holds the model variables such as parameters and optimizer state. These variables are modified at each iteration using the optimizer. You can subclass flax.training.train_state
to track more data. You might want to do that for tracking the state of dropout and batch statistics if you include those layers in your model. For this simple model, the default class will suffice.

Read more: Optimizers in JAX and Flax
Define training step
In this function, we evaluate the model with a set of input images using the Apply
method. We use the obtained logits to compute the loss. We then use value_and_grad
to evaluate the loss function and its gradient. The gradients are then used to update the model parameters. Finally, it uses the compute_metrics
function defined above to calculate the loss and accuracy.
The function is decorated with the @Jit decorator to trace the function and compile just-in-time for faster computation.

Define evaluation step
The evaluation function will use Apply
to evaluate the model on the test data.

Evaluate the model
The evaluation function runs the evaluation step and returns the test metrics.
Train and evaluate the model
We need to initialize the train state before training the model. The function to initialize the state requires a pseudo-random number (PRNG) key. Use the PRNGKey
function to obtain a key and split it to get another key that you'll use for parameter initialization. Follow this link to learn more about JAX PRNG Design.

Pass this key to the create_train_state
function together with the learning rate and momentum. You can now train the model and use the eval_model
function to evaluate the model.
Model performance
While the training is happening, we print the training and validation metrics. You can also use the resulting metrics to plot training and validation charts.


Final thoughts
In this article, we have seen how to set up a simple neural network with Flax and train it on the CPU. The next article will focus on how to train the network using TPUs and GPUS. Subscribe to have the issue delivered to your inbox.
Check out the Image Augmentation in JAX tutorial to learn how to augment the images and possibly improve the model's performance.
Check out more JAX articles.