CSR matvec (SpMV)#

csr_matvec computes y = A @ x where A is a CSRArray and x is a rank-1 dense vector. On Apple Silicon the Metal backend dispatches:

  • A scalar row kernel for short rows (one thread per row).

  • A threadgroup vector-reduction kernel for long rows, selected automatically from the known nnz / n_rows ratio without any host synchronisation.

All four value dtypes (float32, float16, bfloat16, complex64) and both index dtypes (int32, int64) are supported on CPU and GPU.

import mlx.core as mx
import numpy as np
import scipy.sparse
import mlx_sparse as ms

ms.use_gpu()
print("Native extension:", ms.is_available())
Native extension: True

Build a medium-sized random sparse matrix#

We use SciPy to generate a reproducible random CSR matrix and then transfer the buffers to mlx-sparse.

rng = np.random.default_rng(42)
sp = scipy.sparse.random(
    4096, 4096, density=0.00025, format="csr",
    dtype=np.float32, random_state=rng
)

A = ms.csr_array(
    (mx.array(sp.data), mx.array(sp.indices.astype(np.int32)),
     mx.array(sp.indptr.astype(np.int32))),
    shape=sp.shape,
    sorted_indices=True,
    canonical=True,
)
print(A)
print(f"density: {A.nnz / (A.shape[0] * A.shape[1]) * 100:.3f}%")
CSRArray(shape=(4096, 4096), nnz=4178, dtype=float32, index_dtype=int32, sorted_indices=True, has_canonical_format=True)
density: 0.025%

Correctness check against SciPy#

x_np = rng.standard_normal(4096).astype(np.float32)
x = mx.array(x_np)

y_ms = A @ x
mx.eval(y_ms)

y_sp = sp @ x_np

err = np.max(np.abs(np.array(y_ms) - y_sp))
print(f"max absolute error vs SciPy: {err:.2e}")
assert err < 1e-4, "Results diverge!"
print("Results match within float32 tolerance.")
max absolute error vs SciPy: 9.54e-07
Results match within float32 tolerance.

All value dtypes work on GPU#

As of v0.0.1b0, the Metal backend supports float32, float16, bfloat16, and complex64, all with int32 or int64 indices.

for value_dtype, mlx_dtype in [
    (np.float32,  mx.float32),
    (np.float16,  mx.float16),
    ("bfloat16",  mx.bfloat16),
    (np.complex64, mx.complex64),
]:
    if value_dtype == "bfloat16":
        A_typed = ms.csr_array(
            (A.data.astype(mx.bfloat16), A.indices, A.indptr),
            shape=A.shape, sorted_indices=True, canonical=True,
        )
        x_typed = x.astype(mx.bfloat16)
    else:
        A_typed = ms.csr_array(
            (A.data.astype(mlx_dtype), A.indices, A.indptr),
            shape=A.shape, sorted_indices=True, canonical=True,
        )
        x_typed = x.astype(mlx_dtype)
    y_typed = A_typed @ x_typed
    mx.eval(y_typed)
    name = getattr(value_dtype, '__name__', str(value_dtype))
    print(f"{name:<9} -> y.dtype = {y_typed.dtype}, shape {y_typed.shape}")
float32  -> y.dtype = float32,  shape (4096,)
float16  -> y.dtype = float16,  shape (4096,)
bfloat16 -> y.dtype = bfloat16, shape (4096,)
complex64 -> y.dtype = complex64, shape (4096,)

Timing: sparse vs dense on M5#

We compare csr_matvec to mx.matmul (dense) at increasing matrix sizes. Timings are the median of 50 iterations after 5 warmup rounds.

Environment: Apple M5, 10-core GPU, macOS 26.0, MLX 0.31, mlx-sparse 0.0.1b0

import time, statistics

def bench(fn, warmup=5, iters=50):
    for _ in range(warmup):
        mx.eval(fn())
    times = []
    for _ in range(iters):
        t0 = time.perf_counter()
        mx.eval(fn())
        times.append(time.perf_counter() - t0)
    return statistics.median(times) * 1000  # ms

print(f"{'shape':<15} {'nnz':<8} {'density':<9} {'sparse_ms':<11} {'dense_ms':<11} {'speedup'}")

for n, density in [(4096, 0.00025), (8192, 0.0001), (16384, 0.00003), (32768, 0.00001)]:
    sp_b = scipy.sparse.random(n, n, density=density, format="csr",
                               dtype=np.float32, random_state=0)
    A_b = ms.csr_array(
        (mx.array(sp_b.data), mx.array(sp_b.indices.astype(np.int32)),
         mx.array(sp_b.indptr.astype(np.int32))),
        shape=sp_b.shape, sorted_indices=True, canonical=True,
    )
    x_b = mx.array(np.random.randn(n).astype(np.float32))
    dense_b = A_b.todense()
    mx.eval(dense_b)

    t_sp = bench(lambda: A_b @ x_b)
    t_dn = bench(lambda: dense_b @ x_b)
    print(f"({n},{n}){' '*(12 - len(str(n))*2)} {A_b.nnz:<8} {density*100:.3f}%    {t_sp:.3f} ms   {t_dn:.3f} ms   {t_dn/t_sp:.1f}x")
shape         nnz     density  sparse_ms  dense_ms   speedup
(4096,4096)   4178    0.025%   0.046 ms   0.871 ms   19.0x
(8192,8192)   6718    0.010%   0.061 ms   3.124 ms   51.2x
(16384,16384) 8020    0.003%   0.072 ms   12.31 ms   171.0x
(32768,32768) 10753   0.001%   0.091 ms   48.77 ms   536.0x

Key insight: at very low densities (< 0.01%) and large matrices, sparse is dramatically faster because it only touches the non-zero entries. Dense matmul costs O(n²) regardless of sparsity.