Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
61 changes: 27 additions & 34 deletions docs/api/datasets/pyhealth.datasets.MIMIC4FHIRDataset.rst
Original file line number Diff line number Diff line change
Expand Up @@ -6,56 +6,54 @@ for CEHR-style token sequences used with
:class:`~pyhealth.tasks.mpf_clinical_prediction.MPFClinicalPredictionTask` and
:class:`~pyhealth.models.EHRMambaCEHR`.

YAML defaults live in ``pyhealth/datasets/configs/mimic4_fhir.yaml`` (e.g.
``glob_pattern`` for ``**/*.ndjson.gz`` on PhysioNet exports, or ``**/*.ndjson``
if uncompressed). The class subclasses :class:`~pyhealth.datasets.BaseDataset`
and builds a standard Polars ``global_event_df`` backed by cached Parquet
(``global_event_df.parquet/part-*.parquet``), same tabular path as other
datasets: :meth:`~pyhealth.datasets.BaseDataset.set_task`, :meth:`iter_patients`,
:meth:`get_patient`, etc.
YAML defaults live in ``pyhealth/datasets/configs/mimic4_fhir.yaml``. Unlike the
earlier nested-object approach, the YAML now declares a normal ``tables:``
schema for flattened FHIR resources (``patient``, ``encounter``, ``condition``,
``observation``, ``medication_request``, ``procedure``). The class subclasses
:class:`~pyhealth.datasets.BaseDataset` and builds a standard Polars
``global_event_df`` backed by cached Parquet (``global_event_df.parquet/part-*.parquet``),
same tabular path as other datasets: :meth:`~pyhealth.datasets.BaseDataset.set_task`,
:meth:`iter_patients`, :meth:`get_patient`, etc.

**Ingest (out-of-core).** Matching ``*.ndjson`` / ``*.ndjson.gz`` files are read
**line by line**; each FHIR resource row is written to hash-partitioned Parquet
shards (``patient_id`` → stable shard via CRC32). That bounds **peak ingest RAM**
to on-disk batch buffers and shard writers (see constructor / YAML
``ingest_num_shards``), instead of materializing the full export in Python lists.
Shards are finalized into ``part-*.parquet`` under the dataset cache; there is no
full-table ``(patient_id, timestamp)`` sort on disk—per-event time order for
:class:`~pyhealth.data.Patient` comes from ``data_source.sort("timestamp")`` when
a patient slice is loaded.
**line by line**; each resource is normalized into a flattened per-resource
Parquet table under ``cache/flattened_tables/``. Those tables are then fed
through the regular YAML-driven :class:`~pyhealth.datasets.BaseDataset` loader to
materialize ``global_event_df``. This keeps FHIR aligned with PyHealth's usual
table-first pipeline instead of reparsing nested JSON per patient downstream.

**``max_patients``.** When set, the loader selects the first *N* patient ids after
a **sorted** ``unique`` over the cached table (then filters shards). That caps
stored patients and speeds downstream iteration for subsets; ingest still scans
all matching NDJSON once to populate shards unless you also narrow
``glob_pattern``.
a **sorted** ``unique`` over the flattened patient table, filters every
normalized table to that cohort, and then builds ``global_event_df`` from the
filtered tables. Ingest still scans all matching NDJSON once unless you also
override ``glob_patterns`` / ``glob_pattern`` (defaults skip non-flattened PhysioNet shards).

**Downstream memory (still important).** Streaming ingest avoids loading the
entire NDJSON corpus into RAM at once, but other steps can still be heavy on large
cohorts: building :class:`~pyhealth.data.Patient` / :class:`~pyhealth.datasets.FHIRPatient`
from ``fhir/resource_json`` parses JSON per patient; MPF vocabulary warmup and
:meth:`set_task` walk patients/samples; training needs RAM/VRAM for the model and
batches. For a **full** PhysioNet tree, plan for **large disk** (Parquet cache),
**comfortable system RAM** for Polars/PyArrow and task pipelines, and restrict
``glob_pattern`` or ``max_patients`` when prototyping on a laptop.
entire NDJSON corpus into RAM at once, but other steps can still be heavy on
large cohorts: ``global_event_df`` materialization, MPF vocabulary warmup, and
:meth:`set_task` still walk patients and samples; training needs RAM/VRAM for the
model and batches. For a **full** PhysioNet tree, plan for **large disk**
(flattened tables plus event cache), **comfortable system RAM** for Polars/PyArrow
and task pipelines, and restrict ``glob_patterns`` / ``glob_pattern`` or ``max_patients`` when
prototyping on a laptop.

