Constructors#

These functions create sparse array instances from raw buffers or structured definitions. All array inputs accept MLX arrays, NumPy arrays, or any sequence convertible via mx.array().

csr_array#

mlx_sparse.csr_array(arg, shape, *, validate='metadata', sorted_indices=False, canonical=None)[source]#

Construct a CSRArray from explicit CSR buffers.

Accepts either a (data, indices, indptr) triple or an existing CSRArray. All array inputs are converted to mlx.core.array if they are not already.

Parameters:
  • arg

    A length-3 iterable (data, indices, indptr) where

    • data: non-zero values, shape (nnz,), dtype float32 | float16 | bfloat16 | complex64.

    • indices: column indices, shape (nnz,), dtype int32 | int64.

    • indptr: row pointers, shape (n_rows + 1,), same integer dtype as indices.

    Alternatively, an existing CSRArray instance (returned unchanged if shape matches).

  • shape – Matrix dimensions as a length-2 sequence (n_rows, n_cols).

  • validate (bool | Literal['metadata', 'full']) –

    Validation level, one of:

    • "metadata" (default): checks ranks, lengths, and dtypes without reading array values. Safe to call on device arrays.

    • "full" / True: also verifies indptr monotonicity and column index bounds. May synchronize to host.

    • False / "none": skips all checks.

  • sorted_indices (bool) – Set to True to assert that column indices within each row are already sorted ascending. Default False.

  • canonical (bool | None) – Set to True to assert canonical format (sorted indices, no duplicate columns). Implies sorted_indices=True. Default None (not asserted).

Returns:

A CSRArray with the given buffers and shape.

Raises:
  • TypeError – If arg is not a 3-tuple or a CSRArray, or if dtype constraints are violated.

  • ValueError – If shape or length constraints are violated.

Return type:

CSRArray

Example:

import mlx.core as mx
import mlx_sparse as ms

data = mx.array([1.0, 2.0, 3.0], dtype=mx.float32)
indices = mx.array([0, 1, 0], dtype=mx.int32)
indptr = mx.array([0, 2, 3], dtype=mx.int32)

A = ms.csr_array((data, indices, indptr), shape=(2, 2))
# Full validation from host arrays:
A_checked = ms.csr_array(
    (data, indices, indptr), shape=(2, 2), validate="full"
)

coo_array#

mlx_sparse.coo_array(arg, shape, *, validate='metadata', canonical=None)[source]#

Construct a COOArray from coordinate arrays.

Accepts either a (data, (row, col)) pair or an existing COOArray. All array inputs are converted to mlx.core.array if they are not already.

Parameters:
  • arg

    A (data, (row, col)) pair where

    • data: non-zero values, shape (nnz,), dtype float32 | float16 | bfloat16 | complex64.

    • row: row coordinates, shape (nnz,), dtype int32 | int64.

    • col: column coordinates, shape (nnz,), same integer dtype as row.

    Alternatively, an existing COOArray (returned unchanged if shape matches).

  • shape – Matrix dimensions as a length-2 sequence (n_rows, n_cols).

  • validate (bool | Literal['metadata', 'full']) –

    Validation level, one of:

    • "metadata" (default): checks ranks, lengths, and dtypes.

    • "full" / True: also verifies coordinate bounds. May synchronize to host.

    • False / "none": skips all checks.

  • canonical (bool | None) – Set to True to assert the coordinates are sorted and duplicate-free. Default None (not asserted).

Returns:

A COOArray with the given buffers and shape.

Raises:
  • TypeError – If arg cannot be unpacked as (data, (row, col)), or if dtype constraints are violated.

  • ValueError – If shape or length constraints are violated.

Return type:

COOArray

Example:

import mlx.core as mx
import mlx_sparse as ms

data = mx.array([1.0, 2.0, 3.0, 2.0], dtype=mx.float32)
row = mx.array([0, 0, 1, 0], dtype=mx.int32)
col = mx.array([0, 1, 0, 0], dtype=mx.int32)

# Two entries at (0, 0). Summed when converting to CSR.
coo = ms.coo_array((data, (row, col)), shape=(2, 2))
csr = coo.tocsr(canonical=True)  # sums duplicates

eye#

mlx_sparse.eye(n, m=None, *, k=0, dtype=mlx.core.float32, index_dtype=mlx.core.int32)[source]#

Return a sparse identity-like CSR matrix with ones on a specified diagonal.

Produces the same result as numpy.eye() with k=k, but returns a CSRArray instead of a dense array. The matrix has at most min(n, m) stored values. Rows (or columns) that the diagonal does not pass through are empty rows in the CSR representation.

