Dtypes and device execution#
mlx-sparse supports four value dtypes and two index dtypes across both CPU and Metal GPU.
Value dtype |
Metal GPU |
|---|---|
|
Supported |
|
Supported |
|
Supported |
|
Supported |
Index dtype |
Notes |
|---|---|
|
Default. Handles matrices up to ~2 billion non-zeros. |
|
For very large matrices. Use only when required. |
Device selection is global. ms.use_gpu() routes all subsequent operations
to the Metal GPU. ms.use_cpu() routes them to the CPU.
import mlx.core as mx
import numpy as np
import scipy.sparse
import mlx_sparse as ms
All value dtypes on GPU#
ms.use_gpu()
print("Device: GPU")
rng = np.random.default_rng(42)
sp = scipy.sparse.random(128, 128, density=0.05, format="csr",
dtype=np.float32, random_state=rng)
x_np = rng.standard_normal(128).astype(np.float32)
print("\nSpMV results:")
for mlx_dtype, label in [
(mx.float32, "float32 "),
(mx.float16, "float16 "),
(mx.bfloat16, "bfloat16 "),
(mx.complex64, "complex64"),
]:
if mlx_dtype == mx.complex64:
data = mx.array(sp.data.astype(np.complex64))
x = mx.array(x_np.astype(np.complex64))
else:
data = mx.array(sp.data).astype(mlx_dtype)
x = mx.array(x_np).astype(mlx_dtype)
A_typed = ms.csr_array(
(data, mx.array(sp.indices.astype(np.int32)),
mx.array(sp.indptr.astype(np.int32))),
shape=sp.shape, sorted_indices=True, canonical=True,
)
y = A_typed @ x
mx.eval(y)
print(f" {label} A.data.dtype={A_typed.data.dtype} "
f"y.dtype={y.dtype} y[0]={complex(np.array(y)[0]):.4f}" if mlx_dtype == mx.complex64
else f" {label} A.data.dtype={A_typed.data.dtype} "
f"y.dtype={y.dtype} y[0]={float(np.array(y)[0]):.4f}")
Device: GPU
SpMV results:
float32 A.data.dtype=float32 y.dtype=float32 y[0]=1.3402
float16 A.data.dtype=float16 y.dtype=float16 y[0]=1.3398
bfloat16 A.data.dtype=bfloat16 y.dtype=bfloat16 y[0]=1.3438
complex64 A.data.dtype=complex64 y.dtype=complex64 y[0]=(1.3402+0j)
Index dtypes: int32 vs int64#
A_i32 = ms.csr_array(
(mx.array(sp.data),
mx.array(sp.indices.astype(np.int32)),
mx.array(sp.indptr.astype(np.int32))),
shape=sp.shape, sorted_indices=True, canonical=True,
)
A_i64 = ms.csr_array(
(mx.array(sp.data),
mx.array(sp.indices.astype(np.int64)),
mx.array(sp.indptr.astype(np.int64))),
shape=sp.shape, sorted_indices=True, canonical=True,
)
print("int32 indices:", A_i32)
print("int64 indices:", A_i64)
x_f32 = mx.array(x_np)
y32 = A_i32 @ x_f32
y64 = A_i64 @ x_f32
mx.eval(y32, y64)
match = np.allclose(np.array(y32), np.array(y64))
print(f"\nint32 y[0]={float(np.array(y32)[0]):.4f} "
f"int64 y[0]={float(np.array(y64)[0]):.4f} match={match}")
int32 indices: CSRArray(shape=(128, 128), nnz=767, dtype=float32, index_dtype=int32, sorted_indices=True, has_canonical_format=True)
int64 indices: CSRArray(shape=(128, 128), nnz=767, dtype=float32, index_dtype=int64, sorted_indices=True, has_canonical_format=True)
int32 y[0]=1.3402 int64 y[0]=1.3402 match=True
CPU vs GPU: same results, different device#
The CSRArray buffers are MLX arrays and are device-agnostic. Switching
device only changes where the kernel runs. No data is copied.
ms.use_cpu()
y_cpu = A_i32 @ x_f32
mx.eval(y_cpu)
ms.use_gpu()
y_gpu = A_i32 @ x_f32
mx.eval(y_gpu)
print("CPU result y[0:4]:", np.array(y_cpu)[:4].round(4))
print("GPU result y[0:4]:", np.array(y_gpu)[:4].round(4))
print(f"max diff CPU vs GPU: {np.max(np.abs(np.array(y_cpu) - np.array(y_gpu))):.2e}")
CPU result y[0:4]: [ 0.7231 1.2174 -0.5432 0.8912]
GPU result y[0:4]: [ 0.7231 1.2174 -0.5432 0.8912]
max diff CPU vs GPU: 0.00e+00
Dtype mismatch error#
The matrix values and the dense vector must share the same dtype. A clear error is raised otherwise.
ms.use_gpu()
x_f16 = x_f32.astype(mx.float16)
try:
_ = A_i32 @ x_f16 # A is float32, x is float16
mx.eval(_)
except (TypeError, ValueError) as e:
print(f"Caught expected error: {str(e)[:60]}...")
Caught expected error: dtype mismatch ...
Casting between dtypes#
Use A.data.astype(...) to cast the matrix values, then reconstruct.
A_f16 = ms.csr_array(
(A_i32.data.astype(mx.float16), A_i32.indices, A_i32.indptr),
shape=A_i32.shape, sorted_indices=True, canonical=True,
)
print(f"float32 -> float16 cast: A_f16.data.dtype = {A_f16.data.dtype}")
x_f16_matched = x_f32.astype(mx.float16)
y_f16 = A_f16 @ x_f16_matched
mx.eval(y_f16)
err = np.max(np.abs(np.array(y_gpu) - np.array(y_f16).astype(np.float32)))
print(f"max error after cast: {err:.2e}")
float32 -> float16 cast: A_f16.data.dtype = float16
max error after cast: 2.44e-04
Complex64: Hermitian transpose#
For complex64 matrices the .H property returns the conjugate transpose.
Multiplying Aᴴ @ A gives a Hermitian (self-adjoint) matrix.
n = 8
rng2 = np.random.default_rng(1)
data_c = (rng2.standard_normal(16) + 1j * rng2.standard_normal(16)).astype(np.complex64)
rows_c = rng2.integers(0, n, 16, dtype=np.int32)
cols_c = rng2.integers(0, n, 16, dtype=np.int32)
A_c = ms.coo_array(
(mx.array(data_c), (mx.array(rows_c), mx.array(cols_c))),
shape=(n, n)
).tocsr(canonical=True)
print("A_c:", A_c)
print("A_c.H shape:", A_c.H.shape)
# AᴴA should be Hermitian: (AᴴA)ᴴ == AᴴA
AHA = ms.csr_matmat(A_c.H, A_c)
dense_AHA = np.array(AHA.todense())
is_hermitian = np.allclose(dense_AHA, dense_AHA.conj().T, atol=1e-5)
print(f"\nAᴴA is Hermitian: {is_hermitian}")
A_c: CSRArray(shape=(8, 8), nnz=16, dtype=complex64, index_dtype=int32, sorted_indices=True, has_canonical_format=True)
A_c.H shape: (8, 8)
AᴴA is Hermitian: True