docs_new/docs/advanced_features/attention_backend.mdx
SGLang supports a large variety of attention backends. Each of them has different pros and cons. You can test them according to your needs.
<Warning> Selecting an optimal attention backend is crucial for maximizing your performance. Different backends excel in various scenarios, so choose based on your model, hardware, and use case. Not all backends are supported on all platforms and model architectures.If you don't specify --attention-backend, SGLang makes a best effort to automatically select the most performant backend based on your hardware and model architecture.
</Warning>
The support matrix is split into two parts: MHA (standard attention) and MLA (multi-head latent attention). For an explanation of the key differences between MHA and MLA, please see the SGLang documentation on DeepSeek MLA and the original DeepSeek MLA paper.
Verified backends: TRTLLM MLA, TRTLLM MHA, FA3, Ascend (NPU), Triton.
Limited support: FlashInfer can run under Spec V2, but its plan stream (used for split-KV optimization) introduces a synchronization point that limits overlap benefits. </Note>
<Tip> Page size controls how many tokens are grouped into a KV cache block. For the prefix cache to take effect, the number of tokens must fill at least one complete page. For example, if your prompt is only 32 tokens and `page_size = 64`, it won't fill a complete page and cannot be matched in the prefix cache (pages cannot be padded). With 65 tokens and `page_size = 64`, only the first page of 64 tokens will be cached and matched; the remaining 1 token is discarded. Use `page_size = 1` for maximum prefix reuse (token-level matching). Note that higher page sizes generally improve attention kernel performance, so prefer `page_size > 1` when prefix cache reuse is not critical. </Tip>Many backends that do not natively operate on pages can emulate page_size > 1 at the wrapper layer by expanding page tables to per-token indices. The "Page Size > 1 (native)" column indicates true in-kernel paging. Some backends require fixed native page sizes and cannot be reduced/emulated differently: TRTLLM MHA (16/32/64), TRTLLM MLA (32/64), FlashMLA (64), Cutlass MLA (128), Ascend (128).
MLA page-size constraints:
GDN (Gated Delta Network) is a linear attention mechanism with O(n) complexity, used in hybrid models that alternate GDN linear attention layers with standard full attention layers. GDN is not selected via --attention-backend; it is automatically activated when the model architecture requires it (e.g., Qwen 3.5, Qwen 3 Next, Jet Nemotron, Jet VLM).
The GDN linear attention layers have their own kernel backends, selected via --linear-attn-backend (default: triton). You can override the kernel per phase with --linear-attn-decode-backend and --linear-attn-prefill-backend.
DSA (Deepseek Sparse Attention) is a native sparse attention mechanism used by DeepSeek V3.2. It is activated automatically when the model architecture requires it and is selected via --attention-backend nsa.
Internally, the NSA backend dispatches to different sub-backends for prefill and decode phases. You can override these with --nsa-prefill-backend and --nsa-decode-backend:
For deployment examples, see the DeepSeek V3.2 deployment guide.
You can mix-and-match attention backends for prefill and decode. This is useful when one backend excels at prefill and another excels at decode. For the implementation details, please see python/sglang/srt/layers/attention/hybrid_attn_backend.py.
# Example: Prefill with FA4, Decode with TRTLLM MLA (Blackwell)
python3 -m sglang.launch_server \
--model-path nvidia/DeepSeek-R1-FP4 \
--tp 8 \
--attention-backend trtllm_mla \
--moe-runner-backend flashinfer_trtllm \
--quantization modelopt_fp4 \
--prefill-attention-backend fa4
Hybrid attention also works with speculative decoding. The backend used for draft decoding and target verification depends on --speculative-attention-mode:
--speculative-attention-mode decode (recommended): draft/verify use the decode backend.--speculative-attention-mode prefill (default): draft/verify use the prefill backend.Constraints when combining hybrid attention with speculative decoding:
trtllm_mha, speculative decoding supports only --speculative-eagle-topk 1.--page-size > 1 and --speculative-eagle-topk > 1, only flashinfer is supported.--speculative-attention-mode prefill.If the --attention-backend argument is not specified, SGLang automatically selects the best backend based on the hardware (CUDA) and model architecture.
1. MHA Models (e.g., Llama, Qwen)
fa3 if using CUDA 12.3+ and the model configuration is supported.trtllm_mha, unless using speculative decoding with topk > 1.flashinfer if available; otherwise falls back to triton.2. MLA Models (e.g., DeepSeek V3)
fa3 (requires CUDA 12.3+).flashinfer; trtllm_mla is auto-selected for DeepSeek V3 models specifically.triton.python3 -m sglang.launch_server \
--model meta-llama/Meta-Llama-3.1-8B-Instruct \
--attention-backend flashinfer
python3 -m sglang.launch_server \
--tp 8 \
--model deepseek-ai/DeepSeek-V3 \
--attention-backend flashinfer \
--trust-remote-code
python3 -m sglang.launch_server \
--model meta-llama/Meta-Llama-3.1-8B-Instruct \
--attention-backend fa3
python3 -m sglang.launch_server \
--tp 8 \
--model deepseek-ai/DeepSeek-V3 \
--trust-remote-code \
--attention-backend fa3
python3 -m sglang.launch_server \
--model meta-llama/Meta-Llama-3.1-8B-Instruct \
--attention-backend triton
python3 -m sglang.launch_server \
--tp 8 \
--model deepseek-ai/DeepSeek-V3 \
--attention-backend triton \
--trust-remote-code
python3 -m sglang.launch_server \
--tp 8 \
--model deepseek-ai/DeepSeek-R1 \
--attention-backend flashmla \
--trust-remote-code
python3 -m sglang.launch_server \
--tp 8 \
--model deepseek-ai/DeepSeek-R1 \
--attention-backend flashmla \
--kv-cache-dtype fp8_e4m3 \
--trust-remote-code
python3 -m sglang.launch_server \
--tp 8 \
--model deepseek-ai/DeepSeek-R1 \
--attention-backend trtllm_mla \
--trust-remote-code
python3 -m sglang.launch_server \
--tp 8 \
--model deepseek-ai/DeepSeek-R1 \
--attention-backend trtllm_mla \
--kv-cache-dtype fp8_e4m3 \
--trust-remote-code
python3 -m sglang.launch_server \
--tp 4 \
--model Qwen/Qwen3.5-35B-A3B-FP8 \
--attention-backend trtllm_mha \
--trust-remote-code
python3 -m sglang.launch_server \
--tp 4 \
--model Qwen/Qwen3.5-35B-A3B-FP8 \
--decode-attention-backend trtllm_mha \
--trust-remote-code
# FA4 for both prefill and decode on SM90/SM100
python3 -m sglang.launch_server \
--model-path Qwen/Qwen3-30B-A3B-Instruct-2507-FP8 \
--attention-backend fa4 \
--page-size 128 \
--trust-remote-code
python3 -m sglang.launch_server \
--tp 8 \
--model deepseek-ai/DeepSeek-R1 \
--prefill-attention-backend fa4 \
--trust-remote-code
python3 -m sglang.launch_server \
--tp 8 \
--model deepseek-ai/DeepSeek-R1 \
--attention-backend cutlass_mla \
--trust-remote-code
python3 -m sglang.launch_server \
--model meta-llama/Meta-Llama-3.1-8B-Instruct \
--attention-backend ascend
python3 -m sglang.launch_server \
--model meta-llama/Meta-Llama-3.1-8B-Instruct \
--attention-backend intel_xpu
python3 -m sglang.launch_server \
--model meta-llama/Meta-Llama-3.1-8B-Instruct \
--attention-backend wave
python3 -m sglang.launch_server \
--model meta-llama/Meta-Llama-3.1-8B-Instruct \
--attention-backend flex_attention
python3 -m sglang.launch_server \
--model Qwen/Qwen2.5-14B-Instruct-1M \
--attention-backend dual_chunk_flash_attn
python3 -m sglang.launch_server \
--model meta-llama/Meta-Llama-3.1-8B-Instruct \
--attention-backend torch_native
To add a new attention backend, you can learn from the existing backends
(python/sglang/srt/layers/attention/triton_backend.py, python/sglang/srt/layers/attention/flashattention_backend.py)
and follow the steps below.