Containers#

Sparse array containers in mlx-sparse are immutable frozen dataclasses. They hold mlx.core.array buffers and format metadata but do not subclass mx.array. Structural operations return new instances and nothing is mutated in place.

CSRArray#

class mlx_sparse.CSRArray(data, indices, indptr, shape, sorted_indices=False, has_canonical_format=False)[source]#

Bases: object

A 2D sparse matrix in Compressed Sparse Row (CSR) format.

CSRArray stores a 2D sparse matrix using three MLX arrays:

  • data: non-zero values, shape (nnz,).

  • indices: column index of each stored value, shape (nnz,).

  • indptr: row pointer array, shape (n_rows + 1,).

Row i spans data[indptr[i] : indptr[i+1]] with corresponding column indices indices[indptr[i] : indptr[i+1]]. Duplicate column entries are permitted unless the matrix is in canonical form.

Format invariants (checked by validate="metadata" by default):

  • All three arrays must be rank-1.

  • data.shape[0] == indices.shape[0] (the nnz count).

  • indptr.shape[0] == n_rows + 1.

  • indices and indptr share the same integer dtype (int32 or int64).

  • data dtype is one of float32, float16, bfloat16, or complex64.

Additional value-level invariants (validate="full" only):

  • indptr[0] == 0, indptr[-1] == nnz.

  • indptr is monotonically nondecreasing.

  • 0 <= indices[j] < n_cols for all stored values.

CSRArray is immutable (frozen dataclass). Structural operations return new instances. The sorted_indices and has_canonical_format flags are metadata hints. Set them only when the input is already known to satisfy those properties. Use canonicalize() to sort and sum duplicates.

Parameters:
  • data (mlx.core.array) – Non-zero values, shape (nnz,).

  • indices (mlx.core.array) – Column indices, shape (nnz,).

  • indptr (mlx.core.array) – Row pointer array, shape (n_rows + 1,).

  • shape (tuple[int, int]) – Matrix dimensions as (n_rows, n_cols).

  • sorted_indices (bool) – Hint that column indices within each row are sorted ascending. Defaults to False.

  • has_canonical_format (bool) – Hint that the matrix has sorted column indices and no duplicate column index in any row. Implies sorted_indices=True. Defaults to False.

Example:

import mlx.core as mx
import mlx_sparse as ms

data = mx.array([2.0, -1.0, 4.0, 5.0], dtype=mx.float32)
indices = mx.array([0, 2, 1, 3], dtype=mx.int32)
indptr = mx.array([0, 2, 2, 4], dtype=mx.int32)
A = ms.csr_array((data, indices, indptr), shape=(3, 4))
x = mx.array([1.0, 0.0, 1.0, 1.0], dtype=mx.float32)
y = A @ x  # shape (3,)

Properties

nnz

Number of stored values (including any duplicates).

dtype

Value dtype of the stored non-zeros (e.g. mlx.core.float32).

index_dtype

Integer dtype used for indices and indptr.

ndim

Always 2.

T

Transposed matrix.

H

Hermitian (conjugate) transpose.

Methods

todense()

Materialize the sparse matrix as a dense MLX array.

tocsc(*[, canonical])

Convert to CSCArray.

sort_indices()

Return a new CSRArray with column indices sorted within each row.

sum_duplicates()

Sum duplicate column entries within each row.

canonicalize()

Return the canonical form: sorted indices, no duplicates.

transpose()

Transpose the sparse matrix, returning a new CSRArray.

conj()

Complex-conjugate the stored values.

conjugate()

Alias for conj().

data: mlx.core.array#
indices: mlx.core.array#
indptr: mlx.core.array#
shape: tuple[int, int]#
sorted_indices: bool#
has_canonical_format: bool#
property nnz: int#

Number of stored values (including any duplicates).

property dtype#

Value dtype of the stored non-zeros (e.g. mlx.core.float32).

property index_dtype#

Integer dtype used for indices and indptr.

property ndim: int#

Always 2. Sparse arrays in this package are rank-2.

todense()[source]#

Materialize the sparse matrix as a dense MLX array.

