Back to Sglang

sglang-kernel (prior sgl-kernel)

sgl-kernel/README.md

0.5.115.0 KB
Original Source

sglang-kernel (prior sgl-kernel)

Kernel Library for LLM inference engines

<div align="center">

</div>

sglang-kernel provides optimized compute primitives for LLM inference engines, enabling efficient inference for large language models and vision-language models through custom kernel operations. The source tree remains under the sgl-kernel/ directory and the Python import path remains sgl_kernel.

Installation

Requires torch == 2.11.0

bash
# Latest version
pip3 install sglang-kernel --upgrade

Building from Source

Requires

  • CMake ≥3.31,
  • Python ≥3.10
  • scikit-build-core
  • ninja(optional)

Use Makefile to build from the sgl-kernel source tree

bash
make build

Limit build resource usage (CPU / parallelism)

By default, make build uses all available CPU cores. You can override build parallelism and NVCC compile threads:

bash
# Limit parallel jobs (controls both make and cmake parallelism)
make build MAX_JOBS=2

# Additionally limit NVCC internal threads (reduces CPU and peak memory)
make build MAX_JOBS=2 CMAKE_ARGS="-DSGL_KERNEL_COMPILE_THREADS=1"

Contribution

Steps to add a new kernel:

  1. Implement the kernel in csrc
  2. Expose the interface in include/sgl_kernel_ops.h
  3. Create torch extension in csrc/common_extension.cc
  4. Update CMakeLists.txt to include new CUDA source
  5. Expose Python interface in python
  6. Add test and benchmark

Development Tips

  1. When creating torch extensions, add the function definition with m.def, and device binding with m.impl:
  • How to write schema: Schema reference

    cpp
    // We need def with schema here for torch.compile
    m.def(
     "bmm_fp8(Tensor A, Tensor B, Tensor! D, Tensor A_scale, Tensor B_scale, Tensor workspace_buffer, "
     "int cublas_handle) -> ()");
    m.impl("bmm_fp8", torch::kCUDA, &bmm_fp8);
    

Adapting C++ Native Types for Torch Compatibility

Third-party C++ libraries often use int and float, but PyTorch bindings require int64_t and double due to Python's type mapping.

Use make_pytorch_shim from sgl_kernel_torch_shim.h to handle conversions automatically:

cpp

// Add type conversion for int -> int64_t
template <>
struct pytorch_library_compatible_type<int> {
  using type = int64_t;
  static int convert_from_type(int64_t arg) {
    TORCH_CHECK(arg <= std::numeric_limits<int>::max(), "value too large");
    TORCH_CHECK(arg >= std::numeric_limits<int>::min(), "value too small");
    return arg;
  }
};
cpp
// Wrap your function
m.impl("fwd", torch::kCUDA, make_pytorch_shim(&mha_fwd));

Testing & Benchmarking

  1. Add pytest tests in tests/, if you need to skip some test, please use @pytest.mark.skipif
python
@pytest.mark.skipif(
    skip_condition, reason="Nvfp4 Requires compute capability of 10 or above."
)
  1. Add benchmarks using triton benchmark in benchmark/

    We recommend using triton.testing.do_bench_cudagraph for kernel benchmarking:

    Compared to triton.testing.do_bench, do_bench_cudagraph provides:

    • Reduced CPU overhead impact for more accurate kernel performance measurements
    • Incorporation of PDL (Programmatic Dependent Launch) effects into individual kernel results
    • More realistic performance data on PDL-supported architectures (SM >= 90)
  2. Run test suite

Kernel Size Analysis

Analyze CUDA kernel sizes in compiled wheel files to identify oversized kernels and template-instantiation bloat:

This tool requires cubloaty (install with pip install cubloaty) to work.

bash
# Install cubloaty
pip install cubloaty

# Analyze a wheel file
python analyze_whl_kernel_sizes.py path/to/sglang_kernel-*.whl

# Custom output file
python analyze_whl_kernel_sizes.py path/to/sglang_kernel-*.whl --output my_analysis.txt

The tool generates:

  • A text report with:
    • Kernel groups (by name prefix)
    • Individual kernel sizes (sorted by size)

Use this to identify large kernels and potential template instantiation bloat.

FAQ