docs/advanced_features/hicache_best_practices.md
SGLang HiCache extends the traditional RadixAttention with a three-tier hierarchical KV caching system that dramatically improves performance for long-context and multi-turn conversation scenarios. By intelligently managing KV caches across GPU memory, host memory, and external storage backends, HiCache addresses the fundamental capacity bottleneck that limits cache hit rates in conventional systems.
# Essential HiCache flags
--page-size 64 # Page size for cache management
--enable-hierarchical-cache # Enable HiCache
--hicache-ratio 2 # Host memory ratio (2x GPU memory)
--hicache-size 100 # Host memory size in GBs, will override the above ratio
--hicache-io-backend kernel # The I/O backend of moving data between CPU and GPU
--hicache-write-policy write_through # Cache write policy from GPU to CPU
--hicache-storage-backend # Optional storage backend (e.g., hf3fs, mooncake, etc.)
Notes:
--hicache-storage-backend at startup, SGLang also supports runtime attach/detach of the HiCache storage backend (no restart required) via HTTP admin endpoints. See Runtime Attach/Detach HiCache Storage Backend.# Page-first: Optimized for I/O efficiency with zero-copy (recommended with kernel backend)
--hicache-mem-layout page_first
# Page-first-direct: Optimized for direct I/O operations (Compatible with fa3 and same zero-copy performance as page_first)
--hicache-mem-layout page_first_direct
# Layer-first
--hicache-mem-layout layer_first
Layout Compatibility:
page_first: Only compatible with kernel I/O backend, automatically switches to layer_first with direct backendpage_first_direct: Specifically designed for direct I/O backend with optimized memory organizationHiCache storage supports cross-cluster KV reuse when different deployments use different TP sizes (for example, tp=4 and tp=8) and share the same storage backend namespace.
Use tp_lcm_size in --hicache-storage-backend-extra-config:
# Example: heterogeneous TP = {4, 8}, so lcm = 8
--hicache-storage-backend-extra-config '{"tp_lcm_size": 8}'
Guidelines:
tp_lcm_size to the least common multiple (LCM) of all TP sizes that will share the same HiCache storage.page_head layout, HiCache will split head shards based on tp_lcm_size to make keys reusable across heterogeneous TP deployments.# Best-effort: Terminate prefetch when needed
--hicache-storage-prefetch-policy best_effort
# Wait-complete: Ensure complete prefetch, higher cache reuse
--hicache-storage-prefetch-policy wait_complete
# Timeout: Balance between completion and best-effort
--hicache-storage-prefetch-policy timeout
HiCache works seamlessly with PD Disaggregation. You can choose between two configurations:
# Prefill node with HiCache enabled for cross-prefill sharing (ideal for SystemPrompt scenarios)
python3 -m sglang.launch_server \
--model-path /xxx/DeepSeek-R1/ \
--tp 8 \
--host 0.0.0.0 \
--port 10000 \
--enable-metrics \
--enable-cache-report \
--mem-fraction-static 0.85 \
--page-size 64 \
--enable-hierarchical-cache \
--hicache-ratio 2 \
--hicache-size 0 \
--hicache-mem-layout page_first_direct \
--hicache-io-backend direct \
--hicache-write-policy write_through \
--hicache-storage-backend hf3fs \
--hicache-storage-prefetch-policy wait_complete \
--disaggregation-ib-device mlx5_0 \
--disaggregation-mode prefill \
--disaggregation-transfer-backend mooncake
# Decode node with async offloading enabled for KV cache reuse by Prefill (ideal for multi-turn conversations)
python3 -m sglang.launch_server \
--model-path /xxx/DeepSeek-R1/ \
--tp 8 \
--host 0.0.0.0 \
--port 10000 \
--enable-metrics \
--enable-cache-report \
--page-size 64 \
--hicache-ratio 2 \
--hicache-size 0 \
--hicache-mem-layout page_first_direct \
--hicache-io-backend direct \
--hicache-write-policy write_through \
--hicache-storage-backend hf3fs \
--hicache-storage-prefetch-policy wait_complete \
--disaggregation-decode-enable-offload-kvcache \ # Enable async KV cache offloading in decode node
--disaggregation-ib-device mlx5_0 \
--disaggregation-mode decode \
--disaggregation-transfer-backend mooncake
Here is an example of deploying DeepSeek-R1 with HiCache-HF3FS. For more details, see the HF3FS Documentation.
python3 -m sglang.launch_server \
--model-path /xxx/DeepSeek-R1/ \
--log-level info \
--tp 8 \
--host 0.0.0.0 \
--port 10000 \
--enable-metrics \
--enable-cache-report \
--page-size 64 \
--mem-fraction-static 0.85 \
--enable-hierarchical-cache \
--hicache-ratio 2 \
--hicache-size 0 \
--hicache-mem-layout page_first_direct \
--hicache-io-backend direct \
--hicache-write-policy write_through \
--hicache-storage-backend hf3fs \
--hicache-storage-prefetch-policy wait_complete \
Here is an example of deploying Qwen3-235B-A22B-Instruct-2507 with Mooncake. For more details, see the Mooncake Documentation.
# Set Mooncake environment variables
export MOONCAKE_TE_META_DATA_SERVER="http://127.0.0.1:8080/metadata"
export MOONCAKE_GLOBAL_SEGMENT_SIZE=816043786240
export MOONCAKE_PROTOCOL="rdma"
export MOONCAKE_DEVICE="$DEVICE_LIST"
export MOONCAKE_MASTER=127.0.0.1:50051
# Launch SGLang server with Mooncake backend
python3 -m sglang.launch_server \
--model-path $MODEL_PATH \
--tp 8 \
--page-size 64 \
--enable-hierarchical-cache \
--hicache-ratio 2 \
--hicache-mem-layout page_first_direct \
--hicache-io-backend direct \
--hicache-storage-backend mooncake \
--hicache-write-policy write_through \
--hicache-storage-prefetch-policy timeout
To integrate a new storage backend:
Implement three core methods:
get(key): Retrieve value by keyexists(key): Check key existenceset(key, value): Store key-value pairRegister your backend: Add your storage backend to the HiCache BackendFactory
The HiCache controller handles all scheduling and synchronization automatically.
Alternatively, you can use dynamic loading to avoid hard-coding your backend in the repository:
python3 -m sglang.launch_server \
--model-path your-model \
--enable-hierarchical-cache \
--hicache-storage-backend dynamic \
--hicache-storage-backend-extra-config '{"backend_name":"custom_backend_name", "module_path": "your_module_path", "class_name": "YourHiCacheClassName"}'
Configuration Parameters:
--hicache-storage-backend: Set to dynamic--hicache-storage-backend-extra-config: JSON configuration with:
backend_name: Custom backend identifiermodule_path: Python module path to your implementationclass_name: Your HiCache implementation class nameinterface_v1: 0 (disable) or 1 (enable) to control usage of batch_get_v1 and batch_set_v1 methodsThis document will be continuously updated based on community feedback and new features. Contributions and suggestions are welcome!