Train ResNet in Flax from scratch(Distributed ResNet training)

Derrick Mwiti
Derrick Mwiti
8 min read

Apart from designing custom CNN architectures, you can use architectures that have already been built. ResNet is one such popular architecture. In most cases, you'll achieve better performance by using such architectures. In this article, you will learn how to perform distributed training of a ResNet model in Flax.

Install Flax models

The flaxmodels package provides pre-trained models for Jax and Flax, including:

  • StyleGAN2
  • GPT2
  • VGG
  • ResNet
git clone
pip install -r flaxmodels/training/resnet/requirements.txt

In this project, we will train the model from scratch– meaning that we will not use the pre-trained weights. In a separate article, we have covered how to perform transfer learning with ResNet.

Interested in learning more about transfer learning in Flax?
👉 Check our Transfer learning with JAX & Flax tutorial.

Perform standard imports

With flaxmodels installed, let's import the standard libraries used in this article.

import wget # pip install wget
import zipfile
import torch
from import DataLoader
import os
from PIL import Image
from torchvision import transforms
from import Dataset
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
%matplotlib inline
# ignore harmless warnings
import warnings
import jax
from jax import numpy as jnp
import flax
from flax import linen as nn  
from import train_state
import optax
import time
from tqdm.notebook import tqdm
import math
from flax import jax_utils

Download dataset

We will train the ResNet model to predict two classes from the cats and dogs dataset. Download and extract the cat and dog images."")
with zipfile.ZipFile('', 'r') as zip_ref:

Loading dataset in Flax

Since JAX and Flax don't ship with any data loaders, we use data loading utilities from PyTorch or TensorFlow. When using PyTorch, we start by creating a dataset class.

class CatsDogsDataset(Dataset):
    def __init__(self, root_dir, annotation_file, transform=None):
        self.root_dir = root_dir
        self.annotations = pd.read_csv(annotation_file)
        self.transform = transform

    def __len__(self):
        return len(self.annotations)

    def __getitem__(self, index):
        img_id = self.annotations.iloc[index, 0]
        img =, img_id)).convert("RGB")
        y_label = torch.tensor(float(self.annotations.iloc[index, 1]))

        if self.transform is not None:
            img = self.transform(img)

        return (img, y_label)
Want to learn more about loading datasets in Flax?
👉 Check our How to load datasets in JAX with TensorFlow tutorial.

Next, create a Pandas DataFrame containing the image paths and labels.

train_df = pd.DataFrame(columns=["img_path","label"])
train_df["img_path"] = os.listdir("train/")
for idx, i in enumerate(os.listdir("train/")):
    if "cat" in i:
        train_df["label"][idx] = 0
    if "dog" in i:
        train_df["label"][idx] = 1

train_df.to_csv (r'train_csv.csv', index = False, header=True)

Data transformation in Flax

Define a function that will stack the data and return it as a NumPy array.

def custom_collate_fn(batch):
    transposed_data = list(zip(*batch))
    labels = np.array(transposed_data[1])
    imgs = np.stack(transposed_data[0])
    return imgs, labels

Create a transformation for resizing the images. Next, apply the transformation to the dataset created earlier.  

size_image = 224

