Device selection and execution model#
mlx-sparse builds on MLX’s device abstraction. Every MLX array belongs to a stream on a device, and all operations that create new arrays run on the current default device unless explicitly overridden with a stream argument.
Selecting the default device#
import mlx_sparse as ms
ms.use_gpu() # Apple Silicon GPU via Metal
ms.use_cpu() # CPU
ms.use_device("gpu") # same as use_gpu(). useful with argparse
These functions call mx.set_default_device and optionally probe the device
with a trivial evaluation to confirm it is available. The selected device
persists for the lifetime of the Python process or until changed by another
call.
Note
Calling use_gpu() or use_cpu() after operations have already been
dispatched does not retroactively move prior work. It only affects new
operations.
Lazy execution#
MLX uses a deferred execution model. Operations like A @ x or A.todense()
do not compute anything immediately. They add nodes to a computation graph.
Computation runs when mx.eval() is called explicitly, or implicitly when a
value is read (for example via numpy.array(y) or print(y)).
mlx-sparse follows this model for fixed-output numerical kernels:
Fixed-output operations are lazy.
csr_matvec,csr_matmul,todense,T,H, transpose products, and autodiff primitives add nodes to the MLX graph and do not materialize values immediately.Dynamic-output structural operations must discover output sizes.
fromdense(),sum_duplicates()/canonicalize(), and sparse-sparsematmatrun native counting or symbolic work first, then synchronize compact counts or structure so final sparse buffers can be allocated.Full validation (``validate=”full”``) also reads values. It must inspect
indptrandindicesto check bounds, so it callsmx.evalon those arrays. Keep this in mind when constructing from device arrays.``to_numpy`` (used internally by fallback operations and full validation) always calls
mx.eval.
A graph composition example:
ms.use_gpu()
y = A @ x # lazy: one graph node
z = mx.sin(y) + 2.0 # lazy: two more graph nodes
mx.eval(z) # GPU runs here. only one dispatch
This means you can build multi-step computations before triggering any GPU work, letting MLX fuse and optimize the graph.
Which operations run on GPU#
Operation |
CPU |
Metal GPU |
|---|---|---|
|
Yes |
Yes |
|
Yes |
Yes |
|
Yes |
Yes |
|
Yes |
Yes |
|
Yes |
Yes |
|
Yes |
Yes |
|
Yes |
Yes |
|
Yes |
Yes |
|
Yes |
Yes (staged count/prefix/fill, synchronizes row counts) |
|
Yes |
Yes (staged count/prefix/fill, synchronizes row counts) |
Sparse-sparse |
Yes |
Experimental via |
Sparse-sparse |
Yes |
Not yet. Native host symbolic/numeric paths are used. |
Batched sparse-dense products for COO, CSR, and CSC |
Yes |
Yes |
Autodiff (JVP / VJP, sparse values and dense RHS) |
Yes |
Yes |
When a GPU primitive encounters an unsupported configuration, it raises a
RuntimeError with a clear message. Some public operations intentionally
lower to other native primitives on GPU, for example some non-float32 CSR
transpose products use csr_transpose followed by the ordinary product
rather than a direct complex or low-precision atomic scatter kernel. COO and
CSC scatter products keep native GPU coverage, float32 uses atomic
scatter-add, while other value dtypes use native serial scatter where Metal
lacks compatible atomic adds.
Typical workflow: construct on CPU, multiply on GPU#
The most common pattern for large-scale workloads is:
Assemble or canonicalize sparse structure once. Native staged constructors can run on CPU or GPU, but they may synchronize counts to allocate compact output buffers.
Keep the resulting CSR buffers and dense RHS arrays on the target device.
Run repeated COO/CSR/CSC matvec, matmul, and batched products on GPU.
import mlx.core as mx
import numpy as np
import mlx_sparse as ms
# Assembly phase: build and canonicalize once
ms.use_cpu()
coo = ms.coo_array((data, (row, col)), shape=(m, n))
csr = coo.tocsr(canonical=True)
mx.eval(csr.data, csr.indices, csr.indptr) # materialise buffers
# Compute phase: multiply on GPU
ms.use_gpu()
# Re-wrap the same buffers (already evaluated) into a new csr_array call.
# No data is copied. MLX arrays are device-agnostic.
csr_gpu = ms.csr_array(
(csr.data, csr.indices, csr.indptr),
shape=csr.shape,
sorted_indices=csr.sorted_indices,
canonical=csr.has_canonical_format,
validate=False, # buffers already validated
)
x = mx.array(np.random.randn(n).astype(np.float32))
y = csr_gpu @ x # dispatches Metal kernel
mx.eval(y)
Stream safety#
All native primitives pass MLX’s StreamOrDevice parameter through to the
underlying operation wrappers and C++ primitive constructors. When the default
stream is used, MLX handles command sequencing automatically. Do not call
mx.synchronize() or your own Metal synchronization inside a sparse
operation. This will deadlock with MLX’s command encoder.