docs/features/nixl_connector_usage.md
NixlConnector is a high-performance KV cache transfer connector for vLLM's disaggregated prefilling feature. It provides fully asynchronous send/receive operations using the NIXL library for efficient cross-process KV cache transfer.
For feature compatibility details (supported model architectures, TP configurations, and feature interactions), see the NixlConnector Compatibility Matrix.
Install the NIXL library: uv pip install nixl, as a quick start on Nvidia platform.
For ROCm platform, the base ROCm docker file includes RIXL and ucx already.
For non-cuda platform, please install nixl with ucx build from source, instructed as below.
python tools/install_nixl_from_source_ubuntu.py
NixlConnector uses NIXL library for underlying communication, which supports multiple transport backends. UCX (Unified Communication X) is the primary default transport library used by NIXL. Configure transport environment variables:
# Example UCX configuration, adjust according to your environment
export UCX_TLS=all # or specify specific transports like "rc,ud,sm,^cuda_ipc" ..etc
export UCX_NET_DEVICES=all # or specify network devices like "mlx5_0:1,mlx5_1:1"
!!! tip
When using UCX as the transport backend, NCCL environment variables (like NCCL_IB_HCA, NCCL_SOCKET_IFNAME) are not applicable to NixlConnector, so configure UCX-specific environment variables instead of NCCL variables.
NixlConnector can use different NIXL transport backends (plugins). By default, NixlConnector uses UCX as the transport backend.
To select a different backend, set kv_connector_extra_config.backends in --kv-transfer-config.
vllm serve <MODEL> \
--kv-transfer-config '{
"kv_connector":"NixlConnector",
"kv_role":"kv_both",
"kv_connector_extra_config":{"backends":["LIBFABRIC"]}
}'
You can also pass JSON keys individually using dotted arguments, and you can append list elements using +:
vllm serve <MODEL> \
--kv-transfer-config.kv_connector NixlConnector \
--kv-transfer-config.kv_role kv_both \
--kv-transfer-config.kv_connector_extra_config.backends+ LIBFABRIC
!!! note Backend availability depends on how NIXL was built and what plugins are present in your environment. Refer to the NIXL repository for available backends and build instructions.
Start a prefiller instance that produces KV caches
# 1st GPU as prefiller
CUDA_VISIBLE_DEVICES=0 \
UCX_NET_DEVICES=all \
VLLM_NIXL_SIDE_CHANNEL_PORT=5600 \
vllm serve Qwen/Qwen3-0.6B \
--port 8100 \
--enforce-eager \
--kv-transfer-config '{"kv_connector":"NixlConnector","kv_role":"kv_both","kv_load_failure_policy":"fail"}'
Start a decoder instance that consumes KV caches:
# 2nd GPU as decoder
CUDA_VISIBLE_DEVICES=1 \
UCX_NET_DEVICES=all \
VLLM_NIXL_SIDE_CHANNEL_PORT=5601 \
vllm serve Qwen/Qwen3-0.6B \
--port 8200 \
--enforce-eager \
--kv-transfer-config '{"kv_connector":"NixlConnector","kv_role":"kv_both","kv_load_failure_policy":"fail"}'
Use a proxy server to route requests between prefiller and decoder:
python tests/v1/kv_connector/nixl_integration/toy_proxy_server.py \
--port 8192 \
--prefiller-hosts localhost \
--prefiller-ports 8100 \
--decoder-hosts localhost \
--decoder-ports 8200
VLLM_NIXL_SIDE_CHANNEL_PORT: Port for NIXL handshake communication
--data-parallel-size=2 and base_port=5600, dp_rank 0..1 use port 5600, 5601 on that node).VLLM_NIXL_SIDE_CHANNEL_HOST: Host for side channel communication
VLLM_NIXL_ABORT_REQUEST_TIMEOUT: Timeout (in seconds) for automatically releasing the prefiller’s KV cache for a particular request. (Optional)
# Prefiller 1 on Machine A (example IP: ${IP1})
VLLM_NIXL_SIDE_CHANNEL_HOST=${IP1} \
VLLM_NIXL_SIDE_CHANNEL_PORT=5600 \
UCX_NET_DEVICES=all \
vllm serve Qwen/Qwen3-0.6B --port 8000 \
--tensor-parallel-size 8 \
--kv-transfer-config '{"kv_connector":"NixlConnector","kv_role":"kv_producer","kv_load_failure_policy":"fail"}'
# Prefiller 2 on Machine B (example IP: ${IP2})
VLLM_NIXL_SIDE_CHANNEL_HOST=${IP2} \
VLLM_NIXL_SIDE_CHANNEL_PORT=5600 \
UCX_NET_DEVICES=all \
vllm serve Qwen/Qwen3-0.6B --port 8000 \
--tensor-parallel-size 8 \
--kv-transfer-config '{"kv_connector":"NixlConnector","kv_role":"kv_producer","kv_load_failure_policy":"fail"}'
# Decoder 1 on Machine C (example IP: ${IP3})
VLLM_NIXL_SIDE_CHANNEL_HOST=${IP3} \
VLLM_NIXL_SIDE_CHANNEL_PORT=5600 \
UCX_NET_DEVICES=all \
vllm serve Qwen/Qwen3-0.6B --port 8000 \
--tensor-parallel-size 8 \
--kv-transfer-config '{"kv_connector":"NixlConnector","kv_role":"kv_consumer","kv_load_failure_policy":"fail"}'
# Decoder 2 on Machine D (example IP: ${IP4})
VLLM_NIXL_SIDE_CHANNEL_HOST=${IP4} \
VLLM_NIXL_SIDE_CHANNEL_PORT=5600 \
UCX_NET_DEVICES=all \
vllm serve Qwen/Qwen3-0.6B --port 8000 \
--tensor-parallel-size 8 \
--kv-transfer-config '{"kv_connector":"NixlConnector","kv_role":"kv_consumer","kv_load_failure_policy":"fail"}'
python tests/v1/kv_connector/nixl_integration/toy_proxy_server.py \
--port 8192 \
--prefiller-hosts ${IP1} ${IP2} \
--prefiller-ports 8000 8000 \
--decoder-hosts ${IP3} ${IP4} \
--decoder-ports 8000 8000
For multi-host DP deployment, only need to provide the host/port of the head instances.
!!! tip
NixlConnector currently does not distinguish kv_role; the actual prefiller/decoder roles are determined by the upper-level proxy (e.g., toy_proxy_server.py using --prefiller-hosts and --decoder-hosts).
Therefore, kv_role in --kv-transfer-config is effectively a placeholder and does not affect NixlConnector's behavior.
The kv_load_failure_policy setting controls how the system handles failures when the decoder instance loads KV cache blocks from the prefiller instance:
!!! warning
Using kv_load_failure_policy="recompute" can lead to performance degradation in production deployments. When KV loads fail, the decode instance will execute prefill work with decode-optimized configurations, which is inefficient and defeats the purpose of disaggregated prefilling. This also increases tail latency for other ongoing decode requests.
Support use case: Prefill with 'HND' and decode with 'NHD' with experimental configuration
--kv-transfer-config '{..., "enable_permute_local_kv":"True"}'
By default, this feature is disabled. On attention backends that support this feature, each logical block is contiguous in physical memory. This reduces the number of buffers that need to be transferred. To enable this feature:
--kv-transfer-config '{..., "kv_connector_extra_config": {"enable_cross_layers_blocks": "True"}}'
Refer to these example scripts in the vLLM repository: