First steps with mlx-sparse#

This notebook walks through the minimum viable workflow: construct a sparse matrix, inspect its storage, run a sparse-dense product, and convert back to dense.

What is a sparse matrix?#

A sparse matrix is one where most entries are zero. Instead of storing all m x n values, we store only the non-zeros together with their coordinates. For a 10,000 x 10,000 matrix with 0.1% density, that is 10,000 stored values vs 100,000,000, a 10,000x reduction in memory.

mlx-sparse provides three primary formats:

  • COO (Coordinate): assembly format and native coordinate products.

  • CSR (Compressed Sparse Row): compressed row storage; efficient row-scan products.

  • CSC (Compressed Sparse Column): compressed column storage; efficient column-oriented products.

This first notebook assembles in COO and then uses CSR to show the row-oriented product path.

import mlx.core as mx
import numpy as np
import mlx_sparse as ms

# Use GPU for all operations (Apple Silicon Metal backend).
ms.use_gpu()
Device(gpu, 0)

Building a COO matrix#

We will represent the 3 x 4 matrix:

[[2,  0, -1,  0],
 [0,  0,  0,  0],
 [0,  4,  0,  5]]

Four non-zero entries at positions (0,0), (0,2), (2,1), (2,3).

data = mx.array([2.0, -1.0, 4.0, 5.0], dtype=mx.float32)
row = mx.array([0, 0, 2, 2], dtype=mx.int32)
col = mx.array([0, 2, 1, 3], dtype=mx.int32)

coo = ms.coo_array((data, (row, col)), shape=(3, 4))
print(coo)
COOArray(shape=(3, 4), nnz=4, dtype=mlx.core.float32, index_dtype=mlx.core.int32, has_canonical_format=False)

Converting to CSR#

Pass canonical=True to sort column indices and sum any duplicate entries in one step.

csr = coo.tocsr(canonical=True)
print(csr)

# Materialise and inspect the raw buffers.
mx.eval(csr.data, csr.indices, csr.indptr)
print()
print("data:   ", np.array(csr.data))
print("indices:", np.array(csr.indices))
print("indptr: ", np.array(csr.indptr))
CSRArray(shape=(3, 4), nnz=4, dtype=mlx.core.float32, index_dtype=mlx.core.int32, sorted_indices=True, has_canonical_format=True)

data:    [ 2. -1.  4.  5.]
indices: [0 2 1 3]
indptr:  [0 2 2 4]

Reading the indptr: row 0 spans entries [0, 2) (columns 0 and 2), row 1 spans [2, 2) (empty), and row 2 spans [2, 4) (columns 1 and 3).

dense = csr.todense()
mx.eval(dense)
print(np.array(dense))
[[ 2.  0. -1.  0.]
 [ 0.  0.  0.  0.]
 [ 0.  4.  0.  5.]]

Sparse matrix-vector product#

The @ operator dispatches to the native csr_matvec primitive. The result is a lazy MLX array — nothing runs until mx.eval is called.

x = mx.array([1.0, 1.0, 1.0, 1.0], dtype=mx.float32)

y = csr @ x  # lazy, dispatches Metal kernel on eval
mx.eval(y)

print("y = A @ x:", np.array(y))

# Verify against dense matmul.
y_dense = csr.todense() @ x
mx.eval(y_dense)
print("expected: ", np.array(y_dense))
y = A @ x: [1. 0. 9.]
expected:  [1. 0. 9.]

Transpose and Hermitian transpose#

At = csr.T
print("A.T shape:", At.shape)
print(np.array(At.todense()))
A.T shape: (4, 3)
[[ 2.  0.  0.]
 [ 0.  0.  4.]
 [-1.  0.  0.]
 [ 0.  0.  5.]]

Summary#

Step

Call

Result

Construct

ms.coo_array((data, (row, col)), shape=...)

COOArray

Convert

coo.tocsr(canonical=True)

CSRArray

Product

csr @ x

lazy mx.array

Materialise

mx.eval(y)

computes on device

Dense view

csr.todense()

mx.array shape (m, n)

Transpose

csr.T

new CSRArray shape (n, m)

Next: 02 — CSR matvec for larger matrices and timing.