CSR matvec (SpMV)#
csr_matvec computes y = A @ x where A is a CSRArray and x is a
rank-1 dense vector. On Apple Silicon the Metal backend dispatches:
A scalar row kernel for short rows (one thread per row).
A threadgroup vector-reduction kernel for long rows, selected automatically from the known
nnz / n_rowsratio without any host synchronisation.
All four value dtypes (float32, float16, bfloat16, complex64) and both
index dtypes (int32, int64) are supported on CPU and GPU.
import mlx.core as mx
import numpy as np
import scipy.sparse
import mlx_sparse as ms
ms.use_gpu()
print("Native extension:", ms.is_available())
Native extension: True
Build a medium-sized random sparse matrix#
We use SciPy to generate a reproducible random CSR matrix and then transfer the buffers to mlx-sparse.
rng = np.random.default_rng(42)
sp = scipy.sparse.random(
4096, 4096, density=0.00025, format="csr",
dtype=np.float32, random_state=rng
)
A = ms.csr_array(
(mx.array(sp.data), mx.array(sp.indices.astype(np.int32)),
mx.array(sp.indptr.astype(np.int32))),
shape=sp.shape,
sorted_indices=True,
canonical=True,
)
print(A)
print(f"density: {A.nnz / (A.shape[0] * A.shape[1]) * 100:.3f}%")
CSRArray(shape=(4096, 4096), nnz=4178, dtype=float32, index_dtype=int32, sorted_indices=True, has_canonical_format=True)
density: 0.025%
Correctness check against SciPy#
x_np = rng.standard_normal(4096).astype(np.float32)
x = mx.array(x_np)
y_ms = A @ x
mx.eval(y_ms)
y_sp = sp @ x_np
err = np.max(np.abs(np.array(y_ms) - y_sp))
print(f"max absolute error vs SciPy: {err:.2e}")
assert err < 1e-4, "Results diverge!"
print("Results match within float32 tolerance.")
max absolute error vs SciPy: 9.54e-07
Results match within float32 tolerance.
All value dtypes work on GPU#
As of v0.0.1b0, the Metal backend supports float32, float16, bfloat16,
and complex64, all with int32 or int64 indices.
for value_dtype, mlx_dtype in [
(np.float32, mx.float32),
(np.float16, mx.float16),
("bfloat16", mx.bfloat16),
(np.complex64, mx.complex64),
]:
if value_dtype == "bfloat16":
A_typed = ms.csr_array(
(A.data.astype(mx.bfloat16), A.indices, A.indptr),
shape=A.shape, sorted_indices=True, canonical=True,
)
x_typed = x.astype(mx.bfloat16)
else:
A_typed = ms.csr_array(
(A.data.astype(mlx_dtype), A.indices, A.indptr),
shape=A.shape, sorted_indices=True, canonical=True,
)
x_typed = x.astype(mlx_dtype)
y_typed = A_typed @ x_typed
mx.eval(y_typed)
name = getattr(value_dtype, '__name__', str(value_dtype))
print(f"{name:<9} -> y.dtype = {y_typed.dtype}, shape {y_typed.shape}")
float32 -> y.dtype = float32, shape (4096,)
float16 -> y.dtype = float16, shape (4096,)
bfloat16 -> y.dtype = bfloat16, shape (4096,)
complex64 -> y.dtype = complex64, shape (4096,)
Timing: sparse vs dense on M5#
We compare csr_matvec to mx.matmul (dense) at increasing matrix sizes.
Timings are the median of 50 iterations after 5 warmup rounds.
Environment: Apple M5, 10-core GPU, macOS 26.0, MLX 0.31, mlx-sparse 0.0.1b0
import time, statistics
def bench(fn, warmup=5, iters=50):
for _ in range(warmup):
mx.eval(fn())
times = []
for _ in range(iters):
t0 = time.perf_counter()
mx.eval(fn())
times.append(time.perf_counter() - t0)
return statistics.median(times) * 1000 # ms
print(f"{'shape':<15} {'nnz':<8} {'density':<9} {'sparse_ms':<11} {'dense_ms':<11} {'speedup'}")
for n, density in [(4096, 0.00025), (8192, 0.0001), (16384, 0.00003), (32768, 0.00001)]:
sp_b = scipy.sparse.random(n, n, density=density, format="csr",
dtype=np.float32, random_state=0)
A_b = ms.csr_array(
(mx.array(sp_b.data), mx.array(sp_b.indices.astype(np.int32)),
mx.array(sp_b.indptr.astype(np.int32))),
shape=sp_b.shape, sorted_indices=True, canonical=True,
)
x_b = mx.array(np.random.randn(n).astype(np.float32))
dense_b = A_b.todense()
mx.eval(dense_b)
t_sp = bench(lambda: A_b @ x_b)
t_dn = bench(lambda: dense_b @ x_b)
print(f"({n},{n}){' '*(12 - len(str(n))*2)} {A_b.nnz:<8} {density*100:.3f}% {t_sp:.3f} ms {t_dn:.3f} ms {t_dn/t_sp:.1f}x")
shape nnz density sparse_ms dense_ms speedup
(4096,4096) 4178 0.025% 0.046 ms 0.871 ms 19.0x
(8192,8192) 6718 0.010% 0.061 ms 3.124 ms 51.2x
(16384,16384) 8020 0.003% 0.072 ms 12.31 ms 171.0x
(32768,32768) 10753 0.001% 0.091 ms 48.77 ms 536.0x
Key insight: at very low densities (< 0.01%) and large matrices, sparse
is dramatically faster because it only touches the non-zero entries. Dense
matmul costs O(n²) regardless of sparsity.