Flax vs. TensorFlow

Derrick Mwiti
Derrick Mwiti

Table of Contents

Flax is the neural network library for JAX. TensorFlow is a deep learning library with a large ecosystem of tools and resources. Flax and TensorFlow are similar but different in some ways. For instance, both Flax and TensorFlow can run on XLA.

Let's look at the differences between Flax and TensorFlow from my perspective as a user of both libraries.

Random number generation in TensorFlow and Flax

In TensorFlow, you can set global or function level seeds. Generating random numbers in TensorFlow is quite straightforward.


However, this is not the case in Flax. Flax is built on top of JAX. JAX expects pure functions, meaning functions without any side effects. To achieve this JAX introduces stateless pseudo-random number generators (PRNGs). For example, calling the random number generator from NumPy will result in a different number every time.

import numpy as np

In JAX and Flax, the result should be the same on every call. We, therefore, generate random numbers from a random state. The state should not be re-used. It can be split to obtain several pseudo-random numbers.

import jax 
key = jax.random.PRNGKey(0)
key1, key2, key3 = jax.random.split(key, num=3)

Model definition in Flax and TensorFlow

Model definition in TensorFlow is made easy by the Keras API. You can use Keras to define Sequential or Functional networks. Keras has many layers for designing various types of networks, such as CNNs, and LSTMS.

Read more: How to build TensorFlow models with the Keras Functional API

In Flax, networks are designed using the setup or compact way. The setup method is explicit, while the compact way is in-line. Setup is very similar to how networks are designed in PyTorch. For example, here is a network designed with the setup way.

class MLP(nn.Module):
  def setup(self):
    # Submodule names are derived by the attributes you assign to. In this
    # case, "dense1" and "dense2". This follows the logic in PyTorch.
    self.dense1 = nn.Dense(32)
    self.dense2 = nn.Dense(32)

  def __call__(self, x):
    x = self.dense1(x)
    x = nn.relu(x)
    x = self.dense2(x)
    return x

Here's the same network designed in a compact way. The compact way is more straightforward because there is less code duplicity.

class MLP(nn.Module):

  def __call__(self, x):
    x = nn.Dense(32, name="dense1")(x)
    x = nn.relu(x)
    x = nn.Dense(32, name="dense2")(x)
    return x

Check our Image classification with JAX & Flax article to learn more about designing networks in JAX and Flax.

Activations in Flax and TensorFlow

The tf.keras.activations module in TensorFlow provides most of the activations needed when designing networks. In Flax, activation functions are available via the linen module.

Read more: Activation functions in JAX and Flax

Optimizers in Flax and TensorFlow

The tf.keras.optimizers in TensorFlow has popular optimizer functions. However, Flax doesn't ship with any optimizer functions. Optimizers used in Flax are provided by another library known as Optax.

Read more: Optimizers in JAX and Flax

Metrics in Flax and TensorFlow

In TensorFlow, metrics are available via the tf.keras.metrics module. As of this writing, Flax has no metrics module. You'll need to define metric functions for your networks or use other third-party libraries.

