Canonicalization and validation#
The CSR format has two optional invariants that make operations faster and results predictable:
Sorted column indices: within each row, column indices appear in ascending order. Tracked by
sorted_indiceson theCSRArray.No duplicate entries: at most one stored value per
(row, col)pair. When combined with sorted indices, this is called canonical form, tracked byhas_canonical_format.
These flags are hints. mlx-sparse trusts them without re-checking by
default. Passing validate="full" at construction time verifies them on
the host.
import mlx.core as mx
import numpy as np
import mlx_sparse as ms
ms.use_gpu()
Building a non-canonical matrix#
We deliberately build a COO matrix with unsorted indices and duplicate entries to show what happens before and after canonicalization.
# Row 0: entries at col 2, then col 0 (unsorted)
# Row 1: entries at col 1 and col 3, plus a duplicate at col 1 (sum -> 3.0)
# Row 2: single entry at col 0
data = mx.array([1.0, 1.0, 1.0, 2.0, 5.0, 7.0], dtype=mx.float32)
row = mx.array([0, 0, 1, 1, 1, 2], dtype=mx.int32)
col = mx.array([2, 0, 1, 1, 3, 0], dtype=mx.int32)
coo = ms.coo_array((data, (row, col)), shape=(3, 4))
print("COO before canonicalization:\n", coo)
# tocsr without canonical: preserves whatever order COO had
csr_raw = coo.tocsr(canonical=False)
mx.eval(csr_raw.indices)
print("\nCSR (not canonical, sorted_indices=False):\n", csr_raw)
print("indices:", np.array(csr_raw.indices))
# tocsr with canonical: sort + sum duplicates
csr_can = coo.tocsr(canonical=True)
mx.eval(csr_can.data, csr_can.indices)
print("\nCSR canonical (tocsr(canonical=True)):\n", csr_can)
print("indices:", np.array(csr_can.indices))
print("data: ", np.array(csr_can.data))
COO before canonicalization:
COOArray(shape=(3, 4), nnz=6, dtype=float32, index_dtype=int32, has_canonical_format=False)
CSR (not canonical, sorted_indices=False):
CSRArray(shape=(3, 4), nnz=6, dtype=float32, index_dtype=int32, sorted_indices=False, has_canonical_format=False)
indices: [2 0 1 3 2 0]
CSR canonical (tocsr(canonical=True)):
CSRArray(shape=(3, 4), nnz=5, dtype=float32, index_dtype=int32, sorted_indices=True, has_canonical_format=True)
indices: [0 2 1 3 0]
data: [1. 1. 3. 5. 7.]
canonicalize() on an existing CSRArray#
csr_raw2 = coo.tocsr(canonical=False)
print(f"Before canonicalize: sorted_indices={csr_raw2.sorted_indices} "
f"has_canonical_format={csr_raw2.has_canonical_format} nnz={csr_raw2.nnz}")
csr_fixed = csr_raw2.canonicalize()
mx.eval(csr_fixed.data)
print(f"After canonicalize: sorted_indices={csr_fixed.sorted_indices} "
f"has_canonical_format={csr_fixed.has_canonical_format} nnz={csr_fixed.nnz}")
print("\nDense (canonical):\n", np.array(csr_fixed.todense()))
Before canonicalize: sorted_indices=False has_canonical_format=False nnz=6
After canonicalize: sorted_indices=True has_canonical_format=True nnz=5
Dense (canonical):
[[1. 0. 1. 0.]
[0. 3. 0. 5.]
[7. 0. 0. 0.]]
sort_indices(): sort only, preserve duplicates#
csr_nodup = coo.tocsr(canonical=False)
mx.eval(csr_nodup.indices)
print(f"Before sort_indices: indices={np.array(csr_nodup.indices)} "
f"sorted_indices={csr_nodup.sorted_indices}")
csr_sorted = csr_nodup.sort_indices()
mx.eval(csr_sorted.indices)
print(f"After sort_indices: indices={np.array(csr_sorted.indices)} "
f"sorted_indices={csr_sorted.sorted_indices}")
print(f"Note: nnz still = {csr_sorted.nnz} (duplicates kept)")
Before sort_indices: indices=[2 0 1 1 3 0] sorted_indices=False
After sort_indices: indices=[0 2 1 1 3 0] sorted_indices=True
Note: nnz still = 6 (duplicates kept)
sum_duplicates(): merge duplicate entries#
csr_sorted2 = coo.tocsr(canonical=False).sort_indices()
print(f"Before sum_duplicates: nnz={csr_sorted2.nnz} "
f"has_canonical_format={csr_sorted2.has_canonical_format}")
csr_summed = csr_sorted2.sum_duplicates()
mx.eval(csr_summed.data)
print(f"After sum_duplicates: nnz={csr_summed.nnz} "
f"has_canonical_format={csr_summed.has_canonical_format}")
# Row 1 now has col 1 = 1.0 + 2.0 = 3.0, col 3 = 5.0
ip = np.array(csr_summed.indptr)
row1_data = np.array(csr_summed.data)[ip[1]:ip[2]]
print(f"Row 1 data (duplicate col 1 summed): {row1_data}")
Before sum_duplicates: nnz=6 has_canonical_format=False
After sum_duplicates: nnz=5 has_canonical_format=True
Row 1 data (duplicate col 1 summed): [3. 5.]
Validation levels#
Pass validate= to csr_array() or coo_array() to control how much
checking is done at construction time.
Level |
What is checked |
|---|---|
|
Shapes, dtypes, lengths only (fast, no host sync) |
|
All metadata checks plus value bounds (requires |
|
No checks at all (caller guarantees correctness) |
import scipy.sparse
rng = np.random.default_rng(5)
sp = scipy.sparse.random(32, 32, density=0.1, format="csr",
dtype=np.float32, random_state=rng)
for level, label in [
("metadata", "metadata validate"),
("full", "full validate "),
(False, "no validate "),
]:
A = ms.csr_array(
(mx.array(sp.data),
mx.array(sp.indices.astype(np.int32)),
mx.array(sp.indptr.astype(np.int32))),
shape=sp.shape, sorted_indices=True, canonical=True,
validate=level,
)
print(f"{label}: ok")
# Full validation catches out-of-bounds column indices
bad_indices = mx.array(np.array([0, 99, 0], dtype=np.int32)) # 99 is out of range for n_cols=4
bad_indptr = mx.array(np.array([0, 2, 2, 3], dtype=np.int32))
bad_data = mx.array(np.ones(3, dtype=np.float32))
try:
ms.csr_array(
(bad_data, bad_indices, bad_indptr),
shape=(3, 4), validate="full",
)
except ValueError as e:
print(f"\nCaught expected error (out-of-bounds column index):\n {e}")
metadata validate: ok (fast path)
full validate: ok (bounds checked)
no validate: ok (trusted unconditionally)
Caught expected error (out-of-bounds column index):
CSRArray indices must be in bounds for n_cols=4, got min=0, max=99
Recommended workflow summary#
Scenario |
Recommended flags |
|---|---|
Assemble from scratch (COO) |
|
Import from SciPy CSR |
|
Untrusted external data |
|
Performance-critical hot path |
|
# Pattern 1: Assemble from COO
data2 = mx.array([1.0, 2.0, 3.0, 4.0, 5.0], dtype=mx.float32)
row2 = mx.array([0, 0, 1, 2, 3], dtype=mx.int32)
col2 = mx.array([0, 3, 1, 2, 3], dtype=mx.int32)
coo2 = ms.coo_array((data2, (row2, col2)), shape=(4, 4))
csr2 = coo2.tocsr(canonical=True)
print("Assembled from COO:\n", csr2)
# Pattern 2: Import from SciPy
A_from_sp = ms.from_scipy(sp)
print("\nImported from SciPy:\n", A_from_sp)
Assembled from COO:
CSRArray(shape=(4, 4), nnz=5, dtype=float32, index_dtype=int32, sorted_indices=True, has_canonical_format=True)
Imported from SciPy:
CSRArray(shape=(32, 32), nnz=102, dtype=float32, index_dtype=int32, sorted_indices=True, has_canonical_format=True)