doc/source/dev/api-dev/array_api.rst
.. _dev-arrayapi:
.. note:: Array API standard support is still experimental and hidden behind an environment variable. Only a small part of the public API is covered right now.
This guide describes how to use and add support for the
Python array API standard <https://data-apis.org/array-api/latest/index.html>_.
This standard allows users to use any array API compatible array library
with parts of SciPy out of the box.
The RFC_ defines how SciPy implements support for the standard, with the main
principle being "array type in equals array type out". In addition, the
implementation does more strict validation of allowed array-like inputs, e.g.
rejecting numpy matrix and masked array instances, and arrays with object
dtype.
In the following, an array API compatible namespace is noted as xp.
To enable the array API standard support, an environment variable must be set before importing SciPy:
.. code:: bash
export SCIPY_ARRAY_API=1
This both enables array API standard support and the more strict input
validation for array-like arguments. Note that this environment variable is
meant to be temporary, as a way to make incremental changes and merge them into
main without affecting backwards compatibility immediately. We do not
intend to keep this environment variable around long-term.
This clustering example shows usage with PyTorch tensors as inputs and return values:
.. code:: python
>>> import torch
>>> from scipy.cluster.vq import vq
>>> code_book = torch.tensor([[1., 1., 1.],
... [2., 2., 2.]])
>>> features = torch.tensor([[1.9, 2.3, 1.7],
... [1.5, 2.5, 2.2],
... [0.8, 0.6, 1.7]])
>>> code, dist = vq(features, code_book)
>>> code
tensor([1, 1, 0], dtype=torch.int32)
>>> dist
tensor([0.4359, 0.7348, 0.8307])
Note that the above example works for PyTorch CPU tensors. For GPU tensors or
CuPy arrays, the expected result for vq is a TypeError, because vq
uses compiled code in its implementation, which won't work on GPU, and there
are currently no GPU specific implementations to delegate to.
More strict array input validation will reject np.matrix and
np.ma.MaskedArray instances, as well as arrays with object dtype:
.. code:: python
>>> import numpy as np
>>> from scipy.cluster.vq import vq
>>> code_book = np.array([[1., 1., 1.],
... [2., 2., 2.]])
>>> features = np.array([[1.9, 2.3, 1.7],
... [1.5, 2.5, 2.2],
... [0.8, 0.6, 1.7]])
>>> vq(features, code_book)
(array([1, 1, 0], dtype=int32), array([0.43588989, 0.73484692, 0.83066239]))
>>> # The above uses numpy arrays; trying to use np.matrix instances or object
>>> # arrays instead will yield an exception with `SCIPY_ARRAY_API=1`:
>>> vq(np.asmatrix(features), code_book)
...
TypeError: 'numpy.matrix' are not supported
>>> vq(np.ma.asarray(features), code_book)
...
TypeError: 'numpy.ma.MaskedArray' are not supported
>>> vq(features.astype(np.object_), code_book)
...
TypeError: object arrays are not supported
========= ========= ========= Library CPU GPU ========= ========= ========= NumPy ✅ n/a CuPy n/a ✅ PyTorch ✅ ✅ JAX ⚠️ no JIT ⛔ Dask ⛔ n/a ========= ========= =========
In the example above, the feature has some support for NumPy, CuPy, PyTorch, and JAX
arrays, but no support for Dask arrays. Some backends, like JAX and PyTorch, natively
support multiple devices (CPU and GPU), but SciPy support for such arrays may be
limited; for instance, this SciPy feature is only expected to work with JAX arrays
located on the CPU. Additionally, some backends can have major caveats; in the example
the function will fail when running inside jax.jit.
Additional caveats may be listed in the docstring of the function.
Some functions also note support for MArray <https://mdhaber.github.io/marray/tutorial.html>__,
a library that add a "missing data" awareness to the array library of your choice. MArray
is not an independent array library; rather, it wraps the namespace of an array API
compatible library to add "mask" support. Consequently, where MArray support is noted,
it is supported in conjunction with all backend/device combinations marked as supported
in the capabilities table.
While the elements of the table marked with "n/a" are inherently out of scope, we are continually working on filling in the rest. Dask wrapping around backends other than NumPy (notably, CuPy) is currently out of scope but this may change in the future.
Please see the tracker issue_ for updates.
.. _dev-arrayapi_implementation_notes:
A key part of the support for the array API standard and specific compatibility
functions for Numpy, CuPy and PyTorch is provided through
array-api-compat <https://github.com/data-apis/array-api-compat>_.
This package is included in the SciPy codebase via a git submodule (under
scipy/_lib), so no new dependencies are introduced.
array-api-compat provides generic utility functions and adds aliases such
as xp.concat (which, for numpy, mapped to np.concatenate before NumPy added
np.concat in NumPy 2.0). This allows using a uniform API across NumPy, PyTorch,
CuPy and JAX (with other libraries, such as Dask, being worked on).
When the environment variable isn't set and hence array API standard support in
SciPy is disabled, we still use the wrapped version of the NumPy namespace,
which is array_api_compat.numpy. That should not change behavior of SciPy
functions, as it's effectively the existing numpy namespace with a number of
aliases added and a handful of functions amended/added for array API standard
support. When support is enabled, xp = array_namespace(input) will
be the standard-compatible namespace matching the input array type to a
function (e.g., if the input to cluster.vq.kmeans is a PyTorch tensor, then
xp is array_api_compat.torch).
.. _dev-arrayapi_adding_support:
As much as possible, new code added to SciPy should try to follow as closely as possible the array API standard (these functions typically are best-practice idioms for NumPy usage as well). By following the standard, effectively adding support for the array API standard is typically straightforward, and we ideally don't need to maintain any customization.
Various helper functions are available in scipy._lib._array_api - please see
the __all__ in that module for a list of current helpers, and their docstrings
for more information.
To add support to a SciPy function which is defined in a .py file, what you
have to change is:
xp rather np functions,Input array validation uses the following pattern::
xp = array_namespace(arr) # where arr is the input array
xp = array_namespace(arr1, arr2)
arr = xp.asarray(arr)
arr = _asarray(arr, order='C', dtype=xp.float64, xp=xp)
Note that if one input is a non-NumPy array type, all array-like inputs have to be of that type; trying to mix non-NumPy arrays with lists, Python scalars or other arbitrary Python objects will raise an exception. For NumPy arrays, those types will continue to be accepted for backwards compatibility reasons.
If a function calls into a compiled code just once, use the following pattern::
x = np.asarray(x) # convert to numpy right before compiled call(s) y = _call_compiled_code(x) y = xp.asarray(y) # convert back to original array type
If there are multiple calls to compiled code, ensure doing the conversion just once to avoid too much overhead.
Here is an example for a hypothetical public SciPy function toto::
def toto(a, b): a = np.asarray(a) b = np.asarray(b, copy=True)
c = np.sum(a) - np.prod(b)
# this is some C or Cython call
d = cdist(c)
return d
You would convert this like so::
def toto(a, b): xp = array_namespace(a, b) a = xp.asarray(a) b = xp.asarray(b, copy=True)
c = xp.sum(a) - xp.prod(b)
# this is some C or Cython call
c = np.asarray(c)
d = cdist(c)
d = xp.asarray(d)
return d
Going through compiled code requires going back to a NumPy array, because SciPy's extension modules only work with NumPy arrays (or memoryviews in the case of Cython). For arrays on CPU, the conversions should be zero-copy, while on GPU and other devices the attempt at conversion will raise an exception. The reason for that is that silent data transfer between devices is considered bad practice, as it is likely to be a large and hard-to-detect performance bottleneck.
In some cases, compiled code can be supported on alternative backends
through delegation to native implementations. Such delegation has currently been set
up in ~scipy.fft, ~scipy.ndimage, ~scipy.signal
(see :ref:array_api_support_signal_caveats),
and ~scipy.special, though there is not yet a
standard approach, and support in each module has mostly evolved separately.
There is also some effort being put into expanding access to native
implementations, such as the xsf project <https://github.com/scipy/xsf/issues/1>_
to establish a library of mathematical special function implementations which support
both CPU and GPU.
.. _dev-arrayapi_jax_support:
A note on JAX support
JAX was designed with deliberate restrictions to make code easier to reason about and exploits
this to better support features like JIT-compilation and autodifferentiation. The most
relevant restrictions for SciPy developers are:
* JAX arrays are immutable. Rather than performing in-place updates of arrays, one
can use the `at <https://docs.jax.dev/en/latest/_autosummary/jax.numpy.ndarray.at.html>`_
property to transform an array in an equivalent way. Inside a JIT compiled function,
an expression like ``x = x.at[idx].set(y)`` will be applied in-place under the hood.
* Functions using the JAX JIT must be functionally pure. They cannot have side
effects, cannot mutate data, and their outputs must be determined completely by their
inputs. Raising a Python exception is a side-effect that is not permitted within a JITed
function.
* Within the JIT, value based control flow with Python ``if`` statements is not permitted.
Only static properties of arrays such as their ``shape`` and ``dtype`` are permitted to be
used with ``if``. `xp.where <https://data-apis.org/array-api/2024.12/API_specification/generated/array_api.where.html#where>`_
and `array_api_extra.apply_where <https://data-apis.org/array-api-extra/generated/array_api_extra.apply_where.html>`_ are
provide some basic control flow that works with the JIT.
* Within the JIT, the shapes of output arrays cannot depend dynamically on the *values* in input arrays.
See `Common Gotchas in JAX <https://docs.jax.dev/en/latest/notebooks/Common_Gotchas_in_JAX.html>`_ for more information.
**Recommendations for developers:**
* To work around the mutability restriction, developers adding JAX support
to SciPy functions which make in-place updates should use
`array_api_extra.at <https://data-apis.org/array-api-extra/generated/array_api_extra.at.html>`_
which works for all array API compatible backends, delegating to JAX's ``at`` for
JAX arrays and performing regular in-place operations with ``__setitem__`` for other kinds
of arrays.
* The restriction that functions be functionally pure to support the JAX JIT necessitates
that input-validation that raises with bad input must be skipped when ``xp`` is JAX.
* Compiled functions ``f`` which cannot be supported on JAX through delegation to a native
implementation can potentially still be supported through the use of
`array_api_extra.lazy_apply <https://data-apis.org/array-api-extra/generated/array_api_extra.lazy_apply.html>`_, which uses JAX's `pure_callback <https://docs.jax.dev/en/latest/_autosummary/jax.pure_callback.html>`_
mechanism to enable calling Python functions within JIT-ed JAX functions.
Using ``lazy_apply``, the example function ``toto`` might be made compatible
with the JAX JIT in the following way::
def toto(a, b):
xp = array_namespace(a, b)
a = xp.asarray(a)
b = xp.asarray(b, copy=True)
c = xp.sum(a) - xp.prod(b)
# this is some C or Cython call
# as_numpy=True tells lazy_apply to convert to and from NumPy.
d = xpx.lazy_apply(cdist, c, as_numpy=True)
return d
``lazy_apply`` can be used so long as ``f`` is a pure function for which the output
shape can be determined knowing only the input shapes.
* `xp.where <https://data-apis.org/array-api/2024.12/API_specification/generated/array_api.where.html#where>`_
and `array_api_extra.apply_where <https://data-apis.org/array-api-extra/generated/array_api_extra.apply_where.html>`_ provide a level of basic control flow that works with the JIT
and in some cases these can be used to replace the value dependent use of ``if``. In
some cases its also possible to wrap code using ``if`` within a pure function and use
``lazy_apply``.
**JAX Eager:**
It is also possible to run JAX in eager-mode without the JIT (in fact this is the
default behavior when ``@jax.jit`` is not applied). Eager-mode comes with serious
performance limitations and is typically only used to debug functions which are
ultimately intended to be run with the JIT. Do not be tempted to attempt to distinguish
whether JAX is being used with the JIT in order to bypass some restrictions while running
with eager-mode. There is no way to make this distinction using JAX's public API, and any means
of determining if JAX is running with the JIT would necessarily involve implementation
details that SciPy should not rely upon. In general, support for eager-mode is not a high-value
target, and it is not considered a good use of developer time to put significant effort
into enabling eager-only support.
.. _dev-arrayapi_marray_support:
A note on MArray support
MArray wraps array API compatible namespaces, so it is common for an array API compatible function to execute without warnings or errors provided MArray input. This does necessarily mean that the function supports MArray input. The mask of MArrays is used to denote missing values; therefore, to consider a function compatible with MArray, numerical output with masked input must equal the numerical output expected if the masked values were removed entirely. MArray provides primitives to facilitate this, but typically, implementations must be generalized to ensure that this property holds.
For instance, suppose that the function mean were not defined by the array API
standard, and it had been implemented in SciPy as::
def mean(x, axis=0): xp = array_namespace(x) n = x.shape[axis] return xp.sum(x, axis=axis) / n
This implementation assumes that the denominator used to normalize the sum will be the
same for each axis-slice: the length of the array along the axis. This assumption
is not valid for an array with masked values. Although MArray will ensure that masked
values are not considered when computing the sum, the implementation still needs to be
generalized to normalize by the number of non-masked elements in a slice::
from scipy._lib._array_api import _count_nonmasked
def mean(x, axis=0): xp = array_namespace(x) # n = x.shape[axis] n = _count_nonmasked(x, axis=axis, keepdims=True, xp=xp) return xp.sum(x, axis=axis) / n
Using counts of non-masked elements instead of slice lengths was by far the most common
generalization needed to add MArray support throughout scipy.stats, and for many
functions, it was the only generalization needed. Other common changes included:
scipy.stats.cramervonmises_2samp),scipy._lib._array_api._share_masks (e.g., see scipy.stats.spearmanrho),scipy._lib._array_api._masked_apply to evaluate an elementwise function
on all data, then re-applying the common mask of the inputs to the output (e.g., see
scipy.stats.cramervonmises), andsum, which propagates a mask, and xp.sum,
which ignores masked values (e.g., see scipy.stats.chisquare).In most cases, use of _count_nonmasked, _share_masks, and _masked_apply has
been sufficient to generalize functions without treating MArray input as a special case
or accessing the mask or data attributes of an MArray directly. In exceptional
cases, scipy._lib._array_api.is_marray is available to determine when the namespace
has been wrapped with MArray, and the namespace of the underlying array library
can be accessed using the _xp attribute of an MArray variable (e.g. see
scipy.stats.mode).
.. warning::
NumPy masked arrays tend to mask the output of calculations that would otherwise return infinite or NaN output. This is unsafe because subsequent calculations will treat these "invalid" outputs as "missing", erroneously producing seemingly valid numerical output when the correct calculation would propagate the infinite or NaN value. When adding MArray support to SciPy functions, ensure that masked values of the output arise due to masked input (e.g., the output of an elementwise function on a masked input value is masked; the output of a reducing function on an all-masked slice is masked) and only due to masked input.
.. _dev-arrayapi_default_dtype:
Default Datatypes
The Array API Standard allows conforming libraries to have
`default datatypes <https://data-apis.org/array-api/latest/API_specification/data_types.html#default-data-types>`_
for integers and real and complex floating point numbers which differ from the
``int64``, ``float64``, ``complex128`` defaults used by NumPy. Our aim is to
have array API supporting SciPy functions with array inputs have behavior which
is independent of the default dtype to the extent that this is practical. This means
that any when using array creation functions from the ``xp`` namespace such as ``xp.zeros``
or ``xp.arange``, one should take care to explicitly set a dtype with the ``dtype``
kwarg; otherwise, the result will depend on the default dtype.
Note that SciPy is currently only tested in CI on platforms/backends/backend-configurations
where a ``float64`` dtype is available. At the moment, ``float32``-only support
varies from function to function and is not well-documented. Some examples of
``float32``-only backends are JAX in its
`default configuration <https://docs.jax.dev/en/latest/default_dtypes.html>`_,
and PyTorch with the
`Metal performance shader backend <https://pytorch.org/blog/introducing-accelerated-pytorch-training-on-mac/>`_
on ARM Mac. We are open to expanding and improving ``float32``-only
support in cases where this is feasible and there is sufficient user interest.
It is therefore recommended that those using SciPy with the JAX backend set the X64 flag by one
of the means JAX provides since the default ``float32``-only configuration is not
currently tested.
Though not directly related to default dtypes, it may be useful to know that JAX defaults to using
`TensorFloat32 <https://github.com/jax-ml/jax/issues/4873>`_
precision for matrix multiplication of ``float32`` arrays on GPUs that support TensorFloat32.
The half-precision mantissas used in this format can cause accuracy issues in scientific applications.
SciPy sets `jax_default_matmul_precision <https://docs.jax.dev/en/latest/_autosummary/jax.default_matmul_precision.html>`_ to
``"float32"`` in its JAX GPU tests and we recommend this configuration for users
of the JAX backend. This configuration option does not affect matrix multiplications of ``float64``
arrays when the X64 flag is enabled.
Array creation functions without array inputs
For array creation functions without array inputs,
adding array API standard support can be accomplished by adding keyword
only arguments xp and device to specify the desired backend and
device respectively. See for instance ~scipy.signal.buttap which
constructs the analog prototype of an Nth-order Butterworth filter. It may also
be desirable to add a dtype kwarg to control the output dtype for
such functions.
Note that none of these keyword arguments are necessary when
there are array arguments from which the backend, device, and desired output dtype
can be inferred (ideally, output dtype should follow input dtypes through the expected
type promotion rules <https://data-apis.org/array-api/latest/API_specification/type_promotion.html>_).
For the sake of API simplicity and consistency and in the
spirit of "There should be one-- and preferably only one --obvious way to do it.",
it is recommended to avoid the use of these kwargs in functions which take
at least one array argument.
It is still under debate how array creation functions without array inputs should
behave with respect to :ref:default dtypes <dev-arrayapi_default_dtype>.
Should they respect default dtype or should the output dtype be fixed across
backends and defaults? Should there be a dtype kwarg for controlling the output
dtype or is being able to apply xp.astype on the output sufficient?
Since there is not yet a consistent pattern to follow, for now its
important to clearly document how such functions behave with respect to the
default dtype in the :ref:extra_note <dev-arrayapi_extra_note> described below.
Support for alternative array API standard backends can be registered and
documented using the xp_capabilities decorator which has the following
signature::
def xp_capabilities( *, # Alternative capabilities table. # Used only for testing this decorator. capabilities_table=None, # Generate pytest.mark.skip/xfail_xp_backends. # See documentation in conftest.py. # lists of tuples [(module name, reason), ...] skip_backends=(), xfail_backends=(), cpu_only=False, np_only=False, reason=None, out_of_scope=False, exceptions=(), # lists of tuples [(module name, reason), ...] warnings=(), # xpx.testing.lazy_xp_function kwargs. # Refer to array-api-extra documentation. allow_dask_compute=False, jax_jit=True, # Extra note to inject into the docstring extra_note=None, # Dictionary mapping method names to dictionaries of method # specific capabilities for use when when xp_capabilities is # applied to a class with varying capabilities per method method_capabilities=None, # Whether the function supports MArrays (used only in documentation) marray=False, ):
This is available in scipy._lib._array_api and can be applied to functions,
methods, and classes to declare the current extent of their array API standard
support. For the sake of brevity, in the remainder of this document, we write
as if xp_capabilities only applies to functions.
The xp_capabilities decorator is what inserts the capabilities table into
docstrings. It also allows developers to tag tests with the
@make_xp_test_case decorator or apply make_xp_pytest_param to pytest
parameters to automatically generate backend specific SKIP/XFAIL markers, and
setting up testing that functions work with the JAX JIT or work in Dask lazily
(i.e. without materializing arrays with dask.compute or otherwise triggering
computation with dask.persist).
.. warning::
The modification of docstrings by xp_capabilities can potentially break
intersphinx references because it currently has the side effect of replacing
implicit roles with :func:. This can be avoided by explicitly
setting the role for references to classes and methods that are
outside of SciPy. The following snippet is taken from the docstring for
~scipy.signal.detrend where the role :meth: for a class method is
needed.
.. code-block:: rst
See Also
--------
:meth:`numpy.polynomial.polynomial.Polynomial.fit` : Create least squares fit polynomial
Note that in some modules a systematic process for delegation to native
implementations is set up, where functions are replaced with wrappers
that perform delegation. In this case, xp_capabilities is not always
applied as a decorator with @ syntax, but may instead be applied
programatically on the wrappers. When working on array API standard
support within a module, its important to be aware of how such delegation
is set up, if any, and how xp_capabilities is being applied. A common
practice currently is to have a file, _support_alternative_backends.py
within a module that sets up such delegation. See for instance
scipy/signal/_support_alternative_backends.py <https://github.com/scipy/scipy/blob/main/scipy/signal/_support_alternative_backends.py>_.
Basic behavior
Using ``xp_capabilities`` with no arguments, like this::
@xp_capabilities()
def my_function(x):
...
declares that a function works on all supported backends, on JAX with the JIT
and lazily in Dask. This is most likely to occur if a function is written
entirely in terms of the array API standard as described earlier in this
document. Such functions are commonly referred to as array-agnostic. For
functions which are written mostly in terms of the array API standard, but
include calls to compiled code sandwiched between conversions to and from NumPy,
``xp_capabilities`` should be given the ``cpu_only=True`` option. Backends
which are supported on GPU by such a function ``f`` through delegation to a
native implementation can be specified with the ``exceptions`` kwarg, which
in this case takes a list of strings specifying GPU-capable backends. The
currently supported string values when using ``cpu_only=True`` are
``"cupy"``, ``"torch"``, and ``"jax.numpy"``.
It is recommended to reserve ``cpu_only=False`` (the default) for
array-agnostic functions which are expected to work on all array API compliant
backends, including ones not tested in SciPy and ones that do not yet exist.
If a function is supported on GPU on all tested backends through delegation to
respective native implementations, one should use ``cpu_only=True`` while listing
each backend in the list of ``exceptions`` like so::
@xp_capabilities(cpu_only=True, exceptions=["cupy"])
When setting ``cpu_only=True``, one may list a reason by passing a string with
the ``reason`` kwarg. This can be helpful for documenting why something is not
supported for other developers. The reason will appear in the ``pytest`` output
when the SciPy test suite is run with ``pytest`` in verbose mode
(with the ``-v`` flag).
JAX JIT
```````
One may declare a function as not supporting the JAX JIT with the option
``jax_jit=False``. See the earlier note on :ref:`JAX support <dev-arrayapi_jax_support>`
for more information.
Unsupported functions
Functions which do not support the array API standard through the means
described earlier in this document should either receive the np_only=True
option or the out_of_scope=True option. The former should be used for
functions which are not currently supported but which are considered in-scope
for future support.
Functions for which array API support has not been added following the
procedures described earlier in this document, but for which delegation to a
native implementation has been set up for one or more array API backends, should
still use np_only=True in their xp_capabilities entries. Just as for
cpu_only=True, exceptions can be passed with the exceptions kwarg (and
also just as for cpu_only=True one can pass a reason with the reason
kwarg)::
@xp_capabilities( np_only=True, exceptions=["cupy"], reason="not converted yet but has CuPy delegation." )
Valid strings to pass in the exceptions list are "array_api_strict",
"cupy", "dask.array", "jax.numpy", and "torch".
If np_only=True and "torch" or "jax.numpy" is added to
the lists of exceptions, it will be declared as supported on both CPU and
GPU.
.. _dev-arrayapi_jax_jit_no_gpu:
.. dropdown:: Functions with JAX JIT support but no GPU support :icon: alert :color: warning
It's possible for a function to be natively available in JAX,
support jax.jit, but not be supported on GPU. Thus, it may be
possible for JAX delegation to be set up in a function which has
not yet received the array API standard compatibility treatment,
and for the JIT to be supported but not the GPU.
Because exceptions does double
duty declaring exceptions to cpu_only=True and np_only=True, it is not
possible to express this situation using xp_capabilities in the way
described above. This is not too serious of an issue because the intention is
that np_only=True is only a temporary state. Through the means described
above in the section on :ref:adding array API support <dev-arrayapi_adding_support>,
it is a reasonable goal for all functions in SciPy's public API to at least reach
the state cpu_only=True. For functions still waiting in the np_only=True state,
xp_capabilities's skip_backends kwarg can be used as an escape hatch to
allow more fine grained declaration of capabilities. See the section on
:ref:skip_backends and xfail_backends <dev-arrayapi_skip_xfail_backends>.
out_of_scope=True signals that there is no intention to ever provide array
API support for a given function. There is not yet a formal policy for which
functions should be considered out-of-scope. Some general rules of thumb that are
being followed are to exclude:
scipy.constants.value <../../reference/generated/scipy.constants.value>scipy.linalg.blas which give direct wrappers to low-level BLAS routines.As an example, the contents of scipy.odr are considered out-of-scope for a
combination of reasons 2 and 3 above. scipy.odr essentially provides a direct
wrapper of the monolithic ODRPACK Fortran library, and its API is tied
to the structure of this monolithic library. An efficient GPU
accelerated implementation of nonlinear weighted orthogonal distance regression
would benefit from not having to support an API so tightly coupled to ODRPACK
but is also a challenging problem in its own right.
(Since the previous paragraph was written scipy.odr has been slated for
deprecation. Things that are deprecated are inherently out-of-scope).
Considerations of what to consider in-scope are evolving, and something which is now considered out-of-scope may be decided to be in-scope in the future if sufficient user interest and feasability are demonstrated.
.. _dev-arrayapi_skip_xfail_backends:
skip_backends and xfail_backends
One may pass lists of tuples of backend string, reason pairs to ``xp_capabilities``
with the ``skip_backends`` and ``xfail_backends`` kwargs. The valid backend strings
are ``"array_api_strict"``, ``"cupy"``, ``"dask.array"``, ``"jax.numpy"`` and ``"torch"``
(note that one should almost never want to skip tests for
`array_api_strict <https://data-apis.org/array-api-strict/>`_ as failures with this
backend most likely indicate a failure to correctly follow the array API standard).
Any backend passed in such a way with either kwarg is declared as unsupported with both
CPU and GPU. The difference between ``skip_backends`` and ``xfail_backends`` is that for
tests using the ``xp`` fixture, ``skip_backends`` adds ``pytest.skip`` markers for
backends and the corresponding tests are skipped entirely, while with ``xfail_backends``,
``pytest.xfail`` markers are added, and tests are still run but expected to fail.
One example in which it is pertinent to use
``skip_backends`` is for functions which otherwise support the array API standard, but
use features which are not available on a particular backend, such as mutation of
arrays through item assignment, which is not supported in JAX. For instance the following
can be used to signify a function which is otherwise array-agnostic, but uses
in-place item-assignment::
@xp_capabilities(
skip_backends=[("jax.numpy", "in-place item assignment")]
)
def function_with_internal_mutation(x):
...
Another case is when there is a bug in the corresponding array library, in which
case the ``reason`` string should contain a link to the upstream issue.
In the caveat above about functions with JAX JIT support but no JAX GPU support
we discussed the edge-case of a function which has not been given array
API standard support in the usual way, is available on JAX through delegation to
a native implementation which supports ``jax.jit``, but does not work on the GPU.
For now, such situations can in principle be handled by using ``cpu_only=True``
and passing in any backends which are not supported even on CPU to ``skip_backends``::
@xp_capabilities(
cpu_only=True,
skip_backends=[
("array_api_strict", "not supported"),
("cupy", "not supported"),
("dask.array", "not supported"),
("torch", "not supported"),
]
)
def oddball_function(x):
...
Such situations are hopefully rare enough that special handling isn't needed.
``xp_capabilities`` has evolved naturally over time to meet developer needs; good
suggestions for ways to improve developer ergonomics are welcome.
Dask Compute
````````````
The default, ``allow_dask_compute=False`` declares that a function works lazily
in Dask and will not materialize any Dask arrays with ``dask.compute`` or
otherwise initiate computation with ``dask.persist``. Use
``allow_dask_compute=True`` to declare that a function supports Dask arrays but
not lazily. Developers can also pass an integer to give a cap for the number of
combined calls to ``dask.compute`` and ``dask.persist`` that are allowed. If a function
is not array-agnostic, then it will typically be the case that
``allow_dask_compute=True`` should be set, unless Dask specific codepaths have been added.
Dask support is currently deprioritized due to structural barriers
that make the development of meaningful Dask support particularly challenging. At present,
developers should feel free to reflexively add ``skip_backends=[("dask.array", "deprioritized")]``
to the ``xp_capabilities`` entry of any function they are working on. Reprioritization may
be considered in the future if a champion emerges and the structural outlook improves.
See `RFC: Should Dask support remain a priority? #24205 <https://github.com/scipy/scipy/issues/24205>`_
for relevant discussion.
.. _dev-arrayapi_extra_note:
``extra_note``
``````````````
Some functions may be supported on an alternative backend, but only in particular
cases, perhaps only for some values of a kwarg, for real arrays but not complex ones,
or only for arrays with fewer than a given number of dimensions. Such caveats should
be supplied with the ``extra_note`` kwarg of ``xp_capabilities``. Note that the
implementation of ``extra_note`` simply inserts a string directly into the Notes section
of the docstring, and one must be careful about whitespace. This is perhaps
best demonstrated with an example::
uses_choose_conv_extra_note = (
"""CuPy does not support inputs with ``ndim>1`` when ``method="auto"``
but does support higher dimensional arrays for ``method="direct"``
and ``method="fft"``.
"""
)
.. _dev-arrayapi_marray:
``marray``
``````````
Use ``marray=True`` to document support for MArrays below the backend capabilities
table. Unless noted in ``extra_note``, MArrays are understood to be supported by
all features of the function and in conjunction with all supported backends.
Applying ``xp_capabilities`` to classes
```````````````````````````````````````
For classes with array API standard support, one must apply ``xp_capabilities``
once to the class itself, not separately to individual methods. The class level
capabilities should be decided based on best judgment of which backends
are generally usable with the class in a holistic sense. If individual methods
differ in their capabilities, this can be specified using the
``method_capabilities`` kwarg of ``xp_capabilities`` like in the example
below::
@xp_capabilities(
method_capabilities={
"__init__": dict(jax_jit=False),
"bar": dict(cpu_only=True, exceptions=["cupy"], jax_jit=False),
}
)
class Foo:
def __init__(self, x):
...
def bar(self, y):
# not array-agnostic but has delegation to CuPy to set up
...
def baz(self, y):
# array-agnostic method
...
Adding ``method_capabilities`` makes no changes to the documentation but does
make it possible to access method level capabilities when adding tests and
to test class methods with the JAX JIT. Documentation of
method specific support and limitations should be added to the ``extra_note``
described above.
``method_capabilities`` should be a dictionary mapping method names to
dictionaries with keys corresponding to the usual arguments of ``xp_capabilities``.
Keys that are not supplied in the inner dictionaries will be filled with the
``xp_capabilities`` default values. Entries in ``method_capabilities`` completely
override the class level capabilities entry so that one can declare that some
methods are supported on backends for which the class itself is considered
unsupported; this is useful for incremental development. If a method has no
corresponding entry in ``method_capabilities``, then by default, its capabilities
will be the same as the class level capabilities.
.. _dev-arrayapi_adding_tests:
Adding tests
------------
To run a test on multiple array backends, you should add the ``xp`` fixture to
it. ``xp`` currently supports testing with the following backends:
* `array_api_strict <https://data-apis.org/array-api-strict/>`_
* `cupy <https://cupy.dev/>`_
* `dask.array <https://www.dask.org/>`_
* `jax.numpy <https://docs.jax.dev/en/latest/>`_,
* `numpy <https://numpy.org/>`_
* `torch <https://pytorch.org/>`_
``xp`` is a
`pytest fixture <https://docs.pytest.org/en/6.2.x/fixture.html>`_
which is parameterized over all currently installed backends among
those listed above. Note that ``xp`` takes values from the set of "raw"
namespaces, not the wrapped namespaces from
:ref:`array_api_compat <dev-arrayapi_implementation_notes>`.
``scipy._lib._array_api`` provides the ``make_xp_test_case``
decorator, and the ``make_xp_pytest_param`` and ``make_xp_pytest_marks``
functions to declare which functions are being tested by a test. These draw on
the ``xp_capabilities`` entries for a function (or in some cases those for a
list of functions) to insert the relevant backend specific skip and xfail
markers.
**make_xp_test_case:**
In most cases, developers should use ``make_xp_test_case``, which is applied as a
decorator to a test function, test method, or entire test class. Applying it to a
test class is equivalent to applying it to each method separately. The decorator can
be applied at both the class and method level as below::
@make_xp_test_case(my_function)
class TestMyFunction:
def test1(self, xp):
...
@make_xp_test_case(other_function)
def test_integration_with_other_function(self, xp)
...
Applying ``@make_xp_test_case(my_function)`` to ``TestMyFunction`` causes
all skips and xfails from the ``xp_capabilities`` entry for ``my_function``
to be applied to all methods in the class. Additional applications of
``@make_xp_test_case`` to individual methods will add additional skips and
xfails and not override the class level decorator. Below is an equivalent
way to write the above example. This style can become unwieldy when there
are many methods in a class testing the same function.::
class TestMyFunction:
@make_xp_test_case(my_function)
def test1(self, xp):
...
@make_xp_test_case(my_function, other_function)
def test_integration_with_other_function(self, xp)
...
**make_xp_pytest_param:**
``make_xp_pytest_param`` covers the situation where a common test body is
parametrized over a list of functions using ``pytest.mark.parametrize``.
It is not used as frequently as ``make_xp_test_case`` but this kind of
situation is not too uncommon.::
@pytest.mark.parametrize(
"func",
[make_xp_pytest_param(func) for func in tested_funcs]
)
def test_foo(func, xp):
...
Without access to ``make_xp_pytest_param``, one might instead have to do
something like::
@make_xp_test_case(*test_funcs)
@pytest.mark.parametrize(
"func", tested_funcs
)
def test_foo(func, xp):
...
But then ``test_foo`` would take on the collective skips and xfails
for all of the functions in ``test_funcs`` taken together, leading to
tests being run with unnecessary skips and xfails.
Unlike ``make_xp_test_case``, only a single function can be passed to any given
call to ``make_xp_pytest_param``. Additional arguments specify additional
parameters for ``pytest.mark.parametrize``, such as in the example
below::
@pytest.mark.parametrize(
"func,norm",
[
make_xp_pytest_param(func, norm)
for func, norm in it.product(tested_functions, [True, False])
]
)
def test_normed_foo(func, norm, xp):
...
**make_xp_pytest_marks:**
``make_xp_pytest_marks`` is rarely used. It directly returns a list of
pytest marks which can be used with the ``pytestmark = ...`` variable
to set marks for all tests in a file.
**Strict checks:**
The ``xp`` fixture should almost always be used along with ``make_xp_test_case``
or one of the similar functions listed above and the ``xp`` fixture has
strict checks to enforce this. If one had accidentally written::
@pytest.mark.parametrize(
"func", tested_funcs
)
def test_foo(func, xp):
...
without using ``make_xp_pytest_param`` then running this test would result
in an error with the the message::
ERROR scipy/my_module/tests/test_foo.py::test_foo[numpy] - UserWarning: test uses `xp`
fixture without drawing from `xp_capabilities` but is not explicitly marked with ``pytest.mark...
Since ``xp_capabilities`` is used to declare alternative backend support for the
purpose of both testing and documentation, this strict check in the ``xp``
fixture ensures that documentation of tested array API capabilities does not
become out-of-date. There may be cases where one intentionally does cannot or
does not want to use ``make_xp_test_case`` or an equivalent, such as for private
functions which do not have associated ``xp_capabilities`` entries. To bypass
the strict checks, one can explicitly mark a test with
``@pytest.mark.uses_xp_capabilities(False)``. An optional ``reason`` string can
be passed to this mark. Tests of private functionality for which there are no
relevant ``xp_capabilities`` entries, one should use ``reason="private"``.::
pytest.mark.uses_xp_capabilities(False, reason="private")
def test_private_toto_helper(xp):
...
MArray testing
``````````````
Note that tests for MArray support are not added automatically by any of the mechanisms
above; support must be tested manually. A common pattern is property-based testing:
rather than testing the output of a function with particular MArray input against
manually calculated reference values, compare the output of the function with
randomly-generated MArray input against the output of a reference implementation,
such as the same function with ``nan_policy='omit'`` or with the masked data removed
programmatically. For reducing functions, it is also important to test the behavior of
the function when some slices have insufficient (e.g. zero or one) element remaining
after masked elements have been removed. Typically, MArray tests use the ``xp`` fixture
to test MArray used *in conjunction* with all array backends supported by the function
being tested. See ``scipy/stats/tests/test_marray.py`` for examples.
Directly adding pytest markers
``````````````````````````````
Though most of the time it's sufficient to use ``make_xp_test_case`` and
``make_xp_pytest_param``, the following ``pytest`` markers are available and can
be added directly to tests. (``make_xp_test_case`` and its equivalents provide a
declarative means of adding ``skip_xp_backends`` and ``xfail_xp_backends``
markers).
* ``skip_xp_backends(backend=None, reason=None, np_only=False, cpu_only=False, eager_only=False, exceptions=None)``:
skip certain backends or categories of backends.
See docstring of ``scipy.conftest.skip_or_xfail_xp_backends`` for information on how
to use this marker to skip tests.
* ``xfail_xp_backends(backend=None, reason=None, np_only=False, cpu_only=False, eager_only=False, exceptions=None)``:
xfail certain backends or categories of backends.
See docstring of ``scipy.conftest.skip_or_xfail_xp_backends`` for information on how
to use this marker to xfail tests.
* ``skip_xp_invalid_arg`` is used to skip tests that use arguments which
are invalid when ``SCIPY_ARRAY_API`` is enabled. For instance, some tests of
`scipy.stats` functions pass masked arrays to the function being tested, but
masked arrays are incompatible with the array API. Use of the
``skip_xp_invalid_arg`` decorator allows these tests to protect against
regressions when ``SCIPY_ARRAY_API`` is not used without resulting in failures
when ``SCIPY_ARRAY_API`` is used. In time, we will want these functions to emit
deprecation warnings when they receive array API invalid input, and this
decorator will check that the deprecation warning is emitted without it
causing the test to fail. When ``SCIPY_ARRAY_API=1`` behavior becomes the
default and only behavior, these tests (and the decorator itself) will be
removed.
* ``array_api_backends``: this marker is automatically added by the ``xp`` fixture to
all tests that use it. This is useful e.g. to select all and only such tests::
spin test -b all -m array_api_backends
* ``uses_xp_capabilities(status, funcs=None, reason=None)``: discussed above.
``make_xp_test_case`` and its equivalents apply the marker
``uses_xp_capabilities(True)`` and direct use of ``uses_xp_capabilities(False)``
can be used to declare a test intentionally does not use ``make_xp_test_case``
or one of its equivalents.
Test specific skips and xfails
``````````````````````````````
For a public function ``f``, ``skip_xp_backends`` and ``xfail_xp_backends`` should
only be used directly for backend related skips and xfails which are needed for
the specific test but which do not reflect the general capabilities of
``f``. Reasons to directly use ``skip_xp_backends`` include when:
1. the test body itself contains unsupported functionality (though one should
try to avoid this whenever possible, see the subsection on testing
practice below).
2. ``f`` is only partially supported on a backend and the test relies on
cases which are not supported, e.g. tests involving complex values for
functions which only support real values on a given backend, tests involving
higher dimensional arrays for functions which only support arrays of size 2d
or less on a given backend.
3. the test exposes a bug in ``f`` on a given backend which crashes test
execution.
For tests exposing bugs on alternative backends that do not crash test
execution, such as bugs that lead to numerical errors, it is preferable to use
``xfail_xp_backends`` so we can be notified with an ``XPASS`` when the
upstream bug is fixed.
``xfail_xp_backends`` should not be used for test failures for an alternative
backend which are actually unrelated to ``f`` but are instead due to bugs
outside ``f`` exposed by other parts of the test body. To avoid such situations,
we recommend as a general practice to attempt to isolate use of the alternative
backend only to the function ``f`` being tested with a caveat that there are
situations where or it is necessary or desired to do otherwise: see the section
on :ref:`backend isolation <dev-arrayapi_backend_isolation>` below for more
information.
Tests which are inherently NumPy only should avoid the ``xp`` fixture
altogether rather than using it with an ``np_only=True`` skip marker.
Note that, in one case, ``xp_capabilities`` offers more granularity than
``skip_xp_backends`` and ``xfail_xp_backends``. ``xp_capabilities`` allows
developers to separately declare support for the JAX JIT and support for lazy
computation with Dask with the respective ``jax_jit`` and ``allow_dask_compute``
kwargs. ``skip_xp_backends`` (``xfail_xp_backends``) offers only an
``eager_only`` kwarg which can only add skips (xfails) for both the JAX JIT and
lazy Dask together. The current state is that one cannot add test specific skips
(xfails) for the JAX JIT without also adding them for lazy Dask and vice
versa. This is a known limitation that arose through the historical process
through which ``xp_capabilities``, ``skip_xp_backends``, and
``xfail_xp_backends`` were developed, and it may be addressed in the future if
there is sufficient developer need.
Array-agnostic assertions
`````````````````````````
``scipy._lib._array_api`` contains array-agnostic assertions such as ``xp_assert_close``
which can be used to replace assertions from `numpy.testing`.
When these assertions are executed within a test that uses the ``xp`` fixture, they
enforce that the namespaces of both the actual and desired arrays match the namespace
which was set by the fixture. Tests without the ``xp`` fixture infer the namespace from
the desired array. This machinery can be overridden by explicitly passing the ``xp=``
parameter to the assertion functions.
Examples
````````
The following examples demonstrate how to use direct markers together with
``make_xp_test_case``::
from scipy.conftest import skip_xp_invalid_arg
from scipy._lib._array_api import xp_assert_close, make_xp_test_case
@make_xp_test_case(toto)
class TestToto:
def test_toto_list_input(self):
# This test is inherently NumPy only so avoids the xp fixture altogether.
a = [1., 2., 3.]
b = [0., 2., 5.]
xp_assert_close(toto(a, b), np.array(a))
...
@pytest.mark.skip_xp_backends(
'cupy',
reason='cupy does not support inputs with ndim>2'
)
def test_toto2(self, xp):
...
...
# Do not run when SCIPY_ARRAY_API=1 is used since calling toto on masked
# arrays will raise in this case.
@skip_xp_invalid_arg
def test_toto_masked_array(self):
...
Running tests
`````````````
After applying these markers, either through ``make_xp_test_case`` or one of its
equvilents, or directly,
``spin test`` can be used with the option ``-b`` or ``--array-api-backend``::
spin test -b numpy -b torch -s cluster
This automatically sets ``SCIPY_ARRAY_API`` appropriately and will cause
tests with the ``xp`` fixture to run only for the selected backends to be
collected. Valid backends are ``numpy``, ``array_api_strict``, ``cupy``,
``dask.array``, ``jax.numpy``, and ``torch``. One may also use the
``-m array_api_backends`` option to restrict collection to only tests using
the ``xp`` fixture. For instance the following command causes pytest to only
collect tests using the ``xp`` fixture with the CuPy backend::
spin test -b cupy -m array_api_backends
To test a library
that has multiple devices with a non-default device, a second environment
variable (``SCIPY_DEVICE``, only used in the test suite) can be set. Valid
values depend on the array library under test, e.g. for PyTorch, valid values are
``"cpu", "cuda", "mps"``. To run the test suite with the PyTorch MPS
backend, use: ``SCIPY_DEVICE=mps spin test -b torch``.
Note that in SciPy's GitHub Actions workflows, there are regular tests
with array-api-strict, Dask, PyTorch, and JAX on CPU, and tests with
CuPy, PyTorch, and JAX on GPU.
A third environment variable (``SCIPY_DEFAULT_DTYPE``, again only used in the
test suite) can be used to control the :ref:`default dtype <dev-arrayapi_default_dtype>`
used by ``torch`` in tests. Valid values are ``"float64"`` and ``"float32"``.
If ``SCIPY_DEFAULT_DTYPE`` is unset, then ``torch``'s default dtype will be ``float64``.
The intention behind testing with different default dtypes is primarily to catch
subtle bugs that can arise with the ``torch`` backend due to internal array creation
that does not explicitly specify a dtype. The intention is not to implicitly test
that functions are numerically accurate with both ``float32`` and ``float64`` inputs,
or that input dtype controls output dtype. These tasks, if done, should instead be accomplished
mindfully by explicitly setting dtypes within tests. For the sake of consistency,
tests intended to be run with the ``torch`` backend should not use array creation
functions without explicitly setting the dtype. At the time of writing, there are many
tests in the test suite which do not follow this practice, and this could be a good source
of first issues for new contributors.
Tests are not optional
``````````````````````
It is not permitted for a function ``f`` to have an ``xp_capabilities``
entry advertising some level of alternative backend support if there are
no tests in the test suite for ``f`` that use the ``xp`` fixture along with
``make_xp_test_case(f)`` or one of its equivalents. All documented alternative
backend support must be tested and the only valid ``xp_capabilities`` entries
for a function without tests of the form described above are ``np_only=True``
with no ``exceptions``, or ``out_of_scope=True``.
The command::
spin check --xp-markers
can be used to check that all functions advertising alternative backend
support have at least one test using the ``xp`` fixture that draws from
``xp_capabilities`` through ``make_xp_test_case`` or one of its equivalents.
This check is run in CI.
.. _dev-arrayapi_backend_isolation:
Backend isolation in tests
``````````````````````````
In most cases, it's important that for any supported function ``f``, there exist
tests using the ``xp`` fixture that restrict use of alternative backends to only
the function ``f`` being tested. Other functions evaluated within a test, for
the purpose of producing reference values, inputs, round-trip calculations,
etc. should instead use the NumPy backend. This helps ensure that any failures
that occur on a backend actually relate to the function of interest, and avoids
the need to skip backends due to lack of support for functions other than
``f``. Property based integration tests which check that some invariant holds
using the same alternative backend across different functions can also have
value, giving a window into the general health of backend support for a module,
but in order to ensure the test suite actually reflects the state of backend
support for each function, it's usually vital to have tests which isolate use
of the alternative backend only to the function being tested.
To help facilitate such backend isolation, there is a function
``_xp_copy_to_numpy`` in ``scipy._lib._array_api`` which can copy an arbitrary
``xp`` array to a NumPy array, bypassing any device transfer guards, while
preserving dtypes. It is essential that this function is only used in
tests. Attempts to copy a device array to NumPy outside of tests should fail,
because otherwise it is opaque as to whether a function is working on GPU or
not. Creation of input arrays and reference output arrays, and computations that
verify that the output of the function being tested satisfies an invariant (such
as round trip tests that a function composed with its inverse gives the identity
function), should all be done with NumPy (using the ``_xp_copy_to_numpy``
function if necessary).
Such backend isolation should not be applied blindly. Consider for example a
vectorized root finding function like `scipy.optimize.elementwise.find_root`.
When testing such a function on alternative backends, isolating use of the
alternative backend only to ``find_root`` by using an input callable ``f`` (the
function for which roots are sought) that converts to and from NumPy would not
be desirable since since ``find_root`` and ``f`` are so tightly coupled in this
case. In other cases, a function ``h`` used in the tests of a function ``g`` may
be known to be so simple and rock solid that there is no point in going through
the trouble of backend isolation. Maintainers are free to use their discretion to
decide whether backend isolation is necessary or desirable.
Testing the JAX JIT compiler (and lazy evaluation with Dask)
The JAX JIT compiler <https://jax.readthedocs.io/en/latest/jit-compilation.html>_
introduces special restrictions to all code wrapped by @jax.jit, which are not
present when running JAX in eager mode. Notably, boolean masks in __getitem__
and .at aren't supported, and you can't materialize the arrays by applying
bool(), float(), np.asarray() etc. to them.
To properly test scipy with JAX, the tested scipy functions must be wrapped
with @jax.jit before they are called by the unit tests. This is done
automatically when using make_xp_test_case and its friends when the
associated xp_capabilities entry (or entries) have jax_jit=True::
from scipy._lib._array_api import make_xp_test_case, xp_assert_close from scipy.mymodule import toto
@make_xp_test_case(toto) def test_toto(xp): a = xp.asarray([3, 10, 5, 16, 8, 4, 2, 1, ]) b = xp.asarray([3, 5, 8, 4, 2, 1]) # When xp==jax.numpy, toto is wrapped with @jax.jit # so long as the xp_capabilities entry for toto has # jax_jit=True. xp_assert_close(toto(a), b)
To achieve this for private functions without xp_capabilities entries,
you should tag them as follows in your test module::
from scipy._lib._array_api import xp_assert_close from scipy._external.array_api_extra.testing import lazy_xp_function from scipy.mymodule import _private_toto_helper
lazy_xp_function(_private_toto_helper)
@pytest.mark.uses_xp_capabilities(False, reason="private") def test_private_toto_helper(xp): a = xp.asarray([1, 2, 3]) b = xp.asarray([0, 2, 5]) # When xp==jax.numpy, _private_toto_helper is wrapped with @jax.jit xp_assert_close(_private_toto_helper(a, b), a)
.. warning::
If instead of importing the functions from scipy.mymodule, the above example
imported mymodule and called toto through the qualified name
mymodule.toto, @jax.jit would not be applied to toto. This is due to an
implementation specific quirk which limits the application of @jax.jit only
to functions in the globals of the module that defines the current test.
If one wishes to use a pattern like mymodule.toto in a test, one must define a
variable lazy_xp_modules at the top of the test file to specify additional places
the testing framework should look for functions tagged with lazy_xp_function::
import scipy.mymodule as mymodule
from scipy._lib._array_api import make_xp_test_case, xp_assert_close
lazy_xp_modules = [mymodule]
@make_xp_test_case(mymodule.toto)
def test_toto(xp):
a = xp.asarray([3, 10, 5, 16, 8, 4, 2, 1, ])
b = xp.asarray([3, 5, 8, 4, 2, 1])
# When xp==jax.numpy, toto is wrapped with @jax.jit
# so long as the xp_capabilities entry for toto has
# jax_jit=True.
xp_assert_close(toto(a), b)
This can be slightly annoying to remember at first, but in practice isn't too bad
once one gets in the habit of checking for this. The essential complexity of
lazy_xp_function is actually quite high, and the current design trades off on
developer ergonomics to allow for a simpler implementation.
Testing lazy evaluation with Dask works similarly, except lazy_xp_function wraps
functions with a decorator that disables compute() and persist() and ensures
that exceptions and warnings are raised eagerly. Similarly as for the JAX JIT,
make_xp_test_case and friends will automatically do this when the associated
xp_capabilities entry has allow_dask_compute=False. The same warning about
requiring lazy_xp_modules applies for tests Dask works with lazy evaluation just
as it does for tests of the JAX JIT.
See full documentation here <https://data-apis.org/array-api-extra/generated/array_api_extra.testing.lazy_xp_function.html>_.
Adding tests for class methods
To declare that a test is testing a particular method of a class,
one can pass a tuple of the form ``tuple[type, str]`` as an entry of
``funcs`` in ``make_xp_test_case`` and ``make_xp_pytest_marks`` or as
the argument ``func`` of ``make_xp_pytest_param``. The tuple
``(A, "f")`` signifies that one is testing the method ``A.f`` of the
class ``A``. Such a tuple is used rather than simply ``A.f``
in order allow unambiguous specification of what is being tested in
cases where a method is inherited from a parent class.::
@make_xp_test_case((Foo, "bar"))
def test_Foo_bar(xp):
...
When passing such a tuple to ``make_xp_pytest_param``, only the first
entry of the tuple is actually used in the resulting pytest param::
@pytest.mark.parametrize("cls", [(A, "f"), (B, "f"), C])
def test(cls, xp):
# cls iterates over A, B, C.
...
When using such tuple arguments, the pytest skips and xfails will be
taken from the class level capabilities, unless a method specific
override was added in the ``method_capabilities`` kwarg of
``xp_capabilities``.
If the capabilities for ``(A, "f")`` have
``jax_jit=True`` (or if Dask is not in ``skip_backends``) then using
``@make_xp_test_case((A, "f"))`` or one of its equivalents
will cause ``lazy_xp_function`` to be applied to ``(A, "f")``.
(``lazy_xp_function`` will in this case replace ``A.f`` with
a clone to avoid unintentional modification of a parent
in cases where a method is inherited from a parent class).
Additional information
----------------------
Here are some additional resources which motivated some design decisions and
helped during the development phase:
* Initial `PR <https://github.com/tupui/scipy/pull/24>`__ with some discussions
* Quick started from this `PR <https://github.com/scipy/scipy/pull/15395>`__ and
some inspiration taken from
`scikit-learn <https://github.com/scikit-learn/scikit-learn/blob/main/sklearn/utils/_array_api.py>`__.
* `PR <https://github.com/scikit-learn/scikit-learn/issues/22352>`__ adding Array
API support to scikit-learn
* Some other relevant scikit-learn PRs:
`#22554 <https://github.com/scikit-learn/scikit-learn/pull/22554>`__ and
`#25956 <https://github.com/scikit-learn/scikit-learn/pull/25956>`__
.. _RFC: https://github.com/scipy/scipy/issues/18286
.. _the tracker issue: https://github.com/scipy/scipy/issues/18867
.. _dev-arrayapi_coverage:
API Coverage
------------
The below tables show the current state of alternative backend support across
SciPy's modules. Public functions, function-like callables, and classes are
included in the tables. Parts of the public API which are deemed out-of-scope
are excluded from consideration when calculating coverage percentages. If a
module or submodule contains no in-scope functions, it is excluded from the
tables. For example, `scipy.datasets` is excluded because its contents are
considered out-of-scope.
.. toctree::
:hidden:
array_api_modules_tables/cluster_vq
array_api_modules_tables/cluster_hierarchy
array_api_modules_tables/constants
array_api_modules_tables/differentiate
array_api_modules_tables/fft
array_api_modules_tables/integrate
array_api_modules_tables/interpolate
array_api_modules_tables/io
array_api_modules_tables/linalg
array_api_modules_tables/linalg_interpolative
array_api_modules_tables/ndimage
array_api_modules_tables/optimize
array_api_modules_tables/optimize_elementwise
array_api_modules_tables/signal
array_api_modules_tables/signal_windows
array_api_modules_tables/sparse
array_api_modules_tables/sparse_linalg
array_api_modules_tables/sparse_csgraph
array_api_modules_tables/spatial
array_api_modules_tables/spatial_distance
array_api_modules_tables/spatial_transform
array_api_modules_tables/special
array_api_modules_tables/stats
array_api_modules_tables/stats_contingency
array_api_modules_tables/stats_qmc
Support on CPU
``````````````
.. array-api-support-per-module::
:backend_type: cpu
:cluster.vq: array_api_support_cluster_vq_cpu
:cluster.hierarchy: array_api_support_cluster_hierarchy_cpu
:constants: array_api_support_constants_cpu
:differentiate: array_api_support_differentiate_cpu
:fft: array_api_support_fft_cpu
:integrate: array_api_support_integrate_cpu
:interpolate: array_api_support_interpolate_cpu
:io: array_api_support_io_cpu
:linalg: array_api_support_linalg_cpu
:linalg.interpolative: array_api_support_linalg_interpolative_cpu
:ndimage: array_api_support_ndimage_cpu
:optimize: array_api_support_optimize_cpu
:optimize.elementwise: array_api_support_optimize_elementwise_cpu
:signal: array_api_support_signal_cpu
:signal.windows: array_api_support_signal_windows_cpu
:sparse: array_api_support_sparse_cpu
:sparse.linalg: array_api_support_sparse_linalg_cpu
:sparse.csgraph: array_api_support_sparse_csgraph_cpu
:spatial: array_api_support_spatial_cpu
:spatial.distance: array_api_support_spatial_distance_cpu
:spatial.transform: array_api_support_spatial_transform_cpu
:special: array_api_support_special_cpu
:stats: array_api_support_stats_cpu
:stats.contingency: array_api_support_stats_contingency_cpu
:stats.qmc: array_api_support_stats_qmc_cpu
Support on GPU
``````````````
.. array-api-support-per-module::
:backend_type: gpu
:cluster.vq: array_api_support_cluster_vq_gpu
:cluster.hierarchy: array_api_support_cluster_hierarchy_gpu
:constants: array_api_support_constants_gpu
:differentiate: array_api_support_differentiate_gpu
:fft: array_api_support_fft_gpu
:integrate: array_api_support_integrate_gpu
:interpolate: array_api_support_interpolate_gpu
:io: array_api_support_io_gpu
:linalg: array_api_support_linalg_gpu
:linalg.interpolative: array_api_support_linalg_interpolative_gpu
:ndimage: array_api_support_ndimage_gpu
:optimize: array_api_support_optimize_gpu
:optimize.elementwise: array_api_support_optimize_elementwise_gpu
:signal: array_api_support_signal_gpu
:signal.windows: array_api_support_signal_windows_gpu
:sparse: array_api_support_sparse_gpu
:sparse.linalg: array_api_support_sparse_linalg_gpu
:sparse.csgraph: array_api_support_sparse_csgraph_gpu
:spatial: array_api_support_spatial_gpu
:spatial.distance: array_api_support_spatial_distance_gpu
:spatial.transform: array_api_support_spatial_transform_gpu
:special: array_api_support_special_gpu
:stats: array_api_support_stats_gpu
:stats.contingency: array_api_support_stats_contingency_gpu
:stats.qmc: array_api_support_stats_qmc_gpu
Support with JIT
````````````````
.. array-api-support-per-module::
:backend_type: jit
:cluster.vq: array_api_support_cluster_vq_jit
:cluster.hierarchy: array_api_support_cluster_hierarchy_jit
:constants: array_api_support_constants_jit
:differentiate: array_api_support_differentiate_jit
:fft: array_api_support_fft_jit
:integrate: array_api_support_integrate_jit
:interpolate: array_api_support_interpolate_jit
:io: array_api_support_io_jit
:linalg: array_api_support_linalg_jit
:linalg.interpolative: array_api_support_linalg_interpolative_jit
:ndimage: array_api_support_ndimage_jit
:optimize: array_api_support_optimize_jit
:optimize.elementwise: array_api_support_optimize_elementwise_jit
:signal: array_api_support_signal_jit
:signal.windows: array_api_support_signal_windows_jit
:sparse: array_api_support_sparse_jit
:sparse.linalg: array_api_support_sparse_linalg_jit
:sparse.csgraph: array_api_support_sparse_csgraph_jit
:spatial: array_api_support_spatial_jit
:spatial.distance: array_api_support_spatial_distance_jit
:spatial.transform: array_api_support_spatial_transform_jit
:special: array_api_support_special_jit
:stats: array_api_support_stats_jit
:stats.contingency: array_api_support_stats_contingency_jit
:stats.qmc: array_api_support_stats_qmc_jit