Back to Tensorflow

XLA Flags Guidance

third_party/xla/docs/flags_guidance.md

2.21.026.7 KB
Original Source

XLA Flags Guidance

This guide offers a curated selection of key XLA flags to assist users in effectively navigating and utilizing XLA's capabilities. The following sections detail flags that can significantly impact runtime performance and memory utilization. Should any issues, such as crashes, arise after enabling a flag, it is recommended to revert to the default setting and create a GitHub issue.

Correctness Flags

FlagDescriptionDefault ValuesSuggested ValuesCandidate Values
xla_mosaic_on_device_checksThis flag enables on-device checks for Mosaic codegen. Currently, the supported checks are on bounds, i.e., if an out-of-bounds memory is touched, the compilation/execution would catch it.xla_mosaic_on_device_checks=boundsxla_mosaic_on_device_checks=boundsxla_mosaic_on_device_checks=bounds

Performance Flags

The following flags are instrumental in enhancing runtime performance. Experimenting with these settings may lead to considerable performance gains.

FlagDescriptionDefault ValuesSuggested ValuesCandidate Values
Pipelining
  1. xla_should_allow_loop_variant_parameter_in_chain
  2. xla_should_add_loop_invariant_op_in_chain
  3. xla_tpu_enable_ici_ag_pipelining | These 3 flags should be used in conjunction to enable collective pipelining of ICI(Interchip-Interconnect) all-gather operations, which creates more opportunities for overlapping execution. | 1. xla_should_allow_loop_variant_parameter_in_chain=kDisabled
  4. xla_should_add_loop_invariant_op_in_chain=kDisabled
  5. xla_tpu_enable_ici_ag_pipelining=false | 1. xla_should_allow_loop_variant_parameter_in_chain=kEnabled
  6. xla_should_add_loop_invariant_op_in_chain=kEnabled
  7. xla_tpu_enable_ici_ag_pipelining=true | 1. xla_should_allow_loop_variant_parameter_in_chain=kDisabled/kEnabled/kAuto
  8. xla_should_add_loop_invariant_op_in_chain=kDisabled/kEnabled/kAuto
  9. xla_tpu_enable_ici_ag_pipelining=true/false v5e/Async xla_enable_async_all_gather xla_tpu_enable_async_collective_fusion xla_tpu_enable_async_collective_fusion_fuse_all_gather | These 3 flags should be used in conjunction to activate asynchronous all-gather operations on v5e. | xla_enable_async_all_gather=kAuto xla_tpu_enable_async_collective_fusion=true xla_tpu_enable_async_collective_fusion_fuse_all_gather=true | xla_enable_async_all_gather=kAuto xla_tpu_enable_async_collective_fusion=true xla_tpu_enable_async_collective_fusion_fuse_all_gather=true | xla_enable_async_all_gather=kDisabled/kEnabled/kAuto xla_tpu_enable_async_collective_fusion=true/false xla_tpu_enable_async_collective_fusion_fuse_all_gather=true/false v5e/Async xla_tpu_enable_async_collective_fusion xla_tpu_enable_async_collective_fusion_fuse_all_reduce | These 2 flags should be used in conjunction to activate asynchronous all-reduce operations on v5e. | xla_tpu_enable_async_collective_fusion=true xla_tpu_enable_async_collective_fusion_fuse_all_reduce=false | xla_tpu_enable_async_collective_fusion=true xla_tpu_enable_async_collective_fusion_fuse_all_reduce=true | xla_tpu_enable_async_collective_fusion=true/false xla_tpu_enable_async_collective_fusion_fuse_all_reduce=true/false Async xla_tpu_enable_async_all_to_all | This flag enables asynchronous all-to-all communication. | xla_tpu_enable_async_all_to_all=false | xla_tpu_enable_async_all_to_all=true | xla_tpu_enable_async_all_to_all=true/false Latency-bound xla_all_gather_latency_bound_threshold_in_bytes | This flag is intended for latency-bound (i.e., small-sized) all-gather operations. Enabling this triggers specific optimizations that can reduce execution time for latency-bound all-gathers. Typically it’s used in inference workloads. | xla_all_gather_latency_bound_threshold_in_bytes=-1 (which is not enabled) | 4~16Mb(i.e. 4~16 * 1024 * 1024) | [0, 9223372036854775807] Latency-bound xla_all_reduce_latency_bound_threshold_in_bytes | This flag is intended for latency-bound (i.e., small-sized) all-gather operations. Enabling this triggers specific optimizations that can reduce execution time for latency-bound all-reduces. Typically it’s used in inference workloads. | xla_all_reduce_latency_bound_threshold_in_bytes=-1 (which is not enabled) | 4~16Mb(i.e. 4~16 * 1024 * 1024) | [0, 9223372036854775807] Latency-bound xla_collective_permute_latency_bound_threshold_in_bytes | This flag is intended for latency-bound (i.e., small-sized) all-gather operations. Enabling this triggers specific optimizations that can reduce execution time for latency-bound collective-permutes. Typically it’s used in inference workloads. | xla_collective_permute_latency_bound_threshold_in_bytes=-1 (which is not enabled) | 4~16Mb(i.e. 4~16 * 1024 * 1024) | [0, 9223372036854775807] Latency-bound xla_all_to_all_latency_bound_threshold_in_bytes | This flag is intended for latency-bound (i.e., small-sized) all-gather operations. Enabling this triggers specific optimizations that can reduce execution time for latency-bound all-to-all. Typically it’s used in inference workloads. | xla_all_to_all_latency_bound_threshold_in_bytes=-1 (which is not enabled) | 4~16Mb(i.e. 4~16 * 1024 * 1024) | [0, 9223372036854775807] xla_enable_async_collective_permute | Rewrites all collective-permute operations to their asynchronous variants. When set to auto, XLA can turn on async collective based on other configurations or conditions automatically. | xla_enable_async_collective_permute=kAuto | xla_enable_async_collective_permute=kAuto | xla_enable_async_collective_permute=kAuto/kEnabled/kDisabled Compute centric xla_tpu_enable_dot_strength_reduction | This flag rewrites non-compute intensive dots as multiply + reduce operations. | Compute centric xla_tpu_enable_dot_strength_reduction=true | xla_tpu_enable_dot_strength_reduction=true | xla_tpu_enable_dot_strength_reduction=true/false Compute centric xla_tpu_dot_dot_fusion | This flag enables dot-dot fusion, which fuses a producer-dot operation with a consumer-dot operation. On doing so, the producer-dot's output is not manifested in slow/main memory driving down memory footprint. | xla_tpu_dot_dot_fusion=true | xla_tpu_dot_dot_fusion=true | xla_tpu_dot_dot_fusion=true/false Compute centric xla_jf_enable_multi_output_fusion | This flag enables fusions that fuse multiple consumers (i.e. the resultant fusion will have multiple outputs) | xla_jf_enable_multi_output_fusion=true | xla_jf_enable_multi_output_fusion=true | xla_jf_enable_multi_output_fusion=true/false Compute centric xla_tpu_scoped_vmem_limit_kib | This flag sets the amount of scratchpad VMEM available to per op for local usage in KiloBytes. Rest of the VMEM is used as buffer space. | xla_tpu_scoped_vmem_limit_kib=16384 | xla_tpu_scoped_vmem_limit_kib=16384 | xla_tpu_scoped_vmem_limit_kib=[4096, VMEM size of the architecture - 1024] Compute centric xla_tpu_async_copy_bandwidth_scaling_factor | Scales effective bandwidth for async copies. This is used when making prefetch decisions and deciding which tensors should live in VMEM. | xla_tpu_async_copy_bandwidth_scaling_factor=1 | xla_tpu_async_copy_bandwidth_scaling_factor=1 | xla_tpu_async_copy_bandwidth_scaling_factor=(0, 1] Compute centric xla_msa_enable_cross_program_prefetch_freeing | Enables freeing optimization for cross-program-prefetched buffers. | xla_msa_enable_cross_program_prefetch_freeing=enabled | xla_msa_enable_cross_program_prefetch_freeing=enabled | xla_msa_enable_cross_program_prefetch_freeing=enabled/disabled Compute centric xla_tpu_msa_inefficient_use_to_copy_ratio | The ratio of use bytes to copy bytes for a given allocation site below which we consider the site to be inefficient. This is used while making VMEM placement decisions. A value of 0 would treat all sites as efficient and a value of 1 would require the amount of bytes used at the site to be at least as much as the async copy bytes. | xla_tpu_msa_inefficient_use_to_copy_ratio=0.5 | xla_tpu_msa_inefficient_use_to_copy_ratio=0.5 | xla_tpu_msa_inefficient_use_to_copy_ratio=[0, 1]

