# Copyright (c) 2026 The mlx-sparse contributors - All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
from functools import reduce
from operator import mul
import mlx.core as mx
import mlx_sparse._native as _native
from mlx_sparse._coo import COOArray
from mlx_sparse._csc import CSCArray
from mlx_sparse._csr import CSRArray
from mlx_sparse._validation import (
ensure_mx_array,
validate_coo_matmul_inputs,
validate_coo_matvec_inputs,
validate_csc_matmul_inputs,
validate_csc_matvec_inputs,
validate_csc_matvec_transpose_inputs,
validate_csr_matmul_inputs,
validate_csr_matvec_inputs,
validate_csr_metadata,
)
def _prod(values) -> int:
return int(reduce(mul, values, 1))
[docs]
def identity_like(x: mx.array) -> mx.array:
"""Return a native MLX copy of ``x``.
This function exists as an extension smoke test. It passes ``x`` through
the native ``_ext`` module (if available) and returns an identical MLX
array. For production code, prefer ``mlx.core`` operations directly.
Args:
x: Any MLX array.
Returns:
An MLX array with the same shape, dtype, and values as ``x``.
"""
return _native.identity_like(ensure_mx_array(x))
[docs]
def todense(array) -> mx.array:
"""Materialize a sparse array as a dense MLX array.
Convenience wrapper that calls ``array.todense()`` on any sparse container.
Duplicate entries are summed, consistent with ``canonicalize().todense()``.
Args:
array: A :class:`~mlx_sparse.COOArray`, :class:`~mlx_sparse.CSRArray`,
or :class:`~mlx_sparse.CSCArray` instance.
Returns:
Dense array of shape ``(n_rows, n_cols)`` with the same dtype as
``array.data``.
Raises:
TypeError: If ``array`` does not have a ``todense`` method.
Example::
import mlx_sparse as ms
dense = ms.todense(my_csr)
"""
if hasattr(array, "todense"):
return array.todense()
raise TypeError(f"todense expects an mlx-sparse array, got {type(array).__name__}.")
def _ensure_csr_array(name: str, a) -> CSRArray:
if not isinstance(a, CSRArray):
raise TypeError(f"{name} expects CSRArray, got {type(a).__name__}.")
validate_csr_metadata(a.data, a.indices, a.indptr, a.shape)
return a
def _ensure_csc_array(name: str, a) -> CSCArray:
if not isinstance(a, CSCArray):
raise TypeError(f"{name} expects CSCArray, got {type(a).__name__}.")
return a
def _ensure_coo_array(name: str, a) -> COOArray:
if not isinstance(a, COOArray):
raise TypeError(f"{name} expects COOArray, got {type(a).__name__}.")
return a
def csr_row_sums(a: CSRArray) -> mx.array:
"""Reduce each row of a CSR matrix to the sum of its stored values."""
a = _ensure_csr_array("csr_row_sums", a)
return _native.csr_row_sums(a.data, a.indices, a.indptr, a.shape)
def csr_col_sums(a: CSRArray) -> mx.array:
"""Reduce each column of a CSR matrix to the sum of its stored values."""
a = _ensure_csr_array("csr_col_sums", a)
return _native.csr_col_sums(a.data, a.indices, a.indptr, a.shape)
def csr_column_sums(a: CSRArray) -> mx.array:
"""Alias for :func:`csr_col_sums`."""
return csr_col_sums(a)
def csr_row_norms(a: CSRArray) -> mx.array:
"""Compute the L2 norm of each CSR row."""
a = _ensure_csr_array("csr_row_norms", a)
if not a.has_canonical_format:
a = a.canonicalize()
return _native.csr_row_norms(a.data, a.indices, a.indptr, a.shape)
def csr_diagonal(a: CSRArray) -> mx.array:
"""Extract the summed diagonal of a CSR matrix."""
a = _ensure_csr_array("csr_diagonal", a)
return _native.csr_diagonal(a.data, a.indices, a.indptr, a.shape)
def csr_trace(a: CSRArray) -> mx.array:
"""Compute the trace of a CSR matrix."""
a = _ensure_csr_array("csr_trace", a)
return _native.csr_trace(a.data, a.indices, a.indptr, a.shape)
def coo_row_sums(a: COOArray) -> mx.array:
"""Reduce each row of a COO matrix to the sum of its stored values."""
a = _ensure_coo_array("coo_row_sums", a)
return _native.coo_row_sums(a.data, a.row, a.col, a.shape)
def coo_col_sums(a: COOArray) -> mx.array:
"""Reduce each column of a COO matrix to the sum of its stored values."""
a = _ensure_coo_array("coo_col_sums", a)
return _native.coo_col_sums(a.data, a.row, a.col, a.shape)
def coo_column_sums(a: COOArray) -> mx.array:
"""Alias for :func:`coo_col_sums`."""
return coo_col_sums(a)
def coo_row_norms(a: COOArray) -> mx.array:
"""Compute the dense-semantics L2 norm of each COO row."""
a = _ensure_coo_array("coo_row_norms", a)
if not a.has_canonical_format:
return a.tocsr(canonical=True).row_norms()
return _native.coo_row_norms(a.data, a.row, a.col, a.shape, assume_canonical=True)
def coo_col_norms(a: COOArray) -> mx.array:
"""Compute the dense-semantics L2 norm of each COO column."""
a = _ensure_coo_array("coo_col_norms", a)
if not a.has_canonical_format:
return a.tocsc(canonical=True).col_norms()
return _native.coo_col_norms(a.data, a.row, a.col, a.shape, assume_canonical=True)
def coo_column_norms(a: COOArray) -> mx.array:
"""Alias for :func:`coo_col_norms`."""
return coo_col_norms(a)
def coo_diagonal(a: COOArray) -> mx.array:
"""Extract the summed diagonal of a COO matrix."""
a = _ensure_coo_array("coo_diagonal", a)
return _native.coo_diagonal(a.data, a.row, a.col, a.shape)
def coo_trace(a: COOArray) -> mx.array:
"""Compute the trace of a COO matrix."""
a = _ensure_coo_array("coo_trace", a)
return _native.coo_trace(a.data, a.row, a.col, a.shape)
def csc_row_sums(a: CSCArray) -> mx.array:
"""Reduce each row of a CSC matrix to the sum of its stored values."""
a = _ensure_csc_array("csc_row_sums", a)
return _native.csc_row_sums(a.data, a.indices, a.indptr, a.shape)
def csc_col_sums(a: CSCArray) -> mx.array:
"""Reduce each column of a CSC matrix to the sum of its stored values."""
a = _ensure_csc_array("csc_col_sums", a)
return _native.csc_col_sums(a.data, a.indices, a.indptr, a.shape)
def csc_column_sums(a: CSCArray) -> mx.array:
"""Alias for :func:`csc_col_sums`."""
return csc_col_sums(a)
def csc_row_norms(a: CSCArray) -> mx.array:
"""Compute the L2 norm of each CSC row."""
a = _ensure_csc_array("csc_row_norms", a)
if not a.has_canonical_format:
a = a.canonicalize()
return _native.csc_row_norms(
a.data, a.indices, a.indptr, a.shape, assume_canonical=True
)
def csc_col_norms(a: CSCArray) -> mx.array:
"""Compute the L2 norm of each CSC column."""
a = _ensure_csc_array("csc_col_norms", a)
if not a.has_canonical_format:
a = a.canonicalize()
return _native.csc_col_norms(
a.data, a.indices, a.indptr, a.shape, assume_canonical=True
)
def csc_column_norms(a: CSCArray) -> mx.array:
"""Alias for :func:`csc_col_norms`."""
return csc_col_norms(a)
def csc_diagonal(a: CSCArray) -> mx.array:
"""Extract the summed diagonal of a CSC matrix."""
a = _ensure_csc_array("csc_diagonal", a)
return _native.csc_diagonal(a.data, a.indices, a.indptr, a.shape)
def csc_trace(a: CSCArray) -> mx.array:
"""Compute the trace of a CSC matrix."""
a = _ensure_csc_array("csc_trace", a)
return _native.csc_trace(a.data, a.indices, a.indptr, a.shape)
[docs]
def csc_matvec(a: CSCArray, x) -> mx.array:
"""Multiply a CSC sparse matrix by a dense vector."""
a = _ensure_csc_array("csc_matvec", a)
x = ensure_mx_array(x)
validate_csc_matvec_inputs(a.data, a.indices, a.indptr, x, a.shape)
return _native.csc_matvec(a.data, a.indices, a.indptr, x, a.shape)
[docs]
def coo_matvec(a: COOArray, x) -> mx.array:
"""Multiply a COO sparse matrix by a dense vector."""
a = _ensure_coo_array("coo_matvec", a)
x = ensure_mx_array(x)
validate_coo_matvec_inputs(a.data, a.row, a.col, x, a.shape)
return _native.coo_matvec(a.data, a.row, a.col, x, a.shape)
[docs]
def csc_matvec_transpose(a: CSCArray, x) -> mx.array:
"""Multiply the transpose of a CSC sparse matrix by a dense vector."""
a = _ensure_csc_array("csc_matvec_transpose", a)
x = ensure_mx_array(x)
validate_csc_matvec_transpose_inputs(a.data, a.indices, a.indptr, x, a.shape)
return _native.csc_matvec_transpose(a.data, a.indices, a.indptr, x, a.shape)
[docs]
def coo_batched_matvec(a: COOArray, rhs) -> mx.array:
"""Multiply a COO sparse matrix by a batch of dense vectors."""
a = _ensure_coo_array("coo_batched_matvec", a)
rhs = ensure_mx_array(rhs)
if rhs.ndim < 2:
raise ValueError(
f"coo_batched_matvec expects rank-2 or higher RHS, got {rhs.shape}."
)
if rhs.shape[-1] != a.shape[1]:
raise ValueError(
f"coo_batched_matvec RHS has vector dimension {rhs.shape[-1]}, "
f"but sparse n_cols={a.shape[1]}."
)
if a.data.dtype != rhs.dtype:
raise TypeError(
"coo_batched_matvec requires sparse data and RHS to have the same dtype, "
f"got {a.data.dtype} and {rhs.dtype}."
)
batch_shape = tuple(int(dim) for dim in rhs.shape[:-1])
batch_size = _prod(batch_shape)
rhs_flat = mx.reshape(rhs, (batch_size, a.shape[1]))
out_flat = _native.coo_batched_matvec(a.data, a.row, a.col, rhs_flat, a.shape)
return mx.reshape(out_flat, (*batch_shape, a.shape[0]))
[docs]
def csc_batched_matvec(a: CSCArray, rhs) -> mx.array:
"""Multiply a CSC sparse matrix by a batch of dense vectors."""
a = _ensure_csc_array("csc_batched_matvec", a)
rhs = ensure_mx_array(rhs)
if rhs.ndim < 2:
raise ValueError(
f"csc_batched_matvec expects rank-2 or higher RHS, got {rhs.shape}."
)
if rhs.shape[-1] != a.shape[1]:
raise ValueError(
f"csc_batched_matvec RHS has vector dimension {rhs.shape[-1]}, "
f"but sparse n_cols={a.shape[1]}."
)
if a.data.dtype != rhs.dtype:
raise TypeError(
"csc_batched_matvec requires sparse data and RHS to have the same dtype, "
f"got {a.data.dtype} and {rhs.dtype}."
)
batch_shape = tuple(int(dim) for dim in rhs.shape[:-1])
batch_size = _prod(batch_shape)
rhs_flat = mx.reshape(rhs, (batch_size, a.shape[1]))
out_flat = _native.csc_batched_matvec(
a.data, a.indices, a.indptr, rhs_flat, a.shape
)
return mx.reshape(out_flat, (*batch_shape, a.shape[0]))
[docs]
def csr_matvec(a: CSRArray, x) -> mx.array:
"""Multiply a CSR sparse matrix by a dense vector.
Computes ``y = A @ x`` where ``A`` is a :class:`~mlx_sparse.CSRArray` and
``x`` is a rank-1 dense array. The result is added to the MLX computation
graph and not evaluated eagerly.
On Apple Silicon, the Metal backend dispatches a scalar row kernel for
short rows and a vector-reduction kernel for long rows. CPU and GPU paths
support ``float32``, ``float16``, ``bfloat16``, and ``complex64`` values
with ``int32`` or ``int64`` indices.
Args:
a: The sparse matrix, shape ``(n_rows, n_cols)``.
x: Dense vector, shape ``(n_cols,)``. Converted to ``mx.array`` if
needed. Must have the same dtype as ``a.data``.
Returns:
Dense vector of shape ``(n_rows,)`` with the same dtype as ``a.data``.
Raises:
TypeError: If ``a`` is not a :class:`~mlx_sparse.CSRArray`, or if the
dtypes of ``a.data`` and ``x`` do not match.
ValueError: If shape constraints are violated.
Example::
import mlx.core as mx
import mlx_sparse as ms
y = a @ x # preferred via __matmul__
y = ms.csr_matvec(a, x) # explicit call
mx.eval(y)
"""
if not isinstance(a, CSRArray):
raise TypeError(f"csr_matvec expects CSRArray, got {type(a).__name__}.")
x = ensure_mx_array(x)
validate_csr_matvec_inputs(a.data, a.indices, a.indptr, x, a.shape)
return _native.csr_matvec(a.data, a.indices, a.indptr, x, a.shape)
[docs]
def csr_batched_matvec(a: CSRArray, rhs) -> mx.array:
"""Multiply a CSR sparse matrix by a batch of dense vectors.
Computes ``Y[b] = A @ X[b]`` for ``X`` with shape ``(..., n_cols)`` and
returns shape ``(..., n_rows)``. The implementation uses native batched
CPU/Metal kernels after flattening any leading batch dimensions.
"""
if not isinstance(a, CSRArray):
raise TypeError(f"csr_batched_matvec expects CSRArray, got {type(a).__name__}.")
rhs = ensure_mx_array(rhs)
if rhs.ndim < 2:
raise ValueError(
f"csr_batched_matvec expects rank-2 or higher RHS, got {rhs.shape}."
)
if rhs.shape[-1] != a.shape[1]:
raise ValueError(
f"csr_batched_matvec RHS has vector dimension {rhs.shape[-1]}, "
f"but sparse n_cols={a.shape[1]}."
)
if a.data.dtype != rhs.dtype:
raise TypeError(
"csr_batched_matvec requires sparse data and RHS to have the same dtype, "
f"got {a.data.dtype} and {rhs.dtype}."
)
batch_shape = tuple(int(dim) for dim in rhs.shape[:-1])
batch_size = _prod(batch_shape)
rhs_flat = mx.reshape(rhs, (batch_size, a.shape[1]))
out_flat = _native.csr_batched_matvec(
a.data, a.indices, a.indptr, rhs_flat, a.shape
)
return mx.reshape(out_flat, (*batch_shape, a.shape[0]))
[docs]
def csr_matmat(a: CSRArray, rhs: CSRArray) -> CSRArray:
"""Multiply two CSR sparse matrices and return a canonical CSR matrix.
Computes ``C = A @ B`` where both ``A`` and ``B`` are
:class:`~mlx_sparse.CSRArray` instances. The output sparsity pattern is
not known at graph-build time, so this operation performs a native C++
structural assembly pass on the host (calling ``mx.eval`` on the input
arrays internally) and returns a new :class:`~mlx_sparse.CSRArray` with
canonical format.
Because the output size is data-dependent, this operation is not
representable as a fixed-shape MLX primitive. It is suitable for one-shot
matrix products and matrix-power computations, but is not appropriate
inside a JIT-compiled function.
Args:
a: Left-hand sparse matrix, shape ``(m, k)``.
rhs: Right-hand sparse matrix, shape ``(k, n)``.
Returns:
A canonical :class:`~mlx_sparse.CSRArray` with shape ``(m, n)``,
``has_canonical_format=True``, and ``sorted_indices=True``.
Raises:
TypeError: If either argument is not a :class:`~mlx_sparse.CSRArray`.
ValueError: If the inner dimensions do not match (``a.shape[1] != rhs.shape[0]``).
Example::
import mlx_sparse as ms
# Compute the square of a sparse matrix
C = A @ A # dispatches csr_matmat when A is CSRArray
C = ms.csr_matmat(A, A) # explicit call
# Chain sparse matrix products
D = ms.csr_matmat(ms.csr_matmat(A, B), C)
"""
if not isinstance(a, CSRArray):
raise TypeError(f"csr_matmat expects CSRArray lhs, got {type(a).__name__}.")
if not isinstance(rhs, CSRArray):
raise TypeError(f"csr_matmat expects CSRArray rhs, got {type(rhs).__name__}.")
if a.data.dtype != rhs.data.dtype:
raise TypeError(
"CSR sparse-sparse matmul requires matching value dtypes, "
f"got {a.data.dtype} and {rhs.data.dtype}."
)
data, indices, indptr = _native.csr_matmat(a, rhs)
return CSRArray(
data=data,
indices=indices,
indptr=indptr,
shape=(a.shape[0], rhs.shape[1]),
sorted_indices=True,
has_canonical_format=True,
)
[docs]
def coo_matmat(a: COOArray, rhs: COOArray) -> COOArray:
"""Multiply two COO sparse matrices and return a canonical COO matrix.
The native implementation groups both operands by coordinate rows, performs
a symbolic row pass to size the result, then fills sorted output coordinates
without routing through CSR.
"""
a = _ensure_coo_array("coo_matmat", a)
if not isinstance(rhs, COOArray):
raise TypeError(f"coo_matmat expects COOArray rhs, got {type(rhs).__name__}.")
if a.shape[1] != rhs.shape[0]:
raise ValueError(
f"COO sparse-sparse matmul dimension mismatch: {a.shape} @ {rhs.shape}."
)
if a.data.dtype != rhs.data.dtype:
raise TypeError(
"COO sparse-sparse matmul requires matching value dtypes, "
f"got {a.data.dtype} and {rhs.data.dtype}."
)
data, row, col = _native.coo_matmat(a, rhs)
return COOArray(
data=data,
row=row,
col=col,
shape=(a.shape[0], rhs.shape[1]),
has_canonical_format=True,
)
[docs]
def csc_matmat(a: CSCArray, rhs: CSCArray) -> CSCArray:
"""Multiply two CSC sparse matrices and return a canonical CSC matrix.
The native implementation traverses right-hand columns and left-hand
compressed columns directly, producing sorted row indices per output column.
It does not convert to CSR internally.
"""
a = _ensure_csc_array("csc_matmat", a)
if not isinstance(rhs, CSCArray):
raise TypeError(f"csc_matmat expects CSCArray rhs, got {type(rhs).__name__}.")
if a.shape[1] != rhs.shape[0]:
raise ValueError(
f"CSC sparse-sparse matmul dimension mismatch: {a.shape} @ {rhs.shape}."
)
if a.data.dtype != rhs.data.dtype:
raise TypeError(
"CSC sparse-sparse matmul requires matching value dtypes, "
f"got {a.data.dtype} and {rhs.data.dtype}."
)
data, indices, indptr = _native.csc_matmat(a, rhs)
return CSCArray(
data=data,
indices=indices,
indptr=indptr,
shape=(a.shape[0], rhs.shape[1]),
sorted_indices=True,
has_canonical_format=True,
)
def _csr_matmul_rank2(a: CSRArray, rhs: mx.array) -> mx.array:
validate_csr_matmul_inputs(a.data, a.indices, a.indptr, rhs, a.shape)
return _native.csr_matmul(a.data, a.indices, a.indptr, rhs, a.shape)
def _coo_matmul_rank2(a: COOArray, rhs: mx.array) -> mx.array:
validate_coo_matmul_inputs(a.data, a.row, a.col, rhs, a.shape)
return _native.coo_matmul(a.data, a.row, a.col, rhs, a.shape)
def _csc_matmul_rank2(a: CSCArray, rhs: mx.array) -> mx.array:
validate_csc_matmul_inputs(a.data, a.indices, a.indptr, rhs, a.shape)
return _native.csc_matmul(a.data, a.indices, a.indptr, rhs, a.shape)
def _csr_matmul_batched(a: CSRArray, rhs: mx.array) -> mx.array:
if a.data.dtype != rhs.dtype:
raise TypeError(
"csr_matmul requires sparse data and RHS to have the same dtype, "
f"got {a.data.dtype} and {rhs.dtype}."
)
batch_shape = tuple(int(dim) for dim in rhs.shape[:-2])
rhs_cols = int(rhs.shape[-1])
batch_size = _prod(batch_shape)
rhs_flat = mx.reshape(rhs, (batch_size, a.shape[1], rhs_cols))
out_flat = _native.csr_batched_matmul(
a.data, a.indices, a.indptr, rhs_flat, a.shape
)
return mx.reshape(out_flat, (*batch_shape, a.shape[0], rhs_cols))
def _coo_matmul_batched(a: COOArray, rhs: mx.array) -> mx.array:
if a.data.dtype != rhs.dtype:
raise TypeError(
"coo_matmul requires sparse data and RHS to have the same dtype, "
f"got {a.data.dtype} and {rhs.dtype}."
)
batch_shape = tuple(int(dim) for dim in rhs.shape[:-2])
rhs_cols = int(rhs.shape[-1])
batch_size = _prod(batch_shape)
rhs_flat = mx.reshape(rhs, (batch_size, a.shape[1], rhs_cols))
out_flat = _native.coo_batched_matmul(a.data, a.row, a.col, rhs_flat, a.shape)
return mx.reshape(out_flat, (*batch_shape, a.shape[0], rhs_cols))
def _csc_matmul_batched(a: CSCArray, rhs: mx.array) -> mx.array:
if a.data.dtype != rhs.dtype:
raise TypeError(
"csc_matmul requires sparse data and RHS to have the same dtype, "
f"got {a.data.dtype} and {rhs.dtype}."
)
batch_shape = tuple(int(dim) for dim in rhs.shape[:-2])
rhs_cols = int(rhs.shape[-1])
batch_size = _prod(batch_shape)
rhs_flat = mx.reshape(rhs, (batch_size, a.shape[1], rhs_cols))
out_flat = _native.csc_batched_matmul(
a.data, a.indices, a.indptr, rhs_flat, a.shape
)
return mx.reshape(out_flat, (*batch_shape, a.shape[0], rhs_cols))
[docs]
def coo_batched_matmul(a: COOArray, rhs) -> mx.array:
"""Multiply a COO sparse matrix by a batch of dense matrices."""
a = _ensure_coo_array("coo_batched_matmul", a)
rhs = ensure_mx_array(rhs)
if rhs.ndim < 3:
raise ValueError(
f"coo_batched_matmul expects rank-3 or higher RHS, got {rhs.shape}."
)
if rhs.shape[-2] != a.shape[1]:
raise ValueError(
f"coo_batched_matmul RHS has sparse dimension {rhs.shape[-2]}, "
f"but sparse n_cols={a.shape[1]}."
)
return _coo_matmul_batched(a, rhs)
[docs]
def csc_batched_matmul(a: CSCArray, rhs) -> mx.array:
"""Multiply a CSC sparse matrix by a batch of dense matrices."""
a = _ensure_csc_array("csc_batched_matmul", a)
rhs = ensure_mx_array(rhs)
if rhs.ndim < 3:
raise ValueError(
f"csc_batched_matmul expects rank-3 or higher RHS, got {rhs.shape}."
)
if rhs.shape[-2] != a.shape[1]:
raise ValueError(
f"csc_batched_matmul RHS has sparse dimension {rhs.shape[-2]}, "
f"but sparse n_cols={a.shape[1]}."
)
return _csc_matmul_batched(a, rhs)
[docs]
def coo_matmul(a: COOArray, rhs) -> mx.array:
"""Multiply a COO sparse matrix by a dense matrix or batched matrices."""
a = _ensure_coo_array("coo_matmul", a)
rhs = ensure_mx_array(rhs)
if rhs.ndim == 2:
return _coo_matmul_rank2(a, rhs)
if rhs.ndim < 2:
raise ValueError(f"coo_matmul expects rank-2 or higher RHS, got {rhs.shape}.")
if rhs.shape[-2] != a.shape[1]:
raise ValueError(
f"coo_matmul RHS has sparse dimension {rhs.shape[-2]}, "
f"but sparse n_cols={a.shape[1]}."
)
return _coo_matmul_batched(a, rhs)
[docs]
def csc_matmul(a: CSCArray, rhs) -> mx.array:
"""Multiply a CSC sparse matrix by a dense matrix or batched matrices."""
a = _ensure_csc_array("csc_matmul", a)
rhs = ensure_mx_array(rhs)
if rhs.ndim == 2:
return _csc_matmul_rank2(a, rhs)
if rhs.ndim < 2:
raise ValueError(f"csc_matmul expects rank-2 or higher RHS, got {rhs.shape}.")
if rhs.shape[-2] != a.shape[1]:
raise ValueError(
f"csc_matmul RHS has sparse dimension {rhs.shape[-2]}, "
f"but sparse n_cols={a.shape[1]}."
)
return _csc_matmul_batched(a, rhs)
[docs]
def csr_batched_matmul(a: CSRArray, rhs) -> mx.array:
"""Multiply a CSR sparse matrix by a batch of dense matrices.
``rhs`` must have shape ``(..., n_cols, k)`` and the result has shape
``(..., n_rows, k)``. For rank-2 dense matrices, use :func:`csr_matmul`.
"""
if not isinstance(a, CSRArray):
raise TypeError(f"csr_batched_matmul expects CSRArray, got {type(a).__name__}.")
rhs = ensure_mx_array(rhs)
if rhs.ndim < 3:
raise ValueError(
f"csr_batched_matmul expects rank-3 or higher RHS, got {rhs.shape}."
)
if rhs.shape[-2] != a.shape[1]:
raise ValueError(
f"csr_batched_matmul RHS has sparse dimension {rhs.shape[-2]}, "
f"but sparse n_cols={a.shape[1]}."
)
return _csr_matmul_batched(a, rhs)
[docs]
def csr_matmul(a: CSRArray, rhs) -> mx.array:
"""Multiply a CSR sparse matrix by a dense matrix.
Computes ``Y = A @ B`` where ``A`` is a :class:`~mlx_sparse.CSRArray` and
``B`` is a rank-2 or batched dense array. The result is added to the MLX
computation graph and not evaluated eagerly.
On Apple Silicon, the Metal backend dispatches scalar output-element
kernels for short rows and vector-reduction kernels for long rows. CPU and
GPU paths support ``float32``, ``float16``, ``bfloat16``, and ``complex64``
values with ``int32`` or ``int64`` indices.
Args:
a: The sparse matrix, shape ``(n_rows, n_cols)``.
rhs: Dense matrix, shape ``(n_cols, k)``, or batched dense matrix with
sparse dimension at ``rhs.shape[-2]``. Converted to ``mx.array`` if
needed. Must have the same dtype as ``a.data``.
Returns:
Dense matrix or batched dense matrix with sparse dimension replaced by
``n_rows`` and the same dtype as ``a.data``.
Raises:
TypeError: If ``a`` is not a :class:`~mlx_sparse.CSRArray`, or if
dtype constraints are violated.
ValueError: If shape constraints are violated.
Example::
import mlx.core as mx
import mlx_sparse as ms
Y = a @ B # preferred via __matmul__
Y = ms.csr_matmul(a, B) # explicit call
mx.eval(Y)
"""
if not isinstance(a, CSRArray):
raise TypeError(f"csr_matmul expects CSRArray, got {type(a).__name__}.")
rhs = ensure_mx_array(rhs)
if rhs.ndim == 2:
return _csr_matmul_rank2(a, rhs)
if rhs.ndim < 2:
raise ValueError(f"csr_matmul expects rank-2 or higher RHS, got {rhs.shape}.")
if rhs.shape[-2] != a.shape[1]:
raise ValueError(
f"csr_matmul RHS has sparse dimension {rhs.shape[-2]}, "
f"but sparse n_cols={a.shape[1]}."
)
return _csr_matmul_batched(a, rhs)