Back to Tensorflow

PJRT Examples

third_party/xla/docs/pjrt/examples.md

2.21.02.0 KB
Original Source

PJRT Examples

Example: JAX CUDA plugin

  1. PJRT C API implementation through wrapper (pjrt_c_api_gpu.h).
  2. Set up the entry point for the package (setup.py).
  3. Implement an initialize() method (__init__.py).
  4. Can be tested with any jax tests for CUDA.

Frameworks Implementations

Some references for using PJRT on the framework side, to interface with PJRT devices:

Hardware Implementations