diff --git a/docs/api/datasets/pyhealth.datasets.MIMIC4FHIRDataset.rst b/docs/api/datasets/pyhealth.datasets.MIMIC4FHIRDataset.rst index 5b03c4196..1a19fc5e9 100644 --- a/docs/api/datasets/pyhealth.datasets.MIMIC4FHIRDataset.rst +++ b/docs/api/datasets/pyhealth.datasets.MIMIC4FHIRDataset.rst @@ -6,56 +6,54 @@ for CEHR-style token sequences used with :class:`~pyhealth.tasks.mpf_clinical_prediction.MPFClinicalPredictionTask` and :class:`~pyhealth.models.EHRMambaCEHR`. -YAML defaults live in ``pyhealth/datasets/configs/mimic4_fhir.yaml`` (e.g. -``glob_pattern`` for ``**/*.ndjson.gz`` on PhysioNet exports, or ``**/*.ndjson`` -if uncompressed). The class subclasses :class:`~pyhealth.datasets.BaseDataset` -and builds a standard Polars ``global_event_df`` backed by cached Parquet -(``global_event_df.parquet/part-*.parquet``), same tabular path as other -datasets: :meth:`~pyhealth.datasets.BaseDataset.set_task`, :meth:`iter_patients`, -:meth:`get_patient`, etc. +YAML defaults live in ``pyhealth/datasets/configs/mimic4_fhir.yaml``. Unlike the +earlier nested-object approach, the YAML now declares a normal ``tables:`` +schema for flattened FHIR resources (``patient``, ``encounter``, ``condition``, +``observation``, ``medication_request``, ``procedure``). The class subclasses +:class:`~pyhealth.datasets.BaseDataset` and builds a standard Polars +``global_event_df`` backed by cached Parquet (``global_event_df.parquet/part-*.parquet``), +same tabular path as other datasets: :meth:`~pyhealth.datasets.BaseDataset.set_task`, +:meth:`iter_patients`, :meth:`get_patient`, etc. **Ingest (out-of-core).** Matching ``*.ndjson`` / ``*.ndjson.gz`` files are read -**line by line**; each FHIR resource row is written to hash-partitioned Parquet -shards (``patient_id`` → stable shard via CRC32). That bounds **peak ingest RAM** -to on-disk batch buffers and shard writers (see constructor / YAML -``ingest_num_shards``), instead of materializing the full export in Python lists. -Shards are finalized into ``part-*.parquet`` under the dataset cache; there is no -full-table ``(patient_id, timestamp)`` sort on disk—per-event time order for -:class:`~pyhealth.data.Patient` comes from ``data_source.sort("timestamp")`` when -a patient slice is loaded. +**line by line**; each resource is normalized into a flattened per-resource +Parquet table under ``cache/flattened_tables/``. Those tables are then fed +through the regular YAML-driven :class:`~pyhealth.datasets.BaseDataset` loader to +materialize ``global_event_df``. This keeps FHIR aligned with PyHealth's usual +table-first pipeline instead of reparsing nested JSON per patient downstream. **``max_patients``.** When set, the loader selects the first *N* patient ids after -a **sorted** ``unique`` over the cached table (then filters shards). That caps -stored patients and speeds downstream iteration for subsets; ingest still scans -all matching NDJSON once to populate shards unless you also narrow -``glob_pattern``. +a **sorted** ``unique`` over the flattened patient table, filters every +normalized table to that cohort, and then builds ``global_event_df`` from the +filtered tables. Ingest still scans all matching NDJSON once unless you also +override ``glob_patterns`` / ``glob_pattern`` (defaults skip non-flattened PhysioNet shards). **Downstream memory (still important).** Streaming ingest avoids loading the -entire NDJSON corpus into RAM at once, but other steps can still be heavy on large -cohorts: building :class:`~pyhealth.data.Patient` / :class:`~pyhealth.datasets.FHIRPatient` -from ``fhir/resource_json`` parses JSON per patient; MPF vocabulary warmup and -:meth:`set_task` walk patients/samples; training needs RAM/VRAM for the model and -batches. For a **full** PhysioNet tree, plan for **large disk** (Parquet cache), -**comfortable system RAM** for Polars/PyArrow and task pipelines, and restrict -``glob_pattern`` or ``max_patients`` when prototyping on a laptop. +entire NDJSON corpus into RAM at once, but other steps can still be heavy on +large cohorts: ``global_event_df`` materialization, MPF vocabulary warmup, and +:meth:`set_task` still walk patients and samples; training needs RAM/VRAM for the +model and batches. For a **full** PhysioNet tree, plan for **large disk** +(flattened tables plus event cache), **comfortable system RAM** for Polars/PyArrow +and task pipelines, and restrict ``glob_patterns`` / ``glob_pattern`` or ``max_patients`` when +prototyping on a laptop. **Recommended hardware (informal)** Order-of-magnitude guides, not guarantees. Ingest footprint is **much smaller** than “load everything into Python”; wall time still grows with **decompressed -NDJSON volume** and shard/batch settings. +NDJSON volume** and the amount of flattened table data produced. * **Smoke / CI** Small on-disk fixtures (see tests and ``examples/mimic4fhir_mpf_ehrmamba.py``): a recent laptop is sufficient. * **Laptop-scale real FHIR subset** - A **narrow** ``glob_pattern`` and/or ``max_patients`` in the hundreds keeps + A **narrow** ``glob_patterns`` / ``glob_pattern`` and/or ``max_patients`` in the hundreds keeps cache and task passes manageable. **≥ 16 GB** system RAM is a practical comfort target for Polars + trainer + OS; validate GPU **VRAM** for your ``max_len`` and batch size. -* **Full default glob on a complete export** +* **Full default globs on a complete export** Favor **workstations or servers** with **fast SSD**, **large disk**, and **ample RAM** for downstream steps—not because NDJSON is fully buffered in memory during ingest, but because total work and caches still scale with the @@ -70,8 +68,3 @@ NDJSON volume** and shard/batch settings. :members: :undoc-members: :show-inheritance: - -.. autoclass:: pyhealth.datasets.FHIRPatient - :members: - :undoc-members: - :show-inheritance: diff --git a/examples/mimic4fhir_mpf_ehrmamba.py b/examples/mimic4fhir_mpf_ehrmamba.py index 05981e70e..ef58759e6 100644 --- a/examples/mimic4fhir_mpf_ehrmamba.py +++ b/examples/mimic4fhir_mpf_ehrmamba.py @@ -26,8 +26,8 @@ needed for conclusive comparisons. Paste your table from ``--ablation`` into the PR description. -**Scaling:** :class:`~pyhealth.datasets.MIMIC4FHIRDataset` streams NDJSON to -hash-sharded Parquet (bounded RAM during ingest). This example trains via +**Scaling:** :class:`~pyhealth.datasets.MIMIC4FHIRDataset` streams NDJSON into +flattened per-resource Parquet tables (bounded RAM during ingest). This example trains via ``dataset.set_task(MPFClinicalPredictionTask)`` → LitData-backed :class:`~pyhealth.datasets.sample_dataset.SampleDataset` → :class:`~pyhealth.trainer.Trainer` (PyHealth’s standard path), instead of @@ -35,15 +35,16 @@ bound ingest when possible. Very large cohorts still need RAM/disk for task caches and MPF vocabulary warmup. -**Offline Parquet (NDJSON → Parquet already done):** pass -``--prebuilt-global-event-dir`` pointing at a directory of ``shard-*.parquet`` -(from ingest / ``stream_fhir_ndjson_root_to_sharded_parquet``). The example seeds -``global_event_df.parquet/`` under the usual PyHealth cache UUID so -``BaseDataset.global_event_df`` skips re-ingest — the downstream path is still +**Offline flattened tables (NDJSON normalization already done):** pass +``--prebuilt-global-event-dir`` pointing at a directory containing the normalized +FHIR tables (``patient.parquet``, ``encounter.parquet``, ``condition.parquet``, +etc.). The example seeds ``flattened_tables/`` under the usual PyHealth cache UUID, +then lets :class:`~pyhealth.datasets.BaseDataset` rebuild +``global_event_df.parquet/`` from those tables — the downstream path is still ``global_event_df`` → :class:`~pyhealth.data.Patient` → :class:`~pyhealth.tasks.mpf_clinical_prediction.MPFClinicalPredictionTask` → :class:`~pyhealth.trainer.Trainer``. Use ``--fhir-root`` / ``--glob-pattern`` / -``--ingest-num-shards`` / ``--max-patients -1`` matching the ingest fingerprint. +``--max-patients -1`` matching the ingest fingerprint. ``--train-patient-cap`` restricts task transforms via ``task.pre_filter`` using a label-aware deterministic patient subset. The full ``unique_patient_ids`` scan and MPF vocab warmup in the dataset still walk the cached cohort. @@ -60,11 +61,10 @@ export MIMIC4_FHIR_ROOT=/path/to/fhir pixi run -e base python examples/mimic4fhir_mpf_ehrmamba.py --fhir-root "$MIMIC4_FHIR_ROOT" - # Prebuilt Parquet shards (skip NDJSON re-ingest); cap patients for a smoke train + # Prebuilt flattened FHIR tables (skip NDJSON normalization); cap patients for a smoke train pixi run -e base python examples/mimic4fhir_mpf_ehrmamba.py \\ - --prebuilt-global-event-dir /path/to/shard_parquet_dir \\ - --fhir-root /same/as/ndjson/ingest/root \\ - --glob-pattern 'Mimic*.ndjson.gz' --ingest-num-shards 16 --max-patients -1 \\ + --prebuilt-global-event-dir /path/to/flattened_table_dir \\ + --fhir-root /same/as/ndjson/ingest/root --glob-pattern 'Mimic*.ndjson.gz' --max-patients -1 \\ --train-patient-cap 2048 --epochs 2 \\ --ntfy-url 'https://ntfy.sh/your-topic' """ @@ -150,8 +150,8 @@ type=int, default=500, help=( - "Fingerprint for cache dir: cap patients during ingest (-1 = full cohort, " - "match an uncapped NDJSON→Parquet export)." + "Fingerprint for cache dir: cap patients during normalization (-1 = full cohort, " + "match an uncapped NDJSON→flattened-table export)." ), ) _parser.add_argument( @@ -159,16 +159,16 @@ type=str, default=None, help=( - "Directory with shard-*.parquet from NDJSON ingest. Seeds " - "cache/global_event_df.parquet/ so training skips re-ingest (downstream " - "unchanged: Patient + MPF + Trainer)." + "Directory with normalized flattened FHIR tables (*.parquet). Seeds " + "cache/flattened_tables/ so training skips NDJSON normalization " + "(downstream unchanged: Patient + MPF + Trainer)." ), ) _parser.add_argument( "--ingest-num-shards", type=int, default=None, - help="Fingerprint only: must match NDJSON→Parquet ingest (default: dataset YAML / heuristic).", + help="Compatibility no-op: retained for CLI stability with older runs.", ) _parser.add_argument( "--train-patient-cap", @@ -216,7 +216,7 @@ import polars as pl from pyhealth.datasets import MIMIC4FHIRDataset, get_dataloader -from pyhealth.datasets.mimic4_fhir import fhir_patient_from_patient, infer_mortality_label +from pyhealth.datasets.mimic4_fhir import infer_mortality_label from pyhealth.models import EHRMambaCEHR from pyhealth.tasks.mpf_clinical_prediction import MPFClinicalPredictionTask from pyhealth.trainer import Trainer @@ -251,20 +251,20 @@ def _max_patients_arg(v: int) -> Optional[int]: return None if v is not None and v < 0 else v -def _seed_global_event_cache_from_shards(prebuilt_dir: Path, ds: MIMIC4FHIRDataset) -> None: - """Link shard-*.parquet into the dataset cache as part-*.parquet (PyHealth layout).""" +def _seed_flattened_table_cache(prebuilt_dir: Path, ds: MIMIC4FHIRDataset) -> None: + """Copy normalized per-resource parquet tables into the dataset cache.""" - shards = sorted(prebuilt_dir.glob("shard-*.parquet")) - if not shards: + tables = sorted(prebuilt_dir.glob("*.parquet")) + if not tables: raise FileNotFoundError( - f"No shard-*.parquet under {prebuilt_dir} — use ingest output directory." + f"No *.parquet tables under {prebuilt_dir} — expected flattened FHIR tables." ) - ge = ds.cache_dir / "global_event_df.parquet" - if ge.exists() and any(ge.glob("*.parquet")): + prepared = ds.prepared_tables_dir + if prepared.exists() and any(prepared.glob("*.parquet")): return - ge.mkdir(parents=True, exist_ok=True) - for i, src in enumerate(shards): - dest = ge / f"part-{i:05d}.parquet" + prepared.mkdir(parents=True, exist_ok=True) + for src in tables: + dest = prepared / src.name if dest.exists(): continue try: @@ -348,7 +348,7 @@ def _quick_test_ndjson_dir() -> str: def _patient_label(ds: MIMIC4FHIRDataset, patient_id: str) -> int: patient = ds.get_patient(patient_id) - return int(infer_mortality_label(fhir_patient_from_patient(patient))) + return int(infer_mortality_label(patient)) def _ensure_binary_label_coverage(ds: MIMIC4FHIRDataset) -> None: @@ -552,7 +552,7 @@ def run_single_train( ds_kw["ingest_num_shards"] = ingest_num_shards ds = MIMIC4FHIRDataset(**ds_kw) if prebuilt_global_event_dir: - _seed_global_event_cache_from_shards( + _seed_flattened_table_cache( Path(prebuilt_global_event_dir).expanduser().resolve(), ds ) if train_patient_cap is not None: @@ -699,8 +699,8 @@ def _main_train(args: argparse.Namespace) -> None: ) try: print( - "pipeline: synthetic NDJSON → ingest Parquet → set_task → " - "SampleDataset → Trainer" + "pipeline: synthetic NDJSON → flattened tables → global_event_df " + "→ set_task → SampleDataset → Trainer" ) task = MPFClinicalPredictionTask( max_len=args.max_len, @@ -744,14 +744,15 @@ def _main_train(args: argparse.Namespace) -> None: if not pb.is_dir(): raise SystemExit(f"--prebuilt-global-event-dir not a directory: {pb}") print( - "pipeline: offline NDJSON→Parquet shards → seed global_event_df cache → " - "set_task → SampleDataset → Trainer (no NDJSON re-ingest)" + "pipeline: offline flattened FHIR tables → seed flattened table cache " + "→ global_event_df → set_task → SampleDataset → Trainer " + "(no NDJSON normalization)" ) - _seed_global_event_cache_from_shards(pb, ds) + _seed_flattened_table_cache(pb, ds) else: print( - "pipeline: NDJSON root → MIMIC4FHIRDataset ingest → Parquet cache → " - "set_task → SampleDataset → Trainer" + "pipeline: NDJSON root → MIMIC4FHIRDataset flattening → global_event_df " + "→ set_task → SampleDataset → Trainer" ) print("glob_pattern:", ds.glob_pattern, "| max_patients fingerprint:", mp) if args.train_patient_cap is not None: @@ -788,9 +789,9 @@ def _main_train(args: argparse.Namespace) -> None: if len(sample_ds) == 0: raise SystemExit( "No training samples (0 patients or empty sequences). " - "PhysioNet MIMIC-IV FHIR uses *.ndjson.gz (default glob **/*.ndjson.gz). " - "If your tree is plain *.ndjson, construct MIMIC4FHIRDataset with " - "glob_pattern='**/*.ndjson'." + "PhysioNet MIMIC-IV FHIR uses *.ndjson.gz (see default glob_patterns in " + "pyhealth/datasets/configs/mimic4_fhir.yaml). If your tree is plain *.ndjson, " + "construct MIMIC4FHIRDataset with glob_pattern='**/*.ndjson'." ) sample_ds, train_loader, val_loader, test_loader, vocab_size = ( diff --git a/pyhealth/datasets/__init__.py b/pyhealth/datasets/__init__.py index d3070e545..620a0a908 100644 --- a/pyhealth/datasets/__init__.py +++ b/pyhealth/datasets/__init__.py @@ -59,7 +59,7 @@ def __init__(self, *args, **kwargs): from .medical_transcriptions import MedicalTranscriptionsDataset from .mimic3 import MIMIC3Dataset from .mimic4 import MIMIC4CXRDataset, MIMIC4Dataset, MIMIC4EHRDataset, MIMIC4NoteDataset -from .mimic4_fhir import ConceptVocab, FHIRPatient, MIMIC4FHIRDataset +from .mimic4_fhir import ConceptVocab, MIMIC4FHIRDataset from .mimicextract import MIMICExtractDataset from .omop import OMOPDataset from .sample_dataset import SampleBuilder, SampleDataset, create_sample_dataset diff --git a/pyhealth/datasets/configs/mimic4_fhir.yaml b/pyhealth/datasets/configs/mimic4_fhir.yaml index 617717fe1..989ab79cf 100644 --- a/pyhealth/datasets/configs/mimic4_fhir.yaml +++ b/pyhealth/datasets/configs/mimic4_fhir.yaml @@ -1,23 +1,118 @@ -# MIMIC-IV FHIR NDJSON ingest for :class:`~pyhealth.datasets.MIMIC4FHIRDataset`. -# This YAML is read by ``read_fhir_settings_yaml`` (FHIR layout), not the CSV -# ``DatasetConfig`` table schema used by MIMIC-IV relational loaders. -version: "fhir_r4" -# PhysioNet layout (credentialled export): a single ``fhir/`` directory containing -# gzip NDJSON shards, e.g. ``MimicPatient.ndjson.gz``, ``MimicEncounter.ndjson.gz``, -# ``MimicCondition.ndjson.gz``, ``MimicObservationLabevents.ndjson.gz``, etc. -# Set ``root`` to that ``fhir/`` path (or the unpacked folder that holds these files). -# -# ``**/*.ndjson.gz`` matches all such files whether laid out flat or in subfolders. -# For uncompressed trees, pass ``glob_pattern="**/*.ndjson"`` to MIMIC4FHIRDataset. -glob_pattern: "**/*.ndjson.gz" -# Parallel ingest: rows are hashed by ``patient_id`` into this many Parquet shards -# (bounded RAM, avoids one global on-disk sort). Omit to auto-pick from CPU count. -# ingest_num_shards: 16 -# Optional: restrict resource types when scanning (future use) -resource_types: - - Patient - - Encounter - - Condition - - Observation - - MedicationRequest - - Procedure +# MIMIC-IV FHIR Resource Flattening Configuration +# ================================================ +# +# This YAML defines the normalized schema for MIMIC-IV FHIR exports after streaming +# ingestion. Raw NDJSON/NDJSON.GZ resources are parsed and flattened into six +# per-resource-type Parquet tables (see ``version`` below), then loaded through the +# standard BaseDataset pipeline for task construction. +# +# For ingest details, see the docstring of ``stream_fhir_ndjson_to_flat_tables()`` +# in ``pyhealth/datasets/mimic4_fhir.py``. + +version: "fhir_r4_flattened" + +# Glob Pattern(s) for NDJSON File Discovery +# =========================================== +# +# ``glob_patterns`` (list) or ``glob_pattern`` (string): Patterns to match NDJSON files +# under the ingest root directory. Patterns are applied via pathlib.Path.glob(). +# +# Default: Six targeted patterns matching PhysioNet MIMIC-IV FHIR Mimic* shard families +# that map to flattened tables. This avoids decompressing and parsing ~10% of PhysioNet +# exports (MedicationAdministration, Specimen, Organization, …) that are skipped by +# the flattener. +# +# Alternatives: +# - For non-PhysioNet naming, use a single broad pattern: +# glob_pattern: "**/*.ndjson.gz" +# - To test on a subset, use a narrower list: +# glob_patterns: +# - "**/MimicPatient*.ndjson.gz" +# - "**/MimicObservation*.ndjson.gz" +# +# Notes: +# - Patterns use ``**/`` for recursive search (works in both flat and nested layouts). +# - Can be overridden at runtime via MIMIC4FHIRDataset(glob_pattern=...) or +# MIMIC4FHIRDataset(glob_patterns=[...]). + +glob_patterns: + - "**/MimicPatient*.ndjson.gz" + - "**/MimicEncounter*.ndjson.gz" + - "**/MimicCondition*.ndjson.gz" + - "**/MimicObservation*.ndjson.gz" + - "**/MimicMedicationRequest*.ndjson.gz" + - "**/MimicProcedure*.ndjson.gz" + +# Flattened Table Schema +# ====================== +# +# Each table is normalized from a single FHIR resource type. Columns are: +# - patient_id (str): Foreign key to patient (derived from subject.reference or id). +# - [timestamp] (str): ISO 8601 datetime string (coerced; nullable). +# - attributes (List[str]): Additional columns from the resource. +# +# Unsupported resource types (Medication, MedicationAdministration, Specimen, …) +# are silently dropped during ingest; only tables listed here are written. + +tables: + patient: + file_path: "patient.parquet" + patient_id: "patient_id" + timestamp: "birth_date" + attributes: + - "patient_fhir_id" + - "birth_date" + - "gender" + - "deceased_boolean" + - "deceased_datetime" + + encounter: + file_path: "encounter.parquet" + patient_id: "patient_id" + timestamp: "event_time" + attributes: + - "resource_id" + - "encounter_id" + - "event_time" + - "encounter_class" + - "encounter_end" + + condition: + file_path: "condition.parquet" + patient_id: "patient_id" + timestamp: "event_time" + attributes: + - "resource_id" + - "encounter_id" + - "event_time" + - "concept_key" + + observation: + file_path: "observation.parquet" + patient_id: "patient_id" + timestamp: "event_time" + attributes: + - "resource_id" + - "encounter_id" + - "event_time" + - "concept_key" + + medication_request: + file_path: "medication_request.parquet" + patient_id: "patient_id" + timestamp: "event_time" + attributes: + - "resource_id" + - "encounter_id" + - "event_time" + - "concept_key" + + procedure: + file_path: "procedure.parquet" + patient_id: "patient_id" + timestamp: "event_time" + attributes: + - "resource_id" + - "encounter_id" + - "event_time" + - "concept_key" diff --git a/pyhealth/datasets/mimic4_fhir.py b/pyhealth/datasets/mimic4_fhir.py index 1fb023aa3..4c98884c1 100644 --- a/pyhealth/datasets/mimic4_fhir.py +++ b/pyhealth/datasets/mimic4_fhir.py @@ -1,92 +1,130 @@ -"""MIMIC-IV FHIR (NDJSON) ingestion for CEHR-style sequences. - -Loads newline-delimited JSON (plain ``*.ndjson`` or gzip ``*.ndjson.gz``, as on -PhysioNet), or Bundle ``entry`` resources, groups by Patient id, and builds -token timelines for MPF / EHRMambaCEHR. - -:class:`MIMIC4FHIRDataset` materializes a PyHealth-standard **global event -table** as Parquet (``patient_id``, ``timestamp``, ``event_type``, -``fhir/resource_json``) under the dataset cache. Ingest **hash-partitions** rows -by ``patient_id`` into multiple shard files (bounded memory, no full-table sort); -ingest schedules one process pool task per NDJSON/GZ file (worker count capped by -``os.cpu_count()``). ``global_event_df`` may scan several ``part-*.parquet`` files -like other multi-part caches. Per-patient time order still comes from -:class:`~pyhealth.data.Patient` (``data_source.sort("timestamp")``). The same -``global_event_df`` / :class:`~pyhealth.data.Patient` / :meth:`set_task` path -as CSV-backed datasets applies downstream. - -Settings such as ``glob_pattern`` live in ``configs/mimic4_fhir.yaml`` and are -read by :func:`read_fhir_settings_yaml`. For PhysioNet MIMIC-IV on FHIR, set -``root`` to the ``fhir/`` directory that contains ``Mimic*.ndjson.gz`` shards -(e.g. ``MimicPatient.ndjson.gz``, ``MimicEncounter.ndjson.gz``); the default -``glob_pattern`` is ``**/*.ndjson.gz``. For tests, write small -``*.ndjson`` / ``*.ndjson.gz`` files and point ``root`` / ``glob_pattern`` at them. - -**JSON / ingest.** NDJSON lines use ``orjson``. Row reading, batching, and Parquet -shard hashing follow a **fixed implementation** in this module (not configurable on -:class:`MIMIC4FHIRDataset`). :data:`FHIR_SCHEMA_VERSION` is part of the dataset cache -fingerprint so ingest or schema changes invalidate cached Parquet. - -Plain single-resource lines (not Bundle / ``{"resource":...}`` wrappers) store the -**original line text** in ``fhir/resource_json`` where safe (skipping -``orjson.dumps``). Otherwise the resource is serialized with ``orjson.dumps``. -Column buffers build Arrow tables without per-row ``from_pylist`` dict materialization. +"""MIMIC-IV FHIR ingestion using flattened resource tables. + +The maintainer-requested architecture for FHIR in PyHealth is: + +1. Stream NDJSON/NDJSON.GZ FHIR resources from disk. +2. Normalize each resource type into a 2D table (Patient, Encounter, Condition, + Observation, MedicationRequest, Procedure). +3. Feed those tables through the standard YAML-driven + :class:`~pyhealth.datasets.BaseDataset` pipeline so downstream task processing + operates on :class:`~pyhealth.data.Patient` and ``global_event_df`` rows rather + than custom nested FHIR objects. + +This module implements that flow. The dataset builds normalized resource tables +under its cache directory, then loads them through a regular ``tables:`` config in +``configs/mimic4_fhir.yaml``. """ from __future__ import annotations -import concurrent.futures +import functools import gzip -import itertools import hashlib +import itertools import logging -import multiprocessing +import operator import os -import zlib -import platformdirs import shutil import uuid from dataclasses import dataclass, field from datetime import datetime from pathlib import Path -from typing import Any, Dict, Generator, Iterator, List, Optional, Set, Tuple +from typing import Any, Dict, Iterator, List, Optional, Sequence, Tuple +import dask.dataframe as dd +import narwhals as nw import orjson +import pandas as pd +import platformdirs import polars as pl -from litdata.processing.data_processor import in_notebook -from tqdm import tqdm import pyarrow as pa import pyarrow.parquet as pq +from litdata.processing.data_processor import in_notebook from yaml import safe_load from ..data import Patient from .base_dataset import BaseDataset -# Normalized event table (BaseDataset / Patient contract) -FHIR_EVENT_TYPE: str = "fhir" -FHIR_RESOURCE_JSON_COL: str = "fhir/resource_json" -FHIR_SCHEMA_VERSION: int = 2 - -_FHIR_NDJSON_USE_RAW_LINE: bool = True -_FHIR_NDJSON_USE_COLUMN_BUFFERS: bool = True - logger = logging.getLogger(__name__) DEFAULT_PAD = 0 DEFAULT_UNK = 1 +FHIR_SCHEMA_VERSION = 3 + +FHIR_TABLES: List[str] = [ + "patient", + "encounter", + "condition", + "observation", + "medication_request", + "procedure", +] + +# Tables that carry ``patient_id`` for cohort discovery when ``patient.parquet`` is absent. +FHIR_TABLES_FOR_PATIENT_IDS: List[str] = [t for t in FHIR_TABLES if t != "patient"] + +FHIR_TABLE_FILE_NAMES: Dict[str, str] = { + table_name: f"{table_name}.parquet" for table_name in FHIR_TABLES +} +FHIR_TABLE_COLUMNS: Dict[str, List[str]] = { + "patient": [ + "patient_id", + "patient_fhir_id", + "birth_date", + "gender", + "deceased_boolean", + "deceased_datetime", + ], + "encounter": [ + "patient_id", + "resource_id", + "encounter_id", + "event_time", + "encounter_class", + "encounter_end", + ], + "condition": [ + "patient_id", + "resource_id", + "encounter_id", + "event_time", + "concept_key", + ], + "observation": [ + "patient_id", + "resource_id", + "encounter_id", + "event_time", + "concept_key", + ], + "medication_request": [ + "patient_id", + "resource_id", + "encounter_id", + "event_time", + "concept_key", + ], + "procedure": [ + "patient_id", + "resource_id", + "encounter_id", + "event_time", + "concept_key", + ], +} -def _fhir_json_loads_ndjson_line(line: str) -> Any: - return orjson.loads(line.encode("utf-8")) - - -def _fhir_json_dumps_resource(res: Dict[str, Any]) -> str: - return orjson.dumps(res).decode("utf-8") +EVENT_TYPE_TO_TOKEN_TYPE = { + "encounter": 1, + "condition": 2, + "medication_request": 3, + "observation": 4, + "procedure": 5, +} -def _fhir_json_loads_resource_column(raw: str | bytes) -> Any: - b = raw.encode("utf-8") if isinstance(raw, str) else raw - return orjson.loads(b) +def _fhir_json_loads_ndjson_line(line: str) -> Any: + return orjson.loads(line.encode("utf-8")) def _parse_dt(s: Optional[str]) -> Optional[datetime]: @@ -133,32 +171,113 @@ def _first_coding(obj: Optional[Dict[str, Any]]) -> Optional[str]: return _coding_key(codings[0]) +def _ref_id(ref: Optional[str]) -> Optional[str]: + if not ref: + return None + if "/" in ref: + return ref.rsplit("/", 1)[-1] + return ref + + +def _unwrap_resource_dict(raw: Any) -> Optional[Dict[str, Any]]: + if not isinstance(raw, dict): + return None + resource = raw.get("resource") if "resource" in raw else raw + return resource if isinstance(resource, dict) else None + + +def iter_resources_from_ndjson_obj(obj: Dict[str, Any]) -> Iterator[Dict[str, Any]]: + """Yield resource dictionaries from one parsed NDJSON object.""" + + if "entry" in obj: + for entry in obj.get("entry") or []: + resource = entry.get("resource") + if isinstance(resource, dict): + yield resource + return + + resource = _unwrap_resource_dict(obj) + if resource is not None: + yield resource + + +def iter_ndjson_objects(path: Path) -> Iterator[Dict[str, Any]]: + if path.suffix == ".gz": + opener = gzip.open(path, "rt", encoding="utf-8", errors="replace") + else: + opener = open(path, encoding="utf-8", errors="replace") + with opener as stream: + for line in stream: + line = line.strip() + if not line: + continue + parsed = _fhir_json_loads_ndjson_line(line) + if isinstance(parsed, dict): + yield parsed + + def _clinical_concept_key(res: Dict[str, Any]) -> Optional[str]: - """Resolve a stable vocabulary key; resource-type-specific per FHIR R4.""" - - rt = res.get("resourceType") - if rt == "MedicationRequest": - mcc = res.get("medicationCodeableConcept") - if isinstance(mcc, dict): - ck = _first_coding(mcc) - if ck: - return ck - mref = res.get("medicationReference") - if isinstance(mref, dict): - ref = mref.get("reference") - if ref: - rid = _ref_id(ref) - return f"MedicationRequest/reference|{rid or ref}" + """Resolve a stable token key from a flattened FHIR resource.""" + + resource_type = res.get("resourceType") + if resource_type == "MedicationRequest": + medication_cc = res.get("medicationCodeableConcept") + if isinstance(medication_cc, dict): + concept_key = _first_coding(medication_cc) + if concept_key: + return concept_key + medication_ref = res.get("medicationReference") + if isinstance(medication_ref, dict): + reference = medication_ref.get("reference") + if reference: + ref_id = _ref_id(reference) + return f"MedicationRequest/reference|{ref_id or reference}" return None + code = res.get("code") if isinstance(code, dict): return _first_coding(code) return None +def patient_id_for_resource( + resource: Dict[str, Any], + resource_type: Optional[str] = None, +) -> Optional[str]: + resource_type = resource_type or resource.get("resourceType") + if resource_type == "Patient": + patient_id = resource.get("id") + return str(patient_id) if patient_id is not None else None + if resource_type == "Encounter": + return _ref_id((resource.get("subject") or {}).get("reference")) + if resource_type in {"Condition", "Observation", "MedicationRequest", "Procedure"}: + return _ref_id((resource.get("subject") or {}).get("reference")) + return None + + +def _resource_time_string( + resource: Dict[str, Any], + resource_type: Optional[str] = None, +) -> Optional[str]: + resource_type = resource_type or resource.get("resourceType") + if resource_type == "Patient": + return resource.get("birthDate") + if resource_type == "Encounter": + return (resource.get("period") or {}).get("start") + if resource_type == "Condition": + return resource.get("onsetDateTime") or resource.get("recordedDate") + if resource_type == "Observation": + return resource.get("effectiveDateTime") or resource.get("issued") + if resource_type == "MedicationRequest": + return resource.get("authoredOn") + if resource_type == "Procedure": + return resource.get("performedDateTime") or resource.get("recordedDate") + return None + + @dataclass class ConceptVocab: - """Maps FHIR coding keys to dense ids. Supports save/load for streaming builds.""" + """Maps concept keys to dense ids.""" token_to_id: Dict[str, int] = field(default_factory=dict) pad_id: int = DEFAULT_PAD @@ -173,10 +292,10 @@ def __post_init__(self) -> None: def add_token(self, key: str) -> int: if key in self.token_to_id: return self.token_to_id[key] - tid = self._next_id + token_id = self._next_id self._next_id += 1 - self.token_to_id[key] = tid - return tid + self.token_to_id[key] = token_id + return token_id def __getitem__(self, key: str) -> int: return self.token_to_id.get(key, self.unk_id) @@ -189,15 +308,15 @@ def to_json(self) -> Dict[str, Any]: return {"token_to_id": self.token_to_id, "next_id": self._next_id} @classmethod - def from_json(cls, data: Dict[str, Any]) -> ConceptVocab: - v = cls() + def from_json(cls, data: Dict[str, Any]) -> "ConceptVocab": + vocab = cls() loaded = dict(data.get("token_to_id") or {}) if not loaded: - v._next_id = int(data.get("next_id", 2)) - return v - v.token_to_id = loaded - v._next_id = int(data.get("next_id", max(loaded.values()) + 1)) - return v + vocab._next_id = int(data.get("next_id", 2)) + return vocab + vocab.token_to_id = loaded + vocab._next_id = int(data.get("next_id", max(loaded.values()) + 1)) + return vocab def save(self, path: str) -> None: Path(path).parent.mkdir(parents=True, exist_ok=True) @@ -206,438 +325,32 @@ def save(self, path: str) -> None: ) @classmethod - def load(cls, path: str) -> ConceptVocab: + def load(cls, path: str) -> "ConceptVocab": return cls.from_json(orjson.loads(Path(path).read_bytes())) def ensure_special_tokens(vocab: ConceptVocab) -> Dict[str, int]: - """Reserve special tokens for MPF / readout.""" - - out: Dict[str, int] = {} + specials: Dict[str, int] = {} for name in ("", "", "", ""): - out[name] = vocab.add_token(name) - return out - - -@dataclass -class FHIRPatient: - """Minimal patient container for FHIR resources (not pyhealth.data.Patient).""" - - patient_id: str - resources: List[Dict[str, Any]] - birth_date: Optional[datetime] = None - - def get_patient_resource(self) -> Optional[Dict[str, Any]]: - for r in self.resources: - if r.get("resourceType") == "Patient": - return r - return None - - -def parse_ndjson_line(line: str) -> Any: - line = line.strip() - if not line: - return None - return _fhir_json_loads_ndjson_line(line) - - -def iter_ndjson_file(path: Path) -> Generator[Dict[str, Any], None, None]: - if path.suffix == ".gz": - opener = gzip.open(path, "rt", encoding="utf-8", errors="replace") - else: - opener = open(path, encoding="utf-8", errors="replace") - with opener as f: - for line in f: - obj = parse_ndjson_line(line) - if obj is not None: - yield obj - - -def iter_ndjson_raw_line_and_obj( - path: Path, -) -> Generator[Tuple[str, Dict[str, Any]], None, None]: - """Yield ``(line_text, parsed_obj)`` for each non-empty NDJSON object line.""" - - - if path.suffix == ".gz": - opener = gzip.open(path, "rt", encoding="utf-8", errors="replace") - else: - opener = open(path, encoding="utf-8", errors="replace") - with opener as f: - for line in f: - raw = line.strip() - if not raw: - continue - obj: Any = _fhir_json_loads_ndjson_line(raw) - if isinstance(obj, dict): - yield raw, obj - - -def _ndjson_root_has_top_level_bundle_or_resource_wrapper(obj: Dict[str, Any]) -> bool: - """True if root object is a Bundle or a ``{"resource": ...}`` wrapper line.""" - - return "entry" in obj or "resource" in obj - - -def _ref_id(ref: Optional[str]) -> Optional[str]: - if not ref: - return None - if "/" in ref: - return ref.rsplit("/", 1)[-1] - return ref - - -def _unwrap_resource_dict(raw: Any) -> Optional[Dict[str, Any]]: - if not isinstance(raw, dict): - return None - r = raw.get("resource") if "resource" in raw else raw - return r if isinstance(r, dict) else None - - -def iter_resources_from_ndjson_obj(obj: Dict[str, Any]) -> Iterator[Dict[str, Any]]: - """Yield FHIR resource dicts from one parsed NDJSON object. - - Expands ``Bundle`` ``entry`` resources; otherwise yields a single resource. - """ - - if isinstance(obj, dict) and "entry" in obj: - for ent in obj.get("entry") or []: - res = ent.get("resource") - if isinstance(res, dict): - yield res - else: - r = _unwrap_resource_dict(obj) - if r is not None: - yield r - - -def patient_id_for_resource( - res: Dict[str, Any], - resource_type: Optional[str] = None, -) -> Optional[str]: - """Logical patient id for sharding and tabular ``patient_id`` (FHIR subject refs).""" - - rid: Optional[str] = None - rt = resource_type if resource_type is not None else res.get("resourceType") - if rt == "Patient": - pid = res.get("id") - rid = str(pid) if pid is not None else None - elif rt == "Encounter": - rid = _ref_id((res.get("subject") or {}).get("reference")) - elif rt in ("Condition", "Observation", "MedicationRequest", "Procedure"): - rid = _ref_id((res.get("subject") or {}).get("reference")) - return rid - - -RESOURCE_TYPE_TO_TOKEN_TYPE = { - "Encounter": 1, - "Condition": 2, - "MedicationRequest": 3, - "Observation": 4, - "Procedure": 5, -} - - -def _event_time( - res: Dict[str, Any], - resource_type: Optional[str] = None, -) -> Optional[datetime]: - rt = resource_type if resource_type is not None else res.get("resourceType") - if rt == "Encounter": - return _parse_dt((res.get("period") or {}).get("start")) - if rt == "Condition": - return _parse_dt(res.get("onsetDateTime") or res.get("recordedDate")) - if rt == "Observation": - return _parse_dt(res.get("effectiveDateTime") or res.get("issued")) - if rt == "MedicationRequest": - return _parse_dt(res.get("authoredOn")) - if rt == "Procedure": - return _parse_dt(res.get("performedDateTime") or res.get("recordedDate")) - return None - - -def resource_row_timestamp( - res: Dict[str, Any], - resource_type: Optional[str] = None, -) -> Optional[datetime]: - """Timestamp for ``Patient.data_source`` sort order and Parquet ``timestamp``.""" - - t = _event_time(res, resource_type) - if t is not None: - return t - rt = resource_type if resource_type is not None else res.get("resourceType") - if rt == "Patient": - return _parse_dt(res.get("birthDate")) - return None - - -def _sequential_visit_idx_for_time( - t: Optional[datetime], visit_encounters: List[Tuple[datetime, int]] -) -> int: - """Map event time to the sequential ``visit_idx`` used in the main encounter loop. - - ``visit_encounters`` lists ``(encounter_start, visit_idx)`` only for encounters - with a valid ``period.start``, in the same order as :func:`build_cehr_sequences` - assigns ``visit_idx`` (sorted ``encounters``, skipping those without start). This - must not use raw indices into the full ``encounters`` list, or indices diverge - when some encounters lack a start time. - """ - - if not visit_encounters: - return 0 - if t is None: - return visit_encounters[-1][1] - t = _as_naive(t) - chosen = visit_encounters[0][1] - for es, vidx in visit_encounters: - if es <= t: - chosen = vidx - else: - break - return chosen - - -def _concept_vocab_key_for_cehr(res: Dict[str, Any]) -> str: - """Dense-vocabulary string for one resource (aligned with MPF CEHR tokens).""" - - rt = res.get("resourceType") - ck = _clinical_concept_key(res) - if rt == "Observation": - ck = ck or "obs|unknown" - if ck is None: - ck = f"{(rt or 'res').lower()}|unknown" - return ck - - -def collect_cehr_timeline_events( - patient: FHIRPatient, -) -> List[Tuple[datetime, Dict[str, Any], int]]: - """Clinical events in CEHR timeline order (time, resource, visit_idx). - - Matches encounter grouping and visit indexing used by - :func:`build_cehr_sequences` / MPF so vocabulary warmup stays consistent - without materializing per-token feature lists. - """ - - events: List[Tuple[datetime, Dict[str, Any], int]] = [] - encounters = [r for r in patient.resources if r.get("resourceType") == "Encounter"] - encounters.sort(key=lambda e: _event_time(e) or datetime.min) - - visit_encounters: List[Tuple[datetime, int]] = [] - _v = 0 - for enc in encounters: - _es = _event_time(enc) - if _es is None: - continue - visit_encounters.append((_as_naive(_es), _v)) - _v += 1 - - visit_idx = 0 - for enc in encounters: - eid = enc.get("id") - enc_start = _event_time(enc) - if enc_start is None: - continue - for r in patient.resources: - if r.get("resourceType") == "Patient": - continue - rt = r.get("resourceType") - if rt not in RESOURCE_TYPE_TO_TOKEN_TYPE: - continue - if rt == "Encounter" and r.get("id") != eid: - continue - if rt != "Encounter": - enc_ref = (r.get("encounter") or {}).get("reference") - if enc_ref: - ref_eid = _ref_id(enc_ref) - if ref_eid is None or str(eid) != str(ref_eid): - continue - else: - continue - t = _event_time(r) - if t is None: - t = enc_start - events.append((t, r, visit_idx)) - visit_idx += 1 - - for r in patient.resources: - if r.get("resourceType") == "Patient": - continue - rt = r.get("resourceType") - if rt not in RESOURCE_TYPE_TO_TOKEN_TYPE: - continue - if rt == "Encounter": - continue - enc_ref = (r.get("encounter") or {}).get("reference") - if enc_ref: - continue - t_evt = _event_time(r) - v_idx = _sequential_visit_idx_for_time(t_evt, visit_encounters) - t = t_evt - if t is None: - if visit_encounters: - for es, v in visit_encounters: - if v == v_idx: - t = es - break - else: - t = visit_encounters[-1][0] - if t is None: - continue - events.append((t, r, v_idx)) - - events.sort(key=lambda x: x[0]) - return events - - -def warm_mpf_vocab_from_fhir_patient( - vocab: ConceptVocab, - patient: FHIRPatient, - clinical_cap: int, -) -> None: - """Register concept keys for the tail clinical window (MPF parallel path). - - Lighter than :func:`build_cehr_sequences`: no per-position feature lists. - """ - - tail = ( - collect_cehr_timeline_events(patient)[-clinical_cap:] - if clinical_cap > 0 - else [] - ) - for _, res, _ in tail: - vocab.add_token(_concept_vocab_key_for_cehr(res)) - - -def build_cehr_sequences( - patient: FHIRPatient, - vocab: ConceptVocab, - max_len: int, - *, - base_time: Optional[datetime] = None, - grow_vocab: bool = True, -) -> Tuple[ - List[int], - List[int], - List[float], - List[float], - List[int], - List[int], -]: - """Flatten patient resources into CEHR-aligned lists (pre-padding). - - Args: - max_len: Maximum number of **clinical** tokens emitted (after time sort and - tail slice). Use ``0`` to emit no clinical tokens (empty lists; avoids - Python's ``events[-0:]`` which would incorrectly take the full timeline). - Downstream MPF tasks reserve two slots for ````/```` and - ````, so pass ``max_len - 2`` there when the final tensor length - is fixed. - grow_vocab: If True (default), assign new dense ids via ``add_token``. If - False, use only existing ids (```` for unknown codes)—for parallel - ``set_task`` workers after a main-process vocabulary warmup. - """ - - birth = patient.birth_date - if birth is None: - pr = patient.get_patient_resource() - if pr: - birth = _parse_dt(pr.get("birthDate")) - - events = collect_cehr_timeline_events(patient) - - if base_time is None and events: - base_time = events[0][0] - elif base_time is None: - base_time = datetime.now() - - concept_ids: List[int] = [] - token_types: List[int] = [] - time_stamps: List[float] = [] - ages: List[float] = [] - visit_orders: List[int] = [] - visit_segments: List[int] = [] - - base_time = _as_naive(base_time) - birth = _as_naive(birth) - tail = events[-max_len:] if max_len > 0 else [] - for t, res, v_idx in tail: - t = _as_naive(t) - rt = res.get("resourceType") - ck = _concept_vocab_key_for_cehr(res) - if grow_vocab: - cid = vocab.add_token(ck) - else: - cid = vocab[ck] - tt = RESOURCE_TYPE_TO_TOKEN_TYPE.get(rt, 0) - ts = float((t - base_time).total_seconds()) if base_time and t else 0.0 - age_y = 0.0 - if birth and t: - age_y = (t - birth).days / 365.25 - seg = v_idx % 2 - concept_ids.append(cid) - token_types.append(tt) - time_stamps.append(ts) - ages.append(age_y) - visit_orders.append(min(v_idx, 511)) - visit_segments.append(seg) - - return concept_ids, token_types, time_stamps, ages, visit_orders, visit_segments - - -def fhir_patient_from_patient(patient: Patient) -> FHIRPatient: - """Rebuild :class:`FHIRPatient` from a tabular :class:`~pyhealth.data.Patient`.""" - - resources: List[Dict[str, Any]] = [] - for row in patient.data_source.iter_rows(named=True): - raw = row.get(FHIR_RESOURCE_JSON_COL) - if not raw: - continue - resources.append(_fhir_json_loads_resource_column(raw)) - birth: Optional[datetime] = None - for r in resources: - if r.get("resourceType") == "Patient": - birth = _parse_dt(r.get("birthDate")) - break - return FHIRPatient( - patient_id=patient.patient_id, resources=resources, birth_date=birth - ) - - -def infer_mortality_label(patient: FHIRPatient) -> int: - """Heuristic binary label: 1 if deceased or explicit death condition.""" - - pr = patient.get_patient_resource() - if pr and pr.get("deceasedBoolean") is True: - return 1 - if pr and pr.get("deceasedDateTime"): - return 1 - for r in patient.resources: - if r.get("resourceType") != "Condition": - continue - ck = (_first_coding(r.get("code") or {}) or "").lower() - if any(x in ck for x in ("death", "deceased", "mortality")): - return 1 - return 0 + specials[name] = vocab.add_token(name) + return specials def synthetic_mpf_one_patient_resources() -> List[Dict[str, Any]]: - """Minimal FHIR NDJSON rows for one patient (tests, quick-start examples).""" - patient: Dict[str, Any] = { "resourceType": "Patient", "id": "p-synth-1", "birthDate": "1950-01-01", "gender": "female", } - enc: Dict[str, Any] = { + encounter: Dict[str, Any] = { "resourceType": "Encounter", "id": "e1", "subject": {"reference": "Patient/p-synth-1"}, "period": {"start": "2020-06-01T10:00:00Z"}, "class": {"code": "IMP"}, } - cond: Dict[str, Any] = { + condition: Dict[str, Any] = { "resourceType": "Condition", "id": "c1", "subject": {"reference": "Patient/p-synth-1"}, @@ -647,26 +360,24 @@ def synthetic_mpf_one_patient_resources() -> List[Dict[str, Any]]: }, "onsetDateTime": "2020-06-01T11:00:00Z", } - return [patient, enc, cond] + return [patient, encounter, condition] def synthetic_mpf_two_patient_resources() -> List[Dict[str, Any]]: - """Two-patient fixture including a deceased patient (binary label smoke tests).""" - - dead_p: Dict[str, Any] = { + dead_patient: Dict[str, Any] = { "resourceType": "Patient", "id": "p-synth-2", "birthDate": "1940-05-05", "deceasedBoolean": True, } - dead_enc: Dict[str, Any] = { + dead_encounter: Dict[str, Any] = { "resourceType": "Encounter", "id": "e-dead", "subject": {"reference": "Patient/p-synth-2"}, "period": {"start": "2020-07-01T10:00:00Z"}, "class": {"code": "IMP"}, } - dead_obs: Dict[str, Any] = { + dead_observation: Dict[str, Any] = { "resourceType": "Observation", "id": "o-dead", "subject": {"reference": "Patient/p-synth-2"}, @@ -674,13 +385,19 @@ def synthetic_mpf_two_patient_resources() -> List[Dict[str, Any]]: "effectiveDateTime": "2020-07-01T12:00:00Z", "code": {"coding": [{"system": "http://loinc.org", "code": "789-0"}]}, } - return [*synthetic_mpf_one_patient_resources(), dead_p, dead_enc, dead_obs] + return [ + *synthetic_mpf_one_patient_resources(), + dead_patient, + dead_encounter, + dead_observation, + ] def synthetic_mpf_one_patient_ndjson_text() -> str: return ( "\n".join( - orjson.dumps(r).decode("utf-8") for r in synthetic_mpf_one_patient_resources() + orjson.dumps(resource).decode("utf-8") + for resource in synthetic_mpf_one_patient_resources() ) + "\n" ) @@ -689,321 +406,559 @@ def synthetic_mpf_one_patient_ndjson_text() -> str: def synthetic_mpf_two_patient_ndjson_text() -> str: return ( "\n".join( - orjson.dumps(r).decode("utf-8") - for r in synthetic_mpf_two_patient_resources() + orjson.dumps(resource).decode("utf-8") + for resource in synthetic_mpf_two_patient_resources() ) + "\n" ) def read_fhir_settings_yaml(path: Optional[str] = None) -> Dict[str, Any]: - """Load FHIR YAML (glob pattern, version); not a CSV ``DatasetConfig`` schema. - - Args: - path: Defaults to ``configs/mimic4_fhir.yaml`` beside this module. - - Returns: - Parsed mapping. - """ if path is None: path = os.path.join(os.path.dirname(__file__), "configs", "mimic4_fhir.yaml") - with open(path, encoding="utf-8") as f: - data = safe_load(f) + with open(path, encoding="utf-8") as stream: + data = safe_load(stream) return data if isinstance(data, dict) else {} -def fhir_events_arrow_schema() -> pa.Schema: - """Arrow schema for normalized FHIR event rows.""" +def _table_schema(table_name: str) -> pa.Schema: + return pa.schema([(column, pa.string()) for column in FHIR_TABLE_COLUMNS[table_name]]) - return pa.schema( - [ - ("patient_id", pa.string()), - ("event_type", pa.string()), - ("timestamp", pa.timestamp("ms")), - (FHIR_RESOURCE_JSON_COL, pa.string()), - ] - ) + +class _BufferedParquetWriter: + def __init__(self, path: Path, schema: pa.Schema, batch_size: int = 50_000) -> None: + self.path = path + self.schema = schema + self.batch_size = batch_size + self.rows: List[Dict[str, Any]] = [] + self.writer: Optional[pq.ParquetWriter] = None + self.path.parent.mkdir(parents=True, exist_ok=True) + + def add(self, row: Dict[str, Any]) -> None: + self.rows.append(row) + if len(self.rows) >= self.batch_size: + self.flush() + + def flush(self) -> None: + if not self.rows: + return + table = pa.Table.from_pylist(self.rows, schema=self.schema) + if self.writer is None: + self.writer = pq.ParquetWriter(str(self.path), self.schema) + self.writer.write_table(table) + self.rows.clear() + + def close(self) -> None: + self.flush() + if self.writer is None: + pq.write_table(pa.Table.from_pylist([], schema=self.schema), str(self.path)) + return + self.writer.close() -def _crc32_shard_index(key: str, num_shards: int) -> int: - """Stable shard index in ``[0, num_shards)`` (portable ``crc32``).""" +def _normalize_deceased_boolean_for_storage(value: Any) -> Optional[str]: + """Map ``Patient.deceasedBoolean`` to stored ``\"true\"`` / ``\"false\"`` / ``None``. - u = zlib.crc32(key.encode("utf-8")) & 0xFFFFFFFF - return int(u % max(1, num_shards)) + FHIR JSON uses real booleans; some exports incorrectly use strings. Python's + ``bool(\"false\")`` is ``True``, so we must not coerce unknown values with + ``bool()`` or non-living patients can be written as ``deceased_boolean=\"true\"``. + """ + if value is None: + return None + if value is True: + return "true" + if value is False: + return "false" + if isinstance(value, str): + key = value.strip().lower() + if key in ("true", "1", "yes", "y", "t"): + return "true" + if key in ("false", "0", "no", "n", "f", ""): + return "false" + return None + if isinstance(value, (int, float)) and not isinstance(value, bool): + if value == 0: + return "false" + if value == 1: + return "true" + return None + return None -def _normalize_fhir_chunk_work(args: Any) -> Dict[str, Any]: - """Pickle-friendly fields for :func:`_process_fhir_file_chunk`.""" +def _flatten_resource_to_table_row( + resource: Dict[str, Any], +) -> Optional[Tuple[str, Dict[str, Optional[str]]]]: + resource_type = resource.get("resourceType") + patient_id = patient_id_for_resource(resource, resource_type) + if not patient_id: + return None - if isinstance(args, dict): - w = {**args} - w["file_path"] = Path(w["file_path"]) - w["out_dir"] = Path(w["out_dir"]) - elif len(args) == 5: - file_idx, file_path, out_dir, num_shards, batch_size = args # type: ignore[misc] - w = { - "file_idx": file_idx, - "file_path": Path(file_path), - "out_dir": Path(out_dir), - "num_shards": num_shards, - "batch_size": batch_size, + if resource_type == "Patient": + return "patient", { + "patient_id": patient_id, + "patient_fhir_id": str(resource.get("id") or patient_id), + "birth_date": resource.get("birthDate"), + "gender": resource.get("gender"), + "deceased_boolean": _normalize_deceased_boolean_for_storage( + resource.get("deceasedBoolean") + ), + "deceased_datetime": resource.get("deceasedDateTime"), } - else: - raise TypeError( - "expected a dict or 5-tuple " - "(file_idx, file_path, out_dir, num_shards, batch_size) for " - f"_process_fhir_file_chunk, got {type(args)!r}" - ) - w["ingest_use_raw_ndjson_line"] = _FHIR_NDJSON_USE_RAW_LINE - w["ingest_use_column_buffers"] = _FHIR_NDJSON_USE_COLUMN_BUFFERS - return w + resource_id = str(resource.get("id")) if resource.get("id") is not None else None + event_time = _resource_time_string(resource, resource_type) + + if resource_type == "Encounter": + return "encounter", { + "patient_id": patient_id, + "resource_id": resource_id, + "encounter_id": resource_id, + "event_time": event_time, + "encounter_class": (resource.get("class") or {}).get("code"), + "encounter_end": (resource.get("period") or {}).get("end"), + } -def _process_fhir_file_chunk( - args: Any, -) -> int: - """Read one NDJSON/NDJSON.GZ file and write hash-sharded Parquet rows. + linked_encounter_id = _ref_id((resource.get("encounter") or {}).get("reference")) + concept_key = _clinical_concept_key(resource) + row = { + "patient_id": patient_id, + "resource_id": resource_id, + "encounter_id": linked_encounter_id, + "event_time": event_time, + "concept_key": concept_key, + } + if resource_type == "Condition": + return "condition", row + if resource_type == "Observation": + return "observation", row + if resource_type == "MedicationRequest": + return "medication_request", row + if resource_type == "Procedure": + return "procedure", row + return None + + +GlobPatternArg = str | Sequence[str] +"""Type alias for glob pattern argument: single string or sequence of strings.""" - Output files are ``shard-{file_idx:04d}-{shard_idx:04d}.parquet`` (``file_idx`` is - the file's index in the sorted ingest list) so tasks never share a path. Multiple - flushes for the same shard use one :class:`~pyarrow.parquet.ParquetWriter`. - A single compressed ``.ndjson.gz`` is still consumed sequentially in one process - (standard gzip stream); scheduling **one process per input file** lets large - files (e.g. Chartevents vs Labevents) run on different CPUs instead of being - stuck together in a static path batch. +def sorted_ndjson_files(root: Path, glob_pattern: GlobPatternArg) -> List[Path]: + """Return sorted unique file paths under ``root`` matching glob pattern(s). Args: - args: A ``dict`` with ``file_idx``, ``file_path``, ``out_dir``, - ``num_shards``, and ``batch_size``, or a 5-tuple of those values in that - order. Ingest behavior matches :data:`FHIR_SCHEMA_VERSION`. + root (Path): Root directory to search under. + glob_pattern (GlobPatternArg): Single glob string (e.g., ``"*.ndjson.gz"``) + or sequence of glob strings. Patterns are applied to ``root.glob()``; + results are deduplicated and sorted lexicographically by string path. Returns: - Row count (FHIR resources with a resolvable ``patient_id``) for this file. + List[Path]: Sorted list of matching files. Empty if no matches. + + Example: + >>> from pathlib import Path + >>> root = Path("/data/fhir") + >>> # Single pattern: + >>> files = sorted_ndjson_files(root, "**/*.ndjson.gz") + >>> # Multiple patterns (deduplicated): + >>> files = sorted_ndjson_files(root, [ + ... "**/MimicPatient*.ndjson.gz", + ... "**/MimicEncounter*.ndjson.gz", + ... ]) """ - wk = _normalize_fhir_chunk_work(args) - file_idx = int(wk["file_idx"]) - fp = wk["file_path"] - out_dir = wk["out_dir"] - num_shards = max(1, int(wk["num_shards"])) - batch_size = int(wk["batch_size"]) - ingest_use_raw_ndjson_line = bool(wk["ingest_use_raw_ndjson_line"]) - ingest_use_column_buffers = bool(wk["ingest_use_column_buffers"]) + patterns = [glob_pattern] if isinstance(glob_pattern, str) else list(glob_pattern) + files: set[Path] = set() + for pat in patterns: + files.update(p for p in root.glob(pat) if p.is_file()) + return sorted(files, key=lambda p: str(p)) + + +def stream_fhir_ndjson_to_flat_tables( + root: Path, + glob_pattern: GlobPatternArg, + out_dir: Path, +) -> None: + """Stream NDJSON resources into normalized per-resource Parquet tables. + + Reads all NDJSON/NDJSON.GZ files matching ``glob_pattern`` under ``root``, + parses each line as FHIR JSON, normalizes each resource via + ``_flatten_resource_to_table_row``, and writes rows to per-resource-type + Parquet tables under ``out_dir``. Resources are skipped if their type + is not in ``FHIR_TABLES`` (e.g., Medication, Specimen, Organization). + + Args: + root (Path): Root directory containing NDJSON/NDJSON.GZ files. + glob_pattern (GlobPatternArg): Single glob string or sequence of glob strings + to match NDJSON files. E.g., ``"**/*.ndjson.gz"`` or + ``["**/MimicPatient*.ndjson.gz", "**/MimicEncounter*.ndjson.gz"]``. + out_dir (Path): Output directory for per-resource-type Parquet tables. + Created if absent. Writes: + - ``patient.parquet`` + - ``encounter.parquet`` + - ``condition.parquet`` + - ``observation.parquet`` + - ``medication_request.parquet`` + - ``procedure.parquet`` + + Raises: + IOError: If files cannot be read or written. + + Notes: + - All matching files are decompressed and fully parsed (no early exit + for unsupported resource types). + - Rows are buffered in memory (batch size 50k) before writing. + - Empty output tables are still created. + - Writers are always closed in a ``finally`` block (including on errors). + """ - schema = fhir_events_arrow_schema() out_dir.mkdir(parents=True, exist_ok=True) - writers: List[Optional[pq.ParquetWriter]] = [None] * num_shards - n_rows = 0 - - if ingest_use_column_buffers: - buf_pid: List[List[str]] = [[] for _ in range(num_shards)] - buf_et: List[List[str]] = [[] for _ in range(num_shards)] - buf_ts: List[List[Optional[datetime]]] = [[] for _ in range(num_shards)] - buf_rj: List[List[str]] = [[] for _ in range(num_shards)] - else: - batches: List[List[Dict[str, Any]]] = [[] for _ in range(num_shards)] + writers = { + table_name: _BufferedParquetWriter( + path=out_dir / FHIR_TABLE_FILE_NAMES[table_name], + schema=_table_schema(table_name), + ) + for table_name in FHIR_TABLES + } - def flush(shard: int) -> None: - nonlocal n_rows - if ingest_use_column_buffers: - if not buf_pid[shard]: - return - count = len(buf_pid[shard]) - table = pa.table( - { - "patient_id": buf_pid[shard], - "event_type": buf_et[shard], - "timestamp": buf_ts[shard], - FHIR_RESOURCE_JSON_COL: buf_rj[shard], - }, - schema=schema, - ) - buf_pid[shard].clear() - buf_et[shard].clear() - buf_ts[shard].clear() - buf_rj[shard].clear() - else: - if not batches[shard]: - return - table = pa.Table.from_pylist(batches[shard], schema=schema) - count = len(batches[shard]) - batches[shard].clear() - if writers[shard] is None: - writers[shard] = pq.ParquetWriter( - str(out_dir / f"shard-{file_idx:04d}-{shard:04d}.parquet"), - schema, - ) - writers[shard].write_table(table) - n_rows += count + try: + files = sorted_ndjson_files(root, glob_pattern) + if not files: + return - def append_row( - shard: int, pid: str, ts: Optional[datetime], resource_json: str - ) -> None: - if ingest_use_column_buffers: - buf_pid[shard].append(pid) - buf_et[shard].append(FHIR_EVENT_TYPE) - buf_ts[shard].append(ts) - buf_rj[shard].append(resource_json) - else: - batches[shard].append( - { - "patient_id": pid, - "event_type": FHIR_EVENT_TYPE, - "timestamp": ts, - FHIR_RESOURCE_JSON_COL: resource_json, - } - ) + for file_path in files: + for ndjson_obj in iter_ndjson_objects(file_path): + for resource in iter_resources_from_ndjson_obj(ndjson_obj): + flattened = _flatten_resource_to_table_row(resource) + if flattened is None: + continue + table_name, row = flattened + writers[table_name].add(row) + finally: + for writer in writers.values(): + writer.close() - if fp.is_file(): - for raw_line, obj in iter_ndjson_raw_line_and_obj(fp): - use_raw_for_line = ingest_use_raw_ndjson_line and ( - not _ndjson_root_has_top_level_bundle_or_resource_wrapper(obj) - ) - for res in iter_resources_from_ndjson_obj(obj): - rt = res.get("resourceType") - pid = patient_id_for_resource(res, rt) - if not pid: - continue - ts = resource_row_timestamp(res, rt) - resource_json = ( - raw_line if use_raw_for_line else _fhir_json_dumps_resource(res) - ) - shard = _crc32_shard_index(pid, num_shards) - append_row(shard, pid, ts, resource_json) - cur_len = ( - len(buf_pid[shard]) - if ingest_use_column_buffers - else len(batches[shard]) - ) - if cur_len >= batch_size: - flush(shard) - for s in range(num_shards): - flush(s) - for s in range(num_shards): - if writers[s] is not None: - writers[s].close() +def _sorted_patient_ids_from_flat_tables(table_dir: Path) -> List[str]: + patient_table = table_dir / FHIR_TABLE_FILE_NAMES["patient"] + if patient_table.exists(): + return ( + pl.scan_parquet(str(patient_table)) + .select("patient_id") + .unique() + .sort("patient_id") + .collect(engine="streaming")["patient_id"] + .to_list() + ) - return n_rows + frames = [ + pl.scan_parquet(str(table_dir / FHIR_TABLE_FILE_NAMES[table_name])).select("patient_id") + for table_name in FHIR_TABLES_FOR_PATIENT_IDS + ] + return ( + pl.concat(frames) + .unique() + .sort("patient_id") + .collect(engine="streaming")["patient_id"] + .to_list() + ) -def stream_fhir_ndjson_root_to_sharded_parquet( - root: Path, - glob_pattern: str, +def filter_flat_tables_by_patient_ids( + source_dir: Path, out_dir: Path, - *, - num_shards: int = 16, - batch_size: int = 50_000, -) -> int: - """Stream matching NDJSON / NDJSON.GZ files into hash-sharded Parquet files. + keep_ids: Sequence[str], +) -> None: + out_dir.mkdir(parents=True, exist_ok=True) + keep_ids_set = set(keep_ids) + for table_name in FHIR_TABLES: + src = source_dir / FHIR_TABLE_FILE_NAMES[table_name] + dst = out_dir / FHIR_TABLE_FILE_NAMES[table_name] + pl.scan_parquet(str(src)).filter(pl.col("patient_id").is_in(keep_ids_set)).sink_parquet( + str(dst) + ) - Files under ``root`` matching ``glob_pattern`` are read in parallel: **one - process pool task per file**, up to ``min(os.cpu_count(), N)`` workers, so the - pool load-balances across uneven file sizes (MIMIC-IV FHIR has a few very - large ``*.ndjson.gz`` shards). Each task writes - ``shard-{file_index}-{hash_bucket}.parquet``; the downstream cache globs - ``shard-*.parquet`` and scans them with Polars. - All rows for a given ``patient_id`` share one hash bucket (same ``num_shards``); - output paths are disjoint per input file. Shards with no rows for a file - produce no file. If no input files match, or all rows lack a ``patient_id``, - writes a single empty ``shard-0000.parquet``. +def _clean_string(value: Any) -> Optional[str]: + if value is None: + return None + if isinstance(value, str): + value = value.strip() + return value or None + return str(value) - Returns: - Number of rows written (FHIR resources with a resolvable ``patient_id``). - """ - schema = fhir_events_arrow_schema() - out_dir = Path(out_dir) - out_dir.mkdir(parents=True, exist_ok=True) - num_shards = max(1, int(num_shards)) +def _deceased_boolean_column_means_dead(value: Any) -> bool: + """True only for an explicit affirmative stored flag (not Python truthiness).""" + s = _clean_string(value) + return s is not None and s.lower() == "true" - all_files = sorted( - p for p in root.glob(glob_pattern) if p.is_file() - ) - if not all_files: - pq.write_table( - pa.Table.from_pylist([], schema=schema), - str(out_dir / "shard-0000.parquet"), + +def _row_datetime(value: Any) -> Optional[datetime]: + if value is None: + return None + if isinstance(value, datetime): + return _as_naive(value) + try: + return _parse_dt(str(value)) + except Exception: + return None + + +def _concept_key_from_row(row: Dict[str, Any]) -> str: + event_type = row.get("event_type") + if event_type == "condition": + return _clean_string(row.get("condition/concept_key")) or "condition|unknown" + if event_type == "observation": + return _clean_string(row.get("observation/concept_key")) or "observation|unknown" + if event_type == "medication_request": + return ( + _clean_string(row.get("medication_request/concept_key")) + or "medication_request|unknown" ) + if event_type == "procedure": + return _clean_string(row.get("procedure/concept_key")) or "procedure|unknown" + if event_type == "encounter": + encounter_class = _clean_string(row.get("encounter/encounter_class")) + return f"encounter|{encounter_class}" if encounter_class else "encounter|unknown" + return f"{event_type or 'event'}|unknown" + + +def _linked_encounter_id_from_row(row: Dict[str, Any]) -> Optional[str]: + event_type = row.get("event_type") + if event_type == "condition": + return _clean_string(row.get("condition/encounter_id")) + if event_type == "observation": + return _clean_string(row.get("observation/encounter_id")) + if event_type == "medication_request": + return _clean_string(row.get("medication_request/encounter_id")) + if event_type == "procedure": + return _clean_string(row.get("procedure/encounter_id")) + if event_type == "encounter": + return _clean_string(row.get("encounter/encounter_id")) + return None + + +def _birth_datetime_from_patient(patient: Patient) -> Optional[datetime]: + for row in patient.data_source.iter_rows(named=True): + if row.get("event_type") != "patient": + continue + birth = _row_datetime(row.get("timestamp")) + if birth is not None: + return birth + birth_raw = _clean_string(row.get("patient/birth_date")) + if birth_raw: + return _parse_dt(birth_raw) + return None + + +def _sequential_visit_idx_for_time( + event_time: Optional[datetime], + visit_encounters: List[Tuple[datetime, int]], +) -> int: + if not visit_encounters: return 0 + if event_time is None: + return visit_encounters[-1][1] + event_time = _as_naive(event_time) + chosen = visit_encounters[0][1] + for encounter_start, visit_idx in visit_encounters: + if encounter_start <= event_time: + chosen = visit_idx + else: + break + return chosen - cpu = os.cpu_count() or 1 - max_workers = min(cpu, len(all_files)) - work_args = [ - { - "file_idx": i, - "file_path": all_files[i], - "out_dir": out_dir, - "num_shards": num_shards, - "batch_size": batch_size, - } - for i in range(len(all_files)) + +def collect_cehr_timeline_events( + patient: Patient, +) -> List[Tuple[datetime, str, str, int]]: + """Collect CEHR timeline events directly from flattened patient rows.""" + + rows = list( + patient.data_source.sort(["timestamp", "event_type"], nulls_last=True).iter_rows( + named=True + ) + ) + + encounter_rows: List[Tuple[datetime, str]] = [] + for row in rows: + if row.get("event_type") != "encounter": + continue + encounter_id = _linked_encounter_id_from_row(row) + encounter_start = _row_datetime(row.get("timestamp")) + if encounter_id is None or encounter_start is None: + continue + encounter_rows.append((encounter_start, encounter_id)) + + encounter_rows.sort(key=lambda pair: pair[0]) + encounter_visit_idx = { + encounter_id: visit_idx + for visit_idx, (_, encounter_id) in enumerate(encounter_rows) + } + encounter_start_by_id = { + encounter_id: encounter_start for encounter_start, encounter_id in encounter_rows + } + visit_encounters = [ + (encounter_start, visit_idx) + for visit_idx, (encounter_start, _) in enumerate(encounter_rows) ] - if len(work_args) == 1: - n_rows = _process_fhir_file_chunk(work_args[0]) # type: ignore[arg-type] - else: - # ``spawn`` matches :meth:`BaseDataset._task_transform` — avoid ``fork`` with - # a Polars-import-loaded parent (see Polars multiprocessing docs). - ctx = multiprocessing.get_context("spawn") - with concurrent.futures.ProcessPoolExecutor( - max_workers=max_workers, - mp_context=ctx, - ) as executor: - counts = list( - tqdm( - executor.map(_process_fhir_file_chunk, work_args), - total=len(work_args), - desc="FHIR NDJSON ingest", - unit="file", - ) - ) - n_rows = sum(counts) + events: List[Tuple[datetime, str, str, int]] = [] + unlinked: List[Tuple[Optional[datetime], str, str]] = [] + + for row in rows: + event_type = row.get("event_type") + if event_type not in EVENT_TYPE_TO_TOKEN_TYPE: + continue + + event_time = _row_datetime(row.get("timestamp")) + concept_key = _concept_key_from_row(row) + + if event_type == "encounter": + encounter_id = _linked_encounter_id_from_row(row) + if encounter_id is None or event_time is None: + continue + visit_idx = encounter_visit_idx.get(encounter_id) + if visit_idx is None: + continue + events.append((event_time, concept_key, event_type, visit_idx)) + continue + + encounter_id = _linked_encounter_id_from_row(row) + if encounter_id and encounter_id in encounter_visit_idx: + visit_idx = encounter_visit_idx[encounter_id] + if event_time is None: + event_time = encounter_start_by_id.get(encounter_id) + if event_time is None: + continue + events.append((event_time, concept_key, event_type, visit_idx)) + else: + unlinked.append((event_time, concept_key, event_type)) + + for event_time, concept_key, event_type in unlinked: + visit_idx = _sequential_visit_idx_for_time(event_time, visit_encounters) + if event_time is None: + if not visit_encounters: + continue + for encounter_start, encounter_visit_idx_value in visit_encounters: + if encounter_visit_idx_value == visit_idx: + event_time = encounter_start + break + else: + event_time = visit_encounters[-1][0] + if event_time is None: + continue + events.append((event_time, concept_key, event_type, visit_idx)) + + events.sort(key=lambda item: item[0]) + return events + + +def warm_mpf_vocab_from_patient( + vocab: ConceptVocab, + patient: Patient, + clinical_cap: int, +) -> None: + tail = collect_cehr_timeline_events(patient)[-clinical_cap:] if clinical_cap > 0 else [] + for _, concept_key, _, _ in tail: + vocab.add_token(concept_key) + + +def build_cehr_sequences( + patient: Patient, + vocab: ConceptVocab, + max_len: int, + *, + base_time: Optional[datetime] = None, + grow_vocab: bool = True, +) -> Tuple[List[int], List[int], List[float], List[float], List[int], List[int]]: + """Flatten a patient's tabular FHIR rows into CEHR-aligned feature lists.""" - if n_rows == 0: - pq.write_table( - pa.Table.from_pylist([], schema=schema), - str(out_dir / "shard-0000.parquet"), + events = collect_cehr_timeline_events(patient) + birth = _birth_datetime_from_patient(patient) + + if base_time is None and events: + base_time = events[0][0] + elif base_time is None: + base_time = datetime.now() + + concept_ids: List[int] = [] + token_types: List[int] = [] + time_stamps: List[float] = [] + ages: List[float] = [] + visit_orders: List[int] = [] + visit_segments: List[int] = [] + + base_time = _as_naive(base_time) + birth = _as_naive(birth) + tail = events[-max_len:] if max_len > 0 else [] + + for event_time, concept_key, event_type, visit_idx in tail: + event_time = _as_naive(event_time) + concept_id = vocab.add_token(concept_key) if grow_vocab else vocab[concept_key] + token_type = EVENT_TYPE_TO_TOKEN_TYPE.get(event_type, 0) + time_delta = ( + float((event_time - base_time).total_seconds()) + if base_time is not None and event_time is not None + else 0.0 ) - return n_rows + age_years = 0.0 + if birth is not None and event_time is not None: + age_years = (event_time - birth).days / 365.25 + concept_ids.append(concept_id) + token_types.append(token_type) + time_stamps.append(time_delta) + ages.append(age_years) + visit_orders.append(min(visit_idx, 511)) + visit_segments.append(visit_idx % 2) -class MIMIC4FHIRDataset(BaseDataset): - """MIMIC-IV on FHIR (NDJSON / NDJSON.GZ / Bundle) with PyHealth's tabular cache. + return concept_ids, token_types, time_stamps, ages, visit_orders, visit_segments - Streams resources to ``global_event_df`` Parquet - (``patient_id``, ``timestamp``, ``event_type``, ``fhir/resource_json``), then - uses :class:`~pyhealth.data.Patient` and standard :meth:`set_task` like other - datasets. MPF uses :class:`~pyhealth.tasks.mpf_clinical_prediction.MPFClinicalPredictionTask`. - YAML defaults (e.g. ``glob_pattern``) live in - ``pyhealth/datasets/configs/mimic4_fhir.yaml``. NDJSON→Parquet ingest is - implemented in this module and is not tunable via the constructor. +def infer_mortality_label(patient: Patient) -> int: + """Heuristic binary label from flattened patient rows.""" - Args: - root: Root directory scanned for NDJSON/NDJSON.GZ (PhysioNet: the ``fhir/`` - folder with ``Mimic*.ndjson.gz`` files). - config_path: Optional path to the FHIR YAML settings file. - glob_pattern: If set, overrides the YAML ``glob_pattern`` (default - ``**/*.ndjson.gz`` for credentialled exports). - max_patients: After streaming, keep only the first N patient ids (sorted). - ingest_num_shards: Number of hash shards for the NDJSON→Parquet pass; - defaults from YAML ``ingest_num_shards`` or CPU-based heuristics. - vocab_path: Optional path to a saved :class:`ConceptVocab` JSON. - cache_dir: Forwarded to :class:`~pyhealth.datasets.BaseDataset`. - num_workers: Forwarded to :class:`~pyhealth.datasets.BaseDataset`. - dev: If True and ``max_patients`` is None, caps at 1000 patients. + for row in patient.data_source.iter_rows(named=True): + if row.get("event_type") == "patient": + if _deceased_boolean_column_means_dead(row.get("patient/deceased_boolean")): + return 1 + if _clean_string(row.get("patient/deceased_datetime")): + return 1 - Example: - >>> from pyhealth.datasets import MIMIC4FHIRDataset - >>> from pyhealth.tasks.mpf_clinical_prediction import ( - ... MPFClinicalPredictionTask, - ... ) - >>> ds = MIMIC4FHIRDataset(root="/path/to/fhir", max_patients=50) - >>> task = MPFClinicalPredictionTask(max_len=256) - >>> sample_ds = ds.set_task(task) # doctest: +SKIP + for row in patient.data_source.iter_rows(named=True): + if row.get("event_type") != "condition": + continue + concept_key = (_clean_string(row.get("condition/concept_key")) or "").lower() + if any(token in concept_key for token in ("death", "deceased", "mortality")): + return 1 + return 0 + + +class MIMIC4FHIRDataset(BaseDataset): + """MIMIC-IV on FHIR with flattened resource tables and standard task flow. + + This dataset normalizes raw MIMIC-IV FHIR NDJSON/NDJSON.GZ exports into + six flattened Parquet tables (Patient, Encounter, Condition, Observation, + MedicationRequest, Procedure), then pipelines them through + :class:`~pyhealth.datasets.BaseDataset` for standard downstream task + processing (global event dataframe, patient iteration, task sampling). + + **Ingest flow (out-of-core):** + 1. Scan NDJSON files matching ``glob_patterns`` (defaults to six Mimic* families). + 2. Parse and flatten each FHIR resource into a row in the appropriate table. + 3. Cache normalized tables as Parquet under ``cache_dir / {uuid} / flattened_tables/``. + 4. Load and compose tables into ``global_event_df`` via YAML config. + + **Data model:** + - Resource types outside ``FHIR_TABLES`` (Medication, Specimen, …) are skipped. + - Timestamps are coerced from heterogeneous FHIR ISO 8601 strings (with/without + timezone, or date-only). Coercion keeps downstream Polars/Dask pipelines robust. + - Concept keys are derived from the first FHIR coding or synthesized from references. + + **Cache fingerprinting:** + Cache invalidation includes ``glob_patterns`` and YAML digest, so changes to either + create a new independent cache. """ def __init__( @@ -1011,6 +966,7 @@ def __init__( root: str, config_path: Optional[str] = None, glob_pattern: Optional[str] = None, + glob_patterns: Optional[Sequence[str]] = None, max_patients: Optional[int] = None, ingest_num_shards: Optional[int] = None, vocab_path: Optional[str] = None, @@ -1018,37 +974,105 @@ def __init__( num_workers: int = 1, dev: bool = False, ) -> None: + """Initialize a MIMIC-IV FHIR dataset. + + Args: + root (str): Path to the NDJSON/NDJSON.GZ export directory. + config_path (Optional[str]): Path to a custom YAML config file. Defaults to + ``pyhealth/datasets/configs/mimic4_fhir.yaml``. + glob_pattern (Optional[str]): Single glob pattern for NDJSON files + (e.g., ``"*.ndjson.gz"``). Mutually exclusive with ``glob_patterns``. + Overrides YAML setting. + glob_patterns (Optional[Sequence[str]]): Multiple glob patterns as a list. + Patterns are deduplicated and sorted. Mutually exclusive with ``glob_pattern``. + Overrides YAML setting. + max_patients (Optional[int]): If set, ingest is limited to the first *N* + unique patient IDs (sorted). Ingest still parses all matching NDJSON + unless you narrow ``glob_patterns`` / ``glob_pattern``. For faster + prototyping on a laptop, combine with narrow globs. + ingest_num_shards (Optional[int]): Ignored; retained for API compatibility. + vocab_path (Optional[str]): Path to a pre-built ConceptVocab JSON file. + If provided and exists, it is loaded; otherwise a new vocab is created. + cache_dir (Optional[str | Path]): Cache directory root. Behavior: + + - **None** (default): Auto-generated under ``platformdirs.user_cache_dir()``. + - **str** or **Path**: Used as root; a UUID is appended per configuration. + + num_workers (int): Number of worker processes for task sampling. Defaults to 1. + dev (bool): Development mode: limits to 1000 patients if ``max_patients`` is None. + + Raises: + ValueError: If both ``glob_pattern`` and ``glob_patterns`` are provided. + TypeError: If ``glob_patterns`` in YAML is not a list. + FileNotFoundError: If ``root`` or ``config_path`` does not exist. + + Notes: + - **Glob resolution order:** ``glob_patterns`` kwarg → ``glob_pattern`` kwarg + → YAML ``glob_patterns`` → YAML ``glob_pattern`` → ``"**/*.ndjson.gz"`` (fallback). + - **Default YAML globs** match only the six MIMIC shard families that map to + flattened tables, skipping ~10% of PhysioNet exports (Medication, Specimen, …). + - **Cache fingerprinting** includes ``glob_patterns`` and config YAML digest, + so changes invalidate the cache. + + Example: + >>> from pyhealth.datasets import MIMIC4FHIRDataset + >>> # Using default YAML globs (PhysioNet-compatible): + >>> ds = MIMIC4FHIRDataset(root="/data/mimic4_fhir") + >>> print(ds.glob_patterns) # Shows: [MimicPatient*, ..., MimicProcedure*] + >>> # Using a custom glob for non-standard NDJSON naming: + >>> ds = MIMIC4FHIRDataset( + ... root="/data/ndjson", + ... glob_pattern="*.ndjson", + ... max_patients=100, + ... ) + >>> # Using a narrowed set of patterns for faster testing: + >>> ds = MIMIC4FHIRDataset( + ... root="/data/mimic4_fhir", + ... glob_patterns=["**/MimicPatient*.ndjson.gz", "**/MimicObservation*.ndjson.gz"], + ... ) + """ + del ingest_num_shards + default_cfg = os.path.join( os.path.dirname(__file__), "configs", "mimic4_fhir.yaml" ) self._fhir_config_path = str(Path(config_path or default_cfg).resolve()) self._fhir_settings = read_fhir_settings_yaml(self._fhir_config_path) - self.glob_pattern = ( - glob_pattern - if glob_pattern is not None - else str(self._fhir_settings.get("glob_pattern", "**/*.ndjson.gz")) - ) - mp = max_patients - if dev and mp is None: - mp = 1000 - self.max_patients = mp - if ingest_num_shards is not None: - self.ingest_num_shards = max(1, int(ingest_num_shards)) + if glob_pattern is not None and glob_patterns is not None: + raise ValueError("Pass at most one of glob_pattern and glob_patterns.") + if glob_patterns is not None: + self.glob_patterns = list(glob_patterns) + elif glob_pattern is not None: + self.glob_patterns = [glob_pattern] else: - raw_shards = self._fhir_settings.get("ingest_num_shards") - if raw_shards is not None: - self.ingest_num_shards = max(1, int(raw_shards)) + raw_list = self._fhir_settings.get("glob_patterns") + if raw_list: + if not isinstance(raw_list, list): + raise TypeError( + "mimic4_fhir.yaml glob_patterns must be a list of strings." + ) + self.glob_patterns = [str(x) for x in raw_list] + elif self._fhir_settings.get("glob_pattern") is not None: + self.glob_patterns = [str(self._fhir_settings["glob_pattern"])] else: - self.ingest_num_shards = max(4, min(32, (os.cpu_count() or 4) * 2)) - if vocab_path and os.path.isfile(vocab_path): - self.vocab = ConceptVocab.load(vocab_path) - else: - self.vocab = ConceptVocab() + self.glob_patterns = ["**/*.ndjson.gz"] + self.glob_pattern = ( + self.glob_patterns[0] + if len(self.glob_patterns) == 1 + else "; ".join(self.glob_patterns) + ) + self.max_patients = 1000 if dev and max_patients is None else max_patients + self.source_root = str(Path(root).expanduser().resolve()) + self.vocab = ( + ConceptVocab.load(vocab_path) + if vocab_path and os.path.isfile(vocab_path) + else ConceptVocab() + ) super().__init__( - root=root, - tables=["fhir_ndjson"], + root=self.source_root, + tables=FHIR_TABLES, dataset_name="mimic4_fhir", - config_path=None, + config_path=self._fhir_config_path, cache_dir=cache_dir, num_workers=num_workers, dev=dev, @@ -1056,92 +1080,200 @@ def __init__( def _init_cache_dir(self, cache_dir: str | Path | None) -> Path: try: - y_digest = hashlib.sha256( + yaml_digest = hashlib.sha256( Path(self._fhir_config_path).read_bytes() ).hexdigest()[:16] except OSError: - y_digest = "missing" - id_str = orjson.dumps( + yaml_digest = "missing" + identity = orjson.dumps( { - "root": str(self.root), + "source_root": self.source_root, "tables": sorted(self.tables), "dataset_name": self.dataset_name, "dev": self.dev, - "glob_pattern": self.glob_pattern, + "glob_patterns": self.glob_patterns, "max_patients": self.max_patients, - "ingest_num_shards": self.ingest_num_shards, "fhir_schema_version": FHIR_SCHEMA_VERSION, - "fhir_yaml_digest16": y_digest, + "fhir_yaml_digest16": yaml_digest, }, option=orjson.OPT_SORT_KEYS, ).decode("utf-8") - cid = str(uuid.uuid5(uuid.NAMESPACE_DNS, id_str)) + cache_id = str(uuid.uuid5(uuid.NAMESPACE_DNS, identity)) if cache_dir is None: - out = Path(platformdirs.user_cache_dir(appname="pyhealth")) / cid + out = Path(platformdirs.user_cache_dir(appname="pyhealth")) / cache_id out.mkdir(parents=True, exist_ok=True) logger.info("No cache_dir provided. Using default cache dir: %s", out) else: - out = Path(cache_dir) / cid + out = Path(cache_dir) / cache_id out.mkdir(parents=True, exist_ok=True) logger.info("Using provided cache_dir: %s", out) return out - def _event_transform(self, output_dir: Path) -> None: - root = Path(self.root).expanduser().resolve() + @property + def prepared_tables_dir(self) -> Path: + return self.cache_dir / "flattened_tables" + + def _ensure_prepared_tables(self) -> None: + root = Path(self.source_root) if not root.is_dir(): raise FileNotFoundError(f"MIMIC4 FHIR root not found: {root}") + + expected_files = [ + self.prepared_tables_dir / FHIR_TABLE_FILE_NAMES[table_name] + for table_name in FHIR_TABLES + ] + if all(path.is_file() for path in expected_files): + return + + if self.prepared_tables_dir.exists(): + shutil.rmtree(self.prepared_tables_dir) + try: - staging = self.create_tmpdir() / "fhir_event_shards" + if self.max_patients is None: + staging_root = self.create_tmpdir() + staging = staging_root / "flattened_fhir_tables" + staging.mkdir(parents=True, exist_ok=True) + stream_fhir_ndjson_to_flat_tables(root, self.glob_patterns, staging) + shutil.move(str(staging), str(self.prepared_tables_dir)) + return + + staging_root = self.create_tmpdir() + staging = staging_root / "flattened_fhir_tables" staging.mkdir(parents=True, exist_ok=True) - stream_fhir_ndjson_root_to_sharded_parquet( - root, - self.glob_pattern, - staging, - num_shards=self.ingest_num_shards, - batch_size=50_000, + stream_fhir_ndjson_to_flat_tables(root, self.glob_patterns, staging) + + filtered_root = self.create_tmpdir() + filtered = filtered_root / "flattened_fhir_tables_filtered" + patient_ids = _sorted_patient_ids_from_flat_tables(staging) + keep_ids = patient_ids[: self.max_patients] + filter_flat_tables_by_patient_ids(staging, filtered, keep_ids) + shutil.move(str(filtered), str(self.prepared_tables_dir)) + finally: + self.clean_tmpdir() + + def _event_transform(self, output_dir: Path) -> None: + self._ensure_prepared_tables() + super()._event_transform(output_dir) + + def load_table(self, table_name: str) -> dd.DataFrame: + """Load one flattened Parquet table, mirroring BaseDataset.load_table's contract. + + Differences from the base CSV path that are intentional and FHIR-specific: + - Source is a pre-built Parquet file under ``prepared_tables_dir``, not CSV. + - Timestamps use ``errors="coerce"`` (FHIR ISO strings include timezone ``Z`` suffix + or are partial dates; ``errors="raise"`` would break). + - After timestamp parsing, any tz-aware column is stripped to naive UTC + (Dask's ``to_parquet`` / ``sort_values`` path cannot handle tz-aware datetimes). + - Rows with null ``patient_id`` are dropped before returning so the caller's + ``sort_values("patient_id")`` in ``_event_transform`` never sees null keys. + Everything else (column lowercasing, preprocess hook, join, attribute renaming) + matches BaseDataset.load_table exactly. + + NOTE: This method mirrors BaseDataset.load_table (base_dataset.py). + The ONLY deviations are: + 1. dd.read_parquet() instead of _scan_csv_tsv_gz() + 2. errors="coerce" + utc=True in dd.to_datetime + 3. map_partitions(tz_localize(None)) for tz-aware dates + 4. dropna(subset=["patient_id"]) + If BaseDataset.load_table changes, audit those 4 points here. + """ + + assert self.config is not None, "Config must be provided to load tables" + if table_name not in self.config.tables: + raise ValueError(f"Table {table_name} not found in config") + + table_cfg = self.config.tables[table_name] + path = self.prepared_tables_dir / table_cfg.file_path + if not path.exists(): + raise FileNotFoundError(f"Flattened table not found: {path}") + + logger.info("Scanning FHIR flattened table: %s from %s", table_name, path) + df: dd.DataFrame = dd.read_parquet( + str(path), + split_row_groups=True, # type: ignore[arg-type] + blocksize="64MB", + ).replace("", pd.NA) + + # Mirror BaseDataset.load_table: lowercase columns before preprocess hook. + df = df.rename(columns=str.lower) + + # Mirror BaseDataset.load_table: optional preprocess_{table_name} hook. + preprocess_func = getattr(self, f"preprocess_{table_name}", None) + if preprocess_func is not None: + logger.info( + "Preprocessing FHIR table: %s with %s", table_name, preprocess_func.__name__ ) - staged_files = sorted(staging.glob("shard-*.parquet")) - if output_dir.exists(): - shutil.rmtree(output_dir) - output_dir.mkdir(parents=True, exist_ok=True) - - keep: Optional[Set[str]] = None - if self.max_patients is not None: - lf_all = pl.concat( - [pl.scan_parquet(str(p)) for p in staged_files] - ) - pids = ( - lf_all.select("patient_id") - .unique() - .sort("patient_id") - .collect(engine="streaming")["patient_id"] - .to_list() + df = preprocess_func(nw.from_native(df)).to_native() # type: ignore[union-attr] + + # Mirror BaseDataset.load_table: handle joins (resolved against prepared_tables_dir). + for join_cfg in table_cfg.join: + other_path = self.prepared_tables_dir / Path(join_cfg.file_path).name + if not other_path.exists(): + raise FileNotFoundError(f"FHIR join table not found: {other_path}") + logger.info("Joining FHIR table %s with %s", table_name, other_path) + join_df: dd.DataFrame = dd.read_parquet( + str(other_path), + split_row_groups=True, # type: ignore[arg-type] + blocksize="64MB", + ).replace("", pd.NA) + join_df = join_df.rename(columns=str.lower) + join_key = join_cfg.on.lower() + columns = [c.lower() for c in join_cfg.columns] + df = df.merge(join_df[[join_key] + columns], on=join_key, how=join_cfg.how) + + patient_id_col = table_cfg.patient_id + timestamp_col = table_cfg.timestamp + timestamp_format = table_cfg.timestamp_format + attribute_cols = table_cfg.attributes + + # Timestamp parsing: coerce rather than raise for FHIR heterogeneous strings. + if timestamp_col: + if isinstance(timestamp_col, list): + timestamp_series: dd.Series = functools.reduce( + operator.add, (df[col].astype("string") for col in timestamp_col) ) - keep = set(pids[: self.max_patients]) - - if keep is None: - for i, p in enumerate(staged_files): - shutil.move(str(p), str(output_dir / f"part-{i:05d}.parquet")) else: - for i, p in enumerate(staged_files): - pl.scan_parquet(str(p)).filter( - pl.col("patient_id").is_in(keep) - ).sink_parquet(str(output_dir / f"part-{i:05d}.parquet")) - except Exception as e: - if output_dir.exists(): - logger.error( - "Error during FHIR event caching, removing incomplete dir %s", - output_dir, - ) - shutil.rmtree(output_dir) - raise e - finally: - self.clean_tmpdir() + timestamp_series = df[timestamp_col].astype("string") + + # utc=True avoids mixed-offset parse errors; we strip tz after. + timestamp_series = dd.to_datetime( + timestamp_series, + format=timestamp_format, + errors="coerce", + utc=True, + ) + + def _strip_tz_to_naive_ms(part: pd.Series) -> pd.Series: + if getattr(part.dtype, "tz", None) is not None: + part = part.dt.tz_localize(None) + return part.astype("datetime64[ms]") + + timestamp_series = timestamp_series.map_partitions(_strip_tz_to_naive_ms) + df = df.assign(timestamp=timestamp_series) + else: + df = df.assign(timestamp=pd.NaT) + + # Mirror BaseDataset.load_table: patient_id from config column or row index. + if patient_id_col: + df = df.assign(patient_id=df[patient_id_col].astype("string")) + else: + df = df.reset_index(drop=True) + df = df.assign(patient_id=df.index.astype("string")) + + # Drop rows without a patient key; BaseDataset._event_transform's sort_values + # on "patient_id" fails on null keys with Dask's division-calculation logic. + df = df.dropna(subset=["patient_id"]) + + df = df.assign(event_type=table_name) + + rename_attr = {attr.lower(): f"{table_name}/{attr}" for attr in attribute_cols} + df = df.rename(columns=rename_attr) + attr_cols = [rename_attr[attr.lower()] for attr in attribute_cols] + final_cols = ["patient_id", "event_type", "timestamp"] + attr_cols + return df[final_cols] @property def unique_patient_ids(self) -> List[str]: - """Sorted unique patient ids (stable across multi-part Parquet caches).""" - if self._unique_patient_ids is None: self._unique_patient_ids = ( self.global_event_df.select("patient_id") @@ -1153,20 +1285,6 @@ def unique_patient_ids(self) -> List[str]: logger.info("Found %d unique patient IDs", len(self._unique_patient_ids)) return self._unique_patient_ids - def iter_patients(self, df: Optional[pl.LazyFrame] = None) -> Iterator[Patient]: - if df is not None: - yield from super().iter_patients(df) - return - base = self.global_event_df - for patient_id in self.unique_patient_ids: - patient_df = base.filter(pl.col("patient_id") == patient_id).collect( - engine="streaming" - ) - yield Patient(patient_id=patient_id, data_source=patient_df) - - def stats(self) -> None: - super().stats() - def set_task( self, task: Any = None, @@ -1179,27 +1297,24 @@ def set_task( raise ValueError( "Pass a task instance, e.g. MPFClinicalPredictionTask(max_len=512)." ) + from pyhealth.tasks.mpf_clinical_prediction import MPFClinicalPredictionTask if isinstance(task, MPFClinicalPredictionTask): - nw = ( + worker_count = ( 1 if in_notebook() else (num_workers if num_workers is not None else self.num_workers) ) - # Match :meth:`BaseDataset._task_transform`: unique ids after ``pre_filter``, - # not the full cached cohort (e.g. task subclasses that cap patients). warmup_pids = self._mpf_patient_ids_for_task(task) - pid_n = len(warmup_pids) - effective_workers = min(nw, pid_n) if pid_n else 1 + patient_count = len(warmup_pids) + effective_workers = min(worker_count, patient_count) if patient_count else 1 ensure_special_tokens(self.vocab) - # Always warm in the main process: single-worker ``set_task`` can skip - # ``BaseDataset._task_transform`` on LitData cache hit, leaving a fresh - # vocab at specials-only without this pass. self._warm_mpf_vocabulary(task, warmup_pids) task.frozen_vocab = effective_workers > 1 task.vocab = self.vocab task._specials = ensure_special_tokens(self.vocab) + return super().set_task( task, num_workers, @@ -1208,11 +1323,9 @@ def set_task( ) def _mpf_patient_ids_for_task(self, task: Any) -> List[str]: - """Sorted unique patient ids that ``task`` will see (same as ``_task_transform``).""" - - lf_filtered = task.pre_filter(self.global_event_df) + filtered = task.pre_filter(self.global_event_df) return ( - lf_filtered.select("patient_id") + filtered.select("patient_id") .unique() .collect(engine="streaming") .to_series() @@ -1221,20 +1334,9 @@ def _mpf_patient_ids_for_task(self, task: Any) -> List[str]: ) def _warm_mpf_vocabulary(self, task: Any, patient_ids: List[str]) -> None: - """Main-process vocab keys only (parallel ``set_task`` workers use frozen vocab). - - Args: - task: MPF task (uses ``max_len`` for clinical token cap). - patient_ids: Cohort to warm (post-``pre_filter``), not necessarily - :attr:`unique_patient_ids`. - """ - clinical_cap = max(0, task.max_len - 2) - # Same batching as :func:`_task_transform_fn` — one collect per batch, not - # one full scan per patient. - batch_size = 128 base = self.global_event_df - for batch in itertools.batched(patient_ids, batch_size): + for batch in itertools.batched(patient_ids, 128): patients = ( base.filter(pl.col("patient_id").is_in(batch)) .collect(engine="streaming") @@ -1242,17 +1344,14 @@ def _warm_mpf_vocabulary(self, task: Any, patient_ids: List[str]) -> None: ) for patient_key, patient_df in patients.items(): patient_id = patient_key[0] - py_patient = Patient(patient_id=patient_id, data_source=patient_df) - fp = fhir_patient_from_patient(py_patient) - warm_mpf_vocab_from_fhir_patient(self.vocab, fp, clinical_cap) + patient = Patient(patient_id=patient_id, data_source=patient_df) + warm_mpf_vocab_from_patient(self.vocab, patient, clinical_cap) def gather_samples(self, task: Any) -> List[Dict[str, Any]]: - """Run ``task`` on each :class:`~pyhealth.data.Patient` (tabular path).""" - task.vocab = self.vocab task._specials = None task.frozen_vocab = False samples: List[Dict[str, Any]] = [] - for p in self.iter_patients(): - samples.extend(task(p)) + for patient in self.iter_patients(): + samples.extend(task(patient)) return samples diff --git a/pyhealth/tasks/mpf_clinical_prediction.py b/pyhealth/tasks/mpf_clinical_prediction.py index c1958870d..5a8e7ed5f 100644 --- a/pyhealth/tasks/mpf_clinical_prediction.py +++ b/pyhealth/tasks/mpf_clinical_prediction.py @@ -2,17 +2,15 @@ from __future__ import annotations -from typing import Any, Dict, List, Optional, Union +from typing import Any, Dict, List, Optional import torch from pyhealth.data import Patient from pyhealth.datasets.mimic4_fhir import ( ConceptVocab, - FHIRPatient, build_cehr_sequences, ensure_special_tokens, - fhir_patient_from_patient, infer_mortality_label, ) @@ -46,9 +44,9 @@ def _left_pad_float(seq: List[float], max_len: int, pad: float = 0.0) -> List[fl class MPFClinicalPredictionTask(BaseTask): """Binary mortality prediction from FHIR CEHR sequences with optional MPF tokens. - Works on :class:`~pyhealth.data.Patient` (standard ``global_event_df`` / - :meth:`~pyhealth.datasets.MIMIC4FHIRDataset.set_task` path) or legacy - :class:`~pyhealth.datasets.mimic4_fhir.FHIRPatient`. For :meth:`set_task`, + Works on :class:`~pyhealth.data.Patient` via the standard + ``global_event_df`` / :meth:`~pyhealth.datasets.MIMIC4FHIRDataset.set_task` + path. For :meth:`set_task`, :class:`~pyhealth.datasets.MIMIC4FHIRDataset` reserves specials, warms concept keys in the main process over the same patient cohort as :meth:`~pyhealth.tasks.base_task.BaseTask.pre_filter` (including when LitData @@ -90,14 +88,11 @@ def _ensure_vocab(self) -> ConceptVocab: self._specials = ensure_special_tokens(self.vocab) return self.vocab - def __call__( - self, patient: Union[Patient, FHIRPatient] - ) -> List[Dict[str, Any]]: + def __call__(self, patient: Patient) -> List[Dict[str, Any]]: """Build one labeled sample dict per patient. Args: - patient: Tabular :class:`~pyhealth.data.Patient` or legacy - :class:`~pyhealth.datasets.mimic4_fhir.FHIRPatient`. + patient: A tabular :class:`~pyhealth.data.Patient`. Returns: A one-element list with ``concept_ids``, tensor-ready feature lists, and @@ -105,12 +100,7 @@ def __call__( ``max_len == 2`` the sequence is ````/```` and ```` only. """ vocab = self._ensure_vocab() - if isinstance(patient, Patient): - fhir_patient = fhir_patient_from_patient(patient) - pid = patient.patient_id - else: - fhir_patient = patient - pid = patient.patient_id + pid = patient.patient_id clinical_cap = max(0, self.max_len - 2) ( concept_ids, @@ -120,7 +110,7 @@ def __call__( visit_orders, visit_segments, ) = build_cehr_sequences( - fhir_patient, + patient, vocab, clinical_cap, grow_vocab=not self.frozen_vocab, @@ -146,7 +136,7 @@ def __call__( visit_orders = _left_pad_int(visit_orders, ml, 0) visit_segments = _left_pad_int(visit_segments, ml, 0) - label = infer_mortality_label(fhir_patient) + label = infer_mortality_label(patient) return [ { "patient_id": pid, diff --git a/tests/core/test_mimic4_fhir_dataset.py b/tests/core/test_mimic4_fhir_dataset.py index 9043b69fa..554f5bfe5 100644 --- a/tests/core/test_mimic4_fhir_dataset.py +++ b/tests/core/test_mimic4_fhir_dataset.py @@ -2,18 +2,18 @@ import tempfile import unittest from pathlib import Path -from typing import Any, Dict, List +from typing import Dict, List import orjson import polars as pl +from pyhealth.data import Patient from pyhealth.datasets import MIMIC4FHIRDataset from pyhealth.datasets.mimic4_fhir import ( ConceptVocab, - FHIR_RESOURCE_JSON_COL, - FHIR_EVENT_TYPE, + _flatten_resource_to_table_row, build_cehr_sequences, - fhir_patient_from_patient, + collect_cehr_timeline_events, infer_mortality_label, synthetic_mpf_two_patient_ndjson_text, ) @@ -25,9 +25,7 @@ ) -def _third_patient_loinc_resources() -> List[Dict[str, Any]]: - """Third synthetic patient with a LOINC code not present on p-synth-1/2.""" - +def _third_patient_loinc_resources() -> List[Dict[str, object]]: return [ { "resourceType": "Patient", @@ -53,21 +51,82 @@ def _third_patient_loinc_resources() -> List[Dict[str, Any]]: def write_two_class_plus_third_ndjson(directory: Path, *, name: str = "fixture.ndjson") -> Path: - """Two-class PhysioNet-style fixture plus an extra patient (LOINC 999-9).""" - lines = synthetic_mpf_two_patient_ndjson_text().strip().split("\n") - lines.extend( - orjson.dumps(r).decode("utf-8") for r in _third_patient_loinc_resources() - ) + lines.extend(orjson.dumps(r).decode("utf-8") for r in _third_patient_loinc_resources()) path = directory / name path.write_text("\n".join(lines) + "\n", encoding="utf-8") return path +def _patient_from_rows(patient_id: str, rows: List[Dict[str, object]]) -> Patient: + return Patient(patient_id=patient_id, data_source=pl.DataFrame(rows)) + + +class TestDeceasedBooleanFlattening(unittest.TestCase): + def test_string_false_not_coerced_by_python_bool(self) -> None: + """Non-conformant ``\"false\"`` string must not become stored ``\"true\"``.""" + row = _flatten_resource_to_table_row( + { + "resourceType": "Patient", + "id": "p-str-false", + "deceasedBoolean": "false", + } + ) + self.assertIsNotNone(row) + _table, payload = row + self.assertEqual(payload.get("deceased_boolean"), "false") + + def test_string_true_parsed(self) -> None: + row = _flatten_resource_to_table_row( + { + "resourceType": "Patient", + "id": "p-str-true", + "deceasedBoolean": "true", + } + ) + self.assertIsNotNone(row) + self.assertEqual(row[1].get("deceased_boolean"), "true") + + def test_json_booleans_unchanged(self) -> None: + for raw, expected in ((True, "true"), (False, "false")): + with self.subTest(raw=raw): + row = _flatten_resource_to_table_row( + { + "resourceType": "Patient", + "id": "p-bool", + "deceasedBoolean": raw, + } + ) + self.assertIsNotNone(row) + self.assertEqual(row[1].get("deceased_boolean"), expected) + + def test_unknown_deceased_type_stored_as_none(self) -> None: + row = _flatten_resource_to_table_row( + { + "resourceType": "Patient", + "id": "p-garbage", + "deceasedBoolean": {"unexpected": "object"}, + } + ) + self.assertIsNotNone(row) + self.assertIsNone(row[1].get("deceased_boolean")) + + def test_infer_mortality_respects_string_false_row(self) -> None: + patient = _patient_from_rows( + "p1", + [ + { + "event_type": "patient", + "timestamp": "2020-01-01T00:00:00", + "patient/deceased_boolean": "false", + }, + ], + ) + self.assertEqual(infer_mortality_label(patient), 0) + + class TestMIMIC4FHIRDataset(unittest.TestCase): def test_concept_vocab_from_json_empty_token_to_id(self) -> None: - """Corrupted save with empty ``token_to_id`` must not call ``max()`` on [].""" - v = ConceptVocab.from_json({"token_to_id": {}}) self.assertIn("", v.token_to_id) self.assertIn("", v.token_to_id) @@ -77,18 +136,61 @@ def test_concept_vocab_from_json_empty_respects_next_id(self) -> None: v = ConceptVocab.from_json({"token_to_id": {}, "next_id": 50}) self.assertEqual(v._next_id, 50) - def test_disk_fixture_resolves_events_per_patient(self) -> None: - """NDJSON on disk → Parquet cache carries multiple rows for ``p-synth-1``.""" + def test_sorted_ndjson_files_accepts_sequence_and_dedupes(self) -> None: + from pyhealth.datasets.mimic4_fhir import sorted_ndjson_files with tempfile.TemporaryDirectory() as tmp: - tdir = Path(tmp) - write_one_patient_ndjson(tdir) + root = Path(tmp) + (root / "MimicPatient.ndjson.gz").write_text("x", encoding="utf-8") + (root / "MimicMedication.ndjson.gz").write_text("y", encoding="utf-8") + (root / "notes.txt").write_text("z", encoding="utf-8") + wide = sorted_ndjson_files(root, "**/*.ndjson.gz") + narrow = sorted_ndjson_files( + root, + ["MimicPatient*.ndjson.gz", "**/MimicPatient*.ndjson.gz"], + ) + self.assertEqual(len(wide), 2) + self.assertEqual(len(narrow), 1) + self.assertEqual(narrow[0].name, "MimicPatient.ndjson.gz") + + def test_dataset_accepts_glob_patterns_kwarg(self) -> None: + with tempfile.TemporaryDirectory() as tmp: + write_one_patient_ndjson(Path(tmp)) + ds = MIMIC4FHIRDataset( + root=tmp, glob_patterns=["*.ndjson"], cache_dir=tmp + ) + self.assertEqual(ds.glob_patterns, ["*.ndjson"]) + _ = ds.global_event_df.collect(engine="streaming") + + def test_dataset_rejects_both_glob_kwargs(self) -> None: + with tempfile.TemporaryDirectory() as tmp: + with self.assertRaises(ValueError): + MIMIC4FHIRDataset( + root=tmp, + glob_pattern="*.ndjson", + glob_patterns=["*.ndjson"], + cache_dir=tmp, + ) + + def test_disk_fixture_resolves_events_per_patient(self) -> None: + with tempfile.TemporaryDirectory() as tmp: + write_one_patient_ndjson(Path(tmp)) ds = MIMIC4FHIRDataset(root=tmp, glob_pattern="*.ndjson", cache_dir=tmp) - sub = ( - ds.global_event_df.filter(pl.col("patient_id") == "p-synth-1") - .collect(engine="streaming") + sub = ds.global_event_df.filter(pl.col("patient_id") == "p-synth-1").collect( + engine="streaming" ) self.assertGreaterEqual(len(sub), 2) + self.assertIn("condition/concept_key", sub.columns) + + def test_prepared_flat_tables_exist(self) -> None: + with tempfile.TemporaryDirectory() as tmp: + write_two_class_ndjson(Path(tmp)) + ds = MIMIC4FHIRDataset(root=tmp, glob_pattern="*.ndjson", cache_dir=tmp) + _ = ds.global_event_df.collect(engine="streaming") + prepared = ds.prepared_tables_dir + self.assertTrue((prepared / "patient.parquet").is_file()) + self.assertTrue((prepared / "encounter.parquet").is_file()) + self.assertTrue((prepared / "condition.parquet").is_file()) def test_build_cehr_non_empty(self) -> None: from pyhealth.tasks.mpf_clinical_prediction import MPFClinicalPredictionTask @@ -102,8 +204,6 @@ def test_build_cehr_non_empty(self) -> None: self.assertGreater(ds.vocab.vocab_size, 2) def test_set_task_vocab_warm_on_litdata_cache_hit(self) -> None: - """MPF ``set_task`` must fill ``ds.vocab`` even when ``_task_transform`` skips.""" - from pyhealth.tasks.mpf_clinical_prediction import MPFClinicalPredictionTask with tempfile.TemporaryDirectory() as tmp: @@ -112,9 +212,7 @@ def test_set_task_vocab_warm_on_litdata_cache_hit(self) -> None: ds1 = MIMIC4FHIRDataset(root=tmp, glob_pattern="*.ndjson", cache_dir=tmp) ds1.set_task(MPFClinicalPredictionTask(**task_kw), num_workers=1) warm_size = ds1.vocab.vocab_size - self.assertGreater( - warm_size, 6, "fixture plus MPF specials should exceed pad/unk only" - ) + self.assertGreater(warm_size, 6) ds2 = MIMIC4FHIRDataset(root=tmp, glob_pattern="*.ndjson", cache_dir=tmp) ds2.set_task(MPFClinicalPredictionTask(**task_kw), num_workers=1) self.assertEqual(ds2.vocab.vocab_size, warm_size) @@ -125,125 +223,83 @@ def test_mortality_heuristic(self) -> None: with tempfile.TemporaryDirectory() as tmp: write_two_class_ndjson(Path(tmp)) ds = MIMIC4FHIRDataset(root=tmp, glob_pattern="*.ndjson", cache_dir=tmp) - task = MPFClinicalPredictionTask(max_len=64, use_mpf=False) - samples = ds.gather_samples(task) - labels = {s["label"] for s in samples} - self.assertEqual(labels, {0, 1}) + samples = ds.gather_samples(MPFClinicalPredictionTask(max_len=64, use_mpf=False)) + self.assertEqual({s["label"] for s in samples}, {0, 1}) def test_infer_deceased(self) -> None: with tempfile.TemporaryDirectory() as tmp: write_two_class_ndjson(Path(tmp)) ds = MIMIC4FHIRDataset(root=tmp, glob_pattern="*.ndjson", cache_dir=tmp) - dead = fhir_patient_from_patient(ds.get_patient("p-synth-2")) + dead = ds.get_patient("p-synth-2") self.assertEqual(infer_mortality_label(dead), 1) def test_disk_ndjson_gz_physionet_style(self) -> None: - """Gzip NDJSON (PhysioNet ``*.ndjson.gz``) matches default glob when set.""" - with tempfile.TemporaryDirectory() as tmp: gz_path = Path(tmp) / "fixture.ndjson.gz" with gzip.open(gz_path, "wt", encoding="utf-8") as gz: gz.write(ndjson_two_class_text()) - ds = MIMIC4FHIRDataset( - root=tmp, glob_pattern="*.ndjson.gz", max_patients=5 - ) + ds = MIMIC4FHIRDataset(root=tmp, glob_pattern="*.ndjson.gz", max_patients=5) self.assertGreaterEqual(len(ds.unique_patient_ids), 1) def test_disk_ndjson_temp_dir(self) -> None: - """Load from a temp directory (cleanup via context manager).""" + from pyhealth.tasks.mpf_clinical_prediction import MPFClinicalPredictionTask with tempfile.TemporaryDirectory() as tmp: write_two_class_ndjson(Path(tmp)) - ds = MIMIC4FHIRDataset( - root=tmp, glob_pattern="*.ndjson", max_patients=5 - ) + ds = MIMIC4FHIRDataset(root=tmp, glob_pattern="*.ndjson", max_patients=5) self.assertEqual(len(ds.unique_patient_ids), 2) - from pyhealth.tasks.mpf_clinical_prediction import ( - MPFClinicalPredictionTask, - ) - - task = MPFClinicalPredictionTask(max_len=48, use_mpf=True) - samples = ds.gather_samples(task) + samples = ds.gather_samples(MPFClinicalPredictionTask(max_len=48, use_mpf=True)) self.assertGreaterEqual(len(samples), 1) - for s in samples: - self.assertIn("concept_ids", s) - self.assertIn("label", s) - - def test_sharded_ingest_sorted_patient_ids_multi_part_cache(self) -> None: - """Hash shards → ``part-*.parquet``; patient ids exposed in sorted order.""" + for sample in samples: + self.assertIn("concept_ids", sample) + self.assertIn("label", sample) + def test_global_event_df_schema_and_flattened_columns(self) -> None: with tempfile.TemporaryDirectory() as tmp: write_two_class_ndjson(Path(tmp)) - ds = MIMIC4FHIRDataset( - root=tmp, - glob_pattern="*.ndjson", - cache_dir=tmp, - ingest_num_shards=8, - ) - ids = ds.unique_patient_ids - self.assertEqual(ids, sorted(ids)) - self.assertEqual(set(ids), {"p-synth-1", "p-synth-2"}) - part_dir = ds.cache_dir / "global_event_df.parquet" - parts = sorted(part_dir.glob("part-*.parquet")) - self.assertGreaterEqual(len(parts), 1) - # ``p-synth-1`` / ``p-synth-2`` crc32 to different slots for 8 shards. - self.assertGreaterEqual(len(parts), 2) - - def test_global_event_df_schema_and_streaming_path(self) -> None: - with tempfile.TemporaryDirectory() as tmp: - write_two_class_ndjson(Path(tmp)) - ds = MIMIC4FHIRDataset( - root=tmp, - glob_pattern="*.ndjson", - cache_dir=tmp, - max_patients=5, - ) - lf = ds.global_event_df - df = lf.collect(engine="streaming") + ds = MIMIC4FHIRDataset(root=tmp, glob_pattern="*.ndjson", cache_dir=tmp) + df = ds.global_event_df.collect(engine="streaming") self.assertGreater(len(df), 0) self.assertIn("patient_id", df.columns) self.assertIn("timestamp", df.columns) self.assertIn("event_type", df.columns) - self.assertIn(FHIR_RESOURCE_JSON_COL, df.columns) - self.assertTrue((df["event_type"] == FHIR_EVENT_TYPE).all()) + self.assertIn("condition/concept_key", df.columns) + self.assertIn("observation/concept_key", df.columns) + self.assertIn("patient/deceased_boolean", df.columns) def test_set_task_parity_with_gather_samples_ndjson(self) -> None: + from pyhealth.tasks.mpf_clinical_prediction import MPFClinicalPredictionTask + with tempfile.TemporaryDirectory() as tmp: write_two_class_ndjson(Path(tmp), name="fx.ndjson") - from pyhealth.tasks.mpf_clinical_prediction import ( - MPFClinicalPredictionTask, - ) - ds = MIMIC4FHIRDataset( root=tmp, glob_pattern="*.ndjson", cache_dir=tmp, num_workers=1 ) - task = MPFClinicalPredictionTask(max_len=48, use_mpf=True) - ref = sorted(ds.gather_samples(task), key=lambda s: s["patient_id"]) - task2 = MPFClinicalPredictionTask(max_len=48, use_mpf=True) - sample_ds = ds.set_task(task2, num_workers=1) + ref = sorted( + ds.gather_samples(MPFClinicalPredictionTask(max_len=48, use_mpf=True)), + key=lambda s: s["patient_id"], + ) + sample_ds = ds.set_task( + MPFClinicalPredictionTask(max_len=48, use_mpf=True), num_workers=1 + ) got = sorted( [sample_ds[i] for i in range(len(sample_ds))], key=lambda s: s["patient_id"], ) self.assertEqual(len(got), len(ref)) - for a, b in zip(ref, got): - self.assertEqual(a["label"], int(b["label"])) - ac = a["concept_ids"] - bc = b["concept_ids"] - if hasattr(bc, "tolist"): - bc = bc.tolist() - self.assertEqual(ac, bc) + for expected, actual in zip(ref, got): + self.assertEqual(expected["label"], int(actual["label"])) + actual_ids = actual["concept_ids"] + if hasattr(actual_ids, "tolist"): + actual_ids = actual_ids.tolist() + self.assertEqual(expected["concept_ids"], actual_ids) def test_gather_samples_resets_frozen_vocab_after_set_task(self) -> None: - """Reusing the same task after ``set_task`` must grow a new dataset's vocab.""" + from pyhealth.tasks.mpf_clinical_prediction import MPFClinicalPredictionTask with tempfile.TemporaryDirectory() as tmp_a, tempfile.TemporaryDirectory() as tmp_b: write_two_class_ndjson(Path(tmp_a), name="a.ndjson") write_two_class_ndjson(Path(tmp_b), name="b.ndjson") - from pyhealth.tasks.mpf_clinical_prediction import ( - MPFClinicalPredictionTask, - ) - ds_a = MIMIC4FHIRDataset( root=tmp_a, glob_pattern="*.ndjson", cache_dir=tmp_a, num_workers=1 ) @@ -252,7 +308,6 @@ def test_gather_samples_resets_frozen_vocab_after_set_task(self) -> None: ) task = MPFClinicalPredictionTask(max_len=48, use_mpf=True) ds_a.set_task(task, num_workers=1) - # Single-process transform: vocab grows during caching; no pre-warm pass. self.assertFalse(task.frozen_vocab) ref = sorted( @@ -261,23 +316,15 @@ def test_gather_samples_resets_frozen_vocab_after_set_task(self) -> None: ) got = sorted(ds_b.gather_samples(task), key=lambda s: s["patient_id"]) self.assertEqual(len(got), len(ref)) - for a, b in zip(ref, got): - self.assertEqual(a["label"], b["label"]) - ac = a["concept_ids"] - bc = b["concept_ids"] - if hasattr(bc, "tolist"): - bc = bc.tolist() - self.assertEqual(ac, bc) + for expected, actual in zip(ref, got): + self.assertEqual(expected["label"], actual["label"]) + self.assertEqual(expected["concept_ids"], actual["concept_ids"]) def test_set_task_multi_worker_sets_frozen_vocab(self) -> None: - """``effective_workers > 1`` requires main-process warmup and frozen ids.""" + from pyhealth.tasks.mpf_clinical_prediction import MPFClinicalPredictionTask with tempfile.TemporaryDirectory() as tmp: write_two_class_ndjson(Path(tmp)) - from pyhealth.tasks.mpf_clinical_prediction import ( - MPFClinicalPredictionTask, - ) - ds = MIMIC4FHIRDataset( root=tmp, glob_pattern="*.ndjson", cache_dir=tmp, num_workers=2 ) @@ -286,8 +333,6 @@ def test_set_task_multi_worker_sets_frozen_vocab(self) -> None: self.assertTrue(task.frozen_vocab) def test_mpf_pre_filter_vocab_warmup_excludes_dropped_patients(self) -> None: - """Warmup must not deserialize patients omitted by ``pre_filter``.""" - from pyhealth.tasks.mpf_clinical_prediction import MPFClinicalPredictionTask class TwoPatientMPFTask(MPFClinicalPredictionTask): @@ -306,8 +351,6 @@ def pre_filter(self, df: pl.LazyFrame) -> pl.LazyFrame: self.assertIn("http://loinc.org|789-0", ds.vocab.token_to_id) def test_mpf_pre_filter_patient_ids_drive_effective_workers(self) -> None: - """``len(_mpf_patient_ids_for_task)`` must match ``_task_transform`` slicing.""" - from pyhealth.tasks.mpf_clinical_prediction import MPFClinicalPredictionTask class OnePatientMPFTask(MPFClinicalPredictionTask): @@ -322,303 +365,260 @@ def pre_filter(self, df: pl.LazyFrame) -> pl.LazyFrame: task = OnePatientMPFTask(max_len=48, use_mpf=True) warmup_pids = ds._mpf_patient_ids_for_task(task) self.assertEqual(warmup_pids, ["p-synth-1"]) - nw = 2 - pid_n = len(warmup_pids) - effective_workers = min(nw, pid_n) if pid_n else 1 + effective_workers = min(2, len(warmup_pids)) if warmup_pids else 1 self.assertEqual(effective_workers, 1) - self.assertFalse(effective_workers > 1) def test_encounter_reference_requires_exact_id(self) -> None: - """``e1`` must not match reference ``Encounter/e10`` (substring bug).""" - - from pyhealth.datasets.mimic4_fhir import FHIRPatient - - patient_r = { - "resourceType": "Patient", - "id": "p1", - "birthDate": "1950-01-01", - } - enc1 = { - "resourceType": "Encounter", - "id": "e1", - "subject": {"reference": "Patient/p1"}, - "period": {"start": "2020-06-01T10:00:00Z"}, - "class": {"code": "AMB"}, - } - enc10 = { - "resourceType": "Encounter", - "id": "e10", - "subject": {"reference": "Patient/p1"}, - "period": {"start": "2020-07-02T10:00:00Z"}, - "class": {"code": "IMP"}, - } - cond_e10 = { - "resourceType": "Condition", - "id": "c99", - "subject": {"reference": "Patient/p1"}, - "encounter": {"reference": "Encounter/e10"}, - "code": { - "coding": [{"system": "http://hl7.org/fhir/sid/icd-10-cm", "code": "I99"}] - }, - "onsetDateTime": "2020-07-02T11:00:00Z", - } - pr = FHIRPatient( - patient_id="p1", - resources=[patient_r, enc1, enc10, cond_e10], + patient = _patient_from_rows( + "p1", + [ + { + "patient_id": "p1", + "event_type": "patient", + "timestamp": None, + "patient/birth_date": "1950-01-01", + }, + { + "patient_id": "p1", + "event_type": "encounter", + "timestamp": "2020-06-01T10:00:00", + "encounter/encounter_id": "e1", + "encounter/encounter_class": "AMB", + }, + { + "patient_id": "p1", + "event_type": "encounter", + "timestamp": "2020-07-02T10:00:00", + "encounter/encounter_id": "e10", + "encounter/encounter_class": "IMP", + }, + { + "patient_id": "p1", + "event_type": "condition", + "timestamp": "2020-07-02T11:00:00", + "condition/encounter_id": "e10", + "condition/concept_key": "http://hl7.org/fhir/sid/icd-10-cm|I99", + }, + ], ) vocab = ConceptVocab() - concept_ids, *_ = build_cehr_sequences(pr, vocab, max_len=64) + concept_ids, *_ = build_cehr_sequences(patient, vocab, max_len=64) tid = vocab["http://hl7.org/fhir/sid/icd-10-cm|I99"] self.assertEqual(concept_ids.count(tid), 1) def test_unlinked_condition_emitted_once_with_two_encounters(self) -> None: - """No encounter.reference: must not duplicate once per encounter loop.""" - - from pyhealth.datasets.mimic4_fhir import FHIRPatient - - patient_r = { - "resourceType": "Patient", - "id": "p1", - "birthDate": "1950-01-01", - } - enc_a = { - "resourceType": "Encounter", - "id": "ea", - "subject": {"reference": "Patient/p1"}, - "period": {"start": "2020-06-01T10:00:00Z"}, - "class": {"code": "AMB"}, - } - enc_b = { - "resourceType": "Encounter", - "id": "eb", - "subject": {"reference": "Patient/p1"}, - "period": {"start": "2020-07-01T10:00:00Z"}, - "class": {"code": "IMP"}, - } - cond = { - "resourceType": "Condition", - "id": "cx", - "subject": {"reference": "Patient/p1"}, - "code": { - "coding": [{"system": "http://hl7.org/fhir/sid/icd-10-cm", "code": "Z00"}] - }, - "onsetDateTime": "2020-06-15T12:00:00Z", - } - pr = FHIRPatient( - patient_id="p1", - resources=[patient_r, enc_a, enc_b, cond], + patient = _patient_from_rows( + "p1", + [ + { + "patient_id": "p1", + "event_type": "patient", + "timestamp": None, + "patient/birth_date": "1950-01-01", + }, + { + "patient_id": "p1", + "event_type": "encounter", + "timestamp": "2020-06-01T10:00:00", + "encounter/encounter_id": "ea", + "encounter/encounter_class": "AMB", + }, + { + "patient_id": "p1", + "event_type": "encounter", + "timestamp": "2020-07-01T10:00:00", + "encounter/encounter_id": "eb", + "encounter/encounter_class": "IMP", + }, + { + "patient_id": "p1", + "event_type": "condition", + "timestamp": "2020-06-15T12:00:00", + "condition/concept_key": "http://hl7.org/fhir/sid/icd-10-cm|Z00", + }, + ], ) vocab = ConceptVocab() - concept_ids, *_ = build_cehr_sequences(pr, vocab, max_len=64) - z00 = vocab["http://hl7.org/fhir/sid/icd-10-cm|Z00"] - self.assertEqual(concept_ids.count(z00), 1) + concept_ids, *_ = build_cehr_sequences(patient, vocab, max_len=64) + self.assertEqual(concept_ids.count(vocab["http://hl7.org/fhir/sid/icd-10-cm|Z00"]), 1) def test_cehr_sequence_shapes(self) -> None: with tempfile.TemporaryDirectory() as tmp: write_one_patient_ndjson(Path(tmp)) ds = MIMIC4FHIRDataset(root=tmp, glob_pattern="*.ndjson", cache_dir=tmp) - p = fhir_patient_from_patient(ds.get_patient("p-synth-1")) - v = ConceptVocab() - c, tt, ts, ag, vo, vs = build_cehr_sequences(p, v, max_len=32) - n = len(c) - self.assertEqual(len(tt), n) - self.assertEqual(len(ts), n) + patient = ds.get_patient("p-synth-1") + vocab = ConceptVocab() + concept_ids, token_types, time_stamps, ages, visit_orders, visit_segments = ( + build_cehr_sequences(patient, vocab, max_len=32) + ) + n = len(concept_ids) + self.assertEqual(len(token_types), n) + self.assertEqual(len(time_stamps), n) + self.assertEqual(len(ages), n) + self.assertEqual(len(visit_orders), n) + self.assertEqual(len(visit_segments), n) self.assertGreater(n, 0) def test_build_cehr_max_len_zero_no_clinical_tokens(self) -> None: - """``max_len=0`` must not use ``events[-0:]`` (full list); emit nothing.""" - - from pyhealth.datasets.mimic4_fhir import FHIRPatient - - patient_r = { - "resourceType": "Patient", - "id": "p1", - "birthDate": "1950-01-01", - } - enc = { - "resourceType": "Encounter", - "id": "e1", - "subject": {"reference": "Patient/p1"}, - "period": {"start": "2020-06-01T10:00:00Z"}, - "class": {"code": "AMB"}, - } - cond = { - "resourceType": "Condition", - "id": "c1", - "subject": {"reference": "Patient/p1"}, - "encounter": {"reference": "Encounter/e1"}, - "code": { - "coding": [{"system": "http://hl7.org/fhir/sid/icd-10-cm", "code": "I10"}] - }, - "onsetDateTime": "2020-06-01T11:00:00Z", - } - pr = FHIRPatient(patient_id="p1", resources=[patient_r, enc, cond]) + patient = _patient_from_rows( + "p1", + [ + { + "patient_id": "p1", + "event_type": "patient", + "timestamp": None, + "patient/birth_date": "1950-01-01", + }, + { + "patient_id": "p1", + "event_type": "encounter", + "timestamp": "2020-06-01T10:00:00", + "encounter/encounter_id": "e1", + "encounter/encounter_class": "AMB", + }, + { + "patient_id": "p1", + "event_type": "condition", + "timestamp": "2020-06-01T11:00:00", + "condition/encounter_id": "e1", + "condition/concept_key": "http://hl7.org/fhir/sid/icd-10-cm|I10", + }, + ], + ) vocab = ConceptVocab() - c, tt, ts, ag, vo, vs = build_cehr_sequences(pr, vocab, max_len=0) + c, _, _, _, _, vs = build_cehr_sequences(patient, vocab, max_len=0) self.assertEqual(c, []) self.assertEqual(vs, []) def test_visit_segments_alternate_by_visit_index(self) -> None: - """CEHR-style segments: all tokens in visit ``k`` share ``k % 2``.""" - - from pyhealth.datasets.mimic4_fhir import FHIRPatient - - patient_r = { - "resourceType": "Patient", - "id": "p1", - "birthDate": "1950-01-01", - } - enc0 = { - "resourceType": "Encounter", - "id": "e0", - "subject": {"reference": "Patient/p1"}, - "period": {"start": "2020-06-01T10:00:00Z"}, - "class": {"code": "AMB"}, - } - enc1 = { - "resourceType": "Encounter", - "id": "e1", - "subject": {"reference": "Patient/p1"}, - "period": {"start": "2020-07-01T10:00:00Z"}, - "class": {"code": "IMP"}, - } - c0 = { - "resourceType": "Condition", - "id": "c0", - "subject": {"reference": "Patient/p1"}, - "encounter": {"reference": "Encounter/e0"}, - "code": { - "coding": [{"system": "http://hl7.org/fhir/sid/icd-10-cm", "code": "I10"}] - }, - "onsetDateTime": "2020-06-01T11:00:00Z", - } - c1 = { - "resourceType": "Condition", - "id": "c1", - "subject": {"reference": "Patient/p1"}, - "encounter": {"reference": "Encounter/e1"}, - "code": { - "coding": [{"system": "http://hl7.org/fhir/sid/icd-10-cm", "code": "I20"}] - }, - "onsetDateTime": "2020-07-01T11:00:00Z", - } - pr = FHIRPatient( - patient_id="p1", - resources=[patient_r, enc0, enc1, c0, c1], + patient = _patient_from_rows( + "p1", + [ + { + "patient_id": "p1", + "event_type": "patient", + "timestamp": None, + "patient/birth_date": "1950-01-01", + }, + { + "patient_id": "p1", + "event_type": "encounter", + "timestamp": "2020-06-01T10:00:00", + "encounter/encounter_id": "e0", + "encounter/encounter_class": "AMB", + }, + { + "patient_id": "p1", + "event_type": "condition", + "timestamp": "2020-06-01T11:00:00", + "condition/encounter_id": "e0", + "condition/concept_key": "http://hl7.org/fhir/sid/icd-10-cm|I10", + }, + { + "patient_id": "p1", + "event_type": "encounter", + "timestamp": "2020-07-01T10:00:00", + "encounter/encounter_id": "e1", + "encounter/encounter_class": "IMP", + }, + { + "patient_id": "p1", + "event_type": "condition", + "timestamp": "2020-07-01T11:00:00", + "condition/encounter_id": "e1", + "condition/concept_key": "http://hl7.org/fhir/sid/icd-10-cm|I20", + }, + ], ) vocab = ConceptVocab() - _, _, _, _, _, vs = build_cehr_sequences(pr, vocab, max_len=64) - self.assertEqual(len(vs), 4) - self.assertEqual(vs, [0, 0, 1, 1]) + _, _, _, _, _, visit_segments = build_cehr_sequences(patient, vocab, max_len=64) + self.assertEqual(visit_segments, [0, 0, 1, 1]) def test_unlinked_visit_idx_matches_sequential_counter(self) -> None: - """Skipped encounters (no ``period.start``) must not shift unlinked ``visit_idx``.""" - - from pyhealth.datasets.mimic4_fhir import FHIRPatient - - patient_r = { - "resourceType": "Patient", - "id": "p1", - "birthDate": "1950-01-01", - } - enc_no_start = { - "resourceType": "Encounter", - "id": "e_bad", - "subject": {"reference": "Patient/p1"}, - "class": {"code": "AMB"}, - } - enc_ok = { - "resourceType": "Encounter", - "id": "e_ok", - "subject": {"reference": "Patient/p1"}, - "period": {"start": "2020-03-01T10:00:00Z"}, - "class": {"code": "IMP"}, - } - cond_linked = { - "resourceType": "Condition", - "id": "c_link", - "subject": {"reference": "Patient/p1"}, - "encounter": {"reference": "Encounter/e_ok"}, - "code": { - "coding": [{"system": "http://hl7.org/fhir/sid/icd-10-cm", "code": "I10"}] - }, - "onsetDateTime": "2020-03-05T11:00:00Z", - } - cond_unlinked = { - "resourceType": "Condition", - "id": "c_free", - "subject": {"reference": "Patient/p1"}, - "code": { - "coding": [{"system": "http://hl7.org/fhir/sid/icd-10-cm", "code": "Z00"}] - }, - "onsetDateTime": "2020-03-15T12:00:00Z", - } - pr = FHIRPatient( - patient_id="p1", - resources=[patient_r, enc_no_start, enc_ok, cond_linked, cond_unlinked], + patient = _patient_from_rows( + "p1", + [ + { + "patient_id": "p1", + "event_type": "patient", + "timestamp": None, + "patient/birth_date": "1950-01-01", + }, + { + "patient_id": "p1", + "event_type": "encounter", + "timestamp": None, + "encounter/encounter_id": "e_bad", + "encounter/encounter_class": "AMB", + }, + { + "patient_id": "p1", + "event_type": "encounter", + "timestamp": "2020-03-01T10:00:00", + "encounter/encounter_id": "e_ok", + "encounter/encounter_class": "IMP", + }, + { + "patient_id": "p1", + "event_type": "condition", + "timestamp": "2020-03-05T11:00:00", + "condition/encounter_id": "e_ok", + "condition/concept_key": "http://hl7.org/fhir/sid/icd-10-cm|I10", + }, + { + "patient_id": "p1", + "event_type": "condition", + "timestamp": "2020-03-15T12:00:00", + "condition/concept_key": "http://hl7.org/fhir/sid/icd-10-cm|Z00", + }, + ], ) vocab = ConceptVocab() - c, _, _, _, vo, vs = build_cehr_sequences(pr, vocab, max_len=64) + concept_ids, _, _, _, visit_orders, visit_segments = build_cehr_sequences( + patient, vocab, max_len=64 + ) i10 = vocab["http://hl7.org/fhir/sid/icd-10-cm|I10"] z00 = vocab["http://hl7.org/fhir/sid/icd-10-cm|Z00"] - i_link = c.index(i10) - i_free = c.index(z00) - self.assertEqual(vo[i_link], vo[i_free]) - self.assertEqual(vs[i_link], vs[i_free]) + i_link = concept_ids.index(i10) + i_free = concept_ids.index(z00) + self.assertEqual(visit_orders[i_link], visit_orders[i_free]) + self.assertEqual(visit_segments[i_link], visit_segments[i_free]) def test_medication_request_uses_medication_codeable_concept(self) -> None: - """FHIR R4 MedicationRequest carries Rx in ``medicationCodeableConcept``, not ``code``.""" - - from pyhealth.datasets.mimic4_fhir import FHIRPatient - - patient_r = { - "resourceType": "Patient", - "id": "p1", - "birthDate": "1950-01-01", - } - enc = { - "resourceType": "Encounter", - "id": "e1", - "subject": {"reference": "Patient/p1"}, - "period": {"start": "2020-06-01T10:00:00Z"}, - "class": {"code": "IMP"}, - } - mr_a = { - "resourceType": "MedicationRequest", - "id": "m1", - "subject": {"reference": "Patient/p1"}, - "encounter": {"reference": "Encounter/e1"}, - "authoredOn": "2020-06-01T11:00:00Z", - "medicationCodeableConcept": { - "coding": [ - { - "system": "http://www.nlm.nih.gov/research/umls/rxnorm", - "code": "111", - } - ] - }, - } - mr_b = { - "resourceType": "MedicationRequest", - "id": "m2", - "subject": {"reference": "Patient/p1"}, - "encounter": {"reference": "Encounter/e1"}, - "authoredOn": "2020-06-01T12:00:00Z", - "medicationCodeableConcept": { - "coding": [ - { - "system": "http://www.nlm.nih.gov/research/umls/rxnorm", - "code": "222", - } - ] - }, - } - pr = FHIRPatient( - patient_id="p1", - resources=[patient_r, enc, mr_a, mr_b], + patient = _patient_from_rows( + "p1", + [ + { + "patient_id": "p1", + "event_type": "patient", + "timestamp": None, + "patient/birth_date": "1950-01-01", + }, + { + "patient_id": "p1", + "event_type": "encounter", + "timestamp": "2020-06-01T10:00:00", + "encounter/encounter_id": "e1", + "encounter/encounter_class": "IMP", + }, + { + "patient_id": "p1", + "event_type": "medication_request", + "timestamp": "2020-06-01T11:00:00", + "medication_request/encounter_id": "e1", + "medication_request/concept_key": "http://www.nlm.nih.gov/research/umls/rxnorm|111", + }, + { + "patient_id": "p1", + "event_type": "medication_request", + "timestamp": "2020-06-01T12:00:00", + "medication_request/encounter_id": "e1", + "medication_request/concept_key": "http://www.nlm.nih.gov/research/umls/rxnorm|222", + }, + ], ) vocab = ConceptVocab() - c, _, _, _, _, _ = build_cehr_sequences(pr, vocab, max_len=64) + c, *_ = build_cehr_sequences(patient, vocab, max_len=64) ka = "http://www.nlm.nih.gov/research/umls/rxnorm|111" kb = "http://www.nlm.nih.gov/research/umls/rxnorm|222" self.assertNotEqual(vocab[ka], vocab[kb]) @@ -626,37 +626,73 @@ def test_medication_request_uses_medication_codeable_concept(self) -> None: self.assertEqual(c.count(vocab[kb]), 1) def test_medication_request_medication_reference_token(self) -> None: - """When only ``medicationReference`` is present, use a stable ref-based key.""" - - from pyhealth.datasets.mimic4_fhir import FHIRPatient - - patient_r = { - "resourceType": "Patient", - "id": "p1", - "birthDate": "1950-01-01", - } - enc = { - "resourceType": "Encounter", - "id": "e1", - "subject": {"reference": "Patient/p1"}, - "period": {"start": "2020-06-01T10:00:00Z"}, - "class": {"code": "IMP"}, - } - mr = { - "resourceType": "MedicationRequest", - "id": "m1", - "subject": {"reference": "Patient/p1"}, - "encounter": {"reference": "Encounter/e1"}, - "authoredOn": "2020-06-01T11:00:00Z", - "medicationReference": {"reference": "Medication/med-abc"}, - } - pr = FHIRPatient(patient_id="p1", resources=[patient_r, enc, mr]) + patient = _patient_from_rows( + "p1", + [ + { + "patient_id": "p1", + "event_type": "patient", + "timestamp": None, + "patient/birth_date": "1950-01-01", + }, + { + "patient_id": "p1", + "event_type": "encounter", + "timestamp": "2020-06-01T10:00:00", + "encounter/encounter_id": "e1", + "encounter/encounter_class": "IMP", + }, + { + "patient_id": "p1", + "event_type": "medication_request", + "timestamp": "2020-06-01T11:00:00", + "medication_request/encounter_id": "e1", + "medication_request/concept_key": "MedicationRequest/reference|med-abc", + }, + ], + ) vocab = ConceptVocab() - c, _, _, _, _, _ = build_cehr_sequences(pr, vocab, max_len=64) + c, *_ = build_cehr_sequences(patient, vocab, max_len=64) key = "MedicationRequest/reference|med-abc" self.assertIn(vocab[key], c) self.assertEqual(c.count(vocab[key]), 1) + def test_collect_cehr_timeline_events_orders_by_timestamp(self) -> None: + patient = _patient_from_rows( + "p1", + [ + { + "patient_id": "p1", + "event_type": "patient", + "timestamp": None, + "patient/birth_date": "1950-01-01", + }, + { + "patient_id": "p1", + "event_type": "encounter", + "timestamp": "2020-06-01T10:00:00", + "encounter/encounter_id": "e1", + "encounter/encounter_class": "AMB", + }, + { + "patient_id": "p1", + "event_type": "condition", + "timestamp": "2020-06-01T11:00:00", + "condition/encounter_id": "e1", + "condition/concept_key": "a|1", + }, + { + "patient_id": "p1", + "event_type": "observation", + "timestamp": "2020-06-01T12:00:00", + "observation/encounter_id": "e1", + "observation/concept_key": "b|2", + }, + ], + ) + events = collect_cehr_timeline_events(patient) + self.assertEqual([event[1] for event in events], ["encounter|AMB", "a|1", "b|2"]) + if __name__ == "__main__": unittest.main()