Quick start#

This page covers the most common usage patterns in order. By the end you will have assembled a sparse matrix, run sparse-dense products on the selected MLX device, computed a gradient through sparse values and dense operands, and seen the structured and interchange constructors.

Selecting a device#

Call use_gpu() or use_cpu() once before any computation. This sets MLX’s default device for all subsequent operations.

import mlx_sparse as ms

ms.use_cpu()  # portable default, including Linux CPU-only wheels
ms.use_gpu()  # Apple Silicon GPU (Metal), when available

If no device is selected, MLX uses its own default (usually GPU on Apple Silicon). Fixed-shape sparse primitives support all value and index dtype combinations on CPU and on the Apple Silicon Metal backend. Linux wheels are CPU-only in this release.

Constructing a COO matrix#

coo_array() accepts a (data, (row, col)) tuple plus a shape. All arrays can be mlx.core.array or anything convertible via mx.array(), including NumPy arrays.

import mlx.core as mx
import numpy as np
import mlx_sparse as ms

# Represents the 3x4 matrix:
#  [[ 2,  0, -1,  0],
#   [ 0,  0,  0,  0],
#   [ 0,  4,  0,  5]]

data = mx.array(np.array([2.0, -1.0, 4.0, 5.0], dtype=np.float32))
row = mx.array(np.array([0, 0, 2, 2], dtype=np.int32))
col = mx.array(np.array([0, 2, 1, 3], dtype=np.int32))

coo = ms.coo_array((data, (row, col)), shape=(3, 4))
print(coo)
# COOArray(shape=(3, 4), nnz=4, dtype=float32, index_dtype=int32, ...)

Converting to CSR#

COO is the assembly format. Convert to CSR before running repeated products. Pass canonical=True to sort column indices and sum any duplicate entries in the same conversion call.

csr = coo.tocsr(canonical=True)
print(csr)
# CSRArray(shape=(3, 4), nnz=4, dtype=float32, index_dtype=int32,
#          sorted_indices=True, has_canonical_format=True)

You can also construct CSR directly from three arrays if you already have the CSR buffers (e.g. from SciPy):

import scipy.sparse

sp = scipy.sparse.random(256, 256, density=0.01, format="csr",
                          dtype=np.float32, random_state=0)
csr = ms.csr_array(
    (mx.array(sp.data),
     mx.array(sp.indices, dtype=mx.int32),
     mx.array(sp.indptr, dtype=mx.int32)),
    shape=sp.shape,
    sorted_indices=True,
    canonical=True,
)

Structured constructors#

For common structured matrices and interchange, use eye(), diags(), fromdense(), from_scipy(), or asarray() instead of assembling a COO triple by hand.

import numpy as np

# 4x4 identity
I = ms.eye(4)

# Tridiagonal Laplacian: diagonals [−1, 2, −1] at offsets [−1, 0, 1]
n = 6
L = ms.diags(
    [np.full(n - 1, -1.0), np.full(n, 2.0), np.full(n - 1, -1.0)],
    offsets=[-1, 0, 1],
)

# Dense-to-sparse conversion
dense = mx.array(np.eye(4, dtype=np.float32) * 3.0)
csr = ms.fromdense(dense)

# Generic conversion: CSR inputs pass through, SciPy sparse and dense arrays
# become canonical CSRArray instances.
csr = ms.asarray(np.eye(4, dtype=np.float32))

Sparse-dense matrix-vector product#

Use the @ operator. The result is a lazy MLX array and no computation happens until mx.eval is called.

x = mx.array(np.ones(4, dtype=np.float32))
y = csr @ x  # lazy, shape (3,)
mx.eval(y)  # evaluate on the active device
print(np.array(y))  # [ 1. 0. 9.]

Sparse-dense matrix-matrix product#

The same @ operator dispatches to csr_matmul() when the right-hand side is rank-2, and handles batched dense operands as well.

B = mx.array(np.random.randn(4, 8).astype(np.float32))
Y = csr @ B  # shape (3, 8)
mx.eval(Y)

# Batched: rhs shape (batch, n_cols, k) -> output shape (batch, n_rows, k)
B_batch = mx.array(np.random.randn(2, 4, 8).astype(np.float32))
Y_batch = csr @ B_batch  # shape (2, 3, 8)

Sparse-sparse matrix product#

When both operands are CSRArray instances, @ dispatches to csr_matmat() and returns a new CSRArray:

A = ms.eye(4)
B = ms.diags([1.0, 2.0, 3.0, 4.0])
C = A @ B  # CSRArray, sparse-sparse product

Converting to dense#

dense = csr.todense()  # mx.array, shape (3, 4)
# or, using the module-level helper:
dense = ms.todense(csr)

Transpose and Hermitian transpose#

T returns a structural transpose as a new CSRArray with shape (n_cols, n_rows). H additionally conjugates the values.

At = csr.T  # CSRArray(shape=(4, 3), ...)
Ah = csr.H  # conjugate transpose. Relevant for complex64 matrices.

Computing gradients#

mx.grad differentiates through csr_matvec() and csr_matmul() with respect to both sparse data values and the dense operand. CPU and GPU are supported for real and complex64 dtypes.

ms.use_gpu()

csr = ms.coo_array((data, (row, col)), shape=(3, 4)).tocsr()
x = mx.array(np.ones(4, dtype=np.float32))

def loss(values, x):
    A = ms.csr_array((values, csr.indices, csr.indptr), shape=csr.shape)
    y = A @ x
    return mx.sum(y * y)

grad_values, grad_x = mx.grad(loss, argnums=(0, 1))(csr.data, x)
mx.eval(grad_values, grad_x)

The gradient matches mx.grad applied to the equivalent dense matrix multiply, up to floating-point rounding.