JAX
LSTM in JAX & Flax (Complete example with code and notebook) Paid Members Public
LSTMs are a class of neural networks used to solve sequence problems such as time series and natural language processing. The LSTMs maintain some internal state that is useful in solving these problems. LSTMs apply for loops to iterate over each time step. We can use functions from JAX and
How to use TensorBoard in JAX & Flax Paid Members Public
Tracking machine learning experiments makes understanding and visualizing the model's performance easy. It also makes it possible to spot any problems in the network. For example, you can quickly spot overfitting by looking at the training and validation charts. You can plot these charts using your favorite charts
How to load datasets in JAX with TensorFlow Paid Members Public
JAX doesn't ship with data loading utilities. This keeps JAX focused on providing a fast tool for building and training machine learning models. Loading data in JAX is done using either TensorFlow or PyTorch. In the Image classification with JAX & Flax tutorial, we saw how to load
JAX loss functions Paid Members Public
Loss functions are at the core of training machine learning. They can be used to identify how well the model is performing on a dataset. Poor performance leads to a very high loss, while a well-performing model will have a lower loss. Therefore, the choice of a loss function is
JAX (What it is and how to use it in Python) Paid Members Public
JAX is a Python library offering high performance in machine learning with XLA and Just In Time (JIT) compilation. Its API is similar to NumPy's with a few differences. JAX ships with functionalities that aim to improve and increase speed in machine learning research. These functionalities include: * Automatic
Distributed training with JAX & Flax Paid Members Public
Training models on accelerators with JAX and Flax differs slightly from training with CPU. For instance, the data needs to be replicated in the different devices when using multiple accelerators. After that, we need to execute the training on multiple devices and aggregate the results. Flax supports TPU and GPU
Image classification with JAX & Flax Paid Members Public
Flax is a neural network library for JAX. JAX is a Python library that provides high-performance computing in machine learning research. JAX provides an API similar to NumPy making it easy to adopt. JAX also includes other functionalities for improving machine learning research. They include: * Automatic differentiation. JAX supports forward