Handling state in JAX & Flax (BatchNorm and DropOut layers)

Derrick Mwiti
Derrick Mwiti

Table of Contents

Jitting functions in Flax makes them faster but requires that the functions have no side effects. The fact that jitted functions can't have side effects introduces a challenge when dealing with stateful items such as model parameters and stateful layers such as batch normalization layers. In this article, we'll create a network with the BatchNorm and DropOut layers. After that, we'll see how to deal with generating the random number for the DropOut layer and adding the batch statistics when training the network.

Perform standard imports

We kick off by importing standard data science packages that we'll use in this article.

import torch
from torch.utils.data import DataLoader
import os
from PIL import Image
from torchvision import transforms
from torch.utils.data import Dataset
import numpy as np
import pandas as pd
from typing import Any
import matplotlib.pyplot as plt
%matplotlib inline
# ignore harmless warnings
import warnings
import jax
from jax import numpy as jnp
import flax
from flax import linen as nn  
from flax.training import train_state
import optax

Download the dataset

Let's illustrate how to include BatchNorm and DropOut layers in a Flax network by designing a simple Convolutional Neural Network using the cat and dogs dataset from Kaggle.

Download and extract the dataset.

import wget 

import zipfile
with zipfile.ZipFile('train.zip', 'r') as zip_ref:

Loading datasets in JAX

Since JAX doesn't ship with data loading tools, load the dataset using PyTorch. We start by creating a PyTorch Dataset class.

class CatsDogsDataset(Dataset):
    def __init__(self, root_dir, annotation_file, transform=None):
        self.root_dir = root_dir
        self.annotations = pd.read_csv(annotation_file)
        self.transform = transform

    def __len__(self):
        return len(self.annotations)

    def __getitem__(self, index):
        img_id = self.annotations.iloc[index, 0]
        img = Image.open(os.path.join(self.root_dir, img_id)).convert("RGB")
        y_label = torch.tensor(float(self.annotations.iloc[index, 1]))

        if self.transform is not None:
            img = self.transform(img)

        return (img, y_label)
Interested in learning more about loading datasets in JAX?
👉 Check our How to load datasets in JAX with TensorFlow tutorial.

Next, create a Pandas DataFrame containing the image path and the labels.

train_df = pd.DataFrame(columns=["img_path","label"])
train_df["img_path"] = os.listdir("train/")
for idx, i in enumerate(os.listdir("train/")):
    if "cat" in i:
        train_df["label"][idx] = 0
    if "dog" in i:
        train_df["label"][idx] = 1

train_df.to_csv (r'train_csv.csv', index = False, header=True)

Data processing with PyTorch

Next, create a function to stack the dataset and return it as a NumPy array.

def custom_collate_fn(batch):
    transposed_data = list(zip(*batch))
    labels = np.array(transposed_data[1])
    imgs = np.stack(transposed_data[0])
    return imgs, labels

We then use PyTorch to create training and testing data loaders.

size_image = 224
batch_size = 64