**Recommended hardware (informal)**

Order-of-magnitude guides, not guarantees. Ingest footprint is **much smaller**
than “load everything into Python”; wall time still grows with **decompressed
NDJSON volume** and shard/batch settings.
NDJSON volume** and the amount of flattened table data produced.

* **Smoke / CI**
Small on-disk fixtures (see tests and ``examples/mimic4fhir_mpf_ehrmamba.py``):
a recent laptop is sufficient.

* **Laptop-scale real FHIR subset**
A **narrow** ``glob_pattern`` and/or ``max_patients`` in the hundreds keeps
A **narrow** ``glob_patterns`` / ``glob_pattern`` and/or ``max_patients`` in the hundreds keeps
cache and task passes manageable. **≥ 16 GB** system RAM is a practical
comfort target for Polars + trainer + OS; validate GPU **VRAM** for your
``max_len`` and batch size.

* **Full default glob on a complete export**
* **Full default globs on a complete export**
Favor **workstations or servers** with **fast SSD**, **large disk**, and
**ample RAM** for downstream steps—not because NDJSON is fully buffered in
memory during ingest, but because total work and caches still scale with the
Expand All @@ -70,8 +68,3 @@ NDJSON volume** and shard/batch settings.
:members:
:undoc-members:
:show-inheritance:

.. autoclass:: pyhealth.datasets.FHIRPatient
:members:
:undoc-members:
:show-inheritance:
83 changes: 42 additions & 41 deletions examples/mimic4fhir_mpf_ehrmamba.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,24 +26,25 @@
needed for conclusive comparisons. Paste your table from ``--ablation``
into the PR description.

**Scaling:** :class:`~pyhealth.datasets.MIMIC4FHIRDataset` streams NDJSON to
hash-sharded Parquet (bounded RAM during ingest). This example trains via
**Scaling:** :class:`~pyhealth.datasets.MIMIC4FHIRDataset` streams NDJSON into
flattened per-resource Parquet tables (bounded RAM during ingest). This example trains via
``dataset.set_task(MPFClinicalPredictionTask)`` → LitData-backed
:class:`~pyhealth.datasets.sample_dataset.SampleDataset` →
:class:`~pyhealth.trainer.Trainer` (PyHealth’s standard path), instead of
materializing all samples with ``gather_samples()``. Prefer ``--max-patients`` to
bound ingest when possible. Very large cohorts still need RAM/disk for task
caches and MPF vocabulary warmup.

