Source code for mlx_sparse._device

# Copyright (c) 2026 The mlx-sparse contributors - All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#    http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import annotations

import mlx.core as mx


def _device(kind, index: int):
    return mx.Device(kind, index)


[docs] def use_cpu(index: int = 0, *, require_available: bool = True) -> mx.Device: """Set MLX's default device to CPU and return the device object. Calling this once at the start of a script or test is the recommended way to pin all subsequent MLX and mlx-sparse operations to the CPU. The setting persists for the lifetime of the Python process or until overridden by another ``use_*`` call. Args: index: CPU device index. Almost always ``0``. Multiple CPU devices are not typical in Apple Silicon runtimes. Default ``0``. require_available: If ``True`` (default), probe the device immediately with a trivial ``mx.eval`` and raise ``RuntimeError`` if it fails. Set to ``False`` to skip the probe (e.g. in environments where eager evaluation is not yet possible). Returns: The ``mlx.core.Device`` object that was set as the default. Raises: RuntimeError: If ``require_available=True`` and the CPU device cannot be probed successfully. Example:: import mlx_sparse as ms ms.use_cpu() # pin to CPU y = A @ x # runs on CPU """ device = _device(mx.cpu, index) mx.set_default_device(device) if require_available: try: probe = mx.array([0.0]) mx.eval(probe) except Exception as exc: raise RuntimeError( f"MLX CPU device {index} is not available to this Python " "process. Verify that native MLX can create a CPU array in the " "same virtual environment." ) from exc return device
[docs] def use_gpu(index: int = 0, *, require_available: bool = True) -> mx.Device: """Set MLX's default device to GPU (Metal) and return the device object. On Apple Silicon, this selects the integrated GPU via MLX's Metal backend. Fixed-shape sparse primitives dispatch native Metal kernels for supported value and index dtypes. Operations with dynamic output structure, such as sparse-sparse products and duplicate summation, may synchronize to host for structural assembly. Args: index: GPU device index. ``0`` selects the primary GPU. Default ``0``. require_available: If ``True`` (default), verify that ``mx.is_available`` returns ``True`` for the selected device and that a trivial array can be evaluated. Raises ``RuntimeError`` if either check fails. Returns: The ``mlx.core.Device`` object that was set as the default. Raises: RuntimeError: If ``require_available=True`` and the GPU is not available or cannot be probed. Example:: import mlx_sparse as ms ms.use_gpu() # pin to GPU y = A @ x # dispatches Metal csr_matvec kernel """ device = _device(mx.gpu, index) mx.set_default_device(device) if require_available: try: available = mx.is_available(device) probe = mx.array([0.0]) mx.eval(probe) except Exception as exc: raise RuntimeError( f"MLX GPU device {index} is not available to this Python " "process. Verify that native MLX can create a GPU array in the " "same virtual environment." ) from exc if not available: raise RuntimeError( f"MLX GPU device {index} is not available. Native MLX must be " "able to create arrays on the GPU before mlx-sparse GPU kernels " "can run." ) return device
[docs] def use_device(name: str, index: int = 0) -> mx.Device: """Set MLX's default device by name string. A convenience wrapper around :func:`use_cpu` and :func:`use_gpu` that accepts a plain string device name. Useful when the target device is provided as a command-line argument. Args: name: ``"cpu"`` or ``"gpu"`` (case-insensitive). index: Device index. Default ``0``. Returns: The ``mlx.core.Device`` object that was set as the default. Raises: ValueError: If ``name`` is not ``"cpu"`` or ``"gpu"``. Example:: import argparse import mlx_sparse as ms parser = argparse.ArgumentParser() parser.add_argument("--device", default="gpu") args = parser.parse_args() ms.use_device(args.device) """ normalized = name.lower() if normalized == "cpu": return use_cpu(index) if normalized == "gpu": return use_gpu(index) raise ValueError(f"device must be 'cpu' or 'gpu', got {name!r}.")