docs_new/docs/advanced_features/forward_hooks.mdx
SGLang supports attaching PyTorch forward hooks to specific submodules in the loaded model, configured entirely via server_args JSON.
This is useful for:
Hooks are attached once during ModelRunner.initialize and run on every forward pass.
Hooks are configured via a ServerArgs field:
class ServerArgs:
...
# For forward hooks
forward_hooks: Optional[List[dict[str, Any]]] = None
In JSON form, a minimal configuration looks like:
{
"forward_hooks": [
{
"name": "outer_linear_hooks",
"target_modules": ["outer.0", "outer.1"],
"hook_factory": "my_project.hooks:dummy_hook_factory",
"config": {
"tag": "outer-layer"
}
}
]
}
forward_hooks (optional list of objects)
Each element is a hook spec describing:
Each entry in forward_hooks is a JSON object with the following shape:
{
"name": "optional-descriptive-name",
"target_modules": ["pattern1", "pattern2", "..."],
"hook_factory": "module.submodule:factory_name",
"config": {
"...": "arbitrary JSON"
}
}
name (optional)Human-readable name for logging.
Used only in log messages such as:
Registered forward hook 'outer_linear_hooks' on outer.0
target_modules (required)List of module name patterns used to match entries in model.named_modules().
Patterns are matched using fnmatch.fnmatch, so:
"outer.0" matches exactly "outer.0"."outer.*" matches "outer.0", "outer.1", "outer.inner", etc."outer.inner.*" matches children under outer.inner.If no modules match the given patterns, hook registration does not fail. Instead, SGLang logs a warning and continues:
textNo modules matched hook spec 'name' patterns=['...']
hook_factory (required)String path to the Python factory function that creates the hook.
Supported formats:
"package.module:factory_name""package.module.submodule.factory_name"The path is resolved via:
def resolve_callable(path: Optional[str]) -> Optional[Callable]:
if path is None:
return None
if ":" in path:
module_name, fn_name = path.split(":", 1)
else:
parts = path.split(".")
if len(parts) < 2:
raise ValueError(
f"Invalid hook callable path '{path}'. "
"Expected 'module.submodule:factory' or 'module.submodule.factory'."
)
*mod_parts, fn_name = parts
module_name = ".".join(mod_parts)
module = importlib.import_module(module_name)
try:
return getattr(module, fn_name)
except AttributeError as e:
raise AttributeError(
f"Module '{module_name}' has no attribute '{fn_name}' "
f"(from hook path '{path}')"
) from e
Failure modes:
:), a ValueError is raised at startup.AttributeError is raised with a clear error message.None, a warning is logged and no hook is registered for that spec (initialization continues).The first two cause initialization to fail fast with a descriptive error; the last one is non-fatal.
config (optional)dict.Hooks are registered in ModelRunner.initialize():
if server_args.forward_hooks:
register_forward_hooks(self.model, server_args.forward_hooks)
The actual registration logic is implemented by register_forward_hooks:
def register_forward_hooks(model: nn.Module, hook_specs: List[dict[str, Any]]) -> None:
"""
hook_specs is a list of dicts from server_args.forward_hooks.
Attaches forward hooks to the matching modules.
"""
name_to_module = dict(model.named_modules())
for spec in hook_specs:
spec_name = spec.get("name", "")
target_patterns = spec.get("target_modules", [])
if not target_patterns:
logger.warning(
f"Hook spec '{spec_name}' has no 'target_modules', skipping"
)
continue
hook_factory_path = spec.get("hook_factory")
if not hook_factory_path:
logger.warning(
f"Hook spec '{spec_name}' has no 'hook_factory', skipping"
)
continue
config = spec.get("config") or {}
hook_factory = resolve_callable(hook_factory_path)
hook = hook_factory(config) if hook_factory else None
if hook is None:
logger.warning(
f"Hook factory '{hook_factory_path}' for spec '{spec_name}' "
"returned None, not registering any hook"
)
continue
# Resolve patterns like "model.layers.*.mlp"
matched = []
for name, module in name_to_module.items():
if any(fnmatch.fnmatch(name, pattern) for pattern in target_patterns):
matched.append((name, module))
if not matched:
logger.warning(
f"No modules matched hook spec '{spec_name}' "
f"patterns={target_patterns}"
)
continue
for module_name, module in matched:
if hook:
_ = module.register_forward_hook(hook)
logger.info(
f"Registered forward hook '{spec_name}' "
f"on {module_name}"
)
Key points:
module.register_forward_hook).ModelRunner (they cannot be removed later via this API).None, a warning is logged and that spec is skipped.A hook factory is a regular Python function:
config: dict (from JSON)(module, inputs, output)Example:
HOOK_CALLS = []
def dummy_hook_factory(config):
"""Factory that returns a forward hook capturing a tag from config."""
tag = config.get("tag", "default")
def hook(module, inputs, output):
HOOK_CALLS.append(
{
"module_type": type(module).__name__,
"tag": tag,
"shape": tuple(output.shape),
}
)
return output # must return output if you don’t want to modify the tensor
return hook
In JSON:
{
"forward_hooks": [
{
"name": "capture_outer",
"target_modules": ["outer.0", "outer.1"],
"hook_factory": "my_project.hooks:dummy_hook_factory",
"config": {
"tag": "outer"
}
}
]
}
This will:
my_project.hooks:dummy_hook_factory to a Python callable.config = {"tag": "outer"}.outer.0 and outer.1.HOOK_CALLS.Define forward_hooks as a list of specs in ServerArgs to turn on the feature.
Each spec:
target_modules (glob patterns over model.named_modules()),hook_factory,config into that factory.Hook factories are resolved via resolve_callable, which supports module:factory and module.submodule.factory.
Hooks are standard PyTorch forward hooks, attached once at startup and invoked on every forward pass.
Misconfiguration is either:
None).