docs_new/docs/advanced_features/expert_parallelism.mdx
Expert Parallelism (EP) in SGLang distributes expert weights across multiple devices in Mixture-of-Experts (MoE) models, addressing memory bottlenecks and enabling efficient scaling for high-performance inference. It is particularly vital for serving large-scale MoE models where tokens are dynamically routed to specialized experts across GPUs. By leveraging optimized all-to-all communication and grouped matrix multiplications (GEMMs), EP reduces latency, boosts throughput, and minimizes idle GPU time. SGLang's EP offers strong extensibility through its modular framework, allowing seamless integration of custom kernels, backends, and optimizations without refactoring core logic, supporting diverse hardware and quantization schemes.
SGLang's EP integrates diverse, highly efficient backends for different use cases, allowing fine-grained control over performance trade-offs. Users specify backends via command-line flags:
--moe-a2a-backend: Selects the backend for all-to-all communication.--moe-runner-backend: Selects the backend for MoE computation.DeepEP and Mooncake backends support two modes for token dispatch: normal mode (optimized for prefill workloads with high throughput) and low_latency mode (optimized for decode workloads with low latency and CUDA Graph compatibility). MORI backend only supports normal mode now. NIXL-EP currently operates in low-latency mode with CUDA Graph support. Users are recommended to set --deepep-mode auto to enable automatic dispatch mode switching during runtime. Setting --deepep-mode normal or --deepep-mode low_latency is useful for debugging or development purposes.
Currently, DeepEP, Mooncake, NIXL-EP, ascend_fuseep and MORI only support cases where ep_size = tp_size. For hybrid EP and TP (i.e., ep_size < tp_size), only the none backend (All-Reduce or All-Gather-based dispatching) is supported.
Launch with DeepEP and DeepGEMM for DeepSeek-V3:
python -m sglang.launch_server --model-path deepseek-ai/DeepSeek-V3 --moe-a2a-backend deepep --moe-runner-backend deep_gemm --tp 8 --ep 8
SGLang's EP framework provides modular abstractions for easy integration of custom kernels, backends, and optimizations. It decouples the MoE forward pass into stages (dispatch → pre-permute → core runner → post-permute → combine), enabling seamless extensions without refactoring core logic.
The framework centers on FusedMoE as the unified entry point for a single, extensible structure. Key components include:
BaseDispatcher subclasses).MoeRunnerCore implementations (e.g., TritonRunnerCore).register_pre_permute and register_post_permute for dynamic modes, or register_fused_func for static, torch.compile-compatible fused operations).This design supports multiple backends via --moe-a2a-backend and --moe-runner-backend, with quantization integrated through a standardized apply() method. The computation flow ensures modularity:
[input_hidden_states]
|
v
TopK.forward -> select_experts / triton_kernels.routing / bypass
|
v
[TopKOutput]
|
v
FusedMoE.forward -> Dispatcher.dispatch -> DeepEP / bypass
| |
| v
| [DispatchOutput]
| |
| v
| quant_method.apply -> MoeRunner.forward
| | |
| | v
| | pre-permute + grouped_gemm + post-permute
| | |
| |--------------
| v
| [CombineInput]
| |
| v
| Dispatcher.combine -> DeepEP / bypass
| |
|---------------------
v
[final_hidden_states]
For details, see the MoE Refactor Roadmap.
To add a new backend:
BaseDispatcher subclass with dispatch and combine methods.MoeRunnerCore subclass for core operations (e.g., grouped GEMMs).RunnerInput, RunnerOutput).register_fused_func for end-to-end operations.register_pre_permute and register_post_permute for flexible layouts.See the MoE Refactor Implementation PR for full changes, including type hints and config expansions.
For an example implementation, see moe_runner/triton.py, which demonstrates Triton-based grouped GEMMs with registered fused and permutation functions.
SGLang's EP employs advanced overlap techniques to hide communication latency behind computation, maximizing GPU utilization in MoE layers.
TBO splits requests into micro-batches, interleaving attention computation with dispatch/combine operations. Yield points in the execution graph allow pausing for overlaps, increasing overall throughput without peak memory spikes:
operations = [
self._forward_attn,
YieldOperation(), # Overlap with dispatch of prior micro-batch
self._forward_dispatch,
self._forward_mlp,
YieldOperation(), # Overlap with combine
self._forward_combine,
]
Users need to specify --enable-two-batch-overlap to unlock up to 2x throughput. For details, see the Large-Scale EP Blog.
SGLang introduces a dispatcher-hook system for Single-Batch Overlap (SBO), enabling the overlap of operations within a single batch—such as shared experts computation with communication—while decentralizing logic to enhance modularity. These hooks execute before and after the dispatch and combine operations without modifying core MoE modules. This design simplifies interfaces, reduces coupling, and improves extensibility. For implementation details and an example of overlapping shared experts with DeepEP's combine operation, refer to PR #13327. Users can set --enable-single-batch-overlap to enable this feature.
SGLang integrates the Expert Parallelism Load Balancer (EPLB) from DeepSeek to address routing imbalances in MoE models. By analyzing expert activation statistics, EPLB computes an optimal expert arrangement, strategically placing or replicating experts to minimize GPU utilization variance, reduce idle cycles, and enhance scalability.
To enable EPLB, use the flags --enable-eplb. For optimal performance, increase batch sizes to stabilize activation statistics and configure periodic rebalancing (e.g., every 1000 requests) to adapt to evolving workloads. Simulations demonstrate significant improvements in load balancedness (ratio of mean to max computation time), correlating strongly with throughput gains.
For more details, refer to the EPLB Section in the Large-Scale EP Blog and the EPLB Repository.
When utilizing speculative decoding with MTP on MoE architectures, use the --speculative-moe-runner-backend and --speculative-moe-a2a-backend arguments to customize the MoE layer behavior for the draft model. While they default to the target model’s settings, users can differentiate them for varying precisions between target and draft models.
For model like nvidia/DeepSeek-R1-0528-NVFP4-v2, the target model uses NVFP4 precision while the draft model uses BF16. To apply flashinfer_trtllm kernel for target MoE layer while falling back to triton fused MoE kernel for draft MoE layer, users can set the arguments as follows:
...
--moe-runner-backend flashinfer_trtllm \
--speculative-moe-runner-backend triton \
...
--moe-a2a-backend only supports deepep and ascend_fuseep backends,
deepep: The mechanism is consistent with the above description.
ascend_fuseep: Offer a large fused operator which integrates all operations between dispatch and combine to boost MoE computation. Only used for decode stage in PD Disaggregation Mode.
--moe-runner-backend parameter does not need to be configured.
--deepep-mode:
In PD mixed mode, please set --deepep-mode auto.
In PD Disaggregation Mode, prefill instance sets --deepep-mode normal, and decode instance sets --deepep-mode low_latency.
DeepEP Ascend is the adapted version of the DeepEP communication library for Huawei Ascend NPUs, specifically designed for Mixture-of-Experts (MoE) model Expert Parallelism (EP). It supports the Ant-moving Function (Split the sequence length into rounds for streaming batch transmission) to optimize the buffer size occupied during collective communication in prefill stage, especially for long sequences.
Ant-moving Function can be enabled for both the dispatch and combine phases via the following environment variables:
DEEPEP_NORMAL_LONG_SEQ_PER_ROUND_TOKENS: Enable ant-moving function in dispatch stage. Indicates the number of tokens transmitted per round on each rank, default 8192.
DEEPEP_NORMAL_LONG_SEQ_ROUND: Enable ant-moving function in dispatch stage. Indicates the number of rounds transmitted on each rank, default 1.
DEEPEP_NORMAL_COMBINE_ENABLE_LONG_SEQ: Enable ant-moving function in combine stage, default 0 (means disabled).
DEEPEP_NORMAL_LONG_SEQ_PER_ROUND_TOKENS * DEEPEP_NORMAL_LONG_SEQ_ROUND means input sequence length. When the input sequence length exceeds 8192, it is recommended to enable the ant-moving function in both dispatch and combine phase.
The environment variable HCCL_BUFFSIZE is used to configure the buffer size (MB) actually allocated. Its calculation formula is as follows:
# Enable Ant-moving Function
HCCL_BUFFSIZE >= 2 * (102MB + 4MB + DEEPEP_NORMAL_LONG_SEQ_PER_ROUND_TOKENS * (hidden_size + hidden_size + hidden_size) * topk) + PADDING_BUFFSIZE
# Disable Ant-moving Function
HCCL_BUFFSIZE >= 2 * (102MB + 4MB + TOTAL_SEQ_LEN * (hidden_size + hidden_size) * topk) + PADDING_BUFFSIZE
Wherein the parameters are described as follows:
hidden_size: hidden size in model config.
topk: The number of selected routing experts.
TOTAL_SEQ_LEN: input sequence length.
PADDING_BUFFSIZE: A value of 20 or greater is recommended.