docs/serving/expert_parallel_deployment.md
vLLM supports Expert Parallelism (EP), which allows experts in Mixture-of-Experts (MoE) models to be deployed on separate GPUs, increasing locality, efficiency, and throughput overall.
EP is typically coupled with Data Parallelism (DP). While DP can be used independently of EP, EP is more efficient when used in conjunction with DP. You can read more about data parallelism here.
Before using EP, you need to install the necessary dependencies. We are actively working on making this easier in the future:
gdrcopy by running the install_gdrcopy.sh script (e.g., install_gdrcopy.sh "${GDRCOPY_OS_VERSION}" "12.8" "x64"). You can find available OS versions here.vLLM provides multiple communication backends for EP. Use --all2all-backend to select one:
| Backend | Use Case | Features | Best For |
|---|---|---|---|
allgather_reducescatter | Default backend | Standard all2all using allgather/reducescatter primitives | General purpose, works with any EP+DP configuration |
deepep_high_throughput | Multi-node prefill | Grouped GEMM with continuous layout, optimized for prefill | Prefill-dominated workloads, high-throughput scenarios |
deepep_low_latency | Multi-node decode | CUDA graph support, masked layout, optimized for decode | Decode-dominated workloads, low-latency scenarios |
flashinfer_nvlink_one_sided | MNNVL systems | FlashInfer's one-sided A2A strategy for multi-node NVLink | High-throughput workloads |
flashinfer_nvlink_two_sided | MNNVL systems | FlashInfer's two-sided A2A strategy for multi-node NVLink | Systems with NVLink across nodes |
!!! warning EP is an experimental feature. Argument names and default values may change in the future.
Enable EP by setting the --enable-expert-parallel flag. The EP size is automatically calculated as:
EP_SIZE = TP_SIZE × DP_SIZE
Where:
TP_SIZE: Tensor parallel sizeDP_SIZE: Data parallel sizeEP_SIZE: Expert parallel size (computed automatically)When EP is enabled, different layers in MoE models behave differently:
| Layer Type | Behavior | Parallelism Used |
|---|---|---|
| Expert (MoE) Layers | Sharded across all EP ranks | Expert Parallel (EP) of size TP × DP |
| Attention Layers | Behavior depends on TP size | See below |
Attention layer parallelism:
TP = 1: Attention weights are replicated across all DP ranks (data parallelism)TP > 1: Attention weights are sharded using tensor parallelism across TP ranks within each DP groupFor example, with TP=2, DP=4 (8 GPUs total):
!!! note "Key Difference from Data Parallel Deployment"
Without --enable-expert-parallel, MoE layers would use tensor parallelism (forming a TP group of size TP × DP), similar to dense models. With EP enabled, expert layers switch to expert parallelism, which can provide better efficiency and locality for MoE models.
The following command serves a DeepSeek-V3-0324 model with 1-way tensor parallel, 8-way (attention) data parallel, and 8-way expert parallel. The attention weights are replicated across all GPUs, while the expert weights are split across GPUs. It will work on a H200 (or H20) node with 8 GPUs. For H100, you can try to serve a smaller model or refer to the multi-node deployment section.
# Single node EP deployment
vllm serve deepseek-ai/DeepSeek-V3-0324 \
--tensor-parallel-size 1 \ # Tensor parallelism across 1 GPU
--data-parallel-size 8 \ # Data parallelism across 8 processes
--enable-expert-parallel # Enable expert parallelism
For multi-node deployment, use the DeepEP communication kernel with one of two modes (see Backend Selection Guide above).
The following example deploys DeepSeek-V3-0324 across 2 nodes using deepep_low_latency mode:
# Node 1 (Primary - handles incoming requests)
vllm serve deepseek-ai/DeepSeek-V3-0324 \
--all2all-backend deepep_low_latency \
--tensor-parallel-size 1 \ # TP size per node
--enable-expert-parallel \ # Enable EP
--data-parallel-size 16 \ # Total DP size across all nodes
--data-parallel-size-local 8 \ # Local DP size on this node (8 GPUs per node)
--data-parallel-address 192.168.1.100 \ # Replace with actual IP of Node 1
--data-parallel-rpc-port 13345 \ # RPC communication port, can be any port as long as reachable by all nodes
--api-server-count=8 # Number of API servers for load handling (scaling this out to # local ranks is recommended)
# Node 2 (Secondary - headless mode, no API server)
vllm serve deepseek-ai/DeepSeek-V3-0324 \
--all2all-backend deepep_low_latency \
--tensor-parallel-size 1 \ # TP size per node
--enable-expert-parallel \ # Enable EP
--data-parallel-size 16 \ # Total DP size across all nodes
--data-parallel-size-local 8 \ # Local DP size on this node
--data-parallel-start-rank 8 \ # Starting rank offset for this node
--data-parallel-address 192.168.1.100 \ # IP of primary node (Node 1)
--data-parallel-rpc-port 13345 \ # Same RPC port as primary
--headless # No API server, worker only
--headless flag, meaning all client requests are handled by the primary node--data-parallel-start-rank should equal the cumulative local DP size of previous nodes--api-server-count on the primary node to handle higher request loads!!! important "InfiniBand Clusters"
On InfiniBand networked clusters, set this environment variable to prevent initialization hangs:
bash export GLOO_SOCKET_IFNAME=eth0
This ensures torch distributed group discovery uses Ethernet instead of InfiniBand for initial setup.
While MoE models are typically trained so that each expert receives a similar number of tokens, in practice the distribution of tokens across experts can be highly skewed. vLLM provides an Expert Parallel Load Balancer (EPLB) to redistribute expert mappings across EP ranks, evening the load across experts.
Enable EPLB with the --enable-eplb flag.
When enabled, vLLM collects load statistics with every forward pass and periodically rebalances expert distribution.
Configure EPLB with the --eplb-config argument, which accepts a JSON string. The available keys and their descriptions are:
| Parameter | Description | Default |
|---|---|---|
window_size | Number of engine steps to track for rebalancing decisions | 1000 |
step_interval | Frequency of rebalancing (every N engine steps) | 3000 |
log_balancedness | Log balancedness metrics (avg tokens per expert ÷ max tokens per expert) | false |
num_redundant_experts | Additional global experts per EP rank beyond equal distribution | 0 |
use_async | Use non-blocking EPLB for reduced latency overhead | false |
policy | The policy type for expert parallel load balancing | "default" |
communicator | Backend for expert weight transfers: "torch_nccl", "torch_gloo", "pynccl", "nixl", or null (auto) | null |
For example:
vllm serve Qwen/Qwen3-30B-A3B \
--enable-eplb \
--eplb-config '{"window_size":1000,"step_interval":3000,"num_redundant_experts":2,"log_balancedness":true}'
??? tip "Prefer individual arguments instead of JSON?"
```bash
vllm serve Qwen/Qwen3-30B-A3B \
--enable-eplb \
--eplb-config.window_size 1000 \
--eplb-config.step_interval 3000 \
--eplb-config.num_redundant_experts 2 \
--eplb-config.log_balancedness true
```
NUM_TOTAL_EXPERTS ÷ NUM_EP_RANKS experts(NUM_TOTAL_EXPERTS + NUM_REDUNDANT_EXPERTS) ÷ NUM_EP_RANKS expertsEPLB uses redundant experts that need to fit in GPU memory. This means that EPLB may not be a good fit for memory constrained environments or when KV cache space is at a premium.
This overhead equals NUM_MOE_LAYERS * BYTES_PER_EXPERT * (NUM_TOTAL_EXPERTS + NUM_REDUNDANT_EXPERTS) ÷ NUM_EP_RANKS.
For DeepSeekV3, this is approximately 2.4 GB for one redundant expert per EP rank.
Single node deployment with EPLB enabled:
# Single node with EPLB load balancing
vllm serve deepseek-ai/DeepSeek-V3-0324 \
--tensor-parallel-size 1 \ # Tensor parallelism
--data-parallel-size 8 \ # Data parallelism
--enable-expert-parallel \ # Enable EP
--enable-eplb \ # Enable load balancer
--eplb-config '{"window_size":1000,"step_interval":3000,"num_redundant_experts":2,"log_balancedness":true}'
For multi-node deployment, add these EPLB flags to each node's command. We recommend setting --eplb-config '{"num_redundant_experts":32}' to 32 in large scale use cases so the most popular experts are always available.
high_throughput and low_latency kernels are optimized for disaggregated serving and may show poor performance for mixed workloads--enable-dbo to overlap all-to-all communication with compute. See Dual Batch Overlap for more details.--async-scheduling to overlap scheduling with model execution.non-zero status: 7 cannot register cq buf: When using Infiniband/RoCE, make sure host VM and pods show ulimit -l "unlimited".init failed for transport: IBGDA: The InfiniBand GDA kernel modules are missing. Run tools/ep_kernels/configure_system_drivers.sh on each GPU node and reboot. Also fixes error NVSHMEM API called before NVSHMEM initialization has completed.hostNetwork: true, securityContext.privileged: true to access Infiniband.Use simulator flags VLLM_MOE_ROUTING_SIMULATION_STRATEGY=uniform_random and VLLM_RANDOMIZE_DP_DUMMY_INPUTS=1 so token routing is balanced across EP ranks.
Increasing VLLM_MOE_DP_CHUNK_SIZE may increase throughput by increasing the maximum batch size for inter-rank token transfers. This may cause DeepEP to throw assert self.nvshmem_qp_depth >= (num_max_dispatch_tokens_per_rank + 1) * 2, which can be fixed by increasing environment variable NVSHMEM_QP_DEPTH.
For production deployments requiring strict SLA guarantees for time-to-first-token and inter-token latency, disaggregated serving allows independent scaling of prefill and decode operations.
deepep_high_throughput backend for optimal prefill performancedeepep_low_latency backend for minimal decode latencyInstall gdrcopy/ucx/nixl: For maximum performance, run the install_gdrcopy.sh script to install gdrcopy (e.g., install_gdrcopy.sh "${GDRCOPY_OS_VERSION}" "12.8" "x64"). You can find available OS versions here. If gdrcopy is not installed, things will still work with a plain pip install nixl, just with lower performance. nixl and ucx are installed as dependencies via pip. For non-cuda platform to install nixl with non-cuda UCX build, run the install_nixl_from_source_ubuntu.py script.
Configure Both Instances: Add this flag to both prefill and decode instances --kv-transfer-config '{"kv_connector":"NixlConnector","kv_role":"kv_both"}. Noted, you may also specify one or multiple NIXL_Backend. Such as: --kv-transfer-config '{"kv_connector":"NixlConnector","kv_role":"kv_both", "kv_connector_extra_config":{"backends":["UCX", "GDS"]}}'
Client Orchestration: Use the client-side script below to coordinate prefill/decode operations. We are actively working on routing solutions.
from openai import OpenAI
import uuid
try:
# 1: Set up clients for prefill and decode instances
openai_api_key = "EMPTY" # vLLM doesn't require a real API key
# Replace these IP addresses with your actual instance addresses
prefill_client = OpenAI(
api_key=openai_api_key,
base_url="http://192.168.1.100:8000/v1", # Prefill instance URL
)
decode_client = OpenAI(
api_key=openai_api_key,
base_url="http://192.168.1.101:8001/v1", # Decode instance URL
)
# Get model name from prefill instance
models = prefill_client.models.list()
model = models.data[0].id
print(f"Using model: {model}")
# 2: Prefill Phase
# Generate unique request ID to link prefill and decode operations
request_id = str(uuid.uuid4())
print(f"Request ID: {request_id}")
prefill_response = prefill_client.completions.create(
model=model,
# Prompt must exceed vLLM's block size (16 tokens) for PD to work
prompt="Write a detailed explanation of Paged Attention for Transformers works including the management of KV cache for multi-turn conversations",
max_tokens=1, # Force prefill-only operation
extra_body={
"kv_transfer_params": {
"do_remote_decode": True, # Enable remote decode
"do_remote_prefill": False, # This is the prefill instance
"remote_engine_id": None, # Will be populated by vLLM
"remote_block_ids": None, # Will be populated by vLLM
"remote_host": None, # Will be populated by vLLM
"remote_port": None, # Will be populated by vLLM
}
},
extra_headers={"X-Request-Id": request_id},
)
print("-" * 50)
print("✓ Prefill completed successfully")
print(f"Prefill response: {prefill_response.choices[0].text}")
# 3: Decode Phase
# Transfer KV cache parameters from prefill to decode instance
decode_response = decode_client.completions.create(
model=model,
prompt="This prompt is ignored during decode", # Original prompt not needed
max_tokens=150, # Generate up to 150 tokens
extra_body={
"kv_transfer_params": prefill_response.kv_transfer_params # Pass KV cache info
},
extra_headers={"X-Request-Id": request_id}, # Same request ID
)
print("-" * 50)
print("✓ Decode completed successfully")
print(f"Final response: {decode_response.choices[0].text}")
except Exception as e:
print(f"❌ Error during disaggregated serving: {e}")
print("Check that both prefill and decode instances are running and accessible")
To simulate the decode deployment of disaggregated serving, pass --kv-transfer-config '{"kv_connector":"DecodeBenchConnector","kv_role":"kv_both"}' to the vllm serve invocation. The connector populates KV cache with random values so decode can be profiled in isolation.
CUDAGraph capture: Use --compilation_config '{"cudagraph_mode": "FULL_DECODE_ONLY"}' to enable CUDA graph capture for decode only and save KV cache.