**Offline Parquet (NDJSON → Parquet already done):** pass
``--prebuilt-global-event-dir`` pointing at a directory of ``shard-*.parquet``
(from ingest / ``stream_fhir_ndjson_root_to_sharded_parquet``). The example seeds
``global_event_df.parquet/`` under the usual PyHealth cache UUID so
``BaseDataset.global_event_df`` skips re-ingest — the downstream path is still
**Offline flattened tables (NDJSON normalization already done):** pass
``--prebuilt-global-event-dir`` pointing at a directory containing the normalized
FHIR tables (``patient.parquet``, ``encounter.parquet``, ``condition.parquet``,
etc.). The example seeds ``flattened_tables/`` under the usual PyHealth cache UUID,
then lets :class:`~pyhealth.datasets.BaseDataset` rebuild
``global_event_df.parquet/`` from those tables — the downstream path is still
``global_event_df`` → :class:`~pyhealth.data.Patient` →
:class:`~pyhealth.tasks.mpf_clinical_prediction.MPFClinicalPredictionTask` →
:class:`~pyhealth.trainer.Trainer``. Use ``--fhir-root`` / ``--glob-pattern`` /
``--ingest-num-shards`` / ``--max-patients -1`` matching the ingest fingerprint.
``--max-patients -1`` matching the ingest fingerprint.
``--train-patient-cap`` restricts task transforms via ``task.pre_filter`` using a
label-aware deterministic patient subset. The full ``unique_patient_ids`` scan and MPF vocab warmup
in the dataset still walk the cached cohort.
Expand All @@ -60,11 +61,10 @@
export MIMIC4_FHIR_ROOT=/path/to/fhir
pixi run -e base python examples/mimic4fhir_mpf_ehrmamba.py --fhir-root "$MIMIC4_FHIR_ROOT"

# Prebuilt Parquet shards (skip NDJSON re-ingest); cap patients for a smoke train
# Prebuilt flattened FHIR tables (skip NDJSON normalization); cap patients for a smoke train
pixi run -e base python examples/mimic4fhir_mpf_ehrmamba.py \\
--prebuilt-global-event-dir /path/to/shard_parquet_dir \\
--fhir-root /same/as/ndjson/ingest/root \\
--glob-pattern 'Mimic*.ndjson.gz' --ingest-num-shards 16 --max-patients -1 \\
--prebuilt-global-event-dir /path/to/flattened_table_dir \\
--fhir-root /same/as/ndjson/ingest/root --glob-pattern 'Mimic*.ndjson.gz' --max-patients -1 \\
--train-patient-cap 2048 --epochs 2 \\
--ntfy-url 'https://ntfy.sh/your-topic'
"""
Expand Down Expand Up @@ -150,25 +150,25 @@
type=int,
default=500,
help=(
"Fingerprint for cache dir: cap patients during ingest (-1 = full cohort, "
"match an uncapped NDJSON→Parquet export)."
"Fingerprint for cache dir: cap patients during normalization (-1 = full cohort, "
"match an uncapped NDJSON→flattened-table export)."
),
)
_parser.add_argument(
"--prebuilt-global-event-dir",
type=str,
default=None,
help=(
"Directory with shard-*.parquet from NDJSON ingest. Seeds "
"cache/global_event_df.parquet/ so training skips re-ingest (downstream "
"unchanged: Patient + MPF + Trainer)."
"Directory with normalized flattened FHIR tables (*.parquet). Seeds "
"cache/flattened_tables/ so training skips NDJSON normalization "
"(downstream unchanged: Patient + MPF + Trainer)."
),
)
_parser.add_argument(
"--ingest-num-shards",
type=int,
default=None,
help="Fingerprint only: must match NDJSON→Parquet ingest (default: dataset YAML / heuristic).",
help="Compatibility no-op: retained for CLI stability with older runs.",
)
_parser.add_argument(
"--train-patient-cap",
Expand Down Expand Up @@ -216,7 +216,7 @@
import polars as pl

from pyhealth.datasets import MIMIC4FHIRDataset, get_dataloader
from pyhealth.datasets.mimic4_fhir import fhir_patient_from_patient, infer_mortality_label
from pyhealth.datasets.mimic4_fhir import infer_mortality_label
from pyhealth.models import EHRMambaCEHR
from pyhealth.tasks.mpf_clinical_prediction import MPFClinicalPredictionTask
from pyhealth.trainer import Trainer
Expand Down Expand Up @@ -251,20 +251,20 @@ def _max_patients_arg(v: int) -> Optional[int]:
return None if v is not None and v < 0 else v


def _seed_global_event_cache_from_shards(prebuilt_dir: Path, ds: MIMIC4FHIRDataset) -> None:
"""Link shard-*.parquet into the dataset cache as part-*.parquet (PyHealth layout)."""
def _seed_flattened_table_cache(prebuilt_dir: Path, ds: MIMIC4FHIRDataset) -> None:
"""Copy normalized per-resource parquet tables into the dataset cache."""

shards = sorted(prebuilt_dir.glob("shard-*.parquet"))
if not shards:
tables = sorted(prebuilt_dir.glob("*.parquet"))
if not tables:
raise FileNotFoundError(
f"No shard-*.parquet under {prebuilt_dir} — use ingest output directory."
f"No *.parquet tables under {prebuilt_dir} — expected flattened FHIR tables."
)
ge = ds.cache_dir / "global_event_df.parquet"
if ge.exists() and any(ge.glob("*.parquet")):
prepared = ds.prepared_tables_dir
if prepared.exists() and any(prepared.glob("*.parquet")):
return
ge.mkdir(parents=True, exist_ok=True)
for i, src in enumerate(shards):
dest = ge / f"part-{i:05d}.parquet"
prepared.mkdir(parents=True, exist_ok=True)
for src in tables:
dest = prepared / src.name
if dest.exists():
continue
try:
Expand Down Expand Up @@ -348,7 +348,7 @@ def _quick_test_ndjson_dir() -> str:

def _patient_label(ds: MIMIC4FHIRDataset, patient_id: str) -> int:
patient = ds.get_patient(patient_id)
return int(infer_mortality_label(fhir_patient_from_patient(patient)))
return int(infer_mortality_label(patient))


def _ensure_binary_label_coverage(ds: MIMIC4FHIRDataset) -> None:
Expand Down Expand Up @@ -552,7 +552,7 @@ def run_single_train(
ds_kw["ingest_num_shards"] = ingest_num_shards
ds = MIMIC4FHIRDataset(**ds_kw)
if prebuilt_global_event_dir:
_seed_global_event_cache_from_shards(
_seed_flattened_table_cache(
Path(prebuilt_global_event_dir).expanduser().resolve(), ds
)
if train_patient_cap is not None:
Expand Down Expand Up @@ -699,8 +699,8 @@ def _main_train(args: argparse.Namespace) -> None:
)
try:
print(
"pipeline: synthetic NDJSON → ingest Parquetset_task → "
"SampleDataset → Trainer"
"pipeline: synthetic NDJSON → flattened tablesglobal_event_df "
"→ set_task → SampleDataset → Trainer"
)
task = MPFClinicalPredictionTask(
max_len=args.max_len,
Expand Down Expand Up @@ -744,14 +744,15 @@ def _main_train(args: argparse.Namespace) -> None:
if not pb.is_dir():
raise SystemExit(f"--prebuilt-global-event-dir not a directory: {pb}")
print(
"pipeline: offline NDJSON→Parquet shards → seed global_event_df cache → "
"set_task → SampleDataset → Trainer (no NDJSON re-ingest)"
"pipeline: offline flattened FHIR tables → seed flattened table cache "
"→ global_event_df → set_task → SampleDataset → Trainer "
"(no NDJSON normalization)"
)
_seed_global_event_cache_from_shards(pb, ds)
_seed_flattened_table_cache(pb, ds)
else:
print(
"pipeline: NDJSON root → MIMIC4FHIRDataset ingestParquet cache → "
"set_task → SampleDataset → Trainer"
"pipeline: NDJSON root → MIMIC4FHIRDataset flatteningglobal_event_df "
"set_task → SampleDataset → Trainer"
)
print("glob_pattern:", ds.glob_pattern, "| max_patients fingerprint:", mp)
if args.train_patient_cap is not None:
Expand Down Expand Up @@ -788,9 +789,9 @@ def _main_train(args: argparse.Namespace) -> None:
if len(sample_ds) == 0:
raise SystemExit(
"No training samples (0 patients or empty sequences). "
"PhysioNet MIMIC-IV FHIR uses *.ndjson.gz (default glob **/*.ndjson.gz). "
"If your tree is plain *.ndjson, construct MIMIC4FHIRDataset with "
"glob_pattern='**/*.ndjson'."
"PhysioNet MIMIC-IV FHIR uses *.ndjson.gz (see default glob_patterns in "
"pyhealth/datasets/configs/mimic4_fhir.yaml). If your tree is plain *.ndjson, "
"construct MIMIC4FHIRDataset with glob_pattern='**/*.ndjson'."
)

sample_ds, train_loader, val_loader, test_loader, vocab_size = (
Expand Down
2 changes: 1 addition & 1 deletion pyhealth/datasets/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ def __init__(self, *args, **kwargs):
from .medical_transcriptions import MedicalTranscriptionsDataset
from .mimic3 import MIMIC3Dataset
from .mimic4 import MIMIC4CXRDataset, MIMIC4Dataset, MIMIC4EHRDataset, MIMIC4NoteDataset
from .mimic4_fhir import ConceptVocab, FHIRPatient, MIMIC4FHIRDataset
from .mimic4_fhir import ConceptVocab, MIMIC4FHIRDataset
from .mimicextract import MIMICExtractDataset
from .omop import OMOPDataset
from .sample_dataset import SampleBuilder, SampleDataset, create_sample_dataset
Expand Down
Loading