Train ResNet in Flax from scratch(Distributed ResNet training)
Apart from designing custom CNN architectures, you can use architectures that have already been built. ResNet is one such popular architecture. In most cases, you'll achieve better performance by using such architectures. In this article, you will learn how to perform distributed training of a ResNet model in Flax.
Install Flax models
The flaxmodels
package provides pre-trained models for Jax and Flax, including:
- StyleGAN2
- GPT2
- VGG
- ResNet
git clone https://github.com/matthias-wright/flaxmodels.git
pip install -r flaxmodels/training/resnet/requirements.txt
In this project, we will train the model from scratch– meaning that we will not use the pre-trained weights. In a separate article, we have covered how to perform transfer learning with ResNet.
👉 Check our Transfer learning with JAX & Flax tutorial.
Perform standard imports
With flaxmodels
installed, let's import the standard libraries used in this article.
import wget # pip install wget
import zipfile
import torch
from torch.utils.data import DataLoader
import os
from PIL import Image
from torchvision import transforms
from torch.utils.data import Dataset
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
%matplotlib inline
# ignore harmless warnings
import warnings
warnings.filterwarnings("ignore")
import jax
from jax import numpy as jnp
import flax
from flax import linen as nn
from flax.training import train_state
import optax
import time
from tqdm.notebook import tqdm
import math
from flax import jax_utils
Download dataset
We will train the ResNet model to predict two classes from the cats and dogs dataset. Download and extract the cat and dog images.
wget.download("https://ml.machinelearningnuggets.com/train.zip")
with zipfile.ZipFile('train.zip', 'r') as zip_ref:
zip_ref.extractall('.')
Loading dataset in Flax
Since JAX and Flax don't ship with any data loaders, we use data loading utilities from PyTorch or TensorFlow. When using PyTorch, we start by creating a dataset class.
class CatsDogsDataset(Dataset):
def __init__(self, root_dir, annotation_file, transform=None):
self.root_dir = root_dir
self.annotations = pd.read_csv(annotation_file)
self.transform = transform
def __len__(self):
return len(self.annotations)
def __getitem__(self, index):
img_id = self.annotations.iloc[index, 0]
img = Image.open(os.path.join(self.root_dir, img_id)).convert("RGB")
y_label = torch.tensor(float(self.annotations.iloc[index, 1]))
if self.transform is not None:
img = self.transform(img)
return (img, y_label)
👉 Check our How to load datasets in JAX with TensorFlow tutorial.
Next, create a Pandas DataFrame containing the image paths and labels.
train_df = pd.DataFrame(columns=["img_path","label"])
train_df["img_path"] = os.listdir("train/")
for idx, i in enumerate(os.listdir("train/")):
if "cat" in i:
train_df["label"][idx] = 0
if "dog" in i:
train_df["label"][idx] = 1
train_df.to_csv (r'train_csv.csv', index = False, header=True)
Data transformation in Flax
Define a function that will stack the data and return it as a NumPy array.
def custom_collate_fn(batch):
transposed_data = list(zip(*batch))
labels = np.array(transposed_data[1])
imgs = np.stack(transposed_data[0])
return imgs, labels
Create a transformation for resizing the images. Next, apply the transformation to the dataset created earlier.
size_image = 224
transform = transforms.Compose([
transforms.Resize((size_image,size_image)),
np.array])
dataset = CatsDogsDataset("train","train_csv.csv",transform=transform)
Split this dataset into a training and testing set and create data loaders for each set.
batch_size = 32
train_set, validation_set = torch.utils.data.random_split(dataset,[20000,5000])
train_loader = DataLoader(dataset=train_set, collate_fn=custom_collate_fn,shuffle=True, batch_size=batch_size)
validation_loader = DataLoader(dataset=validation_set,collate_fn=custom_collate_fn, shuffle=False, batch_size=batch_size)
Instantiate Flax ResNet model
With the data in place, instantiate the Flax ResNet model using the flaxmodels
package. The instantiation requires:
- The desired number of classes.
- The type of output.
- The data type.
- Whether the model is pre-trained– in this case
False
.
import jax.numpy as jnp
import flaxmodels as fm
num_classes = 2
dtype = jnp.float32
model = fm.ResNet50(output='log_softmax', pretrained=None, num_classes=num_classes, dtype=dtype)
Compute metrics
Define the metrics for evaluating the model during training. Let's start by creating the loss function.
def cross_entropy_loss(*, logits, labels):
labels_onehot = jax.nn.one_hot(labels, num_classes=num_classes)
return optax.softmax_cross_entropy(logits=logits, labels=labels_onehot).mean()
Next, define a function that computes and returns the loss and accuracy.
def compute_metrics(*, logits, labels):
loss = cross_entropy_loss(logits=logits, labels=labels)
accuracy = jnp.mean(jnp.argmax(logits, -1) == labels)
metrics = {
'loss': loss,
'accuracy': accuracy,
}
return metrics
Create Flax model training state
Flax provides a training state for storing training information. The training state can be modified to add new information. In this case, we need to alter the training state to add the batch statistics since the ResNet model computes batch_stats
.
class TrainState(train_state.TrainState):
batch_stats: flax.core.FrozenDict
We need the model parameters and batch statistics to create the training state function. We can access these by initializing the model with the train
as False
.
key = jax.random.PRNGKey(0)
variables = model.init(key, jnp.ones([1, size_image, size_image, 3]), train=False)
The create
method of TrainState
expects the following parameters:
- The
apply_fn
– model apply function. - The model parameters–
variables['params']
. - The optimizer, usually defined using Optax.
- The batch statistics–
variables['batch_stats']
.
We apply pmap
to this function to create a distributed version of the training state. pmap
compiles the function for execution on multiple devices such as multiple GPUs and TPUs.
import functools
@functools.partial(jax.pmap)
def create_train_state(rng):
"""Creates initial `TrainState`."""
return TrainState.create(apply_fn = model.apply,params = variables['params'],tx = optax.adam(0.01,0.9),batch_stats = variables['batch_stats'])
👉 Check our Handling state in JAX and Flax tutorial.
Apply model function
Next, define a parallel model training function. Pass an axis_name
so you can use that to aggregate the metrics from all the devices. The function:
- Computes the loss.
- Computes predictions from all devices by calculating the average of the probabilities using
jax.lax.pmean()
.
When applying the model, we also include the batch statistics and the random number for DropOut
. Since this is the training function, the train
parameter is True
. The batch_stats
are also included when computing the gradients. The update_model
function applies the computed gradients– updates the model parameters.
@functools.partial(jax.pmap, axis_name='ensemble')
def apply_model(state, images, labels):
def loss_fn(params,batch_stats):
logits,batch_stats = model.apply({'params': params,'batch_stats': batch_stats},images, train=True,rngs={'dropout': jax.random.PRNGKey(0)}, mutable=['batch_stats'])
one_hot = jax.nn.one_hot(labels, num_classes)
loss = optax.softmax_cross_entropy(logits=logits, labels=one_hot).mean()
return loss, (logits, batch_stats)
(loss, (logits, batch_stats)), grads = jax.value_and_grad(loss_fn, has_aux=True)(state.params,state.batch_stats)
probs = jax.lax.pmean(jax.nn.softmax(logits), axis_name='ensemble')
accuracy = jnp.mean(jnp.argmax(probs, -1) == labels)
return grads,loss, accuracy
@jax.pmap
def update_model(state, grads):
return state.apply_gradients(grads=grads)
TensorBoard in Flax
The next step is to train the ResNet model. However, you might be interested in tracking the training using TensorBoard. In that case, you have to configure TensorBoard. You can write the metrics to TensorBoard using the PyTorch SummaryWriter
.
rm -rf ./flax_logs/
from torch.utils.tensorboard import SummaryWriter
import torchvision.transforms.functional as F
logdir = "flax_logs"
writer = SummaryWriter(logdir)
👉 Check our How to use TensorBoard in JAX & Flax tutorial.
Train Flax ResNet model
Let's train the ResNet model on the entire training set and evaluate it on a subset of the test set. You can also evaluate it on the whole test set. Replicate the test set to the available devices.
(test_images, test_labels) = next(iter(validation_loader))
test_images = test_images / 255.0
test_images = np.array(jax_utils.replicate(test_images))
test_labels = np.array(jax_utils.replicate(test_labels))
Create some lists to hold the training and evaluation metrics.
epoch_loss = []
epoch_accuracy = []
testing_accuracy = []
testing_loss = []
Next, define the ResNet model training function. The function does the following:
- Loops through the training dataset and scales it.
- Replicates the data on the available devices.
- Applies the model on the dataset and computes the metrics.
- Obtains the metrics from the devices using
jax_utils.unreplicate
. - Appends the metrics to a list.
- Computes the mean of the loss and accuracy to obtain the metrics for each epoch.
- Applies the model to the test set and obtains the metrics.
- Append the test metrics to a list.
- Writes the training and evaluation metrics to TensorBaord.
- Prints the training and evaluation metrics.
def train_one_epoch(state, dataloader,num_epochs):
"""Train for 1 epoch on the training set."""
for epoch in range(num_epochs):
for cnt, (images, labels) in tqdm(enumerate(dataloader), total=(math.ceil(len(train_set)/batch_size))):
images = images / 255.0
images = jax_utils.replicate(images)
labels = jax_utils.replicate(labels)
grads, loss, accuracy = apply_model(state, images, labels)
state = update_model(state, grads)
epoch_loss.append(jax_utils.unreplicate(loss))
epoch_accuracy.append(jax_utils.unreplicate(accuracy))
train_loss = np.mean(epoch_loss)
train_accuracy = np.mean(epoch_accuracy)
_, test_loss, test_accuracy = jax_utils.unreplicate(apply_model(state, test_images, test_labels))
testing_accuracy.append(test_accuracy)
testing_loss.append(test_loss)
writer.add_scalar('Loss/train', np.array(train_loss), epoch)
writer.add_scalar('Loss/test', np.array(test_loss), epoch)
writer.add_scalar('Accuracy/train', np.array(train_accuracy), epoch)
writer.add_scalar('Accuracy/test', np.array(test_accuracy), epoch)
print(f"Epoch: {epoch + 1}, train loss: {train_loss:.4f}, train accuracy: {train_accuracy * 100:.4f}, test loss: {test_loss:.4f}, test accuracy: {test_accuracy* 100:.4f}", flush=True)
return state, epoch_loss, epoch_accuracy, testing_accuracy, testing_loss
Create a training state by generating random numbers equivalent to the number of devices.
seed = 0
rng = jax.random.PRNGKey(seed)
rng, init_rng = jax.random.split(rng)
state = create_train_state(jax.random.split(init_rng, jax.device_count()))
del init_rng # Must not be used anymore.
Train the ResNet model by passing the training data and the number of epochs.
start = time.time()
num_epochs = 30
state, epoch_loss, epoch_accuracy, testing_accuracy, testing_loss = train_one_epoch(state, train_loader,num_epochs)
print("Total time: ", time.time() - start, "seconds")
Evaluate model with TensorBoard
Run TensorBoard to see the logged scalars on TensorBoard.
%load_ext tensorboard
%tensorboard --logdir={logdir}
👉 Check our TensorBoard tutorial (Deep dive with examples and notebook) tutorial.
Visualize Flax model performance
The metrics that were stored in a list can be plotted using Matplotlib.
plt.plot(epoch_accuracy, label="Training")
plt.plot(testing_accuracy, label="Test")
plt.xlabel("Epoch")
plt.ylabel("Accuracy")
plt.legend()
plt.show()
plt.plot(epoch_loss, label="Training")
plt.plot(testing_loss, label="Test")
plt.xlabel("Epoch")
plt.ylabel("Accuracy")
plt.legend()
plt.show()
Save Flax ResNet model
To save the trained Flax ResNet model use the save_checkpoint
function. The function expects:
- The folder where the ResNet model will be saved.
- The model to be saved–
target
. - The step – training step number.
- The model prefix.
- Whether to overwrite existing models.
!pip install tensorstore
from flax.training import checkpoints
ckpt_dir = 'model_checkpoint/'
checkpoints.save_checkpoint(ckpt_dir=ckpt_dir,
target=state,
step=100,
prefix='flax_model',
overwrite=True
)
Load Flax RestNet model
The saved ResNet Flax model can also be loaded to make predictions. Flax models are loaded using the restore_checkpoint
function. The function expects:
- The target state.
- The folder containing the saved model.
- The model's prefix.
loaded_model = checkpoints.restore_checkpoint(
ckpt_dir=ckpt_dir,
target=state,
prefix='flax_model'
)
Final thoughts
In this article, you have learned how to train a ResNet model from scratch in Flax. In particular, you have covered:
- Creating a ResNet model in Flax.
- Defining the training state for the ResNet Flax model.
- Training the Flax ResNet model in a distributed manner.
- Track the performance of the Flax ResNet model with TensorBoard.
- Saving and loading the Flax ResNet model.
Interested in other JAX and Flax tutorials. Check other articles from our blog:
- What is JAX?
- Flax vs. TensorFlow
- JAX loss functions
- Transfer learning with JAX & Flax
- Activation functions in JAX and Flax
- Optimizers in JAX and Flax
- Handling state in JAX & Flax (BatchNorm and DropOut layers)
- How to load datasets in JAX using TensorFlow
- Building Convolutional Neural Networks in JAX and Flax
- Distributed training in JAX
- Using TensorBoard in JAX and Flax
- LSTM in JAX & Flax
- Elegy (High-level API for deep learning in JAX & Flax)
Follow us on LinkedIn, Twitter, GitHub, and subscribe to our blog, so you don't miss a new issue.