Flax is a neural network library for JAX. JAX is a Python library that provides high-performance computing in machine learning research. JAX provides an API similar to NumPy making it easy to adopt. JAX also includes other functionalities for improving machine learning research. They include:
- Automatic differentiation. JAX supports forward and reverse automatic differential of numerical functions with functions such as jacrev, grad, hessian and jacfwd.
- Vectorization. JAX supports automatic vectorization via the
vmapfunction. It also makes it easy to parallelize large-scale data processing via the
- JIT compilation. JAX uses XLA for Just In Time (JIT) compilation and execution of code on GPUs and TPUs.
In this article, let's look at how you can use JAX and Flax to build a simple convolutional neural network.
Loading the dataset
We'll use the cats and dogs dataset from Kaggle. Let's start by downloading and extracting it.
Next, we create a Pandas DataFrame that will contain the categories.
Define a function that will stack the data and return it as NumPy arrays.
We are now ready to define the training and test data and use that with the PyTorch DataLoader. We also define a PyTorch transformation for resizing the images.
Define Convolution Neural Network with Flax
Install Flax to create a simple neural network.
pip install flax
Networks are created in Flax using the Linen API by subclassing Module. All Flax modules are Python dataclasses. This means that they have
__.init__ by default. You should, therefore, override
setup() instead to initialize the network. However, you can use the compact wrapper to make the model definition more concise.
The loss can be computed using the Optax package. We one-hot encode the integer labels before passing them to the softmax cross-entropy function.
num_classes is 2 because we are dealing with two classes.
Next, we define a function that will use the above loss function to compute and return the loss. We also compute the accuracy in the same function.
Create training state
A training state holds the model variables such as parameters and optimizer state. These variables are modified at each iteration using the optimizer. You can subclass
flax.training.train_state to track more data. You might want to do that for tracking the state of dropout and batch statistics if you include those layers in your model. For this simple model, the default class will suffice.
Read more: Optimizers in JAX and Flax
Define training step
In this function, we evaluate the model with a set of input images using the
Apply method. We use the obtained logits to compute the loss. We then use
value_and_grad to evaluate the loss function and its gradient. The gradients are then used to update the model parameters. Finally, it uses the
compute_metrics function defined above to calculate the loss and accuracy.
The function is decorated with the @Jit decorator to trace the function and compile just-in-time for faster computation.
Define evaluation step
The evaluation function will use
Apply to evaluate the model on the test data.
In this function, we apply the training step we defined above. We loop through each batch in the data loader and perform optimization for each batch. We use the
jax.device_get to get the metrics and compute the mean.
Evaluate the model
The evaluation function runs the evaluation step and returns the test metrics.
Train and evaluate the model
We need to initialize the train state before training the model. The function to initialize the state requires a pseudo-random number (PRNG) key. Use the
PRNGKey function to obtain a key and split it to get another key that you'll use for parameter initialization. Follow this link to learn more about JAX PRNG Design.
Pass this key to the
create_train_state function together with the learning rate and momentum. You can now use the
train_one_epoch function to train the model and the
eval_model function to evaluate the model.
This post is for subscribers onlySubscribe
Already have an account? Log in