JAX

Handling state in JAX & Flax (BatchNorm and DropOut layers) Paid Members Public
Jitting functions in Flax makes them faster but requires that the functions have no side effects. The fact that jitted functions can't have side effects introduces a challenge when dealing with stateful items such as model parameters and stateful layers such as batch normalization layers. In this article, we'll create

Transfer learning with JAX & Flax Paid Members Public
Training large neural networks can take days or weeks. Once these networks are trained, you can take advantage of their weights and apply them to new tasks– transfer learning. As a result, you fine-tune a new network and get good results in a short period. Let's look at how you

Activation functions in JAX and Flax Paid Members Public
Activation functions are applied in neural networks to ensure that the network outputs the desired result. The activations functions cap the output within a specific range. For instance, when solving a binary classification problem, the outcome should be a number between 0 and 1. This indicates the probability of an

Optimizers in JAX and Flax Paid Members Public
Optimizers are applied when training neural networks to reduce the error between the true and predicted values. This optimization is done via gradient descent. Gradient descent adjusts errors in the network through a cost function. In JAX, optimizers are applied from the Optax library. Optimizers can be classified into two

Elegy(High-level API for deep learning in JAX & Flax) Paid Members Public
Training deep learning networks in Flax is done in a couple of steps. It involves creating the following functions: * Model definition. * Compute metrics. * Training state. * Training step. * Training and evaluation function. Flax and JAX give more control in defining and training deep learning networks. However, this comes with more verbosity.

LSTM in JAX & Flax (Complete example with code and notebook) Paid Members Public
LSTMs are a class of neural networks used to solve sequence problems such as time series and natural language processing. The LSTMs maintain some internal state that is useful in solving these problems. LSTMs apply for loops to iterate over each time step. We can use functions from JAX and

How to load datasets in JAX with TensorFlow Paid Members Public
JAX doesn't ship with data loading utilities. This keeps JAX focused on providing a fast tool for building and training machine learning models. Loading data in JAX is done using either TensorFlow or PyTorch. In the Image classification with JAX & Flax tutorial, we saw how to load image data with

JAX loss functions Paid Members Public
Loss functions are at the core of training machine learning. They can be used to identify how well the model is performing on a dataset. Poor performance leads to a very high loss, while a well-performing model will have a lower loss. Therefore, the choice of a loss function is