docs/algo/rollout_corr.md
Author: Yingru Li
Last updated: 10/30/2025.
📖 Documentation Structure
- This document - Practical usage guide: configurations, presets, troubleshooting
- Mathematical Formulations - Theoretical foundations, derivations, and algorithmic details
Start here for implementation, refer to the math doc for theory and design rationale.
This document provides a comprehensive overview of the Rollout Correction implementation in verl.
Note on Naming: This feature is called "Rollout Correction" to reflect the complete functionality: importance sampling (IS) weights and rejection sampling (RS). The internal variable rollout_is_weights retains its name as it specifically refers to the IS weights component.
@online{liu-li-2025-rl-collapse,
title = {When Speed Kills Stability: Demystifying {RL} Collapse from the Training-Inference Mismatch},
author = {Liu, Jiacai and Li, Yingru and Fu, Yuqian and Wang, Jiawei and Liu, Qian and Shen, Yu},
year = {2025},
month = sep,
url = {https://richardli.xyz/rl-collapse}
}
@article{li2025trust,
title={Trust Region Masking for Long-Horizon LLM Reinforcement Learning},
author={Li, Yingru and Liu, Jiacai and Xu, Jiawei and Tong, Yuxuan and Li, Ziniu and Liu, Qian and Wang, Baoxiang},
journal={arXiv preprint arXiv:2512.23075},
year={2025}
}
Rollout Correction provides a unified framework to handle general off-policy problems in RL training. Any scenario where the data collection distribution differs from the training distribution can benefit from these methods.
Common off-policy scenarios:
Policy Mismatch (Implementation Differences)
Temporal Lag (Model Staleness)
Replay Buffers
Off-Policy Algorithms
Data Quality Filtering
These off-policy gaps can cause training instability and policy collapse. Rollout Correction uses importance sampling (IS) weights and rejection sampling (RS) to correct for any distribution shift between data collection and training.
Important Note on Common Implementation Mistakes:
Many LLM-RL implementations incorrectly apply PPO by ignoring the actual rollout policy π_rollout and assuming the training reference policy π_old is the behavior policy. This is mathematically incorrect when π_rollout ≠ π_old (which is typical in LLM-RL due to precision/backend differences between rollout and training).
This is not PPO's fault - PPO itself is mathematically correct. The issue is the incorrect assumption that π_old = π_rollout in naive implementations.
This critical implementation mistake that leads to RL training collapse was identified in the blog post "When Speed Kills Stability: Demystifying RL Collapse from the Training-Inference Mismatch" and motivated the development of this rollout correction framework.
Mathematically correct approaches:
See Mathematical Formulations for detailed explanation.
The implementation cleanly separates two orthogonal mechanisms:
IS Weights (rollout_is_weights): Continuous reweighting for gradient correction
.clamp(max=rollout_is_threshold) (TIS: Truncated Importance Sampling)Rejection Sampling (modified_response_mask): Binary filtering for outlier exclusion
This separation ensures:
NEW: We now provide typed configuration with verified presets for common scenarios. These presets have been validated with tens of thousands of GPU hours across various models and training scenarios.
from verl.trainer.config.algorithm import RolloutCorrectionConfig
# === Decoupled PPO mode (3 policies: π_rollout, π_old, π_θ) ===
# IS weights correct for gap between π_old and π_rollout
config = RolloutCorrectionConfig.decoupled_token_is() # Token-TIS
config = RolloutCorrectionConfig.decoupled_seq_is() # Seq-TIS
config = RolloutCorrectionConfig.decoupled_seq_is_rs() # Seq-MIS
config = RolloutCorrectionConfig.decoupled_geo_rs() # Geo-RS (ratio mode)
config = RolloutCorrectionConfig.decoupled_geo_rs_token_tis() # Geo-RS + Token-TIS
# === K3 KL Estimator presets (more stable for small KL) ===
config = RolloutCorrectionConfig.decoupled_k3_rs() # K3-RS only
config = RolloutCorrectionConfig.decoupled_k3_rs_token_tis() # K3-RS + Token-TIS
# === Bypass PPO mode (2 policies: π_rollout = π_old, π_θ) - fast ===
# PPO ratio handles IS, so no explicit IS weights needed
config = RolloutCorrectionConfig.bypass_ppo_clip() # PPO-clip only
config = RolloutCorrectionConfig.bypass_ppo_clip_geo_rs() # PPO-clip + Geo-RS
config = RolloutCorrectionConfig.bypass_ppo_clip_k3_rs() # PPO-clip + K3-RS
# === Bypass PG mode (2 policies, no PPO clipping) - fast ===
# IS weights computed on-the-fly as π_θ / π_rollout
config = RolloutCorrectionConfig.bypass_pg_is() # Seq-TIS + PG
config = RolloutCorrectionConfig.bypass_pg_geo_rs() # Geo-RS + PG
config = RolloutCorrectionConfig.bypass_pg_geo_rs_token_tis() # Geo-RS + Token-TIS + PG
# === Other ===
config = RolloutCorrectionConfig.disabled() # Metrics only (no correction)
For advanced customization or YAML-based configs:
algorithm:
rollout_correction:
rollout_is: token # IS weights: "token", "sequence", or null
rollout_is_threshold: 2.0 # Upper threshold for IS weights
rollout_is_batch_normalize: false # Batch normalize IS weights to mean=1.0
rollout_rs: null # Rejection sampling: comma-separated canonical options (e.g. "token_k1,seq_max_k2")
rollout_rs_threshold: null # Threshold spec: float(s) or "lower_upper" string(s)
bypass_mode: false # Skip old_log_prob computation (sets π_old = π_rollout)
loss_type: ppo_clip # Loss type in bypass mode: "ppo_clip" (default) or "reinforce"
# REQUIRED: Enable log prob calculation
actor_rollout_ref:
rollout:
calculate_log_probs: true
verl/trainer/ppo/rollout_corr_helper.py - Contains compute_rollout_correction_and_rejection_mask() and compute_offpolicy_metrics()verl/trainer/ppo/core_algos.py - Rollout Correction integration with PPO and REINFORCE modes (compute_policy_loss_bypass_mode(), compute_policy_loss_reinforce())verl/trainer/ppo/ray_trainer.py - Bypass mode implementation (skips old_log_prob computation)verl/workers/actor/dp_actor.py - Mode selection logic and metrics collectionverl/trainer/config/algorithm.py - Rollout Correction parameters in RolloutCorrectionConfigverl/workers/config/actor.py - Rollout Correction parameters in PolicyLossConfigverl/trainer/config/actor/actor.yaml - Rollout Correction configuration sectionverl/trainer/config/ppo_trainer.yaml - Algorithm config with Rollout Correctiondocs/examples/config.rst - Configuration parameter descriptionsrecipe/dapo/run_dapo_qwen2.5_32b_rollout_corr.sh - DAPO example with Rollout Correctionexamples/rollout_correction/run_with_rollout_corr.sh - Basic exampleexamples/rollout_correction/run_with_rollout_corr_multi_rs.sh - Multi-RS exampletests/trainer/ppo/test_rollout_corr.py - Unit tests for IS/RS mechanismstests/trainer/ppo/test_rollout_corr_integration.py - Integration testsAll parameters are under algorithm.rollout_correction:
rollout_is (str or null)Importance sampling weights aggregation level:
null = No IS weights computed (metrics-only mode)"token": Per-token IS weights
"sequence": Per-sequence weight ρ_seq = ∏_t ρ_t
All IS weights are safety-bounded to [exp(-20), exp(20)] ≈ [2e-9, 5e8]
rollout_is_threshold (float)Upper threshold for IS weight truncation. Default: 2.0
.clamp(max=rollout_is_threshold) (TIS: Truncated Importance Sampling)rollout_rs parameters)rollout_is_batch_normalize (bool)Apply batch normalization to IS weights. Default: False
True: Normalize IS weights to have mean=1.0 within each batch
False: Use raw (truncated) IS weightsrollout_rs (str or null)Rejection sampling aggregation modes. Supply a comma-separated string (spaces optional) using the canonical options implemented in rollout_corr_helper:
token_k1: Token-level rejection with -log r bounds (ratio thresholds supplied as lower_upper). Example: "0.6_1.4"token_k2: Token-level rejection with 0.5 * (log r)^2 (upper bound only)token_k3: Token-level rejection with exp(log r) - 1 - log r (upper bound only)seq_sum_k1: Sequence-level rejection with sum of -log r (ratio bounds)seq_sum_k2: Sequence-level rejection with sum of 0.5 * (log r)^2 (upper bound only)seq_sum_k3: Sequence-level rejection with sum of exp(log r) - 1 - log r (upper bound only)seq_mean_k1: Sequence-level rejection with mean of -log r (ratio bounds)seq_mean_k2: Sequence-level rejection with mean of 0.5 * (log r)^2 (upper bound only)seq_mean_k3: Sequence-level rejection with mean of exp(log r) - 1 - log r (upper bound only)seq_max_k2: Sequence-level rejection with max of 0.5 * (log r)^2 (upper bound only)seq_max_k3: Sequence-level rejection with max of exp(log r) - 1 - log r (upper bound only)rollout_rs_threshold (str, float, or null)Threshold specification for rejection sampling.
*k1): Use "lower_upper" strings (e.g. "0.7_1.3"). Supplying a float implies only the upper bound; the lower bound defaults to its reciprocal.*k2/*k3): Supply positive upper bounds (float or numeric string).null to disable thresholds entirely (only valid when rollout_rs is null).The rollout correction framework is built from orthogonal components that can be combined flexibly. Understanding these components helps you choose the right configuration for your scenario.
Operating Mode (Section: Operation Modes)
Loss Function (in bypass mode, controlled by loss_type)
loss_type="ppo_clip", default): PPO clipped objective (IS handled by ratio)loss_type="reinforce"): Policy gradient with explicit IS weights (no clipping)IS/RS Aggregation Level
See Mathematical Formulations for detailed theory.
This section provides detailed guidance on choosing and using the verified presets. Each preset is a specific combination of components optimized for common scenarios.
| Preset Method | Estimator | Mode | IS Level | RS Level | Properties |
|---|---|---|---|---|---|
| Decoupled PPO Mode (3 policies: π_rollout, π_old, π_θ) | |||||
decoupled_token_is() | Token-TIS | Decoupled | token | - | Token-level IS weights |
decoupled_seq_is() | Seq-TIS | Decoupled | sequence | - | Sequence-level IS weights |
decoupled_seq_is_rs() | Seq-MIS | Decoupled | sequence | sequence | Sequence IS + seq_sum_k1 RS |
decoupled_geo_rs() | Geo-RS | Decoupled | - | sequence | Geometric RS (seq_mean_k1) |
decoupled_geo_rs_token_tis() | Geo-RS-Token-TIS | Decoupled | token | sequence | Geometric RS + token IS |
| K3 KL Estimator (more stable for small KL values) | |||||
decoupled_k3_rs() | K3-RS | Decoupled | - | sequence | seq_mean_k3 RS |
decoupled_k3_rs_token_tis() | K3-RS-Token-TIS | Decoupled | token | sequence | seq_mean_k3 RS + token IS |
| Bypass Mode (PPO-clip) (2 policies; ratio handles IS, RS masks outliers) | |||||
bypass_ppo_clip() | - | Bypass (PPO-clip) | - | - | PPO-clip only |
bypass_ppo_clip_geo_rs() | Geo-RS | Bypass (PPO-clip) | - | sequence | PPO-clip + Geo-RS |
bypass_ppo_clip_k3_rs() | K3-RS | Bypass (PPO-clip) | - | sequence | PPO-clip + K3-RS |
| Bypass Mode (REINFORCE) (2 policies; explicit IS weights, no PPO clipping) | |||||
bypass_pg_is() | Seq-TIS | Bypass (REINFORCE) | sequence | - | REINFORCE with explicit IS |
bypass_pg_geo_rs() | Geo-RS | Bypass (REINFORCE) | - | sequence | REINFORCE with Geo-RS |
bypass_pg_geo_rs_token_tis() | Geo-RS-Token-TIS | Bypass (REINFORCE) | token | sequence | REINFORCE + Geo-RS + token IS |
| Other | |||||
disabled() | - | - | - | - | Metrics only, no correction |
Note:
loss_type to select the loss function:
"ppo_clip" (default): PPO clipped objective where ratio = π_θ/π_rollout already handles IS"reinforce": REINFORCE with explicit IS weights as π_θ/π_rolloutOther supported combinations without preset methods:
See detailed configuration examples below for manual configurations.
Key properties:
bypass_pg_rs) uses bypass + geometric RS with loss_type="reinforce" (no IS weights)decoupled_token_is)Configuration:
config = RolloutCorrectionConfig.decoupled_token_is(threshold=2.0)
Components:
Equivalent YAML:
algorithm:
rollout_correction:
rollout_is: token
rollout_is_threshold: 2.0
rollout_rs: null
bypass_mode: false # Decoupled mode
Properties:
Theory: See rollout_corr_math.md §3.3.1
decoupled_seq_is)Also known as: Seq-TIS (Sequence-Level Truncated IS)
Configuration:
config = RolloutCorrectionConfig.decoupled_seq_is(threshold=2.0)
Components:
Equivalent YAML:
algorithm:
rollout_correction:
rollout_is: sequence
rollout_is_threshold: 2.0
rollout_rs: null
bypass_mode: false # Decoupled mode
Properties:
Theory: See rollout_corr_math.md §3.3.2
decoupled_seq_is_rs)Also known as: Seq-MIS (Sequence-Level Masked IS)
Configuration:
config = RolloutCorrectionConfig.decoupled_seq_is_rs(is_threshold=2.0, rs_threshold="0.5_2.0")
Components:
Equivalent YAML:
algorithm:
rollout_correction:
rollout_is: sequence
rollout_is_threshold: 2.0
rollout_rs: seq_sum_k1
rollout_rs_threshold: 0.5_2.0
bypass_mode: false # Decoupled mode
Properties:
When to use Seq-MIS over Seq-TIS:
Theory: See rollout_corr_math.md §3.5
bypass_ppo_clip)Configuration:
config = RolloutCorrectionConfig.bypass_ppo_clip()
Components:
Equivalent YAML:
rollout_correction:
rollout_is: null
rollout_rs: null
bypass_mode: true
loss_type: ppo_clip
Properties:
actor.compute_log_prob() forward pass (2 policies instead of 3)bypass_ppo_clip_geo_rs() for RSConfiguration requirement:
actor_rollout_ref.rollout.calculate_log_probs: trueAdditional requirements for bypass mode:
actor_rollout_ref.actor.use_rollout_log_probs: trueactor_rollout_ref.actor.policy_loss.loss_mode: bypass_modeactor_rollout_ref.actor.policy_loss.rollout_correctionTheory: See rollout_corr_math.md §3.1.2
bypass_pg_is)Configuration:
config = RolloutCorrectionConfig.bypass_pg_is(threshold=2.0)
Components:
Equivalent YAML:
rollout_correction:
rollout_is: sequence
rollout_is_threshold: 2.0
rollout_rs: null
bypass_mode: true
loss_type: reinforce # REINFORCE with explicit IS weights
Properties:
Theory: See rollout_corr_math.md §3.2.2
These configurations are fully supported but don't have convenience preset methods yet.
token_is_rs)Token-level IS weights with token-level RS mask.
Python:
config = RolloutCorrectionConfig(
rollout_is="token",
rollout_is_threshold=2.0,
rollout_rs="token_k1",
rollout_rs_threshold=2.0,
)
Properties: Per-token IS weights + per-token RS mask.
token_rs)Token-level RS only, no IS weights.
Python:
config = RolloutCorrectionConfig(
rollout_is=None,
rollout_rs="token_k1",
rollout_rs_threshold=2.0,
)
Properties: Token-level RS mask, no IS reweighting.
seq_rs)Sequence-level RS only, no IS weights.
Python:
config = RolloutCorrectionConfig(
rollout_is=None,
rollout_rs="seq_sum_k1",
rollout_rs_threshold="0.5_2.0",
)
Properties: Sequence-level RS mask, no IS reweighting.
IS weights (rollout_is_weights) go through a fixed processing pipeline:
Stage 1: Safety Bound (Prevent Overflow)
exp(clamp(log_ratio, -20, 20)) per token → bounds each token to [2e-9, 5e8]exp(clamp(sum(log_ratio), -20, 20)) → bounds product to [2e-9, 5e8], broadcast to all tokensStage 2: Truncation (Reduce Variance)
.clamp(max=rollout_is_threshold) → caps weights at upper threshold (TIS: Truncated Importance Sampling)Stage 3: Padding Zeroing (Correct Aggregation)
weights * response_mask → zeros out padding positionsStage 4: Optional Batch Normalization
rollout_is_batch_normalize=True: Normalize weights to mean=1.0 within batchRejection Sampling (Separate Mechanism)
Rejection sampling modifies response_mask (NOT weights) through compute_rollout_rejection_mask():
The framework provides two operating modes for computing π_old, which can be combined with different loss functions.
| Configuration | bypass_mode | loss_type | Operating Mode | Loss Function | Description |
|---|---|---|---|---|---|
| Decoupled | false | N/A | Decoupled | PPO | Computes old_log_prob separately via actor.compute_log_prob() |
| Bypass + PPO-clip | true | "ppo_clip" (default) | Bypass | PPO-clip | PPO clipped objective (IS handled by ratio) |
| Bypass + REINFORCE | true | "reinforce" | Bypass | REINFORCE | Policy gradient with explicit IS weights (no PPO clipping) |
Policy setup:
actor.compute_log_prob() at start of training epoch)Configuration: bypass_mode = false
Properties:
actor.compute_log_prob())Theory: See rollout_corr_math.md §3.1.1
Policy setup:
Configuration: bypass_mode = true
Properties:
actor.compute_log_prob() call (faster)Theory: See rollout_corr_math.md §3.1.2
The aggregation level can be chosen independently of the operating mode. Any aggregation level works in either decoupled or bypass mode.
rollout_is | rollout_rs | Behavior |
|---|---|---|
null | null | Disabled: No computation, no metrics, no rejection |
null | "token_k1", "seq_sum_k1", "seq_mean_k1", "seq_max_k2", etc | Rejection only: Compute metrics, NO weight correction, YES rejection sampling |
"token" or "sequence" | null | IS weights only: Weight correction enabled, NO rejection sampling |
"token" or "sequence" | "token_k1", "seq_sum_k1", "seq_mean_k1", "seq_max_k2", etc | Full correction: Both weight correction and rejection sampling enabled |
rollout_is=null, rollout_rs="token_k1")rollout_is="token", rollout_rs=null)rollout_is="token", rollout_rs="token_k1")null but still providing rollout_log_probsTheory: See rollout_corr_math.md §3.3 for details on aggregation levels.
Recommended: Bypass Mode
This workflow uses bypass mode for efficiency.
Start with metrics only to understand the off-policy gap:
rollout_correction:
rollout_is: null
rollout_rs: null
bypass_mode: true # Bypass mode (recommended)
loss_type: ppo_clip # Default: PPO clipped objective
Monitor rollout_corr/kl, rollout_corr/log_ppl_abs_diff, rollout_corr/chi2_token to assess off-policy gap.
Enable rejection sampling if you see high outlier fractions:
rollout_correction:
rollout_is: null
rollout_rs: sequence # or "geometric" for higher sensitivity
rollout_rs_threshold: 2.0
bypass_mode: true # Bypass mode
loss_type: ppo_clip # or "reinforce" for explicit IS weights
This excludes outliers from training without modifying gradients.
Enable full IS correction (with REINFORCE loss) once comfortable with metrics:
rollout_correction:
rollout_is: sequence # Recommended: unbiased, suitable for most cases
rollout_is_threshold: 2.0
rollout_rs: sequence # or "geometric" for more aggressive filtering
rollout_rs_threshold: 2.0
bypass_mode: true # Bypass mode
loss_type: reinforce # REINFORCE with explicit IS weights
Benefits of bypass mode:
actor.compute_log_prob() forward pass (faster)loss_type controls the loss function: "ppo_clip" (default) or "reinforce"algorithm:
rollout_correction:
rollout_is: token # Enable IS weights at token level
rollout_is_threshold: 2.0 # Threshold for IS weights
rollout_rs: null # No rejection sampling
actor_rollout_ref:
rollout:
calculate_log_probs: true # Required!
actor_rollout_ref.actor.use_rollout_log_probs: trueactor_rollout_ref.actor.policy_loss.loss_mode: bypass_modeactor_rollout_ref.actor.policy_loss.rollout_correctionAll metrics are prefixed with rollout_corr/ in logs. For example, rollout_is_mean appears as rollout_corr/rollout_is_mean.
These metrics cover both:
rollout_is_mean: Mean importance sampling weight across all valid tokens
rollout_is_std: Standard deviation of IS weights
rollout_is_min: Minimum IS weight observed
rollout_is_max: Maximum IS weight observed
rollout_is_threshold to see truncation impactrollout_is_eff_sample_size: Effective sample size after IS weighting
1 / mean(weights²) where weights are normalizedrollout_is_ratio_fraction_high: Fraction of weights exceeding upper threshold
rollout_is_ratio_fraction_low: Fraction of weights below lower threshold (1/upper_threshold)
rollout_is_seq_mean: Mean IS weight at sequence level
rollout_is_mean for sequence-level aggregationrollout_is_seq_std: Standard deviation of sequence-level IS weights
rollout_is_seq_min: Minimum sequence-level IS weight
rollout_is_seq_max: Maximum sequence-level IS weight
rollout_is_seq_max_deviation: Maximum absolute deviation from 1.0 at sequence level
rollout_is_seq_fraction_high: Fraction of sequences exceeding upper threshold
rollout_is_seq_fraction_low: Fraction of sequences below lower threshold
rollout_rs is enabled)rollout_rs_masked_fraction: Fraction of tokens rejected via rejection sampling
response_mask (sets rejected tokens to 0)rollout_rs is enabled (token/sequence/geometric)rollout_rs_seq_masked_fraction: Fraction of sequences with at least one rejected token
Note on terminology: These metrics use "training" to refer to the training reference policy and "rollout" to refer to π_rollout (the behavior policy used for data collection).
In bypass/pure IS mode, metrics measure the drift between π_θ and π_rollout directly.
training_ppl: Perplexity of training reference policy (π_old in decoupled mode, π_θ in bypass/pure IS mode)
exp(-mean(log_probs))rollout_ppl: Perplexity of rollout policy π_rollout (e.g., vLLM BF16)
ppl_ratio: Ratio of training PPL to rollout PPL
exp(mean(log(training_ppl / rollout_ppl)))training_log_ppl: Log perplexity of training policy
rollout_log_ppl: Log perplexity of rollout policy
log_ppl_diff: Mean difference in log perplexities
mean(log_ppl_rollout - log_ppl_training)log_ppl_abs_diff: Mean absolute log perplexity difference
log_ppl_diff_max: Maximum log perplexity difference across sequences
log_ppl_diff_min: Minimum log perplexity difference across sequences
kl: KL divergence KL(π_rollout || π_training)
mean(log_prob_rollout - log_prob_training)k3_kl: K3 divergence (equals KL(π_rollout || π_training) in expectation)
mean(exp(log_ratio) - log_ratio - 1)chi2_token: Chi-squared divergence at token level
mean(ratio²) - 1 where ratio = π_training/π_rolloutchi2_seq: Chi-squared divergence at sequence level
mean((∏_t ratio_t)²) - 1# Metrics are returned from compute_rollout_correction_and_rejection_mask
from verl.trainer.ppo.rollout_corr_helper import compute_rollout_correction_and_rejection_mask
# Returns 3 values (weights, modified_response_mask, metrics)
weights_proto, modified_response_mask, metrics = compute_rollout_correction_and_rejection_mask(
old_log_prob=training_log_probs, # from training policy
rollout_log_prob=rollout_log_probs, # from rollout policy
response_mask=response_mask,
rollout_is="token", # Enable IS weights at token level
rollout_is_threshold=2.0,
rollout_rs="token_k1",
rollout_rs_threshold="0.5_2.0",
)
# Extract IS weights (processed, zeroed at padding)
is_weights = weights_proto.batch["rollout_is_weights"]
# IS weights processing (with IS enabled at token level):
# 1. Safety-bounded: exp(clamp(log_ratio, -20, 20)) per token
# 2. Truncated: .clamp(max=2.0) to cap extreme weights
# 3. Zeroed at padding positions
# Note: Truncation is ALWAYS applied to IS weights (TIS: Truncated Importance Sampling)
# modified_response_mask has rejection applied (since rollout_rs="token_k1"):
# 1. RS rejection: tokens outside [0.5, 2.0] masked to 0 via response_mask
# Note: RS and IS are separate mechanisms - both can be enabled independently
# All metrics have 'rollout_corr/' prefix
print(f"Mean IS weight: {metrics['rollout_corr/rollout_is_mean']:.3f}")
print(f"Effective sample size: {metrics['rollout_corr/rollout_is_eff_sample_size']:.3f}")
print(f"RS masked fraction: {metrics['rollout_corr/rollout_rs_masked_fraction']:.3f}")
print(f"KL divergence: {metrics['rollout_corr/kl']:.3f}")
# Check IS weights for valid tokens (non-padding)
valid_weights = is_weights[response_mask.bool()]
print(f"\n✓ IS weights min (valid tokens): {valid_weights.min():.4f}")
print(f"✓ IS weights max (valid tokens): {valid_weights.max():.4f}")
print(f"✓ All valid IS weights > 0: {(valid_weights > 0).all()}")
print(f"✓ IS weights are capped at threshold: {(valid_weights <= 2.0).all()}")
# Check rejection via response_mask
rejected_tokens = (response_mask == 1) & (modified_response_mask == 0)
print(f"\n✓ Rejected {rejected_tokens.sum()} tokens via response_mask")
print(f"✓ Rejection sampling modifies response_mask (separate from IS weight truncation)")
print(f"✓ IS weights are always truncated to [0, threshold] after safety bounding")
# Check for warning conditions
if metrics['rollout_corr/rollout_is_mean'] < 0.5 or metrics['rollout_corr/rollout_is_mean'] > 2.0:
print("⚠️ Warning: Mean IS weight far from 1.0, significant off-policy gap detected")
if metrics['rollout_corr/rollout_is_eff_sample_size'] < 0.3:
print("⚠️ Warning: Low effective sample size, high weight concentration")
# In your training loop
for epoch in range(num_epochs):
for batch_idx, batch in enumerate(dataloader):
# ... rollout phase ...
# Compute IS weights and get metrics
rollout_corr_config = config.algorithm.get("rollout_correction", None)
if rollout_corr_config is not None:
weights_proto, modified_response_mask, metrics = compute_rollout_correction_and_rejection_mask(
old_log_prob=batch.old_log_prob,
rollout_log_prob=batch.rollout_log_prob,
response_mask=batch.response_mask,
rollout_is=rollout_corr_config.get("rollout_is", None),
rollout_is_threshold=rollout_corr_config.get("rollout_is_threshold", 2.0),
rollout_rs=rollout_corr_config.get("rollout_rs", None),
rollout_rs_threshold=rollout_corr_config.get("rollout_rs_threshold", None),
)
# Log to tensorboard/wandb
for metric_name, metric_value in metrics.items():
logger.log_scalar(metric_name, metric_value, step=global_step)
# IMPORTANT: Update batch response_mask with rejection applied
batch.response_mask = modified_response_mask
# Use IS weights in training (always safety-bounded, zeroed at padding)
is_weights = weights_proto.batch["rollout_is_weights"]
# ... apply weights to policy gradient ...
def check_rollout_correction_health(metrics, config):
"""Check if Rollout Correction metrics indicate healthy training."""
warnings = []
# Check mean IS weight
mean_weight = metrics['rollout_corr/rollout_is_mean']
if mean_weight < 0.5 or mean_weight > 2.0:
warnings.append(f"Mean IS weight {mean_weight:.3f} is far from 1.0")
# Check effective sample size
ess = metrics['rollout_corr/rollout_is_eff_sample_size']
if ess < 0.3:
warnings.append(f"Effective sample size {ess:.3f} is too low")
# Check standard deviation
std = metrics['rollout_corr/rollout_is_std']
if std > 1.0:
warnings.append(f"IS weight std {std:.3f} is too high")
# Check KL divergence
kl = metrics['rollout_corr/kl']
if abs(kl) > 0.1:
warnings.append(f"KL divergence {kl:.3f} indicates significant off-policy gap")
# Check chi-squared divergence
if 'rollout_corr/chi2_token' in metrics:
chi2_token = metrics['rollout_corr/chi2_token']
if chi2_token > 1.0:
warnings.append(f"Chi-squared divergence (token) {chi2_token:.3f} indicates severe distribution shift")
if warnings:
print("⚠️ Rollout Correction Health Warnings:")
for warning in warnings:
print(f" - {warning}")
return False
else:
print("✅ Rollout Correction metrics look healthy")
return True
# Use in training
_, _, metrics = compute_rollout_correction_and_rejection_mask(...)
is_healthy = check_rollout_correction_health(metrics, config)
if not is_healthy:
# Consider adjusting config or investigating issues
print("Consider:")
print(" - Tightening rollout_is_threshold")
print(" - Switching to geometric aggregation level")
print(" - Checking if rollout and training policies are too different")
Start with the basic token-level truncate configuration:
bash examples/rollout_correction/run_with_rollout_corr.sh
Monitor metrics for 1-2 epochs before adjusting parameters.
algorithm:
rollout_correction:
rollout_is: token
rollout_is_threshold: 2.0
rollout_rs: null # No rejection sampling
algorithm:
rollout_correction:
rollout_is: null # No IS weights
rollout_rs: token_k1
rollout_rs_threshold: "0.5_2.0"
algorithm:
rollout_correction:
rollout_is: token
rollout_is_threshold: 2.0
rollout_rs: token_k1
rollout_rs_threshold: "0.5_2.0"
algorithm:
rollout_correction:
rollout_is: token
rollout_is_threshold: 2.0
rollout_rs: token_k1
rollout_rs_threshold: "0.5_2.0"
bypass_mode: true # Skip old_log_prob computation
loss_type: ppo_clip # PPO clipped objective (default)
Skips expensive actor.compute_log_prob() forward pass. PPO ratio = π_θ/π_rollout handles IS.
rollout_correction:
rollout_is: sequence # Explicit IS correction in loss
rollout_is_threshold: 2.0
rollout_rs: null # Optional: can add rejection sampling
bypass_mode: true
loss_type: reinforce # REINFORCE with explicit IS weights
No PPO clipping, pure policy gradient with IS correction
rollout_correction:
rollout_is: sequence # Computed for metrics
rollout_is_threshold: 2.0
rollout_rs: seq_max_k2 # Sequence max χ²/2 guard
rollout_rs_threshold: 2.5
bypass_mode: true
loss_type: ppo_clip # PPO clipped objective (IS handled by ratio)
PPO clipping with rejection sampling. IS handled by PPO ratio (no explicit IS weights).
Symptoms: rollout_is_std > 1.0, rollout_is_eff_sample_size < 0.3
Solutions:
sequence to geometric levelSymptoms: rollout_is_mean < 0.5 or > 2.0
Solutions:
calculate_log_probs=True is setExample: Plot IS weight distribution
import matplotlib.pyplot as plt
import numpy as np
def plot_is_metrics(metrics_history):
"""Plot rollout IS metrics over training steps."""
fig, axes = plt.subplots(2, 3, figsize=(15, 10))
# Plot 1: Mean IS weight over time
axes[0, 0].plot(metrics_history['rollout_corr/rollout_is_mean'])
axes[0, 0].axhline(y=1.0, color='r', linestyle='--', label='Ideal')
axes[0, 0].set_title('Mean IS Weight')
axes[0, 0].set_xlabel('Step')
axes[0, 0].legend()
# Plot 2: Effective sample size
axes[0, 1].plot(metrics_history['rollout_corr/rollout_is_eff_sample_size'])
axes[0, 1].axhline(y=0.5, color='g', linestyle='--', label='Good')
axes[0, 1].axhline(y=0.3, color='r', linestyle='--', label='Warning')
axes[0, 1].set_title('Effective Sample Size')
axes[0, 1].set_xlabel('Step')
axes[0, 1].legend()
# Plot 3: KL divergence over time
axes[1, 0].plot(metrics_history['rollout_corr/kl'], label='KL')
axes[1, 0].plot(metrics_history['rollout_corr/k3_kl'], label='K3 KL')
axes[1, 0].axhline(y=0, color='g', linestyle='--', alpha=0.3)
axes[1, 0].set_title('KL Divergence')
axes[1, 0].set_xlabel('Step')
axes[1, 0].legend()
# Plot 4: PPL ratio over time
axes[1, 1].plot(metrics_history['rollout_corr/ppl_ratio'])
axes[1, 1].axhline(y=1.0, color='r', linestyle='--', label='Ideal')
axes[1, 1].set_title('PPL Ratio (Training/Rollout)')
axes[1, 1].set_xlabel('Step')
axes[1, 1].legend()
# Plot 5: Chi-squared divergence
if 'rollout_corr/chi2_token' in metrics_history:
axes[1, 2].plot(metrics_history['rollout_corr/chi2_token'], label='Token-level')
if 'rollout_corr/chi2_seq' in metrics_history:
axes[1, 2].plot(metrics_history['rollout_corr/chi2_seq'], label='Seq-level')
axes[1, 2].axhline(y=1.0, color='r', linestyle='--', label='Warning')
axes[1, 2].set_title('Chi-squared Divergence')
axes[1, 2].set_xlabel('Step')
axes[1, 2].legend()
else:
axes[1, 2].axis('off')
plt.tight_layout()
plt.savefig('rollout_is_metrics.png', dpi=150)
print("Saved plot to rollout_is_metrics.png")
Example: Metric collection during training
# Collect metrics over time
metrics_history = {
'rollout_corr/rollout_is_mean': [],
'rollout_corr/rollout_is_eff_sample_size': [],
'rollout_corr/kl': [],
'rollout_corr/k3_kl': [],
'rollout_corr/ppl_ratio': [],
'rollout_corr/chi2_token': [],
'rollout_corr/chi2_seq': [],
}
# In training loop
for step in range(num_steps):
# ... compute IS weights and rejection mask ...
_, _, metrics = compute_rollout_correction_and_rejection_mask(...)
# Store metrics
for key in metrics_history.keys():
if key in metrics:
metrics_history[key].append(metrics[key])
# Plot every 100 steps
if step % 100 == 0:
plot_is_metrics(metrics_history)
Run the test suite to verify everything works:
# Basic unit tests
python tests/trainer/ppo/test_rollout_corr.py
# Integration tests (if pytest is available)
pytest tests/trainer/ppo/test_rollout_corr_integration.py -v
Expected output: All tests pass ✓
verl/trainer/ppo/rollout_corr_helper.pyexamples/rollout_correction/recipe/dapo/run_dapo_qwen2.5_32b_rollout_corr.shRollout Correction provides a unified framework for handling general off-policy problems in RL: