First steps with mlx-sparse#
This notebook walks through the minimum viable workflow: construct a sparse matrix, inspect its storage, run a sparse-dense product, and convert back to dense.
What is a sparse matrix?#
A sparse matrix is one where most entries are zero. Instead of storing all
m x n values, we store only the non-zeros together with their coordinates.
For a 10,000 x 10,000 matrix with 0.1% density, that is 10,000 stored values
vs 100,000,000, a 10,000x reduction in memory.
mlx-sparse provides three primary formats:
COO (Coordinate): assembly format and native coordinate products.
CSR (Compressed Sparse Row): compressed row storage; efficient row-scan products.
CSC (Compressed Sparse Column): compressed column storage; efficient column-oriented products.
This first notebook assembles in COO and then uses CSR to show the row-oriented product path.
import mlx.core as mx
import numpy as np
import mlx_sparse as ms
# Use GPU for all operations (Apple Silicon Metal backend).
ms.use_gpu()
Device(gpu, 0)
Building a COO matrix#
We will represent the 3 x 4 matrix:
[[2, 0, -1, 0],
[0, 0, 0, 0],
[0, 4, 0, 5]]
Four non-zero entries at positions (0,0), (0,2), (2,1), (2,3).
data = mx.array([2.0, -1.0, 4.0, 5.0], dtype=mx.float32)
row = mx.array([0, 0, 2, 2], dtype=mx.int32)
col = mx.array([0, 2, 1, 3], dtype=mx.int32)
coo = ms.coo_array((data, (row, col)), shape=(3, 4))
print(coo)
COOArray(shape=(3, 4), nnz=4, dtype=mlx.core.float32, index_dtype=mlx.core.int32, has_canonical_format=False)
Converting to CSR#
Pass canonical=True to sort column indices and sum any duplicate entries
in one step.
csr = coo.tocsr(canonical=True)
print(csr)
# Materialise and inspect the raw buffers.
mx.eval(csr.data, csr.indices, csr.indptr)
print()
print("data: ", np.array(csr.data))
print("indices:", np.array(csr.indices))
print("indptr: ", np.array(csr.indptr))
CSRArray(shape=(3, 4), nnz=4, dtype=mlx.core.float32, index_dtype=mlx.core.int32, sorted_indices=True, has_canonical_format=True)
data: [ 2. -1. 4. 5.]
indices: [0 2 1 3]
indptr: [0 2 2 4]
Reading the indptr: row 0 spans entries [0, 2) (columns 0 and 2),
row 1 spans [2, 2) (empty), and row 2 spans [2, 4) (columns 1 and 3).
dense = csr.todense()
mx.eval(dense)
print(np.array(dense))
[[ 2. 0. -1. 0.]
[ 0. 0. 0. 0.]
[ 0. 4. 0. 5.]]
Sparse matrix-vector product#
The @ operator dispatches to the native csr_matvec primitive.
The result is a lazy MLX array — nothing runs until mx.eval is called.
x = mx.array([1.0, 1.0, 1.0, 1.0], dtype=mx.float32)
y = csr @ x # lazy, dispatches Metal kernel on eval
mx.eval(y)
print("y = A @ x:", np.array(y))
# Verify against dense matmul.
y_dense = csr.todense() @ x
mx.eval(y_dense)
print("expected: ", np.array(y_dense))
y = A @ x: [1. 0. 9.]
expected: [1. 0. 9.]
Transpose and Hermitian transpose#
At = csr.T
print("A.T shape:", At.shape)
print(np.array(At.todense()))
A.T shape: (4, 3)
[[ 2. 0. 0.]
[ 0. 0. 4.]
[-1. 0. 0.]
[ 0. 0. 5.]]
Summary#
Step |
Call |
Result |
|---|---|---|
Construct |
|
|
Convert |
|
|
Product |
|
lazy |
Materialise |
|
computes on device |
Dense view |
|
|
Transpose |
|
new |
Next: 02 — CSR matvec for larger matrices and timing.