.ai/skills/parity-testing/checkpoint-mechanism.md
Pipelines are monolithic __call__ methods -- you can't just call "the encode part". The checkpoint mechanism lets you stop, save, or inject tensors at named locations inside the pipeline.
Add a _checkpoints argument to both the diffusers pipeline and the reference implementation.
@dataclass
class Checkpoint:
save: bool = False # capture variables into ckpt.data
stop: bool = False # halt pipeline after this point
load: bool = False # inject ckpt.data into local variables
data: dict = field(default_factory=dict)
The pipeline accepts an optional dict[str, Checkpoint]. Place checkpoint calls at boundaries between pipeline stages -- after each encoder, before the denoising loop (capture all loop inputs), after each loop iteration, after the loop (capture final latents before decode).
def __call__(self, prompt, ..., _checkpoints=None):
# --- text encoding ---
prompt_embeds = self.text_encoder(prompt)
_maybe_checkpoint(_checkpoints, "text_encoding", {
"prompt_embeds": prompt_embeds,
})
# --- prepare latents, sigmas, positions ---
latents = self.prepare_latents(...)
sigmas = self.scheduler.sigmas
# ...
_maybe_checkpoint(_checkpoints, "preloop", {
"latents": latents,
"sigmas": sigmas,
"prompt_embeds": prompt_embeds,
"prompt_attention_mask": prompt_attention_mask,
"video_coords": video_coords,
# capture EVERYTHING the loop needs -- every tensor the transformer
# forward() receives. Missing even one variable here means you can't
# tell if it's the source of divergence during denoise debugging.
})
# --- denoising loop ---
for i, t in enumerate(timesteps):
noise_pred = self.transformer(latents, t, prompt_embeds, ...)
latents = self.scheduler.step(noise_pred, t, latents)[0]
_maybe_checkpoint(_checkpoints, f"after_step_{i}", {
"latents": latents,
})
_maybe_checkpoint(_checkpoints, "post_loop", {
"latents": latents,
})
# --- decode ---
video = self.vae.decode(latents)
return video
Each _maybe_checkpoint call does three things based on the Checkpoint's flags: save captures the local variables into ckpt.data, load injects pre-populated ckpt.data back into local variables, stop halts execution (raises an exception caught at the top level).
def _maybe_checkpoint(checkpoints, name, data):
if not checkpoints:
return
ckpt = checkpoints.get(name)
if ckpt is None:
return
if ckpt.save:
ckpt.data.update(data)
if ckpt.stop:
raise PipelineStop # caught at __call__ level, returns None
Add load support at each checkpoint where you might want to inject:
_maybe_checkpoint(_checkpoints, "preloop", {"latents": latents, ...})
# Load support: replace local variables with injected data
if _checkpoints:
ckpt = _checkpoints.get("preloop")
if ckpt is not None and ckpt.load:
latents = ckpt.data["latents"].to(device=device, dtype=latents.dtype)
The checkpoint dict is passed into the pipeline and mutated in-place. After the pipeline returns (or stops early), you read back ckpt.data to get the captured tensors. Both pipelines save under their own key names, so the test maps between them (e.g. reference "video_state.latent" -> diffusers "latents").
For large models, free the source pipeline's GPU memory before loading the target pipeline. Clone injected tensors to CPU, delete everything else, then run the target with enable_model_cpu_offload().