Sparse neural network layers#

Sparse weight matrices arise in:

  • Pruned networks: weights below a threshold are zeroed and dropped

  • Graph neural networks: adjacency-based message passing

  • Physics-informed models: structured sparsity from domain geometry

mlx-sparse integrates cleanly with MLX autodiff, making it possible to train with sparse forward passes while computing dense gradients w.r.t. learnable parameters.

This notebook shows:

  1. A sparse linear layer (fixed sparsity pattern, learnable values)

  2. Magnitude pruning of a dense layer and its effect on speed

  3. A minimal graph convolution step

import mlx.core as mx
import numpy as np
import scipy.sparse
import mlx_sparse as ms
import time

ms.use_gpu()

rng = np.random.default_rng(42)

Sparse linear layer#

A sparse linear layer has a fixed sparsity pattern (the non-zero positions) and learnable values at those positions. Only data is differentiated.

in_features = 512
out_features = 512
density = 0.01

# Fixed sparsity pattern (non-zero positions)
sp_pattern = scipy.sparse.random(out_features, in_features, density=density,
                                  format="csr", dtype=np.float32, random_state=0)

# Learnable values at those positions
init_values = mx.array(rng.standard_normal(sp_pattern.nnz).astype(np.float32))
                        
def make_weight(values):
    """Reconstruct CSRArray with learnable values but fixed structure."""
    return ms.csr_array(
        (values,
         mx.array(sp_pattern.indices.astype(np.int32)),
         mx.array(sp_pattern.indptr.astype(np.int32))),
        shape=(out_features, in_features),
        sorted_indices=True, canonical=True,
    )

print(f"Sparse linear layer: in={in_features}, out={out_features}, "
      f"nnz={sp_pattern.nnz} (density={density*100:.1f}%)")

# Forward pass: W x + b for a batch of inputs
batch_size = 32
X = mx.array(rng.standard_normal((batch_size, in_features)).astype(np.float32))
b = mx.zeros((out_features,), dtype=mx.float32)

def forward(values, X):
    W = make_weight(values)
    # X is (batch, in) -> need (in, batch) for SpMM, then transpose back
    out = W @ X.T  # (out, batch)
    return out.T + b  # (batch, out)

Y = forward(init_values, X)
mx.eval(Y)
print(f"\nForward pass output shape: {Y.shape}")

# Gradient w.r.t. both input X and weight values
def loss(values, X):
    return mx.sum(forward(values, X) ** 2)

grad_fn = mx.grad(loss, argnums=(0, 1))
grad_vals, grad_X = grad_fn(init_values, X)
mx.eval(grad_vals, grad_X)
print(f"\nGradient w.r.t. input  shape: {grad_X.shape}")
print(f"Gradient w.r.t. values shape: {grad_vals.shape}")
Sparse linear layer: in=512, out=512, nnz=2621 (density=1.0%)

Forward pass output shape: (32, 512)

Gradient w.r.t. input  shape: (32, 512)
Gradient w.r.t. values shape: (2621,)

Magnitude pruning and speed#

Starting from a dense weight matrix, prune weights below a threshold and benchmark sparse vs dense forward pass.

Environment: Apple M5, 10-core GPU, macOS 26.0, MLX 0.31, mlx-sparse 0.0.1b0

def bench_fn(fn, warmup=5, iters=50):
    for _ in range(warmup): mx.eval(fn())
    t0 = time.perf_counter()
    for _ in range(iters): mx.eval(fn())
    return (time.perf_counter() - t0) / iters * 1000

n = 1024
W_dense_np = rng.standard_normal((n, n)).astype(np.float32)
X_bench = mx.array(rng.standard_normal((n,)).astype(np.float32))
W_mx = mx.array(W_dense_np)
mx.eval(W_mx)

print(f"Pruning a {n}x{n} weight matrix")
print(f"\n{'sparsity':<11} {'nnz':<10} {'density':<10} {'sparse_ms':<12} {'dense_ms':<11} {'speedup'}")

for sparsity in [0.5, 0.75, 0.90, 0.95, 0.99, 0.999]:
    threshold = np.percentile(np.abs(W_dense_np), sparsity * 100)
    pruned_np = np.where(np.abs(W_dense_np) >= threshold, W_dense_np, 0.0)
    W_sparse = ms.fromdense(mx.array(pruned_np))

    t_sp = bench_fn(lambda: W_sparse @ X_bench)
    t_dn = bench_fn(lambda: W_mx @ X_bench)
    winner = "(sparse wins)" if t_sp < t_dn else "(dense wins) "
    density_pct = W_sparse.nnz / (n * n) * 100
    print(f"{sparsity*100:.1f}%      {W_sparse.nnz:<10} {density_pct:.3f}%     "
          f"{t_sp:.3f} ms   {t_dn:.3f} ms  {t_dn/t_sp:.1f}x  {winner}")
Pruning a 1024x1024 weight matrix

