Handling state in JAX & Flax (BatchNorm and DropOut layers)

Derrick Mwiti
Derrick Mwiti
8 min read

Jitting functions in Flax makes them faster but requires that the functions have no side effects. The fact that jitted functions can't have side effects introduces a challenge when dealing with stateful items such as model parameters and stateful layers such as batch normalization layers. In this article, we'll create a network with the BatchNorm and DropOut layers. After that, we'll see how to deal with generating the random number for the DropOut layer and adding the batch statistics when training the network.    

Perform standard imports

We kick off by importing standard data science packages that we'll use in this article.

import torch
from torch.utils.data import DataLoader
import os
from PIL import Image
from torchvision import transforms
from torch.utils.data import Dataset
import numpy as np
import pandas as pd
from typing import Any
import matplotlib.pyplot as plt
%matplotlib inline
# ignore harmless warnings
import warnings
import jax
from jax import numpy as jnp
import flax
from flax import linen as nn  
from flax.training import train_state
import optax

Download the dataset

Let's illustrate how to include BatchNorm and DropOut layers in a Flax network by designing a simple Convolutional Neural Network using the cat and dogs dataset from Kaggle.

Download and extract the dataset.

import wget 

import zipfile
with zipfile.ZipFile('train.zip', 'r') as zip_ref:

Loading datasets in JAX

Since JAX doesn't ship with data loading tools, load the dataset using PyTorch. We start by creating a PyTorch Dataset class.  

class CatsDogsDataset(Dataset):
    def __init__(self, root_dir, annotation_file, transform=None):
        self.root_dir = root_dir
        self.annotations = pd.read_csv(annotation_file)
        self.transform = transform

    def __len__(self):
        return len(self.annotations)

    def __getitem__(self, index):
        img_id = self.annotations.iloc[index, 0]
        img = Image.open(os.path.join(self.root_dir, img_id)).convert("RGB")
        y_label = torch.tensor(float(self.annotations.iloc[index, 1]))

        if self.transform is not None:
            img = self.transform(img)

        return (img, y_label)
Interested in learning more about loading datasets in JAX?
👉 Check our How to load datasets in JAX with TensorFlow tutorial.

Next, create a Pandas DataFrame containing the image path and the labels.  

train_df = pd.DataFrame(columns=["img_path","label"])
train_df["img_path"] = os.listdir("train/")
for idx, i in enumerate(os.listdir("train/")):
    if "cat" in i:
        train_df["label"][idx] = 0
    if "dog" in i:
        train_df["label"][idx] = 1

train_df.to_csv (r'train_csv.csv', index = False, header=True)

Data processing with PyTorch

Next, create a function to stack the dataset and return it as a NumPy array.

def custom_collate_fn(batch):
    transposed_data = list(zip(*batch))
    labels = np.array(transposed_data[1])
    imgs = np.stack(transposed_data[0])
    return imgs, labels

We then use PyTorch to create training and testing data loaders.

size_image = 224
batch_size = 64

