Source code for mlx_sparse.linalg._sparse_ops
# Copyright (c) 2026 The mlx-sparse contributors - All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
from mlx_sparse.linalg.utils.sparse import inner_product_csr as _as_csr
[docs]
def vdot(a, b):
"""Compute the Frobenius inner product of two sparse matrices.
Returns ``sum(conj(a) * b)`` over all stored non-zero pairs, using the
native CSR sorted-merge kernel for efficient sparse-sparse element-wise
accumulation. Equivalent to ``dot(conj(a), b)`` for real matrices.
Args:
a: First sparse matrix. Must be a :class:`~mlx_sparse.CSRArray`,
:class:`~mlx_sparse.COOArray`, or :class:`~mlx_sparse.CSCArray`.
b: Second sparse matrix with the same shape as ``a``. Must be a
:class:`~mlx_sparse.CSRArray`, :class:`~mlx_sparse.COOArray`, or
:class:`~mlx_sparse.CSCArray`.
Returns:
A scalar ``mlx.core.array`` equal to ``sum(conj(a) * b)``.
Raises:
TypeError: If ``a`` or ``b`` is not a supported sparse type.
"""
return _as_csr(a).vdot(_as_csr(b))
[docs]
def dot(a, b):
"""Compute the Frobenius dot product of two sparse matrices.
Returns ``sum(a * b)`` over all stored non-zero pairs (no conjugation),
using the native CSR sorted-merge kernel for efficient sparse-sparse
element-wise accumulation.
Args:
a: First sparse matrix. Must be a :class:`~mlx_sparse.CSRArray`,
:class:`~mlx_sparse.COOArray`, or :class:`~mlx_sparse.CSCArray`.
b: Second sparse matrix with the same shape as ``a``. Must be a
:class:`~mlx_sparse.CSRArray`, :class:`~mlx_sparse.COOArray`, or
:class:`~mlx_sparse.CSCArray`.
Returns:
A scalar ``mlx.core.array`` equal to ``sum(a * b)``.
Raises:
TypeError: If ``a`` or ``b`` is not a supported sparse type.
"""
return _as_csr(a).dot(_as_csr(b))