Dtype policy#
mlx-sparse enforces explicit dtype constraints on both the value arrays
(data) and the index arrays (indices, indptr, row, col).
Mixed dtypes between the sparse data and the dense operand are rejected at
constructor time rather than silently promoted.
Value dtypes#
The following value dtypes are supported:
Dtype |
Python name |
Notes |
|---|---|---|
|
|
Fully supported on CPU and Metal GPU. The primary dtype. |
|
|
Supported on CPU and Metal GPU. CPU and GPU accumulate in
|
|
|
Supported on CPU and Metal GPU. Same accumulation convention
as |
|
|
Supported on CPU and Metal GPU for forward operations and autodiff through sparse values and dense RHS operands. |
Integer (int32, int64), boolean, and float64 are not supported.
Index dtypes#
Dtype |
Notes |
|---|---|
|
Default. All CPU and Metal kernels support it. Sufficient for matrices up to roughly 2 billion non-zeros, which covers all practical Apple Silicon workloads. |
|
Supported on CPU and Metal GPU. Use for matrices that exceed
the |
``indices`` and ``indptr`` must share the same dtype in a CSRArray.
Similarly, row and col must share the same dtype in a COOArray.
Mismatched index dtypes are caught at metadata validation time.
Mixed dtype rejection#
The Python constructors and native C++ validation both check that
data.dtype matches the dense operand’s dtype for all operation calls.
There is no implicit promotion:
import mlx.core as mx
import mlx_sparse as ms
A = ms.coo_array(
(mx.array([1.0], dtype=mx.float32), (mx.array([0], dtype=mx.int32), mx.array([0], dtype=mx.int32))),
shape=(1, 1),
).tocsr()
x_fp16 = mx.array([1.0], dtype=mx.float16)
A @ x_fp16 # TypeError: csr_matvec requires sparse data and RHS to have
# the same dtype, got float32 and float16.
To use a different dtype, convert before constructing:
A_fp16 = ms.csr_array(
(A.data.astype(mx.float16), A.indices, A.indptr),
shape=A.shape,
)
y = A_fp16 @ x_fp16
Accumulation policy#
For float16 and bfloat16, both CPU and Metal GPU backends use a
float32 accumulator to reduce rounding error during the inner-product
loop, then cast back to the storage dtype on output.
Staged Metal trace reductions also store intermediate float32 partials for
these dtypes before the final cast, so large traces do not introduce an extra
low-precision partial-sum boundary.
For complex64, both real and imaginary components accumulate in
complex64 (i.e. float32 component precision). There is no upcasting
to complex128.
For float32, accumulation is in float32 throughout.
Metal dtype coverage#
Every Metal GPU kernel is compiled for all four value dtypes (float32,
float16, bfloat16, complex64) and both index dtypes (int32,
int64). The kernels with Metal implementations are:
csr_matveccsr_matmulcoo_matvec/coo_matmuland batched variantscsc_matvec/csc_matmuland batched variantscoo_tocsrcoo_tocsccsr_todensecsc_todensecsr_sort_indicescsc_sort_indicescsr_transposesparse reductions, including row/column sums, row/column norms, diagonal, and trace for COO, CSR, and CSC
Dynamic-output helpers such as canonicalize(), fromdense(), and
CSR @ CSR synchronize to host because their output sizes depend on values.
These are host-side assembly operations and are not Metal kernels.