Parameters:
  • n (int) – Number of rows.

  • m (int | None) – Number of columns. Defaults to n, producing a square matrix.

  • k (int) – Diagonal offset. 0 selects the main diagonal. Positive values shift the diagonal above the main diagonal (superdiagonal). Negative values shift it below (subdiagonal).

  • dtype – Value dtype for the stored ones. Must be one of mx.float32, mx.float16, mx.bfloat16, or mx.complex64. Defaults to mx.float32.

  • index_dtype – Integer dtype for indices and indptr. Must be mx.int32 or mx.int64. Defaults to mx.int32.

Returns:

A canonical CSRArray with has_canonical_format=True and sorted_indices=True.

Raises:

TypeError – If dtype or index_dtype is not a supported value.

Return type:

CSRArray

Example:

import mlx_sparse as ms
import mlx.core as mx

# 4x4 identity matrix
I = ms.eye(4)
mx.eval(I.data)
# CSRArray(shape=(4, 4), nnz=4, ...)

# 3x5 matrix with ones on the first superdiagonal
A = ms.eye(3, 5, k=1)
# Non-zeros at (0,1), (1,2), (2,3)

diags#

mlx_sparse.diags(diagonals, offsets=0, *, shape=None, dtype=None, index_dtype=mlx.core.int32)[source]#

Construct a CSR matrix from one or more diagonals.

Mirrors the behaviour of scipy.sparse.diags() but returns a CSRArray. Each diagonal is placed at the position specified by the corresponding offset. Diagonals are assembled into a COO triple and sorted before the CSR row-pointer array is built, so the result is always in canonical form.

Parameters:
  • diagonals

    The diagonal values. Accepted forms:

    • A single 1-D array-like (or scalar) placed at offsets.

    • A 2-D array whose rows are individual diagonals.

    • A list of 1-D array-likes, one per entry in offsets.

    Each diagonal’s length must not exceed the number of elements that the diagonal at the corresponding offset can hold given shape.

  • offsets – Diagonal offset(s). 0 is the main diagonal. Positive integers are superdiagonals. Negative integers are subdiagonals. When diagonals is a list, offsets must be a matching list of integers. Repeated offsets are not allowed.

  • shape (Sequence[int] | None) – Output matrix shape as (n_rows, n_cols). When omitted, the minimum square shape that fits all diagonals is inferred automatically.

  • dtype – Value dtype. When None (default), the dtype is inferred from the diagonal arrays: complex64 if any diagonal is complex, float16 if any diagonal has dtype float16, otherwise float32.

  • index_dtype – Integer dtype for indices and indptr. Must be mx.int32 or mx.int64. Defaults to mx.int32.

Returns:

A canonical CSRArray with has_canonical_format=True and sorted_indices=True.

Raises:
  • TypeError – If dtype or index_dtype is not supported.

  • ValueError – If the number of diagonals and offsets differ, if offsets are repeated, or if a diagonal is longer than its allocated space.

Return type:

CSRArray

Example:

import numpy as np
import mlx_sparse as ms
import mlx.core as mx

# Tridiagonal matrix: main diagonal 2, off-diagonals -1
A = ms.diags(
    [np.full(4, -1.0), np.full(5, 2.0), np.full(4, -1.0)],
    offsets=[-1, 0, 1],
)
# 5x5, nnz=13

# Single diagonal at offset 2
B = ms.diags([1.0, 2.0, 3.0], offsets=2, shape=(5, 5))

fromdense#

mlx_sparse.fromdense(array, *, threshold=0.0, dtype=None, index_dtype=mlx.core.int32)[source]#

Construct a canonical CSR matrix from a rank-2 dense MLX array.

Identifies the non-zero (or above-threshold) entries of a dense matrix and packages them into a CSRArray. The native path stages this as count, allocate, then fill work so Metal builds can perform the dense scan and CSR writes on device while still returning compact buffers.

The value dtype is preserved from the input array. Index dtype defaults to int32 and can be overridden for matrices with more than ~2 billion non-zeros (not typical on Apple Silicon).

Parameters:
  • array – A rank-2 array-like. Converted to mlx.core.array if not already. Dtype must be one of float32, float16, bfloat16, or complex64.

  • threshold (float) – Entries with absolute value less than or equal to threshold are treated as structural zeros and excluded from the output. The default 0.0 keeps every numerically non-zero entry. Must be non-negative.

  • dtype – Optional value dtype to cast to before extracting non-zeros. When None, the input dtype chosen by MLX is preserved.

  • index_dtype – Integer dtype for indices and indptr. Must be mx.int32 or mx.int64. Defaults to mx.int32.

Returns:

A canonical CSRArray with has_canonical_format=True and sorted_indices=True.

