Back to Diffusers

Checkpoint Mechanism for Stage Testing

.ai/skills/parity-testing/checkpoint-mechanism.md

0.38.03.8 KB
Original Source

Checkpoint Mechanism for Stage Testing

Overview

Pipelines are monolithic __call__ methods -- you can't just call "the encode part". The checkpoint mechanism lets you stop, save, or inject tensors at named locations inside the pipeline.

The Checkpoint class

Add a _checkpoints argument to both the diffusers pipeline and the reference implementation.

python
@dataclass
class Checkpoint:
    save: bool = False   # capture variables into ckpt.data
    stop: bool = False   # halt pipeline after this point
    load: bool = False   # inject ckpt.data into local variables
    data: dict = field(default_factory=dict)

Pipeline instrumentation

The pipeline accepts an optional dict[str, Checkpoint]. Place checkpoint calls at boundaries between pipeline stages -- after each encoder, before the denoising loop (capture all loop inputs), after each loop iteration, after the loop (capture final latents before decode).

python
def __call__(self, prompt, ..., _checkpoints=None):
    # --- text encoding ---
    prompt_embeds = self.text_encoder(prompt)
    _maybe_checkpoint(_checkpoints, "text_encoding", {
        "prompt_embeds": prompt_embeds,
    })

    # --- prepare latents, sigmas, positions ---
    latents = self.prepare_latents(...)
    sigmas = self.scheduler.sigmas
    # ...

    _maybe_checkpoint(_checkpoints, "preloop", {
        "latents": latents,
        "sigmas": sigmas,
        "prompt_embeds": prompt_embeds,
        "prompt_attention_mask": prompt_attention_mask,
        "video_coords": video_coords,
        # capture EVERYTHING the loop needs -- every tensor the transformer
        # forward() receives. Missing even one variable here means you can't
        # tell if it's the source of divergence during denoise debugging.
    })

    # --- denoising loop ---
    for i, t in enumerate(timesteps):
        noise_pred = self.transformer(latents, t, prompt_embeds, ...)
        latents = self.scheduler.step(noise_pred, t, latents)[0]

        _maybe_checkpoint(_checkpoints, f"after_step_{i}", {
            "latents": latents,
        })

    _maybe_checkpoint(_checkpoints, "post_loop", {
        "latents": latents,
    })

    # --- decode ---
    video = self.vae.decode(latents)
    return video

The helper function

Each _maybe_checkpoint call does three things based on the Checkpoint's flags: save captures the local variables into ckpt.data, load injects pre-populated ckpt.data back into local variables, stop halts execution (raises an exception caught at the top level).

python
def _maybe_checkpoint(checkpoints, name, data):
    if not checkpoints:
        return
    ckpt = checkpoints.get(name)
    if ckpt is None:
        return
    if ckpt.save:
        ckpt.data.update(data)
    if ckpt.stop:
        raise PipelineStop  # caught at __call__ level, returns None

Injection support

Add load support at each checkpoint where you might want to inject:

python
_maybe_checkpoint(_checkpoints, "preloop", {"latents": latents, ...})

# Load support: replace local variables with injected data
if _checkpoints:
    ckpt = _checkpoints.get("preloop")
    if ckpt is not None and ckpt.load:
        latents = ckpt.data["latents"].to(device=device, dtype=latents.dtype)

Key insight

The checkpoint dict is passed into the pipeline and mutated in-place. After the pipeline returns (or stops early), you read back ckpt.data to get the captured tensors. Both pipelines save under their own key names, so the test maps between them (e.g. reference "video_state.latent" -> diffusers "latents").

Memory management for large models

For large models, free the source pipeline's GPU memory before loading the target pipeline. Clone injected tensors to CPU, delete everything else, then run the target with enable_model_cpu_offload().