
Train ResNet in Flax from scratch(Distributed ResNet training)
Apart from designing custom CNN architectures, you can use architectures that have already been built. ResNet is one such popular architecture. In most cases, you'll achieve better performance by using such architectures. In this article, you will learn how to perform distributed training of a ResNet model in Flax.
Install Flax models
The flaxmodels
package provides pre-trained models for Jax and Flax, including:
- StyleGAN2
- GPT2
- VGG
- ResNet
git clone https://github.com/matthias-wright/flaxmodels.git
pip install -r flaxmodels/training/resnet/requirements.txt
In this project, we will train the model from scratch– meaning that we will not use the pre-trained weights. In a separate article, we have covered how to perform transfer learning with ResNet.
👉 Check our Transfer learning with JAX & Flax tutorial.
Perform standard imports
With flaxmodels
installed, let's import the standard libraries used in this article.
import wget # pip install wget
import zipfile
import torch
from torch.utils.data import DataLoader
import os
from PIL import Image
from torchvision import transforms
from torch.utils.data import Dataset
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
%matplotlib inline
# ignore harmless warnings
import warnings
warnings.filterwarnings("ignore")
import jax
from jax import numpy as jnp
import flax
from flax import linen as nn
from flax.training import train_state
import optax
import time
from tqdm.notebook import tqdm
import math
from flax import jax_utils
Download dataset
We will train the ResNet model to predict two classes from the cats and dogs dataset. Download and extract the cat and dog images.
wget.download("https://ml.machinelearningnuggets.com/train.zip")
with zipfile.ZipFile('train.zip', 'r') as zip_ref:
zip_ref.extractall('.')
Loading dataset in Flax
Since JAX and Flax don't ship with any data loaders, we use data loading utilities from PyTorch or TensorFlow. When using PyTorch, we start by creating a dataset class.
class CatsDogsDataset(Dataset):
def __init__(self, root_dir, annotation_file, transform=None):
self.root_dir = root_dir
self.annotations = pd.read_csv(annotation_file)
self.transform = transform
def __len__(self):
return len(self.annotations)
def __getitem__(self, index):
img_id = self.annotations.iloc[index, 0]
img = Image.open(os.path.join(self.root_dir, img_id)).convert("RGB")
y_label = torch.tensor(float(self.annotations.iloc[index, 1]))
if self.transform is not None:
img = self.transform(img)
return (img, y_label)
👉 Check our How to load datasets in JAX with TensorFlow tutorial.
Next, create a Pandas DataFrame containing the image paths and labels.
train_df = pd.DataFrame(columns=["img_path","label"])
train_df["img_path"] = os.listdir("train/")
for idx, i in enumerate(os.listdir("train/")):
if "cat" in i:
train_df["label"][idx] = 0
if "dog" in i:
train_df["label"][idx] = 1
train_df.to_csv (r'train_csv.csv', index = False, header=True)
Data transformation in Flax
Define a function that will stack the data and return it as a NumPy array.
def custom_collate_fn(batch):
transposed_data = list(zip(*batch))
labels = np.array(transposed_data[1])
imgs = np.stack(transposed_data[0])
return imgs, labels
Create a transformation for resizing the images. Next, apply the transformation to the dataset created earlier.
size_image = 224
transform = transforms.Compose([
transforms.Resize((size_image,size_image)),
np.array])
dataset = CatsDogsDataset("train","train_csv.csv",transform=transform)
Split this dataset into a training and testing set and create data loaders for each set.
batch_size = 32
train_set, validation_set = torch.utils.data.random_split(dataset,[20000,5000])
train_loader = DataLoader(dataset=train_set, collate_fn=custom_collate_fn,shuffle=True, batch_size=batch_size)
validation_loader = DataLoader(dataset=validation_set,collate_fn=custom_collate_fn, shuffle=False, batch_size=batch_size)
Instantiate Flax ResNet model
With the data in place, instantiate the Flax ResNet model using the flaxmodels
package. The instantiation requires:
- The desired number of classes.
- The type of output.
- The data type.
- Whether the model is pre-trained– in this case
False
.
import jax.numpy as jnp
import flaxmodels as fm
num_classes = 2
dtype = jnp.float32
model = fm.ResNet50(output='log_softmax', pretrained=None, num_classes=num_classes, dtype=dtype)
Compute metrics
Define the metrics for evaluating the model during training. Let's start by creating the loss function.
def cross_entropy_loss(*, logits, labels):
labels_onehot = jax.nn.one_hot(labels, num_classes=num_classes)
return optax.softmax_cross_entropy(logits=logits, labels=labels_onehot).mean()
Next, define a function that computes and returns the loss and accuracy.
def compute_metrics(*, logits, labels):
loss = cross_entropy_loss(logits=logits, labels=labels)
accuracy = jnp.mean(jnp.argmax(logits, -1) == labels)
metrics = {
'loss': loss,
'accuracy': accuracy,
}
return metrics
Create Flax model training state
Flax provides a training state for storing training information. The training state can be modified to add new information. In this case, we need to alter the training state to add the batch statistics since the ResNet model computes batch_stats
.
class TrainState(train_state.TrainState):
batch_stats: flax.core.FrozenDict
We need the model parameters and batch statistics to create the training state function. We can access these by initializing the model with the train
as False
.
key = jax.random.PRNGKey(0)
variables = model.init(key, jnp.ones([1, size_image, size_image, 3]), train=False)
The create
method of TrainState
expects the following parameters:
- The
apply_fn
– model apply function. - The model parameters–
variables['params']
. - The optimizer, usually defined using Optax.
- The batch statistics–
variables['batch_stats']
.
We apply pmap
to this function to create a distributed version of the training state. pmap
compiles the function for execution on multiple devices such as multiple GPUs and TPUs.
import functools
@functools.partial(jax.pmap)
def create_train_state(rng):
"""Creates initial `TrainState`."""
return TrainState.create(apply_fn = model.apply,params = variables['params'],tx = optax.adam(0.01,0.9),batch_stats = variables['batch_stats'])
👉 Check our Handling state in JAX and Flax tutorial.
Apply model function
Next, define a parallel model training function. Pass an axis_name
so you can use that to aggregate the metrics from all the devices. The function:
- Computes the loss.
- Computes predictions from all devices by calculating the average of the probabilities using
jax.lax.pmean()
.
When applying the model, we also include the batch statistics and the random number for DropOut
. Since this is the training function, the train
parameter is True
. The batch_stats
are also included when computing the gradients. The update_model
function applies the computed gradients– updates the model parameters.
@functools.partial(jax.pmap, axis_name='ensemble')
def apply_model(state, images, labels):
def loss_fn(params,batch_stats):
logits,batch_stats = model.apply({'params': params,'batch_stats': batch_stats},images, train=True,rngs={'dropout': jax.random.PRNGKey(0)}, mutable=['batch_stats'])
one_hot = jax.nn.one_hot(labels, num_classes)
loss = optax.softmax_cross_entropy(logits=logits, labels=one_hot).mean()
return loss, (logits, batch_stats)
(loss, (logits, batch_stats)), grads = jax.value_and_grad(loss_fn, has_aux=True)(state.params,state.batch_stats)
probs = jax.lax.pmean(jax.nn.softmax(logits), axis_name='ensemble')
accuracy = jnp.mean(jnp.argmax(probs, -1) == labels)
return grads,loss, accuracy
@jax.pmap
def update_model(state, grads):
return state.apply_gradients(grads=grads)
TensorBoard in Flax
The next step is to train the ResNet model. However, you might be interested in tracking the training using TensorBoard. In that case, you have to configure TensorBoard. You can write the metrics to TensorBoard using the PyTorch SummaryWriter
.
rm -rf ./flax_logs/
from torch.utils.tensorboard import SummaryWriter
import torchvision.transforms.functional as F
logdir = "flax_logs"
writer = SummaryWriter(logdir)
👉 Check our How to use TensorBoard in JAX & Flax tutorial.
Train Flax ResNet model
Let's train the ResNet model on the entire training set and evaluate it on a subset of the test set. You can also evaluate it on the whole test set. Replicate the test set to the available devices.
(test_images, test_labels) = next(iter(validation_loader))
test_images = test_images / 255.0
test_images = np.array(jax_utils.replicate(test_images))
test_labels = np.array(jax_utils.replicate(test_labels))
Create some lists to hold the training and evaluation metrics.
epoch_loss = []
epoch_accuracy = []
testing_accuracy = []
testing_loss = []
Next, define the ResNet model training function. The function does the following:
- Loops through the training dataset and scales it.
- Replicates the data on the available devices.
- Applies the model on the dataset and computes the metrics.
- Obtains the metrics from the devices using
jax_utils.unreplicate
. - Appends the metrics to a list.
- Computes the mean of the loss and accuracy to obtain the metrics for each epoch.
- Applies the model to the test set and obtains the metrics.
- Append the test metrics to a list.
- Writes the training and evaluation metrics to TensorBaord.
- Prints the training and evaluation metrics.
This post is for subscribers only
SubscribeAlready have an account? Log in