Structured constructors: eye, diags, fromdense#

mlx-sparse provides three convenience constructors for building CSR matrices from structured patterns rather than raw COO triplets.

Function

What it builds

ms.eye(n, m, k=0)

Identity-like matrix with ones on a chosen diagonal

ms.diags(diags, offsets)

Band matrix from one or more diagonal vectors

ms.fromdense(A)

Sparse view of an existing dense MLX array

All three return a canonical CSRArray, sorted indices, no duplicates, has_canonical_format=True.

import mlx.core as mx
import numpy as np
import mlx_sparse as ms

ms.use_gpu()

eye: identity and shifted diagonals#

ms.eye(n) returns an n x n identity. The k argument shifts the diagonal: positive values move above the main diagonal (superdiagonal), negative values below (subdiagonal).

I4 = ms.eye(4)
print("I4:\n", I4)
print("dense I4:\n", np.array(I4.todense()))

# k=1: ones one step above the main diagonal
A_sup = ms.eye(4, k=1)
print("\nSuperdiagonal k=1:\n", np.array(A_sup.todense()))

# k=-2: ones two steps below the main diagonal
A_sub = ms.eye(4, k=-2)
print("\nSubdiagonal k=-2:\n", np.array(A_sub.todense()))
I4:
 CSRArray(shape=(4, 4), nnz=4, dtype=float32, index_dtype=int32, sorted_indices=True, has_canonical_format=True)
dense I4:
 [[1. 0. 0. 0.]
 [0. 1. 0. 0.]
 [0. 0. 1. 0.]
 [0. 0. 0. 1.]]

Superdiagonal k=1:
 [[0. 1. 0. 0.]
 [0. 0. 1. 0.]
 [0. 0. 0. 1.]
 [0. 0. 0. 0.]]

Subdiagonal k=-2:
 [[0. 0. 0. 0.]
 [0. 0. 0. 0.]
 [1. 0. 0. 0.]
 [0. 1. 0. 0.]]
# Rectangular: 3 rows, 5 columns, ones at k=1
R = ms.eye(3, 5, k=1)
print("Rectangular 3x5, k=1:\n", np.array(R.todense()))
Rectangular 3x5, k=1:
 [[0. 1. 0. 0. 0.]
 [0. 0. 1. 0. 0.]
 [0. 0. 0. 1. 0.]]

All value dtypes are supported#

for dtype, label in [
    (mx.float32,  "float32 "),
    (mx.float16,  "float16 "),
    (mx.bfloat16, "bfloat16"),
    (mx.complex64, "complex64"),
]:
    E = ms.eye(4, dtype=dtype)
    mx.eval(E.data)
    print(f"{label} eye nnz={E.nnz}  data dtype {E.data.dtype}")
float32  eye nnz=4  data dtype float32
float16  eye nnz=4  data dtype float16
bfloat16 eye nnz=4  data dtype bfloat16
complex64 eye nnz=4  data dtype complex64

diags: band matrices from diagonal vectors#

ms.diags is modelled after scipy.sparse.diags. Pass a list of diagonal vectors and a matching list of offsets.

n = 5
# 1-D Laplacian stencil: -1, 2, -1
T = ms.diags(
    [np.full(n - 1, -1.0), np.full(n, 2.0), np.full(n - 1, -1.0)],
    offsets=[-1, 0, 1],
)
print("Tridiagonal 5x5:\n", T)
print(np.array(T.todense()))
Tridiagonal 5x5:
 CSRArray(shape=(5, 5), nnz=13, dtype=float32, index_dtype=int32, sorted_indices=True, has_canonical_format=True)
[[ 2. -1.  0.  0.  0.]
 [-1.  2. -1.  0.  0.]
 [ 0. -1.  2. -1.  0.]
 [ 0.  0. -1.  2. -1.]
 [ 0.  0.  0. -1.  2.]]
# Single diagonal at offset 2 inside a 5x5 matrix
D = ms.diags([1.0, 2.0, 3.0], offsets=2, shape=(5, 5))
print("Single off-diagonal at k=2:\n", np.array(D.todense()))
Single off-diagonal at k=2:
 [[0. 0. 1. 0. 0.]
 [0. 0. 0. 2. 0.]
 [0. 0. 0. 0. 3.]
 [0. 0. 0. 0. 0.]
 [0. 0. 0. 0. 0.]]
