Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/api/tasks.rst
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,7 @@ Available Tasks
MIMIC-III ICD-9 Coding <tasks/pyhealth.tasks.MIMIC3ICD9Coding>
Cardiology Detection <tasks/pyhealth.tasks.cardiology_detect>
COVID-19 CXR Classification <tasks/pyhealth.tasks.COVID19CXRClassification>
DKA Prediction (MIMIC-IV) <tasks/pyhealth.tasks.dka>
Drug Recommendation <tasks/pyhealth.tasks.drug_recommendation>
EEG Abnormal <tasks/pyhealth.tasks.EEG_abnormal>
EEG Events <tasks/pyhealth.tasks.EEG_events>
Expand Down
8 changes: 8 additions & 0 deletions docs/api/tasks/pyhealth.tasks.dka.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
pyhealth.tasks.dka
==================

.. autoclass:: pyhealth.tasks.dka.DKAPredictionMIMIC4
:members:
:undoc-members:
:show-inheritance:

188 changes: 188 additions & 0 deletions examples/benchmark_perf/benchmark_workers_12.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,188 @@
"""Benchmark script for MIMIC-IV mortality prediction with num_workers=4.

This benchmark measures:
1. Time to load base dataset
2. Time to process task with num_workers=4
3. Total processing time
4. Cache sizes
5. Peak memory usage (with optional memory limit)
"""

import time
import os
import threading
from pathlib import Path
import psutil
from pyhealth.datasets import MIMIC4Dataset
from pyhealth.tasks import MortalityPredictionStageNetMIMIC4

try:
import resource

HAS_RESOURCE = True
except ImportError:
HAS_RESOURCE = False


PEAK_MEM_USAGE = 0
SELF_PROC = psutil.Process(os.getpid())


def track_mem():
"""Background thread to track peak memory usage."""
global PEAK_MEM_USAGE
while True:
m = SELF_PROC.memory_info().rss
if m > PEAK_MEM_USAGE:
PEAK_MEM_USAGE = m
time.sleep(0.1)


def set_memory_limit(max_memory_gb):
"""Set hard memory limit for the process.

Args:
max_memory_gb: Maximum memory in GB (e.g., 8 for 8GB)

Note:
If limit is exceeded, the process will raise MemoryError.
Only works on Unix-like systems (Linux, macOS).
"""
if not HAS_RESOURCE:
print(
"Warning: resource module not available (Windows?). "
"Memory limit not enforced."
)
return

max_memory_bytes = int(max_memory_gb * 1024**3)
try:
resource.setrlimit(resource.RLIMIT_AS, (max_memory_bytes, max_memory_bytes))
print(f"✓ Memory limit set to {max_memory_gb} GB")
except Exception as e:
print(f"Warning: Failed to set memory limit: {e}")


def get_directory_size(path):
"""Calculate total size of a directory in bytes."""
total = 0
try:
for entry in Path(path).rglob("*"):
if entry.is_file():
total += entry.stat().st_size
except Exception as e:
print(f"Error calculating size for {path}: {e}")
return total


def format_size(size_bytes):
"""Format bytes to human-readable size."""
for unit in ["B", "KB", "MB", "GB", "TB"]:
if size_bytes < 1024.0:
return f"{size_bytes:.2f} {unit}"
size_bytes /= 1024.0
return f"{size_bytes:.2f} PB"


def main():
"""Main benchmark function."""
# Configuration
dev = False # Set to True for development/testing
enable_memory_limit = False # Set to True to enforce memory limit
max_memory_gb = 32 # Memory limit in GB (if enable_memory_limit=True)

# Apply memory limit if enabled
if enable_memory_limit:
set_memory_limit(max_memory_gb)

# Start memory tracking thread
mem_thread = threading.Thread(target=track_mem, daemon=True)
mem_thread.start()

print("=" * 80)
print(f"BENCHMARK: num_workers=4, dev={dev}")
if enable_memory_limit:
print(f"Memory Limit: {max_memory_gb} GB (ENFORCED)")
else:
print("Memory Limit: None (unrestricted)")
print("=" * 80)

# Define cache directories based on dev mode
cache_root = "/shared/rsaas/pyhealth/"
if dev:
cache_root += "_dev"

# Track total time
total_start = time.time() # STEP 1: Load MIMIC-IV base dataset
print("\n[1/2] Loading MIMIC-IV base dataset...")
dataset_start = time.time()

