Source code for mlx_sparse.linalg.preconditioners

# Copyright (c) 2026 The mlx-sparse contributors - All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#    http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Native-backed sparse solver preconditioners.

The Python objects in this module are containers and dispatch helpers.
Application and Krylov iteration dispatch to native mlx-sparse primitives rather
than Python solver loops. Constructors may use existing sparse native kernels
and MLX scalar array expressions to build immutable setup data.
"""

from __future__ import annotations

import math
from collections.abc import Mapping
from dataclasses import dataclass, field
from typing import Protocol, runtime_checkable

import mlx.core as mx

import mlx_sparse._native as _native
from mlx_sparse._coo import COOArray
from mlx_sparse._csc import CSCArray
from mlx_sparse._csr import CSRArray
from mlx_sparse.linalg.utils.arrays import (
    ensure_float32_csr,
    ensure_float32_vector,
    ensure_rank1_or_rank2_rhs,
    finite_scalar,
    host_bool,
)
from mlx_sparse.linalg.utils.preconditioners import normalize_identity_dtype
from mlx_sparse.linalg.utils.sparse import canonical_csr, square_shape


[docs] @runtime_checkable class Preconditioner(Protocol): """Protocol for objects that apply an approximate inverse. Solver-facing preconditioners represent the operation ``M^{-1} @ x``, not the matrix ``M`` itself. Implementations expose a square ``shape``, value ``dtype``, a stable ``kind`` identifier, symmetry and positive-definiteness metadata, setup/apply device descriptors, effective storage ``nnz``, and a structured ``setup_info`` mapping. The required :meth:`solve` method and ``__call__`` alias must accept rank-1 vector RHS and rank-2 dense RHS matrices without mutating the matrix used during setup. """ shape: tuple[int, int] dtype: object kind: str is_symmetric: bool is_positive_definite: bool setup_device: str apply_device: str nnz: int setup_info: Mapping[str, object]
[docs] def solve(self, x) -> mx.array: """Apply the preconditioner solve to ``x``."""
[docs] def __call__(self, x) -> mx.array: """Alias for :meth:`solve`."""
[docs] @dataclass(frozen=True, slots=True) class IdentityPreconditioner: """No-op inverse-apply preconditioner. ``IdentityPreconditioner`` is useful as an explicit baseline and as the normalized representation of ``M=None`` inside solver plumbing. Stored fields include the compatible square ``shape``, the native value ``dtype`` (currently ``mlx.core.float32``), the stable ``kind`` string, and symmetry/positive-definiteness metadata. """ shape: tuple[int, int] dtype: object = mx.float32 kind: str = "identity" is_symmetric: bool = True is_positive_definite: bool = True
[docs] def __post_init__(self) -> None: """Normalize and validate the stored square shape.""" object.__setattr__(self, "shape", square_shape(self.shape))
@property def nnz(self) -> int: """Number of effective diagonal entries.""" return self.shape[0] @property def setup_device(self) -> str: """Device used during setup.""" return "none" @property def apply_device(self) -> str: """Device used during inverse application.""" return "none" @property def setup_info(self) -> Mapping[str, object]: """Structured metadata describing setup choices and assumptions.""" return { "kind": self.kind, "shape": self.shape, "is_symmetric": self.is_symmetric, "is_positive_definite": self.is_positive_definite, }
[docs] def solve(self, x) -> mx.array: """Return ``x`` after validating rank, shape, dtype, and finiteness. Args: x: Right-hand side vector ``(n,)`` or matrix ``(n, nrhs)``. Returns: ``x`` as a finite ``float32`` MLX array. """ return ensure_rank1_or_rank2_rhs( x, leading_dim=self.shape[0], require_finite=True )
[docs] def __call__(self, x) -> mx.array: """Alias for :meth:`solve`.""" return self.solve(x)
[docs] @dataclass(frozen=True, slots=True) class DiagonalPreconditioner: """Explicit diagonal inverse-apply preconditioner. The stored vector is the inverse diagonal that should multiply each row of a right-hand side. Application dispatches to the native ``diagonal_preconditioner_apply`` primitive, including rank-2 RHS support. Stored fields include the finite ``float32`` ``inverse_diagonal`` vector, the compatible square ``shape``, the stable ``kind`` string, and symmetry/positive-definiteness metadata. """ inverse_diagonal: mx.array shape: tuple[int, int] kind: str = "diagonal" is_symmetric: bool = True is_positive_definite: bool = False
[docs] def __post_init__(self) -> None: """Validate shape and inverse diagonal storage.""" shape = square_shape(self.shape) inv_diag = ensure_float32_vector( "inverse_diagonal", self.inverse_diagonal, require_finite=True ) if inv_diag.shape[0] != shape[0]: raise ValueError( f"inverse_diagonal has length {inv_diag.shape[0]}, " f"expected {shape[0]}." ) object.__setattr__(self, "shape", shape) object.__setattr__(self, "inverse_diagonal", inv_diag)
@property def dtype(self): """Value dtype of ``inverse_diagonal``.""" return self.inverse_diagonal.dtype @property def nnz(self) -> int: """Number of stored inverse diagonal entries.""" return int(self.inverse_diagonal.shape[0]) @property def setup_device(self) -> str: """Device category used for setup validation.""" return "host_validation" @property def apply_device(self) -> str: """Device category used for inverse application.""" return "native_cpu_or_metal" @property def setup_info(self) -> Mapping[str, object]: """Structured metadata describing setup choices and assumptions.""" return { "kind": self.kind, "shape": self.shape, "nnz": self.nnz, "is_symmetric": self.is_symmetric, "is_positive_definite": self.is_positive_definite, }
[docs] def solve(self, x) -> mx.array: """Apply the diagonal inverse to a vector or dense RHS matrix. Args: x: Right-hand side with shape ``(n,)`` or ``(n, nrhs)``. Returns: Native-applied ``inverse_diagonal[:, None] * x`` for matrix RHS, or ``inverse_diagonal * x`` for vector RHS. """ rhs = ensure_rank1_or_rank2_rhs( x, leading_dim=self.shape[0], require_finite=True ) return _native.diagonal_preconditioner_apply(self.inverse_diagonal, rhs)
[docs] def matvec(self, x) -> mx.array: """Alias for :meth:`solve` for SciPy-style inverse-operator use.""" return self.solve(x)
[docs] def __call__(self, x) -> mx.array: """Alias for :meth:`solve`.""" return self.solve(x)
[docs] @dataclass(frozen=True, slots=True) class JacobiPreconditioner(DiagonalPreconditioner): """Jacobi preconditioner built from a sparse matrix diagonal. ``JacobiPreconditioner`` is a specialized diagonal inverse-apply object that records the setup parameters used by :func:`jacobi`. Passing it to :func:`mlx_sparse.linalg.cg` dispatches to the native Jacobi-PCG primitive. Stored fields include ``omega``, ``shift``, ``zero_policy``, ``zero_atol``, whether validation was ``checked``, and ``positive_diagonal`` when the cheap positive shifted-diagonal check was requested. """ kind: str = "jacobi" omega: float = 1.0 shift: float = 0.0 zero_policy: str = "raise" zero_atol: float = 0.0 checked: bool = False positive_diagonal: bool | None = None @property def setup_device(self) -> str: """Device category used for Jacobi setup.""" return "native_sparse_diagonal" @property def setup_info(self) -> Mapping[str, object]: """Structured metadata describing Jacobi setup choices.""" return { "kind": self.kind, "shape": self.shape, "nnz": self.nnz, "omega": self.omega, "shift": self.shift, "zero_policy": self.zero_policy, "zero_atol": self.zero_atol, "checked": self.checked, "positive_diagonal": self.positive_diagonal, "is_symmetric": self.is_symmetric, "is_positive_definite": self.is_positive_definite, }
[docs] @dataclass(frozen=True, slots=True) class CallablePreconditioner: """Python host inverse-apply preconditioner wrapper. ``CallablePreconditioner`` is the explicit normalization layer for custom inverse-apply objects. The wrapped callable receives a rank-1 or rank-2 ``float32`` MLX array and must return the same shape containing ``M^{-1} @ x``. Solver integrations may use this wrapper only on documented host fallback paths because each application crosses through Python. Stored fields include the callable ``apply`` object, compatible square ``shape``, stable ``kind`` metadata, conservative symmetry/positive definiteness flags, and structured setup information. """ apply: object shape: tuple[int, int] dtype: object = mx.float32 kind: str = "callable" is_symmetric: bool = False is_positive_definite: bool = False
[docs] def __post_init__(self) -> None: """Validate the callable contract metadata.""" if not callable(self.apply): raise TypeError("callable preconditioner apply object must be callable.") object.__setattr__(self, "shape", square_shape(self.shape)) if self.dtype != mx.float32: raise TypeError("callable preconditioners currently use float32 values.")
@property def nnz(self) -> int: """Unknown effective storage count for custom callables.""" return -1 @property def setup_device(self) -> str: """Device category used during setup.""" return "python_host" @property def apply_device(self) -> str: """Device category used during inverse application.""" return "python_host" @property def setup_info(self) -> Mapping[str, object]: """Structured metadata describing the callable contract.""" return { "kind": self.kind, "shape": self.shape, "assume_inverse": True, "is_symmetric": self.is_symmetric, "is_positive_definite": self.is_positive_definite, }
[docs] def solve(self, x) -> mx.array: """Apply the wrapped inverse callable and validate its output. Args: x: Right-hand side with shape ``(n,)`` or ``(n, nrhs)``. Returns: Finite ``float32`` output with the exact same shape as ``x``. """ rhs = ensure_rank1_or_rank2_rhs( x, leading_dim=self.shape[0], require_finite=True ) result = self.apply(rhs) try: out = ensure_rank1_or_rank2_rhs( result, leading_dim=self.shape[0], require_finite=True ) except ValueError as exc: raise ValueError( "preconditioner output shape or finite-value validation failed." ) from exc if out.shape != rhs.shape: raise ValueError( f"preconditioner output shape {out.shape} does not match " f"input shape {rhs.shape}." ) return out
[docs] def __call__(self, x) -> mx.array: """Alias for :meth:`solve`.""" return self.solve(x)
[docs] @dataclass(frozen=True, slots=True) class ExactFactorPreconditioner: """Exact inverse-apply preconditioner backed by a sparse factorization. This wrapper composes existing direct sparse solve objects with the iterative ``M`` protocol. It does not refactorize on application: setup is completed before construction. :meth:`solve` uses explicit native LU/Cholesky apply bindings when factors are available, uses the guarded Accelerate solver for real Accelerate factorized objects, and otherwise delegates to the stored reusable solve object. Stored fields include the reusable ``solver``, compatible square ``shape``, factorization ``method``, implementation ``backend``, and conservative symmetry/positive-definiteness metadata. Accelerate-backed ``FactorizedSolve`` instances keep their Accelerate CPU apply boundary; native explicit factors keep their native CPU/Metal triangular-solve apply boundary. """ solver: object shape: tuple[int, int] method: str backend: str kind: str = "exact" is_symmetric: bool = False is_positive_definite: bool = False factor_nnz: int = -1 native_apply_kind: str | None = None native_factorization: object | None = None
[docs] def __post_init__(self) -> None: """Validate the wrapped exact factorization metadata.""" if not hasattr(self.solver, "solve") or not callable(self.solver.solve): raise TypeError("exact factor preconditioners require solve(x).") object.__setattr__(self, "shape", square_shape(self.shape)) if self.factor_nnz < -1: raise ValueError("factor_nnz must be -1 or a non-negative integer.") if self.native_apply_kind not in {None, "lu", "cholesky", "accelerate"}: raise ValueError( "native_apply_kind must be None, 'lu', 'cholesky', or 'accelerate'." )
@property def dtype(self): """Value dtype used by current sparse direct solve backends.""" return mx.float32 @property def nnz(self) -> int: """Stored factor nonzero count, or ``-1`` for opaque factors.""" return int(self.factor_nnz) @property def setup_device(self) -> str: """Device category used during factorization setup.""" if self.backend == "accelerate": return "accelerate_cpu" return "native_cpu" @property def apply_device(self) -> str: """Device category used during inverse application.""" if self.backend == "accelerate": return "accelerate_cpu" return "native_cpu_or_metal" @property def setup_info(self) -> Mapping[str, object]: """Structured metadata describing the exact factorization wrapper.""" return { "kind": self.kind, "shape": self.shape, "method": self.method, "backend": self.backend, "setup_device": self.setup_device, "apply_device": self.apply_device, "nnz": self.nnz, "is_symmetric": self.is_symmetric, "is_positive_definite": self.is_positive_definite, "solver_type": type(self.solver).__name__, "native_apply_kind": self.native_apply_kind, "has_native_solver_apply": self.native_apply_kind is not None, }
[docs] def solve(self, x) -> mx.array: """Apply the exact factorized solve to a vector or dense RHS matrix. Args: x: Right-hand side with shape ``(n,)`` or ``(n, nrhs)``. Returns: Finite ``float32`` solution with the same shape as ``x``. """ rhs = ensure_rank1_or_rank2_rhs( x, leading_dim=self.shape[0], require_finite=True ) result = self._native_or_wrapped_solve(rhs) out = ensure_rank1_or_rank2_rhs( result, leading_dim=self.shape[1], require_finite=True ) if out.shape != rhs.shape: raise ValueError( f"exact factor preconditioner output shape {out.shape} does " f"not match input shape {rhs.shape}." ) return out
def _native_or_wrapped_solve(self, rhs) -> mx.array: """Apply explicit native factors when available, otherwise delegate.""" factor = self.native_factorization if self.native_apply_kind == "lu" and factor is not None: return _native.csr_exact_lu_preconditioner_apply( factor.perm, factor.L.data, factor.L.indices, factor.L.indptr, factor.U.data, factor.U.indices, factor.U.indptr, rhs, self.shape, ) if self.native_apply_kind == "cholesky" and factor is not None: upper = factor._upper() return _native.csr_exact_cholesky_preconditioner_apply( factor.L.data, factor.L.indices, factor.L.indptr, upper.data, upper.indices, upper.indptr, rhs, self.shape, ) return self.solver.solve(rhs)
[docs] def matvec(self, x) -> mx.array: """Alias for :meth:`solve` for inverse-operator composition.""" return self.solve(x)
[docs] def __call__(self, x) -> mx.array: """Alias for :meth:`solve`.""" return self.solve(x)
[docs] @dataclass(frozen=True, slots=True) class ILU0Preconditioner: """Natural-order no-fill incomplete LU preconditioner. ``ILU0Preconditioner`` stores explicit CSR factors ``L`` and ``U`` from a native CPU ILU(0) setup. ``L`` has a unit diagonal and the factors preserve the lower/upper sparsity pattern of the canonical CSR input without fill-reducing ordering or pivoting. Application performs the standard forward and backward triangular solves ``L y = x`` and ``U z = y`` using native CSR triangular-solve kernels. Stored fields include the lower/upper CSR factors, explicit diagonal ``shift`` used during setup, validation ``check`` mode, optional triangular-analysis reuse flag, and conservative nonsymmetric metadata. """ L: CSRArray U: CSRArray shift: float = 0.0 check: bool = True reuse_analysis: bool = False kind: str = "ilu0" is_symmetric: bool = False is_positive_definite: bool = False _l_diagonal_positions: mx.array | None = field( init=False, default=None, repr=False, compare=False ) _u_diagonal_positions: mx.array | None = field( init=False, default=None, repr=False, compare=False ) _l_level_schedule: tuple[mx.array, mx.array] | None = field( init=False, default=None, repr=False, compare=False ) _u_level_schedule: tuple[mx.array, mx.array] | None = field( init=False, default=None, repr=False, compare=False )
[docs] def __post_init__(self) -> None: """Validate factor metadata and optionally cache triangular analysis.""" shape = square_shape(self.L.shape) if square_shape(self.U.shape) != shape: raise ValueError("ILU0 L and U factors must have matching square shapes.") if self.L.data.dtype != mx.float32 or self.U.data.dtype != mx.float32: raise TypeError("ILU0 factors currently require float32 values.") object.__setattr__(self, "shift", finite_scalar("shift", self.shift)) object.__setattr__(self, "check", bool(self.check)) object.__setattr__(self, "reuse_analysis", bool(self.reuse_analysis)) if self.reuse_analysis: l_diag = _native.csr_triangular_diagonal_positions( self.L.indices, self.L.indptr, shape ) u_diag = _native.csr_triangular_diagonal_positions( self.U.indices, self.U.indptr, shape ) l_levels = _native.csr_triangular_level_schedule( self.L.indices, self.L.indptr, shape, lower=True ) u_levels = _native.csr_triangular_level_schedule( self.U.indices, self.U.indptr, shape, lower=False ) object.__setattr__(self, "_l_diagonal_positions", l_diag) object.__setattr__(self, "_u_diagonal_positions", u_diag) object.__setattr__(self, "_l_level_schedule", l_levels) object.__setattr__(self, "_u_level_schedule", u_levels)
@property def shape(self) -> tuple[int, int]: """Shape of the preconditioned square operator.""" return self.L.shape @property def dtype(self): """Value dtype used by the stored factors.""" return mx.float32 @property def nnz_L(self) -> int: """Stored nonzero count in the unit lower factor.""" return int(self.L.nnz) @property def nnz_U(self) -> int: """Stored nonzero count in the upper factor.""" return int(self.U.nnz) @property def nnz(self) -> int: """Total stored factor entries.""" return self.nnz_L + self.nnz_U @property def setup_device(self) -> str: """Device category used during ILU(0) setup.""" return "native_cpu" @property def apply_device(self) -> str: """Device category used during inverse application.""" return "native_cpu_or_metal" @property def setup_info(self) -> Mapping[str, object]: """Structured metadata describing ILU(0) setup choices.""" return { "kind": self.kind, "shape": self.shape, "shift": self.shift, "check": self.check, "reuse_analysis": self.reuse_analysis, "ordering": "natural", "fill": 0, "unit_diagonal_L": True, "nnz_L": self.nnz_L, "nnz_U": self.nnz_U, "nnz": self.nnz, "is_symmetric": self.is_symmetric, "is_positive_definite": self.is_positive_definite, }
[docs] def solve(self, x) -> mx.array: """Apply the ILU(0) inverse approximation to a vector or matrix RHS. Args: x: Right-hand side with shape ``(n,)`` or ``(n, nrhs)``. Returns: Native triangular-solve result ``U^{-1} L^{-1} x`` with the same rank and leading dimension as ``x``. """ rhs = ensure_rank1_or_rank2_rhs( x, leading_dim=self.shape[0], require_finite=True ) if self.reuse_analysis: y = _native.csr_triangular_solve( self.L.data, self.L.indices, self.L.indptr, rhs, self.shape, lower=True, unit_diagonal=True, diagonal_positions=self._l_diagonal_positions, level_schedule=self._l_level_schedule, ) return _native.csr_triangular_solve( self.U.data, self.U.indices, self.U.indptr, y, self.shape, lower=False, unit_diagonal=False, diagonal_positions=self._u_diagonal_positions, level_schedule=self._u_level_schedule, ) return _native.csr_ilu0_preconditioner_apply( self.L.data, self.L.indices, self.L.indptr, self.U.data, self.U.indices, self.U.indptr, rhs, self.shape, )
[docs] def matvec(self, x) -> mx.array: """Alias for :meth:`solve` for inverse-operator composition.""" return self.solve(x)
[docs] def __call__(self, x) -> mx.array: """Alias for :meth:`solve`.""" return self.solve(x)
[docs] @dataclass(frozen=True, slots=True) class IC0Preconditioner: """Natural-order no-fill incomplete Cholesky preconditioner. ``IC0Preconditioner`` stores an explicit lower CSR factor ``L`` produced by native CPU IC(0) setup. The factor uses natural ordering, preserves the symmetric lower sparsity pattern of the canonical CSR input, and introduces no fill. Application performs two native triangular solves, ``L y = x`` followed by ``L.T z = y``. Stored fields include the lower factor, explicit non-negative diagonal ``shift`` used during setup, strict positive-pivot ``check`` mode, and SPD metadata suitable for CG and MINRES-style solvers. """ L: CSRArray shift: float = 0.0 check: bool = True kind: str = "ichol0" is_symmetric: bool = True is_positive_definite: bool = True _upper_factor: CSRArray | None = field( init=False, default=None, repr=False, compare=False )
[docs] def __post_init__(self) -> None: """Validate factor metadata and explicit shift policy.""" square_shape(self.L.shape) if self.L.data.dtype != mx.float32: raise TypeError("IC0 factors currently require float32 values.") shift_value = finite_scalar("shift", self.shift) if shift_value < 0.0: raise ValueError("shift must be non-negative for IC0.") object.__setattr__(self, "shift", shift_value) object.__setattr__(self, "check", bool(self.check))
@property def shape(self) -> tuple[int, int]: """Shape of the preconditioned square operator.""" return self.L.shape @property def dtype(self): """Value dtype used by the stored factor.""" return mx.float32 @property def nnz_L(self) -> int: """Stored nonzero count in the lower factor.""" return int(self.L.nnz) @property def nnz(self) -> int: """Total stored factor entries.""" return self.nnz_L @property def setup_device(self) -> str: """Device category used during IC(0) setup.""" return "native_cpu" @property def apply_device(self) -> str: """Device category used during inverse application.""" return "native_cpu_or_metal" @property def setup_info(self) -> Mapping[str, object]: """Structured metadata describing IC(0) setup choices.""" return { "kind": self.kind, "shape": self.shape, "shift": self.shift, "check": self.check, "ordering": "natural", "fill": 0, "factor": "lower", "nnz_L": self.nnz_L, "nnz": self.nnz, "is_symmetric": self.is_symmetric, "is_positive_definite": self.is_positive_definite, } def _upper(self) -> CSRArray: """Return and cache ``L.T`` as a CSR upper factor.""" upper = self._upper_factor if upper is None: upper = self.L.T object.__setattr__(self, "_upper_factor", upper) return upper
[docs] def solve(self, x) -> mx.array: """Apply the IC(0) inverse approximation to a vector or matrix RHS. Args: x: Right-hand side with shape ``(n,)`` or ``(n, nrhs)``. Returns: Native triangular-solve result ``L.T^{-1} L^{-1} x`` with the same rank and leading dimension as ``x``. """ rhs = ensure_rank1_or_rank2_rhs( x, leading_dim=self.shape[0], require_finite=True ) upper = self._upper() return _native.csr_ic0_preconditioner_apply( self.L.data, self.L.indices, self.L.indptr, upper.data, upper.indices, upper.indptr, rhs, self.shape, )
[docs] def matvec(self, x) -> mx.array: """Alias for :meth:`solve` for inverse-operator composition.""" return self.solve(x)
[docs] def __call__(self, x) -> mx.array: """Alias for :meth:`solve`.""" return self.solve(x)
[docs] @dataclass(frozen=True, slots=True) class ChebyshevPreconditioner: """Polynomial inverse preconditioner for SPD sparse matrices. ``ChebyshevPreconditioner`` stores a canonical ``float32`` CSR operator and applies a fixed-degree first-kind Chebyshev semi-iteration with zero initial guess. Application uses only sparse matrix-vector/matrix products and vector updates, so it follows the selected native CPU or Metal device path without triangular solves or dense factorization. Stored fields include the CSR operator, polynomial ``degree``, validated positive spectral interval ``[lambda_min, lambda_max]``, whether setup used native spectral ``estimate`` data, and detailed ``spectral_info`` metadata. """ A: CSRArray degree: int = 2 lambda_min: float = 0.0 lambda_max: float = 0.0 estimate: bool = True spectral_info: Mapping[str, object] = field(default_factory=dict) kind: str = "chebyshev" is_symmetric: bool = True is_positive_definite: bool = True
[docs] def __post_init__(self) -> None: """Validate polynomial metadata and CSR storage.""" shape = square_shape(self.A.shape) if self.A.data.dtype != mx.float32: raise TypeError("Chebyshev preconditioners require float32 CSR values.") degree_value = int(self.degree) if degree_value <= 0: raise ValueError("degree must be positive.") lambda_min_value = finite_scalar("lambda_min", self.lambda_min) lambda_max_value = finite_scalar("lambda_max", self.lambda_max) if lambda_min_value <= 0.0 or lambda_max_value <= lambda_min_value: raise ValueError( "Chebyshev spectral interval must satisfy " "0 < lambda_min < lambda_max." ) object.__setattr__(self, "degree", degree_value) object.__setattr__(self, "lambda_min", lambda_min_value) object.__setattr__(self, "lambda_max", lambda_max_value) object.__setattr__(self, "estimate", bool(self.estimate)) if shape != self.A.shape: raise ValueError("Chebyshev operator shape must be square.")
@property def shape(self) -> tuple[int, int]: """Shape of the preconditioned square operator.""" return self.A.shape @property def dtype(self): """Value dtype used by the stored operator.""" return mx.float32 @property def nnz(self) -> int: """Number of sparse operator entries used during polynomial apply.""" return int(self.A.nnz) @property def setup_device(self) -> str: """Device category used during spectral setup.""" return "native_cpu" @property def apply_device(self) -> str: """Device category used during inverse application.""" return "native_cpu_or_metal" @property def setup_info(self) -> Mapping[str, object]: """Structured metadata describing Chebyshev setup choices.""" return { "kind": self.kind, "shape": self.shape, "degree": self.degree, "lambda_min": self.lambda_min, "lambda_max": self.lambda_max, "estimate": self.estimate, "nnz": self.nnz, "is_symmetric": self.is_symmetric, "is_positive_definite": self.is_positive_definite, "spectral_info": dict(self.spectral_info), }
[docs] def solve(self, x) -> mx.array: """Apply the Chebyshev polynomial inverse to a vector or matrix RHS. Args: x: Right-hand side with shape ``(n,)`` or ``(n, nrhs)``. Returns: Native Chebyshev semi-iteration result with the same shape as ``x``. """ rhs = ensure_rank1_or_rank2_rhs( x, leading_dim=self.shape[0], require_finite=True ) return _native.csr_chebyshev_preconditioner_apply( self.A.data, self.A.indices, self.A.indptr, rhs, self.shape, degree=self.degree, lambda_min=self.lambda_min, lambda_max=self.lambda_max, )
[docs] def matvec(self, x) -> mx.array: """Alias for :meth:`solve` for inverse-operator composition.""" return self.solve(x)
[docs] def __call__(self, x) -> mx.array: """Alias for :meth:`solve`.""" return self.solve(x)
[docs] def identity(A_or_shape, *, dtype=None) -> IdentityPreconditioner: """Create a no-op preconditioner for a square shape or sparse matrix. Args: A_or_shape: Square sparse matrix, ``(n, n)`` shape tuple, or integer dimension. dtype: Optional dtype. The current native solver integration accepts only ``None`` or ``mlx.core.float32``. Returns: An :class:`IdentityPreconditioner`. """ shape = square_shape(A_or_shape) return IdentityPreconditioner(shape=shape, dtype=normalize_identity_dtype(dtype))
[docs] def diagonal( inv_diag_or_diag, *, inverse: bool = False, shape=None, dtype=None, zero_atol: float = 0.0, ) -> DiagonalPreconditioner: """Create an explicit diagonal inverse-apply preconditioner. Args: inv_diag_or_diag: Rank-1 diagonal values. Interpreted as a diagonal by default, or as an inverse diagonal when ``inverse=True``. inverse: If ``True``, use ``inv_diag_or_diag`` directly as the inverse diagonal. If ``False``, validate and invert it. shape: Optional square shape. Defaults to ``(n, n)`` where ``n`` is the vector length. dtype: Optional dtype. The current native preconditioner path accepts only ``None`` or ``mlx.core.float32``. zero_atol: Absolute threshold used when rejecting zero diagonal entries before inversion. Returns: A :class:`DiagonalPreconditioner` with finite ``float32`` inverse diagonal storage. """ values = ensure_float32_vector("diagonal", inv_diag_or_diag, require_finite=True) if dtype is not None and dtype != mx.float32: raise TypeError("diagonal preconditioners currently use float32 values.") pc_shape = ( square_shape((values.shape[0], values.shape[0])) if shape is None else square_shape(shape) ) if values.shape[0] != pc_shape[0]: raise ValueError( f"diagonal has length {values.shape[0]}, expected {pc_shape[0]}." ) if inverse: inv_diag = values else: atol = float(zero_atol) if atol < 0.0: raise ValueError("zero_atol must be non-negative.") if host_bool(mx.any(mx.abs(values) <= atol)): raise ValueError("diagonal contains zero or near-zero entries.") inv_diag = 1.0 / values return DiagonalPreconditioner(inv_diag, pc_shape)
[docs] def jacobi( A, *, omega: float = 1.0, shift: float = 0.0, zero_policy: str = "raise", zero_atol: float = 0.0, check: bool = False, ) -> JacobiPreconditioner: """Create a Jacobi preconditioner from a sparse matrix diagonal. The inverse diagonal is computed as ``omega / (diag(A) + shift)`` after normalizing ``A`` to canonical CSR so duplicate diagonal entries are summed. The input sparse matrix is never mutated. Args: A: ``CSRArray``, ``COOArray``, ``CSCArray``, or sparse-backed ``LinearOperator``. omega: Damping/weighting factor. shift: Explicit diagonal shift applied before inversion. zero_policy: ``"raise"`` rejects zero/near-zero shifted diagonals. ``"unit"`` replaces those entries with ``1`` before inversion. zero_atol: Absolute threshold used to identify near-zero shifted diagonal entries. check: When ``True``, require ``omega > 0`` and a strictly positive shifted diagonal before any ``zero_policy`` replacement, then mark the preconditioner as positive definite. Returns: A :class:`JacobiPreconditioner` suitable for native PCG. """ if zero_policy not in {"raise", "unit"}: raise ValueError("zero_policy must be 'raise' or 'unit'.") omega_value = finite_scalar("omega", omega) shift_value = finite_scalar("shift", shift) checked = bool(check) if checked and omega_value <= 0.0: raise ValueError("omega must be positive when check=True.") csr = canonical_csr( A, context="jacobi", dense_guidance="", allow_sparse_linear_operator=True, ) if csr.shape[0] != csr.shape[1]: raise ValueError(f"jacobi requires a square matrix, got {csr.shape}.") diag = ensure_float32_vector("diagonal", csr.diagonal()) shifted = diag + mx.array(shift_value, dtype=mx.float32) if not host_bool(mx.all(mx.isfinite(shifted))): raise ValueError("shifted diagonal must contain only finite values.") atol = float(zero_atol) if atol < 0.0: raise ValueError("zero_atol must be non-negative.") near_zero = mx.abs(shifted) <= atol positive_shifted_diagonal = host_bool(mx.all(shifted > atol)) if checked else None if host_bool(mx.any(near_zero)): if zero_policy == "raise": raise ValueError( "Jacobi shifted diagonal contains zero or near-zero entries." ) shifted = mx.where(near_zero, mx.ones_like(shifted), shifted) positive_diagonal = None is_positive_definite = False if checked: positive_diagonal = positive_shifted_diagonal if not positive_diagonal: raise ValueError( "Jacobi shifted diagonal must be strictly positive when check=True." ) is_positive_definite = True inv_diag = mx.array(omega_value, dtype=mx.float32) / shifted return JacobiPreconditioner( inv_diag, csr.shape, is_positive_definite=is_positive_definite, omega=omega_value, shift=shift_value, zero_policy=zero_policy, zero_atol=atol, checked=checked, positive_diagonal=positive_diagonal, )
[docs] def ilu0( A, *, shift: float = 0.0, check: bool = True, reuse_analysis: bool = False, ) -> ILU0Preconditioner: """Create a natural-order no-fill ILU(0) preconditioner. The setup normalizes ``A`` to canonical CSR, promotes real low-precision values to ``float32``, and runs a native CPU ILU(0) factorization with no fill-reducing ordering and no pivoting. The original sparse matrix is not mutated. ``shift`` is added only to existing diagonal entries before setup; a missing diagonal remains an error. Args: A: ``CSRArray``, ``COOArray``, ``CSCArray``, or sparse-backed ``LinearOperator`` describing a square nonsingular matrix. shift: Explicit diagonal shift added before factorization. check: When ``True`` (default), native setup uses a scale-aware near-zero pivot guard. When ``False``, only exact zero and non-finite pivots are rejected during setup. reuse_analysis: When ``True``, cache triangular diagonal-position and level-schedule analysis for repeated explicit ``M(rhs)`` calls. The default is ``False`` because v0.0.4b1 triangular-analysis benchmarks showed this must be workload-measured before enabling. Returns: An :class:`ILU0Preconditioner` whose application uses native CSR triangular solves. """ shift_value = finite_scalar("shift", shift) csr = ensure_float32_csr( canonical_csr( A, context="ILU(0)", dense_guidance="Dense MLX arrays belong in mlx.linalg, not " "mlx_sparse.linalg.preconditioners.", allow_sparse_linear_operator=True, ), context="ILU(0)", ) if csr.shape[0] != csr.shape[1]: raise ValueError(f"ilu0 requires a square matrix, got {csr.shape}.") l_data, l_indices, l_indptr, u_data, u_indices, u_indptr = _native.csr_ilu0( csr.data, csr.indices, csr.indptr, csr.shape, shift=shift_value, check=bool(check), ) return ILU0Preconditioner( L=CSRArray( data=l_data, indices=l_indices, indptr=l_indptr, shape=csr.shape, sorted_indices=True, has_canonical_format=True, ), U=CSRArray( data=u_data, indices=u_indices, indptr=u_indptr, shape=csr.shape, sorted_indices=True, has_canonical_format=True, ), shift=shift_value, check=bool(check), reuse_analysis=bool(reuse_analysis), )
[docs] def ichol0(A, *, shift: float = 0.0, check: bool = True) -> IC0Preconditioner: """Create a natural-order no-fill IC(0) preconditioner. The setup normalizes ``A`` to canonical CSR, promotes real low-precision values to ``float32``, and runs a native CPU incomplete Cholesky factorization with zero fill and natural ordering. The original sparse matrix is not mutated. ``shift`` is a non-negative scalar added only to existing diagonal entries before setup; a missing diagonal remains an error and no pivot is silently perturbed. Args: A: ``CSRArray``, ``COOArray``, ``CSCArray``, or sparse-backed ``LinearOperator`` describing an SPD square matrix. shift: Explicit non-negative diagonal shift added before factorization. Defaults to ``0.0``. check: When ``True`` (default), native setup enforces symmetry for explicitly stored mirrored entries and uses a scale-aware positive pivot guard. When ``False``, non-positive and non-finite pivots are still rejected, but near-zero pivot and symmetry checks are relaxed. Returns: An :class:`IC0Preconditioner` whose application uses native CSR triangular solves ``L`` and ``L.T``. """ shift_value = finite_scalar("shift", shift) if shift_value < 0.0: raise ValueError("shift must be non-negative for IC0.") csr = ensure_float32_csr( canonical_csr( A, context="IC(0)", dense_guidance="Dense MLX arrays belong in mlx.linalg, not " "mlx_sparse.linalg.preconditioners.", allow_sparse_linear_operator=True, ), context="IC(0)", ) if csr.shape[0] != csr.shape[1]: raise ValueError(f"ichol0 requires a square matrix, got {csr.shape}.") l_data, l_indices, l_indptr = _native.csr_ic0( csr.data, csr.indices, csr.indptr, csr.shape, shift=shift_value, check=bool(check), ) return IC0Preconditioner( L=CSRArray( data=l_data, indices=l_indices, indptr=l_indptr, shape=csr.shape, sorted_indices=True, has_canonical_format=True, ), shift=shift_value, check=bool(check), )
[docs] def chebyshev( A, *, degree: int = 2, lambda_min: float | None = None, lambda_max: float | None = None, estimate: bool = True, ) -> ChebyshevPreconditioner: """Create a GPU-friendly Chebyshev polynomial preconditioner. The setup normalizes ``A`` to canonical CSR, promotes real low-precision values to ``float32``, and computes native Gershgorin spectral bounds plus optional native Lanczos Ritz estimates. The resulting preconditioner applies a fixed-degree first-kind Chebyshev semi-iteration with zero initial guess, using only sparse matrix products and vector updates. Args: A: ``CSRArray``, ``COOArray``, ``CSCArray``, or sparse-backed ``LinearOperator`` describing a real SPD square matrix. degree: Positive polynomial degree. Defaults to ``2``. lambda_min: Optional positive lower spectral bound. If omitted, setup uses a positive Gershgorin lower bound when available, otherwise a conservative Lanczos-derived lower estimate when ``estimate=True``. lambda_max: Optional upper spectral bound. If omitted, setup uses the Gershgorin upper bound when available, with Lanczos metadata recorded for diagnostics. estimate: Whether native Lanczos Ritz estimates should be computed as a fallback/refinement for the spectral interval. Defaults to ``True``. Returns: A :class:`ChebyshevPreconditioner` with native CPU/Metal apply support. Raises: ValueError: If no valid positive interval can be established. """ degree_value = int(degree) if degree_value <= 0: raise ValueError("degree must be positive.") csr = ensure_float32_csr( canonical_csr( A, context="Chebyshev", dense_guidance="Dense MLX arrays belong in mlx.linalg, not " "mlx_sparse.linalg.preconditioners.", allow_sparse_linear_operator=True, ), context="Chebyshev", ) if csr.shape[0] != csr.shape[1]: raise ValueError(f"chebyshev requires a square matrix, got {csr.shape}.") ( gershgorin_min, gershgorin_max, ritz_min, ritz_max, diagonal_min, diagonal_max, estimate_steps, ) = _native.csr_chebyshev_spectral_bounds( csr.data, csr.indices, csr.indptr, csr.shape, estimate=bool(estimate), estimate_steps=0, ) def positive_finite(value: float) -> bool: return math.isfinite(float(value)) and float(value) > 0.0 if lambda_max is None: if positive_finite(gershgorin_max): lambda_max_value = float(gershgorin_max) lambda_max_source = "gershgorin" elif bool(estimate) and positive_finite(ritz_max): lambda_max_value = 1.1 * float(ritz_max) lambda_max_source = "lanczos_1.1" else: raise ValueError( "could not determine a positive Chebyshev upper spectral " "bound; pass lambda_max explicitly." ) else: lambda_max_value = finite_scalar("lambda_max", lambda_max) lambda_max_source = "explicit" tiny = max(1.0e-12, 16.0 * 1.1920928955078125e-7) if lambda_min is None: if positive_finite(gershgorin_min): lambda_min_value = float(gershgorin_min) lambda_min_source = "gershgorin" elif bool(estimate) and positive_finite(ritz_min): lambda_min_value = max( tiny, min(0.5 * float(ritz_min), 0.1 * lambda_max_value) ) lambda_min_source = "lanczos_0.5" else: raise ValueError( "could not determine a positive Chebyshev lower spectral " "bound; pass lambda_min explicitly or use estimate=True for " "SPD matrices whose Gershgorin lower bound is non-positive." ) else: lambda_min_value = finite_scalar("lambda_min", lambda_min) lambda_min_source = "explicit" if ( not math.isfinite(lambda_min_value) or not math.isfinite(lambda_max_value) or lambda_min_value <= 0.0 or lambda_max_value <= lambda_min_value ): raise ValueError( "Chebyshev spectral interval must satisfy 0 < lambda_min < lambda_max." ) spectral_info = { "gershgorin_min": float(gershgorin_min), "gershgorin_max": float(gershgorin_max), "ritz_min": float(ritz_min), "ritz_max": float(ritz_max), "diagonal_min": float(diagonal_min), "diagonal_max": float(diagonal_max), "estimate_steps": int(estimate_steps), "lambda_min_source": lambda_min_source, "lambda_max_source": lambda_max_source, } return ChebyshevPreconditioner( A=csr, degree=degree_value, lambda_min=lambda_min_value, lambda_max=lambda_max_value, estimate=bool(estimate), spectral_info=spectral_info, )
[docs] def from_factorized(solver) -> ExactFactorPreconditioner: """Wrap an existing sparse factorization as an exact preconditioner. Args: solver: A :class:`~mlx_sparse.linalg.FactorizedSolve`, :class:`~mlx_sparse.linalg.SparseLU`, or :class:`~mlx_sparse.linalg.SparseCholesky` instance. Returns: An :class:`ExactFactorPreconditioner` whose inverse application uses a native exact-apply path when the factorization exposes one. Raises: TypeError: If ``solver`` is not one of the supported factorization objects. ValueError: If the factorization does not represent a square operator. """ from mlx_sparse.linalg._factorizations import ( FactorizedSolve, SparseCholesky, SparseLU, ) from mlx_sparse.linalg.utils.factorization import NativeFactorizedSolve if isinstance(solver, SparseCholesky): return ExactFactorPreconditioner( solver=solver, shape=square_shape(solver.shape), method="cholesky", backend="native", is_symmetric=True, is_positive_definite=True, factor_nnz=int(solver.L.nnz), native_apply_kind="cholesky", native_factorization=solver, ) if isinstance(solver, SparseLU): return ExactFactorPreconditioner( solver=solver, shape=square_shape(solver.shape), method="lu", backend="native", is_symmetric=False, is_positive_definite=False, factor_nnz=int(solver.L.nnz + solver.U.nnz), native_apply_kind="lu", native_factorization=solver, ) if isinstance(solver, FactorizedSolve): shape = square_shape(solver.shape) if int(solver.rhs_size) != shape[0] or int(solver.solution_size) != shape[1]: raise ValueError( "exact factor preconditioners require matching RHS and " "solution dimensions." ) method = str(solver.method) is_cholesky = method == "cholesky" native_apply_kind = None native_factorization = None factor_nnz = -1 wrapped_solver = getattr(solver, "_solver", None) if isinstance(wrapped_solver, NativeFactorizedSolve): factor = wrapped_solver.factorization if isinstance(factor, SparseLU): native_apply_kind = "lu" native_factorization = factor factor_nnz = int(factor.L.nnz + factor.U.nnz) elif isinstance(factor, SparseCholesky): native_apply_kind = "cholesky" native_factorization = factor factor_nnz = int(factor.L.nnz) elif str(solver.backend) == "accelerate" and _native.is_accelerate_float_solve( wrapped_solver ): native_apply_kind = "accelerate" native_factorization = wrapped_solver return ExactFactorPreconditioner( solver=solver, shape=shape, method=method, backend=str(solver.backend), is_symmetric=method in {"cholesky", "ldlt"}, is_positive_definite=is_cholesky, factor_nnz=factor_nnz, native_apply_kind=native_apply_kind, native_factorization=native_factorization, ) raise TypeError( "from_factorized expects FactorizedSolve, SparseLU, or SparseCholesky." )
[docs] def exact(A, *, method: str = "auto") -> ExactFactorPreconditioner: """Factorize ``A`` once and return an exact inverse-apply preconditioner. ``exact`` is a convenience wrapper around :func:`mlx_sparse.linalg.factorized`. It is intended as a correctness baseline, diagnostic tool, and composition point for existing direct solvers rather than a performance headline. Args: A: Sparse coefficient matrix accepted by :func:`mlx_sparse.linalg.factorized`. method: Direct factorization method. Defaults to ``"auto"``. Returns: An :class:`ExactFactorPreconditioner` wrapping the reusable factorized solve object. """ from mlx_sparse.linalg._factorizations import factorized return from_factorized(factorized(A, method=method))
[docs] def aspreconditioner(M, A=None, *, assume_inverse: bool = True) -> Preconditioner: """Normalize supported preconditioner-like objects. Args: M: ``None``, an existing preconditioner, an object with ``solve(x)``, or a callable. Sparse matrices are rejected because they do not explicitly define an inverse-apply contract. A: Optional reference matrix or shape used to validate compatibility. assume_inverse: Must be ``True`` for callables and custom objects, documenting that their output is already an inverse/preconditioner application. Returns: A supported preconditioner object. Raises: ValueError: If ``A`` is required or the preconditioner shape mismatches. TypeError: If ``M`` is a sparse matrix or unsupported object. """ if M is None: if A is None: raise ValueError("A is required when M is None.") return identity(A) if isinstance( M, ( IdentityPreconditioner, DiagonalPreconditioner, CallablePreconditioner, ChebyshevPreconditioner, ExactFactorPreconditioner, IC0Preconditioner, ILU0Preconditioner, ), ): if A is not None and M.shape != square_shape(A): raise ValueError(f"preconditioner shape {M.shape} does not match A.shape.") return M if isinstance(M, (CSRArray, COOArray, CSCArray)): raise TypeError( "sparse matrices are not inverse-apply preconditioners; use " "preconditioners.jacobi(A), preconditioners.ilu0(A), " "preconditioners.ichol0(A), or preconditioners.diagonal(...)." ) from mlx_sparse.linalg._factorizations import ( FactorizedSolve, SparseCholesky, SparseLU, ) if isinstance(M, (FactorizedSolve, SparseLU, SparseCholesky)): pc = from_factorized(M) if A is not None and pc.shape != square_shape(A): raise ValueError(f"preconditioner shape {pc.shape} does not match A.shape.") return pc if hasattr(M, "solve") and callable(M.solve): if not assume_inverse: raise TypeError("custom preconditioner objects must apply the inverse.") shape = square_shape(getattr(M, "shape", A)) if A is None else square_shape(A) if hasattr(M, "shape") and square_shape(M.shape) != shape: raise ValueError( f"preconditioner shape {square_shape(M.shape)} does not match " f"A.shape {shape}." ) return CallablePreconditioner(M.solve, shape) if callable(M): if not assume_inverse: raise TypeError("callable preconditioners must apply the inverse.") if A is None: raise ValueError("A is required when M is a callable.") return CallablePreconditioner(M, square_shape(A)) raise TypeError( "M must be None, a supported preconditioner, an inverse-apply object " "with solve(x), or an inverse-apply callable." )
__all__ = [ "DiagonalPreconditioner", "CallablePreconditioner", "ChebyshevPreconditioner", "ExactFactorPreconditioner", "IC0Preconditioner", "ILU0Preconditioner", "IdentityPreconditioner", "JacobiPreconditioner", "Preconditioner", "aspreconditioner", "chebyshev", "diagonal", "exact", "from_factorized", "identity", "ichol0", "ilu0", "jacobi", ]