docs_new/docs/sglang-diffusion/support_new_models.mdx
This document explains how to add support for new diffusion models in SGLang Diffusion.
SGLang Diffusion is engineered for both performance and flexibility, built upon a pipeline architecture. This design allows developers to construct pipelines for various diffusion models while keeping the core generation loop standardized for optimization.
At its core, the architecture revolves around two key concepts, as highlighted in our blog post:
ComposedPipeline: This class orchestrates a series of PipelineStages to define the complete generation process for a specific model. It acts as the main entry point for a model and manages the data flow between the different stages of the diffusion process.PipelineStage: Each stage is a modular component that encapsulates a function within the diffusion process. Examples include prompt encoding, the denoising loop, or VAE decoding.SGLang Diffusion supports two pipeline composition styles. Both are valid; choose the one that best fits your model.
The recommended default for most new models. Uses a three-stage structure:
BeforeDenoisingStage (model-specific) → DenoisingStage (standard) → DecodingStage (standard)
Why recommended? Modern diffusion models often have highly heterogeneous pre-processing requirements — different text encoders, different latent formats, different conditioning mechanisms. The Hybrid approach keeps pre-processing isolated per model, avoids fragile shared stages with excessive conditional logic, and lets developers port Diffusers reference code quickly.
Uses the framework's fine-grained standard stages (TextEncodingStage, LatentPreparationStage, TimestepPreparationStage, etc.) to build the pipeline by composition. Convenience methods like add_standard_t2i_stages() and add_standard_ti2i_stages() make this very concise.
This style is appropriate when:
To add support for a new diffusion model, you will need to define or configure the following components:
PipelineConfig: A dataclass holding static configurations for your model pipeline — precision settings, model architecture parameters, and callback methods used by the standard DenoisingStage and DecodingStage. Each model has its own subclass.
SamplingParams: A dataclass defining runtime generation parameters — prompt, negative_prompt, guidance_scale, num_inference_steps, seed, height, width, etc.
Pre-processing stage(s): Either a single model-specific {Model}BeforeDenoisingStage (Hybrid style) or a combination of standard stages (Modular style). See Two Pipeline Styles above.
ComposedPipeline: A class that wires together your pre-processing stage(s) with the standard DenoisingStage and DecodingStage. See base definitions:
Modules (model components): Each pipeline references modules loaded from the model repository (e.g., Diffusers model_index.json):
text_encoder: Encodes text prompts into embeddings.tokenizer: Tokenizes raw text input for the text encoder(s).processor: Preprocesses images and extracts features; often used in image-to-image tasks.image_encoder: Specialized image feature extractor.dit/transformer: The core denoising network (DiT/UNet architecture) operating in latent space.scheduler: Controls the timestep schedule and denoising dynamics.vae: Variational Autoencoder for encoding/decoding between pixel space and latent space.The following fine-grained stages can be composed to build the pre-processing portion of a pipeline. They are best suited for models whose pre-processing largely fits the standard patterns. If your model requires significant customization, consider the Hybrid style with a single BeforeDenoisingStage instead.
Before writing any code, obtain the model's original implementation or Diffusers pipeline code:
pipeline_*.py file from the diffusers library or HuggingFace repo)model_index.json and the associated pipeline classOnce you have the reference code, study it thoroughly:
model_index.json to identify required modules.__call__ method to understand:
Before creating any new files, check whether an existing pipeline or stage can be reused or extended. Only create new pipelines/stages when the existing ones would need substantial structural changes or when no architecturally similar implementation exists.
runtime/pipelines_core/stages/ and stages/model_specific_stages/.AutoencoderKL), text encoders (CLIP, T5), and schedulers. Reuse these directly.Adapt the model's core components:
runtime/models/dits/runtime/models/encoders/runtime/models/vaes/runtime/models/schedulers/ if neededUse SGLang's fused kernels where possible (see LayerNormScaleShift, RMSNormScaleShift, apply_qk_norm, etc.).
Tensor Parallel (TP) and Sequence Parallel (SP): For multi-GPU deployment, it is recommended to add TP/SP support to the DiT model. This can be done incrementally after the single-GPU implementation is verified. Reference implementations:
runtime/models/dits/wanvideo.py) — Full TP + SP: ColumnParallelLinear/RowParallelLinear for attention, sequence dimension sharding via get_sp_world_size()runtime/models/dits/qwen_image.py) — SP via USPAttention (Ulysses + Ring Attention)configs/models/dits/{model_name}.pyconfigs/models/vaes/{model_name}.pyconfigs/sample/{model_name}.pyThe PipelineConfig provides callbacks that the standard DenoisingStage and DecodingStage use:
# python/sglang/multimodal_gen/configs/pipeline_configs/my_model.py
@dataclass
class MyModelPipelineConfig(ImagePipelineConfig):
task_type: ModelTaskType = ModelTaskType.T2I
vae_precision: str = "bf16"
should_use_guidance: bool = True
dit_config: DiTConfig = field(default_factory=MyModelDitConfig)
vae_config: VAEConfig = field(default_factory=MyModelVAEConfig)
def get_freqs_cis(self, batch, device, rotary_emb, dtype):
"""Prepare rotary position embeddings for the DiT."""
...
def prepare_pos_cond_kwargs(self, batch, latent_model_input, t, **kwargs):
"""Build positive conditioning kwargs for each denoising step."""
return {
"hidden_states": latent_model_input,
"encoder_hidden_states": batch.prompt_embeds[0],
"timestep": t,
}
def prepare_neg_cond_kwargs(self, batch, latent_model_input, t, **kwargs):
"""Build negative conditioning kwargs for CFG."""
return {
"hidden_states": latent_model_input,
"encoder_hidden_states": batch.negative_prompt_embeds[0],
"timestep": t,
}
def get_decode_scale_and_shift(self):
"""Return (scale, shift) for latent denormalization before VAE decode."""
...
Choose based on your model's needs (see How to Choose):
Create a single stage that handles all pre-processing. Best when the model has custom/complex pre-processing logic.
# python/sglang/multimodal_gen/runtime/pipelines_core/stages/model_specific_stages/my_model.py
class MyModelBeforeDenoisingStage(PipelineStage):
"""Monolithic pre-processing stage for MyModel.
Consolidates: input validation, text/image encoding, latent
preparation, and timestep computation.
"""
def __init__(self, vae, text_encoder, tokenizer, transformer, scheduler):
super().__init__()
self.vae = vae
self.text_encoder = text_encoder
self.tokenizer = tokenizer
self.transformer = transformer
self.scheduler = scheduler
@torch.no_grad()
def forward(self, batch: Req, server_args: ServerArgs) -> Req:
device = get_local_torch_device()
# 1. Encode prompt (model-specific logic)
prompt_embeds, negative_prompt_embeds = self._encode_prompt(...)
# 2. Prepare latents
latents = self._prepare_latents(...)
# 3. Prepare timesteps
timesteps, sigmas = self._prepare_timesteps(...)
# 4. Populate batch for DenoisingStage
batch.prompt_embeds = [prompt_embeds]
batch.negative_prompt_embeds = [negative_prompt_embeds]
batch.latents = latents
batch.timesteps = timesteps
batch.num_inference_steps = len(timesteps)
batch.sigmas = sigmas.tolist()
batch.generator = generator
batch.raw_latent_shape = latents.shape
return batch
Skip creating a custom stage entirely — configure via PipelineConfig callbacks and use framework helpers. Best when the model fits standard patterns.
(This option has no separate stage file; the pipeline class in Step 7 calls add_standard_t2i_stages() directly.)
Key batch fields that DenoisingStage expects (regardless of which option you choose):
# python/sglang/multimodal_gen/runtime/pipelines/my_model.py
class MyModelPipeline(LoRAPipeline, ComposedPipelineBase):
pipeline_name = "MyModelPipeline" # Must match model_index.json _class_name
_required_config_modules = [
"text_encoder", "tokenizer", "vae", "transformer", "scheduler",
]
def create_pipeline_stages(self, server_args: ServerArgs):
# 1. Monolithic pre-processing (model-specific)
self.add_stage(
MyModelBeforeDenoisingStage(
vae=self.get_module("vae"),
text_encoder=self.get_module("text_encoder"),
tokenizer=self.get_module("tokenizer"),
transformer=self.get_module("transformer"),
scheduler=self.get_module("scheduler"),
),
)
# 2. Standard denoising loop (framework-provided)
self.add_stage(
DenoisingStage(
transformer=self.get_module("transformer"),
scheduler=self.get_module("scheduler"),
),
)
# 3. Standard VAE decoding (framework-provided)
self.add_standard_decoding_stage()
EntryClass = [MyModelPipeline]
# python/sglang/multimodal_gen/runtime/pipelines/my_model.py
class MyModelPipeline(LoRAPipeline, ComposedPipelineBase):
pipeline_name = "MyModelPipeline"
_required_config_modules = [
"text_encoder", "tokenizer", "vae", "transformer", "scheduler",
]
def create_pipeline_stages(self, server_args: ServerArgs):
# All pre-processing + denoising + decoding in one call
self.add_standard_t2i_stages(
prepare_extra_timestep_kwargs=[prepare_mu], # model-specific hooks
)
EntryClass = [MyModelPipeline]
Register your configs in registry.py:
register_configs(
model_family="my_model",
sampling_param_cls=MyModelSamplingParams,
pipeline_config_cls=MyModelPipelineConfig,
hf_model_paths=["org/my-model-name"],
)
The EntryClass in your pipeline file is automatically discovered by the registry — no additional registration needed for the pipeline class itself.
After implementation, verify that the generated output is not noise. A noisy or garbled output is the most common sign of an incorrect implementation. Common causes include:
is_neox_style)Debug by comparing intermediate tensor values against the Diffusers reference pipeline with the same seed.
Before submitting your implementation, verify:
Common (both styles):
runtime/pipelines/{model_name}.py with EntryClassconfigs/pipeline_configs/{model_name}.pyconfigs/sample/{model_name}.pyruntime/models/dits/{model_name}.pyconfigs/models/dits/ and configs/models/vaes/registry.py via register_configs()pipeline_name matches Diffusers model_index.json _class_name_required_config_modules lists all modules from model_index.jsonPipelineConfig callbacks (prepare_pos_cond_kwargs, etc.) match the DiT's forward() signatureDenoisingStage and DecodingStage (not custom denoising loops)wanvideo.py for TP+SP, qwen_image.py for USPAttention)Hybrid style only:
stages/model_specific_stages/{model_name}.pyBeforeDenoisingStage.forward() populates all batch fields required by DenoisingStage