diff --git a/docs/api/datasets.rst b/docs/api/datasets.rst index b02439d26..ea772bfec 100644 --- a/docs/api/datasets.rst +++ b/docs/api/datasets.rst @@ -224,6 +224,7 @@ Available Datasets datasets/pyhealth.datasets.SampleDataset datasets/pyhealth.datasets.MIMIC3Dataset datasets/pyhealth.datasets.MIMIC4Dataset + datasets/pyhealth.datasets.MIMIC4FHIRDataset datasets/pyhealth.datasets.MedicalTranscriptionsDataset datasets/pyhealth.datasets.CardiologyDataset datasets/pyhealth.datasets.eICUDataset diff --git a/docs/api/datasets/pyhealth.datasets.MIMIC4FHIRDataset.rst b/docs/api/datasets/pyhealth.datasets.MIMIC4FHIRDataset.rst new file mode 100644 index 000000000..1a19fc5e9 --- /dev/null +++ b/docs/api/datasets/pyhealth.datasets.MIMIC4FHIRDataset.rst @@ -0,0 +1,70 @@ +pyhealth.datasets.MIMIC4FHIRDataset +===================================== + +`MIMIC-IV on FHIR `_ NDJSON ingest +for CEHR-style token sequences used with +:class:`~pyhealth.tasks.mpf_clinical_prediction.MPFClinicalPredictionTask` and +:class:`~pyhealth.models.EHRMambaCEHR`. + +YAML defaults live in ``pyhealth/datasets/configs/mimic4_fhir.yaml``. Unlike the +earlier nested-object approach, the YAML now declares a normal ``tables:`` +schema for flattened FHIR resources (``patient``, ``encounter``, ``condition``, +``observation``, ``medication_request``, ``procedure``). The class subclasses +:class:`~pyhealth.datasets.BaseDataset` and builds a standard Polars +``global_event_df`` backed by cached Parquet (``global_event_df.parquet/part-*.parquet``), +same tabular path as other datasets: :meth:`~pyhealth.datasets.BaseDataset.set_task`, +:meth:`iter_patients`, :meth:`get_patient`, etc. + +**Ingest (out-of-core).** Matching ``*.ndjson`` / ``*.ndjson.gz`` files are read +**line by line**; each resource is normalized into a flattened per-resource +Parquet table under ``cache/flattened_tables/``. Those tables are then fed +through the regular YAML-driven :class:`~pyhealth.datasets.BaseDataset` loader to +materialize ``global_event_df``. This keeps FHIR aligned with PyHealth's usual +table-first pipeline instead of reparsing nested JSON per patient downstream. + +**``max_patients``.** When set, the loader selects the first *N* patient ids after +a **sorted** ``unique`` over the flattened patient table, filters every +normalized table to that cohort, and then builds ``global_event_df`` from the +filtered tables. Ingest still scans all matching NDJSON once unless you also +override ``glob_patterns`` / ``glob_pattern`` (defaults skip non-flattened PhysioNet shards). + +**Downstream memory (still important).** Streaming ingest avoids loading the +entire NDJSON corpus into RAM at once, but other steps can still be heavy on +large cohorts: ``global_event_df`` materialization, MPF vocabulary warmup, and +:meth:`set_task` still walk patients and samples; training needs RAM/VRAM for the +model and batches. For a **full** PhysioNet tree, plan for **large disk** +(flattened tables plus event cache), **comfortable system RAM** for Polars/PyArrow +and task pipelines, and restrict ``glob_patterns`` / ``glob_pattern`` or ``max_patients`` when +prototyping on a laptop. + +**Recommended hardware (informal)** + +Order-of-magnitude guides, not guarantees. Ingest footprint is **much smaller** +than “load everything into Python”; wall time still grows with **decompressed +NDJSON volume** and the amount of flattened table data produced. + +* **Smoke / CI** + Small on-disk fixtures (see tests and ``examples/mimic4fhir_mpf_ehrmamba.py``): + a recent laptop is sufficient. + +* **Laptop-scale real FHIR subset** + A **narrow** ``glob_patterns`` / ``glob_pattern`` and/or ``max_patients`` in the hundreds keeps + cache and task passes manageable. **≥ 16 GB** system RAM is a practical + comfort target for Polars + trainer + OS; validate GPU **VRAM** for your + ``max_len`` and batch size. + +* **Full default globs on a complete export** + Favor **workstations or servers** with **fast SSD**, **large disk**, and + **ample RAM** for downstream steps—not because NDJSON is fully buffered in + memory during ingest, but because total work and caches still scale with the + full dataset. + +.. autoclass:: pyhealth.datasets.MIMIC4FHIRDataset + :members: + :undoc-members: + :show-inheritance: + +.. autoclass:: pyhealth.datasets.ConceptVocab + :members: + :undoc-members: + :show-inheritance: diff --git a/docs/api/models.rst b/docs/api/models.rst index 7368dec94..c7e9f2729 100644 --- a/docs/api/models.rst +++ b/docs/api/models.rst @@ -185,6 +185,7 @@ API Reference models/pyhealth.models.MoleRec models/pyhealth.models.Deepr models/pyhealth.models.EHRMamba + models/pyhealth.models.EHRMambaCEHR models/pyhealth.models.JambaEHR models/pyhealth.models.ContraWR models/pyhealth.models.SparcNet diff --git a/docs/api/models/pyhealth.models.EHRMambaCEHR.rst b/docs/api/models/pyhealth.models.EHRMambaCEHR.rst new file mode 100644 index 000000000..79466cad3 --- /dev/null +++ b/docs/api/models/pyhealth.models.EHRMambaCEHR.rst @@ -0,0 +1,12 @@ +pyhealth.models.EHRMambaCEHR +=================================== + +EHRMambaCEHR applies CEHR-style embeddings (:class:`~pyhealth.models.cehr_embeddings.MambaEmbeddingsForCEHR`) +and a stack of :class:`~pyhealth.models.MambaBlock` layers to a single FHIR token stream, for use with +:class:`~pyhealth.tasks.mpf_clinical_prediction.MPFClinicalPredictionTask` and +:class:`~pyhealth.datasets.mimic4_fhir.MIMIC4FHIRDataset`. + +.. autoclass:: pyhealth.models.EHRMambaCEHR + :members: + :undoc-members: + :show-inheritance: diff --git a/docs/api/tasks.rst b/docs/api/tasks.rst index 399b8f1aa..83790ca44 100644 --- a/docs/api/tasks.rst +++ b/docs/api/tasks.rst @@ -214,6 +214,7 @@ Available Tasks Drug Recommendation Length of Stay Prediction Medical Transcriptions Classification + MPF Clinical Prediction (FHIR) Mortality Prediction (Next Visit) Mortality Prediction (StageNet MIMIC-IV) Patient Linkage (MIMIC-III) diff --git a/docs/api/tasks/pyhealth.tasks.mpf_clinical_prediction.rst b/docs/api/tasks/pyhealth.tasks.mpf_clinical_prediction.rst new file mode 100644 index 000000000..ff66deb08 --- /dev/null +++ b/docs/api/tasks/pyhealth.tasks.mpf_clinical_prediction.rst @@ -0,0 +1,12 @@ +pyhealth.tasks.mpf_clinical_prediction +====================================== + +Multitask Prompted Fine-tuning (MPF) style binary clinical prediction on FHIR +token timelines, paired with :class:`~pyhealth.datasets.MIMIC4FHIRDataset` and +:class:`~pyhealth.models.EHRMambaCEHR`. Based on CEHR / EHRMamba ideas; see the +paper linked in the course replication PR. + +.. autoclass:: pyhealth.tasks.MPFClinicalPredictionTask + :members: + :undoc-members: + :show-inheritance: diff --git a/examples/mimic4fhir_mpf_ehrmamba.py b/examples/mimic4fhir_mpf_ehrmamba.py new file mode 100644 index 000000000..e53098779 --- /dev/null +++ b/examples/mimic4fhir_mpf_ehrmamba.py @@ -0,0 +1,868 @@ +"""EHRMambaCEHR on MIMIC-IV FHIR NDJSON with MPF clinical prediction (ablations). + +Replication target: EHRMamba / CEHR-style modeling on tokenized FHIR timelines +(e.g. `arXiv:2405.14567 `_). This script is +runnable end-to-end on **synthetic** NDJSON (``--quick-test``) or on +credentialled MIMIC-IV on FHIR from PhysioNet. + +Experimental setup (for write-ups / PR): + * **Data**: Synthetic two-patient NDJSON (``--quick-test``) or disk NDJSON + under ``MIMIC4_FHIR_ROOT`` / ``--fhir-root``. + * **Task ablations**: ``max_len`` (context window), ``use_mpf`` vs generic + ````/```` boundaries (``--no-mpf``). + * **Model ablations**: ``hidden_dim`` (embedding width); optional dropout + fixed at 0.1 in this script. + * **Train**: Adam via :class:`~pyhealth.trainer.Trainer`, monitor ROC-AUC, + report test ROC-AUC / PR-AUC. + +**Ablation mode** (``--ablation``): sweeps a small grid on synthetic data only, +trains 1 epoch per config, and prints a comparison table. Use this to document +how task/model knobs affect metrics on the minimal fixture before scaling to +real FHIR. + +**Findings** (fill in after your runs; synthetic runs are noisy): + On ``--quick-test`` data, longer ``max_len`` and MPF specials typically + change logits enough to move AUC slightly; real MIMIC-IV FHIR runs are + needed for conclusive comparisons. Paste your table from ``--ablation`` + into the PR description. + +**Scaling:** :class:`~pyhealth.datasets.MIMIC4FHIRDataset` streams NDJSON into +flattened per-resource Parquet tables (bounded RAM during ingest). This example trains via +``dataset.set_task(MPFClinicalPredictionTask)`` → LitData-backed +:class:`~pyhealth.datasets.sample_dataset.SampleDataset` → +:class:`~pyhealth.trainer.Trainer` (PyHealth’s standard path), instead of +materializing all samples with ``gather_samples()``. Prefer ``--max-patients`` to + bound ingest when possible. Very large cohorts still need RAM/disk for task +caches and MPF vocabulary warmup. + +**Offline flattened tables (NDJSON normalization already done):** pass +``--prebuilt-global-event-dir`` pointing at a directory containing the normalized +FHIR tables (``patient.parquet``, ``encounter.parquet``, ``condition.parquet``, +etc.). The example seeds ``flattened_tables/`` under the usual PyHealth cache UUID, +then lets :class:`~pyhealth.datasets.BaseDataset` rebuild +``global_event_df.parquet/`` from those tables — the downstream path is still +``global_event_df`` → :class:`~pyhealth.data.Patient` → +:class:`~pyhealth.tasks.mpf_clinical_prediction.MPFClinicalPredictionTask` → +:class:`~pyhealth.trainer.Trainer``. Use ``--fhir-root`` / ``--glob-pattern`` / +``--max-patients -1`` matching the ingest fingerprint. +``--train-patient-cap`` restricts task transforms via ``task.pre_filter`` using a +label-aware deterministic patient subset. The full ``unique_patient_ids`` scan and MPF vocab warmup +in the dataset still walk the cached cohort. + +**Approximate minimum specs** (``--quick-test``, CPU, synthetic 2-patient +fixture; measured once on macOS/arm64 with ``/usr/bin/time -l``): peak RSS +~**600–700 MiB**, wall **~10–15 s** for two short epochs. Real NDJSON/GZ at scale +needs proportionally more RAM, disk, and time; GPU helps training, not the +current all-in-RAM parse. + +Usage: + cd PyHealth && PYTHONPATH=. python examples/mimic4fhir_mpf_ehrmamba.py --quick-test + PYTHONPATH=. python examples/mimic4fhir_mpf_ehrmamba.py --quick-test --ablation + export MIMIC4_FHIR_ROOT=/path/to/fhir + pixi run -e base python examples/mimic4fhir_mpf_ehrmamba.py --fhir-root "$MIMIC4_FHIR_ROOT" + + # Prebuilt flattened FHIR tables (skip NDJSON normalization); cap patients for a smoke train + pixi run -e base python examples/mimic4fhir_mpf_ehrmamba.py \\ + --prebuilt-global-event-dir /path/to/flattened_table_dir \\ + --fhir-root /same/as/ndjson/ingest/root --glob-pattern 'Mimic*.ndjson.gz' --max-patients -1 \\ + --train-patient-cap 2048 --epochs 2 \\ + --ntfy-url 'https://ntfy.sh/your-topic' +""" + +from __future__ import annotations + +import argparse +import os +import random +import re +import shutil +import sys +import tempfile +import time +import urllib.error +import urllib.request +from pathlib import Path +from typing import Any, Dict, List, Optional + +_parser = argparse.ArgumentParser(description="EHRMambaCEHR on MIMIC-IV FHIR (MPF)") +_parser.add_argument( + "--gpu", + type=int, + default=None, + help="GPU index; sets CUDA_VISIBLE_DEVICES.", +) +_parser.add_argument( + "--fhir-root", + type=str, + default=None, + help="Root directory with NDJSON (default: MIMIC4_FHIR_ROOT env).", +) +_parser.add_argument( + "--glob-pattern", + type=str, + default=None, + help=( + "Override glob for NDJSON/NDJSON.GZ (default: yaml **/*.ndjson.gz). " + "Use a narrow pattern to limit ingest time and cache size." + ), +) +_parser.add_argument( + "--max-len", + type=int, + default=512, + help="Sequence length ablation (e.g. 512 / 1024 / 2048 per proposal).", +) +_parser.add_argument( + "--no-mpf", + action="store_true", + help="Ablation: use generic CLS/REG specials instead of task MPF tokens.", +) +_parser.add_argument( + "--hidden-dim", + type=int, + default=128, + help="Embedding / hidden size (model ablation).", +) +_parser.add_argument( + "--lr", + type=float, + default=1e-3, + help="Adam learning rate (trainer.train optimizer_params).", +) +_parser.add_argument( + "--quick-test", + action="store_true", + help="Use synthetic in-memory FHIR lines only (no disk root).", +) +_parser.add_argument( + "--ablation", + action="store_true", + help="Run a small max_len × MPF × hidden_dim grid on synthetic data; print table.", +) +_parser.add_argument( + "--epochs", + type=int, + default=None, + help="Training epochs (default: 2 with --quick-test, else 20).", +) +_parser.add_argument( + "--max-patients", + type=int, + default=500, + help=( + "Fingerprint for cache dir: cap patients during normalization (-1 = full cohort, " + "match an uncapped NDJSON→flattened-table export)." + ), +) +_parser.add_argument( + "--prebuilt-global-event-dir", + type=str, + default=None, + help=( + "Directory with normalized flattened FHIR tables (*.parquet). Seeds " + "cache/flattened_tables/ so training skips NDJSON normalization " + "(downstream unchanged: Patient + MPF + Trainer)." + ), +) +_parser.add_argument( + "--ingest-num-shards", + type=int, + default=None, + help="Compatibility no-op: retained for CLI stability with older runs.", +) +_parser.add_argument( + "--train-patient-cap", + type=int, + default=None, + help=( + "After cache is ready, only build samples from a deterministic label-aware " + "patient subset of size N (reduces train time; unique-id scan of " + "global_event_df still runs once)." + ), +) +_parser.add_argument( + "--ntfy-url", + type=str, + default=None, + help="POST notification when main() finishes (e.g. https://ntfy.sh/topic).", +) +_parser.add_argument( + "--loss-plot-path", + type=str, + default=None, + help="Write loss curve PNG here (default: alongside Trainer log under output/).", +) +_parser.add_argument( + "--cache-dir", + type=str, + default=None, + help="PyHealth dataset cache parent (UUID subdir added by MIMIC4FHIRDataset).", +) +_parser.add_argument( + "--task-num-workers", + type=int, + default=None, + help=( + "Workers for LitData task/processor transforms (default: dataset " + "``num_workers``, usually 1)." + ), +) +_pre_args, _ = _parser.parse_known_args() +if _pre_args.gpu is not None: + os.environ["CUDA_VISIBLE_DEVICES"] = str(_pre_args.gpu) + +import torch + +import polars as pl + +from pyhealth.datasets import MIMIC4FHIRDataset, get_dataloader +from pyhealth.processors.cehr_processor import infer_mortality_label +from pyhealth.models import EHRMambaCEHR +from pyhealth.tasks.mpf_clinical_prediction import MPFClinicalPredictionTask +from pyhealth.trainer import Trainer + + +class PatientCappedMPFTask(MPFClinicalPredictionTask): + """Example-only: limit task transform to an explicit patient_id allow-list.""" + + def __init__( + self, + *, + max_len: int, + use_mpf: bool, + patient_ids_allow: List[str], + ) -> None: + super().__init__(max_len=max_len, use_mpf=use_mpf) + self.patient_ids_allow = patient_ids_allow + + def pre_filter(self, df: pl.LazyFrame) -> pl.LazyFrame: + return df.filter(pl.col("patient_id").is_in(self.patient_ids_allow)) + + +SCRIPT_DIR = os.path.dirname(os.path.abspath(__file__)) +SEED = 42 +DEVICE = "cuda" if torch.cuda.is_available() else "cpu" +BATCH_SIZE = 8 +EPOCHS = 20 +SPLIT_RATIOS = (0.7, 0.1, 0.2) + + +def _max_patients_arg(v: int) -> Optional[int]: + return None if v is not None and v < 0 else v + + +def _seed_flattened_table_cache(prebuilt_dir: Path, ds: MIMIC4FHIRDataset) -> None: + """Copy normalized per-resource parquet tables into the dataset cache.""" + + tables = sorted(prebuilt_dir.glob("*.parquet")) + if not tables: + raise FileNotFoundError( + f"No *.parquet tables under {prebuilt_dir} — expected flattened FHIR tables." + ) + prepared = ds.prepared_tables_dir + if prepared.exists() and any(prepared.glob("*.parquet")): + return + prepared.mkdir(parents=True, exist_ok=True) + for src in tables: + dest = prepared / src.name + if dest.exists(): + continue + try: + os.link(src, dest) + except OSError: + shutil.copy2(src, dest) + + +def _parse_train_losses_from_log(log_path: Path) -> List[float]: + """Mean training loss per epoch from Trainer file log.""" + + if not log_path.is_file(): + return [] + text = log_path.read_text(encoding="utf-8", errors="replace") + losses: List[float] = [] + lines = text.splitlines() + for i, line in enumerate(lines): + if "--- Train epoch-" in line and i + 1 < len(lines): + m = re.search(r"loss:\s*([0-9.eE+-]+)", lines[i + 1]) + if m: + losses.append(float(m.group(1))) + return losses + + +def _write_loss_plot(losses: List[float], out_path: Path) -> None: + if not losses: + return + out_path.parent.mkdir(parents=True, exist_ok=True) + try: + import matplotlib.pyplot as plt + except ImportError: + csv_path = out_path.with_suffix(".csv") + csv_path.write_text( + "epoch,train_loss_mean\n" + + "\n".join(f"{i},{v}" for i, v in enumerate(losses)), + encoding="utf-8", + ) + print( + "matplotlib not installed; wrote", csv_path, "(pip install matplotlib for PNG)" + ) + return + plt.figure(figsize=(6, 3.5)) + plt.plot(range(len(losses)), losses, marker="o", linewidth=1) + plt.xlabel("epoch") + plt.ylabel("mean train loss") + plt.title("EHRMambaCEHR training loss (MPF)") + plt.grid(True, alpha=0.3) + plt.tight_layout() + plt.savefig(out_path, dpi=120) + plt.close() + print("loss plot:", out_path) + + +def _ntfy(url: str, title: str, message: str) -> None: + try: + req = urllib.request.Request( + url, + data=message.encode("utf-8"), + method="POST", + ) + req.add_header("Title", title[:200]) + with urllib.request.urlopen(req, timeout=60) as resp: + if resp.status >= 400: + print("ntfy HTTP", resp.status, file=sys.stderr) + except urllib.error.URLError as e: + print("ntfy failed:", e, file=sys.stderr) + + +def _quick_test_ndjson_dir() -> str: + """Write two-patient synthetic NDJSON; returns temp directory (caller cleans up).""" + import orjson + + _resources = [ + {"resourceType": "Patient", "id": "p-synth-1", "birthDate": "1950-01-01", "gender": "female"}, + { + "resourceType": "Encounter", "id": "e1", + "subject": {"reference": "Patient/p-synth-1"}, + "period": {"start": "2020-06-01T10:00:00Z"}, "class": {"code": "IMP"}, + }, + { + "resourceType": "Condition", "id": "c1", + "subject": {"reference": "Patient/p-synth-1"}, + "encounter": {"reference": "Encounter/e1"}, + "code": {"coding": [{"system": "http://hl7.org/fhir/sid/icd-10-cm", "code": "I10"}]}, + "onsetDateTime": "2020-06-01T11:00:00Z", + }, + {"resourceType": "Patient", "id": "p-synth-2", "birthDate": "1940-05-05", "deceasedBoolean": True}, + { + "resourceType": "Encounter", "id": "e-dead", + "subject": {"reference": "Patient/p-synth-2"}, + "period": {"start": "2020-07-01T10:00:00Z"}, "class": {"code": "IMP"}, + }, + { + "resourceType": "Observation", "id": "o-dead", + "subject": {"reference": "Patient/p-synth-2"}, + "encounter": {"reference": "Encounter/e-dead"}, + "effectiveDateTime": "2020-07-01T12:00:00Z", + "code": {"coding": [{"system": "http://loinc.org", "code": "789-0"}]}, + }, + ] + ndjson_text = "\n".join(orjson.dumps(r).decode("utf-8") for r in _resources) + "\n" + + tmp = tempfile.mkdtemp(prefix="pyhealth_mimic4_fhir_quick_") + Path(tmp, "fixture.ndjson").write_text(ndjson_text, encoding="utf-8") + return tmp + + +def _patient_label(ds: MIMIC4FHIRDataset, patient_id: str) -> int: + patient = ds.get_patient(patient_id) + return int(infer_mortality_label(patient)) + + +def _ensure_binary_label_coverage(ds: MIMIC4FHIRDataset) -> None: + found: Dict[int, str] = {} + scanned = 0 + for patient_id in ds.unique_patient_ids: + label = _patient_label(ds, patient_id) + scanned += 1 + found.setdefault(label, patient_id) + if len(found) == 2: + print( + "label preflight:", + {"scanned_patients": scanned, "example_patient_ids": found}, + ) + return + raise SystemExit( + "Binary mortality example found only one label in the available cohort; " + "cannot build a valid binary training set from this cache." + ) + + +def _select_patient_ids_for_cap( + ds: MIMIC4FHIRDataset, requested_cap: int +) -> List[str]: + patient_ids = ds.unique_patient_ids + if not patient_ids: + return [] + + desired = max(2, requested_cap) + desired = min(desired, len(patient_ids)) + if desired < requested_cap: + print( + f"train_patient_cap requested {requested_cap}, but only {desired} patients are available." + ) + elif requested_cap < 2: + print( + f"train_patient_cap={requested_cap} is too small for binary labels; using {desired}." + ) + + encountered: List[str] = [] + label_by_patient_id: Dict[str, int] = {} + first_by_label: Dict[int, str] = {} + for patient_id in patient_ids: + label = _patient_label(ds, patient_id) + encountered.append(patient_id) + label_by_patient_id[patient_id] = label + first_by_label.setdefault(label, patient_id) + if len(encountered) >= desired and len(first_by_label) == 2: + break + + if len(first_by_label) < 2: + raise SystemExit( + "Unable to satisfy --train-patient-cap with both binary labels from the " + "available cohort. Use a different cache/export or remove the cap." + ) + + selected = encountered[:desired] + selected_labels = {label_by_patient_id[pid] for pid in selected} + if len(selected_labels) == 1: + missing_label = 1 - next(iter(selected_labels)) + replacement = first_by_label[missing_label] + for idx in range(len(selected) - 1, -1, -1): + if label_by_patient_id[selected[idx]] != missing_label: + selected[idx] = replacement + break + + counts = { + 0: sum(1 for pid in selected if label_by_patient_id[pid] == 0), + 1: sum(1 for pid in selected if label_by_patient_id[pid] == 1), + } + print( + "train_patient_cap selection:", + { + "requested": requested_cap, + "selected": len(selected), + "scanned_patients": len(encountered), + "label_counts": counts, + }, + ) + return selected + + +def _sample_label(sample: Dict[str, Any]) -> int: + label = sample["label"] + if isinstance(label, torch.Tensor): + return int(label.reshape(-1)[0].item()) + return int(label) + + +def _split_counts(n: int) -> List[int]: + if n < 3: + raise ValueError("Need at least 3 samples for three-way stratified split.") + counts = [1, 1, 1] + remaining = n - 3 + raw = [ratio * remaining for ratio in SPLIT_RATIOS] + floors = [int(x) for x in raw] + for i, floor in enumerate(floors): + counts[i] += floor + assigned = sum(counts) + order = sorted( + range(3), + key=lambda i: raw[i] - floors[i], + reverse=True, + ) + for i in order: + if assigned >= n: + break + counts[i] += 1 + assigned += 1 + counts[0] += n - assigned + return counts + + +def _split_sample_dataset_for_binary_metrics(sample_ds: Any) -> tuple[Any, Any, Any]: + if len(sample_ds) < 8: + print("sample count < 8; reusing the full dataset for train/val/test.") + return sample_ds, sample_ds, sample_ds + + label_to_indices: Dict[int, List[int]] = {0: [], 1: []} + for idx in range(len(sample_ds)): + label_to_indices[_sample_label(sample_ds[idx])].append(idx) + + label_counts = {label: len(indices) for label, indices in label_to_indices.items()} + min_count = min(label_counts.values()) + if min_count < 3: + print( + "label distribution too small for disjoint binary train/val/test splits; " + "reusing the full dataset for train/val/test.", + label_counts, + ) + return sample_ds, sample_ds, sample_ds + + rng = random.Random(SEED) + split_indices: List[List[int]] = [[], [], []] + for indices in label_to_indices.values(): + shuffled = indices[:] + rng.shuffle(shuffled) + n_train, n_val, n_test = _split_counts(len(shuffled)) + split_indices[0].extend(shuffled[:n_train]) + split_indices[1].extend(shuffled[n_train : n_train + n_val]) + split_indices[2].extend(shuffled[n_train + n_val : n_train + n_val + n_test]) + + for indices in split_indices: + indices.sort() + + split_counts = [] + for indices in split_indices: + split_counts.append( + { + 0: sum(1 for idx in indices if _sample_label(sample_ds[idx]) == 0), + 1: sum(1 for idx in indices if _sample_label(sample_ds[idx]) == 1), + "n": len(indices), + } + ) + print( + "binary stratified split counts:", + {"train": split_counts[0], "val": split_counts[1], "test": split_counts[2]}, + ) + return ( + sample_ds.subset(split_indices[0]), + sample_ds.subset(split_indices[1]), + sample_ds.subset(split_indices[2]), + ) + + +def _build_loaders_from_sample_dataset( + sample_ds: Any, + vocab_size: int, +) -> tuple[Any, Any, Any, Any, int]: + train_ds, val_ds, test_ds = _split_sample_dataset_for_binary_metrics(sample_ds) + train_loader = get_dataloader(train_ds, batch_size=BATCH_SIZE, shuffle=True) + val_loader = get_dataloader(val_ds, batch_size=BATCH_SIZE, shuffle=False) + test_loader = get_dataloader(test_ds, batch_size=BATCH_SIZE, shuffle=False) + return sample_ds, train_loader, val_loader, test_loader, vocab_size + + +def run_single_train( + *, + fhir_root: str, + max_len: int, + use_mpf: bool, + hidden_dim: int, + epochs: int, + lr: float = 1e-3, + glob_pattern: str = "*.ndjson", + cache_dir: Optional[str] = None, + dataset_max_patients: Optional[int] = 500, + ingest_num_shards: Optional[int] = None, + prebuilt_global_event_dir: Optional[str] = None, + train_patient_cap: Optional[int] = None, +) -> Dict[str, float]: + """Train/eval one configuration; returns test metrics (floats).""" + + ds_kw: Dict[str, Any] = { + "root": fhir_root, + "glob_pattern": glob_pattern, + "cache_dir": cache_dir, + "max_patients": dataset_max_patients, + } + if ingest_num_shards is not None: + ds_kw["ingest_num_shards"] = ingest_num_shards + ds = MIMIC4FHIRDataset(**ds_kw) + if prebuilt_global_event_dir: + _seed_flattened_table_cache( + Path(prebuilt_global_event_dir).expanduser().resolve(), ds + ) + if train_patient_cap is not None: + allow = _select_patient_ids_for_cap(ds, train_patient_cap) + task: MPFClinicalPredictionTask = PatientCappedMPFTask( + max_len=max_len, + use_mpf=use_mpf, + patient_ids_allow=allow, + ) + else: + _ensure_binary_label_coverage(ds) + task = MPFClinicalPredictionTask(max_len=max_len, use_mpf=use_mpf) + sample_ds = ds.set_task(task, num_workers=1) + vocab_size = ds.vocab.vocab_size + sample_ds, train_l, val_l, test_l, vocab_size = _build_loaders_from_sample_dataset( + sample_ds, vocab_size + ) + model = EHRMambaCEHR( + dataset=sample_ds, + vocab_size=vocab_size, + embedding_dim=hidden_dim, + num_layers=2, + dropout=0.1, + ) + trainer = Trainer(model=model, metrics=["roc_auc", "pr_auc"], device=DEVICE) + trainer.train( + train_dataloader=train_l, + val_dataloader=val_l, + epochs=epochs, + monitor="roc_auc", + optimizer_params={"lr": lr}, + ) + results = trainer.evaluate(test_l) + return {k: float(v) for k, v in results.items()} + + +def run_ablation_table(*, lr: float = 1e-3) -> None: + """Task × model grid on synthetic NDJSON (short runs for comparison).""" + + # Ablations: context length, MPF vs CLS/REG, plus one hidden_dim pair. + grid = [ + (32, True, 32), + (32, False, 32), + (96, True, 64), + (96, False, 64), + ] + tmp = _quick_test_ndjson_dir() + try: + print( + "Ablation (synthetic, 1 epoch each): max_len, use_mpf, hidden_dim, lr=" + f"{lr} -> test roc_auc, pr_auc" + ) + rows = [] + t0 = time.perf_counter() + for max_len, use_mpf, hidden_dim in grid: + metrics = run_single_train( + fhir_root=tmp, + max_len=max_len, + use_mpf=use_mpf, + hidden_dim=hidden_dim, + epochs=1, + lr=lr, + cache_dir=tmp, + dataset_max_patients=500, + ) + rows.append((max_len, use_mpf, hidden_dim, metrics)) + print( + f" max_len={max_len} mpf={use_mpf} hid={hidden_dim} -> " + f"roc_auc={metrics['roc_auc']:.4f} pr_auc={metrics['pr_auc']:.4f}" + ) + print("ablation_wall_s:", round(time.perf_counter() - t0, 2)) + best = max(rows, key=lambda r: r[3]["roc_auc"]) + print( + "best_by_roc_auc:", + { + "max_len": best[0], + "use_mpf": best[1], + "hidden_dim": best[2], + "metrics": best[3], + }, + ) + except Exception: + print(f"ablation: leaving scratch directory for debugging: {tmp}", file=sys.stderr) + raise + else: + shutil.rmtree(tmp, ignore_errors=True) + + +def main() -> None: + args = _parser.parse_args() + status = "abort" + ntfy_detail = "" + try: + _main_train(args) + status = "ok" + ntfy_detail = "Training finished successfully." + except SystemExit as e: + status = "exit" + ntfy_detail = f"SystemExit {e.code!r}" + raise + except Exception as e: + status = "error" + ntfy_detail = f"{type(e).__name__}: {e}"[:3800] + raise + finally: + if args.ntfy_url and status in ("ok", "error"): + _ntfy( + args.ntfy_url, + "mimic-fhir-train OK" if status == "ok" else "mimic-fhir-train FAIL", + ntfy_detail, + ) + + +def _main_train(args: argparse.Namespace) -> None: + fhir_root = args.fhir_root or os.environ.get("MIMIC4_FHIR_ROOT") + quick = args.quick_test + quick_test_tmp: Optional[str] = None + if args.epochs is not None: + epochs = args.epochs + else: + epochs = 2 if quick else EPOCHS + + if args.ablation: + if not quick: + raise SystemExit("--ablation requires --quick-test (synthetic data only).") + run_ablation_table(lr=args.lr) + return + + print("EHRMambaCEHR – MIMIC-IV FHIR (MPF clinical prediction)") + print("device:", DEVICE) + print("max_len:", args.max_len, "| use_mpf:", not args.no_mpf) + print("hidden_dim:", args.hidden_dim, "| lr:", args.lr) + + sample_ds: Any + vocab: Any + + if quick: + quick_test_tmp = _quick_test_ndjson_dir() + ds = MIMIC4FHIRDataset( + root=quick_test_tmp, + glob_pattern="*.ndjson", + cache_dir=quick_test_tmp, + max_patients=500, + ) + try: + print( + "pipeline: synthetic NDJSON → flattened tables → global_event_df " + "→ set_task → SampleDataset → Trainer" + ) + task = MPFClinicalPredictionTask( + max_len=args.max_len, + use_mpf=not args.no_mpf, + ) + print("set_task (quick-test, num_workers=1)...") + t_task0 = time.perf_counter() + sample_ds = ds.set_task(task, num_workers=1) + print( + "set_task done: n_samples=", + len(sample_ds), + "wall_s=", + round(time.perf_counter() - t_task0, 2), + ) + vocab = ds.vocab + except Exception: + print( + f"quick-test: leaving NDJSON/Parquet scratch at {quick_test_tmp}", + file=sys.stderr, + ) + raise + else: + mp = _max_patients_arg(args.max_patients) + if not fhir_root or not os.path.isdir(fhir_root): + raise SystemExit( + "Set MIMIC4_FHIR_ROOT or pass --fhir-root to an existing directory " + "(NDJSON tree for ingest fingerprint, even when using --prebuilt-global-event-dir)." + ) + ds_kw: Dict[str, Any] = { + "root": fhir_root, + "max_patients": mp, + "cache_dir": args.cache_dir, + } + if args.glob_pattern is not None: + ds_kw["glob_pattern"] = args.glob_pattern + if args.ingest_num_shards is not None: + ds_kw["ingest_num_shards"] = args.ingest_num_shards + ds = MIMIC4FHIRDataset(**ds_kw) + if args.prebuilt_global_event_dir: + pb = Path(args.prebuilt_global_event_dir).expanduser().resolve() + if not pb.is_dir(): + raise SystemExit(f"--prebuilt-global-event-dir not a directory: {pb}") + print( + "pipeline: offline flattened FHIR tables → seed flattened table cache " + "→ global_event_df → set_task → SampleDataset → Trainer " + "(no NDJSON normalization)" + ) + _seed_flattened_table_cache(pb, ds) + else: + print( + "pipeline: NDJSON root → MIMIC4FHIRDataset flattening → global_event_df " + "→ set_task → SampleDataset → Trainer" + ) + print("glob_pattern:", ds.glob_pattern, "| max_patients fingerprint:", mp) + if args.train_patient_cap is not None: + print("train_patient_cap:", args.train_patient_cap) + allow = _select_patient_ids_for_cap(ds, args.train_patient_cap) + mpf_task: MPFClinicalPredictionTask = PatientCappedMPFTask( + max_len=args.max_len, + use_mpf=not args.no_mpf, + patient_ids_allow=allow, + ) + print("task patient allow-list size:", len(allow)) + else: + _ensure_binary_label_coverage(ds) + mpf_task = MPFClinicalPredictionTask( + max_len=args.max_len, + use_mpf=not args.no_mpf, + ) + nw = args.task_num_workers + if nw is None: + nw = ds.num_workers + print(f"set_task (LitData task cache, num_workers={nw})...") + t_task0 = time.perf_counter() + sample_ds = ds.set_task(mpf_task, num_workers=nw) + print( + "set_task done: n_samples=", + len(sample_ds), + "wall_s=", + round(time.perf_counter() - t_task0, 2), + ) + vocab = ds.vocab + print("fhir_root:", fhir_root) + + try: + if len(sample_ds) == 0: + raise SystemExit( + "No training samples (0 patients or empty sequences). " + "PhysioNet MIMIC-IV FHIR uses *.ndjson.gz (see default glob_patterns in " + "pyhealth/datasets/configs/mimic4_fhir.yaml). If your tree is plain *.ndjson, " + "construct MIMIC4FHIRDataset with glob_pattern='**/*.ndjson'." + ) + + sample_ds, train_loader, val_loader, test_loader, vocab_size = ( + _build_loaders_from_sample_dataset(sample_ds, vocab.vocab_size) + ) + + model = EHRMambaCEHR( + dataset=sample_ds, + vocab_size=vocab_size, + embedding_dim=args.hidden_dim, + num_layers=2, + dropout=0.1, + ) + trainer = Trainer(model=model, metrics=["roc_auc", "pr_auc"], device=DEVICE) + + t0 = time.perf_counter() + trainer.train( + train_dataloader=train_loader, + val_dataloader=val_loader, + epochs=epochs, + monitor="roc_auc", + optimizer_params={"lr": args.lr}, + ) + results = trainer.evaluate(test_loader) + print("Test:", {k: float(v) for k, v in results.items()}) + print("wall_s:", round(time.perf_counter() - t0, 1)) + print("concept_vocab_size:", vocab.vocab_size) + + log_txt = ( + Path(trainer.exp_path) / "log.txt" if trainer.exp_path else None + ) + if log_txt and log_txt.is_file(): + losses = _parse_train_losses_from_log(log_txt) + print("train_loss_per_epoch:", losses) + plot_path = ( + Path(args.loss_plot_path) + if args.loss_plot_path + else Path(trainer.exp_path) / "train_loss.png" + ) + if trainer.exp_path: + _write_loss_plot(losses, plot_path) + finally: + if quick_test_tmp is not None: + shutil.rmtree(quick_test_tmp, ignore_errors=True) + + +if __name__ == "__main__": + main() diff --git a/pixi.lock b/pixi.lock index 42762b62b..51e563c4d 100644 --- a/pixi.lock +++ b/pixi.lock @@ -96,6 +96,7 @@ environments: - pypi: https://files.pythonhosted.org/packages/56/9a/fff8376f8e3d084cd1530e1ef7b879bb7d6d265620c95c1b322725c694f4/nvidia_nvtx_cu12-12.6.77-py3-none-manylinux2014_x86_64.manylinux_2_17_x86_64.whl - pypi: https://files.pythonhosted.org/packages/53/20/08c6dc0f20c1394e2324b9344838e4e7af770cdcb52c30757a475f50daeb/obstore-0.8.2-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl - pypi: https://files.pythonhosted.org/packages/7e/95/e0770cf1ad9667492f56b732f44398ef2756d61df914e10d121a3cad013a/ogb-1.3.6-py3-none-any.whl + - pypi: https://files.pythonhosted.org/packages/c9/ec/c68e3b9021a31d9ec15a94931db1410136af862955854ed5dd7e7e4f5bff/orjson-3.11.7-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl - pypi: https://files.pythonhosted.org/packages/d3/04/7d2b9a0d1b81e30f39e6f358bac01f4f18b585f35b0ffc5c83fc274f146b/outdated-0.2.2-py2.py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/20/12/38679034af332785aac8774540895e234f4d07f7545804097de4b666afd8/packaging-25.0-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/e9/e2/20a317688435470872885e7fc8f95109ae9683dec7c50be29b56911515a5/pandas-2.3.1-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl @@ -217,6 +218,7 @@ environments: - pypi: https://files.pythonhosted.org/packages/85/c5/e19c8f99d83fd377ec8c7e0cf627a8049746da54afc24ef0a0cb73d5dfb5/numpy-2.2.6-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl - pypi: https://files.pythonhosted.org/packages/77/20/77907765e29b2eba6bd8821872284d91170d7084f670855b2dfcb249ea14/obstore-0.8.2-cp313-cp313-manylinux_2_24_aarch64.whl - pypi: https://files.pythonhosted.org/packages/7e/95/e0770cf1ad9667492f56b732f44398ef2756d61df914e10d121a3cad013a/ogb-1.3.6-py3-none-any.whl + - pypi: https://files.pythonhosted.org/packages/89/25/0a16e0729a0e6a1504f9d1a13cdd365f030068aab64cec6958396b9969d7/orjson-3.11.7-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl - pypi: https://files.pythonhosted.org/packages/d3/04/7d2b9a0d1b81e30f39e6f358bac01f4f18b585f35b0ffc5c83fc274f146b/outdated-0.2.2-py2.py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/20/12/38679034af332785aac8774540895e234f4d07f7545804097de4b666afd8/packaging-25.0-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/0f/b0/80f6ec783313f1e2356b28b4fd8d2148c378370045da918c73145e6aab50/pandas-2.3.1-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl @@ -328,6 +330,7 @@ environments: - pypi: https://files.pythonhosted.org/packages/dc/9e/14520dc3dadf3c803473bd07e9b2bd1b69bc583cb2497b47000fed2fa92f/numpy-2.2.6-cp313-cp313-macosx_11_0_arm64.whl - pypi: https://files.pythonhosted.org/packages/ea/4d/699359774ce6330130536d008bfc32827fab0c25a00238d015a5974a3d1d/obstore-0.8.2-cp313-cp313-macosx_11_0_arm64.whl - pypi: https://files.pythonhosted.org/packages/7e/95/e0770cf1ad9667492f56b732f44398ef2756d61df914e10d121a3cad013a/ogb-1.3.6-py3-none-any.whl + - pypi: https://files.pythonhosted.org/packages/89/25/6e0e52cac5aab51d7b6dcd257e855e1dec1c2060f6b28566c509b4665f62/orjson-3.11.7-cp313-cp313-macosx_10_15_x86_64.macosx_11_0_arm64.macosx_10_15_universal2.whl - pypi: https://files.pythonhosted.org/packages/d3/04/7d2b9a0d1b81e30f39e6f358bac01f4f18b585f35b0ffc5c83fc274f146b/outdated-0.2.2-py2.py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/20/12/38679034af332785aac8774540895e234f4d07f7545804097de4b666afd8/packaging-25.0-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/c7/db/d8f24a7cc9fb0972adab0cc80b6817e8bef888cfd0024eeb5a21c0bb5c4a/pandas-2.3.1-cp313-cp313-macosx_11_0_arm64.whl @@ -440,6 +443,7 @@ environments: - pypi: https://files.pythonhosted.org/packages/cb/3b/d58c12eafcb298d4e6d0d40216866ab15f59e55d148a5658bb3132311fcf/numpy-2.2.6-cp313-cp313-win_amd64.whl - pypi: https://files.pythonhosted.org/packages/14/dd/916c6777222db3271e9fb3cf9a97ed92b3a9b3e465bdeec96de9ab809d53/obstore-0.8.2-cp313-cp313-win_amd64.whl - pypi: https://files.pythonhosted.org/packages/7e/95/e0770cf1ad9667492f56b732f44398ef2756d61df914e10d121a3cad013a/ogb-1.3.6-py3-none-any.whl + - pypi: https://files.pythonhosted.org/packages/00/04/c6f72daca5092e3117840a1b1e88dfc809cc1470cf0734890d0366b684a1/orjson-3.11.7-cp313-cp313-win_amd64.whl - pypi: https://files.pythonhosted.org/packages/d3/04/7d2b9a0d1b81e30f39e6f358bac01f4f18b585f35b0ffc5c83fc274f146b/outdated-0.2.2-py2.py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/20/12/38679034af332785aac8774540895e234f4d07f7545804097de4b666afd8/packaging-25.0-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/b2/c0/54415af59db5cdd86a3d3bf79863e8cc3fa9ed265f0745254061ac09d5f2/pandas-2.3.1-cp313-cp313-win_amd64.whl @@ -597,6 +601,7 @@ environments: - pypi: https://files.pythonhosted.org/packages/56/9a/fff8376f8e3d084cd1530e1ef7b879bb7d6d265620c95c1b322725c694f4/nvidia_nvtx_cu12-12.6.77-py3-none-manylinux2014_x86_64.manylinux_2_17_x86_64.whl - pypi: https://files.pythonhosted.org/packages/53/20/08c6dc0f20c1394e2324b9344838e4e7af770cdcb52c30757a475f50daeb/obstore-0.8.2-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl - pypi: https://files.pythonhosted.org/packages/7e/95/e0770cf1ad9667492f56b732f44398ef2756d61df914e10d121a3cad013a/ogb-1.3.6-py3-none-any.whl + - pypi: https://files.pythonhosted.org/packages/c9/ec/c68e3b9021a31d9ec15a94931db1410136af862955854ed5dd7e7e4f5bff/orjson-3.11.7-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl - pypi: https://files.pythonhosted.org/packages/d3/04/7d2b9a0d1b81e30f39e6f358bac01f4f18b585f35b0ffc5c83fc274f146b/outdated-0.2.2-py2.py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/e9/e2/20a317688435470872885e7fc8f95109ae9683dec7c50be29b56911515a5/pandas-2.3.1-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl - pypi: https://files.pythonhosted.org/packages/71/e7/40fb618334dcdf7c5a316c0e7343c5cd82d3d866edc100d98e29bc945ecd/partd-1.4.2-py3-none-any.whl @@ -727,6 +732,7 @@ environments: - pypi: https://files.pythonhosted.org/packages/85/c5/e19c8f99d83fd377ec8c7e0cf627a8049746da54afc24ef0a0cb73d5dfb5/numpy-2.2.6-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl - pypi: https://files.pythonhosted.org/packages/77/20/77907765e29b2eba6bd8821872284d91170d7084f670855b2dfcb249ea14/obstore-0.8.2-cp313-cp313-manylinux_2_24_aarch64.whl - pypi: https://files.pythonhosted.org/packages/7e/95/e0770cf1ad9667492f56b732f44398ef2756d61df914e10d121a3cad013a/ogb-1.3.6-py3-none-any.whl + - pypi: https://files.pythonhosted.org/packages/89/25/0a16e0729a0e6a1504f9d1a13cdd365f030068aab64cec6958396b9969d7/orjson-3.11.7-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl - pypi: https://files.pythonhosted.org/packages/d3/04/7d2b9a0d1b81e30f39e6f358bac01f4f18b585f35b0ffc5c83fc274f146b/outdated-0.2.2-py2.py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/0f/b0/80f6ec783313f1e2356b28b4fd8d2148c378370045da918c73145e6aab50/pandas-2.3.1-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl - pypi: https://files.pythonhosted.org/packages/71/e7/40fb618334dcdf7c5a316c0e7343c5cd82d3d866edc100d98e29bc945ecd/partd-1.4.2-py3-none-any.whl @@ -848,6 +854,7 @@ environments: - pypi: https://files.pythonhosted.org/packages/dc/9e/14520dc3dadf3c803473bd07e9b2bd1b69bc583cb2497b47000fed2fa92f/numpy-2.2.6-cp313-cp313-macosx_11_0_arm64.whl - pypi: https://files.pythonhosted.org/packages/ea/4d/699359774ce6330130536d008bfc32827fab0c25a00238d015a5974a3d1d/obstore-0.8.2-cp313-cp313-macosx_11_0_arm64.whl - pypi: https://files.pythonhosted.org/packages/7e/95/e0770cf1ad9667492f56b732f44398ef2756d61df914e10d121a3cad013a/ogb-1.3.6-py3-none-any.whl + - pypi: https://files.pythonhosted.org/packages/89/25/6e0e52cac5aab51d7b6dcd257e855e1dec1c2060f6b28566c509b4665f62/orjson-3.11.7-cp313-cp313-macosx_10_15_x86_64.macosx_11_0_arm64.macosx_10_15_universal2.whl - pypi: https://files.pythonhosted.org/packages/d3/04/7d2b9a0d1b81e30f39e6f358bac01f4f18b585f35b0ffc5c83fc274f146b/outdated-0.2.2-py2.py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/c7/db/d8f24a7cc9fb0972adab0cc80b6817e8bef888cfd0024eeb5a21c0bb5c4a/pandas-2.3.1-cp313-cp313-macosx_11_0_arm64.whl - pypi: https://files.pythonhosted.org/packages/71/e7/40fb618334dcdf7c5a316c0e7343c5cd82d3d866edc100d98e29bc945ecd/partd-1.4.2-py3-none-any.whl @@ -970,6 +977,7 @@ environments: - pypi: https://files.pythonhosted.org/packages/cb/3b/d58c12eafcb298d4e6d0d40216866ab15f59e55d148a5658bb3132311fcf/numpy-2.2.6-cp313-cp313-win_amd64.whl - pypi: https://files.pythonhosted.org/packages/14/dd/916c6777222db3271e9fb3cf9a97ed92b3a9b3e465bdeec96de9ab809d53/obstore-0.8.2-cp313-cp313-win_amd64.whl - pypi: https://files.pythonhosted.org/packages/7e/95/e0770cf1ad9667492f56b732f44398ef2756d61df914e10d121a3cad013a/ogb-1.3.6-py3-none-any.whl + - pypi: https://files.pythonhosted.org/packages/00/04/c6f72daca5092e3117840a1b1e88dfc809cc1470cf0734890d0366b684a1/orjson-3.11.7-cp313-cp313-win_amd64.whl - pypi: https://files.pythonhosted.org/packages/d3/04/7d2b9a0d1b81e30f39e6f358bac01f4f18b585f35b0ffc5c83fc274f146b/outdated-0.2.2-py2.py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/b2/c0/54415af59db5cdd86a3d3bf79863e8cc3fa9ed265f0745254061ac09d5f2/pandas-2.3.1-cp313-cp313-win_amd64.whl - pypi: https://files.pythonhosted.org/packages/71/e7/40fb618334dcdf7c5a316c0e7343c5cd82d3d866edc100d98e29bc945ecd/partd-1.4.2-py3-none-any.whl @@ -1158,6 +1166,7 @@ environments: - pypi: https://files.pythonhosted.org/packages/56/9a/fff8376f8e3d084cd1530e1ef7b879bb7d6d265620c95c1b322725c694f4/nvidia_nvtx_cu12-12.6.77-py3-none-manylinux2014_x86_64.manylinux_2_17_x86_64.whl - pypi: https://files.pythonhosted.org/packages/53/20/08c6dc0f20c1394e2324b9344838e4e7af770cdcb52c30757a475f50daeb/obstore-0.8.2-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl - pypi: https://files.pythonhosted.org/packages/7e/95/e0770cf1ad9667492f56b732f44398ef2756d61df914e10d121a3cad013a/ogb-1.3.6-py3-none-any.whl + - pypi: https://files.pythonhosted.org/packages/c9/ec/c68e3b9021a31d9ec15a94931db1410136af862955854ed5dd7e7e4f5bff/orjson-3.11.7-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl - pypi: https://files.pythonhosted.org/packages/d3/04/7d2b9a0d1b81e30f39e6f358bac01f4f18b585f35b0ffc5c83fc274f146b/outdated-0.2.2-py2.py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/e9/e2/20a317688435470872885e7fc8f95109ae9683dec7c50be29b56911515a5/pandas-2.3.1-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl - pypi: https://files.pythonhosted.org/packages/71/e7/40fb618334dcdf7c5a316c0e7343c5cd82d3d866edc100d98e29bc945ecd/partd-1.4.2-py3-none-any.whl @@ -1324,6 +1333,7 @@ environments: - pypi: https://files.pythonhosted.org/packages/85/c5/e19c8f99d83fd377ec8c7e0cf627a8049746da54afc24ef0a0cb73d5dfb5/numpy-2.2.6-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl - pypi: https://files.pythonhosted.org/packages/77/20/77907765e29b2eba6bd8821872284d91170d7084f670855b2dfcb249ea14/obstore-0.8.2-cp313-cp313-manylinux_2_24_aarch64.whl - pypi: https://files.pythonhosted.org/packages/7e/95/e0770cf1ad9667492f56b732f44398ef2756d61df914e10d121a3cad013a/ogb-1.3.6-py3-none-any.whl + - pypi: https://files.pythonhosted.org/packages/89/25/0a16e0729a0e6a1504f9d1a13cdd365f030068aab64cec6958396b9969d7/orjson-3.11.7-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl - pypi: https://files.pythonhosted.org/packages/d3/04/7d2b9a0d1b81e30f39e6f358bac01f4f18b585f35b0ffc5c83fc274f146b/outdated-0.2.2-py2.py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/0f/b0/80f6ec783313f1e2356b28b4fd8d2148c378370045da918c73145e6aab50/pandas-2.3.1-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl - pypi: https://files.pythonhosted.org/packages/71/e7/40fb618334dcdf7c5a316c0e7343c5cd82d3d866edc100d98e29bc945ecd/partd-1.4.2-py3-none-any.whl @@ -1474,6 +1484,7 @@ environments: - pypi: https://files.pythonhosted.org/packages/dc/9e/14520dc3dadf3c803473bd07e9b2bd1b69bc583cb2497b47000fed2fa92f/numpy-2.2.6-cp313-cp313-macosx_11_0_arm64.whl - pypi: https://files.pythonhosted.org/packages/ea/4d/699359774ce6330130536d008bfc32827fab0c25a00238d015a5974a3d1d/obstore-0.8.2-cp313-cp313-macosx_11_0_arm64.whl - pypi: https://files.pythonhosted.org/packages/7e/95/e0770cf1ad9667492f56b732f44398ef2756d61df914e10d121a3cad013a/ogb-1.3.6-py3-none-any.whl + - pypi: https://files.pythonhosted.org/packages/89/25/6e0e52cac5aab51d7b6dcd257e855e1dec1c2060f6b28566c509b4665f62/orjson-3.11.7-cp313-cp313-macosx_10_15_x86_64.macosx_11_0_arm64.macosx_10_15_universal2.whl - pypi: https://files.pythonhosted.org/packages/d3/04/7d2b9a0d1b81e30f39e6f358bac01f4f18b585f35b0ffc5c83fc274f146b/outdated-0.2.2-py2.py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/c7/db/d8f24a7cc9fb0972adab0cc80b6817e8bef888cfd0024eeb5a21c0bb5c4a/pandas-2.3.1-cp313-cp313-macosx_11_0_arm64.whl - pypi: https://files.pythonhosted.org/packages/71/e7/40fb618334dcdf7c5a316c0e7343c5cd82d3d866edc100d98e29bc945ecd/partd-1.4.2-py3-none-any.whl @@ -1626,6 +1637,7 @@ environments: - pypi: https://files.pythonhosted.org/packages/cb/3b/d58c12eafcb298d4e6d0d40216866ab15f59e55d148a5658bb3132311fcf/numpy-2.2.6-cp313-cp313-win_amd64.whl - pypi: https://files.pythonhosted.org/packages/14/dd/916c6777222db3271e9fb3cf9a97ed92b3a9b3e465bdeec96de9ab809d53/obstore-0.8.2-cp313-cp313-win_amd64.whl - pypi: https://files.pythonhosted.org/packages/7e/95/e0770cf1ad9667492f56b732f44398ef2756d61df914e10d121a3cad013a/ogb-1.3.6-py3-none-any.whl + - pypi: https://files.pythonhosted.org/packages/00/04/c6f72daca5092e3117840a1b1e88dfc809cc1470cf0734890d0366b684a1/orjson-3.11.7-cp313-cp313-win_amd64.whl - pypi: https://files.pythonhosted.org/packages/d3/04/7d2b9a0d1b81e30f39e6f358bac01f4f18b585f35b0ffc5c83fc274f146b/outdated-0.2.2-py2.py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/b2/c0/54415af59db5cdd86a3d3bf79863e8cc3fa9ed265f0745254061ac09d5f2/pandas-2.3.1-cp313-cp313-win_amd64.whl - pypi: https://files.pythonhosted.org/packages/71/e7/40fb618334dcdf7c5a316c0e7343c5cd82d3d866edc100d98e29bc945ecd/partd-1.4.2-py3-none-any.whl @@ -1777,6 +1789,7 @@ environments: - pypi: https://files.pythonhosted.org/packages/56/9a/fff8376f8e3d084cd1530e1ef7b879bb7d6d265620c95c1b322725c694f4/nvidia_nvtx_cu12-12.6.77-py3-none-manylinux2014_x86_64.manylinux_2_17_x86_64.whl - pypi: https://files.pythonhosted.org/packages/53/20/08c6dc0f20c1394e2324b9344838e4e7af770cdcb52c30757a475f50daeb/obstore-0.8.2-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl - pypi: https://files.pythonhosted.org/packages/7e/95/e0770cf1ad9667492f56b732f44398ef2756d61df914e10d121a3cad013a/ogb-1.3.6-py3-none-any.whl + - pypi: https://files.pythonhosted.org/packages/c9/ec/c68e3b9021a31d9ec15a94931db1410136af862955854ed5dd7e7e4f5bff/orjson-3.11.7-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl - pypi: https://files.pythonhosted.org/packages/d3/04/7d2b9a0d1b81e30f39e6f358bac01f4f18b585f35b0ffc5c83fc274f146b/outdated-0.2.2-py2.py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/e9/e2/20a317688435470872885e7fc8f95109ae9683dec7c50be29b56911515a5/pandas-2.3.1-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl - pypi: https://files.pythonhosted.org/packages/71/e7/40fb618334dcdf7c5a316c0e7343c5cd82d3d866edc100d98e29bc945ecd/partd-1.4.2-py3-none-any.whl @@ -1906,6 +1919,7 @@ environments: - pypi: https://files.pythonhosted.org/packages/85/c5/e19c8f99d83fd377ec8c7e0cf627a8049746da54afc24ef0a0cb73d5dfb5/numpy-2.2.6-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl - pypi: https://files.pythonhosted.org/packages/77/20/77907765e29b2eba6bd8821872284d91170d7084f670855b2dfcb249ea14/obstore-0.8.2-cp313-cp313-manylinux_2_24_aarch64.whl - pypi: https://files.pythonhosted.org/packages/7e/95/e0770cf1ad9667492f56b732f44398ef2756d61df914e10d121a3cad013a/ogb-1.3.6-py3-none-any.whl + - pypi: https://files.pythonhosted.org/packages/89/25/0a16e0729a0e6a1504f9d1a13cdd365f030068aab64cec6958396b9969d7/orjson-3.11.7-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl - pypi: https://files.pythonhosted.org/packages/d3/04/7d2b9a0d1b81e30f39e6f358bac01f4f18b585f35b0ffc5c83fc274f146b/outdated-0.2.2-py2.py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/0f/b0/80f6ec783313f1e2356b28b4fd8d2148c378370045da918c73145e6aab50/pandas-2.3.1-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl - pypi: https://files.pythonhosted.org/packages/71/e7/40fb618334dcdf7c5a316c0e7343c5cd82d3d866edc100d98e29bc945ecd/partd-1.4.2-py3-none-any.whl @@ -2026,6 +2040,7 @@ environments: - pypi: https://files.pythonhosted.org/packages/dc/9e/14520dc3dadf3c803473bd07e9b2bd1b69bc583cb2497b47000fed2fa92f/numpy-2.2.6-cp313-cp313-macosx_11_0_arm64.whl - pypi: https://files.pythonhosted.org/packages/ea/4d/699359774ce6330130536d008bfc32827fab0c25a00238d015a5974a3d1d/obstore-0.8.2-cp313-cp313-macosx_11_0_arm64.whl - pypi: https://files.pythonhosted.org/packages/7e/95/e0770cf1ad9667492f56b732f44398ef2756d61df914e10d121a3cad013a/ogb-1.3.6-py3-none-any.whl + - pypi: https://files.pythonhosted.org/packages/89/25/6e0e52cac5aab51d7b6dcd257e855e1dec1c2060f6b28566c509b4665f62/orjson-3.11.7-cp313-cp313-macosx_10_15_x86_64.macosx_11_0_arm64.macosx_10_15_universal2.whl - pypi: https://files.pythonhosted.org/packages/d3/04/7d2b9a0d1b81e30f39e6f358bac01f4f18b585f35b0ffc5c83fc274f146b/outdated-0.2.2-py2.py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/c7/db/d8f24a7cc9fb0972adab0cc80b6817e8bef888cfd0024eeb5a21c0bb5c4a/pandas-2.3.1-cp313-cp313-macosx_11_0_arm64.whl - pypi: https://files.pythonhosted.org/packages/71/e7/40fb618334dcdf7c5a316c0e7343c5cd82d3d866edc100d98e29bc945ecd/partd-1.4.2-py3-none-any.whl @@ -2147,6 +2162,7 @@ environments: - pypi: https://files.pythonhosted.org/packages/cb/3b/d58c12eafcb298d4e6d0d40216866ab15f59e55d148a5658bb3132311fcf/numpy-2.2.6-cp313-cp313-win_amd64.whl - pypi: https://files.pythonhosted.org/packages/14/dd/916c6777222db3271e9fb3cf9a97ed92b3a9b3e465bdeec96de9ab809d53/obstore-0.8.2-cp313-cp313-win_amd64.whl - pypi: https://files.pythonhosted.org/packages/7e/95/e0770cf1ad9667492f56b732f44398ef2756d61df914e10d121a3cad013a/ogb-1.3.6-py3-none-any.whl + - pypi: https://files.pythonhosted.org/packages/00/04/c6f72daca5092e3117840a1b1e88dfc809cc1470cf0734890d0366b684a1/orjson-3.11.7-cp313-cp313-win_amd64.whl - pypi: https://files.pythonhosted.org/packages/d3/04/7d2b9a0d1b81e30f39e6f358bac01f4f18b585f35b0ffc5c83fc274f146b/outdated-0.2.2-py2.py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/b2/c0/54415af59db5cdd86a3d3bf79863e8cc3fa9ed265f0745254061ac09d5f2/pandas-2.3.1-cp313-cp313-win_amd64.whl - pypi: https://files.pythonhosted.org/packages/71/e7/40fb618334dcdf7c5a316c0e7343c5cd82d3d866edc100d98e29bc945ecd/partd-1.4.2-py3-none-any.whl @@ -2289,6 +2305,7 @@ environments: - pypi: https://files.pythonhosted.org/packages/56/9a/fff8376f8e3d084cd1530e1ef7b879bb7d6d265620c95c1b322725c694f4/nvidia_nvtx_cu12-12.6.77-py3-none-manylinux2014_x86_64.manylinux_2_17_x86_64.whl - pypi: https://files.pythonhosted.org/packages/53/20/08c6dc0f20c1394e2324b9344838e4e7af770cdcb52c30757a475f50daeb/obstore-0.8.2-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl - pypi: https://files.pythonhosted.org/packages/7e/95/e0770cf1ad9667492f56b732f44398ef2756d61df914e10d121a3cad013a/ogb-1.3.6-py3-none-any.whl + - pypi: https://files.pythonhosted.org/packages/c9/ec/c68e3b9021a31d9ec15a94931db1410136af862955854ed5dd7e7e4f5bff/orjson-3.11.7-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl - pypi: https://files.pythonhosted.org/packages/d3/04/7d2b9a0d1b81e30f39e6f358bac01f4f18b585f35b0ffc5c83fc274f146b/outdated-0.2.2-py2.py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/20/12/38679034af332785aac8774540895e234f4d07f7545804097de4b666afd8/packaging-25.0-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/e9/e2/20a317688435470872885e7fc8f95109ae9683dec7c50be29b56911515a5/pandas-2.3.1-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl @@ -2415,6 +2432,7 @@ environments: - pypi: https://files.pythonhosted.org/packages/85/c5/e19c8f99d83fd377ec8c7e0cf627a8049746da54afc24ef0a0cb73d5dfb5/numpy-2.2.6-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl - pypi: https://files.pythonhosted.org/packages/77/20/77907765e29b2eba6bd8821872284d91170d7084f670855b2dfcb249ea14/obstore-0.8.2-cp313-cp313-manylinux_2_24_aarch64.whl - pypi: https://files.pythonhosted.org/packages/7e/95/e0770cf1ad9667492f56b732f44398ef2756d61df914e10d121a3cad013a/ogb-1.3.6-py3-none-any.whl + - pypi: https://files.pythonhosted.org/packages/89/25/0a16e0729a0e6a1504f9d1a13cdd365f030068aab64cec6958396b9969d7/orjson-3.11.7-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl - pypi: https://files.pythonhosted.org/packages/d3/04/7d2b9a0d1b81e30f39e6f358bac01f4f18b585f35b0ffc5c83fc274f146b/outdated-0.2.2-py2.py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/20/12/38679034af332785aac8774540895e234f4d07f7545804097de4b666afd8/packaging-25.0-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/0f/b0/80f6ec783313f1e2356b28b4fd8d2148c378370045da918c73145e6aab50/pandas-2.3.1-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl @@ -2531,6 +2549,7 @@ environments: - pypi: https://files.pythonhosted.org/packages/dc/9e/14520dc3dadf3c803473bd07e9b2bd1b69bc583cb2497b47000fed2fa92f/numpy-2.2.6-cp313-cp313-macosx_11_0_arm64.whl - pypi: https://files.pythonhosted.org/packages/ea/4d/699359774ce6330130536d008bfc32827fab0c25a00238d015a5974a3d1d/obstore-0.8.2-cp313-cp313-macosx_11_0_arm64.whl - pypi: https://files.pythonhosted.org/packages/7e/95/e0770cf1ad9667492f56b732f44398ef2756d61df914e10d121a3cad013a/ogb-1.3.6-py3-none-any.whl + - pypi: https://files.pythonhosted.org/packages/89/25/6e0e52cac5aab51d7b6dcd257e855e1dec1c2060f6b28566c509b4665f62/orjson-3.11.7-cp313-cp313-macosx_10_15_x86_64.macosx_11_0_arm64.macosx_10_15_universal2.whl - pypi: https://files.pythonhosted.org/packages/d3/04/7d2b9a0d1b81e30f39e6f358bac01f4f18b585f35b0ffc5c83fc274f146b/outdated-0.2.2-py2.py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/20/12/38679034af332785aac8774540895e234f4d07f7545804097de4b666afd8/packaging-25.0-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/c7/db/d8f24a7cc9fb0972adab0cc80b6817e8bef888cfd0024eeb5a21c0bb5c4a/pandas-2.3.1-cp313-cp313-macosx_11_0_arm64.whl @@ -2648,6 +2667,7 @@ environments: - pypi: https://files.pythonhosted.org/packages/cb/3b/d58c12eafcb298d4e6d0d40216866ab15f59e55d148a5658bb3132311fcf/numpy-2.2.6-cp313-cp313-win_amd64.whl - pypi: https://files.pythonhosted.org/packages/14/dd/916c6777222db3271e9fb3cf9a97ed92b3a9b3e465bdeec96de9ab809d53/obstore-0.8.2-cp313-cp313-win_amd64.whl - pypi: https://files.pythonhosted.org/packages/7e/95/e0770cf1ad9667492f56b732f44398ef2756d61df914e10d121a3cad013a/ogb-1.3.6-py3-none-any.whl + - pypi: https://files.pythonhosted.org/packages/00/04/c6f72daca5092e3117840a1b1e88dfc809cc1470cf0734890d0366b684a1/orjson-3.11.7-cp313-cp313-win_amd64.whl - pypi: https://files.pythonhosted.org/packages/d3/04/7d2b9a0d1b81e30f39e6f358bac01f4f18b585f35b0ffc5c83fc274f146b/outdated-0.2.2-py2.py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/20/12/38679034af332785aac8774540895e234f4d07f7545804097de4b666afd8/packaging-25.0-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/b2/c0/54415af59db5cdd86a3d3bf79863e8cc3fa9ed265f0745254061ac09d5f2/pandas-2.3.1-cp313-cp313-win_amd64.whl @@ -2792,6 +2812,7 @@ environments: - pypi: https://files.pythonhosted.org/packages/56/9a/fff8376f8e3d084cd1530e1ef7b879bb7d6d265620c95c1b322725c694f4/nvidia_nvtx_cu12-12.6.77-py3-none-manylinux2014_x86_64.manylinux_2_17_x86_64.whl - pypi: https://files.pythonhosted.org/packages/53/20/08c6dc0f20c1394e2324b9344838e4e7af770cdcb52c30757a475f50daeb/obstore-0.8.2-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl - pypi: https://files.pythonhosted.org/packages/7e/95/e0770cf1ad9667492f56b732f44398ef2756d61df914e10d121a3cad013a/ogb-1.3.6-py3-none-any.whl + - pypi: https://files.pythonhosted.org/packages/c9/ec/c68e3b9021a31d9ec15a94931db1410136af862955854ed5dd7e7e4f5bff/orjson-3.11.7-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl - pypi: https://files.pythonhosted.org/packages/d3/04/7d2b9a0d1b81e30f39e6f358bac01f4f18b585f35b0ffc5c83fc274f146b/outdated-0.2.2-py2.py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/20/12/38679034af332785aac8774540895e234f4d07f7545804097de4b666afd8/packaging-25.0-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/e9/e2/20a317688435470872885e7fc8f95109ae9683dec7c50be29b56911515a5/pandas-2.3.1-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl @@ -2913,6 +2934,7 @@ environments: - pypi: https://files.pythonhosted.org/packages/85/c5/e19c8f99d83fd377ec8c7e0cf627a8049746da54afc24ef0a0cb73d5dfb5/numpy-2.2.6-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl - pypi: https://files.pythonhosted.org/packages/77/20/77907765e29b2eba6bd8821872284d91170d7084f670855b2dfcb249ea14/obstore-0.8.2-cp313-cp313-manylinux_2_24_aarch64.whl - pypi: https://files.pythonhosted.org/packages/7e/95/e0770cf1ad9667492f56b732f44398ef2756d61df914e10d121a3cad013a/ogb-1.3.6-py3-none-any.whl + - pypi: https://files.pythonhosted.org/packages/89/25/0a16e0729a0e6a1504f9d1a13cdd365f030068aab64cec6958396b9969d7/orjson-3.11.7-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl - pypi: https://files.pythonhosted.org/packages/d3/04/7d2b9a0d1b81e30f39e6f358bac01f4f18b585f35b0ffc5c83fc274f146b/outdated-0.2.2-py2.py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/20/12/38679034af332785aac8774540895e234f4d07f7545804097de4b666afd8/packaging-25.0-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/0f/b0/80f6ec783313f1e2356b28b4fd8d2148c378370045da918c73145e6aab50/pandas-2.3.1-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl @@ -3024,6 +3046,7 @@ environments: - pypi: https://files.pythonhosted.org/packages/dc/9e/14520dc3dadf3c803473bd07e9b2bd1b69bc583cb2497b47000fed2fa92f/numpy-2.2.6-cp313-cp313-macosx_11_0_arm64.whl - pypi: https://files.pythonhosted.org/packages/ea/4d/699359774ce6330130536d008bfc32827fab0c25a00238d015a5974a3d1d/obstore-0.8.2-cp313-cp313-macosx_11_0_arm64.whl - pypi: https://files.pythonhosted.org/packages/7e/95/e0770cf1ad9667492f56b732f44398ef2756d61df914e10d121a3cad013a/ogb-1.3.6-py3-none-any.whl + - pypi: https://files.pythonhosted.org/packages/89/25/6e0e52cac5aab51d7b6dcd257e855e1dec1c2060f6b28566c509b4665f62/orjson-3.11.7-cp313-cp313-macosx_10_15_x86_64.macosx_11_0_arm64.macosx_10_15_universal2.whl - pypi: https://files.pythonhosted.org/packages/d3/04/7d2b9a0d1b81e30f39e6f358bac01f4f18b585f35b0ffc5c83fc274f146b/outdated-0.2.2-py2.py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/20/12/38679034af332785aac8774540895e234f4d07f7545804097de4b666afd8/packaging-25.0-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/c7/db/d8f24a7cc9fb0972adab0cc80b6817e8bef888cfd0024eeb5a21c0bb5c4a/pandas-2.3.1-cp313-cp313-macosx_11_0_arm64.whl @@ -3136,6 +3159,7 @@ environments: - pypi: https://files.pythonhosted.org/packages/cb/3b/d58c12eafcb298d4e6d0d40216866ab15f59e55d148a5658bb3132311fcf/numpy-2.2.6-cp313-cp313-win_amd64.whl - pypi: https://files.pythonhosted.org/packages/14/dd/916c6777222db3271e9fb3cf9a97ed92b3a9b3e465bdeec96de9ab809d53/obstore-0.8.2-cp313-cp313-win_amd64.whl - pypi: https://files.pythonhosted.org/packages/7e/95/e0770cf1ad9667492f56b732f44398ef2756d61df914e10d121a3cad013a/ogb-1.3.6-py3-none-any.whl + - pypi: https://files.pythonhosted.org/packages/00/04/c6f72daca5092e3117840a1b1e88dfc809cc1470cf0734890d0366b684a1/orjson-3.11.7-cp313-cp313-win_amd64.whl - pypi: https://files.pythonhosted.org/packages/d3/04/7d2b9a0d1b81e30f39e6f358bac01f4f18b585f35b0ffc5c83fc274f146b/outdated-0.2.2-py2.py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/20/12/38679034af332785aac8774540895e234f4d07f7545804097de4b666afd8/packaging-25.0-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/b2/c0/54415af59db5cdd86a3d3bf79863e8cc3fa9ed265f0745254061ac09d5f2/pandas-2.3.1-cp313-cp313-win_amd64.whl @@ -6152,6 +6176,26 @@ packages: purls: [] size: 9327033 timestamp: 1751392489008 +- pypi: https://files.pythonhosted.org/packages/00/04/c6f72daca5092e3117840a1b1e88dfc809cc1470cf0734890d0366b684a1/orjson-3.11.7-cp313-cp313-win_amd64.whl + name: orjson + version: 3.11.7 + sha256: b9f95dcdea9d4f805daa9ddf02617a89e484c6985fa03055459f90e87d7a0757 + requires_python: '>=3.10' +- pypi: https://files.pythonhosted.org/packages/89/25/0a16e0729a0e6a1504f9d1a13cdd365f030068aab64cec6958396b9969d7/orjson-3.11.7-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl + name: orjson + version: 3.11.7 + sha256: 814be4b49b228cfc0b3c565acf642dd7d13538f966e3ccde61f4f55be3e20785 + requires_python: '>=3.10' +- pypi: https://files.pythonhosted.org/packages/89/25/6e0e52cac5aab51d7b6dcd257e855e1dec1c2060f6b28566c509b4665f62/orjson-3.11.7-cp313-cp313-macosx_10_15_x86_64.macosx_11_0_arm64.macosx_10_15_universal2.whl + name: orjson + version: 3.11.7 + sha256: 1d98b30cc1313d52d4af17d9c3d307b08389752ec5f2e5febdfada70b0f8c733 + requires_python: '>=3.10' +- pypi: https://files.pythonhosted.org/packages/c9/ec/c68e3b9021a31d9ec15a94931db1410136af862955854ed5dd7e7e4f5bff/orjson-3.11.7-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl + name: orjson + version: 3.11.7 + sha256: a12b80df61aab7b98b490fe9e4879925ba666fccdfcd175252ce4d9035865ace + requires_python: '>=3.10' - pypi: https://files.pythonhosted.org/packages/d3/04/7d2b9a0d1b81e30f39e6f358bac01f4f18b585f35b0ffc5c83fc274f146b/outdated-0.2.2-py2.py3-none-any.whl name: outdated version: 0.2.2 @@ -7082,7 +7126,7 @@ packages: - pypi: ./ name: pyhealth version: 2.0.1 - sha256: bf368461a8e66f93ad43f5880295045cbabe6688064af9a650a92bdaf1665332 + sha256: 941556c467dc4bb2cbe43e6b755c9b30c080151d9506bcfe584315280889b50b requires_dist: - torch~=2.7.1 - torchvision @@ -7107,6 +7151,7 @@ packages: - more-itertools~=10.8.0 - einops>=0.8.0 - linear-attention-transformer>=0.19.1 + - orjson~=3.10 - torch-geometric>=2.6.0 ; extra == 'graph' - editdistance~=0.8.1 ; extra == 'nlp' - rouge-score~=0.1.2 ; extra == 'nlp' diff --git a/pyhealth/datasets/__init__.py b/pyhealth/datasets/__init__.py index 54e77670c..c41285d0a 100644 --- a/pyhealth/datasets/__init__.py +++ b/pyhealth/datasets/__init__.py @@ -59,6 +59,7 @@ def __init__(self, *args, **kwargs): from .medical_transcriptions import MedicalTranscriptionsDataset from .mimic3 import MIMIC3Dataset from .mimic4 import MIMIC4CXRDataset, MIMIC4Dataset, MIMIC4EHRDataset, MIMIC4NoteDataset +from .mimic4_fhir import MIMIC4FHIRDataset from .mimicextract import MIMICExtractDataset from .omop import OMOPDataset from .sample_dataset import SampleBuilder, SampleDataset, create_sample_dataset diff --git a/pyhealth/datasets/configs/mimic4_fhir.yaml b/pyhealth/datasets/configs/mimic4_fhir.yaml new file mode 100644 index 000000000..989ab79cf --- /dev/null +++ b/pyhealth/datasets/configs/mimic4_fhir.yaml @@ -0,0 +1,118 @@ +# MIMIC-IV FHIR Resource Flattening Configuration +# ================================================ +# +# This YAML defines the normalized schema for MIMIC-IV FHIR exports after streaming +# ingestion. Raw NDJSON/NDJSON.GZ resources are parsed and flattened into six +# per-resource-type Parquet tables (see ``version`` below), then loaded through the +# standard BaseDataset pipeline for task construction. +# +# For ingest details, see the docstring of ``stream_fhir_ndjson_to_flat_tables()`` +# in ``pyhealth/datasets/mimic4_fhir.py``. + +version: "fhir_r4_flattened" + +# Glob Pattern(s) for NDJSON File Discovery +# =========================================== +# +# ``glob_patterns`` (list) or ``glob_pattern`` (string): Patterns to match NDJSON files +# under the ingest root directory. Patterns are applied via pathlib.Path.glob(). +# +# Default: Six targeted patterns matching PhysioNet MIMIC-IV FHIR Mimic* shard families +# that map to flattened tables. This avoids decompressing and parsing ~10% of PhysioNet +# exports (MedicationAdministration, Specimen, Organization, …) that are skipped by +# the flattener. +# +# Alternatives: +# - For non-PhysioNet naming, use a single broad pattern: +# glob_pattern: "**/*.ndjson.gz" +# - To test on a subset, use a narrower list: +# glob_patterns: +# - "**/MimicPatient*.ndjson.gz" +# - "**/MimicObservation*.ndjson.gz" +# +# Notes: +# - Patterns use ``**/`` for recursive search (works in both flat and nested layouts). +# - Can be overridden at runtime via MIMIC4FHIRDataset(glob_pattern=...) or +# MIMIC4FHIRDataset(glob_patterns=[...]). + +glob_patterns: + - "**/MimicPatient*.ndjson.gz" + - "**/MimicEncounter*.ndjson.gz" + - "**/MimicCondition*.ndjson.gz" + - "**/MimicObservation*.ndjson.gz" + - "**/MimicMedicationRequest*.ndjson.gz" + - "**/MimicProcedure*.ndjson.gz" + +# Flattened Table Schema +# ====================== +# +# Each table is normalized from a single FHIR resource type. Columns are: +# - patient_id (str): Foreign key to patient (derived from subject.reference or id). +# - [timestamp] (str): ISO 8601 datetime string (coerced; nullable). +# - attributes (List[str]): Additional columns from the resource. +# +# Unsupported resource types (Medication, MedicationAdministration, Specimen, …) +# are silently dropped during ingest; only tables listed here are written. + +tables: + patient: + file_path: "patient.parquet" + patient_id: "patient_id" + timestamp: "birth_date" + attributes: + - "patient_fhir_id" + - "birth_date" + - "gender" + - "deceased_boolean" + - "deceased_datetime" + + encounter: + file_path: "encounter.parquet" + patient_id: "patient_id" + timestamp: "event_time" + attributes: + - "resource_id" + - "encounter_id" + - "event_time" + - "encounter_class" + - "encounter_end" + + condition: + file_path: "condition.parquet" + patient_id: "patient_id" + timestamp: "event_time" + attributes: + - "resource_id" + - "encounter_id" + - "event_time" + - "concept_key" + + observation: + file_path: "observation.parquet" + patient_id: "patient_id" + timestamp: "event_time" + attributes: + - "resource_id" + - "encounter_id" + - "event_time" + - "concept_key" + + medication_request: + file_path: "medication_request.parquet" + patient_id: "patient_id" + timestamp: "event_time" + attributes: + - "resource_id" + - "encounter_id" + - "event_time" + - "concept_key" + + procedure: + file_path: "procedure.parquet" + patient_id: "patient_id" + timestamp: "event_time" + attributes: + - "resource_id" + - "encounter_id" + - "event_time" + - "concept_key" diff --git a/pyhealth/datasets/fhir_utils.py b/pyhealth/datasets/fhir_utils.py new file mode 100644 index 000000000..a80c6f1fe --- /dev/null +++ b/pyhealth/datasets/fhir_utils.py @@ -0,0 +1,438 @@ +"""FHIR NDJSON parsing, flattening, and Parquet table writing. + +Key public API +-------------- +stream_fhir_ndjson_to_flat_tables(root, glob_pattern, out_dir) + Stream all matching NDJSON/NDJSON.GZ resources into six per-type Parquet tables. + +sorted_ndjson_files(root, glob_pattern) + List matching NDJSON files under root (deduplicated, sorted). + +filter_flat_tables_by_patient_ids(source_dir, out_dir, keep_ids) + Subset existing flattened tables to a specific patient cohort. +""" + +from __future__ import annotations + +import gzip +from datetime import datetime +from pathlib import Path +from typing import Any, Dict, Iterator, List, Optional, Sequence, Tuple + +import orjson +import polars as pl +import pyarrow as pa +import pyarrow.parquet as pq + +GlobPatternArg = str | Sequence[str] +"""Single glob string or sequence of strings for NDJSON file discovery.""" + +__all__ = [ + # Types + "GlobPatternArg", + # Constants + "FHIR_SCHEMA_VERSION", + "FHIR_TABLES", + "FHIR_TABLES_FOR_PATIENT_IDS", + "FHIR_TABLE_FILE_NAMES", + "FHIR_TABLE_COLUMNS", + # Datetime helpers + "parse_dt", + "as_naive", + # FHIR iteration + "iter_ndjson_objects", + "iter_resources_from_ndjson_obj", + # Resource extraction + "patient_id_for_resource", + # Pipeline + "sorted_ndjson_files", + "stream_fhir_ndjson_to_flat_tables", + "filter_flat_tables_by_patient_ids", + "sorted_patient_ids_from_flat_tables", +] + +FHIR_SCHEMA_VERSION = 3 + +FHIR_TABLES: List[str] = [ + "patient", + "encounter", + "condition", + "observation", + "medication_request", + "procedure", +] + +FHIR_TABLES_FOR_PATIENT_IDS: List[str] = [t for t in FHIR_TABLES if t != "patient"] + +FHIR_TABLE_FILE_NAMES: Dict[str, str] = {t: f"{t}.parquet" for t in FHIR_TABLES} + +FHIR_TABLE_COLUMNS: Dict[str, List[str]] = { + "patient": ["patient_id", "patient_fhir_id", "birth_date", "gender", "deceased_boolean", "deceased_datetime"], + "encounter": ["patient_id", "resource_id", "encounter_id", "event_time", "encounter_class", "encounter_end"], + "condition": ["patient_id", "resource_id", "encounter_id", "event_time", "concept_key"], + "observation": ["patient_id", "resource_id", "encounter_id", "event_time", "concept_key"], + "medication_request": ["patient_id", "resource_id", "encounter_id", "event_time", "concept_key"], + "procedure": ["patient_id", "resource_id", "encounter_id", "event_time", "concept_key"], +} + +# --------------------------------------------------------------------------- +# Datetime helpers (also imported by cehr_processor) +# --------------------------------------------------------------------------- + + +def parse_dt(s: Optional[str]) -> Optional[datetime]: + if not s: + return None + try: + dt = datetime.fromisoformat(s.replace("Z", "+00:00")) + except ValueError: + dt = None + if dt is None and len(s) >= 10: + try: + dt = datetime.strptime(s[:10], "%Y-%m-%d") + except ValueError: + return None + if dt is None: + return None + return dt.replace(tzinfo=None) if dt.tzinfo is not None else dt + + +def as_naive(dt: Optional[datetime]) -> Optional[datetime]: + if dt is None: + return None + return dt.replace(tzinfo=None) if dt.tzinfo is not None else dt + + +# --------------------------------------------------------------------------- +# FHIR JSON helpers +# --------------------------------------------------------------------------- + + +def _coding_key(coding: Dict[str, Any]) -> str: + return f"{coding.get('system') or 'unknown'}|{coding.get('code') or 'unknown'}" + + +def _first_coding(obj: Optional[Dict[str, Any]]) -> Optional[str]: + if not obj: + return None + codings = obj.get("coding") or [] + if not codings and "concept" in obj: + codings = (obj.get("concept") or {}).get("coding") or [] + return _coding_key(codings[0]) if codings else None + + +def _ref_id(ref: Optional[str]) -> Optional[str]: + if not ref: + return None + return ref.rsplit("/", 1)[-1] if "/" in ref else ref + + +def _unwrap_resource_dict(raw: Any) -> Optional[Dict[str, Any]]: + if not isinstance(raw, dict): + return None + resource = raw.get("resource") if "resource" in raw else raw + return resource if isinstance(resource, dict) else None + + +def iter_resources_from_ndjson_obj(obj: Dict[str, Any]) -> Iterator[Dict[str, Any]]: + """Yield resource dicts from one parsed NDJSON object (Bundle or bare resource).""" + if "entry" in obj: + for entry in obj.get("entry") or []: + resource = entry.get("resource") + if isinstance(resource, dict): + yield resource + return + resource = _unwrap_resource_dict(obj) + if resource is not None: + yield resource + + +def iter_ndjson_objects(path: Path) -> Iterator[Dict[str, Any]]: + """Yield parsed JSON objects from a plain or gzip-compressed NDJSON file.""" + opener = ( + gzip.open(path, "rt", encoding="utf-8", errors="replace") + if path.suffix == ".gz" + else open(path, encoding="utf-8", errors="replace") + ) + with opener as stream: + for line in stream: + line = line.strip() + if not line: + continue + parsed = orjson.loads(line) + if isinstance(parsed, dict): + yield parsed + + +# --------------------------------------------------------------------------- +# Resource field extraction +# --------------------------------------------------------------------------- + + +def _clinical_concept_key(res: Dict[str, Any]) -> Optional[str]: + """Resolve a stable token key from a FHIR resource.""" + resource_type = res.get("resourceType") + if resource_type == "MedicationRequest": + med_cc = res.get("medicationCodeableConcept") + if isinstance(med_cc, dict): + key = _first_coding(med_cc) + if key: + return key + med_ref = res.get("medicationReference") + if isinstance(med_ref, dict): + ref = med_ref.get("reference") + if ref: + return f"MedicationRequest/reference|{_ref_id(ref) or ref}" + return None + code = res.get("code") + return _first_coding(code) if isinstance(code, dict) else None + + +def patient_id_for_resource( + resource: Dict[str, Any], + resource_type: Optional[str] = None, +) -> Optional[str]: + resource_type = resource_type or resource.get("resourceType") + if resource_type == "Patient": + pid = resource.get("id") + return str(pid) if pid is not None else None + if resource_type in {"Encounter", "Condition", "Observation", "MedicationRequest", "Procedure"}: + return _ref_id((resource.get("subject") or {}).get("reference")) + return None + + +def _resource_time_string( + resource: Dict[str, Any], + resource_type: Optional[str] = None, +) -> Optional[str]: + resource_type = resource_type or resource.get("resourceType") + if resource_type == "Patient": + return resource.get("birthDate") + if resource_type == "Encounter": + return (resource.get("period") or {}).get("start") + if resource_type == "Condition": + return resource.get("onsetDateTime") or resource.get("recordedDate") + if resource_type == "Observation": + return resource.get("effectiveDateTime") or resource.get("issued") + if resource_type == "MedicationRequest": + return resource.get("authoredOn") + if resource_type == "Procedure": + return resource.get("performedDateTime") or resource.get("recordedDate") + return None + + +# --------------------------------------------------------------------------- +# Flattening +# --------------------------------------------------------------------------- + + +def _normalize_deceased_boolean_for_storage(value: Any) -> Optional[str]: + """Map Patient.deceasedBoolean to stored "true"/"false"/None. + + FHIR JSON uses real booleans; some exports use strings. Python's + bool("false") is True, so we must not coerce with bool(). + """ + if value is None: + return None + if value is True: + return "true" + if value is False: + return "false" + if isinstance(value, str): + key = value.strip().lower() + if key in ("true", "1", "yes", "y", "t"): + return "true" + if key in ("false", "0", "no", "n", "f", ""): + return "false" + return None + if isinstance(value, (int, float)) and not isinstance(value, bool): + if value == 0: + return "false" + if value == 1: + return "true" + return None + return None + + +_RESOURCE_TYPE_TO_TABLE: Dict[str, str] = { + "Condition": "condition", + "Observation": "observation", + "MedicationRequest": "medication_request", + "Procedure": "procedure", +} + + +def _flatten_resource_to_table_row( + resource: Dict[str, Any], +) -> Optional[Tuple[str, Dict[str, Optional[str]]]]: + """Map one FHIR resource dict to (table_name, row_dict), or None if unsupported.""" + resource_type = resource.get("resourceType") + patient_id = patient_id_for_resource(resource, resource_type) + if not patient_id: + return None + + if resource_type == "Patient": + return "patient", { + "patient_id": patient_id, + "patient_fhir_id": str(resource.get("id") or patient_id), + "birth_date": resource.get("birthDate"), + "gender": resource.get("gender"), + "deceased_boolean": _normalize_deceased_boolean_for_storage(resource.get("deceasedBoolean")), + "deceased_datetime": resource.get("deceasedDateTime"), + } + + resource_id = str(resource.get("id")) if resource.get("id") is not None else None + event_time = _resource_time_string(resource, resource_type) + + if resource_type == "Encounter": + return "encounter", { + "patient_id": patient_id, + "resource_id": resource_id, + "encounter_id": resource_id, + "event_time": event_time, + "encounter_class": (resource.get("class") or {}).get("code"), + "encounter_end": (resource.get("period") or {}).get("end"), + } + + table_name = _RESOURCE_TYPE_TO_TABLE.get(resource_type) + if table_name is None: + return None + return table_name, { + "patient_id": patient_id, + "resource_id": resource_id, + "encounter_id": _ref_id((resource.get("encounter") or {}).get("reference")), + "event_time": event_time, + "concept_key": _clinical_concept_key(resource), + } + + +# --------------------------------------------------------------------------- +# Parquet writer +# --------------------------------------------------------------------------- + + +def _table_schema(table_name: str) -> pa.Schema: + return pa.schema([(col, pa.string()) for col in FHIR_TABLE_COLUMNS[table_name]]) + + +class _BufferedParquetWriter: + def __init__(self, path: Path, schema: pa.Schema, batch_size: int = 50_000) -> None: + self.path = path + self.schema = schema + self.batch_size = batch_size + self.rows: List[Dict[str, Any]] = [] + self.writer: Optional[pq.ParquetWriter] = None + self.path.parent.mkdir(parents=True, exist_ok=True) + + def add(self, row: Dict[str, Any]) -> None: + self.rows.append(row) + if len(self.rows) >= self.batch_size: + self.flush() + + def flush(self) -> None: + if not self.rows: + return + table = pa.Table.from_pylist(self.rows, schema=self.schema) + if self.writer is None: + self.writer = pq.ParquetWriter(str(self.path), self.schema) + self.writer.write_table(table) + self.rows.clear() + + def close(self) -> None: + self.flush() + if self.writer is None: + pq.write_table(pa.Table.from_pylist([], schema=self.schema), str(self.path)) + return + self.writer.close() + + +# --------------------------------------------------------------------------- +# Pipeline +# --------------------------------------------------------------------------- + + +def sorted_ndjson_files(root: Path, glob_pattern: GlobPatternArg) -> List[Path]: + """Return sorted unique file paths under root matching glob pattern(s). + + Args: + root: Root directory to search under. + glob_pattern: Single glob string or sequence of glob strings. + + Returns: + Sorted list of matching files. Empty if no matches. + """ + patterns = [glob_pattern] if isinstance(glob_pattern, str) else list(glob_pattern) + files: set[Path] = set() + for pat in patterns: + files.update(p for p in root.glob(pat) if p.is_file()) + return sorted(files, key=lambda p: str(p)) + + +def stream_fhir_ndjson_to_flat_tables( + root: Path, + glob_pattern: GlobPatternArg, + out_dir: Path, +) -> None: + """Stream NDJSON resources into normalized per-resource Parquet tables under out_dir. + + Args: + root: Root directory containing NDJSON/NDJSON.GZ files. + glob_pattern: Single glob string or sequence of glob strings. + out_dir: Output directory for per-resource-type Parquet tables. + Creates patient.parquet, encounter.parquet, condition.parquet, + observation.parquet, medication_request.parquet, procedure.parquet. + """ + out_dir.mkdir(parents=True, exist_ok=True) + writers = { + name: _BufferedParquetWriter(path=out_dir / FHIR_TABLE_FILE_NAMES[name], schema=_table_schema(name)) + for name in FHIR_TABLES + } + try: + for file_path in sorted_ndjson_files(root, glob_pattern): + for ndjson_obj in iter_ndjson_objects(file_path): + for resource in iter_resources_from_ndjson_obj(ndjson_obj): + result = _flatten_resource_to_table_row(resource) + if result is not None: + writers[result[0]].add(result[1]) + finally: + for writer in writers.values(): + writer.close() + + +def sorted_patient_ids_from_flat_tables(table_dir: Path) -> List[str]: + """Return sorted unique patient IDs from a directory of flattened Parquet tables.""" + patient_table = table_dir / FHIR_TABLE_FILE_NAMES["patient"] + if patient_table.exists(): + return ( + pl.scan_parquet(str(patient_table)) + .select("patient_id") + .unique() + .sort("patient_id") + .collect(engine="streaming")["patient_id"] + .to_list() + ) + frames = [ + pl.scan_parquet(str(table_dir / FHIR_TABLE_FILE_NAMES[t])).select("patient_id") + for t in FHIR_TABLES_FOR_PATIENT_IDS + ] + return ( + pl.concat(frames) + .unique() + .sort("patient_id") + .collect(engine="streaming")["patient_id"] + .to_list() + ) + + +def filter_flat_tables_by_patient_ids( + source_dir: Path, + out_dir: Path, + keep_ids: Sequence[str], +) -> None: + """Filter all flattened tables to only include rows for the given patient IDs.""" + out_dir.mkdir(parents=True, exist_ok=True) + keep_set = set(keep_ids) + for name in FHIR_TABLES: + src = source_dir / FHIR_TABLE_FILE_NAMES[name] + dst = out_dir / FHIR_TABLE_FILE_NAMES[name] + pl.scan_parquet(str(src)).filter(pl.col("patient_id").is_in(keep_set)).sink_parquet(str(dst)) diff --git a/pyhealth/datasets/mimic4_fhir.py b/pyhealth/datasets/mimic4_fhir.py new file mode 100644 index 000000000..7bc53b878 --- /dev/null +++ b/pyhealth/datasets/mimic4_fhir.py @@ -0,0 +1,392 @@ +"""MIMIC-IV FHIR ingestion using flattened resource tables. + +Architecture +------------ +1. Stream NDJSON/NDJSON.GZ FHIR resources from disk. +2. Normalize each resource type into a 2D table (Patient, Encounter, + Condition, Observation, MedicationRequest, Procedure) via + :mod:`~pyhealth.datasets.fhir_utils`. +3. Feed those tables through the standard YAML-driven + :class:`~pyhealth.datasets.BaseDataset` pipeline so downstream task + processing operates on :class:`~pyhealth.data.Patient` and + ``global_event_df`` rows. +""" + +from __future__ import annotations + +import functools +import hashlib +import logging +import operator +import os +import shutil +import uuid +from pathlib import Path +from typing import Any, Dict, List, Optional, Sequence + +import dask.dataframe as dd +import narwhals as nw +import orjson +import pandas as pd +import platformdirs +from yaml import safe_load + +from .base_dataset import BaseDataset +from .fhir_utils import ( + FHIR_SCHEMA_VERSION, + FHIR_TABLE_FILE_NAMES, + FHIR_TABLES, + sorted_patient_ids_from_flat_tables, + filter_flat_tables_by_patient_ids, + stream_fhir_ndjson_to_flat_tables, +) + +logger = logging.getLogger(__name__) + + +def read_fhir_settings_yaml( + path: Optional[str] = None, +) -> Dict[str, Any]: + if path is None: + path = os.path.join( + os.path.dirname(__file__), "configs", "mimic4_fhir.yaml" + ) + with open(path, encoding="utf-8") as stream: + data = safe_load(stream) + return data if isinstance(data, dict) else {} + + +def _strip_tz_to_naive_ms(part: pd.Series) -> pd.Series: + if getattr(part.dtype, "tz", None) is not None: + part = part.dt.tz_localize(None) + return part.astype("datetime64[ms]") + + +class MIMIC4FHIRDataset(BaseDataset): + """MIMIC-IV FHIR with flattened resource tables. + + Streams raw MIMIC-IV FHIR NDJSON/NDJSON.GZ exports into six + flattened Parquet tables then pipelines them through + :class:`~pyhealth.datasets.BaseDataset` for standard downstream + task processing (global event dataframe, patient iteration, task + sampling). + + Args: + root: Path to the NDJSON/NDJSON.GZ export directory. + config_path: Path to a custom YAML config. Defaults to + ``pyhealth/datasets/configs/mimic4_fhir.yaml``. + glob_pattern: Single glob for NDJSON files. Mutually + exclusive with *glob_patterns*. + glob_patterns: Multiple glob patterns. Mutually exclusive + with *glob_pattern*. + max_patients: Limit ingest to the first *N* unique patient + IDs. + ingest_num_shards: Ignored; retained for API compatibility. + cache_dir: Cache directory root (UUID subdir appended per + config). + num_workers: Worker processes for task sampling. + dev: Development mode; limits to 1 000 patients if + *max_patients* is ``None``. + + Examples: + >>> ds = MIMIC4FHIRDataset( + ... root="/data/mimic-iv-fhir", + ... glob_pattern="**/*.ndjson.gz", + ... max_patients=500, + ... ) + >>> sample_ds = ds.set_task(task, num_workers=4) + """ + + def __init__( + self, + root: str, + config_path: Optional[str] = None, + glob_pattern: Optional[str] = None, + glob_patterns: Optional[Sequence[str]] = None, + max_patients: Optional[int] = None, + ingest_num_shards: Optional[int] = None, + cache_dir: Optional[str | Path] = None, + num_workers: int = 1, + dev: bool = False, + ) -> None: + del ingest_num_shards + + default_cfg = os.path.join( + os.path.dirname(__file__), "configs", "mimic4_fhir.yaml" + ) + self._fhir_config_path = str( + Path(config_path or default_cfg).resolve() + ) + self._fhir_settings = read_fhir_settings_yaml( + self._fhir_config_path + ) + + if glob_pattern is not None and glob_patterns is not None: + raise ValueError( + "Pass at most one of glob_pattern and glob_patterns." + ) + if glob_patterns is not None: + self.glob_patterns: List[str] = list(glob_patterns) + elif glob_pattern is not None: + self.glob_patterns = [glob_pattern] + else: + raw_list = self._fhir_settings.get("glob_patterns") + if raw_list: + if not isinstance(raw_list, list): + raise TypeError( + "mimic4_fhir.yaml glob_patterns must be a " + "list of strings." + ) + self.glob_patterns = [str(x) for x in raw_list] + elif self._fhir_settings.get("glob_pattern") is not None: + self.glob_patterns = [ + str(self._fhir_settings["glob_pattern"]) + ] + else: + self.glob_patterns = ["**/*.ndjson.gz"] + + self.glob_pattern = ( + self.glob_patterns[0] + if len(self.glob_patterns) == 1 + else "; ".join(self.glob_patterns) + ) + self.max_patients = ( + 1000 if dev and max_patients is None else max_patients + ) + + resolved_root = str(Path(root).expanduser().resolve()) + super().__init__( + root=resolved_root, + tables=FHIR_TABLES, + dataset_name="mimic4_fhir", + config_path=self._fhir_config_path, + cache_dir=cache_dir, + num_workers=num_workers, + dev=dev, + ) + + # ------------------------------------------------------------------ + # Cache identity + # ------------------------------------------------------------------ + + def _init_cache_dir(self, cache_dir: str | Path | None) -> Path: + try: + yaml_digest = hashlib.sha256( + Path(self._fhir_config_path).read_bytes() + ).hexdigest()[:16] + except OSError: + yaml_digest = "missing" + identity = orjson.dumps( + { + "root": self.root, + "tables": sorted(self.tables), + "dataset_name": self.dataset_name, + "dev": self.dev, + "glob_patterns": self.glob_patterns, + "max_patients": self.max_patients, + "fhir_schema_version": FHIR_SCHEMA_VERSION, + "fhir_yaml_digest16": yaml_digest, + }, + option=orjson.OPT_SORT_KEYS, + ).decode("utf-8") + cache_id = str(uuid.uuid5(uuid.NAMESPACE_DNS, identity)) + out = ( + Path(platformdirs.user_cache_dir(appname="pyhealth")) + / cache_id + if cache_dir is None + else Path(cache_dir) / cache_id + ) + out.mkdir(parents=True, exist_ok=True) + logger.info(f"Cache dir: {out}") + return out + + # ------------------------------------------------------------------ + # NDJSON → Parquet ingest + # ------------------------------------------------------------------ + + @property + def prepared_tables_dir(self) -> Path: + return self.cache_dir / "flattened_tables" + + def _ensure_prepared_tables(self) -> None: + root = Path(self.root) + if not root.is_dir(): + raise FileNotFoundError( + f"MIMIC4 FHIR root not found: {root}" + ) + + expected = [ + self.prepared_tables_dir / FHIR_TABLE_FILE_NAMES[t] + for t in FHIR_TABLES + ] + if all(p.is_file() for p in expected): + return + if self.prepared_tables_dir.exists(): + shutil.rmtree(self.prepared_tables_dir) + + try: + staging_root = self.create_tmpdir() + staging = staging_root / "flattened_fhir_tables" + staging.mkdir(parents=True, exist_ok=True) + stream_fhir_ndjson_to_flat_tables( + root, self.glob_patterns, staging + ) + if self.max_patients is None: + shutil.move( + str(staging), str(self.prepared_tables_dir) + ) + return + + filtered_root = self.create_tmpdir() + filtered = filtered_root / "filtered" + pids = sorted_patient_ids_from_flat_tables(staging) + filter_flat_tables_by_patient_ids( + staging, filtered, pids[: self.max_patients] + ) + shutil.move( + str(filtered), str(self.prepared_tables_dir) + ) + finally: + self.clean_tmpdir() + + def _event_transform(self, output_dir: Path) -> None: + self._ensure_prepared_tables() + super()._event_transform(output_dir) + + # ------------------------------------------------------------------ + # Table loading (Parquet instead of CSV) + # ------------------------------------------------------------------ + + def load_table(self, table_name: str) -> dd.DataFrame: + """Load one flattened Parquet table into the standard event + schema. + + Deviations from ``BaseDataset.load_table`` (CSV via + ``_scan_csv_tsv_gz``): + + * Reads from pre-built Parquet under ``prepared_tables_dir``. + * Timestamp parsing uses ``errors="coerce"`` + ``utc=True`` + (FHIR ISO strings include timezone suffix or partial dates). + * Strips tz-aware timestamps to naive UTC for Dask compat. + * Drops rows with null ``patient_id`` before returning. + """ + assert self.config is not None + if table_name not in self.config.tables: + raise ValueError( + f"Table {table_name} not found in config" + ) + + table_cfg = self.config.tables[table_name] + path = self.prepared_tables_dir / table_cfg.file_path + if not path.exists(): + raise FileNotFoundError( + f"Flattened table not found: {path}" + ) + + logger.info( + f"Scanning FHIR flattened table: {table_name} " + f"from {path}" + ) + df: dd.DataFrame = dd.read_parquet( + str(path), split_row_groups=True, blocksize="64MB" + ).replace("", pd.NA) + df = df.rename(columns=str.lower) + + preprocess_func = getattr( + self, f"preprocess_{table_name}", None + ) + if preprocess_func is not None: + logger.info( + f"Preprocessing FHIR table: {table_name} " + f"with {preprocess_func.__name__}" + ) + df = preprocess_func(nw.from_native(df)).to_native() # type: ignore[union-attr] + + for join_cfg in table_cfg.join: + join_path = ( + self.prepared_tables_dir + / Path(join_cfg.file_path).name + ) + if not join_path.exists(): + raise FileNotFoundError( + f"FHIR join table not found: {join_path}" + ) + logger.info( + f"Joining FHIR table {table_name} with {join_path}" + ) + join_df: dd.DataFrame = dd.read_parquet( + str(join_path), + split_row_groups=True, + blocksize="64MB", + ).replace("", pd.NA) + join_df = join_df.rename(columns=str.lower) + join_key = join_cfg.on.lower() + cols = [c.lower() for c in join_cfg.columns] + df = df.merge( + join_df[[join_key] + cols], + on=join_key, + how=join_cfg.how, + ) + + ts_col = table_cfg.timestamp + if ts_col: + ts = ( + functools.reduce( + operator.add, + (df[c].astype("string") for c in ts_col), + ) + if isinstance(ts_col, list) + else df[ts_col].astype("string") + ) + ts = dd.to_datetime( + ts, + format=table_cfg.timestamp_format, + errors="coerce", + utc=True, + ) + df = df.assign( + timestamp=ts.map_partitions(_strip_tz_to_naive_ms) + ) + else: + df = df.assign(timestamp=pd.NaT) + + if table_cfg.patient_id: + df = df.assign( + patient_id=df[table_cfg.patient_id].astype("string") + ) + else: + df = df.reset_index(drop=True) + df = df.assign(patient_id=df.index.astype("string")) + + df = df.dropna(subset=["patient_id"]) + df = df.assign(event_type=table_name) + rename_attr = { + attr.lower(): f"{table_name}/{attr}" + for attr in table_cfg.attributes + } + df = df.rename(columns=rename_attr) + return df[ + ["patient_id", "event_type", "timestamp"] + + [rename_attr[a.lower()] for a in table_cfg.attributes] + ] + + # ------------------------------------------------------------------ + # Patient IDs (deterministic sorted order) + # ------------------------------------------------------------------ + + @property + def unique_patient_ids(self) -> List[str]: + if self._unique_patient_ids is None: + self._unique_patient_ids = ( + self.global_event_df.select("patient_id") + .unique() + .sort("patient_id") + .collect(engine="streaming") + .to_series() + .to_list() + ) + logger.info( + f"Found {len(self._unique_patient_ids)} " + f"unique patient IDs" + ) + return self._unique_patient_ids diff --git a/pyhealth/models/__init__.py b/pyhealth/models/__init__.py index 5233b1726..86486dab8 100644 --- a/pyhealth/models/__init__.py +++ b/pyhealth/models/__init__.py @@ -38,6 +38,7 @@ from .transformer import Transformer, TransformerLayer from .transformers_model import TransformersModel from .ehrmamba import EHRMamba, MambaBlock +from .ehrmamba_cehr import EHRMambaCEHR from .vae import VAE from .vision_embedding import VisionEmbeddingModel from .text_embedding import TextEmbedding diff --git a/pyhealth/models/cehr_embeddings.py b/pyhealth/models/cehr_embeddings.py new file mode 100644 index 000000000..7974a699e --- /dev/null +++ b/pyhealth/models/cehr_embeddings.py @@ -0,0 +1,112 @@ +# SPDX-License-Identifier: Apache-2.0 +# Copyright 2024 Vector Institute / Odyssey authors +# +# Derived from Odyssey (https://github.com/VectorInstitute/odyssey): +# odyssey/models/embeddings.py — MambaEmbeddingsForCEHR, TimeEmbeddingLayer, VisitEmbedding +# Modifications: removed HuggingFace MambaConfig dependency; explicit constructor args. + +from __future__ import annotations + +from typing import Any, Optional + +import torch +from torch import nn + + +class TimeEmbeddingLayer(nn.Module): + """Embedding layer for time features (sinusoidal).""" + + def __init__(self, embedding_size: int, is_time_delta: bool = False): + super().__init__() + self.embedding_size = embedding_size + self.is_time_delta = is_time_delta + self.w = nn.Parameter(torch.empty(1, self.embedding_size)) + self.phi = nn.Parameter(torch.empty(1, self.embedding_size)) + nn.init.xavier_uniform_(self.w) + nn.init.xavier_uniform_(self.phi) + + def forward(self, time_stamps: torch.Tensor) -> torch.Tensor: + if self.is_time_delta: + time_stamps = torch.cat( + (time_stamps[:, 0:1] * 0, time_stamps[:, 1:] - time_stamps[:, :-1]), + dim=-1, + ) + time_stamps = time_stamps.float() + next_input = time_stamps.unsqueeze(-1) * self.w + self.phi + return torch.sin(next_input) + + +class VisitEmbedding(nn.Module): + """Embedding layer for visit segments.""" + + def __init__(self, visit_order_size: int, embedding_size: int): + super().__init__() + self.embedding = nn.Embedding(visit_order_size, embedding_size) + + def forward(self, visit_segments: torch.Tensor) -> torch.Tensor: + return self.embedding(visit_segments) + + +class MambaEmbeddingsForCEHR(nn.Module): + """CEHR-style combined embeddings for Mamba (concept + type + time + age + visit).""" + + def __init__( + self, + vocab_size: int, + hidden_size: int, + pad_token_id: int = 0, + type_vocab_size: int = 9, + max_num_visits: int = 512, + time_embeddings_size: int = 32, + visit_order_size: int = 3, + layer_norm_eps: float = 1e-12, + hidden_dropout_prob: float = 0.1, + ) -> None: + super().__init__() + self.hidden_size = hidden_size + self.pad_token_id = pad_token_id + self.type_vocab_size = type_vocab_size + self.max_num_visits = max_num_visits + self.word_embeddings = nn.Embedding( + vocab_size, hidden_size, padding_idx=pad_token_id + ) + self.token_type_embeddings = nn.Embedding(type_vocab_size, hidden_size) + self.visit_order_embeddings = nn.Embedding(max_num_visits, hidden_size) + self.time_embeddings = TimeEmbeddingLayer( + embedding_size=time_embeddings_size, is_time_delta=True + ) + self.age_embeddings = TimeEmbeddingLayer( + embedding_size=time_embeddings_size, is_time_delta=False + ) + self.visit_segment_embeddings = VisitEmbedding( + visit_order_size=visit_order_size, embedding_size=hidden_size + ) + self.scale_back_concat_layer = nn.Linear( + hidden_size + 2 * time_embeddings_size, hidden_size + ) + self.tanh = nn.Tanh() + self.LayerNorm = nn.LayerNorm(hidden_size, eps=layer_norm_eps) + self.dropout = nn.Dropout(hidden_dropout_prob) + + def forward( + self, + input_ids: torch.Tensor, + token_type_ids_batch: torch.Tensor, + time_stamps: torch.Tensor, + ages: torch.Tensor, + visit_orders: torch.Tensor, + visit_segments: torch.Tensor, + ) -> torch.Tensor: + inputs_embeds = self.word_embeddings(input_ids) + time_stamps_embeds = self.time_embeddings(time_stamps) + ages_embeds = self.age_embeddings(ages) + visit_segments_embeds = self.visit_segment_embeddings(visit_segments) + visit_order_embeds = self.visit_order_embeddings(visit_orders) + token_type_embeds = self.token_type_embeddings(token_type_ids_batch) + concat_in = torch.cat( + (inputs_embeds, time_stamps_embeds, ages_embeds), dim=-1 + ) + h = self.tanh(self.scale_back_concat_layer(concat_in)) + embeddings = h + token_type_embeds + visit_order_embeds + visit_segments_embeds + embeddings = self.dropout(embeddings) + return self.LayerNorm(embeddings) diff --git a/pyhealth/models/ehrmamba_cehr.py b/pyhealth/models/ehrmamba_cehr.py new file mode 100644 index 000000000..cd555629c --- /dev/null +++ b/pyhealth/models/ehrmamba_cehr.py @@ -0,0 +1,117 @@ +"""EHRMamba with CEHR-style embeddings for single-stream FHIR token sequences.""" + +from __future__ import annotations + +from typing import Any, Dict, Optional + +import torch +from torch import nn + +from pyhealth.datasets import SampleDataset + +from .base_model import BaseModel +from .cehr_embeddings import MambaEmbeddingsForCEHR +from .ehrmamba import MambaBlock +from .utils import get_rightmost_masked_timestep + + +class EHRMambaCEHR(BaseModel): + """Mamba backbone over CEHR embeddings (FHIR / MPF pipeline). + + Args: + dataset: Fitted :class:`~pyhealth.datasets.SampleDataset` with MPF task schema. + vocab_size: Concept embedding vocabulary size (typically ``task.vocab.vocab_size``). + embedding_dim: Hidden size (``hidden_size`` in CEHR embeddings). + num_layers: Number of :class:`~pyhealth.models.ehrmamba.MambaBlock` layers. + pad_token_id: Padding id for masking (default 0). + state_size: SSM state size per channel. + conv_kernel: Causal conv kernel in each block. + dropout: Dropout before classifier. + """ + + def __init__( + self, + dataset: SampleDataset, + vocab_size: int, + embedding_dim: int = 128, + num_layers: int = 2, + pad_token_id: int = 0, + state_size: int = 16, + conv_kernel: int = 4, + dropout: float = 0.1, + type_vocab_size: int = 16, + max_num_visits: int = 512, + time_embeddings_size: int = 32, + visit_segment_vocab: int = 3, + ): + super().__init__(dataset=dataset) + self.embedding_dim = embedding_dim + self.num_layers = num_layers + self.pad_token_id = pad_token_id + self.vocab_size = vocab_size + + assert len(self.label_keys) == 1, "EHRMambaCEHR supports single label key only" + self.label_key = self.label_keys[0] + self.mode = self.dataset.output_schema[self.label_key] + + self.embeddings = MambaEmbeddingsForCEHR( + vocab_size=vocab_size, + hidden_size=embedding_dim, + pad_token_id=pad_token_id, + type_vocab_size=type_vocab_size, + max_num_visits=max_num_visits, + time_embeddings_size=time_embeddings_size, + visit_order_size=visit_segment_vocab, + ) + self.blocks = nn.ModuleList( + [ + MambaBlock( + d_model=embedding_dim, + state_size=state_size, + conv_kernel=conv_kernel, + ) + for _ in range(num_layers) + ] + ) + self.dropout = nn.Dropout(dropout) + out_dim = self.get_output_size() + self.fc = nn.Linear(embedding_dim, out_dim) + self._forecasting_head: Optional[nn.Module] = None + + def forward_forecasting(self, **kwargs: Any) -> Optional[torch.Tensor]: + """Optional next-token / forecasting head (extension point; not implemented).""" + + return None + + def forward(self, **kwargs: Any) -> Dict[str, torch.Tensor]: + concept_ids = kwargs["concept_ids"].to(self.device).long() + token_type_ids = kwargs["token_type_ids"].to(self.device).long() + time_stamps = kwargs["time_stamps"].to(self.device).float() + ages = kwargs["ages"].to(self.device).float() + visit_orders = kwargs["visit_orders"].to(self.device).long() + visit_segments = kwargs["visit_segments"].to(self.device).long() + + x = self.embeddings( + input_ids=concept_ids, + token_type_ids_batch=token_type_ids, + time_stamps=time_stamps, + ages=ages, + visit_orders=visit_orders, + visit_segments=visit_segments, + ) + mask = concept_ids != self.pad_token_id + for blk in self.blocks: + x = blk(x) + pooled = get_rightmost_masked_timestep(x, mask) + logits = self.fc(self.dropout(pooled)) + y_true = kwargs[self.label_key].to(self.device).float() + if y_true.dim() == 1: + y_true = y_true.unsqueeze(-1) + loss = self.get_loss_function()(logits, y_true) + y_prob = self.prepare_y_prob(logits) + return { + "loss": loss, + "y_prob": y_prob, + "y_true": y_true, + "logit": logits, + } diff --git a/pyhealth/models/utils.py b/pyhealth/models/utils.py index 67edc010e..45cd6608d 100644 --- a/pyhealth/models/utils.py +++ b/pyhealth/models/utils.py @@ -44,3 +44,31 @@ def get_last_visit(hidden_states, mask): last_hidden_states = torch.gather(hidden_states, 1, last_visit) last_hidden_state = last_hidden_states[:, 0, :] return last_hidden_state + + +def get_rightmost_masked_timestep(hidden_states, mask): + """Gather hidden state at the last True position in ``mask`` per row. + + Unlike :func:`get_last_visit`, this does **not** assume valid tokens form a + contiguous prefix; it picks the maximum index where ``mask`` is True. + Use for MPF / CEHR layouts where padding can appear between boundary tokens. + + Args: + hidden_states: ``[batch, seq_len, hidden_size]``. + mask: ``[batch, seq_len]`` bool. + + Returns: + Tensor ``[batch, hidden_size]``. + """ + if mask is None: + return hidden_states[:, -1, :] + batch, seq_len, hidden = hidden_states.shape + device = hidden_states.device + idx = torch.arange(seq_len, device=device, dtype=torch.long).unsqueeze(0).expand( + batch, -1 + ) + idx_m = torch.where(mask, idx, torch.full_like(idx, -1)) + last_idx = idx_m.max(dim=1).values.clamp(min=0) + last_idx = last_idx.view(batch, 1, 1).expand(batch, 1, hidden) + gathered = torch.gather(hidden_states, 1, last_idx) + return gathered[:, 0, :] diff --git a/pyhealth/processors/__init__.py b/pyhealth/processors/__init__.py index b48072270..4568a5ece 100644 --- a/pyhealth/processors/__init__.py +++ b/pyhealth/processors/__init__.py @@ -50,6 +50,7 @@ def get_processor(name: str): from .ignore_processor import IgnoreProcessor from .temporal_timeseries_processor import TemporalTimeseriesProcessor from .tuple_time_text_processor import TupleTimeTextProcessor +from .cehr_processor import CehrProcessor, ConceptVocab # Expose public API from .base_processor import ( @@ -79,4 +80,6 @@ def get_processor(name: str): "GraphProcessor", "AudioProcessor", "TupleTimeTextProcessor", + "CehrProcessor", + "ConceptVocab", ] diff --git a/pyhealth/processors/cehr_processor.py b/pyhealth/processors/cehr_processor.py new file mode 100644 index 000000000..19ba90170 --- /dev/null +++ b/pyhealth/processors/cehr_processor.py @@ -0,0 +1,554 @@ +"""CEHR-style tokenization, vocabulary, and sequence building for FHIR timelines. + +Key public API +-------------- +ConceptVocab + Token-to-dense-id mapping with PAD/UNK reserved at 0 and 1. JSON-serializable. + +CehrProcessor + FeatureProcessor subclass that owns a ConceptVocab, can be warmed over a patient + stream, and converts a Patient's tabular FHIR rows into CEHR-aligned sequences. + +build_cehr_sequences(patient, vocab, max_len) + Flatten a Patient's tabular FHIR rows into CEHR-aligned feature lists. + +infer_mortality_label(patient) + Heuristic binary mortality label from flattened patient rows. +""" + +from __future__ import annotations + +from dataclasses import dataclass, field +from datetime import datetime +from pathlib import Path +from typing import Any, Dict, Iterable, List, Optional, Tuple + +import orjson + +from pyhealth.data import Patient + +from .base_processor import FeatureProcessor +from . import register_processor + +DEFAULT_PAD = 0 +DEFAULT_UNK = 1 + +# --------------------------------------------------------------------------- +# Datetime helpers +# +# These are intentional copies of the identically-named functions in +# pyhealth.datasets.fhir_utils. They exist here to avoid a circular import: +# importing pyhealth.datasets.fhir_utils (even as a submodule) triggers +# pyhealth/datasets/__init__.py, which imports MIMIC4FHIRDataset, which in turn +# imports from this module — creating a cycle. The implementations are pure +# stdlib (no pyhealth deps), so keeping them in sync is straightforward; +# any change to the canonical copy in fhir_utils must be mirrored here. +# --------------------------------------------------------------------------- + + +def parse_dt(s: Optional[str]) -> Optional[datetime]: + """Parse an ISO 8601 or YYYY-MM-DD date string to a naive datetime.""" + if not s: + return None + try: + dt = datetime.fromisoformat(s.replace("Z", "+00:00")) + except ValueError: + dt = None + if dt is None and len(s) >= 10: + try: + dt = datetime.strptime(s[:10], "%Y-%m-%d") + except ValueError: + return None + if dt is None: + return None + return dt.replace(tzinfo=None) if dt.tzinfo is not None else dt + + +def as_naive(dt: Optional[datetime]) -> Optional[datetime]: + """Strip timezone info from a datetime, or return None unchanged.""" + if dt is None: + return None + return dt.replace(tzinfo=None) if dt.tzinfo is not None else dt + + +__all__ = [ + # Constants + "DEFAULT_PAD", + "DEFAULT_UNK", + "EVENT_TYPE_TO_TOKEN_TYPE", + # Vocabulary + "ConceptVocab", + "ensure_special_tokens", + # Processor + "CehrProcessor", + # Sequence building + "collect_cehr_timeline_events", + "warm_mpf_vocab_from_patient", + "build_cehr_sequences", + # Labels + "infer_mortality_label", +] + +EVENT_TYPE_TO_TOKEN_TYPE = { + "encounter": 1, + "condition": 2, + "medication_request": 3, + "observation": 4, + "procedure": 5, +} + +# Table-driven lookups for flattened event-row column access. +_CONCEPT_KEY_COL: Dict[str, str] = { + "condition": "condition/concept_key", + "observation": "observation/concept_key", + "medication_request": "medication_request/concept_key", + "procedure": "procedure/concept_key", +} + +_ENCOUNTER_ID_COL: Dict[str, str] = { + "condition": "condition/encounter_id", + "observation": "observation/encounter_id", + "medication_request": "medication_request/encounter_id", + "procedure": "procedure/encounter_id", + "encounter": "encounter/encounter_id", +} + +# --------------------------------------------------------------------------- +# ConceptVocab +# --------------------------------------------------------------------------- + + +@dataclass +class ConceptVocab: + """Maps concept keys to dense ids with PAD/UNK reserved at 0 and 1.""" + + token_to_id: Dict[str, int] = field(default_factory=dict) + pad_id: int = DEFAULT_PAD + unk_id: int = DEFAULT_UNK + _next_id: int = 2 + + def __post_init__(self) -> None: + if not self.token_to_id: + self.token_to_id = {"": self.pad_id, "": self.unk_id} + self._next_id = 2 + + def add_token(self, key: str) -> int: + if key in self.token_to_id: + return self.token_to_id[key] + tid = self._next_id + self.token_to_id[key] = tid + self._next_id += 1 + return tid + + def __getitem__(self, key: str) -> int: + return self.token_to_id.get(key, self.unk_id) + + @property + def vocab_size(self) -> int: + return self._next_id + + def to_json(self) -> Dict[str, Any]: + return { + "token_to_id": self.token_to_id, + "next_id": self._next_id, + "pad_id": self.pad_id, + "unk_id": self.unk_id, + } + + @classmethod + def from_json(cls, data: Dict[str, Any]) -> "ConceptVocab": + pad_id = int(data.get("pad_id", DEFAULT_PAD)) + unk_id = int(data.get("unk_id", DEFAULT_UNK)) + vocab = cls(pad_id=pad_id, unk_id=unk_id) + loaded = dict(data.get("token_to_id") or {}) + if not loaded: + vocab._next_id = int(data.get("next_id", 2)) + return vocab + vocab.token_to_id = loaded + vocab._next_id = int(data.get("next_id", max(loaded.values()) + 1)) + return vocab + + def save(self, path: str) -> None: + Path(path).parent.mkdir(parents=True, exist_ok=True) + Path(path).write_bytes(orjson.dumps(self.to_json(), option=orjson.OPT_SORT_KEYS)) + + @classmethod + def load(cls, path: str) -> "ConceptVocab": + return cls.from_json(orjson.loads(Path(path).read_bytes())) + + +def ensure_special_tokens(vocab: ConceptVocab) -> Dict[str, int]: + """Add EHRMamba/CEHR special tokens and return their ids.""" + return {name: vocab.add_token(name) for name in ("", "", "", "")} + + +# --------------------------------------------------------------------------- +# Row utilities for flattened event stream +# --------------------------------------------------------------------------- + + +def _clean_string(value: Any) -> Optional[str]: + if value is None: + return None + if isinstance(value, str): + return value.strip() or None + return str(value) + + +def _deceased_boolean_column_means_dead(value: Any) -> bool: + """True only for an explicit affirmative stored flag (not Python truthiness).""" + s = _clean_string(value) + return s is not None and s.lower() == "true" + + +def _row_datetime(value: Any) -> Optional[datetime]: + if value is None: + return None + if isinstance(value, datetime): + return as_naive(value) + try: + return parse_dt(str(value)) + except Exception: + return None + + +def _concept_key_from_row(row: Dict[str, Any]) -> str: + event_type = row.get("event_type") + col = _CONCEPT_KEY_COL.get(event_type) + if col: + return _clean_string(row.get(col)) or f"{event_type}|unknown" + if event_type == "encounter": + enc_class = _clean_string(row.get("encounter/encounter_class")) + return f"encounter|{enc_class}" if enc_class else "encounter|unknown" + return f"{event_type or 'event'}|unknown" + + +def _linked_encounter_id_from_row(row: Dict[str, Any]) -> Optional[str]: + col = _ENCOUNTER_ID_COL.get(row.get("event_type")) + return _clean_string(row.get(col)) if col else None + + +def _birth_datetime_from_patient(patient: Patient) -> Optional[datetime]: + for row in patient.data_source.iter_rows(named=True): + if row.get("event_type") != "patient": + continue + birth = _row_datetime(row.get("timestamp")) + if birth is not None: + return birth + raw = _clean_string(row.get("patient/birth_date")) + if raw: + return parse_dt(raw) + return None + + +def _sequential_visit_idx_for_time( + event_time: Optional[datetime], + visit_encounters: List[Tuple[datetime, int]], +) -> int: + if not visit_encounters: + return 0 + if event_time is None: + return visit_encounters[-1][1] + event_time = as_naive(event_time) + chosen = visit_encounters[0][1] + for encounter_start, visit_idx in visit_encounters: + if encounter_start <= event_time: + chosen = visit_idx + else: + break + return chosen + + +# --------------------------------------------------------------------------- +# CEHR timeline and sequence building +# --------------------------------------------------------------------------- + + +def collect_cehr_timeline_events( + patient: Patient, +) -> List[Tuple[datetime, str, str, int]]: + """Collect (time, concept_key, event_type, visit_idx) tuples from a patient's rows.""" + rows = list( + patient.data_source.sort(["timestamp", "event_type"], nulls_last=True).iter_rows(named=True) + ) + + # Build encounter list — rows are already timestamp-sorted so the loop + # preserves chronological order without an explicit sort. + encounter_rows: List[Tuple[datetime, str]] = [] + for row in rows: + if row.get("event_type") != "encounter": + continue + enc_id = _linked_encounter_id_from_row(row) + enc_start = _row_datetime(row.get("timestamp")) + if enc_id is not None and enc_start is not None: + encounter_rows.append((enc_start, enc_id)) + + encounter_visit_idx = {enc_id: idx for idx, (_, enc_id) in enumerate(encounter_rows)} + encounter_start_by_id = {enc_id: enc_start for enc_start, enc_id in encounter_rows} + visit_encounters = [(enc_start, idx) for idx, (enc_start, _) in enumerate(encounter_rows)] + + events: List[Tuple[datetime, str, str, int]] = [] + unlinked: List[Tuple[Optional[datetime], str, str]] = [] + + for row in rows: + event_type = row.get("event_type") + if event_type not in EVENT_TYPE_TO_TOKEN_TYPE: + continue + event_time = _row_datetime(row.get("timestamp")) + concept_key = _concept_key_from_row(row) + + if event_type == "encounter": + enc_id = _linked_encounter_id_from_row(row) + if enc_id is None or event_time is None: + continue + visit_idx = encounter_visit_idx.get(enc_id) + if visit_idx is None: + continue + events.append((event_time, concept_key, event_type, visit_idx)) + continue + + enc_id = _linked_encounter_id_from_row(row) + if enc_id and enc_id in encounter_visit_idx: + visit_idx = encounter_visit_idx[enc_id] + if event_time is None: + event_time = encounter_start_by_id.get(enc_id) + if event_time is None: + continue + events.append((event_time, concept_key, event_type, visit_idx)) + else: + unlinked.append((event_time, concept_key, event_type)) + + for event_time, concept_key, event_type in unlinked: + visit_idx = _sequential_visit_idx_for_time(event_time, visit_encounters) + if event_time is None: + if not visit_encounters: + continue + for enc_start, enc_visit_idx in visit_encounters: + if enc_visit_idx == visit_idx: + event_time = enc_start + break + else: + event_time = visit_encounters[-1][0] + if event_time is None: + continue + events.append((event_time, concept_key, event_type, visit_idx)) + + events.sort(key=lambda item: item[0]) + return events + + +def warm_mpf_vocab_from_patient( + vocab: ConceptVocab, + patient: Patient, + clinical_cap: int, +) -> None: + """Add concept keys from the last clinical_cap events of a patient to vocab.""" + tail = collect_cehr_timeline_events(patient)[-clinical_cap:] if clinical_cap > 0 else [] + for _, concept_key, _, _ in tail: + vocab.add_token(concept_key) + + +def build_cehr_sequences( + patient: Patient, + vocab: ConceptVocab, + max_len: int, + *, + base_time: Optional[datetime] = None, + grow_vocab: bool = True, +) -> Tuple[List[int], List[int], List[float], List[float], List[int], List[int]]: + """Flatten a patient's tabular FHIR rows into CEHR-aligned feature lists.""" + events = collect_cehr_timeline_events(patient) + birth = _birth_datetime_from_patient(patient) + + if base_time is None: + base_time = events[0][0] if events else datetime.now() + base_time = as_naive(base_time) + birth = as_naive(birth) + + concept_ids: List[int] = [] + token_types: List[int] = [] + time_stamps: List[float] = [] + ages: List[float] = [] + visit_orders: List[int] = [] + visit_segments: List[int] = [] + + for event_time, concept_key, event_type, visit_idx in (events[-max_len:] if max_len > 0 else []): + event_time = as_naive(event_time) + concept_id = vocab.add_token(concept_key) if grow_vocab else vocab[concept_key] + token_type = EVENT_TYPE_TO_TOKEN_TYPE.get(event_type, 0) + time_delta = ( + float((event_time - base_time).total_seconds()) + if base_time is not None and event_time is not None + else 0.0 + ) + age_years = ( + (event_time - birth).days / 365.25 + if birth is not None and event_time is not None + else 0.0 + ) + concept_ids.append(concept_id) + token_types.append(token_type) + time_stamps.append(time_delta) + ages.append(age_years) + visit_orders.append(min(visit_idx, 511)) + visit_segments.append(visit_idx % 2) + + return concept_ids, token_types, time_stamps, ages, visit_orders, visit_segments + + +# --------------------------------------------------------------------------- +# Label inference +# --------------------------------------------------------------------------- + + +def infer_mortality_label(patient: Patient) -> int: + """Heuristic binary mortality label from flattened patient rows.""" + for row in patient.data_source.iter_rows(named=True): + if row.get("event_type") == "patient": + if _deceased_boolean_column_means_dead(row.get("patient/deceased_boolean")): + return 1 + if _clean_string(row.get("patient/deceased_datetime")): + return 1 + for row in patient.data_source.iter_rows(named=True): + if row.get("event_type") == "condition": + key = (_clean_string(row.get("condition/concept_key")) or "").lower() + if any(token in key for token in ("death", "deceased", "mortality")): + return 1 + return 0 + + +# --------------------------------------------------------------------------- +# CehrProcessor +# --------------------------------------------------------------------------- + + +@register_processor("cehr") +class CehrProcessor(FeatureProcessor): + """CEHR concept sequence processor for FHIR timelines. + + Owns a :class:`ConceptVocab` and converts a + :class:`~pyhealth.data.Patient`'s tabular FHIR event rows into + CEHR-aligned integer sequence lists ready for downstream models. + + This processor departs from the standard ``FeatureProcessor`` contract + in two ways that are intentional for this domain: + + * ``process(patient)`` takes a full :class:`~pyhealth.data.Patient` rather + than a single scalar field value, because building CEHR sequences + requires access to all event rows simultaneously. + * ``fit(patients, clinical_cap)`` takes an iterable of + :class:`~pyhealth.data.Patient` objects instead of the base-class + ``fit(samples, field)`` signature, because vocabulary warming is driven + by the Patient timeline, not a pre-aggregated sample dict. + + Typical usage:: + + processor = CehrProcessor(max_len=512) + processor.fit(dataset.iter_patients()) # warm vocabulary + sequences = processor.process(some_patient) # build sequences + processor.save("vocab.json") # persist vocab + + Attributes: + vocab: Concept-to-id mapping (PAD=0, UNK=1). + max_len: Maximum number of clinical tokens per patient (boundary tokens + not counted; see :class:`~pyhealth.tasks.MPFClinicalPredictionTask`). + frozen_vocab: When True, unknown concepts map to UNK instead of adding + new ids — used for multi-worker safety after vocab warm-up. + """ + + def __init__( + self, + vocab: Optional[ConceptVocab] = None, + max_len: int = 512, + frozen_vocab: bool = False, + ) -> None: + self.vocab = vocab or ConceptVocab() + self.max_len = max_len + self.frozen_vocab = frozen_vocab + + def fit( # type: ignore[override] + self, + patients: Iterable[Patient], + clinical_cap: Optional[int] = None, + ) -> "CehrProcessor": + """Warm vocabulary from a stream of patients. + + Note: this method intentionally overrides ``FeatureProcessor.fit(samples, + field)`` with a different signature, because CEHR vocabulary warming + operates on :class:`~pyhealth.data.Patient` timelines rather than + pre-aggregated sample dicts. + + Special tokens are *not* inserted here; they are added lazily by + :meth:`~pyhealth.tasks.MPFClinicalPredictionTask._ensure_processor` + on the first call to the task. This keeps ``fit`` focused on concept + key discovery. + + Args: + patients: Iterable of :class:`~pyhealth.data.Patient` objects. + clinical_cap: Maximum number of tail events to scan per patient. + Defaults to ``max_len - 2`` (room for two boundary tokens). + + Returns: + self (for chaining). + """ + cap = clinical_cap if clinical_cap is not None else max(0, self.max_len - 2) + for patient in patients: + warm_mpf_vocab_from_patient(self.vocab, patient, cap) + return self + + def process( + self, + patient: Patient, + ) -> Tuple[List[int], List[int], List[float], List[float], List[int], List[int]]: + """Build CEHR sequences from a patient's FHIR event rows. + + Args: + patient: A tabular :class:`~pyhealth.data.Patient`. + + Returns: + Six equal-length lists ``(concept_ids, token_type_ids, time_stamps, + ages, visit_orders, visit_segments)`` ready for boundary-token + insertion and left-padding by the task. + """ + clinical_cap = max(0, self.max_len - 2) + return build_cehr_sequences( + patient, self.vocab, clinical_cap, grow_vocab=not self.frozen_vocab + ) + + def save(self, path: str) -> None: + """Persist the vocabulary to a JSON file at *path*.""" + self.vocab.save(path) + + def load(self, path: str) -> None: + """Load a previously saved vocabulary from *path*.""" + self.vocab = ConceptVocab.load(path) + + def is_token(self) -> bool: + """All six output lists contain discrete token/index values.""" + return True + + def schema(self) -> Tuple[str, ...]: + return ( + "concept_ids", + "token_type_ids", + "time_stamps", + "ages", + "visit_orders", + "visit_segments", + ) + + def dim(self) -> Tuple[int, ...]: + """Each of the six output lists becomes a 1-D tensor.""" + return (1, 1, 1, 1, 1, 1) + + def spatial(self) -> Tuple[bool, ...]: + """All six outputs are along the sequence (temporal) axis.""" + return (True, True, True, True, True, True) + + def __repr__(self) -> str: + # frozen_vocab is a runtime flag, not a constructor parameter, so it + # must be excluded here. This repr is used by BaseDataset.set_task to + # compute the LitData task-cache UUID via vars(task); including + # frozen_vocab would produce different UUIDs for single- vs + # multi-worker runs of the same pipeline, defeating caching. + return f"CehrProcessor(max_len={self.max_len})" diff --git a/pyhealth/tasks/__init__.py b/pyhealth/tasks/__init__.py index 797988377..ffdf99560 100644 --- a/pyhealth/tasks/__init__.py +++ b/pyhealth/tasks/__init__.py @@ -66,3 +66,11 @@ VariantClassificationClinVar, ) from .patient_linkage_mimic3 import PatientLinkageMIMIC3Task + + +def __getattr__(name: str): + if name == "MPFClinicalPredictionTask": + from .mpf_clinical_prediction import MPFClinicalPredictionTask + + return MPFClinicalPredictionTask + raise AttributeError(f"module {__name__!r} has no attribute {name!r}") diff --git a/pyhealth/tasks/mpf_clinical_prediction.py b/pyhealth/tasks/mpf_clinical_prediction.py new file mode 100644 index 000000000..fb60c8a70 --- /dev/null +++ b/pyhealth/tasks/mpf_clinical_prediction.py @@ -0,0 +1,228 @@ +"""Multitask Prompted Fine-tuning (MPF) clinical prediction on FHIR timelines.""" + +from __future__ import annotations + +import itertools +from typing import Any, Dict, List, Optional + +import polars as pl +import torch + +from pyhealth.data import Patient +from pyhealth.processors.cehr_processor import ( + CehrProcessor, + ConceptVocab, + ensure_special_tokens, + infer_mortality_label, +) + +from .base_task import BaseTask + + +def _pad_int(seq: List[int], max_len: int, pad: int = 0) -> List[int]: + if len(seq) > max_len: + return seq[-max_len:] + return seq + [pad] * (max_len - len(seq)) + + +def _pad_float(seq: List[float], max_len: int, pad: float = 0.0) -> List[float]: + if len(seq) > max_len: + return seq[-max_len:] + return seq + [pad] * (max_len - len(seq)) + + +def _left_pad_int(seq: List[int], max_len: int, pad: int = 0) -> List[int]: + if len(seq) > max_len: + return seq[-max_len:] + return [pad] * (max_len - len(seq)) + seq + + +def _left_pad_float(seq: List[float], max_len: int, pad: float = 0.0) -> List[float]: + if len(seq) > max_len: + return seq[-max_len:] + return [pad] * (max_len - len(seq)) + seq + + +class MPFClinicalPredictionTask(BaseTask): + """Binary mortality prediction from FHIR CEHR sequences with optional MPF tokens. + + The task owns a :class:`~pyhealth.processors.CehrProcessor` and its + :class:`~pyhealth.processors.ConceptVocab`. For single-worker use the + vocabulary grows lazily in :meth:`__call__`. For multi-worker LitData + runs, call :meth:`warm_vocab` before + :meth:`~pyhealth.datasets.BaseDataset.set_task` so the vocabulary is + complete and :attr:`frozen_vocab` prevents races across workers. + + Attributes: + max_len: Truncated sequence length (must be >= 2 for boundary tokens). + use_mpf: If True, use ```` / ```` specials; else ```` / ````. + processor: The CEHR processor owning the shared concept vocabulary. + frozen_vocab: If True, do not add new concept ids (post-warmup parallel path). + """ + + task_name: str = "MPFClinicalPredictionFHIR" + input_schema: Dict[str, Any] = { + "concept_ids": ("tensor", {"dtype": torch.long}), + "token_type_ids": ("tensor", {"dtype": torch.long}), + "time_stamps": "tensor", + "ages": "tensor", + "visit_orders": ("tensor", {"dtype": torch.long}), + "visit_segments": ("tensor", {"dtype": torch.long}), + } + output_schema: Dict[str, str] = {"label": "binary"} + + def __init__( + self, + max_len: int = 512, + use_mpf: bool = True, + processor: Optional[CehrProcessor] = None, + ) -> None: + if max_len < 2: + raise ValueError("max_len must be >= 2 for MPF boundary tokens") + self.max_len = max_len + self.use_mpf = use_mpf + self.processor = processor or CehrProcessor(max_len=max_len) + self._specials: Optional[Dict[str, int]] = None + + # ------------------------------------------------------------------ + # Backward-compatible property aliases + # ------------------------------------------------------------------ + + @property + def vocab(self) -> ConceptVocab: + return self.processor.vocab + + @vocab.setter + def vocab(self, value: ConceptVocab) -> None: + self.processor.vocab = value + + @property + def frozen_vocab(self) -> bool: + return self.processor.frozen_vocab + + @frozen_vocab.setter + def frozen_vocab(self, value: bool) -> None: + self.processor.frozen_vocab = value + + # ------------------------------------------------------------------ + # Vocabulary warm-up (call before set_task for multi-worker safety) + # ------------------------------------------------------------------ + + def warm_vocab(self, dataset: Any, num_workers: int = 1) -> None: + """Warm CEHR vocabulary from *dataset* before calling ``set_task``. + + For single-worker pipelines this is **optional** — ``__call__`` grows + the vocabulary lazily. For multi-worker pipelines call this first so + that the vocabulary is fully populated before LitData forks workers:: + + task = MPFClinicalPredictionTask(max_len=512) + task.warm_vocab(ds, num_workers=4) + sample_dataset = ds.set_task(task, num_workers=4) + + Args: + dataset: A :class:`~pyhealth.datasets.BaseDataset` instance whose + ``global_event_df`` has already been built. + num_workers: Number of workers that will be passed to ``set_task``. + When > 1, ``frozen_vocab`` is set after warming so that worker + processes look up tokens instead of racing on + :class:`~pyhealth.processors.ConceptVocab`. + """ + from litdata.processing.data_processor import in_notebook + + worker_count = 1 if in_notebook() else num_workers + filtered = self.pre_filter(dataset.global_event_df) + warmup_pids = ( + filtered.select("patient_id") + .unique() + .collect(engine="streaming") + .to_series() + .sort() + .to_list() + ) + patient_count = len(warmup_pids) + effective_workers = min(worker_count, patient_count) if patient_count else 1 + + clinical_cap = max(0, self.max_len - 2) + base = dataset.global_event_df + + def _iter_warmup_patients(): + for batch in itertools.batched(warmup_pids, 128): + batch_df = ( + base.filter(pl.col("patient_id").is_in(batch)) + .collect(engine="streaming") + ) + for patient_df in batch_df.partition_by("patient_id"): + yield Patient( + patient_id=patient_df["patient_id"][0], + data_source=patient_df, + ) + + self.processor.fit(_iter_warmup_patients(), clinical_cap=clinical_cap) + self.frozen_vocab = effective_workers > 1 + self._specials = ensure_special_tokens(self.processor.vocab) + + # ------------------------------------------------------------------ + # Per-patient sample generation + # ------------------------------------------------------------------ + + def _ensure_processor(self) -> CehrProcessor: + if self._specials is None: + self._specials = ensure_special_tokens(self.processor.vocab) + return self.processor + + def __call__(self, patient: Patient) -> List[Dict[str, Any]]: + """Build one labeled sample dict per patient. + + Args: + patient: A tabular :class:`~pyhealth.data.Patient`. + + Returns: + A one-element list with ``concept_ids``, tensor-ready feature lists, and + ``label`` (0/1). Boundary tokens are always included; when + ``max_len == 2`` the sequence is ````/```` and ```` only. + """ + processor = self._ensure_processor() + pid = patient.patient_id + ( + concept_ids, + token_types, + time_stamps, + ages, + visit_orders, + visit_segments, + ) = processor.process(patient) + + assert self._specials is not None + mor_id = self._specials[""] if self.use_mpf else self._specials[""] + reg_id = self._specials[""] + z0 = 0 + zf = 0.0 + concept_ids = [mor_id] + concept_ids + [reg_id] + token_types = [z0] + token_types + [z0] + time_stamps = [zf] + time_stamps + [zf] + ages = [zf] + ages + [zf] + visit_orders = [z0] + visit_orders + [z0] + visit_segments = [z0] + visit_segments + [z0] + + ml = self.max_len + concept_ids = _left_pad_int(concept_ids, ml, processor.vocab.pad_id) + token_types = _left_pad_int(token_types, ml, 0) + time_stamps = _left_pad_float(time_stamps, ml, 0.0) + ages = _left_pad_float(ages, ml, 0.0) + visit_orders = _left_pad_int(visit_orders, ml, 0) + visit_segments = _left_pad_int(visit_segments, ml, 0) + + label = infer_mortality_label(patient) + return [ + { + "patient_id": pid, + "visit_id": f"{pid}-0", + "concept_ids": concept_ids, + "token_type_ids": token_types, + "time_stamps": time_stamps, + "ages": ages, + "visit_orders": visit_orders, + "visit_segments": visit_segments, + "label": label, + } + ] diff --git a/pyproject.toml b/pyproject.toml index 98f88d47b..47e40a62d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -45,6 +45,7 @@ dependencies = [ "more-itertools~=10.8.0", "einops>=0.8.0", "linear-attention-transformer>=0.19.1", + "orjson~=3.10", ] license = "BSD-3-Clause" license-files = ["LICENSE.md"] diff --git a/tests/core/test_ehrmamba_cehr.py b/tests/core/test_ehrmamba_cehr.py new file mode 100644 index 000000000..4b1f5065e --- /dev/null +++ b/tests/core/test_ehrmamba_cehr.py @@ -0,0 +1,124 @@ +import unittest + +import torch + +from pyhealth.datasets import create_sample_dataset, get_dataloader +from pyhealth.models import EHRMambaCEHR +from pyhealth.tasks.mpf_clinical_prediction import MPFClinicalPredictionTask + + +def _tiny_samples(seq: int = 16) -> tuple: + from pyhealth.processors.cehr_processor import ConceptVocab, ensure_special_tokens + + task = MPFClinicalPredictionTask(max_len=seq, use_mpf=True) + task.vocab = ConceptVocab() + sp = ensure_special_tokens(task.vocab) + mid = task.vocab.add_token("test|filler") + samples = [] + for lab in (0, 1): + samples.append( + { + "patient_id": f"p{lab}", + "visit_id": f"v{lab}", + "concept_ids": [sp[""]] + [mid] * (seq - 2) + [sp[""]], + "token_type_ids": [0] * seq, + "time_stamps": [0.0] * seq, + "ages": [50.0] * seq, + "visit_orders": [0] * seq, + "visit_segments": [0] * seq, + "label": lab, + } + ) + return samples, task + + +class TestEHRMambaCEHR(unittest.TestCase): + def test_readout_pools_rightmost_non_pad(self) -> None: + """MPF padding between tokens must not make pooling pick a pad position.""" + + from pyhealth.models.utils import ( + get_last_visit, + get_rightmost_masked_timestep, + ) + + h = torch.tensor([[[1.0, 0.0], [2.0, 0.0], [0.0, 0.0], [99.0, 0.0]]]) + m = torch.tensor([[True, True, False, True]]) + out = get_rightmost_masked_timestep(h, m) + self.assertTrue(torch.allclose(out[0], torch.tensor([99.0, 0.0]))) + wrong = get_last_visit(h, m) + self.assertFalse(torch.allclose(out[0], wrong[0])) + + def test_end_to_end_fhir_pipeline(self) -> None: + import tempfile + from pathlib import Path + + from pyhealth.datasets import MIMIC4FHIRDataset, create_sample_dataset + from pyhealth.datasets import get_dataloader + + from tests.core.test_mimic4_fhir_ndjson_fixtures import run_task, write_two_class_ndjson + + task = MPFClinicalPredictionTask(max_len=32, use_mpf=True) + with tempfile.TemporaryDirectory() as tmp: + write_two_class_ndjson(Path(tmp)) + ds = MIMIC4FHIRDataset( + root=tmp, glob_pattern="*.ndjson", cache_dir=tmp + ) + samples = run_task(ds, task) + sample_ds = create_sample_dataset( + samples=samples, + input_schema=task.input_schema, + output_schema=task.output_schema, + dataset_name="fhir_test", + ) + vocab_size = max(max(s["concept_ids"]) for s in samples) + 1 + model = EHRMambaCEHR( + dataset=sample_ds, + vocab_size=vocab_size, + embedding_dim=64, + num_layers=1, + ) + batch = next( + iter(get_dataloader(sample_ds, batch_size=2, shuffle=False)) + ) + out = model(**batch) + self.assertIn("loss", out) + out["loss"].backward() + + def test_forward_backward(self) -> None: + samples, task = _tiny_samples() + ds = create_sample_dataset( + samples=samples, + input_schema=task.input_schema, + output_schema=task.output_schema, + ) + vocab_size = max(max(s["concept_ids"]) for s in samples) + 1 + model = EHRMambaCEHR( + dataset=ds, + vocab_size=vocab_size, + embedding_dim=64, + num_layers=1, + state_size=8, + ) + batch = next(iter(get_dataloader(ds, batch_size=2, shuffle=False))) + out = model(**batch) + self.assertEqual(out["logit"].shape[0], 2) + out["loss"].backward() + + def test_eval_mode(self) -> None: + samples, task = _tiny_samples() + ds = create_sample_dataset( + samples=samples, + input_schema=task.input_schema, + output_schema=task.output_schema, + ) + vocab_size = max(max(s["concept_ids"]) for s in samples) + 1 + model = EHRMambaCEHR(dataset=ds, vocab_size=vocab_size, embedding_dim=32, num_layers=1) + model.eval() + with torch.no_grad(): + batch = next(iter(get_dataloader(ds, batch_size=2, shuffle=False))) + out = model(**batch) + self.assertIn("y_prob", out) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/core/test_mimic4_fhir_dataset.py b/tests/core/test_mimic4_fhir_dataset.py new file mode 100644 index 000000000..ceb0a9b25 --- /dev/null +++ b/tests/core/test_mimic4_fhir_dataset.py @@ -0,0 +1,694 @@ +import gzip +import tempfile +import unittest +from pathlib import Path +from typing import Dict, List + +import orjson +import polars as pl + +from pyhealth.data import Patient +from pyhealth.datasets import MIMIC4FHIRDataset +from pyhealth.datasets.fhir_utils import ( + _flatten_resource_to_table_row, +) +from pyhealth.processors.cehr_processor import ( + ConceptVocab, + build_cehr_sequences, + collect_cehr_timeline_events, + infer_mortality_label, +) + +from tests.core.test_mimic4_fhir_ndjson_fixtures import ( + ndjson_two_class_text, + run_task, + write_one_patient_ndjson, + write_two_class_ndjson, +) + + +def _third_patient_loinc_resources() -> List[Dict[str, object]]: + return [ + { + "resourceType": "Patient", + "id": "p-synth-3", + "birthDate": "1960-01-01", + }, + { + "resourceType": "Encounter", + "id": "e3", + "subject": {"reference": "Patient/p-synth-3"}, + "period": {"start": "2020-08-01T10:00:00Z"}, + "class": {"code": "IMP"}, + }, + { + "resourceType": "Observation", + "id": "o3", + "subject": {"reference": "Patient/p-synth-3"}, + "encounter": {"reference": "Encounter/e3"}, + "effectiveDateTime": "2020-08-01T12:00:00Z", + "code": {"coding": [{"system": "http://loinc.org", "code": "999-9"}]}, + }, + ] + + +def write_two_class_plus_third_ndjson(directory: Path, *, name: str = "fixture.ndjson") -> Path: + lines = ndjson_two_class_text().strip().split("\n") + lines.extend(orjson.dumps(r).decode("utf-8") for r in _third_patient_loinc_resources()) + path = directory / name + path.write_text("\n".join(lines) + "\n", encoding="utf-8") + return path + + +def _patient_from_rows(patient_id: str, rows: List[Dict[str, object]]) -> Patient: + return Patient(patient_id=patient_id, data_source=pl.DataFrame(rows)) + + +class TestDeceasedBooleanFlattening(unittest.TestCase): + def test_string_false_not_coerced_by_python_bool(self) -> None: + """Non-conformant ``\"false\"`` string must not become stored ``\"true\"``.""" + row = _flatten_resource_to_table_row( + { + "resourceType": "Patient", + "id": "p-str-false", + "deceasedBoolean": "false", + } + ) + self.assertIsNotNone(row) + _table, payload = row + self.assertEqual(payload.get("deceased_boolean"), "false") + + def test_string_true_parsed(self) -> None: + row = _flatten_resource_to_table_row( + { + "resourceType": "Patient", + "id": "p-str-true", + "deceasedBoolean": "true", + } + ) + self.assertIsNotNone(row) + self.assertEqual(row[1].get("deceased_boolean"), "true") + + def test_json_booleans_unchanged(self) -> None: + for raw, expected in ((True, "true"), (False, "false")): + with self.subTest(raw=raw): + row = _flatten_resource_to_table_row( + { + "resourceType": "Patient", + "id": "p-bool", + "deceasedBoolean": raw, + } + ) + self.assertIsNotNone(row) + self.assertEqual(row[1].get("deceased_boolean"), expected) + + def test_unknown_deceased_type_stored_as_none(self) -> None: + row = _flatten_resource_to_table_row( + { + "resourceType": "Patient", + "id": "p-garbage", + "deceasedBoolean": {"unexpected": "object"}, + } + ) + self.assertIsNotNone(row) + self.assertIsNone(row[1].get("deceased_boolean")) + + def test_infer_mortality_respects_string_false_row(self) -> None: + patient = _patient_from_rows( + "p1", + [ + { + "event_type": "patient", + "timestamp": "2020-01-01T00:00:00", + "patient/deceased_boolean": "false", + }, + ], + ) + self.assertEqual(infer_mortality_label(patient), 0) + + +class TestMIMIC4FHIRDataset(unittest.TestCase): + def test_concept_vocab_from_json_empty_token_to_id(self) -> None: + v = ConceptVocab.from_json({"token_to_id": {}}) + self.assertIn("", v.token_to_id) + self.assertIn("", v.token_to_id) + self.assertEqual(v._next_id, 2) + + def test_concept_vocab_from_json_empty_respects_next_id(self) -> None: + v = ConceptVocab.from_json({"token_to_id": {}, "next_id": 50}) + self.assertEqual(v._next_id, 50) + + def test_sorted_ndjson_files_accepts_sequence_and_dedupes(self) -> None: + from pyhealth.datasets.fhir_utils import sorted_ndjson_files + + with tempfile.TemporaryDirectory() as tmp: + root = Path(tmp) + (root / "MimicPatient.ndjson.gz").write_text("x", encoding="utf-8") + (root / "MimicMedication.ndjson.gz").write_text("y", encoding="utf-8") + (root / "notes.txt").write_text("z", encoding="utf-8") + wide = sorted_ndjson_files(root, "**/*.ndjson.gz") + narrow = sorted_ndjson_files( + root, + ["MimicPatient*.ndjson.gz", "**/MimicPatient*.ndjson.gz"], + ) + self.assertEqual(len(wide), 2) + self.assertEqual(len(narrow), 1) + self.assertEqual(narrow[0].name, "MimicPatient.ndjson.gz") + + def test_dataset_accepts_glob_patterns_kwarg(self) -> None: + with tempfile.TemporaryDirectory() as tmp: + write_one_patient_ndjson(Path(tmp)) + ds = MIMIC4FHIRDataset( + root=tmp, glob_patterns=["*.ndjson"], cache_dir=tmp + ) + self.assertEqual(ds.glob_patterns, ["*.ndjson"]) + _ = ds.global_event_df.collect(engine="streaming") + + def test_dataset_rejects_both_glob_kwargs(self) -> None: + with tempfile.TemporaryDirectory() as tmp: + with self.assertRaises(ValueError): + MIMIC4FHIRDataset( + root=tmp, + glob_pattern="*.ndjson", + glob_patterns=["*.ndjson"], + cache_dir=tmp, + ) + + def test_disk_fixture_resolves_events_per_patient(self) -> None: + with tempfile.TemporaryDirectory() as tmp: + write_one_patient_ndjson(Path(tmp)) + ds = MIMIC4FHIRDataset(root=tmp, glob_pattern="*.ndjson", cache_dir=tmp) + sub = ds.global_event_df.filter(pl.col("patient_id") == "p-synth-1").collect( + engine="streaming" + ) + self.assertGreaterEqual(len(sub), 2) + self.assertIn("condition/concept_key", sub.columns) + + def test_prepared_flat_tables_exist(self) -> None: + with tempfile.TemporaryDirectory() as tmp: + write_two_class_ndjson(Path(tmp)) + ds = MIMIC4FHIRDataset(root=tmp, glob_pattern="*.ndjson", cache_dir=tmp) + _ = ds.global_event_df.collect(engine="streaming") + prepared = ds.prepared_tables_dir + self.assertTrue((prepared / "patient.parquet").is_file()) + self.assertTrue((prepared / "encounter.parquet").is_file()) + self.assertTrue((prepared / "condition.parquet").is_file()) + + def test_build_cehr_non_empty(self) -> None: + from pyhealth.tasks.mpf_clinical_prediction import MPFClinicalPredictionTask + + with tempfile.TemporaryDirectory() as tmp: + write_one_patient_ndjson(Path(tmp)) + ds = MIMIC4FHIRDataset(root=tmp, glob_pattern="*.ndjson", cache_dir=tmp) + task = MPFClinicalPredictionTask(max_len=64, use_mpf=True) + run_task(ds, task) + self.assertIsInstance(task.vocab, ConceptVocab) + self.assertGreater(task.vocab.vocab_size, 2) + + def test_set_task_vocab_warm_on_litdata_cache_hit(self) -> None: + from pyhealth.tasks.mpf_clinical_prediction import MPFClinicalPredictionTask + + with tempfile.TemporaryDirectory() as tmp: + write_two_class_ndjson(Path(tmp)) + task_kw = {"max_len": 64, "use_mpf": True} + ds1 = MIMIC4FHIRDataset(root=tmp, glob_pattern="*.ndjson", cache_dir=tmp) + task1 = MPFClinicalPredictionTask(**task_kw) + task1.warm_vocab(ds1, num_workers=1) + ds1.set_task(task1, num_workers=1) + warm_size = task1.vocab.vocab_size + self.assertGreater(warm_size, 6) + ds2 = MIMIC4FHIRDataset(root=tmp, glob_pattern="*.ndjson", cache_dir=tmp) + task2 = MPFClinicalPredictionTask(**task_kw) + task2.warm_vocab(ds2, num_workers=1) + ds2.set_task(task2, num_workers=1) + self.assertEqual(task2.vocab.vocab_size, warm_size) + + def test_mortality_heuristic(self) -> None: + from pyhealth.tasks.mpf_clinical_prediction import MPFClinicalPredictionTask + + with tempfile.TemporaryDirectory() as tmp: + write_two_class_ndjson(Path(tmp)) + ds = MIMIC4FHIRDataset(root=tmp, glob_pattern="*.ndjson", cache_dir=tmp) + samples = run_task(ds, MPFClinicalPredictionTask(max_len=64, use_mpf=False)) + self.assertEqual({s["label"] for s in samples}, {0, 1}) + + def test_infer_deceased(self) -> None: + with tempfile.TemporaryDirectory() as tmp: + write_two_class_ndjson(Path(tmp)) + ds = MIMIC4FHIRDataset(root=tmp, glob_pattern="*.ndjson", cache_dir=tmp) + dead = ds.get_patient("p-synth-2") + self.assertEqual(infer_mortality_label(dead), 1) + + def test_disk_ndjson_gz_physionet_style(self) -> None: + with tempfile.TemporaryDirectory() as tmp: + gz_path = Path(tmp) / "fixture.ndjson.gz" + with gzip.open(gz_path, "wt", encoding="utf-8") as gz: + gz.write(ndjson_two_class_text()) + ds = MIMIC4FHIRDataset(root=tmp, glob_pattern="*.ndjson.gz", max_patients=5) + self.assertGreaterEqual(len(ds.unique_patient_ids), 1) + + def test_disk_ndjson_temp_dir(self) -> None: + from pyhealth.tasks.mpf_clinical_prediction import MPFClinicalPredictionTask + + with tempfile.TemporaryDirectory() as tmp: + write_two_class_ndjson(Path(tmp)) + ds = MIMIC4FHIRDataset(root=tmp, glob_pattern="*.ndjson", max_patients=5) + self.assertEqual(len(ds.unique_patient_ids), 2) + samples = run_task(ds, MPFClinicalPredictionTask(max_len=48, use_mpf=True)) + self.assertGreaterEqual(len(samples), 1) + for sample in samples: + self.assertIn("concept_ids", sample) + self.assertIn("label", sample) + + def test_global_event_df_schema_and_flattened_columns(self) -> None: + with tempfile.TemporaryDirectory() as tmp: + write_two_class_ndjson(Path(tmp)) + ds = MIMIC4FHIRDataset(root=tmp, glob_pattern="*.ndjson", cache_dir=tmp) + df = ds.global_event_df.collect(engine="streaming") + self.assertGreater(len(df), 0) + self.assertIn("patient_id", df.columns) + self.assertIn("timestamp", df.columns) + self.assertIn("event_type", df.columns) + self.assertIn("condition/concept_key", df.columns) + self.assertIn("observation/concept_key", df.columns) + self.assertIn("patient/deceased_boolean", df.columns) + + def test_set_task_produces_correct_samples(self) -> None: + from pyhealth.tasks.mpf_clinical_prediction import MPFClinicalPredictionTask + + with tempfile.TemporaryDirectory() as tmp: + write_two_class_ndjson(Path(tmp), name="fx.ndjson") + ds = MIMIC4FHIRDataset( + root=tmp, glob_pattern="*.ndjson", cache_dir=tmp, num_workers=1 + ) + task = MPFClinicalPredictionTask(max_len=48, use_mpf=True) + task.warm_vocab(ds, num_workers=1) + sample_ds = ds.set_task(task, num_workers=1) + samples = sorted( + [sample_ds[i] for i in range(len(sample_ds))], + key=lambda s: s["patient_id"], + ) + self.assertEqual(len(samples), 2) + for s in samples: + self.assertIn("concept_ids", s) + self.assertIn("label", s) + labels = {int(s["label"]) for s in samples} + self.assertEqual(labels, {0, 1}) + + def test_set_task_multi_worker_sets_frozen_vocab(self) -> None: + from pyhealth.tasks.mpf_clinical_prediction import MPFClinicalPredictionTask + + with tempfile.TemporaryDirectory() as tmp: + write_two_class_ndjson(Path(tmp)) + ds = MIMIC4FHIRDataset( + root=tmp, glob_pattern="*.ndjson", cache_dir=tmp, num_workers=2 + ) + task = MPFClinicalPredictionTask(max_len=48, use_mpf=True) + task.warm_vocab(ds, num_workers=2) + ds.set_task(task, num_workers=2) + self.assertTrue(task.frozen_vocab) + + def test_mpf_pre_filter_vocab_warmup_excludes_dropped_patients(self) -> None: + from pyhealth.tasks.mpf_clinical_prediction import MPFClinicalPredictionTask + + class TwoPatientMPFTask(MPFClinicalPredictionTask): + def pre_filter(self, df: pl.LazyFrame) -> pl.LazyFrame: + return df.filter(pl.col("patient_id").is_in(["p-synth-1", "p-synth-2"])) + + with tempfile.TemporaryDirectory() as tmp: + write_two_class_plus_third_ndjson(Path(tmp)) + ds = MIMIC4FHIRDataset( + root=tmp, glob_pattern="*.ndjson", cache_dir=tmp, num_workers=1 + ) + self.assertEqual(len(ds.unique_patient_ids), 3) + task = TwoPatientMPFTask(max_len=48, use_mpf=True) + task.warm_vocab(ds, num_workers=1) + ds.set_task(task, num_workers=1) + self.assertNotIn("http://loinc.org|999-9", task.vocab.token_to_id) + self.assertIn("http://loinc.org|789-0", task.vocab.token_to_id) + + def test_mpf_pre_filter_single_patient_limits_effective_workers(self) -> None: + """Pre-filter that yields one patient should cap effective_workers to 1. + + We verify the effective_workers logic directly rather than via + ``set_task`` because ``set_task`` with a 1-patient cohort produces + only one label class (p-synth-1 is alive → label=0), which causes + ``BinaryLabelProcessor.fit`` to raise "Expected 2 unique labels, got 1". + The invariant under test belongs to the ``set_task`` override in + ``MIMIC4FHIRDataset``; the Polars pre-filter and worker-count + formula are both exercised here without triggering that constraint. + """ + from pyhealth.tasks.mpf_clinical_prediction import MPFClinicalPredictionTask + + class OnePatientMPFTask(MPFClinicalPredictionTask): + def pre_filter(self, df: pl.LazyFrame) -> pl.LazyFrame: + return df.filter(pl.col("patient_id") == "p-synth-1") + + with tempfile.TemporaryDirectory() as tmp: + write_two_class_ndjson(Path(tmp)) + ds = MIMIC4FHIRDataset( + root=tmp, glob_pattern="*.ndjson", cache_dir=tmp, num_workers=2 + ) + task = OnePatientMPFTask(max_len=48, use_mpf=True) + warmup_pids = ( + task.pre_filter(ds.global_event_df) + .select("patient_id") + .unique() + .collect(engine="streaming") + .to_series() + .sort() + .to_list() + ) + self.assertEqual(warmup_pids, ["p-synth-1"]) + # One patient, two requested workers: effective_workers = min(2, 1) = 1 + effective_workers = min(2, len(warmup_pids)) if warmup_pids else 1 + self.assertEqual(effective_workers, 1) + + def test_encounter_reference_requires_exact_id(self) -> None: + patient = _patient_from_rows( + "p1", + [ + { + "patient_id": "p1", + "event_type": "patient", + "timestamp": None, + "patient/birth_date": "1950-01-01", + }, + { + "patient_id": "p1", + "event_type": "encounter", + "timestamp": "2020-06-01T10:00:00", + "encounter/encounter_id": "e1", + "encounter/encounter_class": "AMB", + }, + { + "patient_id": "p1", + "event_type": "encounter", + "timestamp": "2020-07-02T10:00:00", + "encounter/encounter_id": "e10", + "encounter/encounter_class": "IMP", + }, + { + "patient_id": "p1", + "event_type": "condition", + "timestamp": "2020-07-02T11:00:00", + "condition/encounter_id": "e10", + "condition/concept_key": "http://hl7.org/fhir/sid/icd-10-cm|I99", + }, + ], + ) + vocab = ConceptVocab() + concept_ids, *_ = build_cehr_sequences(patient, vocab, max_len=64) + tid = vocab["http://hl7.org/fhir/sid/icd-10-cm|I99"] + self.assertEqual(concept_ids.count(tid), 1) + + def test_unlinked_condition_emitted_once_with_two_encounters(self) -> None: + patient = _patient_from_rows( + "p1", + [ + { + "patient_id": "p1", + "event_type": "patient", + "timestamp": None, + "patient/birth_date": "1950-01-01", + }, + { + "patient_id": "p1", + "event_type": "encounter", + "timestamp": "2020-06-01T10:00:00", + "encounter/encounter_id": "ea", + "encounter/encounter_class": "AMB", + }, + { + "patient_id": "p1", + "event_type": "encounter", + "timestamp": "2020-07-01T10:00:00", + "encounter/encounter_id": "eb", + "encounter/encounter_class": "IMP", + }, + { + "patient_id": "p1", + "event_type": "condition", + "timestamp": "2020-06-15T12:00:00", + "condition/concept_key": "http://hl7.org/fhir/sid/icd-10-cm|Z00", + }, + ], + ) + vocab = ConceptVocab() + concept_ids, *_ = build_cehr_sequences(patient, vocab, max_len=64) + self.assertEqual(concept_ids.count(vocab["http://hl7.org/fhir/sid/icd-10-cm|Z00"]), 1) + + def test_cehr_sequence_shapes(self) -> None: + with tempfile.TemporaryDirectory() as tmp: + write_one_patient_ndjson(Path(tmp)) + ds = MIMIC4FHIRDataset(root=tmp, glob_pattern="*.ndjson", cache_dir=tmp) + patient = ds.get_patient("p-synth-1") + vocab = ConceptVocab() + concept_ids, token_types, time_stamps, ages, visit_orders, visit_segments = ( + build_cehr_sequences(patient, vocab, max_len=32) + ) + n = len(concept_ids) + self.assertEqual(len(token_types), n) + self.assertEqual(len(time_stamps), n) + self.assertEqual(len(ages), n) + self.assertEqual(len(visit_orders), n) + self.assertEqual(len(visit_segments), n) + self.assertGreater(n, 0) + + def test_build_cehr_max_len_zero_no_clinical_tokens(self) -> None: + patient = _patient_from_rows( + "p1", + [ + { + "patient_id": "p1", + "event_type": "patient", + "timestamp": None, + "patient/birth_date": "1950-01-01", + }, + { + "patient_id": "p1", + "event_type": "encounter", + "timestamp": "2020-06-01T10:00:00", + "encounter/encounter_id": "e1", + "encounter/encounter_class": "AMB", + }, + { + "patient_id": "p1", + "event_type": "condition", + "timestamp": "2020-06-01T11:00:00", + "condition/encounter_id": "e1", + "condition/concept_key": "http://hl7.org/fhir/sid/icd-10-cm|I10", + }, + ], + ) + vocab = ConceptVocab() + c, _, _, _, _, vs = build_cehr_sequences(patient, vocab, max_len=0) + self.assertEqual(c, []) + self.assertEqual(vs, []) + + def test_visit_segments_alternate_by_visit_index(self) -> None: + patient = _patient_from_rows( + "p1", + [ + { + "patient_id": "p1", + "event_type": "patient", + "timestamp": None, + "patient/birth_date": "1950-01-01", + }, + { + "patient_id": "p1", + "event_type": "encounter", + "timestamp": "2020-06-01T10:00:00", + "encounter/encounter_id": "e0", + "encounter/encounter_class": "AMB", + }, + { + "patient_id": "p1", + "event_type": "condition", + "timestamp": "2020-06-01T11:00:00", + "condition/encounter_id": "e0", + "condition/concept_key": "http://hl7.org/fhir/sid/icd-10-cm|I10", + }, + { + "patient_id": "p1", + "event_type": "encounter", + "timestamp": "2020-07-01T10:00:00", + "encounter/encounter_id": "e1", + "encounter/encounter_class": "IMP", + }, + { + "patient_id": "p1", + "event_type": "condition", + "timestamp": "2020-07-01T11:00:00", + "condition/encounter_id": "e1", + "condition/concept_key": "http://hl7.org/fhir/sid/icd-10-cm|I20", + }, + ], + ) + vocab = ConceptVocab() + _, _, _, _, _, visit_segments = build_cehr_sequences(patient, vocab, max_len=64) + self.assertEqual(visit_segments, [0, 0, 1, 1]) + + def test_unlinked_visit_idx_matches_sequential_counter(self) -> None: + patient = _patient_from_rows( + "p1", + [ + { + "patient_id": "p1", + "event_type": "patient", + "timestamp": None, + "patient/birth_date": "1950-01-01", + }, + { + "patient_id": "p1", + "event_type": "encounter", + "timestamp": None, + "encounter/encounter_id": "e_bad", + "encounter/encounter_class": "AMB", + }, + { + "patient_id": "p1", + "event_type": "encounter", + "timestamp": "2020-03-01T10:00:00", + "encounter/encounter_id": "e_ok", + "encounter/encounter_class": "IMP", + }, + { + "patient_id": "p1", + "event_type": "condition", + "timestamp": "2020-03-05T11:00:00", + "condition/encounter_id": "e_ok", + "condition/concept_key": "http://hl7.org/fhir/sid/icd-10-cm|I10", + }, + { + "patient_id": "p1", + "event_type": "condition", + "timestamp": "2020-03-15T12:00:00", + "condition/concept_key": "http://hl7.org/fhir/sid/icd-10-cm|Z00", + }, + ], + ) + vocab = ConceptVocab() + concept_ids, _, _, _, visit_orders, visit_segments = build_cehr_sequences( + patient, vocab, max_len=64 + ) + i10 = vocab["http://hl7.org/fhir/sid/icd-10-cm|I10"] + z00 = vocab["http://hl7.org/fhir/sid/icd-10-cm|Z00"] + i_link = concept_ids.index(i10) + i_free = concept_ids.index(z00) + self.assertEqual(visit_orders[i_link], visit_orders[i_free]) + self.assertEqual(visit_segments[i_link], visit_segments[i_free]) + + def test_medication_request_uses_medication_codeable_concept(self) -> None: + patient = _patient_from_rows( + "p1", + [ + { + "patient_id": "p1", + "event_type": "patient", + "timestamp": None, + "patient/birth_date": "1950-01-01", + }, + { + "patient_id": "p1", + "event_type": "encounter", + "timestamp": "2020-06-01T10:00:00", + "encounter/encounter_id": "e1", + "encounter/encounter_class": "IMP", + }, + { + "patient_id": "p1", + "event_type": "medication_request", + "timestamp": "2020-06-01T11:00:00", + "medication_request/encounter_id": "e1", + "medication_request/concept_key": "http://www.nlm.nih.gov/research/umls/rxnorm|111", + }, + { + "patient_id": "p1", + "event_type": "medication_request", + "timestamp": "2020-06-01T12:00:00", + "medication_request/encounter_id": "e1", + "medication_request/concept_key": "http://www.nlm.nih.gov/research/umls/rxnorm|222", + }, + ], + ) + vocab = ConceptVocab() + c, *_ = build_cehr_sequences(patient, vocab, max_len=64) + ka = "http://www.nlm.nih.gov/research/umls/rxnorm|111" + kb = "http://www.nlm.nih.gov/research/umls/rxnorm|222" + self.assertNotEqual(vocab[ka], vocab[kb]) + self.assertEqual(c.count(vocab[ka]), 1) + self.assertEqual(c.count(vocab[kb]), 1) + + def test_medication_request_medication_reference_token(self) -> None: + patient = _patient_from_rows( + "p1", + [ + { + "patient_id": "p1", + "event_type": "patient", + "timestamp": None, + "patient/birth_date": "1950-01-01", + }, + { + "patient_id": "p1", + "event_type": "encounter", + "timestamp": "2020-06-01T10:00:00", + "encounter/encounter_id": "e1", + "encounter/encounter_class": "IMP", + }, + { + "patient_id": "p1", + "event_type": "medication_request", + "timestamp": "2020-06-01T11:00:00", + "medication_request/encounter_id": "e1", + "medication_request/concept_key": "MedicationRequest/reference|med-abc", + }, + ], + ) + vocab = ConceptVocab() + c, *_ = build_cehr_sequences(patient, vocab, max_len=64) + key = "MedicationRequest/reference|med-abc" + self.assertIn(vocab[key], c) + self.assertEqual(c.count(vocab[key]), 1) + + def test_collect_cehr_timeline_events_orders_by_timestamp(self) -> None: + patient = _patient_from_rows( + "p1", + [ + { + "patient_id": "p1", + "event_type": "patient", + "timestamp": None, + "patient/birth_date": "1950-01-01", + }, + { + "patient_id": "p1", + "event_type": "encounter", + "timestamp": "2020-06-01T10:00:00", + "encounter/encounter_id": "e1", + "encounter/encounter_class": "AMB", + }, + { + "patient_id": "p1", + "event_type": "condition", + "timestamp": "2020-06-01T11:00:00", + "condition/encounter_id": "e1", + "condition/concept_key": "a|1", + }, + { + "patient_id": "p1", + "event_type": "observation", + "timestamp": "2020-06-01T12:00:00", + "observation/encounter_id": "e1", + "observation/concept_key": "b|2", + }, + ], + ) + events = collect_cehr_timeline_events(patient) + self.assertEqual([event[1] for event in events], ["encounter|AMB", "a|1", "b|2"]) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/core/test_mimic4_fhir_ndjson_fixtures.py b/tests/core/test_mimic4_fhir_ndjson_fixtures.py new file mode 100644 index 000000000..450ea5aff --- /dev/null +++ b/tests/core/test_mimic4_fhir_ndjson_fixtures.py @@ -0,0 +1,112 @@ +"""NDJSON file bodies for :mod:`tests.core.test_mimic4_fhir_dataset` (disk-only ingest).""" + +from __future__ import annotations + +from pathlib import Path +from typing import Any, Dict, List + +import orjson + + +# --------------------------------------------------------------------------- +# Synthetic in-memory FHIR resources +# --------------------------------------------------------------------------- + + +def _one_patient_resources() -> List[Dict[str, Any]]: + return [ + {"resourceType": "Patient", "id": "p-synth-1", "birthDate": "1950-01-01", "gender": "female"}, + { + "resourceType": "Encounter", + "id": "e1", + "subject": {"reference": "Patient/p-synth-1"}, + "period": {"start": "2020-06-01T10:00:00Z"}, + "class": {"code": "IMP"}, + }, + { + "resourceType": "Condition", + "id": "c1", + "subject": {"reference": "Patient/p-synth-1"}, + "encounter": {"reference": "Encounter/e1"}, + "code": {"coding": [{"system": "http://hl7.org/fhir/sid/icd-10-cm", "code": "I10"}]}, + "onsetDateTime": "2020-06-01T11:00:00Z", + }, + ] + + +def _two_patient_resources() -> List[Dict[str, Any]]: + return [ + *_one_patient_resources(), + {"resourceType": "Patient", "id": "p-synth-2", "birthDate": "1940-05-05", "deceasedBoolean": True}, + { + "resourceType": "Encounter", + "id": "e-dead", + "subject": {"reference": "Patient/p-synth-2"}, + "period": {"start": "2020-07-01T10:00:00Z"}, + "class": {"code": "IMP"}, + }, + { + "resourceType": "Observation", + "id": "o-dead", + "subject": {"reference": "Patient/p-synth-2"}, + "encounter": {"reference": "Encounter/e-dead"}, + "effectiveDateTime": "2020-07-01T12:00:00Z", + "code": {"coding": [{"system": "http://loinc.org", "code": "789-0"}]}, + }, + ] + + +# --------------------------------------------------------------------------- +# Text serialisers +# --------------------------------------------------------------------------- + + +def ndjson_one_patient_text() -> str: + return "\n".join(orjson.dumps(r).decode("utf-8") for r in _one_patient_resources()) + "\n" + + +def ndjson_two_class_text() -> str: + return "\n".join(orjson.dumps(r).decode("utf-8") for r in _two_patient_resources()) + "\n" + + +# --------------------------------------------------------------------------- +# Disk writers +# --------------------------------------------------------------------------- + + +def write_two_class_ndjson(directory: Path, *, name: str = "fixture.ndjson") -> Path: + path = directory / name + path.write_text(ndjson_two_class_text(), encoding="utf-8") + return path + + +def write_one_patient_ndjson(directory: Path, *, name: str = "fixture.ndjson") -> Path: + path = directory / name + path.write_text(ndjson_one_patient_text(), encoding="utf-8") + return path + + +# --------------------------------------------------------------------------- +# Shared test helper +# --------------------------------------------------------------------------- + + +def run_task(ds: Any, task: Any) -> List[Dict[str, Any]]: + """Run *task* over every patient in *ds* without the LitData caching pipeline. + + This helper mirrors the direct-iteration path that the old + ``MIMIC4FHIRDataset.gather_samples`` provided. It is intentionally kept + here (the shared fixture module) so all FHIR test files can import it + rather than each maintaining their own copy. + + Args: + ds: A :class:`~pyhealth.datasets.MIMIC4FHIRDataset` instance whose + ``global_event_df`` has already been built. + task: A :class:`~pyhealth.tasks.MPFClinicalPredictionTask` instance. + + Returns: + Flat list of sample dicts, one per patient. + """ + task._specials = None + task.frozen_vocab = False + return [s for patient in ds.iter_patients() for s in task(patient)] diff --git a/tests/core/test_mpf_task.py b/tests/core/test_mpf_task.py new file mode 100644 index 000000000..947b284c1 --- /dev/null +++ b/tests/core/test_mpf_task.py @@ -0,0 +1,88 @@ +import shutil +import tempfile +import unittest +from pathlib import Path + +from pyhealth.datasets import MIMIC4FHIRDataset +from pyhealth.tasks.mpf_clinical_prediction import MPFClinicalPredictionTask + +from tests.core.test_mimic4_fhir_ndjson_fixtures import run_task, write_two_class_ndjson + + +class TestMPFClinicalPredictionTask(unittest.TestCase): + def _two_patient_ds(self) -> MIMIC4FHIRDataset: + tmp = tempfile.mkdtemp() + self.addCleanup(lambda p=tmp: shutil.rmtree(p, ignore_errors=True)) + write_two_class_ndjson(Path(tmp)) + return MIMIC4FHIRDataset( + root=tmp, glob_pattern="*.ndjson", cache_dir=tmp + ) + + def test_max_len_validation(self) -> None: + with self.assertRaises(ValueError): + MPFClinicalPredictionTask(max_len=1, use_mpf=True) + + def test_mpf_sets_boundary_tokens(self) -> None: + task = MPFClinicalPredictionTask(max_len=32, use_mpf=True) + ds = self._two_patient_ds() + samples = run_task(ds, task) + vocab = task.vocab + self.assertGreater(len(samples), 0) + s0 = samples[0] + mor = vocab[""] + reg = vocab[""] + pad_id = vocab.pad_id + ids = s0["concept_ids"] + first = next(i for i, x in enumerate(ids) if x != pad_id) + last_nz = next(i for i in range(len(ids) - 1, -1, -1) if ids[i] != pad_id) + self.assertEqual(ids[first], mor) + self.assertEqual(ids[last_nz], reg) + self.assertEqual(ids[-1], reg) + + def test_no_mpf_uses_cls_reg(self) -> None: + task = MPFClinicalPredictionTask(max_len=32, use_mpf=False) + ds = self._two_patient_ds() + samples = run_task(ds, task) + vocab = task.vocab + s0 = samples[0] + cls_id = vocab[""] + reg_id = vocab[""] + pad_id = vocab.pad_id + ids = s0["concept_ids"] + first = next(i for i, x in enumerate(ids) if x != pad_id) + last_nz = next(i for i in range(len(ids) - 1, -1, -1) if ids[i] != pad_id) + self.assertEqual(ids[first], cls_id) + self.assertEqual(ids[last_nz], reg_id) + self.assertEqual(ids[-1], reg_id) + + def test_schema_keys(self) -> None: + task = MPFClinicalPredictionTask(max_len=16, use_mpf=True) + ds = self._two_patient_ds() + samples = run_task(ds, task) + for k in task.input_schema: + self.assertIn(k, samples[0]) + self.assertIn("label", samples[0]) + + def test_max_len_two_keeps_boundary_tokens(self) -> None: + """``clinical_cap=0`` must yield ``[, ]`` left-padded, not truncated.""" + + task = MPFClinicalPredictionTask(max_len=2, use_mpf=True) + ds = self._two_patient_ds() + samples = run_task(ds, task) + vocab = task.vocab + mor = vocab[""] + reg = vocab[""] + pad_id = vocab.pad_id + for s in samples: + ids = s["concept_ids"] + first = next(i for i, x in enumerate(ids) if x != pad_id) + last_nz = next( + i for i in range(len(ids) - 1, -1, -1) if ids[i] != pad_id + ) + self.assertEqual(ids[first], mor) + self.assertEqual(ids[last_nz], reg) + self.assertEqual(ids[-1], reg) + + +if __name__ == "__main__": + unittest.main()