transform = transforms.Compose([
dataset = CatsDogsDataset("train","train_csv.csv",transform=transform)
train_set, validation_set = torch.utils.data.random_split(dataset,[20000,5000])
train_loader = DataLoader(dataset=train_set, collate_fn=custom_collate_fn,shuffle=True, batch_size=batch_size)
validation_loader = DataLoader(dataset=validation_set,collate_fn=custom_collate_fn, shuffle=False, batch_size=batch_size)

Define Flax model with BatchNorm and DropOut

Define the Flax network with the BatchNorm and DropOut layers. In the network, we introduce the training variable to control when the batch stats should be updated. We ensure that they aren't updated during testing.

In the BatchNorm layer we set use_running_average to False meaning that the stats stored in batch_stats will not be used, but batch stats of the input will be computed.

The DropOut layer takes the following:

  • The rate drop out probability.
  • Whether it's deterministic. If deterministic inputs are scaled and masked. Otherwise, they are not masked and returned as they are.
class CNN(nn.Module):

  def __call__(self, x, training):
    x = nn.Conv(features=128, kernel_size=(3, 3))(x)
    x = nn.relu(x)
    x = nn.max_pool(x, window_shape=(2, 2), strides=(2, 2))
    x = nn.Conv(features=64, kernel_size=(3, 3))(x)
    x = nn.relu(x)
    x = nn.max_pool(x, window_shape=(2, 2), strides=(2, 2))
    x = nn.Conv(features=32, kernel_size=(3, 3))(x)
    x = nn.relu(x)
    x = nn.max_pool(x, window_shape=(2, 2), strides=(2, 2))
    x = x.reshape((x.shape[0], -1))  
    x = nn.Dense(features=256)(x)
    x = nn.Dense(features=128)(x)
    x = nn.BatchNorm(use_running_average=not training)(x)
    x = nn.Dropout(0.2, deterministic=not training)(x)
    x = nn.relu(x)
    x = nn.Dense(features=2)(x)
    x = nn.log_softmax(x)
    return x

Create loss function

The next step is to create the loss function. When applying the model, we:

  • Pass the batch stats parameters.
  • training as True.
  • Set the batch_stats as mutable.
  • Set the random number for the DropOut
def cross_entropy_loss(*, logits, labels):
  labels_onehot = jax.nn.one_hot(labels, num_classes=2)
  return optax.softmax_cross_entropy(logits=logits, labels=labels_onehot).mean()
  def compute_loss(params, batch_stats, images, labels):
    logits,batch_stats = CNN().apply({'params': params,'batch_stats': batch_stats},images, training=True,rngs={'dropout': jax.random.PRNGKey(0)}, mutable=['batch_stats'])
    loss = cross_entropy_loss(logits=logits, labels=labels)
    return loss, (logits, batch_stats)

Compute metrics

The compute metrics function calculates the loss and accuracy and returns them.

def compute_metrics(*, logits, labels):
  loss = cross_entropy_loss(logits=logits, labels=labels)
  accuracy = jnp.mean(jnp.argmax(logits, -1) == labels)
  metrics = {
      'loss': loss,
      'accuracy': accuracy,
  return metrics

Create custom Flax training state

Let's create a custom Flax training state that will store the batch stats information. To do that, create a new training state class that subclasses Flax's TrainState.

# initialize weights
model = CNN()
key = jax.random.PRNGKey(0)
variables = model.init(key, jnp.ones([1, size_image, size_image, 3]), training=False)

class TrainState(train_state.TrainState):
    batch_stats: flax.core.FrozenDict

To define a Flax training state, use TrainState.create and pass the:

  • Apply function.
  • Model parameters.
  • The optimizer function.
  • The batch stats.
state = TrainState.create(
    apply_fn = model.apply,
    params = variables['params'],
    tx = optax.sgd(0.01),
    batch_stats = variables['batch_stats'],

Training step

In the training step, we compute the gradients with respect to the loss and model parameters– the model parameters and batch statistics. We use the gradients to update the model parameters and return the new state and model metrics. The function is decorated with @jax.jit to make the computation faster.

def train_step(state,images, labels):
  """Train for a single step."""
  (batch_loss, (logits, batch_stats)), grads= jax.value_and_grad(compute_loss, has_aux=True)(state.params,state.batch_stats, images,labels)
  state = state.apply_gradients(grads=grads) 
  metrics = compute_metrics(logits=logits, labels=labels) 
  return state, metrics
Want to learn more about JAX?
👉 Check our JAX (What it is and how to use it in Python) tutorial.

Next, define a function that applies the training step for one epoch. The functions:

  • Loops through the training data.
  • Passes each training batch the training step.
  • Obtains the batch metrics.
  • Computes the mean to obtain the epoch metrics.
  • Returns the new state and metrics.
def train_one_epoch(state, dataloader):
    """Train for 1 epoch on the training set."""
    batch_metrics = []
    for cnt, (images, labels) in enumerate(dataloader):
        images = images / 255.0
        state, metrics = train_step(state, images, labels)

    batch_metrics_np = jax.device_get(batch_metrics)  
    epoch_metrics_np = {
        k: np.mean([metrics[k] for metrics in batch_metrics_np])
        for k in batch_metrics_np[0]
    return state, epoch_metrics_np

Evaluation step

We pass the test images and labels to the model in the evaluation step and obtain the evaluation metrics. The function is also jitted to take advantage of JAX's fast computation. During the evaluation, set training to False so that the model parameters are not updated. In this step, we also pass the batch stats and the random number generator for the DropOut layer.

def eval_step(batch_stats, params, images, labels):
    logits = CNN().apply({'params': params,'batch_stats': batch_stats}, images, training=False,rngs={'dropout': jax.random.PRNGKey(0)})
    return compute_metrics(logits=logits, labels=labels)

The evaluate_model function applies the eval_step to the test data and returns the evaluation metrics.

def evaluate_model(state, test_imgs, test_lbls):
    """Evaluate on the validation set."""
    metrics = eval_step(state.batch_stats,state.params, test_imgs, test_lbls)
    metrics = jax.device_get(metrics) 
    metrics = jax.tree_map(lambda x: x.item(), metrics)  
    return metrics

Train Flax model

To train the model, we define another function that implements train_one_epoch. Let's start by defining the model evaluation data.

(test_images, test_labels) = next(iter(validation_loader))
test_images = test_images / 255.0

Set up TensorBoard in Flax

You can log the model metrics to TensorBoard by writing the scalars to TensorBoard.

from torch.utils.tensorboard import SummaryWriter
import torchvision.transforms.functional as F
logdir = "flax_logs"
writer = SummaryWriter(logdir)
Looking for deep dive into TensorBoard?
👉 Check our TensorBoard tutorial (Deep dive with examples and notebook) tutorial.

Train model

We can also append the metrics to a list and visualize them with Matplotlib.

training_loss = []
training_accuracy = []
testing_loss = []
testing_accuracy = []

Next, define the training function that will:

  • Train the Flax model for the specified number of epochs.
  • Evaluate the model on the test data.
  • Append the metrics to a list.
  • Write model metrics to TensorBoard.
  • Print the metrics on every epoch.
  • Return the trained model state
def train_model(epochs):
    for epoch in range(1, epochs + 1):
        train_state, train_metrics = train_one_epoch(state, train_loader)

        test_metrics = evaluate_model(train_state, test_images, test_labels)
        writer.add_scalar('Loss/train', train_metrics['loss'], epoch)
        writer.add_scalar('Loss/test', test_metrics['loss'], epoch)
        writer.add_scalar('Accuracy/train', train_metrics['accuracy'], epoch)
        writer.add_scalar('Accuracy/test', test_metrics['accuracy'], epoch)
        print(f"Epoch: {epoch}, training loss: {train_metrics['loss']}, training accuracy: {train_metrics['accuracy'] * 100}, validation loss: {test_metrics['loss']}, validation accuracy: {test_metrics['accuracy'] * 100}")
    return train_state

Run the training function to train the model.

trained_model_state = train_model(30)

Save Flax model

The save_checkpoint saves a Flax model. It expects:

  • The directory to save the model checkpoint.
  • The Flax trained model, in this case trained_model_state.
  • The model's prefix.
  • Whether to overwrite existing models.
from flax.training import checkpoints
ckpt_dir = 'model_checkpoint/'

Load Flax model

The restore_checkpoint method loads a saved Flax model from the saved location.

loaded_model = checkpoints.restore_checkpoint(

Evaluate Flax model

Run the evalaute_model function to check the performance of the model on test data.

evaluate_model(trained_model_state,test_images, test_labels)

Visualize Flax model performance

To visualize the performance of the Flax model, you can plot the metrics using Matplotlib or load TensorBoard and check the scalars tab.

%load_ext tensorboard 
%tensorboard --logdir={logdir}
Interested in learning more about using TensorBoard in Flax?
👉 Check our How to use TensorBoard in JAX & Flax tutorial.

Final thoughts

In this article, you have seen how to build networks in Flax containing BatchNorm and DropOut layers. You have also seen how to adjust the training process to cater to these new layers. Specifically, you have learned:

  • How to define Flax models with BatchNorm and DropOut layers.
  • How to create a custom Flax training state.
  • Training and evaluating a Flax model with BatchNorm and DropOut layers.
  • How to save and load a Flax model.
  • How to evaluate the performance of a Flax model

Interested in diving deeper into JAX and Flax? Here are some more resources from our blog:

Open On GitHub

Follow us on LinkedIn, Twitter, GitHub, and subscribe to our blog, so you don't miss a new issue.


Derrick Mwiti Twitter

Google Developer Expert - Machine Learning