Source code for mlx_sparse

# Copyright (c) 2026 The mlx-sparse contributors - All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#    http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Sparse array containers and primitives for MLX.

mlx-sparse provides COO, CSR, and CSC sparse matrix containers backed by MLX
arrays, with native C++ primitives for sparse operations on CPU and, on
supported Apple Silicon systems, Metal GPU. The public API is intentionally small:
construct sparse matrices
from MLX, NumPy, SciPy, or explicit sparse buffers, run sparse-dense products,
differentiate through sparse values and dense operands, and convert back to
dense when needed.

Typical usage::

    import mlx.core as mx
    import numpy as np
    import mlx_sparse as ms

    ms.use_cpu()

    data = mx.array(np.array([2.0, -1.0, 4.0], dtype=np.float32))
    row = mx.array(np.array([0, 0, 1], dtype=np.int32))
    col = mx.array(np.array([0, 2, 1], dtype=np.int32))

    A = ms.coo_array((data, (row, col)), shape=(2, 3)).tocsr(canonical=True)
    x = mx.array(np.array([3.0, 10.0, 7.0], dtype=np.float32))

    y = A @ x  # CSR matvec
    dense = A.todense()  # materialise as dense
    At = A.T  # structural transpose
    Ah = A.H  # Hermitian (conjugate) transpose
"""

from mlx_sparse import linalg, random, runtime
from mlx_sparse._capabilities import (
    capabilities,
    has_capability,
)
from mlx_sparse._config import (
    config,
    config_context,
    get_config,
    set_config,
)
from mlx_sparse._construct import (
    asarray,
    diags,
    eye,
    from_dense,
    from_numpy,
    from_scipy,
    fromdense,
)
from mlx_sparse._coo import COOArray, coo_array
from mlx_sparse._csc import CSCArray, csc_array
from mlx_sparse._csr import CSRArray, csr_array
from mlx_sparse._device import use_cpu, use_device, use_gpu
from mlx_sparse._ops import (
    coo_batched_matmul,
    coo_batched_matvec,
    coo_col_norms,
    coo_col_sums,
    coo_column_norms,
    coo_column_sums,
    coo_diagonal,
    coo_matmat,
    coo_matmul,
    coo_matvec,
    coo_row_norms,
    coo_row_sums,
    coo_trace,
    csc_batched_matmul,
    csc_batched_matvec,
    csc_col_norms,
    csc_col_sums,
    csc_column_norms,
    csc_column_sums,
    csc_diagonal,
    csc_matmat,
    csc_matmul,
    csc_matvec,
    csc_matvec_transpose,
    csc_row_norms,
    csc_row_sums,
    csc_trace,
    csr_batched_matmul,
    csr_batched_matvec,
    csr_col_sums,
    csr_column_sums,
    csr_diagonal,
    csr_matmat,
    csr_matmul,
    csr_matvec,
    csr_row_norms,
    csr_row_sums,
    csr_trace,
    identity_like,
    todense,
)
from mlx_sparse._typing import is_available

try:
    from mlx_sparse._version import __version__
except ImportError:
    # Package was not installed via pip / build, running directly from source.
    __version__ = "0.0.0.dev0"


[docs] def issparse(x) -> bool: """Return ``True`` if ``x`` is a recognized mlx-sparse container. Currently returns ``True`` for :class:`COOArray`, :class:`CSRArray`, and :class:`CSCArray` instances. All other objects return ``False``. Args: x: Any Python object. Returns: ``True`` if ``x`` is a :class:`COOArray`, :class:`CSRArray`, or :class:`CSCArray`. Example:: import mlx_sparse as ms ms.issparse(my_csr) # True ms.issparse(mx.ones((3, 4))) # False """ return isinstance(x, (COOArray, CSRArray, CSCArray))
__all__ = [ "COOArray", "CSCArray", "CSRArray", "asarray", "capabilities", "coo_array", "config", "config_context", "coo_batched_matmul", "coo_batched_matvec", "coo_col_norms", "coo_col_sums", "coo_column_norms", "coo_column_sums", "coo_diagonal", "coo_matmat", "coo_row_norms", "coo_row_sums", "coo_trace", "csc_array", "csc_batched_matmul", "csc_batched_matvec", "csc_col_norms", "csc_col_sums", "csc_column_norms", "csc_column_sums", "csc_diagonal", "csc_matmat", "csc_matmul", "csc_matvec", "csc_matvec_transpose", "csc_row_norms", "csc_row_sums", "csc_trace", "csr_array", "csr_batched_matmul", "csr_batched_matvec", "csr_col_sums", "csr_column_sums", "csr_diagonal", "csr_matmat", "csr_matmul", "csr_matvec", "csr_row_norms", "csr_row_sums", "csr_trace", "coo_matvec", "coo_matmul", "diags", "eye", "from_dense", "from_numpy", "from_scipy", "fromdense", "identity_like", "is_available", "issparse", "linalg", "runtime", "random", "has_capability", "get_config", "set_config", "todense", "use_cpu", "use_device", "use_gpu", ]