Raises:
  • TypeError – If the input dtype is not a supported value dtype.

  • ValueError – If the input is not rank-2, or if threshold is negative.

Return type:

CSRArray

Example:

import mlx.core as mx
import numpy as np
import mlx_sparse as ms

dense = mx.array(np.array([
    [1.0, 0.0, 2.0],
    [0.0, 0.0, 0.0],
    [3.0, 4.0, 0.0],
], dtype=np.float32))

csr = ms.fromdense(dense)
# CSRArray(shape=(3, 3), nnz=4, dtype=float32, ...)

# Drop near-zero entries below 0.1
csr_thresholded = ms.fromdense(dense, threshold=0.5)

from_dense#

mlx_sparse.from_dense(array, *, threshold=0.0, dtype=None, index_dtype=mlx.core.int32)[source]#

Alias for fromdense() with a PEP 8 compatible name.

Parameters:

threshold (float)

Return type:

CSRArray

from_numpy#

mlx_sparse.from_numpy(array, *, threshold=0.0, dtype=None, index_dtype=mlx.core.int32)[source]#

Convert a rank-2 NumPy array to a canonical CSRArray.

Parameters:

threshold (float)

Return type:

CSRArray

from_scipy#

mlx_sparse.from_scipy(matrix, *, format='csr', dtype=None, index_dtype=mlx.core.int32, canonical=True)[source]#

Convert a SciPy sparse matrix or sparse array to mlx-sparse.

Any SciPy sparse format is accepted. format="csr" returns a CSRArray, format="csc" returns a CSCArray, and format="coo" returns a COOArray. The conversion preserves supported float32, float16, and complex64 values. Other real floating dtypes, including SciPy’s default float64, are cast to float32 unless dtype is provided.

Parameters:
  • matrix – A scipy.sparse matrix or array.

  • format (str) – Output sparse format: "csr" (default), "csc", or "coo".

  • dtype – Optional MLX value dtype. Must be one of mx.float32, mx.float16, mx.bfloat16, or mx.complex64.

  • index_dtype – Integer dtype for sparse indices. Must be mx.int32 or mx.int64.

  • canonical (bool) – If True (default), sum duplicates and sort indices before exporting buffers.

Returns:

A CSRArray, CSCArray, or COOArray.

Raises:
  • TypeError – If SciPy is not installed, matrix is not sparse, or a dtype is unsupported.

  • ValueError – If format is not "csr", "csc", or "coo".

asarray#

mlx_sparse.asarray(x, *, threshold=0.0, dtype=None, index_dtype=mlx.core.int32)[source]#

Convert common sparse or dense inputs to a sparse array.

Existing CSRArray and CSCArray instances are returned unchanged unless dtype requests a value cast. COOArray instances are converted with tocsr(canonical=True). SciPy sparse matrices/arrays route through from_scipy(), dense MLX, NumPy, and Python array-likes route through fromdense().

Parameters:
  • x – Existing mlx-sparse array, SciPy sparse array, dense MLX array, NumPy array, or Python rank-2 array-like.

  • threshold (float) – Dense-only structural-zero threshold.

  • dtype – Optional target value dtype.

  • index_dtype – Target index dtype for newly constructed sparse arrays.

Returns:

Existing CSRArray or CSCArray inputs are preserved. Other inputs return a canonical CSRArray.

Return type:

CSRArray | CSCArray

Validation modes#

The validate parameter on csr_array() and coo_array() accepts:

Value

Behaviour

"metadata"

Checks ranks, array lengths, and dtypes. No value reads. Default.

"full" or True

Full metadata checks plus value-level checks (bounds, monotonicity). May call mx.eval to read index values from device.

False or "none"

No checks. Use only when inputs are known valid.

See Validation for a detailed discussion.

Structured constructors#

eye(), diags(), fromdense(), from_dense(), from_numpy(), from_scipy(), and asarray() all make common construction paths explicit. Dense conversions return a canonical CSRArray (has_canonical_format=True, sorted_indices=True). They are host-side assembly operations and call mx.eval internally when necessary to determine the output size.

import numpy as np
import mlx.core as mx
import mlx_sparse as ms

# 5x5 identity
I = ms.eye(5)

# 4x4 tridiagonal
T = ms.diags(
    [np.full(3, -1.0), np.full(4, 2.0), np.full(3, -1.0)],
    offsets=[-1, 0, 1],
)

# Convert a dense weight matrix
W = mx.array(np.random.randn(8, 8).astype(np.float32))
W_sparse = ms.fromdense(W, threshold=0.1)

# Convert SciPy sparse or dense NumPy inputs without hand-building buffers
W_from_np = ms.asarray(np.eye(8, dtype=np.float32))