third_party/xla/xla/examples/axpy/README.md
This tutorial and the code in this directory shows how to write a simple StableHLO program and then compile it with XLA and PJRT . The purpose is to show how XLA can ingest StableHLO program and produce an executable that's compatible with the local device. As such, the program is very simple: $\alpha x+y$ ("axpy").
The process includes just a few steps:
All the code is already provided in this directory, which you can build and run using the steps at the end of this page.
We'll define the computation axpy as a StableHLO program, using an MLIR file in the StableHLO dialect.
It can be helpful to consider the computation as a graph, where each node is an operation (an "op" or "HLO" which means "high-level operation") and the graph edges are the data flow between operations. So the graph for axpy looks like this:
graph TD
p0(alpha f32) --> mul(Multiply 4xf32)
p1(x 4xf32) --> mul --> add(Add 4xf32)
p2(y 4xf32) --> add
And here's how we define the program using MLIR (in the StableHLO dialect):
func.func @main(
%alpha: tensor<f32>, %x: tensor<4xf32>, %y: tensor<4xf32>
) -> tensor<4xf32> {
%0 = stablehlo.broadcast_in_dim %alpha, dims = []
: (tensor<f32>) -> tensor<4xf32>
%1 = stablehlo.multiply %0, %x : tensor<4xf32>
%2 = stablehlo.add %1, %y : tensor<4xf32>
func.return %2: tensor<4xf32>
}
This code is in stablehlo_axpy.mlir.
Note: StableHLO expresses broadcasting explicitly, so we use
"stablehlo.broadcast_in_dim" to broadcast our scalar to a 1-dimensional
tensor.
Our program for this tutorial is set up as a test in
stablehlo_compile_test.cc. In this file,
you'll see that we first set up a PjrtClient with the XLA:CPU plugin that
allows us to compile our StableHLO program:
ASSERT_OK_AND_ASSIGN(std::unique_ptr<PjRtClient> client,
GetCApiClient(kCpuPjrtName));
Then we read the StableHLO program from our MLIR file into a string:
// Read StableHLO program to string
std::string program_path = tsl::io::JoinPath(
tsl::testing::XlaSrcRoot(), "examples", "axpy", "stablehlo_axpy.mlir");
std::string program_string;
TF_ASSERT_OK(
tsl::ReadFileToString(tsl::Env::Default(), program_path, &program_string));
In order to parse the StableHLO program, we must first register the appropriate MLIR dialects:
// Register MLIR dialects necessary to parse our program. In our case this is
// just the Func dialect and StableHLO.
mlir::DialectRegistry dialects;
dialects.insert<mlir::func::FuncDialect>();
mlir::stablehlo::registerAllDialects(dialects);
// Parse StableHLO program.
auto ctx = std::make_unique<mlir::MLIRContext>(dialects);
mlir::OwningOpRef<mlir::ModuleOp> program =
mlir::parseSourceString<mlir::ModuleOp>(program_string, ctx.get());
Now that we've set up our client and parsed the StableHLO program we can compile it to an executable:
// Use our client to compile our StableHLO program to an executable.
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<PjRtLoadedExecutable> executable,
client->Compile(*program, CompileOptions{}));
Finally, in stablehlo_compile_test.cc,
we can feed the executable some inputs for the three arguments and
compute the results:
// Create inputs to our computation.
auto alpha_literal = xla::LiteralUtil::CreateR0<float>(3.14f);
auto x_literal = xla::LiteralUtil::CreateR1<float>({1.0f, 2.0f, 3.0f, 4.0f});
auto y_literal =
xla::LiteralUtil::CreateR1<float>({10.5f, 20.5f, 30.5f, 40.5f});
// Get the host device.
PjRtDevice* cpu = pjrt_se_client.devices()[0];
// Transfer our literals to buffers. If we were using a GPU, these buffers
// would correspond to device memory.
TF_ASSERT_OK_AND_ASSIGN(
std::unique_ptr<PjRtBuffer> alpha,
pjrt_se_client.BufferFromHostLiteral(alpha_literal, cpu));
TF_ASSERT_OK_AND_ASSIGN(
std::unique_ptr<PjRtBuffer> x,
pjrt_se_client.BufferFromHostLiteral(x_literal, cpu));
TF_ASSERT_OK_AND_ASSIGN(
std::unique_ptr<PjRtBuffer> y,
pjrt_se_client.BufferFromHostLiteral(y_literal, cpu));
// Do our computation.
TF_ASSERT_OK_AND_ASSIGN(
std::vector<std::vector<std::unique_ptr<PjRtBuffer>>> axpy_result,
executable->Execute({{alpha.get(), x.get(), y.get()}}, /*options=*/{}));
// Convert result buffer back to literal.
TF_ASSERT_OK_AND_ASSIGN(std::shared_ptr<Literal> axpy_result_literal,
axpy_result[0][0]->ToLiteralSync());
// Check to make sure that our results match what we expect.
xla::LiteralTestUtil::ExpectR1Near<float>({13.64f, 26.78f, 39.92f, 53.06f},
*axpy_result_literal,
xla::ErrorSpec(0.01f));
You can build and run this example as follows using
Bazelisk or
Bazel (run from within xla/examples/axpy/):
bazelisk test :stablehlo_compile_test --test_output=all --nocheck_visibility
Sample output from the test should look like this:
==================== Test output for //xla/examples/axpy:stablehlo_compile_test:
[==========] Running 1 test from 1 test suite.
[----------] Global test environment set-up.
[----------] 1 test from StableHloAxpyTest
[ RUN ] StableHloAxpyTest.LoadAndRunCpuExecutable
Loaded StableHLO program from xla/examples/axpy/stablehlo_axpy.mlir:
func.func @main(
%alpha: tensor<f32>, %x: tensor<4xf32>, %y: tensor<4xf32>
) -> tensor<4xf32> {
%0 = stablehlo.broadcast_in_dim %alpha, dims = []
: (tensor<f32>) -> tensor<4xf32>
%1 = stablehlo.multiply %0, %x : tensor<4xf32>
%2 = stablehlo.add %1, %y : tensor<4xf32>
func.return %2: tensor<4xf32>
}
Computation inputs:
alpha:f32[] 3.14
x:f32[4] {1, 2, 3, 4}
y:f32[4] {10.5, 20.5, 30.5, 40.5}
Computation output: f32[4] {13.64, 26.78, 39.920002, 53.06}
[ OK ] StableHloAxpyTest.LoadAndRunCpuExecutable (264 ms)
[----------] 1 test from StableHloAxpyTest (264 ms total)
[----------] Global test environment tear-down
[==========] 1 test from 1 test suite ran. (264 ms total)
[ PASSED ] 1 test.