import optax
import jax.numpy as jnp
def compute_metrics(logits, labels):
  loss = jnp.mean(optax.softmax_cross_entropy(logits, jax.nn.one_hot(labels, num_classes=2)))
  accuracy = jnp.mean(jnp.argmax(logits, -1) == labels)
  metrics = {
      'loss': loss,
      'accuracy': accuracy
  return metrics

Read more: JAX loss functions

Computing gradients in Flax and TensorFlow

The jax.grad function is used to compute gradients in Flax. It offers the ability to return auxillary data. For example, you can return loss and gradients at the same time.

def sum_logistic(x):
  return jnp.sum(1.0 / (1.0 + jnp.exp(-x))),(x + 1)

x_small = jnp.arange(6.)
derivative_fn = jax.grad(sum_logistic, has_aux=True)
# (DeviceArray([0.25      , 0.19661194, 0.10499357, 0.04517666, 0.01766271,
#            0.00664806], dtype=float32), DeviceArray([1., 2., 3., 4., 5., 6.], dtype=float32))

Advanced automatic differentiation can also be done using jax.vjp() and jax.jvp().

Read more: JAX (What it is and how to use it in Python)

In TensorFlow, gradients are computed using tf.GradientTape.

def grad(model, inputs, targets):
  with tf.GradientTape() as tape:
    loss_value = loss(model, inputs, targets, training=True)
  return loss_value, tape.gradient(loss_value, model.trainable_variables)

Unless you are creating custom training loops in TensorFlow, you will not define a gradient function. This is done automatically when you train the network.

Loading datasets in Flax and TensorFlow

TensorFlow provides utilities for loading data. Flax doesn't ship with any data loaders. You have to use the data loaders from other libraries such as TensorFlow. As long as the data is in JAX NumPy or regular arrays and has the proper shape, it can be passed to Flax networks.

Read more: How to load datasets in Flax with TensorFlow

Training model in Flax vs. TensorFlow

Training models in TensorFlow is done by compiling the network and calling the fit method. However, in Flax, we create a training state to hold the training information and then pass data to the network.

from flax.training import train_state

def create_train_state(rng):
  """Creates initial `TrainState`."""
  model = LSTMModel()
  params = model.init(rng, jnp.array(X_train_padded[0]))['params']
  tx = optax.adam(0.001,0.9,0.999,1e-07)
  return train_state.TrainState.create(
      apply_fn=model.apply, params=params, tx=tx)

After that, we define a training step that will compute the loss and gradients. It then uses these gradients to update the model parameters and returns the model metrics and the new state.

def train_step(state, text, labels):
  def loss_fn(params):
    logits = LSTMModel().apply({'params': params}, text)
    loss = jnp.mean(optax.softmax_cross_entropy(
        labels=jax.nn.one_hot(labels, num_classes=2)))
    return loss, logits
  grad_fn = jax.value_and_grad(loss_fn, has_aux=True)
  (_, logits), grads = grad_fn(state.params)
  state = state.apply_gradients(grads=grads)
  metrics = compute_metrics(logits, labels)
  return state, metrics

Use Elegy to train networks like in Keras. Elegy is a high-level API for JAX neural network libraries.

Distributed training in Flax and TensorFlow

Training networks in TensorFlow in a distributed manner is done by creating distributed strategy.

mirrored_strategy = tf.distribute.MirroredStrategy()

with mirrored_strategy.scope():
  model = tf.keras.Sequential([tf.keras.layers.Dense(1, input_shape=(1,))])

model.compile(loss='mse', optimizer='sgd')

To train networks in a distributed way in Flax, we define distributed versions of our Flax functions. This is done using the pmap function that executes a function on multiple devices. You'll then compute predictions from all devices and get the average using jax.lax.pmean(). You also need to replicate the data on all the devices using jax_utils.replicate . To obtain metrics from the device use jax_utils.unreplicate.

Read more: Distributed training with JAX & Flax

Working with TPU accelerators

You can use Flax and TensorFlow with TPU and GPU accelerators. To use Flax with TPUs on Colab, you'll need to set it up:


For TensorFlow, set up the TPU distributed strategy.

cluster_resolver = tf.distribute.cluster_resolver.TPUClusterResolver(
tpu_strategy = tf.distribute.TPUStrategy(cluster_resolver)

Model evaluation

TensorFlow provides the evaluate function for evaluating networks. Flax doesn't ship with such a function. You'll need to create a function that applies the model and returns the test metrics. Elegy provides Keras-like functions such as the evaluate method.

def eval_step(state, text, labels):
    logits = LSTMModel().apply({'params': state.params}, text)
    return compute_metrics(logits=logits, labels=labels)
  def evaluate_model(state, text, test_lbls):
    """Evaluate on the validation set."""
    metrics = eval_step(state, text, test_lbls)
    metrics = jax.device_get(metrics) 
    metrics = jax.tree_map(lambda x: x.item(), metrics)  
    return metrics

Visualize model performance

Model visualizing is similar in Flax and TensorFlow. Once you obtain the metrics, you can use a package such as Matplotlib to visualize the model's performance. You can also use TensorBoard in both Flax and TensorFlow.

Final thoughts

You have seen the differences between the Flax and TensorFlow libraries. In particular, have seen the difference in model definition and training. Interested in exploring JAX and Flax further? Here are more nuggets from our blog:

Follow us on LinkedIn, Twitter, GitHub, and subscribe to our blog, so you don't miss a new issue.


Derrick Mwiti Twitter

Google Developer Expert - Machine Learning