optional-skills/mlops/simpo/references/loss-functions.md
Complete guide to SimPO loss functions and mathematical formulations.
SimPO supports two loss types:
Both are reference-free (no reference model needed).
Step 1: Log probability ratio:
pi_logratios = log P_θ(y_chosen|x) - log P_θ(y_rejected|x)
Step 2: Apply target margin:
logits = pi_logratios - γ/β
Where:
gamma_beta_ratio (target margin)Step 3: Compute loss (depends on loss type)
Formula:
L = -log σ(β * logits) * (1 - ε) - log σ(-β * logits) * ε
Where:
beta (reward scaling)label_smoothing (default 0.0)Implementation:
losses = (
-F.logsigmoid(self.beta * logits) * (1 - self.label_smoothing)
- F.logsigmoid(-self.beta * logits) * self.label_smoothing
)
Characteristics:
Formula:
L = max(0, 1 - β * logits)
Implementation:
losses = torch.relu(1 - self.beta * logits)
Characteristics:
Formula:
L_DPO = -E[log σ(β * log(π_θ(y_w|x)/π_ref(y_w|x)) - β * log(π_θ(y_l|x)/π_ref(y_l|x)))]
Key features:
Formula:
L_SimPO = -log σ(β * (log π_θ(y_w|x) - log π_θ(y_l|x) - γ/β))
Key features:
Visual comparison:
DPO: [Policy] - [Reference] → Loss
SimPO: [Policy] → Loss
Per-token log probabilities:
# Get log probs for each token
per_token_logps = log_softmax(logits).gather(dim=-1, index=labels)
# Create mask to ignore padding
loss_mask = (labels != label_pad_token_id)
Average log probability (if average_log_prob=True):
avg_logp = (per_token_logps * loss_mask).sum(-1) / loss_mask.sum(-1)
Sum log probability (if average_log_prob=False):
sum_logp = (per_token_logps * loss_mask).sum(-1)
Why average?
Chosen reward:
chosen_rewards = beta * policy_chosen_logps.detach()
Rejected reward:
rejected_rewards = beta * policy_rejected_logps.detach()
Reward margin:
reward_margin = chosen_rewards.mean() - rejected_rewards.mean()
Sigmoid loss:
L = -log σ(β * logits) * (1 - ε) - log σ(-β * logits) * ε
Effect:
When to use:
Config:
label_smoothing: 0.1 # 10% smoothing
With SFT component:
L_total = L_SimPO + λ * L_SFT
Where:
sft_weight (0.0 to 1.0)Implementation:
if self.sft_weight > 0:
sft_loss = -policy_chosen_logps
total_loss = simpo_loss + self.sft_weight * sft_loss
When to use:
Trade-off:
Config:
sft_weight: 0.1 # 10% SFT regularization
| Aspect | Sigmoid | Hinge |
|---|---|---|
| Smoothness | Smooth | Non-smooth |
| Gradients | Continuous | Discontinuous at margin |
| Sparsity | Dense solutions | Sparse solutions |
| Interpretability | Probabilistic | Geometric margin |
| Use case | General purpose | Margin-based tasks |
| Recommendation | Default choice | Experimental |
Config:
# Sigmoid (default)
loss_type: sigmoid
# Hinge (alternative)
loss_type: hinge
Sigmoid loss gradient:
∂L/∂logits = -β * σ(-β * logits) * (1 - ε) + β * σ(β * logits) * ε
Hinge loss gradient:
∂L/∂logits = -β if logits < 1/β
0 otherwise
Implications:
Sigmoid:
Hinge:
Config:
beta: 2.0
gamma_beta_ratio: 0.5
loss_type: sigmoid
label_smoothing: 0.0
sft_weight: 0.0
Loss calculation:
# Step 1: Compute log probs
chosen_logps = avg_log_prob(policy(chosen)) # e.g., -1.2
rejected_logps = avg_log_prob(policy(rejected)) # e.g., -2.5
# Step 2: Log ratio and margin
pi_logratios = -1.2 - (-2.5) = 1.3
logits = 1.3 - 0.5 = 0.8
# Step 3: Sigmoid loss
loss = -log(sigmoid(2.0 * 0.8))
= -log(sigmoid(1.6))
= -log(0.832)
= 0.184
Config:
beta: 2.5
gamma_beta_ratio: 0.5
loss_type: sigmoid
sft_weight: 0.1
Loss calculation:
# SimPO loss (as above)
simpo_loss = 0.184
# SFT loss
sft_loss = -chosen_logps = -(-1.2) = 1.2
# Total loss
total_loss = simpo_loss + 0.1 * sft_loss
= 0.184 + 0.12
= 0.304
Low margin (< 0.5):
High margin (> 5.0):
Monitor:
reward_margin = chosen_rewards.mean() - rejected_rewards.mean()
print(f"Reward margin: {reward_margin:.2f}")
Typical values:
Warning signs: