Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
153 changes: 97 additions & 56 deletions src/dataset_tools/bq_squashfs/download_images.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,12 +31,41 @@
chunk sqfs files into the final task_N.sqfs archives.
"""

# ─────────────────────────────────────────────────────────────────────────────
# COST NOTE FOR AI ASSISTANTS (and humans) — READ BEFORE ADDING BQ CALLS HERE
#
# This script runs at scale against large tables. Know the sizes before you
# add or change any BigQuery statement:
#
# training_images ~24M rows / ~7 GB (global_all_leps_2605)
# training_images_downloads grows to a comparable size as downloads complete
#
# BigQuery on-demand pricing is $5 per TB *scanned* (not per row written).
# Two rules that matter most in this file:
#
# 1. A MERGE / UPDATE / DELETE is billed mainly for scanning the *target*
# table, regardless of how few rows the source touches. Running one inside
# the per-chunk loop re-scans all ~7 GB on every chunk — at chunk_size
# 10000 that is ~2400 full-table scans (~17 TB, ~$85) for a single pass.
# Do DML *once per run*, not once per chunk. See
# merge_downloads_into_training_images() below.
#
# 2. Appending rows via load_table_from_dataframe() (batch load) is free and
# parallel-safe — prefer it over streaming inserts or per-row DML. The
# downloads table is the source of truth for progress; status only needs
# to land in training_images once, at the end.
#
# Before adding a query: estimate bytes scanned (a full SELECT * over
# training_images is ~7 GB = ~$0.035 each — cheap once, expensive in a loop),
# filter/project columns, and never put unbounded DML inside a loop. To verify
# real spend, check INFORMATION_SCHEMA.JOBS_BY_PROJECT (total_bytes_billed).
# ─────────────────────────────────────────────────────────────────────────────

import argparse
import random
import subprocess
import threading
import time
import uuid
from concurrent.futures import ThreadPoolExecutor, as_completed
from pathlib import Path

Expand Down Expand Up @@ -219,15 +248,14 @@ def write_results_to_bq(
raise


def merge_chunk_into_training_images(
def merge_downloads_into_training_images(
client: bigquery.Client,
results: list[dict],
training_table: str,
downloads_table: str,
) -> int:
"""MERGE all chunk results into training_images, updating fetch_status for every outcome.
"""MERGE all recorded download results into training_images, once per run.

Outcomes merged:
Updates fetch_status (and dims) for every outcome:
downloaded — fetch_status='downloaded', dims and corrupted populated
corrupted — fetch_status='corrupted', corrupted=True, dims NULL
failed — fetch_status='failed', all fields NULL
Expand All @@ -237,56 +265,61 @@ def merge_chunk_into_training_images(
the WHERE fetch_status='pending' clause — no wasted retry attempts.
Retrying can still be done intentionally via retry_failed_downloads.py.

Uses a temp table containing only this chunk's rows so the MERGE scans
a small dataset rather than all of training_images_downloads.
Only updates rows that are still 'pending' — safe from parallel tasks.
Returns the number of rows updated.
"""
to_merge = [r for r in results if r["fetch_status"] in ("downloaded", "corrupted", "failed")]
if not to_merge:
return 0

# Derive "project.dataset" from training_table ("project.dataset.table_name")
_parts = training_table.split(".")
_dataset_ref = ".".join(_parts[:2]) if len(_parts) >= 2 else _parts[0]
tmp_table = f"{_dataset_ref}._dl_merge_tmp_{uuid.uuid4().hex[:8]}"
df = pd.DataFrame(to_merge)
job_config = bigquery.LoadJobConfig(
write_disposition=bigquery.WriteDisposition.WRITE_TRUNCATE,
schema=DOWNLOADS_SCHEMA,
)
client.load_table_from_dataframe(df, tmp_table, job_config=job_config).result()
Source is the downloads table, which already holds every result (appended
per chunk via the free batch-load tier). Because a MERGE is billed mainly
for scanning the *target* table, running it once per run instead of once
per chunk turns N full-table scans into a single one. The source is
deduplicated per dataset_source_uuid (the table is append-only and
--force-redownload can append a second row), preferring a successful
outcome over a failed one.

