Back to jax

Pallas: a JAX kernel language

docs/pallas/index.rst

0.3.251.2 KB
Original Source

.. _pallas:

Pallas: a JAX kernel language

Pallas is an extension to JAX that enables writing custom kernels for GPU and TPU. It aims to provide fine-grained control over the generated code, combined with the high-level ergonomics of JAX tracing and the jax.numpy API.

This section contains tutorials, guides and examples for using Pallas. See also the :class:jax.experimental.pallas module API documentation.

.. warning:: Pallas is experimental and is changing frequently. See the :ref:pallas-changelog for the recent changes.

You can expect to encounter errors and unimplemented cases, e.g., when lowering of high-level JAX concepts that would require emulation, or simply because Pallas is still under development.

.. toctree:: :caption: Guides :maxdepth: 2

quickstart pipelining grid_blockspec

.. toctree:: :caption: TPU backend guide :maxdepth: 2

tpu/index

.. toctree:: :caption: Mosaic GPU backend guide :maxdepth: 2

gpu/index

.. toctree:: :caption: Instruction Reference :maxdepth: 2

Instruction Reference <../jax.experimental.pallas>

.. toctree:: :caption: Design Notes :maxdepth: 2

design/index

.. toctree:: :caption: Other :maxdepth: 1

CHANGELOG