Source code for mlx_sparse._coo

# Copyright (c) 2026 The mlx-sparse contributors - All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#    http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import annotations

from dataclasses import dataclass

import mlx.core as mx

import mlx_sparse._native as _native
from mlx_sparse._csr import CSRArray
from mlx_sparse._typing import Shape2D, ValidationMode
from mlx_sparse._validation import (
    ensure_mx_array,
    normalize_shape,
    normalize_validation_mode,
    sanitize_scalar,
    validate_coo_metadata,
    validate_coo_values,
)


[docs] @dataclass(frozen=True, slots=True) class COOArray: """A 2D sparse matrix in Coordinate (COO) format. COOArray stores a sparse matrix as three parallel arrays: - **data**: non-zero values, shape ``(nnz,)``. - **row**: row coordinate for each value, shape ``(nnz,)``. - **col**: column coordinate for each value, shape ``(nnz,)``. COO is the primary *construction* format. It allows duplicate ``(row, col)`` entries and does not require sorted coordinates, making it straightforward to assemble matrices from element lists, graph adjacency lists, finite-element stencils, or Hamiltonians. COO supports native sparse-dense products directly. For heavily repeated row-oriented workloads, CSR may still be preferable after construction because its compressed row layout avoids coordinate scatter. **Format invariants** (checked by ``validate="metadata"`` by default): - All three arrays must be rank-1 with the same length. - ``row`` and ``col`` share the same integer dtype (``int32`` or ``int64``). - ``data`` dtype is one of ``float32``, ``float16``, ``bfloat16``, or ``complex64``. **Additional value-level invariants** (``validate="full"`` only): - ``0 <= row[i] < n_rows`` for all entries. - ``0 <= col[i] < n_cols`` for all entries. Args: data: Non-zero values, shape ``(nnz,)``. row: Row coordinates, shape ``(nnz,)``. col: Column coordinates, shape ``(nnz,)``. shape: Matrix dimensions as ``(n_rows, n_cols)``. has_canonical_format: Hint that coordinates are sorted and duplicate- free. Defaults to ``False``. Example:: import mlx.core as mx import mlx_sparse as ms data = mx.array([2.0, -1.0, 4.0], dtype=mx.float32) row = mx.array([0, 0, 1], dtype=mx.int32) col = mx.array([0, 2, 1], dtype=mx.int32) # 2×3 matrix: [[2, 0, -1], # [0, 4, 0]] coo = ms.coo_array((data, (row, col)), shape=(2, 3)) csr = coo.tocsr(canonical=True) """ data: mx.array row: mx.array col: mx.array shape: Shape2D has_canonical_format: bool = False def __post_init__(self) -> None: object.__setattr__(self, "shape", normalize_shape(self.shape)) @property def nnz(self) -> int: """Number of stored values (including any duplicates).""" return int(self.data.shape[0]) @property def dtype(self): """Value dtype of the stored non-zeros (e.g. ``mlx.core.float32``).""" return self.data.dtype @property def index_dtype(self): """Integer dtype used for ``row`` and ``col``.""" return self.row.dtype @property def ndim(self) -> int: """Always 2. Sparse arrays in this package are rank-2.""" return 2 def __repr__(self) -> str: return ( "COOArray(" f"shape={self.shape}, nnz={self.nnz}, dtype={self.dtype}, " f"index_dtype={self.index_dtype}, " f"has_canonical_format={self.has_canonical_format})" )
[docs] def tocsr(self, *, canonical: bool = False) -> CSRArray: """Convert to :class:`CSRArray`. Sorts entries by row then column and builds a ``(n_rows + 1,)`` row pointer array. Duplicate ``(row, col)`` entries are preserved in the raw output. Pass ``canonical=True`` to sum them. Args: canonical: If ``True``, call :meth:`~CSRArray.canonicalize` on the result to sort indices and sum duplicates. Default ``False``. Returns: A :class:`CSRArray` with ``sorted_indices=True``. If ``canonical=True``, also ``has_canonical_format=True``. """ data, indices, indptr = _native.coo_tocsr( self.data, self.row, self.col, self.shape ) csr = CSRArray( data=data, indices=indices, indptr=indptr, shape=self.shape, sorted_indices=True, has_canonical_format=False, ) if canonical: return csr.canonicalize() return csr
[docs] def tocsc(self, *, canonical: bool = False): """Convert to :class:`~mlx_sparse.CSCArray`.""" from mlx_sparse._csc import CSCArray data, indices, indptr = _native.coo_tocsc( self.data, self.row, self.col, self.shape ) csc = CSCArray( data=data, indices=indices, indptr=indptr, shape=self.shape, sorted_indices=True, has_canonical_format=False, ) if canonical: return csc.canonicalize() return csc
[docs] def todense(self) -> mx.array: """Materialize as a dense MLX array. Internally converts to CSR and then calls :meth:`~CSRArray.todense`. Duplicate entries are summed. Returns: Dense array of shape ``(n_rows, n_cols)`` with the same dtype as ``self.data``. """ return self.tocsr(canonical=False).todense()
[docs] def row_sums(self) -> mx.array: """Return the sum of stored values in each COO row.""" from mlx_sparse._ops import coo_row_sums return coo_row_sums(self)
[docs] def col_sums(self) -> mx.array: """Return the sum of stored values in each COO column.""" from mlx_sparse._ops import coo_col_sums return coo_col_sums(self)
[docs] def column_sums(self) -> mx.array: """Alias for :meth:`col_sums`.""" return self.col_sums()
[docs] def row_norms(self) -> mx.array: """Return the dense-semantics L2 norm of each COO row as ``float32``.""" from mlx_sparse._ops import coo_row_norms return coo_row_norms(self)
[docs] def col_norms(self) -> mx.array: """Return the dense-semantics L2 norm of each COO column as ``float32``.""" from mlx_sparse._ops import coo_col_norms return coo_col_norms(self)
[docs] def column_norms(self) -> mx.array: """Alias for :meth:`col_norms`.""" return self.col_norms()
[docs] def diagonal(self) -> mx.array: """Return the summed main diagonal.""" from mlx_sparse._ops import coo_diagonal return coo_diagonal(self)
[docs] def trace(self) -> mx.array: """Return the summed main diagonal as a scalar.""" from mlx_sparse._ops import coo_trace return coo_trace(self)
[docs] def sum(self, axis=None) -> mx.array: """Sum sparse values over all entries, rows, or columns. ``axis=None`` returns a scalar, ``axis=1`` returns row sums, and ``axis=0`` returns column sums. """ if axis is None: return mx.sum(self.row_sums()) if axis in (1, -1): return self.row_sums() if axis in (0, -2): return self.col_sums() raise ValueError(f"COOArray.sum axis must be None, 0, or 1; got {axis!r}.")
[docs] def __matmul__(self, rhs): """Matrix multiplication via the ``@`` operator.""" from mlx_sparse._csc import CSCArray from mlx_sparse._ops import coo_matmat, coo_matmul, coo_matvec if isinstance(rhs, COOArray): return coo_matmat(self, rhs) if isinstance(rhs, (CSRArray, CSCArray)): raise NotImplementedError( "Mixed-format COO sparse-sparse matmul is not implemented. " "Convert explicitly if another format is acceptable for your workload." ) rhs = ensure_mx_array(rhs) if rhs.ndim == 1: return coo_matvec(self, rhs) if rhs.ndim >= 2: return coo_matmul(self, rhs) raise ValueError(f"COO matmul expects rank-1 or higher RHS, got {rhs.shape}.")
[docs] def __rmul__(self, other): """Multiply the current CSCArray by a number using the ``*`` operator. This returns a new CSCArray with the data multiplied by the number, and therefore does not in-place mutate the current CSCArray. Args: other: A valid number (complex or not). Returns: A new CSCArray with the data multiplied by the number. Raises: TypeError: If ``other`` is not an actual number. """ other = sanitize_scalar(other) return COOArray( data=other * self.data, row=self.row, col=self.col, shape=self.shape, has_canonical_format=self.has_canonical_format, )
[docs] def __mul__(self, other): """Multiply the current CSCArray by a number using the ``*`` operator. This returns a new CSCArray with the data multiplied by the number, and therefore does not in-place mutate the current CSCArray. Args: other: A valid number (complex or not). Returns: A new CSCArray with the data multiplied by the number. Raises: TypeError: If ``other`` is not an actual number. """ return self.__rmul__(other)
[docs] def coo_array( arg, shape, *, validate: ValidationMode = "metadata", canonical: bool | None = None, ) -> COOArray: """Construct a :class:`COOArray` from coordinate arrays. Accepts either a ``(data, (row, col))`` pair or an existing ``COOArray``. All array inputs are converted to ``mlx.core.array`` if they are not already. Args: arg: A ``(data, (row, col))`` pair where - *data*: non-zero values, shape ``(nnz,)``, dtype ``float32 | float16 | bfloat16 | complex64``. - *row*: row coordinates, shape ``(nnz,)``, dtype ``int32 | int64``. - *col*: column coordinates, shape ``(nnz,)``, same integer dtype as *row*. Alternatively, an existing :class:`COOArray` (returned unchanged if ``shape`` matches). shape: Matrix dimensions as a length-2 sequence ``(n_rows, n_cols)``. validate: Validation level, one of: - ``"metadata"`` *(default)*: checks ranks, lengths, and dtypes. - ``"full"`` / ``True``: also verifies coordinate bounds. May synchronize to host. - ``False`` / ``"none"``: skips all checks. canonical: Set to ``True`` to assert the coordinates are sorted and duplicate-free. Default ``None`` (not asserted). Returns: A :class:`COOArray` with the given buffers and shape. Raises: TypeError: If ``arg`` cannot be unpacked as ``(data, (row, col))``, or if dtype constraints are violated. ValueError: If shape or length constraints are violated. Example:: import mlx.core as mx import mlx_sparse as ms data = mx.array([1.0, 2.0, 3.0, 2.0], dtype=mx.float32) row = mx.array([0, 0, 1, 0], dtype=mx.int32) col = mx.array([0, 1, 0, 0], dtype=mx.int32) # Two entries at (0, 0). Summed when converting to CSR. coo = ms.coo_array((data, (row, col)), shape=(2, 2)) csr = coo.tocsr(canonical=True) # sums duplicates """ mode = normalize_validation_mode(validate) shape = normalize_shape(shape) if isinstance(arg, COOArray): array = arg if array.shape != shape: raise ValueError( f"COOArray shape mismatch: got {array.shape}, expected {shape}." ) return array try: data, coords = arg row, col = coords except Exception as exc: raise TypeError("coo_array expects (data, (row, col)) or a COOArray.") from exc data = ensure_mx_array(data) row = ensure_mx_array(row) col = ensure_mx_array(col) if mode != "none": validate_coo_metadata(data, row, col, shape) if mode == "full": validate_coo_values(row, col, shape) return COOArray( data=data, row=row, col=col, shape=shape, has_canonical_format=bool(canonical), )