Transfer learning with JAX & Flax

Derrick Mwiti
Derrick Mwiti

Table of Contents

Training large neural networks can take days or weeks. Once these networks are trained, you can take advantage of their weights and apply them to new tasks– transfer learning. As a result, you fine-tune a new network and get good results in a short period. Let's look at how you can fine-tune a pre-trained ResNet network in JAX and Flax.

Install JAX ResNet

We'll use ResNet checkpoints provided by the jax-resnet library.

pip install jax-resnet

Let's import it together with other packages used in this article.

# pip install flax
import numpy as np
import pandas as pd
from PIL import Image
import jax
import optax
import flax
import jax.numpy as jnp
from jax_resnet import pretrained_resnet, slice_variables, Sequential
from import train_state
from flax import linen as nn
from flax.core import FrozenDict,frozen_dict
from functools import partial
import os
import torch
from import DataLoader
from torchvision import transforms
from import Dataset
import matplotlib.pyplot as plt
%matplotlib inline
# ignore harmless warnings
import warnings
Interested in a deep dive into transfer learning?
👉 Check our Transfer learning guide(With examples for text and images in Keras and PyTorch) tutorial.

Download dataset

We will fine-tune the ResNet model to predict two classes from the cats and dogs dataset. Download and extract the cat and dog images.

pip install wget
import wget"")
import zipfile
with zipfile.ZipFile('', 'r') as zip_ref:

Data loading in JAX

JAX doesn't ship with data loading utilities. We use existing data loaders in TensorFlow and PyTorch to load the data. Let's use PyTorch to load the image data.

The first step is to create a PyTorch Dataset class.

class CatsDogsDataset(Dataset):
    def __init__(self, root_dir, annotation_file, transform=None):
        self.root_dir = root_dir
        self.annotations = pd.read_csv(annotation_file)
        self.transform = transform

    def __len__(self):
        return len(self.annotations)

    def __getitem__(self, index):
        img_id = self.annotations.iloc[index, 0]
        img =, img_id)).convert("RGB")
        y_label = torch.tensor(float(self.annotations.iloc[index, 1]))

        if self.transform is not None:
            img = self.transform(img)

        return (img, y_label)
Interested in learning more about loading datasets in JAX?
Check our How to load datasets in JAX with TensorFlow tutorial.

Data processing

Next, create a Pandas DataFrame with the image paths and labels.

train_df = pd.DataFrame(columns=["img_path","label"])
train_df["img_path"] = os.listdir("train/")
for idx, i in enumerate(os.listdir("train/")):
    if "cat" in i:
        train_df["label"][idx] = 0
    if "dog" in i:
        train_df["label"][idx] = 1

train_df.to_csv (r'train_csv.csv', index = False, header=True)

Define a function to stack the data and return the images and labels as a NumPy array.

def custom_collate_fn(batch):
    transposed_data = list(zip(*batch))
    labels = np.array(transposed_data[1])
    imgs = np.stack(transposed_data[0])
    return imgs, labels

Let's also resize the images to ensure they are the same size. Define the size in a configuration dictionary. We'll use the other config variables later.

