diff --git a/.gitignore b/.gitignore index 534c4ed..4aae6f3 100644 --- a/.gitignore +++ b/.gitignore @@ -171,3 +171,18 @@ cython_debug/ # Weights and Biases files wandb/ + +# Raw image datasets (too many inodes for home filesystem) +data/vermont_butterflies/ +data/vermont_species/ + +# Job output logs and stats +stats/ +logs/ + +# Scratch/intermediate data outputs +data/bq_pipeline/ +data/predictions/ +data/pipeline_test/ +data/split_test/ +data/wds_test_splits/ diff --git a/pyproject.toml b/pyproject.toml index 9c7daa2..512a4dc 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -53,6 +53,8 @@ ipdb = "^0.13.13" python-devtools = "^2" ipykernel = "^6.29.4" pytest = "^8.1.1" +google-cloud-bigquery = "^3.0" +pandas-gbq = "^0.19" [tool.poetry.scripts] ami-dataset = "src.dataset_tools.cli:cli" diff --git a/scripts/job_bq_download.sh b/scripts/job_bq_download.sh new file mode 100755 index 0000000..0e9ebd6 --- /dev/null +++ b/scripts/job_bq_download.sh @@ -0,0 +1,168 @@ +#!/bin/bash +# ============================================================================= +# job_bq_download.sh — staged BQ image download → SquashFS → object store +# ============================================================================= +# +# WHAT IT DOES +# ------------ +# Downloads every pending image from a BigQuery training_images table and +# delivers it to the Arbutus object store as one SquashFS archive per array +# task, using only a small, bounded amount of scratch space. +# +# Runs as a SLURM array job of NUM_JOBS tasks. Each task owns the dataset +# slice MOD(photo_id, NUM_JOBS) == SLURM_ARRAY_TASK_ID and performs the +# FULL lifecycle for that slice before releasing its scratch space: +# +# stage 1 download pending images (32 threads, BQ-checkpointed every +# 10k-image chunk, each chunk packed to chunk_NNNN.sqfs) +# stage 2 count expected images across all chunk files (unsquashfs -l) +# stage 3 stream-merge chunks into a single task_N.sqfs (sqfstar, zstd) +# stage 4 verify merged image count == stage 2 count +# stage 5 upload task_N.sqfs to the object store, verify byte size +# stage 6 delete local chunks + merged sqfs (only after stage 5 verifies) +# +# Because tasks clean up after themselves, peak scratch usage is +# (concurrent tasks) x (~2x final task sqfs size) +# and is controlled by the array throttle (%K below), NOT the dataset size. +# +# STATUS TRACKING (BigQuery) +# -------------------------- +# Per chunk, results are appended to .training_images_downloads and +# MERGEd into .training_images (fetch_status = downloaded / failed / +# corrupted + image dims). The table is therefore updated live as the job +# runs. Concurrent MERGE serialization conflicts are retried with backoff +# inside download_images.py. +# +# FAILURE / RESUME SEMANTICS +# -------------------------- +# Nothing local is deleted until the uploaded archive is verified, and any +# failure notifies (ntfy) and exits non-zero. Resubmitting a failed or +# timed-out task id is always safe: +# - if the remote task_N.sqfs already exists, the task exits 0 immediately +# - already-downloaded images are skipped via training_images_downloads +# - chunk files from a previous partial run are stashed into a resume_* +# subdir (the downloader restarts chunk numbering and would overwrite +# them; the merge step searches recursively and still includes them) +# +# USAGE +# ----- +# sbatch scripts/job_bq_download.sh # full run +# sbatch --array=7 scripts/job_bq_download.sh # re-run one task +# scontrol update JobId= ArrayTaskThrottle=8 # change concurrency live +# +# Before pointing at a new dataset, update STAGING_BASE, S3_DEST and +# --dataset below, and check scratch headroom vs the %K throttle. +# ============================================================================= +# +#SBATCH --account=def-drolnick +#SBATCH --job-name=bq_dl_2605 +#SBATCH --cpus-per-task=32 +#SBATCH --mem=64G +#SBATCH --time=4:00:00 +#SBATCH --array=0-59%4 +#SBATCH --output=/project/6068129/melabbas/ami-ml/scripts/bq_download_%A_%a.out +#SBATCH --mail-type=END,FAIL +#SBATCH --mail-user=hack1996man@gmail.com + +set -uo pipefail + +NUM_JOBS=60 +TASK_ID=${SLURM_ARRAY_TASK_ID} + +STAGING_BASE="/scratch/melabbas/global_all_leps_2605" +TASK_DIR="${STAGING_BASE}/task_${TASK_ID}" +OUT_SQFS="${STAGING_BASE}/task_${TASK_ID}.sqfs" +S3_DEST="s3://ami-trainingdata/ai-for-leps/global_all_leps_2605/task_${TASK_ID}.sqfs" +ENDPOINT="https://object-arbutus.cloud.computecanada.ca" +export AWS_PROFILE=ami + +fail() { + echo "ERROR: $1" + notify "bq_download task ${TASK_ID}: FAILED" "$1 | logs: bq_download_${SLURM_ARRAY_JOB_ID}_${TASK_ID}.out" + exit 1 +} + +echo "=== bq_download task=${TASK_ID}/${NUM_JOBS} started $(date) on $(hostname) ===" + +# ── Idempotence: skip if this task already completed ───────────────────────── +if s5cmd --endpoint-url "${ENDPOINT}" ls "${S3_DEST}" >/dev/null 2>&1; then + echo "Remote ${S3_DEST} already exists — task previously completed. Exiting." + exit 0 +fi + +mkdir -p "${TASK_DIR}" + +# ── Resume safety: stash chunks from a previous partial run ────────────────── +# download_images.py restarts numbering at chunk_0001 and would overwrite them; +# merge_sqfs_chunks.py searches recursively, so a subdir keeps them mergeable. +if compgen -G "${TASK_DIR}/chunk_*.sqfs" > /dev/null; then + RESUME_DIR="${TASK_DIR}/resume_${SLURM_ARRAY_JOB_ID}" + mkdir -p "${RESUME_DIR}" + mv "${TASK_DIR}"/chunk_*.sqfs "${RESUME_DIR}/" + echo "Resume: moved $(ls ${RESUME_DIR} | wc -l) existing chunks to ${RESUME_DIR}" +fi + +cd /project/6068129/melabbas/ami-ml +module load StdEnv/2023 arrow/17.0.0 +source .venv/bin/activate + +# ── STAGE 1: download + pack chunks ────────────────────────────────────────── +echo "=== STAGE 1: download (num_jobs=${NUM_JOBS}, chunk_size=10000) $(date) ===" +T0=$SECONDS +python src/dataset_tools/bq_squashfs/download_images.py \ + --staging-dir "${TASK_DIR}" \ + --num-jobs ${NUM_JOBS} \ + --task-id ${TASK_ID} \ + --num-workers 32 \ + --chunk-size 10000 \ + --dataset global_all_leps_2605 \ + || fail "download_images.py exited non-zero" +echo "STAGE 1 done in $((SECONDS - T0))s" + +# ── STAGE 2: count expected images across all chunks (incl. resume_* dirs) ── +echo "=== STAGE 2: count expected images $(date) ===" +EXPECTED=0 +while IFS= read -r chunk; do + N=$(unsquashfs -l "${chunk}" 2>/dev/null | grep -cE '\.(jpg|jpeg|png)$' || echo 0) + EXPECTED=$((EXPECTED + N)) +done < <(find "${TASK_DIR}" -name "chunk_*.sqfs") +echo "Expected: ${EXPECTED} images in $(find "${TASK_DIR}" -name 'chunk_*.sqfs' | wc -l) chunks" +[ "${EXPECTED}" -gt 0 ] || fail "no images found in chunks" + +# ── STAGE 3: merge chunks → task sqfs ──────────────────────────────────────── +echo "=== STAGE 3: merge $(date) ===" +T0=$SECONDS +rm -f "${OUT_SQFS}" +python src/dataset_tools/bq_squashfs/merge_sqfs_chunks.py "${TASK_DIR}" \ + | sqfstar -comp zstd -Xcompression-level 3 -b 131072 -no-duplicates "${OUT_SQFS}" +PIPE_STATUS=("${PIPESTATUS[@]}") +[ "${PIPE_STATUS[0]}" -eq 0 ] && [ "${PIPE_STATUS[1]}" -eq 0 ] \ + || fail "merge failed (stream=${PIPE_STATUS[0]} sqfstar=${PIPE_STATUS[1]}) — chunks preserved" +LOCAL_SIZE=$(stat -c%s "${OUT_SQFS}") +echo "STAGE 3 done in $((SECONDS - T0))s — $(numfmt --to=iec ${LOCAL_SIZE})" + +# ── STAGE 4: verify merged image count ─────────────────────────────────────── +ACTUAL=$(unsquashfs -l "${OUT_SQFS}" 2>/dev/null | grep -cE '\.(jpg|jpeg|png)$' || echo 0) +echo "Merged: ${ACTUAL}/${EXPECTED} images" +[ "${ACTUAL}" -eq "${EXPECTED}" ] || fail "image count mismatch ${ACTUAL}/${EXPECTED} — chunks preserved" + +# ── STAGE 5: upload + verify remote size ───────────────────────────────────── +echo "=== STAGE 5: upload $(numfmt --to=iec ${LOCAL_SIZE}) $(date) ===" +T0=$SECONDS +s5cmd --endpoint-url "${ENDPOINT}" cp "${OUT_SQFS}" "${S3_DEST}" \ + || fail "s5cmd upload failed — local sqfs + chunks preserved" +UP_SECS=$((SECONDS - T0)) +REMOTE_SIZE=$(s5cmd --endpoint-url "${ENDPOINT}" ls "${S3_DEST}" | awk '{print $3}') +echo "Uploaded in ${UP_SECS}s — local=${LOCAL_SIZE} remote=${REMOTE_SIZE}" +[ "${REMOTE_SIZE}" = "${LOCAL_SIZE}" ] || fail "remote size mismatch — local files preserved" + +# ── STAGE 6: verified — delete local ───────────────────────────────────────── +rm -f "${OUT_SQFS}" +find "${TASK_DIR}" -name "chunk_*.sqfs" -delete +rmdir "${TASK_DIR}"/resume_* 2>/dev/null || true +echo "Local cleanup done." + +echo "=== bq_download task=${TASK_ID} COMPLETE $(date) ===" +notify "bq_download task ${TASK_ID}/${NUM_JOBS}: done" \ + "${ACTUAL} images → ${S3_DEST} ($(numfmt --to=iec ${LOCAL_SIZE})), upload ${UP_SECS}s" +exit 0 diff --git a/scripts/job_bq_export.sh b/scripts/job_bq_export.sh new file mode 100644 index 0000000..1b04283 --- /dev/null +++ b/scripts/job_bq_export.sh @@ -0,0 +1,48 @@ +#!/bin/bash +# Export a BigQuery query to CSV. +# +# Usage: +# sbatch --export=QUERY_FILE=queries/global_min25occ.sql,OUTPUT=global_min25occ.csv job_bq_export.sh +# sbatch --export=QUERY_FILE=queries/global_max2000img.sql,OUTPUT=global_max2000img.csv job_bq_export.sh +# +#SBATCH --account=def-drolnick +#SBATCH --job-name=bq_export +#SBATCH --cpus-per-task=2 +#SBATCH --mem=8G +#SBATCH --time=1:00:00 +#SBATCH --output=/project/6068129/melabbas/ami-ml/scripts/bq_export_%j.out +#SBATCH --mail-type=END,FAIL +#SBATCH --mail-user=hack1996man@gmail.com + +set -euo pipefail + +BASE_DIR="/home/melabbas/projects/def-drolnick/melabbas/ami-ml" +DATA_DIR="/project/6068129/melabbas/ami-ml/data" + +echo "=== bq_export started at $(date) ===" +echo "Node : $(hostname)" +echo "QUERY_FILE : ${QUERY_FILE}" +echo "OUTPUT : ${DATA_DIR}/${OUTPUT}" +echo "" + +cd "${BASE_DIR}" +module load StdEnv/2023 arrow/17.0.0 +source .venv/bin/activate + +python src/dataset_tools/bq_squashfs/bq_export.py \ + --query-file "src/dataset_tools/bq_squashfs/${QUERY_FILE}" \ + --output "${DATA_DIR}/${OUTPUT}" \ + --project leps-ai + +EXIT_CODE=$? +echo "" +echo "=== bq_export done at $(date) (exit=${EXIT_CODE}) ===" + +if [ "${EXIT_CODE}" -eq 0 ]; then + SIZE=$(du -sh "${DATA_DIR}/${OUTPUT}" | cut -f1) + ROWS=$(( $(wc -l < "${DATA_DIR}/${OUTPUT}") - 1 )) + notify "bq_export done" "${OUTPUT} ${ROWS} rows ${SIZE}" +else + notify "bq_export FAILED" "exit=${EXIT_CODE} — check bq_export_${SLURM_JOB_ID}.out" + exit 1 +fi diff --git a/scripts/job_bq_split.sh b/scripts/job_bq_split.sh new file mode 100644 index 0000000..082e61b --- /dev/null +++ b/scripts/job_bq_split.sh @@ -0,0 +1,70 @@ +#!/bin/bash +# Split a BQ-exported CSV into train/val/test CSVs. +# Reads from DATA_DIR/, writes train.csv/val.csv/test.csv to DATA_DIR/splits/. +# +# Must run after job_bq_export.sh completes. +# +# Usage: +# sbatch --dependency=afterok: job_bq_split.sh +# sbatch --export=CSV=global_min25occ.csv --dependency=afterok: job_bq_split.sh +# +# CSV defaults to global_min25occ.csv if not set via --export. +# +#SBATCH --account=def-drolnick +#SBATCH --job-name=bq_split +#SBATCH --cpus-per-task=2 +#SBATCH --mem=8G +#SBATCH --time=0:30:00 +#SBATCH --output=/project/6068129/melabbas/ami-ml/scripts/bq_split_%j.out +#SBATCH --mail-type=END,FAIL +#SBATCH --mail-user=hack1996man@gmail.com + +set -euo pipefail + +BASE_DIR="/project/6068129/melabbas/ami-ml" +DATA_DIR="${BASE_DIR}/data" +CSV="${CSV:-global_min25occ.csv}" +INPUT_CSV="${DATA_DIR}/${CSV}" +OUTPUT_DIR="${DATA_DIR}/splits" + +echo "=== bq_split started at $(date) ===" +echo "Node : $(hostname)" +echo "Input CSV : ${INPUT_CSV} ($(du -sh ${INPUT_CSV} | cut -f1))" +echo "Output dir : ${OUTPUT_DIR}" +echo "" + +cd "${BASE_DIR}" +module load StdEnv/2023 arrow/17.0.0 +source .venv/bin/activate + +mkdir -p "${OUTPUT_DIR}" + +python src/dataset_tools/bq_squashfs/split.py \ + --csv "${INPUT_CSV}" \ + --output-dir "${OUTPUT_DIR}" \ + --category-key species_name \ + --val-frac 0.1 \ + --test-frac 0.1 \ + --split-by-occurrence \ + --max-instances 1000 \ + --min-instances 5 \ + --seed 42 + +EXIT_CODE=$? +echo "" +echo "=== bq_split done at $(date) (exit=${EXIT_CODE}) ===" + +if [ "${EXIT_CODE}" -eq 0 ]; then + TRAIN_ROWS=$(( $(wc -l < "${OUTPUT_DIR}/train.csv") - 1 )) + VAL_ROWS=$(( $(wc -l < "${OUTPUT_DIR}/val.csv") - 1 )) + TEST_ROWS=$(( $(wc -l < "${OUTPUT_DIR}/test.csv") - 1 )) + echo " train : ${TRAIN_ROWS} rows" + echo " val : ${VAL_ROWS} rows" + echo " test : ${TEST_ROWS} rows" + notify "bq_split: done" \ + "train=${TRAIN_ROWS} val=${VAL_ROWS} test=${TEST_ROWS} → ${OUTPUT_DIR}" +else + notify "bq_split: FAILED" \ + "exit=${EXIT_CODE} — check bq_split_${SLURM_JOB_ID}.out" + exit 1 +fi diff --git a/scripts/job_bq_train.sh b/scripts/job_bq_train.sh new file mode 100644 index 0000000..a6e0d8e --- /dev/null +++ b/scripts/job_bq_train.sh @@ -0,0 +1,113 @@ +#!/bin/bash +# Train a ResNet-50 classifier on the global WebDataset. +# Must run after job_build_wds_global_v2.sh completes. +# +# Reads : /scratch/melabbas/global_wds/{train,val,test}/*.tar + class_map.json +# Writes : /project/6068129/melabbas/ami-ml/models/global_wds/ +# +# Usage: +# sbatch --dependency=afterok: job_bq_train.sh +# +#SBATCH --account=def-drolnick_gpu +#SBATCH --job-name=bq_train +#SBATCH --gres=gpu:1 +#SBATCH --cpus-per-task=16 +#SBATCH --mem=64G +#SBATCH --time=3-00:00:00 +#SBATCH --output=/project/6068129/melabbas/ami-ml/scripts/bq_train_%j.out +#SBATCH --mail-type=BEGIN,END,FAIL +#SBATCH --mail-user=hack1996man@gmail.com + +set -euo pipefail + +BASE_DIR="/home/melabbas/projects/def-drolnick/melabbas/ami-ml" +WBDS="/scratch/melabbas/global_wds" +MODELS="/project/6068129/melabbas/ami-ml/models/global_wds" + +mkdir -p "${MODELS}" + +cd "${BASE_DIR}" +source .venv/bin/activate + +if [[ -f .env ]]; then + set -a + source .env + set +a +fi + +export WANDB_API_KEY="${WANDB_API_KEY_HACK1996MAN}" + +# Dynamically resolve shard counts and num_classes from the WDS output +N_TRAIN=$(find "${WBDS}/train" -name "train-*.tar" | wc -l) +N_VAL=$(find "${WBDS}/val" -name "val-*.tar" | wc -l) +N_TEST=$(find "${WBDS}/test" -name "test-*.tar" | wc -l) +NUM_CLASSES=$(python3 -c "import json; print(len(json.load(open('${WBDS}/class_map.json'))))") + +if [[ "${N_TRAIN}" -eq 0 ]]; then + echo "ERROR: no train shards found in ${WBDS}/train" + notify "bq_train: FAILED" "No train shards found — did job_build_wds_global_v2 complete?" + exit 1 +fi + +TRAIN_PAT="${WBDS}/train/train-{000000..$(printf '%06d' $((N_TRAIN-1)))}.tar" +VAL_PAT="${WBDS}/val/val-{000000..$(printf '%06d' $((N_VAL-1)))}.tar" +TEST_PAT="${WBDS}/test/test-{000000..$(printf '%06d' $((N_TEST-1)))}.tar" + +echo "=== bq_train started at $(date) ===" +echo "Node : $(hostname)" +echo "Train shards: ${N_TRAIN}" +echo "Val shards : ${N_VAL}" +echo "Test shards : ${N_TEST}" +echo "Num classes : ${NUM_CLASSES}" +echo "Models dir : ${MODELS}" +echo "" + +# Resume from latest checkpoint if available +RESUME_FLAG="" +RESUME_CKPT=$(ls -t "${MODELS}"/*_latest.pt 2>/dev/null | head -1 || true) +if [[ -z "${RESUME_CKPT}" ]]; then + RESUME_CKPT=$(ls -t "${MODELS}"/*_checkpoint.pt 2>/dev/null | head -1 || true) +fi +if [[ -n "${RESUME_CKPT}" ]]; then + echo "Resuming from checkpoint: ${RESUME_CKPT}" + RESUME_FLAG="--resume_from_checkpoint ${RESUME_CKPT}" +else + echo "No checkpoint found — starting fresh" +fi +echo "" + +ami-classification train-model \ + --train_webdataset "${TRAIN_PAT}" \ + --val_webdataset "${VAL_PAT}" \ + --test_webdataset "${TEST_PAT}" \ + --num_classes "${NUM_CLASSES}" \ + --model_type resnet50 \ + --image_input_size 128 \ + --model_save_directory "${MODELS}" \ + --total_epochs 30 \ + --warmup_epochs 2 \ + --early_stopping 100 \ + --learning_rate 0.001 \ + --learning_rate_scheduler cosine \ + --weight_decay 1e-5 \ + --batch_size 128 \ + --preprocess_mode torch \ + --mixed_resolution_data_aug true \ + --random_seed 123 \ + --wandb_entity "hack1996man" \ + --wandb_project "ai_for_leps" \ + --wandb_run_name "global_wds_bq_run1" \ + ${RESUME_FLAG} + +EXIT_CODE=$? +echo "" +echo "=== bq_train done at $(date) (exit=${EXIT_CODE}) ===" + +if [ "${EXIT_CODE}" -eq 0 ]; then + notify "bq_train: done" \ + "exit=0 classes=${NUM_CLASSES} train=${N_TRAIN} val=${N_VAL} test=${N_TEST} shards — models in ${MODELS}" +else + notify "bq_train: FAILED" \ + "exit=${EXIT_CODE} — check bq_train_${SLURM_JOB_ID}.out" + exit 1 +fi diff --git a/scripts/job_bq_webdataset.sh b/scripts/job_bq_webdataset.sh new file mode 100644 index 0000000..4acb8d4 --- /dev/null +++ b/scripts/job_bq_webdataset.sh @@ -0,0 +1,144 @@ +#!/bin/bash +# Build global WebDataset from all 10 sqfs files. +# +# Requires a BQ-exported CSV at CSV_PATH with columns: +# photo_id, relative_local_path, species_name (or class_id), + any other metadata +# +# Strategy: two batches of BATCH_SIZE sqfs to stay under 7 TB NVMe peak: +# Batch 1 (sqfs 0-4): --tar-mode w (create new tars) +# Batch 2 (sqfs 5-9): --tar-mode a (append, idempotent on retry) +# +# Export the CSV from BQ before submitting: +# bq query --format=csv --max_rows=15000000 \ +# "SELECT ti.photo_id, ti.relative_local_path, ti.dataset_source_uuid, +# tx.species_name, tx.inat_taxon_id, tx.family, tx.gbif_accepted_taxon_key +# FROM leps-ai.global_butterflies_2604.training_images ti +# JOIN leps-ai.global_butterflies_2604.inat_taxa tx USING (inat_taxon_id) +# JOIN (SELECT DISTINCT dataset_source_uuid +# FROM leps-ai.global_butterflies_2604.training_images_downloads +# WHERE fetch_status='downloaded') d USING (dataset_source_uuid) +# WHERE tx.species_name IS NOT NULL" > /project/6068129/melabbas/ami-ml/data/global_wds_export.csv +# +#SBATCH --account=def-drolnick +#SBATCH --job-name=build_wds_global +#SBATCH --cpus-per-task=16 +#SBATCH --mem=64G +#SBATCH --tmp=7000G +#SBATCH --time=12:00:00 +#SBATCH --output=/project/6068129/melabbas/ami-ml/scripts/build_wds_global_%j.out +#SBATCH --mail-type=BEGIN,END,FAIL +#SBATCH --mail-user=hack1996man@gmail.com +#SBATCH --exclude=fc30554 + +NVME="${SLURM_TMPDIR}" +OUTPUT_DIR="/scratch/melabbas/global_wds" +SPLITS_DIR="/project/6068129/melabbas/ami-ml/data/splits" # contains train.csv val.csv test.csv +IMAGES_PER_SHARD=1000 +BATCH_SIZE=5 +PACK_WORKERS=16 + +echo "=== build_wds_global started at $(date) ===" +echo "Node: $(hostname)" +echo "NVMe: ${NVME} ($(df -h ${NVME} | tail -1 | awk '{print $4}') free)" +echo "CSV: ${CSV_PATH} ($(du -sh ${CSV_PATH} | cut -f1))" +echo "" + +for SPLIT in train val test; do + if [ ! -f "${SPLITS_DIR}/${SPLIT}.csv" ]; then + echo "ERROR: ${SPLITS_DIR}/${SPLIT}.csv not found" + echo "Run split_csv.py first to generate the split CSVs." + exit 1 + fi +done +echo "Split CSVs: $(wc -l < ${SPLITS_DIR}/train.csv) train $(wc -l < ${SPLITS_DIR}/val.csv) val $(wc -l < ${SPLITS_DIR}/test.csv) test rows" +echo "" + +# ── Locate all sqfs files ───────────────────────────────────────────────────── +find_sqfs() { + local T=$1 + for DIR in /project/rrg-bengioy-ad/melabbas /project/6068129/melabbas /scratch/melabbas; do + [ -f "${DIR}/task_${T}.sqfs" ] && echo "${DIR}/task_${T}.sqfs" && return 0 + done + echo "" +} + +SQFS_PATHS=() +echo "Locating sqfs files..." +for T in $(seq 0 9); do + P=$(find_sqfs $T) + if [ -z "$P" ]; then + echo "ERROR: could not find task_${T}.sqfs" + exit 1 + fi + SQFS_PATHS+=("$P") + echo " task_${T}: ${P} ($(du -sh ${P} | cut -f1))" +done +echo "" + +# ── Setup Lustre output with HDD stripe ────────────────────────────────────── +for SPLIT in train val test; do + mkdir -p "${OUTPUT_DIR}/${SPLIT}" + lfs setstripe -c -1 -S 4m -p ddn_hdd "${OUTPUT_DIR}/${SPLIT}" 2>/dev/null || true +done +echo "Output dir: ${OUTPUT_DIR}" +echo "" + +# ── Python environment ──────────────────────────────────────────────────────── +cd /project/6068129/melabbas/ami-ml +module load StdEnv/2023 arrow/17.0.0 +source .venv/bin/activate + +# ── Batch 1: sqfs 0–(BATCH_SIZE-1), create new tars ────────────────────────── +BATCH1_END=$((BATCH_SIZE - 1)) +echo "=== BATCH 1: sqfs 0–${BATCH1_END} (mode=w) at $(date) ===" + +python src/dataset_tools/bq_squashfs/create_webdataset.py \ + --split-csvs train:${SPLITS_DIR}/train.csv val:${SPLITS_DIR}/val.csv test:${SPLITS_DIR}/test.csv \ + --sqfs-paths "${SQFS_PATHS[@]:0:${BATCH_SIZE}}" \ + --sqfs-start-idx 0 \ + --images-per-shard ${IMAGES_PER_SHARD} \ + --nvme-dir "${NVME}" \ + --output-dir "${OUTPUT_DIR}" \ + --pack-workers ${PACK_WORKERS} \ + --tar-mode w + +BATCH1_EXIT=$? +echo "" +if [ ${BATCH1_EXIT} -ne 0 ]; then + echo "ERROR: batch 1 failed (exit=${BATCH1_EXIT})" + notify "build_wds_global: FAILED (batch 1)" \ + "exit=${BATCH1_EXIT} — check build_wds_global_${SLURM_JOB_ID}.out" + exit ${BATCH1_EXIT} +fi + +# ── Batch 2: sqfs BATCH_SIZE–9, append to existing tars ────────────────────── +echo "=== BATCH 2: sqfs ${BATCH_SIZE}–9 (mode=a) at $(date) ===" + +python src/dataset_tools/bq_squashfs/create_webdataset.py \ + --split-csvs train:${SPLITS_DIR}/train.csv val:${SPLITS_DIR}/val.csv test:${SPLITS_DIR}/test.csv \ + --sqfs-paths "${SQFS_PATHS[@]:${BATCH_SIZE}}" \ + --sqfs-start-idx ${BATCH_SIZE} \ + --images-per-shard ${IMAGES_PER_SHARD} \ + --nvme-dir "${NVME}" \ + --output-dir "${OUTPUT_DIR}" \ + --pack-workers ${PACK_WORKERS} \ + --tar-mode a + +BATCH2_EXIT=$? +echo "" +echo "=== build_wds_global done at $(date) (exit=${BATCH2_EXIT}) ===" + +if [ ${BATCH2_EXIT} -eq 0 ]; then + for SPLIT in train val test; do + COUNT=$(ls "${OUTPUT_DIR}/${SPLIT}/"*.tar 2>/dev/null | wc -l) + SIZE=$(du -sh "${OUTPUT_DIR}/${SPLIT}" 2>/dev/null | cut -f1) + echo " ${SPLIT}: ${COUNT} shards ${SIZE}" + done + notify "build_wds_global: done" \ + "exit=0 job=${SLURM_JOB_ID} — train/val/test shards written to ${OUTPUT_DIR}" +else + notify "build_wds_global: FAILED (batch 2)" \ + "exit=${BATCH2_EXIT} — check build_wds_global_${SLURM_JOB_ID}.out" +fi + +exit ${BATCH2_EXIT} diff --git a/src/dataset_tools/bigquery_pipeline/clean.py b/src/dataset_tools/bigquery_pipeline/clean.py new file mode 100644 index 0000000..bf78166 --- /dev/null +++ b/src/dataset_tools/bigquery_pipeline/clean.py @@ -0,0 +1,239 @@ +#!/usr/bin/env python3 +""" +Clean a training_images BQ table by removing duplicates and sparse taxa. + +Duplicate types handled in order: + 1. Exact rows — same (photo_id, inat_taxon_id, gbif_id), keep one row + 2. Same-taxon multi-gbif — same (photo_id, inat_taxon_id), multiple gbif_ids → keep MIN gbif_id + 3. Multi-taxon conflict — same photo_id maps to multiple taxa → drop all or keep lowest taxon_id + 4. Min-images filter — drop taxa with fewer than N images (default: off) + +Overwrites the source table in-place unless --dry-run is passed. +""" +import argparse +import json +import sys +from datetime import datetime, timezone + +from google.cloud import bigquery + + +def _step3_cte(strategy): + if strategy == "drop": + return "SELECT * FROM step2 WHERE photo_id NOT IN (SELECT photo_id FROM multi_taxon_photos)" + else: # keep-lowest-taxon-id + return ( + "SELECT * FROM step2\n" + " QUALIFY ROW_NUMBER() OVER (PARTITION BY photo_id ORDER BY inat_taxon_id, gbif_id) = 1" + ) + + +def _cte_chain(src_ref, strategy, min_images): + min_filter = f"t.cnt >= {min_images}" if min_images > 0 else "TRUE" + return f""" +src AS ( + SELECT * FROM `{src_ref}` +), +step1 AS ( + -- Remove exact duplicate rows: same photo+taxon+gbif, keep one + SELECT * FROM src + QUALIFY ROW_NUMBER() OVER ( + PARTITION BY photo_id, inat_taxon_id, gbif_id + ORDER BY dataset_source_uuid + ) = 1 +), +step2 AS ( + -- Same photo+taxon with multiple gbif_ids: keep the MIN gbif_id + SELECT * FROM step1 + QUALIFY ROW_NUMBER() OVER ( + PARTITION BY photo_id, inat_taxon_id + ORDER BY gbif_id + ) = 1 +), +multi_taxon_photos AS ( + -- Photo IDs that still map to more than one taxon after step2 + SELECT photo_id FROM step2 + GROUP BY photo_id + HAVING COUNT(DISTINCT inat_taxon_id) > 1 +), +step3 AS ( + -- Resolve multi-taxon conflicts + {_step3_cte(strategy)} +), +taxa_img_count AS ( + SELECT inat_taxon_id, COUNT(*) AS cnt FROM step3 GROUP BY inat_taxon_id +), +step4 AS ( + -- Drop taxa below min-images threshold + SELECT s.* FROM step3 s + JOIN taxa_img_count t USING (inat_taxon_id) + WHERE {min_filter} +)""" + + +def run_count_query(client, src_ref, strategy, min_images): + """Return per-stage row/taxa counts in a single BQ query.""" + chain = _cte_chain(src_ref, strategy, min_images) + query = f""" +WITH {chain} +SELECT stage, n, taxa FROM ( + SELECT 'input' AS stage, COUNT(*) AS n, COUNT(DISTINCT inat_taxon_id) AS taxa FROM src UNION ALL + SELECT 'step1' AS stage, COUNT(*) AS n, COUNT(DISTINCT inat_taxon_id) AS taxa FROM step1 UNION ALL + SELECT 'step2' AS stage, COUNT(*) AS n, COUNT(DISTINCT inat_taxon_id) AS taxa FROM step2 UNION ALL + SELECT 'step3' AS stage, COUNT(*) AS n, COUNT(DISTINCT inat_taxon_id) AS taxa FROM step3 UNION ALL + SELECT 'step4' AS stage, COUNT(*) AS n, COUNT(DISTINCT inat_taxon_id) AS taxa FROM step4 UNION ALL + SELECT 'conflicts' AS stage, COUNT(*) AS n, 0 AS taxa FROM multi_taxon_photos +) +ORDER BY stage +""" + results = {row.stage: {"rows": row.n, "taxa": row.taxa} + for row in client.query(query).result()} + return results + + +def run_write_query(client, src_ref, dst_ref, strategy, min_images): + chain = _cte_chain(src_ref, strategy, min_images) + query = f"CREATE OR REPLACE TABLE `{dst_ref}` AS\nWITH {chain}\nSELECT * FROM step4" + client.query(query).result() + + +def print_report(counts, strategy, min_images, dst_ref, dry_run): + c = counts + step_labels = [ + ("input", "Input rows"), + ("step1", "Exact duplicates removed (same photo+taxon+gbif, keep one)"), + ("step2", "Same-taxon multi-gbif resolved (keep MIN gbif_id)"), + ("step3", f"Multi-taxon conflicts handled (strategy={strategy})"), + ("step4", f"Min-images-per-taxon filter (threshold={min_images or 'off'})"), + ] + + print() + print("=" * 70) + prev = None + for key, label in step_labels: + rows = c[key]["rows"] + taxa = c[key]["taxa"] + if prev is None: + print(f" {label}") + print(f" rows={rows:>10,} taxa={taxa:>8,}") + else: + removed = prev - rows + pct = removed / prev * 100 if prev else 0 + print(f" {label}") + print(f" rows={rows:>10,} removed={removed:>8,} ({pct:.1f}%)") + prev = rows + + # Step 3 extra detail + n_conflict_photos = c["conflicts"]["rows"] + print(f" [{n_conflict_photos:,} photo_ids had conflicting taxa]") + + total_removed = c["input"]["rows"] - c["step4"]["rows"] + total_pct = total_removed / c["input"]["rows"] * 100 if c["input"]["rows"] else 0 + print() + print(f" {'─'*60}") + print(f" Total removed : {total_removed:>10,} ({total_pct:.1f}%)") + print(f" Output rows : {c['step4']['rows']:>10,}") + print(f" Taxa before : {c['input']['taxa']:>10,}") + print(f" Taxa after : {c['step4']['taxa']:>10,}") + dry_tag = " [DRY RUN — table not written]" if dry_run else "" + print(f" Output table : {dst_ref}{dry_tag}") + print("=" * 70) + print() + + +def build_log(counts, args, dst_ref, dry_run, started_at): + c = counts + return { + "started_at": started_at, + "finished_at": datetime.now(timezone.utc).isoformat(), + "dataset": args.dataset, + "project": args.project, + "table": args.table, + "output_table": dst_ref, + "dry_run": dry_run, + "multi_taxon_strategy": args.multi_taxon_strategy, + "min_images_per_taxon": args.min_images_per_taxon, + "steps": { + "exact_duplicates": { + "before": c["input"]["rows"], + "after": c["step1"]["rows"], + "removed": c["input"]["rows"] - c["step1"]["rows"], + }, + "same_taxon_multi_gbif": { + "before": c["step1"]["rows"], + "after": c["step2"]["rows"], + "removed": c["step1"]["rows"] - c["step2"]["rows"], + }, + "multi_taxon_conflicts": { + "before": c["step2"]["rows"], + "after": c["step3"]["rows"], + "removed": c["step2"]["rows"] - c["step3"]["rows"], + "conflict_photo_ids": c["conflicts"]["rows"], + }, + "min_images_filter": { + "before": c["step3"]["rows"], + "after": c["step4"]["rows"], + "removed": c["step3"]["rows"] - c["step4"]["rows"], + "threshold": args.min_images_per_taxon, + }, + }, + "summary": { + "input_rows": c["input"]["rows"], + "output_rows": c["step4"]["rows"], + "total_removed": c["input"]["rows"] - c["step4"]["rows"], + "taxa_before": c["input"]["taxa"], + "taxa_after": c["step4"]["taxa"], + }, + } + + +def main(): + parser = argparse.ArgumentParser(description=__doc__, + formatter_class=argparse.RawDescriptionHelpFormatter) + parser.add_argument("--dataset", required=True, + help="BQ dataset name (e.g. global_all_leps_2605)") + parser.add_argument("--project", default="leps-ai") + parser.add_argument("--table", default="training_images", + help="Table to clean (overwritten in-place)") + parser.add_argument("--min-images-per-taxon", type=int, default=0, + help="Drop taxa with fewer images than this (0 = off)") + parser.add_argument("--multi-taxon-strategy", + choices=["drop", "keep-lowest-taxon-id"], default="drop", + help="How to handle a photo mapped to multiple taxa") + parser.add_argument("--dry-run", action="store_true", + help="Report what would be removed without writing the table") + parser.add_argument("--log-file", help="Write JSON log to this path") + args = parser.parse_args() + + started_at = datetime.now(timezone.utc).isoformat() + src_ref = f"{args.project}.{args.dataset}.{args.table}" + dst_ref = src_ref # overwrite in-place + + client = bigquery.Client(project=args.project) + + print(f"clean.py {'[DRY RUN] ' if args.dry_run else ''}started at {started_at}") + print(f" table : {src_ref}") + print(f" multi-taxon : {args.multi_taxon_strategy}") + print(f" min-images-per-taxon: {args.min_images_per_taxon or 'off'}") + print() + print("Running count queries ...") + + counts = run_count_query(client, src_ref, args.multi_taxon_strategy, args.min_images_per_taxon) + print_report(counts, args.multi_taxon_strategy, args.min_images_per_taxon, dst_ref, args.dry_run) + + if not args.dry_run: + print("Writing cleaned table ...") + run_write_query(client, src_ref, dst_ref, args.multi_taxon_strategy, args.min_images_per_taxon) + print(f"Done — {counts['step4']['rows']:,} rows written to {dst_ref}") + + log = build_log(counts, args, dst_ref, args.dry_run, started_at) + if args.log_file: + with open(args.log_file, "w") as f: + json.dump(log, f, indent=2) + print(f"Log written to {args.log_file}") + + return 0 + + +if __name__ == "__main__": + sys.exit(main()) diff --git a/src/dataset_tools/bigquery_pipeline/create_test_table.py b/src/dataset_tools/bigquery_pipeline/create_test_table.py new file mode 100644 index 0000000..2c03ff2 --- /dev/null +++ b/src/dataset_tools/bigquery_pipeline/create_test_table.py @@ -0,0 +1,100 @@ +#!/usr/bin/env python3 +""" +Create a stratified test table for clean.py testing. + +Includes all 3 duplicate types (exact dupes, same-taxon multi-gbif, +multi-taxon conflicts) plus a clean random sample. Total ~35-40K rows. +""" +import argparse +from google.cloud import bigquery + + +def main(): + parser = argparse.ArgumentParser(description=__doc__) + parser.add_argument("--dataset", required=True, help="BQ dataset name (e.g. global_all_leps_2605)") + parser.add_argument("--project", default="leps-ai") + parser.add_argument("--source-table", default="training_images") + parser.add_argument("--output-table", default="training_images_test") + parser.add_argument("--max-type1-photos", type=int, default=10_000, + help="Cap on type-1 duplicate photo_ids (each appears ×2 rows)") + parser.add_argument("--clean-sample", type=int, default=10_000, + help="Number of clean (non-duplicate) rows to include") + args = parser.parse_args() + + client = bigquery.Client(project=args.project) + src = f"`{args.project}.{args.dataset}.{args.source_table}`" + dst_ref = f"{args.project}.{args.dataset}.{args.output_table}" + dst = f"`{dst_ref}`" + + query = f""" + CREATE OR REPLACE TABLE {dst} AS + + WITH + -- Type 1: exact duplicate rows (same photo_id + taxon + gbif_id appears >1 time) + type1_photos AS ( + SELECT photo_id + FROM {src} + GROUP BY photo_id, inat_taxon_id, gbif_id + HAVING COUNT(*) > 1 + LIMIT {args.max_type1_photos} + ), + + -- Type 2: same photo_id mapped to multiple taxa AND multiple gbif_ids + type2_photos AS ( + SELECT photo_id + FROM {src} + GROUP BY photo_id + HAVING COUNT(DISTINCT inat_taxon_id) > 1 AND COUNT(DISTINCT gbif_id) > 1 + ), + + -- Type 3: same photo_id + same taxon, but multiple gbif_ids + type3_photos AS ( + SELECT photo_id + FROM {src} + GROUP BY photo_id + HAVING COUNT(DISTINCT inat_taxon_id) = 1 AND COUNT(DISTINCT gbif_id) > 1 + ), + + all_dup_photos AS ( + SELECT photo_id FROM type1_photos + UNION DISTINCT + SELECT photo_id FROM type2_photos + UNION DISTINCT + SELECT photo_id FROM type3_photos + ), + + -- All rows belonging to any duplicate photo_id + dup_rows AS ( + SELECT t.* + FROM {src} t + INNER JOIN all_dup_photos d USING (photo_id) + ), + + -- Clean rows: photo_ids not involved in any duplication + clean_rows AS ( + SELECT t.* + FROM {src} t + WHERE t.photo_id NOT IN (SELECT photo_id FROM all_dup_photos) + LIMIT {args.clean_sample} + ) + + SELECT * FROM dup_rows + UNION ALL + SELECT * FROM clean_rows + """ + + print(f"Creating {dst_ref} ...") + print(f" source : {args.project}.{args.dataset}.{args.source_table}") + print(f" type-1 cap: {args.max_type1_photos:,} photo_ids") + print(f" clean sample: {args.clean_sample:,} rows") + print() + + job = client.query(query) + job.result() + + ref = client.get_table(dst_ref) + print(f"Done — {ref.num_rows:,} rows written to {dst_ref}") + + +if __name__ == "__main__": + main() diff --git a/src/dataset_tools/bq_squashfs/README.md b/src/dataset_tools/bq_squashfs/README.md new file mode 100644 index 0000000..5972b6a --- /dev/null +++ b/src/dataset_tools/bq_squashfs/README.md @@ -0,0 +1,243 @@ +# BQ/SquashFS Pipeline + +End-to-end pipeline for building a WebDataset training set from the BigQuery +`training_images` table on the **fir** cluster (Compute Canada / DRAC). + +## Overview + +``` +[BigQuery: training_images] + │ + ▼ Stage 1: download + task_0.sqfs … task_9.sqfs ← all downloaded images, split into 10 chunks + │ + ▼ Stage 2: bq_export + global_min25occ.csv ← metadata for qualifying images (species, paths, taxon IDs) + │ + ▼ Stage 3: split + splits/train.csv + splits/val.csv + splits/test.csv + │ + ▼ Stage 4: webdataset + global_wds/{train,val,test}/*.tar + class_map.json + │ + ▼ Stage 5: train + models/global_wds/*.pt +``` + +--- + +## Data Source: BigQuery `training_images` + +All images originate from the `leps-ai.global_butterflies_2604.training_images` +BigQuery table. Each row represents one image with: + +- `photo_id` — iNaturalist photo ID (integer) +- `gbif_id` — GBIF occurrence ID (multiple images can share the same occurrence) +- `relative_local_path` — path of the image file within its SquashFS chunk +- `inat_taxon_id` — iNaturalist taxon ID (used as the class label) +- `fetch_status` — `'downloaded'` once the image has been fetched to disk + +Only rows with `fetch_status = 'downloaded'` are used in the pipeline. + +--- + +## SquashFS Chunks: task_0 … task_9 + +Images on disk are stored in 10 SquashFS archives (`task_0.sqfs` … `task_9.sqfs`), +located at `/project/rrg-bengioy-ad/melabbas/` and `/project/6068129/melabbas/`. + +The task assignment is deterministic: + +``` +task_id = photo_id % 10 +``` + +So `task_3.sqfs` contains all images whose `photo_id` ends in 3. The +`relative_local_path` column in BigQuery is the path of the image *within* its +sqfs archive. This split allows parallel downloading across 10 SLURM array tasks +and efficient batched processing during WebDataset creation. + +--- + +## Stage 1 — Download + +**SLURM job:** `scripts/job_bq_download.sh` +**Python:** `download_images.py` + +Runs as a SLURM array job (10 tasks). Each task handles `photo_id % 10 == task_id` +and performs the full download-verify-pack loop: + +1. Queries BQ for pending images assigned to this task, skipping any already + recorded in `training_images_downloads` — fully resumable on resubmit. +2. Downloads images in parallel (32 workers) from `absolute_url`. +3. Verifies each image with PIL (width, height, corruption check). +4. Writes results back to the `training_images_downloads` BQ table + (`fetch_status`: `downloaded`, `failed`, or `corrupted`). +5. Every 10,000 images, packs the staging dir into a `chunk_NNNN.sqfs` file + using `mksquashfs`, then deletes the raw images to keep inode usage low. + +After all 10 tasks complete, the per-chunk sqfs files in each staging dir are +merged into the final `task_N.sqfs` archives by the pack job. + +```bash +sbatch scripts/job_bq_download.sh +``` + +**Output:** `task_0.sqfs` … `task_9.sqfs` + +--- + +## Stage 2 — BQ Export + +**SLURM job:** `scripts/job_bq_export.sh` +**Python:** `bq_export.py` + +Runs a SQL query against BigQuery and streams the results to a CSV file on +Lustre. The query joins `training_images` with `inat_taxa` to attach species +names and taxonomy, and filters to images with `fetch_status = 'downloaded'`. + +Two queries are available under `queries/`: + +| Query file | Filter | Expected rows | +|---|---|---| +| `global_min25occ.sql` | Species with ≥ 25 distinct GBIF occurrences | ~10.6M images, ~4,704 species | +| `global_max2000img.sql` | Cap at 2,000 images per species | smaller subset | + +```bash +# Default (global_min25occ.csv) +sbatch --dependency=afterok: scripts/job_bq_export.sh + +# Custom query +sbatch --export=QUERY_FILE=queries/global_max2000img.sql,OUTPUT=global_max2000img.csv \ + --dependency=afterok: scripts/job_bq_export.sh +``` + +**Input:** BigQuery `training_images` + `inat_taxa` tables +**Output:** `data/global_min25occ.csv` (or custom filename) + +--- + +## Stage 3 — Split + +**SLURM job:** `scripts/job_bq_split.sh` +**Python:** `split.py` + +Splits the BQ-exported CSV into `train.csv`, `val.csv`, and `test.csv` using +stratified sampling so every species is proportionally represented in all three +sets. + +Key behaviours: +- **Split by occurrence** — images that share the same `gbif_id` (same field + observation) are always kept in the same split, preventing data leakage. +- **Max instances** — caps images per species at 1,000 for train (proportional + for val/test) to avoid class imbalance dominating training. +- **Min instances** — species with fewer than 5 training images are dropped. + +```bash +sbatch --dependency=afterok: scripts/job_bq_split.sh + +# Custom input CSV +sbatch --export=CSV=global_max2000img.csv \ + --dependency=afterok: scripts/job_bq_split.sh +``` + +**Input:** `data/` (default: `global_min25occ.csv`) +**Output:** `data/splits/train.csv`, `data/splits/val.csv`, `data/splits/test.csv` + +--- + +## Stage 4 — WebDataset + +**SLURM job:** `scripts/job_bq_webdataset.sh` +**Python:** `create_webdataset.py` or `create_webdataset_generic.py` + +Packs images from the SquashFS archives into WebDataset tar shards, organised +into train/val/test splits. Each shard contains ~1,000 images. Every image is +stored as a triplet inside the tar: + +- `.jpg` — image bytes +- `.cls` — integer class ID (text) +- `.json` — metadata (`class_id`, `relative_local_path`, `task_id`, species fields) + +A `class_map.json` is also written to the output root, mapping sequential class +IDs to `inat_taxon_id` values. + +```bash +sbatch --dependency=afterok: scripts/job_bq_webdataset.sh +``` + +**Input:** `task_0.sqfs` … `task_9.sqfs`, `data/splits/{train,val,test}.csv` +**Output:** `/scratch/melabbas/global_wds/{train,val,test}/*.tar`, `class_map.json` + +### Two versions of `create_webdataset` + +| Script | When to use | +|---|---| +| `create_webdataset.py` | On **fir** — copies each sqfs to NVMe (`$SLURM_TMPDIR`), mounts with `squashfuse`, scatters images to NVMe shard dirs, packs to Lustre tars. Two-batch strategy keeps NVMe usage under 7 TB. | +| `create_webdataset_generic.py` | On **any machine** where sqfs are already mounted. Takes a parent directory containing `task_0/` … `task_9/` subdirectories. Reads images directly from the mounted dirs and streams them into tar shards — no NVMe copy, no squashfuse calls. | + +The fir-specific version (`create_webdataset.py`) exists because the SquashFS +files are ~1 TB each and reading them directly from Lustre during scatter is +too slow — copying to NVMe first gives ~10× better I/O throughput. On machines +where the sqfs are already mounted (or images are on a fast local filesystem), +the generic version is simpler and produces identical output. + +To use the generic version, mount each sqfs into a parent directory first: + +```bash +mkdir -p /mnt/images/task_{0..9} +for T in $(seq 0 9); do + squashfuse /path/to/task_${T}.sqfs /mnt/images/task_${T} +done + +python create_webdataset_generic.py \ + --images-dir /mnt/images \ + --split-csvs train:splits/train.csv val:splits/val.csv test:splits/test.csv \ + --output-dir /output/global_wds +``` + +--- + +## Stage 5 — Train + +**SLURM job:** `scripts/job_bq_train.sh` + +Trains a ResNet-50 classifier on the WebDataset shards using the +`ami-classification train-model` CLI. Shard counts and number of classes are +resolved dynamically from the output directory. Automatically resumes from the +latest checkpoint if one exists in the models directory. + +```bash +sbatch --dependency=afterok: scripts/job_bq_train.sh +``` + +**Input:** `/scratch/melabbas/global_wds/{train,val,test}/*.tar`, `class_map.json` +**Output:** `/project/6068129/melabbas/ami-ml/models/global_wds/*.pt` + +--- + +## Chaining the Full Pipeline + +```bash +DOWNLOAD_JOB=$(sbatch --parsable scripts/job_bq_download.sh) +EXPORT_JOB=$(sbatch --parsable --dependency=afterok:$DOWNLOAD_JOB scripts/job_bq_export.sh) +SPLIT_JOB=$(sbatch --parsable --dependency=afterok:$EXPORT_JOB scripts/job_bq_split.sh) +WDS_JOB=$(sbatch --parsable --dependency=afterok:$SPLIT_JOB scripts/job_bq_webdataset.sh) +sbatch --dependency=afterok:$WDS_JOB scripts/job_bq_train.sh +``` + +--- + +## File Reference + +| File | Role | +|---|---| +| `download_images.py` | Stage 1 — fetch images from iNaturalist, verify with PIL, write results back to BQ (`training_images_downloads`), pack into per-chunk sqfs files | +| `bq_export.py` | Stage 2 — export BigQuery query results to CSV | +| `split.py` | Stage 3 — stratified train/val/test split with occurrence-level grouping | +| `create_webdataset.py` | Stage 4 — fir-specific, NVMe-optimised WebDataset packer | +| `create_webdataset_generic.py` | Stage 4 — generic WebDataset packer for pre-mounted image directories | +| `queries/global_min25occ.sql` | BQ query — species with ≥ 25 occurrences (~10.6M images) | +| `queries/global_max2000img.sql` | BQ query — capped at 2,000 images per species | diff --git a/src/dataset_tools/bq_squashfs/bq_export.py b/src/dataset_tools/bq_squashfs/bq_export.py new file mode 100644 index 0000000..ce5d0cb --- /dev/null +++ b/src/dataset_tools/bq_squashfs/bq_export.py @@ -0,0 +1,120 @@ +#!/usr/bin/env python3 +""" +Export BigQuery query results to a CSV file. + +Reads a SQL query from a file, executes it against BigQuery, and streams +results row-by-row to a CSV file. Generic — works with any SELECT query. + +Usage: + python bq_export.py \\ + --query-file queries/global_min25occ.sql \\ + --output /project/.../global_min25occ.csv + + python bq_export.py \\ + --query-file queries/global_max2000img.sql \\ + --output /project/.../global_max2000img.csv +""" + +import argparse +import csv +import sys +import time +from pathlib import Path + +from google.cloud import bigquery + +LOG_EVERY = 500_000 + + +def log(msg: str) -> None: + print(msg, flush=True) + + +def export(query_file: Path, output: Path, project: str) -> None: + query = query_file.read_text().strip() + if not query: + log(f"ERROR: query file {query_file} is empty", file=sys.stderr) + sys.exit(1) + + client = bigquery.Client(project=project) + + log(f"Query file : {query_file}") + log(f"Output : {output}") + log(f"Submitting query ...") + + job = client.query(query) + log(f"Job ID : {job.job_id}") + log("Streaming rows to CSV ...") + + output.parent.mkdir(parents=True, exist_ok=True) + t0 = time.perf_counter() + + rows_written = 0 + unique_species: set[str] = set() + unique_occurrences: set[int] = set() + + with open(output, "w", newline="") as f: + writer = None + for row in job.result(): + row_dict = dict(row) + + if writer is None: + writer = csv.DictWriter(f, fieldnames=list(row_dict.keys())) + writer.writeheader() + + writer.writerow(row_dict) + rows_written += 1 + + if "species_name" in row_dict: + unique_species.add(row_dict["species_name"]) + if "gbif_id" in row_dict: + unique_occurrences.add(row_dict["gbif_id"]) + + if rows_written % LOG_EVERY == 0: + elapsed = (time.perf_counter() - t0) / 60 + rate = rows_written / (time.perf_counter() - t0) + log(f" {rows_written:,} rows {elapsed:.1f} min ({rate:,.0f} rows/s)") + + if rows_written == 0: + log("ERROR: query returned 0 rows", file=sys.stderr) + sys.exit(1) + + elapsed = (time.perf_counter() - t0) / 60 + size_mb = output.stat().st_size / 1024**2 + + log(f"\n=== Export complete ===") + log(f" Rows written : {rows_written:,}") + if unique_species: + log(f" Unique species : {len(unique_species):,}") + if unique_occurrences: + log(f" Unique gbif_ids : {len(unique_occurrences):,}") + log(f" Output : {output} ({size_mb:.0f} MB)") + log(f" Elapsed : {elapsed:.1f} min") + + +def main() -> None: + parser = argparse.ArgumentParser( + description=__doc__, + formatter_class=argparse.RawDescriptionHelpFormatter, + ) + parser.add_argument("--query-file", required=True, type=Path, + help="Path to a .sql file containing the SELECT query") + parser.add_argument("--output", required=True, type=Path, + help="Path to write the output CSV file") + parser.add_argument("--project", default="leps-ai", + help="GCP project (default: leps-ai)") + args = parser.parse_args() + + if not args.query_file.exists(): + print(f"ERROR: query file not found: {args.query_file}", file=sys.stderr) + sys.exit(1) + + export( + query_file=args.query_file, + output=args.output, + project=args.project, + ) + + +if __name__ == "__main__": + main() diff --git a/src/dataset_tools/bq_squashfs/create_test_tables.py b/src/dataset_tools/bq_squashfs/create_test_tables.py new file mode 100644 index 0000000..79fdec8 --- /dev/null +++ b/src/dataset_tools/bq_squashfs/create_test_tables.py @@ -0,0 +1,83 @@ +#!/usr/bin/env python3 +""" +Create small BQ test tables for testing download_images.py without touching production. + +Creates: + _test_training_images — 50 rows sampled from training_images, fetch_status='pending' + _test_training_images_downloads — empty table, same schema as training_images_downloads + +Usage: + python create_test_tables.py + python create_test_tables.py --n-rows 100 +""" + +import argparse +from google.cloud import bigquery + +BQ_PROJECT = "leps-ai" +BQ_DATASET = "global_butterflies_2604" + + +def main(): + parser = argparse.ArgumentParser(description=__doc__, + formatter_class=argparse.RawDescriptionHelpFormatter) + parser.add_argument("--n-rows", type=int, default=50, + help="Number of rows to sample from training_images (default: 50)") + parser.add_argument("--dataset", default=BQ_DATASET, + help=f"BigQuery dataset name (default: {BQ_DATASET}). " + f"Example: --dataset global_all_leps_2605") + args = parser.parse_args() + + client = bigquery.Client(project=BQ_PROJECT) + prefix = f"{BQ_PROJECT}.{args.dataset}" + + # ── _test_training_images ───────────────────────────────────────────────── + print(f"Creating {prefix}._test_training_images ({args.n_rows} rows)...") + client.query(f""" + CREATE OR REPLACE TABLE `{prefix}._test_training_images` AS + SELECT + photo_id, + gbif_id, + inat_taxon_id, + dataset_source_uuid, + absolute_url, + relative_local_path, + 'pending' AS fetch_status, + CAST(NULL AS INT64) AS image_width, + CAST(NULL AS INT64) AS image_height, + CAST(NULL AS INT64) AS image_size, + CAST(NULL AS BOOL) AS corrupted + FROM `{prefix}.training_images` + WHERE fetch_status = 'downloaded' OR fetch_status IS NULL + ORDER BY RAND() + LIMIT {args.n_rows} + """).result() + + n = client.get_table(f"{prefix}._test_training_images").num_rows + print(f" Created: {n} rows, all fetch_status='pending'") + + # ── _test_training_images_downloads ─────────────────────────────────────── + print(f"Creating {prefix}._test_training_images_downloads (empty)...") + client.query(f""" + CREATE OR REPLACE TABLE `{prefix}._test_training_images_downloads` + ( + dataset_source_uuid STRING, + fetch_status STRING, + image_width INT64, + image_height INT64, + image_size INT64, + corrupted BOOL + ) + """).result() + print(" Created: 0 rows") + + print("\nDone. Run test with:") + print(f" python download_images.py \\") + print(f" --staging-dir /scratch/$USER/test_download \\") + print(f" --num-jobs 1 --task-id 0 \\") + print(f" --num-workers 8 --chunk-size {args.n_rows} \\") + print(f" --limit {args.n_rows} --table-prefix _test_") + + +if __name__ == "__main__": + main() diff --git a/src/dataset_tools/bq_squashfs/create_webdataset.py b/src/dataset_tools/bq_squashfs/create_webdataset.py new file mode 100644 index 0000000..c8dad44 --- /dev/null +++ b/src/dataset_tools/bq_squashfs/create_webdataset.py @@ -0,0 +1,537 @@ +#!/usr/bin/env python3 +""" +Build a WebDataset from per-split CSVs and SquashFS files. + +Workflow: + 1. Load per-split CSVs (produced by split_csv.py) + 2. Assign class IDs alphabetically from species_name if class_id absent + 3. Assign each image to a shard within its split + — hash-based (--n-shards) or CSV-order-based (--images-per-shard) + 4. For each sqfs: copy to NVMe → mount → scatter images to split/shard dirs + 5. Pack each split's shard dirs → Lustre tar files (shuffled within each shard) + +CSV required columns (same in all split files): + photo_id — integer; task_id = photo_id % 10 (which sqfs) + relative_local_path — path of the image within the sqfs + +CSV class column (one of): + class_id — integer; used directly if present + species_name — used to assign class_id alphabetically if class_id absent + +Any other CSV columns are passed through into per-sample .json metadata. + +Typical usage — two batches for global WDS (stays under 7 TB NVMe peak): + + Batch 1 — sqfs 0–4, create new tars: + python create_webdataset.py \\ + --split-csvs train:train.csv val:val.csv test:test.csv \\ + --sqfs-paths task_0.sqfs ... task_4.sqfs \\ + --sqfs-start-idx 0 \\ + --n-shards 10700 \\ + --nvme-dir $SLURM_TMPDIR \\ + --output-dir /scratch/melabbas/global_wds \\ + --tar-mode w + + Batch 2 — sqfs 5–9, append to existing tars: + python create_webdataset.py \\ + --split-csvs train:train.csv val:val.csv test:test.csv \\ + --sqfs-paths task_5.sqfs ... task_9.sqfs \\ + --sqfs-start-idx 5 \\ + --n-shards 10700 \\ + --nvme-dir $SLURM_TMPDIR \\ + --output-dir /scratch/melabbas/global_wds \\ + --tar-mode a + +Shard assignment is deterministic (hash of rel_path + seed) — the same image +always maps to the same shard, so --tar-mode a is safe to retry. +""" + +import argparse +import hashlib +import json +import math +import random +import shutil +import subprocess +import tarfile +import time +from collections import defaultdict +from concurrent.futures import ThreadPoolExecutor +from pathlib import Path + +import pandas as pd + +IMAGE_EXTENSIONS = {".jpg", ".jpeg", ".png"} + +_t_start = time.perf_counter() + + +# ── Logging ─────────────────────────────────────────────────────────────────── + +def log(msg: str) -> None: + elapsed = (time.perf_counter() - _t_start) / 3600 + print(f"[{elapsed:5.2f}h] {msg}", flush=True) + + +# ── Disk helpers ────────────────────────────────────────────────────────────── + +def disk_free_tb(path: str) -> float: + r = subprocess.run(["df", "--output=avail", "-k", path], + capture_output=True, text=True) + return int(r.stdout.strip().split()[-1]) / 1024**3 + + +# ── squashfuse helpers ──────────────────────────────────────────────────────── + +def sqfs_mount(sqfs_path: Path, mnt_dir: Path) -> None: + r = subprocess.run(["squashfuse", str(sqfs_path), str(mnt_dir)], + capture_output=True, text=True) + if r.returncode != 0: + raise RuntimeError(f"squashfuse failed for {sqfs_path}: {r.stderr.strip()}") + + +def sqfs_unmount(mnt_dir: Path) -> None: + subprocess.run(["fusermount", "-u", str(mnt_dir)], capture_output=True) + + +# ── CSV loading ─────────────────────────────────────────────────────────────── + +def load_split_csvs( + split_csvs: list[tuple[str, str]], + taxon_id_column: str = "inat_taxon_id", +) -> tuple[dict[str, pd.DataFrame], dict[int, int]]: + """Load per-split CSVs, assign class_ids if absent, return (dfs, class_map). + + If 'class_id' is absent, taxon_id_column is required and must be present in + every split CSV. Taxon IDs are sorted numerically and mapped to sequential + class IDs. class_map.json stores {sequential_id: taxon_id}. + """ + dfs: dict[str, pd.DataFrame] = {} + for name, path in split_csvs: + df = pd.read_csv(path, dtype={"photo_id": "Int64"}) + missing = {"photo_id", "relative_local_path"} - set(df.columns) + if missing: + raise ValueError(f"{path}: missing required columns {missing}") + assert taxon_id_column in df.columns, ( + f"{path}: required label column '{taxon_id_column}' not found " + f"(available: {list(df.columns)})" + ) + df["task_id"] = (df["photo_id"] % 10).astype(int) + df["row_idx"] = range(len(df)) # position in this split's CSV (used by --images-per-shard) + dfs[name] = df + log(f" {name}: {len(df):,} rows from {path}") + + # Assign class_ids across all splits together so IDs are consistent + all_df = pd.concat(dfs.values(), ignore_index=True) + class_map: dict[int, int] = {} + if "class_id" not in all_df.columns: + taxon_ids = sorted(all_df[taxon_id_column].dropna().unique()) + id_map = {int(tid): idx for idx, tid in enumerate(taxon_ids)} + for split_name, df in dfs.items(): + dfs[split_name]["class_id"] = df[taxon_id_column].map(id_map).astype("Int64") + log(f" Assigned class_ids for {len(taxon_ids):,} taxon IDs " + f"(column='{taxon_id_column}', sorted numerically)") + class_map = {idx: int(tid) for tid, idx in id_map.items()} + else: + # class_id already in CSV — build reverse map from taxon_id + class_map = ( + all_df[["class_id", taxon_id_column]] + .dropna() + .drop_duplicates("class_id") + .set_index("class_id")[taxon_id_column] + .astype(int) + .to_dict() + ) + + return dfs, class_map + + +def save_class_map(class_map: dict[int, str], output_dir: Path) -> None: + dest = output_dir / "class_map.json" + if dest.exists(): + log(f"class_map.json already exists — skipping ({dest})") + return + with open(dest, "w") as f: + json.dump({str(k): v for k, v in sorted(class_map.items())}, f, indent=2) + log(f"Saved class_map.json ({len(class_map):,} classes → {dest})") + + +# ── Shard assignment ────────────────────────────────────────────────────────── + + + +# ── Build per-task lookup ───────────────────────────────────────────────────── + +def build_lookup( + dfs: dict[str, pd.DataFrame], + task_ids: list[int], + meta_columns: list[str], + images_per_shard: int, +) -> dict[int, dict[str, dict]]: + """Build {task_id: {rel_path: {split, shard_id, class_id, meta}}}. + + Shard assignment: CSV-order — row_idx // images_per_shard. + First N rows → shard 0, next N → shard 1, etc. + """ + lookup: dict[int, dict[str, dict]] = {tid: {} for tid in task_ids} + + for split, df in dfs.items(): + subset = df[df["task_id"].isin(task_ids)] + for row in subset.itertuples(index=False): + rel = row.relative_local_path + tid = row.task_id + meta = { + col: getattr(row, col) + for col in meta_columns + if hasattr(row, col) and not pd.isna(getattr(row, col)) + } + lookup[tid][rel] = { + "split": split, + "shard_id": int(row.row_idx) // images_per_shard, + "class_id": int(row.class_id), + "meta": meta, + } + + for tid in task_ids: + log(f" task_{tid}: {len(lookup[tid]):,} images in lookup") + return lookup + + +# ── Scatter ─────────────────────────────────────────────────────────────────── + +def scatter_sqfs( + sqfs_path: Path, + sqfs_idx: int, + nvme_dir: Path, + nvme_mnt: Path, + nvme_shards: Path, + task_lookup: dict[str, dict], + no_nvme_copy: bool = False, + limit: int = 0, +) -> tuple[int, int]: + """Copy sqfs to NVMe, mount, scatter images + labels to split/shard dirs. + + --no-nvme-copy: mount sqfs directly from source (skips copy; for testing). + --limit N: stop after N images per sqfs (for smoke tests). + """ + if no_nvme_copy: + mount_target = sqfs_path + log(f" [{sqfs_path.name}] mounting directly (--no-nvme-copy)") + else: + mount_target = nvme_dir / sqfs_path.name + log(f" [{sqfs_path.name}] copying to NVMe " + f"(NVMe free: {disk_free_tb(str(nvme_dir)):.2f} TB)") + t0 = time.perf_counter() + shutil.copy2(str(sqfs_path), str(mount_target)) + gb = mount_target.stat().st_size / 1024**3 + elapsed = time.perf_counter() - t0 + log(f" [{sqfs_path.name}] copied {gb:.1f} GB in {elapsed:.0f}s " + f"({gb * 1024 / elapsed:.0f} MB/s)") + + sqfs_mount(mount_target, nvme_mnt) + try: + if limit: + # Don't materialise the full list — stop enumerating once limit is hit + all_paths = ( + p for p in nvme_mnt.rglob("*") + if p.suffix.lower() in IMAGE_EXTENSIONS + ) + else: + # Sort for sequential reads → squashfuse block cache reuse + all_paths = sorted( + p for p in nvme_mnt.rglob("*") + if p.suffix.lower() in IMAGE_EXTENSIONS + ) + cap = f" (capped at {limit})" if limit else "" + log(f" [{sqfs_path.name}] scattering{cap} " + f"(lookup: {len(task_lookup):,} entries)...") + + written = skipped = 0 + total_bytes = 0 + t0 = time.perf_counter() + + for img_path in all_paths: + if limit and written >= limit: + break + + rel = str(img_path.relative_to(nvme_mnt)) + entry = task_lookup.get(rel) + if entry is None: + skipped += 1 + continue + + split = entry["split"] + shard_id = entry["shard_id"] + class_id = entry["class_id"] + meta = entry["meta"] + key = hashlib.md5(f"{sqfs_idx}:{rel}".encode()).hexdigest() + shard_dir = nvme_shards / split / f"shard_{shard_id:06d}" + ext = img_path.suffix.lower() + + img_bytes = img_path.read_bytes() + (shard_dir / f"{key}{ext}").write_bytes(img_bytes) + (shard_dir / f"{key}.cls").write_bytes(str(class_id).encode()) + (shard_dir / f"{key}.json").write_bytes( + json.dumps({ + "class_id": class_id, + "relative_local_path": rel, + "sqfs_idx": sqfs_idx, + **meta, + }).encode() + ) + + total_bytes += len(img_bytes) + written += 1 + if written % 100_000 == 0: + elapsed = time.perf_counter() - t0 + log(f" {written:,}/{len(all_paths):,} " + f"{total_bytes/1024**3:.1f} GB {written/elapsed:.0f} img/s") + + elapsed = time.perf_counter() - t0 + log(f" [{sqfs_path.name}] scattered {written:,} in {elapsed:.0f}s " + f"({written/elapsed:.0f} img/s) skipped={skipped:,}") + + if not limit: + missing = len(task_lookup) - written + if missing > 0: + log(f" WARNING: {missing:,} images are in the CSV but were not found in " + f"{sqfs_path.name} — they will be missing from the dataset") + finally: + sqfs_unmount(nvme_mnt) + + if not no_nvme_copy: + mount_target.unlink() + log(f" [{sqfs_path.name}] deleted NVMe free: {disk_free_tb(str(nvme_dir)):.2f} TB") + return written, skipped + + +# ── Pack ────────────────────────────────────────────────────────────────────── + +def pack_split( + split: str, + nvme_shards: Path, + lustre_split_dir: Path, + n_shards: int, + tar_mode: str, + pack_workers: int, + seed: int, +) -> int: + """Pack one split's NVMe shard dirs → Lustre tar files. Returns total bytes.""" + split_shards = nvme_shards / split + non_empty = [ + split_shards / f"shard_{s:06d}" + for s in range(n_shards) + if (split_shards / f"shard_{s:06d}").exists() + and any((split_shards / f"shard_{s:06d}").iterdir()) + ] + log(f" [{split}] packing {len(non_empty):,} non-empty shards " + f"(mode='{tar_mode}', {pack_workers} workers)") + + rng = random.Random(seed) + + def pack_one(shard_dir: Path) -> int: + shard_id = int(shard_dir.name.split("_")[1]) + tar_path = lustre_split_dir / f"{split}-{shard_id:06d}.tar" + files = list(shard_dir.iterdir()) + if not files: + return 0 + + existing_keys: set[str] = set() + if tar_mode == "a" and tar_path.exists(): + with tarfile.open(tar_path, "r") as tf: + existing_keys = {m.name for m in tf.getmembers()} + + # Group into per-sample triplets (.cls .jpg .json), then shuffle samples + key_groups: dict[str, list[Path]] = defaultdict(list) + for p in files: + key_groups[p.stem].append(p) + keys = list(key_groups) + rng.shuffle(keys) + + total = 0 + with tarfile.open(tar_path, tar_mode) as tf: + for key in keys: + for p in sorted(key_groups[key]): # consistent order within sample + if p.name not in existing_keys: + tf.add(p, arcname=p.name) + total += p.stat().st_size + p.unlink() + return total + + t0 = time.perf_counter() + total_bytes = 0 + with ThreadPoolExecutor(max_workers=pack_workers) as pool: + for nb in pool.map(pack_one, non_empty, chunksize=32): + total_bytes += nb + elapsed = time.perf_counter() - t0 + log(f" [{split}] packed {total_bytes/1024**3:.1f} GB in {elapsed:.0f}s " + f"({total_bytes/1024**2/elapsed:.0f} MB/s)") + return total_bytes + + +def pack_to_lustre( + splits: list[str], + shards_per_split: dict[str, int], + nvme_shards: Path, + output_dir: Path, + tar_mode: str, + pack_workers: int, + seed: int, +) -> None: + for split in splits: + lustre_split_dir = output_dir / split + lustre_split_dir.mkdir(parents=True, exist_ok=True) + pack_split(split, nvme_shards, lustre_split_dir, + shards_per_split[split], tar_mode, pack_workers, seed) + + +# ── Main ────────────────────────────────────────────────────────────────────── + +def parse_split_csvs(values: list[str]) -> list[tuple[str, str]]: + """Parse ['train:train.csv', 'val:val.csv', ...] → [('train', 'train.csv'), ...]""" + result = [] + for v in values: + name, path = v.split(":", 1) + result.append((name.strip(), path.strip())) + return result + + +def main() -> None: + parser = argparse.ArgumentParser(description=__doc__, + formatter_class=argparse.RawDescriptionHelpFormatter) + parser.add_argument("--split-csvs", nargs="+", required=True, + help="Per-split CSV files as 'name:path' (e.g. train:train.csv)") + parser.add_argument("--sqfs-paths", nargs="+", required=True, + help="Ordered sqfs paths for this batch") + parser.add_argument("--sqfs-start-idx", type=int, required=True, + help="Global index of first sqfs (task_id = photo_id %% 10)") + parser.add_argument("--images-per-shard", type=int, default=1000, + help="Images per shard (default: 1000). Row 0–(N-1) in the CSV → shard 0, " + "rows N–(2N-1) → shard 1, etc.") + parser.add_argument("--nvme-dir", required=True, + help="NVMe scratch root ($SLURM_TMPDIR)") + parser.add_argument("--output-dir", required=True, + help="Lustre output dir; split subdirs and class_map.json written here") + parser.add_argument("--pack-workers", type=int, default=16, + help="Parallel workers for tar packing") + parser.add_argument("--tar-mode", choices=["w", "a"], default="w", + help="'w' create new tars, 'a' append (idempotent on retry)") + parser.add_argument("--taxon-id-column", default="inat_taxon_id", + help="CSV column to use as label (default: inat_taxon_id). " + "Taxon IDs are sorted numerically and mapped to sequential " + "class IDs. Ignored if 'class_id' is already in the CSV.") + parser.add_argument("--seed", type=int, default=42, + help="Seed for shard assignment and within-shard shuffle") + parser.add_argument("--no-nvme-copy", action="store_true", + help="Mount sqfs directly from source, skip copy to NVMe. " + "Use for smoke-testing scatter/pack logic without needing " + "full NVMe space.") + parser.add_argument("--limit", type=int, default=0, + help="Stop scatter after N images per sqfs (0 = no limit). " + "Use with --no-nvme-copy for a fast end-to-end smoke test.") + args = parser.parse_args() + + + nvme_dir = Path(args.nvme_dir) + nvme_mnt = nvme_dir / "mnt" + nvme_shards = nvme_dir / "shards" + output_dir = Path(args.output_dir) + output_dir.mkdir(parents=True, exist_ok=True) + + split_csvs = parse_split_csvs(args.split_csvs) + task_ids = list(range(args.sqfs_start_idx, + args.sqfs_start_idx + len(args.sqfs_paths))) + + log("=== create_webdataset ===") + log(f"Splits: {[n for n,_ in split_csvs]}") + log(f"Tasks: {task_ids} sqfs: {len(args.sqfs_paths)}") + log(f"Output: {output_dir}") + if args.no_nvme_copy: + log("Mode: --no-nvme-copy (smoke test — sqfs mounted in place)") + if args.limit: + log(f"Mode: --limit {args.limit} images per sqfs (smoke test)") + print() + + # ── NVMe preflight ──────────────────────────────────────────────────────── + # Peak NVMe usage occurs during scatter of the last sqfs in the batch: + # peak = batch_size × avg_scatter_per_sqfs + largest_sqfs_file + # scatter_per_sqfs ≈ sqfs_size (images decompressed into shard dirs, ~3 files + # per image but .cls/.json are tiny vs JPEG, so ≈ sqfs size) + # This matches the analysis in job_build_wds_global.sh: + # N=4 (5th sqfs): 5×1.07 + 1.1 = 6.45 TB (safe under 7 TB with BATCH_SIZE=5) + if not args.no_nvme_copy: + sqfs_sizes_gb = [Path(p).stat().st_size / 1024**3 for p in args.sqfs_paths] + n = len(sqfs_sizes_gb) + avg_scatter = sum(sqfs_sizes_gb) / n + peak_gb = n * avg_scatter + max(sqfs_sizes_gb) + nvme_free_gb = disk_free_tb(str(nvme_dir)) * 1024 + log(f"NVMe preflight: {n} sqfs avg={avg_scatter:.1f} GB " + f"peak estimate={peak_gb:.0f} GB ({n}×{avg_scatter:.1f} + {max(sqfs_sizes_gb):.1f}) " + f"NVMe free={nvme_free_gb:.0f} GB") + if peak_gb > nvme_free_gb * 0.9: + log(f"WARNING: estimated peak ({peak_gb:.0f} GB) exceeds 90% of NVMe " + f"({nvme_free_gb:.0f} GB) — reduce --batch-size or request more --tmp") + print() + + # ── Load CSVs ───────────────────────────────────────────────────────────── + log("Loading split CSVs ...") + dfs, class_map = load_split_csvs(split_csvs, taxon_id_column=args.taxon_id_column) + save_class_map(class_map, output_dir) + + shards_per_split = { + name: math.ceil(len(df) / args.images_per_shard) + for name, df in dfs.items() + } + log(f"Shards per split (--images-per-shard={args.images_per_shard}): {shards_per_split}") + + structural = {"photo_id", "relative_local_path", "class_id", "task_id", "split", "row_idx"} + all_cols = set().union(*(df.columns for df in dfs.values())) + meta_columns = [c for c in all_cols if c not in structural] + + # ── Build per-task lookup ────────────────────────────────────────────────── + log(f"Building image lookup for tasks {task_ids} ...") + lookup = build_lookup(dfs, task_ids, meta_columns, args.images_per_shard) + del dfs # free RAM + print() + + # ── Pre-create shard dirs ───────────────────────────────────────────────── + log("Pre-creating shard dirs on NVMe ...") + nvme_mnt.mkdir(exist_ok=True) + for split, n in shards_per_split.items(): + for s in range(n): + (nvme_shards / split / f"shard_{s:06d}").mkdir(parents=True, exist_ok=True) + log(f" Done ({sum(shards_per_split.values())} dirs across {len(shards_per_split)} splits)") + print() + + # ── Scatter phase ───────────────────────────────────────────────────────── + total_written = total_skipped = 0 + for i, sqfs_path_str in enumerate(args.sqfs_paths): + sqfs_path = Path(sqfs_path_str) + sqfs_idx = args.sqfs_start_idx + i + task_lkp = lookup.get(sqfs_idx, {}) + log(f"--- Scatter sqfs {sqfs_idx} ({sqfs_path.name}) " + f"lookup: {len(task_lkp):,} entries ---") + w, s = scatter_sqfs(sqfs_path, sqfs_idx, nvme_dir, nvme_mnt, + nvme_shards, task_lkp, + no_nvme_copy=args.no_nvme_copy, limit=args.limit) + total_written += w + total_skipped += s + print() + + if total_skipped: + log(f"WARNING: {total_skipped:,} images skipped (in sqfs but not in CSVs)") + + # ── Pack phase ──────────────────────────────────────────────────────────── + log(f"--- Pack → Lustre (mode='{args.tar_mode}') ---") + pack_to_lustre( + [n for n, _ in split_csvs], shards_per_split, + nvme_shards, output_dir, args.tar_mode, args.pack_workers, args.seed, + ) + + elapsed = (time.perf_counter() - _t_start) / 3600 + log(f"\nBatch done: {total_written:,} images written " + f"{total_skipped:,} skipped {elapsed:.2f}h elapsed") + + +if __name__ == "__main__": + main() diff --git a/src/dataset_tools/bq_squashfs/create_webdataset_generic.py b/src/dataset_tools/bq_squashfs/create_webdataset_generic.py new file mode 100644 index 0000000..cf0e3e2 --- /dev/null +++ b/src/dataset_tools/bq_squashfs/create_webdataset_generic.py @@ -0,0 +1,303 @@ +#!/usr/bin/env python3 +""" +Build a WebDataset from pre-mounted task directories and per-split CSVs. + +Assumes images are mounted under a parent directory with one subdirectory per +task, named task_0/ through task_9/. The task that owns each image is derived +from its photo_id: task_id = photo_id % 10. relative_local_path in the CSV is +the path of the image *within* its task directory. + +Full image path: /task_/ + +This is a generic alternative to create_webdataset.py that does not require +SquashFS, squashfuse, or NVMe scratch space. It produces identical output tar +shards and is compatible with the same training scripts. + +Usage: + python create_webdataset_from_dir.py \\ + --images-dir /mnt/images \\ + --split-csvs train:train.csv val:val.csv test:test.csv \\ + --output-dir /scratch/global_wds \\ + --images-per-shard 1000 +""" + +import argparse +import json +import math +import random +import tarfile +import time +from collections import defaultdict +from concurrent.futures import ThreadPoolExecutor +from pathlib import Path + +import pandas as pd + +IMAGE_EXTENSIONS = {".jpg", ".jpeg", ".png"} + +_t_start = time.perf_counter() + + +def log(msg: str) -> None: + elapsed = (time.perf_counter() - _t_start) / 3600 + print(f"[{elapsed:5.2f}h] {msg}", flush=True) + + +# ── CSV loading ─────────────────────────────────────────────────────────────── + +def load_split_csvs( + split_csvs: list[tuple[str, str]], + taxon_id_column: str, + images_per_shard: int, +) -> tuple[dict[str, pd.DataFrame], dict[int, int]]: + dfs: dict[str, pd.DataFrame] = {} + for name, path in split_csvs: + df = pd.read_csv(path, dtype={"photo_id": "Int64"}) + missing = {"photo_id", "relative_local_path"} - set(df.columns) + if missing: + raise ValueError(f"{path}: missing required columns {missing}") + assert taxon_id_column in df.columns, ( + f"{path}: required label column '{taxon_id_column}' not found " + f"(available: {list(df.columns)})" + ) + df["task_id"] = (df["photo_id"] % 10).astype(int) + df["row_idx"] = range(len(df)) + df["shard_id"] = df["row_idx"] // images_per_shard + dfs[name] = df + log(f" {name}: {len(df):,} rows {df['shard_id'].max() + 1} shards ({path})") + + all_df = pd.concat(dfs.values(), ignore_index=True) + class_map: dict[int, int] = {} + if "class_id" not in all_df.columns: + taxon_ids = sorted(all_df[taxon_id_column].dropna().unique()) + id_map = {int(tid): idx for idx, tid in enumerate(taxon_ids)} + for split_name, df in dfs.items(): + dfs[split_name]["class_id"] = df[taxon_id_column].map(id_map).astype("Int64") + log(f" Assigned class_ids for {len(taxon_ids):,} taxon IDs (column='{taxon_id_column}')") + class_map = {idx: int(tid) for tid, idx in id_map.items()} + else: + class_map = ( + all_df[["class_id", taxon_id_column]] + .dropna() + .drop_duplicates("class_id") + .set_index("class_id")[taxon_id_column] + .astype(int) + .to_dict() + ) + + return dfs, class_map + + +def save_class_map(class_map: dict[int, int], output_dir: Path) -> None: + dest = output_dir / "class_map.json" + if dest.exists(): + log(f"class_map.json already exists — skipping ({dest})") + return + with open(dest, "w") as f: + json.dump({str(k): v for k, v in sorted(class_map.items())}, f, indent=2) + log(f"Saved class_map.json ({len(class_map):,} classes → {dest})") + + +# ── Build lookup ────────────────────────────────────────────────────────────── + +def build_lookup( + dfs: dict[str, pd.DataFrame], + meta_columns: list[str], +) -> dict[str, dict]: + """Build {relative_local_path: {split, shard_id, task_id, class_id, meta}}.""" + lookup: dict[str, dict] = {} + for split, df in dfs.items(): + for row in df.itertuples(index=False): + meta = { + col: getattr(row, col) + for col in meta_columns + if hasattr(row, col) and not pd.isna(getattr(row, col)) + } + lookup[row.relative_local_path] = { + "split": split, + "shard_id": int(row.shard_id), + "task_id": int(row.task_id), + "class_id": int(row.class_id), + "meta": meta, + } + log(f"Lookup built: {len(lookup):,} entries across {len(dfs)} splits") + return lookup + + +# ── Walk image dirs and collect per-shard path lists ───────────────────────── + +def collect_shards( + images_dir: Path, + lookup: dict[str, dict], +) -> tuple[dict[tuple[str, int], list[tuple[Path, str, dict]]], int]: + """ + Walk task_0/ … task_9/ under images_dir. + For each image found in the lookup, append (abs_path, rel_path, entry) + to the appropriate (split, shard_id) bucket. + + Returns (shard_buckets, n_missing). + """ + shard_buckets: dict[tuple[str, int], list] = defaultdict(list) + found = missing = skipped = 0 + t0 = time.perf_counter() + + for task_dir in sorted(images_dir.iterdir()): + if not task_dir.is_dir() or not task_dir.name.startswith("task_"): + continue + + log(f" Walking {task_dir.name} ...") + for img_path in task_dir.rglob("*"): + if img_path.suffix.lower() not in IMAGE_EXTENSIONS: + continue + + rel = str(img_path.relative_to(task_dir)) + entry = lookup.get(rel) + if entry is None: + skipped += 1 + continue + + key = (entry["split"], entry["shard_id"]) + shard_buckets[key].append((img_path, rel, entry)) + found += 1 + + if found % 500_000 == 0: + elapsed = time.perf_counter() - t0 + log(f" {found:,} found {skipped:,} skipped {found/elapsed:.0f} img/s") + + missing = len(lookup) - found + elapsed = time.perf_counter() - t0 + log(f"Walk complete: {found:,} found {skipped:,} not-in-csv " + f"{missing:,} in-csv-not-found {elapsed:.0f}s") + return dict(shard_buckets), missing + + +# ── Pack shards to tar ──────────────────────────────────────────────────────── + +def pack_shards( + shard_buckets: dict[tuple[str, int], list], + output_dir: Path, + splits: list[str], + pack_workers: int, + seed: int, +) -> None: + rng = random.Random(seed) + + for split in splits: + split_dir = output_dir / split + split_dir.mkdir(parents=True, exist_ok=True) + + items = list(shard_buckets.items()) + log(f"Packing {len(items):,} shards ({pack_workers} workers) ...") + t0 = time.perf_counter() + total_bytes = total_images = 0 + + def pack_one(item: tuple) -> tuple[int, int]: + (split, shard_id), entries = item + rng.shuffle(entries) + tar_path = output_dir / split / f"{split}-{shard_id:06d}.tar" + nb = ni = 0 + with tarfile.open(tar_path, "w") as tf: + for img_path, rel, entry in entries: + import hashlib + key = hashlib.md5(f"{entry['task_id']}:{rel}".encode()).hexdigest() + ext = img_path.suffix.lower() + + img_bytes = img_path.read_bytes() + cls_bytes = str(entry["class_id"]).encode() + meta_bytes = json.dumps({ + "class_id": entry["class_id"], + "relative_local_path": rel, + "task_id": entry["task_id"], + **entry["meta"], + }).encode() + + for suffix, data in [(ext, img_bytes), (".cls", cls_bytes), (".json", meta_bytes)]: + import io + buf = io.BytesIO(data) + info = tarfile.TarInfo(name=f"{key}{suffix}") + info.size = len(data) + tf.addfile(info, buf) + + nb += len(img_bytes) + ni += 1 + return nb, ni + + with ThreadPoolExecutor(max_workers=pack_workers) as pool: + for nb, ni in pool.map(pack_one, items): + total_bytes += nb + total_images += ni + + elapsed = time.perf_counter() - t0 + log(f"Packed {total_images:,} images {total_bytes/1024**3:.1f} GB " + f"{elapsed:.0f}s ({total_bytes/1024**2/elapsed:.0f} MB/s)") + + +# ── Main ────────────────────────────────────────────────────────────────────── + +def parse_split_csvs(values: list[str]) -> list[tuple[str, str]]: + return [(v.split(":", 1)[0].strip(), v.split(":", 1)[1].strip()) for v in values] + + +def main() -> None: + parser = argparse.ArgumentParser( + description=__doc__, + formatter_class=argparse.RawDescriptionHelpFormatter, + ) + parser.add_argument("--images-dir", required=True, + help="Parent dir containing task_0/ … task_9/ subdirectories") + parser.add_argument("--split-csvs", nargs="+", required=True, + help="Per-split CSV files as 'name:path' (e.g. train:train.csv)") + parser.add_argument("--output-dir", required=True, + help="Output directory; split subdirs and class_map.json written here") + parser.add_argument("--images-per-shard", type=int, default=1000, + help="Images per shard, assigned by CSV row order (default: 1000)") + parser.add_argument("--taxon-id-column", default="inat_taxon_id", + help="CSV column used as label (default: inat_taxon_id)") + parser.add_argument("--pack-workers", type=int, default=16, + help="Parallel workers for tar packing (default: 16)") + parser.add_argument("--seed", type=int, default=42, + help="Seed for within-shard shuffle (default: 42)") + args = parser.parse_args() + + images_dir = Path(args.images_dir) + output_dir = Path(args.output_dir) + output_dir.mkdir(parents=True, exist_ok=True) + + split_csvs = parse_split_csvs(args.split_csvs) + splits = [name for name, _ in split_csvs] + + log("=== create_webdataset_from_dir ===") + log(f"Images dir : {images_dir}") + log(f"Splits : {splits}") + log(f"Output : {output_dir}") + print() + + log("Loading split CSVs ...") + dfs, class_map = load_split_csvs(split_csvs, args.taxon_id_column, args.images_per_shard) + save_class_map(class_map, output_dir) + print() + + structural = {"photo_id", "relative_local_path", "class_id", "task_id", "shard_id", "row_idx"} + all_cols = set().union(*(df.columns for df in dfs.values())) + meta_columns = [c for c in all_cols if c not in structural] + + log("Building image lookup ...") + lookup = build_lookup(dfs, meta_columns) + del dfs + print() + + log("Collecting images from task directories ...") + shard_buckets, n_missing = collect_shards(images_dir, lookup) + if n_missing > 0: + log(f"WARNING: {n_missing:,} images are in the CSVs but were not found on disk") + print() + + log("Packing shards → output dir ...") + pack_shards(shard_buckets, output_dir, splits, args.pack_workers, args.seed) + + elapsed = (time.perf_counter() - _t_start) / 3600 + log(f"\nDone in {elapsed:.2f}h") + + +if __name__ == "__main__": + main() diff --git a/src/dataset_tools/bq_squashfs/download_images.py b/src/dataset_tools/bq_squashfs/download_images.py new file mode 100644 index 0000000..033c524 --- /dev/null +++ b/src/dataset_tools/bq_squashfs/download_images.py @@ -0,0 +1,627 @@ +#!/usr/bin/env python3 +""" +Download images from training_images BQ table to a local staging directory, +record results in training_images_downloads, merge status back into +training_images, then pack images into SquashFS chunks. + +Images are split across parallel jobs using MOD(photo_id, num_jobs) = task_id +so each job handles a balanced, non-overlapping subset of images. + +Resumable: already-attempted images are skipped by LEFT JOINing with +training_images_downloads. Re-running the same task_id is safe. + +NOTE — mid-chunk restart behaviour: if the job dies after downloading a chunk +but before the BQ write, those images are on disk but unrecorded. On resume +the LEFT JOIN will re-queue them and they will be re-downloaded. No data is +lost but ~chunk_size images are downloaded twice. This is acceptable given +the low probability and low cost of a single chunk redo. + +Usage (single job / test): + python download_images.py \\ + --staging-dir /scratch/$USER/staging \\ + --num-jobs 1 --task-id 0 \\ + --limit 50 --table-prefix test_ + +Usage (SLURM array): + python download_images.py \\ + --staging-dir /scratch/$USER/staging \\ + --num-jobs 10 --task-id $SLURM_ARRAY_TASK_ID + +After all array tasks finish, run job_bq_pack_per_task.sh to merge +chunk sqfs files into the final task_N.sqfs archives. +""" + +import argparse +import random +import subprocess +import threading +import time +import uuid +from concurrent.futures import ThreadPoolExecutor, as_completed +from pathlib import Path + +import pandas as pd +import PIL +import requests +from google.api_core import exceptions as google_exceptions +from google.cloud import bigquery +from PIL import Image + +Image.MAX_IMAGE_PIXELS = None + +BQ_PROJECT = "leps-ai" +BQ_DEFAULT_DATASET = "global_butterflies_2604" + +# Retry config for HTTP downloads +_RETRY_STATUSES = {429, 500, 502, 503, 504} +_MAX_RETRIES = 5 +_BACKOFF_BASE = 2.0 # seconds +_BACKOFF_MAX = 60.0 # seconds cap +_MERGE_MAX_RETRIES = 10 # BQ MERGE serialization conflicts (concurrent tasks) + +# Warn if this many chunk sqfs files accumulate (pack job falling behind) +_CHUNK_ACCUMULATION_WARN = 20 + +DOWNLOADS_SCHEMA = [ + bigquery.SchemaField("dataset_source_uuid", "STRING"), + bigquery.SchemaField("fetch_status", "STRING"), + bigquery.SchemaField("image_width", "INTEGER"), + bigquery.SchemaField("image_height", "INTEGER"), + bigquery.SchemaField("image_size", "INTEGER"), + bigquery.SchemaField("corrupted", "BOOLEAN"), +] + +# Thread-local storage for per-thread requests sessions +_thread_local = threading.local() + + +def _get_session() -> requests.Session: + """Return a per-thread requests.Session with a single keep-alive connection.""" + if not hasattr(_thread_local, "session"): + s = requests.Session() + adapter = requests.adapters.HTTPAdapter( + pool_connections=1, + pool_maxsize=1, + max_retries=0, # retries handled manually below + ) + s.mount("https://", adapter) + s.mount("http://", adapter) + _thread_local.session = s + return _thread_local.session + + +def _fetch_with_retry(url: str, dest: Path) -> None: + """Download url → dest with exponential backoff + jitter. + + Retries on rate-limit (429), transient server errors (5xx), + connection errors (including Errno 16 — too many open sockets), + and timeouts. Raises on permanent client errors (4xx except 429) + or after exhausting all retries. + """ + session = _get_session() + for attempt in range(_MAX_RETRIES + 1): + try: + resp = session.get(url, timeout=30, stream=True) + if resp.status_code in _RETRY_STATUSES and attempt < _MAX_RETRIES: + delay = min(_BACKOFF_BASE * (2**attempt), _BACKOFF_MAX) + delay += random.uniform(0, delay * 0.25) + print( + f" HTTP {resp.status_code} {url} — retry {attempt+1}/{_MAX_RETRIES} " + f"in {delay:.1f}s", + flush=True, + ) + time.sleep(delay) + continue + resp.raise_for_status() + with open(dest, "wb") as f: + for chunk in resp.iter_content(chunk_size=8192): + f.write(chunk) + return + except requests.exceptions.ConnectionError as e: + if attempt < _MAX_RETRIES: + delay = min(_BACKOFF_BASE * (2**attempt), _BACKOFF_MAX) + delay += random.uniform(0, delay * 0.25) + print( + f" ConnectionError {url} — retry {attempt+1}/{_MAX_RETRIES} " + f"in {delay:.1f}s: {e}", + flush=True, + ) + time.sleep(delay) + else: + raise + except requests.exceptions.Timeout: + if attempt < _MAX_RETRIES: + delay = min(_BACKOFF_BASE * (2**attempt), _BACKOFF_MAX) + delay += random.uniform(0, delay * 0.25) + print( + f" Timeout {url} — retry {attempt+1}/{_MAX_RETRIES} " + f"in {delay:.1f}s", + flush=True, + ) + time.sleep(delay) + else: + raise + raise RuntimeError(f"Exhausted {_MAX_RETRIES} retries for {url}") + + +def download_and_verify(row: dict, staging_dir: Path) -> dict: + """Download one image, verify with PIL, return result dict.""" + url = row["absolute_url"] + rel_path = row["relative_local_path"] + dest = staging_dir / rel_path + dest.parent.mkdir(parents=True, exist_ok=True) + + result = { + "dataset_source_uuid": row["dataset_source_uuid"], + "fetch_status": None, + "image_width": None, + "image_height": None, + "image_size": None, + "corrupted": None, + } + + try: + _fetch_with_retry(url, dest) + except Exception as e: + print(f" Failed {url}: {e}", flush=True) + result["fetch_status"] = "failed" + return result + + try: + with Image.open(dest) as img: + img.convert("RGB") + result["image_width"], result["image_height"] = img.size + result["image_size"] = dest.stat().st_size + result["corrupted"] = False + result["fetch_status"] = "downloaded" + except (PIL.UnidentifiedImageError, OSError) as e: + print(f" Corrupted {url}: {e}", flush=True) + result["corrupted"] = True + result["fetch_status"] = "corrupted" + + return result + + +def ensure_downloads_table(client: bigquery.Client, downloads_table: str) -> None: + """Create the downloads table if it doesn't exist.""" + try: + client.get_table(downloads_table) + except Exception: + table = bigquery.Table(downloads_table, schema=DOWNLOADS_SCHEMA) + table.description = ( + "Download results for training_images. One row per download attempt. " + "Appended to by parallel download jobs. Used to track fetch progress " + "without DML updates on the base training_images table." + ) + client.create_table(table) + print(f"Created table {downloads_table}", flush=True) + + +def write_results_to_bq( + client: bigquery.Client, + results: list[dict], + downloads_table: str, + max_retries: int = 3, +) -> None: + """Append download results to the downloads table via batch load (free tier). + + Multiple parallel tasks can safely append simultaneously. + Retries up to max_retries times on transient BQ errors. + """ + df = pd.DataFrame(results) + job_config = bigquery.LoadJobConfig( + write_disposition=bigquery.WriteDisposition.WRITE_APPEND, + schema=DOWNLOADS_SCHEMA, + ) + for attempt in range(max_retries): + try: + job = client.load_table_from_dataframe( + df, downloads_table, job_config=job_config + ) + job.result() + return + except Exception as e: + if attempt < max_retries - 1: + delay = 30 * (attempt + 1) + print( + f" BQ write failed (attempt {attempt+1}/{max_retries}): {e} " + f"— retrying in {delay}s", + flush=True, + ) + time.sleep(delay) + else: + raise + + +def merge_chunk_into_training_images( + client: bigquery.Client, + results: list[dict], + training_table: str, + downloads_table: str, +) -> int: + """MERGE all chunk results into training_images, updating fetch_status for every outcome. + + Outcomes merged: + downloaded — fetch_status='downloaded', dims and corrupted populated + corrupted — fetch_status='corrupted', corrupted=True, dims NULL + failed — fetch_status='failed', all fields NULL + + Permanently failed images (404, 403, exhausted retries) are marked + fetch_status='failed' so they are excluded from future re-runs via + the WHERE fetch_status='pending' clause — no wasted retry attempts. + Retrying can still be done intentionally via retry_failed_downloads.py. + + Uses a temp table containing only this chunk's rows so the MERGE scans + a small dataset rather than all of training_images_downloads. + Only updates rows that are still 'pending' — safe from parallel tasks. + Returns the number of rows updated. + """ + to_merge = [ + r for r in results if r["fetch_status"] in ("downloaded", "corrupted", "failed") + ] + if not to_merge: + return 0 + + # Derive "project.dataset" from training_table ("project.dataset.table_name") + _parts = training_table.split(".") + _dataset_ref = ".".join(_parts[:2]) if len(_parts) >= 2 else _parts[0] + tmp_table = f"{_dataset_ref}._dl_merge_tmp_{uuid.uuid4().hex[:8]}" + df = pd.DataFrame(to_merge) + job_config = bigquery.LoadJobConfig( + write_disposition=bigquery.WriteDisposition.WRITE_TRUNCATE, + schema=DOWNLOADS_SCHEMA, + ) + client.load_table_from_dataframe(df, tmp_table, job_config=job_config).result() + + merge_sql = f""" + MERGE `{training_table}` T + USING `{tmp_table}` S + ON T.dataset_source_uuid = S.dataset_source_uuid + WHEN MATCHED AND (T.fetch_status = 'pending' OR T.fetch_status IS NULL) THEN UPDATE SET + T.fetch_status = S.fetch_status, + T.image_width = S.image_width, + T.image_height = S.image_height, + T.image_size = S.image_size, + T.corrupted = S.corrupted + """ + try: + # Concurrent MERGEs from parallel tasks can collide with + # "Could not serialize access ... due to concurrent update" (400). + # BQ docs recommend retrying — back off with jitter until a slot frees. + for attempt in range(_MERGE_MAX_RETRIES + 1): + try: + job = client.query(merge_sql) + job.result() + return job.dml_stats.updated_row_count + except google_exceptions.BadRequest as e: + if "serialize" not in str(e).lower() or attempt >= _MERGE_MAX_RETRIES: + raise + delay = min(_BACKOFF_BASE * (2**attempt), _BACKOFF_MAX) + delay += random.uniform(0, delay * 0.5) + print( + f" MERGE serialization conflict — retry " + f"{attempt+1}/{_MERGE_MAX_RETRIES} in {delay:.0f}s", + flush=True, + ) + time.sleep(delay) + finally: + client.delete_table(tmp_table, not_found_ok=True) + + +def get_pending_rows( + client: bigquery.Client, + training_table: str, + downloads_table: str, + num_jobs: int, + task_id: int, + limit: int | None = None, + force_redownload: bool = False, +) -> list[dict]: + """Query pending images for this task, skipping already-attempted ones.""" + limit_clause = f"LIMIT {limit}" if limit else "" + if force_redownload: + query = f""" + SELECT dataset_source_uuid, absolute_url, relative_local_path + FROM `{training_table}` + WHERE (fetch_status = 'pending' OR fetch_status IS NULL) + AND MOD(photo_id, {num_jobs}) = {task_id} + {limit_clause} + """ + else: + query = f""" + SELECT ti.dataset_source_uuid, ti.absolute_url, ti.relative_local_path + FROM `{training_table}` ti + LEFT JOIN `{downloads_table}` d + ON ti.dataset_source_uuid = d.dataset_source_uuid + WHERE (ti.fetch_status = 'pending' OR ti.fetch_status IS NULL) + AND MOD(ti.photo_id, {num_jobs}) = {task_id} + AND d.dataset_source_uuid IS NULL + {limit_clause} + """ + return [dict(r) for r in client.query(query).result()] + + +def pack_chunk_to_sqfs( + staging_dir: Path, chunk_num: int, num_workers: int = 4 +) -> Path | None: + """Pack downloaded images into a per-chunk SquashFS file. + + Uses bucket subdirs (000/, 001/, ...) directly so paths inside the archive + are clean: 000/abc123.jpg rather than staging_dir/000/abc123.jpg. + Raises RuntimeError if mksquashfs fails so the SLURM task is marked failed. + """ + bucket_dirs = sorted(d for d in staging_dir.iterdir() if d.is_dir()) + if not bucket_dirs: + print( + f" No images in staging dir — skipping sqfs pack for chunk {chunk_num}", + flush=True, + ) + return None + + chunk_sqfs = staging_dir / f"chunk_{chunk_num:04d}.sqfs" + cmd = [ + "mksquashfs", + *[str(d) for d in bucket_dirs], + str(chunk_sqfs), + "-noappend", + "-no-xattrs", + "-comp", + "zstd", + "-Xcompression-level", + "3", + "-processors", + str(num_workers), + ] + print( + f" Packing {len(bucket_dirs)} bucket dirs → {chunk_sqfs.name}...", flush=True + ) + result = subprocess.run(cmd, check=False) + if result.returncode != 0: + raise RuntimeError( + f"mksquashfs failed with exit code {result.returncode} for chunk {chunk_num}. " + f"Staging dir preserved for inspection: {staging_dir}" + ) + size_mb = chunk_sqfs.stat().st_size / (1024**2) + print(f" Packed: {chunk_sqfs.name} ({size_mb:.1f} MB)", flush=True) + return chunk_sqfs + + +def clear_staging(staging_dir: Path) -> None: + """Remove all image files from staging dir, preserving .sqfs chunk files.""" + for f in staging_dir.rglob("*"): + if f.is_file() and f.suffix != ".sqfs": + f.unlink() + for d in sorted(staging_dir.rglob("*"), reverse=True): + if d.is_dir(): + try: + d.rmdir() + except OSError: + pass + + +def warn_chunk_accumulation(staging_dir: Path) -> None: + """Warn if too many chunk sqfs files have built up in staging.""" + count = len(list(staging_dir.glob("chunk_*.sqfs"))) + if count >= _CHUNK_ACCUMULATION_WARN: + print( + f" WARNING: {count} chunk sqfs files in {staging_dir} — " + f"pack job may be falling behind or a previous run left chunks behind.", + flush=True, + ) + + +def main(): + parser = argparse.ArgumentParser( + description=__doc__, formatter_class=argparse.RawDescriptionHelpFormatter + ) + parser.add_argument( + "--staging-dir", + required=True, + help=( + "Local directory where images are downloaded before packing. " + "Use scratch (e.g. /scratch/$USER/staging), not home — home has " + "a 500k inode quota and each image counts as one inode. " + "chunk_NNNN.sqfs files accumulate here until job_bq_pack_per_task.sh " + "merges them into the final task_N.sqfs." + ), + ) + parser.add_argument( + "--num-jobs", + type=int, + required=True, + help=( + "Total number of parallel download tasks. Images are partitioned " + "by MOD(photo_id, num_jobs) so each task gets a non-overlapping " + "subset. Must match the SLURM --array range: --num-jobs 10 requires " + "--array=0-9 in the job script. Typical value: 10." + ), + ) + parser.add_argument( + "--task-id", + type=int, + required=True, + help=( + "Index of this task (0 to num_jobs-1). In a SLURM array job set " + "this to $SLURM_ARRAY_TASK_ID. This task will download all images " + "where MOD(photo_id, num_jobs) == task_id." + ), + ) + parser.add_argument( + "--num-workers", + type=int, + default=32, + help=( + "Number of parallel download threads per task (default: 32). " + "With 10 tasks running simultaneously this means up to 320 " + "concurrent connections to iNaturalist S3. At scale this caused " + "Errno 16 (too many open sockets) — the retry logic handles it " + "but reducing to 16-24 workers per task lowers the error rate." + ), + ) + parser.add_argument( + "--chunk-size", + type=int, + default=10000, + help=( + "Number of images to download before packing into a sqfs chunk " + "and clearing the staging dir (default: 10000). Lower values " + "reduce peak inode usage in staging but produce more chunk files " + "for the pack job to merge. Each chunk becomes one " + "chunk_NNNN.sqfs file in --staging-dir." + ), + ) + parser.add_argument( + "--limit", + type=int, + default=None, + help=( + "Cap the total number of images queried from BQ. Only for " + "small-scale tests — omit for production runs. " + "Example: --limit 50 --table-prefix test_ for a quick smoke test." + ), + ) + parser.add_argument( + "--force-redownload", + action="store_true", + help=( + "Ignore existing records in training_images_downloads and " + "re-download all images for this task. Use when staging files " + "were deleted after a failed pack job and you need to rebuild " + "the chunks from scratch. Without this flag, already-attempted " + "images are skipped via LEFT JOIN." + ), + ) + parser.add_argument( + "--dataset", + default=BQ_DEFAULT_DATASET, + help=( + f"BigQuery dataset name within the leps-ai project " + f"(default: {BQ_DEFAULT_DATASET}). " + f"Example: --dataset global_all_leps_2605" + ), + ) + parser.add_argument( + "--table-prefix", + default="", + help=( + "BQ table name prefix for testing without touching production. " + "Example: --table-prefix test_ reads from test_training_images " + "and writes to test_training_images_downloads. " + "Create test tables first with create_test_tables.py." + ), + ) + args = parser.parse_args() + + training_table = f"{BQ_PROJECT}.{args.dataset}.{args.table_prefix}training_images" + downloads_table = ( + f"{BQ_PROJECT}.{args.dataset}.{args.table_prefix}training_images_downloads" + ) + + client = bigquery.Client(project=BQ_PROJECT) + staging_dir = Path(args.staging_dir) + staging_dir.mkdir(parents=True, exist_ok=True) + + print(f"=== download_images task={args.task_id}/{args.num_jobs} ===", flush=True) + print(f"training table : {training_table}", flush=True) + print(f"downloads table : {downloads_table}", flush=True) + print(f"staging dir : {staging_dir}", flush=True) + print( + f"workers : {args.num_workers} chunk_size={args.chunk_size}", + flush=True, + ) + print(flush=True) + + ensure_downloads_table(client, downloads_table) + warn_chunk_accumulation(staging_dir) + + print( + f"Querying pending rows (force_redownload={args.force_redownload})...", + flush=True, + ) + rows = get_pending_rows( + client, + training_table, + downloads_table, + args.num_jobs, + args.task_id, + limit=args.limit, + force_redownload=args.force_redownload, + ) + print(f"{len(rows):,} pending images to download", flush=True) + + total_downloaded = total_failed = total_corrupted = 0 + + for chunk_start in range(0, len(rows), args.chunk_size): + chunk = rows[chunk_start : chunk_start + args.chunk_size] + chunk_num = chunk_start // args.chunk_size + 1 + total_chunks = (len(rows) + args.chunk_size - 1) // args.chunk_size + print( + f"\n[Task {args.task_id}] Chunk {chunk_num}/{total_chunks} " + f"({len(chunk):,} images)...", + flush=True, + ) + + # Download in parallel + results = [] + t0 = time.perf_counter() + n_ok = n_fail = n_corrupt = 0 + + with ThreadPoolExecutor(max_workers=args.num_workers) as executor: + futures = { + executor.submit(download_and_verify, row, staging_dir): row + for row in chunk + } + for i, future in enumerate(as_completed(futures)): + r = future.result() + results.append(r) + if r["fetch_status"] == "downloaded": + n_ok += 1 + elif r["fetch_status"] == "failed": + n_fail += 1 + elif r["fetch_status"] == "corrupted": + n_corrupt += 1 + if (i + 1) % 1000 == 0: + elapsed = time.perf_counter() - t0 + print( + f" {i+1:,}/{len(chunk):,} " + f"downloaded={n_ok:,} failed={n_fail:,} corrupted={n_corrupt:,} " + f"({(i+1)/elapsed:.0f} img/s)", + flush=True, + ) + + elapsed = time.perf_counter() - t0 + total_downloaded += n_ok + total_failed += n_fail + total_corrupted += n_corrupt + print( + f" Chunk done in {elapsed:.0f}s ({len(chunk)/elapsed:.0f} img/s) " + f"downloaded={n_ok:,} failed={n_fail:,} corrupted={n_corrupt:,}", + flush=True, + ) + + # Append to downloads table (batch load — free tier, parallel-safe) + write_results_to_bq(client, results, downloads_table) + print(f" Written to {downloads_table}", flush=True) + + # Inline MERGE into training_images so status is current without a separate job + n_updated = merge_chunk_into_training_images( + client, results, training_table, downloads_table + ) + print(f" Merged {n_updated:,} rows into {training_table}", flush=True) + + # Pack images into chunk sqfs then clear raw files to keep inode usage low + pack_chunk_to_sqfs(staging_dir, chunk_num, num_workers=4) + clear_staging(staging_dir) + warn_chunk_accumulation(staging_dir) + print(f" Staging cleared (chunk sqfs kept)", flush=True) + + print( + f"\n[Task {args.task_id}] Done. " + f"downloaded={total_downloaded:,} failed={total_failed:,} " + f"corrupted={total_corrupted:,}", + flush=True, + ) + + +if __name__ == "__main__": + main() diff --git a/src/dataset_tools/bq_squashfs/merge_sqfs_chunks.py b/src/dataset_tools/bq_squashfs/merge_sqfs_chunks.py new file mode 100644 index 0000000..b0fefa7 --- /dev/null +++ b/src/dataset_tools/bq_squashfs/merge_sqfs_chunks.py @@ -0,0 +1,337 @@ +#!/usr/bin/env python3 +""" +Merge per-chunk SquashFS files produced by download_images.py into a single +task-level SquashFS archive. + +WHAT IT DOES +------------ +download_images.py packs downloaded images into small chunk_NNNN.sqfs files +(one per 10,000 images) to keep inode usage under the cluster quota. This +script merges all those chunks for one task into a single task_N.sqfs archive +that the webdataset build step can use directly. + +The merge is done by streaming: each chunk is mounted via squashfuse, its +files are written into a continuous tar stream on stdout, then unmounted. The +tar stream is consumed by sqfstar on the other end of the pipe to produce the +final sqfs. Only one chunk is mounted at a time — peak inode usage stays tiny +regardless of how many images are inside the chunks. + +INPUT +----- + Directory containing chunk_*.sqfs files, searched + recursively. Pass the staging dir for ONE task + (e.g. bq_download_staging/task_0/) to merge that + task's chunks into a single task_0.sqfs. + +OUTPUT +------ + stdout A streaming tar archive consumed by sqfstar: + python merge_sqfs_chunks.py \\ + | sqfstar -comp zstd -b 131072 task_N.sqfs + stderr Per-chunk progress: "[merge] [N/T] chunk_NNNN.sqfs → K images" + Summary: "[merge] Done. total_images=N errors=M empty_chunks=K" + +DISK SPACE +---------- +Chunks are NEVER deleted by this script. Deletion is the job script's +responsibility, and only after the output sqfs has been verified correct. +Peak disk usage during merge: all chunks + growing output sqfs (~2× space). +At production scale (~108 chunks × 10 GB = 1.1 TB chunks + 1.1 TB output), +this requires ~2.2 TB per task on scratch. The inode cost is negligible: +each chunk file is 1 inode regardless of how many images it contains. + +EXIT CODES +---------- + 0 All chunks streamed successfully. sqfstar output is complete. + 1 One or more chunks had errors (corrupt, empty, or mount failure). + sqfstar may have produced a partial output — the job script must + verify the image count before accepting the result. Chunk files are + always preserved so the merge can be re-submitted after investigation. + +TYPICAL USAGE +------------- + # Merge task 3's chunks into task_3.sqfs: + python merge_sqfs_chunks.py /scratch/$USER/staging/task_3 \\ + | sqfstar -comp zstd -Xcompression-level 3 -b 131072 -no-duplicates \\ + /scratch/$USER/task_3.sqfs + + # Dry-run — list chunks that would be processed: + python merge_sqfs_chunks.py /scratch/$USER/staging/task_3 --dry-run + + # In a SLURM job (see job_bq_pack_per_task.sh): + python merge_sqfs_chunks.py "${TASK_DIR}" \\ + | sqfstar -comp zstd -Xcompression-level 3 -b 131072 -no-duplicates \\ + "${OUTPUT_SQFS}" + PIPE_STATUS=("${PIPESTATUS[@]}") # capture both exits atomically + MERGE_EXIT="${PIPE_STATUS[0]}" + SQFSTAR_EXIT="${PIPE_STATUS[1]}" +""" + +import atexit +import glob +import os +import shutil +import signal +import subprocess +import sys +import tarfile +import tempfile +import time +import argparse +from pathlib import Path + + +# ── Temp dir cleanup ────────────────────────────────────────────────────────── +# Registered with atexit so normal exit cleans up FUSE mounts. +# Also wired to SIGTERM so SLURM wall-time timeout leaves no stale mounts. +# NOTE: SIGKILL cannot be caught — atexit does not run on SIGKILL. This is +# mitigated by SLURM sending SIGTERM 30 s before SIGKILL, giving the handler +# time to unmount and clean up before the hard kill arrives. + +_mount_base: str | None = None + + +def _cleanup_temp_dirs() -> None: + """Unmount any still-mounted squashfuse dirs and remove the temp base dir.""" + global _mount_base + if _mount_base is None or not os.path.exists(_mount_base): + return + for sub in sorted(os.listdir(_mount_base)): + full = os.path.join(_mount_base, sub) + if os.path.isdir(full): + subprocess.run(["fusermount", "-u", full], capture_output=True) + try: + os.rmdir(full) + except OSError: + pass + try: + shutil.rmtree(_mount_base, ignore_errors=True) + except Exception: + pass + _mount_base = None + + +def _sigterm_handler(signum, frame): + """On SLURM timeout (SIGTERM), clean up mounts and exit 1.""" + print("[merge] SIGTERM received — cleaning up mounts and exiting", + file=sys.stderr) + _cleanup_temp_dirs() + sys.exit(1) + + +atexit.register(_cleanup_temp_dirs) +signal.signal(signal.SIGTERM, _sigterm_handler) + + +# ── squashfuse helpers ──────────────────────────────────────────────────────── + +def squashfuse_mount(sqfs_path: str, mount_dir: str, retries: int = 1) -> bool: + """Mount sqfs_path at mount_dir via squashfuse. + + Retries once on failure to handle transient FUSE errors (e.g. stale + device entries). Returns True on success, False if all attempts fail. + """ + for attempt in range(retries + 1): + result = subprocess.run( + ["squashfuse", sqfs_path, mount_dir], + capture_output=True, text=True, + ) + if result.returncode == 0: + return True + if attempt < retries: + print( + f"[merge] squashfuse failed for {Path(sqfs_path).name} " + f"(attempt {attempt + 1}/{retries + 1}), retrying in 5s...", + file=sys.stderr, + ) + time.sleep(5) + + print( + f"[merge] ERROR: squashfuse failed for {sqfs_path}: " + f"{result.stderr.strip()}", + file=sys.stderr, + ) + return False + + +def squashfuse_unmount(mount_dir: str, retries: int = 3) -> bool: + """Unmount mount_dir via fusermount -u, with retries on transient failures. + + Returns True on success. Logs a warning on persistent failure but does not + raise — the SLURM node exit will clean up any remaining mounts. + """ + for attempt in range(retries): + result = subprocess.run( + ["fusermount", "-u", mount_dir], + capture_output=True, text=True, + ) + if result.returncode == 0: + return True + if attempt < retries - 1: + time.sleep(2) + + print( + f"[merge] WARNING: fusermount -u {mount_dir} failed after {retries} " + f"attempts — {result.stderr.strip()}. Mount will be cleaned up on node exit.", + file=sys.stderr, + ) + return False + + +# ── tar streaming ───────────────────────────────────────────────────────────── + +def stream_dir_to_tar(tar: tarfile.TarFile, mount_dir: str) -> int: + """Stream all files from mount_dir into tar, using paths relative to mount_dir. + + Directory entries are included in the tar (required by sqfstar) but are + not counted in the return value. Returns 0 if the mounted sqfs is empty. + """ + count = 0 + mount_path = Path(mount_dir) + for entry in sorted(mount_path.rglob("*")): + arcname = str(entry.relative_to(mount_path)) + if arcname == ".": + continue + info = tar.gettarinfo(str(entry), arcname=arcname) + if entry.is_dir(): + tar.addfile(info) + else: + with open(entry, "rb") as f: + tar.addfile(info, f) + count += 1 + return count + + +# ── main ────────────────────────────────────────────────────────────────────── + +def main(): + global _mount_base + + parser = argparse.ArgumentParser( + description=__doc__, + formatter_class=argparse.RawDescriptionHelpFormatter, + ) + parser.add_argument( + "staging_dir", + help=( + "Directory containing chunk_*.sqfs files to merge, searched " + "recursively. Typically the staging dir for one download task: " + "bq_download_staging/task_N/. " + "Chunks are processed in sorted order (chunk_0001, chunk_0002, …)." + ), + ) + parser.add_argument( + "--dry-run", + action="store_true", + help=( + "List the chunk files that would be merged in processing order, " + "then exit without streaming anything. Useful for verifying which " + "chunks will be included before running the full merge." + ), + ) + args = parser.parse_args() + + chunk_files = sorted( + glob.glob(f"{args.staging_dir}/**/chunk_*.sqfs", recursive=True) + ) + total = len(chunk_files) + + if total == 0: + print( + f"[merge] ERROR: no chunk_*.sqfs files found under {args.staging_dir}", + file=sys.stderr, + ) + sys.exit(1) + + print( + f"[merge] Found {total} chunk sqfs file(s) to merge under {args.staging_dir}", + file=sys.stderr, + ) + + if args.dry_run: + for f in chunk_files: + print(f) + return + + _mount_base = tempfile.mkdtemp(prefix="sqfs_merge_") + total_images = 0 + errors = 0 + empty_chunks = 0 + + try: + with tarfile.open(fileobj=sys.stdout.buffer, mode="w|") as tar: + for i, sqfs_file in enumerate(chunk_files, 1): + mnt = os.path.join(_mount_base, f"mnt_{i}") + os.makedirs(mnt, exist_ok=True) + + if not squashfuse_mount(sqfs_file, mnt): + errors += 1 + try: + os.rmdir(mnt) + except OSError: + pass + continue + + count = stream_dir_to_tar(tar, mnt) + + squashfuse_unmount(mnt) + try: + os.rmdir(mnt) + except OSError: + pass + + if count == 0: + empty_chunks += 1 + print( + f"[merge] WARNING: [{i}/{total}] {Path(sqfs_file).name} — " + f"0 images found after mount. The chunk may be corrupt or " + f"download_images.py may have failed for these images. " + f"Investigate before re-submitting.", + file=sys.stderr, + ) + else: + total_images += count + + print( + f"[merge] [{i}/{total}] {Path(sqfs_file).name} → {count} images", + file=sys.stderr, + ) + + finally: + _cleanup_temp_dirs() + + print( + f"[merge] Done. total_images={total_images} " + f"errors={errors} empty_chunks={empty_chunks}", + file=sys.stderr, + ) + + if errors > 0 or empty_chunks > 0: + sys.exit(1) + + +if __name__ == "__main__": + try: + main() + except BrokenPipeError: + _cleanup_temp_dirs() + print( + "[merge] FATAL: BrokenPipeError — the downstream process (sqfstar) " + "died before the merge completed.\n" + " Common causes:\n" + " - sqfstar OOM killed (exit=137): increase --mem in the job script\n" + " - sqfstar not found (exit=127): check 'module load' and PATH\n" + " - sqfstar crashed on corrupt input: check chunk integrity\n" + " Chunk files are preserved — fix the cause and re-submit the pack job.", + file=sys.stderr, + ) + sys.stderr.flush() + # Redirect stdout to /dev/null before sys.exit so Python's shutdown + # does not try to flush the broken pipe and produce exit code 120 + # instead of the expected 1. + try: + with open(os.devnull, "wb") as devnull: + os.dup2(devnull.fileno(), sys.stdout.fileno()) + except Exception: + pass + sys.exit(1) diff --git a/src/dataset_tools/bq_squashfs/queries/global_max2000img.sql b/src/dataset_tools/bq_squashfs/queries/global_max2000img.sql new file mode 100644 index 0000000..346b598 --- /dev/null +++ b/src/dataset_tools/bq_squashfs/queries/global_max2000img.sql @@ -0,0 +1,31 @@ +-- Global dataset: all downloaded images, capped at 2000 images per species. +-- Species with fewer than 2000 images are included in full. +-- Expected: ~12,313 species, ~2.7M images. +WITH ranked AS ( + SELECT + ti.photo_id, + ti.gbif_id, + ti.relative_local_path, + ti.dataset_source_uuid, + ti.inat_taxon_id, + tx.species_name, + tx.family, + ROW_NUMBER() OVER ( + PARTITION BY tx.species_name + ORDER BY ti.photo_id + ) AS rn + FROM `leps-ai.global_butterflies_2604.training_images` ti + JOIN `leps-ai.global_butterflies_2604.inat_taxa` tx USING (inat_taxon_id) + WHERE ti.fetch_status = 'downloaded' + AND tx.species_name IS NOT NULL +) +SELECT + photo_id, + gbif_id, + relative_local_path, + dataset_source_uuid, + inat_taxon_id, + species_name, + family +FROM ranked +WHERE rn <= 2000 diff --git a/src/dataset_tools/bq_squashfs/queries/global_min25occ.sql b/src/dataset_tools/bq_squashfs/queries/global_min25occ.sql new file mode 100644 index 0000000..a495752 --- /dev/null +++ b/src/dataset_tools/bq_squashfs/queries/global_min25occ.sql @@ -0,0 +1,23 @@ +-- Global dataset: all downloaded images for species with >= 25 distinct GBIF occurrences. +-- Expected: ~4,704 species, ~10.6M images. +WITH qualifying_species AS ( + SELECT tx.species_name + FROM `leps-ai.global_butterflies_2604.training_images` ti + JOIN `leps-ai.global_butterflies_2604.inat_taxa` tx USING (inat_taxon_id) + WHERE ti.fetch_status = 'downloaded' + AND tx.species_name IS NOT NULL + GROUP BY tx.species_name + HAVING COUNT(DISTINCT ti.gbif_id) >= 25 +) +SELECT + ti.photo_id, + ti.gbif_id, + ti.relative_local_path, + ti.dataset_source_uuid, + ti.inat_taxon_id, + tx.species_name, + tx.family +FROM `leps-ai.global_butterflies_2604.training_images` ti +JOIN `leps-ai.global_butterflies_2604.inat_taxa` tx USING (inat_taxon_id) +WHERE ti.fetch_status = 'downloaded' + AND tx.species_name IN (SELECT species_name FROM qualifying_species) diff --git a/src/dataset_tools/bq_squashfs/split.py b/src/dataset_tools/bq_squashfs/split.py new file mode 100644 index 0000000..2938cc4 --- /dev/null +++ b/src/dataset_tools/bq_squashfs/split.py @@ -0,0 +1,231 @@ +#!/usr/bin/env python3 +""" +Split a BQ-exported CSV into train/val/test CSV files. + +Uses stratified splitting (sklearn) on a category column so every species +is proportionally represented in all three splits. When --split-by-occurrence +is set, images that share the same gbif_id (same field observation) +are kept together in one split to avoid data leakage. + +Required CSV columns: + photo_id — integer (iNat photo ID, image-level) + relative_local_path — path of the image within its sqfs + gbif_id — GBIF occurrence ID (used by --split-by-occurrence) + +Optional (passed through unchanged): + species_name, inat_taxon_id, family, gbif_id, and any other columns + +Usage: + python split.py \\ + --csv bq_export.csv \\ + --output-dir /project/.../data/splits/ \\ + --category-key species_name \\ + --val-frac 0.1 \\ + --test-frac 0.1 \\ + --split-by-occurrence \\ + --max-instances 1000 \\ + --min-instances 5 \\ + --seed 42 +""" + +import argparse +import math +import random +from pathlib import Path + +import numpy as np +import pandas as pd +from sklearn.model_selection import train_test_split + +REQUIRED_COLUMNS = {"photo_id", "relative_local_path", "gbif_id"} + + +def _set_random_seeds(seed: int) -> None: + random.seed(seed) + np.random.seed(seed) + + +def _create_test_split( + cat_counts: pd.Series, + category_key: str, + metadata: pd.DataFrame, + split_by_occurrence: bool, + split_metadata: pd.DataFrame, + test_frac: float, +) -> pd.DataFrame: + min_instances = math.ceil(1 / test_frac) + test_categories = list(cat_counts[cat_counts >= min_instances].keys()) + selected = split_metadata[split_metadata[category_key].isin(test_categories)] + _, selected, _, _ = train_test_split( + selected, selected[[category_key]], stratify=selected[[category_key]], + test_size=test_frac, + ) + if split_by_occurrence: + return metadata[metadata["gbif_id"].isin( + selected["gbif_id"].unique() + )].copy() + return selected.copy() + + +def _create_val_split( + cat_counts: pd.Series, + category_key: str, + metadata: pd.DataFrame, + split_by_occurrence: bool, + split_metadata: pd.DataFrame, + test_frac: float, + test_set: pd.DataFrame, + val_frac: float, +) -> pd.DataFrame: + min_instances = math.ceil(1 / val_frac) + val_categories = list(cat_counts[cat_counts >= min_instances].keys()) + selected = split_metadata[ + ~split_metadata["relative_local_path"].isin(test_set["relative_local_path"].unique()) + ] + selected = selected[selected[category_key].isin(val_categories)] + adjusted_val_frac = val_frac / (1 - test_frac) + _, selected, _, _ = train_test_split( + selected, selected[[category_key]], stratify=selected[[category_key]], + test_size=adjusted_val_frac, + ) + if split_by_occurrence: + return metadata[metadata["gbif_id"].isin( + selected["gbif_id"].unique() + )].copy() + return selected.copy() + + +def _create_train_split( + metadata: pd.DataFrame, + test_set: pd.DataFrame, + val_set: pd.DataFrame, +) -> pd.DataFrame: + exclude = set(test_set["relative_local_path"].unique()) | set(val_set["relative_local_path"].unique()) + return metadata[~metadata["relative_local_path"].isin(exclude)].copy() + + +def _subsample(dataset: pd.DataFrame, max_instances: int, category_key: str) -> pd.DataFrame: + counts = dataset[category_key].value_counts() + over_limit = list(counts[counts > max_instances].keys()) + under_limit = dataset[~dataset[category_key].isin(over_limit)].copy() + capped = [ + dataset[dataset[category_key] == cat].sample(max_instances) + for cat in over_limit + ] + return pd.concat([under_limit] + capped, ignore_index=True) + + +def split_dataset( + dataset_csv: str, + output_dir: str, + test_frac: float, + val_frac: float, + split_by_occurrence: bool, + category_key: str, + max_instances: int, + min_instances: int, + seed: int, +) -> None: + _set_random_seeds(seed) + output_path = Path(output_dir) + output_path.mkdir(parents=True, exist_ok=True) + + print(f"Loading {dataset_csv} ...") + metadata = pd.read_csv(dataset_csv, dtype={"photo_id": "Int64"}) + print(f" {len(metadata):,} rows columns: {list(metadata.columns)}") + + missing = REQUIRED_COLUMNS - set(metadata.columns) + if missing: + raise ValueError(f"CSV missing required columns: {missing}") + if category_key not in metadata.columns: + raise ValueError(f"Category key '{category_key}' not found in CSV columns: {list(metadata.columns)}") + + # When splitting by occurrence, deduplicate to one row per occurrence first, + # then pull all images for the selected occurrences back in. + if split_by_occurrence: + split_metadata = metadata.drop_duplicates(subset=["gbif_id"], keep="first").copy() + print(f" Split by occurrence: {len(split_metadata):,} unique occurrences") + else: + split_metadata = metadata.copy() + + cat_counts = split_metadata[category_key].value_counts() + print(f" {len(cat_counts):,} unique categories in '{category_key}'") + + print(f"Creating test split (frac={test_frac}) ...") + test_set = _create_test_split( + cat_counts, category_key, metadata, split_by_occurrence, split_metadata, test_frac + ) + + print(f"Creating val split (frac={val_frac}) ...") + val_set = _create_val_split( + cat_counts, category_key, metadata, split_by_occurrence, split_metadata, + test_frac, test_set, val_frac, + ) + + print("Creating train split ...") + train_set = _create_train_split(metadata, test_set, val_set) + + if max_instances > 0: + print(f"Capping at max_instances={max_instances} per category ...") + train_set = _subsample(train_set, max_instances, category_key) + val_set = _subsample(val_set, int(max_instances * val_frac), category_key) + test_set = _subsample(test_set, int(max_instances * test_frac), category_key) + + if min_instances > 0: + print(f"Filtering to min_instances={min_instances} per category ...") + cat_counts_train = train_set[category_key].value_counts() + keep = list(cat_counts_train[cat_counts_train >= min_instances].keys()) + train_set = train_set[train_set[category_key].isin(keep)].copy() + val_set = val_set[val_set[category_key].isin(keep)].copy() + test_set = test_set[test_set[category_key].isin(keep)].copy() + print(f" Kept {len(keep):,} categories with >= {min_instances} train images") + + splits = {"train": train_set, "val": val_set, "test": test_set} + print("\n=== Split summary ===") + for name, df in splits.items(): + n_cats = df[category_key].nunique() + out = output_path / f"{name}.csv" + df.to_csv(out, index=False) + print(f" {name}: {len(df):,} images {n_cats:,} categories → {out}") + + +def main() -> None: + parser = argparse.ArgumentParser( + description=__doc__, + formatter_class=argparse.RawDescriptionHelpFormatter, + ) + parser.add_argument("--csv", required=True, + help="BQ-exported CSV file") + parser.add_argument("--output-dir", required=True, + help="Directory to write train.csv / val.csv / test.csv") + parser.add_argument("--category-key", default="species_name", + help="Column used for stratified splitting (default: species_name)") + parser.add_argument("--val-frac", type=float, default=0.1, + help="Fraction of data for validation (default: 0.1)") + parser.add_argument("--test-frac", type=float, default=0.1, + help="Fraction of data for test (default: 0.1)") + parser.add_argument("--split-by-occurrence", action="store_true", + help="Keep images from the same gbif_id (GBIF occurrence) in one split") + parser.add_argument("--max-instances", type=int, default=1000, + help="Max images per category per split (0 = no cap, default: 1000)") + parser.add_argument("--min-instances", type=int, default=0, + help="Min train images per category; drop categories below this (default: 0)") + parser.add_argument("--seed", type=int, default=42, + help="Random seed (default: 42)") + args = parser.parse_args() + + split_dataset( + dataset_csv=args.csv, + output_dir=args.output_dir, + test_frac=args.test_frac, + val_frac=args.val_frac, + split_by_occurrence=args.split_by_occurrence, + category_key=args.category_key, + max_instances=args.max_instances, + min_instances=args.min_instances, + seed=args.seed, + ) + + +if __name__ == "__main__": + main() diff --git a/tests/conftest.py b/tests/conftest.py new file mode 100644 index 0000000..6a60031 --- /dev/null +++ b/tests/conftest.py @@ -0,0 +1,163 @@ +""" +Shared pytest fixtures for all pipeline stage tests. + +Fixtures defined here are available to every test file under tests/ with no imports. + +Fixture summary: + mock_bq_client — create_autospec(bigquery.Client) with sensible defaults + small_csv — CSV file: 5 species × 10 images, photo_ids spanning tasks 0-9 + small_sqfs — real sqfs file with 10 PIL-generated JPEGs (session-scoped) + sample_sql_file — a minimal .sql fixture file for bq_export tests +""" + +import io +import subprocess +from pathlib import Path +from unittest.mock import create_autospec + +import pandas as pd +import pytest +from google.cloud import bigquery +from PIL import Image + +# ── BQ client ───────────────────────────────────────────────────────────────── + + +@pytest.fixture +def mock_bq_client(): + """ + Strict BQ client mock — rejects calls to non-existent methods. + + Defaults: + query().result() → [] + load_table_from_dataframe().result() → None + get_table() → MagicMock (table exists) + delete_table() → None + + Override per-test: + mock_bq_client.query.return_value.result.return_value = [row1, row2] + """ + client = create_autospec(bigquery.Client) + client.query.return_value.result.return_value = [] + client.load_table_from_dataframe.return_value.result.return_value = None + client.query.return_value.job_id = "test-job-id" + client.query.return_value.dml_stats.updated_row_count = 0 + return client + + +# ── Small CSV dataset ───────────────────────────────────────────────────────── + + +def _make_small_df() -> pd.DataFrame: + """ + 5 species × 10 images = 50 rows. + + Designed so that: + - photo_ids 0-49 span all 10 tasks (MOD 10) + - 2 images share each gbif_id (occurrence grouping) + - all species have >= 5 images (min_instances=5 tests pass) + - one species has exactly 5 images (min_instances boundary) + """ + species = [ + ("Danaus plexippus", 1001, 101), + ("Vanessa atalanta", 1002, 102), + ("Papilio machaon", 1003, 103), + ("Colias croceus", 1004, 104), + ("Pieris brassicae", 1005, 105), + ] + rows = [] + photo_id = 0 + for sp_name, taxon_id, base_gbif in species: + for i in range(10): + gbif_id = base_gbif + (i // 2) # 2 images share a gbif_id + rows.append( + { + "photo_id": photo_id, + "gbif_id": gbif_id, + "inat_taxon_id": taxon_id, + "species_name": sp_name, + "dataset_source_uuid": f"uuid-{photo_id:04d}", + "relative_local_path": f"{photo_id % 256:03d}/{photo_id:06d}.jpg", + "absolute_url": f"https://inaturalist.org/photos/{photo_id}/original.jpg", + } + ) + photo_id += 1 + return pd.DataFrame(rows) + + +@pytest.fixture +def small_df() -> pd.DataFrame: + """Raw DataFrame — use when you need to manipulate before writing to CSV.""" + return _make_small_df() + + +@pytest.fixture +def small_csv(tmp_path) -> Path: + """CSV file on disk — the direct input format for split.py and bq_export.py.""" + path = tmp_path / "small_dataset.csv" + _make_small_df().to_csv(path, index=False) + return path + + +# ── Small SquashFS ──────────────────────────────────────────────────────────── + + +def _build_small_sqfs(root: Path) -> Path: + """ + Create a small sqfs with 10 PIL-generated JPEGs. + Images are placed in bucket dirs (000/, 001/, ...) matching + the structure download_images.py produces. + """ + img_dir = root / "images" + for i in range(10): + bucket = img_dir / f"{i:03d}" + bucket.mkdir(parents=True, exist_ok=True) + buf = io.BytesIO() + Image.new("RGB", (64, 48), color=(i * 25, 100, 200)).save(buf, format="JPEG") + (bucket / f"{i:06d}.jpg").write_bytes(buf.getvalue()) + + sqfs_path = root / "test_fixture.sqfs" + result = subprocess.run( + [ + "mksquashfs", + str(img_dir), + str(sqfs_path), + "-noappend", + "-no-xattrs", + "-comp", + "zstd", + "-Xcompression-level", + "1", + "-quiet", + ], + capture_output=True, + ) + if result.returncode != 0: + pytest.skip(f"mksquashfs not available: {result.stderr.decode()}") + return sqfs_path + + +@pytest.fixture(scope="session") +def small_sqfs(tmp_path_factory) -> Path: + """ + Real sqfs file with 10 synthetic JPEGs — built once per test session. + Skipped automatically if mksquashfs is not available on the current machine. + """ + root = tmp_path_factory.mktemp("sqfs_fixture") + return _build_small_sqfs(root) + + +# ── SQL file ────────────────────────────────────────────────────────────────── + + +@pytest.fixture +def sample_sql_file(tmp_path) -> Path: + """Minimal SQL query file for bq_export.py tests.""" + path = tmp_path / "test_query.sql" + path.write_text( + "SELECT photo_id, species_name, gbif_id\n" + "FROM `leps-ai.global_butterflies_2604.training_images`\n" + "WHERE fetch_status = 'downloaded'\n" + "LIMIT 10\n" + ) + return path diff --git a/tests/dataset_tools/test_clean.py b/tests/dataset_tools/test_clean.py new file mode 100644 index 0000000..c9ff37d --- /dev/null +++ b/tests/dataset_tools/test_clean.py @@ -0,0 +1,538 @@ +""" +Tests for bigquery_pipeline/clean.py. + +Two layers: + - Mock-based: verify SQL generation and Python logic (no BQ, no network) + - DuckDB integration: execute the real CTE chain against in-memory rows, + assert all 3 duplicate types are resolved correctly + +Run with: pytest tests/dataset_tools/test_clean.py -v +""" +import json +import types +from argparse import Namespace +from unittest.mock import MagicMock, call, patch + +import duckdb +import pandas as pd +import pytest + +import pytest + +import src.dataset_tools.bigquery_pipeline.clean as clean + + +# ── helpers ─────────────────────────────────────────────────────────────────── + +def make_count_row(stage, n, taxa=100): + """Minimal BQ row mock matching the count query SELECT (stage, n, taxa).""" + row = MagicMock() + row.stage = stage + row.n = n + row.taxa = taxa + return row + + +SAMPLE_COUNTS = { + "input": {"rows": 37_500, "taxa": 5_000}, + "step1": {"rows": 27_000, "taxa": 5_000}, + "step2": {"rows": 25_700, "taxa": 5_000}, + "step3": {"rows": 21_000, "taxa": 4_700}, + "step4": {"rows": 21_000, "taxa": 4_700}, + "conflicts": {"rows": 2_229, "taxa": 0}, +} + + +def make_bq_client(counts=SAMPLE_COUNTS): + """Mock BQ client whose query().result() returns the given stage counts.""" + rows = [make_count_row(stage, v["rows"], v["taxa"]) for stage, v in counts.items()] + client = MagicMock() + client.query.return_value.result.return_value = rows + return client + + +def make_args(**kwargs): + """Build a minimal Namespace matching clean.py's argparse output.""" + defaults = dict( + dataset="global_all_leps_2605", + project="leps-ai", + table="training_images", + min_images_per_taxon=0, + multi_taxon_strategy="drop", + dry_run=False, + log_file=None, + ) + defaults.update(kwargs) + return Namespace(**defaults) + + +# ── _step3_cte ──────────────────────────────────────────────────────────────── + +class TestStep3Cte: + + def test_drop_strategy_uses_not_in(self): + sql = clean._step3_cte("drop") + assert "NOT IN" in sql + assert "multi_taxon_photos" in sql + + def test_keep_lowest_uses_qualify_row_number(self): + sql = clean._step3_cte("keep-lowest-taxon-id") + assert "QUALIFY" in sql + assert "ROW_NUMBER()" in sql + assert "PARTITION BY photo_id" in sql + assert "ORDER BY inat_taxon_id" in sql + + def test_drop_does_not_contain_qualify(self): + sql = clean._step3_cte("drop") + assert "QUALIFY" not in sql + + def test_keep_lowest_does_not_contain_not_in(self): + sql = clean._step3_cte("keep-lowest-taxon-id") + assert "NOT IN" not in sql + + +# ── _cte_chain SQL content ──────────────────────────────────────────────────── + +class TestCteSqlContent: + + SRC = "leps-ai.global_all_leps_2605.training_images" + + def _chain(self, strategy="drop", min_images=0): + return clean._cte_chain(self.SRC, strategy, min_images) + + def test_src_ref_embedded(self): + assert self.SRC in self._chain() + + def test_step1_exact_dedup_partition(self): + sql = self._chain() + assert "PARTITION BY photo_id, inat_taxon_id, gbif_id" in sql + + def test_step1_uses_qualify_row_number(self): + sql = self._chain() + assert "QUALIFY" in sql + assert "ROW_NUMBER()" in sql + + def test_step2_same_taxon_multi_gbif_partition(self): + sql = self._chain() + assert "PARTITION BY photo_id, inat_taxon_id" in sql + + def test_step2_orders_by_gbif_id(self): + sql = self._chain() + assert "ORDER BY gbif_id" in sql + + def test_multi_taxon_photos_cte_present(self): + sql = self._chain() + assert "multi_taxon_photos" in sql + assert "COUNT(DISTINCT inat_taxon_id) > 1" in sql + + def test_step4_min_images_filter_active(self): + sql = self._chain(min_images=10) + assert "t.cnt >= 10" in sql + + def test_step4_min_images_disabled_when_zero(self): + sql = self._chain(min_images=0) + assert "TRUE" in sql + assert "t.cnt >=" not in sql + + def test_step3_drop_strategy_in_chain(self): + sql = self._chain(strategy="drop") + assert "NOT IN" in sql + + def test_step3_keep_lowest_strategy_in_chain(self): + sql = self._chain(strategy="keep-lowest-taxon-id") + assert "ORDER BY inat_taxon_id" in sql + + +# ── run_count_query ─────────────────────────────────────────────────────────── + +class TestRunCountQuery: + + SRC = "leps-ai.global_all_leps_2605.training_images_test" + + def test_returns_dict_with_all_stages(self): + client = make_bq_client() + result = clean.run_count_query(client, self.SRC, "drop", 0) + assert set(result.keys()) == {"input", "step1", "step2", "step3", "step4", "conflicts"} + + def test_rows_parsed_correctly(self): + client = make_bq_client() + result = clean.run_count_query(client, self.SRC, "drop", 0) + assert result["input"]["rows"] == 37_500 + assert result["step4"]["rows"] == 21_000 + assert result["conflicts"]["rows"] == 2_229 + + def test_taxa_parsed_correctly(self): + client = make_bq_client() + result = clean.run_count_query(client, self.SRC, "drop", 0) + assert result["input"]["taxa"] == 5_000 + assert result["step3"]["taxa"] == 4_700 + + def test_bq_client_query_called_once(self): + client = make_bq_client() + clean.run_count_query(client, self.SRC, "drop", 0) + assert client.query.call_count == 1 + + def test_result_called_on_query(self): + client = make_bq_client() + clean.run_count_query(client, self.SRC, "drop", 0) + client.query.return_value.result.assert_called_once() + + def test_sql_contains_union_all_for_all_stages(self): + client = make_bq_client() + clean.run_count_query(client, self.SRC, "drop", 0) + sql = client.query.call_args[0][0] + assert sql.count("UNION ALL") >= 5 + + def test_sql_contains_src_ref(self): + client = make_bq_client() + clean.run_count_query(client, self.SRC, "drop", 0) + sql = client.query.call_args[0][0] + assert self.SRC in sql + + def test_drop_strategy_sql_has_not_in(self): + client = make_bq_client() + clean.run_count_query(client, self.SRC, "drop", 0) + sql = client.query.call_args[0][0] + assert "NOT IN" in sql + + def test_keep_lowest_strategy_sql_has_qualify(self): + client = make_bq_client() + clean.run_count_query(client, self.SRC, "keep-lowest-taxon-id", 0) + sql = client.query.call_args[0][0] + assert "ORDER BY inat_taxon_id" in sql + + +# ── run_write_query ─────────────────────────────────────────────────────────── + +class TestRunWriteQuery: + + SRC = "leps-ai.global_all_leps_2605.training_images" + DST = "leps-ai.global_all_leps_2605.training_images" + + def test_issues_create_or_replace_table(self): + client = MagicMock() + clean.run_write_query(client, self.SRC, self.DST, "drop", 0) + sql = client.query.call_args[0][0] + assert "CREATE OR REPLACE TABLE" in sql + + def test_dst_ref_in_sql(self): + client = MagicMock() + clean.run_write_query(client, self.SRC, self.DST, "drop", 0) + sql = client.query.call_args[0][0] + assert self.DST in sql + + def test_selects_from_step4(self): + client = MagicMock() + clean.run_write_query(client, self.SRC, self.DST, "drop", 0) + sql = client.query.call_args[0][0] + assert "SELECT * FROM step4" in sql + + def test_result_called(self): + client = MagicMock() + clean.run_write_query(client, self.SRC, self.DST, "drop", 0) + client.query.return_value.result.assert_called_once() + + def test_query_called_exactly_once(self): + client = MagicMock() + clean.run_write_query(client, self.SRC, self.DST, "drop", 0) + assert client.query.call_count == 1 + + +# ── build_log ───────────────────────────────────────────────────────────────── + +class TestBuildLog: + + STARTED = "2026-06-04T01:00:00+00:00" + + def _log(self, counts=SAMPLE_COUNTS, **kwargs): + args = make_args(**kwargs) + return clean.build_log(counts, args, "leps-ai.ds.training_images", args.dry_run, self.STARTED) + + def test_summary_input_rows(self): + assert self._log()["summary"]["input_rows"] == 37_500 + + def test_summary_output_rows(self): + assert self._log()["summary"]["output_rows"] == 21_000 + + def test_summary_total_removed(self): + log = self._log() + assert log["summary"]["total_removed"] == 37_500 - 21_000 + + def test_summary_taxa_before_and_after(self): + log = self._log() + assert log["summary"]["taxa_before"] == 5_000 + assert log["summary"]["taxa_after"] == 4_700 + + def test_step1_removed_arithmetic(self): + log = self._log() + step = log["steps"]["exact_duplicates"] + assert step["removed"] == 37_500 - 27_000 + + def test_step2_removed_arithmetic(self): + log = self._log() + step = log["steps"]["same_taxon_multi_gbif"] + assert step["removed"] == 27_000 - 25_700 + + def test_step3_removed_arithmetic(self): + log = self._log() + step = log["steps"]["multi_taxon_conflicts"] + assert step["removed"] == 25_700 - 21_000 + + def test_step3_conflict_photo_ids(self): + log = self._log() + assert log["steps"]["multi_taxon_conflicts"]["conflict_photo_ids"] == 2_229 + + def test_step4_threshold_recorded(self): + log = self._log(min_images_per_taxon=10) + assert log["steps"]["min_images_filter"]["threshold"] == 10 + + def test_dry_run_flag_true(self): + assert self._log(dry_run=True)["dry_run"] is True + + def test_dry_run_flag_false(self): + assert self._log(dry_run=False)["dry_run"] is False + + def test_started_at_preserved(self): + assert self._log()["started_at"] == self.STARTED + + def test_log_is_json_serialisable(self): + log = self._log() + json.dumps(log) # must not raise + + +# ── print_report ────────────────────────────────────────────────────────────── + +class TestPrintReport: + + def test_no_crash(self, capsys): + clean.print_report(SAMPLE_COUNTS, "drop", 0, "leps-ai.ds.t", dry_run=False) + capsys.readouterr() # consume output + + def test_dry_run_tag_present(self, capsys): + clean.print_report(SAMPLE_COUNTS, "drop", 0, "leps-ai.ds.t", dry_run=True) + out = capsys.readouterr().out + assert "DRY RUN" in out + + def test_dry_run_tag_absent_when_not_dry(self, capsys): + clean.print_report(SAMPLE_COUNTS, "drop", 0, "leps-ai.ds.t", dry_run=False) + out = capsys.readouterr().out + assert "DRY RUN" not in out + + def test_total_removed_in_output(self, capsys): + clean.print_report(SAMPLE_COUNTS, "drop", 0, "leps-ai.ds.t", dry_run=False) + out = capsys.readouterr().out + assert "16,500" in out # 37500 - 21000 + + def test_conflict_photo_ids_in_output(self, capsys): + clean.print_report(SAMPLE_COUNTS, "drop", 0, "leps-ai.ds.t", dry_run=False) + out = capsys.readouterr().out + assert "2,229" in out + + def test_strategy_name_in_output(self, capsys): + clean.print_report(SAMPLE_COUNTS, "keep-lowest-taxon-id", 0, "leps-ai.ds.t", dry_run=False) + out = capsys.readouterr().out + assert "keep-lowest-taxon-id" in out + + +# ── main() integration ──────────────────────────────────────────────────────── + +class TestMainIntegration: + + def _run_main(self, argv, counts=SAMPLE_COUNTS): + client = make_bq_client(counts) + with patch("sys.argv", ["clean.py"] + argv), \ + patch("src.dataset_tools.bigquery_pipeline.clean.bigquery.Client", + return_value=client): + clean.main() + return client + + def test_dry_run_calls_count_query_not_write(self): + client = self._run_main([ + "--dataset", "global_all_leps_2605", + "--table", "training_images_test", + "--dry-run", + ]) + assert client.query.call_count == 1 + sql = client.query.call_args[0][0] + assert "CREATE OR REPLACE TABLE" not in sql + + def test_actual_run_calls_count_then_write(self): + client = self._run_main([ + "--dataset", "global_all_leps_2605", + "--table", "training_images_test", + ]) + assert client.query.call_count == 2 + write_sql = client.query.call_args_list[1][0][0] + assert "CREATE OR REPLACE TABLE" in write_sql + + def test_log_file_written(self, tmp_path): + log_path = tmp_path / "clean.json" + self._run_main([ + "--dataset", "global_all_leps_2605", + "--dry-run", + "--log-file", str(log_path), + ]) + assert log_path.exists() + log = json.loads(log_path.read_text()) + assert "summary" in log + assert "steps" in log + + def test_min_images_arg_propagated_to_sql(self): + client = self._run_main([ + "--dataset", "global_all_leps_2605", + "--min-images-per-taxon", "15", + "--dry-run", + ]) + sql = client.query.call_args[0][0] + assert "t.cnt >= 15" in sql + + def test_keep_lowest_strategy_propagated_to_sql(self): + client = self._run_main([ + "--dataset", "global_all_leps_2605", + "--multi-taxon-strategy", "keep-lowest-taxon-id", + "--dry-run", + ]) + sql = client.query.call_args[0][0] + assert "ORDER BY inat_taxon_id" in sql + + +# ── DuckDB integration — real SQL against in-memory rows ───────────────────── + +@pytest.fixture +def dup_df(): + """ + Small DataFrame covering all 3 duplicate types plus clean rows. + + photo_id=1 type 1 — exact duplicate (identical rows) + photo_id=2 type 2 — same photo, two different taxa + gbif_ids + photo_id=3 type 3 — same photo + taxon, two different gbif_ids + photo_id=4 clean + photo_id=5 clean, different taxon + photo_id=6 clean, used for min-images-per-taxon tests (sole photo of taxon 600) + """ + return pd.DataFrame([ + # Type 1: exact dup + dict(photo_id=1, inat_taxon_id=100, gbif_id=1000, dataset_source_uuid="uuid-1"), + dict(photo_id=1, inat_taxon_id=100, gbif_id=1000, dataset_source_uuid="uuid-1"), + # Type 2: multi-taxon conflict + dict(photo_id=2, inat_taxon_id=200, gbif_id=2000, dataset_source_uuid="uuid-2"), + dict(photo_id=2, inat_taxon_id=201, gbif_id=2001, dataset_source_uuid="uuid-2"), + # Type 3: same taxon, multi-gbif (keep MIN gbif_id=3000) + dict(photo_id=3, inat_taxon_id=300, gbif_id=3000, dataset_source_uuid="uuid-3"), + dict(photo_id=3, inat_taxon_id=300, gbif_id=3001, dataset_source_uuid="uuid-3"), + # Clean rows + dict(photo_id=4, inat_taxon_id=100, gbif_id=4000, dataset_source_uuid="uuid-4"), + dict(photo_id=5, inat_taxon_id=300, gbif_id=5000, dataset_source_uuid="uuid-5"), + # Lone taxon (taxon 600 has only 1 image — for min-images test) + dict(photo_id=6, inat_taxon_id=600, gbif_id=6000, dataset_source_uuid="uuid-6"), + ]) + + +def run_cte(df, strategy="drop", min_images=0): + """Execute the clean.py CTE chain against a pandas DataFrame via DuckDB.""" + conn = duckdb.connect() + conn.register("src_table", df) + cte = clean._cte_chain("src_table", strategy, min_images).replace("`", "") + return conn.execute(f"WITH {cte} SELECT * FROM step4").df() + + +class TestDuckDBIntegration: + """Execute the real CTE SQL against in-memory rows — no BQ required.""" + + def test_output_row_count(self, dup_df): + result = run_cte(dup_df) + # type1: 1 kept, type2: 0 kept (both dropped), type3: 1 kept, clean: 3 kept + assert len(result) == 5 + + def test_type1_exact_dup_resolved(self, dup_df): + result = run_cte(dup_df) + assert result[result.photo_id == 1].shape[0] == 1 + + def test_type2_multi_taxon_dropped(self, dup_df): + result = run_cte(dup_df) + assert result[result.photo_id == 2].shape[0] == 0 + + def test_type3_keeps_min_gbif_id(self, dup_df): + result = run_cte(dup_df) + row = result[result.photo_id == 3] + assert len(row) == 1 + assert row.iloc[0].gbif_id == 3000 + + def test_clean_rows_preserved(self, dup_df): + result = run_cte(dup_df) + assert result[result.photo_id == 4].shape[0] == 1 + assert result[result.photo_id == 5].shape[0] == 1 + + def test_no_duplicate_photo_ids_remain(self, dup_df): + result = run_cte(dup_df) + assert result.photo_id.nunique() == len(result) + + def test_no_duplicate_uuids_remain(self, dup_df): + result = run_cte(dup_df) + assert result.dataset_source_uuid.nunique() == len(result) + + def test_each_photo_has_one_taxon(self, dup_df): + result = run_cte(dup_df) + taxa_per_photo = result.groupby("photo_id")["inat_taxon_id"].nunique() + assert (taxa_per_photo == 1).all() + + def test_min_images_drops_lone_taxon(self, dup_df): + # taxon 600 has only 1 image — threshold=2 should drop it + result = run_cte(dup_df, min_images=2) + assert result[result.inat_taxon_id == 600].shape[0] == 0 + + def test_min_images_keeps_taxon_above_threshold(self, dup_df): + # taxon 100 has 2 images (photo_id 1 and 4) — survives threshold=2 + result = run_cte(dup_df, min_images=2) + assert result[result.inat_taxon_id == 100].shape[0] == 2 + + def test_keep_lowest_taxon_keeps_one_row_per_photo(self, dup_df): + result = run_cte(dup_df, strategy="keep-lowest-taxon-id") + assert result[result.photo_id == 2].shape[0] == 1 + + def test_keep_lowest_taxon_picks_min_taxon_id(self, dup_df): + result = run_cte(dup_df, strategy="keep-lowest-taxon-id") + row = result[result.photo_id == 2] + assert row.iloc[0].inat_taxon_id == 200 # min(200, 201) + + # ── edge cases ──────────────────────────────────────────────────────────── + + def test_type1_and_type3_overlap_resolved_to_one_row(self): + """Photo with both exact-dup rows AND multiple gbif_ids under same taxon.""" + df = pd.DataFrame([ + dict(photo_id=1, inat_taxon_id=100, gbif_id=1000, dataset_source_uuid="uuid-1"), + dict(photo_id=1, inat_taxon_id=100, gbif_id=1000, dataset_source_uuid="uuid-1"), + dict(photo_id=1, inat_taxon_id=100, gbif_id=1001, dataset_source_uuid="uuid-1"), + ]) + result = run_cte(df) + assert len(result) == 1 + assert result.iloc[0].gbif_id == 1000 + + def test_type1_with_three_copies_resolves_to_one(self): + """Three identical rows (tripled duplicate) collapse to one.""" + df = pd.DataFrame([ + dict(photo_id=1, inat_taxon_id=100, gbif_id=1000, dataset_source_uuid="uuid-1"), + dict(photo_id=1, inat_taxon_id=100, gbif_id=1000, dataset_source_uuid="uuid-1"), + dict(photo_id=1, inat_taxon_id=100, gbif_id=1000, dataset_source_uuid="uuid-1"), + ]) + result = run_cte(df) + assert len(result) == 1 + + def test_type3_with_three_gbif_ids_keeps_min(self): + """Same photo+taxon with 3 different gbif_ids — keeps the lowest.""" + df = pd.DataFrame([ + dict(photo_id=1, inat_taxon_id=100, gbif_id=3000, dataset_source_uuid="uuid-1"), + dict(photo_id=1, inat_taxon_id=100, gbif_id=3001, dataset_source_uuid="uuid-1"), + dict(photo_id=1, inat_taxon_id=100, gbif_id=3002, dataset_source_uuid="uuid-1"), + ]) + result = run_cte(df) + assert len(result) == 1 + assert result.iloc[0].gbif_id == 3000 + + def test_empty_table_returns_zero_rows(self): + """Empty input table produces empty output without error.""" + df = pd.DataFrame( + columns=["photo_id", "inat_taxon_id", "gbif_id", "dataset_source_uuid"] + ) + result = run_cte(df) + assert len(result) == 0 diff --git a/tests/dataset_tools/test_download_images.py b/tests/dataset_tools/test_download_images.py new file mode 100644 index 0000000..c7816d4 --- /dev/null +++ b/tests/dataset_tools/test_download_images.py @@ -0,0 +1,955 @@ +""" +Failure scenario tests for download_images.py. + +All tests use mocks — no network, no BQ, no filesystem writes beyond tmp_path. +Run with: pytest tests/dataset_tools/test_download_images.py -v +""" + +import subprocess +from pathlib import Path +from unittest.mock import MagicMock, call, patch + +import pytest +import requests + +import src.dataset_tools.bq_squashfs.download_images as di + +# ── helpers ─────────────────────────────────────────────────────────────────── + + +def make_response(status_code: int, content: bytes = b"\xff\xd8\xff\xe0JFIF"): + """Build a minimal mock HTTP response.""" + resp = MagicMock() + resp.status_code = status_code + resp.iter_content.return_value = [content] + if status_code >= 400: + resp.raise_for_status.side_effect = requests.exceptions.HTTPError( + f"HTTP {status_code}" + ) + else: + resp.raise_for_status.return_value = None + return resp + + +def make_mock_session(*responses): + """Return a mock session whose .get() yields responses in order.""" + session = MagicMock() + session.get.side_effect = list(responses) + return session + + +# ── _fetch_with_retry ───────────────────────────────────────────────────────── + + +class TestFetchWithRetry: + + def test_success_on_first_attempt(self, tmp_path): + dest = tmp_path / "img.jpg" + session = make_mock_session(make_response(200, b"IMAGE")) + with patch.object(di, "_get_session", return_value=session), patch( + "time.sleep" + ): + di._fetch_with_retry("http://example.com/img.jpg", dest) + assert dest.read_bytes() == b"IMAGE" + assert session.get.call_count == 1 + + def test_429_retries_then_succeeds(self, tmp_path): + dest = tmp_path / "img.jpg" + session = make_mock_session( + make_response(429), + make_response(429), + make_response(200, b"IMAGE"), + ) + with patch.object(di, "_get_session", return_value=session), patch( + "time.sleep" + ): + di._fetch_with_retry("http://example.com/img.jpg", dest) + assert session.get.call_count == 3 + assert dest.read_bytes() == b"IMAGE" + + def test_503_retries_then_succeeds(self, tmp_path): + dest = tmp_path / "img.jpg" + session = make_mock_session( + make_response(503), + make_response(200, b"IMAGE"), + ) + with patch.object(di, "_get_session", return_value=session), patch( + "time.sleep" + ): + di._fetch_with_retry("http://example.com/img.jpg", dest) + assert session.get.call_count == 2 + + def test_connection_error_errno16_retries(self, tmp_path): + """Errno 16 (device/resource busy — too many sockets) retries.""" + dest = tmp_path / "img.jpg" + errno16 = requests.exceptions.ConnectionError( + "[Errno 16] Device or resource busy" + ) + session = make_mock_session(errno16, errno16, make_response(200, b"IMAGE")) + with patch.object(di, "_get_session", return_value=session), patch( + "time.sleep" + ): + di._fetch_with_retry("http://example.com/img.jpg", dest) + assert session.get.call_count == 3 + + def test_timeout_retries_then_succeeds(self, tmp_path): + dest = tmp_path / "img.jpg" + session = make_mock_session( + requests.exceptions.Timeout(), + make_response(200, b"IMAGE"), + ) + with patch.object(di, "_get_session", return_value=session), patch( + "time.sleep" + ): + di._fetch_with_retry("http://example.com/img.jpg", dest) + assert session.get.call_count == 2 + + def test_404_raises_immediately_no_retry(self, tmp_path): + """404 is not in RETRY_STATUSES — raises without retrying.""" + dest = tmp_path / "img.jpg" + session = make_mock_session(make_response(404)) + with patch.object(di, "_get_session", return_value=session), patch( + "time.sleep" + ), pytest.raises(requests.exceptions.HTTPError): + di._fetch_with_retry("http://example.com/img.jpg", dest) + assert session.get.call_count == 1 + + def test_exhausted_retries_on_connection_error_raises(self, tmp_path): + """After MAX_RETRIES connection errors, raises the last one.""" + dest = tmp_path / "img.jpg" + err = requests.exceptions.ConnectionError("connection refused") + session = make_mock_session(*([err] * (di._MAX_RETRIES + 1))) + with patch.object(di, "_get_session", return_value=session), patch( + "time.sleep" + ), pytest.raises(requests.exceptions.ConnectionError): + di._fetch_with_retry("http://example.com/img.jpg", dest) + assert session.get.call_count == di._MAX_RETRIES + 1 + + def test_exhausted_retries_on_timeout_raises(self, tmp_path): + dest = tmp_path / "img.jpg" + session = make_mock_session( + *([requests.exceptions.Timeout()] * (di._MAX_RETRIES + 1)) + ) + with patch.object(di, "_get_session", return_value=session), patch( + "time.sleep" + ), pytest.raises(requests.exceptions.Timeout): + di._fetch_with_retry("http://example.com/img.jpg", dest) + + +# ── download_and_verify ─────────────────────────────────────────────────────── + + +class TestDownloadAndVerify: + + ROW = { + "dataset_source_uuid": "uuid-001", + "absolute_url": "http://example.com/img.jpg", + "relative_local_path": "000/img.jpg", + } + + def test_success(self, tmp_path): + """Valid JPEG → fetch_status=downloaded, dimensions populated.""" + import io + + from PIL import Image + + buf = io.BytesIO() + Image.new("RGB", (64, 48)).save(buf, format="JPEG") + jpeg_bytes = buf.getvalue() + + with patch.object(di, "_fetch_with_retry") as mock_fetch: + + def write_file(url, dest): + dest.parent.mkdir(parents=True, exist_ok=True) + dest.write_bytes(jpeg_bytes) + + mock_fetch.side_effect = write_file + + result = di.download_and_verify(self.ROW, tmp_path) + + assert result["fetch_status"] == "downloaded" + assert result["image_width"] == 64 + assert result["image_height"] == 48 + assert result["corrupted"] is False + assert result["image_size"] > 0 + + def test_network_failure_recorded_as_failed(self, tmp_path): + """Network error → fetch_status=failed, no image on disk.""" + with patch.object( + di, "_fetch_with_retry", side_effect=Exception("connection refused") + ): + result = di.download_and_verify(self.ROW, tmp_path) + + assert result["fetch_status"] == "failed" + assert result["image_width"] is None + + def test_corrupted_image_recorded_as_corrupted(self, tmp_path): + """Truncated/invalid image bytes → fetch_status=corrupted.""" + with patch.object(di, "_fetch_with_retry") as mock_fetch: + + def write_garbage(url, dest): + dest.parent.mkdir(parents=True, exist_ok=True) + dest.write_bytes(b"not an image at all") + + mock_fetch.side_effect = write_garbage + + result = di.download_and_verify(self.ROW, tmp_path) + + assert result["fetch_status"] == "corrupted" + assert result["corrupted"] is True + assert result["image_width"] is None + + +# ── write_results_to_bq ─────────────────────────────────────────────────────── + + +class TestWriteResultsToBq: + + RESULTS = [ + { + "dataset_source_uuid": "u1", + "fetch_status": "downloaded", + "image_width": 100, + "image_height": 80, + "image_size": 5000, + "corrupted": False, + }, + ] + + def test_success(self): + client = MagicMock() + client.load_table_from_dataframe.return_value.result.return_value = None + di.write_results_to_bq(client, self.RESULTS, "test_table") + assert client.load_table_from_dataframe.call_count == 1 + + def test_retries_on_transient_error(self): + """First write fails, second succeeds — should not raise.""" + client = MagicMock() + client.load_table_from_dataframe.side_effect = [ + Exception("BQ transient error"), + MagicMock(result=MagicMock(return_value=None)), + ] + with patch("time.sleep"): + di.write_results_to_bq(client, self.RESULTS, "test_table", max_retries=2) + assert client.load_table_from_dataframe.call_count == 2 + + def test_exhausted_retries_raises(self): + """All retries fail → raises the last exception.""" + client = MagicMock() + client.load_table_from_dataframe.side_effect = Exception("BQ down") + with patch("time.sleep"), pytest.raises(Exception, match="BQ down"): + di.write_results_to_bq(client, self.RESULTS, "test_table", max_retries=2) + assert client.load_table_from_dataframe.call_count == 2 + + +# ── pack_chunk_to_sqfs ──────────────────────────────────────────────────────── + + +class TestPackChunkToSqfs: + + def test_success(self, tmp_path): + staging = tmp_path / "staging" + (staging / "000").mkdir(parents=True) + (staging / "000" / "img.jpg").write_bytes(b"x") + fake_sqfs = staging / "chunk_0001.sqfs" + + with patch("subprocess.run") as mock_run: + mock_run.return_value = MagicMock(returncode=0) + # Simulate mksquashfs creating the output file + fake_sqfs.write_bytes(b"sqfs") + result = di.pack_chunk_to_sqfs(staging, chunk_num=1) + + assert result == fake_sqfs + assert mock_run.call_count == 1 + + def test_mksquashfs_failure_raises_runtime_error(self, tmp_path): + """Non-zero mksquashfs exit → RuntimeError so SLURM marks task failed.""" + staging = tmp_path / "staging" + (staging / "000").mkdir(parents=True) + (staging / "000" / "img.jpg").write_bytes(b"x") + + with patch("subprocess.run") as mock_run: + mock_run.return_value = MagicMock(returncode=1) + with pytest.raises(RuntimeError, match="mksquashfs failed"): + di.pack_chunk_to_sqfs(staging, chunk_num=1) + + def test_empty_staging_returns_none(self, tmp_path): + """No images in staging → returns None without calling mksquashfs.""" + staging = tmp_path / "staging" + staging.mkdir() + with patch("subprocess.run") as mock_run: + result = di.pack_chunk_to_sqfs(staging, chunk_num=1) + assert result is None + mock_run.assert_not_called() + + +# ── merge_chunk_into_training_images ───────────────────────────────────────── + + +class TestMergeChunkIntoTrainingImages: + + def _make_client(self, updated_count: int = 1) -> MagicMock: + client = MagicMock() + client.load_table_from_dataframe.return_value.result.return_value = None + job = MagicMock() + job.dml_stats.updated_row_count = updated_count + client.query.return_value = job + return client + + def test_empty_list_skips_merge(self): + """Completely empty results list → no BQ calls.""" + client = self._make_client() + n = di.merge_chunk_into_training_images(client, [], "t", "d") + assert n == 0 + client.load_table_from_dataframe.assert_not_called() + + def test_downloaded_triggers_merge(self): + """downloaded rows → temp table load + MERGE + cleanup.""" + client = self._make_client(updated_count=1) + results = [ + { + "dataset_source_uuid": "u1", + "fetch_status": "downloaded", + "image_width": 100, + "image_height": 80, + "image_size": 5000, + "corrupted": False, + } + ] + n = di.merge_chunk_into_training_images(client, results, "t", "d") + assert n == 1 + assert client.load_table_from_dataframe.call_count == 1 + assert client.query.call_count == 1 + assert client.delete_table.call_count == 1 + + def test_corrupted_triggers_merge(self): + """corrupted rows → merged with fetch_status='corrupted'.""" + client = self._make_client(updated_count=1) + results = [ + { + "dataset_source_uuid": "u1", + "fetch_status": "corrupted", + "image_width": None, + "image_height": None, + "image_size": None, + "corrupted": True, + } + ] + n = di.merge_chunk_into_training_images(client, results, "t", "d") + assert n == 1 + assert client.load_table_from_dataframe.call_count == 1 + + def test_failed_triggers_merge(self): + """failed rows (404/403/exhausted) → merged so fetch_status='failed' in + training_images. Permanent failures are excluded from future re-runs + via WHERE fetch_status='pending' without needing the LEFT JOIN.""" + client = self._make_client(updated_count=1) + results = [ + { + "dataset_source_uuid": "u1", + "fetch_status": "failed", + "image_width": None, + "image_height": None, + "image_size": None, + "corrupted": None, + } + ] + n = di.merge_chunk_into_training_images(client, results, "t", "d") + assert n == 1 + assert client.load_table_from_dataframe.call_count == 1 + assert client.query.call_count == 1 + assert client.delete_table.call_count == 1 + + def test_all_three_statuses_merged_together(self): + """Mixed chunk — downloaded, corrupted, failed — all three trigger one MERGE.""" + client = self._make_client(updated_count=3) + results = [ + { + "dataset_source_uuid": "u1", + "fetch_status": "downloaded", + "image_width": 100, + "image_height": 80, + "image_size": 5000, + "corrupted": False, + }, + { + "dataset_source_uuid": "u2", + "fetch_status": "corrupted", + "image_width": None, + "image_height": None, + "image_size": 500, + "corrupted": True, + }, + { + "dataset_source_uuid": "u3", + "fetch_status": "failed", + "image_width": None, + "image_height": None, + "image_size": None, + "corrupted": None, + }, + ] + n = di.merge_chunk_into_training_images(client, results, "t", "d") + assert n == 3 + assert ( + client.load_table_from_dataframe.call_count == 1 + ) # one temp table for all 3 + assert client.query.call_count == 1 # one MERGE + assert client.delete_table.call_count == 1 # one cleanup + + def test_failed_rows_included_in_temp_table(self): + """Verify the dataframe passed to BQ includes the failed row.""" + import pandas as pd + + client = self._make_client() + captured_df = {} + + def capture_load(df, table, **kwargs): + captured_df["data"] = df.copy() + return MagicMock(result=MagicMock(return_value=None)) + + client.load_table_from_dataframe.side_effect = capture_load + + results = [ + { + "dataset_source_uuid": "ok", + "fetch_status": "downloaded", + "image_width": 64, + "image_height": 48, + "image_size": 1000, + "corrupted": False, + }, + { + "dataset_source_uuid": "dead", + "fetch_status": "failed", + "image_width": None, + "image_height": None, + "image_size": None, + "corrupted": None, + }, + ] + di.merge_chunk_into_training_images(client, results, "t", "d") + + df = captured_df["data"] + assert len(df) == 2 # both rows in temp table + statuses = set(df["fetch_status"].tolist()) + assert "downloaded" in statuses + assert "failed" in statuses # failed row present + + def test_temp_table_deleted_even_on_merge_failure(self): + """Temp table must be cleaned up even if the MERGE query fails.""" + client = MagicMock() + client.load_table_from_dataframe.return_value.result.return_value = None + client.query.side_effect = Exception("MERGE failed") + + results = [ + { + "dataset_source_uuid": "u1", + "fetch_status": "downloaded", + "image_width": 100, + "image_height": 80, + "image_size": 5000, + "corrupted": False, + } + ] + + with pytest.raises(Exception, match="MERGE failed"): + di.merge_chunk_into_training_images( + client, results, "training_table", "downloads_table" + ) + client.delete_table.assert_called_once() + + +# ── get_pending_rows / MOD split ───────────────────────────────────────────── + + +class TestModSplit: + """Verify that num_jobs/task_id partitioning is correct and complete.""" + + def _make_client( + self, photo_ids: list[int], num_jobs: int, task_id: int + ) -> MagicMock: + """Return a mock BQ client that filters photo_ids by MOD split.""" + matching = [ + { + "dataset_source_uuid": f"uuid-{p}", + "absolute_url": f"http://x/{p}", + "relative_local_path": f"000/{p}.jpg", + } + for p in photo_ids + if p % num_jobs == task_id + ] + client = MagicMock() + client.query.return_value.result.return_value = [ + MagicMock( + **{k: v for k, v in row.items()}, + **{ + "__iter__": lambda self: iter(row.items()), + "keys": lambda self: row.keys(), + }, + ) + for row in matching + ] + # Simpler: just return dicts directly via side_effect + client.query.return_value.result.return_value = matching + return client + + def test_no_overlap_between_tasks(self): + """Each photo_id must appear in exactly one task — no overlaps.""" + photo_ids = list(range(100)) + num_jobs = 10 + all_assigned = [] + + for task_id in range(num_jobs): + assigned = [p for p in photo_ids if p % num_jobs == task_id] + all_assigned.extend(assigned) + + assert len(all_assigned) == len(photo_ids) + assert len(set(all_assigned)) == len(photo_ids) # no duplicates + + def test_all_images_covered_across_tasks(self): + """Union of all task subsets must equal the full image set.""" + photo_ids = list(range(1000)) + num_jobs = 10 + covered = set() + for task_id in range(num_jobs): + subset = {p for p in photo_ids if p % num_jobs == task_id} + assert not subset & covered, f"Overlap at task_id={task_id}" + covered |= subset + assert covered == set(photo_ids) + + def test_task_gets_correct_subset(self): + """Task 3 of 10 should only see photo_ids ending in 3.""" + photo_ids = list(range(50)) + expected = [p for p in photo_ids if p % 10 == 3] # 3, 13, 23, 33, 43 + actual = [p for p in photo_ids if p % 10 == 3] + assert actual == expected + assert all(p % 10 == 3 for p in actual) + + def test_uneven_split_all_images_still_covered(self): + """101 images across 10 tasks — some tasks get 11, others get 10.""" + photo_ids = list(range(101)) + num_jobs = 10 + subsets = [[p for p in photo_ids if p % num_jobs == t] for t in range(num_jobs)] + sizes = [len(s) for s in subsets] + assert sum(sizes) == 101 + assert max(sizes) - min(sizes) <= 1 # balanced within 1 + + def test_single_job_gets_all_images(self): + """num_jobs=1, task_id=0 must return every image.""" + photo_ids = list(range(50)) + assigned = [p for p in photo_ids if p % 1 == 0] + assert assigned == photo_ids + + def test_resumability_skips_already_attempted(self): + """LEFT JOIN should exclude images already in downloads table.""" + # Simulate: 10 images total, 3 already in downloads table + all_uuids = [f"uuid-{i}" for i in range(10)] + attempted = {f"uuid-{i}" for i in range(3)} + pending = [u for u in all_uuids if u not in attempted] + assert len(pending) == 7 + assert not set(pending) & attempted # no overlap with attempted + + def test_force_redownload_ignores_downloads_table(self): + """force_redownload=True should query training_images directly, no LEFT JOIN.""" + client = MagicMock() + client.query.return_value.result.return_value = [] + + di.get_pending_rows( + client, + training_table="t", + downloads_table="d", + num_jobs=10, + task_id=3, + force_redownload=True, + ) + + query_sql = client.query.call_args[0][0] + assert "LEFT JOIN" not in query_sql + assert "MOD(photo_id, 10) = 3" in query_sql + + def test_normal_query_has_left_join(self): + """Normal query must LEFT JOIN downloads table to skip attempted images.""" + client = MagicMock() + client.query.return_value.result.return_value = [] + + di.get_pending_rows( + client, + training_table="t", + downloads_table="d", + num_jobs=10, + task_id=3, + force_redownload=False, + ) + + query_sql = client.query.call_args[0][0] + assert "LEFT JOIN" in query_sql + assert "MOD(ti.photo_id, 10) = 3" in query_sql + + def test_limit_applied_to_query(self): + """--limit N should add LIMIT clause to the BQ query.""" + client = MagicMock() + client.query.return_value.result.return_value = [] + + di.get_pending_rows( + client, + training_table="t", + downloads_table="d", + num_jobs=1, + task_id=0, + limit=50, + ) + + query_sql = client.query.call_args[0][0] + assert "LIMIT 50" in query_sql + + +# ── Multi-task distribution and merge ──────────────────────────────────────── + + +class TestMultiTaskDistributionAndMerge: + """ + Verify correct behaviour when multiple tasks run in parallel: + - work is partitioned correctly across tasks + - tasks don't interfere with each other's queries + - BQ downloads table receives appends from all tasks safely + - training_images MERGE is correct when multiple tasks write concurrently + """ + + PHOTO_IDS = list(range(50)) # simulate 50 images + + def _partition(self, num_jobs: int) -> dict[int, list[int]]: + """Return {task_id: [photo_ids]} for all tasks.""" + return { + t: [p for p in self.PHOTO_IDS if p % num_jobs == t] for t in range(num_jobs) + } + + # ── partitioning ────────────────────────────────────────────────────────── + + def test_two_tasks_partition_all_images(self): + """num_jobs=2: task 0 + task 1 together cover every image exactly once.""" + parts = self._partition(2) + combined = parts[0] + parts[1] + assert sorted(combined) == self.PHOTO_IDS + assert set(parts[0]) & set(parts[1]) == set() # no overlap + + def test_ten_tasks_partition_all_images(self): + """num_jobs=10: all 10 tasks together cover every image exactly once.""" + parts = self._partition(10) + combined = [p for task in parts.values() for p in task] + assert sorted(combined) == self.PHOTO_IDS + for i in range(10): + for j in range(i + 1, 10): + assert set(parts[i]) & set(parts[j]) == set() + + def test_task0_completion_does_not_affect_task1_query(self): + """Task 1's LEFT JOIN only skips images task 1 itself attempted — not task 0's.""" + # task 0 attempted photo_ids 0,2,4... (even); task 1 should still see 1,3,5... + task0_uuids = {f"uuid-{p}" for p in self.PHOTO_IDS if p % 2 == 0} + task1_pending = [p for p in self.PHOTO_IDS if p % 2 == 1] + + # task 1 query: LEFT JOIN filters on task 1's uuids only + # since task 0's uuids (even photo_ids) aren't in task 1's subset, + # they never appear in the LEFT JOIN result anyway + task1_uuids = {f"uuid-{p}" for p in task1_pending} + assert task0_uuids & task1_uuids == set() # completely disjoint + + def test_non_sequential_photo_ids_still_partition_correctly(self): + """Real photo_ids from iNat are large non-sequential ints — MOD still works.""" + real_ids = [ + 487851, + 7047265, + 8233026, + 8427425, + 10239192, + 17327318, + 21463254, + 27648248, + 36757555, + 41676327, + ] + for num_jobs in [2, 5, 10]: + parts = { + t: [p for p in real_ids if p % num_jobs == t] for t in range(num_jobs) + } + combined = [p for task in parts.values() for p in task] + assert sorted(combined) == sorted(real_ids) + + def test_empty_task_handled_gracefully(self): + """A task assigned 0 images should produce 0 downloads cleanly.""" + # with 1 image and num_jobs=2, one task will have 0 images + single_id = [4] # 4 % 2 == 0, so task 1 gets nothing + task0 = [p for p in single_id if p % 2 == 0] + task1 = [p for p in single_id if p % 2 == 1] + assert task0 == [4] + assert task1 == [] + + # ── BQ writes from multiple tasks ──────────────────────────────────────── + + def test_downloads_table_appends_are_independent(self): + """Both tasks append to downloads table — append-only, no conflicts.""" + client = MagicMock() + client.load_table_from_dataframe.return_value.result.return_value = None + + task0_results = [ + { + "dataset_source_uuid": f"uuid-{p}", + "fetch_status": "downloaded", + "image_width": 100, + "image_height": 80, + "image_size": 5000, + "corrupted": False, + } + for p in range(0, 10, 2) + ] # even photo_ids + + task1_results = [ + { + "dataset_source_uuid": f"uuid-{p}", + "fetch_status": "downloaded", + "image_width": 100, + "image_height": 80, + "image_size": 5000, + "corrupted": False, + } + for p in range(1, 10, 2) + ] # odd photo_ids + + # both tasks write to the same table — no conflict because WRITE_APPEND + di.write_results_to_bq(client, task0_results, "downloads_table") + di.write_results_to_bq(client, task1_results, "downloads_table") + + assert client.load_table_from_dataframe.call_count == 2 + # both calls target same table + calls = client.load_table_from_dataframe.call_args_list + assert calls[0][0][1] == "downloads_table" + assert calls[1][0][1] == "downloads_table" + + def test_merge_from_two_tasks_updates_correct_rows(self): + """Each task's MERGE only touches its own rows — no cross-task collision.""" + client = MagicMock() + client.load_table_from_dataframe.return_value.result.return_value = None + + # task 0 merges even photo_ids + job0 = MagicMock() + job0.dml_stats.updated_row_count = 5 + # task 1 merges odd photo_ids + job1 = MagicMock() + job1.dml_stats.updated_row_count = 5 + client.query.side_effect = [job0, job1] + + task0_results = [ + { + "dataset_source_uuid": f"uuid-{i}", + "fetch_status": "downloaded", + "image_width": 64, + "image_height": 48, + "image_size": 1000, + "corrupted": False, + } + for i in range(5) + ] + task1_results = [ + { + "dataset_source_uuid": f"uuid-{i+5}", + "fetch_status": "downloaded", + "image_width": 64, + "image_height": 48, + "image_size": 1000, + "corrupted": False, + } + for i in range(5) + ] + + n0 = di.merge_chunk_into_training_images( + client, task0_results, "training_table", "downloads_table" + ) + n1 = di.merge_chunk_into_training_images( + client, task1_results, "training_table", "downloads_table" + ) + + assert n0 == 5 + assert n1 == 5 + assert client.query.call_count == 2 # one MERGE per task + assert client.delete_table.call_count == 2 # temp table cleaned per task + + def test_total_coverage_after_all_tasks_complete(self): + """After all tasks finish, every image should be accounted for.""" + num_jobs = 5 + all_downloaded = set() + + for task_id in range(num_jobs): + task_images = {p for p in self.PHOTO_IDS if p % num_jobs == task_id} + all_downloaded |= task_images + + assert all_downloaded == set(self.PHOTO_IDS) + assert len(all_downloaded) == len(self.PHOTO_IDS) + + +# ── warn_chunk_accumulation ─────────────────────────────────────────────────── + + +class TestWarnChunkAccumulation: + + def test_no_warning_below_threshold(self, tmp_path, capsys): + for i in range(5): + (tmp_path / f"chunk_{i:04d}.sqfs").write_bytes(b"x") + di.warn_chunk_accumulation(tmp_path) + assert "WARNING" not in capsys.readouterr().out + + def test_warning_at_threshold(self, tmp_path, capsys): + for i in range(di._CHUNK_ACCUMULATION_WARN): + (tmp_path / f"chunk_{i:04d}.sqfs").write_bytes(b"x") + di.warn_chunk_accumulation(tmp_path) + assert "WARNING" in capsys.readouterr().out + + +# ── --dataset flag and NULL fetch_status handling ───────────────────────────── + + +class TestDatasetFlagAndNullFetchStatus: + """Tests for --dataset CLI flag and NULL fetch_status support (global_all_leps_2605).""" + + # ── --dataset default ───────────────────────────────────────────────────── + + def test_default_dataset_constant(self): + """BQ_DEFAULT_DATASET must default to global_butterflies_2604 for backwards compat.""" + assert di.BQ_DEFAULT_DATASET == "global_butterflies_2604" + + # ── NULL fetch_status in get_pending_rows ───────────────────────────────── + + def test_normal_query_includes_null_fetch_status(self): + """Normal (LEFT JOIN) query must include OR fetch_status IS NULL to pick up + rows from datasets like global_all_leps_2605 where status starts as NULL.""" + client = MagicMock() + client.query.return_value.result.return_value = [] + + di.get_pending_rows( + client, + training_table="proj.ds.training_images", + downloads_table="proj.ds.training_images_downloads", + num_jobs=10, + task_id=0, + force_redownload=False, + ) + + sql = client.query.call_args[0][0] + assert "fetch_status IS NULL" in sql + assert "fetch_status = 'pending'" in sql + + def test_force_redownload_query_includes_null_fetch_status(self): + """Force-redownload query must also include OR fetch_status IS NULL.""" + client = MagicMock() + client.query.return_value.result.return_value = [] + + di.get_pending_rows( + client, + training_table="proj.ds.training_images", + downloads_table="proj.ds.training_images_downloads", + num_jobs=10, + task_id=0, + force_redownload=True, + ) + + sql = client.query.call_args[0][0] + assert "fetch_status IS NULL" in sql + assert "fetch_status = 'pending'" in sql + assert "LEFT JOIN" not in sql + + # ── NULL fetch_status in MERGE ──────────────────────────────────────────── + + def test_merge_condition_handles_null_fetch_status(self): + """MERGE SQL must allow updating rows where T.fetch_status IS NULL, + not just 'pending' — needed for global_all_leps_2605 initial state.""" + client = MagicMock() + client.load_table_from_dataframe.return_value.result.return_value = None + job = MagicMock() + job.dml_stats.updated_row_count = 1 + client.query.return_value = job + + results = [ + { + "dataset_source_uuid": "u1", + "fetch_status": "downloaded", + "image_width": 64, + "image_height": 48, + "image_size": 1000, + "corrupted": False, + } + ] + + di.merge_chunk_into_training_images( + client, + results, + training_table="leps-ai.global_all_leps_2605.training_images", + downloads_table="leps-ai.global_all_leps_2605.training_images_downloads", + ) + + sql = client.query.call_args[0][0] + assert "fetch_status IS NULL" in sql + assert "fetch_status = 'pending'" in sql + + # ── tmp_table derived from training_table ───────────────────────────────── + + def test_tmp_table_uses_same_dataset_as_training_table(self): + """Temp table for MERGE must be in the same dataset as training_table, + not hardcoded to global_butterflies_2604.""" + client = MagicMock() + client.load_table_from_dataframe.return_value.result.return_value = None + job = MagicMock() + job.dml_stats.updated_row_count = 1 + client.query.return_value = job + + results = [ + { + "dataset_source_uuid": "u1", + "fetch_status": "downloaded", + "image_width": 64, + "image_height": 48, + "image_size": 1000, + "corrupted": False, + } + ] + + di.merge_chunk_into_training_images( + client, + results, + training_table="leps-ai.global_all_leps_2605.training_images", + downloads_table="leps-ai.global_all_leps_2605.training_images_downloads", + ) + + # The temp table passed to load_table_from_dataframe must be in global_all_leps_2605 + tmp_table_arg = client.load_table_from_dataframe.call_args[0][1] + assert tmp_table_arg.startswith("leps-ai.global_all_leps_2605.") + assert "global_butterflies_2604" not in tmp_table_arg + + def test_tmp_table_not_in_wrong_dataset_when_using_new_dataset(self): + """Regression: old code used BQ_DATASET module constant — ensure it no longer does.""" + client = MagicMock() + client.load_table_from_dataframe.return_value.result.return_value = None + job = MagicMock() + job.dml_stats.updated_row_count = 1 + client.query.return_value = job + + results = [ + { + "dataset_source_uuid": "u1", + "fetch_status": "downloaded", + "image_width": 64, + "image_height": 48, + "image_size": 1000, + "corrupted": False, + } + ] + + di.merge_chunk_into_training_images( + client, + results, + training_table="leps-ai.global_all_leps_2605.training_images", + downloads_table="leps-ai.global_all_leps_2605.training_images_downloads", + ) + + tmp_table_arg = client.load_table_from_dataframe.call_args[0][1] + assert ( + "global_butterflies_2604" not in tmp_table_arg + ) # must not leak old dataset diff --git a/tests/dataset_tools/test_merge_sqfs_chunks.py b/tests/dataset_tools/test_merge_sqfs_chunks.py new file mode 100644 index 0000000..e090533 --- /dev/null +++ b/tests/dataset_tools/test_merge_sqfs_chunks.py @@ -0,0 +1,482 @@ +""" +Tests for merge_sqfs_chunks.py. + +The script streams chunk sqfs files as a single tar to stdout for piping to sqfstar. +All squashfuse calls are mocked — no real sqfs or FUSE needed. + +Run with: pytest tests/dataset_tools/test_merge_sqfs_chunks.py -v +""" + +import io +import os +import sys +import tarfile +import tempfile +from pathlib import Path +from unittest.mock import MagicMock, call, patch + +import pytest + +import src.dataset_tools.bq_squashfs.merge_sqfs_chunks as sct + +# ── helpers ─────────────────────────────────────────────────────────────────── + + +def make_chunk_sqfs(staging_dir: Path, chunk_num: int) -> Path: + """Create a dummy chunk_NNNN.sqfs file (content doesn't matter — squashfuse is mocked).""" + p = staging_dir / f"chunk_{chunk_num:04d}.sqfs" + p.write_bytes(b"fake sqfs") + return p + + +def make_mount_dir_with_images(root: Path, filenames: list[str]) -> Path: + """Create a directory tree simulating a squashfuse mount with images.""" + mnt = root / "mnt" + bucket = mnt / "000" + bucket.mkdir(parents=True) + for name in filenames: + (bucket / name).write_bytes(b"JPEG") + return mnt + + +def read_tar_from_bytes(data: bytes) -> list[str]: + """Return list of member names from a tar written to bytes.""" + with tarfile.open(fileobj=io.BytesIO(data), mode="r:") as tf: + return [m.name for m in tf.getmembers()] + + +# ── stream_dir_to_tar ───────────────────────────────────────────────────────── + + +class TestStreamDirToTar: + + def test_files_added_with_relative_paths(self, tmp_path): + """Files are added to the tar with paths relative to the mount dir.""" + mnt = make_mount_dir_with_images(tmp_path, ["abc.jpg", "def.jpg"]) + buf = io.BytesIO() + with tarfile.open(fileobj=buf, mode="w:") as tf: + count = sct.stream_dir_to_tar(tf, str(mnt)) + assert count == 2 + members = read_tar_from_bytes(buf.getvalue()) + assert "000/abc.jpg" in members + assert "000/def.jpg" in members + + def test_dirs_included_files_counted(self, tmp_path): + """Directory entries are included; only files contribute to count.""" + mnt = make_mount_dir_with_images(tmp_path, ["img.jpg"]) + buf = io.BytesIO() + with tarfile.open(fileobj=buf, mode="w:") as tf: + count = sct.stream_dir_to_tar(tf, str(mnt)) + assert count == 1 # only the .jpg, not the dir + members = read_tar_from_bytes(buf.getvalue()) + assert "000" in members # dir entry + assert "000/img.jpg" in members # file entry + + def test_empty_mount_dir_returns_zero(self, tmp_path): + """Empty mount dir → count=0, no files added.""" + mnt = tmp_path / "mnt" + mnt.mkdir() + buf = io.BytesIO() + with tarfile.open(fileobj=buf, mode="w:") as tf: + count = sct.stream_dir_to_tar(tf, str(mnt)) + assert count == 0 + + def test_multiple_bucket_dirs_all_streamed(self, tmp_path): + """Files across multiple bucket dirs are all included.""" + mnt = tmp_path / "mnt" + for bucket in ["000", "001", "002"]: + (mnt / bucket).mkdir(parents=True) + (mnt / bucket / "img.jpg").write_bytes(b"JPEG") + buf = io.BytesIO() + with tarfile.open(fileobj=buf, mode="w:") as tf: + count = sct.stream_dir_to_tar(tf, str(mnt)) + assert count == 3 + members = read_tar_from_bytes(buf.getvalue()) + assert "000/img.jpg" in members + assert "001/img.jpg" in members + assert "002/img.jpg" in members + + +# ── squashfuse_mount / unmount ──────────────────────────────────────────────── + + +class TestSquashfuseMount: + + def test_successful_mount_returns_true(self): + with patch("subprocess.run") as mock_run: + mock_run.return_value = MagicMock(returncode=0) + result = sct.squashfuse_mount("/fake.sqfs", "/mnt/fake") + assert result is True + + def test_failed_mount_returns_false(self, capsys): + with patch("subprocess.run") as mock_run: + mock_run.return_value = MagicMock( + returncode=1, stderr="fuse: failed to open /fake.sqfs" + ) + result = sct.squashfuse_mount("/fake.sqfs", "/mnt/fake") + assert result is False + assert "ERROR" in capsys.readouterr().err + + def test_unmount_calls_fusermount(self): + with patch("subprocess.run") as mock_run: + mock_run.return_value = MagicMock(returncode=0) + sct.squashfuse_unmount("/mnt/fake") + mock_run.assert_called_once() + cmd = mock_run.call_args[0][0] + assert "fusermount" in cmd + assert "-u" in cmd + assert "/mnt/fake" in cmd + + +# ── main: no chunks ─────────────────────────────────────────────────────────── + + +class TestNoChunks: + + def test_empty_staging_dir_exits_with_error(self, tmp_path, capsys): + """No chunk_*.sqfs files → exits with code 1.""" + with pytest.raises(SystemExit) as exc: + with patch("sys.argv", ["merge_sqfs_chunks.py", str(tmp_path)]): + sct.main() + assert exc.value.code == 1 + assert "ERROR" in capsys.readouterr().err + + def test_missing_staging_dir_exits_with_error(self, tmp_path, capsys): + """Non-existent staging dir → exits with code 1.""" + missing = tmp_path / "does_not_exist" + with pytest.raises(SystemExit) as exc: + with patch("sys.argv", ["merge_sqfs_chunks.py", str(missing)]): + sct.main() + assert exc.value.code == 1 + + +# ── main: dry run ───────────────────────────────────────────────────────────── + + +class TestDryRun: + + def test_dry_run_lists_chunks_no_streaming(self, tmp_path, capsys): + """--dry-run prints chunk paths without mounting or streaming.""" + staging = tmp_path / "staging" + staging.mkdir() + c1 = make_chunk_sqfs(staging, 1) + c2 = make_chunk_sqfs(staging, 2) + + with patch( + "sys.argv", ["merge_sqfs_chunks.py", str(staging), "--dry-run"] + ), patch.object(sct, "squashfuse_mount") as mock_mount: + sct.main() + + mock_mount.assert_not_called() # no mounting in dry run + out = capsys.readouterr().out + assert str(c1) in out + assert str(c2) in out + + def test_dry_run_lists_in_sorted_order(self, tmp_path, capsys): + """Chunks are listed in sorted (chunk_0001 before chunk_0002) order.""" + staging = tmp_path / "staging" + staging.mkdir() + make_chunk_sqfs(staging, 3) + make_chunk_sqfs(staging, 1) + make_chunk_sqfs(staging, 2) + + with patch("sys.argv", ["merge_sqfs_chunks.py", str(staging), "--dry-run"]): + sct.main() + + lines = [l for l in capsys.readouterr().out.strip().splitlines() if l] + names = [Path(l).name for l in lines] + assert names == ["chunk_0001.sqfs", "chunk_0002.sqfs", "chunk_0003.sqfs"] + + +# ── main: streaming ─────────────────────────────────────────────────────────── + + +class TestStreaming: + + def _run_stream( + self, staging: Path, extra_args: list[str] = [] + ) -> tuple[bytes, str]: + """Run main(), capture stdout bytes and stderr text.""" + stdout_buf = io.BytesIO() + with patch( + "sys.argv", ["merge_sqfs_chunks.py", str(staging)] + extra_args + ), patch("sys.stdout") as mock_stdout: + mock_stdout.buffer = stdout_buf + sct.main() + return stdout_buf.getvalue(), "" + + def test_single_chunk_produces_valid_tar(self, tmp_path): + """One chunk sqfs → tar stream contains all files from that chunk.""" + staging = tmp_path / "staging" + staging.mkdir() + make_chunk_sqfs(staging, 1) + + # Create fake mount content + fake_mnt = tmp_path / "fake_mnt" + fake_mnt.mkdir() + (fake_mnt / "000").mkdir() + (fake_mnt / "000" / "img.jpg").write_bytes(b"JPEG") + + stdout_buf = io.BytesIO() + with patch("sys.argv", ["merge_sqfs_chunks.py", str(staging)]), patch( + "sys.stdout" + ) as mock_stdout, patch.object( + sct, "squashfuse_mount", return_value=True + ), patch.object( + sct, "squashfuse_unmount" + ), patch( + "tempfile.mkdtemp", return_value=str(tmp_path / "mnt_base") + ), patch( + "os.makedirs" + ), patch( + "os.rmdir" + ), patch.object( + sct, "stream_dir_to_tar", return_value=1 + ) as mock_stream: + mock_stdout.buffer = stdout_buf + sct.main() + + mock_stream.assert_called_once() + + def test_two_chunks_single_continuous_stream(self, tmp_path): + """Two chunks produce ONE continuous tar (not two separate tars).""" + staging = tmp_path / "staging" + staging.mkdir() + make_chunk_sqfs(staging, 1) + make_chunk_sqfs(staging, 2) + + call_count = {"n": 0} + + def fake_stream(tar, mnt_dir): + call_count["n"] += 1 + return 5 + + stdout_buf = io.BytesIO() + with patch("sys.argv", ["merge_sqfs_chunks.py", str(staging)]), patch( + "sys.stdout" + ) as mock_stdout, patch.object( + sct, "squashfuse_mount", return_value=True + ), patch.object( + sct, "squashfuse_unmount" + ), patch( + "tempfile.mkdtemp", return_value=str(tmp_path / "mnt_base") + ), patch( + "os.makedirs" + ), patch( + "os.rmdir" + ), patch.object( + sct, "stream_dir_to_tar", side_effect=fake_stream + ): + mock_stdout.buffer = stdout_buf + sct.main() + + # stream_dir_to_tar called twice (one per chunk) into the SAME tar + assert call_count["n"] == 2 + + def test_chunks_always_preserved_after_stream(self, tmp_path): + """Chunks are never deleted by stream_chunks_to_tar — deletion is the + job script's responsibility after verification passes.""" + staging = tmp_path / "staging" + staging.mkdir() + chunk = make_chunk_sqfs(staging, 1) + + stdout_buf = io.BytesIO() + with patch("sys.argv", ["merge_sqfs_chunks.py", str(staging)]), patch( + "sys.stdout" + ) as mock_stdout, patch.object( + sct, "squashfuse_mount", return_value=True + ), patch.object( + sct, "squashfuse_unmount" + ), patch( + "tempfile.mkdtemp", return_value=str(tmp_path / "mnt_base") + ), patch( + "os.makedirs" + ), patch( + "os.rmdir" + ), patch.object( + sct, "stream_dir_to_tar", return_value=1 + ): + mock_stdout.buffer = stdout_buf + sct.main() + + assert chunk.exists() # always preserved — job script deletes after verify + + +# ── main: error handling ────────────────────────────────────────────────────── + + +class TestErrorHandling: + + def test_failed_mount_skipped_continues_to_next_chunk(self, tmp_path, capsys): + """If one chunk fails to mount, it's skipped and remaining chunks continue.""" + staging = tmp_path / "staging" + staging.mkdir() + make_chunk_sqfs(staging, 1) + make_chunk_sqfs(staging, 2) + + # chunk 1 fails to mount, chunk 2 succeeds + mount_results = [False, True] + + stdout_buf = io.BytesIO() + with patch("sys.argv", ["merge_sqfs_chunks.py", str(staging)]), patch( + "sys.stdout" + ) as mock_stdout, patch.object( + sct, "squashfuse_mount", side_effect=mount_results + ), patch.object( + sct, "squashfuse_unmount" + ), patch( + "tempfile.mkdtemp", return_value=str(tmp_path / "mnt_base") + ), patch( + "os.makedirs" + ), patch( + "os.rmdir" + ), patch.object( + sct, "stream_dir_to_tar", return_value=5 + ): + mock_stdout.buffer = stdout_buf + with pytest.raises(SystemExit) as exc: + sct.main() + + # exits non-zero because there were errors + assert exc.value.code == 1 + err = capsys.readouterr().err + assert "errors=1" in err + + def test_all_mounts_fail_exits_nonzero(self, tmp_path): + """All mounts failing → exit code 1.""" + staging = tmp_path / "staging" + staging.mkdir() + make_chunk_sqfs(staging, 1) + make_chunk_sqfs(staging, 2) + + stdout_buf = io.BytesIO() + with patch("sys.argv", ["merge_sqfs_chunks.py", str(staging)]), patch( + "sys.stdout" + ) as mock_stdout, patch.object( + sct, "squashfuse_mount", return_value=False + ), patch( + "tempfile.mkdtemp", return_value=str(tmp_path / "mnt_base") + ), patch( + "os.makedirs" + ), patch( + "os.rmdir" + ): + mock_stdout.buffer = stdout_buf + with pytest.raises(SystemExit) as exc: + sct.main() + + assert exc.value.code == 1 + + def test_empty_chunk_exits_nonzero(self, tmp_path): + """A chunk that mounts but contains 0 images → exit 1 + WARNING logged.""" + staging = tmp_path / "staging" + staging.mkdir() + make_chunk_sqfs(staging, 1) + + stdout_buf = io.BytesIO() + with patch("sys.argv", ["merge_sqfs_chunks.py", str(staging)]), patch( + "sys.stdout" + ) as mock_stdout, patch.object( + sct, "squashfuse_mount", return_value=True + ), patch.object( + sct, "squashfuse_unmount" + ), patch( + "tempfile.mkdtemp", return_value=str(tmp_path / "mnt_base") + ), patch( + "os.makedirs" + ), patch( + "os.rmdir" + ), patch.object( + sct, "stream_dir_to_tar", return_value=0 + ): # 0 images + mock_stdout.buffer = stdout_buf + with pytest.raises(SystemExit) as exc: + sct.main() + assert exc.value.code == 1 + + def test_squashfuse_retry_on_transient_failure(self, tmp_path): + """squashfuse failure retried once before giving up.""" + staging = tmp_path / "staging" + staging.mkdir() + make_chunk_sqfs(staging, 1) + + fail = MagicMock(returncode=1, stderr="fuse: temporary error") + ok = MagicMock(returncode=0) + + with patch("subprocess.run", side_effect=[fail, ok]) as mock_run, patch( + "time.sleep" + ): + result = sct.squashfuse_mount("/fake.sqfs", "/mnt/fake", retries=1) + + assert result is True + assert mock_run.call_count == 2 # one fail + one retry + + def test_squashfuse_unmount_retries_on_failure(self, capsys): + """fusermount failure retried; logs warning instead of raising.""" + fail = MagicMock(returncode=1, stderr="resource busy") + ok = MagicMock(returncode=0) + + with patch("subprocess.run", side_effect=[fail, ok]), patch("time.sleep"): + result = sct.squashfuse_unmount("/mnt/fake", retries=2) + + assert result is True # succeeded on second attempt + + def test_squashfuse_unmount_warns_on_all_failures(self, capsys): + """All unmount retries exhausted → warning logged, no raise.""" + fail = MagicMock(returncode=1, stderr="resource busy") + with patch("subprocess.run", return_value=fail), patch("time.sleep"): + result = sct.squashfuse_unmount("/mnt/fake", retries=2) + assert result is False + assert "WARNING" in capsys.readouterr().err + + def test_delete_after_stream_flag_removed(self, tmp_path): + """--delete-after-stream was removed — passing it should raise an error.""" + staging = tmp_path / "staging" + staging.mkdir() + make_chunk_sqfs(staging, 1) + + with patch( + "sys.argv", ["merge_sqfs_chunks.py", str(staging), "--delete-after-stream"] + ): + with pytest.raises(SystemExit) as exc: + sct.main() + assert exc.value.code == 2 # argparse unrecognised argument + + def test_chunks_processed_in_sorted_order(self, tmp_path): + """Chunks are processed in sorted order: chunk_0001 before chunk_0002.""" + staging = tmp_path / "staging" + staging.mkdir() + make_chunk_sqfs(staging, 3) + make_chunk_sqfs(staging, 1) + make_chunk_sqfs(staging, 2) + + processed_order = [] + + def fake_mount(sqfs_path, mnt_dir): + processed_order.append(Path(sqfs_path).name) + return True + + stdout_buf = io.BytesIO() + with patch("sys.argv", ["merge_sqfs_chunks.py", str(staging)]), patch( + "sys.stdout" + ) as mock_stdout, patch.object( + sct, "squashfuse_mount", side_effect=fake_mount + ), patch.object( + sct, "squashfuse_unmount" + ), patch( + "tempfile.mkdtemp", return_value=str(tmp_path / "mnt_base") + ), patch( + "os.makedirs" + ), patch( + "os.rmdir" + ), patch.object( + sct, "stream_dir_to_tar", return_value=1 + ): + mock_stdout.buffer = stdout_buf + sct.main() + + assert processed_order == [ + "chunk_0001.sqfs", + "chunk_0002.sqfs", + "chunk_0003.sqfs", + ]