Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/api/tasks.rst
Original file line number Diff line number Diff line change
Expand Up @@ -218,6 +218,7 @@ Available Tasks
Mortality Prediction (StageNet MIMIC-IV) <tasks/pyhealth.tasks.mortality_prediction_stagenet_mimic4>
Patient Linkage (MIMIC-III) <tasks/pyhealth.tasks.patient_linkage_mimic3_fn>
Readmission Prediction <tasks/pyhealth.tasks.readmission_prediction>
ER-Specific Readmission (MIMIC-IV) <tasks/pyhealth.tasks.mimic4_er_readmission>
Sleep Staging <tasks/pyhealth.tasks.sleep_staging>
Sleep Staging (SleepEDF) <tasks/pyhealth.tasks.SleepStagingSleepEDF>
Temple University EEG Tasks <tasks/pyhealth.tasks.temple_university_EEG_tasks>
Expand Down
7 changes: 7 additions & 0 deletions docs/api/tasks/pyhealth.tasks.mimic4_er_readmission.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
pyhealth.tasks.mimic4_er_readmission
======================================

.. autoclass:: pyhealth.tasks.mimic4_er_readmission.ERReadmissionMIMIC4
:members:
:undoc-members:
:show-inheritance:
137 changes: 137 additions & 0 deletions examples/mimic4_er_readmission_retain.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,137 @@
"""
Name: Ranjithkumar Rajendran
NetID: rr54
Paper: KEEP (CHIL 2025) — Elhussein et al.

Ablation 1 — Task Comparison.
Compares standard inpatient readmission
(ReadmissionPredictionMIMIC4) vs ER-specific
readmission (ERReadmissionMIMIC4) using RETAIN.
"""
from pyhealth.datasets import (
MIMIC4EHRDataset,
split_by_patient,
get_dataloader,
)
from pyhealth.tasks import (
ERReadmissionMIMIC4,
ReadmissionPredictionMIMIC4,
)
from pyhealth.models import RETAIN
from pyhealth.trainer import Trainer
import math


def _fmt(v):
"""Format a metric, showing 'n/a' for NaN."""
return "n/a" if math.isnan(v) else f"{v:.4f}"


def _print_metrics(name, m):
"""Print ROC-AUC and PR-AUC for a model."""
print(f"{name} ROC-AUC: {_fmt(m['roc_auc'])}")
print(f"{name} PR-AUC : {_fmt(m['pr_auc'])}")


def main():
"""Run the Task-Comparison ablation."""
print("Loading Dataset ...")
# Point this to your MIMIC-IV root directory.
# e.g. "/content/drive/MyDrive/mimic-iv/2.2"
dataset = MIMIC4EHRDataset(
root="/path/to/mimic-iv-2.2",
tables=[
"diagnoses_icd",
"procedures_icd",
"prescriptions",
],
dev=True,
)

# --- Task 1: Standard Inpatient Readmission ----
print("\n[Ablation] Task 1: Standard Readmission")
ds_std = dataset.set_task(
ReadmissionPredictionMIMIC4()
)

# --- Task 2: ER-Specific Readmission -----------
print("\n[Ablation] Task 2: ER Readmission")
ds_er = dataset.set_task(ERReadmissionMIMIC4())

print(f"\nStandard samples : {len(ds_std)}")
print(f"ER-Specific samples: {len(ds_er)}")

# --- Initialise models -------------------------
print("\nInitializing RETAIN on both cohorts ...")
model_std = RETAIN(dataset=ds_std)
print(" -> Standard task: OK")
model_er = RETAIN(dataset=ds_er)
print(" -> ER task : OK")

# --- Split + Dataloaders -----------------------
print("\n--- Splitting data ---")
tr_s, va_s, te_s = split_by_patient(
ds_std, [0.8, 0.1, 0.1]
)
tr_e, va_e, te_e = split_by_patient(
ds_er, [0.8, 0.1, 0.1]
)

if len(va_s) == 0 or len(va_e) == 0:
print(
"Val set is empty (tiny synthetic data).\n"
"Pipeline verified — skipping Trainer."
)
return

dl = get_dataloader # alias for brevity
tr_l_s = dl(tr_s, batch_size=64, shuffle=True)
va_l_s = dl(va_s, batch_size=64, shuffle=False)
te_l_s = dl(te_s, batch_size=64, shuffle=False)

tr_l_e = dl(tr_e, batch_size=64, shuffle=True)
va_l_e = dl(va_e, batch_size=64, shuffle=False)
te_l_e = dl(te_e, batch_size=64, shuffle=False)

# --- Train Standard ----------------------------
print("\n--- Training: Standard Readmission ---")
t_std = Trainer(model=model_std)
t_std.train(
train_dataloader=tr_l_s,
val_dataloader=va_l_s,
epochs=10,
monitor="pr_auc",
)
m_std = t_std.evaluate(te_l_s)
_print_metrics("Standard", m_std)

