CSR matmul (SpMM) and batched products#

csr_matmul computes Y = A @ B where A is a CSRArray and B is a rank-2 (or higher) dense matrix. It also handles batched inputs: if B has ndim > 2, the batch dimensions are handled by reshaping internally, no Python loops.

The @ operator on CSRArray dispatches automatically:

A @ x  # x.ndim == 1 -> csr_matvec
A @ B  # B.ndim == 2 -> csr_matmul
A @ Bbatch  # B.ndim > 2 -> csr_matmul (batch)
A @ C  # C is CSRArray -> csr_matmat (sparse-sparse)
import mlx.core as mx
import numpy as np
import scipy.sparse
import mlx_sparse as ms

ms.use_gpu()
Device(gpu, 0)
rng = np.random.default_rng(7)
sp = scipy.sparse.random(512, 512, density=0.002, format="csr",
                           dtype=np.float32, random_state=rng)

A = ms.csr_array(
    (mx.array(sp.data), mx.array(sp.indices.astype(np.int32)),
     mx.array(sp.indptr.astype(np.int32))),
    shape=sp.shape, sorted_indices=True, canonical=True,
)
print("A:", A)
A: CSRArray(shape=(512, 512), nnz=524, dtype=mlx.core.float32, index_dtype=mlx.core.int32, sorted_indices=True, has_canonical_format=True)

Rank-2 RHS: A @ B#

B_np = rng.standard_normal((512, 64)).astype(np.float32)
B = mx.array(B_np)

Y = A @ B
mx.eval(Y)

print(f"Y = A @ B:  shape {Y.shape}  dtype {Y.dtype}")

# Compare to SciPy.
Y_sp = sp @ B_np
err = np.max(np.abs(np.array(Y) - Y_sp))
print(f"max error vs SciPy: {err:.2e}")
Y = A @ B:  shape (512, 64)  dtype mlx.core.float32
max error vs SciPy: 4.77e-07

Batched RHS: A @ B_batch#

Pass a rank-3 tensor of shape (batch, n_cols, k). The output will have shape (batch, n_rows, k).

# B_batch: 8 independent dense matrices, each (512, 16)
B_batch = mx.array(rng.standard_normal((8, 512, 16)).astype(np.float32))
print("B_batch shape:", B_batch.shape)

Y_batch = A @ B_batch
mx.eval(Y_batch)
print("Y_batch shape:", Y_batch.shape)

# Verify: Y_batch[3] should equal A @ B_batch[3]
Y3_ref = A @ B_batch[3]
mx.eval(Y3_ref)
match = np.allclose(np.array(Y_batch[3]), np.array(Y3_ref), atol=1e-5)
print("Slice [0] matches rank-2 result:", match)
B_batch shape: (8, 512, 16)
Y_batch shape: (8, 512, 16)
Slice [0] matches rank-2 result: True

Batched gradient#

Autodiff works through batched matmul without any special handling.

def loss(B):
    Y = A @ B
    return mx.sum(Y ** 2)

grad_B = mx.grad(loss)(B_batch)
mx.eval(grad_B)
print("grad shape:", grad_B.shape)

# Compare to dense reference.
dense = A.todense()
mx.eval(dense)
def dense_loss(B): return mx.sum((dense @ B) ** 2)
grad_B_ref = mx.grad(dense_loss)(B_batch)
mx.eval(grad_B_ref)

err = np.max(np.abs(np.array(grad_B) - np.array(grad_B_ref)))
print(f"max error vs dense grad: {err:.2e}")
grad shape: (8, 512, 16)
max error vs dense grad: 9.54e-07

SpMM timing vs dense matmul#

Environment: Apple M5, 10-core GPU, macOS 26.0, MLX 0.31, mlx-sparse 0.0.1b0

import time

def bench(fn, warmup=5, iters=50):
    for _ in range(warmup): mx.eval(fn())
    t0 = time.perf_counter()
    for _ in range(iters): mx.eval(fn())
    return (time.perf_counter() - t0) / iters * 1000

print(f"{'shape':<14} {'k':<4} {'density':<9} {'sparse_ms':<11} {'dense_ms':<11} {'speedup'}")
for n, density, k in [
    (2048, 0.001, 16), (2048, 0.001, 64),
    (4096, 0.00025, 16), (4096, 0.00025, 64),
    (8192, 0.0001, 16), (8192, 0.0001, 64),
]:
    sp_b = scipy.sparse.random(n, n, density=density, format="csr",
                               dtype=np.float32, random_state=1)
    Am = ms.csr_array(
        (mx.array(sp_b.data), mx.array(sp_b.indices.astype(np.int32)),
         mx.array(sp_b.indptr.astype(np.int32))),
        shape=sp_b.shape, sorted_indices=True, canonical=True,
    )
    Bm = mx.array(np.random.randn(n, k).astype(np.float32))
    Dm = Am.todense()
    mx.eval(Dm)

    t_sp = bench(lambda: Am @ Bm)
    t_dn = bench(lambda: Dm @ Bm)
    print(f"({n},{n}){' '*(10-len(str(n))*2)} {k:<4} {density*100:.3f}%    {t_sp:.3f} ms   {t_dn:.3f} ms   {t_dn/t_sp:.1f}x")
shape          k    density   sparse_ms   dense_ms    speedup
(2048,2048)   16   0.100%    0.206 ms   0.441 ms   2.1x
(2048,2048)   64   0.100%    0.214 ms   0.431 ms   2.0x
(4096,4096)   16   0.025%    0.221 ms   0.963 ms   4.4x
(4096,4096)   64   0.025%    0.237 ms   1.027 ms   4.3x
(8192,8192)   16   0.010%    0.198 ms   3.308 ms   16.7x
(8192,8192)   64   0.010%    0.240 ms   3.685 ms   15.4x