docs/source/en/kernel_doc/writing_kernels.md
This guide explains how to write kernels that go beyond a stateless forward replacement. It covers two capabilities the extended KernelConfig API supports:
For basic kernels (stateless forward replacements with no parameter changes), see the kernels library documentation.
Any kernel that carries its own parameters follows a two-class pattern.
KernelName: contains only the forward pass. The kernels library uses this class to kernelize the model because it does not allow stateful kernel classes.KernelNameLayout: an nn.Module that holds the parameters and monkey-patches the original module before the checkpoint is loaded. At runtime, kernelize replaces its forward with the forward from KernelName'. You do not need to define forward. Transformers injects one automatically with the same signature as KernelName.forward.[!IMPORTANT]
The naming convention is strict. The layout class must be named {KernelName}Layout and defined in the same module as KernelName.
Use this pattern when the kernel expects weights under different names or in a different shape than the original model checkpoint.
The KernelNameLayout class has the same __init__ signature as the module it replaces and declares a conversion_mapping class attribute that tells Transformers how to remap checkpoint keys to the new parameter names (see Dynamic weight loading for more details).
import torch
import torch.nn as nn
class CustomRMSNormLayout(nn.Module):
conversion_mapping = [...] # rules that remap checkpoint keys to the new parameter names
def __init__(self, hidden_size: int, eps: float = 1e-6):
super().__init__()
self.scale = nn.Parameter(torch.ones(hidden_size))
self.variance_epsilon = eps
class CustomRMSNorm(nn.Module):
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
input_dtype = hidden_states.dtype
hidden_states = hidden_states.to(torch.float32)
variance = hidden_states.pow(2).mean(-1, keepdim=True)
hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
return self.scale * hidden_states.to(input_dtype)
class layers:
CustomRMSNorm = CustomRMSNorm
[!NOTE] The
layersclass is required by thekernelslibrary to expose the kernel entry point.
Load this kernel by passing the repo and class name to [KernelConfig]. The key is the original module class name from the model. The value points to the KernelName class (not the Layout) in the repo.
from transformers import AutoModelForCausalLM, KernelConfig
kernel_config = KernelConfig({"RMSNorm": "owner/my-kernel:CustomRMSNorm"})
model = AutoModelForCausalLM.from_pretrained(
"Qwen/Qwen3-0.6B",
use_kernels=True,
kernel_config=kernel_config,
device_map="cuda",
)
When the model loads, Transformers:
CustomRMSNorm from the repo and looks for CustomRMSNormLayout in the same module.RMSNorm in the model with CustomRMSNormLayout.conversion_mapping so they load into the new parameter names.kernelize, which replaces CustomRMSNormLayout.forward with CustomRMSNorm.forward.Use this pattern when a kernel replaces multiple adjacent modules with a single fused implementation. Because the fused module combines parameters from several original modules, the KernelNameLayout.__init__ receives the instantiated child modules rather than their constructor arguments.
import torch
import torch.nn as nn
class RMSNormMLPLayout(nn.Module):
conversion_mapping = [...] # rules that remap checkpoint keys to the fused parameter names
def __init__(self, norm, mlp):
super().__init__()
self.variance_epsilon = norm.variance_epsilon
self.scale = nn.Parameter(torch.empty_like(norm.weight))
self.gate_up_proj = nn.Linear(
mlp.gate_proj.in_features,
mlp.gate_proj.out_features + mlp.up_proj.out_features,
bias=mlp.gate_proj.bias is not None,
device=mlp.gate_proj.weight.device,
dtype=mlp.gate_proj.weight.dtype,
)
self.down_proj = nn.Linear(
mlp.down_proj.in_features,
mlp.down_proj.out_features,
bias=mlp.down_proj.bias is not None,
device=mlp.down_proj.weight.device,
dtype=mlp.down_proj.weight.dtype,
)
self.act_fn = mlp.act_fn
class RMSNormMLP(nn.Module):
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
input_dtype = hidden_states.dtype
hidden_states = hidden_states.to(torch.float32)
variance = hidden_states.pow(2).mean(-1, keepdim=True)
hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
hidden_states = self.scale * hidden_states.to(input_dtype)
gate, up = self.gate_up_proj(hidden_states).chunk(2, dim=-1)
return self.down_proj(self.act_fn(gate) * up)
class layers:
RMSNormMLP = RMSNormMLP
To fuse modules, pass a tuple of (class_name, path_pattern) pairs as the key in KernelConfig instead of a plain string. All patterns must share the same parent module (Transformers fuses the children in that parent). The * wildcard matches any single path segment.
from transformers import AutoModelForCausalLM, KernelConfig
kernel_config = KernelConfig(
{
(
("RMSNorm", "model.layers.*.post_attention_layernorm"),
("MLP", "model.layers.*.mlp"),
): "owner/my-kernel:RMSNormMLP",
}
)
model = AutoModelForCausalLM.from_pretrained(
"Qwen/Qwen3-0.6B",
use_kernels=True,
kernel_config=kernel_config,
device_map="cuda",
)
When the model loads, Transformers:
RMSNormMLP from the repo and finds RMSNormMLPLayout in the same module.model.layers.* and builds a fused parent class whose __init__ calls RMSNormMLPLayout(post_attention_layernorm, mlp).mlp) with nn.Identity() to preserve the parent module's interface.conversion_mapping.kernelize, which replaces RMSNormMLPLayout.forward with RMSNormMLP.forward.[!TIP] The order of pairs in the fusion tuple determines the argument order passed to
KernelNameLayout.__init__.