docs/advanced_features/sglang_for_rl.md
This document is a practical guide for infrastructure teams integrating SGLang into RL and post-training systems. It focuses on the operational pain points in the loop (rollout, evaluation, training, weight sync) and maps them to concrete SGLang APIs, flags, and integration patterns. The focus is on maximizing rollout efficiency, accuracy and stability while keeping rollout-serving behavior aligned in production environments.
Let's embrace a guiding principle from early DeepMind's RL engineering:
Be a library, not a framework.
This philosophy empowers innovation by providing SGLang as flexible tools, not rigid structures. Here are five reasons to use SGLang for your RL lifecycle:
The following sections cover these aspects in detail.
Rollout and training are both memory-intensive, and co-locating them on the same GPUs often leads to memory pressure and slow handoffs. SGLang provides a memory-aware sleep/wake mechanism that releases KV cache and weights while keeping the server process alive, then resumes them for rollout without a full restart. This avoids repeated disk I/O and CUDA graph recapture during each RL step.
Under the hood, the RL team uses CUDA-graph-aware weight offload via torch_memory_saver to preserve virtual memory addresses for graph replay. For details, see: Efficient RL Training - Optimizing Memory Usage in verl.
Enable memory saver support when launching the server:
--enable-memory-saver
Endpoint: POST /release_memory_occupation
Request body:
| Field | Description | Defaults | Options |
|---|---|---|---|
tags | Which memory regions to release. If omitted, all are released. | None | Type: list[str], values: kv_cache, weights |
Behavior notes:
kv_cache is released, SGLang flushes cache; subsequent requests will rebuild KV cache as needed.Endpoint: POST /resume_memory_occupation
Request body:
| Field | Description | Defaults | Options |
|---|---|---|---|
tags | Which memory regions to resume. If omitted, all are resumed. | None | Type: list[str], values: kv_cache, weights |
After training completes each step, rollout engines must be refit with new weights. SGLang supports three refit strategies so you can match your infrastructure style (co-located vs disaggregated) and scaling needs. Each strategy maps to a concrete API with clear request schemas. For a deeper dive into SGLang's weight update utilities, see RL System Deep Thinking: Weight Update Mechanisms.
How to choose:
When to use:
Why it works well:
This path trades some I/O overhead for simplicity and flexibility. It integrates naturally with checkpointing and makes it trivial to add new rollout engines: point them at the same checkpoint and call the API. It is also the safest option for high availability because the checkpoint itself is the source of truth.
Endpoint: POST /update_weights_from_disk
Request body:
| Field | Description | Defaults | Options |
|---|---|---|---|
model_path | The model path with the new weights. | Required | Type: str |
load_format | The format to load the weights. | None | Type: str |
abort_all_requests | Abort all running requests before update. | False | Type: bool |
weight_version | Optional weight version label tracked by the server. | None | Type: str |
is_async | Perform weight load asynchronously. | False | Type: bool |
torch_empty_cache | Empty torch cache. | False | Type: bool |
keep_pause | Keep scheduler paused after update. | False | Type: bool |
recapture_cuda_graph | Recapture CUDA graphs after update. | False | Type: bool |
token_step | Trainer step id for rollout bookkeeping. | 0 | Type: int |
flush_cache | Flush KV cache after update. | True | Type: bool |
Response body:
| Field | Description | Defaults | Options |
|---|---|---|---|
success | Whether the update succeeded. | - | Type: bool |
message | Status / error message. | - | Type: str |
num_paused_requests | Number of paused requests during update. | 0 | Type: int |
Python Engine API: engine.update_weights_from_disk(model_path, load_format=None)
Diffusion engine (SGLang-Diffusion): The diffusion engine exposes the same POST /update_weights_from_disk endpoint with the following behavior:
--dit-layerwise-offload) is enabled, the diffusion offload manager replaces GPU parameters with small torch.empty((1,)) placeholders while real weights live in consolidated pinned CPU buffers. A naive param.data.copy_() would fail with a shape mismatch. Instead, the updater dynamically detects active offload managers and writes new weights directly into their CPU buffers, bypassing the placeholders entirely. For any layer that happens to be prefetched on GPU at update time, the live GPU tensor is also updated so the change takes effect immediately. This requires no extra GPU memory and does not disturb the offload state.torch.distributed.tensor (tensor parallelism) are updated through distribute_tensor so that each shard is correctly placed on the right device mesh.Request body:
| Field | Description | Defaults | Options |
|---|---|---|---|
model_path | The model path with the new weights. | Required | Type: str |
flush_cache | Flush TeaCache state after update. | True | Type: bool |
target_modules | List of module names to update (e.g. ["transformer"]). If omitted, all nn.Module components are updated. | None | Type: list[str] |
Response body:
| Field | Description | Defaults | Options |
|---|---|---|---|
success | Whether the update succeeded. | - | Type: bool |
message | Status / error message. | - | Type: str |
Note: The diffusion engine (SGLang-Diffusion) does not currently support hot refit (updating weights while inference is in progress). The diffusion scheduler processes one request at a time and completes the entire inference before handling the next request, so weight updates and inference never run concurrently.
When to use:
Important constraints:
This strategy requires the training process and rollout engine to share access to the tensors. Co-located setups must keep the model on GPU; moving tensors to CPU will break the update path. For high-performance MoE or specialized attention kernels, co-location may limit some optimizations compared to disaggregated rollouts.
Endpoint: POST /update_weights_from_tensor
Request body:
| Field | Description | Defaults | Options |
|---|---|---|---|
serialized_named_tensors | Per-TP serialized tensor payloads. | Required | Type: list[str |
load_format | Optional load format selector. | None | None, direct, flattened_bucket, or a custom loader path string |
flush_cache | Flush KV cache after update. | True | Type: bool |
abort_all_requests | Abort all running requests before update. | False | Type: bool |
weight_version | Optional version label tracked by the server. | None | Type: str |
Note: The serialized tensor payloads must be created with MultiprocessingSerializer.serialize(...) and should be base64-safe strings.
Python Engine API: engine.update_weights_from_tensor(named_tensors, load_format=None, flush_cache=True)
When to use:
How it works:
Training workers gather weights (typically on TP rank 0), broadcast them to the rollout group, and each rollout TP shard loads the parameters it needs. This avoids disk I/O and keeps training and rollout decoupled, at the cost of managing a dedicated communication group.
Initialize weight update group
Endpoint: POST /init_weights_update_group
Request body:
| Field | Description | Defaults | Options |
|---|---|---|---|
master_address | Group master address. | Required | Type: str |
master_port | Group master port. | Required | Type: int |
rank_offset | Offset for local rank mapping. | Required | Type: int |
world_size | Total world size. | Required | Type: int |
group_name | Group name. | weight_update_group | Type: str |
backend | Communication backend. | nccl | Type: str |
Update weight
Endpoint: POST /update_weights_from_distributed
Request body:
| Field | Description | Defaults | Options |
|---|---|---|---|
names | Parameter names to update. | Required | Type: list[str] |
dtypes | Dtype strings for each parameter. | Required | Type: list[str] |
shapes | Tensor shapes. | Required | Type: list[list[int]] |
group_name | Group name. | weight_update_group | Type: str |
flush_cache | Flush KV cache after update. | True | Type: bool |
abort_all_requests | Abort all running requests before update. | False | Type: bool |
weight_version | Optional version label. | None | Type: str |
load_format | Optional format selector. | None | None or flattened_bucket |
Destroy weights update group
Endpoint: POST /destroy_weights_update_group
Request body:
| Field | Description | Defaults | Options |
|---|---|---|---|
group_name | Group name. | weight_update_group | Type: str |
Python Engine APIs:
engine.init_weights_update_group(...)engine.update_weights_from_distributed(names, dtypes, shapes, ...)engine.destroy_weights_update_group(group_name)Multi-turn RL rollouts often suffer from long-tail requests that block the entire batch. A small number of slow interactions can stall all GPUs, and the long-tail behavior makes profiling and monitoring difficult.
SGLang exposes explicit pause/resume APIs so you can pause slow requests and continue them later. This pattern matches systems like APRIL, terminate once enough responses are collected, and recycle incomplete responses in the next step. The result is higher GPU utilization without discarding partial work.
pause_generation --- update weights --- continue_generation is the correct execution flow when updating weights from training. An update can only happen when SGLang is not actively processing inference tasks.
Endpoint: POST /pause_generation
Request body:
| Field | Description | Defaults | Options |
|---|---|---|---|
mode | Pause mode. | abort | abort, retract, in_place |
Modes:
abort: Default behavior, identical to abort endpoint with abort_all set. Pending requests from waiting_queue and running_queue will be returned immediately to the caller.retract: Put engine in "paused" state. Move running requests back to waiting queue. KV cache can be flushed and recomputed later.in_place: Put engine in "paused" state without changing states of the requests. Running requests rely on availability of KV caches to continue, so any subsequent flush_cache call will be unsuccessful.Endpoint: POST /continue_generation
In many RL stacks, rollout and training are implemented with different kernels or batching behavior. Even when weights are identical, token probabilities can drift, silently breaking the on-policy assumption. This is the training–inference mismatch problem.
SGLang supports a deterministic inference mode that reduces non-determinism across batch shapes. This mitigates variance introduced by runtime batching and kernel selection. To further achieve true on-policy training, you need to modify the training engine to use the same deterministic kernels. For implementation details, see these miles examples: True On-Policy and True On-Policy for VLM. For additional context, see the blog post Let Speed Be With Stability: All-In-One Solution to Training-Inference Mismatch with Miles.
Server flag:
--enable-deterministic-inference
For more details, see Deterministic Inference
SGLang Model Gateway is the recommended control plane for large‑scale RL rollouts. It provides async, non‑blocking request handling, cache‑aware load balancing, and fault‑tolerant routing across rollout and reward servers. This lets you keep GPUs saturated while avoiding long‑tail stalls and brittle, engine‑local concurrency logic. It has been deployed in the training of GLM 4.5+ models and proven to be highly efficient in production-level large-scale RL workloads.
Key benefits for RL infrastructure:
For deployment and configuration, see: SGLang Model Gateway