Sparse-sparse products (csr_matmat)#

csr_matmat computes C = A @ B where both operands are CSRArray instances. The @ operator on CSRArray dispatches automatically:

A @ x  # x is rank-1 dense -> csr_matvec
A @ B  # B is rank-2 dense -> csr_matmul
A @ C  # C is CSRArray -> csr_matmat

The result is always a canonical CSRArray with has_canonical_format=True.

Note: The output sparsity pattern is data-dependent. csr_matmat performs a structural pass on the host (calls mx.eval on inputs internally) and is therefore not suitable inside a JIT-compiled function or a GPU-hot loop.

import mlx.core as mx
import numpy as np
import scipy.sparse
import mlx_sparse as ms

ms.use_gpu()
Device(gpu, 0)

Basic example#

# Build two small sparse matrices
A_dense = np.array([[1, 0, 2], [0, 3, 0], [0, 0, 4]], dtype=np.float32)
B_dense = np.array([[1, 0], [0, 2], [3, 0]], dtype=np.float32)

A = ms.fromdense(mx.array(A_dense))
B = ms.fromdense(mx.array(B_dense))

print("A:\n", A_dense)
print("\nB:\n", B_dense)

C = A @ B  # dispatches csr_matmat
mx.eval(C.data)

print("\nC = A @ B:\n", C)
print(np.array(C.todense()))
print("\nDense reference:\n", A_dense @ B_dense)
A:
 [[1. 0. 2.]
 [0. 3. 0.]
 [0. 0. 4.]]

B:
 [[1. 0.]
 [0. 2.]
 [3. 0.]]

C = A @ B:
 CSRArray(shape=(3, 2), nnz=3, dtype=mlx.core.float32, index_dtype=mlx.core.int32, sorted_indices=True, has_canonical_format=True)
[[ 7.  0.]
 [ 0.  6.]
 [12.  0.]]

Dense reference:
 [[ 7.  0.]
 [ 0.  6.]
 [12.  0.]]

Correctness check on a larger random matrix#

rng = np.random.default_rng(42)

sp_a = scipy.sparse.random(512, 512, density=0.001, format="csr", dtype=np.float32, random_state=rng)
sp_b = scipy.sparse.random(512, 512, density=0.001, format="csr", dtype=np.float32, random_state=rng)

A = ms.from_scipy(sp_a)
B = ms.from_scipy(sp_b)
print("A:", A)
print("B:", B)

C = ms.csr_matmat(A, B)
mx.eval(C.data)
print("C = A @ B:", C)

C_sp = (sp_a @ sp_b).toarray()
err = np.max(np.abs(np.array(C.todense()) - C_sp))
print(f"max error vs SciPy: {err:.2e}")
A: CSRArray(shape=(512, 512), nnz=262, dtype=mlx.core.float32, index_dtype=mlx.core.int32, sorted_indices=True, has_canonical_format=True)
B: CSRArray(shape=(512, 512), nnz=262, dtype=mlx.core.float32, index_dtype=mlx.core.int32, sorted_indices=True, has_canonical_format=True)
C = A @ B: CSRArray(shape=(512, 512), nnz=127, dtype=mlx.core.float32, index_dtype=mlx.core.int32, sorted_indices=True, has_canonical_format=True)
max error vs SciPy: 0.00e+00

Matrix powers#

Chaining @ computes sparse matrix powers. Each intermediate result is a canonical CSRArray.

n = 6
T = ms.diags(
    [np.full(n - 1, -1.0), np.full(n, 2.0), np.full(n - 1, -1.0)],
    offsets=[-1, 0, 1],
)
print("T (tridiagonal 6x6):\n", np.array(T.todense()))

T2 = T @ T
T4 = T2 @ T2
mx.eval(T2.data, T4.data)
print(f"\nT^2 nnz={T2.nnz}  (dense would be {n*n})")
print(f"T^4 nnz={T4.nnz}  (dense would be {n*n})")

# Verify against NumPy
T_np = np.array(T.todense())
err2 = np.max(np.abs(np.array(T2.todense()) - T_np @ T_np))
err4 = np.max(np.abs(np.array(T4.todense()) - np.linalg.matrix_power(T_np, 4)))
print(f"max error T^2 vs dense: {err2:.2e}")
print(f"max error T^4 vs dense: {err4:.2e}")
T (tridiagonal 6x6):
 [[ 2. -1.  0.  0.  0.  0.]
 [-1.  2. -1.  0.  0.  0.]
 [ 0. -1.  2. -1.  0.  0.]
 [ 0.  0. -1.  2. -1.  0.]
 [ 0.  0.  0. -1.  2. -1.]
 [ 0.  0.  0.  0. -1.  2.]]

T^2 nnz=24  (dense would be 36)
T^4 nnz=34  (dense would be 36)
max error T^2 vs dense: 0.00e+00
max error T^4 vs dense: 0.00e+00

Non-square and rectangular products#

rng = np.random.default_rng(7)
sp_a2 = scipy.sparse.random(4, 6, density=0.3, format="csr", dtype=np.float32, random_state=rng)
sp_b2 = scipy.sparse.random(6, 3, density=0.4, format="csr", dtype=np.float32, random_state=rng)

A2 = ms.from_scipy(sp_a2)
B2 = ms.from_scipy(sp_b2)
C2 = A2 @ B2
mx.eval(C2.data)

print("A (4x6):", A2)
print("B (6x3):", B2)
print("C (4x3):", C2)

ref = sp_a2.toarray() @ sp_b2.toarray()
err = np.max(np.abs(np.array(C2.todense()) - ref))
print(f"max error vs dense: {err:.2e}")
A (4x6): CSRArray(shape=(4, 6), nnz=7, dtype=mlx.core.float32, index_dtype=mlx.core.int32, sorted_indices=True, has_canonical_format=True)
B (6x3): CSRArray(shape=(6, 3), nnz=7, dtype=mlx.core.float32, index_dtype=mlx.core.int32, sorted_indices=True, has_canonical_format=True)
C (4x3): CSRArray(shape=(4, 3), nnz=4, dtype=mlx.core.float32, index_dtype=mlx.core.int32, sorted_indices=True, has_canonical_format=True)
max error vs dense: 0.00e+00

Fill-in: how product density compares to input density#

Sparse-sparse products can produce a denser result than either input, this is called fill-in. The amount depends on the sparsity pattern.

n = 256
for density in [0.001, 0.005, 0.01, 0.05]:
    sp1 = scipy.sparse.random(n, n, density=density, format="csr", dtype=np.float32, random_state=0)
    sp2 = scipy.sparse.random(n, n, density=density, format="csr", dtype=np.float32, random_state=1)
    Ac = ms.from_scipy(sp1)
    Bc = ms.from_scipy(sp2)
    Cc = Ac @ Bc
    mx.eval(Cc.data)
    d_in = density * 100
    d_out = Cc.nnz / (n * n) * 100
    print(f"A density={d_in:.3f}%  B density={d_in:.3f}%  "
          f"C density={d_out:.3f}%  fill-in factor: {d_out/d_in:.2f}x")
A density=0.100%  B density=0.100%  C density=0.027%  fill-in factor: 0.27x
A density=0.500%  B density=0.500%  C density=0.607%  fill-in factor: 1.21x
A density=1.000%  B density=1.000%  C density=2.541%  fill-in factor: 2.54x
A density=5.000%  B density=5.000%  C density=46.960%  fill-in factor: 9.39x