Duplicate column entries in the same row are summed, matching the semantics of canonicalize().todense().

Returns:

Dense array of shape (n_rows, n_cols) and the same dtype as self.data.

Return type:

mlx.core.array

row_sums()[source]#

Return the sum of stored values in each CSR row.

Return type:

mlx.core.array

col_sums()[source]#

Return the sum of stored values in each CSR column.

Return type:

mlx.core.array

column_sums()[source]#

Alias for col_sums().

Return type:

mlx.core.array

row_norms()[source]#

Return the L2 norm of each CSR row as float32.

Return type:

mlx.core.array

diagonal()[source]#

Return the summed main diagonal.

Return type:

mlx.core.array

trace()[source]#

Return the summed main diagonal as a scalar.

Return type:

mlx.core.array

sum(axis=None)[source]#

Sum sparse values over all entries, rows, or columns.

axis=None returns a scalar, axis=1 returns row sums, and axis=0 returns column sums.

Return type:

mlx.core.array

sort_indices()[source]#

Return a new CSRArray with column indices sorted within each row.

If self.sorted_indices is already True, returns self unchanged (no copy). Otherwise dispatches the native sort primitive.

Returns:

A new CSRArray with sorted_indices=True and has_canonical_format=False (duplicates may still be present).

Return type:

CSRArray

sum_duplicates()[source]#

Sum duplicate column entries within each row.

Sorts indices first (via sort_indices), then accumulates entries that share the same column index. The resulting nnz may be smaller than the original.

Returns:

A new CSRArray with sorted_indices=True and has_canonical_format=True.

Return type:

CSRArray

canonicalize()[source]#

Return the canonical form: sorted indices, no duplicates.

If self.has_canonical_format is already True, returns self with no work done. Otherwise calls sum_duplicates().

Returns:

A CSRArray with has_canonical_format=True.

Return type:

CSRArray

transpose()[source]#

Transpose the sparse matrix, returning a new CSRArray.

The result has shape=(n_cols, n_rows) and sorted_indices=True. If the source has has_canonical_format=True, the result also inherits that flag.

Returns:

A new CSRArray with shape (n_cols, n_rows).

Return type:

CSRArray

tocsc(*, canonical=None)[source]#

Convert to CSCArray.

Parameters:

canonical (bool | None) – If True, canonicalize the returned CSC matrix. If False or None (default), return the structural conversion as produced by the native count/prefix/fill path. The structural path preserves values but does not promise sorted output metadata on every backend.

property T: CSRArray#

Transposed matrix. Alias for transpose().

conj()[source]#

Complex-conjugate the stored values.

Structure (indices, indptr, shape) is shared. For real dtypes this is a no-op at the value level but still returns a new CSRArray pointing to the conjugated data array.

Returns:

A new CSRArray with conjugated data.

Return type:

CSRArray

conjugate()[source]#

Alias for conj().

Return type:

CSRArray

property H: CSRArray#

Hermitian (conjugate) transpose. Equivalent to conj().T.

vdot(other)[source]#

Sparse Frobenius inner product with another sparse array.

Both operands are canonicalized CSR matrices and the matching-column merge is executed by the native sparse kernel. Dense materialization is never used.

Return type:

mlx.core.array

dot(other)[source]#

Sparse Frobenius dot product with another sparse array.

Unlike vdot(), complex operands are not conjugated.

Return type:

mlx.core.array

__matmul__(rhs)[source]#

Matrix multiplication via the @ operator.

Dispatches to csr_matmat() for CSR operands, csr_matvec() for a rank-1 dense rhs, or csr_matmul() for rank-2 and batched dense operands. Dense inputs are converted to MLX arrays if needed.

Parameters:

rhs – CSR sparse matrix, dense vector of shape (n_cols,), dense matrix of shape (n_cols, k), or batched dense matrix with sparse dimension at rhs.shape[-2].

Returns:

A CSRArray for CSR RHS, otherwise a dense MLX array.

Raises:
  • ValueError – If dense rhs.ndim is not at least 1.

  • TypeError – If rhs dtype does not match self.data dtype.

