docs_new/docs/hardware-platforms/ascend-npus/ascend_npu_optimization.mdx
This guide explains the role of each parameter used in SGLang deployments on Ascend NPU. It uses the DeepSeek-V3.2 best practice configuration as the reference example. For a complete list of tested deployment configurations, see the Ascend NPU Best Practice page.
<Note> Parameters in this guide fall into two categories:[Required]): These must be set correctly for the target deployment scenario (e.g., multi-node communication, PD disaggregation). Incorrect values will cause deployment failures or incorrect behavior.The following system-level tuning steps reduce OS interference and improve CPU scheduling determinism:
<table> <thead> <tr> <th>Command / Variable</th> <th>Purpose</th> </tr> </thead> <tbody> <tr> <td>`echo performance | tee /sys/devices/system/cpu/cpu*/cpufreq/scaling_governor`</td> <td>Locks all CPU cores to the maximum frequency, eliminating DVFS-induced latency jitter during inference-critical paths.</td> </tr> <tr> <td>`sysctl -w vm.swappiness=0`</td> <td>Minimizes kernel swapping of anonymous pages. Reduces the risk of page faults on NPU memory buffers pinned to host RAM.</td> </tr> <tr> <td>`sysctl -w kernel.numa_balancing=0`</td> <td>Disables automatic NUMA page migration. Prevents the kernel from moving memory pages between NUMA nodes while inference is running, which would cause latency spikes.</td> </tr> <tr> <td>`sysctl -w kernel.sched_migration_cost_ns=50000`</td> <td>Sets a minimum task migration cost, discouraging the scheduler from moving inference threads between CPU cores unnecessarily.</td> </tr> <tr> <td>`SGLANG_SET_CPU_AFFINITY=1`</td> <td>Binds SGLang worker processes to specific CPU cores, avoiding cross-core migration overhead for high-frequency scheduling loops.</td> </tr> </tbody> </table>These arguments and environment variables are critical for tuning prefill performance:
<table> <thead> <tr> <th>Argument / Variable</th> <th>Purpose</th> <th>Reference Value</th> </tr> </thead> <tbody> <tr> <td>`--chunked-prefill-size`</td> <td>Sets the maximum number of tokens per prefill chunk. A positive value enables chunked prefill, which interleaves prefill and decode for better concurrency in mixed workloads. Set to `-1` to disable chunking and process each request in a single forward pass, which is preferred for dedicated prefill servers with long-context sequences.</td> <td>`-1`</td> </tr> <tr> <td>`--max-prefill-tokens`</td> <td>Limits the total number of tokens the prefill server can process in one batch. The effective bound is `max(this value, model_max_context_length)`. Set this based on your target sequence length and available NPU memory to bound memory usage while maximizing throughput. Tune by increasing until you encounter out-of-memory errors.</td> <td>`68000`</td> </tr> <tr> <td>`--max-running-requests`</td> <td>Limits the number of concurrent requests being processed. For prefill, a low value (e.g., `1`) dedicates more compute and memory to each request, achieving higher per-request throughput — ideal for dedicated prefill nodes processing long sequences. For general-purpose serving, use a higher value to support multi-request concurrency.</td> <td>`1`</td> </tr> <tr> <td>`--disable-radix-cache`</td> <td>Disables prefix caching via RadixAttention. Set this flag when processing non-overlapping long sequences where prefix caching provides no benefit and only consumes memory. Leave unset (radix cache enabled) for chat/conversation workloads with shared system prompts.</td> <td>true</td> </tr> <tr> <td>`--disable-cuda-graph`</td> <td>Disables CUDA Graph capture. CUDA Graphs reduce kernel launch overhead for small, predictable batch sizes, making them ideal for decode. For prefill with large and variable batch sizes, CUDA Graphs provide minimal benefit and can cause issues with dynamic shapes. Set this flag on prefill nodes; leave unset on decode nodes.</td> <td>true</td> </tr> <tr> <td>`--enable-nsa-prefill-context-parallel`</td> <td><strong>(DeepSeek V3.2 NSA-specific)</strong> Enables context parallelism for the long-sequence prefill phase of DeepSeek V3.2 with NSA (Native Sparse Attention). Distributes the sequence across CP ranks to parallelize the computationally expensive NSA prefill for ultra-long contexts.</td> <td>Enabled</td> </tr> <tr> <td>`--nsa-prefill-cp-mode`</td> <td><strong>(DeepSeek V3.2 NSA-specific)</strong> Controls how the long sequence is split across context parallel ranks: `in-seq-split` divides each sequence uniformly across CP ranks, optimal for single-request prefill. `round-robin-split` (code default) distributes tokens by index mod CP size, supporting multi-batch prefill. Only effective when `--enable-nsa-prefill-context-parallel` is enabled.</td> <td>`in-seq-split`</td> </tr> <tr> <td>`--attn-cp-size`</td> <td>Specifies the context parallelism group size for attention computation. Larger values distribute the sequence across more ranks, reducing per-rank memory and compute at the cost of increased communication. For models with NSA, this controls the CP size for sparse attention prefill. Set to the number of available devices for maximum parallelization.</td> <td>`32`</td> </tr> </tbody> </table>These arguments and environment variables are critical for tuning decode performance:
<table> <thead> <tr> <th>Argument / Variable</th> <th>Purpose</th> <th>Reference Value</th> </tr> </thead> <tbody> <tr> <td>`--dp-size`</td> <td>Sets the data parallelism degree for the decode server. With DP attention enabled, attention layers are sharded across DP ranks while FFN/MoE layers use tensor parallelism. Higher values create more independent decode instances, increasing throughput through parallel request processing. Choose a value that divides evenly into your total card count, with remaining cards used for TP/EP.</td> <td>`8`</td> </tr> <tr> <td>`--ep`</td> <td>Sets the expert parallelism degree. For MoE models, this distributes experts across cards, reducing per-card expert loading overhead and enabling all-to-all dispatch. The code default is `1`; set explicitly for MoE models. The optimal value depends on your model's expert count and architecture. DeepSeek V3.2 with 256 routed experts uses `ep=32`. For models with fewer experts, use a proportionally smaller value.</td> <td>`32`</td> </tr> <tr> <td>`--enable-dp-attention`</td> <td>Enables data parallelism for attention layers while keeping tensor parallelism for FFN/MoE layers. This is a key optimization for decode throughput — attention is DP-sharded to reduce KV cache duplication, while MoE layers remain TP-sharded to leverage expert parallelism. Best suited for MoE models where attention is not the compute bottleneck.</td> <td>Enabled</td> </tr> <tr> <td>`--enable-dp-lm-head`</td> <td>Enables vocabulary parallelism across the DP attention group, sharding the LM head weight across ranks. Each rank only computes logits for its vocabulary shard, avoiding a costly all-gather of logits across the DP group. This is essential when DP attention is enabled to maintain throughput.</td> <td>Enabled</td> </tr> <tr> <td>`--cuda-graph-max-bs`</td> <td>Caps the maximum batch size for which CUDA Graphs are captured. Larger values cover more batch sizes but increase graph capture time and memory overhead. If your `max-running-requests` is high but typical batch sizes are lower, use a smaller value to reduce capture overhead. Tune based on your observed batch size distribution during serving.</td> <td>`4`</td> </tr> <tr> <td>`SGLANG_SCHEDULER_SKIP_ALL_GATHER=1`</td> <td>When DP attention is enabled, the scheduler normally performs an all-gather across DP ranks to determine the full set of ready requests. Setting this to `1` skips that operation, reducing decode scheduling latency. Only safe when load is balanced across DP ranks (e.g., via a round-robin load balancing policy). Disable if you observe uneven load distribution across DP ranks.</td> <td>`1`</td> </tr> </tbody> </table>Speculative decoding reduces per-token latency by generating draft tokens that are then verified by the target model:
<table> <thead> <tr> <th>Argument / Variable</th> <th>Purpose</th> <th>Reference Value</th> </tr> </thead> <tbody> <tr> <td>`--speculative-algorithm`</td> <td>Selects the speculative decoding algorithm. `NEXTN` (aliased to `EAGLE`) uses the model's built-in MTP (Multi-Token Prediction) heads, requiring no separate draft model. `EAGLE3` uses an external draft model, which can achieve higher acceptance rates at the cost of additional memory. Other built-in options include `STANDALONE`, `NGRAM`, and `DFLASH`, plus any plugin-registered name via `SpeculativeAlgorithm.register`. Choose `NEXTN` for models with native MTP support (e.g., DeepSeek V3.2/R1); choose `EAGLE3` for models without MTP (e.g., Qwen MoE).</td> <td>`NEXTN`</td> </tr> <tr> <td>`--speculative-num-steps`</td> <td>Number of speculative forward passes per iteration. More steps can increase the acceptance length and throughput but add latency. For prefill, use a small value (`1`) to minimize prefill latency impact. For decode, use a larger value (`2`–`4`) to maximize throughput. Tune based on your latency vs throughput requirements.</td> <td>Prefill: `1`, Decode: `3`</td> </tr> <tr> <td>`--speculative-eagle-topk`</td> <td>Limits the number of draft tokens considered per position. Lower values reduce compute on unlikely tokens and are required for the experimental SpecV2 overlap scheduler. Higher values may increase acceptance rates but add overhead. Start with `1` if using SpecV2; otherwise, `4`–`8` is typical.</td> <td>`1`</td> </tr> <tr> <td>`--speculative-num-draft-tokens`</td> <td>Number of draft tokens generated per speculative step. Higher values increase potential acceptance length and throughput but add per-step computation. Balance against your latency budget — prefill typically uses fewer draft tokens (`2`) to minimize overhead; decode can use more (`4`) to maximize throughput.</td> <td>Prefill: `2`, Decode: `4`</td> </tr> <tr> <td>`SGLANG_ENABLE_OVERLAP_PLAN_STREAM=1`</td> <td>Enables the overlap plan stream feature for EAGLE v2/v3 speculative decoding workers. This overlaps draft model computation with target model verification, effectively hiding draft latency. Enable when using EAGLE-based speculative decoding; not applicable for NEXTN.</td> <td>`1`</td> </tr> <tr> <td>`SGLANG_ENABLE_SPEC_V2=1`</td> <td>Enables the experimental SpecV2 overlap scheduler for speculative decoding. Works with `--speculative-eagle-topk 1` to overlap the draft generation and verification stages. Requires `SGLANG_ENABLE_OVERLAP_PLAN_STREAM=1`.</td> <td>`1`</td> </tr> </tbody> </table>The following environment variables are used in other best practice configurations and may be applicable depending on your model and deployment:
<table> <thead> <tr> <th>Variable</th> <th>Purpose</th> <th>Typical Usage</th> </tr> </thead> <tbody> <tr> <td>`HCCL_OP_EXPANSION_MODE=AIV`</td> <td>Configures the HCCL communication algorithm scheduling to use AIV (Ascend Intelligent Vision) expansion mode, which can improve communication efficiency for certain collective operations.</td> <td>Used in Qwen MoE and R1 non-NSA configurations</td> </tr> <tr> <td>`SGLANG_NPU_FUSED_MOE_MODE`</td> <td>Controls the fused MoE optimization mode on Ascend NPU. `1` is default; `2` enables a more aggressive fusion strategy (`DISPATCH_FFN_COMBINE`) that can improve MoE dispatch throughput. Mode `2` requires `--quantization modelslim`. Used primarily with DeepSeek R1 models.</td> <td>`1` or `2`</td> </tr> <tr> <td>`SGLANG_NPU_USE_MLAPO=1`</td> <td><strong>(DeepSeek MLA-specific)</strong> Adopts the `MLAPO` fusion operator in the MLA (Multi-Head Latent Attention) preprocessing stage for DeepSeek models with MLA architecture.</td> <td>Used with DeepSeek R1</td> </tr> <tr> <td>`SGLANG_USE_FIA_NZ=1`</td> <td><strong>(DeepSeek MLA-specific)</strong> Reshapes the KV Cache into FIA NZ format for improved memory access efficiency. Must be used together with `SGLANG_NPU_USE_MLAPO=1`.</td> <td>Used with DeepSeek R1</td> </tr> <tr> <td>`SGLANG_NPU_USE_MULTI_STREAM=1`</td> <td><strong>(DeepSeek MoE-specific)</strong> Enables dual-stream computation for shared experts and routing experts in DeepSeek MoE models, allowing the two expert types to execute concurrently on separate streams.</td> <td>Used with DeepSeek R1</td> </tr> <tr> <td>`SGLANG_USE_AG_AFTER_QLORA=1`</td> <td>Delays the all-gather operation until after Q-LoRA processing. This reduces communication overhead by performing Q-LoRA projection before the all-gather, requiring fewer bytes to be transferred.</td> <td>Used with DeepSeek V3.2/R1 prefill</td> </tr> </tbody> </table>