Containers#
Sparse array containers in mlx-sparse are immutable frozen dataclasses. They
hold mlx.core.array buffers and format metadata but do not subclass
mx.array. Structural operations return new instances and nothing is mutated
in place.
CSRArray#
- class mlx_sparse.CSRArray(data, indices, indptr, shape, sorted_indices=False, has_canonical_format=False)[source]#
Bases:
objectA 2D sparse matrix in Compressed Sparse Row (CSR) format.
CSRArray stores a 2D sparse matrix using three MLX arrays:
data: non-zero values, shape
(nnz,).indices: column index of each stored value, shape
(nnz,).indptr: row pointer array, shape
(n_rows + 1,).
Row
ispansdata[indptr[i] : indptr[i+1]]with corresponding column indicesindices[indptr[i] : indptr[i+1]]. Duplicate column entries are permitted unless the matrix is in canonical form.Format invariants (checked by
validate="metadata"by default):All three arrays must be rank-1.
data.shape[0] == indices.shape[0](thennzcount).indptr.shape[0] == n_rows + 1.indicesandindptrshare the same integer dtype (int32orint64).datadtype is one offloat32,float16,bfloat16, orcomplex64.
Additional value-level invariants (
validate="full"only):indptr[0] == 0,indptr[-1] == nnz.indptris monotonically nondecreasing.0 <= indices[j] < n_colsfor all stored values.
CSRArrayis immutable (frozen dataclass). Structural operations return new instances. Thesorted_indicesandhas_canonical_formatflags are metadata hints. Set them only when the input is already known to satisfy those properties. Usecanonicalize()to sort and sum duplicates.- Parameters:
data (mlx.core.array) – Non-zero values, shape
(nnz,).indices (mlx.core.array) – Column indices, shape
(nnz,).indptr (mlx.core.array) – Row pointer array, shape
(n_rows + 1,).shape (tuple[int, int]) – Matrix dimensions as
(n_rows, n_cols).sorted_indices (bool) – Hint that column indices within each row are sorted ascending. Defaults to
False.has_canonical_format (bool) – Hint that the matrix has sorted column indices and no duplicate column index in any row. Implies
sorted_indices=True. Defaults toFalse.
Example:
import mlx.core as mx import mlx_sparse as ms data = mx.array([2.0, -1.0, 4.0, 5.0], dtype=mx.float32) indices = mx.array([0, 2, 1, 3], dtype=mx.int32) indptr = mx.array([0, 2, 2, 4], dtype=mx.int32) A = ms.csr_array((data, indices, indptr), shape=(3, 4)) x = mx.array([1.0, 0.0, 1.0, 1.0], dtype=mx.float32) y = A @ x # shape (3,)
Properties
Number of stored values (including any duplicates).
Value dtype of the stored non-zeros (e.g.
mlx.core.float32).Integer dtype used for
indicesandindptr.Always 2.
Transposed matrix.
Hermitian (conjugate) transpose.
Methods
todense()Materialize the sparse matrix as a dense MLX array.
tocsc(*[, canonical])Convert to
CSCArray.Return a new CSRArray with column indices sorted within each row.
Sum duplicate column entries within each row.
Return the canonical form: sorted indices, no duplicates.
Transpose the sparse matrix, returning a new CSRArray.
conj()Complex-conjugate the stored values.
Alias for
conj().- data: mlx.core.array#
- indices: mlx.core.array#
- indptr: mlx.core.array#
- property dtype#
Value dtype of the stored non-zeros (e.g.
mlx.core.float32).
- property index_dtype#
Integer dtype used for
indicesandindptr.
- todense()[source]#
Materialize the sparse matrix as a dense MLX array.
Duplicate column entries in the same row are summed, matching the semantics of
canonicalize().todense().- Returns:
Dense array of shape
(n_rows, n_cols)and the same dtype asself.data.- Return type:
mlx.core.array
- column_sums()[source]#
Alias for
col_sums().- Return type:
mlx.core.array
- sum(axis=None)[source]#
Sum sparse values over all entries, rows, or columns.
axis=Nonereturns a scalar,axis=1returns row sums, andaxis=0returns column sums.- Return type:
mlx.core.array
- sort_indices()[source]#
Return a new CSRArray with column indices sorted within each row.
If
self.sorted_indicesis alreadyTrue, returnsselfunchanged (no copy). Otherwise dispatches the native sort primitive.- Returns:
A new
CSRArraywithsorted_indices=Trueandhas_canonical_format=False(duplicates may still be present).- Return type:
- sum_duplicates()[source]#
Sum duplicate column entries within each row.
Sorts indices first (via
sort_indices), then accumulates entries that share the same column index. The resultingnnzmay be smaller than the original.- Returns:
A new
CSRArraywithsorted_indices=Trueandhas_canonical_format=True.- Return type:
- canonicalize()[source]#
Return the canonical form: sorted indices, no duplicates.
If
self.has_canonical_formatis alreadyTrue, returnsselfwith no work done. Otherwise callssum_duplicates().- Returns:
A
CSRArraywithhas_canonical_format=True.- Return type:
- transpose()[source]#
Transpose the sparse matrix, returning a new CSRArray.
The result has
shape=(n_cols, n_rows)andsorted_indices=True. If the source hashas_canonical_format=True, the result also inherits that flag.- Returns:
A new
CSRArraywith shape(n_cols, n_rows).- Return type:
- tocsc(*, canonical=None)[source]#
Convert to
CSCArray.- Parameters:
canonical (bool | None) – If
True, canonicalize the returned CSC matrix. IfFalseorNone(default), return the structural conversion as produced by the native count/prefix/fill path. The structural path preserves values but does not promise sorted output metadata on every backend.
- property T: CSRArray#
Transposed matrix. Alias for
transpose().
- conj()[source]#
Complex-conjugate the stored values.
Structure (indices, indptr, shape) is shared. For real dtypes this is a no-op at the value level but still returns a new
CSRArraypointing to the conjugated data array.- Returns:
A new
CSRArraywith conjugateddata.- Return type:
- vdot(other)[source]#
Sparse Frobenius inner product with another sparse array.
Both operands are canonicalized CSR matrices and the matching-column merge is executed by the native sparse kernel. Dense materialization is never used.
- Return type:
mlx.core.array
- dot(other)[source]#
Sparse Frobenius dot product with another sparse array.
Unlike
vdot(), complex operands are not conjugated.- Return type:
mlx.core.array
- __matmul__(rhs)[source]#
Matrix multiplication via the
@operator.Dispatches to
csr_matmat()for CSR operands,csr_matvec()for a rank-1 denserhs, orcsr_matmul()for rank-2 and batched dense operands. Dense inputs are converted to MLX arrays if needed.- Parameters:
rhs – CSR sparse matrix, dense vector of shape
(n_cols,), dense matrix of shape(n_cols, k), or batched dense matrix with sparse dimension atrhs.shape[-2].- Returns:
A CSRArray for CSR RHS, otherwise a dense MLX array.
- Raises:
ValueError – If dense
rhs.ndimis not at least 1.TypeError – If
rhsdtype does not matchself.datadtype.
- __rmul__(other)[source]#
Multiply the current CSRArray by a number using the
*operator.This returns a new CSRArray with the data multiplied by the number, and therefore does not in-place mutate the current CSRArray.
- Parameters:
other – A valid number (complex or not).
- Returns:
A new CSRArray with the data multiplied by the number.
- Raises:
TypeError – If
otheris not an actual number.
- __mul__(other)[source]#
Multiply the current CSRArray by a number using the
*operator.This returns a new CSRArray with the data multiplied by the number, and therefore does not in-place mutate the current CSRArray.
- Parameters:
other – A valid number (complex or not).
- Returns:
A new CSRArray with the data multiplied by the number.
- Raises:
TypeError – If
otheris not an actual number.
CSCArray#
- class mlx_sparse.CSCArray(data, indices, indptr, shape, sorted_indices=False, has_canonical_format=False)[source]#
Bases:
objectA 2D sparse matrix in Compressed Sparse Column (CSC) format.
CSCArray stores a 2D sparse matrix using three MLX arrays:
data: non-zero values, shape
(nnz,).indices: row index of each stored value, shape
(nnz,).indptr: column pointer array, shape
(n_cols + 1,).
Column
jspansdata[indptr[j] : indptr[j+1]]with corresponding row indicesindices[indptr[j] : indptr[j+1]]. Duplicate row entries within a column are permitted unless the matrix is in canonical form.CSC is the column-compressed dual of CSR. It is the natural layout for operations that consume one full column at a time, such as transpose matvec, column-oriented canonicalization, and future direct factorization kernels.
- Parameters:
data (mlx.core.array) – Non-zero values, shape
(nnz,).indices (mlx.core.array) – Row indices, shape
(nnz,).indptr (mlx.core.array) – Column pointer array, shape
(n_cols + 1,).shape (tuple[int, int]) – Matrix dimensions as
(n_rows, n_cols).sorted_indices (bool) – Hint that row indices within each column are sorted ascending. Defaults to
False.has_canonical_format (bool) – Hint that the matrix has sorted row indices and no duplicate row index in any column. Implies
sorted_indices=True. Defaults toFalse.
Properties
Number of stored values (including any duplicates).
Value dtype of the stored non-zeros.
Integer dtype used for
indicesandindptr.Always 2.
Transposed matrix.
Hermitian (conjugate) transpose.
Methods
todense()Materialize the sparse matrix as a dense MLX array.
tocsr(*[, canonical])Convert to
CSRArray.Return a new CSCArray with row indices sorted within each column.
Sum duplicate row entries within each column.
Return canonical form: sorted row indices, no duplicates.
Transpose the sparse matrix, returning a zero-copy CSRArray.
conj()Complex-conjugate the stored values.
Alias for
conj().- data: mlx.core.array#
- indices: mlx.core.array#
- indptr: mlx.core.array#
- property dtype#
Value dtype of the stored non-zeros.
- property index_dtype#
Integer dtype used for
indicesandindptr.
- column_sums()[source]#
Alias for
col_sums().- Return type:
mlx.core.array
- row_norms()[source]#
Return the dense-semantics L2 norm of each CSC row as
float32.- Return type:
mlx.core.array
- col_norms()[source]#
Return the dense-semantics L2 norm of each CSC column as
float32.- Return type:
mlx.core.array
- column_norms()[source]#
Alias for
col_norms().- Return type:
mlx.core.array
- sum(axis=None)[source]#
Sum sparse values over all entries, rows, or columns.
axis=Nonereturns a scalar,axis=1returns row sums, andaxis=0returns column sums.- Return type:
mlx.core.array
- tocsr(*, canonical=None)[source]#
Convert to
CSRArray.- Parameters:
canonical (bool | None) – If
True, canonicalize the returned CSR matrix. IfFalseorNone(default), return the structural conversion as produced by the native count/prefix/fill path. The structural path preserves values but does not promise sorted output metadata on every backend.- Return type:
- sort_indices()[source]#
Return a new CSCArray with row indices sorted within each column.
- Return type:
- property T: CSRArray#
Transposed matrix. Alias for
transpose().
- __rmul__(other)[source]#
Multiply the current CSCArray by a number using the
*operator.This returns a new CSCArray with the data multiplied by the number, and therefore does not in-place mutate the current CSCArray.
- Parameters:
other – A valid number (complex or not).
- Returns:
A new CSCArray with the data multiplied by the number.
- Raises:
TypeError – If
otheris not an actual number.
- __mul__(other)[source]#
Multiply the current CSCArray by a number using the
*operator.This returns a new CSCArray with the data multiplied by the number, and therefore does not in-place mutate the current CSCArray.
- Parameters:
other – A valid number (complex or not).
- Returns:
A new CSCArray with the data multiplied by the number.
- Raises:
TypeError – If
otheris not an actual number.
COOArray#
- class mlx_sparse.COOArray(data, row, col, shape, has_canonical_format=False)[source]#
Bases:
objectA 2D sparse matrix in Coordinate (COO) format.
COOArray stores a sparse matrix as three parallel arrays:
data: non-zero values, shape
(nnz,).row: row coordinate for each value, shape
(nnz,).col: column coordinate for each value, shape
(nnz,).
COO is the primary construction format. It allows duplicate
(row, col)entries and does not require sorted coordinates, making it straightforward to assemble matrices from element lists, graph adjacency lists, finite-element stencils, or Hamiltonians.COO supports native sparse-dense products directly. For heavily repeated row-oriented workloads, CSR may still be preferable after construction because its compressed row layout avoids coordinate scatter.
Format invariants (checked by
validate="metadata"by default):All three arrays must be rank-1 with the same length.
rowandcolshare the same integer dtype (int32orint64).datadtype is one offloat32,float16,bfloat16, orcomplex64.
Additional value-level invariants (
validate="full"only):0 <= row[i] < n_rowsfor all entries.0 <= col[i] < n_colsfor all entries.
- Parameters:
data (mlx.core.array) – Non-zero values, shape
(nnz,).row (mlx.core.array) – Row coordinates, shape
(nnz,).col (mlx.core.array) – Column coordinates, shape
(nnz,).shape (tuple[int, int]) – Matrix dimensions as
(n_rows, n_cols).has_canonical_format (bool) – Hint that coordinates are sorted and duplicate- free. Defaults to
False.
Example:
import mlx.core as mx import mlx_sparse as ms data = mx.array([2.0, -1.0, 4.0], dtype=mx.float32) row = mx.array([0, 0, 1], dtype=mx.int32) col = mx.array([0, 2, 1], dtype=mx.int32) # 2×3 matrix: [[2, 0, -1], # [0, 4, 0]] coo = ms.coo_array((data, (row, col)), shape=(2, 3)) csr = coo.tocsr(canonical=True)
Properties
Number of stored values (including any duplicates).
Value dtype of the stored non-zeros (e.g.
mlx.core.float32).Integer dtype used for
rowandcol.Always 2.
Methods
tocsr(*[, canonical])Convert to
CSRArray.tocsc(*[, canonical])Convert to
CSCArray.todense()Materialize as a dense MLX array.
- data: mlx.core.array#
- row: mlx.core.array#
- col: mlx.core.array#
- property dtype#
Value dtype of the stored non-zeros (e.g.
mlx.core.float32).
- property index_dtype#
Integer dtype used for
rowandcol.
- tocsr(*, canonical=False)[source]#
Convert to
CSRArray.Sorts entries by row then column and builds a
(n_rows + 1,)row pointer array. Duplicate(row, col)entries are preserved in the raw output. Passcanonical=Trueto sum them.- Parameters:
canonical (bool) – If
True, callcanonicalize()on the result to sort indices and sum duplicates. DefaultFalse.- Returns:
A
CSRArraywithsorted_indices=True. Ifcanonical=True, alsohas_canonical_format=True.- Return type:
- todense()[source]#
Materialize as a dense MLX array.
Internally converts to CSR and then calls
todense(). Duplicate entries are summed.- Returns:
Dense array of shape
(n_rows, n_cols)with the same dtype asself.data.- Return type:
mlx.core.array
- column_sums()[source]#
Alias for
col_sums().- Return type:
mlx.core.array
- row_norms()[source]#
Return the dense-semantics L2 norm of each COO row as
float32.- Return type:
mlx.core.array
- col_norms()[source]#
Return the dense-semantics L2 norm of each COO column as
float32.- Return type:
mlx.core.array
- column_norms()[source]#
Alias for
col_norms().- Return type:
mlx.core.array
- sum(axis=None)[source]#
Sum sparse values over all entries, rows, or columns.
axis=Nonereturns a scalar,axis=1returns row sums, andaxis=0returns column sums.- Return type:
mlx.core.array
- __rmul__(other)[source]#
Multiply the current CSCArray by a number using the
*operator.This returns a new CSCArray with the data multiplied by the number, and therefore does not in-place mutate the current CSCArray.
- Parameters:
other – A valid number (complex or not).
- Returns:
A new CSCArray with the data multiplied by the number.
- Raises:
TypeError – If
otheris not an actual number.
- __mul__(other)[source]#
Multiply the current CSCArray by a number using the
*operator.This returns a new CSCArray with the data multiplied by the number, and therefore does not in-place mutate the current CSCArray.
- Parameters:
other – A valid number (complex or not).
- Returns:
A new CSCArray with the data multiplied by the number.
- Raises:
TypeError – If
otheris not an actual number.
Utility functions#
- mlx_sparse.issparse(x)[source]#
Return
Trueifxis a recognized mlx-sparse container.Currently returns
TrueforCOOArray,CSRArray, andCSCArrayinstances. All other objects returnFalse.- Parameters:
x – Any Python object.
- Returns:
- Return type:
Example:
import mlx_sparse as ms ms.issparse(my_csr) # True ms.issparse(mx.ones((3, 4))) # False