Constructors#
These functions create sparse array instances from raw buffers or structured
definitions. All array inputs accept MLX arrays, NumPy arrays, or any sequence
convertible via mx.array().
csr_array#
- mlx_sparse.csr_array(arg, shape, *, validate='metadata', sorted_indices=False, canonical=None)[source]#
Construct a
CSRArrayfrom explicit CSR buffers.Accepts either a
(data, indices, indptr)triple or an existingCSRArray. All array inputs are converted tomlx.core.arrayif they are not already.- Parameters:
arg –
A length-3 iterable
(data, indices, indptr)wheredata: non-zero values, shape
(nnz,), dtypefloat32 | float16 | bfloat16 | complex64.indices: column indices, shape
(nnz,), dtypeint32 | int64.indptr: row pointers, shape
(n_rows + 1,), same integer dtype as indices.
Alternatively, an existing
CSRArrayinstance (returned unchanged ifshapematches).shape – Matrix dimensions as a length-2 sequence
(n_rows, n_cols).validate (bool | Literal['metadata', 'full']) –
Validation level, one of:
"metadata"(default): checks ranks, lengths, and dtypes without reading array values. Safe to call on device arrays."full"/True: also verifiesindptrmonotonicity and column index bounds. May synchronize to host.False/"none": skips all checks.
sorted_indices (bool) – Set to
Trueto assert that column indices within each row are already sorted ascending. DefaultFalse.canonical (bool | None) – Set to
Trueto assert canonical format (sorted indices, no duplicate columns). Impliessorted_indices=True. DefaultNone(not asserted).
- Returns:
A
CSRArraywith the given buffers and shape.- Raises:
TypeError – If
argis not a 3-tuple or aCSRArray, or if dtype constraints are violated.ValueError – If shape or length constraints are violated.
- Return type:
Example:
import mlx.core as mx import mlx_sparse as ms data = mx.array([1.0, 2.0, 3.0], dtype=mx.float32) indices = mx.array([0, 1, 0], dtype=mx.int32) indptr = mx.array([0, 2, 3], dtype=mx.int32) A = ms.csr_array((data, indices, indptr), shape=(2, 2)) # Full validation from host arrays: A_checked = ms.csr_array( (data, indices, indptr), shape=(2, 2), validate="full" )
coo_array#
- mlx_sparse.coo_array(arg, shape, *, validate='metadata', canonical=None)[source]#
Construct a
COOArrayfrom coordinate arrays.Accepts either a
(data, (row, col))pair or an existingCOOArray. All array inputs are converted tomlx.core.arrayif they are not already.- Parameters:
arg –
A
(data, (row, col))pair wheredata: non-zero values, shape
(nnz,), dtypefloat32 | float16 | bfloat16 | complex64.row: row coordinates, shape
(nnz,), dtypeint32 | int64.col: column coordinates, shape
(nnz,), same integer dtype as row.
Alternatively, an existing
COOArray(returned unchanged ifshapematches).shape – Matrix dimensions as a length-2 sequence
(n_rows, n_cols).validate (bool | Literal['metadata', 'full']) –
Validation level, one of:
"metadata"(default): checks ranks, lengths, and dtypes."full"/True: also verifies coordinate bounds. May synchronize to host.False/"none": skips all checks.
canonical (bool | None) – Set to
Trueto assert the coordinates are sorted and duplicate-free. DefaultNone(not asserted).
- Returns:
A
COOArraywith the given buffers and shape.- Raises:
TypeError – If
argcannot be unpacked as(data, (row, col)), or if dtype constraints are violated.ValueError – If shape or length constraints are violated.
- Return type:
Example:
import mlx.core as mx import mlx_sparse as ms data = mx.array([1.0, 2.0, 3.0, 2.0], dtype=mx.float32) row = mx.array([0, 0, 1, 0], dtype=mx.int32) col = mx.array([0, 1, 0, 0], dtype=mx.int32) # Two entries at (0, 0). Summed when converting to CSR. coo = ms.coo_array((data, (row, col)), shape=(2, 2)) csr = coo.tocsr(canonical=True) # sums duplicates
eye#
- mlx_sparse.eye(n, m=None, *, k=0, dtype=mlx.core.float32, index_dtype=mlx.core.int32)[source]#
Return a sparse identity-like CSR matrix with ones on a specified diagonal.
Produces the same result as
numpy.eye()withk=k, but returns aCSRArrayinstead of a dense array. The matrix has at mostmin(n, m)stored values. Rows (or columns) that the diagonal does not pass through are empty rows in the CSR representation.- Parameters:
n (int) – Number of rows.
m (int | None) – Number of columns. Defaults to
n, producing a square matrix.k (int) – Diagonal offset.
0selects the main diagonal. Positive values shift the diagonal above the main diagonal (superdiagonal). Negative values shift it below (subdiagonal).dtype – Value dtype for the stored ones. Must be one of
mx.float32,mx.float16,mx.bfloat16, ormx.complex64. Defaults tomx.float32.index_dtype – Integer dtype for
indicesandindptr. Must bemx.int32ormx.int64. Defaults tomx.int32.
- Returns:
A canonical
CSRArraywithhas_canonical_format=Trueandsorted_indices=True.- Raises:
TypeError – If
dtypeorindex_dtypeis not a supported value.- Return type:
Example:
import mlx_sparse as ms import mlx.core as mx # 4x4 identity matrix I = ms.eye(4) mx.eval(I.data) # CSRArray(shape=(4, 4), nnz=4, ...) # 3x5 matrix with ones on the first superdiagonal A = ms.eye(3, 5, k=1) # Non-zeros at (0,1), (1,2), (2,3)
diags#
- mlx_sparse.diags(diagonals, offsets=0, *, shape=None, dtype=None, index_dtype=mlx.core.int32)[source]#
Construct a CSR matrix from one or more diagonals.
Mirrors the behaviour of
scipy.sparse.diags()but returns aCSRArray. Each diagonal is placed at the position specified by the corresponding offset. Diagonals are assembled into a COO triple and sorted before the CSR row-pointer array is built, so the result is always in canonical form.- Parameters:
diagonals –
The diagonal values. Accepted forms:
A single 1-D array-like (or scalar) placed at
offsets.A 2-D array whose rows are individual diagonals.
A list of 1-D array-likes, one per entry in
offsets.
Each diagonal’s length must not exceed the number of elements that the diagonal at the corresponding offset can hold given
shape.offsets – Diagonal offset(s).
0is the main diagonal. Positive integers are superdiagonals. Negative integers are subdiagonals. Whendiagonalsis a list,offsetsmust be a matching list of integers. Repeated offsets are not allowed.shape (Sequence[int] | None) – Output matrix shape as
(n_rows, n_cols). When omitted, the minimum square shape that fits all diagonals is inferred automatically.dtype – Value dtype. When
None(default), the dtype is inferred from the diagonal arrays:complex64if any diagonal is complex,float16if any diagonal has dtypefloat16, otherwisefloat32.index_dtype – Integer dtype for
indicesandindptr. Must bemx.int32ormx.int64. Defaults tomx.int32.
- Returns:
A canonical
CSRArraywithhas_canonical_format=Trueandsorted_indices=True.- Raises:
TypeError – If
dtypeorindex_dtypeis not supported.ValueError – If the number of diagonals and offsets differ, if offsets are repeated, or if a diagonal is longer than its allocated space.
- Return type:
Example:
import numpy as np import mlx_sparse as ms import mlx.core as mx # Tridiagonal matrix: main diagonal 2, off-diagonals -1 A = ms.diags( [np.full(4, -1.0), np.full(5, 2.0), np.full(4, -1.0)], offsets=[-1, 0, 1], ) # 5x5, nnz=13 # Single diagonal at offset 2 B = ms.diags([1.0, 2.0, 3.0], offsets=2, shape=(5, 5))
fromdense#
- mlx_sparse.fromdense(array, *, threshold=0.0, dtype=None, index_dtype=mlx.core.int32)[source]#
Construct a canonical CSR matrix from a rank-2 dense MLX array.
Identifies the non-zero (or above-threshold) entries of a dense matrix and packages them into a
CSRArray. The native path stages this as count, allocate, then fill work so Metal builds can perform the dense scan and CSR writes on device while still returning compact buffers.The value dtype is preserved from the input array. Index dtype defaults to
int32and can be overridden for matrices with more than ~2 billion non-zeros (not typical on Apple Silicon).- Parameters:
array – A rank-2 array-like. Converted to
mlx.core.arrayif not already. Dtype must be one offloat32,float16,bfloat16, orcomplex64.threshold (float) – Entries with absolute value less than or equal to
thresholdare treated as structural zeros and excluded from the output. The default0.0keeps every numerically non-zero entry. Must be non-negative.dtype – Optional value dtype to cast to before extracting non-zeros. When
None, the input dtype chosen by MLX is preserved.index_dtype – Integer dtype for
indicesandindptr. Must bemx.int32ormx.int64. Defaults tomx.int32.
- Returns:
A canonical
CSRArraywithhas_canonical_format=Trueandsorted_indices=True.- Raises:
TypeError – If the input dtype is not a supported value dtype.
ValueError – If the input is not rank-2, or if
thresholdis negative.
- Return type:
Example:
import mlx.core as mx import numpy as np import mlx_sparse as ms dense = mx.array(np.array([ [1.0, 0.0, 2.0], [0.0, 0.0, 0.0], [3.0, 4.0, 0.0], ], dtype=np.float32)) csr = ms.fromdense(dense) # CSRArray(shape=(3, 3), nnz=4, dtype=float32, ...) # Drop near-zero entries below 0.1 csr_thresholded = ms.fromdense(dense, threshold=0.5)
from_dense#
- mlx_sparse.from_dense(array, *, threshold=0.0, dtype=None, index_dtype=mlx.core.int32)[source]#
Alias for
fromdense()with a PEP 8 compatible name.
from_numpy#
from_scipy#
- mlx_sparse.from_scipy(matrix, *, format='csr', dtype=None, index_dtype=mlx.core.int32, canonical=True)[source]#
Convert a SciPy sparse matrix or sparse array to mlx-sparse.
Any SciPy sparse format is accepted.
format="csr"returns aCSRArray,format="csc"returns aCSCArray, andformat="coo"returns aCOOArray. The conversion preserves supportedfloat32,float16, andcomplex64values. Other real floating dtypes, including SciPy’s defaultfloat64, are cast tofloat32unlessdtypeis provided.- Parameters:
matrix – A
scipy.sparsematrix or array.format (str) – Output sparse format:
"csr"(default),"csc", or"coo".dtype – Optional MLX value dtype. Must be one of
mx.float32,mx.float16,mx.bfloat16, ormx.complex64.index_dtype – Integer dtype for sparse indices. Must be
mx.int32ormx.int64.canonical (bool) – If
True(default), sum duplicates and sort indices before exporting buffers.
- Returns:
A
CSRArray,CSCArray, orCOOArray.- Raises:
TypeError – If SciPy is not installed,
matrixis not sparse, or a dtype is unsupported.ValueError – If
formatis not"csr","csc", or"coo".
asarray#
- mlx_sparse.asarray(x, *, threshold=0.0, dtype=None, index_dtype=mlx.core.int32)[source]#
Convert common sparse or dense inputs to a sparse array.
Existing
CSRArrayandCSCArrayinstances are returned unchanged unlessdtyperequests a value cast.COOArrayinstances are converted withtocsr(canonical=True). SciPy sparse matrices/arrays route throughfrom_scipy(), dense MLX, NumPy, and Python array-likes route throughfromdense().- Parameters:
x – Existing mlx-sparse array, SciPy sparse array, dense MLX array, NumPy array, or Python rank-2 array-like.
threshold (float) – Dense-only structural-zero threshold.
dtype – Optional target value dtype.
index_dtype – Target index dtype for newly constructed sparse arrays.
- Returns:
Existing
CSRArrayorCSCArrayinputs are preserved. Other inputs return a canonicalCSRArray.- Return type:
Validation modes#
The validate parameter on csr_array() and coo_array() accepts:
Value |
Behaviour |
|---|---|
|
Checks ranks, array lengths, and dtypes. No value reads. Default. |
|
Full metadata checks plus value-level checks (bounds, monotonicity).
May call |
|
No checks. Use only when inputs are known valid. |
See Validation for a detailed discussion.
Structured constructors#
eye(), diags(), fromdense(), from_dense(),
from_numpy(), from_scipy(), and asarray() all make common
construction paths explicit. Dense conversions return a canonical
CSRArray (has_canonical_format=True, sorted_indices=True).
They are host-side assembly operations and call mx.eval internally when
necessary to determine the output size.
import numpy as np
import mlx.core as mx
import mlx_sparse as ms
# 5x5 identity
I = ms.eye(5)
# 4x4 tridiagonal
T = ms.diags(
[np.full(3, -1.0), np.full(4, 2.0), np.full(3, -1.0)],
offsets=[-1, 0, 1],
)
# Convert a dense weight matrix
W = mx.array(np.random.randn(8, 8).astype(np.float32))
W_sparse = ms.fromdense(W, threshold=0.1)
# Convert SciPy sparse or dense NumPy inputs without hand-building buffers
W_from_np = ms.asarray(np.eye(8, dtype=np.float32))