From ce1aa1870d34b7da9ad80d9158e18cad1622e35f Mon Sep 17 00:00:00 2001 From: Michael Bunsen Date: Sat, 6 Jun 2026 08:30:50 -0700 Subject: [PATCH] fix(bq-squashfs): run the training_images MERGE once per run, not per chunk MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit The download pipeline MERGEd each chunk's results into training_images immediately after downloading it. A BigQuery MERGE is billed mainly for scanning its target table, and training_images in global_all_leps_2605 is ~24M rows / ~7 GB with no partitioning or clustering, so every chunk re-scanned the whole table. At the default chunk size of 10,000 that is ~2,400 full-table scans (~17 TB, ~$85 at $5/TB) for a single download pass. The downloads table already records every result via free batch loads, and resume already reads from it (get_pending_rows LEFT JOINs against it), so the per-chunk MERGE was only keeping training_images.fetch_status cosmetically current mid-run. This moves the MERGE to a single call at the end of the run, sourced from the (deduplicated) downloads table — one target scan instead of thousands. Also adds a cost note at the top of the file for future readers (human and AI) and rewrites the merge tests for the new contract. Co-Authored-By: Claude Opus 4.8 (1M context) --- .../bq_squashfs/download_images.py | 153 +++++++---- tests/dataset_tools/test_download_images.py | 240 +++++++----------- 2 files changed, 183 insertions(+), 210 deletions(-) diff --git a/src/dataset_tools/bq_squashfs/download_images.py b/src/dataset_tools/bq_squashfs/download_images.py index 6c172ab..3285cfb 100644 --- a/src/dataset_tools/bq_squashfs/download_images.py +++ b/src/dataset_tools/bq_squashfs/download_images.py @@ -31,12 +31,41 @@ chunk sqfs files into the final task_N.sqfs archives. """ +# ───────────────────────────────────────────────────────────────────────────── +# COST NOTE FOR AI ASSISTANTS (and humans) — READ BEFORE ADDING BQ CALLS HERE +# +# This script runs at scale against large tables. Know the sizes before you +# add or change any BigQuery statement: +# +# training_images ~24M rows / ~7 GB (global_all_leps_2605) +# training_images_downloads grows to a comparable size as downloads complete +# +# BigQuery on-demand pricing is $5 per TB *scanned* (not per row written). +# Two rules that matter most in this file: +# +# 1. A MERGE / UPDATE / DELETE is billed mainly for scanning the *target* +# table, regardless of how few rows the source touches. Running one inside +# the per-chunk loop re-scans all ~7 GB on every chunk — at chunk_size +# 10000 that is ~2400 full-table scans (~17 TB, ~$85) for a single pass. +# Do DML *once per run*, not once per chunk. See +# merge_downloads_into_training_images() below. +# +# 2. Appending rows via load_table_from_dataframe() (batch load) is free and +# parallel-safe — prefer it over streaming inserts or per-row DML. The +# downloads table is the source of truth for progress; status only needs +# to land in training_images once, at the end. +# +# Before adding a query: estimate bytes scanned (a full SELECT * over +# training_images is ~7 GB = ~$0.035 each — cheap once, expensive in a loop), +# filter/project columns, and never put unbounded DML inside a loop. To verify +# real spend, check INFORMATION_SCHEMA.JOBS_BY_PROJECT (total_bytes_billed). +# ───────────────────────────────────────────────────────────────────────────── + import argparse import random import subprocess import threading import time -import uuid from concurrent.futures import ThreadPoolExecutor, as_completed from pathlib import Path @@ -219,15 +248,14 @@ def write_results_to_bq( raise -def merge_chunk_into_training_images( +def merge_downloads_into_training_images( client: bigquery.Client, - results: list[dict], training_table: str, downloads_table: str, ) -> int: - """MERGE all chunk results into training_images, updating fetch_status for every outcome. + """MERGE all recorded download results into training_images, once per run. - Outcomes merged: + Updates fetch_status (and dims) for every outcome: downloaded — fetch_status='downloaded', dims and corrupted populated corrupted — fetch_status='corrupted', corrupted=True, dims NULL failed — fetch_status='failed', all fields NULL @@ -237,56 +265,61 @@ def merge_chunk_into_training_images( the WHERE fetch_status='pending' clause — no wasted retry attempts. Retrying can still be done intentionally via retry_failed_downloads.py. - Uses a temp table containing only this chunk's rows so the MERGE scans - a small dataset rather than all of training_images_downloads. - Only updates rows that are still 'pending' — safe from parallel tasks. - Returns the number of rows updated. - """ - to_merge = [r for r in results if r["fetch_status"] in ("downloaded", "corrupted", "failed")] - if not to_merge: - return 0 - - # Derive "project.dataset" from training_table ("project.dataset.table_name") - _parts = training_table.split(".") - _dataset_ref = ".".join(_parts[:2]) if len(_parts) >= 2 else _parts[0] - tmp_table = f"{_dataset_ref}._dl_merge_tmp_{uuid.uuid4().hex[:8]}" - df = pd.DataFrame(to_merge) - job_config = bigquery.LoadJobConfig( - write_disposition=bigquery.WriteDisposition.WRITE_TRUNCATE, - schema=DOWNLOADS_SCHEMA, - ) - client.load_table_from_dataframe(df, tmp_table, job_config=job_config).result() + Source is the downloads table, which already holds every result (appended + per chunk via the free batch-load tier). Because a MERGE is billed mainly + for scanning the *target* table, running it once per run instead of once + per chunk turns N full-table scans into a single one. The source is + deduplicated per dataset_source_uuid (the table is append-only and + --force-redownload can append a second row), preferring a successful + outcome over a failed one. + Only updates rows still 'pending' — safe to re-run and safe from parallel + tasks. Returns the number of rows updated. + """ merge_sql = f""" MERGE `{training_table}` T - USING `{tmp_table}` S + USING ( + SELECT + dataset_source_uuid, + ARRAY_AGG( + STRUCT(fetch_status, image_width, image_height, image_size, corrupted) + ORDER BY CASE fetch_status + WHEN 'downloaded' THEN 0 + WHEN 'corrupted' THEN 1 + ELSE 2 + END + LIMIT 1 + )[OFFSET(0)] AS r + FROM `{downloads_table}` + WHERE fetch_status IN ('downloaded', 'corrupted', 'failed') + GROUP BY dataset_source_uuid + ) S ON T.dataset_source_uuid = S.dataset_source_uuid WHEN MATCHED AND (T.fetch_status = 'pending' OR T.fetch_status IS NULL) THEN UPDATE SET - T.fetch_status = S.fetch_status, - T.image_width = S.image_width, - T.image_height = S.image_height, - T.image_size = S.image_size, - T.corrupted = S.corrupted + T.fetch_status = S.r.fetch_status, + T.image_width = S.r.image_width, + T.image_height = S.r.image_height, + T.image_size = S.r.image_size, + T.corrupted = S.r.corrupted """ - try: - # Concurrent MERGEs from parallel tasks can collide with - # "Could not serialize access ... due to concurrent update" (400). - # BQ docs recommend retrying — back off with jitter until a slot frees. - for attempt in range(_MERGE_MAX_RETRIES + 1): - try: - job = client.query(merge_sql) - job.result() - return job.dml_stats.updated_row_count - except google_exceptions.BadRequest as e: - if "serialize" not in str(e).lower() or attempt >= _MERGE_MAX_RETRIES: - raise - delay = min(_BACKOFF_BASE * (2 ** attempt), _BACKOFF_MAX) - delay += random.uniform(0, delay * 0.5) - print(f" MERGE serialization conflict — retry " - f"{attempt+1}/{_MERGE_MAX_RETRIES} in {delay:.0f}s", flush=True) - time.sleep(delay) - finally: - client.delete_table(tmp_table, not_found_ok=True) + # Concurrent MERGEs from parallel tasks can collide with + # "Could not serialize access ... due to concurrent update" (400). + # BQ docs recommend retrying — back off with jitter until a slot frees. + for attempt in range(_MERGE_MAX_RETRIES + 1): + try: + job = client.query(merge_sql) + job.result() + return job.dml_stats.updated_row_count + except google_exceptions.BadRequest as e: + if "serialize" not in str(e).lower() or attempt >= _MERGE_MAX_RETRIES: + raise + delay = min(_BACKOFF_BASE * (2 ** attempt), _BACKOFF_MAX) + delay += random.uniform(0, delay * 0.5) + print(f" MERGE serialization conflict — retry " + f"{attempt+1}/{_MERGE_MAX_RETRIES} in {delay:.0f}s", flush=True) + time.sleep(delay) + # Unreachable: the loop either returns or raises on the final attempt. + return 0 def get_pending_rows( @@ -514,22 +547,30 @@ def main(): print(f" Chunk done in {elapsed:.0f}s ({len(chunk)/elapsed:.0f} img/s) " f"downloaded={n_ok:,} failed={n_fail:,} corrupted={n_corrupt:,}", flush=True) - # Append to downloads table (batch load — free tier, parallel-safe) + # Append to downloads table (batch load — free tier, parallel-safe). + # This is the source of truth for progress: get_pending_rows() resumes + # by LEFT JOINing against it, so status is durable without any DML on + # training_images mid-run. The MERGE into training_images runs once at + # the end of the run (see after the loop). write_results_to_bq(client, results, downloads_table) print(f" Written to {downloads_table}", flush=True) - # Inline MERGE into training_images so status is current without a separate job - n_updated = merge_chunk_into_training_images( - client, results, training_table, downloads_table - ) - print(f" Merged {n_updated:,} rows into {training_table}", flush=True) - # Pack images into chunk sqfs then clear raw files to keep inode usage low pack_chunk_to_sqfs(staging_dir, chunk_num, num_workers=4) clear_staging(staging_dir) warn_chunk_accumulation(staging_dir) print(f" Staging cleared (chunk sqfs kept)", flush=True) + # One MERGE per run instead of one per chunk. A MERGE is billed mainly for + # scanning the target table, so the per-chunk version re-scanned the whole + # training_images table on every chunk (~N full scans per pass). Sourcing + # from the already-populated downloads table lets us do it in a single scan. + print(f"\nMerging download status into {training_table}...", flush=True) + n_updated = merge_downloads_into_training_images( + client, training_table, downloads_table + ) + print(f" Merged {n_updated:,} rows into {training_table}", flush=True) + print(f"\n[Task {args.task_id}] Done. " f"downloaded={total_downloaded:,} failed={total_failed:,} " f"corrupted={total_corrupted:,}", flush=True) diff --git a/tests/dataset_tools/test_download_images.py b/tests/dataset_tools/test_download_images.py index f11e7a5..39c69c2 100644 --- a/tests/dataset_tools/test_download_images.py +++ b/tests/dataset_tools/test_download_images.py @@ -257,119 +257,93 @@ def test_empty_staging_returns_none(self, tmp_path): mock_run.assert_not_called() -# ── merge_chunk_into_training_images ───────────────────────────────────────── +# ── merge_downloads_into_training_images ───────────────────────────────────── -class TestMergeChunkIntoTrainingImages: +class TestMergeDownloadsIntoTrainingImages: + """The MERGE now runs once per run, sourced from the downloads table. + + Cost rationale: a MERGE is billed mainly for scanning the *target* table + (~7 GB for global_all_leps_2605), so it must run once per run, not once per + chunk. These tests pin that contract: a single MERGE query, no per-chunk + temp tables, sourced from the downloads table, deduplicated per uuid. + """ def _make_client(self, updated_count: int = 1) -> MagicMock: client = MagicMock() - client.load_table_from_dataframe.return_value.result.return_value = None job = MagicMock() job.dml_stats.updated_row_count = updated_count client.query.return_value = job return client - def test_empty_list_skips_merge(self): - """Completely empty results list → no BQ calls.""" - client = self._make_client() - n = di.merge_chunk_into_training_images(client, [], "t", "d") - assert n == 0 - client.load_table_from_dataframe.assert_not_called() - - def test_downloaded_triggers_merge(self): - """downloaded rows → temp table load + MERGE + cleanup.""" - client = self._make_client(updated_count=1) - results = [{"dataset_source_uuid": "u1", "fetch_status": "downloaded", - "image_width": 100, "image_height": 80, - "image_size": 5000, "corrupted": False}] - n = di.merge_chunk_into_training_images(client, results, "t", "d") - assert n == 1 - assert client.load_table_from_dataframe.call_count == 1 + def test_runs_single_merge_no_temp_table(self): + """One MERGE query, and no temp-table load/delete (the expensive per-chunk + pattern that re-scanned the whole target table).""" + client = self._make_client(updated_count=7) + n = di.merge_downloads_into_training_images(client, "t", "d") + assert n == 7 assert client.query.call_count == 1 - assert client.delete_table.call_count == 1 - - def test_corrupted_triggers_merge(self): - """corrupted rows → merged with fetch_status='corrupted'.""" - client = self._make_client(updated_count=1) - results = [{"dataset_source_uuid": "u1", "fetch_status": "corrupted", - "image_width": None, "image_height": None, - "image_size": None, "corrupted": True}] - n = di.merge_chunk_into_training_images(client, results, "t", "d") - assert n == 1 - assert client.load_table_from_dataframe.call_count == 1 + client.load_table_from_dataframe.assert_not_called() + client.delete_table.assert_not_called() - def test_failed_triggers_merge(self): - """failed rows (404/403/exhausted) → merged so fetch_status='failed' in - training_images. Permanent failures are excluded from future re-runs - via WHERE fetch_status='pending' without needing the LEFT JOIN.""" - client = self._make_client(updated_count=1) - results = [{"dataset_source_uuid": "u1", "fetch_status": "failed", - "image_width": None, "image_height": None, - "image_size": None, "corrupted": None}] - n = di.merge_chunk_into_training_images(client, results, "t", "d") - assert n == 1 - assert client.load_table_from_dataframe.call_count == 1 - assert client.query.call_count == 1 - assert client.delete_table.call_count == 1 - - def test_all_three_statuses_merged_together(self): - """Mixed chunk — downloaded, corrupted, failed — all three trigger one MERGE.""" - client = self._make_client(updated_count=3) - results = [ - {"dataset_source_uuid": "u1", "fetch_status": "downloaded", - "image_width": 100, "image_height": 80, "image_size": 5000, "corrupted": False}, - {"dataset_source_uuid": "u2", "fetch_status": "corrupted", - "image_width": None, "image_height": None, "image_size": 500, "corrupted": True}, - {"dataset_source_uuid": "u3", "fetch_status": "failed", - "image_width": None, "image_height": None, "image_size": None, "corrupted": None}, - ] - n = di.merge_chunk_into_training_images(client, results, "t", "d") - assert n == 3 - assert client.load_table_from_dataframe.call_count == 1 # one temp table for all 3 - assert client.query.call_count == 1 # one MERGE - assert client.delete_table.call_count == 1 # one cleanup + def test_merge_sources_from_downloads_table(self): + """The MERGE reads its source from the downloads table (already populated + via free batch loads), not a freshly loaded temp table.""" + client = self._make_client() + di.merge_downloads_into_training_images( + client, + training_table="leps-ai.global_all_leps_2605.training_images", + downloads_table="leps-ai.global_all_leps_2605.training_images_downloads", + ) + sql = client.query.call_args[0][0] + assert "MERGE `leps-ai.global_all_leps_2605.training_images`" in sql + assert "FROM `leps-ai.global_all_leps_2605.training_images_downloads`" in sql - def test_failed_rows_included_in_temp_table(self): - """Verify the dataframe passed to BQ includes the failed row.""" - import pandas as pd + def test_merge_deduplicates_source_by_uuid(self): + """The downloads table is append-only (and --force-redownload can add a + second row), so the source must collapse to one row per uuid.""" client = self._make_client() - captured_df = {} + di.merge_downloads_into_training_images(client, "t", "d") + sql = client.query.call_args[0][0] + assert "GROUP BY dataset_source_uuid" in sql - def capture_load(df, table, **kwargs): - captured_df["data"] = df.copy() - return MagicMock(result=MagicMock(return_value=None)) + def test_merge_condition_handles_null_fetch_status(self): + """MERGE must update rows where T.fetch_status IS NULL, not just 'pending' + — needed for global_all_leps_2605 whose initial state is NULL.""" + client = self._make_client() + di.merge_downloads_into_training_images( + client, + training_table="leps-ai.global_all_leps_2605.training_images", + downloads_table="leps-ai.global_all_leps_2605.training_images_downloads", + ) + sql = client.query.call_args[0][0] + assert "fetch_status IS NULL" in sql + assert "fetch_status = 'pending'" in sql - client.load_table_from_dataframe.side_effect = capture_load + def test_merge_retries_on_serialization_conflict(self): + """Concurrent MERGEs can raise a 'serialize access' BadRequest; the + function backs off and retries rather than failing the run.""" + from google.api_core import exceptions as google_exceptions - results = [ - {"dataset_source_uuid": "ok", "fetch_status": "downloaded", - "image_width": 64, "image_height": 48, "image_size": 1000, "corrupted": False}, - {"dataset_source_uuid": "dead", "fetch_status": "failed", - "image_width": None, "image_height": None, "image_size": None, "corrupted": None}, + client = MagicMock() + ok = MagicMock() + ok.dml_stats.updated_row_count = 3 + client.query.side_effect = [ + google_exceptions.BadRequest("Could not serialize access to table"), + ok, ] - di.merge_chunk_into_training_images(client, results, "t", "d") + with patch("time.sleep"): + n = di.merge_downloads_into_training_images(client, "t", "d") + assert n == 3 + assert client.query.call_count == 2 - df = captured_df["data"] - assert len(df) == 2 # both rows in temp table - statuses = set(df["fetch_status"].tolist()) - assert "downloaded" in statuses - assert "failed" in statuses # failed row present + def test_merge_reraises_non_serialization_error(self): + """A BadRequest that is not a serialization conflict must propagate.""" + from google.api_core import exceptions as google_exceptions - def test_temp_table_deleted_even_on_merge_failure(self): - """Temp table must be cleaned up even if the MERGE query fails.""" client = MagicMock() - client.load_table_from_dataframe.return_value.result.return_value = None - client.query.side_effect = Exception("MERGE failed") - - results = [{"dataset_source_uuid": "u1", "fetch_status": "downloaded", - "image_width": 100, "image_height": 80, - "image_size": 5000, "corrupted": False}] - - with pytest.raises(Exception, match="MERGE failed"): - di.merge_chunk_into_training_images( - client, results, "training_table", "downloads_table" - ) - client.delete_table.assert_called_once() + client.query.side_effect = google_exceptions.BadRequest("syntax error") + with pytest.raises(google_exceptions.BadRequest, match="syntax error"): + di.merge_downloads_into_training_images(client, "t", "d") # ── get_pending_rows / MOD split ───────────────────────────────────────────── @@ -600,38 +574,27 @@ def test_downloads_table_appends_are_independent(self): assert calls[1][0][1] == "downloads_table" def test_merge_from_two_tasks_updates_correct_rows(self): - """Each task's MERGE only touches its own rows — no cross-task collision.""" + """Each task runs its own end-of-run MERGE; the WHERE fetch_status filter + keeps them from colliding. One MERGE query per task, no temp tables.""" client = MagicMock() - client.load_table_from_dataframe.return_value.result.return_value = None - # task 0 merges even photo_ids job0 = MagicMock() job0.dml_stats.updated_row_count = 5 - # task 1 merges odd photo_ids job1 = MagicMock() job1.dml_stats.updated_row_count = 5 client.query.side_effect = [job0, job1] - task0_results = [{"dataset_source_uuid": f"uuid-{i}", "fetch_status": "downloaded", - "image_width": 64, "image_height": 48, - "image_size": 1000, "corrupted": False} - for i in range(5)] - task1_results = [{"dataset_source_uuid": f"uuid-{i+5}", "fetch_status": "downloaded", - "image_width": 64, "image_height": 48, - "image_size": 1000, "corrupted": False} - for i in range(5)] - - n0 = di.merge_chunk_into_training_images( - client, task0_results, "training_table", "downloads_table" + n0 = di.merge_downloads_into_training_images( + client, "training_table", "downloads_table" ) - n1 = di.merge_chunk_into_training_images( - client, task1_results, "training_table", "downloads_table" + n1 = di.merge_downloads_into_training_images( + client, "training_table", "downloads_table" ) assert n0 == 5 assert n1 == 5 assert client.query.call_count == 2 # one MERGE per task - assert client.delete_table.call_count == 2 # temp table cleaned per task + client.delete_table.assert_not_called() # no temp tables anymore def test_total_coverage_after_all_tasks_complete(self): """After all tasks finish, every image should be accounted for.""" @@ -720,17 +683,12 @@ def test_merge_condition_handles_null_fetch_status(self): """MERGE SQL must allow updating rows where T.fetch_status IS NULL, not just 'pending' — needed for global_all_leps_2605 initial state.""" client = MagicMock() - client.load_table_from_dataframe.return_value.result.return_value = None job = MagicMock() job.dml_stats.updated_row_count = 1 client.query.return_value = job - results = [{"dataset_source_uuid": "u1", "fetch_status": "downloaded", - "image_width": 64, "image_height": 48, - "image_size": 1000, "corrupted": False}] - - di.merge_chunk_into_training_images( - client, results, + di.merge_downloads_into_training_images( + client, training_table="leps-ai.global_all_leps_2605.training_images", downloads_table="leps-ai.global_all_leps_2605.training_images_downloads", ) @@ -739,49 +697,23 @@ def test_merge_condition_handles_null_fetch_status(self): assert "fetch_status IS NULL" in sql assert "fetch_status = 'pending'" in sql - # ── tmp_table derived from training_table ───────────────────────────────── + # ── MERGE targets the dataset from the args, not a hardcoded one ─────────── - def test_tmp_table_uses_same_dataset_as_training_table(self): - """Temp table for MERGE must be in the same dataset as training_table, - not hardcoded to global_butterflies_2604.""" + def test_merge_targets_dataset_from_args(self): + """The MERGE must read and write the dataset passed in (global_all_leps_2605), + never the old hardcoded global_butterflies_2604.""" client = MagicMock() - client.load_table_from_dataframe.return_value.result.return_value = None job = MagicMock() job.dml_stats.updated_row_count = 1 client.query.return_value = job - results = [{"dataset_source_uuid": "u1", "fetch_status": "downloaded", - "image_width": 64, "image_height": 48, - "image_size": 1000, "corrupted": False}] - - di.merge_chunk_into_training_images( - client, results, - training_table="leps-ai.global_all_leps_2605.training_images", - downloads_table="leps-ai.global_all_leps_2605.training_images_downloads", - ) - - # The temp table passed to load_table_from_dataframe must be in global_all_leps_2605 - tmp_table_arg = client.load_table_from_dataframe.call_args[0][1] - assert tmp_table_arg.startswith("leps-ai.global_all_leps_2605.") - assert "global_butterflies_2604" not in tmp_table_arg - - def test_tmp_table_not_in_wrong_dataset_when_using_new_dataset(self): - """Regression: old code used BQ_DATASET module constant — ensure it no longer does.""" - client = MagicMock() - client.load_table_from_dataframe.return_value.result.return_value = None - job = MagicMock() - job.dml_stats.updated_row_count = 1 - client.query.return_value = job - - results = [{"dataset_source_uuid": "u1", "fetch_status": "downloaded", - "image_width": 64, "image_height": 48, - "image_size": 1000, "corrupted": False}] - - di.merge_chunk_into_training_images( - client, results, + di.merge_downloads_into_training_images( + client, training_table="leps-ai.global_all_leps_2605.training_images", downloads_table="leps-ai.global_all_leps_2605.training_images_downloads", ) - tmp_table_arg = client.load_table_from_dataframe.call_args[0][1] - assert "global_butterflies_2604" not in tmp_table_arg # must not leak old dataset + sql = client.query.call_args[0][0] + assert "leps-ai.global_all_leps_2605.training_images" in sql + assert "leps-ai.global_all_leps_2605.training_images_downloads" in sql + assert "global_butterflies_2604" not in sql # must not leak old dataset