# Activation functions in JAX and Flax

## Table of Contents

Activation functions are applied in neural networks to ensure that the network outputs the desired result. The activations functions cap the output within a specific range. For instance, when solving a binary classification problem, the outcome should be a number between 0 and 1. This indicates the probability of an item belonging to either of the two classes. However, in a regression problem, you want the numerical prediction of a quantity, for example, the price of an item. You should, therefore, choose an appropriate activation function for the problem being solved.

Let's look at common activation functions in JAX and Flax.

### ReLU – Rectified linear unit

The **ReLU activation function** is primarily used in the hidden layers of neural networks to ensure non-linearity. The function caps all outputs to zero and above. Outputs below zero are returned as zero, while numbers above zero are returned as they are. This ensures that there are no negative numbers in the network.

On line 9 we apply the ReLu activation function after the convolution layer.

```
import flax
from flax import linen as nn
class CNN(nn.Module):
@nn.compact
def __call__(self, x):
x = nn.Conv(features=32, kernel_size=(3, 3))(x)
x = nn.relu(x)
x = nn.avg_pool(x, window_shape=(2, 2), strides=(2, 2))
x = nn.Conv(features=64, kernel_size=(3, 3))(x)
x = nn.relu(x)
x = nn.avg_pool(x, window_shape=(2, 2), strides=(2, 2))
x = x.reshape((x.shape[0], -1))
x = nn.Dense(features=256)(x)
x = nn.relu(x)
x = nn.Dense(features=2)(x)
x = nn.log_softmax(x)
return x
```

### PReLU– Parametric Rectified Linear Unit

**Parametric Rectified Linear Unit** is ReLU with extra parameters equal to the number of channels. It works by introducing *a– *a learnable parameter. PReLU allows for non-negative values.

`x = nn.PReLU(x)`

### Sigmoid

The **sigmoid activation function **caps output to a number between 0 and 1 and is mainly used for binary classification tasks. Sigmoid is used where the classes are non-exclusive. For example, an image can have a car, a building, a tree, etc. Just because there is a car in the image doesn’t mean a tree can’t be in the picture. Use the sigmoid function when there is more than one correct answer.

`x = nn.sigmoid(x)`

### Log sigmoid

**Log sigmoid** computes the log of the sigmoid activation, and its output is within the range of −∞ to 0.

`x = nn.log_sigmoid(x)`

### Softmax

The **softmax activation function** is a variant of the sigmoid function used in multi-class problems where labels are mutually exclusive. For example, a picture is either grayscale or color. Use the softmax activation when there is only one correct answer.

`x = nn.softmax(x)`

### Log softmax

**Log softmax **computes the logarithm of the softmax function, which rescales elements to the range −∞ to 0.

`x = nn.log_softmax(x)`

### ELU – Exponential linear unit activation

**ELU activation** function helps in solving the vanishing and exploding gradients problem. Unlike ReLu, ELU allows negative numbers pushing the mean unit activations closer to zero. ELUs may lead to faster training and better generalization in networks with more than five layers.

For values above zero the number is returned as is but for numbers below zeros they are a number that is less that but close to zero.

`x = nn.elu(x)`

### CELU – Continuously-differentiable exponential linear unit

CELU is ELU that is continuously differentiable.

`x = nn.celu(x)`

### GELU– Gaussian error linear unit activation

**GELU** non-linearity weights inputs by their value rather than gates inputs by their sign as in ReLU– Source.

`x = nn.gelu(x)`

### GLU – Gated linear unit activation

**GLU** is computed as * GLU(a,b)=a⊗σ(b)*. It has been applied in Gated CNNs for natural language processing. In the formula, the b gate controls what information is passed to the next layer. GLU helps tackle the vanishing gradient problem.

`x = nn.glu(x)`

### Soft sign

The **Soft sign** activation function caps values between -1 and 1. It is similar to the hyperbolic tangent activation function– tanh. The difference is that tanh converges exponentially while Soft sign converges polynomially.

`x = nn.soft_sign(x)`

### Softplus

The **Softplus activation** returns values as zero and above. It is a smooth version of the ReLu.

`x = nn.soft_plus(x)`

### Swish–Sigmoid** Linear Unit(** **SiLU)**

The SiLU activation function is computed as `x * sigmoid(beta * x)`

where beta is the hyperparameter for Swish activation function. SiLU, is, therefore, computed by multiplying the sigmoid function with its input.

`x = nn.swish(x)`

### Custom activation functions in JAX and Flax

You can also define custom activation functions in JAX. For example, here's how you'd define the LeakyReLu activation function.

## Final thoughts

You have learned about the various activation functions you can use in JAX and Flax. You have also seen that you can create new functions by creating a class that implements the `__call__`

method.

Want to dive deeper into JAX and Flax? Here are some more resources from our blog:

- What is JAX?
- Flax vs. TensorFlow
- JAX loss functions
- Optimizers in JAX and Flax
- How to load datasets in JAX using TensorFlow
- Building Convolutional Neural Networks in JAX and Flax
- Distributed training in JAX
- Using TensorBoard in JAX and Flax
- LSTM in JAX & Flax
- Elegy (High-level API for deep learning in JAX & Flax)

*Follow us on LinkedIn, Twitter, GitHub, and subscribe to our blog*,

*so you don't miss a new issue.*

### Machine learning nuggets Newsletter

Join the newsletter to receive the latest updates in your inbox.