JAX (What it is and how to use it in Python)
JAX is a Python library offering high performance in machine learning with XLA and Just In Time (JIT) compilation. Its API is similar to NumPy's with a few differences. JAX ships with functionalities that aim to improve and increase speed in machine learning research. These functionalities include:
- Automatic differentiation
- Vectorization
- JIT compilation
This article will cover these functionalities and other JAX concepts. Let's get started.
What is XLA?
XLA (Accelerated Linear Algebra) is a linear algebra compiler for accelerating machine learning models. It leads to an increase in the speed of model execution and reduced memory usage. XLA programs can be generated by JAX, PyTorch, Julia, and NX.
Installing JAX
JAX can be installed from the Python Package Index using:
pip install jax
JAX is pre-installed on Google Colab. See the link below for other installation options.
Setting up TPUs on Google Colab
You need to set up JAX to use TPUs on Colab. That is done by executing the following code. Ensure that you have changed the runtime to TPU by going to Runtime-> Change Runtime Type. If no accelerator is available, JAX will use the CPU.
Data types in JAX
The data types in NumPy are similar to those in JAX arrays. For instance, here is how you can create float
and int
data in JAX.
When you check the type of the data, you will see that it's a DeviceArray
.
DeviceArray
in JAX is the equivalent of numpy.ndarry
in NumPy.
jax.numpy
provides an interface similar to NumPy's. However, JAX also provides jax.lax
a low-level API that is more powerful and stricter. For example, with jax.numpy
you can add numbers that have mixed types but jax.lax
will not allow this.
Ways to create JAX arrays
You can create JAX arrays like you would in NumPy. For example, can use:
arange
linspace
- Python lists.
ones
.zeros
.-
identity
.
Generating random numbers with JAX
Random number generation is one main difference between JAX and NumPy. JAX is meant to be used with functional programs. JAX expects these functions to be pure. A pure function has no side effects and expects the output to only come from its inputs. JAX transformation functions expect pure functions.
Therefore, when working with JAX, all input should be passed through function parameters, while all output should come from the function results. Hence, something like Python's print function is not pure.
A pure function returns the same results when called with the same inputs. This is not possible with np.random.random()
because it is stateful and returns different results when called several times.
JAX implements random number generation using a random state. This random state is referred to as a key🔑 . JAX generates pseudorandom numbers from the pseudorandom number generator (PRNGs) state.
You should, therefore, not reuse the same state. Instead, you should split the PRNG to obtain as many sub keys as you need.
Using the same key will always generate the same output.
Pure functions
We have mentioned that the output of a pure function should only come from the result of the function. Therefore, something like Python's print
function introduces impurity. This can be demonstrated using this function.
We can see the printed statement the first time the function is executed. However, we don't see that print statement in consecutive runs because it is cached. We only see the statement again after changing the data's shape, which forces JAX to recompile the function. More on jax.jit
in a moment.
JAX NumPy operations
Operations on JAX arrays are similar to operations with NumPy arrays. For example, you can max
, argmax
, and sum
like in NumPy.
However, JAX doesn't allow operations on non-array input like NumPy. For example, passing Python lists or tuples will lead to an error.
JAX arrays are immutable
Unlike in NumPy, JAX arrays can not be modified in place. This is because JAX expects pure functions.
Array updates in JAX are performed using x.at[idx].set(y)
. This returns a new array while the old array stays unaltered.
Out-of-Bounds Indexing
NumPy usually throws an error when you try to get an item in an array that is out of bounds. JAX doesn't throw any error but returns the last item in the array.
JAX is designed like this because throwing errors in accelerators can be challenging.
Data placement on devices in JAX
JAX arrays are placed in the first device, jax.devices()[0]
that is, GPU, TPU, or CPU. Data can be placed on a particular device using jax.device_put()
.
The data becomes committed to that device, and operations on it are also committed on the same device.
How fast is JAX?
JAX uses asynchronous dispatch, meaning that it does not wait for computation to complete to give control back to the Python program. Therefore, when you perform an execution, JAX will return a future. JAX forces Python to wait for the execution when you want to print the output or if you convert the result to a NumPy array.
Therefore, if you want to compute the time of execution of a program you'll have to convert the result to a NumPy array using block_until_ready()
to wait for the execution to complete. Generally speaking, NumPy will outperform JAX on the CPU, but JAX will outperform NumPy on accelerators and when using jitted functions.
Using jit() to speed up functions
jit
performs just-in-time compilation with XLA. jax.jit
expects a pure function. Any side effects in the function will only be executed once. Let's create a pure function and time its execution time without jit.
Let's now use jit and time the execution of the same function. In this case, we can see that using jit makes the execution almost 20 times faster.
In the above example, test_fn_jit
is the jit-compiled version of the function. JAX then created code that is optimized for GPU or TPU. The optimized code is what will be used the next time this function is called.
How JIT works
JAX works by converting Python functions into an intermediate language called jaxpr(JAX Expression). The jax.make_jaxpr
can be used to show the jaxpr representation of a Python function. If the function has any side effects, they are not recorded by jaxpr. We saw earlier that any side effects, for example, printing, will only be shown during the first call.
printed x: Traced<ShapedArray(float32[6])>with<DynamicJaxprTrace(level=1/0)>
{ lambda ; a:f32[6]. let
b:f32[6] = neg a
c:f32[6] = exp b
d:f32[6] = add c 1.0
e:f32[6] = div 1.0 d
f:f32[] = reduce_sum[axes=(0,)] e
in (f,) }
JAX creates the jaxpr through tracing. Each argument in the function is wrapped with a tracer object. The purpose of these tracers is to record all JAX operations performed on them when the function is called. JAX uses the tracer records to rebuild the function, which leads to jaxpr. Python side-effects don't show up in the jaxpr because the tracers do not record them.
JAX requires arrays shapes to be static and known at compile time. Decorating a function conditioned on a value with jit results in error. Therefore, not all code can be jit-compiled.
There are a couple of solutions to this problem:
- Remove conditionals on the value.
- Use JAX control flow operators such as
jax.lax.cond
. - Jit only a part of the function.
- Make parameters static.
We can implement the last option and make the boolean parameter static. This is done by specifying static_argnums
or static_argnames
. This forces JAX to recompile the function when the value of the static parameter changes. This is not a good strategy if the function will get many values for the static argument. You don't want to recompile the function too many times.
You can pass the static arguments using Python’s functools.partial
.
Taking derivatives with grad()
Computing derivatives in JAX is done using jax.grad
.
The grad
function has a has_aux
argument that allows you to return auxiliary data. For example, when building machine learning models, you can use it to return loss and gradients.
You can perform advanced automatic differentiation using jax.vjp()
and jax.jvp()
.
Auto-vectorization with vmap
vmap(Vectorizing map) allows you write a function that can be applied to a single data and then vmap
will map it to a batch of data. Without vmap
the solution would be to loop through the batches while applying the function. Using jit with for loops is a little complicated and may be slower.
Parallelization with pmap
The working of jax.pmap
is similar to jax.vmap
. The difference is that jax.pmap
is meant for parallel execution, that is, computation on multiple devices. This is applicable when training a machine learning model on batches of data.
Computation on batches can occur in different devices then the results are aggregated. The pmap
ed function returns a ShardedDeviceArray
. This is because the arrays are split across all the devices. There is no need to decorate the function with jit because the function is jit-compiled by default when using pmap
.
You may need to aggregate data using one of the collective operators, for example, to compute the mean of the accuracy or mean of the logits. In that case, you'll need to specify an axis_name
. This name is important to achieve communication between devices. See the article below on how to train machine learning models in a distributed manner in JAX.
Debugging NANs in JAX
By default, the occurrence of NANs in JAX program will not lead to an error.
You can turn on the NAN checker and your program will error out at the occurrence of NANs. You should only use the NAN checker for debugging because it leads to performance issues. Also, it doesn't work with pmap
, use vmap
instead.
Double (64bit) precision
JAX enforces single-precision of numbers. For example, you will get a warning when you create a float64
number. If you check the type of the number, you will notice that it's float32
.
x = jnp.float64(1.25844)
# /usr/local/lib/python3.7/dist-packages/jax/_src/numpy/lax_numpy.py:1806: UserWarning: Explicitly requested dtype float64 requested in array is not available, and will be truncated to dtype float32. To enable more dtypes, set the jax_enable_x64 configuration option or the JAX_ENABLE_X64 shell environment variable. See https://github.com/google/jax#current-gotchas for more.
# lax_internal._check_user_dtype_supported(dtype, "array")
# DeviceArray(1.25844, dtype=float32)
You can use double-precision numbers by setting that in the configuration using jax_enable_x64
.
What is a pytree?
A pytree is a container that holds Python objects. In JAX, it can hold arrays, tuples, lists, dictionaries, etc. A Pytree contains leaves. For example, model parameters in JAX are pytrees.
Handling state in JAX
Training machine learning models will often involve state in areas such as model parameters, optimizer state, and stateful Layer such as BatchNorm. However, jit-compiled functions must have no side effects. We, therefore, need a way to track and update model parameters, optimizer state, and stateful layers. The solution is to define the state explicitly. This article shows how to handle training state while training a machine learning model.
Loading datasets with JAX
JAX doesn't ship with any data loading tools. However, JAX recommends using data loaders from PyTorch and TensorFlow.
Building neural networks with JAX
You can build a model from scratch using JAX. However, various neural network libraries are built on top of JAX to make building neural networks with JAX easier. The Image classification with JAX & Flax article shows how to load data with PyTorch and build a convolutional neural network with Jax and Flax.
Final thoughts
In this article, we have covered the basics of JAX. We have seen that JAX uses XLA and just-in-time compilation to improve the performance of Python functions. Specifically, we have covered:
- Setting up JAX to use TPUs on Google Colab.
- Comparison between data types in JAX and NumPy.
- Creating arrays in JAX.
- How to generate random numbers in JAX.
- Operations on JAX arrays.
- Gotchas in JAX, such as using pure functions and the immutability of JAX arrays.
- Placing JAX arrays in GPUs or TPUs.
- How to use JIT to speed up functions.
...and so much more
Resources
- What is JAX?
- Flax vs. TensorFlow
- Distributed training with JAX & Flax
- How to load datasets in JAX with TensorFlow
- JAX loss functions
- Optimizers in JAX and Flax
- Image classification with JAX & Flax
- How to use TensorBoard in Flax
- LSTM in JAX & Flax
- Elegy(High-level API for deep learning in JAX & Flax)
Follow us on LinkedIn, Twitter, GitHub, and subscribe to our blog, so you don't miss a new issue.