Source code for mlx_sparse._csr

# Copyright (c) 2026 The mlx-sparse contributors - All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#    http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import annotations

from dataclasses import dataclass

import mlx.core as mx

import mlx_sparse._native as _native
from mlx_sparse._typing import Shape2D, ValidationMode
from mlx_sparse._validation import (
    ensure_mx_array,
    normalize_shape,
    normalize_validation_mode,
    sanitize_scalar,
    validate_csr_metadata,
    validate_csr_values,
)


[docs] @dataclass(frozen=True, slots=True) class CSRArray: """A 2D sparse matrix in Compressed Sparse Row (CSR) format. CSRArray stores a 2D sparse matrix using three MLX arrays: - **data**: non-zero values, shape ``(nnz,)``. - **indices**: column index of each stored value, shape ``(nnz,)``. - **indptr**: row pointer array, shape ``(n_rows + 1,)``. Row ``i`` spans ``data[indptr[i] : indptr[i+1]]`` with corresponding column indices ``indices[indptr[i] : indptr[i+1]]``. Duplicate column entries are permitted unless the matrix is in canonical form. **Format invariants** (checked by ``validate="metadata"`` by default): - All three arrays must be rank-1. - ``data.shape[0] == indices.shape[0]`` (the ``nnz`` count). - ``indptr.shape[0] == n_rows + 1``. - ``indices`` and ``indptr`` share the same integer dtype (``int32`` or ``int64``). - ``data`` dtype is one of ``float32``, ``float16``, ``bfloat16``, or ``complex64``. **Additional value-level invariants** (``validate="full"`` only): - ``indptr[0] == 0``, ``indptr[-1] == nnz``. - ``indptr`` is monotonically nondecreasing. - ``0 <= indices[j] < n_cols`` for all stored values. ``CSRArray`` is immutable (frozen dataclass). Structural operations return new instances. The ``sorted_indices`` and ``has_canonical_format`` flags are metadata hints. Set them only when the input is already known to satisfy those properties. Use ``canonicalize()`` to sort and sum duplicates. Args: data: Non-zero values, shape ``(nnz,)``. indices: Column indices, shape ``(nnz,)``. indptr: Row pointer array, shape ``(n_rows + 1,)``. shape: Matrix dimensions as ``(n_rows, n_cols)``. sorted_indices: Hint that column indices within each row are sorted ascending. Defaults to ``False``. has_canonical_format: Hint that the matrix has sorted column indices and no duplicate column index in any row. Implies ``sorted_indices=True``. Defaults to ``False``. Example:: import mlx.core as mx import mlx_sparse as ms data = mx.array([2.0, -1.0, 4.0, 5.0], dtype=mx.float32) indices = mx.array([0, 2, 1, 3], dtype=mx.int32) indptr = mx.array([0, 2, 2, 4], dtype=mx.int32) A = ms.csr_array((data, indices, indptr), shape=(3, 4)) x = mx.array([1.0, 0.0, 1.0, 1.0], dtype=mx.float32) y = A @ x # shape (3,) """ data: mx.array indices: mx.array indptr: mx.array shape: Shape2D sorted_indices: bool = False has_canonical_format: bool = False def __post_init__(self) -> None: object.__setattr__(self, "shape", normalize_shape(self.shape)) @property def nnz(self) -> int: """Number of stored values (including any duplicates).""" return int(self.data.shape[0]) @property def dtype(self): """Value dtype of the stored non-zeros (e.g. ``mlx.core.float32``).""" return self.data.dtype @property def index_dtype(self): """Integer dtype used for ``indices`` and ``indptr``.""" return self.indices.dtype @property def ndim(self) -> int: """Always 2. Sparse arrays in this package are rank-2.""" return 2 def __repr__(self) -> str: return ( "CSRArray(" f"shape={self.shape}, nnz={self.nnz}, dtype={self.dtype}, " f"index_dtype={self.index_dtype}, " f"sorted_indices={self.sorted_indices}, " f"has_canonical_format={self.has_canonical_format})" )
[docs] def todense(self) -> mx.array: """Materialize the sparse matrix as a dense MLX array. Duplicate column entries in the same row are summed, matching the semantics of ``canonicalize().todense()``. Returns: Dense array of shape ``(n_rows, n_cols)`` and the same dtype as ``self.data``. """ return _native.csr_todense(self.data, self.indices, self.indptr, self.shape)
[docs] def row_sums(self) -> mx.array: """Return the sum of stored values in each CSR row.""" return _native.csr_row_sums(self.data, self.indices, self.indptr, self.shape)
[docs] def col_sums(self) -> mx.array: """Return the sum of stored values in each CSR column.""" return _native.csr_col_sums(self.data, self.indices, self.indptr, self.shape)
[docs] def column_sums(self) -> mx.array: """Alias for :meth:`col_sums`.""" return self.col_sums()
[docs] def row_norms(self) -> mx.array: """Return the L2 norm of each CSR row as ``float32``.""" array = self if self.has_canonical_format else self.canonicalize() return _native.csr_row_norms( array.data, array.indices, array.indptr, array.shape, )
[docs] def diagonal(self) -> mx.array: """Return the summed main diagonal.""" return _native.csr_diagonal(self.data, self.indices, self.indptr, self.shape)
[docs] def trace(self) -> mx.array: """Return the summed main diagonal as a scalar.""" return _native.csr_trace(self.data, self.indices, self.indptr, self.shape)
[docs] def sum(self, axis=None) -> mx.array: """Sum sparse values over all entries, rows, or columns. ``axis=None`` returns a scalar, ``axis=1`` returns row sums, and ``axis=0`` returns column sums. """ if axis is None: return mx.sum(self.row_sums()) if axis in (1, -1): return self.row_sums() if axis in (0, -2): return self.col_sums() raise ValueError(f"CSRArray.sum axis must be None, 0, or 1; got {axis!r}.")
[docs] def sort_indices(self) -> "CSRArray": """Return a new CSRArray with column indices sorted within each row. If ``self.sorted_indices`` is already ``True``, returns ``self`` unchanged (no copy). Otherwise dispatches the native sort primitive. Returns: A new ``CSRArray`` with ``sorted_indices=True`` and ``has_canonical_format=False`` (duplicates may still be present). """ if self.sorted_indices: return self data, indices, indptr = _native.csr_sort_indices( self.data, self.indices, self.indptr, ) return CSRArray( data=data, indices=indices, indptr=indptr, shape=self.shape, sorted_indices=True, has_canonical_format=False, )
[docs] def sum_duplicates(self) -> "CSRArray": """Sum duplicate column entries within each row. Sorts indices first (via ``sort_indices``), then accumulates entries that share the same column index. The resulting ``nnz`` may be smaller than the original. Returns: A new ``CSRArray`` with ``sorted_indices=True`` and ``has_canonical_format=True``. """ sorted_self = self.sort_indices() data, indices, indptr = _native.csr_sum_duplicates( sorted_self.data, sorted_self.indices, sorted_self.indptr, ) return CSRArray( data=data, indices=indices, indptr=indptr, shape=self.shape, sorted_indices=True, has_canonical_format=True, )
[docs] def canonicalize(self) -> "CSRArray": """Return the canonical form: sorted indices, no duplicates. If ``self.has_canonical_format`` is already ``True``, returns ``self`` with no work done. Otherwise calls ``sum_duplicates()``. Returns: A ``CSRArray`` with ``has_canonical_format=True``. """ if self.has_canonical_format: return self return self.sum_duplicates()
[docs] def transpose(self) -> "CSRArray": """Transpose the sparse matrix, returning a new CSRArray. The result has ``shape=(n_cols, n_rows)`` and ``sorted_indices=True``. If the source has ``has_canonical_format=True``, the result also inherits that flag. Returns: A new ``CSRArray`` with shape ``(n_cols, n_rows)``. """ data, indices, indptr = _native.csr_transpose( self.data, self.indices, self.indptr, self.shape, ) return CSRArray( data=data, indices=indices, indptr=indptr, shape=(self.shape[1], self.shape[0]), sorted_indices=True, has_canonical_format=self.has_canonical_format, )
[docs] def tocsc(self, *, canonical: bool | None = None): """Convert to :class:`~mlx_sparse.CSCArray`. Args: canonical: If ``True``, canonicalize the returned CSC matrix. If ``False`` or ``None`` (default), return the structural conversion as produced by the native count/prefix/fill path. The structural path preserves values but does not promise sorted output metadata on every backend. """ from mlx_sparse._csc import CSCArray data, indices, indptr = _native.csr_tocsc( self.data, self.indices, self.indptr, self.shape, ) out = CSCArray( data=data, indices=indices, indptr=indptr, shape=self.shape, sorted_indices=False, has_canonical_format=False, ) if canonical is True: return out.canonicalize() if canonical is False: return CSCArray( data=out.data, indices=out.indices, indptr=out.indptr, shape=out.shape, sorted_indices=out.sorted_indices, has_canonical_format=False, ) return out
@property def T(self) -> "CSRArray": """Transposed matrix. Alias for :meth:`transpose`.""" return self.transpose()
[docs] def conj(self) -> "CSRArray": """Complex-conjugate the stored values. Structure (indices, indptr, shape) is shared. For real dtypes this is a no-op at the value level but still returns a new ``CSRArray`` pointing to the conjugated data array. Returns: A new ``CSRArray`` with conjugated ``data``. """ return CSRArray( data=mx.conjugate(self.data), indices=self.indices, indptr=self.indptr, shape=self.shape, sorted_indices=self.sorted_indices, has_canonical_format=self.has_canonical_format, )
[docs] def conjugate(self) -> "CSRArray": """Alias for :meth:`conj`.""" return self.conj()
@property def H(self) -> "CSRArray": """Hermitian (conjugate) transpose. Equivalent to ``conj().T``.""" return self.conj().transpose()
[docs] def vdot(self, other) -> mx.array: """Sparse Frobenius inner product with another sparse array. Both operands are canonicalized CSR matrices and the matching-column merge is executed by the native sparse kernel. Dense materialization is never used. """ rhs = other if isinstance(rhs, CSRArray): rhs_csr = rhs.canonicalize() else: from mlx_sparse._coo import COOArray if not isinstance(rhs, COOArray): raise TypeError( f"CSRArray.vdot expects CSRArray or COOArray, got {type(rhs).__name__}." ) rhs_csr = rhs.tocsr(canonical=True) if rhs_csr.shape != self.shape: raise ValueError( f"vdot shape mismatch: got {self.shape} and {rhs_csr.shape}." ) lhs = self.canonicalize() if lhs.data.dtype in {mx.float16, mx.bfloat16}: lhs = CSRArray( data=lhs.data.astype(mx.float32), indices=lhs.indices, indptr=lhs.indptr, shape=lhs.shape, sorted_indices=lhs.sorted_indices, has_canonical_format=lhs.has_canonical_format, ) if rhs_csr.data.dtype in {mx.float16, mx.bfloat16}: rhs_csr = CSRArray( data=rhs_csr.data.astype(mx.float32), indices=rhs_csr.indices, indptr=rhs_csr.indptr, shape=rhs_csr.shape, sorted_indices=rhs_csr.sorted_indices, has_canonical_format=rhs_csr.has_canonical_format, ) if lhs.data.dtype != rhs_csr.data.dtype: raise TypeError( "CSRArray.vdot operands must have the same dtype after low-precision promotion." ) if lhs.data.dtype not in {mx.float32, mx.complex64}: raise TypeError("CSRArray.vdot supports float32 and complex64 data.") return _native.csr_vdot( lhs.data, lhs.indices, lhs.indptr, rhs_csr.data, rhs_csr.indices, rhs_csr.indptr, lhs.shape, )
[docs] def dot(self, other) -> mx.array: """Sparse Frobenius dot product with another sparse array. Unlike :meth:`vdot`, complex operands are not conjugated. """ rhs = other if isinstance(rhs, CSRArray): rhs_csr = rhs.canonicalize() else: from mlx_sparse._coo import COOArray if not isinstance(rhs, COOArray): raise TypeError( f"CSRArray.dot expects CSRArray or COOArray, got {type(rhs).__name__}." ) rhs_csr = rhs.tocsr(canonical=True) if rhs_csr.shape != self.shape: raise ValueError( f"dot shape mismatch: got {self.shape} and {rhs_csr.shape}." ) lhs = self.canonicalize() if lhs.data.dtype in {mx.float16, mx.bfloat16}: lhs = CSRArray( data=lhs.data.astype(mx.float32), indices=lhs.indices, indptr=lhs.indptr, shape=lhs.shape, sorted_indices=lhs.sorted_indices, has_canonical_format=lhs.has_canonical_format, ) if rhs_csr.data.dtype in {mx.float16, mx.bfloat16}: rhs_csr = CSRArray( data=rhs_csr.data.astype(mx.float32), indices=rhs_csr.indices, indptr=rhs_csr.indptr, shape=rhs_csr.shape, sorted_indices=rhs_csr.sorted_indices, has_canonical_format=rhs_csr.has_canonical_format, ) if lhs.data.dtype != rhs_csr.data.dtype: raise TypeError( "CSRArray.dot operands must have the same dtype after low-precision promotion." ) if lhs.data.dtype not in {mx.float32, mx.complex64}: raise TypeError("CSRArray.dot supports float32 and complex64 data.") return _native.csr_dot( lhs.data, lhs.indices, lhs.indptr, rhs_csr.data, rhs_csr.indices, rhs_csr.indptr, lhs.shape, )
[docs] def __matmul__(self, rhs): """Matrix multiplication via the ``@`` operator. Dispatches to :func:`~mlx_sparse.csr_matmat` for CSR operands, :func:`~mlx_sparse.csr_matvec` for a rank-1 dense ``rhs``, or :func:`~mlx_sparse.csr_matmul` for rank-2 and batched dense operands. Dense inputs are converted to MLX arrays if needed. Args: rhs: CSR sparse matrix, dense vector of shape ``(n_cols,)``, dense matrix of shape ``(n_cols, k)``, or batched dense matrix with sparse dimension at ``rhs.shape[-2]``. Returns: A CSRArray for CSR RHS, otherwise a dense MLX array. Raises: ValueError: If dense ``rhs.ndim`` is not at least 1. TypeError: If ``rhs`` dtype does not match ``self.data`` dtype. """ if isinstance(rhs, CSRArray): from mlx_sparse._ops import csr_matmat return csr_matmat(self, rhs) from mlx_sparse._coo import COOArray if isinstance(rhs, COOArray): from mlx_sparse._ops import csr_matmat return csr_matmat(self, rhs.tocsr(canonical=True)) rhs = ensure_mx_array(rhs) if rhs.ndim == 1: from mlx_sparse._ops import csr_matvec return csr_matvec(self, rhs) if rhs.ndim >= 2: from mlx_sparse._ops import csr_matmul return csr_matmul(self, rhs) raise ValueError(f"CSR matmul expects rank-1 or higher RHS, got {rhs.shape}.")
[docs] def __rmul__(self, other): """Multiply the current CSRArray by a number using the ``*`` operator. This returns a new CSRArray with the data multiplied by the number, and therefore does not in-place mutate the current CSRArray. Args: other: A valid number (complex or not). Returns: A new CSRArray with the data multiplied by the number. Raises: TypeError: If ``other`` is not an actual number. """ other = sanitize_scalar(other) return CSRArray( data=other * self.data, indices=self.indices, indptr=self.indptr, shape=self.shape, sorted_indices=self.sorted_indices, has_canonical_format=self.has_canonical_format, )
[docs] def __mul__(self, other): """Multiply the current CSRArray by a number using the ``*`` operator. This returns a new CSRArray with the data multiplied by the number, and therefore does not in-place mutate the current CSRArray. Args: other: A valid number (complex or not). Returns: A new CSRArray with the data multiplied by the number. Raises: TypeError: If ``other`` is not an actual number. """ return self.__rmul__(other)
[docs] def csr_array( arg, shape, *, validate: ValidationMode = "metadata", sorted_indices: bool = False, canonical: bool | None = None, ) -> CSRArray: """Construct a :class:`CSRArray` from explicit CSR buffers. Accepts either a ``(data, indices, indptr)`` triple or an existing ``CSRArray``. All array inputs are converted to ``mlx.core.array`` if they are not already. Args: arg: A length-3 iterable ``(data, indices, indptr)`` where - *data*: non-zero values, shape ``(nnz,)``, dtype ``float32 | float16 | bfloat16 | complex64``. - *indices*: column indices, shape ``(nnz,)``, dtype ``int32 | int64``. - *indptr*: row pointers, shape ``(n_rows + 1,)``, same integer dtype as *indices*. Alternatively, an existing :class:`CSRArray` instance (returned unchanged if ``shape`` matches). shape: Matrix dimensions as a length-2 sequence ``(n_rows, n_cols)``. validate: Validation level, one of: - ``"metadata"`` *(default)*: checks ranks, lengths, and dtypes without reading array values. Safe to call on device arrays. - ``"full"`` / ``True``: also verifies ``indptr`` monotonicity and column index bounds. May synchronize to host. - ``False`` / ``"none"``: skips all checks. sorted_indices: Set to ``True`` to assert that column indices within each row are already sorted ascending. Default ``False``. canonical: Set to ``True`` to assert canonical format (sorted indices, no duplicate columns). Implies ``sorted_indices=True``. Default ``None`` (not asserted). Returns: A :class:`CSRArray` with the given buffers and shape. Raises: TypeError: If ``arg`` is not a 3-tuple or a ``CSRArray``, or if dtype constraints are violated. ValueError: If shape or length constraints are violated. Example:: import mlx.core as mx import mlx_sparse as ms data = mx.array([1.0, 2.0, 3.0], dtype=mx.float32) indices = mx.array([0, 1, 0], dtype=mx.int32) indptr = mx.array([0, 2, 3], dtype=mx.int32) A = ms.csr_array((data, indices, indptr), shape=(2, 2)) # Full validation from host arrays: A_checked = ms.csr_array( (data, indices, indptr), shape=(2, 2), validate="full" ) """ mode = normalize_validation_mode(validate) shape = normalize_shape(shape) if isinstance(arg, CSRArray): array = arg if array.shape != shape: raise ValueError( f"CSRArray shape mismatch: got {array.shape}, expected {shape}." ) return array try: data, indices, indptr = arg except Exception as exc: raise TypeError( "csr_array expects (data, indices, indptr) or a CSRArray instance." ) from exc data = ensure_mx_array(data) indices = ensure_mx_array(indices) indptr = ensure_mx_array(indptr) if mode != "none": validate_csr_metadata(data, indices, indptr, shape) if mode == "full": validate_csr_values(indices, indptr, shape, data.shape[0]) has_canonical_format = bool(canonical) if canonical is not None else False if has_canonical_format: sorted_indices = True return CSRArray( data=data, indices=indices, indptr=indptr, shape=shape, sorted_indices=sorted_indices, has_canonical_format=has_canonical_format, )