examples/python/CuTeDSL/notebooks/benchmark_autotune.ipynb
import torch
import cutlass
import cutlass.cute as cute
import cutlass.cute.testing as testing
import cutlass.torch as cutlass_torch
CuTe DSL provides autotune and benchmark utilities to help users evaluate and optimize kernel performance. This notebook demonstrates how to use these tools.
We provides two kinds of autotune utilities for users: autotune.jit decorator and the tune function. The former is used as a decorator used on top of @cute.jit while the latter is used as an individual function.
We take the elementwise_add_kernel as an example. After writing the jit host function and kernel, we could add the @autotune_jit decorator on top of the jit host function to enable autotune.
@testing.autotune_jit(
params_dict={"copy_bits": [64, 128]},
update_on_change=["M", "N"],
warmup_iterations=100,
iterations=100,
)
The autotune_jit decorator provides several parameters to control the autotuning process:
@cute.kernel
def elementwise_add_kernel(
gA: cute.Tensor,
gB: cute.Tensor,
gC: cute.Tensor,
cC: cute.Tensor, # coordinate tensor
shape: cute.Shape,
thr_layout: cute.Layout,
val_layout: cute.Layout,
):
tidx, _, _ = cute.arch.thread_idx()
bidx, _, _ = cute.arch.block_idx()
# slice for CTAs
# logical id -> address
blk_coord = ((None, None), bidx)
blkA = gA[blk_coord] # (TileM,TileN)
blkB = gB[blk_coord] # (TileM,TileN)
blkC = gC[blk_coord] # (TileM,TileN)
blkCrd = cC[blk_coord] # (TileM, TileN)
# # declare the atoms which will be used later for memory copy
copy_atom_load = cute.make_copy_atom(cute.nvgpu.CopyUniversalOp(), gA.element_type)
copy_atom_store = cute.make_copy_atom(cute.nvgpu.CopyUniversalOp(), gC.element_type)
tiled_copy_A = cute.make_tiled_copy_tv(copy_atom_load, thr_layout, val_layout)
tiled_copy_B = cute.make_tiled_copy_tv(copy_atom_load, thr_layout, val_layout)
tiled_copy_C = cute.make_tiled_copy_tv(copy_atom_store, thr_layout, val_layout)
thr_copy_A = tiled_copy_A.get_slice(tidx)
thr_copy_B = tiled_copy_B.get_slice(tidx)
thr_copy_C = tiled_copy_C.get_slice(tidx)
thrA = thr_copy_A.partition_S(blkA)
thrB = thr_copy_B.partition_S(blkB)
thrC = thr_copy_C.partition_S(blkC)
# allocate fragments for gmem->rmem
frgA = cute.make_fragment_like(thrA)
frgB = cute.make_fragment_like(thrB)
frgC = cute.make_fragment_like(thrC)
thrCrd = thr_copy_C.partition_S(blkCrd)
frgPred = cute.make_rmem_tensor(thrCrd.shape, cutlass.Boolean)
for i in range(0, cute.size(frgPred), 1):
val = cute.elem_less(thrCrd[i], shape)
frgPred[i] = val
##########################################################
# Move data to reg address space
##########################################################
cute.copy(copy_atom_load, thrA, frgA, pred=frgPred)
cute.copy(copy_atom_load, thrB, frgB, pred=frgPred)
# Load data before use. The compiler will optimize the copy and load
# operations to convert some memory ld/st into register uses.
result = frgA.load() + frgB.load()
# Save the results back to registers. Here we reuse b's registers.
frgC.store(result)
# Copy the results back to c
cute.copy(copy_atom_store, frgC, thrC, pred=frgPred)
@testing.autotune_jit(
params_dict={"copy_bits": [64, 128]},
update_on_change=["M", "N"],
warmup_iterations=100,
iterations=100,
)
@cute.jit
def elementwise_add_autotune(mA, mB, mC, M, N, copy_bits: cutlass.Constexpr = 128):
dtype = mA.element_type
vector_size = copy_bits // dtype.width
thr_layout = cute.make_ordered_layout((4, 32), order=(1, 0))
val_layout = cute.make_ordered_layout((4, vector_size), order=(1, 0))
tiler_mn, tv_layout = cute.make_layout_tv(thr_layout, val_layout)
gA = cute.zipped_divide(mA, tiler_mn) # ((TileM,TileN),(RestM,RestN))
gB = cute.zipped_divide(mB, tiler_mn) # ((TileM,TileN),(RestM,RestN))
gC = cute.zipped_divide(mC, tiler_mn) # ((TileM,TileN),(RestM,RestN))
idC = cute.make_identity_tensor(mC.shape)
cC = cute.zipped_divide(idC, tiler=tiler_mn)
elementwise_add_kernel(gA, gB, gC, cC, mC.shape, thr_layout, val_layout).launch(
grid=[cute.size(gC, mode=[1]), 1, 1],
block=[cute.size(tv_layout, mode=[0]), 1, 1],
)
When we run the jit funciton elementwise_add_autotune, the CuTe DSL will help us tune the kernels by looping the specified configs and run the kernel with the best config.
M, N = 1024, 1024
dtype = cutlass.Float32
skip_ref_check = False
print(f"\nRunning Elementwise Add test with:")
print(f"Tensor dimensions: [{M}, {N}]")
print(f"Input and Output Data type: {dtype}")
torch_dtype = cutlass_torch.dtype(dtype)
a = torch.randn(M, N, device=torch.device("cuda"), dtype=torch_dtype)
b = torch.randn(M, N, device=torch.device("cuda"), dtype=torch_dtype)
c = torch.zeros_like(a)
print(f"Input tensor shapes:")
print(f"a: {a.shape}, dtype: {a.dtype}")
print(f"b: {b.shape}, dtype: {b.dtype}")
print(f"c: {c.shape}, dtype: {c.dtype}\n")
elementwise_add_autotune(a, b, c, M, N)
if not skip_ref_check:
print("Verifying results for autotuned function ...")
torch.testing.assert_close(a + b, c)
print("Results verified successfully!")
The output is as follows:
Running Elementwise Add test with:
Tensor dimensions: [1024, 1024]
Input and Output Data type: Float32
Input tensor shapes:
a: torch.Size([1024, 1024]), dtype: torch.float32
b: torch.Size([1024, 1024]), dtype: torch.float32
c: torch.Size([1024, 1024]), dtype: torch.float32
Verifying results for autotuned function ...
Results verified successfully!
To monitor the autotuning process in detail, you can enable logging by setting the environment variable CUTE_DSL_LOG_AUTOTUNE.
export CUTE_DSL_LOG_AUTOTUNE=1
This will display comprehensive information including:
Below is a sample output showing the autotuning process with different configurations:
2025-07-23 06:17:03,978 - cutlass.cute.testing_Autotune - INFO - Tuning configuration: {'copy_bits': 64}
2025-07-23 06:17:04,519 - cutlass.cute.testing_Autotune - INFO - Execution time: 0.010857919985428453 us
2025-07-23 06:17:04,519 - cutlass.cute.testing_Autotune - INFO - Tuning configuration: {'copy_bits': 128}
2025-07-23 06:17:04,683 - cutlass.cute.testing_Autotune - INFO - Execution time: 0.011117440033704042 us
2025-07-23 06:17:04,683 - cutlass.cute.testing_Autotune - INFO - Best configuration: {'copy_bits': 64}, execution time: 0.010857919985428453 us
2025-07-23 06:17:04,683 - cutlass.cute.testing_Autotune - INFO - Total tuning time: 0.7053244113922119 s
...
2025-07-23 06:17:04,700 - cutlass.cute.testing_Autotune - INFO - Using cached best configuration: {'copy_bits': 64}
We also provide a tune funtion. The interface of the tune function is as follows:
def tune(
func: Callable[[Any], Callable[[], Any]],
params_dict: Dict[str, List[Any]] = None,
kernel_arguments: JitArguments = JitArguments(),
warmup_iterations=10,
iterations=100,
stream: Optional[cuda_driver.CUstream] = None,
) -> Dict[str, Any]:
The tune function takes the following parameters:
It returns a dictionary containing the best kernel configuration found.
Here is an example to use the tune function:
First remove the @testing.autotune_jit decorator from the elementwise_add_autotune function:
@testing.autotune_jit(
params_dict={"copy_bits": [64, 128]},
update_on_change=["M", "N"],
warmup_iterations=100,
iterations=100,
)
Define a tune_func that:
elementwise_add_autotune function using cute.compile()Pass tune_func to testing.tune function along with:
tune function will find optimal parameters automaticallydef tune_func(a, b, c, M, N, copy_bits=128):
compiled_func = cute.compile(elementwise_add_autotune, a, b, c, M, N, copy_bits=128)
return lambda: compiled_func(a, b, c, M, N)
params = testing.tune(
tune_func,
params_dict={"copy_bits": [64, 128]},
kernel_arguments=testing.JitArguments(a, b, c, M, N),
)
print(f"The best kernel configs found: {params}")
# run the kernel with the best config
compiled_func = cute.compile(elementwise_add_autotune, a, b, c, M, N, **params)
compiled_func(a, b, c, M, N)
In CuTe DSL, the benchmark utility can be used to measure kernel execution time. The interface of benchmark routine is as follows:
def benchmark(
callable: Callable,
*,
warmup_iterations: int = 10,
iterations: int = 100,
stream: Optional[cuda_driver.CUstream] = None,
kernel_arguments: Optional[JitArguments] = None,
workspace_generator: Optional[Callable[[], JitArguments]] = None,
workspace_count: int = 1,
use_cuda_graphs: bool = False,
) -> float:
The benchmark utility exposes several key configuration parameters to control profiling behavior:
When benchmarking, there are several key parameters that can be configured:
Core parameters:
Stream configuration:
Cache effects mitigation:
CUDA Graph support:
This function will return the execution time of the callable in microseconds. As GPU frequency can vary dynamically, we could fix the SM and memory frequencies to get more stable and reproducible benchmark results. This can be done by setting the GPU clocks using nvidia-smi before running the benchmark. In the next, let's use the benchmark function to get the execution time of the above elementwise_add kernel.
def generate_kernel_arguments():
a = torch.randn(
M, N, device=torch.device("cuda"), dtype=torch_dtype
)
b = torch.randn(
M, N, device=torch.device("cuda"), dtype=torch_dtype
)
c = torch.zeros_like(a)
return testing.JitArguments(a, b, c, M, N)
avg_time_us = testing.benchmark(
elementwise_add_autotune,
workspace_generator=generate_kernel_arguments,
workspace_count=10,
warmup_iterations=10,
iterations=100,
)
# Print execution results
print(
f"Kernel execution time for cute.jit kernel with M={M}, N={N}: {avg_time_us / 1e3:.4f} ms"
)
print(
f"Achieved memory throughput for M={M}, N={N}: {(3 * a.numel() * dtype.width // 8) / (avg_time_us / 1e6) / 1e9:.2f} GB/s"
)
After running the code, we will get output similar to the following:
Kernel execution time for cute.jit kernel with M=1024, N=1024: 0.0403 ms
Achieved memory throughput for M=1024, N=1024: 312.37 GB/s