diff --git a/docs/api/datasets.rst b/docs/api/datasets.rst
index b02439d26..ea772bfec 100644
--- a/docs/api/datasets.rst
+++ b/docs/api/datasets.rst
@@ -224,6 +224,7 @@ Available Datasets
datasets/pyhealth.datasets.SampleDataset
datasets/pyhealth.datasets.MIMIC3Dataset
datasets/pyhealth.datasets.MIMIC4Dataset
+ datasets/pyhealth.datasets.MIMIC4FHIRDataset
datasets/pyhealth.datasets.MedicalTranscriptionsDataset
datasets/pyhealth.datasets.CardiologyDataset
datasets/pyhealth.datasets.eICUDataset
diff --git a/docs/api/datasets/pyhealth.datasets.MIMIC4FHIRDataset.rst b/docs/api/datasets/pyhealth.datasets.MIMIC4FHIRDataset.rst
new file mode 100644
index 000000000..1a19fc5e9
--- /dev/null
+++ b/docs/api/datasets/pyhealth.datasets.MIMIC4FHIRDataset.rst
@@ -0,0 +1,70 @@
+pyhealth.datasets.MIMIC4FHIRDataset
+=====================================
+
+`MIMIC-IV on FHIR `_ NDJSON ingest
+for CEHR-style token sequences used with
+:class:`~pyhealth.tasks.mpf_clinical_prediction.MPFClinicalPredictionTask` and
+:class:`~pyhealth.models.EHRMambaCEHR`.
+
+YAML defaults live in ``pyhealth/datasets/configs/mimic4_fhir.yaml``. Unlike the
+earlier nested-object approach, the YAML now declares a normal ``tables:``
+schema for flattened FHIR resources (``patient``, ``encounter``, ``condition``,
+``observation``, ``medication_request``, ``procedure``). The class subclasses
+:class:`~pyhealth.datasets.BaseDataset` and builds a standard Polars
+``global_event_df`` backed by cached Parquet (``global_event_df.parquet/part-*.parquet``),
+same tabular path as other datasets: :meth:`~pyhealth.datasets.BaseDataset.set_task`,
+:meth:`iter_patients`, :meth:`get_patient`, etc.
+
+**Ingest (out-of-core).** Matching ``*.ndjson`` / ``*.ndjson.gz`` files are read
+**line by line**; each resource is normalized into a flattened per-resource
+Parquet table under ``cache/flattened_tables/``. Those tables are then fed
+through the regular YAML-driven :class:`~pyhealth.datasets.BaseDataset` loader to
+materialize ``global_event_df``. This keeps FHIR aligned with PyHealth's usual
+table-first pipeline instead of reparsing nested JSON per patient downstream.
+
+**``max_patients``.** When set, the loader selects the first *N* patient ids after
+a **sorted** ``unique`` over the flattened patient table, filters every
+normalized table to that cohort, and then builds ``global_event_df`` from the
+filtered tables. Ingest still scans all matching NDJSON once unless you also
+override ``glob_patterns`` / ``glob_pattern`` (defaults skip non-flattened PhysioNet shards).
+
+**Downstream memory (still important).** Streaming ingest avoids loading the
+entire NDJSON corpus into RAM at once, but other steps can still be heavy on
+large cohorts: ``global_event_df`` materialization, MPF vocabulary warmup, and
+:meth:`set_task` still walk patients and samples; training needs RAM/VRAM for the
+model and batches. For a **full** PhysioNet tree, plan for **large disk**
+(flattened tables plus event cache), **comfortable system RAM** for Polars/PyArrow
+and task pipelines, and restrict ``glob_patterns`` / ``glob_pattern`` or ``max_patients`` when
+prototyping on a laptop.
+
+**Recommended hardware (informal)**
+
+Order-of-magnitude guides, not guarantees. Ingest footprint is **much smaller**
+than “load everything into Python”; wall time still grows with **decompressed
+NDJSON volume** and the amount of flattened table data produced.
+
+* **Smoke / CI**
+ Small on-disk fixtures (see tests and ``examples/mimic4fhir_mpf_ehrmamba.py``):
+ a recent laptop is sufficient.
+
+* **Laptop-scale real FHIR subset**
+ A **narrow** ``glob_patterns`` / ``glob_pattern`` and/or ``max_patients`` in the hundreds keeps
+ cache and task passes manageable. **≥ 16 GB** system RAM is a practical
+ comfort target for Polars + trainer + OS; validate GPU **VRAM** for your
+ ``max_len`` and batch size.
+
+* **Full default globs on a complete export**
+ Favor **workstations or servers** with **fast SSD**, **large disk**, and
+ **ample RAM** for downstream steps—not because NDJSON is fully buffered in
+ memory during ingest, but because total work and caches still scale with the
+ full dataset.
+
+.. autoclass:: pyhealth.datasets.MIMIC4FHIRDataset
+ :members:
+ :undoc-members:
+ :show-inheritance:
+
+.. autoclass:: pyhealth.datasets.ConceptVocab
+ :members:
+ :undoc-members:
+ :show-inheritance:
diff --git a/docs/api/models.rst b/docs/api/models.rst
index 7368dec94..c7e9f2729 100644
--- a/docs/api/models.rst
+++ b/docs/api/models.rst
@@ -185,6 +185,7 @@ API Reference
models/pyhealth.models.MoleRec
models/pyhealth.models.Deepr
models/pyhealth.models.EHRMamba
+ models/pyhealth.models.EHRMambaCEHR
models/pyhealth.models.JambaEHR
models/pyhealth.models.ContraWR
models/pyhealth.models.SparcNet
diff --git a/docs/api/models/pyhealth.models.EHRMambaCEHR.rst b/docs/api/models/pyhealth.models.EHRMambaCEHR.rst
new file mode 100644
index 000000000..79466cad3
--- /dev/null
+++ b/docs/api/models/pyhealth.models.EHRMambaCEHR.rst
@@ -0,0 +1,12 @@
+pyhealth.models.EHRMambaCEHR
+===================================
+
+EHRMambaCEHR applies CEHR-style embeddings (:class:`~pyhealth.models.cehr_embeddings.MambaEmbeddingsForCEHR`)
+and a stack of :class:`~pyhealth.models.MambaBlock` layers to a single FHIR token stream, for use with
+:class:`~pyhealth.tasks.mpf_clinical_prediction.MPFClinicalPredictionTask` and
+:class:`~pyhealth.datasets.mimic4_fhir.MIMIC4FHIRDataset`.
+
+.. autoclass:: pyhealth.models.EHRMambaCEHR
+ :members:
+ :undoc-members:
+ :show-inheritance:
diff --git a/docs/api/tasks.rst b/docs/api/tasks.rst
index 399b8f1aa..83790ca44 100644
--- a/docs/api/tasks.rst
+++ b/docs/api/tasks.rst
@@ -214,6 +214,7 @@ Available Tasks
Drug Recommendation
Length of Stay Prediction
Medical Transcriptions Classification
+ MPF Clinical Prediction (FHIR)
Mortality Prediction (Next Visit)
Mortality Prediction (StageNet MIMIC-IV)
Patient Linkage (MIMIC-III)
diff --git a/docs/api/tasks/pyhealth.tasks.mpf_clinical_prediction.rst b/docs/api/tasks/pyhealth.tasks.mpf_clinical_prediction.rst
new file mode 100644
index 000000000..ff66deb08
--- /dev/null
+++ b/docs/api/tasks/pyhealth.tasks.mpf_clinical_prediction.rst
@@ -0,0 +1,12 @@
+pyhealth.tasks.mpf_clinical_prediction
+======================================
+
+Multitask Prompted Fine-tuning (MPF) style binary clinical prediction on FHIR
+token timelines, paired with :class:`~pyhealth.datasets.MIMIC4FHIRDataset` and
+:class:`~pyhealth.models.EHRMambaCEHR`. Based on CEHR / EHRMamba ideas; see the
+paper linked in the course replication PR.
+
+.. autoclass:: pyhealth.tasks.MPFClinicalPredictionTask
+ :members:
+ :undoc-members:
+ :show-inheritance:
diff --git a/examples/mimic4fhir_mpf_ehrmamba.py b/examples/mimic4fhir_mpf_ehrmamba.py
new file mode 100644
index 000000000..e53098779
--- /dev/null
+++ b/examples/mimic4fhir_mpf_ehrmamba.py
@@ -0,0 +1,868 @@
+"""EHRMambaCEHR on MIMIC-IV FHIR NDJSON with MPF clinical prediction (ablations).
+
+Replication target: EHRMamba / CEHR-style modeling on tokenized FHIR timelines
+(e.g. `arXiv:2405.14567 `_). This script is
+runnable end-to-end on **synthetic** NDJSON (``--quick-test``) or on
+credentialled MIMIC-IV on FHIR from PhysioNet.
+
+Experimental setup (for write-ups / PR):
+ * **Data**: Synthetic two-patient NDJSON (``--quick-test``) or disk NDJSON
+ under ``MIMIC4_FHIR_ROOT`` / ``--fhir-root``.
+ * **Task ablations**: ``max_len`` (context window), ``use_mpf`` vs generic
+ ````/```` boundaries (``--no-mpf``).
+ * **Model ablations**: ``hidden_dim`` (embedding width); optional dropout
+ fixed at 0.1 in this script.
+ * **Train**: Adam via :class:`~pyhealth.trainer.Trainer`, monitor ROC-AUC,
+ report test ROC-AUC / PR-AUC.
+
+**Ablation mode** (``--ablation``): sweeps a small grid on synthetic data only,
+trains 1 epoch per config, and prints a comparison table. Use this to document
+how task/model knobs affect metrics on the minimal fixture before scaling to
+real FHIR.
+
+**Findings** (fill in after your runs; synthetic runs are noisy):
+ On ``--quick-test`` data, longer ``max_len`` and MPF specials typically
+ change logits enough to move AUC slightly; real MIMIC-IV FHIR runs are
+ needed for conclusive comparisons. Paste your table from ``--ablation``
+ into the PR description.
+
+**Scaling:** :class:`~pyhealth.datasets.MIMIC4FHIRDataset` streams NDJSON into
+flattened per-resource Parquet tables (bounded RAM during ingest). This example trains via
+``dataset.set_task(MPFClinicalPredictionTask)`` → LitData-backed
+:class:`~pyhealth.datasets.sample_dataset.SampleDataset` →
+:class:`~pyhealth.trainer.Trainer` (PyHealth’s standard path), instead of
+materializing all samples with ``gather_samples()``. Prefer ``--max-patients`` to
+ bound ingest when possible. Very large cohorts still need RAM/disk for task
+caches and MPF vocabulary warmup.
+
+**Offline flattened tables (NDJSON normalization already done):** pass
+``--prebuilt-global-event-dir`` pointing at a directory containing the normalized
+FHIR tables (``patient.parquet``, ``encounter.parquet``, ``condition.parquet``,
+etc.). The example seeds ``flattened_tables/`` under the usual PyHealth cache UUID,
+then lets :class:`~pyhealth.datasets.BaseDataset` rebuild
+``global_event_df.parquet/`` from those tables — the downstream path is still
+``global_event_df`` → :class:`~pyhealth.data.Patient` →
+:class:`~pyhealth.tasks.mpf_clinical_prediction.MPFClinicalPredictionTask` →
+:class:`~pyhealth.trainer.Trainer``. Use ``--fhir-root`` / ``--glob-pattern`` /
+``--max-patients -1`` matching the ingest fingerprint.
+``--train-patient-cap`` restricts task transforms via ``task.pre_filter`` using a
+label-aware deterministic patient subset. The full ``unique_patient_ids`` scan and MPF vocab warmup
+in the dataset still walk the cached cohort.
+
+**Approximate minimum specs** (``--quick-test``, CPU, synthetic 2-patient
+fixture; measured once on macOS/arm64 with ``/usr/bin/time -l``): peak RSS
+~**600–700 MiB**, wall **~10–15 s** for two short epochs. Real NDJSON/GZ at scale
+needs proportionally more RAM, disk, and time; GPU helps training, not the
+current all-in-RAM parse.
+
+Usage:
+ cd PyHealth && PYTHONPATH=. python examples/mimic4fhir_mpf_ehrmamba.py --quick-test
+ PYTHONPATH=. python examples/mimic4fhir_mpf_ehrmamba.py --quick-test --ablation
+ export MIMIC4_FHIR_ROOT=/path/to/fhir
+ pixi run -e base python examples/mimic4fhir_mpf_ehrmamba.py --fhir-root "$MIMIC4_FHIR_ROOT"
+
+ # Prebuilt flattened FHIR tables (skip NDJSON normalization); cap patients for a smoke train
+ pixi run -e base python examples/mimic4fhir_mpf_ehrmamba.py \\
+ --prebuilt-global-event-dir /path/to/flattened_table_dir \\
+ --fhir-root /same/as/ndjson/ingest/root --glob-pattern 'Mimic*.ndjson.gz' --max-patients -1 \\
+ --train-patient-cap 2048 --epochs 2 \\
+ --ntfy-url 'https://ntfy.sh/your-topic'
+"""
+
+from __future__ import annotations
+
+import argparse
+import os
+import random
+import re
+import shutil
+import sys
+import tempfile
+import time
+import urllib.error
+import urllib.request
+from pathlib import Path
+from typing import Any, Dict, List, Optional
+
+_parser = argparse.ArgumentParser(description="EHRMambaCEHR on MIMIC-IV FHIR (MPF)")
+_parser.add_argument(
+ "--gpu",
+ type=int,
+ default=None,
+ help="GPU index; sets CUDA_VISIBLE_DEVICES.",
+)
+_parser.add_argument(
+ "--fhir-root",
+ type=str,
+ default=None,
+ help="Root directory with NDJSON (default: MIMIC4_FHIR_ROOT env).",
+)
+_parser.add_argument(
+ "--glob-pattern",
+ type=str,
+ default=None,
+ help=(
+ "Override glob for NDJSON/NDJSON.GZ (default: yaml **/*.ndjson.gz). "
+ "Use a narrow pattern to limit ingest time and cache size."
+ ),
+)
+_parser.add_argument(
+ "--max-len",
+ type=int,
+ default=512,
+ help="Sequence length ablation (e.g. 512 / 1024 / 2048 per proposal).",
+)
+_parser.add_argument(
+ "--no-mpf",
+ action="store_true",
+ help="Ablation: use generic CLS/REG specials instead of task MPF tokens.",
+)
+_parser.add_argument(
+ "--hidden-dim",
+ type=int,
+ default=128,
+ help="Embedding / hidden size (model ablation).",
+)
+_parser.add_argument(
+ "--lr",
+ type=float,
+ default=1e-3,
+ help="Adam learning rate (trainer.train optimizer_params).",
+)
+_parser.add_argument(
+ "--quick-test",
+ action="store_true",
+ help="Use synthetic in-memory FHIR lines only (no disk root).",
+)
+_parser.add_argument(
+ "--ablation",
+ action="store_true",
+ help="Run a small max_len × MPF × hidden_dim grid on synthetic data; print table.",
+)
+_parser.add_argument(
+ "--epochs",
+ type=int,
+ default=None,
+ help="Training epochs (default: 2 with --quick-test, else 20).",
+)
+_parser.add_argument(
+ "--max-patients",
+ type=int,
+ default=500,
+ help=(
+ "Fingerprint for cache dir: cap patients during normalization (-1 = full cohort, "
+ "match an uncapped NDJSON→flattened-table export)."
+ ),
+)
+_parser.add_argument(
+ "--prebuilt-global-event-dir",
+ type=str,
+ default=None,
+ help=(
+ "Directory with normalized flattened FHIR tables (*.parquet). Seeds "
+ "cache/flattened_tables/ so training skips NDJSON normalization "
+ "(downstream unchanged: Patient + MPF + Trainer)."
+ ),
+)
+_parser.add_argument(
+ "--ingest-num-shards",
+ type=int,
+ default=None,
+ help="Compatibility no-op: retained for CLI stability with older runs.",
+)
+_parser.add_argument(
+ "--train-patient-cap",
+ type=int,
+ default=None,
+ help=(
+ "After cache is ready, only build samples from a deterministic label-aware "
+ "patient subset of size N (reduces train time; unique-id scan of "
+ "global_event_df still runs once)."
+ ),
+)
+_parser.add_argument(
+ "--ntfy-url",
+ type=str,
+ default=None,
+ help="POST notification when main() finishes (e.g. https://ntfy.sh/topic).",
+)
+_parser.add_argument(
+ "--loss-plot-path",
+ type=str,
+ default=None,
+ help="Write loss curve PNG here (default: alongside Trainer log under output/).",
+)
+_parser.add_argument(
+ "--cache-dir",
+ type=str,
+ default=None,
+ help="PyHealth dataset cache parent (UUID subdir added by MIMIC4FHIRDataset).",
+)
+_parser.add_argument(
+ "--task-num-workers",
+ type=int,
+ default=None,
+ help=(
+ "Workers for LitData task/processor transforms (default: dataset "
+ "``num_workers``, usually 1)."
+ ),
+)
+_pre_args, _ = _parser.parse_known_args()
+if _pre_args.gpu is not None:
+ os.environ["CUDA_VISIBLE_DEVICES"] = str(_pre_args.gpu)
+
+import torch
+
+import polars as pl
+
+from pyhealth.datasets import MIMIC4FHIRDataset, get_dataloader
+from pyhealth.processors.cehr_processor import infer_mortality_label
+from pyhealth.models import EHRMambaCEHR
+from pyhealth.tasks.mpf_clinical_prediction import MPFClinicalPredictionTask
+from pyhealth.trainer import Trainer
+
+
+class PatientCappedMPFTask(MPFClinicalPredictionTask):
+ """Example-only: limit task transform to an explicit patient_id allow-list."""
+
+ def __init__(
+ self,
+ *,
+ max_len: int,
+ use_mpf: bool,
+ patient_ids_allow: List[str],
+ ) -> None:
+ super().__init__(max_len=max_len, use_mpf=use_mpf)
+ self.patient_ids_allow = patient_ids_allow
+
+ def pre_filter(self, df: pl.LazyFrame) -> pl.LazyFrame:
+ return df.filter(pl.col("patient_id").is_in(self.patient_ids_allow))
+
+
+SCRIPT_DIR = os.path.dirname(os.path.abspath(__file__))
+SEED = 42
+DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
+BATCH_SIZE = 8
+EPOCHS = 20
+SPLIT_RATIOS = (0.7, 0.1, 0.2)
+
+
+def _max_patients_arg(v: int) -> Optional[int]:
+ return None if v is not None and v < 0 else v
+
+
+def _seed_flattened_table_cache(prebuilt_dir: Path, ds: MIMIC4FHIRDataset) -> None:
+ """Copy normalized per-resource parquet tables into the dataset cache."""
+
+ tables = sorted(prebuilt_dir.glob("*.parquet"))
+ if not tables:
+ raise FileNotFoundError(
+ f"No *.parquet tables under {prebuilt_dir} — expected flattened FHIR tables."
+ )
+ prepared = ds.prepared_tables_dir
+ if prepared.exists() and any(prepared.glob("*.parquet")):
+ return
+ prepared.mkdir(parents=True, exist_ok=True)
+ for src in tables:
+ dest = prepared / src.name
+ if dest.exists():
+ continue
+ try:
+ os.link(src, dest)
+ except OSError:
+ shutil.copy2(src, dest)
+
+
+def _parse_train_losses_from_log(log_path: Path) -> List[float]:
+ """Mean training loss per epoch from Trainer file log."""
+
+ if not log_path.is_file():
+ return []
+ text = log_path.read_text(encoding="utf-8", errors="replace")
+ losses: List[float] = []
+ lines = text.splitlines()
+ for i, line in enumerate(lines):
+ if "--- Train epoch-" in line and i + 1 < len(lines):
+ m = re.search(r"loss:\s*([0-9.eE+-]+)", lines[i + 1])
+ if m:
+ losses.append(float(m.group(1)))
+ return losses
+
+
+def _write_loss_plot(losses: List[float], out_path: Path) -> None:
+ if not losses:
+ return
+ out_path.parent.mkdir(parents=True, exist_ok=True)
+ try:
+ import matplotlib.pyplot as plt
+ except ImportError:
+ csv_path = out_path.with_suffix(".csv")
+ csv_path.write_text(
+ "epoch,train_loss_mean\n"
+ + "\n".join(f"{i},{v}" for i, v in enumerate(losses)),
+ encoding="utf-8",
+ )
+ print(
+ "matplotlib not installed; wrote", csv_path, "(pip install matplotlib for PNG)"
+ )
+ return
+ plt.figure(figsize=(6, 3.5))
+ plt.plot(range(len(losses)), losses, marker="o", linewidth=1)
+ plt.xlabel("epoch")
+ plt.ylabel("mean train loss")
+ plt.title("EHRMambaCEHR training loss (MPF)")
+ plt.grid(True, alpha=0.3)
+ plt.tight_layout()
+ plt.savefig(out_path, dpi=120)
+ plt.close()
+ print("loss plot:", out_path)
+
+
+def _ntfy(url: str, title: str, message: str) -> None:
+ try:
+ req = urllib.request.Request(
+ url,
+ data=message.encode("utf-8"),
+ method="POST",
+ )
+ req.add_header("Title", title[:200])
+ with urllib.request.urlopen(req, timeout=60) as resp:
+ if resp.status >= 400:
+ print("ntfy HTTP", resp.status, file=sys.stderr)
+ except urllib.error.URLError as e:
+ print("ntfy failed:", e, file=sys.stderr)
+
+
+def _quick_test_ndjson_dir() -> str:
+ """Write two-patient synthetic NDJSON; returns temp directory (caller cleans up)."""
+ import orjson
+
+ _resources = [
+ {"resourceType": "Patient", "id": "p-synth-1", "birthDate": "1950-01-01", "gender": "female"},
+ {
+ "resourceType": "Encounter", "id": "e1",
+ "subject": {"reference": "Patient/p-synth-1"},
+ "period": {"start": "2020-06-01T10:00:00Z"}, "class": {"code": "IMP"},
+ },
+ {
+ "resourceType": "Condition", "id": "c1",
+ "subject": {"reference": "Patient/p-synth-1"},
+ "encounter": {"reference": "Encounter/e1"},
+ "code": {"coding": [{"system": "http://hl7.org/fhir/sid/icd-10-cm", "code": "I10"}]},
+ "onsetDateTime": "2020-06-01T11:00:00Z",
+ },
+ {"resourceType": "Patient", "id": "p-synth-2", "birthDate": "1940-05-05", "deceasedBoolean": True},
+ {
+ "resourceType": "Encounter", "id": "e-dead",
+ "subject": {"reference": "Patient/p-synth-2"},
+ "period": {"start": "2020-07-01T10:00:00Z"}, "class": {"code": "IMP"},
+ },
+ {
+ "resourceType": "Observation", "id": "o-dead",
+ "subject": {"reference": "Patient/p-synth-2"},
+ "encounter": {"reference": "Encounter/e-dead"},
+ "effectiveDateTime": "2020-07-01T12:00:00Z",
+ "code": {"coding": [{"system": "http://loinc.org", "code": "789-0"}]},
+ },
+ ]
+ ndjson_text = "\n".join(orjson.dumps(r).decode("utf-8") for r in _resources) + "\n"
+
+ tmp = tempfile.mkdtemp(prefix="pyhealth_mimic4_fhir_quick_")
+ Path(tmp, "fixture.ndjson").write_text(ndjson_text, encoding="utf-8")
+ return tmp
+
+
+def _patient_label(ds: MIMIC4FHIRDataset, patient_id: str) -> int:
+ patient = ds.get_patient(patient_id)
+ return int(infer_mortality_label(patient))
+
+
+def _ensure_binary_label_coverage(ds: MIMIC4FHIRDataset) -> None:
+ found: Dict[int, str] = {}
+ scanned = 0
+ for patient_id in ds.unique_patient_ids:
+ label = _patient_label(ds, patient_id)
+ scanned += 1
+ found.setdefault(label, patient_id)
+ if len(found) == 2:
+ print(
+ "label preflight:",
+ {"scanned_patients": scanned, "example_patient_ids": found},
+ )
+ return
+ raise SystemExit(
+ "Binary mortality example found only one label in the available cohort; "
+ "cannot build a valid binary training set from this cache."
+ )
+
+
+def _select_patient_ids_for_cap(
+ ds: MIMIC4FHIRDataset, requested_cap: int
+) -> List[str]:
+ patient_ids = ds.unique_patient_ids
+ if not patient_ids:
+ return []
+
+ desired = max(2, requested_cap)
+ desired = min(desired, len(patient_ids))
+ if desired < requested_cap:
+ print(
+ f"train_patient_cap requested {requested_cap}, but only {desired} patients are available."
+ )
+ elif requested_cap < 2:
+ print(
+ f"train_patient_cap={requested_cap} is too small for binary labels; using {desired}."
+ )
+
+ encountered: List[str] = []
+ label_by_patient_id: Dict[str, int] = {}
+ first_by_label: Dict[int, str] = {}
+ for patient_id in patient_ids:
+ label = _patient_label(ds, patient_id)
+ encountered.append(patient_id)
+ label_by_patient_id[patient_id] = label
+ first_by_label.setdefault(label, patient_id)
+ if len(encountered) >= desired and len(first_by_label) == 2:
+ break
+
+ if len(first_by_label) < 2:
+ raise SystemExit(
+ "Unable to satisfy --train-patient-cap with both binary labels from the "
+ "available cohort. Use a different cache/export or remove the cap."
+ )
+
+ selected = encountered[:desired]
+ selected_labels = {label_by_patient_id[pid] for pid in selected}
+ if len(selected_labels) == 1:
+ missing_label = 1 - next(iter(selected_labels))
+ replacement = first_by_label[missing_label]
+ for idx in range(len(selected) - 1, -1, -1):
+ if label_by_patient_id[selected[idx]] != missing_label:
+ selected[idx] = replacement
+ break
+
+ counts = {
+ 0: sum(1 for pid in selected if label_by_patient_id[pid] == 0),
+ 1: sum(1 for pid in selected if label_by_patient_id[pid] == 1),
+ }
+ print(
+ "train_patient_cap selection:",
+ {
+ "requested": requested_cap,
+ "selected": len(selected),
+ "scanned_patients": len(encountered),
+ "label_counts": counts,
+ },
+ )
+ return selected
+
+
+def _sample_label(sample: Dict[str, Any]) -> int:
+ label = sample["label"]
+ if isinstance(label, torch.Tensor):
+ return int(label.reshape(-1)[0].item())
+ return int(label)
+
+
+def _split_counts(n: int) -> List[int]:
+ if n < 3:
+ raise ValueError("Need at least 3 samples for three-way stratified split.")
+ counts = [1, 1, 1]
+ remaining = n - 3
+ raw = [ratio * remaining for ratio in SPLIT_RATIOS]
+ floors = [int(x) for x in raw]
+ for i, floor in enumerate(floors):
+ counts[i] += floor
+ assigned = sum(counts)
+ order = sorted(
+ range(3),
+ key=lambda i: raw[i] - floors[i],
+ reverse=True,
+ )
+ for i in order:
+ if assigned >= n:
+ break
+ counts[i] += 1
+ assigned += 1
+ counts[0] += n - assigned
+ return counts
+
+
+def _split_sample_dataset_for_binary_metrics(sample_ds: Any) -> tuple[Any, Any, Any]:
+ if len(sample_ds) < 8:
+ print("sample count < 8; reusing the full dataset for train/val/test.")
+ return sample_ds, sample_ds, sample_ds
+
+ label_to_indices: Dict[int, List[int]] = {0: [], 1: []}
+ for idx in range(len(sample_ds)):
+ label_to_indices[_sample_label(sample_ds[idx])].append(idx)
+
+ label_counts = {label: len(indices) for label, indices in label_to_indices.items()}
+ min_count = min(label_counts.values())
+ if min_count < 3:
+ print(
+ "label distribution too small for disjoint binary train/val/test splits; "
+ "reusing the full dataset for train/val/test.",
+ label_counts,
+ )
+ return sample_ds, sample_ds, sample_ds
+
+ rng = random.Random(SEED)
+ split_indices: List[List[int]] = [[], [], []]
+ for indices in label_to_indices.values():
+ shuffled = indices[:]
+ rng.shuffle(shuffled)
+ n_train, n_val, n_test = _split_counts(len(shuffled))
+ split_indices[0].extend(shuffled[:n_train])
+ split_indices[1].extend(shuffled[n_train : n_train + n_val])
+ split_indices[2].extend(shuffled[n_train + n_val : n_train + n_val + n_test])
+
+ for indices in split_indices:
+ indices.sort()
+
+ split_counts = []
+ for indices in split_indices:
+ split_counts.append(
+ {
+ 0: sum(1 for idx in indices if _sample_label(sample_ds[idx]) == 0),
+ 1: sum(1 for idx in indices if _sample_label(sample_ds[idx]) == 1),
+ "n": len(indices),
+ }
+ )
+ print(
+ "binary stratified split counts:",
+ {"train": split_counts[0], "val": split_counts[1], "test": split_counts[2]},
+ )
+ return (
+ sample_ds.subset(split_indices[0]),
+ sample_ds.subset(split_indices[1]),
+ sample_ds.subset(split_indices[2]),
+ )
+
+
+def _build_loaders_from_sample_dataset(
+ sample_ds: Any,
+ vocab_size: int,
+) -> tuple[Any, Any, Any, Any, int]:
+ train_ds, val_ds, test_ds = _split_sample_dataset_for_binary_metrics(sample_ds)
+ train_loader = get_dataloader(train_ds, batch_size=BATCH_SIZE, shuffle=True)
+ val_loader = get_dataloader(val_ds, batch_size=BATCH_SIZE, shuffle=False)
+ test_loader = get_dataloader(test_ds, batch_size=BATCH_SIZE, shuffle=False)
+ return sample_ds, train_loader, val_loader, test_loader, vocab_size
+
+
+def run_single_train(
+ *,
+ fhir_root: str,
+ max_len: int,
+ use_mpf: bool,
+ hidden_dim: int,
+ epochs: int,
+ lr: float = 1e-3,
+ glob_pattern: str = "*.ndjson",
+ cache_dir: Optional[str] = None,
+ dataset_max_patients: Optional[int] = 500,
+ ingest_num_shards: Optional[int] = None,
+ prebuilt_global_event_dir: Optional[str] = None,
+ train_patient_cap: Optional[int] = None,
+) -> Dict[str, float]:
+ """Train/eval one configuration; returns test metrics (floats)."""
+
+ ds_kw: Dict[str, Any] = {
+ "root": fhir_root,
+ "glob_pattern": glob_pattern,
+ "cache_dir": cache_dir,
+ "max_patients": dataset_max_patients,
+ }
+ if ingest_num_shards is not None:
+ ds_kw["ingest_num_shards"] = ingest_num_shards
+ ds = MIMIC4FHIRDataset(**ds_kw)
+ if prebuilt_global_event_dir:
+ _seed_flattened_table_cache(
+ Path(prebuilt_global_event_dir).expanduser().resolve(), ds
+ )
+ if train_patient_cap is not None:
+ allow = _select_patient_ids_for_cap(ds, train_patient_cap)
+ task: MPFClinicalPredictionTask = PatientCappedMPFTask(
+ max_len=max_len,
+ use_mpf=use_mpf,
+ patient_ids_allow=allow,
+ )
+ else:
+ _ensure_binary_label_coverage(ds)
+ task = MPFClinicalPredictionTask(max_len=max_len, use_mpf=use_mpf)
+ sample_ds = ds.set_task(task, num_workers=1)
+ vocab_size = ds.vocab.vocab_size
+ sample_ds, train_l, val_l, test_l, vocab_size = _build_loaders_from_sample_dataset(
+ sample_ds, vocab_size
+ )
+ model = EHRMambaCEHR(
+ dataset=sample_ds,
+ vocab_size=vocab_size,
+ embedding_dim=hidden_dim,
+ num_layers=2,
+ dropout=0.1,
+ )
+ trainer = Trainer(model=model, metrics=["roc_auc", "pr_auc"], device=DEVICE)
+ trainer.train(
+ train_dataloader=train_l,
+ val_dataloader=val_l,
+ epochs=epochs,
+ monitor="roc_auc",
+ optimizer_params={"lr": lr},
+ )
+ results = trainer.evaluate(test_l)
+ return {k: float(v) for k, v in results.items()}
+
+
+def run_ablation_table(*, lr: float = 1e-3) -> None:
+ """Task × model grid on synthetic NDJSON (short runs for comparison)."""
+
+ # Ablations: context length, MPF vs CLS/REG, plus one hidden_dim pair.
+ grid = [
+ (32, True, 32),
+ (32, False, 32),
+ (96, True, 64),
+ (96, False, 64),
+ ]
+ tmp = _quick_test_ndjson_dir()
+ try:
+ print(
+ "Ablation (synthetic, 1 epoch each): max_len, use_mpf, hidden_dim, lr="
+ f"{lr} -> test roc_auc, pr_auc"
+ )
+ rows = []
+ t0 = time.perf_counter()
+ for max_len, use_mpf, hidden_dim in grid:
+ metrics = run_single_train(
+ fhir_root=tmp,
+ max_len=max_len,
+ use_mpf=use_mpf,
+ hidden_dim=hidden_dim,
+ epochs=1,
+ lr=lr,
+ cache_dir=tmp,
+ dataset_max_patients=500,
+ )
+ rows.append((max_len, use_mpf, hidden_dim, metrics))
+ print(
+ f" max_len={max_len} mpf={use_mpf} hid={hidden_dim} -> "
+ f"roc_auc={metrics['roc_auc']:.4f} pr_auc={metrics['pr_auc']:.4f}"
+ )
+ print("ablation_wall_s:", round(time.perf_counter() - t0, 2))
+ best = max(rows, key=lambda r: r[3]["roc_auc"])
+ print(
+ "best_by_roc_auc:",
+ {
+ "max_len": best[0],
+ "use_mpf": best[1],
+ "hidden_dim": best[2],
+ "metrics": best[3],
+ },
+ )
+ except Exception:
+ print(f"ablation: leaving scratch directory for debugging: {tmp}", file=sys.stderr)
+ raise
+ else:
+ shutil.rmtree(tmp, ignore_errors=True)
+
+
+def main() -> None:
+ args = _parser.parse_args()
+ status = "abort"
+ ntfy_detail = ""
+ try:
+ _main_train(args)
+ status = "ok"
+ ntfy_detail = "Training finished successfully."
+ except SystemExit as e:
+ status = "exit"
+ ntfy_detail = f"SystemExit {e.code!r}"
+ raise
+ except Exception as e:
+ status = "error"
+ ntfy_detail = f"{type(e).__name__}: {e}"[:3800]
+ raise
+ finally:
+ if args.ntfy_url and status in ("ok", "error"):
+ _ntfy(
+ args.ntfy_url,
+ "mimic-fhir-train OK" if status == "ok" else "mimic-fhir-train FAIL",
+ ntfy_detail,
+ )
+
+
+def _main_train(args: argparse.Namespace) -> None:
+ fhir_root = args.fhir_root or os.environ.get("MIMIC4_FHIR_ROOT")
+ quick = args.quick_test
+ quick_test_tmp: Optional[str] = None
+ if args.epochs is not None:
+ epochs = args.epochs
+ else:
+ epochs = 2 if quick else EPOCHS
+
+ if args.ablation:
+ if not quick:
+ raise SystemExit("--ablation requires --quick-test (synthetic data only).")
+ run_ablation_table(lr=args.lr)
+ return
+
+ print("EHRMambaCEHR – MIMIC-IV FHIR (MPF clinical prediction)")
+ print("device:", DEVICE)
+ print("max_len:", args.max_len, "| use_mpf:", not args.no_mpf)
+ print("hidden_dim:", args.hidden_dim, "| lr:", args.lr)
+
+ sample_ds: Any
+ vocab: Any
+
+ if quick:
+ quick_test_tmp = _quick_test_ndjson_dir()
+ ds = MIMIC4FHIRDataset(
+ root=quick_test_tmp,
+ glob_pattern="*.ndjson",
+ cache_dir=quick_test_tmp,
+ max_patients=500,
+ )
+ try:
+ print(
+ "pipeline: synthetic NDJSON → flattened tables → global_event_df "
+ "→ set_task → SampleDataset → Trainer"
+ )
+ task = MPFClinicalPredictionTask(
+ max_len=args.max_len,
+ use_mpf=not args.no_mpf,
+ )
+ print("set_task (quick-test, num_workers=1)...")
+ t_task0 = time.perf_counter()
+ sample_ds = ds.set_task(task, num_workers=1)
+ print(
+ "set_task done: n_samples=",
+ len(sample_ds),
+ "wall_s=",
+ round(time.perf_counter() - t_task0, 2),
+ )
+ vocab = ds.vocab
+ except Exception:
+ print(
+ f"quick-test: leaving NDJSON/Parquet scratch at {quick_test_tmp}",
+ file=sys.stderr,
+ )
+ raise
+ else:
+ mp = _max_patients_arg(args.max_patients)
+ if not fhir_root or not os.path.isdir(fhir_root):
+ raise SystemExit(
+ "Set MIMIC4_FHIR_ROOT or pass --fhir-root to an existing directory "
+ "(NDJSON tree for ingest fingerprint, even when using --prebuilt-global-event-dir)."
+ )
+ ds_kw: Dict[str, Any] = {
+ "root": fhir_root,
+ "max_patients": mp,
+ "cache_dir": args.cache_dir,
+ }
+ if args.glob_pattern is not None:
+ ds_kw["glob_pattern"] = args.glob_pattern
+ if args.ingest_num_shards is not None:
+ ds_kw["ingest_num_shards"] = args.ingest_num_shards
+ ds = MIMIC4FHIRDataset(**ds_kw)
+ if args.prebuilt_global_event_dir:
+ pb = Path(args.prebuilt_global_event_dir).expanduser().resolve()
+ if not pb.is_dir():
+ raise SystemExit(f"--prebuilt-global-event-dir not a directory: {pb}")
+ print(
+ "pipeline: offline flattened FHIR tables → seed flattened table cache "
+ "→ global_event_df → set_task → SampleDataset → Trainer "
+ "(no NDJSON normalization)"
+ )
+ _seed_flattened_table_cache(pb, ds)
+ else:
+ print(
+ "pipeline: NDJSON root → MIMIC4FHIRDataset flattening → global_event_df "
+ "→ set_task → SampleDataset → Trainer"
+ )
+ print("glob_pattern:", ds.glob_pattern, "| max_patients fingerprint:", mp)
+ if args.train_patient_cap is not None:
+ print("train_patient_cap:", args.train_patient_cap)
+ allow = _select_patient_ids_for_cap(ds, args.train_patient_cap)
+ mpf_task: MPFClinicalPredictionTask = PatientCappedMPFTask(
+ max_len=args.max_len,
+ use_mpf=not args.no_mpf,
+ patient_ids_allow=allow,
+ )
+ print("task patient allow-list size:", len(allow))
+ else:
+ _ensure_binary_label_coverage(ds)
+ mpf_task = MPFClinicalPredictionTask(
+ max_len=args.max_len,
+ use_mpf=not args.no_mpf,
+ )
+ nw = args.task_num_workers
+ if nw is None:
+ nw = ds.num_workers
+ print(f"set_task (LitData task cache, num_workers={nw})...")
+ t_task0 = time.perf_counter()
+ sample_ds = ds.set_task(mpf_task, num_workers=nw)
+ print(
+ "set_task done: n_samples=",
+ len(sample_ds),
+ "wall_s=",
+ round(time.perf_counter() - t_task0, 2),
+ )
+ vocab = ds.vocab
+ print("fhir_root:", fhir_root)
+
+ try:
+ if len(sample_ds) == 0:
+ raise SystemExit(
+ "No training samples (0 patients or empty sequences). "
+ "PhysioNet MIMIC-IV FHIR uses *.ndjson.gz (see default glob_patterns in "
+ "pyhealth/datasets/configs/mimic4_fhir.yaml). If your tree is plain *.ndjson, "
+ "construct MIMIC4FHIRDataset with glob_pattern='**/*.ndjson'."
+ )
+
+ sample_ds, train_loader, val_loader, test_loader, vocab_size = (
+ _build_loaders_from_sample_dataset(sample_ds, vocab.vocab_size)
+ )
+
+ model = EHRMambaCEHR(
+ dataset=sample_ds,
+ vocab_size=vocab_size,
+ embedding_dim=args.hidden_dim,
+ num_layers=2,
+ dropout=0.1,
+ )
+ trainer = Trainer(model=model, metrics=["roc_auc", "pr_auc"], device=DEVICE)
+
+ t0 = time.perf_counter()
+ trainer.train(
+ train_dataloader=train_loader,
+ val_dataloader=val_loader,
+ epochs=epochs,
+ monitor="roc_auc",
+ optimizer_params={"lr": args.lr},
+ )
+ results = trainer.evaluate(test_loader)
+ print("Test:", {k: float(v) for k, v in results.items()})
+ print("wall_s:", round(time.perf_counter() - t0, 1))
+ print("concept_vocab_size:", vocab.vocab_size)
+
+ log_txt = (
+ Path(trainer.exp_path) / "log.txt" if trainer.exp_path else None
+ )
+ if log_txt and log_txt.is_file():
+ losses = _parse_train_losses_from_log(log_txt)
+ print("train_loss_per_epoch:", losses)
+ plot_path = (
+ Path(args.loss_plot_path)
+ if args.loss_plot_path
+ else Path(trainer.exp_path) / "train_loss.png"
+ )
+ if trainer.exp_path:
+ _write_loss_plot(losses, plot_path)
+ finally:
+ if quick_test_tmp is not None:
+ shutil.rmtree(quick_test_tmp, ignore_errors=True)
+
+
+if __name__ == "__main__":
+ main()
diff --git a/pixi.lock b/pixi.lock
index 42762b62b..51e563c4d 100644
--- a/pixi.lock
+++ b/pixi.lock
@@ -96,6 +96,7 @@ environments:
- pypi: https://files.pythonhosted.org/packages/56/9a/fff8376f8e3d084cd1530e1ef7b879bb7d6d265620c95c1b322725c694f4/nvidia_nvtx_cu12-12.6.77-py3-none-manylinux2014_x86_64.manylinux_2_17_x86_64.whl
- pypi: https://files.pythonhosted.org/packages/53/20/08c6dc0f20c1394e2324b9344838e4e7af770cdcb52c30757a475f50daeb/obstore-0.8.2-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl
- pypi: https://files.pythonhosted.org/packages/7e/95/e0770cf1ad9667492f56b732f44398ef2756d61df914e10d121a3cad013a/ogb-1.3.6-py3-none-any.whl
+ - pypi: https://files.pythonhosted.org/packages/c9/ec/c68e3b9021a31d9ec15a94931db1410136af862955854ed5dd7e7e4f5bff/orjson-3.11.7-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl
- pypi: https://files.pythonhosted.org/packages/d3/04/7d2b9a0d1b81e30f39e6f358bac01f4f18b585f35b0ffc5c83fc274f146b/outdated-0.2.2-py2.py3-none-any.whl
- pypi: https://files.pythonhosted.org/packages/20/12/38679034af332785aac8774540895e234f4d07f7545804097de4b666afd8/packaging-25.0-py3-none-any.whl
- pypi: https://files.pythonhosted.org/packages/e9/e2/20a317688435470872885e7fc8f95109ae9683dec7c50be29b56911515a5/pandas-2.3.1-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl
@@ -217,6 +218,7 @@ environments:
- pypi: https://files.pythonhosted.org/packages/85/c5/e19c8f99d83fd377ec8c7e0cf627a8049746da54afc24ef0a0cb73d5dfb5/numpy-2.2.6-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl
- pypi: https://files.pythonhosted.org/packages/77/20/77907765e29b2eba6bd8821872284d91170d7084f670855b2dfcb249ea14/obstore-0.8.2-cp313-cp313-manylinux_2_24_aarch64.whl
- pypi: https://files.pythonhosted.org/packages/7e/95/e0770cf1ad9667492f56b732f44398ef2756d61df914e10d121a3cad013a/ogb-1.3.6-py3-none-any.whl
+ - pypi: https://files.pythonhosted.org/packages/89/25/0a16e0729a0e6a1504f9d1a13cdd365f030068aab64cec6958396b9969d7/orjson-3.11.7-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl
- pypi: https://files.pythonhosted.org/packages/d3/04/7d2b9a0d1b81e30f39e6f358bac01f4f18b585f35b0ffc5c83fc274f146b/outdated-0.2.2-py2.py3-none-any.whl
- pypi: https://files.pythonhosted.org/packages/20/12/38679034af332785aac8774540895e234f4d07f7545804097de4b666afd8/packaging-25.0-py3-none-any.whl
- pypi: https://files.pythonhosted.org/packages/0f/b0/80f6ec783313f1e2356b28b4fd8d2148c378370045da918c73145e6aab50/pandas-2.3.1-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl
@@ -328,6 +330,7 @@ environments:
- pypi: https://files.pythonhosted.org/packages/dc/9e/14520dc3dadf3c803473bd07e9b2bd1b69bc583cb2497b47000fed2fa92f/numpy-2.2.6-cp313-cp313-macosx_11_0_arm64.whl
- pypi: https://files.pythonhosted.org/packages/ea/4d/699359774ce6330130536d008bfc32827fab0c25a00238d015a5974a3d1d/obstore-0.8.2-cp313-cp313-macosx_11_0_arm64.whl
- pypi: https://files.pythonhosted.org/packages/7e/95/e0770cf1ad9667492f56b732f44398ef2756d61df914e10d121a3cad013a/ogb-1.3.6-py3-none-any.whl
+ - pypi: https://files.pythonhosted.org/packages/89/25/6e0e52cac5aab51d7b6dcd257e855e1dec1c2060f6b28566c509b4665f62/orjson-3.11.7-cp313-cp313-macosx_10_15_x86_64.macosx_11_0_arm64.macosx_10_15_universal2.whl
- pypi: https://files.pythonhosted.org/packages/d3/04/7d2b9a0d1b81e30f39e6f358bac01f4f18b585f35b0ffc5c83fc274f146b/outdated-0.2.2-py2.py3-none-any.whl
- pypi: https://files.pythonhosted.org/packages/20/12/38679034af332785aac8774540895e234f4d07f7545804097de4b666afd8/packaging-25.0-py3-none-any.whl
- pypi: https://files.pythonhosted.org/packages/c7/db/d8f24a7cc9fb0972adab0cc80b6817e8bef888cfd0024eeb5a21c0bb5c4a/pandas-2.3.1-cp313-cp313-macosx_11_0_arm64.whl
@@ -440,6 +443,7 @@ environments:
- pypi: https://files.pythonhosted.org/packages/cb/3b/d58c12eafcb298d4e6d0d40216866ab15f59e55d148a5658bb3132311fcf/numpy-2.2.6-cp313-cp313-win_amd64.whl
- pypi: https://files.pythonhosted.org/packages/14/dd/916c6777222db3271e9fb3cf9a97ed92b3a9b3e465bdeec96de9ab809d53/obstore-0.8.2-cp313-cp313-win_amd64.whl
- pypi: https://files.pythonhosted.org/packages/7e/95/e0770cf1ad9667492f56b732f44398ef2756d61df914e10d121a3cad013a/ogb-1.3.6-py3-none-any.whl
+ - pypi: https://files.pythonhosted.org/packages/00/04/c6f72daca5092e3117840a1b1e88dfc809cc1470cf0734890d0366b684a1/orjson-3.11.7-cp313-cp313-win_amd64.whl
- pypi: https://files.pythonhosted.org/packages/d3/04/7d2b9a0d1b81e30f39e6f358bac01f4f18b585f35b0ffc5c83fc274f146b/outdated-0.2.2-py2.py3-none-any.whl
- pypi: https://files.pythonhosted.org/packages/20/12/38679034af332785aac8774540895e234f4d07f7545804097de4b666afd8/packaging-25.0-py3-none-any.whl
- pypi: https://files.pythonhosted.org/packages/b2/c0/54415af59db5cdd86a3d3bf79863e8cc3fa9ed265f0745254061ac09d5f2/pandas-2.3.1-cp313-cp313-win_amd64.whl
@@ -597,6 +601,7 @@ environments:
- pypi: https://files.pythonhosted.org/packages/56/9a/fff8376f8e3d084cd1530e1ef7b879bb7d6d265620c95c1b322725c694f4/nvidia_nvtx_cu12-12.6.77-py3-none-manylinux2014_x86_64.manylinux_2_17_x86_64.whl
- pypi: https://files.pythonhosted.org/packages/53/20/08c6dc0f20c1394e2324b9344838e4e7af770cdcb52c30757a475f50daeb/obstore-0.8.2-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl
- pypi: https://files.pythonhosted.org/packages/7e/95/e0770cf1ad9667492f56b732f44398ef2756d61df914e10d121a3cad013a/ogb-1.3.6-py3-none-any.whl
+ - pypi: https://files.pythonhosted.org/packages/c9/ec/c68e3b9021a31d9ec15a94931db1410136af862955854ed5dd7e7e4f5bff/orjson-3.11.7-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl
- pypi: https://files.pythonhosted.org/packages/d3/04/7d2b9a0d1b81e30f39e6f358bac01f4f18b585f35b0ffc5c83fc274f146b/outdated-0.2.2-py2.py3-none-any.whl
- pypi: https://files.pythonhosted.org/packages/e9/e2/20a317688435470872885e7fc8f95109ae9683dec7c50be29b56911515a5/pandas-2.3.1-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl
- pypi: https://files.pythonhosted.org/packages/71/e7/40fb618334dcdf7c5a316c0e7343c5cd82d3d866edc100d98e29bc945ecd/partd-1.4.2-py3-none-any.whl
@@ -727,6 +732,7 @@ environments:
- pypi: https://files.pythonhosted.org/packages/85/c5/e19c8f99d83fd377ec8c7e0cf627a8049746da54afc24ef0a0cb73d5dfb5/numpy-2.2.6-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl
- pypi: https://files.pythonhosted.org/packages/77/20/77907765e29b2eba6bd8821872284d91170d7084f670855b2dfcb249ea14/obstore-0.8.2-cp313-cp313-manylinux_2_24_aarch64.whl
- pypi: https://files.pythonhosted.org/packages/7e/95/e0770cf1ad9667492f56b732f44398ef2756d61df914e10d121a3cad013a/ogb-1.3.6-py3-none-any.whl
+ - pypi: https://files.pythonhosted.org/packages/89/25/0a16e0729a0e6a1504f9d1a13cdd365f030068aab64cec6958396b9969d7/orjson-3.11.7-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl
- pypi: https://files.pythonhosted.org/packages/d3/04/7d2b9a0d1b81e30f39e6f358bac01f4f18b585f35b0ffc5c83fc274f146b/outdated-0.2.2-py2.py3-none-any.whl
- pypi: https://files.pythonhosted.org/packages/0f/b0/80f6ec783313f1e2356b28b4fd8d2148c378370045da918c73145e6aab50/pandas-2.3.1-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl
- pypi: https://files.pythonhosted.org/packages/71/e7/40fb618334dcdf7c5a316c0e7343c5cd82d3d866edc100d98e29bc945ecd/partd-1.4.2-py3-none-any.whl
@@ -848,6 +854,7 @@ environments:
- pypi: https://files.pythonhosted.org/packages/dc/9e/14520dc3dadf3c803473bd07e9b2bd1b69bc583cb2497b47000fed2fa92f/numpy-2.2.6-cp313-cp313-macosx_11_0_arm64.whl
- pypi: https://files.pythonhosted.org/packages/ea/4d/699359774ce6330130536d008bfc32827fab0c25a00238d015a5974a3d1d/obstore-0.8.2-cp313-cp313-macosx_11_0_arm64.whl
- pypi: https://files.pythonhosted.org/packages/7e/95/e0770cf1ad9667492f56b732f44398ef2756d61df914e10d121a3cad013a/ogb-1.3.6-py3-none-any.whl
+ - pypi: https://files.pythonhosted.org/packages/89/25/6e0e52cac5aab51d7b6dcd257e855e1dec1c2060f6b28566c509b4665f62/orjson-3.11.7-cp313-cp313-macosx_10_15_x86_64.macosx_11_0_arm64.macosx_10_15_universal2.whl
- pypi: https://files.pythonhosted.org/packages/d3/04/7d2b9a0d1b81e30f39e6f358bac01f4f18b585f35b0ffc5c83fc274f146b/outdated-0.2.2-py2.py3-none-any.whl
- pypi: https://files.pythonhosted.org/packages/c7/db/d8f24a7cc9fb0972adab0cc80b6817e8bef888cfd0024eeb5a21c0bb5c4a/pandas-2.3.1-cp313-cp313-macosx_11_0_arm64.whl
- pypi: https://files.pythonhosted.org/packages/71/e7/40fb618334dcdf7c5a316c0e7343c5cd82d3d866edc100d98e29bc945ecd/partd-1.4.2-py3-none-any.whl
@@ -970,6 +977,7 @@ environments:
- pypi: https://files.pythonhosted.org/packages/cb/3b/d58c12eafcb298d4e6d0d40216866ab15f59e55d148a5658bb3132311fcf/numpy-2.2.6-cp313-cp313-win_amd64.whl
- pypi: https://files.pythonhosted.org/packages/14/dd/916c6777222db3271e9fb3cf9a97ed92b3a9b3e465bdeec96de9ab809d53/obstore-0.8.2-cp313-cp313-win_amd64.whl
- pypi: https://files.pythonhosted.org/packages/7e/95/e0770cf1ad9667492f56b732f44398ef2756d61df914e10d121a3cad013a/ogb-1.3.6-py3-none-any.whl
+ - pypi: https://files.pythonhosted.org/packages/00/04/c6f72daca5092e3117840a1b1e88dfc809cc1470cf0734890d0366b684a1/orjson-3.11.7-cp313-cp313-win_amd64.whl
- pypi: https://files.pythonhosted.org/packages/d3/04/7d2b9a0d1b81e30f39e6f358bac01f4f18b585f35b0ffc5c83fc274f146b/outdated-0.2.2-py2.py3-none-any.whl
- pypi: https://files.pythonhosted.org/packages/b2/c0/54415af59db5cdd86a3d3bf79863e8cc3fa9ed265f0745254061ac09d5f2/pandas-2.3.1-cp313-cp313-win_amd64.whl
- pypi: https://files.pythonhosted.org/packages/71/e7/40fb618334dcdf7c5a316c0e7343c5cd82d3d866edc100d98e29bc945ecd/partd-1.4.2-py3-none-any.whl
@@ -1158,6 +1166,7 @@ environments:
- pypi: https://files.pythonhosted.org/packages/56/9a/fff8376f8e3d084cd1530e1ef7b879bb7d6d265620c95c1b322725c694f4/nvidia_nvtx_cu12-12.6.77-py3-none-manylinux2014_x86_64.manylinux_2_17_x86_64.whl
- pypi: https://files.pythonhosted.org/packages/53/20/08c6dc0f20c1394e2324b9344838e4e7af770cdcb52c30757a475f50daeb/obstore-0.8.2-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl
- pypi: https://files.pythonhosted.org/packages/7e/95/e0770cf1ad9667492f56b732f44398ef2756d61df914e10d121a3cad013a/ogb-1.3.6-py3-none-any.whl
+ - pypi: https://files.pythonhosted.org/packages/c9/ec/c68e3b9021a31d9ec15a94931db1410136af862955854ed5dd7e7e4f5bff/orjson-3.11.7-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl
- pypi: https://files.pythonhosted.org/packages/d3/04/7d2b9a0d1b81e30f39e6f358bac01f4f18b585f35b0ffc5c83fc274f146b/outdated-0.2.2-py2.py3-none-any.whl
- pypi: https://files.pythonhosted.org/packages/e9/e2/20a317688435470872885e7fc8f95109ae9683dec7c50be29b56911515a5/pandas-2.3.1-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl
- pypi: https://files.pythonhosted.org/packages/71/e7/40fb618334dcdf7c5a316c0e7343c5cd82d3d866edc100d98e29bc945ecd/partd-1.4.2-py3-none-any.whl
@@ -1324,6 +1333,7 @@ environments:
- pypi: https://files.pythonhosted.org/packages/85/c5/e19c8f99d83fd377ec8c7e0cf627a8049746da54afc24ef0a0cb73d5dfb5/numpy-2.2.6-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl
- pypi: https://files.pythonhosted.org/packages/77/20/77907765e29b2eba6bd8821872284d91170d7084f670855b2dfcb249ea14/obstore-0.8.2-cp313-cp313-manylinux_2_24_aarch64.whl
- pypi: https://files.pythonhosted.org/packages/7e/95/e0770cf1ad9667492f56b732f44398ef2756d61df914e10d121a3cad013a/ogb-1.3.6-py3-none-any.whl
+ - pypi: https://files.pythonhosted.org/packages/89/25/0a16e0729a0e6a1504f9d1a13cdd365f030068aab64cec6958396b9969d7/orjson-3.11.7-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl
- pypi: https://files.pythonhosted.org/packages/d3/04/7d2b9a0d1b81e30f39e6f358bac01f4f18b585f35b0ffc5c83fc274f146b/outdated-0.2.2-py2.py3-none-any.whl
- pypi: https://files.pythonhosted.org/packages/0f/b0/80f6ec783313f1e2356b28b4fd8d2148c378370045da918c73145e6aab50/pandas-2.3.1-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl
- pypi: https://files.pythonhosted.org/packages/71/e7/40fb618334dcdf7c5a316c0e7343c5cd82d3d866edc100d98e29bc945ecd/partd-1.4.2-py3-none-any.whl
@@ -1474,6 +1484,7 @@ environments:
- pypi: https://files.pythonhosted.org/packages/dc/9e/14520dc3dadf3c803473bd07e9b2bd1b69bc583cb2497b47000fed2fa92f/numpy-2.2.6-cp313-cp313-macosx_11_0_arm64.whl
- pypi: https://files.pythonhosted.org/packages/ea/4d/699359774ce6330130536d008bfc32827fab0c25a00238d015a5974a3d1d/obstore-0.8.2-cp313-cp313-macosx_11_0_arm64.whl
- pypi: https://files.pythonhosted.org/packages/7e/95/e0770cf1ad9667492f56b732f44398ef2756d61df914e10d121a3cad013a/ogb-1.3.6-py3-none-any.whl
+ - pypi: https://files.pythonhosted.org/packages/89/25/6e0e52cac5aab51d7b6dcd257e855e1dec1c2060f6b28566c509b4665f62/orjson-3.11.7-cp313-cp313-macosx_10_15_x86_64.macosx_11_0_arm64.macosx_10_15_universal2.whl
- pypi: https://files.pythonhosted.org/packages/d3/04/7d2b9a0d1b81e30f39e6f358bac01f4f18b585f35b0ffc5c83fc274f146b/outdated-0.2.2-py2.py3-none-any.whl
- pypi: https://files.pythonhosted.org/packages/c7/db/d8f24a7cc9fb0972adab0cc80b6817e8bef888cfd0024eeb5a21c0bb5c4a/pandas-2.3.1-cp313-cp313-macosx_11_0_arm64.whl
- pypi: https://files.pythonhosted.org/packages/71/e7/40fb618334dcdf7c5a316c0e7343c5cd82d3d866edc100d98e29bc945ecd/partd-1.4.2-py3-none-any.whl
@@ -1626,6 +1637,7 @@ environments:
- pypi: https://files.pythonhosted.org/packages/cb/3b/d58c12eafcb298d4e6d0d40216866ab15f59e55d148a5658bb3132311fcf/numpy-2.2.6-cp313-cp313-win_amd64.whl
- pypi: https://files.pythonhosted.org/packages/14/dd/916c6777222db3271e9fb3cf9a97ed92b3a9b3e465bdeec96de9ab809d53/obstore-0.8.2-cp313-cp313-win_amd64.whl
- pypi: https://files.pythonhosted.org/packages/7e/95/e0770cf1ad9667492f56b732f44398ef2756d61df914e10d121a3cad013a/ogb-1.3.6-py3-none-any.whl
+ - pypi: https://files.pythonhosted.org/packages/00/04/c6f72daca5092e3117840a1b1e88dfc809cc1470cf0734890d0366b684a1/orjson-3.11.7-cp313-cp313-win_amd64.whl
- pypi: https://files.pythonhosted.org/packages/d3/04/7d2b9a0d1b81e30f39e6f358bac01f4f18b585f35b0ffc5c83fc274f146b/outdated-0.2.2-py2.py3-none-any.whl
- pypi: https://files.pythonhosted.org/packages/b2/c0/54415af59db5cdd86a3d3bf79863e8cc3fa9ed265f0745254061ac09d5f2/pandas-2.3.1-cp313-cp313-win_amd64.whl
- pypi: https://files.pythonhosted.org/packages/71/e7/40fb618334dcdf7c5a316c0e7343c5cd82d3d866edc100d98e29bc945ecd/partd-1.4.2-py3-none-any.whl
@@ -1777,6 +1789,7 @@ environments:
- pypi: https://files.pythonhosted.org/packages/56/9a/fff8376f8e3d084cd1530e1ef7b879bb7d6d265620c95c1b322725c694f4/nvidia_nvtx_cu12-12.6.77-py3-none-manylinux2014_x86_64.manylinux_2_17_x86_64.whl
- pypi: https://files.pythonhosted.org/packages/53/20/08c6dc0f20c1394e2324b9344838e4e7af770cdcb52c30757a475f50daeb/obstore-0.8.2-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl
- pypi: https://files.pythonhosted.org/packages/7e/95/e0770cf1ad9667492f56b732f44398ef2756d61df914e10d121a3cad013a/ogb-1.3.6-py3-none-any.whl
+ - pypi: https://files.pythonhosted.org/packages/c9/ec/c68e3b9021a31d9ec15a94931db1410136af862955854ed5dd7e7e4f5bff/orjson-3.11.7-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl
- pypi: https://files.pythonhosted.org/packages/d3/04/7d2b9a0d1b81e30f39e6f358bac01f4f18b585f35b0ffc5c83fc274f146b/outdated-0.2.2-py2.py3-none-any.whl
- pypi: https://files.pythonhosted.org/packages/e9/e2/20a317688435470872885e7fc8f95109ae9683dec7c50be29b56911515a5/pandas-2.3.1-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl
- pypi: https://files.pythonhosted.org/packages/71/e7/40fb618334dcdf7c5a316c0e7343c5cd82d3d866edc100d98e29bc945ecd/partd-1.4.2-py3-none-any.whl
@@ -1906,6 +1919,7 @@ environments:
- pypi: https://files.pythonhosted.org/packages/85/c5/e19c8f99d83fd377ec8c7e0cf627a8049746da54afc24ef0a0cb73d5dfb5/numpy-2.2.6-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl
- pypi: https://files.pythonhosted.org/packages/77/20/77907765e29b2eba6bd8821872284d91170d7084f670855b2dfcb249ea14/obstore-0.8.2-cp313-cp313-manylinux_2_24_aarch64.whl
- pypi: https://files.pythonhosted.org/packages/7e/95/e0770cf1ad9667492f56b732f44398ef2756d61df914e10d121a3cad013a/ogb-1.3.6-py3-none-any.whl
+ - pypi: https://files.pythonhosted.org/packages/89/25/0a16e0729a0e6a1504f9d1a13cdd365f030068aab64cec6958396b9969d7/orjson-3.11.7-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl
- pypi: https://files.pythonhosted.org/packages/d3/04/7d2b9a0d1b81e30f39e6f358bac01f4f18b585f35b0ffc5c83fc274f146b/outdated-0.2.2-py2.py3-none-any.whl
- pypi: https://files.pythonhosted.org/packages/0f/b0/80f6ec783313f1e2356b28b4fd8d2148c378370045da918c73145e6aab50/pandas-2.3.1-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl
- pypi: https://files.pythonhosted.org/packages/71/e7/40fb618334dcdf7c5a316c0e7343c5cd82d3d866edc100d98e29bc945ecd/partd-1.4.2-py3-none-any.whl
@@ -2026,6 +2040,7 @@ environments:
- pypi: https://files.pythonhosted.org/packages/dc/9e/14520dc3dadf3c803473bd07e9b2bd1b69bc583cb2497b47000fed2fa92f/numpy-2.2.6-cp313-cp313-macosx_11_0_arm64.whl
- pypi: https://files.pythonhosted.org/packages/ea/4d/699359774ce6330130536d008bfc32827fab0c25a00238d015a5974a3d1d/obstore-0.8.2-cp313-cp313-macosx_11_0_arm64.whl
- pypi: https://files.pythonhosted.org/packages/7e/95/e0770cf1ad9667492f56b732f44398ef2756d61df914e10d121a3cad013a/ogb-1.3.6-py3-none-any.whl
+ - pypi: https://files.pythonhosted.org/packages/89/25/6e0e52cac5aab51d7b6dcd257e855e1dec1c2060f6b28566c509b4665f62/orjson-3.11.7-cp313-cp313-macosx_10_15_x86_64.macosx_11_0_arm64.macosx_10_15_universal2.whl
- pypi: https://files.pythonhosted.org/packages/d3/04/7d2b9a0d1b81e30f39e6f358bac01f4f18b585f35b0ffc5c83fc274f146b/outdated-0.2.2-py2.py3-none-any.whl
- pypi: https://files.pythonhosted.org/packages/c7/db/d8f24a7cc9fb0972adab0cc80b6817e8bef888cfd0024eeb5a21c0bb5c4a/pandas-2.3.1-cp313-cp313-macosx_11_0_arm64.whl
- pypi: https://files.pythonhosted.org/packages/71/e7/40fb618334dcdf7c5a316c0e7343c5cd82d3d866edc100d98e29bc945ecd/partd-1.4.2-py3-none-any.whl
@@ -2147,6 +2162,7 @@ environments:
- pypi: https://files.pythonhosted.org/packages/cb/3b/d58c12eafcb298d4e6d0d40216866ab15f59e55d148a5658bb3132311fcf/numpy-2.2.6-cp313-cp313-win_amd64.whl
- pypi: https://files.pythonhosted.org/packages/14/dd/916c6777222db3271e9fb3cf9a97ed92b3a9b3e465bdeec96de9ab809d53/obstore-0.8.2-cp313-cp313-win_amd64.whl
- pypi: https://files.pythonhosted.org/packages/7e/95/e0770cf1ad9667492f56b732f44398ef2756d61df914e10d121a3cad013a/ogb-1.3.6-py3-none-any.whl
+ - pypi: https://files.pythonhosted.org/packages/00/04/c6f72daca5092e3117840a1b1e88dfc809cc1470cf0734890d0366b684a1/orjson-3.11.7-cp313-cp313-win_amd64.whl
- pypi: https://files.pythonhosted.org/packages/d3/04/7d2b9a0d1b81e30f39e6f358bac01f4f18b585f35b0ffc5c83fc274f146b/outdated-0.2.2-py2.py3-none-any.whl
- pypi: https://files.pythonhosted.org/packages/b2/c0/54415af59db5cdd86a3d3bf79863e8cc3fa9ed265f0745254061ac09d5f2/pandas-2.3.1-cp313-cp313-win_amd64.whl
- pypi: https://files.pythonhosted.org/packages/71/e7/40fb618334dcdf7c5a316c0e7343c5cd82d3d866edc100d98e29bc945ecd/partd-1.4.2-py3-none-any.whl
@@ -2289,6 +2305,7 @@ environments:
- pypi: https://files.pythonhosted.org/packages/56/9a/fff8376f8e3d084cd1530e1ef7b879bb7d6d265620c95c1b322725c694f4/nvidia_nvtx_cu12-12.6.77-py3-none-manylinux2014_x86_64.manylinux_2_17_x86_64.whl
- pypi: https://files.pythonhosted.org/packages/53/20/08c6dc0f20c1394e2324b9344838e4e7af770cdcb52c30757a475f50daeb/obstore-0.8.2-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl
- pypi: https://files.pythonhosted.org/packages/7e/95/e0770cf1ad9667492f56b732f44398ef2756d61df914e10d121a3cad013a/ogb-1.3.6-py3-none-any.whl
+ - pypi: https://files.pythonhosted.org/packages/c9/ec/c68e3b9021a31d9ec15a94931db1410136af862955854ed5dd7e7e4f5bff/orjson-3.11.7-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl
- pypi: https://files.pythonhosted.org/packages/d3/04/7d2b9a0d1b81e30f39e6f358bac01f4f18b585f35b0ffc5c83fc274f146b/outdated-0.2.2-py2.py3-none-any.whl
- pypi: https://files.pythonhosted.org/packages/20/12/38679034af332785aac8774540895e234f4d07f7545804097de4b666afd8/packaging-25.0-py3-none-any.whl
- pypi: https://files.pythonhosted.org/packages/e9/e2/20a317688435470872885e7fc8f95109ae9683dec7c50be29b56911515a5/pandas-2.3.1-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl
@@ -2415,6 +2432,7 @@ environments:
- pypi: https://files.pythonhosted.org/packages/85/c5/e19c8f99d83fd377ec8c7e0cf627a8049746da54afc24ef0a0cb73d5dfb5/numpy-2.2.6-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl
- pypi: https://files.pythonhosted.org/packages/77/20/77907765e29b2eba6bd8821872284d91170d7084f670855b2dfcb249ea14/obstore-0.8.2-cp313-cp313-manylinux_2_24_aarch64.whl
- pypi: https://files.pythonhosted.org/packages/7e/95/e0770cf1ad9667492f56b732f44398ef2756d61df914e10d121a3cad013a/ogb-1.3.6-py3-none-any.whl
+ - pypi: https://files.pythonhosted.org/packages/89/25/0a16e0729a0e6a1504f9d1a13cdd365f030068aab64cec6958396b9969d7/orjson-3.11.7-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl
- pypi: https://files.pythonhosted.org/packages/d3/04/7d2b9a0d1b81e30f39e6f358bac01f4f18b585f35b0ffc5c83fc274f146b/outdated-0.2.2-py2.py3-none-any.whl
- pypi: https://files.pythonhosted.org/packages/20/12/38679034af332785aac8774540895e234f4d07f7545804097de4b666afd8/packaging-25.0-py3-none-any.whl
- pypi: https://files.pythonhosted.org/packages/0f/b0/80f6ec783313f1e2356b28b4fd8d2148c378370045da918c73145e6aab50/pandas-2.3.1-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl
@@ -2531,6 +2549,7 @@ environments:
- pypi: https://files.pythonhosted.org/packages/dc/9e/14520dc3dadf3c803473bd07e9b2bd1b69bc583cb2497b47000fed2fa92f/numpy-2.2.6-cp313-cp313-macosx_11_0_arm64.whl
- pypi: https://files.pythonhosted.org/packages/ea/4d/699359774ce6330130536d008bfc32827fab0c25a00238d015a5974a3d1d/obstore-0.8.2-cp313-cp313-macosx_11_0_arm64.whl
- pypi: https://files.pythonhosted.org/packages/7e/95/e0770cf1ad9667492f56b732f44398ef2756d61df914e10d121a3cad013a/ogb-1.3.6-py3-none-any.whl
+ - pypi: https://files.pythonhosted.org/packages/89/25/6e0e52cac5aab51d7b6dcd257e855e1dec1c2060f6b28566c509b4665f62/orjson-3.11.7-cp313-cp313-macosx_10_15_x86_64.macosx_11_0_arm64.macosx_10_15_universal2.whl
- pypi: https://files.pythonhosted.org/packages/d3/04/7d2b9a0d1b81e30f39e6f358bac01f4f18b585f35b0ffc5c83fc274f146b/outdated-0.2.2-py2.py3-none-any.whl
- pypi: https://files.pythonhosted.org/packages/20/12/38679034af332785aac8774540895e234f4d07f7545804097de4b666afd8/packaging-25.0-py3-none-any.whl
- pypi: https://files.pythonhosted.org/packages/c7/db/d8f24a7cc9fb0972adab0cc80b6817e8bef888cfd0024eeb5a21c0bb5c4a/pandas-2.3.1-cp313-cp313-macosx_11_0_arm64.whl
@@ -2648,6 +2667,7 @@ environments:
- pypi: https://files.pythonhosted.org/packages/cb/3b/d58c12eafcb298d4e6d0d40216866ab15f59e55d148a5658bb3132311fcf/numpy-2.2.6-cp313-cp313-win_amd64.whl
- pypi: https://files.pythonhosted.org/packages/14/dd/916c6777222db3271e9fb3cf9a97ed92b3a9b3e465bdeec96de9ab809d53/obstore-0.8.2-cp313-cp313-win_amd64.whl
- pypi: https://files.pythonhosted.org/packages/7e/95/e0770cf1ad9667492f56b732f44398ef2756d61df914e10d121a3cad013a/ogb-1.3.6-py3-none-any.whl
+ - pypi: https://files.pythonhosted.org/packages/00/04/c6f72daca5092e3117840a1b1e88dfc809cc1470cf0734890d0366b684a1/orjson-3.11.7-cp313-cp313-win_amd64.whl
- pypi: https://files.pythonhosted.org/packages/d3/04/7d2b9a0d1b81e30f39e6f358bac01f4f18b585f35b0ffc5c83fc274f146b/outdated-0.2.2-py2.py3-none-any.whl
- pypi: https://files.pythonhosted.org/packages/20/12/38679034af332785aac8774540895e234f4d07f7545804097de4b666afd8/packaging-25.0-py3-none-any.whl
- pypi: https://files.pythonhosted.org/packages/b2/c0/54415af59db5cdd86a3d3bf79863e8cc3fa9ed265f0745254061ac09d5f2/pandas-2.3.1-cp313-cp313-win_amd64.whl
@@ -2792,6 +2812,7 @@ environments:
- pypi: https://files.pythonhosted.org/packages/56/9a/fff8376f8e3d084cd1530e1ef7b879bb7d6d265620c95c1b322725c694f4/nvidia_nvtx_cu12-12.6.77-py3-none-manylinux2014_x86_64.manylinux_2_17_x86_64.whl
- pypi: https://files.pythonhosted.org/packages/53/20/08c6dc0f20c1394e2324b9344838e4e7af770cdcb52c30757a475f50daeb/obstore-0.8.2-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl
- pypi: https://files.pythonhosted.org/packages/7e/95/e0770cf1ad9667492f56b732f44398ef2756d61df914e10d121a3cad013a/ogb-1.3.6-py3-none-any.whl
+ - pypi: https://files.pythonhosted.org/packages/c9/ec/c68e3b9021a31d9ec15a94931db1410136af862955854ed5dd7e7e4f5bff/orjson-3.11.7-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl
- pypi: https://files.pythonhosted.org/packages/d3/04/7d2b9a0d1b81e30f39e6f358bac01f4f18b585f35b0ffc5c83fc274f146b/outdated-0.2.2-py2.py3-none-any.whl
- pypi: https://files.pythonhosted.org/packages/20/12/38679034af332785aac8774540895e234f4d07f7545804097de4b666afd8/packaging-25.0-py3-none-any.whl
- pypi: https://files.pythonhosted.org/packages/e9/e2/20a317688435470872885e7fc8f95109ae9683dec7c50be29b56911515a5/pandas-2.3.1-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl
@@ -2913,6 +2934,7 @@ environments:
- pypi: https://files.pythonhosted.org/packages/85/c5/e19c8f99d83fd377ec8c7e0cf627a8049746da54afc24ef0a0cb73d5dfb5/numpy-2.2.6-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl
- pypi: https://files.pythonhosted.org/packages/77/20/77907765e29b2eba6bd8821872284d91170d7084f670855b2dfcb249ea14/obstore-0.8.2-cp313-cp313-manylinux_2_24_aarch64.whl
- pypi: https://files.pythonhosted.org/packages/7e/95/e0770cf1ad9667492f56b732f44398ef2756d61df914e10d121a3cad013a/ogb-1.3.6-py3-none-any.whl
+ - pypi: https://files.pythonhosted.org/packages/89/25/0a16e0729a0e6a1504f9d1a13cdd365f030068aab64cec6958396b9969d7/orjson-3.11.7-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl
- pypi: https://files.pythonhosted.org/packages/d3/04/7d2b9a0d1b81e30f39e6f358bac01f4f18b585f35b0ffc5c83fc274f146b/outdated-0.2.2-py2.py3-none-any.whl
- pypi: https://files.pythonhosted.org/packages/20/12/38679034af332785aac8774540895e234f4d07f7545804097de4b666afd8/packaging-25.0-py3-none-any.whl
- pypi: https://files.pythonhosted.org/packages/0f/b0/80f6ec783313f1e2356b28b4fd8d2148c378370045da918c73145e6aab50/pandas-2.3.1-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl
@@ -3024,6 +3046,7 @@ environments:
- pypi: https://files.pythonhosted.org/packages/dc/9e/14520dc3dadf3c803473bd07e9b2bd1b69bc583cb2497b47000fed2fa92f/numpy-2.2.6-cp313-cp313-macosx_11_0_arm64.whl
- pypi: https://files.pythonhosted.org/packages/ea/4d/699359774ce6330130536d008bfc32827fab0c25a00238d015a5974a3d1d/obstore-0.8.2-cp313-cp313-macosx_11_0_arm64.whl
- pypi: https://files.pythonhosted.org/packages/7e/95/e0770cf1ad9667492f56b732f44398ef2756d61df914e10d121a3cad013a/ogb-1.3.6-py3-none-any.whl
+ - pypi: https://files.pythonhosted.org/packages/89/25/6e0e52cac5aab51d7b6dcd257e855e1dec1c2060f6b28566c509b4665f62/orjson-3.11.7-cp313-cp313-macosx_10_15_x86_64.macosx_11_0_arm64.macosx_10_15_universal2.whl
- pypi: https://files.pythonhosted.org/packages/d3/04/7d2b9a0d1b81e30f39e6f358bac01f4f18b585f35b0ffc5c83fc274f146b/outdated-0.2.2-py2.py3-none-any.whl
- pypi: https://files.pythonhosted.org/packages/20/12/38679034af332785aac8774540895e234f4d07f7545804097de4b666afd8/packaging-25.0-py3-none-any.whl
- pypi: https://files.pythonhosted.org/packages/c7/db/d8f24a7cc9fb0972adab0cc80b6817e8bef888cfd0024eeb5a21c0bb5c4a/pandas-2.3.1-cp313-cp313-macosx_11_0_arm64.whl
@@ -3136,6 +3159,7 @@ environments:
- pypi: https://files.pythonhosted.org/packages/cb/3b/d58c12eafcb298d4e6d0d40216866ab15f59e55d148a5658bb3132311fcf/numpy-2.2.6-cp313-cp313-win_amd64.whl
- pypi: https://files.pythonhosted.org/packages/14/dd/916c6777222db3271e9fb3cf9a97ed92b3a9b3e465bdeec96de9ab809d53/obstore-0.8.2-cp313-cp313-win_amd64.whl
- pypi: https://files.pythonhosted.org/packages/7e/95/e0770cf1ad9667492f56b732f44398ef2756d61df914e10d121a3cad013a/ogb-1.3.6-py3-none-any.whl
+ - pypi: https://files.pythonhosted.org/packages/00/04/c6f72daca5092e3117840a1b1e88dfc809cc1470cf0734890d0366b684a1/orjson-3.11.7-cp313-cp313-win_amd64.whl
- pypi: https://files.pythonhosted.org/packages/d3/04/7d2b9a0d1b81e30f39e6f358bac01f4f18b585f35b0ffc5c83fc274f146b/outdated-0.2.2-py2.py3-none-any.whl
- pypi: https://files.pythonhosted.org/packages/20/12/38679034af332785aac8774540895e234f4d07f7545804097de4b666afd8/packaging-25.0-py3-none-any.whl
- pypi: https://files.pythonhosted.org/packages/b2/c0/54415af59db5cdd86a3d3bf79863e8cc3fa9ed265f0745254061ac09d5f2/pandas-2.3.1-cp313-cp313-win_amd64.whl
@@ -6152,6 +6176,26 @@ packages:
purls: []
size: 9327033
timestamp: 1751392489008
+- pypi: https://files.pythonhosted.org/packages/00/04/c6f72daca5092e3117840a1b1e88dfc809cc1470cf0734890d0366b684a1/orjson-3.11.7-cp313-cp313-win_amd64.whl
+ name: orjson
+ version: 3.11.7
+ sha256: b9f95dcdea9d4f805daa9ddf02617a89e484c6985fa03055459f90e87d7a0757
+ requires_python: '>=3.10'
+- pypi: https://files.pythonhosted.org/packages/89/25/0a16e0729a0e6a1504f9d1a13cdd365f030068aab64cec6958396b9969d7/orjson-3.11.7-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl
+ name: orjson
+ version: 3.11.7
+ sha256: 814be4b49b228cfc0b3c565acf642dd7d13538f966e3ccde61f4f55be3e20785
+ requires_python: '>=3.10'
+- pypi: https://files.pythonhosted.org/packages/89/25/6e0e52cac5aab51d7b6dcd257e855e1dec1c2060f6b28566c509b4665f62/orjson-3.11.7-cp313-cp313-macosx_10_15_x86_64.macosx_11_0_arm64.macosx_10_15_universal2.whl
+ name: orjson
+ version: 3.11.7
+ sha256: 1d98b30cc1313d52d4af17d9c3d307b08389752ec5f2e5febdfada70b0f8c733
+ requires_python: '>=3.10'
+- pypi: https://files.pythonhosted.org/packages/c9/ec/c68e3b9021a31d9ec15a94931db1410136af862955854ed5dd7e7e4f5bff/orjson-3.11.7-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl
+ name: orjson
+ version: 3.11.7
+ sha256: a12b80df61aab7b98b490fe9e4879925ba666fccdfcd175252ce4d9035865ace
+ requires_python: '>=3.10'
- pypi: https://files.pythonhosted.org/packages/d3/04/7d2b9a0d1b81e30f39e6f358bac01f4f18b585f35b0ffc5c83fc274f146b/outdated-0.2.2-py2.py3-none-any.whl
name: outdated
version: 0.2.2
@@ -7082,7 +7126,7 @@ packages:
- pypi: ./
name: pyhealth
version: 2.0.1
- sha256: bf368461a8e66f93ad43f5880295045cbabe6688064af9a650a92bdaf1665332
+ sha256: 941556c467dc4bb2cbe43e6b755c9b30c080151d9506bcfe584315280889b50b
requires_dist:
- torch~=2.7.1
- torchvision
@@ -7107,6 +7151,7 @@ packages:
- more-itertools~=10.8.0
- einops>=0.8.0
- linear-attention-transformer>=0.19.1
+ - orjson~=3.10
- torch-geometric>=2.6.0 ; extra == 'graph'
- editdistance~=0.8.1 ; extra == 'nlp'
- rouge-score~=0.1.2 ; extra == 'nlp'
diff --git a/pyhealth/datasets/__init__.py b/pyhealth/datasets/__init__.py
index 54e77670c..c41285d0a 100644
--- a/pyhealth/datasets/__init__.py
+++ b/pyhealth/datasets/__init__.py
@@ -59,6 +59,7 @@ def __init__(self, *args, **kwargs):
from .medical_transcriptions import MedicalTranscriptionsDataset
from .mimic3 import MIMIC3Dataset
from .mimic4 import MIMIC4CXRDataset, MIMIC4Dataset, MIMIC4EHRDataset, MIMIC4NoteDataset
+from .mimic4_fhir import MIMIC4FHIRDataset
from .mimicextract import MIMICExtractDataset
from .omop import OMOPDataset
from .sample_dataset import SampleBuilder, SampleDataset, create_sample_dataset
diff --git a/pyhealth/datasets/configs/mimic4_fhir.yaml b/pyhealth/datasets/configs/mimic4_fhir.yaml
new file mode 100644
index 000000000..989ab79cf
--- /dev/null
+++ b/pyhealth/datasets/configs/mimic4_fhir.yaml
@@ -0,0 +1,118 @@
+# MIMIC-IV FHIR Resource Flattening Configuration
+# ================================================
+#
+# This YAML defines the normalized schema for MIMIC-IV FHIR exports after streaming
+# ingestion. Raw NDJSON/NDJSON.GZ resources are parsed and flattened into six
+# per-resource-type Parquet tables (see ``version`` below), then loaded through the
+# standard BaseDataset pipeline for task construction.
+#
+# For ingest details, see the docstring of ``stream_fhir_ndjson_to_flat_tables()``
+# in ``pyhealth/datasets/mimic4_fhir.py``.
+
+version: "fhir_r4_flattened"
+
+# Glob Pattern(s) for NDJSON File Discovery
+# ===========================================
+#
+# ``glob_patterns`` (list) or ``glob_pattern`` (string): Patterns to match NDJSON files
+# under the ingest root directory. Patterns are applied via pathlib.Path.glob().
+#
+# Default: Six targeted patterns matching PhysioNet MIMIC-IV FHIR Mimic* shard families
+# that map to flattened tables. This avoids decompressing and parsing ~10% of PhysioNet
+# exports (MedicationAdministration, Specimen, Organization, …) that are skipped by
+# the flattener.
+#
+# Alternatives:
+# - For non-PhysioNet naming, use a single broad pattern:
+# glob_pattern: "**/*.ndjson.gz"
+# - To test on a subset, use a narrower list:
+# glob_patterns:
+# - "**/MimicPatient*.ndjson.gz"
+# - "**/MimicObservation*.ndjson.gz"
+#
+# Notes:
+# - Patterns use ``**/`` for recursive search (works in both flat and nested layouts).
+# - Can be overridden at runtime via MIMIC4FHIRDataset(glob_pattern=...) or
+# MIMIC4FHIRDataset(glob_patterns=[...]).
+
+glob_patterns:
+ - "**/MimicPatient*.ndjson.gz"
+ - "**/MimicEncounter*.ndjson.gz"
+ - "**/MimicCondition*.ndjson.gz"
+ - "**/MimicObservation*.ndjson.gz"
+ - "**/MimicMedicationRequest*.ndjson.gz"
+ - "**/MimicProcedure*.ndjson.gz"
+
+# Flattened Table Schema
+# ======================
+#
+# Each table is normalized from a single FHIR resource type. Columns are:
+# - patient_id (str): Foreign key to patient (derived from subject.reference or id).
+# - [timestamp] (str): ISO 8601 datetime string (coerced; nullable).
+# - attributes (List[str]): Additional columns from the resource.
+#
+# Unsupported resource types (Medication, MedicationAdministration, Specimen, …)
+# are silently dropped during ingest; only tables listed here are written.
+
+tables:
+ patient:
+ file_path: "patient.parquet"
+ patient_id: "patient_id"
+ timestamp: "birth_date"
+ attributes:
+ - "patient_fhir_id"
+ - "birth_date"
+ - "gender"
+ - "deceased_boolean"
+ - "deceased_datetime"
+
+ encounter:
+ file_path: "encounter.parquet"
+ patient_id: "patient_id"
+ timestamp: "event_time"
+ attributes:
+ - "resource_id"
+ - "encounter_id"
+ - "event_time"
+ - "encounter_class"
+ - "encounter_end"
+
+ condition:
+ file_path: "condition.parquet"
+ patient_id: "patient_id"
+ timestamp: "event_time"
+ attributes:
+ - "resource_id"
+ - "encounter_id"
+ - "event_time"
+ - "concept_key"
+
+ observation:
+ file_path: "observation.parquet"
+ patient_id: "patient_id"
+ timestamp: "event_time"
+ attributes:
+ - "resource_id"
+ - "encounter_id"
+ - "event_time"
+ - "concept_key"
+
+ medication_request:
+ file_path: "medication_request.parquet"
+ patient_id: "patient_id"
+ timestamp: "event_time"
+ attributes:
+ - "resource_id"
+ - "encounter_id"
+ - "event_time"
+ - "concept_key"
+
+ procedure:
+ file_path: "procedure.parquet"
+ patient_id: "patient_id"
+ timestamp: "event_time"
+ attributes:
+ - "resource_id"
+ - "encounter_id"
+ - "event_time"
+ - "concept_key"
diff --git a/pyhealth/datasets/fhir_utils.py b/pyhealth/datasets/fhir_utils.py
new file mode 100644
index 000000000..a80c6f1fe
--- /dev/null
+++ b/pyhealth/datasets/fhir_utils.py
@@ -0,0 +1,438 @@
+"""FHIR NDJSON parsing, flattening, and Parquet table writing.
+
+Key public API
+--------------
+stream_fhir_ndjson_to_flat_tables(root, glob_pattern, out_dir)
+ Stream all matching NDJSON/NDJSON.GZ resources into six per-type Parquet tables.
+
+sorted_ndjson_files(root, glob_pattern)
+ List matching NDJSON files under root (deduplicated, sorted).
+
+filter_flat_tables_by_patient_ids(source_dir, out_dir, keep_ids)
+ Subset existing flattened tables to a specific patient cohort.
+"""
+
+from __future__ import annotations
+
+import gzip
+from datetime import datetime
+from pathlib import Path
+from typing import Any, Dict, Iterator, List, Optional, Sequence, Tuple
+
+import orjson
+import polars as pl
+import pyarrow as pa
+import pyarrow.parquet as pq
+
+GlobPatternArg = str | Sequence[str]
+"""Single glob string or sequence of strings for NDJSON file discovery."""
+
+__all__ = [
+ # Types
+ "GlobPatternArg",
+ # Constants
+ "FHIR_SCHEMA_VERSION",
+ "FHIR_TABLES",
+ "FHIR_TABLES_FOR_PATIENT_IDS",
+ "FHIR_TABLE_FILE_NAMES",
+ "FHIR_TABLE_COLUMNS",
+ # Datetime helpers
+ "parse_dt",
+ "as_naive",
+ # FHIR iteration
+ "iter_ndjson_objects",
+ "iter_resources_from_ndjson_obj",
+ # Resource extraction
+ "patient_id_for_resource",
+ # Pipeline
+ "sorted_ndjson_files",
+ "stream_fhir_ndjson_to_flat_tables",
+ "filter_flat_tables_by_patient_ids",
+ "sorted_patient_ids_from_flat_tables",
+]
+
+FHIR_SCHEMA_VERSION = 3
+
+FHIR_TABLES: List[str] = [
+ "patient",
+ "encounter",
+ "condition",
+ "observation",
+ "medication_request",
+ "procedure",
+]
+
+FHIR_TABLES_FOR_PATIENT_IDS: List[str] = [t for t in FHIR_TABLES if t != "patient"]
+
+FHIR_TABLE_FILE_NAMES: Dict[str, str] = {t: f"{t}.parquet" for t in FHIR_TABLES}
+
+FHIR_TABLE_COLUMNS: Dict[str, List[str]] = {
+ "patient": ["patient_id", "patient_fhir_id", "birth_date", "gender", "deceased_boolean", "deceased_datetime"],
+ "encounter": ["patient_id", "resource_id", "encounter_id", "event_time", "encounter_class", "encounter_end"],
+ "condition": ["patient_id", "resource_id", "encounter_id", "event_time", "concept_key"],
+ "observation": ["patient_id", "resource_id", "encounter_id", "event_time", "concept_key"],
+ "medication_request": ["patient_id", "resource_id", "encounter_id", "event_time", "concept_key"],
+ "procedure": ["patient_id", "resource_id", "encounter_id", "event_time", "concept_key"],
+}
+
+# ---------------------------------------------------------------------------
+# Datetime helpers (also imported by cehr_processor)
+# ---------------------------------------------------------------------------
+
+
+def parse_dt(s: Optional[str]) -> Optional[datetime]:
+ if not s:
+ return None
+ try:
+ dt = datetime.fromisoformat(s.replace("Z", "+00:00"))
+ except ValueError:
+ dt = None
+ if dt is None and len(s) >= 10:
+ try:
+ dt = datetime.strptime(s[:10], "%Y-%m-%d")
+ except ValueError:
+ return None
+ if dt is None:
+ return None
+ return dt.replace(tzinfo=None) if dt.tzinfo is not None else dt
+
+
+def as_naive(dt: Optional[datetime]) -> Optional[datetime]:
+ if dt is None:
+ return None
+ return dt.replace(tzinfo=None) if dt.tzinfo is not None else dt
+
+
+# ---------------------------------------------------------------------------
+# FHIR JSON helpers
+# ---------------------------------------------------------------------------
+
+
+def _coding_key(coding: Dict[str, Any]) -> str:
+ return f"{coding.get('system') or 'unknown'}|{coding.get('code') or 'unknown'}"
+
+
+def _first_coding(obj: Optional[Dict[str, Any]]) -> Optional[str]:
+ if not obj:
+ return None
+ codings = obj.get("coding") or []
+ if not codings and "concept" in obj:
+ codings = (obj.get("concept") or {}).get("coding") or []
+ return _coding_key(codings[0]) if codings else None
+
+
+def _ref_id(ref: Optional[str]) -> Optional[str]:
+ if not ref:
+ return None
+ return ref.rsplit("/", 1)[-1] if "/" in ref else ref
+
+
+def _unwrap_resource_dict(raw: Any) -> Optional[Dict[str, Any]]:
+ if not isinstance(raw, dict):
+ return None
+ resource = raw.get("resource") if "resource" in raw else raw
+ return resource if isinstance(resource, dict) else None
+
+
+def iter_resources_from_ndjson_obj(obj: Dict[str, Any]) -> Iterator[Dict[str, Any]]:
+ """Yield resource dicts from one parsed NDJSON object (Bundle or bare resource)."""
+ if "entry" in obj:
+ for entry in obj.get("entry") or []:
+ resource = entry.get("resource")
+ if isinstance(resource, dict):
+ yield resource
+ return
+ resource = _unwrap_resource_dict(obj)
+ if resource is not None:
+ yield resource
+
+
+def iter_ndjson_objects(path: Path) -> Iterator[Dict[str, Any]]:
+ """Yield parsed JSON objects from a plain or gzip-compressed NDJSON file."""
+ opener = (
+ gzip.open(path, "rt", encoding="utf-8", errors="replace")
+ if path.suffix == ".gz"
+ else open(path, encoding="utf-8", errors="replace")
+ )
+ with opener as stream:
+ for line in stream:
+ line = line.strip()
+ if not line:
+ continue
+ parsed = orjson.loads(line)
+ if isinstance(parsed, dict):
+ yield parsed
+
+
+# ---------------------------------------------------------------------------
+# Resource field extraction
+# ---------------------------------------------------------------------------
+
+
+def _clinical_concept_key(res: Dict[str, Any]) -> Optional[str]:
+ """Resolve a stable token key from a FHIR resource."""
+ resource_type = res.get("resourceType")
+ if resource_type == "MedicationRequest":
+ med_cc = res.get("medicationCodeableConcept")
+ if isinstance(med_cc, dict):
+ key = _first_coding(med_cc)
+ if key:
+ return key
+ med_ref = res.get("medicationReference")
+ if isinstance(med_ref, dict):
+ ref = med_ref.get("reference")
+ if ref:
+ return f"MedicationRequest/reference|{_ref_id(ref) or ref}"
+ return None
+ code = res.get("code")
+ return _first_coding(code) if isinstance(code, dict) else None
+
+
+def patient_id_for_resource(
+ resource: Dict[str, Any],
+ resource_type: Optional[str] = None,
+) -> Optional[str]:
+ resource_type = resource_type or resource.get("resourceType")
+ if resource_type == "Patient":
+ pid = resource.get("id")
+ return str(pid) if pid is not None else None
+ if resource_type in {"Encounter", "Condition", "Observation", "MedicationRequest", "Procedure"}:
+ return _ref_id((resource.get("subject") or {}).get("reference"))
+ return None
+
+
+def _resource_time_string(
+ resource: Dict[str, Any],
+ resource_type: Optional[str] = None,
+) -> Optional[str]:
+ resource_type = resource_type or resource.get("resourceType")
+ if resource_type == "Patient":
+ return resource.get("birthDate")
+ if resource_type == "Encounter":
+ return (resource.get("period") or {}).get("start")
+ if resource_type == "Condition":
+ return resource.get("onsetDateTime") or resource.get("recordedDate")
+ if resource_type == "Observation":
+ return resource.get("effectiveDateTime") or resource.get("issued")
+ if resource_type == "MedicationRequest":
+ return resource.get("authoredOn")
+ if resource_type == "Procedure":
+ return resource.get("performedDateTime") or resource.get("recordedDate")
+ return None
+
+
+# ---------------------------------------------------------------------------
+# Flattening
+# ---------------------------------------------------------------------------
+
+
+def _normalize_deceased_boolean_for_storage(value: Any) -> Optional[str]:
+ """Map Patient.deceasedBoolean to stored "true"/"false"/None.
+
+ FHIR JSON uses real booleans; some exports use strings. Python's
+ bool("false") is True, so we must not coerce with bool().
+ """
+ if value is None:
+ return None
+ if value is True:
+ return "true"
+ if value is False:
+ return "false"
+ if isinstance(value, str):
+ key = value.strip().lower()
+ if key in ("true", "1", "yes", "y", "t"):
+ return "true"
+ if key in ("false", "0", "no", "n", "f", ""):
+ return "false"
+ return None
+ if isinstance(value, (int, float)) and not isinstance(value, bool):
+ if value == 0:
+ return "false"
+ if value == 1:
+ return "true"
+ return None
+ return None
+
+
+_RESOURCE_TYPE_TO_TABLE: Dict[str, str] = {
+ "Condition": "condition",
+ "Observation": "observation",
+ "MedicationRequest": "medication_request",
+ "Procedure": "procedure",
+}
+
+
+def _flatten_resource_to_table_row(
+ resource: Dict[str, Any],
+) -> Optional[Tuple[str, Dict[str, Optional[str]]]]:
+ """Map one FHIR resource dict to (table_name, row_dict), or None if unsupported."""
+ resource_type = resource.get("resourceType")
+ patient_id = patient_id_for_resource(resource, resource_type)
+ if not patient_id:
+ return None
+
+ if resource_type == "Patient":
+ return "patient", {
+ "patient_id": patient_id,
+ "patient_fhir_id": str(resource.get("id") or patient_id),
+ "birth_date": resource.get("birthDate"),
+ "gender": resource.get("gender"),
+ "deceased_boolean": _normalize_deceased_boolean_for_storage(resource.get("deceasedBoolean")),
+ "deceased_datetime": resource.get("deceasedDateTime"),
+ }
+
+ resource_id = str(resource.get("id")) if resource.get("id") is not None else None
+ event_time = _resource_time_string(resource, resource_type)
+
+ if resource_type == "Encounter":
+ return "encounter", {
+ "patient_id": patient_id,
+ "resource_id": resource_id,
+ "encounter_id": resource_id,
+ "event_time": event_time,
+ "encounter_class": (resource.get("class") or {}).get("code"),
+ "encounter_end": (resource.get("period") or {}).get("end"),
+ }
+
+ table_name = _RESOURCE_TYPE_TO_TABLE.get(resource_type)
+ if table_name is None:
+ return None
+ return table_name, {
+ "patient_id": patient_id,
+ "resource_id": resource_id,
+ "encounter_id": _ref_id((resource.get("encounter") or {}).get("reference")),
+ "event_time": event_time,
+ "concept_key": _clinical_concept_key(resource),
+ }
+
+
+# ---------------------------------------------------------------------------
+# Parquet writer
+# ---------------------------------------------------------------------------
+
+
+def _table_schema(table_name: str) -> pa.Schema:
+ return pa.schema([(col, pa.string()) for col in FHIR_TABLE_COLUMNS[table_name]])
+
+
+class _BufferedParquetWriter:
+ def __init__(self, path: Path, schema: pa.Schema, batch_size: int = 50_000) -> None:
+ self.path = path
+ self.schema = schema
+ self.batch_size = batch_size
+ self.rows: List[Dict[str, Any]] = []
+ self.writer: Optional[pq.ParquetWriter] = None
+ self.path.parent.mkdir(parents=True, exist_ok=True)
+
+ def add(self, row: Dict[str, Any]) -> None:
+ self.rows.append(row)
+ if len(self.rows) >= self.batch_size:
+ self.flush()
+
+ def flush(self) -> None:
+ if not self.rows:
+ return
+ table = pa.Table.from_pylist(self.rows, schema=self.schema)
+ if self.writer is None:
+ self.writer = pq.ParquetWriter(str(self.path), self.schema)
+ self.writer.write_table(table)
+ self.rows.clear()
+
+ def close(self) -> None:
+ self.flush()
+ if self.writer is None:
+ pq.write_table(pa.Table.from_pylist([], schema=self.schema), str(self.path))
+ return
+ self.writer.close()
+
+
+# ---------------------------------------------------------------------------
+# Pipeline
+# ---------------------------------------------------------------------------
+
+
+def sorted_ndjson_files(root: Path, glob_pattern: GlobPatternArg) -> List[Path]:
+ """Return sorted unique file paths under root matching glob pattern(s).
+
+ Args:
+ root: Root directory to search under.
+ glob_pattern: Single glob string or sequence of glob strings.
+
+ Returns:
+ Sorted list of matching files. Empty if no matches.
+ """
+ patterns = [glob_pattern] if isinstance(glob_pattern, str) else list(glob_pattern)
+ files: set[Path] = set()
+ for pat in patterns:
+ files.update(p for p in root.glob(pat) if p.is_file())
+ return sorted(files, key=lambda p: str(p))
+
+
+def stream_fhir_ndjson_to_flat_tables(
+ root: Path,
+ glob_pattern: GlobPatternArg,
+ out_dir: Path,
+) -> None:
+ """Stream NDJSON resources into normalized per-resource Parquet tables under out_dir.
+
+ Args:
+ root: Root directory containing NDJSON/NDJSON.GZ files.
+ glob_pattern: Single glob string or sequence of glob strings.
+ out_dir: Output directory for per-resource-type Parquet tables.
+ Creates patient.parquet, encounter.parquet, condition.parquet,
+ observation.parquet, medication_request.parquet, procedure.parquet.
+ """
+ out_dir.mkdir(parents=True, exist_ok=True)
+ writers = {
+ name: _BufferedParquetWriter(path=out_dir / FHIR_TABLE_FILE_NAMES[name], schema=_table_schema(name))
+ for name in FHIR_TABLES
+ }
+ try:
+ for file_path in sorted_ndjson_files(root, glob_pattern):
+ for ndjson_obj in iter_ndjson_objects(file_path):
+ for resource in iter_resources_from_ndjson_obj(ndjson_obj):
+ result = _flatten_resource_to_table_row(resource)
+ if result is not None:
+ writers[result[0]].add(result[1])
+ finally:
+ for writer in writers.values():
+ writer.close()
+
+
+def sorted_patient_ids_from_flat_tables(table_dir: Path) -> List[str]:
+ """Return sorted unique patient IDs from a directory of flattened Parquet tables."""
+ patient_table = table_dir / FHIR_TABLE_FILE_NAMES["patient"]
+ if patient_table.exists():
+ return (
+ pl.scan_parquet(str(patient_table))
+ .select("patient_id")
+ .unique()
+ .sort("patient_id")
+ .collect(engine="streaming")["patient_id"]
+ .to_list()
+ )
+ frames = [
+ pl.scan_parquet(str(table_dir / FHIR_TABLE_FILE_NAMES[t])).select("patient_id")
+ for t in FHIR_TABLES_FOR_PATIENT_IDS
+ ]
+ return (
+ pl.concat(frames)
+ .unique()
+ .sort("patient_id")
+ .collect(engine="streaming")["patient_id"]
+ .to_list()
+ )
+
+
+def filter_flat_tables_by_patient_ids(
+ source_dir: Path,
+ out_dir: Path,
+ keep_ids: Sequence[str],
+) -> None:
+ """Filter all flattened tables to only include rows for the given patient IDs."""
+ out_dir.mkdir(parents=True, exist_ok=True)
+ keep_set = set(keep_ids)
+ for name in FHIR_TABLES:
+ src = source_dir / FHIR_TABLE_FILE_NAMES[name]
+ dst = out_dir / FHIR_TABLE_FILE_NAMES[name]
+ pl.scan_parquet(str(src)).filter(pl.col("patient_id").is_in(keep_set)).sink_parquet(str(dst))
diff --git a/pyhealth/datasets/mimic4_fhir.py b/pyhealth/datasets/mimic4_fhir.py
new file mode 100644
index 000000000..7bc53b878
--- /dev/null
+++ b/pyhealth/datasets/mimic4_fhir.py
@@ -0,0 +1,392 @@
+"""MIMIC-IV FHIR ingestion using flattened resource tables.
+
+Architecture
+------------
+1. Stream NDJSON/NDJSON.GZ FHIR resources from disk.
+2. Normalize each resource type into a 2D table (Patient, Encounter,
+ Condition, Observation, MedicationRequest, Procedure) via
+ :mod:`~pyhealth.datasets.fhir_utils`.
+3. Feed those tables through the standard YAML-driven
+ :class:`~pyhealth.datasets.BaseDataset` pipeline so downstream task
+ processing operates on :class:`~pyhealth.data.Patient` and
+ ``global_event_df`` rows.
+"""
+
+from __future__ import annotations
+
+import functools
+import hashlib
+import logging
+import operator
+import os
+import shutil
+import uuid
+from pathlib import Path
+from typing import Any, Dict, List, Optional, Sequence
+
+import dask.dataframe as dd
+import narwhals as nw
+import orjson
+import pandas as pd
+import platformdirs
+from yaml import safe_load
+
+from .base_dataset import BaseDataset
+from .fhir_utils import (
+ FHIR_SCHEMA_VERSION,
+ FHIR_TABLE_FILE_NAMES,
+ FHIR_TABLES,
+ sorted_patient_ids_from_flat_tables,
+ filter_flat_tables_by_patient_ids,
+ stream_fhir_ndjson_to_flat_tables,
+)
+
+logger = logging.getLogger(__name__)
+
+
+def read_fhir_settings_yaml(
+ path: Optional[str] = None,
+) -> Dict[str, Any]:
+ if path is None:
+ path = os.path.join(
+ os.path.dirname(__file__), "configs", "mimic4_fhir.yaml"
+ )
+ with open(path, encoding="utf-8") as stream:
+ data = safe_load(stream)
+ return data if isinstance(data, dict) else {}
+
+
+def _strip_tz_to_naive_ms(part: pd.Series) -> pd.Series:
+ if getattr(part.dtype, "tz", None) is not None:
+ part = part.dt.tz_localize(None)
+ return part.astype("datetime64[ms]")
+
+
+class MIMIC4FHIRDataset(BaseDataset):
+ """MIMIC-IV FHIR with flattened resource tables.
+
+ Streams raw MIMIC-IV FHIR NDJSON/NDJSON.GZ exports into six
+ flattened Parquet tables then pipelines them through
+ :class:`~pyhealth.datasets.BaseDataset` for standard downstream
+ task processing (global event dataframe, patient iteration, task
+ sampling).
+
+ Args:
+ root: Path to the NDJSON/NDJSON.GZ export directory.
+ config_path: Path to a custom YAML config. Defaults to
+ ``pyhealth/datasets/configs/mimic4_fhir.yaml``.
+ glob_pattern: Single glob for NDJSON files. Mutually
+ exclusive with *glob_patterns*.
+ glob_patterns: Multiple glob patterns. Mutually exclusive
+ with *glob_pattern*.
+ max_patients: Limit ingest to the first *N* unique patient
+ IDs.
+ ingest_num_shards: Ignored; retained for API compatibility.
+ cache_dir: Cache directory root (UUID subdir appended per
+ config).
+ num_workers: Worker processes for task sampling.
+ dev: Development mode; limits to 1 000 patients if
+ *max_patients* is ``None``.
+
+ Examples:
+ >>> ds = MIMIC4FHIRDataset(
+ ... root="/data/mimic-iv-fhir",
+ ... glob_pattern="**/*.ndjson.gz",
+ ... max_patients=500,
+ ... )
+ >>> sample_ds = ds.set_task(task, num_workers=4)
+ """
+
+ def __init__(
+ self,
+ root: str,
+ config_path: Optional[str] = None,
+ glob_pattern: Optional[str] = None,
+ glob_patterns: Optional[Sequence[str]] = None,
+ max_patients: Optional[int] = None,
+ ingest_num_shards: Optional[int] = None,
+ cache_dir: Optional[str | Path] = None,
+ num_workers: int = 1,
+ dev: bool = False,
+ ) -> None:
+ del ingest_num_shards
+
+ default_cfg = os.path.join(
+ os.path.dirname(__file__), "configs", "mimic4_fhir.yaml"
+ )
+ self._fhir_config_path = str(
+ Path(config_path or default_cfg).resolve()
+ )
+ self._fhir_settings = read_fhir_settings_yaml(
+ self._fhir_config_path
+ )
+
+ if glob_pattern is not None and glob_patterns is not None:
+ raise ValueError(
+ "Pass at most one of glob_pattern and glob_patterns."
+ )
+ if glob_patterns is not None:
+ self.glob_patterns: List[str] = list(glob_patterns)
+ elif glob_pattern is not None:
+ self.glob_patterns = [glob_pattern]
+ else:
+ raw_list = self._fhir_settings.get("glob_patterns")
+ if raw_list:
+ if not isinstance(raw_list, list):
+ raise TypeError(
+ "mimic4_fhir.yaml glob_patterns must be a "
+ "list of strings."
+ )
+ self.glob_patterns = [str(x) for x in raw_list]
+ elif self._fhir_settings.get("glob_pattern") is not None:
+ self.glob_patterns = [
+ str(self._fhir_settings["glob_pattern"])
+ ]
+ else:
+ self.glob_patterns = ["**/*.ndjson.gz"]
+
+ self.glob_pattern = (
+ self.glob_patterns[0]
+ if len(self.glob_patterns) == 1
+ else "; ".join(self.glob_patterns)
+ )
+ self.max_patients = (
+ 1000 if dev and max_patients is None else max_patients
+ )
+
+ resolved_root = str(Path(root).expanduser().resolve())
+ super().__init__(
+ root=resolved_root,
+ tables=FHIR_TABLES,
+ dataset_name="mimic4_fhir",
+ config_path=self._fhir_config_path,
+ cache_dir=cache_dir,
+ num_workers=num_workers,
+ dev=dev,
+ )
+
+ # ------------------------------------------------------------------
+ # Cache identity
+ # ------------------------------------------------------------------
+
+ def _init_cache_dir(self, cache_dir: str | Path | None) -> Path:
+ try:
+ yaml_digest = hashlib.sha256(
+ Path(self._fhir_config_path).read_bytes()
+ ).hexdigest()[:16]
+ except OSError:
+ yaml_digest = "missing"
+ identity = orjson.dumps(
+ {
+ "root": self.root,
+ "tables": sorted(self.tables),
+ "dataset_name": self.dataset_name,
+ "dev": self.dev,
+ "glob_patterns": self.glob_patterns,
+ "max_patients": self.max_patients,
+ "fhir_schema_version": FHIR_SCHEMA_VERSION,
+ "fhir_yaml_digest16": yaml_digest,
+ },
+ option=orjson.OPT_SORT_KEYS,
+ ).decode("utf-8")
+ cache_id = str(uuid.uuid5(uuid.NAMESPACE_DNS, identity))
+ out = (
+ Path(platformdirs.user_cache_dir(appname="pyhealth"))
+ / cache_id
+ if cache_dir is None
+ else Path(cache_dir) / cache_id
+ )
+ out.mkdir(parents=True, exist_ok=True)
+ logger.info(f"Cache dir: {out}")
+ return out
+
+ # ------------------------------------------------------------------
+ # NDJSON → Parquet ingest
+ # ------------------------------------------------------------------
+
+ @property
+ def prepared_tables_dir(self) -> Path:
+ return self.cache_dir / "flattened_tables"
+
+ def _ensure_prepared_tables(self) -> None:
+ root = Path(self.root)
+ if not root.is_dir():
+ raise FileNotFoundError(
+ f"MIMIC4 FHIR root not found: {root}"
+ )
+
+ expected = [
+ self.prepared_tables_dir / FHIR_TABLE_FILE_NAMES[t]
+ for t in FHIR_TABLES
+ ]
+ if all(p.is_file() for p in expected):
+ return
+ if self.prepared_tables_dir.exists():
+ shutil.rmtree(self.prepared_tables_dir)
+
+ try:
+ staging_root = self.create_tmpdir()
+ staging = staging_root / "flattened_fhir_tables"
+ staging.mkdir(parents=True, exist_ok=True)
+ stream_fhir_ndjson_to_flat_tables(
+ root, self.glob_patterns, staging
+ )
+ if self.max_patients is None:
+ shutil.move(
+ str(staging), str(self.prepared_tables_dir)
+ )
+ return
+
+ filtered_root = self.create_tmpdir()
+ filtered = filtered_root / "filtered"
+ pids = sorted_patient_ids_from_flat_tables(staging)
+ filter_flat_tables_by_patient_ids(
+ staging, filtered, pids[: self.max_patients]
+ )
+ shutil.move(
+ str(filtered), str(self.prepared_tables_dir)
+ )
+ finally:
+ self.clean_tmpdir()
+
+ def _event_transform(self, output_dir: Path) -> None:
+ self._ensure_prepared_tables()
+ super()._event_transform(output_dir)
+
+ # ------------------------------------------------------------------
+ # Table loading (Parquet instead of CSV)
+ # ------------------------------------------------------------------
+
+ def load_table(self, table_name: str) -> dd.DataFrame:
+ """Load one flattened Parquet table into the standard event
+ schema.
+
+ Deviations from ``BaseDataset.load_table`` (CSV via
+ ``_scan_csv_tsv_gz``):
+
+ * Reads from pre-built Parquet under ``prepared_tables_dir``.
+ * Timestamp parsing uses ``errors="coerce"`` + ``utc=True``
+ (FHIR ISO strings include timezone suffix or partial dates).
+ * Strips tz-aware timestamps to naive UTC for Dask compat.
+ * Drops rows with null ``patient_id`` before returning.
+ """
+ assert self.config is not None
+ if table_name not in self.config.tables:
+ raise ValueError(
+ f"Table {table_name} not found in config"
+ )
+
+ table_cfg = self.config.tables[table_name]
+ path = self.prepared_tables_dir / table_cfg.file_path
+ if not path.exists():
+ raise FileNotFoundError(
+ f"Flattened table not found: {path}"
+ )
+
+ logger.info(
+ f"Scanning FHIR flattened table: {table_name} "
+ f"from {path}"
+ )
+ df: dd.DataFrame = dd.read_parquet(
+ str(path), split_row_groups=True, blocksize="64MB"
+ ).replace("", pd.NA)
+ df = df.rename(columns=str.lower)
+
+ preprocess_func = getattr(
+ self, f"preprocess_{table_name}", None
+ )
+ if preprocess_func is not None:
+ logger.info(
+ f"Preprocessing FHIR table: {table_name} "
+ f"with {preprocess_func.__name__}"
+ )
+ df = preprocess_func(nw.from_native(df)).to_native() # type: ignore[union-attr]
+
+ for join_cfg in table_cfg.join:
+ join_path = (
+ self.prepared_tables_dir
+ / Path(join_cfg.file_path).name
+ )
+ if not join_path.exists():
+ raise FileNotFoundError(
+ f"FHIR join table not found: {join_path}"
+ )
+ logger.info(
+ f"Joining FHIR table {table_name} with {join_path}"
+ )
+ join_df: dd.DataFrame = dd.read_parquet(
+ str(join_path),
+ split_row_groups=True,
+ blocksize="64MB",
+ ).replace("", pd.NA)
+ join_df = join_df.rename(columns=str.lower)
+ join_key = join_cfg.on.lower()
+ cols = [c.lower() for c in join_cfg.columns]
+ df = df.merge(
+ join_df[[join_key] + cols],
+ on=join_key,
+ how=join_cfg.how,
+ )
+
+ ts_col = table_cfg.timestamp
+ if ts_col:
+ ts = (
+ functools.reduce(
+ operator.add,
+ (df[c].astype("string") for c in ts_col),
+ )
+ if isinstance(ts_col, list)
+ else df[ts_col].astype("string")
+ )
+ ts = dd.to_datetime(
+ ts,
+ format=table_cfg.timestamp_format,
+ errors="coerce",
+ utc=True,
+ )
+ df = df.assign(
+ timestamp=ts.map_partitions(_strip_tz_to_naive_ms)
+ )
+ else:
+ df = df.assign(timestamp=pd.NaT)
+
+ if table_cfg.patient_id:
+ df = df.assign(
+ patient_id=df[table_cfg.patient_id].astype("string")
+ )
+ else:
+ df = df.reset_index(drop=True)
+ df = df.assign(patient_id=df.index.astype("string"))
+
+ df = df.dropna(subset=["patient_id"])
+ df = df.assign(event_type=table_name)
+ rename_attr = {
+ attr.lower(): f"{table_name}/{attr}"
+ for attr in table_cfg.attributes
+ }
+ df = df.rename(columns=rename_attr)
+ return df[
+ ["patient_id", "event_type", "timestamp"]
+ + [rename_attr[a.lower()] for a in table_cfg.attributes]
+ ]
+
+ # ------------------------------------------------------------------
+ # Patient IDs (deterministic sorted order)
+ # ------------------------------------------------------------------
+
+ @property
+ def unique_patient_ids(self) -> List[str]:
+ if self._unique_patient_ids is None:
+ self._unique_patient_ids = (
+ self.global_event_df.select("patient_id")
+ .unique()
+ .sort("patient_id")
+ .collect(engine="streaming")
+ .to_series()
+ .to_list()
+ )
+ logger.info(
+ f"Found {len(self._unique_patient_ids)} "
+ f"unique patient IDs"
+ )
+ return self._unique_patient_ids
diff --git a/pyhealth/models/__init__.py b/pyhealth/models/__init__.py
index 5233b1726..86486dab8 100644
--- a/pyhealth/models/__init__.py
+++ b/pyhealth/models/__init__.py
@@ -38,6 +38,7 @@
from .transformer import Transformer, TransformerLayer
from .transformers_model import TransformersModel
from .ehrmamba import EHRMamba, MambaBlock
+from .ehrmamba_cehr import EHRMambaCEHR
from .vae import VAE
from .vision_embedding import VisionEmbeddingModel
from .text_embedding import TextEmbedding
diff --git a/pyhealth/models/cehr_embeddings.py b/pyhealth/models/cehr_embeddings.py
new file mode 100644
index 000000000..7974a699e
--- /dev/null
+++ b/pyhealth/models/cehr_embeddings.py
@@ -0,0 +1,112 @@
+# SPDX-License-Identifier: Apache-2.0
+# Copyright 2024 Vector Institute / Odyssey authors
+#
+# Derived from Odyssey (https://github.com/VectorInstitute/odyssey):
+# odyssey/models/embeddings.py — MambaEmbeddingsForCEHR, TimeEmbeddingLayer, VisitEmbedding
+# Modifications: removed HuggingFace MambaConfig dependency; explicit constructor args.
+
+from __future__ import annotations
+
+from typing import Any, Optional
+
+import torch
+from torch import nn
+
+
+class TimeEmbeddingLayer(nn.Module):
+ """Embedding layer for time features (sinusoidal)."""
+
+ def __init__(self, embedding_size: int, is_time_delta: bool = False):
+ super().__init__()
+ self.embedding_size = embedding_size
+ self.is_time_delta = is_time_delta
+ self.w = nn.Parameter(torch.empty(1, self.embedding_size))
+ self.phi = nn.Parameter(torch.empty(1, self.embedding_size))
+ nn.init.xavier_uniform_(self.w)
+ nn.init.xavier_uniform_(self.phi)
+
+ def forward(self, time_stamps: torch.Tensor) -> torch.Tensor:
+ if self.is_time_delta:
+ time_stamps = torch.cat(
+ (time_stamps[:, 0:1] * 0, time_stamps[:, 1:] - time_stamps[:, :-1]),
+ dim=-1,
+ )
+ time_stamps = time_stamps.float()
+ next_input = time_stamps.unsqueeze(-1) * self.w + self.phi
+ return torch.sin(next_input)
+
+
+class VisitEmbedding(nn.Module):
+ """Embedding layer for visit segments."""
+
+ def __init__(self, visit_order_size: int, embedding_size: int):
+ super().__init__()
+ self.embedding = nn.Embedding(visit_order_size, embedding_size)
+
+ def forward(self, visit_segments: torch.Tensor) -> torch.Tensor:
+ return self.embedding(visit_segments)
+
+
+class MambaEmbeddingsForCEHR(nn.Module):
+ """CEHR-style combined embeddings for Mamba (concept + type + time + age + visit)."""
+
+ def __init__(
+ self,
+ vocab_size: int,
+ hidden_size: int,
+ pad_token_id: int = 0,
+ type_vocab_size: int = 9,
+ max_num_visits: int = 512,
+ time_embeddings_size: int = 32,
+ visit_order_size: int = 3,
+ layer_norm_eps: float = 1e-12,
+ hidden_dropout_prob: float = 0.1,
+ ) -> None:
+ super().__init__()
+ self.hidden_size = hidden_size
+ self.pad_token_id = pad_token_id
+ self.type_vocab_size = type_vocab_size
+ self.max_num_visits = max_num_visits
+ self.word_embeddings = nn.Embedding(
+ vocab_size, hidden_size, padding_idx=pad_token_id
+ )
+ self.token_type_embeddings = nn.Embedding(type_vocab_size, hidden_size)
+ self.visit_order_embeddings = nn.Embedding(max_num_visits, hidden_size)
+ self.time_embeddings = TimeEmbeddingLayer(
+ embedding_size=time_embeddings_size, is_time_delta=True
+ )
+ self.age_embeddings = TimeEmbeddingLayer(
+ embedding_size=time_embeddings_size, is_time_delta=False
+ )
+ self.visit_segment_embeddings = VisitEmbedding(
+ visit_order_size=visit_order_size, embedding_size=hidden_size
+ )
+ self.scale_back_concat_layer = nn.Linear(
+ hidden_size + 2 * time_embeddings_size, hidden_size
+ )
+ self.tanh = nn.Tanh()
+ self.LayerNorm = nn.LayerNorm(hidden_size, eps=layer_norm_eps)
+ self.dropout = nn.Dropout(hidden_dropout_prob)
+
+ def forward(
+ self,
+ input_ids: torch.Tensor,
+ token_type_ids_batch: torch.Tensor,
+ time_stamps: torch.Tensor,
+ ages: torch.Tensor,
+ visit_orders: torch.Tensor,
+ visit_segments: torch.Tensor,
+ ) -> torch.Tensor:
+ inputs_embeds = self.word_embeddings(input_ids)
+ time_stamps_embeds = self.time_embeddings(time_stamps)
+ ages_embeds = self.age_embeddings(ages)
+ visit_segments_embeds = self.visit_segment_embeddings(visit_segments)
+ visit_order_embeds = self.visit_order_embeddings(visit_orders)
+ token_type_embeds = self.token_type_embeddings(token_type_ids_batch)
+ concat_in = torch.cat(
+ (inputs_embeds, time_stamps_embeds, ages_embeds), dim=-1
+ )
+ h = self.tanh(self.scale_back_concat_layer(concat_in))
+ embeddings = h + token_type_embeds + visit_order_embeds + visit_segments_embeds
+ embeddings = self.dropout(embeddings)
+ return self.LayerNorm(embeddings)
diff --git a/pyhealth/models/ehrmamba_cehr.py b/pyhealth/models/ehrmamba_cehr.py
new file mode 100644
index 000000000..cd555629c
--- /dev/null
+++ b/pyhealth/models/ehrmamba_cehr.py
@@ -0,0 +1,117 @@
+"""EHRMamba with CEHR-style embeddings for single-stream FHIR token sequences."""
+
+from __future__ import annotations
+
+from typing import Any, Dict, Optional
+
+import torch
+from torch import nn
+
+from pyhealth.datasets import SampleDataset
+
+from .base_model import BaseModel
+from .cehr_embeddings import MambaEmbeddingsForCEHR
+from .ehrmamba import MambaBlock
+from .utils import get_rightmost_masked_timestep
+
+
+class EHRMambaCEHR(BaseModel):
+ """Mamba backbone over CEHR embeddings (FHIR / MPF pipeline).
+
+ Args:
+ dataset: Fitted :class:`~pyhealth.datasets.SampleDataset` with MPF task schema.
+ vocab_size: Concept embedding vocabulary size (typically ``task.vocab.vocab_size``).
+ embedding_dim: Hidden size (``hidden_size`` in CEHR embeddings).
+ num_layers: Number of :class:`~pyhealth.models.ehrmamba.MambaBlock` layers.
+ pad_token_id: Padding id for masking (default 0).
+ state_size: SSM state size per channel.
+ conv_kernel: Causal conv kernel in each block.
+ dropout: Dropout before classifier.
+ """
+
+ def __init__(
+ self,
+ dataset: SampleDataset,
+ vocab_size: int,
+ embedding_dim: int = 128,
+ num_layers: int = 2,
+ pad_token_id: int = 0,
+ state_size: int = 16,
+ conv_kernel: int = 4,
+ dropout: float = 0.1,
+ type_vocab_size: int = 16,
+ max_num_visits: int = 512,
+ time_embeddings_size: int = 32,
+ visit_segment_vocab: int = 3,
+ ):
+ super().__init__(dataset=dataset)
+ self.embedding_dim = embedding_dim
+ self.num_layers = num_layers
+ self.pad_token_id = pad_token_id
+ self.vocab_size = vocab_size
+
+ assert len(self.label_keys) == 1, "EHRMambaCEHR supports single label key only"
+ self.label_key = self.label_keys[0]
+ self.mode = self.dataset.output_schema[self.label_key]
+
+ self.embeddings = MambaEmbeddingsForCEHR(
+ vocab_size=vocab_size,
+ hidden_size=embedding_dim,
+ pad_token_id=pad_token_id,
+ type_vocab_size=type_vocab_size,
+ max_num_visits=max_num_visits,
+ time_embeddings_size=time_embeddings_size,
+ visit_order_size=visit_segment_vocab,
+ )
+ self.blocks = nn.ModuleList(
+ [
+ MambaBlock(
+ d_model=embedding_dim,
+ state_size=state_size,
+ conv_kernel=conv_kernel,
+ )
+ for _ in range(num_layers)
+ ]
+ )
+ self.dropout = nn.Dropout(dropout)
+ out_dim = self.get_output_size()
+ self.fc = nn.Linear(embedding_dim, out_dim)
+ self._forecasting_head: Optional[nn.Module] = None
+
+ def forward_forecasting(self, **kwargs: Any) -> Optional[torch.Tensor]:
+ """Optional next-token / forecasting head (extension point; not implemented)."""
+
+ return None
+
+ def forward(self, **kwargs: Any) -> Dict[str, torch.Tensor]:
+ concept_ids = kwargs["concept_ids"].to(self.device).long()
+ token_type_ids = kwargs["token_type_ids"].to(self.device).long()
+ time_stamps = kwargs["time_stamps"].to(self.device).float()
+ ages = kwargs["ages"].to(self.device).float()
+ visit_orders = kwargs["visit_orders"].to(self.device).long()
+ visit_segments = kwargs["visit_segments"].to(self.device).long()
+
+ x = self.embeddings(
+ input_ids=concept_ids,
+ token_type_ids_batch=token_type_ids,
+ time_stamps=time_stamps,
+ ages=ages,
+ visit_orders=visit_orders,
+ visit_segments=visit_segments,
+ )
+ mask = concept_ids != self.pad_token_id
+ for blk in self.blocks:
+ x = blk(x)
+ pooled = get_rightmost_masked_timestep(x, mask)
+ logits = self.fc(self.dropout(pooled))
+ y_true = kwargs[self.label_key].to(self.device).float()
+ if y_true.dim() == 1:
+ y_true = y_true.unsqueeze(-1)
+ loss = self.get_loss_function()(logits, y_true)
+ y_prob = self.prepare_y_prob(logits)
+ return {
+ "loss": loss,
+ "y_prob": y_prob,
+ "y_true": y_true,
+ "logit": logits,
+ }
diff --git a/pyhealth/models/utils.py b/pyhealth/models/utils.py
index 67edc010e..45cd6608d 100644
--- a/pyhealth/models/utils.py
+++ b/pyhealth/models/utils.py
@@ -44,3 +44,31 @@ def get_last_visit(hidden_states, mask):
last_hidden_states = torch.gather(hidden_states, 1, last_visit)
last_hidden_state = last_hidden_states[:, 0, :]
return last_hidden_state
+
+
+def get_rightmost_masked_timestep(hidden_states, mask):
+ """Gather hidden state at the last True position in ``mask`` per row.
+
+ Unlike :func:`get_last_visit`, this does **not** assume valid tokens form a
+ contiguous prefix; it picks the maximum index where ``mask`` is True.
+ Use for MPF / CEHR layouts where padding can appear between boundary tokens.
+
+ Args:
+ hidden_states: ``[batch, seq_len, hidden_size]``.
+ mask: ``[batch, seq_len]`` bool.
+
+ Returns:
+ Tensor ``[batch, hidden_size]``.
+ """
+ if mask is None:
+ return hidden_states[:, -1, :]
+ batch, seq_len, hidden = hidden_states.shape
+ device = hidden_states.device
+ idx = torch.arange(seq_len, device=device, dtype=torch.long).unsqueeze(0).expand(
+ batch, -1
+ )
+ idx_m = torch.where(mask, idx, torch.full_like(idx, -1))
+ last_idx = idx_m.max(dim=1).values.clamp(min=0)
+ last_idx = last_idx.view(batch, 1, 1).expand(batch, 1, hidden)
+ gathered = torch.gather(hidden_states, 1, last_idx)
+ return gathered[:, 0, :]
diff --git a/pyhealth/processors/__init__.py b/pyhealth/processors/__init__.py
index b48072270..4568a5ece 100644
--- a/pyhealth/processors/__init__.py
+++ b/pyhealth/processors/__init__.py
@@ -50,6 +50,7 @@ def get_processor(name: str):
from .ignore_processor import IgnoreProcessor
from .temporal_timeseries_processor import TemporalTimeseriesProcessor
from .tuple_time_text_processor import TupleTimeTextProcessor
+from .cehr_processor import CehrProcessor, ConceptVocab
# Expose public API
from .base_processor import (
@@ -79,4 +80,6 @@ def get_processor(name: str):
"GraphProcessor",
"AudioProcessor",
"TupleTimeTextProcessor",
+ "CehrProcessor",
+ "ConceptVocab",
]
diff --git a/pyhealth/processors/cehr_processor.py b/pyhealth/processors/cehr_processor.py
new file mode 100644
index 000000000..19ba90170
--- /dev/null
+++ b/pyhealth/processors/cehr_processor.py
@@ -0,0 +1,554 @@
+"""CEHR-style tokenization, vocabulary, and sequence building for FHIR timelines.
+
+Key public API
+--------------
+ConceptVocab
+ Token-to-dense-id mapping with PAD/UNK reserved at 0 and 1. JSON-serializable.
+
+CehrProcessor
+ FeatureProcessor subclass that owns a ConceptVocab, can be warmed over a patient
+ stream, and converts a Patient's tabular FHIR rows into CEHR-aligned sequences.
+
+build_cehr_sequences(patient, vocab, max_len)
+ Flatten a Patient's tabular FHIR rows into CEHR-aligned feature lists.
+
+infer_mortality_label(patient)
+ Heuristic binary mortality label from flattened patient rows.
+"""
+
+from __future__ import annotations
+
+from dataclasses import dataclass, field
+from datetime import datetime
+from pathlib import Path
+from typing import Any, Dict, Iterable, List, Optional, Tuple
+
+import orjson
+
+from pyhealth.data import Patient
+
+from .base_processor import FeatureProcessor
+from . import register_processor
+
+DEFAULT_PAD = 0
+DEFAULT_UNK = 1
+
+# ---------------------------------------------------------------------------
+# Datetime helpers
+#
+# These are intentional copies of the identically-named functions in
+# pyhealth.datasets.fhir_utils. They exist here to avoid a circular import:
+# importing pyhealth.datasets.fhir_utils (even as a submodule) triggers
+# pyhealth/datasets/__init__.py, which imports MIMIC4FHIRDataset, which in turn
+# imports from this module — creating a cycle. The implementations are pure
+# stdlib (no pyhealth deps), so keeping them in sync is straightforward;
+# any change to the canonical copy in fhir_utils must be mirrored here.
+# ---------------------------------------------------------------------------
+
+
+def parse_dt(s: Optional[str]) -> Optional[datetime]:
+ """Parse an ISO 8601 or YYYY-MM-DD date string to a naive datetime."""
+ if not s:
+ return None
+ try:
+ dt = datetime.fromisoformat(s.replace("Z", "+00:00"))
+ except ValueError:
+ dt = None
+ if dt is None and len(s) >= 10:
+ try:
+ dt = datetime.strptime(s[:10], "%Y-%m-%d")
+ except ValueError:
+ return None
+ if dt is None:
+ return None
+ return dt.replace(tzinfo=None) if dt.tzinfo is not None else dt
+
+
+def as_naive(dt: Optional[datetime]) -> Optional[datetime]:
+ """Strip timezone info from a datetime, or return None unchanged."""
+ if dt is None:
+ return None
+ return dt.replace(tzinfo=None) if dt.tzinfo is not None else dt
+
+
+__all__ = [
+ # Constants
+ "DEFAULT_PAD",
+ "DEFAULT_UNK",
+ "EVENT_TYPE_TO_TOKEN_TYPE",
+ # Vocabulary
+ "ConceptVocab",
+ "ensure_special_tokens",
+ # Processor
+ "CehrProcessor",
+ # Sequence building
+ "collect_cehr_timeline_events",
+ "warm_mpf_vocab_from_patient",
+ "build_cehr_sequences",
+ # Labels
+ "infer_mortality_label",
+]
+
+EVENT_TYPE_TO_TOKEN_TYPE = {
+ "encounter": 1,
+ "condition": 2,
+ "medication_request": 3,
+ "observation": 4,
+ "procedure": 5,
+}
+
+# Table-driven lookups for flattened event-row column access.
+_CONCEPT_KEY_COL: Dict[str, str] = {
+ "condition": "condition/concept_key",
+ "observation": "observation/concept_key",
+ "medication_request": "medication_request/concept_key",
+ "procedure": "procedure/concept_key",
+}
+
+_ENCOUNTER_ID_COL: Dict[str, str] = {
+ "condition": "condition/encounter_id",
+ "observation": "observation/encounter_id",
+ "medication_request": "medication_request/encounter_id",
+ "procedure": "procedure/encounter_id",
+ "encounter": "encounter/encounter_id",
+}
+
+# ---------------------------------------------------------------------------
+# ConceptVocab
+# ---------------------------------------------------------------------------
+
+
+@dataclass
+class ConceptVocab:
+ """Maps concept keys to dense ids with PAD/UNK reserved at 0 and 1."""
+
+ token_to_id: Dict[str, int] = field(default_factory=dict)
+ pad_id: int = DEFAULT_PAD
+ unk_id: int = DEFAULT_UNK
+ _next_id: int = 2
+
+ def __post_init__(self) -> None:
+ if not self.token_to_id:
+ self.token_to_id = {"": self.pad_id, "": self.unk_id}
+ self._next_id = 2
+
+ def add_token(self, key: str) -> int:
+ if key in self.token_to_id:
+ return self.token_to_id[key]
+ tid = self._next_id
+ self.token_to_id[key] = tid
+ self._next_id += 1
+ return tid
+
+ def __getitem__(self, key: str) -> int:
+ return self.token_to_id.get(key, self.unk_id)
+
+ @property
+ def vocab_size(self) -> int:
+ return self._next_id
+
+ def to_json(self) -> Dict[str, Any]:
+ return {
+ "token_to_id": self.token_to_id,
+ "next_id": self._next_id,
+ "pad_id": self.pad_id,
+ "unk_id": self.unk_id,
+ }
+
+ @classmethod
+ def from_json(cls, data: Dict[str, Any]) -> "ConceptVocab":
+ pad_id = int(data.get("pad_id", DEFAULT_PAD))
+ unk_id = int(data.get("unk_id", DEFAULT_UNK))
+ vocab = cls(pad_id=pad_id, unk_id=unk_id)
+ loaded = dict(data.get("token_to_id") or {})
+ if not loaded:
+ vocab._next_id = int(data.get("next_id", 2))
+ return vocab
+ vocab.token_to_id = loaded
+ vocab._next_id = int(data.get("next_id", max(loaded.values()) + 1))
+ return vocab
+
+ def save(self, path: str) -> None:
+ Path(path).parent.mkdir(parents=True, exist_ok=True)
+ Path(path).write_bytes(orjson.dumps(self.to_json(), option=orjson.OPT_SORT_KEYS))
+
+ @classmethod
+ def load(cls, path: str) -> "ConceptVocab":
+ return cls.from_json(orjson.loads(Path(path).read_bytes()))
+
+
+def ensure_special_tokens(vocab: ConceptVocab) -> Dict[str, int]:
+ """Add EHRMamba/CEHR special tokens and return their ids."""
+ return {name: vocab.add_token(name) for name in ("", "", "", "")}
+
+
+# ---------------------------------------------------------------------------
+# Row utilities for flattened event stream
+# ---------------------------------------------------------------------------
+
+
+def _clean_string(value: Any) -> Optional[str]:
+ if value is None:
+ return None
+ if isinstance(value, str):
+ return value.strip() or None
+ return str(value)
+
+
+def _deceased_boolean_column_means_dead(value: Any) -> bool:
+ """True only for an explicit affirmative stored flag (not Python truthiness)."""
+ s = _clean_string(value)
+ return s is not None and s.lower() == "true"
+
+
+def _row_datetime(value: Any) -> Optional[datetime]:
+ if value is None:
+ return None
+ if isinstance(value, datetime):
+ return as_naive(value)
+ try:
+ return parse_dt(str(value))
+ except Exception:
+ return None
+
+
+def _concept_key_from_row(row: Dict[str, Any]) -> str:
+ event_type = row.get("event_type")
+ col = _CONCEPT_KEY_COL.get(event_type)
+ if col:
+ return _clean_string(row.get(col)) or f"{event_type}|unknown"
+ if event_type == "encounter":
+ enc_class = _clean_string(row.get("encounter/encounter_class"))
+ return f"encounter|{enc_class}" if enc_class else "encounter|unknown"
+ return f"{event_type or 'event'}|unknown"
+
+
+def _linked_encounter_id_from_row(row: Dict[str, Any]) -> Optional[str]:
+ col = _ENCOUNTER_ID_COL.get(row.get("event_type"))
+ return _clean_string(row.get(col)) if col else None
+
+
+def _birth_datetime_from_patient(patient: Patient) -> Optional[datetime]:
+ for row in patient.data_source.iter_rows(named=True):
+ if row.get("event_type") != "patient":
+ continue
+ birth = _row_datetime(row.get("timestamp"))
+ if birth is not None:
+ return birth
+ raw = _clean_string(row.get("patient/birth_date"))
+ if raw:
+ return parse_dt(raw)
+ return None
+
+
+def _sequential_visit_idx_for_time(
+ event_time: Optional[datetime],
+ visit_encounters: List[Tuple[datetime, int]],
+) -> int:
+ if not visit_encounters:
+ return 0
+ if event_time is None:
+ return visit_encounters[-1][1]
+ event_time = as_naive(event_time)
+ chosen = visit_encounters[0][1]
+ for encounter_start, visit_idx in visit_encounters:
+ if encounter_start <= event_time:
+ chosen = visit_idx
+ else:
+ break
+ return chosen
+
+
+# ---------------------------------------------------------------------------
+# CEHR timeline and sequence building
+# ---------------------------------------------------------------------------
+
+
+def collect_cehr_timeline_events(
+ patient: Patient,
+) -> List[Tuple[datetime, str, str, int]]:
+ """Collect (time, concept_key, event_type, visit_idx) tuples from a patient's rows."""
+ rows = list(
+ patient.data_source.sort(["timestamp", "event_type"], nulls_last=True).iter_rows(named=True)
+ )
+
+ # Build encounter list — rows are already timestamp-sorted so the loop
+ # preserves chronological order without an explicit sort.
+ encounter_rows: List[Tuple[datetime, str]] = []
+ for row in rows:
+ if row.get("event_type") != "encounter":
+ continue
+ enc_id = _linked_encounter_id_from_row(row)
+ enc_start = _row_datetime(row.get("timestamp"))
+ if enc_id is not None and enc_start is not None:
+ encounter_rows.append((enc_start, enc_id))
+
+ encounter_visit_idx = {enc_id: idx for idx, (_, enc_id) in enumerate(encounter_rows)}
+ encounter_start_by_id = {enc_id: enc_start for enc_start, enc_id in encounter_rows}
+ visit_encounters = [(enc_start, idx) for idx, (enc_start, _) in enumerate(encounter_rows)]
+
+ events: List[Tuple[datetime, str, str, int]] = []
+ unlinked: List[Tuple[Optional[datetime], str, str]] = []
+
+ for row in rows:
+ event_type = row.get("event_type")
+ if event_type not in EVENT_TYPE_TO_TOKEN_TYPE:
+ continue
+ event_time = _row_datetime(row.get("timestamp"))
+ concept_key = _concept_key_from_row(row)
+
+ if event_type == "encounter":
+ enc_id = _linked_encounter_id_from_row(row)
+ if enc_id is None or event_time is None:
+ continue
+ visit_idx = encounter_visit_idx.get(enc_id)
+ if visit_idx is None:
+ continue
+ events.append((event_time, concept_key, event_type, visit_idx))
+ continue
+
+ enc_id = _linked_encounter_id_from_row(row)
+ if enc_id and enc_id in encounter_visit_idx:
+ visit_idx = encounter_visit_idx[enc_id]
+ if event_time is None:
+ event_time = encounter_start_by_id.get(enc_id)
+ if event_time is None:
+ continue
+ events.append((event_time, concept_key, event_type, visit_idx))
+ else:
+ unlinked.append((event_time, concept_key, event_type))
+
+ for event_time, concept_key, event_type in unlinked:
+ visit_idx = _sequential_visit_idx_for_time(event_time, visit_encounters)
+ if event_time is None:
+ if not visit_encounters:
+ continue
+ for enc_start, enc_visit_idx in visit_encounters:
+ if enc_visit_idx == visit_idx:
+ event_time = enc_start
+ break
+ else:
+ event_time = visit_encounters[-1][0]
+ if event_time is None:
+ continue
+ events.append((event_time, concept_key, event_type, visit_idx))
+
+ events.sort(key=lambda item: item[0])
+ return events
+
+
+def warm_mpf_vocab_from_patient(
+ vocab: ConceptVocab,
+ patient: Patient,
+ clinical_cap: int,
+) -> None:
+ """Add concept keys from the last clinical_cap events of a patient to vocab."""
+ tail = collect_cehr_timeline_events(patient)[-clinical_cap:] if clinical_cap > 0 else []
+ for _, concept_key, _, _ in tail:
+ vocab.add_token(concept_key)
+
+
+def build_cehr_sequences(
+ patient: Patient,
+ vocab: ConceptVocab,
+ max_len: int,
+ *,
+ base_time: Optional[datetime] = None,
+ grow_vocab: bool = True,
+) -> Tuple[List[int], List[int], List[float], List[float], List[int], List[int]]:
+ """Flatten a patient's tabular FHIR rows into CEHR-aligned feature lists."""
+ events = collect_cehr_timeline_events(patient)
+ birth = _birth_datetime_from_patient(patient)
+
+ if base_time is None:
+ base_time = events[0][0] if events else datetime.now()
+ base_time = as_naive(base_time)
+ birth = as_naive(birth)
+
+ concept_ids: List[int] = []
+ token_types: List[int] = []
+ time_stamps: List[float] = []
+ ages: List[float] = []
+ visit_orders: List[int] = []
+ visit_segments: List[int] = []
+
+ for event_time, concept_key, event_type, visit_idx in (events[-max_len:] if max_len > 0 else []):
+ event_time = as_naive(event_time)
+ concept_id = vocab.add_token(concept_key) if grow_vocab else vocab[concept_key]
+ token_type = EVENT_TYPE_TO_TOKEN_TYPE.get(event_type, 0)
+ time_delta = (
+ float((event_time - base_time).total_seconds())
+ if base_time is not None and event_time is not None
+ else 0.0
+ )
+ age_years = (
+ (event_time - birth).days / 365.25
+ if birth is not None and event_time is not None
+ else 0.0
+ )
+ concept_ids.append(concept_id)
+ token_types.append(token_type)
+ time_stamps.append(time_delta)
+ ages.append(age_years)
+ visit_orders.append(min(visit_idx, 511))
+ visit_segments.append(visit_idx % 2)
+
+ return concept_ids, token_types, time_stamps, ages, visit_orders, visit_segments
+
+
+# ---------------------------------------------------------------------------
+# Label inference
+# ---------------------------------------------------------------------------
+
+
+def infer_mortality_label(patient: Patient) -> int:
+ """Heuristic binary mortality label from flattened patient rows."""
+ for row in patient.data_source.iter_rows(named=True):
+ if row.get("event_type") == "patient":
+ if _deceased_boolean_column_means_dead(row.get("patient/deceased_boolean")):
+ return 1
+ if _clean_string(row.get("patient/deceased_datetime")):
+ return 1
+ for row in patient.data_source.iter_rows(named=True):
+ if row.get("event_type") == "condition":
+ key = (_clean_string(row.get("condition/concept_key")) or "").lower()
+ if any(token in key for token in ("death", "deceased", "mortality")):
+ return 1
+ return 0
+
+
+# ---------------------------------------------------------------------------
+# CehrProcessor
+# ---------------------------------------------------------------------------
+
+
+@register_processor("cehr")
+class CehrProcessor(FeatureProcessor):
+ """CEHR concept sequence processor for FHIR timelines.
+
+ Owns a :class:`ConceptVocab` and converts a
+ :class:`~pyhealth.data.Patient`'s tabular FHIR event rows into
+ CEHR-aligned integer sequence lists ready for downstream models.
+
+ This processor departs from the standard ``FeatureProcessor`` contract
+ in two ways that are intentional for this domain:
+
+ * ``process(patient)`` takes a full :class:`~pyhealth.data.Patient` rather
+ than a single scalar field value, because building CEHR sequences
+ requires access to all event rows simultaneously.
+ * ``fit(patients, clinical_cap)`` takes an iterable of
+ :class:`~pyhealth.data.Patient` objects instead of the base-class
+ ``fit(samples, field)`` signature, because vocabulary warming is driven
+ by the Patient timeline, not a pre-aggregated sample dict.
+
+ Typical usage::
+
+ processor = CehrProcessor(max_len=512)
+ processor.fit(dataset.iter_patients()) # warm vocabulary
+ sequences = processor.process(some_patient) # build sequences
+ processor.save("vocab.json") # persist vocab
+
+ Attributes:
+ vocab: Concept-to-id mapping (PAD=0, UNK=1).
+ max_len: Maximum number of clinical tokens per patient (boundary tokens
+ not counted; see :class:`~pyhealth.tasks.MPFClinicalPredictionTask`).
+ frozen_vocab: When True, unknown concepts map to UNK instead of adding
+ new ids — used for multi-worker safety after vocab warm-up.
+ """
+
+ def __init__(
+ self,
+ vocab: Optional[ConceptVocab] = None,
+ max_len: int = 512,
+ frozen_vocab: bool = False,
+ ) -> None:
+ self.vocab = vocab or ConceptVocab()
+ self.max_len = max_len
+ self.frozen_vocab = frozen_vocab
+
+ def fit( # type: ignore[override]
+ self,
+ patients: Iterable[Patient],
+ clinical_cap: Optional[int] = None,
+ ) -> "CehrProcessor":
+ """Warm vocabulary from a stream of patients.
+
+ Note: this method intentionally overrides ``FeatureProcessor.fit(samples,
+ field)`` with a different signature, because CEHR vocabulary warming
+ operates on :class:`~pyhealth.data.Patient` timelines rather than
+ pre-aggregated sample dicts.
+
+ Special tokens are *not* inserted here; they are added lazily by
+ :meth:`~pyhealth.tasks.MPFClinicalPredictionTask._ensure_processor`
+ on the first call to the task. This keeps ``fit`` focused on concept
+ key discovery.
+
+ Args:
+ patients: Iterable of :class:`~pyhealth.data.Patient` objects.
+ clinical_cap: Maximum number of tail events to scan per patient.
+ Defaults to ``max_len - 2`` (room for two boundary tokens).
+
+ Returns:
+ self (for chaining).
+ """
+ cap = clinical_cap if clinical_cap is not None else max(0, self.max_len - 2)
+ for patient in patients:
+ warm_mpf_vocab_from_patient(self.vocab, patient, cap)
+ return self
+
+ def process(
+ self,
+ patient: Patient,
+ ) -> Tuple[List[int], List[int], List[float], List[float], List[int], List[int]]:
+ """Build CEHR sequences from a patient's FHIR event rows.
+
+ Args:
+ patient: A tabular :class:`~pyhealth.data.Patient`.
+
+ Returns:
+ Six equal-length lists ``(concept_ids, token_type_ids, time_stamps,
+ ages, visit_orders, visit_segments)`` ready for boundary-token
+ insertion and left-padding by the task.
+ """
+ clinical_cap = max(0, self.max_len - 2)
+ return build_cehr_sequences(
+ patient, self.vocab, clinical_cap, grow_vocab=not self.frozen_vocab
+ )
+
+ def save(self, path: str) -> None:
+ """Persist the vocabulary to a JSON file at *path*."""
+ self.vocab.save(path)
+
+ def load(self, path: str) -> None:
+ """Load a previously saved vocabulary from *path*."""
+ self.vocab = ConceptVocab.load(path)
+
+ def is_token(self) -> bool:
+ """All six output lists contain discrete token/index values."""
+ return True
+
+ def schema(self) -> Tuple[str, ...]:
+ return (
+ "concept_ids",
+ "token_type_ids",
+ "time_stamps",
+ "ages",
+ "visit_orders",
+ "visit_segments",
+ )
+
+ def dim(self) -> Tuple[int, ...]:
+ """Each of the six output lists becomes a 1-D tensor."""
+ return (1, 1, 1, 1, 1, 1)
+
+ def spatial(self) -> Tuple[bool, ...]:
+ """All six outputs are along the sequence (temporal) axis."""
+ return (True, True, True, True, True, True)
+
+ def __repr__(self) -> str:
+ # frozen_vocab is a runtime flag, not a constructor parameter, so it
+ # must be excluded here. This repr is used by BaseDataset.set_task to
+ # compute the LitData task-cache UUID via vars(task); including
+ # frozen_vocab would produce different UUIDs for single- vs
+ # multi-worker runs of the same pipeline, defeating caching.
+ return f"CehrProcessor(max_len={self.max_len})"
diff --git a/pyhealth/tasks/__init__.py b/pyhealth/tasks/__init__.py
index 797988377..ffdf99560 100644
--- a/pyhealth/tasks/__init__.py
+++ b/pyhealth/tasks/__init__.py
@@ -66,3 +66,11 @@
VariantClassificationClinVar,
)
from .patient_linkage_mimic3 import PatientLinkageMIMIC3Task
+
+
+def __getattr__(name: str):
+ if name == "MPFClinicalPredictionTask":
+ from .mpf_clinical_prediction import MPFClinicalPredictionTask
+
+ return MPFClinicalPredictionTask
+ raise AttributeError(f"module {__name__!r} has no attribute {name!r}")
diff --git a/pyhealth/tasks/mpf_clinical_prediction.py b/pyhealth/tasks/mpf_clinical_prediction.py
new file mode 100644
index 000000000..fb60c8a70
--- /dev/null
+++ b/pyhealth/tasks/mpf_clinical_prediction.py
@@ -0,0 +1,228 @@
+"""Multitask Prompted Fine-tuning (MPF) clinical prediction on FHIR timelines."""
+
+from __future__ import annotations
+
+import itertools
+from typing import Any, Dict, List, Optional
+
+import polars as pl
+import torch
+
+from pyhealth.data import Patient
+from pyhealth.processors.cehr_processor import (
+ CehrProcessor,
+ ConceptVocab,
+ ensure_special_tokens,
+ infer_mortality_label,
+)
+
+from .base_task import BaseTask
+
+
+def _pad_int(seq: List[int], max_len: int, pad: int = 0) -> List[int]:
+ if len(seq) > max_len:
+ return seq[-max_len:]
+ return seq + [pad] * (max_len - len(seq))
+
+
+def _pad_float(seq: List[float], max_len: int, pad: float = 0.0) -> List[float]:
+ if len(seq) > max_len:
+ return seq[-max_len:]
+ return seq + [pad] * (max_len - len(seq))
+
+
+def _left_pad_int(seq: List[int], max_len: int, pad: int = 0) -> List[int]:
+ if len(seq) > max_len:
+ return seq[-max_len:]
+ return [pad] * (max_len - len(seq)) + seq
+
+
+def _left_pad_float(seq: List[float], max_len: int, pad: float = 0.0) -> List[float]:
+ if len(seq) > max_len:
+ return seq[-max_len:]
+ return [pad] * (max_len - len(seq)) + seq
+
+
+class MPFClinicalPredictionTask(BaseTask):
+ """Binary mortality prediction from FHIR CEHR sequences with optional MPF tokens.
+
+ The task owns a :class:`~pyhealth.processors.CehrProcessor` and its
+ :class:`~pyhealth.processors.ConceptVocab`. For single-worker use the
+ vocabulary grows lazily in :meth:`__call__`. For multi-worker LitData
+ runs, call :meth:`warm_vocab` before
+ :meth:`~pyhealth.datasets.BaseDataset.set_task` so the vocabulary is
+ complete and :attr:`frozen_vocab` prevents races across workers.
+
+ Attributes:
+ max_len: Truncated sequence length (must be >= 2 for boundary tokens).
+ use_mpf: If True, use ```` / ```` specials; else ```` / ````.
+ processor: The CEHR processor owning the shared concept vocabulary.
+ frozen_vocab: If True, do not add new concept ids (post-warmup parallel path).
+ """
+
+ task_name: str = "MPFClinicalPredictionFHIR"
+ input_schema: Dict[str, Any] = {
+ "concept_ids": ("tensor", {"dtype": torch.long}),
+ "token_type_ids": ("tensor", {"dtype": torch.long}),
+ "time_stamps": "tensor",
+ "ages": "tensor",
+ "visit_orders": ("tensor", {"dtype": torch.long}),
+ "visit_segments": ("tensor", {"dtype": torch.long}),
+ }
+ output_schema: Dict[str, str] = {"label": "binary"}
+
+ def __init__(
+ self,
+ max_len: int = 512,
+ use_mpf: bool = True,
+ processor: Optional[CehrProcessor] = None,
+ ) -> None:
+ if max_len < 2:
+ raise ValueError("max_len must be >= 2 for MPF boundary tokens")
+ self.max_len = max_len
+ self.use_mpf = use_mpf
+ self.processor = processor or CehrProcessor(max_len=max_len)
+ self._specials: Optional[Dict[str, int]] = None
+
+ # ------------------------------------------------------------------
+ # Backward-compatible property aliases
+ # ------------------------------------------------------------------
+
+ @property
+ def vocab(self) -> ConceptVocab:
+ return self.processor.vocab
+
+ @vocab.setter
+ def vocab(self, value: ConceptVocab) -> None:
+ self.processor.vocab = value
+
+ @property
+ def frozen_vocab(self) -> bool:
+ return self.processor.frozen_vocab
+
+ @frozen_vocab.setter
+ def frozen_vocab(self, value: bool) -> None:
+ self.processor.frozen_vocab = value
+
+ # ------------------------------------------------------------------
+ # Vocabulary warm-up (call before set_task for multi-worker safety)
+ # ------------------------------------------------------------------
+
+ def warm_vocab(self, dataset: Any, num_workers: int = 1) -> None:
+ """Warm CEHR vocabulary from *dataset* before calling ``set_task``.
+
+ For single-worker pipelines this is **optional** — ``__call__`` grows
+ the vocabulary lazily. For multi-worker pipelines call this first so
+ that the vocabulary is fully populated before LitData forks workers::
+
+ task = MPFClinicalPredictionTask(max_len=512)
+ task.warm_vocab(ds, num_workers=4)
+ sample_dataset = ds.set_task(task, num_workers=4)
+
+ Args:
+ dataset: A :class:`~pyhealth.datasets.BaseDataset` instance whose
+ ``global_event_df`` has already been built.
+ num_workers: Number of workers that will be passed to ``set_task``.
+ When > 1, ``frozen_vocab`` is set after warming so that worker
+ processes look up tokens instead of racing on
+ :class:`~pyhealth.processors.ConceptVocab`.
+ """
+ from litdata.processing.data_processor import in_notebook
+
+ worker_count = 1 if in_notebook() else num_workers
+ filtered = self.pre_filter(dataset.global_event_df)
+ warmup_pids = (
+ filtered.select("patient_id")
+ .unique()
+ .collect(engine="streaming")
+ .to_series()
+ .sort()
+ .to_list()
+ )
+ patient_count = len(warmup_pids)
+ effective_workers = min(worker_count, patient_count) if patient_count else 1
+
+ clinical_cap = max(0, self.max_len - 2)
+ base = dataset.global_event_df
+
+ def _iter_warmup_patients():
+ for batch in itertools.batched(warmup_pids, 128):
+ batch_df = (
+ base.filter(pl.col("patient_id").is_in(batch))
+ .collect(engine="streaming")
+ )
+ for patient_df in batch_df.partition_by("patient_id"):
+ yield Patient(
+ patient_id=patient_df["patient_id"][0],
+ data_source=patient_df,
+ )
+
+ self.processor.fit(_iter_warmup_patients(), clinical_cap=clinical_cap)
+ self.frozen_vocab = effective_workers > 1
+ self._specials = ensure_special_tokens(self.processor.vocab)
+
+ # ------------------------------------------------------------------
+ # Per-patient sample generation
+ # ------------------------------------------------------------------
+
+ def _ensure_processor(self) -> CehrProcessor:
+ if self._specials is None:
+ self._specials = ensure_special_tokens(self.processor.vocab)
+ return self.processor
+
+ def __call__(self, patient: Patient) -> List[Dict[str, Any]]:
+ """Build one labeled sample dict per patient.
+
+ Args:
+ patient: A tabular :class:`~pyhealth.data.Patient`.
+
+ Returns:
+ A one-element list with ``concept_ids``, tensor-ready feature lists, and
+ ``label`` (0/1). Boundary tokens are always included; when
+ ``max_len == 2`` the sequence is ````/```` and ```` only.
+ """
+ processor = self._ensure_processor()
+ pid = patient.patient_id
+ (
+ concept_ids,
+ token_types,
+ time_stamps,
+ ages,
+ visit_orders,
+ visit_segments,
+ ) = processor.process(patient)
+
+ assert self._specials is not None
+ mor_id = self._specials[""] if self.use_mpf else self._specials[""]
+ reg_id = self._specials[""]
+ z0 = 0
+ zf = 0.0
+ concept_ids = [mor_id] + concept_ids + [reg_id]
+ token_types = [z0] + token_types + [z0]
+ time_stamps = [zf] + time_stamps + [zf]
+ ages = [zf] + ages + [zf]
+ visit_orders = [z0] + visit_orders + [z0]
+ visit_segments = [z0] + visit_segments + [z0]
+
+ ml = self.max_len
+ concept_ids = _left_pad_int(concept_ids, ml, processor.vocab.pad_id)
+ token_types = _left_pad_int(token_types, ml, 0)
+ time_stamps = _left_pad_float(time_stamps, ml, 0.0)
+ ages = _left_pad_float(ages, ml, 0.0)
+ visit_orders = _left_pad_int(visit_orders, ml, 0)
+ visit_segments = _left_pad_int(visit_segments, ml, 0)
+
+ label = infer_mortality_label(patient)
+ return [
+ {
+ "patient_id": pid,
+ "visit_id": f"{pid}-0",
+ "concept_ids": concept_ids,
+ "token_type_ids": token_types,
+ "time_stamps": time_stamps,
+ "ages": ages,
+ "visit_orders": visit_orders,
+ "visit_segments": visit_segments,
+ "label": label,
+ }
+ ]
diff --git a/pyproject.toml b/pyproject.toml
index 98f88d47b..47e40a62d 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -45,6 +45,7 @@ dependencies = [
"more-itertools~=10.8.0",
"einops>=0.8.0",
"linear-attention-transformer>=0.19.1",
+ "orjson~=3.10",
]
license = "BSD-3-Clause"
license-files = ["LICENSE.md"]
diff --git a/tests/core/test_ehrmamba_cehr.py b/tests/core/test_ehrmamba_cehr.py
new file mode 100644
index 000000000..4b1f5065e
--- /dev/null
+++ b/tests/core/test_ehrmamba_cehr.py
@@ -0,0 +1,124 @@
+import unittest
+
+import torch
+
+from pyhealth.datasets import create_sample_dataset, get_dataloader
+from pyhealth.models import EHRMambaCEHR
+from pyhealth.tasks.mpf_clinical_prediction import MPFClinicalPredictionTask
+
+
+def _tiny_samples(seq: int = 16) -> tuple:
+ from pyhealth.processors.cehr_processor import ConceptVocab, ensure_special_tokens
+
+ task = MPFClinicalPredictionTask(max_len=seq, use_mpf=True)
+ task.vocab = ConceptVocab()
+ sp = ensure_special_tokens(task.vocab)
+ mid = task.vocab.add_token("test|filler")
+ samples = []
+ for lab in (0, 1):
+ samples.append(
+ {
+ "patient_id": f"p{lab}",
+ "visit_id": f"v{lab}",
+ "concept_ids": [sp[""]] + [mid] * (seq - 2) + [sp[""]],
+ "token_type_ids": [0] * seq,
+ "time_stamps": [0.0] * seq,
+ "ages": [50.0] * seq,
+ "visit_orders": [0] * seq,
+ "visit_segments": [0] * seq,
+ "label": lab,
+ }
+ )
+ return samples, task
+
+
+class TestEHRMambaCEHR(unittest.TestCase):
+ def test_readout_pools_rightmost_non_pad(self) -> None:
+ """MPF padding between tokens must not make pooling pick a pad position."""
+
+ from pyhealth.models.utils import (
+ get_last_visit,
+ get_rightmost_masked_timestep,
+ )
+
+ h = torch.tensor([[[1.0, 0.0], [2.0, 0.0], [0.0, 0.0], [99.0, 0.0]]])
+ m = torch.tensor([[True, True, False, True]])
+ out = get_rightmost_masked_timestep(h, m)
+ self.assertTrue(torch.allclose(out[0], torch.tensor([99.0, 0.0])))
+ wrong = get_last_visit(h, m)
+ self.assertFalse(torch.allclose(out[0], wrong[0]))
+
+ def test_end_to_end_fhir_pipeline(self) -> None:
+ import tempfile
+ from pathlib import Path
+
+ from pyhealth.datasets import MIMIC4FHIRDataset, create_sample_dataset
+ from pyhealth.datasets import get_dataloader
+
+ from tests.core.test_mimic4_fhir_ndjson_fixtures import run_task, write_two_class_ndjson
+
+ task = MPFClinicalPredictionTask(max_len=32, use_mpf=True)
+ with tempfile.TemporaryDirectory() as tmp:
+ write_two_class_ndjson(Path(tmp))
+ ds = MIMIC4FHIRDataset(
+ root=tmp, glob_pattern="*.ndjson", cache_dir=tmp
+ )
+ samples = run_task(ds, task)
+ sample_ds = create_sample_dataset(
+ samples=samples,
+ input_schema=task.input_schema,
+ output_schema=task.output_schema,
+ dataset_name="fhir_test",
+ )
+ vocab_size = max(max(s["concept_ids"]) for s in samples) + 1
+ model = EHRMambaCEHR(
+ dataset=sample_ds,
+ vocab_size=vocab_size,
+ embedding_dim=64,
+ num_layers=1,
+ )
+ batch = next(
+ iter(get_dataloader(sample_ds, batch_size=2, shuffle=False))
+ )
+ out = model(**batch)
+ self.assertIn("loss", out)
+ out["loss"].backward()
+
+ def test_forward_backward(self) -> None:
+ samples, task = _tiny_samples()
+ ds = create_sample_dataset(
+ samples=samples,
+ input_schema=task.input_schema,
+ output_schema=task.output_schema,
+ )
+ vocab_size = max(max(s["concept_ids"]) for s in samples) + 1
+ model = EHRMambaCEHR(
+ dataset=ds,
+ vocab_size=vocab_size,
+ embedding_dim=64,
+ num_layers=1,
+ state_size=8,
+ )
+ batch = next(iter(get_dataloader(ds, batch_size=2, shuffle=False)))
+ out = model(**batch)
+ self.assertEqual(out["logit"].shape[0], 2)
+ out["loss"].backward()
+
+ def test_eval_mode(self) -> None:
+ samples, task = _tiny_samples()
+ ds = create_sample_dataset(
+ samples=samples,
+ input_schema=task.input_schema,
+ output_schema=task.output_schema,
+ )
+ vocab_size = max(max(s["concept_ids"]) for s in samples) + 1
+ model = EHRMambaCEHR(dataset=ds, vocab_size=vocab_size, embedding_dim=32, num_layers=1)
+ model.eval()
+ with torch.no_grad():
+ batch = next(iter(get_dataloader(ds, batch_size=2, shuffle=False)))
+ out = model(**batch)
+ self.assertIn("y_prob", out)
+
+
+if __name__ == "__main__":
+ unittest.main()
diff --git a/tests/core/test_mimic4_fhir_dataset.py b/tests/core/test_mimic4_fhir_dataset.py
new file mode 100644
index 000000000..ceb0a9b25
--- /dev/null
+++ b/tests/core/test_mimic4_fhir_dataset.py
@@ -0,0 +1,694 @@
+import gzip
+import tempfile
+import unittest
+from pathlib import Path
+from typing import Dict, List
+
+import orjson
+import polars as pl
+
+from pyhealth.data import Patient
+from pyhealth.datasets import MIMIC4FHIRDataset
+from pyhealth.datasets.fhir_utils import (
+ _flatten_resource_to_table_row,
+)
+from pyhealth.processors.cehr_processor import (
+ ConceptVocab,
+ build_cehr_sequences,
+ collect_cehr_timeline_events,
+ infer_mortality_label,
+)
+
+from tests.core.test_mimic4_fhir_ndjson_fixtures import (
+ ndjson_two_class_text,
+ run_task,
+ write_one_patient_ndjson,
+ write_two_class_ndjson,
+)
+
+
+def _third_patient_loinc_resources() -> List[Dict[str, object]]:
+ return [
+ {
+ "resourceType": "Patient",
+ "id": "p-synth-3",
+ "birthDate": "1960-01-01",
+ },
+ {
+ "resourceType": "Encounter",
+ "id": "e3",
+ "subject": {"reference": "Patient/p-synth-3"},
+ "period": {"start": "2020-08-01T10:00:00Z"},
+ "class": {"code": "IMP"},
+ },
+ {
+ "resourceType": "Observation",
+ "id": "o3",
+ "subject": {"reference": "Patient/p-synth-3"},
+ "encounter": {"reference": "Encounter/e3"},
+ "effectiveDateTime": "2020-08-01T12:00:00Z",
+ "code": {"coding": [{"system": "http://loinc.org", "code": "999-9"}]},
+ },
+ ]
+
+
+def write_two_class_plus_third_ndjson(directory: Path, *, name: str = "fixture.ndjson") -> Path:
+ lines = ndjson_two_class_text().strip().split("\n")
+ lines.extend(orjson.dumps(r).decode("utf-8") for r in _third_patient_loinc_resources())
+ path = directory / name
+ path.write_text("\n".join(lines) + "\n", encoding="utf-8")
+ return path
+
+
+def _patient_from_rows(patient_id: str, rows: List[Dict[str, object]]) -> Patient:
+ return Patient(patient_id=patient_id, data_source=pl.DataFrame(rows))
+
+
+class TestDeceasedBooleanFlattening(unittest.TestCase):
+ def test_string_false_not_coerced_by_python_bool(self) -> None:
+ """Non-conformant ``\"false\"`` string must not become stored ``\"true\"``."""
+ row = _flatten_resource_to_table_row(
+ {
+ "resourceType": "Patient",
+ "id": "p-str-false",
+ "deceasedBoolean": "false",
+ }
+ )
+ self.assertIsNotNone(row)
+ _table, payload = row
+ self.assertEqual(payload.get("deceased_boolean"), "false")
+
+ def test_string_true_parsed(self) -> None:
+ row = _flatten_resource_to_table_row(
+ {
+ "resourceType": "Patient",
+ "id": "p-str-true",
+ "deceasedBoolean": "true",
+ }
+ )
+ self.assertIsNotNone(row)
+ self.assertEqual(row[1].get("deceased_boolean"), "true")
+
+ def test_json_booleans_unchanged(self) -> None:
+ for raw, expected in ((True, "true"), (False, "false")):
+ with self.subTest(raw=raw):
+ row = _flatten_resource_to_table_row(
+ {
+ "resourceType": "Patient",
+ "id": "p-bool",
+ "deceasedBoolean": raw,
+ }
+ )
+ self.assertIsNotNone(row)
+ self.assertEqual(row[1].get("deceased_boolean"), expected)
+
+ def test_unknown_deceased_type_stored_as_none(self) -> None:
+ row = _flatten_resource_to_table_row(
+ {
+ "resourceType": "Patient",
+ "id": "p-garbage",
+ "deceasedBoolean": {"unexpected": "object"},
+ }
+ )
+ self.assertIsNotNone(row)
+ self.assertIsNone(row[1].get("deceased_boolean"))
+
+ def test_infer_mortality_respects_string_false_row(self) -> None:
+ patient = _patient_from_rows(
+ "p1",
+ [
+ {
+ "event_type": "patient",
+ "timestamp": "2020-01-01T00:00:00",
+ "patient/deceased_boolean": "false",
+ },
+ ],
+ )
+ self.assertEqual(infer_mortality_label(patient), 0)
+
+
+class TestMIMIC4FHIRDataset(unittest.TestCase):
+ def test_concept_vocab_from_json_empty_token_to_id(self) -> None:
+ v = ConceptVocab.from_json({"token_to_id": {}})
+ self.assertIn("", v.token_to_id)
+ self.assertIn("", v.token_to_id)
+ self.assertEqual(v._next_id, 2)
+
+ def test_concept_vocab_from_json_empty_respects_next_id(self) -> None:
+ v = ConceptVocab.from_json({"token_to_id": {}, "next_id": 50})
+ self.assertEqual(v._next_id, 50)
+
+ def test_sorted_ndjson_files_accepts_sequence_and_dedupes(self) -> None:
+ from pyhealth.datasets.fhir_utils import sorted_ndjson_files
+
+ with tempfile.TemporaryDirectory() as tmp:
+ root = Path(tmp)
+ (root / "MimicPatient.ndjson.gz").write_text("x", encoding="utf-8")
+ (root / "MimicMedication.ndjson.gz").write_text("y", encoding="utf-8")
+ (root / "notes.txt").write_text("z", encoding="utf-8")
+ wide = sorted_ndjson_files(root, "**/*.ndjson.gz")
+ narrow = sorted_ndjson_files(
+ root,
+ ["MimicPatient*.ndjson.gz", "**/MimicPatient*.ndjson.gz"],
+ )
+ self.assertEqual(len(wide), 2)
+ self.assertEqual(len(narrow), 1)
+ self.assertEqual(narrow[0].name, "MimicPatient.ndjson.gz")
+
+ def test_dataset_accepts_glob_patterns_kwarg(self) -> None:
+ with tempfile.TemporaryDirectory() as tmp:
+ write_one_patient_ndjson(Path(tmp))
+ ds = MIMIC4FHIRDataset(
+ root=tmp, glob_patterns=["*.ndjson"], cache_dir=tmp
+ )
+ self.assertEqual(ds.glob_patterns, ["*.ndjson"])
+ _ = ds.global_event_df.collect(engine="streaming")
+
+ def test_dataset_rejects_both_glob_kwargs(self) -> None:
+ with tempfile.TemporaryDirectory() as tmp:
+ with self.assertRaises(ValueError):
+ MIMIC4FHIRDataset(
+ root=tmp,
+ glob_pattern="*.ndjson",
+ glob_patterns=["*.ndjson"],
+ cache_dir=tmp,
+ )
+
+ def test_disk_fixture_resolves_events_per_patient(self) -> None:
+ with tempfile.TemporaryDirectory() as tmp:
+ write_one_patient_ndjson(Path(tmp))
+ ds = MIMIC4FHIRDataset(root=tmp, glob_pattern="*.ndjson", cache_dir=tmp)
+ sub = ds.global_event_df.filter(pl.col("patient_id") == "p-synth-1").collect(
+ engine="streaming"
+ )
+ self.assertGreaterEqual(len(sub), 2)
+ self.assertIn("condition/concept_key", sub.columns)
+
+ def test_prepared_flat_tables_exist(self) -> None:
+ with tempfile.TemporaryDirectory() as tmp:
+ write_two_class_ndjson(Path(tmp))
+ ds = MIMIC4FHIRDataset(root=tmp, glob_pattern="*.ndjson", cache_dir=tmp)
+ _ = ds.global_event_df.collect(engine="streaming")
+ prepared = ds.prepared_tables_dir
+ self.assertTrue((prepared / "patient.parquet").is_file())
+ self.assertTrue((prepared / "encounter.parquet").is_file())
+ self.assertTrue((prepared / "condition.parquet").is_file())
+
+ def test_build_cehr_non_empty(self) -> None:
+ from pyhealth.tasks.mpf_clinical_prediction import MPFClinicalPredictionTask
+
+ with tempfile.TemporaryDirectory() as tmp:
+ write_one_patient_ndjson(Path(tmp))
+ ds = MIMIC4FHIRDataset(root=tmp, glob_pattern="*.ndjson", cache_dir=tmp)
+ task = MPFClinicalPredictionTask(max_len=64, use_mpf=True)
+ run_task(ds, task)
+ self.assertIsInstance(task.vocab, ConceptVocab)
+ self.assertGreater(task.vocab.vocab_size, 2)
+
+ def test_set_task_vocab_warm_on_litdata_cache_hit(self) -> None:
+ from pyhealth.tasks.mpf_clinical_prediction import MPFClinicalPredictionTask
+
+ with tempfile.TemporaryDirectory() as tmp:
+ write_two_class_ndjson(Path(tmp))
+ task_kw = {"max_len": 64, "use_mpf": True}
+ ds1 = MIMIC4FHIRDataset(root=tmp, glob_pattern="*.ndjson", cache_dir=tmp)
+ task1 = MPFClinicalPredictionTask(**task_kw)
+ task1.warm_vocab(ds1, num_workers=1)
+ ds1.set_task(task1, num_workers=1)
+ warm_size = task1.vocab.vocab_size
+ self.assertGreater(warm_size, 6)
+ ds2 = MIMIC4FHIRDataset(root=tmp, glob_pattern="*.ndjson", cache_dir=tmp)
+ task2 = MPFClinicalPredictionTask(**task_kw)
+ task2.warm_vocab(ds2, num_workers=1)
+ ds2.set_task(task2, num_workers=1)
+ self.assertEqual(task2.vocab.vocab_size, warm_size)
+
+ def test_mortality_heuristic(self) -> None:
+ from pyhealth.tasks.mpf_clinical_prediction import MPFClinicalPredictionTask
+
+ with tempfile.TemporaryDirectory() as tmp:
+ write_two_class_ndjson(Path(tmp))
+ ds = MIMIC4FHIRDataset(root=tmp, glob_pattern="*.ndjson", cache_dir=tmp)
+ samples = run_task(ds, MPFClinicalPredictionTask(max_len=64, use_mpf=False))
+ self.assertEqual({s["label"] for s in samples}, {0, 1})
+
+ def test_infer_deceased(self) -> None:
+ with tempfile.TemporaryDirectory() as tmp:
+ write_two_class_ndjson(Path(tmp))
+ ds = MIMIC4FHIRDataset(root=tmp, glob_pattern="*.ndjson", cache_dir=tmp)
+ dead = ds.get_patient("p-synth-2")
+ self.assertEqual(infer_mortality_label(dead), 1)
+
+ def test_disk_ndjson_gz_physionet_style(self) -> None:
+ with tempfile.TemporaryDirectory() as tmp:
+ gz_path = Path(tmp) / "fixture.ndjson.gz"
+ with gzip.open(gz_path, "wt", encoding="utf-8") as gz:
+ gz.write(ndjson_two_class_text())
+ ds = MIMIC4FHIRDataset(root=tmp, glob_pattern="*.ndjson.gz", max_patients=5)
+ self.assertGreaterEqual(len(ds.unique_patient_ids), 1)
+
+ def test_disk_ndjson_temp_dir(self) -> None:
+ from pyhealth.tasks.mpf_clinical_prediction import MPFClinicalPredictionTask
+
+ with tempfile.TemporaryDirectory() as tmp:
+ write_two_class_ndjson(Path(tmp))
+ ds = MIMIC4FHIRDataset(root=tmp, glob_pattern="*.ndjson", max_patients=5)
+ self.assertEqual(len(ds.unique_patient_ids), 2)
+ samples = run_task(ds, MPFClinicalPredictionTask(max_len=48, use_mpf=True))
+ self.assertGreaterEqual(len(samples), 1)
+ for sample in samples:
+ self.assertIn("concept_ids", sample)
+ self.assertIn("label", sample)
+
+ def test_global_event_df_schema_and_flattened_columns(self) -> None:
+ with tempfile.TemporaryDirectory() as tmp:
+ write_two_class_ndjson(Path(tmp))
+ ds = MIMIC4FHIRDataset(root=tmp, glob_pattern="*.ndjson", cache_dir=tmp)
+ df = ds.global_event_df.collect(engine="streaming")
+ self.assertGreater(len(df), 0)
+ self.assertIn("patient_id", df.columns)
+ self.assertIn("timestamp", df.columns)
+ self.assertIn("event_type", df.columns)
+ self.assertIn("condition/concept_key", df.columns)
+ self.assertIn("observation/concept_key", df.columns)
+ self.assertIn("patient/deceased_boolean", df.columns)
+
+ def test_set_task_produces_correct_samples(self) -> None:
+ from pyhealth.tasks.mpf_clinical_prediction import MPFClinicalPredictionTask
+
+ with tempfile.TemporaryDirectory() as tmp:
+ write_two_class_ndjson(Path(tmp), name="fx.ndjson")
+ ds = MIMIC4FHIRDataset(
+ root=tmp, glob_pattern="*.ndjson", cache_dir=tmp, num_workers=1
+ )
+ task = MPFClinicalPredictionTask(max_len=48, use_mpf=True)
+ task.warm_vocab(ds, num_workers=1)
+ sample_ds = ds.set_task(task, num_workers=1)
+ samples = sorted(
+ [sample_ds[i] for i in range(len(sample_ds))],
+ key=lambda s: s["patient_id"],
+ )
+ self.assertEqual(len(samples), 2)
+ for s in samples:
+ self.assertIn("concept_ids", s)
+ self.assertIn("label", s)
+ labels = {int(s["label"]) for s in samples}
+ self.assertEqual(labels, {0, 1})
+
+ def test_set_task_multi_worker_sets_frozen_vocab(self) -> None:
+ from pyhealth.tasks.mpf_clinical_prediction import MPFClinicalPredictionTask
+
+ with tempfile.TemporaryDirectory() as tmp:
+ write_two_class_ndjson(Path(tmp))
+ ds = MIMIC4FHIRDataset(
+ root=tmp, glob_pattern="*.ndjson", cache_dir=tmp, num_workers=2
+ )
+ task = MPFClinicalPredictionTask(max_len=48, use_mpf=True)
+ task.warm_vocab(ds, num_workers=2)
+ ds.set_task(task, num_workers=2)
+ self.assertTrue(task.frozen_vocab)
+
+ def test_mpf_pre_filter_vocab_warmup_excludes_dropped_patients(self) -> None:
+ from pyhealth.tasks.mpf_clinical_prediction import MPFClinicalPredictionTask
+
+ class TwoPatientMPFTask(MPFClinicalPredictionTask):
+ def pre_filter(self, df: pl.LazyFrame) -> pl.LazyFrame:
+ return df.filter(pl.col("patient_id").is_in(["p-synth-1", "p-synth-2"]))
+
+ with tempfile.TemporaryDirectory() as tmp:
+ write_two_class_plus_third_ndjson(Path(tmp))
+ ds = MIMIC4FHIRDataset(
+ root=tmp, glob_pattern="*.ndjson", cache_dir=tmp, num_workers=1
+ )
+ self.assertEqual(len(ds.unique_patient_ids), 3)
+ task = TwoPatientMPFTask(max_len=48, use_mpf=True)
+ task.warm_vocab(ds, num_workers=1)
+ ds.set_task(task, num_workers=1)
+ self.assertNotIn("http://loinc.org|999-9", task.vocab.token_to_id)
+ self.assertIn("http://loinc.org|789-0", task.vocab.token_to_id)
+
+ def test_mpf_pre_filter_single_patient_limits_effective_workers(self) -> None:
+ """Pre-filter that yields one patient should cap effective_workers to 1.
+
+ We verify the effective_workers logic directly rather than via
+ ``set_task`` because ``set_task`` with a 1-patient cohort produces
+ only one label class (p-synth-1 is alive → label=0), which causes
+ ``BinaryLabelProcessor.fit`` to raise "Expected 2 unique labels, got 1".
+ The invariant under test belongs to the ``set_task`` override in
+ ``MIMIC4FHIRDataset``; the Polars pre-filter and worker-count
+ formula are both exercised here without triggering that constraint.
+ """
+ from pyhealth.tasks.mpf_clinical_prediction import MPFClinicalPredictionTask
+
+ class OnePatientMPFTask(MPFClinicalPredictionTask):
+ def pre_filter(self, df: pl.LazyFrame) -> pl.LazyFrame:
+ return df.filter(pl.col("patient_id") == "p-synth-1")
+
+ with tempfile.TemporaryDirectory() as tmp:
+ write_two_class_ndjson(Path(tmp))
+ ds = MIMIC4FHIRDataset(
+ root=tmp, glob_pattern="*.ndjson", cache_dir=tmp, num_workers=2
+ )
+ task = OnePatientMPFTask(max_len=48, use_mpf=True)
+ warmup_pids = (
+ task.pre_filter(ds.global_event_df)
+ .select("patient_id")
+ .unique()
+ .collect(engine="streaming")
+ .to_series()
+ .sort()
+ .to_list()
+ )
+ self.assertEqual(warmup_pids, ["p-synth-1"])
+ # One patient, two requested workers: effective_workers = min(2, 1) = 1
+ effective_workers = min(2, len(warmup_pids)) if warmup_pids else 1
+ self.assertEqual(effective_workers, 1)
+
+ def test_encounter_reference_requires_exact_id(self) -> None:
+ patient = _patient_from_rows(
+ "p1",
+ [
+ {
+ "patient_id": "p1",
+ "event_type": "patient",
+ "timestamp": None,
+ "patient/birth_date": "1950-01-01",
+ },
+ {
+ "patient_id": "p1",
+ "event_type": "encounter",
+ "timestamp": "2020-06-01T10:00:00",
+ "encounter/encounter_id": "e1",
+ "encounter/encounter_class": "AMB",
+ },
+ {
+ "patient_id": "p1",
+ "event_type": "encounter",
+ "timestamp": "2020-07-02T10:00:00",
+ "encounter/encounter_id": "e10",
+ "encounter/encounter_class": "IMP",
+ },
+ {
+ "patient_id": "p1",
+ "event_type": "condition",
+ "timestamp": "2020-07-02T11:00:00",
+ "condition/encounter_id": "e10",
+ "condition/concept_key": "http://hl7.org/fhir/sid/icd-10-cm|I99",
+ },
+ ],
+ )
+ vocab = ConceptVocab()
+ concept_ids, *_ = build_cehr_sequences(patient, vocab, max_len=64)
+ tid = vocab["http://hl7.org/fhir/sid/icd-10-cm|I99"]
+ self.assertEqual(concept_ids.count(tid), 1)
+
+ def test_unlinked_condition_emitted_once_with_two_encounters(self) -> None:
+ patient = _patient_from_rows(
+ "p1",
+ [
+ {
+ "patient_id": "p1",
+ "event_type": "patient",
+ "timestamp": None,
+ "patient/birth_date": "1950-01-01",
+ },
+ {
+ "patient_id": "p1",
+ "event_type": "encounter",
+ "timestamp": "2020-06-01T10:00:00",
+ "encounter/encounter_id": "ea",
+ "encounter/encounter_class": "AMB",
+ },
+ {
+ "patient_id": "p1",
+ "event_type": "encounter",
+ "timestamp": "2020-07-01T10:00:00",
+ "encounter/encounter_id": "eb",
+ "encounter/encounter_class": "IMP",
+ },
+ {
+ "patient_id": "p1",
+ "event_type": "condition",
+ "timestamp": "2020-06-15T12:00:00",
+ "condition/concept_key": "http://hl7.org/fhir/sid/icd-10-cm|Z00",
+ },
+ ],
+ )
+ vocab = ConceptVocab()
+ concept_ids, *_ = build_cehr_sequences(patient, vocab, max_len=64)
+ self.assertEqual(concept_ids.count(vocab["http://hl7.org/fhir/sid/icd-10-cm|Z00"]), 1)
+
+ def test_cehr_sequence_shapes(self) -> None:
+ with tempfile.TemporaryDirectory() as tmp:
+ write_one_patient_ndjson(Path(tmp))
+ ds = MIMIC4FHIRDataset(root=tmp, glob_pattern="*.ndjson", cache_dir=tmp)
+ patient = ds.get_patient("p-synth-1")
+ vocab = ConceptVocab()
+ concept_ids, token_types, time_stamps, ages, visit_orders, visit_segments = (
+ build_cehr_sequences(patient, vocab, max_len=32)
+ )
+ n = len(concept_ids)
+ self.assertEqual(len(token_types), n)
+ self.assertEqual(len(time_stamps), n)
+ self.assertEqual(len(ages), n)
+ self.assertEqual(len(visit_orders), n)
+ self.assertEqual(len(visit_segments), n)
+ self.assertGreater(n, 0)
+
+ def test_build_cehr_max_len_zero_no_clinical_tokens(self) -> None:
+ patient = _patient_from_rows(
+ "p1",
+ [
+ {
+ "patient_id": "p1",
+ "event_type": "patient",
+ "timestamp": None,
+ "patient/birth_date": "1950-01-01",
+ },
+ {
+ "patient_id": "p1",
+ "event_type": "encounter",
+ "timestamp": "2020-06-01T10:00:00",
+ "encounter/encounter_id": "e1",
+ "encounter/encounter_class": "AMB",
+ },
+ {
+ "patient_id": "p1",
+ "event_type": "condition",
+ "timestamp": "2020-06-01T11:00:00",
+ "condition/encounter_id": "e1",
+ "condition/concept_key": "http://hl7.org/fhir/sid/icd-10-cm|I10",
+ },
+ ],
+ )
+ vocab = ConceptVocab()
+ c, _, _, _, _, vs = build_cehr_sequences(patient, vocab, max_len=0)
+ self.assertEqual(c, [])
+ self.assertEqual(vs, [])
+
+ def test_visit_segments_alternate_by_visit_index(self) -> None:
+ patient = _patient_from_rows(
+ "p1",
+ [
+ {
+ "patient_id": "p1",
+ "event_type": "patient",
+ "timestamp": None,
+ "patient/birth_date": "1950-01-01",
+ },
+ {
+ "patient_id": "p1",
+ "event_type": "encounter",
+ "timestamp": "2020-06-01T10:00:00",
+ "encounter/encounter_id": "e0",
+ "encounter/encounter_class": "AMB",
+ },
+ {
+ "patient_id": "p1",
+ "event_type": "condition",
+ "timestamp": "2020-06-01T11:00:00",
+ "condition/encounter_id": "e0",
+ "condition/concept_key": "http://hl7.org/fhir/sid/icd-10-cm|I10",
+ },
+ {
+ "patient_id": "p1",
+ "event_type": "encounter",
+ "timestamp": "2020-07-01T10:00:00",
+ "encounter/encounter_id": "e1",
+ "encounter/encounter_class": "IMP",
+ },
+ {
+ "patient_id": "p1",
+ "event_type": "condition",
+ "timestamp": "2020-07-01T11:00:00",
+ "condition/encounter_id": "e1",
+ "condition/concept_key": "http://hl7.org/fhir/sid/icd-10-cm|I20",
+ },
+ ],
+ )
+ vocab = ConceptVocab()
+ _, _, _, _, _, visit_segments = build_cehr_sequences(patient, vocab, max_len=64)
+ self.assertEqual(visit_segments, [0, 0, 1, 1])
+
+ def test_unlinked_visit_idx_matches_sequential_counter(self) -> None:
+ patient = _patient_from_rows(
+ "p1",
+ [
+ {
+ "patient_id": "p1",
+ "event_type": "patient",
+ "timestamp": None,
+ "patient/birth_date": "1950-01-01",
+ },
+ {
+ "patient_id": "p1",
+ "event_type": "encounter",
+ "timestamp": None,
+ "encounter/encounter_id": "e_bad",
+ "encounter/encounter_class": "AMB",
+ },
+ {
+ "patient_id": "p1",
+ "event_type": "encounter",
+ "timestamp": "2020-03-01T10:00:00",
+ "encounter/encounter_id": "e_ok",
+ "encounter/encounter_class": "IMP",
+ },
+ {
+ "patient_id": "p1",
+ "event_type": "condition",
+ "timestamp": "2020-03-05T11:00:00",
+ "condition/encounter_id": "e_ok",
+ "condition/concept_key": "http://hl7.org/fhir/sid/icd-10-cm|I10",
+ },
+ {
+ "patient_id": "p1",
+ "event_type": "condition",
+ "timestamp": "2020-03-15T12:00:00",
+ "condition/concept_key": "http://hl7.org/fhir/sid/icd-10-cm|Z00",
+ },
+ ],
+ )
+ vocab = ConceptVocab()
+ concept_ids, _, _, _, visit_orders, visit_segments = build_cehr_sequences(
+ patient, vocab, max_len=64
+ )
+ i10 = vocab["http://hl7.org/fhir/sid/icd-10-cm|I10"]
+ z00 = vocab["http://hl7.org/fhir/sid/icd-10-cm|Z00"]
+ i_link = concept_ids.index(i10)
+ i_free = concept_ids.index(z00)
+ self.assertEqual(visit_orders[i_link], visit_orders[i_free])
+ self.assertEqual(visit_segments[i_link], visit_segments[i_free])
+
+ def test_medication_request_uses_medication_codeable_concept(self) -> None:
+ patient = _patient_from_rows(
+ "p1",
+ [
+ {
+ "patient_id": "p1",
+ "event_type": "patient",
+ "timestamp": None,
+ "patient/birth_date": "1950-01-01",
+ },
+ {
+ "patient_id": "p1",
+ "event_type": "encounter",
+ "timestamp": "2020-06-01T10:00:00",
+ "encounter/encounter_id": "e1",
+ "encounter/encounter_class": "IMP",
+ },
+ {
+ "patient_id": "p1",
+ "event_type": "medication_request",
+ "timestamp": "2020-06-01T11:00:00",
+ "medication_request/encounter_id": "e1",
+ "medication_request/concept_key": "http://www.nlm.nih.gov/research/umls/rxnorm|111",
+ },
+ {
+ "patient_id": "p1",
+ "event_type": "medication_request",
+ "timestamp": "2020-06-01T12:00:00",
+ "medication_request/encounter_id": "e1",
+ "medication_request/concept_key": "http://www.nlm.nih.gov/research/umls/rxnorm|222",
+ },
+ ],
+ )
+ vocab = ConceptVocab()
+ c, *_ = build_cehr_sequences(patient, vocab, max_len=64)
+ ka = "http://www.nlm.nih.gov/research/umls/rxnorm|111"
+ kb = "http://www.nlm.nih.gov/research/umls/rxnorm|222"
+ self.assertNotEqual(vocab[ka], vocab[kb])
+ self.assertEqual(c.count(vocab[ka]), 1)
+ self.assertEqual(c.count(vocab[kb]), 1)
+
+ def test_medication_request_medication_reference_token(self) -> None:
+ patient = _patient_from_rows(
+ "p1",
+ [
+ {
+ "patient_id": "p1",
+ "event_type": "patient",
+ "timestamp": None,
+ "patient/birth_date": "1950-01-01",
+ },
+ {
+ "patient_id": "p1",
+ "event_type": "encounter",
+ "timestamp": "2020-06-01T10:00:00",
+ "encounter/encounter_id": "e1",
+ "encounter/encounter_class": "IMP",
+ },
+ {
+ "patient_id": "p1",
+ "event_type": "medication_request",
+ "timestamp": "2020-06-01T11:00:00",
+ "medication_request/encounter_id": "e1",
+ "medication_request/concept_key": "MedicationRequest/reference|med-abc",
+ },
+ ],
+ )
+ vocab = ConceptVocab()
+ c, *_ = build_cehr_sequences(patient, vocab, max_len=64)
+ key = "MedicationRequest/reference|med-abc"
+ self.assertIn(vocab[key], c)
+ self.assertEqual(c.count(vocab[key]), 1)
+
+ def test_collect_cehr_timeline_events_orders_by_timestamp(self) -> None:
+ patient = _patient_from_rows(
+ "p1",
+ [
+ {
+ "patient_id": "p1",
+ "event_type": "patient",
+ "timestamp": None,
+ "patient/birth_date": "1950-01-01",
+ },
+ {
+ "patient_id": "p1",
+ "event_type": "encounter",
+ "timestamp": "2020-06-01T10:00:00",
+ "encounter/encounter_id": "e1",
+ "encounter/encounter_class": "AMB",
+ },
+ {
+ "patient_id": "p1",
+ "event_type": "condition",
+ "timestamp": "2020-06-01T11:00:00",
+ "condition/encounter_id": "e1",
+ "condition/concept_key": "a|1",
+ },
+ {
+ "patient_id": "p1",
+ "event_type": "observation",
+ "timestamp": "2020-06-01T12:00:00",
+ "observation/encounter_id": "e1",
+ "observation/concept_key": "b|2",
+ },
+ ],
+ )
+ events = collect_cehr_timeline_events(patient)
+ self.assertEqual([event[1] for event in events], ["encounter|AMB", "a|1", "b|2"])
+
+
+if __name__ == "__main__":
+ unittest.main()
diff --git a/tests/core/test_mimic4_fhir_ndjson_fixtures.py b/tests/core/test_mimic4_fhir_ndjson_fixtures.py
new file mode 100644
index 000000000..450ea5aff
--- /dev/null
+++ b/tests/core/test_mimic4_fhir_ndjson_fixtures.py
@@ -0,0 +1,112 @@
+"""NDJSON file bodies for :mod:`tests.core.test_mimic4_fhir_dataset` (disk-only ingest)."""
+
+from __future__ import annotations
+
+from pathlib import Path
+from typing import Any, Dict, List
+
+import orjson
+
+
+# ---------------------------------------------------------------------------
+# Synthetic in-memory FHIR resources
+# ---------------------------------------------------------------------------
+
+
+def _one_patient_resources() -> List[Dict[str, Any]]:
+ return [
+ {"resourceType": "Patient", "id": "p-synth-1", "birthDate": "1950-01-01", "gender": "female"},
+ {
+ "resourceType": "Encounter",
+ "id": "e1",
+ "subject": {"reference": "Patient/p-synth-1"},
+ "period": {"start": "2020-06-01T10:00:00Z"},
+ "class": {"code": "IMP"},
+ },
+ {
+ "resourceType": "Condition",
+ "id": "c1",
+ "subject": {"reference": "Patient/p-synth-1"},
+ "encounter": {"reference": "Encounter/e1"},
+ "code": {"coding": [{"system": "http://hl7.org/fhir/sid/icd-10-cm", "code": "I10"}]},
+ "onsetDateTime": "2020-06-01T11:00:00Z",
+ },
+ ]
+
+
+def _two_patient_resources() -> List[Dict[str, Any]]:
+ return [
+ *_one_patient_resources(),
+ {"resourceType": "Patient", "id": "p-synth-2", "birthDate": "1940-05-05", "deceasedBoolean": True},
+ {
+ "resourceType": "Encounter",
+ "id": "e-dead",
+ "subject": {"reference": "Patient/p-synth-2"},
+ "period": {"start": "2020-07-01T10:00:00Z"},
+ "class": {"code": "IMP"},
+ },
+ {
+ "resourceType": "Observation",
+ "id": "o-dead",
+ "subject": {"reference": "Patient/p-synth-2"},
+ "encounter": {"reference": "Encounter/e-dead"},
+ "effectiveDateTime": "2020-07-01T12:00:00Z",
+ "code": {"coding": [{"system": "http://loinc.org", "code": "789-0"}]},
+ },
+ ]
+
+
+# ---------------------------------------------------------------------------
+# Text serialisers
+# ---------------------------------------------------------------------------
+
+
+def ndjson_one_patient_text() -> str:
+ return "\n".join(orjson.dumps(r).decode("utf-8") for r in _one_patient_resources()) + "\n"
+
+
+def ndjson_two_class_text() -> str:
+ return "\n".join(orjson.dumps(r).decode("utf-8") for r in _two_patient_resources()) + "\n"
+
+
+# ---------------------------------------------------------------------------
+# Disk writers
+# ---------------------------------------------------------------------------
+
+
+def write_two_class_ndjson(directory: Path, *, name: str = "fixture.ndjson") -> Path:
+ path = directory / name
+ path.write_text(ndjson_two_class_text(), encoding="utf-8")
+ return path
+
+
+def write_one_patient_ndjson(directory: Path, *, name: str = "fixture.ndjson") -> Path:
+ path = directory / name
+ path.write_text(ndjson_one_patient_text(), encoding="utf-8")
+ return path
+
+
+# ---------------------------------------------------------------------------
+# Shared test helper
+# ---------------------------------------------------------------------------
+
+
+def run_task(ds: Any, task: Any) -> List[Dict[str, Any]]:
+ """Run *task* over every patient in *ds* without the LitData caching pipeline.
+
+ This helper mirrors the direct-iteration path that the old
+ ``MIMIC4FHIRDataset.gather_samples`` provided. It is intentionally kept
+ here (the shared fixture module) so all FHIR test files can import it
+ rather than each maintaining their own copy.
+
+ Args:
+ ds: A :class:`~pyhealth.datasets.MIMIC4FHIRDataset` instance whose
+ ``global_event_df`` has already been built.
+ task: A :class:`~pyhealth.tasks.MPFClinicalPredictionTask` instance.
+
+ Returns:
+ Flat list of sample dicts, one per patient.
+ """
+ task._specials = None
+ task.frozen_vocab = False
+ return [s for patient in ds.iter_patients() for s in task(patient)]
diff --git a/tests/core/test_mpf_task.py b/tests/core/test_mpf_task.py
new file mode 100644
index 000000000..947b284c1
--- /dev/null
+++ b/tests/core/test_mpf_task.py
@@ -0,0 +1,88 @@
+import shutil
+import tempfile
+import unittest
+from pathlib import Path
+
+from pyhealth.datasets import MIMIC4FHIRDataset
+from pyhealth.tasks.mpf_clinical_prediction import MPFClinicalPredictionTask
+
+from tests.core.test_mimic4_fhir_ndjson_fixtures import run_task, write_two_class_ndjson
+
+
+class TestMPFClinicalPredictionTask(unittest.TestCase):
+ def _two_patient_ds(self) -> MIMIC4FHIRDataset:
+ tmp = tempfile.mkdtemp()
+ self.addCleanup(lambda p=tmp: shutil.rmtree(p, ignore_errors=True))
+ write_two_class_ndjson(Path(tmp))
+ return MIMIC4FHIRDataset(
+ root=tmp, glob_pattern="*.ndjson", cache_dir=tmp
+ )
+
+ def test_max_len_validation(self) -> None:
+ with self.assertRaises(ValueError):
+ MPFClinicalPredictionTask(max_len=1, use_mpf=True)
+
+ def test_mpf_sets_boundary_tokens(self) -> None:
+ task = MPFClinicalPredictionTask(max_len=32, use_mpf=True)
+ ds = self._two_patient_ds()
+ samples = run_task(ds, task)
+ vocab = task.vocab
+ self.assertGreater(len(samples), 0)
+ s0 = samples[0]
+ mor = vocab[""]
+ reg = vocab[""]
+ pad_id = vocab.pad_id
+ ids = s0["concept_ids"]
+ first = next(i for i, x in enumerate(ids) if x != pad_id)
+ last_nz = next(i for i in range(len(ids) - 1, -1, -1) if ids[i] != pad_id)
+ self.assertEqual(ids[first], mor)
+ self.assertEqual(ids[last_nz], reg)
+ self.assertEqual(ids[-1], reg)
+
+ def test_no_mpf_uses_cls_reg(self) -> None:
+ task = MPFClinicalPredictionTask(max_len=32, use_mpf=False)
+ ds = self._two_patient_ds()
+ samples = run_task(ds, task)
+ vocab = task.vocab
+ s0 = samples[0]
+ cls_id = vocab[""]
+ reg_id = vocab[""]
+ pad_id = vocab.pad_id
+ ids = s0["concept_ids"]
+ first = next(i for i, x in enumerate(ids) if x != pad_id)
+ last_nz = next(i for i in range(len(ids) - 1, -1, -1) if ids[i] != pad_id)
+ self.assertEqual(ids[first], cls_id)
+ self.assertEqual(ids[last_nz], reg_id)
+ self.assertEqual(ids[-1], reg_id)
+
+ def test_schema_keys(self) -> None:
+ task = MPFClinicalPredictionTask(max_len=16, use_mpf=True)
+ ds = self._two_patient_ds()
+ samples = run_task(ds, task)
+ for k in task.input_schema:
+ self.assertIn(k, samples[0])
+ self.assertIn("label", samples[0])
+
+ def test_max_len_two_keeps_boundary_tokens(self) -> None:
+ """``clinical_cap=0`` must yield ``[, ]`` left-padded, not truncated."""
+
+ task = MPFClinicalPredictionTask(max_len=2, use_mpf=True)
+ ds = self._two_patient_ds()
+ samples = run_task(ds, task)
+ vocab = task.vocab
+ mor = vocab[""]
+ reg = vocab[""]
+ pad_id = vocab.pad_id
+ for s in samples:
+ ids = s["concept_ids"]
+ first = next(i for i, x in enumerate(ids) if x != pad_id)
+ last_nz = next(
+ i for i in range(len(ids) - 1, -1, -1) if ids[i] != pad_id
+ )
+ self.assertEqual(ids[first], mor)
+ self.assertEqual(ids[last_nz], reg)
+ self.assertEqual(ids[-1], reg)
+
+
+if __name__ == "__main__":
+ unittest.main()