From b44dbcc48f47cff0dc82ba851501bdfcbd0c784d Mon Sep 17 00:00:00 2001 From: ak2k <19240940+ak2k@users.noreply.github.com> Date: Sat, 20 Jun 2026 12:11:03 -0400 Subject: [PATCH] perf(build_csr): preallocate the CSR indices array MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Step 4 collected each batch's remapped tgt indices in a Python list and np.concatenate'd them after the loop. For work_referenced_works (~3B deduplicated edges) the joined int32 array is ~12 GB, and concatenate holds the per-batch list and the joined array at once — transiently doubling that 12 GB at the worst moment. n_edges is known exactly before the loop, so preallocate indices = np.empty(n_edges, int32) and fill it slice-by-slice via a running offset. An assert pins that every deduplicated edge is placed exactly once. Output is unchanged: edges stream in (src, tgt) order, so writing them in batch order yields the same array np.concatenate produced. Measured on a 3M-node / 60M-edge synthetic graph the concatenate transient is ~0.33 GB of peak RSS; isolating just the array assembly at 500M edges it is 4.09 GB -> 2.03 GB, i.e. the saving is ~4 bytes/edge and scales to ~12 GB at work_referenced_works. Adds tests/test_build_csr.py — the module had no coverage. It checks the CSR against an independently computed reference (null handling, duplicate collapse, dense remap) at both a small fixture and a 2000-node graph, cross-shard deduplication, the empty-relationship case, idempotent skip on unchanged inputs, and byte-identical output across runs. --- sync/build_csr.py | 25 ++++- tests/test_build_csr.py | 198 ++++++++++++++++++++++++++++++++++++++++ 2 files changed, 218 insertions(+), 5 deletions(-) create mode 100644 tests/test_build_csr.py diff --git a/sync/build_csr.py b/sync/build_csr.py index 22a7fdb..fe4d229 100644 --- a/sync/build_csr.py +++ b/sync/build_csr.py @@ -314,8 +314,18 @@ def _build_csr_duckdb( # Step 4: Read edges in chunks, remap to dense indices via # binary search (original_ids is sorted), and accumulate CSR # components. Peak memory per batch: ~80 MB (5M rows × 16 bytes). + # + # The CSR ``indices`` array is the full deduplicated tgt-index set + # (~12 GB for work_referenced_works's ~3B edges), so it is + # preallocated at its final size and filled slice-by-slice as the + # batches stream in. Collecting per-batch arrays in a list and + # ``np.concatenate``-ing at the end would transiently hold both the + # list and the joined array — doubling that 12 GB at the worst + # possible moment. The edge count is known exactly (``n_edges``), + # so the single allocation is safe. indptr = np.zeros(n_nodes + 1, dtype=np.int64) - indices_chunks: list[np.ndarray] = [] + indices = np.empty(n_edges, dtype=np.int32) + offset = 0 batch_count = 0 pf = pq.ParquetFile(str(tmp_edges)) @@ -330,7 +340,12 @@ def _build_csr_duckdb( src_idx = np.searchsorted(original_ids, src).astype(np.int32) tgt_idx = np.searchsorted(original_ids, tgt).astype(np.int32) - indices_chunks.append(tgt_idx) + # Append this batch's column indices into the preallocated array. + # Edges arrive in (src, tgt) order, so writing them in batch + # order keeps ``indices`` grouped by source node — exactly the + # CSR layout that ``indptr`` describes. + indices[offset:offset + len(tgt_idx)] = tgt_idx + offset += len(tgt_idx) # Count edges per source node for indptr construction. counts = np.bincount(src_idx, minlength=n_nodes) @@ -339,12 +354,12 @@ def _build_csr_duckdb( batch_count += 1 log.debug("Processed batch %d (%d rows)", batch_count, len(src)) + # All deduplicated edges must have been placed exactly once. + assert offset == n_edges, f"placed {offset} indices, expected {n_edges}" + # Build CSR indptr (cumulative sum of per-row edge counts). np.cumsum(indptr, out=indptr) - indices = np.concatenate(indices_chunks) - del indices_chunks - data = np.ones(n_edges, dtype=np.float32) csr = sparse.csr_matrix( (data, indices, indptr), diff --git a/tests/test_build_csr.py b/tests/test_build_csr.py new file mode 100644 index 0000000..480593d --- /dev/null +++ b/tests/test_build_csr.py @@ -0,0 +1,198 @@ +"""Tests for sync.build_csr — CSR construction from relationship shards. + +These pin the module's two load-bearing guarantees: + +* **Correctness** — the dense-remapped CSR matches an independently computed + reference, with nulls dropped, duplicate edges collapsed, and the sparse + OpenAlex IDs mapped to a contiguous ``[0, n_nodes)`` range sorted by ID. +* **Determinism** — the same shards always produce a byte-identical ``.npz`` + (a documented invariant: explicit ``ORDER BY`` plus a sorted shard list). + +build_csr depends on numpy, scipy, and duckdb, which are optional relative to +the core sync pipeline; the tests skip where those are unavailable. +""" + +from __future__ import annotations + +import hashlib + +import pytest + +# build_csr's deps (numpy/scipy/duckdb) are optional relative to the core sync +# pipeline; importorskip yields the modules and skips the file where any is +# absent. The first-party import then necessarily follows the guards (E402). +np = pytest.importorskip("numpy") +pa = pytest.importorskip("pyarrow") +pq = pytest.importorskip("pyarrow.parquet") +sparse = pytest.importorskip("scipy.sparse") +pytest.importorskip("duckdb") + +from sync import build_csr # noqa: E402 + +REL = "work_referenced_works" +SRC_COL, TGT_COL = "work_id", "referenced_work_id" + + +def _sha256(path): + return hashlib.sha256(path.read_bytes()).hexdigest() + + +def _write_shard(shard_dir, edges, index=0): + """Write ``edges`` (list of (src, tgt), None allowed) as one parquet shard.""" + shard_dir.mkdir(parents=True, exist_ok=True) + src = pa.array([e[0] for e in edges], type=pa.uint64()) + tgt = pa.array([e[1] for e in edges], type=pa.uint64()) + pq.write_table( + pa.table({SRC_COL: src, TGT_COL: tgt}), + shard_dir / f"part_{index:04d}.parquet", + ) + + +def _random_graph(seed, *, n_nodes=2_000, n_edges=8_000): + """A reproducible sparse graph over OpenAlex-scale (uint64) IDs.""" + rng = np.random.default_rng(seed) + ids = rng.choice(np.arange(1, 7_000_000_000), size=n_nodes, replace=False) + return [ + (int(ids[i]), int(ids[j])) + for i, j in zip( + rng.integers(0, n_nodes, n_edges), rng.integers(0, n_nodes, n_edges) + ) + ] + + +def _reference_csr(edges): + """Build the expected CSR the long way, independent of build_csr's path. + + Mirrors the module contract: node set is every non-null src plus every + non-null tgt; edges are the distinct rows with both endpoints present; + dense index is the rank of the ID in sorted order. + """ + nodes = {s for s, _ in edges if s is not None} + nodes |= {t for _, t in edges if t is not None} + sorted_ids = sorted(nodes) + index = {nid: i for i, nid in enumerate(sorted_ids)} + valid = sorted({(s, t) for s, t in edges if s is not None and t is not None}) + rows = [index[s] for s, _ in valid] + cols = [index[t] for _, t in valid] + n = len(sorted_ids) + ref = sparse.csr_matrix( + (np.ones(len(valid), dtype=np.float32), (rows, cols)), + shape=(n, n), + ) + ref.sort_indices() + return ref, np.array(sorted_ids, dtype=np.uint64) + + +def _build(tmp_path, edges=None, *, shards=None, force=True): + """Build CSR for ``edges`` (one shard) or ``shards`` (one shard per list).""" + groups = shards if shards is not None else [edges] + parquet_dir = tmp_path / "data" + shard_dir = build_csr.rt_dir(parquet_dir, REL) + for i, grp in enumerate(groups): + _write_shard(shard_dir, grp, i) + out_dir = tmp_path / "csr" + result = build_csr.build_csr( + REL, parquet_dir=parquet_dir, output_dir=out_dir, force=force + ) + return result, out_dir / f"{REL}.npz" + + +def test_csr_matches_independent_reference(tmp_path): + # Sparse, OpenAlex-scale IDs; duplicate edge (10, 30); a null on each side. + edges = [ + (10, 30), (10, 20), (40, 10), (10, 30), # dup + (30, 30), (20, 40), (40, 20), + (50, None), (None, 60), # endpoints seed nodes 50 & 60 but no edge + (7_000_000_001, 10), (10, 7_000_000_001), + ] + _, npz = _build(tmp_path, edges) + got = sparse.load_npz(npz) + id_map = np.load(npz.with_suffix(".id_map.npy")) + + ref, ref_ids = _reference_csr(edges) + assert got.shape == ref.shape + assert got.nnz == ref.nnz + assert got.indices.dtype == np.int32 + np.testing.assert_array_equal(got.indptr, ref.indptr) + np.testing.assert_array_equal(got.indices, ref.indices) + np.testing.assert_array_equal(got.data, ref.data) + np.testing.assert_array_equal(id_map, ref_ids) + # {10,20,30,40,50,60,7000000001}: 50 and 60 are seeded by null-partnered + # rows, so they are nodes with no incident edge. + assert got.shape[0] == 7 + + +def test_node_set_includes_null_partnered_endpoints(tmp_path): + edges = [(10, 20), (30, None), (None, 40)] + _, npz = _build(tmp_path, edges) + got = sparse.load_npz(npz) + id_map = np.load(npz.with_suffix(".id_map.npy")) + # 10,20 from the real edge; 30 and 40 seeded by their null-partnered rows + np.testing.assert_array_equal(id_map, np.array([10, 20, 30, 40], dtype=np.uint64)) + assert got.shape == (4, 4) + assert got.nnz == 1 # only (10,20) is a real edge + + +def test_duplicate_edges_collapse(tmp_path): + edges = [(1, 2)] * 5 + [(2, 1)] + _, npz = _build(tmp_path, edges) + got = sparse.load_npz(npz) + assert got.nnz == 2 # (1,2) and (2,1), the four duplicates dropped + + +def test_empty_relationship_yields_empty_matrix(tmp_path): + result, npz = _build(tmp_path, []) + got = sparse.load_npz(npz) + assert got.shape == (0, 0) + assert got.nnz == 0 + assert result["n_edges"] == 0 + + +def test_duplicate_edges_collapse_across_shards(tmp_path): + # Production always reads many shards; an edge repeated in *different* shards + # must dedup globally, not just within one shard. + shards = [[(1, 2), (3, 4)], [(1, 2), (5, 6)]] + _, npz = _build(tmp_path, shards=shards) + got = sparse.load_npz(npz) + ref, ref_ids = _reference_csr([e for shard in shards for e in shard]) + assert got.nnz == 3 # (1, 2) once, plus (3, 4) and (5, 6) + np.testing.assert_array_equal(got.indptr, ref.indptr) + np.testing.assert_array_equal(got.indices, ref.indices) + np.testing.assert_array_equal(np.load(npz.with_suffix(".id_map.npy")), ref_ids) + + +def test_larger_graph_matches_reference(tmp_path): + # Validate the remap against the independent reference at a scale where + # dense-index ordering bugs would surface, not just on the ~7-node fixture. + edges = _random_graph(20260620) + _, npz = _build(tmp_path, edges) + got = sparse.load_npz(npz) + ref, ref_ids = _reference_csr(edges) + assert got.shape == ref.shape + assert got.nnz == ref.nnz + np.testing.assert_array_equal(got.indptr, ref.indptr) + np.testing.assert_array_equal(got.indices, ref.indices) + np.testing.assert_array_equal(got.data, ref.data) + np.testing.assert_array_equal(np.load(npz.with_suffix(".id_map.npy")), ref_ids) + + +def test_unchanged_inputs_skip_and_preserve_output(tmp_path): + # Idempotency: a second build with force=False detects unchanged inputs via + # the provenance fingerprint, skips, and leaves the .npz byte-identical. + edges = [(10, 20), (20, 30), (10, 30)] + result1, npz = _build(tmp_path, edges) + assert result1["status"] == "built" + before = _sha256(npz) + result2, _ = _build(tmp_path, edges, force=False) + assert result2["status"] == "skipped" + assert _sha256(npz) == before + + +def test_output_is_byte_identical_across_runs(tmp_path): + edges = _random_graph(20260620) + _, npz_a = _build(tmp_path / "a", edges) + _, npz_b = _build(tmp_path / "b", edges) + assert _sha256(npz_a) == _sha256(npz_b) + assert _sha256(npz_a.with_suffix(".id_map.npy")) == _sha256( + npz_b.with_suffix(".id_map.npy") + )