docs/xla_flags.md
This guide gives a brief overview of XLA and how XLA relates to Jax. For in-depth details please refer to XLA documentation.
XLA (Accelerated Linear Algebra) is a domain-specific compiler for linear algebra that plays a pivotal role in Jax's performance and flexibility. It enables Jax to generate optimized code for various hardware backends (CPUs, GPUs, TPUs) by transforming and compiling your Python/NumPy-like code into efficient machine instructions.
Jax uses XLA's JIT compilation capabilities to transform your Python functions into optimized XLA computations at runtime.
You can influence XLA's behavior in Jax by setting XLA_FLAGS environment variables before running your Python script or colab notebook.
For the colab notebooks:
Provide flags using os.environ['XLA_FLAGS']:
import os
# Set multiple flags separated by spaces
os.environ['XLA_FLAGS'] = '--flag1=value1 --flag2=value2'
For the python scripts:
Specify XLA_FLAGS as a part of cli command:
XLA_FLAGS='--flag1=value1 --flag2=value2' python3 source.py
Important Notes:
XLA_FLAGS before importing Jax or other relevant libraries. Changing XLA_FLAGS after backend initialization will have no effect and given backend initialization time is not clearly defined it is usually safer to set XLA_FLAGS before executing any Jax code.For further information:
Complete and up to date documentation about XLA can be found in the official XLA documentation.
For backends supported by open-source version of XLA (CPU, GPU), XLA flags are defined with their default values in xla/debug_options_flags.cc, and a complete list of flags could be found here.
A guide on how to use key XLA flags can be found here.
Additional reading: