Canonicalization and validation#

The CSR format has two optional invariants that make operations faster and results predictable:

  1. Sorted column indices: within each row, column indices appear in ascending order. Tracked by sorted_indices on the CSRArray.

  2. No duplicate entries: at most one stored value per (row, col) pair. When combined with sorted indices, this is called canonical form, tracked by has_canonical_format.

These flags are hints. mlx-sparse trusts them without re-checking by default. Passing validate="full" at construction time verifies them on the host.

import mlx.core as mx
import numpy as np
import mlx_sparse as ms

ms.use_gpu()

Building a non-canonical matrix#

We deliberately build a COO matrix with unsorted indices and duplicate entries to show what happens before and after canonicalization.

# Row 0: entries at col 2, then col 0 (unsorted)
# Row 1: entries at col 1 and col 3, plus a duplicate at col 1 (sum -> 3.0)
# Row 2: single entry at col 0
data = mx.array([1.0, 1.0, 1.0, 2.0, 5.0, 7.0], dtype=mx.float32)
row = mx.array([0, 0, 1, 1, 1, 2], dtype=mx.int32)
col = mx.array([2, 0, 1, 1, 3, 0], dtype=mx.int32)

coo = ms.coo_array((data, (row, col)), shape=(3, 4))
print("COO before canonicalization:\n", coo)

# tocsr without canonical: preserves whatever order COO had
csr_raw = coo.tocsr(canonical=False)
mx.eval(csr_raw.indices)
print("\nCSR (not canonical, sorted_indices=False):\n", csr_raw)
print("indices:", np.array(csr_raw.indices))

# tocsr with canonical: sort + sum duplicates
csr_can = coo.tocsr(canonical=True)
mx.eval(csr_can.data, csr_can.indices)
print("\nCSR canonical (tocsr(canonical=True)):\n", csr_can)
print("indices:", np.array(csr_can.indices))
print("data:   ", np.array(csr_can.data))
COO before canonicalization:
 COOArray(shape=(3, 4), nnz=6, dtype=float32, index_dtype=int32, has_canonical_format=False)

CSR (not canonical, sorted_indices=False):
 CSRArray(shape=(3, 4), nnz=6, dtype=float32, index_dtype=int32, sorted_indices=False, has_canonical_format=False)
indices: [2 0 1 3 2 0]

CSR canonical (tocsr(canonical=True)):
 CSRArray(shape=(3, 4), nnz=5, dtype=float32, index_dtype=int32, sorted_indices=True, has_canonical_format=True)
indices: [0 2 1 3 0]
data:    [1. 1. 3. 5. 7.]

canonicalize() on an existing CSRArray#

csr_raw2 = coo.tocsr(canonical=False)
print(f"Before canonicalize: sorted_indices={csr_raw2.sorted_indices}  "
      f"has_canonical_format={csr_raw2.has_canonical_format}  nnz={csr_raw2.nnz}")

csr_fixed = csr_raw2.canonicalize()
mx.eval(csr_fixed.data)
print(f"After canonicalize:  sorted_indices={csr_fixed.sorted_indices}   "
      f"has_canonical_format={csr_fixed.has_canonical_format}   nnz={csr_fixed.nnz}")

print("\nDense (canonical):\n", np.array(csr_fixed.todense()))
Before canonicalize: sorted_indices=False  has_canonical_format=False  nnz=6
After canonicalize:  sorted_indices=True   has_canonical_format=True   nnz=5

Dense (canonical):
 [[1. 0. 1. 0.]
 [0. 3. 0. 5.]
 [7. 0. 0. 0.]]

sort_indices(): sort only, preserve duplicates#

csr_nodup = coo.tocsr(canonical=False)
mx.eval(csr_nodup.indices)
print(f"Before sort_indices: indices={np.array(csr_nodup.indices)}  "
      f"sorted_indices={csr_nodup.sorted_indices}")

csr_sorted = csr_nodup.sort_indices()
mx.eval(csr_sorted.indices)
print(f"After sort_indices:  indices={np.array(csr_sorted.indices)}  "
      f"sorted_indices={csr_sorted.sorted_indices}")
print(f"Note: nnz still = {csr_sorted.nnz} (duplicates kept)")
Before sort_indices: indices=[2 0 1 1 3 0]  sorted_indices=False
After sort_indices:  indices=[0 2 1 1 3 0]  sorted_indices=True
Note: nnz still = 6 (duplicates kept)

sum_duplicates(): merge duplicate entries#

csr_sorted2 = coo.tocsr(canonical=False).sort_indices()
print(f"Before sum_duplicates: nnz={csr_sorted2.nnz}  "
      f"has_canonical_format={csr_sorted2.has_canonical_format}")

csr_summed = csr_sorted2.sum_duplicates()
mx.eval(csr_summed.data)
print(f"After sum_duplicates:  nnz={csr_summed.nnz}  "
      f"has_canonical_format={csr_summed.has_canonical_format}")

# Row 1 now has col 1 = 1.0 + 2.0 = 3.0, col 3 = 5.0
ip = np.array(csr_summed.indptr)
row1_data = np.array(csr_summed.data)[ip[1]:ip[2]]
print(f"Row 1 data (duplicate col 1 summed): {row1_data}")
Before sum_duplicates: nnz=6  has_canonical_format=False
After sum_duplicates:  nnz=5  has_canonical_format=True
Row 1 data (duplicate col 1 summed): [3. 5.]

Validation levels#

Pass validate= to csr_array() or coo_array() to control how much checking is done at construction time.

Level

What is checked

"metadata" (default)

Shapes, dtypes, lengths only (fast, no host sync)

"full" / True

All metadata checks plus value bounds (requires mx.eval)

False / "none"

No checks at all (caller guarantees correctness)

import scipy.sparse
rng = np.random.default_rng(5)
sp = scipy.sparse.random(32, 32, density=0.1, format="csr",
                           dtype=np.float32, random_state=rng)

for level, label in [
    ("metadata", "metadata validate"),
    ("full",     "full validate    "),
    (False,      "no validate      "),
]:
    A = ms.csr_array(
        (mx.array(sp.data),
         mx.array(sp.indices.astype(np.int32)),
         mx.array(sp.indptr.astype(np.int32))),
        shape=sp.shape, sorted_indices=True, canonical=True,
        validate=level,
    )
    print(f"{label}: ok")

# Full validation catches out-of-bounds column indices
bad_indices = mx.array(np.array([0, 99, 0], dtype=np.int32))  # 99 is out of range for n_cols=4
bad_indptr = mx.array(np.array([0, 2, 2, 3], dtype=np.int32))
bad_data = mx.array(np.ones(3, dtype=np.float32))

try:
    ms.csr_array(
        (bad_data, bad_indices, bad_indptr),
        shape=(3, 4), validate="full",
    )
except ValueError as e:
    print(f"\nCaught expected error (out-of-bounds column index):\n  {e}")
metadata validate: ok (fast path)
full validate:     ok (bounds checked)
no validate:       ok (trusted unconditionally)

Caught expected error (out-of-bounds column index):
  CSRArray indices must be in bounds for n_cols=4, got min=0, max=99