JAX

Convolutional Neural Networks in JAX: Ultimate Guide Members Public

JAX is a high performance library that offers accelerated computing through XLA and Just In Time Compilation. It also has handy features that enable you to write one codebase that can be applied to batches of data and run on CPU, GPU, or TPU. However, one of its biggest selling

Derrick Mwiti
Derrick Mwiti
JAX

Train ResNet in Flax from scratch(Distributed ResNet training) Members Public

Apart from designing custom CNN architectures, you can use architectures that have already been built. ResNet is one such popular architecture. In most cases, you'll achieve better performance by using such architectures. In this article, you will learn how to perform distributed training of a ResNet model in Flax. Install

Derrick Mwiti
Derrick Mwiti
JAX

Handling state in JAX & Flax (BatchNorm and DropOut layers) Members Public

Jitting functions in Flax makes them faster but requires that the functions have no side effects. The fact that jitted functions can't have side effects introduces a challenge when dealing with stateful items such as model parameters and stateful layers such as batch normalization layers. In this article, we'll create

Derrick Mwiti
Derrick Mwiti
JAX

Transfer learning with JAX & Flax Members Public

Training large neural networks can take days or weeks. Once these networks are trained, you can take advantage of their weights and apply them to new tasks– transfer learning. As a result, you fine-tune a new network and get good results in a short period. Let's look at how you

Derrick Mwiti
Derrick Mwiti
JAX

Flax vs. TensorFlow Members Public

Flax is the neural network library for JAX. TensorFlow is a deep learning library with a large ecosystem of tools and resources. Flax and TensorFlow are similar but different in some ways. For instance, both Flax and TensorFlow can run on XLA. Let's look at the differences between Flax and

Derrick Mwiti
Derrick Mwiti
JAX

Activation functions in JAX and Flax Members Public

Activation functions are applied in neural networks to ensure that the network outputs the desired result. The activations functions cap the output within a specific range. For instance, when solving a binary classification problem, the outcome should be a number between 0 and 1. This indicates the probability of an

Derrick Mwiti
Derrick Mwiti
JAX

Optimizers in JAX and Flax Members Public

Optimizers are applied when training neural networks to reduce the error between the true and predicted values. This optimization is done via gradient descent. Gradient descent adjusts errors in the network through a cost function. In JAX, optimizers are applied from the Optax library. Optimizers can be classified into two

Derrick Mwiti
Derrick Mwiti
JAX

Elegy(High-level API for deep learning in JAX & Flax) Members Public

Training deep learning networks in Flax is done in a couple of steps. It involves creating the following functions: * Model definition. * Compute metrics. * Training state. * Training step. * Training and evaluation function. Flax and JAX give more control in defining and training deep learning networks. However, this comes with more verbosity.

Derrick Mwiti
Derrick Mwiti
JAX