__rmul__(other)[source]#

Multiply the current CSRArray by a number using the * operator.

This returns a new CSRArray with the data multiplied by the number, and therefore does not in-place mutate the current CSRArray.

Parameters:

other – A valid number (complex or not).

Returns:

A new CSRArray with the data multiplied by the number.

Raises:

TypeError – If other is not an actual number.

__mul__(other)[source]#

Multiply the current CSRArray by a number using the * operator.

This returns a new CSRArray with the data multiplied by the number, and therefore does not in-place mutate the current CSRArray.

Parameters:

other – A valid number (complex or not).

Returns:

A new CSRArray with the data multiplied by the number.

Raises:

TypeError – If other is not an actual number.

CSCArray#

class mlx_sparse.CSCArray(data, indices, indptr, shape, sorted_indices=False, has_canonical_format=False)[source]#

Bases: object

A 2D sparse matrix in Compressed Sparse Column (CSC) format.

CSCArray stores a 2D sparse matrix using three MLX arrays:

  • data: non-zero values, shape (nnz,).

  • indices: row index of each stored value, shape (nnz,).

  • indptr: column pointer array, shape (n_cols + 1,).

Column j spans data[indptr[j] : indptr[j+1]] with corresponding row indices indices[indptr[j] : indptr[j+1]]. Duplicate row entries within a column are permitted unless the matrix is in canonical form.

CSC is the column-compressed dual of CSR. It is the natural layout for operations that consume one full column at a time, such as transpose matvec, column-oriented canonicalization, and future direct factorization kernels.

Parameters:
  • data (mlx.core.array) – Non-zero values, shape (nnz,).

  • indices (mlx.core.array) – Row indices, shape (nnz,).

  • indptr (mlx.core.array) – Column pointer array, shape (n_cols + 1,).

  • shape (tuple[int, int]) – Matrix dimensions as (n_rows, n_cols).

  • sorted_indices (bool) – Hint that row indices within each column are sorted ascending. Defaults to False.

  • has_canonical_format (bool) – Hint that the matrix has sorted row indices and no duplicate row index in any column. Implies sorted_indices=True. Defaults to False.

Properties

nnz

Number of stored values (including any duplicates).

dtype

Value dtype of the stored non-zeros.

index_dtype

Integer dtype used for indices and indptr.

ndim

Always 2.

T

Transposed matrix.

H

Hermitian (conjugate) transpose.

Methods

todense()

Materialize the sparse matrix as a dense MLX array.

tocsr(*[, canonical])

Convert to CSRArray.

sort_indices()

Return a new CSCArray with row indices sorted within each column.

sum_duplicates()

Sum duplicate row entries within each column.

canonicalize()

Return canonical form: sorted row indices, no duplicates.

transpose()

Transpose the sparse matrix, returning a zero-copy CSRArray.

conj()

Complex-conjugate the stored values.

conjugate()

Alias for conj().

data: mlx.core.array#
indices: mlx.core.array#
indptr: mlx.core.array#
shape: tuple[int, int]#
sorted_indices: bool#
has_canonical_format: bool#
property nnz: int#

Number of stored values (including any duplicates).

property dtype#

Value dtype of the stored non-zeros.

property index_dtype#

Integer dtype used for indices and indptr.

property ndim: int#

Always 2. Sparse arrays in this package are rank-2.

todense()[source]#

Materialize the sparse matrix as a dense MLX array.

Return type:

mlx.core.array

row_sums()[source]#

Return the sum of stored values in each CSC row.

Return type:

mlx.core.array

col_sums()[source]#

Return the sum of stored values in each CSC column.

Return type:

mlx.core.array

column_sums()[source]#

Alias for col_sums().

Return type:

mlx.core.array

row_norms()[source]#

Return the dense-semantics L2 norm of each CSC row as float32.

Return type:

mlx.core.array

col_norms()[source]#

Return the dense-semantics L2 norm of each CSC column as float32.

Return type:

mlx.core.array

column_norms()[source]#

Alias for col_norms().

Return type:

mlx.core.array

diagonal()[source]#

