Handling state in JAX & Flax (BatchNorm and DropOut layers)
Jitting functions in Flax makes them faster but requires that the functions have no side effects. The fact that jitted functions can't have side effects introduces a challenge when dealing with stateful items such as model parameters and stateful layers such as batch normalization layers. In this article, we'll create a network with the BatchNorm and DropOut layers. After that, we'll see how to deal with generating the random number for the DropOut layer and adding the batch statistics when training the network.
Perform standard imports
We kick off by importing standard data science packages that we'll use in this article.
import torch
from torch.utils.data import DataLoader
import os
from PIL import Image
from torchvision import transforms
from torch.utils.data import Dataset
import numpy as np
import pandas as pd
from typing import Any
import matplotlib.pyplot as plt
%matplotlib inline
# ignore harmless warnings
import warnings
warnings.filterwarnings("ignore")
import jax
from jax import numpy as jnp
import flax
from flax import linen as nn
from flax.training import train_state
import optax
Download the dataset
Let's illustrate how to include BatchNorm and DropOut layers in a Flax network by designing a simple Convolutional Neural Network using the cat and dogs dataset from Kaggle.
Download and extract the dataset.
import wget
wget.download("https://ml.machinelearningnuggets.com/train.zip")
import zipfile
with zipfile.ZipFile('train.zip', 'r') as zip_ref:
zip_ref.extractall('.')
Loading datasets in JAX
Since JAX doesn't ship with data loading tools, load the dataset using PyTorch. We start by creating a PyTorch Dataset class.
class CatsDogsDataset(Dataset):
def __init__(self, root_dir, annotation_file, transform=None):
self.root_dir = root_dir
self.annotations = pd.read_csv(annotation_file)
self.transform = transform
def __len__(self):
return len(self.annotations)
def __getitem__(self, index):
img_id = self.annotations.iloc[index, 0]
img = Image.open(os.path.join(self.root_dir, img_id)).convert("RGB")
y_label = torch.tensor(float(self.annotations.iloc[index, 1]))
if self.transform is not None:
img = self.transform(img)
return (img, y_label)
👉 Check our How to load datasets in JAX with TensorFlow tutorial.
Next, create a Pandas DataFrame containing the image path and the labels.
train_df = pd.DataFrame(columns=["img_path","label"])
train_df["img_path"] = os.listdir("train/")
for idx, i in enumerate(os.listdir("train/")):
if "cat" in i:
train_df["label"][idx] = 0
if "dog" in i:
train_df["label"][idx] = 1
train_df.to_csv (r'train_csv.csv', index = False, header=True)
Data processing with PyTorch
Next, create a function to stack the dataset and return it as a NumPy array.
def custom_collate_fn(batch):
transposed_data = list(zip(*batch))
labels = np.array(transposed_data[1])
imgs = np.stack(transposed_data[0])
return imgs, labels
We then use PyTorch to create training and testing data loaders.
size_image = 224
batch_size = 64
transform = transforms.Compose([
transforms.Resize((size_image,size_image)),
np.array])
dataset = CatsDogsDataset("train","train_csv.csv",transform=transform)
train_set, validation_set = torch.utils.data.random_split(dataset,[20000,5000])
train_loader = DataLoader(dataset=train_set, collate_fn=custom_collate_fn,shuffle=True, batch_size=batch_size)
validation_loader = DataLoader(dataset=validation_set,collate_fn=custom_collate_fn, shuffle=False, batch_size=batch_size)
Define Flax model with BatchNorm and DropOut
Define the Flax network with the BatchNorm and DropOut layers. In the network, we introduce the training
variable to control when the batch stats should be updated. We ensure that they aren't updated during testing.
In the BatchNorm
layer we set use_running_average
to False
meaning that the stats stored in batch_stats
will not be used, but batch stats of the input will be computed.
The DropOut
layer takes the following:
- The rate drop out probability.
- Whether it's
deterministic
. If deterministic inputs are scaled and masked. Otherwise, they are not masked and returned as they are.
class CNN(nn.Module):
@nn.compact
def __call__(self, x, training):
x = nn.Conv(features=128, kernel_size=(3, 3))(x)
x = nn.relu(x)
x = nn.max_pool(x, window_shape=(2, 2), strides=(2, 2))
x = nn.Conv(features=64, kernel_size=(3, 3))(x)
x = nn.relu(x)
x = nn.max_pool(x, window_shape=(2, 2), strides=(2, 2))
x = nn.Conv(features=32, kernel_size=(3, 3))(x)
x = nn.relu(x)
x = nn.max_pool(x, window_shape=(2, 2), strides=(2, 2))
x = x.reshape((x.shape[0], -1))
x = nn.Dense(features=256)(x)
x = nn.Dense(features=128)(x)
x = nn.BatchNorm(use_running_average=not training)(x)
x = nn.Dropout(0.2, deterministic=not training)(x)
x = nn.relu(x)
x = nn.Dense(features=2)(x)
x = nn.log_softmax(x)
return x
Create loss function
The next step is to create the loss function. When applying the model, we:
- Pass the batch stats parameters.
training
as True.- Set the
batch_stats
as mutable. - Set the random number for the
DropOut
def cross_entropy_loss(*, logits, labels):
labels_onehot = jax.nn.one_hot(labels, num_classes=2)
return optax.softmax_cross_entropy(logits=logits, labels=labels_onehot).mean()
def compute_loss(params, batch_stats, images, labels):
logits,batch_stats = CNN().apply({'params': params,'batch_stats': batch_stats},images, training=True,rngs={'dropout': jax.random.PRNGKey(0)}, mutable=['batch_stats'])
loss = cross_entropy_loss(logits=logits, labels=labels)
return loss, (logits, batch_stats)
Compute metrics
The compute metrics function calculates the loss and accuracy and returns them.
def compute_metrics(*, logits, labels):
loss = cross_entropy_loss(logits=logits, labels=labels)
accuracy = jnp.mean(jnp.argmax(logits, -1) == labels)
metrics = {
'loss': loss,
'accuracy': accuracy,
}
return metrics
Create custom Flax training state
Let's create a custom Flax training state that will store the batch stats information. To do that, create a new training state class that subclasses Flax's TrainState
.
# initialize weights
model = CNN()
key = jax.random.PRNGKey(0)
variables = model.init(key, jnp.ones([1, size_image, size_image, 3]), training=False)
class TrainState(train_state.TrainState):
batch_stats: flax.core.FrozenDict
To define a Flax training state, use TrainState.create
and pass the:
- Apply function.
- Model parameters.
- The optimizer function.
- The batch stats.
state = TrainState.create(
apply_fn = model.apply,
params = variables['params'],
tx = optax.sgd(0.01),
batch_stats = variables['batch_stats'],
)
Training step
In the training step, we compute the gradients with respect to the loss and model parameters– the model parameters and batch statistics. We use the gradients to update the model parameters and return the new state and model metrics. The function is decorated with @jax.jit
to make the computation faster.
@jax.jit
def train_step(state,images, labels):
"""Train for a single step."""
(batch_loss, (logits, batch_stats)), grads= jax.value_and_grad(compute_loss, has_aux=True)(state.params,state.batch_stats, images,labels)
state = state.apply_gradients(grads=grads)
metrics = compute_metrics(logits=logits, labels=labels)
return state, metrics
👉 Check our JAX (What it is and how to use it in Python) tutorial.
Next, define a function that applies the training step for one epoch. The functions:
- Loops through the training data.
- Passes each training batch the training step.
- Obtains the batch metrics.
- Computes the mean to obtain the epoch metrics.
- Returns the new state and metrics.
def train_one_epoch(state, dataloader):
"""Train for 1 epoch on the training set."""
batch_metrics = []
for cnt, (images, labels) in enumerate(dataloader):
images = images / 255.0
state, metrics = train_step(state, images, labels)
batch_metrics.append(metrics)
batch_metrics_np = jax.device_get(batch_metrics)
epoch_metrics_np = {
k: np.mean([metrics[k] for metrics in batch_metrics_np])
for k in batch_metrics_np[0]
}
return state, epoch_metrics_np
Evaluation step
We pass the test images and labels to the model in the evaluation step and obtain the evaluation metrics. The function is also jitted to take advantage of JAX's fast computation. During the evaluation, set training
to False
so that the model parameters are not updated. In this step, we also pass the batch stats and the random number generator for the DropOut
layer.
@jax.jit
def eval_step(batch_stats, params, images, labels):
logits = CNN().apply({'params': params,'batch_stats': batch_stats}, images, training=False,rngs={'dropout': jax.random.PRNGKey(0)})
return compute_metrics(logits=logits, labels=labels)
The evaluate_model
function applies the eval_step
to the test data and returns the evaluation metrics.
def evaluate_model(state, test_imgs, test_lbls):
"""Evaluate on the validation set."""
metrics = eval_step(state.batch_stats,state.params, test_imgs, test_lbls)
metrics = jax.device_get(metrics)
metrics = jax.tree_map(lambda x: x.item(), metrics)
return metrics
Train Flax model
To train the model, we define another function that implements train_one_epoch
. Let's start by defining the model evaluation data.
(test_images, test_labels) = next(iter(validation_loader))
test_images = test_images / 255.0
Set up TensorBoard in Flax
You can log the model metrics to TensorBoard by writing the scalars to TensorBoard.
from torch.utils.tensorboard import SummaryWriter
import torchvision.transforms.functional as F
logdir = "flax_logs"
writer = SummaryWriter(logdir)
👉 Check our TensorBoard tutorial (Deep dive with examples and notebook) tutorial.
Train model
We can also append the metrics to a list and visualize them with Matplotlib.
training_loss = []
training_accuracy = []
testing_loss = []
testing_accuracy = []
Next, define the training function that will:
- Train the Flax model for the specified number of epochs.
- Evaluate the model on the test data.
- Append the metrics to a list.
- Write model metrics to TensorBoard.
- Print the metrics on every epoch.
- Return the trained model state
def train_model(epochs):
for epoch in range(1, epochs + 1):
train_state, train_metrics = train_one_epoch(state, train_loader)
training_loss.append(train_metrics['loss'])
training_accuracy.append(train_metrics['accuracy'])
test_metrics = evaluate_model(train_state, test_images, test_labels)
testing_loss.append(test_metrics['loss'])
testing_accuracy.append(test_metrics['accuracy'])
writer.add_scalar('Loss/train', train_metrics['loss'], epoch)
writer.add_scalar('Loss/test', test_metrics['loss'], epoch)
writer.add_scalar('Accuracy/train', train_metrics['accuracy'], epoch)
writer.add_scalar('Accuracy/test', test_metrics['accuracy'], epoch)
print(f"Epoch: {epoch}, training loss: {train_metrics['loss']}, training accuracy: {train_metrics['accuracy'] * 100}, validation loss: {test_metrics['loss']}, validation accuracy: {test_metrics['accuracy'] * 100}")
return train_state
Run the training function to train the model.
trained_model_state = train_model(30)
Save Flax model
The save_checkpoint
saves a Flax model. It expects:
- The directory to save the model checkpoint.
- The Flax trained model, in this case
trained_model_state
. - The model's prefix.
- Whether to overwrite existing models.
from flax.training import checkpoints
ckpt_dir = 'model_checkpoint/'
checkpoints.save_checkpoint(ckpt_dir=ckpt_dir,
target=trained_model_state,
step=100,
prefix='flax_model',
overwrite=True
)
Load Flax model
The restore_checkpoint
method loads a saved Flax model from the saved location.
loaded_model = checkpoints.restore_checkpoint(
ckpt_dir=ckpt_dir,
target=state,
prefix='flax_model'
)
Evaluate Flax model
Run the evalaute_model
function to check the performance of the model on test data.
evaluate_model(trained_model_state,test_images, test_labels)
Visualize Flax model performance
To visualize the performance of the Flax model, you can plot the metrics using Matplotlib or load TensorBoard and check the scalars tab.
%load_ext tensorboard
%tensorboard --logdir={logdir}
👉 Check our How to use TensorBoard in JAX & Flax tutorial.
Final thoughts
In this article, you have seen how to build networks in Flax containing BatchNorm and DropOut layers. You have also seen how to adjust the training process to cater to these new layers. Specifically, you have learned:
- How to define Flax models with BatchNorm and DropOut layers.
- How to create a custom Flax training state.
- Training and evaluating a Flax model with BatchNorm and DropOut layers.
- How to save and load a Flax model.
- How to evaluate the performance of a Flax model
Interested in diving deeper into JAX and Flax? Here are some more resources from our blog:
- What is JAX?
- Flax vs. TensorFlow
- JAX loss functions
- Transfer learning with JAX & Flax
- Activation functions in JAX and Flax
- Optimizers in JAX and Flax
- How to load datasets in JAX using TensorFlow
- Building Convolutional Neural Networks in JAX and Flax
- Distributed training in JAX
- Using TensorBoard in JAX and Flax
- LSTM in JAX & Flax
- Elegy (High-level API for deep learning in JAX & Flax)
Follow us on LinkedIn, Twitter, GitHub, and subscribe to our blog, so you don't miss a new issue.