SciPy interoperability#
mlx-sparse does not provide automatic SciPy interop, scipy.sparse and
MLX live on different sides of the host/accelerator boundary. But the
buffer layout of CSRArray is identical to scipy.sparse.csr_matrix:
data, indices, indptr. Conversion is a one-liner in both directions.
This notebook shows:
SciPy -> mlx-sparse (the common path for existing workflows)
mlx-sparse -> SciPy (for passing results to SciPy solvers or plotters)
Cross-validation of numerical results
import mlx.core as mx
import numpy as np
import scipy.sparse
import scipy.sparse.linalg
import mlx_sparse as ms
ms.use_gpu()
SciPy -> mlx-sparse#
rng = np.random.default_rng(0)
sp = scipy.sparse.random(256, 256, density=0.025, format="csr",
dtype=np.float32, random_state=rng)
print(f"SciPy CSR: {sp.shape} nnz={sp.nnz} dtype={sp.dtype}")
# SciPy guarantees sorted indices and no duplicates after tocsr()
def scipy_to_mlx(sp_csr):
"""Convert a scipy.sparse.csr_matrix/array to mlx_sparse.CSRArray."""
sp_csr = sp_csr.tocsr() # ensure format
return ms.csr_array(
(
mx.array(sp_csr.data),
mx.array(sp_csr.indices.astype(np.int32)),
mx.array(sp_csr.indptr.astype(np.int32)),
),
shape=sp_csr.shape,
sorted_indices=True,
canonical=True,
)
A = scipy_to_mlx(sp)
print("mlx-sparse:", A)
SciPy CSR: (256, 256) nnz=1664 dtype=float32
mlx-sparse: CSRArray(shape=(256, 256), nnz=1664, dtype=float32, index_dtype=int32, sorted_indices=True, has_canonical_format=True)
mlx-sparse -> SciPy#
def mlx_to_scipy(csr: ms.CSRArray):
"""Convert a mlx_sparse.CSRArray to scipy.sparse.csr_array."""
mx.eval(csr.data, csr.indices, csr.indptr)
return scipy.sparse.csr_array(
(
np.array(csr.data),
np.array(csr.indices).astype(np.int32),
np.array(csr.indptr).astype(np.int32),
),
shape=csr.shape,
)
sp_back = mlx_to_scipy(A)
print(f"SciPy (from mlx): {sp_back.shape} nnz={sp_back.nnz} dtype={sp_back.dtype}")
diff = np.max(np.abs(sp.toarray() - sp_back.toarray()))
print(f"Round-trip max diff: {diff:.2e}")
SciPy (from mlx): (256, 256) nnz=1664 dtype=float32
Round-trip max diff: 0.00e+00
Cross-validation: mlx-sparse vs SciPy SpMV#
# SpMV
x_np = rng.standard_normal(256).astype(np.float32)
x = mx.array(x_np)
y_mlx = A @ x
mx.eval(y_mlx)
y_sp = sp @ x_np
err_v = np.max(np.abs(np.array(y_mlx) - y_sp))
print(f"SpMV max error mlx vs SciPy: {err_v:.2e}")
# SpMM
B_np = rng.standard_normal((256, 64)).astype(np.float32)
B = mx.array(B_np)
Y_mlx = A @ B
mx.eval(Y_mlx)
Y_sp = sp @ B_np
err_m = np.max(np.abs(np.array(Y_mlx) - Y_sp))
print(f"\nSpMM (256x64) max error mlx vs SciPy: {err_m:.2e}")
SpMV max error mlx vs SciPy: 9.54e-07
SpMM (256x64) max error mlx vs SciPy: 1.91e-06
Using SciPy solvers on mlx-sparse matrices#
mlx-sparse handles the forward pass on the GPU. When you need a linear
solver (cg, spsolve, etc.), convert back to SciPy for the solve, then
convert the solution back to MLX.
# Build a symmetric positive definite matrix AᵀA + δI
n = 64
sp_small = scipy.sparse.random(n, n, density=0.1, format="csr",
dtype=np.float32, random_state=10)
ATA = sp_small.T @ sp_small + scipy.sparse.eye(n, dtype=np.float32) * 0.5
ATA_csr = ATA.tocsr()
b_np = rng.standard_normal(n).astype(np.float32)
# Solve on CPU via SciPy CG
x_sol, info = scipy.sparse.linalg.cg(ATA_csr, b_np)
residual = np.linalg.norm(ATA_csr @ x_sol - b_np)
print(f"CG converged: {info == 0} residual: {residual:.2e}")
# Solution available as MLX array for downstream GPU operations
x_mlx = mx.array(x_sol.astype(np.float32))
print(f"solution as mx.array shape: {x_mlx.shape}")
# Now apply on GPU
ATA_mlx = scipy_to_mlx(ATA_csr.astype(np.float32))
residual_mlx = ATA_mlx @ x_mlx - mx.array(b_np)
mx.eval(residual_mlx)
# (results would match scipy's)
CG converged: True residual: 2.93e-08
solution as mx.array shape: (64,)
COO interop#
SciPy COO uses (data, (row, col)), the same layout as ms.coo_array.
sp_coo = scipy.sparse.random(8, 8, density=0.3, format="coo",
dtype=np.float32, random_state=2)
coo_mlx = ms.coo_array(
(
mx.array(sp_coo.data),
(mx.array(sp_coo.row.astype(np.int32)),
mx.array(sp_coo.col.astype(np.int32))),
),
shape=sp_coo.shape,
)
print("mlx COO:", coo_mlx)
# Convert to CSR and compare to SciPy's dense view
csr_mlx = coo_mlx.tocsr(canonical=True)
ref = sp_coo.toarray()
diff = np.max(np.abs(np.array(csr_mlx.todense()) - ref))
print(f"Round-trip COO->CSR max diff: {diff:.2e}")
mlx COO: COOArray(shape=(8, 8), nnz=20, dtype=float32, index_dtype=int32, has_canonical_format=False)
Round-trip COO->CSR max diff: 0.00e+00