third_party/xla/docs/errors/error_1000.md
Category: Compile Time: HBM OOM
This error indicates that the program requires more High Bandwidth Memory (HBM) than is physically available on the TPU device.
Sample error messages:
RESOURCE_EXHAUSTED: TPU TensorCore Hbm usage: 34.82G, SparseCore Hbm usage 174.10G, exceeding available bytes: 95.74G
RESOURCE_EXHAUSTED: XLA:TPU compile permanent error. Ran out of memory in memory space hbm. Used 49.34G of 32.00G hbm. Exceeded hbm capacity by 17.34G.
XLA backends: TPU
XLA performs checks to ensure that the aggregate size of all necessary static allocations fit in the device's HBM.
The compiler manages the TPU's fixed HBM capacity for several types of allocations:
This error occurs when the XLA compiler cannot fit all of the above allocations into the device HBM.
Carefully analyze the error message and logs to determine which category of HBM OOM below best describes your error:
If the error explicitly breaks down usage, e.g., "TC Hbm usage: X, SC Hbm usage Y", this means the aggregate TensorCore (TC) + SparseCore (SC) usage exceeds the HBM limit. Compare the two values to identify the bottleneck:
feature_width, max_unique_nz_per_row and logical_replica_count.
You can reduce peak stack usage by tuning the
--xla_sc_num_serialized_tables_to_optimize_hbm flag which serializes
the processing of tables. This comes at the cost of reduced parallelism.maximum_parallel_iterations
increase the amount of input data prefetched into the HBM heap. Lowering
this value can free up significant memory.If you observe the error message "Ran out of memory in memory space HBM" and one or more unexpectedly large allocations are present in the logs (> 50% of HBM limit), it is almost never a hardware capacity issue. It is typically a configuration error. Check the XLA label (if present) of the large allocations for hints on their JAX source code.
jax.debug.print()s.If you observe the error message "Ran out of memory in memory space HBM" and no unexpectedly large tensors are present in the logs, the program runs out of capacity due to the aggregate sum of allocations exceeding the HBM limit. In this case, it is often helpful to visualize the memory profile to identify the specific buffers contributing to the peak usage. See Debug OOM errors with XProf for a step-by-step guide on identifying peak memory contributors.
Once you have identified some of the top contributors, use the following steps to optimize the memory footprint.
You can often resolve OOMs with these configuration adjustments:
jax.jit, specify
donate_argnums for
your model parameters. This allows XLA to overwrite the input memory with
the output.If configuration changes are insufficient, the model topology might be too large for the current hardware setup.
Inefficient tensor shapes are a common, silent cause of OOMs on TPUs. To get peak performance on TPU's, XLA pads tensor dimensions—typically to multiples of 128 for the minor-most dimension and 8 for the second-minor. This padding affects both input arrays and intermediate tensors (HLO temporaries), potentially inflating memory usage significantly, especially with small dimension sizes. See Array Layouts.
(129, 1024) might be padded to (256, 1024),
resulting in nearly 50% memory waste.(128, 1024) requires no padding and incurs 0%
memory waste.Key memory flags can be tuned to trade-off performance for lower memory usage. However, this strategy should be used as a last resort measure since it can adversely affect performance.
If the model is close to fitting into memory, you can use the
jax.checkpoint
decorator with jax.grad to manually control which intermediates are saved on
the forward pass versus recomputed on the backward pass, trading compute cycles
for HBM.
Alternatively, you can force the XLA::Rematerialization pass to prioritize
memory savings, potentially at the cost of slower compilations:
| Flag | Description | Impact / Trade-off |
|---|---|---|
--xla_tpu_max_hbm_size_mib | Manually sets the limit on HBM size used by the Rematerialization pass. | Forces the compiler to work harder to fit the program into a limit smaller than the actual physical HBM. |
--xla_tpu_rematerialization_algo=PEAK_PRIORITY | Focuses efforts at the points of peak memory usage. | Can be more efficient for aggressive memory reduction than the default algorithm. |
--xla_tpu_rematerialization_max_block_size_limit=32 | Controls the maximum number of instructions in a block that can be rematerialized at once. | Increasing this allows for memory savings at the cost of significantly increases compile time. |
--xla_tpu_rematerialization_block_effort_factor=10.0 | Defines the amount of effort (compile time) spent searching for blocks to rematerialize. | Higher values allow a more exhaustive search for memory savings at the cost of increased compile times. |
--xla_tpu_pre_fusion_remat=true | Enables an additional Rematerialization pass before the fusion pass. | Can find more memory savings, but increases compile times and may potentially impact numerical stability. |
Note that making changes to XLA flags should be used as a last resort measure, since it can adversely affect performance.
Debug OOM errors with XProf provides a tutorial on using the XProf Memory Viewer to visualize the compiler's view of HBM usage.
This tool allows you to see peak memory allocation and buffer lifetimes, which is crucial for understanding exactly what consumes HBM at the point of peak utilization. For general profiling setup, see Getting started with Xprof and TensorBoard Profiling.