config = {
    'NUM_LABELS': 2,
    'BATCH_SIZE': 32,
    'N_EPOCHS': 5,
    'LR': 0.001,
    'IMAGE_SIZE': 224,
    'WEIGHT_DECAY': 1e-5,

Resize the images using PyTorch transforms. Next, use the CatsDogsDataset class to define the training and testing data loaders.

transform = transforms.Compose([
dataset = CatsDogsDataset("train","train_csv.csv",transform=transform)
train_set, validation_set =,[20000,5000])
train_loader = DataLoader(dataset=train_set, collate_fn=custom_collate_fn,shuffle=True, batch_size=config["BATCH_SIZE"])
validation_loader = DataLoader(dataset=validation_set,collate_fn=custom_collate_fn, shuffle=False, batch_size=config["BATCH_SIZE"])

ResNet model definition

Pre-trained ResNet models are trained on many classes. However, the dataset we have has two classes. We, therefore, use the ResNet as the backbone and define a custom classification layer.

Create head network

Create a head network with output as per the problem, in this case, a binary image classification.

reference -
class Head(nn.Module):
    '''head model'''
    batch_norm_cls: partial = partial(nn.BatchNorm, momentum=0.9)
    def __call__(self, inputs, train: bool):
        output_n = inputs.shape[-1]
        x = self.batch_norm_cls(use_running_average=not train)(inputs)
        x = nn.Dropout(rate=0.25)(x, deterministic=not train)
        x = nn.Dense(features=output_n)(x)
        x = nn.relu(x)
        x = self.batch_norm_cls(use_running_average=not train)(x)
        x = nn.Dropout(rate=0.5)(x, deterministic=not train)
        x = nn.Dense(features=config["NUM_LABELS"])(x)
        return x

Combine ResNet backbone with head

Combine the pre-trained ResNet backbone with the custom head you created above.

class Model(nn.Module):
    '''Combines backbone and head model'''
    backbone: Sequential
    head: Head
    def __call__(self, inputs, train: bool):
        x = self.backbone(inputs)
        # average pool layer
        x = jnp.mean(x, axis=(1, 2))
        x = self.head(x, train)
        return x

Load pre-trained ResNet 50

Next, create a function that loads the pre-trained ResNet model. Omit the last two layers of the network because we have defined a custom head. The function returns the ResNet model and its parameters. The model parameters are obtained using the slice_variables function.

def get_backbone_and_params(model_arch: str):
    Get backbone and params
    1. Loads pretrained model (resnet50)
    2. Get model and param structure except last 2 layers
    3. Extract the corresponding subset of the variables dict
    INPUT : model_arch
    RETURNS backbone , backbone_params
    if model_arch == 'resnet50':
        resnet_tmpl, params = pretrained_resnet(50)
        model = resnet_tmpl()
        raise NotImplementedError
    # get model & param structure for backbone
    start, end = 0, len(model.layers) - 2
    backbone = Sequential(model.layers[start:end])
    backbone_params = slice_variables(params, start, end)
    return backbone, backbone_params

Get model and variables

Use the above function to create the final model. Define a function that:

  • Initializes the network's input.
  • Obtains the ResNet backbone and its parameters.
  • Passes the input to the backbone and gets the output.
  • Initializes the network's head.
  • Creates the final model using backbone and head.
  • Combines the parameters from backbone and head.
def get_model_and_variables(model_arch: str, head_init_key: int):
    Get model and variables 
    1. Initialise inputs(shape=(1,image_size,image_size,3))
    2. Get backbone and params
    3. Apply backbone model and get outputs
    4. Initialise head
    5. Create final model using backbone and head
    6. Combine params from backbone and head
    INPUT model_arch, head_init_key
    RETURNS  model, variables 
    inputs = jnp.ones((1, config['IMAGE_SIZE'],config['IMAGE_SIZE'], 3), jnp.float32)
    backbone, backbone_params = get_backbone_and_params(model_arch)
    key = jax.random.PRNGKey(head_init_key)
    backbone_output = backbone.apply(backbone_params, inputs, mutable=False)
    head_inputs = jnp.ones((1, backbone_output.shape[-1]), jnp.float32)
    head = Head()
    head_params = head.init(key, head_inputs, train=False)
    #final model
    model = Model(backbone, head)
    variables = FrozenDict({
        'params': {
            'backbone': backbone_params['params'],
            'head': head_params['params']
        'batch_stats': {
            'backbone': backbone_params['batch_stats'],
            'head': head_params['batch_stats']
    return model, variables

All names relating to the backbone network are prefixed with the name backbone. You can use any name, but all backbone variable names should be the same. This is important when freezing layers, as we'll see later.

Next, use the function defined above to create the model.

model, variables = get_model_and_variables('resnet50', 0)

Zero gradients

Since we are applying transfer learning, we need to ensure that the backbone is not updated. Otherwise, we'll be training the network from scratch. We want to take advantage of the pre-trained weights and use them as a feature extractor for the network. To achieve this, we freeze the parameters of all layers whose name starts with backbone. As a result, these parameters will not be updated during training.

reference -
def zero_grads():
    Zero out the previous gradient computation
    def init_fn(_): 
        return ()
    def update_fn(updates, state, params=None):
        return jax.tree_map(jnp.zeros_like, updates), ()
    return optax.GradientTransformation(init_fn, update_fn)
reference -
def create_mask(params, label_fn):
    def _map(params, mask, label_fn):
        for k in params:
            if label_fn(k):
                mask[k] = 'zero'
                if isinstance(params[k], FrozenDict):
                    mask[k] = {}
                    _map(params[k], mask[k], label_fn)
                    mask[k] = 'adam'
    mask = {}
    _map(params, mask, label_fn)
    return frozen_dict.freeze(mask)

Define Flax optimizer

Create an optimizer that will only be applied to the head and not backbone layers. This is done using the optax.multi_transform while passing the desired transformations.

adamw = optax.adamw(
    b1=0.9, b2=0.999, 
    eps=1e-6, weight_decay=1e-2
optimizer = optax.multi_transform(
    {'adam': adamw, 'zero': zero_grads()},
    create_mask(variables['params'], lambda s: s.startswith('backbone'))
Interested in learning more about optimizers in JAX and Flax?
Check our JAX and Flax optimizers tutorial.

Define Flax loss function

Next, define the function to compute the loss function.

def cross_entropy_loss(*, logits, labels):
  labels_onehot = jax.nn.one_hot(labels, num_classes=config["NUM_LABELS"])
  return optax.softmax_cross_entropy(logits=logits, labels=labels_onehot).mean()

When computing the loss during training, set train to True. You also have to:

  • Set the batch_stats
  • Define the random number for the dropout layers.
  • Set the batch_stats as mutable.
def compute_loss(params, batch_stats, images, labels):
    logits,batch_stats = model.apply({'params': params,'batch_stats': batch_stats},images, train=True,rngs={'dropout': jax.random.PRNGKey(0)}, mutable=['batch_stats'])
    loss = cross_entropy_loss(logits=logits, labels=labels)
    return loss, (logits, batch_stats)
Interested in learning more about loss functions in JAX?
Check our JAX loss functions tutorial.

Compute Flax metrics

Using the loss function, define a function that will return the loss and accuracy during training.

def compute_metrics(*, logits, labels):
  loss = cross_entropy_loss(logits=logits, labels=labels)
  accuracy = jnp.mean(jnp.argmax(logits, -1) == labels)
  metrics = {
      'loss': loss,
      'accuracy': accuracy,
  return metrics

Create Flax training state

Flax provides a training state for storing training information. In this case, we add the batch_stats information.

class TrainState(train_state.TrainState):
    batch_stats: FrozenDict
state = TrainState.create(
    apply_fn = model.apply,
    params = variables['params'],
    tx = optimizer,
    batch_stats = variables['batch_stats'],

Training step

The training step receives the images and labels and computes the gradient with respect to the model parameters. It then returns the new state and the model metrics.

def train_step(state: TrainState,images, labels):
  """Train for a single step."""
  (batch_loss, (logits, batch_stats)), grads= jax.value_and_grad(compute_loss, has_aux=True)(state.params,state.batch_stats, images,labels)
  state = state.apply_gradients(grads=grads) 
  metrics = compute_metrics(logits=logits, labels=labels) 
  return state, metrics

To train the network for one epoch, loop through the training data while applying the training step.

def train_one_epoch(state, dataloader):
    """Train for 1 epoch on the training set."""
    batch_metrics = []
    for cnt, (images, labels) in enumerate(dataloader):
        images = images / 255.0
        state, metrics = train_step(state, images, labels)

    batch_metrics_np = jax.device_get(batch_metrics)  
    epoch_metrics_np = {
        k: np.mean([metrics[k] for metrics in batch_metrics_np])
        for k in batch_metrics_np[0]
    return state, epoch_metrics_np

Evaluation step

The model evaluation steps accept the test labels and images and applies them to the network. It then returns the model evaluation metrics. During evaluation, set the train parameter to False. You'll also define the batch_stats and the random number for the dropout layer.

def eval_step(batch_stats, params, images, labels):
    logits = model.apply({'params': params,'batch_stats': batch_stats}, images, train=False,rngs={'dropout': jax.random.PRNGKey(0)})
    return compute_metrics(logits=logits, labels=labels)
def evaluate_model(state, test_imgs, test_lbls):
    """Evaluate on the validation set."""
    metrics = eval_step(state.batch_stats,state.params, test_imgs, test_lbls)
    metrics = jax.device_get(metrics) 
    metrics = jax.tree_map(lambda x: x.item(), metrics)  
    return metrics

Train ResNet model in Flax

Train the ResNet model by applying the train_one_epoch function for the desired number of epochs. This is a few epochs since we are fine-tuning the network.

Set up TensorBoard in Flax

To monitor model training via TensorBoard, you can write the training and validation metrics to TensorBoard.

from torch.utils.tensorboard import SummaryWriter
import torchvision.transforms.functional as F
logdir = "flax_logs"
writer = SummaryWriter(logdir)
Interested in learning more about using TensorBoard in Flax?
Check our How to use TensorBoard in JAX & Flax tutorial.

Train model

Define a function to train and evaluate the model while writing the metrics to TensorBoard.

(test_images, test_labels) = next(iter(validation_loader))
test_images = test_images / 255.0

training_loss = []
training_accuracy = []
testing_loss = []
testing_accuracy = []

def train_model(epochs):
    for epoch in range(1, epochs + 1):
        train_state, train_metrics = train_one_epoch(state, train_loader)

        test_metrics = evaluate_model(train_state, test_images, test_labels)
        writer.add_scalar('Loss/train', train_metrics['loss'], epoch)
        writer.add_scalar('Loss/test', test_metrics['loss'], epoch)
        writer.add_scalar('Accuracy/train', train_metrics['accuracy'], epoch)
        writer.add_scalar('Accuracy/test', test_metrics['accuracy'], epoch)
        print(f"Epoch: {epoch}, training loss: {train_metrics['loss']}, training accuracy: {train_metrics['accuracy'] * 100}, validation loss: {test_metrics['loss']}, validation accuracy: {test_metrics['accuracy'] * 100}")
    return train_state

Run the training function.

trained_model_state = train_model(config["N_EPOCHS"])

Save Flax model

Use the save_checkpoint to save a trained Flax model.

from import checkpoints
ckpt_dir = 'model_checkpoint/'

Load saved Flax model

A saved Flax model is loaded using the restore_checkpoint method.

loaded_model = checkpoints.restore_checkpoint(

Evaluate Flax ResNet model

To evaluate a Flax model, pass the test and training data to the evalaute_model function.

evaluate_model(loaded_model,test_images, test_labels)

Visualize model performance

You can check the network's performance via TensorBoard or plot the metrics using Matplotlib.

Final thoughts

You can apply transfer learning to take advantage of pre-trained models and get results with minimal effort. You have learned how to train a ResNet model in Flax. Specially, you have covered:

  • How to define the ResNet model in Flax.
  • How to freeze the layers of the ResNet network.
  • Training a ResNet model on custom data in Flax.
  • Saving and loading a ResNet model in Flax.

Want to dive deeper into JAX and Flax? Here are some more resources from our blog:

Open On GitHub

Follow us on LinkedIn, Twitter, GitHub, and subscribe to our blog, so you don't miss a new issue.


Derrick Mwiti Twitter

Google Developer Expert - Machine Learning