Source code for mlx_sparse._config

# Copyright (c) 2026 The mlx-sparse contributors - All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#    http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Typed runtime configuration for mlx-sparse.

The design follows the same broad shape as neuraLQX's configuration manager:
options are declared in one schema, read from environment variables, validated
on mutation, observable through hooks, and temporarily patchable from user code.
The implementation here is deliberately smaller because mlx-sparse only needs a
few package-level knobs today.
"""

from __future__ import annotations

import contextlib
import enum
import hashlib
import json
import os
import threading
from collections.abc import Callable, Iterator, Mapping, Sequence
from dataclasses import dataclass, replace
from typing import Any, Generic, TypeVar, cast

T = TypeVar("T")


class ConfigError(Exception):
    """Base class for configuration errors."""


class UnknownOptionError(ConfigError):
    """Raised when a configuration option name is not registered."""


class ConfigValidationError(ConfigError):
    """Raised when a configuration value cannot be parsed or validated."""


class ConfigMutability(str, enum.Enum):
    """When an option may be changed through the Python API."""

    IMMUTABLE = "immutable"
    STARTUP = "startup"
    RUNTIME = "runtime"


class ConfigSource(str, enum.Enum):
    """Where an effective configuration value came from."""

    DEFAULT = "default"
    ENV_DEFAULT = "env_default"
    ENV_FORCE = "env_force"
    USER = "user"
    PATCH = "patch"
    RESET = "reset"


class _UnsetType:
    __slots__ = ()

    def __repr__(self) -> str:
        return "<UNSET>"


UNSET = _UnsetType()
_MISSING = _UnsetType()

Parser = Callable[[Any], Any]
Validator = Callable[[Any], None]
MutationHook = Callable[["ConfigMutation"], None]


@dataclass(frozen=True, slots=True)
class ConfigMutation:
    """A single effective configuration change."""

    name: str
    old_value: Any
    new_value: Any
    source: ConfigSource
    mutability: ConfigMutability


@dataclass(frozen=True, slots=True)
class ConfigOption(Generic[T]):
    """Static declaration for one configuration option."""

    name: str
    default: T
    doc: str
    value_type: type[Any] | tuple[type[Any], ...] | None = None
    parser: Parser | None = None
    validator: Validator | None = None
    env_default: tuple[str, ...] = ()
    env_force: tuple[str, ...] = ()
    role: str = "general"
    mutability: ConfigMutability = ConfigMutability.RUNTIME
    include_in_fingerprint: bool = True

    def parse(self, raw_value: Any) -> T:
        parsed = self.parser(raw_value) if self.parser is not None else raw_value
        if self.value_type is not None and not isinstance(parsed, self.value_type):
            raise ConfigValidationError(
                f"Option {self.name!r} expects {self.value_type}, "
                f"got {type(parsed).__name__} with value {parsed!r}."
            )
        if self.validator is not None:
            self.validator(parsed)
        return cast(T, parsed)


@dataclass(slots=True)
class _OptionState(Generic[T]):
    spec: ConfigOption[T]
    env_default_value: T | _UnsetType = UNSET
    env_force_value: T | _UnsetType = UNSET
    user_override: T | _UnsetType = UNSET


def parse_bool(value: Any) -> bool:
    """Parse a permissive boolean value."""

    if isinstance(value, bool):
        return value
    if isinstance(value, int) and value in (0, 1):
        return bool(value)
    if isinstance(value, str):
        normalized = value.strip().lower()
        if normalized in {"1", "true", "t", "yes", "y", "on"}:
            return True
        if normalized in {"0", "false", "f", "no", "n", "off"}:
            return False
    raise ConfigValidationError(f"Cannot parse boolean from value {value!r}.")


def parse_thread_count(value: Any) -> int | str:
    """Parse a positive thread count or the ``"auto"`` sentinel."""

    if isinstance(value, bool):
        raise ConfigValidationError(
            f"Thread count must be a positive integer or 'auto', got {value!r}."
        )
    if isinstance(value, int):
        if value >= 1:
            return value
        raise ConfigValidationError(
            f"Thread count must be a positive integer or 'auto', got {value!r}."
        )
    if isinstance(value, str):
        normalized = value.strip().lower()
        if normalized == "auto":
            return "auto"
        try:
            parsed = int(normalized, 10)
        except ValueError as exc:
            raise ConfigValidationError(
                f"Thread count must be a positive integer or 'auto', got {value!r}."
            ) from exc
        if parsed >= 1:
            return parsed
    raise ConfigValidationError(
        f"Thread count must be a positive integer or 'auto', got {value!r}."
    )


def parse_thread_count_or_inherit(value: Any) -> int | str:
    """Parse a positive thread count, ``"auto"``, or ``"inherit"``."""

    if isinstance(value, str) and value.strip().lower() == "inherit":
        return "inherit"
    return parse_thread_count(value)


def _format_env_value(value: Any) -> str:
    if isinstance(value, bool):
        return "1" if value else "0"
    if value is None:
        return ""
    return str(value)


class ConfigManager:
    """Typed, hookable singleton configuration manager for mlx-sparse.

    Effective value precedence is:

    1. forced environment variable (``env_force``)
    2. programmatic override
    3. default environment variable (``env_default``)
    4. built-in default
    """

    _instance: "ConfigManager | None" = None
    _class_lock = threading.Lock()

    PREFIX = "MLX_SPARSE_"

    def __new__(cls) -> "ConfigManager":
        if cls._instance is None:
            with cls._class_lock:
                if cls._instance is None:
                    cls._instance = super().__new__(cls)
                    cls._instance._initialized = False
        return cls._instance

    def __init__(self) -> None:
        if self._initialized:
            return
        object.__setattr__(self, "_options", {})
        object.__setattr__(self, "_hooks", {})
        object.__setattr__(self, "_global_hooks", [])
        object.__setattr__(self, "_runtime_locked", False)
        object.__setattr__(self, "_lock", threading.RLock())
        self._register_default_options()
        self._initialized = True

    def __repr__(self) -> str:
        return (
            "ConfigManager("
            f"options={len(self._options)}, "
            f"runtime_locked={self._runtime_locked})"
        )

    def __getattr__(self, name: str) -> Any:
        if "_options" in self.__dict__ and name in self._options:
            return self.get(name)
        raise AttributeError(f"{self.__class__.__name__} has no option {name!r}.")

    def __setattr__(self, name: str, value: Any) -> None:
        if name.startswith("_"):
            object.__setattr__(self, name, value)
            return
        if "_options" in self.__dict__ and name in self._options:
            self.set(name, value)
            return
        raise AttributeError(f"{self.__class__.__name__} has no option {name!r}.")

    @staticmethod
    def _normalize_env_names(names: str | Sequence[str] | None) -> tuple[str, ...]:
        if names is None:
            return ()
        if isinstance(names, str):
            names = (names,)
        out: list[str] = []
        for name in names:
            if not isinstance(name, str) or not name:
                raise ConfigValidationError(
                    "Environment names must be non-empty strings."
                )
            out.append(name.upper())
        return tuple(out)

    def _lookup_env(self, key: str) -> str | None:
        key_upper = key.upper()
        for env_key, env_value in os.environ.items():
            if env_key.upper() == key_upper:
                return env_value
        return None

    def _read_env_value(
        self, option: ConfigOption[Any], env_names: Sequence[str]
    ) -> Any | _UnsetType:
        for env_name in env_names:
            raw_value = self._lookup_env(env_name)
            if raw_value is None:
                continue
            try:
                return option.parse(raw_value)
            except ConfigValidationError as exc:
                raise ConfigValidationError(
                    f"Invalid value from environment variable {env_name!r} "
                    f"for option {option.name!r}: {raw_value!r}."
                ) from exc
        return UNSET

    def _direct_effective_value(self, state: _OptionState[Any]) -> Any:
        if state.env_force_value is not UNSET:
            return state.env_force_value
        if state.user_override is not UNSET:
            return state.user_override
        if state.env_default_value is not UNSET:
            return state.env_default_value
        return state.spec.default

    def _sync_env(self, state: _OptionState[Any]) -> None:
        os.environ[self.PREFIX + state.spec.name] = _format_env_value(
            self._direct_effective_value(state)
        )

    def register_option(self, option: ConfigOption[Any]) -> None:
        with self._lock:
            if not option.name or not isinstance(option.name, str):
                raise ConfigValidationError("Option name must be a non-empty string.")
            name = option.name.upper()
            if name in self._options:
                raise ConfigValidationError(f"Option {name!r} is already registered.")

            default = option.parse(option.default)
            option = replace(
                option,
                name=name,
                default=default,
                env_default=self._normalize_env_names(option.env_default),
                env_force=self._normalize_env_names(option.env_force),
            )
            state = _OptionState(spec=option)
            state.env_default_value = self._read_env_value(option, option.env_default)
            state.env_force_value = self._read_env_value(option, option.env_force)
            self._options[name] = state
            self._hooks[name] = []
            self._sync_env(state)

    def define_option(
        self,
        name: str,
        *,
        default: Any,
        doc: str,
        value_type: type[Any] | tuple[type[Any], ...] | None = None,
        parser: Parser | None = None,
        validator: Validator | None = None,
        env_default: str | Sequence[str] | None = None,
        env_force: str | Sequence[str] | None = None,
        role: str = "general",
        mutability: ConfigMutability = ConfigMutability.RUNTIME,
        include_in_fingerprint: bool = True,
    ) -> None:
        self.register_option(
            ConfigOption(
                name=name,
                default=default,
                doc=doc,
                value_type=value_type,
                parser=parser,
                validator=validator,
                env_default=self._normalize_env_names(env_default),
                env_force=self._normalize_env_names(env_force),
                role=role,
                mutability=mutability,
                include_in_fingerprint=include_in_fingerprint,
            )
        )

    def define_bool(
        self,
        name: str,
        *,
        default: bool,
        doc: str,
        env_default: str | Sequence[str] | None = None,
        env_force: str | Sequence[str] | None = None,
        role: str = "general",
        mutability: ConfigMutability = ConfigMutability.RUNTIME,
        include_in_fingerprint: bool = True,
    ) -> None:
        self.define_option(
            name,
            default=default,
            doc=doc,
            value_type=bool,
            parser=parse_bool,
            env_default=env_default,
            env_force=env_force,
            role=role,
            mutability=mutability,
            include_in_fingerprint=include_in_fingerprint,
        )

    def _get_state(self, name: str) -> _OptionState[Any]:
        key = name.upper()
        try:
            return self._options[key]
        except KeyError as exc:
            raise UnknownOptionError(f"Unknown option {name!r}.") from exc

    def _assert_can_mutate(self, state: _OptionState[Any]) -> None:
        if state.env_force_value is not UNSET:
            raise ConfigError(
                f"Option {state.spec.name!r} is forced by environment "
                f"{state.spec.env_force} and cannot be changed from Python."
            )
        if state.spec.mutability is ConfigMutability.IMMUTABLE:
            raise ConfigError(f"Option {state.spec.name!r} is immutable.")
        if state.spec.mutability is ConfigMutability.STARTUP and self._runtime_locked:
            raise ConfigError(
                f"Option {state.spec.name!r} is startup-only and runtime is locked."
            )

    def get(self, name: str) -> Any:
        return self._direct_effective_value(self._get_state(name))

    def read(self, name: str) -> Any:
        return self.get(name)

    def set(
        self,
        name: str,
        value: Any,
        *,
        source: ConfigSource = ConfigSource.USER,
    ) -> Any:
        state = self._get_state(name)
        parsed = state.spec.parse(value)
        with self._lock:
            self._assert_can_mutate(state)
            old_value = self.get(state.spec.name)
            state.user_override = parsed
            self._sync_env(state)
            new_value = self.get(state.spec.name)
            self._maybe_emit(state, old_value, new_value, source)
            return new_value

    def update(
        self,
        name: str,
        value: Any,
        *,
        source: ConfigSource = ConfigSource.USER,
    ) -> Any:
        return self.set(name, value, source=source)

    def clear_override(
        self,
        name: str,
        *,
        source: ConfigSource = ConfigSource.RESET,
    ) -> Any:
        state = self._get_state(name)
        with self._lock:
            self._assert_can_mutate(state)
            old_value = self.get(state.spec.name)
            state.user_override = UNSET
            self._sync_env(state)
            new_value = self.get(state.spec.name)
            self._maybe_emit(state, old_value, new_value, source)
            return new_value

    def set_many(
        self,
        updates: Mapping[str, Any],
        *,
        source: ConfigSource = ConfigSource.USER,
    ) -> None:
        for name, value in updates.items():
            self.set(name, value, source=source)

    @contextlib.contextmanager
    def patch(
        self,
        arg1: str | Mapping[str, Any] | None = None,
        arg2: Any = _MISSING,
        **kwargs: Any,
    ) -> Iterator[None]:
        updates = self._normalize_patch_args(arg1, arg2, kwargs)
        previous: dict[str, Any] = {}
        for name in updates:
            state = self._get_state(name)
            previous[state.spec.name] = state.user_override

        try:
            self.set_many(updates, source=ConfigSource.PATCH)
            yield
        finally:
            for name, old_value in previous.items():
                if old_value is UNSET:
                    self.clear_override(name, source=ConfigSource.PATCH)
                else:
                    self.set(name, old_value, source=ConfigSource.PATCH)

    def list_options(self) -> tuple[str, ...]:
        with self._lock:
            return tuple(sorted(self._options))

    def snapshot(self) -> dict[str, Any]:
        return {name: self.get(name) for name in self.list_options()}

    @property
    def values(self) -> dict[str, Any]:
        return self.snapshot()

    def user_overrides(self) -> dict[str, Any]:
        return {
            name: state.user_override
            for name, state in self._options.items()
            if state.user_override is not UNSET
        }

    def describe_option(self, name: str) -> dict[str, Any]:
        state = self._get_state(name)
        return {
            "name": state.spec.name,
            "doc": state.spec.doc,
            "role": state.spec.role,
            "default": state.spec.default,
            "mutability": state.spec.mutability.value,
            "env_default": state.spec.env_default,
            "env_force": state.spec.env_force,
            "effective_value": self.get(state.spec.name),
            "source": self.value_source(state.spec.name).value,
            "runtime_locked": self._runtime_locked,
            "include_in_fingerprint": state.spec.include_in_fingerprint,
        }

    def options_by_role(self) -> dict[str, tuple[str, ...]]:
        grouped: dict[str, list[str]] = {}
        for name, state in self._options.items():
            grouped.setdefault(state.spec.role, []).append(name)
        return {role: tuple(sorted(names)) for role, names in grouped.items()}

    def value_source(self, name: str) -> ConfigSource:
        state = self._get_state(name)
        if state.env_force_value is not UNSET:
            return ConfigSource.ENV_FORCE
        if state.user_override is not UNSET:
            return ConfigSource.USER
        if state.env_default_value is not UNSET:
            return ConfigSource.ENV_DEFAULT
        return ConfigSource.DEFAULT

    def fingerprint(self) -> str:
        values = {
            name: self.get(name)
            for name, state in self._options.items()
            if state.spec.include_in_fingerprint
        }
        payload = json.dumps(values, sort_keys=True, default=str).encode("utf-8")
        return hashlib.sha256(payload).hexdigest()

    @property
    def runtime_locked(self) -> bool:
        return self._runtime_locked

    def lock_runtime(self) -> None:
        with self._lock:
            self._runtime_locked = True

    def unlock_runtime_for_testing(self) -> None:
        with self._lock:
            self._runtime_locked = False

    def add_hook(
        self,
        name: str,
        hook: MutationHook,
        *,
        run_immediately: bool = False,
    ) -> None:
        state = self._get_state(name)
        with self._lock:
            self._hooks[state.spec.name].append(hook)
        if run_immediately:
            current = self.get(state.spec.name)
            hook(
                ConfigMutation(
                    name=state.spec.name,
                    old_value=current,
                    new_value=current,
                    source=self.value_source(state.spec.name),
                    mutability=state.spec.mutability,
                )
            )

    def add_global_hook(self, hook: MutationHook) -> None:
        with self._lock:
            self._global_hooks.append(hook)

    def show(self) -> str:
        rows = [
            (
                name,
                state.spec.role,
                repr(state.spec.default),
                state.spec.mutability.value,
                repr(self.get(name)),
                self.value_source(name).value,
            )
            for name, state in sorted(self._options.items())
        ]
        headers = ("name", "role", "default", "mutability", "effective", "source")
        widths = [
            max(len(str(cell)) for cell in (header, *(row[i] for row in rows)))
            for i, header in enumerate(headers)
        ]
        lines = [
            "  ".join(header.ljust(widths[i]) for i, header in enumerate(headers)),
            "  ".join("-" * width for width in widths),
        ]
        lines.extend(
            "  ".join(str(cell).ljust(widths[i]) for i, cell in enumerate(row))
            for row in rows
        )
        return "\n".join(lines)

    def _maybe_emit(
        self,
        state: _OptionState[Any],
        old_value: Any,
        new_value: Any,
        source: ConfigSource,
    ) -> None:
        if old_value == new_value:
            return
        event = ConfigMutation(
            name=state.spec.name,
            old_value=old_value,
            new_value=new_value,
            source=source,
            mutability=state.spec.mutability,
        )
        for hook in tuple(self._hooks[state.spec.name]):
            hook(event)
        for hook in tuple(self._global_hooks):
            hook(event)

    @staticmethod
    def _normalize_patch_args(
        arg1: str | Mapping[str, Any] | None,
        arg2: Any,
        kwargs: Mapping[str, Any],
    ) -> dict[str, Any]:
        if arg1 is None:
            if arg2 is not _MISSING:
                raise TypeError("patch(None, value) is not a valid call form.")
            return dict(kwargs)
        if isinstance(arg1, str):
            if arg2 is _MISSING:
                raise TypeError("patch('NAME', value) requires a value.")
            if kwargs:
                raise TypeError("Cannot combine two-argument patch with keywords.")
            return {arg1: arg2}
        if arg2 is not _MISSING:
            raise TypeError("Mapping patch form does not accept a second value.")
        if kwargs:
            raise TypeError("Cannot combine mapping patch with keywords.")
        if not isinstance(arg1, Mapping):
            raise TypeError("patch expects a name or mapping.")
        return dict(arg1)

    def _register_default_options(self) -> None:
        self.define_option(
            "CPU_THREADS",
            default="auto",
            doc=(
                "Package-wide CPU worker setting. Use a positive integer for an "
                "explicit worker count, or 'auto' to resolve from standard "
                "threading and scheduler environment variables, process affinity, "
                "and hardware concurrency."
            ),
            value_type=(int, str),
            parser=parse_thread_count,
            env_default=("MLX_SPARSE_CPU_THREADS", "MLX_SPARSE_N_THREADS"),
            role="runtime",
            mutability=ConfigMutability.RUNTIME,
        )
        self.define_bool(
            "SPGEMM_PARALLEL",
            default=True,
            doc=(
                "Enable package-level CPU parallel execution for sparse-sparse "
                "matrix products when a parallel implementation is available."
            ),
            env_default="MLX_SPARSE_SPGEMM_PARALLEL",
            role="runtime",
            mutability=ConfigMutability.RUNTIME,
        )
        self.define_option(
            "SPGEMM_THREADS",
            default="inherit",
            doc=(
                "CPU worker setting for sparse-sparse matrix products. Use a "
                "positive integer for an explicit family-specific count, 'auto' "
                "for dynamic runtime resolution, or 'inherit' to use CPU_THREADS."
            ),
            value_type=(int, str),
            parser=parse_thread_count_or_inherit,
            env_default="MLX_SPARSE_SPGEMM_THREADS",
            role="runtime",
            mutability=ConfigMutability.RUNTIME,
        )
        self.define_bool(
            "SOLVER_PARALLEL",
            default=False,
            doc=(
                "Enable package-level CPU parallel execution for solver routines "
                "when a parallel implementation is available."
            ),
            env_default="MLX_SPARSE_SOLVER_PARALLEL",
            role="runtime",
            mutability=ConfigMutability.RUNTIME,
        )
        self.define_option(
            "SOLVER_THREADS",
            default="inherit",
            doc=(
                "CPU worker setting for solver routines. Use a positive integer "
                "for an explicit family-specific count, 'auto' for dynamic "
                "runtime resolution, or 'inherit' to use CPU_THREADS."
            ),
            value_type=(int, str),
            parser=parse_thread_count_or_inherit,
            env_default="MLX_SPARSE_SOLVER_THREADS",
            role="runtime",
            mutability=ConfigMutability.RUNTIME,
        )
        self.define_bool(
            "EXPERIMENTAL_METAL_SPGEMM",
            default=False,
            doc=(
                "Enable experimental staged Metal implementations for same-format "
                "CSR, COO, and CSC sparse-sparse products. "
                "The optimized native host SpGEMM path remains the default because "
                "it is faster on current small and medium benchmark cases."
            ),
            env_default="MLX_SPARSE_EXPERIMENTAL_METAL_SPGEMM",
            env_force="MLX_SPARSE_FORCE_EXPERIMENTAL_METAL_SPGEMM",
            role="sparse",
            mutability=ConfigMutability.RUNTIME,
        )


config = ConfigManager()


[docs] def get_config(name: str) -> Any: """Read a package configuration value.""" return config.get(name)
[docs] def set_config(name: str, value: Any) -> Any: """Set a package configuration value.""" return config.set(name, value)
[docs] def config_context( *args: Any, **kwargs: Any ) -> contextlib.AbstractContextManager[None]: """Temporarily patch package configuration values.""" return config.patch(*args, **kwargs)
__all__ = [ "config", "config_context", "get_config", "parse_bool", "parse_thread_count", "parse_thread_count_or_inherit", "set_config", ]