SciPy interoperability#

mlx-sparse does not provide automatic SciPy interop, scipy.sparse and MLX live on different sides of the host/accelerator boundary. But the buffer layout of CSRArray is identical to scipy.sparse.csr_matrix: data, indices, indptr. Conversion is a one-liner in both directions.

This notebook shows:

  • SciPy -> mlx-sparse (the common path for existing workflows)

  • mlx-sparse -> SciPy (for passing results to SciPy solvers or plotters)

  • Cross-validation of numerical results

import mlx.core as mx
import numpy as np
import scipy.sparse
import scipy.sparse.linalg
import mlx_sparse as ms

ms.use_gpu()

SciPy -> mlx-sparse#

rng = np.random.default_rng(0)
sp = scipy.sparse.random(256, 256, density=0.025, format="csr",
                           dtype=np.float32, random_state=rng)

print(f"SciPy CSR: {sp.shape} nnz={sp.nnz}  dtype={sp.dtype}")

# SciPy guarantees sorted indices and no duplicates after tocsr()
def scipy_to_mlx(sp_csr):
    """Convert a scipy.sparse.csr_matrix/array to mlx_sparse.CSRArray."""
    sp_csr = sp_csr.tocsr()  # ensure format
    return ms.csr_array(
        (
            mx.array(sp_csr.data),
            mx.array(sp_csr.indices.astype(np.int32)),
            mx.array(sp_csr.indptr.astype(np.int32)),
        ),
        shape=sp_csr.shape,
        sorted_indices=True,
        canonical=True,
    )

A = scipy_to_mlx(sp)
print("mlx-sparse:", A)
SciPy CSR: (256, 256) nnz=1664  dtype=float32
mlx-sparse: CSRArray(shape=(256, 256), nnz=1664, dtype=float32, index_dtype=int32, sorted_indices=True, has_canonical_format=True)

mlx-sparse -> SciPy#

def mlx_to_scipy(csr: ms.CSRArray):
    """Convert a mlx_sparse.CSRArray to scipy.sparse.csr_array."""
    mx.eval(csr.data, csr.indices, csr.indptr)
    return scipy.sparse.csr_array(
        (
            np.array(csr.data),
            np.array(csr.indices).astype(np.int32),
            np.array(csr.indptr).astype(np.int32),
        ),
        shape=csr.shape,
    )

sp_back = mlx_to_scipy(A)
print(f"SciPy (from mlx): {sp_back.shape} nnz={sp_back.nnz}  dtype={sp_back.dtype}")

diff = np.max(np.abs(sp.toarray() - sp_back.toarray()))
print(f"Round-trip max diff: {diff:.2e}")
SciPy (from mlx): (256, 256) nnz=1664  dtype=float32
Round-trip max diff: 0.00e+00

Cross-validation: mlx-sparse vs SciPy SpMV#

# SpMV
x_np = rng.standard_normal(256).astype(np.float32)
x = mx.array(x_np)

y_mlx = A @ x
mx.eval(y_mlx)
y_sp = sp @ x_np

err_v = np.max(np.abs(np.array(y_mlx) - y_sp))
print(f"SpMV max error mlx vs SciPy: {err_v:.2e}")

# SpMM
B_np = rng.standard_normal((256, 64)).astype(np.float32)
B = mx.array(B_np)

Y_mlx = A @ B
mx.eval(Y_mlx)
Y_sp = sp @ B_np

err_m = np.max(np.abs(np.array(Y_mlx) - Y_sp))
print(f"\nSpMM (256x64) max error mlx vs SciPy: {err_m:.2e}")
SpMV max error mlx vs SciPy: 9.54e-07

SpMM (256x64) max error mlx vs SciPy: 1.91e-06

Using SciPy solvers on mlx-sparse matrices#

mlx-sparse handles the forward pass on the GPU. When you need a linear solver (cg, spsolve, etc.), convert back to SciPy for the solve, then convert the solution back to MLX.

# Build a symmetric positive definite matrix AᵀA + δI
n = 64
sp_small = scipy.sparse.random(n, n, density=0.1, format="csr",
                               dtype=np.float32, random_state=10)
ATA = sp_small.T @ sp_small + scipy.sparse.eye(n, dtype=np.float32) * 0.5
ATA_csr = ATA.tocsr()

b_np = rng.standard_normal(n).astype(np.float32)

# Solve on CPU via SciPy CG
x_sol, info = scipy.sparse.linalg.cg(ATA_csr, b_np)
residual = np.linalg.norm(ATA_csr @ x_sol - b_np)
print(f"CG converged: {info == 0}  residual: {residual:.2e}")

# Solution available as MLX array for downstream GPU operations
x_mlx = mx.array(x_sol.astype(np.float32))
print(f"solution as mx.array shape: {x_mlx.shape}")

# Now apply on GPU
ATA_mlx = scipy_to_mlx(ATA_csr.astype(np.float32))
residual_mlx = ATA_mlx @ x_mlx - mx.array(b_np)
mx.eval(residual_mlx)
# (results would match scipy's)
CG converged: True  residual: 2.93e-08
solution as mx.array shape: (64,)

COO interop#

SciPy COO uses (data, (row, col)), the same layout as ms.coo_array.

sp_coo = scipy.sparse.random(8, 8, density=0.3, format="coo",
                              dtype=np.float32, random_state=2)

coo_mlx = ms.coo_array(
    (
        mx.array(sp_coo.data),
        (mx.array(sp_coo.row.astype(np.int32)),
         mx.array(sp_coo.col.astype(np.int32))),
    ),
    shape=sp_coo.shape,
)
print("mlx COO:", coo_mlx)

# Convert to CSR and compare to SciPy's dense view
csr_mlx = coo_mlx.tocsr(canonical=True)
ref = sp_coo.toarray()
diff = np.max(np.abs(np.array(csr_mlx.todense()) - ref))
print(f"Round-trip COO->CSR max diff: {diff:.2e}")
mlx COO: COOArray(shape=(8, 8), nnz=20, dtype=float32, index_dtype=int32, has_canonical_format=False)
Round-trip COO->CSR max diff: 0.00e+00