Source code for mlx_sparse.runtime

# Copyright (c) 2026 The mlx-sparse contributors - All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#    http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Runtime controls for mlx-sparse CPU execution.

The public API exposes direct module attributes for common interactive use, and
an enum for structured programmatic use:

.. code-block:: python

   import mlx_sparse as ms

   print(ms.runtime.N_THREADS)
   ms.runtime.N_THREADS = 8
   ms.runtime.SPGEMM_PARALLEL = True

   with ms.runtime.context(n_threads=1, spgemm_parallel=False):
       C = A @ B

``N_THREADS`` is the resolved package-wide CPU worker count. It is
intentionally separate from operation-family switches and per-family thread
overrides so future kernels can share a common worker budget by default while
still allowing users to disable or tune one family independently.
"""

from __future__ import annotations

import contextlib
import enum
import os
import sys
import types
from collections.abc import Iterator, Mapping
from typing import Any

from mlx_sparse._config import config
from mlx_sparse._typing import is_available


[docs] class RuntimeOption(str, enum.Enum): """Runtime option identifiers accepted by :mod:`mlx_sparse.runtime`.""" N_THREADS = "CPU_THREADS" SPGEMM_PARALLEL = "SPGEMM_PARALLEL" SPGEMM_THREADS = "SPGEMM_THREADS" SOLVER_PARALLEL = "SOLVER_PARALLEL" SOLVER_THREADS = "SOLVER_THREADS"
N_THREADS = RuntimeOption.N_THREADS SPGEMM_PARALLEL = RuntimeOption.SPGEMM_PARALLEL SPGEMM_THREADS = RuntimeOption.SPGEMM_THREADS SOLVER_PARALLEL = RuntimeOption.SOLVER_PARALLEL SOLVER_THREADS = RuntimeOption.SOLVER_THREADS _MISSING = object() _OPTION_ALIASES = { "N_THREADS": RuntimeOption.N_THREADS.value, "CPU_THREADS": RuntimeOption.N_THREADS.value, "SPGEMM_PARALLEL": RuntimeOption.SPGEMM_PARALLEL.value, "SPGEMM_THREADS": RuntimeOption.SPGEMM_THREADS.value, "SOLVER_PARALLEL": RuntimeOption.SOLVER_PARALLEL.value, "SOLVER_THREADS": RuntimeOption.SOLVER_THREADS.value, } _SCHEDULER_THREAD_ENV_VARS = ( "SLURM_CPUS_PER_TASK", "PBS_NP", "LSB_DJOB_NUMPROC", "NSLOTS", ) _THREAD_HINT_ENV_VARS = ("OMP_NUM_THREADS",) _RUNTIME_ENV_VARS = ( "MLX_SPARSE_CPU_THREADS", "MLX_SPARSE_N_THREADS", "MLX_SPARSE_SPGEMM_PARALLEL", "MLX_SPARSE_SPGEMM_THREADS", "MLX_SPARSE_SOLVER_PARALLEL", "MLX_SPARSE_SOLVER_THREADS", *_THREAD_HINT_ENV_VARS, *_SCHEDULER_THREAD_ENV_VARS, ) def _normalize_option(option: RuntimeOption | str) -> str: if isinstance(option, RuntimeOption): return option.value if not isinstance(option, str): raise TypeError("runtime option must be a RuntimeOption or string.") key = option.strip().upper() try: return _OPTION_ALIASES[key] except KeyError as exc: raise KeyError(f"Unknown runtime option {option!r}.") from exc def _normalize_updates(updates: Mapping[Any, Any]) -> dict[str, Any]: return {_normalize_option(option): value for option, value in updates.items()} def _normalize_kwargs(kwargs: Mapping[str, Any]) -> dict[str, Any]: return {_normalize_option(name): value for name, value in kwargs.items()} def _parse_positive_env_int(name: str) -> int | None: raw_value = os.environ.get(name) if raw_value is None: return None token = raw_value.strip().split(",", 1)[0] if not token: return None try: value = int(token, 10) except ValueError: return None if value >= 1: return value return None def _hardware_concurrency() -> int | None: value = os.cpu_count() if value is None or value < 1: return None return value def _affinity_count() -> int | None: get_affinity = getattr(os, "sched_getaffinity", None) if get_affinity is None: return None try: value = len(get_affinity(0)) except OSError: return None if value < 1: return None return value def _detected_scheduler_counts() -> dict[str, int]: out: dict[str, int] = {} for name in _SCHEDULER_THREAD_ENV_VARS: value = _parse_positive_env_int(name) if value is not None: out[name] = value return out def _detected_thread_hints() -> dict[str, int]: out: dict[str, int] = {} for name in _THREAD_HINT_ENV_VARS: value = _parse_positive_env_int(name) if value is not None: out[name] = value return out def _mlx_info() -> dict[str, Any]: try: import mlx.core as mx except Exception as exc: # pragma: no cover - MLX is a hard dependency here. return {"available": False, "error": type(exc).__name__} return { "available": True, "default_device": str(mx.default_device()), "metal_available": bool(mx.metal.is_available()), } def _resolve_auto_threads() -> tuple[int, str]: for name in _THREAD_HINT_ENV_VARS: value = _parse_positive_env_int(name) if value is not None: return value, name for name in _SCHEDULER_THREAD_ENV_VARS: value = _parse_positive_env_int(name) if value is not None: return value, name affinity = _affinity_count() if affinity is not None: return affinity, "process_affinity" hardware = _hardware_concurrency() if hardware is not None: return hardware, "hardware_concurrency" return 1, "fallback"
[docs] def resolve_n_threads() -> tuple[int, str]: """Resolve the effective CPU worker count and the source used. Explicit ``MLX_SPARSE_CPU_THREADS`` / ``ms.runtime.N_THREADS = ...`` values win first. In ``"auto"`` mode, standard thread hints are consulted before scheduler allocations, then process affinity, then hardware concurrency. """ configured = config.get(RuntimeOption.N_THREADS.value) if isinstance(configured, int): return configured, "configured" return _resolve_auto_threads()
def _resolve_family_threads(option: RuntimeOption) -> tuple[int, str]: configured = config.get(option.value) if isinstance(configured, int): return configured, "configured" if configured == "inherit": return resolve_n_threads() return _resolve_auto_threads()
[docs] def resolve_spgemm_threads() -> tuple[int, str]: """Resolve the CPU worker count and source for sparse-sparse products.""" if not bool(config.get(RuntimeOption.SPGEMM_PARALLEL.value)): return 1, "disabled" return _resolve_family_threads(SPGEMM_THREADS)
[docs] def resolve_solver_threads() -> tuple[int, str]: """Resolve the CPU worker count and source for solver routines.""" if not bool(config.get(RuntimeOption.SOLVER_PARALLEL.value)): return 1, "disabled" return _resolve_family_threads(SOLVER_THREADS)
[docs] @contextlib.contextmanager def context( arg1: RuntimeOption | str | Mapping[Any, Any] | None = None, arg2: Any = _MISSING, **kwargs: Any, ) -> Iterator[None]: """Temporarily patch runtime options. Accepted forms mirror :meth:`mlx_sparse.config.patch`: ``context(ms.runtime.RuntimeOption.N_THREADS, 4)`` Patch one enum option. ``context({ms.runtime.RuntimeOption.N_THREADS: 4})`` Patch several enum options. ``context(n_threads=4, spgemm_parallel=False)`` Patch options with readable keyword aliases. """ if arg1 is None: if arg2 is not _MISSING: raise TypeError("context(None, value) is not a valid call form.") updates = _normalize_kwargs(kwargs) elif isinstance(arg1, Mapping): if arg2 is not _MISSING: raise TypeError("Mapping context form does not accept a second value.") if kwargs: raise TypeError("Cannot combine mapping context with keywords.") updates = _normalize_updates(arg1) else: if arg2 is _MISSING: raise TypeError("context(option, value) requires a value.") if kwargs: raise TypeError("Cannot combine two-argument context with keywords.") updates = {_normalize_option(arg1): arg2} with config.patch(updates): yield
patch = context
[docs] def info() -> dict[str, Any]: """Return structured runtime information for reports and diagnostics.""" resolved_threads, source = resolve_n_threads() spgemm_threads, spgemm_source = resolve_spgemm_threads() solver_threads, solver_source = resolve_solver_threads() spgemm_parallel = bool(config.get(RuntimeOption.SPGEMM_PARALLEL.value)) solver_parallel = bool(config.get(RuntimeOption.SOLVER_PARALLEL.value)) option_names = {option.name.lower(): option.value for option in RuntimeOption} return { "runtime_options": { key: config.get(name) for key, name in option_names.items() }, "config_sources": { key: config.value_source(name).value for key, name in option_names.items() }, "n_threads": resolved_threads, "n_threads_source": source, "spgemm_parallel": spgemm_parallel, "spgemm_n_threads": spgemm_threads, "spgemm_n_threads_source": spgemm_source, "solver_parallel": solver_parallel, "solver_n_threads": solver_threads, "solver_n_threads_source": solver_source, "native_extension_available": is_available(), "hardware_concurrency": _hardware_concurrency(), "process_affinity": _affinity_count(), "thread_hints": _detected_thread_hints(), "scheduler": _detected_scheduler_counts(), "environment": { name: os.environ[name] for name in _RUNTIME_ENV_VARS if name in os.environ }, "mlx": _mlx_info(), "config_fingerprint": config.fingerprint(), }
_OPTION_ATTRIBUTE_OPTIONS = { "N_THREADS": RuntimeOption.N_THREADS, "SPGEMM_PARALLEL": RuntimeOption.SPGEMM_PARALLEL, "SPGEMM_THREADS": RuntimeOption.SPGEMM_THREADS, "SOLVER_PARALLEL": RuntimeOption.SOLVER_PARALLEL, "SOLVER_THREADS": RuntimeOption.SOLVER_THREADS, } def _read_option_attribute(name: str) -> Any: if name == "N_THREADS": return resolve_n_threads()[0] if name == "SPGEMM_THREADS": return resolve_spgemm_threads()[0] if name == "SOLVER_THREADS": return resolve_solver_threads()[0] return config.get(_OPTION_ATTRIBUTE_OPTIONS[name].value) class _RuntimeModule(types.ModuleType): """Module subclass that turns public runtime knobs into descriptors.""" def __getattribute__(self, name: str) -> Any: if name in _OPTION_ATTRIBUTE_OPTIONS: return _read_option_attribute(name) return super().__getattribute__(name) def __setattr__(self, name: str, value: Any) -> None: option = _OPTION_ATTRIBUTE_OPTIONS.get(name) if option is not None: config.set(option.value, value) return super().__setattr__(name, value) def __dir__(self) -> list[str]: return sorted(set(super().__dir__()) | set(_OPTION_ATTRIBUTE_OPTIONS)) sys.modules[__name__].__class__ = _RuntimeModule __all__ = [ "N_THREADS", "SOLVER_PARALLEL", "SOLVER_THREADS", "SPGEMM_PARALLEL", "SPGEMM_THREADS", "RuntimeOption", "context", "info", "patch", "resolve_n_threads", "resolve_solver_threads", "resolve_spgemm_threads", ]