third_party/xla/docs/async_ops.md
all-reduce-start and
all-reduce-done).To target the first shortcoming, we propose to introduce one last set of new
asynchronous opcodes: kAsyncStart, kAsyncUpdate, and kAsyncDone. The idea
is to create a generic asynchronous opcode that can wrap any HLO instruction.
The actual operation that will be performed asynchronously will be encoded using
a called computation that only has the instruction as its root and any
parameters for inputs. The in-flight input/output buffer handling and aliasing
can then be shared for any asynchronous operation. The async-start instruction’s
output shape will then be a tuple of the input operands, output values, and any
intermediate state that is needed for the async-update or async-done
instructions.
%async_op {
%param0 = f32[64] parameter(0)
ROOT %op = f32[32] op(f32[64] %param0), op_specific_attr=”foo”
}
%async-start = (f32[64], f32[32], s32[]) async-start(f32[64] %operand),
calls=%async_op
%async-done = f32[32] async-done((f32[64], f32[32], s32[]) %async-start)
In the representation above, only async-start has a called computation since
it is trivial to find what the async-done does by following its operand to
find the corresponding async-start to find the called computation.
Also note
that the first element in the output tuple of async-start aliases with the
operand, so the buffer stays alive until at least the async-done instruction.
Similarly, the second element aliases with the output of async-done, and the
third element is the context state that is used to keep track of the
asynchronous operation. This representation also supports multiple tensors in
the asynchronous operation input and/or output and the aliasing works the same
way:
%async_op {
%param0 = f32[64] parameter(0)
%param1 = f32[64] parameter(1)
ROOT %op = (f32[32], f32[32]) op(f32[64] %param0, f32[64] %param1),
op_specific_attr=”foo”
}
%async-start = ((f32[64], f32[64]), (f32[32], f32[32]), s32[])
async-start(f32[64] %operand0, f32[64] %operand1),
calls=%async_op
%async-done = (f32[32], f32[32]) async-done(%async-start)
In addition, the op can further be decomposed into zero or more async-update
steps that perform intermediate computations. The input/output aliasing works
the same way with the async-update instruction and each async-start and
async-update instructions must have one user that is either another
async-update or an async-done:
%async_op {
%param0 = f32[64] parameter(0)
ROOT %op = f32[32] op(f32[64] %param0), op_specific_attr=”foo”
}
%async-start = (f32[64], f32[32], s32[]) async-start(f32[64] %operand),
calls=%async_op
%async-update0 = (f32[64], f32[32], s32[]) async-update(
(f32[64], f32[32], s32[]) %async-start)
%async-update1 = (f32[64], f32[32], s32[]) async-update(
(f32[64], f32[32], s32[]) %async-update0)
%async-done = f32[32] async-done((f32[64], f32[32], s32[]) %async-update1)
Since having a separate computation to define the operation that will be performed asynchronously is a bit cumbersome, we also propose a syntax sugar to automatically print and parse asynchronous operations as if they are first-class opcodes. The idea is to treat the “-start”, “-update”, and “-done” suffixes specially by automatically creating the computation and instruction (without the suffix) when parsing. For example, the code snippet above can be pretty-printed to the following and the two can be parsed to the same representation:
%op-start = (f32[64], f32[32], s32[]) op-start(f32[64] %operand),
op_specific_attr=”foo”
%op-update0 = (f32[64], f32[32], s32[]) op-update(
(f32[64], f32[32], s32[]) %op-start),
op_specific_attr=”foo”
%op-update1 = (f32[64], f32[32], s32[]) op-update(
(f32[64], f32[32], s32[]) %op-update0)
%op-done = f32[32] op-done((f32[64], f32[32], s32[]) %op-update1)
In order not to create ambiguities, the verifier will not allow an operation to
be wrapped with async-start if we explicitly defined an opcode for that
operation with the “-start” and/or “-done” suffixes. This is also an escape
hatch in case we have any instructions that require HLO-level treatment that
doesn’t fit in the model described above (e.g. the aliasing input/output
buffers). So, initially, copy-start/copy-done,
collective-permute-start/collective-permute-done etc. will continue to use
their respective first-class opcodes instead of the new
async-start/async-done opcodes until we clean up the code to remove these
“-start”/”-done” opcodes.