docs/debugging/xla_metadata.md
set_xla_metadataSummary: set_xla_metadata allows you to attach metadata to operations in your JAX code. This metadata is passed down to the XLA compiler as frontend_attributes and can be used to enable compiler-level debugging tools, such as the XLA-TPU debugger.
You can use it in three ways:
Warning: set_xla_metadata is an experimental feature and its API is subject to change.
When JAX transforms and compiles your code, it ultimately generates an XLA (Accelerated Linear Algebra) computation graph. Each operation in this graph can have associated metadata, specifically frontend_attributes. This metadata doesn't change the numerical result of the operation, but it can be used to signal special behavior to the compiler or runtime.
set_xla_metadata provides a way to attach this metadata directly from your JAX code. This is a powerful feature for low-level debugging and profiling.
Tagging an individual operation gives you precise control over which parts of your computation you want to inspect. To do this, you wrap the output (value) of an operation with set_xla_metadata. When wrapping a function with multiple operations within, only the final operation of said function will be tagged.
import jax
import jax.numpy as jnp
from jax.experimental.xla_metadata import set_xla_metadata
# Tagging an individual operation
def value_tagging(x):
y = jnp.sin(x)
z = jnp.cos(x)
return set_xla_metadata(y * z, breakpoint=True)
print(jax.jit(value_tagging).lower(1.0).as_text("hlo"))
Results in:
ENTRY main.5 {
x.1 = f32[] parameter(0)
sin.2 = f32[] sine(x.1)
cos.3 = f32[] cosine(x.1)
ROOT mul.4 = f32[] multiply(sin.2, cos.3), frontend_attributes={breakpoint="true"}
}
If you want to apply the same metadata to a larger section of code, you can use set_xla_metadata as a context manager. All JAX operations within the with block will have the specified metadata attached.
import jax
import jax.numpy as jnp
from jax.experimental.xla_metadata import set_xla_metadata
# Tagging a block of code
def context_tagging(x):
with set_xla_metadata(_xla_log=True):
y = jnp.sin(x)
z = jnp.cos(y)
return y * z
print(jax.jit(context_tagging).lower(1.0).as_text("hlo"))
Results in:
ENTRY main.5 {
x.1 = f32[] parameter(0)
sin.2 = f32[] sine(x.1), frontend_attributes={_xla_log="true"}
cos.3 = f32[] cosine(sin.2), frontend_attributes={_xla_log="true"}
ROOT mul.4 = f32[] multiply(sin.2, cos.3), frontend_attributes={_xla_log="true"}
}
If you want to tag all operations in a function, you can also use set_xla_metadata as a decorator:
import jax
import jax.numpy as jnp
from jax.experimental.xla_metadata import set_xla_metadata
# Tagging with a decorator
@set_xla_metadata(_xla_log=True)
@jax.jit
def decorator_tagging(x):
y = jnp.sin(x)
z = jnp.cos(y)
return y * z
print(decorator_tagging.lower(1.0).as_text("hlo"))
This will result in the same HLO as above.
set_xla_metadata utilizes either a XlaMetadataContextManager or JAX primitive depending on use-case and is compatible with JAX's transformations like jit, vmap, and grad.
vmap: When you vmap a function containing set_xla_metadata, the metadata will be applied to all of the relevant batched operations.grad:
with set_xla_metadata(...):, the metadata is applied to both the forward pass and backward pass of the operations within it.set_xla_metadata() currently only applies to the forward pass of a function. To tag individual operations generated by the backward pass (i.e., the gradient computation), a simple custom_vjp can be used:
import jax
import jax.numpy as jnp
from jax.experimental.xla_metadata import set_xla_metadata
def fn(x):
y = jnp.sin(x)
z = jnp.cos(x)
return y * z
metadata = {"example": "grad_tagging"}
# --- Define Custom VJP to tag gradients ---
@jax.custom_vjp
def wrapped_fn(x):
return fn(x)
def fwd(*args):
primal_out, vjp_fn = jax.vjp(fn, *args)
return primal_out, vjp_fn
def bwd(vjp_fn, cts_in):
cts_out = vjp_fn(cts_in)
cts_out = set_xla_metadata(cts_out, **metadata)
return cts_out
wrapped_fn.defvjp(fwd, bwd)
# ------
print(jax.jit(jax.grad(wrapped_fn)).lower(jnp.array(3.0)).as_text("hlo"))
ENTRY main.10 {
x.1 = f32[] parameter(0)
sin.2 = f32[] sine(x.1)
neg.6 = f32[] negate(sin.2)
sin.5 = f32[] sine(x.1)
mul.7 = f32[] multiply(neg.6, sin.5)
cos.4 = f32[] cosine(x.1)
cos.3 = f32[] cosine(x.1)
mul.8 = f32[] multiply(cos.4, cos.3)
ROOT add_any.9 = f32[] add(mul.7, mul.8), frontend_attributes={example="grad_tagging"}
}
set_xla_metadatacustom_vjp must be used in order to tag gradients in this case. See above for an example.set_xla_metadata is an experimental feature and its API is subject to change.