Memory Flags

The flags listed below are provided to address HBM-related issues. These should only be adjusted if you encounter HBM "out of memory" errors during model compilation. In all other scenarios, the default values are recommended, as altering them could adversely affect performance.

FlagDescriptionDefault ValuesSuggested ValuesCandidate Values
Scheduler
xla_latency_hiding_scheduler_rerunThis setting adjusts the behavior of the latency-hiding scheduler. It works by incrementally reducing the memory limit allocated for scheduling with each "rerun" of the process.xla_latency_hiding_scheduler_rerun=1xla_latency_hiding_scheduler_rerun=50~10(it doesn’t make much sense beyond 10 reruns)
Fusion
xla_tpu_rwb_fusionThis flag enables reduce+broadcast type of fusions, and may decrease memory usage.xla_tpu_rwb_fusion=truexla_tpu_rwb_fusion=falsexla_tpu_rwb_fusion=true/false
Scheduler
xla_memory_schedulerThis flag specifies the algorithm the memory scheduler will use to minimize memory consumption. Using a more advanced algorithm might get a less memory-consuming schedule, at the cost of longer compilation time.xla_memory_scheduler=kDefaultxla_memory_scheduler=kBrkgaxla_memory_scheduler=kDefault/kList/kDfs/kPostOrder/kBrkga
Scheduler
xla_tpu_enable_latency_hiding_schedulerThis flag enables the latency-hiding scheduler, which allows us to perform asynchronous collective instead of synchronous ones. Disabling it reduces memory usage at the cost of losing the performance gains from these asynchronous operations.xla_tpu_enable_latency_hiding_scheduler=truexla_tpu_enable_latency_hiding_scheduler=falsexla_tpu_enable_latency_hiding_scheduler=true/false
SPMD
xla_jf_spmd_threshold_for_windowed_einsum_mibThis flag sets the lower threshold of the minimum size of the dot to trigger collective matmul. Setting it to a higher value would save memory at the cost of losing opportunities to perform collective matmul.xla_jf_spmd_threshold_for_windowed_einsum_mib=-110Mb~1Gb (i.e. 10*1024*1024 ~ 1024*1024*1024)[0, 9223372036854775807]
Scheduler
xla_gpu_enable_analytical_sol_latency_estimatorThis flag enables the analytical estimator which maximizes compute-communication overlap on GPUs.xla_gpu_enable_analytical_sol_latency_estimator=truexla_gpu_enable_analytical_sol_latency_estimator=falsetrue/false

