Structured constructors: eye, diags, fromdense#
mlx-sparse provides three convenience constructors for building CSR matrices from structured patterns rather than raw COO triplets.
Function |
What it builds |
|---|---|
|
Identity-like matrix with ones on a chosen diagonal |
|
Band matrix from one or more diagonal vectors |
|
Sparse view of an existing dense MLX array |
All three return a canonical CSRArray, sorted indices, no duplicates,
has_canonical_format=True.
import mlx.core as mx
import numpy as np
import mlx_sparse as ms
ms.use_gpu()
eye: identity and shifted diagonals#
ms.eye(n) returns an n x n identity. The k argument shifts the
diagonal: positive values move above the main diagonal (superdiagonal),
negative values below (subdiagonal).
I4 = ms.eye(4)
print("I4:\n", I4)
print("dense I4:\n", np.array(I4.todense()))
# k=1: ones one step above the main diagonal
A_sup = ms.eye(4, k=1)
print("\nSuperdiagonal k=1:\n", np.array(A_sup.todense()))
# k=-2: ones two steps below the main diagonal
A_sub = ms.eye(4, k=-2)
print("\nSubdiagonal k=-2:\n", np.array(A_sub.todense()))
I4:
CSRArray(shape=(4, 4), nnz=4, dtype=float32, index_dtype=int32, sorted_indices=True, has_canonical_format=True)
dense I4:
[[1. 0. 0. 0.]
[0. 1. 0. 0.]
[0. 0. 1. 0.]
[0. 0. 0. 1.]]
Superdiagonal k=1:
[[0. 1. 0. 0.]
[0. 0. 1. 0.]
[0. 0. 0. 1.]
[0. 0. 0. 0.]]
Subdiagonal k=-2:
[[0. 0. 0. 0.]
[0. 0. 0. 0.]
[1. 0. 0. 0.]
[0. 1. 0. 0.]]
# Rectangular: 3 rows, 5 columns, ones at k=1
R = ms.eye(3, 5, k=1)
print("Rectangular 3x5, k=1:\n", np.array(R.todense()))
Rectangular 3x5, k=1:
[[0. 1. 0. 0. 0.]
[0. 0. 1. 0. 0.]
[0. 0. 0. 1. 0.]]
All value dtypes are supported#
for dtype, label in [
(mx.float32, "float32 "),
(mx.float16, "float16 "),
(mx.bfloat16, "bfloat16"),
(mx.complex64, "complex64"),
]:
E = ms.eye(4, dtype=dtype)
mx.eval(E.data)
print(f"{label} eye nnz={E.nnz} data dtype {E.data.dtype}")
float32 eye nnz=4 data dtype float32
float16 eye nnz=4 data dtype float16
bfloat16 eye nnz=4 data dtype bfloat16
complex64 eye nnz=4 data dtype complex64
diags: band matrices from diagonal vectors#
ms.diags is modelled after scipy.sparse.diags. Pass a list of diagonal
vectors and a matching list of offsets.
n = 5
# 1-D Laplacian stencil: -1, 2, -1
T = ms.diags(
[np.full(n - 1, -1.0), np.full(n, 2.0), np.full(n - 1, -1.0)],
offsets=[-1, 0, 1],
)
print("Tridiagonal 5x5:\n", T)
print(np.array(T.todense()))
Tridiagonal 5x5:
CSRArray(shape=(5, 5), nnz=13, dtype=float32, index_dtype=int32, sorted_indices=True, has_canonical_format=True)
[[ 2. -1. 0. 0. 0.]
[-1. 2. -1. 0. 0.]
[ 0. -1. 2. -1. 0.]
[ 0. 0. -1. 2. -1.]
[ 0. 0. 0. -1. 2.]]
# Single diagonal at offset 2 inside a 5x5 matrix
D = ms.diags([1.0, 2.0, 3.0], offsets=2, shape=(5, 5))
print("Single off-diagonal at k=2:\n", np.array(D.todense()))
Single off-diagonal at k=2:
[[0. 0. 1. 0. 0.]
[0. 0. 0. 2. 0.]
[0. 0. 0. 0. 3.]
[0. 0. 0. 0. 0.]
[0. 0. 0. 0. 0.]]
# Pentadiagonal matrix (bandwidth 5): useful as a higher-order FD stencil
n = 6
P = ms.diags(
[
np.full(n - 2, 1.0),
np.full(n - 1, -2.0),
np.full(n, 4.0),
np.full(n - 1, -2.0),
np.full(n - 2, 1.0),
],
offsets=[-2, -1, 0, 1, 2],
)
print("Pentadiagonal 6x6 (bandwidth 5):\n", np.array(P.todense()))
Pentadiagonal 6x6 (bandwidth 5):
[[ 4. -2. 1. 0. 0. 0.]
[-2. 4. -2. 1. 0. 0.]
[ 1. -2. 4. -2. 1. 0.]
[ 0. 1. -2. 4. -2. 1.]
[ 0. 0. 1. -2. 4. -2.]
[ 0. 0. 0. 1. -2. 4.]]
fromdense: convert an existing dense array to CSR#
ms.fromdense inspects a rank-2 dense MLX array and extracts all non-zero
entries. An optional threshold argument drops near-zero entries below a
given absolute value.
dense = mx.array(np.array([
[1.0, 0.0, 2.0],
[0.0, 0.0, 0.0],
[3.0, 4.0, 0.0],
], dtype=np.float32))
print("Dense array:\n", np.array(dense))
csr = ms.fromdense(dense)
print("\nCSR:\n", csr)
print("data: ", np.array(csr.data))
print("indices:", np.array(csr.indices))
print("indptr: ", np.array(csr.indptr))
Dense array:
[[1. 0. 2.]
[0. 0. 0.]
[3. 4. 0.]]
CSR:
CSRArray(shape=(3, 3), nnz=4, dtype=float32, index_dtype=int32, sorted_indices=True, has_canonical_format=True)
data: [1. 2. 3. 4.]
indices: [0 2 0 1]
indptr: [0 2 2 4]
# Demonstrate the threshold parameter
rng = np.random.default_rng(0)
A = mx.array(rng.uniform(0, 1, (4, 4)).astype(np.float32))
for thr in [0.0, 0.3, 0.6]:
s = ms.fromdense(A, threshold=thr)
label = f"threshold={thr}"
print(f"{'All entries:' if thr == 0.0 else f'Above {thr}:':14} nnz={s.nnz} ({label})")
All entries: nnz=9 (threshold=0.0)
Above 0.3: nnz=7
Above 0.6: nnz=3
Round-trip: dense -> sparse -> dense#
Converting to CSR and back must reproduce the original values exactly.
rng2 = np.random.default_rng(42)
mask = rng2.random((64, 64)) < 0.1 # around 10% non-zero
vals = rng2.standard_normal((64, 64)).astype(np.float32)
vals[~mask] = 0.0
dense_mx = mx.array(vals)
sparse = ms.fromdense(dense_mx)
dense2 = sparse.todense()
mx.eval(dense2)
err = np.max(np.abs(np.array(dense2) - vals))
print(f"Round-trip max error: {err:.2e}")
Round-trip max error: 0.00e+00
Comparison with SciPy equivalents#
All three constructors mirror their SciPy counterparts. The table below shows the mapping.
import scipy.sparse
# eye
n = 6
I_ms = np.array(ms.eye(n, k=1).todense())
I_sp = scipy.sparse.eye(n, k=1, format="csr").toarray().astype(np.float32)
print(f"eye: max diff vs SciPy = {np.max(np.abs(I_ms - I_sp)):.2e}")
# diags
T_ms = np.array(T.todense()) # 5x5 tridiagonal from above
T_sp = scipy.sparse.diags(
[np.full(4, -1.0), np.full(5, 2.0), np.full(4, -1.0)],
offsets=[-1, 0, 1], format="csr"
).toarray().astype(np.float32)
print(f"diags: max diff vs SciPy = {np.max(np.abs(T_ms - T_sp)):.2e}")
# fromdense
csr_ms = np.array(ms.fromdense(dense).todense())
csr_sp = np.array(dense)
print(f"fromdense: max diff vs SciPy = {np.max(np.abs(csr_ms - csr_sp)):.2e}")
eye: max diff vs SciPy = 0.00e+00
diags: max diff vs SciPy = 0.00e+00
fromdense: max diff vs SciPy = 0.00e+00