# Image classification with JAX & Flax

Flax is a neural network library for JAX. JAX is a Python library that provides high-performance computing in machine learning research. JAX provides an API similar to NumPy making it easy to adopt. JAX also includes other functionalities for improving machine learning research. They include:

**Automatic differentiation**. JAX supports forward and reverse automatic differential of numerical functions with functions such as jacrev, grad, hessian and jacfwd.**Vectorization**. JAX supports automatic vectorization via the`vmap`

function. It also makes it easy to parallelize large-scale data processing via the`pmap`

function.**JIT compilation**. JAX uses XLA for Just In Time (JIT) compilation and execution of code on GPUs and TPUs.

In this article, let's look at how you can use JAX and Flax to build a simple convolutional neural network.

### Loading the dataset

We'll use the cats and dogs dataset from Kaggle. Let's start by downloading and extracting it.

Flax doesn't ship with any data loading tools. You can use the data loaders from PyTorch or TensorFlow. In this case, let's load the data using PyTorch. The first step is to define the dataset class.

Next, we create a Pandas DataFrame that will contain the categories.

Define a function that will stack the data and return it as NumPy arrays.

We are now ready to define the training and test data and use that with the PyTorch DataLoader. We also define a PyTorch transformation for resizing the images.

### Define Convolution Neural Network with Flax

Install Flax to create a simple neural network.

```
pip install flax
```

Networks are created in Flax using the Linen API by subclassing Module. All Flax modules are Python dataclasses. This means that they have `__.init__`

by default. You should, therefore, override `setup()`

instead to initialize the network. However, you can use the compact wrapper to make the model definition more concise.

### Define loss

The loss can be computed using the Optax package. We one-hot encode the integer labels before passing them to the softmax cross-entropy function. `num_classes`

is 2 because we are dealing with two classes.

### Compute metrics

Next, we define a function that will use the above loss function to compute and return the loss. We also compute the accuracy in the same function.

### Create training state

A training state holds the model variables such as parameters and optimizer state. These variables are modified at each iteration using the optimizer. You can subclass `flax.training.train_state`

to track more data. You might want to do that for tracking the state of dropout and batch statistics if you include those layers in your model. For this simple model, the default class will suffice.

Read more: Optimizers in JAX and Flax

### Define training step

In this function, we evaluate the model with a set of input images using the `Apply`

method. We use the obtained logits to compute the loss. We then use `value_and_grad`

to evaluate the loss function and its gradient. The gradients are then used to update the model parameters. Finally, it uses the `compute_metrics`

function defined above to calculate the loss and accuracy.

The function is decorated with the @Jit decorator to trace the function and compile just-in-time for faster computation.

### Define evaluation step

The evaluation function will use `Apply`

to evaluate the model on the test data.

### Training function

In this function, we apply the training step we defined above. We loop through each batch in the data loader and perform optimization for each batch. We use the `jax.device_get`

to get the metrics and compute the mean.

### Evaluate the model

The evaluation function runs the evaluation step and returns the test metrics.

### Train and evaluate the model

We need to initialize the train state before training the model. The function to initialize the state requires a pseudo-random number (PRNG) key. Use the `PRNGKey`

function to obtain a key and split it to get another key that you'll use for parameter initialization. Follow this link to learn more about JAX PRNG Design.

Pass this key to the `create_train_state`

function together with the learning rate and momentum. You can now use the `train_one_epoch`

function to train the model and the `eval_model`

function to evaluate the model.

## This post is for subscribers only

SubscribeAlready have an account? Log in