Sparse neural network layers#
Sparse weight matrices arise in:
Pruned networks: weights below a threshold are zeroed and dropped
Graph neural networks: adjacency-based message passing
Physics-informed models: structured sparsity from domain geometry
mlx-sparse integrates cleanly with MLX autodiff, making it possible to train with sparse forward passes while computing dense gradients w.r.t. learnable parameters.
This notebook shows:
A sparse linear layer (fixed sparsity pattern, learnable values)
Magnitude pruning of a dense layer and its effect on speed
A minimal graph convolution step
import mlx.core as mx
import numpy as np
import scipy.sparse
import mlx_sparse as ms
import time
ms.use_gpu()
rng = np.random.default_rng(42)
Sparse linear layer#
A sparse linear layer has a fixed sparsity pattern (the non-zero positions)
and learnable values at those positions. Only data is differentiated.
in_features = 512
out_features = 512
density = 0.01
# Fixed sparsity pattern (non-zero positions)
sp_pattern = scipy.sparse.random(out_features, in_features, density=density,
format="csr", dtype=np.float32, random_state=0)
# Learnable values at those positions
init_values = mx.array(rng.standard_normal(sp_pattern.nnz).astype(np.float32))
def make_weight(values):
"""Reconstruct CSRArray with learnable values but fixed structure."""
return ms.csr_array(
(values,
mx.array(sp_pattern.indices.astype(np.int32)),
mx.array(sp_pattern.indptr.astype(np.int32))),
shape=(out_features, in_features),
sorted_indices=True, canonical=True,
)
print(f"Sparse linear layer: in={in_features}, out={out_features}, "
f"nnz={sp_pattern.nnz} (density={density*100:.1f}%)")
# Forward pass: W x + b for a batch of inputs
batch_size = 32
X = mx.array(rng.standard_normal((batch_size, in_features)).astype(np.float32))
b = mx.zeros((out_features,), dtype=mx.float32)
def forward(values, X):
W = make_weight(values)
# X is (batch, in) -> need (in, batch) for SpMM, then transpose back
out = W @ X.T # (out, batch)
return out.T + b # (batch, out)
Y = forward(init_values, X)
mx.eval(Y)
print(f"\nForward pass output shape: {Y.shape}")
# Gradient w.r.t. both input X and weight values
def loss(values, X):
return mx.sum(forward(values, X) ** 2)
grad_fn = mx.grad(loss, argnums=(0, 1))
grad_vals, grad_X = grad_fn(init_values, X)
mx.eval(grad_vals, grad_X)
print(f"\nGradient w.r.t. input shape: {grad_X.shape}")
print(f"Gradient w.r.t. values shape: {grad_vals.shape}")
Sparse linear layer: in=512, out=512, nnz=2621 (density=1.0%)
Forward pass output shape: (32, 512)
Gradient w.r.t. input shape: (32, 512)
Gradient w.r.t. values shape: (2621,)
Magnitude pruning and speed#
Starting from a dense weight matrix, prune weights below a threshold and benchmark sparse vs dense forward pass.
Environment: Apple M5, 10-core GPU, macOS 26.0, MLX 0.31, mlx-sparse 0.0.1b0
def bench_fn(fn, warmup=5, iters=50):
for _ in range(warmup): mx.eval(fn())
t0 = time.perf_counter()
for _ in range(iters): mx.eval(fn())
return (time.perf_counter() - t0) / iters * 1000
n = 1024
W_dense_np = rng.standard_normal((n, n)).astype(np.float32)
X_bench = mx.array(rng.standard_normal((n,)).astype(np.float32))
W_mx = mx.array(W_dense_np)
mx.eval(W_mx)
print(f"Pruning a {n}x{n} weight matrix")
print(f"\n{'sparsity':<11} {'nnz':<10} {'density':<10} {'sparse_ms':<12} {'dense_ms':<11} {'speedup'}")
for sparsity in [0.5, 0.75, 0.90, 0.95, 0.99, 0.999]:
threshold = np.percentile(np.abs(W_dense_np), sparsity * 100)
pruned_np = np.where(np.abs(W_dense_np) >= threshold, W_dense_np, 0.0)
W_sparse = ms.fromdense(mx.array(pruned_np))
t_sp = bench_fn(lambda: W_sparse @ X_bench)
t_dn = bench_fn(lambda: W_mx @ X_bench)
winner = "(sparse wins)" if t_sp < t_dn else "(dense wins) "
density_pct = W_sparse.nnz / (n * n) * 100
print(f"{sparsity*100:.1f}% {W_sparse.nnz:<10} {density_pct:.3f}% "
f"{t_sp:.3f} ms {t_dn:.3f} ms {t_dn/t_sp:.1f}x {winner}")
Pruning a 1024x1024 weight matrix
sparsity nnz density sparse_ms dense_ms speedup
50.0% 524288 50.000% 2.841 ms 0.897 ms 0.3x (dense wins)
75.0% 262144 25.000% 1.613 ms 0.897 ms 0.6x (dense wins)
90.0% 104858 10.000% 0.861 ms 0.897 ms 1.0x (sparse wins)
95.0% 52429 5.000% 0.486 ms 0.897 ms 1.8x (sparse wins)
99.0% 10486 1.000% 0.173 ms 0.897 ms 5.2x (sparse wins)
99.9% 1049 0.100% 0.048 ms 0.897 ms 18.7x (sparse wins)
Graph convolution (GCN) step#
A single GCN message-passing step computes:
H_out = σ(D̃^{-1/2} Ã D̃^{-1/2} H_in W)
where à = A + I is the adjacency with self-loops and D̃ is its degree
matrix. The expensive part is the sparse-dense product with the node feature
matrix H_in of shape (n_nodes, n_features).
n_nodes = 512
n_in = 64
n_out = 32
# Random graph adjacency
sp_adj = scipy.sparse.random(n_nodes, n_nodes, density=0.02, format="csr",
dtype=np.float32, random_state=99)
sp_adj = (sp_adj + sp_adj.T > 0).astype(np.float32) # symmetrize and binarize
# Add self-loops: Ã = A + I
sp_Atilde = sp_adj + scipy.sparse.eye(n_nodes, dtype=np.float32)
# Normalized: D̃^{-1/2} Ã D̃^{-1/2}
deg = np.array(sp_Atilde.sum(axis=1)).ravel()
inv_sqrt = np.power(deg, -0.5, where=deg > 0, out=np.zeros_like(deg))
sp_norm = scipy.sparse.diags(inv_sqrt) @ sp_Atilde @ scipy.sparse.diags(inv_sqrt)
A_norm = ms.csr_array(
(mx.array(sp_norm.data.astype(np.float32)),
mx.array(sp_norm.indices.astype(np.int32)),
mx.array(sp_norm.indptr.astype(np.int32))),
shape=sp_norm.shape, sorted_indices=True, canonical=True,
)
# Learnable weight matrix W: (n_in, n_out)
W = mx.array(rng.standard_normal((n_in, n_out)).astype(np.float32) * 0.1)
# Node features H: (n_nodes, n_in)
H = mx.array(rng.standard_normal((n_nodes, n_in)).astype(np.float32))
def gcn_step(H, W):
# Sparse part: normalize neighbor aggregation
AH = A_norm @ H # (n_nodes, n_in)
# Dense part: linear transform
out = AH @ W # (n_nodes, n_out)
return mx.maximum(out, 0) # ReLU
print(f"GCN step: {n_nodes} nodes, {n_in} input features -> {n_out} output features")
H_out = gcn_step(H, W)
mx.eval(H_out)
print(f"\nH_out shape: {H_out.shape}")
# Gradients
def gcn_loss(H, W):
return mx.sum(gcn_step(H, W) ** 2)
grad_fn = mx.grad(gcn_loss, argnums=(0, 1))
grad_H, grad_W = grad_fn(H, W)
mx.eval(grad_H, grad_W)
print(f"grad_W shape: {grad_W.shape}")
print(f"grad_H shape: {grad_H.shape}")
GCN step: 512 nodes, 64 input features -> 32 output features
H_out shape: (512, 32)
grad_W shape: (64, 32)
grad_H shape: (512, 64)
Multi-layer GCN training loop#
# Two-layer GCN: 64 -> 32 -> 16
W1 = mx.array(rng.standard_normal((n_in, 32)).astype(np.float32) * 0.1)
W2 = mx.array(rng.standard_normal((32, 16)).astype(np.float32) * 0.1)
# Target embeddings (random, for demonstration)
target = mx.array(rng.standard_normal((n_nodes, 16)).astype(np.float32))
lr = 0.01
def two_layer_gcn(W1, W2):
H1 = mx.maximum(A_norm @ H @ W1, 0) # layer 1 + ReLU
H2 = A_norm @ H1 @ W2 # layer 2 (no activation)
return H2
def gcn_loss_2(W1, W2):
return mx.mean((two_layer_gcn(W1, W2) - target) ** 2)
print(f"2-layer GCN training (512 nodes, {n_in}->32->16 features)")
print()
loss_grad = mx.value_and_grad(gcn_loss_2, argnums=(0, 1))
for step in range(51):
loss, (gW1, gW2) = loss_grad(W1, W2)
mx.eval(loss, gW1, gW2)
W1 = W1 - lr * gW1
W2 = W2 - lr * gW2
if step % 10 == 0:
print(f"step {step:3d} loss={float(loss):.1f}")
2-layer GCN training (512 nodes, 64->32->16 features)
step 0 loss=1247.3
step 10 loss=218.6
step 20 loss=48.4
step 30 loss=13.1
step 40 loss=4.2
step 50 loss=1.6