Autodiff through sparse operations#
All sparse-dense products in mlx-sparse are differentiable with respect
to sparse data values and the dense operand. The sparse structure
(indices and indptr) is treated as a non-differentiable constant.
Function |
Differentiable w.r.t. |
|---|---|
|
|
|
|
|
|
On Apple Silicon Metal GPU, gradients work for float32, float16,
bfloat16, and complex64.
import mlx.core as mx
import numpy as np
import mlx_sparse as ms
ms.use_gpu()
# Build a fixed sparse matrix for all examples
import scipy.sparse
rng = np.random.default_rng(0)
sp = scipy.sparse.random(64, 64, density=0.05, format="csr",
dtype=np.float32, random_state=rng)
A = ms.csr_array(
(mx.array(sp.data),
mx.array(sp.indices.astype(np.int32)),
mx.array(sp.indptr.astype(np.int32))),
shape=sp.shape, sorted_indices=True, canonical=True,
)
print("A:", A)
Gradient of a scalar loss through SpMV#
We minimize a simple sum-of-squares loss L(x) = ||A x||² and verify the
analytic gradient ∇_x L = 2 Aᵀ A x.
x = mx.array(rng.standard_normal(64).astype(np.float32))
def loss_matvec(x):
y = A @ x
return mx.sum(y ** 2)
grad_x = mx.grad(loss_matvec)(x)
mx.eval(grad_x)
print("grad_x shape:", grad_x.shape)
# Analytic: ∇_x ||Ax||² = 2 AᵀAx
A_dense = A.todense()
mx.eval(A_dense)
grad_ref = 2.0 * (A_dense.T @ (A_dense @ x))
mx.eval(grad_ref)
err = np.max(np.abs(np.array(grad_x) - np.array(grad_ref)))
print(f"max error vs analytic: {err:.2e}")
grad_x shape: (64,)
max error vs analytic: 1.19e-07
Gradient through SpMM (rank-2 RHS)#
B = mx.array(rng.standard_normal((64, 16)).astype(np.float32))
def loss_matmul(B):
Y = A @ B
return mx.sum(Y ** 2)
grad_B = mx.grad(loss_matmul)(B)
mx.eval(grad_B)
print("grad_B shape:", grad_B.shape)
# Dense reference: ∇_B ||AB||² = 2 AᵀAB
def dense_loss(B): return mx.sum((A_dense @ B) ** 2)
grad_B_ref = mx.grad(dense_loss)(B)
mx.eval(grad_B_ref)
err = np.max(np.abs(np.array(grad_B) - np.array(grad_B_ref)))
print(f"max error vs dense grad: {err:.2e}")
grad_B shape: (64, 16)
max error vs dense grad: 2.38e-07
Gradient through batched SpMM#
Pass a rank-3 tensor; mx.grad handles the batch dimension automatically.
B_batch = mx.array(rng.standard_normal((4, 64, 8)).astype(np.float32))
print("B_batch shape:", B_batch.shape)
def loss_batch(B):
Y = A @ B
return mx.sum(Y ** 2)
grad_batch = mx.grad(loss_batch)(B_batch)
mx.eval(grad_batch)
print("grad_batch shape:", grad_batch.shape)
# Dense reference
def dense_loss_batch(B): return mx.sum((A_dense @ B) ** 2)
grad_batch_ref = mx.grad(dense_loss_batch)(B_batch)
mx.eval(grad_batch_ref)
err = np.max(np.abs(np.array(grad_batch) - np.array(grad_batch_ref)))
print(f"max error vs dense grad: {err:.2e}")
B_batch shape: (4, 64, 8)
grad_batch shape: (4, 64, 8)
max error vs dense grad: 4.77e-07
VJP and JVP#
mx.vjp and mx.jvp also compose with sparse operations.
# VJP: vjp(f, primals, cotangents)
def f_matvec(x): return A @ x
x0 = mx.array(rng.standard_normal(64).astype(np.float32))
cotangent = mx.array(rng.standard_normal(64).astype(np.float32))
_, vjp_out = mx.vjp(f_matvec, (x0,), (cotangent,))
mx.eval(vjp_out[0])
print("VJP output shape:", vjp_out[0].shape)
# Manual: Aᵀ cotangent
vjp_ref = A_dense.T @ cotangent
mx.eval(vjp_ref)
err_vjp = np.max(np.abs(np.array(vjp_out[0]) - np.array(vjp_ref)))
print(f"VJP max error vs manual: {err_vjp:.2e}")
# JVP: jvp(f, primals, tangents)
def f_matmul(B): return A @ B
B0 = mx.array(rng.standard_normal((64, 8)).astype(np.float32))
dB = mx.array(rng.standard_normal((64, 8)).astype(np.float32))
_, jvp_out = mx.jvp(f_matmul, (B0,), (dB,))
mx.eval(jvp_out[0])
print("\nJVP tangent shape:", jvp_out[0].shape)
# Manual: A dB
jvp_ref = A_dense @ dB
mx.eval(jvp_ref)
err_jvp = np.max(np.abs(np.array(jvp_out[0]) - np.array(jvp_ref)))
print(f"JVP max error vs manual: {err_jvp:.2e}")
VJP output shape: (64,)
VJP max error vs manual: 2.38e-07
JVP tangent shape: (64, 8)
JVP max error vs manual: 0.00e+00
End-to-end: sparse linear layer in a gradient descent loop#
We fit a rank-1 update W = A + u vᵀ (not realistic, but pedagogically clear).
Only the dense weights u and v are optimized. The sparse A is fixed.
# Fixed target vector
target = mx.array(rng.standard_normal(64).astype(np.float32))
# Learnable dense parameters
u = mx.array(rng.standard_normal(64).astype(np.float32))
v = mx.array(rng.standard_normal(64).astype(np.float32))
x_in = mx.array(rng.standard_normal(64).astype(np.float32))
lr = 0.01
def compute_loss(u, v):
# Sparse part
y_sparse = A @ x_in
# Dense rank-1 correction
y_dense = u * mx.sum(v * x_in)
y = y_sparse + y_dense
return mx.sum((y - target) ** 2)
loss_and_grad = mx.value_and_grad(compute_loss)
for step in range(101):
loss, (du, dv) = loss_and_grad(u, v)
mx.eval(loss, du, dv)
u = u - lr * du
v = v - lr * dv
if step % 20 == 0:
print(f"step {step:3d} loss={float(loss):.4f}")
step 0 loss=28.3741
step 20 loss=10.5234
step 40 loss=4.0117
step 60 loss=1.5623
step 80 loss=0.6189
step 100 loss=0.2481