docs_new/docs/hardware-platforms/plugin.mdx
Allows hardware vendors and developers to extend SGLang without modifying the main repository code.
The framework provides two plugin types, both discovered via Python's standard setuptools entry_points:
pip install, no sglang code changes required.SGLANG_PLATFORM selects or validates the active platform plugin; SGLANG_PLUGINS (comma-separated) controls which general plugins to load.The plugin system currently targets out-of-tree (OOT) hardware platforms — enabling new devices to integrate with SGLang without any changes to the main repository. The main-repo hardware paths (CUDA, ROCm, NPU, XPU, etc.) continue to use the existing is_cuda()/is_npu()/… utility functions.
As the plugin interfaces mature and stabilize, in-tree hardware backends can be gradually migrated to the same plugin architecture. This would replace the scattered if device == "cuda" … elif device == "npu" … branches throughout the codebase with a single polymorphic dispatch through the platform interface, making each hardware backend self-contained and the core engine hardware-agnostic.
The platform hierarchy uses a DeviceMixin pattern to share device operations between SRT (LLM inference) and Multimodal subsystems:
DeviceMixin (shared device identity + operations)
├── SRTPlatform(DeviceMixin) # + graph runner, KV pool, …
│ └── MySRTPlatform(SRTPlatform, MyDeviceMixin) # OOT plugin
└── MMPlatform(DeviceMixin) # + attention backend, VAE, … (future)
└── MyMMPlatform(MMPlatform, MyDeviceMixin) # OOT plugin
Key design points:
is_cuda(), is_npu(), etc.) and device operations (set_device(), get_device_name(), etc.)current_platform singletonNotImplementedError by default (fail-fast)False/pass)[Active] (called by SGLang core) or [Planned] (reserved for future migration)current_platform)current_platform is a lazy singleton in sglang.srt.platforms. On first access it resolves the active platform through the following priority chain:
entry_points("sglang.srt.platforms") → Enumerate ALL plugins by name (metadata only)
│
├─ SGLANG_PLATFORM set (front-loading filter):
│ ├─ Name not found in discovered → RuntimeError
│ ├─ activate() returns non-None → load that platform
│ └─ activate() returns None → RuntimeError (hardware unavailable)
│
└─ SGLANG_PLATFORM unset (auto-discover, activate all):
├─ 0 activated → fallback base SRTPlatform
├─ 1 activated → use it
└─ N activated → RuntimeError (must set SGLANG_PLATFORM)
load_plugins() discovers and executes general plugins, then applies all registered hooks. It is called at four points:
Note:
load_plugins()is idempotent (guarded by_plugins_loadedflag). In spawn'd subprocesses the flag resets, so plugins are correctly re-loaded.
load_plugins()
├── _get_excluded_dists() → compute dists to skip (via SGLANG_PLATFORM)
├── load_plugins_by_group("sglang.srt.plugins", → discover entry_points, filter by SGLANG_PLUGINS
│ excluded_dists=...) skip plugins from unselected platform packages
├── for each plugin: → set _current_plugin_source context var
│ func() side effects (register hooks with source tracking)
└── HookRegistry.apply_hooks() → monkey-patch targets
A hardware platform plugin registers an SRTPlatform subclass that tells SGLang how to interact with a specific hardware backend.
1. Create a minimal package:
my_platform_plugin/
├── pyproject.toml
└── my_platform_plugin/
├── __init__.py # activate() function
├── device.py # MyDeviceMixin
└── platform.py # MySRTPlatform
2. pyproject.toml:
[build-system]
requires = ["setuptools"]
build-backend = "setuptools.build_meta"
[project]
name = "my-platform-plugin"
version = "0.1.0"
[project.entry-points."sglang.srt.platforms"]
my_device = "my_platform_plugin:activate"
3. __init__.py — activation function:
def activate():
"""Return fully-qualified class name to activate, or None to skip."""
if _my_device_is_available():
return "my_platform_plugin.platform.MySRTPlatform"
return None
4. device.py — device mixin:
from sglang.srt.platforms.device_mixin import DeviceMixin, PlatformEnum
class MyDeviceMixin(DeviceMixin):
_enum = PlatformEnum.OOT
device_name = "my_device"
device_type = "my_device" # torch device type
def set_device(self, device) -> None: ...
def get_device_name(self, device_id=0) -> str: ...
def get_device_total_memory(self, device_id=0) -> int: ...
def get_current_memory_usage(self, device=None) -> float: ...
def get_device_capability(self, device_id=0): ...
def get_torch_distributed_backend_str(self) -> str: ...
5. platform.py — SRT platform:
from sglang.srt.platforms.interface import SRTPlatform
from my_platform_plugin.device import MyDeviceMixin
class MySRTPlatform(SRTPlatform, MyDeviceMixin):
def get_default_attention_backend(self) -> str: ...
def support_cuda_graph(self) -> bool: ...
# ... override other methods as needed
6. Install and verify:
pip install -e my_platform_plugin/
python -c "from sglang.srt.platforms import current_platform; print(current_platform)"
<table style={{width: "100%", borderCollapse: "collapse", tableLayout: "fixed"}}> <colgroup> <col style={{width: "25%"}} /> <col style={{width: "25%"}} /> <col style={{width: "25%"}} /> <col style={{width: "25%"}} /> </colgroup> <thead> <tr> <th>Method</th> <th>Default</th> <th>Status</th> <th>Description</th> </tr> </thead> <tbody> <tr> <td><code>get_device(local_rank)</code></td> <td><code>raise NotImplementedError</code></td> <td>Planned</td> <td>Return <code>torch.device</code> for a given local rank</td> </tr> <tr> <td><code>set_device(device)</code></td> <td><code>raise NotImplementedError</code></td> <td>Planned</td> <td>Set the current device</td> </tr> <tr> <td><code>get_device_name(device_id)</code></td> <td><code>raise NotImplementedError</code></td> <td>Planned</td> <td>Get human-readable device name</td> </tr> <tr> <td><code>get_device_uuid(device_id)</code></td> <td><code>raise NotImplementedError</code></td> <td>Planned</td> <td>Get unique device identifier</td> </tr> <tr> <td><code>get_device_capability(device_id)</code></td> <td><code>raise NotImplementedError</code></td> <td>Planned</td> <td>Get <code>DeviceCapability(major, minor)</code>. None if N/A</td> </tr> <tr> <td><code>empty_cache()</code></td> <td><code>pass</code></td> <td>Planned</td> <td>Release cached device memory</td> </tr> <tr> <td><code>synchronize()</code></td> <td><code>pass</code></td> <td>Planned</td> <td>Synchronize device operations</td> </tr> <tr> <td><code>get_device_total_memory(device_id)</code></td> <td><code>raise NotImplementedError</code></td> <td><strong>Active</strong></td> <td>Get total device memory in bytes</td> </tr> <tr> <td><code>get_available_memory(device_id)</code></td> <td><code>raise NotImplementedError</code></td> <td>Planned</td> <td>Return <code>(free_bytes, total_bytes)</code></td> </tr> <tr> <td><code>get_current_memory_usage(device)</code></td> <td><code>raise NotImplementedError</code></td> <td><strong>Active</strong></td> <td>Get current peak memory usage in bytes</td> </tr> <tr> <td><code>get_torch_distributed_backend_str()</code></td> <td><code>raise NotImplementedError</code></td> <td>Planned</td> <td>Distributed backend string (e.g. "nccl", "hccl")</td> </tr> <tr> <td><code>get_communicator_class()</code></td> <td><code>None</code></td> <td>Planned</td> <td>Platform-specific communicator class</td> </tr> <tr> <td><code>inference_mode()</code></td> <td><code>torch.inference_mode(True)</code></td> <td>Planned</td> <td>Return inference mode context manager</td> </tr> <tr> <td><code>seed_everything(seed)</code></td> <td>Set random/np/torch seeds</td> <td>Planned</td> <td>Set random seeds for reproducibility</td> </tr> <tr> <td><code>verify_quantization(quant)</code></td> <td><code>pass</code></td> <td>Planned</td> <td>Validate quantization method support</td> </tr> <tr> <td><code>get_cpu_architecture()</code></td> <td>Auto-detect x86/arm</td> <td>Planned</td> <td>Detect CPU architecture (<code>CpuArchEnum</code>)</td> </tr> </tbody> </table>Methods annotated [Active] are called by SGLang core through
current_platform— OOT implementations take effect immediately. Methods annotated [Planned] are reserved interfaces — SGLang core still uses hardcoded calls (e.g.torch.cuda.empty_cache()). OOT implementations will NOT take effect until the core is migrated in a future PR.
General function plugins inject behavior into sglang without requiring a custom platform. Use cases include:
1. Create a minimal package:
my_general_plugin/
├── pyproject.toml
└── my_general_plugin/
└── __init__.py # register() function
2. pyproject.toml:
[build-system]
requires = ["setuptools"]
build-backend = "setuptools.build_meta"
[project]
name = "my-general-plugin"
version = "0.1.0"
[project.entry-points."sglang.srt.plugins"]
my_plugin = "my_general_plugin:register"
3. __init__.py — register hooks:
from sglang.srt.plugins.hook_registry import HookRegistry, HookType
def register():
"""Entry point called by load_plugins()."""
HookRegistry.register(
"sglang.srt.managers.scheduler.Scheduler.__init__",
my_hook,
HookType.AROUND,
)
def my_hook(original_fn, self, *args, **kwargs):
result = original_fn(self, *args, **kwargs)
print(f"Scheduler initialized! gpu_id={self.gpu_id}")
return result
4. Install and run:
pip install -e my_general_plugin/
sglang serve --model-path <model> [options]
# Look for "Scheduler initialized!" in logs
HookRegistry supports four hook types:
Note: Only
REPLACEaccepts a class as the hook. Passing a class toBEFORE/AFTER/AROUNDraisesTypeErrorat registration time.
Hooks can be registered using the imperative API or the decorator API:
# --- Imperative API ---
from sglang.srt.plugins.hook_registry import HookRegistry, HookType
def my_timer(original_fn, *args, **kwargs):
start = time.perf_counter()
result = original_fn(*args, **kwargs)
print(f"Elapsed: {time.perf_counter() - start:.3f}s")
return result
HookRegistry.register(
"sglang.srt.managers.scheduler.Scheduler.get_next_batch_to_run",
my_timer,
HookType.AROUND,
)
# --- Decorator API ---
from sglang.srt.plugins.hook_registry import plugin_hook, HookType
@plugin_hook(
"sglang.srt.managers.scheduler.Scheduler.get_next_batch_to_run",
type=HookType.AROUND,
)
def my_timer(original_fn, *args, **kwargs):
start = time.perf_counter()
result = original_fn(*args, **kwargs)
print(f"Elapsed: {time.perf_counter() - start:.3f}s")
return result
# --- Class replacement (REPLACE) ---
from sglang.srt.plugins.hook_registry import plugin_hook, HookType
from sglang.srt.managers.scheduler import Scheduler
@plugin_hook(
"sglang.srt.managers.scheduler.Scheduler",
type=HookType.REPLACE,
)
class MyScheduler(Scheduler):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
print("Enhanced scheduler initialized!")
Target paths use fully-qualified dotted notation. Both formats are supported:
sglang.srt.managers.scheduler.Scheduler.__init__sglang.srt.managers.scheduler:Scheduler.__init__ (colon treated as dot)