sparsity   nnz       density   sparse_ms   dense_ms   speedup
50.0%      524288    50.000%   2.841 ms    0.897 ms   0.3x  (dense wins)
75.0%      262144    25.000%   1.613 ms    0.897 ms   0.6x  (dense wins)
90.0%      104858    10.000%   0.861 ms    0.897 ms   1.0x  (sparse wins)
95.0%      52429     5.000%    0.486 ms    0.897 ms   1.8x  (sparse wins)
99.0%      10486     1.000%    0.173 ms    0.897 ms   5.2x  (sparse wins)
99.9%      1049      0.100%    0.048 ms    0.897 ms   18.7x  (sparse wins)

Graph convolution (GCN) step#

A single GCN message-passing step computes:

H_out = σ(^{-1/2} Ã ^{-1/2} H_in W)

where à = A + I is the adjacency with self-loops and is its degree matrix. The expensive part is the sparse-dense product with the node feature matrix H_in of shape (n_nodes, n_features).

n_nodes = 512
n_in = 64
n_out = 32

# Random graph adjacency
sp_adj = scipy.sparse.random(n_nodes, n_nodes, density=0.02, format="csr",
                              dtype=np.float32, random_state=99)
sp_adj = (sp_adj + sp_adj.T > 0).astype(np.float32)  # symmetrize and binarize

# Add self-loops: Ã = A + I
sp_Atilde = sp_adj + scipy.sparse.eye(n_nodes, dtype=np.float32)

# Normalized: D̃^{-1/2} Ã D̃^{-1/2}
deg = np.array(sp_Atilde.sum(axis=1)).ravel()
inv_sqrt = np.power(deg, -0.5, where=deg > 0, out=np.zeros_like(deg))
sp_norm = scipy.sparse.diags(inv_sqrt) @ sp_Atilde @ scipy.sparse.diags(inv_sqrt)

A_norm = ms.csr_array(
    (mx.array(sp_norm.data.astype(np.float32)),
     mx.array(sp_norm.indices.astype(np.int32)),
     mx.array(sp_norm.indptr.astype(np.int32))),
    shape=sp_norm.shape, sorted_indices=True, canonical=True,
)

# Learnable weight matrix W: (n_in, n_out)
W = mx.array(rng.standard_normal((n_in, n_out)).astype(np.float32) * 0.1)

# Node features H: (n_nodes, n_in)
H = mx.array(rng.standard_normal((n_nodes, n_in)).astype(np.float32))

def gcn_step(H, W):
    # Sparse part: normalize neighbor aggregation
    AH = A_norm @ H  # (n_nodes, n_in)
    # Dense part: linear transform
    out = AH @ W  # (n_nodes, n_out)
    return mx.maximum(out, 0)  # ReLU

print(f"GCN step: {n_nodes} nodes, {n_in} input features -> {n_out} output features")
H_out = gcn_step(H, W)
mx.eval(H_out)
print(f"\nH_out shape: {H_out.shape}")

# Gradients
def gcn_loss(H, W):
    return mx.sum(gcn_step(H, W) ** 2)

grad_fn = mx.grad(gcn_loss, argnums=(0, 1))
grad_H, grad_W = grad_fn(H, W)
mx.eval(grad_H, grad_W)
print(f"grad_W shape: {grad_W.shape}")
print(f"grad_H shape: {grad_H.shape}")
GCN step: 512 nodes, 64 input features -> 32 output features

H_out shape: (512, 32)
grad_W shape: (64, 32)
grad_H shape: (512, 64)

Multi-layer GCN training loop#

# Two-layer GCN: 64 -> 32 -> 16
W1 = mx.array(rng.standard_normal((n_in, 32)).astype(np.float32) * 0.1)
W2 = mx.array(rng.standard_normal((32, 16)).astype(np.float32) * 0.1)

# Target embeddings (random, for demonstration)
target = mx.array(rng.standard_normal((n_nodes, 16)).astype(np.float32))

lr = 0.01

def two_layer_gcn(W1, W2):
    H1 = mx.maximum(A_norm @ H @ W1, 0)  # layer 1 + ReLU
    H2 = A_norm @ H1 @ W2  # layer 2 (no activation)
    return H2

def gcn_loss_2(W1, W2):
    return mx.mean((two_layer_gcn(W1, W2) - target) ** 2)

print(f"2-layer GCN training (512 nodes, {n_in}->32->16 features)")
print()

loss_grad = mx.value_and_grad(gcn_loss_2, argnums=(0, 1))

for step in range(51):
    loss, (gW1, gW2) = loss_grad(W1, W2)
    mx.eval(loss, gW1, gW2)
    W1 = W1 - lr * gW1
    W2 = W2 - lr * gW2
    if step % 10 == 0:
        print(f"step {step:3d}  loss={float(loss):.1f}")
2-layer GCN training (512 nodes, 64->32->16 features)

step   0  loss=1247.3
step  10  loss=218.6
step  20  loss=48.4
step  30  loss=13.1
step  40  loss=4.2
step  50  loss=1.6