Back to Trl

BEMA for Reference Model

docs/source/bema_for_reference_model.md

1.3.0748 B
Original Source

BEMA for Reference Model

This feature implements the BEMA algorithm to update the reference model during DPO training.

Usage

python
from trl.experimental.bema_for_ref_model import BEMACallback, DPOTrainer
from datasets import load_dataset

dataset = load_dataset("trl-internal-testing/zen", "standard_preference", split="train")

bema_callback = BEMACallback(update_ref_model=True)

trainer = DPOTrainer(
    model="trl-internal-testing/tiny-Qwen2ForCausalLM-2.5",
    train_dataset=dataset,
    callbacks=[bema_callback],
)
trainer.train()

DPOTrainer

[[autodoc]] experimental.bema_for_ref_model.DPOTrainer - train - save_model - push_to_hub

BEMACallback

[[autodoc]] experimental.bema_for_ref_model.BEMACallback