AI/SM90_BLOCK_SIZE_TUNING.md
How to choose tile sizes and MMA configurations for FlashAttention on Hopper (SM90).
Use flash_attn/cute/sm90_config_search.py to enumerate feasible configs:
# Both fwd and bwd
python flash_attn/cute/sm90_config_search.py --headdim 128
# Forward only
python flash_attn/cute/sm90_config_search.py --mode fwd --headdim 192-128
# Backward only, custom tile choices
python flash_attn/cute/sm90_config_search.py --mode bwd --headdim 192 --tile-m 64,80 --tile-n 64,96
setmaxnreg. Budget per MMA warp group:
atom_layout_n * 8.Each SM90 backward kernel has num_wg + 1 warp groups (128 threads each):
For forward: num_wg MMA WGs + 1 producer WG. tile_m = num_wg * 64 (no swap).
| num_wg | tile_m (fwd) | Threads | Reg budget | Best for |
|---|---|---|---|---|
| 2 | 128 | 384 | 216/thread | hdim <= 128 |
| 3 | 192 | 512 | 128/thread | hdim 129-192 |
More WGs = larger tile_m = better M-direction parallelism, but tighter register budget and higher smem usage.
Each MMA can optionally swap its A and B operands. This transposes the output tile, exchanging which dimension maps to M (must be divisible by 64) and which maps to N.
When to swap:
Forward: No swap needed since tile_m = num_wg * 64 is always divisible by 64.
Backward (5 MMAs):
The atom_layout distributes WGs across the M and N dimensions of an MMA output. With num_wg MMA WGs and atom_layout_m = A:
After swap, the atom layout is also swapped.
Impact on smem traffic: More WGs in the N direction (wg_n larger) means each instruction reads a smaller B slice, but more instructions total read overlapping A slices. Fewer WGs in N (wg_n smaller) means fewer instructions but each reads a larger B slice. Typically smaller wg_n = less total smem traffic.
When AtomLayoutMSdP == 1 && AtomLayoutNdKV == num_wg && SdP_swapAB && !dKV_swapAB, the P and dS matrices can be kept in registers and fed directly as the A operand of dV and dK GEMMs. This:
This is a significant optimization — always preferred when the conditions are met.
Forward:
Backward:
Accumulator registers per thread per WG = M * N / (num_wg * 128), where M x N is the output tile.
Forward peak registers:
regs_S + regs_P + regs_O (S, P in bf16, O all live)regs_S + regs_O (S and O alternate, P reuses S regs)Where regs_P = regs_S / 2 (bf16 vs f32).
Backward peak registers:
max(2 * regs_SdP, regs_dQ) + regs_dK + regs_dVSum of tensor buffers (ignoring alignment padding, which is small):
Forward: max(sQ, sO) + sK*2 + sV*2 + sP
Backward: sQ*2 + sK + sV + sdO*dO_stage + sP + sdS + sdQaccum
Per-iteration smem bandwidth consumed. Each GMMA instruction reads:
Total instructions = (M_eff / 64) * wg_n. Each instruction independently reads A and B from smem.
Additional traffic: R2S stores for P, dS (bf16), dQ smem store + TMA load (f32).
Traffic per block (traffic / (tile_m * tile_n)) normalizes across tile sizes for comparison. Lower is better.
Best: tile_m=128, tile_n=192, RS, 2 WG. 224K smem, 9.3 tr/blk.
C++ FA3 config: tile_m=80, tile_n=128, SdP_swap=T, dKV_swap=F, dQ_swap=T, aSdP=1, adKV=2. mma_dkv_is_rs=True. 204K smem, 208 regs, 39.6 tr/blk.
3 WG, tile_m=64, tile_n=96, SdP_swap=F, dKV_swap=T, adKV=1 or 3. 216K smem, 128 regs. This is the only feasible tile_n > 64 for hdim=192 due to register pressure.
With 3 WG: need AtomLayoutNdKV=3 (since hdimv=128 not divisible by 3). tile_n=96, 212K smem. With 2 WG: tile_n=112 feasible at 210K smem, or tile_n=64 at 168K smem.