docs/algo/rollout_corr_math.md
verlAuthor: Yingru Li Last updated: 2025-11-04
📖 Documentation Structure
- This document - Mathematical theory: formulations, derivations, and algorithmic foundations
- Rollout Correction Usage Guide - Practical implementation: configurations, presets, troubleshooting
Start here for theory and design rationale, refer to the usage guide for implementation.
@online{liu-li-2025-rl-collapse,
title = {When Speed Kills Stability: Demystifying {RL} Collapse from the Training-Inference Mismatch},
author = {Liu, Jiacai and Li, Yingru and Fu, Yuqian and Wang, Jiawei and Liu, Qian and Shen, Yu},
year = {2025},
month = sep,
url = {https://richardli.xyz/rl-collapse}
}
@article{li2025trust,
title={Trust Region Masking for Long-Horizon LLM Reinforcement Learning},
author={Li, Yingru and Liu, Jiacai and Xu, Jiawei and Tong, Yuxuan and Li, Ziniu and Liu, Qian and Wang, Baoxiang},
journal={arXiv preprint arXiv:2512.23075},
year={2025}
}
This document provides the definitive mathematical formulations for rollout correction methods in verl, following the natural progression from REINFORCE to PPO to Decoupled PPO.
Rollout correction provides a unified framework to handle general off-policy problems in RL training - any scenario where the data collection distribution differs from the training distribution.
Applicable scenarios include:
This section establishes the theoretical progression that verl implements.
The REINFORCE algorithm (Williams, 1992) is the foundation of policy gradient methods.
Vanilla REINFORCE (On-Policy)
For trajectories $\tau = (s_0, a_0, s_1, a_1, \ldots, s_T, a_T)$ sampled from the current policy $\pi_\theta$, the policy gradient is:
$$ \nabla_\theta J(\theta) = \mathbb{E}{\tau \sim \pi\theta} \left[ \sum_{t=0}^T \nabla_\theta \log \pi_\theta(a_t|s_t) \cdot A_t \right] $$
where $A_t$ is the advantage function at timestep $t$.
Off-Policy REINFORCE
When trajectories are sampled from a different behavior policy $\mu$, we apply importance sampling over the joint trajectory distribution:
$$ \nabla_\theta J(\theta) = \mathbb{E}{\tau \sim \mu} \left[ \frac{P{\pi_\theta}(\tau)}{P_\mu(\tau)} \sum_{t=0}^T \nabla_\theta \log \pi_\theta(a_t|s_t) \cdot A_t \right] $$
where the trajectory-level importance weight is:
$$ \frac{P_{\pi_\theta}(\tau)}{P_\mu(\tau)} = \frac{p(s_0) \prod_{t=0}^T \pi_\theta(a_t|s_t) p(s_{t+1}|s_t, a_t)}{p(s_0) \prod_{t=0}^T \mu(a_t|s_t) p(s_{t+1}|s_t, a_t)} = \prod_{t=0}^T \frac{\pi_\theta(a_t|s_t)}{\mu(a_t|s_t)} $$
The transition dynamics $p(s_{t+1}|s_t, a_t)$ and initial state $p(s_0)$ cancel out, leaving only the product of per-step action probability ratios.
Key properties:
Implementation in verl: The bypass_pg_is preset implements off-policy REINFORCE with truncated importance sampling.
Proximal Policy Optimization (Schulman et al., 2017) adds a clipped surrogate objective:
$$ L_{\text{PPO}}(\theta) = -\mathbb{E}_{(s,a) \sim \mu} \left[ \min\left( r_t(\theta) A_t, \text{clip}(r_t(\theta), 1-\epsilon, 1+\epsilon) A_t \right) \right] $$
where $r_t(\theta) = \frac{\pi_\theta(a_t|s_t)}{\mu(a_t|s_t)}$ and $\epsilon$ is the clip range (typically 0.2).
Key properties:
Decoupled PPO (Hilton et al., 2021) solves PPO's batch size sensitivity by decoupling two roles:
The problem: Standard PPO controls policy update size via the ratio $\frac{\pi_\theta}{\pi_{\text{old}}}$, where $\pi_{\text{old}}$ is assumed to be both the proximal policy and the behavior policy. This coupling makes the algorithm sensitive to batch size because aggregating data from multiple workers or using replay buffers changes the effective behavior policy.
The solution: Decouple these two roles, leading to a three-policy formulation:
$$ L_{\text{DecoupledPPO}}(\theta) = -\mathbb{E}_{(s,a) \sim \mu} \left[ w_t \cdot \min\left( r_t(\theta) A_t, \text{clip}(r_t(\theta), 1-\epsilon, 1+\epsilon) A_t \right) \right] $$
where:
Key properties: By decoupling:
This is the algorithm that verl implements via its three-policy framework.
The verl library implements decoupled PPO using three distinct policies, each serving a specific role.
$\pi_{\text{rollout}}$ (Behavior Policy $\mu$) The policy used for data collection. This is the behavior distribution $\mu$ from theory.
$\pi_{\text{old}}$ (Proximal Policy $\pi_{\text{prox}}$) The reference policy for PPO clipping. This is the "proximal policy" from decoupled PPO theory.
actor.compute_log_prob()$\pi_{\theta}$ (Current Policy) The policy being actively optimized during training.
The three-policy framework can operate in two modes:
Decoupled Mode (Three Policies)
Bypass Mode (Two Policies)
actor.compute_log_prob() call); does not achieve batch size invarianceThe three-policy framework handles two types of distribution drift:
Drift 1: $\pi_{\text{rollout}} \to \pi_{\text{old}}$ (Off-Policy Gap)
This is the distribution shift between the data collection policy and the training reference policy.
Drift 2: $\pi_{\text{old}} \to \pi_{\theta}$ (Policy Update Drift)
This is the drift from policy parameter updates during training.
The rollout correction framework in verl is built from orthogonal components that can be combined flexibly:
This section explains each component and their valid combinations.
The operating mode determines how the proximal policy $\pi_{\text{old}}$ is computed.
Configuration: bypass_mode = false
Policy setup:
actor.compute_log_prob() at start of training epoch)IS ratio: $\rho_t = \frac{\pi_{\text{old}}(a_t|s_t)}{\pi_{\text{rollout}}(a_t|s_t)}$ (corrects Drift 1: rollout→old)
PPO ratio: $r_t(\theta) = \frac{\pi_{\theta}(a_t|s_t)}{\pi_{\text{old}}(a_t|s_t)}$ (corrects Drift 2: old→current)
Properties:
actor.compute_log_prob())Configuration: bypass_mode = true
Policy setup:
Ratios:
loss_type = "ppo_clip", default): PPO ratio $r_t(\theta) = \frac{\pi_{\theta}(a_t|s_t)}{\pi_{\text{rollout}}(a_t|s_t)}$ clips against rollout policy (IS handled by ratio)loss_type = "reinforce"): IS ratio $\rho_t = \frac{\pi_{\theta}(a_t|s_t)}{\pi_{\text{rollout}}(a_t|s_t)}$ computed on-the-fly in loss functionProperties:
actor.compute_log_prob() call (faster)Configuration: loss_type = "ppo_clip" (default in bypass mode)
Loss function:
$$ L_{\text{PPO}}(\theta) = -\mathbb{E}_t \left[ w_t \cdot \min\left( r_t(\theta) A_t, \text{clip}(r_t(\theta), 1-\epsilon, 1+\epsilon) A_t \right) \right] $$
where:
Properties:
Configuration: loss_type = "reinforce" (requires bypass_mode = true)
Loss function (example with sequence-level IS):
$$ L_{\text{PG}}(\theta) = -\mathbb{E}{(s,a) \sim \pi{\text{rollout}}} \left[ \text{stopgrad}(w_{\text{seq}}(\theta)) \cdot \sum_{t \in T} \log \pi_{\theta}(a_t|s_t) \cdot A_t \right] $$
where:
Effective gradient:
$$ \nabla_\theta L_{\text{PG}} = -\mathbb{E}{(s,a) \sim \pi{\text{rollout}}} \left[ \text{stopgrad}(w_{\text{seq}}(\theta)) \cdot \sum_{t \in T} \nabla_\theta \log \pi_{\theta}(a_t|s_t) \cdot A_t \right] $$
Theoretical Justification for stopgrad:
The stopgrad operator is mathematically required by importance sampling theory, not an implementation detail. Here's why:
The fundamental principle: Importance sampling is a technique to change the measure (reweight samples from one distribution to estimate expectations under another), not to optimize the reweighting function itself.
Formal derivation:
Original objective: We want to optimize $J(\theta) = \mathbb{E}{\tau \sim \pi\theta}[\sum_t A_t]$.
Off-policy setting: We only have samples from $\pi_{\text{rollout}}$, so we use importance sampling: $$ J(\theta) = \mathbb{E}{\tau \sim \pi{\text{rollout}}} \left[ \underbrace{\frac{P_{\pi_\theta}(\tau)}{P_{\pi_{\text{rollout}}}(\tau)}}_{w(\tau;\theta)} \sum_t A_t \right] $$
Computing the policy gradient: The correct gradient uses the policy gradient theorem BEFORE importance sampling: $$ \begin{aligned} \nabla_\theta J(\theta) &= \nabla_\theta \mathbb{E}{\tau \sim \pi\theta}\left[\sum_t A_t\right] \ &= \mathbb{E}{\tau \sim \pi\theta} \left[\sum_t A_t \nabla_\theta \log \pi_\theta(a_t|s_t) \right] \quad \text{(policy gradient theorem)} \ &= \mathbb{E}{\tau \sim \pi{\text{rollout}}} \left[ w(\tau;\theta) \sum_t A_t \nabla_\theta \log \pi_\theta(a_t|s_t) \right] \quad \text{(change of measure)} \end{aligned} $$
In the final line, $w(\tau;\theta)$ appears as a multiplicative coefficient from the change of measure, not as something we differentiate.
What goes wrong without stopgrad: If we naively compute $\nabla_\theta \left[w(\theta) \log \pi_\theta \right]$ in the loss, we get: $$ \nabla_\theta \left[w(\theta) \log \pi_\theta \right] = \underbrace{\log \pi_\theta \cdot \nabla_\theta w(\theta)}{\text{WRONG: bias term}} + \underbrace{w(\theta) \cdot \nabla\theta \log \pi_\theta}_{\text{CORRECT: IS-weighted gradient}} $$
The first term $\log \pi_\theta \cdot \nabla_\theta w(\theta)$ is an artifact of the computational trick (using loss times log-prob), not part of the true policy gradient. It biases the gradient estimator and optimizes a different objective than $J(\theta)$.
Implementation requirement: In PyTorch, to compute only the second term, we must use:
loss = -advantages * log_prob * rollout_is_weights.detach() # stopgrad on weights
Without .detach(), autograd computes both terms, giving an incorrect gradient.
Intuition: The IS weight $w(\theta)$ tells us "how much to trust this sample" for estimating the gradient under $\pi_\theta$. We update $\theta$ to maximize the reweighted objective, but we don't update $\theta$ to maximize the weight itself—that would be circular reasoning (optimizing the correction factor instead of the actual objective).
Properties:
loss_type config option in bypass mode):
"ppo_clip" (default): PPO clipped objective
"reinforce": Pure policy gradient with explicit IS weights, no PPO clipping
Implementation: compute_policy_loss_bypass_mode() and compute_policy_loss_reinforce() in core_algos.py
The aggregation level determines how per-token probability ratios are combined into IS weights and/or rejection masks. This choice is orthogonal to the operating mode - you can use any aggregation level in either decoupled or bypass mode.
IS weights: $w_t = \min(\rho_t, C_{\text{IS}})$ where $\rho_t = \frac{\pi_{\text{old}}(a_t|s_t)}{\pi_{\text{rollout}}(a_t|s_t)}$ (decoupled) or $\rho_t = \frac{\pi_{\theta}(a_t|s_t)}{\pi_{\text{rollout}}(a_t|s_t)}$ (bypass/pure IS)
Configuration:
rollout_is = "token" # IS weights
rollout_rs = "token_k1" # Optional: rejection sampling (ratio bounds)
Properties:
Loss function (REINFORCE + Token IS):
$$ L_{\text{REINFORCE+TIS}}(\theta) = -\mathbb{E}t \left[ \text{stopgrad}(w_t) \cdot \log \pi\theta(a_t|s_t) \cdot A_t \right] $$
where $w_t = \min(\rho_t, C_{\text{IS}})$ are the truncated token-level IS weights. The stopgrad operator ensures that when computing $\nabla_\theta L$, the weights are treated as constants (see §3.2.2 for theoretical justification). This formulation can also be combined with PPO clipping by replacing the REINFORCE gradient with the clipped surrogate objective.
Implementation:
compute_rollout_correction_weights() in rollout_corr_helper.pycompute_policy_loss() in core_algos.pyIS weights: $w_{\text{seq}} = \min\left( \prod_{t \in T} \rho_t, C_{\text{IS}} \right) = \min\left( \exp\left(\sum_{t \in T} \log \rho_t\right), C_{\text{IS}} \right)$ (broadcast to all tokens)
Configuration:
rollout_is = "sequence" # IS weights
rollout_rs = "seq_sum_k1" # Optional: rejection sampling
Properties:
Terminology Note:
Loss function (REINFORCE + Sequence IS):
$$ L_{\text{REINFORCE+SeqIS}}(\theta) = -\mathbb{E}t \left[ \text{stopgrad}(w{\text{seq}}) \cdot \log \pi_\theta(a_t|s_t) \cdot A_t \right] $$
where $w_{\text{seq}}$ is broadcast to all tokens in the sequence. The stopgrad operator ensures correct IS gradient computation (see §3.2.2). This formulation can also be combined with PPO clipping.
Geometric mean ratio: $\rho_{\text{geo}} = \exp\left( \frac{1}{|T|} \sum_{t \in T} \log \rho_t \right) = \left(\prod_{t \in T} \rho_t\right)^{1/|T|}$ (broadcast to all tokens)
Configuration:
rollout_is = null # No IS weights, pure rejection
rollout_rs = "seq_mean_k1" # Geometric mean rejection sampling (ratio bounds)
Properties:
"0.999_1.001" (~±0.1%)The Length Trap Problem:
Standard IS estimators have a systematic length bias that penalizes long sequences. The importance ratio $\rho(y)$ is multiplicative:
$$ \rho(y) = \prod_{t=1}^T \frac{\pi(y_t|y_{<t})}{\mu(y_t|y_{<t})} $$
Assume the new policy $\pi$ differs slightly from $\mu$, with average per-token ratio $\approx 1.1$:
This creates Context Collapse: the model preferentially learns from short, shallow answers and rejects long chains of thought—even if per-step quality is identical. For reasoning models (CoT) and agents, this effectively penalizes "thinking too long."
Geo-RS Solution:
Geometric-level rejection normalizes by sequence length, converting the extensive property (total probability product) to an intensive property (average per-token drift):
$$ \rho_{\text{geo}}(y) = \rho(y)^{1/T} $$
Now both sequences have the same "trust score":
Why tight thresholds? For 100 tokens with per-token log-ratio = 0.01 each:
A ratio bound of "0.999_1.001" rejects sequences whose average per-token log-deviation exceeds ≈0.1%.
Loss function (REINFORCE + Geometric RS):
$$ L_{\text{GeoRS}}(\theta) = -\mathbb{E}{(s,a) \mid \text{seq} \in \mathcal{A}{\text{geo}}} \left[ \sum_{t \in T} \log \pi_\theta(a_t|s_t) \cdot A_t \right] $$
where $\mathcal{A}{\text{geo}} = { \text{seq} : C{\text{RS-lower}} \leq \rho_{\text{geo}} \leq C_{\text{RS-upper}} }$ is the acceptance set (rejection mask). No IS weights are used, so no stopgrad needed. This formulation can also be combined with PPO clipping.
Combined Estimator (Geo-RS-Token-TIS):
For best results, combine the Geometric Filter (length-invariant validity check) with Token-level IS weights (lower variance):
$$ \hat{g}{\text{geo-rs-token-tis}}(y) = \underbrace{\mathbb{I}\left( C{\text{low}} \le \rho(y)^{1/T} \le C_{\text{high}} \right)}_{\text{Geometric Filter}} \cdot \prod_t \min(\rho_t, C) \cdot f(y) $$
This is implemented by combining rollout_rs="seq_mean_k1" with rollout_is="token".
Per-token statistic:
$$ K2_t = \frac{1}{2} \left(\log \rho_t\right)^2 $$
where $\rho_t = \frac{\pi_{\text{old}}(a_t|s_t)}{\pi_{\text{rollout}}(a_t|s_t)}$ and the implementation clips $\log \rho_t$ to $[-20, 20]$ for numerical safety.
Sequence aggregations (share the same per-token $K2_t$):
seq_sum_k2: $K2_{\text{sum}} = \sum_{t \in T} K2_t$seq_mean_k2: $K2_{\text{mean}} = \frac{1}{|T|} \sum_{t \in T} K2_t$seq_max_k2: $K2_{\text{max}} = \max_{t \in T} K2_t$Configuration:
rollout_is = null # Optional: pair with token IS weights for lower variance
rollout_rs = "token_k2" # or "seq_sum_k2", "seq_mean_k2", "seq_max_k2"
rollout_rs_threshold = 2.0 # Positive upper bound only
Properties:
token_k2, 2.0-2.5 for seq_mean_k2, and 2.5-4.0 for seq_sum_k2.seq_max_k2 isolates single-token spikes even when the rest of the sequence is clean.rollout_is="token") to keep useful samples while clipping variance.Combined Estimator (K2-RS-Token-TIS):
For combined filtering and weighting, let $K2_{\text{agg}}$ denote the selected aggregation (token, sum, mean, or max):
$$ \hat{g}{\text{k2-rs-token-tis}}(y) = \underbrace{\mathbb{I}\left( K2{\text{agg}}(y) \le C_{\text{k2}} \right)}_{\text{K2 Filter}} \cdot \prod_t \min(\rho_t, C) \cdot f(y) $$
This is implemented via rollout_rs="seq_mean_k2" (or another k2 mode) together with rollout_is="token".
K3 divergence at sequence level:
$$ K3_{\text{seq}} = \frac{1}{|T|} \sum_{t \in T} \left( \rho_t - \log \rho_t - 1 \right) $$
where $\rho_t = \frac{\pi_{\text{old}}(a_t|s_t)}{\pi_{\text{rollout}}(a_t|s_t)}$ is the per-token ratio.
K3 equals the reverse KL: In expectation, $K3 = \text{KL}(\pi_{\text{rollout}} | \pi_{\text{old}})$. This follows from:
Configuration:
rollout_is = null # No IS weights, pure rejection
rollout_rs = "seq_mean_k3" # K3 rejection sampling
Properties:
Why K3 over geometric ratio?
Combined Estimator (K3-RS-Token-TIS):
For best results, combine K3 filter with token-level IS weights:
$$ \hat{g}{\text{k3-rs-token-tis}}(y) = \underbrace{\mathbb{I}\left( K3{\text{seq}} \le C_{\text{k3}} \right)}_{\text{K3 Filter}} \cdot \prod_t \min(\rho_t, C) \cdot f(y) $$
This is implemented by combining rollout_rs="seq_mean_k3" with rollout_is="token".
An optional variance reduction technique that normalizes IS weights to have mean 1.0 within each batch.
Configuration:
rollout_is_batch_normalize = True # Default: False
Normalization formula (aggregation-aware):
For token-level IS (§3.3.1):
$$ \tilde{w}t = \frac{w_t}{\frac{1}{\sum{i,t} m_{i,t}} \sum_{i,t} w_{i,t} \cdot m_{i,t}} $$
where $w_{i,t}$ are truncated token IS weights, $m_{i,t}$ is the response mask, and normalization is over all tokens.
For sequence-level IS (§3.3.2):
$$ \tilde{w}i = \frac{w_i}{\frac{1}{B}\sum{j=1}^B \bar{w}_j} $$
where $\bar{w}j = \frac{1}{T_j}\sum{t=1}^{T_j} w_{j,t} \cdot m_{j,t}$ is the per-sequence mean (all tokens in a sequence have the same weight), and normalization is over sequences.
Properties:
masked_mean to respect padding tokensMetrics:
rollout_is_batch_norm_factor: The normalization factor applied (batch mean before normalization)Implementation: rollout_corr_helper.py
Rejection sampling can be added to any combination of operating mode and aggregation level. It modifies the response_mask to exclude outlier tokens/sequences.
Configuration examples:
rollout_rs = "token_k1" # Token-level ratio bounds
rollout_rs_threshold = "0.6_1.6"
rollout_rs = "seq_sum_k1" # Sequence sum of log ratios
rollout_rs_threshold = "0.5_2.0"
rollout_rs = "seq_mean_k3" # Sequence mean of K3 divergence
rollout_rs_threshold = 0.01
Acceptance set:
Properties:
Implementation: compute_rollout_rejection_mask() in rollout_corr_helper.py
Key insight: Estimators (how IS/RS is computed) and operating modes (decoupled PPO vs bypass PG) are orthogonal. Any estimator can be combined with any operating mode.
| Estimator | Configuration | Compatible Modes |
|---|---|---|
| Token-TIS | rollout_is="token" | Decoupled PPO, Bypass PG |
| Seq-TIS | rollout_is="sequence" | Decoupled PPO, Bypass PG |
| Seq-MIS | rollout_is="sequence" + rollout_rs="seq_sum_k1" | Decoupled PPO, Bypass PG |
| Geo-RS | rollout_rs="seq_mean_k1" (geometric mean) | Decoupled PPO, Bypass PG |
| Geo-RS-Token-TIS | rollout_is="token" + rollout_rs="seq_mean_k1" | Decoupled PPO, Bypass PG |
| K3-RS | rollout_rs="seq_mean_k3" | Decoupled PPO, Bypass PG |
| K3-RS-Token-TIS | rollout_is="token" + rollout_rs="seq_mean_k3" | Decoupled PPO, Bypass PG |
Note: In bypass mode, loss_type controls the loss function. Use "ppo_clip" (default) or "reinforce".
| Preset Method | Estimator | Mode | Properties |
|---|---|---|---|
| Decoupled PPO Mode (3 policies: π_rollout, π_old, π_θ) | |||
decoupled_token_is() | Token-TIS | Decoupled PPO | Per-token IS weights |
decoupled_seq_is() | Seq-TIS | Decoupled PPO | Sequence-level IS weights |
decoupled_seq_is_rs() | Seq-MIS | Decoupled PPO | Sequence IS + sequence RS |
decoupled_geo_rs() | Geo-RS | Decoupled PPO | Geometric RS |
decoupled_geo_rs_token_tis() | Geo-RS-Token-TIS | Decoupled PPO | Geometric filter + token IS |
| K3 KL Estimator (more stable for small KL values) | |||
decoupled_k3_rs() | K3-RS | Decoupled PPO | K3 rejection, no IS weights |
decoupled_k3_rs_token_tis() | K3-RS-Token-TIS | Decoupled PPO | K3 filter + token clipped weight |
| Bypass Mode (PPO-clip) (ratio handles IS, RS masks outliers) | |||
bypass_ppo_clip() | - | Bypass (PPO-clip) | PPO-clip only |
bypass_ppo_clip_geo_rs() | Geo-RS | Bypass (PPO-clip) | PPO-clip + Geo-RS (ratio) |
bypass_ppo_clip_k3_rs() | K3-RS | Bypass (PPO-clip) | PPO-clip + K3-RS |
| Bypass Mode (REINFORCE) (explicit IS weights, no PPO clipping) | |||
bypass_pg_is() | Seq-TIS | Bypass (REINFORCE) | REINFORCE + Seq IS |
bypass_pg_geo_rs() | Geo-RS | Bypass (REINFORCE) | REINFORCE + Geo-RS (ratio) |
bypass_pg_geo_rs_token_tis() | Geo-RS-Token-TIS | Bypass (REINFORCE) | REINFORCE + Geo filter + token IS |
| Other | |||
disabled() | - | - | Metrics only |
Note: Bypass mode sets π_old = π_rollout and uses loss_type to select the loss function.
These combinations are fully supported but require manual configuration:
1. Token IS + Token RS
config = RolloutCorrectionConfig(
rollout_is="token",
rollout_is_threshold=2.0,
rollout_rs="token_k1",
rollout_rs_threshold="0.5_2.0",
)
Properties: Token-level IS weights + token-level RS mask.
2. Pure Token RS
config = RolloutCorrectionConfig(
rollout_is=None,
rollout_rs="token_k1",
rollout_rs_threshold="0.5_2.0",
)
Properties: Token-level RS mask only, no IS weights.
3. Pure Sequence RS
config = RolloutCorrectionConfig(
rollout_is=None,
rollout_rs="seq_sum_k1",
rollout_rs_threshold="0.5_2.0",
)
Properties: Sequence-level RS mask only, no IS weights.
Key properties:
bypass_pg_rs) uses bypass + geometric RS with loss_type="reinforce" for REINFORCE (no IS weights)Theory: Naive LLM-RL implementation that incorrectly applies PPO by ignoring the actual rollout policy and assuming $\pi_{\text{old}} = \pi_{\text{rollout}}$.
Note: This incorrect implementation pattern was identified in Liu, Li, et al. (2025) as a key cause of training instability in LLM-RL systems, motivating the development of this rollout correction framework.
Loss Function:
$$ L_{\text{PPO}}(\theta) = -\mathbb{E}_t \left[ \min\left( r_t(\theta) A_t, \text{clip}(r_t(\theta), 1-\epsilon, 1+\epsilon) A_t \right) \right] $$
where $r_t(\theta) = \frac{\pi_{\theta}(a_t|s_t)}{\pi_{\text{old}}(a_t|s_t)}$ (ignores $\pi_{\text{rollout}}$).
Why it's wrong:
Correct alternatives:
Implementation: compute_policy_loss() in core_algos.py
These metrics quantify the severity of off-policy drift.
Note on notation: Metrics use $\rho_t = \frac{\pi_{\text{old}}(a_t|s_t)}{\pi_{\text{rollout}}(a_t|s_t)}$. In bypass mode, $\pi_{\text{old}} = \pi_{\text{rollout}}$, so metrics measure rollout→current drift using $\rho_t = \frac{\pi_{\theta}}{\pi_{\text{rollout}}}$ instead.
Direct KL estimator:
$$ \text{KL}(\pi_{\text{rollout}} | \pi_{\text{old}}) = \mathbb{E}{t \sim \pi{\text{rollout}}} \left[ \log \pi_{\text{rollout}}(a_t|s_t) - \log \pi_{\text{old}}(a_t|s_t) \right] $$
K3 KL estimator (alternative formulation):
$$ \text{KL}{\text{K3}} = \mathbb{E}{t \sim \pi_{\text{rollout}}} \left[ \rho_t - \log \rho_t - 1 \right] $$
where $\rho_t = \frac{\pi_{\text{old}}(a_t|s_t)}{\pi_{\text{rollout}}(a_t|s_t)}$.
Old policy perplexity:
$$ \text{PPL}{\text{old}} = \exp\left( -\frac{1}{|T|} \sum{t \in T} \log \pi_{\text{old}}(a_t|s_t) \right) $$
Rollout policy perplexity:
$$ \text{PPL}{\text{rollout}} = \exp\left( -\frac{1}{|T|} \sum{t \in T} \log \pi_{\text{rollout}}(a_t|s_t) \right) $$
PPL ratio (inverse of geometric mean IS weight):
$$ \text{PPL}{\text{ratio}} = \frac{\text{PPL}{\text{old}}}{\text{PPL}{\text{rollout}}} = \exp\left( -\frac{1}{|T|} \sum{t \in T} \log \rho_t \right) = \left(\prod_{t \in T} \rho_t\right)^{-1/|T|} $$
Interpretation: Values > 1 mean $\pi_{\text{old}}$ assigns lower probability than $\pi_{\text{rollout}}$ to the observed actions (distribution shift).
Measures the second moment of the IS weight distribution.
Token-level:
$$ \chi^2_{\text{token}} = \mathbb{E}{t \sim \pi{\text{rollout}}} \left[ \rho_t^2 \right] - 1 $$
Sequence-level:
$$ \chi^2_{\text{seq}} = \mathbb{E}{\text{seq} \sim \pi{\text{rollout}}} \left[ \left(\prod_{t \in T} \rho_t\right)^2 \right] - 1 $$
Interpretation:
Implementation: compute_offpolicy_metrics() in rollout_corr_helper.py
| Method | Theory | Policies | PPO Clip | IS Correction | Correctness | Speed |
|---|---|---|---|---|---|---|
Bypass Mode (π_old = π_rollout, loss_type selects algorithm) | ||||||
loss_type="ppo_clip" (default) | PPO (ratio = π_θ/π_rollout) | 2 (rollout, θ) | ✅ | RS mask only (ratio handles IS) | ✅ Correct | Fast |
loss_type="reinforce" | Off-policy REINFORCE | 2 (rollout, θ) | ❌ | ✅ (explicit IS weights) | ✅ Correct | Fast |
| Bypass Mode Presets (PPO-clip) | ||||||
bypass_ppo_clip | PPO only | 2 (rollout, θ) | ✅ | - | ✅ Correct | Fast |
bypass_ppo_clip_geo_rs | PPO + Geo-RS | 2 (rollout, θ) | ✅ | Geo-RS mask (ratio) | ✅ Correct | Fast |
| Bypass Mode Presets (REINFORCE) | ||||||
bypass_pg_is | REINFORCE + Seq-TIS | 2 (rollout, θ) | ❌ | ✅ Seq-TIS | ✅ Correct | Fast |
bypass_pg_geo_rs | REINFORCE + Geo-RS | 2 (rollout, θ) | ❌ | Geo-RS only (ratio) | ✅ Correct | Fast |
bypass_pg_geo_rs_token_tis | REINFORCE + Geo RS + Token IS | 2 (rollout, θ) | ❌ | ✅ Geo-RS-Token-TIS | ✅ Correct | Fast |
| Decoupled PPO Mode (IS weights = π_old / π_rollout) | ||||||
decoupled_token_is | Decoupled PPO | 3 (rollout, old, θ) | ✅ | ✅ Token-TIS | ✅ Correct | Standard |
decoupled_seq_is | Decoupled PPO | 3 (rollout, old, θ) | ✅ | ✅ Seq-TIS | ✅ Correct | Standard |
decoupled_seq_is_rs | Decoupled PPO + RS | 3 (rollout, old, θ) | ✅ | ✅ Seq-MIS | ✅ Correct | Standard |
decoupled_geo_rs | Decoupled PPO + Geo-RS | 3 (rollout, old, θ) | ✅ | Geo-RS only (ratio) | ✅ Correct | Standard |
decoupled_geo_rs_token_tis | Decoupled PPO + Geo RS + Token IS | 3 (rollout, old, θ) | ✅ | ✅ Geo-RS-Token-TIS | ✅ Correct | Standard |
| Incorrect (for reference) | ||||||
| Naive LLM-RL | Incorrect PPO usage | 2 (old, θ) | ✅ | ❌ | ⚠️ Incorrect | Standard |
Notes:
loss_type to select the loss function:
"ppo_clip" (default): PPO clipped ratio (IS handled by ratio = π_θ/π_rollout, no explicit IS weights to avoid double-counting)"reinforce": Explicit IS weights applied as $w \cdot \log \pi \cdot A$These estimators define how IS weights and rejection masks are computed. They are orthogonal to the operating mode (decoupled PPO vs bypass policy gradient) and can be combined with either.
| Estimator | Configuration | Mechanism | Best For |
|---|---|---|---|
| Token-TIS | rollout_is="token" | Clips per-token ratios | Lower variance IS with acceptable bias |
| Seq-TIS | rollout_is="sequence" | Clips sequence ratio $\rho(\tau) \to \min(\rho(\tau), C)$ | Clean data with moderate mismatch; unbiased |
| Seq-MIS | rollout_is="sequence" + rollout_rs="seq_sum_k1" | Rejects sequences with $\rho(\tau) > C$ | Severe mismatch; filters "toxic tail" (garbage data) |
| Geo-RS | rollout_rs="seq_mean_k1" | Rejects on geometric mean ratio exp(E[log(r)]) | Length-invariant trust region |
| Geo-RS-Token-TIS | rollout_is="token" + rollout_rs="seq_mean_k1" | Geometric filter + token IS weights | Ratio-based length normalization + lower variance IS |
| K3-RS | rollout_rs="seq_mean_k3" | Rejects on K3 KL divergence | Small KL values; smooth detector |
| K3-RS-Token-TIS | rollout_is="token" + rollout_rs="seq_mean_k3" | K3 filter + token IS weights | Small KL + lower variance IS |
Note: Each estimator can be used with either:
bypass_mode=false): Three policies with PPO clippingbypass_mode=true): Two policies with configurable loss type
loss_type="ppo_clip" (default): PPO clipped objective (IS via ratio, RS mask applied)loss_type="reinforce": REINFORCE with explicit IS weightsChoosing estimator by off-policy severity:
Choosing estimator by sequence length:
Choosing operating mode:
bypass_mode=false)bypass_mode=true) to skip old_log_prob computationloss_type="reinforce"Decoupled mode (computes old_log_prob separately):
Bypass mode (sets $\pi_{\text{old}} = \pi_{\text{rollout}}$):
old_log_prob computation