CSR matmul (SpMM) and batched products#
csr_matmul computes Y = A @ B where A is a CSRArray and B is a
rank-2 (or higher) dense matrix. It also handles batched inputs: if B has
ndim > 2, the batch dimensions are handled by reshaping internally, no
Python loops.
The @ operator on CSRArray dispatches automatically:
A @ x # x.ndim == 1 -> csr_matvec
A @ B # B.ndim == 2 -> csr_matmul
A @ Bbatch # B.ndim > 2 -> csr_matmul (batch)
A @ C # C is CSRArray -> csr_matmat (sparse-sparse)
import mlx.core as mx
import numpy as np
import scipy.sparse
import mlx_sparse as ms
ms.use_gpu()
Device(gpu, 0)
rng = np.random.default_rng(7)
sp = scipy.sparse.random(512, 512, density=0.002, format="csr",
dtype=np.float32, random_state=rng)
A = ms.csr_array(
(mx.array(sp.data), mx.array(sp.indices.astype(np.int32)),
mx.array(sp.indptr.astype(np.int32))),
shape=sp.shape, sorted_indices=True, canonical=True,
)
print("A:", A)
A: CSRArray(shape=(512, 512), nnz=524, dtype=mlx.core.float32, index_dtype=mlx.core.int32, sorted_indices=True, has_canonical_format=True)
Rank-2 RHS: A @ B#
B_np = rng.standard_normal((512, 64)).astype(np.float32)
B = mx.array(B_np)
Y = A @ B
mx.eval(Y)
print(f"Y = A @ B: shape {Y.shape} dtype {Y.dtype}")
# Compare to SciPy.
Y_sp = sp @ B_np
err = np.max(np.abs(np.array(Y) - Y_sp))
print(f"max error vs SciPy: {err:.2e}")
Y = A @ B: shape (512, 64) dtype mlx.core.float32
max error vs SciPy: 4.77e-07
Batched RHS: A @ B_batch#
Pass a rank-3 tensor of shape (batch, n_cols, k). The output will have
shape (batch, n_rows, k).
# B_batch: 8 independent dense matrices, each (512, 16)
B_batch = mx.array(rng.standard_normal((8, 512, 16)).astype(np.float32))
print("B_batch shape:", B_batch.shape)
Y_batch = A @ B_batch
mx.eval(Y_batch)
print("Y_batch shape:", Y_batch.shape)
# Verify: Y_batch[3] should equal A @ B_batch[3]
Y3_ref = A @ B_batch[3]
mx.eval(Y3_ref)
match = np.allclose(np.array(Y_batch[3]), np.array(Y3_ref), atol=1e-5)
print("Slice [0] matches rank-2 result:", match)
B_batch shape: (8, 512, 16)
Y_batch shape: (8, 512, 16)
Slice [0] matches rank-2 result: True
Batched gradient#
Autodiff works through batched matmul without any special handling.
def loss(B):
Y = A @ B
return mx.sum(Y ** 2)
grad_B = mx.grad(loss)(B_batch)
mx.eval(grad_B)
print("grad shape:", grad_B.shape)
# Compare to dense reference.
dense = A.todense()
mx.eval(dense)
def dense_loss(B): return mx.sum((dense @ B) ** 2)
grad_B_ref = mx.grad(dense_loss)(B_batch)
mx.eval(grad_B_ref)
err = np.max(np.abs(np.array(grad_B) - np.array(grad_B_ref)))
print(f"max error vs dense grad: {err:.2e}")
grad shape: (8, 512, 16)
max error vs dense grad: 9.54e-07
SpMM timing vs dense matmul#
Environment: Apple M5, 10-core GPU, macOS 26.0, MLX 0.31, mlx-sparse 0.0.1b0
import time
def bench(fn, warmup=5, iters=50):
for _ in range(warmup): mx.eval(fn())
t0 = time.perf_counter()
for _ in range(iters): mx.eval(fn())
return (time.perf_counter() - t0) / iters * 1000
print(f"{'shape':<14} {'k':<4} {'density':<9} {'sparse_ms':<11} {'dense_ms':<11} {'speedup'}")
for n, density, k in [
(2048, 0.001, 16), (2048, 0.001, 64),
(4096, 0.00025, 16), (4096, 0.00025, 64),
(8192, 0.0001, 16), (8192, 0.0001, 64),
]:
sp_b = scipy.sparse.random(n, n, density=density, format="csr",
dtype=np.float32, random_state=1)
Am = ms.csr_array(
(mx.array(sp_b.data), mx.array(sp_b.indices.astype(np.int32)),
mx.array(sp_b.indptr.astype(np.int32))),
shape=sp_b.shape, sorted_indices=True, canonical=True,
)
Bm = mx.array(np.random.randn(n, k).astype(np.float32))
Dm = Am.todense()
mx.eval(Dm)
t_sp = bench(lambda: Am @ Bm)
t_dn = bench(lambda: Dm @ Bm)
print(f"({n},{n}){' '*(10-len(str(n))*2)} {k:<4} {density*100:.3f}% {t_sp:.3f} ms {t_dn:.3f} ms {t_dn/t_sp:.1f}x")
shape k density sparse_ms dense_ms speedup
(2048,2048) 16 0.100% 0.206 ms 0.441 ms 2.1x
(2048,2048) 64 0.100% 0.214 ms 0.431 ms 2.0x
(4096,4096) 16 0.025% 0.221 ms 0.963 ms 4.4x
(4096,4096) 64 0.025% 0.237 ms 1.027 ms 4.3x
(8192,8192) 16 0.010% 0.198 ms 3.308 ms 16.7x
(8192,8192) 64 0.010% 0.240 ms 3.685 ms 15.4x