base_dataset = MIMIC4Dataset(
ehr_root="/srv/local/data/physionet.org/files/mimiciv/2.2/",
ehr_tables=[
"patients",
"admissions",
"diagnoses_icd",
"procedures_icd",
"labevents",
],
dev=dev,
cache_dir=f"{cache_root}/base_dataset",
)

dataset_time = time.time() - dataset_start
print(f"✓ Dataset loaded in {dataset_time:.2f} seconds")

# STEP 2: Apply StageNet mortality prediction task with num_workers=12
print("\n[2/2] Applying mortality prediction task (num_workers=12)...")
task_start = time.time()

sample_dataset = base_dataset.set_task(
MortalityPredictionStageNetMIMIC4(),
num_workers=12,
cache_dir=f"{cache_root}/task_samples",
)

task_time = time.time() - task_start
print(f"✓ Task processing completed in {task_time:.2f} seconds")

# Measure cache sizes
print("\n[3/3] Measuring cache sizes...")
base_cache_dir = f"{cache_root}/base_dataset"
task_cache_dir = f"{cache_root}/task_samples"

base_cache_size = get_directory_size(base_cache_dir)
task_cache_size = get_directory_size(task_cache_dir)
total_cache_size = base_cache_size + task_cache_size

print(f"✓ Base dataset cache: {format_size(base_cache_size)}")
print(f"✓ Task samples cache: {format_size(task_cache_size)}")
print(f"✓ Total cache size: {format_size(total_cache_size)}")

# Total time and peak memory
total_time = time.time() - total_start
peak_mem = PEAK_MEM_USAGE

# Print summary
print("\n" + "=" * 80)
print("BENCHMARK RESULTS")
print("=" * 80)
print("Configuration:")
print(" - num_workers: 4")
print(f" - dev mode: {dev}")
print(f" - Total samples: {len(sample_dataset)}")
print("\nTiming:")
print(f" - Dataset loading: {dataset_time:.2f}s")
print(f" - Task processing: {task_time:.2f}s")
print(f" - Total time: {total_time:.2f}s")
print("\nCache Sizes:")
print(f" - Base dataset cache: {format_size(base_cache_size)}")
print(f" - Task samples cache: {format_size(task_cache_size)}")
print(f" - Total cache: {format_size(total_cache_size)}")
print("\nMemory:")
print(f" - Peak memory usage: {format_size(peak_mem)}")
print("=" * 80)


if __name__ == "__main__":
main()
186 changes: 186 additions & 0 deletions examples/clinical_tasks/dka_mimic4.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,186 @@
"""
Example of using StageNet for DKA (Diabetic Ketoacidosis) prediction on MIMIC-IV.

This example demonstrates:
1. Loading MIMIC-IV data with relevant tables for DKA prediction
2. Applying the DKAPredictionMIMIC4 task (general population)
3. Creating a SampleDataset with StageNet processors
4. Training a StageNet model for DKA prediction

Target Population:
- ALL patients in MIMIC-IV (no diabetes filtering)
- Much larger patient pool with more negative samples
- Label: 1 if patient has ANY DKA diagnosis, 0 otherwise

Note: For T1DM-specific DKA prediction, see t1dka_mimic4.py
"""

import os
import torch

from pyhealth.datasets import (
MIMIC4Dataset,
get_dataloader,
split_by_patient,
)
from pyhealth.datasets.utils import save_processors, load_processors
from pyhealth.models import StageNet
from pyhealth.tasks import DKAPredictionMIMIC4
from pyhealth.trainer import Trainer


def main():
"""Main function to run DKA prediction pipeline on general population."""

# Configuration
MIMIC4_ROOT = "/srv/local/data/physionet.org/files/mimiciv/2.2/"
DATASET_CACHE_DIR = "/shared/rsaas/pyhealth/cache/mimic4_dataset"
TASK_CACHE_DIR = "/shared/rsaas/pyhealth/cache/mimic4_dka_general_stagenet"
PROCESSOR_DIR = "/shared/rsaas/pyhealth/processors/stagenet_dka_general_mimic4"
DEVICE = "cuda:5" if torch.cuda.is_available() else "cpu"

print("=" * 60)
print("DKA PREDICTION (GENERAL POPULATION) WITH STAGENET ON MIMIC-IV")
print("=" * 60)

# STEP 1: Load MIMIC-IV base dataset
print("\n=== Step 1: Loading MIMIC-IV Dataset ===")
base_dataset = MIMIC4Dataset(
ehr_root=MIMIC4_ROOT,
ehr_tables=[
"admissions",
"diagnoses_icd",
"procedures_icd",
"labevents",
],
cache_dir=DATASET_CACHE_DIR,
# dev=True, # Uncomment for faster development iteration
)

