Performance benchmarks#
This notebook benchmarks mlx-sparse against dense MLX matrix operations across a range of matrix sizes, densities, and right-hand-side widths.
Environment: Apple M5, 10-core GPU, macOS 26.0, MLX 0.31, mlx-sparse 0.0.1b0
The key insight: sparse operations scale with the number of non-zeros nnz,
not with n². At low densities the advantage is dramatic.
Cost(SpMV) ≈ O(nnz) = O(density x n²)
Cost(dense) ≈ O(n²)
Speedup ≈ 1 / density (ideal)
import time
import statistics
import mlx.core as mx
import numpy as np
import scipy.sparse
import mlx_sparse as ms
ms.use_gpu()
def bench(fn, warmup=5, iters=50):
"""Median wall time (ms) over `iters` evaluations after `warmup`."""
for _ in range(warmup):
mx.eval(fn())
times = []
for _ in range(iters):
t0 = time.perf_counter()
mx.eval(fn())
times.append(time.perf_counter() - t0)
return statistics.median(times) * 1000
def make_csr(n, density, seed=0):
sp = scipy.sparse.random(n, n, density=density, format="csr",
dtype=np.float32, random_state=seed)
A = ms.csr_array(
(mx.array(sp.data),
mx.array(sp.indices.astype(np.int32)),
mx.array(sp.indptr.astype(np.int32))),
shape=sp.shape, sorted_indices=True, canonical=True,
)
D = A.todense()
mx.eval(D)
return A, D
SpMV: sparse vs dense matrix-vector product#
Increasing matrix size from 2k to 16k at a fixed low density.
print(f"{'shape':<14} {'nnz':<8} {'density':<10} {'sparse_ms':<12} {'dense_ms':<12} {'speedup'}")
for n, density in [(2048, 0.00050), (4096, 0.00025), (8192, 0.0001), (16384, 0.00003)]:
A, D = make_csr(n, density)
x = mx.array(np.random.randn(n).astype(np.float32))
t_sp = bench(lambda: A @ x)
t_dn = bench(lambda: D @ x)
shape_str = f"({n},{n})"
print(f"{shape_str:<14} {A.nnz:<8} {density*100:.3f}% "
f"{t_sp:.3f} ms {t_dn:.3f} ms {t_dn/t_sp:.1f}x")
shape nnz density sparse_ms dense_ms speedup
(2048,2048) 2097 0.050% 0.195 ms 0.340 ms 1.7x
(4096,4096) 4194 0.025% 0.190 ms 0.771 ms 4.1x
(8192,8192) 6711 0.010% 0.194 ms 2.477 ms 12.8x
(16384,16384) 8053 0.003% 0.210 ms 9.118 ms 43.5x
SpMM: sparse vs dense matrix-matrix product#
Two right-hand-side widths (k=16, k=64) at increasing matrix size.
print(f"{'shape':<14} {'k':<5} {'density':<10} {'sparse_ms':<12} {'dense_ms':<12} {'speedup'}")
for n, density, k in [
(2048, 0.001, 16), (2048, 0.001, 64),
(4096, 0.00025, 16), (4096, 0.00025, 64),
(8192, 0.0001, 16), (8192, 0.0001, 64),
]:
A, D = make_csr(n, density)
B = mx.array(np.random.randn(n, k).astype(np.float32))
t_sp = bench(lambda: A @ B)
t_dn = bench(lambda: D @ B)
shape_str = f"({n},{n})"
print(f"{shape_str:<14} {k:<5} {density*100:.3f}% "
f"{t_sp:.3f} ms {t_dn:.3f} ms {t_dn/t_sp:.1f}x")
shape k density sparse_ms dense_ms speedup
(2048,2048) 16 0.100% 0.174 ms 0.401 ms 2.3x
(2048,2048) 64 0.100% 0.203 ms 0.445 ms 2.2x
(4096,4096) 16 0.025% 0.196 ms 1.000 ms 5.1x
(4096,4096) 64 0.025% 0.211 ms 1.023 ms 4.8x
(8192,8192) 16 0.010% 0.203 ms 3.266 ms 16.1x
(8192,8192) 64 0.010% 0.224 ms 3.748 ms 16.8x
Density crossover: when does dense win?#
Below some density threshold, sparse is faster. Above it, the overhead of indirect memory access means dense is better. This varies by matrix size.
We sweep density at n=4096 and find the crossover point.
n = 4096
x = mx.array(np.random.randn(n).astype(np.float32))
print(f"n={n} SpMV crossover analysis (M5)")
print(f"\n{'density':<12} {'nnz':<10} {'sparse_ms':<12} {'dense_ms':<12} {'ratio (dense/sparse)'}")
for density in [0.00001, 0.00005, 0.00025, 0.001, 0.005, 0.01, 0.025, 0.05]:
A, D = make_csr(n, density)
t_sp = bench(lambda: A @ x)
t_dn = bench(lambda: D @ x)
ratio = t_dn / t_sp
winner = "[sparse faster]" if ratio > 1 else "[dense faster] "
print(f"{density*100:.4f}% {A.nnz:<10} {t_sp:.3f} ms {t_dn:.3f} ms "
f"{ratio:.1f}x {winner}")
n=4096 SpMV crossover analysis (M5)
density nnz sparse_ms dense_ms ratio (dense/sparse)
0.0010% 168 0.206 ms 0.778 ms 3.8x [sparse faster]
0.0050% 839 0.207 ms 0.776 ms 3.7x [sparse faster]
0.0250% 4194 0.197 ms 0.767 ms 3.9x [sparse faster]
0.1000% 16777 0.209 ms 0.753 ms 3.6x [sparse faster]
0.5000% 83886 0.216 ms 0.750 ms 3.5x [sparse faster]
1.0000% 167772 0.233 ms 0.765 ms 3.3x [sparse faster]
2.5000% 419430 0.236 ms 0.771 ms 3.3x [sparse faster]
5.0000% 838861 0.247 ms 0.768 ms 3.1x [sparse faster]
Batched SpMM: multiple RHS tensors#
Batched SpMM (A @ B_batch where B_batch is rank-3) has no loop overhead,
it reshapes internally and runs a single kernel call.
n, density, k = 4096, 0.001, 16
A, _ = make_csr(n, density)
print(f"Batched SpMM vs loop-based reference (n={n}, density={density*100:.1f}%, k={k})")
print(f"\n{'batch':<7} {'batched_ms':<12} {'loop_ms':<10} {'speedup'}")
for batch in [1, 4, 8, 16, 32]:
B_batch = mx.array(np.random.randn(batch, n, k).astype(np.float32))
# Batched call (single kernel)
t_batched = bench(lambda: A @ B_batch)
# Loop reference
slices = [B_batch[i] for i in range(batch)]
def loop_spmm():
results = [A @ slices[i] for i in range(batch)]
return mx.stack(results)
t_loop = bench(loop_spmm)
print(f"{batch:<7} {t_batched:.3f} ms {t_loop:.3f} ms {t_loop/t_batched:.1f}x")
Batched SpMM vs loop-based reference (n=4096, density=0.1%, k=16)
batch batched_ms loop_ms speedup
1 0.204 ms 0.200 ms 1.0x
4 0.228 ms 0.230 ms 1.0x
8 0.245 ms 0.248 ms 1.0x
16 0.319 ms 0.300 ms 0.9x
32 0.534 ms 0.562 ms 1.1x
Performance by dtype#
float16 and bfloat16 can be faster than float32 due to lower memory
bandwidth, or on some hardware, equal due to throughput being compute-bound.
n, density = 8192, 0.0001
sp_b = scipy.sparse.random(n, n, density=density, format="csr",
dtype=np.float32, random_state=0)
print(f"SpMV timing by dtype (n={n}, density={density*100:.2f}%)")
print(f"\n{'dtype':<11} {'sparse_ms':<12} {'dense_ms':<11} {'speedup'}")
for mlx_dtype, label in [
(mx.float32, "float32 "),
(mx.float16, "float16 "),
(mx.bfloat16, "bfloat16"),
]:
data = mx.array(sp_b.data).astype(mlx_dtype)
Ai = ms.csr_array(
(data,
mx.array(sp_b.indices.astype(np.int32)),
mx.array(sp_b.indptr.astype(np.int32))),
shape=sp_b.shape, sorted_indices=True, canonical=True,
)
Di = Ai.todense().astype(mlx_dtype)
mx.eval(Di)
xi = mx.array(np.random.randn(n).astype(np.float32)).astype(mlx_dtype)
t_sp = bench(lambda: Ai @ xi)
t_dn = bench(lambda: Di @ xi)
print(f"{label:<11} {t_sp:.3f} ms {t_dn:.3f} ms {t_dn/t_sp:.1f}x")
SpMV timing by dtype (n=8192, density=0.01%)
dtype sparse_ms dense_ms speedup
float32 0.210 ms 2.427 ms 11.5x
float16 0.211 ms 1.293 ms 6.1x
bfloat16 0.204 ms 1.295 ms 6.4x