transform = transforms.Compose([
dataset = CatsDogsDataset("train","train_csv.csv",transform=transform)
train_set, validation_set = torch.utils.data.random_split(dataset,[20000,5000])
train_loader = DataLoader(dataset=train_set, collate_fn=custom_collate_fn,shuffle=True, batch_size=batch_size)
validation_loader = DataLoader(dataset=validation_set,collate_fn=custom_collate_fn, shuffle=False, batch_size=batch_size)

Define Flax model with BatchNorm and DropOut

Define the Flax network with the BatchNorm and DropOut layers. In the network, we introduce the training variable to control when the batch stats should be updated. We ensure that they aren't updated during testing.

In the BatchNorm layer we set use_running_average to False meaning that the stats stored in batch_stats will not be used, but batch stats of the input will be computed.

The DropOut layer takes the following rate:

  • The rate drop out probability.
  • Whether it's deterministic. If deterministic inputs are scaled and masked. Otherwise, they are not masked and returned as they are.
class CNN(nn.Module):

  def __call__(self, x, training):
    x = nn.Conv(features=128, kernel_size=(3, 3))(x)
    x = nn.relu(x)
    x = nn.max_pool(x, window_shape=(2, 2), strides=(2, 2))
    x = nn.Conv(features=64, kernel_size=(3, 3))(x)
    x = nn.relu(x)
    x = nn.max_pool(x, window_shape=(2, 2), strides=(2, 2))
    x = nn.Conv(features=32, kernel_size=(3, 3))(x)
    x = nn.relu(x)
    x = nn.max_pool(x, window_shape=(2, 2), strides=(2, 2))
    x = x.reshape((x.shape[0], -1))  
    x = nn.Dense(features=256)(x)
    x = nn.Dense(features=128)(x)
    x = nn.BatchNorm(use_running_average=not training)(x)
    x = nn.Dropout(0.2, deterministic=not training)(x)
    x = nn.relu(x)
    x = nn.Dense(features=2)(x)
    x = nn.log_softmax(x)
    return x

Create loss function

The next step is to create the loss function. When applying the model, we:

  • Pass the batch stats parameters.
  • training as True.
  • Set the batch_stats as mutable.
  • Set the random number for the DropOut  
def cross_entropy_loss(*, logits, labels):
  labels_onehot = jax.nn.one_hot(labels, num_classes=2)
  return optax.softmax_cross_entropy(logits=logits, labels=labels_onehot).mean()
  def compute_loss(params, batch_stats, images, labels):
    logits,batch_stats = CNN().apply({'params': params,'batch_stats': batch_stats},images, training=True,rngs={'dropout': jax.random.PRNGKey(0)}, mutable=['batch_stats'])
    loss = cross_entropy_loss(logits=logits, labels=labels)
    return loss, (logits, batch_stats)

Compute metrics

The compute metrics function calculates the loss and accuracy and returns them.

def compute_metrics(*, logits, labels):
  loss = cross_entropy_loss(logits=logits, labels=labels)
  accuracy = jnp.mean(jnp.argmax(logits, -1) == labels)
  metrics = {
      'loss': loss,
      'accuracy': accuracy,
  return metrics

Create custom Flax training state

Let's create a custom Flax training state that will store the batch stats information. To do that, create a new training state class that subclasses Flax's TrainState.

# initialize weights
model = CNN()
key = jax.random.PRNGKey(0)
variables = model.init(key, jnp.ones([1, size_image, size_image, 3]), training=False)

class TrainState(train_state.TrainState):
    batch_stats: flax.core.FrozenDict

To define a Flax training state, use TrainState.create and pass the:

  • Apply function.
  • Model parameters.
  • The optimizer function.
  • The batch stats.
state = TrainState.create(
    apply_fn = model.apply,
    params = variables['params'],
    tx = optax.sgd(0.01),
    batch_stats = variables['batch_stats'],

Training step

In the training step, we compute the gradients with respect to the loss and model parameters– the model parameters and batch statistics. We use the gradients to update the model parameters and return the new state and model metrics. The function is decorated with @jax.jit to make the computation faster.

def train_step(state,images, labels):
  """Train for a single step."""
  (batch_loss, (logits, batch_stats)), grads= jax.value_and_grad(compute_loss, has_aux=True)(state.params,state.batch_stats, images,labels)
  state = state.apply_gradients(grads=grads) 
  metrics = compute_metrics(logits=logits, labels=labels) 
  return state, metrics
Want to learn more about JAX?
👉 Check our JAX (What it is and how to use it in Python) tutorial.

Next, define a function that applies the training step for one epoch. The functions:

  • Loops through the training data.
  • Passes each training batch the training step.
  • Obtains the batch metrics.
  • Computes the mean to obtain the epoch metrics.
  • Returns the new state and metrics.

This post is for subscribers only


Already have an account? Log in