# Copyright (c) 2026 The mlx-sparse contributors - All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Native-backed sparse solver preconditioners.
The Python objects in this module are containers and dispatch helpers.
Application and Krylov iteration dispatch to native mlx-sparse primitives rather
than Python solver loops. Constructors may use existing sparse native kernels
and MLX scalar array expressions to build immutable setup data.
"""
from __future__ import annotations
import math
from collections.abc import Mapping
from dataclasses import dataclass, field
from typing import Protocol, runtime_checkable
import mlx.core as mx
import mlx_sparse._native as _native
from mlx_sparse._coo import COOArray
from mlx_sparse._csc import CSCArray
from mlx_sparse._csr import CSRArray
from mlx_sparse.linalg.utils.arrays import (
ensure_float32_csr,
ensure_float32_vector,
ensure_rank1_or_rank2_rhs,
finite_scalar,
host_bool,
)
from mlx_sparse.linalg.utils.preconditioners import normalize_identity_dtype
from mlx_sparse.linalg.utils.sparse import canonical_csr, square_shape
[docs]
@runtime_checkable
class Preconditioner(Protocol):
"""Protocol for objects that apply an approximate inverse.
Solver-facing preconditioners represent the operation ``M^{-1} @ x``, not
the matrix ``M`` itself. Implementations expose a square ``shape``, value
``dtype``, a stable ``kind`` identifier, symmetry and positive-definiteness
metadata, setup/apply device descriptors, effective storage ``nnz``, and a
structured ``setup_info`` mapping. The required :meth:`solve` method and
``__call__`` alias must accept rank-1 vector RHS and rank-2 dense RHS
matrices without mutating the matrix used during setup.
"""
shape: tuple[int, int]
dtype: object
kind: str
is_symmetric: bool
is_positive_definite: bool
setup_device: str
apply_device: str
nnz: int
setup_info: Mapping[str, object]
[docs]
def solve(self, x) -> mx.array:
"""Apply the preconditioner solve to ``x``."""
[docs]
def __call__(self, x) -> mx.array:
"""Alias for :meth:`solve`."""
[docs]
@dataclass(frozen=True, slots=True)
class IdentityPreconditioner:
"""No-op inverse-apply preconditioner.
``IdentityPreconditioner`` is useful as an explicit baseline and as the
normalized representation of ``M=None`` inside solver plumbing.
Stored fields include the compatible square ``shape``, the native value
``dtype`` (currently ``mlx.core.float32``), the stable ``kind`` string, and
symmetry/positive-definiteness metadata.
"""
shape: tuple[int, int]
dtype: object = mx.float32
kind: str = "identity"
is_symmetric: bool = True
is_positive_definite: bool = True
[docs]
def __post_init__(self) -> None:
"""Normalize and validate the stored square shape."""
object.__setattr__(self, "shape", square_shape(self.shape))
@property
def nnz(self) -> int:
"""Number of effective diagonal entries."""
return self.shape[0]
@property
def setup_device(self) -> str:
"""Device used during setup."""
return "none"
@property
def apply_device(self) -> str:
"""Device used during inverse application."""
return "none"
@property
def setup_info(self) -> Mapping[str, object]:
"""Structured metadata describing setup choices and assumptions."""
return {
"kind": self.kind,
"shape": self.shape,
"is_symmetric": self.is_symmetric,
"is_positive_definite": self.is_positive_definite,
}
[docs]
def solve(self, x) -> mx.array:
"""Return ``x`` after validating rank, shape, dtype, and finiteness.
Args:
x: Right-hand side vector ``(n,)`` or matrix ``(n, nrhs)``.
Returns:
``x`` as a finite ``float32`` MLX array.
"""
return ensure_rank1_or_rank2_rhs(
x, leading_dim=self.shape[0], require_finite=True
)
[docs]
def __call__(self, x) -> mx.array:
"""Alias for :meth:`solve`."""
return self.solve(x)
[docs]
@dataclass(frozen=True, slots=True)
class DiagonalPreconditioner:
"""Explicit diagonal inverse-apply preconditioner.
The stored vector is the inverse diagonal that should multiply each row of
a right-hand side. Application dispatches to the native
``diagonal_preconditioner_apply`` primitive, including rank-2 RHS support.
Stored fields include the finite ``float32`` ``inverse_diagonal`` vector,
the compatible square ``shape``, the stable ``kind`` string, and
symmetry/positive-definiteness metadata.
"""
inverse_diagonal: mx.array
shape: tuple[int, int]
kind: str = "diagonal"
is_symmetric: bool = True
is_positive_definite: bool = False
[docs]
def __post_init__(self) -> None:
"""Validate shape and inverse diagonal storage."""
shape = square_shape(self.shape)
inv_diag = ensure_float32_vector(
"inverse_diagonal", self.inverse_diagonal, require_finite=True
)
if inv_diag.shape[0] != shape[0]:
raise ValueError(
f"inverse_diagonal has length {inv_diag.shape[0]}, "
f"expected {shape[0]}."
)
object.__setattr__(self, "shape", shape)
object.__setattr__(self, "inverse_diagonal", inv_diag)
@property
def dtype(self):
"""Value dtype of ``inverse_diagonal``."""
return self.inverse_diagonal.dtype
@property
def nnz(self) -> int:
"""Number of stored inverse diagonal entries."""
return int(self.inverse_diagonal.shape[0])
@property
def setup_device(self) -> str:
"""Device category used for setup validation."""
return "host_validation"
@property
def apply_device(self) -> str:
"""Device category used for inverse application."""
return "native_cpu_or_metal"
@property
def setup_info(self) -> Mapping[str, object]:
"""Structured metadata describing setup choices and assumptions."""
return {
"kind": self.kind,
"shape": self.shape,
"nnz": self.nnz,
"is_symmetric": self.is_symmetric,
"is_positive_definite": self.is_positive_definite,
}
[docs]
def solve(self, x) -> mx.array:
"""Apply the diagonal inverse to a vector or dense RHS matrix.
Args:
x: Right-hand side with shape ``(n,)`` or ``(n, nrhs)``.
Returns:
Native-applied ``inverse_diagonal[:, None] * x`` for matrix RHS, or
``inverse_diagonal * x`` for vector RHS.
"""
rhs = ensure_rank1_or_rank2_rhs(
x, leading_dim=self.shape[0], require_finite=True
)
return _native.diagonal_preconditioner_apply(self.inverse_diagonal, rhs)
[docs]
def matvec(self, x) -> mx.array:
"""Alias for :meth:`solve` for SciPy-style inverse-operator use."""
return self.solve(x)
[docs]
def __call__(self, x) -> mx.array:
"""Alias for :meth:`solve`."""
return self.solve(x)
[docs]
@dataclass(frozen=True, slots=True)
class JacobiPreconditioner(DiagonalPreconditioner):
"""Jacobi preconditioner built from a sparse matrix diagonal.
``JacobiPreconditioner`` is a specialized diagonal inverse-apply object
that records the setup parameters used by :func:`jacobi`. Passing it to
:func:`mlx_sparse.linalg.cg` dispatches to the native Jacobi-PCG primitive.
Stored fields include ``omega``, ``shift``, ``zero_policy``, ``zero_atol``,
whether validation was ``checked``, and ``positive_diagonal`` when the
cheap positive shifted-diagonal check was requested.
"""
kind: str = "jacobi"
omega: float = 1.0
shift: float = 0.0
zero_policy: str = "raise"
zero_atol: float = 0.0
checked: bool = False
positive_diagonal: bool | None = None
@property
def setup_device(self) -> str:
"""Device category used for Jacobi setup."""
return "native_sparse_diagonal"
@property
def setup_info(self) -> Mapping[str, object]:
"""Structured metadata describing Jacobi setup choices."""
return {
"kind": self.kind,
"shape": self.shape,
"nnz": self.nnz,
"omega": self.omega,
"shift": self.shift,
"zero_policy": self.zero_policy,
"zero_atol": self.zero_atol,
"checked": self.checked,
"positive_diagonal": self.positive_diagonal,
"is_symmetric": self.is_symmetric,
"is_positive_definite": self.is_positive_definite,
}
[docs]
@dataclass(frozen=True, slots=True)
class CallablePreconditioner:
"""Python host inverse-apply preconditioner wrapper.
``CallablePreconditioner`` is the explicit normalization layer for custom
inverse-apply objects. The wrapped callable receives a rank-1 or rank-2
``float32`` MLX array and must return the same shape containing
``M^{-1} @ x``. Solver integrations may use this wrapper only on documented
host fallback paths because each application crosses through Python.
Stored fields include the callable ``apply`` object, compatible square
``shape``, stable ``kind`` metadata, conservative symmetry/positive
definiteness flags, and structured setup information.
"""
apply: object
shape: tuple[int, int]
dtype: object = mx.float32
kind: str = "callable"
is_symmetric: bool = False
is_positive_definite: bool = False
[docs]
def __post_init__(self) -> None:
"""Validate the callable contract metadata."""
if not callable(self.apply):
raise TypeError("callable preconditioner apply object must be callable.")
object.__setattr__(self, "shape", square_shape(self.shape))
if self.dtype != mx.float32:
raise TypeError("callable preconditioners currently use float32 values.")
@property
def nnz(self) -> int:
"""Unknown effective storage count for custom callables."""
return -1
@property
def setup_device(self) -> str:
"""Device category used during setup."""
return "python_host"
@property
def apply_device(self) -> str:
"""Device category used during inverse application."""
return "python_host"
@property
def setup_info(self) -> Mapping[str, object]:
"""Structured metadata describing the callable contract."""
return {
"kind": self.kind,
"shape": self.shape,
"assume_inverse": True,
"is_symmetric": self.is_symmetric,
"is_positive_definite": self.is_positive_definite,
}
[docs]
def solve(self, x) -> mx.array:
"""Apply the wrapped inverse callable and validate its output.
Args:
x: Right-hand side with shape ``(n,)`` or ``(n, nrhs)``.
Returns:
Finite ``float32`` output with the exact same shape as ``x``.
"""
rhs = ensure_rank1_or_rank2_rhs(
x, leading_dim=self.shape[0], require_finite=True
)
result = self.apply(rhs)
try:
out = ensure_rank1_or_rank2_rhs(
result, leading_dim=self.shape[0], require_finite=True
)
except ValueError as exc:
raise ValueError(
"preconditioner output shape or finite-value validation failed."
) from exc
if out.shape != rhs.shape:
raise ValueError(
f"preconditioner output shape {out.shape} does not match "
f"input shape {rhs.shape}."
)
return out
[docs]
def __call__(self, x) -> mx.array:
"""Alias for :meth:`solve`."""
return self.solve(x)
[docs]
@dataclass(frozen=True, slots=True)
class ExactFactorPreconditioner:
"""Exact inverse-apply preconditioner backed by a sparse factorization.
This wrapper composes existing direct sparse solve objects with the
iterative ``M`` protocol. It does not refactorize on application: setup is
completed before construction. :meth:`solve` uses explicit native
LU/Cholesky apply bindings when factors are available, uses the guarded
Accelerate solver for real Accelerate factorized objects, and otherwise
delegates to the stored reusable solve object.
Stored fields include the reusable ``solver``, compatible square ``shape``,
factorization ``method``, implementation ``backend``, and conservative
symmetry/positive-definiteness metadata. Accelerate-backed
``FactorizedSolve`` instances keep their Accelerate CPU apply boundary;
native explicit factors keep their native CPU/Metal triangular-solve apply
boundary.
"""
solver: object
shape: tuple[int, int]
method: str
backend: str
kind: str = "exact"
is_symmetric: bool = False
is_positive_definite: bool = False
factor_nnz: int = -1
native_apply_kind: str | None = None
native_factorization: object | None = None
[docs]
def __post_init__(self) -> None:
"""Validate the wrapped exact factorization metadata."""
if not hasattr(self.solver, "solve") or not callable(self.solver.solve):
raise TypeError("exact factor preconditioners require solve(x).")
object.__setattr__(self, "shape", square_shape(self.shape))
if self.factor_nnz < -1:
raise ValueError("factor_nnz must be -1 or a non-negative integer.")
if self.native_apply_kind not in {None, "lu", "cholesky", "accelerate"}:
raise ValueError(
"native_apply_kind must be None, 'lu', 'cholesky', or 'accelerate'."
)
@property
def dtype(self):
"""Value dtype used by current sparse direct solve backends."""
return mx.float32
@property
def nnz(self) -> int:
"""Stored factor nonzero count, or ``-1`` for opaque factors."""
return int(self.factor_nnz)
@property
def setup_device(self) -> str:
"""Device category used during factorization setup."""
if self.backend == "accelerate":
return "accelerate_cpu"
return "native_cpu"
@property
def apply_device(self) -> str:
"""Device category used during inverse application."""
if self.backend == "accelerate":
return "accelerate_cpu"
return "native_cpu_or_metal"
@property
def setup_info(self) -> Mapping[str, object]:
"""Structured metadata describing the exact factorization wrapper."""
return {
"kind": self.kind,
"shape": self.shape,
"method": self.method,
"backend": self.backend,
"setup_device": self.setup_device,
"apply_device": self.apply_device,
"nnz": self.nnz,
"is_symmetric": self.is_symmetric,
"is_positive_definite": self.is_positive_definite,
"solver_type": type(self.solver).__name__,
"native_apply_kind": self.native_apply_kind,
"has_native_solver_apply": self.native_apply_kind is not None,
}
[docs]
def solve(self, x) -> mx.array:
"""Apply the exact factorized solve to a vector or dense RHS matrix.
Args:
x: Right-hand side with shape ``(n,)`` or ``(n, nrhs)``.
Returns:
Finite ``float32`` solution with the same shape as ``x``.
"""
rhs = ensure_rank1_or_rank2_rhs(
x, leading_dim=self.shape[0], require_finite=True
)
result = self._native_or_wrapped_solve(rhs)
out = ensure_rank1_or_rank2_rhs(
result, leading_dim=self.shape[1], require_finite=True
)
if out.shape != rhs.shape:
raise ValueError(
f"exact factor preconditioner output shape {out.shape} does "
f"not match input shape {rhs.shape}."
)
return out
def _native_or_wrapped_solve(self, rhs) -> mx.array:
"""Apply explicit native factors when available, otherwise delegate."""
factor = self.native_factorization
if self.native_apply_kind == "lu" and factor is not None:
return _native.csr_exact_lu_preconditioner_apply(
factor.perm,
factor.L.data,
factor.L.indices,
factor.L.indptr,
factor.U.data,
factor.U.indices,
factor.U.indptr,
rhs,
self.shape,
)
if self.native_apply_kind == "cholesky" and factor is not None:
upper = factor._upper()
return _native.csr_exact_cholesky_preconditioner_apply(
factor.L.data,
factor.L.indices,
factor.L.indptr,
upper.data,
upper.indices,
upper.indptr,
rhs,
self.shape,
)
return self.solver.solve(rhs)
[docs]
def matvec(self, x) -> mx.array:
"""Alias for :meth:`solve` for inverse-operator composition."""
return self.solve(x)
[docs]
def __call__(self, x) -> mx.array:
"""Alias for :meth:`solve`."""
return self.solve(x)
[docs]
@dataclass(frozen=True, slots=True)
class ILU0Preconditioner:
"""Natural-order no-fill incomplete LU preconditioner.
``ILU0Preconditioner`` stores explicit CSR factors ``L`` and ``U`` from a
native CPU ILU(0) setup. ``L`` has a unit diagonal and the factors preserve
the lower/upper sparsity pattern of the canonical CSR input without
fill-reducing ordering or pivoting. Application performs the standard
forward and backward triangular solves ``L y = x`` and ``U z = y`` using
native CSR triangular-solve kernels.
Stored fields include the lower/upper CSR factors, explicit diagonal
``shift`` used during setup, validation ``check`` mode, optional
triangular-analysis reuse flag, and conservative nonsymmetric metadata.
"""
L: CSRArray
U: CSRArray
shift: float = 0.0
check: bool = True
reuse_analysis: bool = False
kind: str = "ilu0"
is_symmetric: bool = False
is_positive_definite: bool = False
_l_diagonal_positions: mx.array | None = field(
init=False, default=None, repr=False, compare=False
)
_u_diagonal_positions: mx.array | None = field(
init=False, default=None, repr=False, compare=False
)
_l_level_schedule: tuple[mx.array, mx.array] | None = field(
init=False, default=None, repr=False, compare=False
)
_u_level_schedule: tuple[mx.array, mx.array] | None = field(
init=False, default=None, repr=False, compare=False
)
[docs]
def __post_init__(self) -> None:
"""Validate factor metadata and optionally cache triangular analysis."""
shape = square_shape(self.L.shape)
if square_shape(self.U.shape) != shape:
raise ValueError("ILU0 L and U factors must have matching square shapes.")
if self.L.data.dtype != mx.float32 or self.U.data.dtype != mx.float32:
raise TypeError("ILU0 factors currently require float32 values.")
object.__setattr__(self, "shift", finite_scalar("shift", self.shift))
object.__setattr__(self, "check", bool(self.check))
object.__setattr__(self, "reuse_analysis", bool(self.reuse_analysis))
if self.reuse_analysis:
l_diag = _native.csr_triangular_diagonal_positions(
self.L.indices, self.L.indptr, shape
)
u_diag = _native.csr_triangular_diagonal_positions(
self.U.indices, self.U.indptr, shape
)
l_levels = _native.csr_triangular_level_schedule(
self.L.indices, self.L.indptr, shape, lower=True
)
u_levels = _native.csr_triangular_level_schedule(
self.U.indices, self.U.indptr, shape, lower=False
)
object.__setattr__(self, "_l_diagonal_positions", l_diag)
object.__setattr__(self, "_u_diagonal_positions", u_diag)
object.__setattr__(self, "_l_level_schedule", l_levels)
object.__setattr__(self, "_u_level_schedule", u_levels)
@property
def shape(self) -> tuple[int, int]:
"""Shape of the preconditioned square operator."""
return self.L.shape
@property
def dtype(self):
"""Value dtype used by the stored factors."""
return mx.float32
@property
def nnz_L(self) -> int:
"""Stored nonzero count in the unit lower factor."""
return int(self.L.nnz)
@property
def nnz_U(self) -> int:
"""Stored nonzero count in the upper factor."""
return int(self.U.nnz)
@property
def nnz(self) -> int:
"""Total stored factor entries."""
return self.nnz_L + self.nnz_U
@property
def setup_device(self) -> str:
"""Device category used during ILU(0) setup."""
return "native_cpu"
@property
def apply_device(self) -> str:
"""Device category used during inverse application."""
return "native_cpu_or_metal"
@property
def setup_info(self) -> Mapping[str, object]:
"""Structured metadata describing ILU(0) setup choices."""
return {
"kind": self.kind,
"shape": self.shape,
"shift": self.shift,
"check": self.check,
"reuse_analysis": self.reuse_analysis,
"ordering": "natural",
"fill": 0,
"unit_diagonal_L": True,
"nnz_L": self.nnz_L,
"nnz_U": self.nnz_U,
"nnz": self.nnz,
"is_symmetric": self.is_symmetric,
"is_positive_definite": self.is_positive_definite,
}
[docs]
def solve(self, x) -> mx.array:
"""Apply the ILU(0) inverse approximation to a vector or matrix RHS.
Args:
x: Right-hand side with shape ``(n,)`` or ``(n, nrhs)``.
Returns:
Native triangular-solve result ``U^{-1} L^{-1} x`` with the same
rank and leading dimension as ``x``.
"""
rhs = ensure_rank1_or_rank2_rhs(
x, leading_dim=self.shape[0], require_finite=True
)
if self.reuse_analysis:
y = _native.csr_triangular_solve(
self.L.data,
self.L.indices,
self.L.indptr,
rhs,
self.shape,
lower=True,
unit_diagonal=True,
diagonal_positions=self._l_diagonal_positions,
level_schedule=self._l_level_schedule,
)
return _native.csr_triangular_solve(
self.U.data,
self.U.indices,
self.U.indptr,
y,
self.shape,
lower=False,
unit_diagonal=False,
diagonal_positions=self._u_diagonal_positions,
level_schedule=self._u_level_schedule,
)
return _native.csr_ilu0_preconditioner_apply(
self.L.data,
self.L.indices,
self.L.indptr,
self.U.data,
self.U.indices,
self.U.indptr,
rhs,
self.shape,
)
[docs]
def matvec(self, x) -> mx.array:
"""Alias for :meth:`solve` for inverse-operator composition."""
return self.solve(x)
[docs]
def __call__(self, x) -> mx.array:
"""Alias for :meth:`solve`."""
return self.solve(x)
[docs]
@dataclass(frozen=True, slots=True)
class IC0Preconditioner:
"""Natural-order no-fill incomplete Cholesky preconditioner.
``IC0Preconditioner`` stores an explicit lower CSR factor ``L`` produced by
native CPU IC(0) setup. The factor uses natural ordering, preserves the
symmetric lower sparsity pattern of the canonical CSR input, and introduces
no fill. Application performs two native triangular solves,
``L y = x`` followed by ``L.T z = y``.
Stored fields include the lower factor, explicit non-negative diagonal
``shift`` used during setup, strict positive-pivot ``check`` mode, and SPD
metadata suitable for CG and MINRES-style solvers.
"""
L: CSRArray
shift: float = 0.0
check: bool = True
kind: str = "ichol0"
is_symmetric: bool = True
is_positive_definite: bool = True
_upper_factor: CSRArray | None = field(
init=False, default=None, repr=False, compare=False
)
[docs]
def __post_init__(self) -> None:
"""Validate factor metadata and explicit shift policy."""
square_shape(self.L.shape)
if self.L.data.dtype != mx.float32:
raise TypeError("IC0 factors currently require float32 values.")
shift_value = finite_scalar("shift", self.shift)
if shift_value < 0.0:
raise ValueError("shift must be non-negative for IC0.")
object.__setattr__(self, "shift", shift_value)
object.__setattr__(self, "check", bool(self.check))
@property
def shape(self) -> tuple[int, int]:
"""Shape of the preconditioned square operator."""
return self.L.shape
@property
def dtype(self):
"""Value dtype used by the stored factor."""
return mx.float32
@property
def nnz_L(self) -> int:
"""Stored nonzero count in the lower factor."""
return int(self.L.nnz)
@property
def nnz(self) -> int:
"""Total stored factor entries."""
return self.nnz_L
@property
def setup_device(self) -> str:
"""Device category used during IC(0) setup."""
return "native_cpu"
@property
def apply_device(self) -> str:
"""Device category used during inverse application."""
return "native_cpu_or_metal"
@property
def setup_info(self) -> Mapping[str, object]:
"""Structured metadata describing IC(0) setup choices."""
return {
"kind": self.kind,
"shape": self.shape,
"shift": self.shift,
"check": self.check,
"ordering": "natural",
"fill": 0,
"factor": "lower",
"nnz_L": self.nnz_L,
"nnz": self.nnz,
"is_symmetric": self.is_symmetric,
"is_positive_definite": self.is_positive_definite,
}
def _upper(self) -> CSRArray:
"""Return and cache ``L.T`` as a CSR upper factor."""
upper = self._upper_factor
if upper is None:
upper = self.L.T
object.__setattr__(self, "_upper_factor", upper)
return upper
[docs]
def solve(self, x) -> mx.array:
"""Apply the IC(0) inverse approximation to a vector or matrix RHS.
Args:
x: Right-hand side with shape ``(n,)`` or ``(n, nrhs)``.
Returns:
Native triangular-solve result ``L.T^{-1} L^{-1} x`` with the same
rank and leading dimension as ``x``.
"""
rhs = ensure_rank1_or_rank2_rhs(
x, leading_dim=self.shape[0], require_finite=True
)
upper = self._upper()
return _native.csr_ic0_preconditioner_apply(
self.L.data,
self.L.indices,
self.L.indptr,
upper.data,
upper.indices,
upper.indptr,
rhs,
self.shape,
)
[docs]
def matvec(self, x) -> mx.array:
"""Alias for :meth:`solve` for inverse-operator composition."""
return self.solve(x)
[docs]
def __call__(self, x) -> mx.array:
"""Alias for :meth:`solve`."""
return self.solve(x)
[docs]
@dataclass(frozen=True, slots=True)
class ChebyshevPreconditioner:
"""Polynomial inverse preconditioner for SPD sparse matrices.
``ChebyshevPreconditioner`` stores a canonical ``float32`` CSR operator and
applies a fixed-degree first-kind Chebyshev semi-iteration with zero initial
guess. Application uses only sparse matrix-vector/matrix products and
vector updates, so it follows the selected native CPU or Metal device path
without triangular solves or dense factorization.
Stored fields include the CSR operator, polynomial ``degree``, validated
positive spectral interval ``[lambda_min, lambda_max]``, whether setup used
native spectral ``estimate`` data, and detailed ``spectral_info`` metadata.
"""
A: CSRArray
degree: int = 2
lambda_min: float = 0.0
lambda_max: float = 0.0
estimate: bool = True
spectral_info: Mapping[str, object] = field(default_factory=dict)
kind: str = "chebyshev"
is_symmetric: bool = True
is_positive_definite: bool = True
[docs]
def __post_init__(self) -> None:
"""Validate polynomial metadata and CSR storage."""
shape = square_shape(self.A.shape)
if self.A.data.dtype != mx.float32:
raise TypeError("Chebyshev preconditioners require float32 CSR values.")
degree_value = int(self.degree)
if degree_value <= 0:
raise ValueError("degree must be positive.")
lambda_min_value = finite_scalar("lambda_min", self.lambda_min)
lambda_max_value = finite_scalar("lambda_max", self.lambda_max)
if lambda_min_value <= 0.0 or lambda_max_value <= lambda_min_value:
raise ValueError(
"Chebyshev spectral interval must satisfy "
"0 < lambda_min < lambda_max."
)
object.__setattr__(self, "degree", degree_value)
object.__setattr__(self, "lambda_min", lambda_min_value)
object.__setattr__(self, "lambda_max", lambda_max_value)
object.__setattr__(self, "estimate", bool(self.estimate))
if shape != self.A.shape:
raise ValueError("Chebyshev operator shape must be square.")
@property
def shape(self) -> tuple[int, int]:
"""Shape of the preconditioned square operator."""
return self.A.shape
@property
def dtype(self):
"""Value dtype used by the stored operator."""
return mx.float32
@property
def nnz(self) -> int:
"""Number of sparse operator entries used during polynomial apply."""
return int(self.A.nnz)
@property
def setup_device(self) -> str:
"""Device category used during spectral setup."""
return "native_cpu"
@property
def apply_device(self) -> str:
"""Device category used during inverse application."""
return "native_cpu_or_metal"
@property
def setup_info(self) -> Mapping[str, object]:
"""Structured metadata describing Chebyshev setup choices."""
return {
"kind": self.kind,
"shape": self.shape,
"degree": self.degree,
"lambda_min": self.lambda_min,
"lambda_max": self.lambda_max,
"estimate": self.estimate,
"nnz": self.nnz,
"is_symmetric": self.is_symmetric,
"is_positive_definite": self.is_positive_definite,
"spectral_info": dict(self.spectral_info),
}
[docs]
def solve(self, x) -> mx.array:
"""Apply the Chebyshev polynomial inverse to a vector or matrix RHS.
Args:
x: Right-hand side with shape ``(n,)`` or ``(n, nrhs)``.
Returns:
Native Chebyshev semi-iteration result with the same shape as
``x``.
"""
rhs = ensure_rank1_or_rank2_rhs(
x, leading_dim=self.shape[0], require_finite=True
)
return _native.csr_chebyshev_preconditioner_apply(
self.A.data,
self.A.indices,
self.A.indptr,
rhs,
self.shape,
degree=self.degree,
lambda_min=self.lambda_min,
lambda_max=self.lambda_max,
)
[docs]
def matvec(self, x) -> mx.array:
"""Alias for :meth:`solve` for inverse-operator composition."""
return self.solve(x)
[docs]
def __call__(self, x) -> mx.array:
"""Alias for :meth:`solve`."""
return self.solve(x)
[docs]
def identity(A_or_shape, *, dtype=None) -> IdentityPreconditioner:
"""Create a no-op preconditioner for a square shape or sparse matrix.
Args:
A_or_shape: Square sparse matrix, ``(n, n)`` shape tuple, or integer
dimension.
dtype: Optional dtype. The current native solver integration accepts
only ``None`` or ``mlx.core.float32``.
Returns:
An :class:`IdentityPreconditioner`.
"""
shape = square_shape(A_or_shape)
return IdentityPreconditioner(shape=shape, dtype=normalize_identity_dtype(dtype))
[docs]
def diagonal(
inv_diag_or_diag,
*,
inverse: bool = False,
shape=None,
dtype=None,
zero_atol: float = 0.0,
) -> DiagonalPreconditioner:
"""Create an explicit diagonal inverse-apply preconditioner.
Args:
inv_diag_or_diag: Rank-1 diagonal values. Interpreted as a diagonal by
default, or as an inverse diagonal when ``inverse=True``.
inverse: If ``True``, use ``inv_diag_or_diag`` directly as the inverse
diagonal. If ``False``, validate and invert it.
shape: Optional square shape. Defaults to ``(n, n)`` where ``n`` is the
vector length.
dtype: Optional dtype. The current native preconditioner path accepts
only ``None`` or ``mlx.core.float32``.
zero_atol: Absolute threshold used when rejecting zero diagonal entries
before inversion.
Returns:
A :class:`DiagonalPreconditioner` with finite ``float32`` inverse
diagonal storage.
"""
values = ensure_float32_vector("diagonal", inv_diag_or_diag, require_finite=True)
if dtype is not None and dtype != mx.float32:
raise TypeError("diagonal preconditioners currently use float32 values.")
pc_shape = (
square_shape((values.shape[0], values.shape[0]))
if shape is None
else square_shape(shape)
)
if values.shape[0] != pc_shape[0]:
raise ValueError(
f"diagonal has length {values.shape[0]}, expected {pc_shape[0]}."
)
if inverse:
inv_diag = values
else:
atol = float(zero_atol)
if atol < 0.0:
raise ValueError("zero_atol must be non-negative.")
if host_bool(mx.any(mx.abs(values) <= atol)):
raise ValueError("diagonal contains zero or near-zero entries.")
inv_diag = 1.0 / values
return DiagonalPreconditioner(inv_diag, pc_shape)
[docs]
def jacobi(
A,
*,
omega: float = 1.0,
shift: float = 0.0,
zero_policy: str = "raise",
zero_atol: float = 0.0,
check: bool = False,
) -> JacobiPreconditioner:
"""Create a Jacobi preconditioner from a sparse matrix diagonal.
The inverse diagonal is computed as ``omega / (diag(A) + shift)`` after
normalizing ``A`` to canonical CSR so duplicate diagonal entries are summed.
The input sparse matrix is never mutated.
Args:
A: ``CSRArray``, ``COOArray``, ``CSCArray``, or sparse-backed
``LinearOperator``.
omega: Damping/weighting factor.
shift: Explicit diagonal shift applied before inversion.
zero_policy: ``"raise"`` rejects zero/near-zero shifted diagonals.
``"unit"`` replaces those entries with ``1`` before inversion.
zero_atol: Absolute threshold used to identify near-zero shifted
diagonal entries.
check: When ``True``, require ``omega > 0`` and a strictly positive
shifted diagonal before any ``zero_policy`` replacement, then mark
the preconditioner as positive definite.
Returns:
A :class:`JacobiPreconditioner` suitable for native PCG.
"""
if zero_policy not in {"raise", "unit"}:
raise ValueError("zero_policy must be 'raise' or 'unit'.")
omega_value = finite_scalar("omega", omega)
shift_value = finite_scalar("shift", shift)
checked = bool(check)
if checked and omega_value <= 0.0:
raise ValueError("omega must be positive when check=True.")
csr = canonical_csr(
A,
context="jacobi",
dense_guidance="",
allow_sparse_linear_operator=True,
)
if csr.shape[0] != csr.shape[1]:
raise ValueError(f"jacobi requires a square matrix, got {csr.shape}.")
diag = ensure_float32_vector("diagonal", csr.diagonal())
shifted = diag + mx.array(shift_value, dtype=mx.float32)
if not host_bool(mx.all(mx.isfinite(shifted))):
raise ValueError("shifted diagonal must contain only finite values.")
atol = float(zero_atol)
if atol < 0.0:
raise ValueError("zero_atol must be non-negative.")
near_zero = mx.abs(shifted) <= atol
positive_shifted_diagonal = host_bool(mx.all(shifted > atol)) if checked else None
if host_bool(mx.any(near_zero)):
if zero_policy == "raise":
raise ValueError(
"Jacobi shifted diagonal contains zero or near-zero entries."
)
shifted = mx.where(near_zero, mx.ones_like(shifted), shifted)
positive_diagonal = None
is_positive_definite = False
if checked:
positive_diagonal = positive_shifted_diagonal
if not positive_diagonal:
raise ValueError(
"Jacobi shifted diagonal must be strictly positive when check=True."
)
is_positive_definite = True
inv_diag = mx.array(omega_value, dtype=mx.float32) / shifted
return JacobiPreconditioner(
inv_diag,
csr.shape,
is_positive_definite=is_positive_definite,
omega=omega_value,
shift=shift_value,
zero_policy=zero_policy,
zero_atol=atol,
checked=checked,
positive_diagonal=positive_diagonal,
)
[docs]
def ilu0(
A,
*,
shift: float = 0.0,
check: bool = True,
reuse_analysis: bool = False,
) -> ILU0Preconditioner:
"""Create a natural-order no-fill ILU(0) preconditioner.
The setup normalizes ``A`` to canonical CSR, promotes real low-precision
values to ``float32``, and runs a native CPU ILU(0) factorization with no
fill-reducing ordering and no pivoting. The original sparse matrix is not
mutated. ``shift`` is added only to existing diagonal entries before setup;
a missing diagonal remains an error.
Args:
A: ``CSRArray``, ``COOArray``, ``CSCArray``, or sparse-backed
``LinearOperator`` describing a square nonsingular matrix.
shift: Explicit diagonal shift added before factorization.
check: When ``True`` (default), native setup uses a scale-aware
near-zero pivot guard. When ``False``, only exact zero and
non-finite pivots are rejected during setup.
reuse_analysis: When ``True``, cache triangular diagonal-position and
level-schedule analysis for repeated explicit ``M(rhs)`` calls.
The default is ``False`` because v0.0.4b1 triangular-analysis
benchmarks showed this must be workload-measured before enabling.
Returns:
An :class:`ILU0Preconditioner` whose application uses native CSR
triangular solves.
"""
shift_value = finite_scalar("shift", shift)
csr = ensure_float32_csr(
canonical_csr(
A,
context="ILU(0)",
dense_guidance="Dense MLX arrays belong in mlx.linalg, not "
"mlx_sparse.linalg.preconditioners.",
allow_sparse_linear_operator=True,
),
context="ILU(0)",
)
if csr.shape[0] != csr.shape[1]:
raise ValueError(f"ilu0 requires a square matrix, got {csr.shape}.")
l_data, l_indices, l_indptr, u_data, u_indices, u_indptr = _native.csr_ilu0(
csr.data,
csr.indices,
csr.indptr,
csr.shape,
shift=shift_value,
check=bool(check),
)
return ILU0Preconditioner(
L=CSRArray(
data=l_data,
indices=l_indices,
indptr=l_indptr,
shape=csr.shape,
sorted_indices=True,
has_canonical_format=True,
),
U=CSRArray(
data=u_data,
indices=u_indices,
indptr=u_indptr,
shape=csr.shape,
sorted_indices=True,
has_canonical_format=True,
),
shift=shift_value,
check=bool(check),
reuse_analysis=bool(reuse_analysis),
)
[docs]
def ichol0(A, *, shift: float = 0.0, check: bool = True) -> IC0Preconditioner:
"""Create a natural-order no-fill IC(0) preconditioner.
The setup normalizes ``A`` to canonical CSR, promotes real low-precision
values to ``float32``, and runs a native CPU incomplete Cholesky
factorization with zero fill and natural ordering. The original sparse
matrix is not mutated. ``shift`` is a non-negative scalar added only to
existing diagonal entries before setup; a missing diagonal remains an
error and no pivot is silently perturbed.
Args:
A: ``CSRArray``, ``COOArray``, ``CSCArray``, or sparse-backed
``LinearOperator`` describing an SPD square matrix.
shift: Explicit non-negative diagonal shift added before
factorization. Defaults to ``0.0``.
check: When ``True`` (default), native setup enforces symmetry for
explicitly stored mirrored entries and uses a scale-aware positive
pivot guard. When ``False``, non-positive and non-finite pivots
are still rejected, but near-zero pivot and symmetry checks are
relaxed.
Returns:
An :class:`IC0Preconditioner` whose application uses native CSR
triangular solves ``L`` and ``L.T``.
"""
shift_value = finite_scalar("shift", shift)
if shift_value < 0.0:
raise ValueError("shift must be non-negative for IC0.")
csr = ensure_float32_csr(
canonical_csr(
A,
context="IC(0)",
dense_guidance="Dense MLX arrays belong in mlx.linalg, not "
"mlx_sparse.linalg.preconditioners.",
allow_sparse_linear_operator=True,
),
context="IC(0)",
)
if csr.shape[0] != csr.shape[1]:
raise ValueError(f"ichol0 requires a square matrix, got {csr.shape}.")
l_data, l_indices, l_indptr = _native.csr_ic0(
csr.data,
csr.indices,
csr.indptr,
csr.shape,
shift=shift_value,
check=bool(check),
)
return IC0Preconditioner(
L=CSRArray(
data=l_data,
indices=l_indices,
indptr=l_indptr,
shape=csr.shape,
sorted_indices=True,
has_canonical_format=True,
),
shift=shift_value,
check=bool(check),
)
[docs]
def chebyshev(
A,
*,
degree: int = 2,
lambda_min: float | None = None,
lambda_max: float | None = None,
estimate: bool = True,
) -> ChebyshevPreconditioner:
"""Create a GPU-friendly Chebyshev polynomial preconditioner.
The setup normalizes ``A`` to canonical CSR, promotes real low-precision
values to ``float32``, and computes native Gershgorin spectral bounds plus
optional native Lanczos Ritz estimates. The resulting preconditioner
applies a fixed-degree first-kind Chebyshev semi-iteration with zero
initial guess, using only sparse matrix products and vector updates.
Args:
A: ``CSRArray``, ``COOArray``, ``CSCArray``, or sparse-backed
``LinearOperator`` describing a real SPD square matrix.
degree: Positive polynomial degree. Defaults to ``2``.
lambda_min: Optional positive lower spectral bound. If omitted, setup
uses a positive Gershgorin lower bound when available, otherwise a
conservative Lanczos-derived lower estimate when ``estimate=True``.
lambda_max: Optional upper spectral bound. If omitted, setup uses the
Gershgorin upper bound when available, with Lanczos metadata
recorded for diagnostics.
estimate: Whether native Lanczos Ritz estimates should be computed as
a fallback/refinement for the spectral interval. Defaults to
``True``.
Returns:
A :class:`ChebyshevPreconditioner` with native CPU/Metal apply support.
Raises:
ValueError: If no valid positive interval can be established.
"""
degree_value = int(degree)
if degree_value <= 0:
raise ValueError("degree must be positive.")
csr = ensure_float32_csr(
canonical_csr(
A,
context="Chebyshev",
dense_guidance="Dense MLX arrays belong in mlx.linalg, not "
"mlx_sparse.linalg.preconditioners.",
allow_sparse_linear_operator=True,
),
context="Chebyshev",
)
if csr.shape[0] != csr.shape[1]:
raise ValueError(f"chebyshev requires a square matrix, got {csr.shape}.")
(
gershgorin_min,
gershgorin_max,
ritz_min,
ritz_max,
diagonal_min,
diagonal_max,
estimate_steps,
) = _native.csr_chebyshev_spectral_bounds(
csr.data,
csr.indices,
csr.indptr,
csr.shape,
estimate=bool(estimate),
estimate_steps=0,
)
def positive_finite(value: float) -> bool:
return math.isfinite(float(value)) and float(value) > 0.0
if lambda_max is None:
if positive_finite(gershgorin_max):
lambda_max_value = float(gershgorin_max)
lambda_max_source = "gershgorin"
elif bool(estimate) and positive_finite(ritz_max):
lambda_max_value = 1.1 * float(ritz_max)
lambda_max_source = "lanczos_1.1"
else:
raise ValueError(
"could not determine a positive Chebyshev upper spectral "
"bound; pass lambda_max explicitly."
)
else:
lambda_max_value = finite_scalar("lambda_max", lambda_max)
lambda_max_source = "explicit"
tiny = max(1.0e-12, 16.0 * 1.1920928955078125e-7)
if lambda_min is None:
if positive_finite(gershgorin_min):
lambda_min_value = float(gershgorin_min)
lambda_min_source = "gershgorin"
elif bool(estimate) and positive_finite(ritz_min):
lambda_min_value = max(
tiny, min(0.5 * float(ritz_min), 0.1 * lambda_max_value)
)
lambda_min_source = "lanczos_0.5"
else:
raise ValueError(
"could not determine a positive Chebyshev lower spectral "
"bound; pass lambda_min explicitly or use estimate=True for "
"SPD matrices whose Gershgorin lower bound is non-positive."
)
else:
lambda_min_value = finite_scalar("lambda_min", lambda_min)
lambda_min_source = "explicit"
if (
not math.isfinite(lambda_min_value)
or not math.isfinite(lambda_max_value)
or lambda_min_value <= 0.0
or lambda_max_value <= lambda_min_value
):
raise ValueError(
"Chebyshev spectral interval must satisfy 0 < lambda_min < lambda_max."
)
spectral_info = {
"gershgorin_min": float(gershgorin_min),
"gershgorin_max": float(gershgorin_max),
"ritz_min": float(ritz_min),
"ritz_max": float(ritz_max),
"diagonal_min": float(diagonal_min),
"diagonal_max": float(diagonal_max),
"estimate_steps": int(estimate_steps),
"lambda_min_source": lambda_min_source,
"lambda_max_source": lambda_max_source,
}
return ChebyshevPreconditioner(
A=csr,
degree=degree_value,
lambda_min=lambda_min_value,
lambda_max=lambda_max_value,
estimate=bool(estimate),
spectral_info=spectral_info,
)
[docs]
def from_factorized(solver) -> ExactFactorPreconditioner:
"""Wrap an existing sparse factorization as an exact preconditioner.
Args:
solver: A :class:`~mlx_sparse.linalg.FactorizedSolve`,
:class:`~mlx_sparse.linalg.SparseLU`, or
:class:`~mlx_sparse.linalg.SparseCholesky` instance.
Returns:
An :class:`ExactFactorPreconditioner` whose inverse application uses a
native exact-apply path when the factorization exposes one.
Raises:
TypeError: If ``solver`` is not one of the supported factorization
objects.
ValueError: If the factorization does not represent a square operator.
"""
from mlx_sparse.linalg._factorizations import (
FactorizedSolve,
SparseCholesky,
SparseLU,
)
from mlx_sparse.linalg.utils.factorization import NativeFactorizedSolve
if isinstance(solver, SparseCholesky):
return ExactFactorPreconditioner(
solver=solver,
shape=square_shape(solver.shape),
method="cholesky",
backend="native",
is_symmetric=True,
is_positive_definite=True,
factor_nnz=int(solver.L.nnz),
native_apply_kind="cholesky",
native_factorization=solver,
)
if isinstance(solver, SparseLU):
return ExactFactorPreconditioner(
solver=solver,
shape=square_shape(solver.shape),
method="lu",
backend="native",
is_symmetric=False,
is_positive_definite=False,
factor_nnz=int(solver.L.nnz + solver.U.nnz),
native_apply_kind="lu",
native_factorization=solver,
)
if isinstance(solver, FactorizedSolve):
shape = square_shape(solver.shape)
if int(solver.rhs_size) != shape[0] or int(solver.solution_size) != shape[1]:
raise ValueError(
"exact factor preconditioners require matching RHS and "
"solution dimensions."
)
method = str(solver.method)
is_cholesky = method == "cholesky"
native_apply_kind = None
native_factorization = None
factor_nnz = -1
wrapped_solver = getattr(solver, "_solver", None)
if isinstance(wrapped_solver, NativeFactorizedSolve):
factor = wrapped_solver.factorization
if isinstance(factor, SparseLU):
native_apply_kind = "lu"
native_factorization = factor
factor_nnz = int(factor.L.nnz + factor.U.nnz)
elif isinstance(factor, SparseCholesky):
native_apply_kind = "cholesky"
native_factorization = factor
factor_nnz = int(factor.L.nnz)
elif str(solver.backend) == "accelerate" and _native.is_accelerate_float_solve(
wrapped_solver
):
native_apply_kind = "accelerate"
native_factorization = wrapped_solver
return ExactFactorPreconditioner(
solver=solver,
shape=shape,
method=method,
backend=str(solver.backend),
is_symmetric=method in {"cholesky", "ldlt"},
is_positive_definite=is_cholesky,
factor_nnz=factor_nnz,
native_apply_kind=native_apply_kind,
native_factorization=native_factorization,
)
raise TypeError(
"from_factorized expects FactorizedSolve, SparseLU, or SparseCholesky."
)
[docs]
def exact(A, *, method: str = "auto") -> ExactFactorPreconditioner:
"""Factorize ``A`` once and return an exact inverse-apply preconditioner.
``exact`` is a convenience wrapper around
:func:`mlx_sparse.linalg.factorized`. It is intended as a correctness
baseline, diagnostic tool, and composition point for existing direct
solvers rather than a performance headline.
Args:
A: Sparse coefficient matrix accepted by
:func:`mlx_sparse.linalg.factorized`.
method: Direct factorization method. Defaults to ``"auto"``.
Returns:
An :class:`ExactFactorPreconditioner` wrapping the reusable factorized
solve object.
"""
from mlx_sparse.linalg._factorizations import factorized
return from_factorized(factorized(A, method=method))
[docs]
def aspreconditioner(M, A=None, *, assume_inverse: bool = True) -> Preconditioner:
"""Normalize supported preconditioner-like objects.
Args:
M: ``None``, an existing preconditioner, an object with ``solve(x)``,
or a callable. Sparse matrices are rejected because they do not
explicitly define an inverse-apply contract.
A: Optional reference matrix or shape used to validate compatibility.
assume_inverse: Must be ``True`` for callables and custom objects,
documenting that their output is already an inverse/preconditioner
application.
Returns:
A supported preconditioner object.
Raises:
ValueError: If ``A`` is required or the preconditioner shape mismatches.
TypeError: If ``M`` is a sparse matrix or unsupported object.
"""
if M is None:
if A is None:
raise ValueError("A is required when M is None.")
return identity(A)
if isinstance(
M,
(
IdentityPreconditioner,
DiagonalPreconditioner,
CallablePreconditioner,
ChebyshevPreconditioner,
ExactFactorPreconditioner,
IC0Preconditioner,
ILU0Preconditioner,
),
):
if A is not None and M.shape != square_shape(A):
raise ValueError(f"preconditioner shape {M.shape} does not match A.shape.")
return M
if isinstance(M, (CSRArray, COOArray, CSCArray)):
raise TypeError(
"sparse matrices are not inverse-apply preconditioners; use "
"preconditioners.jacobi(A), preconditioners.ilu0(A), "
"preconditioners.ichol0(A), or preconditioners.diagonal(...)."
)
from mlx_sparse.linalg._factorizations import (
FactorizedSolve,
SparseCholesky,
SparseLU,
)
if isinstance(M, (FactorizedSolve, SparseLU, SparseCholesky)):
pc = from_factorized(M)
if A is not None and pc.shape != square_shape(A):
raise ValueError(f"preconditioner shape {pc.shape} does not match A.shape.")
return pc
if hasattr(M, "solve") and callable(M.solve):
if not assume_inverse:
raise TypeError("custom preconditioner objects must apply the inverse.")
shape = square_shape(getattr(M, "shape", A)) if A is None else square_shape(A)
if hasattr(M, "shape") and square_shape(M.shape) != shape:
raise ValueError(
f"preconditioner shape {square_shape(M.shape)} does not match "
f"A.shape {shape}."
)
return CallablePreconditioner(M.solve, shape)
if callable(M):
if not assume_inverse:
raise TypeError("callable preconditioners must apply the inverse.")
if A is None:
raise ValueError("A is required when M is a callable.")
return CallablePreconditioner(M, square_shape(A))
raise TypeError(
"M must be None, a supported preconditioner, an inverse-apply object "
"with solve(x), or an inverse-apply callable."
)
__all__ = [
"DiagonalPreconditioner",
"CallablePreconditioner",
"ChebyshevPreconditioner",
"ExactFactorPreconditioner",
"IC0Preconditioner",
"ILU0Preconditioner",
"IdentityPreconditioner",
"JacobiPreconditioner",
"Preconditioner",
"aspreconditioner",
"chebyshev",
"diagonal",
"exact",
"from_factorized",
"identity",
"ichol0",
"ilu0",
"jacobi",
]