Other commonly used flags

FlagTypeNotes
xla_dump_toString (filepath)The folder where pre-optimization HLO files and other artifacts will be placed (see XLA Tools).

TPU XLA flags

FlagTypeNotes
xla_tpu_enable_data_parallel_all_reduce_optBoolean (true/false)Optimization to increase overlap opportunities for DCN (data center networking) all-reduces used for data parallel sharding.
xla_tpu_data_parallel_opt_different_sized_opsBoolean (true/false)Enables pipelining of data parallel ops across multiple iterations even if their output sizes don't match what can be saved in place in the stacked variables. Can increase memory pressure.
xla_tpu_spmd_rng_bit_generator_unsafeBoolean (true/false)Whether to run RngBitGenerator HLO in a partitioned way, which is unsafe if deterministic results are expected with different shardings on different parts of the computation.
xla_tpu_megacore_fusion_allow_agsBoolean (true/false)Allows fusing all-gathers with convolutions/all-reduces.
xla_tpu_enable_ag_backward_pipeliningBoolean (true/false)Pipelines all-gathers (currently megascale all-gathers) backwards through scan loops.

GPU XLA flags

The -O1 optimization level enables advanced compiler passes for improved GPU performance, including several categories of flags below: pipelining of data-parallel collectives (xla_gpu_enable_pipelined_all_gather, xla_gpu_enable_pipelined_all_reduce, xla_gpu_enable_pipelined_reduce_scatter), while loop unrolling (xla_gpu_enable_while_loop_double_buffering), latency hiding scheduling (xla_gpu_enable_latency_hiding_scheduler), and SOL latency estimator on Hopper/Blackwell (xla_gpu_enable_analytical_sol_latency_estimator). See GPU Effort Levels for details.

