Operations#
These module-level functions wrap the native sparse primitives. In most cases
the @ operator on COOArray, CSRArray, and
CSCArray is the preferred spelling. These functions exist for
explicit dispatch and for callers who prefer a functional style.
Sparse-dense operations return lazy mlx.core.array values. Sparse-sparse
matmat operations return new sparse arrays and may synchronize structure to
host because their output sparsity pattern is data-dependent.
csr_matvec#
- mlx_sparse.csr_matvec(a, x)[source]#
Multiply a CSR sparse matrix by a dense vector.
Computes
y = A @ xwhereAis aCSRArrayandxis a rank-1 dense array. The result is added to the MLX computation graph and not evaluated eagerly.On Apple Silicon, the Metal backend dispatches a scalar row kernel for short rows and a vector-reduction kernel for long rows. CPU and GPU paths support
float32,float16,bfloat16, andcomplex64values withint32orint64indices.- Parameters:
a (CSRArray) – The sparse matrix, shape
(n_rows, n_cols).x – Dense vector, shape
(n_cols,). Converted tomx.arrayif needed. Must have the same dtype asa.data.
- Returns:
Dense vector of shape
(n_rows,)with the same dtype asa.data.- Raises:
TypeError – If
ais not aCSRArray, or if the dtypes ofa.dataandxdo not match.ValueError – If shape constraints are violated.
- Return type:
mlx.core.array
Example:
import mlx.core as mx import mlx_sparse as ms y = a @ x # preferred via __matmul__ y = ms.csr_matvec(a, x) # explicit call mx.eval(y)
coo_matvec#
csc_matvec#
csc_matvec_transpose#
csr_matmul#
- mlx_sparse.csr_matmul(a, rhs)[source]#
Multiply a CSR sparse matrix by a dense matrix.
Computes
Y = A @ BwhereAis aCSRArrayandBis a rank-2 or batched dense array. The result is added to the MLX computation graph and not evaluated eagerly.On Apple Silicon, the Metal backend dispatches scalar output-element kernels for short rows and vector-reduction kernels for long rows. CPU and GPU paths support
float32,float16,bfloat16, andcomplex64values withint32orint64indices.- Parameters:
a (CSRArray) – The sparse matrix, shape
(n_rows, n_cols).rhs – Dense matrix, shape
(n_cols, k), or batched dense matrix with sparse dimension atrhs.shape[-2]. Converted tomx.arrayif needed. Must have the same dtype asa.data.
- Returns:
Dense matrix or batched dense matrix with sparse dimension replaced by
n_rowsand the same dtype asa.data.- Raises:
TypeError – If
ais not aCSRArray, or if dtype constraints are violated.ValueError – If shape constraints are violated.
- Return type:
mlx.core.array
Example:
import mlx.core as mx import mlx_sparse as ms Y = a @ B # preferred via __matmul__ Y = ms.csr_matmul(a, B) # explicit call mx.eval(Y)
coo_matmul#
csc_matmul#
csr_batched_matvec#
- mlx_sparse.csr_batched_matvec(a, rhs)[source]#
Multiply a CSR sparse matrix by a batch of dense vectors.
Computes
Y[b] = A @ X[b]forXwith shape(..., n_cols)and returns shape(..., n_rows). The implementation uses native batched CPU/Metal kernels after flattening any leading batch dimensions.- Parameters:
a (CSRArray)
- Return type:
mlx.core.array
coo_batched_matvec#
csc_batched_matvec#
csr_batched_matmul#
- mlx_sparse.csr_batched_matmul(a, rhs)[source]#
Multiply a CSR sparse matrix by a batch of dense matrices.
rhsmust have shape(..., n_cols, k)and the result has shape(..., n_rows, k). For rank-2 dense matrices, usecsr_matmul().- Parameters:
a (CSRArray)
- Return type:
mlx.core.array
coo_batched_matmul#
csc_batched_matmul#
csr_matmat#
- mlx_sparse.csr_matmat(a, rhs)[source]#
Multiply two CSR sparse matrices and return a canonical CSR matrix.
Computes
C = A @ Bwhere bothAandBareCSRArrayinstances. The output sparsity pattern is not known at graph-build time, so this operation performs a native C++ structural assembly pass on the host (callingmx.evalon the input arrays internally) and returns a newCSRArraywith canonical format.Because the output size is data-dependent, this operation is not representable as a fixed-shape MLX primitive. It is suitable for one-shot matrix products and matrix-power computations, but is not appropriate inside a JIT-compiled function.
- Parameters:
- Returns:
A canonical
CSRArraywith shape(m, n),has_canonical_format=True, andsorted_indices=True.- Raises:
ValueError – If the inner dimensions do not match (
a.shape[1] != rhs.shape[0]).
- Return type:
Example:
import mlx_sparse as ms # Compute the square of a sparse matrix C = A @ A # dispatches csr_matmat when A is CSRArray C = ms.csr_matmat(A, A) # explicit call # Chain sparse matrix products D = ms.csr_matmat(ms.csr_matmat(A, B), C)
coo_matmat#
csc_matmat#
Reductions#
All sparse containers expose reduction methods as well as module-level helper
functions. row_sums / col_sums return the same dtype as the sparse
values, row_norms / col_norms return float32, and diagonal /
trace sum duplicate diagonal entries.
Format |
Functions |
|---|---|
COO |
|
CSR |
|
CSC |
|
COO and CSC reductions are native C++/Metal paths. Norm reductions use dense matrix semantics, so non-canonical COO/CSC inputs are canonicalized before norming to ensure duplicate coordinates are summed before the square is taken.
todense#
- mlx_sparse.todense(array)[source]#
Materialize a sparse array as a dense MLX array.
Convenience wrapper that calls
array.todense()on any sparse container. Duplicate entries are summed, consistent withcanonicalize().todense().- Parameters:
- Returns:
Dense array of shape
(n_rows, n_cols)with the same dtype asarray.data.- Raises:
TypeError – If
arraydoes not have atodensemethod.- Return type:
mlx.core.array
Example:
import mlx_sparse as ms dense = ms.todense(my_csr)
identity_like#
- mlx_sparse.identity_like(x)[source]#
Return a native MLX copy of
x.This function exists as an extension smoke test. It passes
xthrough the native_extmodule (if available) and returns an identical MLX array. For production code, prefermlx.coreoperations directly.- Parameters:
x (mlx.core.array) – Any MLX array.
- Returns:
An MLX array with the same shape, dtype, and values as
x.- Return type:
mlx.core.array
is_available#
- mlx_sparse.is_available()[source]#
Return
Trueif the native C++ extension is loaded.The mlx-sparse native extension (
_ext) provides MLX-primitive implementations of sparse operations with CPU and Metal backends. When it is absent (e.g. a pure-source checkout without a build step), all operations fall back to NumPy-based Python implementations inmlx_sparse._fallback.- Returns:
Trueifmlx_sparse._extwas successfully imported at package load time,Falseotherwise.- Return type:
Example:
import mlx_sparse as ms if not ms.is_available(): print("Native extension not found. Using Python fallback.")
Dispatch summary#
The @ operator on CSRArray dispatches based on the type and rank
of rhs:
C = A @ B # rhs is CSRArray -> csr_matmat(A, B) returns CSRArray
y = A @ x # rhs.ndim == 1 -> csr_matvec(A, x) returns mx.array
Y = A @ X # rhs.ndim == 2 -> csr_matmul(A, X) returns mx.array
Yb = A @ Xb # rhs.ndim > 2 -> csr_matmul(A, Xb) returns mx.array
The explicit function calls accept the same arguments:
y = ms.csr_matvec(A, x)
Y = ms.csr_matmul(A, X)
yb = ms.csr_batched_matvec(A, xb)
Yb = ms.csr_batched_matmul(A, Xb)
C = ms.csr_matmat(A, B)
For COOArray and CSCArray, dense RHS dispatch mirrors CSR:
rank-1 RHS uses format-native matvec, rank-2 RHS uses format-native matmul,
and higher-rank RHS is flattened into the corresponding native batched
primitive. Same-format sparse-sparse products are also native:
COOArray @ COOArray returns canonical COO via coo_matmat(), and
CSCArray @ CSCArray returns canonical CSC via csc_matmat().
Mixed-format sparse-sparse products remain explicit: convert the operand
yourself when a different storage format is acceptable.
All sparse-dense products validate that rhs.dtype == A.data.dtype. There
is no implicit type promotion. See Dtype policy for the
full dtype matrix.
Native dispatch notes#
Sparse-dense matvec and matmul for COO, CSR, and CSC are fixed-output primitives and stay lazy in the MLX graph. Explicit batched helpers use native C++/Metal kernels for leading batch dimensions rather than materializing dense matrices in Python.
Transpose products used by autodiff are also native. On Metal, float32
transpose matvec/matmul use atomic scatter-add kernels. Other GPU value dtypes
lower through native csr_transpose followed by the ordinary native product.
Sparse-sparse matmat is different: its output shape depends on the input
structure, so it performs symbolic/count work and synchronizes enough structure
to allocate compact output buffers. CSR uses row symbolic/numeric assembly, COO
groups coordinate rows without routing through CSR, and CSC walks right-hand
columns against left-hand compressed columns to produce sorted output columns.
coo_matvec / coo_matmul are native coordinate scatter products. On
Metal, float32 uses atomic scatter-add over stored coordinates, other
value dtypes stay native through a serial scatter kernel because Metal does
not provide storage-compatible atomic adds for float16, bfloat16, or
complex64.
csc_matvec / csc_matmul are native compressed-column scatter products.
Forward CSC products walk columns and scatter into output rows, on Metal,
float32 uses atomic scatter-add while other value dtypes use native serial
scatter. CSC transpose products are the layout’s reduction fast path: each
output entry is one compressed-column dot product.