Only updates rows still 'pending' — safe to re-run and safe from parallel
tasks. Returns the number of rows updated.
"""
merge_sql = f"""
MERGE `{training_table}` T
USING `{tmp_table}` S
USING (
SELECT
dataset_source_uuid,
ARRAY_AGG(
STRUCT(fetch_status, image_width, image_height, image_size, corrupted)
ORDER BY CASE fetch_status
WHEN 'downloaded' THEN 0
WHEN 'corrupted' THEN 1
ELSE 2
END
LIMIT 1
)[OFFSET(0)] AS r
FROM `{downloads_table}`
WHERE fetch_status IN ('downloaded', 'corrupted', 'failed')
GROUP BY dataset_source_uuid
) S
ON T.dataset_source_uuid = S.dataset_source_uuid
WHEN MATCHED AND (T.fetch_status = 'pending' OR T.fetch_status IS NULL) THEN UPDATE SET
T.fetch_status = S.fetch_status,
T.image_width = S.image_width,
T.image_height = S.image_height,
T.image_size = S.image_size,
T.corrupted = S.corrupted
T.fetch_status = S.r.fetch_status,
T.image_width = S.r.image_width,
T.image_height = S.r.image_height,
T.image_size = S.r.image_size,
T.corrupted = S.r.corrupted
"""
try:
# Concurrent MERGEs from parallel tasks can collide with
# "Could not serialize access ... due to concurrent update" (400).
# BQ docs recommend retrying — back off with jitter until a slot frees.
for attempt in range(_MERGE_MAX_RETRIES + 1):
try:
job = client.query(merge_sql)
job.result()
return job.dml_stats.updated_row_count
except google_exceptions.BadRequest as e:
if "serialize" not in str(e).lower() or attempt >= _MERGE_MAX_RETRIES:
raise
delay = min(_BACKOFF_BASE * (2 ** attempt), _BACKOFF_MAX)
delay += random.uniform(0, delay * 0.5)
print(f" MERGE serialization conflict — retry "
f"{attempt+1}/{_MERGE_MAX_RETRIES} in {delay:.0f}s", flush=True)
time.sleep(delay)
finally:
client.delete_table(tmp_table, not_found_ok=True)
# Concurrent MERGEs from parallel tasks can collide with
# "Could not serialize access ... due to concurrent update" (400).
# BQ docs recommend retrying — back off with jitter until a slot frees.
for attempt in range(_MERGE_MAX_RETRIES + 1):
try:
job = client.query(merge_sql)
job.result()
return job.dml_stats.updated_row_count
except google_exceptions.BadRequest as e:
if "serialize" not in str(e).lower() or attempt >= _MERGE_MAX_RETRIES:
raise
delay = min(_BACKOFF_BASE * (2 ** attempt), _BACKOFF_MAX)
delay += random.uniform(0, delay * 0.5)
print(f" MERGE serialization conflict — retry "
f"{attempt+1}/{_MERGE_MAX_RETRIES} in {delay:.0f}s", flush=True)
time.sleep(delay)
# Unreachable: the loop either returns or raises on the final attempt.
return 0


def get_pending_rows(
Expand Down Expand Up @@ -514,22 +547,30 @@ def main():
print(f" Chunk done in {elapsed:.0f}s ({len(chunk)/elapsed:.0f} img/s) "
f"downloaded={n_ok:,} failed={n_fail:,} corrupted={n_corrupt:,}", flush=True)

# Append to downloads table (batch load — free tier, parallel-safe)
# Append to downloads table (batch load — free tier, parallel-safe).
# This is the source of truth for progress: get_pending_rows() resumes
# by LEFT JOINing against it, so status is durable without any DML on
# training_images mid-run. The MERGE into training_images runs once at
# the end of the run (see after the loop).
write_results_to_bq(client, results, downloads_table)
print(f" Written to {downloads_table}", flush=True)

# Inline MERGE into training_images so status is current without a separate job
n_updated = merge_chunk_into_training_images(
client, results, training_table, downloads_table
)
print(f" Merged {n_updated:,} rows into {training_table}", flush=True)

# Pack images into chunk sqfs then clear raw files to keep inode usage low
pack_chunk_to_sqfs(staging_dir, chunk_num, num_workers=4)
clear_staging(staging_dir)
warn_chunk_accumulation(staging_dir)
print(f" Staging cleared (chunk sqfs kept)", flush=True)

# One MERGE per run instead of one per chunk. A MERGE is billed mainly for
# scanning the target table, so the per-chunk version re-scanned the whole
# training_images table on every chunk (~N full scans per pass). Sourcing
# from the already-populated downloads table lets us do it in a single scan.
print(f"\nMerging download status into {training_table}...", flush=True)
n_updated = merge_downloads_into_training_images(
client, training_table, downloads_table
)
print(f" Merged {n_updated:,} rows into {training_table}", flush=True)

print(f"\n[Task {args.task_id}] Done. "
f"downloaded={total_downloaded:,} failed={total_failed:,} "
f"corrupted={total_corrupted:,}", flush=True)
Expand Down
Loading