Photo by Rafael Pol on Unsplash

Image classification with JAX & Flax

Derrick Mwiti
Derrick Mwiti
5 min read

Flax is a neural network library for JAX. JAX is a Python library that provides high-performance computing in machine learning research. JAX provides an API similar to NumPy making it easy to adopt. JAX also includes other functionalities for improving machine learning research. They include:

  • Automatic differentiation. JAX supports forward and reverse automatic differential of numerical functions with functions such as jacrev, grad, hessian and jacfwd.
  • Vectorization. JAX supports automatic vectorization via the vmap function. It also makes it easy to parallelize large-scale data processing via the pmap function.
  • JIT compilation. JAX uses XLA for Just In Time (JIT) compilation and execution of code on GPUs and TPUs.

In this article, let's look at how you can use JAX and Flax to build a simple convolutional neural network.  

Loading the dataset

We'll use the cats and dogs dataset from Kaggle. Let's start by downloading and extracting it.

Flax doesn't ship with any data loading tools. You can use the data loaders from PyTorch or TensorFlow. In this case, let's load the data using PyTorch. The first step is to define the dataset class.

Next, we create a Pandas DataFrame that will contain the categories.

 Define a function that will stack the data and return it as NumPy arrays.

We are now ready to define the training and test data and use that with the PyTorch DataLoader. We also define a PyTorch transformation for resizing the images.

Define Convolution Neural Network with Flax

Install Flax to create a simple neural network.

pip install flax

Networks are created in Flax using the Linen API by subclassing Module.  All Flax modules are Python dataclasses. This means that they have  __.init__ by default. You should, therefore, override setup() instead to initialize the network. However, you can use the compact wrapper to make the model definition more concise.    

Define loss

The loss can be computed using the Optax package. We one-hot encode the integer labels before passing them to the softmax cross-entropy function. num_classes is 2 because we are dealing with two classes.

Compute metrics

Next, we define a function that will use the above loss function to compute and return the loss. We also compute the accuracy in the same function.  

Create training state

A training state holds the model variables such as parameters and optimizer state. These variables are modified at each iteration using the optimizer. You can subclass to track more data. You might want to do that for tracking the state of dropout and batch statistics if you include those layers in your model. For this simple model, the default class will suffice.

Managing Parameters and State — Flax documentation

Read more: Optimizers in JAX and Flax

Define training step

In this function, we evaluate the model with a set of input images using the Apply method. We use the obtained logits to compute the loss. We then use  value_and_grad  to evaluate the loss function and its gradient. The gradients are then used to update the model parameters. Finally, it uses the compute_metrics function defined above to calculate the loss and accuracy.

The function is decorated with the @Jit decorator to trace the function and compile just-in-time for faster computation.

jax.value_and_grad — JAX documentation

Define evaluation step

The evaluation function will use Apply  to evaluate the model on the test data.  

flax.linen package — Flax documentation

Training function

In this function, we apply the training step we defined above.  We loop through each batch in the data loader and perform optimization for each batch. We use the jax.device_get to get the metrics and compute the mean.  

jax.device_get — JAX documentation

Evaluate the model

The evaluation function runs the evaluation step and returns the test metrics.

Train and evaluate the model

We need to initialize the train state before training the model. The function to initialize the state requires a pseudo-random number (PRNG) key. Use the PRNGKey function to obtain a key and split it to get another key that you'll use for parameter initialization. Follow this link to learn more about JAX PRNG Design.

JAX PRNG Design — JAX documentation

Pass this key to the create_train_state function together with the learning rate and momentum. You can now use the train_one_epoch function to train the model and the eval_model function to evaluate the model.  

This post is for subscribers only


Already have an account? Log in