Photo by Rafael Pol on Unsplash

Image classification with JAX & Flax

Derrick Mwiti
Derrick Mwiti

Table of Contents

Flax is a neural network library for JAX. JAX is a Python library that provides high-performance computing in machine learning research. JAX provides an API similar to NumPy making it easy to adopt. JAX also includes other functionalities for improving machine learning research. They include:

  • Automatic differentiation. JAX supports forward and reverse automatic differential of numerical functions with functions such as jacrev, grad, hessian and jacfwd.
  • Vectorization. JAX supports automatic vectorization via the vmap function. It also makes it easy to parallelize large-scale data processing via the pmap function.
  • JIT compilation. JAX uses XLA for Just In Time (JIT) compilation and execution of code on GPUs and TPUs.

In this article, let's look at how you can use JAX and Flax to build a simple convolutional neural network.


Loading the dataset

We'll use the cats and dogs dataset from Kaggle. Let's start by downloading and extracting it.

Flax doesn't ship with any data loading tools. You can use the data loaders from PyTorch or TensorFlow. In this case, let's load the data using PyTorch. The first step is to define the dataset class.

Next, we create a Pandas DataFrame that will contain the categories.

Define a function that will stack the data and return it as NumPy arrays.

We are now ready to define the training and test data and use that with the PyTorch DataLoader. We also define a PyTorch transformation for resizing the images.

Define Convolution Neural Network with Flax

Install Flax to create a simple neural network.

pip install flax

Networks are created in Flax using the Linen API by subclassing Module. All Flax modules are Python dataclasses. This means that they have __.init__ by default. You should, therefore, override setup() instead to initialize the network. However, you can use the compact wrapper to make the model definition more concise.

Define loss

The loss can be computed using the Optax package. We one-hot encode the integer labels before passing them to the softmax cross-entropy function. num_classes is 2 because we are dealing with two classes.

Compute metrics

Next, we define a function that will use the above loss function to compute and return the loss. We also compute the accuracy in the same function.

Create training state

A training state holds the model variables such as parameters and optimizer state. These variables are modified at each iteration using the optimizer. You can subclass flax.training.train_state to track more data. You might want to do that for tracking the state of dropout and batch statistics if you include those layers in your model. For this simple model, the default class will suffice.

Managing Parameters and State — Flax documentation

Read more: Optimizers in JAX and Flax

Define training step

In this function, we evaluate the model with a set of input images using the Apply method. We use the obtained logits to compute the loss. We then use value_and_grad to evaluate the loss function and its gradient. The gradients are then used to update the model parameters. Finally, it uses the compute_metrics function defined above to calculate the loss and accuracy.

The function is decorated with the @Jit decorator to trace the function and compile just-in-time for faster computation.

jax.value_and_grad — JAX documentation

Define evaluation step

The evaluation function will use Apply to evaluate the model on the test data.

flax.linen package — Flax documentation

Evaluate the model

The evaluation function runs the evaluation step and returns the test metrics.

Train and evaluate the model

We need to initialize the train state before training the model. The function to initialize the state requires a pseudo-random number (PRNG) key. Use the PRNGKey function to obtain a key and split it to get another key that you'll use for parameter initialization. Follow this link to learn more about JAX PRNG Design.

JAX PRNG Design — JAX documentation

Pass this key to the create_train_state function together with the learning rate and momentum. You can now train the model and use the eval_model function to evaluate the model.

Model performance

While the training is happening, we print the training and validation metrics. You can also use the resulting metrics to plot training and validation charts.


Final thoughts

In this article, we have seen how to set up a simple neural network with Flax and train it on the CPU. The next article will focus on how to train the network using TPUs and GPUS. Subscribe to have the issue delivered to your inbox.

Check out the Image Augmentation in JAX tutorial to learn how to augment the images and possibly improve the model's performance.

Kaggle Notebook

Check out more JAX articles.

JAX

Derrick Mwiti Twitter

Google Developer Expert - Machine Learning

Comments