Back to Trl

Post-Training Toolkit Integration

docs/source/ptt_integration.md

1.3.04.4 KB
Original Source

Post-Training Toolkit Integration

Post-Training Toolkit is a diagnostic and observability layer for RLHF training runs. Add one callback to any TRL trainer and get auto-metrics, crash postmortems, and literature-backed heuristics—without writing glue code.

It was built to operationalize the debugging patterns we found most useful when running post-training at scale.

Usage

  1. First, install Post-Training Toolkit:
bash
pip install post-training-toolkit
  1. Add one callback to your trainer. That's it!
<hfoptions id="trainer"> <hfoption id="DPO">
python
from post_training_toolkit import DiagnosticsCallback
from trl import DPOTrainer

trainer = DPOTrainer(
    model=model,
    args=training_args,
    callbacks=[DiagnosticsCallback()],  # ← Just add this
    ...
)
trainer.train()
</hfoption> <hfoption id="PPO">
python
from post_training_toolkit import DiagnosticsCallback
from trl.experimental.ppo import PPOTrainer

trainer = PPOTrainer(
    model=model,
    args=training_args,
    callbacks=[DiagnosticsCallback()],  # ← Just add this
    ...
)
trainer.train()
</hfoption> <hfoption id="SFT">
python
from post_training_toolkit import DiagnosticsCallback
from trl import SFTTrainer

trainer = SFTTrainer(
    model=model,
    args=training_args,
    callbacks=[DiagnosticsCallback()],  # ← Just add this
    ...
)
trainer.train()
</hfoption> <hfoption id="ORPO">
python
from post_training_toolkit import DiagnosticsCallback
from trl.experimental.orpo import ORPOTrainer

trainer = ORPOTrainer(
    model=model,
    args=training_args,
    callbacks=[DiagnosticsCallback()],  # ← Just add this
    ...
)
trainer.train()
</hfoption> <hfoption id="KTO">
python
from post_training_toolkit import DiagnosticsCallback
from trl import KTOTrainer

trainer = KTOTrainer(
    model=model,
    args=training_args,
    callbacks=[DiagnosticsCallback()],  # ← Just add this
    ...
)
trainer.train()
</hfoption> <hfoption id="CPO">
python
from post_training_toolkit import DiagnosticsCallback
from trl.experimental.cpo import CPOTrainer

trainer = CPOTrainer(
    model=model,
    args=training_args,
    callbacks=[DiagnosticsCallback()],  # ← Just add this
    ...
)
trainer.train()
</hfoption> <hfoption id="GRPO">
python
from post_training_toolkit import DiagnosticsCallback
from trl import GRPOTrainer

trainer = GRPOTrainer(
    model=model,
    args=training_args,
    callbacks=[DiagnosticsCallback()],  # ← Just add this
    ...
)
trainer.train()
</hfoption> </hfoptions>

What You Get

Example output:

text
[HIGH] DPO loss stuck at ~0.693 (random chance). Model may not be learning preferences.
       Ref: Rafailov et al. (2023) 'DPO', Section 4.2

[RECOMMENDED] Increase learning rate 2-5x, check data quality, or reduce beta.

Example Demo

See a full working example with auto-stop in action:

📂 demo/live_demo.ipynb

📂 demo/scripts/custom_heuristic.py

1. Auto-Metrics

The callback automatically captures algorithm-specific metrics, backed by the latest research and industry push:

TrainerKey Metrics Captured
DPOloss, win_rate, reward_margin, logps_chosen/rejected
PPOpolicy_loss, value_loss, entropy, clip_fraction, KL
GRPOgroup rewards, advantages, policy loss, KL
SFTloss, perplexity, accuracy
ORPOsft_loss, odds_ratio_loss, log_odds_ratio
KTOkl, logps for desirable/undesirable

2. Crash Postmortems

If training crashes or gets interrupted, you get a postmortem.json with full context:

json
{
  "exit_reason": "exception",
  "last_step": 847,
  "timestamp": "2025-12-17T19:26:04Z",
  "final_metrics": {"dpo_loss": 0.693, "win_rate": 0.52}
}

No more "what step did it die on?"

3. Auto-Stop on Critical Issues

Enable automatic training termination when critical issues are detected:

python
callback = DiagnosticsCallback(stop_on_critical=True)

Distributed Training

Works automatically with multi-GPU setups. Zero configuration needed:

bash
accelerate launch --num_processes 8 train.py

Automatically detects stragglers, aggregates metrics across ranks, and tracks memory balance.