Training models on accelerators with JAX and Flax differs slightly from training with CPU. For instance, the data needs to be replicated in the different devices when using multiple accelerators. After that, we need to execute the training on multiple devices and aggregate the results. Flax supports TPU and GPU