torch/_inductor/lookup_table/README.md
The template lookup table system provides a way to pre-configure kernel template parameters for specific operations and input configurations, bypassing the default choice generation and autotuning process.
The lookup table system replaces default choice generation with pre-configured template parameters for specific
operations and input configurations. It sits orthogonal to max-autotune(-gemm) in the following way
If a lookup table is provided and there is a match
If there is no match, we fall back to the default choice generation process, including max-autotune(-gemm) logic
Enable the system by setting both:
from torch._inductor import config
config.lookup_table.table = your_table_dict
# You also need to set it as the default choice handler
from torch._inductor.lookup_table import LookupTableChoices
torch._inductor.V.set_choices_handler(LookupTableChoices())
The key schema format is described in detail in the Key Schemas section below.
Configure device key behavior:
# Control whether entries include device-specific keys for lookups
# Device-agnostic entries work across different GPU models
Lookup Behavior: During lookup, the system automatically tries both key formats:
"NVIDIA H100+input_data+mm") - tried first"input_data+mm") - tried if device-specific failsPriority: If both device-specific and device-agnostic entries exist for the same inputs, the device-specific entry takes priority.
NOTE: Device-based keys simplify hardware-specific optimization without complex build rules. Currently limited to device name only. If you need additional conditional key attributes (e.g., CUDA version filtering), please file an issue or submit a patch.
When the table is active, the following behavior occurs for all supported operations:
Currently supports: mm, addmm, bmm, mm_plus_mm, scaled_mm operations with
The table is a dictionary with keys in the format:
"input_key+op_name"
Where:
input_key: Generated from KernelInputs.key property, represents tensor shapes/dtypes/stridesop_name: Operation name ("mm", "addmm", etc.)Each value is a list of configuration dictionaries containing:
template_id: Template identifier ("triton:mm", "triton::mm_persistent_tma", "decompose_k", etc.)BLOCK_M, BLOCK_N, BLOCK_K, num_warps, etc.)NOTE: The key schema format is subject to change as the system evolves.
The lookup table uses composite keys to match kernel configurations. See Implementation Details below for more technical information about key generation. This section describes the structure of these keys.
Keys follow the pattern:
[device_name+]input_key+[additional_params+]op_name
Components:
device_name (optional): GPU device identifier (e.g., "NVIDIA H100")
torch.cuda.get_device_properties().gcnArchNameinput_key: Tensor configuration representation from KernelInputs.key
((dtype, shape, stride), (dtype, shape, stride), ...)((torch.float16, [128, 256], [0, 1]), (torch.float16, [64, 256], [256, 1]))additional_params (optional): Operation-specific parameters
key1=value1&key2=value2alpha=1&beta=1 for addmm operationsop_name: Operation identifier
"mm", "addmm", "bmm", "mm_plus_mm", "scaled_mm"Device-specific key for addmm:
"NVIDIA H100+((torch.float16, [128, 256], [0, 1]), (torch.float16, [128, 64], [64, 1]), (torch.float16, [64, 256], [256, 1]))+alpha=1&beta=1+addmm"
Device-agnostic key for mm:
"((torch.float16, [64, 128], [128, 1]), (torch.float16, [128, 256], [256, 1]))+mm"
Key with no additional parameters:
"((torch.float32, [512, 512], [512, 1]), (torch.float32, [512, 512], [512, 1]))+bmm"
During lookup, the system tries keys in priority order:
This allows tables to contain:
This is an example table for a single input showing two configurations
table = {
"((torch.float16, [128, 256], [0, 1]), (torch.float16, [128, 64], [64, 1]), (torch.float16, [64, 256], [256, 1]))+alpha=1&beta=1+addmm": [
{
"template_id": "triton::mm",
"EVEN_K": true,
"USE_FAST_ACCUM": false,
"ACC_TYPE": "tl.float32",
"num_stages": 2,
"num_warps": 4,
"BLOCK_M": 32,
"BLOCK_N": 32,
"BLOCK_K": 64,
"hint_override": null,
"GROUP_M": 8,
"template_hash": "0717af5834e39dcca7ea817f896b8d85b4886422da7a3ab5f6911b4cfe568896"
},
{
"template_id": "aten::bias_addmm"
},
]
}
The lookup table system includes source hashing to prevent using stale configurations when template code changes.
torch._inductor.config.lookup_table.check_src_hash = True"template_hash" to table entries for enhanced safetyWhen source hash checking is enabled:
"template_hash" fields are validated against current template source hashes{
"template_id": "triton::mm",
"BLOCK_M": 32,
"BLOCK_N": 32,
"BLOCK_K": 16,
"template_hash": "0717af5834e39dcca7ea817f896b8d85b4886422da7a3ab5f6911b4cfe568896"
}
torch.cuda.get_device_properties().gcnArchName (e.g., "NVIDIA H100")KernelInputs.key containing tensor propertiesThe system is accessed through:
lookup_template_configs(kernel_inputs, op_name, template_uids) - Main lookup functionLookupTableChoices._finalize_template_configs() - Integration point with existing choice systemtemplate_id field