Step-by-step tutorial for adding new CUDA kernels to FlashInfer
This tutorial walks through adding a simple element-wise scale operation to FlashInfer. We'll implement scale(x, factor) = x * factor to demonstrate the complete workflow.
Add a new operation that scales each element of a tensor by a scalar factor:
x and scalar factorx * factor (element-wise)include/Create include/flashinfer/scale.cuh:
#pragma once
#include <cuda_runtime.h>
#include <cuda_fp16.h>
#include <cuda_bf16.h>
namespace flashinfer {
/*!
* \brief Element-wise scale kernel
* \tparam T Data type (half, __nv_bfloat16, float)
* \param input Input tensor
* \param output Output tensor
* \param factor Scale factor
* \param n Number of elements
*/
template <typename T>
__global__ void ScaleKernel(const T* input, T* output, T factor, int n) {
int idx = blockIdx.x * blockDim.x + threadIdx.x;
if (idx < n) {
output[idx] = input[idx] * factor;
}
}
/*!
* \brief Launch scale kernel
* \tparam T Data type
* \param input Input pointer
* \param output Output pointer
* \param factor Scale factor
* \param n Number of elements
* \param stream CUDA stream
*/
template <typename T>
cudaError_t ScaleLauncher(const T* input, T* output, T factor, int n,
cudaStream_t stream = nullptr) {
const int threads = 256;
const int blocks = (n + threads - 1) / threads;
ScaleKernel<T><<<blocks, threads, 0, stream>>>(input, output, factor, n);
return cudaGetLastError();
}
} // namespace flashinfer
Key points:
csrc/Create csrc/scale.cu:
#include "flashinfer/scale.cuh"
using namespace flashinfer;
void scale_launcher(TensorView input, TensorView output,
float factor) {
CHECK_INPUT(input);
CHECK_INPUT(output);
TVM_FFI_ICHECK_EQ(input.dtype(), output.dtype());
int n = input.numel();
auto stream = get_stream(input.device());
DISPATCH_DLPACK_DTYPE_TO_CTYPE_FP32_FP16(input.dtype(), DType, [&] {
cudaError_t status = ScaleLauncher<DType>(
input.data_ptr<DType>(),
output.data_ptr<DType>(),
static_cast<DType>(factor),
n,
stream
);
TVM_FFI_ICHECK(status == cudaSuccess)
<< "Failed to run ScaleLauncher: " << cudaGetErrorString(status);
return true;
});
}
Key points:
tvm_ffi_utils.h (only allowed in csrc/)tvm::ffi::TensorView as input and output tensor typestvm_ffi_utils.h to check the input and output if both on CUDA device, both contiguous, and share the same data typetvm_ffi_utils.h, or adds new one if not coveredTVM_FFI_ICHECK<< operatorTVM_FFI_THROW(ErrorType) << "message" for custom error checkingTVM-FFI Error Handling:
TVM_FFI_THROW(ValueError) << "message" - Throw ValueError with custom messageTVM_FFI_THROW(TypeError) << "message" - Throw TypeError<< to chain multiple values in the error messageWhen to use TVM_FFI_THROW vs TVM_FFI_LOG_AND_THROW:
TVM_FFI_THROW: Use for normal runtime error handling. This is the standard way to report errors that will be caught and propagated to Python.
TVM_FFI_LOG_AND_THROW: Use only in cases where:
This variant logs the error message before throwing, ensuring visibility even if the exception doesn't propagate correctly.
Example from fused_moe (see csrc/trtllm_fused_moe_kernel_launcher.cu):
// In a setup/validation function that may be called during construction
void check_weights_shape(std::string which_weights) const {
if (which_weights != "gemm1" && which_weights != "gemm2") {
// Internal error that should never happen - use LOG_AND_THROW
TVM_FFI_LOG_AND_THROW(InternalError) << "Internal error: which_weights = " << which_weights;
}
// ...
if (weight_layout is unsupported) {
// Unsupported config during setup - use LOG_AND_THROW
TVM_FFI_LOG_AND_THROW(NotImplementedError)
<< "Unsupported weight_layout: " << (int)weight_layout;
}
}
// In a normal runtime function
void scale_run(TensorView input, TensorView output, double factor) {
if (!input_tensor.is_cuda()) {
// Normal validation error - use TVM_FFI_THROW
TVM_FFI_THROW(ValueError) << "Input must be a CUDA tensor";
}
}
csrc/Create csrc/scale_jit_binding.cu:
#include "scale.cu"
#include "tvm_ffi_utils.h"
// Forward declaration
void scale_launcher(TensorView input, TensorView output, float factor);
// Export to TVM-FFI
TVM_FFI_DLL_EXPORT_TYPED_FUNC(run, scale_launcher);
Key points:
TVM_FFI_DLL_EXPORT_TYPED_FUNC(name, function)Create flashinfer/jit/scale.py:
import os
import shutil
from pathlib import Path
from . import JitSpec, gen_jit_spec
from . import env as jit_env
from .core import write_if_different
def get_scale_uri(dtype_in: str, dtype_out: str) -> str:
"""Generate unique identifier for scale module."""
return f"scale_dtype_in_{dtype_in}_dtype_out_{dtype_out}"
def gen_scale_module(dtype_in, dtype_out):
"""
Generate JIT module for scale operation.
Note: This is a simple example without Jinja templating.
The dtype dispatch is handled at runtime in the C++ code.
"""
# Compute URI
uri = get_scale_uri(dtype_in, dtype_out)
# Create generation directory
gen_directory = jit_env.FLASHINFER_GEN_SRC_DIR / uri
os.makedirs(gen_directory, exist_ok=True)
# Copy source files (no Jinja needed for this simple case)
sources = []
for fname in ["scale.cu", "scale_jit_binding.cu"]:
src_path = jit_env.FLASHINFER_CSRC_DIR / fname
dest_path = gen_directory / fname
shutil.copy(src_path, dest_path)
sources.append(dest_path)
# Return JitSpec
return gen_jit_spec(
name=uri,
sources=sources,
extra_cuda_cflags=[],
)
Key points:
CLAUDE.mdFlashInfer uses CompilationContext to manage CUDA architecture targets. This is critical because some kernels only work on specific GPU architectures (e.g., Hopper SM90, Blackwell SM100).
Automatic Detection (default):
from flashinfer.compilation_context import CompilationContext
ctx = CompilationContext()
# Automatically detects all GPUs in the system
# For SM90+, adds 'a' suffix (e.g., 9.0a for Hopper)
# Result: ctx.TARGET_CUDA_ARCHS = {(9, '0a'), (10, '0a'), ...}
Manual Override (via environment variable):
export FLASHINFER_CUDA_ARCH_LIST="8.0 9.0a 10.0a"
# Now only these architectures will be compiled
When creating a JIT module, specify which major SM versions are supported:
from flashinfer.jit.core import gen_jit_spec
from flashinfer.jit import current_compilation_context
def gen_my_hopper_only_module():
"""Example: Kernel works on SM90 and later supported architectures."""
uri = get_my_uri(...)
gen_directory = jit_env.FLASHINFER_GEN_SRC_DIR / uri
# ... copy sources ...
nvcc_flags = current_compilation_context.get_nvcc_flags_list(
# Explicitly list supported SM versions - no automatic future compatibility
supported_major_versions=[9, 10, 11, 12] # SM90, SM100, SM110, SM120
)
return gen_jit_spec(
name=uri,
sources=sources,
extra_cuda_cflags=nvcc_flags,
)
def gen_my_blackwell_only_module():
"""Example: Kernel only works on SM100 (Blackwell)"""
uri = get_my_uri(...)
gen_directory = jit_env.FLASHINFER_GEN_SRC_DIR / uri
# ... copy sources ...
nvcc_flags = current_compilation_context.get_nvcc_flags_list(
supported_major_versions=[10] # SM100 only
)
return gen_jit_spec(
name=uri,
sources=sources,
extra_cuda_cflags=nvcc_flags,
)
def gen_my_universal_module():
"""Example: Kernel works on all architectures"""
uri = get_my_uri(...)
gen_directory = jit_env.FLASHINFER_GEN_SRC_DIR / uri
# ... copy sources ...
nvcc_flags = current_compilation_context.get_nvcc_flags_list(
supported_major_versions=None # All available architectures
)
return gen_jit_spec(
name=uri,
sources=sources,
extra_cuda_cflags=nvcc_flags,
)
What Happens:
RuntimeError: No supported CUDA architectures found for major versions [9, 10, 11, 12]# MLA kernel: Blackwell and newer only
def gen_mla_module() -> JitSpec:
nvcc_flags = current_compilation_context.get_nvcc_flags_list(
supported_major_versions=[10, 11] # SM100, SM110
)
return gen_jit_spec(
name=uri,
sources=sources,
extra_cuda_cflags=nvcc_flags,
)
# Blackwell FMHA: SM120 only
def gen_fmhav2_blackwell_module(...):
nvcc_flags = current_compilation_context.get_nvcc_flags_list(
supported_major_versions=[12] # SM120 only
)
return gen_jit_spec(
name=uri,
sources=sources,
extra_cuda_cflags=nvcc_flags,
)
# Standard attention: Hopper and later supported architectures
def gen_batch_prefill_module(...):
nvcc_flags = current_compilation_context.get_nvcc_flags_list(
supported_major_versions=[9, 10, 11, 12] # SM90, SM100, SM110, SM120
)
return gen_jit_spec(
name=uri,
sources=sources,
extra_cuda_cflags=nvcc_flags,
)
| Supported Versions | Architectures | Use Case |
|---|---|---|
None |
All available GPUs | Universal kernels (default) |
[9, 10, 11, 12] |
SM90, SM100, SM110, SM120 | Hopper, Blackwell |
[10, 11, 12] |
SM100, SM110, SM120 | Blackwell only |
[12] |
SM120 | Specific architecture only |
[8, 9, 10, 11, 12] |
SM80, SM90, SM100, SM110, SM120 | Ampere, Hopper, Blackwell |
When your kernel has architecture requirements, add skip checks in tests (see Step 6 below):
import pytest
import torch
from flashinfer.utils import is_sm90a_supported
def test_hopper_kernel():
if not is_sm90a_supported(torch.device("cuda")):
pytest.skip("SM90a is not supported on this GPU")
# Test code here
...
flashinfer/Create flashinfer/scale.py:
import functools
import torch
from typing import Optional
from .jit.scale import gen_scale_module
from .utils import backend_requirement, supported_compute_capability
from .api_logging import flashinfer_api
@functools.cache
def get_scale_module(dtype_in, dtype_out):
"""Get or compile scale module (cached)."""
return gen_scale_module(dtype_in, dtype_out).build_and_load()
@supported_compute_capability([80, 86, 89, 90, 100, 103, 110, 120])
def _check_scale_problem_size(input: torch.Tensor, factor: float,
out: Optional[torch.Tensor] = None) -> bool:
"""Validate inputs for scale operation."""
# Validate input
if not input.is_cuda:
raise ValueError("Input must be a CUDA tensor")
# Validate output if provided
if out is not None:
if out.shape != input.shape:
raise ValueError("Output shape mismatch")
if out.dtype != input.dtype:
raise ValueError("Output dtype mismatch")
if not out.is_cuda:
raise ValueError("Output must be a CUDA tensor")
return True
@flashinfer_api
@backend_requirement(
backend_checks={}, # No backend choices for this simple kernel
common_check=_check_scale_problem_size,
)
def scale(input: torch.Tensor, factor: float,
out: Optional[torch.Tensor] = None) -> torch.Tensor:
"""
Element-wise scale operation.
Parameters
----------
input : torch.Tensor
Input tensor (CUDA)
factor : float
Scale factor
out : Optional[torch.Tensor]
Output tensor (if None, allocate new tensor)
Returns
-------
output : torch.Tensor
Scaled tensor (input * factor)
Examples
--------
>>> import torch
>>> import flashinfer
>>> x = torch.randn(1024, dtype=torch.float16, device="cuda")
>>> y = flashinfer.scale(x, 2.0)
>>> torch.allclose(y, x * 2.0)
True
"""
# Allocate output if needed
if out is None:
out = torch.empty_like(input)
# Get module (compile if first call with this dtype)
dtype_str = str(input.dtype).replace("torch.", "")
module = get_scale_module(dtype_str, dtype_str)
# Call TVM-FFI function (exported as "run")
module.run(input, out, float(factor))
return out
Key points:
@functools.cache to cache compiled modules@flashinfer_api decorator enables logging and sets it apart from helper functionsout: Optional[torch.Tensor] = None). This allows users to provide pre-allocated buffers to avoid allocation overhead in performance-critical paths. Only allocate internally when out is None.@backend_requirement decorator@backend_requirement and @supported_compute_capability DecoratorsFlashInfer provides two decorators for enforcing compute capability and backend requirements:
@supported_compute_capability DecoratorMarks a function with its supported CUDA compute capabilities:
from flashinfer.utils import supported_compute_capability
@supported_compute_capability([80, 86, 89, 90, 100, 103, 110, 120])
def _my_check_function(input, output):
"""Supports SM80 (Ampere) through SM120 (Blackwell)."""
# Validation logic here
return True
@backend_requirement DecoratorEnforces backend and problem size requirements at runtime. There are three usage patterns:
Pattern 1: Single Backend (No Backend Choices)
For kernels with only one implementation (like our scale example):
from flashinfer.utils import backend_requirement, supported_compute_capability
@supported_compute_capability([80, 86, 89, 90, 100, 103, 110, 120])
def _check_my_kernel(input, output):
"""Validate inputs. Must return True if valid."""
if input.shape[-1] > 256:
raise ValueError("Head dimension must be <= 256")
return True
@backend_requirement(
backend_checks={}, # Empty dict = no backend parameter
common_check=_check_my_kernel,
)
def my_kernel(input, output):
# Kernel implementation
pass
Pattern 2: Multiple Backends
For kernels with multiple implementation backends (e.g., CUTLASS, cuDNN):
@supported_compute_capability([80, 86, 89, 90])
def _cutlass_check(q, k, v, backend):
"""CUTLASS backend: Ampere through Hopper."""
if q.shape[-1] > 256:
raise ValueError("CUTLASS: head_dim must be <= 256")
return True
@supported_compute_capability([75, 80, 86, 89, 90, 100])
def _cudnn_check(q, k, v, backend):
"""cuDNN backend: Turing through Blackwell."""
return True
@backend_requirement(
backend_checks={
"cutlass": _cutlass_check,
"cudnn": _cudnn_check,
},
common_check=None, # Optional: shared validation for all backends
)
def attention(q, k, v, backend="cutlass"):
if backend == "cutlass":
# CUTLASS implementation
pass
elif backend == "cudnn":
# cuDNN implementation
pass
Pattern 3: Auto Backend Selection
For kernels that can automatically select the best backend:
def _heuristic_func(suitable_backends, q, k, v, backend):
"""Return backends in order of preference."""
# Prefer CUTLASS for small head dims, cuDNN for larger
if q.shape[-1] <= 128:
preferred = ["cutlass", "cudnn"]
else:
preferred = ["cudnn", "cutlass"]
return [b for b in preferred if b in suitable_backends]
@backend_requirement(
backend_checks={
"cutlass": _cutlass_check,
"cudnn": _cudnn_check,
},
common_check=_common_validation,
heuristic_func=_heuristic_func, # Required when backend="auto" is used
)
def attention(q, k, v, backend="auto"):
if backend == "auto":
# Use the first backend from suitable_auto_backends
backend = attention.suitable_auto_backends[0]
# ... rest of implementation
@backend_requirementThe decorator adds these methods to the wrapped function:
# Check if a backend is supported (optionally for a specific CC)
scale.is_backend_supported("cutlass") # True/False
scale.is_backend_supported("cutlass", cc=90) # True/False for Hopper
# Check if any backend supports this compute capability
scale.is_compute_capability_supported(90) # True/False
# Check if a backend exists
scale.has_backend("cutlass") # True/False
# Check if there are multiple backend choices
scale.has_backend_choices() # True/False
skip_check Keyword ArgumentThe decorator adds a skip_check keyword argument to bypass validation for performance-critical code paths:
# Normal call with validation
result = scale(x, 2.0)
# Skip validation for performance (use with caution!)
result = scale(x, 2.0, skip_check=True)
Check functions must:
True if validation passesValueError with descriptive message if validation fails@supported_compute_capability to specify supported architecturesEvery new kernel must have a TraceTemplate so that flashinfer-bench can auto-generate
benchmark definition files via @flashinfer_api(trace=...).
flashinfer/trace/templates/Add a file (or extend an existing one) in flashinfer/trace/templates/. The
real flashinfer/trace/templates/norm.py is a good reference — it shows two
variants that share an op_type but have distinct name_prefix values:
# flashinfer/trace/templates/norm.py (real file, simplified for illustration)
from ..template import Const, Tensor, TraceTemplate, Var
# op_type – high-level operation category written to the JSON "op_type" field.
# Two templates can share the same op_type when they are variants of
# the same operation family.
# name_prefix – base string for the auto-generated filename and JSON "name" field.
# Const axis values are appended, e.g. rmsnorm_h4096.json.
# Must be unique across templates that share an op_type.
rmsnorm_trace = TraceTemplate(
op_type="rmsnorm", # category: all RMSNorm variants share this
name_prefix="rmsnorm", # specific variant → file: rmsnorm_h<hidden>.json
description="Root Mean Square Normalization. Epsilon is fixed at 1e-6.",
axes={
"batch_size": Var(), # runtime-variable: omitted from filename
"hidden_size": Const(abbrev="h"), # baked into filename as "h<value>"
},
inputs={
# json_key "hidden_states" differs from the Python param name "input",
# so param= is set explicitly.
"hidden_states": Tensor(["batch_size", "hidden_size"], param="input"),
"weight": Tensor(["hidden_size"]), # key == param, no param= needed
},
outputs={
"output": Tensor(["batch_size", "hidden_size"], dtype_from="input"),
},
tags=["status:verified"],
reference=_rmsnorm_reference,
)
fused_add_rmsnorm_trace = TraceTemplate(
op_type="rmsnorm", # same category as rmsnorm_trace above
name_prefix="fused_add_rmsnorm", # different variant → fused_add_rmsnorm_h<hidden>.json
description="Fused Add + RMSNorm. Epsilon is fixed at 1e-6.",
axes={
"batch_size": Var(),
"hidden_size": Const(abbrev="h"),
},
inputs={
"hidden_states": Tensor(["batch_size", "hidden_size"], param="input"),
"residual": Tensor(["batch_size", "hidden_size"]),
"weight": Tensor(["hidden_size"]),
},
outputs={
"output": Tensor(["batch_size", "hidden_size"], dtype_from="input"),
"residual": Tensor(
["batch_size", "hidden_size"],
dtype_from="input",
description="Updated residual (in-place: residual += hidden_states).",
),
},
tags=["status:verified", "fused"],
reference=_fused_add_rmsnorm_reference,
)
Key rules:
Var() → value is NOT baked into the generated name or JSON value.Const(abbrev=...) → value IS extracted and written to JSON. abbrev="h" → h4096; abbrev="" → omit from filename.Tensor key defaults to param=key; use param="other_name" when they differ.dtype_from="<input_key>" copies the dtype from that input tensor (use the JSON key, not the param name).trace=:def _my_trace_dispatch(**kwargs):
if kwargs.get("mode") == "fast":
return fast_trace
return slow_trace
@flashinfer_api(trace=_my_trace_dispatch)
def my_op(..., mode="fast"):
...
See flashinfer/fused_moe/core.py + flashinfer/trace/templates/moe.py for a
real dispatch example keyed on routing_method_type.# flashinfer/norm.py (real file)
from .trace.templates.norm import rmsnorm_trace
@flashinfer_api(trace=rmsnorm_trace)
def rmsnorm(input: torch.Tensor, weight: torch.Tensor, ...) -> torch.Tensor:
...
The fi_api tag is derived automatically from func.__module__ + "." + func.__qualname__.
Open tests/trace/test_fi_trace_template_consistency.py and add your module to
the import list inside _collect_template_func_pairs():
import flashinfer.norm # ← add your module here
That's it. @flashinfer_api(trace=...) automatically registers every
(func, template) pair in flashinfer.api_logging._TRACE_REGISTRY at
decoration time. Importing the module triggers the decorator, and the
parameterized tests then check:
param= reference exists in the function's signature.Const axis appears in at least one tensor's dim_names or the function's parameter list.fi_trace with auto-generated CPU tensors returns a complete dict
(no "unknown" dtypes for non-optional inputs, all Const axes have values).If your template uses tuple inputs or exotic dtypes (fp8 scale tensors, etc.),
add a targeted end-to-end test at the bottom of the file and add your label to
_E2E_SKIP (see the MoE example there).
For dispatch templates (callable trace=), also set a .templates
attribute on the dispatch function listing all possible return values:
def _my_trace_dispatch(**kwargs): ...
_my_trace_dispatch.templates = [fast_trace, slow_trace]
This lets the registry auto-discover and check all variants.
pytest tests/trace/test_fi_trace_template_consistency.py -v
A failing structural test looks like:
AssertionError: [rmsnorm] Template 'rmsnorm' has param mismatches:
Input 'hidden_states' → param='x' not found in rmsnorm(['input', 'weight', 'eps'])
which tells you exactly which key is wrong and what names are available.
tests/Create tests in an appropriate subdirectory (e.g., tests/elementwise/test_scale.py or create a new subdir if needed):
import pytest
import torch
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16, torch.float32])
@pytest.mark.parametrize("size", [128, 1024, 4096])
def test_scale_correctness(dtype, size):
"""Test scale operation correctness."""
import flashinfer
# Setup
x = torch.randn(size, dtype=dtype, device="cuda")
factor = 3.14
# Run FlashInfer kernel
y = flashinfer.scale(x, factor)
# Reference implementation
expected = x * factor
# Compare
if dtype == torch.float32:
rtol, atol = 1e-5, 1e-6
else:
rtol, atol = 1e-3, 1e-3
torch.testing.assert_close(y, expected, rtol=rtol, atol=atol)
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16])
def test_scale_inplace(dtype):
"""Test scale with pre-allocated output."""
import flashinfer
x = torch.randn(1024, dtype=dtype, device="cuda")
out = torch.empty_like(x)
factor = 2.0
result = flashinfer.scale(x, factor, out=out)
# Should return the same tensor
assert result is out
# Check correctness
expected = x * factor
torch.testing.assert_close(result, expected, rtol=1e-3, atol=1e-3)
def test_scale_cpu_error():
"""Test that CPU tensors raise an error."""
import flashinfer
x = torch.randn(128, dtype=torch.float32)
with pytest.raises(ValueError, match="CUDA"):
flashinfer.scale(x, 2.0)
Key points:
pytest.mark.parametrize for multiple configurationsRegister your kernel in AOT so users with flashinfer-jit-cache can skip JIT compilation.
Edit flashinfer/aot.py, add to the appropriate section:
def gen_scale_modules() -> Iterator[JitSpec]:
"""Generate scale operation modules for AOT compilation."""
from .jit.scale import gen_scale_module
# Pre-compile common dtypes
for dtype in ["float16", "bfloat16", "float32"]:
yield gen_scale_module(dtype, dtype)
# In the main AOT build loop, add:
# for spec in gen_scale_modules():
# spec.build()
Key points:
flashinfer-jit-cache won't need to compile at runtimeEdit flashinfer/__init__.py:
from .scale import scale as scale
# Or in the existing imports section:
# from .scale import scale
# The kernel compiles automatically on first use
pytest tests/test_scale.py -v
# Run with different dtypes
pytest tests/test_scale.py::test_scale_correctness[float16-128] -v
All new kernels should have benchmarks. This helps track performance regressions and allows users to compare against other implementations.
Create a benchmark file in benchmarks/ (e.g., benchmarks/bench_scale.py):
import torch
from flashinfer.testing import bench_gpu_time
def bench_scale():
"""Benchmark scale kernel."""
import flashinfer
sizes = [1024, 4096, 16384, 65536, 262144]
dtypes = [torch.float16, torch.bfloat16]
print("Scale Kernel Benchmark")
print("-" * 60)
print(f"{'Size':>10} {'Dtype':>10} {'Time (us)':>12} {'Std (us)':>10}")
print("-" * 60)
for size in sizes:
for dtype in dtypes:
x = torch.randn(size, dtype=dtype, device="cuda")
# Benchmark with CUPTI (auto-fallback to CUDA events)
median_time, std_time = bench_gpu_time(
flashinfer.scale,
args=(x, 2.0),
enable_cupti=True,
dry_run_iters=10,
repeat_iters=100,
)
print(f"{size:>10} {str(dtype):>10} {median_time*1e6:>12.2f} {std_time*1e6:>10.2f}")
if __name__ == "__main__":
bench_scale()
For more complex kernels, consider:
benchmarks/flashinfer_benchmark.py if applicable→ For complete benchmarking guide, see .claude/skills/benchmark-kernel/SKILL.md
include/flashinfer/scale.cuh # NEW: CUDA kernel definition
csrc/scale.cu # NEW: PyTorch launcher
csrc/scale_jit_binding.cu # NEW: TVM-FFI binding
flashinfer/jit/scale.py # NEW: JIT generator
flashinfer/scale.py # NEW: Python API (with @flashinfer_api(trace=...))
flashinfer/trace/templates/scale.py # NEW: TraceTemplate definition
flashinfer/__init__.py # MODIFIED: Export API
flashinfer/aot.py # MODIFIED: Register AOT
tests/test_scale.py # NEW: Kernel unit tests
tests/trace/test_fi_trace_template_consistency.py # MODIFIED: Add (func, template) pair
benchmarks/bench_scale.py # NEW: Benchmark script