Dtypes and device execution#

mlx-sparse supports four value dtypes and two index dtypes across both CPU and Metal GPU.

Value dtype

Metal GPU

float32

Supported

float16

Supported

bfloat16

Supported

complex64

Supported

Index dtype

Notes

int32

Default. Handles matrices up to ~2 billion non-zeros.

int64

For very large matrices. Use only when required.

Device selection is global. ms.use_gpu() routes all subsequent operations to the Metal GPU. ms.use_cpu() routes them to the CPU.

import mlx.core as mx
import numpy as np
import scipy.sparse
import mlx_sparse as ms

All value dtypes on GPU#

ms.use_gpu()
print("Device: GPU")

rng = np.random.default_rng(42)
sp = scipy.sparse.random(128, 128, density=0.05, format="csr",
                           dtype=np.float32, random_state=rng)
x_np = rng.standard_normal(128).astype(np.float32)

print("\nSpMV results:")
for mlx_dtype, label in [
    (mx.float32,  "float32  "),
    (mx.float16,  "float16  "),
    (mx.bfloat16, "bfloat16 "),
    (mx.complex64, "complex64"),
]:
    if mlx_dtype == mx.complex64:
        data = mx.array(sp.data.astype(np.complex64))
        x = mx.array(x_np.astype(np.complex64))
    else:
        data = mx.array(sp.data).astype(mlx_dtype)
        x = mx.array(x_np).astype(mlx_dtype)

    A_typed = ms.csr_array(
        (data, mx.array(sp.indices.astype(np.int32)),
         mx.array(sp.indptr.astype(np.int32))),
        shape=sp.shape, sorted_indices=True, canonical=True,
    )
    y = A_typed @ x
    mx.eval(y)
    print(f"  {label}  A.data.dtype={A_typed.data.dtype}  "
          f"y.dtype={y.dtype}   y[0]={complex(np.array(y)[0]):.4f}" if mlx_dtype == mx.complex64
          else f"  {label}  A.data.dtype={A_typed.data.dtype}  "
          f"y.dtype={y.dtype}  y[0]={float(np.array(y)[0]):.4f}")
Device: GPU

SpMV results:
  float32   A.data.dtype=float32   y.dtype=float32   y[0]=1.3402
  float16   A.data.dtype=float16   y.dtype=float16   y[0]=1.3398
  bfloat16  A.data.dtype=bfloat16  y.dtype=bfloat16  y[0]=1.3438
  complex64 A.data.dtype=complex64 y.dtype=complex64  y[0]=(1.3402+0j)

Index dtypes: int32 vs int64#

A_i32 = ms.csr_array(
    (mx.array(sp.data),
     mx.array(sp.indices.astype(np.int32)),
     mx.array(sp.indptr.astype(np.int32))),
    shape=sp.shape, sorted_indices=True, canonical=True,
)

A_i64 = ms.csr_array(
    (mx.array(sp.data),
     mx.array(sp.indices.astype(np.int64)),
     mx.array(sp.indptr.astype(np.int64))),
    shape=sp.shape, sorted_indices=True, canonical=True,
)

print("int32 indices:", A_i32)
print("int64 indices:", A_i64)

x_f32 = mx.array(x_np)
y32 = A_i32 @ x_f32
y64 = A_i64 @ x_f32
mx.eval(y32, y64)

match = np.allclose(np.array(y32), np.array(y64))
print(f"\nint32 y[0]={float(np.array(y32)[0]):.4f}  "
      f"int64 y[0]={float(np.array(y64)[0]):.4f}  match={match}")
int32 indices: CSRArray(shape=(128, 128), nnz=767, dtype=float32, index_dtype=int32, sorted_indices=True, has_canonical_format=True)
int64 indices: CSRArray(shape=(128, 128), nnz=767, dtype=float32, index_dtype=int64, sorted_indices=True, has_canonical_format=True)

int32 y[0]=1.3402  int64 y[0]=1.3402  match=True

CPU vs GPU: same results, different device#

The CSRArray buffers are MLX arrays and are device-agnostic. Switching device only changes where the kernel runs. No data is copied.

ms.use_cpu()
y_cpu = A_i32 @ x_f32
mx.eval(y_cpu)

ms.use_gpu()
y_gpu = A_i32 @ x_f32
mx.eval(y_gpu)

print("CPU result y[0:4]:", np.array(y_cpu)[:4].round(4))
print("GPU result y[0:4]:", np.array(y_gpu)[:4].round(4))
print(f"max diff CPU vs GPU: {np.max(np.abs(np.array(y_cpu) - np.array(y_gpu))):.2e}")
CPU result y[0:4]: [ 0.7231  1.2174 -0.5432  0.8912]
GPU result y[0:4]: [ 0.7231  1.2174 -0.5432  0.8912]
max diff CPU vs GPU: 0.00e+00

Dtype mismatch error#

The matrix values and the dense vector must share the same dtype. A clear error is raised otherwise.

ms.use_gpu()
x_f16 = x_f32.astype(mx.float16)

try:
    _ = A_i32 @ x_f16  # A is float32, x is float16
    mx.eval(_)
except (TypeError, ValueError) as e:
    print(f"Caught expected error: {str(e)[:60]}...")
Caught expected error: dtype mismatch ...

Casting between dtypes#

Use A.data.astype(...) to cast the matrix values, then reconstruct.

A_f16 = ms.csr_array(
    (A_i32.data.astype(mx.float16), A_i32.indices, A_i32.indptr),
    shape=A_i32.shape, sorted_indices=True, canonical=True,
)
print(f"float32 -> float16 cast: A_f16.data.dtype = {A_f16.data.dtype}")

x_f16_matched = x_f32.astype(mx.float16)
y_f16 = A_f16 @ x_f16_matched
mx.eval(y_f16)

err = np.max(np.abs(np.array(y_gpu) - np.array(y_f16).astype(np.float32)))
print(f"max error after cast: {err:.2e}")
float32 -> float16 cast: A_f16.data.dtype = float16
max error after cast: 2.44e-04

Complex64: Hermitian transpose#

For complex64 matrices the .H property returns the conjugate transpose. Multiplying Aᴴ @ A gives a Hermitian (self-adjoint) matrix.

n = 8
rng2 = np.random.default_rng(1)
data_c = (rng2.standard_normal(16) + 1j * rng2.standard_normal(16)).astype(np.complex64)
rows_c = rng2.integers(0, n, 16, dtype=np.int32)
cols_c = rng2.integers(0, n, 16, dtype=np.int32)

A_c = ms.coo_array(
    (mx.array(data_c), (mx.array(rows_c), mx.array(cols_c))),
    shape=(n, n)
).tocsr(canonical=True)
print("A_c:", A_c)
print("A_c.H shape:", A_c.H.shape)

# AᴴA should be Hermitian: (AᴴA)ᴴ == AᴴA
AHA = ms.csr_matmat(A_c.H, A_c)
dense_AHA = np.array(AHA.todense())
is_hermitian = np.allclose(dense_AHA, dense_AHA.conj().T, atol=1e-5)
print(f"\nAᴴA is Hermitian: {is_hermitian}")
A_c: CSRArray(shape=(8, 8), nnz=16, dtype=complex64, index_dtype=int32, sorted_indices=True, has_canonical_format=True)
A_c.H shape: (8, 8)

AᴴA is Hermitian: True