Sparse linalg: iterative solvers#
mlx_sparse.linalg provides native sparse iterative solvers that operate
directly on CSR arrays without ever materializing a dense matrix. All three
solvers, CG, GMRES, and MINRES, dispatch through native C++/Metal
kernels and participate in MLX’s lazy computation graph.
Solver |
Matrix requirement |
Typical use |
|---|---|---|
|
Symmetric positive-definite (SPD) |
Poisson/FEM systems |
|
Any non-singular square matrix |
Non-symmetric PDEs |
|
Symmetric (possibly indefinite) |
Saddle-point problems |
Each solver returns (x, info) where info == 0 means the solver
converged to the requested tolerance within maxiter iterations.
import mlx.core as mx
import numpy as np
import mlx_sparse as ms
from mlx_sparse import linalg
Building a sparse SPD matrix#
Conjugate gradients requires a symmetric positive-definite matrix. A
reliable way to construct one is to take any sparse matrix B and form
A = B @ B.T + shift * I. Here we use a 5-point 2D Laplacian stencil on a
small 4x4 grid.
# 5-point Laplacian on a 4x4 grid means 16x16 SPD matrix
n = 4
N = n * n
rows_list, cols_list, vals_list = [], [], []
for i in range(n):
for j in range(n):
k = i * n + j
rows_list.append(k); cols_list.append(k); vals_list.append(4.0)
for ni, nj in [(i-1, j), (i+1, j), (i, j-1), (i, j+1)]:
if 0 <= ni < n and 0 <= nj < n:
rows_list.append(k); cols_list.append(ni*n+nj); vals_list.append(-1.0)
import scipy.sparse
L_scipy = scipy.sparse.coo_matrix(
(np.array(vals_list, np.float32),
(np.array(rows_list, np.int32), np.array(cols_list, np.int32))),
shape=(N, N),
).tocsr()
A = ms.csr_array(
(mx.array(L_scipy.data), mx.array(L_scipy.indices), mx.array(L_scipy.indptr)),
shape=(N, N), canonical=True,
)
b_np = np.ones(N, dtype=np.float32)
b = mx.array(b_np)
print(f"shape={A.shape}, nnz={A.nnz}, density={A.nnz/N**2:.4f}")
shape=(16, 16), nnz=64, density=0.2500
Conjugate Gradients#
CG converges in at most n steps for an n×n SPD system (exact arithmetic).
In practice it converges much faster for well-conditioned matrices. rtol
controls the relative residual stopping criterion
||r_k|| ≤ max(atol, rtol * ||b||).
x_cg, info_cg = linalg.cg(A, b, rtol=1e-6, maxiter=200)
mx.eval(x_cg)
print(f"CG info={info_cg} (0=converged)")
residual = np.linalg.norm(np.array(A @ x_cg) - b_np)
rel_residual = residual / np.linalg.norm(b_np)
print(f"||Ax - b|| = {residual:.2e}, relative = {rel_residual:.2e}")
CG info=0 (0=converged)
||Ax - b|| = 4.62e-07, relative = 1.15e-07
# Compare CG solution with numpy dense reference
x_ref = np.linalg.solve(L_scipy.toarray(), b_np)
rel_error = np.linalg.norm(np.array(x_cg) - x_ref) / np.linalg.norm(x_ref)
print(f"Relative error vs numpy: {rel_error:.2e}")
print(f"First 6 entries: {np.array(x_cg)[:6]}")
Relative error vs numpy: 8.63e-08
First 6 entries: [0.8333334 1.1666667 1.1666667 0.8333334 1.1666667 1.6666667]
Solver parameters: rtol, atol, maxiter#
rtol(default1e-5): stop when||r|| ≤ rtol * ||b||.atol(default0): absolute floor, useful whenbis near zero.maxiter(default10 * n): maximum number of iterations.When
maxiteris exhausted without convergence,info > 0reports the iteration count at termination.
# Force early termination by setting maxiter=2 on a 16x16 system
_, info_early = linalg.cg(A, b, rtol=1e-12, maxiter=2)
print(f"info with maxiter=2: {info_early} (non-zero = not converged)")
# Loose tolerance converges in fewer iterations
_, info_loose = linalg.cg(A, b, rtol=1e-2)
print(f"info with rtol=1e-2: {info_loose}")
info with maxiter=2: 2 (non-zero = not converged)
info with rtol=1e-2: 0
GMRES for non-symmetric systems#
GMRES uses restarted Arnoldi and works for any non-singular square matrix,
including non-symmetric ones. The restart parameter controls the Krylov
subspace dimension before each restart (default: min(20, n)).
# Build a non-symmetric sparse matrix: upper-triangular shift of the Laplacian
import scipy.sparse
B_scipy = L_scipy + scipy.sparse.triu(L_scipy, k=1) * 0.5 # asymmetric perturbation
B = ms.csr_array(
(mx.array(B_scipy.data.astype(np.float32)),
mx.array(B_scipy.indices.astype(np.int32)),
mx.array(B_scipy.indptr.astype(np.int32))),
shape=(N, N), canonical=True,
)
x_gmres, info_gmres = linalg.gmres(B, b, restart=16, rtol=1e-6, maxiter=200)
mx.eval(x_gmres)
print(f"GMRES info={info_gmres}")
res_gmres = np.linalg.norm(np.array(B @ x_gmres) - b_np) / np.linalg.norm(b_np)
print(f"Relative residual: {res_gmres:.2e}")
GMRES info=200
Relative residual: 5.51e-06
MINRES for symmetric indefinite systems#
MINRES uses Lanczos projection and works when the matrix is symmetric but may have both positive and negative eigenvalues (indefinite). It minimizes the residual over the Krylov space in the 2-norm at each step.
# Shift the Laplacian to make it indefinite: subtract largest eigenvalue estimate
eig_approx = 8.0 # rough upper bound for the 4x4 grid Laplacian
C_scipy = L_scipy - scipy.sparse.eye(N, dtype=np.float32) * eig_approx
C = ms.csr_array(
(mx.array(C_scipy.data.astype(np.float32)),
mx.array(C_scipy.indices.astype(np.int32)),
mx.array(C_scipy.indptr.astype(np.int32))),
shape=(N, N), canonical=True,
)
# b must be in the range of C for a solution to exist
b_c = mx.array(np.random.default_rng(0).normal(size=N).astype(np.float32))
x_minres, info_minres = linalg.minres(C, b_c, rtol=1e-5, maxiter=500)
mx.eval(x_minres)
print(f"MINRES info={info_minres}")
res_mr = np.linalg.norm(np.array(C @ x_minres) - np.array(b_c)) / np.linalg.norm(np.array(b_c))
print(f"Relative residual: {res_mr:.2e}")
MINRES info=0
Relative residual: 6.80e-08
All three solvers on the SPD system#
All three solvers can handle the SPD Laplacian. CG and MINRES tend to converge faster than GMRES for SPD systems because they exploit symmetry.
solvers = {
"cg": lambda: linalg.cg(A, b, rtol=1e-7),
"gmres": lambda: linalg.gmres(A, b, rtol=1e-7),
"minres": lambda: linalg.minres(A, b, rtol=1e-7),
}
for name, solve_fn in solvers.items():
x, info = solve_fn()
mx.eval(x)
rel = np.linalg.norm(np.array(A @ x) - b_np) / np.linalg.norm(b_np)
print(f"{name:8s} converged={info==0} rel_residual={rel:.2e}")
cg converged=True rel_residual=1.15e-07
gmres converged=False rel_residual=1.51e-07
minres converged=False rel_residual=1.93e-07