diff --git a/docs/api/tasks.rst b/docs/api/tasks.rst index 399b8f1aa..8d25ced19 100644 --- a/docs/api/tasks.rst +++ b/docs/api/tasks.rst @@ -205,6 +205,7 @@ Available Tasks .. toctree:: :maxdepth: 3 + AMA Prediction (MIMIC-III) Base Task In-Hospital Mortality (MIMIC-IV) MIMIC-III ICD-9 Coding diff --git a/docs/api/tasks/pyhealth.tasks.ama_prediction.rst b/docs/api/tasks/pyhealth.tasks.ama_prediction.rst new file mode 100644 index 000000000..83afa05e4 --- /dev/null +++ b/docs/api/tasks/pyhealth.tasks.ama_prediction.rst @@ -0,0 +1,11 @@ +pyhealth.tasks.ama_prediction +================================= + +Against-medical-advice (AMA) discharge on MIMIC-III using administrative +features only. Use with :class:`pyhealth.datasets.MIMIC3Dataset` and +``tables=[]`` unless you extend the cohort. + +.. autoclass:: pyhealth.tasks.ama_prediction.AMAPredictionMIMIC3 + :members: + :undoc-members: + :show-inheritance: diff --git a/examples/mimic3_ama_prediction_logistic_regression.py b/examples/mimic3_ama_prediction_logistic_regression.py new file mode 100644 index 000000000..25710e690 --- /dev/null +++ b/examples/mimic3_ama_prediction_logistic_regression.py @@ -0,0 +1,1022 @@ +"""Ablation study for MIMIC-III Against-Medical-Advice (AMA) discharge prediction. + +This script demonstrates ``AMAPredictionMIMIC3`` with three feature ablations, +trains a ``LogisticRegression`` model on processed samples, and evaluates how +demographic and administrative features affect AMA discharge prediction. +Labels come from ``discharge_location``; inputs follow the task +``input_schema`` / ``output_schema`` (multi_hot, tensor, binary). + +Paper: + Boag, W.; Suresh, H.; Celi, L. A.; Szolovits, P.; and Ghassemi, M. + "Racial Disparities and Mistrust in End-of-Life Care." Machine Learning + for Healthcare Conference, PMLR 106:211-235, 2018. + +Ablation configurations tested: + 1. BASELINE: ``feature_keys`` = demographics (gender, insurance) + age + + length of stay (LOS). + 2. BASELINE+RACE: adds normalized ethnicity as the ``race`` feature. + 3. BASELINE+RACE+SUBSTANCE: adds ``has_substance_use`` from admission + diagnosis text. + +Results: + For each baseline configuration we report: + - Overall AUROC averaged over N random 60/40 train/test splits + (patient-level ``split_by_patient``). + - Subgroup AUROC by race, age band (Young / Middle / Senior), and + insurance category. + +Usage: + # Default: synthetic exhaustive grid when ``--root`` is omitted + # (``--data-source auto`` with no root -> synthetic). + cd /path/to/PyHealth && \\ + python examples/mimic3_ama_prediction_logistic_regression.py + + # Synthetic random cohort (faster than exhaustive) + cd /path/to/PyHealth && \\ + python examples/mimic3_ama_prediction_logistic_regression.py \\ + --data-source synthetic --synthetic-mode random --patients 200 + + # Full MIMIC-III 1.4 on disk (replace root with your PhysioNet path) + cd /path/to/PyHealth && \\ + python examples/mimic3_ama_prediction_logistic_regression.py \\ + --data-source real --root /path/to/mimic-iii/1.4 \\ + --splits 100 --epochs 10 + +""" + +import argparse +import gzip +import itertools +import tempfile +import time +from datetime import datetime, timedelta +from pathlib import Path +from typing import Any, Dict, List, Optional, Tuple + +import numpy as np +import pandas as pd +import torch +from sklearn.metrics import roc_auc_score + +from pyhealth.datasets import MIMIC3Dataset, get_dataloader, split_by_patient +from pyhealth.models import LogisticRegression +from pyhealth.tasks import AMAPredictionMIMIC3 +from pyhealth.trainer import Trainer + + +def generate_synthetic_mimic3( + root: str, + n_patients: int = 50, + avg_admissions_per_patient: int = 2, + seed: int = 42, + mode: str = "exhaustive", +) -> None: + """Write gzipped PATIENTS, ADMISSIONS, and ICUSTAYS CSVs for local demos. + + ``mode="exhaustive"`` (default) emits one patient per element of the + Cartesian product of task-relevant factors so every combination appears + at least once: gender × MIMIC ethnicity string × raw insurance × age + band (maps to Young / Middle / Senior) × AMA vs non-AMA discharge × + substance vs non-substance diagnosis text. A few extra rows cover + NEWBORN filtering, EXPIRED, SNF, and missing insurance (``Other``). + + ``mode="random"`` reproduces the legacy stochastic generator (use for + quick tests via ``--synthetic-mode random``). + + Args: + root: Directory to write CSV files to. + n_patients: Used only when ``mode="random"`` (patient count). + avg_admissions_per_patient: Poisson mean per patient (random mode). + seed: RNG seed (random mode only). + mode: ``"exhaustive"`` or ``"random"``. + """ + # Writes PATIENTS.csv.gz / ADMISSIONS.csv.gz / ICUSTAYS.csv.gz under ``root`` + # with columns compatible with ``MIMIC3Dataset`` and ``AMAPredictionMIMIC3``. + root_path = Path(root) + root_path.mkdir(parents=True, exist_ok=True) + + patients_data: List[dict] = [] + admissions_data: List[dict] = [] + icustays_data: List[dict] = [] + + genders = ["M", "F"] + ethnicities = [ + "WHITE", + "BLACK/AFRICAN AMERICAN", + "HISPANIC OR LATINO", + "ASIAN - CHINESE", + "AMERICAN INDIAN/ALASKA NATIVE", + "UNKNOWN/NOT SPECIFIED", + ] + # Raw insurance values (normalize to Public / Private / Self Pay / Other) + insurances_raw: List[Optional[str]] = [ + "Medicare", + "Medicaid", + "Government", + "Private", + "Self Pay", + None, + ] + admission_types = ["EMERGENCY", "URGENT", "NEWBORN", "ELECTIVE"] + discharge_locations = [ + "HOME", + "SKILLED NURSING FACILITY", + "LONG TERM CARE", + "LEFT AGAINST MEDICAL ADVI", + "EXPIRED", + ] + diagnoses_substance = [ + "ALCOHOL WITHDRAWAL", + "OPIOID DEPENDENCE", + "HEROIN OVERDOSE", + "COCAINE INTOXICATION", + "DRUG WITHDRAWAL SEIZURE", + "ETOH ABUSE", + "SUBSTANCE ABUSE", + "OVERDOSE - ACCIDENTAL", + ] + diagnoses_other = [ + "PNEUMONIA", + "ACUTE MYOCARDIAL INFARCTION", + "CHEST PAIN", + "CONGESTIVE HEART FAILURE", + "SEPSIS", + "ACUTE KIDNEY INJURY", + "ACUTE RESPIRATORY FAILURE", + "ASPIRATION", + ] + + def write_csv_gz(filename: str, data: List[dict]) -> None: + """Serialize ``data`` rows to ``{filename}.gz`` as CSV (gzip). + + Args: + filename: Base name without ``.gz`` (e.g. ``"PATIENTS.csv"``). + data: List of row dicts matching MIMIC-III column names. + + Returns: + None (writes file under ``root_path``). + """ + df = pd.DataFrame(data) + filepath = root_path / f"{filename}.gz" + with gzip.open(filepath, "wt") as f: + df.to_csv(f, index=False) + print(f" Created {filename}.gz ({len(data)} rows)") + + def append_visit( + subject_id: int, + hadm_id: int, + icustay_id: int, + *, + gender: str, + age_years: int, + ethnicity: str, + insurance_raw: Optional[str], + admission_type: str, + discharge_loc: str, + diagnosis: str, + day_offset: int, + ) -> int: + """Append one synthetic patient row, one admission, and one ICU stay. + + Args: + subject_id: MIMIC ``subject_id``. + hadm_id: Hospital admission id. + icustay_id: ICU stay id; incremented on success. + gender: Patient gender (``M`` / ``F``). + age_years: Approximate age used to synthesize ``dob``. + ethnicity: Raw MIMIC ethnicity string. + insurance_raw: Raw insurance (may be ``None`` for ``Other`` path). + admission_type: MIMIC ``admission_type`` (e.g. ``NEWBORN``). + discharge_loc: ``discharge_location`` (AMA vs non-AMA). + diagnosis: Free-text ``diagnosis`` for substance-use flag. + day_offset: Day index to spread ``admittime`` values. + + Returns: + Next ``icustay_id`` if an ICU row was written; else unchanged id. + """ + # Align DOB with ``admittime`` so ``AMAPredictionMIMIC3`` age (year diff + # from PATIENTS.DOB to admission) matches ``age_years``. A fixed DOB + # anchor with 2150 admissions would yield ~150y ages capped at 90, so + # every row mapped to "Senior (65+)" in subgroup reports. + admit_time = datetime(2150, 1, 1) + timedelta(days=day_offset) + discharge_time = admit_time + timedelta(days=7) + dob = admit_time - timedelta(days=int(age_years * 365)) + patients_data.append( + { + "subject_id": subject_id, + "gender": gender, + "dob": dob.strftime("%Y-%m-%d %H:%M:%S"), + "dod": None, + "dod_hosp": None, + "dod_ssn": None, + "expire_flag": 0, + } + ) + admissions_data.append( + { + "subject_id": subject_id, + "hadm_id": hadm_id, + "admission_type": admission_type, + "admission_location": "EMERGENCY ROOM ADMIT", + "insurance": insurance_raw, + "language": "ENGLISH", + "religion": "CHRISTIAN", + "marital_status": "SINGLE", + "ethnicity": ethnicity, + "edregtime": admit_time.strftime("%Y-%m-%d %H:%M:%S"), + "edouttime": admit_time.strftime("%Y-%m-%d %H:%M:%S"), + "diagnosis": diagnosis, + "discharge_location": discharge_loc, + "dischtime": discharge_time.strftime("%Y-%m-%d %H:%M:%S"), + "admittime": admit_time.strftime("%Y-%m-%d %H:%M:%S"), + "hospital_expire_flag": 1 if discharge_loc == "EXPIRED" else 0, + } + ) + icu_intime = admit_time + timedelta(hours=2) + icu_outtime = discharge_time - timedelta(hours=2) + if icu_intime < icu_outtime: + icustays_data.append( + { + "subject_id": subject_id, + "hadm_id": hadm_id, + "icustay_id": icustay_id, + "first_careunit": "MICU", + "last_careunit": "MICU", + "dbsource": "metavision", + "intime": icu_intime.strftime("%Y-%m-%d %H:%M:%S"), + "outtime": icu_outtime.strftime("%Y-%m-%d %H:%M:%S"), + } + ) + return icustay_id + 1 + return icustay_id + + if mode == "exhaustive": + # Ages map to Young (18-44) / Middle (45-64) / Senior (65+). + age_bands = [30, 52, 72] + discharge_ama = ["HOME", "LEFT AGAINST MEDICAL ADVI"] + diagnosis_texts = ["PNEUMONIA", "ALCOHOL WITHDRAWAL"] + combos = itertools.product( + genders, + ethnicities, + insurances_raw, + age_bands, + discharge_ama, + diagnosis_texts, + ) + subject_id = 1 + hadm_id = 100 + icustay_id = 1000 + for idx, (gender, eth, ins_raw, age_y, disch, diag) in enumerate(combos): + icustay_id = append_visit( + subject_id, + hadm_id, + icustay_id, + gender=gender, + age_years=age_y, + ethnicity=eth, + insurance_raw=ins_raw, + admission_type="EMERGENCY", + discharge_loc=disch, + diagnosis=diag, + day_offset=idx % 500, + ) + subject_id += 1 + hadm_id += 1 + + # Extra coverage: SNF discharge, EXPIRED, NEWBORN (skipped by task). + exhaustive_grid_n = ( + len(genders) + * len(ethnicities) + * len(insurances_raw) + * len(age_bands) + * len(discharge_ama) + * len(diagnosis_texts) + ) + extra_rows = ( + ( + "M", + 45, + "WHITE", + "Private", + "EMERGENCY", + "SKILLED NURSING FACILITY", + "SEPSIS", + ), + ( + "F", + 55, + "BLACK/AFRICAN AMERICAN", + "Medicaid", + "EMERGENCY", + "EXPIRED", + "CHEST PAIN", + ), + ( + "M", + 28, + "HISPANIC OR LATINO", + "Private", + "NEWBORN", + "HOME", + "PNEUMONIA", + ), + ) + for k, extra in enumerate(extra_rows): + g, age_y, eth, ins, adm_type, disch, diag = extra + icustay_id = append_visit( + subject_id, + hadm_id, + icustay_id, + gender=g, + age_years=age_y, + ethnicity=eth, + insurance_raw=ins, + admission_type=adm_type, + discharge_loc=disch, + diagnosis=diag, + day_offset=(exhaustive_grid_n + k) % 500, + ) + subject_id += 1 + hadm_id += 1 + + print( + f"Generating exhaustive synthetic MIMIC-III in {root_path} " + f"({len(patients_data)} patients, cross-product + edge rows)...", + ) + elif mode == "random": + np.random.seed(seed) + subject_id = 1 + hadm_id = 100 + icustay_id = 1000 + insurances = ["Medicare", "Medicaid", "Private", "Self Pay", "Government"] + for i in range(n_patients): + gender = genders[i % len(genders)] + ethnicity = ethnicities[i % len(ethnicities)] + insurance = insurances[i % len(insurances)] + + age_at_visit = int( + np.random.choice([25, 45, 65, 85]) + np.random.randint(-5, 5) + ) + + n_admissions = max( + 1, + int(np.random.poisson(avg_admissions_per_patient)), + ) + first_admit = datetime(2150, 1, 1) + dob = first_admit - timedelta(days=int(age_at_visit * 365)) + patients_data.append( + { + "subject_id": subject_id, + "gender": gender, + "dob": dob.strftime("%Y-%m-%d %H:%M:%S"), + "dod": None, + "dod_hosp": None, + "dod_ssn": None, + "expire_flag": 0, + } + ) + + for j in range(n_admissions): + admit_time = datetime(2150, 1, 1) + timedelta(days=int(j * 100)) + discharge_time = admit_time + timedelta( + days=int(np.random.randint(1, 30)), + ) + + admission_type = admission_types[(i + j) % len(admission_types)] + + if np.random.random() < 0.15: + discharge_loc = "LEFT AGAINST MEDICAL ADVI" + elif np.random.random() < 0.05: + discharge_loc = "EXPIRED" + else: + discharge_loc = discharge_locations[ + (i + j) % (len(discharge_locations) - 2) + ] + + if np.random.random() < 0.2: + diagnosis = diagnoses_substance[ + np.random.randint(0, len(diagnoses_substance)) + ] + else: + diagnosis = diagnoses_other[ + np.random.randint(0, len(diagnoses_other)) + ] + + admissions_data.append( + { + "subject_id": subject_id, + "hadm_id": hadm_id, + "admission_type": admission_type, + "admission_location": "EMERGENCY ROOM ADMIT", + "insurance": insurance, + "language": "ENGLISH", + "religion": "CHRISTIAN", + "marital_status": "SINGLE", + "ethnicity": ethnicity, + "edregtime": admit_time.strftime("%Y-%m-%d %H:%M:%S"), + "edouttime": admit_time.strftime("%Y-%m-%d %H:%M:%S"), + "diagnosis": diagnosis, + "discharge_location": discharge_loc, + "dischtime": discharge_time.strftime("%Y-%m-%d %H:%M:%S"), + "admittime": admit_time.strftime("%Y-%m-%d %H:%M:%S"), + "hospital_expire_flag": 1 if discharge_loc == "EXPIRED" else 0, + } + ) + + icu_intime = admit_time + timedelta( + hours=int(np.random.randint(0, 12)), + ) + icu_outtime = discharge_time - timedelta( + hours=int(np.random.randint(0, 12)), + ) + + if icu_intime < icu_outtime: + icustays_data.append( + { + "subject_id": subject_id, + "hadm_id": hadm_id, + "icustay_id": icustay_id, + "first_careunit": "MICU", + "last_careunit": "MICU", + "dbsource": "metavision", + "intime": icu_intime.strftime("%Y-%m-%d %H:%M:%S"), + "outtime": icu_outtime.strftime("%Y-%m-%d %H:%M:%S"), + } + ) + icustay_id += 1 + + hadm_id += 1 + + subject_id += 1 + + print(f"Generating random synthetic MIMIC-III in {root_path}...") + else: + raise ValueError(f"Unknown mode {mode!r}; use 'exhaustive' or 'random'.") + + write_csv_gz("PATIENTS.csv", patients_data) + write_csv_gz("ADMISSIONS.csv", admissions_data) + write_csv_gz("ICUSTAYS.csv", icustays_data) + print("Done.") + + +BASELINES = { + "BASELINE": ["demographics", "age", "los"], + "BASELINE+RACE": ["demographics", "age", "los", "race"], + "BASELINE+RACE+SUBSTANCE": [ + "demographics", + "age", + "los", + "race", + "has_substance_use", + ], +} + + +# ------------------------------------------------------------------ +# Helpers -- demographics lookup +# ------------------------------------------------------------------ + + +def _build_demographics_lookup( + dataset: Any, + task: AMAPredictionMIMIC3, +) -> Dict[Tuple[str, str], Dict[str, Any]]: + """Build a post-hoc lookup for subgroup AUROC labels (not model inputs). + + Re-runs ``task`` on each patient from ``dataset`` to recover string + race label, scalar age, and insurance token for each visit. Keys match + batch ``patient_id`` / ``visit_id`` used when aligning predictions. + + Args: + dataset: ``MIMIC3Dataset`` (or compatible) with ``iter_patients()``. + task: ``AMAPredictionMIMIC3`` instance (same as used in ``set_task``). + + Returns: + Map ``(patient_id_str, visit_id_str)`` -> ``{"race", "age", + "insurance"}``. Types match raw task outputs (race string without + ``race:`` prefix, age float, insurance category string). + """ + lookup: Dict[Tuple[str, str], Dict[str, Any]] = {} + for patient in dataset.iter_patients(): + for sample in task(patient): + pid = str(sample["patient_id"]) + vid = str(sample["visit_id"]) + race = sample["race"][0].split(":", 1)[1] + age = sample["age"][0] + insurance = "Other" + for t in sample["demographics"]: + if t.startswith("insurance:"): + insurance = t.split(":", 1)[1] + break + lookup[(pid, vid)] = { + "race": race, + "age": age, + "insurance": insurance, + } + return lookup + + +def _age_group(age: float) -> str: + """Map continuous age to Boag-style coarse age band labels. + + Args: + age: Age in years (typically from task sample ``age[0]``). + + Returns: + One of ``"Young (18-44)"``, ``"Middle (45-64)"``, ``"Senior (65+)"``. + """ + if age < 45: + return "Young (18-44)" + if age < 65: + return "Middle (45-64)" + return "Senior (65+)" + + +# ------------------------------------------------------------------ +# Helpers -- inference with demographic labels +# ------------------------------------------------------------------ + + +def _get_predictions( + model: Any, + dataloader: Any, + lookup: Dict[Tuple[str, str], Dict[str, Any]], +) -> Tuple[Any, Any, Dict[str, Any]]: + """Forward pass: collect probabilities, labels, and subgroup tags. + + Args: + model: Trained ``LogisticRegression`` (or compatible ``**batch``). + dataloader: PyHealth dataloader yielding batches with ``patient_id``, + ``visit_id``, and feature tensors. + lookup: Output of :func:`_build_demographics_lookup`. + + Returns: + Tuple ``(y_prob, y_true, groups)`` as ``numpy.ndarray`` vectors for + ``y_prob`` / ``y_true``, and ``groups`` mapping attribute name -> + array of string subgroup labels (Race / Age Group / Insurance). + """ + model.eval() + all_probs, all_labels = [], [] + all_races, all_ages, all_ins = [], [], [] + + with torch.no_grad(): + for batch in dataloader: + output = model(**batch) + all_probs.append(output["y_prob"].detach().cpu()) + all_labels.append(output["y_true"].detach().cpu()) + + pids = batch["patient_id"] + vids = batch["visit_id"] + if isinstance(vids, torch.Tensor): + vids = vids.tolist() + if isinstance(pids, torch.Tensor): + pids = pids.tolist() + + for pid, vid in zip(pids, vids): + info = lookup.get((str(pid), str(vid)), {}) + all_races.append(info.get("race", "Other")) + all_ages.append(_age_group(info.get("age", 0.0))) + all_ins.append(info.get("insurance", "Other")) + + y_prob = torch.cat(all_probs).numpy().ravel() + y_true = torch.cat(all_labels).numpy().ravel() + groups = { + "Race": np.array(all_races), + "Age Group": np.array(all_ages), + "Insurance": np.array(all_ins), + } + return y_prob, y_true, groups + + +# ------------------------------------------------------------------ +# Helpers -- safe metrics +# ------------------------------------------------------------------ + + +def _safe_auroc(y: Any, p: Any) -> float: + """Area under ROC; returns NaN if only one class or sklearn errors. + + Args: + y: Binary labels (1d array-like). + p: Predicted probabilities (1d array-like, same length as ``y``). + + Returns: + AUROC scalar, or ``float("nan")`` when undefined. + """ + if len(np.unique(y)) < 2: + return float("nan") + try: + return roc_auc_score(y, p) + except ValueError: + return float("nan") + + +# ------------------------------------------------------------------ +# Single split +# ------------------------------------------------------------------ + + +def _create_model( + sample_dataset: Any, + feature_keys: List[str], + embedding_dim: int = 128, +) -> LogisticRegression: + """Instantiate ``LogisticRegression`` restricted to ``feature_keys``. + + Args: + sample_dataset: Task output from ``dataset.set_task`` (provides + vocab / feature metadata for embeddings). + feature_keys: Subset of ``AMAPredictionMIMIC3`` input schema keys + (e.g. ``["demographics","age","los"]`` for baseline ablation). + embedding_dim: Linear embedding width per feature block. + + Returns: + Model with ``fc`` resized to ``len(feature_keys) * embedding_dim``. + """ + model = LogisticRegression( + dataset=sample_dataset, + embedding_dim=embedding_dim, + ) + model.feature_keys = list(feature_keys) + output_size = model.get_output_size() + model.fc = torch.nn.Linear( + len(feature_keys) * embedding_dim, + output_size, + ) + return model + + +def _run_single_split( + sample_dataset: Any, + feature_keys: List[str], + lookup: Dict[Tuple[str, str], Dict[str, Any]], + seed: int, + epochs: int, + batch_size: int = 32, +) -> Optional[Dict[str, Any]]: + """Train one ``LogisticRegression`` on a patient-level 60/40 split. + + Args: + sample_dataset: ``SampleDataset`` from ``set_task(AMAPredictionMIMIC3)``. + feature_keys: Model input keys (must exist in each batch). + lookup: Demographics map for subgroup metrics. + seed: RNG seed for ``split_by_patient``. + epochs: SGD epochs for ``Trainer.train``. + batch_size: Minibatch size for train/test loaders. + + Returns: + Dict with ``"auroc"`` (overall) and ``"subgroups"`` nested metrics, + or ``None`` if training fails. + """ + # Patient-level split keeps all visits for a subject in one partition. + train_ds, _, test_ds = split_by_patient( + sample_dataset, + [0.6, 0.0, 0.4], + seed=seed, + ) + train_dl = get_dataloader(train_ds, batch_size=batch_size, shuffle=True) + test_dl = get_dataloader(test_ds, batch_size=batch_size, shuffle=False) + + # Model only sees tensors for keys in ``feature_keys``; label ``ama``. + model = _create_model(sample_dataset, feature_keys) + trainer = Trainer(model=model) + try: + trainer.train( + train_dataloader=train_dl, + val_dataloader=None, + epochs=epochs, + monitor=None, + ) + except Exception as exc: + print(f" train failed: {exc}") + return None + + # Subgroup labels use ``lookup`` (not part of the forward batch tensors). + y_prob, y_true, groups = _get_predictions(model, test_dl, lookup) + + overall_auroc = _safe_auroc(y_true, y_prob) + + subgroup = {} + for attr_name, attr_vals in groups.items(): + subgroup[attr_name] = {} + for grp in sorted(set(attr_vals)): + mask = attr_vals == grp + n = int(mask.sum()) + if n < 2: + continue + yt, yp = y_true[mask], y_prob[mask] + subgroup[attr_name][grp] = { + "auroc": _safe_auroc(yt, yp), + "n": n, + } + + return { + "auroc": overall_auroc, + "subgroups": subgroup, + } + + +# ------------------------------------------------------------------ +# Aggregation +# ------------------------------------------------------------------ + + +def _nanmean(lst: List[float]) -> float: + """Mean over finite values; ignores NaNs.""" + v = [x for x in lst if not np.isnan(x)] + return float(np.mean(v)) if v else float("nan") + + +def _nanstd(lst: List[float]) -> float: + """Std dev over finite values; ignores NaNs.""" + v = [x for x in lst if not np.isnan(x)] + return float(np.std(v)) if v else float("nan") + + +def _aggregate( + results: List[Optional[Dict[str, Any]]], +) -> Optional[Dict[str, Any]]: + """Pool per-split metric dicts into mean/std summaries. + + Args: + results: List of outputs from :func:`_run_single_split` (may + contain ``None`` for failed splits). + + Returns: + Nested dict with ``auroc_mean``, ``auroc_std``, and per-subgroup + means/stds; ``None`` if every split failed. + """ + valid = [r for r in results if r is not None] + if not valid: + return None + + agg = { + "n": len(valid), + "auroc_mean": _nanmean([r["auroc"] for r in valid]), + "auroc_std": _nanstd([r["auroc"] for r in valid]), + } + + all_attrs = set() + for r in valid: + all_attrs.update(r["subgroups"].keys()) + + agg["subgroups"] = {} + for attr in sorted(all_attrs): + agg["subgroups"][attr] = {} + all_grps = set() + for r in valid: + if attr in r["subgroups"]: + all_grps.update(r["subgroups"][attr].keys()) + + for grp in sorted(all_grps): + aurocs, ns = [], [] + for r in valid: + m = r["subgroups"].get(attr, {}).get(grp) + if m is None: + continue + aurocs.append(m["auroc"]) + ns.append(m["n"]) + + agg["subgroups"][attr][grp] = { + "auroc_mean": _nanmean(aurocs), + "auroc_std": _nanstd(aurocs), + "n_avg": int(np.mean(ns)) if ns else 0, + } + return agg + + +# ------------------------------------------------------------------ +# Pretty-printing +# ------------------------------------------------------------------ + + +def _fmt(val: float, digits: int = 4) -> str: + """Format ``val`` for console tables; show ``N/A`` for NaN.""" + return "N/A" if np.isnan(val) else f"{val:.{digits}f}" + + +def _print_results( + name: str, + feature_keys: List[str], + agg: Optional[Dict[str, Any]], +) -> None: + """Pretty-print one ablation block (overall and subgroup AUROC). + + Args: + name: Baseline label (e.g. ``"BASELINE+RACE"``). + feature_keys: Feature list used for that run. + agg: Aggregated metrics from :func:`_aggregate`, or ``None``. + """ + w = 70 + print(f"\n{'=' * w}") + print(f" {name} (LogisticRegression)") + print(f" Features: {feature_keys}") + print(f"{'=' * w}") + + if agg is None: + print(" No valid results.\n") + return + + ci_lo = agg["auroc_mean"] - 1.96 * agg["auroc_std"] + ci_hi = agg["auroc_mean"] + 1.96 * agg["auroc_std"] + print(f"\n 1. Overall Performance ({agg['n']} splits)") + print( + f" AUROC: {_fmt(agg['auroc_mean'])} +/- {_fmt(agg['auroc_std'])}" + f" 95% CI ({_fmt(ci_lo)}, {_fmt(ci_hi)})" + ) + + print("\n 2. Subgroup Performance") + for attr, grps in agg["subgroups"].items(): + print(f" {attr}:") + print(f" {'Group':<20} {'AUROC':>15} {'n_avg':>7}") + print(f" {'-' * 42}") + for grp, m in grps.items(): + a_str = f"{_fmt(m['auroc_mean'])}+/-{_fmt(m['auroc_std'])}" + print(f" {grp:<20} {a_str:>15} {m['n_avg']:>7}") + + +# ------------------------------------------------------------------ +# Main +# ------------------------------------------------------------------ + + +def main() -> None: + """CLI entry: load data, ``set_task``, build lookup, run all baselines. + + Pipeline: + 1. Resolve synthetic vs real ``root`` and LitData ``cache_dir``. + 2. ``MIMIC3Dataset`` -> ``AMAPredictionMIMIC3`` -> ``SampleDataset``. + 3. Demographics lookup for subgroup AUROC (not passed to the model). + 4. For each entry in ``BASELINES``, repeat ``--splits`` train/eval + loops with the corresponding ``feature_keys`` (task I/O schema + unchanged; model uses a subset of keys). + + Side effects: + Prints logs; may write temp CSV/cache directories. + """ + parser = argparse.ArgumentParser( + description="AMA prediction ablation -- LogisticRegression", + ) + parser.add_argument( + "--data-source", + choices=("auto", "synthetic", "real"), + default="auto", + help="auto: use --root if set, else synthetic. synthetic: always local CSVs. " + "real: require --root to MIMIC-III on disk.", + ) + parser.add_argument( + "--synthetic-mode", + choices=("exhaustive", "random"), + default="exhaustive", + help="exhaustive: full cross-product of demographics×AMA×substance (default). " + "random: stochastic demo (--patients applies).", + ) + parser.add_argument( + "--root", + default=None, + help="MIMIC-III root directory (required for --data-source real, or sets " + "auto mode to real when provided).", + ) + parser.add_argument( + "--patients", + type=int, + default=100, + help="Synthetic patient count (random mode only; exhaustive ignores this).", + ) + parser.add_argument( + "--seed", + type=int, + default=42, + help="RNG seed for random synthetic mode.", + ) + parser.add_argument( + "--splits", + type=int, + default=5, + help="Number of random 60/40 splits (default 5 for speed with synthetic data)", + ) + parser.add_argument( + "--epochs", + type=int, + default=3, + help="Training epochs per split (default 3 for speed with synthetic data)", + ) + parser.add_argument( + "--dev", + action="store_true", + help="With random synthetic: only 10 patients. Exhaustive grid unchanged.", + ) + args = parser.parse_args() + + if args.data_source == "real" and not args.root: + parser.error("--data-source real requires --root /path/to/mimic-iii/1.4") + + use_synthetic = args.data_source == "synthetic" or ( + args.data_source == "auto" and args.root is None + ) + if args.data_source == "synthetic" and args.root: + print( + "Note: --root is ignored with --data-source synthetic " + "(data is written to a temporary directory).\n", + ) + + cache_dir = tempfile.mkdtemp(prefix="ama_lr_") + + if use_synthetic: + print("[Setup] Generating synthetic MIMIC-III dataset...") + data_dir = tempfile.mkdtemp(prefix="synthetic_mimic3_") + n_patients = ( + 10 if args.dev and args.synthetic_mode == "random" else args.patients + ) + generate_synthetic_mimic3( + data_dir, + n_patients=n_patients, + avg_admissions_per_patient=2, + seed=args.seed, + mode=args.synthetic_mode, + ) + args.root = str(data_dir) + print(f" Synthetic data: {data_dir}") + print(f" Mode: {args.synthetic_mode}\n") + else: + assert args.root is not None + print(f"Using real MIMIC-III from: {args.root}\n") + + print(f"Cache: {cache_dir}") + print(f"Root: {args.root}") + print(f"Splits: {args.splits} | Epochs: {args.epochs}") + + print("\n[1/4] Loading dataset...") + t0 = time.time() + dataset = MIMIC3Dataset( + root=args.root, + tables=[], + cache_dir=cache_dir, + dev=args.dev, + ) + print(f" Loaded in {time.time() - t0:.1f}s") + dataset.stats() + + print("\n[2/4] Applying AMA task...") + task = AMAPredictionMIMIC3() + try: + sample_dataset = dataset.set_task(task) + except ValueError as exc: + if "unique labels" not in str(exc).lower(): + raise + print(f"\n {exc}") + print(" The dataset contains no AMA-positive cases.") + print(" For synthetic data: this is expected if AMA rate is low.") + print(" For synthetic random mode with more patients:") + print(" python examples/mimic3_ama_prediction_logistic_regression.py \\") + print( + " --data-source synthetic --synthetic-mode random --patients 500\n" + ) + print(" For real MIMIC-III:") + print(" python examples/mimic3_ama_prediction_logistic_regression.py \\") + print(" --data-source real --root /path/to/mimic-iii/1.4") + print("\nDone.") + return + print(f" Samples: {len(sample_dataset)}") + + print("\n[3/4] Building demographics lookup...") + t0 = time.time() + lookup = _build_demographics_lookup(dataset, task) + print(f" {len(lookup)} entries in {time.time() - t0:.1f}s") + + print( + f"\n[4/4] Running ablation ({len(BASELINES)} baselines " + f"x {args.splits} splits)...\n" + ) + + t_total = time.time() + for name, feature_keys in BASELINES.items(): + split_results = [] + for i in range(args.splits): + t0 = time.time() + res = _run_single_split( + sample_dataset, + feature_keys, + lookup, + seed=i, + epochs=args.epochs, + ) + elapsed = time.time() - t0 + if res is not None: + print( + f" [{name}] split {i + 1:3d}/{args.splits}: " + f"AUROC={_fmt(res['auroc'])} ({elapsed:.1f}s)" + ) + else: + print(f" [{name}] split {i + 1:3d}/{args.splits}: FAILED") + split_results.append(res) + + agg = _aggregate(split_results) + _print_results(name, feature_keys, agg) + + print(f"\nTotal time: {time.time() - t_total:.1f}s") + print("Done.") + + +if __name__ == "__main__": + main() diff --git a/examples/mimic3_ama_prediction_rnn.py b/examples/mimic3_ama_prediction_rnn.py new file mode 100644 index 000000000..1e774cec3 --- /dev/null +++ b/examples/mimic3_ama_prediction_rnn.py @@ -0,0 +1,570 @@ +"""Ablation study for MIMIC-III Against-Medical-Advice (AMA) discharge prediction (RNN). + +This script demonstrates ``AMAPredictionMIMIC3`` with three ``feature_keys`` +ablations (baseline, +race, +substance), training PyHealth's ``RNN`` so the +mapping from task tensors to logits is non-linear. One visit +per sequence yields a compact recurrent block over the selected features. +Labels and task schemas match ``AMAPredictionMIMIC3`` (AMA from discharge +location; ``input_schema`` / ``output_schema`` unchanged). + +Paper: + Boag, W.; Suresh, H.; Celi, L. A.; Szolovits, P.; and Ghassemi, M. + "Racial Disparities and Mistrust in End-of-Life Care." + Machine Learning for Healthcare Conference, PMLR, 2018. + +Ablation configurations tested: + 1. BASELINE: demographics + age + LOS. + 2. BASELINE+RACE: adds ``race`` (normalized ethnicity). + 3. BASELINE+RACE+SUBSTANCE: adds ``has_substance_use`` from diagnosis text. + +Results: + For each baseline configuration we report: + - Overall AUROC over N random 60/40 patient-level splits. + - Subgroup AUROC by race, age group, and insurance. + +Synthetic data: + ``generate_synthetic_mimic3`` is imported from + ``examples/mimic3_ama_prediction_logistic_regression.py`` (single shared + implementation of the CSV writer). + +Usage: + # Synthetic exhaustive grid (default when no ``--root``) + cd /path/to/PyHealth && python examples/mimic3_ama_prediction_rnn.py \\ + --data-source synthetic + + # Real MIMIC-III 1.4 (set ``--root`` to your extract path) + cd /path/to/PyHealth && python examples/mimic3_ama_prediction_rnn.py \\ + --data-source real --root /path/to/mimic-iii/1.4 --splits 100 --epochs 10 + +""" + +import argparse +import importlib.util +import tempfile +import time +from pathlib import Path +from typing import Any, Dict, List, Optional, Tuple + +import numpy as np +import torch +from sklearn.metrics import roc_auc_score + +from pyhealth.datasets import MIMIC3Dataset, get_dataloader, split_by_patient +from pyhealth.models import RNN +from pyhealth.tasks import AMAPredictionMIMIC3 +from pyhealth.trainer import Trainer + +_LR_EXAMPLE = Path(__file__).resolve().parent / ( + "mimic3_ama_prediction_logistic_regression.py" +) +_spec = importlib.util.spec_from_file_location( + "mimic3_ama_lr_example", + _LR_EXAMPLE, +) +_lr_mod = importlib.util.module_from_spec(_spec) +assert _spec.loader is not None +_spec.loader.exec_module(_lr_mod) +generate_synthetic_mimic3 = _lr_mod.generate_synthetic_mimic3 + +BASELINES = { + "BASELINE": ["demographics", "age", "los"], + "BASELINE+RACE": ["demographics", "age", "los", "race"], + "BASELINE+RACE+SUBSTANCE": [ + "demographics", + "age", + "los", + "race", + "has_substance_use", + ], +} + + +# ------------------------------------------------------------------ +# Helpers -- demographics lookup +# ------------------------------------------------------------------ + + +def _build_demographics_lookup( + dataset: Any, + task: AMAPredictionMIMIC3, +) -> Dict[Tuple[str, str], Dict[str, Any]]: + """Map each visit to race, age, and insurance for subgroup metrics. + + Args: + dataset: ``MIMIC3Dataset`` with ``iter_patients()``. + task: Same ``AMAPredictionMIMIC3`` used with ``set_task``. + + Returns: + ``(patient_id, visit_id)`` strings -> ``{"race","age","insurance"}``. + """ + lookup: Dict[Tuple[str, str], Dict[str, Any]] = {} + for patient in dataset.iter_patients(): + for sample in task(patient): + pid = str(sample["patient_id"]) + vid = str(sample["visit_id"]) + race = sample["race"][0].split(":", 1)[1] + age = sample["age"][0] + insurance = "Other" + for t in sample["demographics"]: + if t.startswith("insurance:"): + insurance = t.split(":", 1)[1] + break + lookup[(pid, vid)] = { + "race": race, + "age": age, + "insurance": insurance, + } + return lookup + + +def _age_group(age: float) -> str: + """Map age in years to Young / Middle / Senior subgroup labels.""" + if age < 45: + return "Young (18-44)" + if age < 65: + return "Middle (45-64)" + return "Senior (65+)" + + +# ------------------------------------------------------------------ +# Helpers -- inference with demographic labels +# ------------------------------------------------------------------ + + +def _get_predictions( + model: Any, + dataloader: Any, + lookup: Dict[Tuple[str, str], Dict[str, Any]], +) -> Tuple[Any, Any, Dict[str, Any]]: + """Evaluate ``model`` on ``dataloader``; attach subgroup labels. + + Args: + model: Trained ``RNN`` accepting batch kwargs. + dataloader: Test loader with ``patient_id`` / ``visit_id``. + lookup: Demographics map from :func:`_build_demographics_lookup`. + + Returns: + ``(y_prob, y_true, groups)`` numpy arrays / group dict. + """ + model.eval() + all_probs, all_labels = [], [] + all_races, all_ages, all_ins = [], [], [] + + with torch.no_grad(): + for batch in dataloader: + output = model(**batch) + all_probs.append(output["y_prob"].detach().cpu()) + all_labels.append(output["y_true"].detach().cpu()) + + pids = batch["patient_id"] + vids = batch["visit_id"] + if isinstance(vids, torch.Tensor): + vids = vids.tolist() + if isinstance(pids, torch.Tensor): + pids = pids.tolist() + + for pid, vid in zip(pids, vids): + info = lookup.get((str(pid), str(vid)), {}) + all_races.append(info.get("race", "Other")) + all_ages.append(_age_group(info.get("age", 0.0))) + all_ins.append(info.get("insurance", "Other")) + + y_prob = torch.cat(all_probs).numpy().ravel() + y_true = torch.cat(all_labels).numpy().ravel() + groups = { + "Race": np.array(all_races), + "Age Group": np.array(all_ages), + "Insurance": np.array(all_ins), + } + return y_prob, y_true, groups + + +# ------------------------------------------------------------------ +# Helpers -- safe metrics +# ------------------------------------------------------------------ + + +def _safe_auroc(y: Any, p: Any) -> float: + """ROC-AUC with guards for single-class slices.""" + if len(np.unique(y)) < 2: + return float("nan") + try: + return roc_auc_score(y, p) + except ValueError: + return float("nan") + + +# ------------------------------------------------------------------ +# Single split +# ------------------------------------------------------------------ + + +def _create_model( + sample_dataset: Any, + feature_keys: List[str], + embedding_dim: int = 128, + hidden_dim: int = 64, +) -> RNN: + """Build ``RNN`` over ``feature_keys`` (ablation-controlled inputs). + + Args: + sample_dataset: ``SampleDataset`` after ``set_task``. + feature_keys: Keys from ``AMAPredictionMIMIC3.input_schema``. + embedding_dim: Token embedding size. + hidden_dim: RNN hidden size (``fc`` input is ``len(keys)*hidden_dim``). + + Returns: + Configured ``RNN`` instance. + """ + model = RNN( + dataset=sample_dataset, + embedding_dim=embedding_dim, + hidden_dim=hidden_dim, + ) + model.feature_keys = list(feature_keys) + output_size = model.get_output_size() + model.fc = torch.nn.Linear( + len(feature_keys) * hidden_dim, + output_size, + ) + return model + + +def _run_single_split( + sample_dataset: Any, + feature_keys: List[str], + lookup: Dict[Tuple[str, str], Dict[str, Any]], + seed: int, + epochs: int, + batch_size: int = 32, +) -> Optional[Dict[str, Any]]: + """One patient-level split: train RNN, then test + subgroup AUROC. + + Args: + sample_dataset: AMA task samples. + feature_keys: Subset of input schema keys for this ablation. + lookup: Demographics for post-hoc slicing. + seed: Split RNG seed. + epochs: Training epochs. + batch_size: Loader batch size. + + Returns: + Metric dict or ``None`` on training failure. + """ + train_ds, _, test_ds = split_by_patient( + sample_dataset, + [0.6, 0.0, 0.4], + seed=seed, + ) + train_dl = get_dataloader(train_ds, batch_size=batch_size, shuffle=True) + test_dl = get_dataloader(test_ds, batch_size=batch_size, shuffle=False) + + # ``feature_keys`` selects which task outputs the RNN reads for this run. + model = _create_model(sample_dataset, feature_keys) + trainer = Trainer(model=model) + try: + trainer.train( + train_dataloader=train_dl, + val_dataloader=None, + epochs=epochs, + monitor=None, + ) + except Exception as exc: + print(f" train failed: {exc}") + return None + + y_prob, y_true, groups = _get_predictions(model, test_dl, lookup) + + overall_auroc = _safe_auroc(y_true, y_prob) + + subgroup = {} + for attr_name, attr_vals in groups.items(): + subgroup[attr_name] = {} + for grp in sorted(set(attr_vals)): + mask = attr_vals == grp + n = int(mask.sum()) + if n < 2: + continue + yt, yp = y_true[mask], y_prob[mask] + subgroup[attr_name][grp] = { + "auroc": _safe_auroc(yt, yp), + "n": n, + } + + return { + "auroc": overall_auroc, + "subgroups": subgroup, + } + + +# ------------------------------------------------------------------ +# Aggregation +# ------------------------------------------------------------------ + + +def _nanmean(lst: List[float]) -> float: + """Mean ignoring NaNs.""" + v = [x for x in lst if not np.isnan(x)] + return float(np.mean(v)) if v else float("nan") + + +def _nanstd(lst: List[float]) -> float: + """Std ignoring NaNs.""" + v = [x for x in lst if not np.isnan(x)] + return float(np.std(v)) if v else float("nan") + + +def _aggregate( + results: List[Optional[Dict[str, Any]]], +) -> Optional[Dict[str, Any]]: + """Mean/std of overall AUROC and per-subgroup metrics across splits.""" + valid = [r for r in results if r is not None] + if not valid: + return None + + agg = { + "n": len(valid), + "auroc_mean": _nanmean([r["auroc"] for r in valid]), + "auroc_std": _nanstd([r["auroc"] for r in valid]), + } + + all_attrs = set() + for r in valid: + all_attrs.update(r["subgroups"].keys()) + + agg["subgroups"] = {} + for attr in sorted(all_attrs): + agg["subgroups"][attr] = {} + all_grps = set() + for r in valid: + if attr in r["subgroups"]: + all_grps.update(r["subgroups"][attr].keys()) + + for grp in sorted(all_grps): + aurocs, ns = [], [] + for r in valid: + m = r["subgroups"].get(attr, {}).get(grp) + if m is None: + continue + aurocs.append(m["auroc"]) + ns.append(m["n"]) + + agg["subgroups"][attr][grp] = { + "auroc_mean": _nanmean(aurocs), + "auroc_std": _nanstd(aurocs), + "n_avg": int(np.mean(ns)) if ns else 0, + } + return agg + + +# ------------------------------------------------------------------ +# Pretty-printing +# ------------------------------------------------------------------ + + +def _fmt(val: float, digits: int = 4) -> str: + """Human-readable float or ``N/A`` for NaN.""" + return "N/A" if np.isnan(val) else f"{val:.{digits}f}" + + +def _print_results( + name: str, + feature_keys: List[str], + agg: Optional[Dict[str, Any]], +) -> None: + """Print ablation summary for RNN runs.""" + w = 70 + print(f"\n{'=' * w}") + print(f" {name} (RNN hidden_dim=64)") + print(f" Features: {feature_keys}") + print(f"{'=' * w}") + + if agg is None: + print(" No valid results.\n") + return + + ci_lo = agg["auroc_mean"] - 1.96 * agg["auroc_std"] + ci_hi = agg["auroc_mean"] + 1.96 * agg["auroc_std"] + print(f"\n 1. Overall Performance ({agg['n']} splits)") + print( + f" AUROC: {_fmt(agg['auroc_mean'])} +/- {_fmt(agg['auroc_std'])}" + f" 95% CI ({_fmt(ci_lo)}, {_fmt(ci_hi)})" + ) + + print("\n 2. Subgroup Performance") + for attr, grps in agg["subgroups"].items(): + print(f" {attr}:") + print(f" {'Group':<20} {'AUROC':>15} {'n_avg':>7}") + print(f" {'-' * 42}") + for grp, m in grps.items(): + a_str = f"{_fmt(m['auroc_mean'])}+/-{_fmt(m['auroc_std'])}" + print(f" {grp:<20} {a_str:>15} {m['n_avg']:>7}") + + +# ------------------------------------------------------------------ +# Main +# ------------------------------------------------------------------ + + +def main() -> None: + """CLI entry: load MIMIC-III (or synthetic CSVs), apply AMA task, train RNN. + + Resolves ``--data-source`` / ``--root`` / ``--synthetic-mode``, builds a + ``MIMIC3Dataset`` and ``SampleDataset`` via ``set_task(AMAPredictionMIMIC3)``, + builds a demographics lookup for subgroup metrics, then for each entry in + ``BASELINES`` runs ``--splits`` train/eval loops. Each loop uses the + listed ``feature_keys`` so the ``RNN`` reads only that subset of task + outputs; the task ``input_schema`` / ``output_schema`` are unchanged. + + Side effects: prints progress; creates temporary data and cache dirs. + """ + parser = argparse.ArgumentParser( + description="AMA prediction ablation -- RNN", + ) + parser.add_argument( + "--data-source", + choices=("auto", "synthetic", "real"), + default="auto", + help="auto: use --root if set, else synthetic. synthetic: local CSVs. " + "real: require --root.", + ) + parser.add_argument( + "--synthetic-mode", + choices=("exhaustive", "random"), + default="exhaustive", + help="Synthetic grid: exhaustive cross-product or random; " + "see generate_synthetic_mimic3 docstring in the loaded module.", + ) + parser.add_argument( + "--root", + default=None, + help="MIMIC-III root (required for --data-source real).", + ) + parser.add_argument( + "--patients", type=int, default=100, help="Random synthetic mode only." + ) + parser.add_argument( + "--seed", type=int, default=42, help="RNG seed for random synthetic mode." + ) + parser.add_argument( + "--splits", type=int, default=5, help="Number of random 60/40 splits" + ) + parser.add_argument( + "--epochs", type=int, default=3, help="Training epochs per split" + ) + parser.add_argument( + "--dev", action="store_true", help="Random synthetic: 10 patients only." + ) + args = parser.parse_args() + + if args.data_source == "real" and not args.root: + parser.error("--data-source real requires --root /path/to/mimic-iii/1.4") + + use_synthetic = args.data_source == "synthetic" or ( + args.data_source == "auto" and args.root is None + ) + if args.data_source == "synthetic" and args.root: + print( + "Note: --root ignored with --data-source synthetic.\n", + ) + + cache_dir = tempfile.mkdtemp(prefix="ama_rnn_") + + if use_synthetic: + print("[Setup] Generating synthetic MIMIC-III dataset...") + data_dir = tempfile.mkdtemp(prefix="synthetic_mimic3_") + n_patients = ( + 10 if args.dev and args.synthetic_mode == "random" else args.patients + ) + generate_synthetic_mimic3( + data_dir, + n_patients=n_patients, + avg_admissions_per_patient=2, + seed=args.seed, + mode=args.synthetic_mode, + ) + args.root = str(data_dir) + print(f" Synthetic: {data_dir} mode={args.synthetic_mode}\n") + else: + print(f"Using real MIMIC-III from: {args.root}\n") + + print(f"Cache: {cache_dir}") + print(f"Root: {args.root}") + print(f"Splits: {args.splits} | Epochs: {args.epochs}") + + print("\n[1/4] Loading dataset...") + t0 = time.time() + dataset = MIMIC3Dataset( + root=args.root, + tables=[], + cache_dir=cache_dir, + dev=args.dev, + ) + print(f" Loaded in {time.time() - t0:.1f}s") + dataset.stats() + + print("\n[2/4] Applying AMA task...") + task = AMAPredictionMIMIC3() + try: + sample_dataset = dataset.set_task(task) + except ValueError as exc: + if "unique labels" not in str(exc).lower(): + raise + print(f"\n {exc}") + print(" The dataset contains no AMA-positive cases.") + print(" AMA prevalence is ~2% so small/synthetic data") + print(" often lacks positives. Demonstrating the task") + print(" on raw patients instead:\n") + total = 0 + for patient in dataset.iter_patients(): + samples = task(patient) + total += len(samples) + print(f" Task produced {total} samples (all label=0)") + print("\n Re-run with real MIMIC-III:") + print(" python examples/mimic3_ama_prediction_rnn.py \\") + print(" --data-source real --root /path/to/mimic-iii/1.4") + print("\nDone.") + return + print(f" Samples: {len(sample_dataset)}") + + print("\n[3/4] Building demographics lookup...") + t0 = time.time() + lookup = _build_demographics_lookup(dataset, task) + print(f" {len(lookup)} entries in {time.time() - t0:.1f}s") + + print( + f"\n[4/4] Running ablation ({len(BASELINES)} baselines " + f"x {args.splits} splits)...\n" + ) + + t_total = time.time() + for name, feature_keys in BASELINES.items(): + split_results = [] + for i in range(args.splits): + t0 = time.time() + res = _run_single_split( + sample_dataset, + feature_keys, + lookup, + seed=i, + epochs=args.epochs, + ) + elapsed = time.time() - t0 + if res is not None: + print( + f" [{name}] split {i + 1:3d}/{args.splits}: " + f"AUROC={_fmt(res['auroc'])} ({elapsed:.1f}s)" + ) + else: + print(f" [{name}] split {i + 1:3d}/{args.splits}: FAILED") + split_results.append(res) + + agg = _aggregate(split_results) + _print_results(name, feature_keys, agg) + + print(f"\nTotal time: {time.time() - t_total:.1f}s") + print("Done.") + + +if __name__ == "__main__": + main() diff --git a/pyhealth/tasks/__init__.py b/pyhealth/tasks/__init__.py index 797988377..c4e541054 100644 --- a/pyhealth/tasks/__init__.py +++ b/pyhealth/tasks/__init__.py @@ -1,3 +1,4 @@ +from .ama_prediction import AMAPredictionMIMIC3 from .base_task import BaseTask from .benchmark_ehrshot import BenchmarkEHRShot from .cancer_survival import CancerMutationBurden, CancerSurvivalPrediction diff --git a/pyhealth/tasks/ama_prediction.py b/pyhealth/tasks/ama_prediction.py new file mode 100644 index 000000000..8865192cd --- /dev/null +++ b/pyhealth/tasks/ama_prediction.py @@ -0,0 +1,306 @@ +"""MIMIC-III Against-Medical-Advice (AMA) discharge prediction task module. + +Defines :class:`AMAPredictionMIMIC3`, a :class:`~pyhealth.tasks.base_task.BaseTask` +subclass that emits one sample per eligible ICU admission from MIMIC-III +administrative tables (patients, admissions, icustays). No ICD / procedure / +prescription sequences are required. Module-level helpers normalize race and +insurance strings, parse datetimes, and flag substance-related diagnoses from +free-text ``ADMISSIONS.DIAGNOSIS``. + +Paper: + Boag, W.; Suresh, H.; Celi, L. A.; Szolovits, P.; and Ghassemi, M. + "Racial Disparities and Mistrust in End-of-Life Care." Machine Learning + for Healthcare Conference, PMLR 106:211-235, 2018. + +Model-side ablations (task schema is fixed; select ``feature_keys`` on the +model, as in the example scripts): + 1. BASELINE: ``demographics`` (gender + insurance), ``age``, ``los``. + 2. BASELINE+RACE: adds ``race`` (normalized ethnicity tokens). + 3. BASELINE+RACE+SUBSTANCE: adds ``has_substance_use`` (0/1 tensor). + +Task I/O (schemas for ``dataset.set_task``): + - ``input_schema``: ``multi_hot`` demographics and race; ``tensor`` age, + LOS, substance flag. + - ``output_schema``: binary ``ama`` from ``discharge_location``. + +Usage: + >>> from pyhealth.datasets import MIMIC3Dataset + >>> from pyhealth.tasks.ama_prediction import AMAPredictionMIMIC3 + >>> dataset = MIMIC3Dataset(root="/path/to/mimic-iii/1.4", tables=[]) + >>> sample_ds = dataset.set_task(AMAPredictionMIMIC3()) + +""" + +import re +from datetime import datetime +from typing import Any, Dict, List, Optional + +from .base_task import BaseTask + +_SUBSTANCE_PATTERN = re.compile( + r"alcohol|opioid|opiate|heroin|cocaine|drug|withdrawal" + r"|intoxication|overdose|substance|etoh", + re.IGNORECASE, +) + + +def _normalize_race(ethnicity: Optional[str]) -> str: + """Map MIMIC-III ethnicity strings to the race categories used by + Boag et al. 2018. + + Args: + ethnicity: Raw ethnicity value from MIMIC-III ``admissions`` + table. + + Returns: + One of ``"White"``, ``"Black"``, ``"Hispanic"``, + ``"Asian"``, ``"Native American"``, or ``"Other"``. + """ + if ethnicity is None: + return "Other" + eth = str(ethnicity).upper() + if "HISPANIC" in eth or "SOUTH AMERICAN" in eth: + return "Hispanic" + if "AMERICAN INDIAN" in eth: + return "Native American" + if "ASIAN" in eth: + return "Asian" + if "BLACK" in eth: + return "Black" + if "WHITE" in eth: + return "White" + return "Other" + + +def _normalize_insurance(insurance: Optional[str]) -> str: + """Map MIMIC-III insurance strings to the categories used by + Boag et al. 2018. + + Args: + insurance: Raw insurance value from MIMIC-III ``admissions`` + table. + + Returns: + ``"Public"`` for Medicare/Medicaid/Government, or the + original value (typically ``"Private"`` or ``"Self Pay"``). + """ + if insurance is None: + return "Other" + if insurance in ("Medicare", "Medicaid", "Government"): + return "Public" + return insurance + + +def _safe_parse_datetime(value: Any) -> datetime: + """Parse a datetime string, trying common MIMIC-III formats.""" + if isinstance(value, datetime): + return value + for fmt in ("%Y-%m-%d %H:%M:%S", "%Y-%m-%d"): + try: + return datetime.strptime(str(value), fmt) + except (ValueError, TypeError): + continue + raise ValueError(f"Cannot parse datetime: {value!r}") + + +def _has_substance_use(diagnosis: Any) -> int: + """Detect substance-use related admission from the free-text + ``DIAGNOSIS`` field in MIMIC-III ``ADMISSIONS``. + + Args: + diagnosis: Raw ``diagnosis`` string from the admissions table. + + Returns: + 1 if a substance-use keyword is found, 0 otherwise. + """ + if diagnosis is None: + return 0 + return 1 if _SUBSTANCE_PATTERN.search(str(diagnosis)) else 0 + + +class AMAPredictionMIMIC3(BaseTask): + """Predict whether a patient leaves the hospital against medical advice. + + This task reproduces the AMA (Against Medical Advice) discharge + prediction target from Boag et al. 2018, "Racial Disparities and + Mistrust in End-of-Life Care." A positive label indicates that the + patient's ``discharge_location`` is ``"LEFT AGAINST MEDICAL ADVI"`` + in the MIMIC-III admissions table. + + The feature set supports three ablation baselines: + + * **BASELINE** -- ``demographics`` (gender, insurance), + ``age``, and ``los``. Select with + ``feature_keys=["demographics", "age", "los"]``. + + * **BASELINE + RACE** -- adds ``race`` (normalized ethnicity). + Select with + ``feature_keys=["demographics", "age", "los", "race"]``. + + * **BASELINE + RACE + SUBSTANCE** -- adds ``has_substance_use`` + (derived from ``ADMISSIONS.DIAGNOSIS`` free-text field). + Select with + ``feature_keys=["demographics", "age", "los", "race", + "has_substance_use"]``. + + These baselines can be toggled via the model's ``feature_keys`` + parameter without changing the task. + + Only administrative and demographic features are extracted; no + clinical code tables (diagnoses, procedures, prescriptions) are + required. + + Unlike mortality or readmission prediction, the label is a property + of the **current** admission, so patients with only one visit are + eligible. + + **Processor mapping (schemas):** Each value in ``input_schema`` and + ``output_schema`` must be a processor string key understood by + ``dataset.set_task`` (see the Tasks docs processor table). Here, + ``"multi_hot"`` feeds ``MultiHotProcessor`` (token lists for + ``demographics`` and ``race``), ``"tensor"`` feeds ``TensorProcessor`` + (``age``, ``los``, ``has_substance_use``), and ``"binary"`` labels + ``ama`` for the binary label path. + + Attributes: + task_name (str): The name of the task. + input_schema (Dict[str, str]): The schema for the task input. + output_schema (Dict[str, str]): The schema for the task output. + + Examples: + >>> from pyhealth.datasets import MIMIC3Dataset + >>> from pyhealth.tasks import AMAPredictionMIMIC3 + >>> dataset = MIMIC3Dataset( + ... root="/path/to/mimic-iii/1.4", + ... tables=[], + ... ) + >>> task = AMAPredictionMIMIC3() + >>> samples = dataset.set_task(task) + + Reference: + Boag, W.; Suresh, H.; Celi, L. A.; Szolovits, P.; and Ghassemi, M. + 2018. Racial Disparities and Mistrust in End-of-Life Care. In Machine + Learning for Healthcare Conference. PMLR. + """ + + AMA_DISCHARGE_LOCATION: str = "LEFT AGAINST MEDICAL ADVI" + + task_name: str = "AMAPredictionMIMIC3" + input_schema: Dict[str, str] = { + "demographics": "multi_hot", + "age": "tensor", + "los": "tensor", + "race": "multi_hot", + "has_substance_use": "tensor", + } + output_schema: Dict[str, str] = {"ama": "binary"} + + def __init__(self, exclude_newborns: bool = True) -> None: + """Initializes the AMA prediction task. + + Args: + exclude_newborns: If ``True``, admissions whose + ``admission_type`` is ``"NEWBORN"`` are skipped. + Defaults to ``True``. + """ + self.exclude_newborns = exclude_newborns + + def __call__(self, patient: Any) -> List[Dict[str, Any]]: + """Processes a single patient for AMA discharge prediction. + + Each non-newborn admission is emitted as a sample. The binary + label is derived from the admission's ``discharge_location``. + + Args: + patient: A Patient object from ``MIMIC3Dataset``. + + Returns: + A list of sample dictionaries. Each dictionary contains + the features described in the class docstring plus + ``visit_id``, ``patient_id``, and the ``ama`` label. + """ + admissions = patient.get_events(event_type="admissions") + if len(admissions) == 0: + return [] + + patients_events = patient.get_events(event_type="patients") + if len(patients_events) == 0: + return [] + patient_info = patients_events[0] + + gender = getattr(patient_info, "gender", None) + dob_raw = getattr(patient_info, "dob", None) + try: + dob = _safe_parse_datetime(dob_raw) + except (ValueError, TypeError): + dob = None + + samples: List[Dict[str, Any]] = [] + + for admission in admissions: + if self.exclude_newborns: + admission_type = getattr(admission, "admission_type", None) + if admission_type == "NEWBORN": + continue + + # --- Label --- + discharge_location = getattr(admission, "discharge_location", None) + ama_label = 1 if discharge_location == self.AMA_DISCHARGE_LOCATION else 0 + + # --- BASELINE demographics (gender + insurance) --- + insurance = getattr(admission, "insurance", None) + + demo_tokens: List[str] = [] + if gender: + demo_tokens.append(f"gender:{gender}") + demo_tokens.append(f"insurance:{_normalize_insurance(insurance)}") + + # --- Race (separate feature for ablation) --- + ethnicity = getattr(admission, "ethnicity", None) + race_tokens: List[str] = [f"race:{_normalize_race(ethnicity)}"] + + # --- Age (continuous) --- + age_years = 0.0 + if dob is not None: + admit_dt = admission.timestamp + if isinstance(admit_dt, datetime): + age_years = ( + admit_dt.year + - dob.year + - int((admit_dt.month, admit_dt.day) < (dob.month, dob.day)) + ) + age_years = float(min(age_years, 90)) + + # --- LOS (continuous, in days) --- + los_days = 0.0 + dischtime_raw = getattr(admission, "dischtime", None) + if dischtime_raw is not None: + try: + dischtime = _safe_parse_datetime(dischtime_raw) + admit_dt = admission.timestamp + if isinstance(admit_dt, datetime): + los_days = max( + (dischtime - admit_dt).total_seconds() / 86400.0, + 0.0, + ) + except (ValueError, TypeError): + los_days = 0.0 + + # --- Substance use (from ADMISSIONS.DIAGNOSIS) --- + diagnosis_text = getattr(admission, "diagnosis", None) + substance = float(_has_substance_use(diagnosis_text)) + + samples.append( + { + "visit_id": admission.hadm_id, + "patient_id": patient.patient_id, + "demographics": demo_tokens, + "age": [age_years], + "los": [los_days], + "race": race_tokens, + "has_substance_use": [substance], + "ama": ama_label, + } + ) + + return samples diff --git a/tests/core/test_mimic3_ama_prediction.py b/tests/core/test_mimic3_ama_prediction.py new file mode 100644 index 000000000..7d975f018 --- /dev/null +++ b/tests/core/test_mimic3_ama_prediction.py @@ -0,0 +1,1422 @@ +"""Test suite for MIMIC-III Against-Medical-Advice (AMA) discharge prediction. + +This is the automated test module for ``AMAPredictionMIMIC3`` +(``pyhealth.tasks.ama_prediction``). It provides a comprehensive set of checks +covering: + + - Task helpers: race and insurance normalization, substance-use + detection from diagnosis text. + - Task contract: ``task_name``, ``input_schema``, ``output_schema``, and + default flags (e.g. newborn filtering). + - Feature engineering on mock patients: AMA vs non-AMA labels, age and + LOS from timestamps, demographics vs separate ``race`` tokens, + ``has_substance_use``, multi-admission behavior, and schema key sets. + - Ablation-oriented checks: baseline feature presence, label correctness, + and absence of clinical code fields in samples. + - Integration: curated five-row gzipped MIMIC-style CSVs with + ``MIMIC3Dataset`` + ``set_task`` + ``LogisticRegression`` forward passes + and short ``Trainer`` smoke runs (example CLI tables are not asserted). + - Synthetic generator sanity: exhaustive grid patient row count. + +Paper (task motivation): + Boag, W.; Suresh, H.; Celi, L. A.; Szolovits, P.; and Ghassemi, M. + "Racial Disparities and Mistrust in End-of-Life Care." MLHC / PMLR, 2018. + +Usage: + # From the PyHealth repository root (quiet summary) + cd /path/to/PyHealth && python -m unittest tests.core.test_mimic3_ama_prediction -q + + # Verbose per-test output + cd /path/to/PyHealth && python -m unittest tests.core.test_mimic3_ama_prediction -v + + # Run this file directly + cd /path/to/PyHealth && python tests/core/test_mimic3_ama_prediction.py + +""" + +import gc +import gzip +import importlib.util +import io +import shutil +import sys +import tempfile +import unittest +import warnings +from datetime import datetime +from pathlib import Path +from typing import Any, Dict, List +from unittest.mock import MagicMock + +import pandas as pd +import torch + +from pyhealth.datasets import MIMIC3Dataset, get_dataloader, split_by_patient +from pyhealth.models import LogisticRegression +from pyhealth.tasks.ama_prediction import ( + AMAPredictionMIMIC3, + _has_substance_use, + _normalize_insurance, + _normalize_race, +) +from pyhealth.trainer import Trainer + +warnings.filterwarnings("ignore", category=ResourceWarning) +if "ignore::ResourceWarning" not in getattr(sys, "warnoptions", []): + sys.warnoptions.append("ignore::ResourceWarning") + warnings._filters_mutated() + +_EXAMPLE_PATH = ( + Path(__file__).resolve().parents[2] + / "examples" + / "mimic3_ama_prediction_logistic_regression.py" +) +_spec = importlib.util.spec_from_file_location( + "mimic3_ama_prediction_example", + _EXAMPLE_PATH, +) +_example_mod = importlib.util.module_from_spec(_spec) +assert _spec.loader is not None +_spec.loader.exec_module(_example_mod) +generate_synthetic_mimic3 = _example_mod.generate_synthetic_mimic3 + + +# Fixed 5-row MIMIC-III-like slice for integration tests (3 non-AMA, 2 AMA). +# Race (normalized): 2 White, 2 Black, 1 Hispanic. +# Age at admit (task uses calendar year from DOB vs admittime): two young +# (18-44), two middle (45-64), one senior (65+). +CURATED_SYNTHETIC_N = 5 +CURATED_SYNTHETIC_AMA_NEGATIVE = 3 +CURATED_SYNTHETIC_AMA_POSITIVE = 2 + +_CURATED_MIMIC3_PATIENTS = [ + { + "subject_id": 1, + "gender": "M", + "dob": "2118-01-01 00:00:00", + "dod": None, + "dod_hosp": None, + "dod_ssn": None, + "expire_flag": 0, + }, + { + "subject_id": 2, + "gender": "F", + "dob": "2096-01-02 00:00:00", + "dod": None, + "dod_hosp": None, + "dod_ssn": None, + "expire_flag": 0, + }, + { + "subject_id": 3, + "gender": "M", + "dob": "2098-01-03 00:00:00", + "dod": None, + "dod_hosp": None, + "dod_ssn": None, + "expire_flag": 0, + }, + { + "subject_id": 4, + "gender": "F", + "dob": "2115-01-11 00:00:00", + "dod": None, + "dod_hosp": None, + "dod_ssn": None, + "expire_flag": 0, + }, + { + "subject_id": 5, + "gender": "M", + "dob": "2079-01-12 00:00:00", + "dod": None, + "dod_hosp": None, + "dod_ssn": None, + "expire_flag": 0, + }, +] + +_CURATED_MIMIC3_ADMISSIONS = [ + { + "subject_id": 1, + "hadm_id": 100, + "admission_type": "EMERGENCY", + "admission_location": "EMERGENCY ROOM ADMIT", + "insurance": "Private", + "language": "ENGLISH", + "religion": "CHRISTIAN", + "marital_status": "SINGLE", + "ethnicity": "WHITE", + "edregtime": "2150-01-01 00:00:00", + "edouttime": "2150-01-01 00:00:00", + "diagnosis": "PNEUMONIA", + "discharge_location": "HOME", + "dischtime": "2150-01-08 00:00:00", + "admittime": "2150-01-01 00:00:00", + "hospital_expire_flag": 0, + }, + { + "subject_id": 2, + "hadm_id": 101, + "admission_type": "EMERGENCY", + "admission_location": "EMERGENCY ROOM ADMIT", + "insurance": "Private", + "language": "ENGLISH", + "religion": "CHRISTIAN", + "marital_status": "SINGLE", + "ethnicity": "WHITE", + "edregtime": "2150-01-02 00:00:00", + "edouttime": "2150-01-02 00:00:00", + "diagnosis": "CHEST PAIN", + "discharge_location": "HOME", + "dischtime": "2150-01-09 00:00:00", + "admittime": "2150-01-02 00:00:00", + "hospital_expire_flag": 0, + }, + { + "subject_id": 3, + "hadm_id": 102, + "admission_type": "EMERGENCY", + "admission_location": "EMERGENCY ROOM ADMIT", + "insurance": "Medicaid", + "language": "ENGLISH", + "religion": "CHRISTIAN", + "marital_status": "SINGLE", + "ethnicity": "BLACK/AFRICAN AMERICAN", + "edregtime": "2150-01-03 00:00:00", + "edouttime": "2150-01-03 00:00:00", + "diagnosis": "SEPSIS", + "discharge_location": "HOME", + "dischtime": "2150-01-10 00:00:00", + "admittime": "2150-01-03 00:00:00", + "hospital_expire_flag": 0, + }, + { + "subject_id": 4, + "hadm_id": 103, + "admission_type": "EMERGENCY", + "admission_location": "EMERGENCY ROOM ADMIT", + "insurance": "Medicaid", + "language": "ENGLISH", + "religion": "CHRISTIAN", + "marital_status": "SINGLE", + "ethnicity": "BLACK/AFRICAN AMERICAN", + "edregtime": "2150-01-11 00:00:00", + "edouttime": "2150-01-11 00:00:00", + "diagnosis": "PNEUMONIA", + "discharge_location": "LEFT AGAINST MEDICAL ADVI", + "dischtime": "2150-01-18 00:00:00", + "admittime": "2150-01-11 00:00:00", + "hospital_expire_flag": 0, + }, + { + "subject_id": 5, + "hadm_id": 104, + "admission_type": "EMERGENCY", + "admission_location": "EMERGENCY ROOM ADMIT", + "insurance": "Medicare", + "language": "ENGLISH", + "religion": "CHRISTIAN", + "marital_status": "SINGLE", + "ethnicity": "HISPANIC OR LATINO", + "edregtime": "2150-01-12 00:00:00", + "edouttime": "2150-01-12 00:00:00", + "diagnosis": "OPIOID DEPENDENCE", + "discharge_location": "LEFT AGAINST MEDICAL ADVI", + "dischtime": "2150-01-19 00:00:00", + "admittime": "2150-01-12 00:00:00", + "hospital_expire_flag": 0, + }, +] + +_CURATED_MIMIC3_ICUSTAYS = [ + { + "subject_id": 1, + "hadm_id": 100, + "icustay_id": 1000, + "first_careunit": "MICU", + "last_careunit": "MICU", + "dbsource": "metavision", + "intime": "2150-01-01 02:00:00", + "outtime": "2150-01-07 22:00:00", + }, + { + "subject_id": 2, + "hadm_id": 101, + "icustay_id": 1001, + "first_careunit": "MICU", + "last_careunit": "MICU", + "dbsource": "metavision", + "intime": "2150-01-02 02:00:00", + "outtime": "2150-01-08 22:00:00", + }, + { + "subject_id": 3, + "hadm_id": 102, + "icustay_id": 1002, + "first_careunit": "MICU", + "last_careunit": "MICU", + "dbsource": "metavision", + "intime": "2150-01-03 02:00:00", + "outtime": "2150-01-09 22:00:00", + }, + { + "subject_id": 4, + "hadm_id": 103, + "icustay_id": 1003, + "first_careunit": "MICU", + "last_careunit": "MICU", + "dbsource": "metavision", + "intime": "2150-01-11 02:00:00", + "outtime": "2150-01-17 22:00:00", + }, + { + "subject_id": 5, + "hadm_id": 104, + "icustay_id": 1004, + "first_careunit": "MICU", + "last_careunit": "MICU", + "dbsource": "metavision", + "intime": "2150-01-12 02:00:00", + "outtime": "2150-01-18 22:00:00", + }, +] + + +def _write_curated_synthetic_mimic3_for_tests(root: str) -> None: + """Materialize the fixed 5-row MIMIC-III-like CSV.gz bundle for tests. + + Args: + root: Directory receiving ``PATIENTS.csv.gz``, ``ADMISSIONS.csv.gz``, + ``ICUSTAYS.csv.gz`` (column names match ``MIMIC3Dataset`` ingest). + + Returns: + None. + + Note: + Row mix matches ``CURATED_SYNTHETIC_*`` constants so + ``AMAPredictionMIMIC3`` yields both AMA labels and varied demographics + without loading real MIMIC-III. + """ + root_path = Path(root) + root_path.mkdir(parents=True, exist_ok=True) + for name, rows in ( + ("PATIENTS", _CURATED_MIMIC3_PATIENTS), + ("ADMISSIONS", _CURATED_MIMIC3_ADMISSIONS), + ("ICUSTAYS", _CURATED_MIMIC3_ICUSTAYS), + ): + df = pd.DataFrame(rows) + with gzip.open(root_path / f"{name}.csv.gz", "wt") as f: + df.to_csv(f, index=False) + + +# ------------------------------------------------------------------ +# Module-level shared dataset (loaded once for all integration tests) +# ------------------------------------------------------------------ + +_shared_tmpdir = None +_shared_cache_dir = None +_shared_dataset = None +_shared_sample_dataset = None + + +def setUpModule() -> None: + """Load one shared ``MIMIC3Dataset`` + ``SampleDataset`` for integration tests. + + Runs once per test module import: writes curated CSVs, builds the base + dataset with ``tables=[]``, applies ``AMAPredictionMIMIC3`` via + ``set_task`` (task ``input_schema`` / ``output_schema`` drive processors). + """ + global _shared_tmpdir, _shared_cache_dir + global _shared_dataset, _shared_sample_dataset + _shared_tmpdir = tempfile.mkdtemp(prefix="ama_shared_") + _shared_cache_dir = tempfile.mkdtemp(prefix="ama_shared_cache_") + _write_curated_synthetic_mimic3_for_tests(_shared_tmpdir) + _shared_dataset = MIMIC3Dataset( + root=_shared_tmpdir, + tables=[], + cache_dir=_shared_cache_dir, + ) + _shared_sample_dataset = _shared_dataset.set_task(AMAPredictionMIMIC3()) + + +def tearDownModule() -> None: + """Release LitData handles and remove temp CSV/cache directories.""" + global _shared_dataset, _shared_sample_dataset + if _shared_sample_dataset is not None: + _shared_sample_dataset.close() + _shared_sample_dataset = None + _shared_dataset = None + gc.collect() + # Proactively close lingering ``.ld`` chunk readers to avoid shutdown + # ``ResourceWarning`` from litdata after temp dirs are removed. + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + for obj in gc.get_objects(): + if isinstance(obj, io.FileIO) and not obj.closed: + name = getattr(obj, "name", "") + if ( + isinstance(name, str) + and _shared_cache_dir + and _shared_cache_dir in name + ): + try: + obj.close() + except Exception: + pass + gc.collect() + for d in (_shared_tmpdir, _shared_cache_dir): + if d: + shutil.rmtree(d, ignore_errors=True) + + +# ------------------------------------------------------------------ +# Helpers +# ------------------------------------------------------------------ + + +def _make_event(**attrs: Any) -> MagicMock: + """Build a ``MagicMock`` admission/drug/etc. row for unit tests. + + Args: + **attrs: Field names and values (e.g. ``hadm_id=``, ``icd9_code=``). + + Returns: + Mock object whose attributes match ``attrs``. + """ + event = MagicMock() + for key, value in attrs.items(): + setattr(event, key, value) + return event + + +def _build_patient( + patient_id: str, + admissions: List[Dict[str, Any]], + diagnoses: List[Dict[str, Any]], + procedures: List[Dict[str, Any]], + prescriptions: List[Dict[str, Any]], + gender: str = "M", + dob: str = "2100-01-01 00:00:00", +) -> MagicMock: + """Build a mock ``Patient`` whose ``get_events`` mirrors PyHealth filters. + + Args: + patient_id: ``patient.patient_id`` string. + admissions: kwargs for each admission ``MagicMock``. + diagnoses: kwargs for diagnosis events (``hadm_id`` aligned). + procedures: kwargs for procedure events. + prescriptions: kwargs for prescription events. + gender: Demographics token for ``patients`` event. + dob: DOB string feeding age logic in ``AMAPredictionMIMIC3``. + + Returns: + ``MagicMock`` with ``get_events`` routing ``event_type`` to the proper + list and honoring simple ``filters`` tuples on code tables. + + Note: + Output samples follow ``AMAPredictionMIMIC3.input_schema`` / + ``output_schema`` (same keys as real ``set_task`` tensors after + processors run on CSV-backed data). + """ + patient = MagicMock() + patient.patient_id = patient_id + + patient_event = _make_event(gender=gender, dob=dob) + adm_events = [_make_event(**a) for a in admissions] + diag_events = [_make_event(**d) for d in diagnoses] + proc_events = [_make_event(**p) for p in procedures] + rx_events = [_make_event(**r) for r in prescriptions] + + def _get_events(event_type, filters=None, **kwargs): + if event_type == "patients": + return [patient_event] + if event_type == "admissions": + return adm_events + source = { + "diagnoses_icd": diag_events, + "procedures_icd": proc_events, + "prescriptions": rx_events, + }.get(event_type, []) + if filters: + col, op, val = filters[0] + source = [e for e in source if getattr(e, col, None) == val] + return source + + patient.get_events = _get_events + return patient + + +SAMPLE_KEYS = { + "visit_id", + "patient_id", + "demographics", + "age", + "los", + "race", + "has_substance_use", + "ama", +} + + +class TestNormalizeRace(unittest.TestCase): + """Unit tests for the race normalization helper.""" + + def test_white(self): + self.assertEqual(_normalize_race("WHITE"), "White") + self.assertEqual(_normalize_race("WHITE - RUSSIAN"), "White") + + def test_black(self): + self.assertEqual(_normalize_race("BLACK/AFRICAN AMERICAN"), "Black") + + def test_hispanic(self): + self.assertEqual(_normalize_race("HISPANIC OR LATINO"), "Hispanic") + self.assertEqual(_normalize_race("SOUTH AMERICAN"), "Hispanic") + + def test_asian(self): + self.assertEqual(_normalize_race("ASIAN - CHINESE"), "Asian") + + def test_native_american(self): + self.assertEqual( + _normalize_race("AMERICAN INDIAN/ALASKA NATIVE"), + "Native American", + ) + + def test_other(self): + self.assertEqual(_normalize_race("UNKNOWN/NOT SPECIFIED"), "Other") + self.assertEqual(_normalize_race(None), "Other") + + def test_normalize_insurance(self): + self.assertEqual(_normalize_insurance("Medicare"), "Public") + self.assertEqual(_normalize_insurance("Medicaid"), "Public") + self.assertEqual(_normalize_insurance("Government"), "Public") + self.assertEqual(_normalize_insurance("Private"), "Private") + self.assertEqual(_normalize_insurance("Self Pay"), "Self Pay") + self.assertEqual(_normalize_insurance(None), "Other") + + +class TestHasSubstanceUse(unittest.TestCase): + """Unit tests for the substance-use detection helper.""" + + def test_alcohol(self): + self.assertEqual(_has_substance_use("ALCOHOL WITHDRAWAL"), 1) + + def test_opioid(self): + self.assertEqual(_has_substance_use("OPIOID DEPENDENCE"), 1) + + def test_heroin(self): + self.assertEqual(_has_substance_use("HEROIN OVERDOSE"), 1) + + def test_cocaine(self): + self.assertEqual(_has_substance_use("COCAINE INTOXICATION"), 1) + + def test_drug_withdrawal(self): + self.assertEqual(_has_substance_use("DRUG WITHDRAWAL SEIZURE"), 1) + + def test_etoh(self): + self.assertEqual(_has_substance_use("ETOH ABUSE"), 1) + + def test_substance(self): + self.assertEqual(_has_substance_use("SUBSTANCE ABUSE"), 1) + + def test_overdose(self): + self.assertEqual(_has_substance_use("OVERDOSE - ACCIDENTAL"), 1) + + def test_negative(self): + self.assertEqual(_has_substance_use("PNEUMONIA"), 0) + self.assertEqual(_has_substance_use("CHEST PAIN"), 0) + + def test_none(self): + self.assertEqual(_has_substance_use(None), 0) + + def test_case_insensitive(self): + self.assertEqual(_has_substance_use("alcohol withdrawal"), 1) + self.assertEqual(_has_substance_use("Heroin Overdose"), 1) + + +class TestAMAPredictionMIMIC3Schema(unittest.TestCase): + """Validate class-level schema attributes.""" + + def test_task_name(self): + self.assertEqual( + AMAPredictionMIMIC3.task_name, + "AMAPredictionMIMIC3", + ) + + def test_input_schema(self): + schema = AMAPredictionMIMIC3.input_schema + self.assertEqual(schema["demographics"], "multi_hot") + self.assertEqual(schema["age"], "tensor") + self.assertEqual(schema["los"], "tensor") + self.assertEqual(schema["race"], "multi_hot") + self.assertEqual(schema["has_substance_use"], "tensor") + + def test_output_schema(self): + self.assertEqual(AMAPredictionMIMIC3.output_schema, {"ama": "binary"}) + + def test_defaults(self): + task = AMAPredictionMIMIC3() + self.assertTrue(task.exclude_newborns) + + +class TestAMAPredictionMIMIC3Mock(unittest.TestCase): + """Unit tests using purely synthetic mock patients. + + No real dataset is loaded. Each test builds mock Patient + objects in memory and runs the task callable directly. + All tests complete in milliseconds. + """ + + def setUp(self) -> None: + """Fresh task instance per test (default ``exclude_newborns=True``).""" + self.task = AMAPredictionMIMIC3() + + def _default_admission( + self, + hadm_id: str = "100", + **overrides: Any, + ) -> Dict[str, Any]: + """Admission kwargs for ``_build_patient`` with sane AMA-test defaults. + + Args: + hadm_id: Visit id string stored on the mock admission. + **overrides: Fields to replace (e.g. ``discharge_location=``). + + Returns: + Dict passed to ``_make_event`` via ``_build_patient`` admissions + list; keys mirror post-ingest MIMIC attribute names. + """ + adm = { + "hadm_id": hadm_id, + "admission_type": "EMERGENCY", + "discharge_location": "HOME", + "ethnicity": "WHITE", + "insurance": "Private", + "dischtime": "2150-01-10 14:00:00", + "timestamp": datetime(2150, 1, 1), + "diagnosis": "PNEUMONIA", + } + adm.update(overrides) + return adm + + # ---------------------------------------------------------- + # Label generation + # ---------------------------------------------------------- + + def test_empty_patient(self): + patient = MagicMock() + patient.patient_id = "P0" + patient.get_events = lambda event_type, **kw: [] + self.assertEqual(self.task(patient), []) + + def test_ama_label_positive(self): + """AMA discharge -> label=1.""" + patient = _build_patient( + patient_id="P1", + admissions=[ + self._default_admission( + hadm_id="100", + discharge_location=("LEFT AGAINST MEDICAL ADVI"), + ), + ], + diagnoses=[{"hadm_id": "100", "icd9_code": "4019"}], + procedures=[{"hadm_id": "100", "icd9_code": "3893"}], + prescriptions=[{"hadm_id": "100", "drug": "Aspirin"}], + ) + samples = self.task(patient) + self.assertEqual(len(samples), 1) + self.assertEqual(samples[0]["ama"], 1) + self.assertEqual(samples[0]["visit_id"], "100") + self.assertEqual(samples[0]["patient_id"], "P1") + + def test_ama_label_negative(self): + """Non-AMA discharge -> label=0.""" + patient = _build_patient( + patient_id="P2", + admissions=[ + self._default_admission(hadm_id="200"), + ], + diagnoses=[{"hadm_id": "200", "icd9_code": "25000"}], + procedures=[{"hadm_id": "200", "icd9_code": "3995"}], + prescriptions=[{"hadm_id": "200", "drug": "Metformin"}], + ) + samples = self.task(patient) + self.assertEqual(len(samples), 1) + self.assertEqual(samples[0]["ama"], 0) + + def test_multiple_admissions_mixed_labels(self): + """Two admissions: one AMA, one not.""" + patient = _build_patient( + patient_id="P3", + admissions=[ + self._default_admission(hadm_id="300"), + self._default_admission( + hadm_id="301", + admission_type="URGENT", + discharge_location=("LEFT AGAINST MEDICAL ADVI"), + timestamp=datetime(2150, 6, 1), + dischtime="2150-06-05 10:00:00", + ), + ], + diagnoses=[ + {"hadm_id": "300", "icd9_code": "4019"}, + {"hadm_id": "301", "icd9_code": "30000"}, + ], + procedures=[ + {"hadm_id": "300", "icd9_code": "3893"}, + {"hadm_id": "301", "icd9_code": "9394"}, + ], + prescriptions=[ + {"hadm_id": "300", "drug": "Lisinopril"}, + {"hadm_id": "301", "drug": "Naloxone"}, + ], + ) + samples = self.task(patient) + self.assertEqual(len(samples), 2) + labels = {s["visit_id"]: s["ama"] for s in samples} + self.assertEqual(labels["300"], 0) + self.assertEqual(labels["301"], 1) + + # ---------------------------------------------------------- + # Filtering / edge cases + # ---------------------------------------------------------- + + def test_exclude_newborns(self): + """NEWBORN admissions skipped when flag is True.""" + patient = _build_patient( + patient_id="P7", + admissions=[ + self._default_admission( + hadm_id="700", + admission_type="NEWBORN", + ), + ], + diagnoses=[{"hadm_id": "700", "icd9_code": "V3000"}], + procedures=[{"hadm_id": "700", "icd9_code": "9904"}], + prescriptions=[{"hadm_id": "700", "drug": "Vitamin K"}], + ) + task_ex = AMAPredictionMIMIC3(exclude_newborns=True) + self.assertEqual(task_ex(patient), []) + + task_in = AMAPredictionMIMIC3(exclude_newborns=False) + samples = task_in(patient) + self.assertEqual(len(samples), 1) + self.assertEqual(samples[0]["ama"], 0) + + def test_no_patients_events(self): + """Patient with admissions but no ``patients`` event.""" + patient = MagicMock() + patient.patient_id = "PX" + + def _get(event_type, **kw): + if event_type == "patients": + return [] + if event_type == "admissions": + return [_make_event(hadm_id="1")] + return [] + + patient.get_events = _get + self.assertEqual(self.task(patient), []) + + # ---------------------------------------------------------- + # Feature extraction + # ---------------------------------------------------------- + + def test_sample_keys(self): + """Every sample must contain the expected keys.""" + patient = _build_patient( + patient_id="P8", + admissions=[ + self._default_admission( + hadm_id="800", + discharge_location="SNF", + timestamp=datetime(2150, 2, 15), + ), + ], + diagnoses=[{"hadm_id": "800", "icd9_code": "4280"}], + procedures=[{"hadm_id": "800", "icd9_code": "3722"}], + prescriptions=[{"hadm_id": "800", "drug": "Furosemide"}], + ) + samples = self.task(patient) + self.assertEqual(len(samples), 1) + self.assertEqual(set(samples[0].keys()), SAMPLE_KEYS) + + def test_demographics_baseline_tokens(self): + """BASELINE demographics: gender + insurance, no race.""" + patient = _build_patient( + patient_id="P10", + admissions=[ + self._default_admission( + hadm_id="1000", + ethnicity="BLACK/AFRICAN AMERICAN", + insurance="Medicaid", + timestamp=datetime(2150, 5, 1), + ), + ], + diagnoses=[{"hadm_id": "1000", "icd9_code": "4019"}], + procedures=[{"hadm_id": "1000", "icd9_code": "3893"}], + prescriptions=[{"hadm_id": "1000", "drug": "Aspirin"}], + gender="F", + ) + samples = self.task(patient) + demo = samples[0]["demographics"] + self.assertIn("gender:F", demo) + self.assertIn("insurance:Public", demo) + for token in demo: + self.assertFalse( + token.startswith("race:"), + "race must not be in demographics", + ) + + def test_race_separate_feature(self): + """Race is a separate multi-hot feature.""" + patient = _build_patient( + patient_id="P10b", + admissions=[ + self._default_admission( + hadm_id="1001", + ethnicity="BLACK/AFRICAN AMERICAN", + timestamp=datetime(2150, 5, 1), + ), + ], + diagnoses=[{"hadm_id": "1001", "icd9_code": "4019"}], + procedures=[{"hadm_id": "1001", "icd9_code": "3893"}], + prescriptions=[{"hadm_id": "1001", "drug": "Aspirin"}], + ) + samples = self.task(patient) + self.assertIn("race", samples[0]) + self.assertEqual(samples[0]["race"], ["race:Black"]) + + def test_substance_use_positive(self): + """Substance-use diagnosis -> has_substance_use=1.""" + patient = _build_patient( + patient_id="P14", + admissions=[ + self._default_admission( + hadm_id="1400", + diagnosis="ALCOHOL WITHDRAWAL", + timestamp=datetime(2150, 7, 1), + ), + ], + diagnoses=[{"hadm_id": "1400", "icd9_code": "29181"}], + procedures=[{"hadm_id": "1400", "icd9_code": "3893"}], + prescriptions=[{"hadm_id": "1400", "drug": "Lorazepam"}], + ) + samples = self.task(patient) + self.assertEqual(samples[0]["has_substance_use"], [1.0]) + + def test_substance_use_negative(self): + """Non-substance diagnosis -> has_substance_use=0.""" + patient = _build_patient( + patient_id="P15", + admissions=[ + self._default_admission( + hadm_id="1500", + diagnosis="PNEUMONIA", + timestamp=datetime(2150, 8, 1), + ), + ], + diagnoses=[{"hadm_id": "1500", "icd9_code": "486"}], + procedures=[{"hadm_id": "1500", "icd9_code": "3893"}], + prescriptions=[{"hadm_id": "1500", "drug": "Levofloxacin"}], + ) + samples = self.task(patient) + self.assertEqual(samples[0]["has_substance_use"], [0.0]) + + def test_age_calculation(self): + """Age computed from dob and admission timestamp.""" + patient = _build_patient( + patient_id="P11", + admissions=[ + self._default_admission( + hadm_id="1100", + timestamp=datetime(2150, 6, 15), + dischtime="2150-06-20 12:00:00", + ), + ], + diagnoses=[{"hadm_id": "1100", "icd9_code": "4019"}], + procedures=[{"hadm_id": "1100", "icd9_code": "3893"}], + prescriptions=[{"hadm_id": "1100", "drug": "Aspirin"}], + dob="2100-01-01 00:00:00", + ) + samples = self.task(patient) + self.assertEqual(samples[0]["age"], [50.0]) + + def test_age_capped_at_90(self): + """Ages above 90 are capped (MIMIC-III convention).""" + patient = _build_patient( + patient_id="P12", + admissions=[ + self._default_admission( + hadm_id="1200", + timestamp=datetime(2150, 6, 15), + ), + ], + diagnoses=[{"hadm_id": "1200", "icd9_code": "4019"}], + procedures=[{"hadm_id": "1200", "icd9_code": "3893"}], + prescriptions=[{"hadm_id": "1200", "drug": "Aspirin"}], + dob="1850-01-01 00:00:00", + ) + samples = self.task(patient) + self.assertEqual(samples[0]["age"], [90.0]) + + def test_los_calculation(self): + """LOS in days from admittime to dischtime.""" + patient = _build_patient( + patient_id="P13", + admissions=[ + self._default_admission( + hadm_id="1300", + timestamp=datetime(2150, 3, 1, 8, 0, 0), + dischtime="2150-03-06 08:00:00", + ), + ], + diagnoses=[{"hadm_id": "1300", "icd9_code": "4019"}], + procedures=[{"hadm_id": "1300", "icd9_code": "3893"}], + prescriptions=[{"hadm_id": "1300", "drug": "Aspirin"}], + ) + samples = self.task(patient) + self.assertAlmostEqual(samples[0]["los"][0], 5.0, places=2) + + # ---------------------------------------------------------- + # Multi-patient synthetic "dataset" (2 patients) + # ---------------------------------------------------------- + + def test_two_patient_synthetic_dataset(self): + """End-to-end with 2 synthetic patients, both labels.""" + p1 = _build_patient( + patient_id="S1", + admissions=[ + self._default_admission( + hadm_id="A1", + discharge_location=("LEFT AGAINST MEDICAL ADVI"), + ethnicity="HISPANIC OR LATINO", + insurance="Medicaid", + diagnosis="HEROIN OVERDOSE", + timestamp=datetime(2150, 1, 1), + dischtime="2150-01-03 12:00:00", + ), + ], + diagnoses=[{"hadm_id": "A1", "icd9_code": "96500"}], + procedures=[{"hadm_id": "A1", "icd9_code": "9604"}], + prescriptions=[{"hadm_id": "A1", "drug": "Naloxone"}], + gender="M", + dob="2100-06-15 00:00:00", + ) + p2 = _build_patient( + patient_id="S2", + admissions=[ + self._default_admission( + hadm_id="A2", + discharge_location="HOME", + ethnicity="WHITE", + insurance="Private", + diagnosis="CHEST PAIN", + timestamp=datetime(2150, 3, 10), + dischtime="2150-03-12 08:00:00", + ), + ], + diagnoses=[{"hadm_id": "A2", "icd9_code": "78650"}], + procedures=[{"hadm_id": "A2", "icd9_code": "8856"}], + prescriptions=[{"hadm_id": "A2", "drug": "Aspirin"}], + gender="F", + dob="2090-01-01 00:00:00", + ) + + all_samples = self.task(p1) + self.task(p2) + self.assertEqual(len(all_samples), 2) + + s1 = all_samples[0] + self.assertEqual(s1["ama"], 1) + self.assertEqual(s1["race"], ["race:Hispanic"]) + self.assertIn("insurance:Public", s1["demographics"]) + self.assertEqual(s1["has_substance_use"], [1.0]) + self.assertAlmostEqual(s1["age"][0], 49.0, places=0) + self.assertAlmostEqual(s1["los"][0], 2.5, places=1) + + s2 = all_samples[1] + self.assertEqual(s2["ama"], 0) + self.assertEqual(s2["race"], ["race:White"]) + self.assertIn("insurance:Private", s2["demographics"]) + self.assertEqual(s2["has_substance_use"], [0.0]) + self.assertAlmostEqual(s2["age"][0], 60.0, places=0) + + +class TestAMAAblationBaselines(unittest.TestCase): + """Tests for the three ablation study feature baselines. + + These tests verify that each baseline can be used to select + different subsets of features via the model's feature_keys parameter. + """ + + def setUp(self) -> None: + """Build one multi-visit mock patient covering baseline feature keys. + + Samples include all ``AMAPredictionMIMIC3.input_schema`` fields so + tests can reason about ``feature_keys`` subsets the same way models + do after ``set_task``. + """ + self.task = AMAPredictionMIMIC3() + self.patient = _build_patient( + patient_id="ABLATION_TEST", + admissions=[ + { + "hadm_id": "A1", + "admission_type": "EMERGENCY", + "discharge_location": "HOME", + "ethnicity": "HISPANIC OR LATINO", + "insurance": "Medicaid", + "dischtime": "2150-01-10 14:00:00", + "timestamp": datetime(2150, 1, 1), + "diagnosis": "ALCOHOL WITHDRAWAL", + }, + { + "hadm_id": "A2", + "admission_type": "URGENT", + "discharge_location": "LEFT AGAINST MEDICAL ADVI", + "ethnicity": "WHITE", + "insurance": "Private", + "dischtime": "2150-06-05 10:00:00", + "timestamp": datetime(2150, 6, 1), + "diagnosis": "PNEUMONIA", + }, + ], + diagnoses=[ + {"hadm_id": "A1", "icd9_code": "29181"}, + {"hadm_id": "A2", "icd9_code": "486"}, + ], + procedures=[ + {"hadm_id": "A1", "icd9_code": "3893"}, + {"hadm_id": "A2", "icd9_code": "9604"}, + ], + prescriptions=[ + {"hadm_id": "A1", "drug": "Lorazepam"}, + {"hadm_id": "A2", "drug": "Levofloxacin"}, + ], + gender="M", + dob="2100-01-01 00:00:00", + ) + + def test_baseline_features_present(self): + """BASELINE includes demographics, age, los.""" + samples = self.task(self.patient) + self.assertGreaterEqual(len(samples), 1) + + sample = samples[0] + self.assertIn("demographics", sample) + self.assertIn("age", sample) + self.assertIn("los", sample) + self.assertTrue(isinstance(sample["age"][0], float)) + self.assertTrue(isinstance(sample["los"][0], float)) + self.assertTrue(isinstance(sample["demographics"], list)) + + def test_baseline_race_feature_present(self): + """BASELINE + RACE adds race feature.""" + samples = self.task(self.patient) + self.assertGreaterEqual(len(samples), 1) + + for sample in samples: + self.assertIn("race", sample) + self.assertTrue(isinstance(sample["race"], list)) + race_val = sample["race"][0].split(":", 1)[1] + self.assertIn( + race_val, + ["White", "Black", "Hispanic", "Asian", "Native American", "Other"], + ) + + def test_baseline_substance_use_feature_present(self): + """BASELINE + RACE + SUBSTANCE adds has_substance_use.""" + samples = self.task(self.patient) + self.assertGreaterEqual(len(samples), 1) + + for sample in samples: + self.assertIn("has_substance_use", sample) + self.assertTrue(isinstance(sample["has_substance_use"], list)) + self.assertIn(sample["has_substance_use"][0], [0.0, 1.0]) + + def test_substance_use_detection_in_ablation(self): + """Verify substance use detection for ablation patient.""" + samples = self.task(self.patient) + + s1 = next(s for s in samples if s["visit_id"] == "A1") + self.assertEqual(s1["has_substance_use"], [1.0]) + + s2 = next(s for s in samples if s["visit_id"] == "A2") + self.assertEqual(s2["has_substance_use"], [0.0]) + + def test_race_normalization_in_ablation(self): + """Verify race normalization for ablation patient.""" + samples = self.task(self.patient) + + s1 = next(s for s in samples if s["visit_id"] == "A1") + self.assertEqual(s1["race"], ["race:Hispanic"]) + + s2 = next(s for s in samples if s["visit_id"] == "A2") + self.assertEqual(s2["race"], ["race:White"]) + + def test_age_and_los_computed(self): + """Verify age and LOS are computed correctly.""" + samples = self.task(self.patient) + + for sample in samples: + age = sample["age"][0] + los = sample["los"][0] + self.assertAlmostEqual(age, 50.0, places=1) + self.assertGreater(los, 0.0) + + def test_demographics_includes_gender_and_insurance(self): + """BASELINE demographics include gender and insurance.""" + samples = self.task(self.patient) + + for sample in samples: + demo = sample["demographics"] + has_gender = any(t.startswith("gender:") for t in demo) + has_insurance = any(t.startswith("insurance:") for t in demo) + self.assertTrue(has_gender) + self.assertTrue(has_insurance) + + def test_insurance_normalization_in_ablation(self): + """Verify insurance normalization (Medicaid -> Public).""" + samples = self.task(self.patient) + + s1 = next(s for s in samples if s["visit_id"] == "A1") + demo1 = s1["demographics"] + self.assertIn("insurance:Public", demo1) + + s2 = next(s for s in samples if s["visit_id"] == "A2") + demo2 = s2["demographics"] + self.assertIn("insurance:Private", demo2) + + def test_label_correctness_in_ablation(self): + """Verify AMA label is correct.""" + samples = self.task(self.patient) + + s1 = next(s for s in samples if s["visit_id"] == "A1") + self.assertEqual(s1["ama"], 0) + + s2 = next(s for s in samples if s["visit_id"] == "A2") + self.assertEqual(s2["ama"], 1) + + def test_baseline_minimal_features(self): + """BASELINE (minimal) has only required keys.""" + samples = self.task(self.patient) + self.assertGreaterEqual(len(samples), 1) + + sample = samples[0] + baseline_keys = { + "demographics", + "age", + "los", + "race", + "has_substance_use", + "visit_id", + "patient_id", + "ama", + } + self.assertEqual(set(sample.keys()), baseline_keys) + + def test_multiple_admissions_all_included(self): + """All non-newborn admissions are included (no filtering).""" + samples = self.task(self.patient) + self.assertEqual(len(samples), 2) + + visit_ids = {s["visit_id"] for s in samples} + self.assertEqual(visit_ids, {"A1", "A2"}) + + def test_ablation_patient_no_clinical_codes(self): + """Ablation samples do not contain clinical code fields.""" + samples = self.task(self.patient) + + for sample in samples: + self.assertNotIn("conditions", sample) + self.assertNotIn("procedures", sample) + self.assertNotIn("drugs", sample) + + +# ------------------------------------------------------------------ +# Integration tests using shared curated 5-row dataset +# ------------------------------------------------------------------ + + +class TestAMAWithSyntheticData(unittest.TestCase): + """AMA task on curated minimal synthetic CSVs (fast pipeline checks).""" + + def test_dataset_loads_successfully(self): + self.assertIsNotNone(_shared_dataset) + self.assertGreater(len(_shared_sample_dataset), 0) + + def test_samples_have_expected_features(self): + sample = _shared_sample_dataset[0] + + expected_keys = { + "visit_id", + "patient_id", + "demographics", + "age", + "los", + "race", + "has_substance_use", + "ama", + } + self.assertEqual(set(sample.keys()), expected_keys) + + def test_demographics_values(self): + for sample in _shared_sample_dataset: + demo = sample["demographics"] + self.assertTrue( + torch.is_tensor(demo) or isinstance(demo, (int, float)), + "Demographics should be processed", + ) + + def test_age_in_valid_range(self): + for sample in _shared_sample_dataset: + age = sample["age"] + self.assertTrue(torch.is_tensor(age) or isinstance(age, (int, float))) + + def test_los_positive(self): + for sample in _shared_sample_dataset: + los = sample["los"] + self.assertTrue(torch.is_tensor(los) or isinstance(los, (int, float))) + + def test_race_normalized(self): + for sample in _shared_sample_dataset: + race = sample["race"] + self.assertTrue(torch.is_tensor(race) or isinstance(race, (int, float))) + + def test_substance_use_binary(self): + for sample in _shared_sample_dataset: + substance = sample["has_substance_use"] + self.assertTrue( + torch.is_tensor(substance) or isinstance(substance, (int, float)), + ) + + def test_ama_label_binary(self): + for sample in _shared_sample_dataset: + ama = sample["ama"] + self.assertIn(ama, [0, 1]) + + def test_has_positive_and_negative_labels(self): + labels = [sample["ama"] for sample in _shared_sample_dataset] + has_positive = any(label == 1 for label in labels) + has_negative = any(label == 0 for label in labels) + + self.assertTrue( + has_positive and has_negative, + "Dataset should have both positive and negative AMA cases", + ) + + +class TestAMABaselineFeatures(unittest.TestCase): + """LogisticRegression ablation feature subsets on synthetic data.""" + + def _create_model_with_features(self, feature_keys): + model = LogisticRegression( + dataset=_shared_sample_dataset, + embedding_dim=64, + ) + model.feature_keys = list(feature_keys) + output_size = model.get_output_size() + embedding_dim = model.embedding_model.embedding_layers[ + feature_keys[0] + ].out_features + model.fc = torch.nn.Linear(len(feature_keys) * embedding_dim, output_size) + return model + + def test_baseline_model_can_be_created(self): + model = self._create_model_with_features(["demographics", "age", "los"]) + self.assertIsNotNone(model) + self.assertIsNotNone(model.fc) + + def test_baseline_plus_race_model(self): + model = self._create_model_with_features(["demographics", "age", "los", "race"]) + self.assertIsNotNone(model) + self.assertIsNotNone(model.fc) + + def test_baseline_plus_race_plus_substance_model(self): + model = self._create_model_with_features( + ["demographics", "age", "los", "race", "has_substance_use"] + ) + self.assertIsNotNone(model) + self.assertIsNotNone(model.fc) + + def test_baseline_forward_pass(self): + model = self._create_model_with_features(["demographics", "age", "los"]) + + train_ds, _, test_ds = split_by_patient( + _shared_sample_dataset, [0.8, 0.0, 0.2], seed=0 + ) + test_dl = get_dataloader(test_ds, batch_size=8, shuffle=False) + + model.eval() + with torch.no_grad(): + batch = next(iter(test_dl)) + output = model(**batch) + + self.assertIn("y_prob", output) + self.assertIn("y_true", output) + self.assertEqual(output["y_prob"].shape[0], len(test_ds)) + + def test_baseline_plus_race_forward_pass(self): + model = self._create_model_with_features(["demographics", "age", "los", "race"]) + + train_ds, _, test_ds = split_by_patient( + _shared_sample_dataset, [0.8, 0.0, 0.2], seed=0 + ) + test_dl = get_dataloader(test_ds, batch_size=8, shuffle=False) + + model.eval() + with torch.no_grad(): + batch = next(iter(test_dl)) + output = model(**batch) + + self.assertIn("y_prob", output) + self.assertEqual(output["y_prob"].shape[0], len(test_ds)) + + def test_baseline_plus_full_forward_pass(self): + model = self._create_model_with_features( + ["demographics", "age", "los", "race", "has_substance_use"] + ) + + train_ds, _, test_ds = split_by_patient( + _shared_sample_dataset, [0.8, 0.0, 0.2], seed=0 + ) + test_dl = get_dataloader(test_ds, batch_size=8, shuffle=False) + + model.eval() + with torch.no_grad(): + batch = next(iter(test_dl)) + output = model(**batch) + + self.assertIn("y_prob", output) + self.assertEqual(output["y_prob"].shape[0], len(test_ds)) + + +class TestAMATrainingSpeed(unittest.TestCase): + """Short training runs on tiny synthetic data.""" + + def test_training_completes_quickly(self): + import time + + train_ds, _, test_ds = split_by_patient( + _shared_sample_dataset, [0.6, 0.0, 0.4], seed=0 + ) + train_dl = get_dataloader(train_ds, batch_size=8, shuffle=True) + + model = LogisticRegression( + dataset=_shared_sample_dataset, + embedding_dim=64, + ) + model.feature_keys = ["demographics", "age", "los"] + output_size = model.get_output_size() + embedding_dim = model.embedding_model.embedding_layers[ + "demographics" + ].out_features + model.fc = torch.nn.Linear(3 * embedding_dim, output_size) + + trainer = Trainer(model=model) + + t0 = time.time() + trainer.train( + train_dataloader=train_dl, + val_dataloader=None, + epochs=1, + monitor=None, + ) + elapsed = time.time() - t0 + + self.assertGreater(elapsed, 0, "Training should take some time") + + def test_multiple_splits_complete_quickly(self): + for split_seed in range(2): + train_ds, _, _ = split_by_patient( + _shared_sample_dataset, + [0.6, 0.0, 0.4], + seed=split_seed, + ) + train_dl = get_dataloader(train_ds, batch_size=8, shuffle=True) + + model = LogisticRegression( + dataset=_shared_sample_dataset, + embedding_dim=64, + ) + model.feature_keys = ["demographics", "age", "los"] + output_size = model.get_output_size() + embedding_dim = model.embedding_model.embedding_layers[ + "demographics" + ].out_features + model.fc = torch.nn.Linear(3 * embedding_dim, output_size) + + trainer = Trainer(model=model) + trainer.train( + train_dataloader=train_dl, + val_dataloader=None, + epochs=1, + monitor=None, + ) + + self.assertTrue(True) + + +EXHAUSTIVE_PATIENT_ROWS = 2 * 6 * 6 * 3 * 2 * 2 + 3 + + +class TestExhaustiveSyntheticGrid(unittest.TestCase): + """Sanity-check exhaustive synthetic generator (row counts only).""" + + def test_patient_row_count_matches_cross_product(self): + tmp = tempfile.mkdtemp(prefix="ama_exhaustive_") + try: + generate_synthetic_mimic3( + tmp, + mode="exhaustive", + seed=0, + n_patients=1, + ) + with gzip.open(Path(tmp) / "PATIENTS.csv.gz", "rt") as f: + lines = f.readlines() + data_rows = len(lines) - 1 + self.assertEqual(data_rows, EXHAUSTIVE_PATIENT_ROWS) + finally: + shutil.rmtree(tmp, ignore_errors=True) + + +class TestCuratedSyntheticGrid(unittest.TestCase): + """Curated synthetic: fixed small CSV (integration test data).""" + + def test_curated_csv_row_counts(self): + tmp = tempfile.mkdtemp(prefix="ama_curated_") + try: + _write_curated_synthetic_mimic3_for_tests(tmp) + for name in ("PATIENTS", "ADMISSIONS", "ICUSTAYS"): + with gzip.open(Path(tmp) / f"{name}.csv.gz", "rt") as f: + n = len(f.readlines()) - 1 + self.assertEqual(n, CURATED_SYNTHETIC_N, name) + finally: + shutil.rmtree(tmp, ignore_errors=True) + + def test_curated_task_label_counts(self): + self.assertEqual(len(_shared_sample_dataset), CURATED_SYNTHETIC_N) + + def _ama_int(x): + if torch.is_tensor(x): + return int(x.item()) + return int(x) + + labels = [ + _ama_int(_shared_sample_dataset[i]["ama"]) + for i in range(CURATED_SYNTHETIC_N) + ] + self.assertEqual(sum(labels), CURATED_SYNTHETIC_AMA_POSITIVE) + self.assertEqual( + labels.count(0), + CURATED_SYNTHETIC_AMA_NEGATIVE, + ) + self.assertEqual( + labels.count(1), + CURATED_SYNTHETIC_AMA_POSITIVE, + ) + + +if __name__ == "__main__": + unittest.main()