From d83d94bfc6e789b7b2c68cdcf8faf69d96fe3ccb Mon Sep 17 00:00:00 2001 From: Madhav Kanda Date: Thu, 9 Apr 2026 16:47:16 -0500 Subject: [PATCH 01/10] ama task prediction --- docs/api/tasks.rst | 1 + .../tasks/pyhealth.tasks.ama_prediction.rst | 7 + ...mic3_ama_prediction_logistic_regression.py | 187 +++++++ pyhealth/tasks/__init__.py | 1 + pyhealth/tasks/ama_prediction.py | 278 ++++++++++ tests/core/test_mimic3_ama_prediction.py | 475 ++++++++++++++++++ 6 files changed, 949 insertions(+) create mode 100644 docs/api/tasks/pyhealth.tasks.ama_prediction.rst create mode 100644 examples/mimic3_ama_prediction_logistic_regression.py create mode 100644 pyhealth/tasks/ama_prediction.py create mode 100644 tests/core/test_mimic3_ama_prediction.py diff --git a/docs/api/tasks.rst b/docs/api/tasks.rst index 399b8f1aa..8d25ced19 100644 --- a/docs/api/tasks.rst +++ b/docs/api/tasks.rst @@ -205,6 +205,7 @@ Available Tasks .. toctree:: :maxdepth: 3 + AMA Prediction (MIMIC-III) Base Task In-Hospital Mortality (MIMIC-IV) MIMIC-III ICD-9 Coding diff --git a/docs/api/tasks/pyhealth.tasks.ama_prediction.rst b/docs/api/tasks/pyhealth.tasks.ama_prediction.rst new file mode 100644 index 000000000..a82a2521c --- /dev/null +++ b/docs/api/tasks/pyhealth.tasks.ama_prediction.rst @@ -0,0 +1,7 @@ +pyhealth.tasks.ama_prediction +================================= + +.. autoclass:: pyhealth.tasks.ama_prediction.AMAPredictionMIMIC3 + :members: + :undoc-members: + :show-inheritance: diff --git a/examples/mimic3_ama_prediction_logistic_regression.py b/examples/mimic3_ama_prediction_logistic_regression.py new file mode 100644 index 000000000..750a0093b --- /dev/null +++ b/examples/mimic3_ama_prediction_logistic_regression.py @@ -0,0 +1,187 @@ +"""AMA Prediction on MIMIC-III -- Ablation Study. + +This script reproduces the Against-Medical-Advice (AMA) discharge +prediction task from: + + Boag, W.; Suresh, H.; Celi, L. A.; Szolovits, P.; Ghassemi, M. + "Racial Disparities and Mistrust in End-of-Life Care." + Machine Learning for Healthcare Conference, PMLR, 2018. + +The paper predicts AMA discharge using L1-regularized logistic +regression on demographic features (age, gender, race, insurance, +LOS) and three mistrust-proxy scores. Our PyHealth task reproduces +the same label and demographic feature set. We use PyHealth's +LogisticRegression model (a single linear layer on feature +embeddings) as the primary model since it is the closest analog +to the paper's approach within PyHealth's pipeline. + +The ablation study is structured as follows: + +1. **Paper baseline** -- LogisticRegression on demographics only + (age + LOS + gender + race + insurance), matching the paper's + BASELINE+RACE configuration. + +2. **Feature group comparison** -- demographics-only versus + clinical-codes-only versus all features combined, showing + whether ICD codes and prescriptions add predictive value + beyond demographics. + +3. **Model comparison** -- LogisticRegression versus RNN versus + Transformer on the full feature set, showing whether more + expressive architectures improve AMA prediction. + +All experiments use the synthetic MIMIC-III demo data hosted by +PyHealth so the script is runnable without credentialed access. +Because the demo data contains very few patients and may lack +positive AMA labels, the reported metrics are illustrative only. + +Usage: + python examples/mimic3_ama_prediction_logistic_regression.py +""" + +import tempfile + +from pyhealth.datasets import MIMIC3Dataset, get_dataloader, split_by_patient +from pyhealth.models import LogisticRegression, RNN, Transformer +from pyhealth.tasks import AMAPredictionMIMIC3 +from pyhealth.trainer import Trainer + +SYNTHETIC_ROOT = ( + "https://storage.googleapis.com/pyhealth/Synthetic_MIMIC-III" +) +TABLES = ["DIAGNOSES_ICD", "PROCEDURES_ICD", "PRESCRIPTIONS"] +EPOCHS = 5 +BATCH_SIZE = 32 +MONITOR = "pr_auc" + +DEMO_KEYS = ["demographics", "age", "los"] +CODE_KEYS = ["conditions", "procedures", "drugs"] +ALL_KEYS = DEMO_KEYS + CODE_KEYS + + +def load_dataset(): + """Load the synthetic MIMIC-III dataset and apply the AMA task.""" + dataset = MIMIC3Dataset( + root=SYNTHETIC_ROOT, + tables=TABLES, + cache_dir=tempfile.TemporaryDirectory().name, + dev=True, + ) + dataset.stats() + task = AMAPredictionMIMIC3() + sample_dataset = dataset.set_task(task) + return sample_dataset + + +def run_experiment(sample_dataset, model_cls, model_kwargs, label): + """Train and evaluate a single configuration. + + Args: + sample_dataset: The ``SampleDataset`` returned by ``set_task``. + model_cls: Model class (``LogisticRegression``, ``RNN``, or + ``Transformer``). + model_kwargs: Extra keyword arguments forwarded to the model. + label: Human-readable experiment label for logging. + + Returns: + Dict of evaluation metrics, or ``None`` if training failed. + """ + train_ds, val_ds, test_ds = split_by_patient( + sample_dataset, [0.8, 0.1, 0.1] + ) + train_dl = get_dataloader(train_ds, batch_size=BATCH_SIZE, shuffle=True) + val_dl = get_dataloader(val_ds, batch_size=BATCH_SIZE, shuffle=False) + test_dl = get_dataloader(test_ds, batch_size=BATCH_SIZE, shuffle=False) + + model = model_cls(dataset=sample_dataset, **model_kwargs) + + trainer = Trainer(model=model) + try: + trainer.train( + train_dataloader=train_dl, + val_dataloader=val_dl, + epochs=EPOCHS, + monitor=MONITOR, + ) + metrics = trainer.evaluate(test_dl) + except Exception as exc: + print(f" [{label}] Training failed: {exc}") + return None + + print(f" [{label}] {metrics}") + return metrics + + +def main(): + sample_dataset = load_dataset() + + # ================================================================= + # Ablation 1: Paper baseline (LogisticRegression on demographics) + # + # This is the closest reproduction of the paper's BASELINE+RACE + # configuration: age, LOS, gender, race, and insurance fed into + # a logistic regression model. + # ================================================================= + print("\n" + "=" * 60) + print("ABLATION 1: Paper baseline -- LogisticRegression on demographics") + print("=" * 60) + + run_experiment( + sample_dataset, + LogisticRegression, + {"feature_keys": DEMO_KEYS}, + "LogReg demographics (paper baseline)", + ) + + # ================================================================= + # Ablation 2: Feature group comparison (all using LogisticRegression) + # + # Compare demographics-only (paper's features) versus clinical + # codes only (extension beyond paper) versus all combined. + # ================================================================= + print("\n" + "=" * 60) + print("ABLATION 2: Feature groups (LogisticRegression)") + print("=" * 60) + + configs = [ + (DEMO_KEYS, "demographics only (age+LOS+gender+race+insurance)"), + (CODE_KEYS, "clinical codes only (conditions+procedures+drugs)"), + (ALL_KEYS, "all features combined"), + ] + for feature_keys, label in configs: + run_experiment( + sample_dataset, + LogisticRegression, + {"feature_keys": feature_keys}, + label, + ) + + # ================================================================= + # Ablation 3: Model comparison on all features + # + # Test whether more expressive architectures improve over the + # logistic regression baseline when all features are available. + # ================================================================= + print("\n" + "=" * 60) + print("ABLATION 3: Model comparison (all features)") + print("=" * 60) + + model_configs = [ + (LogisticRegression, {}, "LogisticRegression"), + (RNN, {"hidden_size": 64}, "RNN hidden=64"), + (Transformer, {"hidden_size": 64}, "Transformer hidden=64"), + ] + + for model_cls, kwargs, label in model_configs: + run_experiment( + sample_dataset, + model_cls, + {"feature_keys": ALL_KEYS, **kwargs}, + label, + ) + + print("\nDone.") + + +if __name__ == "__main__": + main() diff --git a/pyhealth/tasks/__init__.py b/pyhealth/tasks/__init__.py index 797988377..c4e541054 100644 --- a/pyhealth/tasks/__init__.py +++ b/pyhealth/tasks/__init__.py @@ -1,3 +1,4 @@ +from .ama_prediction import AMAPredictionMIMIC3 from .base_task import BaseTask from .benchmark_ehrshot import BenchmarkEHRShot from .cancer_survival import CancerMutationBurden, CancerSurvivalPrediction diff --git a/pyhealth/tasks/ama_prediction.py b/pyhealth/tasks/ama_prediction.py new file mode 100644 index 000000000..5ccf61e0a --- /dev/null +++ b/pyhealth/tasks/ama_prediction.py @@ -0,0 +1,278 @@ +from datetime import datetime +from typing import Any, Dict, List + +from .base_task import BaseTask + + +def _normalize_race(ethnicity: str) -> str: + """Map MIMIC-III ethnicity strings to the race categories used by + Boag et al. 2018. + + Args: + ethnicity: Raw ethnicity value from MIMIC-III ``admissions`` + table. + + Returns: + One of ``"White"``, ``"Black"``, ``"Hispanic"``, + ``"Asian"``, ``"Native American"``, or ``"Other"``. + """ + if ethnicity is None: + return "Other" + eth = str(ethnicity).upper() + if "HISPANIC" in eth or "SOUTH AMERICAN" in eth: + return "Hispanic" + if "AMERICAN INDIAN" in eth: + return "Native American" + if "ASIAN" in eth: + return "Asian" + if "BLACK" in eth: + return "Black" + if "WHITE" in eth: + return "White" + return "Other" + + +def _normalize_insurance(insurance: str) -> str: + """Map MIMIC-III insurance strings to the categories used by + Boag et al. 2018. + + Args: + insurance: Raw insurance value from MIMIC-III ``admissions`` + table. + + Returns: + ``"Public"`` for Medicare/Medicaid/Government, or the + original value (typically ``"Private"`` or ``"Self Pay"``). + """ + if insurance is None: + return "Other" + if insurance in ("Medicare", "Medicaid", "Government"): + return "Public" + return insurance + + +def _safe_parse_datetime(value: Any) -> datetime: + """Parse a datetime string, trying common MIMIC-III formats.""" + if isinstance(value, datetime): + return value + for fmt in ("%Y-%m-%d %H:%M:%S", "%Y-%m-%d"): + try: + return datetime.strptime(str(value), fmt) + except (ValueError, TypeError): + continue + raise ValueError(f"Cannot parse datetime: {value!r}") + + +class AMAPredictionMIMIC3(BaseTask): + """Predict whether a patient leaves the hospital against medical advice. + + This task reproduces the AMA (Against Medical Advice) discharge + prediction target from Boag et al. 2018, "Racial Disparities and + Mistrust in End-of-Life Care." A positive label indicates that the + patient's ``discharge_location`` is ``"LEFT AGAINST MEDICAL ADVI"`` + in the MIMIC-III admissions table. + + The feature set follows the paper's **BASELINE+RACE** + configuration: + + * **demographics** (multi-hot) -- gender, normalized race, and + normalized insurance category. + * **age** (tensor) -- patient age at admission in years. + * **los** (tensor) -- length of hospital stay in days. + * **conditions** (sequence) -- ICD-9 diagnosis codes. + * **procedures** (sequence) -- ICD-9 procedure codes. + * **drugs** (sequence) -- prescription drug names. + + The demographic and clinical-code features can be used + independently or together via the model's ``feature_keys`` + parameter, enabling ablation studies that mirror the paper. + + Unlike mortality or readmission prediction, the label is a property + of the **current** admission, so patients with only one visit are + eligible. + + Attributes: + task_name (str): The name of the task. + input_schema (Dict[str, str]): The schema for the task input. + output_schema (Dict[str, str]): The schema for the task output. + + Examples: + >>> from pyhealth.datasets import MIMIC3Dataset + >>> from pyhealth.tasks import AMAPredictionMIMIC3 + >>> dataset = MIMIC3Dataset( + ... root="/path/to/mimic-iii/1.4", + ... tables=["diagnoses_icd", "procedures_icd", "prescriptions"], + ... ) + >>> task = AMAPredictionMIMIC3() + >>> samples = dataset.set_task(task) + + Reference: + Boag, W.; Suresh, H.; Celi, L. A.; Szolovits, P.; and Ghassemi, M. + 2018. Racial Disparities and Mistrust in End-of-Life Care. In Machine + Learning for Healthcare Conference. PMLR. + """ + + AMA_DISCHARGE_LOCATION: str = "LEFT AGAINST MEDICAL ADVI" + + task_name: str = "AMAPredictionMIMIC3" + input_schema: Dict[str, str] = { + "conditions": "sequence", + "procedures": "sequence", + "drugs": "sequence", + "demographics": "multi_hot", + "age": "tensor", + "los": "tensor", + } + output_schema: Dict[str, str] = {"ama": "binary"} + + def __init__(self, exclude_newborns: bool = True) -> None: + """Initializes the AMA prediction task. + + Args: + exclude_newborns: If ``True``, admissions whose + ``admission_type`` is ``"NEWBORN"`` are skipped. + Defaults to ``True``. + """ + self.exclude_newborns = exclude_newborns + + def __call__(self, patient: Any) -> List[Dict[str, Any]]: + """Processes a single patient for AMA discharge prediction. + + Each admission with at least one diagnosis code, one procedure + code, and one prescription is emitted as a sample. The binary + label is derived from the admission's ``discharge_location``. + + Args: + patient: A Patient object from ``MIMIC3Dataset``. + + Returns: + A list of sample dictionaries, each containing: + - ``visit_id``: MIMIC-III ``hadm_id``. + - ``patient_id``: MIMIC-III ``subject_id``. + - ``conditions``: List of ICD-9 diagnosis codes. + - ``procedures``: List of ICD-9 procedure codes. + - ``drugs``: List of drug names from prescriptions. + - ``demographics``: List of categorical tokens + (gender, race, insurance). + - ``age``: Patient age at admission in years (float). + - ``los``: Hospital length of stay in days (float). + - ``ama``: Binary label (1 = AMA discharge, 0 = other). + """ + admissions = patient.get_events(event_type="admissions") + if len(admissions) == 0: + return [] + + patients_events = patient.get_events(event_type="patients") + if len(patients_events) == 0: + return [] + patient_info = patients_events[0] + + gender = getattr(patient_info, "gender", None) + dob_raw = getattr(patient_info, "dob", None) + try: + dob = _safe_parse_datetime(dob_raw) + except (ValueError, TypeError): + dob = None + + samples: List[Dict[str, Any]] = [] + + for admission in admissions: + if self.exclude_newborns: + admission_type = getattr( + admission, "admission_type", None + ) + if admission_type == "NEWBORN": + continue + + # --- Label --- + discharge_location = getattr( + admission, "discharge_location", None + ) + ama_label = ( + 1 + if discharge_location == self.AMA_DISCHARGE_LOCATION + else 0 + ) + + # --- Clinical codes --- + hadm_filter = ("hadm_id", "==", admission.hadm_id) + + diagnoses = patient.get_events( + event_type="diagnoses_icd", filters=[hadm_filter] + ) + conditions = [event.icd9_code for event in diagnoses] + if len(conditions) == 0: + continue + + procedures = patient.get_events( + event_type="procedures_icd", filters=[hadm_filter] + ) + procedures_list = [event.icd9_code for event in procedures] + if len(procedures_list) == 0: + continue + + prescriptions = patient.get_events( + event_type="prescriptions", filters=[hadm_filter] + ) + drugs = [event.drug for event in prescriptions] + if len(drugs) == 0: + continue + + # --- Demographics (categorical) --- + ethnicity = getattr(admission, "ethnicity", None) + insurance = getattr(admission, "insurance", None) + + demo_tokens: List[str] = [] + if gender: + demo_tokens.append(f"gender:{gender}") + demo_tokens.append(f"race:{_normalize_race(ethnicity)}") + demo_tokens.append( + f"insurance:{_normalize_insurance(insurance)}" + ) + + # --- Age (continuous) --- + age_years = 0.0 + if dob is not None: + admit_dt = admission.timestamp + if isinstance(admit_dt, datetime): + age_years = ( + admit_dt.year + - dob.year + - int( + (admit_dt.month, admit_dt.day) + < (dob.month, dob.day) + ) + ) + age_years = float(min(age_years, 90)) + + # --- LOS (continuous, in days) --- + los_days = 0.0 + dischtime_raw = getattr(admission, "dischtime", None) + if dischtime_raw is not None: + try: + dischtime = _safe_parse_datetime(dischtime_raw) + admit_dt = admission.timestamp + if isinstance(admit_dt, datetime): + los_days = max( + (dischtime - admit_dt).total_seconds() + / 86400.0, + 0.0, + ) + except (ValueError, TypeError): + los_days = 0.0 + + samples.append( + { + "visit_id": admission.hadm_id, + "patient_id": patient.patient_id, + "conditions": conditions, + "procedures": procedures_list, + "drugs": drugs, + "demographics": demo_tokens, + "age": [age_years], + "los": [los_days], + "ama": ama_label, + } + ) + + return samples diff --git a/tests/core/test_mimic3_ama_prediction.py b/tests/core/test_mimic3_ama_prediction.py new file mode 100644 index 000000000..022f2e1cd --- /dev/null +++ b/tests/core/test_mimic3_ama_prediction.py @@ -0,0 +1,475 @@ +import tempfile +import unittest +from datetime import datetime +from pathlib import Path +from unittest.mock import MagicMock + +from pyhealth.tasks.ama_prediction import ( + AMAPredictionMIMIC3, + _normalize_insurance, + _normalize_race, +) + + +def _make_event(**attrs): + """Create a mock event with the given attributes.""" + event = MagicMock() + for key, value in attrs.items(): + setattr(event, key, value) + return event + + +def _build_patient( + patient_id, + admissions, + diagnoses, + procedures, + prescriptions, + gender="M", + dob="2100-01-01 00:00:00", +): + """Build a mock Patient whose ``get_events`` respects *filters*. + + Args: + patient_id: Patient identifier string. + admissions: List of admission event dicts (should include + ``ethnicity``, ``insurance``, ``dischtime``). + diagnoses: List of diagnosis event dicts (must include ``hadm_id``). + procedures: List of procedure event dicts (must include ``hadm_id``). + prescriptions: List of prescription event dicts (must include ``hadm_id``). + gender: Gender string for the patient demographics event. + dob: Date of birth string for computing age. + """ + patient = MagicMock() + patient.patient_id = patient_id + + patient_event = _make_event(gender=gender, dob=dob) + adm_events = [_make_event(**a) for a in admissions] + diag_events = [_make_event(**d) for d in diagnoses] + proc_events = [_make_event(**p) for p in procedures] + rx_events = [_make_event(**r) for r in prescriptions] + + def _get_events(event_type, filters=None, **kwargs): + if event_type == "patients": + return [patient_event] + if event_type == "admissions": + return adm_events + source = { + "diagnoses_icd": diag_events, + "procedures_icd": proc_events, + "prescriptions": rx_events, + }.get(event_type, []) + if filters: + col, op, val = filters[0] + source = [e for e in source if getattr(e, col, None) == val] + return source + + patient.get_events = _get_events + return patient + + +SAMPLE_KEYS = { + "visit_id", + "patient_id", + "conditions", + "procedures", + "drugs", + "demographics", + "age", + "los", + "ama", +} + + +class TestNormalizeRace(unittest.TestCase): + """Unit tests for the race normalization helper.""" + + def test_white(self): + self.assertEqual(_normalize_race("WHITE"), "White") + self.assertEqual(_normalize_race("WHITE - RUSSIAN"), "White") + + def test_black(self): + self.assertEqual(_normalize_race("BLACK/AFRICAN AMERICAN"), "Black") + + def test_hispanic(self): + self.assertEqual(_normalize_race("HISPANIC OR LATINO"), "Hispanic") + self.assertEqual( + _normalize_race("SOUTH AMERICAN"), "Hispanic" + ) + + def test_asian(self): + self.assertEqual( + _normalize_race("ASIAN - CHINESE"), "Asian" + ) + + def test_native_american(self): + self.assertEqual( + _normalize_race("AMERICAN INDIAN/ALASKA NATIVE"), "Native American" + ) + + def test_other(self): + self.assertEqual(_normalize_race("UNKNOWN/NOT SPECIFIED"), "Other") + self.assertEqual(_normalize_race(None), "Other") + + def test_normalize_insurance(self): + self.assertEqual(_normalize_insurance("Medicare"), "Public") + self.assertEqual(_normalize_insurance("Medicaid"), "Public") + self.assertEqual(_normalize_insurance("Government"), "Public") + self.assertEqual(_normalize_insurance("Private"), "Private") + self.assertEqual(_normalize_insurance("Self Pay"), "Self Pay") + self.assertEqual(_normalize_insurance(None), "Other") + + +class TestAMAPredictionMIMIC3Schema(unittest.TestCase): + """Validate class-level schema attributes.""" + + def test_task_name(self): + self.assertEqual(AMAPredictionMIMIC3.task_name, "AMAPredictionMIMIC3") + + def test_input_schema(self): + schema = AMAPredictionMIMIC3.input_schema + self.assertEqual(schema["conditions"], "sequence") + self.assertEqual(schema["procedures"], "sequence") + self.assertEqual(schema["drugs"], "sequence") + self.assertEqual(schema["demographics"], "multi_hot") + self.assertEqual(schema["age"], "tensor") + self.assertEqual(schema["los"], "tensor") + + def test_output_schema(self): + self.assertEqual(AMAPredictionMIMIC3.output_schema, {"ama": "binary"}) + + def test_defaults(self): + task = AMAPredictionMIMIC3() + self.assertTrue(task.exclude_newborns) + + +class TestAMAPredictionMIMIC3Mock(unittest.TestCase): + """Unit tests using synthetic mock data (no real dataset needed).""" + + def setUp(self): + self.task = AMAPredictionMIMIC3() + + def _default_admission(self, hadm_id="100", **overrides): + """Return a standard admission dict with demographic fields.""" + adm = { + "hadm_id": hadm_id, + "admission_type": "EMERGENCY", + "discharge_location": "HOME", + "ethnicity": "WHITE", + "insurance": "Private", + "dischtime": "2150-01-10 14:00:00", + "timestamp": datetime(2150, 1, 1), + } + adm.update(overrides) + return adm + + def test_empty_patient(self): + patient = MagicMock() + patient.patient_id = "P0" + patient.get_events = lambda event_type, **kw: [] + self.assertEqual(self.task(patient), []) + + def test_ama_label_positive(self): + """Admission with AMA discharge should produce label=1.""" + patient = _build_patient( + patient_id="P1", + admissions=[ + self._default_admission( + hadm_id="100", + discharge_location="LEFT AGAINST MEDICAL ADVI", + ), + ], + diagnoses=[{"hadm_id": "100", "icd9_code": "4019"}], + procedures=[{"hadm_id": "100", "icd9_code": "3893"}], + prescriptions=[{"hadm_id": "100", "drug": "Aspirin"}], + ) + samples = self.task(patient) + self.assertEqual(len(samples), 1) + self.assertEqual(samples[0]["ama"], 1) + self.assertEqual(samples[0]["visit_id"], "100") + self.assertEqual(samples[0]["patient_id"], "P1") + + def test_ama_label_negative(self): + """Non-AMA discharge should produce label=0.""" + patient = _build_patient( + patient_id="P2", + admissions=[ + self._default_admission(hadm_id="200"), + ], + diagnoses=[{"hadm_id": "200", "icd9_code": "25000"}], + procedures=[{"hadm_id": "200", "icd9_code": "3995"}], + prescriptions=[{"hadm_id": "200", "drug": "Metformin"}], + ) + samples = self.task(patient) + self.assertEqual(len(samples), 1) + self.assertEqual(samples[0]["ama"], 0) + + def test_multiple_admissions_mixed_labels(self): + """Two admissions: one AMA, one not.""" + patient = _build_patient( + patient_id="P3", + admissions=[ + self._default_admission(hadm_id="300"), + self._default_admission( + hadm_id="301", + admission_type="URGENT", + discharge_location="LEFT AGAINST MEDICAL ADVI", + timestamp=datetime(2150, 6, 1), + dischtime="2150-06-05 10:00:00", + ), + ], + diagnoses=[ + {"hadm_id": "300", "icd9_code": "4019"}, + {"hadm_id": "301", "icd9_code": "30000"}, + ], + procedures=[ + {"hadm_id": "300", "icd9_code": "3893"}, + {"hadm_id": "301", "icd9_code": "9394"}, + ], + prescriptions=[ + {"hadm_id": "300", "drug": "Lisinopril"}, + {"hadm_id": "301", "drug": "Naloxone"}, + ], + ) + samples = self.task(patient) + self.assertEqual(len(samples), 2) + labels = {s["visit_id"]: s["ama"] for s in samples} + self.assertEqual(labels["300"], 0) + self.assertEqual(labels["301"], 1) + + def test_skip_admission_missing_diagnoses(self): + """Admissions with no diagnosis codes should be excluded.""" + patient = _build_patient( + patient_id="P4", + admissions=[self._default_admission(hadm_id="400")], + diagnoses=[], + procedures=[{"hadm_id": "400", "icd9_code": "3893"}], + prescriptions=[{"hadm_id": "400", "drug": "Aspirin"}], + ) + self.assertEqual(self.task(patient), []) + + def test_skip_admission_missing_procedures(self): + """Admissions with no procedure codes should be excluded.""" + patient = _build_patient( + patient_id="P5", + admissions=[self._default_admission(hadm_id="500")], + diagnoses=[{"hadm_id": "500", "icd9_code": "4019"}], + procedures=[], + prescriptions=[{"hadm_id": "500", "drug": "Aspirin"}], + ) + self.assertEqual(self.task(patient), []) + + def test_skip_admission_missing_prescriptions(self): + """Admissions with no prescriptions should be excluded.""" + patient = _build_patient( + patient_id="P6", + admissions=[ + self._default_admission( + hadm_id="600", + discharge_location="LEFT AGAINST MEDICAL ADVI", + ), + ], + diagnoses=[{"hadm_id": "600", "icd9_code": "4019"}], + procedures=[{"hadm_id": "600", "icd9_code": "3893"}], + prescriptions=[], + ) + self.assertEqual(self.task(patient), []) + + def test_exclude_newborns(self): + """NEWBORN admissions should be excluded by default.""" + patient = _build_patient( + patient_id="P7", + admissions=[ + self._default_admission( + hadm_id="700", admission_type="NEWBORN" + ), + ], + diagnoses=[{"hadm_id": "700", "icd9_code": "V3000"}], + procedures=[{"hadm_id": "700", "icd9_code": "9904"}], + prescriptions=[{"hadm_id": "700", "drug": "Vitamin K"}], + ) + task_exclude = AMAPredictionMIMIC3(exclude_newborns=True) + self.assertEqual(task_exclude(patient), []) + + task_include = AMAPredictionMIMIC3(exclude_newborns=False) + samples = task_include(patient) + self.assertEqual(len(samples), 1) + self.assertEqual(samples[0]["ama"], 0) + + def test_sample_keys(self): + """Every sample must contain the expected keys.""" + patient = _build_patient( + patient_id="P8", + admissions=[ + self._default_admission( + hadm_id="800", discharge_location="SNF", + timestamp=datetime(2150, 2, 15), + ), + ], + diagnoses=[{"hadm_id": "800", "icd9_code": "4280"}], + procedures=[{"hadm_id": "800", "icd9_code": "3722"}], + prescriptions=[{"hadm_id": "800", "drug": "Furosemide"}], + ) + samples = self.task(patient) + self.assertEqual(len(samples), 1) + self.assertEqual(set(samples[0].keys()), SAMPLE_KEYS) + + def test_clinical_codes_content(self): + """Verify extracted codes match the input data.""" + patient = _build_patient( + patient_id="P9", + admissions=[ + self._default_admission( + hadm_id="900", timestamp=datetime(2150, 4, 1), + ), + ], + diagnoses=[ + {"hadm_id": "900", "icd9_code": "4019"}, + {"hadm_id": "900", "icd9_code": "25000"}, + ], + procedures=[ + {"hadm_id": "900", "icd9_code": "3893"}, + ], + prescriptions=[ + {"hadm_id": "900", "drug": "Lisinopril"}, + {"hadm_id": "900", "drug": "Metformin"}, + ], + ) + samples = self.task(patient) + self.assertEqual(samples[0]["conditions"], ["4019", "25000"]) + self.assertEqual(samples[0]["procedures"], ["3893"]) + self.assertEqual(samples[0]["drugs"], ["Lisinopril", "Metformin"]) + + def test_demographics_tokens(self): + """Demographics should include prefixed gender, race, insurance.""" + patient = _build_patient( + patient_id="P10", + admissions=[ + self._default_admission( + hadm_id="1000", + ethnicity="BLACK/AFRICAN AMERICAN", + insurance="Medicaid", + timestamp=datetime(2150, 5, 1), + ), + ], + diagnoses=[{"hadm_id": "1000", "icd9_code": "4019"}], + procedures=[{"hadm_id": "1000", "icd9_code": "3893"}], + prescriptions=[{"hadm_id": "1000", "drug": "Aspirin"}], + gender="F", + ) + samples = self.task(patient) + demo = samples[0]["demographics"] + self.assertIn("gender:F", demo) + self.assertIn("race:Black", demo) + self.assertIn("insurance:Public", demo) + + def test_age_calculation(self): + """Age should be computed from dob and admission timestamp.""" + patient = _build_patient( + patient_id="P11", + admissions=[ + self._default_admission( + hadm_id="1100", + timestamp=datetime(2150, 6, 15), + dischtime="2150-06-20 12:00:00", + ), + ], + diagnoses=[{"hadm_id": "1100", "icd9_code": "4019"}], + procedures=[{"hadm_id": "1100", "icd9_code": "3893"}], + prescriptions=[{"hadm_id": "1100", "drug": "Aspirin"}], + dob="2100-01-01 00:00:00", + ) + samples = self.task(patient) + self.assertEqual(samples[0]["age"], [50.0]) + + def test_age_capped_at_90(self): + """Ages above 90 should be capped (MIMIC-III de-identification).""" + patient = _build_patient( + patient_id="P12", + admissions=[ + self._default_admission( + hadm_id="1200", + timestamp=datetime(2150, 6, 15), + ), + ], + diagnoses=[{"hadm_id": "1200", "icd9_code": "4019"}], + procedures=[{"hadm_id": "1200", "icd9_code": "3893"}], + prescriptions=[{"hadm_id": "1200", "drug": "Aspirin"}], + dob="1850-01-01 00:00:00", + ) + samples = self.task(patient) + self.assertEqual(samples[0]["age"], [90.0]) + + def test_los_calculation(self): + """LOS should be computed in days from admittime to dischtime.""" + patient = _build_patient( + patient_id="P13", + admissions=[ + self._default_admission( + hadm_id="1300", + timestamp=datetime(2150, 3, 1, 8, 0, 0), + dischtime="2150-03-06 08:00:00", + ), + ], + diagnoses=[{"hadm_id": "1300", "icd9_code": "4019"}], + procedures=[{"hadm_id": "1300", "icd9_code": "3893"}], + prescriptions=[{"hadm_id": "1300", "drug": "Aspirin"}], + ) + samples = self.task(patient) + self.assertAlmostEqual(samples[0]["los"][0], 5.0, places=2) + + +class TestAMAPredictionMIMIC3Integration(unittest.TestCase): + """Integration test using the MIMIC-III demo dataset. + + The demo dataset contains zero AMA discharge events. Because the + ``BinaryLabelProcessor`` requires exactly two unique labels to fit, + ``set_task()`` cannot complete on this dataset. Instead we verify + that the task callable itself produces well-formed samples when + invoked on real ``Patient`` objects from the loaded dataset. + """ + + @classmethod + def setUpClass(cls): + from pyhealth.datasets import MIMIC3Dataset + + cls.cache_dir = tempfile.TemporaryDirectory() + demo_path = str( + Path(__file__).parent.parent.parent + / "test-resources" + / "core" + / "mimic3demo" + ) + cls.dataset = MIMIC3Dataset( + root=demo_path, + tables=["diagnoses_icd", "procedures_icd", "prescriptions"], + cache_dir=cls.cache_dir.name, + ) + cls.task = AMAPredictionMIMIC3() + + def test_task_callable_on_real_patients(self): + """Run the task on real Patient objects from the demo dataset.""" + total_samples = 0 + for patient in self.dataset.iter_patients(): + samples = self.task(patient) + for sample in samples: + self.assertIn("ama", sample) + self.assertEqual(sample["ama"], 0) + self.assertIsInstance(sample["conditions"], list) + self.assertIsInstance(sample["procedures"], list) + self.assertIsInstance(sample["drugs"], list) + self.assertGreater(len(sample["conditions"]), 0) + self.assertGreater(len(sample["procedures"]), 0) + self.assertGreater(len(sample["drugs"]), 0) + self.assertIsInstance(sample["demographics"], list) + self.assertGreater(len(sample["demographics"]), 0) + self.assertIsInstance(sample["age"], list) + self.assertEqual(len(sample["age"]), 1) + self.assertIsInstance(sample["los"], list) + self.assertEqual(len(sample["los"]), 1) + total_samples += 1 + self.assertGreater(total_samples, 0, "Should produce at least one sample") + + +if __name__ == "__main__": + unittest.main() From d4f5adaee01a3dc284bcc98b0d0c7f9ec08ff5bf Mon Sep 17 00:00:00 2001 From: Madhav Kanda Date: Sun, 12 Apr 2026 12:34:03 -0500 Subject: [PATCH 02/10] Added Extra ablations --- ...mic3_ama_prediction_logistic_regression.py | 518 ++++++++++----- examples/mimic3_ama_prediction_rnn.py | 432 +++++++++++++ pyhealth/tasks/ama_prediction.py | 96 ++- tests/core/test_mimic3_ama_prediction.py | 590 ++++++++++++++---- 4 files changed, 1324 insertions(+), 312 deletions(-) create mode 100644 examples/mimic3_ama_prediction_rnn.py diff --git a/examples/mimic3_ama_prediction_logistic_regression.py b/examples/mimic3_ama_prediction_logistic_regression.py index 750a0093b..4f5427463 100644 --- a/examples/mimic3_ama_prediction_logistic_regression.py +++ b/examples/mimic3_ama_prediction_logistic_regression.py @@ -1,48 +1,37 @@ -"""AMA Prediction on MIMIC-III -- Ablation Study. +"""AMA Prediction -- LogisticRegression Ablation with Fairness Analysis. -This script reproduces the Against-Medical-Advice (AMA) discharge -prediction task from: +Reproduces the Against-Medical-Advice discharge prediction from: Boag, W.; Suresh, H.; Celi, L. A.; Szolovits, P.; Ghassemi, M. "Racial Disparities and Mistrust in End-of-Life Care." Machine Learning for Healthcare Conference, PMLR, 2018. -The paper predicts AMA discharge using L1-regularized logistic -regression on demographic features (age, gender, race, insurance, -LOS) and three mistrust-proxy scores. Our PyHealth task reproduces -the same label and demographic feature set. We use PyHealth's -LogisticRegression model (a single linear layer on feature -embeddings) as the primary model since it is the closest analog -to the paper's approach within PyHealth's pipeline. +For each baseline the script reports: + 1. Overall AUROC / PR-AUC averaged over N random 60/40 splits. + 2. Subgroup performance (AUROC, PR-AUC) sliced by Race, Age Group, + and Insurance Type. + 3. Fairness metrics per subgroup: + - Demographic Parity = % predicted AMA (P(Y_hat=1 | Group=g)) + - Equal Opportunity = True Positive Rate (P(Y_hat=1 | Y=1, Group=g)) -The ablation study is structured as follows: - -1. **Paper baseline** -- LogisticRegression on demographics only - (age + LOS + gender + race + insurance), matching the paper's - BASELINE+RACE configuration. - -2. **Feature group comparison** -- demographics-only versus - clinical-codes-only versus all features combined, showing - whether ICD codes and prescriptions add predictive value - beyond demographics. - -3. **Model comparison** -- LogisticRegression versus RNN versus - Transformer on the full feature set, showing whether more - expressive architectures improve AMA prediction. - -All experiments use the synthetic MIMIC-III demo data hosted by -PyHealth so the script is runnable without credentialed access. -Because the demo data contains very few patients and may lack -positive AMA labels, the reported metrics are illustrative only. - -Usage: +Usage (synthetic demo data -- illustrative only, likely no AMA positives): python examples/mimic3_ama_prediction_logistic_regression.py + +Usage (real MIMIC-III): + python examples/mimic3_ama_prediction_logistic_regression.py \\ + --root /path/to/mimic-iii/1.4 --splits 100 --epochs 10 """ +import argparse import tempfile +import time + +import numpy as np +import torch +from sklearn.metrics import average_precision_score, roc_auc_score from pyhealth.datasets import MIMIC3Dataset, get_dataloader, split_by_patient -from pyhealth.models import LogisticRegression, RNN, Transformer +from pyhealth.models import LogisticRegression from pyhealth.tasks import AMAPredictionMIMIC3 from pyhealth.trainer import Trainer @@ -50,137 +39,384 @@ "https://storage.googleapis.com/pyhealth/Synthetic_MIMIC-III" ) TABLES = ["DIAGNOSES_ICD", "PROCEDURES_ICD", "PRESCRIPTIONS"] -EPOCHS = 5 -BATCH_SIZE = 32 -MONITOR = "pr_auc" -DEMO_KEYS = ["demographics", "age", "los"] -CODE_KEYS = ["conditions", "procedures", "drugs"] -ALL_KEYS = DEMO_KEYS + CODE_KEYS +BASELINES = { + "BASELINE": ["demographics", "age", "los"], + "BASELINE+RACE": ["demographics", "age", "los", "race"], + "BASELINE+RACE+SUBSTANCE": [ + "demographics", "age", "los", "race", "has_substance_use", + ], +} -def load_dataset(): - """Load the synthetic MIMIC-III dataset and apply the AMA task.""" - dataset = MIMIC3Dataset( - root=SYNTHETIC_ROOT, - tables=TABLES, - cache_dir=tempfile.TemporaryDirectory().name, - dev=True, - ) - dataset.stats() - task = AMAPredictionMIMIC3() - sample_dataset = dataset.set_task(task) - return sample_dataset +# ------------------------------------------------------------------ +# Helpers -- demographics lookup +# ------------------------------------------------------------------ + +def _build_demographics_lookup(dataset, task): + """Run the task on every patient and collect raw demographic info. + Returns a dict mapping ``(patient_id, visit_id)`` to a dict with + keys ``race``, ``age``, and ``insurance``. + """ + lookup = {} + for patient in dataset.iter_patients(): + for sample in task(patient): + pid = str(sample["patient_id"]) + vid = str(sample["visit_id"]) + race = sample["race"][0].split(":", 1)[1] + age = sample["age"][0] + insurance = "Other" + for t in sample["demographics"]: + if t.startswith("insurance:"): + insurance = t.split(":", 1)[1] + break + lookup[(pid, vid)] = { + "race": race, "age": age, "insurance": insurance, + } + return lookup + + +def _age_group(age): + if age < 45: + return "Young (18-44)" + if age < 65: + return "Middle (45-64)" + return "Senior (65+)" + + +# ------------------------------------------------------------------ +# Helpers -- inference with demographic labels +# ------------------------------------------------------------------ + +def _get_predictions(model, dataloader, lookup): + """Run model on *dataloader*, return predictions + subgroup labels.""" + model.eval() + all_probs, all_labels = [], [] + all_races, all_ages, all_ins = [], [], [] + + with torch.no_grad(): + for batch in dataloader: + output = model(**batch) + all_probs.append(output["y_prob"].detach().cpu()) + all_labels.append(output["y_true"].detach().cpu()) + + pids = batch["patient_id"] + vids = batch["visit_id"] + if isinstance(vids, torch.Tensor): + vids = vids.tolist() + if isinstance(pids, torch.Tensor): + pids = pids.tolist() + + for pid, vid in zip(pids, vids): + info = lookup.get((str(pid), str(vid)), {}) + all_races.append(info.get("race", "Other")) + all_ages.append(_age_group(info.get("age", 0.0))) + all_ins.append(info.get("insurance", "Other")) + + y_prob = torch.cat(all_probs).numpy().ravel() + y_true = torch.cat(all_labels).numpy().ravel() + groups = { + "Race": np.array(all_races), + "Age Group": np.array(all_ages), + "Insurance": np.array(all_ins), + } + return y_prob, y_true, groups + + +# ------------------------------------------------------------------ +# Helpers -- safe metrics +# ------------------------------------------------------------------ + +def _safe_auroc(y, p): + if len(np.unique(y)) < 2: + return float("nan") + try: + return roc_auc_score(y, p) + except ValueError: + return float("nan") -def run_experiment(sample_dataset, model_cls, model_kwargs, label): - """Train and evaluate a single configuration. - Args: - sample_dataset: The ``SampleDataset`` returned by ``set_task``. - model_cls: Model class (``LogisticRegression``, ``RNN``, or - ``Transformer``). - model_kwargs: Extra keyword arguments forwarded to the model. - label: Human-readable experiment label for logging. +def _safe_prauc(y, p): + if np.sum(y) == 0: + return float("nan") + try: + return average_precision_score(y, p) + except ValueError: + return float("nan") - Returns: - Dict of evaluation metrics, or ``None`` if training failed. - """ - train_ds, val_ds, test_ds = split_by_patient( - sample_dataset, [0.8, 0.1, 0.1] + +# ------------------------------------------------------------------ +# Single split +# ------------------------------------------------------------------ + +def _create_model(sample_dataset, feature_keys, embedding_dim=128): + """Create a LogisticRegression with the requested feature subset.""" + model = LogisticRegression( + dataset=sample_dataset, embedding_dim=embedding_dim, ) - train_dl = get_dataloader(train_ds, batch_size=BATCH_SIZE, shuffle=True) - val_dl = get_dataloader(val_ds, batch_size=BATCH_SIZE, shuffle=False) - test_dl = get_dataloader(test_ds, batch_size=BATCH_SIZE, shuffle=False) + model.feature_keys = list(feature_keys) + output_size = model.get_output_size() + model.fc = torch.nn.Linear( + len(feature_keys) * embedding_dim, output_size, + ) + return model + - model = model_cls(dataset=sample_dataset, **model_kwargs) +def _run_single_split(sample_dataset, feature_keys, lookup, + seed, epochs, batch_size=32): + """Train + evaluate one 60/40 split. Returns metrics dict or None.""" + train_ds, _, test_ds = split_by_patient( + sample_dataset, [0.6, 0.0, 0.4], seed=seed, + ) + train_dl = get_dataloader(train_ds, batch_size=batch_size, shuffle=True) + test_dl = get_dataloader(test_ds, batch_size=batch_size, shuffle=False) + model = _create_model(sample_dataset, feature_keys) trainer = Trainer(model=model) try: trainer.train( - train_dataloader=train_dl, - val_dataloader=val_dl, - epochs=EPOCHS, - monitor=MONITOR, + train_dataloader=train_dl, val_dataloader=None, + epochs=epochs, monitor=None, ) - metrics = trainer.evaluate(test_dl) except Exception as exc: - print(f" [{label}] Training failed: {exc}") + print(f" train failed: {exc}") return None - print(f" [{label}] {metrics}") - return metrics + y_prob, y_true, groups = _get_predictions(model, test_dl, lookup) + threshold = 0.5 + y_pred = (y_prob >= threshold).astype(int) + + overall_auroc = _safe_auroc(y_true, y_prob) + overall_prauc = _safe_prauc(y_true, y_prob) + + subgroup = {} + for attr_name, attr_vals in groups.items(): + subgroup[attr_name] = {} + for grp in sorted(set(attr_vals)): + mask = attr_vals == grp + n = int(mask.sum()) + if n < 2: + continue + yt, yp, yd = y_true[mask], y_prob[mask], y_pred[mask] + pos = yt.sum() + subgroup[attr_name][grp] = { + "auroc": _safe_auroc(yt, yp), + "pr_auc": _safe_prauc(yt, yp), + "pct_pred": float(yd.mean()) * 100, + "tpr": float(yd[yt == 1].mean()) * 100 if pos > 0 + else float("nan"), + "n": n, + } + + return { + "auroc": overall_auroc, + "pr_auc": overall_prauc, + "subgroups": subgroup, + } + + +# ------------------------------------------------------------------ +# Aggregation +# ------------------------------------------------------------------ + +def _nanmean(lst): + v = [x for x in lst if not np.isnan(x)] + return np.mean(v) if v else float("nan") + + +def _nanstd(lst): + v = [x for x in lst if not np.isnan(x)] + return np.std(v) if v else float("nan") + + +def _aggregate(results): + """Aggregate per-split metrics into means and stds.""" + valid = [r for r in results if r is not None] + if not valid: + return None + agg = { + "n": len(valid), + "auroc_mean": _nanmean([r["auroc"] for r in valid]), + "auroc_std": _nanstd([r["auroc"] for r in valid]), + "pr_auc_mean": _nanmean([r["pr_auc"] for r in valid]), + "pr_auc_std": _nanstd([r["pr_auc"] for r in valid]), + } + + all_attrs = set() + for r in valid: + all_attrs.update(r["subgroups"].keys()) + + agg["subgroups"] = {} + for attr in sorted(all_attrs): + agg["subgroups"][attr] = {} + all_grps = set() + for r in valid: + if attr in r["subgroups"]: + all_grps.update(r["subgroups"][attr].keys()) + + for grp in sorted(all_grps): + aurocs, praucs, pcts, tprs, ns = [], [], [], [], [] + for r in valid: + m = r["subgroups"].get(attr, {}).get(grp) + if m is None: + continue + aurocs.append(m["auroc"]) + praucs.append(m["pr_auc"]) + pcts.append(m["pct_pred"]) + tprs.append(m["tpr"]) + ns.append(m["n"]) + + agg["subgroups"][attr][grp] = { + "auroc_mean": _nanmean(aurocs), + "auroc_std": _nanstd(aurocs), + "pr_auc_mean": _nanmean(praucs), + "pr_auc_std": _nanstd(praucs), + "pct_pred_mean": _nanmean(pcts), + "tpr_mean": _nanmean(tprs), + "n_avg": int(np.mean(ns)) if ns else 0, + } + return agg + + +# ------------------------------------------------------------------ +# Pretty-printing +# ------------------------------------------------------------------ + +def _fmt(val, digits=4): + return "N/A" if np.isnan(val) else f"{val:.{digits}f}" + + +def _print_results(name, feature_keys, agg): + w = 70 + print(f"\n{'=' * w}") + print(f" {name} (LogisticRegression)") + print(f" Features: {feature_keys}") + print(f"{'=' * w}") + + if agg is None: + print(" No valid results.\n") + return + + ci_lo = agg["auroc_mean"] - 1.96 * agg["auroc_std"] + ci_hi = agg["auroc_mean"] + 1.96 * agg["auroc_std"] + print(f"\n 1. Overall Performance ({agg['n']} splits)") + print(f" AUROC: {_fmt(agg['auroc_mean'])} +/- {_fmt(agg['auroc_std'])}" + f" 95% CI ({_fmt(ci_lo)}, {_fmt(ci_hi)})") + print(f" PR-AUC: {_fmt(agg['pr_auc_mean'])} +/- {_fmt(agg['pr_auc_std'])}") + + print(f"\n 2. Subgroup Performance") + for attr, grps in agg["subgroups"].items(): + print(f" {attr}:") + print(f" {'Group':<20} {'AUROC':>15} {'PR-AUC':>15} {'n_avg':>7}") + print(f" {'-'*58}") + for grp, m in grps.items(): + a_str = f"{_fmt(m['auroc_mean'])}+/-{_fmt(m['auroc_std'])}" + p_str = f"{_fmt(m['pr_auc_mean'])}+/-{_fmt(m['pr_auc_std'])}" + print(f" {grp:<20} {a_str:>15} {p_str:>15} {m['n_avg']:>7}") + + print(f"\n 3. Fairness Metrics") + print(f" Demographic Parity (% Predicted AMA):") + for attr, grps in agg["subgroups"].items(): + parts = [f"{g}: {_fmt(m['pct_pred_mean'],2)}%" + for g, m in grps.items()] + print(f" {attr}: {', '.join(parts)}") + + print(f" Equal Opportunity (True Positive Rate):") + for attr, grps in agg["subgroups"].items(): + parts = [f"{g}: {_fmt(m['tpr_mean'],2)}%" + for g, m in grps.items()] + print(f" {attr}: {', '.join(parts)}") + + +# ------------------------------------------------------------------ +# Main +# ------------------------------------------------------------------ def main(): - sample_dataset = load_dataset() - - # ================================================================= - # Ablation 1: Paper baseline (LogisticRegression on demographics) - # - # This is the closest reproduction of the paper's BASELINE+RACE - # configuration: age, LOS, gender, race, and insurance fed into - # a logistic regression model. - # ================================================================= - print("\n" + "=" * 60) - print("ABLATION 1: Paper baseline -- LogisticRegression on demographics") - print("=" * 60) - - run_experiment( - sample_dataset, - LogisticRegression, - {"feature_keys": DEMO_KEYS}, - "LogReg demographics (paper baseline)", + parser = argparse.ArgumentParser( + description="AMA prediction ablation -- LogisticRegression", ) + parser.add_argument("--root", default=SYNTHETIC_ROOT, + help="MIMIC-III root (local path or URL)") + parser.add_argument("--splits", type=int, default=100, + help="Number of random 60/40 splits (default 100)") + parser.add_argument("--epochs", type=int, default=10, + help="Training epochs per split") + parser.add_argument("--dev", action="store_true", + help="Use dev mode (1000 patients)") + args = parser.parse_args() + + cache_dir = tempfile.mkdtemp(prefix="ama_lr_") + print(f"Cache: {cache_dir}") + print(f"Root: {args.root}") + print(f"Splits: {args.splits} | Epochs: {args.epochs}") + + print("\n[1/4] Loading dataset...") + t0 = time.time() + dataset = MIMIC3Dataset( + root=args.root, tables=TABLES, + cache_dir=cache_dir, dev=args.dev, + ) + print(f" Loaded in {time.time()-t0:.1f}s") + dataset.stats() - # ================================================================= - # Ablation 2: Feature group comparison (all using LogisticRegression) - # - # Compare demographics-only (paper's features) versus clinical - # codes only (extension beyond paper) versus all combined. - # ================================================================= - print("\n" + "=" * 60) - print("ABLATION 2: Feature groups (LogisticRegression)") - print("=" * 60) - - configs = [ - (DEMO_KEYS, "demographics only (age+LOS+gender+race+insurance)"), - (CODE_KEYS, "clinical codes only (conditions+procedures+drugs)"), - (ALL_KEYS, "all features combined"), - ] - for feature_keys, label in configs: - run_experiment( - sample_dataset, - LogisticRegression, - {"feature_keys": feature_keys}, - label, - ) - - # ================================================================= - # Ablation 3: Model comparison on all features - # - # Test whether more expressive architectures improve over the - # logistic regression baseline when all features are available. - # ================================================================= - print("\n" + "=" * 60) - print("ABLATION 3: Model comparison (all features)") - print("=" * 60) - - model_configs = [ - (LogisticRegression, {}, "LogisticRegression"), - (RNN, {"hidden_size": 64}, "RNN hidden=64"), - (Transformer, {"hidden_size": 64}, "Transformer hidden=64"), - ] - - for model_cls, kwargs, label in model_configs: - run_experiment( - sample_dataset, - model_cls, - {"feature_keys": ALL_KEYS, **kwargs}, - label, - ) - - print("\nDone.") + print("\n[2/4] Applying AMA task...") + task = AMAPredictionMIMIC3() + try: + sample_dataset = dataset.set_task(task) + except ValueError as exc: + if "unique labels" not in str(exc).lower(): + raise + print(f"\n {exc}") + print(" The dataset contains no AMA-positive cases.") + print(" AMA prevalence is ~2% so small/synthetic data") + print(" often lacks positives. Demonstrating the task") + print(" on raw patients instead:\n") + total = 0 + for patient in dataset.iter_patients(): + samples = task(patient) + total += len(samples) + print(f" Task produced {total} samples (all label=0)") + print("\n Re-run with real MIMIC-III for ablation:") + print(" python examples/" + "mimic3_ama_prediction_logistic_regression.py \\") + print(" --root /path/to/mimic-iii/1.4") + print("\nDone.") + return + print(f" Samples: {len(sample_dataset)}") + + print("\n[3/4] Building demographics lookup...") + t0 = time.time() + lookup = _build_demographics_lookup(dataset, task) + print(f" {len(lookup)} entries in {time.time()-t0:.1f}s") + + print(f"\n[4/4] Running ablation ({len(BASELINES)} baselines " + f"x {args.splits} splits)...\n") + + t_total = time.time() + for name, feature_keys in BASELINES.items(): + split_results = [] + for i in range(args.splits): + t0 = time.time() + res = _run_single_split( + sample_dataset, feature_keys, lookup, + seed=i, epochs=args.epochs, + ) + elapsed = time.time() - t0 + if res is not None: + print(f" [{name}] split {i+1:3d}/{args.splits}: " + f"AUROC={_fmt(res['auroc'])} ({elapsed:.1f}s)") + else: + print(f" [{name}] split {i+1:3d}/{args.splits}: FAILED") + split_results.append(res) + + agg = _aggregate(split_results) + _print_results(name, feature_keys, agg) + + print(f"\nTotal time: {time.time()-t_total:.1f}s") + print("Done.") if __name__ == "__main__": diff --git a/examples/mimic3_ama_prediction_rnn.py b/examples/mimic3_ama_prediction_rnn.py new file mode 100644 index 000000000..15b834843 --- /dev/null +++ b/examples/mimic3_ama_prediction_rnn.py @@ -0,0 +1,432 @@ +"""AMA Prediction -- RNN Ablation with Fairness Analysis. + +Reproduces the Against-Medical-Advice discharge prediction from: + + Boag, W.; Suresh, H.; Celi, L. A.; Szolovits, P.; Ghassemi, M. + "Racial Disparities and Mistrust in End-of-Life Care." + Machine Learning for Healthcare Conference, PMLR, 2018. + +This script uses PyHealth's RNN model instead of LogisticRegression. +For the paper's demographic-only baselines the RNN degenerates to a +single-step recurrence (effectively a non-linear transform), which +lets us directly compare the impact of model capacity vs. the +LogisticRegression ablation in the companion script. + +For each baseline the script reports: + 1. Overall AUROC / PR-AUC averaged over N random 60/40 splits. + 2. Subgroup performance (AUROC, PR-AUC) sliced by Race, Age Group, + and Insurance Type. + 3. Fairness metrics per subgroup: + - Demographic Parity = % predicted AMA (P(Y_hat=1 | Group=g)) + - Equal Opportunity = True Positive Rate (P(Y_hat=1 | Y=1, Group=g)) + +Usage (synthetic demo data -- illustrative only, likely no AMA positives): + python examples/mimic3_ama_prediction_rnn.py + +Usage (real MIMIC-III): + python examples/mimic3_ama_prediction_rnn.py \\ + --root /path/to/mimic-iii/1.4 --splits 100 --epochs 10 +""" + +import argparse +import tempfile +import time + +import numpy as np +import torch +from sklearn.metrics import average_precision_score, roc_auc_score + +from pyhealth.datasets import MIMIC3Dataset, get_dataloader, split_by_patient +from pyhealth.models import RNN +from pyhealth.tasks import AMAPredictionMIMIC3 +from pyhealth.trainer import Trainer + +SYNTHETIC_ROOT = ( + "https://storage.googleapis.com/pyhealth/Synthetic_MIMIC-III" +) +TABLES = ["DIAGNOSES_ICD", "PROCEDURES_ICD", "PRESCRIPTIONS"] + +BASELINES = { + "BASELINE": ["demographics", "age", "los"], + "BASELINE+RACE": ["demographics", "age", "los", "race"], + "BASELINE+RACE+SUBSTANCE": [ + "demographics", "age", "los", "race", "has_substance_use", + ], +} + + +# ------------------------------------------------------------------ +# Helpers -- demographics lookup +# ------------------------------------------------------------------ + +def _build_demographics_lookup(dataset, task): + """Run the task on every patient and collect raw demographic info. + + Returns a dict mapping ``(patient_id, visit_id)`` to a dict with + keys ``race``, ``age``, and ``insurance``. + """ + lookup = {} + for patient in dataset.iter_patients(): + for sample in task(patient): + pid = str(sample["patient_id"]) + vid = str(sample["visit_id"]) + race = sample["race"][0].split(":", 1)[1] + age = sample["age"][0] + insurance = "Other" + for t in sample["demographics"]: + if t.startswith("insurance:"): + insurance = t.split(":", 1)[1] + break + lookup[(pid, vid)] = { + "race": race, "age": age, "insurance": insurance, + } + return lookup + + +def _age_group(age): + if age < 45: + return "Young (18-44)" + if age < 65: + return "Middle (45-64)" + return "Senior (65+)" + + +# ------------------------------------------------------------------ +# Helpers -- inference with demographic labels +# ------------------------------------------------------------------ + +def _get_predictions(model, dataloader, lookup): + """Run model on *dataloader*, return predictions + subgroup labels.""" + model.eval() + all_probs, all_labels = [], [] + all_races, all_ages, all_ins = [], [], [] + + with torch.no_grad(): + for batch in dataloader: + output = model(**batch) + all_probs.append(output["y_prob"].detach().cpu()) + all_labels.append(output["y_true"].detach().cpu()) + + pids = batch["patient_id"] + vids = batch["visit_id"] + if isinstance(vids, torch.Tensor): + vids = vids.tolist() + if isinstance(pids, torch.Tensor): + pids = pids.tolist() + + for pid, vid in zip(pids, vids): + info = lookup.get((str(pid), str(vid)), {}) + all_races.append(info.get("race", "Other")) + all_ages.append(_age_group(info.get("age", 0.0))) + all_ins.append(info.get("insurance", "Other")) + + y_prob = torch.cat(all_probs).numpy().ravel() + y_true = torch.cat(all_labels).numpy().ravel() + groups = { + "Race": np.array(all_races), + "Age Group": np.array(all_ages), + "Insurance": np.array(all_ins), + } + return y_prob, y_true, groups + + +# ------------------------------------------------------------------ +# Helpers -- safe metrics +# ------------------------------------------------------------------ + +def _safe_auroc(y, p): + if len(np.unique(y)) < 2: + return float("nan") + try: + return roc_auc_score(y, p) + except ValueError: + return float("nan") + + +def _safe_prauc(y, p): + if np.sum(y) == 0: + return float("nan") + try: + return average_precision_score(y, p) + except ValueError: + return float("nan") + + +# ------------------------------------------------------------------ +# Single split +# ------------------------------------------------------------------ + +def _create_model(sample_dataset, feature_keys, + embedding_dim=128, hidden_dim=64): + """Create an RNN with the requested feature subset.""" + model = RNN( + dataset=sample_dataset, + embedding_dim=embedding_dim, + hidden_dim=hidden_dim, + ) + model.feature_keys = list(feature_keys) + output_size = model.get_output_size() + model.fc = torch.nn.Linear( + len(feature_keys) * hidden_dim, output_size, + ) + return model + + +def _run_single_split(sample_dataset, feature_keys, lookup, + seed, epochs, batch_size=32): + """Train + evaluate one 60/40 split. Returns metrics dict or None.""" + train_ds, _, test_ds = split_by_patient( + sample_dataset, [0.6, 0.0, 0.4], seed=seed, + ) + train_dl = get_dataloader(train_ds, batch_size=batch_size, shuffle=True) + test_dl = get_dataloader(test_ds, batch_size=batch_size, shuffle=False) + + model = _create_model(sample_dataset, feature_keys) + trainer = Trainer(model=model) + try: + trainer.train( + train_dataloader=train_dl, val_dataloader=None, + epochs=epochs, monitor=None, + ) + except Exception as exc: + print(f" train failed: {exc}") + return None + + y_prob, y_true, groups = _get_predictions(model, test_dl, lookup) + threshold = 0.5 + y_pred = (y_prob >= threshold).astype(int) + + overall_auroc = _safe_auroc(y_true, y_prob) + overall_prauc = _safe_prauc(y_true, y_prob) + + subgroup = {} + for attr_name, attr_vals in groups.items(): + subgroup[attr_name] = {} + for grp in sorted(set(attr_vals)): + mask = attr_vals == grp + n = int(mask.sum()) + if n < 2: + continue + yt, yp, yd = y_true[mask], y_prob[mask], y_pred[mask] + pos = yt.sum() + subgroup[attr_name][grp] = { + "auroc": _safe_auroc(yt, yp), + "pr_auc": _safe_prauc(yt, yp), + "pct_pred": float(yd.mean()) * 100, + "tpr": float(yd[yt == 1].mean()) * 100 if pos > 0 + else float("nan"), + "n": n, + } + + return { + "auroc": overall_auroc, + "pr_auc": overall_prauc, + "subgroups": subgroup, + } + + +# ------------------------------------------------------------------ +# Aggregation +# ------------------------------------------------------------------ + +def _nanmean(lst): + v = [x for x in lst if not np.isnan(x)] + return np.mean(v) if v else float("nan") + + +def _nanstd(lst): + v = [x for x in lst if not np.isnan(x)] + return np.std(v) if v else float("nan") + + +def _aggregate(results): + """Aggregate per-split metrics into means and stds.""" + valid = [r for r in results if r is not None] + if not valid: + return None + + agg = { + "n": len(valid), + "auroc_mean": _nanmean([r["auroc"] for r in valid]), + "auroc_std": _nanstd([r["auroc"] for r in valid]), + "pr_auc_mean": _nanmean([r["pr_auc"] for r in valid]), + "pr_auc_std": _nanstd([r["pr_auc"] for r in valid]), + } + + all_attrs = set() + for r in valid: + all_attrs.update(r["subgroups"].keys()) + + agg["subgroups"] = {} + for attr in sorted(all_attrs): + agg["subgroups"][attr] = {} + all_grps = set() + for r in valid: + if attr in r["subgroups"]: + all_grps.update(r["subgroups"][attr].keys()) + + for grp in sorted(all_grps): + aurocs, praucs, pcts, tprs, ns = [], [], [], [], [] + for r in valid: + m = r["subgroups"].get(attr, {}).get(grp) + if m is None: + continue + aurocs.append(m["auroc"]) + praucs.append(m["pr_auc"]) + pcts.append(m["pct_pred"]) + tprs.append(m["tpr"]) + ns.append(m["n"]) + + agg["subgroups"][attr][grp] = { + "auroc_mean": _nanmean(aurocs), + "auroc_std": _nanstd(aurocs), + "pr_auc_mean": _nanmean(praucs), + "pr_auc_std": _nanstd(praucs), + "pct_pred_mean": _nanmean(pcts), + "tpr_mean": _nanmean(tprs), + "n_avg": int(np.mean(ns)) if ns else 0, + } + return agg + + +# ------------------------------------------------------------------ +# Pretty-printing +# ------------------------------------------------------------------ + +def _fmt(val, digits=4): + return "N/A" if np.isnan(val) else f"{val:.{digits}f}" + + +def _print_results(name, feature_keys, agg): + w = 70 + print(f"\n{'=' * w}") + print(f" {name} (RNN hidden_dim=64)") + print(f" Features: {feature_keys}") + print(f"{'=' * w}") + + if agg is None: + print(" No valid results.\n") + return + + ci_lo = agg["auroc_mean"] - 1.96 * agg["auroc_std"] + ci_hi = agg["auroc_mean"] + 1.96 * agg["auroc_std"] + print(f"\n 1. Overall Performance ({agg['n']} splits)") + print(f" AUROC: {_fmt(agg['auroc_mean'])} +/- {_fmt(agg['auroc_std'])}" + f" 95% CI ({_fmt(ci_lo)}, {_fmt(ci_hi)})") + print(f" PR-AUC: {_fmt(agg['pr_auc_mean'])} +/- {_fmt(agg['pr_auc_std'])}") + + print(f"\n 2. Subgroup Performance") + for attr, grps in agg["subgroups"].items(): + print(f" {attr}:") + print(f" {'Group':<20} {'AUROC':>15} {'PR-AUC':>15} {'n_avg':>7}") + print(f" {'-'*58}") + for grp, m in grps.items(): + a_str = f"{_fmt(m['auroc_mean'])}+/-{_fmt(m['auroc_std'])}" + p_str = f"{_fmt(m['pr_auc_mean'])}+/-{_fmt(m['pr_auc_std'])}" + print(f" {grp:<20} {a_str:>15} {p_str:>15} {m['n_avg']:>7}") + + print(f"\n 3. Fairness Metrics") + print(f" Demographic Parity (% Predicted AMA):") + for attr, grps in agg["subgroups"].items(): + parts = [f"{g}: {_fmt(m['pct_pred_mean'],2)}%" + for g, m in grps.items()] + print(f" {attr}: {', '.join(parts)}") + + print(f" Equal Opportunity (True Positive Rate):") + for attr, grps in agg["subgroups"].items(): + parts = [f"{g}: {_fmt(m['tpr_mean'],2)}%" + for g, m in grps.items()] + print(f" {attr}: {', '.join(parts)}") + + +# ------------------------------------------------------------------ +# Main +# ------------------------------------------------------------------ + +def main(): + parser = argparse.ArgumentParser( + description="AMA prediction ablation -- RNN", + ) + parser.add_argument("--root", default=SYNTHETIC_ROOT, + help="MIMIC-III root (local path or URL)") + parser.add_argument("--splits", type=int, default=100, + help="Number of random 60/40 splits (default 100)") + parser.add_argument("--epochs", type=int, default=10, + help="Training epochs per split") + parser.add_argument("--dev", action="store_true", + help="Use dev mode (1000 patients)") + args = parser.parse_args() + + cache_dir = tempfile.mkdtemp(prefix="ama_rnn_") + print(f"Cache: {cache_dir}") + print(f"Root: {args.root}") + print(f"Splits: {args.splits} | Epochs: {args.epochs}") + + print("\n[1/4] Loading dataset...") + t0 = time.time() + dataset = MIMIC3Dataset( + root=args.root, tables=TABLES, + cache_dir=cache_dir, dev=args.dev, + ) + print(f" Loaded in {time.time()-t0:.1f}s") + dataset.stats() + + print("\n[2/4] Applying AMA task...") + task = AMAPredictionMIMIC3() + try: + sample_dataset = dataset.set_task(task) + except ValueError as exc: + if "unique labels" not in str(exc).lower(): + raise + print(f"\n {exc}") + print(" The dataset contains no AMA-positive cases.") + print(" AMA prevalence is ~2% so small/synthetic data") + print(" often lacks positives. Demonstrating the task") + print(" on raw patients instead:\n") + total = 0 + for patient in dataset.iter_patients(): + samples = task(patient) + total += len(samples) + print(f" Task produced {total} samples (all label=0)") + print("\n Re-run with real MIMIC-III for ablation:") + print(" python examples/" + "mimic3_ama_prediction_rnn.py \\") + print(" --root /path/to/mimic-iii/1.4") + print("\nDone.") + return + print(f" Samples: {len(sample_dataset)}") + + print("\n[3/4] Building demographics lookup...") + t0 = time.time() + lookup = _build_demographics_lookup(dataset, task) + print(f" {len(lookup)} entries in {time.time()-t0:.1f}s") + + print(f"\n[4/4] Running ablation ({len(BASELINES)} baselines " + f"x {args.splits} splits)...\n") + + t_total = time.time() + for name, feature_keys in BASELINES.items(): + split_results = [] + for i in range(args.splits): + t0 = time.time() + res = _run_single_split( + sample_dataset, feature_keys, lookup, + seed=i, epochs=args.epochs, + ) + elapsed = time.time() - t0 + if res is not None: + print(f" [{name}] split {i+1:3d}/{args.splits}: " + f"AUROC={_fmt(res['auroc'])} ({elapsed:.1f}s)") + else: + print(f" [{name}] split {i+1:3d}/{args.splits}: FAILED") + split_results.append(res) + + agg = _aggregate(split_results) + _print_results(name, feature_keys, agg) + + print(f"\nTotal time: {time.time()-t_total:.1f}s") + print("Done.") + + +if __name__ == "__main__": + main() diff --git a/pyhealth/tasks/ama_prediction.py b/pyhealth/tasks/ama_prediction.py index 5ccf61e0a..33b799ebf 100644 --- a/pyhealth/tasks/ama_prediction.py +++ b/pyhealth/tasks/ama_prediction.py @@ -1,8 +1,15 @@ +import re from datetime import datetime from typing import Any, Dict, List from .base_task import BaseTask +_SUBSTANCE_PATTERN = re.compile( + r"alcohol|opioid|opiate|heroin|cocaine|drug|withdrawal" + r"|intoxication|overdose|substance|etoh", + re.IGNORECASE, +) + def _normalize_race(ethnicity: str) -> str: """Map MIMIC-III ethnicity strings to the race categories used by @@ -63,29 +70,48 @@ def _safe_parse_datetime(value: Any) -> datetime: raise ValueError(f"Cannot parse datetime: {value!r}") +def _has_substance_use(diagnosis: Any) -> int: + """Detect substance-use related admission from the free-text + ``DIAGNOSIS`` field in MIMIC-III ``ADMISSIONS``. + + Args: + diagnosis: Raw ``diagnosis`` string from the admissions table. + + Returns: + 1 if a substance-use keyword is found, 0 otherwise. + """ + if diagnosis is None: + return 0 + return 1 if _SUBSTANCE_PATTERN.search(str(diagnosis)) else 0 + + class AMAPredictionMIMIC3(BaseTask): """Predict whether a patient leaves the hospital against medical advice. This task reproduces the AMA (Against Medical Advice) discharge prediction target from Boag et al. 2018, "Racial Disparities and - Mistrust in End-of-Life Care." A positive label indicates that the + Mistrust in End-of-Life Care." A positive label indicates that the patient's ``discharge_location`` is ``"LEFT AGAINST MEDICAL ADVI"`` in the MIMIC-III admissions table. - The feature set follows the paper's **BASELINE+RACE** - configuration: + The feature set supports three ablation baselines: + + * **BASELINE** -- ``demographics`` (gender, insurance), + ``age``, and ``los``. Select with + ``feature_keys=["demographics", "age", "los"]``. + + * **BASELINE + RACE** -- adds ``race`` (normalized ethnicity). + Select with + ``feature_keys=["demographics", "age", "los", "race"]``. - * **demographics** (multi-hot) -- gender, normalized race, and - normalized insurance category. - * **age** (tensor) -- patient age at admission in years. - * **los** (tensor) -- length of hospital stay in days. - * **conditions** (sequence) -- ICD-9 diagnosis codes. - * **procedures** (sequence) -- ICD-9 procedure codes. - * **drugs** (sequence) -- prescription drug names. + * **BASELINE + RACE + SUBSTANCE** -- adds ``has_substance_use`` + (derived from ``ADMISSIONS.DIAGNOSIS`` free-text field). + Select with + ``feature_keys=["demographics", "age", "los", "race", + "has_substance_use"]``. - The demographic and clinical-code features can be used - independently or together via the model's ``feature_keys`` - parameter, enabling ablation studies that mirror the paper. + These baselines can be toggled via the model's ``feature_keys`` + parameter without changing the task. Unlike mortality or readmission prediction, the label is a property of the **current** admission, so patients with only one visit are @@ -116,12 +142,14 @@ class AMAPredictionMIMIC3(BaseTask): task_name: str = "AMAPredictionMIMIC3" input_schema: Dict[str, str] = { - "conditions": "sequence", - "procedures": "sequence", - "drugs": "sequence", "demographics": "multi_hot", "age": "tensor", "los": "tensor", + "race": "multi_hot", + "has_substance_use": "tensor", + "conditions": "sequence", + "procedures": "sequence", + "drugs": "sequence", } output_schema: Dict[str, str] = {"ama": "binary"} @@ -146,17 +174,9 @@ def __call__(self, patient: Any) -> List[Dict[str, Any]]: patient: A Patient object from ``MIMIC3Dataset``. Returns: - A list of sample dictionaries, each containing: - - ``visit_id``: MIMIC-III ``hadm_id``. - - ``patient_id``: MIMIC-III ``subject_id``. - - ``conditions``: List of ICD-9 diagnosis codes. - - ``procedures``: List of ICD-9 procedure codes. - - ``drugs``: List of drug names from prescriptions. - - ``demographics``: List of categorical tokens - (gender, race, insurance). - - ``age``: Patient age at admission in years (float). - - ``los``: Hospital length of stay in days (float). - - ``ama``: Binary label (1 = AMA discharge, 0 = other). + A list of sample dictionaries. Each dictionary contains + the features described in the class docstring plus + ``visit_id``, ``patient_id``, and the ``ama`` label. """ admissions = patient.get_events(event_type="admissions") if len(admissions) == 0: @@ -218,18 +238,22 @@ def __call__(self, patient: Any) -> List[Dict[str, Any]]: if len(drugs) == 0: continue - # --- Demographics (categorical) --- - ethnicity = getattr(admission, "ethnicity", None) + # --- BASELINE demographics (gender + insurance) --- insurance = getattr(admission, "insurance", None) demo_tokens: List[str] = [] if gender: demo_tokens.append(f"gender:{gender}") - demo_tokens.append(f"race:{_normalize_race(ethnicity)}") demo_tokens.append( f"insurance:{_normalize_insurance(insurance)}" ) + # --- Race (separate feature for ablation) --- + ethnicity = getattr(admission, "ethnicity", None) + race_tokens: List[str] = [ + f"race:{_normalize_race(ethnicity)}" + ] + # --- Age (continuous) --- age_years = 0.0 if dob is not None: @@ -261,16 +285,22 @@ def __call__(self, patient: Any) -> List[Dict[str, Any]]: except (ValueError, TypeError): los_days = 0.0 + # --- Substance use (from ADMISSIONS.DIAGNOSIS) --- + diagnosis_text = getattr(admission, "diagnosis", None) + substance = float(_has_substance_use(diagnosis_text)) + samples.append( { "visit_id": admission.hadm_id, "patient_id": patient.patient_id, - "conditions": conditions, - "procedures": procedures_list, - "drugs": drugs, "demographics": demo_tokens, "age": [age_years], "los": [los_days], + "race": race_tokens, + "has_substance_use": [substance], + "conditions": conditions, + "procedures": procedures_list, + "drugs": drugs, "ama": ama_label, } ) diff --git a/tests/core/test_mimic3_ama_prediction.py b/tests/core/test_mimic3_ama_prediction.py index 022f2e1cd..7570d0921 100644 --- a/tests/core/test_mimic3_ama_prediction.py +++ b/tests/core/test_mimic3_ama_prediction.py @@ -1,11 +1,10 @@ -import tempfile import unittest from datetime import datetime -from pathlib import Path from unittest.mock import MagicMock from pyhealth.tasks.ama_prediction import ( AMAPredictionMIMIC3, + _has_substance_use, _normalize_insurance, _normalize_race, ) @@ -28,17 +27,18 @@ def _build_patient( gender="M", dob="2100-01-01 00:00:00", ): - """Build a mock Patient whose ``get_events`` respects *filters*. + """Build a mock Patient with ``get_events`` that respects filters. + + Uses 2-5 synthetic patients max. No real dataset is loaded. Args: patient_id: Patient identifier string. - admissions: List of admission event dicts (should include - ``ethnicity``, ``insurance``, ``dischtime``). - diagnoses: List of diagnosis event dicts (must include ``hadm_id``). - procedures: List of procedure event dicts (must include ``hadm_id``). - prescriptions: List of prescription event dicts (must include ``hadm_id``). - gender: Gender string for the patient demographics event. - dob: Date of birth string for computing age. + admissions: List of dicts for admission events. + diagnoses: List of dicts for diagnosis events. + procedures: List of dicts for procedure events. + prescriptions: List of dicts for prescription events. + gender: Gender string for the demographics event. + dob: Date-of-birth string for computing age. """ patient = MagicMock() patient.patient_id = patient_id @@ -61,7 +61,11 @@ def _get_events(event_type, filters=None, **kwargs): }.get(event_type, []) if filters: col, op, val = filters[0] - source = [e for e in source if getattr(e, col, None) == val] + source = [ + e + for e in source + if getattr(e, col, None) == val + ] return source patient.get_events = _get_events @@ -77,6 +81,8 @@ def _get_events(event_type, filters=None, **kwargs): "demographics", "age", "los", + "race", + "has_substance_use", "ama", } @@ -86,13 +92,19 @@ class TestNormalizeRace(unittest.TestCase): def test_white(self): self.assertEqual(_normalize_race("WHITE"), "White") - self.assertEqual(_normalize_race("WHITE - RUSSIAN"), "White") + self.assertEqual( + _normalize_race("WHITE - RUSSIAN"), "White" + ) def test_black(self): - self.assertEqual(_normalize_race("BLACK/AFRICAN AMERICAN"), "Black") + self.assertEqual( + _normalize_race("BLACK/AFRICAN AMERICAN"), "Black" + ) def test_hispanic(self): - self.assertEqual(_normalize_race("HISPANIC OR LATINO"), "Hispanic") + self.assertEqual( + _normalize_race("HISPANIC OR LATINO"), "Hispanic" + ) self.assertEqual( _normalize_race("SOUTH AMERICAN"), "Hispanic" ) @@ -104,39 +116,118 @@ def test_asian(self): def test_native_american(self): self.assertEqual( - _normalize_race("AMERICAN INDIAN/ALASKA NATIVE"), "Native American" + _normalize_race("AMERICAN INDIAN/ALASKA NATIVE"), + "Native American", ) def test_other(self): - self.assertEqual(_normalize_race("UNKNOWN/NOT SPECIFIED"), "Other") + self.assertEqual( + _normalize_race("UNKNOWN/NOT SPECIFIED"), "Other" + ) self.assertEqual(_normalize_race(None), "Other") def test_normalize_insurance(self): - self.assertEqual(_normalize_insurance("Medicare"), "Public") - self.assertEqual(_normalize_insurance("Medicaid"), "Public") - self.assertEqual(_normalize_insurance("Government"), "Public") - self.assertEqual(_normalize_insurance("Private"), "Private") - self.assertEqual(_normalize_insurance("Self Pay"), "Self Pay") + self.assertEqual( + _normalize_insurance("Medicare"), "Public" + ) + self.assertEqual( + _normalize_insurance("Medicaid"), "Public" + ) + self.assertEqual( + _normalize_insurance("Government"), "Public" + ) + self.assertEqual( + _normalize_insurance("Private"), "Private" + ) + self.assertEqual( + _normalize_insurance("Self Pay"), "Self Pay" + ) self.assertEqual(_normalize_insurance(None), "Other") +class TestHasSubstanceUse(unittest.TestCase): + """Unit tests for the substance-use detection helper.""" + + def test_alcohol(self): + self.assertEqual( + _has_substance_use("ALCOHOL WITHDRAWAL"), 1 + ) + + def test_opioid(self): + self.assertEqual( + _has_substance_use("OPIOID DEPENDENCE"), 1 + ) + + def test_heroin(self): + self.assertEqual( + _has_substance_use("HEROIN OVERDOSE"), 1 + ) + + def test_cocaine(self): + self.assertEqual( + _has_substance_use("COCAINE INTOXICATION"), 1 + ) + + def test_drug_withdrawal(self): + self.assertEqual( + _has_substance_use("DRUG WITHDRAWAL SEIZURE"), 1 + ) + + def test_etoh(self): + self.assertEqual(_has_substance_use("ETOH ABUSE"), 1) + + def test_substance(self): + self.assertEqual( + _has_substance_use("SUBSTANCE ABUSE"), 1 + ) + + def test_overdose(self): + self.assertEqual( + _has_substance_use("OVERDOSE - ACCIDENTAL"), 1 + ) + + def test_negative(self): + self.assertEqual(_has_substance_use("PNEUMONIA"), 0) + self.assertEqual(_has_substance_use("CHEST PAIN"), 0) + + def test_none(self): + self.assertEqual(_has_substance_use(None), 0) + + def test_case_insensitive(self): + self.assertEqual( + _has_substance_use("alcohol withdrawal"), 1 + ) + self.assertEqual( + _has_substance_use("Heroin Overdose"), 1 + ) + + class TestAMAPredictionMIMIC3Schema(unittest.TestCase): """Validate class-level schema attributes.""" def test_task_name(self): - self.assertEqual(AMAPredictionMIMIC3.task_name, "AMAPredictionMIMIC3") + self.assertEqual( + AMAPredictionMIMIC3.task_name, + "AMAPredictionMIMIC3", + ) def test_input_schema(self): schema = AMAPredictionMIMIC3.input_schema - self.assertEqual(schema["conditions"], "sequence") - self.assertEqual(schema["procedures"], "sequence") - self.assertEqual(schema["drugs"], "sequence") self.assertEqual(schema["demographics"], "multi_hot") self.assertEqual(schema["age"], "tensor") self.assertEqual(schema["los"], "tensor") + self.assertEqual(schema["race"], "multi_hot") + self.assertEqual( + schema["has_substance_use"], "tensor" + ) + self.assertEqual(schema["conditions"], "sequence") + self.assertEqual(schema["procedures"], "sequence") + self.assertEqual(schema["drugs"], "sequence") def test_output_schema(self): - self.assertEqual(AMAPredictionMIMIC3.output_schema, {"ama": "binary"}) + self.assertEqual( + AMAPredictionMIMIC3.output_schema, {"ama": "binary"} + ) def test_defaults(self): task = AMAPredictionMIMIC3() @@ -144,13 +235,18 @@ def test_defaults(self): class TestAMAPredictionMIMIC3Mock(unittest.TestCase): - """Unit tests using synthetic mock data (no real dataset needed).""" + """Unit tests using purely synthetic mock patients. + + No real dataset is loaded. Each test builds mock Patient + objects in memory and runs the task callable directly. + All tests complete in milliseconds. + """ def setUp(self): self.task = AMAPredictionMIMIC3() def _default_admission(self, hadm_id="100", **overrides): - """Return a standard admission dict with demographic fields.""" + """Return a standard admission dict.""" adm = { "hadm_id": hadm_id, "admission_type": "EMERGENCY", @@ -159,10 +255,15 @@ def _default_admission(self, hadm_id="100", **overrides): "insurance": "Private", "dischtime": "2150-01-10 14:00:00", "timestamp": datetime(2150, 1, 1), + "diagnosis": "PNEUMONIA", } adm.update(overrides) return adm + # ---------------------------------------------------------- + # Label generation + # ---------------------------------------------------------- + def test_empty_patient(self): patient = MagicMock() patient.patient_id = "P0" @@ -170,18 +271,26 @@ def test_empty_patient(self): self.assertEqual(self.task(patient), []) def test_ama_label_positive(self): - """Admission with AMA discharge should produce label=1.""" + """AMA discharge -> label=1.""" patient = _build_patient( patient_id="P1", admissions=[ self._default_admission( hadm_id="100", - discharge_location="LEFT AGAINST MEDICAL ADVI", + discharge_location=( + "LEFT AGAINST MEDICAL ADVI" + ), ), ], - diagnoses=[{"hadm_id": "100", "icd9_code": "4019"}], - procedures=[{"hadm_id": "100", "icd9_code": "3893"}], - prescriptions=[{"hadm_id": "100", "drug": "Aspirin"}], + diagnoses=[ + {"hadm_id": "100", "icd9_code": "4019"} + ], + procedures=[ + {"hadm_id": "100", "icd9_code": "3893"} + ], + prescriptions=[ + {"hadm_id": "100", "drug": "Aspirin"} + ], ) samples = self.task(patient) self.assertEqual(len(samples), 1) @@ -190,15 +299,21 @@ def test_ama_label_positive(self): self.assertEqual(samples[0]["patient_id"], "P1") def test_ama_label_negative(self): - """Non-AMA discharge should produce label=0.""" + """Non-AMA discharge -> label=0.""" patient = _build_patient( patient_id="P2", admissions=[ self._default_admission(hadm_id="200"), ], - diagnoses=[{"hadm_id": "200", "icd9_code": "25000"}], - procedures=[{"hadm_id": "200", "icd9_code": "3995"}], - prescriptions=[{"hadm_id": "200", "drug": "Metformin"}], + diagnoses=[ + {"hadm_id": "200", "icd9_code": "25000"} + ], + procedures=[ + {"hadm_id": "200", "icd9_code": "3995"} + ], + prescriptions=[ + {"hadm_id": "200", "drug": "Metformin"} + ], ) samples = self.task(patient) self.assertEqual(len(samples), 1) @@ -213,7 +328,9 @@ def test_multiple_admissions_mixed_labels(self): self._default_admission( hadm_id="301", admission_type="URGENT", - discharge_location="LEFT AGAINST MEDICAL ADVI", + discharge_location=( + "LEFT AGAINST MEDICAL ADVI" + ), timestamp=datetime(2150, 6, 1), dischtime="2150-06-05 10:00:00", ), @@ -237,90 +354,146 @@ def test_multiple_admissions_mixed_labels(self): self.assertEqual(labels["300"], 0) self.assertEqual(labels["301"], 1) - def test_skip_admission_missing_diagnoses(self): - """Admissions with no diagnosis codes should be excluded.""" + # ---------------------------------------------------------- + # Filtering / edge cases + # ---------------------------------------------------------- + + def test_skip_missing_diagnoses(self): + """No diagnosis codes -> skip admission.""" patient = _build_patient( patient_id="P4", - admissions=[self._default_admission(hadm_id="400")], + admissions=[ + self._default_admission(hadm_id="400") + ], diagnoses=[], - procedures=[{"hadm_id": "400", "icd9_code": "3893"}], - prescriptions=[{"hadm_id": "400", "drug": "Aspirin"}], + procedures=[ + {"hadm_id": "400", "icd9_code": "3893"} + ], + prescriptions=[ + {"hadm_id": "400", "drug": "Aspirin"} + ], ) self.assertEqual(self.task(patient), []) - def test_skip_admission_missing_procedures(self): - """Admissions with no procedure codes should be excluded.""" + def test_skip_missing_procedures(self): + """No procedure codes -> skip admission.""" patient = _build_patient( patient_id="P5", - admissions=[self._default_admission(hadm_id="500")], - diagnoses=[{"hadm_id": "500", "icd9_code": "4019"}], + admissions=[ + self._default_admission(hadm_id="500") + ], + diagnoses=[ + {"hadm_id": "500", "icd9_code": "4019"} + ], procedures=[], - prescriptions=[{"hadm_id": "500", "drug": "Aspirin"}], + prescriptions=[ + {"hadm_id": "500", "drug": "Aspirin"} + ], ) self.assertEqual(self.task(patient), []) - def test_skip_admission_missing_prescriptions(self): - """Admissions with no prescriptions should be excluded.""" + def test_skip_missing_prescriptions(self): + """No prescriptions -> skip admission.""" patient = _build_patient( patient_id="P6", admissions=[ self._default_admission( hadm_id="600", - discharge_location="LEFT AGAINST MEDICAL ADVI", + discharge_location=( + "LEFT AGAINST MEDICAL ADVI" + ), ), ], - diagnoses=[{"hadm_id": "600", "icd9_code": "4019"}], - procedures=[{"hadm_id": "600", "icd9_code": "3893"}], + diagnoses=[ + {"hadm_id": "600", "icd9_code": "4019"} + ], + procedures=[ + {"hadm_id": "600", "icd9_code": "3893"} + ], prescriptions=[], ) self.assertEqual(self.task(patient), []) def test_exclude_newborns(self): - """NEWBORN admissions should be excluded by default.""" + """NEWBORN admissions skipped when flag is True.""" patient = _build_patient( patient_id="P7", admissions=[ self._default_admission( - hadm_id="700", admission_type="NEWBORN" + hadm_id="700", + admission_type="NEWBORN", ), ], - diagnoses=[{"hadm_id": "700", "icd9_code": "V3000"}], - procedures=[{"hadm_id": "700", "icd9_code": "9904"}], - prescriptions=[{"hadm_id": "700", "drug": "Vitamin K"}], + diagnoses=[ + {"hadm_id": "700", "icd9_code": "V3000"} + ], + procedures=[ + {"hadm_id": "700", "icd9_code": "9904"} + ], + prescriptions=[ + {"hadm_id": "700", "drug": "Vitamin K"} + ], ) - task_exclude = AMAPredictionMIMIC3(exclude_newborns=True) - self.assertEqual(task_exclude(patient), []) + task_ex = AMAPredictionMIMIC3(exclude_newborns=True) + self.assertEqual(task_ex(patient), []) - task_include = AMAPredictionMIMIC3(exclude_newborns=False) - samples = task_include(patient) + task_in = AMAPredictionMIMIC3(exclude_newborns=False) + samples = task_in(patient) self.assertEqual(len(samples), 1) self.assertEqual(samples[0]["ama"], 0) + def test_no_patients_events(self): + """Patient with admissions but no ``patients`` event.""" + patient = MagicMock() + patient.patient_id = "PX" + + def _get(event_type, **kw): + if event_type == "patients": + return [] + if event_type == "admissions": + return [_make_event(hadm_id="1")] + return [] + + patient.get_events = _get + self.assertEqual(self.task(patient), []) + + # ---------------------------------------------------------- + # Feature extraction + # ---------------------------------------------------------- + def test_sample_keys(self): """Every sample must contain the expected keys.""" patient = _build_patient( patient_id="P8", admissions=[ self._default_admission( - hadm_id="800", discharge_location="SNF", + hadm_id="800", + discharge_location="SNF", timestamp=datetime(2150, 2, 15), ), ], - diagnoses=[{"hadm_id": "800", "icd9_code": "4280"}], - procedures=[{"hadm_id": "800", "icd9_code": "3722"}], - prescriptions=[{"hadm_id": "800", "drug": "Furosemide"}], + diagnoses=[ + {"hadm_id": "800", "icd9_code": "4280"} + ], + procedures=[ + {"hadm_id": "800", "icd9_code": "3722"} + ], + prescriptions=[ + {"hadm_id": "800", "drug": "Furosemide"} + ], ) samples = self.task(patient) self.assertEqual(len(samples), 1) self.assertEqual(set(samples[0].keys()), SAMPLE_KEYS) def test_clinical_codes_content(self): - """Verify extracted codes match the input data.""" + """Extracted codes match the input data.""" patient = _build_patient( patient_id="P9", admissions=[ self._default_admission( - hadm_id="900", timestamp=datetime(2150, 4, 1), + hadm_id="900", + timestamp=datetime(2150, 4, 1), ), ], diagnoses=[ @@ -336,12 +509,17 @@ def test_clinical_codes_content(self): ], ) samples = self.task(patient) - self.assertEqual(samples[0]["conditions"], ["4019", "25000"]) + self.assertEqual( + samples[0]["conditions"], ["4019", "25000"] + ) self.assertEqual(samples[0]["procedures"], ["3893"]) - self.assertEqual(samples[0]["drugs"], ["Lisinopril", "Metformin"]) + self.assertEqual( + samples[0]["drugs"], + ["Lisinopril", "Metformin"], + ) - def test_demographics_tokens(self): - """Demographics should include prefixed gender, race, insurance.""" + def test_demographics_baseline_tokens(self): + """BASELINE demographics: gender + insurance, no race.""" patient = _build_patient( patient_id="P10", admissions=[ @@ -352,19 +530,108 @@ def test_demographics_tokens(self): timestamp=datetime(2150, 5, 1), ), ], - diagnoses=[{"hadm_id": "1000", "icd9_code": "4019"}], - procedures=[{"hadm_id": "1000", "icd9_code": "3893"}], - prescriptions=[{"hadm_id": "1000", "drug": "Aspirin"}], + diagnoses=[ + {"hadm_id": "1000", "icd9_code": "4019"} + ], + procedures=[ + {"hadm_id": "1000", "icd9_code": "3893"} + ], + prescriptions=[ + {"hadm_id": "1000", "drug": "Aspirin"} + ], gender="F", ) samples = self.task(patient) demo = samples[0]["demographics"] self.assertIn("gender:F", demo) - self.assertIn("race:Black", demo) self.assertIn("insurance:Public", demo) + for token in demo: + self.assertFalse( + token.startswith("race:"), + "race must not be in demographics", + ) + + def test_race_separate_feature(self): + """Race is a separate multi-hot feature.""" + patient = _build_patient( + patient_id="P10b", + admissions=[ + self._default_admission( + hadm_id="1001", + ethnicity="BLACK/AFRICAN AMERICAN", + timestamp=datetime(2150, 5, 1), + ), + ], + diagnoses=[ + {"hadm_id": "1001", "icd9_code": "4019"} + ], + procedures=[ + {"hadm_id": "1001", "icd9_code": "3893"} + ], + prescriptions=[ + {"hadm_id": "1001", "drug": "Aspirin"} + ], + ) + samples = self.task(patient) + self.assertIn("race", samples[0]) + self.assertEqual( + samples[0]["race"], ["race:Black"] + ) + + def test_substance_use_positive(self): + """Substance-use diagnosis -> has_substance_use=1.""" + patient = _build_patient( + patient_id="P14", + admissions=[ + self._default_admission( + hadm_id="1400", + diagnosis="ALCOHOL WITHDRAWAL", + timestamp=datetime(2150, 7, 1), + ), + ], + diagnoses=[ + {"hadm_id": "1400", "icd9_code": "29181"} + ], + procedures=[ + {"hadm_id": "1400", "icd9_code": "3893"} + ], + prescriptions=[ + {"hadm_id": "1400", "drug": "Lorazepam"} + ], + ) + samples = self.task(patient) + self.assertEqual( + samples[0]["has_substance_use"], [1.0] + ) + + def test_substance_use_negative(self): + """Non-substance diagnosis -> has_substance_use=0.""" + patient = _build_patient( + patient_id="P15", + admissions=[ + self._default_admission( + hadm_id="1500", + diagnosis="PNEUMONIA", + timestamp=datetime(2150, 8, 1), + ), + ], + diagnoses=[ + {"hadm_id": "1500", "icd9_code": "486"} + ], + procedures=[ + {"hadm_id": "1500", "icd9_code": "3893"} + ], + prescriptions=[ + {"hadm_id": "1500", "drug": "Levofloxacin"} + ], + ) + samples = self.task(patient) + self.assertEqual( + samples[0]["has_substance_use"], [0.0] + ) def test_age_calculation(self): - """Age should be computed from dob and admission timestamp.""" + """Age computed from dob and admission timestamp.""" patient = _build_patient( patient_id="P11", admissions=[ @@ -374,16 +641,22 @@ def test_age_calculation(self): dischtime="2150-06-20 12:00:00", ), ], - diagnoses=[{"hadm_id": "1100", "icd9_code": "4019"}], - procedures=[{"hadm_id": "1100", "icd9_code": "3893"}], - prescriptions=[{"hadm_id": "1100", "drug": "Aspirin"}], + diagnoses=[ + {"hadm_id": "1100", "icd9_code": "4019"} + ], + procedures=[ + {"hadm_id": "1100", "icd9_code": "3893"} + ], + prescriptions=[ + {"hadm_id": "1100", "drug": "Aspirin"} + ], dob="2100-01-01 00:00:00", ) samples = self.task(patient) self.assertEqual(samples[0]["age"], [50.0]) def test_age_capped_at_90(self): - """Ages above 90 should be capped (MIMIC-III de-identification).""" + """Ages above 90 are capped (MIMIC-III convention).""" patient = _build_patient( patient_id="P12", admissions=[ @@ -392,16 +665,22 @@ def test_age_capped_at_90(self): timestamp=datetime(2150, 6, 15), ), ], - diagnoses=[{"hadm_id": "1200", "icd9_code": "4019"}], - procedures=[{"hadm_id": "1200", "icd9_code": "3893"}], - prescriptions=[{"hadm_id": "1200", "drug": "Aspirin"}], + diagnoses=[ + {"hadm_id": "1200", "icd9_code": "4019"} + ], + procedures=[ + {"hadm_id": "1200", "icd9_code": "3893"} + ], + prescriptions=[ + {"hadm_id": "1200", "drug": "Aspirin"} + ], dob="1850-01-01 00:00:00", ) samples = self.task(patient) self.assertEqual(samples[0]["age"], [90.0]) def test_los_calculation(self): - """LOS should be computed in days from admittime to dischtime.""" + """LOS in days from admittime to dischtime.""" patient = _build_patient( patient_id="P13", admissions=[ @@ -411,64 +690,99 @@ def test_los_calculation(self): dischtime="2150-03-06 08:00:00", ), ], - diagnoses=[{"hadm_id": "1300", "icd9_code": "4019"}], - procedures=[{"hadm_id": "1300", "icd9_code": "3893"}], - prescriptions=[{"hadm_id": "1300", "drug": "Aspirin"}], + diagnoses=[ + {"hadm_id": "1300", "icd9_code": "4019"} + ], + procedures=[ + {"hadm_id": "1300", "icd9_code": "3893"} + ], + prescriptions=[ + {"hadm_id": "1300", "drug": "Aspirin"} + ], ) samples = self.task(patient) - self.assertAlmostEqual(samples[0]["los"][0], 5.0, places=2) - + self.assertAlmostEqual( + samples[0]["los"][0], 5.0, places=2 + ) -class TestAMAPredictionMIMIC3Integration(unittest.TestCase): - """Integration test using the MIMIC-III demo dataset. + # ---------------------------------------------------------- + # Multi-patient synthetic "dataset" (2 patients) + # ---------------------------------------------------------- - The demo dataset contains zero AMA discharge events. Because the - ``BinaryLabelProcessor`` requires exactly two unique labels to fit, - ``set_task()`` cannot complete on this dataset. Instead we verify - that the task callable itself produces well-formed samples when - invoked on real ``Patient`` objects from the loaded dataset. - """ + def test_two_patient_synthetic_dataset(self): + """End-to-end with 2 synthetic patients, both labels.""" + p1 = _build_patient( + patient_id="S1", + admissions=[ + self._default_admission( + hadm_id="A1", + discharge_location=( + "LEFT AGAINST MEDICAL ADVI" + ), + ethnicity="HISPANIC OR LATINO", + insurance="Medicaid", + diagnosis="HEROIN OVERDOSE", + timestamp=datetime(2150, 1, 1), + dischtime="2150-01-03 12:00:00", + ), + ], + diagnoses=[ + {"hadm_id": "A1", "icd9_code": "96500"} + ], + procedures=[ + {"hadm_id": "A1", "icd9_code": "9604"} + ], + prescriptions=[ + {"hadm_id": "A1", "drug": "Naloxone"} + ], + gender="M", + dob="2100-06-15 00:00:00", + ) + p2 = _build_patient( + patient_id="S2", + admissions=[ + self._default_admission( + hadm_id="A2", + discharge_location="HOME", + ethnicity="WHITE", + insurance="Private", + diagnosis="CHEST PAIN", + timestamp=datetime(2150, 3, 10), + dischtime="2150-03-12 08:00:00", + ), + ], + diagnoses=[ + {"hadm_id": "A2", "icd9_code": "78650"} + ], + procedures=[ + {"hadm_id": "A2", "icd9_code": "8856"} + ], + prescriptions=[ + {"hadm_id": "A2", "drug": "Aspirin"} + ], + gender="F", + dob="2090-01-01 00:00:00", + ) - @classmethod - def setUpClass(cls): - from pyhealth.datasets import MIMIC3Dataset - - cls.cache_dir = tempfile.TemporaryDirectory() - demo_path = str( - Path(__file__).parent.parent.parent - / "test-resources" - / "core" - / "mimic3demo" - ) - cls.dataset = MIMIC3Dataset( - root=demo_path, - tables=["diagnoses_icd", "procedures_icd", "prescriptions"], - cache_dir=cls.cache_dir.name, - ) - cls.task = AMAPredictionMIMIC3() - - def test_task_callable_on_real_patients(self): - """Run the task on real Patient objects from the demo dataset.""" - total_samples = 0 - for patient in self.dataset.iter_patients(): - samples = self.task(patient) - for sample in samples: - self.assertIn("ama", sample) - self.assertEqual(sample["ama"], 0) - self.assertIsInstance(sample["conditions"], list) - self.assertIsInstance(sample["procedures"], list) - self.assertIsInstance(sample["drugs"], list) - self.assertGreater(len(sample["conditions"]), 0) - self.assertGreater(len(sample["procedures"]), 0) - self.assertGreater(len(sample["drugs"]), 0) - self.assertIsInstance(sample["demographics"], list) - self.assertGreater(len(sample["demographics"]), 0) - self.assertIsInstance(sample["age"], list) - self.assertEqual(len(sample["age"]), 1) - self.assertIsInstance(sample["los"], list) - self.assertEqual(len(sample["los"]), 1) - total_samples += 1 - self.assertGreater(total_samples, 0, "Should produce at least one sample") + all_samples = self.task(p1) + self.task(p2) + self.assertEqual(len(all_samples), 2) + + s1 = all_samples[0] + self.assertEqual(s1["ama"], 1) + self.assertEqual(s1["race"], ["race:Hispanic"]) + self.assertIn("insurance:Public", s1["demographics"]) + self.assertEqual(s1["has_substance_use"], [1.0]) + self.assertAlmostEqual(s1["age"][0], 49.0, places=0) + self.assertAlmostEqual(s1["los"][0], 2.5, places=1) + + s2 = all_samples[1] + self.assertEqual(s2["ama"], 0) + self.assertEqual(s2["race"], ["race:White"]) + self.assertIn( + "insurance:Private", s2["demographics"] + ) + self.assertEqual(s2["has_substance_use"], [0.0]) + self.assertAlmostEqual(s2["age"][0], 60.0, places=0) if __name__ == "__main__": From ae0fed8e619d32da4b98b25ca4ac2e539ea8343b Mon Sep 17 00:00:00 2001 From: Madhav Kanda Date: Sun, 12 Apr 2026 18:12:47 -0500 Subject: [PATCH 03/10] Fixed the feature selection issue --- ...mic3_ama_prediction_logistic_regression.py | 3 +- pyhealth/tasks/ama_prediction.py | 39 +-- tests/core/test_mimic3_ama_prediction.py | 289 ++++++++++++------ 3 files changed, 202 insertions(+), 129 deletions(-) diff --git a/examples/mimic3_ama_prediction_logistic_regression.py b/examples/mimic3_ama_prediction_logistic_regression.py index 4f5427463..17f9e42d6 100644 --- a/examples/mimic3_ama_prediction_logistic_regression.py +++ b/examples/mimic3_ama_prediction_logistic_regression.py @@ -38,7 +38,6 @@ SYNTHETIC_ROOT = ( "https://storage.googleapis.com/pyhealth/Synthetic_MIMIC-III" ) -TABLES = ["DIAGNOSES_ICD", "PROCEDURES_ICD", "PRESCRIPTIONS"] BASELINES = { "BASELINE": ["demographics", "age", "los"], @@ -356,7 +355,7 @@ def main(): print("\n[1/4] Loading dataset...") t0 = time.time() dataset = MIMIC3Dataset( - root=args.root, tables=TABLES, + root=args.root, tables=[], cache_dir=cache_dir, dev=args.dev, ) print(f" Loaded in {time.time()-t0:.1f}s") diff --git a/pyhealth/tasks/ama_prediction.py b/pyhealth/tasks/ama_prediction.py index 33b799ebf..96cc806ab 100644 --- a/pyhealth/tasks/ama_prediction.py +++ b/pyhealth/tasks/ama_prediction.py @@ -113,6 +113,10 @@ class AMAPredictionMIMIC3(BaseTask): These baselines can be toggled via the model's ``feature_keys`` parameter without changing the task. + Only administrative and demographic features are extracted; no + clinical code tables (diagnoses, procedures, prescriptions) are + required. + Unlike mortality or readmission prediction, the label is a property of the **current** admission, so patients with only one visit are eligible. @@ -127,7 +131,7 @@ class AMAPredictionMIMIC3(BaseTask): >>> from pyhealth.tasks import AMAPredictionMIMIC3 >>> dataset = MIMIC3Dataset( ... root="/path/to/mimic-iii/1.4", - ... tables=["diagnoses_icd", "procedures_icd", "prescriptions"], + ... tables=[], ... ) >>> task = AMAPredictionMIMIC3() >>> samples = dataset.set_task(task) @@ -147,9 +151,6 @@ class AMAPredictionMIMIC3(BaseTask): "los": "tensor", "race": "multi_hot", "has_substance_use": "tensor", - "conditions": "sequence", - "procedures": "sequence", - "drugs": "sequence", } output_schema: Dict[str, str] = {"ama": "binary"} @@ -166,8 +167,7 @@ def __init__(self, exclude_newborns: bool = True) -> None: def __call__(self, patient: Any) -> List[Dict[str, Any]]: """Processes a single patient for AMA discharge prediction. - Each admission with at least one diagnosis code, one procedure - code, and one prescription is emitted as a sample. The binary + Each non-newborn admission is emitted as a sample. The binary label is derived from the admission's ``discharge_location``. Args: @@ -214,30 +214,6 @@ def __call__(self, patient: Any) -> List[Dict[str, Any]]: else 0 ) - # --- Clinical codes --- - hadm_filter = ("hadm_id", "==", admission.hadm_id) - - diagnoses = patient.get_events( - event_type="diagnoses_icd", filters=[hadm_filter] - ) - conditions = [event.icd9_code for event in diagnoses] - if len(conditions) == 0: - continue - - procedures = patient.get_events( - event_type="procedures_icd", filters=[hadm_filter] - ) - procedures_list = [event.icd9_code for event in procedures] - if len(procedures_list) == 0: - continue - - prescriptions = patient.get_events( - event_type="prescriptions", filters=[hadm_filter] - ) - drugs = [event.drug for event in prescriptions] - if len(drugs) == 0: - continue - # --- BASELINE demographics (gender + insurance) --- insurance = getattr(admission, "insurance", None) @@ -298,9 +274,6 @@ def __call__(self, patient: Any) -> List[Dict[str, Any]]: "los": [los_days], "race": race_tokens, "has_substance_use": [substance], - "conditions": conditions, - "procedures": procedures_list, - "drugs": drugs, "ama": ama_label, } ) diff --git a/tests/core/test_mimic3_ama_prediction.py b/tests/core/test_mimic3_ama_prediction.py index 7570d0921..e820c5262 100644 --- a/tests/core/test_mimic3_ama_prediction.py +++ b/tests/core/test_mimic3_ama_prediction.py @@ -75,9 +75,6 @@ def _get_events(event_type, filters=None, **kwargs): SAMPLE_KEYS = { "visit_id", "patient_id", - "conditions", - "procedures", - "drugs", "demographics", "age", "los", @@ -220,9 +217,6 @@ def test_input_schema(self): self.assertEqual( schema["has_substance_use"], "tensor" ) - self.assertEqual(schema["conditions"], "sequence") - self.assertEqual(schema["procedures"], "sequence") - self.assertEqual(schema["drugs"], "sequence") def test_output_schema(self): self.assertEqual( @@ -358,62 +352,6 @@ def test_multiple_admissions_mixed_labels(self): # Filtering / edge cases # ---------------------------------------------------------- - def test_skip_missing_diagnoses(self): - """No diagnosis codes -> skip admission.""" - patient = _build_patient( - patient_id="P4", - admissions=[ - self._default_admission(hadm_id="400") - ], - diagnoses=[], - procedures=[ - {"hadm_id": "400", "icd9_code": "3893"} - ], - prescriptions=[ - {"hadm_id": "400", "drug": "Aspirin"} - ], - ) - self.assertEqual(self.task(patient), []) - - def test_skip_missing_procedures(self): - """No procedure codes -> skip admission.""" - patient = _build_patient( - patient_id="P5", - admissions=[ - self._default_admission(hadm_id="500") - ], - diagnoses=[ - {"hadm_id": "500", "icd9_code": "4019"} - ], - procedures=[], - prescriptions=[ - {"hadm_id": "500", "drug": "Aspirin"} - ], - ) - self.assertEqual(self.task(patient), []) - - def test_skip_missing_prescriptions(self): - """No prescriptions -> skip admission.""" - patient = _build_patient( - patient_id="P6", - admissions=[ - self._default_admission( - hadm_id="600", - discharge_location=( - "LEFT AGAINST MEDICAL ADVI" - ), - ), - ], - diagnoses=[ - {"hadm_id": "600", "icd9_code": "4019"} - ], - procedures=[ - {"hadm_id": "600", "icd9_code": "3893"} - ], - prescriptions=[], - ) - self.assertEqual(self.task(patient), []) - def test_exclude_newborns(self): """NEWBORN admissions skipped when flag is True.""" patient = _build_patient( @@ -486,38 +424,6 @@ def test_sample_keys(self): self.assertEqual(len(samples), 1) self.assertEqual(set(samples[0].keys()), SAMPLE_KEYS) - def test_clinical_codes_content(self): - """Extracted codes match the input data.""" - patient = _build_patient( - patient_id="P9", - admissions=[ - self._default_admission( - hadm_id="900", - timestamp=datetime(2150, 4, 1), - ), - ], - diagnoses=[ - {"hadm_id": "900", "icd9_code": "4019"}, - {"hadm_id": "900", "icd9_code": "25000"}, - ], - procedures=[ - {"hadm_id": "900", "icd9_code": "3893"}, - ], - prescriptions=[ - {"hadm_id": "900", "drug": "Lisinopril"}, - {"hadm_id": "900", "drug": "Metformin"}, - ], - ) - samples = self.task(patient) - self.assertEqual( - samples[0]["conditions"], ["4019", "25000"] - ) - self.assertEqual(samples[0]["procedures"], ["3893"]) - self.assertEqual( - samples[0]["drugs"], - ["Lisinopril", "Metformin"], - ) - def test_demographics_baseline_tokens(self): """BASELINE demographics: gender + insurance, no race.""" patient = _build_patient( @@ -785,5 +691,200 @@ def test_two_patient_synthetic_dataset(self): self.assertAlmostEqual(s2["age"][0], 60.0, places=0) +class TestAMAAblationBaselines(unittest.TestCase): + """Tests for the three ablation study feature baselines. + + These tests verify that each baseline can be used to select + different subsets of features via the model's feature_keys parameter. + """ + + def setUp(self): + """Create a simple test patient with mixed demographics.""" + self.task = AMAPredictionMIMIC3() + self.patient = _build_patient( + patient_id="ABLATION_TEST", + admissions=[ + { + "hadm_id": "A1", + "admission_type": "EMERGENCY", + "discharge_location": "HOME", + "ethnicity": "HISPANIC OR LATINO", + "insurance": "Medicaid", + "dischtime": "2150-01-10 14:00:00", + "timestamp": datetime(2150, 1, 1), + "diagnosis": "ALCOHOL WITHDRAWAL", + }, + { + "hadm_id": "A2", + "admission_type": "URGENT", + "discharge_location": "LEFT AGAINST MEDICAL ADVI", + "ethnicity": "WHITE", + "insurance": "Private", + "dischtime": "2150-06-05 10:00:00", + "timestamp": datetime(2150, 6, 1), + "diagnosis": "PNEUMONIA", + }, + ], + diagnoses=[ + {"hadm_id": "A1", "icd9_code": "29181"}, + {"hadm_id": "A2", "icd9_code": "486"}, + ], + procedures=[ + {"hadm_id": "A1", "icd9_code": "3893"}, + {"hadm_id": "A2", "icd9_code": "9604"}, + ], + prescriptions=[ + {"hadm_id": "A1", "drug": "Lorazepam"}, + {"hadm_id": "A2", "drug": "Levofloxacin"}, + ], + gender="M", + dob="2100-01-01 00:00:00", + ) + + def test_baseline_features_present(self): + """BASELINE includes demographics, age, los.""" + samples = self.task(self.patient) + self.assertGreaterEqual(len(samples), 1) + + sample = samples[0] + self.assertIn("demographics", sample) + self.assertIn("age", sample) + self.assertIn("los", sample) + self.assertTrue(isinstance(sample["age"][0], float)) + self.assertTrue(isinstance(sample["los"][0], float)) + self.assertTrue(isinstance(sample["demographics"], list)) + + def test_baseline_race_feature_present(self): + """BASELINE + RACE adds race feature.""" + samples = self.task(self.patient) + self.assertGreaterEqual(len(samples), 1) + + for sample in samples: + self.assertIn("race", sample) + self.assertTrue(isinstance(sample["race"], list)) + # Race should be one of the normalized values + race_val = sample["race"][0].split(":", 1)[1] + self.assertIn( + race_val, + ["White", "Black", "Hispanic", "Asian", + "Native American", "Other"], + ) + + def test_baseline_substance_use_feature_present(self): + """BASELINE + RACE + SUBSTANCE adds has_substance_use.""" + samples = self.task(self.patient) + self.assertGreaterEqual(len(samples), 1) + + for sample in samples: + self.assertIn("has_substance_use", sample) + self.assertTrue(isinstance(sample["has_substance_use"], list)) + self.assertIn(sample["has_substance_use"][0], [0.0, 1.0]) + + def test_substance_use_detection_in_ablation(self): + """Verify substance use detection for ablation patient.""" + samples = self.task(self.patient) + + # First admission has substance use (ALCOHOL WITHDRAWAL) + s1 = next(s for s in samples if s["visit_id"] == "A1") + self.assertEqual(s1["has_substance_use"], [1.0]) + + # Second admission has no substance use (PNEUMONIA) + s2 = next(s for s in samples if s["visit_id"] == "A2") + self.assertEqual(s2["has_substance_use"], [0.0]) + + def test_race_normalization_in_ablation(self): + """Verify race normalization for ablation patient.""" + samples = self.task(self.patient) + + # First admission: Hispanic + s1 = next(s for s in samples if s["visit_id"] == "A1") + self.assertEqual(s1["race"], ["race:Hispanic"]) + + # Second admission: White + s2 = next(s for s in samples if s["visit_id"] == "A2") + self.assertEqual(s2["race"], ["race:White"]) + + def test_age_and_los_computed(self): + """Verify age and LOS are computed correctly.""" + samples = self.task(self.patient) + + for sample in samples: + age = sample["age"][0] + los = sample["los"][0] + + # Age should be 50 (2150 - 2100) + self.assertAlmostEqual(age, 50.0, places=1) + + # LOS should be positive + self.assertGreater(los, 0.0) + + def test_demographics_includes_gender_and_insurance(self): + """BASELINE demographics include gender and insurance.""" + samples = self.task(self.patient) + + for sample in samples: + demo = sample["demographics"] + # Should have gender and insurance tokens + has_gender = any(t.startswith("gender:") for t in demo) + has_insurance = any(t.startswith("insurance:") for t in demo) + self.assertTrue(has_gender) + self.assertTrue(has_insurance) + + def test_insurance_normalization_in_ablation(self): + """Verify insurance normalization (Medicaid -> Public).""" + samples = self.task(self.patient) + + # First admission: Medicaid -> Public + s1 = next(s for s in samples if s["visit_id"] == "A1") + demo1 = s1["demographics"] + self.assertIn("insurance:Public", demo1) + + # Second admission: Private + s2 = next(s for s in samples if s["visit_id"] == "A2") + demo2 = s2["demographics"] + self.assertIn("insurance:Private", demo2) + + def test_label_correctness_in_ablation(self): + """Verify AMA label is correct.""" + samples = self.task(self.patient) + + # First admission: not AMA + s1 = next(s for s in samples if s["visit_id"] == "A1") + self.assertEqual(s1["ama"], 0) + + # Second admission: AMA + s2 = next(s for s in samples if s["visit_id"] == "A2") + self.assertEqual(s2["ama"], 1) + + def test_baseline_minimal_features(self): + """BASELINE (minimal) has only required keys.""" + samples = self.task(self.patient) + self.assertGreaterEqual(len(samples), 1) + + sample = samples[0] + baseline_keys = { + "demographics", "age", "los", "race", + "has_substance_use", "visit_id", "patient_id", "ama", + } + self.assertEqual(set(sample.keys()), baseline_keys) + + def test_multiple_admissions_all_included(self): + """All non-newborn admissions are included (no filtering).""" + samples = self.task(self.patient) + self.assertEqual(len(samples), 2) + + visit_ids = {s["visit_id"] for s in samples} + self.assertEqual(visit_ids, {"A1", "A2"}) + + def test_ablation_patient_no_clinical_codes(self): + """Ablation samples do not contain clinical code fields.""" + samples = self.task(self.patient) + + for sample in samples: + self.assertNotIn("conditions", sample) + self.assertNotIn("procedures", sample) + self.assertNotIn("drugs", sample) + + if __name__ == "__main__": unittest.main() From 9175c156e0c9c5ea6b1e733777bde67c728253ea Mon Sep 17 00:00:00 2001 From: Madhav Kanda Date: Mon, 13 Apr 2026 17:26:00 -0500 Subject: [PATCH 04/10] Modified to use synthetic dataset --- ...mic3_ama_prediction_logistic_regression.py | 334 +++++++++++++--- .../test_ama_ablation_with_synthetic_data.py | 378 ++++++++++++++++++ 2 files changed, 648 insertions(+), 64 deletions(-) create mode 100644 tests/core/test_ama_ablation_with_synthetic_data.py diff --git a/examples/mimic3_ama_prediction_logistic_regression.py b/examples/mimic3_ama_prediction_logistic_regression.py index 17f9e42d6..c2da59cb4 100644 --- a/examples/mimic3_ama_prediction_logistic_regression.py +++ b/examples/mimic3_ama_prediction_logistic_regression.py @@ -1,43 +1,237 @@ -"""AMA Prediction -- LogisticRegression Ablation with Fairness Analysis. - -Reproduces the Against-Medical-Advice discharge prediction from: - - Boag, W.; Suresh, H.; Celi, L. A.; Szolovits, P.; Ghassemi, M. - "Racial Disparities and Mistrust in End-of-Life Care." - Machine Learning for Healthcare Conference, PMLR, 2018. - -For each baseline the script reports: - 1. Overall AUROC / PR-AUC averaged over N random 60/40 splits. - 2. Subgroup performance (AUROC, PR-AUC) sliced by Race, Age Group, - and Insurance Type. - 3. Fairness metrics per subgroup: - - Demographic Parity = % predicted AMA (P(Y_hat=1 | Group=g)) - - Equal Opportunity = True Positive Rate (P(Y_hat=1 | Y=1, Group=g)) - -Usage (synthetic demo data -- illustrative only, likely no AMA positives): +"""Ablation study for AMA discharge prediction on MIMIC-III. + +This script demonstrates the AMAPredictionMIMIC3 task with three feature +ablations and evaluates model fairness using AUROC across demographic +subgroups (race, age, insurance). A logistic regression classifier is +trained on the extracted features to analyze how demographic information +affects prediction of against-medical-advice (AMA) discharge. + +Paper: Boag, W.; Suresh, H.; Celi, L. A.; Szolovits, P.; and Ghassemi, M. +"Racial Disparities and Mistrust in End-of-Life Care." Machine Learning +for Healthcare Conference, PMLR 106:211-235, 2018. + +Ablation configurations tested: + 1. BASELINE: demographics (gender, insurance) + age + los + 2. BASELINE+RACE: adds normalized ethnicity feature + 3. BASELINE+RACE+SUBSTANCE: adds substance use diagnosis flag + +Results: + For each baseline, we report: + - Overall AUROC averaged over N random 60/40 train/test splits + - Subgroup performance (AUROC) stratified by: + * Race (White, Black, Hispanic, Asian, Native American, Other) + * Age Group (Young 18-44, Middle 45-64, Senior 65+) + * Insurance (Public, Private, Self Pay) + - Fairness metrics per subgroup: + * Demographic Parity: % predicted AMA per group + * Equal Opportunity: True Positive Rate per group + These reveal disparities in model behavior across demographics. + +Usage (synthetic demo data -- default, fast): python examples/mimic3_ama_prediction_logistic_regression.py -Usage (real MIMIC-III): +Usage (with more patients and more splits): + python examples/mimic3_ama_prediction_logistic_regression.py \\ + --patients 500 --splits 10 --epochs 5 + +Usage (with real MIMIC-III data): python examples/mimic3_ama_prediction_logistic_regression.py \\ --root /path/to/mimic-iii/1.4 --splits 100 --epochs 10 """ import argparse +import gzip import tempfile import time +from datetime import datetime, timedelta +from pathlib import Path +from typing import List import numpy as np +import pandas as pd import torch -from sklearn.metrics import average_precision_score, roc_auc_score +from sklearn.metrics import roc_auc_score from pyhealth.datasets import MIMIC3Dataset, get_dataloader, split_by_patient from pyhealth.models import LogisticRegression from pyhealth.tasks import AMAPredictionMIMIC3 from pyhealth.trainer import Trainer -SYNTHETIC_ROOT = ( - "https://storage.googleapis.com/pyhealth/Synthetic_MIMIC-III" -) + +def generate_synthetic_mimic3( + root: str, + n_patients: int = 50, + avg_admissions_per_patient: int = 2, + seed: int = 42, +) -> None: + """Write gzipped PATIENTS, ADMISSIONS, and ICUSTAYS CSVs for local demos. + + Covers rotated gender, ethnicity, insurance, mixed AMA and substance-use + diagnoses. Used when ``--root`` is omitted so the example runs without + network access or a full MIMIC-III install. + + Args: + root: Directory to write CSV files to. + n_patients: Number of synthetic patients to generate. + avg_admissions_per_patient: Poisson mean for admissions per patient. + seed: Random seed for reproducibility. + """ + np.random.seed(seed) + root_path = Path(root) + root_path.mkdir(parents=True, exist_ok=True) + + genders = ["M", "F"] + ethnicities = [ + "WHITE", + "BLACK/AFRICAN AMERICAN", + "HISPANIC OR LATINO", + "ASIAN - CHINESE", + "AMERICAN INDIAN/ALASKA NATIVE", + "UNKNOWN/NOT SPECIFIED", + ] + insurances = ["Medicare", "Medicaid", "Private", "Self Pay", "Government"] + admission_types = ["EMERGENCY", "URGENT", "NEWBORN", "ELECTIVE"] + discharge_locations = [ + "HOME", + "SKILLED NURSING FACILITY", + "LONG TERM CARE", + "LEFT AGAINST MEDICAL ADVI", + "EXPIRED", + ] + diagnoses_substance = [ + "ALCOHOL WITHDRAWAL", + "OPIOID DEPENDENCE", + "HEROIN OVERDOSE", + "COCAINE INTOXICATION", + "DRUG WITHDRAWAL SEIZURE", + "ETOH ABUSE", + "SUBSTANCE ABUSE", + "OVERDOSE - ACCIDENTAL", + ] + diagnoses_other = [ + "PNEUMONIA", + "ACUTE MYOCARDIAL INFARCTION", + "CHEST PAIN", + "CONGESTIVE HEART FAILURE", + "SEPSIS", + "ACUTE KIDNEY INJURY", + "ACUTE RESPIRATORY FAILURE", + "ASPIRATION", + ] + + patients_data: List[dict] = [] + admissions_data: List[dict] = [] + icustays_data: List[dict] = [] + + subject_id = 1 + hadm_id = 100 + icustay_id = 1000 + + for i in range(n_patients): + gender = genders[i % len(genders)] + ethnicity = ethnicities[i % len(ethnicities)] + insurance = insurances[i % len(insurances)] + + age_at_visit = int( + np.random.choice([25, 45, 65, 85]) + np.random.randint(-5, 5) + ) + dob = datetime(2000, 1, 1) - timedelta(days=age_at_visit * 365) + + patients_data.append({ + "subject_id": subject_id, + "gender": gender, + "dob": dob.strftime("%Y-%m-%d %H:%M:%S"), + "dod": None, + "dod_hosp": None, + "dod_ssn": None, + "expire_flag": 0, + }) + + n_admissions = max(1, int(np.random.poisson(avg_admissions_per_patient))) + for j in range(n_admissions): + admit_time = datetime(2150, 1, 1) + timedelta(days=int(j * 100)) + discharge_time = admit_time + timedelta( + days=int(np.random.randint(1, 30)) + ) + + admission_type = admission_types[(i + j) % len(admission_types)] + + if np.random.random() < 0.15: + discharge_loc = "LEFT AGAINST MEDICAL ADVI" + elif np.random.random() < 0.05: + discharge_loc = "EXPIRED" + else: + discharge_loc = discharge_locations[ + (i + j) % (len(discharge_locations) - 2) + ] + + if np.random.random() < 0.2: + diagnosis = diagnoses_substance[ + np.random.randint(0, len(diagnoses_substance)) + ] + else: + diagnosis = diagnoses_other[ + np.random.randint(0, len(diagnoses_other)) + ] + + admissions_data.append({ + "subject_id": subject_id, + "hadm_id": hadm_id, + "admission_type": admission_type, + "admission_location": "EMERGENCY ROOM ADMIT", + "insurance": insurance, + "language": "ENGLISH", + "religion": "CHRISTIAN", + "marital_status": "SINGLE", + "ethnicity": ethnicity, + "edregtime": admit_time.strftime("%Y-%m-%d %H:%M:%S"), + "edouttime": admit_time.strftime("%Y-%m-%d %H:%M:%S"), + "diagnosis": diagnosis, + "discharge_location": discharge_loc, + "dischtime": discharge_time.strftime("%Y-%m-%d %H:%M:%S"), + "admittime": admit_time.strftime("%Y-%m-%d %H:%M:%S"), + "hospital_expire_flag": 1 + if discharge_loc == "EXPIRED" + else 0, + }) + + icu_intime = admit_time + timedelta( + hours=int(np.random.randint(0, 12)) + ) + icu_outtime = discharge_time - timedelta( + hours=int(np.random.randint(0, 12)) + ) + + if icu_intime < icu_outtime: + icustays_data.append({ + "subject_id": subject_id, + "hadm_id": hadm_id, + "icustay_id": icustay_id, + "first_careunit": "MICU", + "last_careunit": "MICU", + "dbsource": "metavision", + "intime": icu_intime.strftime("%Y-%m-%d %H:%M:%S"), + "outtime": icu_outtime.strftime("%Y-%m-%d %H:%M:%S"), + }) + icustay_id += 1 + + hadm_id += 1 + + subject_id += 1 + + def write_csv_gz(filename: str, data: List[dict]) -> None: + df = pd.DataFrame(data) + filepath = root_path / f"{filename}.gz" + with gzip.open(filepath, "wt") as f: + df.to_csv(f, index=False) + print(f" Created {filename}.gz ({len(data)} rows)") + + print(f"Generating synthetic MIMIC-III dataset in {root_path}...") + write_csv_gz("PATIENTS.csv", patients_data) + write_csv_gz("ADMISSIONS.csv", admissions_data) + write_csv_gz("ICUSTAYS.csv", icustays_data) + print("Done.") + BASELINES = { "BASELINE": ["demographics", "age", "los"], @@ -136,15 +330,6 @@ def _safe_auroc(y, p): return float("nan") -def _safe_prauc(y, p): - if np.sum(y) == 0: - return float("nan") - try: - return average_precision_score(y, p) - except ValueError: - return float("nan") - - # ------------------------------------------------------------------ # Single split # ------------------------------------------------------------------ @@ -187,7 +372,6 @@ def _run_single_split(sample_dataset, feature_keys, lookup, y_pred = (y_prob >= threshold).astype(int) overall_auroc = _safe_auroc(y_true, y_prob) - overall_prauc = _safe_prauc(y_true, y_prob) subgroup = {} for attr_name, attr_vals in groups.items(): @@ -201,7 +385,6 @@ def _run_single_split(sample_dataset, feature_keys, lookup, pos = yt.sum() subgroup[attr_name][grp] = { "auroc": _safe_auroc(yt, yp), - "pr_auc": _safe_prauc(yt, yp), "pct_pred": float(yd.mean()) * 100, "tpr": float(yd[yt == 1].mean()) * 100 if pos > 0 else float("nan"), @@ -210,7 +393,6 @@ def _run_single_split(sample_dataset, feature_keys, lookup, return { "auroc": overall_auroc, - "pr_auc": overall_prauc, "subgroups": subgroup, } @@ -239,8 +421,6 @@ def _aggregate(results): "n": len(valid), "auroc_mean": _nanmean([r["auroc"] for r in valid]), "auroc_std": _nanstd([r["auroc"] for r in valid]), - "pr_auc_mean": _nanmean([r["pr_auc"] for r in valid]), - "pr_auc_std": _nanstd([r["pr_auc"] for r in valid]), } all_attrs = set() @@ -256,13 +436,12 @@ def _aggregate(results): all_grps.update(r["subgroups"][attr].keys()) for grp in sorted(all_grps): - aurocs, praucs, pcts, tprs, ns = [], [], [], [], [] + aurocs, pcts, tprs, ns = [], [], [], [] for r in valid: m = r["subgroups"].get(attr, {}).get(grp) if m is None: continue aurocs.append(m["auroc"]) - praucs.append(m["pr_auc"]) pcts.append(m["pct_pred"]) tprs.append(m["tpr"]) ns.append(m["n"]) @@ -270,8 +449,6 @@ def _aggregate(results): agg["subgroups"][attr][grp] = { "auroc_mean": _nanmean(aurocs), "auroc_std": _nanstd(aurocs), - "pr_auc_mean": _nanmean(praucs), - "pr_auc_std": _nanstd(praucs), "pct_pred_mean": _nanmean(pcts), "tpr_mean": _nanmean(tprs), "n_avg": int(np.mean(ns)) if ns else 0, @@ -303,17 +480,15 @@ def _print_results(name, feature_keys, agg): print(f"\n 1. Overall Performance ({agg['n']} splits)") print(f" AUROC: {_fmt(agg['auroc_mean'])} +/- {_fmt(agg['auroc_std'])}" f" 95% CI ({_fmt(ci_lo)}, {_fmt(ci_hi)})") - print(f" PR-AUC: {_fmt(agg['pr_auc_mean'])} +/- {_fmt(agg['pr_auc_std'])}") print(f"\n 2. Subgroup Performance") for attr, grps in agg["subgroups"].items(): print(f" {attr}:") - print(f" {'Group':<20} {'AUROC':>15} {'PR-AUC':>15} {'n_avg':>7}") - print(f" {'-'*58}") + print(f" {'Group':<20} {'AUROC':>15} {'n_avg':>7}") + print(f" {'-'*42}") for grp, m in grps.items(): a_str = f"{_fmt(m['auroc_mean'])}+/-{_fmt(m['auroc_std'])}" - p_str = f"{_fmt(m['pr_auc_mean'])}+/-{_fmt(m['pr_auc_std'])}" - print(f" {grp:<20} {a_str:>15} {p_str:>15} {m['n_avg']:>7}") + print(f" {grp:<20} {a_str:>15} {m['n_avg']:>7}") print(f"\n 3. Fairness Metrics") print(f" Demographic Parity (% Predicted AMA):") @@ -337,17 +512,49 @@ def main(): parser = argparse.ArgumentParser( description="AMA prediction ablation -- LogisticRegression", ) - parser.add_argument("--root", default=SYNTHETIC_ROOT, - help="MIMIC-III root (local path or URL)") - parser.add_argument("--splits", type=int, default=100, - help="Number of random 60/40 splits (default 100)") - parser.add_argument("--epochs", type=int, default=10, - help="Training epochs per split") - parser.add_argument("--dev", action="store_true", - help="Use dev mode (1000 patients)") + parser.add_argument( + "--root", + default=None, + help="MIMIC-III root (local path). If not provided, uses synthetic data.", + ) + parser.add_argument( + "--patients", + type=int, + default=100, + help="Number of synthetic patients (default 100, only used if --root not provided)", + ) + parser.add_argument( + "--splits", + type=int, + default=5, + help="Number of random 60/40 splits (default 5 for speed with synthetic data)", + ) + parser.add_argument( + "--epochs", + type=int, + default=3, + help="Training epochs per split (default 3 for speed with synthetic data)", + ) + parser.add_argument( + "--dev", + action="store_true", + help="Use dev mode (fewer patients/splits for testing)", + ) args = parser.parse_args() cache_dir = tempfile.mkdtemp(prefix="ama_lr_") + + # If no root provided, generate synthetic data + if args.root is None: + print("[Setup] Generating synthetic MIMIC-III dataset...") + data_dir = tempfile.mkdtemp(prefix="synthetic_mimic3_") + n_patients = 10 if args.dev else args.patients + generate_synthetic_mimic3(data_dir, n_patients=n_patients, seed=42) + args.root = data_dir + print(f" Synthetic data: {data_dir}\n") + else: + print(f"Using real MIMIC-III from: {args.root}\n") + print(f"Cache: {cache_dir}") print(f"Root: {args.root}") print(f"Splits: {args.splits} | Epochs: {args.epochs}") @@ -355,8 +562,10 @@ def main(): print("\n[1/4] Loading dataset...") t0 = time.time() dataset = MIMIC3Dataset( - root=args.root, tables=[], - cache_dir=cache_dir, dev=args.dev, + root=args.root, + tables=[], + cache_dir=cache_dir, + dev=args.dev, ) print(f" Loaded in {time.time()-t0:.1f}s") dataset.stats() @@ -370,15 +579,12 @@ def main(): raise print(f"\n {exc}") print(" The dataset contains no AMA-positive cases.") - print(" AMA prevalence is ~2% so small/synthetic data") - print(" often lacks positives. Demonstrating the task") - print(" on raw patients instead:\n") - total = 0 - for patient in dataset.iter_patients(): - samples = task(patient) - total += len(samples) - print(f" Task produced {total} samples (all label=0)") - print("\n Re-run with real MIMIC-III for ablation:") + print(" For synthetic data: this is expected if AMA rate is low.") + print(" To increase AMA cases, re-run with:") + print(" python examples/" + "mimic3_ama_prediction_logistic_regression.py \\") + print(" --patients 500\n") + print(" For real MIMIC-III with better AMA coverage:") print(" python examples/" "mimic3_ama_prediction_logistic_regression.py \\") print(" --root /path/to/mimic-iii/1.4") diff --git a/tests/core/test_ama_ablation_with_synthetic_data.py b/tests/core/test_ama_ablation_with_synthetic_data.py new file mode 100644 index 000000000..8a0a6452c --- /dev/null +++ b/tests/core/test_ama_ablation_with_synthetic_data.py @@ -0,0 +1,378 @@ +"""Tests for AMA prediction ablation studies with synthetic data. + +Uses local synthetic MIMIC-III data to ensure fast execution +and comprehensive coverage of all demographic combinations. + +Synthetic CSV generation lives in ``examples/mimic3_ama_prediction_ +logistic_regression.py``; tests load that helper via importlib so there +is no separate dataset module. +""" + +import importlib.util +import tempfile +import unittest +from datetime import datetime +from pathlib import Path + +import numpy as np +import torch + +from pyhealth.datasets import MIMIC3Dataset, get_dataloader, split_by_patient +from pyhealth.models import LogisticRegression +from pyhealth.tasks import AMAPredictionMIMIC3 +from pyhealth.trainer import Trainer + +_EXAMPLE_PATH = ( + Path(__file__).resolve().parents[2] + / "examples" + / "mimic3_ama_prediction_logistic_regression.py" +) +_spec = importlib.util.spec_from_file_location( + "mimic3_ama_prediction_example", _EXAMPLE_PATH +) +_example_mod = importlib.util.module_from_spec(_spec) +assert _spec.loader is not None +_spec.loader.exec_module(_example_mod) +generate_synthetic_mimic3 = _example_mod.generate_synthetic_mimic3 + + +class TestAMAWithSyntheticData(unittest.TestCase): + """Test AMA prediction ablation studies on synthetic data. + + Uses a small synthetic dataset that covers all demographic + combinations to ensure tests run quickly (~1 second total). + """ + + @classmethod + def setUpClass(cls): + """Generate synthetic dataset once for all tests.""" + cls.tmpdir = tempfile.mkdtemp(prefix="ama_test_") + cls.cache_dir = tempfile.mkdtemp(prefix="ama_test_cache_") + + # Generate small but comprehensive synthetic data + generate_synthetic_mimic3( + cls.tmpdir, + n_patients=50, + avg_admissions_per_patient=2, + seed=42, + ) + + # Load dataset + cls.dataset = MIMIC3Dataset( + root=cls.tmpdir, + tables=[], + cache_dir=cls.cache_dir, + ) + + # Apply task + cls.task = AMAPredictionMIMIC3() + cls.sample_dataset = cls.dataset.set_task(cls.task) + + def test_dataset_loads_successfully(self): + """Verify synthetic dataset loads with expected structure.""" + self.assertIsNotNone(self.dataset) + self.assertGreater(len(self.sample_dataset), 0) + + def test_samples_have_expected_features(self): + """Verify each sample contains required features.""" + sample = self.sample_dataset[0] + + expected_keys = { + "visit_id", + "patient_id", + "demographics", + "age", + "los", + "race", + "has_substance_use", + "ama", + } + self.assertEqual(set(sample.keys()), expected_keys) + + def test_demographics_values(self): + """Verify demographics contain processed feature vectors.""" + for sample in self.sample_dataset: + demo = sample["demographics"] + # After processing, demographics are tensors + self.assertTrue( + torch.is_tensor(demo) or isinstance(demo, (int, float)), + "Demographics should be processed", + ) + + def test_age_in_valid_range(self): + """Verify ages are processed as tensors.""" + for sample in self.sample_dataset: + age = sample["age"] + # After processing, age is a tensor + self.assertTrue(torch.is_tensor(age) or isinstance(age, (int, float))) + + def test_los_positive(self): + """Verify LOS (length of stay) is processed as tensor.""" + for sample in self.sample_dataset: + los = sample["los"] + # After processing, los is a tensor + self.assertTrue(torch.is_tensor(los) or isinstance(los, (int, float))) + + def test_race_normalized(self): + """Verify race is processed as tensor.""" + for sample in self.sample_dataset: + race = sample["race"] + # After processing, race is a tensor + self.assertTrue(torch.is_tensor(race) or isinstance(race, (int, float))) + + def test_substance_use_binary(self): + """Verify substance use is processed as tensor.""" + for sample in self.sample_dataset: + substance = sample["has_substance_use"] + # After processing, substance use is a tensor + self.assertTrue(torch.is_tensor(substance) or isinstance(substance, (int, float))) + + def test_ama_label_binary(self): + """Verify AMA label is 0 or 1.""" + for sample in self.sample_dataset: + ama = sample["ama"] + self.assertIn(ama, [0, 1]) + + def test_has_positive_and_negative_labels(self): + """Verify dataset has both AMA positive and negative cases.""" + labels = [sample["ama"] for sample in self.sample_dataset] + has_positive = any(l == 1 for l in labels) + has_negative = any(l == 0 for l in labels) + + self.assertTrue( + has_positive and has_negative, + "Dataset should have both positive and negative AMA cases", + ) + + +class TestAMABaselineFeatures(unittest.TestCase): + """Test that each ablation baseline uses correct features.""" + + @classmethod + def setUpClass(cls): + """Generate synthetic dataset for baseline tests.""" + cls.tmpdir = tempfile.mkdtemp(prefix="ama_baseline_test_") + cls.cache_dir = tempfile.mkdtemp(prefix="ama_baseline_cache_") + + generate_synthetic_mimic3( + cls.tmpdir, + n_patients=30, + seed=42, + ) + + cls.dataset = MIMIC3Dataset( + root=cls.tmpdir, + tables=[], + cache_dir=cls.cache_dir, + ) + + cls.task = AMAPredictionMIMIC3() + cls.sample_dataset = cls.dataset.set_task(cls.task) + + def _create_model_with_features(self, feature_keys): + """Helper to create logistic regression model with feature keys.""" + model = LogisticRegression( + dataset=self.sample_dataset, + embedding_dim=64, # Use dataset's embedding_dim + ) + model.feature_keys = list(feature_keys) + output_size = model.get_output_size() + embedding_dim = model.embedding_model.embedding_layers[ + feature_keys[0] + ].out_features + model.fc = torch.nn.Linear( + len(feature_keys) * embedding_dim, output_size + ) + return model + + def test_baseline_model_can_be_created(self): + """BASELINE: demographics, age, los.""" + model = self._create_model_with_features( + ["demographics", "age", "los"] + ) + self.assertIsNotNone(model) + # Verify fc layer exists + self.assertIsNotNone(model.fc) + + def test_baseline_plus_race_model(self): + """BASELINE + RACE: adds race.""" + model = self._create_model_with_features( + ["demographics", "age", "los", "race"] + ) + self.assertIsNotNone(model) + self.assertIsNotNone(model.fc) + + def test_baseline_plus_race_plus_substance_model(self): + """BASELINE + RACE + SUBSTANCE: adds has_substance_use.""" + model = self._create_model_with_features( + ["demographics", "age", "los", "race", "has_substance_use"] + ) + self.assertIsNotNone(model) + self.assertIsNotNone(model.fc) + + def test_baseline_forward_pass(self): + """Verify model forward pass works with baseline features.""" + model = self._create_model_with_features( + ["demographics", "age", "los"] + ) + + train_ds, _, test_ds = split_by_patient( + self.sample_dataset, [0.8, 0.0, 0.2], seed=0 + ) + test_dl = get_dataloader(test_ds, batch_size=8, shuffle=False) + + model.eval() + with torch.no_grad(): + batch = next(iter(test_dl)) + output = model(**batch) + + self.assertIn("y_prob", output) + self.assertIn("y_true", output) + self.assertEqual(output["y_prob"].shape[0], len(test_ds)) + + def test_baseline_plus_race_forward_pass(self): + """Verify model forward pass with race feature.""" + model = self._create_model_with_features( + ["demographics", "age", "los", "race"] + ) + + train_ds, _, test_ds = split_by_patient( + self.sample_dataset, [0.8, 0.0, 0.2], seed=0 + ) + test_dl = get_dataloader(test_ds, batch_size=8, shuffle=False) + + model.eval() + with torch.no_grad(): + batch = next(iter(test_dl)) + output = model(**batch) + + self.assertIn("y_prob", output) + self.assertEqual(output["y_prob"].shape[0], len(test_ds)) + + def test_baseline_plus_full_forward_pass(self): + """Verify model forward pass with all features.""" + model = self._create_model_with_features( + ["demographics", "age", "los", "race", "has_substance_use"] + ) + + train_ds, _, test_ds = split_by_patient( + self.sample_dataset, [0.8, 0.0, 0.2], seed=0 + ) + test_dl = get_dataloader(test_ds, batch_size=8, shuffle=False) + + model.eval() + with torch.no_grad(): + batch = next(iter(test_dl)) + output = model(**batch) + + self.assertIn("y_prob", output) + self.assertEqual(output["y_prob"].shape[0], len(test_ds)) + + +class TestAMATrainingSpeed(unittest.TestCase): + """Verify training with synthetic data is fast.""" + + @classmethod + def setUpClass(cls): + """Generate small synthetic dataset for speed tests.""" + cls.tmpdir = tempfile.mkdtemp(prefix="ama_speed_test_") + cls.cache_dir = tempfile.mkdtemp(prefix="ama_speed_cache_") + + generate_synthetic_mimic3( + cls.tmpdir, + n_patients=20, # Small for speed + seed=42, + ) + + cls.dataset = MIMIC3Dataset( + root=cls.tmpdir, + tables=[], + cache_dir=cls.cache_dir, + ) + + cls.task = AMAPredictionMIMIC3() + cls.sample_dataset = cls.dataset.set_task(cls.task) + + def test_training_completes_quickly(self): + """Verify one training epoch completes in reasonable time.""" + import time + + train_ds, _, test_ds = split_by_patient( + self.sample_dataset, [0.6, 0.0, 0.4], seed=0 + ) + train_dl = get_dataloader( + train_ds, batch_size=8, shuffle=True + ) + + model = LogisticRegression( + dataset=self.sample_dataset, + embedding_dim=64, + ) + model.feature_keys = ["demographics", "age", "los"] + output_size = model.get_output_size() + embedding_dim = model.embedding_model.embedding_layers[ + "demographics" + ].out_features + model.fc = torch.nn.Linear( + 3 * embedding_dim, output_size + ) + + trainer = Trainer(model=model) + + t0 = time.time() + trainer.train( + train_dataloader=train_dl, + val_dataloader=None, + epochs=1, + monitor=None, + ) + elapsed = time.time() - t0 + + # Should complete in reasonable time + self.assertGreater(elapsed, 0, "Training should take some time") + + def test_multiple_splits_complete_quickly(self): + """Verify 2 random splits complete without error.""" + train_ds, _, test_ds = split_by_patient( + self.sample_dataset, + [0.6, 0.0, 0.4], + seed=0, + ) + + for split_seed in range(2): + train_ds, _, _ = split_by_patient( + self.sample_dataset, + [0.6, 0.0, 0.4], + seed=split_seed, + ) + train_dl = get_dataloader( + train_ds, batch_size=8, shuffle=True + ) + + model = LogisticRegression( + dataset=self.sample_dataset, + embedding_dim=64, + ) + model.feature_keys = ["demographics", "age", "los"] + output_size = model.get_output_size() + embedding_dim = model.embedding_model.embedding_layers[ + "demographics" + ].out_features + model.fc = torch.nn.Linear( + 3 * embedding_dim, output_size + ) + + trainer = Trainer(model=model) + trainer.train( + train_dataloader=train_dl, + val_dataloader=None, + epochs=1, + monitor=None, + ) + + # Verify we completed without errors + self.assertTrue(True) + + +if __name__ == "__main__": + unittest.main() From 680c24f1dfb0cf82803e3cabf12fa066a1489091 Mon Sep 17 00:00:00 2001 From: Madhav Kanda Date: Tue, 14 Apr 2026 08:02:04 -0500 Subject: [PATCH 05/10] fixed the synthetic data generator --- ...mic3_ama_prediction_logistic_regression.py | 421 +++++++++++++----- examples/mimic3_ama_prediction_rnn.py | 125 ++++-- .../test_ama_ablation_with_synthetic_data.py | 378 ---------------- tests/core/test_mimic3_ama_prediction.py | 358 +++++++++++++++ 4 files changed, 751 insertions(+), 531 deletions(-) delete mode 100644 tests/core/test_ama_ablation_with_synthetic_data.py diff --git a/examples/mimic3_ama_prediction_logistic_regression.py b/examples/mimic3_ama_prediction_logistic_regression.py index c2da59cb4..1127a0180 100644 --- a/examples/mimic3_ama_prediction_logistic_regression.py +++ b/examples/mimic3_ama_prediction_logistic_regression.py @@ -27,25 +27,26 @@ * Equal Opportunity: True Positive Rate per group These reveal disparities in model behavior across demographics. -Usage (synthetic demo data -- default, fast): +Usage (synthetic exhaustive grid -- default when ``--root`` is omitted): python examples/mimic3_ama_prediction_logistic_regression.py -Usage (with more patients and more splits): +Usage (synthetic random demo): python examples/mimic3_ama_prediction_logistic_regression.py \\ - --patients 500 --splits 10 --epochs 5 + --data-source synthetic --synthetic-mode random --patients 200 -Usage (with real MIMIC-III data): +Usage (real MIMIC-III; same as ``--root /path`` with ``--data-source auto``): python examples/mimic3_ama_prediction_logistic_regression.py \\ - --root /path/to/mimic-iii/1.4 --splits 100 --epochs 10 + --data-source real --root /path/to/mimic-iii/1.4 --splits 100 --epochs 10 """ import argparse import gzip +import itertools import tempfile import time from datetime import datetime, timedelta from pathlib import Path -from typing import List +from typing import List, Optional import numpy as np import pandas as pd @@ -63,23 +64,34 @@ def generate_synthetic_mimic3( n_patients: int = 50, avg_admissions_per_patient: int = 2, seed: int = 42, + mode: str = "exhaustive", ) -> None: """Write gzipped PATIENTS, ADMISSIONS, and ICUSTAYS CSVs for local demos. - Covers rotated gender, ethnicity, insurance, mixed AMA and substance-use - diagnoses. Used when ``--root`` is omitted so the example runs without - network access or a full MIMIC-III install. + ``mode="exhaustive"`` (default) emits one patient per element of the + Cartesian product of task-relevant factors so every combination appears + at least once: gender × MIMIC ethnicity string × raw insurance × age + band (maps to Young / Middle / Senior) × AMA vs non-AMA discharge × + substance vs non-substance diagnosis text. A few extra rows cover + NEWBORN filtering, EXPIRED, SNF, and missing insurance (``Other``). + + ``mode="random"`` reproduces the legacy stochastic generator (use for + quick tests via ``--synthetic-mode random``). Args: root: Directory to write CSV files to. - n_patients: Number of synthetic patients to generate. - avg_admissions_per_patient: Poisson mean for admissions per patient. - seed: Random seed for reproducibility. + n_patients: Used only when ``mode="random"`` (patient count). + avg_admissions_per_patient: Poisson mean per patient (random mode). + seed: RNG seed (random mode only). + mode: ``"exhaustive"`` or ``"random"``. """ - np.random.seed(seed) root_path = Path(root) root_path.mkdir(parents=True, exist_ok=True) + patients_data: List[dict] = [] + admissions_data: List[dict] = [] + icustays_data: List[dict] = [] + genders = ["M", "F"] ethnicities = [ "WHITE", @@ -89,7 +101,15 @@ def generate_synthetic_mimic3( "AMERICAN INDIAN/ALASKA NATIVE", "UNKNOWN/NOT SPECIFIED", ] - insurances = ["Medicare", "Medicaid", "Private", "Self Pay", "Government"] + # Raw insurance values (normalize to Public / Private / Self Pay / Other) + insurances_raw: List[Optional[str]] = [ + "Medicare", + "Medicaid", + "Government", + "Private", + "Self Pay", + None, + ] admission_types = ["EMERGENCY", "URGENT", "NEWBORN", "ELECTIVE"] discharge_locations = [ "HOME", @@ -119,24 +139,29 @@ def generate_synthetic_mimic3( "ASPIRATION", ] - patients_data: List[dict] = [] - admissions_data: List[dict] = [] - icustays_data: List[dict] = [] - - subject_id = 1 - hadm_id = 100 - icustay_id = 1000 - - for i in range(n_patients): - gender = genders[i % len(genders)] - ethnicity = ethnicities[i % len(ethnicities)] - insurance = insurances[i % len(insurances)] - - age_at_visit = int( - np.random.choice([25, 45, 65, 85]) + np.random.randint(-5, 5) - ) - dob = datetime(2000, 1, 1) - timedelta(days=age_at_visit * 365) + def write_csv_gz(filename: str, data: List[dict]) -> None: + df = pd.DataFrame(data) + filepath = root_path / f"{filename}.gz" + with gzip.open(filepath, "wt") as f: + df.to_csv(f, index=False) + print(f" Created {filename}.gz ({len(data)} rows)") + def append_visit( + subject_id: int, + hadm_id: int, + icustay_id: int, + *, + gender: str, + age_years: int, + ethnicity: str, + insurance_raw: Optional[str], + admission_type: str, + discharge_loc: str, + diagnosis: str, + day_offset: int, + ) -> int: + """Append one patient + admission + icustay; return next icustay_id.""" + dob = datetime(2000, 1, 1) - timedelta(days=int(age_years * 365)) patients_data.append({ "subject_id": subject_id, "gender": gender, @@ -146,87 +171,216 @@ def generate_synthetic_mimic3( "dod_ssn": None, "expire_flag": 0, }) - - n_admissions = max(1, int(np.random.poisson(avg_admissions_per_patient))) - for j in range(n_admissions): - admit_time = datetime(2150, 1, 1) + timedelta(days=int(j * 100)) - discharge_time = admit_time + timedelta( - days=int(np.random.randint(1, 30)) + admit_time = datetime(2150, 1, 1) + timedelta(days=day_offset) + discharge_time = admit_time + timedelta(days=7) + admissions_data.append({ + "subject_id": subject_id, + "hadm_id": hadm_id, + "admission_type": admission_type, + "admission_location": "EMERGENCY ROOM ADMIT", + "insurance": insurance_raw, + "language": "ENGLISH", + "religion": "CHRISTIAN", + "marital_status": "SINGLE", + "ethnicity": ethnicity, + "edregtime": admit_time.strftime("%Y-%m-%d %H:%M:%S"), + "edouttime": admit_time.strftime("%Y-%m-%d %H:%M:%S"), + "diagnosis": diagnosis, + "discharge_location": discharge_loc, + "dischtime": discharge_time.strftime("%Y-%m-%d %H:%M:%S"), + "admittime": admit_time.strftime("%Y-%m-%d %H:%M:%S"), + "hospital_expire_flag": 1 if discharge_loc == "EXPIRED" else 0, + }) + icu_intime = admit_time + timedelta(hours=2) + icu_outtime = discharge_time - timedelta(hours=2) + if icu_intime < icu_outtime: + icustays_data.append({ + "subject_id": subject_id, + "hadm_id": hadm_id, + "icustay_id": icustay_id, + "first_careunit": "MICU", + "last_careunit": "MICU", + "dbsource": "metavision", + "intime": icu_intime.strftime("%Y-%m-%d %H:%M:%S"), + "outtime": icu_outtime.strftime("%Y-%m-%d %H:%M:%S"), + }) + return icustay_id + 1 + return icustay_id + + if mode == "exhaustive": + # Ages map to Young (18-44) / Middle (45-64) / Senior (65+). + age_bands = [30, 52, 72] + discharge_ama = ["HOME", "LEFT AGAINST MEDICAL ADVI"] + diagnosis_texts = ["PNEUMONIA", "ALCOHOL WITHDRAWAL"] + combos = itertools.product( + genders, + ethnicities, + insurances_raw, + age_bands, + discharge_ama, + diagnosis_texts, + ) + subject_id = 1 + hadm_id = 100 + icustay_id = 1000 + for idx, (gender, eth, ins_raw, age_y, disch, diag) in enumerate(combos): + icustay_id = append_visit( + subject_id, + hadm_id, + icustay_id, + gender=gender, + age_years=age_y, + ethnicity=eth, + insurance_raw=ins_raw, + admission_type="EMERGENCY", + discharge_loc=disch, + diagnosis=diag, + day_offset=idx % 500, ) + subject_id += 1 + hadm_id += 1 - admission_type = admission_types[(i + j) % len(admission_types)] + # Extra coverage: SNF discharge, EXPIRED, NEWBORN (skipped by task). + exhaustive_grid_n = ( + len(genders) + * len(ethnicities) + * len(insurances_raw) + * len(age_bands) + * len(discharge_ama) + * len(diagnosis_texts) + ) + for k, extra in enumerate( + ( + ("M", 45, "WHITE", "Private", "EMERGENCY", "SKILLED NURSING FACILITY", "SEPSIS"), + ("F", 55, "BLACK/AFRICAN AMERICAN", "Medicaid", "EMERGENCY", "EXPIRED", "CHEST PAIN"), + ("M", 28, "HISPANIC OR LATINO", "Private", "NEWBORN", "HOME", "PNEUMONIA"), + ), + ): + g, age_y, eth, ins, adm_type, disch, diag = extra + icustay_id = append_visit( + subject_id, + hadm_id, + icustay_id, + gender=g, + age_years=age_y, + ethnicity=eth, + insurance_raw=ins, + admission_type=adm_type, + discharge_loc=disch, + diagnosis=diag, + day_offset=(exhaustive_grid_n + k) % 500, + ) + subject_id += 1 + hadm_id += 1 - if np.random.random() < 0.15: - discharge_loc = "LEFT AGAINST MEDICAL ADVI" - elif np.random.random() < 0.05: - discharge_loc = "EXPIRED" - else: - discharge_loc = discharge_locations[ - (i + j) % (len(discharge_locations) - 2) - ] - - if np.random.random() < 0.2: - diagnosis = diagnoses_substance[ - np.random.randint(0, len(diagnoses_substance)) - ] - else: - diagnosis = diagnoses_other[ - np.random.randint(0, len(diagnoses_other)) - ] + print( + f"Generating exhaustive synthetic MIMIC-III in {root_path} " + f"({len(patients_data)} patients, cross-product + edge rows)...", + ) + elif mode == "random": + np.random.seed(seed) + subject_id = 1 + hadm_id = 100 + icustay_id = 1000 + insurances = ["Medicare", "Medicaid", "Private", "Self Pay", "Government"] + for i in range(n_patients): + gender = genders[i % len(genders)] + ethnicity = ethnicities[i % len(ethnicities)] + insurance = insurances[i % len(insurances)] + + age_at_visit = int( + np.random.choice([25, 45, 65, 85]) + np.random.randint(-5, 5) + ) + dob = datetime(2000, 1, 1) - timedelta(days=age_at_visit * 365) - admissions_data.append({ + patients_data.append({ "subject_id": subject_id, - "hadm_id": hadm_id, - "admission_type": admission_type, - "admission_location": "EMERGENCY ROOM ADMIT", - "insurance": insurance, - "language": "ENGLISH", - "religion": "CHRISTIAN", - "marital_status": "SINGLE", - "ethnicity": ethnicity, - "edregtime": admit_time.strftime("%Y-%m-%d %H:%M:%S"), - "edouttime": admit_time.strftime("%Y-%m-%d %H:%M:%S"), - "diagnosis": diagnosis, - "discharge_location": discharge_loc, - "dischtime": discharge_time.strftime("%Y-%m-%d %H:%M:%S"), - "admittime": admit_time.strftime("%Y-%m-%d %H:%M:%S"), - "hospital_expire_flag": 1 - if discharge_loc == "EXPIRED" - else 0, + "gender": gender, + "dob": dob.strftime("%Y-%m-%d %H:%M:%S"), + "dod": None, + "dod_hosp": None, + "dod_ssn": None, + "expire_flag": 0, }) - icu_intime = admit_time + timedelta( - hours=int(np.random.randint(0, 12)) - ) - icu_outtime = discharge_time - timedelta( - hours=int(np.random.randint(0, 12)) + n_admissions = max( + 1, int(np.random.poisson(avg_admissions_per_patient)), ) - - if icu_intime < icu_outtime: - icustays_data.append({ + for j in range(n_admissions): + admit_time = datetime(2150, 1, 1) + timedelta(days=int(j * 100)) + discharge_time = admit_time + timedelta( + days=int(np.random.randint(1, 30)), + ) + + admission_type = admission_types[(i + j) % len(admission_types)] + + if np.random.random() < 0.15: + discharge_loc = "LEFT AGAINST MEDICAL ADVI" + elif np.random.random() < 0.05: + discharge_loc = "EXPIRED" + else: + discharge_loc = discharge_locations[ + (i + j) % (len(discharge_locations) - 2) + ] + + if np.random.random() < 0.2: + diagnosis = diagnoses_substance[ + np.random.randint(0, len(diagnoses_substance)) + ] + else: + diagnosis = diagnoses_other[ + np.random.randint(0, len(diagnoses_other)) + ] + + admissions_data.append({ "subject_id": subject_id, "hadm_id": hadm_id, - "icustay_id": icustay_id, - "first_careunit": "MICU", - "last_careunit": "MICU", - "dbsource": "metavision", - "intime": icu_intime.strftime("%Y-%m-%d %H:%M:%S"), - "outtime": icu_outtime.strftime("%Y-%m-%d %H:%M:%S"), + "admission_type": admission_type, + "admission_location": "EMERGENCY ROOM ADMIT", + "insurance": insurance, + "language": "ENGLISH", + "religion": "CHRISTIAN", + "marital_status": "SINGLE", + "ethnicity": ethnicity, + "edregtime": admit_time.strftime("%Y-%m-%d %H:%M:%S"), + "edouttime": admit_time.strftime("%Y-%m-%d %H:%M:%S"), + "diagnosis": diagnosis, + "discharge_location": discharge_loc, + "dischtime": discharge_time.strftime("%Y-%m-%d %H:%M:%S"), + "admittime": admit_time.strftime("%Y-%m-%d %H:%M:%S"), + "hospital_expire_flag": 1 + if discharge_loc == "EXPIRED" + else 0, }) - icustay_id += 1 - hadm_id += 1 - - subject_id += 1 - - def write_csv_gz(filename: str, data: List[dict]) -> None: - df = pd.DataFrame(data) - filepath = root_path / f"{filename}.gz" - with gzip.open(filepath, "wt") as f: - df.to_csv(f, index=False) - print(f" Created {filename}.gz ({len(data)} rows)") + icu_intime = admit_time + timedelta( + hours=int(np.random.randint(0, 12)), + ) + icu_outtime = discharge_time - timedelta( + hours=int(np.random.randint(0, 12)), + ) + + if icu_intime < icu_outtime: + icustays_data.append({ + "subject_id": subject_id, + "hadm_id": hadm_id, + "icustay_id": icustay_id, + "first_careunit": "MICU", + "last_careunit": "MICU", + "dbsource": "metavision", + "intime": icu_intime.strftime("%Y-%m-%d %H:%M:%S"), + "outtime": icu_outtime.strftime("%Y-%m-%d %H:%M:%S"), + }) + icustay_id += 1 + + hadm_id += 1 + + subject_id += 1 + + print(f"Generating random synthetic MIMIC-III in {root_path}...") + else: + raise ValueError(f"Unknown mode {mode!r}; use 'exhaustive' or 'random'.") - print(f"Generating synthetic MIMIC-III dataset in {root_path}...") write_csv_gz("PATIENTS.csv", patients_data) write_csv_gz("ADMISSIONS.csv", admissions_data) write_csv_gz("ICUSTAYS.csv", icustays_data) @@ -512,16 +666,37 @@ def main(): parser = argparse.ArgumentParser( description="AMA prediction ablation -- LogisticRegression", ) + parser.add_argument( + "--data-source", + choices=("auto", "synthetic", "real"), + default="auto", + help="auto: use --root if set, else synthetic. synthetic: always local CSVs. " + "real: require --root to MIMIC-III on disk.", + ) + parser.add_argument( + "--synthetic-mode", + choices=("exhaustive", "random"), + default="exhaustive", + help="exhaustive: full cross-product of demographics×AMA×substance (default). " + "random: stochastic demo (--patients applies).", + ) parser.add_argument( "--root", default=None, - help="MIMIC-III root (local path). If not provided, uses synthetic data.", + help="MIMIC-III root directory (required for --data-source real, or sets " + "auto mode to real when provided).", ) parser.add_argument( "--patients", type=int, default=100, - help="Number of synthetic patients (default 100, only used if --root not provided)", + help="Synthetic patient count (random mode only; exhaustive ignores this).", + ) + parser.add_argument( + "--seed", + type=int, + default=42, + help="RNG seed for random synthetic mode.", ) parser.add_argument( "--splits", @@ -538,21 +713,42 @@ def main(): parser.add_argument( "--dev", action="store_true", - help="Use dev mode (fewer patients/splits for testing)", + help="With random synthetic: only 10 patients. Exhaustive grid unchanged.", ) args = parser.parse_args() + if args.data_source == "real" and not args.root: + parser.error("--data-source real requires --root /path/to/mimic-iii/1.4") + + use_synthetic = args.data_source == "synthetic" or ( + args.data_source == "auto" and args.root is None + ) + if args.data_source == "synthetic" and args.root: + print( + "Note: --root is ignored with --data-source synthetic " + "(data is written to a temporary directory).\n", + ) + cache_dir = tempfile.mkdtemp(prefix="ama_lr_") - # If no root provided, generate synthetic data - if args.root is None: + if use_synthetic: print("[Setup] Generating synthetic MIMIC-III dataset...") data_dir = tempfile.mkdtemp(prefix="synthetic_mimic3_") - n_patients = 10 if args.dev else args.patients - generate_synthetic_mimic3(data_dir, n_patients=n_patients, seed=42) - args.root = data_dir - print(f" Synthetic data: {data_dir}\n") + n_patients = ( + 10 if args.dev and args.synthetic_mode == "random" else args.patients + ) + generate_synthetic_mimic3( + data_dir, + n_patients=n_patients, + avg_admissions_per_patient=2, + seed=args.seed, + mode=args.synthetic_mode, + ) + args.root = str(data_dir) + print(f" Synthetic data: {data_dir}") + print(f" Mode: {args.synthetic_mode}\n") else: + assert args.root is not None print(f"Using real MIMIC-III from: {args.root}\n") print(f"Cache: {cache_dir}") @@ -580,14 +776,15 @@ def main(): print(f"\n {exc}") print(" The dataset contains no AMA-positive cases.") print(" For synthetic data: this is expected if AMA rate is low.") - print(" To increase AMA cases, re-run with:") + print(" For synthetic random mode with more patients:") print(" python examples/" "mimic3_ama_prediction_logistic_regression.py \\") - print(" --patients 500\n") - print(" For real MIMIC-III with better AMA coverage:") + print(" --data-source synthetic --synthetic-mode random " + "--patients 500\n") + print(" For real MIMIC-III:") print(" python examples/" "mimic3_ama_prediction_logistic_regression.py \\") - print(" --root /path/to/mimic-iii/1.4") + print(" --data-source real --root /path/to/mimic-iii/1.4") print("\nDone.") return print(f" Samples: {len(sample_dataset)}") diff --git a/examples/mimic3_ama_prediction_rnn.py b/examples/mimic3_ama_prediction_rnn.py index 15b834843..d067dea7c 100644 --- a/examples/mimic3_ama_prediction_rnn.py +++ b/examples/mimic3_ama_prediction_rnn.py @@ -13,38 +13,50 @@ LogisticRegression ablation in the companion script. For each baseline the script reports: - 1. Overall AUROC / PR-AUC averaged over N random 60/40 splits. - 2. Subgroup performance (AUROC, PR-AUC) sliced by Race, Age Group, + 1. Overall AUROC averaged over N random 60/40 splits. + 2. Subgroup performance (AUROC) sliced by Race, Age Group, and Insurance Type. 3. Fairness metrics per subgroup: - Demographic Parity = % predicted AMA (P(Y_hat=1 | Group=g)) - Equal Opportunity = True Positive Rate (P(Y_hat=1 | Y=1, Group=g)) -Usage (synthetic demo data -- illustrative only, likely no AMA positives): - python examples/mimic3_ama_prediction_rnn.py +Synthetic data is generated by the same helper as the LogisticRegression +example (``generate_synthetic_mimic3`` in ``mimic3_ama_prediction_ +logistic_regression.py``). + +Usage (synthetic exhaustive grid, default): + python examples/mimic3_ama_prediction_rnn.py --data-source synthetic Usage (real MIMIC-III): python examples/mimic3_ama_prediction_rnn.py \\ - --root /path/to/mimic-iii/1.4 --splits 100 --epochs 10 + --data-source real --root /path/to/mimic-iii/1.4 --splits 100 --epochs 10 """ import argparse +import importlib.util import tempfile import time +from pathlib import Path import numpy as np import torch -from sklearn.metrics import average_precision_score, roc_auc_score +from sklearn.metrics import roc_auc_score from pyhealth.datasets import MIMIC3Dataset, get_dataloader, split_by_patient from pyhealth.models import RNN from pyhealth.tasks import AMAPredictionMIMIC3 from pyhealth.trainer import Trainer -SYNTHETIC_ROOT = ( - "https://storage.googleapis.com/pyhealth/Synthetic_MIMIC-III" +_LR_EXAMPLE = Path(__file__).resolve().parent / ( + "mimic3_ama_prediction_logistic_regression.py" +) +_spec = importlib.util.spec_from_file_location( + "mimic3_ama_lr_example", _LR_EXAMPLE, ) -TABLES = ["DIAGNOSES_ICD", "PROCEDURES_ICD", "PRESCRIPTIONS"] +_lr_mod = importlib.util.module_from_spec(_spec) +assert _spec.loader is not None +_spec.loader.exec_module(_lr_mod) +generate_synthetic_mimic3 = _lr_mod.generate_synthetic_mimic3 BASELINES = { "BASELINE": ["demographics", "age", "los"], @@ -143,15 +155,6 @@ def _safe_auroc(y, p): return float("nan") -def _safe_prauc(y, p): - if np.sum(y) == 0: - return float("nan") - try: - return average_precision_score(y, p) - except ValueError: - return float("nan") - - # ------------------------------------------------------------------ # Single split # ------------------------------------------------------------------ @@ -197,7 +200,6 @@ def _run_single_split(sample_dataset, feature_keys, lookup, y_pred = (y_prob >= threshold).astype(int) overall_auroc = _safe_auroc(y_true, y_prob) - overall_prauc = _safe_prauc(y_true, y_prob) subgroup = {} for attr_name, attr_vals in groups.items(): @@ -211,7 +213,6 @@ def _run_single_split(sample_dataset, feature_keys, lookup, pos = yt.sum() subgroup[attr_name][grp] = { "auroc": _safe_auroc(yt, yp), - "pr_auc": _safe_prauc(yt, yp), "pct_pred": float(yd.mean()) * 100, "tpr": float(yd[yt == 1].mean()) * 100 if pos > 0 else float("nan"), @@ -220,7 +221,6 @@ def _run_single_split(sample_dataset, feature_keys, lookup, return { "auroc": overall_auroc, - "pr_auc": overall_prauc, "subgroups": subgroup, } @@ -249,8 +249,6 @@ def _aggregate(results): "n": len(valid), "auroc_mean": _nanmean([r["auroc"] for r in valid]), "auroc_std": _nanstd([r["auroc"] for r in valid]), - "pr_auc_mean": _nanmean([r["pr_auc"] for r in valid]), - "pr_auc_std": _nanstd([r["pr_auc"] for r in valid]), } all_attrs = set() @@ -266,13 +264,12 @@ def _aggregate(results): all_grps.update(r["subgroups"][attr].keys()) for grp in sorted(all_grps): - aurocs, praucs, pcts, tprs, ns = [], [], [], [], [] + aurocs, pcts, tprs, ns = [], [], [], [] for r in valid: m = r["subgroups"].get(attr, {}).get(grp) if m is None: continue aurocs.append(m["auroc"]) - praucs.append(m["pr_auc"]) pcts.append(m["pct_pred"]) tprs.append(m["tpr"]) ns.append(m["n"]) @@ -280,8 +277,6 @@ def _aggregate(results): agg["subgroups"][attr][grp] = { "auroc_mean": _nanmean(aurocs), "auroc_std": _nanstd(aurocs), - "pr_auc_mean": _nanmean(praucs), - "pr_auc_std": _nanstd(praucs), "pct_pred_mean": _nanmean(pcts), "tpr_mean": _nanmean(tprs), "n_avg": int(np.mean(ns)) if ns else 0, @@ -313,17 +308,15 @@ def _print_results(name, feature_keys, agg): print(f"\n 1. Overall Performance ({agg['n']} splits)") print(f" AUROC: {_fmt(agg['auroc_mean'])} +/- {_fmt(agg['auroc_std'])}" f" 95% CI ({_fmt(ci_lo)}, {_fmt(ci_hi)})") - print(f" PR-AUC: {_fmt(agg['pr_auc_mean'])} +/- {_fmt(agg['pr_auc_std'])}") print(f"\n 2. Subgroup Performance") for attr, grps in agg["subgroups"].items(): print(f" {attr}:") - print(f" {'Group':<20} {'AUROC':>15} {'PR-AUC':>15} {'n_avg':>7}") - print(f" {'-'*58}") + print(f" {'Group':<20} {'AUROC':>15} {'n_avg':>7}") + print(f" {'-'*42}") for grp, m in grps.items(): a_str = f"{_fmt(m['auroc_mean'])}+/-{_fmt(m['auroc_std'])}" - p_str = f"{_fmt(m['pr_auc_mean'])}+/-{_fmt(m['pr_auc_std'])}" - print(f" {grp:<20} {a_str:>15} {p_str:>15} {m['n_avg']:>7}") + print(f" {grp:<20} {a_str:>15} {m['n_avg']:>7}") print(f"\n 3. Fairness Metrics") print(f" Demographic Parity (% Predicted AMA):") @@ -347,17 +340,67 @@ def main(): parser = argparse.ArgumentParser( description="AMA prediction ablation -- RNN", ) - parser.add_argument("--root", default=SYNTHETIC_ROOT, - help="MIMIC-III root (local path or URL)") - parser.add_argument("--splits", type=int, default=100, - help="Number of random 60/40 splits (default 100)") - parser.add_argument("--epochs", type=int, default=10, + parser.add_argument( + "--data-source", + choices=("auto", "synthetic", "real"), + default="auto", + help="auto: use --root if set, else synthetic. synthetic: local CSVs. " + "real: require --root.", + ) + parser.add_argument( + "--synthetic-mode", + choices=("exhaustive", "random"), + default="exhaustive", + help="Passed to generate_synthetic_mimic3 (see logistic example).", + ) + parser.add_argument( + "--root", + default=None, + help="MIMIC-III root (required for --data-source real).", + ) + parser.add_argument("--patients", type=int, default=100, + help="Random synthetic mode only.") + parser.add_argument("--seed", type=int, default=42, + help="RNG seed for random synthetic mode.") + parser.add_argument("--splits", type=int, default=5, + help="Number of random 60/40 splits") + parser.add_argument("--epochs", type=int, default=3, help="Training epochs per split") parser.add_argument("--dev", action="store_true", - help="Use dev mode (1000 patients)") + help="Random synthetic: 10 patients only.") args = parser.parse_args() + if args.data_source == "real" and not args.root: + parser.error("--data-source real requires --root /path/to/mimic-iii/1.4") + + use_synthetic = args.data_source == "synthetic" or ( + args.data_source == "auto" and args.root is None + ) + if args.data_source == "synthetic" and args.root: + print( + "Note: --root ignored with --data-source synthetic.\n", + ) + cache_dir = tempfile.mkdtemp(prefix="ama_rnn_") + + if use_synthetic: + print("[Setup] Generating synthetic MIMIC-III dataset...") + data_dir = tempfile.mkdtemp(prefix="synthetic_mimic3_") + n_patients = ( + 10 if args.dev and args.synthetic_mode == "random" else args.patients + ) + generate_synthetic_mimic3( + data_dir, + n_patients=n_patients, + avg_admissions_per_patient=2, + seed=args.seed, + mode=args.synthetic_mode, + ) + args.root = str(data_dir) + print(f" Synthetic: {data_dir} mode={args.synthetic_mode}\n") + else: + print(f"Using real MIMIC-III from: {args.root}\n") + print(f"Cache: {cache_dir}") print(f"Root: {args.root}") print(f"Splits: {args.splits} | Epochs: {args.epochs}") @@ -365,7 +408,7 @@ def main(): print("\n[1/4] Loading dataset...") t0 = time.time() dataset = MIMIC3Dataset( - root=args.root, tables=TABLES, + root=args.root, tables=[], cache_dir=cache_dir, dev=args.dev, ) print(f" Loaded in {time.time()-t0:.1f}s") @@ -388,10 +431,10 @@ def main(): samples = task(patient) total += len(samples) print(f" Task produced {total} samples (all label=0)") - print("\n Re-run with real MIMIC-III for ablation:") + print("\n Re-run with real MIMIC-III:") print(" python examples/" "mimic3_ama_prediction_rnn.py \\") - print(" --root /path/to/mimic-iii/1.4") + print(" --data-source real --root /path/to/mimic-iii/1.4") print("\nDone.") return print(f" Samples: {len(sample_dataset)}") diff --git a/tests/core/test_ama_ablation_with_synthetic_data.py b/tests/core/test_ama_ablation_with_synthetic_data.py deleted file mode 100644 index 8a0a6452c..000000000 --- a/tests/core/test_ama_ablation_with_synthetic_data.py +++ /dev/null @@ -1,378 +0,0 @@ -"""Tests for AMA prediction ablation studies with synthetic data. - -Uses local synthetic MIMIC-III data to ensure fast execution -and comprehensive coverage of all demographic combinations. - -Synthetic CSV generation lives in ``examples/mimic3_ama_prediction_ -logistic_regression.py``; tests load that helper via importlib so there -is no separate dataset module. -""" - -import importlib.util -import tempfile -import unittest -from datetime import datetime -from pathlib import Path - -import numpy as np -import torch - -from pyhealth.datasets import MIMIC3Dataset, get_dataloader, split_by_patient -from pyhealth.models import LogisticRegression -from pyhealth.tasks import AMAPredictionMIMIC3 -from pyhealth.trainer import Trainer - -_EXAMPLE_PATH = ( - Path(__file__).resolve().parents[2] - / "examples" - / "mimic3_ama_prediction_logistic_regression.py" -) -_spec = importlib.util.spec_from_file_location( - "mimic3_ama_prediction_example", _EXAMPLE_PATH -) -_example_mod = importlib.util.module_from_spec(_spec) -assert _spec.loader is not None -_spec.loader.exec_module(_example_mod) -generate_synthetic_mimic3 = _example_mod.generate_synthetic_mimic3 - - -class TestAMAWithSyntheticData(unittest.TestCase): - """Test AMA prediction ablation studies on synthetic data. - - Uses a small synthetic dataset that covers all demographic - combinations to ensure tests run quickly (~1 second total). - """ - - @classmethod - def setUpClass(cls): - """Generate synthetic dataset once for all tests.""" - cls.tmpdir = tempfile.mkdtemp(prefix="ama_test_") - cls.cache_dir = tempfile.mkdtemp(prefix="ama_test_cache_") - - # Generate small but comprehensive synthetic data - generate_synthetic_mimic3( - cls.tmpdir, - n_patients=50, - avg_admissions_per_patient=2, - seed=42, - ) - - # Load dataset - cls.dataset = MIMIC3Dataset( - root=cls.tmpdir, - tables=[], - cache_dir=cls.cache_dir, - ) - - # Apply task - cls.task = AMAPredictionMIMIC3() - cls.sample_dataset = cls.dataset.set_task(cls.task) - - def test_dataset_loads_successfully(self): - """Verify synthetic dataset loads with expected structure.""" - self.assertIsNotNone(self.dataset) - self.assertGreater(len(self.sample_dataset), 0) - - def test_samples_have_expected_features(self): - """Verify each sample contains required features.""" - sample = self.sample_dataset[0] - - expected_keys = { - "visit_id", - "patient_id", - "demographics", - "age", - "los", - "race", - "has_substance_use", - "ama", - } - self.assertEqual(set(sample.keys()), expected_keys) - - def test_demographics_values(self): - """Verify demographics contain processed feature vectors.""" - for sample in self.sample_dataset: - demo = sample["demographics"] - # After processing, demographics are tensors - self.assertTrue( - torch.is_tensor(demo) or isinstance(demo, (int, float)), - "Demographics should be processed", - ) - - def test_age_in_valid_range(self): - """Verify ages are processed as tensors.""" - for sample in self.sample_dataset: - age = sample["age"] - # After processing, age is a tensor - self.assertTrue(torch.is_tensor(age) or isinstance(age, (int, float))) - - def test_los_positive(self): - """Verify LOS (length of stay) is processed as tensor.""" - for sample in self.sample_dataset: - los = sample["los"] - # After processing, los is a tensor - self.assertTrue(torch.is_tensor(los) or isinstance(los, (int, float))) - - def test_race_normalized(self): - """Verify race is processed as tensor.""" - for sample in self.sample_dataset: - race = sample["race"] - # After processing, race is a tensor - self.assertTrue(torch.is_tensor(race) or isinstance(race, (int, float))) - - def test_substance_use_binary(self): - """Verify substance use is processed as tensor.""" - for sample in self.sample_dataset: - substance = sample["has_substance_use"] - # After processing, substance use is a tensor - self.assertTrue(torch.is_tensor(substance) or isinstance(substance, (int, float))) - - def test_ama_label_binary(self): - """Verify AMA label is 0 or 1.""" - for sample in self.sample_dataset: - ama = sample["ama"] - self.assertIn(ama, [0, 1]) - - def test_has_positive_and_negative_labels(self): - """Verify dataset has both AMA positive and negative cases.""" - labels = [sample["ama"] for sample in self.sample_dataset] - has_positive = any(l == 1 for l in labels) - has_negative = any(l == 0 for l in labels) - - self.assertTrue( - has_positive and has_negative, - "Dataset should have both positive and negative AMA cases", - ) - - -class TestAMABaselineFeatures(unittest.TestCase): - """Test that each ablation baseline uses correct features.""" - - @classmethod - def setUpClass(cls): - """Generate synthetic dataset for baseline tests.""" - cls.tmpdir = tempfile.mkdtemp(prefix="ama_baseline_test_") - cls.cache_dir = tempfile.mkdtemp(prefix="ama_baseline_cache_") - - generate_synthetic_mimic3( - cls.tmpdir, - n_patients=30, - seed=42, - ) - - cls.dataset = MIMIC3Dataset( - root=cls.tmpdir, - tables=[], - cache_dir=cls.cache_dir, - ) - - cls.task = AMAPredictionMIMIC3() - cls.sample_dataset = cls.dataset.set_task(cls.task) - - def _create_model_with_features(self, feature_keys): - """Helper to create logistic regression model with feature keys.""" - model = LogisticRegression( - dataset=self.sample_dataset, - embedding_dim=64, # Use dataset's embedding_dim - ) - model.feature_keys = list(feature_keys) - output_size = model.get_output_size() - embedding_dim = model.embedding_model.embedding_layers[ - feature_keys[0] - ].out_features - model.fc = torch.nn.Linear( - len(feature_keys) * embedding_dim, output_size - ) - return model - - def test_baseline_model_can_be_created(self): - """BASELINE: demographics, age, los.""" - model = self._create_model_with_features( - ["demographics", "age", "los"] - ) - self.assertIsNotNone(model) - # Verify fc layer exists - self.assertIsNotNone(model.fc) - - def test_baseline_plus_race_model(self): - """BASELINE + RACE: adds race.""" - model = self._create_model_with_features( - ["demographics", "age", "los", "race"] - ) - self.assertIsNotNone(model) - self.assertIsNotNone(model.fc) - - def test_baseline_plus_race_plus_substance_model(self): - """BASELINE + RACE + SUBSTANCE: adds has_substance_use.""" - model = self._create_model_with_features( - ["demographics", "age", "los", "race", "has_substance_use"] - ) - self.assertIsNotNone(model) - self.assertIsNotNone(model.fc) - - def test_baseline_forward_pass(self): - """Verify model forward pass works with baseline features.""" - model = self._create_model_with_features( - ["demographics", "age", "los"] - ) - - train_ds, _, test_ds = split_by_patient( - self.sample_dataset, [0.8, 0.0, 0.2], seed=0 - ) - test_dl = get_dataloader(test_ds, batch_size=8, shuffle=False) - - model.eval() - with torch.no_grad(): - batch = next(iter(test_dl)) - output = model(**batch) - - self.assertIn("y_prob", output) - self.assertIn("y_true", output) - self.assertEqual(output["y_prob"].shape[0], len(test_ds)) - - def test_baseline_plus_race_forward_pass(self): - """Verify model forward pass with race feature.""" - model = self._create_model_with_features( - ["demographics", "age", "los", "race"] - ) - - train_ds, _, test_ds = split_by_patient( - self.sample_dataset, [0.8, 0.0, 0.2], seed=0 - ) - test_dl = get_dataloader(test_ds, batch_size=8, shuffle=False) - - model.eval() - with torch.no_grad(): - batch = next(iter(test_dl)) - output = model(**batch) - - self.assertIn("y_prob", output) - self.assertEqual(output["y_prob"].shape[0], len(test_ds)) - - def test_baseline_plus_full_forward_pass(self): - """Verify model forward pass with all features.""" - model = self._create_model_with_features( - ["demographics", "age", "los", "race", "has_substance_use"] - ) - - train_ds, _, test_ds = split_by_patient( - self.sample_dataset, [0.8, 0.0, 0.2], seed=0 - ) - test_dl = get_dataloader(test_ds, batch_size=8, shuffle=False) - - model.eval() - with torch.no_grad(): - batch = next(iter(test_dl)) - output = model(**batch) - - self.assertIn("y_prob", output) - self.assertEqual(output["y_prob"].shape[0], len(test_ds)) - - -class TestAMATrainingSpeed(unittest.TestCase): - """Verify training with synthetic data is fast.""" - - @classmethod - def setUpClass(cls): - """Generate small synthetic dataset for speed tests.""" - cls.tmpdir = tempfile.mkdtemp(prefix="ama_speed_test_") - cls.cache_dir = tempfile.mkdtemp(prefix="ama_speed_cache_") - - generate_synthetic_mimic3( - cls.tmpdir, - n_patients=20, # Small for speed - seed=42, - ) - - cls.dataset = MIMIC3Dataset( - root=cls.tmpdir, - tables=[], - cache_dir=cls.cache_dir, - ) - - cls.task = AMAPredictionMIMIC3() - cls.sample_dataset = cls.dataset.set_task(cls.task) - - def test_training_completes_quickly(self): - """Verify one training epoch completes in reasonable time.""" - import time - - train_ds, _, test_ds = split_by_patient( - self.sample_dataset, [0.6, 0.0, 0.4], seed=0 - ) - train_dl = get_dataloader( - train_ds, batch_size=8, shuffle=True - ) - - model = LogisticRegression( - dataset=self.sample_dataset, - embedding_dim=64, - ) - model.feature_keys = ["demographics", "age", "los"] - output_size = model.get_output_size() - embedding_dim = model.embedding_model.embedding_layers[ - "demographics" - ].out_features - model.fc = torch.nn.Linear( - 3 * embedding_dim, output_size - ) - - trainer = Trainer(model=model) - - t0 = time.time() - trainer.train( - train_dataloader=train_dl, - val_dataloader=None, - epochs=1, - monitor=None, - ) - elapsed = time.time() - t0 - - # Should complete in reasonable time - self.assertGreater(elapsed, 0, "Training should take some time") - - def test_multiple_splits_complete_quickly(self): - """Verify 2 random splits complete without error.""" - train_ds, _, test_ds = split_by_patient( - self.sample_dataset, - [0.6, 0.0, 0.4], - seed=0, - ) - - for split_seed in range(2): - train_ds, _, _ = split_by_patient( - self.sample_dataset, - [0.6, 0.0, 0.4], - seed=split_seed, - ) - train_dl = get_dataloader( - train_ds, batch_size=8, shuffle=True - ) - - model = LogisticRegression( - dataset=self.sample_dataset, - embedding_dim=64, - ) - model.feature_keys = ["demographics", "age", "los"] - output_size = model.get_output_size() - embedding_dim = model.embedding_model.embedding_layers[ - "demographics" - ].out_features - model.fc = torch.nn.Linear( - 3 * embedding_dim, output_size - ) - - trainer = Trainer(model=model) - trainer.train( - train_dataloader=train_dl, - val_dataloader=None, - epochs=1, - monitor=None, - ) - - # Verify we completed without errors - self.assertTrue(True) - - -if __name__ == "__main__": - unittest.main() diff --git a/tests/core/test_mimic3_ama_prediction.py b/tests/core/test_mimic3_ama_prediction.py index e820c5262..1c7960aa8 100644 --- a/tests/core/test_mimic3_ama_prediction.py +++ b/tests/core/test_mimic3_ama_prediction.py @@ -1,13 +1,36 @@ +import gzip +import importlib.util +import shutil +import tempfile import unittest from datetime import datetime +from pathlib import Path from unittest.mock import MagicMock +import torch + +from pyhealth.datasets import MIMIC3Dataset, get_dataloader, split_by_patient +from pyhealth.models import LogisticRegression from pyhealth.tasks.ama_prediction import ( AMAPredictionMIMIC3, _has_substance_use, _normalize_insurance, _normalize_race, ) +from pyhealth.trainer import Trainer + +_EXAMPLE_PATH = ( + Path(__file__).resolve().parents[2] + / "examples" + / "mimic3_ama_prediction_logistic_regression.py" +) +_spec = importlib.util.spec_from_file_location( + "mimic3_ama_prediction_example", _EXAMPLE_PATH, +) +_example_mod = importlib.util.module_from_spec(_spec) +assert _spec.loader is not None +_spec.loader.exec_module(_example_mod) +generate_synthetic_mimic3 = _example_mod.generate_synthetic_mimic3 def _make_event(**attrs): @@ -886,5 +909,340 @@ def test_ablation_patient_no_clinical_codes(self): self.assertNotIn("drugs", sample) +# ------------------------------------------------------------------ +# Synthetic MIMIC-III + LogisticRegression (loads example generator) +# ------------------------------------------------------------------ + + +class TestAMAWithSyntheticData(unittest.TestCase): + """AMA task on small random synthetic CSVs (fast pipeline checks).""" + + @classmethod + def setUpClass(cls): + cls.tmpdir = tempfile.mkdtemp(prefix="ama_test_") + cls.cache_dir = tempfile.mkdtemp(prefix="ama_test_cache_") + + generate_synthetic_mimic3( + cls.tmpdir, + n_patients=50, + avg_admissions_per_patient=2, + seed=42, + mode="random", + ) + + cls.dataset = MIMIC3Dataset( + root=cls.tmpdir, + tables=[], + cache_dir=cls.cache_dir, + ) + + cls.task = AMAPredictionMIMIC3() + cls.sample_dataset = cls.dataset.set_task(cls.task) + + def test_dataset_loads_successfully(self): + self.assertIsNotNone(self.dataset) + self.assertGreater(len(self.sample_dataset), 0) + + def test_samples_have_expected_features(self): + sample = self.sample_dataset[0] + + expected_keys = { + "visit_id", + "patient_id", + "demographics", + "age", + "los", + "race", + "has_substance_use", + "ama", + } + self.assertEqual(set(sample.keys()), expected_keys) + + def test_demographics_values(self): + for sample in self.sample_dataset: + demo = sample["demographics"] + self.assertTrue( + torch.is_tensor(demo) or isinstance(demo, (int, float)), + "Demographics should be processed", + ) + + def test_age_in_valid_range(self): + for sample in self.sample_dataset: + age = sample["age"] + self.assertTrue(torch.is_tensor(age) or isinstance(age, (int, float))) + + def test_los_positive(self): + for sample in self.sample_dataset: + los = sample["los"] + self.assertTrue(torch.is_tensor(los) or isinstance(los, (int, float))) + + def test_race_normalized(self): + for sample in self.sample_dataset: + race = sample["race"] + self.assertTrue(torch.is_tensor(race) or isinstance(race, (int, float))) + + def test_substance_use_binary(self): + for sample in self.sample_dataset: + substance = sample["has_substance_use"] + self.assertTrue( + torch.is_tensor(substance) or isinstance(substance, (int, float)), + ) + + def test_ama_label_binary(self): + for sample in self.sample_dataset: + ama = sample["ama"] + self.assertIn(ama, [0, 1]) + + def test_has_positive_and_negative_labels(self): + labels = [sample["ama"] for sample in self.sample_dataset] + has_positive = any(l == 1 for l in labels) + has_negative = any(l == 0 for l in labels) + + self.assertTrue( + has_positive and has_negative, + "Dataset should have both positive and negative AMA cases", + ) + + +class TestAMABaselineFeatures(unittest.TestCase): + """LogisticRegression ablation feature subsets on synthetic data.""" + + @classmethod + def setUpClass(cls): + cls.tmpdir = tempfile.mkdtemp(prefix="ama_baseline_test_") + cls.cache_dir = tempfile.mkdtemp(prefix="ama_baseline_cache_") + + generate_synthetic_mimic3( + cls.tmpdir, + n_patients=30, + seed=42, + mode="random", + ) + + cls.dataset = MIMIC3Dataset( + root=cls.tmpdir, + tables=[], + cache_dir=cls.cache_dir, + ) + + cls.task = AMAPredictionMIMIC3() + cls.sample_dataset = cls.dataset.set_task(cls.task) + + def _create_model_with_features(self, feature_keys): + model = LogisticRegression( + dataset=self.sample_dataset, + embedding_dim=64, + ) + model.feature_keys = list(feature_keys) + output_size = model.get_output_size() + embedding_dim = model.embedding_model.embedding_layers[ + feature_keys[0] + ].out_features + model.fc = torch.nn.Linear( + len(feature_keys) * embedding_dim, output_size + ) + return model + + def test_baseline_model_can_be_created(self): + model = self._create_model_with_features( + ["demographics", "age", "los"] + ) + self.assertIsNotNone(model) + self.assertIsNotNone(model.fc) + + def test_baseline_plus_race_model(self): + model = self._create_model_with_features( + ["demographics", "age", "los", "race"] + ) + self.assertIsNotNone(model) + self.assertIsNotNone(model.fc) + + def test_baseline_plus_race_plus_substance_model(self): + model = self._create_model_with_features( + ["demographics", "age", "los", "race", "has_substance_use"] + ) + self.assertIsNotNone(model) + self.assertIsNotNone(model.fc) + + def test_baseline_forward_pass(self): + model = self._create_model_with_features( + ["demographics", "age", "los"] + ) + + train_ds, _, test_ds = split_by_patient( + self.sample_dataset, [0.8, 0.0, 0.2], seed=0 + ) + test_dl = get_dataloader(test_ds, batch_size=8, shuffle=False) + + model.eval() + with torch.no_grad(): + batch = next(iter(test_dl)) + output = model(**batch) + + self.assertIn("y_prob", output) + self.assertIn("y_true", output) + self.assertEqual(output["y_prob"].shape[0], len(test_ds)) + + def test_baseline_plus_race_forward_pass(self): + model = self._create_model_with_features( + ["demographics", "age", "los", "race"] + ) + + train_ds, _, test_ds = split_by_patient( + self.sample_dataset, [0.8, 0.0, 0.2], seed=0 + ) + test_dl = get_dataloader(test_ds, batch_size=8, shuffle=False) + + model.eval() + with torch.no_grad(): + batch = next(iter(test_dl)) + output = model(**batch) + + self.assertIn("y_prob", output) + self.assertEqual(output["y_prob"].shape[0], len(test_ds)) + + def test_baseline_plus_full_forward_pass(self): + model = self._create_model_with_features( + ["demographics", "age", "los", "race", "has_substance_use"] + ) + + train_ds, _, test_ds = split_by_patient( + self.sample_dataset, [0.8, 0.0, 0.2], seed=0 + ) + test_dl = get_dataloader(test_ds, batch_size=8, shuffle=False) + + model.eval() + with torch.no_grad(): + batch = next(iter(test_dl)) + output = model(**batch) + + self.assertIn("y_prob", output) + self.assertEqual(output["y_prob"].shape[0], len(test_ds)) + + +class TestAMATrainingSpeed(unittest.TestCase): + """Short training runs on tiny synthetic data.""" + + @classmethod + def setUpClass(cls): + cls.tmpdir = tempfile.mkdtemp(prefix="ama_speed_test_") + cls.cache_dir = tempfile.mkdtemp(prefix="ama_speed_cache_") + + generate_synthetic_mimic3( + cls.tmpdir, + n_patients=20, + seed=42, + mode="random", + ) + + cls.dataset = MIMIC3Dataset( + root=cls.tmpdir, + tables=[], + cache_dir=cls.cache_dir, + ) + + cls.task = AMAPredictionMIMIC3() + cls.sample_dataset = cls.dataset.set_task(cls.task) + + def test_training_completes_quickly(self): + import time + + train_ds, _, test_ds = split_by_patient( + self.sample_dataset, [0.6, 0.0, 0.4], seed=0 + ) + train_dl = get_dataloader( + train_ds, batch_size=8, shuffle=True + ) + + model = LogisticRegression( + dataset=self.sample_dataset, + embedding_dim=64, + ) + model.feature_keys = ["demographics", "age", "los"] + output_size = model.get_output_size() + embedding_dim = model.embedding_model.embedding_layers[ + "demographics" + ].out_features + model.fc = torch.nn.Linear( + 3 * embedding_dim, output_size + ) + + trainer = Trainer(model=model) + + t0 = time.time() + trainer.train( + train_dataloader=train_dl, + val_dataloader=None, + epochs=1, + monitor=None, + ) + elapsed = time.time() - t0 + + self.assertGreater(elapsed, 0, "Training should take some time") + + def test_multiple_splits_complete_quickly(self): + split_by_patient( + self.sample_dataset, + [0.6, 0.0, 0.4], + seed=0, + ) + + for split_seed in range(2): + train_ds, _, _ = split_by_patient( + self.sample_dataset, + [0.6, 0.0, 0.4], + seed=split_seed, + ) + train_dl = get_dataloader( + train_ds, batch_size=8, shuffle=True + ) + + model = LogisticRegression( + dataset=self.sample_dataset, + embedding_dim=64, + ) + model.feature_keys = ["demographics", "age", "los"] + output_size = model.get_output_size() + embedding_dim = model.embedding_model.embedding_layers[ + "demographics" + ].out_features + model.fc = torch.nn.Linear( + 3 * embedding_dim, output_size + ) + + trainer = Trainer(model=model) + trainer.train( + train_dataloader=train_dl, + val_dataloader=None, + epochs=1, + monitor=None, + ) + + self.assertTrue(True) + + +EXHAUSTIVE_PATIENT_ROWS = 2 * 6 * 6 * 3 * 2 * 2 + 3 + + +class TestExhaustiveSyntheticGrid(unittest.TestCase): + """Sanity-check exhaustive synthetic generator (row counts only).""" + + def test_patient_row_count_matches_cross_product(self): + tmp = tempfile.mkdtemp(prefix="ama_exhaustive_") + try: + generate_synthetic_mimic3( + tmp, + mode="exhaustive", + seed=0, + n_patients=1, + ) + with gzip.open(Path(tmp) / "PATIENTS.csv.gz", "rt") as f: + lines = f.readlines() + data_rows = len(lines) - 1 + self.assertEqual(data_rows, EXHAUSTIVE_PATIENT_ROWS) + finally: + shutil.rmtree(tmp, ignore_errors=True) + + if __name__ == "__main__": unittest.main() From 075c821e33b5cadfc2aca515853ab36f4b4a6ff0 Mon Sep 17 00:00:00 2001 From: Madhav Kanda Date: Wed, 15 Apr 2026 18:04:12 -0500 Subject: [PATCH 06/10] Pep8 style modifications --- .../pyhealth.datasets.MIMIC3Dataset.rst | 8 +- .../tasks/pyhealth.tasks.ama_prediction.rst | 10 + ...mic3_ama_prediction_logistic_regression.py | 34 +- examples/mimic3_ama_prediction_rnn.py | 1 + pyhealth/tasks/ama_prediction.py | 20 +- tests/core/test_mimic3_ama_prediction.py | 452 ++++++++++++++---- 6 files changed, 407 insertions(+), 118 deletions(-) diff --git a/docs/api/datasets/pyhealth.datasets.MIMIC3Dataset.rst b/docs/api/datasets/pyhealth.datasets.MIMIC3Dataset.rst index 0b7ec27c0..381357eff 100644 --- a/docs/api/datasets/pyhealth.datasets.MIMIC3Dataset.rst +++ b/docs/api/datasets/pyhealth.datasets.MIMIC3Dataset.rst @@ -1,4 +1,4 @@ -pyhealth.datasets.MIMIC3Dataset +pyhealth.datasets.MIMIC3Dataset =================================== The open Medical Information Mart for Intensive Care (MIMIC-III) database, refer to `doc `_ for more information. We process this database into well-structured dataset object and give user the **best flexibility and convenience** for supporting modeling and analysis. @@ -8,8 +8,8 @@ The open Medical Information Mart for Intensive Care (MIMIC-III) database, refer :undoc-members: :show-inheritance: - +.. seealso:: - - + Administrative AMA discharge prediction (no ICD / Rx tables required): + :class:`pyhealth.tasks.ama_prediction.AMAPredictionMIMIC3`. \ No newline at end of file diff --git a/docs/api/tasks/pyhealth.tasks.ama_prediction.rst b/docs/api/tasks/pyhealth.tasks.ama_prediction.rst index a82a2521c..aaac74de4 100644 --- a/docs/api/tasks/pyhealth.tasks.ama_prediction.rst +++ b/docs/api/tasks/pyhealth.tasks.ama_prediction.rst @@ -1,6 +1,16 @@ pyhealth.tasks.ama_prediction ================================= +Against-medical-advice (AMA) discharge on MIMIC-III using administrative +features only. Use with :class:`pyhealth.datasets.MIMIC3Dataset` and +``tables=[]`` unless you extend the cohort. + +.. seealso:: + + * :doc:`../tasks` — task / processor overview (``set_task`` workflow). + * :doc:`../../how_to_contribute` — contribution and PR expectations. + * :doc:`../datasets/pyhealth.datasets.MIMIC3Dataset` — source dataset. + .. autoclass:: pyhealth.tasks.ama_prediction.AMAPredictionMIMIC3 :members: :undoc-members: diff --git a/examples/mimic3_ama_prediction_logistic_regression.py b/examples/mimic3_ama_prediction_logistic_regression.py index 1127a0180..a0b807e40 100644 --- a/examples/mimic3_ama_prediction_logistic_regression.py +++ b/examples/mimic3_ama_prediction_logistic_regression.py @@ -37,6 +37,7 @@ Usage (real MIMIC-III; same as ``--root /path`` with ``--data-source auto``): python examples/mimic3_ama_prediction_logistic_regression.py \\ --data-source real --root /path/to/mimic-iii/1.4 --splits 100 --epochs 10 + """ import argparse @@ -249,13 +250,36 @@ def append_visit( * len(discharge_ama) * len(diagnosis_texts) ) - for k, extra in enumerate( + extra_rows = ( + ( + "M", + 45, + "WHITE", + "Private", + "EMERGENCY", + "SKILLED NURSING FACILITY", + "SEPSIS", + ), ( - ("M", 45, "WHITE", "Private", "EMERGENCY", "SKILLED NURSING FACILITY", "SEPSIS"), - ("F", 55, "BLACK/AFRICAN AMERICAN", "Medicaid", "EMERGENCY", "EXPIRED", "CHEST PAIN"), - ("M", 28, "HISPANIC OR LATINO", "Private", "NEWBORN", "HOME", "PNEUMONIA"), + "F", + 55, + "BLACK/AFRICAN AMERICAN", + "Medicaid", + "EMERGENCY", + "EXPIRED", + "CHEST PAIN", ), - ): + ( + "M", + 28, + "HISPANIC OR LATINO", + "Private", + "NEWBORN", + "HOME", + "PNEUMONIA", + ), + ) + for k, extra in enumerate(extra_rows): g, age_y, eth, ins, adm_type, disch, diag = extra icustay_id = append_visit( subject_id, diff --git a/examples/mimic3_ama_prediction_rnn.py b/examples/mimic3_ama_prediction_rnn.py index d067dea7c..6729255a9 100644 --- a/examples/mimic3_ama_prediction_rnn.py +++ b/examples/mimic3_ama_prediction_rnn.py @@ -30,6 +30,7 @@ Usage (real MIMIC-III): python examples/mimic3_ama_prediction_rnn.py \\ --data-source real --root /path/to/mimic-iii/1.4 --splits 100 --epochs 10 + """ import argparse diff --git a/pyhealth/tasks/ama_prediction.py b/pyhealth/tasks/ama_prediction.py index 96cc806ab..f62b7afa8 100644 --- a/pyhealth/tasks/ama_prediction.py +++ b/pyhealth/tasks/ama_prediction.py @@ -1,6 +1,12 @@ +"""MIMIC-III Against-Medical-Advice (AMA) discharge prediction task. + +Defines :class:`AMAPredictionMIMIC3` and helpers for Boag et al. 2018-style +demographic baselines (race / substance-use ablations). +""" + import re from datetime import datetime -from typing import Any, Dict, List +from typing import Any, Dict, List, Optional from .base_task import BaseTask @@ -11,7 +17,7 @@ ) -def _normalize_race(ethnicity: str) -> str: +def _normalize_race(ethnicity: Optional[str]) -> str: """Map MIMIC-III ethnicity strings to the race categories used by Boag et al. 2018. @@ -39,7 +45,7 @@ def _normalize_race(ethnicity: str) -> str: return "Other" -def _normalize_insurance(insurance: str) -> str: +def _normalize_insurance(insurance: Optional[str]) -> str: """Map MIMIC-III insurance strings to the categories used by Boag et al. 2018. @@ -121,6 +127,14 @@ class AMAPredictionMIMIC3(BaseTask): of the **current** admission, so patients with only one visit are eligible. + **Processor mapping (schemas):** Each value in ``input_schema`` and + ``output_schema`` must be a processor string key understood by + ``dataset.set_task`` (see the Tasks docs processor table). Here, + ``"multi_hot"`` feeds ``MultiHotProcessor`` (token lists for + ``demographics`` and ``race``), ``"tensor"`` feeds ``TensorProcessor`` + (``age``, ``los``, ``has_substance_use``), and ``"binary"`` labels + ``ama`` for the binary label path. + Attributes: task_name (str): The name of the task. input_schema (Dict[str, str]): The schema for the task input. diff --git a/tests/core/test_mimic3_ama_prediction.py b/tests/core/test_mimic3_ama_prediction.py index 1c7960aa8..1e93e6b43 100644 --- a/tests/core/test_mimic3_ama_prediction.py +++ b/tests/core/test_mimic3_ama_prediction.py @@ -1,12 +1,27 @@ +"""Unit and integration tests for AMA discharge prediction on MIMIC-III. + +Aligned with the PyHealth PR checklist ("What makes a good PyHealth PR?", +see course slides): task inherits ``BaseTask`` with explicit schemas; +integration tests use a **five-row** shared synthetic slice so runs stay +fast; mocks cover edge cases without loading full MIMIC-III. End-to-end +usage lives in ``examples/mimic3_ama_prediction_logistic_regression.py`` and +``examples/mimic3_ama_prediction_rnn.py``. +""" + +import gc import gzip import importlib.util +import io import shutil +import sys import tempfile import unittest +import warnings from datetime import datetime from pathlib import Path from unittest.mock import MagicMock +import pandas as pd import torch from pyhealth.datasets import MIMIC3Dataset, get_dataloader, split_by_patient @@ -19,6 +34,11 @@ ) from pyhealth.trainer import Trainer +warnings.filterwarnings("ignore", category=ResourceWarning) +if "ignore::ResourceWarning" not in getattr(sys, "warnoptions", []): + sys.warnoptions.append("ignore::ResourceWarning") + warnings._filters_mutated() + _EXAMPLE_PATH = ( Path(__file__).resolve().parents[2] / "examples" @@ -33,6 +53,277 @@ generate_synthetic_mimic3 = _example_mod.generate_synthetic_mimic3 +# Fixed 5-row MIMIC-III-like slice for integration tests (3 non-AMA, 2 AMA). +# Race (normalized): 2 White, 2 Black, 1 Hispanic. +# Age at admit (task uses calendar year from DOB vs admittime): two young +# (18-44), two middle (45-64), one senior (65+). +CURATED_SYNTHETIC_N = 5 +CURATED_SYNTHETIC_AMA_NEGATIVE = 3 +CURATED_SYNTHETIC_AMA_POSITIVE = 2 + +_CURATED_MIMIC3_PATIENTS = [ + { + "subject_id": 1, + "gender": "M", + "dob": "2118-01-01 00:00:00", + "dod": None, + "dod_hosp": None, + "dod_ssn": None, + "expire_flag": 0, + }, + { + "subject_id": 2, + "gender": "F", + "dob": "2096-01-02 00:00:00", + "dod": None, + "dod_hosp": None, + "dod_ssn": None, + "expire_flag": 0, + }, + { + "subject_id": 3, + "gender": "M", + "dob": "2098-01-03 00:00:00", + "dod": None, + "dod_hosp": None, + "dod_ssn": None, + "expire_flag": 0, + }, + { + "subject_id": 4, + "gender": "F", + "dob": "2115-01-11 00:00:00", + "dod": None, + "dod_hosp": None, + "dod_ssn": None, + "expire_flag": 0, + }, + { + "subject_id": 5, + "gender": "M", + "dob": "2079-01-12 00:00:00", + "dod": None, + "dod_hosp": None, + "dod_ssn": None, + "expire_flag": 0, + }, +] + +_CURATED_MIMIC3_ADMISSIONS = [ + { + "subject_id": 1, + "hadm_id": 100, + "admission_type": "EMERGENCY", + "admission_location": "EMERGENCY ROOM ADMIT", + "insurance": "Private", + "language": "ENGLISH", + "religion": "CHRISTIAN", + "marital_status": "SINGLE", + "ethnicity": "WHITE", + "edregtime": "2150-01-01 00:00:00", + "edouttime": "2150-01-01 00:00:00", + "diagnosis": "PNEUMONIA", + "discharge_location": "HOME", + "dischtime": "2150-01-08 00:00:00", + "admittime": "2150-01-01 00:00:00", + "hospital_expire_flag": 0, + }, + { + "subject_id": 2, + "hadm_id": 101, + "admission_type": "EMERGENCY", + "admission_location": "EMERGENCY ROOM ADMIT", + "insurance": "Private", + "language": "ENGLISH", + "religion": "CHRISTIAN", + "marital_status": "SINGLE", + "ethnicity": "WHITE", + "edregtime": "2150-01-02 00:00:00", + "edouttime": "2150-01-02 00:00:00", + "diagnosis": "CHEST PAIN", + "discharge_location": "HOME", + "dischtime": "2150-01-09 00:00:00", + "admittime": "2150-01-02 00:00:00", + "hospital_expire_flag": 0, + }, + { + "subject_id": 3, + "hadm_id": 102, + "admission_type": "EMERGENCY", + "admission_location": "EMERGENCY ROOM ADMIT", + "insurance": "Medicaid", + "language": "ENGLISH", + "religion": "CHRISTIAN", + "marital_status": "SINGLE", + "ethnicity": "BLACK/AFRICAN AMERICAN", + "edregtime": "2150-01-03 00:00:00", + "edouttime": "2150-01-03 00:00:00", + "diagnosis": "SEPSIS", + "discharge_location": "HOME", + "dischtime": "2150-01-10 00:00:00", + "admittime": "2150-01-03 00:00:00", + "hospital_expire_flag": 0, + }, + { + "subject_id": 4, + "hadm_id": 103, + "admission_type": "EMERGENCY", + "admission_location": "EMERGENCY ROOM ADMIT", + "insurance": "Medicaid", + "language": "ENGLISH", + "religion": "CHRISTIAN", + "marital_status": "SINGLE", + "ethnicity": "BLACK/AFRICAN AMERICAN", + "edregtime": "2150-01-11 00:00:00", + "edouttime": "2150-01-11 00:00:00", + "diagnosis": "PNEUMONIA", + "discharge_location": "LEFT AGAINST MEDICAL ADVI", + "dischtime": "2150-01-18 00:00:00", + "admittime": "2150-01-11 00:00:00", + "hospital_expire_flag": 0, + }, + { + "subject_id": 5, + "hadm_id": 104, + "admission_type": "EMERGENCY", + "admission_location": "EMERGENCY ROOM ADMIT", + "insurance": "Medicare", + "language": "ENGLISH", + "religion": "CHRISTIAN", + "marital_status": "SINGLE", + "ethnicity": "HISPANIC OR LATINO", + "edregtime": "2150-01-12 00:00:00", + "edouttime": "2150-01-12 00:00:00", + "diagnosis": "OPIOID DEPENDENCE", + "discharge_location": "LEFT AGAINST MEDICAL ADVI", + "dischtime": "2150-01-19 00:00:00", + "admittime": "2150-01-12 00:00:00", + "hospital_expire_flag": 0, + }, +] + +_CURATED_MIMIC3_ICUSTAYS = [ + { + "subject_id": 1, + "hadm_id": 100, + "icustay_id": 1000, + "first_careunit": "MICU", + "last_careunit": "MICU", + "dbsource": "metavision", + "intime": "2150-01-01 02:00:00", + "outtime": "2150-01-07 22:00:00", + }, + { + "subject_id": 2, + "hadm_id": 101, + "icustay_id": 1001, + "first_careunit": "MICU", + "last_careunit": "MICU", + "dbsource": "metavision", + "intime": "2150-01-02 02:00:00", + "outtime": "2150-01-08 22:00:00", + }, + { + "subject_id": 3, + "hadm_id": 102, + "icustay_id": 1002, + "first_careunit": "MICU", + "last_careunit": "MICU", + "dbsource": "metavision", + "intime": "2150-01-03 02:00:00", + "outtime": "2150-01-09 22:00:00", + }, + { + "subject_id": 4, + "hadm_id": 103, + "icustay_id": 1003, + "first_careunit": "MICU", + "last_careunit": "MICU", + "dbsource": "metavision", + "intime": "2150-01-11 02:00:00", + "outtime": "2150-01-17 22:00:00", + }, + { + "subject_id": 5, + "hadm_id": 104, + "icustay_id": 1004, + "first_careunit": "MICU", + "last_careunit": "MICU", + "dbsource": "metavision", + "intime": "2150-01-12 02:00:00", + "outtime": "2150-01-18 22:00:00", + }, +] + + +def _write_curated_synthetic_mimic3_for_tests(root: str) -> None: + """Write ``_CURATED_MIMIC3_*`` to gzipped PATIENTS/ADMISSIONS/ICUSTAYS.""" + root_path = Path(root) + root_path.mkdir(parents=True, exist_ok=True) + for name, rows in ( + ("PATIENTS", _CURATED_MIMIC3_PATIENTS), + ("ADMISSIONS", _CURATED_MIMIC3_ADMISSIONS), + ("ICUSTAYS", _CURATED_MIMIC3_ICUSTAYS), + ): + df = pd.DataFrame(rows) + with gzip.open(root_path / f"{name}.csv.gz", "wt") as f: + df.to_csv(f, index=False) + + +# ------------------------------------------------------------------ +# Module-level shared dataset (loaded once for all integration tests) +# ------------------------------------------------------------------ + +_shared_tmpdir = None +_shared_cache_dir = None +_shared_dataset = None +_shared_sample_dataset = None + + +def setUpModule(): + global _shared_tmpdir, _shared_cache_dir + global _shared_dataset, _shared_sample_dataset + _shared_tmpdir = tempfile.mkdtemp(prefix="ama_shared_") + _shared_cache_dir = tempfile.mkdtemp(prefix="ama_shared_cache_") + _write_curated_synthetic_mimic3_for_tests(_shared_tmpdir) + _shared_dataset = MIMIC3Dataset( + root=_shared_tmpdir, tables=[], cache_dir=_shared_cache_dir, + ) + _shared_sample_dataset = _shared_dataset.set_task(AMAPredictionMIMIC3()) + + +def tearDownModule(): + global _shared_dataset, _shared_sample_dataset + if _shared_sample_dataset is not None: + _shared_sample_dataset.close() + _shared_sample_dataset = None + _shared_dataset = None + gc.collect() + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + for obj in gc.get_objects(): + if isinstance(obj, io.FileIO) and not obj.closed: + name = getattr(obj, "name", "") + if ( + isinstance(name, str) + and _shared_cache_dir + and _shared_cache_dir in name + ): + try: + obj.close() + except Exception: + pass + gc.collect() + for d in (_shared_tmpdir, _shared_cache_dir): + if d: + shutil.rmtree(d, ignore_errors=True) + + +# ------------------------------------------------------------------ +# Helpers +# ------------------------------------------------------------------ + + def _make_event(**attrs): """Create a mock event with the given attributes.""" event = MagicMock() @@ -52,8 +343,6 @@ def _build_patient( ): """Build a mock Patient with ``get_events`` that respects filters. - Uses 2-5 synthetic patients max. No real dataset is loaded. - Args: patient_id: Patient identifier string. admissions: List of dicts for admission events. @@ -785,7 +1074,6 @@ def test_baseline_race_feature_present(self): for sample in samples: self.assertIn("race", sample) self.assertTrue(isinstance(sample["race"], list)) - # Race should be one of the normalized values race_val = sample["race"][0].split(":", 1)[1] self.assertIn( race_val, @@ -807,11 +1095,9 @@ def test_substance_use_detection_in_ablation(self): """Verify substance use detection for ablation patient.""" samples = self.task(self.patient) - # First admission has substance use (ALCOHOL WITHDRAWAL) s1 = next(s for s in samples if s["visit_id"] == "A1") self.assertEqual(s1["has_substance_use"], [1.0]) - # Second admission has no substance use (PNEUMONIA) s2 = next(s for s in samples if s["visit_id"] == "A2") self.assertEqual(s2["has_substance_use"], [0.0]) @@ -819,11 +1105,9 @@ def test_race_normalization_in_ablation(self): """Verify race normalization for ablation patient.""" samples = self.task(self.patient) - # First admission: Hispanic s1 = next(s for s in samples if s["visit_id"] == "A1") self.assertEqual(s1["race"], ["race:Hispanic"]) - # Second admission: White s2 = next(s for s in samples if s["visit_id"] == "A2") self.assertEqual(s2["race"], ["race:White"]) @@ -834,11 +1118,7 @@ def test_age_and_los_computed(self): for sample in samples: age = sample["age"][0] los = sample["los"][0] - - # Age should be 50 (2150 - 2100) self.assertAlmostEqual(age, 50.0, places=1) - - # LOS should be positive self.assertGreater(los, 0.0) def test_demographics_includes_gender_and_insurance(self): @@ -847,7 +1127,6 @@ def test_demographics_includes_gender_and_insurance(self): for sample in samples: demo = sample["demographics"] - # Should have gender and insurance tokens has_gender = any(t.startswith("gender:") for t in demo) has_insurance = any(t.startswith("insurance:") for t in demo) self.assertTrue(has_gender) @@ -857,12 +1136,10 @@ def test_insurance_normalization_in_ablation(self): """Verify insurance normalization (Medicaid -> Public).""" samples = self.task(self.patient) - # First admission: Medicaid -> Public s1 = next(s for s in samples if s["visit_id"] == "A1") demo1 = s1["demographics"] self.assertIn("insurance:Public", demo1) - # Second admission: Private s2 = next(s for s in samples if s["visit_id"] == "A2") demo2 = s2["demographics"] self.assertIn("insurance:Private", demo2) @@ -871,11 +1148,9 @@ def test_label_correctness_in_ablation(self): """Verify AMA label is correct.""" samples = self.task(self.patient) - # First admission: not AMA s1 = next(s for s in samples if s["visit_id"] == "A1") self.assertEqual(s1["ama"], 0) - # Second admission: AMA s2 = next(s for s in samples if s["visit_id"] == "A2") self.assertEqual(s2["ama"], 1) @@ -910,41 +1185,19 @@ def test_ablation_patient_no_clinical_codes(self): # ------------------------------------------------------------------ -# Synthetic MIMIC-III + LogisticRegression (loads example generator) +# Integration tests using shared curated 5-row dataset # ------------------------------------------------------------------ class TestAMAWithSyntheticData(unittest.TestCase): - """AMA task on small random synthetic CSVs (fast pipeline checks).""" - - @classmethod - def setUpClass(cls): - cls.tmpdir = tempfile.mkdtemp(prefix="ama_test_") - cls.cache_dir = tempfile.mkdtemp(prefix="ama_test_cache_") - - generate_synthetic_mimic3( - cls.tmpdir, - n_patients=50, - avg_admissions_per_patient=2, - seed=42, - mode="random", - ) - - cls.dataset = MIMIC3Dataset( - root=cls.tmpdir, - tables=[], - cache_dir=cls.cache_dir, - ) - - cls.task = AMAPredictionMIMIC3() - cls.sample_dataset = cls.dataset.set_task(cls.task) + """AMA task on curated minimal synthetic CSVs (fast pipeline checks).""" def test_dataset_loads_successfully(self): - self.assertIsNotNone(self.dataset) - self.assertGreater(len(self.sample_dataset), 0) + self.assertIsNotNone(_shared_dataset) + self.assertGreater(len(_shared_sample_dataset), 0) def test_samples_have_expected_features(self): - sample = self.sample_dataset[0] + sample = _shared_sample_dataset[0] expected_keys = { "visit_id", @@ -959,7 +1212,7 @@ def test_samples_have_expected_features(self): self.assertEqual(set(sample.keys()), expected_keys) def test_demographics_values(self): - for sample in self.sample_dataset: + for sample in _shared_sample_dataset: demo = sample["demographics"] self.assertTrue( torch.is_tensor(demo) or isinstance(demo, (int, float)), @@ -967,34 +1220,34 @@ def test_demographics_values(self): ) def test_age_in_valid_range(self): - for sample in self.sample_dataset: + for sample in _shared_sample_dataset: age = sample["age"] self.assertTrue(torch.is_tensor(age) or isinstance(age, (int, float))) def test_los_positive(self): - for sample in self.sample_dataset: + for sample in _shared_sample_dataset: los = sample["los"] self.assertTrue(torch.is_tensor(los) or isinstance(los, (int, float))) def test_race_normalized(self): - for sample in self.sample_dataset: + for sample in _shared_sample_dataset: race = sample["race"] self.assertTrue(torch.is_tensor(race) or isinstance(race, (int, float))) def test_substance_use_binary(self): - for sample in self.sample_dataset: + for sample in _shared_sample_dataset: substance = sample["has_substance_use"] self.assertTrue( torch.is_tensor(substance) or isinstance(substance, (int, float)), ) def test_ama_label_binary(self): - for sample in self.sample_dataset: + for sample in _shared_sample_dataset: ama = sample["ama"] self.assertIn(ama, [0, 1]) def test_has_positive_and_negative_labels(self): - labels = [sample["ama"] for sample in self.sample_dataset] + labels = [sample["ama"] for sample in _shared_sample_dataset] has_positive = any(l == 1 for l in labels) has_negative = any(l == 0 for l in labels) @@ -1007,30 +1260,9 @@ def test_has_positive_and_negative_labels(self): class TestAMABaselineFeatures(unittest.TestCase): """LogisticRegression ablation feature subsets on synthetic data.""" - @classmethod - def setUpClass(cls): - cls.tmpdir = tempfile.mkdtemp(prefix="ama_baseline_test_") - cls.cache_dir = tempfile.mkdtemp(prefix="ama_baseline_cache_") - - generate_synthetic_mimic3( - cls.tmpdir, - n_patients=30, - seed=42, - mode="random", - ) - - cls.dataset = MIMIC3Dataset( - root=cls.tmpdir, - tables=[], - cache_dir=cls.cache_dir, - ) - - cls.task = AMAPredictionMIMIC3() - cls.sample_dataset = cls.dataset.set_task(cls.task) - def _create_model_with_features(self, feature_keys): model = LogisticRegression( - dataset=self.sample_dataset, + dataset=_shared_sample_dataset, embedding_dim=64, ) model.feature_keys = list(feature_keys) @@ -1070,7 +1302,7 @@ def test_baseline_forward_pass(self): ) train_ds, _, test_ds = split_by_patient( - self.sample_dataset, [0.8, 0.0, 0.2], seed=0 + _shared_sample_dataset, [0.8, 0.0, 0.2], seed=0 ) test_dl = get_dataloader(test_ds, batch_size=8, shuffle=False) @@ -1089,7 +1321,7 @@ def test_baseline_plus_race_forward_pass(self): ) train_ds, _, test_ds = split_by_patient( - self.sample_dataset, [0.8, 0.0, 0.2], seed=0 + _shared_sample_dataset, [0.8, 0.0, 0.2], seed=0 ) test_dl = get_dataloader(test_ds, batch_size=8, shuffle=False) @@ -1107,7 +1339,7 @@ def test_baseline_plus_full_forward_pass(self): ) train_ds, _, test_ds = split_by_patient( - self.sample_dataset, [0.8, 0.0, 0.2], seed=0 + _shared_sample_dataset, [0.8, 0.0, 0.2], seed=0 ) test_dl = get_dataloader(test_ds, batch_size=8, shuffle=False) @@ -1123,39 +1355,18 @@ def test_baseline_plus_full_forward_pass(self): class TestAMATrainingSpeed(unittest.TestCase): """Short training runs on tiny synthetic data.""" - @classmethod - def setUpClass(cls): - cls.tmpdir = tempfile.mkdtemp(prefix="ama_speed_test_") - cls.cache_dir = tempfile.mkdtemp(prefix="ama_speed_cache_") - - generate_synthetic_mimic3( - cls.tmpdir, - n_patients=20, - seed=42, - mode="random", - ) - - cls.dataset = MIMIC3Dataset( - root=cls.tmpdir, - tables=[], - cache_dir=cls.cache_dir, - ) - - cls.task = AMAPredictionMIMIC3() - cls.sample_dataset = cls.dataset.set_task(cls.task) - def test_training_completes_quickly(self): import time train_ds, _, test_ds = split_by_patient( - self.sample_dataset, [0.6, 0.0, 0.4], seed=0 + _shared_sample_dataset, [0.6, 0.0, 0.4], seed=0 ) train_dl = get_dataloader( train_ds, batch_size=8, shuffle=True ) model = LogisticRegression( - dataset=self.sample_dataset, + dataset=_shared_sample_dataset, embedding_dim=64, ) model.feature_keys = ["demographics", "age", "los"] @@ -1181,15 +1392,9 @@ def test_training_completes_quickly(self): self.assertGreater(elapsed, 0, "Training should take some time") def test_multiple_splits_complete_quickly(self): - split_by_patient( - self.sample_dataset, - [0.6, 0.0, 0.4], - seed=0, - ) - for split_seed in range(2): train_ds, _, _ = split_by_patient( - self.sample_dataset, + _shared_sample_dataset, [0.6, 0.0, 0.4], seed=split_seed, ) @@ -1198,7 +1403,7 @@ def test_multiple_splits_complete_quickly(self): ) model = LogisticRegression( - dataset=self.sample_dataset, + dataset=_shared_sample_dataset, embedding_dim=64, ) model.feature_keys = ["demographics", "age", "los"] @@ -1244,5 +1449,40 @@ def test_patient_row_count_matches_cross_product(self): shutil.rmtree(tmp, ignore_errors=True) +class TestCuratedSyntheticGrid(unittest.TestCase): + """Curated synthetic: fixed small CSV (integration test data).""" + + def test_curated_csv_row_counts(self): + tmp = tempfile.mkdtemp(prefix="ama_curated_") + try: + _write_curated_synthetic_mimic3_for_tests(tmp) + for name in ("PATIENTS", "ADMISSIONS", "ICUSTAYS"): + with gzip.open(Path(tmp) / f"{name}.csv.gz", "rt") as f: + n = len(f.readlines()) - 1 + self.assertEqual(n, CURATED_SYNTHETIC_N, name) + finally: + shutil.rmtree(tmp, ignore_errors=True) + + def test_curated_task_label_counts(self): + self.assertEqual(len(_shared_sample_dataset), CURATED_SYNTHETIC_N) + + def _ama_int(x): + if torch.is_tensor(x): + return int(x.item()) + return int(x) + + labels = [ + _ama_int(_shared_sample_dataset[i]["ama"]) + for i in range(CURATED_SYNTHETIC_N) + ] + self.assertEqual(sum(labels), CURATED_SYNTHETIC_AMA_POSITIVE) + self.assertEqual( + labels.count(0), CURATED_SYNTHETIC_AMA_NEGATIVE, + ) + self.assertEqual( + labels.count(1), CURATED_SYNTHETIC_AMA_POSITIVE, + ) + + if __name__ == "__main__": unittest.main() From c1ad08954a6c241e498f448c5666528115a50048 Mon Sep 17 00:00:00 2001 From: Madhav Kanda Date: Wed, 15 Apr 2026 20:15:08 -0500 Subject: [PATCH 07/10] fixed docstrings and function comments --- ...mic3_ama_prediction_logistic_regression.py | 526 ++++++++++++------ examples/mimic3_ama_prediction_rnn.py | 306 ++++++---- pyhealth/tasks/ama_prediction.py | 63 ++- tests/core/test_mimic3_ama_prediction.py | 500 ++++++++--------- 4 files changed, 823 insertions(+), 572 deletions(-) diff --git a/examples/mimic3_ama_prediction_logistic_regression.py b/examples/mimic3_ama_prediction_logistic_regression.py index a0b807e40..bb64e6233 100644 --- a/examples/mimic3_ama_prediction_logistic_regression.py +++ b/examples/mimic3_ama_prediction_logistic_regression.py @@ -1,42 +1,48 @@ -"""Ablation study for AMA discharge prediction on MIMIC-III. +"""Ablation study for MIMIC-III Against-Medical-Advice (AMA) discharge prediction. -This script demonstrates the AMAPredictionMIMIC3 task with three feature -ablations and evaluates model fairness using AUROC across demographic -subgroups (race, age, insurance). A logistic regression classifier is -trained on the extracted features to analyze how demographic information -affects prediction of against-medical-advice (AMA) discharge. +This script demonstrates ``AMAPredictionMIMIC3`` with three feature ablations, +trains a ``LogisticRegression`` model on processed samples, and evaluates how +demographic and administrative features affect AMA discharge prediction and +fairness. Labels come from ``discharge_location``; inputs follow the task +``input_schema`` / ``output_schema`` (multi_hot, tensor, binary). -Paper: Boag, W.; Suresh, H.; Celi, L. A.; Szolovits, P.; and Ghassemi, M. -"Racial Disparities and Mistrust in End-of-Life Care." Machine Learning -for Healthcare Conference, PMLR 106:211-235, 2018. +Paper: + Boag, W.; Suresh, H.; Celi, L. A.; Szolovits, P.; and Ghassemi, M. + "Racial Disparities and Mistrust in End-of-Life Care." Machine Learning + for Healthcare Conference, PMLR 106:211-235, 2018. Ablation configurations tested: - 1. BASELINE: demographics (gender, insurance) + age + los - 2. BASELINE+RACE: adds normalized ethnicity feature - 3. BASELINE+RACE+SUBSTANCE: adds substance use diagnosis flag + 1. BASELINE: ``feature_keys`` = demographics (gender, insurance) + age + + length of stay (LOS). + 2. BASELINE+RACE: adds normalized ethnicity as the ``race`` feature. + 3. BASELINE+RACE+SUBSTANCE: adds ``has_substance_use`` from admission + diagnosis text. Results: - For each baseline, we report: + For each baseline configuration we report: - Overall AUROC averaged over N random 60/40 train/test splits - - Subgroup performance (AUROC) stratified by: - * Race (White, Black, Hispanic, Asian, Native American, Other) - * Age Group (Young 18-44, Middle 45-64, Senior 65+) - * Insurance (Public, Private, Self Pay) - - Fairness metrics per subgroup: - * Demographic Parity: % predicted AMA per group - * Equal Opportunity: True Positive Rate per group - These reveal disparities in model behavior across demographics. - -Usage (synthetic exhaustive grid -- default when ``--root`` is omitted): - python examples/mimic3_ama_prediction_logistic_regression.py - -Usage (synthetic random demo): - python examples/mimic3_ama_prediction_logistic_regression.py \\ + (patient-level ``split_by_patient``). + - Subgroup AUROC by race, age band (Young / Middle / Senior), and + insurance category. + - Fairness-style summaries: demographic parity (% predicted AMA per + group) and equal opportunity (TPR per group). + +Usage: + # Default: synthetic exhaustive grid when ``--root`` is omitted + # (``--data-source auto`` with no root -> synthetic). + cd /path/to/PyHealth && \\ + python examples/mimic3_ama_prediction_logistic_regression.py + + # Synthetic random cohort (faster than exhaustive) + cd /path/to/PyHealth && \\ + python examples/mimic3_ama_prediction_logistic_regression.py \\ --data-source synthetic --synthetic-mode random --patients 200 -Usage (real MIMIC-III; same as ``--root /path`` with ``--data-source auto``): - python examples/mimic3_ama_prediction_logistic_regression.py \\ - --data-source real --root /path/to/mimic-iii/1.4 --splits 100 --epochs 10 + # Full MIMIC-III 1.4 on disk (replace root with your PhysioNet path) + cd /path/to/PyHealth && \\ + python examples/mimic3_ama_prediction_logistic_regression.py \\ + --data-source real --root /path/to/mimic-iii/1.4 \\ + --splits 100 --epochs 10 """ @@ -47,7 +53,7 @@ import time from datetime import datetime, timedelta from pathlib import Path -from typing import List, Optional +from typing import Any, Dict, List, Optional, Tuple import numpy as np import pandas as pd @@ -86,6 +92,8 @@ def generate_synthetic_mimic3( seed: RNG seed (random mode only). mode: ``"exhaustive"`` or ``"random"``. """ + # Writes PATIENTS.csv.gz / ADMISSIONS.csv.gz / ICUSTAYS.csv.gz under ``root`` + # with columns compatible with ``MIMIC3Dataset`` and ``AMAPredictionMIMIC3``. root_path = Path(root) root_path.mkdir(parents=True, exist_ok=True) @@ -141,6 +149,15 @@ def generate_synthetic_mimic3( ] def write_csv_gz(filename: str, data: List[dict]) -> None: + """Serialize ``data`` rows to ``{filename}.gz`` as CSV (gzip). + + Args: + filename: Base name without ``.gz`` (e.g. ``"PATIENTS.csv"``). + data: List of row dicts matching MIMIC-III column names. + + Returns: + None (writes file under ``root_path``). + """ df = pd.DataFrame(data) filepath = root_path / f"{filename}.gz" with gzip.open(filepath, "wt") as f: @@ -161,50 +178,73 @@ def append_visit( diagnosis: str, day_offset: int, ) -> int: - """Append one patient + admission + icustay; return next icustay_id.""" + """Append one synthetic patient row, one admission, and one ICU stay. + + Args: + subject_id: MIMIC ``subject_id``. + hadm_id: Hospital admission id. + icustay_id: ICU stay id; incremented on success. + gender: Patient gender (``M`` / ``F``). + age_years: Approximate age used to synthesize ``dob``. + ethnicity: Raw MIMIC ethnicity string. + insurance_raw: Raw insurance (may be ``None`` for ``Other`` path). + admission_type: MIMIC ``admission_type`` (e.g. ``NEWBORN``). + discharge_loc: ``discharge_location`` (AMA vs non-AMA). + diagnosis: Free-text ``diagnosis`` for substance-use flag. + day_offset: Day index to spread ``admittime`` values. + + Returns: + Next ``icustay_id`` if an ICU row was written; else unchanged id. + """ dob = datetime(2000, 1, 1) - timedelta(days=int(age_years * 365)) - patients_data.append({ - "subject_id": subject_id, - "gender": gender, - "dob": dob.strftime("%Y-%m-%d %H:%M:%S"), - "dod": None, - "dod_hosp": None, - "dod_ssn": None, - "expire_flag": 0, - }) + patients_data.append( + { + "subject_id": subject_id, + "gender": gender, + "dob": dob.strftime("%Y-%m-%d %H:%M:%S"), + "dod": None, + "dod_hosp": None, + "dod_ssn": None, + "expire_flag": 0, + } + ) admit_time = datetime(2150, 1, 1) + timedelta(days=day_offset) discharge_time = admit_time + timedelta(days=7) - admissions_data.append({ - "subject_id": subject_id, - "hadm_id": hadm_id, - "admission_type": admission_type, - "admission_location": "EMERGENCY ROOM ADMIT", - "insurance": insurance_raw, - "language": "ENGLISH", - "religion": "CHRISTIAN", - "marital_status": "SINGLE", - "ethnicity": ethnicity, - "edregtime": admit_time.strftime("%Y-%m-%d %H:%M:%S"), - "edouttime": admit_time.strftime("%Y-%m-%d %H:%M:%S"), - "diagnosis": diagnosis, - "discharge_location": discharge_loc, - "dischtime": discharge_time.strftime("%Y-%m-%d %H:%M:%S"), - "admittime": admit_time.strftime("%Y-%m-%d %H:%M:%S"), - "hospital_expire_flag": 1 if discharge_loc == "EXPIRED" else 0, - }) + admissions_data.append( + { + "subject_id": subject_id, + "hadm_id": hadm_id, + "admission_type": admission_type, + "admission_location": "EMERGENCY ROOM ADMIT", + "insurance": insurance_raw, + "language": "ENGLISH", + "religion": "CHRISTIAN", + "marital_status": "SINGLE", + "ethnicity": ethnicity, + "edregtime": admit_time.strftime("%Y-%m-%d %H:%M:%S"), + "edouttime": admit_time.strftime("%Y-%m-%d %H:%M:%S"), + "diagnosis": diagnosis, + "discharge_location": discharge_loc, + "dischtime": discharge_time.strftime("%Y-%m-%d %H:%M:%S"), + "admittime": admit_time.strftime("%Y-%m-%d %H:%M:%S"), + "hospital_expire_flag": 1 if discharge_loc == "EXPIRED" else 0, + } + ) icu_intime = admit_time + timedelta(hours=2) icu_outtime = discharge_time - timedelta(hours=2) if icu_intime < icu_outtime: - icustays_data.append({ - "subject_id": subject_id, - "hadm_id": hadm_id, - "icustay_id": icustay_id, - "first_careunit": "MICU", - "last_careunit": "MICU", - "dbsource": "metavision", - "intime": icu_intime.strftime("%Y-%m-%d %H:%M:%S"), - "outtime": icu_outtime.strftime("%Y-%m-%d %H:%M:%S"), - }) + icustays_data.append( + { + "subject_id": subject_id, + "hadm_id": hadm_id, + "icustay_id": icustay_id, + "first_careunit": "MICU", + "last_careunit": "MICU", + "dbsource": "metavision", + "intime": icu_intime.strftime("%Y-%m-%d %H:%M:%S"), + "outtime": icu_outtime.strftime("%Y-%m-%d %H:%M:%S"), + } + ) return icustay_id + 1 return icustay_id @@ -317,18 +357,21 @@ def append_visit( ) dob = datetime(2000, 1, 1) - timedelta(days=age_at_visit * 365) - patients_data.append({ - "subject_id": subject_id, - "gender": gender, - "dob": dob.strftime("%Y-%m-%d %H:%M:%S"), - "dod": None, - "dod_hosp": None, - "dod_ssn": None, - "expire_flag": 0, - }) + patients_data.append( + { + "subject_id": subject_id, + "gender": gender, + "dob": dob.strftime("%Y-%m-%d %H:%M:%S"), + "dod": None, + "dod_hosp": None, + "dod_ssn": None, + "expire_flag": 0, + } + ) n_admissions = max( - 1, int(np.random.poisson(avg_admissions_per_patient)), + 1, + int(np.random.poisson(avg_admissions_per_patient)), ) for j in range(n_admissions): admit_time = datetime(2150, 1, 1) + timedelta(days=int(j * 100)) @@ -356,26 +399,26 @@ def append_visit( np.random.randint(0, len(diagnoses_other)) ] - admissions_data.append({ - "subject_id": subject_id, - "hadm_id": hadm_id, - "admission_type": admission_type, - "admission_location": "EMERGENCY ROOM ADMIT", - "insurance": insurance, - "language": "ENGLISH", - "religion": "CHRISTIAN", - "marital_status": "SINGLE", - "ethnicity": ethnicity, - "edregtime": admit_time.strftime("%Y-%m-%d %H:%M:%S"), - "edouttime": admit_time.strftime("%Y-%m-%d %H:%M:%S"), - "diagnosis": diagnosis, - "discharge_location": discharge_loc, - "dischtime": discharge_time.strftime("%Y-%m-%d %H:%M:%S"), - "admittime": admit_time.strftime("%Y-%m-%d %H:%M:%S"), - "hospital_expire_flag": 1 - if discharge_loc == "EXPIRED" - else 0, - }) + admissions_data.append( + { + "subject_id": subject_id, + "hadm_id": hadm_id, + "admission_type": admission_type, + "admission_location": "EMERGENCY ROOM ADMIT", + "insurance": insurance, + "language": "ENGLISH", + "religion": "CHRISTIAN", + "marital_status": "SINGLE", + "ethnicity": ethnicity, + "edregtime": admit_time.strftime("%Y-%m-%d %H:%M:%S"), + "edouttime": admit_time.strftime("%Y-%m-%d %H:%M:%S"), + "diagnosis": diagnosis, + "discharge_location": discharge_loc, + "dischtime": discharge_time.strftime("%Y-%m-%d %H:%M:%S"), + "admittime": admit_time.strftime("%Y-%m-%d %H:%M:%S"), + "hospital_expire_flag": 1 if discharge_loc == "EXPIRED" else 0, + } + ) icu_intime = admit_time + timedelta( hours=int(np.random.randint(0, 12)), @@ -385,16 +428,18 @@ def append_visit( ) if icu_intime < icu_outtime: - icustays_data.append({ - "subject_id": subject_id, - "hadm_id": hadm_id, - "icustay_id": icustay_id, - "first_careunit": "MICU", - "last_careunit": "MICU", - "dbsource": "metavision", - "intime": icu_intime.strftime("%Y-%m-%d %H:%M:%S"), - "outtime": icu_outtime.strftime("%Y-%m-%d %H:%M:%S"), - }) + icustays_data.append( + { + "subject_id": subject_id, + "hadm_id": hadm_id, + "icustay_id": icustay_id, + "first_careunit": "MICU", + "last_careunit": "MICU", + "dbsource": "metavision", + "intime": icu_intime.strftime("%Y-%m-%d %H:%M:%S"), + "outtime": icu_outtime.strftime("%Y-%m-%d %H:%M:%S"), + } + ) icustay_id += 1 hadm_id += 1 @@ -415,7 +460,11 @@ def append_visit( "BASELINE": ["demographics", "age", "los"], "BASELINE+RACE": ["demographics", "age", "los", "race"], "BASELINE+RACE+SUBSTANCE": [ - "demographics", "age", "los", "race", "has_substance_use", + "demographics", + "age", + "los", + "race", + "has_substance_use", ], } @@ -424,13 +473,27 @@ def append_visit( # Helpers -- demographics lookup # ------------------------------------------------------------------ -def _build_demographics_lookup(dataset, task): - """Run the task on every patient and collect raw demographic info. - Returns a dict mapping ``(patient_id, visit_id)`` to a dict with - keys ``race``, ``age``, and ``insurance``. +def _build_demographics_lookup( + dataset: Any, + task: AMAPredictionMIMIC3, +) -> Dict[Tuple[str, str], Dict[str, Any]]: + """Build a post-hoc lookup for fairness reporting (not model inputs). + + Re-runs ``task`` on each patient from ``dataset`` to recover string + race label, scalar age, and insurance token for each visit. Keys match + batch ``patient_id`` / ``visit_id`` used when aligning predictions. + + Args: + dataset: ``MIMIC3Dataset`` (or compatible) with ``iter_patients()``. + task: ``AMAPredictionMIMIC3`` instance (same as used in ``set_task``). + + Returns: + Map ``(patient_id_str, visit_id_str)`` -> ``{"race", "age", + "insurance"}``. Types match raw task outputs (race string without + ``race:`` prefix, age float, insurance category string). """ - lookup = {} + lookup: Dict[Tuple[str, str], Dict[str, Any]] = {} for patient in dataset.iter_patients(): for sample in task(patient): pid = str(sample["patient_id"]) @@ -443,12 +506,22 @@ def _build_demographics_lookup(dataset, task): insurance = t.split(":", 1)[1] break lookup[(pid, vid)] = { - "race": race, "age": age, "insurance": insurance, + "race": race, + "age": age, + "insurance": insurance, } return lookup -def _age_group(age): +def _age_group(age: float) -> str: + """Map continuous age to Boag-style coarse age band labels. + + Args: + age: Age in years (typically from task sample ``age[0]``). + + Returns: + One of ``"Young (18-44)"``, ``"Middle (45-64)"``, ``"Senior (65+)"``. + """ if age < 45: return "Young (18-44)" if age < 65: @@ -460,8 +533,25 @@ def _age_group(age): # Helpers -- inference with demographic labels # ------------------------------------------------------------------ -def _get_predictions(model, dataloader, lookup): - """Run model on *dataloader*, return predictions + subgroup labels.""" + +def _get_predictions( + model: Any, + dataloader: Any, + lookup: Dict[Tuple[str, str], Dict[str, Any]], +) -> Tuple[Any, Any, Dict[str, Any]]: + """Forward pass: collect probabilities, labels, and subgroup tags. + + Args: + model: Trained ``LogisticRegression`` (or compatible ``**batch``). + dataloader: PyHealth dataloader yielding batches with ``patient_id``, + ``visit_id``, and feature tensors. + lookup: Output of :func:`_build_demographics_lookup`. + + Returns: + Tuple ``(y_prob, y_true, groups)`` as ``numpy.ndarray`` vectors for + ``y_prob`` / ``y_true``, and ``groups`` mapping attribute name -> + array of string subgroup labels (Race / Age Group / Insurance). + """ model.eval() all_probs, all_labels = [], [] all_races, all_ages, all_ins = [], [], [] @@ -499,7 +589,17 @@ def _get_predictions(model, dataloader, lookup): # Helpers -- safe metrics # ------------------------------------------------------------------ -def _safe_auroc(y, p): + +def _safe_auroc(y: Any, p: Any) -> float: + """Area under ROC; returns NaN if only one class or sklearn errors. + + Args: + y: Binary labels (1d array-like). + p: Predicted probabilities (1d array-like, same length as ``y``). + + Returns: + AUROC scalar, or ``float("nan")`` when undefined. + """ if len(np.unique(y)) < 2: return float("nan") try: @@ -512,39 +612,83 @@ def _safe_auroc(y, p): # Single split # ------------------------------------------------------------------ -def _create_model(sample_dataset, feature_keys, embedding_dim=128): - """Create a LogisticRegression with the requested feature subset.""" + +def _create_model( + sample_dataset: Any, + feature_keys: List[str], + embedding_dim: int = 128, +) -> LogisticRegression: + """Instantiate ``LogisticRegression`` restricted to ``feature_keys``. + + Args: + sample_dataset: Task output from ``dataset.set_task`` (provides + vocab / feature metadata for embeddings). + feature_keys: Subset of ``AMAPredictionMIMIC3`` input schema keys + (e.g. ``["demographics","age","los"]`` for baseline ablation). + embedding_dim: Linear embedding width per feature block. + + Returns: + Model with ``fc`` resized to ``len(feature_keys) * embedding_dim``. + """ model = LogisticRegression( - dataset=sample_dataset, embedding_dim=embedding_dim, + dataset=sample_dataset, + embedding_dim=embedding_dim, ) model.feature_keys = list(feature_keys) output_size = model.get_output_size() model.fc = torch.nn.Linear( - len(feature_keys) * embedding_dim, output_size, + len(feature_keys) * embedding_dim, + output_size, ) return model -def _run_single_split(sample_dataset, feature_keys, lookup, - seed, epochs, batch_size=32): - """Train + evaluate one 60/40 split. Returns metrics dict or None.""" +def _run_single_split( + sample_dataset: Any, + feature_keys: List[str], + lookup: Dict[Tuple[str, str], Dict[str, Any]], + seed: int, + epochs: int, + batch_size: int = 32, +) -> Optional[Dict[str, Any]]: + """Train one ``LogisticRegression`` on a patient-level 60/40 split. + + Args: + sample_dataset: ``SampleDataset`` from ``set_task(AMAPredictionMIMIC3)``. + feature_keys: Model input keys (must exist in each batch). + lookup: Demographics map for subgroup metrics. + seed: RNG seed for ``split_by_patient``. + epochs: SGD epochs for ``Trainer.train``. + batch_size: Minibatch size for train/test loaders. + + Returns: + Dict with ``"auroc"`` (overall) and ``"subgroups"`` nested metrics, + or ``None`` if training fails. + """ + # Patient-level split keeps all visits for a subject in one partition. train_ds, _, test_ds = split_by_patient( - sample_dataset, [0.6, 0.0, 0.4], seed=seed, + sample_dataset, + [0.6, 0.0, 0.4], + seed=seed, ) train_dl = get_dataloader(train_ds, batch_size=batch_size, shuffle=True) test_dl = get_dataloader(test_ds, batch_size=batch_size, shuffle=False) + # Model only sees tensors for keys in ``feature_keys``; label ``ama``. model = _create_model(sample_dataset, feature_keys) trainer = Trainer(model=model) try: trainer.train( - train_dataloader=train_dl, val_dataloader=None, - epochs=epochs, monitor=None, + train_dataloader=train_dl, + val_dataloader=None, + epochs=epochs, + monitor=None, ) except Exception as exc: print(f" train failed: {exc}") return None + # Fairness slices use ``lookup`` (not part of the forward batch tensors). y_prob, y_true, groups = _get_predictions(model, test_dl, lookup) threshold = 0.5 y_pred = (y_prob >= threshold).astype(int) @@ -564,8 +708,7 @@ def _run_single_split(sample_dataset, feature_keys, lookup, subgroup[attr_name][grp] = { "auroc": _safe_auroc(yt, yp), "pct_pred": float(yd.mean()) * 100, - "tpr": float(yd[yt == 1].mean()) * 100 if pos > 0 - else float("nan"), + "tpr": float(yd[yt == 1].mean()) * 100 if pos > 0 else float("nan"), "n": n, } @@ -579,18 +722,32 @@ def _run_single_split(sample_dataset, feature_keys, lookup, # Aggregation # ------------------------------------------------------------------ -def _nanmean(lst): + +def _nanmean(lst: List[float]) -> float: + """Mean over finite values; ignores NaNs.""" v = [x for x in lst if not np.isnan(x)] - return np.mean(v) if v else float("nan") + return float(np.mean(v)) if v else float("nan") -def _nanstd(lst): +def _nanstd(lst: List[float]) -> float: + """Std dev over finite values; ignores NaNs.""" v = [x for x in lst if not np.isnan(x)] - return np.std(v) if v else float("nan") + return float(np.std(v)) if v else float("nan") + +def _aggregate( + results: List[Optional[Dict[str, Any]]], +) -> Optional[Dict[str, Any]]: + """Pool per-split metric dicts into mean/std summaries. -def _aggregate(results): - """Aggregate per-split metrics into means and stds.""" + Args: + results: List of outputs from :func:`_run_single_split` (may + contain ``None`` for failed splits). + + Returns: + Nested dict with ``auroc_mean``, ``auroc_std``, and per-subgroup + means/stds; ``None`` if every split failed. + """ valid = [r for r in results if r is not None] if not valid: return None @@ -638,11 +795,24 @@ def _aggregate(results): # Pretty-printing # ------------------------------------------------------------------ -def _fmt(val, digits=4): + +def _fmt(val: float, digits: int = 4) -> str: + """Format ``val`` for console tables; show ``N/A`` for NaN.""" return "N/A" if np.isnan(val) else f"{val:.{digits}f}" -def _print_results(name, feature_keys, agg): +def _print_results( + name: str, + feature_keys: List[str], + agg: Optional[Dict[str, Any]], +) -> None: + """Pretty-print one ablation block (overall, subgroup AUROC, fairness). + + Args: + name: Baseline label (e.g. ``"BASELINE+RACE"``). + feature_keys: Feature list used for that run. + agg: Aggregated metrics from :func:`_aggregate`, or ``None``. + """ w = 70 print(f"\n{'=' * w}") print(f" {name} (LogisticRegression)") @@ -656,29 +826,29 @@ def _print_results(name, feature_keys, agg): ci_lo = agg["auroc_mean"] - 1.96 * agg["auroc_std"] ci_hi = agg["auroc_mean"] + 1.96 * agg["auroc_std"] print(f"\n 1. Overall Performance ({agg['n']} splits)") - print(f" AUROC: {_fmt(agg['auroc_mean'])} +/- {_fmt(agg['auroc_std'])}" - f" 95% CI ({_fmt(ci_lo)}, {_fmt(ci_hi)})") + print( + f" AUROC: {_fmt(agg['auroc_mean'])} +/- {_fmt(agg['auroc_std'])}" + f" 95% CI ({_fmt(ci_lo)}, {_fmt(ci_hi)})" + ) - print(f"\n 2. Subgroup Performance") + print("\n 2. Subgroup Performance") for attr, grps in agg["subgroups"].items(): print(f" {attr}:") print(f" {'Group':<20} {'AUROC':>15} {'n_avg':>7}") - print(f" {'-'*42}") + print(f" {'-' * 42}") for grp, m in grps.items(): a_str = f"{_fmt(m['auroc_mean'])}+/-{_fmt(m['auroc_std'])}" print(f" {grp:<20} {a_str:>15} {m['n_avg']:>7}") - print(f"\n 3. Fairness Metrics") - print(f" Demographic Parity (% Predicted AMA):") + print("\n 3. Fairness Metrics") + print(" Demographic Parity (% Predicted AMA):") for attr, grps in agg["subgroups"].items(): - parts = [f"{g}: {_fmt(m['pct_pred_mean'],2)}%" - for g, m in grps.items()] + parts = [f"{g}: {_fmt(m['pct_pred_mean'], 2)}%" for g, m in grps.items()] print(f" {attr}: {', '.join(parts)}") - print(f" Equal Opportunity (True Positive Rate):") + print(" Equal Opportunity (True Positive Rate):") for attr, grps in agg["subgroups"].items(): - parts = [f"{g}: {_fmt(m['tpr_mean'],2)}%" - for g, m in grps.items()] + parts = [f"{g}: {_fmt(m['tpr_mean'], 2)}%" for g, m in grps.items()] print(f" {attr}: {', '.join(parts)}") @@ -686,7 +856,21 @@ def _print_results(name, feature_keys, agg): # Main # ------------------------------------------------------------------ -def main(): + +def main() -> None: + """CLI entry: load data, ``set_task``, build lookup, run all baselines. + + Pipeline: + 1. Resolve synthetic vs real ``root`` and LitData ``cache_dir``. + 2. ``MIMIC3Dataset`` -> ``AMAPredictionMIMIC3`` -> ``SampleDataset``. + 3. Demographics lookup for fairness slices (not passed to the model). + 4. For each entry in ``BASELINES``, repeat ``--splits`` train/eval + loops with the corresponding ``feature_keys`` (task I/O schema + unchanged; model uses a subset of keys). + + Side effects: + Prints logs; may write temp CSV/cache directories. + """ parser = argparse.ArgumentParser( description="AMA prediction ablation -- LogisticRegression", ) @@ -787,7 +971,7 @@ def main(): cache_dir=cache_dir, dev=args.dev, ) - print(f" Loaded in {time.time()-t0:.1f}s") + print(f" Loaded in {time.time() - t0:.1f}s") dataset.stats() print("\n[2/4] Applying AMA task...") @@ -801,13 +985,12 @@ def main(): print(" The dataset contains no AMA-positive cases.") print(" For synthetic data: this is expected if AMA rate is low.") print(" For synthetic random mode with more patients:") - print(" python examples/" - "mimic3_ama_prediction_logistic_regression.py \\") - print(" --data-source synthetic --synthetic-mode random " - "--patients 500\n") + print(" python examples/mimic3_ama_prediction_logistic_regression.py \\") + print( + " --data-source synthetic --synthetic-mode random --patients 500\n" + ) print(" For real MIMIC-III:") - print(" python examples/" - "mimic3_ama_prediction_logistic_regression.py \\") + print(" python examples/mimic3_ama_prediction_logistic_regression.py \\") print(" --data-source real --root /path/to/mimic-iii/1.4") print("\nDone.") return @@ -816,10 +999,12 @@ def main(): print("\n[3/4] Building demographics lookup...") t0 = time.time() lookup = _build_demographics_lookup(dataset, task) - print(f" {len(lookup)} entries in {time.time()-t0:.1f}s") + print(f" {len(lookup)} entries in {time.time() - t0:.1f}s") - print(f"\n[4/4] Running ablation ({len(BASELINES)} baselines " - f"x {args.splits} splits)...\n") + print( + f"\n[4/4] Running ablation ({len(BASELINES)} baselines " + f"x {args.splits} splits)...\n" + ) t_total = time.time() for name, feature_keys in BASELINES.items(): @@ -827,21 +1012,26 @@ def main(): for i in range(args.splits): t0 = time.time() res = _run_single_split( - sample_dataset, feature_keys, lookup, - seed=i, epochs=args.epochs, + sample_dataset, + feature_keys, + lookup, + seed=i, + epochs=args.epochs, ) elapsed = time.time() - t0 if res is not None: - print(f" [{name}] split {i+1:3d}/{args.splits}: " - f"AUROC={_fmt(res['auroc'])} ({elapsed:.1f}s)") + print( + f" [{name}] split {i + 1:3d}/{args.splits}: " + f"AUROC={_fmt(res['auroc'])} ({elapsed:.1f}s)" + ) else: - print(f" [{name}] split {i+1:3d}/{args.splits}: FAILED") + print(f" [{name}] split {i + 1:3d}/{args.splits}: FAILED") split_results.append(res) agg = _aggregate(split_results) _print_results(name, feature_keys, agg) - print(f"\nTotal time: {time.time()-t_total:.1f}s") + print(f"\nTotal time: {time.time() - t_total:.1f}s") print("Done.") diff --git a/examples/mimic3_ama_prediction_rnn.py b/examples/mimic3_ama_prediction_rnn.py index 6729255a9..9760227c8 100644 --- a/examples/mimic3_ama_prediction_rnn.py +++ b/examples/mimic3_ama_prediction_rnn.py @@ -1,34 +1,40 @@ -"""AMA Prediction -- RNN Ablation with Fairness Analysis. +"""Ablation study for MIMIC-III Against-Medical-Advice (AMA) discharge prediction (RNN). -Reproduces the Against-Medical-Advice discharge prediction from: +This script demonstrates ``AMAPredictionMIMIC3`` with three ``feature_keys`` +ablations (baseline, +race, +substance), training PyHealth's ``RNN`` so the +mapping from task tensors to logits is non-linear. One visit +per sequence yields a compact recurrent block over the selected features. +Labels and task schemas match ``AMAPredictionMIMIC3`` (AMA from discharge +location; ``input_schema`` / ``output_schema`` unchanged). - Boag, W.; Suresh, H.; Celi, L. A.; Szolovits, P.; Ghassemi, M. +Paper: + Boag, W.; Suresh, H.; Celi, L. A.; Szolovits, P.; and Ghassemi, M. "Racial Disparities and Mistrust in End-of-Life Care." Machine Learning for Healthcare Conference, PMLR, 2018. -This script uses PyHealth's RNN model instead of LogisticRegression. -For the paper's demographic-only baselines the RNN degenerates to a -single-step recurrence (effectively a non-linear transform), which -lets us directly compare the impact of model capacity vs. the -LogisticRegression ablation in the companion script. - -For each baseline the script reports: - 1. Overall AUROC averaged over N random 60/40 splits. - 2. Subgroup performance (AUROC) sliced by Race, Age Group, - and Insurance Type. - 3. Fairness metrics per subgroup: - - Demographic Parity = % predicted AMA (P(Y_hat=1 | Group=g)) - - Equal Opportunity = True Positive Rate (P(Y_hat=1 | Y=1, Group=g)) - -Synthetic data is generated by the same helper as the LogisticRegression -example (``generate_synthetic_mimic3`` in ``mimic3_ama_prediction_ -logistic_regression.py``). - -Usage (synthetic exhaustive grid, default): - python examples/mimic3_ama_prediction_rnn.py --data-source synthetic - -Usage (real MIMIC-III): - python examples/mimic3_ama_prediction_rnn.py \\ +Ablation configurations tested: + 1. BASELINE: demographics + age + LOS. + 2. BASELINE+RACE: adds ``race`` (normalized ethnicity). + 3. BASELINE+RACE+SUBSTANCE: adds ``has_substance_use`` from diagnosis text. + +Results: + For each baseline configuration we report: + - Overall AUROC over N random 60/40 patient-level splits. + - Subgroup AUROC by race, age group, and insurance. + - Demographic parity and equal-opportunity (TPR) tables per subgroup. + +Synthetic data: + ``generate_synthetic_mimic3`` is imported from + ``examples/mimic3_ama_prediction_logistic_regression.py`` (single shared + implementation of the CSV writer). + +Usage: + # Synthetic exhaustive grid (default when no ``--root``) + cd /path/to/PyHealth && python examples/mimic3_ama_prediction_rnn.py \\ + --data-source synthetic + + # Real MIMIC-III 1.4 (set ``--root`` to your extract path) + cd /path/to/PyHealth && python examples/mimic3_ama_prediction_rnn.py \\ --data-source real --root /path/to/mimic-iii/1.4 --splits 100 --epochs 10 """ @@ -38,6 +44,7 @@ import tempfile import time from pathlib import Path +from typing import Any, Dict, List, Optional, Tuple import numpy as np import torch @@ -52,7 +59,8 @@ "mimic3_ama_prediction_logistic_regression.py" ) _spec = importlib.util.spec_from_file_location( - "mimic3_ama_lr_example", _LR_EXAMPLE, + "mimic3_ama_lr_example", + _LR_EXAMPLE, ) _lr_mod = importlib.util.module_from_spec(_spec) assert _spec.loader is not None @@ -63,7 +71,11 @@ "BASELINE": ["demographics", "age", "los"], "BASELINE+RACE": ["demographics", "age", "los", "race"], "BASELINE+RACE+SUBSTANCE": [ - "demographics", "age", "los", "race", "has_substance_use", + "demographics", + "age", + "los", + "race", + "has_substance_use", ], } @@ -72,13 +84,21 @@ # Helpers -- demographics lookup # ------------------------------------------------------------------ -def _build_demographics_lookup(dataset, task): - """Run the task on every patient and collect raw demographic info. - Returns a dict mapping ``(patient_id, visit_id)`` to a dict with - keys ``race``, ``age``, and ``insurance``. +def _build_demographics_lookup( + dataset: Any, + task: AMAPredictionMIMIC3, +) -> Dict[Tuple[str, str], Dict[str, Any]]: + """Map each visit to race, age, and insurance for subgroup metrics. + + Args: + dataset: ``MIMIC3Dataset`` with ``iter_patients()``. + task: Same ``AMAPredictionMIMIC3`` used with ``set_task``. + + Returns: + ``(patient_id, visit_id)`` strings -> ``{"race","age","insurance"}``. """ - lookup = {} + lookup: Dict[Tuple[str, str], Dict[str, Any]] = {} for patient in dataset.iter_patients(): for sample in task(patient): pid = str(sample["patient_id"]) @@ -91,12 +111,15 @@ def _build_demographics_lookup(dataset, task): insurance = t.split(":", 1)[1] break lookup[(pid, vid)] = { - "race": race, "age": age, "insurance": insurance, + "race": race, + "age": age, + "insurance": insurance, } return lookup -def _age_group(age): +def _age_group(age: float) -> str: + """Map age in years to Young / Middle / Senior subgroup labels.""" if age < 45: return "Young (18-44)" if age < 65: @@ -108,8 +131,22 @@ def _age_group(age): # Helpers -- inference with demographic labels # ------------------------------------------------------------------ -def _get_predictions(model, dataloader, lookup): - """Run model on *dataloader*, return predictions + subgroup labels.""" + +def _get_predictions( + model: Any, + dataloader: Any, + lookup: Dict[Tuple[str, str], Dict[str, Any]], +) -> Tuple[Any, Any, Dict[str, Any]]: + """Evaluate ``model`` on ``dataloader``; attach subgroup labels. + + Args: + model: Trained ``RNN`` accepting batch kwargs. + dataloader: Test loader with ``patient_id`` / ``visit_id``. + lookup: Demographics map from :func:`_build_demographics_lookup`. + + Returns: + ``(y_prob, y_true, groups)`` numpy arrays / group dict. + """ model.eval() all_probs, all_labels = [], [] all_races, all_ages, all_ins = [], [], [] @@ -147,7 +184,9 @@ def _get_predictions(model, dataloader, lookup): # Helpers -- safe metrics # ------------------------------------------------------------------ -def _safe_auroc(y, p): + +def _safe_auroc(y: Any, p: Any) -> float: + """ROC-AUC with guards for single-class slices.""" if len(np.unique(y)) < 2: return float("nan") try: @@ -160,9 +199,24 @@ def _safe_auroc(y, p): # Single split # ------------------------------------------------------------------ -def _create_model(sample_dataset, feature_keys, - embedding_dim=128, hidden_dim=64): - """Create an RNN with the requested feature subset.""" + +def _create_model( + sample_dataset: Any, + feature_keys: List[str], + embedding_dim: int = 128, + hidden_dim: int = 64, +) -> RNN: + """Build ``RNN`` over ``feature_keys`` (ablation-controlled inputs). + + Args: + sample_dataset: ``SampleDataset`` after ``set_task``. + feature_keys: Keys from ``AMAPredictionMIMIC3.input_schema``. + embedding_dim: Token embedding size. + hidden_dim: RNN hidden size (``fc`` input is ``len(keys)*hidden_dim``). + + Returns: + Configured ``RNN`` instance. + """ model = RNN( dataset=sample_dataset, embedding_dim=embedding_dim, @@ -171,26 +225,50 @@ def _create_model(sample_dataset, feature_keys, model.feature_keys = list(feature_keys) output_size = model.get_output_size() model.fc = torch.nn.Linear( - len(feature_keys) * hidden_dim, output_size, + len(feature_keys) * hidden_dim, + output_size, ) return model -def _run_single_split(sample_dataset, feature_keys, lookup, - seed, epochs, batch_size=32): - """Train + evaluate one 60/40 split. Returns metrics dict or None.""" +def _run_single_split( + sample_dataset: Any, + feature_keys: List[str], + lookup: Dict[Tuple[str, str], Dict[str, Any]], + seed: int, + epochs: int, + batch_size: int = 32, +) -> Optional[Dict[str, Any]]: + """One patient-level split: train RNN, then test + fairness stats. + + Args: + sample_dataset: AMA task samples. + feature_keys: Subset of input schema keys for this ablation. + lookup: Demographics for post-hoc slicing. + seed: Split RNG seed. + epochs: Training epochs. + batch_size: Loader batch size. + + Returns: + Metric dict or ``None`` on training failure. + """ train_ds, _, test_ds = split_by_patient( - sample_dataset, [0.6, 0.0, 0.4], seed=seed, + sample_dataset, + [0.6, 0.0, 0.4], + seed=seed, ) train_dl = get_dataloader(train_ds, batch_size=batch_size, shuffle=True) test_dl = get_dataloader(test_ds, batch_size=batch_size, shuffle=False) + # ``feature_keys`` selects which task outputs the RNN reads for this run. model = _create_model(sample_dataset, feature_keys) trainer = Trainer(model=model) try: trainer.train( - train_dataloader=train_dl, val_dataloader=None, - epochs=epochs, monitor=None, + train_dataloader=train_dl, + val_dataloader=None, + epochs=epochs, + monitor=None, ) except Exception as exc: print(f" train failed: {exc}") @@ -202,6 +280,7 @@ def _run_single_split(sample_dataset, feature_keys, lookup, overall_auroc = _safe_auroc(y_true, y_prob) + # Nested: attribute name -> group label -> AUROC, % predicted AMA, TPR, n. subgroup = {} for attr_name, attr_vals in groups.items(): subgroup[attr_name] = {} @@ -215,8 +294,7 @@ def _run_single_split(sample_dataset, feature_keys, lookup, subgroup[attr_name][grp] = { "auroc": _safe_auroc(yt, yp), "pct_pred": float(yd.mean()) * 100, - "tpr": float(yd[yt == 1].mean()) * 100 if pos > 0 - else float("nan"), + "tpr": float(yd[yt == 1].mean()) * 100 if pos > 0 else float("nan"), "n": n, } @@ -230,18 +308,23 @@ def _run_single_split(sample_dataset, feature_keys, lookup, # Aggregation # ------------------------------------------------------------------ -def _nanmean(lst): + +def _nanmean(lst: List[float]) -> float: + """Mean ignoring NaNs.""" v = [x for x in lst if not np.isnan(x)] - return np.mean(v) if v else float("nan") + return float(np.mean(v)) if v else float("nan") -def _nanstd(lst): +def _nanstd(lst: List[float]) -> float: + """Std ignoring NaNs.""" v = [x for x in lst if not np.isnan(x)] - return np.std(v) if v else float("nan") + return float(np.std(v)) if v else float("nan") -def _aggregate(results): - """Aggregate per-split metrics into means and stds.""" +def _aggregate( + results: List[Optional[Dict[str, Any]]], +) -> Optional[Dict[str, Any]]: + """Mean/std of overall AUROC and per-subgroup metrics across splits.""" valid = [r for r in results if r is not None] if not valid: return None @@ -289,11 +372,18 @@ def _aggregate(results): # Pretty-printing # ------------------------------------------------------------------ -def _fmt(val, digits=4): + +def _fmt(val: float, digits: int = 4) -> str: + """Human-readable float or ``N/A`` for NaN.""" return "N/A" if np.isnan(val) else f"{val:.{digits}f}" -def _print_results(name, feature_keys, agg): +def _print_results( + name: str, + feature_keys: List[str], + agg: Optional[Dict[str, Any]], +) -> None: + """Print ablation summary for RNN runs.""" w = 70 print(f"\n{'=' * w}") print(f" {name} (RNN hidden_dim=64)") @@ -307,29 +397,29 @@ def _print_results(name, feature_keys, agg): ci_lo = agg["auroc_mean"] - 1.96 * agg["auroc_std"] ci_hi = agg["auroc_mean"] + 1.96 * agg["auroc_std"] print(f"\n 1. Overall Performance ({agg['n']} splits)") - print(f" AUROC: {_fmt(agg['auroc_mean'])} +/- {_fmt(agg['auroc_std'])}" - f" 95% CI ({_fmt(ci_lo)}, {_fmt(ci_hi)})") + print( + f" AUROC: {_fmt(agg['auroc_mean'])} +/- {_fmt(agg['auroc_std'])}" + f" 95% CI ({_fmt(ci_lo)}, {_fmt(ci_hi)})" + ) - print(f"\n 2. Subgroup Performance") + print("\n 2. Subgroup Performance") for attr, grps in agg["subgroups"].items(): print(f" {attr}:") print(f" {'Group':<20} {'AUROC':>15} {'n_avg':>7}") - print(f" {'-'*42}") + print(f" {'-' * 42}") for grp, m in grps.items(): a_str = f"{_fmt(m['auroc_mean'])}+/-{_fmt(m['auroc_std'])}" print(f" {grp:<20} {a_str:>15} {m['n_avg']:>7}") - print(f"\n 3. Fairness Metrics") - print(f" Demographic Parity (% Predicted AMA):") + print("\n 3. Fairness Metrics") + print(" Demographic Parity (% Predicted AMA):") for attr, grps in agg["subgroups"].items(): - parts = [f"{g}: {_fmt(m['pct_pred_mean'],2)}%" - for g, m in grps.items()] + parts = [f"{g}: {_fmt(m['pct_pred_mean'], 2)}%" for g, m in grps.items()] print(f" {attr}: {', '.join(parts)}") - print(f" Equal Opportunity (True Positive Rate):") + print(" Equal Opportunity (True Positive Rate):") for attr, grps in agg["subgroups"].items(): - parts = [f"{g}: {_fmt(m['tpr_mean'],2)}%" - for g, m in grps.items()] + parts = [f"{g}: {_fmt(m['tpr_mean'], 2)}%" for g, m in grps.items()] print(f" {attr}: {', '.join(parts)}") @@ -337,7 +427,19 @@ def _print_results(name, feature_keys, agg): # Main # ------------------------------------------------------------------ -def main(): + +def main() -> None: + """CLI entry: load MIMIC-III (or synthetic CSVs), apply AMA task, train RNN. + + Resolves ``--data-source`` / ``--root`` / ``--synthetic-mode``, builds a + ``MIMIC3Dataset`` and ``SampleDataset`` via ``set_task(AMAPredictionMIMIC3)``, + builds a demographics lookup for subgroup metrics, then for each entry in + ``BASELINES`` runs ``--splits`` train/eval loops. Each loop uses the + listed ``feature_keys`` so the ``RNN`` reads only that subset of task + outputs; the task ``input_schema`` / ``output_schema`` are unchanged. + + Side effects: prints progress; creates temporary data and cache dirs. + """ parser = argparse.ArgumentParser( description="AMA prediction ablation -- RNN", ) @@ -352,23 +454,29 @@ def main(): "--synthetic-mode", choices=("exhaustive", "random"), default="exhaustive", - help="Passed to generate_synthetic_mimic3 (see logistic example).", + help="Synthetic grid: exhaustive cross-product or random; " + "see generate_synthetic_mimic3 docstring in the loaded module.", ) parser.add_argument( "--root", default=None, help="MIMIC-III root (required for --data-source real).", ) - parser.add_argument("--patients", type=int, default=100, - help="Random synthetic mode only.") - parser.add_argument("--seed", type=int, default=42, - help="RNG seed for random synthetic mode.") - parser.add_argument("--splits", type=int, default=5, - help="Number of random 60/40 splits") - parser.add_argument("--epochs", type=int, default=3, - help="Training epochs per split") - parser.add_argument("--dev", action="store_true", - help="Random synthetic: 10 patients only.") + parser.add_argument( + "--patients", type=int, default=100, help="Random synthetic mode only." + ) + parser.add_argument( + "--seed", type=int, default=42, help="RNG seed for random synthetic mode." + ) + parser.add_argument( + "--splits", type=int, default=5, help="Number of random 60/40 splits" + ) + parser.add_argument( + "--epochs", type=int, default=3, help="Training epochs per split" + ) + parser.add_argument( + "--dev", action="store_true", help="Random synthetic: 10 patients only." + ) args = parser.parse_args() if args.data_source == "real" and not args.root: @@ -409,10 +517,12 @@ def main(): print("\n[1/4] Loading dataset...") t0 = time.time() dataset = MIMIC3Dataset( - root=args.root, tables=[], - cache_dir=cache_dir, dev=args.dev, + root=args.root, + tables=[], + cache_dir=cache_dir, + dev=args.dev, ) - print(f" Loaded in {time.time()-t0:.1f}s") + print(f" Loaded in {time.time() - t0:.1f}s") dataset.stats() print("\n[2/4] Applying AMA task...") @@ -433,8 +543,7 @@ def main(): total += len(samples) print(f" Task produced {total} samples (all label=0)") print("\n Re-run with real MIMIC-III:") - print(" python examples/" - "mimic3_ama_prediction_rnn.py \\") + print(" python examples/mimic3_ama_prediction_rnn.py \\") print(" --data-source real --root /path/to/mimic-iii/1.4") print("\nDone.") return @@ -443,10 +552,12 @@ def main(): print("\n[3/4] Building demographics lookup...") t0 = time.time() lookup = _build_demographics_lookup(dataset, task) - print(f" {len(lookup)} entries in {time.time()-t0:.1f}s") + print(f" {len(lookup)} entries in {time.time() - t0:.1f}s") - print(f"\n[4/4] Running ablation ({len(BASELINES)} baselines " - f"x {args.splits} splits)...\n") + print( + f"\n[4/4] Running ablation ({len(BASELINES)} baselines " + f"x {args.splits} splits)...\n" + ) t_total = time.time() for name, feature_keys in BASELINES.items(): @@ -454,21 +565,26 @@ def main(): for i in range(args.splits): t0 = time.time() res = _run_single_split( - sample_dataset, feature_keys, lookup, - seed=i, epochs=args.epochs, + sample_dataset, + feature_keys, + lookup, + seed=i, + epochs=args.epochs, ) elapsed = time.time() - t0 if res is not None: - print(f" [{name}] split {i+1:3d}/{args.splits}: " - f"AUROC={_fmt(res['auroc'])} ({elapsed:.1f}s)") + print( + f" [{name}] split {i + 1:3d}/{args.splits}: " + f"AUROC={_fmt(res['auroc'])} ({elapsed:.1f}s)" + ) else: - print(f" [{name}] split {i+1:3d}/{args.splits}: FAILED") + print(f" [{name}] split {i + 1:3d}/{args.splits}: FAILED") split_results.append(res) agg = _aggregate(split_results) _print_results(name, feature_keys, agg) - print(f"\nTotal time: {time.time()-t_total:.1f}s") + print(f"\nTotal time: {time.time() - t_total:.1f}s") print("Done.") diff --git a/pyhealth/tasks/ama_prediction.py b/pyhealth/tasks/ama_prediction.py index f62b7afa8..8865192cd 100644 --- a/pyhealth/tasks/ama_prediction.py +++ b/pyhealth/tasks/ama_prediction.py @@ -1,7 +1,34 @@ -"""MIMIC-III Against-Medical-Advice (AMA) discharge prediction task. +"""MIMIC-III Against-Medical-Advice (AMA) discharge prediction task module. + +Defines :class:`AMAPredictionMIMIC3`, a :class:`~pyhealth.tasks.base_task.BaseTask` +subclass that emits one sample per eligible ICU admission from MIMIC-III +administrative tables (patients, admissions, icustays). No ICD / procedure / +prescription sequences are required. Module-level helpers normalize race and +insurance strings, parse datetimes, and flag substance-related diagnoses from +free-text ``ADMISSIONS.DIAGNOSIS``. + +Paper: + Boag, W.; Suresh, H.; Celi, L. A.; Szolovits, P.; and Ghassemi, M. + "Racial Disparities and Mistrust in End-of-Life Care." Machine Learning + for Healthcare Conference, PMLR 106:211-235, 2018. + +Model-side ablations (task schema is fixed; select ``feature_keys`` on the +model, as in the example scripts): + 1. BASELINE: ``demographics`` (gender + insurance), ``age``, ``los``. + 2. BASELINE+RACE: adds ``race`` (normalized ethnicity tokens). + 3. BASELINE+RACE+SUBSTANCE: adds ``has_substance_use`` (0/1 tensor). + +Task I/O (schemas for ``dataset.set_task``): + - ``input_schema``: ``multi_hot`` demographics and race; ``tensor`` age, + LOS, substance flag. + - ``output_schema``: binary ``ama`` from ``discharge_location``. + +Usage: + >>> from pyhealth.datasets import MIMIC3Dataset + >>> from pyhealth.tasks.ama_prediction import AMAPredictionMIMIC3 + >>> dataset = MIMIC3Dataset(root="/path/to/mimic-iii/1.4", tables=[]) + >>> sample_ds = dataset.set_task(AMAPredictionMIMIC3()) -Defines :class:`AMAPredictionMIMIC3` and helpers for Boag et al. 2018-style -demographic baselines (race / substance-use ablations). """ import re @@ -212,21 +239,13 @@ def __call__(self, patient: Any) -> List[Dict[str, Any]]: for admission in admissions: if self.exclude_newborns: - admission_type = getattr( - admission, "admission_type", None - ) + admission_type = getattr(admission, "admission_type", None) if admission_type == "NEWBORN": continue # --- Label --- - discharge_location = getattr( - admission, "discharge_location", None - ) - ama_label = ( - 1 - if discharge_location == self.AMA_DISCHARGE_LOCATION - else 0 - ) + discharge_location = getattr(admission, "discharge_location", None) + ama_label = 1 if discharge_location == self.AMA_DISCHARGE_LOCATION else 0 # --- BASELINE demographics (gender + insurance) --- insurance = getattr(admission, "insurance", None) @@ -234,15 +253,11 @@ def __call__(self, patient: Any) -> List[Dict[str, Any]]: demo_tokens: List[str] = [] if gender: demo_tokens.append(f"gender:{gender}") - demo_tokens.append( - f"insurance:{_normalize_insurance(insurance)}" - ) + demo_tokens.append(f"insurance:{_normalize_insurance(insurance)}") # --- Race (separate feature for ablation) --- ethnicity = getattr(admission, "ethnicity", None) - race_tokens: List[str] = [ - f"race:{_normalize_race(ethnicity)}" - ] + race_tokens: List[str] = [f"race:{_normalize_race(ethnicity)}"] # --- Age (continuous) --- age_years = 0.0 @@ -252,10 +267,7 @@ def __call__(self, patient: Any) -> List[Dict[str, Any]]: age_years = ( admit_dt.year - dob.year - - int( - (admit_dt.month, admit_dt.day) - < (dob.month, dob.day) - ) + - int((admit_dt.month, admit_dt.day) < (dob.month, dob.day)) ) age_years = float(min(age_years, 90)) @@ -268,8 +280,7 @@ def __call__(self, patient: Any) -> List[Dict[str, Any]]: admit_dt = admission.timestamp if isinstance(admit_dt, datetime): los_days = max( - (dischtime - admit_dt).total_seconds() - / 86400.0, + (dischtime - admit_dt).total_seconds() / 86400.0, 0.0, ) except (ValueError, TypeError): diff --git a/tests/core/test_mimic3_ama_prediction.py b/tests/core/test_mimic3_ama_prediction.py index 1e93e6b43..f08add200 100644 --- a/tests/core/test_mimic3_ama_prediction.py +++ b/tests/core/test_mimic3_ama_prediction.py @@ -1,11 +1,37 @@ -"""Unit and integration tests for AMA discharge prediction on MIMIC-III. - -Aligned with the PyHealth PR checklist ("What makes a good PyHealth PR?", -see course slides): task inherits ``BaseTask`` with explicit schemas; -integration tests use a **five-row** shared synthetic slice so runs stay -fast; mocks cover edge cases without loading full MIMIC-III. End-to-end -usage lives in ``examples/mimic3_ama_prediction_logistic_regression.py`` and -``examples/mimic3_ama_prediction_rnn.py``. +"""Test suite for MIMIC-III Against-Medical-Advice (AMA) discharge prediction. + +This is the automated test module for ``AMAPredictionMIMIC3`` +(``pyhealth.tasks.ama_prediction``). It provides a comprehensive set of checks +covering: + + - Task helpers: race and insurance normalization, substance-use + detection from diagnosis text. + - Task contract: ``task_name``, ``input_schema``, ``output_schema``, and + default flags (e.g. newborn filtering). + - Feature engineering on mock patients: AMA vs non-AMA labels, age and + LOS from timestamps, demographics vs separate ``race`` tokens, + ``has_substance_use``, multi-admission behavior, and schema key sets. + - Ablation-oriented checks: baseline feature presence, label correctness, + and absence of clinical code fields in samples. + - Integration: curated five-row gzipped MIMIC-style CSVs with + ``MIMIC3Dataset`` + ``set_task`` + ``LogisticRegression`` forward passes + and short ``Trainer`` smoke runs. + - Synthetic generator sanity: exhaustive grid patient row count. + +Paper (task motivation): + Boag, W.; Suresh, H.; Celi, L. A.; Szolovits, P.; and Ghassemi, M. + "Racial Disparities and Mistrust in End-of-Life Care." MLHC / PMLR, 2018. + +Usage: + # From the PyHealth repository root (quiet summary) + cd /path/to/PyHealth && python -m unittest tests.core.test_mimic3_ama_prediction -q + + # Verbose per-test output + cd /path/to/PyHealth && python -m unittest tests.core.test_mimic3_ama_prediction -v + + # Run this file directly + cd /path/to/PyHealth && python tests/core/test_mimic3_ama_prediction.py + """ import gc @@ -19,6 +45,7 @@ import warnings from datetime import datetime from pathlib import Path +from typing import Any, Dict, List from unittest.mock import MagicMock import pandas as pd @@ -45,7 +72,8 @@ / "mimic3_ama_prediction_logistic_regression.py" ) _spec = importlib.util.spec_from_file_location( - "mimic3_ama_prediction_example", _EXAMPLE_PATH, + "mimic3_ama_prediction_example", + _EXAMPLE_PATH, ) _example_mod = importlib.util.module_from_spec(_spec) assert _spec.loader is not None @@ -257,7 +285,20 @@ def _write_curated_synthetic_mimic3_for_tests(root: str) -> None: - """Write ``_CURATED_MIMIC3_*`` to gzipped PATIENTS/ADMISSIONS/ICUSTAYS.""" + """Materialize the fixed 5-row MIMIC-III-like CSV.gz bundle for tests. + + Args: + root: Directory receiving ``PATIENTS.csv.gz``, ``ADMISSIONS.csv.gz``, + ``ICUSTAYS.csv.gz`` (column names match ``MIMIC3Dataset`` ingest). + + Returns: + None. + + Note: + Row mix matches ``CURATED_SYNTHETIC_*`` constants so + ``AMAPredictionMIMIC3`` yields both AMA labels and varied demographics + without loading real MIMIC-III. + """ root_path = Path(root) root_path.mkdir(parents=True, exist_ok=True) for name, rows in ( @@ -280,25 +321,36 @@ def _write_curated_synthetic_mimic3_for_tests(root: str) -> None: _shared_sample_dataset = None -def setUpModule(): +def setUpModule() -> None: + """Load one shared ``MIMIC3Dataset`` + ``SampleDataset`` for integration tests. + + Runs once per test module import: writes curated CSVs, builds the base + dataset with ``tables=[]``, applies ``AMAPredictionMIMIC3`` via + ``set_task`` (task ``input_schema`` / ``output_schema`` drive processors). + """ global _shared_tmpdir, _shared_cache_dir global _shared_dataset, _shared_sample_dataset _shared_tmpdir = tempfile.mkdtemp(prefix="ama_shared_") _shared_cache_dir = tempfile.mkdtemp(prefix="ama_shared_cache_") _write_curated_synthetic_mimic3_for_tests(_shared_tmpdir) _shared_dataset = MIMIC3Dataset( - root=_shared_tmpdir, tables=[], cache_dir=_shared_cache_dir, + root=_shared_tmpdir, + tables=[], + cache_dir=_shared_cache_dir, ) _shared_sample_dataset = _shared_dataset.set_task(AMAPredictionMIMIC3()) -def tearDownModule(): +def tearDownModule() -> None: + """Release LitData handles and remove temp CSV/cache directories.""" global _shared_dataset, _shared_sample_dataset if _shared_sample_dataset is not None: _shared_sample_dataset.close() _shared_sample_dataset = None _shared_dataset = None gc.collect() + # Proactively close lingering ``.ld`` chunk readers to avoid shutdown + # ``ResourceWarning`` from litdata after temp dirs are removed. with warnings.catch_warnings(): warnings.simplefilter("ignore") for obj in gc.get_objects(): @@ -324,8 +376,15 @@ def tearDownModule(): # ------------------------------------------------------------------ -def _make_event(**attrs): - """Create a mock event with the given attributes.""" +def _make_event(**attrs: Any) -> MagicMock: + """Build a ``MagicMock`` admission/drug/etc. row for unit tests. + + Args: + **attrs: Field names and values (e.g. ``hadm_id=``, ``icd9_code=``). + + Returns: + Mock object whose attributes match ``attrs``. + """ event = MagicMock() for key, value in attrs.items(): setattr(event, key, value) @@ -333,24 +392,33 @@ def _make_event(**attrs): def _build_patient( - patient_id, - admissions, - diagnoses, - procedures, - prescriptions, - gender="M", - dob="2100-01-01 00:00:00", -): - """Build a mock Patient with ``get_events`` that respects filters. + patient_id: str, + admissions: List[Dict[str, Any]], + diagnoses: List[Dict[str, Any]], + procedures: List[Dict[str, Any]], + prescriptions: List[Dict[str, Any]], + gender: str = "M", + dob: str = "2100-01-01 00:00:00", +) -> MagicMock: + """Build a mock ``Patient`` whose ``get_events`` mirrors PyHealth filters. Args: - patient_id: Patient identifier string. - admissions: List of dicts for admission events. - diagnoses: List of dicts for diagnosis events. - procedures: List of dicts for procedure events. - prescriptions: List of dicts for prescription events. - gender: Gender string for the demographics event. - dob: Date-of-birth string for computing age. + patient_id: ``patient.patient_id`` string. + admissions: kwargs for each admission ``MagicMock``. + diagnoses: kwargs for diagnosis events (``hadm_id`` aligned). + procedures: kwargs for procedure events. + prescriptions: kwargs for prescription events. + gender: Demographics token for ``patients`` event. + dob: DOB string feeding age logic in ``AMAPredictionMIMIC3``. + + Returns: + ``MagicMock`` with ``get_events`` routing ``event_type`` to the proper + list and honoring simple ``filters`` tuples on code tables. + + Note: + Output samples follow ``AMAPredictionMIMIC3.input_schema`` / + ``output_schema`` (same keys as real ``set_task`` tensors after + processors run on CSV-backed data). """ patient = MagicMock() patient.patient_id = patient_id @@ -373,11 +441,7 @@ def _get_events(event_type, filters=None, **kwargs): }.get(event_type, []) if filters: col, op, val = filters[0] - source = [ - e - for e in source - if getattr(e, col, None) == val - ] + source = [e for e in source if getattr(e, col, None) == val] return source patient.get_events = _get_events @@ -401,27 +465,17 @@ class TestNormalizeRace(unittest.TestCase): def test_white(self): self.assertEqual(_normalize_race("WHITE"), "White") - self.assertEqual( - _normalize_race("WHITE - RUSSIAN"), "White" - ) + self.assertEqual(_normalize_race("WHITE - RUSSIAN"), "White") def test_black(self): - self.assertEqual( - _normalize_race("BLACK/AFRICAN AMERICAN"), "Black" - ) + self.assertEqual(_normalize_race("BLACK/AFRICAN AMERICAN"), "Black") def test_hispanic(self): - self.assertEqual( - _normalize_race("HISPANIC OR LATINO"), "Hispanic" - ) - self.assertEqual( - _normalize_race("SOUTH AMERICAN"), "Hispanic" - ) + self.assertEqual(_normalize_race("HISPANIC OR LATINO"), "Hispanic") + self.assertEqual(_normalize_race("SOUTH AMERICAN"), "Hispanic") def test_asian(self): - self.assertEqual( - _normalize_race("ASIAN - CHINESE"), "Asian" - ) + self.assertEqual(_normalize_race("ASIAN - CHINESE"), "Asian") def test_native_american(self): self.assertEqual( @@ -430,27 +484,15 @@ def test_native_american(self): ) def test_other(self): - self.assertEqual( - _normalize_race("UNKNOWN/NOT SPECIFIED"), "Other" - ) + self.assertEqual(_normalize_race("UNKNOWN/NOT SPECIFIED"), "Other") self.assertEqual(_normalize_race(None), "Other") def test_normalize_insurance(self): - self.assertEqual( - _normalize_insurance("Medicare"), "Public" - ) - self.assertEqual( - _normalize_insurance("Medicaid"), "Public" - ) - self.assertEqual( - _normalize_insurance("Government"), "Public" - ) - self.assertEqual( - _normalize_insurance("Private"), "Private" - ) - self.assertEqual( - _normalize_insurance("Self Pay"), "Self Pay" - ) + self.assertEqual(_normalize_insurance("Medicare"), "Public") + self.assertEqual(_normalize_insurance("Medicaid"), "Public") + self.assertEqual(_normalize_insurance("Government"), "Public") + self.assertEqual(_normalize_insurance("Private"), "Private") + self.assertEqual(_normalize_insurance("Self Pay"), "Self Pay") self.assertEqual(_normalize_insurance(None), "Other") @@ -458,42 +500,28 @@ class TestHasSubstanceUse(unittest.TestCase): """Unit tests for the substance-use detection helper.""" def test_alcohol(self): - self.assertEqual( - _has_substance_use("ALCOHOL WITHDRAWAL"), 1 - ) + self.assertEqual(_has_substance_use("ALCOHOL WITHDRAWAL"), 1) def test_opioid(self): - self.assertEqual( - _has_substance_use("OPIOID DEPENDENCE"), 1 - ) + self.assertEqual(_has_substance_use("OPIOID DEPENDENCE"), 1) def test_heroin(self): - self.assertEqual( - _has_substance_use("HEROIN OVERDOSE"), 1 - ) + self.assertEqual(_has_substance_use("HEROIN OVERDOSE"), 1) def test_cocaine(self): - self.assertEqual( - _has_substance_use("COCAINE INTOXICATION"), 1 - ) + self.assertEqual(_has_substance_use("COCAINE INTOXICATION"), 1) def test_drug_withdrawal(self): - self.assertEqual( - _has_substance_use("DRUG WITHDRAWAL SEIZURE"), 1 - ) + self.assertEqual(_has_substance_use("DRUG WITHDRAWAL SEIZURE"), 1) def test_etoh(self): self.assertEqual(_has_substance_use("ETOH ABUSE"), 1) def test_substance(self): - self.assertEqual( - _has_substance_use("SUBSTANCE ABUSE"), 1 - ) + self.assertEqual(_has_substance_use("SUBSTANCE ABUSE"), 1) def test_overdose(self): - self.assertEqual( - _has_substance_use("OVERDOSE - ACCIDENTAL"), 1 - ) + self.assertEqual(_has_substance_use("OVERDOSE - ACCIDENTAL"), 1) def test_negative(self): self.assertEqual(_has_substance_use("PNEUMONIA"), 0) @@ -503,12 +531,8 @@ def test_none(self): self.assertEqual(_has_substance_use(None), 0) def test_case_insensitive(self): - self.assertEqual( - _has_substance_use("alcohol withdrawal"), 1 - ) - self.assertEqual( - _has_substance_use("Heroin Overdose"), 1 - ) + self.assertEqual(_has_substance_use("alcohol withdrawal"), 1) + self.assertEqual(_has_substance_use("Heroin Overdose"), 1) class TestAMAPredictionMIMIC3Schema(unittest.TestCase): @@ -526,14 +550,10 @@ def test_input_schema(self): self.assertEqual(schema["age"], "tensor") self.assertEqual(schema["los"], "tensor") self.assertEqual(schema["race"], "multi_hot") - self.assertEqual( - schema["has_substance_use"], "tensor" - ) + self.assertEqual(schema["has_substance_use"], "tensor") def test_output_schema(self): - self.assertEqual( - AMAPredictionMIMIC3.output_schema, {"ama": "binary"} - ) + self.assertEqual(AMAPredictionMIMIC3.output_schema, {"ama": "binary"}) def test_defaults(self): task = AMAPredictionMIMIC3() @@ -548,11 +568,25 @@ class TestAMAPredictionMIMIC3Mock(unittest.TestCase): All tests complete in milliseconds. """ - def setUp(self): + def setUp(self) -> None: + """Fresh task instance per test (default ``exclude_newborns=True``).""" self.task = AMAPredictionMIMIC3() - def _default_admission(self, hadm_id="100", **overrides): - """Return a standard admission dict.""" + def _default_admission( + self, + hadm_id: str = "100", + **overrides: Any, + ) -> Dict[str, Any]: + """Admission kwargs for ``_build_patient`` with sane AMA-test defaults. + + Args: + hadm_id: Visit id string stored on the mock admission. + **overrides: Fields to replace (e.g. ``discharge_location=``). + + Returns: + Dict passed to ``_make_event`` via ``_build_patient`` admissions + list; keys mirror post-ingest MIMIC attribute names. + """ adm = { "hadm_id": hadm_id, "admission_type": "EMERGENCY", @@ -583,20 +617,12 @@ def test_ama_label_positive(self): admissions=[ self._default_admission( hadm_id="100", - discharge_location=( - "LEFT AGAINST MEDICAL ADVI" - ), + discharge_location=("LEFT AGAINST MEDICAL ADVI"), ), ], - diagnoses=[ - {"hadm_id": "100", "icd9_code": "4019"} - ], - procedures=[ - {"hadm_id": "100", "icd9_code": "3893"} - ], - prescriptions=[ - {"hadm_id": "100", "drug": "Aspirin"} - ], + diagnoses=[{"hadm_id": "100", "icd9_code": "4019"}], + procedures=[{"hadm_id": "100", "icd9_code": "3893"}], + prescriptions=[{"hadm_id": "100", "drug": "Aspirin"}], ) samples = self.task(patient) self.assertEqual(len(samples), 1) @@ -611,15 +637,9 @@ def test_ama_label_negative(self): admissions=[ self._default_admission(hadm_id="200"), ], - diagnoses=[ - {"hadm_id": "200", "icd9_code": "25000"} - ], - procedures=[ - {"hadm_id": "200", "icd9_code": "3995"} - ], - prescriptions=[ - {"hadm_id": "200", "drug": "Metformin"} - ], + diagnoses=[{"hadm_id": "200", "icd9_code": "25000"}], + procedures=[{"hadm_id": "200", "icd9_code": "3995"}], + prescriptions=[{"hadm_id": "200", "drug": "Metformin"}], ) samples = self.task(patient) self.assertEqual(len(samples), 1) @@ -634,9 +654,7 @@ def test_multiple_admissions_mixed_labels(self): self._default_admission( hadm_id="301", admission_type="URGENT", - discharge_location=( - "LEFT AGAINST MEDICAL ADVI" - ), + discharge_location=("LEFT AGAINST MEDICAL ADVI"), timestamp=datetime(2150, 6, 1), dischtime="2150-06-05 10:00:00", ), @@ -674,15 +692,9 @@ def test_exclude_newborns(self): admission_type="NEWBORN", ), ], - diagnoses=[ - {"hadm_id": "700", "icd9_code": "V3000"} - ], - procedures=[ - {"hadm_id": "700", "icd9_code": "9904"} - ], - prescriptions=[ - {"hadm_id": "700", "drug": "Vitamin K"} - ], + diagnoses=[{"hadm_id": "700", "icd9_code": "V3000"}], + procedures=[{"hadm_id": "700", "icd9_code": "9904"}], + prescriptions=[{"hadm_id": "700", "drug": "Vitamin K"}], ) task_ex = AMAPredictionMIMIC3(exclude_newborns=True) self.assertEqual(task_ex(patient), []) @@ -722,15 +734,9 @@ def test_sample_keys(self): timestamp=datetime(2150, 2, 15), ), ], - diagnoses=[ - {"hadm_id": "800", "icd9_code": "4280"} - ], - procedures=[ - {"hadm_id": "800", "icd9_code": "3722"} - ], - prescriptions=[ - {"hadm_id": "800", "drug": "Furosemide"} - ], + diagnoses=[{"hadm_id": "800", "icd9_code": "4280"}], + procedures=[{"hadm_id": "800", "icd9_code": "3722"}], + prescriptions=[{"hadm_id": "800", "drug": "Furosemide"}], ) samples = self.task(patient) self.assertEqual(len(samples), 1) @@ -748,15 +754,9 @@ def test_demographics_baseline_tokens(self): timestamp=datetime(2150, 5, 1), ), ], - diagnoses=[ - {"hadm_id": "1000", "icd9_code": "4019"} - ], - procedures=[ - {"hadm_id": "1000", "icd9_code": "3893"} - ], - prescriptions=[ - {"hadm_id": "1000", "drug": "Aspirin"} - ], + diagnoses=[{"hadm_id": "1000", "icd9_code": "4019"}], + procedures=[{"hadm_id": "1000", "icd9_code": "3893"}], + prescriptions=[{"hadm_id": "1000", "drug": "Aspirin"}], gender="F", ) samples = self.task(patient) @@ -780,21 +780,13 @@ def test_race_separate_feature(self): timestamp=datetime(2150, 5, 1), ), ], - diagnoses=[ - {"hadm_id": "1001", "icd9_code": "4019"} - ], - procedures=[ - {"hadm_id": "1001", "icd9_code": "3893"} - ], - prescriptions=[ - {"hadm_id": "1001", "drug": "Aspirin"} - ], + diagnoses=[{"hadm_id": "1001", "icd9_code": "4019"}], + procedures=[{"hadm_id": "1001", "icd9_code": "3893"}], + prescriptions=[{"hadm_id": "1001", "drug": "Aspirin"}], ) samples = self.task(patient) self.assertIn("race", samples[0]) - self.assertEqual( - samples[0]["race"], ["race:Black"] - ) + self.assertEqual(samples[0]["race"], ["race:Black"]) def test_substance_use_positive(self): """Substance-use diagnosis -> has_substance_use=1.""" @@ -807,20 +799,12 @@ def test_substance_use_positive(self): timestamp=datetime(2150, 7, 1), ), ], - diagnoses=[ - {"hadm_id": "1400", "icd9_code": "29181"} - ], - procedures=[ - {"hadm_id": "1400", "icd9_code": "3893"} - ], - prescriptions=[ - {"hadm_id": "1400", "drug": "Lorazepam"} - ], + diagnoses=[{"hadm_id": "1400", "icd9_code": "29181"}], + procedures=[{"hadm_id": "1400", "icd9_code": "3893"}], + prescriptions=[{"hadm_id": "1400", "drug": "Lorazepam"}], ) samples = self.task(patient) - self.assertEqual( - samples[0]["has_substance_use"], [1.0] - ) + self.assertEqual(samples[0]["has_substance_use"], [1.0]) def test_substance_use_negative(self): """Non-substance diagnosis -> has_substance_use=0.""" @@ -833,20 +817,12 @@ def test_substance_use_negative(self): timestamp=datetime(2150, 8, 1), ), ], - diagnoses=[ - {"hadm_id": "1500", "icd9_code": "486"} - ], - procedures=[ - {"hadm_id": "1500", "icd9_code": "3893"} - ], - prescriptions=[ - {"hadm_id": "1500", "drug": "Levofloxacin"} - ], + diagnoses=[{"hadm_id": "1500", "icd9_code": "486"}], + procedures=[{"hadm_id": "1500", "icd9_code": "3893"}], + prescriptions=[{"hadm_id": "1500", "drug": "Levofloxacin"}], ) samples = self.task(patient) - self.assertEqual( - samples[0]["has_substance_use"], [0.0] - ) + self.assertEqual(samples[0]["has_substance_use"], [0.0]) def test_age_calculation(self): """Age computed from dob and admission timestamp.""" @@ -859,15 +835,9 @@ def test_age_calculation(self): dischtime="2150-06-20 12:00:00", ), ], - diagnoses=[ - {"hadm_id": "1100", "icd9_code": "4019"} - ], - procedures=[ - {"hadm_id": "1100", "icd9_code": "3893"} - ], - prescriptions=[ - {"hadm_id": "1100", "drug": "Aspirin"} - ], + diagnoses=[{"hadm_id": "1100", "icd9_code": "4019"}], + procedures=[{"hadm_id": "1100", "icd9_code": "3893"}], + prescriptions=[{"hadm_id": "1100", "drug": "Aspirin"}], dob="2100-01-01 00:00:00", ) samples = self.task(patient) @@ -883,15 +853,9 @@ def test_age_capped_at_90(self): timestamp=datetime(2150, 6, 15), ), ], - diagnoses=[ - {"hadm_id": "1200", "icd9_code": "4019"} - ], - procedures=[ - {"hadm_id": "1200", "icd9_code": "3893"} - ], - prescriptions=[ - {"hadm_id": "1200", "drug": "Aspirin"} - ], + diagnoses=[{"hadm_id": "1200", "icd9_code": "4019"}], + procedures=[{"hadm_id": "1200", "icd9_code": "3893"}], + prescriptions=[{"hadm_id": "1200", "drug": "Aspirin"}], dob="1850-01-01 00:00:00", ) samples = self.task(patient) @@ -908,20 +872,12 @@ def test_los_calculation(self): dischtime="2150-03-06 08:00:00", ), ], - diagnoses=[ - {"hadm_id": "1300", "icd9_code": "4019"} - ], - procedures=[ - {"hadm_id": "1300", "icd9_code": "3893"} - ], - prescriptions=[ - {"hadm_id": "1300", "drug": "Aspirin"} - ], + diagnoses=[{"hadm_id": "1300", "icd9_code": "4019"}], + procedures=[{"hadm_id": "1300", "icd9_code": "3893"}], + prescriptions=[{"hadm_id": "1300", "drug": "Aspirin"}], ) samples = self.task(patient) - self.assertAlmostEqual( - samples[0]["los"][0], 5.0, places=2 - ) + self.assertAlmostEqual(samples[0]["los"][0], 5.0, places=2) # ---------------------------------------------------------- # Multi-patient synthetic "dataset" (2 patients) @@ -934,9 +890,7 @@ def test_two_patient_synthetic_dataset(self): admissions=[ self._default_admission( hadm_id="A1", - discharge_location=( - "LEFT AGAINST MEDICAL ADVI" - ), + discharge_location=("LEFT AGAINST MEDICAL ADVI"), ethnicity="HISPANIC OR LATINO", insurance="Medicaid", diagnosis="HEROIN OVERDOSE", @@ -944,15 +898,9 @@ def test_two_patient_synthetic_dataset(self): dischtime="2150-01-03 12:00:00", ), ], - diagnoses=[ - {"hadm_id": "A1", "icd9_code": "96500"} - ], - procedures=[ - {"hadm_id": "A1", "icd9_code": "9604"} - ], - prescriptions=[ - {"hadm_id": "A1", "drug": "Naloxone"} - ], + diagnoses=[{"hadm_id": "A1", "icd9_code": "96500"}], + procedures=[{"hadm_id": "A1", "icd9_code": "9604"}], + prescriptions=[{"hadm_id": "A1", "drug": "Naloxone"}], gender="M", dob="2100-06-15 00:00:00", ) @@ -969,15 +917,9 @@ def test_two_patient_synthetic_dataset(self): dischtime="2150-03-12 08:00:00", ), ], - diagnoses=[ - {"hadm_id": "A2", "icd9_code": "78650"} - ], - procedures=[ - {"hadm_id": "A2", "icd9_code": "8856"} - ], - prescriptions=[ - {"hadm_id": "A2", "drug": "Aspirin"} - ], + diagnoses=[{"hadm_id": "A2", "icd9_code": "78650"}], + procedures=[{"hadm_id": "A2", "icd9_code": "8856"}], + prescriptions=[{"hadm_id": "A2", "drug": "Aspirin"}], gender="F", dob="2090-01-01 00:00:00", ) @@ -996,9 +938,7 @@ def test_two_patient_synthetic_dataset(self): s2 = all_samples[1] self.assertEqual(s2["ama"], 0) self.assertEqual(s2["race"], ["race:White"]) - self.assertIn( - "insurance:Private", s2["demographics"] - ) + self.assertIn("insurance:Private", s2["demographics"]) self.assertEqual(s2["has_substance_use"], [0.0]) self.assertAlmostEqual(s2["age"][0], 60.0, places=0) @@ -1010,8 +950,13 @@ class TestAMAAblationBaselines(unittest.TestCase): different subsets of features via the model's feature_keys parameter. """ - def setUp(self): - """Create a simple test patient with mixed demographics.""" + def setUp(self) -> None: + """Build one multi-visit mock patient covering baseline feature keys. + + Samples include all ``AMAPredictionMIMIC3.input_schema`` fields so + tests can reason about ``feature_keys`` subsets the same way models + do after ``set_task``. + """ self.task = AMAPredictionMIMIC3() self.patient = _build_patient( patient_id="ABLATION_TEST", @@ -1077,8 +1022,7 @@ def test_baseline_race_feature_present(self): race_val = sample["race"][0].split(":", 1)[1] self.assertIn( race_val, - ["White", "Black", "Hispanic", "Asian", - "Native American", "Other"], + ["White", "Black", "Hispanic", "Asian", "Native American", "Other"], ) def test_baseline_substance_use_feature_present(self): @@ -1161,8 +1105,14 @@ def test_baseline_minimal_features(self): sample = samples[0] baseline_keys = { - "demographics", "age", "los", "race", - "has_substance_use", "visit_id", "patient_id", "ama", + "demographics", + "age", + "los", + "race", + "has_substance_use", + "visit_id", + "patient_id", + "ama", } self.assertEqual(set(sample.keys()), baseline_keys) @@ -1248,8 +1198,8 @@ def test_ama_label_binary(self): def test_has_positive_and_negative_labels(self): labels = [sample["ama"] for sample in _shared_sample_dataset] - has_positive = any(l == 1 for l in labels) - has_negative = any(l == 0 for l in labels) + has_positive = any(label == 1 for label in labels) + has_negative = any(label == 0 for label in labels) self.assertTrue( has_positive and has_negative, @@ -1270,22 +1220,16 @@ def _create_model_with_features(self, feature_keys): embedding_dim = model.embedding_model.embedding_layers[ feature_keys[0] ].out_features - model.fc = torch.nn.Linear( - len(feature_keys) * embedding_dim, output_size - ) + model.fc = torch.nn.Linear(len(feature_keys) * embedding_dim, output_size) return model def test_baseline_model_can_be_created(self): - model = self._create_model_with_features( - ["demographics", "age", "los"] - ) + model = self._create_model_with_features(["demographics", "age", "los"]) self.assertIsNotNone(model) self.assertIsNotNone(model.fc) def test_baseline_plus_race_model(self): - model = self._create_model_with_features( - ["demographics", "age", "los", "race"] - ) + model = self._create_model_with_features(["demographics", "age", "los", "race"]) self.assertIsNotNone(model) self.assertIsNotNone(model.fc) @@ -1297,9 +1241,7 @@ def test_baseline_plus_race_plus_substance_model(self): self.assertIsNotNone(model.fc) def test_baseline_forward_pass(self): - model = self._create_model_with_features( - ["demographics", "age", "los"] - ) + model = self._create_model_with_features(["demographics", "age", "los"]) train_ds, _, test_ds = split_by_patient( _shared_sample_dataset, [0.8, 0.0, 0.2], seed=0 @@ -1316,9 +1258,7 @@ def test_baseline_forward_pass(self): self.assertEqual(output["y_prob"].shape[0], len(test_ds)) def test_baseline_plus_race_forward_pass(self): - model = self._create_model_with_features( - ["demographics", "age", "los", "race"] - ) + model = self._create_model_with_features(["demographics", "age", "los", "race"]) train_ds, _, test_ds = split_by_patient( _shared_sample_dataset, [0.8, 0.0, 0.2], seed=0 @@ -1361,9 +1301,7 @@ def test_training_completes_quickly(self): train_ds, _, test_ds = split_by_patient( _shared_sample_dataset, [0.6, 0.0, 0.4], seed=0 ) - train_dl = get_dataloader( - train_ds, batch_size=8, shuffle=True - ) + train_dl = get_dataloader(train_ds, batch_size=8, shuffle=True) model = LogisticRegression( dataset=_shared_sample_dataset, @@ -1374,9 +1312,7 @@ def test_training_completes_quickly(self): embedding_dim = model.embedding_model.embedding_layers[ "demographics" ].out_features - model.fc = torch.nn.Linear( - 3 * embedding_dim, output_size - ) + model.fc = torch.nn.Linear(3 * embedding_dim, output_size) trainer = Trainer(model=model) @@ -1398,9 +1334,7 @@ def test_multiple_splits_complete_quickly(self): [0.6, 0.0, 0.4], seed=split_seed, ) - train_dl = get_dataloader( - train_ds, batch_size=8, shuffle=True - ) + train_dl = get_dataloader(train_ds, batch_size=8, shuffle=True) model = LogisticRegression( dataset=_shared_sample_dataset, @@ -1411,9 +1345,7 @@ def test_multiple_splits_complete_quickly(self): embedding_dim = model.embedding_model.embedding_layers[ "demographics" ].out_features - model.fc = torch.nn.Linear( - 3 * embedding_dim, output_size - ) + model.fc = torch.nn.Linear(3 * embedding_dim, output_size) trainer = Trainer(model=model) trainer.train( @@ -1477,10 +1409,12 @@ def _ama_int(x): ] self.assertEqual(sum(labels), CURATED_SYNTHETIC_AMA_POSITIVE) self.assertEqual( - labels.count(0), CURATED_SYNTHETIC_AMA_NEGATIVE, + labels.count(0), + CURATED_SYNTHETIC_AMA_NEGATIVE, ) self.assertEqual( - labels.count(1), CURATED_SYNTHETIC_AMA_POSITIVE, + labels.count(1), + CURATED_SYNTHETIC_AMA_POSITIVE, ) From 57d4720d796f45c4b216021624288d2740c42207 Mon Sep 17 00:00:00 2001 From: Madhav Kanda Date: Wed, 15 Apr 2026 21:43:39 -0500 Subject: [PATCH 08/10] cleaned up results for easeier understanding --- ...mic3_ama_prediction_logistic_regression.py | 59 +++++++------------ examples/mimic3_ama_prediction_rnn.py | 28 +-------- tests/core/test_mimic3_ama_prediction.py | 2 +- 3 files changed, 25 insertions(+), 64 deletions(-) diff --git a/examples/mimic3_ama_prediction_logistic_regression.py b/examples/mimic3_ama_prediction_logistic_regression.py index bb64e6233..25710e690 100644 --- a/examples/mimic3_ama_prediction_logistic_regression.py +++ b/examples/mimic3_ama_prediction_logistic_regression.py @@ -2,8 +2,8 @@ This script demonstrates ``AMAPredictionMIMIC3`` with three feature ablations, trains a ``LogisticRegression`` model on processed samples, and evaluates how -demographic and administrative features affect AMA discharge prediction and -fairness. Labels come from ``discharge_location``; inputs follow the task +demographic and administrative features affect AMA discharge prediction. +Labels come from ``discharge_location``; inputs follow the task ``input_schema`` / ``output_schema`` (multi_hot, tensor, binary). Paper: @@ -24,8 +24,6 @@ (patient-level ``split_by_patient``). - Subgroup AUROC by race, age band (Young / Middle / Senior), and insurance category. - - Fairness-style summaries: demographic parity (% predicted AMA per - group) and equal opportunity (TPR per group). Usage: # Default: synthetic exhaustive grid when ``--root`` is omitted @@ -196,7 +194,13 @@ def append_visit( Returns: Next ``icustay_id`` if an ICU row was written; else unchanged id. """ - dob = datetime(2000, 1, 1) - timedelta(days=int(age_years * 365)) + # Align DOB with ``admittime`` so ``AMAPredictionMIMIC3`` age (year diff + # from PATIENTS.DOB to admission) matches ``age_years``. A fixed DOB + # anchor with 2150 admissions would yield ~150y ages capped at 90, so + # every row mapped to "Senior (65+)" in subgroup reports. + admit_time = datetime(2150, 1, 1) + timedelta(days=day_offset) + discharge_time = admit_time + timedelta(days=7) + dob = admit_time - timedelta(days=int(age_years * 365)) patients_data.append( { "subject_id": subject_id, @@ -208,8 +212,6 @@ def append_visit( "expire_flag": 0, } ) - admit_time = datetime(2150, 1, 1) + timedelta(days=day_offset) - discharge_time = admit_time + timedelta(days=7) admissions_data.append( { "subject_id": subject_id, @@ -355,8 +357,13 @@ def append_visit( age_at_visit = int( np.random.choice([25, 45, 65, 85]) + np.random.randint(-5, 5) ) - dob = datetime(2000, 1, 1) - timedelta(days=age_at_visit * 365) + n_admissions = max( + 1, + int(np.random.poisson(avg_admissions_per_patient)), + ) + first_admit = datetime(2150, 1, 1) + dob = first_admit - timedelta(days=int(age_at_visit * 365)) patients_data.append( { "subject_id": subject_id, @@ -369,10 +376,6 @@ def append_visit( } ) - n_admissions = max( - 1, - int(np.random.poisson(avg_admissions_per_patient)), - ) for j in range(n_admissions): admit_time = datetime(2150, 1, 1) + timedelta(days=int(j * 100)) discharge_time = admit_time + timedelta( @@ -478,7 +481,7 @@ def _build_demographics_lookup( dataset: Any, task: AMAPredictionMIMIC3, ) -> Dict[Tuple[str, str], Dict[str, Any]]: - """Build a post-hoc lookup for fairness reporting (not model inputs). + """Build a post-hoc lookup for subgroup AUROC labels (not model inputs). Re-runs ``task`` on each patient from ``dataset`` to recover string race label, scalar age, and insurance token for each visit. Keys match @@ -688,10 +691,8 @@ def _run_single_split( print(f" train failed: {exc}") return None - # Fairness slices use ``lookup`` (not part of the forward batch tensors). + # Subgroup labels use ``lookup`` (not part of the forward batch tensors). y_prob, y_true, groups = _get_predictions(model, test_dl, lookup) - threshold = 0.5 - y_pred = (y_prob >= threshold).astype(int) overall_auroc = _safe_auroc(y_true, y_prob) @@ -703,12 +704,9 @@ def _run_single_split( n = int(mask.sum()) if n < 2: continue - yt, yp, yd = y_true[mask], y_prob[mask], y_pred[mask] - pos = yt.sum() + yt, yp = y_true[mask], y_prob[mask] subgroup[attr_name][grp] = { "auroc": _safe_auroc(yt, yp), - "pct_pred": float(yd.mean()) * 100, - "tpr": float(yd[yt == 1].mean()) * 100 if pos > 0 else float("nan"), "n": n, } @@ -771,21 +769,17 @@ def _aggregate( all_grps.update(r["subgroups"][attr].keys()) for grp in sorted(all_grps): - aurocs, pcts, tprs, ns = [], [], [], [] + aurocs, ns = [], [] for r in valid: m = r["subgroups"].get(attr, {}).get(grp) if m is None: continue aurocs.append(m["auroc"]) - pcts.append(m["pct_pred"]) - tprs.append(m["tpr"]) ns.append(m["n"]) agg["subgroups"][attr][grp] = { "auroc_mean": _nanmean(aurocs), "auroc_std": _nanstd(aurocs), - "pct_pred_mean": _nanmean(pcts), - "tpr_mean": _nanmean(tprs), "n_avg": int(np.mean(ns)) if ns else 0, } return agg @@ -806,7 +800,7 @@ def _print_results( feature_keys: List[str], agg: Optional[Dict[str, Any]], ) -> None: - """Pretty-print one ablation block (overall, subgroup AUROC, fairness). + """Pretty-print one ablation block (overall and subgroup AUROC). Args: name: Baseline label (e.g. ``"BASELINE+RACE"``). @@ -840,17 +834,6 @@ def _print_results( a_str = f"{_fmt(m['auroc_mean'])}+/-{_fmt(m['auroc_std'])}" print(f" {grp:<20} {a_str:>15} {m['n_avg']:>7}") - print("\n 3. Fairness Metrics") - print(" Demographic Parity (% Predicted AMA):") - for attr, grps in agg["subgroups"].items(): - parts = [f"{g}: {_fmt(m['pct_pred_mean'], 2)}%" for g, m in grps.items()] - print(f" {attr}: {', '.join(parts)}") - - print(" Equal Opportunity (True Positive Rate):") - for attr, grps in agg["subgroups"].items(): - parts = [f"{g}: {_fmt(m['tpr_mean'], 2)}%" for g, m in grps.items()] - print(f" {attr}: {', '.join(parts)}") - # ------------------------------------------------------------------ # Main @@ -863,7 +846,7 @@ def main() -> None: Pipeline: 1. Resolve synthetic vs real ``root`` and LitData ``cache_dir``. 2. ``MIMIC3Dataset`` -> ``AMAPredictionMIMIC3`` -> ``SampleDataset``. - 3. Demographics lookup for fairness slices (not passed to the model). + 3. Demographics lookup for subgroup AUROC (not passed to the model). 4. For each entry in ``BASELINES``, repeat ``--splits`` train/eval loops with the corresponding ``feature_keys`` (task I/O schema unchanged; model uses a subset of keys). diff --git a/examples/mimic3_ama_prediction_rnn.py b/examples/mimic3_ama_prediction_rnn.py index 9760227c8..1e774cec3 100644 --- a/examples/mimic3_ama_prediction_rnn.py +++ b/examples/mimic3_ama_prediction_rnn.py @@ -21,7 +21,6 @@ For each baseline configuration we report: - Overall AUROC over N random 60/40 patient-level splits. - Subgroup AUROC by race, age group, and insurance. - - Demographic parity and equal-opportunity (TPR) tables per subgroup. Synthetic data: ``generate_synthetic_mimic3`` is imported from @@ -239,7 +238,7 @@ def _run_single_split( epochs: int, batch_size: int = 32, ) -> Optional[Dict[str, Any]]: - """One patient-level split: train RNN, then test + fairness stats. + """One patient-level split: train RNN, then test + subgroup AUROC. Args: sample_dataset: AMA task samples. @@ -275,12 +274,9 @@ def _run_single_split( return None y_prob, y_true, groups = _get_predictions(model, test_dl, lookup) - threshold = 0.5 - y_pred = (y_prob >= threshold).astype(int) overall_auroc = _safe_auroc(y_true, y_prob) - # Nested: attribute name -> group label -> AUROC, % predicted AMA, TPR, n. subgroup = {} for attr_name, attr_vals in groups.items(): subgroup[attr_name] = {} @@ -289,12 +285,9 @@ def _run_single_split( n = int(mask.sum()) if n < 2: continue - yt, yp, yd = y_true[mask], y_prob[mask], y_pred[mask] - pos = yt.sum() + yt, yp = y_true[mask], y_prob[mask] subgroup[attr_name][grp] = { "auroc": _safe_auroc(yt, yp), - "pct_pred": float(yd.mean()) * 100, - "tpr": float(yd[yt == 1].mean()) * 100 if pos > 0 else float("nan"), "n": n, } @@ -348,21 +341,17 @@ def _aggregate( all_grps.update(r["subgroups"][attr].keys()) for grp in sorted(all_grps): - aurocs, pcts, tprs, ns = [], [], [], [] + aurocs, ns = [], [] for r in valid: m = r["subgroups"].get(attr, {}).get(grp) if m is None: continue aurocs.append(m["auroc"]) - pcts.append(m["pct_pred"]) - tprs.append(m["tpr"]) ns.append(m["n"]) agg["subgroups"][attr][grp] = { "auroc_mean": _nanmean(aurocs), "auroc_std": _nanstd(aurocs), - "pct_pred_mean": _nanmean(pcts), - "tpr_mean": _nanmean(tprs), "n_avg": int(np.mean(ns)) if ns else 0, } return agg @@ -411,17 +400,6 @@ def _print_results( a_str = f"{_fmt(m['auroc_mean'])}+/-{_fmt(m['auroc_std'])}" print(f" {grp:<20} {a_str:>15} {m['n_avg']:>7}") - print("\n 3. Fairness Metrics") - print(" Demographic Parity (% Predicted AMA):") - for attr, grps in agg["subgroups"].items(): - parts = [f"{g}: {_fmt(m['pct_pred_mean'], 2)}%" for g, m in grps.items()] - print(f" {attr}: {', '.join(parts)}") - - print(" Equal Opportunity (True Positive Rate):") - for attr, grps in agg["subgroups"].items(): - parts = [f"{g}: {_fmt(m['tpr_mean'], 2)}%" for g, m in grps.items()] - print(f" {attr}: {', '.join(parts)}") - # ------------------------------------------------------------------ # Main diff --git a/tests/core/test_mimic3_ama_prediction.py b/tests/core/test_mimic3_ama_prediction.py index f08add200..7d975f018 100644 --- a/tests/core/test_mimic3_ama_prediction.py +++ b/tests/core/test_mimic3_ama_prediction.py @@ -15,7 +15,7 @@ and absence of clinical code fields in samples. - Integration: curated five-row gzipped MIMIC-style CSVs with ``MIMIC3Dataset`` + ``set_task`` + ``LogisticRegression`` forward passes - and short ``Trainer`` smoke runs. + and short ``Trainer`` smoke runs (example CLI tables are not asserted). - Synthetic generator sanity: exhaustive grid patient row count. Paper (task motivation): From 85de8dd945625fced37391909c4afd0462b32d5f Mon Sep 17 00:00:00 2001 From: Madhav Kanda Date: Wed, 15 Apr 2026 21:48:02 -0500 Subject: [PATCH 09/10] removed uneccesary files change --- docs/api/datasets/pyhealth.datasets.MIMIC3Dataset.rst | 8 +------- docs/api/tasks/pyhealth.tasks.ama_prediction.rst | 6 ------ 2 files changed, 1 insertion(+), 13 deletions(-) diff --git a/docs/api/datasets/pyhealth.datasets.MIMIC3Dataset.rst b/docs/api/datasets/pyhealth.datasets.MIMIC3Dataset.rst index 381357eff..3574fd3d8 100644 --- a/docs/api/datasets/pyhealth.datasets.MIMIC3Dataset.rst +++ b/docs/api/datasets/pyhealth.datasets.MIMIC3Dataset.rst @@ -6,10 +6,4 @@ The open Medical Information Mart for Intensive Care (MIMIC-III) database, refer .. autoclass:: pyhealth.datasets.MIMIC3Dataset :members: :undoc-members: - :show-inheritance: - -.. seealso:: - - Administrative AMA discharge prediction (no ICD / Rx tables required): - :class:`pyhealth.tasks.ama_prediction.AMAPredictionMIMIC3`. - \ No newline at end of file + :show-inheritance: \ No newline at end of file diff --git a/docs/api/tasks/pyhealth.tasks.ama_prediction.rst b/docs/api/tasks/pyhealth.tasks.ama_prediction.rst index aaac74de4..83afa05e4 100644 --- a/docs/api/tasks/pyhealth.tasks.ama_prediction.rst +++ b/docs/api/tasks/pyhealth.tasks.ama_prediction.rst @@ -5,12 +5,6 @@ Against-medical-advice (AMA) discharge on MIMIC-III using administrative features only. Use with :class:`pyhealth.datasets.MIMIC3Dataset` and ``tables=[]`` unless you extend the cohort. -.. seealso:: - - * :doc:`../tasks` — task / processor overview (``set_task`` workflow). - * :doc:`../../how_to_contribute` — contribution and PR expectations. - * :doc:`../datasets/pyhealth.datasets.MIMIC3Dataset` — source dataset. - .. autoclass:: pyhealth.tasks.ama_prediction.AMAPredictionMIMIC3 :members: :undoc-members: From b51583bfbb4bab294d41536ee4f310a87b54e0cb Mon Sep 17 00:00:00 2001 From: Madhav Kanda Date: Wed, 15 Apr 2026 22:47:11 -0500 Subject: [PATCH 10/10] Restore MIMIC3Dataset.rst to match upstream master --- docs/api/datasets/pyhealth.datasets.MIMIC3Dataset.rst | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/docs/api/datasets/pyhealth.datasets.MIMIC3Dataset.rst b/docs/api/datasets/pyhealth.datasets.MIMIC3Dataset.rst index 3574fd3d8..0b7ec27c0 100644 --- a/docs/api/datasets/pyhealth.datasets.MIMIC3Dataset.rst +++ b/docs/api/datasets/pyhealth.datasets.MIMIC3Dataset.rst @@ -1,4 +1,4 @@ -pyhealth.datasets.MIMIC3Dataset +pyhealth.datasets.MIMIC3Dataset =================================== The open Medical Information Mart for Intensive Care (MIMIC-III) database, refer to `doc `_ for more information. We process this database into well-structured dataset object and give user the **best flexibility and convenience** for supporting modeling and analysis. @@ -6,4 +6,10 @@ The open Medical Information Mart for Intensive Care (MIMIC-III) database, refer .. autoclass:: pyhealth.datasets.MIMIC3Dataset :members: :undoc-members: - :show-inheritance: \ No newline at end of file + :show-inheritance: + + + + + + \ No newline at end of file