Sparse linalg: linear operators and inner products#
This notebook covers two features that round out mlx_sparse.linalg:
LinearOperator: a matrix-free wrapper that lets solvers and spectral routines work with any callable that implements matrix-vector products, not just a stored sparse matrix.
dot / vdot: sparse Frobenius inner products that merge canonical CSR rows in native code without materializing a dense matrix.
import mlx.core as mx
import numpy as np
import mlx_sparse as ms
from mlx_sparse import linalg
# 2x2 SPD matrix used throughout
A = ms.csr_array(
(
mx.array([4.0, -1.0, -1.0, 4.0], dtype=mx.float32),
mx.array([0, 1, 0, 1], dtype=mx.int32),
mx.array([0, 2, 4], dtype=mx.int32),
),
shape=(2, 2),
canonical=True,
)
x = mx.array(np.array([1.0, 2.0], dtype=np.float32))
aslinearoperator: wrapping sparse matrices#
aslinearoperator(A) turns CSRArray, COOArray, or CSCArray inputs
into a LinearOperator. Sparse inputs are normalized once to canonical CSR
storage so solver iterations reuse the same row-oriented native kernels.
op = linalg.aslinearoperator(A)
print("shape", op.shape)
print("dtype", op.dtype)
print("A @ x ", op @ x)
print("A.T @ x", op.rmatvec(x))
shape (2, 2)
dtype mlx.core.float32
A @ x array([2, 7], dtype=float32)
A.T @ x array([2, 7], dtype=float32)
Custom LinearOperator from a callable#
When the matrix is not explicitly stored, e.g. a Fourier transform, a
preconditioned system, or a projection, you can construct a LinearOperator
directly from Python callables.
# Scaled operator: 0.5 * A without storing a new matrix
scaled = linalg.LinearOperator(
A.shape,
matvec=lambda rhs: 0.5 * (A @ rhs),
matmat=lambda rhs: 0.5 * (A @ rhs),
dtype=A.dtype,
)
print("0.5 * A @ x =", scaled @ x)
print("direct check:", 0.5 * (A @ x))
0.5 * A @ x = array([1, 3.5], dtype=float32)
direct check: array([1, 3.5], dtype=float32)
Adjoint / transpose operator#
The .H (adjoint) and .T (transpose) properties return new
LinearOperator objects that swap matvec and rmatvec.
This is useful when building normal-equations solvers A.T @ A @ x = A.T @ b.
op_T = op.T # transpose operator
print("(A.T) @ x via op.T =", op_T @ x)
# Normal operator: A.T @ A (matrix-free)
ATA = linalg.LinearOperator(
A.shape,
matvec=lambda v: op_T @ (op @ v),
dtype=A.dtype,
)
print("(A.T @ A) @ x = ", ATA @ x)
print("dense check = ", mx.array(np.array(A.todense()).T) @ (A @ x))
(A.T) @ x via op.T = array([2, 7], dtype=float32)
(A.T @ A) @ x = array([1, 26], dtype=float32)
dense check = array([1, 26], dtype=float32)
Using a LinearOperator with a solver#
Iterative solvers accept LinearOperator in place of explicit sparse arrays
when the matrix is too large or too structured to store. Here we wrap the
2×2 A in
a scaled operator and solve with CG.
# Build a slightly larger (6x6) tridiagonal SPD operator for a non-trivial example
import scipy.sparse
n6 = 6
tri = scipy.sparse.diags([[4.0]*n6, [-1.0]*(n6-1), [-1.0]*(n6-1)], [0, 1, -1], dtype=np.float32).tocsr()
A6 = ms.csr_array(
(mx.array(tri.data), mx.array(tri.indices), mx.array(tri.indptr)),
shape=(n6, n6), canonical=True,
)
b6 = mx.array(np.ones(n6, dtype=np.float32))
op6 = linalg.aslinearoperator(A6)
x_sol, info = linalg.cg(op6, b6, rtol=1e-7)
mx.eval(x_sol)
print(f"CG via LinearOperator: info={info}")
residual = np.linalg.norm(np.array(A6 @ x_sol) - np.ones(n6))
print(f"||Ax - b|| = {residual:.2e}")
CG via LinearOperator: info=0
||Ax - b|| = 0.00e+00
dot and vdot: sparse Frobenius inner products#
linalg.dot(A, B) computes the Frobenius inner product
sum(A * B) by merging canonical CSR rows directly in native code.
linalg.vdot(A, B) conjugates the left operand first (NumPy/MLX convention),
so for real matrices dot and vdot are identical.
The equivalent dense operation is (A.todense() * B.todense()).sum() which
allocates two n×n dense arrays; the sparse version never does this.
print("dot(A, A) =", linalg.dot(A, A))
print("vdot(A, A) =", linalg.vdot(A, A))
# Both should equal the sum of element-wise squares (Frobenius norm squared)
A_np = np.array(A.todense())
print("numpy ref =", float(np.sum(A_np * A_np)))
dot(A, A) = array(34, dtype=float32)
vdot(A, A) = array(34, dtype=float32)
numpy ref = 34.0
vdot with complex matrices#
For complex64 inputs, vdot(A, B) = sum(conj(A) * B) while
dot(A, B) = sum(A * B) without conjugation. The difference is the
imaginary part of the result.
C = ms.csr_array(
(
mx.array(np.array([1.0 + 2.0j, -3.0 + 0.5j], dtype=np.complex64)),
mx.array([0, 1], dtype=mx.int32),
mx.array([0, 1, 2], dtype=mx.int32),
),
shape=(2, 2),
canonical=True,
)
print("dot(C, C) =", linalg.dot(C, C)) # no conjugation
print("vdot(C, C) =", linalg.vdot(C, C)) # conjugates left: always real and non-negative
c_np = np.array([1.0+2.0j, -3.0+0.5j])
print("numpy dot = ", np.dot(c_np, c_np)) # no conjugation
print("numpy vdot = ", np.vdot(c_np, c_np)) # conjugates first arg
dot(C, C) = array(5.75+1j, dtype=complex64)
vdot(C, C) = array(14.25+0j, dtype=complex64)
numpy dot = (5.75+1j)
numpy vdot = (14.25+0j)