JAX loss functions
Image by Gerd Altmann from Pixabay

JAX loss functions

Derrick Mwiti
Derrick Mwiti
5 min read

Loss functions are at the core of training machine learning. They can be used to identify how well the model is performing on a dataset. Poor performance leads to a very high loss, while a well-performing model will have a lower loss. Therefore, the choice of a loss function is an important one when building machine learning models.

In this article, we'll look at the loss functions available in JAX and how you can use them.

What is a loss function?

Machine learning models learn by evaluating predictions against true values and adjusting the weights. The objective is to obtain the weights that minimize the loss function, that is the error. The loss function is also referred to as the cost function. The choice of a loss function depends on the problem. The two most common problems are classification and regression problems. Each will require a different set of loss functions.

Creating custom loss functions in JAX

When training networks with JAX, you'll need to obtain the logits at the training stage. These logits are used for computing the loss. You'll then need to evaluate the loss function and its gradient. The gradient is used to update the model parameters. At this point, you can compute the training metrics for the model.  

What are logits?
Logits are unnormalized log probabilities. 

You can use JAX functions such as  log_sigmoid and log_softmax to build custom loss functions. You can even write your loss functions from scratch without using these functions. Here is an example of computing the sigmoid binary cross entropy loss.

Which loss functions are available in JAX?

Building custom loss functions for your networks can introduce errors in your program. Furthermore, you have to take the burden of maintaining these functions. However, if the loss function you want is unavailable, there is a strong case for creating a custom loss function. Be that as it may, there is no need to reinvent the wheel and rewrite the already implemented loss functions.

JAX doesn't ship with any loss functions. In JAX, we use optax for defining loss functions. It's important to ensure that you use JAX-compatible libraries to take advantage of functions such as JIT, vmap and pmap that make your programs faster.

Let's take a look at some of the loss functions available in optax.

Sigmoid binary cross entropy

The sigmoid binary cross entropy loss is computed using optax.sigmoid_binary_cross_entropy. The function expects logits and class labels. It is used in problems where the classes are not mutually exclusive. For example, the model can predict that the image contains two objects in an image classification problem.  

Softmax cross entropy

The softmax cross entropy function is used where the classes are mutually exclusive. For example, in the MNIST dataset, each digit has exactly one label. The function expects an array of logits and probability distributions. The probability distribution sum to 1.  

Cosine distance

The cosine distance measures the cosine distance between targets and predictions.

Cosine similarity

The cosine similarity loss measures the cosine similarity between the true and predicted values. The cosine similarity is the cosine of the angle between two vectors. This is obtained by the dot product of the vectors divided by the product of their lengths.

The result is a number between -1 and 1. 0 shows orthogonality, while numbers closer to -1 indicate similarity. Numbers close to  1 portray high dissimilarity.

Huber loss

The Huber loss is used for regression problems. It is less sensitive to outliers compared to the squared error loss. A variant of the Huber loss that can be used in classification problems exists.  

l2 loss

L2 loss function is the Least Square Errors. The L2 loss aims at minimizing the sum of the squared differences between the true and predicted values. The Mean Squared Error is the mean of all L2 loss values.

log cosh

log_cosh is the logarithm of the hyperbolic cosine of the prediction error.

log(cosh(x)) is approximately equal to (x ** 2) / 2 for small x and to abs(x) - log(2) for large x. This means that 'logcosh' works mostly like the mean squared error, but will not be so strongly affected by the occasional wildly incorrect prediction. TensorFlow Docs

Smooth labels

optax.smooth_labels is used together with a cross-entropy loss to smooth labels. It returns a smoothed version of the one hot input labels. Label smoothing has been applied in image classification, language translation, and speech recognition to prevent models from becoming overconfident.

Computing loss with JAX Metrics

JAX Metrics is an open-source package for computing losses and metrics in JAX. It provides a Keras-like API for computing model loss and metrics.

For example, here is how you use the library to compute the cross-entropy loss. Similar to Keras, the losses can be computed by either instantiating the Loss or loss.

Here is what the code would like in a JAX training step.

The losses we have seen earlier can also be computed using JAX Metrics.

How to monitor JAX loss functions

Monitoring the loss of your network is important because it indicates whether it's learning or not. A glance at the loss can tell you if there are any problems in the network, such as overfitting. One way to monitor the loss is to print the training and validation loss as the network is training.    

This post is for subscribers only


Already have an account? Log in