Back to Tensorflow

ML wheels

third_party/xla/third_party/py/README.md

2.21.05.4 KB
Original Source

ML wheels

Provides a standardized and efficient system for packaging and verifying ML software.

ML wheels features

  • Standardized creation, validation (auditwheel) and testing of wheel artifacts.

  • Availability of the final wheel artifacts in the Bazel Build phase, which enables testing of generated wheels by regular py_test targets together with the rest of the existing tests.

  • Ability to use Bazel RBE for wheel creation and testing.

  • Reproducible and unified steps for generating testing of the wheels on different platforms.

Getting started

  1. Integrate hermetic Python, C++ and CUDA (if needed) toolchains in the project.

    Examples:

    JAX hermetic Python and C++ integration

    JAX hermetic CUDA integration

    TensorFlow hermetic Python integration

    TensorFlow hermetic C++ and CUDA integration

  2. Create python script that produces a wheel, and declare it as py_binary build rule.

    A common case scenario: a python script should take wheel sources provided in the arguments list, then do the required transformations and run command like python -m build in the folder with the collected resources.

    JAX py_binary declaration

    TensorFlow py_binary declaration

  3. Create Bazel build rule that returns python wheel in the output.

    In a common case scenario, this Bazel rule runs py_binary (created in step 1) passed in the rule attributes.

    JAX rule definition

    TensorFlow rule definition

    • The wheel sources should be provided in the wheel build rule attributes.

      To collect the wheel sources that are suitable for all types of Bazel builds, including cross-compile builds, the following build rules should be used: collect_data_files, transitive_py_deps from @xla//third_party/py:python_wheel.bzl, and transitive_hdrs from @xla//xla/tsl:tsl.bzl.

      jaxlib wheel sources

      TensorFlow wheel sources

    • the wheel name should conform to PEP-491 naming convention.

      JAX example

      TensorFlow example

    • Storing of the wheel version is custom, and should be implemented per project. It can be additional repository rule, or a constant in .bzl file.

      JAX example

      Tensorflow example

    • The wheel suffix is controlled by a common repository rule python_wheel_version_suffix_repository, that should be called in WORKSPACE file.

      JAX rule call

      Tensorflow rule call

  4. To verify manylinux tag compliance, use common py_binary verify_manylinux_compliance_test.

JAX tests

Tensorflow test

  1. With the wheel build rule defined, one can run Bazel test targets dependent on the wheel instead of individual Bazel targets. To implement it, define py_import call. py_import target can be used in other python targets in the same way as py_library.

JAX example

Tensorflow example

TensorFlow tests dependent on py_import