transform = transforms.Compose([
dataset = CatsDogsDataset("train","train_csv.csv",transform=transform)

Split this dataset into a training and testing set and create data loaders for each set.  

batch_size = 32

train_set, validation_set =,[20000,5000])
train_loader = DataLoader(dataset=train_set, collate_fn=custom_collate_fn,shuffle=True, batch_size=batch_size)
validation_loader = DataLoader(dataset=validation_set,collate_fn=custom_collate_fn, shuffle=False, batch_size=batch_size)

Instantiate Flax ResNet model

With the data in place, instantiate the Flax ResNet model using the flaxmodels package. The instantiation requires:

  • The desired number of classes.
  • The type of output.
  • The data type.
  • Whether the model is pre-trained– in this case False.
import jax.numpy as jnp
import flaxmodels as fm

num_classes = 2
dtype = jnp.float32
model = fm.ResNet50(output='log_softmax', pretrained=None, num_classes=num_classes, dtype=dtype)

Compute metrics

Define the metrics for evaluating the model during training. Let's start by creating the loss function.  

def cross_entropy_loss(*, logits, labels):
  labels_onehot = jax.nn.one_hot(labels, num_classes=num_classes)
  return optax.softmax_cross_entropy(logits=logits, labels=labels_onehot).mean()

Next, define a function that computes and returns the loss and accuracy.

def compute_metrics(*, logits, labels):
  loss = cross_entropy_loss(logits=logits, labels=labels)
  accuracy = jnp.mean(jnp.argmax(logits, -1) == labels)
  metrics = {
      'loss': loss,
      'accuracy': accuracy,
  return metrics

Create Flax model training state

Flax provides a training state for storing training information. The training state can be modified to add new information. In this case, we need to alter the training state to add the batch statistics since the ResNet model computes batch_stats.  

class TrainState(train_state.TrainState):
    batch_stats: flax.core.FrozenDict

We need the model parameters and batch statistics to create the training state function. We can access these by initializing the model with the train as False.

key = jax.random.PRNGKey(0)
variables = model.init(key, jnp.ones([1, size_image, size_image, 3]), train=False)

The create method of TrainState expects the following parameters:

  • The apply_fn – model apply function.
  • The model parameters– variables['params'].
  • The optimizer, usually defined using Optax.
  • The batch statistics– variables['batch_stats'].

We apply pmap to this function to create a distributed version of the training state.  pmap compiles the function for execution on multiple devices such as multiple GPUs and TPUs.

import functools
def create_train_state(rng):
  """Creates initial `TrainState`."""
  return TrainState.create(apply_fn = model.apply,params = variables['params'],tx = optax.adam(0.01,0.9),batch_stats = variables['batch_stats'])
Would you like to learn handling state in JAX and Flax?
👉 Check our Handling state in JAX and Flax tutorial.

Apply model function

Next, define a parallel model training function. Pass an axis_name so you can use that to aggregate the metrics from all the devices. The function:

  • Computes the loss.
  • Computes predictions from all devices by calculating the average of the probabilities using jax.lax.pmean() .

When applying the model, we also include the batch statistics and the random number for  DropOut. Since this is the training function, the train parameter is True. The batch_stats are also included when computing the gradients. The update_model function applies the computed gradients– updates the model parameters.

@functools.partial(jax.pmap, axis_name='ensemble')
def apply_model(state, images, labels):
  def loss_fn(params,batch_stats):
    logits,batch_stats = model.apply({'params': params,'batch_stats': batch_stats},images, train=True,rngs={'dropout': jax.random.PRNGKey(0)}, mutable=['batch_stats'])
    one_hot = jax.nn.one_hot(labels, num_classes)
    loss = optax.softmax_cross_entropy(logits=logits, labels=one_hot).mean()
    return loss, (logits, batch_stats)

  (loss, (logits, batch_stats)), grads = jax.value_and_grad(loss_fn, has_aux=True)(state.params,state.batch_stats)
  probs = jax.lax.pmean(jax.nn.softmax(logits), axis_name='ensemble')

  accuracy = jnp.mean(jnp.argmax(probs, -1) == labels)
  return grads,loss, accuracy

def update_model(state, grads):
  return state.apply_gradients(grads=grads)

TensorBoard in Flax

The next step is to train the ResNet model. However, you might be interested in tracking the training using TensorBoard. In that case, you have to configure TensorBoard. You can write the metrics to TensorBoard using the PyTorch SummaryWriter.    

rm -rf ./flax_logs/     
from torch.utils.tensorboard import SummaryWriter
import torchvision.transforms.functional as F
logdir = "flax_logs"
writer = SummaryWriter(logdir)
Learn how to use TensorBoard in Flax
👉 Check our How to use TensorBoard in JAX & Flax tutorial.

Train Flax ResNet model

Let's train the ResNet model on the entire training set and evaluate it on a subset of the test set. You can also evaluate it on the whole test set. Replicate the test set to the available devices.

(test_images, test_labels) = next(iter(validation_loader))
test_images = test_images / 255.0
test_images = np.array(jax_utils.replicate(test_images))
test_labels = np.array(jax_utils.replicate(test_labels))

Create some lists to hold the training and evaluation metrics.

epoch_loss = []
epoch_accuracy = []
testing_accuracy = []
testing_loss = []

Next, define the ResNet model training function. The function does the following:

  • Loops through the training dataset and scales it.
  • Replicates the data on the available devices.
  • Applies the model on the dataset and computes the metrics.
  • Obtains the metrics from the devices using  jax_utils.unreplicate.
  • Appends the metrics to a list.
  • Computes the mean of the loss and accuracy to obtain the metrics for each epoch.
  • Applies the model to the test set and obtains the metrics.
  • Append the test metrics to a list.
  • Writes the training and evaluation metrics to TensorBaord.
  • Prints the training and evaluation metrics.  

This post is for subscribers only


Already have an account? Log in