Sparse linalg: direct factorizations#

mlx_sparse.linalg provides two sparse direct factorizations:

  • Sparse Cholesky: for symmetric positive-definite (SPD) matrices. Computes A = L @ L.T where L is a sparse lower-triangular CSR factor.

  • Sparse LU: for any non-singular square matrix. Computes P @ A = L @ U where P is a row permutation, L is unit lower-triangular, and U is upper-triangular.

Both factorizations store their factors as CSRArray objects and perform back-substitution with sparse triangular solve kernels, no dense matrices are ever allocated.

spsolve(A, b) is a convenience wrapper that runs LU and solves in one call.

import mlx.core as mx
import numpy as np
import mlx_sparse as ms
from mlx_sparse import linalg

Test matrix#

We use a small but non-trivial 4×4 SPD matrix so the factor structure is easy to inspect by hand.

# 4x4 SPD matrix (symmetric, diagonally dominant)
#   A = [[4, 1, 0, 0],
#        [1, 4, 1, 0],
#        [0, 1, 4, 1],
#        [0, 0, 1, 4]]
import scipy.sparse
diags = scipy.sparse.diags(
    [[-1.0, -1.0, -1.0], [4.0, 4.0, 4.0, 4.0], [-1.0, -1.0, -1.0]],
    [-1, 0, 1], shape=(4, 4), dtype=np.float32,
).tocsr()

A = ms.csr_array(
    (mx.array(diags.data), mx.array(diags.indices), mx.array(diags.indptr)),
    shape=(4, 4), canonical=True,
)
b = mx.array(np.array([1.0, 2.0, 3.0, 4.0], dtype=np.float32))
b_np = np.array([1.0, 2.0, 3.0, 4.0], dtype=np.float32)

print("A (dense):\n", np.array(A.todense()))
print(f"nnz={A.nnz}")
A (dense):
 [[ 4. -1.  0.  0.]
 [-1.  4. -1.  0.]
 [ 0. -1.  4. -1.]
 [ 0.  0. -1.  4.]]
nnz=10

Sparse Cholesky#

sparse_cholesky(A) returns a SparseCholesky object whose .L attribute is the sparse lower-triangular factor. The object is callable: chol(b) is equivalent to chol.solve(b).

chol = linalg.sparse_cholesky(A)
print("Cholesky factor L (dense view):")
print(np.array(chol.L.todense()))
print(f"L.nnz={chol.L.nnz}  (vs A.nnz={A.nnz})")
Cholesky factor L (dense view):
[[ 2.          0.          0.          0.        ]
 [-0.5         1.9364917   0.          0.        ]
 [ 0.         -0.5163978   1.9321836   0.        ]
 [ 0.          0.         -0.51754916  1.9318755 ]]
L.nnz=7  (vs A.nnz=10)
# Verify: L @ L.T must equal A
L_dense = np.array(chol.L.todense())
A_dense = np.array(A.todense())
reconstruction_error = np.max(np.abs(L_dense @ L_dense.T - A_dense))
print(f"max |L @ L.T - A| = {reconstruction_error:.2e}")
max |L @ L.T - A| = 0.00e+00
# Solve A x = b via Cholesky: forward solve L y = b, then backward L.T x = y
x_chol = chol.solve(b)
mx.eval(x_chol)

x_ref = np.linalg.solve(A_dense, b_np)
rel_error = np.linalg.norm(np.array(x_chol) - x_ref) / np.linalg.norm(x_ref)
print(f"x = {np.array(x_chol)}")
print(f"relative error vs numpy: {rel_error:.2e}")
x = [0.48803824 0.952153   1.320574   1.3301437 ]
relative error vs numpy: 8.40e-08

Sparse LU#

sparse_lu(A) returns a SparseLU object with attributes perm, L, and U satisfying P @ A = L @ U. L has unit diagonal (stored implicitly); U has general diagonal entries.

LU works on any non-singular square matrix, not just SPD ones.

lu = linalg.sparse_lu(A)
mx.eval(lu.perm, lu.L.data, lu.U.data)

print("Row permutation perm:", np.array(lu.perm))
print("\nL factor (unit lower triangular):")
print(np.round(np.array(lu.L.todense()), 4))
print("\nU factor (upper triangular):")
print(np.round(np.array(lu.U.todense()), 4))
Row permutation perm: [0 1 2 3]

L factor (unit lower triangular):
[[ 1.      0.      0.      0.    ]
 [-0.25    1.      0.      0.    ]
 [ 0.     -0.2667  1.      0.    ]
 [ 0.      0.     -0.2679  1.    ]]

U factor (upper triangular):
[[ 4.     -1.      0.      0.    ]
 [ 0.      3.75   -1.      0.    ]
 [ 0.      0.      3.7333 -1.    ]
 [ 0.      0.      0.      3.7321]]
# Verify: L @ U must equal P @ A
L_lu = np.array(lu.L.todense())
U_lu = np.array(lu.U.todense())
perm = np.array(lu.perm)
PA = A_dense[perm]  # permuted rows
lu_error = np.max(np.abs(L_lu @ U_lu - PA))
print(f"max |L @ U - P @ A| = {lu_error:.2e}")
max |L @ U - P @ A| = 0.00e+00
# Solve A x = b via LU
x_lu = lu.solve(b)
mx.eval(x_lu)

rel_error_lu = np.linalg.norm(np.array(x_lu) - x_ref) / np.linalg.norm(x_ref)
print(f"x = {np.array(x_lu)}")
print(f"relative error vs numpy: {rel_error_lu:.2e}")
x = [0.4880383  0.95215315 1.3205743  1.3301436 ]
relative error vs numpy: 6.33e-08

spsolve: one-shot convenience wrapper#

When you only need the solution and not the factors, spsolve(A, b) runs sparse LU and solves in a single call.

x_direct = linalg.spsolve(A, b)
mx.eval(x_direct)

rel_error_direct = np.linalg.norm(np.array(x_direct) - x_ref) / np.linalg.norm(x_ref)
print(f"spsolve x = {np.array(x_direct)}")
print(f"relative error: {rel_error_direct:.2e}")
spsolve x = [0.4880383  0.95215315 1.3205743  1.3301436 ]
relative error: 6.33e-08

When to use Cholesky vs LU vs iterative solvers#

Method

Matrix type

Cost

Re-use factors?

sparse_cholesky

SPD only

~n fill-in

Yes, call chol.solve(b) multiple times

sparse_lu

Any non-singular

More fill-in

Yes, call lu.solve(b) multiple times

spsolve

Any non-singular

Same as LU

No, factorizes each call

cg / gmres / minres

See the notebook on Solvers

O(nnz × iters)

Direct factorizations pay their cost up front but amortize well when solving many right-hand sides. For very large sparse systems where fill-in is prohibitive, iterative solvers are preferred.