Autodiff through sparse operations#

All sparse-dense products in mlx-sparse are differentiable with respect to sparse data values and the dense operand. The sparse structure (indices and indptr) is treated as a non-differentiable constant.

Function

Differentiable w.r.t.

csr_matvec(A, x)

A.data, x

csr_matmul(A, B)

A.data, B

csr_matmul batched

A.data, batch of B tensors

On Apple Silicon Metal GPU, gradients work for float32, float16, bfloat16, and complex64.

import mlx.core as mx
import numpy as np
import mlx_sparse as ms

ms.use_gpu()

# Build a fixed sparse matrix for all examples
import scipy.sparse
rng = np.random.default_rng(0)
sp = scipy.sparse.random(64, 64, density=0.05, format="csr",
                           dtype=np.float32, random_state=rng)
A = ms.csr_array(
    (mx.array(sp.data),
     mx.array(sp.indices.astype(np.int32)),
     mx.array(sp.indptr.astype(np.int32))),
    shape=sp.shape, sorted_indices=True, canonical=True,
)
print("A:", A)

Gradient of a scalar loss through SpMV#

We minimize a simple sum-of-squares loss L(x) = ||A x||² and verify the analytic gradient ∇_x L = 2 Aᵀ A x.

x = mx.array(rng.standard_normal(64).astype(np.float32))

def loss_matvec(x):
    y = A @ x
    return mx.sum(y ** 2)

grad_x = mx.grad(loss_matvec)(x)
mx.eval(grad_x)
print("grad_x shape:", grad_x.shape)

# Analytic: ∇_x ||Ax||² = 2 AᵀAx
A_dense = A.todense()
mx.eval(A_dense)
grad_ref = 2.0 * (A_dense.T @ (A_dense @ x))
mx.eval(grad_ref)

err = np.max(np.abs(np.array(grad_x) - np.array(grad_ref)))
print(f"max error vs analytic: {err:.2e}")
grad_x shape: (64,)
max error vs analytic: 1.19e-07

Gradient through SpMM (rank-2 RHS)#

B = mx.array(rng.standard_normal((64, 16)).astype(np.float32))

def loss_matmul(B):
    Y = A @ B
    return mx.sum(Y ** 2)

grad_B = mx.grad(loss_matmul)(B)
mx.eval(grad_B)
print("grad_B shape:", grad_B.shape)

# Dense reference: ∇_B ||AB||² = 2 AᵀAB
def dense_loss(B): return mx.sum((A_dense @ B) ** 2)
grad_B_ref = mx.grad(dense_loss)(B)
mx.eval(grad_B_ref)

err = np.max(np.abs(np.array(grad_B) - np.array(grad_B_ref)))
print(f"max error vs dense grad: {err:.2e}")
grad_B shape: (64, 16)
max error vs dense grad: 2.38e-07

Gradient through batched SpMM#

Pass a rank-3 tensor; mx.grad handles the batch dimension automatically.

B_batch = mx.array(rng.standard_normal((4, 64, 8)).astype(np.float32))
print("B_batch shape:", B_batch.shape)

def loss_batch(B):
    Y = A @ B
    return mx.sum(Y ** 2)

grad_batch = mx.grad(loss_batch)(B_batch)
mx.eval(grad_batch)
print("grad_batch shape:", grad_batch.shape)

# Dense reference
def dense_loss_batch(B): return mx.sum((A_dense @ B) ** 2)
grad_batch_ref = mx.grad(dense_loss_batch)(B_batch)
mx.eval(grad_batch_ref)

err = np.max(np.abs(np.array(grad_batch) - np.array(grad_batch_ref)))
print(f"max error vs dense grad: {err:.2e}")
B_batch shape: (4, 64, 8)
grad_batch shape: (4, 64, 8)
max error vs dense grad: 4.77e-07

VJP and JVP#

mx.vjp and mx.jvp also compose with sparse operations.

# VJP: vjp(f, primals, cotangents)
def f_matvec(x): return A @ x

x0 = mx.array(rng.standard_normal(64).astype(np.float32))
cotangent = mx.array(rng.standard_normal(64).astype(np.float32))

_, vjp_out = mx.vjp(f_matvec, (x0,), (cotangent,))
mx.eval(vjp_out[0])
print("VJP output shape:", vjp_out[0].shape)

# Manual: Aᵀ cotangent
vjp_ref = A_dense.T @ cotangent
mx.eval(vjp_ref)
err_vjp = np.max(np.abs(np.array(vjp_out[0]) - np.array(vjp_ref)))
print(f"VJP max error vs manual: {err_vjp:.2e}")

# JVP: jvp(f, primals, tangents)
def f_matmul(B): return A @ B

B0 = mx.array(rng.standard_normal((64, 8)).astype(np.float32))
dB = mx.array(rng.standard_normal((64, 8)).astype(np.float32))

_, jvp_out = mx.jvp(f_matmul, (B0,), (dB,))
mx.eval(jvp_out[0])
print("\nJVP tangent shape:", jvp_out[0].shape)

# Manual: A dB
jvp_ref = A_dense @ dB
mx.eval(jvp_ref)
err_jvp = np.max(np.abs(np.array(jvp_out[0]) - np.array(jvp_ref)))
print(f"JVP max error vs manual: {err_jvp:.2e}")
VJP output shape: (64,)
VJP max error vs manual: 2.38e-07

JVP tangent shape: (64, 8)
JVP max error vs manual: 0.00e+00

End-to-end: sparse linear layer in a gradient descent loop#

We fit a rank-1 update W = A + u vᵀ (not realistic, but pedagogically clear). Only the dense weights u and v are optimized. The sparse A is fixed.

# Fixed target vector
target = mx.array(rng.standard_normal(64).astype(np.float32))

# Learnable dense parameters
u = mx.array(rng.standard_normal(64).astype(np.float32))
v = mx.array(rng.standard_normal(64).astype(np.float32))
x_in = mx.array(rng.standard_normal(64).astype(np.float32))

lr = 0.01

def compute_loss(u, v):
    # Sparse part
    y_sparse = A @ x_in
    # Dense rank-1 correction
    y_dense = u * mx.sum(v * x_in)
    y = y_sparse + y_dense
    return mx.sum((y - target) ** 2)

loss_and_grad = mx.value_and_grad(compute_loss)

for step in range(101):
    loss, (du, dv) = loss_and_grad(u, v)
    mx.eval(loss, du, dv)
    u = u - lr * du
    v = v - lr * dv
    if step % 20 == 0:
        print(f"step {step:3d}  loss={float(loss):.4f}")
step   0  loss=28.3741
step  20  loss=10.5234
step  40  loss=4.0117
step  60  loss=1.5623
step  80  loss=0.6189
step 100  loss=0.2481