# Pentadiagonal matrix (bandwidth 5): useful as a higher-order FD stencil
n = 6
P = ms.diags(
    [
        np.full(n - 2, 1.0),
        np.full(n - 1, -2.0),
        np.full(n, 4.0),
        np.full(n - 1, -2.0),
        np.full(n - 2, 1.0),
    ],
    offsets=[-2, -1, 0, 1, 2],
)
print("Pentadiagonal 6x6 (bandwidth 5):\n", np.array(P.todense()))
Pentadiagonal 6x6 (bandwidth 5):
 [[ 4. -2.  1.  0.  0.  0.]
 [-2.  4. -2.  1.  0.  0.]
 [ 1. -2.  4. -2.  1.  0.]
 [ 0.  1. -2.  4. -2.  1.]
 [ 0.  0.  1. -2.  4. -2.]
 [ 0.  0.  0.  1. -2.  4.]]

fromdense: convert an existing dense array to CSR#

ms.fromdense inspects a rank-2 dense MLX array and extracts all non-zero entries. An optional threshold argument drops near-zero entries below a given absolute value.

dense = mx.array(np.array([
    [1.0, 0.0, 2.0],
    [0.0, 0.0, 0.0],
    [3.0, 4.0, 0.0],
], dtype=np.float32))

print("Dense array:\n", np.array(dense))

csr = ms.fromdense(dense)
print("\nCSR:\n", csr)
print("data:   ", np.array(csr.data))
print("indices:", np.array(csr.indices))
print("indptr: ", np.array(csr.indptr))
Dense array:
 [[1. 0. 2.]
 [0. 0. 0.]
 [3. 4. 0.]]

CSR:
 CSRArray(shape=(3, 3), nnz=4, dtype=float32, index_dtype=int32, sorted_indices=True, has_canonical_format=True)
data:    [1. 2. 3. 4.]
indices: [0 2 0 1]
indptr:  [0 2 2 4]
# Demonstrate the threshold parameter
rng = np.random.default_rng(0)
A = mx.array(rng.uniform(0, 1, (4, 4)).astype(np.float32))

for thr in [0.0, 0.3, 0.6]:
    s = ms.fromdense(A, threshold=thr)
    label = f"threshold={thr}"
    print(f"{'All entries:' if thr == 0.0 else f'Above {thr}:':14} nnz={s.nnz}  ({label})")
All entries:   nnz=9  (threshold=0.0)
Above 0.3:     nnz=7
Above 0.6:     nnz=3

Round-trip: dense -> sparse -> dense#

Converting to CSR and back must reproduce the original values exactly.

rng2 = np.random.default_rng(42)
mask = rng2.random((64, 64)) < 0.1  # around 10% non-zero
vals = rng2.standard_normal((64, 64)).astype(np.float32)
vals[~mask] = 0.0

dense_mx = mx.array(vals)
sparse = ms.fromdense(dense_mx)
dense2 = sparse.todense()
mx.eval(dense2)

err = np.max(np.abs(np.array(dense2) - vals))
print(f"Round-trip max error: {err:.2e}")
Round-trip max error: 0.00e+00

Comparison with SciPy equivalents#

All three constructors mirror their SciPy counterparts. The table below shows the mapping.

import scipy.sparse

# eye
n = 6
I_ms = np.array(ms.eye(n, k=1).todense())
I_sp = scipy.sparse.eye(n, k=1, format="csr").toarray().astype(np.float32)
print(f"eye:   max diff vs SciPy = {np.max(np.abs(I_ms - I_sp)):.2e}")

# diags
T_ms = np.array(T.todense())  # 5x5 tridiagonal from above
T_sp = scipy.sparse.diags(
    [np.full(4, -1.0), np.full(5, 2.0), np.full(4, -1.0)],
    offsets=[-1, 0, 1], format="csr"
).toarray().astype(np.float32)
print(f"diags: max diff vs SciPy = {np.max(np.abs(T_ms - T_sp)):.2e}")

# fromdense
csr_ms = np.array(ms.fromdense(dense).todense())
csr_sp = np.array(dense)
print(f"fromdense: max diff vs SciPy = {np.max(np.abs(csr_ms - csr_sp)):.2e}")
eye:   max diff vs SciPy = 0.00e+00
diags: max diff vs SciPy = 0.00e+00
fromdense: max diff vs SciPy = 0.00e+00