docs/pallas/index.rst
.. _pallas:
Pallas is an extension to JAX that enables writing custom kernels for GPU and TPU.
It aims to provide fine-grained control over the generated code, combined with
the high-level ergonomics of JAX tracing and the jax.numpy API.
This section contains tutorials, guides and examples for using Pallas.
See also the :class:jax.experimental.pallas module API documentation.
.. warning::
Pallas is experimental and is changing frequently.
See the :ref:pallas-changelog for the recent changes.
You can expect to encounter errors and unimplemented cases, e.g., when lowering of high-level JAX concepts that would require emulation, or simply because Pallas is still under development.
.. toctree:: :caption: Guides :maxdepth: 2
quickstart pipelining grid_blockspec
.. toctree:: :caption: TPU backend guide :maxdepth: 2
tpu/index
.. toctree:: :caption: Mosaic GPU backend guide :maxdepth: 2
gpu/index
.. toctree:: :caption: Instruction Reference :maxdepth: 2
Instruction Reference <../jax.experimental.pallas>
.. toctree:: :caption: Design Notes :maxdepth: 2
design/index
.. toctree:: :caption: Other :maxdepth: 1
CHANGELOG