third_party/xla/docs/lhs_cost_model.md
This page describes the internals of the cost model used by Latency Hiding Scheduler. If you are interested in tuning the model go straight to the Tuning section.
The Latency Hiding Scheduler (LHS) is a compiler pass that schedules a HLO DAG in a way that minimizes wall time.
Its decisions are guided by the unified cost model, which uses a mixture of performance tables and analytical models. In particular XLA embeds performance tables for a GEMMs and fast-interconnect collectives, and uses analytical networking and fusion cost model for other cases. The rest of the document describes the inner workings of these on a high level.
Performance table consist of two main components: a collector and an interpolator.
The collector is a C++ tool responsible for generating the performance
tables for collective operations. It measures the performance of individual HLO
ops (e.g., all-gather, all-reduce) across a statically defined parameter
space.
The tool performs a sweep over a range of collective ops, transfer sizes, and
transfer schemes for a given cluster. It uses the existing multi-host HLO runner
infrastructure and ExecutionProfile data to run the generated HLO and gather
performance metrics.
Latency tables are collected for a cross-product of the following parameters:
all-reduceall-gatherreduce-scatterrail-alignednon-rail-alignedThis sweep is run for intra-node clusters with 2, 4, and 8 devices.
The result of a collection run is a latency table in .pbtxt format
(approximately 116 KB per platform).
The interpolator is the compiler component that consumes the generated performance tables to provide runtime estimates during compilation.
On initialization, the Interpolator processes the performance table into a map.
This map uses a tuple of (collective_type, transfer_scheme) as its key.
The value associated with each key is a 2D Euclidean plane. This plane indexes the network throughput (measured by the Collector) based on two axes:
When the compiler encounters a collective operation, the Interpolator performs the following steps:
(collective_type, transfer_scheme) as the map key.(transfer_size, num_devices) as the query point.The system is designed to store network throughput rather than raw latency. This design choice significantly simplifies extrapolating performance for transfer sizes not explicitly present in the table.
If the latency tables capture network bandwidth saturation at a collective size
S, the throughput T at that point is considered the maximum. For any new
collective of size S' > S, the runtime can be estimated as:
$$\text{EstimatedTime}(S') = \frac{S'}{T_{\text{saturated}}}$$
<!-- mdformat on -->This allows the model to estimate performance for collectives of any size, even those larger than the 2GiB maximum measured by the Collector.
Important: This extrapolation model relies on the assumption that the generated latency tables capture true network bandwidth saturation. If the tables do not contain measurements at or beyond the saturation point, the interpolator will:
In general XLA:GPU teams maintains performance tables, but in cases user decide to provide their own, it is the responsibility of the user generating the tables to ensure they are representative and include measurements in the bandwidth-saturated region for the target hardware.
Similar to the system for collectives, GEMM latency tables are supported by two components: a collector and an interpolator.
The collector is a C++ tool that computes performance tables for General
Matrix Multiplications (GEMMs). It measures the performance of matrix
multiplications at the HLO dot op level.
The tool performs a sweep over a static space of GEMM dimensions (batch, two non-contracting, and one contracting dimension) and data types.
LHS = bf16,f32, RHS = bf16,f32, OUT = bf16,f32.Latency tables are collected for a cross-product of the following dimensions:
{1, 2, 4}{256, 512, ..., 4096}{256, 512, ..., 4096}{256, 512, ..., 4096}A full sweep generates a .pbtxt latency table, ready to be consumed by
interpolator.
The interpolator is the compiler component that uses the generated tables to estimate GEMM performance.
The collected latency tables allow the interpolator to reconstruct FLOPS for each entry:
<!-- mdformat off(disable mdformat for proper MathJax formatting) -->$$\text{FLOPS} = \frac{2 \times b \times m \times n \times k}{\text{runtime}}$$
<!-- mdformat on -->A key insight is that FLOPS saturate at a certain point; that is, the hardware reaches peak FLOPS beyond a certain matrix shape. This saturation allows the use of the same extrapolation method employed for collectives.
The interpolator builds a 4D Euclidean space from the table data. To provide a performance estimate, it performs a weighted-average interpolation within this 4D space. If there's no table for a certain data type, as a heuristic each dimension is normalized to the number of bytes.
The S-curve model is a fully analytical networking roofline model.
The model is designed to estimate the performance of collective operations based on a set of fixed network properties.
The model requires two categories of inputs:
Fixed Network Properties (User-Defined):
By default, XLA auto-detects a platform and uses values for the most common architectures. These properties are configurable by the user. See Tuning section for details.
Per-Collective Inputs:
AllGather, ReduceScatter)The S-curve model is integrated into XLA:GPU and is being used on Hopper, and
Blackwell.
For other kernels we rely on the GPU performance cost model to estimate the right runtimes. You can read more about it here.
S-curve model can be tuned by issuing right XLA flags. Default configuration should be good enough in majority of cases, but the model control is exposed in other cases.
export NIC_SPEED_GBPS=... # NIC speed per GPU in Gigabytes
export GPUS_PER_NODE=... # Num of GPUs per cluster interconnected with fast network (e.g. NVLINK)
export XLA_FLAGS=--xla_gpu_analytical_latency_estimator_options="nic_speed_gbps=$NIC_SPEED_GBPS,gpus_per_node=$GPUS_PER_NODE"