diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index bd22d44..698fe4c 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -21,8 +21,10 @@ repos: - id: ruff args: [--fix, --exit-non-zero-on-fix] types_or: [python, jupyter] + exclude: "^openpmcvl/granular/" - id: ruff-format types_or: [python, jupyter] + exclude: "^openpmcvl/granular/" - repo: https://github.com/pre-commit/mirrors-mypy rev: v1.11.2 @@ -31,7 +33,7 @@ repos: entry: python3 -m mypy --config-file pyproject.toml language: system types: [python] - exclude: "tests" + exclude: "(^tests/|^openpmcvl/granular/|^openpmcvl/.*/tests/)" - repo: https://github.com/crate-ci/typos rev: v1.24.5 @@ -44,6 +46,7 @@ repos: hooks: - id: nbqa-ruff args: [--fix, --exit-non-zero-on-fix] + exclude: "^openpmcvl/granular/" ci: autofix_commit_msg: | diff --git a/README.md b/README.md index 0b0604e..52efff3 100644 --- a/README.md +++ b/README.md @@ -7,8 +7,8 @@ [![license](https://img.shields.io/github/license/VectorInstitute/aieng-template.svg)](https://github.com/VectorInstitute/pmc-data-extraction/blob/main/LICENSE.md)
- Open-PMC Pipeline
diff --git a/openpmcvl/granular/models/subfigure_ocr.py b/openpmcvl/granular/models/subfigure_ocr.py index a470b83..cf25111 100644 --- a/openpmcvl/granular/models/subfigure_ocr.py +++ b/openpmcvl/granular/models/subfigure_ocr.py @@ -89,7 +89,7 @@ def detect_subfigure_boundaries(self, figure_path): ## Reformat model outputs to display bounding boxes in our desired format ## List of lists where each inner list is [x1, y1, x2, y2, confidence] - subfigure_info = list() + subfigure_info = [] if outputs[0] is None: return subfigure_info diff --git a/openpmcvl/granular/models/yolo_layer.py b/openpmcvl/granular/models/yolo_layer.py index e7c48b6..2bf325d 100644 --- a/openpmcvl/granular/models/yolo_layer.py +++ b/openpmcvl/granular/models/yolo_layer.py @@ -470,7 +470,7 @@ class (float): class index. for ti in range(n): i, j = truth_i[ti], truth_j[ti] - # find box with iou over 0.7 and under 0.3 (achor point) + # find box with iou over 0.7 and under 0.3 (anchor point) current_truth_box = truth_box[ti : ti + 1] current_pred_boxes = pred[b, :, j, i, :4] pred_ious = bboxes_iou( diff --git a/openpmcvl/granular/pipeline/subcaption.ipynb b/openpmcvl/granular/pipeline/subcaption.ipynb index 0fe63b5..969f6a0 100644 --- a/openpmcvl/granular/pipeline/subcaption.ipynb +++ b/openpmcvl/granular/pipeline/subcaption.ipynb @@ -17,7 +17,7 @@ "\n", "PMC_ROOT = \"set this directory\"\n", "\n", - "# Make sure .env file containt OPENAI_API_KEY\n", + "# Make sure .env file contains OPENAI_API_KEY\n", "load_dotenv()\n", "client = OpenAI()" ] @@ -47,9 +47,9 @@ "PROMPT = \"\"\"\n", "Subfigure labels are letters referring to individual subfigures within a larger figure.\n", "This is a caption: \"%s\"\n", - "Check if the caption contains explicit subfigure label. \n", - "If not, output \"NO\" and end the generation. \n", - "If yes, output \"YES\", then generate the subcaption of the subfigures according to the caption. \n", + "Check if the caption contains explicit subfigure label.\n", + "If not, output \"NO\" and end the generation.\n", + "If yes, output \"YES\", then generate the subcaption of the subfigures according to the caption.\n", "The output should use the template:\n", " YES\n", " Subfigure-A: ...\n", @@ -158,7 +158,8 @@ "outputs": [], "source": [ "# Upload the requests file to OpenAI for batch processing\n", - "batch_input_file = client.files.create(file=open(requests_file, \"rb\"), purpose=\"batch\")\n", + "with open(requests_file, \"rb\") as request_file:\n", + " batch_input_file = client.files.create(file=request_file, purpose=\"batch\")\n", "batch_input_file_id = batch_input_file.id\n", "\n", "# Create a batch job to process the requests\n", diff --git a/pyproject.toml b/pyproject.toml index 6954641..19eaa8e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -44,6 +44,11 @@ nbqa = { version = "^1.7.0", extras = ["toolchain"] } pip-audit = "^2.7.1" [tool.mypy] +exclude = [ + "^working/", + "^openpmcvl/granular/", + "^openpmcvl/.*/tests/", +] ignore_missing_imports = true install_types = true pretty = true @@ -68,6 +73,7 @@ extra_checks = true [tool.ruff] include = ["*.py", "pyproject.toml", "*.ipynb"] +extend-exclude = ["working", "openpmcvl/granular"] line-length = 88 [tool.ruff.format] @@ -110,6 +116,7 @@ ignore = [ # Ignore import violations in all `__init__.py` files. [tool.ruff.lint.per-file-ignores] "__init__.py" = ["E402", "F401", "F403", "F811"] +"*.ipynb" = ["D100"] [tool.ruff.lint.pep8-naming] ignore-names = ["X*", "setUp"] @@ -132,6 +139,7 @@ norecursedirs = ["working","openpmcvl"] [tool.typos.default.extend-words] nd = "nd" +thre = "thre" [build-system] requires = ["poetry-core>=1.0.0"] diff --git a/working/process/subcaption_and_summary_generation/README.md b/working/process/subcaption_and_summary_generation/README.md new file mode 100644 index 0000000..737d91f --- /dev/null +++ b/working/process/subcaption_and_summary_generation/README.md @@ -0,0 +1,106 @@ +## vLLM Inference Pipeline for Open-PMC-18M Subcaption, Image-context Summary generation, and Modality Labeling + +This repo contains three vLLM inference stages, each launched via a Slurm bash script: + +* **Stage 1 (Subcaption extraction, VLM):** `Qwen2.5-VL-32B-Instruct` generates a *verbatim* subfigure caption from a full figure caption + subfigure image. +* **Stage 2 (Context summary, LLM):** `Qwen2.5-14B-Instruct` generates a focused summary of the context passage relevant to the subcaption. +* **Stage 3 (Modality Labeling, VLM):** `Qwen2.5-VL-32B-Instruct` generates L2 labels, then L1 and L0 labels are inferred from a predefined set based on the generated L2 label. + +### Environment / Versions + +This pipeline was run with: + +* `vllm==0.8.2` +* `xformers==0.0.29.post2` +* `torch==2.6.0` + +### Inputs + +All scripts read and **overwrite** the same CSV or Jsonl (checkpointing is done by writing back to `--data_path`). + +**Required columns** + +* Subcaption stage (`generate_subcaption_vllm.py`): + + * `subfig_path` (path to subfigure image) + * `caption` (full compound figure caption) + * Output column: `sub_caption` +* Summary stage (`generate_summary_vllm.py`): + + * `caption` (full compound figure caption) + * `sub_caption` (subcaption for each subfigure) + * `image_context` (image context related to subfigure) + * Output column: `summary` +* Modality Labeling stage (`generate_modality_labels_vllm.py`): + + * `subfig_path` (path to subfigure image) + * Output column: `L0_label`, `L1_label`, and `L2_label` + +All stages support **resume** behavior: they skip rows where the output column is already filled (non-empty). + +--- + +## How to Run (Slurm) + +### 1) Subcaption generation (Qwen2.5-VL-32B-Instruct) + +Edit the Slurm script to point to: + +* your python file path +* your CSV path (`--data_path`) +* your model weights path (`--model_dir`) +* any desired batch/tp settings + +Then submit: + +```bash +sbatch run_vllm_subcaption_inference.sh +``` + +Slurm script reference: + +**What it does:** launches `generate_subcaption_vllm.py` with vLLM tensor parallelism and writes `sub_caption` back into the CSV. + +--- + +### 2) Summary generation (Qwen2.5-14B-Instruct) + +After Stage 1 finishes (CSV now has `sub_caption`), edit and submit: + +```bash +sbatch run_vllm_summary_inference.sh +``` + +Slurm script reference: + +**What it does:** runs `generate_summary_vllm.py` and writes `summary` back into the same CSV. + +--- + +### 3) Modality Label generation (Qwen2.5-VL-32B-Instruct) + +Edit the Slurm script to point to: + +* your python file path +* your CSV path (`--data_path`) +* your model weights path (`--model_dir`) +* any desired batch/tp settings + +Then submit: + +```bash +sbatch run_vllm_modality_inference.sh +``` + +Slurm script reference: + +**What it does:** runs `generate_modality_labels_vllm.py` and writes `L0`, `L1`, and `L2` labels back into the same jsonl file. + +--- + +## Notes + +* **Paths:** All Slurm scripts include placeholder paths like `/path/to/...` — replace them before submitting. +* **GPU selection:** All scripts set `CUDA_VISIBLE_DEVICES=0,1` and use `--tp_size 2` to shard across 2 GPUs. +* **Checkpointing:** All scripts allow periodic checkpointing. +* **Outputs formatting:** subcaptions are extracted from `...`, and summaries from `...` (regex-based extraction). diff --git a/working/process/subcaption_and_summary_generation/scripts/run_vllm_modality_inference.sh b/working/process/subcaption_and_summary_generation/scripts/run_vllm_modality_inference.sh new file mode 100644 index 0000000..c15338d --- /dev/null +++ b/working/process/subcaption_and_summary_generation/scripts/run_vllm_modality_inference.sh @@ -0,0 +1,33 @@ +#!/bin/bash +#SBATCH --job-name=pmc-subcaption-qwen32b +#SBATCH --partition=a100 +#SBATCH --time=1-00:00:00 +#SBATCH --nodes=1 +#SBATCH --gpus-per-node=2 +#SBATCH --cpus-per-task=4 +#SBATCH --mem=59G +#SBATCH --output=qwen32b-subcap.%j.out + +# Activate your environment + +echo "Script Run Start!" +nvidia-smi + +#module load cuda-12.4 +module load gcc-12.3.0 +gcc --version + +source ~/envs/exp/bin/activate # Adjust this path to your virtual environment + +echo "Module Loaded and Environment Activated!" + +# Specify which GPUs to use +CUDA_VISIBLE_DEVICES=0,1 \ +python /path/to/generate_modality_labels_vllm.py \ + --data_path /path/to/data \ + --model_dir /path/to/Qwen2.5-VL-32B-Instruct \ + --batch_size 512 \ + --max_new_tokens 128 \ + --tp_size 2 \ + --gpu_mem_util 0.90 \ + --dtype bfloat16 \ No newline at end of file diff --git a/working/process/subcaption_and_summary_generation/scripts/run_vllm_subcaption_inference.sh b/working/process/subcaption_and_summary_generation/scripts/run_vllm_subcaption_inference.sh new file mode 100644 index 0000000..05a6b1b --- /dev/null +++ b/working/process/subcaption_and_summary_generation/scripts/run_vllm_subcaption_inference.sh @@ -0,0 +1,33 @@ +#!/bin/bash +#SBATCH --job-name=pmc-subcaption-qwen32b +#SBATCH --partition=a100 +#SBATCH --time=1-00:00:00 +#SBATCH --nodes=1 +#SBATCH --gpus-per-node=2 +#SBATCH --cpus-per-task=4 +#SBATCH --mem=59G +#SBATCH --output=qwen32b-subcap.%j.out + +# Activate your environment + +echo "Script Run Start!" +nvidia-smi + +#module load cuda-12.4 +module load gcc-12.3.0 +gcc --version + +source ~/envs/exp/bin/activate # Adjust this path to your virtual environment + +echo "Module Loaded and Environment Activated!" + +# Specify which GPUs to use +CUDA_VISIBLE_DEVICES=0,1 \ +python /path/to/generate_subcaption_vllm.py \ + --data_path /path/to/data.csv \ + --model_dir /path/to/qwen2.5_vl_32B_model_weights_directory \ + --batch_size 32 \ + --max_new_tokens 1024 \ + --tp_size 2 \ + --gpu_mem_util 0.90 \ + --dtype bfloat16 diff --git a/working/process/subcaption_and_summary_generation/scripts/run_vllm_summary_inference.sh b/working/process/subcaption_and_summary_generation/scripts/run_vllm_summary_inference.sh new file mode 100644 index 0000000..b3c6fee --- /dev/null +++ b/working/process/subcaption_and_summary_generation/scripts/run_vllm_summary_inference.sh @@ -0,0 +1,32 @@ +#!/bin/bash +#SBATCH --job-name=summary-pmc +#SBATCH --partition=a40 +#SBATCH --time=24:00:00 +#SBATCH --nodes=1 +#SBATCH --gpus-per-node=2 +#SBATCH --cpus-per-task=4 +#SBATCH --mem=43G +#SBATCH --output=qwen14b-summary.%j.out + +echo "Script Run Start!" +nvidia-smi + +#module load cuda-12.4 +module load gcc-12.3.0 +gcc --version + +source ~/envs/exp2/bin/activate # Adjust this path to your virtual environment + +echo "Module Loaded and Environment Activated!" + +# Specify which GPUs to use +CUDA_VISIBLE_DEVICES=0,1 \ +python /path/to/generate_summary_vllm.py \ + --data_path /path/to/data.csv \ + --model_dir /path/to/qwen2.5_14b_instruct_model_weights \ + --batch_size 1024 \ + --max_new_tokens 256 \ + --tp_size 2 \ + --gpu_mem_util 0.90 \ + --dtype bfloat16 + diff --git a/working/process/subcaption_and_summary_generation/src/generate_modality_labels_vllm.py b/working/process/subcaption_and_summary_generation/src/generate_modality_labels_vllm.py new file mode 100644 index 0000000..fb54637 --- /dev/null +++ b/working/process/subcaption_and_summary_generation/src/generate_modality_labels_vllm.py @@ -0,0 +1,446 @@ +#!/usr/bin/env python3 +import argparse +import json +import logging +import os +import re +import time +from typing import Any, Dict, List, Optional, Tuple + +import pandas as pd +from PIL import Image +from qwen_vl_utils import process_vision_info +from tqdm import tqdm +from transformers import AutoProcessor +from vllm import LLM, SamplingParams + + +os.environ.setdefault("VLLM_WORKER_MULTIPROC_METHOD", "spawn") + +PROMPT_MEDICAL_L2_ONLY = ( + "You are an expert in medical image modality classification. " + "You are given a single image.\n\n" + "Your task is to assign ONE fine-grained subclass label (L2) to the image.\n\n" + "You must choose exactly ONE L2 label from the following allowed subclasses:\n" + "- Radiology: [Ultrasound, Magnetic Resonance, Computerized Tomography, " + "X-Ray, 2D Radiography, Angiography, PET, Combined modalities in one image]\n" + "- Microscopy: [Light microscopy, Electron microscopy, Transmission microscopy, " + "Fluorescence microscopy]\n" + "- Visible Light Photography: [Dermatology, skin, Endoscopy, Other organs]\n" + "- Other: [Other]\n\n" + 'If the image clearly does NOT belong to any medical modality above, choose "Other".\n' + "If the image appears medical but you are unsure among subclasses, choose the most visually plausible one.\n\n" + "OUTPUT FORMAT:\n" + "Return your answer as a single JSON object with ONLY the L2 field:\n" + "{\n" + ' "L2": ""\n' + "}\n" + "Do not include explanations, reasoning, or any additional text. Only output the JSON object." +) + +# L2 Radiology label sets +L2_RADIOLOGY = { + "ultrasound", + "magnetic resonance", + "computerized tomography", + "x-ray", + "2d radiography", + "angiography", + "pet", + "combined modalities in one image", +} + +# L2 Microscopy label sets +L2_MICROSCOPY = { + "light microscopy", + "electron microscopy", + "transmission microscopy", + "fluorescence microscopy", +} + +# L2 Visible Light Photography label sets +L2_VLP = { + "dermatology", + "skin", + "endoscopy", + "other organs", +} + +# -------------------- Logging -------------------- +logging.basicConfig( + level=logging.INFO, format="%(asctime)s | %(levelname)s | %(message)s" +) +log = logging.getLogger(__name__) + + +# -------------------- Helpers -------------------- +def _is_empty(x) -> bool: + """ + Check if a response is empty (None, NaN, or empty string). Used to identify unprocessed rows. + + Args: + x: The input to check. + + Returns + ------- + bool: True if x is considered empty, False otherwise. + """ + return x is None or (isinstance(x, float) and pd.isna(x)) or (str(x).strip() == "") + + +def _jsonl_overwrite(_df: pd.DataFrame, _path: str): + """ + Safely overwrite a JSONL file by writing to a temporary file first and then replacing the original. + + Args: + _df (pd.DataFrame): DataFrame to save. + _path (str): Path to the JSONL file. + """ + tmp = _path + ".tmp" + _df.to_json(tmp, lines=True, orient="records") + os.replace(tmp, _path) + + +def _load_rgb(path: str) -> Image.Image: + """ + Load an image from the given path and convert it to RGB mode if necessary. + + Args: + path (str): Path to the image file. + + Returns + ------- + Image.Image: The loaded RGB image. + """ + img = Image.open(path) + if img.mode != "RGB": + img = img.convert("RGB") + return img + + +def build_messages(img: Image.Image, prompt: str) -> List[Dict[str, Any]]: + """ + Build the message structure for the vLLM compatible VLM input. + + Args: + img (Image.Image): The input image. + prompt (str): The text prompt. + + Returns + ------- + List[Dict[str, Any]]: The constructed message list. + """ + return [ + { + "role": "user", + "content": [ + {"type": "image", "image": img}, + {"type": "text", "text": prompt}, + ], + } + ] + + +def extract_l2_label(text: str) -> Optional[str]: + """ + Extract JSON {L2: "..."} from model text output. If parsing fails, return None. + + Args: + text (str): The raw text output from the model. + + Returns + ------- + Optional[str]: The extracted L2 label, or None if parsing fails. + """ + cleaned = text.strip() + + # strip Markdown fences if present + if cleaned.startswith("```"): + cleaned = re.sub(r"^```[a-zA-Z0-9]*\s*", "", cleaned) + cleaned = re.sub(r"```$", "", cleaned).strip() + + # keep only the JSON object part if there's extra text + start = cleaned.find("{") + end = cleaned.rfind("}") + if start != -1 and end != -1 and end > start: + cleaned = cleaned[start : end + 1] + + try: + obj = json.loads(cleaned) + except Exception: + log.warning("JSON parse failed; storing raw text instead.") + return None + + l2 = str(obj.get("L2") or obj.get("l2") or "").strip() + return l2 + + +def infer_from_l2(l2_raw: str) -> Tuple[str, str, str]: + """ + Infer (L0, L1, L2) from an L2 string. + + L0 ∈ {Medical, Other} + L1 ∈ {Radiology, Microscopy, Visible Light Photography, Other} + L2 = original L2 text (possibly normalized upstream). + + Args: + l2_raw (str): The raw L2 label. + + Returns + ------- + Tuple[str, str, str]: The inferred (L0, L1, L2) labels. + """ + l2 = (l2_raw or "").strip() + l2_norm = l2.lower() + + if l2_norm in L2_RADIOLOGY: + l1 = "Radiology" + l0 = "Medical" + elif l2_norm in L2_MICROSCOPY: + l1 = "Microscopy" + l0 = "Medical" + elif l2_norm in L2_VLP: + l1 = "Visible Light Photography" + l0 = "Medical" + else: + l1 = "Other" + l0 = "Other" + + return l0, l1, l2 + + +# -------------------- Batch processing -------------------- +def process_batched( + df: pd.DataFrame, + llm: LLM, + processor, + out_path: str, + batch_size: int = 8, + max_new_tokens: int = 256, + temperature: float = 0.0, + top_p: float = 1.0, +) -> pd.DataFrame: + """ + Process the DataFrame in batches to generate modality labels using the provided vLLM model. + + Args: + df (pd.DataFrame): Input DataFrame with image paths. + llm (LLM): The vLLM model instance. + processor: The processor for preparing inputs. + out_path (str): Path to save the output CSV. + batch_size (int): Number of samples to process in each batch. + max_new_tokens (int): Maximum number of tokens to generate. + temperature (float): Sampling temperature. + top_p (float): Top-p sampling parameter. + + Returns + ------- + pd.DataFrame: The updated DataFrame with generated modality labels. + """ + image_col = "subfig_path" + label_cols = ["L0_label", "L1_label", "L2_label"] + + # ensure label columns exist, if not exist, create and store empty strings + for col in label_cols: + if col not in df.columns: + df[col] = "" + + # Sampling parameters for generation. + sampling = SamplingParams( + max_tokens=max_new_tokens, + temperature=temperature, + top_p=top_p, + ) + + t0_all = time.time() + n = len(df) + total_loaded, total_failed, total_done = 0, 0, 0 # counters to track progress + + # rows needing inference = those with empty L0_label + to_infer = sum(_is_empty(x) for x in df.get("L0_label", pd.Series([None] * n))) + pbar = tqdm(total=to_infer, desc="inference", ncols=100, unit="img") # progress bar + json_ok, json_fail = 0, 0 + + log.info(f"Starting batched processing on {n:,} rows (to infer: {to_infer:,})") + + flag = False + + for start in range(0, n, batch_size): + end = min(start + batch_size, n) + + # Select unprocessed rows. This also allows resuming. + idxs = [ + i + for i in range(start, end) + if any(_is_empty(df.at[i, col]) for col in label_cols) + ] + if not idxs: + continue # skip if all rows in this batch are already processed + + t_img0 = time.time() + requests = [] + idx_map = [] + + # Load tqdm for progress tracking + iterable = tqdm( + idxs, + desc=f"[prep] rows {start}-{end - 1}", + leave=False, + ncols=100, + unit="row", + ) + + batch_loaded, batch_failed = 0, 0 + + # Prepare inputs for each row in the batch + for i in iterable: + img_path = str(df.at[i, image_col]) if image_col in df.columns else "" + + try: + pil_img = _load_rgb(img_path) + batch_loaded += 1 + except Exception as e: + batch_failed += 1 + log.warning(f"Failed to load image at row {i}, path={img_path}: {e}") + continue + + messages = build_messages( + pil_img, PROMPT_MEDICAL_L2_ONLY + ) # Build vLLM message structure + image_inputs, _videos = process_vision_info( + messages + ) # Process images for vLLM using qwen_vl_utils's process_vision_info function. + + # Apply chat template to format the prompt correctly + fprompt = processor.apply_chat_template( + messages, tokenize=False, add_generation_prompt=True + ) + + # Final request List for vLLM + requests.append( + { + "prompt": fprompt, + "multi_modal_data": {"image": image_inputs}, + } + ) + idx_map.append(i) + + t_img = time.time() - t_img0 + total_loaded += batch_loaded + total_failed += batch_failed + + log.info( + f"[prep] batch {start}-{end - 1}: loaded={batch_loaded}, " + f"failed={batch_failed}, time={t_img:.2f}s" + ) + + if requests: + t_gen0 = time.time() + responses = llm.generate(requests, sampling) # vLLM generation call + t_gen = time.time() - t_gen0 + + # Process and store outputs + for j, res in enumerate(responses): + raw = res.outputs[0].text if res.outputs else "" + l2_parsed = extract_l2_label(raw) + + if l2_parsed is not None: + l0, l1, l2 = infer_from_l2(l2_parsed) + json_ok += 1 + else: + # if JSON extraction fails, store full raw string in all labels + l0 = l1 = l2 = raw.strip() + json_fail += 1 + + row_idx = idx_map[j] + df.at[row_idx, "L0_label"] = l0 + df.at[row_idx, "L1_label"] = l1 + df.at[row_idx, "L2_label"] = l2 + + pbar.update(1) + + total_done += len(responses) + flag = True + log.info( + f"[gen ] batch {start}-{end - 1}: outputs={len(responses)}, " + f"time={t_gen:.2f}s | json_ok={json_ok}, json_fail={json_fail}" + ) + + # Checkpointing every 1000 batches + if flag and start and ((start // batch_size) % 1000 == 0): + _jsonl_overwrite(df, out_path) + elapsed = time.time() - t0_all + log.info( + f"[ckpt] saved at row {start} → {out_path} | elapsed={elapsed / 60:.1f}m | " + f"done={total_done} | loaded={total_loaded} | failed_img={total_failed}" + ) + flag = False + + # Final save after all batches are processed + _jsonl_overwrite(df, out_path) + pbar.close() + log.info( + f"Total time {time.time() - t0_all:.2f}s | done={total_done} | " + f"loaded_img={total_loaded} | failed_img={total_failed} | " + f"json_ok={json_ok} | json_fail={json_fail}. Final saved → {out_path}" + ) + + return df + + +# -------------------- Main -------------------- +def main(): + args = argparse.ArgumentParser() + args.add_argument( + "--data_path", required=True, help="JSONL with column 'subfig_path'." + ) + args.add_argument( + "--model_dir", + default="Qwen/Qwen2.5-VL-32B-Instruct", + help="HF id or local path to Qwen2.5-VL-32B-Instruct", + ) + args.add_argument( + "--batch_size", type=int, default=8, help="Keep modest; VLMs are memory heavy" + ) + args.add_argument("--max_new_tokens", type=int, default=256) + args.add_argument( + "--tp_size", type=int, default=4, help="Tensor parallel degree for 32B" + ) + args.add_argument("--gpu_mem_util", type=float, default=0.90) + args.add_argument( + "--dtype", default="bfloat16", choices=["auto", "bfloat16", "float16"] + ) + args.add_argument("--temperature", type=float, default=0.0) + args.add_argument("--top_p", type=float, default=1.0) + + args_dct = args.parse_args() + + log.info(f"Loading processor and model from {args_dct.model_dir}") + processor = AutoProcessor.from_pretrained(args_dct.model_dir) + llm = LLM( + model=args_dct.model_dir, + tensor_parallel_size=args_dct.tp_size, + gpu_memory_utilization=args_dct.gpu_mem_util, + dtype=None if args_dct.dtype == "auto" else args_dct.dtype, + ) + + log.info(f"Reading data from {args_dct.data_path}") + df = pd.read_json(args_dct.data_path, lines=True) + + # Process in batches and generate modality labels + df = process_batched( + df=df, + llm=llm, + processor=processor, + out_path=args_dct.data_path, + batch_size=args_dct.batch_size, + max_new_tokens=args_dct.max_new_tokens, + temperature=args_dct.temperature, + top_p=args_dct.top_p, + ) + + log.info(f"Completed writing {len(df):,} rows → {args_dct.data_path}") + + +if __name__ == "__main__": + main() diff --git a/working/process/subcaption_and_summary_generation/src/generate_subcaption_vllm.py b/working/process/subcaption_and_summary_generation/src/generate_subcaption_vllm.py new file mode 100644 index 0000000..265a075 --- /dev/null +++ b/working/process/subcaption_and_summary_generation/src/generate_subcaption_vllm.py @@ -0,0 +1,314 @@ +#!/usr/bin/env python3 +import argparse +import os +import re +import time +from typing import Any, Dict, List + +import pandas as pd +from PIL import Image +from qwen_vl_utils import process_vision_info +from tqdm import tqdm +from transformers import AutoProcessor +from vllm import LLM, SamplingParams + + +os.environ.setdefault("VLLM_WORKER_MULTIPROC_METHOD", "spawn") + +prompt = ( + "### INSTRUCTIONS:\n" + "You are an expert medical image captioning assistant. Your task is the following:\n" + "1. You will be provided with a subfigure image that is part of a full image figure and the full figure caption in the input.\n" + "2. The full caption contains descriptions for multiple subfigures (e.g., Subfigure-A, Subfigure-B, etc.).\n" + "3. Your task is to identify the relevant subfigure caption corresponding to the provided subfigure image from the full caption exactly as it appears.\n" + "4. If the subcaption is written jointly for two or more subfigures (e.g., A–C together, (A–C), Axial (A) and coronal (B), etc.), copy that combined description exactly as it appears.\n" + "5. Do NOT rewrite, summarize, or generate new text. Copy the relevant portion exactly as it appears in the full caption.\n" + "6. Here, 'exactly as it appears' mean the extracted caption must match word-for-word, character-for-character with the correct subfigure caption text from the full caption. It must be a verbatim copy, not paraphrased, summarized, or partially copied.\n" + "7. If no relevant caption is found in the full caption, output the verbatim copy of the entire full caption.\n" + "### OUTPUT FORMAT:\n" + "\n" + "\n" + "\n\n" + "### INPUT:\n\n" +) + + +def _is_empty(x) -> bool: + """ + Check if a response is empty (None, NaN, or empty string). Used to identify unprocessed rows. + + Args: + x: The input to check. + + Returns + ------- + bool: True if x is considered empty, False otherwise. + """ + return x is None or (isinstance(x, float) and pd.isna(x)) or (str(x).strip() == "") + + +def _csv_overwrite(_df: pd.DataFrame, _path: str): + """ + Safely overwrite a CSV file by writing to a temporary file first and then replacing the original. + + Args: + _df (pd.DataFrame): DataFrame to save. + _path (str): Path to the CSV file. + """ + tmp = _path + ".tmp" + _df.to_csv(tmp, index=False) + os.replace(tmp, _path) + + +def _load_rgb(path: str) -> Image.Image: + """ + Load an image from the given path and convert it to RGB mode if necessary. + + Args: + path (str): Path to the image file. + + Returns + ------- + Image.Image: The loaded RGB image. + """ + img = Image.open(path) + if img.mode != "RGB": + img = img.convert("RGB") + return img + + +def build_messages(img: Image.Image, prompt: str) -> List[Dict[str, Any]]: + """ + Build the message structure for the vLLM compatible VLM input. + + Args: + img (Image.Image): The input image. + prompt (str): The text prompt. + + Returns + ------- + List[Dict[str, Any]]: The constructed message list. + """ + messages = [ + { + "role": "user", + "content": [ + {"type": "image", "image": img}, + {"type": "text", "text": prompt}, + ], + } + ] + + return messages + + +def process_batched( + df: pd.DataFrame, + llm: LLM, + processor, + out_path: str, + batch_size: int = 8, + max_new_tokens: int = 256, + temperature: float = 0.0, + top_p: float = 1.0, +) -> pd.DataFrame: + """ + Process the DataFrame in batches to generate subcaptions using the provided vLLM model. + + Args: + df (pd.DataFrame): Input DataFrame with image paths and captions. + llm (LLM): The vLLM model instance. + processor: The processor for preparing inputs. + out_path (str): Path to save the output CSV. + batch_size (int): Number of samples to process in each batch. + max_new_tokens (int): Maximum number of tokens to generate. + temperature (float): Sampling temperature. + top_p (float): Top-p sampling parameter. + + Returns + ------- + pd.DataFrame: The updated DataFrame with generated subcaptions. + """ + image_col = "subfig_path" + output_col = "sub_caption" + + # Sampling parameters for generation. Stop at . + sampling = SamplingParams( + max_tokens=max_new_tokens, + temperature=temperature, + top_p=top_p, + stop=[""], + ) + + pattern = re.compile( + r"\s*(.*?)\s*", re.DOTALL + ) # to extract text within tags + + t0_all = time.time() + n = len(df) + total_loaded, total_failed, total_done = 0, 0, 0 # counters to track progress + + for start in range(0, n, batch_size): + end = min(start + batch_size, n) + + idxs = [ + i for i in range(start, end) if _is_empty(df.at[i, output_col]) + ] # Select unprocessed rows. This also allows resuming. + if not idxs: + continue # skip if all rows in this batch are already processed + + t_img0 = time.time() + requests = [] + idx_map = [] + + # Load tqdm for progress tracking + iterable = tqdm( + idxs, + desc=f"[prep] rows {start}-{end - 1}", + leave=False, + ncols=100, + unit="row", + ) + + batch_loaded, batch_failed = 0, 0 # counters to track batch progress + + # Prepare inputs for each row in the batch + for i in iterable: + img_path = str(df.at[i, image_col]) if image_col in df.columns else "" + text = f"{prompt}\n\n##Full Caption:\n{df.caption.iloc[i]}" # Final text prompt containing full caption + + try: + pil_img = _load_rgb(img_path) + batch_loaded += 1 + except Exception: + batch_failed += 1 + continue + + messages = build_messages(pil_img, text) # Build vLLM message structure + image_inputs, _videos = process_vision_info( + messages + ) # Process images for vLLM using qwen_vl_utils's process_vision_info function. + + # Apply chat template to format the prompt correctly + fprompt = processor.apply_chat_template( + messages, tokenize=False, add_generation_prompt=True + ) + + # Final request List for vLLM + requests.append( + { + "prompt": fprompt, + "multi_modal_data": {"image": image_inputs}, + } + ) + idx_map.append(i) + + t_img = time.time() - t_img0 + total_loaded += batch_loaded + total_failed += batch_failed + + print( + f"[prep] batch {start}-{end - 1}: loaded={batch_loaded}, failed={batch_failed}, time={t_img:.2f}s" + ) + + if requests: + t_gen0 = time.time() + responses = llm.generate(requests, sampling) # vLLM generation call + t_gen = time.time() - t_gen0 + + # Process and store outputs + for j, res in enumerate(responses): + out = res.outputs[0].text if res.outputs else "" + m = pattern.search(out) + df.at[idx_map[j], output_col] = ( + m.group(1).strip() if m else out.replace("", "").strip() + ) # Strip of extra caption tags if regex fails. + + total_done += len(responses) + print( + f"[gen ] batch {start}-{end - 1}: outputs={len(responses)}, time={t_gen:.2f}s" + ) + + # Checkpointing every 10 batches + if start and ((start // batch_size) % 10 == 0): + _csv_overwrite(df, out_path) + elapsed = time.time() - t0_all + print( + f"[ckpt] saved at row {start} → {out_path} | elapsed={elapsed / 60:.1f}m | " + f"done={total_done} | loaded={total_loaded} | failed={total_failed}" + ) + + # Final save after all batches are processed + _csv_overwrite(df, out_path) + print( + f"Total time {time.time() - t0_all:.2f}s | done={total_done} | loaded={total_loaded} | failed={total_failed}. " + f"Final saved → {out_path}" + ) + return df + + +def main(): + args = argparse.ArgumentParser() + args.add_argument( + "--data_path", + required=True, + help="CSV with at least two columns: image path + full caption.", + ) + args.add_argument( + "--model_dir", + default="Qwen/Qwen2.5-VL-32B-Instruct", + help="HF id or local path to Qwen2.5-VL-32B-Instruct", + ) + args.add_argument( + "--batch_size", type=int, default=8, help="Keep modest; VLMs are memory heavy" + ) + args.add_argument( + "--max_new_tokens", type=int, default=256, help="Max tokens to generate" + ) + args.add_argument( + "--tp_size", + type=int, + default=4, + help="Tensor parallel degree for 32B (e.g., 4×A100-80GB)", + ) + args.add_argument( + "--gpu_mem_util", + type=float, + default=0.90, + help="GPU memory utilization for vLLM", + ) + args.add_argument( + "--dtype", default="bfloat16", choices=["auto", "bfloat16", "float16"] + ) + args.add_argument("--temperature", type=float, default=0.0) + args.add_argument("--top_p", type=float, default=1.0) + + args_dct = args.parse_args() + + processor = AutoProcessor.from_pretrained(args_dct.model_dir) + llm = LLM( + model=args_dct.model_dir, + tensor_parallel_size=args_dct.tp_size, + gpu_memory_utilization=args_dct.gpu_mem_util, + dtype=None if args_dct.dtype == "auto" else args_dct.dtype, + ) + + df = pd.read_csv(args_dct.data_path) # Load input CSV + + # Process in batches and generate subcaptions + df = process_batched( + df=df, + llm=llm, + processor=processor, + out_path=args_dct.data_path, + batch_size=args_dct.batch_size, + max_new_tokens=args_dct.max_new_tokens, + temperature=args_dct.temperature, + top_p=args_dct.top_p, + ) + + print(f"Completed writing {len(df)} rows → {args_dct.data_path}") + + +if __name__ == "__main__": + main() diff --git a/working/process/subcaption_and_summary_generation/src/generate_summary_vllm.py b/working/process/subcaption_and_summary_generation/src/generate_summary_vllm.py new file mode 100644 index 0000000..6653167 --- /dev/null +++ b/working/process/subcaption_and_summary_generation/src/generate_summary_vllm.py @@ -0,0 +1,228 @@ +#!/usr/bin/env python3 +import argparse +import os +import re +import time + +import pandas as pd +from transformers import AutoTokenizer +from vllm import LLM, SamplingParams + + +os.environ.setdefault("VLLM_WORKER_MULTIPROC_METHOD", "spawn") + +prompt = ( + "### INSTRUCTIONS:\n" + "You will be provided with:\n" + "1. A subcaption that describes a subfigure from a compound figure.\n" + "2. The full caption of the compound figure.\n" + "3. A context passage related to the compound figure.\n" + "**Definition of compound figure:** A compound figure is a figure that contains multiple subfigures of the same topic (e.g., panels A, B, C, etc.).\n\n" + "Your task is to summarize only the portions of the context passage " + "that are most relevant to the given subcaption. The full caption \n" + "is provided for additional information.\n" + "The summary should:\n" + "- Use both the subcaption and the full caption to determine context.\n" + "- Be concise and focused on the subcaption's content.\n" + "- Exclude unrelated information from the context passage.\n" + "- Preserve key biomedical terminology exactly as it appears.\n" + "- Output the summary only, without any labels or additional text in the following format:\n" + "\n" + "\n" + "\n\n" + "### INPUT:\n\n" +) + + +def build_chat(tokenizer, user_prompt: str, max_length: int = 32700): + """ + Build chat-style input encoding for vLLM from user prompt. + + Args: + tokenizer: The tokenizer to use. + user_prompt (str): The user prompt string. + max_length (int): Maximum token length for the input. + + Returns + ------- + encoded inputs. + """ + messages = [ + { + "role": "system", + "content": "You are a biomedical image context summary generator.", + }, + {"role": "user", "content": user_prompt}, + ] + + enc = tokenizer.apply_chat_template( + messages, tokenize=False, add_generation_prompt=True + ) + + token_ids = tokenizer.encode(enc, add_special_tokens=False) + if len(token_ids) > max_length: + enc = tokenizer.decode(token_ids[:max_length], skip_special_tokens=False) + + return enc + + +def _is_empty(x) -> bool: + """ + Check if a response is empty (None, NaN, or empty string). Used to identify unprocessed rows. + + Args: + x: The input to check. + + Returns + ------- + bool: True if x is considered empty, False otherwise. + """ + return ( + (x is None) or (isinstance(x, float) and pd.isna(x)) or (str(x).strip() == "") + ) + + +def _csv_overwrite(_df: pd.DataFrame, _path: str): + """ + Safely overwrite a CSV file by writing to a temporary file first and then replacing the original. + + Args: + _df (pd.DataFrame): DataFrame to save. + _path (str): Path to the CSV file. + """ + tmp = _path + ".tmp" + _df.to_csv(tmp, index=False) + os.replace(tmp, _path) + + +def process_data_batched_vllm( + df: pd.DataFrame, + llm: LLM, + tokenizer, + out_path: str, + batch_size: int = 16, + max_new_tokens: int = 192, +) -> None: + """ + Process the DataFrame in batches using vLLM to generate summaries. + + Args: + df (pd.DataFrame): Input DataFrame with columns 'caption', 'sub_caption', and 'image_context'. + llm (LLM): The vLLM model instance. + tokenizer: The tokenizer for building prompts. + out_path (str): Path to save the output CSV. + batch_size (int): Number of samples to process in each batch. + max_new_tokens (int): Maximum number of new tokens to generate for each summary. + """ + pattern = re.compile( + r"\s*(.*?)\s*<\/summary>", re.DOTALL + ) # Pattern to extract summary text + + sampling_params = SamplingParams( + max_tokens=max_new_tokens, temperature=0.0, top_p=1.0 + ) + t0_all = time.time() + + # Batch Processing Loop + for start in range(0, len(df), batch_size): + end = min(start + batch_size, len(df)) + idxs = [ + i for i in range(start, end) if _is_empty(df.loc[i, "summary"]) + ] # Select unprocessed rows. This also allows resuming. + if idxs: + batch_prompts = [] + for i in idxs: + # Prompt construction with full caption, subcaption, and context passage + user_prompt = ( + prompt + + f"Full Caption:\n{df.caption.iloc[i]}\n\n" + + f"Subcaption:\n{df.sub_caption.iloc[i]}\n\n" + + f"Context Passage:\n{df.image_context.iloc[i]}" + ) + batch_prompts.append(build_chat(tokenizer, user_prompt)) + + outs = llm.generate(batch_prompts, sampling_params) # vLLM generation call + + for j, out in enumerate(outs): + text = out.outputs[0].text + m = pattern.search(text) + df.loc[idxs[j], "summary"] = ( + m.group(1).strip() if m else text.strip() + ) # Extract summary or use full text if pattern not found + + # Overwrite CSV checkpoint every (batch size * 10) batches + if start and (start % (10 * batch_size) == 0): + _csv_overwrite(df, out_path) + print(f"[ckpt] Saved at row {start} → {out_path}") + + # Final save after all batches are processed + _csv_overwrite(df, out_path) + print(f"Total time {time.time() - t0_all:.2f}s. Final saved → {out_path}") + + +def main(): + parser = argparse.ArgumentParser(formatter_class=argparse.RawTextHelpFormatter) + parser.add_argument("--data_path", required=True, help="CSV path to data") + parser.add_argument( + "--model_dir", + required=True, + help="Path or HF id for the model (e.g., /model-weights/Qwen2.5-7B-Instruct)", + ) + parser.add_argument( + "--batch_size", + type=int, + default=16, + help="vLLM micro-batch size per generate() call", + ) + parser.add_argument( + "--max_new_tokens", type=int, default=192, help="Max new tokens to generate" + ) + parser.add_argument( + "--tp_size", type=int, default=1, help="Tensor parallel size for vLLM" + ) + parser.add_argument( + "--gpu_mem_util", + type=float, + default=0.90, + help="GPU memory utilization fraction for vLLM", + ) + parser.add_argument( + "--dtype", default="bfloat16", choices=["auto", "bfloat16", "float16"] + ) + + args = parser.parse_args() + + data_path = args.data_path + model_dir = args.model_dir + + # Tokenizer used only to template chat → plain prompt string + tokenizer = AutoTokenizer.from_pretrained(model_dir, trust_remote_code=True) + + # Init vLLM engine + # Notes: + # - tensor_parallel_size lets you span multiple GPUs if available. + # - gpu_memory_utilization tunes how full vLLM packs the GPU. + # - max_model_len can be set if you have very long contexts (defaults are fine for most). + llm = LLM( + model=model_dir, + tensor_parallel_size=args.tp_size, + gpu_memory_utilization=args.gpu_mem_util, + dtype=None if args.dtype == "auto" else args.dtype, + ) + + df = pd.read_csv(data_path) # Load input CSV + + process_data_batched_vllm( + df=df, + llm=llm, + tokenizer=tokenizer, + out_path=data_path, + batch_size=args.batch_size, + max_new_tokens=args.max_new_tokens, + ) + + print(f"Completed writing {len(df)} entries to: {data_path}") + + +if __name__ == "__main__": + main()