Sparse linalg: iterative solvers#

mlx_sparse.linalg provides native sparse iterative solvers that operate directly on CSR arrays without ever materializing a dense matrix. All three solvers, CG, GMRES, and MINRES, dispatch through native C++/Metal kernels and participate in MLX’s lazy computation graph.

Solver

Matrix requirement

Typical use

cg

Symmetric positive-definite (SPD)

Poisson/FEM systems

gmres

Any non-singular square matrix

Non-symmetric PDEs

minres

Symmetric (possibly indefinite)

Saddle-point problems

Each solver returns (x, info) where info == 0 means the solver converged to the requested tolerance within maxiter iterations.

import mlx.core as mx
import numpy as np
import mlx_sparse as ms
from mlx_sparse import linalg

Building a sparse SPD matrix#

Conjugate gradients requires a symmetric positive-definite matrix. A reliable way to construct one is to take any sparse matrix B and form A = B @ B.T + shift * I. Here we use a 5-point 2D Laplacian stencil on a small 4x4 grid.

# 5-point Laplacian on a 4x4 grid means 16x16 SPD matrix
n = 4
N = n * n

rows_list, cols_list, vals_list = [], [], []
for i in range(n):
    for j in range(n):
        k = i * n + j
        rows_list.append(k); cols_list.append(k); vals_list.append(4.0)
        for ni, nj in [(i-1, j), (i+1, j), (i, j-1), (i, j+1)]:
            if 0 <= ni < n and 0 <= nj < n:
                rows_list.append(k); cols_list.append(ni*n+nj); vals_list.append(-1.0)

import scipy.sparse
L_scipy = scipy.sparse.coo_matrix(
    (np.array(vals_list, np.float32),
     (np.array(rows_list, np.int32), np.array(cols_list, np.int32))),
    shape=(N, N),
).tocsr()

A = ms.csr_array(
    (mx.array(L_scipy.data), mx.array(L_scipy.indices), mx.array(L_scipy.indptr)),
    shape=(N, N), canonical=True,
)
b_np = np.ones(N, dtype=np.float32)
b = mx.array(b_np)

print(f"shape={A.shape}, nnz={A.nnz}, density={A.nnz/N**2:.4f}")
shape=(16, 16), nnz=64, density=0.2500

Conjugate Gradients#

CG converges in at most n steps for an n×n SPD system (exact arithmetic). In practice it converges much faster for well-conditioned matrices. rtol controls the relative residual stopping criterion ||r_k|| max(atol, rtol * ||b||).

x_cg, info_cg = linalg.cg(A, b, rtol=1e-6, maxiter=200)
mx.eval(x_cg)

print(f"CG info={info_cg}  (0=converged)")
residual = np.linalg.norm(np.array(A @ x_cg) - b_np)
rel_residual = residual / np.linalg.norm(b_np)
print(f"||Ax - b|| = {residual:.2e},  relative = {rel_residual:.2e}")
CG info=0  (0=converged)
||Ax - b|| = 4.62e-07,  relative = 1.15e-07
# Compare CG solution with numpy dense reference
x_ref = np.linalg.solve(L_scipy.toarray(), b_np)
rel_error = np.linalg.norm(np.array(x_cg) - x_ref) / np.linalg.norm(x_ref)
print(f"Relative error vs numpy: {rel_error:.2e}")
print(f"First 6 entries: {np.array(x_cg)[:6]}")
Relative error vs numpy: 8.63e-08
First 6 entries: [0.8333334 1.1666667 1.1666667 0.8333334 1.1666667 1.6666667]

Solver parameters: rtol, atol, maxiter#

  • rtol (default 1e-5): stop when ||r|| rtol * ||b||.

  • atol (default 0): absolute floor, useful when b is near zero.

  • maxiter (default 10 * n): maximum number of iterations.

  • When maxiter is exhausted without convergence, info > 0 reports the iteration count at termination.

# Force early termination by setting maxiter=2 on a 16x16 system
_, info_early = linalg.cg(A, b, rtol=1e-12, maxiter=2)
print(f"info with maxiter=2: {info_early}  (non-zero = not converged)")

# Loose tolerance converges in fewer iterations
_, info_loose = linalg.cg(A, b, rtol=1e-2)
print(f"info with rtol=1e-2: {info_loose}")
info with maxiter=2: 2  (non-zero = not converged)
info with rtol=1e-2: 0

GMRES for non-symmetric systems#

GMRES uses restarted Arnoldi and works for any non-singular square matrix, including non-symmetric ones. The restart parameter controls the Krylov subspace dimension before each restart (default: min(20, n)).

# Build a non-symmetric sparse matrix: upper-triangular shift of the Laplacian
import scipy.sparse
B_scipy = L_scipy + scipy.sparse.triu(L_scipy, k=1) * 0.5  # asymmetric perturbation
B = ms.csr_array(
    (mx.array(B_scipy.data.astype(np.float32)),
     mx.array(B_scipy.indices.astype(np.int32)),
     mx.array(B_scipy.indptr.astype(np.int32))),
    shape=(N, N), canonical=True,
)

x_gmres, info_gmres = linalg.gmres(B, b, restart=16, rtol=1e-6, maxiter=200)
mx.eval(x_gmres)

print(f"GMRES info={info_gmres}")
res_gmres = np.linalg.norm(np.array(B @ x_gmres) - b_np) / np.linalg.norm(b_np)
print(f"Relative residual: {res_gmres:.2e}")
GMRES info=200
Relative residual: 5.51e-06

MINRES for symmetric indefinite systems#

MINRES uses Lanczos projection and works when the matrix is symmetric but may have both positive and negative eigenvalues (indefinite). It minimizes the residual over the Krylov space in the 2-norm at each step.

# Shift the Laplacian to make it indefinite: subtract largest eigenvalue estimate
eig_approx = 8.0  # rough upper bound for the 4x4 grid Laplacian
C_scipy = L_scipy - scipy.sparse.eye(N, dtype=np.float32) * eig_approx
C = ms.csr_array(
    (mx.array(C_scipy.data.astype(np.float32)),
     mx.array(C_scipy.indices.astype(np.int32)),
     mx.array(C_scipy.indptr.astype(np.int32))),
    shape=(N, N), canonical=True,
)

# b must be in the range of C for a solution to exist
b_c = mx.array(np.random.default_rng(0).normal(size=N).astype(np.float32))

x_minres, info_minres = linalg.minres(C, b_c, rtol=1e-5, maxiter=500)
mx.eval(x_minres)

print(f"MINRES info={info_minres}")
res_mr = np.linalg.norm(np.array(C @ x_minres) - np.array(b_c)) / np.linalg.norm(np.array(b_c))
print(f"Relative residual: {res_mr:.2e}")
MINRES info=0
Relative residual: 6.80e-08

All three solvers on the SPD system#

All three solvers can handle the SPD Laplacian. CG and MINRES tend to converge faster than GMRES for SPD systems because they exploit symmetry.

solvers = {
    "cg": lambda: linalg.cg(A, b, rtol=1e-7),
    "gmres": lambda: linalg.gmres(A, b, rtol=1e-7),
    "minres": lambda: linalg.minres(A, b, rtol=1e-7),
}

for name, solve_fn in solvers.items():
    x, info = solve_fn()
    mx.eval(x)
    rel = np.linalg.norm(np.array(A @ x) - b_np) / np.linalg.norm(b_np)
    print(f"{name:8s}  converged={info==0}  rel_residual={rel:.2e}")
cg        converged=True  rel_residual=1.15e-07
gmres     converged=False  rel_residual=1.51e-07
minres    converged=False  rel_residual=1.93e-07