# --- Train ER ----------------------------------
print("\n--- Training: ER Readmission ---")
t_er = Trainer(model=model_er)
t_er.train(
train_dataloader=tr_l_e,
val_dataloader=va_l_e,
epochs=10,
monitor="pr_auc",
)
m_er = t_er.evaluate(te_l_e)
_print_metrics("ER", m_er)

# --- Compare -----------------------------------
s = m_std["pr_auc"]
e = m_er["pr_auc"]
if math.isnan(s) or math.isnan(e):
print("\nAblation note: PR-AUC undefined "
"on this tiny split (expected).")
else:
d = s - e
print(
f"\nAblation result: ER cohort "
f"PR-AUC is {d * 100:.2f}% "
f"{'lower' if d > 0 else 'higher'} "
f"than standard."
)


if __name__ == "__main__":
main()
116 changes: 116 additions & 0 deletions examples/mimic4_er_readmission_transformer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,116 @@
"""
Name: Ranjithkumar Rajendran
NetID: rr54
Paper: KEEP (CHIL 2025) — Elhussein et al.

Ablation 2 — Architecture Comparison.
Compares Transformer vs RETAIN on the new
ERReadmissionMIMIC4 task.
"""
from pyhealth.datasets import (
MIMIC4EHRDataset,
split_by_patient,
get_dataloader,
)
from pyhealth.tasks import ERReadmissionMIMIC4
from pyhealth.models import Transformer, RETAIN
from pyhealth.trainer import Trainer
import math


def _fmt(v):
"""Format a metric, showing 'n/a' for NaN."""
return "n/a" if math.isnan(v) else f"{v:.4f}"


def _print_metrics(name, m):
"""Print ROC-AUC and PR-AUC for a model."""
print(f"{name} ROC-AUC: {_fmt(m['roc_auc'])}")
print(f"{name} PR-AUC : {_fmt(m['pr_auc'])}")


def main():
"""Run the Architecture-Comparison ablation."""
print("Loading Dataset ...")
# Point this to your MIMIC-IV root directory.
dataset = MIMIC4EHRDataset(
root="/path/to/mimic-iv-2.2",
tables=["diagnoses_icd"],
dev=True,
)

print("\nApplying ER-Specific Readmission Task ...")
ds_er = dataset.set_task(ERReadmissionMIMIC4())
print(f"ER samples: {len(ds_er)}")

# --- Initialise both architectures -------------
print("\n[Ablation] Architecture 1: RETAIN")
model_ret = RETAIN(dataset=ds_er)
print(" -> RETAIN OK")

print("\n[Ablation] Architecture 2: Transformer")
model_tfm = Transformer(dataset=ds_er)
print(" -> Transformer OK")

# --- Split + Dataloaders -----------------------
print("\n--- Splitting data ---")
tr, va, te = split_by_patient(
ds_er, [0.8, 0.1, 0.1]
)

if len(va) == 0:
print(
"Val set is empty (tiny synthetic data).\n"
"Pipeline verified — skipping Trainer."
)
return

dl = get_dataloader
tr_l = dl(tr, batch_size=64, shuffle=True)
va_l = dl(va, batch_size=64, shuffle=False)
te_l = dl(te, batch_size=64, shuffle=False)

# --- Train RETAIN ------------------------------
print("\n--- Training: RETAIN ---")
t_ret = Trainer(model=model_ret)
t_ret.train(
train_dataloader=tr_l,
val_dataloader=va_l,
epochs=10,
monitor="pr_auc",
)
m_ret = t_ret.evaluate(te_l)
_print_metrics("RETAIN", m_ret)

# --- Train Transformer -------------------------
print("\n--- Training: Transformer ---")
t_tfm = Trainer(model=model_tfm)
t_tfm.train(
train_dataloader=tr_l,
val_dataloader=va_l,
epochs=10,
monitor="pr_auc",
)
m_tfm = t_tfm.evaluate(te_l)
_print_metrics("Transformer", m_tfm)

# --- Compare -----------------------------------
r = m_ret["pr_auc"]
t = m_tfm["pr_auc"]
if math.isnan(r) or math.isnan(t):
print(
"\nAblation note: PR-AUC undefined "
"on this tiny split (expected)."
)
else:
d = t - r
print(
f"\nAblation result: Transformer "
f"PR-AUC is {d * 100:.2f}% "
f"{'higher' if d > 0 else 'lower'}"
f" than RETAIN."
)


if __name__ == "__main__":
main()
1 change: 1 addition & 0 deletions pyhealth/tasks/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
from .length_of_stay_stagenet_mimic4 import LengthOfStayStageNetMIMIC4
from .medical_coding import MIMIC3ICD9Coding
from .medical_transcriptions_classification import MedicalTranscriptionsClassification
from .mimic4_er_readmission import ERReadmissionMIMIC4
from .mortality_prediction import (
MortalityPredictionEICU,
MortalityPredictionEICU2,
Expand Down
Loading