Back to jax

Rank promotion warning

docs/rank_promotion_warning.rst

0.3.251.8 KB
Original Source

Rank promotion warning

NumPy broadcasting rules <https://docs.scipy.org/doc/numpy/user/basics.broadcasting.html#general-broadcasting-rules>_ allow the automatic promotion of arguments from one rank (number of array axes) to another. This behavior can be convenient when intended but can also lead to surprising bugs where a silent rank promotion masks an underlying shape error.

Here's an example of rank promotion:

from jax import numpy as jnp x = jnp.arange(12).reshape(4, 3) y = jnp.array([0, 1, 0]) x + y Array([[ 0, 2, 2], [ 3, 5, 5], [ 6, 8, 8], [ 9, 11, 11]], dtype=int32)

To avoid potential surprises, :code:jax.numpy is configurable so that expressions requiring rank promotion can lead to a warning, error, or can be allowed just like regular NumPy. The configuration option is named :code:jax_numpy_rank_promotion and it can take on string values :code:allow, :code:warn, and :code:raise. The default setting is :code:allow, which allows rank promotion without warning or error. The :code:raise setting raises an error on rank promotion, and :code:warn raises a warning on the first occurrence of rank promotion.

Rank promotion can be enabled or disabled locally with the :func:jax.numpy_rank_promotion context manager:

.. code-block:: python

with jax.numpy_rank_promotion("warn"): z = x + y

This configuration can also be set globally in several ways. One is by using :code:jax.config in your code:

.. code-block:: python

import jax jax.config.update("jax_numpy_rank_promotion", "warn")

You can also set the option using the environment variable :code:JAX_NUMPY_RANK_PROMOTION, for example as :code:JAX_NUMPY_RANK_PROMOTION='warn'. Finally, when using :code:absl-py the option can be set with a command-line flag.