
Handling state in JAX & Flax (BatchNorm and DropOut layers)
Jitting functions in Flax makes them faster but requires that the functions have no side effects. The fact that jitted functions can't have side effects introduces a challenge when dealing with stateful items such as model parameters and stateful layers such as batch normalization layers. In this article, we'll create a network with the BatchNorm and DropOut layers. After that, we'll see how to deal with generating the random number for the DropOut layer and adding the batch statistics when training the network.
Perform standard imports
We kick off by importing standard data science packages that we'll use in this article.
import torch
from torch.utils.data import DataLoader
import os
from PIL import Image
from torchvision import transforms
from torch.utils.data import Dataset
import numpy as np
import pandas as pd
from typing import Any
import matplotlib.pyplot as plt
%matplotlib inline
# ignore harmless warnings
import warnings
warnings.filterwarnings("ignore")
import jax
from jax import numpy as jnp
import flax
from flax import linen as nn
from flax.training import train_state
import optax
Download the dataset
Let's illustrate how to include BatchNorm and DropOut layers in a Flax network by designing a simple Convolutional Neural Network using the cat and dogs dataset from Kaggle.
Download and extract the dataset.
import wget
wget.download("https://ml.machinelearningnuggets.com/train.zip")
import zipfile
with zipfile.ZipFile('train.zip', 'r') as zip_ref:
zip_ref.extractall('.')
Loading datasets in JAX
Since JAX doesn't ship with data loading tools, load the dataset using PyTorch. We start by creating a PyTorch Dataset class.
class CatsDogsDataset(Dataset):
def __init__(self, root_dir, annotation_file, transform=None):
self.root_dir = root_dir
self.annotations = pd.read_csv(annotation_file)
self.transform = transform
def __len__(self):
return len(self.annotations)
def __getitem__(self, index):
img_id = self.annotations.iloc[index, 0]
img = Image.open(os.path.join(self.root_dir, img_id)).convert("RGB")
y_label = torch.tensor(float(self.annotations.iloc[index, 1]))
if self.transform is not None:
img = self.transform(img)
return (img, y_label)
👉 Check our How to load datasets in JAX with TensorFlow tutorial.
Next, create a Pandas DataFrame containing the image path and the labels.
train_df = pd.DataFrame(columns=["img_path","label"])
train_df["img_path"] = os.listdir("train/")
for idx, i in enumerate(os.listdir("train/")):
if "cat" in i:
train_df["label"][idx] = 0
if "dog" in i:
train_df["label"][idx] = 1
train_df.to_csv (r'train_csv.csv', index = False, header=True)
Data processing with PyTorch
Next, create a function to stack the dataset and return it as a NumPy array.
def custom_collate_fn(batch):
transposed_data = list(zip(*batch))
labels = np.array(transposed_data[1])
imgs = np.stack(transposed_data[0])
return imgs, labels
We then use PyTorch to create training and testing data loaders.
size_image = 224
batch_size = 64
transform = transforms.Compose([
transforms.Resize((size_image,size_image)),
np.array])
dataset = CatsDogsDataset("train","train_csv.csv",transform=transform)
train_set, validation_set = torch.utils.data.random_split(dataset,[20000,5000])
train_loader = DataLoader(dataset=train_set, collate_fn=custom_collate_fn,shuffle=True, batch_size=batch_size)
validation_loader = DataLoader(dataset=validation_set,collate_fn=custom_collate_fn, shuffle=False, batch_size=batch_size)
Define Flax model with BatchNorm and DropOut
Define the Flax network with the BatchNorm and DropOut layers. In the network, we introduce the training
variable to control when the batch stats should be updated. We ensure that they aren't updated during testing.
In the BatchNorm
layer we set use_running_average
to False
meaning that the stats stored in batch_stats
will not be used, but batch stats of the input will be computed.
The DropOut
layer takes the following rate:
- The rate drop out probability.
- Whether it's
deterministic
. If deterministic inputs are scaled and masked. Otherwise, they are not masked and returned as they are.
class CNN(nn.Module):
@nn.compact
def __call__(self, x, training):
x = nn.Conv(features=128, kernel_size=(3, 3))(x)
x = nn.relu(x)
x = nn.max_pool(x, window_shape=(2, 2), strides=(2, 2))
x = nn.Conv(features=64, kernel_size=(3, 3))(x)
x = nn.relu(x)
x = nn.max_pool(x, window_shape=(2, 2), strides=(2, 2))
x = nn.Conv(features=32, kernel_size=(3, 3))(x)
x = nn.relu(x)
x = nn.max_pool(x, window_shape=(2, 2), strides=(2, 2))
x = x.reshape((x.shape[0], -1))
x = nn.Dense(features=256)(x)
x = nn.Dense(features=128)(x)
x = nn.BatchNorm(use_running_average=not training)(x)
x = nn.Dropout(0.2, deterministic=not training)(x)
x = nn.relu(x)
x = nn.Dense(features=2)(x)
x = nn.log_softmax(x)
return x
Create loss function
The next step is to create the loss function. When applying the model, we:
- Pass the batch stats parameters.
training
as True.- Set the
batch_stats
as mutable. - Set the random number for the
DropOut
def cross_entropy_loss(*, logits, labels):
labels_onehot = jax.nn.one_hot(labels, num_classes=2)
return optax.softmax_cross_entropy(logits=logits, labels=labels_onehot).mean()
def compute_loss(params, batch_stats, images, labels):
logits,batch_stats = CNN().apply({'params': params,'batch_stats': batch_stats},images, training=True,rngs={'dropout': jax.random.PRNGKey(0)}, mutable=['batch_stats'])
loss = cross_entropy_loss(logits=logits, labels=labels)
return loss, (logits, batch_stats)
Compute metrics
The compute metrics function calculates the loss and accuracy and returns them.
def compute_metrics(*, logits, labels):
loss = cross_entropy_loss(logits=logits, labels=labels)
accuracy = jnp.mean(jnp.argmax(logits, -1) == labels)
metrics = {
'loss': loss,
'accuracy': accuracy,
}
return metrics
Create custom Flax training state
Let's create a custom Flax training state that will store the batch stats information. To do that, create a new training state class that subclasses Flax's TrainState
.
# initialize weights
model = CNN()
key = jax.random.PRNGKey(0)
variables = model.init(key, jnp.ones([1, size_image, size_image, 3]), training=False)
class TrainState(train_state.TrainState):
batch_stats: flax.core.FrozenDict
To define a Flax training state, use TrainState.create
and pass the:
- Apply function.
- Model parameters.
- The optimizer function.
- The batch stats.
state = TrainState.create(
apply_fn = model.apply,
params = variables['params'],
tx = optax.sgd(0.01),
batch_stats = variables['batch_stats'],
)
Training step
In the training step, we compute the gradients with respect to the loss and model parameters– the model parameters and batch statistics. We use the gradients to update the model parameters and return the new state and model metrics. The function is decorated with @jax.jit
to make the computation faster.
@jax.jit
def train_step(state,images, labels):
"""Train for a single step."""
(batch_loss, (logits, batch_stats)), grads= jax.value_and_grad(compute_loss, has_aux=True)(state.params,state.batch_stats, images,labels)
state = state.apply_gradients(grads=grads)
metrics = compute_metrics(logits=logits, labels=labels)
return state, metrics
👉 Check our JAX (What it is and how to use it in Python) tutorial.
Next, define a function that applies the training step for one epoch. The functions:
- Loops through the training data.
- Passes each training batch the training step.
- Obtains the batch metrics.
- Computes the mean to obtain the epoch metrics.
- Returns the new state and metrics.
This post is for subscribers only
SubscribeAlready have an account? Log in