JAX

LSTM in JAX & Flax (Complete example with code and notebook) Members Public

LSTMs are a class of neural networks used to solve sequence problems such as time series and natural language processing. The LSTMs maintain some internal state that is useful in solving these problems. LSTMs apply for loops to iterate over each time step. We can use functions from JAX and

Derrick Mwiti
Derrick Mwiti
JAX
How to use TensorBoard in Flax

How to use TensorBoard in JAX & Flax Members Public

Tracking machine learning experiments makes understanding and visualizing the model's performance easy. It also makes it possible to spot any problems in the network. For example, you can quickly spot overfitting by looking at the training and validation charts. You can plot these charts using your favorite charts

Derrick Mwiti
Derrick Mwiti
JAX
How to load datasets in JAX with TensorFlow

How to load datasets in JAX with TensorFlow Members Public

JAX doesn't ship with data loading utilities. This keeps JAX focused on providing a fast tool for building and training machine learning models. Loading data in JAX is done using either TensorFlow or PyTorch. In the Image classification with JAX & Flax tutorial, we saw how to load

Derrick Mwiti
Derrick Mwiti
JAX
                                                           JAX loss functions

JAX loss functions Members Public

Loss functions are at the core of training machine learning. They can be used to identify how well the model is performing on a dataset. Poor performance leads to a very high loss, while a well-performing model will have a lower loss. Therefore, the choice of a loss function is

Derrick Mwiti
Derrick Mwiti
JAX
What is JAX?

JAX (What it is and how to use it in Python) Members Public

JAX is a Python library offering high performance in machine learning with XLA and Just In Time (JIT) compilation. Its API is similar to NumPy's with a few differences. JAX ships with functionalities that aim to improve and increase speed in machine learning research. These functionalities include: * Automatic

Derrick Mwiti
Derrick Mwiti
JAX
Distributed training with JAX & Flax

Distributed training with JAX & Flax Members Public

Training models on accelerators with JAX and Flax differs slightly from training with CPU. For instance, the data needs to be replicated in the different devices when using multiple accelerators. After that, we need to execute the training on multiple devices and aggregate the results. Flax supports TPU and GPU

Derrick Mwiti
Derrick Mwiti
JAX

Image classification with JAX & Flax Members Public

Flax is a neural network library for JAX. JAX is a Python library that provides high-performance computing in machine learning research. JAX provides an API similar to NumPy making it easy to adopt. JAX also includes other functionalities for improving machine learning research. They include: * Automatic differentiation. JAX supports forward

Derrick Mwiti
Derrick Mwiti
JAX