Smithery Logo
MCPsSkillsDocsPricing
Login
Smithery Logo

Accelerating the Agent Economy

Resources

DocumentationPrivacy PolicySystem Status

Company

PricingAboutBlog

Connect

© 2026 Smithery. All rights reserved.

    pytorch

    metal-kernel

    pytorch/metal-kernel
    Coding
    97,267
    11 installs

    About

    SKILL.md

    Install

    Install via Skills CLI

    or add to your agent
    • Claude Code
      Claude Code
    • Codex
      Codex
    • OpenClaw
      OpenClaw
    • Cursor
      Cursor
    • Amp
      Amp
    • GitHub Copilot
      GitHub Copilot
    • Gemini CLI
      Gemini CLI
    • Kilo Code
      Kilo Code
    • Junie
      Junie
    • Replit
      Replit
    • Windsurf
      Windsurf
    • Cline
      Cline
    • Continue
      Continue
    • OpenCode
      OpenCode
    • OpenHands
      OpenHands
    • Roo Code
      Roo Code
    • Augment
      Augment
    • Goose
      Goose
    • Trae
      Trae
    • Zencoder
      Zencoder
    • Antigravity
      Antigravity
    ├─
    ├─
    └─

    About

    Write Metal/MPS kernels for PyTorch operators. Use when adding MPS device support to operators, implementing Metal shaders, or porting CUDA kernels to Apple Silicon...

    SKILL.md

    Metal Kernel Writing Guide

    This skill guides you through implementing Metal kernels for PyTorch operators on Apple Silicon.

    Important: The goal of this skill is to use native Metal capabilities via the c10/metal/ infrastructure, NOT MPSGraph. Native Metal kernels provide better control, performance, and maintainability.

    Overview

    There are two workflows covered by this skill:

    1. Adding new MPS support - Implementing a new operator from scratch
    2. Migrating from MPSGraph - Converting existing MPSGraph-based operators to native Metal

    Both workflows involve:

    1. Update dispatch in aten/src/ATen/native/native_functions.yaml
    2. Write Metal kernel in aten/src/ATen/native/mps/kernels/
    3. Implement host-side stub in aten/src/ATen/native/mps/operations/

    Step 1: Update native_functions.yaml

    Location: aten/src/ATen/native/native_functions.yaml

    For New Operators

    Find the operator entry and add MPS dispatch:

    # Simple MPS-specific implementation
    - func: my_op(Tensor self) -> Tensor
      dispatch:
        CPU: my_op_cpu
        CUDA: my_op_cuda
        MPS: my_op_mps
    
    # Shared implementation across devices (preferred for structured kernels)
    - func: my_op.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)
      dispatch:
        CPU, CUDA, MPS: my_op_out
    
    # Structured kernel (preferred for new ops)
    - func: my_op.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)
      structured: True
      structured_inherits: TensorIteratorBase
      dispatch:
        CPU, CUDA, MPS: my_op_out
    

    For Migrating from MPSGraph

    When migrating an existing operator from MPSGraph to native Metal, consolidate the dispatch entry:

    # BEFORE (MPSGraph-based, separate dispatch)
    - func: atan2.out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!)
      structured: True
      structured_inherits: TensorIteratorBase
      dispatch:
        CPU, CUDA: atan2_out
        MPS: atan2_out_mps  # Separate MPS implementation
    
    # AFTER (native Metal, shared dispatch via stub)
    - func: atan2.out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!)
      structured: True
      structured_inherits: TensorIteratorBase
      dispatch:
        CPU, CUDA, MPS: atan2_out  # MPS now uses the same stub mechanism
    

    Key change: Replace MPS: my_op_out_mps with adding MPS to the shared dispatch line (e.g., CPU, CUDA, MPS: my_op_out).

    Dispatch naming conventions:

    • MPS: function_name_mps - MPS-specific implementation (old MPSGraph pattern)
    • CPU, CUDA, MPS: function_name - Shared stub implementation (native Metal pattern)

    Step 2: Implement Metal Kernel

    Location: aten/src/ATen/native/mps/kernels/

    Unary Kernel Pattern

    // MyKernel.metal
    #include <c10/metal/indexing.h>
    #include <c10/metal/utils.h>
    #include <metal_stdlib>
    
    using namespace metal;
    using namespace c10::metal;
    
    // Define operation functor
    struct my_op_functor {
      template <typename T>
      inline T operator()(const T x) {
        return /* your operation */;
      }
    };
    
    // Register for supported types
    REGISTER_UNARY_OP(my_op, float, float);
    REGISTER_UNARY_OP(my_op, half, half);
    REGISTER_UNARY_OP(my_op, bfloat, bfloat);
    

    Binary Kernel Pattern

    struct my_binary_functor {
      template <typename T>
      inline T operator()(const T a, const T b) {
        return /* your operation */;
      }
    };
    
    REGISTER_BINARY_OP(my_binary, float, float);
    REGISTER_BINARY_OP(my_binary, half, half);
    

    Binary Kernel Type Registration Macros

    For binary operations, use the convenience macros defined in BinaryKernel.metal:

    // Floating-point types only (float, half, bfloat)
    REGISTER_FLOAT_BINARY_OP(my_op);
    
    // Integral types with float output (for math ops like atan2, copysign)
    // Registers: long->float, int->float, short->float, uchar->float, char->float, bool->float
    REGISTER_INT2FLOAT_BINARY_OP(my_op);
    
    // Integral types with same-type output (for bitwise/logical ops)
    // Registers: long, int, short, uchar, char, bool
    REGISTER_INTEGER_BINARY_OP(my_op);
    
    // Floating-point with opmath precision (for ops needing higher precision)
    REGISTER_OPMATH_FLOAT_BINARY_OP(my_op);
    

    Common patterns:

    • Math functions (atan2, copysign, logaddexp): Use both REGISTER_FLOAT_BINARY_OP and REGISTER_INT2FLOAT_BINARY_OP
    • Comparison/logical ops (maximum, minimum): Use both REGISTER_FLOAT_BINARY_OP and REGISTER_INTEGER_BINARY_OP
    • Arithmetic ops (add, sub, mul): Use both REGISTER_FLOAT_BINARY_OP and REGISTER_INTEGER_BINARY_OP

    Example for atan2 (supports both float and int inputs):

    struct atan2_functor {
      template <typename T, enable_if_t<is_floating_point_v<T>, bool> = true>
      inline T operator()(const T a, const T b) {
        return static_cast<T>(precise::atan2(float(a), float(b)));
      }
      template <typename T, enable_if_t<is_integral_v<T>, bool> = true>
      inline float operator()(const T a, const T b) {
        return precise::atan2(float(a), float(b));
      }
    };
    
    REGISTER_FLOAT_BINARY_OP(atan2);
    REGISTER_INT2FLOAT_BINARY_OP(atan2);
    

    With Scalar Parameter

    struct my_alpha_functor {
      template <typename T>
      inline T operator()(const T a, const T b, const T alpha) {
        return a + c10::metal::mul(alpha, b);
      }
    };
    
    REGISTER_UNARY_ALPHA_OP(my_alpha, float, float, float);
    REGISTER_UNARY_ALPHA_OP(my_alpha, half, half, half);
    

    Type-Specialized Functor

    struct special_functor {
      // Floating point types
      template <typename T, enable_if_t<is_scalar_floating_point_v<T>, bool> = true>
      inline T operator()(const T x) {
        return precise::exp(x);  // Use precise math
      }
    
      // Integral types
      template <typename T, enable_if_t<is_scalar_integral_v<T>, bool> = true>
      inline float operator()(const T x) {
        return precise::exp(float(x));
      }
    
      // Complex types (float2 for cfloat, half2 for chalf)
      template <typename T, enable_if_t<is_complex_v<T>, bool> = true>
      inline T operator()(const T x) {
        // x.x = real, x.y = imaginary
        return T(/* real */, /* imag */);
      }
    };
    

    Note on complex types: Complex numbers in Metal are represented as vector types:

    • c10::complex<float> maps to float2 (x = real, y = imaginary)
    • c10::complex<half> maps to half2

    Use is_complex_v<T> to specialize for complex types in functors.

    Available c10/metal Utilities

    utils.h:

    • opmath_t<T> - Operation math type (half->float)
    • accum_t<T> - Accumulation type for reductions
    • max(), min() with NaN propagation

    special_math.h:

    • precise::exp(), precise::log(), precise::sqrt()
    • precise::sin(), precise::cos(), precise::tan()
    • erf(), erfc(), erfinv()

    indexing.h:

    • REGISTER_UNARY_OP(name, in_type, out_type)
    • REGISTER_BINARY_OP(name, in_type, out_type)
    • REGISTER_UNARY_ALPHA_OP(name, in_type, alpha_type, out_type)

    Step 3: Implement Host-Side Stub

    Location: aten/src/ATen/native/mps/operations/

    Choose or create an appropriate file based on operation type:

    • UnaryKernel.mm - Single input operations via stub dispatch
    • BinaryKernel.mm - Two input operations via stub dispatch
    • UnaryOps.mm / BinaryOps.mm - Legacy MPSGraph implementations (for reference)
    • ReduceOps.mm - Reductions (sum, mean, max, etc.)
    • Create new file for distinct operation categories

    Stub Registration Pattern (Preferred for Native Metal)

    For structured kernels that use the TensorIterator pattern:

    // In BinaryKernel.mm (or appropriate file)
    
    static void my_op_mps_kernel(TensorIteratorBase& iter) {
      lib.exec_binary_kernel(iter, "my_op");  // "my_op" matches the functor name in .metal
    }
    
    // Register the MPS stub - this connects to the dispatch system
    REGISTER_DISPATCH(my_op_stub, &my_op_mps_kernel)
    

    For unary operations:

    static void my_unary_mps_kernel(TensorIteratorBase& iter) {
      lib.exec_unary_kernel(iter, "my_unary");
    }
    
    REGISTER_DISPATCH(my_unary_stub, &my_unary_mps_kernel)
    

    Migration: Removing Old MPSGraph Implementation

    When migrating from MPSGraph, also remove the old implementation:

    1. Remove from BinaryOps.mm (or UnaryOps.mm):

      • Delete the TORCH_IMPL_FUNC(my_op_out_mps) implementation
      • Remove the corresponding #include <ATen/ops/my_op_native.h> header
    2. Add to BinaryKernel.mm (or UnaryKernel.mm):

      • Add the static kernel function
      • Add the REGISTER_DISPATCH call

    Step 4: Compile

    After making changes, compile to verify everything builds correctly:

    cd build && ninja torch_cpu
    

    Testing

    Basic operator support is already tested by test_output_match in test/test_mps.py. After implementing an operator, enable testing by removing expected failures:

    1. Remove from common_mps.py

    Location: torch/testing/_internal/common_mps.py

    Find and remove the operator from skip/xfail lists:

    # Remove entries like:
    MPS_XFAILLIST = {
        "my_op": ...,  # Remove this line
    }
    
    MPS_SKIPLIST = {
        "my_op": ...,  # Remove this line
    }
    

    2. Remove from OpInfo decorators

    Location: torch/testing/_internal/common_methods_invocations.py (or related files)

    Remove MPS-specific decorators from the OpInfo:

    OpInfo(
        "my_op",
        # Remove decorators like:
        # decorators=[skipMPS, expectedFailureMPS("reason")],
        ...
    )
    

    3. Run tests to verify

    # Run the specific operator test
    python test/test_mps.py -k test_output_match_my_op
    
    # Or run full MPS test suite
    python test/test_mps.py
    

    Debugging Metal Kernels with torch.mps.compile_shader

    Use torch.mps.compile_shader to JIT-compile and test individual Metal kernels in isolation. This is invaluable for debugging multi-kernel pipelines where you need to verify each stage independently.

    Basic Usage

    import torch
    
    source = '''
    #include <metal_stdlib>
    using namespace metal;
    
    kernel void my_kernel(
        const device float* input [[buffer(0)]],
        device float* output [[buffer(1)]],
        uint tid [[thread_position_in_grid]]) {
      output[tid] = input[tid] * 2.0;
    }
    '''
    
    lib = torch.mps.compile_shader(source)
    
    inp = torch.tensor([1.0, 2.0, 3.0], device='mps')
    out = torch.zeros(3, device='mps')
    lib.my_kernel(inp, out, threads=[3, 1, 1], group_size=[3, 1, 1])
    torch.mps.synchronize()
    print(out)  # tensor([2., 4., 6.], device='mps:0')
    

    Dispatch Semantics

    compile_shader uses dispatchThreads semantics (same as mtl_dispatch1DJob in PyTorch):

    • threads=[N, 1, 1] — total number of threads (NOT threadgroups)
    • group_size=[G, 1, 1] — threads per threadgroup

    This differs from the dispatchThreadgroups API used by some host-side code. To match dispatchThreadgroups:MTLSizeMake(num_tgs, num_slices, 1) threadsPerThreadgroup:MTLSizeMake(TG_SIZE, 1, 1):

    # Equivalent compile_shader call:
    lib.kernel(args...,
        threads=[num_tgs * TG_SIZE, num_slices, 1],
        group_size=[TG_SIZE, 1, 1])
    

    Constant Buffer Parameters

    Pass scalar constants as single-element tensors:

    slice_size = torch.tensor([1024], dtype=torch.int32, device='mps')
    lib.my_kernel(data, output, slice_size, threads=[1024, 1, 1], group_size=[256, 1, 1])
    

    Debugging Strategy for Multi-Kernel Pipelines

    When a pipeline of kernels (e.g., histogram → prefix_sum → scatter) produces wrong results, test each kernel individually and verify its output against a Python/NumPy reference:

    # 1. Run GPU kernel
    lib.histogram(keys, hist, ..., threads=[N, 1, 1], group_size=[256, 1, 1])
    torch.mps.synchronize()
    
    # 2. Compute reference in Python
    ref_hist = compute_histogram_cpu(keys.cpu().numpy(), ...)
    
    # 3. Compare
    assert np.array_equal(hist.cpu().numpy(), ref_hist), "Histogram mismatch!"
    

    This isolates which kernel in the pipeline is broken, rather than debugging the entire pipeline at once.

    Common Pitfalls

    • Wrong threads count — threads is total threads, not threadgroups. For 5 threadgroups of 256, use threads=[1280, 1, 1].
    • Threadgroup memory — compile_shader doesn't support [[threadgroup(N)]] parameters directly. If your kernel needs threadgroup memory, restructure to use threadgroup arrays declared inside the kernel body instead.

    Checklist

    • Added MPS dispatch to native_functions.yaml
    • Implemented Metal kernel in kernels/
    • Implemented host-side operator in operations/
    • Handles empty tensors
    • Handles non-contiguous tensors
    • Supports required dtypes (float32, float16, bfloat16, and often complex types via float2/half2)
    • Removed expected failures from torch/testing/_internal/common_mps.py
    • Removed skip/xfail decorators from OpInfo (if applicable)
    Recommended Servers
    Kernel
    Kernel
    Gemini
    Gemini
    ScrapeGraph AI Integration Server
    ScrapeGraph AI Integration Server
    Repository
    pytorch/pytorch
    Files