Write Metal/MPS kernels for PyTorch operators. Use when adding MPS device support to operators, implementing Metal shaders, or porting CUDA kernels to Apple Silicon...
This skill guides you through implementing Metal kernels for PyTorch operators on Apple Silicon.
Important: The goal of this skill is to use native Metal capabilities via the c10/metal/ infrastructure, NOT MPSGraph. Native Metal kernels provide better control, performance, and maintainability.
There are two workflows covered by this skill:
Both workflows involve:
aten/src/ATen/native/native_functions.yamlaten/src/ATen/native/mps/kernels/aten/src/ATen/native/mps/operations/Location: aten/src/ATen/native/native_functions.yaml
Find the operator entry and add MPS dispatch:
# Simple MPS-specific implementation
- func: my_op(Tensor self) -> Tensor
dispatch:
CPU: my_op_cpu
CUDA: my_op_cuda
MPS: my_op_mps
# Shared implementation across devices (preferred for structured kernels)
- func: my_op.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)
dispatch:
CPU, CUDA, MPS: my_op_out
# Structured kernel (preferred for new ops)
- func: my_op.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)
structured: True
structured_inherits: TensorIteratorBase
dispatch:
CPU, CUDA, MPS: my_op_out
When migrating an existing operator from MPSGraph to native Metal, consolidate the dispatch entry:
# BEFORE (MPSGraph-based, separate dispatch)
- func: atan2.out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!)
structured: True
structured_inherits: TensorIteratorBase
dispatch:
CPU, CUDA: atan2_out
MPS: atan2_out_mps # Separate MPS implementation
# AFTER (native Metal, shared dispatch via stub)
- func: atan2.out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!)
structured: True
structured_inherits: TensorIteratorBase
dispatch:
CPU, CUDA, MPS: atan2_out # MPS now uses the same stub mechanism
Key change: Replace MPS: my_op_out_mps with adding MPS to the shared dispatch line (e.g., CPU, CUDA, MPS: my_op_out).
Dispatch naming conventions:
MPS: function_name_mps - MPS-specific implementation (old MPSGraph pattern)CPU, CUDA, MPS: function_name - Shared stub implementation (native Metal pattern)Location: aten/src/ATen/native/mps/kernels/
// MyKernel.metal
#include <c10/metal/indexing.h>
#include <c10/metal/utils.h>
#include <metal_stdlib>
using namespace metal;
using namespace c10::metal;
// Define operation functor
struct my_op_functor {
template <typename T>
inline T operator()(const T x) {
return /* your operation */;
}
};
// Register for supported types
REGISTER_UNARY_OP(my_op, float, float);
REGISTER_UNARY_OP(my_op, half, half);
REGISTER_UNARY_OP(my_op, bfloat, bfloat);
struct my_binary_functor {
template <typename T>
inline T operator()(const T a, const T b) {
return /* your operation */;
}
};
REGISTER_BINARY_OP(my_binary, float, float);
REGISTER_BINARY_OP(my_binary, half, half);
For binary operations, use the convenience macros defined in BinaryKernel.metal:
// Floating-point types only (float, half, bfloat)
REGISTER_FLOAT_BINARY_OP(my_op);
// Integral types with float output (for math ops like atan2, copysign)
// Registers: long->float, int->float, short->float, uchar->float, char->float, bool->float
REGISTER_INT2FLOAT_BINARY_OP(my_op);
// Integral types with same-type output (for bitwise/logical ops)
// Registers: long, int, short, uchar, char, bool
REGISTER_INTEGER_BINARY_OP(my_op);
// Floating-point with opmath precision (for ops needing higher precision)
REGISTER_OPMATH_FLOAT_BINARY_OP(my_op);
Common patterns:
REGISTER_FLOAT_BINARY_OP and REGISTER_INT2FLOAT_BINARY_OPREGISTER_FLOAT_BINARY_OP and REGISTER_INTEGER_BINARY_OPREGISTER_FLOAT_BINARY_OP and REGISTER_INTEGER_BINARY_OPExample for atan2 (supports both float and int inputs):
struct atan2_functor {
template <typename T, enable_if_t<is_floating_point_v<T>, bool> = true>
inline T operator()(const T a, const T b) {
return static_cast<T>(precise::atan2(float(a), float(b)));
}
template <typename T, enable_if_t<is_integral_v<T>, bool> = true>
inline float operator()(const T a, const T b) {
return precise::atan2(float(a), float(b));
}
};
REGISTER_FLOAT_BINARY_OP(atan2);
REGISTER_INT2FLOAT_BINARY_OP(atan2);
struct my_alpha_functor {
template <typename T>
inline T operator()(const T a, const T b, const T alpha) {
return a + c10::metal::mul(alpha, b);
}
};
REGISTER_UNARY_ALPHA_OP(my_alpha, float, float, float);
REGISTER_UNARY_ALPHA_OP(my_alpha, half, half, half);
struct special_functor {
// Floating point types
template <typename T, enable_if_t<is_scalar_floating_point_v<T>, bool> = true>
inline T operator()(const T x) {
return precise::exp(x); // Use precise math
}
// Integral types
template <typename T, enable_if_t<is_scalar_integral_v<T>, bool> = true>
inline float operator()(const T x) {
return precise::exp(float(x));
}
// Complex types (float2 for cfloat, half2 for chalf)
template <typename T, enable_if_t<is_complex_v<T>, bool> = true>
inline T operator()(const T x) {
// x.x = real, x.y = imaginary
return T(/* real */, /* imag */);
}
};
Note on complex types: Complex numbers in Metal are represented as vector types:
c10::complex<float> maps to float2 (x = real, y = imaginary)c10::complex<half> maps to half2Use is_complex_v<T> to specialize for complex types in functors.
utils.h:
opmath_t<T> - Operation math type (half->float)accum_t<T> - Accumulation type for reductionsmax(), min() with NaN propagationspecial_math.h:
precise::exp(), precise::log(), precise::sqrt()precise::sin(), precise::cos(), precise::tan()erf(), erfc(), erfinv()indexing.h:
REGISTER_UNARY_OP(name, in_type, out_type)REGISTER_BINARY_OP(name, in_type, out_type)REGISTER_UNARY_ALPHA_OP(name, in_type, alpha_type, out_type)Location: aten/src/ATen/native/mps/operations/
Choose or create an appropriate file based on operation type:
UnaryKernel.mm - Single input operations via stub dispatchBinaryKernel.mm - Two input operations via stub dispatchUnaryOps.mm / BinaryOps.mm - Legacy MPSGraph implementations (for reference)ReduceOps.mm - Reductions (sum, mean, max, etc.)For structured kernels that use the TensorIterator pattern:
// In BinaryKernel.mm (or appropriate file)
static void my_op_mps_kernel(TensorIteratorBase& iter) {
lib.exec_binary_kernel(iter, "my_op"); // "my_op" matches the functor name in .metal
}
// Register the MPS stub - this connects to the dispatch system
REGISTER_DISPATCH(my_op_stub, &my_op_mps_kernel)
For unary operations:
static void my_unary_mps_kernel(TensorIteratorBase& iter) {
lib.exec_unary_kernel(iter, "my_unary");
}
REGISTER_DISPATCH(my_unary_stub, &my_unary_mps_kernel)
When migrating from MPSGraph, also remove the old implementation:
Remove from BinaryOps.mm (or UnaryOps.mm):
TORCH_IMPL_FUNC(my_op_out_mps) implementation#include <ATen/ops/my_op_native.h> headerAdd to BinaryKernel.mm (or UnaryKernel.mm):
REGISTER_DISPATCH callAfter making changes, compile to verify everything builds correctly:
cd build && ninja torch_cpu
Basic operator support is already tested by test_output_match in test/test_mps.py. After implementing an operator, enable testing by removing expected failures:
Location: torch/testing/_internal/common_mps.py
Find and remove the operator from skip/xfail lists:
# Remove entries like:
MPS_XFAILLIST = {
"my_op": ..., # Remove this line
}
MPS_SKIPLIST = {
"my_op": ..., # Remove this line
}
Location: torch/testing/_internal/common_methods_invocations.py (or related files)
Remove MPS-specific decorators from the OpInfo:
OpInfo(
"my_op",
# Remove decorators like:
# decorators=[skipMPS, expectedFailureMPS("reason")],
...
)
# Run the specific operator test
python test/test_mps.py -k test_output_match_my_op
# Or run full MPS test suite
python test/test_mps.py
torch.mps.compile_shaderUse torch.mps.compile_shader to JIT-compile and test individual Metal kernels in isolation. This is invaluable for debugging multi-kernel pipelines where you need to verify each stage independently.
import torch
source = '''
#include <metal_stdlib>
using namespace metal;
kernel void my_kernel(
const device float* input [[buffer(0)]],
device float* output [[buffer(1)]],
uint tid [[thread_position_in_grid]]) {
output[tid] = input[tid] * 2.0;
}
'''
lib = torch.mps.compile_shader(source)
inp = torch.tensor([1.0, 2.0, 3.0], device='mps')
out = torch.zeros(3, device='mps')
lib.my_kernel(inp, out, threads=[3, 1, 1], group_size=[3, 1, 1])
torch.mps.synchronize()
print(out) # tensor([2., 4., 6.], device='mps:0')
compile_shader uses dispatchThreads semantics (same as mtl_dispatch1DJob in PyTorch):
threads=[N, 1, 1] — total number of threads (NOT threadgroups)group_size=[G, 1, 1] — threads per threadgroupThis differs from the dispatchThreadgroups API used by some host-side code. To match dispatchThreadgroups:MTLSizeMake(num_tgs, num_slices, 1) threadsPerThreadgroup:MTLSizeMake(TG_SIZE, 1, 1):
# Equivalent compile_shader call:
lib.kernel(args...,
threads=[num_tgs * TG_SIZE, num_slices, 1],
group_size=[TG_SIZE, 1, 1])
Pass scalar constants as single-element tensors:
slice_size = torch.tensor([1024], dtype=torch.int32, device='mps')
lib.my_kernel(data, output, slice_size, threads=[1024, 1, 1], group_size=[256, 1, 1])
When a pipeline of kernels (e.g., histogram → prefix_sum → scatter) produces wrong results, test each kernel individually and verify its output against a Python/NumPy reference:
# 1. Run GPU kernel
lib.histogram(keys, hist, ..., threads=[N, 1, 1], group_size=[256, 1, 1])
torch.mps.synchronize()
# 2. Compute reference in Python
ref_hist = compute_histogram_cpu(keys.cpu().numpy(), ...)
# 3. Compare
assert np.array_equal(hist.cpu().numpy(), ref_hist), "Histogram mismatch!"
This isolates which kernel in the pipeline is broken, rather than debugging the entire pipeline at once.
threads count — threads is total threads, not threadgroups. For 5 threadgroups of 256, use threads=[1280, 1, 1].compile_shader doesn't support [[threadgroup(N)]] parameters directly. If your kernel needs threadgroup memory, restructure to use threadgroup arrays declared inside the kernel body instead.native_functions.yamlkernels/operations/torch/testing/_internal/common_mps.py