diff --git a/examples/interpretability/integrated_gradients_benchmark_stagenet.py b/examples/interpretability/integrated_gradients_benchmark_stagenet.py new file mode 100644 index 000000000..82dbca218 --- /dev/null +++ b/examples/interpretability/integrated_gradients_benchmark_stagenet.py @@ -0,0 +1,213 @@ +""" +Example of using StageNet for mortality prediction on MIMIC-IV with Integrated Gradients. + +This example demonstrates: +1. Loading MIMIC-IV data +2. Loading existing processors +3. Applying the MortalityPredictionStageNetMIMIC4 task +4. Loading a pre-trained StageNet model +5. Benchmarking model performance on test set +6. Computing Integrated Gradients attributions for test samples + +Processor Caching: + The script loads existing processors from: + ../../output/processors/stagenet_mortality_mimic4/ +""" + +from pathlib import Path + +import torch + +from pyhealth.datasets import ( + MIMIC4Dataset, + get_dataloader, + load_processors, + save_processors, + split_by_patient, +) +from pyhealth.interpret.methods import IntegratedGradients +from pyhealth.metrics.interpretability import evaluate_attribution +from pyhealth.models import StageNet +from pyhealth.tasks import MortalityPredictionStageNetMIMIC4 +from pyhealth.trainer import Trainer + +# Configuration +CHECKPOINT_PATH = ( + "/home/johnwu3/projects/PyHealth_Branch_Testing/PyHealth/output/" + "20260131-184735/best.ckpt" +) +PROCESSOR_DIR = "../output/processors/stagenet_mortality_mimic4" +CACHE_DIR = "../../mimic4_stagenet_cache" + + +def main(): + """Main execution function for StageNet mortality prediction with IG.""" + + # STEP 1: Load MIMIC-IV base dataset + print("Loading MIMIC-IV dataset...") + base_dataset = MIMIC4Dataset( + ehr_root="/srv/local/data/physionet.org/files/mimiciv/2.2/", + ehr_tables=[ + "patients", + "admissions", + "diagnoses_icd", + "procedures_icd", + "labevents", + ], + num_workers=8, + cache_dir="/shared/eng/pyhealth/ig", + ) + + # STEP 2: Check for existing processors and load/create accordingly + processor_dir_path = Path(PROCESSOR_DIR) + input_procs_file = processor_dir_path / "input_processors.pkl" + output_procs_file = processor_dir_path / "output_processors.pkl" + + input_processors = None + output_processors = None + + if input_procs_file.exists() and output_procs_file.exists(): + # Load existing processors + print(f"\n{'='*60}") + print("LOADING EXISTING PROCESSORS") + print(f"{'='*60}") + input_processors, output_processors = load_processors(PROCESSOR_DIR) + print(f"✓ Using pre-fitted processors from {PROCESSOR_DIR}") + else: + # Will create new processors + print(f"\n{'='*60}") + print("NO EXISTING PROCESSORS FOUND") + print(f"{'='*60}") + print(f"Will create and save new processors to {PROCESSOR_DIR}") + + # STEP 3: Apply StageNet mortality prediction task + print("Applying MortalityPredictionStageNetMIMIC4 task...") + sample_dataset = base_dataset.set_task( + MortalityPredictionStageNetMIMIC4(), + num_workers=8, + cache_dir=CACHE_DIR, + input_processors=input_processors, + output_processors=output_processors, + ) + + print(f"Total samples: {len(sample_dataset)}") + + # Save processors if they were newly created + if input_processors is None and output_processors is None: + print(f"\n{'='*60}") + print("SAVING NEWLY CREATED PROCESSORS") + print(f"{'='*60}") + save_processors(sample_dataset, PROCESSOR_DIR) + print(f"✓ Processors saved to {PROCESSOR_DIR}") + + # STEP 4: Split dataset + train_dataset, val_dataset, test_dataset = split_by_patient( + sample_dataset, [0.8, 0.1, 0.1] + ) + + print(f"Train samples: {len(train_dataset)}") + print(f"Val samples: {len(val_dataset)}") + print(f"Test samples: {len(test_dataset)}") + + # Create dataloaders + train_loader = get_dataloader(train_dataset, batch_size=256, shuffle=True) + val_loader = get_dataloader(val_dataset, batch_size=256, shuffle=False) + test_loader = get_dataloader(test_dataset, batch_size=256, shuffle=False) + + # STEP 5: Initialize and train/load model + print("\nInitializing StageNet model...") + model = StageNet( + dataset=sample_dataset, + embedding_dim=128, + chunk_size=128, + levels=3, + dropout=0.3, + ) + + num_params = sum(p.numel() for p in model.parameters()) + print(f"Model initialized with {num_params} parameters") + + trainer = Trainer( + model=model, + device="cuda:0", + metrics=["pr_auc", "roc_auc", "accuracy", "f1"], + ) + + # Check if checkpoint exists before loading + checkpoint_path_obj = Path(CHECKPOINT_PATH) + if checkpoint_path_obj.exists(): + print(f"\n{'='*60}") + print("LOADING EXISTING CHECKPOINT") + print(f"{'='*60}") + print(f"Path: {CHECKPOINT_PATH}") + trainer.load_ckpt(CHECKPOINT_PATH) + print("✓ Checkpoint loaded successfully") + else: + print(f"\n{'='*60}") + print("TRAINING NEW MODEL") + print(f"{'='*60}") + print(f"Checkpoint not found at: {CHECKPOINT_PATH}") + print("Training a new model from scratch...") + + trainer.train( + train_dataloader=train_loader, + val_dataloader=val_loader, + epochs=25, + monitor="pr_auc", + optimizer_params={"lr": 1e-5}, + ) + print("\n✓ Training completed") + + # STEP 6: Benchmark model on test set + print("\n" + "=" * 70) + print("BENCHMARKING MODEL ON TEST SET") + print("=" * 70) + + test_metrics = trainer.evaluate(test_loader) + + print("\nTest Set Performance:") + for metric_name, metric_value in test_metrics.items(): + print(f" {metric_name}: {metric_value:.4f}") + + # STEP 7: Compute Integrated Gradients Faithfulness Metrics + print("\n" + "=" * 70) + print("COMPUTING INTEGRATED GRADIENTS FAITHFULNESS METRICS") + print("=" * 70) + + # Initialize Integrated Gradients with 5 steps for faster computation + ig = IntegratedGradients(model, use_embeddings=True, steps=5) + print("✓ Integrated Gradients initialized (steps=5)") + + # Compute sufficiency and comprehensiveness on test set + print("\nEvaluating Integrated Gradients on test set...") + print("This computes sufficiency and comprehensiveness metrics.") + + # Use the functional API to evaluate attribution faithfulness + results = evaluate_attribution( + model, + test_loader, + ig, + metrics=["comprehensiveness", "sufficiency"], + percentages=[25, 50, 99], + ) + + # Print results + print("\n" + "=" * 70) + print("FAITHFULNESS METRICS RESULTS") + print("=" * 70) + + print("\nMetrics (averaged over test set):") + print(f" Comprehensiveness: {results['comprehensiveness']:.6f}") + print(f" Sufficiency: {results['sufficiency']:.6f}") + + print("\nInterpretation:") + print(" • Comprehensiveness: How much does removing the top features") + print(" change the model's prediction?") + print(" (Higher is better)") + print(" • Sufficiency: How much does keeping only the top features") + print(" maintain the model's prediction?") + print(" (Lower is better)") + + +if __name__ == "__main__": + main() diff --git a/examples/interpretability/integrated_gradients_mortality_mimic4_stagenet.py b/examples/interpretability/integrated_gradients_mortality_mimic4_stagenet.py index 65484a2a4..b55a8524d 100644 --- a/examples/interpretability/integrated_gradients_mortality_mimic4_stagenet.py +++ b/examples/interpretability/integrated_gradients_mortality_mimic4_stagenet.py @@ -289,6 +289,7 @@ def main(): "procedures_icd", "labevents", ], + num_workers=8, ) # STEP 2: Check for existing processors and load/create accordingly @@ -317,7 +318,7 @@ def main(): # Apply StageNet mortality prediction task sample_dataset = base_dataset.set_task( MortalityPredictionStageNetMIMIC4(), - num_workers=4, + num_workers=8, cache_dir="../../mimic4_stagenet_cache", input_processors=input_processors, output_processors=output_processors, diff --git a/pyhealth/interpret/methods/integrated_gradients.py b/pyhealth/interpret/methods/integrated_gradients.py index e64285700..dece8eacd 100644 --- a/pyhealth/interpret/methods/integrated_gradients.py +++ b/pyhealth/interpret/methods/integrated_gradients.py @@ -165,7 +165,7 @@ class IntegratedGradients(BaseInterpreter): ... ) """ - def __init__(self, model: BaseModel, use_embeddings: bool = True): + def __init__(self, model: BaseModel, use_embeddings: bool = True, steps: int = 50): """Initialize IntegratedGradients interpreter. Args: @@ -176,6 +176,10 @@ def __init__(self, model: BaseModel, use_embeddings: bool = True): codes. Set to False only for fully continuous input models. When True, the model must implement forward_from_embedding() and have an embedding_model attribute. + steps: Default number of interpolation steps for Riemann + approximation of the path integral. Default is 50. + Can be overridden in attribute() calls. More steps lead to + better approximation but slower computation. Raises: AssertionError: If use_embeddings=True but model does not @@ -183,6 +187,7 @@ def __init__(self, model: BaseModel, use_embeddings: bool = True): """ super().__init__(model) self.use_embeddings = use_embeddings + self.steps = steps # Check model supports forward_from_embedding if needed if use_embeddings: @@ -196,7 +201,7 @@ def __init__(self, model: BaseModel, use_embeddings: bool = True): def attribute( self, baseline: Optional[Dict[str, torch.Tensor]] = None, - steps: int = 50, + steps: Optional[int] = None, target_class_idx: Optional[int] = None, **data, ) -> Dict[str, torch.Tensor]: @@ -211,8 +216,9 @@ def attribute( - None: Uses small random baseline for all features (default) - Dict[str, torch.Tensor]: Custom baseline for each feature steps: Number of steps to use in the Riemann approximation of - the integral. More steps lead to better approximation but - slower computation. Default is 50. + the integral. If None, uses self.steps (set during + initialization). More steps lead to better approximation but + slower computation. target_class_idx: Target class index for attribution computation. If None, uses the predicted class (argmax of model output). @@ -275,6 +281,10 @@ def attribute( >>> top_k = torch.topk(torch.abs(condition_attr), k=5) >>> print(f"Most important features: {top_k.indices}") """ + # Use instance default if steps not specified + if steps is None: + steps = self.steps + # Extract feature keys and prepare inputs feature_keys = self.model.feature_keys inputs = {} @@ -304,6 +314,38 @@ def attribute( label_val = label_val.to(next(self.model.parameters()).device) label_data[key] = label_val + # Determine target class from original input if not specified + # This ensures the target class is fixed for all interpolation steps + if target_class_idx is None: + with torch.no_grad(): + # Prepare inputs for forward pass + forward_inputs = {} + for key in inputs: + if time_info and key in time_info: + forward_inputs[key] = (time_info[key], inputs[key]) + else: + forward_inputs[key] = inputs[key] + + forward_kwargs = {**label_data} if label_data else {} + output = self.model(**forward_inputs, **forward_kwargs) + logits = output["logit"] + + # Determine task type + output_schema = self.model.dataset.output_schema + label_key = list(output_schema.keys())[0] + task_mode = output_schema[label_key] + is_binary = task_mode == "binary" or ( + hasattr(task_mode, "__name__") + and task_mode.__name__ == "BinaryLabelProcessor" + ) + + # Get predicted class + if is_binary: + probs = torch.sigmoid(logits) + target_class_idx = (probs > 0.5).long().squeeze(-1) + else: + target_class_idx = torch.argmax(logits, dim=-1) + # Compute integrated gradients with single baseline attributions = self._integrated_gradients( inputs=inputs, @@ -377,22 +419,25 @@ def _prepare_embeddings_and_baselines( def _compute_target_output( self, logits: torch.Tensor, - target_class_idx: Optional[int] = None, + target_class_idx: int, ) -> torch.Tensor: """Compute target output scalar for backpropagation. - This method determines the target class (if not specified), creates - the appropriate one-hot encoding, and computes the scalar output - that will be used for computing gradients. + This method creates the appropriate one-hot encoding and computes + the scalar output that will be used for computing gradients. Args: logits: Model output logits [batch, num_classes] or [batch, 1] - target_class_idx: Optional target class index. If None, uses - the predicted class (argmax of logits). + target_class_idx: Target class index (must not be None). Returns: Scalar tensor representing the target output for backprop. """ + assert target_class_idx is not None, ( + "target_class_idx must be set before calling _compute_target_output. " + "This should be determined in attribute() method." + ) + # Determine task type from model's output schema output_schema = self.model.dataset.output_schema label_key = list(output_schema.keys())[0] @@ -404,16 +449,8 @@ def _compute_target_output( and task_mode.__name__ == "BinaryLabelProcessor" ) - # Determine target class - if target_class_idx is None: - if is_binary: - # Binary: if sigmoid(logit) > 0.5, class=1, else class=0 - probs = torch.sigmoid(logits) - tc_idx = (probs > 0.5).long().squeeze(-1) - else: - # Multiclass: argmax over classes - tc_idx = torch.argmax(logits, dim=-1) - elif not isinstance(target_class_idx, torch.Tensor): + # Convert target_class_idx to tensor if needed + if not isinstance(target_class_idx, torch.Tensor): tc_idx = torch.tensor(target_class_idx, device=logits.device) else: tc_idx = target_class_idx @@ -455,7 +492,7 @@ def _interpolate_and_compute_gradients( target_class_idx: Optional[int] = None, time_info: Optional[Dict[str, torch.Tensor]] = None, label_data: Optional[Dict[str, torch.Tensor]] = None, - ) -> Dict[str, list]: + ) -> Dict[str, torch.Tensor]: """Interpolate between baseline and input, accumulating gradients. This is the core of the Integrated Gradients algorithm. For each @@ -463,7 +500,7 @@ def _interpolate_and_compute_gradients( 1. Creates interpolated embeddings between baseline and input 2. Runs forward pass through the model 3. Computes gradients w.r.t. the interpolated embeddings - 4. Collects gradients for later averaging + 4. Accumulates gradients using running sum (memory efficient) Args: input_embeddings: Embedded input tensors for each feature. @@ -474,10 +511,13 @@ def _interpolate_and_compute_gradients( label_data: Optional label data to pass to model. Returns: - Dictionary mapping feature keys to lists of gradients, one - gradient tensor per interpolation step. + Dictionary mapping feature keys to accumulated gradient tensors + (already averaged over steps). """ - all_gradients = {key: [] for key in input_embeddings} + # Use running sum instead of storing all gradients (memory efficient) + avg_gradients = { + key: torch.zeros_like(emb) for key, emb in input_embeddings.items() + } for step_idx in range(steps + 1): alpha = step_idx / steps @@ -512,20 +552,23 @@ def _interpolate_and_compute_gradients( self.model.zero_grad() target_output.backward(retain_graph=True) - # Collect gradients for each feature's embedding + # Accumulate gradients using running sum (memory efficient) for key in input_embeddings: emb = interpolated_embeddings[key] if emb.grad is not None: - grad = emb.grad.detach().clone() - all_gradients[key].append(grad) - else: - all_gradients[key].append(torch.zeros_like(emb)) + # Add to running sum instead of storing in list + avg_gradients[key] += emb.grad.detach() + # If grad is None, we add nothing (zeros) + + # Average the accumulated gradients + for key in avg_gradients: + avg_gradients[key] /= steps + 1 - return all_gradients + return avg_gradients def _compute_final_attributions( self, - all_gradients: Dict[str, list], + avg_gradients: Dict[str, torch.Tensor], input_embeddings: Dict[str, torch.Tensor], baseline_embeddings: Dict[str, torch.Tensor], input_shapes: Dict[str, tuple], @@ -533,10 +576,9 @@ def _compute_final_attributions( """Compute final integrated gradients and map to input shapes. This method completes the IG computation by: - 1. Averaging gradients across interpolation steps - 2. Applying the IG formula: (input - baseline) * avg_gradient - 3. Summing over embedding dimension - 4. Mapping attributions back to original input tensor shapes + 1. Applying the IG formula: (input - baseline) * avg_gradient + 2. Summing over embedding dimension + 3. Mapping attributions back to original input tensor shapes Important properties of IG attributions: - Can be POSITIVE (feature increases prediction) or NEGATIVE @@ -547,7 +589,7 @@ def _compute_final_attributions( from the target class Args: - all_gradients: Dictionary of gradient lists from interpolation. + avg_gradients: Dictionary of averaged gradient tensors. input_embeddings: Embedded input tensors. baseline_embeddings: Baseline embeddings. input_shapes: Original input tensor shapes for mapping. @@ -559,13 +601,9 @@ def _compute_final_attributions( integrated_grads = {} for key in input_embeddings: - # Average gradients across interpolation steps (exclude last) - stacked_grads = torch.stack(all_gradients[key][:-1], dim=0) - avg_grad = torch.mean(stacked_grads, dim=0) - # Apply IG formula: (input_emb - baseline_emb) * avg_gradient delta_emb = input_embeddings[key] - baseline_embeddings[key] - emb_attribution = delta_emb * avg_grad + emb_attribution = delta_emb * avg_gradients[key] # Sum over embedding dimension to get per-token attribution # Handle both 3D [batch, seq, emb] and 4D [batch, seq, tokens, emb] @@ -622,13 +660,18 @@ def _integrated_gradients_embedding_based( inputs: Dictionary of input tensors. baseline: Optional baseline tensors. steps: Number of interpolation steps. - target_class_idx: Target class for attribution. + target_class_idx: Target class for attribution (must not be None). time_info: Optional time information for temporal models. label_data: Optional label data. Returns: Dictionary of attribution tensors matching input shapes. """ + assert target_class_idx is not None, ( + "target_class_idx must be set before calling _integrated_gradients_embedding_based. " + "This should be determined in attribute() method." + ) + # Step 1: Embed inputs and create baselines in embedding space input_embs, baseline_embs, shapes = self._prepare_embeddings_and_baselines( inputs, baseline @@ -669,13 +712,18 @@ def _integrated_gradients_continuous( inputs: Dictionary of input tensors. baseline: Optional baseline tensors. steps: Number of interpolation steps. - target_class_idx: Target class for attribution. + target_class_idx: Target class for attribution (must not be None). time_info: Optional time information for temporal models. label_data: Optional label data. Returns: Dictionary of attribution tensors matching input shapes. """ + assert target_class_idx is not None, ( + "target_class_idx must be set before calling _integrated_gradients_continuous. " + "This should be determined in attribute() method." + ) + # Create baseline if not provided if baseline is None: baseline = {}