Return the summed main diagonal.

Return type:

mlx.core.array

trace()[source]#

Return the summed main diagonal as a scalar.

Return type:

mlx.core.array

sum(axis=None)[source]#

Sum sparse values over all entries, rows, or columns.

axis=None returns a scalar, axis=1 returns row sums, and axis=0 returns column sums.

Return type:

mlx.core.array

tocsr(*, canonical=None)[source]#

Convert to CSRArray.

Parameters:

canonical (bool | None) – If True, canonicalize the returned CSR matrix. If False or None (default), return the structural conversion as produced by the native count/prefix/fill path. The structural path preserves values but does not promise sorted output metadata on every backend.

Return type:

CSRArray

sort_indices()[source]#

Return a new CSCArray with row indices sorted within each column.

Return type:

CSCArray

sum_duplicates()[source]#

Sum duplicate row entries within each column.

Return type:

CSCArray

canonicalize()[source]#

Return canonical form: sorted row indices, no duplicates.

Return type:

CSCArray

transpose()[source]#

Transpose the sparse matrix, returning a zero-copy CSRArray.

Return type:

CSRArray

property T: CSRArray#

Transposed matrix. Alias for transpose().

conj()[source]#

Complex-conjugate the stored values.

Return type:

CSCArray

conjugate()[source]#

Alias for conj().

Return type:

CSCArray

property H: CSRArray#

Hermitian (conjugate) transpose. Equivalent to conj().T.

__matmul__(rhs)[source]#

Matrix multiplication via the @ operator.

__rmul__(other)[source]#

Multiply the current CSCArray by a number using the * operator.

This returns a new CSCArray with the data multiplied by the number, and therefore does not in-place mutate the current CSCArray.

Parameters:

other – A valid number (complex or not).

Returns:

A new CSCArray with the data multiplied by the number.

Raises:

TypeError – If other is not an actual number.

__mul__(other)[source]#

Multiply the current CSCArray by a number using the * operator.

This returns a new CSCArray with the data multiplied by the number, and therefore does not in-place mutate the current CSCArray.

Parameters:

other – A valid number (complex or not).

Returns:

A new CSCArray with the data multiplied by the number.

Raises:

TypeError – If other is not an actual number.

COOArray#

class mlx_sparse.COOArray(data, row, col, shape, has_canonical_format=False)[source]#

Bases: object

A 2D sparse matrix in Coordinate (COO) format.

COOArray stores a sparse matrix as three parallel arrays:

  • data: non-zero values, shape (nnz,).

  • row: row coordinate for each value, shape (nnz,).

  • col: column coordinate for each value, shape (nnz,).

COO is the primary construction format. It allows duplicate (row, col) entries and does not require sorted coordinates, making it straightforward to assemble matrices from element lists, graph adjacency lists, finite-element stencils, or Hamiltonians.

COO supports native sparse-dense products directly. For heavily repeated row-oriented workloads, CSR may still be preferable after construction because its compressed row layout avoids coordinate scatter.

Format invariants (checked by validate="metadata" by default):

  • All three arrays must be rank-1 with the same length.

  • row and col share the same integer dtype (int32 or int64).

  • data dtype is one of float32, float16, bfloat16, or complex64.

Additional value-level invariants (validate="full" only):

  • 0 <= row[i] < n_rows for all entries.

  • 0 <= col[i] < n_cols for all entries.

Parameters:
  • data (mlx.core.array) – Non-zero values, shape (nnz,).

  • row (mlx.core.array) – Row coordinates, shape (nnz,).

  • col (mlx.core.array) – Column coordinates, shape (nnz,).

  • shape (tuple[int, int]) – Matrix dimensions as (n_rows, n_cols).

  • has_canonical_format (bool) – Hint that coordinates are sorted and duplicate- free. Defaults to False.

Example:

import mlx.core as mx
import mlx_sparse as ms

data = mx.array([2.0, -1.0, 4.0], dtype=mx.float32)
row = mx.array([0, 0, 1], dtype=mx.int32)
col = mx.array([0, 2, 1], dtype=mx.int32)

