JAX loss functions
Loss functions are at the core of training machine learning. They can be used to identify how well the model is performing on a dataset. Poor performance leads to a very high loss, while a well-performing model will have a lower loss. Therefore, the choice of a loss function is an important one when building machine learning models.
In this article, we'll look at the loss functions available in JAX and how you can use them.
What is a loss function?
Machine learning models learn by evaluating predictions against true values and adjusting the weights. The objective is to obtain the weights that minimize the loss function, that is the error. The loss function is also referred to as the cost function. The choice of a loss function depends on the problem. The two most common problems are classification and regression problems. Each will require a different set of loss functions.
Creating custom loss functions in JAX
When training networks with JAX, you'll need to obtain the logits at the training stage. These logits are used for computing the loss. You'll then need to evaluate the loss function and its gradient. The gradient is used to update the model parameters. At this point, you can compute the training metrics for the model.
Logits are unnormalized log probabilities.
You can use JAX functions such as log_sigmoid
and log_softmax
to build custom loss functions. You can even write your loss functions from scratch without using these functions. Here is an example of computing the sigmoid binary cross entropy loss.
Which loss functions are available in JAX?
Building custom loss functions for your networks can introduce errors in your program. Furthermore, you have to take the burden of maintaining these functions. However, if the loss function you want is unavailable, there is a strong case for creating a custom loss function. Be that as it may, there is no need to reinvent the wheel and rewrite the already implemented loss functions.
JAX doesn't ship with any loss functions. In JAX, we use optax for defining loss functions. It's important to ensure that you use JAX-compatible libraries to take advantage of functions such as JIT
, vmap
and pmap
that make your programs faster.
Let's take a look at some of the loss functions available in optax.
Sigmoid binary cross entropy
The sigmoid binary cross entropy loss is computed using optax.sigmoid_binary_cross_entropy
. The function expects logits and class labels. It is used in problems where the classes are not mutually exclusive. For example, the model can predict that the image contains two objects in an image classification problem.
Softmax cross entropy
The softmax cross entropy function is used where the classes are mutually exclusive. For example, in the MNIST dataset, each digit has exactly one label. The function expects an array of logits and probability distributions. The probability distribution sum to 1.
Cosine distance
The cosine distance measures the cosine distance between targets and predictions.
Cosine similarity
The cosine similarity loss measures the cosine similarity between the true and predicted values. The cosine similarity is the cosine of the angle between two vectors. This is obtained by the dot product of the vectors divided by the product of their lengths.
The result is a number between -1 and 1. 0 shows orthogonality, while numbers closer to -1 indicate similarity. Numbers close to 1 portray high dissimilarity.
Huber loss
The Huber loss is used for regression problems. It is less sensitive to outliers compared to the squared error loss. A variant of the Huber loss that can be used in classification problems exists.
l2 loss
L2 loss function is the Least Square Errors. The L2 loss aims at minimizing the sum of the squared differences between the true and predicted values. The Mean Squared Error is the mean of all L2 loss values.
log cosh
log_cosh
is the logarithm of the hyperbolic cosine of the prediction error.
Smooth labels
optax.smooth_labels
is used together with a cross-entropy loss to smooth labels. It returns a smoothed version of the one hot input labels. Label smoothing has been applied in image classification, language translation, and speech recognition to prevent models from becoming overconfident.
Computing loss with JAX Metrics
JAX Metrics is an open-source package for computing losses and metrics in JAX. It provides a Keras-like API for computing model loss and metrics.
For example, here is how you use the library to compute the cross-entropy loss. Similar to Keras, the losses can be computed by either instantiating the Loss
or loss
.
Here is what the code would like in a JAX training step.
The losses we have seen earlier can also be computed using JAX Metrics.
How to monitor JAX loss functions
Monitoring the loss of your network is important because it indicates whether it's learning or not. A glance at the loss can tell you if there are any problems in the network, such as overfitting. One way to monitor the loss is to print the training and validation loss as the network is training.
You can also plot the training and validation loss to represent the training visually.
Why JAX loss nan happens
JAX will not show errors when NANs occur in your program. This is by design because of the complexities involved in showing errors from accelerators. When debugging, you can turn on the NAN checker to show NAN errors. NANs should be fixed because the network stops learning when they occur.
However, what produces NANs in a network. There are various factors, not limited to:
- The dataset has not been scaled.
- There are NANs in the training set.
- The occurrence of infinite values in the training data.
- Wrong optimizer function.
- Exploding gradients leading to large updates to training weights.
- Using a very large learning rate.
Final thoughts
In this article, we have seen that choosing the right loss function is critical to the learning of a network. We have also discussed various loss functions in JAX. More precisely, we have coved:
- What is a loss function?
- How to create custom loss functions in JAX.
- Loss functions available in JAX.
- Computing loss with JAX Metrics.
- Monitoring loss in JAX.
- How to avoid NANs in JAX.
Resources
- What is JAX?
- Elegy(High-level API for deep learning in JAX & Flax)
- Flax vs. TensorFlow
- Optimizers in JAX and Flax
- Distributed training with JAX & Flax
- How to load datasets in JAX with TensorFlow
- Image classification with JAX & Flax
- How to use TensorBoard in Flax
- LSTM in JAX & Flax
- JAX metrics
- Optax
Follow us on LinkedIn, Twitter, GitHub, and subscribe to our blog, so you don't miss a new issue.