third_party/xla/docs/sparsecore.md
SparseCore is a specialized tiled processor engineered for high-performance acceleration of workloads that involve irregular, sparse memory access and computation, particularly on large datasets stored in High Bandwidth Memory (HBM). While it excels at tasks like embedding lookups, its capabilities extend to accelerating a variety of other dynamic and sparse workloads.
Key architectural features:
| Attribute | TPU v4 | TPU v5p | Trillium |
|---|---|---|---|
| SparseCores/Chip | 4 | 4 | 2 |
| Tiles/SparseCore | 16 | 16 | 16 |
| SIMD Width | 8 | 8 | 8 (F32) 16 (BF16) |
| HBM Capacity | 32 GiB | 96 GiB | 32 GiB |
Effective data preparation is paramount for SparseCore performance, and this is where host preprocessing plays a vital role. It encompasses several key functionalities:
max_ids_per_partition and max_unique_ids_per_partition.Table stacking is a significant optimization technique where multiple embedding tables are logically combined to enhance embedding lookup efficiency. This process is typically handled automatically by the underlying ML framework.
The primary advantage of table stacking is the creation of a larger effective batch size for the operations on these stacked tables. This reduces computational overhead and can be effective in hiding inter-chip communication (ICI) latencies. For optimal performance, a moderate number of stacked tables (generally in the range of 5 to 100) is recommended.
Before data can be processed by SparseCore, it's commonly converted into a Coordinate (COO) sparse tensor format. The COO format is a way to represent sparse matrices efficiently typically using three arrays:
row_ids: An array containing the row indices for each non-zero element. In
the context of batch processing, this often corresponds to the batch
dimension.col_ids: An array containing the column indices for each non-zero element.
For embeddings, these are often the feature or ID values.values (optional): An array holding the actual values of the non-zero
elements at the corresponding (row, col) coordinates. For limit
calculations (discussed later) related to ID counts, these values (gains)
are often not considered.Consider an input sparse matrix representing batches of IDs:
[
[id_A], // Sample 0
[id_A, id_B, id_C], // Sample 1
[id_B, id_B, id_D], // Sample 2 (note duplicate id_B)
]
After conversion to COO format (and potentially after deduplicating IDs within the same sample):
row_ids = [0, 1, 1, 1, 2, 2]
col_ids = [id_A, id_A, id_B, id_C, id_B, id_D]
This conversion is fundamental to how SparseCore processes and distributes work.
The col_ids in particular, are crucial for determining which specific
SparseCore partition an ID belongs to, enabling efficient sharding and lookup.
The SparsecoreConfig, or equivalent mechanisms such as XLA flags, serves as a
high-level interface for controlling a wide array of SparseCore behaviors. A
thorough understanding of these parameters is vital for effective performance
tuning and ensuring the correct operation of your models.
disable_table_stacking: bool = False
False (implying table stacking is generally enabled by
default where the framework supports it).max_ids_per_chip_per_sample: int = 64
64.max_ids_per_table: Optional[Dict[str, int]] = None
max_ids_per_partition. If a
table T is divided into P partitions, this limit applies to the sum
of IDs directed to all P partitions. It's often related to
max_ids_per_partition_per_sample and the overall batch size.xla_sparse_core_max_ids_file flag), where
max_ids_per_partition is defined. This table-level concept is a method
to set those partition-level limits (max_ids and max_uniques).None (the value may be inferred from per-partition limits
or other configurations if not explicitly provided).max_unique_ids_per_table: Optional[Dict[str, int]] = None
max_ids_per_table, but this parameter
specifies the maximum number of unique IDs for each logical table. This is
a critical setting for appropriately sizing on-device buffers used in
unique ID processing and subsequent vector operations.max_unique_ids_per_partition_per_sample.None.allow_id_dropping: bool = False
max_ids_per_partition).
True: IDs that would cause the limits to be exceeded are
silently dropped. Typically, IDs within a partition are processed in a
sorted order, and any ID that would push the running count over the
limit for its designated mini-batch is discarded. This allows the
program to continue execution but may have an adverse impact on model
accuracy.False: An error is triggered, and the process will likely
terminate if observed limits go beyond compiled limits. This approach
ensures all data is processed but requires limits to be configured
more conservatively.False (causing an error on overflow rather than silent data dropping).initialize_tables_on_host: bool = True
True follows this convention.
If it were set to False, it would imply an on-device initialization
mechanism, which could have different performance implications or specific
initialization prerequisites.enable_fast_table_initialization: bool = False
Pipelining is a performance optimization technique that enables the simultaneous execution of operations on the TensorCore (TC) and the SparseCore (SC). By overlapping these computations, overall throughput can be significantly improved.
i (for example,
forward or backward pass) while the TC is concurrently processing a
different part of the same step i, or even parts of adjacent steps like
i-1 or i+1.i might not be fully updated and visible to the SC until step i+2.tf_xla_disable_full_embedding_pipelining. Setting this flag to true
disables full pipelining (overlapping TensorCore and SparseCore
computation), whereas setting it to false (or if the flag's semantics
imply enabling when false) activates it.Without pipelining (simplified sequential flow):
Loop: SC/F_i -> TC/F_i -> TC/B_i -> SC/B_i
With pipelining (simplified overlapped flow):
Time ->
Step i: SC/F_i | TC/F_i | TC/B_i | SC/B_i
Step i+1: SC/F_i+1| TC/F_i+1| TC/B_i+1| SC/B_i+1
Note: The actual pipelining stages implemented in the hardware and compiler can be more intricate, often involving pre-loops, main execution loops, and post-loops to manage data dependencies and ensure correctness.
XLA (Accelerated Linear Algebra) is the domain-specific compiler that translates high-level computational graphs, typically from frameworks like TensorFlow, into highly optimized machine code tailored for TPUs. This includes generating the instructions for operations destined for the SparseCore.
SparseDenseMatmulOp) and other
sparse computations into low-level, executable SparseCore programs.max_ids_per_partition, max_unique_ids_per_partition, often
provided via a limits file specified by flags like
xla_sparse_core_max_ids_file) to statically determine the sizes of and
allocate on-device memory buffers, particularly within the SPMEM.xla_sparse_core_estimate_max_ids for limit
estimation, or xla_sc_detect_nan for debugging).Currently Sparsecore Implementation is internal and served using libtpu.so.
Compilation failures related to SparseCore configurations or resource
constraints often manifest as XLA:TPU compile-time errors. These error
messages can provide valuable insights into issues such as limits being set too
high for the available SPMEM, or the use of unsupported configurations.
On SparseCore, "limits" are fundamental configuration parameters that primarily refer to two per-partition settings for each table that is sharded (distributed) across the available SparseCores:
max_ids_per_partition: This defines the maximum number of total IDs
(including duplicates) that any single SparseCore is expected to send to, or
process for, a specific partition of a given table within a single
computational step.max_unique_ids_per_partition: This defines the maximum number of
unique IDs that any single SparseCore is expected to send to, or process
for, aTable sharding strategy: Embedding tables are typically "mod-sharded"
across all SparseCores in the system. This means each SparseCore becomes
responsible for a distinct subset of the vocabulary (rows) of each table. An
ID j would generally be assigned to SparseCore_k based on a formula like
k = j % num_total_sparse_cores.
Definition of a "partition": In this context, a "partition" refers to the specific segment of an embedding table for which a single SparseCore handles lookups.
SPMEM buffer allocation: These limits are used by the XLA compiler to
statically size and allocate buffers within the on-device scratchpad memory
(SPMEM). Buffers are dimensioned such that all necessary data related to the
IDs for a given partition (up to the specified max_ids and
max_unique_ids limits) can be loaded into SPMEM for processing. This is
particularly crucial for non-elementwise computations, such as reducing
duplicate IDs within a partition (for example, , when creating a Compressed
Sparse Row (CSR) representation), where the entire relevant dataset for that
partition's IDs needs to be readily available in fast memory.
Compiled limits versus observed limits:
allow_id_dropping is enabled) or errors.Calculating limits: The process of determining appropriate limits involves a careful analysis of the input data distribution. For any given table (let's call it T1, which might itself be part of a larger stacked table T):
SparseTensor of shape [BatchSize, MaxSequenceLength]) is initially split across the available SparseCores. For instance, if a TensorCore is paired with 2 SparseCores, each SparseCore might receive a sub-batch of shape [BatchSize/2, MaxSequenceLength].row_ids and col_ids.row_id and col_id) are removed.col_id (within a sample), the target SparseCore responsible for this ID is determined using the mod-sharding rule: target_sc_id = col_id % num_total_sparse_cores.ids_per_sparse_core[target_sc_id]++) and the number of unique IDs (unique_ids_per_sparse_core[target_sc_id]++, after ensuring uniqueness for that specific target_sc_id) that are destined for each target_sc_id.max_ids_per_partition for table T1 is then set to max(ids_per_sparse_core_array).max_unique_ids_per_partition for table T1 is set to max(unique_ids_per_sparse_core_array).T1 is a component of a stacked table, additional transformations like rotations or shifts might be applied to the ID distributions before summing statistics from all constituent tables. This helps in balancing the load across chips.Setting these limits correctly is a balancing act: lower limits can potentially lead to higher performance (as less data needs to be processed per step and SPMEM pressure is reduced), but if set too low, they can result in excessive mini-batching or undesirable ID dropping.
SparseCore communication, particularly in the context of processing a list of IDs for embedding lookups, relies on several coordinated mechanisms:
col_ids), the col_id value is used to determine which SparseCore is responsible for that specific ID: target_sc_id = col_id % num_total_sparse_cores.row_ids and col_ids (along with any associated features or weights if applicable) either to memory (HBM) directly accessible by each SparseCore or to a shared HBM from which the SparseCores will fetch their required data.In essence, the initial "distribution" of IDs to the appropriate SparseCores is largely handled by the sharding scheme and the host preprocessing steps. Subsequent communication involves SparseCores operating on their local data, potentially followed by collective communication operations like all-to-all if data needs to be globally exchanged or reordered across SparseCores before further processing by the TensorCores.
Each SparseCore efficiently manages several distinct types of memory to perform its computations:
max_ids_per_partition and max_unique_ids_per_partition. This static allocation ensures that for any given operation on a table partition (such as CSR reduction), all the necessary data for that partition's IDs (up to the defined limits) can fit into SPMEM.feature_width + 1) * max_unique_nz_per_row * logical_replica_count * 4 bytesfeature_width * max_unique_nz_per_row * logical_replica_count * 4 bytesmaximum_parallel_iterations flag). While more prefetching can improve performance by overlapping host-to-device transfers with device computation, it also consumes more HBM.xla_sc_num_serialized_tables_to_optimize_hbm provides a mechanism to control how many tables' data is kept "live" in HBM stack memory at any given time. Increasing this number effectively serializes the processing for more tables, which can reduce peak HBM stack usage but may come at the cost of performance due to reduced parallelism.The core memory management strategy for SparseCore revolves around using the small, fast SPMEM for the "hot" data that is actively being processed by a SparseCore tile, thereby minimizing accesses to the slower HBM. The configured limits are the primary mechanism for ensuring that SPMEM does not overflow. HBM is utilized for storing large embedding tables and temporary data that either exceeds SPMEM capacity or needs to be shared across different processing units or pipeline stages. The XLA compiler is responsible for orchestrating all data movement and buffer allocation based on these architectural principles and the user-configured limits.
Achieving optimal performance with SparseCore necessitates a clear understanding of potential bottlenecks and how to address them. These can arise on the host, within the SparseCore itself, or in its interaction with the TensorCores.
maximum_parallel_iterations flag to fine-tune data prefetching.tf_xla_disable_full_embedding_pipelining = false or its equivalent).allow_id_dropping is true, overly low limits can also lead to ID dropping, which impacts model accuracy.max_ids_per_partition and max_unique_ids_per_partition are set too high, the XLA compiler may be unable to allocate sufficient SPMEM, resulting in compilation errors like: "Fixed size allocations (...) do not fit in TileSpmem (...)". Additionally, if the term (sample_count * feature_width) / kNumTiles (where kNumTiles is the number of tiles per SC) is too large for staging gather operands within the tile SPMEM, errors such as "Gather operand too large..." can occur.feature_width, max_unique_nz_per_row, and logical_replica_count leads to HBM stack memory requirements that exceed the available HBM, this can cause Out-Of-Memory (OOM) errors either at runtime or during compilation.xla_sc_num_serialized_tables_to_optimize_hbm flag to reduce HBM stack usage by serializing the processing of tables (this usually comes at a performance cost).maximum_parallel_iterations).sample_count per SparseCore, which in turn influences memory consumption and compute load.Analyzing a performance profile is a key step in identifying bottlenecks and uncovering opportunities for optimization within your SparseCore workloads.
TPU:0 SparseCore 1 (pid 1005)) are invaluable for visually identifying dominant operations and idle periods.--vmodule=best_fit_allocator=1 might provide logs of peak heap usage.xla_sc_num_serialized_tables_to_optimize_hbm is set to a high value, you might expect slower SC performance but lower HBM stack consumption.Several flags can be enabled to assist in debugging issues related to SparseCore execution. It's important to note that enabling these checks often incurs a performance penalty and, therefore, they should typically be disabled for production runs.
xla_sparse_core_enable_id_bound_check = truexla_sc_detect_nan = truexla_sc_assert_level=boundsxla_tpu_buffer_contents_sanitizer_config='cores_to_sanitize: [TC, SC_SCS, SC_TILE], sanitizer_mode: LOCAL_ONLY'xla_tpu_verify_launch_id_across_cores=trueSparseCore's SparseDenseMatmulOp is designed to support operations on embedding tables using both 32-bit floating-point (FP32) and integer data types. While model training is typically performed using FP32 precision for embedding tables, post-training quantization (PTQ) can be applied. PTQ allows the use of lower-precision datatypes (like 8-bit integers) for inference, which can potentially lead to improved performance and a reduced memory footprint.
The SparseDenseMatmulOp can be configured to perform "simulated quantization." In this operational mode, embedding vectors are first quantized to a lower precision and then dequantized back to a higher precision (for example, FP32) before they are used in subsequent computations. This technique allows models to be trained while accounting for the effects of quantization noise. Training with simulated quantization can improve the accuracy of the final model when it is fully quantized for inference.
SparseDenseMatmulOp (for quantization):quantization_config_num_buckets = 256
quantization_config_low = -X.X
quantization_config_high = Y.Y
The numerical behavior of the model can change depending on whether pipelining between the TensorCore and SparseCore is enabled. If pipelining is active, gradients processed by the SparseCore might be "stale" (from a previous iteration). This can interact with the quantization process and potentially affect model training dynamics or final accuracy.
The SparseCore ecosystem is subject to continuous development and enhancement.
These enhancements are focused on improving SparseCore's performance, memory efficiency, and operational flexibility for an even wider range of sparse workloads.