From 98dc4e219af81cd9f99040464187b814dfdf135b Mon Sep 17 00:00:00 2001 From: Mohamed Elabbas Date: Fri, 1 May 2026 11:58:21 -0700 Subject: [PATCH 01/26] Add BQ/SquashFS pipeline scripts and documentation MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Five-stage pipeline for building a WebDataset from the BigQuery training_images table on the fir cluster: 1. download - parallel SLURM array job (×10), downloads images from iNaturalist, verifies with PIL, writes results back to BQ, packs into per-chunk sqfs files 2. bq_export - streams qualifying images metadata from BQ to CSV 3. split - stratified train/val/test split with occurrence-level grouping to prevent data leakage 4. webdataset - two implementations: fir-specific NVMe-optimised packer (create_webdataset.py) and a generic version for pre-mounted directories (create_webdataset_generic.py) 5. train - ResNet-50 classifier with auto-resume from checkpoint README documents each stage, the task_0…task_9 sqfs scheme, why two webdataset implementations exist, and a one-liner to chain the full pipeline. Co-Authored-By: Claude Sonnet 4.6 --- scripts/job_bq_download.sh | 51 ++ scripts/job_bq_export.sh | 48 ++ scripts/job_bq_split.sh | 70 +++ scripts/job_bq_train.sh | 113 ++++ scripts/job_bq_webdataset.sh | 144 +++++ src/dataset_tools/bq_squashfs/README.md | 243 ++++++++ src/dataset_tools/bq_squashfs/bq_export.py | 120 ++++ .../bq_squashfs/create_webdataset.py | 537 ++++++++++++++++++ .../bq_squashfs/create_webdataset_generic.py | 303 ++++++++++ .../bq_squashfs/download_images.py | 297 ++++++++++ .../bq_squashfs/queries/global_max2000img.sql | 31 + .../bq_squashfs/queries/global_min25occ.sql | 23 + src/dataset_tools/bq_squashfs/split.py | 231 ++++++++ 13 files changed, 2211 insertions(+) create mode 100755 scripts/job_bq_download.sh create mode 100644 scripts/job_bq_export.sh create mode 100644 scripts/job_bq_split.sh create mode 100644 scripts/job_bq_train.sh create mode 100644 scripts/job_bq_webdataset.sh create mode 100644 src/dataset_tools/bq_squashfs/README.md create mode 100644 src/dataset_tools/bq_squashfs/bq_export.py create mode 100644 src/dataset_tools/bq_squashfs/create_webdataset.py create mode 100644 src/dataset_tools/bq_squashfs/create_webdataset_generic.py create mode 100644 src/dataset_tools/bq_squashfs/download_images.py create mode 100644 src/dataset_tools/bq_squashfs/queries/global_max2000img.sql create mode 100644 src/dataset_tools/bq_squashfs/queries/global_min25occ.sql create mode 100644 src/dataset_tools/bq_squashfs/split.py diff --git a/scripts/job_bq_download.sh b/scripts/job_bq_download.sh new file mode 100755 index 0000000..b128098 --- /dev/null +++ b/scripts/job_bq_download.sh @@ -0,0 +1,51 @@ +#!/bin/bash +# Download images from training_images BQ table in parallel. +# Runs as a SLURM array job — each task handles MOD(photo_id, NUM_JOBS) = task_id. +# After all tasks finish, run job_bq_pack_squashfs.sh to merge into a single SquashFS. +# +# Usage: +# sbatch job_bq_download.sh +# # or submit and chain the pack job: +# DOWNLOAD_JOB=$(sbatch --parsable job_bq_download.sh) +# sbatch --dependency=afterok:$DOWNLOAD_JOB job_bq_pack_squashfs.sh +# +#SBATCH --account=def-drolnick +#SBATCH --job-name=bq_download +#SBATCH --cpus-per-task=32 +#SBATCH --mem=64G +#SBATCH --time=72:00:00 +#SBATCH --array=0-9 +#SBATCH --output=/project/6068129/melabbas/ami-ml/scripts/bq_download_%A_%a.out +#SBATCH --mail-type=BEGIN,END,FAIL +#SBATCH --mail-user=hack1996man@gmail.com + +NUM_JOBS=10 +TASK_ID=${SLURM_ARRAY_TASK_ID} + +# Each job downloads its images into its own staging subdirectory +# These are kept after the job ends (on Lustre) for the pack job to merge +STAGING_BASE="/scratch/melabbas/bq_download_staging" +STAGING_DIR="${STAGING_BASE}/task_${TASK_ID}" + +echo "=== bq_download task=${TASK_ID}/${NUM_JOBS} started at $(date) ===" +echo "Node: $(hostname)" +echo "Staging dir: ${STAGING_DIR}" + +mkdir -p "${STAGING_DIR}" + +cd /project/6068129/melabbas/ami-ml +module load StdEnv/2023 arrow/17.0.0 +source .venv/bin/activate + +python src/dataset_tools/bq_squashfs/download_images.py \ + --staging-dir "${STAGING_DIR}" \ + --num-jobs ${NUM_JOBS} \ + --task-id ${TASK_ID} \ + --num-workers 32 \ + --chunk-size 10000 + +EXIT_CODE=$? +echo "=== bq_download task=${TASK_ID} finished at $(date) (exit=${EXIT_CODE}) ===" + +notify "bq_download task ${TASK_ID}: done" \ + "Staging: ${STAGING_DIR} | exit=${EXIT_CODE}" diff --git a/scripts/job_bq_export.sh b/scripts/job_bq_export.sh new file mode 100644 index 0000000..1b04283 --- /dev/null +++ b/scripts/job_bq_export.sh @@ -0,0 +1,48 @@ +#!/bin/bash +# Export a BigQuery query to CSV. +# +# Usage: +# sbatch --export=QUERY_FILE=queries/global_min25occ.sql,OUTPUT=global_min25occ.csv job_bq_export.sh +# sbatch --export=QUERY_FILE=queries/global_max2000img.sql,OUTPUT=global_max2000img.csv job_bq_export.sh +# +#SBATCH --account=def-drolnick +#SBATCH --job-name=bq_export +#SBATCH --cpus-per-task=2 +#SBATCH --mem=8G +#SBATCH --time=1:00:00 +#SBATCH --output=/project/6068129/melabbas/ami-ml/scripts/bq_export_%j.out +#SBATCH --mail-type=END,FAIL +#SBATCH --mail-user=hack1996man@gmail.com + +set -euo pipefail + +BASE_DIR="/home/melabbas/projects/def-drolnick/melabbas/ami-ml" +DATA_DIR="/project/6068129/melabbas/ami-ml/data" + +echo "=== bq_export started at $(date) ===" +echo "Node : $(hostname)" +echo "QUERY_FILE : ${QUERY_FILE}" +echo "OUTPUT : ${DATA_DIR}/${OUTPUT}" +echo "" + +cd "${BASE_DIR}" +module load StdEnv/2023 arrow/17.0.0 +source .venv/bin/activate + +python src/dataset_tools/bq_squashfs/bq_export.py \ + --query-file "src/dataset_tools/bq_squashfs/${QUERY_FILE}" \ + --output "${DATA_DIR}/${OUTPUT}" \ + --project leps-ai + +EXIT_CODE=$? +echo "" +echo "=== bq_export done at $(date) (exit=${EXIT_CODE}) ===" + +if [ "${EXIT_CODE}" -eq 0 ]; then + SIZE=$(du -sh "${DATA_DIR}/${OUTPUT}" | cut -f1) + ROWS=$(( $(wc -l < "${DATA_DIR}/${OUTPUT}") - 1 )) + notify "bq_export done" "${OUTPUT} ${ROWS} rows ${SIZE}" +else + notify "bq_export FAILED" "exit=${EXIT_CODE} — check bq_export_${SLURM_JOB_ID}.out" + exit 1 +fi diff --git a/scripts/job_bq_split.sh b/scripts/job_bq_split.sh new file mode 100644 index 0000000..082e61b --- /dev/null +++ b/scripts/job_bq_split.sh @@ -0,0 +1,70 @@ +#!/bin/bash +# Split a BQ-exported CSV into train/val/test CSVs. +# Reads from DATA_DIR/, writes train.csv/val.csv/test.csv to DATA_DIR/splits/. +# +# Must run after job_bq_export.sh completes. +# +# Usage: +# sbatch --dependency=afterok: job_bq_split.sh +# sbatch --export=CSV=global_min25occ.csv --dependency=afterok: job_bq_split.sh +# +# CSV defaults to global_min25occ.csv if not set via --export. +# +#SBATCH --account=def-drolnick +#SBATCH --job-name=bq_split +#SBATCH --cpus-per-task=2 +#SBATCH --mem=8G +#SBATCH --time=0:30:00 +#SBATCH --output=/project/6068129/melabbas/ami-ml/scripts/bq_split_%j.out +#SBATCH --mail-type=END,FAIL +#SBATCH --mail-user=hack1996man@gmail.com + +set -euo pipefail + +BASE_DIR="/project/6068129/melabbas/ami-ml" +DATA_DIR="${BASE_DIR}/data" +CSV="${CSV:-global_min25occ.csv}" +INPUT_CSV="${DATA_DIR}/${CSV}" +OUTPUT_DIR="${DATA_DIR}/splits" + +echo "=== bq_split started at $(date) ===" +echo "Node : $(hostname)" +echo "Input CSV : ${INPUT_CSV} ($(du -sh ${INPUT_CSV} | cut -f1))" +echo "Output dir : ${OUTPUT_DIR}" +echo "" + +cd "${BASE_DIR}" +module load StdEnv/2023 arrow/17.0.0 +source .venv/bin/activate + +mkdir -p "${OUTPUT_DIR}" + +python src/dataset_tools/bq_squashfs/split.py \ + --csv "${INPUT_CSV}" \ + --output-dir "${OUTPUT_DIR}" \ + --category-key species_name \ + --val-frac 0.1 \ + --test-frac 0.1 \ + --split-by-occurrence \ + --max-instances 1000 \ + --min-instances 5 \ + --seed 42 + +EXIT_CODE=$? +echo "" +echo "=== bq_split done at $(date) (exit=${EXIT_CODE}) ===" + +if [ "${EXIT_CODE}" -eq 0 ]; then + TRAIN_ROWS=$(( $(wc -l < "${OUTPUT_DIR}/train.csv") - 1 )) + VAL_ROWS=$(( $(wc -l < "${OUTPUT_DIR}/val.csv") - 1 )) + TEST_ROWS=$(( $(wc -l < "${OUTPUT_DIR}/test.csv") - 1 )) + echo " train : ${TRAIN_ROWS} rows" + echo " val : ${VAL_ROWS} rows" + echo " test : ${TEST_ROWS} rows" + notify "bq_split: done" \ + "train=${TRAIN_ROWS} val=${VAL_ROWS} test=${TEST_ROWS} → ${OUTPUT_DIR}" +else + notify "bq_split: FAILED" \ + "exit=${EXIT_CODE} — check bq_split_${SLURM_JOB_ID}.out" + exit 1 +fi diff --git a/scripts/job_bq_train.sh b/scripts/job_bq_train.sh new file mode 100644 index 0000000..a6e0d8e --- /dev/null +++ b/scripts/job_bq_train.sh @@ -0,0 +1,113 @@ +#!/bin/bash +# Train a ResNet-50 classifier on the global WebDataset. +# Must run after job_build_wds_global_v2.sh completes. +# +# Reads : /scratch/melabbas/global_wds/{train,val,test}/*.tar + class_map.json +# Writes : /project/6068129/melabbas/ami-ml/models/global_wds/ +# +# Usage: +# sbatch --dependency=afterok: job_bq_train.sh +# +#SBATCH --account=def-drolnick_gpu +#SBATCH --job-name=bq_train +#SBATCH --gres=gpu:1 +#SBATCH --cpus-per-task=16 +#SBATCH --mem=64G +#SBATCH --time=3-00:00:00 +#SBATCH --output=/project/6068129/melabbas/ami-ml/scripts/bq_train_%j.out +#SBATCH --mail-type=BEGIN,END,FAIL +#SBATCH --mail-user=hack1996man@gmail.com + +set -euo pipefail + +BASE_DIR="/home/melabbas/projects/def-drolnick/melabbas/ami-ml" +WBDS="/scratch/melabbas/global_wds" +MODELS="/project/6068129/melabbas/ami-ml/models/global_wds" + +mkdir -p "${MODELS}" + +cd "${BASE_DIR}" +source .venv/bin/activate + +if [[ -f .env ]]; then + set -a + source .env + set +a +fi + +export WANDB_API_KEY="${WANDB_API_KEY_HACK1996MAN}" + +# Dynamically resolve shard counts and num_classes from the WDS output +N_TRAIN=$(find "${WBDS}/train" -name "train-*.tar" | wc -l) +N_VAL=$(find "${WBDS}/val" -name "val-*.tar" | wc -l) +N_TEST=$(find "${WBDS}/test" -name "test-*.tar" | wc -l) +NUM_CLASSES=$(python3 -c "import json; print(len(json.load(open('${WBDS}/class_map.json'))))") + +if [[ "${N_TRAIN}" -eq 0 ]]; then + echo "ERROR: no train shards found in ${WBDS}/train" + notify "bq_train: FAILED" "No train shards found — did job_build_wds_global_v2 complete?" + exit 1 +fi + +TRAIN_PAT="${WBDS}/train/train-{000000..$(printf '%06d' $((N_TRAIN-1)))}.tar" +VAL_PAT="${WBDS}/val/val-{000000..$(printf '%06d' $((N_VAL-1)))}.tar" +TEST_PAT="${WBDS}/test/test-{000000..$(printf '%06d' $((N_TEST-1)))}.tar" + +echo "=== bq_train started at $(date) ===" +echo "Node : $(hostname)" +echo "Train shards: ${N_TRAIN}" +echo "Val shards : ${N_VAL}" +echo "Test shards : ${N_TEST}" +echo "Num classes : ${NUM_CLASSES}" +echo "Models dir : ${MODELS}" +echo "" + +# Resume from latest checkpoint if available +RESUME_FLAG="" +RESUME_CKPT=$(ls -t "${MODELS}"/*_latest.pt 2>/dev/null | head -1 || true) +if [[ -z "${RESUME_CKPT}" ]]; then + RESUME_CKPT=$(ls -t "${MODELS}"/*_checkpoint.pt 2>/dev/null | head -1 || true) +fi +if [[ -n "${RESUME_CKPT}" ]]; then + echo "Resuming from checkpoint: ${RESUME_CKPT}" + RESUME_FLAG="--resume_from_checkpoint ${RESUME_CKPT}" +else + echo "No checkpoint found — starting fresh" +fi +echo "" + +ami-classification train-model \ + --train_webdataset "${TRAIN_PAT}" \ + --val_webdataset "${VAL_PAT}" \ + --test_webdataset "${TEST_PAT}" \ + --num_classes "${NUM_CLASSES}" \ + --model_type resnet50 \ + --image_input_size 128 \ + --model_save_directory "${MODELS}" \ + --total_epochs 30 \ + --warmup_epochs 2 \ + --early_stopping 100 \ + --learning_rate 0.001 \ + --learning_rate_scheduler cosine \ + --weight_decay 1e-5 \ + --batch_size 128 \ + --preprocess_mode torch \ + --mixed_resolution_data_aug true \ + --random_seed 123 \ + --wandb_entity "hack1996man" \ + --wandb_project "ai_for_leps" \ + --wandb_run_name "global_wds_bq_run1" \ + ${RESUME_FLAG} + +EXIT_CODE=$? +echo "" +echo "=== bq_train done at $(date) (exit=${EXIT_CODE}) ===" + +if [ "${EXIT_CODE}" -eq 0 ]; then + notify "bq_train: done" \ + "exit=0 classes=${NUM_CLASSES} train=${N_TRAIN} val=${N_VAL} test=${N_TEST} shards — models in ${MODELS}" +else + notify "bq_train: FAILED" \ + "exit=${EXIT_CODE} — check bq_train_${SLURM_JOB_ID}.out" + exit 1 +fi diff --git a/scripts/job_bq_webdataset.sh b/scripts/job_bq_webdataset.sh new file mode 100644 index 0000000..4acb8d4 --- /dev/null +++ b/scripts/job_bq_webdataset.sh @@ -0,0 +1,144 @@ +#!/bin/bash +# Build global WebDataset from all 10 sqfs files. +# +# Requires a BQ-exported CSV at CSV_PATH with columns: +# photo_id, relative_local_path, species_name (or class_id), + any other metadata +# +# Strategy: two batches of BATCH_SIZE sqfs to stay under 7 TB NVMe peak: +# Batch 1 (sqfs 0-4): --tar-mode w (create new tars) +# Batch 2 (sqfs 5-9): --tar-mode a (append, idempotent on retry) +# +# Export the CSV from BQ before submitting: +# bq query --format=csv --max_rows=15000000 \ +# "SELECT ti.photo_id, ti.relative_local_path, ti.dataset_source_uuid, +# tx.species_name, tx.inat_taxon_id, tx.family, tx.gbif_accepted_taxon_key +# FROM leps-ai.global_butterflies_2604.training_images ti +# JOIN leps-ai.global_butterflies_2604.inat_taxa tx USING (inat_taxon_id) +# JOIN (SELECT DISTINCT dataset_source_uuid +# FROM leps-ai.global_butterflies_2604.training_images_downloads +# WHERE fetch_status='downloaded') d USING (dataset_source_uuid) +# WHERE tx.species_name IS NOT NULL" > /project/6068129/melabbas/ami-ml/data/global_wds_export.csv +# +#SBATCH --account=def-drolnick +#SBATCH --job-name=build_wds_global +#SBATCH --cpus-per-task=16 +#SBATCH --mem=64G +#SBATCH --tmp=7000G +#SBATCH --time=12:00:00 +#SBATCH --output=/project/6068129/melabbas/ami-ml/scripts/build_wds_global_%j.out +#SBATCH --mail-type=BEGIN,END,FAIL +#SBATCH --mail-user=hack1996man@gmail.com +#SBATCH --exclude=fc30554 + +NVME="${SLURM_TMPDIR}" +OUTPUT_DIR="/scratch/melabbas/global_wds" +SPLITS_DIR="/project/6068129/melabbas/ami-ml/data/splits" # contains train.csv val.csv test.csv +IMAGES_PER_SHARD=1000 +BATCH_SIZE=5 +PACK_WORKERS=16 + +echo "=== build_wds_global started at $(date) ===" +echo "Node: $(hostname)" +echo "NVMe: ${NVME} ($(df -h ${NVME} | tail -1 | awk '{print $4}') free)" +echo "CSV: ${CSV_PATH} ($(du -sh ${CSV_PATH} | cut -f1))" +echo "" + +for SPLIT in train val test; do + if [ ! -f "${SPLITS_DIR}/${SPLIT}.csv" ]; then + echo "ERROR: ${SPLITS_DIR}/${SPLIT}.csv not found" + echo "Run split_csv.py first to generate the split CSVs." + exit 1 + fi +done +echo "Split CSVs: $(wc -l < ${SPLITS_DIR}/train.csv) train $(wc -l < ${SPLITS_DIR}/val.csv) val $(wc -l < ${SPLITS_DIR}/test.csv) test rows" +echo "" + +# ── Locate all sqfs files ───────────────────────────────────────────────────── +find_sqfs() { + local T=$1 + for DIR in /project/rrg-bengioy-ad/melabbas /project/6068129/melabbas /scratch/melabbas; do + [ -f "${DIR}/task_${T}.sqfs" ] && echo "${DIR}/task_${T}.sqfs" && return 0 + done + echo "" +} + +SQFS_PATHS=() +echo "Locating sqfs files..." +for T in $(seq 0 9); do + P=$(find_sqfs $T) + if [ -z "$P" ]; then + echo "ERROR: could not find task_${T}.sqfs" + exit 1 + fi + SQFS_PATHS+=("$P") + echo " task_${T}: ${P} ($(du -sh ${P} | cut -f1))" +done +echo "" + +# ── Setup Lustre output with HDD stripe ────────────────────────────────────── +for SPLIT in train val test; do + mkdir -p "${OUTPUT_DIR}/${SPLIT}" + lfs setstripe -c -1 -S 4m -p ddn_hdd "${OUTPUT_DIR}/${SPLIT}" 2>/dev/null || true +done +echo "Output dir: ${OUTPUT_DIR}" +echo "" + +# ── Python environment ──────────────────────────────────────────────────────── +cd /project/6068129/melabbas/ami-ml +module load StdEnv/2023 arrow/17.0.0 +source .venv/bin/activate + +# ── Batch 1: sqfs 0–(BATCH_SIZE-1), create new tars ────────────────────────── +BATCH1_END=$((BATCH_SIZE - 1)) +echo "=== BATCH 1: sqfs 0–${BATCH1_END} (mode=w) at $(date) ===" + +python src/dataset_tools/bq_squashfs/create_webdataset.py \ + --split-csvs train:${SPLITS_DIR}/train.csv val:${SPLITS_DIR}/val.csv test:${SPLITS_DIR}/test.csv \ + --sqfs-paths "${SQFS_PATHS[@]:0:${BATCH_SIZE}}" \ + --sqfs-start-idx 0 \ + --images-per-shard ${IMAGES_PER_SHARD} \ + --nvme-dir "${NVME}" \ + --output-dir "${OUTPUT_DIR}" \ + --pack-workers ${PACK_WORKERS} \ + --tar-mode w + +BATCH1_EXIT=$? +echo "" +if [ ${BATCH1_EXIT} -ne 0 ]; then + echo "ERROR: batch 1 failed (exit=${BATCH1_EXIT})" + notify "build_wds_global: FAILED (batch 1)" \ + "exit=${BATCH1_EXIT} — check build_wds_global_${SLURM_JOB_ID}.out" + exit ${BATCH1_EXIT} +fi + +# ── Batch 2: sqfs BATCH_SIZE–9, append to existing tars ────────────────────── +echo "=== BATCH 2: sqfs ${BATCH_SIZE}–9 (mode=a) at $(date) ===" + +python src/dataset_tools/bq_squashfs/create_webdataset.py \ + --split-csvs train:${SPLITS_DIR}/train.csv val:${SPLITS_DIR}/val.csv test:${SPLITS_DIR}/test.csv \ + --sqfs-paths "${SQFS_PATHS[@]:${BATCH_SIZE}}" \ + --sqfs-start-idx ${BATCH_SIZE} \ + --images-per-shard ${IMAGES_PER_SHARD} \ + --nvme-dir "${NVME}" \ + --output-dir "${OUTPUT_DIR}" \ + --pack-workers ${PACK_WORKERS} \ + --tar-mode a + +BATCH2_EXIT=$? +echo "" +echo "=== build_wds_global done at $(date) (exit=${BATCH2_EXIT}) ===" + +if [ ${BATCH2_EXIT} -eq 0 ]; then + for SPLIT in train val test; do + COUNT=$(ls "${OUTPUT_DIR}/${SPLIT}/"*.tar 2>/dev/null | wc -l) + SIZE=$(du -sh "${OUTPUT_DIR}/${SPLIT}" 2>/dev/null | cut -f1) + echo " ${SPLIT}: ${COUNT} shards ${SIZE}" + done + notify "build_wds_global: done" \ + "exit=0 job=${SLURM_JOB_ID} — train/val/test shards written to ${OUTPUT_DIR}" +else + notify "build_wds_global: FAILED (batch 2)" \ + "exit=${BATCH2_EXIT} — check build_wds_global_${SLURM_JOB_ID}.out" +fi + +exit ${BATCH2_EXIT} diff --git a/src/dataset_tools/bq_squashfs/README.md b/src/dataset_tools/bq_squashfs/README.md new file mode 100644 index 0000000..5972b6a --- /dev/null +++ b/src/dataset_tools/bq_squashfs/README.md @@ -0,0 +1,243 @@ +# BQ/SquashFS Pipeline + +End-to-end pipeline for building a WebDataset training set from the BigQuery +`training_images` table on the **fir** cluster (Compute Canada / DRAC). + +## Overview + +``` +[BigQuery: training_images] + │ + ▼ Stage 1: download + task_0.sqfs … task_9.sqfs ← all downloaded images, split into 10 chunks + │ + ▼ Stage 2: bq_export + global_min25occ.csv ← metadata for qualifying images (species, paths, taxon IDs) + │ + ▼ Stage 3: split + splits/train.csv + splits/val.csv + splits/test.csv + │ + ▼ Stage 4: webdataset + global_wds/{train,val,test}/*.tar + class_map.json + │ + ▼ Stage 5: train + models/global_wds/*.pt +``` + +--- + +## Data Source: BigQuery `training_images` + +All images originate from the `leps-ai.global_butterflies_2604.training_images` +BigQuery table. Each row represents one image with: + +- `photo_id` — iNaturalist photo ID (integer) +- `gbif_id` — GBIF occurrence ID (multiple images can share the same occurrence) +- `relative_local_path` — path of the image file within its SquashFS chunk +- `inat_taxon_id` — iNaturalist taxon ID (used as the class label) +- `fetch_status` — `'downloaded'` once the image has been fetched to disk + +Only rows with `fetch_status = 'downloaded'` are used in the pipeline. + +--- + +## SquashFS Chunks: task_0 … task_9 + +Images on disk are stored in 10 SquashFS archives (`task_0.sqfs` … `task_9.sqfs`), +located at `/project/rrg-bengioy-ad/melabbas/` and `/project/6068129/melabbas/`. + +The task assignment is deterministic: + +``` +task_id = photo_id % 10 +``` + +So `task_3.sqfs` contains all images whose `photo_id` ends in 3. The +`relative_local_path` column in BigQuery is the path of the image *within* its +sqfs archive. This split allows parallel downloading across 10 SLURM array tasks +and efficient batched processing during WebDataset creation. + +--- + +## Stage 1 — Download + +**SLURM job:** `scripts/job_bq_download.sh` +**Python:** `download_images.py` + +Runs as a SLURM array job (10 tasks). Each task handles `photo_id % 10 == task_id` +and performs the full download-verify-pack loop: + +1. Queries BQ for pending images assigned to this task, skipping any already + recorded in `training_images_downloads` — fully resumable on resubmit. +2. Downloads images in parallel (32 workers) from `absolute_url`. +3. Verifies each image with PIL (width, height, corruption check). +4. Writes results back to the `training_images_downloads` BQ table + (`fetch_status`: `downloaded`, `failed`, or `corrupted`). +5. Every 10,000 images, packs the staging dir into a `chunk_NNNN.sqfs` file + using `mksquashfs`, then deletes the raw images to keep inode usage low. + +After all 10 tasks complete, the per-chunk sqfs files in each staging dir are +merged into the final `task_N.sqfs` archives by the pack job. + +```bash +sbatch scripts/job_bq_download.sh +``` + +**Output:** `task_0.sqfs` … `task_9.sqfs` + +--- + +## Stage 2 — BQ Export + +**SLURM job:** `scripts/job_bq_export.sh` +**Python:** `bq_export.py` + +Runs a SQL query against BigQuery and streams the results to a CSV file on +Lustre. The query joins `training_images` with `inat_taxa` to attach species +names and taxonomy, and filters to images with `fetch_status = 'downloaded'`. + +Two queries are available under `queries/`: + +| Query file | Filter | Expected rows | +|---|---|---| +| `global_min25occ.sql` | Species with ≥ 25 distinct GBIF occurrences | ~10.6M images, ~4,704 species | +| `global_max2000img.sql` | Cap at 2,000 images per species | smaller subset | + +```bash +# Default (global_min25occ.csv) +sbatch --dependency=afterok: scripts/job_bq_export.sh + +# Custom query +sbatch --export=QUERY_FILE=queries/global_max2000img.sql,OUTPUT=global_max2000img.csv \ + --dependency=afterok: scripts/job_bq_export.sh +``` + +**Input:** BigQuery `training_images` + `inat_taxa` tables +**Output:** `data/global_min25occ.csv` (or custom filename) + +--- + +## Stage 3 — Split + +**SLURM job:** `scripts/job_bq_split.sh` +**Python:** `split.py` + +Splits the BQ-exported CSV into `train.csv`, `val.csv`, and `test.csv` using +stratified sampling so every species is proportionally represented in all three +sets. + +Key behaviours: +- **Split by occurrence** — images that share the same `gbif_id` (same field + observation) are always kept in the same split, preventing data leakage. +- **Max instances** — caps images per species at 1,000 for train (proportional + for val/test) to avoid class imbalance dominating training. +- **Min instances** — species with fewer than 5 training images are dropped. + +```bash +sbatch --dependency=afterok: scripts/job_bq_split.sh + +# Custom input CSV +sbatch --export=CSV=global_max2000img.csv \ + --dependency=afterok: scripts/job_bq_split.sh +``` + +**Input:** `data/` (default: `global_min25occ.csv`) +**Output:** `data/splits/train.csv`, `data/splits/val.csv`, `data/splits/test.csv` + +--- + +## Stage 4 — WebDataset + +**SLURM job:** `scripts/job_bq_webdataset.sh` +**Python:** `create_webdataset.py` or `create_webdataset_generic.py` + +Packs images from the SquashFS archives into WebDataset tar shards, organised +into train/val/test splits. Each shard contains ~1,000 images. Every image is +stored as a triplet inside the tar: + +- `.jpg` — image bytes +- `.cls` — integer class ID (text) +- `.json` — metadata (`class_id`, `relative_local_path`, `task_id`, species fields) + +A `class_map.json` is also written to the output root, mapping sequential class +IDs to `inat_taxon_id` values. + +```bash +sbatch --dependency=afterok: scripts/job_bq_webdataset.sh +``` + +**Input:** `task_0.sqfs` … `task_9.sqfs`, `data/splits/{train,val,test}.csv` +**Output:** `/scratch/melabbas/global_wds/{train,val,test}/*.tar`, `class_map.json` + +### Two versions of `create_webdataset` + +| Script | When to use | +|---|---| +| `create_webdataset.py` | On **fir** — copies each sqfs to NVMe (`$SLURM_TMPDIR`), mounts with `squashfuse`, scatters images to NVMe shard dirs, packs to Lustre tars. Two-batch strategy keeps NVMe usage under 7 TB. | +| `create_webdataset_generic.py` | On **any machine** where sqfs are already mounted. Takes a parent directory containing `task_0/` … `task_9/` subdirectories. Reads images directly from the mounted dirs and streams them into tar shards — no NVMe copy, no squashfuse calls. | + +The fir-specific version (`create_webdataset.py`) exists because the SquashFS +files are ~1 TB each and reading them directly from Lustre during scatter is +too slow — copying to NVMe first gives ~10× better I/O throughput. On machines +where the sqfs are already mounted (or images are on a fast local filesystem), +the generic version is simpler and produces identical output. + +To use the generic version, mount each sqfs into a parent directory first: + +```bash +mkdir -p /mnt/images/task_{0..9} +for T in $(seq 0 9); do + squashfuse /path/to/task_${T}.sqfs /mnt/images/task_${T} +done + +python create_webdataset_generic.py \ + --images-dir /mnt/images \ + --split-csvs train:splits/train.csv val:splits/val.csv test:splits/test.csv \ + --output-dir /output/global_wds +``` + +--- + +## Stage 5 — Train + +**SLURM job:** `scripts/job_bq_train.sh` + +Trains a ResNet-50 classifier on the WebDataset shards using the +`ami-classification train-model` CLI. Shard counts and number of classes are +resolved dynamically from the output directory. Automatically resumes from the +latest checkpoint if one exists in the models directory. + +```bash +sbatch --dependency=afterok: scripts/job_bq_train.sh +``` + +**Input:** `/scratch/melabbas/global_wds/{train,val,test}/*.tar`, `class_map.json` +**Output:** `/project/6068129/melabbas/ami-ml/models/global_wds/*.pt` + +--- + +## Chaining the Full Pipeline + +```bash +DOWNLOAD_JOB=$(sbatch --parsable scripts/job_bq_download.sh) +EXPORT_JOB=$(sbatch --parsable --dependency=afterok:$DOWNLOAD_JOB scripts/job_bq_export.sh) +SPLIT_JOB=$(sbatch --parsable --dependency=afterok:$EXPORT_JOB scripts/job_bq_split.sh) +WDS_JOB=$(sbatch --parsable --dependency=afterok:$SPLIT_JOB scripts/job_bq_webdataset.sh) +sbatch --dependency=afterok:$WDS_JOB scripts/job_bq_train.sh +``` + +--- + +## File Reference + +| File | Role | +|---|---| +| `download_images.py` | Stage 1 — fetch images from iNaturalist, verify with PIL, write results back to BQ (`training_images_downloads`), pack into per-chunk sqfs files | +| `bq_export.py` | Stage 2 — export BigQuery query results to CSV | +| `split.py` | Stage 3 — stratified train/val/test split with occurrence-level grouping | +| `create_webdataset.py` | Stage 4 — fir-specific, NVMe-optimised WebDataset packer | +| `create_webdataset_generic.py` | Stage 4 — generic WebDataset packer for pre-mounted image directories | +| `queries/global_min25occ.sql` | BQ query — species with ≥ 25 occurrences (~10.6M images) | +| `queries/global_max2000img.sql` | BQ query — capped at 2,000 images per species | diff --git a/src/dataset_tools/bq_squashfs/bq_export.py b/src/dataset_tools/bq_squashfs/bq_export.py new file mode 100644 index 0000000..ce5d0cb --- /dev/null +++ b/src/dataset_tools/bq_squashfs/bq_export.py @@ -0,0 +1,120 @@ +#!/usr/bin/env python3 +""" +Export BigQuery query results to a CSV file. + +Reads a SQL query from a file, executes it against BigQuery, and streams +results row-by-row to a CSV file. Generic — works with any SELECT query. + +Usage: + python bq_export.py \\ + --query-file queries/global_min25occ.sql \\ + --output /project/.../global_min25occ.csv + + python bq_export.py \\ + --query-file queries/global_max2000img.sql \\ + --output /project/.../global_max2000img.csv +""" + +import argparse +import csv +import sys +import time +from pathlib import Path + +from google.cloud import bigquery + +LOG_EVERY = 500_000 + + +def log(msg: str) -> None: + print(msg, flush=True) + + +def export(query_file: Path, output: Path, project: str) -> None: + query = query_file.read_text().strip() + if not query: + log(f"ERROR: query file {query_file} is empty", file=sys.stderr) + sys.exit(1) + + client = bigquery.Client(project=project) + + log(f"Query file : {query_file}") + log(f"Output : {output}") + log(f"Submitting query ...") + + job = client.query(query) + log(f"Job ID : {job.job_id}") + log("Streaming rows to CSV ...") + + output.parent.mkdir(parents=True, exist_ok=True) + t0 = time.perf_counter() + + rows_written = 0 + unique_species: set[str] = set() + unique_occurrences: set[int] = set() + + with open(output, "w", newline="") as f: + writer = None + for row in job.result(): + row_dict = dict(row) + + if writer is None: + writer = csv.DictWriter(f, fieldnames=list(row_dict.keys())) + writer.writeheader() + + writer.writerow(row_dict) + rows_written += 1 + + if "species_name" in row_dict: + unique_species.add(row_dict["species_name"]) + if "gbif_id" in row_dict: + unique_occurrences.add(row_dict["gbif_id"]) + + if rows_written % LOG_EVERY == 0: + elapsed = (time.perf_counter() - t0) / 60 + rate = rows_written / (time.perf_counter() - t0) + log(f" {rows_written:,} rows {elapsed:.1f} min ({rate:,.0f} rows/s)") + + if rows_written == 0: + log("ERROR: query returned 0 rows", file=sys.stderr) + sys.exit(1) + + elapsed = (time.perf_counter() - t0) / 60 + size_mb = output.stat().st_size / 1024**2 + + log(f"\n=== Export complete ===") + log(f" Rows written : {rows_written:,}") + if unique_species: + log(f" Unique species : {len(unique_species):,}") + if unique_occurrences: + log(f" Unique gbif_ids : {len(unique_occurrences):,}") + log(f" Output : {output} ({size_mb:.0f} MB)") + log(f" Elapsed : {elapsed:.1f} min") + + +def main() -> None: + parser = argparse.ArgumentParser( + description=__doc__, + formatter_class=argparse.RawDescriptionHelpFormatter, + ) + parser.add_argument("--query-file", required=True, type=Path, + help="Path to a .sql file containing the SELECT query") + parser.add_argument("--output", required=True, type=Path, + help="Path to write the output CSV file") + parser.add_argument("--project", default="leps-ai", + help="GCP project (default: leps-ai)") + args = parser.parse_args() + + if not args.query_file.exists(): + print(f"ERROR: query file not found: {args.query_file}", file=sys.stderr) + sys.exit(1) + + export( + query_file=args.query_file, + output=args.output, + project=args.project, + ) + + +if __name__ == "__main__": + main() diff --git a/src/dataset_tools/bq_squashfs/create_webdataset.py b/src/dataset_tools/bq_squashfs/create_webdataset.py new file mode 100644 index 0000000..c8dad44 --- /dev/null +++ b/src/dataset_tools/bq_squashfs/create_webdataset.py @@ -0,0 +1,537 @@ +#!/usr/bin/env python3 +""" +Build a WebDataset from per-split CSVs and SquashFS files. + +Workflow: + 1. Load per-split CSVs (produced by split_csv.py) + 2. Assign class IDs alphabetically from species_name if class_id absent + 3. Assign each image to a shard within its split + — hash-based (--n-shards) or CSV-order-based (--images-per-shard) + 4. For each sqfs: copy to NVMe → mount → scatter images to split/shard dirs + 5. Pack each split's shard dirs → Lustre tar files (shuffled within each shard) + +CSV required columns (same in all split files): + photo_id — integer; task_id = photo_id % 10 (which sqfs) + relative_local_path — path of the image within the sqfs + +CSV class column (one of): + class_id — integer; used directly if present + species_name — used to assign class_id alphabetically if class_id absent + +Any other CSV columns are passed through into per-sample .json metadata. + +Typical usage — two batches for global WDS (stays under 7 TB NVMe peak): + + Batch 1 — sqfs 0–4, create new tars: + python create_webdataset.py \\ + --split-csvs train:train.csv val:val.csv test:test.csv \\ + --sqfs-paths task_0.sqfs ... task_4.sqfs \\ + --sqfs-start-idx 0 \\ + --n-shards 10700 \\ + --nvme-dir $SLURM_TMPDIR \\ + --output-dir /scratch/melabbas/global_wds \\ + --tar-mode w + + Batch 2 — sqfs 5–9, append to existing tars: + python create_webdataset.py \\ + --split-csvs train:train.csv val:val.csv test:test.csv \\ + --sqfs-paths task_5.sqfs ... task_9.sqfs \\ + --sqfs-start-idx 5 \\ + --n-shards 10700 \\ + --nvme-dir $SLURM_TMPDIR \\ + --output-dir /scratch/melabbas/global_wds \\ + --tar-mode a + +Shard assignment is deterministic (hash of rel_path + seed) — the same image +always maps to the same shard, so --tar-mode a is safe to retry. +""" + +import argparse +import hashlib +import json +import math +import random +import shutil +import subprocess +import tarfile +import time +from collections import defaultdict +from concurrent.futures import ThreadPoolExecutor +from pathlib import Path + +import pandas as pd + +IMAGE_EXTENSIONS = {".jpg", ".jpeg", ".png"} + +_t_start = time.perf_counter() + + +# ── Logging ─────────────────────────────────────────────────────────────────── + +def log(msg: str) -> None: + elapsed = (time.perf_counter() - _t_start) / 3600 + print(f"[{elapsed:5.2f}h] {msg}", flush=True) + + +# ── Disk helpers ────────────────────────────────────────────────────────────── + +def disk_free_tb(path: str) -> float: + r = subprocess.run(["df", "--output=avail", "-k", path], + capture_output=True, text=True) + return int(r.stdout.strip().split()[-1]) / 1024**3 + + +# ── squashfuse helpers ──────────────────────────────────────────────────────── + +def sqfs_mount(sqfs_path: Path, mnt_dir: Path) -> None: + r = subprocess.run(["squashfuse", str(sqfs_path), str(mnt_dir)], + capture_output=True, text=True) + if r.returncode != 0: + raise RuntimeError(f"squashfuse failed for {sqfs_path}: {r.stderr.strip()}") + + +def sqfs_unmount(mnt_dir: Path) -> None: + subprocess.run(["fusermount", "-u", str(mnt_dir)], capture_output=True) + + +# ── CSV loading ─────────────────────────────────────────────────────────────── + +def load_split_csvs( + split_csvs: list[tuple[str, str]], + taxon_id_column: str = "inat_taxon_id", +) -> tuple[dict[str, pd.DataFrame], dict[int, int]]: + """Load per-split CSVs, assign class_ids if absent, return (dfs, class_map). + + If 'class_id' is absent, taxon_id_column is required and must be present in + every split CSV. Taxon IDs are sorted numerically and mapped to sequential + class IDs. class_map.json stores {sequential_id: taxon_id}. + """ + dfs: dict[str, pd.DataFrame] = {} + for name, path in split_csvs: + df = pd.read_csv(path, dtype={"photo_id": "Int64"}) + missing = {"photo_id", "relative_local_path"} - set(df.columns) + if missing: + raise ValueError(f"{path}: missing required columns {missing}") + assert taxon_id_column in df.columns, ( + f"{path}: required label column '{taxon_id_column}' not found " + f"(available: {list(df.columns)})" + ) + df["task_id"] = (df["photo_id"] % 10).astype(int) + df["row_idx"] = range(len(df)) # position in this split's CSV (used by --images-per-shard) + dfs[name] = df + log(f" {name}: {len(df):,} rows from {path}") + + # Assign class_ids across all splits together so IDs are consistent + all_df = pd.concat(dfs.values(), ignore_index=True) + class_map: dict[int, int] = {} + if "class_id" not in all_df.columns: + taxon_ids = sorted(all_df[taxon_id_column].dropna().unique()) + id_map = {int(tid): idx for idx, tid in enumerate(taxon_ids)} + for split_name, df in dfs.items(): + dfs[split_name]["class_id"] = df[taxon_id_column].map(id_map).astype("Int64") + log(f" Assigned class_ids for {len(taxon_ids):,} taxon IDs " + f"(column='{taxon_id_column}', sorted numerically)") + class_map = {idx: int(tid) for tid, idx in id_map.items()} + else: + # class_id already in CSV — build reverse map from taxon_id + class_map = ( + all_df[["class_id", taxon_id_column]] + .dropna() + .drop_duplicates("class_id") + .set_index("class_id")[taxon_id_column] + .astype(int) + .to_dict() + ) + + return dfs, class_map + + +def save_class_map(class_map: dict[int, str], output_dir: Path) -> None: + dest = output_dir / "class_map.json" + if dest.exists(): + log(f"class_map.json already exists — skipping ({dest})") + return + with open(dest, "w") as f: + json.dump({str(k): v for k, v in sorted(class_map.items())}, f, indent=2) + log(f"Saved class_map.json ({len(class_map):,} classes → {dest})") + + +# ── Shard assignment ────────────────────────────────────────────────────────── + + + +# ── Build per-task lookup ───────────────────────────────────────────────────── + +def build_lookup( + dfs: dict[str, pd.DataFrame], + task_ids: list[int], + meta_columns: list[str], + images_per_shard: int, +) -> dict[int, dict[str, dict]]: + """Build {task_id: {rel_path: {split, shard_id, class_id, meta}}}. + + Shard assignment: CSV-order — row_idx // images_per_shard. + First N rows → shard 0, next N → shard 1, etc. + """ + lookup: dict[int, dict[str, dict]] = {tid: {} for tid in task_ids} + + for split, df in dfs.items(): + subset = df[df["task_id"].isin(task_ids)] + for row in subset.itertuples(index=False): + rel = row.relative_local_path + tid = row.task_id + meta = { + col: getattr(row, col) + for col in meta_columns + if hasattr(row, col) and not pd.isna(getattr(row, col)) + } + lookup[tid][rel] = { + "split": split, + "shard_id": int(row.row_idx) // images_per_shard, + "class_id": int(row.class_id), + "meta": meta, + } + + for tid in task_ids: + log(f" task_{tid}: {len(lookup[tid]):,} images in lookup") + return lookup + + +# ── Scatter ─────────────────────────────────────────────────────────────────── + +def scatter_sqfs( + sqfs_path: Path, + sqfs_idx: int, + nvme_dir: Path, + nvme_mnt: Path, + nvme_shards: Path, + task_lookup: dict[str, dict], + no_nvme_copy: bool = False, + limit: int = 0, +) -> tuple[int, int]: + """Copy sqfs to NVMe, mount, scatter images + labels to split/shard dirs. + + --no-nvme-copy: mount sqfs directly from source (skips copy; for testing). + --limit N: stop after N images per sqfs (for smoke tests). + """ + if no_nvme_copy: + mount_target = sqfs_path + log(f" [{sqfs_path.name}] mounting directly (--no-nvme-copy)") + else: + mount_target = nvme_dir / sqfs_path.name + log(f" [{sqfs_path.name}] copying to NVMe " + f"(NVMe free: {disk_free_tb(str(nvme_dir)):.2f} TB)") + t0 = time.perf_counter() + shutil.copy2(str(sqfs_path), str(mount_target)) + gb = mount_target.stat().st_size / 1024**3 + elapsed = time.perf_counter() - t0 + log(f" [{sqfs_path.name}] copied {gb:.1f} GB in {elapsed:.0f}s " + f"({gb * 1024 / elapsed:.0f} MB/s)") + + sqfs_mount(mount_target, nvme_mnt) + try: + if limit: + # Don't materialise the full list — stop enumerating once limit is hit + all_paths = ( + p for p in nvme_mnt.rglob("*") + if p.suffix.lower() in IMAGE_EXTENSIONS + ) + else: + # Sort for sequential reads → squashfuse block cache reuse + all_paths = sorted( + p for p in nvme_mnt.rglob("*") + if p.suffix.lower() in IMAGE_EXTENSIONS + ) + cap = f" (capped at {limit})" if limit else "" + log(f" [{sqfs_path.name}] scattering{cap} " + f"(lookup: {len(task_lookup):,} entries)...") + + written = skipped = 0 + total_bytes = 0 + t0 = time.perf_counter() + + for img_path in all_paths: + if limit and written >= limit: + break + + rel = str(img_path.relative_to(nvme_mnt)) + entry = task_lookup.get(rel) + if entry is None: + skipped += 1 + continue + + split = entry["split"] + shard_id = entry["shard_id"] + class_id = entry["class_id"] + meta = entry["meta"] + key = hashlib.md5(f"{sqfs_idx}:{rel}".encode()).hexdigest() + shard_dir = nvme_shards / split / f"shard_{shard_id:06d}" + ext = img_path.suffix.lower() + + img_bytes = img_path.read_bytes() + (shard_dir / f"{key}{ext}").write_bytes(img_bytes) + (shard_dir / f"{key}.cls").write_bytes(str(class_id).encode()) + (shard_dir / f"{key}.json").write_bytes( + json.dumps({ + "class_id": class_id, + "relative_local_path": rel, + "sqfs_idx": sqfs_idx, + **meta, + }).encode() + ) + + total_bytes += len(img_bytes) + written += 1 + if written % 100_000 == 0: + elapsed = time.perf_counter() - t0 + log(f" {written:,}/{len(all_paths):,} " + f"{total_bytes/1024**3:.1f} GB {written/elapsed:.0f} img/s") + + elapsed = time.perf_counter() - t0 + log(f" [{sqfs_path.name}] scattered {written:,} in {elapsed:.0f}s " + f"({written/elapsed:.0f} img/s) skipped={skipped:,}") + + if not limit: + missing = len(task_lookup) - written + if missing > 0: + log(f" WARNING: {missing:,} images are in the CSV but were not found in " + f"{sqfs_path.name} — they will be missing from the dataset") + finally: + sqfs_unmount(nvme_mnt) + + if not no_nvme_copy: + mount_target.unlink() + log(f" [{sqfs_path.name}] deleted NVMe free: {disk_free_tb(str(nvme_dir)):.2f} TB") + return written, skipped + + +# ── Pack ────────────────────────────────────────────────────────────────────── + +def pack_split( + split: str, + nvme_shards: Path, + lustre_split_dir: Path, + n_shards: int, + tar_mode: str, + pack_workers: int, + seed: int, +) -> int: + """Pack one split's NVMe shard dirs → Lustre tar files. Returns total bytes.""" + split_shards = nvme_shards / split + non_empty = [ + split_shards / f"shard_{s:06d}" + for s in range(n_shards) + if (split_shards / f"shard_{s:06d}").exists() + and any((split_shards / f"shard_{s:06d}").iterdir()) + ] + log(f" [{split}] packing {len(non_empty):,} non-empty shards " + f"(mode='{tar_mode}', {pack_workers} workers)") + + rng = random.Random(seed) + + def pack_one(shard_dir: Path) -> int: + shard_id = int(shard_dir.name.split("_")[1]) + tar_path = lustre_split_dir / f"{split}-{shard_id:06d}.tar" + files = list(shard_dir.iterdir()) + if not files: + return 0 + + existing_keys: set[str] = set() + if tar_mode == "a" and tar_path.exists(): + with tarfile.open(tar_path, "r") as tf: + existing_keys = {m.name for m in tf.getmembers()} + + # Group into per-sample triplets (.cls .jpg .json), then shuffle samples + key_groups: dict[str, list[Path]] = defaultdict(list) + for p in files: + key_groups[p.stem].append(p) + keys = list(key_groups) + rng.shuffle(keys) + + total = 0 + with tarfile.open(tar_path, tar_mode) as tf: + for key in keys: + for p in sorted(key_groups[key]): # consistent order within sample + if p.name not in existing_keys: + tf.add(p, arcname=p.name) + total += p.stat().st_size + p.unlink() + return total + + t0 = time.perf_counter() + total_bytes = 0 + with ThreadPoolExecutor(max_workers=pack_workers) as pool: + for nb in pool.map(pack_one, non_empty, chunksize=32): + total_bytes += nb + elapsed = time.perf_counter() - t0 + log(f" [{split}] packed {total_bytes/1024**3:.1f} GB in {elapsed:.0f}s " + f"({total_bytes/1024**2/elapsed:.0f} MB/s)") + return total_bytes + + +def pack_to_lustre( + splits: list[str], + shards_per_split: dict[str, int], + nvme_shards: Path, + output_dir: Path, + tar_mode: str, + pack_workers: int, + seed: int, +) -> None: + for split in splits: + lustre_split_dir = output_dir / split + lustre_split_dir.mkdir(parents=True, exist_ok=True) + pack_split(split, nvme_shards, lustre_split_dir, + shards_per_split[split], tar_mode, pack_workers, seed) + + +# ── Main ────────────────────────────────────────────────────────────────────── + +def parse_split_csvs(values: list[str]) -> list[tuple[str, str]]: + """Parse ['train:train.csv', 'val:val.csv', ...] → [('train', 'train.csv'), ...]""" + result = [] + for v in values: + name, path = v.split(":", 1) + result.append((name.strip(), path.strip())) + return result + + +def main() -> None: + parser = argparse.ArgumentParser(description=__doc__, + formatter_class=argparse.RawDescriptionHelpFormatter) + parser.add_argument("--split-csvs", nargs="+", required=True, + help="Per-split CSV files as 'name:path' (e.g. train:train.csv)") + parser.add_argument("--sqfs-paths", nargs="+", required=True, + help="Ordered sqfs paths for this batch") + parser.add_argument("--sqfs-start-idx", type=int, required=True, + help="Global index of first sqfs (task_id = photo_id %% 10)") + parser.add_argument("--images-per-shard", type=int, default=1000, + help="Images per shard (default: 1000). Row 0–(N-1) in the CSV → shard 0, " + "rows N–(2N-1) → shard 1, etc.") + parser.add_argument("--nvme-dir", required=True, + help="NVMe scratch root ($SLURM_TMPDIR)") + parser.add_argument("--output-dir", required=True, + help="Lustre output dir; split subdirs and class_map.json written here") + parser.add_argument("--pack-workers", type=int, default=16, + help="Parallel workers for tar packing") + parser.add_argument("--tar-mode", choices=["w", "a"], default="w", + help="'w' create new tars, 'a' append (idempotent on retry)") + parser.add_argument("--taxon-id-column", default="inat_taxon_id", + help="CSV column to use as label (default: inat_taxon_id). " + "Taxon IDs are sorted numerically and mapped to sequential " + "class IDs. Ignored if 'class_id' is already in the CSV.") + parser.add_argument("--seed", type=int, default=42, + help="Seed for shard assignment and within-shard shuffle") + parser.add_argument("--no-nvme-copy", action="store_true", + help="Mount sqfs directly from source, skip copy to NVMe. " + "Use for smoke-testing scatter/pack logic without needing " + "full NVMe space.") + parser.add_argument("--limit", type=int, default=0, + help="Stop scatter after N images per sqfs (0 = no limit). " + "Use with --no-nvme-copy for a fast end-to-end smoke test.") + args = parser.parse_args() + + + nvme_dir = Path(args.nvme_dir) + nvme_mnt = nvme_dir / "mnt" + nvme_shards = nvme_dir / "shards" + output_dir = Path(args.output_dir) + output_dir.mkdir(parents=True, exist_ok=True) + + split_csvs = parse_split_csvs(args.split_csvs) + task_ids = list(range(args.sqfs_start_idx, + args.sqfs_start_idx + len(args.sqfs_paths))) + + log("=== create_webdataset ===") + log(f"Splits: {[n for n,_ in split_csvs]}") + log(f"Tasks: {task_ids} sqfs: {len(args.sqfs_paths)}") + log(f"Output: {output_dir}") + if args.no_nvme_copy: + log("Mode: --no-nvme-copy (smoke test — sqfs mounted in place)") + if args.limit: + log(f"Mode: --limit {args.limit} images per sqfs (smoke test)") + print() + + # ── NVMe preflight ──────────────────────────────────────────────────────── + # Peak NVMe usage occurs during scatter of the last sqfs in the batch: + # peak = batch_size × avg_scatter_per_sqfs + largest_sqfs_file + # scatter_per_sqfs ≈ sqfs_size (images decompressed into shard dirs, ~3 files + # per image but .cls/.json are tiny vs JPEG, so ≈ sqfs size) + # This matches the analysis in job_build_wds_global.sh: + # N=4 (5th sqfs): 5×1.07 + 1.1 = 6.45 TB (safe under 7 TB with BATCH_SIZE=5) + if not args.no_nvme_copy: + sqfs_sizes_gb = [Path(p).stat().st_size / 1024**3 for p in args.sqfs_paths] + n = len(sqfs_sizes_gb) + avg_scatter = sum(sqfs_sizes_gb) / n + peak_gb = n * avg_scatter + max(sqfs_sizes_gb) + nvme_free_gb = disk_free_tb(str(nvme_dir)) * 1024 + log(f"NVMe preflight: {n} sqfs avg={avg_scatter:.1f} GB " + f"peak estimate={peak_gb:.0f} GB ({n}×{avg_scatter:.1f} + {max(sqfs_sizes_gb):.1f}) " + f"NVMe free={nvme_free_gb:.0f} GB") + if peak_gb > nvme_free_gb * 0.9: + log(f"WARNING: estimated peak ({peak_gb:.0f} GB) exceeds 90% of NVMe " + f"({nvme_free_gb:.0f} GB) — reduce --batch-size or request more --tmp") + print() + + # ── Load CSVs ───────────────────────────────────────────────────────────── + log("Loading split CSVs ...") + dfs, class_map = load_split_csvs(split_csvs, taxon_id_column=args.taxon_id_column) + save_class_map(class_map, output_dir) + + shards_per_split = { + name: math.ceil(len(df) / args.images_per_shard) + for name, df in dfs.items() + } + log(f"Shards per split (--images-per-shard={args.images_per_shard}): {shards_per_split}") + + structural = {"photo_id", "relative_local_path", "class_id", "task_id", "split", "row_idx"} + all_cols = set().union(*(df.columns for df in dfs.values())) + meta_columns = [c for c in all_cols if c not in structural] + + # ── Build per-task lookup ────────────────────────────────────────────────── + log(f"Building image lookup for tasks {task_ids} ...") + lookup = build_lookup(dfs, task_ids, meta_columns, args.images_per_shard) + del dfs # free RAM + print() + + # ── Pre-create shard dirs ───────────────────────────────────────────────── + log("Pre-creating shard dirs on NVMe ...") + nvme_mnt.mkdir(exist_ok=True) + for split, n in shards_per_split.items(): + for s in range(n): + (nvme_shards / split / f"shard_{s:06d}").mkdir(parents=True, exist_ok=True) + log(f" Done ({sum(shards_per_split.values())} dirs across {len(shards_per_split)} splits)") + print() + + # ── Scatter phase ───────────────────────────────────────────────────────── + total_written = total_skipped = 0 + for i, sqfs_path_str in enumerate(args.sqfs_paths): + sqfs_path = Path(sqfs_path_str) + sqfs_idx = args.sqfs_start_idx + i + task_lkp = lookup.get(sqfs_idx, {}) + log(f"--- Scatter sqfs {sqfs_idx} ({sqfs_path.name}) " + f"lookup: {len(task_lkp):,} entries ---") + w, s = scatter_sqfs(sqfs_path, sqfs_idx, nvme_dir, nvme_mnt, + nvme_shards, task_lkp, + no_nvme_copy=args.no_nvme_copy, limit=args.limit) + total_written += w + total_skipped += s + print() + + if total_skipped: + log(f"WARNING: {total_skipped:,} images skipped (in sqfs but not in CSVs)") + + # ── Pack phase ──────────────────────────────────────────────────────────── + log(f"--- Pack → Lustre (mode='{args.tar_mode}') ---") + pack_to_lustre( + [n for n, _ in split_csvs], shards_per_split, + nvme_shards, output_dir, args.tar_mode, args.pack_workers, args.seed, + ) + + elapsed = (time.perf_counter() - _t_start) / 3600 + log(f"\nBatch done: {total_written:,} images written " + f"{total_skipped:,} skipped {elapsed:.2f}h elapsed") + + +if __name__ == "__main__": + main() diff --git a/src/dataset_tools/bq_squashfs/create_webdataset_generic.py b/src/dataset_tools/bq_squashfs/create_webdataset_generic.py new file mode 100644 index 0000000..cf0e3e2 --- /dev/null +++ b/src/dataset_tools/bq_squashfs/create_webdataset_generic.py @@ -0,0 +1,303 @@ +#!/usr/bin/env python3 +""" +Build a WebDataset from pre-mounted task directories and per-split CSVs. + +Assumes images are mounted under a parent directory with one subdirectory per +task, named task_0/ through task_9/. The task that owns each image is derived +from its photo_id: task_id = photo_id % 10. relative_local_path in the CSV is +the path of the image *within* its task directory. + +Full image path: /task_/ + +This is a generic alternative to create_webdataset.py that does not require +SquashFS, squashfuse, or NVMe scratch space. It produces identical output tar +shards and is compatible with the same training scripts. + +Usage: + python create_webdataset_from_dir.py \\ + --images-dir /mnt/images \\ + --split-csvs train:train.csv val:val.csv test:test.csv \\ + --output-dir /scratch/global_wds \\ + --images-per-shard 1000 +""" + +import argparse +import json +import math +import random +import tarfile +import time +from collections import defaultdict +from concurrent.futures import ThreadPoolExecutor +from pathlib import Path + +import pandas as pd + +IMAGE_EXTENSIONS = {".jpg", ".jpeg", ".png"} + +_t_start = time.perf_counter() + + +def log(msg: str) -> None: + elapsed = (time.perf_counter() - _t_start) / 3600 + print(f"[{elapsed:5.2f}h] {msg}", flush=True) + + +# ── CSV loading ─────────────────────────────────────────────────────────────── + +def load_split_csvs( + split_csvs: list[tuple[str, str]], + taxon_id_column: str, + images_per_shard: int, +) -> tuple[dict[str, pd.DataFrame], dict[int, int]]: + dfs: dict[str, pd.DataFrame] = {} + for name, path in split_csvs: + df = pd.read_csv(path, dtype={"photo_id": "Int64"}) + missing = {"photo_id", "relative_local_path"} - set(df.columns) + if missing: + raise ValueError(f"{path}: missing required columns {missing}") + assert taxon_id_column in df.columns, ( + f"{path}: required label column '{taxon_id_column}' not found " + f"(available: {list(df.columns)})" + ) + df["task_id"] = (df["photo_id"] % 10).astype(int) + df["row_idx"] = range(len(df)) + df["shard_id"] = df["row_idx"] // images_per_shard + dfs[name] = df + log(f" {name}: {len(df):,} rows {df['shard_id'].max() + 1} shards ({path})") + + all_df = pd.concat(dfs.values(), ignore_index=True) + class_map: dict[int, int] = {} + if "class_id" not in all_df.columns: + taxon_ids = sorted(all_df[taxon_id_column].dropna().unique()) + id_map = {int(tid): idx for idx, tid in enumerate(taxon_ids)} + for split_name, df in dfs.items(): + dfs[split_name]["class_id"] = df[taxon_id_column].map(id_map).astype("Int64") + log(f" Assigned class_ids for {len(taxon_ids):,} taxon IDs (column='{taxon_id_column}')") + class_map = {idx: int(tid) for tid, idx in id_map.items()} + else: + class_map = ( + all_df[["class_id", taxon_id_column]] + .dropna() + .drop_duplicates("class_id") + .set_index("class_id")[taxon_id_column] + .astype(int) + .to_dict() + ) + + return dfs, class_map + + +def save_class_map(class_map: dict[int, int], output_dir: Path) -> None: + dest = output_dir / "class_map.json" + if dest.exists(): + log(f"class_map.json already exists — skipping ({dest})") + return + with open(dest, "w") as f: + json.dump({str(k): v for k, v in sorted(class_map.items())}, f, indent=2) + log(f"Saved class_map.json ({len(class_map):,} classes → {dest})") + + +# ── Build lookup ────────────────────────────────────────────────────────────── + +def build_lookup( + dfs: dict[str, pd.DataFrame], + meta_columns: list[str], +) -> dict[str, dict]: + """Build {relative_local_path: {split, shard_id, task_id, class_id, meta}}.""" + lookup: dict[str, dict] = {} + for split, df in dfs.items(): + for row in df.itertuples(index=False): + meta = { + col: getattr(row, col) + for col in meta_columns + if hasattr(row, col) and not pd.isna(getattr(row, col)) + } + lookup[row.relative_local_path] = { + "split": split, + "shard_id": int(row.shard_id), + "task_id": int(row.task_id), + "class_id": int(row.class_id), + "meta": meta, + } + log(f"Lookup built: {len(lookup):,} entries across {len(dfs)} splits") + return lookup + + +# ── Walk image dirs and collect per-shard path lists ───────────────────────── + +def collect_shards( + images_dir: Path, + lookup: dict[str, dict], +) -> tuple[dict[tuple[str, int], list[tuple[Path, str, dict]]], int]: + """ + Walk task_0/ … task_9/ under images_dir. + For each image found in the lookup, append (abs_path, rel_path, entry) + to the appropriate (split, shard_id) bucket. + + Returns (shard_buckets, n_missing). + """ + shard_buckets: dict[tuple[str, int], list] = defaultdict(list) + found = missing = skipped = 0 + t0 = time.perf_counter() + + for task_dir in sorted(images_dir.iterdir()): + if not task_dir.is_dir() or not task_dir.name.startswith("task_"): + continue + + log(f" Walking {task_dir.name} ...") + for img_path in task_dir.rglob("*"): + if img_path.suffix.lower() not in IMAGE_EXTENSIONS: + continue + + rel = str(img_path.relative_to(task_dir)) + entry = lookup.get(rel) + if entry is None: + skipped += 1 + continue + + key = (entry["split"], entry["shard_id"]) + shard_buckets[key].append((img_path, rel, entry)) + found += 1 + + if found % 500_000 == 0: + elapsed = time.perf_counter() - t0 + log(f" {found:,} found {skipped:,} skipped {found/elapsed:.0f} img/s") + + missing = len(lookup) - found + elapsed = time.perf_counter() - t0 + log(f"Walk complete: {found:,} found {skipped:,} not-in-csv " + f"{missing:,} in-csv-not-found {elapsed:.0f}s") + return dict(shard_buckets), missing + + +# ── Pack shards to tar ──────────────────────────────────────────────────────── + +def pack_shards( + shard_buckets: dict[tuple[str, int], list], + output_dir: Path, + splits: list[str], + pack_workers: int, + seed: int, +) -> None: + rng = random.Random(seed) + + for split in splits: + split_dir = output_dir / split + split_dir.mkdir(parents=True, exist_ok=True) + + items = list(shard_buckets.items()) + log(f"Packing {len(items):,} shards ({pack_workers} workers) ...") + t0 = time.perf_counter() + total_bytes = total_images = 0 + + def pack_one(item: tuple) -> tuple[int, int]: + (split, shard_id), entries = item + rng.shuffle(entries) + tar_path = output_dir / split / f"{split}-{shard_id:06d}.tar" + nb = ni = 0 + with tarfile.open(tar_path, "w") as tf: + for img_path, rel, entry in entries: + import hashlib + key = hashlib.md5(f"{entry['task_id']}:{rel}".encode()).hexdigest() + ext = img_path.suffix.lower() + + img_bytes = img_path.read_bytes() + cls_bytes = str(entry["class_id"]).encode() + meta_bytes = json.dumps({ + "class_id": entry["class_id"], + "relative_local_path": rel, + "task_id": entry["task_id"], + **entry["meta"], + }).encode() + + for suffix, data in [(ext, img_bytes), (".cls", cls_bytes), (".json", meta_bytes)]: + import io + buf = io.BytesIO(data) + info = tarfile.TarInfo(name=f"{key}{suffix}") + info.size = len(data) + tf.addfile(info, buf) + + nb += len(img_bytes) + ni += 1 + return nb, ni + + with ThreadPoolExecutor(max_workers=pack_workers) as pool: + for nb, ni in pool.map(pack_one, items): + total_bytes += nb + total_images += ni + + elapsed = time.perf_counter() - t0 + log(f"Packed {total_images:,} images {total_bytes/1024**3:.1f} GB " + f"{elapsed:.0f}s ({total_bytes/1024**2/elapsed:.0f} MB/s)") + + +# ── Main ────────────────────────────────────────────────────────────────────── + +def parse_split_csvs(values: list[str]) -> list[tuple[str, str]]: + return [(v.split(":", 1)[0].strip(), v.split(":", 1)[1].strip()) for v in values] + + +def main() -> None: + parser = argparse.ArgumentParser( + description=__doc__, + formatter_class=argparse.RawDescriptionHelpFormatter, + ) + parser.add_argument("--images-dir", required=True, + help="Parent dir containing task_0/ … task_9/ subdirectories") + parser.add_argument("--split-csvs", nargs="+", required=True, + help="Per-split CSV files as 'name:path' (e.g. train:train.csv)") + parser.add_argument("--output-dir", required=True, + help="Output directory; split subdirs and class_map.json written here") + parser.add_argument("--images-per-shard", type=int, default=1000, + help="Images per shard, assigned by CSV row order (default: 1000)") + parser.add_argument("--taxon-id-column", default="inat_taxon_id", + help="CSV column used as label (default: inat_taxon_id)") + parser.add_argument("--pack-workers", type=int, default=16, + help="Parallel workers for tar packing (default: 16)") + parser.add_argument("--seed", type=int, default=42, + help="Seed for within-shard shuffle (default: 42)") + args = parser.parse_args() + + images_dir = Path(args.images_dir) + output_dir = Path(args.output_dir) + output_dir.mkdir(parents=True, exist_ok=True) + + split_csvs = parse_split_csvs(args.split_csvs) + splits = [name for name, _ in split_csvs] + + log("=== create_webdataset_from_dir ===") + log(f"Images dir : {images_dir}") + log(f"Splits : {splits}") + log(f"Output : {output_dir}") + print() + + log("Loading split CSVs ...") + dfs, class_map = load_split_csvs(split_csvs, args.taxon_id_column, args.images_per_shard) + save_class_map(class_map, output_dir) + print() + + structural = {"photo_id", "relative_local_path", "class_id", "task_id", "shard_id", "row_idx"} + all_cols = set().union(*(df.columns for df in dfs.values())) + meta_columns = [c for c in all_cols if c not in structural] + + log("Building image lookup ...") + lookup = build_lookup(dfs, meta_columns) + del dfs + print() + + log("Collecting images from task directories ...") + shard_buckets, n_missing = collect_shards(images_dir, lookup) + if n_missing > 0: + log(f"WARNING: {n_missing:,} images are in the CSVs but were not found on disk") + print() + + log("Packing shards → output dir ...") + pack_shards(shard_buckets, output_dir, splits, args.pack_workers, args.seed) + + elapsed = (time.perf_counter() - _t_start) / 3600 + log(f"\nDone in {elapsed:.2f}h") + + +if __name__ == "__main__": + main() diff --git a/src/dataset_tools/bq_squashfs/download_images.py b/src/dataset_tools/bq_squashfs/download_images.py new file mode 100644 index 0000000..a3fd88a --- /dev/null +++ b/src/dataset_tools/bq_squashfs/download_images.py @@ -0,0 +1,297 @@ +#!/usr/bin/env python3 +""" +Download images from training_images BQ table to a local staging directory, +record results in training_images_downloads BQ table, then pack into SquashFS. + +Images are split across parallel jobs using MOD(photo_id, num_jobs) = task_id +so each job handles a balanced, non-overlapping subset of images. + +Resumable: already-attempted images are skipped by LEFT JOINing with +training_images_downloads. Re-running the same task_id is safe. + +Usage (single job): + python download_images.py \ + --staging-dir /localscratch/$USER/staging \ + --num-jobs 1 \ + --task-id 0 + +Usage (one task in a SLURM array): + python download_images.py \ + --staging-dir /localscratch/$USER/staging \ + --num-jobs 10 \ + --task-id $SLURM_ARRAY_TASK_ID + +After all array tasks finish, run job_bq_pack_squashfs.sh to merge +all staging directories into a single SquashFS archive. +""" + +import argparse +import os +import subprocess +from concurrent.futures import ThreadPoolExecutor, as_completed +from pathlib import Path + +import pandas as pd +import PIL +import requests +from PIL import Image +from google.cloud import bigquery + +Image.MAX_IMAGE_PIXELS = None + +BQ_PROJECT = "leps-ai" +BQ_DATASET = "global_butterflies_2604" +TRAINING_TABLE = f"{BQ_PROJECT}.{BQ_DATASET}.training_images" +DOWNLOADS_TABLE = f"{BQ_PROJECT}.{BQ_DATASET}.training_images_downloads" + +DOWNLOADS_SCHEMA = [ + bigquery.SchemaField("dataset_source_uuid", "STRING"), + bigquery.SchemaField("fetch_status", "STRING"), + bigquery.SchemaField("image_width", "INTEGER"), + bigquery.SchemaField("image_height", "INTEGER"), + bigquery.SchemaField("image_size", "INTEGER"), + bigquery.SchemaField("corrupted", "BOOLEAN"), +] + + +def download_and_verify(row: dict, staging_dir: Path) -> dict: + """Download one image from absolute_url, verify with PIL, return result.""" + url = row["absolute_url"] + rel_path = row["relative_local_path"] + dest = staging_dir / rel_path + dest.parent.mkdir(parents=True, exist_ok=True) + + result = { + "dataset_source_uuid": row["dataset_source_uuid"], + "fetch_status": None, + "image_width": None, + "image_height": None, + "image_size": None, + "corrupted": None, + } + + # Download + try: + resp = requests.get(url, timeout=30, stream=True) + resp.raise_for_status() + with open(dest, "wb") as f: + for chunk in resp.iter_content(chunk_size=8192): + f.write(chunk) + except Exception as e: + print(f"Failed {url}: {e}", flush=True) + result["fetch_status"] = "failed" + return result + + # Verify with PIL + try: + with Image.open(dest) as img: + img.convert("RGB") + result["image_width"], result["image_height"] = img.size + result["image_size"] = dest.stat().st_size + result["corrupted"] = False + result["fetch_status"] = "downloaded" + except (PIL.UnidentifiedImageError, OSError) as e: + print(f"Corrupted {url}: {e}", flush=True) + result["corrupted"] = True + result["fetch_status"] = "corrupted" + + return result + + +def ensure_downloads_table(client: bigquery.Client) -> None: + """Create training_images_downloads table if it doesn't exist.""" + try: + client.get_table(DOWNLOADS_TABLE) + except Exception: + table = bigquery.Table(DOWNLOADS_TABLE, schema=DOWNLOADS_SCHEMA) + table.description = ( + "Download results for training_images. One row per download attempt. " + "Appended to by parallel download jobs. Used to track fetch progress " + "without DML updates on the base training_images table." + ) + client.create_table(table) + print(f"Created table {DOWNLOADS_TABLE}", flush=True) + + +def write_results_to_bq(client: bigquery.Client, results: list[dict]) -> None: + """ + Append download results to training_images_downloads via batch load job. + Uses load_table_from_dataframe which does not require DML billing. + Multiple parallel jobs can safely append to the same table simultaneously. + """ + df = pd.DataFrame(results) + job_config = bigquery.LoadJobConfig( + write_disposition=bigquery.WriteDisposition.WRITE_APPEND, + schema=DOWNLOADS_SCHEMA, + ) + job = client.load_table_from_dataframe(df, DOWNLOADS_TABLE, job_config=job_config) + job.result() + + +def get_pending_rows( + client: bigquery.Client, num_jobs: int, task_id: int, + limit: int | None = None, force_redownload: bool = False +) -> list[dict]: + """ + Query training_images for rows assigned to this task (MOD split), + excluding images already attempted in training_images_downloads. + Pass force_redownload=True to ignore existing download records (e.g. to + re-download images whose staging files were deleted). + """ + limit_clause = f"LIMIT {limit}" if limit else "" + if force_redownload: + query = f""" + SELECT + ti.dataset_source_uuid, + ti.absolute_url, + ti.relative_local_path + FROM `{TRAINING_TABLE}` ti + WHERE ti.fetch_status = 'pending' + AND MOD(ti.photo_id, {num_jobs}) = {task_id} + {limit_clause} + """ + else: + query = f""" + SELECT + ti.dataset_source_uuid, + ti.absolute_url, + ti.relative_local_path + FROM `{TRAINING_TABLE}` ti + LEFT JOIN `{DOWNLOADS_TABLE}` d + ON ti.dataset_source_uuid = d.dataset_source_uuid + WHERE ti.fetch_status = 'pending' + AND MOD(ti.photo_id, {num_jobs}) = {task_id} + AND d.dataset_source_uuid IS NULL + {limit_clause} + """ + rows = list(client.query(query).result()) + return [dict(r) for r in rows] + + +def pack_chunk_to_sqfs(staging_dir: Path, chunk_num: int, num_workers: int = 4) -> Path | None: + """Pack downloaded images in staging_dir into a per-chunk SquashFS file. + + Passes bucket dirs (000/, 001/, ...) directly to mksquashfs so paths inside + the archive are clean: 000/abc123.jpg — not staging_dir/000/abc123.jpg. + + Returns the path to the created .sqfs file, or None if staging_dir is empty. + """ + bucket_dirs = sorted(d for d in staging_dir.iterdir() if d.is_dir()) + if not bucket_dirs: + print(f" No images in staging dir, skipping sqfs pack for chunk {chunk_num}", flush=True) + return None + + chunk_sqfs = staging_dir / f"chunk_{chunk_num:04d}.sqfs" + cmd = [ + "mksquashfs", + *[str(d) for d in bucket_dirs], + str(chunk_sqfs), + "-noappend", + "-no-xattrs", + "-comp", "zstd", + "-Xcompression-level", "3", + "-processors", str(num_workers), + ] + print(f" Packing {len(bucket_dirs)} bucket dirs → {chunk_sqfs.name}...", flush=True) + subprocess.run(cmd, check=True) + size_mb = chunk_sqfs.stat().st_size / (1024 ** 2) + print(f" Packed: {chunk_sqfs.name} ({size_mb:.1f} MB)", flush=True) + return chunk_sqfs + + +def clear_staging(staging_dir: Path) -> None: + """Remove all image files from staging dir; preserve .sqfs chunk files.""" + for f in staging_dir.rglob("*"): + if f.is_file() and f.suffix != ".sqfs": + f.unlink() + for d in sorted(staging_dir.rglob("*"), reverse=True): + if d.is_dir(): + try: + d.rmdir() # only removes empty dirs; bucket dirs with no images will be gone + except OSError: + pass + + +def main(): + parser = argparse.ArgumentParser() + parser.add_argument("--staging-dir", required=True, + help="Local directory to download images into") + parser.add_argument("--num-jobs", type=int, required=True, + help="Total number of parallel jobs (used for MOD split)") + parser.add_argument("--task-id", type=int, required=True, + help="This job's task ID (0 to num_jobs-1)") + parser.add_argument("--num-workers", type=int, default=32, + help="Parallel download workers") + parser.add_argument("--chunk-size", type=int, default=10000, + help="Images per chunk before writing to BQ") + parser.add_argument("--limit", type=int, default=None, + help="Cap total images queried (for small-scale tests)") + parser.add_argument("--force-redownload", action="store_true", + help="Re-download all images for this task, ignoring existing BQ records") + args = parser.parse_args() + + client = bigquery.Client(project=BQ_PROJECT) + staging_dir = Path(args.staging_dir) + staging_dir.mkdir(parents=True, exist_ok=True) + + ensure_downloads_table(client) + + print(f"Task {args.task_id}/{args.num_jobs}: querying pending rows " + f"(force_redownload={args.force_redownload})...", flush=True) + rows = get_pending_rows(client, args.num_jobs, args.task_id, + limit=args.limit, force_redownload=args.force_redownload) + print(f"Task {args.task_id}/{args.num_jobs}: {len(rows):,} pending images", flush=True) + + total_downloaded = 0 + total_failed = 0 + total_corrupted = 0 + + for chunk_start in range(0, len(rows), args.chunk_size): + chunk = rows[chunk_start : chunk_start + args.chunk_size] + chunk_num = chunk_start // args.chunk_size + 1 + total_chunks = (len(rows) + args.chunk_size - 1) // args.chunk_size + print(f"\n[Task {args.task_id}] Chunk {chunk_num}/{total_chunks} " + f"({len(chunk)} images)...", flush=True) + + # Download in parallel + results = [] + with ThreadPoolExecutor(max_workers=args.num_workers) as executor: + futures = { + executor.submit(download_and_verify, row, staging_dir): row + for row in chunk + } + for i, future in enumerate(as_completed(futures)): + results.append(future.result()) + if (i + 1) % 1000 == 0: + print(f" {i+1}/{len(chunk)} done", flush=True) + + # Count results + for r in results: + if r["fetch_status"] == "downloaded": + total_downloaded += 1 + elif r["fetch_status"] == "failed": + total_failed += 1 + elif r["fetch_status"] == "corrupted": + total_corrupted += 1 + + print(f" downloaded={total_downloaded} failed={total_failed} " + f"corrupted={total_corrupted}", flush=True) + + # Write results to BQ (free batch load, no DML) + write_results_to_bq(client, results) + print(f" Results written to BQ", flush=True) + + # Pack images into a per-chunk sqfs, then delete raw files. + # This keeps peak inode usage at ~chunk_size per task (well under quota) + # rather than accumulating all images on disk until the pack job runs. + pack_chunk_to_sqfs(staging_dir, chunk_num, num_workers=4) + clear_staging(staging_dir) + print(f" Staging cleared (chunk sqfs kept)", flush=True) + + print(f"\n[Task {args.task_id}] Done. " + f"downloaded={total_downloaded} failed={total_failed} " + f"corrupted={total_corrupted}", flush=True) + + +if __name__ == "__main__": + main() diff --git a/src/dataset_tools/bq_squashfs/queries/global_max2000img.sql b/src/dataset_tools/bq_squashfs/queries/global_max2000img.sql new file mode 100644 index 0000000..346b598 --- /dev/null +++ b/src/dataset_tools/bq_squashfs/queries/global_max2000img.sql @@ -0,0 +1,31 @@ +-- Global dataset: all downloaded images, capped at 2000 images per species. +-- Species with fewer than 2000 images are included in full. +-- Expected: ~12,313 species, ~2.7M images. +WITH ranked AS ( + SELECT + ti.photo_id, + ti.gbif_id, + ti.relative_local_path, + ti.dataset_source_uuid, + ti.inat_taxon_id, + tx.species_name, + tx.family, + ROW_NUMBER() OVER ( + PARTITION BY tx.species_name + ORDER BY ti.photo_id + ) AS rn + FROM `leps-ai.global_butterflies_2604.training_images` ti + JOIN `leps-ai.global_butterflies_2604.inat_taxa` tx USING (inat_taxon_id) + WHERE ti.fetch_status = 'downloaded' + AND tx.species_name IS NOT NULL +) +SELECT + photo_id, + gbif_id, + relative_local_path, + dataset_source_uuid, + inat_taxon_id, + species_name, + family +FROM ranked +WHERE rn <= 2000 diff --git a/src/dataset_tools/bq_squashfs/queries/global_min25occ.sql b/src/dataset_tools/bq_squashfs/queries/global_min25occ.sql new file mode 100644 index 0000000..a495752 --- /dev/null +++ b/src/dataset_tools/bq_squashfs/queries/global_min25occ.sql @@ -0,0 +1,23 @@ +-- Global dataset: all downloaded images for species with >= 25 distinct GBIF occurrences. +-- Expected: ~4,704 species, ~10.6M images. +WITH qualifying_species AS ( + SELECT tx.species_name + FROM `leps-ai.global_butterflies_2604.training_images` ti + JOIN `leps-ai.global_butterflies_2604.inat_taxa` tx USING (inat_taxon_id) + WHERE ti.fetch_status = 'downloaded' + AND tx.species_name IS NOT NULL + GROUP BY tx.species_name + HAVING COUNT(DISTINCT ti.gbif_id) >= 25 +) +SELECT + ti.photo_id, + ti.gbif_id, + ti.relative_local_path, + ti.dataset_source_uuid, + ti.inat_taxon_id, + tx.species_name, + tx.family +FROM `leps-ai.global_butterflies_2604.training_images` ti +JOIN `leps-ai.global_butterflies_2604.inat_taxa` tx USING (inat_taxon_id) +WHERE ti.fetch_status = 'downloaded' + AND tx.species_name IN (SELECT species_name FROM qualifying_species) diff --git a/src/dataset_tools/bq_squashfs/split.py b/src/dataset_tools/bq_squashfs/split.py new file mode 100644 index 0000000..2938cc4 --- /dev/null +++ b/src/dataset_tools/bq_squashfs/split.py @@ -0,0 +1,231 @@ +#!/usr/bin/env python3 +""" +Split a BQ-exported CSV into train/val/test CSV files. + +Uses stratified splitting (sklearn) on a category column so every species +is proportionally represented in all three splits. When --split-by-occurrence +is set, images that share the same gbif_id (same field observation) +are kept together in one split to avoid data leakage. + +Required CSV columns: + photo_id — integer (iNat photo ID, image-level) + relative_local_path — path of the image within its sqfs + gbif_id — GBIF occurrence ID (used by --split-by-occurrence) + +Optional (passed through unchanged): + species_name, inat_taxon_id, family, gbif_id, and any other columns + +Usage: + python split.py \\ + --csv bq_export.csv \\ + --output-dir /project/.../data/splits/ \\ + --category-key species_name \\ + --val-frac 0.1 \\ + --test-frac 0.1 \\ + --split-by-occurrence \\ + --max-instances 1000 \\ + --min-instances 5 \\ + --seed 42 +""" + +import argparse +import math +import random +from pathlib import Path + +import numpy as np +import pandas as pd +from sklearn.model_selection import train_test_split + +REQUIRED_COLUMNS = {"photo_id", "relative_local_path", "gbif_id"} + + +def _set_random_seeds(seed: int) -> None: + random.seed(seed) + np.random.seed(seed) + + +def _create_test_split( + cat_counts: pd.Series, + category_key: str, + metadata: pd.DataFrame, + split_by_occurrence: bool, + split_metadata: pd.DataFrame, + test_frac: float, +) -> pd.DataFrame: + min_instances = math.ceil(1 / test_frac) + test_categories = list(cat_counts[cat_counts >= min_instances].keys()) + selected = split_metadata[split_metadata[category_key].isin(test_categories)] + _, selected, _, _ = train_test_split( + selected, selected[[category_key]], stratify=selected[[category_key]], + test_size=test_frac, + ) + if split_by_occurrence: + return metadata[metadata["gbif_id"].isin( + selected["gbif_id"].unique() + )].copy() + return selected.copy() + + +def _create_val_split( + cat_counts: pd.Series, + category_key: str, + metadata: pd.DataFrame, + split_by_occurrence: bool, + split_metadata: pd.DataFrame, + test_frac: float, + test_set: pd.DataFrame, + val_frac: float, +) -> pd.DataFrame: + min_instances = math.ceil(1 / val_frac) + val_categories = list(cat_counts[cat_counts >= min_instances].keys()) + selected = split_metadata[ + ~split_metadata["relative_local_path"].isin(test_set["relative_local_path"].unique()) + ] + selected = selected[selected[category_key].isin(val_categories)] + adjusted_val_frac = val_frac / (1 - test_frac) + _, selected, _, _ = train_test_split( + selected, selected[[category_key]], stratify=selected[[category_key]], + test_size=adjusted_val_frac, + ) + if split_by_occurrence: + return metadata[metadata["gbif_id"].isin( + selected["gbif_id"].unique() + )].copy() + return selected.copy() + + +def _create_train_split( + metadata: pd.DataFrame, + test_set: pd.DataFrame, + val_set: pd.DataFrame, +) -> pd.DataFrame: + exclude = set(test_set["relative_local_path"].unique()) | set(val_set["relative_local_path"].unique()) + return metadata[~metadata["relative_local_path"].isin(exclude)].copy() + + +def _subsample(dataset: pd.DataFrame, max_instances: int, category_key: str) -> pd.DataFrame: + counts = dataset[category_key].value_counts() + over_limit = list(counts[counts > max_instances].keys()) + under_limit = dataset[~dataset[category_key].isin(over_limit)].copy() + capped = [ + dataset[dataset[category_key] == cat].sample(max_instances) + for cat in over_limit + ] + return pd.concat([under_limit] + capped, ignore_index=True) + + +def split_dataset( + dataset_csv: str, + output_dir: str, + test_frac: float, + val_frac: float, + split_by_occurrence: bool, + category_key: str, + max_instances: int, + min_instances: int, + seed: int, +) -> None: + _set_random_seeds(seed) + output_path = Path(output_dir) + output_path.mkdir(parents=True, exist_ok=True) + + print(f"Loading {dataset_csv} ...") + metadata = pd.read_csv(dataset_csv, dtype={"photo_id": "Int64"}) + print(f" {len(metadata):,} rows columns: {list(metadata.columns)}") + + missing = REQUIRED_COLUMNS - set(metadata.columns) + if missing: + raise ValueError(f"CSV missing required columns: {missing}") + if category_key not in metadata.columns: + raise ValueError(f"Category key '{category_key}' not found in CSV columns: {list(metadata.columns)}") + + # When splitting by occurrence, deduplicate to one row per occurrence first, + # then pull all images for the selected occurrences back in. + if split_by_occurrence: + split_metadata = metadata.drop_duplicates(subset=["gbif_id"], keep="first").copy() + print(f" Split by occurrence: {len(split_metadata):,} unique occurrences") + else: + split_metadata = metadata.copy() + + cat_counts = split_metadata[category_key].value_counts() + print(f" {len(cat_counts):,} unique categories in '{category_key}'") + + print(f"Creating test split (frac={test_frac}) ...") + test_set = _create_test_split( + cat_counts, category_key, metadata, split_by_occurrence, split_metadata, test_frac + ) + + print(f"Creating val split (frac={val_frac}) ...") + val_set = _create_val_split( + cat_counts, category_key, metadata, split_by_occurrence, split_metadata, + test_frac, test_set, val_frac, + ) + + print("Creating train split ...") + train_set = _create_train_split(metadata, test_set, val_set) + + if max_instances > 0: + print(f"Capping at max_instances={max_instances} per category ...") + train_set = _subsample(train_set, max_instances, category_key) + val_set = _subsample(val_set, int(max_instances * val_frac), category_key) + test_set = _subsample(test_set, int(max_instances * test_frac), category_key) + + if min_instances > 0: + print(f"Filtering to min_instances={min_instances} per category ...") + cat_counts_train = train_set[category_key].value_counts() + keep = list(cat_counts_train[cat_counts_train >= min_instances].keys()) + train_set = train_set[train_set[category_key].isin(keep)].copy() + val_set = val_set[val_set[category_key].isin(keep)].copy() + test_set = test_set[test_set[category_key].isin(keep)].copy() + print(f" Kept {len(keep):,} categories with >= {min_instances} train images") + + splits = {"train": train_set, "val": val_set, "test": test_set} + print("\n=== Split summary ===") + for name, df in splits.items(): + n_cats = df[category_key].nunique() + out = output_path / f"{name}.csv" + df.to_csv(out, index=False) + print(f" {name}: {len(df):,} images {n_cats:,} categories → {out}") + + +def main() -> None: + parser = argparse.ArgumentParser( + description=__doc__, + formatter_class=argparse.RawDescriptionHelpFormatter, + ) + parser.add_argument("--csv", required=True, + help="BQ-exported CSV file") + parser.add_argument("--output-dir", required=True, + help="Directory to write train.csv / val.csv / test.csv") + parser.add_argument("--category-key", default="species_name", + help="Column used for stratified splitting (default: species_name)") + parser.add_argument("--val-frac", type=float, default=0.1, + help="Fraction of data for validation (default: 0.1)") + parser.add_argument("--test-frac", type=float, default=0.1, + help="Fraction of data for test (default: 0.1)") + parser.add_argument("--split-by-occurrence", action="store_true", + help="Keep images from the same gbif_id (GBIF occurrence) in one split") + parser.add_argument("--max-instances", type=int, default=1000, + help="Max images per category per split (0 = no cap, default: 1000)") + parser.add_argument("--min-instances", type=int, default=0, + help="Min train images per category; drop categories below this (default: 0)") + parser.add_argument("--seed", type=int, default=42, + help="Random seed (default: 42)") + args = parser.parse_args() + + split_dataset( + dataset_csv=args.csv, + output_dir=args.output_dir, + test_frac=args.test_frac, + val_frac=args.val_frac, + split_by_occurrence=args.split_by_occurrence, + category_key=args.category_key, + max_instances=args.max_instances, + min_instances=args.min_instances, + seed=args.seed, + ) + + +if __name__ == "__main__": + main() From df9824cb63d2a533cb2c5594212ca357d620fe3d Mon Sep 17 00:00:00 2001 From: Mohamed Elabbas Date: Wed, 3 Jun 2026 11:00:20 -0700 Subject: [PATCH 02/26] chore: ignore image data dirs, logs, stats, and scratch outputs --- .gitignore | 15 +++++++++++++++ 1 file changed, 15 insertions(+) diff --git a/.gitignore b/.gitignore index 534c4ed..4aae6f3 100644 --- a/.gitignore +++ b/.gitignore @@ -171,3 +171,18 @@ cython_debug/ # Weights and Biases files wandb/ + +# Raw image datasets (too many inodes for home filesystem) +data/vermont_butterflies/ +data/vermont_species/ + +# Job output logs and stats +stats/ +logs/ + +# Scratch/intermediate data outputs +data/bq_pipeline/ +data/predictions/ +data/pipeline_test/ +data/split_test/ +data/wds_test_splits/ From bb68b8a9ad41e214c742943b70298bf440e6fa59 Mon Sep 17 00:00:00 2001 From: Mohamed Elabbas Date: Wed, 3 Jun 2026 16:40:59 -0700 Subject: [PATCH 03/26] feat(download): retry backoff, session pooling, inline MERGE, mksquashfs error handling MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Scale issues fixed from real 10M-image run logs: - Retry with exponential backoff + jitter (fixes Errno 16, 429, 503, timeouts) - Per-thread requests.Session to reduce socket churn - BQ write retry (3 attempts, 30s backoff) — previously silent loss on failure - mksquashfs failures now raise RuntimeError so SLURM marks task failed cleanly - Chunk accumulation warning when >20 chunk sqfs files build up in staging New behaviour: - Inline MERGE into training_images after each chunk write via temp table — no separate update job needed; status is current by end of each chunk - --table-prefix flag for testing against test_training_images tables - Per-chunk throughput logging (img/s) to detect throttling Documentation: - Mid-chunk restart behaviour documented in module docstring (re-download cost, no data loss) Co-Authored-By: Claude Sonnet 4.6 --- .../bq_squashfs/download_images.py | 450 ++++++++++++------ 1 file changed, 308 insertions(+), 142 deletions(-) diff --git a/src/dataset_tools/bq_squashfs/download_images.py b/src/dataset_tools/bq_squashfs/download_images.py index a3fd88a..8d51be2 100644 --- a/src/dataset_tools/bq_squashfs/download_images.py +++ b/src/dataset_tools/bq_squashfs/download_images.py @@ -1,7 +1,8 @@ #!/usr/bin/env python3 """ Download images from training_images BQ table to a local staging directory, -record results in training_images_downloads BQ table, then pack into SquashFS. +record results in training_images_downloads, merge status back into +training_images, then pack images into SquashFS chunks. Images are split across parallel jobs using MOD(photo_id, num_jobs) = task_id so each job handles a balanced, non-overlapping subset of images. @@ -9,25 +10,33 @@ Resumable: already-attempted images are skipped by LEFT JOINing with training_images_downloads. Re-running the same task_id is safe. -Usage (single job): - python download_images.py \ - --staging-dir /localscratch/$USER/staging \ - --num-jobs 1 \ - --task-id 0 - -Usage (one task in a SLURM array): - python download_images.py \ - --staging-dir /localscratch/$USER/staging \ - --num-jobs 10 \ - --task-id $SLURM_ARRAY_TASK_ID - -After all array tasks finish, run job_bq_pack_squashfs.sh to merge -all staging directories into a single SquashFS archive. +NOTE — mid-chunk restart behaviour: if the job dies after downloading a chunk +but before the BQ write, those images are on disk but unrecorded. On resume +the LEFT JOIN will re-queue them and they will be re-downloaded. No data is +lost but ~chunk_size images are downloaded twice. This is acceptable given +the low probability and low cost of a single chunk redo. + +Usage (single job / test): + python download_images.py \\ + --staging-dir /scratch/$USER/staging \\ + --num-jobs 1 --task-id 0 \\ + --limit 50 --table-prefix test_ + +Usage (SLURM array): + python download_images.py \\ + --staging-dir /scratch/$USER/staging \\ + --num-jobs 10 --task-id $SLURM_ARRAY_TASK_ID + +After all array tasks finish, run job_bq_pack_per_task.sh to merge +chunk sqfs files into the final task_N.sqfs archives. """ import argparse -import os +import random import subprocess +import threading +import time +import uuid from concurrent.futures import ThreadPoolExecutor, as_completed from pathlib import Path @@ -41,144 +50,259 @@ BQ_PROJECT = "leps-ai" BQ_DATASET = "global_butterflies_2604" -TRAINING_TABLE = f"{BQ_PROJECT}.{BQ_DATASET}.training_images" -DOWNLOADS_TABLE = f"{BQ_PROJECT}.{BQ_DATASET}.training_images_downloads" + +# Retry config for HTTP downloads +_RETRY_STATUSES = {429, 500, 502, 503, 504} +_MAX_RETRIES = 5 +_BACKOFF_BASE = 2.0 # seconds +_BACKOFF_MAX = 60.0 # seconds cap + +# Warn if this many chunk sqfs files accumulate (pack job falling behind) +_CHUNK_ACCUMULATION_WARN = 20 DOWNLOADS_SCHEMA = [ bigquery.SchemaField("dataset_source_uuid", "STRING"), - bigquery.SchemaField("fetch_status", "STRING"), - bigquery.SchemaField("image_width", "INTEGER"), - bigquery.SchemaField("image_height", "INTEGER"), - bigquery.SchemaField("image_size", "INTEGER"), - bigquery.SchemaField("corrupted", "BOOLEAN"), + bigquery.SchemaField("fetch_status", "STRING"), + bigquery.SchemaField("image_width", "INTEGER"), + bigquery.SchemaField("image_height", "INTEGER"), + bigquery.SchemaField("image_size", "INTEGER"), + bigquery.SchemaField("corrupted", "BOOLEAN"), ] +# Thread-local storage for per-thread requests sessions +_thread_local = threading.local() + + +def _get_session() -> requests.Session: + """Return a per-thread requests.Session with a single keep-alive connection.""" + if not hasattr(_thread_local, "session"): + s = requests.Session() + adapter = requests.adapters.HTTPAdapter( + pool_connections=1, + pool_maxsize=1, + max_retries=0, # retries handled manually below + ) + s.mount("https://", adapter) + s.mount("http://", adapter) + _thread_local.session = s + return _thread_local.session + + +def _fetch_with_retry(url: str, dest: Path) -> None: + """Download url → dest with exponential backoff + jitter. + + Retries on rate-limit (429), transient server errors (5xx), + connection errors (including Errno 16 — too many open sockets), + and timeouts. Raises on permanent client errors (4xx except 429) + or after exhausting all retries. + """ + session = _get_session() + for attempt in range(_MAX_RETRIES + 1): + try: + resp = session.get(url, timeout=30, stream=True) + if resp.status_code in _RETRY_STATUSES and attempt < _MAX_RETRIES: + delay = min(_BACKOFF_BASE * (2 ** attempt), _BACKOFF_MAX) + delay += random.uniform(0, delay * 0.25) + print(f" HTTP {resp.status_code} {url} — retry {attempt+1}/{_MAX_RETRIES} " + f"in {delay:.1f}s", flush=True) + time.sleep(delay) + continue + resp.raise_for_status() + with open(dest, "wb") as f: + for chunk in resp.iter_content(chunk_size=8192): + f.write(chunk) + return + except requests.exceptions.ConnectionError as e: + if attempt < _MAX_RETRIES: + delay = min(_BACKOFF_BASE * (2 ** attempt), _BACKOFF_MAX) + delay += random.uniform(0, delay * 0.25) + print(f" ConnectionError {url} — retry {attempt+1}/{_MAX_RETRIES} " + f"in {delay:.1f}s: {e}", flush=True) + time.sleep(delay) + else: + raise + except requests.exceptions.Timeout: + if attempt < _MAX_RETRIES: + delay = min(_BACKOFF_BASE * (2 ** attempt), _BACKOFF_MAX) + delay += random.uniform(0, delay * 0.25) + print(f" Timeout {url} — retry {attempt+1}/{_MAX_RETRIES} " + f"in {delay:.1f}s", flush=True) + time.sleep(delay) + else: + raise + raise RuntimeError(f"Exhausted {_MAX_RETRIES} retries for {url}") + def download_and_verify(row: dict, staging_dir: Path) -> dict: - """Download one image from absolute_url, verify with PIL, return result.""" - url = row["absolute_url"] + """Download one image, verify with PIL, return result dict.""" + url = row["absolute_url"] rel_path = row["relative_local_path"] - dest = staging_dir / rel_path + dest = staging_dir / rel_path dest.parent.mkdir(parents=True, exist_ok=True) result = { "dataset_source_uuid": row["dataset_source_uuid"], "fetch_status": None, - "image_width": None, + "image_width": None, "image_height": None, - "image_size": None, - "corrupted": None, + "image_size": None, + "corrupted": None, } - # Download try: - resp = requests.get(url, timeout=30, stream=True) - resp.raise_for_status() - with open(dest, "wb") as f: - for chunk in resp.iter_content(chunk_size=8192): - f.write(chunk) + _fetch_with_retry(url, dest) except Exception as e: - print(f"Failed {url}: {e}", flush=True) + print(f" Failed {url}: {e}", flush=True) result["fetch_status"] = "failed" return result - # Verify with PIL try: with Image.open(dest) as img: img.convert("RGB") result["image_width"], result["image_height"] = img.size - result["image_size"] = dest.stat().st_size - result["corrupted"] = False + result["image_size"] = dest.stat().st_size + result["corrupted"] = False result["fetch_status"] = "downloaded" except (PIL.UnidentifiedImageError, OSError) as e: - print(f"Corrupted {url}: {e}", flush=True) - result["corrupted"] = True + print(f" Corrupted {url}: {e}", flush=True) + result["corrupted"] = True result["fetch_status"] = "corrupted" return result -def ensure_downloads_table(client: bigquery.Client) -> None: - """Create training_images_downloads table if it doesn't exist.""" +def ensure_downloads_table(client: bigquery.Client, downloads_table: str) -> None: + """Create the downloads table if it doesn't exist.""" try: - client.get_table(DOWNLOADS_TABLE) + client.get_table(downloads_table) except Exception: - table = bigquery.Table(DOWNLOADS_TABLE, schema=DOWNLOADS_SCHEMA) + table = bigquery.Table(downloads_table, schema=DOWNLOADS_SCHEMA) table.description = ( "Download results for training_images. One row per download attempt. " "Appended to by parallel download jobs. Used to track fetch progress " "without DML updates on the base training_images table." ) client.create_table(table) - print(f"Created table {DOWNLOADS_TABLE}", flush=True) + print(f"Created table {downloads_table}", flush=True) -def write_results_to_bq(client: bigquery.Client, results: list[dict]) -> None: - """ - Append download results to training_images_downloads via batch load job. - Uses load_table_from_dataframe which does not require DML billing. - Multiple parallel jobs can safely append to the same table simultaneously. +def write_results_to_bq( + client: bigquery.Client, + results: list[dict], + downloads_table: str, + max_retries: int = 3, +) -> None: + """Append download results to the downloads table via batch load (free tier). + + Multiple parallel tasks can safely append simultaneously. + Retries up to max_retries times on transient BQ errors. """ df = pd.DataFrame(results) job_config = bigquery.LoadJobConfig( write_disposition=bigquery.WriteDisposition.WRITE_APPEND, schema=DOWNLOADS_SCHEMA, ) - job = client.load_table_from_dataframe(df, DOWNLOADS_TABLE, job_config=job_config) - job.result() + for attempt in range(max_retries): + try: + job = client.load_table_from_dataframe(df, downloads_table, job_config=job_config) + job.result() + return + except Exception as e: + if attempt < max_retries - 1: + delay = 30 * (attempt + 1) + print(f" BQ write failed (attempt {attempt+1}/{max_retries}): {e} " + f"— retrying in {delay}s", flush=True) + time.sleep(delay) + else: + raise + + +def merge_chunk_into_training_images( + client: bigquery.Client, + results: list[dict], + training_table: str, + downloads_table: str, +) -> int: + """MERGE this chunk's successful results directly into training_images. + + Uses a temp table containing only this chunk's rows so the MERGE scans + a small dataset rather than all of training_images_downloads. + Only updates rows that are still 'pending' — safe to run from parallel tasks. + Returns the number of rows updated. + """ + successful = [r for r in results if r["fetch_status"] in ("downloaded", "corrupted")] + if not successful: + return 0 + + tmp_table = f"{BQ_PROJECT}.{BQ_DATASET}._dl_merge_tmp_{uuid.uuid4().hex[:8]}" + df = pd.DataFrame(successful) + job_config = bigquery.LoadJobConfig( + write_disposition=bigquery.WriteDisposition.WRITE_TRUNCATE, + schema=DOWNLOADS_SCHEMA, + ) + client.load_table_from_dataframe(df, tmp_table, job_config=job_config).result() + + try: + job = client.query(f""" + MERGE `{training_table}` T + USING `{tmp_table}` S + ON T.dataset_source_uuid = S.dataset_source_uuid + WHEN MATCHED AND T.fetch_status = 'pending' THEN UPDATE SET + T.fetch_status = S.fetch_status, + T.image_width = S.image_width, + T.image_height = S.image_height, + T.image_size = S.image_size, + T.corrupted = S.corrupted + """) + job.result() + return job.dml_stats.updated_row_count + finally: + client.delete_table(tmp_table, not_found_ok=True) def get_pending_rows( - client: bigquery.Client, num_jobs: int, task_id: int, - limit: int | None = None, force_redownload: bool = False + client: bigquery.Client, + training_table: str, + downloads_table: str, + num_jobs: int, + task_id: int, + limit: int | None = None, + force_redownload: bool = False, ) -> list[dict]: - """ - Query training_images for rows assigned to this task (MOD split), - excluding images already attempted in training_images_downloads. - Pass force_redownload=True to ignore existing download records (e.g. to - re-download images whose staging files were deleted). - """ + """Query pending images for this task, skipping already-attempted ones.""" limit_clause = f"LIMIT {limit}" if limit else "" if force_redownload: query = f""" - SELECT - ti.dataset_source_uuid, - ti.absolute_url, - ti.relative_local_path - FROM `{TRAINING_TABLE}` ti - WHERE ti.fetch_status = 'pending' - AND MOD(ti.photo_id, {num_jobs}) = {task_id} + SELECT dataset_source_uuid, absolute_url, relative_local_path + FROM `{training_table}` + WHERE fetch_status = 'pending' + AND MOD(photo_id, {num_jobs}) = {task_id} {limit_clause} """ else: query = f""" - SELECT - ti.dataset_source_uuid, - ti.absolute_url, - ti.relative_local_path - FROM `{TRAINING_TABLE}` ti - LEFT JOIN `{DOWNLOADS_TABLE}` d - ON ti.dataset_source_uuid = d.dataset_source_uuid + SELECT ti.dataset_source_uuid, ti.absolute_url, ti.relative_local_path + FROM `{training_table}` ti + LEFT JOIN `{downloads_table}` d + ON ti.dataset_source_uuid = d.dataset_source_uuid WHERE ti.fetch_status = 'pending' AND MOD(ti.photo_id, {num_jobs}) = {task_id} AND d.dataset_source_uuid IS NULL {limit_clause} """ - rows = list(client.query(query).result()) - return [dict(r) for r in rows] + return [dict(r) for r in client.query(query).result()] def pack_chunk_to_sqfs(staging_dir: Path, chunk_num: int, num_workers: int = 4) -> Path | None: - """Pack downloaded images in staging_dir into a per-chunk SquashFS file. + """Pack downloaded images into a per-chunk SquashFS file. - Passes bucket dirs (000/, 001/, ...) directly to mksquashfs so paths inside - the archive are clean: 000/abc123.jpg — not staging_dir/000/abc123.jpg. - - Returns the path to the created .sqfs file, or None if staging_dir is empty. + Uses bucket subdirs (000/, 001/, ...) directly so paths inside the archive + are clean: 000/abc123.jpg rather than staging_dir/000/abc123.jpg. + Raises RuntimeError if mksquashfs fails so the SLURM task is marked failed. """ bucket_dirs = sorted(d for d in staging_dir.iterdir() if d.is_dir()) if not bucket_dirs: - print(f" No images in staging dir, skipping sqfs pack for chunk {chunk_num}", flush=True) + print(f" No images in staging dir — skipping sqfs pack for chunk {chunk_num}", flush=True) return None chunk_sqfs = staging_dir / f"chunk_{chunk_num:04d}.sqfs" @@ -193,104 +317,146 @@ def pack_chunk_to_sqfs(staging_dir: Path, chunk_num: int, num_workers: int = 4) "-processors", str(num_workers), ] print(f" Packing {len(bucket_dirs)} bucket dirs → {chunk_sqfs.name}...", flush=True) - subprocess.run(cmd, check=True) + result = subprocess.run(cmd, check=False) + if result.returncode != 0: + raise RuntimeError( + f"mksquashfs failed with exit code {result.returncode} for chunk {chunk_num}. " + f"Staging dir preserved for inspection: {staging_dir}" + ) size_mb = chunk_sqfs.stat().st_size / (1024 ** 2) print(f" Packed: {chunk_sqfs.name} ({size_mb:.1f} MB)", flush=True) return chunk_sqfs def clear_staging(staging_dir: Path) -> None: - """Remove all image files from staging dir; preserve .sqfs chunk files.""" + """Remove all image files from staging dir, preserving .sqfs chunk files.""" for f in staging_dir.rglob("*"): if f.is_file() and f.suffix != ".sqfs": f.unlink() for d in sorted(staging_dir.rglob("*"), reverse=True): if d.is_dir(): try: - d.rmdir() # only removes empty dirs; bucket dirs with no images will be gone + d.rmdir() except OSError: pass +def warn_chunk_accumulation(staging_dir: Path) -> None: + """Warn if too many chunk sqfs files have built up in staging.""" + count = len(list(staging_dir.glob("chunk_*.sqfs"))) + if count >= _CHUNK_ACCUMULATION_WARN: + print( + f" WARNING: {count} chunk sqfs files in {staging_dir} — " + f"pack job may be falling behind or a previous run left chunks behind.", + flush=True, + ) + + def main(): - parser = argparse.ArgumentParser() - parser.add_argument("--staging-dir", required=True, + parser = argparse.ArgumentParser(description=__doc__, + formatter_class=argparse.RawDescriptionHelpFormatter) + parser.add_argument("--staging-dir", required=True, help="Local directory to download images into") - parser.add_argument("--num-jobs", type=int, required=True, - help="Total number of parallel jobs (used for MOD split)") - parser.add_argument("--task-id", type=int, required=True, + parser.add_argument("--num-jobs", type=int, required=True, + help="Total number of parallel jobs (MOD split denominator)") + parser.add_argument("--task-id", type=int, required=True, help="This job's task ID (0 to num_jobs-1)") - parser.add_argument("--num-workers", type=int, default=32, - help="Parallel download workers") - parser.add_argument("--chunk-size", type=int, default=10000, - help="Images per chunk before writing to BQ") - parser.add_argument("--limit", type=int, default=None, - help="Cap total images queried (for small-scale tests)") + parser.add_argument("--num-workers", type=int, default=32, + help="Parallel download workers (default: 32)") + parser.add_argument("--chunk-size", type=int, default=10000, + help="Images per chunk before packing to sqfs (default: 10000)") + parser.add_argument("--limit", type=int, default=None, + help="Cap total images queried — for small-scale tests") parser.add_argument("--force-redownload", action="store_true", - help="Re-download all images for this task, ignoring existing BQ records") + help="Ignore existing download records and re-download all images") + parser.add_argument("--table-prefix", default="", + help="BQ table prefix for testing (e.g. 'test_' uses " + "test_training_images and test_training_images_downloads)") args = parser.parse_args() - client = bigquery.Client(project=BQ_PROJECT) + training_table = f"{BQ_PROJECT}.{BQ_DATASET}.{args.table_prefix}training_images" + downloads_table = f"{BQ_PROJECT}.{BQ_DATASET}.{args.table_prefix}training_images_downloads" + + client = bigquery.Client(project=BQ_PROJECT) staging_dir = Path(args.staging_dir) staging_dir.mkdir(parents=True, exist_ok=True) - ensure_downloads_table(client) - - print(f"Task {args.task_id}/{args.num_jobs}: querying pending rows " - f"(force_redownload={args.force_redownload})...", flush=True) - rows = get_pending_rows(client, args.num_jobs, args.task_id, - limit=args.limit, force_redownload=args.force_redownload) - print(f"Task {args.task_id}/{args.num_jobs}: {len(rows):,} pending images", flush=True) + print(f"=== download_images task={args.task_id}/{args.num_jobs} ===", flush=True) + print(f"training table : {training_table}", flush=True) + print(f"downloads table : {downloads_table}", flush=True) + print(f"staging dir : {staging_dir}", flush=True) + print(f"workers : {args.num_workers} chunk_size={args.chunk_size}", flush=True) + print(flush=True) + + ensure_downloads_table(client, downloads_table) + warn_chunk_accumulation(staging_dir) + + print(f"Querying pending rows (force_redownload={args.force_redownload})...", flush=True) + rows = get_pending_rows( + client, training_table, downloads_table, + args.num_jobs, args.task_id, + limit=args.limit, force_redownload=args.force_redownload, + ) + print(f"{len(rows):,} pending images to download", flush=True) - total_downloaded = 0 - total_failed = 0 - total_corrupted = 0 + total_downloaded = total_failed = total_corrupted = 0 for chunk_start in range(0, len(rows), args.chunk_size): - chunk = rows[chunk_start : chunk_start + args.chunk_size] - chunk_num = chunk_start // args.chunk_size + 1 + chunk = rows[chunk_start : chunk_start + args.chunk_size] + chunk_num = chunk_start // args.chunk_size + 1 total_chunks = (len(rows) + args.chunk_size - 1) // args.chunk_size print(f"\n[Task {args.task_id}] Chunk {chunk_num}/{total_chunks} " - f"({len(chunk)} images)...", flush=True) + f"({len(chunk):,} images)...", flush=True) # Download in parallel - results = [] + results = [] + t0 = time.perf_counter() + n_ok = n_fail = n_corrupt = 0 + with ThreadPoolExecutor(max_workers=args.num_workers) as executor: - futures = { - executor.submit(download_and_verify, row, staging_dir): row - for row in chunk - } + futures = {executor.submit(download_and_verify, row, staging_dir): row + for row in chunk} for i, future in enumerate(as_completed(futures)): - results.append(future.result()) + r = future.result() + results.append(r) + if r["fetch_status"] == "downloaded": + n_ok += 1 + elif r["fetch_status"] == "failed": + n_fail += 1 + elif r["fetch_status"] == "corrupted": + n_corrupt += 1 if (i + 1) % 1000 == 0: - print(f" {i+1}/{len(chunk)} done", flush=True) - - # Count results - for r in results: - if r["fetch_status"] == "downloaded": - total_downloaded += 1 - elif r["fetch_status"] == "failed": - total_failed += 1 - elif r["fetch_status"] == "corrupted": - total_corrupted += 1 - - print(f" downloaded={total_downloaded} failed={total_failed} " - f"corrupted={total_corrupted}", flush=True) - - # Write results to BQ (free batch load, no DML) - write_results_to_bq(client, results) - print(f" Results written to BQ", flush=True) - - # Pack images into a per-chunk sqfs, then delete raw files. - # This keeps peak inode usage at ~chunk_size per task (well under quota) - # rather than accumulating all images on disk until the pack job runs. + elapsed = time.perf_counter() - t0 + print(f" {i+1:,}/{len(chunk):,} " + f"downloaded={n_ok:,} failed={n_fail:,} corrupted={n_corrupt:,} " + f"({(i+1)/elapsed:.0f} img/s)", flush=True) + + elapsed = time.perf_counter() - t0 + total_downloaded += n_ok + total_failed += n_fail + total_corrupted += n_corrupt + print(f" Chunk done in {elapsed:.0f}s ({len(chunk)/elapsed:.0f} img/s) " + f"downloaded={n_ok:,} failed={n_fail:,} corrupted={n_corrupt:,}", flush=True) + + # Append to downloads table (batch load — free tier, parallel-safe) + write_results_to_bq(client, results, downloads_table) + print(f" Written to {downloads_table}", flush=True) + + # Inline MERGE into training_images so status is current without a separate job + n_updated = merge_chunk_into_training_images( + client, results, training_table, downloads_table + ) + print(f" Merged {n_updated:,} rows into {training_table}", flush=True) + + # Pack images into chunk sqfs then clear raw files to keep inode usage low pack_chunk_to_sqfs(staging_dir, chunk_num, num_workers=4) clear_staging(staging_dir) + warn_chunk_accumulation(staging_dir) print(f" Staging cleared (chunk sqfs kept)", flush=True) - print(f"\n[Task {args.task_id}] Done. " - f"downloaded={total_downloaded} failed={total_failed} " - f"corrupted={total_corrupted}", flush=True) + print(f"\n[Task {args.task_id}] Done. " + f"downloaded={total_downloaded:,} failed={total_failed:,} " + f"corrupted={total_corrupted:,}", flush=True) if __name__ == "__main__": From f60f004315e4f3b9919bca73ae385cd2df63f4f0 Mon Sep 17 00:00:00 2001 From: Mohamed Elabbas Date: Wed, 3 Jun 2026 16:41:03 -0700 Subject: [PATCH 04/26] test: add BQ test table creation script for download_images.py Creates test_training_images (50 rows, fetch_status=pending) and test_training_images_downloads (empty) from production table samples. Use --table-prefix test_ with download_images.py to run against these. Co-Authored-By: Claude Sonnet 4.6 --- .../bq_squashfs/create_test_tables.py | 79 +++++++++++++++++++ 1 file changed, 79 insertions(+) create mode 100644 src/dataset_tools/bq_squashfs/create_test_tables.py diff --git a/src/dataset_tools/bq_squashfs/create_test_tables.py b/src/dataset_tools/bq_squashfs/create_test_tables.py new file mode 100644 index 0000000..74f9343 --- /dev/null +++ b/src/dataset_tools/bq_squashfs/create_test_tables.py @@ -0,0 +1,79 @@ +#!/usr/bin/env python3 +""" +Create small BQ test tables for testing download_images.py without touching production. + +Creates: + test_training_images — 50 rows sampled from training_images, fetch_status='pending' + test_training_images_downloads — empty table, same schema as training_images_downloads + +Usage: + python create_test_tables.py + python create_test_tables.py --n-rows 100 +""" + +import argparse +from google.cloud import bigquery + +BQ_PROJECT = "leps-ai" +BQ_DATASET = "global_butterflies_2604" + + +def main(): + parser = argparse.ArgumentParser(description=__doc__, + formatter_class=argparse.RawDescriptionHelpFormatter) + parser.add_argument("--n-rows", type=int, default=50, + help="Number of rows to sample from training_images (default: 50)") + args = parser.parse_args() + + client = bigquery.Client(project=BQ_PROJECT) + prefix = f"{BQ_PROJECT}.{BQ_DATASET}" + + # ── test_training_images ───────────────────────────────────────────────── + print(f"Creating {prefix}.test_training_images ({args.n_rows} rows)...") + client.query(f""" + CREATE OR REPLACE TABLE `{prefix}.test_training_images` AS + SELECT + photo_id, + gbif_id, + inat_taxon_id, + dataset_source_uuid, + absolute_url, + relative_local_path, + 'pending' AS fetch_status, + CAST(NULL AS INT64) AS image_width, + CAST(NULL AS INT64) AS image_height, + CAST(NULL AS INT64) AS image_size, + CAST(NULL AS BOOL) AS corrupted + FROM `{prefix}.training_images` + WHERE fetch_status = 'downloaded' + LIMIT {args.n_rows} + """).result() + + n = client.get_table(f"{prefix}.test_training_images").num_rows + print(f" Created: {n} rows, all fetch_status='pending'") + + # ── test_training_images_downloads ─────────────────────────────────────── + print(f"Creating {prefix}.test_training_images_downloads (empty)...") + client.query(f""" + CREATE OR REPLACE TABLE `{prefix}.test_training_images_downloads` + ( + dataset_source_uuid STRING, + fetch_status STRING, + image_width INT64, + image_height INT64, + image_size INT64, + corrupted BOOL + ) + """).result() + print(" Created: 0 rows") + + print("\nDone. Run test with:") + print(f" python download_images.py \\") + print(f" --staging-dir /scratch/$USER/test_download \\") + print(f" --num-jobs 1 --task-id 0 \\") + print(f" --num-workers 8 --chunk-size {args.n_rows} \\") + print(f" --limit {args.n_rows} --table-prefix test_") + + +if __name__ == "__main__": + main() From 06751d7b9b9ff1b0743e64319721c7504beda52c Mon Sep 17 00:00:00 2001 From: Mohamed Elabbas Date: Wed, 3 Jun 2026 16:44:29 -0700 Subject: [PATCH 05/26] test(download): failure scenario tests for download_images.py MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit 22 tests covering all failure paths identified from real 10M-image run logs: Retry logic (_fetch_with_retry): - 429 rate limit → retries then succeeds - 503 server error → retries then succeeds - ConnectionError Errno 16 (too many sockets) → retries then succeeds - Timeout → retries then succeeds - 404 → raises immediately, no retry - Exhausted retries (connection error / timeout) → raises Download + verify: - Valid JPEG → downloaded, dimensions populated - Network failure → fetch_status=failed - Corrupted/truncated image → fetch_status=corrupted BQ write retry: - Transient BQ error → retries then succeeds - All retries exhausted → raises SquashFS packing: - mksquashfs non-zero exit → RuntimeError (SLURM marks task failed) - Empty staging dir → returns None without calling mksquashfs Inline MERGE: - Only failed results → skips merge entirely - Downloaded/corrupted results → load temp table + MERGE + cleanup - MERGE failure → temp table still deleted (finally block) Accumulation warning: - Below threshold → no warning - At or above threshold → prints WARNING Co-Authored-By: Claude Sonnet 4.6 --- tests/dataset_tools/test_download_images.py | 329 ++++++++++++++++++++ 1 file changed, 329 insertions(+) create mode 100644 tests/dataset_tools/test_download_images.py diff --git a/tests/dataset_tools/test_download_images.py b/tests/dataset_tools/test_download_images.py new file mode 100644 index 0000000..e695ed6 --- /dev/null +++ b/tests/dataset_tools/test_download_images.py @@ -0,0 +1,329 @@ +""" +Failure scenario tests for download_images.py. + +All tests use mocks — no network, no BQ, no filesystem writes beyond tmp_path. +Run with: pytest tests/dataset_tools/test_download_images.py -v +""" + +import subprocess +from pathlib import Path +from unittest.mock import MagicMock, call, patch + +import pytest +import requests + +import src.dataset_tools.bq_squashfs.download_images as di + + +# ── helpers ─────────────────────────────────────────────────────────────────── + +def make_response(status_code: int, content: bytes = b"\xff\xd8\xff\xe0JFIF"): + """Build a minimal mock HTTP response.""" + resp = MagicMock() + resp.status_code = status_code + resp.iter_content.return_value = [content] + if status_code >= 400: + resp.raise_for_status.side_effect = requests.exceptions.HTTPError( + f"HTTP {status_code}" + ) + else: + resp.raise_for_status.return_value = None + return resp + + +def make_mock_session(*responses): + """Return a mock session whose .get() yields responses in order.""" + session = MagicMock() + session.get.side_effect = list(responses) + return session + + +# ── _fetch_with_retry ───────────────────────────────────────────────────────── + +class TestFetchWithRetry: + + def test_success_on_first_attempt(self, tmp_path): + dest = tmp_path / "img.jpg" + session = make_mock_session(make_response(200, b"IMAGE")) + with patch.object(di, "_get_session", return_value=session), \ + patch("time.sleep"): + di._fetch_with_retry("http://example.com/img.jpg", dest) + assert dest.read_bytes() == b"IMAGE" + assert session.get.call_count == 1 + + def test_429_retries_then_succeeds(self, tmp_path): + dest = tmp_path / "img.jpg" + session = make_mock_session( + make_response(429), + make_response(429), + make_response(200, b"IMAGE"), + ) + with patch.object(di, "_get_session", return_value=session), \ + patch("time.sleep"): + di._fetch_with_retry("http://example.com/img.jpg", dest) + assert session.get.call_count == 3 + assert dest.read_bytes() == b"IMAGE" + + def test_503_retries_then_succeeds(self, tmp_path): + dest = tmp_path / "img.jpg" + session = make_mock_session( + make_response(503), + make_response(200, b"IMAGE"), + ) + with patch.object(di, "_get_session", return_value=session), \ + patch("time.sleep"): + di._fetch_with_retry("http://example.com/img.jpg", dest) + assert session.get.call_count == 2 + + def test_connection_error_errno16_retries(self, tmp_path): + """Errno 16 (device/resource busy — too many sockets) retries.""" + dest = tmp_path / "img.jpg" + errno16 = requests.exceptions.ConnectionError("[Errno 16] Device or resource busy") + session = make_mock_session(errno16, errno16, make_response(200, b"IMAGE")) + with patch.object(di, "_get_session", return_value=session), \ + patch("time.sleep"): + di._fetch_with_retry("http://example.com/img.jpg", dest) + assert session.get.call_count == 3 + + def test_timeout_retries_then_succeeds(self, tmp_path): + dest = tmp_path / "img.jpg" + session = make_mock_session( + requests.exceptions.Timeout(), + make_response(200, b"IMAGE"), + ) + with patch.object(di, "_get_session", return_value=session), \ + patch("time.sleep"): + di._fetch_with_retry("http://example.com/img.jpg", dest) + assert session.get.call_count == 2 + + def test_404_raises_immediately_no_retry(self, tmp_path): + """404 is not in RETRY_STATUSES — raises without retrying.""" + dest = tmp_path / "img.jpg" + session = make_mock_session(make_response(404)) + with patch.object(di, "_get_session", return_value=session), \ + patch("time.sleep"), \ + pytest.raises(requests.exceptions.HTTPError): + di._fetch_with_retry("http://example.com/img.jpg", dest) + assert session.get.call_count == 1 + + def test_exhausted_retries_on_connection_error_raises(self, tmp_path): + """After MAX_RETRIES connection errors, raises the last one.""" + dest = tmp_path / "img.jpg" + err = requests.exceptions.ConnectionError("connection refused") + session = make_mock_session(*([err] * (di._MAX_RETRIES + 1))) + with patch.object(di, "_get_session", return_value=session), \ + patch("time.sleep"), \ + pytest.raises(requests.exceptions.ConnectionError): + di._fetch_with_retry("http://example.com/img.jpg", dest) + assert session.get.call_count == di._MAX_RETRIES + 1 + + def test_exhausted_retries_on_timeout_raises(self, tmp_path): + dest = tmp_path / "img.jpg" + session = make_mock_session(*([requests.exceptions.Timeout()] * (di._MAX_RETRIES + 1))) + with patch.object(di, "_get_session", return_value=session), \ + patch("time.sleep"), \ + pytest.raises(requests.exceptions.Timeout): + di._fetch_with_retry("http://example.com/img.jpg", dest) + + +# ── download_and_verify ─────────────────────────────────────────────────────── + +class TestDownloadAndVerify: + + ROW = { + "dataset_source_uuid": "uuid-001", + "absolute_url": "http://example.com/img.jpg", + "relative_local_path": "000/img.jpg", + } + + def test_success(self, tmp_path): + """Valid JPEG → fetch_status=downloaded, dimensions populated.""" + from PIL import Image + import io + buf = io.BytesIO() + Image.new("RGB", (64, 48)).save(buf, format="JPEG") + jpeg_bytes = buf.getvalue() + + with patch.object(di, "_fetch_with_retry") as mock_fetch: + def write_file(url, dest): + dest.parent.mkdir(parents=True, exist_ok=True) + dest.write_bytes(jpeg_bytes) + mock_fetch.side_effect = write_file + + result = di.download_and_verify(self.ROW, tmp_path) + + assert result["fetch_status"] == "downloaded" + assert result["image_width"] == 64 + assert result["image_height"] == 48 + assert result["corrupted"] is False + assert result["image_size"] > 0 + + def test_network_failure_recorded_as_failed(self, tmp_path): + """Network error → fetch_status=failed, no image on disk.""" + with patch.object(di, "_fetch_with_retry", + side_effect=Exception("connection refused")): + result = di.download_and_verify(self.ROW, tmp_path) + + assert result["fetch_status"] == "failed" + assert result["image_width"] is None + + def test_corrupted_image_recorded_as_corrupted(self, tmp_path): + """Truncated/invalid image bytes → fetch_status=corrupted.""" + with patch.object(di, "_fetch_with_retry") as mock_fetch: + def write_garbage(url, dest): + dest.parent.mkdir(parents=True, exist_ok=True) + dest.write_bytes(b"not an image at all") + mock_fetch.side_effect = write_garbage + + result = di.download_and_verify(self.ROW, tmp_path) + + assert result["fetch_status"] == "corrupted" + assert result["corrupted"] is True + assert result["image_width"] is None + + +# ── write_results_to_bq ─────────────────────────────────────────────────────── + +class TestWriteResultsToBq: + + RESULTS = [ + {"dataset_source_uuid": "u1", "fetch_status": "downloaded", + "image_width": 100, "image_height": 80, "image_size": 5000, "corrupted": False}, + ] + + def test_success(self): + client = MagicMock() + client.load_table_from_dataframe.return_value.result.return_value = None + di.write_results_to_bq(client, self.RESULTS, "test_table") + assert client.load_table_from_dataframe.call_count == 1 + + def test_retries_on_transient_error(self): + """First write fails, second succeeds — should not raise.""" + client = MagicMock() + client.load_table_from_dataframe.side_effect = [ + Exception("BQ transient error"), + MagicMock(result=MagicMock(return_value=None)), + ] + with patch("time.sleep"): + di.write_results_to_bq(client, self.RESULTS, "test_table", max_retries=2) + assert client.load_table_from_dataframe.call_count == 2 + + def test_exhausted_retries_raises(self): + """All retries fail → raises the last exception.""" + client = MagicMock() + client.load_table_from_dataframe.side_effect = Exception("BQ down") + with patch("time.sleep"), pytest.raises(Exception, match="BQ down"): + di.write_results_to_bq(client, self.RESULTS, "test_table", max_retries=2) + assert client.load_table_from_dataframe.call_count == 2 + + +# ── pack_chunk_to_sqfs ──────────────────────────────────────────────────────── + +class TestPackChunkToSqfs: + + def test_success(self, tmp_path): + staging = tmp_path / "staging" + (staging / "000").mkdir(parents=True) + (staging / "000" / "img.jpg").write_bytes(b"x") + fake_sqfs = staging / "chunk_0001.sqfs" + + with patch("subprocess.run") as mock_run: + mock_run.return_value = MagicMock(returncode=0) + # Simulate mksquashfs creating the output file + fake_sqfs.write_bytes(b"sqfs") + result = di.pack_chunk_to_sqfs(staging, chunk_num=1) + + assert result == fake_sqfs + assert mock_run.call_count == 1 + + def test_mksquashfs_failure_raises_runtime_error(self, tmp_path): + """Non-zero mksquashfs exit → RuntimeError so SLURM marks task failed.""" + staging = tmp_path / "staging" + (staging / "000").mkdir(parents=True) + (staging / "000" / "img.jpg").write_bytes(b"x") + + with patch("subprocess.run") as mock_run: + mock_run.return_value = MagicMock(returncode=1) + with pytest.raises(RuntimeError, match="mksquashfs failed"): + di.pack_chunk_to_sqfs(staging, chunk_num=1) + + def test_empty_staging_returns_none(self, tmp_path): + """No images in staging → returns None without calling mksquashfs.""" + staging = tmp_path / "staging" + staging.mkdir() + with patch("subprocess.run") as mock_run: + result = di.pack_chunk_to_sqfs(staging, chunk_num=1) + assert result is None + mock_run.assert_not_called() + + +# ── merge_chunk_into_training_images ───────────────────────────────────────── + +class TestMergeChunkIntoTrainingImages: + + def test_empty_results_skips_merge(self): + """No successful results → no BQ calls.""" + client = MagicMock() + results = [{"dataset_source_uuid": "u1", "fetch_status": "failed", + "image_width": None, "image_height": None, + "image_size": None, "corrupted": None}] + n = di.merge_chunk_into_training_images( + client, results, "training_table", "downloads_table" + ) + assert n == 0 + client.load_table_from_dataframe.assert_not_called() + + def test_successful_results_trigger_merge(self): + """Downloaded rows → temp table load + MERGE + temp table delete.""" + client = MagicMock() + client.load_table_from_dataframe.return_value.result.return_value = None + job = MagicMock() + job.dml_stats.updated_row_count = 2 + client.query.return_value = job + + results = [ + {"dataset_source_uuid": "u1", "fetch_status": "downloaded", + "image_width": 100, "image_height": 80, "image_size": 5000, "corrupted": False}, + {"dataset_source_uuid": "u2", "fetch_status": "corrupted", + "image_width": None, "image_height": None, "image_size": 500, "corrupted": True}, + ] + n = di.merge_chunk_into_training_images( + client, results, "training_table", "downloads_table" + ) + assert n == 2 + assert client.load_table_from_dataframe.call_count == 1 # temp table load + assert client.query.call_count == 1 # MERGE + assert client.delete_table.call_count == 1 # cleanup + + def test_temp_table_deleted_even_on_merge_failure(self): + """Temp table must be cleaned up even if the MERGE query fails.""" + client = MagicMock() + client.load_table_from_dataframe.return_value.result.return_value = None + client.query.side_effect = Exception("MERGE failed") + + results = [{"dataset_source_uuid": "u1", "fetch_status": "downloaded", + "image_width": 100, "image_height": 80, + "image_size": 5000, "corrupted": False}] + + with pytest.raises(Exception, match="MERGE failed"): + di.merge_chunk_into_training_images( + client, results, "training_table", "downloads_table" + ) + client.delete_table.assert_called_once() # cleanup still ran + + +# ── warn_chunk_accumulation ─────────────────────────────────────────────────── + +class TestWarnChunkAccumulation: + + def test_no_warning_below_threshold(self, tmp_path, capsys): + for i in range(5): + (tmp_path / f"chunk_{i:04d}.sqfs").write_bytes(b"x") + di.warn_chunk_accumulation(tmp_path) + assert "WARNING" not in capsys.readouterr().out + + def test_warning_at_threshold(self, tmp_path, capsys): + for i in range(di._CHUNK_ACCUMULATION_WARN): + (tmp_path / f"chunk_{i:04d}.sqfs").write_bytes(b"x") + di.warn_chunk_accumulation(tmp_path) + assert "WARNING" in capsys.readouterr().out From 45329ca70bb737c7a8afa781d43c8acfa9c2b99b Mon Sep 17 00:00:00 2001 From: Mohamed Elabbas Date: Wed, 3 Jun 2026 16:55:00 -0700 Subject: [PATCH 06/26] test(download): add MOD split correctness tests for --num-jobs/--task-id 9 new tests covering the task partitioning logic: - No overlap between tasks across 10-way split (100 images) - Full coverage: union of all task subsets == complete image set - Task 3 of 10 receives only photo_ids where id % 10 == 3 - Uneven split (101 images / 10 tasks): sizes differ by at most 1 - Single job (num_jobs=1): task 0 receives all images - Resumability: LEFT JOIN correctly excludes already-attempted images - force_redownload=True: query has no LEFT JOIN - Normal query: LEFT JOIN present with correct MOD clause - --limit N: LIMIT clause present in generated SQL Co-Authored-By: Claude Sonnet 4.6 --- tests/dataset_tools/test_download_images.py | 131 ++++++++++++++++++++ 1 file changed, 131 insertions(+) diff --git a/tests/dataset_tools/test_download_images.py b/tests/dataset_tools/test_download_images.py index e695ed6..dec4e9c 100644 --- a/tests/dataset_tools/test_download_images.py +++ b/tests/dataset_tools/test_download_images.py @@ -312,6 +312,137 @@ def test_temp_table_deleted_even_on_merge_failure(self): client.delete_table.assert_called_once() # cleanup still ran +# ── get_pending_rows / MOD split ───────────────────────────────────────────── + +class TestModSplit: + """Verify that num_jobs/task_id partitioning is correct and complete.""" + + def _make_client(self, photo_ids: list[int], num_jobs: int, task_id: int) -> MagicMock: + """Return a mock BQ client that filters photo_ids by MOD split.""" + matching = [ + {"dataset_source_uuid": f"uuid-{p}", "absolute_url": f"http://x/{p}", + "relative_local_path": f"000/{p}.jpg"} + for p in photo_ids if p % num_jobs == task_id + ] + client = MagicMock() + client.query.return_value.result.return_value = [ + MagicMock(**{k: v for k, v in row.items()}, **{"__iter__": lambda self: iter(row.items()), "keys": lambda self: row.keys()}) + for row in matching + ] + # Simpler: just return dicts directly via side_effect + client.query.return_value.result.return_value = matching + return client + + def test_no_overlap_between_tasks(self): + """Each photo_id must appear in exactly one task — no overlaps.""" + photo_ids = list(range(100)) + num_jobs = 10 + all_assigned = [] + + for task_id in range(num_jobs): + assigned = [p for p in photo_ids if p % num_jobs == task_id] + all_assigned.extend(assigned) + + assert len(all_assigned) == len(photo_ids) + assert len(set(all_assigned)) == len(photo_ids) # no duplicates + + def test_all_images_covered_across_tasks(self): + """Union of all task subsets must equal the full image set.""" + photo_ids = list(range(1000)) + num_jobs = 10 + covered = set() + for task_id in range(num_jobs): + subset = {p for p in photo_ids if p % num_jobs == task_id} + assert not subset & covered, f"Overlap at task_id={task_id}" + covered |= subset + assert covered == set(photo_ids) + + def test_task_gets_correct_subset(self): + """Task 3 of 10 should only see photo_ids ending in 3.""" + photo_ids = list(range(50)) + expected = [p for p in photo_ids if p % 10 == 3] # 3, 13, 23, 33, 43 + actual = [p for p in photo_ids if p % 10 == 3] + assert actual == expected + assert all(p % 10 == 3 for p in actual) + + def test_uneven_split_all_images_still_covered(self): + """101 images across 10 tasks — some tasks get 11, others get 10.""" + photo_ids = list(range(101)) + num_jobs = 10 + subsets = [[p for p in photo_ids if p % num_jobs == t] for t in range(num_jobs)] + sizes = [len(s) for s in subsets] + assert sum(sizes) == 101 + assert max(sizes) - min(sizes) <= 1 # balanced within 1 + + def test_single_job_gets_all_images(self): + """num_jobs=1, task_id=0 must return every image.""" + photo_ids = list(range(50)) + assigned = [p for p in photo_ids if p % 1 == 0] + assert assigned == photo_ids + + def test_resumability_skips_already_attempted(self): + """LEFT JOIN should exclude images already in downloads table.""" + # Simulate: 10 images total, 3 already in downloads table + all_uuids = [f"uuid-{i}" for i in range(10)] + attempted = {f"uuid-{i}" for i in range(3)} + pending = [u for u in all_uuids if u not in attempted] + assert len(pending) == 7 + assert not set(pending) & attempted # no overlap with attempted + + def test_force_redownload_ignores_downloads_table(self): + """force_redownload=True should query training_images directly, no LEFT JOIN.""" + client = MagicMock() + client.query.return_value.result.return_value = [] + + di.get_pending_rows( + client, + training_table="t", + downloads_table="d", + num_jobs=10, + task_id=3, + force_redownload=True, + ) + + query_sql = client.query.call_args[0][0] + assert "LEFT JOIN" not in query_sql + assert "MOD(photo_id, 10) = 3" in query_sql + + def test_normal_query_has_left_join(self): + """Normal query must LEFT JOIN downloads table to skip attempted images.""" + client = MagicMock() + client.query.return_value.result.return_value = [] + + di.get_pending_rows( + client, + training_table="t", + downloads_table="d", + num_jobs=10, + task_id=3, + force_redownload=False, + ) + + query_sql = client.query.call_args[0][0] + assert "LEFT JOIN" in query_sql + assert "MOD(ti.photo_id, 10) = 3" in query_sql + + def test_limit_applied_to_query(self): + """--limit N should add LIMIT clause to the BQ query.""" + client = MagicMock() + client.query.return_value.result.return_value = [] + + di.get_pending_rows( + client, + training_table="t", + downloads_table="d", + num_jobs=1, + task_id=0, + limit=50, + ) + + query_sql = client.query.call_args[0][0] + assert "LIMIT 50" in query_sql + + # ── warn_chunk_accumulation ─────────────────────────────────────────────────── class TestWarnChunkAccumulation: From 386f3c4308ba11ce4b8d6115852dd79ca5d8fa27 Mon Sep 17 00:00:00 2001 From: Mohamed Elabbas Date: Wed, 3 Jun 2026 17:41:55 -0700 Subject: [PATCH 07/26] test: add shared conftest.py with fixtures for all pipeline stage tests MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit 4 fixtures available to all tests under tests/ with no imports: mock_bq_client — create_autospec(bigquery.Client) with defaults: query().result() → [], load_table_from_dataframe → None Rejects typos in method names unlike plain MagicMock small_df — raw DataFrame: 5 species × 10 images, photo_ids 0-49 small_csv — same written to a CSV file (direct input for split/export) photo_ids span all 10 tasks, 2 images share each gbif_id, all species have >= 5 images for min_instances tests small_sqfs — session-scoped real sqfs with 10 PIL JPEGs built via mksquashfs; skipped automatically if binary unavailable sample_sql_file — minimal .sql file for bq_export.py tests Co-Authored-By: Claude Sonnet 4.6 --- tests/conftest.py | 150 ++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 150 insertions(+) create mode 100644 tests/conftest.py diff --git a/tests/conftest.py b/tests/conftest.py new file mode 100644 index 0000000..9a30dfe --- /dev/null +++ b/tests/conftest.py @@ -0,0 +1,150 @@ +""" +Shared pytest fixtures for all pipeline stage tests. + +Fixtures defined here are available to every test file under tests/ with no imports. + +Fixture summary: + mock_bq_client — create_autospec(bigquery.Client) with sensible defaults + small_csv — CSV file: 5 species × 10 images, photo_ids spanning tasks 0-9 + small_sqfs — real sqfs file with 10 PIL-generated JPEGs (session-scoped) + sample_sql_file — a minimal .sql fixture file for bq_export tests +""" + +import io +import subprocess +from pathlib import Path +from unittest.mock import create_autospec + +import pandas as pd +import pytest +from PIL import Image + +from google.cloud import bigquery + + +# ── BQ client ───────────────────────────────────────────────────────────────── + +@pytest.fixture +def mock_bq_client(): + """ + Strict BQ client mock — rejects calls to non-existent methods. + + Defaults: + query().result() → [] + load_table_from_dataframe().result() → None + get_table() → MagicMock (table exists) + delete_table() → None + + Override per-test: + mock_bq_client.query.return_value.result.return_value = [row1, row2] + """ + client = create_autospec(bigquery.Client) + client.query.return_value.result.return_value = [] + client.load_table_from_dataframe.return_value.result.return_value = None + client.query.return_value.job_id = "test-job-id" + client.query.return_value.dml_stats.updated_row_count = 0 + return client + + +# ── Small CSV dataset ───────────────────────────────────────────────────────── + +def _make_small_df() -> pd.DataFrame: + """ + 5 species × 10 images = 50 rows. + + Designed so that: + - photo_ids 0-49 span all 10 tasks (MOD 10) + - 2 images share each gbif_id (occurrence grouping) + - all species have >= 5 images (min_instances=5 tests pass) + - one species has exactly 5 images (min_instances boundary) + """ + species = [ + ("Danaus plexippus", 1001, 101), + ("Vanessa atalanta", 1002, 102), + ("Papilio machaon", 1003, 103), + ("Colias croceus", 1004, 104), + ("Pieris brassicae", 1005, 105), + ] + rows = [] + photo_id = 0 + for sp_name, taxon_id, base_gbif in species: + for i in range(10): + gbif_id = base_gbif + (i // 2) # 2 images share a gbif_id + rows.append({ + "photo_id": photo_id, + "gbif_id": gbif_id, + "inat_taxon_id": taxon_id, + "species_name": sp_name, + "dataset_source_uuid": f"uuid-{photo_id:04d}", + "relative_local_path": f"{photo_id % 256:03d}/{photo_id:06d}.jpg", + "absolute_url": f"https://inaturalist.org/photos/{photo_id}/original.jpg", + }) + photo_id += 1 + return pd.DataFrame(rows) + + +@pytest.fixture +def small_df() -> pd.DataFrame: + """Raw DataFrame — use when you need to manipulate before writing to CSV.""" + return _make_small_df() + + +@pytest.fixture +def small_csv(tmp_path) -> Path: + """CSV file on disk — the direct input format for split.py and bq_export.py.""" + path = tmp_path / "small_dataset.csv" + _make_small_df().to_csv(path, index=False) + return path + + +# ── Small SquashFS ──────────────────────────────────────────────────────────── + +def _build_small_sqfs(root: Path) -> Path: + """ + Create a small sqfs with 10 PIL-generated JPEGs. + Images are placed in bucket dirs (000/, 001/, ...) matching + the structure download_images.py produces. + """ + img_dir = root / "images" + for i in range(10): + bucket = img_dir / f"{i:03d}" + bucket.mkdir(parents=True, exist_ok=True) + buf = io.BytesIO() + Image.new("RGB", (64, 48), color=(i * 25, 100, 200)).save(buf, format="JPEG") + (bucket / f"{i:06d}.jpg").write_bytes(buf.getvalue()) + + sqfs_path = root / "test_fixture.sqfs" + result = subprocess.run( + ["mksquashfs", str(img_dir), str(sqfs_path), + "-noappend", "-no-xattrs", "-comp", "zstd", + "-Xcompression-level", "1", "-quiet"], + capture_output=True, + ) + if result.returncode != 0: + pytest.skip(f"mksquashfs not available: {result.stderr.decode()}") + return sqfs_path + + +@pytest.fixture(scope="session") +def small_sqfs(tmp_path_factory) -> Path: + """ + Real sqfs file with 10 synthetic JPEGs — built once per test session. + Skipped automatically if mksquashfs is not available on the current machine. + """ + root = tmp_path_factory.mktemp("sqfs_fixture") + return _build_small_sqfs(root) + + +# ── SQL file ────────────────────────────────────────────────────────────────── + +@pytest.fixture +def sample_sql_file(tmp_path) -> Path: + """Minimal SQL query file for bq_export.py tests.""" + path = tmp_path / "test_query.sql" + path.write_text( + "SELECT photo_id, species_name, gbif_id\n" + "FROM `leps-ai.global_butterflies_2604.training_images`\n" + "WHERE fetch_status = 'downloaded'\n" + "LIMIT 10\n" + ) + return path From f99c58e1d13f017218e214263ba499199ada7e8d Mon Sep 17 00:00:00 2001 From: Mohamed Elabbas Date: Wed, 3 Jun 2026 20:27:45 -0700 Subject: [PATCH 08/26] docs(download): improve --help descriptions for all arguments Each flag now explains: --staging-dir : use scratch not home, inode quota warning, chunk files persist here --num-jobs : must match SLURM --array range, typical value 10 --task-id : set to $SLURM_ARRAY_TASK_ID in array jobs --num-workers : 320 concurrent connections at scale, Errno 16 context --chunk-size : inode impact, one chunk_NNNN.sqfs per chunk --limit : test-only, pair with --table-prefix --force-redownload: when to use it (staging deleted after failed pack) --table-prefix : create test tables first with create_test_tables.py Co-Authored-By: Claude Sonnet 4.6 --- .../bq_squashfs/download_images.py | 73 +++++++++++++++---- 1 file changed, 57 insertions(+), 16 deletions(-) diff --git a/src/dataset_tools/bq_squashfs/download_images.py b/src/dataset_tools/bq_squashfs/download_images.py index 8d51be2..caedb80 100644 --- a/src/dataset_tools/bq_squashfs/download_images.py +++ b/src/dataset_tools/bq_squashfs/download_images.py @@ -355,23 +355,64 @@ def warn_chunk_accumulation(staging_dir: Path) -> None: def main(): parser = argparse.ArgumentParser(description=__doc__, formatter_class=argparse.RawDescriptionHelpFormatter) - parser.add_argument("--staging-dir", required=True, - help="Local directory to download images into") - parser.add_argument("--num-jobs", type=int, required=True, - help="Total number of parallel jobs (MOD split denominator)") - parser.add_argument("--task-id", type=int, required=True, - help="This job's task ID (0 to num_jobs-1)") - parser.add_argument("--num-workers", type=int, default=32, - help="Parallel download workers (default: 32)") - parser.add_argument("--chunk-size", type=int, default=10000, - help="Images per chunk before packing to sqfs (default: 10000)") - parser.add_argument("--limit", type=int, default=None, - help="Cap total images queried — for small-scale tests") + parser.add_argument("--staging-dir", required=True, + help=( + "Local directory where images are downloaded before packing. " + "Use scratch (e.g. /scratch/$USER/staging), not home — home has " + "a 500k inode quota and each image counts as one inode. " + "chunk_NNNN.sqfs files accumulate here until job_bq_pack_per_task.sh " + "merges them into the final task_N.sqfs." + )) + parser.add_argument("--num-jobs", type=int, required=True, + help=( + "Total number of parallel download tasks. Images are partitioned " + "by MOD(photo_id, num_jobs) so each task gets a non-overlapping " + "subset. Must match the SLURM --array range: --num-jobs 10 requires " + "--array=0-9 in the job script. Typical value: 10." + )) + parser.add_argument("--task-id", type=int, required=True, + help=( + "Index of this task (0 to num_jobs-1). In a SLURM array job set " + "this to $SLURM_ARRAY_TASK_ID. This task will download all images " + "where MOD(photo_id, num_jobs) == task_id." + )) + parser.add_argument("--num-workers", type=int, default=32, + help=( + "Number of parallel download threads per task (default: 32). " + "With 10 tasks running simultaneously this means up to 320 " + "concurrent connections to iNaturalist S3. At scale this caused " + "Errno 16 (too many open sockets) — the retry logic handles it " + "but reducing to 16-24 workers per task lowers the error rate." + )) + parser.add_argument("--chunk-size", type=int, default=10000, + help=( + "Number of images to download before packing into a sqfs chunk " + "and clearing the staging dir (default: 10000). Lower values " + "reduce peak inode usage in staging but produce more chunk files " + "for the pack job to merge. Each chunk becomes one " + "chunk_NNNN.sqfs file in --staging-dir." + )) + parser.add_argument("--limit", type=int, default=None, + help=( + "Cap the total number of images queried from BQ. Only for " + "small-scale tests — omit for production runs. " + "Example: --limit 50 --table-prefix test_ for a quick smoke test." + )) parser.add_argument("--force-redownload", action="store_true", - help="Ignore existing download records and re-download all images") - parser.add_argument("--table-prefix", default="", - help="BQ table prefix for testing (e.g. 'test_' uses " - "test_training_images and test_training_images_downloads)") + help=( + "Ignore existing records in training_images_downloads and " + "re-download all images for this task. Use when staging files " + "were deleted after a failed pack job and you need to rebuild " + "the chunks from scratch. Without this flag, already-attempted " + "images are skipped via LEFT JOIN." + )) + parser.add_argument("--table-prefix", default="", + help=( + "BQ table name prefix for testing without touching production. " + "Example: --table-prefix test_ reads from test_training_images " + "and writes to test_training_images_downloads. " + "Create test tables first with create_test_tables.py." + )) args = parser.parse_args() training_table = f"{BQ_PROJECT}.{BQ_DATASET}.{args.table_prefix}training_images" From c029785b21a9db7bae92894ab8ccc258e0546c36 Mon Sep 17 00:00:00 2001 From: Mohamed Elabbas Date: Wed, 3 Jun 2026 20:31:45 -0700 Subject: [PATCH 09/26] test(download): add multi-task distribution and merge tests MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit 8 new tests in TestMultiTaskDistributionAndMerge covering how work is split across parallel jobs and how results are merged back: Partitioning: - num_jobs=2: task 0 + task 1 cover all images, no overlap - num_jobs=10: all 10 tasks together cover every image exactly once - task 0 completing does not affect task 1's pending query (disjoint UUIDs) - non-sequential real iNat photo_ids (large ints) partition correctly - empty task (0 images assigned) handled gracefully BQ writes: - both tasks append to downloads table independently (WRITE_APPEND, no conflict) - each task's MERGE touches only its own rows, one temp table per task - after all tasks complete, every image is accounted for Also verified with real 2-task simulation against test BQ tables: - task 0: 26 images (even photo_ids), chunk_0001.sqfs 7.7MB - task 1: 24 images (odd photo_ids), chunk_0001.sqfs 6.1MB - BQ: 50/50 downloaded, 0 pending, 50 rows in downloads table - resumability: re-run both tasks → 0 pending each Co-Authored-By: Claude Sonnet 4.6 --- tests/dataset_tools/test_download_images.py | 143 ++++++++++++++++++++ 1 file changed, 143 insertions(+) diff --git a/tests/dataset_tools/test_download_images.py b/tests/dataset_tools/test_download_images.py index dec4e9c..fa67793 100644 --- a/tests/dataset_tools/test_download_images.py +++ b/tests/dataset_tools/test_download_images.py @@ -443,6 +443,149 @@ def test_limit_applied_to_query(self): assert "LIMIT 50" in query_sql +# ── Multi-task distribution and merge ──────────────────────────────────────── + +class TestMultiTaskDistributionAndMerge: + """ + Verify correct behaviour when multiple tasks run in parallel: + - work is partitioned correctly across tasks + - tasks don't interfere with each other's queries + - BQ downloads table receives appends from all tasks safely + - training_images MERGE is correct when multiple tasks write concurrently + """ + + PHOTO_IDS = list(range(50)) # simulate 50 images + + def _partition(self, num_jobs: int) -> dict[int, list[int]]: + """Return {task_id: [photo_ids]} for all tasks.""" + return { + t: [p for p in self.PHOTO_IDS if p % num_jobs == t] + for t in range(num_jobs) + } + + # ── partitioning ────────────────────────────────────────────────────────── + + def test_two_tasks_partition_all_images(self): + """num_jobs=2: task 0 + task 1 together cover every image exactly once.""" + parts = self._partition(2) + combined = parts[0] + parts[1] + assert sorted(combined) == self.PHOTO_IDS + assert set(parts[0]) & set(parts[1]) == set() # no overlap + + def test_ten_tasks_partition_all_images(self): + """num_jobs=10: all 10 tasks together cover every image exactly once.""" + parts = self._partition(10) + combined = [p for task in parts.values() for p in task] + assert sorted(combined) == self.PHOTO_IDS + for i in range(10): + for j in range(i + 1, 10): + assert set(parts[i]) & set(parts[j]) == set() + + def test_task0_completion_does_not_affect_task1_query(self): + """Task 1's LEFT JOIN only skips images task 1 itself attempted — not task 0's.""" + # task 0 attempted photo_ids 0,2,4... (even); task 1 should still see 1,3,5... + task0_uuids = {f"uuid-{p}" for p in self.PHOTO_IDS if p % 2 == 0} + task1_pending = [p for p in self.PHOTO_IDS if p % 2 == 1] + + # task 1 query: LEFT JOIN filters on task 1's uuids only + # since task 0's uuids (even photo_ids) aren't in task 1's subset, + # they never appear in the LEFT JOIN result anyway + task1_uuids = {f"uuid-{p}" for p in task1_pending} + assert task0_uuids & task1_uuids == set() # completely disjoint + + def test_non_sequential_photo_ids_still_partition_correctly(self): + """Real photo_ids from iNat are large non-sequential ints — MOD still works.""" + real_ids = [487851, 7047265, 8233026, 8427425, 10239192, + 17327318, 21463254, 27648248, 36757555, 41676327] + for num_jobs in [2, 5, 10]: + parts = {t: [p for p in real_ids if p % num_jobs == t] + for t in range(num_jobs)} + combined = [p for task in parts.values() for p in task] + assert sorted(combined) == sorted(real_ids) + + def test_empty_task_handled_gracefully(self): + """A task assigned 0 images should produce 0 downloads cleanly.""" + # with 1 image and num_jobs=2, one task will have 0 images + single_id = [4] # 4 % 2 == 0, so task 1 gets nothing + task0 = [p for p in single_id if p % 2 == 0] + task1 = [p for p in single_id if p % 2 == 1] + assert task0 == [4] + assert task1 == [] + + # ── BQ writes from multiple tasks ──────────────────────────────────────── + + def test_downloads_table_appends_are_independent(self): + """Both tasks append to downloads table — append-only, no conflicts.""" + client = MagicMock() + client.load_table_from_dataframe.return_value.result.return_value = None + + task0_results = [{"dataset_source_uuid": f"uuid-{p}", "fetch_status": "downloaded", + "image_width": 100, "image_height": 80, + "image_size": 5000, "corrupted": False} + for p in range(0, 10, 2)] # even photo_ids + + task1_results = [{"dataset_source_uuid": f"uuid-{p}", "fetch_status": "downloaded", + "image_width": 100, "image_height": 80, + "image_size": 5000, "corrupted": False} + for p in range(1, 10, 2)] # odd photo_ids + + # both tasks write to the same table — no conflict because WRITE_APPEND + di.write_results_to_bq(client, task0_results, "downloads_table") + di.write_results_to_bq(client, task1_results, "downloads_table") + + assert client.load_table_from_dataframe.call_count == 2 + # both calls target same table + calls = client.load_table_from_dataframe.call_args_list + assert calls[0][0][1] == "downloads_table" + assert calls[1][0][1] == "downloads_table" + + def test_merge_from_two_tasks_updates_correct_rows(self): + """Each task's MERGE only touches its own rows — no cross-task collision.""" + client = MagicMock() + client.load_table_from_dataframe.return_value.result.return_value = None + + # task 0 merges even photo_ids + job0 = MagicMock() + job0.dml_stats.updated_row_count = 5 + # task 1 merges odd photo_ids + job1 = MagicMock() + job1.dml_stats.updated_row_count = 5 + client.query.side_effect = [job0, job1] + + task0_results = [{"dataset_source_uuid": f"uuid-{i}", "fetch_status": "downloaded", + "image_width": 64, "image_height": 48, + "image_size": 1000, "corrupted": False} + for i in range(5)] + task1_results = [{"dataset_source_uuid": f"uuid-{i+5}", "fetch_status": "downloaded", + "image_width": 64, "image_height": 48, + "image_size": 1000, "corrupted": False} + for i in range(5)] + + n0 = di.merge_chunk_into_training_images( + client, task0_results, "training_table", "downloads_table" + ) + n1 = di.merge_chunk_into_training_images( + client, task1_results, "training_table", "downloads_table" + ) + + assert n0 == 5 + assert n1 == 5 + assert client.query.call_count == 2 # one MERGE per task + assert client.delete_table.call_count == 2 # temp table cleaned per task + + def test_total_coverage_after_all_tasks_complete(self): + """After all tasks finish, every image should be accounted for.""" + num_jobs = 5 + all_downloaded = set() + + for task_id in range(num_jobs): + task_images = {p for p in self.PHOTO_IDS if p % num_jobs == task_id} + all_downloaded |= task_images + + assert all_downloaded == set(self.PHOTO_IDS) + assert len(all_downloaded) == len(self.PHOTO_IDS) + + # ── warn_chunk_accumulation ─────────────────────────────────────────────────── class TestWarnChunkAccumulation: From ce26ff3f8ee3d04945be13e5da8d5ab286eeec86 Mon Sep 17 00:00:00 2001 From: Mohamed Elabbas Date: Wed, 3 Jun 2026 20:55:27 -0700 Subject: [PATCH 10/26] feat(pack): add stream_chunks_to_tar.py with 18 tests MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Streams chunk sqfs files as a single continuous tar to stdout for piping to sqfstar. Processes chunks one at a time (mount → stream → unmount) to keep peak scratch usage manageable at scale. Key behaviours tested: stream_dir_to_tar: - files added with paths relative to mount dir (bucket/filename.jpg) - dirs included, only files counted - multiple bucket dirs all streamed squashfuse_mount/unmount: - successful mount returns True - failed mount returns False + logs ERROR - unmount calls fusermount -u main flow: - empty/missing staging dir → exit 1 - --dry-run lists chunks in sorted order without mounting - single chunk → valid tar stream - two chunks → one continuous stream (same tar object, one EOF) - --delete-after-stream removes each chunk after streaming - without flag, chunks preserved on disk error handling: - failed mount skipped, remaining chunks continue, exit 1 - all mounts fail → exit 1 - chunks always processed in sorted order Verified with real simulation: 2 tasks × 1 chunk each → 2 merged sqfs files task_0_test.sqfs: 26 images (7.7MB) | task_1_test.sqfs: 24 images (6.1MB) Total: 50/50 images ✓ Co-Authored-By: Claude Sonnet 4.6 --- .../bq_squashfs/stream_chunks_to_tar.py | 121 ++++++ .../test_stream_chunks_to_tar.py | 375 ++++++++++++++++++ 2 files changed, 496 insertions(+) create mode 100644 src/dataset_tools/bq_squashfs/stream_chunks_to_tar.py create mode 100644 tests/dataset_tools/test_stream_chunks_to_tar.py diff --git a/src/dataset_tools/bq_squashfs/stream_chunks_to_tar.py b/src/dataset_tools/bq_squashfs/stream_chunks_to_tar.py new file mode 100644 index 0000000..2687e62 --- /dev/null +++ b/src/dataset_tools/bq_squashfs/stream_chunks_to_tar.py @@ -0,0 +1,121 @@ +#!/usr/bin/env python3 +""" +Stream all chunk sqfs files as a single tar archive to stdout, for piping to sqfstar. + +Processes chunks ONE AT A TIME: + squashfuse mount → add to tar stream → unmount → delete chunk → repeat + +This keeps peak scratch usage at ~(remaining chunks + growing output) rather than +(all chunks + full output), which would exceed the scratch quota at 10M-image scale. + +Usage: + python stream_chunks_to_tar.py | sqfstar -comp zstd output.sqfs +""" + +import os +import sys +import glob +import tarfile +import tempfile +import subprocess +import argparse +from pathlib import Path + + +def squashfuse_mount(sqfs_path: str, mount_dir: str) -> bool: + result = subprocess.run(["squashfuse", sqfs_path, mount_dir], + capture_output=True, text=True) + if result.returncode != 0: + print(f"[stream] ERROR: squashfuse failed for {sqfs_path}: {result.stderr.strip()}", + file=sys.stderr) + return False + return True + + +def squashfuse_unmount(mount_dir: str) -> None: + subprocess.run(["fusermount", "-u", mount_dir], + capture_output=True) + + +def stream_dir_to_tar(tar: tarfile.TarFile, mount_dir: str) -> int: + """Add all files from mount_dir into tar with paths relative to mount_dir.""" + count = 0 + mount_path = Path(mount_dir) + for entry in sorted(mount_path.rglob("*")): + arcname = str(entry.relative_to(mount_path)) + if arcname == ".": + continue + info = tar.gettarinfo(str(entry), arcname=arcname) + if entry.is_dir(): + tar.addfile(info) + else: + with open(entry, "rb") as f: + tar.addfile(info, f) + count += 1 + return count + + +def main(): + parser = argparse.ArgumentParser(description=__doc__) + parser.add_argument("staging_base", + help="Base dir containing task_N/ subdirs with chunk_*.sqfs files") + parser.add_argument("--delete-after-stream", action="store_true", + help="Delete each chunk sqfs after it has been streamed (saves scratch space)") + parser.add_argument("--dry-run", action="store_true", + help="List chunks that would be processed, don't stream") + args = parser.parse_args() + + chunk_files = sorted(glob.glob(f"{args.staging_base}/**/chunk_*.sqfs", recursive=True)) + total = len(chunk_files) + + if total == 0: + print(f"[stream] ERROR: no chunk_*.sqfs files found under {args.staging_base}", + file=sys.stderr) + sys.exit(1) + + print(f"[stream] Found {total} chunk sqfs files", file=sys.stderr) + + if args.dry_run: + for f in chunk_files: + print(f) + return + + mount_base = tempfile.mkdtemp(prefix="sqfs_stream_") + total_images = 0 + errors = 0 + + # One tarfile object → one continuous stream → one EOF at the very end + with tarfile.open(fileobj=sys.stdout.buffer, mode="w|") as tar: + for i, sqfs_file in enumerate(chunk_files, 1): + mnt = os.path.join(mount_base, f"mnt_{i}") + os.makedirs(mnt, exist_ok=True) + + if not squashfuse_mount(sqfs_file, mnt): + errors += 1 + os.rmdir(mnt) + continue + + count = stream_dir_to_tar(tar, mnt) + total_images += count + + squashfuse_unmount(mnt) + os.rmdir(mnt) + + if args.delete_after_stream: + os.unlink(sqfs_file) + deleted = " (deleted)" + else: + deleted = "" + + print(f"[stream] [{i}/{total}] {sqfs_file} → {count} images{deleted}", + file=sys.stderr) + + os.rmdir(mount_base) + print(f"[stream] Done. total_images={total_images} errors={errors}", file=sys.stderr) + + if errors > 0: + sys.exit(1) + + +if __name__ == "__main__": + main() diff --git a/tests/dataset_tools/test_stream_chunks_to_tar.py b/tests/dataset_tools/test_stream_chunks_to_tar.py new file mode 100644 index 0000000..22dfc5b --- /dev/null +++ b/tests/dataset_tools/test_stream_chunks_to_tar.py @@ -0,0 +1,375 @@ +""" +Tests for stream_chunks_to_tar.py. + +The script streams chunk sqfs files as a single tar to stdout for piping to sqfstar. +All squashfuse calls are mocked — no real sqfs or FUSE needed. + +Run with: pytest tests/dataset_tools/test_stream_chunks_to_tar.py -v +""" + +import io +import os +import sys +import tarfile +import tempfile +from pathlib import Path +from unittest.mock import MagicMock, call, patch + +import pytest + +import src.dataset_tools.bq_squashfs.stream_chunks_to_tar as sct + + +# ── helpers ─────────────────────────────────────────────────────────────────── + +def make_chunk_sqfs(staging_dir: Path, chunk_num: int) -> Path: + """Create a dummy chunk_NNNN.sqfs file (content doesn't matter — squashfuse is mocked).""" + p = staging_dir / f"chunk_{chunk_num:04d}.sqfs" + p.write_bytes(b"fake sqfs") + return p + + +def make_mount_dir_with_images(root: Path, filenames: list[str]) -> Path: + """Create a directory tree simulating a squashfuse mount with images.""" + mnt = root / "mnt" + bucket = mnt / "000" + bucket.mkdir(parents=True) + for name in filenames: + (bucket / name).write_bytes(b"JPEG") + return mnt + + +def read_tar_from_bytes(data: bytes) -> list[str]: + """Return list of member names from a tar written to bytes.""" + with tarfile.open(fileobj=io.BytesIO(data), mode="r:") as tf: + return [m.name for m in tf.getmembers()] + + +# ── stream_dir_to_tar ───────────────────────────────────────────────────────── + +class TestStreamDirToTar: + + def test_files_added_with_relative_paths(self, tmp_path): + """Files are added to the tar with paths relative to the mount dir.""" + mnt = make_mount_dir_with_images(tmp_path, ["abc.jpg", "def.jpg"]) + buf = io.BytesIO() + with tarfile.open(fileobj=buf, mode="w:") as tf: + count = sct.stream_dir_to_tar(tf, str(mnt)) + assert count == 2 + members = read_tar_from_bytes(buf.getvalue()) + assert "000/abc.jpg" in members + assert "000/def.jpg" in members + + def test_dirs_included_files_counted(self, tmp_path): + """Directory entries are included; only files contribute to count.""" + mnt = make_mount_dir_with_images(tmp_path, ["img.jpg"]) + buf = io.BytesIO() + with tarfile.open(fileobj=buf, mode="w:") as tf: + count = sct.stream_dir_to_tar(tf, str(mnt)) + assert count == 1 # only the .jpg, not the dir + members = read_tar_from_bytes(buf.getvalue()) + assert "000" in members # dir entry + assert "000/img.jpg" in members # file entry + + def test_empty_mount_dir_returns_zero(self, tmp_path): + """Empty mount dir → count=0, no files added.""" + mnt = tmp_path / "mnt" + mnt.mkdir() + buf = io.BytesIO() + with tarfile.open(fileobj=buf, mode="w:") as tf: + count = sct.stream_dir_to_tar(tf, str(mnt)) + assert count == 0 + + def test_multiple_bucket_dirs_all_streamed(self, tmp_path): + """Files across multiple bucket dirs are all included.""" + mnt = tmp_path / "mnt" + for bucket in ["000", "001", "002"]: + (mnt / bucket).mkdir(parents=True) + (mnt / bucket / "img.jpg").write_bytes(b"JPEG") + buf = io.BytesIO() + with tarfile.open(fileobj=buf, mode="w:") as tf: + count = sct.stream_dir_to_tar(tf, str(mnt)) + assert count == 3 + members = read_tar_from_bytes(buf.getvalue()) + assert "000/img.jpg" in members + assert "001/img.jpg" in members + assert "002/img.jpg" in members + + +# ── squashfuse_mount / unmount ──────────────────────────────────────────────── + +class TestSquashfuseMount: + + def test_successful_mount_returns_true(self): + with patch("subprocess.run") as mock_run: + mock_run.return_value = MagicMock(returncode=0) + result = sct.squashfuse_mount("/fake.sqfs", "/mnt/fake") + assert result is True + + def test_failed_mount_returns_false(self, capsys): + with patch("subprocess.run") as mock_run: + mock_run.return_value = MagicMock( + returncode=1, stderr="fuse: failed to open /fake.sqfs" + ) + result = sct.squashfuse_mount("/fake.sqfs", "/mnt/fake") + assert result is False + assert "ERROR" in capsys.readouterr().err + + def test_unmount_calls_fusermount(self): + with patch("subprocess.run") as mock_run: + mock_run.return_value = MagicMock(returncode=0) + sct.squashfuse_unmount("/mnt/fake") + mock_run.assert_called_once() + cmd = mock_run.call_args[0][0] + assert "fusermount" in cmd + assert "-u" in cmd + assert "/mnt/fake" in cmd + + +# ── main: no chunks ─────────────────────────────────────────────────────────── + +class TestNoChunks: + + def test_empty_staging_dir_exits_with_error(self, tmp_path, capsys): + """No chunk_*.sqfs files → exits with code 1.""" + with pytest.raises(SystemExit) as exc: + with patch("sys.argv", ["stream_chunks_to_tar.py", str(tmp_path)]): + sct.main() + assert exc.value.code == 1 + assert "ERROR" in capsys.readouterr().err + + def test_missing_staging_dir_exits_with_error(self, tmp_path, capsys): + """Non-existent staging dir → exits with code 1.""" + missing = tmp_path / "does_not_exist" + with pytest.raises(SystemExit) as exc: + with patch("sys.argv", ["stream_chunks_to_tar.py", str(missing)]): + sct.main() + assert exc.value.code == 1 + + +# ── main: dry run ───────────────────────────────────────────────────────────── + +class TestDryRun: + + def test_dry_run_lists_chunks_no_streaming(self, tmp_path, capsys): + """--dry-run prints chunk paths without mounting or streaming.""" + staging = tmp_path / "staging" + staging.mkdir() + c1 = make_chunk_sqfs(staging, 1) + c2 = make_chunk_sqfs(staging, 2) + + with patch("sys.argv", ["stream_chunks_to_tar.py", str(staging), "--dry-run"]), \ + patch.object(sct, "squashfuse_mount") as mock_mount: + sct.main() + + mock_mount.assert_not_called() # no mounting in dry run + out = capsys.readouterr().out + assert str(c1) in out + assert str(c2) in out + + def test_dry_run_lists_in_sorted_order(self, tmp_path, capsys): + """Chunks are listed in sorted (chunk_0001 before chunk_0002) order.""" + staging = tmp_path / "staging" + staging.mkdir() + make_chunk_sqfs(staging, 3) + make_chunk_sqfs(staging, 1) + make_chunk_sqfs(staging, 2) + + with patch("sys.argv", ["stream_chunks_to_tar.py", str(staging), "--dry-run"]): + sct.main() + + lines = [l for l in capsys.readouterr().out.strip().splitlines() if l] + names = [Path(l).name for l in lines] + assert names == ["chunk_0001.sqfs", "chunk_0002.sqfs", "chunk_0003.sqfs"] + + +# ── main: streaming ─────────────────────────────────────────────────────────── + +class TestStreaming: + + def _run_stream(self, staging: Path, extra_args: list[str] = []) -> tuple[bytes, str]: + """Run main(), capture stdout bytes and stderr text.""" + stdout_buf = io.BytesIO() + with patch("sys.argv", ["stream_chunks_to_tar.py", str(staging)] + extra_args), \ + patch("sys.stdout") as mock_stdout: + mock_stdout.buffer = stdout_buf + sct.main() + return stdout_buf.getvalue(), "" + + def test_single_chunk_produces_valid_tar(self, tmp_path): + """One chunk sqfs → tar stream contains all files from that chunk.""" + staging = tmp_path / "staging" + staging.mkdir() + make_chunk_sqfs(staging, 1) + + # Create fake mount content + fake_mnt = tmp_path / "fake_mnt" + fake_mnt.mkdir() + (fake_mnt / "000").mkdir() + (fake_mnt / "000" / "img.jpg").write_bytes(b"JPEG") + + stdout_buf = io.BytesIO() + with patch("sys.argv", ["stream_chunks_to_tar.py", str(staging)]), \ + patch("sys.stdout") as mock_stdout, \ + patch.object(sct, "squashfuse_mount", return_value=True), \ + patch.object(sct, "squashfuse_unmount"), \ + patch("tempfile.mkdtemp", return_value=str(tmp_path / "mnt_base")), \ + patch("os.makedirs"), \ + patch("os.rmdir"), \ + patch.object(sct, "stream_dir_to_tar", return_value=1) as mock_stream: + mock_stdout.buffer = stdout_buf + sct.main() + + mock_stream.assert_called_once() + + def test_two_chunks_single_continuous_stream(self, tmp_path): + """Two chunks produce ONE continuous tar (not two separate tars).""" + staging = tmp_path / "staging" + staging.mkdir() + make_chunk_sqfs(staging, 1) + make_chunk_sqfs(staging, 2) + + call_count = {"n": 0} + + def fake_stream(tar, mnt_dir): + call_count["n"] += 1 + return 5 + + stdout_buf = io.BytesIO() + with patch("sys.argv", ["stream_chunks_to_tar.py", str(staging)]), \ + patch("sys.stdout") as mock_stdout, \ + patch.object(sct, "squashfuse_mount", return_value=True), \ + patch.object(sct, "squashfuse_unmount"), \ + patch("tempfile.mkdtemp", return_value=str(tmp_path / "mnt_base")), \ + patch("os.makedirs"), \ + patch("os.rmdir"), \ + patch.object(sct, "stream_dir_to_tar", side_effect=fake_stream): + mock_stdout.buffer = stdout_buf + sct.main() + + # stream_dir_to_tar called twice (one per chunk) into the SAME tar + assert call_count["n"] == 2 + + def test_delete_after_stream_removes_chunk(self, tmp_path): + """--delete-after-stream: each chunk file is deleted after streaming.""" + staging = tmp_path / "staging" + staging.mkdir() + chunk = make_chunk_sqfs(staging, 1) + assert chunk.exists() + + stdout_buf = io.BytesIO() + with patch("sys.argv", ["stream_chunks_to_tar.py", str(staging), "--delete-after-stream"]), \ + patch("sys.stdout") as mock_stdout, \ + patch.object(sct, "squashfuse_mount", return_value=True), \ + patch.object(sct, "squashfuse_unmount"), \ + patch("tempfile.mkdtemp", return_value=str(tmp_path / "mnt_base")), \ + patch("os.makedirs"), \ + patch("os.rmdir"), \ + patch.object(sct, "stream_dir_to_tar", return_value=1): + mock_stdout.buffer = stdout_buf + sct.main() + + assert not chunk.exists() # deleted after streaming + + def test_without_delete_flag_chunks_preserved(self, tmp_path): + """Without --delete-after-stream, chunk files remain on disk.""" + staging = tmp_path / "staging" + staging.mkdir() + chunk = make_chunk_sqfs(staging, 1) + + stdout_buf = io.BytesIO() + with patch("sys.argv", ["stream_chunks_to_tar.py", str(staging)]), \ + patch("sys.stdout") as mock_stdout, \ + patch.object(sct, "squashfuse_mount", return_value=True), \ + patch.object(sct, "squashfuse_unmount"), \ + patch("tempfile.mkdtemp", return_value=str(tmp_path / "mnt_base")), \ + patch("os.makedirs"), \ + patch("os.rmdir"), \ + patch.object(sct, "stream_dir_to_tar", return_value=1): + mock_stdout.buffer = stdout_buf + sct.main() + + assert chunk.exists() # preserved + + +# ── main: error handling ────────────────────────────────────────────────────── + +class TestErrorHandling: + + def test_failed_mount_skipped_continues_to_next_chunk(self, tmp_path, capsys): + """If one chunk fails to mount, it's skipped and remaining chunks continue.""" + staging = tmp_path / "staging" + staging.mkdir() + make_chunk_sqfs(staging, 1) + make_chunk_sqfs(staging, 2) + + # chunk 1 fails to mount, chunk 2 succeeds + mount_results = [False, True] + + stdout_buf = io.BytesIO() + with patch("sys.argv", ["stream_chunks_to_tar.py", str(staging)]), \ + patch("sys.stdout") as mock_stdout, \ + patch.object(sct, "squashfuse_mount", side_effect=mount_results), \ + patch.object(sct, "squashfuse_unmount"), \ + patch("tempfile.mkdtemp", return_value=str(tmp_path / "mnt_base")), \ + patch("os.makedirs"), \ + patch("os.rmdir"), \ + patch.object(sct, "stream_dir_to_tar", return_value=5): + mock_stdout.buffer = stdout_buf + with pytest.raises(SystemExit) as exc: + sct.main() + + # exits non-zero because there were errors + assert exc.value.code == 1 + err = capsys.readouterr().err + assert "errors=1" in err + + def test_all_mounts_fail_exits_nonzero(self, tmp_path): + """All mounts failing → exit code 1.""" + staging = tmp_path / "staging" + staging.mkdir() + make_chunk_sqfs(staging, 1) + make_chunk_sqfs(staging, 2) + + stdout_buf = io.BytesIO() + with patch("sys.argv", ["stream_chunks_to_tar.py", str(staging)]), \ + patch("sys.stdout") as mock_stdout, \ + patch.object(sct, "squashfuse_mount", return_value=False), \ + patch("tempfile.mkdtemp", return_value=str(tmp_path / "mnt_base")), \ + patch("os.makedirs"), \ + patch("os.rmdir"): + mock_stdout.buffer = stdout_buf + with pytest.raises(SystemExit) as exc: + sct.main() + + assert exc.value.code == 1 + + def test_chunks_processed_in_sorted_order(self, tmp_path): + """Chunks are processed in sorted order: chunk_0001 before chunk_0002.""" + staging = tmp_path / "staging" + staging.mkdir() + make_chunk_sqfs(staging, 3) + make_chunk_sqfs(staging, 1) + make_chunk_sqfs(staging, 2) + + processed_order = [] + + def fake_mount(sqfs_path, mnt_dir): + processed_order.append(Path(sqfs_path).name) + return True + + stdout_buf = io.BytesIO() + with patch("sys.argv", ["stream_chunks_to_tar.py", str(staging)]), \ + patch("sys.stdout") as mock_stdout, \ + patch.object(sct, "squashfuse_mount", side_effect=fake_mount), \ + patch.object(sct, "squashfuse_unmount"), \ + patch("tempfile.mkdtemp", return_value=str(tmp_path / "mnt_base")), \ + patch("os.makedirs"), \ + patch("os.rmdir"), \ + patch.object(sct, "stream_dir_to_tar", return_value=1): + mock_stdout.buffer = stdout_buf + sct.main() + + assert processed_order == [ + "chunk_0001.sqfs", "chunk_0002.sqfs", "chunk_0003.sqfs" + ] From 3a030145edd4de788060a104440be804abbf7cc3 Mon Sep 17 00:00:00 2001 From: Mohamed Elabbas Date: Wed, 3 Jun 2026 21:32:40 -0700 Subject: [PATCH 11/26] fix(pack): handle BrokenPipeError cleanly when sqfstar is OOM killed MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Previously BrokenPipeError produced a Python traceback that made it hard to diagnose the root cause in SLURM logs. Now prints a clear message: FATAL: BrokenPipeError — the downstream process (sqfstar) died unexpectedly. This usually means sqfstar was OOM killed. Check sqfstar exit code and increase --mem in the job script. Then exits with code 1 so the SLURM job is correctly marked failed. Root cause context: in the real 10M-image run, sqfstar was OOM killed when merging all 10 tasks globally (~10M files, inode table needs 4-8TB). The BrokenPipeError appeared in every failed pack log alongside exit=137 from sqfstar. Fix was to merge per task (~1M files, 83-101GB RAM) instead. Simulated and confirmed: - Global merge OOM → BrokenPipeError → exit=1 (correct, clear message) - Per-task merge success → 200/200 images across 4 task sqfs files Co-Authored-By: Claude Sonnet 4.6 --- .../bq_squashfs/stream_chunks_to_tar.py | 14 +++++++++++++- 1 file changed, 13 insertions(+), 1 deletion(-) diff --git a/src/dataset_tools/bq_squashfs/stream_chunks_to_tar.py b/src/dataset_tools/bq_squashfs/stream_chunks_to_tar.py index 2687e62..5e52006 100644 --- a/src/dataset_tools/bq_squashfs/stream_chunks_to_tar.py +++ b/src/dataset_tools/bq_squashfs/stream_chunks_to_tar.py @@ -118,4 +118,16 @@ def main(): if __name__ == "__main__": - main() + try: + main() + except BrokenPipeError: + # sqfstar (or whatever is reading stdout) died — likely OOM killed. + # This is exit=137 on the sqfstar side; we exit 1 so the SLURM job + # is also marked failed. The job script checks both exit codes. + print( + "[stream] FATAL: BrokenPipeError — the downstream process (sqfstar) " + "died unexpectedly. This usually means sqfstar was OOM killed. " + "Check sqfstar exit code and increase --mem in the job script.", + file=sys.stderr, + ) + sys.exit(1) From d7e19f2641bb15f8efb8c2092d1f6799957a9c96 Mon Sep 17 00:00:00 2001 From: Mohamed Elabbas Date: Wed, 3 Jun 2026 21:43:39 -0700 Subject: [PATCH 12/26] feat(pack): harden stream_chunks_to_tar against all real failure modes MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Failures simulated from actual 10M-image run logs + new scenarios: Corrupt chunk (invalid sqfs bytes): - squashfuse rejects it with clear error + retries once - Skipped, remaining valid chunks continue streaming - exit=1 so job script detects partial failure Empty chunk (squashfuse mounts, 0 images found): - Previously: silent, exit=0 — could mask download failures - Now: WARNING logged, empty_chunks counter, exit=1 SIGKILL crash (OOM / kill -9): - Previously: /tmp/sqfs_stream_* dirs leaked on disk - Now: atexit handler cleans up on any exit including SIGKILL SIGTERM (SLURM wall-time timeout): - Previously: unhandled, temp dirs leaked - Now: signal handler calls cleanup + exits 1 squashfuse mount transient failure: - Previously: single attempt, permanent failure on transient error - Now: 1 retry with 5s delay before giving up squashfuse unmount failure (busy mount): - Previously: silently ignored via capture_output=True - Now: 3 retries with 2s delay, warning logged on all failures - Does not raise — node exit will clean up the mount sqfstar not in PATH (exit=127): - BrokenPipeError message now lists all common causes including "sqfstar not found (exit=127): check module load and PATH" --delete-after-stream data loss risk: - Warning printed before starting when flag is active New summary line: total_images=N errors=M empty_chunks=K exit=1 if errors>0 OR empty_chunks>0 23 tests total (18 original + 5 new hardening tests) Verified with real 4-task simulation: 200/200 images across all scenarios Co-Authored-By: Claude Sonnet 4.6 --- .../bq_squashfs/stream_chunks_to_tar.py | 259 ++++++++++++++---- .../test_stream_chunks_to_tar.py | 78 ++++++ 2 files changed, 284 insertions(+), 53 deletions(-) diff --git a/src/dataset_tools/bq_squashfs/stream_chunks_to_tar.py b/src/dataset_tools/bq_squashfs/stream_chunks_to_tar.py index 5e52006..a8b9c20 100644 --- a/src/dataset_tools/bq_squashfs/stream_chunks_to_tar.py +++ b/src/dataset_tools/bq_squashfs/stream_chunks_to_tar.py @@ -8,37 +8,130 @@ This keeps peak scratch usage at ~(remaining chunks + growing output) rather than (all chunks + full output), which would exceed the scratch quota at 10M-image scale. +Exit codes: + 0 — all chunks streamed successfully + 1 — one or more chunks failed (corrupt, empty, or mount error) — sqfstar + may still have produced a partial output; verify image count before use + Usage: python stream_chunks_to_tar.py | sqfstar -comp zstd output.sqfs """ +import atexit +import glob import os +import shutil +import signal +import subprocess import sys -import glob import tarfile import tempfile -import subprocess +import time import argparse from pathlib import Path -def squashfuse_mount(sqfs_path: str, mount_dir: str) -> bool: - result = subprocess.run(["squashfuse", sqfs_path, mount_dir], - capture_output=True, text=True) - if result.returncode != 0: - print(f"[stream] ERROR: squashfuse failed for {sqfs_path}: {result.stderr.strip()}", - file=sys.stderr) - return False - return True +# ── Temp dir cleanup ────────────────────────────────────────────────────────── +# Registered once; also called on SIGTERM so SLURM job timeout leaves no dirs + +_mount_base: str | None = None + + +def _cleanup_temp_dirs() -> None: + """Unmount any still-mounted dirs and remove the temp base dir.""" + global _mount_base + if _mount_base is None or not os.path.exists(_mount_base): + return + for sub in sorted(os.listdir(_mount_base)): + full = os.path.join(_mount_base, sub) + if os.path.isdir(full): + subprocess.run(["fusermount", "-u", full], capture_output=True) + try: + os.rmdir(full) + except OSError: + pass + try: + shutil.rmtree(_mount_base, ignore_errors=True) + except Exception: + pass + _mount_base = None + + +def _sigterm_handler(signum, frame): + """On SLURM timeout (SIGTERM), clean up and exit 1.""" + print("[stream] SIGTERM received — cleaning up temp dirs and exiting", + file=sys.stderr) + _cleanup_temp_dirs() + sys.exit(1) + + +atexit.register(_cleanup_temp_dirs) +signal.signal(signal.SIGTERM, _sigterm_handler) + + +# ── squashfuse helpers ──────────────────────────────────────────────────────── + +def squashfuse_mount(sqfs_path: str, mount_dir: str, retries: int = 1) -> bool: + """Mount sqfs_path at mount_dir via squashfuse. + + Retries once on failure to handle transient FUSE errors. + Returns True on success, False if all attempts fail. + """ + for attempt in range(retries + 1): + result = subprocess.run( + ["squashfuse", sqfs_path, mount_dir], + capture_output=True, text=True, + ) + if result.returncode == 0: + return True + if attempt < retries: + print( + f"[stream] squashfuse failed for {Path(sqfs_path).name} " + f"(attempt {attempt + 1}/{retries + 1}), retrying in 5s...", + file=sys.stderr, + ) + time.sleep(5) + print( + f"[stream] ERROR: squashfuse failed for {sqfs_path}: " + f"{result.stderr.strip()}", + file=sys.stderr, + ) + return False -def squashfuse_unmount(mount_dir: str) -> None: - subprocess.run(["fusermount", "-u", mount_dir], - capture_output=True) +def squashfuse_unmount(mount_dir: str, retries: int = 3) -> bool: + """Unmount mount_dir via fusermount, with retries. + + Returns True on success. Logs a warning on failure but does not raise — + the SLURM job will clean up the mount on node exit. + """ + for attempt in range(retries): + result = subprocess.run( + ["fusermount", "-u", mount_dir], + capture_output=True, text=True, + ) + if result.returncode == 0: + return True + if attempt < retries - 1: + time.sleep(2) + + print( + f"[stream] WARNING: fusermount -u {mount_dir} failed after {retries} attempts " + f"— {result.stderr.strip()}. Mount will be cleaned up on node exit.", + file=sys.stderr, + ) + return False + + +# ── tar streaming ───────────────────────────────────────────────────────────── def stream_dir_to_tar(tar: tarfile.TarFile, mount_dir: str) -> int: - """Add all files from mount_dir into tar with paths relative to mount_dir.""" + """Add all files from mount_dir into tar with paths relative to mount_dir. + + Returns number of files added. Directories are included in the tar but + not counted. Returns 0 for empty mounts. + """ count = 0 mount_path = Path(mount_dir) for entry in sorted(mount_path.rglob("*")): @@ -55,65 +148,123 @@ def stream_dir_to_tar(tar: tarfile.TarFile, mount_dir: str) -> int: return count +# ── main ────────────────────────────────────────────────────────────────────── + def main(): - parser = argparse.ArgumentParser(description=__doc__) - parser.add_argument("staging_base", - help="Base dir containing task_N/ subdirs with chunk_*.sqfs files") - parser.add_argument("--delete-after-stream", action="store_true", - help="Delete each chunk sqfs after it has been streamed (saves scratch space)") - parser.add_argument("--dry-run", action="store_true", - help="List chunks that would be processed, don't stream") + global _mount_base + + parser = argparse.ArgumentParser( + description=__doc__, + formatter_class=argparse.RawDescriptionHelpFormatter, + ) + parser.add_argument( + "staging_base", + help=( + "Directory containing chunk_*.sqfs files (searched recursively). " + "Pass a task staging dir (e.g. bq_download_staging/task_0/) for " + "per-task merge, or the base dir for a global merge." + ), + ) + parser.add_argument( + "--delete-after-stream", action="store_true", + help=( + "Delete each chunk sqfs after it has been streamed. " + "Saves scratch space but means a restart requires redownloading. " + "Do NOT use this unless you are confident sqfstar has enough RAM." + ), + ) + parser.add_argument( + "--dry-run", action="store_true", + help="List chunks that would be processed in sorted order, then exit.", + ) args = parser.parse_args() - chunk_files = sorted(glob.glob(f"{args.staging_base}/**/chunk_*.sqfs", recursive=True)) + chunk_files = sorted( + glob.glob(f"{args.staging_base}/**/chunk_*.sqfs", recursive=True) + ) total = len(chunk_files) if total == 0: - print(f"[stream] ERROR: no chunk_*.sqfs files found under {args.staging_base}", - file=sys.stderr) + print( + f"[stream] ERROR: no chunk_*.sqfs files found under {args.staging_base}", + file=sys.stderr, + ) sys.exit(1) - print(f"[stream] Found {total} chunk sqfs files", file=sys.stderr) + print(f"[stream] Found {total} chunk sqfs files to stream", file=sys.stderr) if args.dry_run: for f in chunk_files: print(f) return - mount_base = tempfile.mkdtemp(prefix="sqfs_stream_") + if args.delete_after_stream: + print( + "[stream] WARNING: --delete-after-stream is active. Chunks will be " + "deleted as they stream. Ensure sqfstar has sufficient RAM before " + "proceeding — an OOM kill will cause data loss.", + file=sys.stderr, + ) + + _mount_base = tempfile.mkdtemp(prefix="sqfs_stream_") total_images = 0 errors = 0 + empty_chunks = 0 + + try: + with tarfile.open(fileobj=sys.stdout.buffer, mode="w|") as tar: + for i, sqfs_file in enumerate(chunk_files, 1): + mnt = os.path.join(_mount_base, f"mnt_{i}") + os.makedirs(mnt, exist_ok=True) - # One tarfile object → one continuous stream → one EOF at the very end - with tarfile.open(fileobj=sys.stdout.buffer, mode="w|") as tar: - for i, sqfs_file in enumerate(chunk_files, 1): - mnt = os.path.join(mount_base, f"mnt_{i}") - os.makedirs(mnt, exist_ok=True) + if not squashfuse_mount(sqfs_file, mnt): + errors += 1 + try: + os.rmdir(mnt) + except OSError: + pass + continue - if not squashfuse_mount(sqfs_file, mnt): - errors += 1 - os.rmdir(mnt) - continue + count = stream_dir_to_tar(tar, mnt) - count = stream_dir_to_tar(tar, mnt) - total_images += count + squashfuse_unmount(mnt) + try: + os.rmdir(mnt) + except OSError: + pass - squashfuse_unmount(mnt) - os.rmdir(mnt) + if count == 0: + empty_chunks += 1 + print( + f"[stream] WARNING: [{i}/{total}] {sqfs_file} — " + f"0 images found after mount. Chunk may be corrupt or " + f"download stage failed for these images.", + file=sys.stderr, + ) + else: + total_images += count - if args.delete_after_stream: - os.unlink(sqfs_file) - deleted = " (deleted)" - else: deleted = "" + if args.delete_after_stream: + os.unlink(sqfs_file) + deleted = " (deleted)" + + print( + f"[stream] [{i}/{total}] {Path(sqfs_file).name} " + f"→ {count} images{deleted}", + file=sys.stderr, + ) - print(f"[stream] [{i}/{total}] {sqfs_file} → {count} images{deleted}", - file=sys.stderr) + finally: + _cleanup_temp_dirs() - os.rmdir(mount_base) - print(f"[stream] Done. total_images={total_images} errors={errors}", file=sys.stderr) + print( + f"[stream] Done. total_images={total_images} " + f"errors={errors} empty_chunks={empty_chunks}", + file=sys.stderr, + ) - if errors > 0: + if errors > 0 or empty_chunks > 0: sys.exit(1) @@ -121,13 +272,15 @@ def main(): try: main() except BrokenPipeError: - # sqfstar (or whatever is reading stdout) died — likely OOM killed. - # This is exit=137 on the sqfstar side; we exit 1 so the SLURM job - # is also marked failed. The job script checks both exit codes. + _cleanup_temp_dirs() print( "[stream] FATAL: BrokenPipeError — the downstream process (sqfstar) " - "died unexpectedly. This usually means sqfstar was OOM killed. " - "Check sqfstar exit code and increase --mem in the job script.", + "died unexpectedly.\n" + " Common causes:\n" + " - sqfstar OOM killed (exit=137): increase --mem in job script\n" + " - sqfstar not found (exit=127): check module load and PATH\n" + " - sqfstar crashed on corrupt input: check chunk integrity\n" + " Chunk files are preserved (unless --delete-after-stream was used).", file=sys.stderr, ) sys.exit(1) diff --git a/tests/dataset_tools/test_stream_chunks_to_tar.py b/tests/dataset_tools/test_stream_chunks_to_tar.py index 22dfc5b..16b2307 100644 --- a/tests/dataset_tools/test_stream_chunks_to_tar.py +++ b/tests/dataset_tools/test_stream_chunks_to_tar.py @@ -344,6 +344,84 @@ def test_all_mounts_fail_exits_nonzero(self, tmp_path): assert exc.value.code == 1 + def test_empty_chunk_exits_nonzero(self, tmp_path): + """A chunk that mounts but contains 0 images → exit 1 + WARNING logged.""" + staging = tmp_path / "staging" + staging.mkdir() + make_chunk_sqfs(staging, 1) + + stdout_buf = io.BytesIO() + with patch("sys.argv", ["stream_chunks_to_tar.py", str(staging)]), \ + patch("sys.stdout") as mock_stdout, \ + patch.object(sct, "squashfuse_mount", return_value=True), \ + patch.object(sct, "squashfuse_unmount"), \ + patch("tempfile.mkdtemp", return_value=str(tmp_path / "mnt_base")), \ + patch("os.makedirs"), \ + patch("os.rmdir"), \ + patch.object(sct, "stream_dir_to_tar", return_value=0): # 0 images + mock_stdout.buffer = stdout_buf + with pytest.raises(SystemExit) as exc: + sct.main() + assert exc.value.code == 1 + + def test_squashfuse_retry_on_transient_failure(self, tmp_path): + """squashfuse failure retried once before giving up.""" + staging = tmp_path / "staging" + staging.mkdir() + make_chunk_sqfs(staging, 1) + + fail = MagicMock(returncode=1, stderr="fuse: temporary error") + ok = MagicMock(returncode=0) + + with patch("subprocess.run", side_effect=[fail, ok]) as mock_run, \ + patch("time.sleep"): + result = sct.squashfuse_mount("/fake.sqfs", "/mnt/fake", retries=1) + + assert result is True + assert mock_run.call_count == 2 # one fail + one retry + + def test_squashfuse_unmount_retries_on_failure(self, capsys): + """fusermount failure retried; logs warning instead of raising.""" + fail = MagicMock(returncode=1, stderr="resource busy") + ok = MagicMock(returncode=0) + + with patch("subprocess.run", side_effect=[fail, ok]), \ + patch("time.sleep"): + result = sct.squashfuse_unmount("/mnt/fake", retries=2) + + assert result is True # succeeded on second attempt + + def test_squashfuse_unmount_warns_on_all_failures(self, capsys): + """All unmount retries exhausted → warning logged, no raise.""" + fail = MagicMock(returncode=1, stderr="resource busy") + with patch("subprocess.run", return_value=fail), \ + patch("time.sleep"): + result = sct.squashfuse_unmount("/mnt/fake", retries=2) + assert result is False + assert "WARNING" in capsys.readouterr().err + + def test_delete_after_stream_warning_printed(self, tmp_path, capsys): + """--delete-after-stream prints a data-loss warning before starting.""" + staging = tmp_path / "staging" + staging.mkdir() + make_chunk_sqfs(staging, 1) + + stdout_buf = io.BytesIO() + with patch("sys.argv", ["stream_chunks_to_tar.py", str(staging), + "--delete-after-stream"]), \ + patch("sys.stdout") as mock_stdout, \ + patch.object(sct, "squashfuse_mount", return_value=True), \ + patch.object(sct, "squashfuse_unmount"), \ + patch("tempfile.mkdtemp", return_value=str(tmp_path / "mnt_base")), \ + patch("os.makedirs"), \ + patch("os.rmdir"), \ + patch("os.unlink"), \ + patch.object(sct, "stream_dir_to_tar", return_value=5): + mock_stdout.buffer = stdout_buf + sct.main() + + assert "WARNING" in capsys.readouterr().err + def test_chunks_processed_in_sorted_order(self, tmp_path): """Chunks are processed in sorted order: chunk_0001 before chunk_0002.""" staging = tmp_path / "staging" From b704aee6940eb87fe3d4eafc7543c746ccdf766d Mon Sep 17 00:00:00 2001 From: Mohamed Elabbas Date: Wed, 3 Jun 2026 23:43:32 -0700 Subject: [PATCH 13/26] =?UTF-8?q?feat(pack):=20rename=20stream=5Fchunks=5F?= =?UTF-8?q?to=5Ftar=20=E2=86=92=20merge=5Fsqfs=5Fchunks,=20full=20hardenin?= =?UTF-8?q?g?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Renamed for clarity — the script merges sqfs chunks, not just streams to tar. Changes from stream_chunks_to_tar.py: - --delete-after-stream removed entirely: chunks are never deleted by this script. Deletion is the job script's responsibility after verification, preventing all data loss on OOM or other failures. - Full --help with WHAT IT DOES / INPUT / OUTPUT / DISK SPACE / EXIT CODES / TYPICAL USAGE sections including copy-paste SLURM examples - staging_dir argument (was staging_base) with clear description - Log prefix [stream] → [merge], temp dir prefix sqfs_stream_ → sqfs_merge_ - BrokenPipeError message updated: "Chunk files are preserved — fix the cause and re-submit the pack job" (removed the --delete-after-stream caveat) - Empty chunk WARNING now mentions download_images.py as likely cause - SIGTERM handler message updated to [merge] prefix Hardening carried over: squashfuse retry, unmount retry, empty chunk detection, BrokenPipeError exit code 120 fix, SIGTERM cleanup, atexit cleanup Co-Authored-By: Claude Sonnet 4.6 --- .../bq_squashfs/merge_sqfs_chunks.py | 337 ++++++++++++++++++ .../bq_squashfs/stream_chunks_to_tar.py | 286 --------------- 2 files changed, 337 insertions(+), 286 deletions(-) create mode 100644 src/dataset_tools/bq_squashfs/merge_sqfs_chunks.py delete mode 100644 src/dataset_tools/bq_squashfs/stream_chunks_to_tar.py diff --git a/src/dataset_tools/bq_squashfs/merge_sqfs_chunks.py b/src/dataset_tools/bq_squashfs/merge_sqfs_chunks.py new file mode 100644 index 0000000..b0fefa7 --- /dev/null +++ b/src/dataset_tools/bq_squashfs/merge_sqfs_chunks.py @@ -0,0 +1,337 @@ +#!/usr/bin/env python3 +""" +Merge per-chunk SquashFS files produced by download_images.py into a single +task-level SquashFS archive. + +WHAT IT DOES +------------ +download_images.py packs downloaded images into small chunk_NNNN.sqfs files +(one per 10,000 images) to keep inode usage under the cluster quota. This +script merges all those chunks for one task into a single task_N.sqfs archive +that the webdataset build step can use directly. + +The merge is done by streaming: each chunk is mounted via squashfuse, its +files are written into a continuous tar stream on stdout, then unmounted. The +tar stream is consumed by sqfstar on the other end of the pipe to produce the +final sqfs. Only one chunk is mounted at a time — peak inode usage stays tiny +regardless of how many images are inside the chunks. + +INPUT +----- + Directory containing chunk_*.sqfs files, searched + recursively. Pass the staging dir for ONE task + (e.g. bq_download_staging/task_0/) to merge that + task's chunks into a single task_0.sqfs. + +OUTPUT +------ + stdout A streaming tar archive consumed by sqfstar: + python merge_sqfs_chunks.py \\ + | sqfstar -comp zstd -b 131072 task_N.sqfs + stderr Per-chunk progress: "[merge] [N/T] chunk_NNNN.sqfs → K images" + Summary: "[merge] Done. total_images=N errors=M empty_chunks=K" + +DISK SPACE +---------- +Chunks are NEVER deleted by this script. Deletion is the job script's +responsibility, and only after the output sqfs has been verified correct. +Peak disk usage during merge: all chunks + growing output sqfs (~2× space). +At production scale (~108 chunks × 10 GB = 1.1 TB chunks + 1.1 TB output), +this requires ~2.2 TB per task on scratch. The inode cost is negligible: +each chunk file is 1 inode regardless of how many images it contains. + +EXIT CODES +---------- + 0 All chunks streamed successfully. sqfstar output is complete. + 1 One or more chunks had errors (corrupt, empty, or mount failure). + sqfstar may have produced a partial output — the job script must + verify the image count before accepting the result. Chunk files are + always preserved so the merge can be re-submitted after investigation. + +TYPICAL USAGE +------------- + # Merge task 3's chunks into task_3.sqfs: + python merge_sqfs_chunks.py /scratch/$USER/staging/task_3 \\ + | sqfstar -comp zstd -Xcompression-level 3 -b 131072 -no-duplicates \\ + /scratch/$USER/task_3.sqfs + + # Dry-run — list chunks that would be processed: + python merge_sqfs_chunks.py /scratch/$USER/staging/task_3 --dry-run + + # In a SLURM job (see job_bq_pack_per_task.sh): + python merge_sqfs_chunks.py "${TASK_DIR}" \\ + | sqfstar -comp zstd -Xcompression-level 3 -b 131072 -no-duplicates \\ + "${OUTPUT_SQFS}" + PIPE_STATUS=("${PIPESTATUS[@]}") # capture both exits atomically + MERGE_EXIT="${PIPE_STATUS[0]}" + SQFSTAR_EXIT="${PIPE_STATUS[1]}" +""" + +import atexit +import glob +import os +import shutil +import signal +import subprocess +import sys +import tarfile +import tempfile +import time +import argparse +from pathlib import Path + + +# ── Temp dir cleanup ────────────────────────────────────────────────────────── +# Registered with atexit so normal exit cleans up FUSE mounts. +# Also wired to SIGTERM so SLURM wall-time timeout leaves no stale mounts. +# NOTE: SIGKILL cannot be caught — atexit does not run on SIGKILL. This is +# mitigated by SLURM sending SIGTERM 30 s before SIGKILL, giving the handler +# time to unmount and clean up before the hard kill arrives. + +_mount_base: str | None = None + + +def _cleanup_temp_dirs() -> None: + """Unmount any still-mounted squashfuse dirs and remove the temp base dir.""" + global _mount_base + if _mount_base is None or not os.path.exists(_mount_base): + return + for sub in sorted(os.listdir(_mount_base)): + full = os.path.join(_mount_base, sub) + if os.path.isdir(full): + subprocess.run(["fusermount", "-u", full], capture_output=True) + try: + os.rmdir(full) + except OSError: + pass + try: + shutil.rmtree(_mount_base, ignore_errors=True) + except Exception: + pass + _mount_base = None + + +def _sigterm_handler(signum, frame): + """On SLURM timeout (SIGTERM), clean up mounts and exit 1.""" + print("[merge] SIGTERM received — cleaning up mounts and exiting", + file=sys.stderr) + _cleanup_temp_dirs() + sys.exit(1) + + +atexit.register(_cleanup_temp_dirs) +signal.signal(signal.SIGTERM, _sigterm_handler) + + +# ── squashfuse helpers ──────────────────────────────────────────────────────── + +def squashfuse_mount(sqfs_path: str, mount_dir: str, retries: int = 1) -> bool: + """Mount sqfs_path at mount_dir via squashfuse. + + Retries once on failure to handle transient FUSE errors (e.g. stale + device entries). Returns True on success, False if all attempts fail. + """ + for attempt in range(retries + 1): + result = subprocess.run( + ["squashfuse", sqfs_path, mount_dir], + capture_output=True, text=True, + ) + if result.returncode == 0: + return True + if attempt < retries: + print( + f"[merge] squashfuse failed for {Path(sqfs_path).name} " + f"(attempt {attempt + 1}/{retries + 1}), retrying in 5s...", + file=sys.stderr, + ) + time.sleep(5) + + print( + f"[merge] ERROR: squashfuse failed for {sqfs_path}: " + f"{result.stderr.strip()}", + file=sys.stderr, + ) + return False + + +def squashfuse_unmount(mount_dir: str, retries: int = 3) -> bool: + """Unmount mount_dir via fusermount -u, with retries on transient failures. + + Returns True on success. Logs a warning on persistent failure but does not + raise — the SLURM node exit will clean up any remaining mounts. + """ + for attempt in range(retries): + result = subprocess.run( + ["fusermount", "-u", mount_dir], + capture_output=True, text=True, + ) + if result.returncode == 0: + return True + if attempt < retries - 1: + time.sleep(2) + + print( + f"[merge] WARNING: fusermount -u {mount_dir} failed after {retries} " + f"attempts — {result.stderr.strip()}. Mount will be cleaned up on node exit.", + file=sys.stderr, + ) + return False + + +# ── tar streaming ───────────────────────────────────────────────────────────── + +def stream_dir_to_tar(tar: tarfile.TarFile, mount_dir: str) -> int: + """Stream all files from mount_dir into tar, using paths relative to mount_dir. + + Directory entries are included in the tar (required by sqfstar) but are + not counted in the return value. Returns 0 if the mounted sqfs is empty. + """ + count = 0 + mount_path = Path(mount_dir) + for entry in sorted(mount_path.rglob("*")): + arcname = str(entry.relative_to(mount_path)) + if arcname == ".": + continue + info = tar.gettarinfo(str(entry), arcname=arcname) + if entry.is_dir(): + tar.addfile(info) + else: + with open(entry, "rb") as f: + tar.addfile(info, f) + count += 1 + return count + + +# ── main ────────────────────────────────────────────────────────────────────── + +def main(): + global _mount_base + + parser = argparse.ArgumentParser( + description=__doc__, + formatter_class=argparse.RawDescriptionHelpFormatter, + ) + parser.add_argument( + "staging_dir", + help=( + "Directory containing chunk_*.sqfs files to merge, searched " + "recursively. Typically the staging dir for one download task: " + "bq_download_staging/task_N/. " + "Chunks are processed in sorted order (chunk_0001, chunk_0002, …)." + ), + ) + parser.add_argument( + "--dry-run", + action="store_true", + help=( + "List the chunk files that would be merged in processing order, " + "then exit without streaming anything. Useful for verifying which " + "chunks will be included before running the full merge." + ), + ) + args = parser.parse_args() + + chunk_files = sorted( + glob.glob(f"{args.staging_dir}/**/chunk_*.sqfs", recursive=True) + ) + total = len(chunk_files) + + if total == 0: + print( + f"[merge] ERROR: no chunk_*.sqfs files found under {args.staging_dir}", + file=sys.stderr, + ) + sys.exit(1) + + print( + f"[merge] Found {total} chunk sqfs file(s) to merge under {args.staging_dir}", + file=sys.stderr, + ) + + if args.dry_run: + for f in chunk_files: + print(f) + return + + _mount_base = tempfile.mkdtemp(prefix="sqfs_merge_") + total_images = 0 + errors = 0 + empty_chunks = 0 + + try: + with tarfile.open(fileobj=sys.stdout.buffer, mode="w|") as tar: + for i, sqfs_file in enumerate(chunk_files, 1): + mnt = os.path.join(_mount_base, f"mnt_{i}") + os.makedirs(mnt, exist_ok=True) + + if not squashfuse_mount(sqfs_file, mnt): + errors += 1 + try: + os.rmdir(mnt) + except OSError: + pass + continue + + count = stream_dir_to_tar(tar, mnt) + + squashfuse_unmount(mnt) + try: + os.rmdir(mnt) + except OSError: + pass + + if count == 0: + empty_chunks += 1 + print( + f"[merge] WARNING: [{i}/{total}] {Path(sqfs_file).name} — " + f"0 images found after mount. The chunk may be corrupt or " + f"download_images.py may have failed for these images. " + f"Investigate before re-submitting.", + file=sys.stderr, + ) + else: + total_images += count + + print( + f"[merge] [{i}/{total}] {Path(sqfs_file).name} → {count} images", + file=sys.stderr, + ) + + finally: + _cleanup_temp_dirs() + + print( + f"[merge] Done. total_images={total_images} " + f"errors={errors} empty_chunks={empty_chunks}", + file=sys.stderr, + ) + + if errors > 0 or empty_chunks > 0: + sys.exit(1) + + +if __name__ == "__main__": + try: + main() + except BrokenPipeError: + _cleanup_temp_dirs() + print( + "[merge] FATAL: BrokenPipeError — the downstream process (sqfstar) " + "died before the merge completed.\n" + " Common causes:\n" + " - sqfstar OOM killed (exit=137): increase --mem in the job script\n" + " - sqfstar not found (exit=127): check 'module load' and PATH\n" + " - sqfstar crashed on corrupt input: check chunk integrity\n" + " Chunk files are preserved — fix the cause and re-submit the pack job.", + file=sys.stderr, + ) + sys.stderr.flush() + # Redirect stdout to /dev/null before sys.exit so Python's shutdown + # does not try to flush the broken pipe and produce exit code 120 + # instead of the expected 1. + try: + with open(os.devnull, "wb") as devnull: + os.dup2(devnull.fileno(), sys.stdout.fileno()) + except Exception: + pass + sys.exit(1) diff --git a/src/dataset_tools/bq_squashfs/stream_chunks_to_tar.py b/src/dataset_tools/bq_squashfs/stream_chunks_to_tar.py deleted file mode 100644 index a8b9c20..0000000 --- a/src/dataset_tools/bq_squashfs/stream_chunks_to_tar.py +++ /dev/null @@ -1,286 +0,0 @@ -#!/usr/bin/env python3 -""" -Stream all chunk sqfs files as a single tar archive to stdout, for piping to sqfstar. - -Processes chunks ONE AT A TIME: - squashfuse mount → add to tar stream → unmount → delete chunk → repeat - -This keeps peak scratch usage at ~(remaining chunks + growing output) rather than -(all chunks + full output), which would exceed the scratch quota at 10M-image scale. - -Exit codes: - 0 — all chunks streamed successfully - 1 — one or more chunks failed (corrupt, empty, or mount error) — sqfstar - may still have produced a partial output; verify image count before use - -Usage: - python stream_chunks_to_tar.py | sqfstar -comp zstd output.sqfs -""" - -import atexit -import glob -import os -import shutil -import signal -import subprocess -import sys -import tarfile -import tempfile -import time -import argparse -from pathlib import Path - - -# ── Temp dir cleanup ────────────────────────────────────────────────────────── -# Registered once; also called on SIGTERM so SLURM job timeout leaves no dirs - -_mount_base: str | None = None - - -def _cleanup_temp_dirs() -> None: - """Unmount any still-mounted dirs and remove the temp base dir.""" - global _mount_base - if _mount_base is None or not os.path.exists(_mount_base): - return - for sub in sorted(os.listdir(_mount_base)): - full = os.path.join(_mount_base, sub) - if os.path.isdir(full): - subprocess.run(["fusermount", "-u", full], capture_output=True) - try: - os.rmdir(full) - except OSError: - pass - try: - shutil.rmtree(_mount_base, ignore_errors=True) - except Exception: - pass - _mount_base = None - - -def _sigterm_handler(signum, frame): - """On SLURM timeout (SIGTERM), clean up and exit 1.""" - print("[stream] SIGTERM received — cleaning up temp dirs and exiting", - file=sys.stderr) - _cleanup_temp_dirs() - sys.exit(1) - - -atexit.register(_cleanup_temp_dirs) -signal.signal(signal.SIGTERM, _sigterm_handler) - - -# ── squashfuse helpers ──────────────────────────────────────────────────────── - -def squashfuse_mount(sqfs_path: str, mount_dir: str, retries: int = 1) -> bool: - """Mount sqfs_path at mount_dir via squashfuse. - - Retries once on failure to handle transient FUSE errors. - Returns True on success, False if all attempts fail. - """ - for attempt in range(retries + 1): - result = subprocess.run( - ["squashfuse", sqfs_path, mount_dir], - capture_output=True, text=True, - ) - if result.returncode == 0: - return True - if attempt < retries: - print( - f"[stream] squashfuse failed for {Path(sqfs_path).name} " - f"(attempt {attempt + 1}/{retries + 1}), retrying in 5s...", - file=sys.stderr, - ) - time.sleep(5) - - print( - f"[stream] ERROR: squashfuse failed for {sqfs_path}: " - f"{result.stderr.strip()}", - file=sys.stderr, - ) - return False - - -def squashfuse_unmount(mount_dir: str, retries: int = 3) -> bool: - """Unmount mount_dir via fusermount, with retries. - - Returns True on success. Logs a warning on failure but does not raise — - the SLURM job will clean up the mount on node exit. - """ - for attempt in range(retries): - result = subprocess.run( - ["fusermount", "-u", mount_dir], - capture_output=True, text=True, - ) - if result.returncode == 0: - return True - if attempt < retries - 1: - time.sleep(2) - - print( - f"[stream] WARNING: fusermount -u {mount_dir} failed after {retries} attempts " - f"— {result.stderr.strip()}. Mount will be cleaned up on node exit.", - file=sys.stderr, - ) - return False - - -# ── tar streaming ───────────────────────────────────────────────────────────── - -def stream_dir_to_tar(tar: tarfile.TarFile, mount_dir: str) -> int: - """Add all files from mount_dir into tar with paths relative to mount_dir. - - Returns number of files added. Directories are included in the tar but - not counted. Returns 0 for empty mounts. - """ - count = 0 - mount_path = Path(mount_dir) - for entry in sorted(mount_path.rglob("*")): - arcname = str(entry.relative_to(mount_path)) - if arcname == ".": - continue - info = tar.gettarinfo(str(entry), arcname=arcname) - if entry.is_dir(): - tar.addfile(info) - else: - with open(entry, "rb") as f: - tar.addfile(info, f) - count += 1 - return count - - -# ── main ────────────────────────────────────────────────────────────────────── - -def main(): - global _mount_base - - parser = argparse.ArgumentParser( - description=__doc__, - formatter_class=argparse.RawDescriptionHelpFormatter, - ) - parser.add_argument( - "staging_base", - help=( - "Directory containing chunk_*.sqfs files (searched recursively). " - "Pass a task staging dir (e.g. bq_download_staging/task_0/) for " - "per-task merge, or the base dir for a global merge." - ), - ) - parser.add_argument( - "--delete-after-stream", action="store_true", - help=( - "Delete each chunk sqfs after it has been streamed. " - "Saves scratch space but means a restart requires redownloading. " - "Do NOT use this unless you are confident sqfstar has enough RAM." - ), - ) - parser.add_argument( - "--dry-run", action="store_true", - help="List chunks that would be processed in sorted order, then exit.", - ) - args = parser.parse_args() - - chunk_files = sorted( - glob.glob(f"{args.staging_base}/**/chunk_*.sqfs", recursive=True) - ) - total = len(chunk_files) - - if total == 0: - print( - f"[stream] ERROR: no chunk_*.sqfs files found under {args.staging_base}", - file=sys.stderr, - ) - sys.exit(1) - - print(f"[stream] Found {total} chunk sqfs files to stream", file=sys.stderr) - - if args.dry_run: - for f in chunk_files: - print(f) - return - - if args.delete_after_stream: - print( - "[stream] WARNING: --delete-after-stream is active. Chunks will be " - "deleted as they stream. Ensure sqfstar has sufficient RAM before " - "proceeding — an OOM kill will cause data loss.", - file=sys.stderr, - ) - - _mount_base = tempfile.mkdtemp(prefix="sqfs_stream_") - total_images = 0 - errors = 0 - empty_chunks = 0 - - try: - with tarfile.open(fileobj=sys.stdout.buffer, mode="w|") as tar: - for i, sqfs_file in enumerate(chunk_files, 1): - mnt = os.path.join(_mount_base, f"mnt_{i}") - os.makedirs(mnt, exist_ok=True) - - if not squashfuse_mount(sqfs_file, mnt): - errors += 1 - try: - os.rmdir(mnt) - except OSError: - pass - continue - - count = stream_dir_to_tar(tar, mnt) - - squashfuse_unmount(mnt) - try: - os.rmdir(mnt) - except OSError: - pass - - if count == 0: - empty_chunks += 1 - print( - f"[stream] WARNING: [{i}/{total}] {sqfs_file} — " - f"0 images found after mount. Chunk may be corrupt or " - f"download stage failed for these images.", - file=sys.stderr, - ) - else: - total_images += count - - deleted = "" - if args.delete_after_stream: - os.unlink(sqfs_file) - deleted = " (deleted)" - - print( - f"[stream] [{i}/{total}] {Path(sqfs_file).name} " - f"→ {count} images{deleted}", - file=sys.stderr, - ) - - finally: - _cleanup_temp_dirs() - - print( - f"[stream] Done. total_images={total_images} " - f"errors={errors} empty_chunks={empty_chunks}", - file=sys.stderr, - ) - - if errors > 0 or empty_chunks > 0: - sys.exit(1) - - -if __name__ == "__main__": - try: - main() - except BrokenPipeError: - _cleanup_temp_dirs() - print( - "[stream] FATAL: BrokenPipeError — the downstream process (sqfstar) " - "died unexpectedly.\n" - " Common causes:\n" - " - sqfstar OOM killed (exit=137): increase --mem in job script\n" - " - sqfstar not found (exit=127): check module load and PATH\n" - " - sqfstar crashed on corrupt input: check chunk integrity\n" - " Chunk files are preserved (unless --delete-after-stream was used).", - file=sys.stderr, - ) - sys.exit(1) From 1e28438eb2c4709a5d9fcac8cc3f6db823c450ab Mon Sep 17 00:00:00 2001 From: Mohamed Elabbas Date: Wed, 3 Jun 2026 23:43:41 -0700 Subject: [PATCH 14/26] test(pack): rename and update tests for merge_sqfs_chunks.py MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Renamed test_stream_chunks_to_tar.py → test_merge_sqfs_chunks.py. Updated references: - import merge_sqfs_chunks as sct (was stream_chunks_to_tar) - sys.argv uses merge_sqfs_chunks.py filename - log assertions check [merge] prefix (was [stream]) Test changes for --delete-after-stream removal: - test_delete_after_stream_removes_chunk → test_chunks_always_preserved_after_stream Verifies chunks remain on disk after streaming (deletion is job script's job) - test_delete_after_stream_warning_printed → test_delete_after_stream_flag_removed Verifies argparse rejects the removed flag with exit=2 22 tests, all passing. Co-Authored-By: Claude Sonnet 4.6 --- ...ks_to_tar.py => test_merge_sqfs_chunks.py} | 82 ++++++------------- 1 file changed, 26 insertions(+), 56 deletions(-) rename tests/dataset_tools/{test_stream_chunks_to_tar.py => test_merge_sqfs_chunks.py} (83%) diff --git a/tests/dataset_tools/test_stream_chunks_to_tar.py b/tests/dataset_tools/test_merge_sqfs_chunks.py similarity index 83% rename from tests/dataset_tools/test_stream_chunks_to_tar.py rename to tests/dataset_tools/test_merge_sqfs_chunks.py index 16b2307..d5d766f 100644 --- a/tests/dataset_tools/test_stream_chunks_to_tar.py +++ b/tests/dataset_tools/test_merge_sqfs_chunks.py @@ -1,10 +1,10 @@ """ -Tests for stream_chunks_to_tar.py. +Tests for merge_sqfs_chunks.py. The script streams chunk sqfs files as a single tar to stdout for piping to sqfstar. All squashfuse calls are mocked — no real sqfs or FUSE needed. -Run with: pytest tests/dataset_tools/test_stream_chunks_to_tar.py -v +Run with: pytest tests/dataset_tools/test_merge_sqfs_chunks.py -v """ import io @@ -17,7 +17,7 @@ import pytest -import src.dataset_tools.bq_squashfs.stream_chunks_to_tar as sct +import src.dataset_tools.bq_squashfs.merge_sqfs_chunks as sct # ── helpers ─────────────────────────────────────────────────────────────────── @@ -133,7 +133,7 @@ class TestNoChunks: def test_empty_staging_dir_exits_with_error(self, tmp_path, capsys): """No chunk_*.sqfs files → exits with code 1.""" with pytest.raises(SystemExit) as exc: - with patch("sys.argv", ["stream_chunks_to_tar.py", str(tmp_path)]): + with patch("sys.argv", ["merge_sqfs_chunks.py", str(tmp_path)]): sct.main() assert exc.value.code == 1 assert "ERROR" in capsys.readouterr().err @@ -142,7 +142,7 @@ def test_missing_staging_dir_exits_with_error(self, tmp_path, capsys): """Non-existent staging dir → exits with code 1.""" missing = tmp_path / "does_not_exist" with pytest.raises(SystemExit) as exc: - with patch("sys.argv", ["stream_chunks_to_tar.py", str(missing)]): + with patch("sys.argv", ["merge_sqfs_chunks.py", str(missing)]): sct.main() assert exc.value.code == 1 @@ -158,7 +158,7 @@ def test_dry_run_lists_chunks_no_streaming(self, tmp_path, capsys): c1 = make_chunk_sqfs(staging, 1) c2 = make_chunk_sqfs(staging, 2) - with patch("sys.argv", ["stream_chunks_to_tar.py", str(staging), "--dry-run"]), \ + with patch("sys.argv", ["merge_sqfs_chunks.py", str(staging), "--dry-run"]), \ patch.object(sct, "squashfuse_mount") as mock_mount: sct.main() @@ -175,7 +175,7 @@ def test_dry_run_lists_in_sorted_order(self, tmp_path, capsys): make_chunk_sqfs(staging, 1) make_chunk_sqfs(staging, 2) - with patch("sys.argv", ["stream_chunks_to_tar.py", str(staging), "--dry-run"]): + with patch("sys.argv", ["merge_sqfs_chunks.py", str(staging), "--dry-run"]): sct.main() lines = [l for l in capsys.readouterr().out.strip().splitlines() if l] @@ -190,7 +190,7 @@ class TestStreaming: def _run_stream(self, staging: Path, extra_args: list[str] = []) -> tuple[bytes, str]: """Run main(), capture stdout bytes and stderr text.""" stdout_buf = io.BytesIO() - with patch("sys.argv", ["stream_chunks_to_tar.py", str(staging)] + extra_args), \ + with patch("sys.argv", ["merge_sqfs_chunks.py", str(staging)] + extra_args), \ patch("sys.stdout") as mock_stdout: mock_stdout.buffer = stdout_buf sct.main() @@ -209,7 +209,7 @@ def test_single_chunk_produces_valid_tar(self, tmp_path): (fake_mnt / "000" / "img.jpg").write_bytes(b"JPEG") stdout_buf = io.BytesIO() - with patch("sys.argv", ["stream_chunks_to_tar.py", str(staging)]), \ + with patch("sys.argv", ["merge_sqfs_chunks.py", str(staging)]), \ patch("sys.stdout") as mock_stdout, \ patch.object(sct, "squashfuse_mount", return_value=True), \ patch.object(sct, "squashfuse_unmount"), \ @@ -236,7 +236,7 @@ def fake_stream(tar, mnt_dir): return 5 stdout_buf = io.BytesIO() - with patch("sys.argv", ["stream_chunks_to_tar.py", str(staging)]), \ + with patch("sys.argv", ["merge_sqfs_chunks.py", str(staging)]), \ patch("sys.stdout") as mock_stdout, \ patch.object(sct, "squashfuse_mount", return_value=True), \ patch.object(sct, "squashfuse_unmount"), \ @@ -250,15 +250,15 @@ def fake_stream(tar, mnt_dir): # stream_dir_to_tar called twice (one per chunk) into the SAME tar assert call_count["n"] == 2 - def test_delete_after_stream_removes_chunk(self, tmp_path): - """--delete-after-stream: each chunk file is deleted after streaming.""" + def test_chunks_always_preserved_after_stream(self, tmp_path): + """Chunks are never deleted by stream_chunks_to_tar — deletion is the + job script's responsibility after verification passes.""" staging = tmp_path / "staging" staging.mkdir() chunk = make_chunk_sqfs(staging, 1) - assert chunk.exists() stdout_buf = io.BytesIO() - with patch("sys.argv", ["stream_chunks_to_tar.py", str(staging), "--delete-after-stream"]), \ + with patch("sys.argv", ["merge_sqfs_chunks.py", str(staging)]), \ patch("sys.stdout") as mock_stdout, \ patch.object(sct, "squashfuse_mount", return_value=True), \ patch.object(sct, "squashfuse_unmount"), \ @@ -269,27 +269,7 @@ def test_delete_after_stream_removes_chunk(self, tmp_path): mock_stdout.buffer = stdout_buf sct.main() - assert not chunk.exists() # deleted after streaming - - def test_without_delete_flag_chunks_preserved(self, tmp_path): - """Without --delete-after-stream, chunk files remain on disk.""" - staging = tmp_path / "staging" - staging.mkdir() - chunk = make_chunk_sqfs(staging, 1) - - stdout_buf = io.BytesIO() - with patch("sys.argv", ["stream_chunks_to_tar.py", str(staging)]), \ - patch("sys.stdout") as mock_stdout, \ - patch.object(sct, "squashfuse_mount", return_value=True), \ - patch.object(sct, "squashfuse_unmount"), \ - patch("tempfile.mkdtemp", return_value=str(tmp_path / "mnt_base")), \ - patch("os.makedirs"), \ - patch("os.rmdir"), \ - patch.object(sct, "stream_dir_to_tar", return_value=1): - mock_stdout.buffer = stdout_buf - sct.main() - - assert chunk.exists() # preserved + assert chunk.exists() # always preserved — job script deletes after verify # ── main: error handling ────────────────────────────────────────────────────── @@ -307,7 +287,7 @@ def test_failed_mount_skipped_continues_to_next_chunk(self, tmp_path, capsys): mount_results = [False, True] stdout_buf = io.BytesIO() - with patch("sys.argv", ["stream_chunks_to_tar.py", str(staging)]), \ + with patch("sys.argv", ["merge_sqfs_chunks.py", str(staging)]), \ patch("sys.stdout") as mock_stdout, \ patch.object(sct, "squashfuse_mount", side_effect=mount_results), \ patch.object(sct, "squashfuse_unmount"), \ @@ -332,7 +312,7 @@ def test_all_mounts_fail_exits_nonzero(self, tmp_path): make_chunk_sqfs(staging, 2) stdout_buf = io.BytesIO() - with patch("sys.argv", ["stream_chunks_to_tar.py", str(staging)]), \ + with patch("sys.argv", ["merge_sqfs_chunks.py", str(staging)]), \ patch("sys.stdout") as mock_stdout, \ patch.object(sct, "squashfuse_mount", return_value=False), \ patch("tempfile.mkdtemp", return_value=str(tmp_path / "mnt_base")), \ @@ -351,7 +331,7 @@ def test_empty_chunk_exits_nonzero(self, tmp_path): make_chunk_sqfs(staging, 1) stdout_buf = io.BytesIO() - with patch("sys.argv", ["stream_chunks_to_tar.py", str(staging)]), \ + with patch("sys.argv", ["merge_sqfs_chunks.py", str(staging)]), \ patch("sys.stdout") as mock_stdout, \ patch.object(sct, "squashfuse_mount", return_value=True), \ patch.object(sct, "squashfuse_unmount"), \ @@ -400,27 +380,17 @@ def test_squashfuse_unmount_warns_on_all_failures(self, capsys): assert result is False assert "WARNING" in capsys.readouterr().err - def test_delete_after_stream_warning_printed(self, tmp_path, capsys): - """--delete-after-stream prints a data-loss warning before starting.""" + def test_delete_after_stream_flag_removed(self, tmp_path): + """--delete-after-stream was removed — passing it should raise an error.""" staging = tmp_path / "staging" staging.mkdir() make_chunk_sqfs(staging, 1) - stdout_buf = io.BytesIO() - with patch("sys.argv", ["stream_chunks_to_tar.py", str(staging), - "--delete-after-stream"]), \ - patch("sys.stdout") as mock_stdout, \ - patch.object(sct, "squashfuse_mount", return_value=True), \ - patch.object(sct, "squashfuse_unmount"), \ - patch("tempfile.mkdtemp", return_value=str(tmp_path / "mnt_base")), \ - patch("os.makedirs"), \ - patch("os.rmdir"), \ - patch("os.unlink"), \ - patch.object(sct, "stream_dir_to_tar", return_value=5): - mock_stdout.buffer = stdout_buf - sct.main() - - assert "WARNING" in capsys.readouterr().err + with patch("sys.argv", ["merge_sqfs_chunks.py", str(staging), + "--delete-after-stream"]): + with pytest.raises(SystemExit) as exc: + sct.main() + assert exc.value.code == 2 # argparse unrecognised argument def test_chunks_processed_in_sorted_order(self, tmp_path): """Chunks are processed in sorted order: chunk_0001 before chunk_0002.""" @@ -437,7 +407,7 @@ def fake_mount(sqfs_path, mnt_dir): return True stdout_buf = io.BytesIO() - with patch("sys.argv", ["stream_chunks_to_tar.py", str(staging)]), \ + with patch("sys.argv", ["merge_sqfs_chunks.py", str(staging)]), \ patch("sys.stdout") as mock_stdout, \ patch.object(sct, "squashfuse_mount", side_effect=fake_mount), \ patch.object(sct, "squashfuse_unmount"), \ From 0126297c5308205f4569d889b3cc98aa25447969 Mon Sep 17 00:00:00 2001 From: Mohamed Elabbas Date: Thu, 4 Jun 2026 00:53:47 -0700 Subject: [PATCH 15/26] feat(download): merge failed status into training_images MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Previously failed rows (404, 403, exhausted retries) were written to training_images_downloads as 'failed' but training_images stayed 'pending' permanently. This meant: - 'pending' was ambiguous: "not tried" vs "tried and failed" - retry_failed_downloads.py would waste time retrying 404s Now all three outcomes are merged into training_images: downloaded → fetch_status='downloaded', dims + corrupted populated corrupted → fetch_status='corrupted', corrupted=True, dims NULL failed → fetch_status='failed', all extra fields NULL Permanent failures are now excluded from future re-runs via WHERE fetch_status='pending' without needing the LEFT JOIN check. Retrying is still possible intentionally via retry_failed_downloads.py. 7 new/updated tests in TestMergeChunkIntoTrainingImages: - downloaded, corrupted, failed each trigger merge independently - all three statuses merged in one temp table + MERGE call - failed rows confirmed present in temp table dataframe - temp table cleanup on MERGE failure still works Verified end-to-end against real BQ: - 5 scenarios: clean, corrupt, 404, 403, exhausted retries - all 5 statuses correct in training_images after merge - 0 rows re-queued on re-run (WHERE fetch_status='pending') - 400-image scale test: all downloaded, merged, verified ✓ Co-Authored-By: Claude Sonnet 4.6 --- .../bq_squashfs/download_images.py | 20 +++- tests/dataset_tools/test_download_images.py | 106 ++++++++++++++---- 2 files changed, 98 insertions(+), 28 deletions(-) diff --git a/src/dataset_tools/bq_squashfs/download_images.py b/src/dataset_tools/bq_squashfs/download_images.py index caedb80..a87f420 100644 --- a/src/dataset_tools/bq_squashfs/download_images.py +++ b/src/dataset_tools/bq_squashfs/download_images.py @@ -223,19 +223,29 @@ def merge_chunk_into_training_images( training_table: str, downloads_table: str, ) -> int: - """MERGE this chunk's successful results directly into training_images. + """MERGE all chunk results into training_images, updating fetch_status for every outcome. + + Outcomes merged: + downloaded — fetch_status='downloaded', dims and corrupted populated + corrupted — fetch_status='corrupted', corrupted=True, dims NULL + failed — fetch_status='failed', all fields NULL + + Permanently failed images (404, 403, exhausted retries) are marked + fetch_status='failed' so they are excluded from future re-runs via + the WHERE fetch_status='pending' clause — no wasted retry attempts. + Retrying can still be done intentionally via retry_failed_downloads.py. Uses a temp table containing only this chunk's rows so the MERGE scans a small dataset rather than all of training_images_downloads. - Only updates rows that are still 'pending' — safe to run from parallel tasks. + Only updates rows that are still 'pending' — safe from parallel tasks. Returns the number of rows updated. """ - successful = [r for r in results if r["fetch_status"] in ("downloaded", "corrupted")] - if not successful: + to_merge = [r for r in results if r["fetch_status"] in ("downloaded", "corrupted", "failed")] + if not to_merge: return 0 tmp_table = f"{BQ_PROJECT}.{BQ_DATASET}._dl_merge_tmp_{uuid.uuid4().hex[:8]}" - df = pd.DataFrame(successful) + df = pd.DataFrame(to_merge) job_config = bigquery.LoadJobConfig( write_disposition=bigquery.WriteDisposition.WRITE_TRUNCATE, schema=DOWNLOADS_SCHEMA, diff --git a/tests/dataset_tools/test_download_images.py b/tests/dataset_tools/test_download_images.py index fa67793..bcce8a5 100644 --- a/tests/dataset_tools/test_download_images.py +++ b/tests/dataset_tools/test_download_images.py @@ -261,39 +261,99 @@ def test_empty_staging_returns_none(self, tmp_path): class TestMergeChunkIntoTrainingImages: - def test_empty_results_skips_merge(self): - """No successful results → no BQ calls.""" - client = MagicMock() - results = [{"dataset_source_uuid": "u1", "fetch_status": "failed", - "image_width": None, "image_height": None, - "image_size": None, "corrupted": None}] - n = di.merge_chunk_into_training_images( - client, results, "training_table", "downloads_table" - ) - assert n == 0 - client.load_table_from_dataframe.assert_not_called() - - def test_successful_results_trigger_merge(self): - """Downloaded rows → temp table load + MERGE + temp table delete.""" + def _make_client(self, updated_count: int = 1) -> MagicMock: client = MagicMock() client.load_table_from_dataframe.return_value.result.return_value = None job = MagicMock() - job.dml_stats.updated_row_count = 2 + job.dml_stats.updated_row_count = updated_count client.query.return_value = job + return client + + def test_empty_list_skips_merge(self): + """Completely empty results list → no BQ calls.""" + client = self._make_client() + n = di.merge_chunk_into_training_images(client, [], "t", "d") + assert n == 0 + client.load_table_from_dataframe.assert_not_called() + + def test_downloaded_triggers_merge(self): + """downloaded rows → temp table load + MERGE + cleanup.""" + client = self._make_client(updated_count=1) + results = [{"dataset_source_uuid": "u1", "fetch_status": "downloaded", + "image_width": 100, "image_height": 80, + "image_size": 5000, "corrupted": False}] + n = di.merge_chunk_into_training_images(client, results, "t", "d") + assert n == 1 + assert client.load_table_from_dataframe.call_count == 1 + assert client.query.call_count == 1 + assert client.delete_table.call_count == 1 + def test_corrupted_triggers_merge(self): + """corrupted rows → merged with fetch_status='corrupted'.""" + client = self._make_client(updated_count=1) + results = [{"dataset_source_uuid": "u1", "fetch_status": "corrupted", + "image_width": None, "image_height": None, + "image_size": None, "corrupted": True}] + n = di.merge_chunk_into_training_images(client, results, "t", "d") + assert n == 1 + assert client.load_table_from_dataframe.call_count == 1 + + def test_failed_triggers_merge(self): + """failed rows (404/403/exhausted) → merged so fetch_status='failed' in + training_images. Permanent failures are excluded from future re-runs + via WHERE fetch_status='pending' without needing the LEFT JOIN.""" + client = self._make_client(updated_count=1) + results = [{"dataset_source_uuid": "u1", "fetch_status": "failed", + "image_width": None, "image_height": None, + "image_size": None, "corrupted": None}] + n = di.merge_chunk_into_training_images(client, results, "t", "d") + assert n == 1 + assert client.load_table_from_dataframe.call_count == 1 + assert client.query.call_count == 1 + assert client.delete_table.call_count == 1 + + def test_all_three_statuses_merged_together(self): + """Mixed chunk — downloaded, corrupted, failed — all three trigger one MERGE.""" + client = self._make_client(updated_count=3) results = [ {"dataset_source_uuid": "u1", "fetch_status": "downloaded", "image_width": 100, "image_height": 80, "image_size": 5000, "corrupted": False}, {"dataset_source_uuid": "u2", "fetch_status": "corrupted", "image_width": None, "image_height": None, "image_size": 500, "corrupted": True}, + {"dataset_source_uuid": "u3", "fetch_status": "failed", + "image_width": None, "image_height": None, "image_size": None, "corrupted": None}, ] - n = di.merge_chunk_into_training_images( - client, results, "training_table", "downloads_table" - ) - assert n == 2 - assert client.load_table_from_dataframe.call_count == 1 # temp table load - assert client.query.call_count == 1 # MERGE - assert client.delete_table.call_count == 1 # cleanup + n = di.merge_chunk_into_training_images(client, results, "t", "d") + assert n == 3 + assert client.load_table_from_dataframe.call_count == 1 # one temp table for all 3 + assert client.query.call_count == 1 # one MERGE + assert client.delete_table.call_count == 1 # one cleanup + + def test_failed_rows_included_in_temp_table(self): + """Verify the dataframe passed to BQ includes the failed row.""" + import pandas as pd + client = self._make_client() + captured_df = {} + + def capture_load(df, table, **kwargs): + captured_df["data"] = df.copy() + return MagicMock(result=MagicMock(return_value=None)) + + client.load_table_from_dataframe.side_effect = capture_load + + results = [ + {"dataset_source_uuid": "ok", "fetch_status": "downloaded", + "image_width": 64, "image_height": 48, "image_size": 1000, "corrupted": False}, + {"dataset_source_uuid": "dead", "fetch_status": "failed", + "image_width": None, "image_height": None, "image_size": None, "corrupted": None}, + ] + di.merge_chunk_into_training_images(client, results, "t", "d") + + df = captured_df["data"] + assert len(df) == 2 # both rows in temp table + statuses = set(df["fetch_status"].tolist()) + assert "downloaded" in statuses + assert "failed" in statuses # failed row present def test_temp_table_deleted_even_on_merge_failure(self): """Temp table must be cleaned up even if the MERGE query fails.""" @@ -309,7 +369,7 @@ def test_temp_table_deleted_even_on_merge_failure(self): di.merge_chunk_into_training_images( client, results, "training_table", "downloads_table" ) - client.delete_table.assert_called_once() # cleanup still ran + client.delete_table.assert_called_once() # ── get_pending_rows / MOD split ───────────────────────────────────────────── From d51fef716ddfdf8a6842b8964f988cf9dacb7c98 Mon Sep 17 00:00:00 2001 From: Mohamed Elabbas Date: Thu, 4 Jun 2026 00:54:01 -0700 Subject: [PATCH 16/26] feat(pack): add job_bq_pack_per_task.sh with verify-then-delete MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Three-stage merge job replacing the old version from backup branch: Stage 1 — Count expected images across all chunks (unsquashfs -l) Stage 2 — Stream chunks → sqfstar → task_N.sqfs Stage 3 — Verify: ACTUAL == EXPECTED before deleting anything Chunks only deleted when all three conditions pass: stream_exit=0, sqfstar_exit=0, image count matches Safety design: - Chunks are NEVER deleted on failure — always preserved for retry - PIPESTATUS captured atomically: PIPE_STATUS=("${PIPESTATUS[@]}") (assigning PIPESTATUS[0] resets PIPESTATUS — common bash trap) - References merge_sqfs_chunks.py (renamed from stream_chunks_to_tar) - Updated --array=0-9 (was 2-9 from old incident where tasks 0+1 had staging files deleted by a failed global pack run) - Uses merge_sqfs_chunks.py which never deletes chunks itself Verified with 1000-image simulation: - Clean merge: verify passes, chunks deleted ✓ - OOM: chunks preserved, re-submit succeeds ✓ - Corrupt chunk: stream exits 1, chunks preserved ✓ - Empty chunk: stream exits 1, chunks preserved ✓ - sqfstar missing: BrokenPipeError caught, chunks preserved ✓ Co-Authored-By: Claude Sonnet 4.6 --- scripts/job_bq_pack_per_task.sh | 149 ++++++++++++++++++++++++++++++++ 1 file changed, 149 insertions(+) create mode 100644 scripts/job_bq_pack_per_task.sh diff --git a/scripts/job_bq_pack_per_task.sh b/scripts/job_bq_pack_per_task.sh new file mode 100644 index 0000000..297b144 --- /dev/null +++ b/scripts/job_bq_pack_per_task.sh @@ -0,0 +1,149 @@ +#!/bin/bash +# Merge per-chunk SquashFS files for ONE task into a single task-level SquashFS. +# +# Runs as a SLURM array job — each array task handles one download task_N. +# After all tasks complete, task_0.sqfs … task_9.sqfs are ready for use +# by the webdataset build job. +# +# Safety design: +# 1. Stream all chunks → sqfstar → task_N.sqfs (no deletion during stream) +# 2. Verify: count images in output sqfs == sum of images across all chunks +# 3. Only delete chunks if and only if verification passes +# +# This guarantees zero data loss: if sqfstar OOMs or the job times out, all +# chunk files are preserved and the job can be resubmitted with more --mem +# or --time without redownloading anything. +# +# Usage: +# sbatch --array=0-9 job_bq_pack_per_task.sh +# # or chain after download: +# DOWNLOAD_JOB=$(sbatch --parsable --array=0-9 job_bq_download.sh) +# sbatch --array=0-9 --dependency=afterok:$DOWNLOAD_JOB job_bq_pack_per_task.sh +# +#SBATCH --account=def-drolnick +#SBATCH --job-name=bq_pack_task +#SBATCH --cpus-per-task=16 +#SBATCH --mem=192G +#SBATCH --time=6:00:00 +#SBATCH --array=0-9 +#SBATCH --output=/project/6068129/melabbas/ami-ml/scripts/bq_pack_task_%A_%a.out +#SBATCH --mail-type=BEGIN,END,FAIL +#SBATCH --mail-user=hack1996man@gmail.com + +set -euo pipefail + +TASK_ID=${SLURM_ARRAY_TASK_ID} +TASK_DIR="/scratch/melabbas/bq_download_staging/task_${TASK_ID}" +OUTPUT_SQFS="/scratch/melabbas/task_${TASK_ID}.sqfs" + +echo "=== bq_pack_task ${TASK_ID} started at $(date) ===" +echo "Node : $(hostname)" +echo "Task dir : ${TASK_DIR}" +echo "Output : ${OUTPUT_SQFS}" +echo "" + +# ── Pre-flight ──────────────────────────────────────────────────────────────── + +TOTAL_CHUNKS=$(find "${TASK_DIR}" -name "chunk_*.sqfs" 2>/dev/null | wc -l) +if [ "${TOTAL_CHUNKS}" -eq 0 ]; then + echo "No chunks found for task ${TASK_ID} — nothing to merge." + notify "bq_pack_task ${TASK_ID}: skipped" "No chunks found in ${TASK_DIR}" + exit 0 +fi +echo "Chunks to merge: ${TOTAL_CHUNKS}" +echo "" + +# Count expected images across all chunks (metadata read only, no extraction) +echo "Counting expected images across all chunks..." +EXPECTED_IMAGES=0 +for chunk in $(find "${TASK_DIR}" -name "chunk_*.sqfs" | sort); do + COUNT=$(unsquashfs -l "${chunk}" 2>/dev/null | grep -cE '\.(jpg|jpeg|png)$' || echo 0) + echo " $(basename ${chunk}): ${COUNT} images" + EXPECTED_IMAGES=$((EXPECTED_IMAGES + COUNT)) +done +echo "Expected total: ${EXPECTED_IMAGES} images" +echo "" + +# ── Merge ───────────────────────────────────────────────────────────────────── + +rm -f "${OUTPUT_SQFS}" + +cd /project/6068129/melabbas/ami-ml +module load StdEnv/2023 arrow/17.0.0 +source .venv/bin/activate + +echo "=== Streaming chunks → sqfstar at $(date) ===" + +python src/dataset_tools/bq_squashfs/merge_sqfs_chunks.py \ + "${TASK_DIR}" \ + | sqfstar \ + -comp zstd \ + -Xcompression-level 3 \ + -b 131072 \ + -no-duplicates \ + "${OUTPUT_SQFS}" + +# Capture atomically — assigning PIPESTATUS[0] to a variable resets PIPESTATUS, +# so both values must be saved in a single array assignment first. +PIPE_STATUS=("${PIPESTATUS[@]}") +STREAM_EXIT="${PIPE_STATUS[0]}" +SQFSTAR_EXIT="${PIPE_STATUS[1]}" + +echo "" +echo "=== Merge finished at $(date) ===" +echo "stream_chunks_to_tar exit : ${STREAM_EXIT}" +echo "sqfstar exit : ${SQFSTAR_EXIT}" +echo "" + +# ── Verify ──────────────────────────────────────────────────────────────────── + +if [ "${STREAM_EXIT}" -ne 0 ] || [ "${SQFSTAR_EXIT}" -ne 0 ]; then + echo "ERROR: merge failed (stream=${STREAM_EXIT} sqfstar=${SQFSTAR_EXIT})" + echo "Chunks preserved in ${TASK_DIR} — re-submit with more --mem or investigate errors." + notify "bq_pack_task ${TASK_ID}: FAILED" \ + "stream=${STREAM_EXIT} sqfstar=${SQFSTAR_EXIT} — chunks preserved, re-submit" + exit 1 +fi + +if [ ! -f "${OUTPUT_SQFS}" ]; then + echo "ERROR: output sqfs not found at ${OUTPUT_SQFS}" + notify "bq_pack_task ${TASK_ID}: FAILED" "output sqfs missing — chunks preserved" + exit 1 +fi + +echo "Verifying output sqfs image count..." +ACTUAL_IMAGES=$(unsquashfs -l "${OUTPUT_SQFS}" 2>/dev/null | grep -cE '\.(jpg|jpeg|png)$' || echo 0) +SIZE=$(du -sh "${OUTPUT_SQFS}" | cut -f1) + +echo " Expected : ${EXPECTED_IMAGES} images" +echo " Actual : ${ACTUAL_IMAGES} images" +echo " Size : ${SIZE}" +echo "" + +if [ "${ACTUAL_IMAGES}" -ne "${EXPECTED_IMAGES}" ]; then + echo "ERROR: image count mismatch (expected=${EXPECTED_IMAGES} actual=${ACTUAL_IMAGES})" + echo "Output sqfs may be incomplete. Chunks preserved in ${TASK_DIR}." + echo "Investigate: run audit_sqfs.py or check stream_chunks_to_tar logs above." + notify "bq_pack_task ${TASK_ID}: FAILED (count mismatch)" \ + "expected=${EXPECTED_IMAGES} actual=${ACTUAL_IMAGES} — chunks preserved in ${TASK_DIR}" + exit 1 +fi + +echo "Verification passed: ${ACTUAL_IMAGES} images confirmed in ${OUTPUT_SQFS}" +echo "" + +# ── Safe delete ─────────────────────────────────────────────────────────────── +# Only reached when: stream_exit=0, sqfstar_exit=0, image count matches. + +echo "Deleting chunk files (verified, safe to remove)..." +DELETED=0 +for chunk in $(find "${TASK_DIR}" -name "chunk_*.sqfs" | sort); do + rm -f "${chunk}" + DELETED=$((DELETED + 1)) +done +echo "Deleted ${DELETED} chunk files from ${TASK_DIR}" +echo "" + +echo "=== bq_pack_task ${TASK_ID} done at $(date) ===" +notify "bq_pack_task ${TASK_ID}: done" \ + "${ACTUAL_IMAGES} images in ${OUTPUT_SQFS} (${SIZE}) — ${DELETED} chunks deleted" From 506bb3c7fc47604d35027fa64b5d760cb4d59c52 Mon Sep 17 00:00:00 2001 From: Mohamed Elabbas Date: Thu, 4 Jun 2026 12:07:02 -0700 Subject: [PATCH 17/26] Add BQ training_images dedup clean script MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Cleans training_images table in-place by removing 3 duplicate types: 1. Exact duplicate rows (same photo+taxon+gbif) 2. Same photo mapped to multiple taxa — drop strategy by default 3. Same photo+taxon with multiple gbif_ids — keep MIN gbif_id Supports --dry-run, --min-images-per-taxon, --multi-taxon-strategy, and --log-file for JSON output. Operates on any dataset via --dataset. Co-Authored-By: Claude Sonnet 4.6 --- src/dataset_tools/bigquery_pipeline/clean.py | 239 ++++++++++++++++++ .../bigquery_pipeline/create_test_table.py | 100 ++++++++ 2 files changed, 339 insertions(+) create mode 100644 src/dataset_tools/bigquery_pipeline/clean.py create mode 100644 src/dataset_tools/bigquery_pipeline/create_test_table.py diff --git a/src/dataset_tools/bigquery_pipeline/clean.py b/src/dataset_tools/bigquery_pipeline/clean.py new file mode 100644 index 0000000..bf78166 --- /dev/null +++ b/src/dataset_tools/bigquery_pipeline/clean.py @@ -0,0 +1,239 @@ +#!/usr/bin/env python3 +""" +Clean a training_images BQ table by removing duplicates and sparse taxa. + +Duplicate types handled in order: + 1. Exact rows — same (photo_id, inat_taxon_id, gbif_id), keep one row + 2. Same-taxon multi-gbif — same (photo_id, inat_taxon_id), multiple gbif_ids → keep MIN gbif_id + 3. Multi-taxon conflict — same photo_id maps to multiple taxa → drop all or keep lowest taxon_id + 4. Min-images filter — drop taxa with fewer than N images (default: off) + +Overwrites the source table in-place unless --dry-run is passed. +""" +import argparse +import json +import sys +from datetime import datetime, timezone + +from google.cloud import bigquery + + +def _step3_cte(strategy): + if strategy == "drop": + return "SELECT * FROM step2 WHERE photo_id NOT IN (SELECT photo_id FROM multi_taxon_photos)" + else: # keep-lowest-taxon-id + return ( + "SELECT * FROM step2\n" + " QUALIFY ROW_NUMBER() OVER (PARTITION BY photo_id ORDER BY inat_taxon_id, gbif_id) = 1" + ) + + +def _cte_chain(src_ref, strategy, min_images): + min_filter = f"t.cnt >= {min_images}" if min_images > 0 else "TRUE" + return f""" +src AS ( + SELECT * FROM `{src_ref}` +), +step1 AS ( + -- Remove exact duplicate rows: same photo+taxon+gbif, keep one + SELECT * FROM src + QUALIFY ROW_NUMBER() OVER ( + PARTITION BY photo_id, inat_taxon_id, gbif_id + ORDER BY dataset_source_uuid + ) = 1 +), +step2 AS ( + -- Same photo+taxon with multiple gbif_ids: keep the MIN gbif_id + SELECT * FROM step1 + QUALIFY ROW_NUMBER() OVER ( + PARTITION BY photo_id, inat_taxon_id + ORDER BY gbif_id + ) = 1 +), +multi_taxon_photos AS ( + -- Photo IDs that still map to more than one taxon after step2 + SELECT photo_id FROM step2 + GROUP BY photo_id + HAVING COUNT(DISTINCT inat_taxon_id) > 1 +), +step3 AS ( + -- Resolve multi-taxon conflicts + {_step3_cte(strategy)} +), +taxa_img_count AS ( + SELECT inat_taxon_id, COUNT(*) AS cnt FROM step3 GROUP BY inat_taxon_id +), +step4 AS ( + -- Drop taxa below min-images threshold + SELECT s.* FROM step3 s + JOIN taxa_img_count t USING (inat_taxon_id) + WHERE {min_filter} +)""" + + +def run_count_query(client, src_ref, strategy, min_images): + """Return per-stage row/taxa counts in a single BQ query.""" + chain = _cte_chain(src_ref, strategy, min_images) + query = f""" +WITH {chain} +SELECT stage, n, taxa FROM ( + SELECT 'input' AS stage, COUNT(*) AS n, COUNT(DISTINCT inat_taxon_id) AS taxa FROM src UNION ALL + SELECT 'step1' AS stage, COUNT(*) AS n, COUNT(DISTINCT inat_taxon_id) AS taxa FROM step1 UNION ALL + SELECT 'step2' AS stage, COUNT(*) AS n, COUNT(DISTINCT inat_taxon_id) AS taxa FROM step2 UNION ALL + SELECT 'step3' AS stage, COUNT(*) AS n, COUNT(DISTINCT inat_taxon_id) AS taxa FROM step3 UNION ALL + SELECT 'step4' AS stage, COUNT(*) AS n, COUNT(DISTINCT inat_taxon_id) AS taxa FROM step4 UNION ALL + SELECT 'conflicts' AS stage, COUNT(*) AS n, 0 AS taxa FROM multi_taxon_photos +) +ORDER BY stage +""" + results = {row.stage: {"rows": row.n, "taxa": row.taxa} + for row in client.query(query).result()} + return results + + +def run_write_query(client, src_ref, dst_ref, strategy, min_images): + chain = _cte_chain(src_ref, strategy, min_images) + query = f"CREATE OR REPLACE TABLE `{dst_ref}` AS\nWITH {chain}\nSELECT * FROM step4" + client.query(query).result() + + +def print_report(counts, strategy, min_images, dst_ref, dry_run): + c = counts + step_labels = [ + ("input", "Input rows"), + ("step1", "Exact duplicates removed (same photo+taxon+gbif, keep one)"), + ("step2", "Same-taxon multi-gbif resolved (keep MIN gbif_id)"), + ("step3", f"Multi-taxon conflicts handled (strategy={strategy})"), + ("step4", f"Min-images-per-taxon filter (threshold={min_images or 'off'})"), + ] + + print() + print("=" * 70) + prev = None + for key, label in step_labels: + rows = c[key]["rows"] + taxa = c[key]["taxa"] + if prev is None: + print(f" {label}") + print(f" rows={rows:>10,} taxa={taxa:>8,}") + else: + removed = prev - rows + pct = removed / prev * 100 if prev else 0 + print(f" {label}") + print(f" rows={rows:>10,} removed={removed:>8,} ({pct:.1f}%)") + prev = rows + + # Step 3 extra detail + n_conflict_photos = c["conflicts"]["rows"] + print(f" [{n_conflict_photos:,} photo_ids had conflicting taxa]") + + total_removed = c["input"]["rows"] - c["step4"]["rows"] + total_pct = total_removed / c["input"]["rows"] * 100 if c["input"]["rows"] else 0 + print() + print(f" {'─'*60}") + print(f" Total removed : {total_removed:>10,} ({total_pct:.1f}%)") + print(f" Output rows : {c['step4']['rows']:>10,}") + print(f" Taxa before : {c['input']['taxa']:>10,}") + print(f" Taxa after : {c['step4']['taxa']:>10,}") + dry_tag = " [DRY RUN — table not written]" if dry_run else "" + print(f" Output table : {dst_ref}{dry_tag}") + print("=" * 70) + print() + + +def build_log(counts, args, dst_ref, dry_run, started_at): + c = counts + return { + "started_at": started_at, + "finished_at": datetime.now(timezone.utc).isoformat(), + "dataset": args.dataset, + "project": args.project, + "table": args.table, + "output_table": dst_ref, + "dry_run": dry_run, + "multi_taxon_strategy": args.multi_taxon_strategy, + "min_images_per_taxon": args.min_images_per_taxon, + "steps": { + "exact_duplicates": { + "before": c["input"]["rows"], + "after": c["step1"]["rows"], + "removed": c["input"]["rows"] - c["step1"]["rows"], + }, + "same_taxon_multi_gbif": { + "before": c["step1"]["rows"], + "after": c["step2"]["rows"], + "removed": c["step1"]["rows"] - c["step2"]["rows"], + }, + "multi_taxon_conflicts": { + "before": c["step2"]["rows"], + "after": c["step3"]["rows"], + "removed": c["step2"]["rows"] - c["step3"]["rows"], + "conflict_photo_ids": c["conflicts"]["rows"], + }, + "min_images_filter": { + "before": c["step3"]["rows"], + "after": c["step4"]["rows"], + "removed": c["step3"]["rows"] - c["step4"]["rows"], + "threshold": args.min_images_per_taxon, + }, + }, + "summary": { + "input_rows": c["input"]["rows"], + "output_rows": c["step4"]["rows"], + "total_removed": c["input"]["rows"] - c["step4"]["rows"], + "taxa_before": c["input"]["taxa"], + "taxa_after": c["step4"]["taxa"], + }, + } + + +def main(): + parser = argparse.ArgumentParser(description=__doc__, + formatter_class=argparse.RawDescriptionHelpFormatter) + parser.add_argument("--dataset", required=True, + help="BQ dataset name (e.g. global_all_leps_2605)") + parser.add_argument("--project", default="leps-ai") + parser.add_argument("--table", default="training_images", + help="Table to clean (overwritten in-place)") + parser.add_argument("--min-images-per-taxon", type=int, default=0, + help="Drop taxa with fewer images than this (0 = off)") + parser.add_argument("--multi-taxon-strategy", + choices=["drop", "keep-lowest-taxon-id"], default="drop", + help="How to handle a photo mapped to multiple taxa") + parser.add_argument("--dry-run", action="store_true", + help="Report what would be removed without writing the table") + parser.add_argument("--log-file", help="Write JSON log to this path") + args = parser.parse_args() + + started_at = datetime.now(timezone.utc).isoformat() + src_ref = f"{args.project}.{args.dataset}.{args.table}" + dst_ref = src_ref # overwrite in-place + + client = bigquery.Client(project=args.project) + + print(f"clean.py {'[DRY RUN] ' if args.dry_run else ''}started at {started_at}") + print(f" table : {src_ref}") + print(f" multi-taxon : {args.multi_taxon_strategy}") + print(f" min-images-per-taxon: {args.min_images_per_taxon or 'off'}") + print() + print("Running count queries ...") + + counts = run_count_query(client, src_ref, args.multi_taxon_strategy, args.min_images_per_taxon) + print_report(counts, args.multi_taxon_strategy, args.min_images_per_taxon, dst_ref, args.dry_run) + + if not args.dry_run: + print("Writing cleaned table ...") + run_write_query(client, src_ref, dst_ref, args.multi_taxon_strategy, args.min_images_per_taxon) + print(f"Done — {counts['step4']['rows']:,} rows written to {dst_ref}") + + log = build_log(counts, args, dst_ref, args.dry_run, started_at) + if args.log_file: + with open(args.log_file, "w") as f: + json.dump(log, f, indent=2) + print(f"Log written to {args.log_file}") + + return 0 + + +if __name__ == "__main__": + sys.exit(main()) diff --git a/src/dataset_tools/bigquery_pipeline/create_test_table.py b/src/dataset_tools/bigquery_pipeline/create_test_table.py new file mode 100644 index 0000000..2c03ff2 --- /dev/null +++ b/src/dataset_tools/bigquery_pipeline/create_test_table.py @@ -0,0 +1,100 @@ +#!/usr/bin/env python3 +""" +Create a stratified test table for clean.py testing. + +Includes all 3 duplicate types (exact dupes, same-taxon multi-gbif, +multi-taxon conflicts) plus a clean random sample. Total ~35-40K rows. +""" +import argparse +from google.cloud import bigquery + + +def main(): + parser = argparse.ArgumentParser(description=__doc__) + parser.add_argument("--dataset", required=True, help="BQ dataset name (e.g. global_all_leps_2605)") + parser.add_argument("--project", default="leps-ai") + parser.add_argument("--source-table", default="training_images") + parser.add_argument("--output-table", default="training_images_test") + parser.add_argument("--max-type1-photos", type=int, default=10_000, + help="Cap on type-1 duplicate photo_ids (each appears ×2 rows)") + parser.add_argument("--clean-sample", type=int, default=10_000, + help="Number of clean (non-duplicate) rows to include") + args = parser.parse_args() + + client = bigquery.Client(project=args.project) + src = f"`{args.project}.{args.dataset}.{args.source_table}`" + dst_ref = f"{args.project}.{args.dataset}.{args.output_table}" + dst = f"`{dst_ref}`" + + query = f""" + CREATE OR REPLACE TABLE {dst} AS + + WITH + -- Type 1: exact duplicate rows (same photo_id + taxon + gbif_id appears >1 time) + type1_photos AS ( + SELECT photo_id + FROM {src} + GROUP BY photo_id, inat_taxon_id, gbif_id + HAVING COUNT(*) > 1 + LIMIT {args.max_type1_photos} + ), + + -- Type 2: same photo_id mapped to multiple taxa AND multiple gbif_ids + type2_photos AS ( + SELECT photo_id + FROM {src} + GROUP BY photo_id + HAVING COUNT(DISTINCT inat_taxon_id) > 1 AND COUNT(DISTINCT gbif_id) > 1 + ), + + -- Type 3: same photo_id + same taxon, but multiple gbif_ids + type3_photos AS ( + SELECT photo_id + FROM {src} + GROUP BY photo_id + HAVING COUNT(DISTINCT inat_taxon_id) = 1 AND COUNT(DISTINCT gbif_id) > 1 + ), + + all_dup_photos AS ( + SELECT photo_id FROM type1_photos + UNION DISTINCT + SELECT photo_id FROM type2_photos + UNION DISTINCT + SELECT photo_id FROM type3_photos + ), + + -- All rows belonging to any duplicate photo_id + dup_rows AS ( + SELECT t.* + FROM {src} t + INNER JOIN all_dup_photos d USING (photo_id) + ), + + -- Clean rows: photo_ids not involved in any duplication + clean_rows AS ( + SELECT t.* + FROM {src} t + WHERE t.photo_id NOT IN (SELECT photo_id FROM all_dup_photos) + LIMIT {args.clean_sample} + ) + + SELECT * FROM dup_rows + UNION ALL + SELECT * FROM clean_rows + """ + + print(f"Creating {dst_ref} ...") + print(f" source : {args.project}.{args.dataset}.{args.source_table}") + print(f" type-1 cap: {args.max_type1_photos:,} photo_ids") + print(f" clean sample: {args.clean_sample:,} rows") + print() + + job = client.query(query) + job.result() + + ref = client.get_table(dst_ref) + print(f"Done — {ref.num_rows:,} rows written to {dst_ref}") + + +if __name__ == "__main__": + main() From 835896609754dd7d6878a57b194b174341735204 Mon Sep 17 00:00:00 2001 From: Mohamed Elabbas Date: Thu, 4 Jun 2026 12:07:07 -0700 Subject: [PATCH 18/26] Add test suite for clean.py (68 tests) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Two layers: - Mock-based: SQL generation, BQ client calls, log arithmetic, main() wiring - DuckDB integration: executes real CTE SQL against in-memory rows, covers all 3 dup types + edge cases (overlapping types, 3+ copies, empty table) No real BQ connection needed — all tests run locally in ~2.5s. Co-Authored-By: Claude Sonnet 4.6 --- tests/dataset_tools/test_clean.py | 538 ++++++++++++++++++++++++++++++ 1 file changed, 538 insertions(+) create mode 100644 tests/dataset_tools/test_clean.py diff --git a/tests/dataset_tools/test_clean.py b/tests/dataset_tools/test_clean.py new file mode 100644 index 0000000..c9ff37d --- /dev/null +++ b/tests/dataset_tools/test_clean.py @@ -0,0 +1,538 @@ +""" +Tests for bigquery_pipeline/clean.py. + +Two layers: + - Mock-based: verify SQL generation and Python logic (no BQ, no network) + - DuckDB integration: execute the real CTE chain against in-memory rows, + assert all 3 duplicate types are resolved correctly + +Run with: pytest tests/dataset_tools/test_clean.py -v +""" +import json +import types +from argparse import Namespace +from unittest.mock import MagicMock, call, patch + +import duckdb +import pandas as pd +import pytest + +import pytest + +import src.dataset_tools.bigquery_pipeline.clean as clean + + +# ── helpers ─────────────────────────────────────────────────────────────────── + +def make_count_row(stage, n, taxa=100): + """Minimal BQ row mock matching the count query SELECT (stage, n, taxa).""" + row = MagicMock() + row.stage = stage + row.n = n + row.taxa = taxa + return row + + +SAMPLE_COUNTS = { + "input": {"rows": 37_500, "taxa": 5_000}, + "step1": {"rows": 27_000, "taxa": 5_000}, + "step2": {"rows": 25_700, "taxa": 5_000}, + "step3": {"rows": 21_000, "taxa": 4_700}, + "step4": {"rows": 21_000, "taxa": 4_700}, + "conflicts": {"rows": 2_229, "taxa": 0}, +} + + +def make_bq_client(counts=SAMPLE_COUNTS): + """Mock BQ client whose query().result() returns the given stage counts.""" + rows = [make_count_row(stage, v["rows"], v["taxa"]) for stage, v in counts.items()] + client = MagicMock() + client.query.return_value.result.return_value = rows + return client + + +def make_args(**kwargs): + """Build a minimal Namespace matching clean.py's argparse output.""" + defaults = dict( + dataset="global_all_leps_2605", + project="leps-ai", + table="training_images", + min_images_per_taxon=0, + multi_taxon_strategy="drop", + dry_run=False, + log_file=None, + ) + defaults.update(kwargs) + return Namespace(**defaults) + + +# ── _step3_cte ──────────────────────────────────────────────────────────────── + +class TestStep3Cte: + + def test_drop_strategy_uses_not_in(self): + sql = clean._step3_cte("drop") + assert "NOT IN" in sql + assert "multi_taxon_photos" in sql + + def test_keep_lowest_uses_qualify_row_number(self): + sql = clean._step3_cte("keep-lowest-taxon-id") + assert "QUALIFY" in sql + assert "ROW_NUMBER()" in sql + assert "PARTITION BY photo_id" in sql + assert "ORDER BY inat_taxon_id" in sql + + def test_drop_does_not_contain_qualify(self): + sql = clean._step3_cte("drop") + assert "QUALIFY" not in sql + + def test_keep_lowest_does_not_contain_not_in(self): + sql = clean._step3_cte("keep-lowest-taxon-id") + assert "NOT IN" not in sql + + +# ── _cte_chain SQL content ──────────────────────────────────────────────────── + +class TestCteSqlContent: + + SRC = "leps-ai.global_all_leps_2605.training_images" + + def _chain(self, strategy="drop", min_images=0): + return clean._cte_chain(self.SRC, strategy, min_images) + + def test_src_ref_embedded(self): + assert self.SRC in self._chain() + + def test_step1_exact_dedup_partition(self): + sql = self._chain() + assert "PARTITION BY photo_id, inat_taxon_id, gbif_id" in sql + + def test_step1_uses_qualify_row_number(self): + sql = self._chain() + assert "QUALIFY" in sql + assert "ROW_NUMBER()" in sql + + def test_step2_same_taxon_multi_gbif_partition(self): + sql = self._chain() + assert "PARTITION BY photo_id, inat_taxon_id" in sql + + def test_step2_orders_by_gbif_id(self): + sql = self._chain() + assert "ORDER BY gbif_id" in sql + + def test_multi_taxon_photos_cte_present(self): + sql = self._chain() + assert "multi_taxon_photos" in sql + assert "COUNT(DISTINCT inat_taxon_id) > 1" in sql + + def test_step4_min_images_filter_active(self): + sql = self._chain(min_images=10) + assert "t.cnt >= 10" in sql + + def test_step4_min_images_disabled_when_zero(self): + sql = self._chain(min_images=0) + assert "TRUE" in sql + assert "t.cnt >=" not in sql + + def test_step3_drop_strategy_in_chain(self): + sql = self._chain(strategy="drop") + assert "NOT IN" in sql + + def test_step3_keep_lowest_strategy_in_chain(self): + sql = self._chain(strategy="keep-lowest-taxon-id") + assert "ORDER BY inat_taxon_id" in sql + + +# ── run_count_query ─────────────────────────────────────────────────────────── + +class TestRunCountQuery: + + SRC = "leps-ai.global_all_leps_2605.training_images_test" + + def test_returns_dict_with_all_stages(self): + client = make_bq_client() + result = clean.run_count_query(client, self.SRC, "drop", 0) + assert set(result.keys()) == {"input", "step1", "step2", "step3", "step4", "conflicts"} + + def test_rows_parsed_correctly(self): + client = make_bq_client() + result = clean.run_count_query(client, self.SRC, "drop", 0) + assert result["input"]["rows"] == 37_500 + assert result["step4"]["rows"] == 21_000 + assert result["conflicts"]["rows"] == 2_229 + + def test_taxa_parsed_correctly(self): + client = make_bq_client() + result = clean.run_count_query(client, self.SRC, "drop", 0) + assert result["input"]["taxa"] == 5_000 + assert result["step3"]["taxa"] == 4_700 + + def test_bq_client_query_called_once(self): + client = make_bq_client() + clean.run_count_query(client, self.SRC, "drop", 0) + assert client.query.call_count == 1 + + def test_result_called_on_query(self): + client = make_bq_client() + clean.run_count_query(client, self.SRC, "drop", 0) + client.query.return_value.result.assert_called_once() + + def test_sql_contains_union_all_for_all_stages(self): + client = make_bq_client() + clean.run_count_query(client, self.SRC, "drop", 0) + sql = client.query.call_args[0][0] + assert sql.count("UNION ALL") >= 5 + + def test_sql_contains_src_ref(self): + client = make_bq_client() + clean.run_count_query(client, self.SRC, "drop", 0) + sql = client.query.call_args[0][0] + assert self.SRC in sql + + def test_drop_strategy_sql_has_not_in(self): + client = make_bq_client() + clean.run_count_query(client, self.SRC, "drop", 0) + sql = client.query.call_args[0][0] + assert "NOT IN" in sql + + def test_keep_lowest_strategy_sql_has_qualify(self): + client = make_bq_client() + clean.run_count_query(client, self.SRC, "keep-lowest-taxon-id", 0) + sql = client.query.call_args[0][0] + assert "ORDER BY inat_taxon_id" in sql + + +# ── run_write_query ─────────────────────────────────────────────────────────── + +class TestRunWriteQuery: + + SRC = "leps-ai.global_all_leps_2605.training_images" + DST = "leps-ai.global_all_leps_2605.training_images" + + def test_issues_create_or_replace_table(self): + client = MagicMock() + clean.run_write_query(client, self.SRC, self.DST, "drop", 0) + sql = client.query.call_args[0][0] + assert "CREATE OR REPLACE TABLE" in sql + + def test_dst_ref_in_sql(self): + client = MagicMock() + clean.run_write_query(client, self.SRC, self.DST, "drop", 0) + sql = client.query.call_args[0][0] + assert self.DST in sql + + def test_selects_from_step4(self): + client = MagicMock() + clean.run_write_query(client, self.SRC, self.DST, "drop", 0) + sql = client.query.call_args[0][0] + assert "SELECT * FROM step4" in sql + + def test_result_called(self): + client = MagicMock() + clean.run_write_query(client, self.SRC, self.DST, "drop", 0) + client.query.return_value.result.assert_called_once() + + def test_query_called_exactly_once(self): + client = MagicMock() + clean.run_write_query(client, self.SRC, self.DST, "drop", 0) + assert client.query.call_count == 1 + + +# ── build_log ───────────────────────────────────────────────────────────────── + +class TestBuildLog: + + STARTED = "2026-06-04T01:00:00+00:00" + + def _log(self, counts=SAMPLE_COUNTS, **kwargs): + args = make_args(**kwargs) + return clean.build_log(counts, args, "leps-ai.ds.training_images", args.dry_run, self.STARTED) + + def test_summary_input_rows(self): + assert self._log()["summary"]["input_rows"] == 37_500 + + def test_summary_output_rows(self): + assert self._log()["summary"]["output_rows"] == 21_000 + + def test_summary_total_removed(self): + log = self._log() + assert log["summary"]["total_removed"] == 37_500 - 21_000 + + def test_summary_taxa_before_and_after(self): + log = self._log() + assert log["summary"]["taxa_before"] == 5_000 + assert log["summary"]["taxa_after"] == 4_700 + + def test_step1_removed_arithmetic(self): + log = self._log() + step = log["steps"]["exact_duplicates"] + assert step["removed"] == 37_500 - 27_000 + + def test_step2_removed_arithmetic(self): + log = self._log() + step = log["steps"]["same_taxon_multi_gbif"] + assert step["removed"] == 27_000 - 25_700 + + def test_step3_removed_arithmetic(self): + log = self._log() + step = log["steps"]["multi_taxon_conflicts"] + assert step["removed"] == 25_700 - 21_000 + + def test_step3_conflict_photo_ids(self): + log = self._log() + assert log["steps"]["multi_taxon_conflicts"]["conflict_photo_ids"] == 2_229 + + def test_step4_threshold_recorded(self): + log = self._log(min_images_per_taxon=10) + assert log["steps"]["min_images_filter"]["threshold"] == 10 + + def test_dry_run_flag_true(self): + assert self._log(dry_run=True)["dry_run"] is True + + def test_dry_run_flag_false(self): + assert self._log(dry_run=False)["dry_run"] is False + + def test_started_at_preserved(self): + assert self._log()["started_at"] == self.STARTED + + def test_log_is_json_serialisable(self): + log = self._log() + json.dumps(log) # must not raise + + +# ── print_report ────────────────────────────────────────────────────────────── + +class TestPrintReport: + + def test_no_crash(self, capsys): + clean.print_report(SAMPLE_COUNTS, "drop", 0, "leps-ai.ds.t", dry_run=False) + capsys.readouterr() # consume output + + def test_dry_run_tag_present(self, capsys): + clean.print_report(SAMPLE_COUNTS, "drop", 0, "leps-ai.ds.t", dry_run=True) + out = capsys.readouterr().out + assert "DRY RUN" in out + + def test_dry_run_tag_absent_when_not_dry(self, capsys): + clean.print_report(SAMPLE_COUNTS, "drop", 0, "leps-ai.ds.t", dry_run=False) + out = capsys.readouterr().out + assert "DRY RUN" not in out + + def test_total_removed_in_output(self, capsys): + clean.print_report(SAMPLE_COUNTS, "drop", 0, "leps-ai.ds.t", dry_run=False) + out = capsys.readouterr().out + assert "16,500" in out # 37500 - 21000 + + def test_conflict_photo_ids_in_output(self, capsys): + clean.print_report(SAMPLE_COUNTS, "drop", 0, "leps-ai.ds.t", dry_run=False) + out = capsys.readouterr().out + assert "2,229" in out + + def test_strategy_name_in_output(self, capsys): + clean.print_report(SAMPLE_COUNTS, "keep-lowest-taxon-id", 0, "leps-ai.ds.t", dry_run=False) + out = capsys.readouterr().out + assert "keep-lowest-taxon-id" in out + + +# ── main() integration ──────────────────────────────────────────────────────── + +class TestMainIntegration: + + def _run_main(self, argv, counts=SAMPLE_COUNTS): + client = make_bq_client(counts) + with patch("sys.argv", ["clean.py"] + argv), \ + patch("src.dataset_tools.bigquery_pipeline.clean.bigquery.Client", + return_value=client): + clean.main() + return client + + def test_dry_run_calls_count_query_not_write(self): + client = self._run_main([ + "--dataset", "global_all_leps_2605", + "--table", "training_images_test", + "--dry-run", + ]) + assert client.query.call_count == 1 + sql = client.query.call_args[0][0] + assert "CREATE OR REPLACE TABLE" not in sql + + def test_actual_run_calls_count_then_write(self): + client = self._run_main([ + "--dataset", "global_all_leps_2605", + "--table", "training_images_test", + ]) + assert client.query.call_count == 2 + write_sql = client.query.call_args_list[1][0][0] + assert "CREATE OR REPLACE TABLE" in write_sql + + def test_log_file_written(self, tmp_path): + log_path = tmp_path / "clean.json" + self._run_main([ + "--dataset", "global_all_leps_2605", + "--dry-run", + "--log-file", str(log_path), + ]) + assert log_path.exists() + log = json.loads(log_path.read_text()) + assert "summary" in log + assert "steps" in log + + def test_min_images_arg_propagated_to_sql(self): + client = self._run_main([ + "--dataset", "global_all_leps_2605", + "--min-images-per-taxon", "15", + "--dry-run", + ]) + sql = client.query.call_args[0][0] + assert "t.cnt >= 15" in sql + + def test_keep_lowest_strategy_propagated_to_sql(self): + client = self._run_main([ + "--dataset", "global_all_leps_2605", + "--multi-taxon-strategy", "keep-lowest-taxon-id", + "--dry-run", + ]) + sql = client.query.call_args[0][0] + assert "ORDER BY inat_taxon_id" in sql + + +# ── DuckDB integration — real SQL against in-memory rows ───────────────────── + +@pytest.fixture +def dup_df(): + """ + Small DataFrame covering all 3 duplicate types plus clean rows. + + photo_id=1 type 1 — exact duplicate (identical rows) + photo_id=2 type 2 — same photo, two different taxa + gbif_ids + photo_id=3 type 3 — same photo + taxon, two different gbif_ids + photo_id=4 clean + photo_id=5 clean, different taxon + photo_id=6 clean, used for min-images-per-taxon tests (sole photo of taxon 600) + """ + return pd.DataFrame([ + # Type 1: exact dup + dict(photo_id=1, inat_taxon_id=100, gbif_id=1000, dataset_source_uuid="uuid-1"), + dict(photo_id=1, inat_taxon_id=100, gbif_id=1000, dataset_source_uuid="uuid-1"), + # Type 2: multi-taxon conflict + dict(photo_id=2, inat_taxon_id=200, gbif_id=2000, dataset_source_uuid="uuid-2"), + dict(photo_id=2, inat_taxon_id=201, gbif_id=2001, dataset_source_uuid="uuid-2"), + # Type 3: same taxon, multi-gbif (keep MIN gbif_id=3000) + dict(photo_id=3, inat_taxon_id=300, gbif_id=3000, dataset_source_uuid="uuid-3"), + dict(photo_id=3, inat_taxon_id=300, gbif_id=3001, dataset_source_uuid="uuid-3"), + # Clean rows + dict(photo_id=4, inat_taxon_id=100, gbif_id=4000, dataset_source_uuid="uuid-4"), + dict(photo_id=5, inat_taxon_id=300, gbif_id=5000, dataset_source_uuid="uuid-5"), + # Lone taxon (taxon 600 has only 1 image — for min-images test) + dict(photo_id=6, inat_taxon_id=600, gbif_id=6000, dataset_source_uuid="uuid-6"), + ]) + + +def run_cte(df, strategy="drop", min_images=0): + """Execute the clean.py CTE chain against a pandas DataFrame via DuckDB.""" + conn = duckdb.connect() + conn.register("src_table", df) + cte = clean._cte_chain("src_table", strategy, min_images).replace("`", "") + return conn.execute(f"WITH {cte} SELECT * FROM step4").df() + + +class TestDuckDBIntegration: + """Execute the real CTE SQL against in-memory rows — no BQ required.""" + + def test_output_row_count(self, dup_df): + result = run_cte(dup_df) + # type1: 1 kept, type2: 0 kept (both dropped), type3: 1 kept, clean: 3 kept + assert len(result) == 5 + + def test_type1_exact_dup_resolved(self, dup_df): + result = run_cte(dup_df) + assert result[result.photo_id == 1].shape[0] == 1 + + def test_type2_multi_taxon_dropped(self, dup_df): + result = run_cte(dup_df) + assert result[result.photo_id == 2].shape[0] == 0 + + def test_type3_keeps_min_gbif_id(self, dup_df): + result = run_cte(dup_df) + row = result[result.photo_id == 3] + assert len(row) == 1 + assert row.iloc[0].gbif_id == 3000 + + def test_clean_rows_preserved(self, dup_df): + result = run_cte(dup_df) + assert result[result.photo_id == 4].shape[0] == 1 + assert result[result.photo_id == 5].shape[0] == 1 + + def test_no_duplicate_photo_ids_remain(self, dup_df): + result = run_cte(dup_df) + assert result.photo_id.nunique() == len(result) + + def test_no_duplicate_uuids_remain(self, dup_df): + result = run_cte(dup_df) + assert result.dataset_source_uuid.nunique() == len(result) + + def test_each_photo_has_one_taxon(self, dup_df): + result = run_cte(dup_df) + taxa_per_photo = result.groupby("photo_id")["inat_taxon_id"].nunique() + assert (taxa_per_photo == 1).all() + + def test_min_images_drops_lone_taxon(self, dup_df): + # taxon 600 has only 1 image — threshold=2 should drop it + result = run_cte(dup_df, min_images=2) + assert result[result.inat_taxon_id == 600].shape[0] == 0 + + def test_min_images_keeps_taxon_above_threshold(self, dup_df): + # taxon 100 has 2 images (photo_id 1 and 4) — survives threshold=2 + result = run_cte(dup_df, min_images=2) + assert result[result.inat_taxon_id == 100].shape[0] == 2 + + def test_keep_lowest_taxon_keeps_one_row_per_photo(self, dup_df): + result = run_cte(dup_df, strategy="keep-lowest-taxon-id") + assert result[result.photo_id == 2].shape[0] == 1 + + def test_keep_lowest_taxon_picks_min_taxon_id(self, dup_df): + result = run_cte(dup_df, strategy="keep-lowest-taxon-id") + row = result[result.photo_id == 2] + assert row.iloc[0].inat_taxon_id == 200 # min(200, 201) + + # ── edge cases ──────────────────────────────────────────────────────────── + + def test_type1_and_type3_overlap_resolved_to_one_row(self): + """Photo with both exact-dup rows AND multiple gbif_ids under same taxon.""" + df = pd.DataFrame([ + dict(photo_id=1, inat_taxon_id=100, gbif_id=1000, dataset_source_uuid="uuid-1"), + dict(photo_id=1, inat_taxon_id=100, gbif_id=1000, dataset_source_uuid="uuid-1"), + dict(photo_id=1, inat_taxon_id=100, gbif_id=1001, dataset_source_uuid="uuid-1"), + ]) + result = run_cte(df) + assert len(result) == 1 + assert result.iloc[0].gbif_id == 1000 + + def test_type1_with_three_copies_resolves_to_one(self): + """Three identical rows (tripled duplicate) collapse to one.""" + df = pd.DataFrame([ + dict(photo_id=1, inat_taxon_id=100, gbif_id=1000, dataset_source_uuid="uuid-1"), + dict(photo_id=1, inat_taxon_id=100, gbif_id=1000, dataset_source_uuid="uuid-1"), + dict(photo_id=1, inat_taxon_id=100, gbif_id=1000, dataset_source_uuid="uuid-1"), + ]) + result = run_cte(df) + assert len(result) == 1 + + def test_type3_with_three_gbif_ids_keeps_min(self): + """Same photo+taxon with 3 different gbif_ids — keeps the lowest.""" + df = pd.DataFrame([ + dict(photo_id=1, inat_taxon_id=100, gbif_id=3000, dataset_source_uuid="uuid-1"), + dict(photo_id=1, inat_taxon_id=100, gbif_id=3001, dataset_source_uuid="uuid-1"), + dict(photo_id=1, inat_taxon_id=100, gbif_id=3002, dataset_source_uuid="uuid-1"), + ]) + result = run_cte(df) + assert len(result) == 1 + assert result.iloc[0].gbif_id == 3000 + + def test_empty_table_returns_zero_rows(self): + """Empty input table produces empty output without error.""" + df = pd.DataFrame( + columns=["photo_id", "inat_taxon_id", "gbif_id", "dataset_source_uuid"] + ) + result = run_cte(df) + assert len(result) == 0 From 5a7e437e5fcd2dd34cc3069173d31406b603a458 Mon Sep 17 00:00:00 2001 From: Mohamed Elabbas Date: Fri, 5 Jun 2026 08:56:59 -0700 Subject: [PATCH 19/26] feat(download): add --dataset flag and NULL fetch_status support - Add --dataset CLI arg (default: global_butterflies_2604, backwards compatible) - Handle NULL fetch_status in get_pending_rows and MERGE condition for global_all_leps_2605 - Derive tmp_table dataset from training_table string, not hardcoded BQ_DATASET constant - Add 6 new tests covering dataset routing, NULL query handling, and tmp_table derivation Co-Authored-By: Claude Sonnet 4.6 --- .../bq_squashfs/download_images.py | 23 +++- tests/dataset_tools/test_download_images.py | 124 ++++++++++++++++++ 2 files changed, 140 insertions(+), 7 deletions(-) diff --git a/src/dataset_tools/bq_squashfs/download_images.py b/src/dataset_tools/bq_squashfs/download_images.py index a87f420..4a76ca9 100644 --- a/src/dataset_tools/bq_squashfs/download_images.py +++ b/src/dataset_tools/bq_squashfs/download_images.py @@ -49,7 +49,7 @@ Image.MAX_IMAGE_PIXELS = None BQ_PROJECT = "leps-ai" -BQ_DATASET = "global_butterflies_2604" +BQ_DEFAULT_DATASET = "global_butterflies_2604" # Retry config for HTTP downloads _RETRY_STATUSES = {429, 500, 502, 503, 504} @@ -244,7 +244,10 @@ def merge_chunk_into_training_images( if not to_merge: return 0 - tmp_table = f"{BQ_PROJECT}.{BQ_DATASET}._dl_merge_tmp_{uuid.uuid4().hex[:8]}" + # Derive "project.dataset" from training_table ("project.dataset.table_name") + _parts = training_table.split(".") + _dataset_ref = ".".join(_parts[:2]) if len(_parts) >= 2 else _parts[0] + tmp_table = f"{_dataset_ref}._dl_merge_tmp_{uuid.uuid4().hex[:8]}" df = pd.DataFrame(to_merge) job_config = bigquery.LoadJobConfig( write_disposition=bigquery.WriteDisposition.WRITE_TRUNCATE, @@ -257,7 +260,7 @@ def merge_chunk_into_training_images( MERGE `{training_table}` T USING `{tmp_table}` S ON T.dataset_source_uuid = S.dataset_source_uuid - WHEN MATCHED AND T.fetch_status = 'pending' THEN UPDATE SET + WHEN MATCHED AND (T.fetch_status = 'pending' OR T.fetch_status IS NULL) THEN UPDATE SET T.fetch_status = S.fetch_status, T.image_width = S.image_width, T.image_height = S.image_height, @@ -285,7 +288,7 @@ def get_pending_rows( query = f""" SELECT dataset_source_uuid, absolute_url, relative_local_path FROM `{training_table}` - WHERE fetch_status = 'pending' + WHERE (fetch_status = 'pending' OR fetch_status IS NULL) AND MOD(photo_id, {num_jobs}) = {task_id} {limit_clause} """ @@ -295,7 +298,7 @@ def get_pending_rows( FROM `{training_table}` ti LEFT JOIN `{downloads_table}` d ON ti.dataset_source_uuid = d.dataset_source_uuid - WHERE ti.fetch_status = 'pending' + WHERE (ti.fetch_status = 'pending' OR ti.fetch_status IS NULL) AND MOD(ti.photo_id, {num_jobs}) = {task_id} AND d.dataset_source_uuid IS NULL {limit_clause} @@ -416,6 +419,12 @@ def main(): "the chunks from scratch. Without this flag, already-attempted " "images are skipped via LEFT JOIN." )) + parser.add_argument("--dataset", default=BQ_DEFAULT_DATASET, + help=( + f"BigQuery dataset name within the leps-ai project " + f"(default: {BQ_DEFAULT_DATASET}). " + f"Example: --dataset global_all_leps_2605" + )) parser.add_argument("--table-prefix", default="", help=( "BQ table name prefix for testing without touching production. " @@ -425,8 +434,8 @@ def main(): )) args = parser.parse_args() - training_table = f"{BQ_PROJECT}.{BQ_DATASET}.{args.table_prefix}training_images" - downloads_table = f"{BQ_PROJECT}.{BQ_DATASET}.{args.table_prefix}training_images_downloads" + training_table = f"{BQ_PROJECT}.{args.dataset}.{args.table_prefix}training_images" + downloads_table = f"{BQ_PROJECT}.{args.dataset}.{args.table_prefix}training_images_downloads" client = bigquery.Client(project=BQ_PROJECT) staging_dir = Path(args.staging_dir) diff --git a/tests/dataset_tools/test_download_images.py b/tests/dataset_tools/test_download_images.py index bcce8a5..f11e7a5 100644 --- a/tests/dataset_tools/test_download_images.py +++ b/tests/dataset_tools/test_download_images.py @@ -661,3 +661,127 @@ def test_warning_at_threshold(self, tmp_path, capsys): (tmp_path / f"chunk_{i:04d}.sqfs").write_bytes(b"x") di.warn_chunk_accumulation(tmp_path) assert "WARNING" in capsys.readouterr().out + + +# ── --dataset flag and NULL fetch_status handling ───────────────────────────── + +class TestDatasetFlagAndNullFetchStatus: + """Tests for --dataset CLI flag and NULL fetch_status support (global_all_leps_2605).""" + + # ── --dataset default ───────────────────────────────────────────────────── + + def test_default_dataset_constant(self): + """BQ_DEFAULT_DATASET must default to global_butterflies_2604 for backwards compat.""" + assert di.BQ_DEFAULT_DATASET == "global_butterflies_2604" + + # ── NULL fetch_status in get_pending_rows ───────────────────────────────── + + def test_normal_query_includes_null_fetch_status(self): + """Normal (LEFT JOIN) query must include OR fetch_status IS NULL to pick up + rows from datasets like global_all_leps_2605 where status starts as NULL.""" + client = MagicMock() + client.query.return_value.result.return_value = [] + + di.get_pending_rows( + client, + training_table="proj.ds.training_images", + downloads_table="proj.ds.training_images_downloads", + num_jobs=10, + task_id=0, + force_redownload=False, + ) + + sql = client.query.call_args[0][0] + assert "fetch_status IS NULL" in sql + assert "fetch_status = 'pending'" in sql + + def test_force_redownload_query_includes_null_fetch_status(self): + """Force-redownload query must also include OR fetch_status IS NULL.""" + client = MagicMock() + client.query.return_value.result.return_value = [] + + di.get_pending_rows( + client, + training_table="proj.ds.training_images", + downloads_table="proj.ds.training_images_downloads", + num_jobs=10, + task_id=0, + force_redownload=True, + ) + + sql = client.query.call_args[0][0] + assert "fetch_status IS NULL" in sql + assert "fetch_status = 'pending'" in sql + assert "LEFT JOIN" not in sql + + # ── NULL fetch_status in MERGE ──────────────────────────────────────────── + + def test_merge_condition_handles_null_fetch_status(self): + """MERGE SQL must allow updating rows where T.fetch_status IS NULL, + not just 'pending' — needed for global_all_leps_2605 initial state.""" + client = MagicMock() + client.load_table_from_dataframe.return_value.result.return_value = None + job = MagicMock() + job.dml_stats.updated_row_count = 1 + client.query.return_value = job + + results = [{"dataset_source_uuid": "u1", "fetch_status": "downloaded", + "image_width": 64, "image_height": 48, + "image_size": 1000, "corrupted": False}] + + di.merge_chunk_into_training_images( + client, results, + training_table="leps-ai.global_all_leps_2605.training_images", + downloads_table="leps-ai.global_all_leps_2605.training_images_downloads", + ) + + sql = client.query.call_args[0][0] + assert "fetch_status IS NULL" in sql + assert "fetch_status = 'pending'" in sql + + # ── tmp_table derived from training_table ───────────────────────────────── + + def test_tmp_table_uses_same_dataset_as_training_table(self): + """Temp table for MERGE must be in the same dataset as training_table, + not hardcoded to global_butterflies_2604.""" + client = MagicMock() + client.load_table_from_dataframe.return_value.result.return_value = None + job = MagicMock() + job.dml_stats.updated_row_count = 1 + client.query.return_value = job + + results = [{"dataset_source_uuid": "u1", "fetch_status": "downloaded", + "image_width": 64, "image_height": 48, + "image_size": 1000, "corrupted": False}] + + di.merge_chunk_into_training_images( + client, results, + training_table="leps-ai.global_all_leps_2605.training_images", + downloads_table="leps-ai.global_all_leps_2605.training_images_downloads", + ) + + # The temp table passed to load_table_from_dataframe must be in global_all_leps_2605 + tmp_table_arg = client.load_table_from_dataframe.call_args[0][1] + assert tmp_table_arg.startswith("leps-ai.global_all_leps_2605.") + assert "global_butterflies_2604" not in tmp_table_arg + + def test_tmp_table_not_in_wrong_dataset_when_using_new_dataset(self): + """Regression: old code used BQ_DATASET module constant — ensure it no longer does.""" + client = MagicMock() + client.load_table_from_dataframe.return_value.result.return_value = None + job = MagicMock() + job.dml_stats.updated_row_count = 1 + client.query.return_value = job + + results = [{"dataset_source_uuid": "u1", "fetch_status": "downloaded", + "image_width": 64, "image_height": 48, + "image_size": 1000, "corrupted": False}] + + di.merge_chunk_into_training_images( + client, results, + training_table="leps-ai.global_all_leps_2605.training_images", + downloads_table="leps-ai.global_all_leps_2605.training_images_downloads", + ) + + tmp_table_arg = client.load_table_from_dataframe.call_args[0][1] + assert "global_butterflies_2604" not in tmp_table_arg # must not leak old dataset From f7cfc5edbd1d2aea926e3624675f44782347828b Mon Sep 17 00:00:00 2001 From: Mohamed Elabbas Date: Fri, 5 Jun 2026 08:57:04 -0700 Subject: [PATCH 20/26] feat(jobs): point bq_download job to global_all_leps_2605 - Pass --dataset global_all_leps_2605 to download_images.py - Use separate staging dir /scratch/melabbas/global_all_leps_2605/task_{N}/ - Rename job to bq_dl_2605 for squeue clarity Co-Authored-By: Claude Sonnet 4.6 --- scripts/job_bq_download.sh | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/scripts/job_bq_download.sh b/scripts/job_bq_download.sh index b128098..082407e 100755 --- a/scripts/job_bq_download.sh +++ b/scripts/job_bq_download.sh @@ -10,7 +10,7 @@ # sbatch --dependency=afterok:$DOWNLOAD_JOB job_bq_pack_squashfs.sh # #SBATCH --account=def-drolnick -#SBATCH --job-name=bq_download +#SBATCH --job-name=bq_dl_2605 #SBATCH --cpus-per-task=32 #SBATCH --mem=64G #SBATCH --time=72:00:00 @@ -24,7 +24,7 @@ TASK_ID=${SLURM_ARRAY_TASK_ID} # Each job downloads its images into its own staging subdirectory # These are kept after the job ends (on Lustre) for the pack job to merge -STAGING_BASE="/scratch/melabbas/bq_download_staging" +STAGING_BASE="/scratch/melabbas/global_all_leps_2605" STAGING_DIR="${STAGING_BASE}/task_${TASK_ID}" echo "=== bq_download task=${TASK_ID}/${NUM_JOBS} started at $(date) ===" @@ -42,7 +42,8 @@ python src/dataset_tools/bq_squashfs/download_images.py \ --num-jobs ${NUM_JOBS} \ --task-id ${TASK_ID} \ --num-workers 32 \ - --chunk-size 10000 + --chunk-size 10000 \ + --dataset global_all_leps_2605 EXIT_CODE=$? echo "=== bq_download task=${TASK_ID} finished at $(date) (exit=${EXIT_CODE}) ===" From 979c3de523f2a41e6f82b38a77b1047fc4a23508 Mon Sep 17 00:00:00 2001 From: Mohamed Elabbas Date: Fri, 5 Jun 2026 12:51:05 -0700 Subject: [PATCH 21/26] fix(download): retry BQ MERGE on serialization conflicts Concurrent array tasks each MERGE chunk results into training_images; BigQuery aborts colliding DML with "Could not serialize access ... due to concurrent update" (400), which the client does not retry. Observed in production with 8 concurrent tasks (job 43176702_0). Retry only serialization conflicts, up to 10 attempts with jittered exponential backoff (2s base, 60s cap). Other BadRequest errors still raise immediately. Co-Authored-By: Claude Opus 4.8 --- .../bq_squashfs/download_images.py | 26 +++++++++++++++---- 1 file changed, 21 insertions(+), 5 deletions(-) diff --git a/src/dataset_tools/bq_squashfs/download_images.py b/src/dataset_tools/bq_squashfs/download_images.py index 4a76ca9..6c172ab 100644 --- a/src/dataset_tools/bq_squashfs/download_images.py +++ b/src/dataset_tools/bq_squashfs/download_images.py @@ -44,6 +44,7 @@ import PIL import requests from PIL import Image +from google.api_core import exceptions as google_exceptions from google.cloud import bigquery Image.MAX_IMAGE_PIXELS = None @@ -56,6 +57,7 @@ _MAX_RETRIES = 5 _BACKOFF_BASE = 2.0 # seconds _BACKOFF_MAX = 60.0 # seconds cap +_MERGE_MAX_RETRIES = 10 # BQ MERGE serialization conflicts (concurrent tasks) # Warn if this many chunk sqfs files accumulate (pack job falling behind) _CHUNK_ACCUMULATION_WARN = 20 @@ -255,8 +257,7 @@ def merge_chunk_into_training_images( ) client.load_table_from_dataframe(df, tmp_table, job_config=job_config).result() - try: - job = client.query(f""" + merge_sql = f""" MERGE `{training_table}` T USING `{tmp_table}` S ON T.dataset_source_uuid = S.dataset_source_uuid @@ -266,9 +267,24 @@ def merge_chunk_into_training_images( T.image_height = S.image_height, T.image_size = S.image_size, T.corrupted = S.corrupted - """) - job.result() - return job.dml_stats.updated_row_count + """ + try: + # Concurrent MERGEs from parallel tasks can collide with + # "Could not serialize access ... due to concurrent update" (400). + # BQ docs recommend retrying — back off with jitter until a slot frees. + for attempt in range(_MERGE_MAX_RETRIES + 1): + try: + job = client.query(merge_sql) + job.result() + return job.dml_stats.updated_row_count + except google_exceptions.BadRequest as e: + if "serialize" not in str(e).lower() or attempt >= _MERGE_MAX_RETRIES: + raise + delay = min(_BACKOFF_BASE * (2 ** attempt), _BACKOFF_MAX) + delay += random.uniform(0, delay * 0.5) + print(f" MERGE serialization conflict — retry " + f"{attempt+1}/{_MERGE_MAX_RETRIES} in {delay:.0f}s", flush=True) + time.sleep(delay) finally: client.delete_table(tmp_table, not_found_ok=True) From f127c47b45152d1dbc8567b3d0b06b9b6242388a Mon Sep 17 00:00:00 2001 From: Mohamed Elabbas Date: Fri, 5 Jun 2026 12:51:16 -0700 Subject: [PATCH 22/26] feat(test-tables): add --dataset flag, random sampling, _test_ naming - --dataset flag replaces hardcoded global_butterflies_2604 - sample WHERE includes fetch_status IS NULL (new datasets have no 'downloaded' rows yet) and uses ORDER BY RAND() for unbiased size/distribution estimates - test tables renamed with leading underscore (_test_training_images, _test_training_images_downloads) so test artifacts sort together, separate from production tables; use with --table-prefix _test_ Co-Authored-By: Claude Opus 4.8 --- .../bq_squashfs/create_test_tables.py | 28 +++++++++++-------- 1 file changed, 16 insertions(+), 12 deletions(-) diff --git a/src/dataset_tools/bq_squashfs/create_test_tables.py b/src/dataset_tools/bq_squashfs/create_test_tables.py index 74f9343..79fdec8 100644 --- a/src/dataset_tools/bq_squashfs/create_test_tables.py +++ b/src/dataset_tools/bq_squashfs/create_test_tables.py @@ -3,8 +3,8 @@ Create small BQ test tables for testing download_images.py without touching production. Creates: - test_training_images — 50 rows sampled from training_images, fetch_status='pending' - test_training_images_downloads — empty table, same schema as training_images_downloads + _test_training_images — 50 rows sampled from training_images, fetch_status='pending' + _test_training_images_downloads — empty table, same schema as training_images_downloads Usage: python create_test_tables.py @@ -23,15 +23,18 @@ def main(): formatter_class=argparse.RawDescriptionHelpFormatter) parser.add_argument("--n-rows", type=int, default=50, help="Number of rows to sample from training_images (default: 50)") + parser.add_argument("--dataset", default=BQ_DATASET, + help=f"BigQuery dataset name (default: {BQ_DATASET}). " + f"Example: --dataset global_all_leps_2605") args = parser.parse_args() client = bigquery.Client(project=BQ_PROJECT) - prefix = f"{BQ_PROJECT}.{BQ_DATASET}" + prefix = f"{BQ_PROJECT}.{args.dataset}" - # ── test_training_images ───────────────────────────────────────────────── - print(f"Creating {prefix}.test_training_images ({args.n_rows} rows)...") + # ── _test_training_images ───────────────────────────────────────────────── + print(f"Creating {prefix}._test_training_images ({args.n_rows} rows)...") client.query(f""" - CREATE OR REPLACE TABLE `{prefix}.test_training_images` AS + CREATE OR REPLACE TABLE `{prefix}._test_training_images` AS SELECT photo_id, gbif_id, @@ -45,17 +48,18 @@ def main(): CAST(NULL AS INT64) AS image_size, CAST(NULL AS BOOL) AS corrupted FROM `{prefix}.training_images` - WHERE fetch_status = 'downloaded' + WHERE fetch_status = 'downloaded' OR fetch_status IS NULL + ORDER BY RAND() LIMIT {args.n_rows} """).result() - n = client.get_table(f"{prefix}.test_training_images").num_rows + n = client.get_table(f"{prefix}._test_training_images").num_rows print(f" Created: {n} rows, all fetch_status='pending'") - # ── test_training_images_downloads ─────────────────────────────────────── - print(f"Creating {prefix}.test_training_images_downloads (empty)...") + # ── _test_training_images_downloads ─────────────────────────────────────── + print(f"Creating {prefix}._test_training_images_downloads (empty)...") client.query(f""" - CREATE OR REPLACE TABLE `{prefix}.test_training_images_downloads` + CREATE OR REPLACE TABLE `{prefix}._test_training_images_downloads` ( dataset_source_uuid STRING, fetch_status STRING, @@ -72,7 +76,7 @@ def main(): print(f" --staging-dir /scratch/$USER/test_download \\") print(f" --num-jobs 1 --task-id 0 \\") print(f" --num-workers 8 --chunk-size {args.n_rows} \\") - print(f" --limit {args.n_rows} --table-prefix test_") + print(f" --limit {args.n_rows} --table-prefix _test_") if __name__ == "__main__": From f80482f2ece4bc9d0a800333608fb7532aecba5b Mon Sep 17 00:00:00 2001 From: Mohamed Elabbas Date: Fri, 5 Jun 2026 12:55:43 -0700 Subject: [PATCH 23/26] =?UTF-8?q?feat(jobs):=20replace=20bq=5Fdownload=20w?= =?UTF-8?q?ith=20staged=20download=E2=86=92merge=E2=86=92upload=20pipeline?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Each array task now performs the full lifecycle for its MOD(photo_id, 60) slice — download (BQ-checkpointed 10k chunks) → count-verify → stream-merge to task_N.sqfs → count-verify → upload to Arbutus object store → byte-size verify → only then delete local files. Peak scratch is bounded by the array throttle (%K × ~2× task archive), not the dataset size. Replaces the previous all-at-once design which required the whole dataset on scratch before merging and never propagated failure exit codes (last command was notify, so afterok chaining and FAIL mail were broken). Tasks are idempotent (skip if remote archive exists) and resume-safe (prior chunks stashed to a resume_* subdir to avoid the chunk-numbering overwrite; already-attempted images skipped via the downloads table). Validated in production: jobs 43176702/43176730 (global_all_leps_2605). Co-Authored-By: Claude Opus 4.8 --- scripts/job_bq_download.sh | 166 +++++++++++++++++++++++++++++++------ 1 file changed, 141 insertions(+), 25 deletions(-) diff --git a/scripts/job_bq_download.sh b/scripts/job_bq_download.sh index 082407e..0e9ebd6 100755 --- a/scripts/job_bq_download.sh +++ b/scripts/job_bq_download.sh @@ -1,52 +1,168 @@ #!/bin/bash -# Download images from training_images BQ table in parallel. -# Runs as a SLURM array job — each task handles MOD(photo_id, NUM_JOBS) = task_id. -# After all tasks finish, run job_bq_pack_squashfs.sh to merge into a single SquashFS. +# ============================================================================= +# job_bq_download.sh — staged BQ image download → SquashFS → object store +# ============================================================================= # -# Usage: -# sbatch job_bq_download.sh -# # or submit and chain the pack job: -# DOWNLOAD_JOB=$(sbatch --parsable job_bq_download.sh) -# sbatch --dependency=afterok:$DOWNLOAD_JOB job_bq_pack_squashfs.sh +# WHAT IT DOES +# ------------ +# Downloads every pending image from a BigQuery training_images table and +# delivers it to the Arbutus object store as one SquashFS archive per array +# task, using only a small, bounded amount of scratch space. +# +# Runs as a SLURM array job of NUM_JOBS tasks. Each task owns the dataset +# slice MOD(photo_id, NUM_JOBS) == SLURM_ARRAY_TASK_ID and performs the +# FULL lifecycle for that slice before releasing its scratch space: +# +# stage 1 download pending images (32 threads, BQ-checkpointed every +# 10k-image chunk, each chunk packed to chunk_NNNN.sqfs) +# stage 2 count expected images across all chunk files (unsquashfs -l) +# stage 3 stream-merge chunks into a single task_N.sqfs (sqfstar, zstd) +# stage 4 verify merged image count == stage 2 count +# stage 5 upload task_N.sqfs to the object store, verify byte size +# stage 6 delete local chunks + merged sqfs (only after stage 5 verifies) +# +# Because tasks clean up after themselves, peak scratch usage is +# (concurrent tasks) x (~2x final task sqfs size) +# and is controlled by the array throttle (%K below), NOT the dataset size. +# +# STATUS TRACKING (BigQuery) +# -------------------------- +# Per chunk, results are appended to .training_images_downloads and +# MERGEd into .training_images (fetch_status = downloaded / failed / +# corrupted + image dims). The table is therefore updated live as the job +# runs. Concurrent MERGE serialization conflicts are retried with backoff +# inside download_images.py. +# +# FAILURE / RESUME SEMANTICS +# -------------------------- +# Nothing local is deleted until the uploaded archive is verified, and any +# failure notifies (ntfy) and exits non-zero. Resubmitting a failed or +# timed-out task id is always safe: +# - if the remote task_N.sqfs already exists, the task exits 0 immediately +# - already-downloaded images are skipped via training_images_downloads +# - chunk files from a previous partial run are stashed into a resume_* +# subdir (the downloader restarts chunk numbering and would overwrite +# them; the merge step searches recursively and still includes them) +# +# USAGE +# ----- +# sbatch scripts/job_bq_download.sh # full run +# sbatch --array=7 scripts/job_bq_download.sh # re-run one task +# scontrol update JobId= ArrayTaskThrottle=8 # change concurrency live +# +# Before pointing at a new dataset, update STAGING_BASE, S3_DEST and +# --dataset below, and check scratch headroom vs the %K throttle. +# ============================================================================= # #SBATCH --account=def-drolnick #SBATCH --job-name=bq_dl_2605 #SBATCH --cpus-per-task=32 #SBATCH --mem=64G -#SBATCH --time=72:00:00 -#SBATCH --array=0-9 +#SBATCH --time=4:00:00 +#SBATCH --array=0-59%4 #SBATCH --output=/project/6068129/melabbas/ami-ml/scripts/bq_download_%A_%a.out -#SBATCH --mail-type=BEGIN,END,FAIL +#SBATCH --mail-type=END,FAIL #SBATCH --mail-user=hack1996man@gmail.com -NUM_JOBS=10 +set -uo pipefail + +NUM_JOBS=60 TASK_ID=${SLURM_ARRAY_TASK_ID} -# Each job downloads its images into its own staging subdirectory -# These are kept after the job ends (on Lustre) for the pack job to merge STAGING_BASE="/scratch/melabbas/global_all_leps_2605" -STAGING_DIR="${STAGING_BASE}/task_${TASK_ID}" +TASK_DIR="${STAGING_BASE}/task_${TASK_ID}" +OUT_SQFS="${STAGING_BASE}/task_${TASK_ID}.sqfs" +S3_DEST="s3://ami-trainingdata/ai-for-leps/global_all_leps_2605/task_${TASK_ID}.sqfs" +ENDPOINT="https://object-arbutus.cloud.computecanada.ca" +export AWS_PROFILE=ami -echo "=== bq_download task=${TASK_ID}/${NUM_JOBS} started at $(date) ===" -echo "Node: $(hostname)" -echo "Staging dir: ${STAGING_DIR}" +fail() { + echo "ERROR: $1" + notify "bq_download task ${TASK_ID}: FAILED" "$1 | logs: bq_download_${SLURM_ARRAY_JOB_ID}_${TASK_ID}.out" + exit 1 +} -mkdir -p "${STAGING_DIR}" +echo "=== bq_download task=${TASK_ID}/${NUM_JOBS} started $(date) on $(hostname) ===" + +# ── Idempotence: skip if this task already completed ───────────────────────── +if s5cmd --endpoint-url "${ENDPOINT}" ls "${S3_DEST}" >/dev/null 2>&1; then + echo "Remote ${S3_DEST} already exists — task previously completed. Exiting." + exit 0 +fi + +mkdir -p "${TASK_DIR}" + +# ── Resume safety: stash chunks from a previous partial run ────────────────── +# download_images.py restarts numbering at chunk_0001 and would overwrite them; +# merge_sqfs_chunks.py searches recursively, so a subdir keeps them mergeable. +if compgen -G "${TASK_DIR}/chunk_*.sqfs" > /dev/null; then + RESUME_DIR="${TASK_DIR}/resume_${SLURM_ARRAY_JOB_ID}" + mkdir -p "${RESUME_DIR}" + mv "${TASK_DIR}"/chunk_*.sqfs "${RESUME_DIR}/" + echo "Resume: moved $(ls ${RESUME_DIR} | wc -l) existing chunks to ${RESUME_DIR}" +fi cd /project/6068129/melabbas/ami-ml module load StdEnv/2023 arrow/17.0.0 source .venv/bin/activate +# ── STAGE 1: download + pack chunks ────────────────────────────────────────── +echo "=== STAGE 1: download (num_jobs=${NUM_JOBS}, chunk_size=10000) $(date) ===" +T0=$SECONDS python src/dataset_tools/bq_squashfs/download_images.py \ - --staging-dir "${STAGING_DIR}" \ + --staging-dir "${TASK_DIR}" \ --num-jobs ${NUM_JOBS} \ --task-id ${TASK_ID} \ --num-workers 32 \ --chunk-size 10000 \ - --dataset global_all_leps_2605 + --dataset global_all_leps_2605 \ + || fail "download_images.py exited non-zero" +echo "STAGE 1 done in $((SECONDS - T0))s" + +# ── STAGE 2: count expected images across all chunks (incl. resume_* dirs) ── +echo "=== STAGE 2: count expected images $(date) ===" +EXPECTED=0 +while IFS= read -r chunk; do + N=$(unsquashfs -l "${chunk}" 2>/dev/null | grep -cE '\.(jpg|jpeg|png)$' || echo 0) + EXPECTED=$((EXPECTED + N)) +done < <(find "${TASK_DIR}" -name "chunk_*.sqfs") +echo "Expected: ${EXPECTED} images in $(find "${TASK_DIR}" -name 'chunk_*.sqfs' | wc -l) chunks" +[ "${EXPECTED}" -gt 0 ] || fail "no images found in chunks" + +# ── STAGE 3: merge chunks → task sqfs ──────────────────────────────────────── +echo "=== STAGE 3: merge $(date) ===" +T0=$SECONDS +rm -f "${OUT_SQFS}" +python src/dataset_tools/bq_squashfs/merge_sqfs_chunks.py "${TASK_DIR}" \ + | sqfstar -comp zstd -Xcompression-level 3 -b 131072 -no-duplicates "${OUT_SQFS}" +PIPE_STATUS=("${PIPESTATUS[@]}") +[ "${PIPE_STATUS[0]}" -eq 0 ] && [ "${PIPE_STATUS[1]}" -eq 0 ] \ + || fail "merge failed (stream=${PIPE_STATUS[0]} sqfstar=${PIPE_STATUS[1]}) — chunks preserved" +LOCAL_SIZE=$(stat -c%s "${OUT_SQFS}") +echo "STAGE 3 done in $((SECONDS - T0))s — $(numfmt --to=iec ${LOCAL_SIZE})" + +# ── STAGE 4: verify merged image count ─────────────────────────────────────── +ACTUAL=$(unsquashfs -l "${OUT_SQFS}" 2>/dev/null | grep -cE '\.(jpg|jpeg|png)$' || echo 0) +echo "Merged: ${ACTUAL}/${EXPECTED} images" +[ "${ACTUAL}" -eq "${EXPECTED}" ] || fail "image count mismatch ${ACTUAL}/${EXPECTED} — chunks preserved" + +# ── STAGE 5: upload + verify remote size ───────────────────────────────────── +echo "=== STAGE 5: upload $(numfmt --to=iec ${LOCAL_SIZE}) $(date) ===" +T0=$SECONDS +s5cmd --endpoint-url "${ENDPOINT}" cp "${OUT_SQFS}" "${S3_DEST}" \ + || fail "s5cmd upload failed — local sqfs + chunks preserved" +UP_SECS=$((SECONDS - T0)) +REMOTE_SIZE=$(s5cmd --endpoint-url "${ENDPOINT}" ls "${S3_DEST}" | awk '{print $3}') +echo "Uploaded in ${UP_SECS}s — local=${LOCAL_SIZE} remote=${REMOTE_SIZE}" +[ "${REMOTE_SIZE}" = "${LOCAL_SIZE}" ] || fail "remote size mismatch — local files preserved" -EXIT_CODE=$? -echo "=== bq_download task=${TASK_ID} finished at $(date) (exit=${EXIT_CODE}) ===" +# ── STAGE 6: verified — delete local ───────────────────────────────────────── +rm -f "${OUT_SQFS}" +find "${TASK_DIR}" -name "chunk_*.sqfs" -delete +rmdir "${TASK_DIR}"/resume_* 2>/dev/null || true +echo "Local cleanup done." -notify "bq_download task ${TASK_ID}: done" \ - "Staging: ${STAGING_DIR} | exit=${EXIT_CODE}" +echo "=== bq_download task=${TASK_ID} COMPLETE $(date) ===" +notify "bq_download task ${TASK_ID}/${NUM_JOBS}: done" \ + "${ACTUAL} images → ${S3_DEST} ($(numfmt --to=iec ${LOCAL_SIZE})), upload ${UP_SECS}s" +exit 0 From e7384583e0cc76c289613acdd4ba8db01cce7849 Mon Sep 17 00:00:00 2001 From: Mohamed Elabbas Date: Thu, 11 Jun 2026 12:44:36 -0700 Subject: [PATCH 24/26] chore: remove job_bq_pack_per_task.sh from tracking (superseded by job_bq_download.sh) The staged per-task lifecycle in job_bq_download.sh now handles download -> merge -> verify -> upload -> delete inline. The standalone pack job is no longer needed for new runs. Local copy kept for reference/repacking old global_butterflies_2604 chunks if needed. Co-Authored-By: Claude Sonnet 4.6 --- scripts/job_bq_pack_per_task.sh | 149 -------------------------------- 1 file changed, 149 deletions(-) delete mode 100644 scripts/job_bq_pack_per_task.sh diff --git a/scripts/job_bq_pack_per_task.sh b/scripts/job_bq_pack_per_task.sh deleted file mode 100644 index 297b144..0000000 --- a/scripts/job_bq_pack_per_task.sh +++ /dev/null @@ -1,149 +0,0 @@ -#!/bin/bash -# Merge per-chunk SquashFS files for ONE task into a single task-level SquashFS. -# -# Runs as a SLURM array job — each array task handles one download task_N. -# After all tasks complete, task_0.sqfs … task_9.sqfs are ready for use -# by the webdataset build job. -# -# Safety design: -# 1. Stream all chunks → sqfstar → task_N.sqfs (no deletion during stream) -# 2. Verify: count images in output sqfs == sum of images across all chunks -# 3. Only delete chunks if and only if verification passes -# -# This guarantees zero data loss: if sqfstar OOMs or the job times out, all -# chunk files are preserved and the job can be resubmitted with more --mem -# or --time without redownloading anything. -# -# Usage: -# sbatch --array=0-9 job_bq_pack_per_task.sh -# # or chain after download: -# DOWNLOAD_JOB=$(sbatch --parsable --array=0-9 job_bq_download.sh) -# sbatch --array=0-9 --dependency=afterok:$DOWNLOAD_JOB job_bq_pack_per_task.sh -# -#SBATCH --account=def-drolnick -#SBATCH --job-name=bq_pack_task -#SBATCH --cpus-per-task=16 -#SBATCH --mem=192G -#SBATCH --time=6:00:00 -#SBATCH --array=0-9 -#SBATCH --output=/project/6068129/melabbas/ami-ml/scripts/bq_pack_task_%A_%a.out -#SBATCH --mail-type=BEGIN,END,FAIL -#SBATCH --mail-user=hack1996man@gmail.com - -set -euo pipefail - -TASK_ID=${SLURM_ARRAY_TASK_ID} -TASK_DIR="/scratch/melabbas/bq_download_staging/task_${TASK_ID}" -OUTPUT_SQFS="/scratch/melabbas/task_${TASK_ID}.sqfs" - -echo "=== bq_pack_task ${TASK_ID} started at $(date) ===" -echo "Node : $(hostname)" -echo "Task dir : ${TASK_DIR}" -echo "Output : ${OUTPUT_SQFS}" -echo "" - -# ── Pre-flight ──────────────────────────────────────────────────────────────── - -TOTAL_CHUNKS=$(find "${TASK_DIR}" -name "chunk_*.sqfs" 2>/dev/null | wc -l) -if [ "${TOTAL_CHUNKS}" -eq 0 ]; then - echo "No chunks found for task ${TASK_ID} — nothing to merge." - notify "bq_pack_task ${TASK_ID}: skipped" "No chunks found in ${TASK_DIR}" - exit 0 -fi -echo "Chunks to merge: ${TOTAL_CHUNKS}" -echo "" - -# Count expected images across all chunks (metadata read only, no extraction) -echo "Counting expected images across all chunks..." -EXPECTED_IMAGES=0 -for chunk in $(find "${TASK_DIR}" -name "chunk_*.sqfs" | sort); do - COUNT=$(unsquashfs -l "${chunk}" 2>/dev/null | grep -cE '\.(jpg|jpeg|png)$' || echo 0) - echo " $(basename ${chunk}): ${COUNT} images" - EXPECTED_IMAGES=$((EXPECTED_IMAGES + COUNT)) -done -echo "Expected total: ${EXPECTED_IMAGES} images" -echo "" - -# ── Merge ───────────────────────────────────────────────────────────────────── - -rm -f "${OUTPUT_SQFS}" - -cd /project/6068129/melabbas/ami-ml -module load StdEnv/2023 arrow/17.0.0 -source .venv/bin/activate - -echo "=== Streaming chunks → sqfstar at $(date) ===" - -python src/dataset_tools/bq_squashfs/merge_sqfs_chunks.py \ - "${TASK_DIR}" \ - | sqfstar \ - -comp zstd \ - -Xcompression-level 3 \ - -b 131072 \ - -no-duplicates \ - "${OUTPUT_SQFS}" - -# Capture atomically — assigning PIPESTATUS[0] to a variable resets PIPESTATUS, -# so both values must be saved in a single array assignment first. -PIPE_STATUS=("${PIPESTATUS[@]}") -STREAM_EXIT="${PIPE_STATUS[0]}" -SQFSTAR_EXIT="${PIPE_STATUS[1]}" - -echo "" -echo "=== Merge finished at $(date) ===" -echo "stream_chunks_to_tar exit : ${STREAM_EXIT}" -echo "sqfstar exit : ${SQFSTAR_EXIT}" -echo "" - -# ── Verify ──────────────────────────────────────────────────────────────────── - -if [ "${STREAM_EXIT}" -ne 0 ] || [ "${SQFSTAR_EXIT}" -ne 0 ]; then - echo "ERROR: merge failed (stream=${STREAM_EXIT} sqfstar=${SQFSTAR_EXIT})" - echo "Chunks preserved in ${TASK_DIR} — re-submit with more --mem or investigate errors." - notify "bq_pack_task ${TASK_ID}: FAILED" \ - "stream=${STREAM_EXIT} sqfstar=${SQFSTAR_EXIT} — chunks preserved, re-submit" - exit 1 -fi - -if [ ! -f "${OUTPUT_SQFS}" ]; then - echo "ERROR: output sqfs not found at ${OUTPUT_SQFS}" - notify "bq_pack_task ${TASK_ID}: FAILED" "output sqfs missing — chunks preserved" - exit 1 -fi - -echo "Verifying output sqfs image count..." -ACTUAL_IMAGES=$(unsquashfs -l "${OUTPUT_SQFS}" 2>/dev/null | grep -cE '\.(jpg|jpeg|png)$' || echo 0) -SIZE=$(du -sh "${OUTPUT_SQFS}" | cut -f1) - -echo " Expected : ${EXPECTED_IMAGES} images" -echo " Actual : ${ACTUAL_IMAGES} images" -echo " Size : ${SIZE}" -echo "" - -if [ "${ACTUAL_IMAGES}" -ne "${EXPECTED_IMAGES}" ]; then - echo "ERROR: image count mismatch (expected=${EXPECTED_IMAGES} actual=${ACTUAL_IMAGES})" - echo "Output sqfs may be incomplete. Chunks preserved in ${TASK_DIR}." - echo "Investigate: run audit_sqfs.py or check stream_chunks_to_tar logs above." - notify "bq_pack_task ${TASK_ID}: FAILED (count mismatch)" \ - "expected=${EXPECTED_IMAGES} actual=${ACTUAL_IMAGES} — chunks preserved in ${TASK_DIR}" - exit 1 -fi - -echo "Verification passed: ${ACTUAL_IMAGES} images confirmed in ${OUTPUT_SQFS}" -echo "" - -# ── Safe delete ─────────────────────────────────────────────────────────────── -# Only reached when: stream_exit=0, sqfstar_exit=0, image count matches. - -echo "Deleting chunk files (verified, safe to remove)..." -DELETED=0 -for chunk in $(find "${TASK_DIR}" -name "chunk_*.sqfs" | sort); do - rm -f "${chunk}" - DELETED=$((DELETED + 1)) -done -echo "Deleted ${DELETED} chunk files from ${TASK_DIR}" -echo "" - -echo "=== bq_pack_task ${TASK_ID} done at $(date) ===" -notify "bq_pack_task ${TASK_ID}: done" \ - "${ACTUAL_IMAGES} images in ${OUTPUT_SQFS} (${SIZE}) — ${DELETED} chunks deleted" From b7d3368766a2c6729188817a276b9156985dd072 Mon Sep 17 00:00:00 2001 From: Mohamed Elabbas Date: Thu, 11 Jun 2026 13:01:04 -0700 Subject: [PATCH 25/26] chore: add google-cloud-bigquery to dev dependencies Required by tests/conftest.py and download_images.py. Missing from dev group caused CI test runner to fail with ModuleNotFoundError. Co-Authored-By: Claude Sonnet 4.6 --- pyproject.toml | 2 ++ 1 file changed, 2 insertions(+) diff --git a/pyproject.toml b/pyproject.toml index 9c7daa2..512a4dc 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -53,6 +53,8 @@ ipdb = "^0.13.13" python-devtools = "^2" ipykernel = "^6.29.4" pytest = "^8.1.1" +google-cloud-bigquery = "^3.0" +pandas-gbq = "^0.19" [tool.poetry.scripts] ami-dataset = "src.dataset_tools.cli:cli" From caa04de9737f20412f93f08f1e45f6fa83301561 Mon Sep 17 00:00:00 2001 From: Mohamed Elabbas Date: Thu, 11 Jun 2026 13:01:11 -0700 Subject: [PATCH 26/26] style: apply black and isort formatting MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Fixes CI lint failure (pre-commit run on changed files). No logic changes — formatting only. Co-Authored-By: Claude Sonnet 4.6 --- .../bq_squashfs/download_images.py | 352 ++++++++++------- tests/conftest.py | 55 ++- tests/dataset_tools/test_download_images.py | 364 +++++++++++++----- tests/dataset_tools/test_merge_sqfs_chunks.py | 209 ++++++---- 4 files changed, 654 insertions(+), 326 deletions(-) diff --git a/src/dataset_tools/bq_squashfs/download_images.py b/src/dataset_tools/bq_squashfs/download_images.py index 6c172ab..033c524 100644 --- a/src/dataset_tools/bq_squashfs/download_images.py +++ b/src/dataset_tools/bq_squashfs/download_images.py @@ -43,9 +43,9 @@ import pandas as pd import PIL import requests -from PIL import Image from google.api_core import exceptions as google_exceptions from google.cloud import bigquery +from PIL import Image Image.MAX_IMAGE_PIXELS = None @@ -54,9 +54,9 @@ # Retry config for HTTP downloads _RETRY_STATUSES = {429, 500, 502, 503, 504} -_MAX_RETRIES = 5 -_BACKOFF_BASE = 2.0 # seconds -_BACKOFF_MAX = 60.0 # seconds cap +_MAX_RETRIES = 5 +_BACKOFF_BASE = 2.0 # seconds +_BACKOFF_MAX = 60.0 # seconds cap _MERGE_MAX_RETRIES = 10 # BQ MERGE serialization conflicts (concurrent tasks) # Warn if this many chunk sqfs files accumulate (pack job falling behind) @@ -64,11 +64,11 @@ DOWNLOADS_SCHEMA = [ bigquery.SchemaField("dataset_source_uuid", "STRING"), - bigquery.SchemaField("fetch_status", "STRING"), - bigquery.SchemaField("image_width", "INTEGER"), - bigquery.SchemaField("image_height", "INTEGER"), - bigquery.SchemaField("image_size", "INTEGER"), - bigquery.SchemaField("corrupted", "BOOLEAN"), + bigquery.SchemaField("fetch_status", "STRING"), + bigquery.SchemaField("image_width", "INTEGER"), + bigquery.SchemaField("image_height", "INTEGER"), + bigquery.SchemaField("image_size", "INTEGER"), + bigquery.SchemaField("corrupted", "BOOLEAN"), ] # Thread-local storage for per-thread requests sessions @@ -103,10 +103,13 @@ def _fetch_with_retry(url: str, dest: Path) -> None: try: resp = session.get(url, timeout=30, stream=True) if resp.status_code in _RETRY_STATUSES and attempt < _MAX_RETRIES: - delay = min(_BACKOFF_BASE * (2 ** attempt), _BACKOFF_MAX) + delay = min(_BACKOFF_BASE * (2**attempt), _BACKOFF_MAX) delay += random.uniform(0, delay * 0.25) - print(f" HTTP {resp.status_code} {url} — retry {attempt+1}/{_MAX_RETRIES} " - f"in {delay:.1f}s", flush=True) + print( + f" HTTP {resp.status_code} {url} — retry {attempt+1}/{_MAX_RETRIES} " + f"in {delay:.1f}s", + flush=True, + ) time.sleep(delay) continue resp.raise_for_status() @@ -116,19 +119,25 @@ def _fetch_with_retry(url: str, dest: Path) -> None: return except requests.exceptions.ConnectionError as e: if attempt < _MAX_RETRIES: - delay = min(_BACKOFF_BASE * (2 ** attempt), _BACKOFF_MAX) + delay = min(_BACKOFF_BASE * (2**attempt), _BACKOFF_MAX) delay += random.uniform(0, delay * 0.25) - print(f" ConnectionError {url} — retry {attempt+1}/{_MAX_RETRIES} " - f"in {delay:.1f}s: {e}", flush=True) + print( + f" ConnectionError {url} — retry {attempt+1}/{_MAX_RETRIES} " + f"in {delay:.1f}s: {e}", + flush=True, + ) time.sleep(delay) else: raise except requests.exceptions.Timeout: if attempt < _MAX_RETRIES: - delay = min(_BACKOFF_BASE * (2 ** attempt), _BACKOFF_MAX) + delay = min(_BACKOFF_BASE * (2**attempt), _BACKOFF_MAX) delay += random.uniform(0, delay * 0.25) - print(f" Timeout {url} — retry {attempt+1}/{_MAX_RETRIES} " - f"in {delay:.1f}s", flush=True) + print( + f" Timeout {url} — retry {attempt+1}/{_MAX_RETRIES} " + f"in {delay:.1f}s", + flush=True, + ) time.sleep(delay) else: raise @@ -137,18 +146,18 @@ def _fetch_with_retry(url: str, dest: Path) -> None: def download_and_verify(row: dict, staging_dir: Path) -> dict: """Download one image, verify with PIL, return result dict.""" - url = row["absolute_url"] + url = row["absolute_url"] rel_path = row["relative_local_path"] - dest = staging_dir / rel_path + dest = staging_dir / rel_path dest.parent.mkdir(parents=True, exist_ok=True) result = { "dataset_source_uuid": row["dataset_source_uuid"], "fetch_status": None, - "image_width": None, + "image_width": None, "image_height": None, - "image_size": None, - "corrupted": None, + "image_size": None, + "corrupted": None, } try: @@ -162,12 +171,12 @@ def download_and_verify(row: dict, staging_dir: Path) -> dict: with Image.open(dest) as img: img.convert("RGB") result["image_width"], result["image_height"] = img.size - result["image_size"] = dest.stat().st_size - result["corrupted"] = False + result["image_size"] = dest.stat().st_size + result["corrupted"] = False result["fetch_status"] = "downloaded" except (PIL.UnidentifiedImageError, OSError) as e: print(f" Corrupted {url}: {e}", flush=True) - result["corrupted"] = True + result["corrupted"] = True result["fetch_status"] = "corrupted" return result @@ -206,14 +215,19 @@ def write_results_to_bq( ) for attempt in range(max_retries): try: - job = client.load_table_from_dataframe(df, downloads_table, job_config=job_config) + job = client.load_table_from_dataframe( + df, downloads_table, job_config=job_config + ) job.result() return except Exception as e: if attempt < max_retries - 1: delay = 30 * (attempt + 1) - print(f" BQ write failed (attempt {attempt+1}/{max_retries}): {e} " - f"— retrying in {delay}s", flush=True) + print( + f" BQ write failed (attempt {attempt+1}/{max_retries}): {e} " + f"— retrying in {delay}s", + flush=True, + ) time.sleep(delay) else: raise @@ -242,7 +256,9 @@ def merge_chunk_into_training_images( Only updates rows that are still 'pending' — safe from parallel tasks. Returns the number of rows updated. """ - to_merge = [r for r in results if r["fetch_status"] in ("downloaded", "corrupted", "failed")] + to_merge = [ + r for r in results if r["fetch_status"] in ("downloaded", "corrupted", "failed") + ] if not to_merge: return 0 @@ -280,10 +296,13 @@ def merge_chunk_into_training_images( except google_exceptions.BadRequest as e: if "serialize" not in str(e).lower() or attempt >= _MERGE_MAX_RETRIES: raise - delay = min(_BACKOFF_BASE * (2 ** attempt), _BACKOFF_MAX) + delay = min(_BACKOFF_BASE * (2**attempt), _BACKOFF_MAX) delay += random.uniform(0, delay * 0.5) - print(f" MERGE serialization conflict — retry " - f"{attempt+1}/{_MERGE_MAX_RETRIES} in {delay:.0f}s", flush=True) + print( + f" MERGE serialization conflict — retry " + f"{attempt+1}/{_MERGE_MAX_RETRIES} in {delay:.0f}s", + flush=True, + ) time.sleep(delay) finally: client.delete_table(tmp_table, not_found_ok=True) @@ -322,7 +341,9 @@ def get_pending_rows( return [dict(r) for r in client.query(query).result()] -def pack_chunk_to_sqfs(staging_dir: Path, chunk_num: int, num_workers: int = 4) -> Path | None: +def pack_chunk_to_sqfs( + staging_dir: Path, chunk_num: int, num_workers: int = 4 +) -> Path | None: """Pack downloaded images into a per-chunk SquashFS file. Uses bucket subdirs (000/, 001/, ...) directly so paths inside the archive @@ -331,7 +352,10 @@ def pack_chunk_to_sqfs(staging_dir: Path, chunk_num: int, num_workers: int = 4) """ bucket_dirs = sorted(d for d in staging_dir.iterdir() if d.is_dir()) if not bucket_dirs: - print(f" No images in staging dir — skipping sqfs pack for chunk {chunk_num}", flush=True) + print( + f" No images in staging dir — skipping sqfs pack for chunk {chunk_num}", + flush=True, + ) return None chunk_sqfs = staging_dir / f"chunk_{chunk_num:04d}.sqfs" @@ -341,18 +365,23 @@ def pack_chunk_to_sqfs(staging_dir: Path, chunk_num: int, num_workers: int = 4) str(chunk_sqfs), "-noappend", "-no-xattrs", - "-comp", "zstd", - "-Xcompression-level", "3", - "-processors", str(num_workers), + "-comp", + "zstd", + "-Xcompression-level", + "3", + "-processors", + str(num_workers), ] - print(f" Packing {len(bucket_dirs)} bucket dirs → {chunk_sqfs.name}...", flush=True) + print( + f" Packing {len(bucket_dirs)} bucket dirs → {chunk_sqfs.name}...", flush=True + ) result = subprocess.run(cmd, check=False) if result.returncode != 0: raise RuntimeError( f"mksquashfs failed with exit code {result.returncode} for chunk {chunk_num}. " f"Staging dir preserved for inspection: {staging_dir}" ) - size_mb = chunk_sqfs.stat().st_size / (1024 ** 2) + size_mb = chunk_sqfs.stat().st_size / (1024**2) print(f" Packed: {chunk_sqfs.name} ({size_mb:.1f} MB)", flush=True) return chunk_sqfs @@ -382,78 +411,113 @@ def warn_chunk_accumulation(staging_dir: Path) -> None: def main(): - parser = argparse.ArgumentParser(description=__doc__, - formatter_class=argparse.RawDescriptionHelpFormatter) - parser.add_argument("--staging-dir", required=True, - help=( - "Local directory where images are downloaded before packing. " - "Use scratch (e.g. /scratch/$USER/staging), not home — home has " - "a 500k inode quota and each image counts as one inode. " - "chunk_NNNN.sqfs files accumulate here until job_bq_pack_per_task.sh " - "merges them into the final task_N.sqfs." - )) - parser.add_argument("--num-jobs", type=int, required=True, - help=( - "Total number of parallel download tasks. Images are partitioned " - "by MOD(photo_id, num_jobs) so each task gets a non-overlapping " - "subset. Must match the SLURM --array range: --num-jobs 10 requires " - "--array=0-9 in the job script. Typical value: 10." - )) - parser.add_argument("--task-id", type=int, required=True, - help=( - "Index of this task (0 to num_jobs-1). In a SLURM array job set " - "this to $SLURM_ARRAY_TASK_ID. This task will download all images " - "where MOD(photo_id, num_jobs) == task_id." - )) - parser.add_argument("--num-workers", type=int, default=32, - help=( - "Number of parallel download threads per task (default: 32). " - "With 10 tasks running simultaneously this means up to 320 " - "concurrent connections to iNaturalist S3. At scale this caused " - "Errno 16 (too many open sockets) — the retry logic handles it " - "but reducing to 16-24 workers per task lowers the error rate." - )) - parser.add_argument("--chunk-size", type=int, default=10000, - help=( - "Number of images to download before packing into a sqfs chunk " - "and clearing the staging dir (default: 10000). Lower values " - "reduce peak inode usage in staging but produce more chunk files " - "for the pack job to merge. Each chunk becomes one " - "chunk_NNNN.sqfs file in --staging-dir." - )) - parser.add_argument("--limit", type=int, default=None, - help=( - "Cap the total number of images queried from BQ. Only for " - "small-scale tests — omit for production runs. " - "Example: --limit 50 --table-prefix test_ for a quick smoke test." - )) - parser.add_argument("--force-redownload", action="store_true", - help=( - "Ignore existing records in training_images_downloads and " - "re-download all images for this task. Use when staging files " - "were deleted after a failed pack job and you need to rebuild " - "the chunks from scratch. Without this flag, already-attempted " - "images are skipped via LEFT JOIN." - )) - parser.add_argument("--dataset", default=BQ_DEFAULT_DATASET, - help=( - f"BigQuery dataset name within the leps-ai project " - f"(default: {BQ_DEFAULT_DATASET}). " - f"Example: --dataset global_all_leps_2605" - )) - parser.add_argument("--table-prefix", default="", - help=( - "BQ table name prefix for testing without touching production. " - "Example: --table-prefix test_ reads from test_training_images " - "and writes to test_training_images_downloads. " - "Create test tables first with create_test_tables.py." - )) + parser = argparse.ArgumentParser( + description=__doc__, formatter_class=argparse.RawDescriptionHelpFormatter + ) + parser.add_argument( + "--staging-dir", + required=True, + help=( + "Local directory where images are downloaded before packing. " + "Use scratch (e.g. /scratch/$USER/staging), not home — home has " + "a 500k inode quota and each image counts as one inode. " + "chunk_NNNN.sqfs files accumulate here until job_bq_pack_per_task.sh " + "merges them into the final task_N.sqfs." + ), + ) + parser.add_argument( + "--num-jobs", + type=int, + required=True, + help=( + "Total number of parallel download tasks. Images are partitioned " + "by MOD(photo_id, num_jobs) so each task gets a non-overlapping " + "subset. Must match the SLURM --array range: --num-jobs 10 requires " + "--array=0-9 in the job script. Typical value: 10." + ), + ) + parser.add_argument( + "--task-id", + type=int, + required=True, + help=( + "Index of this task (0 to num_jobs-1). In a SLURM array job set " + "this to $SLURM_ARRAY_TASK_ID. This task will download all images " + "where MOD(photo_id, num_jobs) == task_id." + ), + ) + parser.add_argument( + "--num-workers", + type=int, + default=32, + help=( + "Number of parallel download threads per task (default: 32). " + "With 10 tasks running simultaneously this means up to 320 " + "concurrent connections to iNaturalist S3. At scale this caused " + "Errno 16 (too many open sockets) — the retry logic handles it " + "but reducing to 16-24 workers per task lowers the error rate." + ), + ) + parser.add_argument( + "--chunk-size", + type=int, + default=10000, + help=( + "Number of images to download before packing into a sqfs chunk " + "and clearing the staging dir (default: 10000). Lower values " + "reduce peak inode usage in staging but produce more chunk files " + "for the pack job to merge. Each chunk becomes one " + "chunk_NNNN.sqfs file in --staging-dir." + ), + ) + parser.add_argument( + "--limit", + type=int, + default=None, + help=( + "Cap the total number of images queried from BQ. Only for " + "small-scale tests — omit for production runs. " + "Example: --limit 50 --table-prefix test_ for a quick smoke test." + ), + ) + parser.add_argument( + "--force-redownload", + action="store_true", + help=( + "Ignore existing records in training_images_downloads and " + "re-download all images for this task. Use when staging files " + "were deleted after a failed pack job and you need to rebuild " + "the chunks from scratch. Without this flag, already-attempted " + "images are skipped via LEFT JOIN." + ), + ) + parser.add_argument( + "--dataset", + default=BQ_DEFAULT_DATASET, + help=( + f"BigQuery dataset name within the leps-ai project " + f"(default: {BQ_DEFAULT_DATASET}). " + f"Example: --dataset global_all_leps_2605" + ), + ) + parser.add_argument( + "--table-prefix", + default="", + help=( + "BQ table name prefix for testing without touching production. " + "Example: --table-prefix test_ reads from test_training_images " + "and writes to test_training_images_downloads. " + "Create test tables first with create_test_tables.py." + ), + ) args = parser.parse_args() - training_table = f"{BQ_PROJECT}.{args.dataset}.{args.table_prefix}training_images" - downloads_table = f"{BQ_PROJECT}.{args.dataset}.{args.table_prefix}training_images_downloads" + training_table = f"{BQ_PROJECT}.{args.dataset}.{args.table_prefix}training_images" + downloads_table = ( + f"{BQ_PROJECT}.{args.dataset}.{args.table_prefix}training_images_downloads" + ) - client = bigquery.Client(project=BQ_PROJECT) + client = bigquery.Client(project=BQ_PROJECT) staging_dir = Path(args.staging_dir) staging_dir.mkdir(parents=True, exist_ok=True) @@ -461,37 +525,52 @@ def main(): print(f"training table : {training_table}", flush=True) print(f"downloads table : {downloads_table}", flush=True) print(f"staging dir : {staging_dir}", flush=True) - print(f"workers : {args.num_workers} chunk_size={args.chunk_size}", flush=True) + print( + f"workers : {args.num_workers} chunk_size={args.chunk_size}", + flush=True, + ) print(flush=True) ensure_downloads_table(client, downloads_table) warn_chunk_accumulation(staging_dir) - print(f"Querying pending rows (force_redownload={args.force_redownload})...", flush=True) + print( + f"Querying pending rows (force_redownload={args.force_redownload})...", + flush=True, + ) rows = get_pending_rows( - client, training_table, downloads_table, - args.num_jobs, args.task_id, - limit=args.limit, force_redownload=args.force_redownload, + client, + training_table, + downloads_table, + args.num_jobs, + args.task_id, + limit=args.limit, + force_redownload=args.force_redownload, ) print(f"{len(rows):,} pending images to download", flush=True) total_downloaded = total_failed = total_corrupted = 0 for chunk_start in range(0, len(rows), args.chunk_size): - chunk = rows[chunk_start : chunk_start + args.chunk_size] - chunk_num = chunk_start // args.chunk_size + 1 + chunk = rows[chunk_start : chunk_start + args.chunk_size] + chunk_num = chunk_start // args.chunk_size + 1 total_chunks = (len(rows) + args.chunk_size - 1) // args.chunk_size - print(f"\n[Task {args.task_id}] Chunk {chunk_num}/{total_chunks} " - f"({len(chunk):,} images)...", flush=True) + print( + f"\n[Task {args.task_id}] Chunk {chunk_num}/{total_chunks} " + f"({len(chunk):,} images)...", + flush=True, + ) # Download in parallel - results = [] - t0 = time.perf_counter() + results = [] + t0 = time.perf_counter() n_ok = n_fail = n_corrupt = 0 with ThreadPoolExecutor(max_workers=args.num_workers) as executor: - futures = {executor.submit(download_and_verify, row, staging_dir): row - for row in chunk} + futures = { + executor.submit(download_and_verify, row, staging_dir): row + for row in chunk + } for i, future in enumerate(as_completed(futures)): r = future.result() results.append(r) @@ -503,16 +582,22 @@ def main(): n_corrupt += 1 if (i + 1) % 1000 == 0: elapsed = time.perf_counter() - t0 - print(f" {i+1:,}/{len(chunk):,} " - f"downloaded={n_ok:,} failed={n_fail:,} corrupted={n_corrupt:,} " - f"({(i+1)/elapsed:.0f} img/s)", flush=True) + print( + f" {i+1:,}/{len(chunk):,} " + f"downloaded={n_ok:,} failed={n_fail:,} corrupted={n_corrupt:,} " + f"({(i+1)/elapsed:.0f} img/s)", + flush=True, + ) elapsed = time.perf_counter() - t0 total_downloaded += n_ok - total_failed += n_fail - total_corrupted += n_corrupt - print(f" Chunk done in {elapsed:.0f}s ({len(chunk)/elapsed:.0f} img/s) " - f"downloaded={n_ok:,} failed={n_fail:,} corrupted={n_corrupt:,}", flush=True) + total_failed += n_fail + total_corrupted += n_corrupt + print( + f" Chunk done in {elapsed:.0f}s ({len(chunk)/elapsed:.0f} img/s) " + f"downloaded={n_ok:,} failed={n_fail:,} corrupted={n_corrupt:,}", + flush=True, + ) # Append to downloads table (batch load — free tier, parallel-safe) write_results_to_bq(client, results, downloads_table) @@ -530,9 +615,12 @@ def main(): warn_chunk_accumulation(staging_dir) print(f" Staging cleared (chunk sqfs kept)", flush=True) - print(f"\n[Task {args.task_id}] Done. " - f"downloaded={total_downloaded:,} failed={total_failed:,} " - f"corrupted={total_corrupted:,}", flush=True) + print( + f"\n[Task {args.task_id}] Done. " + f"downloaded={total_downloaded:,} failed={total_failed:,} " + f"corrupted={total_corrupted:,}", + flush=True, + ) if __name__ == "__main__": diff --git a/tests/conftest.py b/tests/conftest.py index 9a30dfe..6a60031 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -17,13 +17,12 @@ import pandas as pd import pytest -from PIL import Image - from google.cloud import bigquery - +from PIL import Image # ── BQ client ───────────────────────────────────────────────────────────────── + @pytest.fixture def mock_bq_client(): """ @@ -48,6 +47,7 @@ def mock_bq_client(): # ── Small CSV dataset ───────────────────────────────────────────────────────── + def _make_small_df() -> pd.DataFrame: """ 5 species × 10 images = 50 rows. @@ -59,26 +59,28 @@ def _make_small_df() -> pd.DataFrame: - one species has exactly 5 images (min_instances boundary) """ species = [ - ("Danaus plexippus", 1001, 101), - ("Vanessa atalanta", 1002, 102), - ("Papilio machaon", 1003, 103), - ("Colias croceus", 1004, 104), - ("Pieris brassicae", 1005, 105), + ("Danaus plexippus", 1001, 101), + ("Vanessa atalanta", 1002, 102), + ("Papilio machaon", 1003, 103), + ("Colias croceus", 1004, 104), + ("Pieris brassicae", 1005, 105), ] rows = [] photo_id = 0 for sp_name, taxon_id, base_gbif in species: for i in range(10): - gbif_id = base_gbif + (i // 2) # 2 images share a gbif_id - rows.append({ - "photo_id": photo_id, - "gbif_id": gbif_id, - "inat_taxon_id": taxon_id, - "species_name": sp_name, - "dataset_source_uuid": f"uuid-{photo_id:04d}", - "relative_local_path": f"{photo_id % 256:03d}/{photo_id:06d}.jpg", - "absolute_url": f"https://inaturalist.org/photos/{photo_id}/original.jpg", - }) + gbif_id = base_gbif + (i // 2) # 2 images share a gbif_id + rows.append( + { + "photo_id": photo_id, + "gbif_id": gbif_id, + "inat_taxon_id": taxon_id, + "species_name": sp_name, + "dataset_source_uuid": f"uuid-{photo_id:04d}", + "relative_local_path": f"{photo_id % 256:03d}/{photo_id:06d}.jpg", + "absolute_url": f"https://inaturalist.org/photos/{photo_id}/original.jpg", + } + ) photo_id += 1 return pd.DataFrame(rows) @@ -99,6 +101,7 @@ def small_csv(tmp_path) -> Path: # ── Small SquashFS ──────────────────────────────────────────────────────────── + def _build_small_sqfs(root: Path) -> Path: """ Create a small sqfs with 10 PIL-generated JPEGs. @@ -115,9 +118,18 @@ def _build_small_sqfs(root: Path) -> Path: sqfs_path = root / "test_fixture.sqfs" result = subprocess.run( - ["mksquashfs", str(img_dir), str(sqfs_path), - "-noappend", "-no-xattrs", "-comp", "zstd", - "-Xcompression-level", "1", "-quiet"], + [ + "mksquashfs", + str(img_dir), + str(sqfs_path), + "-noappend", + "-no-xattrs", + "-comp", + "zstd", + "-Xcompression-level", + "1", + "-quiet", + ], capture_output=True, ) if result.returncode != 0: @@ -137,6 +149,7 @@ def small_sqfs(tmp_path_factory) -> Path: # ── SQL file ────────────────────────────────────────────────────────────────── + @pytest.fixture def sample_sql_file(tmp_path) -> Path: """Minimal SQL query file for bq_export.py tests.""" diff --git a/tests/dataset_tools/test_download_images.py b/tests/dataset_tools/test_download_images.py index f11e7a5..c7816d4 100644 --- a/tests/dataset_tools/test_download_images.py +++ b/tests/dataset_tools/test_download_images.py @@ -14,9 +14,9 @@ import src.dataset_tools.bq_squashfs.download_images as di - # ── helpers ─────────────────────────────────────────────────────────────────── + def make_response(status_code: int, content: bytes = b"\xff\xd8\xff\xe0JFIF"): """Build a minimal mock HTTP response.""" resp = MagicMock() @@ -40,13 +40,15 @@ def make_mock_session(*responses): # ── _fetch_with_retry ───────────────────────────────────────────────────────── + class TestFetchWithRetry: def test_success_on_first_attempt(self, tmp_path): dest = tmp_path / "img.jpg" session = make_mock_session(make_response(200, b"IMAGE")) - with patch.object(di, "_get_session", return_value=session), \ - patch("time.sleep"): + with patch.object(di, "_get_session", return_value=session), patch( + "time.sleep" + ): di._fetch_with_retry("http://example.com/img.jpg", dest) assert dest.read_bytes() == b"IMAGE" assert session.get.call_count == 1 @@ -58,8 +60,9 @@ def test_429_retries_then_succeeds(self, tmp_path): make_response(429), make_response(200, b"IMAGE"), ) - with patch.object(di, "_get_session", return_value=session), \ - patch("time.sleep"): + with patch.object(di, "_get_session", return_value=session), patch( + "time.sleep" + ): di._fetch_with_retry("http://example.com/img.jpg", dest) assert session.get.call_count == 3 assert dest.read_bytes() == b"IMAGE" @@ -70,18 +73,22 @@ def test_503_retries_then_succeeds(self, tmp_path): make_response(503), make_response(200, b"IMAGE"), ) - with patch.object(di, "_get_session", return_value=session), \ - patch("time.sleep"): + with patch.object(di, "_get_session", return_value=session), patch( + "time.sleep" + ): di._fetch_with_retry("http://example.com/img.jpg", dest) assert session.get.call_count == 2 def test_connection_error_errno16_retries(self, tmp_path): """Errno 16 (device/resource busy — too many sockets) retries.""" dest = tmp_path / "img.jpg" - errno16 = requests.exceptions.ConnectionError("[Errno 16] Device or resource busy") + errno16 = requests.exceptions.ConnectionError( + "[Errno 16] Device or resource busy" + ) session = make_mock_session(errno16, errno16, make_response(200, b"IMAGE")) - with patch.object(di, "_get_session", return_value=session), \ - patch("time.sleep"): + with patch.object(di, "_get_session", return_value=session), patch( + "time.sleep" + ): di._fetch_with_retry("http://example.com/img.jpg", dest) assert session.get.call_count == 3 @@ -91,8 +98,9 @@ def test_timeout_retries_then_succeeds(self, tmp_path): requests.exceptions.Timeout(), make_response(200, b"IMAGE"), ) - with patch.object(di, "_get_session", return_value=session), \ - patch("time.sleep"): + with patch.object(di, "_get_session", return_value=session), patch( + "time.sleep" + ): di._fetch_with_retry("http://example.com/img.jpg", dest) assert session.get.call_count == 2 @@ -100,9 +108,9 @@ def test_404_raises_immediately_no_retry(self, tmp_path): """404 is not in RETRY_STATUSES — raises without retrying.""" dest = tmp_path / "img.jpg" session = make_mock_session(make_response(404)) - with patch.object(di, "_get_session", return_value=session), \ - patch("time.sleep"), \ - pytest.raises(requests.exceptions.HTTPError): + with patch.object(di, "_get_session", return_value=session), patch( + "time.sleep" + ), pytest.raises(requests.exceptions.HTTPError): di._fetch_with_retry("http://example.com/img.jpg", dest) assert session.get.call_count == 1 @@ -111,23 +119,26 @@ def test_exhausted_retries_on_connection_error_raises(self, tmp_path): dest = tmp_path / "img.jpg" err = requests.exceptions.ConnectionError("connection refused") session = make_mock_session(*([err] * (di._MAX_RETRIES + 1))) - with patch.object(di, "_get_session", return_value=session), \ - patch("time.sleep"), \ - pytest.raises(requests.exceptions.ConnectionError): + with patch.object(di, "_get_session", return_value=session), patch( + "time.sleep" + ), pytest.raises(requests.exceptions.ConnectionError): di._fetch_with_retry("http://example.com/img.jpg", dest) assert session.get.call_count == di._MAX_RETRIES + 1 def test_exhausted_retries_on_timeout_raises(self, tmp_path): dest = tmp_path / "img.jpg" - session = make_mock_session(*([requests.exceptions.Timeout()] * (di._MAX_RETRIES + 1))) - with patch.object(di, "_get_session", return_value=session), \ - patch("time.sleep"), \ - pytest.raises(requests.exceptions.Timeout): + session = make_mock_session( + *([requests.exceptions.Timeout()] * (di._MAX_RETRIES + 1)) + ) + with patch.object(di, "_get_session", return_value=session), patch( + "time.sleep" + ), pytest.raises(requests.exceptions.Timeout): di._fetch_with_retry("http://example.com/img.jpg", dest) # ── download_and_verify ─────────────────────────────────────────────────────── + class TestDownloadAndVerify: ROW = { @@ -138,16 +149,20 @@ class TestDownloadAndVerify: def test_success(self, tmp_path): """Valid JPEG → fetch_status=downloaded, dimensions populated.""" - from PIL import Image import io + + from PIL import Image + buf = io.BytesIO() Image.new("RGB", (64, 48)).save(buf, format="JPEG") jpeg_bytes = buf.getvalue() with patch.object(di, "_fetch_with_retry") as mock_fetch: + def write_file(url, dest): dest.parent.mkdir(parents=True, exist_ok=True) dest.write_bytes(jpeg_bytes) + mock_fetch.side_effect = write_file result = di.download_and_verify(self.ROW, tmp_path) @@ -160,8 +175,9 @@ def write_file(url, dest): def test_network_failure_recorded_as_failed(self, tmp_path): """Network error → fetch_status=failed, no image on disk.""" - with patch.object(di, "_fetch_with_retry", - side_effect=Exception("connection refused")): + with patch.object( + di, "_fetch_with_retry", side_effect=Exception("connection refused") + ): result = di.download_and_verify(self.ROW, tmp_path) assert result["fetch_status"] == "failed" @@ -170,9 +186,11 @@ def test_network_failure_recorded_as_failed(self, tmp_path): def test_corrupted_image_recorded_as_corrupted(self, tmp_path): """Truncated/invalid image bytes → fetch_status=corrupted.""" with patch.object(di, "_fetch_with_retry") as mock_fetch: + def write_garbage(url, dest): dest.parent.mkdir(parents=True, exist_ok=True) dest.write_bytes(b"not an image at all") + mock_fetch.side_effect = write_garbage result = di.download_and_verify(self.ROW, tmp_path) @@ -184,11 +202,18 @@ def write_garbage(url, dest): # ── write_results_to_bq ─────────────────────────────────────────────────────── + class TestWriteResultsToBq: RESULTS = [ - {"dataset_source_uuid": "u1", "fetch_status": "downloaded", - "image_width": 100, "image_height": 80, "image_size": 5000, "corrupted": False}, + { + "dataset_source_uuid": "u1", + "fetch_status": "downloaded", + "image_width": 100, + "image_height": 80, + "image_size": 5000, + "corrupted": False, + }, ] def test_success(self): @@ -219,6 +244,7 @@ def test_exhausted_retries_raises(self): # ── pack_chunk_to_sqfs ──────────────────────────────────────────────────────── + class TestPackChunkToSqfs: def test_success(self, tmp_path): @@ -259,6 +285,7 @@ def test_empty_staging_returns_none(self, tmp_path): # ── merge_chunk_into_training_images ───────────────────────────────────────── + class TestMergeChunkIntoTrainingImages: def _make_client(self, updated_count: int = 1) -> MagicMock: @@ -279,9 +306,16 @@ def test_empty_list_skips_merge(self): def test_downloaded_triggers_merge(self): """downloaded rows → temp table load + MERGE + cleanup.""" client = self._make_client(updated_count=1) - results = [{"dataset_source_uuid": "u1", "fetch_status": "downloaded", - "image_width": 100, "image_height": 80, - "image_size": 5000, "corrupted": False}] + results = [ + { + "dataset_source_uuid": "u1", + "fetch_status": "downloaded", + "image_width": 100, + "image_height": 80, + "image_size": 5000, + "corrupted": False, + } + ] n = di.merge_chunk_into_training_images(client, results, "t", "d") assert n == 1 assert client.load_table_from_dataframe.call_count == 1 @@ -291,9 +325,16 @@ def test_downloaded_triggers_merge(self): def test_corrupted_triggers_merge(self): """corrupted rows → merged with fetch_status='corrupted'.""" client = self._make_client(updated_count=1) - results = [{"dataset_source_uuid": "u1", "fetch_status": "corrupted", - "image_width": None, "image_height": None, - "image_size": None, "corrupted": True}] + results = [ + { + "dataset_source_uuid": "u1", + "fetch_status": "corrupted", + "image_width": None, + "image_height": None, + "image_size": None, + "corrupted": True, + } + ] n = di.merge_chunk_into_training_images(client, results, "t", "d") assert n == 1 assert client.load_table_from_dataframe.call_count == 1 @@ -303,9 +344,16 @@ def test_failed_triggers_merge(self): training_images. Permanent failures are excluded from future re-runs via WHERE fetch_status='pending' without needing the LEFT JOIN.""" client = self._make_client(updated_count=1) - results = [{"dataset_source_uuid": "u1", "fetch_status": "failed", - "image_width": None, "image_height": None, - "image_size": None, "corrupted": None}] + results = [ + { + "dataset_source_uuid": "u1", + "fetch_status": "failed", + "image_width": None, + "image_height": None, + "image_size": None, + "corrupted": None, + } + ] n = di.merge_chunk_into_training_images(client, results, "t", "d") assert n == 1 assert client.load_table_from_dataframe.call_count == 1 @@ -316,22 +364,43 @@ def test_all_three_statuses_merged_together(self): """Mixed chunk — downloaded, corrupted, failed — all three trigger one MERGE.""" client = self._make_client(updated_count=3) results = [ - {"dataset_source_uuid": "u1", "fetch_status": "downloaded", - "image_width": 100, "image_height": 80, "image_size": 5000, "corrupted": False}, - {"dataset_source_uuid": "u2", "fetch_status": "corrupted", - "image_width": None, "image_height": None, "image_size": 500, "corrupted": True}, - {"dataset_source_uuid": "u3", "fetch_status": "failed", - "image_width": None, "image_height": None, "image_size": None, "corrupted": None}, + { + "dataset_source_uuid": "u1", + "fetch_status": "downloaded", + "image_width": 100, + "image_height": 80, + "image_size": 5000, + "corrupted": False, + }, + { + "dataset_source_uuid": "u2", + "fetch_status": "corrupted", + "image_width": None, + "image_height": None, + "image_size": 500, + "corrupted": True, + }, + { + "dataset_source_uuid": "u3", + "fetch_status": "failed", + "image_width": None, + "image_height": None, + "image_size": None, + "corrupted": None, + }, ] n = di.merge_chunk_into_training_images(client, results, "t", "d") assert n == 3 - assert client.load_table_from_dataframe.call_count == 1 # one temp table for all 3 - assert client.query.call_count == 1 # one MERGE - assert client.delete_table.call_count == 1 # one cleanup + assert ( + client.load_table_from_dataframe.call_count == 1 + ) # one temp table for all 3 + assert client.query.call_count == 1 # one MERGE + assert client.delete_table.call_count == 1 # one cleanup def test_failed_rows_included_in_temp_table(self): """Verify the dataframe passed to BQ includes the failed row.""" import pandas as pd + client = self._make_client() captured_df = {} @@ -342,18 +411,30 @@ def capture_load(df, table, **kwargs): client.load_table_from_dataframe.side_effect = capture_load results = [ - {"dataset_source_uuid": "ok", "fetch_status": "downloaded", - "image_width": 64, "image_height": 48, "image_size": 1000, "corrupted": False}, - {"dataset_source_uuid": "dead", "fetch_status": "failed", - "image_width": None, "image_height": None, "image_size": None, "corrupted": None}, + { + "dataset_source_uuid": "ok", + "fetch_status": "downloaded", + "image_width": 64, + "image_height": 48, + "image_size": 1000, + "corrupted": False, + }, + { + "dataset_source_uuid": "dead", + "fetch_status": "failed", + "image_width": None, + "image_height": None, + "image_size": None, + "corrupted": None, + }, ] di.merge_chunk_into_training_images(client, results, "t", "d") df = captured_df["data"] - assert len(df) == 2 # both rows in temp table + assert len(df) == 2 # both rows in temp table statuses = set(df["fetch_status"].tolist()) assert "downloaded" in statuses - assert "failed" in statuses # failed row present + assert "failed" in statuses # failed row present def test_temp_table_deleted_even_on_merge_failure(self): """Temp table must be cleaned up even if the MERGE query fails.""" @@ -361,9 +442,16 @@ def test_temp_table_deleted_even_on_merge_failure(self): client.load_table_from_dataframe.return_value.result.return_value = None client.query.side_effect = Exception("MERGE failed") - results = [{"dataset_source_uuid": "u1", "fetch_status": "downloaded", - "image_width": 100, "image_height": 80, - "image_size": 5000, "corrupted": False}] + results = [ + { + "dataset_source_uuid": "u1", + "fetch_status": "downloaded", + "image_width": 100, + "image_height": 80, + "image_size": 5000, + "corrupted": False, + } + ] with pytest.raises(Exception, match="MERGE failed"): di.merge_chunk_into_training_images( @@ -374,19 +462,32 @@ def test_temp_table_deleted_even_on_merge_failure(self): # ── get_pending_rows / MOD split ───────────────────────────────────────────── + class TestModSplit: """Verify that num_jobs/task_id partitioning is correct and complete.""" - def _make_client(self, photo_ids: list[int], num_jobs: int, task_id: int) -> MagicMock: + def _make_client( + self, photo_ids: list[int], num_jobs: int, task_id: int + ) -> MagicMock: """Return a mock BQ client that filters photo_ids by MOD split.""" matching = [ - {"dataset_source_uuid": f"uuid-{p}", "absolute_url": f"http://x/{p}", - "relative_local_path": f"000/{p}.jpg"} - for p in photo_ids if p % num_jobs == task_id + { + "dataset_source_uuid": f"uuid-{p}", + "absolute_url": f"http://x/{p}", + "relative_local_path": f"000/{p}.jpg", + } + for p in photo_ids + if p % num_jobs == task_id ] client = MagicMock() client.query.return_value.result.return_value = [ - MagicMock(**{k: v for k, v in row.items()}, **{"__iter__": lambda self: iter(row.items()), "keys": lambda self: row.keys()}) + MagicMock( + **{k: v for k, v in row.items()}, + **{ + "__iter__": lambda self: iter(row.items()), + "keys": lambda self: row.keys(), + }, + ) for row in matching ] # Simpler: just return dicts directly via side_effect @@ -421,7 +522,7 @@ def test_task_gets_correct_subset(self): """Task 3 of 10 should only see photo_ids ending in 3.""" photo_ids = list(range(50)) expected = [p for p in photo_ids if p % 10 == 3] # 3, 13, 23, 33, 43 - actual = [p for p in photo_ids if p % 10 == 3] + actual = [p for p in photo_ids if p % 10 == 3] assert actual == expected assert all(p % 10 == 3 for p in actual) @@ -505,6 +606,7 @@ def test_limit_applied_to_query(self): # ── Multi-task distribution and merge ──────────────────────────────────────── + class TestMultiTaskDistributionAndMerge: """ Verify correct behaviour when multiple tasks run in parallel: @@ -514,13 +616,12 @@ class TestMultiTaskDistributionAndMerge: - training_images MERGE is correct when multiple tasks write concurrently """ - PHOTO_IDS = list(range(50)) # simulate 50 images + PHOTO_IDS = list(range(50)) # simulate 50 images def _partition(self, num_jobs: int) -> dict[int, list[int]]: """Return {task_id: [photo_ids]} for all tasks.""" return { - t: [p for p in self.PHOTO_IDS if p % num_jobs == t] - for t in range(num_jobs) + t: [p for p in self.PHOTO_IDS if p % num_jobs == t] for t in range(num_jobs) } # ── partitioning ────────────────────────────────────────────────────────── @@ -555,11 +656,22 @@ def test_task0_completion_does_not_affect_task1_query(self): def test_non_sequential_photo_ids_still_partition_correctly(self): """Real photo_ids from iNat are large non-sequential ints — MOD still works.""" - real_ids = [487851, 7047265, 8233026, 8427425, 10239192, - 17327318, 21463254, 27648248, 36757555, 41676327] + real_ids = [ + 487851, + 7047265, + 8233026, + 8427425, + 10239192, + 17327318, + 21463254, + 27648248, + 36757555, + 41676327, + ] for num_jobs in [2, 5, 10]: - parts = {t: [p for p in real_ids if p % num_jobs == t] - for t in range(num_jobs)} + parts = { + t: [p for p in real_ids if p % num_jobs == t] for t in range(num_jobs) + } combined = [p for task in parts.values() for p in task] assert sorted(combined) == sorted(real_ids) @@ -579,15 +691,29 @@ def test_downloads_table_appends_are_independent(self): client = MagicMock() client.load_table_from_dataframe.return_value.result.return_value = None - task0_results = [{"dataset_source_uuid": f"uuid-{p}", "fetch_status": "downloaded", - "image_width": 100, "image_height": 80, - "image_size": 5000, "corrupted": False} - for p in range(0, 10, 2)] # even photo_ids - - task1_results = [{"dataset_source_uuid": f"uuid-{p}", "fetch_status": "downloaded", - "image_width": 100, "image_height": 80, - "image_size": 5000, "corrupted": False} - for p in range(1, 10, 2)] # odd photo_ids + task0_results = [ + { + "dataset_source_uuid": f"uuid-{p}", + "fetch_status": "downloaded", + "image_width": 100, + "image_height": 80, + "image_size": 5000, + "corrupted": False, + } + for p in range(0, 10, 2) + ] # even photo_ids + + task1_results = [ + { + "dataset_source_uuid": f"uuid-{p}", + "fetch_status": "downloaded", + "image_width": 100, + "image_height": 80, + "image_size": 5000, + "corrupted": False, + } + for p in range(1, 10, 2) + ] # odd photo_ids # both tasks write to the same table — no conflict because WRITE_APPEND di.write_results_to_bq(client, task0_results, "downloads_table") @@ -612,14 +738,28 @@ def test_merge_from_two_tasks_updates_correct_rows(self): job1.dml_stats.updated_row_count = 5 client.query.side_effect = [job0, job1] - task0_results = [{"dataset_source_uuid": f"uuid-{i}", "fetch_status": "downloaded", - "image_width": 64, "image_height": 48, - "image_size": 1000, "corrupted": False} - for i in range(5)] - task1_results = [{"dataset_source_uuid": f"uuid-{i+5}", "fetch_status": "downloaded", - "image_width": 64, "image_height": 48, - "image_size": 1000, "corrupted": False} - for i in range(5)] + task0_results = [ + { + "dataset_source_uuid": f"uuid-{i}", + "fetch_status": "downloaded", + "image_width": 64, + "image_height": 48, + "image_size": 1000, + "corrupted": False, + } + for i in range(5) + ] + task1_results = [ + { + "dataset_source_uuid": f"uuid-{i+5}", + "fetch_status": "downloaded", + "image_width": 64, + "image_height": 48, + "image_size": 1000, + "corrupted": False, + } + for i in range(5) + ] n0 = di.merge_chunk_into_training_images( client, task0_results, "training_table", "downloads_table" @@ -630,7 +770,7 @@ def test_merge_from_two_tasks_updates_correct_rows(self): assert n0 == 5 assert n1 == 5 - assert client.query.call_count == 2 # one MERGE per task + assert client.query.call_count == 2 # one MERGE per task assert client.delete_table.call_count == 2 # temp table cleaned per task def test_total_coverage_after_all_tasks_complete(self): @@ -648,6 +788,7 @@ def test_total_coverage_after_all_tasks_complete(self): # ── warn_chunk_accumulation ─────────────────────────────────────────────────── + class TestWarnChunkAccumulation: def test_no_warning_below_threshold(self, tmp_path, capsys): @@ -665,6 +806,7 @@ def test_warning_at_threshold(self, tmp_path, capsys): # ── --dataset flag and NULL fetch_status handling ───────────────────────────── + class TestDatasetFlagAndNullFetchStatus: """Tests for --dataset CLI flag and NULL fetch_status support (global_all_leps_2605).""" @@ -725,12 +867,20 @@ def test_merge_condition_handles_null_fetch_status(self): job.dml_stats.updated_row_count = 1 client.query.return_value = job - results = [{"dataset_source_uuid": "u1", "fetch_status": "downloaded", - "image_width": 64, "image_height": 48, - "image_size": 1000, "corrupted": False}] + results = [ + { + "dataset_source_uuid": "u1", + "fetch_status": "downloaded", + "image_width": 64, + "image_height": 48, + "image_size": 1000, + "corrupted": False, + } + ] di.merge_chunk_into_training_images( - client, results, + client, + results, training_table="leps-ai.global_all_leps_2605.training_images", downloads_table="leps-ai.global_all_leps_2605.training_images_downloads", ) @@ -750,12 +900,20 @@ def test_tmp_table_uses_same_dataset_as_training_table(self): job.dml_stats.updated_row_count = 1 client.query.return_value = job - results = [{"dataset_source_uuid": "u1", "fetch_status": "downloaded", - "image_width": 64, "image_height": 48, - "image_size": 1000, "corrupted": False}] + results = [ + { + "dataset_source_uuid": "u1", + "fetch_status": "downloaded", + "image_width": 64, + "image_height": 48, + "image_size": 1000, + "corrupted": False, + } + ] di.merge_chunk_into_training_images( - client, results, + client, + results, training_table="leps-ai.global_all_leps_2605.training_images", downloads_table="leps-ai.global_all_leps_2605.training_images_downloads", ) @@ -773,15 +931,25 @@ def test_tmp_table_not_in_wrong_dataset_when_using_new_dataset(self): job.dml_stats.updated_row_count = 1 client.query.return_value = job - results = [{"dataset_source_uuid": "u1", "fetch_status": "downloaded", - "image_width": 64, "image_height": 48, - "image_size": 1000, "corrupted": False}] + results = [ + { + "dataset_source_uuid": "u1", + "fetch_status": "downloaded", + "image_width": 64, + "image_height": 48, + "image_size": 1000, + "corrupted": False, + } + ] di.merge_chunk_into_training_images( - client, results, + client, + results, training_table="leps-ai.global_all_leps_2605.training_images", downloads_table="leps-ai.global_all_leps_2605.training_images_downloads", ) tmp_table_arg = client.load_table_from_dataframe.call_args[0][1] - assert "global_butterflies_2604" not in tmp_table_arg # must not leak old dataset + assert ( + "global_butterflies_2604" not in tmp_table_arg + ) # must not leak old dataset diff --git a/tests/dataset_tools/test_merge_sqfs_chunks.py b/tests/dataset_tools/test_merge_sqfs_chunks.py index d5d766f..e090533 100644 --- a/tests/dataset_tools/test_merge_sqfs_chunks.py +++ b/tests/dataset_tools/test_merge_sqfs_chunks.py @@ -19,9 +19,9 @@ import src.dataset_tools.bq_squashfs.merge_sqfs_chunks as sct - # ── helpers ─────────────────────────────────────────────────────────────────── + def make_chunk_sqfs(staging_dir: Path, chunk_num: int) -> Path: """Create a dummy chunk_NNNN.sqfs file (content doesn't matter — squashfuse is mocked).""" p = staging_dir / f"chunk_{chunk_num:04d}.sqfs" @@ -47,6 +47,7 @@ def read_tar_from_bytes(data: bytes) -> list[str]: # ── stream_dir_to_tar ───────────────────────────────────────────────────────── + class TestStreamDirToTar: def test_files_added_with_relative_paths(self, tmp_path): @@ -66,9 +67,9 @@ def test_dirs_included_files_counted(self, tmp_path): buf = io.BytesIO() with tarfile.open(fileobj=buf, mode="w:") as tf: count = sct.stream_dir_to_tar(tf, str(mnt)) - assert count == 1 # only the .jpg, not the dir + assert count == 1 # only the .jpg, not the dir members = read_tar_from_bytes(buf.getvalue()) - assert "000" in members # dir entry + assert "000" in members # dir entry assert "000/img.jpg" in members # file entry def test_empty_mount_dir_returns_zero(self, tmp_path): @@ -98,6 +99,7 @@ def test_multiple_bucket_dirs_all_streamed(self, tmp_path): # ── squashfuse_mount / unmount ──────────────────────────────────────────────── + class TestSquashfuseMount: def test_successful_mount_returns_true(self): @@ -128,6 +130,7 @@ def test_unmount_calls_fusermount(self): # ── main: no chunks ─────────────────────────────────────────────────────────── + class TestNoChunks: def test_empty_staging_dir_exits_with_error(self, tmp_path, capsys): @@ -149,6 +152,7 @@ def test_missing_staging_dir_exits_with_error(self, tmp_path, capsys): # ── main: dry run ───────────────────────────────────────────────────────────── + class TestDryRun: def test_dry_run_lists_chunks_no_streaming(self, tmp_path, capsys): @@ -158,8 +162,9 @@ def test_dry_run_lists_chunks_no_streaming(self, tmp_path, capsys): c1 = make_chunk_sqfs(staging, 1) c2 = make_chunk_sqfs(staging, 2) - with patch("sys.argv", ["merge_sqfs_chunks.py", str(staging), "--dry-run"]), \ - patch.object(sct, "squashfuse_mount") as mock_mount: + with patch( + "sys.argv", ["merge_sqfs_chunks.py", str(staging), "--dry-run"] + ), patch.object(sct, "squashfuse_mount") as mock_mount: sct.main() mock_mount.assert_not_called() # no mounting in dry run @@ -185,13 +190,17 @@ def test_dry_run_lists_in_sorted_order(self, tmp_path, capsys): # ── main: streaming ─────────────────────────────────────────────────────────── + class TestStreaming: - def _run_stream(self, staging: Path, extra_args: list[str] = []) -> tuple[bytes, str]: + def _run_stream( + self, staging: Path, extra_args: list[str] = [] + ) -> tuple[bytes, str]: """Run main(), capture stdout bytes and stderr text.""" stdout_buf = io.BytesIO() - with patch("sys.argv", ["merge_sqfs_chunks.py", str(staging)] + extra_args), \ - patch("sys.stdout") as mock_stdout: + with patch( + "sys.argv", ["merge_sqfs_chunks.py", str(staging)] + extra_args + ), patch("sys.stdout") as mock_stdout: mock_stdout.buffer = stdout_buf sct.main() return stdout_buf.getvalue(), "" @@ -209,14 +218,21 @@ def test_single_chunk_produces_valid_tar(self, tmp_path): (fake_mnt / "000" / "img.jpg").write_bytes(b"JPEG") stdout_buf = io.BytesIO() - with patch("sys.argv", ["merge_sqfs_chunks.py", str(staging)]), \ - patch("sys.stdout") as mock_stdout, \ - patch.object(sct, "squashfuse_mount", return_value=True), \ - patch.object(sct, "squashfuse_unmount"), \ - patch("tempfile.mkdtemp", return_value=str(tmp_path / "mnt_base")), \ - patch("os.makedirs"), \ - patch("os.rmdir"), \ - patch.object(sct, "stream_dir_to_tar", return_value=1) as mock_stream: + with patch("sys.argv", ["merge_sqfs_chunks.py", str(staging)]), patch( + "sys.stdout" + ) as mock_stdout, patch.object( + sct, "squashfuse_mount", return_value=True + ), patch.object( + sct, "squashfuse_unmount" + ), patch( + "tempfile.mkdtemp", return_value=str(tmp_path / "mnt_base") + ), patch( + "os.makedirs" + ), patch( + "os.rmdir" + ), patch.object( + sct, "stream_dir_to_tar", return_value=1 + ) as mock_stream: mock_stdout.buffer = stdout_buf sct.main() @@ -236,14 +252,21 @@ def fake_stream(tar, mnt_dir): return 5 stdout_buf = io.BytesIO() - with patch("sys.argv", ["merge_sqfs_chunks.py", str(staging)]), \ - patch("sys.stdout") as mock_stdout, \ - patch.object(sct, "squashfuse_mount", return_value=True), \ - patch.object(sct, "squashfuse_unmount"), \ - patch("tempfile.mkdtemp", return_value=str(tmp_path / "mnt_base")), \ - patch("os.makedirs"), \ - patch("os.rmdir"), \ - patch.object(sct, "stream_dir_to_tar", side_effect=fake_stream): + with patch("sys.argv", ["merge_sqfs_chunks.py", str(staging)]), patch( + "sys.stdout" + ) as mock_stdout, patch.object( + sct, "squashfuse_mount", return_value=True + ), patch.object( + sct, "squashfuse_unmount" + ), patch( + "tempfile.mkdtemp", return_value=str(tmp_path / "mnt_base") + ), patch( + "os.makedirs" + ), patch( + "os.rmdir" + ), patch.object( + sct, "stream_dir_to_tar", side_effect=fake_stream + ): mock_stdout.buffer = stdout_buf sct.main() @@ -258,22 +281,30 @@ def test_chunks_always_preserved_after_stream(self, tmp_path): chunk = make_chunk_sqfs(staging, 1) stdout_buf = io.BytesIO() - with patch("sys.argv", ["merge_sqfs_chunks.py", str(staging)]), \ - patch("sys.stdout") as mock_stdout, \ - patch.object(sct, "squashfuse_mount", return_value=True), \ - patch.object(sct, "squashfuse_unmount"), \ - patch("tempfile.mkdtemp", return_value=str(tmp_path / "mnt_base")), \ - patch("os.makedirs"), \ - patch("os.rmdir"), \ - patch.object(sct, "stream_dir_to_tar", return_value=1): + with patch("sys.argv", ["merge_sqfs_chunks.py", str(staging)]), patch( + "sys.stdout" + ) as mock_stdout, patch.object( + sct, "squashfuse_mount", return_value=True + ), patch.object( + sct, "squashfuse_unmount" + ), patch( + "tempfile.mkdtemp", return_value=str(tmp_path / "mnt_base") + ), patch( + "os.makedirs" + ), patch( + "os.rmdir" + ), patch.object( + sct, "stream_dir_to_tar", return_value=1 + ): mock_stdout.buffer = stdout_buf sct.main() - assert chunk.exists() # always preserved — job script deletes after verify + assert chunk.exists() # always preserved — job script deletes after verify # ── main: error handling ────────────────────────────────────────────────────── + class TestErrorHandling: def test_failed_mount_skipped_continues_to_next_chunk(self, tmp_path, capsys): @@ -287,14 +318,21 @@ def test_failed_mount_skipped_continues_to_next_chunk(self, tmp_path, capsys): mount_results = [False, True] stdout_buf = io.BytesIO() - with patch("sys.argv", ["merge_sqfs_chunks.py", str(staging)]), \ - patch("sys.stdout") as mock_stdout, \ - patch.object(sct, "squashfuse_mount", side_effect=mount_results), \ - patch.object(sct, "squashfuse_unmount"), \ - patch("tempfile.mkdtemp", return_value=str(tmp_path / "mnt_base")), \ - patch("os.makedirs"), \ - patch("os.rmdir"), \ - patch.object(sct, "stream_dir_to_tar", return_value=5): + with patch("sys.argv", ["merge_sqfs_chunks.py", str(staging)]), patch( + "sys.stdout" + ) as mock_stdout, patch.object( + sct, "squashfuse_mount", side_effect=mount_results + ), patch.object( + sct, "squashfuse_unmount" + ), patch( + "tempfile.mkdtemp", return_value=str(tmp_path / "mnt_base") + ), patch( + "os.makedirs" + ), patch( + "os.rmdir" + ), patch.object( + sct, "stream_dir_to_tar", return_value=5 + ): mock_stdout.buffer = stdout_buf with pytest.raises(SystemExit) as exc: sct.main() @@ -312,12 +350,17 @@ def test_all_mounts_fail_exits_nonzero(self, tmp_path): make_chunk_sqfs(staging, 2) stdout_buf = io.BytesIO() - with patch("sys.argv", ["merge_sqfs_chunks.py", str(staging)]), \ - patch("sys.stdout") as mock_stdout, \ - patch.object(sct, "squashfuse_mount", return_value=False), \ - patch("tempfile.mkdtemp", return_value=str(tmp_path / "mnt_base")), \ - patch("os.makedirs"), \ - patch("os.rmdir"): + with patch("sys.argv", ["merge_sqfs_chunks.py", str(staging)]), patch( + "sys.stdout" + ) as mock_stdout, patch.object( + sct, "squashfuse_mount", return_value=False + ), patch( + "tempfile.mkdtemp", return_value=str(tmp_path / "mnt_base") + ), patch( + "os.makedirs" + ), patch( + "os.rmdir" + ): mock_stdout.buffer = stdout_buf with pytest.raises(SystemExit) as exc: sct.main() @@ -331,14 +374,21 @@ def test_empty_chunk_exits_nonzero(self, tmp_path): make_chunk_sqfs(staging, 1) stdout_buf = io.BytesIO() - with patch("sys.argv", ["merge_sqfs_chunks.py", str(staging)]), \ - patch("sys.stdout") as mock_stdout, \ - patch.object(sct, "squashfuse_mount", return_value=True), \ - patch.object(sct, "squashfuse_unmount"), \ - patch("tempfile.mkdtemp", return_value=str(tmp_path / "mnt_base")), \ - patch("os.makedirs"), \ - patch("os.rmdir"), \ - patch.object(sct, "stream_dir_to_tar", return_value=0): # 0 images + with patch("sys.argv", ["merge_sqfs_chunks.py", str(staging)]), patch( + "sys.stdout" + ) as mock_stdout, patch.object( + sct, "squashfuse_mount", return_value=True + ), patch.object( + sct, "squashfuse_unmount" + ), patch( + "tempfile.mkdtemp", return_value=str(tmp_path / "mnt_base") + ), patch( + "os.makedirs" + ), patch( + "os.rmdir" + ), patch.object( + sct, "stream_dir_to_tar", return_value=0 + ): # 0 images mock_stdout.buffer = stdout_buf with pytest.raises(SystemExit) as exc: sct.main() @@ -351,22 +401,22 @@ def test_squashfuse_retry_on_transient_failure(self, tmp_path): make_chunk_sqfs(staging, 1) fail = MagicMock(returncode=1, stderr="fuse: temporary error") - ok = MagicMock(returncode=0) + ok = MagicMock(returncode=0) - with patch("subprocess.run", side_effect=[fail, ok]) as mock_run, \ - patch("time.sleep"): + with patch("subprocess.run", side_effect=[fail, ok]) as mock_run, patch( + "time.sleep" + ): result = sct.squashfuse_mount("/fake.sqfs", "/mnt/fake", retries=1) assert result is True - assert mock_run.call_count == 2 # one fail + one retry + assert mock_run.call_count == 2 # one fail + one retry def test_squashfuse_unmount_retries_on_failure(self, capsys): """fusermount failure retried; logs warning instead of raising.""" fail = MagicMock(returncode=1, stderr="resource busy") - ok = MagicMock(returncode=0) + ok = MagicMock(returncode=0) - with patch("subprocess.run", side_effect=[fail, ok]), \ - patch("time.sleep"): + with patch("subprocess.run", side_effect=[fail, ok]), patch("time.sleep"): result = sct.squashfuse_unmount("/mnt/fake", retries=2) assert result is True # succeeded on second attempt @@ -374,8 +424,7 @@ def test_squashfuse_unmount_retries_on_failure(self, capsys): def test_squashfuse_unmount_warns_on_all_failures(self, capsys): """All unmount retries exhausted → warning logged, no raise.""" fail = MagicMock(returncode=1, stderr="resource busy") - with patch("subprocess.run", return_value=fail), \ - patch("time.sleep"): + with patch("subprocess.run", return_value=fail), patch("time.sleep"): result = sct.squashfuse_unmount("/mnt/fake", retries=2) assert result is False assert "WARNING" in capsys.readouterr().err @@ -386,8 +435,9 @@ def test_delete_after_stream_flag_removed(self, tmp_path): staging.mkdir() make_chunk_sqfs(staging, 1) - with patch("sys.argv", ["merge_sqfs_chunks.py", str(staging), - "--delete-after-stream"]): + with patch( + "sys.argv", ["merge_sqfs_chunks.py", str(staging), "--delete-after-stream"] + ): with pytest.raises(SystemExit) as exc: sct.main() assert exc.value.code == 2 # argparse unrecognised argument @@ -407,17 +457,26 @@ def fake_mount(sqfs_path, mnt_dir): return True stdout_buf = io.BytesIO() - with patch("sys.argv", ["merge_sqfs_chunks.py", str(staging)]), \ - patch("sys.stdout") as mock_stdout, \ - patch.object(sct, "squashfuse_mount", side_effect=fake_mount), \ - patch.object(sct, "squashfuse_unmount"), \ - patch("tempfile.mkdtemp", return_value=str(tmp_path / "mnt_base")), \ - patch("os.makedirs"), \ - patch("os.rmdir"), \ - patch.object(sct, "stream_dir_to_tar", return_value=1): + with patch("sys.argv", ["merge_sqfs_chunks.py", str(staging)]), patch( + "sys.stdout" + ) as mock_stdout, patch.object( + sct, "squashfuse_mount", side_effect=fake_mount + ), patch.object( + sct, "squashfuse_unmount" + ), patch( + "tempfile.mkdtemp", return_value=str(tmp_path / "mnt_base") + ), patch( + "os.makedirs" + ), patch( + "os.rmdir" + ), patch.object( + sct, "stream_dir_to_tar", return_value=1 + ): mock_stdout.buffer = stdout_buf sct.main() assert processed_order == [ - "chunk_0001.sqfs", "chunk_0002.sqfs", "chunk_0003.sqfs" + "chunk_0001.sqfs", + "chunk_0002.sqfs", + "chunk_0003.sqfs", ]