print("Dataset initialized, proceeding to task processing...")

# STEP 2: Apply DKA prediction task (general population)
print("\n=== Step 2: Applying DKA Prediction Task (General Population) ===")

# Create task with padding for unseen sequences
# No T1DM filtering - includes ALL patients
dka_task = DKAPredictionMIMIC4(padding=10)

print(f"Task: {dka_task.task_name}")
print(f"Input schema: {list(dka_task.input_schema.keys())}")
print(f"Output schema: {list(dka_task.output_schema.keys())}")
print("Note: This includes ALL patients (not just diabetics)")

# Check for pre-fitted processors
if os.path.exists(os.path.join(PROCESSOR_DIR, "input_processors.pkl")):
print("\nLoading pre-fitted processors...")
input_processors, output_processors = load_processors(PROCESSOR_DIR)

sample_dataset = base_dataset.set_task(
dka_task,
num_workers=4,
cache_dir=TASK_CACHE_DIR,
input_processors=input_processors,
output_processors=output_processors,
)
else:
print("\nFitting new processors...")
sample_dataset = base_dataset.set_task(
dka_task,
num_workers=4,
cache_dir=TASK_CACHE_DIR,
)

# Save processors for future runs
print("Saving processors...")
os.makedirs(PROCESSOR_DIR, exist_ok=True)
save_processors(sample_dataset, PROCESSOR_DIR)

print(f"\nTotal samples: {len(sample_dataset)}")

# Count label distribution
label_counts = {0: 0, 1: 0}
for sample in sample_dataset:
label_counts[int(sample["label"].item())] += 1

print(f"Label distribution:")
print(f" No DKA (0): {label_counts[0]} ({100*label_counts[0]/len(sample_dataset):.1f}%)")
print(f" Has DKA (1): {label_counts[1]} ({100*label_counts[1]/len(sample_dataset):.1f}%)")

# Inspect a sample
sample = sample_dataset[0]
print("\nSample structure:")
print(f" Patient ID: {sample['patient_id']}")
print(f" ICD codes (diagnoses + procedures): {sample['icd_codes'][1].shape} (visits x codes)")
print(f" Labs: {sample['labs'][0].shape} (timesteps x features)")
print(f" Label: {sample['label']}")

# STEP 3: Split dataset
print("\n=== Step 3: Splitting Dataset ===")
train_dataset, val_dataset, test_dataset = split_by_patient(
sample_dataset, [0.8, 0.1, 0.1]
)

print(f"Train: {len(train_dataset)} samples")
print(f"Validation: {len(val_dataset)} samples")
print(f"Test: {len(test_dataset)} samples")

# Create dataloaders
train_loader = get_dataloader(train_dataset, batch_size=256, shuffle=True)
val_loader = get_dataloader(val_dataset, batch_size=256, shuffle=False)
test_loader = get_dataloader(test_dataset, batch_size=256, shuffle=False)

# STEP 4: Initialize StageNet model
print("\n=== Step 4: Initializing StageNet Model ===")
model = StageNet(
dataset=sample_dataset,
embedding_dim=128,
chunk_size=128,
levels=3,
dropout=0.3,
)

num_params = sum(p.numel() for p in model.parameters())
print(f"Model parameters: {num_params:,}")

# STEP 5: Train the model
print("\n=== Step 5: Training Model ===")
trainer = Trainer(
model=model,
device=DEVICE,
metrics=["pr_auc", "roc_auc", "accuracy", "f1"],
)

trainer.train(
train_dataloader=train_loader,
val_dataloader=val_loader,
epochs=50,
monitor="roc_auc",
optimizer_params={"lr": 1e-5},
)

# STEP 6: Evaluate on test set
print("\n=== Step 6: Evaluation ===")
results = trainer.evaluate(test_loader)
print("\nTest Results:")
for metric, value in results.items():
print(f" {metric}: {value:.4f}")

# STEP 7: Inspect model predictions
print("\n=== Step 7: Sample Predictions ===")
sample_batch = next(iter(test_loader))
with torch.no_grad():
output = model(**sample_batch)

print(f"Predicted probabilities: {output['y_prob'][:5]}")
print(f"True labels: {output['y_true'][:5]}")

print("\n" + "=" * 60)
print("DKA PREDICTION (GENERAL POPULATION) TRAINING COMPLETED!")
print("=" * 60)

return results


if __name__ == "__main__":
main()
Loading