What is JAX?
Photo by Pixabay

JAX (What it is and how to use it in Python)

Derrick Mwiti
Derrick Mwiti
11 min read

JAX is a Python library offering high performance in machine learning with XLA and Just In Time (JIT) compilation. Its API is similar to NumPy's with a few differences. JAX ships with functionalities that aim to improve and increase speed in machine learning research. These functionalities include:

  • Automatic differentiation
  • Vectorization
  • JIT compilation

This article will cover these functionalities and other JAX concepts. Let's get started.

What is XLA?

XLA (Accelerated Linear Algebra) is a linear algebra compiler for accelerating machine learning models. It leads to an increase in the speed of model execution and reduced memory usage. XLA programs can be generated by JAX, PyTorch, Julia, and NX.

Installing JAX

JAX can be installed from the Python Package Index using:

pip install jax

JAX is pre-installed on Google Colab. See the link below for other installation options.  

GitHub - google/jax: Composable transformations of Python+NumPy programs: differentiate, vectorize, JIT to GPU/TPU, and more
Composable transformations of Python+NumPy programs: differentiate, vectorize, JIT to GPU/TPU, and more - GitHub - google/jax: Composable transformations of Python+NumPy programs: differentiate, ve...

Setting up TPUs on Google Colab

You need to set up JAX to use TPUs on Colab. That is done by executing the following code. Ensure that you have changed the runtime to TPU by going to Runtime-> Change Runtime Type. If no accelerator is available, JAX will use the CPU.

Data types in JAX

The data types in NumPy are similar to those in JAX arrays. For instance, here is how you can create float and int data in JAX.  

When you check the type of the data, you will see that it's a DeviceArray.  

DeviceArray in JAX is the equivalent of numpy.ndarry in NumPy.

jax.numpy provides an interface similar to NumPy's. However, JAX also provides jax.lax a low-level API that is more powerful and stricter. For example, with jax.numpy you can add numbers that have mixed types but jax.lax will not allow this.

Ways to create JAX  arrays

You can create JAX arrays like you would in NumPy. For example, can use:

  • arange
  • linspace
  • Python lists.
  • ones.
  • zeros.
  • identity.

Generating random numbers with JAX

Random number generation is one main difference between JAX and NumPy. JAX is meant to be used with functional programs. JAX expects these functions to be pure. A pure function has no side effects and expects the output to only come from its inputs. JAX transformation functions expect pure functions.  

Therefore, when working with JAX, all input should be passed through function parameters, while all output should come from the function results. Hence, something like Python's print function is not pure.  

A pure function returns the same results when called with the same inputs. This is not possible with np.random.random() because it is stateful and returns different results when called several times.

JAX implements random number generation using a random state. This random state is referred to as a key🔑 . JAX generates pseudorandom numbers from the pseudorandom number generator (PRNGs)  state.

You should, therefore, not reuse the same state. Instead, you should split the PRNG to obtain as many sub keys as you need.

Using the same key will always generate the same output.

Pure functions

We have mentioned that the output of a pure function should only come from the result of the function. Therefore, something like Python's print function introduces impurity. This can be demonstrated using this function.

We can see the printed statement the first time the function is executed. However, we don't see that print statement in consecutive runs because it is cached. We only see the statement again after changing the data's shape, which forces JAX to recompile the function. More on jax.jit in a moment.  

JAX NumPy operations

Operations on JAX arrays are similar to operations with NumPy arrays. For example, you can max, argmax, and sum like in NumPy.

However, JAX doesn't allow operations on non-array input like NumPy. For example, passing Python lists or tuples will lead to an error.

JAX arrays are immutable

Unlike in NumPy, JAX arrays can not be modified in place. This is because JAX expects pure functions.

Array updates in JAX are performed using x.at[idx].set(y). This returns a new array while the old array stays unaltered.  

Out-of-Bounds Indexing

NumPy usually throws an error when you try to get an item in an array that is out of bounds. JAX doesn't throw any error but returns the last item in the array.

JAX is designed like this because throwing errors in accelerators can be challenging.    

Data placement on devices in JAX

JAX arrays are placed in the first device, jax.devices()[0]that is, GPU, TPU, or CPU. Data can be placed on a particular device using jax.device_put().  

The data becomes committed to that device, and operations on it are also committed on the same device.  

How fast is JAX?

JAX uses asynchronous dispatch, meaning that it does not wait for computation to complete to give control back to the Python program. Therefore, when you perform an execution, JAX will return a future. JAX forces Python to wait for the execution when you want to print the output or if you convert the result to a NumPy array.

JAX Frequently Asked Questions (FAQ) — JAX documentation
Is JAX faster than NumPy?

Therefore, if you want to compute the time of execution of a program you'll have to convert the result to a NumPy array using block_until_ready() to wait for the execution to complete.   Generally speaking, NumPy will outperform JAX on the CPU, but JAX will outperform NumPy on accelerators and when using jitted functions.

Using jit() to speed up functions

jit performs just-in-time compilation with XLA. jax.jit expects a pure function. Any side effects in the function will only be executed once. Let's create a pure function and time its execution time without jit.

