Source code for mlx_sparse._csc

# Copyright (c) 2026 The mlx-sparse contributors - All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#    http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import annotations

from dataclasses import dataclass

import mlx.core as mx

import mlx_sparse._native as _native
from mlx_sparse._csr import CSRArray
from mlx_sparse._typing import Shape2D, ValidationMode
from mlx_sparse._validation import (
    ensure_mx_array,
    normalize_shape,
    normalize_validation_mode,
    sanitize_scalar,
    validate_csc_metadata,
    validate_csc_values,
)


[docs] @dataclass(frozen=True, slots=True) class CSCArray: """A 2D sparse matrix in Compressed Sparse Column (CSC) format. CSCArray stores a 2D sparse matrix using three MLX arrays: - **data**: non-zero values, shape ``(nnz,)``. - **indices**: row index of each stored value, shape ``(nnz,)``. - **indptr**: column pointer array, shape ``(n_cols + 1,)``. Column ``j`` spans ``data[indptr[j] : indptr[j+1]]`` with corresponding row indices ``indices[indptr[j] : indptr[j+1]]``. Duplicate row entries within a column are permitted unless the matrix is in canonical form. CSC is the column-compressed dual of CSR. It is the natural layout for operations that consume one full column at a time, such as transpose matvec, column-oriented canonicalization, and future direct factorization kernels. Args: data: Non-zero values, shape ``(nnz,)``. indices: Row indices, shape ``(nnz,)``. indptr: Column pointer array, shape ``(n_cols + 1,)``. shape: Matrix dimensions as ``(n_rows, n_cols)``. sorted_indices: Hint that row indices within each column are sorted ascending. Defaults to ``False``. has_canonical_format: Hint that the matrix has sorted row indices and no duplicate row index in any column. Implies ``sorted_indices=True``. Defaults to ``False``. """ data: mx.array indices: mx.array indptr: mx.array shape: Shape2D sorted_indices: bool = False has_canonical_format: bool = False def __post_init__(self) -> None: object.__setattr__(self, "shape", normalize_shape(self.shape)) @property def nnz(self) -> int: """Number of stored values (including any duplicates).""" return int(self.data.shape[0]) @property def dtype(self): """Value dtype of the stored non-zeros.""" return self.data.dtype @property def index_dtype(self): """Integer dtype used for ``indices`` and ``indptr``.""" return self.indices.dtype @property def ndim(self) -> int: """Always 2. Sparse arrays in this package are rank-2.""" return 2 def __repr__(self) -> str: return ( "CSCArray(" f"shape={self.shape}, nnz={self.nnz}, dtype={self.dtype}, " f"index_dtype={self.index_dtype}, " f"sorted_indices={self.sorted_indices}, " f"has_canonical_format={self.has_canonical_format})" )
[docs] def todense(self) -> mx.array: """Materialize the sparse matrix as a dense MLX array.""" return _native.csc_todense(self.data, self.indices, self.indptr, self.shape)
[docs] def row_sums(self) -> mx.array: """Return the sum of stored values in each CSC row.""" from mlx_sparse._ops import csc_row_sums return csc_row_sums(self)
[docs] def col_sums(self) -> mx.array: """Return the sum of stored values in each CSC column.""" from mlx_sparse._ops import csc_col_sums return csc_col_sums(self)
[docs] def column_sums(self) -> mx.array: """Alias for :meth:`col_sums`.""" return self.col_sums()
[docs] def row_norms(self) -> mx.array: """Return the dense-semantics L2 norm of each CSC row as ``float32``.""" from mlx_sparse._ops import csc_row_norms return csc_row_norms(self)
[docs] def col_norms(self) -> mx.array: """Return the dense-semantics L2 norm of each CSC column as ``float32``.""" from mlx_sparse._ops import csc_col_norms return csc_col_norms(self)
[docs] def column_norms(self) -> mx.array: """Alias for :meth:`col_norms`.""" return self.col_norms()
[docs] def diagonal(self) -> mx.array: """Return the summed main diagonal.""" from mlx_sparse._ops import csc_diagonal return csc_diagonal(self)
[docs] def trace(self) -> mx.array: """Return the summed main diagonal as a scalar.""" from mlx_sparse._ops import csc_trace return csc_trace(self)
[docs] def sum(self, axis=None) -> mx.array: """Sum sparse values over all entries, rows, or columns. ``axis=None`` returns a scalar, ``axis=1`` returns row sums, and ``axis=0`` returns column sums. """ if axis is None: return mx.sum(self.col_sums()) if axis in (1, -1): return self.row_sums() if axis in (0, -2): return self.col_sums() raise ValueError(f"CSCArray.sum axis must be None, 0, or 1; got {axis!r}.")
[docs] def tocsr(self, *, canonical: bool | None = None) -> CSRArray: """Convert to :class:`~mlx_sparse.CSRArray`. Args: canonical: If ``True``, canonicalize the returned CSR matrix. If ``False`` or ``None`` (default), return the structural conversion as produced by the native count/prefix/fill path. The structural path preserves values but does not promise sorted output metadata on every backend. """ data, indices, indptr = _native.csc_tocsr( self.data, self.indices, self.indptr, self.shape, ) out = CSRArray( data=data, indices=indices, indptr=indptr, shape=self.shape, sorted_indices=False, has_canonical_format=False, ) if canonical is True: return out.canonicalize() if canonical is False: return CSRArray( data=out.data, indices=out.indices, indptr=out.indptr, shape=out.shape, sorted_indices=out.sorted_indices, has_canonical_format=False, ) return out
[docs] def sort_indices(self) -> "CSCArray": """Return a new CSCArray with row indices sorted within each column.""" if self.sorted_indices: return self data, indices, indptr = _native.csc_sort_indices( self.data, self.indices, self.indptr, ) return CSCArray( data=data, indices=indices, indptr=indptr, shape=self.shape, sorted_indices=True, has_canonical_format=False, )
[docs] def sum_duplicates(self) -> "CSCArray": """Sum duplicate row entries within each column.""" sorted_self = self.sort_indices() data, indices, indptr = _native.csc_sum_duplicates( sorted_self.data, sorted_self.indices, sorted_self.indptr, ) return CSCArray( data=data, indices=indices, indptr=indptr, shape=self.shape, sorted_indices=True, has_canonical_format=True, )
[docs] def canonicalize(self) -> "CSCArray": """Return canonical form: sorted row indices, no duplicates.""" if self.has_canonical_format: return self return self.sum_duplicates()
[docs] def transpose(self) -> CSRArray: """Transpose the sparse matrix, returning a zero-copy CSRArray.""" return CSRArray( data=self.data, indices=self.indices, indptr=self.indptr, shape=(self.shape[1], self.shape[0]), sorted_indices=self.sorted_indices, has_canonical_format=self.has_canonical_format, )
@property def T(self) -> CSRArray: """Transposed matrix. Alias for :meth:`transpose`.""" return self.transpose()
[docs] def conj(self) -> "CSCArray": """Complex-conjugate the stored values.""" return CSCArray( data=mx.conjugate(self.data), indices=self.indices, indptr=self.indptr, shape=self.shape, sorted_indices=self.sorted_indices, has_canonical_format=self.has_canonical_format, )
[docs] def conjugate(self) -> "CSCArray": """Alias for :meth:`conj`.""" return self.conj()
@property def H(self) -> CSRArray: """Hermitian (conjugate) transpose. Equivalent to ``conj().T``.""" return self.conj().transpose()
[docs] def __matmul__(self, rhs): """Matrix multiplication via the ``@`` operator.""" from mlx_sparse._coo import COOArray from mlx_sparse._ops import csc_matmat, csc_matmul, csc_matvec if isinstance(rhs, CSCArray): return csc_matmat(self, rhs) if isinstance(rhs, CSRArray | COOArray): raise NotImplementedError( "Mixed-format CSC sparse-sparse matmul is not implemented. " "Convert explicitly if another format is acceptable for your workload." ) rhs = ensure_mx_array(rhs) if rhs.ndim == 1: return csc_matvec(self, rhs) if rhs.ndim >= 2: return csc_matmul(self, rhs) raise ValueError(f"CSC matmul expects rank-1 or higher RHS, got {rhs.shape}.")
[docs] def __rmul__(self, other): """Multiply the current CSCArray by a number using the ``*`` operator. This returns a new CSCArray with the data multiplied by the number, and therefore does not in-place mutate the current CSCArray. Args: other: A valid number (complex or not). Returns: A new CSCArray with the data multiplied by the number. Raises: TypeError: If ``other`` is not an actual number. """ other = sanitize_scalar(other) return CSCArray( data=other * self.data, indices=self.indices, indptr=self.indptr, shape=self.shape, sorted_indices=self.sorted_indices, has_canonical_format=self.has_canonical_format, )
[docs] def __mul__(self, other): """Multiply the current CSCArray by a number using the ``*`` operator. This returns a new CSCArray with the data multiplied by the number, and therefore does not in-place mutate the current CSCArray. Args: other: A valid number (complex or not). Returns: A new CSCArray with the data multiplied by the number. Raises: TypeError: If ``other`` is not an actual number. """ return self.__rmul__(other)
def csc_array( arg, shape, *, validate: ValidationMode = "metadata", sorted_indices: bool = False, canonical: bool | None = None, ) -> CSCArray: """Construct a :class:`CSCArray` from explicit CSC buffers.""" mode = normalize_validation_mode(validate) shape = normalize_shape(shape) if isinstance(arg, CSCArray): array = arg if array.shape != shape: raise ValueError( f"CSCArray shape mismatch: got {array.shape}, expected {shape}." ) return array try: data, indices, indptr = arg except Exception as exc: raise TypeError( "csc_array expects (data, indices, indptr) or a CSCArray instance." ) from exc data = ensure_mx_array(data) indices = ensure_mx_array(indices) indptr = ensure_mx_array(indptr) if mode != "none": validate_csc_metadata(data, indices, indptr, shape) if mode == "full": validate_csc_values(indices, indptr, shape, data.shape[0]) has_canonical_format = bool(canonical) if canonical is not None else False if has_canonical_format: sorted_indices = True return CSCArray( data=data, indices=indices, indptr=indptr, shape=shape, sorted_indices=sorted_indices, has_canonical_format=has_canonical_format, )