-
Notifications
You must be signed in to change notification settings - Fork 559
Add/dka task #749
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Open
jhnwu3
wants to merge
10
commits into
master
Choose a base branch
from
add/dka_task
base: master
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Open
Add/dka task #749
Changes from all commits
Commits
Show all changes
10 commits
Select commit
Hold shift + click to select a range
81ae4b0
init commit, will need to debug and clean later
jhnwu3 d1a572c
make example simpelr
jhnwu3 1a99bf7
new updates to general population dka prediction
jhnwu3 9190f95
more commits for improving the robustness of task processing
jhnwu3 780c5c4
Merge branch 'master' into add/dka_task
jhnwu3 67551fc
type hint change
Logiquo d120df6
Merge remote-tracking branch 'upstream/master' into add/dka_task
Logiquo c14ddb9
ensure T1DDKAPredictionMIMIC4 does not leak data through visit length
Logiquo 4bca279
Chnage T1DDKA pre_filter to be lazy-eval
Logiquo a10ffae
Fix duplicate labevents column definition
Logiquo File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,8 @@ | ||
| pyhealth.tasks.dka | ||
| ================== | ||
|
|
||
| .. autoclass:: pyhealth.tasks.dka.DKAPredictionMIMIC4 | ||
| :members: | ||
| :undoc-members: | ||
| :show-inheritance: | ||
|
|
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,188 @@ | ||
| """Benchmark script for MIMIC-IV mortality prediction with num_workers=4. | ||
|
|
||
| This benchmark measures: | ||
| 1. Time to load base dataset | ||
| 2. Time to process task with num_workers=4 | ||
Logiquo marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| 3. Total processing time | ||
| 4. Cache sizes | ||
| 5. Peak memory usage (with optional memory limit) | ||
| """ | ||
|
|
||
| import time | ||
| import os | ||
| import threading | ||
| from pathlib import Path | ||
| import psutil | ||
| from pyhealth.datasets import MIMIC4Dataset | ||
| from pyhealth.tasks import MortalityPredictionStageNetMIMIC4 | ||
|
|
||
| try: | ||
| import resource | ||
|
|
||
| HAS_RESOURCE = True | ||
| except ImportError: | ||
| HAS_RESOURCE = False | ||
|
|
||
|
|
||
| PEAK_MEM_USAGE = 0 | ||
| SELF_PROC = psutil.Process(os.getpid()) | ||
|
|
||
|
|
||
| def track_mem(): | ||
| """Background thread to track peak memory usage.""" | ||
| global PEAK_MEM_USAGE | ||
| while True: | ||
| m = SELF_PROC.memory_info().rss | ||
| if m > PEAK_MEM_USAGE: | ||
| PEAK_MEM_USAGE = m | ||
| time.sleep(0.1) | ||
|
|
||
|
|
||
| def set_memory_limit(max_memory_gb): | ||
| """Set hard memory limit for the process. | ||
|
|
||
| Args: | ||
| max_memory_gb: Maximum memory in GB (e.g., 8 for 8GB) | ||
|
|
||
| Note: | ||
| If limit is exceeded, the process will raise MemoryError. | ||
| Only works on Unix-like systems (Linux, macOS). | ||
| """ | ||
| if not HAS_RESOURCE: | ||
| print( | ||
| "Warning: resource module not available (Windows?). " | ||
| "Memory limit not enforced." | ||
| ) | ||
| return | ||
|
|
||
| max_memory_bytes = int(max_memory_gb * 1024**3) | ||
| try: | ||
| resource.setrlimit(resource.RLIMIT_AS, (max_memory_bytes, max_memory_bytes)) | ||
| print(f"✓ Memory limit set to {max_memory_gb} GB") | ||
| except Exception as e: | ||
| print(f"Warning: Failed to set memory limit: {e}") | ||
|
|
||
|
|
||
| def get_directory_size(path): | ||
| """Calculate total size of a directory in bytes.""" | ||
| total = 0 | ||
| try: | ||
| for entry in Path(path).rglob("*"): | ||
| if entry.is_file(): | ||
| total += entry.stat().st_size | ||
| except Exception as e: | ||
| print(f"Error calculating size for {path}: {e}") | ||
| return total | ||
|
|
||
|
|
||
| def format_size(size_bytes): | ||
| """Format bytes to human-readable size.""" | ||
| for unit in ["B", "KB", "MB", "GB", "TB"]: | ||
| if size_bytes < 1024.0: | ||
| return f"{size_bytes:.2f} {unit}" | ||
| size_bytes /= 1024.0 | ||
| return f"{size_bytes:.2f} PB" | ||
|
|
||
|
|
||
| def main(): | ||
| """Main benchmark function.""" | ||
| # Configuration | ||
| dev = False # Set to True for development/testing | ||
| enable_memory_limit = False # Set to True to enforce memory limit | ||
| max_memory_gb = 32 # Memory limit in GB (if enable_memory_limit=True) | ||
Logiquo marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
|
||
| # Apply memory limit if enabled | ||
| if enable_memory_limit: | ||
| set_memory_limit(max_memory_gb) | ||
Logiquo marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
|
||
| # Start memory tracking thread | ||
| mem_thread = threading.Thread(target=track_mem, daemon=True) | ||
| mem_thread.start() | ||
|
|
||
| print("=" * 80) | ||
| print(f"BENCHMARK: num_workers=4, dev={dev}") | ||
| if enable_memory_limit: | ||
| print(f"Memory Limit: {max_memory_gb} GB (ENFORCED)") | ||
Logiquo marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| else: | ||
| print("Memory Limit: None (unrestricted)") | ||
| print("=" * 80) | ||
|
|
||
| # Define cache directories based on dev mode | ||
| cache_root = "/shared/rsaas/pyhealth/" | ||
| if dev: | ||
| cache_root += "_dev" | ||
Logiquo marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
|
||
| # Track total time | ||
| total_start = time.time() # STEP 1: Load MIMIC-IV base dataset | ||
| print("\n[1/2] Loading MIMIC-IV base dataset...") | ||
| dataset_start = time.time() | ||
|
|
||
| base_dataset = MIMIC4Dataset( | ||
| ehr_root="/srv/local/data/physionet.org/files/mimiciv/2.2/", | ||
| ehr_tables=[ | ||
| "patients", | ||
| "admissions", | ||
| "diagnoses_icd", | ||
| "procedures_icd", | ||
| "labevents", | ||
| ], | ||
| dev=dev, | ||
| cache_dir=f"{cache_root}/base_dataset", | ||
| ) | ||
|
|
||
| dataset_time = time.time() - dataset_start | ||
| print(f"✓ Dataset loaded in {dataset_time:.2f} seconds") | ||
|
|
||
| # STEP 2: Apply StageNet mortality prediction task with num_workers=12 | ||
Logiquo marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| print("\n[2/2] Applying mortality prediction task (num_workers=12)...") | ||
| task_start = time.time() | ||
|
|
||
| sample_dataset = base_dataset.set_task( | ||
| MortalityPredictionStageNetMIMIC4(), | ||
| num_workers=12, | ||
| cache_dir=f"{cache_root}/task_samples", | ||
| ) | ||
|
|
||
| task_time = time.time() - task_start | ||
| print(f"✓ Task processing completed in {task_time:.2f} seconds") | ||
|
|
||
| # Measure cache sizes | ||
| print("\n[3/3] Measuring cache sizes...") | ||
| base_cache_dir = f"{cache_root}/base_dataset" | ||
| task_cache_dir = f"{cache_root}/task_samples" | ||
|
|
||
| base_cache_size = get_directory_size(base_cache_dir) | ||
| task_cache_size = get_directory_size(task_cache_dir) | ||
| total_cache_size = base_cache_size + task_cache_size | ||
|
|
||
| print(f"✓ Base dataset cache: {format_size(base_cache_size)}") | ||
| print(f"✓ Task samples cache: {format_size(task_cache_size)}") | ||
| print(f"✓ Total cache size: {format_size(total_cache_size)}") | ||
|
|
||
| # Total time and peak memory | ||
| total_time = time.time() - total_start | ||
| peak_mem = PEAK_MEM_USAGE | ||
|
|
||
| # Print summary | ||
| print("\n" + "=" * 80) | ||
| print("BENCHMARK RESULTS") | ||
| print("=" * 80) | ||
| print("Configuration:") | ||
| print(" - num_workers: 4") | ||
Logiquo marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| print(f" - dev mode: {dev}") | ||
| print(f" - Total samples: {len(sample_dataset)}") | ||
| print("\nTiming:") | ||
| print(f" - Dataset loading: {dataset_time:.2f}s") | ||
| print(f" - Task processing: {task_time:.2f}s") | ||
| print(f" - Total time: {total_time:.2f}s") | ||
| print("\nCache Sizes:") | ||
| print(f" - Base dataset cache: {format_size(base_cache_size)}") | ||
| print(f" - Task samples cache: {format_size(task_cache_size)}") | ||
| print(f" - Total cache: {format_size(total_cache_size)}") | ||
| print("\nMemory:") | ||
| print(f" - Peak memory usage: {format_size(peak_mem)}") | ||
| print("=" * 80) | ||
|
|
||
|
|
||
| if __name__ == "__main__": | ||
| main() | ||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,186 @@ | ||
| """ | ||
| Example of using StageNet for DKA (Diabetic Ketoacidosis) prediction on MIMIC-IV. | ||
|
|
||
| This example demonstrates: | ||
| 1. Loading MIMIC-IV data with relevant tables for DKA prediction | ||
| 2. Applying the DKAPredictionMIMIC4 task (general population) | ||
| 3. Creating a SampleDataset with StageNet processors | ||
| 4. Training a StageNet model for DKA prediction | ||
|
|
||
| Target Population: | ||
| - ALL patients in MIMIC-IV (no diabetes filtering) | ||
| - Much larger patient pool with more negative samples | ||
| - Label: 1 if patient has ANY DKA diagnosis, 0 otherwise | ||
|
|
||
| Note: For T1DM-specific DKA prediction, see t1dka_mimic4.py | ||
| """ | ||
|
|
||
| import os | ||
| import torch | ||
|
|
||
| from pyhealth.datasets import ( | ||
| MIMIC4Dataset, | ||
| get_dataloader, | ||
| split_by_patient, | ||
| ) | ||
| from pyhealth.datasets.utils import save_processors, load_processors | ||
| from pyhealth.models import StageNet | ||
| from pyhealth.tasks import DKAPredictionMIMIC4 | ||
| from pyhealth.trainer import Trainer | ||
|
|
||
|
|
||
| def main(): | ||
| """Main function to run DKA prediction pipeline on general population.""" | ||
|
|
||
| # Configuration | ||
| MIMIC4_ROOT = "/srv/local/data/physionet.org/files/mimiciv/2.2/" | ||
| DATASET_CACHE_DIR = "/shared/rsaas/pyhealth/cache/mimic4_dataset" | ||
| TASK_CACHE_DIR = "/shared/rsaas/pyhealth/cache/mimic4_dka_general_stagenet" | ||
| PROCESSOR_DIR = "/shared/rsaas/pyhealth/processors/stagenet_dka_general_mimic4" | ||
| DEVICE = "cuda:5" if torch.cuda.is_available() else "cpu" | ||
|
|
||
| print("=" * 60) | ||
| print("DKA PREDICTION (GENERAL POPULATION) WITH STAGENET ON MIMIC-IV") | ||
| print("=" * 60) | ||
|
|
||
| # STEP 1: Load MIMIC-IV base dataset | ||
| print("\n=== Step 1: Loading MIMIC-IV Dataset ===") | ||
| base_dataset = MIMIC4Dataset( | ||
| ehr_root=MIMIC4_ROOT, | ||
| ehr_tables=[ | ||
| "admissions", | ||
| "diagnoses_icd", | ||
| "procedures_icd", | ||
| "labevents", | ||
| ], | ||
| cache_dir=DATASET_CACHE_DIR, | ||
| # dev=True, # Uncomment for faster development iteration | ||
| ) | ||
|
|
||
| print("Dataset initialized, proceeding to task processing...") | ||
|
|
||
| # STEP 2: Apply DKA prediction task (general population) | ||
| print("\n=== Step 2: Applying DKA Prediction Task (General Population) ===") | ||
|
|
||
| # Create task with padding for unseen sequences | ||
| # No T1DM filtering - includes ALL patients | ||
| dka_task = DKAPredictionMIMIC4(padding=10) | ||
|
|
||
| print(f"Task: {dka_task.task_name}") | ||
| print(f"Input schema: {list(dka_task.input_schema.keys())}") | ||
| print(f"Output schema: {list(dka_task.output_schema.keys())}") | ||
| print("Note: This includes ALL patients (not just diabetics)") | ||
|
|
||
| # Check for pre-fitted processors | ||
| if os.path.exists(os.path.join(PROCESSOR_DIR, "input_processors.pkl")): | ||
| print("\nLoading pre-fitted processors...") | ||
| input_processors, output_processors = load_processors(PROCESSOR_DIR) | ||
|
|
||
| sample_dataset = base_dataset.set_task( | ||
| dka_task, | ||
| num_workers=4, | ||
| cache_dir=TASK_CACHE_DIR, | ||
| input_processors=input_processors, | ||
| output_processors=output_processors, | ||
| ) | ||
| else: | ||
| print("\nFitting new processors...") | ||
| sample_dataset = base_dataset.set_task( | ||
| dka_task, | ||
| num_workers=4, | ||
| cache_dir=TASK_CACHE_DIR, | ||
| ) | ||
|
|
||
| # Save processors for future runs | ||
| print("Saving processors...") | ||
| os.makedirs(PROCESSOR_DIR, exist_ok=True) | ||
| save_processors(sample_dataset, PROCESSOR_DIR) | ||
|
|
||
| print(f"\nTotal samples: {len(sample_dataset)}") | ||
|
|
||
| # Count label distribution | ||
| label_counts = {0: 0, 1: 0} | ||
| for sample in sample_dataset: | ||
| label_counts[int(sample["label"].item())] += 1 | ||
|
|
||
| print(f"Label distribution:") | ||
| print(f" No DKA (0): {label_counts[0]} ({100*label_counts[0]/len(sample_dataset):.1f}%)") | ||
| print(f" Has DKA (1): {label_counts[1]} ({100*label_counts[1]/len(sample_dataset):.1f}%)") | ||
|
|
||
| # Inspect a sample | ||
| sample = sample_dataset[0] | ||
| print("\nSample structure:") | ||
| print(f" Patient ID: {sample['patient_id']}") | ||
| print(f" ICD codes (diagnoses + procedures): {sample['icd_codes'][1].shape} (visits x codes)") | ||
| print(f" Labs: {sample['labs'][0].shape} (timesteps x features)") | ||
| print(f" Label: {sample['label']}") | ||
|
|
||
| # STEP 3: Split dataset | ||
| print("\n=== Step 3: Splitting Dataset ===") | ||
| train_dataset, val_dataset, test_dataset = split_by_patient( | ||
| sample_dataset, [0.8, 0.1, 0.1] | ||
| ) | ||
|
|
||
| print(f"Train: {len(train_dataset)} samples") | ||
| print(f"Validation: {len(val_dataset)} samples") | ||
| print(f"Test: {len(test_dataset)} samples") | ||
|
|
||
| # Create dataloaders | ||
| train_loader = get_dataloader(train_dataset, batch_size=256, shuffle=True) | ||
| val_loader = get_dataloader(val_dataset, batch_size=256, shuffle=False) | ||
| test_loader = get_dataloader(test_dataset, batch_size=256, shuffle=False) | ||
|
|
||
| # STEP 4: Initialize StageNet model | ||
| print("\n=== Step 4: Initializing StageNet Model ===") | ||
| model = StageNet( | ||
| dataset=sample_dataset, | ||
| embedding_dim=128, | ||
| chunk_size=128, | ||
| levels=3, | ||
| dropout=0.3, | ||
| ) | ||
|
|
||
| num_params = sum(p.numel() for p in model.parameters()) | ||
| print(f"Model parameters: {num_params:,}") | ||
|
|
||
| # STEP 5: Train the model | ||
| print("\n=== Step 5: Training Model ===") | ||
| trainer = Trainer( | ||
| model=model, | ||
| device=DEVICE, | ||
| metrics=["pr_auc", "roc_auc", "accuracy", "f1"], | ||
| ) | ||
|
|
||
| trainer.train( | ||
| train_dataloader=train_loader, | ||
| val_dataloader=val_loader, | ||
| epochs=50, | ||
| monitor="roc_auc", | ||
| optimizer_params={"lr": 1e-5}, | ||
| ) | ||
|
|
||
| # STEP 6: Evaluate on test set | ||
| print("\n=== Step 6: Evaluation ===") | ||
| results = trainer.evaluate(test_loader) | ||
| print("\nTest Results:") | ||
| for metric, value in results.items(): | ||
| print(f" {metric}: {value:.4f}") | ||
|
|
||
| # STEP 7: Inspect model predictions | ||
| print("\n=== Step 7: Sample Predictions ===") | ||
| sample_batch = next(iter(test_loader)) | ||
| with torch.no_grad(): | ||
| output = model(**sample_batch) | ||
|
|
||
| print(f"Predicted probabilities: {output['y_prob'][:5]}") | ||
| print(f"True labels: {output['y_true'][:5]}") | ||
|
|
||
| print("\n" + "=" * 60) | ||
| print("DKA PREDICTION (GENERAL POPULATION) TRAINING COMPLETED!") | ||
| print("=" * 60) | ||
|
|
||
| return results | ||
|
|
||
|
|
||
| if __name__ == "__main__": | ||
| main() |
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.