Let's now use jit and time the execution of the same function. In this case, we can see that using jit makes the execution almost 20 times faster.  

In the above example, test_fn_jit is the jit-compiled version of the function. JAX then created code that is optimized for GPU or TPU. The optimized code is what will be used the next time this function is called.

How JIT works

JAX works by converting Python functions into an intermediate language called jaxpr(JAX Expression). The jax.make_jaxpr can be used to show the jaxpr representation of a Python function. If the function has any side effects, they are not recorded by jaxpr. We saw earlier that any side effects, for example, printing, will only be shown during the first call.

printed x: Traced<ShapedArray(float32[6])>with<DynamicJaxprTrace(level=1/0)>
{ lambda ; a:f32[6]. let
    b:f32[6] = neg a
    c:f32[6] = exp b
    d:f32[6] = add c 1.0
    e:f32[6] = div 1.0 d
    f:f32[] = reduce_sum[axes=(0,)] e
  in (f,) }

JAX creates the jaxpr through tracing. Each argument in the function is wrapped with a tracer object. The purpose of these tracers is to record all JAX operations performed on them when the function is called. JAX uses the tracer records to rebuild the function, which leads to jaxpr. Python side-effects don't show up in the jaxpr because the tracers do not record them.  

JAX requires arrays shapes to be static and known at compile time. Decorating a function conditioned on a value with jit results in error. Therefore, not all code can be jit-compiled.

There are a couple of solutions to this problem:

  • Remove conditionals on the value.
  • Use JAX control flow operators such as jax.lax.cond.  
  • Jit only a part of the function.
  • Make parameters static.

We can implement the last option and make the boolean parameter static. This is done by specifying static_argnums or static_argnames. This forces JAX to recompile the function when the value of the static parameter changes. This is not a good strategy if the function will get many values for the static argument. You don't want to recompile the function too many times.

You can pass the static arguments using Python’s functools.partial .

Taking derivatives with grad()

Computing derivatives in JAX is done using jax.grad.

The grad function has a has_aux argument that allows you to return auxiliary data. For example, when building machine learning models, you can use it to return loss and gradients.  

You can perform advanced automatic differentiation using jax.vjp() and jax.jvp().

The Autodiff Cookbook — JAX documentation
JAX Advanced Autodiff

Auto-vectorization with vmap

vmap(Vectorizing map) allows you write a function that can be applied to a single data and then vmap will map it to a batch of data. Without vmap the solution would be to loop through the batches while applying the function. Using jit with for loops is a little complicated and may be slower.  

In JAX, the jax.vmap transformation is designed to generate a vectorized implementation of a function automatically. It does this by tracing the function similarly to jax.jit, and automatically adding batch axes at the beginning of each input. If the batch dimension is not the first, you may use the in_axes and out_axes arguments to specify the location of the batch dimension in inputs and outputs. These may be an integer if the batch axis is the same for all inputs and outputs, or lists, otherwise. Matteo Hessel, JAX author.

Parallelization with pmap

The working of jax.pmap is similar to jax.vmap. The difference is that jax.pmap is meant for parallel execution, that is, computation on multiple devices.  This is applicable when training a machine learning model on batches of data.

Computation on batches can occur in different devices then the results are aggregated.  The pmaped function returns a ShardedDeviceArray . This is because the arrays are split across all the devices. There is no need to decorate the function with jit because the function is jit-compiled by default when using pmap.

You may need to aggregate data using one of the collective operators, for example, to compute the mean of the accuracy or mean of the logits. In that case, you'll need to specify an axis_name. This name is important to achieve communication between devices. See the article below on how to train machine learning models in a distributed manner in JAX.  

Distributed training with JAX & Flax
Training models on accelerators with JAX and Flax differs slightly from training with CPU. For instance, the data needs to be replicated in the different devices when using multiple accelerators. After that, we need to execute the training on multiple devices and aggregate the results. Flax supports…

Debugging NANs in JAX

By default, the occurrence of NANs in JAX program will not lead to an error.  

You can turn on the NAN checker and your program will error out at the occurrence of NANs. You should only use the NAN checker for debugging because it leads to performance issues. Also, it doesn't work with pmap , use vmap instead.  

Double (64bit) precision

JAX enforces single-precision of numbers. For example, you will get a warning when you create a float64 number. If you check the type of the number, you will notice that it's float32.  

x = jnp.float64(1.25844) 
# /usr/local/lib/python3.7/dist-packages/jax/_src/numpy/lax_numpy.py:1806: UserWarning: Explicitly requested dtype float64 requested in array is not available, and will be truncated to dtype float32. To enable more dtypes, set the jax_enable_x64 configuration option or the JAX_ENABLE_X64 shell environment variable. See https://github.com/google/jax#current-gotchas for more.
# lax_internal._check_user_dtype_supported(dtype, "array")
# DeviceArray(1.25844, dtype=float32)

You can use double-precision numbers by setting that in the configuration using jax_enable_x64.  

What is a pytree?

A pytree is a container that holds Python objects. In JAX, it can hold arrays, tuples, lists, dictionaries, etc. A Pytree contains leaves. For example, model parameters in JAX are pytrees.

This post is for subscribers only


Already have an account? Log in