# 2×3 matrix:  [[2, 0, -1],
#               [0, 4,  0]]
coo = ms.coo_array((data, (row, col)), shape=(2, 3))
csr = coo.tocsr(canonical=True)

Properties

nnz

Number of stored values (including any duplicates).

dtype

Value dtype of the stored non-zeros (e.g. mlx.core.float32).

index_dtype

Integer dtype used for row and col.

ndim

Always 2.

Methods

tocsr(*[, canonical])

Convert to CSRArray.

tocsc(*[, canonical])

Convert to CSCArray.

todense()

Materialize as a dense MLX array.

data: mlx.core.array#
row: mlx.core.array#
col: mlx.core.array#
shape: tuple[int, int]#
has_canonical_format: bool#
property nnz: int#

Number of stored values (including any duplicates).

property dtype#

Value dtype of the stored non-zeros (e.g. mlx.core.float32).

property index_dtype#

Integer dtype used for row and col.

property ndim: int#

Always 2. Sparse arrays in this package are rank-2.

tocsr(*, canonical=False)[source]#

Convert to CSRArray.

Sorts entries by row then column and builds a (n_rows + 1,) row pointer array. Duplicate (row, col) entries are preserved in the raw output. Pass canonical=True to sum them.

Parameters:

canonical (bool) – If True, call canonicalize() on the result to sort indices and sum duplicates. Default False.

Returns:

A CSRArray with sorted_indices=True. If canonical=True, also has_canonical_format=True.

Return type:

CSRArray

tocsc(*, canonical=False)[source]#

Convert to CSCArray.

Parameters:

canonical (bool)

todense()[source]#

Materialize as a dense MLX array.

Internally converts to CSR and then calls todense(). Duplicate entries are summed.

Returns:

Dense array of shape (n_rows, n_cols) with the same dtype as self.data.

Return type:

mlx.core.array

row_sums()[source]#

Return the sum of stored values in each COO row.

Return type:

mlx.core.array

col_sums()[source]#

Return the sum of stored values in each COO column.

Return type:

mlx.core.array

column_sums()[source]#

Alias for col_sums().

Return type:

mlx.core.array

row_norms()[source]#

Return the dense-semantics L2 norm of each COO row as float32.

Return type:

mlx.core.array

col_norms()[source]#

Return the dense-semantics L2 norm of each COO column as float32.

Return type:

mlx.core.array

column_norms()[source]#

Alias for col_norms().

Return type:

mlx.core.array

diagonal()[source]#

Return the summed main diagonal.

Return type:

mlx.core.array

trace()[source]#

Return the summed main diagonal as a scalar.

Return type:

mlx.core.array

sum(axis=None)[source]#

Sum sparse values over all entries, rows, or columns.

axis=None returns a scalar, axis=1 returns row sums, and axis=0 returns column sums.

Return type:

mlx.core.array

__matmul__(rhs)[source]#

Matrix multiplication via the @ operator.

__rmul__(other)[source]#

Multiply the current CSCArray by a number using the * operator.

This returns a new CSCArray with the data multiplied by the number, and therefore does not in-place mutate the current CSCArray.

Parameters:

other – A valid number (complex or not).

Returns:

A new CSCArray with the data multiplied by the number.

Raises:

TypeError – If other is not an actual number.

__mul__(other)[source]#

Multiply the current CSCArray by a number using the * operator.

This returns a new CSCArray with the data multiplied by the number, and therefore does not in-place mutate the current CSCArray.

Parameters:

other – A valid number (complex or not).

Returns:

A new CSCArray with the data multiplied by the number.

Raises:

TypeError – If other is not an actual number.

Utility functions#

mlx_sparse.issparse(x)[source]#

Return True if x is a recognized mlx-sparse container.

Currently returns True for COOArray, CSRArray, and CSCArray instances. All other objects return False.

Parameters:

x – Any Python object.

Returns:

True if x is a COOArray, CSRArray, or CSCArray.

Return type:

bool

Example:

import mlx_sparse as ms

ms.issparse(my_csr)  # True
ms.issparse(mx.ones((3, 4)))  # False