Source code for mlx_sparse.linalg._interface

# Copyright (c) 2026 The mlx-sparse contributors - All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#    http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import annotations

from collections.abc import Callable

import mlx.core as mx

from mlx_sparse._coo import COOArray
from mlx_sparse._csc import CSCArray
from mlx_sparse._csr import CSRArray
from mlx_sparse._validation import normalize_shape
from mlx_sparse.linalg.utils.arrays import ensure_array
from mlx_sparse.linalg.utils.operators import sparse_operator

Matvec = Callable[[mx.array], mx.array]
Matmat = Callable[[mx.array], mx.array]


[docs] class LinearOperator: """Sparse/matrix-free operator interface. This class stores callables only. It does not densify an operator or provide dense fallbacks. Sparse arrays are normalized to canonical CSR and wrapped with native CSR matvec/matmat kernels through :func:`aslinearoperator`. """ __slots__ = ( "shape", "matvec_fn", "dtype", "matmat_fn", "rmatvec_fn", "_sparse_array", ) def __init__( self, shape, matvec: Matvec | None = None, *, matvec_fn: Matvec | None = None, dtype=None, matmat: Matmat | None = None, matmat_fn: Matmat | None = None, rmatvec: Matvec | None = None, rmatvec_fn: Matvec | None = None, _sparse_array=None, ) -> None: matvec_impl = matvec_fn if matvec_fn is not None else matvec if matvec_impl is None: raise TypeError("LinearOperator requires a matvec callable.") self.shape = normalize_shape(shape) self.matvec_fn = matvec_impl self.dtype = dtype self.matmat_fn = matmat_fn if matmat_fn is not None else matmat self.rmatvec_fn = rmatvec_fn if rmatvec_fn is not None else rmatvec self._sparse_array = _sparse_array @property def ndim(self) -> int: """Number of dimensions exposed by every linear operator.""" return 2
[docs] def matvec(self, x) -> mx.array: """Apply the operator to a vector: compute ``A @ x``. Args: x: Input vector of shape ``(n,)``. Returns: Output vector of shape ``(m,)`` as an ``mlx.core.array``. Raises: ValueError: If ``x`` is not rank-1 or has the wrong length. """ x = ensure_array(x) if x.ndim != 1: raise ValueError(f"matvec expects rank-1 input, got shape={x.shape}.") if x.shape[0] != self.shape[1]: raise ValueError( f"matvec input has length {x.shape[0]}, expected {self.shape[1]}." ) return self.matvec_fn(x)
[docs] def matmat(self, X) -> mx.array: """Apply the operator to a matrix: compute ``A @ X``. Args: X: Input matrix of shape ``(n, k)``. Returns: Output matrix of shape ``(m, k)`` as an ``mlx.core.array``. Raises: NotImplementedError: If no ``matmat_fn`` was provided at construction time. ValueError: If ``X`` is not rank-2 or has the wrong leading dimension. """ X = ensure_array(X) if X.ndim != 2: raise ValueError(f"matmat expects rank-2 input, got shape={X.shape}.") if X.shape[0] != self.shape[1]: raise ValueError( f"matmat input has leading dimension {X.shape[0]}, " f"expected {self.shape[1]}." ) if self.matmat_fn is None: raise NotImplementedError("matmat is not defined for this operator.") return self.matmat_fn(X)
[docs] def rmatvec(self, x) -> mx.array: """Apply the adjoint operator to a vector: compute ``A.H @ x``. Args: x: Input vector of shape ``(m,)``. Returns: Output vector of shape ``(n,)`` as an ``mlx.core.array``. Raises: NotImplementedError: If no ``rmatvec_fn`` was provided at construction time. ValueError: If ``x`` is not rank-1 or has the wrong length. """ x = ensure_array(x) if x.ndim != 1: raise ValueError(f"rmatvec expects rank-1 input, got shape={x.shape}.") if x.shape[0] != self.shape[0]: raise ValueError( f"rmatvec input has length {x.shape[0]}, expected {self.shape[0]}." ) if self.rmatvec_fn is None: raise NotImplementedError("rmatvec is not defined for this operator.") return self.rmatvec_fn(x)
[docs] def __matmul__(self, rhs): """Apply the operator to a vector or dense matrix with ``@``.""" rhs = ensure_array(rhs) if rhs.ndim == 1: return self.matvec(rhs) if rhs.ndim == 2: return self.matmat(rhs) raise ValueError( f"LinearOperator matmul expects rank-1 or rank-2 RHS, got {rhs.shape}." )
@property def T(self) -> "LinearOperator": """Transpose operator. ``(op.T) @ x`` computes ``A.T @ x``. For real operators ``A.T == A.H``. For complex operators the formula ``A.T @ x = conj(A.H @ conj(x))`` is used so no extra kernel is needed. Requires :attr:`rmatvec_fn` to be defined. """ if self.rmatvec_fn is None: raise NotImplementedError( "LinearOperator.T requires rmatvec to be defined." ) # A.T @ x = conj( rmatvec( conj(x) ) ) # For real dtypes mx.conjugate is a no-op in values, so this is exact. _rv = self.rmatvec_fn _mv = self.matvec_fn sparse = ( self._sparse_array.transpose() if self._sparse_array is not None else None ) return LinearOperator( shape=(self.shape[1], self.shape[0]), matvec_fn=lambda x: mx.conjugate(_rv(mx.conjugate(x))), rmatvec_fn=lambda x: mx.conjugate(_mv(mx.conjugate(x))), dtype=self.dtype, _sparse_array=sparse, ) @property def H(self) -> "LinearOperator": """Hermitian (conjugate) transpose operator. ``(op.H) @ x`` computes ``A.H @ x``. Requires :attr:`rmatvec_fn` to be defined (which stores ``A.H``). The double adjoint ``(A.H).H`` recovers the original ``A``. """ if self.rmatvec_fn is None: raise NotImplementedError( "LinearOperator.H requires rmatvec to be defined." ) _rv = self.rmatvec_fn _mv = self.matvec_fn sparse = self._sparse_array.H if self._sparse_array is not None else None return LinearOperator( shape=(self.shape[1], self.shape[0]), matvec_fn=_rv, rmatvec_fn=_mv, dtype=self.dtype, _sparse_array=sparse, )
[docs] def aslinearoperator(A) -> LinearOperator: """Wrap a sparse matrix or callable as a :class:`LinearOperator`. Accepts several input forms and returns a :class:`LinearOperator` that exposes :meth:`~LinearOperator.matvec`, :meth:`~LinearOperator.matmat`, and :meth:`~LinearOperator.rmatvec` via the native sparse kernels where possible. Args: A: The object to wrap. Accepted types: * :class:`LinearOperator`: returned unchanged. * :class:`~mlx_sparse.CSRArray`, :class:`~mlx_sparse.COOArray`, or :class:`~mlx_sparse.CSCArray`: converted once to canonical CSR and wrapped with native CSR matvec/matmat/rmatvec kernels. * SciPy sparse matrix (``scipy.sparse``): converted to CSR via :func:`~mlx_sparse.from_scipy` then wrapped. * ``(shape, matvec)`` or ``(shape, matvec, matmat)`` tuple: the callables are stored directly with no conversion. Returns: A :class:`LinearOperator` instance. Raises: TypeError: If ``A`` is not one of the accepted types. """ if isinstance(A, LinearOperator): return A if isinstance(A, CSRArray | COOArray | CSCArray): return sparse_operator(A, LinearOperator) if isinstance(A, tuple) and len(A) >= 2: shape, matvec = A[:2] matmat = A[2] if len(A) > 2 else None return LinearOperator(shape=tuple(shape), matvec_fn=matvec, matmat_fn=matmat) try: import scipy.sparse as sp except ImportError: sp = None if sp is not None and sp.issparse(A): from mlx_sparse._construct import from_scipy return sparse_operator(from_scipy(A), LinearOperator) raise TypeError( "aslinearoperator accepts LinearOperator, CSRArray, COOArray, CSCArray, " "SciPy sparse matrices, or (shape, matvec[, matmat]) tuples." )