third_party/xla/docs/pjrt/pjrt_integration.md
This doc focuses on the recommendations about how to integrate with PJRT, and how to test PJRT integration with JAX.
Option A: You can implement the PJRT C API directly.
Option B: If you're able to build against C++ code in the xla repo (via forking or bazel), you can also implement the PJRT C++ API and use the C→C++ wrapper:
GetPluginPjRtClient returns a C++ PJRT client implemented above):#include "third_party/tensorflow/compiler/xla/pjrt/c/pjrt_c_api_wrapper_impl.h"
namespace my_plugin {
PJRT_Error* PJRT_Client_Create(PJRT_Client_Create_Args* args) {
std::unique_ptr<xla::PjRtClient> client = GetPluginPjRtClient();
args->client = pjrt::CreateWrapperClient(std::move(client));
return nullptr;
}
} // namespace my_plugin
Note PJRT_Client_Create can take options passed from the framework. Here is an example of how a GPU client uses this feature.
With the wrapper, you do not need to implement the remaining C APIs.
You need to implement a method GetPjRtApi which returns a PJRT_Api* containing function pointers to PJRT C API implementations. Below is an example assuming implementing through wrapper (similar to pjrt_c_api_cpu.cc):
const PJRT_Api* GetPjrtApi() {
static const PJRT_Api pjrt_api =
pjrt::CreatePjrtApi(my_plugin::PJRT_Client_Create);
return &pjrt_api;
}
You can call RegisterPjRtCApiTestFactory to run a small set of tests for basic PJRT C API behaviors.
You can either use JAX nightly
pip install --pre -U jaxlib -i https://us-python.pkg.dev/ml-oss-artifacts-published/jax/simple/
pip install git+https://github.com/google/jax
For now, you need to match the jaxlib version with the PJRT C API version. It's usually sufficient to use a jaxlib nightly version from the same day as the TF commit you're building your plugin against, e.g.
pip install --pre -U jaxlib==0.6.1.dev20250428 -i https://us-python.pkg.dev/ml-oss-artifacts-published/jax/simple/
You can also build a jaxlib from source at exactly the XLA commit you're building against (instructions).
We will start supporting ABI compatibility soon.
There are two options for your plugin to be discovered by JAX.
jax_plugins namespace package (i.e. just create a jax_plugins directory and define your module below it). Here is an example directory structure:jax_plugins/
my_plugin/
__init__.py
my_plugin.so
jax_plugins group which points to your full module name. Here is an example via pyproject.toml or setup.py:# use pyproject.toml
[project.entry-points.'jax_plugins']
my_plugin = 'my_plugin'
# use setup.py
entry_points={
"jax_plugins": [
"my_plugin = my_plugin",
],
}
Here are examples of how openxla-pjrt-plugin is implemented using Option 2: https://github.com/openxla/openxla-pjrt-plugin/pull/119, https://github.com/openxla/openxla-pjrt-plugin/pull/120.
You need to implement an initialize() method in your python module to register the plugin, for example:
import os
import jax._src.xla_bridge as xb
def initialize():
path = os.path.join(os.path.dirname(__file__), 'my_plugin.so')
xb.register_plugin('my_plugin', priority=500, library_path=path, options=None)
Please refer to here about how to use xla_bridge.register_plugin. It is currently a private method. A public API will be released in the future.
You can run the line below to verify that the plugin is registered and raise an error if it can't be loaded.
jax.config.update("jax_platforms", "my_plugin")
JAX may have multiple backends/plugins. There are a few options to ensure your plugin is used as the default backend:
jax.config.update("jax_platforms", "my_plugin") in the beginning of the program.JAX_PLATFORMS=my_plugin.JAX_PLATFORMS=''. The default value of JAX_PLATFORMS is '' but sometimes it will get overwritten.Some basic test cases to try:
# JAX 1+1
print(jax.numpy.add(1, 1))
# => 2
# jit
print(jax.jit(lambda x: x * 2)(1.))
# => 2.0
# pmap
arr = jax.numpy.arange(jax.device_count()) print(jax.pmap(lambda x: x +
jax.lax.psum(x, 'i'), axis_name='i')(arr))
# single device: [0]
# 4 devices: [6 7 8 9]
(We'll add instructions for running the jax unit tests against your plugin soon!)
For more examples of PJRT plugins see PJRT Examples.