Preconditioners: Chebyshev polynomial#

Chebyshev is the GPU-friendly preconditioner in this release. It avoids triangular solves: each application is sparse matrix-vector products plus vector updates. That makes it attractive on Metal for SPD Poisson-like systems, while quality depends on the spectral interval and polynomial degree.

import mlx.core as mx
import numpy as np
import scipy.sparse
import mlx_sparse as ms
from mlx_sparse import linalg
from mlx_sparse.linalg import preconditioners

# Use CPU execution throughout.
ms.use_cpu(require_available=False)
np.set_printoptions(precision=4, suppress=True)

Poisson smoother/preconditioner#

With no explicit spectral bounds, setup uses native conservative bounds and a Lanczos estimate where available. You can also pass lambda_min and lambda_max explicitly for a fully controlled polynomial.

grid = 16
T = scipy.sparse.diags(
    [-np.ones(grid - 1), 2.0 * np.ones(grid), -np.ones(grid - 1)],
    [-1, 0, 1],
    format="csr",
    dtype=np.float32,
)
I = scipy.sparse.eye(grid, format="csr", dtype=np.float32)
A_sp = (scipy.sparse.kron(I, T) + scipy.sparse.kron(T, I)).astype(np.float32)
A = ms.from_scipy(A_sp)
b = mx.ones((A.shape[0],), dtype=mx.float32)

x0, info0 = linalg.cg(A, b, rtol=1e-4, maxiter=512, return_info=True)
Mj = preconditioners.jacobi(A, check=True)
xj, infoj = linalg.cg(A, b, M=Mj, rtol=1e-4, maxiter=512, return_info=True)
Mc = preconditioners.chebyshev(A, degree=2)
xc, infoc = linalg.cg(A, b, M=Mc, rtol=1e-4, maxiter=512, return_info=True)

print("Chebyshev interval:", Mc.lambda_min, Mc.lambda_max)
print("none     :", info0.status, info0.iterations, f"{info0.residual_norm:.3e}")
print("jacobi   :", infoj.status, infoj.iterations, f"{infoj.residual_norm:.3e}")
print("chebyshev:", infoc.status, infoc.iterations, f"{infoc.residual_norm:.3e}")
Chebyshev interval: 0.03405384719371796 8.0
none     : 0 21 9.470e-04
jacobi   : 0 21 9.470e-04
chebyshev: 0 11 1.148e-03

Device boundary#

Setup is CPU native because it inspects CSR structure and estimates a spectral interval. Application follows native sparse matrix-vector and vector kernels on the selected MLX device. On GPU, this avoids the dependency chain of triangular solves used by ILU(0)/IC(0).