# Copyright (c) 2026 The mlx-sparse contributors - All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
from dataclasses import dataclass
from enum import Enum
from importlib import resources
from typing import Iterator, Mapping
import mlx.core as mx
from mlx_sparse._ext_loader import extension
class NativeBackend(str, Enum):
"""Native execution backend families known to ``mlx-sparse``."""
CPU = "cpu"
METAL = "metal"
ACCELERATE = "accelerate"
CUDA = "cuda"
ROCM = "rocm"
class NativeCapability(str, Enum):
"""Runtime-checkable native capabilities.
The enum is intentionally backend-oriented. Future releases can add more
fine-grained operation capabilities without changing the public
:data:`capabilities` view.
"""
NATIVE_EXTENSION = "native_extension"
CPU_KERNELS = "cpu_kernels"
METAL_KERNELS = "metal_kernels"
ACCELERATE_SOLVERS = "accelerate_solvers"
CUDA_KERNELS = "cuda_kernels"
ROCM_KERNELS = "rocm_kernels"
class NativeCapabilityStatus(str, Enum):
"""Availability state for a native capability."""
AVAILABLE = "available"
UNAVAILABLE = "unavailable"
NOT_BUILT = "not_built"
@dataclass(frozen=True)
class NativeCapabilityRecord:
"""Status record for one :class:`NativeCapability`."""
capability: NativeCapability
status: NativeCapabilityStatus
backend: NativeBackend | None = None
built: bool = False
runtime_available: bool = False
reason: str = ""
@property
def available(self) -> bool:
"""Return ``True`` when the capability can be used now."""
return self.status is NativeCapabilityStatus.AVAILABLE
@dataclass(frozen=True)
class NativeCapabilities:
"""Snapshot of native capabilities for the current Python process."""
records: tuple[NativeCapabilityRecord, ...]
platform: str
architecture: str
def __iter__(self) -> Iterator[NativeCapabilityRecord]:
return iter(self.records)
def __contains__(self, capability: NativeCapability | str) -> bool:
return self.has(capability)
def __getitem__(self, capability: NativeCapability | str) -> NativeCapabilityRecord:
return self.get(capability)
def get(self, capability: NativeCapability | str) -> NativeCapabilityRecord:
"""Return the status record for ``capability``."""
key = _coerce_capability(capability)
for record in self.records:
if record.capability is key:
return record
raise KeyError(key)
def has(self, capability: NativeCapability | str) -> bool:
"""Return ``True`` if ``capability`` is available now."""
return self.get(capability).available
def status(self, capability: NativeCapability | str) -> NativeCapabilityStatus:
"""Return the enum status for ``capability``."""
return self.get(capability).status
def by_backend(
self, backend: NativeBackend | str
) -> tuple[NativeCapabilityRecord, ...]:
"""Return all capability records associated with ``backend``."""
key = _coerce_backend(backend)
return tuple(record for record in self.records if record.backend is key)
@property
def available(self) -> frozenset[NativeCapability]:
"""Capabilities that are available in the current process."""
return frozenset(
record.capability for record in self.records if record.available
)
class _CapabilityView:
"""User-facing native capability view.
``mlx_sparse.capabilities`` is intentionally small and string-friendly:
* capability names: ``"extension"``, ``"cpu"``, ``"metal"``,
``"accelerate"``, ``"cuda"``, ``"rocm"``
* statuses: ``"available"``, ``"unavailable"``, ``"not_built"``
Example::
import mlx_sparse as ms
if ms.capabilities.METAL:
ms.use_gpu()
if ms.capabilities.status("accelerate") == "not_built":
...
"""
_PUBLIC_NAMES = ("extension", "cpu", "metal", "accelerate", "cuda", "rocm")
def __repr__(self) -> str:
statuses = ", ".join(
f"{name}={self.status(name)!r}" for name in self._PUBLIC_NAMES
)
return f"mlx_sparse.capabilities({statuses})"
@property
def extension(self) -> bool:
"""Whether the native extension is loaded."""
return self.has("extension")
@property
def EXTENSION(self) -> bool:
"""Whether the native extension is loaded."""
return self.extension
@property
def cpu(self) -> bool:
"""Whether native CPU kernels are available."""
return self.has("cpu")
@property
def CPU(self) -> bool:
"""Whether native CPU kernels are available."""
return self.cpu
@property
def metal(self) -> bool:
"""Whether native Metal kernels are available."""
return self.has("metal")
@property
def METAL(self) -> bool:
"""Whether native Metal kernels are available."""
return self.metal
@property
def accelerate(self) -> bool:
"""Whether Accelerate solver support is available."""
return self.has("accelerate")
@property
def ACCELERATE(self) -> bool:
"""Whether Accelerate solver support is available."""
return self.accelerate
@property
def cuda(self) -> bool:
"""Whether native CUDA kernels are available."""
return self.has("cuda")
@property
def CUDA(self) -> bool:
"""Whether native CUDA kernels are available."""
return self.cuda
@property
def rocm(self) -> bool:
"""Whether native ROCm/HIP kernels are available."""
return self.has("rocm")
@property
def ROCM(self) -> bool:
"""Whether native ROCm/HIP kernels are available."""
return self.rocm
@property
def names(self) -> tuple[str, ...]:
"""Public capability names accepted by ``has`` and ``status``."""
return self._PUBLIC_NAMES
@property
def platform(self) -> str:
"""Native extension platform reported by the current build."""
return _native_capabilities().platform
@property
def architecture(self) -> str:
"""Native extension architecture reported by the current build."""
return _native_capabilities().architecture
def has(self, capability: NativeCapability | str) -> bool:
"""Return ``True`` if ``capability`` is available now."""
return _native_capabilities().has(capability)
def status(self, capability: NativeCapability | str) -> str:
"""Return ``"available"``, ``"unavailable"``, or ``"not_built"``."""
return _native_capabilities().status(capability).value
def reason(self, capability: NativeCapability | str) -> str:
"""Return a human-readable reason for the current status."""
return _native_capabilities().get(capability).reason
def built(self, capability: NativeCapability | str) -> bool:
"""Return ``True`` if ``capability`` was compiled into this build."""
return _native_capabilities().get(capability).built
def runtime_available(self, capability: NativeCapability | str) -> bool:
"""Return ``True`` if the runtime can use a compiled capability."""
return _native_capabilities().get(capability).runtime_available
capabilities = _CapabilityView()
def _native_capabilities() -> NativeCapabilities:
"""Return enum-backed native capability status for this process."""
facts = _compiled_facts()
platform = str(facts.get("platform") or _python_platform())
architecture = str(facts.get("architecture") or _python_architecture())
ext_loaded = bool(facts.get("extension", False))
records = [
_extension_record(ext_loaded),
_cpu_record(ext_loaded, bool(facts.get("cpu", False))),
_metal_record(ext_loaded, bool(facts.get("metal", False))),
_backend_record(
NativeCapability.ACCELERATE_SOLVERS,
NativeBackend.ACCELERATE,
ext_loaded,
bool(facts.get("accelerate", False)),
platform,
_accelerate_not_built_reason(
bool(facts.get("accelerate_framework", False))
),
),
_backend_record(
NativeCapability.CUDA_KERNELS,
NativeBackend.CUDA,
ext_loaded,
bool(facts.get("cuda", False)),
platform,
"CUDA kernels are not compiled into this build.",
),
_backend_record(
NativeCapability.ROCM_KERNELS,
NativeBackend.ROCM,
ext_loaded,
bool(facts.get("rocm", False)),
platform,
"ROCm/HIP kernels are not compiled into this build.",
),
]
return NativeCapabilities(
records=tuple(records),
platform=platform,
architecture=architecture,
)
[docs]
def has_capability(capability: NativeCapability | str) -> bool:
"""Return ``True`` if ``capability`` is available in this process."""
return capabilities.has(capability)
def _coerce_capability(capability: NativeCapability | str) -> NativeCapability:
if isinstance(capability, NativeCapability):
return capability
aliases = {
"extension": NativeCapability.NATIVE_EXTENSION,
"native": NativeCapability.NATIVE_EXTENSION,
"native_extension": NativeCapability.NATIVE_EXTENSION,
"cpu": NativeCapability.CPU_KERNELS,
"cpu_kernels": NativeCapability.CPU_KERNELS,
"metal": NativeCapability.METAL_KERNELS,
"gpu": NativeCapability.METAL_KERNELS,
"metal_kernels": NativeCapability.METAL_KERNELS,
"accelerate": NativeCapability.ACCELERATE_SOLVERS,
"accelerate_solvers": NativeCapability.ACCELERATE_SOLVERS,
"cuda": NativeCapability.CUDA_KERNELS,
"cuda_kernels": NativeCapability.CUDA_KERNELS,
"rocm": NativeCapability.ROCM_KERNELS,
"hip": NativeCapability.ROCM_KERNELS,
"rocm_kernels": NativeCapability.ROCM_KERNELS,
}
normalized = capability.lower().replace("-", "_")
if normalized in aliases:
return aliases[normalized]
return NativeCapability(capability)
def _coerce_backend(backend: NativeBackend | str) -> NativeBackend:
if isinstance(backend, NativeBackend):
return backend
return NativeBackend(backend)
def _compiled_facts() -> Mapping[str, object]:
ext = extension()
if ext is None:
return {
"extension": False,
"cpu": False,
"metal": False,
"accelerate": False,
"accelerate_framework": False,
"cuda": False,
"rocm": False,
"platform": _python_platform(),
"architecture": _python_architecture(),
}
getter = getattr(ext, "_compiled_capabilities", None)
if getter is not None:
return getter()
# Older editable builds may have the extension loaded before this binding
# exists. Keep capability checks useful until the extension is rebuilt.
return {
"extension": True,
"cpu": True,
"metal": _metallib_present(),
"accelerate": False,
"accelerate_framework": False,
"cuda": False,
"rocm": False,
"platform": _python_platform(),
"architecture": _python_architecture(),
}
def _extension_record(loaded: bool) -> NativeCapabilityRecord:
if loaded:
return NativeCapabilityRecord(
capability=NativeCapability.NATIVE_EXTENSION,
status=NativeCapabilityStatus.AVAILABLE,
built=True,
runtime_available=True,
reason="The mlx_sparse native extension is loaded.",
)
return NativeCapabilityRecord(
capability=NativeCapability.NATIVE_EXTENSION,
status=NativeCapabilityStatus.UNAVAILABLE,
reason=(
"The mlx_sparse native extension is not loaded; Python fallback "
"implementations will be used where available."
),
)
def _accelerate_not_built_reason(framework_built: bool) -> str:
if framework_built:
return (
"The Accelerate framework was detected and linked at build time, "
"but Accelerate sparse solver integration is not enabled in this build."
)
return "Accelerate sparse solver integration is not compiled into this build."
def _cpu_record(extension_loaded: bool, built: bool) -> NativeCapabilityRecord:
if not extension_loaded:
return NativeCapabilityRecord(
capability=NativeCapability.CPU_KERNELS,
backend=NativeBackend.CPU,
status=NativeCapabilityStatus.UNAVAILABLE,
reason="Native CPU kernels require the mlx_sparse extension.",
)
if not built:
return NativeCapabilityRecord(
capability=NativeCapability.CPU_KERNELS,
backend=NativeBackend.CPU,
status=NativeCapabilityStatus.NOT_BUILT,
reason="Native CPU kernels are not compiled into this build.",
)
available, reason = _probe_mlx_device(mx.cpu, "CPU")
return NativeCapabilityRecord(
capability=NativeCapability.CPU_KERNELS,
backend=NativeBackend.CPU,
status=(
NativeCapabilityStatus.AVAILABLE
if available
else NativeCapabilityStatus.UNAVAILABLE
),
built=True,
runtime_available=available,
reason=reason if reason else "Native C++ CPU sparse kernels are available.",
)
def _metal_record(extension_loaded: bool, built: bool) -> NativeCapabilityRecord:
if not extension_loaded:
return NativeCapabilityRecord(
capability=NativeCapability.METAL_KERNELS,
backend=NativeBackend.METAL,
status=NativeCapabilityStatus.UNAVAILABLE,
reason="Metal kernels require the mlx_sparse extension.",
)
if not built:
return NativeCapabilityRecord(
capability=NativeCapability.METAL_KERNELS,
backend=NativeBackend.METAL,
status=NativeCapabilityStatus.NOT_BUILT,
reason=(
"Metal kernels are not compiled into this build or the "
"mlx_sparse.metallib resource is missing."
),
)
available, reason = _probe_mlx_device(mx.gpu, "Metal GPU")
return NativeCapabilityRecord(
capability=NativeCapability.METAL_KERNELS,
backend=NativeBackend.METAL,
status=(
NativeCapabilityStatus.AVAILABLE
if available
else NativeCapabilityStatus.UNAVAILABLE
),
built=True,
runtime_available=available,
reason=reason if reason else "Native Metal sparse kernels are available.",
)
def _backend_record(
capability: NativeCapability,
backend: NativeBackend,
extension_loaded: bool,
built: bool,
platform: str,
not_built_reason: str,
) -> NativeCapabilityRecord:
if not extension_loaded:
return NativeCapabilityRecord(
capability=capability,
backend=backend,
status=NativeCapabilityStatus.UNAVAILABLE,
reason=f"{backend.value} support requires the mlx_sparse extension.",
)
if not built:
return NativeCapabilityRecord(
capability=capability,
backend=backend,
status=NativeCapabilityStatus.NOT_BUILT,
reason=not_built_reason,
)
available, reason = _future_backend_runtime_status(backend, platform)
return NativeCapabilityRecord(
capability=capability,
backend=backend,
status=(
NativeCapabilityStatus.AVAILABLE
if available
else NativeCapabilityStatus.UNAVAILABLE
),
built=True,
runtime_available=available,
reason=reason,
)
def _future_backend_runtime_status(
backend: NativeBackend, platform: str
) -> tuple[bool, str]:
if backend is NativeBackend.ACCELERATE and platform == "darwin":
return True, "Accelerate support is compiled and running on Darwin."
return False, f"{backend.value} support is compiled but not available at runtime."
def _probe_mlx_device(kind: mx.DeviceType, label: str) -> tuple[bool, str]:
try:
device = mx.Device(kind, 0)
except Exception as exc:
return False, f"Could not construct MLX {label} device: {exc}"
try:
if not mx.is_available(device):
return False, f"MLX reports that {label} device 0 is unavailable."
except Exception as exc:
return False, f"Could not query MLX {label} availability: {exc}"
try:
mx.device_info(device)
except Exception as exc:
return False, f"MLX could not initialize {label} device 0: {exc}"
return True, ""
def _metallib_present() -> bool:
try:
return resources.files("mlx_sparse").joinpath("mlx_sparse.metallib").is_file()
except Exception:
return False
def _python_platform() -> str:
import sys
if sys.platform == "darwin":
return "darwin"
if sys.platform.startswith("linux"):
return "linux"
if sys.platform.startswith("win"):
return "windows"
return sys.platform
def _python_architecture() -> str:
import platform
return platform.machine()
__all__ = [
"capabilities",
"has_capability",
]