FlagTypeNotes
xla_gpu_enable_latency_hiding_schedulerBoolean (true/false)This flag enables latency hiding schedulers to overlap asynchronous communication with computation efficiently. The default value is False.
xla_gpu_enable_analytical_sol_latency_estimatorBoolean (true/false)Enables platform specific scheduling decisions, which in turn improve compute-communication overlap. The default value is true.
xla_gpu_analytical_latency_estimator_optionsStructured stringConfigures parameters for the xla_gpu_enable_analytical_sol_latency_estimator. Adjust by setting nic_speed_gbps=$NIC_SPEED,nccl_op_launch_us=$LAUNCH_OVERHEAD,chunk_prep_us=$CHUNK_PREP,rtt_us=$RTT,chunk_size_bytes=$CHUNK_SIZE,gpus_per_node=$GPUS_PER_NODE. The default value depends on a detected platform.
xla_gpu_enable_triton_gemmBoolean (true/false)Use Triton-based matrix multiplication.
xla_gpu_enable_command_bufferList of CommandBufferCmdTypeWhich kind of commands should be captured in command buffers.
xla_gpu_all_reduce_combine_threshold_bytesInteger (bytes)These flags tune when to combine multiple small AllGather / ReduceScatter / AllReduce into one big AllGather / ReduceScatter / AllReduce to reduce time spent on cross-device communication. For example, for the AllGather / ReduceScatter thresholds on a Transformer-based workload, consider tuning them high enough so as to combine at least a Transformer Layer’s weight AllGather / ReduceScatter. By default, the combine_threshold_bytes is set to 256.
xla_gpu_all_gather_combine_threshold_bytesInteger (bytes)See xla_gpu_all_reduce_combine_threshold_bytes above.
xla_gpu_reduce_scatter_combine_threshold_bytesInteger (bytes)See xla_gpu_all_reduce_combine_threshold_bytes above.
xla_gpu_enable_pipelined_all_gatherBoolean (true/false)Enable pipelinling of all-gather instructions.
xla_gpu_enable_pipelined_reduce_scatterBoolean (true/false)Enable pipelinling of reduce-scatter instructions.
xla_gpu_enable_pipelined_all_reduceBoolean (true/false)Enable pipelinling of all-reduce instructions.
xla_gpu_enable_pipelined_host_offloadingBoolean (true/false)Enable pipelining of host offloading instructions.
xla_gpu_enable_while_loop_double_bufferingBoolean (true/false)Enable double-buffering for while loop.
xla_gpu_enable_all_gather_combine_by_dimBoolean (true/false)Combine all-gather ops with the same gather dimension or irrespective of their dimension.
xla_gpu_enable_reduce_scatter_combine_by_dimBoolean (true/false)Combine reduce-scatter ops with the same dimension or irrespective of their dimension.