docs/internals/constants.md
(constants-note)=
"Closed-over constants" are non-scalar arrays that are encountered during JAX
tracing of a function and do not have dependencies on any of the function's
arguments.
JAX operations such as jax.numpy and lax are staged out and do not create
closed-over constants.
In the following example, the arrays
a_jax_array and np.full are closed-over constants, but jnp.full
is not. We refer below to closed-over constants simply as constants.
import numpy as np
from jax import jit
from jax import numpy as jnp
a_jax_array = jnp.ones((16,), dtype=np.float32)
@jit
def f(x):
return x + a_jax_array + np.full((16,), 42.) + jnp.full((16,), 142.)
We describe below the future internal implementation details for
constants. As of July 2025, this is not yet the default implementation;
it is enabled by the environment variable JAX_USE_SIMPLIFIED_JAXPR_CONSTANTS=True.
See further below for the details of the previous
implementation, including its drawbacks.
When JAX tracing encounters a constant that is either an argument of a JAX
primitive
or a function return, it is represented as a core.Literal, and is embedded
in the Jaxpr along with the primitives that use them.
The function core.is_literalable decides which constants are turned into
core.Literal. All scalar constants are turned into core.Literal, along with
non-scalar np.ndarray and jax.Array.
When lowering the code to HLO we could just emit a stablehlo.constant
operation for a core.Literal, but this would have several disadvantages:
jax.Array (e.g., the a_jax_array above), then it is
pulled from the device to the host during lowering, and it will later
re-materialized on the device when the lowered module executes.
This can increase the host memory usage, sometimes dramatically.
Furthermore, if the constant is sharded on multiple devices this
sharding is lost.Instead, during lowering we use the function core.jaxpr_const_args to scan
a Jaxpr and return a list of constants contained within, uniquified by their
id. The core.jaxpr_const_args is memoized for each Jaxpr and sub-Jaxpr
on which it is called.
All the lowered HLO functions will take one additional argument
for each unique constant appearing in the Jaxpr to which it corresponds.
These arguments, referred to as const_args,
come after the dimension variable arguments, after the
token arguments, and just before the actual array arguments.
During lowering we maintain a mapping const_lowering: dict[int, mlir.IrValues]
from the id of the constants to the HLO values for the corresponding
const args.
This mapping is stored in the mlir.LoweringRuleContext and is used
by mlir.ir_constant: when a constant is encountered, we just reuse
the existing lowering from const_lowering instead of emitting a
stablehlo.constant.
When we lower an HLO inner function (i.e., not the main function),
we call again core.jaxpr_const_args
to get the actual constants in the corresponding Jaxpr. These are
expected to be among the constants for which we have a const_lowering.
The inner function will get its own smaller set of const_args and
its own const_lowering mapping to be used when lowering the body.
E.g., the function mlir.lower_jaxpr_as_fun is one place where some
of this happens.
The function mlir.jaxpr_subcomp does not create a new HLO function,
but instead creates a block within the current function. It uses
the enclosing function's const_lowering.
Note also that there will still be stablehlo.constant in the lowered
code, in three cases:
Jaxpr. This can happen for constants that
arise during lowering, e.g., the lowering of some PRNG functions
include constants.mlir.LoweringParameters.hoist_constants_as_args parameter
to control this.One additional complication is that some of the internal lowering functions
need to take the argument avals and sometimes also the shardings and
layouts for the arguments. Furthermore, the avals, shardings, and layout for
all arguments, including the const args,
are used also after lowering also. Therefore, it is convenient
to compute these fairly high in the call stack, e.g., in
pxla.lower_sharding_computations, and pass them down.
For example, the functions mlir.lower_jaxpr_to_module,
pjit._pjit_cached_lower_jaxpr_to_fun, and, mlir.lower_jaxpr_to_fun
take in_avals, in_shardings, and in_layouts that
that include both the avals for const_args and for the regular args
(the ones corresponding to the Jaxpr.invars).
They also take a num_const_args argument.
The lowered MLIR module contains arguments for the const args, so
the compiled executable will need to be passed the const args.
It is important to choose the right place where we prepend the
const args. For example, in the following code, the second invocation
of the jitted function f is expected to hit the C++ jit cache without
any Python code executing.
const = jnp.array([42.])
f = jax.jit(lambda: const)
f()
f()
(TODO: yashk2810 plans to write a description of how the jit caches work.)
This means that the const will have to be passed to the executable in C++
(and thus stored in pxla.MeshExecutableFastpathData),
and therefore the C++ cache
miss functions (e.g., pjit._cpp_pjit.cache_miss,
or aot_cache_miss in pxla.MeshExecutable.create_cpp_call)
will not take the const args as arguments. Instead these cache
miss functions will have to prepend the const args.
The C++ fast path has support for const args starting with jaxlib 0.7.1. In prior versions, the fast path is disabled when there are const args.
To implement this scheme, we keep the const_args in
stages.Lowering, stages.Lowered, and stages.CompiledCallParams.
Interestingly, when we serialize an executable, e.g., for the compilation cache, we do not need to serialize the closed over constants. The executable itself does not contain them, and needs to take them as const args. Whoever is going to deserialize the cached executable will have to pass the const args.
In AOT mode, the lowering and execution may
use different values of the jax_enable_x64 configuration value.
If the constants are 64-bit ndarray we must use the same value
of jax_enable_x64 for lowering and execution.
This describes the current way we handle closed-over constants, as
of July 2025 (as long as JAX_USE_SIMPLIFIED_JAXPR_CONSTANTS=False).
When JAX traces a function to a Jaxpr it collects the closed-over values
into a set of constants, and adds a corresponding set of constvars to the
Jaxpr (the actual arguments are represented by invars).
Most tracing functions, e.g., trace_to_jaxpr_dynamic,
return both the Jaxpr and the constants.
In many places in the code we use a class core.ClosedJaxpr that contains a
Jaxpr and consts corresponding to the Jaxpr.constvars.
There are several issues with ClosedJaxpr:
consts in ClosedJaxpr results in inlined
stablehlo.constant, with all the issues described above.Jaxpr and ClosedJaxpr are used pervasively in JAX, often with the
generic name jaxpr and it is not easy to tell which kind of Jaxpr we
have.
We have started to add type declarations, but in some places the code
is written with isinstance conditionals to work with both.id, we would like to memoize their construction.
For example, the function pe.closed_jaxpr
memoizes the construction of ClosedJaxpr but only for the case when
consts is empty.
This is because sometimes consts are not hashable.