AI/DEBUG_2CTA.md
Strip the test case down to the smallest input that triggers the hang:
compute-sanitizer so you can distinguish a hang from slow executionGPU printf (cute.printf) is the primary tool. The goal is binary search: narrow down which warp and which operation is blocked.
Printf guards — avoid print storms:
# One thread per warp:
if cute.arch.thread_idx()[0] % 32 == 0:
cute.printf("...")
# One thread per CTA (elect_one is a context manager, not a bool):
with cute.arch.elect_one():
cute.printf("...")
# One specific thread:
if tidx == 0:
cute.printf("...")
Strategy — coarse to fine:
consumer_wait, producer_acquire). This tells you which barrier is stuck.What to print:
cute.arch.block_idx()[0]) — critical for multi-CTA debuggingtry_wait succeeds or fails (use try_wait_token parameter)A hang is always a cycle. Typical chain in a pipelined kernel:
MMA waiting for K from load (pipeline_kv full barrier)
-> Load finished but stuck in producer_tail (waiting for MMA to release empty barrier)
-> MMA can't release because it's waiting for K
Once you see which barrier is stuck, trace backwards: who is supposed to signal it, and why haven't they?
Test with different sequence lengths / block counts to find the pattern:
| seqlen | n_blocks | Result |
|---|---|---|
| 128 | 1 | ? |
| 256 | 2 | ? |
| 384 | 3 | ? |
| 512 | 4 | ? |
If the hang correlates with the number of visits to a pipeline stage (e.g., works for n_blocks <= kv_stages but fails when stages wrap around), the problem is likely in barrier tx_count or phase tracking.
For TMA-based pipelines, arrive_and_expect_tx sets the expected transaction byte count on an mbarrier. If the expected count doesn't match the actual bytes arriving, the barrier either:
In 2CTA / cluster mode, both CTAs' TMAs signal the same cluster-level mbarrier. If each CTA's TMA contributes N bytes, the barrier receives 2N bytes total. The tx_count must be N * cta_group_size, not just N.
All TMA pipelines need doubling — Q, K, and V. Even though each CTA loads a different M-tile for Q, both CTAs' TMA operations still signal the same cluster-level barrier, so the expected byte count must account for both.
mbarrier_try_wait_parity uses a single parity bit (0 or 1). If your pipeline state tracks phase as a monotonically increasing counter (0, 1, 2, 3, ...), you need phase % 2 before passing it to the barrier wait. Without this, phase=2 looks like phase=0 to the hardware, which can cause waits on already-completed barriers or misses on pending ones.
If the kernel works WITH printf but hangs WITHOUT it, the printf is acting as a compiler barrier. The MLIR/LLVM backend cannot optimize through an opaque function call like printf, which prevents harmful instruction reordering.
Signs this is happening:
cute.printf("\n") in the right function fixes the hangfence_view_async_shared, fence_acq_rel_cluster, sync_warp, fence_proxy) do NOT fix it — these affect hardware memory ordering, not compiler schedulingPossible workarounds:
@dsl_user_op decorator on pipeline methods to make them opaque to the compilerasm volatile barriers (if available in the DSL)tcgen05.commit(mbar, mask, cta_group::2) is supposed to signal an mbarrier after all pending MMA operations complete. But if there are no pending operations (empty commit group), the signal only reaches the local CTA's barrier, not the remote CTA's. Fix: use explicit mbarrier_arrive(barrier, dst_cta_rank) to both CTAs.
The default producer_tail (inherited from sm90 pipelines) drains the pipeline by calling producer_acquire in a loop. In 2CTA mode this deadlocks because the consumer (MMA warp) may have already exited without releasing all stages. Fix: make producer_tail a no-op for 2CTA.
Both CTAs in a cluster must get the same tile coordinate. Raw blockIdx.x assigns consecutive values to CTAs in the same cluster. Fix: divide blockIdx.x by cluster_shape_m.
Pipelines where CTA 1's threads remotely arrive on CTA 0's barriers need cluster-sized cooperative group counts. Pipelines that are purely local to each CTA keep per-CTA counts.
Causal mask row positions must account for the CTA's position within the cluster. Multiply m_block by cta_group_size when computing mask coordinates.