diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml
index bd22d44..698fe4c 100644
--- a/.pre-commit-config.yaml
+++ b/.pre-commit-config.yaml
@@ -21,8 +21,10 @@ repos:
- id: ruff
args: [--fix, --exit-non-zero-on-fix]
types_or: [python, jupyter]
+ exclude: "^openpmcvl/granular/"
- id: ruff-format
types_or: [python, jupyter]
+ exclude: "^openpmcvl/granular/"
- repo: https://github.com/pre-commit/mirrors-mypy
rev: v1.11.2
@@ -31,7 +33,7 @@ repos:
entry: python3 -m mypy --config-file pyproject.toml
language: system
types: [python]
- exclude: "tests"
+ exclude: "(^tests/|^openpmcvl/granular/|^openpmcvl/.*/tests/)"
- repo: https://github.com/crate-ci/typos
rev: v1.24.5
@@ -44,6 +46,7 @@ repos:
hooks:
- id: nbqa-ruff
args: [--fix, --exit-non-zero-on-fix]
+ exclude: "^openpmcvl/granular/"
ci:
autofix_commit_msg: |
diff --git a/README.md b/README.md
index 0b0604e..52efff3 100644
--- a/README.md
+++ b/README.md
@@ -7,8 +7,8 @@
[](https://github.com/VectorInstitute/pmc-data-extraction/blob/main/LICENSE.md)
-
diff --git a/openpmcvl/granular/models/subfigure_ocr.py b/openpmcvl/granular/models/subfigure_ocr.py
index a470b83..cf25111 100644
--- a/openpmcvl/granular/models/subfigure_ocr.py
+++ b/openpmcvl/granular/models/subfigure_ocr.py
@@ -89,7 +89,7 @@ def detect_subfigure_boundaries(self, figure_path):
## Reformat model outputs to display bounding boxes in our desired format
## List of lists where each inner list is [x1, y1, x2, y2, confidence]
- subfigure_info = list()
+ subfigure_info = []
if outputs[0] is None:
return subfigure_info
diff --git a/openpmcvl/granular/models/yolo_layer.py b/openpmcvl/granular/models/yolo_layer.py
index e7c48b6..2bf325d 100644
--- a/openpmcvl/granular/models/yolo_layer.py
+++ b/openpmcvl/granular/models/yolo_layer.py
@@ -470,7 +470,7 @@ class (float): class index.
for ti in range(n):
i, j = truth_i[ti], truth_j[ti]
- # find box with iou over 0.7 and under 0.3 (achor point)
+ # find box with iou over 0.7 and under 0.3 (anchor point)
current_truth_box = truth_box[ti : ti + 1]
current_pred_boxes = pred[b, :, j, i, :4]
pred_ious = bboxes_iou(
diff --git a/openpmcvl/granular/pipeline/subcaption.ipynb b/openpmcvl/granular/pipeline/subcaption.ipynb
index 0fe63b5..969f6a0 100644
--- a/openpmcvl/granular/pipeline/subcaption.ipynb
+++ b/openpmcvl/granular/pipeline/subcaption.ipynb
@@ -17,7 +17,7 @@
"\n",
"PMC_ROOT = \"set this directory\"\n",
"\n",
- "# Make sure .env file containt OPENAI_API_KEY\n",
+ "# Make sure .env file contains OPENAI_API_KEY\n",
"load_dotenv()\n",
"client = OpenAI()"
]
@@ -47,9 +47,9 @@
"PROMPT = \"\"\"\n",
"Subfigure labels are letters referring to individual subfigures within a larger figure.\n",
"This is a caption: \"%s\"\n",
- "Check if the caption contains explicit subfigure label. \n",
- "If not, output \"NO\" and end the generation. \n",
- "If yes, output \"YES\", then generate the subcaption of the subfigures according to the caption. \n",
+ "Check if the caption contains explicit subfigure label.\n",
+ "If not, output \"NO\" and end the generation.\n",
+ "If yes, output \"YES\", then generate the subcaption of the subfigures according to the caption.\n",
"The output should use the template:\n",
" YES\n",
" Subfigure-A: ...\n",
@@ -158,7 +158,8 @@
"outputs": [],
"source": [
"# Upload the requests file to OpenAI for batch processing\n",
- "batch_input_file = client.files.create(file=open(requests_file, \"rb\"), purpose=\"batch\")\n",
+ "with open(requests_file, \"rb\") as request_file:\n",
+ " batch_input_file = client.files.create(file=request_file, purpose=\"batch\")\n",
"batch_input_file_id = batch_input_file.id\n",
"\n",
"# Create a batch job to process the requests\n",
diff --git a/pyproject.toml b/pyproject.toml
index 6954641..19eaa8e 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -44,6 +44,11 @@ nbqa = { version = "^1.7.0", extras = ["toolchain"] }
pip-audit = "^2.7.1"
[tool.mypy]
+exclude = [
+ "^working/",
+ "^openpmcvl/granular/",
+ "^openpmcvl/.*/tests/",
+]
ignore_missing_imports = true
install_types = true
pretty = true
@@ -68,6 +73,7 @@ extra_checks = true
[tool.ruff]
include = ["*.py", "pyproject.toml", "*.ipynb"]
+extend-exclude = ["working", "openpmcvl/granular"]
line-length = 88
[tool.ruff.format]
@@ -110,6 +116,7 @@ ignore = [
# Ignore import violations in all `__init__.py` files.
[tool.ruff.lint.per-file-ignores]
"__init__.py" = ["E402", "F401", "F403", "F811"]
+"*.ipynb" = ["D100"]
[tool.ruff.lint.pep8-naming]
ignore-names = ["X*", "setUp"]
@@ -132,6 +139,7 @@ norecursedirs = ["working","openpmcvl"]
[tool.typos.default.extend-words]
nd = "nd"
+thre = "thre"
[build-system]
requires = ["poetry-core>=1.0.0"]
diff --git a/working/process/subcaption_and_summary_generation/README.md b/working/process/subcaption_and_summary_generation/README.md
new file mode 100644
index 0000000..737d91f
--- /dev/null
+++ b/working/process/subcaption_and_summary_generation/README.md
@@ -0,0 +1,106 @@
+## vLLM Inference Pipeline for Open-PMC-18M Subcaption, Image-context Summary generation, and Modality Labeling
+
+This repo contains three vLLM inference stages, each launched via a Slurm bash script:
+
+* **Stage 1 (Subcaption extraction, VLM):** `Qwen2.5-VL-32B-Instruct` generates a *verbatim* subfigure caption from a full figure caption + subfigure image.
+* **Stage 2 (Context summary, LLM):** `Qwen2.5-14B-Instruct` generates a focused summary of the context passage relevant to the subcaption.
+* **Stage 3 (Modality Labeling, VLM):** `Qwen2.5-VL-32B-Instruct` generates L2 labels, then L1 and L0 labels are inferred from a predefined set based on the generated L2 label.
+
+### Environment / Versions
+
+This pipeline was run with:
+
+* `vllm==0.8.2`
+* `xformers==0.0.29.post2`
+* `torch==2.6.0`
+
+### Inputs
+
+All scripts read and **overwrite** the same CSV or Jsonl (checkpointing is done by writing back to `--data_path`).
+
+**Required columns**
+
+* Subcaption stage (`generate_subcaption_vllm.py`):
+
+ * `subfig_path` (path to subfigure image)
+ * `caption` (full compound figure caption)
+ * Output column: `sub_caption`
+* Summary stage (`generate_summary_vllm.py`):
+
+ * `caption` (full compound figure caption)
+ * `sub_caption` (subcaption for each subfigure)
+ * `image_context` (image context related to subfigure)
+ * Output column: `summary`
+* Modality Labeling stage (`generate_modality_labels_vllm.py`):
+
+ * `subfig_path` (path to subfigure image)
+ * Output column: `L0_label`, `L1_label`, and `L2_label`
+
+All stages support **resume** behavior: they skip rows where the output column is already filled (non-empty).
+
+---
+
+## How to Run (Slurm)
+
+### 1) Subcaption generation (Qwen2.5-VL-32B-Instruct)
+
+Edit the Slurm script to point to:
+
+* your python file path
+* your CSV path (`--data_path`)
+* your model weights path (`--model_dir`)
+* any desired batch/tp settings
+
+Then submit:
+
+```bash
+sbatch run_vllm_subcaption_inference.sh
+```
+
+Slurm script reference:
+
+**What it does:** launches `generate_subcaption_vllm.py` with vLLM tensor parallelism and writes `sub_caption` back into the CSV.
+
+---
+
+### 2) Summary generation (Qwen2.5-14B-Instruct)
+
+After Stage 1 finishes (CSV now has `sub_caption`), edit and submit:
+
+```bash
+sbatch run_vllm_summary_inference.sh
+```
+
+Slurm script reference:
+
+**What it does:** runs `generate_summary_vllm.py` and writes `summary` back into the same CSV.
+
+---
+
+### 3) Modality Label generation (Qwen2.5-VL-32B-Instruct)
+
+Edit the Slurm script to point to:
+
+* your python file path
+* your CSV path (`--data_path`)
+* your model weights path (`--model_dir`)
+* any desired batch/tp settings
+
+Then submit:
+
+```bash
+sbatch run_vllm_modality_inference.sh
+```
+
+Slurm script reference:
+
+**What it does:** runs `generate_modality_labels_vllm.py` and writes `L0`, `L1`, and `L2` labels back into the same jsonl file.
+
+---
+
+## Notes
+
+* **Paths:** All Slurm scripts include placeholder paths like `/path/to/...` — replace them before submitting.
+* **GPU selection:** All scripts set `CUDA_VISIBLE_DEVICES=0,1` and use `--tp_size 2` to shard across 2 GPUs.
+* **Checkpointing:** All scripts allow periodic checkpointing.
+* **Outputs formatting:** subcaptions are extracted from `...`, and summaries from `...` (regex-based extraction).
diff --git a/working/process/subcaption_and_summary_generation/scripts/run_vllm_modality_inference.sh b/working/process/subcaption_and_summary_generation/scripts/run_vllm_modality_inference.sh
new file mode 100644
index 0000000..c15338d
--- /dev/null
+++ b/working/process/subcaption_and_summary_generation/scripts/run_vllm_modality_inference.sh
@@ -0,0 +1,33 @@
+#!/bin/bash
+#SBATCH --job-name=pmc-subcaption-qwen32b
+#SBATCH --partition=a100
+#SBATCH --time=1-00:00:00
+#SBATCH --nodes=1
+#SBATCH --gpus-per-node=2
+#SBATCH --cpus-per-task=4
+#SBATCH --mem=59G
+#SBATCH --output=qwen32b-subcap.%j.out
+
+# Activate your environment
+
+echo "Script Run Start!"
+nvidia-smi
+
+#module load cuda-12.4
+module load gcc-12.3.0
+gcc --version
+
+source ~/envs/exp/bin/activate # Adjust this path to your virtual environment
+
+echo "Module Loaded and Environment Activated!"
+
+# Specify which GPUs to use
+CUDA_VISIBLE_DEVICES=0,1 \
+python /path/to/generate_modality_labels_vllm.py \
+ --data_path /path/to/data \
+ --model_dir /path/to/Qwen2.5-VL-32B-Instruct \
+ --batch_size 512 \
+ --max_new_tokens 128 \
+ --tp_size 2 \
+ --gpu_mem_util 0.90 \
+ --dtype bfloat16
\ No newline at end of file
diff --git a/working/process/subcaption_and_summary_generation/scripts/run_vllm_subcaption_inference.sh b/working/process/subcaption_and_summary_generation/scripts/run_vllm_subcaption_inference.sh
new file mode 100644
index 0000000..05a6b1b
--- /dev/null
+++ b/working/process/subcaption_and_summary_generation/scripts/run_vllm_subcaption_inference.sh
@@ -0,0 +1,33 @@
+#!/bin/bash
+#SBATCH --job-name=pmc-subcaption-qwen32b
+#SBATCH --partition=a100
+#SBATCH --time=1-00:00:00
+#SBATCH --nodes=1
+#SBATCH --gpus-per-node=2
+#SBATCH --cpus-per-task=4
+#SBATCH --mem=59G
+#SBATCH --output=qwen32b-subcap.%j.out
+
+# Activate your environment
+
+echo "Script Run Start!"
+nvidia-smi
+
+#module load cuda-12.4
+module load gcc-12.3.0
+gcc --version
+
+source ~/envs/exp/bin/activate # Adjust this path to your virtual environment
+
+echo "Module Loaded and Environment Activated!"
+
+# Specify which GPUs to use
+CUDA_VISIBLE_DEVICES=0,1 \
+python /path/to/generate_subcaption_vllm.py \
+ --data_path /path/to/data.csv \
+ --model_dir /path/to/qwen2.5_vl_32B_model_weights_directory \
+ --batch_size 32 \
+ --max_new_tokens 1024 \
+ --tp_size 2 \
+ --gpu_mem_util 0.90 \
+ --dtype bfloat16
diff --git a/working/process/subcaption_and_summary_generation/scripts/run_vllm_summary_inference.sh b/working/process/subcaption_and_summary_generation/scripts/run_vllm_summary_inference.sh
new file mode 100644
index 0000000..b3c6fee
--- /dev/null
+++ b/working/process/subcaption_and_summary_generation/scripts/run_vllm_summary_inference.sh
@@ -0,0 +1,32 @@
+#!/bin/bash
+#SBATCH --job-name=summary-pmc
+#SBATCH --partition=a40
+#SBATCH --time=24:00:00
+#SBATCH --nodes=1
+#SBATCH --gpus-per-node=2
+#SBATCH --cpus-per-task=4
+#SBATCH --mem=43G
+#SBATCH --output=qwen14b-summary.%j.out
+
+echo "Script Run Start!"
+nvidia-smi
+
+#module load cuda-12.4
+module load gcc-12.3.0
+gcc --version
+
+source ~/envs/exp2/bin/activate # Adjust this path to your virtual environment
+
+echo "Module Loaded and Environment Activated!"
+
+# Specify which GPUs to use
+CUDA_VISIBLE_DEVICES=0,1 \
+python /path/to/generate_summary_vllm.py \
+ --data_path /path/to/data.csv \
+ --model_dir /path/to/qwen2.5_14b_instruct_model_weights \
+ --batch_size 1024 \
+ --max_new_tokens 256 \
+ --tp_size 2 \
+ --gpu_mem_util 0.90 \
+ --dtype bfloat16
+
diff --git a/working/process/subcaption_and_summary_generation/src/generate_modality_labels_vllm.py b/working/process/subcaption_and_summary_generation/src/generate_modality_labels_vllm.py
new file mode 100644
index 0000000..fb54637
--- /dev/null
+++ b/working/process/subcaption_and_summary_generation/src/generate_modality_labels_vllm.py
@@ -0,0 +1,446 @@
+#!/usr/bin/env python3
+import argparse
+import json
+import logging
+import os
+import re
+import time
+from typing import Any, Dict, List, Optional, Tuple
+
+import pandas as pd
+from PIL import Image
+from qwen_vl_utils import process_vision_info
+from tqdm import tqdm
+from transformers import AutoProcessor
+from vllm import LLM, SamplingParams
+
+
+os.environ.setdefault("VLLM_WORKER_MULTIPROC_METHOD", "spawn")
+
+PROMPT_MEDICAL_L2_ONLY = (
+ "You are an expert in medical image modality classification. "
+ "You are given a single image.\n\n"
+ "Your task is to assign ONE fine-grained subclass label (L2) to the image.\n\n"
+ "You must choose exactly ONE L2 label from the following allowed subclasses:\n"
+ "- Radiology: [Ultrasound, Magnetic Resonance, Computerized Tomography, "
+ "X-Ray, 2D Radiography, Angiography, PET, Combined modalities in one image]\n"
+ "- Microscopy: [Light microscopy, Electron microscopy, Transmission microscopy, "
+ "Fluorescence microscopy]\n"
+ "- Visible Light Photography: [Dermatology, skin, Endoscopy, Other organs]\n"
+ "- Other: [Other]\n\n"
+ 'If the image clearly does NOT belong to any medical modality above, choose "Other".\n'
+ "If the image appears medical but you are unsure among subclasses, choose the most visually plausible one.\n\n"
+ "OUTPUT FORMAT:\n"
+ "Return your answer as a single JSON object with ONLY the L2 field:\n"
+ "{\n"
+ ' "L2": ""\n'
+ "}\n"
+ "Do not include explanations, reasoning, or any additional text. Only output the JSON object."
+)
+
+# L2 Radiology label sets
+L2_RADIOLOGY = {
+ "ultrasound",
+ "magnetic resonance",
+ "computerized tomography",
+ "x-ray",
+ "2d radiography",
+ "angiography",
+ "pet",
+ "combined modalities in one image",
+}
+
+# L2 Microscopy label sets
+L2_MICROSCOPY = {
+ "light microscopy",
+ "electron microscopy",
+ "transmission microscopy",
+ "fluorescence microscopy",
+}
+
+# L2 Visible Light Photography label sets
+L2_VLP = {
+ "dermatology",
+ "skin",
+ "endoscopy",
+ "other organs",
+}
+
+# -------------------- Logging --------------------
+logging.basicConfig(
+ level=logging.INFO, format="%(asctime)s | %(levelname)s | %(message)s"
+)
+log = logging.getLogger(__name__)
+
+
+# -------------------- Helpers --------------------
+def _is_empty(x) -> bool:
+ """
+ Check if a response is empty (None, NaN, or empty string). Used to identify unprocessed rows.
+
+ Args:
+ x: The input to check.
+
+ Returns
+ -------
+ bool: True if x is considered empty, False otherwise.
+ """
+ return x is None or (isinstance(x, float) and pd.isna(x)) or (str(x).strip() == "")
+
+
+def _jsonl_overwrite(_df: pd.DataFrame, _path: str):
+ """
+ Safely overwrite a JSONL file by writing to a temporary file first and then replacing the original.
+
+ Args:
+ _df (pd.DataFrame): DataFrame to save.
+ _path (str): Path to the JSONL file.
+ """
+ tmp = _path + ".tmp"
+ _df.to_json(tmp, lines=True, orient="records")
+ os.replace(tmp, _path)
+
+
+def _load_rgb(path: str) -> Image.Image:
+ """
+ Load an image from the given path and convert it to RGB mode if necessary.
+
+ Args:
+ path (str): Path to the image file.
+
+ Returns
+ -------
+ Image.Image: The loaded RGB image.
+ """
+ img = Image.open(path)
+ if img.mode != "RGB":
+ img = img.convert("RGB")
+ return img
+
+
+def build_messages(img: Image.Image, prompt: str) -> List[Dict[str, Any]]:
+ """
+ Build the message structure for the vLLM compatible VLM input.
+
+ Args:
+ img (Image.Image): The input image.
+ prompt (str): The text prompt.
+
+ Returns
+ -------
+ List[Dict[str, Any]]: The constructed message list.
+ """
+ return [
+ {
+ "role": "user",
+ "content": [
+ {"type": "image", "image": img},
+ {"type": "text", "text": prompt},
+ ],
+ }
+ ]
+
+
+def extract_l2_label(text: str) -> Optional[str]:
+ """
+ Extract JSON {L2: "..."} from model text output. If parsing fails, return None.
+
+ Args:
+ text (str): The raw text output from the model.
+
+ Returns
+ -------
+ Optional[str]: The extracted L2 label, or None if parsing fails.
+ """
+ cleaned = text.strip()
+
+ # strip Markdown fences if present
+ if cleaned.startswith("```"):
+ cleaned = re.sub(r"^```[a-zA-Z0-9]*\s*", "", cleaned)
+ cleaned = re.sub(r"```$", "", cleaned).strip()
+
+ # keep only the JSON object part if there's extra text
+ start = cleaned.find("{")
+ end = cleaned.rfind("}")
+ if start != -1 and end != -1 and end > start:
+ cleaned = cleaned[start : end + 1]
+
+ try:
+ obj = json.loads(cleaned)
+ except Exception:
+ log.warning("JSON parse failed; storing raw text instead.")
+ return None
+
+ l2 = str(obj.get("L2") or obj.get("l2") or "").strip()
+ return l2
+
+
+def infer_from_l2(l2_raw: str) -> Tuple[str, str, str]:
+ """
+ Infer (L0, L1, L2) from an L2 string.
+
+ L0 ∈ {Medical, Other}
+ L1 ∈ {Radiology, Microscopy, Visible Light Photography, Other}
+ L2 = original L2 text (possibly normalized upstream).
+
+ Args:
+ l2_raw (str): The raw L2 label.
+
+ Returns
+ -------
+ Tuple[str, str, str]: The inferred (L0, L1, L2) labels.
+ """
+ l2 = (l2_raw or "").strip()
+ l2_norm = l2.lower()
+
+ if l2_norm in L2_RADIOLOGY:
+ l1 = "Radiology"
+ l0 = "Medical"
+ elif l2_norm in L2_MICROSCOPY:
+ l1 = "Microscopy"
+ l0 = "Medical"
+ elif l2_norm in L2_VLP:
+ l1 = "Visible Light Photography"
+ l0 = "Medical"
+ else:
+ l1 = "Other"
+ l0 = "Other"
+
+ return l0, l1, l2
+
+
+# -------------------- Batch processing --------------------
+def process_batched(
+ df: pd.DataFrame,
+ llm: LLM,
+ processor,
+ out_path: str,
+ batch_size: int = 8,
+ max_new_tokens: int = 256,
+ temperature: float = 0.0,
+ top_p: float = 1.0,
+) -> pd.DataFrame:
+ """
+ Process the DataFrame in batches to generate modality labels using the provided vLLM model.
+
+ Args:
+ df (pd.DataFrame): Input DataFrame with image paths.
+ llm (LLM): The vLLM model instance.
+ processor: The processor for preparing inputs.
+ out_path (str): Path to save the output CSV.
+ batch_size (int): Number of samples to process in each batch.
+ max_new_tokens (int): Maximum number of tokens to generate.
+ temperature (float): Sampling temperature.
+ top_p (float): Top-p sampling parameter.
+
+ Returns
+ -------
+ pd.DataFrame: The updated DataFrame with generated modality labels.
+ """
+ image_col = "subfig_path"
+ label_cols = ["L0_label", "L1_label", "L2_label"]
+
+ # ensure label columns exist, if not exist, create and store empty strings
+ for col in label_cols:
+ if col not in df.columns:
+ df[col] = ""
+
+ # Sampling parameters for generation.
+ sampling = SamplingParams(
+ max_tokens=max_new_tokens,
+ temperature=temperature,
+ top_p=top_p,
+ )
+
+ t0_all = time.time()
+ n = len(df)
+ total_loaded, total_failed, total_done = 0, 0, 0 # counters to track progress
+
+ # rows needing inference = those with empty L0_label
+ to_infer = sum(_is_empty(x) for x in df.get("L0_label", pd.Series([None] * n)))
+ pbar = tqdm(total=to_infer, desc="inference", ncols=100, unit="img") # progress bar
+ json_ok, json_fail = 0, 0
+
+ log.info(f"Starting batched processing on {n:,} rows (to infer: {to_infer:,})")
+
+ flag = False
+
+ for start in range(0, n, batch_size):
+ end = min(start + batch_size, n)
+
+ # Select unprocessed rows. This also allows resuming.
+ idxs = [
+ i
+ for i in range(start, end)
+ if any(_is_empty(df.at[i, col]) for col in label_cols)
+ ]
+ if not idxs:
+ continue # skip if all rows in this batch are already processed
+
+ t_img0 = time.time()
+ requests = []
+ idx_map = []
+
+ # Load tqdm for progress tracking
+ iterable = tqdm(
+ idxs,
+ desc=f"[prep] rows {start}-{end - 1}",
+ leave=False,
+ ncols=100,
+ unit="row",
+ )
+
+ batch_loaded, batch_failed = 0, 0
+
+ # Prepare inputs for each row in the batch
+ for i in iterable:
+ img_path = str(df.at[i, image_col]) if image_col in df.columns else ""
+
+ try:
+ pil_img = _load_rgb(img_path)
+ batch_loaded += 1
+ except Exception as e:
+ batch_failed += 1
+ log.warning(f"Failed to load image at row {i}, path={img_path}: {e}")
+ continue
+
+ messages = build_messages(
+ pil_img, PROMPT_MEDICAL_L2_ONLY
+ ) # Build vLLM message structure
+ image_inputs, _videos = process_vision_info(
+ messages
+ ) # Process images for vLLM using qwen_vl_utils's process_vision_info function.
+
+ # Apply chat template to format the prompt correctly
+ fprompt = processor.apply_chat_template(
+ messages, tokenize=False, add_generation_prompt=True
+ )
+
+ # Final request List for vLLM
+ requests.append(
+ {
+ "prompt": fprompt,
+ "multi_modal_data": {"image": image_inputs},
+ }
+ )
+ idx_map.append(i)
+
+ t_img = time.time() - t_img0
+ total_loaded += batch_loaded
+ total_failed += batch_failed
+
+ log.info(
+ f"[prep] batch {start}-{end - 1}: loaded={batch_loaded}, "
+ f"failed={batch_failed}, time={t_img:.2f}s"
+ )
+
+ if requests:
+ t_gen0 = time.time()
+ responses = llm.generate(requests, sampling) # vLLM generation call
+ t_gen = time.time() - t_gen0
+
+ # Process and store outputs
+ for j, res in enumerate(responses):
+ raw = res.outputs[0].text if res.outputs else ""
+ l2_parsed = extract_l2_label(raw)
+
+ if l2_parsed is not None:
+ l0, l1, l2 = infer_from_l2(l2_parsed)
+ json_ok += 1
+ else:
+ # if JSON extraction fails, store full raw string in all labels
+ l0 = l1 = l2 = raw.strip()
+ json_fail += 1
+
+ row_idx = idx_map[j]
+ df.at[row_idx, "L0_label"] = l0
+ df.at[row_idx, "L1_label"] = l1
+ df.at[row_idx, "L2_label"] = l2
+
+ pbar.update(1)
+
+ total_done += len(responses)
+ flag = True
+ log.info(
+ f"[gen ] batch {start}-{end - 1}: outputs={len(responses)}, "
+ f"time={t_gen:.2f}s | json_ok={json_ok}, json_fail={json_fail}"
+ )
+
+ # Checkpointing every 1000 batches
+ if flag and start and ((start // batch_size) % 1000 == 0):
+ _jsonl_overwrite(df, out_path)
+ elapsed = time.time() - t0_all
+ log.info(
+ f"[ckpt] saved at row {start} → {out_path} | elapsed={elapsed / 60:.1f}m | "
+ f"done={total_done} | loaded={total_loaded} | failed_img={total_failed}"
+ )
+ flag = False
+
+ # Final save after all batches are processed
+ _jsonl_overwrite(df, out_path)
+ pbar.close()
+ log.info(
+ f"Total time {time.time() - t0_all:.2f}s | done={total_done} | "
+ f"loaded_img={total_loaded} | failed_img={total_failed} | "
+ f"json_ok={json_ok} | json_fail={json_fail}. Final saved → {out_path}"
+ )
+
+ return df
+
+
+# -------------------- Main --------------------
+def main():
+ args = argparse.ArgumentParser()
+ args.add_argument(
+ "--data_path", required=True, help="JSONL with column 'subfig_path'."
+ )
+ args.add_argument(
+ "--model_dir",
+ default="Qwen/Qwen2.5-VL-32B-Instruct",
+ help="HF id or local path to Qwen2.5-VL-32B-Instruct",
+ )
+ args.add_argument(
+ "--batch_size", type=int, default=8, help="Keep modest; VLMs are memory heavy"
+ )
+ args.add_argument("--max_new_tokens", type=int, default=256)
+ args.add_argument(
+ "--tp_size", type=int, default=4, help="Tensor parallel degree for 32B"
+ )
+ args.add_argument("--gpu_mem_util", type=float, default=0.90)
+ args.add_argument(
+ "--dtype", default="bfloat16", choices=["auto", "bfloat16", "float16"]
+ )
+ args.add_argument("--temperature", type=float, default=0.0)
+ args.add_argument("--top_p", type=float, default=1.0)
+
+ args_dct = args.parse_args()
+
+ log.info(f"Loading processor and model from {args_dct.model_dir}")
+ processor = AutoProcessor.from_pretrained(args_dct.model_dir)
+ llm = LLM(
+ model=args_dct.model_dir,
+ tensor_parallel_size=args_dct.tp_size,
+ gpu_memory_utilization=args_dct.gpu_mem_util,
+ dtype=None if args_dct.dtype == "auto" else args_dct.dtype,
+ )
+
+ log.info(f"Reading data from {args_dct.data_path}")
+ df = pd.read_json(args_dct.data_path, lines=True)
+
+ # Process in batches and generate modality labels
+ df = process_batched(
+ df=df,
+ llm=llm,
+ processor=processor,
+ out_path=args_dct.data_path,
+ batch_size=args_dct.batch_size,
+ max_new_tokens=args_dct.max_new_tokens,
+ temperature=args_dct.temperature,
+ top_p=args_dct.top_p,
+ )
+
+ log.info(f"Completed writing {len(df):,} rows → {args_dct.data_path}")
+
+
+if __name__ == "__main__":
+ main()
diff --git a/working/process/subcaption_and_summary_generation/src/generate_subcaption_vllm.py b/working/process/subcaption_and_summary_generation/src/generate_subcaption_vllm.py
new file mode 100644
index 0000000..265a075
--- /dev/null
+++ b/working/process/subcaption_and_summary_generation/src/generate_subcaption_vllm.py
@@ -0,0 +1,314 @@
+#!/usr/bin/env python3
+import argparse
+import os
+import re
+import time
+from typing import Any, Dict, List
+
+import pandas as pd
+from PIL import Image
+from qwen_vl_utils import process_vision_info
+from tqdm import tqdm
+from transformers import AutoProcessor
+from vllm import LLM, SamplingParams
+
+
+os.environ.setdefault("VLLM_WORKER_MULTIPROC_METHOD", "spawn")
+
+prompt = (
+ "### INSTRUCTIONS:\n"
+ "You are an expert medical image captioning assistant. Your task is the following:\n"
+ "1. You will be provided with a subfigure image that is part of a full image figure and the full figure caption in the input.\n"
+ "2. The full caption contains descriptions for multiple subfigures (e.g., Subfigure-A, Subfigure-B, etc.).\n"
+ "3. Your task is to identify the relevant subfigure caption corresponding to the provided subfigure image from the full caption exactly as it appears.\n"
+ "4. If the subcaption is written jointly for two or more subfigures (e.g., A–C together, (A–C), Axial (A) and coronal (B), etc.), copy that combined description exactly as it appears.\n"
+ "5. Do NOT rewrite, summarize, or generate new text. Copy the relevant portion exactly as it appears in the full caption.\n"
+ "6. Here, 'exactly as it appears' mean the extracted caption must match word-for-word, character-for-character with the correct subfigure caption text from the full caption. It must be a verbatim copy, not paraphrased, summarized, or partially copied.\n"
+ "7. If no relevant caption is found in the full caption, output the verbatim copy of the entire full caption.\n"
+ "### OUTPUT FORMAT:\n"
+ "\n"
+ "\n"
+ "\n\n"
+ "### INPUT:\n\n"
+)
+
+
+def _is_empty(x) -> bool:
+ """
+ Check if a response is empty (None, NaN, or empty string). Used to identify unprocessed rows.
+
+ Args:
+ x: The input to check.
+
+ Returns
+ -------
+ bool: True if x is considered empty, False otherwise.
+ """
+ return x is None or (isinstance(x, float) and pd.isna(x)) or (str(x).strip() == "")
+
+
+def _csv_overwrite(_df: pd.DataFrame, _path: str):
+ """
+ Safely overwrite a CSV file by writing to a temporary file first and then replacing the original.
+
+ Args:
+ _df (pd.DataFrame): DataFrame to save.
+ _path (str): Path to the CSV file.
+ """
+ tmp = _path + ".tmp"
+ _df.to_csv(tmp, index=False)
+ os.replace(tmp, _path)
+
+
+def _load_rgb(path: str) -> Image.Image:
+ """
+ Load an image from the given path and convert it to RGB mode if necessary.
+
+ Args:
+ path (str): Path to the image file.
+
+ Returns
+ -------
+ Image.Image: The loaded RGB image.
+ """
+ img = Image.open(path)
+ if img.mode != "RGB":
+ img = img.convert("RGB")
+ return img
+
+
+def build_messages(img: Image.Image, prompt: str) -> List[Dict[str, Any]]:
+ """
+ Build the message structure for the vLLM compatible VLM input.
+
+ Args:
+ img (Image.Image): The input image.
+ prompt (str): The text prompt.
+
+ Returns
+ -------
+ List[Dict[str, Any]]: The constructed message list.
+ """
+ messages = [
+ {
+ "role": "user",
+ "content": [
+ {"type": "image", "image": img},
+ {"type": "text", "text": prompt},
+ ],
+ }
+ ]
+
+ return messages
+
+
+def process_batched(
+ df: pd.DataFrame,
+ llm: LLM,
+ processor,
+ out_path: str,
+ batch_size: int = 8,
+ max_new_tokens: int = 256,
+ temperature: float = 0.0,
+ top_p: float = 1.0,
+) -> pd.DataFrame:
+ """
+ Process the DataFrame in batches to generate subcaptions using the provided vLLM model.
+
+ Args:
+ df (pd.DataFrame): Input DataFrame with image paths and captions.
+ llm (LLM): The vLLM model instance.
+ processor: The processor for preparing inputs.
+ out_path (str): Path to save the output CSV.
+ batch_size (int): Number of samples to process in each batch.
+ max_new_tokens (int): Maximum number of tokens to generate.
+ temperature (float): Sampling temperature.
+ top_p (float): Top-p sampling parameter.
+
+ Returns
+ -------
+ pd.DataFrame: The updated DataFrame with generated subcaptions.
+ """
+ image_col = "subfig_path"
+ output_col = "sub_caption"
+
+ # Sampling parameters for generation. Stop at .
+ sampling = SamplingParams(
+ max_tokens=max_new_tokens,
+ temperature=temperature,
+ top_p=top_p,
+ stop=[""],
+ )
+
+ pattern = re.compile(
+ r"\s*(.*?)\s*", re.DOTALL
+ ) # to extract text within tags
+
+ t0_all = time.time()
+ n = len(df)
+ total_loaded, total_failed, total_done = 0, 0, 0 # counters to track progress
+
+ for start in range(0, n, batch_size):
+ end = min(start + batch_size, n)
+
+ idxs = [
+ i for i in range(start, end) if _is_empty(df.at[i, output_col])
+ ] # Select unprocessed rows. This also allows resuming.
+ if not idxs:
+ continue # skip if all rows in this batch are already processed
+
+ t_img0 = time.time()
+ requests = []
+ idx_map = []
+
+ # Load tqdm for progress tracking
+ iterable = tqdm(
+ idxs,
+ desc=f"[prep] rows {start}-{end - 1}",
+ leave=False,
+ ncols=100,
+ unit="row",
+ )
+
+ batch_loaded, batch_failed = 0, 0 # counters to track batch progress
+
+ # Prepare inputs for each row in the batch
+ for i in iterable:
+ img_path = str(df.at[i, image_col]) if image_col in df.columns else ""
+ text = f"{prompt}\n\n##Full Caption:\n{df.caption.iloc[i]}" # Final text prompt containing full caption
+
+ try:
+ pil_img = _load_rgb(img_path)
+ batch_loaded += 1
+ except Exception:
+ batch_failed += 1
+ continue
+
+ messages = build_messages(pil_img, text) # Build vLLM message structure
+ image_inputs, _videos = process_vision_info(
+ messages
+ ) # Process images for vLLM using qwen_vl_utils's process_vision_info function.
+
+ # Apply chat template to format the prompt correctly
+ fprompt = processor.apply_chat_template(
+ messages, tokenize=False, add_generation_prompt=True
+ )
+
+ # Final request List for vLLM
+ requests.append(
+ {
+ "prompt": fprompt,
+ "multi_modal_data": {"image": image_inputs},
+ }
+ )
+ idx_map.append(i)
+
+ t_img = time.time() - t_img0
+ total_loaded += batch_loaded
+ total_failed += batch_failed
+
+ print(
+ f"[prep] batch {start}-{end - 1}: loaded={batch_loaded}, failed={batch_failed}, time={t_img:.2f}s"
+ )
+
+ if requests:
+ t_gen0 = time.time()
+ responses = llm.generate(requests, sampling) # vLLM generation call
+ t_gen = time.time() - t_gen0
+
+ # Process and store outputs
+ for j, res in enumerate(responses):
+ out = res.outputs[0].text if res.outputs else ""
+ m = pattern.search(out)
+ df.at[idx_map[j], output_col] = (
+ m.group(1).strip() if m else out.replace("", "").strip()
+ ) # Strip of extra caption tags if regex fails.
+
+ total_done += len(responses)
+ print(
+ f"[gen ] batch {start}-{end - 1}: outputs={len(responses)}, time={t_gen:.2f}s"
+ )
+
+ # Checkpointing every 10 batches
+ if start and ((start // batch_size) % 10 == 0):
+ _csv_overwrite(df, out_path)
+ elapsed = time.time() - t0_all
+ print(
+ f"[ckpt] saved at row {start} → {out_path} | elapsed={elapsed / 60:.1f}m | "
+ f"done={total_done} | loaded={total_loaded} | failed={total_failed}"
+ )
+
+ # Final save after all batches are processed
+ _csv_overwrite(df, out_path)
+ print(
+ f"Total time {time.time() - t0_all:.2f}s | done={total_done} | loaded={total_loaded} | failed={total_failed}. "
+ f"Final saved → {out_path}"
+ )
+ return df
+
+
+def main():
+ args = argparse.ArgumentParser()
+ args.add_argument(
+ "--data_path",
+ required=True,
+ help="CSV with at least two columns: image path + full caption.",
+ )
+ args.add_argument(
+ "--model_dir",
+ default="Qwen/Qwen2.5-VL-32B-Instruct",
+ help="HF id or local path to Qwen2.5-VL-32B-Instruct",
+ )
+ args.add_argument(
+ "--batch_size", type=int, default=8, help="Keep modest; VLMs are memory heavy"
+ )
+ args.add_argument(
+ "--max_new_tokens", type=int, default=256, help="Max tokens to generate"
+ )
+ args.add_argument(
+ "--tp_size",
+ type=int,
+ default=4,
+ help="Tensor parallel degree for 32B (e.g., 4×A100-80GB)",
+ )
+ args.add_argument(
+ "--gpu_mem_util",
+ type=float,
+ default=0.90,
+ help="GPU memory utilization for vLLM",
+ )
+ args.add_argument(
+ "--dtype", default="bfloat16", choices=["auto", "bfloat16", "float16"]
+ )
+ args.add_argument("--temperature", type=float, default=0.0)
+ args.add_argument("--top_p", type=float, default=1.0)
+
+ args_dct = args.parse_args()
+
+ processor = AutoProcessor.from_pretrained(args_dct.model_dir)
+ llm = LLM(
+ model=args_dct.model_dir,
+ tensor_parallel_size=args_dct.tp_size,
+ gpu_memory_utilization=args_dct.gpu_mem_util,
+ dtype=None if args_dct.dtype == "auto" else args_dct.dtype,
+ )
+
+ df = pd.read_csv(args_dct.data_path) # Load input CSV
+
+ # Process in batches and generate subcaptions
+ df = process_batched(
+ df=df,
+ llm=llm,
+ processor=processor,
+ out_path=args_dct.data_path,
+ batch_size=args_dct.batch_size,
+ max_new_tokens=args_dct.max_new_tokens,
+ temperature=args_dct.temperature,
+ top_p=args_dct.top_p,
+ )
+
+ print(f"Completed writing {len(df)} rows → {args_dct.data_path}")
+
+
+if __name__ == "__main__":
+ main()
diff --git a/working/process/subcaption_and_summary_generation/src/generate_summary_vllm.py b/working/process/subcaption_and_summary_generation/src/generate_summary_vllm.py
new file mode 100644
index 0000000..6653167
--- /dev/null
+++ b/working/process/subcaption_and_summary_generation/src/generate_summary_vllm.py
@@ -0,0 +1,228 @@
+#!/usr/bin/env python3
+import argparse
+import os
+import re
+import time
+
+import pandas as pd
+from transformers import AutoTokenizer
+from vllm import LLM, SamplingParams
+
+
+os.environ.setdefault("VLLM_WORKER_MULTIPROC_METHOD", "spawn")
+
+prompt = (
+ "### INSTRUCTIONS:\n"
+ "You will be provided with:\n"
+ "1. A subcaption that describes a subfigure from a compound figure.\n"
+ "2. The full caption of the compound figure.\n"
+ "3. A context passage related to the compound figure.\n"
+ "**Definition of compound figure:** A compound figure is a figure that contains multiple subfigures of the same topic (e.g., panels A, B, C, etc.).\n\n"
+ "Your task is to summarize only the portions of the context passage "
+ "that are most relevant to the given subcaption. The full caption \n"
+ "is provided for additional information.\n"
+ "The summary should:\n"
+ "- Use both the subcaption and the full caption to determine context.\n"
+ "- Be concise and focused on the subcaption's content.\n"
+ "- Exclude unrelated information from the context passage.\n"
+ "- Preserve key biomedical terminology exactly as it appears.\n"
+ "- Output the summary only, without any labels or additional text in the following format:\n"
+ "\n"
+ "\n"
+ "\n\n"
+ "### INPUT:\n\n"
+)
+
+
+def build_chat(tokenizer, user_prompt: str, max_length: int = 32700):
+ """
+ Build chat-style input encoding for vLLM from user prompt.
+
+ Args:
+ tokenizer: The tokenizer to use.
+ user_prompt (str): The user prompt string.
+ max_length (int): Maximum token length for the input.
+
+ Returns
+ -------
+ encoded inputs.
+ """
+ messages = [
+ {
+ "role": "system",
+ "content": "You are a biomedical image context summary generator.",
+ },
+ {"role": "user", "content": user_prompt},
+ ]
+
+ enc = tokenizer.apply_chat_template(
+ messages, tokenize=False, add_generation_prompt=True
+ )
+
+ token_ids = tokenizer.encode(enc, add_special_tokens=False)
+ if len(token_ids) > max_length:
+ enc = tokenizer.decode(token_ids[:max_length], skip_special_tokens=False)
+
+ return enc
+
+
+def _is_empty(x) -> bool:
+ """
+ Check if a response is empty (None, NaN, or empty string). Used to identify unprocessed rows.
+
+ Args:
+ x: The input to check.
+
+ Returns
+ -------
+ bool: True if x is considered empty, False otherwise.
+ """
+ return (
+ (x is None) or (isinstance(x, float) and pd.isna(x)) or (str(x).strip() == "")
+ )
+
+
+def _csv_overwrite(_df: pd.DataFrame, _path: str):
+ """
+ Safely overwrite a CSV file by writing to a temporary file first and then replacing the original.
+
+ Args:
+ _df (pd.DataFrame): DataFrame to save.
+ _path (str): Path to the CSV file.
+ """
+ tmp = _path + ".tmp"
+ _df.to_csv(tmp, index=False)
+ os.replace(tmp, _path)
+
+
+def process_data_batched_vllm(
+ df: pd.DataFrame,
+ llm: LLM,
+ tokenizer,
+ out_path: str,
+ batch_size: int = 16,
+ max_new_tokens: int = 192,
+) -> None:
+ """
+ Process the DataFrame in batches using vLLM to generate summaries.
+
+ Args:
+ df (pd.DataFrame): Input DataFrame with columns 'caption', 'sub_caption', and 'image_context'.
+ llm (LLM): The vLLM model instance.
+ tokenizer: The tokenizer for building prompts.
+ out_path (str): Path to save the output CSV.
+ batch_size (int): Number of samples to process in each batch.
+ max_new_tokens (int): Maximum number of new tokens to generate for each summary.
+ """
+ pattern = re.compile(
+ r"\s*(.*?)\s*<\/summary>", re.DOTALL
+ ) # Pattern to extract summary text
+
+ sampling_params = SamplingParams(
+ max_tokens=max_new_tokens, temperature=0.0, top_p=1.0
+ )
+ t0_all = time.time()
+
+ # Batch Processing Loop
+ for start in range(0, len(df), batch_size):
+ end = min(start + batch_size, len(df))
+ idxs = [
+ i for i in range(start, end) if _is_empty(df.loc[i, "summary"])
+ ] # Select unprocessed rows. This also allows resuming.
+ if idxs:
+ batch_prompts = []
+ for i in idxs:
+ # Prompt construction with full caption, subcaption, and context passage
+ user_prompt = (
+ prompt
+ + f"Full Caption:\n{df.caption.iloc[i]}\n\n"
+ + f"Subcaption:\n{df.sub_caption.iloc[i]}\n\n"
+ + f"Context Passage:\n{df.image_context.iloc[i]}"
+ )
+ batch_prompts.append(build_chat(tokenizer, user_prompt))
+
+ outs = llm.generate(batch_prompts, sampling_params) # vLLM generation call
+
+ for j, out in enumerate(outs):
+ text = out.outputs[0].text
+ m = pattern.search(text)
+ df.loc[idxs[j], "summary"] = (
+ m.group(1).strip() if m else text.strip()
+ ) # Extract summary or use full text if pattern not found
+
+ # Overwrite CSV checkpoint every (batch size * 10) batches
+ if start and (start % (10 * batch_size) == 0):
+ _csv_overwrite(df, out_path)
+ print(f"[ckpt] Saved at row {start} → {out_path}")
+
+ # Final save after all batches are processed
+ _csv_overwrite(df, out_path)
+ print(f"Total time {time.time() - t0_all:.2f}s. Final saved → {out_path}")
+
+
+def main():
+ parser = argparse.ArgumentParser(formatter_class=argparse.RawTextHelpFormatter)
+ parser.add_argument("--data_path", required=True, help="CSV path to data")
+ parser.add_argument(
+ "--model_dir",
+ required=True,
+ help="Path or HF id for the model (e.g., /model-weights/Qwen2.5-7B-Instruct)",
+ )
+ parser.add_argument(
+ "--batch_size",
+ type=int,
+ default=16,
+ help="vLLM micro-batch size per generate() call",
+ )
+ parser.add_argument(
+ "--max_new_tokens", type=int, default=192, help="Max new tokens to generate"
+ )
+ parser.add_argument(
+ "--tp_size", type=int, default=1, help="Tensor parallel size for vLLM"
+ )
+ parser.add_argument(
+ "--gpu_mem_util",
+ type=float,
+ default=0.90,
+ help="GPU memory utilization fraction for vLLM",
+ )
+ parser.add_argument(
+ "--dtype", default="bfloat16", choices=["auto", "bfloat16", "float16"]
+ )
+
+ args = parser.parse_args()
+
+ data_path = args.data_path
+ model_dir = args.model_dir
+
+ # Tokenizer used only to template chat → plain prompt string
+ tokenizer = AutoTokenizer.from_pretrained(model_dir, trust_remote_code=True)
+
+ # Init vLLM engine
+ # Notes:
+ # - tensor_parallel_size lets you span multiple GPUs if available.
+ # - gpu_memory_utilization tunes how full vLLM packs the GPU.
+ # - max_model_len can be set if you have very long contexts (defaults are fine for most).
+ llm = LLM(
+ model=model_dir,
+ tensor_parallel_size=args.tp_size,
+ gpu_memory_utilization=args.gpu_mem_util,
+ dtype=None if args.dtype == "auto" else args.dtype,
+ )
+
+ df = pd.read_csv(data_path) # Load input CSV
+
+ process_data_batched_vllm(
+ df=df,
+ llm=llm,
+ tokenizer=tokenizer,
+ out_path=data_path,
+ batch_size=args.batch_size,
+ max_new_tokens=args.max_new_tokens,
+ )
+
+ print(f"Completed writing {len(df)} entries to: {data_path}")
+
+
+if __name__ == "__main__":
+ main()