Flax

Train ResNet in Flax from scratch(Distributed ResNet training) Members Public

Apart from designing custom CNN architectures, you can use architectures that have already been built. ResNet is one such popular architecture. In most cases, you'll achieve better performance by using such architectures. In this article, you will learn how to perform distributed training of a ResNet model in Flax. Install

Derrick Mwiti
Derrick Mwiti
Flax

Handling state in JAX & Flax (BatchNorm and DropOut layers) Members Public

Jitting functions in Flax makes them faster but requires that the functions have no side effects. The fact that jitted functions can't have side effects introduces a challenge when dealing with stateful items such as model parameters and stateful layers such as batch normalization layers. In this article, we'll create

Derrick Mwiti
Derrick Mwiti
Flax

Transfer learning with JAX & Flax Members Public

Training large neural networks can take days or weeks. Once these networks are trained, you can take advantage of their weights and apply them to new tasks– transfer learning. As a result, you fine-tune a new network and get good results in a short period. Let's look at how you

Derrick Mwiti
Derrick Mwiti
Flax

Flax vs. TensorFlow Members Public

Flax is the neural network library for JAX. TensorFlow is a deep learning library with a large ecosystem of tools and resources. Flax and TensorFlow are similar but different in some ways. For instance, both Flax and TensorFlow can run on XLA. Let's look at the differences between Flax and

Derrick Mwiti
Derrick Mwiti
Flax

Activation functions in JAX and Flax Members Public

Activation functions are applied in neural networks to ensure that the network outputs the desired result. The activations functions cap the output within a specific range. For instance, when solving a binary classification problem, the outcome should be a number between 0 and 1. This indicates the probability of an

Derrick Mwiti
Derrick Mwiti
Flax

Optimizers in JAX and Flax Members Public

Optimizers are applied when training neural networks to reduce the error between the true and predicted values. This optimization is done via gradient descent. Gradient descent adjusts errors in the network through a cost function. In JAX, optimizers are applied from the Optax library. Optimizers can be classified into two

Derrick Mwiti
Derrick Mwiti
JAX

Elegy(High-level API for deep learning in JAX & Flax) Members Public

Training deep learning networks in Flax is done in a couple of steps. It involves creating the following functions: * Model definition. * Compute metrics. * Training state. * Training step. * Training and evaluation function. Flax and JAX give more control in defining and training deep learning networks. However, this comes with more verbosity.

Derrick Mwiti
Derrick Mwiti
JAX

LSTM in JAX & Flax (Complete example with code and notebook) Members Public

LSTMs are a  class of neural networks used to solve sequence problems such as time series and natural language processing. The LSTMs maintain some internal state that is useful in solving these problems. LSTMs apply for loops to iterate over each time step. We can use functions from JAX and

Derrick Mwiti
Derrick Mwiti
Flax