docs/export/shape_poly.md
(shape_poly)=
When JAX is used in JIT mode, a function will be traced, lowered to StableHLO, and compiled for each
combination of input types and shapes. After exporting a function and
deserializing it on another system we don't have the Python sources available anymore,
so we cannot re-trace and re-lower it. Shape polymorphism is a feature of JAX export
to allow some exported functions to be used for a whole family of input shapes.
These functions are traced and lowered once, during exporting, and Exported
object contains the information needed to be able to compile and execute the function
on many concrete input shapes. We do this by specifying shapes that contain
dimension variables (symbolic shapes) when exporting, as in the
following example:
>>> import jax
>>> from jax import export
>>> from jax import numpy as jnp
>>> def f(x): # f: f32[a, b]
... return jnp.concatenate([x, x], axis=1)
>>> # We construct symbolic dimension variables.
>>> a, b = export.symbolic_shape("a, b")
>>> # We can use the symbolic dimensions to construct shapes.
>>> x_shape = (a, b)
>>> x_shape
(a, b)
>>> # Then we export with symbolic shapes:
>>> exp: export.Exported = export.export(jax.jit(f))(
... jax.ShapeDtypeStruct(x_shape, jnp.int32))
>>> exp.in_avals
(ShapedArray(int32[a,b]),)
>>> exp.out_avals
(ShapedArray(int32[a,2*b]),)
>>> # We can later call with concrete shapes (with a=3 and b=4), without re-tracing `f`.
>>> res = exp.call(np.ones((3, 4), dtype=np.int32))
>>> res.shape
(3, 8)
Note that such functions are still re-compiled on demand for each concrete input shape they are invoked on. Only the tracing and the lowering are saved.
The {func}jax.export.symbolic_shape is used in the above
example to parse a string representation of a symbolic shape
into dimension expressions objects (of type _DimExpr) that are usable in place of integer
constants to construct shapes. The dimension expression objects
overload most integer operators, so you can use them as
you'd use integer constants in most cases.
See {ref}computing-with-dimension-variables for more details.
Additionally, we provide the {func}jax.export.symbolic_args_specs that
can be used to construct pytrees of jax.ShapeDtypeStruct objects based
on a polymorphic shape specification:
>>> def f1(x, y): # x: f32[a, 1], y : f32[a, 4]
... return x + y
>>> # Assuming you have some actual args with concrete shapes
>>> x = np.ones((3, 1), dtype=np.int32)
>>> y = np.ones((3, 4), dtype=np.int32)
>>> args_specs = export.symbolic_args_specs((x, y), "a, ...")
>>> exp = export.export(jax.jit(f1))(* args_specs)
>>> exp.in_avals
(ShapedArray(int32[a,1]), ShapedArray(int32[a,4]))
Note how the polymorphic shape specification "a, ..." contains
the placeholder ... to be filled from the concrete shapes of
the concrete shapes of the arguments (x, y).
The placeholder ... stands for 0 or more dimensions, while the
placeholder _ stands for one dimension.
The {func}jax.export.symbolic_args_specs supports pytrees of arguments,
which are used to fill-in the dtypes and any placeholders.
The function will construct a pytree of
argument specifications ({class}jax.ShapeDtypeStruct)
matching the structure of the arguments passed to it.
The polymorphic shapes specification can be a
pytree prefix in cases where one specification should apply
to multiple arguments, as in the above example.
See how optional parameters are matched to arguments.
A few examples of shape specifications:
("(b, _, _)", None) can be used for a function with two arguments, the first
being a 3D array with a batch leading dimension that should be symbolic.
The other dimensions for the
first argument and the shape of the second argument are specialized based on the actual
arguments. Note that the same specification would work if the first
argument is a pytree of 3D arrays, all with the same leading dimension
but possibly with different trailing dimensions.
The value None for the second argument means that the argument
is not symbolic. Equivalently, one can use ....
("(batch, ...)", "(batch,)") specifies that the two arguments
have matching leading dimensions, the first argument has rank at
least 1, and the second has rank 1.
We want to trust that the exported program produces the same results as the original JAX program when compiled and executed for any applicable concrete shapes. More precisely:
For any JAX function f and any argument specification arg_spec containing a
symbolic shape, and any concrete argument arg whose shape matches arg_spec:
res = f(arg),exp = export.export(f)(arg_spec),res == exp.call(arg)It is crucial to understand that f(arg) has the freedom to re-invoke
the JAX tracing machinery,
and in fact it does so for each distinct concrete arg shape,
while the execution of exp.call(arg) cannot use JAX tracing anymore
(this execution may happen in an environment where the source code
of f is not available).
Ensuring this form of correctness is hard, and in the hardest cases exporting fails. The rest of this chapter describes how to handle these failures.
(computing-with-dimension-variables)=
JAX keeps track of the shapes of all intermediate results. When those shapes depend
on dimension variables JAX computes them as symbolic dimension expressions
involving dimension variables.
Dimension variables stand for integer values greater or equal to 1.
The symbolic expressions can represent the result
of applying arithmetic operators (add, sub, mul, floordiv, mod,
including the NumPy variants np.sum, np.prod, etc.) on dimension
expressions and integers (int, np.int, or anything convertible by operator.index).
These symbolic dimensions can then be used in shape-parameters of JAX primitives
and APIs, e.g., in jnp.reshape, jnp.arange, slicing indices, etc.
For example, in the following code to flatten a 2D array, the computation
x.shape[0] * x.shape[1] computes the symbolic dimension 4 * b as the
new shape:
>>> f = lambda x: jnp.reshape(x, (x.shape[0] * x.shape[1],))
>>> arg_spec = jax.ShapeDtypeStruct(export.symbolic_shape("b, 4"), jnp.int32)
>>> exp = export.export(jax.jit(f))(arg_spec)
>>> exp.out_avals
(ShapedArray(int32[4*b]),)
It is possible to convert dimension expressions explicitly
to JAX arrays, with jnp.array(x.shape[0]) or even jnp.array(x.shape).
The result of these operations can be used as regular JAX arrays,
but cannot be used anymore as dimensions in shapes, e.g., in reshape:
>>> exp = export.export(jax.jit(lambda x: jnp.array(x.shape[0]) + x))(
... jax.ShapeDtypeStruct(export.symbolic_shape("b"), np.int32))
>>> exp.call(jnp.arange(3, dtype=np.int32))
Array([3, 4, 5], dtype=int32)
>>> exp = export.export(jax.jit(lambda x: x.reshape(jnp.array(x.shape[0]) + 2)))(
... jax.ShapeDtypeStruct(export.symbolic_shape("b"), np.int32)) # doctest: +IGNORE_EXCEPTION_DETAIL
Traceback (most recent call last):
TypeError: Shapes must be 1D sequences of concrete values of integer type, got [Traced<ShapedArray(int32[], weak_type=True)>with<DynamicJaxprTrace(level=1/0)>].
When a symbolic dimension is used in arithmetic operations with non-integers,
e.g., float, np.float, np.ndarray, or JAX arrays, it is automatically
converted to a JAX array using jnp.array.
For example, in the function below all occurrences of x.shape[0]
are converted implicitly to jnp.array(x.shape[0]) because
they are involved in operations with non-integer scalars or with
JAX arrays:
>>> exp = export.export(jax.jit(
... lambda x: (5. + x.shape[0],
... x.shape[0] - np.arange(5, dtype=jnp.int32),
... x + x.shape[0] + jnp.sin(x.shape[0]))))(
... jax.ShapeDtypeStruct(export.symbolic_shape("b"), jnp.int32))
>>> exp.out_avals
(ShapedArray(float32[], weak_type=True),
ShapedArray(int32[5]),
ShapedArray(float32[b], weak_type=True))
>>> exp.call(jnp.ones((3,), jnp.int32))
(Array(8., dtype=float32, weak_type=True),
Array([ 3, 2, 1, 0, -1], dtype=int32),
Array([4.14112, 4.14112, 4.14112], dtype=float32, weak_type=True))
Another typical example is when computing averages
(observe how x.shape[0] is automatically turned into a JAX array):
>>> exp = export.export(jax.jit(
... lambda x: jnp.sum(x, axis=0) / x.shape[0]))(
... jax.ShapeDtypeStruct(export.symbolic_shape("b, c"), jnp.int32))
>>> exp.call(jnp.arange(12, dtype=jnp.int32).reshape((3, 4)))
Array([4., 5., 6., 7.], dtype=float32)
Most JAX code assumes that the shapes of JAX arrays are tuples of integers, but with shape polymorphism some dimensions may be symbolic expressions. This can lead to a number of errors. For example, we can have the usual JAX shape check errors:
>>> v, = export.symbolic_shape("v,")
>>> export.export(jax.jit(lambda x, y: x + y))( # doctest: +IGNORE_EXCEPTION_DETAIL
... jax.ShapeDtypeStruct((v,), dtype=np.int32), # doctest: +IGNORE_EXCEPTION_DETAIL
... jax.ShapeDtypeStruct((4,), dtype=np.int32)) # doctest: +IGNORE_EXCEPTION_DETAIL
Traceback (most recent call last):
TypeError: add got incompatible shapes for broadcasting: (v,), (4,).
>>> export.export(jax.jit(lambda x: jnp.matmul(x, x)))( # doctest: +IGNORE_EXCEPTION_DETAIL
... jax.ShapeDtypeStruct((v, 4), dtype=np.int32)) # doctest: +IGNORE_EXCEPTION_DETAIL
Traceback (most recent call last):
TypeError: dot_general requires contracting dimensions to have the same shape, got (4,) and (v,).
We can fix the above matmul example by specifying that the
argument has shape (v, v).
Inside JAX there are a number of equality and inequality comparisons involving shapes, e.g., for doing shape checking or even for choosing the implementation for some primitives. Comparisons are supported as follows:
True,
e.g., for b + b == 2*b; otherwise the equality evaluates to False.
See below
for a discussion of important consequences of this behavior.b >= 1, b >= 0, 2 * a + b >= 3 are True, while b >= 2,
a >= b, a - b >= 0 are inconclusive and result in an exception.In cases where a comparison operation cannot be resolved to a boolean,
we raise {class}InconclusiveDimensionOperation. E.g.,
import jax
>>> export.export(jax.jit(lambda x: 0 if x.shape[0] + 1 >= x.shape[1] else 1))(
... jax.ShapeDtypeStruct(export.symbolic_shape("a, b"), dtype=np.int32)) # doctest: +IGNORE_EXCEPTION_DETAIL
Traceback (most recent call last):
jax._src.export.shape_poly.InconclusiveDimensionOperation: Symbolic dimension comparison 'a + 1' >= 'b' is inconclusive.
This error arises for comparison operations with shapes that
are non-constant, and the result of the operation cannot be represented as
a boolean value for all values of the symbolic dimensions involved.
If you do get a InconclusiveDimensionOperation, you can try
several strategies:
max or min, or the
np.max or np.min then you can replace those with
core.max_dim and core.min_dim, which have the effect
of delaying the inequality comparison to the compilation
time, when shapes become known.core.max_dim and
core.min_dim, e.g., instead of d if d > 0 else 0
you can write core.max_dim(d, 0).int(d) + 5 write
d + 5.By default, JAX assumes that all dimension variables range over values greater-or-equal to 1, and it tries to derive other simple inequalities from that, e.g.:
a + 2 >= 3,a * 2 >= 1,a + b + c >= 3,a // 4 >= 0, a**2 >= 1, and so on.You can avoid some inequality comparison failures if you change the symbolic shape specifications to add implicit constraints for dimension sizes. E.g.,
2*b for a dimension to constrain it to be even and greater or equal
to 2.b + 15 for a dimension to constrain it to
be at least 16. E.g., the following code would fail without
the + 15 part, because JAX will want to verify that slice sizes
are at most as large as the axis size.>>> _ = export.export(jax.jit(lambda x: x[0:16]))(
... jax.ShapeDtypeStruct(export.symbolic_shape("b + 15"), dtype=np.int32))
Such implicit symbolic constraints are used for deciding comparisons and are checked at compile time, as explained below.
You can also specify explicit symbolic constraints:
>>> # Introduce dimension variable with constraints.
>>> a, b = export.symbolic_shape("a, b",
... constraints=("a >= b", "b >= 16"))
>>> _ = export.export(jax.jit(lambda x: x[:x.shape[1], :16]))(
... jax.ShapeDtypeStruct((a, b), dtype=np.int32))
The constraints form a conjunction together with the implicit
constraints. You can specify >=, <=, and == constraints.
At the moment, JAX has limited support for reasoning with
symbolic constraints:
a >= 16 and b >= 8 we can infer
that a + 2*b >= 32.a >= b + 8 we
can infer that a - b >= 8 but not that a >= 9.
We may improve somewhat this area in the future.==
is encountered, it is rewritten to the expression on
the right.
E.g., floordiv(a, b) == c works by replacing all
occurrences of floordiv(a, b) with c.
Equality constraints must not contain addition or
subtraction at the top-level on the left-hand-side. Examples of
valid left-hand-sides are a * b, or 4 * a, or
floordiv(a + c, b).>>> # Introduce dimension variable with equality constraints.
>>> a, b, c, d = export.symbolic_shape("a, b, c, d",
... constraints=("a * b == c + d",))
>>> 2 * b * a
2*d + 2*c
>>> a * b * b
b*d + b*c
The symbolic constraints can also help to work around the
limitations in the JAX reasoning mechanisms.
For example, in the code below JAX will attempt to prove that
the slice size x.shape[0] % 3, which is the symbolic expression
mod(b, 3), is less or equal to the axis size, which is b.
This happens to be true for all strictly positive values of
b, but it is not something JAX's symbolic comparison rules
can prove. Hence, the following code raises an error:
from jax import lax
>>> b, = export.symbolic_shape("b")
>>> f = lambda x: lax.slice_in_dim(x, 0, x.shape[0] % 3)
>>> export.export(jax.jit(f))(
... jax.ShapeDtypeStruct((b,), dtype=np.int32)) # doctest: +IGNORE_EXCEPTION_DETAIL
Traceback (most recent call last):
jax._src.export.shape_poly.InconclusiveDimensionOperation: Symbolic dimension comparison 'b' >= 'mod(b, 3)' is inconclusive.
This error arises for comparison operations with shapes that
are non-constant, and the result of the operation cannot be represented as
a boolean value for all values of the symbolic dimensions involved.
One option here would be to restrict the code to work only on
axis sizes that are multiple of 3 (by replacing
b with 3*b in the shape). Then, JAX would be able
to simplify the modulo operation mod(3*b, 3) to 0.
Another option is to add a symbolic constraint
with the exact inconclusive inequality that JAX
is attempting to prove:
>>> b, = export.symbolic_shape("b",
... constraints=["b >= mod(b, 3)"])
>>> f = lambda x: lax.slice_in_dim(x, 0, x.shape[0] % 3)
>>> _ = export.export(jax.jit(f))(
... jax.ShapeDtypeStruct((b,), dtype=np.int32))
Just like the implicit constraints, the explicit symbolic constraints are checked at compile time, using the same mechanism as explained below.
The symbolic constraints are stored in αn
{class}jax.export.SymbolicScope object, which is created implicitly
for each call to {func}jax.export.symbolic_shapes. You must be careful
to not mix symbolic expressions that use different scopes.
For example,
the following code will fail because a1 and a2
use different scopes (created by different invocations of
{func}jax.export.symbolic_shape):
>>> a1, = export.symbolic_shape("a,")
>>> a2, = export.symbolic_shape("a,", constraints=("a >= 8",))
>>> a1 + a2 # doctest: +IGNORE_EXCEPTION_DETAIL
Traceback (most recent call last):
ValueError: Invalid mixing of symbolic scopes for linear combination.
Expected scope 4776451856 created at <doctest shape_poly.md[31]>:1:6 (<module>)
and found for 'a' (unknown) scope 4776979920 created at <doctest shape_poly.md[32]>:1:6 (<module>) with constraints:
a >= 8
The symbolic expressions that originate from a single call
to {func}jax.export.symbolic_shape share a scope and
can be mixed up in arithmetic operations. The result would
also share the same scope.
You can reuse scopes:
>>> a, = export.symbolic_shape("a,", constraints=("a >= 8",))
>>> b, = export.symbolic_shape("b,", scope=a.scope) # Reuse the scope of `a`
>>> a + b # Allowed
b + a
You can also create scopes explicitly:
>>> my_scope = export.SymbolicScope()
>>> c, = export.symbolic_shape("c", scope=my_scope)
>>> d, = export.symbolic_shape("d", scope=my_scope)
>>> c + d # Allowed
d + c
JAX tracing uses caches keyed partially by shapes, and symbolic shapes that are printed identically will be considered distinct if they use different scopes.
The equality comparison returns False for b + 1 == b or b == 0
(in which case it is certain that the dimensions are different for all values
of the dimension variables),
but also for b == 1 and for a == b. This is unsound, and we
ought to raise core.InconclusiveDimensionOperation because under
some valuations the result should be True and under other
valuations it should be False. We choose to make equality total
thus allowing unsoundness because otherwise we may get spurious errors
in presence of hash collisions
when hashing dimension expressions or objects that include
them (shapes, core.AbstractValue, core.Jaxpr).
Besides the hashing errors, a partial semantics of equality
leads to errors for the following expressions b == a or b == b or b in [a, b]
even though the error is avoided if we change the order of the comparisons.
Code of the form if x.shape[0] != 1: raise NiceErrorMessage is sound even
with this treatment of equality, but code of the form if x.shape[0] != 1: return 1
is unsound.
Currently, the only way to pass the values of dimension variables
when an exported object is invoked is indirectly through the shapes
of the array arguments. E.g., the value of b can be inferred at the
call site from the shape of the first argument of type f32[b].
This works well for most use cases, and
it mirrors the calling convention of JIT functions.
Sometimes you may want to export a function parameterized
by an integer value that determines some shapes in the program.
For example, we may
want to export the function my_top_k defined below,
parameterized by the
value of k, which determines the shape of the result.
The following attempt will lead to an error since the dimension
variable k cannot be derived from the shape of the input x: i32[4, 10]:
>>> def my_top_k(k, x): # x: i32[4, 10], k <= 10
... return lax.top_k(x, k)[0] # : i32[4, 3]
>>> x = np.arange(40, dtype=np.int32).reshape((4, 10))
>>> # Export with static `k=3`. Since `k` appears in shapes it must be in `static_argnums`.
>>> exp_static_k = export.export(jax.jit(my_top_k, static_argnums=0))(3, x)
>>> exp_static_k.in_avals[0]
ShapedArray(int32[4,10])
>>> exp_static_k.out_avals[0]
ShapedArray(int32[4,3])
>>> # When calling the exported function we pass only the non-static arguments
>>> exp_static_k.call(x)
Array([[ 9, 8, 7],
[19, 18, 17],
[29, 28, 27],
[39, 38, 37]], dtype=int32)
>>> # Now attempt to export with symbolic `k` so that we choose `k` after export.
>>> k, = export.symbolic_shape("k", constraints=["k <= 10"])
>>> export.export(jax.jit(my_top_k, static_argnums=0))(k, x) # doctest: +IGNORE_EXCEPTION_DETAIL
Traceback (most recent call last):
UnexpectedDimVar: "Encountered dimension variable 'k' that is not appearing in the shapes of the function arguments
In the future, we may add an additional mechanism to pass the values of
dimension variables, besides implicitly through the input shapes.
Meanwhile, the workaround for the above use case is to replace the
function parameter k with an array of shape (0, k), so that
k can be derived from the input shape of an array.
The first dimension is 0 to ensure that the whole array is empty
and there is no performance penalty when we call the exported function.
>>> def my_top_k_with_dimensions(dimensions, x): # dimensions: i32[0, k], x: i32[4, 10]
... return my_top_k(dimensions.shape[1], x)
>>> exp = export.export(jax.jit(my_top_k_with_dimensions))(
... jax.ShapeDtypeStruct((0, k), dtype=np.int32),
... x)
>>> exp.in_avals
(ShapedArray(int32[0,k]), ShapedArray(int32[4,10]))
>>> exp.out_avals[0]
ShapedArray(int32[4,k])
>>> # When we invoke `exp` we must construct and pass an array of shape (0, k)
>>> exp.call(np.zeros((0, 3), dtype=np.int32), x)
Array([[ 9, 8, 7],
[19, 18, 17],
[29, 28, 27],
[39, 38, 37]], dtype=int32)
Another situation when you may get an error is when some dimension variables do appear in the input shapes, but in a non-linear expression that JAX cannot currently solve:
>>> a, = export.symbolic_shape("a")
>>> export.export(jax.jit(lambda x: x.shape[0]))(
... jax.ShapeDtypeStruct((a * a,), dtype=np.int32)) # doctest: +IGNORE_EXCEPTION_DETAIL
Traceback (most recent call last):
ValueError: Cannot solve for values of dimension variables {'a'}.
We can only solve linear uni-variate constraints.
Using the following polymorphic shapes specifications: args[0].shape = (a^2,).
Unprocessed specifications: 'a^2' for dimension size args[0].shape[0].
JAX assumes that dimension variables range over strictly positive integers, and this assumption is checked when the code is compiled for concrete input shapes.
For example, given the symbolic input shape (b, b, 2*d),
JAX will generate code to check the following assertions when
invoked with actual argument arg:
arg.shape[0] >= 1arg.shape[1] == arg.shape[0]arg.shape[2] % 2 == 0arg.shape[2] // 2 >= 1For example, here is the error we get when we call the exported
on an argument of shape (3, 3, 5):
>>> def f(x): # x: f32[b, b, 2*d]
... return x
>>> exp = export.export(jax.jit(f))(
... jax.ShapeDtypeStruct(export.symbolic_shape("b, b, 2*d"), dtype=np.int32))
>>> exp.call(np.ones((3, 3, 5), dtype=np.int32)) # doctest: +IGNORE_EXCEPTION_DETAIL
Traceback (most recent call last):
ValueError: Input shapes do not match the polymorphic shapes specification.
Division had remainder 1 when computing the value of 'd'.
Using the following polymorphic shapes specifications:
args[0].shape = (b, b, 2*d).
Obtained dimension variables: 'b' = 3 from specification 'b' for dimension args[0].shape[0] (= 3), .
Please see https://docs.jax.dev/en/latest/export/shape_poly.html#shape-assertion-errors for more details.
These errors arise in a pre-processing step before the compilation.
(shape_poly_debugging)=
First, see the {ref}export_debugging documentation.
Additionally, you can debug the shape refinement, which is
invoked at compilation time for modules that have dimension variables or multi-platform
support.
If there is an error during shape refinement, you can set the JAX_DUMP_IR_TO
environment variable to see a dump of the HLO module before
shape refinement (named ..._before_refine_polymorphic_shapes.mlir).
This module should already have static input shapes.
To enable the logging of all stages of shape refinement you can set the
environment variable TF_CPP_VMODULE=refine_polymorphic_shapes=3 in OSS
(inside Google, you pass --vmodule=refine_polymorphic_shapes=3):
# Log from python
JAX_DUMP_IR_TO=/tmp/export.dumps/ TF_CPP_VMODULE=refine_polymorphic_shapes=3 python tests/shape_poly_test.py ShapePolyTest.test_simple_unary -v=3