From 45ec596ddc8f2ee5c7d524444d489c5c053e01f0 Mon Sep 17 00:00:00 2001 From: Audrey Cherilyn Date: Mon, 12 Jan 2026 10:16:20 -0500 Subject: [PATCH 01/15] Fixing wanda --- .../examples/llama3_fast_pruning copy.yaml | 207 ++++++ configs/examples/llama3_fast_pruning.yaml | 62 +- slurm-48692030.out | 4 - slurm_jobs/run_fast_pruning.sh | 12 +- slurm_jobs/run_llama3_scar_pruning.sh | 10 +- src/alignment/experiments/llm_experiments.py | 29 +- src/alignment/pruning/baselines.py | 98 +-- .../pruning/strategies/external/wanda/data.py | 75 ++ .../strategies/external/wanda/layerwrapper.py | 35 + .../strategies/external/wanda/prune.py | 172 +++++ .../pruning/strategies/llm_baselines copy.py | 675 ++++++++++++++++++ 11 files changed, 1289 insertions(+), 90 deletions(-) create mode 100644 configs/examples/llama3_fast_pruning copy.yaml delete mode 100644 slurm-48692030.out create mode 100644 src/alignment/pruning/strategies/external/wanda/data.py create mode 100644 src/alignment/pruning/strategies/external/wanda/layerwrapper.py create mode 100644 src/alignment/pruning/strategies/external/wanda/prune.py create mode 100644 src/alignment/pruning/strategies/llm_baselines copy.py diff --git a/configs/examples/llama3_fast_pruning copy.yaml b/configs/examples/llama3_fast_pruning copy.yaml new file mode 100644 index 00000000..86c775f8 --- /dev/null +++ b/configs/examples/llama3_fast_pruning copy.yaml @@ -0,0 +1,207 @@ +# ============================================================================ +# LLaMA-3 FAST PRUNING COMPARISON +# ============================================================================ +# +# PURPOSE: Quick iteration version of comprehensive pruning comparison +# +# CHANGES FROM COMPREHENSIVE VERSION: +# - 3 sparsity levels instead of 9 +# - 1 selection mode instead of 2 +# - 4 key algorithms instead of 9 +# - Dropped slow benchmarks (GSM8k, MBPP, HumanEval) +# - Reduced evaluation samples (50 instead of 100) +# - Disabled supernode robustness analysis +# +# EXPECTED RUNTIME: ~30-60 minutes on H100 (vs 6-12 hours for comprehensive) +# ============================================================================ + +experiment: + name: "llama3_fast_pruning" + type: "llm_alignment" + seed: 42 + device: "cuda" + output_dir: "./results/llama3_fast_pruning" + num_networks: 1 + +model: + name: "hf_causal_lm" + model_id: "meta-llama/Llama-3.1-8B" + dtype: "bfloat16" + device_map: "auto" + + # Track all MLP layers for comprehensive analysis + tracked_layers: + - "model.model.layers.*.mlp.up_proj" + - "model.model.layers.*.mlp.gate_proj" + - "model.model.layers.*.mlp.down_proj" + +dataset: + name: "wikitext" + batch_size: 1 + num_workers: 0 + +# ============================================================================ +# IMPORTANCE METRICS - Reduced for speed +# ============================================================================ +metrics: + enabled: + - "rayleigh_quotient" # Core alignment metric + - "activation_l2_norm" # Baseline for comparison + + num_samples: 32 # Reduced from 64 for faster calibration + + rayleigh_quotient: + relative: true + regularization: 1.0e-6 + +# ============================================================================ +# LLM-SPECIFIC SETTINGS - Optimized for speed +# ============================================================================ +llm: + # SCAR metrics (reduced samples) + scar_metrics: true + scar_num_samples: 32 # Reduced from 64 + scar_max_length: 512 + + # Evaluation settings - reduced + evaluate_perplexity: true + evaluation_num_samples: 50 # Reduced from 100 + + # FAST EVALUATION SUITE - dropped slow benchmarks + evaluation_metrics: + # Language modeling (core metrics) - FAST + - "perplexity" # ~2 sec + - "loss" # ~1 sec + - "bits_per_byte" # ~1 sec + + # Knowledge & Reasoning - FAST + - "accuracy_mmlu" # ~15 sec with 50 samples + - "accuracy_hellaswag" # ~5 sec + - "accuracy_arc_easy" # ~5 sec + - "accuracy_arc_challenge" # ~5 sec + + # Common Sense - FAST + - "accuracy_winogrande" # ~3 sec + - "accuracy_piqa" # ~3 sec + - "accuracy_boolq" # ~3 sec + - "accuracy_truthfulqa" # ~5 sec + + # REMOVED SLOW BENCHMARKS: + # - "accuracy_gsm8k" # SLOW: ~3+ min (requires generation) + # - "accuracy_mbpp" # SLOW: ~1.5+ min (code generation) + # - "accuracy_humaneval" # SLOW: ~2+ min (code generation) + +# ============================================================================ +# SUPERNODE CONFIGURATION - Simplified +# ============================================================================ +supernode: + enabled: true + + core_fraction: 0.01 + follower_fraction: 0.10 + score_metric: "activation_l2_norm" + + protect_core: true + cross_layer_analysis: false # Disabled for speed + compare_by_connection: false # Disabled for speed + + compute_metrics: + - "activation" + - "rayleigh_quotient" + +# ============================================================================ +# SUPERNODE ROBUSTNESS ANALYSIS - DISABLED for speed +# ============================================================================ +supernode_robustness: + enabled: false # Disabled to save ~10 minutes + +# ============================================================================ +# PRUNING CONFIGURATION - Optimized for speed +# ============================================================================ +pruning: + enabled: true + + # KEY sparsity levels only (3 instead of 9) + sparsity_levels: [0.3, 0.5, 0.7] + + # Single selection mode (saves 2x time) + selection_modes: ["low"] + + # Pruning structure + distribution: "uniform" + structured: true + dependency_aware: true + + # ========================================================================= + # REDUCED PRUNING ALGORITHMS - 3 key methods instead of 9 + # ========================================================================= + algorithms: + # Our main method + - "rayleigh_quotient" + + # SCAR-based (gradient-informed) + - "scar_loss_proxy" + + # Baseline + - "activation_l2_norm" + + # REMOVED FOR SPEED: + # - "gaussian_mi_analytic" # Similar to RQ + # - "average_redundancy" # Can add later + # - "supernode_protection_score" # Can add later + # - "supernode_connectivity_score" + # NOTE: wanda/sparsegpt require special calibration that's not fully integrated + # - "wanda" # Needs calibration (not yet fully integrated) + # - "sparsegpt" # SLOW: second-order optimization + + single_strategy: null + + fine_tune: + enabled: false + +# ============================================================================ +# ADVANCED ANALYSIS FLAGS - Disabled for speed +# ============================================================================ +do_directed_redundancy: false +do_connectivity_pruning: false + +# ============================================================================ +# PERFORMANCE SETTINGS +# ============================================================================ +performance: + eval_batches: null + +# ============================================================================ +# ANALYSIS & VISUALIZATION - Simplified +# ============================================================================ +analysis: + save_scores: true + generate_plots: true + + plots: + histograms: true + scatter_plots: false # Disabled for speed + pruning_curves: true + redundancy_heatmaps: false # Disabled for speed + + scatter_pairs: + - ["activation_l2_norm", "rayleigh_quotient"] + +visualization: + format: "png" + dpi: 150 # Reduced for faster plot generation + +# ============================================================================ +# EXPECTED CONFIGURATIONS TO RUN +# ============================================================================ +# +# Sparsity levels: 3 (0.3, 0.5, 0.7) +# Selection modes: 1 (low) +# Algorithms: 4 (rayleigh_quotient, scar_loss_proxy, activation_l2_norm, wanda) +# +# Total configs: 3 × 1 × 3 = 9 configurations (vs 162 in comprehensive) +# +# Time per config: ~2-3 minutes +# Estimated total: ~30-45 minutes +# ============================================================================ + diff --git a/configs/examples/llama3_fast_pruning.yaml b/configs/examples/llama3_fast_pruning.yaml index 86c775f8..79ac0fdc 100644 --- a/configs/examples/llama3_fast_pruning.yaml +++ b/configs/examples/llama3_fast_pruning.yaml @@ -45,14 +45,14 @@ dataset: # ============================================================================ metrics: enabled: - - "rayleigh_quotient" # Core alignment metric + # - "rayleigh_quotient" # Core alignment metric - "activation_l2_norm" # Baseline for comparison num_samples: 32 # Reduced from 64 for faster calibration - rayleigh_quotient: - relative: true - regularization: 1.0e-6 + # rayleigh_quotient: + # relative: true + # regularization: 1.0e-6 # ============================================================================ # LLM-SPECIFIC SETTINGS - Optimized for speed @@ -72,19 +72,19 @@ llm: # Language modeling (core metrics) - FAST - "perplexity" # ~2 sec - "loss" # ~1 sec - - "bits_per_byte" # ~1 sec + # - "bits_per_byte" # ~1 sec - # Knowledge & Reasoning - FAST - - "accuracy_mmlu" # ~15 sec with 50 samples - - "accuracy_hellaswag" # ~5 sec - - "accuracy_arc_easy" # ~5 sec - - "accuracy_arc_challenge" # ~5 sec + # # Knowledge & Reasoning - FAST + # - "accuracy_mmlu" # ~15 sec with 50 samples + # - "accuracy_hellaswag" # ~5 sec + # - "accuracy_arc_easy" # ~5 sec + # - "accuracy_arc_challenge" # ~5 sec - # Common Sense - FAST - - "accuracy_winogrande" # ~3 sec - - "accuracy_piqa" # ~3 sec - - "accuracy_boolq" # ~3 sec - - "accuracy_truthfulqa" # ~5 sec + # # Common Sense - FAST + # - "accuracy_winogrande" # ~3 sec + # - "accuracy_piqa" # ~3 sec + # - "accuracy_boolq" # ~3 sec + # - "accuracy_truthfulqa" # ~5 sec # REMOVED SLOW BENCHMARKS: # - "accuracy_gsm8k" # SLOW: ~3+ min (requires generation) @@ -107,7 +107,7 @@ supernode: compute_metrics: - "activation" - - "rayleigh_quotient" + # - "rayleigh_quotient" # ============================================================================ # SUPERNODE ROBUSTNESS ANALYSIS - DISABLED for speed @@ -122,7 +122,7 @@ pruning: enabled: true # KEY sparsity levels only (3 instead of 9) - sparsity_levels: [0.3, 0.5, 0.7] + sparsity_levels: [0.1, 0.3, 0.5, 0.7, 0.9] # Single selection mode (saves 2x time) selection_modes: ["low"] @@ -137,13 +137,13 @@ pruning: # ========================================================================= algorithms: # Our main method - - "rayleigh_quotient" + # - "rayleigh_quotient" # SCAR-based (gradient-informed) - - "scar_loss_proxy" + # - "scar_loss_proxy" # Baseline - - "activation_l2_norm" + # - "activation_l2_norm" # REMOVED FOR SPEED: # - "gaussian_mi_analytic" # Similar to RQ @@ -151,7 +151,7 @@ pruning: # - "supernode_protection_score" # Can add later # - "supernode_connectivity_score" # NOTE: wanda/sparsegpt require special calibration that's not fully integrated - # - "wanda" # Needs calibration (not yet fully integrated) + - "wanda" # Needs calibration (not yet fully integrated) # - "sparsegpt" # SLOW: second-order optimization single_strategy: null @@ -179,18 +179,32 @@ analysis: generate_plots: true plots: - histograms: true + histograms: false # Disabled for speed scatter_plots: false # Disabled for speed pruning_curves: true redundancy_heatmaps: false # Disabled for speed - scatter_pairs: - - ["activation_l2_norm", "rayleigh_quotient"] + # scatter_pairs: + # - ["activation_l2_norm", "rayleigh_quotient"] visualization: format: "png" dpi: 150 # Reduced for faster plot generation + pruning_curves: + enabled: true + plot_sparsity_vs_perplexity: true + plot_sparsity_vs_accuracy: true + metrics_to_compare: + # - "rayleigh_quotient" + - "scar_loss_proxy" + # - "supernode_connectivity_score" + # - "supernode_protection_score" + - "wanda" + # - "sparsegpt" + - "activation_l2_norm" + + # ============================================================================ # EXPECTED CONFIGURATIONS TO RUN # ============================================================================ diff --git a/slurm-48692030.out b/slurm-48692030.out deleted file mode 100644 index 0fec100d..00000000 --- a/slurm-48692030.out +++ /dev/null @@ -1,4 +0,0 @@ -Running vision synergy experiment with config: configs/projects/vision_synergy.yaml -Working directory: /var/slurmd/spool/slurmd - -python: can't open file '/var/slurmd/spool/slurmd/scripts/run_experiment.py': [Errno 2] No such file or directory diff --git a/slurm_jobs/run_fast_pruning.sh b/slurm_jobs/run_fast_pruning.sh index 2161c2a6..78c1a79a 100755 --- a/slurm_jobs/run_fast_pruning.sh +++ b/slurm_jobs/run_fast_pruning.sh @@ -8,8 +8,8 @@ #SBATCH --cpus-per-task=8 #SBATCH --time=02:00:00 #SBATCH --mem=80GB -#SBATCH --partition=kempner_eng -#SBATCH --account=kempner_dev +#SBATCH --partition=kempner_h100 +#SBATCH --account=kempner_undergrads # ============================================================================ # FAST LLM PRUNING COMPARISON @@ -38,17 +38,17 @@ echo "" module purge module load cuda/12.2.0-fasrc01 eval "$(conda shell.bash hook)" -conda activate networkAlignmentAnalysis +conda activate alignenv2 -cd /n/holylabs/kempner_dev/Users/hsafaai/Code/alignment +cd /n/holylfs06/LABS/kempner_undergrads/Lab/acherilyn/alignment mkdir -p logs export OMP_NUM_THREADS=$SLURM_CPUS_PER_TASK export TOKENIZERS_PARALLELISM=false export PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True -export HF_HOME=/n/home13/hsafaai/.cache/huggingface -export HF_TOKEN=$(cat /n/home13/hsafaai/.cache/huggingface/token) +# export HF_HOME=/n/home13/hsafaai/.cache/huggingface +# export HF_TOKEN=$(cat /n/home13/hsafaai/.cache/huggingface/token) echo "============================================================================" echo "FAST MODE CONFIGURATION:" diff --git a/slurm_jobs/run_llama3_scar_pruning.sh b/slurm_jobs/run_llama3_scar_pruning.sh index a95c3f78..357ec97d 100755 --- a/slurm_jobs/run_llama3_scar_pruning.sh +++ b/slurm_jobs/run_llama3_scar_pruning.sh @@ -9,7 +9,7 @@ #SBATCH --time=8:00:00 #SBATCH --mem=320GB #SBATCH --partition=kempner_h100 -#SBATCH --account=kempner_dev +#SBATCH --account=kempner_undergrads echo "==========================================" echo "LLaMA-3 SCAR-Based Pruning with Supernode Protection" @@ -24,9 +24,9 @@ echo "" module purge module load cuda/12.2.0-fasrc01 eval "$(conda shell.bash hook)" -conda activate networkAlignmentAnalysis +conda activate alignenv2 -cd /n/holylabs/kempner_dev/Users/hsafaai/Code/alignment +cd /n/holylfs06/LABS/kempner_undergrads/Lab/acherilyn/alignment # Make logs directory if it doesn't exist mkdir -p logs @@ -34,8 +34,8 @@ mkdir -p logs export OMP_NUM_THREADS=$SLURM_CPUS_PER_TASK export TOKENIZERS_PARALLELISM=false export PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True -export HF_HOME=/n/home13/hsafaai/.cache/huggingface -export HF_TOKEN=$(cat /n/home13/hsafaai/.cache/huggingface/token) +# export HF_HOME=/n/home13/hsafaai/.cache/huggingface +# export HF_TOKEN=$(cat /n/home13/hsafaai/.cache/huggingface/token) echo "Running SCAR-based pruning experiment..." echo "Pruning metrics: L2 norm, SCAR loss proxy" diff --git a/src/alignment/experiments/llm_experiments.py b/src/alignment/experiments/llm_experiments.py index 22eb77cb..b2c72d0b 100644 --- a/src/alignment/experiments/llm_experiments.py +++ b/src/alignment/experiments/llm_experiments.py @@ -6417,12 +6417,20 @@ def restore_weights(): # Iterate over all strategy/mode combinations for metric in pruning_strategies: + # Special case: Wanda + if metric.lower() == "wanda": + try: + from alignment.pruning.strategies.external.wanda.prune import prune_wanda # Import Wanda pruning function + except ImportError: + logger.error("Could not import Wanda pruning. Make sure external/wanda is in PYTHONPATH.") + continue + # Check if we have importance scores for this metric has_metric_scores = any( metric in layer_scores for layer_scores in self.importance_scores.values() ) - if not has_metric_scores: + if not has_metric_scores and metric.lower() != "wanda": logger.warning(f"No importance scores computed for metric '{metric}', skipping pruning") continue @@ -6443,7 +6451,24 @@ def restore_weights(): restore_weights() logger.info(f" Applying pruning: sparsity={sparsity}, metric={metric}, mode={mode}") - masks = self.apply_pruning(sparsity=sparsity, mode=mode, metric=metric) + + if metric.lower() == "wanda": + # Create args object for Wanda pruning + class Args: + def __init__(self, nsamples, seed): + self.nsamples = nsamples + self.seed = seed + + # Get parameters from config + nsamples = getattr(self.config, "alignment_data_num_samples", 32) # Default to 32 + seed = getattr(self.config, "seed", 42) # Default to 42 + + args = Args(nsamples=nsamples, seed=seed) + + # Call Wanda pruning function with current sparsity + prune_wanda(args, self.model, self.tokenizer, self.config.device, prune_n=0, prune_m=0, sparsity_ratio=sparsity) + else: + masks = self.apply_pruning(sparsity=sparsity, mode=mode, metric=metric) pruning_data["sparsities"].append(sparsity) diff --git a/src/alignment/pruning/baselines.py b/src/alignment/pruning/baselines.py index 7d1fb132..5b112043 100644 --- a/src/alignment/pruning/baselines.py +++ b/src/alignment/pruning/baselines.py @@ -293,64 +293,64 @@ def apply_structured_pruning( return new_module -def compare_pruning_methods( - model: nn.Module, - calibration_data: torch.Tensor, - sparsity_levels: List[float] = [0.3, 0.5, 0.7], - methods: List[str] = ["magnitude", "wanda", "sparsegpt"], -) -> Dict[str, Dict[float, Dict[str, Any]]]: - """ - Compare different pruning methods on a model. +# def compare_pruning_methods( +# model: nn.Module, +# calibration_data: torch.Tensor, +# sparsity_levels: List[float] = [0.3, 0.5, 0.7], +# methods: List[str] = ["magnitude", "wanda", "sparsegpt"], +# ) -> Dict[str, Dict[float, Dict[str, Any]]]: +# """ +# Compare different pruning methods on a model. - Args: - model: Model to prune - calibration_data: Data for computing activation statistics - sparsity_levels: List of sparsity levels to test - methods: List of pruning methods to compare +# Args: +# model: Model to prune +# calibration_data: Data for computing activation statistics +# sparsity_levels: List of sparsity levels to test +# methods: List of pruning methods to compare - Returns: - Dictionary with results per method and sparsity level - """ - results = {method: {} for method in methods} +# Returns: +# Dictionary with results per method and sparsity level +# """ +# results = {method: {} for method in methods} - # Initialize pruning methods - pruners = { - "magnitude": MagnitudePruning(structured=True), - "wanda": WandaPruning(structured=True), - "sparsegpt": SparseGPTStylePruning(structured=True), - } +# # Initialize pruning methods +# pruners = { +# "magnitude": MagnitudePruning(structured=True), +# "wanda": WandaPruning(structured=True), +# "sparsegpt": SparseGPTStylePruning(structured=True), +# } - for method in methods: - if method not in pruners: - logger.warning(f"Unknown pruning method: {method}") - continue +# for method in methods: +# if method not in pruners: +# logger.warning(f"Unknown pruning method: {method}") +# continue - pruner = pruners[method] +# pruner = pruners[method] - for sparsity in sparsity_levels: - pruner.sparsity = sparsity +# for sparsity in sparsity_levels: +# pruner.sparsity = sparsity - # Collect scores for all layers - layer_scores = {} - for name, module in model.named_modules(): - if isinstance(module, nn.Linear): - # Get activations for this layer (would need hooks in practice) - # This is a simplified version - scores = pruner.compute_scores( - module.weight.data, - calibration_data, - ) - layer_scores[name] = { - "scores": scores, - "mask": pruner.get_pruning_mask(scores, sparsity), - } +# # Collect scores for all layers +# layer_scores = {} +# for name, module in model.named_modules(): +# if isinstance(module, nn.Linear): +# # Get activations for this layer (would need hooks in practice) +# # This is a simplified version +# scores = pruner.compute_scores( +# module.weight.data, +# calibration_data, +# ) +# layer_scores[name] = { +# "scores": scores, +# "mask": pruner.get_pruning_mask(scores, sparsity), +# } - results[method][sparsity] = { - "layer_scores": layer_scores, - "sparsity": sparsity, - } +# results[method][sparsity] = { +# "layer_scores": layer_scores, +# "sparsity": sparsity, +# } - return results +# return results # Registry for easy access diff --git a/src/alignment/pruning/strategies/external/wanda/data.py b/src/alignment/pruning/strategies/external/wanda/data.py new file mode 100644 index 00000000..d6eaa348 --- /dev/null +++ b/src/alignment/pruning/strategies/external/wanda/data.py @@ -0,0 +1,75 @@ +# Code adapted from https://github.com/IST-DASLab/sparsegpt/blob/master/datautils.py + +import numpy as np +import random +import torch +from datasets import load_dataset + +# Set seed for reproducibility +def set_seed(seed): + np.random.seed(seed) + torch.random.manual_seed(seed) + +# Wrapper for tokenized input IDs +class TokenizerWrapper: + def __init__(self, input_ids): + self.input_ids = input_ids + +# Load and process wikitext2 dataset +def get_wikitext2(nsamples, seed, seqlen, tokenizer): + # Load train and test datasets + traindata = load_dataset('wikitext', 'wikitext-2-raw-v1', split='train') + testdata = load_dataset('wikitext', 'wikitext-2-raw-v1', split='test') + + # Encode datasets + trainenc = tokenizer(" ".join(traindata['text']), return_tensors='pt', truncation=True, max_length=tokenizer.model_max_length if hasattr(tokenizer, 'model_max_length') and tokenizer.model_max_length else 131072) + testenc = tokenizer("\n\n".join(testdata['text']), return_tensors='pt', truncation=True, max_length=tokenizer.model_max_length if hasattr(tokenizer, 'model_max_length') and tokenizer.model_max_length else 131072) + + # Generate samples from training set + random.seed(seed) + trainloader = [] + for _ in range(nsamples): + i = random.randint(0, trainenc.input_ids.shape[1] - seqlen - 1) + j = i + seqlen + inp = trainenc.input_ids[:, i:j] + tar = inp.clone() + tar[:, :-1] = -100 + trainloader.append((inp, tar)) + return trainloader, testenc + +# Load and process c4 dataset +def get_c4(nsamples, seed, seqlen, tokenizer): + # Load train and validation datasets + traindata = load_dataset('allenai/c4', 'en', split='train', streaming=True) + valdata = load_dataset('allenai/c4', 'en', split='validation', streaming=True) + + # Generate samples from training set + random.seed(seed) + trainloader = [] + shuffled_traindata = traindata.shuffle(seed=seed, buffer_size=10000) + for sample in shuffled_traindata: + trainenc = tokenizer(sample['text'], return_tensors='pt', truncation=True, max_length=tokenizer.model_max_length if hasattr(tokenizer, 'model_max_length') and tokenizer.model_max_length else 131072) + if trainenc.input_ids.shape[1] > seqlen: + i = random.randint(0, trainenc.input_ids.shape[1] - seqlen - 1) + j = i + seqlen + inp = trainenc.input_ids[:, i:j] + tar = inp.clone() + tar[:, :-1] = -100 + trainloader.append((inp, tar)) + if len(trainloader) >= nsamples: + break + + # Prepare validation dataset + val_samples = list(valdata.take(1100)) + val_text = ' '.join([s['text'] for s in val_samples]) + valenc = tokenizer(val_text, return_tensors='pt', truncation=True, max_length=256 * seqlen) + valenc = valenc.input_ids[:, :(256 * seqlen)] + valenc = TokenizerWrapper(valenc) + return trainloader, valenc + +# Function to select the appropriate loader based on dataset name +def get_loaders(name, nsamples=128, seed=0, seqlen=2048, tokenizer=None): + if 'wikitext2' in name: + return get_wikitext2(nsamples, seed, seqlen, tokenizer) + if "c4" in name: + return get_c4(nsamples, seed, seqlen, tokenizer) \ No newline at end of file diff --git a/src/alignment/pruning/strategies/external/wanda/layerwrapper.py b/src/alignment/pruning/strategies/external/wanda/layerwrapper.py new file mode 100644 index 00000000..1821e8f9 --- /dev/null +++ b/src/alignment/pruning/strategies/external/wanda/layerwrapper.py @@ -0,0 +1,35 @@ +import torch +import torch.nn as nn + +# Define WrappedGPT class +class WrappedGPT: + """ + This class wraps a GPT layer for specific operations. + """ + + def __init__(self, layer, layer_id=0, layer_name="none"): + self.layer = layer + self.dev = self.layer.weight.device + self.rows = layer.weight.data.shape[0] + self.columns = layer.weight.data.shape[1] + + self.scaler_row = torch.zeros((self.columns), device=self.dev) + self.nsamples = 0 + + self.layer_id = layer_id + self.layer_name = layer_name + + def add_batch(self, inp, out): + if len(inp.shape) == 2: + inp = inp.unsqueeze(0) + tmp = inp.shape[0] + if isinstance(self.layer, nn.Linear): + if len(inp.shape) == 3: + inp = inp.reshape((-1, inp.shape[-1])) + inp = inp.t() + + self.scaler_row *= self.nsamples / (self.nsamples+tmp) + self.nsamples += tmp + + inp = inp.type(torch.float32) + self.scaler_row += torch.norm(inp, p=2, dim=1) ** 2 / self.nsamples \ No newline at end of file diff --git a/src/alignment/pruning/strategies/external/wanda/prune.py b/src/alignment/pruning/strategies/external/wanda/prune.py new file mode 100644 index 00000000..2f998459 --- /dev/null +++ b/src/alignment/pruning/strategies/external/wanda/prune.py @@ -0,0 +1,172 @@ +import time +import heapq +import torch +import torch.nn as nn +from .layerwrapper import WrappedGPT +from .data import get_loaders + +def find_layers(module, layers=[nn.Linear], name=''): + """ + Recursively find the layers of a certain type in a module. + + Args: + module (nn.Module): PyTorch module. + layers (list): List of layer types to find. + name (str): Name of the module. + + Returns: + dict: Dictionary of layers of the given type(s) within the module. + """ + if type(module) in layers: + return {name: module} + res = {} + for name1, child in module.named_children(): + res.update(find_layers( + child, layers=layers, name=name + '.' + name1 if name != '' else name1 + )) + return res + +def prepare_calibration_input(model, dataloader, device, seqlen): + use_cache = model.config.use_cache + model.config.use_cache = False + layers = model.model.model.layers + + # dev = model.hf_device_map["model.embed_tokens"] + if hasattr(model, 'hf_device_map') and "model.embed_tokens" in model.hf_device_map: + device = model.hf_device_map["model.embed_tokens"] + + dtype = next(iter(model.parameters())).dtype + inps = torch.zeros((128, seqlen, model.config.hidden_size), dtype=dtype, device=device) + inps.requires_grad = False + cache = {'i': 0, 'attention_mask': None, "position_ids": None} + + class Catcher(nn.Module): + def __init__(self, module): + super().__init__() + self.module = module + def forward(self, inp, **kwargs): + inps[cache['i']] = inp + cache['i'] += 1 + cache['attention_mask'] = kwargs['attention_mask'] + cache['position_ids'] = kwargs['position_ids'] + raise ValueError + layers[0] = Catcher(layers[0]) + for batch in dataloader: + try: + model(batch[0].to(device)) + except ValueError: + pass + layers[0] = layers[0].module + + outs = torch.zeros_like(inps) + attention_mask = cache['attention_mask'] + position_ids = cache['position_ids'] + model.config.use_cache = use_cache + + return inps, outs, attention_mask, position_ids + +def return_given_alpha(alpha, sort_res, W_metric, tmp_metric, sum_before): + thres_cumsum = sum_before * alpha + sort_mask = tmp_metric <= thres_cumsum.reshape((-1,1)) + thres = torch.gather(sort_res[0], dim=1, index=sort_mask.sum(dim=1, keepdims=True)-1) + W_mask = (W_metric <= thres) + cur_sparsity = (W_mask==True).sum() / W_mask.numel() + return W_mask, cur_sparsity + +def prune_wanda(args, model, tokenizer, device=torch.device("cuda:0"), prune_n=0, prune_m=0, sparsity_ratio=None): + if sparsity_ratio is None: + sparsity_ratio = args.sparsity_ratio + use_cache = model.config.use_cache + model.config.use_cache = False + + # Get sequence length from tokenizer or use default + seqlen = getattr(tokenizer, 'model_max_length', None) + if seqlen is None or seqlen > 10000: # Some tokenizers have very large max_length + seqlen = 2048 # Default sequence length + + print("loading calibdation data") + dataloader, _ = get_loaders("c4",nsamples=args.nsamples,seed=args.seed,seqlen=seqlen,tokenizer=tokenizer) + print("dataset loading complete") + with torch.no_grad(): + inps, outs, attention_mask, position_ids = prepare_calibration_input(model, dataloader, device, seqlen) + + layers = model.model.model.layers + for i in range(len(layers)): + layer = layers[i] + subset = find_layers(layer) + + if hasattr(model, 'hf_device_map') and f"model.layers.{i}" in model.hf_device_map: ## handle the case for llama-30B and llama-65B, when the device map has multiple GPUs; + dev = model.hf_device_map[f"model.layers.{i}"] + inps, outs, attention_mask, position_ids = inps.to(dev), outs.to(dev), attention_mask.to(dev), position_ids.to(dev) + + wrapped_layers = {} + for name in subset: + wrapped_layers[name] = WrappedGPT(subset[name]) + + def add_batch(name): + def tmp(_, inp, out): + wrapped_layers[name].add_batch(inp[0].data, out.data) + return tmp + + handles = [] + for name in wrapped_layers: + handles.append(subset[name].register_forward_hook(add_batch(name))) + + for j in range(args.nsamples): + with torch.no_grad(): + # Generate position_ids if they are None + seq_len = inps[j].shape[0] + pos_ids = torch.arange(seq_len, dtype=torch.long, device=inps[j].device).unsqueeze(0) + outs[j] = layer(inps[j].unsqueeze(0), attention_mask=attention_mask)[0] + for h in handles: + h.remove() + + for name in subset: + print(f"pruning layer {i} name {name}") + W_metric = torch.abs(subset[name].weight.data) * torch.sqrt(wrapped_layers[name].scaler_row.reshape((1,-1))) + + W_mask = (torch.zeros_like(W_metric) == 1) ## initialize a mask to be all False + # if prune_n != 0: + # # structured n:m sparsity + # for ii in range(W_metric.shape[1]): + # if ii % prune_m == 0: + # tmp = W_metric[:,ii:(ii+prune_m)].float() + # W_mask.scatter_(1,ii+torch.topk(tmp, prune_n,dim=1, largest=False)[1], True) + # else: + sort_res = torch.sort(W_metric, dim=-1, stable=True) + + # if args.use_variant: + # # wanda variant + # tmp_metric = torch.cumsum(sort_res[0], dim=1) + # sum_before = W_metric.sum(dim=1) + + # alpha = 0.4 + # alpha_hist = [0., 0.8] + # W_mask, cur_sparsity = return_given_alpha(alpha, sort_res, W_metric, tmp_metric, sum_before) + # while (torch.abs(cur_sparsity - args.sparsity_ratio)>0.001) and (alpha_hist[1]-alpha_hist[0]>=0.001): + # if cur_sparsity > args.sparsity_ratio: + # alpha_new = (alpha + alpha_hist[0]) / 2.0 + # alpha_hist[1] = alpha + # else: + # alpha_new = (alpha + alpha_hist[1]) / 2.0 + # alpha_hist[0] = alpha + + # alpha = alpha_new + # W_mask, cur_sparsity = return_given_alpha(alpha, sort_res, W_metric, tmp_metric, sum_before) + # print(f"alpha found {alpha} sparsity {cur_sparsity:.6f}") + # else: + # unstructured pruning + indices = sort_res[1][:,:int(W_metric.shape[1]*sparsity_ratio)] + W_mask.scatter_(1, indices, True) + + subset[name].weight.data[W_mask] = 0 ## set weights to zero + + for j in range(args.nsamples): + with torch.no_grad(): + seq_len = inps[j].shape[0] + pos_ids = torch.arange(seq_len, dtype=torch.long, device=inps[j].device).unsqueeze(0) + outs[j] = layer(inps[j].unsqueeze(0), attention_mask=attention_mask)[0] + inps, outs = outs, inps + + model.config.use_cache = use_cache + torch.cuda.empty_cache() \ No newline at end of file diff --git a/src/alignment/pruning/strategies/llm_baselines copy.py b/src/alignment/pruning/strategies/llm_baselines copy.py new file mode 100644 index 00000000..2e81167a --- /dev/null +++ b/src/alignment/pruning/strategies/llm_baselines copy.py @@ -0,0 +1,675 @@ +""" +LLM Pruning Baselines: Wanda and SparseGPT. + +This module implements state-of-the-art LLM pruning methods for comparison: + +1. Wanda (Sun et al., 2023): "A Simple and Effective Pruning Approach for Large Language Models" + - Pruning metric: |W| × ||X||_2 (Weight magnitude × Activation norm) + - One-shot structured pruning without retraining + - Reference: https://arxiv.org/abs/2306.11695 + +2. SparseGPT (Frantar & Alistarh, 2023): "SparseGPT: Massive Language Models Can Be Accurately Pruned in One-Shot" + - Second-order pruning using OBS-style weight reconstruction + - Minimizes reconstruction error when pruning + - Reference: https://arxiv.org/abs/2301.00774 + +These methods are baselines compared against alignment-based pruning in: +- NVIDIA Minitron (https://arxiv.org/abs/2407.14679) +- Our alignment-based pruning experiments +""" + +import logging +from typing import Dict, List, Optional, Tuple, Any + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from ..base import BasePruningStrategy, PruningConfig + +logger = logging.getLogger(__name__) + + +class WandaPruning(BasePruningStrategy): + """ + Wanda: Pruning by Weights AND Activations. + + From Sun et al., 2023: "A Simple and Effective Pruning Approach for Large Language Models" + + The importance score for each weight is computed as: + importance(w_ij) = |w_ij| × ||X_j||_2 + + Where: + - w_ij is the weight connecting input j to output i + - X_j is the j-th input feature across calibration samples + - ||X_j||_2 is the L2 norm of activations for input feature j + + This combines weight magnitude (traditional magnitude pruning) with + activation magnitude (data-dependent importance). + + Args: + config: Pruning configuration + num_calibration_samples: Number of samples for calibration (default: 128) + + Example: + >>> strategy = WandaPruning() + >>> # Calibrate with sample activations + >>> strategy.calibrate(model, calibration_dataloader) + >>> # Compute importance scores for a layer + >>> scores = strategy.compute_importance_scores(layer, activations=X) + + Reference: + Sun et al. "A Simple and Effective Pruning Approach for Large Language Models" + https://arxiv.org/abs/2306.11695 + """ + + def __init__( + self, + config: Optional[PruningConfig] = None, + num_calibration_samples: int = 128, + ): + super().__init__(config) + self.num_calibration_samples = num_calibration_samples + self.activation_norms: Dict[str, torch.Tensor] = {} + self._calibrated = False + + def calibrate( + self, + model: nn.Module, + dataloader, + device: str = "cuda", + ) -> None: + """ + Calibrate activation norms using calibration data. + + Args: + model: Model to calibrate + dataloader: Calibration data loader + device: Device for computation + """ + logger.info(f"Calibrating Wanda with {self.num_calibration_samples} samples...") + + # Dictionary to store activation norms per layer + layer_activations: Dict[str, List[torch.Tensor]] = {} + + # Hook to capture activations + hooks = [] + def make_hook(name): + def hook(module, input, output): + if name not in layer_activations: + layer_activations[name] = [] + # Store input activations (for weight × activation) + if isinstance(input, tuple): + inp = input[0] + else: + inp = input + # Flatten batch and sequence dimensions, keep feature dim + if inp.dim() == 3: + # [batch, seq, hidden] -> [batch*seq, hidden] + inp = inp.view(-1, inp.size(-1)) + layer_activations[name].append(inp.detach().cpu()) + return hook + + # Register hooks on Linear layers + for name, module in model.named_modules(): + if isinstance(module, nn.Linear): + hooks.append(module.register_forward_hook(make_hook(name))) + + # Run calibration + model.eval() + samples_seen = 0 + + with torch.no_grad(): + for batch in dataloader: + if samples_seen >= self.num_calibration_samples: + break + + batch_size = 1 # Default batch size + if isinstance(batch, dict): + input_ids = batch["input_ids"].to(device) + batch_size = input_ids.size(0) + attention_mask = batch.get("attention_mask", None) + if attention_mask is not None: + attention_mask = attention_mask.to(device) + model(input_ids, attention_mask=attention_mask) + else: + if isinstance(batch, (list, tuple)): + inputs = batch[0].to(device) + else: + inputs = batch.to(device) + batch_size = inputs.size(0) if hasattr(inputs, 'size') else 1 + model(inputs) + + samples_seen += batch_size + + # Remove hooks + for hook in hooks: + hook.remove() + + # Compute activation norms (L2 norm per feature dimension) + for name, acts in layer_activations.items(): + if acts: + # Concatenate all activations + all_acts = torch.cat(acts, dim=0) # [total_tokens, hidden] + # Compute L2 norm per input feature + self.activation_norms[name] = torch.norm(all_acts, p=2, dim=0) # [hidden] + logger.debug(f"Layer {name}: activation norm shape {self.activation_norms[name].shape}") + + self._calibrated = True + logger.info(f"Wanda calibration complete. Computed norms for {len(self.activation_norms)} layers.") + + def compute_importance_scores( + self, + module: nn.Module, + inputs: Optional[torch.Tensor] = None, + layer_name: Optional[str] = None, + **kwargs + ) -> torch.Tensor: + """ + Compute Wanda importance scores: |W| × ||X||_2 + + Args: + module: Linear module to compute scores for + inputs: Input activations (if not using calibrated norms) + layer_name: Name of the layer (for looking up calibrated norms) + + Returns: + Importance scores with same shape as weights + """ + if not hasattr(module, "weight"): + raise ValueError(f"Module {module} does not have weights") + + weight = module.weight.data # [out_features, in_features] + + # Get activation norms + if inputs is not None: + # Compute norms from provided inputs + if inputs.dim() == 3: + inputs = inputs.view(-1, inputs.size(-1)) + activation_norm = torch.norm(inputs, p=2, dim=0) # [in_features] + elif layer_name and layer_name in self.activation_norms: + # Use calibrated norms + activation_norm = self.activation_norms[layer_name].to(weight.device) + elif self._calibrated: + # Try to find matching layer name + for name in self.activation_norms: + if name.endswith(layer_name) or layer_name in name: + activation_norm = self.activation_norms[name].to(weight.device) + break + else: + logger.warning(f"No calibrated activation norms for layer {layer_name}, using weight magnitude only") + return weight.abs() + else: + logger.warning("Wanda not calibrated and no inputs provided. Using weight magnitude only.") + return weight.abs() + + # Ensure dimensions match + if activation_norm.shape[0] != weight.shape[1]: + logger.warning(f"Activation norm shape {activation_norm.shape} doesn't match " + f"weight in_features {weight.shape[1]}. Using weight magnitude only.") + return weight.abs() + + # Wanda score: |W| × ||X||_2 + # Broadcasting: [out, in] × [in] -> [out, in] + importance = weight.abs() * activation_norm.unsqueeze(0) + + return importance + + def get_structured_scores( + self, + module: nn.Module, + inputs: Optional[torch.Tensor] = None, + layer_name: Optional[str] = None, + dim: int = 0, + ) -> torch.Tensor: + """ + Get structured (per-neuron/per-channel) importance scores. + + Args: + module: Module to score + inputs: Optional input activations + layer_name: Layer name for calibrated norms + dim: Dimension to aggregate over (0 for output neurons, 1 for input features) + + Returns: + 1D tensor of importance scores per neuron/channel + """ + importance = self.compute_importance_scores(module, inputs, layer_name) + + # Aggregate to get per-neuron scores + if dim == 0: + # Sum over input dimension -> score per output neuron + return importance.sum(dim=1) + else: + # Sum over output dimension -> score per input feature + return importance.sum(dim=0) + + +class SparseGPTPruning(BasePruningStrategy): + """ + SparseGPT: Second-order pruning with weight reconstruction. + + From Frantar & Alistarh, 2023: "SparseGPT: Massive Language Models Can Be Accurately Pruned in One-Shot" + + This method uses second-order information (Hessian approximation) to: + 1. Determine which weights to prune (lowest saliency) + 2. Update remaining weights to compensate for pruning + + The saliency for pruning weight w_i is: + saliency_i = w_i² / [H^{-1}]_ii + + Where H is the Hessian matrix approximated as X^T X (outer product of activations). + + After pruning w_i, remaining weights are updated: + w_j := w_j - w_i * [H^{-1}]_ij / [H^{-1}]_ii + + This is an OBS (Optimal Brain Surgeon) style reconstruction that minimizes + the increase in loss when pruning. + + Args: + config: Pruning configuration + num_calibration_samples: Number of samples for Hessian estimation + block_size: Block size for blockwise reconstruction (default: 128) + percdamp: Dampening factor for numerical stability (default: 0.01) + + Example: + >>> strategy = SparseGPTPruning() + >>> strategy.calibrate(model, calibration_dataloader) + >>> # Prune with reconstruction + >>> strategy.prune_layer(layer, sparsity=0.5) + + Reference: + Frantar & Alistarh "SparseGPT: Massive Language Models Can Be Accurately Pruned in One-Shot" + https://arxiv.org/abs/2301.00774 + """ + + def __init__( + self, + config: Optional[PruningConfig] = None, + num_calibration_samples: int = 128, + block_size: int = 128, + percdamp: float = 0.01, + ): + super().__init__(config) + self.num_calibration_samples = num_calibration_samples + self.block_size = block_size + self.percdamp = percdamp + self.hessians: Dict[str, torch.Tensor] = {} + self._calibrated = False + + def calibrate( + self, + model: nn.Module, + dataloader, + device: str = "cuda", + ) -> None: + """ + Compute Hessian approximation (X^T X) for each layer. + + Memory-optimized version that: + - Processes activations incrementally (running sum) + - Stores only diagonal for large layers to save memory + - Keeps Hessians on CPU + + Args: + model: Model to calibrate + dataloader: Calibration data loader + device: Device for computation + """ + logger.info(f"Calibrating SparseGPT with {self.num_calibration_samples} samples...") + + # For memory efficiency, we'll compute running sum of X^T X + # Store (running_H, nsamples) per layer + running_hessians: Dict[str, Tuple[torch.Tensor, int]] = {} + + # Hook to capture activations and update Hessian incrementally + hooks = [] + def make_hook(name): + def hook(module, input, output): + if isinstance(input, tuple): + inp = input[0] + else: + inp = input + + # Flatten to 2D: [batch*seq, features] + if inp.dim() == 3: + inp = inp.view(-1, inp.size(-1)) + + # Move to CPU and float32 for stability + inp = inp.detach().float().cpu() + n_tokens = inp.shape[0] + + # Compute H increment: X^T X + H_inc = inp.T @ inp + + if name not in running_hessians: + running_hessians[name] = (H_inc, n_tokens) + else: + old_H, old_n = running_hessians[name] + running_hessians[name] = (old_H + H_inc, old_n + n_tokens) + return hook + + # Register hooks only for Linear layers (MLP layers) + for name, module in model.named_modules(): + if isinstance(module, nn.Linear): + # Only hook MLP layers to save memory (skip attention) + if any(p in name for p in ["mlp", "up_proj", "gate_proj", "down_proj", "fc"]): + hooks.append(module.register_forward_hook(make_hook(name))) + + # Run calibration + model.eval() + samples_seen = 0 + + with torch.no_grad(): + for batch in dataloader: + if samples_seen >= self.num_calibration_samples: + break + + batch_size = 1 + if isinstance(batch, dict): + input_ids = batch["input_ids"].to(device) + batch_size = input_ids.size(0) + attention_mask = batch.get("attention_mask", None) + if attention_mask is not None: + attention_mask = attention_mask.to(device) + model(input_ids, attention_mask=attention_mask) + else: + if isinstance(batch, (list, tuple)): + inputs = batch[0].to(device) + else: + inputs = batch.to(device) + batch_size = inputs.size(0) if hasattr(inputs, 'size') else 1 + model(inputs) + + samples_seen += batch_size + + # Clear CUDA cache periodically + if samples_seen % 4 == 0: + torch.cuda.empty_cache() + + # Remove hooks + for hook in hooks: + hook.remove() + + # Finalize Hessians: normalize and add dampening + for name, (H_sum, nsamples) in running_hessians.items(): + if nsamples > 0: + # Normalize + H = H_sum / nsamples + + # Add dampening for numerical stability + damp = self.percdamp * torch.diag(H).mean() + H += damp * torch.eye(H.shape[0], device=H.device) + + # Store on CPU to save GPU memory + self.hessians[name] = H.cpu() + logger.debug(f"Layer {name}: Hessian shape {H.shape}") + + # Clear running storage + del running_hessians + torch.cuda.empty_cache() + + self._calibrated = True + logger.info(f"SparseGPT calibration complete. Computed Hessians for {len(self.hessians)} layers.") + + def compute_importance_scores( + self, + module: nn.Module, + inputs: Optional[torch.Tensor] = None, + layer_name: Optional[str] = None, + **kwargs + ) -> torch.Tensor: + """ + Compute SparseGPT saliency scores: w² / [H^{-1}]_ii + + For unstructured pruning, this gives the "cost" of removing each weight. + Lower scores = safer to prune. + + For structured pruning, we aggregate over the neuron dimension. + + Args: + module: Linear module + inputs: Optional inputs for online Hessian computation + layer_name: Layer name for looking up calibrated Hessian + + Returns: + Importance scores + """ + if not hasattr(module, "weight"): + raise ValueError(f"Module {module} does not have weights") + + weight = module.weight.data.float() # [out, in] + + # Get Hessian + H = None + if layer_name and layer_name in self.hessians: + H = self.hessians[layer_name] + elif self._calibrated: + for name in self.hessians: + if name.endswith(layer_name) or layer_name in name: + H = self.hessians[name] + break + + if H is None: + if inputs is not None: + # Compute Hessian from inputs + if inputs.dim() == 3: + inputs = inputs.view(-1, inputs.size(-1)) + inputs = inputs.float() + nsamples = inputs.shape[0] + H = (inputs.T @ inputs) / nsamples + damp = self.percdamp * torch.diag(H).mean() + H += damp * torch.eye(H.shape[0], device=H.device) + else: + logger.warning("SparseGPT not calibrated and no inputs. Using weight magnitude.") + return weight.abs() + + H = H.to(weight.device) + + # Compute H^{-1} diagonal (we need [H^{-1}]_ii for saliency) + try: + # For efficiency, use Cholesky decomposition + L = torch.linalg.cholesky(H) + H_inv = torch.cholesky_inverse(L) + H_inv_diag = torch.diag(H_inv) + except RuntimeError: + # Fall back to direct inverse if Cholesky fails + try: + H_inv = torch.linalg.inv(H) + H_inv_diag = torch.diag(H_inv) + except RuntimeError: + logger.warning("Hessian inversion failed, using weight magnitude") + return weight.abs() + + # Saliency score: w² / [H^{-1}]_ii + # Higher saliency = more important (bigger loss increase if pruned) + # Broadcasting: [out, in]² / [in] -> [out, in] + saliency = (weight ** 2) / (H_inv_diag.unsqueeze(0) + 1e-10) + + return saliency + + def get_structured_scores( + self, + module: nn.Module, + inputs: Optional[torch.Tensor] = None, + layer_name: Optional[str] = None, + dim: int = 0, + ) -> torch.Tensor: + """ + Get structured importance scores (aggregated per neuron). + + Args: + module: Module to score + inputs: Optional input activations + layer_name: Layer name + dim: Dimension to aggregate + + Returns: + 1D tensor of per-neuron scores + """ + saliency = self.compute_importance_scores(module, inputs, layer_name) + + # Aggregate to get per-neuron scores + if dim == 0: + return saliency.sum(dim=1) + else: + return saliency.sum(dim=0) + + def prune_and_reconstruct( + self, + module: nn.Module, + sparsity: float, + layer_name: Optional[str] = None, + inputs: Optional[torch.Tensor] = None, + ) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Prune weights and reconstruct remaining weights to minimize error. + + This is the full SparseGPT algorithm with OBS-style reconstruction. + + Args: + module: Linear module to prune + sparsity: Target sparsity (fraction to prune) + layer_name: Layer name for Hessian lookup + inputs: Optional inputs for online computation + + Returns: + Tuple of (pruning_mask, reconstructed_weight) + """ + if not hasattr(module, "weight"): + raise ValueError("Module does not have weights") + + W = module.weight.data.clone().float() + rows, cols = W.shape + + # Get Hessian + H = None + if layer_name and layer_name in self.hessians: + H = self.hessians[layer_name].to(W.device) + elif inputs is not None: + if inputs.dim() == 3: + inputs = inputs.view(-1, inputs.size(-1)) + inputs = inputs.float().to(W.device) + nsamples = inputs.shape[0] + H = (inputs.T @ inputs) / nsamples + damp = self.percdamp * torch.diag(H).mean() + H += damp * torch.eye(H.shape[0], device=H.device) + else: + logger.warning("No Hessian available, returning simple magnitude pruning") + scores = W.abs() + k = int(sparsity * W.numel()) + # Use topk for exact k selection (avoids threshold tie issues) + flat_scores = scores.flatten() + _, indices_to_prune = torch.topk(flat_scores, k, largest=False) + mask = torch.ones(flat_scores.numel(), dtype=torch.bool, device=W.device) + mask[indices_to_prune] = False + mask = mask.view(W.shape).float() + return mask, W * mask + + # Compute H^{-1} using Cholesky + try: + L = torch.linalg.cholesky(H) + H_inv = torch.cholesky_inverse(L) + except RuntimeError: + logger.warning("Cholesky failed, using direct inverse") + H_inv = torch.linalg.inv(H) + + # Number of weights to prune + num_prune = int(sparsity * W.numel()) + + # Create mask (1 = keep, 0 = prune) + mask = torch.ones_like(W) + + # Prune in blocks for efficiency (simplified version) + # Full SparseGPT uses column-wise processing; this is a simplified version + + # Compute saliency for all weights + H_inv_diag = torch.diag(H_inv) + saliency = (W ** 2) / (H_inv_diag.unsqueeze(0) + 1e-10) + + # Find weights with lowest saliency to prune + flat_saliency = saliency.flatten() + prune_indices = torch.topk(flat_saliency, num_prune, largest=False).indices + + # Create mask + mask_flat = mask.flatten() + mask_flat[prune_indices] = 0 + mask = mask_flat.view(rows, cols) + + # Reconstruct remaining weights (simplified - full algo does column-by-column) + # For each pruned weight, update connected weights + # This is a simplified version; full SparseGPT is more sophisticated + W_new = W.clone() + + # Zero out pruned weights + W_new = W_new * mask + + # Convert back to original dtype + W_new = W_new.to(module.weight.dtype) + + return mask, W_new + + +# Convenience functions for integration with the pruning framework + +def compute_wanda_scores( + model: nn.Module, + dataloader, + device: str = "cuda", + num_samples: int = 128, +) -> Dict[str, torch.Tensor]: + """ + Convenience function to compute Wanda scores for all Linear layers. + + Args: + model: Model to analyze + dataloader: Calibration data + device: Device + num_samples: Number of calibration samples + + Returns: + Dict mapping layer names to importance scores + """ + strategy = WandaPruning(num_calibration_samples=num_samples) + strategy.calibrate(model, dataloader, device) + + scores = {} + for name, module in model.named_modules(): + if isinstance(module, nn.Linear): + scores[name] = strategy.compute_importance_scores( + module, layer_name=name + ) + + return scores + + +def compute_sparsegpt_scores( + model: nn.Module, + dataloader, + device: str = "cuda", + num_samples: int = 128, +) -> Dict[str, torch.Tensor]: + """ + Convenience function to compute SparseGPT saliency scores. + + Args: + model: Model to analyze + dataloader: Calibration data + device: Device + num_samples: Number of calibration samples + + Returns: + Dict mapping layer names to saliency scores + """ + strategy = SparseGPTPruning(num_calibration_samples=num_samples) + strategy.calibrate(model, dataloader, device) + + scores = {} + for name, module in model.named_modules(): + if isinstance(module, nn.Linear): + scores[name] = strategy.compute_importance_scores( + module, layer_name=name + ) + + return scores + From 27cf048ffde98aad17bbac143cd5080bb9c2dd30 Mon Sep 17 00:00:00 2001 From: Houman Safaai Date: Mon, 12 Jan 2026 14:43:17 -0500 Subject: [PATCH 02/15] LLM: add OATS-style WikiText-2 PPL + fix IterableDataset shuffle --- src/alignment/configs/config_loader.py | 14 +- src/alignment/dataops/loaders.py | 28 +++- src/alignment/experiments/llm_experiments.py | 144 ++++++++++++++----- 3 files changed, 146 insertions(+), 40 deletions(-) diff --git a/src/alignment/configs/config_loader.py b/src/alignment/configs/config_loader.py index 002340e2..535525bb 100644 --- a/src/alignment/configs/config_loader.py +++ b/src/alignment/configs/config_loader.py @@ -283,14 +283,24 @@ def _convert_unified_to_original(unified: Dict[str, Any]) -> Dict[str, Any]: if "algorithms" in pruning: converted_algorithms = [] for alg in pruning["algorithms"]: - converted_algorithms.append(METRIC_UNIFIED_TO_ORIGINAL.get(alg, alg)) + # Important: pruning algorithm names are *not* the same as metric names. + # In particular, unified configs often use "magnitude" to mean the + # standard *weight* magnitude pruning baseline (filter/channel L2), + # not the activation metric `activation_l2_norm`. + if alg == "magnitude": + converted_algorithms.append("magnitude") + else: + converted_algorithms.append(METRIC_UNIFIED_TO_ORIGINAL.get(alg, alg)) original_pruning["algorithms"] = converted_algorithms # Convert scoring methods if "scoring_methods" in pruning: converted_scoring = [] for method in pruning["scoring_methods"]: - converted_scoring.append(METRIC_UNIFIED_TO_ORIGINAL.get(method, method)) + if method == "magnitude": + converted_scoring.append("magnitude") + else: + converted_scoring.append(METRIC_UNIFIED_TO_ORIGINAL.get(method, method)) original_pruning["scoring_methods"] = converted_scoring # Other pruning fields diff --git a/src/alignment/dataops/loaders.py b/src/alignment/dataops/loaders.py index cc142f8c..601b6a43 100644 --- a/src/alignment/dataops/loaders.py +++ b/src/alignment/dataops/loaders.py @@ -12,6 +12,7 @@ import torch import torch.distributed as dist from torch.utils.data import DataLoader, DistributedSampler, RandomSampler, SequentialSampler +from torch.utils.data import IterableDataset logger = logging.getLogger(__name__) @@ -78,14 +79,33 @@ def create_data_loader(dataset: Any, config: Optional[DataLoaderConfig] = None, # Create sampler if needed if config.distributed: - sampler = DistributedSampler(dataset, num_replicas=config.world_size, rank=config.rank, shuffle=config.shuffle, drop_last=config.drop_last) - loader_kwargs["sampler"] = sampler - loader_kwargs["shuffle"] = False # Sampler handles shuffling - loader_kwargs.pop("drop_last", None) # Sampler handles this + if isinstance(dataset, IterableDataset): + # IterableDatasets do not support samplers/shuffling in DataLoader. + # (DistributedSampler also requires __len__.) + if loader_kwargs.get("shuffle", False): + logger.warning("IterableDataset does not support shuffle=True; forcing shuffle=False.") + loader_kwargs["shuffle"] = False + loader_kwargs.pop("sampler", None) + else: + sampler = DistributedSampler( + dataset, + num_replicas=config.world_size, + rank=config.rank, + shuffle=config.shuffle, + drop_last=config.drop_last, + ) + loader_kwargs["sampler"] = sampler + loader_kwargs["shuffle"] = False # Sampler handles shuffling + loader_kwargs.pop("drop_last", None) # Sampler handles this elif not loader_kwargs.get("shuffle", True): # Use sequential sampler for deterministic ordering loader_kwargs["sampler"] = SequentialSampler(dataset) loader_kwargs["shuffle"] = False + elif isinstance(dataset, IterableDataset): + # PyTorch DataLoader forbids shuffle=True for IterableDataset (common for streaming datasets like C4). + logger.info("IterableDataset detected; forcing shuffle=False for DataLoader.") + loader_kwargs["shuffle"] = False + loader_kwargs.pop("sampler", None) return DataLoader(dataset, **loader_kwargs) diff --git a/src/alignment/experiments/llm_experiments.py b/src/alignment/experiments/llm_experiments.py index c8881d85..2a1981e1 100644 --- a/src/alignment/experiments/llm_experiments.py +++ b/src/alignment/experiments/llm_experiments.py @@ -25,6 +25,9 @@ class LLMAlignmentExperiment(BaseExperiment): def __init__(self, config): super().__init__(config) self.importance_scores: Dict[str, Dict[str, torch.Tensor]] = {} + # Cache for expensive perplexity tokenization (e.g., full WikiText-2 test set). + # Keyed by (dataset, subset, split, seqlen, add_bos_token flag). + self._ppl_token_cache: Dict[Tuple[str, str, str, int, bool], torch.Tensor] = {} def setup(self): """Setup LLM alignment experiment components.""" @@ -126,28 +129,107 @@ def evaluate_perplexity(self, dataset: str = "wikitext", split: str = "test", nu import torch from torch import autocast - logger.info(f"Evaluating perplexity on {dataset} ({split})...") + llm_cfg = getattr(self.config, "llm", {}) or {} + protocol = str(llm_cfg.get("perplexity_protocol", "legacy")).lower() - # Load dataset - from alignment.dataops.datasets.text_datasets import load_text_dataset - dataset_obj = load_text_dataset(dataset, self.config.model_config.get("model_id"), split=split, max_samples=num_samples) + logger.info(f"Evaluating perplexity on {dataset} ({split}) [protocol={protocol}]...") self.model.eval() - nlls = [] - total_length = 0 - device = torch.device(self.config.device) model_dtype = getattr(torch, self.config.model_config.get("torch_dtype", "float32")) + # ------------------------------------------------------------------ + # OATS/SparseGPT-style WikiText-2 perplexity: + # - concatenate full test set + # - evaluate in contiguous blocks (default: 2048 tokens) + # + # This matches the common protocol used in pruning papers (and OATS Table 19), + # and avoids padding artifacts from per-line evaluation. + # ------------------------------------------------------------------ + if protocol in {"oats", "sparsegpt", "block"} and str(dataset).lower() in {"wikitext", "wikitext2", "wikitext-2"}: + try: + from datasets import load_dataset + except Exception as e: + logger.error(f"datasets library not available; cannot run OATS-style perplexity: {e}") + return float("inf") + + subset = str(llm_cfg.get("wikitext_subset", "wikitext-2-raw-v1")) + seqlen = int(llm_cfg.get("perplexity_seq_len", 2048)) + # HuggingFace tokenizers may or may not add a BOS token by default; we store the flag for caching. + add_bos = bool(getattr(self.tokenizer, "add_bos_token", False)) + + cache_key = (str(dataset).lower(), subset, str(split), seqlen, add_bos) + input_ids = self._ppl_token_cache.get(cache_key) + if input_ids is None: + logger.info(f"Tokenizing WikiText for OATS-style PPL: subset={subset}, split={split}, seqlen={seqlen}") + ds = load_dataset("wikitext", subset, split=split) + texts = [t for t in ds["text"] if isinstance(t, str) and t.strip()] + joined = "\n\n".join(texts) + enc = self.tokenizer(joined, return_tensors="pt") + input_ids = enc["input_ids"].to(dtype=torch.long, device="cpu") + self._ppl_token_cache[cache_key] = input_ids + + nlls: List[torch.Tensor] = [] + total_tokens = 0 + + with torch.no_grad(): + # Iterate blocks without overlap (standard pruning-paper protocol). + # If the last block is too short to have any targets, skip it. + for bi, start in enumerate(range(0, int(input_ids.size(1)), seqlen)): + end = min(start + seqlen, int(input_ids.size(1))) + if end - start < 2: + continue + block = input_ids[:, start:end].to(device=device, dtype=torch.long) + labels = block.clone() + # Ensure token counting matches HF causal LM loss normalization (shifted by 1). + labels[:, 0] = -100 + num_valid_tokens = int((labels != -100).sum().item()) + if num_valid_tokens <= 0: + continue + + with autocast(device_type=self.config.device, dtype=model_dtype): + outputs = self.model(block, labels=labels) + loss = outputs.loss + nlls.append(loss * num_valid_tokens) + total_tokens += num_valid_tokens + + # Optional: allow partial evaluation for debugging + max_blocks = llm_cfg.get("perplexity_max_blocks") + if max_blocks is not None and bi + 1 >= int(max_blocks): + break + + if total_tokens <= 0 or not nlls: + logger.error("No valid tokens processed for OATS-style perplexity!") + return float("inf") + + ppl = torch.exp(torch.stack(nlls).sum() / total_tokens) + perplexity = float(ppl.item()) + logger.info(f"OATS-style WikiText PPL: {perplexity:.4f}") + return perplexity + + # ------------------------------------------------------------------ + # Legacy per-sample perplexity (kept for backwards compatibility). + # WARNING: this is sensitive to padding/truncation and is not paper-standard. + # ------------------------------------------------------------------ + from alignment.dataops.datasets.text_datasets import load_text_dataset + + dataset_obj = load_text_dataset( + dataset, + self.config.model_config.get("model_id"), + split=split, + max_samples=num_samples, + ) + + nlls = [] + total_length = 0 + with torch.no_grad(): for i, batch in enumerate(dataset_obj): if i >= num_samples: break - # Move input_ids to device (long, never bfloat16) input_ids = batch["input_ids"].unsqueeze(0).to(device, dtype=torch.long) - # Prepare labels labels = input_ids.clone() pad_token_id = getattr(self.tokenizer, "pad_token_id", None) or getattr(self.tokenizer, "eos_token_id", None) labels[labels == pad_token_id] = -100 @@ -155,7 +237,6 @@ def evaluate_perplexity(self, dataset: str = "wikitext", split: str = "test", nu labels[0, 0] = -100 try: - # Use autocast for bfloat16-safe forward with autocast(device_type=self.config.device, dtype=model_dtype): outputs = self.model(input_ids, labels=labels) loss = outputs.loss @@ -164,7 +245,6 @@ def evaluate_perplexity(self, dataset: str = "wikitext", split: str = "test", nu if num_valid_tokens > 0: nlls.append(loss * num_valid_tokens) total_length += num_valid_tokens - logger.info(f"Sample {i}: loss={loss.item():.4f}, valid_tokens={num_valid_tokens}") else: logger.warning(f"Sample {i}: No valid tokens!") except Exception as e: @@ -263,31 +343,32 @@ def evaluate_multiple_metrics( use_chain_of_thought = True # GSM8k uses CoT in Minitron logger.info("Using NVIDIA Minitron few-shot settings") - results = {} - + results: Dict[str, Any] = {} + + # Avoid recomputing perplexity multiple times (loss/bpb derive from it). + need_ppl = any(m in metrics for m in ["perplexity", "loss", "bits_per_byte", "normalized_perplexity"]) + ppl_cached: Optional[float] = None + if need_ppl: + try: + ppl_cached = self.evaluate_perplexity( + dataset=getattr(self.config, "evaluation_dataset", "wikitext"), + num_samples=num_samples, + ) + except Exception as e: + logger.error(f"Failed to evaluate perplexity (shared): {e}") + ppl_cached = None + for metric in metrics: num_fewshot = fewshot_settings.get(metric, 0) try: if metric == "perplexity": - results["perplexity"] = self.evaluate_perplexity( - dataset=getattr(self.config, "evaluation_dataset", "wikitext"), - num_samples=num_samples - ) + results["perplexity"] = ppl_cached elif metric == "loss": - # Cross-entropy loss = ln(perplexity) - ppl = self.evaluate_perplexity( - dataset=getattr(self.config, "evaluation_dataset", "wikitext"), - num_samples=num_samples - ) - results["loss"] = np.log(ppl) + results["loss"] = None if ppl_cached is None else float(np.log(ppl_cached)) elif metric == "bits_per_byte": # Bits per byte = log2(perplexity) / avg_chars_per_token - ppl = self.evaluate_perplexity( - dataset=getattr(self.config, "evaluation_dataset", "wikitext"), - num_samples=num_samples - ) # Approximate: assume ~4 characters per token on average - results["bits_per_byte"] = np.log2(ppl) / 4.0 + results["bits_per_byte"] = None if ppl_cached is None else float(np.log2(ppl_cached) / 4.0) elif metric == "accuracy_hellaswag": results["accuracy_hellaswag"] = self._evaluate_hellaswag( num_samples=num_samples, num_fewshot=num_fewshot @@ -332,12 +413,7 @@ def evaluate_multiple_metrics( results["accuracy_humaneval"] = self._evaluate_humaneval(num_samples=num_samples) elif metric == "normalized_perplexity": # Normalized to 0-100 scale (100 = best = PPL of 1) - ppl = self.evaluate_perplexity( - dataset=getattr(self.config, "evaluation_dataset", "wikitext"), - num_samples=num_samples - ) - # Use exponential decay: score = 100 * exp(-0.01 * (ppl - 1)) - results["normalized_perplexity"] = 100 * np.exp(-0.01 * (ppl - 1)) + results["normalized_perplexity"] = None if ppl_cached is None else float(100 * np.exp(-0.01 * (ppl_cached - 1))) else: logger.warning(f"Unknown evaluation metric: {metric}") except Exception as e: From c34ac1241d9d4a26252ac71b93e84046f71f91e2 Mon Sep 17 00:00:00 2001 From: Houman Safaai Date: Mon, 12 Jan 2026 14:49:21 -0500 Subject: [PATCH 03/15] Wanda: running scaler_row + stable row-wise pruning Aligns Wanda calibration/pruning with reference logic from origin/iss117_acllm_v3 (see external/wanda layerwrapper + prune). --- .../pruning/strategies/llm_baselines.py | 99 +++++++++++++------ 1 file changed, 67 insertions(+), 32 deletions(-) diff --git a/src/alignment/pruning/strategies/llm_baselines.py b/src/alignment/pruning/strategies/llm_baselines.py index 20881e71..a6a78cab 100644 --- a/src/alignment/pruning/strategies/llm_baselines.py +++ b/src/alignment/pruning/strategies/llm_baselines.py @@ -88,32 +88,62 @@ def calibrate( device: Device for computation """ logger.info(f"Calibrating Wanda with {self.num_calibration_samples} samples...") - - # Dictionary to store activation norms per layer - layer_activations: Dict[str, List[torch.Tensor]] = {} - - # Hook to capture activations + + # IMPORTANT (paper-faithful behavior + memory): + # Official Wanda implementations accumulate a running statistic (per layer) instead of + # storing all activations. The canonical update (see `external/wanda/layerwrapper.py` + # in origin/iss117_acllm_v3) is equivalent to maintaining: + # scaler_row[j] = E_sample[ sum_t x_{t,j}^2 ] (avg over samples; sum over tokens) + # and then using sqrt(scaler_row) as the per-input-feature activation scale. + # + # This differs from "concatenate all activations then take torch.norm" only by a + # layer-constant scaling (for fixed sequence length), but the running version is + # much more memory-efficient and matches reference code structure. + running: Dict[str, Tuple[torch.Tensor, int]] = {} # name -> (scaler_row (CPU), nsamples) + hooks = [] - def make_hook(name): + + def make_hook(name: str): def hook(module, input, output): - if name not in layer_activations: - layer_activations[name] = [] # Store input activations (for weight × activation) - if isinstance(input, tuple): - inp = input[0] - else: - inp = input - # Flatten batch and sequence dimensions, keep feature dim + inp = input[0] if isinstance(input, tuple) else input + + # Normalize shapes to match reference Wanda behavior: + # - If 2D, treat as a single sample (batch=1). + if inp.dim() == 2: + inp = inp.unsqueeze(0) + + tmp = int(inp.shape[0]) # batch size (number of samples in this hook call) + if tmp <= 0: + return + + # Flatten batch & sequence into tokens for the sum-of-squares statistic. + # Typical LLM MLP inputs are [B, S, F]. if inp.dim() == 3: - # [batch, seq, hidden] -> [batch*seq, hidden] - inp = inp.view(-1, inp.size(-1)) - layer_activations[name].append(inp.detach().cpu()) + tokens = inp.reshape(-1, inp.shape[-1]) # [B*S, F] + else: + # Fallback: treat last dim as features and everything else as "tokens". + tokens = inp.reshape(-1, inp.shape[-1]) + + # sum_t x^2 for each input feature (over tokens and batch) + sumsq = tokens.detach().to(dtype=torch.float32).pow(2).sum(dim=0).cpu() # [F] + + if name not in running: + running[name] = (sumsq / tmp, tmp) + else: + scaler_row, nsamples = running[name] + new_n = nsamples + tmp + # Running mean update (matches reference logic) + scaler_row = scaler_row * (nsamples / new_n) + (sumsq / new_n) + running[name] = (scaler_row, new_n) + return hook - - # Register hooks on Linear layers + + # Register hooks (only for MLP/FFN projections by default to keep calibration lightweight). for name, module in model.named_modules(): if isinstance(module, nn.Linear): - hooks.append(module.register_forward_hook(make_hook(name))) + if any(p in name for p in ["mlp", "up_proj", "gate_proj", "down_proj", "fc"]): + hooks.append(module.register_forward_hook(make_hook(name))) # Run calibration model.eval() @@ -146,14 +176,16 @@ def hook(module, input, output): for hook in hooks: hook.remove() - # Compute activation norms (L2 norm per feature dimension) - for name, acts in layer_activations.items(): - if acts: - # Concatenate all activations - all_acts = torch.cat(acts, dim=0) # [total_tokens, hidden] - # Compute L2 norm per input feature - self.activation_norms[name] = torch.norm(all_acts, p=2, dim=0) # [hidden] - logger.debug(f"Layer {name}: activation norm shape {self.activation_norms[name].shape}") + # Finalize activation norms: + # activation_norm[j] = sqrt(scaler_row[j]) where scaler_row is the running avg of sumsq. + self.activation_norms = {} + for name, (scaler_row, nsamples) in running.items(): + if nsamples <= 0: + continue + # Guard against tiny numerical negatives. + scaler_row = torch.clamp(scaler_row, min=0.0) + self.activation_norms[name] = torch.sqrt(scaler_row) + logger.debug(f"Layer {name}: activation norm shape {self.activation_norms[name].shape}") self._calibrated = True logger.info(f"Wanda calibration complete. Computed norms for {len(self.activation_norms)} layers.") @@ -297,13 +329,16 @@ def prune_unstructured_inplace( if mode == "random": # Random selection per row rand = torch.rand((rows, cols), device=device) - _, idx = torch.topk(rand, k, largest=False, dim=1) + sort_res = torch.sort(rand, dim=1, stable=True) + idx = sort_res[1][:, :k] elif mode == "low": - # Prune lowest Wanda scores - _, idx = torch.topk(scores, k, largest=False, dim=1) + # Paper-faithful: stable row-wise sort, then prune lowest fraction. + sort_res = torch.sort(scores, dim=1, stable=True) + idx = sort_res[1][:, :k] else: # mode == "high" - # Prune highest Wanda scores - _, idx = torch.topk(scores, k, largest=True, dim=1) + # Stable row-wise sort, prune highest fraction. + sort_res = torch.sort(scores, dim=1, stable=True) + idx = sort_res[1][:, -k:] row_idx = torch.arange(rows, device=device).unsqueeze(1).expand_as(idx) mask[row_idx, idx] = False From f988c541d4f012717ce0daac0e6a05a24f81705f Mon Sep 17 00:00:00 2001 From: Houman Safaai Date: Mon, 12 Jan 2026 15:24:50 -0500 Subject: [PATCH 04/15] Vendor Wanda reference implementation (audit only) Vendored from origin/iss117_acllm_v3 for auditability; not used by default. --- .../pruning/strategies/external/__init__.py | 2 + .../strategies/external/wanda/README.md | 26 +++ .../strategies/external/wanda/__init__.py | 7 + .../pruning/strategies/external/wanda/data.py | 75 ++++++++ .../strategies/external/wanda/layerwrapper.py | 35 ++++ .../strategies/external/wanda/prune.py | 172 ++++++++++++++++++ 6 files changed, 317 insertions(+) create mode 100644 src/alignment/pruning/strategies/external/__init__.py create mode 100644 src/alignment/pruning/strategies/external/wanda/README.md create mode 100644 src/alignment/pruning/strategies/external/wanda/__init__.py create mode 100644 src/alignment/pruning/strategies/external/wanda/data.py create mode 100644 src/alignment/pruning/strategies/external/wanda/layerwrapper.py create mode 100644 src/alignment/pruning/strategies/external/wanda/prune.py diff --git a/src/alignment/pruning/strategies/external/__init__.py b/src/alignment/pruning/strategies/external/__init__.py new file mode 100644 index 00000000..202774a9 --- /dev/null +++ b/src/alignment/pruning/strategies/external/__init__.py @@ -0,0 +1,2 @@ +"""Vendored reference implementations of external pruning baselines.""" + diff --git a/src/alignment/pruning/strategies/external/wanda/README.md b/src/alignment/pruning/strategies/external/wanda/README.md new file mode 100644 index 00000000..3a97143b --- /dev/null +++ b/src/alignment/pruning/strategies/external/wanda/README.md @@ -0,0 +1,26 @@ +## Reference: External Wanda Implementation (Vendored) + +This directory vendors a reference implementation of **Wanda** (Sun et al., 2023) used as a baseline for LLM pruning. + +### Purpose + +- **Reference-only**: this code is kept to make it easy to audit our internal Wanda baseline against a known implementation. +- Our paper’s comparisons use **channel-adapted baselines** implemented in `src/alignment/pruning/strategies/llm_baselines.py`. +- When we run the paper-faithful *unstructured* Wanda reproduction baseline, we also use the internal implementation (for integration/consistency), but keep this reference code for cross-checking. + +### Provenance + +This code was merged via `origin/iss117_acllm_v3` (see merge commit on the target branch) and corresponds to the files: + +- `src/alignment/pruning/strategies/external/wanda/data.py` +- `src/alignment/pruning/strategies/external/wanda/layerwrapper.py` +- `src/alignment/pruning/strategies/external/wanda/prune.py` + +### Key details to match + +- The running activation statistic: + - `scaler_row` update uses the expected **sum of squared activations** (per feature) accumulated sequentially. + - Pruning uses `W_metric = |W| * sqrt(scaler_row)`. +- Row-wise, stable sorting: + - `sort_res = torch.sort(W_metric, dim=-1, stable=True)`. + diff --git a/src/alignment/pruning/strategies/external/wanda/__init__.py b/src/alignment/pruning/strategies/external/wanda/__init__.py new file mode 100644 index 00000000..122b3a31 --- /dev/null +++ b/src/alignment/pruning/strategies/external/wanda/__init__.py @@ -0,0 +1,7 @@ +""" +Vendored reference implementation of Wanda. + +This package is kept for auditing / reference and is not the default path used by +our paper experiments. See `README.md` in this directory. +""" + diff --git a/src/alignment/pruning/strategies/external/wanda/data.py b/src/alignment/pruning/strategies/external/wanda/data.py new file mode 100644 index 00000000..d6eaa348 --- /dev/null +++ b/src/alignment/pruning/strategies/external/wanda/data.py @@ -0,0 +1,75 @@ +# Code adapted from https://github.com/IST-DASLab/sparsegpt/blob/master/datautils.py + +import numpy as np +import random +import torch +from datasets import load_dataset + +# Set seed for reproducibility +def set_seed(seed): + np.random.seed(seed) + torch.random.manual_seed(seed) + +# Wrapper for tokenized input IDs +class TokenizerWrapper: + def __init__(self, input_ids): + self.input_ids = input_ids + +# Load and process wikitext2 dataset +def get_wikitext2(nsamples, seed, seqlen, tokenizer): + # Load train and test datasets + traindata = load_dataset('wikitext', 'wikitext-2-raw-v1', split='train') + testdata = load_dataset('wikitext', 'wikitext-2-raw-v1', split='test') + + # Encode datasets + trainenc = tokenizer(" ".join(traindata['text']), return_tensors='pt', truncation=True, max_length=tokenizer.model_max_length if hasattr(tokenizer, 'model_max_length') and tokenizer.model_max_length else 131072) + testenc = tokenizer("\n\n".join(testdata['text']), return_tensors='pt', truncation=True, max_length=tokenizer.model_max_length if hasattr(tokenizer, 'model_max_length') and tokenizer.model_max_length else 131072) + + # Generate samples from training set + random.seed(seed) + trainloader = [] + for _ in range(nsamples): + i = random.randint(0, trainenc.input_ids.shape[1] - seqlen - 1) + j = i + seqlen + inp = trainenc.input_ids[:, i:j] + tar = inp.clone() + tar[:, :-1] = -100 + trainloader.append((inp, tar)) + return trainloader, testenc + +# Load and process c4 dataset +def get_c4(nsamples, seed, seqlen, tokenizer): + # Load train and validation datasets + traindata = load_dataset('allenai/c4', 'en', split='train', streaming=True) + valdata = load_dataset('allenai/c4', 'en', split='validation', streaming=True) + + # Generate samples from training set + random.seed(seed) + trainloader = [] + shuffled_traindata = traindata.shuffle(seed=seed, buffer_size=10000) + for sample in shuffled_traindata: + trainenc = tokenizer(sample['text'], return_tensors='pt', truncation=True, max_length=tokenizer.model_max_length if hasattr(tokenizer, 'model_max_length') and tokenizer.model_max_length else 131072) + if trainenc.input_ids.shape[1] > seqlen: + i = random.randint(0, trainenc.input_ids.shape[1] - seqlen - 1) + j = i + seqlen + inp = trainenc.input_ids[:, i:j] + tar = inp.clone() + tar[:, :-1] = -100 + trainloader.append((inp, tar)) + if len(trainloader) >= nsamples: + break + + # Prepare validation dataset + val_samples = list(valdata.take(1100)) + val_text = ' '.join([s['text'] for s in val_samples]) + valenc = tokenizer(val_text, return_tensors='pt', truncation=True, max_length=256 * seqlen) + valenc = valenc.input_ids[:, :(256 * seqlen)] + valenc = TokenizerWrapper(valenc) + return trainloader, valenc + +# Function to select the appropriate loader based on dataset name +def get_loaders(name, nsamples=128, seed=0, seqlen=2048, tokenizer=None): + if 'wikitext2' in name: + return get_wikitext2(nsamples, seed, seqlen, tokenizer) + if "c4" in name: + return get_c4(nsamples, seed, seqlen, tokenizer) \ No newline at end of file diff --git a/src/alignment/pruning/strategies/external/wanda/layerwrapper.py b/src/alignment/pruning/strategies/external/wanda/layerwrapper.py new file mode 100644 index 00000000..1821e8f9 --- /dev/null +++ b/src/alignment/pruning/strategies/external/wanda/layerwrapper.py @@ -0,0 +1,35 @@ +import torch +import torch.nn as nn + +# Define WrappedGPT class +class WrappedGPT: + """ + This class wraps a GPT layer for specific operations. + """ + + def __init__(self, layer, layer_id=0, layer_name="none"): + self.layer = layer + self.dev = self.layer.weight.device + self.rows = layer.weight.data.shape[0] + self.columns = layer.weight.data.shape[1] + + self.scaler_row = torch.zeros((self.columns), device=self.dev) + self.nsamples = 0 + + self.layer_id = layer_id + self.layer_name = layer_name + + def add_batch(self, inp, out): + if len(inp.shape) == 2: + inp = inp.unsqueeze(0) + tmp = inp.shape[0] + if isinstance(self.layer, nn.Linear): + if len(inp.shape) == 3: + inp = inp.reshape((-1, inp.shape[-1])) + inp = inp.t() + + self.scaler_row *= self.nsamples / (self.nsamples+tmp) + self.nsamples += tmp + + inp = inp.type(torch.float32) + self.scaler_row += torch.norm(inp, p=2, dim=1) ** 2 / self.nsamples \ No newline at end of file diff --git a/src/alignment/pruning/strategies/external/wanda/prune.py b/src/alignment/pruning/strategies/external/wanda/prune.py new file mode 100644 index 00000000..2f998459 --- /dev/null +++ b/src/alignment/pruning/strategies/external/wanda/prune.py @@ -0,0 +1,172 @@ +import time +import heapq +import torch +import torch.nn as nn +from .layerwrapper import WrappedGPT +from .data import get_loaders + +def find_layers(module, layers=[nn.Linear], name=''): + """ + Recursively find the layers of a certain type in a module. + + Args: + module (nn.Module): PyTorch module. + layers (list): List of layer types to find. + name (str): Name of the module. + + Returns: + dict: Dictionary of layers of the given type(s) within the module. + """ + if type(module) in layers: + return {name: module} + res = {} + for name1, child in module.named_children(): + res.update(find_layers( + child, layers=layers, name=name + '.' + name1 if name != '' else name1 + )) + return res + +def prepare_calibration_input(model, dataloader, device, seqlen): + use_cache = model.config.use_cache + model.config.use_cache = False + layers = model.model.model.layers + + # dev = model.hf_device_map["model.embed_tokens"] + if hasattr(model, 'hf_device_map') and "model.embed_tokens" in model.hf_device_map: + device = model.hf_device_map["model.embed_tokens"] + + dtype = next(iter(model.parameters())).dtype + inps = torch.zeros((128, seqlen, model.config.hidden_size), dtype=dtype, device=device) + inps.requires_grad = False + cache = {'i': 0, 'attention_mask': None, "position_ids": None} + + class Catcher(nn.Module): + def __init__(self, module): + super().__init__() + self.module = module + def forward(self, inp, **kwargs): + inps[cache['i']] = inp + cache['i'] += 1 + cache['attention_mask'] = kwargs['attention_mask'] + cache['position_ids'] = kwargs['position_ids'] + raise ValueError + layers[0] = Catcher(layers[0]) + for batch in dataloader: + try: + model(batch[0].to(device)) + except ValueError: + pass + layers[0] = layers[0].module + + outs = torch.zeros_like(inps) + attention_mask = cache['attention_mask'] + position_ids = cache['position_ids'] + model.config.use_cache = use_cache + + return inps, outs, attention_mask, position_ids + +def return_given_alpha(alpha, sort_res, W_metric, tmp_metric, sum_before): + thres_cumsum = sum_before * alpha + sort_mask = tmp_metric <= thres_cumsum.reshape((-1,1)) + thres = torch.gather(sort_res[0], dim=1, index=sort_mask.sum(dim=1, keepdims=True)-1) + W_mask = (W_metric <= thres) + cur_sparsity = (W_mask==True).sum() / W_mask.numel() + return W_mask, cur_sparsity + +def prune_wanda(args, model, tokenizer, device=torch.device("cuda:0"), prune_n=0, prune_m=0, sparsity_ratio=None): + if sparsity_ratio is None: + sparsity_ratio = args.sparsity_ratio + use_cache = model.config.use_cache + model.config.use_cache = False + + # Get sequence length from tokenizer or use default + seqlen = getattr(tokenizer, 'model_max_length', None) + if seqlen is None or seqlen > 10000: # Some tokenizers have very large max_length + seqlen = 2048 # Default sequence length + + print("loading calibdation data") + dataloader, _ = get_loaders("c4",nsamples=args.nsamples,seed=args.seed,seqlen=seqlen,tokenizer=tokenizer) + print("dataset loading complete") + with torch.no_grad(): + inps, outs, attention_mask, position_ids = prepare_calibration_input(model, dataloader, device, seqlen) + + layers = model.model.model.layers + for i in range(len(layers)): + layer = layers[i] + subset = find_layers(layer) + + if hasattr(model, 'hf_device_map') and f"model.layers.{i}" in model.hf_device_map: ## handle the case for llama-30B and llama-65B, when the device map has multiple GPUs; + dev = model.hf_device_map[f"model.layers.{i}"] + inps, outs, attention_mask, position_ids = inps.to(dev), outs.to(dev), attention_mask.to(dev), position_ids.to(dev) + + wrapped_layers = {} + for name in subset: + wrapped_layers[name] = WrappedGPT(subset[name]) + + def add_batch(name): + def tmp(_, inp, out): + wrapped_layers[name].add_batch(inp[0].data, out.data) + return tmp + + handles = [] + for name in wrapped_layers: + handles.append(subset[name].register_forward_hook(add_batch(name))) + + for j in range(args.nsamples): + with torch.no_grad(): + # Generate position_ids if they are None + seq_len = inps[j].shape[0] + pos_ids = torch.arange(seq_len, dtype=torch.long, device=inps[j].device).unsqueeze(0) + outs[j] = layer(inps[j].unsqueeze(0), attention_mask=attention_mask)[0] + for h in handles: + h.remove() + + for name in subset: + print(f"pruning layer {i} name {name}") + W_metric = torch.abs(subset[name].weight.data) * torch.sqrt(wrapped_layers[name].scaler_row.reshape((1,-1))) + + W_mask = (torch.zeros_like(W_metric) == 1) ## initialize a mask to be all False + # if prune_n != 0: + # # structured n:m sparsity + # for ii in range(W_metric.shape[1]): + # if ii % prune_m == 0: + # tmp = W_metric[:,ii:(ii+prune_m)].float() + # W_mask.scatter_(1,ii+torch.topk(tmp, prune_n,dim=1, largest=False)[1], True) + # else: + sort_res = torch.sort(W_metric, dim=-1, stable=True) + + # if args.use_variant: + # # wanda variant + # tmp_metric = torch.cumsum(sort_res[0], dim=1) + # sum_before = W_metric.sum(dim=1) + + # alpha = 0.4 + # alpha_hist = [0., 0.8] + # W_mask, cur_sparsity = return_given_alpha(alpha, sort_res, W_metric, tmp_metric, sum_before) + # while (torch.abs(cur_sparsity - args.sparsity_ratio)>0.001) and (alpha_hist[1]-alpha_hist[0]>=0.001): + # if cur_sparsity > args.sparsity_ratio: + # alpha_new = (alpha + alpha_hist[0]) / 2.0 + # alpha_hist[1] = alpha + # else: + # alpha_new = (alpha + alpha_hist[1]) / 2.0 + # alpha_hist[0] = alpha + + # alpha = alpha_new + # W_mask, cur_sparsity = return_given_alpha(alpha, sort_res, W_metric, tmp_metric, sum_before) + # print(f"alpha found {alpha} sparsity {cur_sparsity:.6f}") + # else: + # unstructured pruning + indices = sort_res[1][:,:int(W_metric.shape[1]*sparsity_ratio)] + W_mask.scatter_(1, indices, True) + + subset[name].weight.data[W_mask] = 0 ## set weights to zero + + for j in range(args.nsamples): + with torch.no_grad(): + seq_len = inps[j].shape[0] + pos_ids = torch.arange(seq_len, dtype=torch.long, device=inps[j].device).unsqueeze(0) + outs[j] = layer(inps[j].unsqueeze(0), attention_mask=attention_mask)[0] + inps, outs = outs, inps + + model.config.use_cache = use_cache + torch.cuda.empty_cache() \ No newline at end of file From a7683a7a8e2c7282ff64848c0b49a0d9d41f2332 Mon Sep 17 00:00:00 2001 From: Houman Safaai Date: Mon, 12 Jan 2026 15:38:17 -0500 Subject: [PATCH 05/15] LLM: allow streaming datasets without __len__ --- src/alignment/experiments/base.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/src/alignment/experiments/base.py b/src/alignment/experiments/base.py index b7402aee..959829d7 100644 --- a/src/alignment/experiments/base.py +++ b/src/alignment/experiments/base.py @@ -469,7 +469,12 @@ def _initialize_dataset(self): ) logger.info(f"Initialized dataset: {self.config.dataset_name}") - logger.info(f"Dataset size: {len(self.dataset)}") + # Some datasets (e.g., streaming/IterableDataset) do not implement __len__. + # Avoid crashing LLM runs that use streaming C4. + try: + logger.info(f"Dataset size: {len(self.dataset)}") + except (TypeError, NotImplementedError): + logger.info("Dataset size: unknown (no __len__)") def _initialize_metrics(self): """Initialize metrics.""" From 22d49bccbf3e4840a27409e48bd70646cfb99db7 Mon Sep 17 00:00:00 2001 From: Houman Safaai Date: Tue, 13 Jan 2026 08:41:42 -0500 Subject: [PATCH 06/15] LLM: compute SCAR-Conn scores even when plots disabled --- src/alignment/experiments/llm_experiments.py | 51 ++++++++++++++++++-- 1 file changed, 47 insertions(+), 4 deletions(-) diff --git a/src/alignment/experiments/llm_experiments.py b/src/alignment/experiments/llm_experiments.py index 2a1981e1..9ff4ab32 100644 --- a/src/alignment/experiments/llm_experiments.py +++ b/src/alignment/experiments/llm_experiments.py @@ -75,9 +75,17 @@ def setup(self): logger.info(f"Tracked layers expanded to {len(expanded)} layers") - # Ensure we have a text dataset for importance computation in LLM experiments. - # BaseExperiment may skip dataset initialization for LLM experiment types. - if getattr(self, "dataset", None) is None: + # Ensure we have a *text* dataset for importance computation in LLM experiments. + # BaseExperiment may skip dataset initialization for LLM experiment types, but even when it + # does initialize a dataset (e.g., `c4` via registry), it may be a streaming dataset without + # a materialized `.texts` list. For SCAR calibration we require raw strings, so we rebuild + # the dataset when `.texts` is missing or None. + needs_text_dataset = ( + getattr(self, "dataset", None) is None + or not hasattr(self.dataset, "texts") + or getattr(self.dataset, "texts", None) is None + ) + if needs_text_dataset: try: from alignment.dataops.datasets.text_datasets import load_text_dataset except ImportError as e: @@ -1739,7 +1747,7 @@ def compute_importance_scores(self, num_samples: int = 1, dim="input") -> Dict[s # otherwise, fall back to iterating the dataset or raise if no dataset is available. calibration_texts: List[str] = [] if getattr(self, "dataset", None) is not None: - if hasattr(self.dataset, "texts"): + if hasattr(self.dataset, "texts") and getattr(self.dataset, "texts", None) is not None: calibration_texts = list(self.dataset.texts) else: logger.warning( @@ -7175,6 +7183,41 @@ def run(self) -> Dict[str, Any]: except Exception as e: logger.warning(f"Failed baseline full-metric evaluation: {e}") + # Some SCAR pruning scores (e.g., `supernode_connectivity_score`) were historically computed + # inside the `generate_plots` block. For fast paper sweeps we often run with + # `generate_plots=false`, but we still need these scores for pruning to run. + if scar_scores and not getattr(self.config, "generate_plots", True): + supernode_config = getattr(self.config, "supernode", {}) or getattr(self.config, "supernode_config", {}) or {} + + if getattr(self.config, "do_directed_redundancy", True): + try: + directed_redundancy_results = self.compute_directed_redundancy( + scar_scores=scar_scores, + supernode_fraction=supernode_config.get("core_fraction", 0.01), + ) + results["directed_redundancy"] = directed_redundancy_results + logger.info("Directed redundancy computation complete") + except Exception as dr_err: + logger.error(f"Failed directed redundancy computation: {dr_err}") + import traceback + logger.error(traceback.format_exc()) + + if getattr(self.config, "do_connectivity_pruning", True): + try: + connectivity_results = self.compute_supernode_connectivity_pruning_score( + scar_scores=scar_scores, + supernode_fraction=supernode_config.get("core_fraction", 0.01), + high_connectivity_fraction=supernode_config.get("follower_fraction", 0.10), + redundancy_weight=supernode_config.get("redundancy_weight", 0.5), + plots_dir=None, + ) + results["supernode_connectivity"] = connectivity_results + logger.info("Supernode-connectivity pruning score computation complete") + except Exception as conn_err: + logger.error(f"Failed supernode-connectivity computation: {conn_err}") + import traceback + logger.error(traceback.format_exc()) + if self.config.do_pruning_experiments: sparsity_levels = self.config.pruning_amounts From 7acd544a5d8ca5e22199c82340c30722c7fa2c39 Mon Sep 17 00:00:00 2001 From: Houman Safaai Date: Tue, 13 Jan 2026 10:26:40 -0500 Subject: [PATCH 07/15] LLM: apply protect_core using down_proj supernode masks --- src/alignment/experiments/llm_experiments.py | 22 +++++++++++++++++++- 1 file changed, 21 insertions(+), 1 deletion(-) diff --git a/src/alignment/experiments/llm_experiments.py b/src/alignment/experiments/llm_experiments.py index 9ff4ab32..c40af708 100644 --- a/src/alignment/experiments/llm_experiments.py +++ b/src/alignment/experiments/llm_experiments.py @@ -6368,7 +6368,27 @@ def apply_pruning(self, sparsity: float = 0.2, metric: str = "activation_l2_norm # Get importance scores scores = self.importance_scores[layer_name][metric].clone() - core_mask = self.importance_scores[layer_name].get("supernode_mask") + # Supernode masks may be stored under different module-name prefixes + # (e.g., `model.layers.*` vs `model.model.layers.*`) depending on whether they were + # produced via SCAR hooks (HF model) or via tracked-layer activation capture (wrapper). + # + # For protection to be applied consistently, prefer the *down_proj* key for this layer + # (the canonical FFN-channel space), falling back to the current layer_name. + core_mask = None + try: + key_candidates = [ + f"model.layers.{layer_idx}.mlp.down_proj", + f"model.model.layers.{layer_idx}.mlp.down_proj", + layer_name, + layer_name.replace("model.model.", "model."), + layer_name.replace("model.", "model.model.", 1), + ] + for kcand in key_candidates: + core_mask = (self.importance_scores.get(kcand) or {}).get("supernode_mask") + if core_mask is not None: + break + except Exception: + core_mask = self.importance_scores[layer_name].get("supernode_mask") if core_mask is not None and self._should_protect_supernodes_for_metric(metric): margin = torch.abs(scores).max().detach().item() + 1.0 if mode == "low": From 75562ec4ba8cff941f22a838ab489af58ab8ae83 Mon Sep 17 00:00:00 2001 From: Houman Safaai Date: Tue, 13 Jan 2026 11:10:12 -0500 Subject: [PATCH 08/15] SCAR: redefine Conn using top-k core write support --- src/alignment/experiments/llm_experiments.py | 60 +++++++++++++++++--- 1 file changed, 52 insertions(+), 8 deletions(-) diff --git a/src/alignment/experiments/llm_experiments.py b/src/alignment/experiments/llm_experiments.py index c40af708..16af5d90 100644 --- a/src/alignment/experiments/llm_experiments.py +++ b/src/alignment/experiments/llm_experiments.py @@ -4748,14 +4748,40 @@ def compute_supernode_connectivity_pruning_score( super_mask = torch.zeros(m, dtype=torch.bool) super_mask[super_idx] = True - # Compute Conn_i from down_proj weights (write-pattern overlap) + # Compute Conn_i from down_proj weights. + # + # IMPORTANT: the classic "probability overlap" Conn + # <|v_i|, a> / (||v_i||_1 ||a||_1) + # tends to collapse to ~1/hidden_dim for dense matrices (≈ 2.4e-4 for d=4096), + # which makes SCAR-Conn numerically ineffective. Instead, we measure the fraction + # of each channel's write mass that falls on the *core write support*: + # the top-K hidden dimensions by aggregated supernode write mass a. + # + # Conn_i := sum_{h in TopK(a)} |v_i[h]| / ||v_i||_1 in [0, 1] W = module.weight.detach().float().cpu() # [hidden_dim, m] abs_W = W.abs() a = abs_W[:, super_idx].sum(dim=1) # [hidden_dim] - a_norm = a.sum() + eps v_norm = abs_W.sum(dim=0) + eps # [m] - conn_num = (abs_W * a.unsqueeze(1)).sum(dim=0) # [m] - conn = (conn_num / (v_norm * a_norm + eps)).clamp(0.0, 1.0) + + hidden_dim = int(abs_W.shape[0]) + k = int(supernode_cfg.get("connectivity_topk", 256)) + mass_frac = supernode_cfg.get("connectivity_mass_fraction", None) + a_sorted, a_order = torch.sort(a, descending=True) + if mass_frac is not None: + try: + mf = float(mass_frac) + except Exception: + mf = None + if mf is not None and 0.0 < mf < 1.0 and a_sorted.numel() > 0: + cdf = torch.cumsum(a_sorted, dim=0) + total = float(cdf[-1].item()) + if total > 0: + target = mf * total + k = int(torch.searchsorted(cdf, torch.tensor(target)).item()) + 1 + k = max(1, min(int(k), hidden_dim)) + core_idx = a_order[:k] + conn = abs_W.index_select(0, core_idx).sum(dim=0) / v_norm + conn = conn.clamp(0.0, 1.0) # Halo: top eta among non-supernodes by Conn non_super_idx = (~super_mask).nonzero(as_tuple=True)[0] @@ -5150,14 +5176,32 @@ def sample_pairs_pos(n: int, p: int) -> Tuple[torch.Tensor, torch.Tensor]: super_mask = torch.zeros(m, dtype=torch.bool) super_mask[super_idx] = True - # Compute Conn_i from down_proj weights (write-pattern overlap) + # Compute Conn_i from down_proj weights (same definition as SCAR-Conn): + # Conn_i := sum_{h in TopK(a)} |v_i[h]| / ||v_i||_1 (fraction of write mass on core support) W = module.weight.detach().float().cpu() # [hidden_dim, m] abs_W = W.abs() a = abs_W[:, super_idx].sum(dim=1) # [hidden_dim] - a_norm = a.sum() + eps v_norm = abs_W.sum(dim=0) + eps # [m] - conn_num = (abs_W * a.unsqueeze(1)).sum(dim=0) # [m] - conn = (conn_num / (v_norm * a_norm + eps)).clamp(0.0, 1.0) + + hidden_dim = int(abs_W.shape[0]) + k = int(supernode_cfg.get("connectivity_topk", 256)) + mass_frac = supernode_cfg.get("connectivity_mass_fraction", None) + a_sorted, a_order = torch.sort(a, descending=True) + if mass_frac is not None: + try: + mf = float(mass_frac) + except Exception: + mf = None + if mf is not None and 0.0 < mf < 1.0 and a_sorted.numel() > 0: + cdf = torch.cumsum(a_sorted, dim=0) + total = float(cdf[-1].item()) + if total > 0: + target = mf * total + k = int(torch.searchsorted(cdf, torch.tensor(target)).item()) + 1 + k = max(1, min(int(k), hidden_dim)) + core_idx = a_order[:k] + conn = abs_W.index_select(0, core_idx).sum(dim=0) / v_norm + conn = conn.clamp(0.0, 1.0) non_super_idx = (~super_mask).nonzero(as_tuple=True)[0] if non_super_idx.numel() < 2: From 5a12cb397d4c74f016a04d3a30bc100f11e1d93d Mon Sep 17 00:00:00 2001 From: Houman Safaai Date: Tue, 13 Jan 2026 11:38:02 -0500 Subject: [PATCH 09/15] SCAR: soften halo protection via rank-power mapping --- src/alignment/experiments/llm_experiments.py | 59 ++++++++++++++++++-- 1 file changed, 53 insertions(+), 6 deletions(-) diff --git a/src/alignment/experiments/llm_experiments.py b/src/alignment/experiments/llm_experiments.py index 16af5d90..1b3df516 100644 --- a/src/alignment/experiments/llm_experiments.py +++ b/src/alignment/experiments/llm_experiments.py @@ -4972,13 +4972,60 @@ def bwd_hook(mod: nn.Module, grad_input: Tuple[torch.Tensor, ...], grad_output: mi = -0.5 * torch.log(1 - rho_sq) redundancy_to_core = mi.max(dim=1).values # [|H|] - red_min = redundancy_to_core.min() - red_max = redundancy_to_core.max() - if red_max > red_min: - red_norm = (redundancy_to_core - red_min) / (red_max - red_min + eps) + + # Convert redundancy-to-core into a [0, 1] protection score. + # + # Empirically, redundancy magnitudes can be extremely small; min-max normalization + # then collapses most halo channels near Protect≈1. But a fully linear rank/CDF + # can be too aggressive when redundancy estimates are noisy. We therefore default + # to a *soft* rank-power mapping that mainly penalizes only the most redundant tail. + norm_mode = str(supernode_cfg.get("protection_normalization", "rank_power")).lower() + if norm_mode == "minmax": + red_min = redundancy_to_core.min() + red_max = redundancy_to_core.max() + if red_max > red_min: + red_norm = (redundancy_to_core - red_min) / (red_max - red_min + eps) + else: + red_norm = torch.zeros_like(redundancy_to_core) + protect_halo = (1.0 - red_norm).clamp(0.0, 1.0) + elif norm_mode in {"rank", "cdf"}: + if redundancy_to_core.numel() <= 1: + protect_halo = torch.ones_like(redundancy_to_core) + else: + # Ascending ranks: lowest redundancy -> highest protection. + _, order = torch.sort(redundancy_to_core, stable=True) + ranks = torch.empty_like(order, dtype=torch.float32) + ranks[order] = torch.arange(order.numel(), dtype=torch.float32) + red_rank = ranks / float(max(1, order.numel() - 1)) + protect_halo = (1.0 - red_rank).clamp(0.0, 1.0) else: - red_norm = torch.zeros_like(redundancy_to_core) - protect_halo = (1.0 - red_norm).clamp(0.0, 1.0) + # rank_power (default): Protect = floor + (1-floor)*(1 - rank^gamma) + if redundancy_to_core.numel() <= 1: + protect_halo = torch.ones_like(redundancy_to_core) + else: + _, order = torch.sort(redundancy_to_core, stable=True) + ranks = torch.empty_like(order, dtype=torch.float32) + ranks[order] = torch.arange(order.numel(), dtype=torch.float32) + red_rank = ranks / float(max(1, order.numel() - 1)) + red_rank = red_rank.clamp(0.0, 1.0) + + gamma = supernode_cfg.get("protection_rank_power", 8.0) + try: + gamma_f = float(gamma) + except Exception: + gamma_f = 8.0 + if not (gamma_f > 0): + gamma_f = 8.0 + + floor = supernode_cfg.get("protection_floor", 0.2) + try: + floor_f = float(floor) + except Exception: + floor_f = 0.2 + floor_f = float(min(1.0, max(0.0, floor_f))) + + protect_halo = floor_f + (1.0 - floor_f) * (1.0 - red_rank.pow(gamma_f)) + protect_halo = protect_halo.clamp(0.0, 1.0) m = st["m"] lp = st["lp_cpu"].float() From a622647ad1046ffecbe2ae310f877122f91748dc Mon Sep 17 00:00:00 2001 From: Houman Safaai Date: Tue, 13 Jan 2026 12:56:04 -0500 Subject: [PATCH 10/15] Paper plots: add accuracy-vs-sparsity curve helper --- .../analysis/visualization/paper_plots.py | 48 +++++++++++++++++++ 1 file changed, 48 insertions(+) diff --git a/src/alignment/analysis/visualization/paper_plots.py b/src/alignment/analysis/visualization/paper_plots.py index 81471491..f5b03b90 100644 --- a/src/alignment/analysis/visualization/paper_plots.py +++ b/src/alignment/analysis/visualization/paper_plots.py @@ -428,6 +428,54 @@ def plot_sparsity_perplexity_curves( return fig +def plot_sparsity_accuracy_curves( + sparsities: Sequence[float], + acc_by_method: Dict[str, Sequence[Optional[float]]], + baseline_acc: Optional[float] = None, + *, + ylabel: str = "Accuracy (%)", + title: str = "Accuracy vs sparsity (low-mode)", + save_path: Optional[Union[str, Path]] = None, + dpi: int = 300, +) -> plt.Figure: + """ + Paper-facing plot: downstream accuracy vs structured sparsity for multiple methods. + + Notes: + - Accuracies are expected to already be in percent units (e.g., 58.0 for 58%). + - Inputs should be filtered to the intended pruning direction (typically low-mode). + """ + xs = np.asarray(list(sparsities), dtype=np.float64) + fig, ax = plt.subplots(figsize=(7.0, 4.2)) + + for label in sorted(acc_by_method.keys()): + ys_raw = acc_by_method[label] + ys = np.asarray([np.nan if v is None else float(v) for v in ys_raw], dtype=np.float64) + finite = np.isfinite(ys) + if not np.any(finite): + continue + ax.plot(xs[finite], ys[finite], "o-", linewidth=2.0, markersize=5, label=label, alpha=0.9) + + if baseline_acc is not None: + try: + b = float(baseline_acc) + if np.isfinite(b): + ax.axhline(b, color="#2c3e50", linestyle=":", linewidth=2.0, label=f"Unpruned ({b:.1f}%)") + except Exception: + pass + + ax.set_xlabel("Structured FFN channel sparsity", fontsize=11) + ax.set_ylabel(ylabel, fontsize=11) + ax.set_title(title, fontsize=12, fontweight="bold") + ax.grid(True, alpha=0.25) + ax.legend(loc="lower left", fontsize=9, frameon=True) + + plt.tight_layout() + if save_path is not None: + _save(fig, save_path, dpi=dpi) + return fig + + def plot_scar_schematic( save_path: Optional[Union[str, Path]] = None, dpi: int = 300, From 08270300aa3a21bee2c6d4bc48fa1448dd104b00 Mon Sep 17 00:00:00 2001 From: Houman Safaai Date: Tue, 13 Jan 2026 15:29:00 -0500 Subject: [PATCH 11/15] LLM eval: use conditional logprob for MCQ continuations --- src/alignment/experiments/llm_experiments.py | 238 +++++++++---------- 1 file changed, 109 insertions(+), 129 deletions(-) diff --git a/src/alignment/experiments/llm_experiments.py b/src/alignment/experiments/llm_experiments.py index 1b3df516..dc426343 100644 --- a/src/alignment/experiments/llm_experiments.py +++ b/src/alignment/experiments/llm_experiments.py @@ -430,6 +430,67 @@ def evaluate_multiple_metrics( return results + def _score_continuations_conditional_logprob( + self, + prompt: str, + continuations: List[str], + *, + max_length: int = 2048, + ) -> List[float]: + """ + Score each continuation by its conditional log-probability given the prompt. + + Implementation detail: + We compute the model loss on *only the continuation tokens* (masking prompt tokens + with -100) and return the mean log-prob per continuation token (higher is better). + """ + # Defensive: handle empty candidate lists + if not continuations: + return [] + + device = torch.device(self.config.device) + model = self.model + tok = self.tokenizer + + # Encode prompt once (no special tokens), then add BOS if the tokenizer defines one. + prompt_ids = tok(prompt, add_special_tokens=False).input_ids + bos_id = getattr(tok, "bos_token_id", None) + prefix_ids = ([bos_id] if bos_id is not None else []) + prompt_ids + prefix_len_full = len(prefix_ids) + + scores: List[float] = [] + model.eval() + with torch.no_grad(): + for cont in continuations: + cont_ids = tok(cont, add_special_tokens=False).input_ids + input_ids = prefix_ids + cont_ids + + # Truncate from the left if needed (keep most recent context). + prefix_len = prefix_len_full + if len(input_ids) > max_length: + drop = len(input_ids) - max_length + input_ids = input_ids[drop:] + prefix_len = max(0, prefix_len_full - drop) + # If we truncated away the entire prompt context, the score becomes meaningless. + if prefix_len <= 0: + scores.append(float("-inf")) + continue + + input_ids_t = torch.tensor(input_ids, dtype=torch.long, device=device).unsqueeze(0) + attn = torch.ones_like(input_ids_t, dtype=torch.long, device=device) + + labels = input_ids_t.clone() + labels[:, :prefix_len] = -100 # only score continuation tokens + + out = model(input_ids=input_ids_t, attention_mask=attn, labels=labels) + loss = getattr(out, "loss", None) + if loss is None: + scores.append(float("-inf")) + else: + scores.append(float(-loss.item())) + + return scores + def _evaluate_mmlu(self, num_samples: int = 100, subjects: List[str] = None, num_fewshot: int = 0) -> float: """ Few-shot evaluation on MMLU (Massive Multitask Language Understanding). @@ -559,26 +620,14 @@ def _evaluate_mmlu(self, num_samples: int = 100, subjects: List[str] = None, num choices = example["choices"] answer_idx = example["answer"] # 0-indexed - # Score each choice - scores = [] - for j, choice in enumerate(choices): - # Format: Question: ... Answer: A) choice - if num_fewshot > 0: - choices_str = "\n".join([f"{choice_labels[k]}) {c}" for k, c in enumerate(choices)]) - text = f"{fewshot_prompt}Question: {question}\n{choices_str}\nAnswer: {choice_labels[j]}" - else: - text = f"Question: {question}\nAnswer: {choice_labels[j]}) {choice}" - inputs = self.tokenizer(text, return_tensors="pt", truncation=True, max_length=2048) - inputs = {k: v.to(device) for k, v in inputs.items()} - - outputs = self.model(**inputs) - logits = outputs.logits - shift_logits = logits[..., :-1, :].contiguous() - shift_labels = inputs["input_ids"][..., 1:].contiguous() - - loss_fct = torch.nn.CrossEntropyLoss(reduction='mean') - loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1)) - scores.append(-loss.item()) + # Score each answer label by conditional log-probability (standard MCQ protocol): + # prompt includes all choices; continuation is just the option label. + choices_str = "\n".join([f"{choice_labels[k]}) {c}" for k, c in enumerate(choices)]) + prompt = f"{fewshot_prompt}Question: {question}\n{choices_str}\nAnswer:" if num_fewshot > 0 else ( + f"Question: {question}\n{choices_str}\nAnswer:" + ) + continuations = [f" {choice_labels[j]}" for j in range(len(choices))] + scores = self._score_continuations_conditional_logprob(prompt, continuations, max_length=2048) predicted = np.argmax(scores) if predicted == answer_idx: @@ -658,24 +707,12 @@ def _evaluate_hellaswag(self, num_samples: int = 100, num_fewshot: int = 0) -> f endings = example["endings"] label = int(example["label"]) - # Score each ending - scores = [] - for ending in endings: - if num_fewshot > 0: - text = f"{fewshot_prompt}Context: {ctx}\nEnding: {ending}" - else: - text = f"{ctx} {ending}" - inputs = self.tokenizer(text, return_tensors="pt", truncation=True, max_length=2048) - inputs = {k: v.to(device) for k, v in inputs.items()} - - outputs = self.model(**inputs) - logits = outputs.logits - shift_logits = logits[..., :-1, :].contiguous() - shift_labels = inputs["input_ids"][..., 1:].contiguous() - - loss_fct = torch.nn.CrossEntropyLoss(reduction='mean') - loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1)) - scores.append(-loss.item()) + # Score endings by conditional log-probability (context is prompt, ending is continuation). + prompt = ( + f"{fewshot_prompt}Context: {ctx}\nEnding:" if num_fewshot > 0 else f"Context: {ctx}\nEnding:" + ) + continuations = [f" {ending}" for ending in endings] + scores = self._score_continuations_conditional_logprob(prompt, continuations, max_length=2048) predicted = np.argmax(scores) if predicted == label: @@ -751,24 +788,12 @@ def _evaluate_arc_easy(self, num_samples: int = 100, num_fewshot: int = 0) -> fl choice_labels = choices["label"] answer_idx = choice_labels.index(answer_key) - # Score each choice - scores = [] - for choice_text in choice_texts: - if num_fewshot > 0: - text = f"{fewshot_prompt}Question: {question}\nAnswer: {choice_text}" - else: - text = f"Question: {question}\nAnswer: {choice_text}" - inputs = self.tokenizer(text, return_tensors="pt", truncation=True, max_length=2048) - inputs = {k: v.to(device) for k, v in inputs.items()} - - outputs = self.model(**inputs) - logits = outputs.logits - shift_logits = logits[..., :-1, :].contiguous() - shift_labels = inputs["input_ids"][..., 1:].contiguous() - - loss_fct = torch.nn.CrossEntropyLoss(reduction='mean') - loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1)) - scores.append(-loss.item()) + # Score candidate answers by conditional log-probability (prompt excludes answer tokens). + prompt = ( + f"{fewshot_prompt}Question: {question}\nAnswer:" if num_fewshot > 0 else f"Question: {question}\nAnswer:" + ) + continuations = [f" {ct}" for ct in choice_texts] + scores = self._score_continuations_conditional_logprob(prompt, continuations, max_length=2048) predicted = np.argmax(scores) if predicted == answer_idx: @@ -838,24 +863,10 @@ def _evaluate_piqa(self, num_samples: int = 100, num_fewshot: int = 0) -> float: sol2 = example["sol2"] label = example["label"] # 0 or 1 - # Score each solution - scores = [] - for sol in [sol1, sol2]: - if num_fewshot > 0: - text = f"{fewshot_prompt}Goal: {goal}\nSolution: {sol}" - else: - text = f"Goal: {goal}\nSolution: {sol}" - inputs = self.tokenizer(text, return_tensors="pt", truncation=True, max_length=2048) - inputs = {k: v.to(device) for k, v in inputs.items()} - - outputs = self.model(**inputs) - logits = outputs.logits - shift_logits = logits[..., :-1, :].contiguous() - shift_labels = inputs["input_ids"][..., 1:].contiguous() - - loss_fct = torch.nn.CrossEntropyLoss(reduction='mean') - loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1)) - scores.append(-loss.item()) + # Score solutions by conditional log-probability (goal is prompt, solution is continuation). + prompt = f"{fewshot_prompt}Goal: {goal}\nSolution:" if num_fewshot > 0 else f"Goal: {goal}\nSolution:" + continuations = [f" {sol1}", f" {sol2}"] + scores = self._score_continuations_conditional_logprob(prompt, continuations, max_length=2048) predicted = np.argmax(scores) if predicted == label: @@ -922,24 +933,13 @@ def _evaluate_boolq(self, num_samples: int = 100, num_fewshot: int = 0) -> float passage = example["passage"] answer = example["answer"] # True or False - # Score "Yes" vs "No" completions - scores = [] - for response in ["Yes", "No"]: - if num_fewshot > 0: - text = f"{fewshot_prompt}Passage: {passage}\nQuestion: {question}\nAnswer: {response}" - else: - text = f"Passage: {passage}\nQuestion: {question}\nAnswer: {response}" - inputs = self.tokenizer(text, return_tensors="pt", truncation=True, max_length=2048) - inputs = {k: v.to(device) for k, v in inputs.items()} - - outputs = self.model(**inputs) - logits = outputs.logits - shift_logits = logits[..., :-1, :].contiguous() - shift_labels = inputs["input_ids"][..., 1:].contiguous() - - loss_fct = torch.nn.CrossEntropyLoss(reduction='mean') - loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1)) - scores.append(-loss.item()) + # Score "Yes" vs "No" by conditional log-probability of the answer token(s). + prompt = ( + f"{fewshot_prompt}Passage: {passage}\nQuestion: {question}\nAnswer:" + if num_fewshot > 0 + else f"Passage: {passage}\nQuestion: {question}\nAnswer:" + ) + scores = self._score_continuations_conditional_logprob(prompt, [" Yes", " No"], max_length=2048) # 0 = Yes (True), 1 = No (False) predicted = np.argmax(scores) == 0 # True if "Yes" has higher score @@ -1016,25 +1016,17 @@ def _evaluate_winogrande(self, num_samples: int = 100, num_fewshot: int = 0) -> option2 = example["option2"] answer = int(example["answer"]) - 1 # Convert 1/2 to 0/1 - # Replace _ with each option and score - scores = [] - for option in [option1, option2]: - completed = sentence.replace("_", option) - if num_fewshot > 0: - text = f"{fewshot_prompt}Sentence: {completed}" - else: - text = completed - inputs = self.tokenizer(text, return_tensors="pt", truncation=True, max_length=2048) - inputs = {k: v.to(device) for k, v in inputs.items()} - - outputs = self.model(**inputs) - logits = outputs.logits - shift_logits = logits[..., :-1, :].contiguous() - shift_labels = inputs["input_ids"][..., 1:].contiguous() - - loss_fct = torch.nn.CrossEntropyLoss(reduction='mean') - loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1)) - scores.append(-loss.item()) # Higher = better + # Score each option by conditional log-probability of the completion. + # We split at the blank so only the option + suffix is scored (prefix is prompt). + if "_" in sentence: + prefix, suffix = sentence.split("_", 1) + else: + prefix, suffix = sentence, "" + prompt = ( + f"{fewshot_prompt}Sentence: {prefix}" if num_fewshot > 0 else f"Sentence: {prefix}" + ) + continuations = [f"{option1}{suffix}", f"{option2}{suffix}"] + scores = self._score_continuations_conditional_logprob(prompt, continuations, max_length=2048) predicted = np.argmax(scores) if predicted == answer: @@ -1116,24 +1108,12 @@ def _evaluate_arc_challenge(self, num_samples: int = 100, num_fewshot: int = 0) choice_labels = choices["label"] answer_idx = choice_labels.index(answer_key) - # Score each choice - scores = [] - for choice_text in choice_texts: - if num_fewshot > 0: - text = f"{fewshot_prompt}Question: {question}\nAnswer: {choice_text}" - else: - text = f"Question: {question}\nAnswer: {choice_text}" - inputs = self.tokenizer(text, return_tensors="pt", truncation=True, max_length=2048) - inputs = {k: v.to(device) for k, v in inputs.items()} - - outputs = self.model(**inputs) - logits = outputs.logits - shift_logits = logits[..., :-1, :].contiguous() - shift_labels = inputs["input_ids"][..., 1:].contiguous() - - loss_fct = torch.nn.CrossEntropyLoss(reduction='mean') - loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1)) - scores.append(-loss.item()) + # Score candidate answers by conditional log-probability (prompt excludes answer tokens). + prompt = ( + f"{fewshot_prompt}Question: {question}\nAnswer:" if num_fewshot > 0 else f"Question: {question}\nAnswer:" + ) + continuations = [f" {ct}" for ct in choice_texts] + scores = self._score_continuations_conditional_logprob(prompt, continuations, max_length=2048) predicted = np.argmax(scores) if predicted == answer_idx: From ea2b7d26e7b773a2e570787733daaa5f924c795e Mon Sep 17 00:00:00 2001 From: Houman Safaai Date: Wed, 14 Jan 2026 10:01:48 -0500 Subject: [PATCH 12/15] update config/slurm --- configs/prune_llm/README.md | 2 +- configs/prune_llm/llama2_7b_full.yaml | 23 ++- configs/prune_llm/llama2_7b_unified.yaml | 2 +- configs/prune_llm/llama3_8b_full.yaml | 27 ++- configs/prune_llm/llama3_8b_unified.yaml | 4 +- configs/prune_llm/mistral_7b_full.yaml | 23 ++- configs/prune_llm/mistral_7b_unified.yaml | 2 +- configs/prune_llm/qwen2_7b_full.yaml | 23 ++- configs/prune_llm/qwen2_7b_unified.yaml | 2 +- .../mobilenetv2_cifar10_unified.yaml | 10 +- .../resnet18_cifar10_unified.yaml | 12 +- .../resnet50_imagenet100_unified.yaml | 12 +- .../vision_prune/vgg16_cifar10_unified.yaml | 12 +- scripts/run_experiment.py | 36 ++++ slurm_jobs/prune_llm/README.md | 74 +++++++ slurm_jobs/prune_llm/run_all_paper.sh | 110 ++++++++++ slurm_jobs/prune_llm/run_llama2_7b.sh | 103 ++++++++++ slurm_jobs/prune_llm/run_llama3_8b.sh | 110 ++++++++++ .../run_llama3_8b_calibration_array.sh | 121 +++++++++++ .../prune_llm/run_llama3_8b_noprotect.sh | 99 +++++++++ ...run_llama3_8b_positive_redundancy_array.sh | 108 ++++++++++ .../run_llama3_8b_protect_baselines.sh | 100 +++++++++ .../run_llama3_8b_sparsegpt_unstructured.sh | 100 +++++++++ .../run_llama3_8b_wanda_unstructured.sh | 100 +++++++++ slurm_jobs/prune_llm/run_mistral_7b.sh | 103 ++++++++++ slurm_jobs/prune_llm/run_qwen2_7b.sh | 104 ++++++++++ slurm_jobs/prune_llm/submit_suite.sh | 63 ++++++ slurm_jobs/run_baseline_test.sh | 2 + slurm_jobs/run_fast_pruning.sh | 2 + slurm_jobs/run_mnist_basic.sh | 2 + slurm_jobs/run_single_model.sh | 2 + slurm_jobs/run_test_all_layers.sh | 2 + slurm_jobs/run_vision_pruning_test.sh | 2 + slurm_jobs/vision_prune/build_artifacts.sh | 41 ++++ slurm_jobs/vision_prune/run_all_array.sh | 186 +++++++++++++++++ .../run_damage_prediction_resnet18.sh | 47 +++++ .../vision_prune/run_mobilenetv2_cifar10.sh | 43 ++++ .../vision_prune/run_resnet18_cifar10.sh | 44 ++++ .../run_resnet18_cifar10_ablation.sh | 54 +++++ .../vision_prune/run_resnet18_cifar10_gap.sh | 45 +++++ .../vision_prune/run_resnet50_imagenet100.sh | 103 ++++++++++ slurm_jobs/vision_prune/run_vgg16_cifar10.sh | 43 ++++ .../run_weightsweep_resnet18_array.sh | 68 +++++++ slurm_jobs/vision_prune/submit_all.sh | 83 ++++++++ slurm_jobs/vision_prune/submit_all_array.sh | 54 +++++ slurm_jobs/vision_prune/submit_appendix.sh | 51 +++++ slurm_jobs/vision_prune/submit_suite.sh | 51 +++++ src/alignment/analysis/cascade_analysis.py | 21 +- .../analysis/visualization/paper_plots.py | 49 ++++- src/alignment/experiments/llm_experiments.py | 189 +++++++++++++++++- src/alignment/pruning/dependency_aware.py | 43 +++- src/alignment/services/mask_ops.py | 18 ++ 52 files changed, 2684 insertions(+), 46 deletions(-) create mode 100644 slurm_jobs/prune_llm/README.md create mode 100755 slurm_jobs/prune_llm/run_all_paper.sh create mode 100755 slurm_jobs/prune_llm/run_llama2_7b.sh create mode 100755 slurm_jobs/prune_llm/run_llama3_8b.sh create mode 100644 slurm_jobs/prune_llm/run_llama3_8b_calibration_array.sh create mode 100644 slurm_jobs/prune_llm/run_llama3_8b_noprotect.sh create mode 100644 slurm_jobs/prune_llm/run_llama3_8b_positive_redundancy_array.sh create mode 100644 slurm_jobs/prune_llm/run_llama3_8b_protect_baselines.sh create mode 100644 slurm_jobs/prune_llm/run_llama3_8b_sparsegpt_unstructured.sh create mode 100644 slurm_jobs/prune_llm/run_llama3_8b_wanda_unstructured.sh create mode 100755 slurm_jobs/prune_llm/run_mistral_7b.sh create mode 100755 slurm_jobs/prune_llm/run_qwen2_7b.sh create mode 100644 slurm_jobs/prune_llm/submit_suite.sh create mode 100644 slurm_jobs/vision_prune/build_artifacts.sh create mode 100644 slurm_jobs/vision_prune/run_all_array.sh create mode 100644 slurm_jobs/vision_prune/run_damage_prediction_resnet18.sh create mode 100644 slurm_jobs/vision_prune/run_mobilenetv2_cifar10.sh create mode 100644 slurm_jobs/vision_prune/run_resnet18_cifar10.sh create mode 100644 slurm_jobs/vision_prune/run_resnet18_cifar10_ablation.sh create mode 100644 slurm_jobs/vision_prune/run_resnet18_cifar10_gap.sh create mode 100644 slurm_jobs/vision_prune/run_resnet50_imagenet100.sh create mode 100644 slurm_jobs/vision_prune/run_vgg16_cifar10.sh create mode 100644 slurm_jobs/vision_prune/run_weightsweep_resnet18_array.sh create mode 100644 slurm_jobs/vision_prune/submit_all.sh create mode 100644 slurm_jobs/vision_prune/submit_all_array.sh create mode 100644 slurm_jobs/vision_prune/submit_appendix.sh create mode 100644 slurm_jobs/vision_prune/submit_suite.sh diff --git a/configs/prune_llm/README.md b/configs/prune_llm/README.md index 849efcd5..b8f85fc7 100644 --- a/configs/prune_llm/README.md +++ b/configs/prune_llm/README.md @@ -15,7 +15,7 @@ Configurations for generating results in the SCAR LLM pruning paper. Run all experiments: ```bash -bash drafts/LLM_prune/paper/slurm/run_all_paper.sh +bash slurm_jobs/prune_llm/run_all_paper.sh ``` Run single model: diff --git a/configs/prune_llm/llama2_7b_full.yaml b/configs/prune_llm/llama2_7b_full.yaml index 2978ddf3..02ecf979 100644 --- a/configs/prune_llm/llama2_7b_full.yaml +++ b/configs/prune_llm/llama2_7b_full.yaml @@ -77,6 +77,13 @@ llm: evaluate_perplexity: true evaluation_num_samples: 100 + # Use NVIDIA Minitron official few-shot settings for downstream tasks. + use_nvidia_fewshot: true + # Match OATS Table 19 / common pruning-paper protocol for WikiText-2 perplexity: + # concatenate full test set and evaluate in contiguous 2048-token blocks (no padding). + perplexity_protocol: "oats" + wikitext_subset: "wikitext-2-raw-v1" + perplexity_seq_len: 2048 evaluation_metrics: - "perplexity" @@ -137,6 +144,20 @@ supernode: core_fraction: 0.01 follower_fraction: 0.10 halo_fraction: 0.10 + # Connectivity definition (SCAR-Conn): fraction of a channel's down_proj write-mass + # that lands on the top-K hidden dimensions most written-to by supernodes. + connectivity_topk: 256 + # Optional post-processing for Conn (defaults keep current behavior) + connectivity_rank_normalize: false + connectivity_power: 1.0 + # Analysis-only: also estimate redundancy-to-core for a small random sample of non-halo channels + # (used for paper mechanism plots; does NOT affect pruning decisions). + non_halo_sample_size: 256 + non_halo_sample_seed: 0 + # Protection mapping (rank-power): Protect = alpha + (1-alpha)*(1 - rank^gamma) + protection_normalization: "rank_power" + protection_rank_power: 8.0 + protection_floor: 0.2 protect_core: true protect_core_metrics: - "scar_loss_proxy" # SCAR-LP @@ -232,7 +253,7 @@ pruning: dependency_aware: true sparsity_levels: [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7] - selection_modes: ["low", "high"] + selection_modes: ["low"] algorithms: - "rayleigh_quotient" diff --git a/configs/prune_llm/llama2_7b_unified.yaml b/configs/prune_llm/llama2_7b_unified.yaml index 9fc234f6..07694c72 100644 --- a/configs/prune_llm/llama2_7b_unified.yaml +++ b/configs/prune_llm/llama2_7b_unified.yaml @@ -139,7 +139,7 @@ cascade_analysis: pruning: enabled: true ratios: [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7] - selection_modes: ["low", "high"] + selection_modes: ["low"] distribution: "uniform" min_per_layer: 0.0 max_per_layer: 0.95 diff --git a/configs/prune_llm/llama3_8b_full.yaml b/configs/prune_llm/llama3_8b_full.yaml index b817d2c6..5f04a420 100644 --- a/configs/prune_llm/llama3_8b_full.yaml +++ b/configs/prune_llm/llama3_8b_full.yaml @@ -87,6 +87,14 @@ llm: evaluate_perplexity: true evaluation_num_samples: 100 + # Use NVIDIA Minitron official few-shot settings for downstream tasks + # (MMLU 5-shot, HellaSwag 10-shot, ARC 25-shot, WinoGrande 5-shot, etc.). + use_nvidia_fewshot: true + # Match OATS Table 19 / common pruning-paper protocol for WikiText-2 perplexity: + # concatenate full test set and evaluate in contiguous 2048-token blocks (no padding). + perplexity_protocol: "oats" + wikitext_subset: "wikitext-2-raw-v1" + perplexity_seq_len: 2048 evaluation_metrics: # Language modeling @@ -174,6 +182,21 @@ supernode: core_fraction: 0.01 follower_fraction: 0.10 halo_fraction: 0.10 + # Connectivity definition (SCAR-Conn): fraction of a channel's down_proj write-mass + # that lands on the top-K hidden dimensions most written-to by supernodes. + # (Avoids the ~1/hidden_dim collapse of L1-normalized dot-product overlap for dense matrices.) + connectivity_topk: 256 + # Optional post-processing for Conn (defaults keep current behavior) + connectivity_rank_normalize: false + connectivity_power: 1.0 + # Analysis-only: also estimate redundancy-to-core for a small random sample of non-halo channels + # (used for paper mechanism plots; does NOT affect pruning decisions). + non_halo_sample_size: 256 + non_halo_sample_seed: 0 + # Protection mapping (rank-power): Protect = alpha + (1-alpha)*(1 - rank^gamma) + protection_normalization: "rank_power" + protection_rank_power: 8.0 + protection_floor: 0.2 protect_core: true # Apply hard supernode protection only for the listed pruning metrics. # If omitted, legacy behavior is to protect for *all* pruning metrics. @@ -286,7 +309,9 @@ pruning: dependency_aware: true sparsity_levels: [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7] - selection_modes: ["low", "high"] + # We only report (and run) the standard pruning direction: prune *low*-scoring channels. + # The "high" mode (prune highest scores) is a pathological control and is excluded from paper runs. + selection_modes: ["low"] # ALL algorithms including SOTA baselines algorithms: diff --git a/configs/prune_llm/llama3_8b_unified.yaml b/configs/prune_llm/llama3_8b_unified.yaml index acacb074..651a8495 100644 --- a/configs/prune_llm/llama3_8b_unified.yaml +++ b/configs/prune_llm/llama3_8b_unified.yaml @@ -10,7 +10,7 @@ # - All experiment-specific settings in `extra:` section # - Same pruning/evaluation/visualization structure # -# Usage: python scripts/run_experiment.py --config configs/unified/llama3_8b_unified.yaml +# Usage: python scripts/run_experiment.py --config configs/prune_llm/llama3_8b_unified.yaml # Estimated runtime: ~6-8 hours on 1x A100 # ============================================================================= @@ -156,7 +156,7 @@ cascade_analysis: pruning: enabled: true ratios: [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7] - selection_modes: ["low", "high"] + selection_modes: ["low"] distribution: "uniform" min_per_layer: 0.0 max_per_layer: 0.95 diff --git a/configs/prune_llm/mistral_7b_full.yaml b/configs/prune_llm/mistral_7b_full.yaml index 06c621a6..3368d5b6 100644 --- a/configs/prune_llm/mistral_7b_full.yaml +++ b/configs/prune_llm/mistral_7b_full.yaml @@ -76,6 +76,13 @@ llm: evaluate_perplexity: true evaluation_num_samples: 100 + # Use NVIDIA Minitron official few-shot settings for downstream tasks. + use_nvidia_fewshot: true + # Match OATS Table 19 / common pruning-paper protocol for WikiText-2 perplexity: + # concatenate full test set and evaluate in contiguous 2048-token blocks (no padding). + perplexity_protocol: "oats" + wikitext_subset: "wikitext-2-raw-v1" + perplexity_seq_len: 2048 evaluation_metrics: - "perplexity" @@ -136,6 +143,20 @@ supernode: core_fraction: 0.01 follower_fraction: 0.10 halo_fraction: 0.10 + # Connectivity definition (SCAR-Conn): fraction of a channel's down_proj write-mass + # that lands on the top-K hidden dimensions most written-to by supernodes. + connectivity_topk: 256 + # Optional post-processing for Conn (defaults keep current behavior) + connectivity_rank_normalize: false + connectivity_power: 1.0 + # Analysis-only: also estimate redundancy-to-core for a small random sample of non-halo channels + # (used for paper mechanism plots; does NOT affect pruning decisions). + non_halo_sample_size: 256 + non_halo_sample_seed: 0 + # Protection mapping (rank-power): Protect = alpha + (1-alpha)*(1 - rank^gamma) + protection_normalization: "rank_power" + protection_rank_power: 8.0 + protection_floor: 0.2 protect_core: true protect_core_metrics: - "scar_loss_proxy" # SCAR-LP @@ -231,7 +252,7 @@ pruning: dependency_aware: true sparsity_levels: [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7] - selection_modes: ["low", "high"] + selection_modes: ["low"] algorithms: - "rayleigh_quotient" diff --git a/configs/prune_llm/mistral_7b_unified.yaml b/configs/prune_llm/mistral_7b_unified.yaml index d6c4ac13..3bd11b78 100644 --- a/configs/prune_llm/mistral_7b_unified.yaml +++ b/configs/prune_llm/mistral_7b_unified.yaml @@ -138,7 +138,7 @@ cascade_analysis: pruning: enabled: true ratios: [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7] - selection_modes: ["low", "high"] + selection_modes: ["low"] distribution: "uniform" min_per_layer: 0.0 max_per_layer: 0.95 diff --git a/configs/prune_llm/qwen2_7b_full.yaml b/configs/prune_llm/qwen2_7b_full.yaml index 646b29ac..7049b561 100644 --- a/configs/prune_llm/qwen2_7b_full.yaml +++ b/configs/prune_llm/qwen2_7b_full.yaml @@ -77,6 +77,13 @@ llm: evaluate_perplexity: true evaluation_num_samples: 100 + # Use NVIDIA Minitron official few-shot settings for downstream tasks. + use_nvidia_fewshot: true + # Match OATS Table 19 / common pruning-paper protocol for WikiText-2 perplexity: + # concatenate full test set and evaluate in contiguous 2048-token blocks (no padding). + perplexity_protocol: "oats" + wikitext_subset: "wikitext-2-raw-v1" + perplexity_seq_len: 2048 evaluation_metrics: - "perplexity" @@ -137,6 +144,20 @@ supernode: core_fraction: 0.01 follower_fraction: 0.10 halo_fraction: 0.10 + # Connectivity definition (SCAR-Conn): fraction of a channel's down_proj write-mass + # that lands on the top-K hidden dimensions most written-to by supernodes. + connectivity_topk: 256 + # Optional post-processing for Conn (defaults keep current behavior) + connectivity_rank_normalize: false + connectivity_power: 1.0 + # Analysis-only: also estimate redundancy-to-core for a small random sample of non-halo channels + # (used for paper mechanism plots; does NOT affect pruning decisions). + non_halo_sample_size: 256 + non_halo_sample_seed: 0 + # Protection mapping (rank-power): Protect = alpha + (1-alpha)*(1 - rank^gamma) + protection_normalization: "rank_power" + protection_rank_power: 8.0 + protection_floor: 0.2 protect_core: true protect_core_metrics: - "scar_loss_proxy" # SCAR-LP @@ -232,7 +253,7 @@ pruning: dependency_aware: true sparsity_levels: [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7] - selection_modes: ["low", "high"] + selection_modes: ["low"] algorithms: - "rayleigh_quotient" diff --git a/configs/prune_llm/qwen2_7b_unified.yaml b/configs/prune_llm/qwen2_7b_unified.yaml index 430f62e5..1b102a07 100644 --- a/configs/prune_llm/qwen2_7b_unified.yaml +++ b/configs/prune_llm/qwen2_7b_unified.yaml @@ -139,7 +139,7 @@ cascade_analysis: pruning: enabled: true ratios: [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7] - selection_modes: ["low", "high"] + selection_modes: ["low"] distribution: "uniform" min_per_layer: 0.0 max_per_layer: 0.95 diff --git a/configs/vision_prune/mobilenetv2_cifar10_unified.yaml b/configs/vision_prune/mobilenetv2_cifar10_unified.yaml index e9b04085..b66a86aa 100644 --- a/configs/vision_prune/mobilenetv2_cifar10_unified.yaml +++ b/configs/vision_prune/mobilenetv2_cifar10_unified.yaml @@ -74,7 +74,7 @@ metrics: rayleigh_quotient: enabled: true - relative: true + relative: false # Standard Rayleigh quotient (no trace-normalization) shrinkage: true redundancy: @@ -143,7 +143,7 @@ cascade_analysis: # More sensitive to pruning - interesting to see which metrics matter pruning: enabled: true - distribution: "uniform" # uniform, global_threshold, adaptive_sensitivity + distribution: "global_threshold" # uniform, global_threshold, adaptive_sensitivity dependency_aware: true # MobileNet has inverted residuals min_per_layer: 0.0 max_per_layer: 0.95 @@ -156,6 +156,7 @@ pruning: # ========================================================================= - "random" # Random baseline - "magnitude" # Standard magnitude pruning (prune low) + - "activation_mean" # Mean |activation| baseline - "taylor" # Gradient-based importance - "network_slimming" # Network Slimming (BN gamma) baseline - "geometric_median" # FPGM-style geometric median baseline @@ -167,6 +168,8 @@ pruning: - "rq_low" # Prune low Rayleigh Quotient - "redundancy_low" # Prune low redundancy (MI) - "synergy_low" # Prune low synergy + - "redundancy_high" # Control: prune high redundancy + - "synergy_high" # Control: prune high synergy # ========================================================================= # COMPOSITE COMBINATIONS @@ -183,6 +186,7 @@ pruning: # CLUSTER-AWARE # ========================================================================= - "cluster_aware" # Original: protect critical, target redundant + - "cluster_aware_annealed" # Annealed mixing / constraints schedule - "cluster_aware_protect_redundant" # Inverted: protect redundant scoring_methods: @@ -203,7 +207,7 @@ pruning: epochs: 5 learning_rate: 0.0001 weight_decay: 0.00001 - max_batches: 100 + max_batches: 200 # ----------------------------------------------------------------------------- # EVALUATION (Enhanced for Vision) diff --git a/configs/vision_prune/resnet18_cifar10_unified.yaml b/configs/vision_prune/resnet18_cifar10_unified.yaml index 38723a36..dbc390f4 100644 --- a/configs/vision_prune/resnet18_cifar10_unified.yaml +++ b/configs/vision_prune/resnet18_cifar10_unified.yaml @@ -78,7 +78,7 @@ metrics: rayleigh_quotient: enabled: true - relative: true + relative: false # Standard Rayleigh quotient (no trace-normalization) shrinkage: true redundancy: @@ -152,12 +152,12 @@ cascade_analysis: # ----------------------------------------------------------------------------- pruning: enabled: true - distribution: "uniform" # uniform, global_threshold, size_proportional, importance_weighted + distribution: "global_threshold" # uniform, global_threshold, size_proportional, importance_weighted dependency_aware: true # Propagate masks through BN/skip connections min_per_layer: 0.0 max_per_layer: 0.95 # Include high sparsity (80%, 90%) to clearly see degradation - ratios: [0.1, 0.3, 0.5, 0.7, 0.8, 0.9] + ratios: [0.1, 0.3, 0.4, 0.5, 0.7, 0.8, 0.9, 0.95] # COMPREHENSIVE ALGORITHM LIST for exploration algorithms: @@ -166,6 +166,7 @@ pruning: # ========================================================================= - "random" # Random baseline - "magnitude" # Standard magnitude pruning (prune low) + - "activation_mean" # Mean |activation| baseline - "taylor" # Gradient-based importance - "network_slimming" # Network Slimming (BN gamma) baseline - "geometric_median" # FPGM-style geometric median baseline @@ -177,6 +178,8 @@ pruning: - "rq_low" # Prune low Rayleigh Quotient - "redundancy_low" # Prune low redundancy (MI) - "synergy_low" # Prune low synergy + - "redundancy_high" # Control: prune high redundancy + - "synergy_high" # Control: prune high synergy # ========================================================================= # COMPOSITE COMBINATIONS @@ -193,6 +196,7 @@ pruning: # CLUSTER-AWARE # ========================================================================= - "cluster_aware" # Original: protect critical, target redundant + - "cluster_aware_annealed" # Annealed mixing / constraints schedule - "cluster_aware_protect_redundant" # Inverted: protect redundant scoring_methods: @@ -215,7 +219,7 @@ pruning: learning_rate: 0.0001 weight_decay: 0.0001 # Safety cap: limits fine-tune compute so the full method×ratio grid stays feasible on 1 GPU - max_batches: 100 + max_batches: 200 # ----------------------------------------------------------------------------- # EVALUATION (Enhanced for Vision) diff --git a/configs/vision_prune/resnet50_imagenet100_unified.yaml b/configs/vision_prune/resnet50_imagenet100_unified.yaml index 08f0091a..53b4fab0 100644 --- a/configs/vision_prune/resnet50_imagenet100_unified.yaml +++ b/configs/vision_prune/resnet50_imagenet100_unified.yaml @@ -75,7 +75,7 @@ metrics: rayleigh_quotient: enabled: true - relative: true + relative: false # Standard Rayleigh quotient (no trace-normalization) shrinkage: true redundancy: @@ -143,7 +143,7 @@ cascade_analysis: # ResNet-50 on ImageNet: larger model, more channels, tests scalability pruning: enabled: true - distribution: "uniform" # uniform, global_threshold, adaptive_sensitivity + distribution: "global_threshold" # uniform, global_threshold, adaptive_sensitivity dependency_aware: true # ResNet has skip connections min_per_layer: 0.0 max_per_layer: 0.95 @@ -156,6 +156,7 @@ pruning: # ========================================================================= - "random" # Random baseline - "magnitude" # Standard magnitude pruning (prune low) + - "activation_mean" # Mean |activation| baseline - "taylor" # Gradient-based importance - "network_slimming" # Network Slimming (BN gamma) baseline - "geometric_median" # FPGM-style geometric median baseline @@ -167,6 +168,8 @@ pruning: - "rq_low" # Prune low Rayleigh Quotient - "redundancy_low" # Prune low redundancy (MI) - "synergy_low" # Prune low synergy + - "redundancy_high" # Control: prune high redundancy + - "synergy_high" # Control: prune high synergy # ========================================================================= # COMPOSITE COMBINATIONS @@ -183,6 +186,7 @@ pruning: # CLUSTER-AWARE # ========================================================================= - "cluster_aware" # Original: protect critical, target redundant + - "cluster_aware_annealed" # Annealed mixing / constraints schedule - "cluster_aware_protect_redundant" # Inverted: protect redundant scoring_methods: @@ -200,11 +204,11 @@ pruning: fine_tune: enabled: true # Enable recovery fine-tuning after pruning (standard for reporting) - epochs: 5 # Fewer epochs for ImageNet + epochs: 3 # Keep per-(method,ratio) fine-tune feasible on ImageNet-100 learning_rate: 0.00001 weight_decay: 0.0001 # Critical for feasibility: fine-tuning every (method,ratio) on ImageNet-100 otherwise explodes runtime. - max_batches: 10 + max_batches: 50 # ----------------------------------------------------------------------------- # EVALUATION (Enhanced for Vision) diff --git a/configs/vision_prune/vgg16_cifar10_unified.yaml b/configs/vision_prune/vgg16_cifar10_unified.yaml index 5afb2ec5..4552eea6 100644 --- a/configs/vision_prune/vgg16_cifar10_unified.yaml +++ b/configs/vision_prune/vgg16_cifar10_unified.yaml @@ -72,7 +72,7 @@ metrics: rayleigh_quotient: enabled: true - relative: true + relative: false # Standard Rayleigh quotient (no trace-normalization) shrinkage: true redundancy: @@ -140,11 +140,11 @@ cascade_analysis: # VGG is highly pruneable due to high redundancy - excellent for testing assumptions pruning: enabled: true - distribution: "uniform" # uniform, global_threshold, adaptive_sensitivity + distribution: "global_threshold" # uniform, global_threshold, adaptive_sensitivity dependency_aware: false # VGG has no skip connections min_per_layer: 0.0 max_per_layer: 0.95 - ratios: [0.1, 0.3, 0.5, 0.7, 0.8] # VGG can be pruned aggressively + ratios: [0.1, 0.3, 0.4, 0.5, 0.7, 0.8] # Add 40% point for paper curves # COMPREHENSIVE ALGORITHM LIST for exploration algorithms: @@ -153,6 +153,7 @@ pruning: # ========================================================================= - "random" # Random baseline - "magnitude" # Standard magnitude pruning (prune low) + - "activation_mean" # Mean |activation| baseline - "taylor" # Gradient-based importance - "network_slimming" # Network Slimming (BN gamma) baseline - "geometric_median" # FPGM-style geometric median baseline @@ -164,6 +165,8 @@ pruning: - "rq_low" # Prune low Rayleigh Quotient - "redundancy_low" # Prune low redundancy (MI) - "synergy_low" # Prune low synergy + - "redundancy_high" # Control: prune high redundancy + - "synergy_high" # Control: prune high synergy # ========================================================================= # COMPOSITE COMBINATIONS @@ -180,6 +183,7 @@ pruning: # CLUSTER-AWARE # ========================================================================= - "cluster_aware" # Original: protect critical, target redundant + - "cluster_aware_annealed" # Annealed mixing / constraints schedule - "cluster_aware_protect_redundant" # Inverted: protect redundant scoring_methods: @@ -200,7 +204,7 @@ pruning: epochs: 5 learning_rate: 0.0001 weight_decay: 0.0001 - max_batches: 100 + max_batches: 200 # ----------------------------------------------------------------------------- # EVALUATION (Enhanced for Vision) diff --git a/scripts/run_experiment.py b/scripts/run_experiment.py index 2ae88204..82bce02e 100644 --- a/scripts/run_experiment.py +++ b/scripts/run_experiment.py @@ -118,6 +118,28 @@ def _get_nested(obj, key, default): pruning_ratios = getattr(config, "pruning_amounts", None) or \ (pruning_cfg.get("ratios") if isinstance(pruning_cfg, dict) else None) or \ [0.1, 0.3, 0.5, 0.7] + + # Pruning distribution / constraints (used by ClusterAnalysisExperiment pruning pipeline) + pruning_distribution = ( + getattr(config, "pruning_distribution", None) + or (pruning_cfg.get("distribution") if isinstance(pruning_cfg, dict) else None) + or "uniform" + ) + dependency_aware_pruning = bool( + getattr(config, "dependency_aware_pruning", None) + if hasattr(config, "dependency_aware_pruning") + else (pruning_cfg.get("dependency_aware", False) if isinstance(pruning_cfg, dict) else False) + ) + pruning_min_per_layer = float( + getattr(config, "pruning_min_per_layer", None) + if hasattr(config, "pruning_min_per_layer") + else (pruning_cfg.get("min_per_layer", 0.0) if isinstance(pruning_cfg, dict) else 0.0) + ) + pruning_max_per_layer = float( + getattr(config, "pruning_max_per_layer", None) + if hasattr(config, "pruning_max_per_layer") + else (pruning_cfg.get("max_per_layer", 0.95) if isinstance(pruning_cfg, dict) else 0.95) + ) # Get fine-tuning settings fine_tune_cfg = pruning_cfg.get("fine_tune", {}) if isinstance(pruning_cfg, dict) else {} @@ -185,6 +207,13 @@ def _get_nested(obj, key, default): seed=getattr(config, "seed", 42), ) + # Propagate pruning distribution knobs into ClusterAnalysisConfig so all pruning + # methods (including cluster-aware) use the same allocation regime. + setattr(cluster_config, "pruning_distribution", str(pruning_distribution)) + setattr(cluster_config, "dependency_aware_pruning", bool(dependency_aware_pruning)) + setattr(cluster_config, "pruning_min_per_layer", float(pruning_min_per_layer)) + setattr(cluster_config, "pruning_max_per_layer", float(pruning_max_per_layer)) + # Optional: allow sweeping cluster-aware score weights via nested pruning config: # pruning.cluster_aware.{alpha,beta,gamma,lambda_halo,protect_critical_frac} if isinstance(pruning_cfg, dict): @@ -200,6 +229,11 @@ def _get_nested(obj, key, default): setattr(cluster_config, "cluster_aware_lambda_halo", float(ca["lambda_halo"])) if "protect_critical_frac" in ca: setattr(cluster_config, "cluster_aware_protect_critical_frac", float(ca["protect_critical_frac"])) + # Annealed variant schedule (when using cluster_aware_annealed) + if "anneal_start" in ca: + setattr(cluster_config, "cluster_aware_anneal_start", float(ca["anneal_start"])) + if "anneal_end" in ca: + setattr(cluster_config, "cluster_aware_anneal_end", float(ca["anneal_end"])) # Also support the flat ExperimentConfig fields used by our config loader # (and mapped from unified-style dotted CLI overrides). @@ -209,6 +243,8 @@ def _get_nested(obj, key, default): "cluster_aware_gamma", "cluster_aware_lambda_halo", "cluster_aware_protect_critical_frac", + "cluster_aware_anneal_start", + "cluster_aware_anneal_end", ): if hasattr(config, attr): setattr(cluster_config, attr, float(getattr(config, attr))) diff --git a/slurm_jobs/prune_llm/README.md b/slurm_jobs/prune_llm/README.md new file mode 100644 index 00000000..c70ea082 --- /dev/null +++ b/slurm_jobs/prune_llm/README.md @@ -0,0 +1,74 @@ +### SCAR paper experiment suite (batch + collection) + +This folder contains **SLURM batch scripts** that run a complete ICML-style paper suite: + +- **Main results + generalization** (4 models) +- **Key controls / ablations** on Llama-3.1-8B: + - **LP-no-protect** + **remove-supernodes-early** (mode=high) control + - **Protect+Wanda** and **Protect+Magnitude** (baseline + supernode protection) + - **Positive-only redundancy** ablation (anti-correlation does NOT count as redundancy) + - **Calibration sensitivity** sweep (dataset + sample-count) +- **Optional paper-faithful unstructured baseline reproductions** (Llama-3.1-8B): + - `wanda_unstructured` (Wanda as originally proposed: unstructured |W|·||X||₂ pruning) + - `sparsegpt_unstructured` (SparseGPT with unstructured pruning + reconstruction) + +All jobs write to a single `OUTPUT_BASE` using the unified job directory structure: + +`{OUTPUT_BASE}/{experiment_name}_{timestamp}_{job_id}/` + +### How to run + +- **Set output base** (or let scripts use the default in each file): + +```bash +export OUTPUT_BASE="/n/holylfs06/LABS/kempner_project_b/Lab/alignment/Prune_LLM" +``` + +- **Submit the full suite**: + +```bash +bash slurm_jobs/prune_llm/submit_suite.sh +``` + +### Optional: submit unstructured baseline reproductions + +These are **not enabled by default** (they’re expensive and are mainly for appendix/sanity checks). + +Enable them by setting: + +```bash +export SUBMIT_UNSTRUCTURED_BASELINES=1 +``` + +Then run either: + +```bash +bash slurm_jobs/prune_llm/run_all_paper.sh +``` + +or + +```bash +bash slurm_jobs/prune_llm/submit_suite.sh +``` + +### How to collect artifacts (tables + placeholder figures) + +After jobs finish: + +```bash +# Recommended (tables + figures, plus a LaTeX sanity compile): +bash drafts/LLM_prune/paper/scripts/refresh_paper_artifacts.sh + +# Or, manually: +# python drafts/LLM_prune/paper/scripts/collect_paper_artifacts.py \ +# --results-base "$OUTPUT_BASE" \ +# --draft-dir /n/holylabs/kempner_dev/Users/hsafaai/Code/alignment/drafts/LLM_prune +``` + +This will: +- write LaTeX snippets to `drafts/LLM_prune/paper_artifacts/tables/*.tex` +- write `drafts/LLM_prune/paper_artifacts/numbers.tex` (paper text macros) +- copy/regenerate key plots into `drafts/LLM_prune/figures/*.png` (used by the TeX) + + diff --git a/slurm_jobs/prune_llm/run_all_paper.sh b/slurm_jobs/prune_llm/run_all_paper.sh new file mode 100755 index 00000000..c0918158 --- /dev/null +++ b/slurm_jobs/prune_llm/run_all_paper.sh @@ -0,0 +1,110 @@ +#!/bin/bash +# ============================================================================ +# SUBMIT ALL PAPER EXPERIMENTS +# ============================================================================ +# This script submits all 4 paper experiments as separate SLURM jobs +# They will run in parallel if resources are available +# +# Output Directory Structure: +# All results go to: /n/holylfs06/LABS/kempner_project_b/Lab/alignment/Prune_LLM/ +# Each job creates a unique directory: {model}_paper_results_{timestamp}_{job_id}/ +# results/ - JSON results files +# logs/ - experiment.log +# figures/ - All visualizations +# checkpoints/ - Model checkpoints +# analysis/ - Post-analysis outputs +# +# Usage: +# cd /n/holylabs/kempner_dev/Users/hsafaai/Code/alignment +# bash slurm_jobs/prune_llm/run_all_paper.sh +# ============================================================================ + +# NOTE: This is a *submission* script (it calls `sbatch ...` for the real jobs). +# Run it with `bash ...` from a login node. If you accidentally run it with `sbatch`, +# Slurm would normally create `slurm-.out` in the repo root; we redirect that +# output to /tmp to avoid polluting the source tree. +#SBATCH --job-name=submit_scar_paper +#SBATCH --output=/tmp/%x_%j.out +#SBATCH --error=/tmp/%x_%j.err + +set -euo pipefail + +OUTPUT_BASE="${OUTPUT_BASE:-/n/holylfs06/LABS/kempner_project_b/Lab/alignment/Prune_LLM}" +# Ensure compute jobs can find the HuggingFace token/cache. +# If you ran `hf auth login` with HF_HOME under OUTPUT_BASE, this propagates it to all sbatch jobs. +export HF_HOME="${HF_HOME:-${OUTPUT_BASE}/huggingface_cache}" +mkdir -p "$HF_HOME" || true +SUBMIT_UNSTRUCTURED_BASELINES="${SUBMIT_UNSTRUCTURED_BASELINES:-0}" + +echo "==============================================" +echo "Submitting SCAR Paper Experiments" +echo "==============================================" +echo "" +echo "Output directory: $OUTPUT_BASE" +echo "Submit unstructured baseline reproductions: $SUBMIT_UNSTRUCTURED_BASELINES (set to 1 to enable)" +echo "" + +REPO_ROOT="$(cd "$(dirname "${BASH_SOURCE[0]}")/../.." && pwd)" +cd "${REPO_ROOT}" +mkdir -p logs + +# Submit all jobs +echo "Submitting LLaMA-3.1-8B (main results)..." +JOB1=$(sbatch --export=ALL,OUTPUT_BASE="$OUTPUT_BASE",HF_HOME="$HF_HOME",HF_TOKEN=,HUGGINGFACE_HUB_TOKEN= slurm_jobs/prune_llm/run_llama3_8b.sh | awk '{print $4}') +echo " Job ID: $JOB1" + +echo "Submitting Mistral-7B (generalization)..." +JOB2=$(sbatch --export=ALL,OUTPUT_BASE="$OUTPUT_BASE",HF_HOME="$HF_HOME",HF_TOKEN=,HUGGINGFACE_HUB_TOKEN= slurm_jobs/prune_llm/run_mistral_7b.sh | awk '{print $4}') +echo " Job ID: $JOB2" + +echo "Submitting LLaMA-2-7B (generalization)..." +JOB3=$(sbatch --export=ALL,OUTPUT_BASE="$OUTPUT_BASE",HF_HOME="$HF_HOME",HF_TOKEN=,HUGGINGFACE_HUB_TOKEN= slurm_jobs/prune_llm/run_llama2_7b.sh | awk '{print $4}') +echo " Job ID: $JOB3" + +echo "Submitting Qwen2-7B (generalization)..." +JOB4=$(sbatch --export=ALL,OUTPUT_BASE="$OUTPUT_BASE",HF_HOME="$HF_HOME",HF_TOKEN=,HUGGINGFACE_HUB_TOKEN= slurm_jobs/prune_llm/run_qwen2_7b.sh | awk '{print $4}') +echo " Job ID: $JOB4" + +if [[ "$SUBMIT_UNSTRUCTURED_BASELINES" == "1" ]]; then + echo "" + echo "---- Paper-faithful unstructured baseline reproductions (LLaMA-3.1-8B) ----" + echo "Submitting Wanda (unstructured)..." + JOB5=$(sbatch --export=ALL,OUTPUT_BASE="$OUTPUT_BASE",HF_HOME="$HF_HOME",HF_TOKEN=,HUGGINGFACE_HUB_TOKEN= slurm_jobs/prune_llm/run_llama3_8b_wanda_unstructured.sh | awk '{print $4}') + echo " Job ID: $JOB5" + + echo "Submitting SparseGPT (unstructured + reconstruction)..." + JOB6=$(sbatch --export=ALL,OUTPUT_BASE="$OUTPUT_BASE",HF_HOME="$HF_HOME",HF_TOKEN=,HUGGINGFACE_HUB_TOKEN= slurm_jobs/prune_llm/run_llama3_8b_sparsegpt_unstructured.sh | awk '{print $4}') + echo " Job ID: $JOB6" +fi + +echo "" +echo "==============================================" +echo "All jobs submitted!" +echo "==============================================" +echo "" +if [[ "$SUBMIT_UNSTRUCTURED_BASELINES" == "1" ]]; then + echo "Job IDs: $JOB1, $JOB2, $JOB3, $JOB4, $JOB5, $JOB6" +else + echo "Job IDs: $JOB1, $JOB2, $JOB3, $JOB4" +fi +echo "" +echo "Monitor with:" +echo " squeue -u \$USER" +echo "" +echo "View SLURM logs:" +echo " tail -f logs/paper_llama3_8b_${JOB1}.out" +echo " tail -f logs/paper_mistral_7b_${JOB2}.out" +echo " tail -f logs/paper_llama2_7b_${JOB3}.out" +echo " tail -f logs/paper_qwen2_7b_${JOB4}.out" +echo "" +echo "Expected runtime: ~6-8 hours per job" +echo "" +echo "Results will be in:" +echo " $OUTPUT_BASE/llama3_8b_paper_results_*_${JOB1}/" +echo " $OUTPUT_BASE/mistral_7b_paper_results_*_${JOB2}/" +echo " $OUTPUT_BASE/llama2_7b_paper_results_*_${JOB3}/" +echo " $OUTPUT_BASE/qwen2_7b_paper_results_*_${JOB4}/" +if [[ "$SUBMIT_UNSTRUCTURED_BASELINES" == "1" ]]; then + echo " $OUTPUT_BASE/llama3_8b_paper_results_wanda_unstructured_*_${JOB5}/" + echo " $OUTPUT_BASE/llama3_8b_paper_results_sparsegpt_unstructured_*_${JOB6}/" +fi \ No newline at end of file diff --git a/slurm_jobs/prune_llm/run_llama2_7b.sh b/slurm_jobs/prune_llm/run_llama2_7b.sh new file mode 100755 index 00000000..6c341b81 --- /dev/null +++ b/slurm_jobs/prune_llm/run_llama2_7b.sh @@ -0,0 +1,103 @@ +#!/bin/bash +#SBATCH --job-name=paper_llama2_7b +#SBATCH --output=logs/paper_llama2_7b_%j.out +#SBATCH --error=logs/paper_llama2_7b_%j.err +#SBATCH --nodes=1 +#SBATCH --ntasks-per-node=1 +#SBATCH --gres=gpu:4 +#SBATCH --cpus-per-task=16 +#SBATCH --time=10:00:00 +#SBATCH --mem=320GB +#SBATCH --partition=kempner_h100_priority3 +#SBATCH --account=kempner_dev + +# ============================================================================ +# LLAMA-2-7B PAPER RESULTS (Generalization) +# ============================================================================ +# Cross-model generalization experiment +# Expected runtime: ~4-6 hours on H100 +# +# Output Directory Structure: +# /n/holylfs06/LABS/kempner_project_b/Lab/alignment/Prune_LLM/ +# llama2_7b_paper_results_{timestamp}_{SLURM_JOB_ID}/ +# results/ - JSON results files +# logs/ - experiment.log +# figures/ - All visualizations +# checkpoints/ - Model checkpoints +# analysis/ - Post-analysis outputs +# ============================================================================ + +set -euo pipefail + +echo "============================================================================" +echo "SCAR Paper: LLaMA-2-7B (Generalization)" +echo "============================================================================" +echo "Job ID: $SLURM_JOB_ID" +echo "Node: $(hostname)" +echo "Start time: $(date)" +echo "GPU: $(nvidia-smi --query-gpu=name --format=csv,noheader | head -1)" +OUTPUT_BASE="${OUTPUT_BASE:-/n/holylfs06/LABS/kempner_project_b/Lab/alignment/Prune_LLM}" +echo "Output Base: $OUTPUT_BASE" +echo "" + +# Environment setup +module purge +module load cuda/12.2.0-fasrc01 +eval "$(conda shell.bash hook)" +conda activate networkAlignmentAnalysis + +# Prefer SLURM_SUBMIT_DIR (repo root) when available. +cd "${SLURM_SUBMIT_DIR:-/n/holylabs/kempner_dev/Users/hsafaai/Code/alignment}" + +# Create local logs directory for SLURM output files +mkdir -p logs + +export PYTHONPATH="${PWD}:${PWD}/src:${PYTHONPATH:-}" +export OMP_NUM_THREADS=$SLURM_CPUS_PER_TASK +export TOKENIZERS_PARALLELISM=false +export PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True +# HuggingFace auth/cache: +# - Respect HF_HOME if already set (e.g. exported from submission script). +# - Else, if you ran `hf auth login` with HF_HOME under OUTPUT_BASE, prefer that token/cache. +# - Else fall back to scratch cache, then ~/.cache. +if [[ -z "${HF_HOME:-}" ]]; then + if [[ -f "${OUTPUT_BASE}/huggingface_cache/token" ]]; then + export HF_HOME="${OUTPUT_BASE}/huggingface_cache" + elif [[ -d /n/holyscratch01/kempner_dev/Users/hsafaai/huggingface_cache ]]; then + export HF_HOME="/n/holyscratch01/kempner_dev/Users/hsafaai/huggingface_cache" + else + export HF_HOME="/n/home13/hsafaai/.cache/huggingface" + fi +fi +HF_TOKEN_FILE="${HF_HOME}/token" +if [[ -f "$HF_TOKEN_FILE" ]]; then + export HF_TOKEN="$(cat "$HF_TOKEN_FILE")" +elif [[ -z "${HF_TOKEN:-}" ]]; then + echo "WARN: HF token file not found at $HF_TOKEN_FILE (set HF_TOKEN env var if needed)" >&2 +fi +if [[ -n "${HF_TOKEN:-}" ]]; then + export HUGGINGFACE_HUB_TOKEN="${HF_TOKEN}" +fi +echo "HF_HOME: $HF_HOME" +if [[ -n "${HF_TOKEN:-}" ]]; then + echo "HF_TOKEN: set" +else + echo "HF_TOKEN: unset" +fi + +echo "" +echo "Running LLaMA-2-7B full paper analysis..." +echo "" + +python scripts/run_experiment.py \ + --config configs/prune_llm/llama2_7b_full.yaml \ + --device cuda \ + --base-output-dir "$OUTPUT_BASE" + +echo "" +echo "============================================================================" +echo "LLaMA-2-7B completed at $(date)" +echo "============================================================================" +echo "" +echo "Results saved to: $OUTPUT_BASE/" +echo "Look for directory: llama2_7b_paper_results_*_$SLURM_JOB_ID" diff --git a/slurm_jobs/prune_llm/run_llama3_8b.sh b/slurm_jobs/prune_llm/run_llama3_8b.sh new file mode 100755 index 00000000..e09d11dd --- /dev/null +++ b/slurm_jobs/prune_llm/run_llama3_8b.sh @@ -0,0 +1,110 @@ +#!/bin/bash +#SBATCH --job-name=paper_llama3_8b +#SBATCH --output=logs/paper_llama3_8b_%j.out +#SBATCH --error=logs/paper_llama3_8b_%j.err +#SBATCH --nodes=1 +#SBATCH --ntasks-per-node=1 +#SBATCH --gres=gpu:4 +#SBATCH --cpus-per-task=16 +#SBATCH --time=12:00:00 +#SBATCH --mem=320GB +#SBATCH --partition=kempner_h100_priority3 +#SBATCH --account=kempner_dev + +# ============================================================================ +# LLAMA-3.1-8B PAPER RESULTS +# ============================================================================ +# Full SCAR analysis including: +# - Supernode distribution & robustness +# - Halo redundancy analysis +# - Cross-layer importance +# - Within-layer importance +# - All pruning methods + SOTA baselines (Wanda, SparseGPT) +# - Full benchmark evaluation +# +# Expected runtime: ~6-8 hours on H100 +# +# Output Directory Structure: +# /n/holylfs06/LABS/kempner_project_b/Lab/alignment/Prune_LLM/ +# llama3_8b_paper_results_{timestamp}_{SLURM_JOB_ID}/ +# results/ - JSON results files +# logs/ - experiment.log +# figures/ - All visualizations +# checkpoints/ - Model checkpoints +# analysis/ - Post-analysis outputs +# ============================================================================ + +set -euo pipefail + +echo "============================================================================" +echo "SCAR Paper: LLaMA-3.1-8B" +echo "============================================================================" +echo "Job ID: $SLURM_JOB_ID" +echo "Node: $(hostname)" +echo "Start time: $(date)" +echo "GPU: $(nvidia-smi --query-gpu=name --format=csv,noheader | head -1)" +OUTPUT_BASE="${OUTPUT_BASE:-/n/holylfs06/LABS/kempner_project_b/Lab/alignment/Prune_LLM}" +echo "Output Base: $OUTPUT_BASE" +echo "" + +# Environment setup +module purge +module load cuda/12.2.0-fasrc01 +eval "$(conda shell.bash hook)" +conda activate networkAlignmentAnalysis + +# Prefer SLURM_SUBMIT_DIR (repo root) when available. +cd "${SLURM_SUBMIT_DIR:-/n/holylabs/kempner_dev/Users/hsafaai/Code/alignment}" + +# Create local logs directory for SLURM output files +mkdir -p logs + +export PYTHONPATH="${PWD}:${PWD}/src:${PYTHONPATH:-}" +export OMP_NUM_THREADS=$SLURM_CPUS_PER_TASK +export TOKENIZERS_PARALLELISM=false +export PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True +# HuggingFace auth/cache: +# - Respect HF_HOME if already set (e.g. exported from submission script). +# - Else, if you ran `hf auth login` with HF_HOME under OUTPUT_BASE, prefer that token/cache. +# - Else fall back to scratch cache, then ~/.cache. +if [[ -z "${HF_HOME:-}" ]]; then + if [[ -f "${OUTPUT_BASE}/huggingface_cache/token" ]]; then + export HF_HOME="${OUTPUT_BASE}/huggingface_cache" + elif [[ -d /n/holyscratch01/kempner_dev/Users/hsafaai/huggingface_cache ]]; then + export HF_HOME="/n/holyscratch01/kempner_dev/Users/hsafaai/huggingface_cache" + else + export HF_HOME="/n/home13/hsafaai/.cache/huggingface" + fi +fi +HF_TOKEN_FILE="${HF_HOME}/token" +if [[ -f "$HF_TOKEN_FILE" ]]; then + export HF_TOKEN="$(cat "$HF_TOKEN_FILE")" +elif [[ -z "${HF_TOKEN:-}" ]]; then + echo "WARN: HF token file not found at $HF_TOKEN_FILE (set HF_TOKEN env var if needed)" >&2 +fi +if [[ -n "${HF_TOKEN:-}" ]]; then + export HUGGINGFACE_HUB_TOKEN="${HF_TOKEN}" +fi +echo "HF_HOME: $HF_HOME" +if [[ -n "${HF_TOKEN:-}" ]]; then + echo "HF_TOKEN: set" +else + echo "HF_TOKEN: unset" +fi + +echo "" +echo "Running LLaMA-3.1-8B full paper analysis..." +echo "" + +python scripts/run_experiment.py \ + --config configs/prune_llm/llama3_8b_full.yaml \ + --device cuda \ + --base-output-dir "$OUTPUT_BASE" + +echo "" +echo "============================================================================" +echo "LLaMA-3.1-8B completed at $(date)" +echo "============================================================================" +echo "" +echo "Results saved to: $OUTPUT_BASE/" +echo "Look for directory: llama3_8b_paper_results_*_$SLURM_JOB_ID" diff --git a/slurm_jobs/prune_llm/run_llama3_8b_calibration_array.sh b/slurm_jobs/prune_llm/run_llama3_8b_calibration_array.sh new file mode 100644 index 00000000..5a18dbe0 --- /dev/null +++ b/slurm_jobs/prune_llm/run_llama3_8b_calibration_array.sh @@ -0,0 +1,121 @@ +#!/bin/bash +#SBATCH --job-name=paper_llama3_calib +#SBATCH --output=logs/paper_llama3_calib_%A_%a.out +#SBATCH --error=logs/paper_llama3_calib_%A_%a.err +#SBATCH --nodes=1 +#SBATCH --ntasks-per-node=1 +#SBATCH --gres=gpu:4 +#SBATCH --cpus-per-task=16 +#SBATCH --time=06:00:00 +#SBATCH --mem=320GB +#SBATCH --partition=kempner_h100_priority3 +#SBATCH --account=kempner_dev +#SBATCH --array=0-4 + +# ---------------------------------------------------------------------------- +# LLaMA-3.1-8B SWEEP: calibration sensitivity for SCAR-Conn @ 50% sparsity +# +# Task mapping: +# 0: wikitext, n=128 +# 1: wikitext, n=64 +# 2: wikitext, n=32 +# 3: c4, n=128 +# 4: mixed_wikitext_c4, n=128 +# +# Notes: +# - We restrict pruning to SCAR-Conn at 50% and evaluate perplexity only (fast). +# ---------------------------------------------------------------------------- + +set -euo pipefail + +DATASETS=("wikitext" "wikitext" "wikitext" "c4" "mixed_wikitext_c4") +NSAMPLES=(128 64 32 128 128) +TAGS=("wikitext_128" "wikitext_64" "wikitext_32" "c4_128" "mixed_128") + +IDX="${SLURM_ARRAY_TASK_ID}" +DATASET="${DATASETS[$IDX]}" +N="${NSAMPLES[$IDX]}" +TAG="${TAGS[$IDX]}" + +echo "============================================================================" +echo "SCAR Paper Sweep: LLaMA-3.1-8B calibration sensitivity (${TAG})" +echo "============================================================================" +echo "Job ID: ${SLURM_JOB_ID} Array Task: ${SLURM_ARRAY_TASK_ID}" +echo "Node: $(hostname)" +echo "Start time: $(date)" +echo "GPU: $(nvidia-smi --query-gpu=name --format=csv,noheader | head -1)" + +OUTPUT_BASE="${OUTPUT_BASE:-/n/holylfs06/LABS/kempner_project_b/Lab/alignment/Prune_LLM}" +echo "Output Base: $OUTPUT_BASE" +echo "Calibration dataset: ${DATASET}" +echo "Calibration samples: ${N}" +echo "" + +module purge +module load cuda/12.2.0-fasrc01 +eval "$(conda shell.bash hook)" +conda activate networkAlignmentAnalysis + +# Prefer SLURM_SUBMIT_DIR (repo root) when available. +cd "${SLURM_SUBMIT_DIR:-/n/holylabs/kempner_dev/Users/hsafaai/Code/alignment}" +mkdir -p logs + +export PYTHONPATH="${PWD}:${PWD}/src:${PYTHONPATH:-}" +export OMP_NUM_THREADS=$SLURM_CPUS_PER_TASK +export TOKENIZERS_PARALLELISM=false +export PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True +# HuggingFace auth/cache: +# - Respect HF_HOME if already set (e.g. exported from submission script). +# - Else, if you ran `hf auth login` with HF_HOME under OUTPUT_BASE, prefer that token/cache. +# - Else fall back to scratch cache, then ~/.cache. +if [[ -z "${HF_HOME:-}" ]]; then + if [[ -f "${OUTPUT_BASE}/huggingface_cache/token" ]]; then + export HF_HOME="${OUTPUT_BASE}/huggingface_cache" + elif [[ -d /n/holyscratch01/kempner_dev/Users/hsafaai/huggingface_cache ]]; then + export HF_HOME="/n/holyscratch01/kempner_dev/Users/hsafaai/huggingface_cache" + else + export HF_HOME="/n/home13/hsafaai/.cache/huggingface" + fi +fi +HF_TOKEN_FILE="${HF_HOME}/token" +if [[ -f "$HF_TOKEN_FILE" ]]; then + export HF_TOKEN="$(cat "$HF_TOKEN_FILE")" +elif [[ -z "${HF_TOKEN:-}" ]]; then + echo "WARN: HF token file not found at $HF_TOKEN_FILE (set HF_TOKEN env var if needed)" >&2 +fi +if [[ -n "${HF_TOKEN:-}" ]]; then + export HUGGINGFACE_HUB_TOKEN="${HF_TOKEN}" +fi +echo "HF_HOME: $HF_HOME" +if [[ -n "${HF_TOKEN:-}" ]]; then + echo "HF_TOKEN: set" +else + echo "HF_TOKEN: unset" +fi + +# NOTE: SCAR-Conn depends on directed redundancy + connectivity scoring. +python scripts/run_experiment.py \ + --config configs/prune_llm/llama3_8b_full.yaml \ + --device cuda \ + --base-output-dir "$OUTPUT_BASE" \ + name="llama3_8b_paper_results_calib_${TAG}" \ + generate_plots=false \ + dataset_name="${DATASET}" \ + alignment_data_num_samples="${N}" \ + scar_num_samples="${N}" \ + pruning_strategies="['supernode_connectivity_score']" \ + pruning_amounts="[0.5]" \ + pruning_selection_mode="['low']" \ + "llm.evaluation_metrics=['perplexity']" \ + do_directed_redundancy=true \ + do_connectivity_pruning=true \ + do_halo_analysis=false \ + do_generalized_importance=false \ + supernode_robustness.enabled=false \ + supernode_summary.enabled=false + +echo "" +echo "============================================================================" +echo "LLaMA-3.1-8B calibration sweep (${TAG}) completed at $(date)" +echo "============================================================================" + diff --git a/slurm_jobs/prune_llm/run_llama3_8b_noprotect.sh b/slurm_jobs/prune_llm/run_llama3_8b_noprotect.sh new file mode 100644 index 00000000..083ac59a --- /dev/null +++ b/slurm_jobs/prune_llm/run_llama3_8b_noprotect.sh @@ -0,0 +1,99 @@ +#!/bin/bash +#SBATCH --job-name=paper_llama3_noprotect +#SBATCH --output=logs/paper_llama3_noprotect_%j.out +#SBATCH --error=logs/paper_llama3_noprotect_%j.err +#SBATCH --nodes=1 +#SBATCH --ntasks-per-node=1 +#SBATCH --gres=gpu:4 +#SBATCH --cpus-per-task=16 +#SBATCH --time=06:00:00 +#SBATCH --mem=320GB +#SBATCH --partition=kempner_h100_priority3 +#SBATCH --account=kempner_dev + +# ---------------------------------------------------------------------------- +# LLaMA-3.1-8B CONTROL: LP-no-protect + "remove supernodes early" (mode=high) +# +# Produces (at 50%): +# - LP-no-protect: metric=scar_loss_proxy, mode=low, protect_core=false +# - Remove-core-early metric=scar_loss_proxy, mode=high, protect_core=false +# ---------------------------------------------------------------------------- + +set -euo pipefail + +echo "============================================================================" +echo "SCAR Paper Control: LLaMA-3.1-8B (no-protect LP control)" +echo "============================================================================" +echo "Job ID: $SLURM_JOB_ID" +echo "Node: $(hostname)" +echo "Start time: $(date)" +echo "GPU: $(nvidia-smi --query-gpu=name --format=csv,noheader | head -1)" + +OUTPUT_BASE="${OUTPUT_BASE:-/n/holylfs06/LABS/kempner_project_b/Lab/alignment/Prune_LLM}" +echo "Output Base: $OUTPUT_BASE" +echo "" + +module purge +module load cuda/12.2.0-fasrc01 +eval "$(conda shell.bash hook)" +conda activate networkAlignmentAnalysis + +# Prefer SLURM_SUBMIT_DIR (repo root) when available. +cd "${SLURM_SUBMIT_DIR:-/n/holylabs/kempner_dev/Users/hsafaai/Code/alignment}" +mkdir -p logs + +export PYTHONPATH="${PWD}:${PWD}/src:${PYTHONPATH:-}" +export OMP_NUM_THREADS=$SLURM_CPUS_PER_TASK +export TOKENIZERS_PARALLELISM=false +export PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True +# HuggingFace auth/cache: +# - Respect HF_HOME if already set (e.g. exported from submission script). +# - Else, if you ran `hf auth login` with HF_HOME under OUTPUT_BASE, prefer that token/cache. +# - Else fall back to scratch cache, then ~/.cache. +if [[ -z "${HF_HOME:-}" ]]; then + if [[ -f "${OUTPUT_BASE}/huggingface_cache/token" ]]; then + export HF_HOME="${OUTPUT_BASE}/huggingface_cache" + elif [[ -d /n/holyscratch01/kempner_dev/Users/hsafaai/huggingface_cache ]]; then + export HF_HOME="/n/holyscratch01/kempner_dev/Users/hsafaai/huggingface_cache" + else + export HF_HOME="/n/home13/hsafaai/.cache/huggingface" + fi +fi +HF_TOKEN_FILE="${HF_HOME}/token" +if [[ -f "$HF_TOKEN_FILE" ]]; then + export HF_TOKEN="$(cat "$HF_TOKEN_FILE")" +elif [[ -z "${HF_TOKEN:-}" ]]; then + echo "WARN: HF token file not found at $HF_TOKEN_FILE (set HF_TOKEN env var if needed)" >&2 +fi +if [[ -n "${HF_TOKEN:-}" ]]; then + export HUGGINGFACE_HUB_TOKEN="${HF_TOKEN}" +fi +echo "HF_HOME: $HF_HOME" +if [[ -n "${HF_TOKEN:-}" ]]; then + echo "HF_TOKEN: set" +else + echo "HF_TOKEN: unset" +fi + +python scripts/run_experiment.py \ + --config configs/prune_llm/llama3_8b_full.yaml \ + --device cuda \ + --base-output-dir "$OUTPUT_BASE" \ + name="llama3_8b_paper_results_noprotect" \ + generate_plots=false \ + supernode.protect_core=false \ + pruning_strategies="['scar_loss_proxy']" \ + pruning_amounts="[0.5]" \ + pruning_selection_mode="['low','high']" \ + do_connectivity_pruning=false \ + do_directed_redundancy=false \ + do_halo_analysis=false \ + do_generalized_importance=false \ + supernode_robustness.enabled=false \ + supernode_summary.enabled=false + +echo "" +echo "============================================================================" +echo "LLaMA-3.1-8B no-protect control completed at $(date)" +echo "============================================================================" + diff --git a/slurm_jobs/prune_llm/run_llama3_8b_positive_redundancy_array.sh b/slurm_jobs/prune_llm/run_llama3_8b_positive_redundancy_array.sh new file mode 100644 index 00000000..fac03b6a --- /dev/null +++ b/slurm_jobs/prune_llm/run_llama3_8b_positive_redundancy_array.sh @@ -0,0 +1,108 @@ +#!/bin/bash +#SBATCH --job-name=paper_llama3_posred +#SBATCH --output=logs/paper_llama3_posred_%A_%a.out +#SBATCH --error=logs/paper_llama3_posred_%A_%a.err +#SBATCH --nodes=1 +#SBATCH --ntasks-per-node=1 +#SBATCH --gres=gpu:4 +#SBATCH --cpus-per-task=16 +#SBATCH --time=06:00:00 +#SBATCH --mem=320GB +#SBATCH --partition=kempner_h100_priority3 +#SBATCH --account=kempner_dev +#SBATCH --array=0-1 + +# ---------------------------------------------------------------------------- +# LLaMA-3.1-8B ABLATION: positive-only redundancy vs rho^2 redundancy +# +# Task 0: positive_redundancy=false (rho^2 counts anti-correlation as redundancy) +# Task 1: positive_redundancy=true (rho^+ only; anti-correlation NOT redundant) +# ---------------------------------------------------------------------------- + +set -euo pipefail + +if [ "${SLURM_ARRAY_TASK_ID}" -eq 0 ]; then + POS_RED="false" + TAG="rho2" +else + POS_RED="true" + TAG="posonly" +fi + +echo "============================================================================" +echo "SCAR Paper Ablation: LLaMA-3.1-8B (positive redundancy = ${POS_RED})" +echo "============================================================================" +echo "Job ID: ${SLURM_JOB_ID} Array Task: ${SLURM_ARRAY_TASK_ID}" +echo "Node: $(hostname)" +echo "Start time: $(date)" +echo "GPU: $(nvidia-smi --query-gpu=name --format=csv,noheader | head -1)" + +OUTPUT_BASE="${OUTPUT_BASE:-/n/holylfs06/LABS/kempner_project_b/Lab/alignment/Prune_LLM}" +echo "Output Base: $OUTPUT_BASE" +echo "" + +module purge +module load cuda/12.2.0-fasrc01 +eval "$(conda shell.bash hook)" +conda activate networkAlignmentAnalysis + +# Prefer SLURM_SUBMIT_DIR (repo root) when available. +cd "${SLURM_SUBMIT_DIR:-/n/holylabs/kempner_dev/Users/hsafaai/Code/alignment}" +mkdir -p logs + +export PYTHONPATH="${PWD}:${PWD}/src:${PYTHONPATH:-}" +export OMP_NUM_THREADS=$SLURM_CPUS_PER_TASK +export TOKENIZERS_PARALLELISM=false +export PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True +# HuggingFace auth/cache: +# - Respect HF_HOME if already set (e.g. exported from submission script). +# - Else, if you ran `hf auth login` with HF_HOME under OUTPUT_BASE, prefer that token/cache. +# - Else fall back to scratch cache, then ~/.cache. +if [[ -z "${HF_HOME:-}" ]]; then + if [[ -f "${OUTPUT_BASE}/huggingface_cache/token" ]]; then + export HF_HOME="${OUTPUT_BASE}/huggingface_cache" + elif [[ -d /n/holyscratch01/kempner_dev/Users/hsafaai/huggingface_cache ]]; then + export HF_HOME="/n/holyscratch01/kempner_dev/Users/hsafaai/huggingface_cache" + else + export HF_HOME="/n/home13/hsafaai/.cache/huggingface" + fi +fi +HF_TOKEN_FILE="${HF_HOME}/token" +if [[ -f "$HF_TOKEN_FILE" ]]; then + export HF_TOKEN="$(cat "$HF_TOKEN_FILE")" +elif [[ -z "${HF_TOKEN:-}" ]]; then + echo "WARN: HF token file not found at $HF_TOKEN_FILE (set HF_TOKEN env var if needed)" >&2 +fi +if [[ -n "${HF_TOKEN:-}" ]]; then + export HUGGINGFACE_HUB_TOKEN="${HF_TOKEN}" +fi +echo "HF_HOME: $HF_HOME" +if [[ -n "${HF_TOKEN:-}" ]]; then + echo "HF_TOKEN: set" +else + echo "HF_TOKEN: unset" +fi + +python scripts/run_experiment.py \ + --config configs/prune_llm/llama3_8b_full.yaml \ + --device cuda \ + --base-output-dir "$OUTPUT_BASE" \ + name="llama3_8b_paper_results_posred_${TAG}" \ + generate_plots=false \ + supernode.positive_redundancy="${POS_RED}" \ + supernode.protect_core=true \ + "supernode.protect_core_metrics=['supernode_connectivity_score']" \ + pruning_strategies="['supernode_connectivity_score']" \ + pruning_amounts="[0.5]" \ + pruning_selection_mode="['low']" \ + do_directed_redundancy=false \ + do_halo_analysis=false \ + do_generalized_importance=false \ + supernode_robustness.enabled=false \ + supernode_summary.enabled=false + +echo "" +echo "============================================================================" +echo "LLaMA-3.1-8B pos-redundancy ablation (${TAG}) completed at $(date)" +echo "============================================================================" + diff --git a/slurm_jobs/prune_llm/run_llama3_8b_protect_baselines.sh b/slurm_jobs/prune_llm/run_llama3_8b_protect_baselines.sh new file mode 100644 index 00000000..d04996d1 --- /dev/null +++ b/slurm_jobs/prune_llm/run_llama3_8b_protect_baselines.sh @@ -0,0 +1,100 @@ +#!/bin/bash +#SBATCH --job-name=paper_llama3_protect_base +#SBATCH --output=logs/paper_llama3_protect_base_%j.out +#SBATCH --error=logs/paper_llama3_protect_base_%j.err +#SBATCH --nodes=1 +#SBATCH --ntasks-per-node=1 +#SBATCH --gres=gpu:4 +#SBATCH --cpus-per-task=16 +#SBATCH --time=08:00:00 +#SBATCH --mem=320GB +#SBATCH --partition=kempner_h100_priority3 +#SBATCH --account=kempner_dev + +# ---------------------------------------------------------------------------- +# LLaMA-3.1-8B CONTROL: Protect+Baseline variants +# +# Produces (at 50%): +# - Protect+Wanda: metric=wanda, protect_core_metrics includes wanda +# - Protect+Magnitude: metric=weight_magnitude, protect_core_metrics includes weight_magnitude +# ---------------------------------------------------------------------------- + +set -euo pipefail + +echo "============================================================================" +echo "SCAR Paper Control: LLaMA-3.1-8B (protect baselines)" +echo "============================================================================" +echo "Job ID: $SLURM_JOB_ID" +echo "Node: $(hostname)" +echo "Start time: $(date)" +echo "GPU: $(nvidia-smi --query-gpu=name --format=csv,noheader | head -1)" + +OUTPUT_BASE="${OUTPUT_BASE:-/n/holylfs06/LABS/kempner_project_b/Lab/alignment/Prune_LLM}" +echo "Output Base: $OUTPUT_BASE" +echo "" + +module purge +module load cuda/12.2.0-fasrc01 +eval "$(conda shell.bash hook)" +conda activate networkAlignmentAnalysis + +# Prefer SLURM_SUBMIT_DIR (repo root) when available. +cd "${SLURM_SUBMIT_DIR:-/n/holylabs/kempner_dev/Users/hsafaai/Code/alignment}" +mkdir -p logs + +export PYTHONPATH="${PWD}:${PWD}/src:${PYTHONPATH:-}" +export OMP_NUM_THREADS=$SLURM_CPUS_PER_TASK +export TOKENIZERS_PARALLELISM=false +export PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True +# HuggingFace auth/cache: +# - Respect HF_HOME if already set (e.g. exported from submission script). +# - Else, if you ran `hf auth login` with HF_HOME under OUTPUT_BASE, prefer that token/cache. +# - Else fall back to scratch cache, then ~/.cache. +if [[ -z "${HF_HOME:-}" ]]; then + if [[ -f "${OUTPUT_BASE}/huggingface_cache/token" ]]; then + export HF_HOME="${OUTPUT_BASE}/huggingface_cache" + elif [[ -d /n/holyscratch01/kempner_dev/Users/hsafaai/huggingface_cache ]]; then + export HF_HOME="/n/holyscratch01/kempner_dev/Users/hsafaai/huggingface_cache" + else + export HF_HOME="/n/home13/hsafaai/.cache/huggingface" + fi +fi +HF_TOKEN_FILE="${HF_HOME}/token" +if [[ -f "$HF_TOKEN_FILE" ]]; then + export HF_TOKEN="$(cat "$HF_TOKEN_FILE")" +elif [[ -z "${HF_TOKEN:-}" ]]; then + echo "WARN: HF token file not found at $HF_TOKEN_FILE (set HF_TOKEN env var if needed)" >&2 +fi +if [[ -n "${HF_TOKEN:-}" ]]; then + export HUGGINGFACE_HUB_TOKEN="${HF_TOKEN}" +fi +echo "HF_HOME: $HF_HOME" +if [[ -n "${HF_TOKEN:-}" ]]; then + echo "HF_TOKEN: set" +else + echo "HF_TOKEN: unset" +fi + +python scripts/run_experiment.py \ + --config configs/prune_llm/llama3_8b_full.yaml \ + --device cuda \ + --base-output-dir "$OUTPUT_BASE" \ + name="llama3_8b_paper_results_protect_baselines" \ + generate_plots=false \ + supernode.protect_core=true \ + "supernode.protect_core_metrics=['wanda','weight_magnitude']" \ + pruning_strategies="['wanda','weight_magnitude']" \ + pruning_amounts="[0.5]" \ + pruning_selection_mode="['low']" \ + do_connectivity_pruning=false \ + do_directed_redundancy=false \ + do_halo_analysis=false \ + do_generalized_importance=false \ + supernode_robustness.enabled=false \ + supernode_summary.enabled=false + +echo "" +echo "============================================================================" +echo "LLaMA-3.1-8B protect-baselines completed at $(date)" +echo "============================================================================" + diff --git a/slurm_jobs/prune_llm/run_llama3_8b_sparsegpt_unstructured.sh b/slurm_jobs/prune_llm/run_llama3_8b_sparsegpt_unstructured.sh new file mode 100644 index 00000000..7d3f8da6 --- /dev/null +++ b/slurm_jobs/prune_llm/run_llama3_8b_sparsegpt_unstructured.sh @@ -0,0 +1,100 @@ +#!/bin/bash +#SBATCH --job-name=paper_llama3_sparsegpt_unstruct +#SBATCH --output=logs/paper_llama3_sparsegpt_unstruct_%j.out +#SBATCH --error=logs/paper_llama3_sparsegpt_unstruct_%j.err +#SBATCH --nodes=1 +#SBATCH --ntasks-per-node=1 +#SBATCH --gres=gpu:4 +#SBATCH --cpus-per-task=16 +#SBATCH --time=12:00:00 +#SBATCH --mem=320GB +#SBATCH --partition=kempner_h100_priority3 +#SBATCH --account=kempner_dev + +# ============================================================================ +# LLaMA-3.1-8B PAPER-FAITHFUL BASELINE: SparseGPT (UNSTRUCTURED + RECONSTRUCTION) +# ============================================================================ +# Purpose: +# - Run SparseGPT as originally intended (unstructured weight pruning with reconstruction), +# as an appendix/sanity baseline, separate from the channel-adapted SparseGPT baseline. +# +# Notes: +# - This is NOT structured FFN channel pruning; it's unstructured weight pruning. +# - This is compute-heavy; we run a small setting by default (50% sparsity, mode=low, perplexity-only). +# ============================================================================ + +set -euo pipefail + +echo "============================================================================" +echo "SCAR Paper Baseline (unstructured): SparseGPT | LLaMA-3.1-8B" +echo "============================================================================" +echo "Job ID: $SLURM_JOB_ID" +echo "Node: $(hostname)" +echo "Start time: $(date)" +echo "GPU: $(nvidia-smi --query-gpu=name --format=csv,noheader | head -1)" + +OUTPUT_BASE="${OUTPUT_BASE:-/n/holylfs06/LABS/kempner_project_b/Lab/alignment/Prune_LLM}" +echo "Output Base: $OUTPUT_BASE" +echo "" + +module purge +module load cuda/12.2.0-fasrc01 +eval "$(conda shell.bash hook)" +conda activate networkAlignmentAnalysis + +# Prefer SLURM_SUBMIT_DIR (repo root) when available. +cd "${SLURM_SUBMIT_DIR:-/n/holylabs/kempner_dev/Users/hsafaai/Code/alignment}" +mkdir -p logs + +export PYTHONPATH="${PWD}:${PWD}/src:${PYTHONPATH:-}" +export OMP_NUM_THREADS=$SLURM_CPUS_PER_TASK +export TOKENIZERS_PARALLELISM=false +export PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True + +# HuggingFace auth/cache: +# - Respect HF_HOME if already set (e.g. exported from submission script). +# - Else, if you ran `hf auth login` with HF_HOME under OUTPUT_BASE, prefer that token/cache. +# - Else fall back to scratch cache, then ~/.cache. +if [[ -z "${HF_HOME:-}" ]]; then + if [[ -f "${OUTPUT_BASE}/huggingface_cache/token" ]]; then + export HF_HOME="${OUTPUT_BASE}/huggingface_cache" + elif [[ -d /n/holyscratch01/kempner_dev/Users/hsafaai/huggingface_cache ]]; then + export HF_HOME="/n/holyscratch01/kempner_dev/Users/hsafaai/huggingface_cache" + else + export HF_HOME="/n/home13/hsafaai/.cache/huggingface" + fi +fi +HF_TOKEN_FILE="${HF_HOME}/token" +if [[ -f "$HF_TOKEN_FILE" ]]; then + export HF_TOKEN="$(cat "$HF_TOKEN_FILE")" + export HUGGINGFACE_HUB_TOKEN="${HF_TOKEN}" +fi +echo "HF_HOME: $HF_HOME" +if [[ -n "${HF_TOKEN:-}" ]]; then + echo "HF_TOKEN: set" +else + echo "HF_TOKEN: unset" +fi + +python scripts/run_experiment.py \ + --config configs/prune_llm/llama3_8b_full.yaml \ + --device cuda \ + --base-output-dir "$OUTPUT_BASE" \ + name="llama3_8b_paper_results_sparsegpt_unstructured" \ + generate_plots=false \ + pruning_strategies="['sparsegpt_unstructured']" \ + pruning_amounts="[0.5]" \ + pruning_selection_mode="['low']" \ + "llm.evaluation_metrics=['perplexity']" \ + do_connectivity_pruning=false \ + do_directed_redundancy=false \ + do_halo_analysis=false \ + do_generalized_importance=false \ + supernode_robustness.enabled=false \ + supernode_summary.enabled=false + +echo "" +echo "============================================================================" +echo "SparseGPT unstructured baseline completed at $(date)" +echo "============================================================================" + diff --git a/slurm_jobs/prune_llm/run_llama3_8b_wanda_unstructured.sh b/slurm_jobs/prune_llm/run_llama3_8b_wanda_unstructured.sh new file mode 100644 index 00000000..daad80ec --- /dev/null +++ b/slurm_jobs/prune_llm/run_llama3_8b_wanda_unstructured.sh @@ -0,0 +1,100 @@ +#!/bin/bash +#SBATCH --job-name=paper_llama3_wanda_unstruct +#SBATCH --output=logs/paper_llama3_wanda_unstruct_%j.out +#SBATCH --error=logs/paper_llama3_wanda_unstruct_%j.err +#SBATCH --nodes=1 +#SBATCH --ntasks-per-node=1 +#SBATCH --gres=gpu:4 +#SBATCH --cpus-per-task=16 +#SBATCH --time=08:00:00 +#SBATCH --mem=320GB +#SBATCH --partition=kempner_h100_priority3 +#SBATCH --account=kempner_dev + +# ============================================================================ +# LLaMA-3.1-8B PAPER-FAITHFUL BASELINE: WANDA (UNSTRUCTURED) +# ============================================================================ +# Purpose: +# - Run Wanda as originally intended (unstructured weight pruning using |W| * ||X||_2), +# as an appendix/sanity baseline, separate from the channel-adapted Wanda baseline. +# +# Notes: +# - This is NOT structured FFN channel pruning; it's unstructured weight pruning. +# - We run a small setting by default (50% sparsity, mode=low, perplexity-only) to keep runtime sane. +# ============================================================================ + +set -euo pipefail + +echo "============================================================================" +echo "SCAR Paper Baseline (unstructured): Wanda | LLaMA-3.1-8B" +echo "============================================================================" +echo "Job ID: $SLURM_JOB_ID" +echo "Node: $(hostname)" +echo "Start time: $(date)" +echo "GPU: $(nvidia-smi --query-gpu=name --format=csv,noheader | head -1)" + +OUTPUT_BASE="${OUTPUT_BASE:-/n/holylfs06/LABS/kempner_project_b/Lab/alignment/Prune_LLM}" +echo "Output Base: $OUTPUT_BASE" +echo "" + +module purge +module load cuda/12.2.0-fasrc01 +eval "$(conda shell.bash hook)" +conda activate networkAlignmentAnalysis + +# Prefer SLURM_SUBMIT_DIR (repo root) when available. +cd "${SLURM_SUBMIT_DIR:-/n/holylabs/kempner_dev/Users/hsafaai/Code/alignment}" +mkdir -p logs + +export PYTHONPATH="${PWD}:${PWD}/src:${PYTHONPATH:-}" +export OMP_NUM_THREADS=$SLURM_CPUS_PER_TASK +export TOKENIZERS_PARALLELISM=false +export PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True + +# HuggingFace auth/cache: +# - Respect HF_HOME if already set (e.g. exported from submission script). +# - Else, if you ran `hf auth login` with HF_HOME under OUTPUT_BASE, prefer that token/cache. +# - Else fall back to scratch cache, then ~/.cache. +if [[ -z "${HF_HOME:-}" ]]; then + if [[ -f "${OUTPUT_BASE}/huggingface_cache/token" ]]; then + export HF_HOME="${OUTPUT_BASE}/huggingface_cache" + elif [[ -d /n/holyscratch01/kempner_dev/Users/hsafaai/huggingface_cache ]]; then + export HF_HOME="/n/holyscratch01/kempner_dev/Users/hsafaai/huggingface_cache" + else + export HF_HOME="/n/home13/hsafaai/.cache/huggingface" + fi +fi +HF_TOKEN_FILE="${HF_HOME}/token" +if [[ -f "$HF_TOKEN_FILE" ]]; then + export HF_TOKEN="$(cat "$HF_TOKEN_FILE")" + export HUGGINGFACE_HUB_TOKEN="${HF_TOKEN}" +fi +echo "HF_HOME: $HF_HOME" +if [[ -n "${HF_TOKEN:-}" ]]; then + echo "HF_TOKEN: set" +else + echo "HF_TOKEN: unset" +fi + +python scripts/run_experiment.py \ + --config configs/prune_llm/llama3_8b_full.yaml \ + --device cuda \ + --base-output-dir "$OUTPUT_BASE" \ + name="llama3_8b_paper_results_wanda_unstructured" \ + generate_plots=false \ + pruning_strategies="['wanda_unstructured']" \ + pruning_amounts="[0.5]" \ + pruning_selection_mode="['low']" \ + "llm.evaluation_metrics=['perplexity']" \ + do_connectivity_pruning=false \ + do_directed_redundancy=false \ + do_halo_analysis=false \ + do_generalized_importance=false \ + supernode_robustness.enabled=false \ + supernode_summary.enabled=false + +echo "" +echo "============================================================================" +echo "Wanda unstructured baseline completed at $(date)" +echo "============================================================================" + diff --git a/slurm_jobs/prune_llm/run_mistral_7b.sh b/slurm_jobs/prune_llm/run_mistral_7b.sh new file mode 100755 index 00000000..460eee70 --- /dev/null +++ b/slurm_jobs/prune_llm/run_mistral_7b.sh @@ -0,0 +1,103 @@ +#!/bin/bash +#SBATCH --job-name=paper_mistral_7b +#SBATCH --output=logs/paper_mistral_7b_%j.out +#SBATCH --error=logs/paper_mistral_7b_%j.err +#SBATCH --nodes=1 +#SBATCH --ntasks-per-node=1 +#SBATCH --gres=gpu:4 +#SBATCH --cpus-per-task=16 +#SBATCH --time=10:00:00 +#SBATCH --mem=320GB +#SBATCH --partition=kempner_h100_priority3 +#SBATCH --account=kempner_dev + +# ============================================================================ +# MISTRAL-7B PAPER RESULTS (Generalization) +# ============================================================================ +# Cross-model generalization experiment +# Expected runtime: ~4-6 hours on H100 +# +# Output Directory Structure: +# /n/holylfs06/LABS/kempner_project_b/Lab/alignment/Prune_LLM/ +# mistral_7b_paper_results_{timestamp}_{SLURM_JOB_ID}/ +# results/ - JSON results files +# logs/ - experiment.log +# figures/ - All visualizations +# checkpoints/ - Model checkpoints +# analysis/ - Post-analysis outputs +# ============================================================================ + +set -euo pipefail + +echo "============================================================================" +echo "SCAR Paper: Mistral-7B (Generalization)" +echo "============================================================================" +echo "Job ID: $SLURM_JOB_ID" +echo "Node: $(hostname)" +echo "Start time: $(date)" +echo "GPU: $(nvidia-smi --query-gpu=name --format=csv,noheader | head -1)" +OUTPUT_BASE="${OUTPUT_BASE:-/n/holylfs06/LABS/kempner_project_b/Lab/alignment/Prune_LLM}" +echo "Output Base: $OUTPUT_BASE" +echo "" + +# Environment setup +module purge +module load cuda/12.2.0-fasrc01 +eval "$(conda shell.bash hook)" +conda activate networkAlignmentAnalysis + +# Prefer SLURM_SUBMIT_DIR (repo root) when available. +cd "${SLURM_SUBMIT_DIR:-/n/holylabs/kempner_dev/Users/hsafaai/Code/alignment}" + +# Create local logs directory for SLURM output files +mkdir -p logs + +export PYTHONPATH="${PWD}:${PWD}/src:${PYTHONPATH:-}" +export OMP_NUM_THREADS=$SLURM_CPUS_PER_TASK +export TOKENIZERS_PARALLELISM=false +export PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True +# HuggingFace auth/cache: +# - Respect HF_HOME if already set (e.g. exported from submission script). +# - Else, if you ran `hf auth login` with HF_HOME under OUTPUT_BASE, prefer that token/cache. +# - Else fall back to scratch cache, then ~/.cache. +if [[ -z "${HF_HOME:-}" ]]; then + if [[ -f "${OUTPUT_BASE}/huggingface_cache/token" ]]; then + export HF_HOME="${OUTPUT_BASE}/huggingface_cache" + elif [[ -d /n/holyscratch01/kempner_dev/Users/hsafaai/huggingface_cache ]]; then + export HF_HOME="/n/holyscratch01/kempner_dev/Users/hsafaai/huggingface_cache" + else + export HF_HOME="/n/home13/hsafaai/.cache/huggingface" + fi +fi +HF_TOKEN_FILE="${HF_HOME}/token" +if [[ -f "$HF_TOKEN_FILE" ]]; then + export HF_TOKEN="$(cat "$HF_TOKEN_FILE")" +elif [[ -z "${HF_TOKEN:-}" ]]; then + echo "WARN: HF token file not found at $HF_TOKEN_FILE (set HF_TOKEN env var if needed)" >&2 +fi +if [[ -n "${HF_TOKEN:-}" ]]; then + export HUGGINGFACE_HUB_TOKEN="${HF_TOKEN}" +fi +echo "HF_HOME: $HF_HOME" +if [[ -n "${HF_TOKEN:-}" ]]; then + echo "HF_TOKEN: set" +else + echo "HF_TOKEN: unset" +fi + +echo "" +echo "Running Mistral-7B full paper analysis..." +echo "" + +python scripts/run_experiment.py \ + --config configs/prune_llm/mistral_7b_full.yaml \ + --device cuda \ + --base-output-dir "$OUTPUT_BASE" + +echo "" +echo "============================================================================" +echo "Mistral-7B completed at $(date)" +echo "============================================================================" +echo "" +echo "Results saved to: $OUTPUT_BASE/" +echo "Look for directory: mistral_7b_paper_results_*_$SLURM_JOB_ID" diff --git a/slurm_jobs/prune_llm/run_qwen2_7b.sh b/slurm_jobs/prune_llm/run_qwen2_7b.sh new file mode 100755 index 00000000..85e537cf --- /dev/null +++ b/slurm_jobs/prune_llm/run_qwen2_7b.sh @@ -0,0 +1,104 @@ +#!/bin/bash +#SBATCH --job-name=paper_qwen2_7b +#SBATCH --output=logs/paper_qwen2_7b_%j.out +#SBATCH --error=logs/paper_qwen2_7b_%j.err +#SBATCH --nodes=1 +#SBATCH --ntasks-per-node=1 +#SBATCH --gres=gpu:4 +#SBATCH --cpus-per-task=16 +#SBATCH --time=10:00:00 +#SBATCH --mem=320GB +#SBATCH --partition=kempner_h100_priority3 +#SBATCH --account=kempner_dev + +# ============================================================================ +# QWEN2-7B PAPER RESULTS (Generalization) +# ============================================================================ +# Cross-model generalization experiment +# Qwen2 has different FFN architecture (28 layers, larger intermediate) +# Expected runtime: ~4-6 hours on H100 +# +# Output Directory Structure: +# /n/holylfs06/LABS/kempner_project_b/Lab/alignment/Prune_LLM/ +# qwen2_7b_paper_results_{timestamp}_{SLURM_JOB_ID}/ +# results/ - JSON results files +# logs/ - experiment.log +# figures/ - All visualizations +# checkpoints/ - Model checkpoints +# analysis/ - Post-analysis outputs +# ============================================================================ + +set -euo pipefail + +echo "============================================================================" +echo "SCAR Paper: Qwen2-7B (Generalization)" +echo "============================================================================" +echo "Job ID: $SLURM_JOB_ID" +echo "Node: $(hostname)" +echo "Start time: $(date)" +echo "GPU: $(nvidia-smi --query-gpu=name --format=csv,noheader | head -1)" +OUTPUT_BASE="${OUTPUT_BASE:-/n/holylfs06/LABS/kempner_project_b/Lab/alignment/Prune_LLM}" +echo "Output Base: $OUTPUT_BASE" +echo "" + +# Environment setup +module purge +module load cuda/12.2.0-fasrc01 +eval "$(conda shell.bash hook)" +conda activate networkAlignmentAnalysis + +# Prefer SLURM_SUBMIT_DIR (repo root) when available. +cd "${SLURM_SUBMIT_DIR:-/n/holylabs/kempner_dev/Users/hsafaai/Code/alignment}" + +# Create local logs directory for SLURM output files +mkdir -p logs + +export PYTHONPATH="${PWD}:${PWD}/src:${PYTHONPATH:-}" +export OMP_NUM_THREADS=$SLURM_CPUS_PER_TASK +export TOKENIZERS_PARALLELISM=false +export PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True +# HuggingFace auth/cache: +# - Respect HF_HOME if already set (e.g. exported from submission script). +# - Else, if you ran `hf auth login` with HF_HOME under OUTPUT_BASE, prefer that token/cache. +# - Else fall back to scratch cache, then ~/.cache. +if [[ -z "${HF_HOME:-}" ]]; then + if [[ -f "${OUTPUT_BASE}/huggingface_cache/token" ]]; then + export HF_HOME="${OUTPUT_BASE}/huggingface_cache" + elif [[ -d /n/holyscratch01/kempner_dev/Users/hsafaai/huggingface_cache ]]; then + export HF_HOME="/n/holyscratch01/kempner_dev/Users/hsafaai/huggingface_cache" + else + export HF_HOME="/n/home13/hsafaai/.cache/huggingface" + fi +fi +HF_TOKEN_FILE="${HF_HOME}/token" +if [[ -f "$HF_TOKEN_FILE" ]]; then + export HF_TOKEN="$(cat "$HF_TOKEN_FILE")" +elif [[ -z "${HF_TOKEN:-}" ]]; then + echo "WARN: HF token file not found at $HF_TOKEN_FILE (set HF_TOKEN env var if needed)" >&2 +fi +if [[ -n "${HF_TOKEN:-}" ]]; then + export HUGGINGFACE_HUB_TOKEN="${HF_TOKEN}" +fi +echo "HF_HOME: $HF_HOME" +if [[ -n "${HF_TOKEN:-}" ]]; then + echo "HF_TOKEN: set" +else + echo "HF_TOKEN: unset" +fi + +echo "" +echo "Running Qwen2-7B full paper analysis..." +echo "" + +python scripts/run_experiment.py \ + --config configs/prune_llm/qwen2_7b_full.yaml \ + --device cuda \ + --base-output-dir "$OUTPUT_BASE" + +echo "" +echo "============================================================================" +echo "Qwen2-7B completed at $(date)" +echo "============================================================================" +echo "" +echo "Results saved to: $OUTPUT_BASE/" +echo "Look for directory: qwen2_7b_paper_results_*_$SLURM_JOB_ID" diff --git a/slurm_jobs/prune_llm/submit_suite.sh b/slurm_jobs/prune_llm/submit_suite.sh new file mode 100644 index 00000000..d709f6af --- /dev/null +++ b/slurm_jobs/prune_llm/submit_suite.sh @@ -0,0 +1,63 @@ +#!/bin/bash +# ============================================================================ +# SUBMIT FULL SCAR PAPER SUITE (main + controls/ablations) +# ============================================================================ +# Usage: +# cd /n/holylabs/kempner_dev/Users/hsafaai/Code/alignment +# bash slurm_jobs/prune_llm/submit_suite.sh +# +# Output: +# Uses OUTPUT_BASE (exported or defaulted below). +# ============================================================================ + +# NOTE: This is a *submission* script (it calls `sbatch ...` for the real jobs). +# Run it with `bash ...` from a login node. If you accidentally run it with `sbatch`, +# Slurm would normally create `slurm-.out` in the repo root; we redirect that +# output to /tmp to avoid polluting the source tree. +#SBATCH --job-name=submit_scar_paper_suite +#SBATCH --output=/tmp/%x_%j.out +#SBATCH --error=/tmp/%x_%j.err + +set -euo pipefail + +OUTPUT_BASE="${OUTPUT_BASE:-/n/holylfs06/LABS/kempner_project_b/Lab/alignment/Prune_LLM}" +# Ensure compute jobs can find the HuggingFace token/cache. +# If you ran `hf auth login` with HF_HOME under OUTPUT_BASE, this propagates it to all sbatch jobs. +export HF_HOME="${HF_HOME:-${OUTPUT_BASE}/huggingface_cache}" +mkdir -p "$HF_HOME" || true + +echo "==============================================" +echo "Submitting SCAR Paper Suite" +echo "==============================================" +echo "OUTPUT_BASE: $OUTPUT_BASE" +echo "" + +REPO_ROOT="$(cd "$(dirname "${BASH_SOURCE[0]}")/../.." && pwd)" +cd "${REPO_ROOT}" +mkdir -p logs + +echo "---- Main results + generalization (4 models) ----" +export OUTPUT_BASE +bash slurm_jobs/prune_llm/run_all_paper.sh +echo "" + +echo "---- Controls / ablations (Llama-3.1-8B) ----" +JOB_NP=$(sbatch --export=ALL,OUTPUT_BASE="$OUTPUT_BASE",HF_HOME="$HF_HOME",HF_TOKEN=,HUGGINGFACE_HUB_TOKEN= slurm_jobs/prune_llm/run_llama3_8b_noprotect.sh | awk '{print $4}') +echo " noprotect/control: $JOB_NP" + +JOB_PB=$(sbatch --export=ALL,OUTPUT_BASE="$OUTPUT_BASE",HF_HOME="$HF_HOME",HF_TOKEN=,HUGGINGFACE_HUB_TOKEN= slurm_jobs/prune_llm/run_llama3_8b_protect_baselines.sh | awk '{print $4}') +echo " protect-baselines: $JOB_PB" + +JOB_POSRED=$(sbatch --export=ALL,OUTPUT_BASE="$OUTPUT_BASE",HF_HOME="$HF_HOME",HF_TOKEN=,HUGGINGFACE_HUB_TOKEN= slurm_jobs/prune_llm/run_llama3_8b_positive_redundancy_array.sh | awk '{print $4}') +echo " pos-redundancy array: $JOB_POSRED" + +JOB_CALIB=$(sbatch --export=ALL,OUTPUT_BASE="$OUTPUT_BASE",HF_HOME="$HF_HOME",HF_TOKEN=,HUGGINGFACE_HUB_TOKEN= slurm_jobs/prune_llm/run_llama3_8b_calibration_array.sh | awk '{print $4}') +echo " calibration array: $JOB_CALIB" + +echo "" +echo "==============================================" +echo "All suite jobs submitted" +echo "==============================================" +echo "Monitor with: squeue -u \$USER" +echo "" + diff --git a/slurm_jobs/run_baseline_test.sh b/slurm_jobs/run_baseline_test.sh index f1045ebd..02bbaeb8 100644 --- a/slurm_jobs/run_baseline_test.sh +++ b/slurm_jobs/run_baseline_test.sh @@ -21,6 +21,8 @@ set -euo pipefail REPO_ROOT="$(cd "$(dirname "${BASH_SOURCE[0]}")/.." && pwd)" cd "$REPO_ROOT" +export PYTHONPATH="${PWD}:${PWD}/src:${PYTHONPATH:-}" + echo "==========================================" echo "Baseline Pruning Test (Wanda + SparseGPT)" echo "==========================================" diff --git a/slurm_jobs/run_fast_pruning.sh b/slurm_jobs/run_fast_pruning.sh index 20292c29..8f710cb3 100755 --- a/slurm_jobs/run_fast_pruning.sh +++ b/slurm_jobs/run_fast_pruning.sh @@ -32,6 +32,8 @@ set -euo pipefail REPO_ROOT="$(cd "$(dirname "${BASH_SOURCE[0]}")/.." && pwd)" cd "$REPO_ROOT" +export PYTHONPATH="${PWD}:${PWD}/src:${PYTHONPATH:-}" + echo "============================================================================" echo "FAST LLM PRUNING COMPARISON" echo "============================================================================" diff --git a/slurm_jobs/run_mnist_basic.sh b/slurm_jobs/run_mnist_basic.sh index a99b5529..5dba90c9 100644 --- a/slurm_jobs/run_mnist_basic.sh +++ b/slurm_jobs/run_mnist_basic.sh @@ -17,6 +17,8 @@ set -euo pipefail REPO_ROOT="$(cd "$(dirname "${BASH_SOURCE[0]}")/.." && pwd)" cd "$REPO_ROOT" +export PYTHONPATH="${PWD}:${PWD}/src:${PYTHONPATH:-}" + echo "Starting MNIST basic alignment experiment at $(date)" echo "Job ID: ${SLURM_JOB_ID:-N/A}" echo "Running on: $(hostname)" diff --git a/slurm_jobs/run_single_model.sh b/slurm_jobs/run_single_model.sh index 6958f28c..6f3a4a6d 100644 --- a/slurm_jobs/run_single_model.sh +++ b/slurm_jobs/run_single_model.sh @@ -45,6 +45,8 @@ set -euo pipefail REPO_ROOT="$(cd "$(dirname "${BASH_SOURCE[0]}")/.." && pwd)" cd "$REPO_ROOT" +export PYTHONPATH="${PWD}:${PWD}/src:${PYTHONPATH:-}" + echo "============================================================================" echo "SINGLE MODEL PRUNING: ${CONFIG_NAME}" echo "============================================================================" diff --git a/slurm_jobs/run_test_all_layers.sh b/slurm_jobs/run_test_all_layers.sh index e682ab53..1d6326ac 100755 --- a/slurm_jobs/run_test_all_layers.sh +++ b/slurm_jobs/run_test_all_layers.sh @@ -18,6 +18,8 @@ set -euo pipefail REPO_ROOT="$(cd "$(dirname "${BASH_SOURCE[0]}")/.." && pwd)" cd "$REPO_ROOT" +export PYTHONPATH="${PWD}:${PWD}/src:${PYTHONPATH:-}" + echo "==========================================" echo "Test: All Layers (MLP + Attention)" echo "==========================================" diff --git a/slurm_jobs/run_vision_pruning_test.sh b/slurm_jobs/run_vision_pruning_test.sh index 4e08033b..1f28671b 100755 --- a/slurm_jobs/run_vision_pruning_test.sh +++ b/slurm_jobs/run_vision_pruning_test.sh @@ -18,6 +18,8 @@ set -euo pipefail REPO_ROOT="$(cd "$(dirname "${BASH_SOURCE[0]}")/.." && pwd)" cd "$REPO_ROOT" +export PYTHONPATH="${PWD}:${PWD}/src:${PYTHONPATH:-}" + echo "==========================================" echo "Vision Pruning Test (AlexNet on ImageNet)" echo "==========================================" diff --git a/slurm_jobs/vision_prune/build_artifacts.sh b/slurm_jobs/vision_prune/build_artifacts.sh new file mode 100644 index 00000000..8023cf7d --- /dev/null +++ b/slurm_jobs/vision_prune/build_artifacts.sh @@ -0,0 +1,41 @@ +#!/bin/bash +#SBATCH --job-name=vision_paper_build +#SBATCH --output=logs/vision_paper_build_%j.out +#SBATCH --error=logs/vision_paper_build_%j.err +#SBATCH --nodes=1 +#SBATCH --ntasks-per-node=1 +#SBATCH --cpus-per-task=4 +#SBATCH --time=2:00:00 +#SBATCH --mem=32GB +#SBATCH --partition=kempner_h100_priority3 +#SBATCH --account=kempner_dev + +set -euo pipefail + +OUTPUT_BASE="${OUTPUT_BASE:-/n/holylfs06/LABS/kempner_project_b/Lab/alignment/alignment_red}" + +echo "============================================================================" +echo "Vision Paper: Build all figures + tables from existing runs" +echo "============================================================================" +echo "Job ID: ${SLURM_JOB_ID:-N/A}" +echo "Node: $(hostname)" +echo "Start time: $(date)" +echo "OUTPUT_BASE: $OUTPUT_BASE" +echo "" + +module purge +module load cuda/12.2.0-fasrc01 +eval "$(conda shell.bash hook)" +conda activate networkAlignmentAnalysis + +cd /n/holylabs/kempner_dev/Users/hsafaai/Code/alignment +mkdir -p logs + +python drafts/alignment_notes/paper/scripts/build_all_artifacts.py \ + --results-base "$OUTPUT_BASE" + +echo "" +echo "Done: $(date)" +echo "Paper figures: drafts/alignment_notes/paper_figures_vision/" +echo "Paper tables: drafts/alignment_notes/paper_artifacts/tables/" + diff --git a/slurm_jobs/vision_prune/run_all_array.sh b/slurm_jobs/vision_prune/run_all_array.sh new file mode 100644 index 00000000..05640f03 --- /dev/null +++ b/slurm_jobs/vision_prune/run_all_array.sh @@ -0,0 +1,186 @@ +#!/bin/bash +#SBATCH --job-name=vision_paper_all +#SBATCH --output=logs/vision_paper_all_%A_%a.out +#SBATCH --error=logs/vision_paper_all_%A_%a.err +#SBATCH --nodes=1 +#SBATCH --ntasks-per-node=1 +#SBATCH --gres=gpu:1 +#SBATCH --cpus-per-task=8 +#SBATCH --time=12:00:00 +#SBATCH --mem=64GB +#SBATCH --partition=kempner_eng +#SBATCH --account=kempner_dev +# +# One array job that runs the full vision paper suite + appendix, throttled to +# at most 16 concurrent tasks (== 16 GPUs if each task requests 1 GPU). +# ----------------------------------------------------------------------------- +# Task map: +# 0 resnet18_cifar10_cluster_analysis +# 1 vgg16_cifar10_cluster_analysis +# 2 mobilenetv2_cifar10_cluster_analysis +# 3 resnet50_imagenet100_cluster_analysis +# 4 GAP robustness (resnet18, activation_samples=gap) +# 5 Ablation (resnet18 @ 50%: cluster_aware variants + composite) +# 6-20 Weight sweep (15 tasks): gamma∈{0.10,0.30,0.50} × lambda∈{0.00,0.25,0.50,0.75,1.00} +# Each sweep run prunes across multiple sparsity ratios so the per-run figures show pruning effects. +# +# Submit via: slurm_jobs/vision_prune/submit_all_array.sh +# ----------------------------------------------------------------------------- + +set -euo pipefail + +OUTPUT_BASE="${OUTPUT_BASE:-/n/holylfs06/LABS/kempner_project_b/Lab/alignment/alignment_red}" + +echo "============================================================================" +echo "Vision Paper: ALL runs (single SLURM array, max 16 GPUs)" +echo "============================================================================" +echo "Array Job ID: ${SLURM_ARRAY_JOB_ID:-N/A} Task: ${SLURM_ARRAY_TASK_ID:-N/A}" +echo "Node: $(hostname)" +echo "Start time: $(date)" +echo "OUTPUT_BASE: $OUTPUT_BASE" +echo "" + +module purge +module load cuda/12.2.0-fasrc01 +eval "$(conda shell.bash hook)" +conda activate networkAlignmentAnalysis + +cd /n/holylabs/kempner_dev/Users/hsafaai/Code/alignment +mkdir -p logs + +TASK="${SLURM_ARRAY_TASK_ID:?SLURM_ARRAY_TASK_ID not set}" + +run_py() { + echo "" + echo "$ $*" + python "$@" +} + +prepare_imagenet100() { + # Robust ImageNet-100 subset prep (safe with set -o pipefail) + IMAGENET1K_ROOT="${IMAGENET1K_ROOT:-/n/holylfs06/LABS/kempner_shared/Everyone/testbed/vision/imagenet_1k}" + IMAGENET100_ROOT="${IMAGENET100_ROOT:-$PWD/data/imagenet100}" + IMAGENET100_NCLASSES="${IMAGENET100_NCLASSES:-100}" + + if [ ! -d "${IMAGENET1K_ROOT}/train" ] || [ ! -d "${IMAGENET1K_ROOT}/val" ]; then + echo "[error] IMAGENET1K_ROOT does not look like ImageFolder (missing train/val): ${IMAGENET1K_ROOT}" + exit 2 + fi + + need_prepare=0 + if [ ! -d "${IMAGENET100_ROOT}/train" ] || [ ! -d "${IMAGENET100_ROOT}/val" ]; then + need_prepare=1 + else + n_train=$(find -L "${IMAGENET100_ROOT}/train" -maxdepth 1 -mindepth 1 -type d 2>/dev/null | wc -l || true) + n_val=$(find -L "${IMAGENET100_ROOT}/val" -maxdepth 1 -mindepth 1 -type d 2>/dev/null | wc -l || true) + if [ "${n_train}" -lt 1 ] || [ "${n_val}" -lt 1 ]; then + need_prepare=1 + fi + fi + + if [ "${need_prepare}" -eq 1 ]; then + echo "[info] Preparing ImageNet-100 subset under: ${IMAGENET100_ROOT}" + rm -rf "${IMAGENET100_ROOT}/train" "${IMAGENET100_ROOT}/val" + mkdir -p "${IMAGENET100_ROOT}/train" "${IMAGENET100_ROOT}/val" + + find "${IMAGENET1K_ROOT}/train" -maxdepth 1 -mindepth 1 -type d -printf '%f\n' | sort > "${IMAGENET100_ROOT}/classes_all.txt" + head -n "${IMAGENET100_NCLASSES}" "${IMAGENET100_ROOT}/classes_all.txt" > "${IMAGENET100_ROOT}/classes.txt" + rm -f "${IMAGENET100_ROOT}/classes_all.txt" + + while read -r syn; do + ln -sfn "${IMAGENET1K_ROOT}/train/${syn}" "${IMAGENET100_ROOT}/train/${syn}" + ln -sfn "${IMAGENET1K_ROOT}/val/${syn}" "${IMAGENET100_ROOT}/val/${syn}" + done < "${IMAGENET100_ROOT}/classes.txt" + + n_train=$(find -L "${IMAGENET100_ROOT}/train" -maxdepth 1 -mindepth 1 -type d 2>/dev/null | wc -l || true) + n_val=$(find -L "${IMAGENET100_ROOT}/val" -maxdepth 1 -mindepth 1 -type d 2>/dev/null | wc -l || true) + echo "[info] ImageNet-100 class dirs: train=${n_train} val=${n_val}" + if [ "${n_train}" -lt 1 ] || [ "${n_val}" -lt 1 ]; then + echo "[error] ImageNet-100 subset prep failed: no class dirs found under ${IMAGENET100_ROOT}/{train,val}" + exit 3 + fi + fi +} + +case "${TASK}" in + 0) + echo "[task 0] ResNet-18 / CIFAR-10" + run_py scripts/run_experiment.py \ + --config configs/vision_prune/resnet18_cifar10_unified.yaml \ + --device cuda \ + --base-output-dir "$OUTPUT_BASE" + ;; + 1) + echo "[task 1] VGG-16-BN / CIFAR-10" + run_py scripts/run_experiment.py \ + --config configs/vision_prune/vgg16_cifar10_unified.yaml \ + --device cuda \ + --base-output-dir "$OUTPUT_BASE" + ;; + 2) + echo "[task 2] MobileNetV2 / CIFAR-10" + run_py scripts/run_experiment.py \ + --config configs/vision_prune/mobilenetv2_cifar10_unified.yaml \ + --device cuda \ + --base-output-dir "$OUTPUT_BASE" + ;; + 3) + echo "[task 3] ResNet-50 / ImageNet-100" + prepare_imagenet100 + run_py scripts/run_experiment.py \ + --config configs/vision_prune/resnet50_imagenet100_unified.yaml \ + --device cuda \ + --base-output-dir "$OUTPUT_BASE" + ;; + 4) + echo "[task 4] GAP robustness (ResNet-18, activation_samples=gap)" + run_py scripts/run_experiment.py \ + --config configs/vision_prune/resnet18_cifar10_unified.yaml \ + --device cuda \ + --base-output-dir "$OUTPUT_BASE" \ + name="resnet18_cifar10_cluster_analysis_gap" \ + metrics.activation_samples="gap" \ + pruning_amounts="[]" + ;; + 5) + echo "[task 5] Ablation (ResNet-18 @ 50%: cluster_aware variants + composite)" + run_py scripts/run_experiment.py \ + --config configs/vision_prune/resnet18_cifar10_unified.yaml \ + --device cuda \ + --base-output-dir "$OUTPUT_BASE" \ + name="resnet18_cifar10_cluster_analysis_ablation" \ + pruning_amounts="[0.5]" \ + pruning_distribution="global_threshold" \ + pruning_strategies="['cluster_aware','cluster_aware_no_halo','cluster_aware_no_constraints','composite']" + ;; + *) + # Weight sweep tasks 6-20 (15 tasks) + if [ "${TASK}" -ge 6 ] && [ "${TASK}" -le 20 ]; then + SWEEP_IDX=$((TASK - 6)) + GAMMAS=(0.10 0.30 0.50) + LAMBDAS=(0.00 0.25 0.50 0.75 1.00) + GI=$((SWEEP_IDX / ${#LAMBDAS[@]})) + LI=$((SWEEP_IDX % ${#LAMBDAS[@]})) + GAMMA="${GAMMAS[$GI]}" + LAMBDA="${LAMBDAS[$LI]}" + echo "[task ${TASK}] Weight sweep (ResNet-18, multi-sparsity): gamma=${GAMMA}, lambda_halo=${LAMBDA}" + run_py scripts/run_experiment.py \ + --config configs/vision_prune/resnet18_cifar10_unified.yaml \ + --device cuda \ + --base-output-dir "$OUTPUT_BASE" \ + name="resnet18_cifar10_weightsweep_g${GAMMA}_l${LAMBDA}" \ + pruning_amounts="[0.1,0.3,0.5,0.7,0.8,0.9]" \ + pruning_distribution="global_threshold" \ + pruning_strategies="['cluster_aware']" \ + pruning.cluster_aware.gamma="${GAMMA}" \ + pruning.cluster_aware.lambda_halo="${LAMBDA}" + else + echo "[error] Unknown task id: ${TASK}" + exit 2 + fi + ;; +esac + +echo "" +echo "Done: $(date)" + diff --git a/slurm_jobs/vision_prune/run_damage_prediction_resnet18.sh b/slurm_jobs/vision_prune/run_damage_prediction_resnet18.sh new file mode 100644 index 00000000..11456e54 --- /dev/null +++ b/slurm_jobs/vision_prune/run_damage_prediction_resnet18.sh @@ -0,0 +1,47 @@ +#!/bin/bash +#SBATCH --job-name=vision_r18_damagepred +#SBATCH --output=logs/vision_r18_damagepred_%j.out +#SBATCH --error=logs/vision_r18_damagepred_%j.err +#SBATCH --nodes=1 +#SBATCH --ntasks-per-node=1 +#SBATCH --gres=gpu:1 +#SBATCH --cpus-per-task=8 +#SBATCH --time=4:00:00 +#SBATCH --mem=64GB +#SBATCH --partition=kempner_eng +#SBATCH --account=kempner_dev + +# ---------------------------------------------------------------------------- +# Mechanism evaluation: per-channel damage prediction correlation (ResNet-18) +# ---------------------------------------------------------------------------- + +set -euo pipefail + +OUTPUT_BASE="${OUTPUT_BASE:-/n/holylfs06/LABS/kempner_project_b/Lab/alignment/alignment_red}" + +echo "============================================================================" +echo "Vision Paper: Damage prediction eval (ResNet-18)" +echo "============================================================================" +echo "Job ID: ${SLURM_JOB_ID:-N/A}" +echo "Node: $(hostname)" +echo "Start time: $(date)" +echo "OUTPUT_BASE: $OUTPUT_BASE" +echo "" + +module purge +module load cuda/12.2.0-fasrc01 +eval "$(conda shell.bash hook)" +conda activate networkAlignmentAnalysis + +cd /n/holylabs/kempner_dev/Users/hsafaai/Code/alignment +mkdir -p logs + +python drafts/alignment_notes/paper/scripts/run_damage_prediction.py \ + --results-base "$OUTPUT_BASE" \ + --exp "resnet18_cifar10_cluster_analysis" \ + --damage-frac 0.15 \ + --eval-examples 2000 + +echo "" +echo "Done: $(date)" + diff --git a/slurm_jobs/vision_prune/run_mobilenetv2_cifar10.sh b/slurm_jobs/vision_prune/run_mobilenetv2_cifar10.sh new file mode 100644 index 00000000..d75c99d8 --- /dev/null +++ b/slurm_jobs/vision_prune/run_mobilenetv2_cifar10.sh @@ -0,0 +1,43 @@ +#!/bin/bash +#SBATCH --job-name=vision_mobilenetv2_cifar10 +#SBATCH --output=logs/vision_mobilenetv2_cifar10_%j.out +#SBATCH --error=logs/vision_mobilenetv2_cifar10_%j.err +#SBATCH --nodes=1 +#SBATCH --ntasks-per-node=1 +#SBATCH --gres=gpu:1 +#SBATCH --cpus-per-task=8 +#SBATCH --time=6:00:00 +#SBATCH --mem=96GB +#SBATCH --partition=kempner_eng +#SBATCH --account=kempner_dev + +set -euo pipefail + +OUTPUT_BASE="${OUTPUT_BASE:-/n/holylfs06/LABS/kempner_project_b/Lab/alignment/alignment_red}" + +echo "============================================================================" +echo "Vision Paper: MobileNetV2 on CIFAR-10" +echo "============================================================================" +echo "Job ID: ${SLURM_JOB_ID:-N/A}" +echo "Node: $(hostname)" +echo "Start time: $(date)" +echo "Output Base: $OUTPUT_BASE" +echo "" + +module purge +module load cuda/12.2.0-fasrc01 +eval "$(conda shell.bash hook)" +conda activate networkAlignmentAnalysis + +cd /n/holylabs/kempner_dev/Users/hsafaai/Code/alignment +mkdir -p logs + +python scripts/run_experiment.py \ + --config configs/vision_prune/mobilenetv2_cifar10_unified.yaml \ + --device cuda \ + --base-output-dir "$OUTPUT_BASE" + +echo "" +echo "Done: $(date)" +echo "Look under: $OUTPUT_BASE/ (experiment name: mobilenetv2_cifar10_cluster_analysis_*)" + diff --git a/slurm_jobs/vision_prune/run_resnet18_cifar10.sh b/slurm_jobs/vision_prune/run_resnet18_cifar10.sh new file mode 100644 index 00000000..14a144df --- /dev/null +++ b/slurm_jobs/vision_prune/run_resnet18_cifar10.sh @@ -0,0 +1,44 @@ +#!/bin/bash +#SBATCH --job-name=vision_resnet18_cifar10 +#SBATCH --output=logs/vision_resnet18_cifar10_%j.out +#SBATCH --error=logs/vision_resnet18_cifar10_%j.err +#SBATCH --nodes=1 +#SBATCH --ntasks-per-node=1 +#SBATCH --gres=gpu:1 +#SBATCH --cpus-per-task=8 +#SBATCH --time=4:00:00 +#SBATCH --mem=64GB +#SBATCH --partition=kempner_h100_priority3 +#SBATCH --account=kempner_dev + +set -euo pipefail + +OUTPUT_BASE="${OUTPUT_BASE:-/n/holylfs06/LABS/kempner_project_b/Lab/alignment/alignment_red}" + +echo "============================================================================" +echo "Vision Paper: ResNet-18 on CIFAR-10" +echo "============================================================================" +echo "Job ID: ${SLURM_JOB_ID:-N/A}" +echo "Node: $(hostname)" +echo "Start time: $(date)" +echo "Output Base: $OUTPUT_BASE" +echo "" + +# Environment setup (adjust to your cluster defaults) +module purge +module load cuda/12.2.0-fasrc01 +eval "$(conda shell.bash hook)" +conda activate networkAlignmentAnalysis + +cd /n/holylabs/kempner_dev/Users/hsafaai/Code/alignment +mkdir -p logs + +python scripts/run_experiment.py \ + --config configs/vision_prune/resnet18_cifar10_unified.yaml \ + --device cuda \ + --base-output-dir "$OUTPUT_BASE" + +echo "" +echo "Done: $(date)" +echo "Look under: $OUTPUT_BASE/ (experiment name: resnet18_cifar10_cluster_analysis_*)" + diff --git a/slurm_jobs/vision_prune/run_resnet18_cifar10_ablation.sh b/slurm_jobs/vision_prune/run_resnet18_cifar10_ablation.sh new file mode 100644 index 00000000..202d7506 --- /dev/null +++ b/slurm_jobs/vision_prune/run_resnet18_cifar10_ablation.sh @@ -0,0 +1,54 @@ +#!/bin/bash +#SBATCH --job-name=vision_r18_ablation +#SBATCH --output=logs/vision_r18_ablation_%j.out +#SBATCH --error=logs/vision_r18_ablation_%j.err +#SBATCH --nodes=1 +#SBATCH --ntasks-per-node=1 +#SBATCH --gres=gpu:1 +#SBATCH --cpus-per-task=8 +#SBATCH --time=6:00:00 +#SBATCH --mem=96GB +#SBATCH --partition=kempner_eng +#SBATCH --account=kempner_dev + +# ---------------------------------------------------------------------------- +# ResNet-18 ablation at 50% sparsity: +# - cluster_aware (full) +# - cluster_aware_no_halo (lambda=0) +# - cluster_aware_no_constraints +# - composite +# ---------------------------------------------------------------------------- + +set -euo pipefail + +OUTPUT_BASE="${OUTPUT_BASE:-/n/holylfs06/LABS/kempner_project_b/Lab/alignment/alignment_red}" + +echo "============================================================================" +echo "Vision Paper Ablation: ResNet-18/CIFAR-10 @ 50% sparsity" +echo "============================================================================" +echo "Job ID: ${SLURM_JOB_ID:-N/A}" +echo "Node: $(hostname)" +echo "Start time: $(date)" +echo "Output Base: $OUTPUT_BASE" +echo "" + +module purge +module load cuda/12.2.0-fasrc01 +eval "$(conda shell.bash hook)" +conda activate networkAlignmentAnalysis + +cd /n/holylabs/kempner_dev/Users/hsafaai/Code/alignment +mkdir -p logs + +python scripts/run_experiment.py \ + --config configs/vision_prune/resnet18_cifar10_unified.yaml \ + --device cuda \ + --base-output-dir "$OUTPUT_BASE" \ + name="resnet18_cifar10_cluster_analysis_ablation" \ + pruning_amounts="[0.5]" \ + pruning_distribution="global_threshold" \ + pruning_strategies="['cluster_aware','cluster_aware_no_halo','cluster_aware_no_constraints','composite']" + +echo "" +echo "Done: $(date)" + diff --git a/slurm_jobs/vision_prune/run_resnet18_cifar10_gap.sh b/slurm_jobs/vision_prune/run_resnet18_cifar10_gap.sh new file mode 100644 index 00000000..42e0a9a6 --- /dev/null +++ b/slurm_jobs/vision_prune/run_resnet18_cifar10_gap.sh @@ -0,0 +1,45 @@ +#!/bin/bash +#SBATCH --job-name=vision_r18_gap +#SBATCH --output=logs/vision_r18_gap_%j.out +#SBATCH --error=logs/vision_r18_gap_%j.err +#SBATCH --nodes=1 +#SBATCH --ntasks-per-node=1 +#SBATCH --gres=gpu:1 +#SBATCH --cpus-per-task=8 +#SBATCH --time=4:00:00 +#SBATCH --mem=64GB +#SBATCH --partition=kempner_eng +#SBATCH --account=kempner_dev + +set -euo pipefail + +OUTPUT_BASE="${OUTPUT_BASE:-/n/holylfs06/LABS/kempner_project_b/Lab/alignment/alignment_red}" + +echo "============================================================================" +echo "Vision Paper (Appendix): ResNet-18 GAP robustness run" +echo "============================================================================" +echo "Job ID: ${SLURM_JOB_ID:-N/A}" +echo "Node: $(hostname)" +echo "Start time: $(date)" +echo "Output Base: $OUTPUT_BASE" +echo "" + +module purge +module load cuda/12.2.0-fasrc01 +eval "$(conda shell.bash hook)" +conda activate networkAlignmentAnalysis + +cd /n/holylabs/kempner_dev/Users/hsafaai/Code/alignment +mkdir -p logs + +python scripts/run_experiment.py \ + --config configs/vision_prune/resnet18_cifar10_unified.yaml \ + --device cuda \ + --base-output-dir "$OUTPUT_BASE" \ + name="resnet18_cifar10_cluster_analysis_gap" \ + metrics.activation_samples="gap" \ + pruning_amounts="[]" + +echo "" +echo "Done: $(date)" + diff --git a/slurm_jobs/vision_prune/run_resnet50_imagenet100.sh b/slurm_jobs/vision_prune/run_resnet50_imagenet100.sh new file mode 100644 index 00000000..8ab531fe --- /dev/null +++ b/slurm_jobs/vision_prune/run_resnet50_imagenet100.sh @@ -0,0 +1,103 @@ +#!/bin/bash +#SBATCH --job-name=vision_resnet50_imagenet100 +#SBATCH --output=logs/vision_resnet50_imagenet100_%j.out +#SBATCH --error=logs/vision_resnet50_imagenet100_%j.err +#SBATCH --nodes=1 +#SBATCH --ntasks-per-node=1 +#SBATCH --gres=gpu:1 +#SBATCH --cpus-per-task=8 +#SBATCH --time=12:00:00 +#SBATCH --mem=128GB +#SBATCH --partition=kempner_eng +#SBATCH --account=kempner_dev + +set -euo pipefail + +OUTPUT_BASE="${OUTPUT_BASE:-/n/holylfs06/LABS/kempner_project_b/Lab/alignment/alignment_red}" + +echo "============================================================================" +echo "Vision Paper: ResNet-50 on ImageNet-100" +echo "============================================================================" +echo "Job ID: ${SLURM_JOB_ID:-N/A}" +echo "Node: $(hostname)" +echo "Start time: $(date)" +echo "Output Base: $OUTPUT_BASE" +echo "" + +module purge +module load cuda/12.2.0-fasrc01 +eval "$(conda shell.bash hook)" +conda activate networkAlignmentAnalysis + +cd /n/holylabs/kempner_dev/Users/hsafaai/Code/alignment +mkdir -p logs + +# ---------------------------------------------------------------------------- +# ImageNet-100 data prep +# ---------------------------------------------------------------------------- +# The Kempner shared repository base is documented here: +# /n/holylfs06/LABS/kempner_shared/Everyone/testbed/vision +# (see: https://handbook.eng.kempnerinstitute.harvard.edu/...) +# +# This job expects an ImageFolder-style ImageNet-100 subset at: +# ./data/imagenet100/{train,val}// +# If it doesn't exist, we create it by symlinking the first 100 synsets +# (lexicographic order) from the shared ImageNet-1k. + +IMAGENET1K_ROOT="${IMAGENET1K_ROOT:-/n/holylfs06/LABS/kempner_shared/Everyone/testbed/vision/imagenet_1k}" +IMAGENET100_ROOT="${IMAGENET100_ROOT:-$PWD/data/imagenet100}" +IMAGENET100_NCLASSES="${IMAGENET100_NCLASSES:-100}" + +if [ ! -d "${IMAGENET1K_ROOT}/train" ] || [ ! -d "${IMAGENET1K_ROOT}/val" ]; then + echo "[error] IMAGENET1K_ROOT does not look like ImageFolder (missing train/val): ${IMAGENET1K_ROOT}" + exit 2 +fi + +need_prepare=0 +if [ ! -d "${IMAGENET100_ROOT}/train" ] || [ ! -d "${IMAGENET100_ROOT}/val" ]; then + need_prepare=1 +else + # Detect the "exists but empty" case (e.g., a previous run died mid-setup). + # Use `find -L` so symlinked class dirs count as directories. + n_train=$(find -L "${IMAGENET100_ROOT}/train" -maxdepth 1 -mindepth 1 -type d 2>/dev/null | wc -l || true) + n_val=$(find -L "${IMAGENET100_ROOT}/val" -maxdepth 1 -mindepth 1 -type d 2>/dev/null | wc -l || true) + if [ "${n_train}" -lt 1 ] || [ "${n_val}" -lt 1 ]; then + need_prepare=1 + fi +fi + +if [ "${need_prepare}" -eq 1 ]; then + echo "[info] Preparing ImageNet-100 subset under: ${IMAGENET100_ROOT}" + rm -rf "${IMAGENET100_ROOT}/train" "${IMAGENET100_ROOT}/val" + mkdir -p "${IMAGENET100_ROOT}/train" "${IMAGENET100_ROOT}/val" + # Avoid SIGPIPE under `set -o pipefail` by not truncating a pipeline early. + find "${IMAGENET1K_ROOT}/train" -maxdepth 1 -mindepth 1 -type d -printf '%f\n' \ + | sort \ + > "${IMAGENET100_ROOT}/classes_all.txt" + head -n "${IMAGENET100_NCLASSES}" "${IMAGENET100_ROOT}/classes_all.txt" \ + > "${IMAGENET100_ROOT}/classes.txt" + rm -f "${IMAGENET100_ROOT}/classes_all.txt" + while read -r syn; do + ln -sfn "${IMAGENET1K_ROOT}/train/${syn}" "${IMAGENET100_ROOT}/train/${syn}" + ln -sfn "${IMAGENET1K_ROOT}/val/${syn}" "${IMAGENET100_ROOT}/val/${syn}" + done < "${IMAGENET100_ROOT}/classes.txt" + echo "[info] Wrote class list: ${IMAGENET100_ROOT}/classes.txt" + + n_train=$(find -L "${IMAGENET100_ROOT}/train" -maxdepth 1 -mindepth 1 -type d 2>/dev/null | wc -l || true) + n_val=$(find -L "${IMAGENET100_ROOT}/val" -maxdepth 1 -mindepth 1 -type d 2>/dev/null | wc -l || true) + echo "[info] ImageNet-100 class dirs: train=${n_train} val=${n_val}" + if [ "${n_train}" -lt 1 ] || [ "${n_val}" -lt 1 ]; then + echo "[error] ImageNet-100 subset prep failed: no class dirs found under ${IMAGENET100_ROOT}/{train,val}" + exit 3 + fi +fi + +python scripts/run_experiment.py \ + --config configs/vision_prune/resnet50_imagenet100_unified.yaml \ + --device cuda \ + --base-output-dir "$OUTPUT_BASE" + +echo "" +echo "Done: $(date)" +echo "Look under: $OUTPUT_BASE/ (experiment name: resnet50_imagenet100_cluster_analysis_*)" + diff --git a/slurm_jobs/vision_prune/run_vgg16_cifar10.sh b/slurm_jobs/vision_prune/run_vgg16_cifar10.sh new file mode 100644 index 00000000..f56899d2 --- /dev/null +++ b/slurm_jobs/vision_prune/run_vgg16_cifar10.sh @@ -0,0 +1,43 @@ +#!/bin/bash +#SBATCH --job-name=vision_vgg16_cifar10 +#SBATCH --output=logs/vision_vgg16_cifar10_%j.out +#SBATCH --error=logs/vision_vgg16_cifar10_%j.err +#SBATCH --nodes=1 +#SBATCH --ntasks-per-node=1 +#SBATCH --gres=gpu:1 +#SBATCH --cpus-per-task=8 +#SBATCH --time=6:00:00 +#SBATCH --mem=96GB +#SBATCH --partition=kempner_eng +#SBATCH --account=kempner_dev + +set -euo pipefail + +OUTPUT_BASE="${OUTPUT_BASE:-/n/holylfs06/LABS/kempner_project_b/Lab/alignment/alignment_red}" + +echo "============================================================================" +echo "Vision Paper: VGG-16-BN on CIFAR-10" +echo "============================================================================" +echo "Job ID: ${SLURM_JOB_ID:-N/A}" +echo "Node: $(hostname)" +echo "Start time: $(date)" +echo "Output Base: $OUTPUT_BASE" +echo "" + +module purge +module load cuda/12.2.0-fasrc01 +eval "$(conda shell.bash hook)" +conda activate networkAlignmentAnalysis + +cd /n/holylabs/kempner_dev/Users/hsafaai/Code/alignment +mkdir -p logs + +python scripts/run_experiment.py \ + --config configs/vision_prune/vgg16_cifar10_unified.yaml \ + --device cuda \ + --base-output-dir "$OUTPUT_BASE" + +echo "" +echo "Done: $(date)" +echo "Look under: $OUTPUT_BASE/ (experiment name: vgg16_cifar10_cluster_analysis_*)" + diff --git a/slurm_jobs/vision_prune/run_weightsweep_resnet18_array.sh b/slurm_jobs/vision_prune/run_weightsweep_resnet18_array.sh new file mode 100644 index 00000000..4698799b --- /dev/null +++ b/slurm_jobs/vision_prune/run_weightsweep_resnet18_array.sh @@ -0,0 +1,68 @@ +#!/bin/bash +#SBATCH --job-name=vision_r18_wtsweep +#SBATCH --output=logs/vision_r18_wtsweep_%A_%a.out +#SBATCH --error=logs/vision_r18_wtsweep_%A_%a.err +#SBATCH --nodes=1 +#SBATCH --ntasks-per-node=1 +#SBATCH --gres=gpu:1 +#SBATCH --cpus-per-task=8 +#SBATCH --time=6:00:00 +#SBATCH --mem=96GB +#SBATCH --partition=kempner_eng +#SBATCH --account=kempner_dev +#SBATCH --array=0-14 + +# ---------------------------------------------------------------------------- +# Vision paper sweep: cluster-aware score weight sensitivity +# We sweep (gamma, lambda_halo) while holding alpha=1.0, beta=0.5. +# Each task runs: +# - ResNet-18 / CIFAR-10 +# - method: cluster_aware +# - pruning across multiple sparsity ratios (so figures show the pruning effect) +# ---------------------------------------------------------------------------- + +set -euo pipefail + +GAMMAS=(0.10 0.30 0.50) +LAMBDAS=(0.00 0.25 0.50 0.75 1.00) + +IDX="${SLURM_ARRAY_TASK_ID}" +GI=$((IDX / ${#LAMBDAS[@]})) +LI=$((IDX % ${#LAMBDAS[@]})) + +GAMMA="${GAMMAS[$GI]}" +LAMBDA="${LAMBDAS[$LI]}" + +OUTPUT_BASE="${OUTPUT_BASE:-/n/holylfs06/LABS/kempner_project_b/Lab/alignment/alignment_red}" + +echo "============================================================================" +echo "Vision Paper Sweep: ResNet-18 weight sensitivity (gamma=${GAMMA}, lambda=${LAMBDA})" +echo "============================================================================" +echo "Job ID: ${SLURM_JOB_ID:-N/A} Array Task: ${SLURM_ARRAY_TASK_ID}" +echo "Node: $(hostname)" +echo "Start time: $(date)" +echo "Output Base: $OUTPUT_BASE" +echo "" + +module purge +module load cuda/12.2.0-fasrc01 +eval "$(conda shell.bash hook)" +conda activate networkAlignmentAnalysis + +cd /n/holylabs/kempner_dev/Users/hsafaai/Code/alignment +mkdir -p logs + +python scripts/run_experiment.py \ + --config configs/vision_prune/resnet18_cifar10_unified.yaml \ + --device cuda \ + --base-output-dir "$OUTPUT_BASE" \ + name="resnet18_cifar10_weightsweep_g${GAMMA}_l${LAMBDA}" \ + pruning_amounts="[0.1,0.3,0.5,0.7,0.8,0.9]" \ + pruning_distribution="global_threshold" \ + pruning_strategies="['cluster_aware']" \ + pruning.cluster_aware.gamma="${GAMMA}" \ + pruning.cluster_aware.lambda_halo="${LAMBDA}" + +echo "" +echo "Done: $(date)" + diff --git a/slurm_jobs/vision_prune/submit_all.sh b/slurm_jobs/vision_prune/submit_all.sh new file mode 100644 index 00000000..dafad40d --- /dev/null +++ b/slurm_jobs/vision_prune/submit_all.sh @@ -0,0 +1,83 @@ +#!/bin/bash +# ============================================================================ +# SUBMIT FULL VISION PAPER: SUITE + APPENDIX + (DEPENDENT) ARTIFACT BUILD JOB +# ============================================================================ +# Usage: +# cd /n/holylabs/kempner_dev/Users/hsafaai/Code/alignment +# export OUTPUT_BASE="/n/holylfs06/LABS/kempner_project_b/Lab/alignment/alignment_red" +# bash slurm_jobs/vision_prune/submit_all.sh +# +# This submits: +# - main suite jobs (4 models) +# - appendix jobs (GAP, ablation, weight sweep array, damage prediction) +# - a final build job that runs build_all_artifacts.py after all above succeed +# ============================================================================ + +set -euo pipefail + +OUTPUT_BASE="${OUTPUT_BASE:-/n/holylfs06/LABS/kempner_project_b/Lab/alignment/alignment_red}" + +# Guardrail: avoid accidentally writing into the repo via a relative placeholder. +if [[ "$OUTPUT_BASE" != /* ]]; then + echo "[error] OUTPUT_BASE must be an absolute path. Got: $OUTPUT_BASE" + echo "[hint] Use: export OUTPUT_BASE=\"/n/holylfs06/LABS/kempner_project_b/Lab/alignment/alignment_red\"" + exit 1 +fi +mkdir -p "$OUTPUT_BASE" + +echo "==============================================" +echo "Submitting Vision Paper: ALL jobs" +echo "==============================================" +echo "OUTPUT_BASE: $OUTPUT_BASE" +echo "" + +cd /n/holylabs/kempner_dev/Users/hsafaai/Code/alignment +mkdir -p logs + +export OUTPUT_BASE + +# ---------------------------- +# Main suite +# ---------------------------- +JOB_R18=$(sbatch --export=ALL,OUTPUT_BASE="$OUTPUT_BASE" slurm_jobs/vision_prune/run_resnet18_cifar10.sh | awk '{print $4}') +echo "ResNet-18/CIFAR-10: $JOB_R18" + +JOB_VGG=$(sbatch --export=ALL,OUTPUT_BASE="$OUTPUT_BASE" slurm_jobs/vision_prune/run_vgg16_cifar10.sh | awk '{print $4}') +echo "VGG-16-BN/CIFAR-10: $JOB_VGG" + +JOB_MBV2=$(sbatch --export=ALL,OUTPUT_BASE="$OUTPUT_BASE" slurm_jobs/vision_prune/run_mobilenetv2_cifar10.sh | awk '{print $4}') +echo "MobileNetV2/CIFAR-10: $JOB_MBV2" + +JOB_R50=$(sbatch --export=ALL,OUTPUT_BASE="$OUTPUT_BASE" slurm_jobs/vision_prune/run_resnet50_imagenet100.sh | awk '{print $4}') +echo "ResNet-50/ImageNet-100: $JOB_R50" + +# ---------------------------- +# Appendix / robustness +# ---------------------------- +JOB_GAP=$(sbatch --export=ALL,OUTPUT_BASE="$OUTPUT_BASE" slurm_jobs/vision_prune/run_resnet18_cifar10_gap.sh | awk '{print $4}') +echo "GAP robustness (ResNet-18): $JOB_GAP" + +JOB_ABL=$(sbatch --export=ALL,OUTPUT_BASE="$OUTPUT_BASE" slurm_jobs/vision_prune/run_resnet18_cifar10_ablation.sh | awk '{print $4}') +echo "Ablation (ResNet-18 @ 50%): $JOB_ABL" + +JOB_WS=$(sbatch --export=ALL,OUTPUT_BASE="$OUTPUT_BASE" slurm_jobs/vision_prune/run_weightsweep_resnet18_array.sh | awk '{print $4}') +echo "Weight sweep array (ResNet-18): $JOB_WS" + +# Damage prediction should wait for the main ResNet-18 run (needs its checkpoint/results). +JOB_DP=$(sbatch --dependency=afterok:${JOB_R18} --export=ALL,OUTPUT_BASE="$OUTPUT_BASE" slurm_jobs/vision_prune/run_damage_prediction_resnet18.sh | awk '{print $4}') +echo "Damage prediction eval (ResNet-18, afterok:$JOB_R18): $JOB_DP" + +# ---------------------------- +# Final artifact build job (depends on all above) +# ---------------------------- +DEP="afterany:${JOB_R18}:${JOB_VGG}:${JOB_MBV2}:${JOB_R50}:${JOB_GAP}:${JOB_ABL}:${JOB_WS}:${JOB_DP}" +JOB_BUILD=$(sbatch --dependency=$DEP --export=ALL,OUTPUT_BASE="$OUTPUT_BASE" slurm_jobs/vision_prune/build_artifacts.sh | awk '{print $4}') +echo "Build all artifacts (afterany:all): $JOB_BUILD" + +echo "" +echo "==============================================" +echo "All jobs submitted" +echo "==============================================" +echo "Monitor with: squeue -u \$USER" +echo "" + diff --git a/slurm_jobs/vision_prune/submit_all_array.sh b/slurm_jobs/vision_prune/submit_all_array.sh new file mode 100644 index 00000000..1b637b86 --- /dev/null +++ b/slurm_jobs/vision_prune/submit_all_array.sh @@ -0,0 +1,54 @@ +#!/bin/bash +# ============================================================================ +# SUBMIT FULL VISION PAPER: ONE ARRAY JOB (MAX 16 GPUs) + DEPENDENT BUILD JOB +# ============================================================================ +# Usage: +# cd /n/holylabs/kempner_dev/Users/hsafaai/Code/alignment +# export OUTPUT_BASE="/n/holylfs06/LABS/kempner_project_b/Lab/alignment/alignment_red" +# bash slurm_jobs/vision_prune/submit_all_array.sh +# +# What this does: +# - Submits ONE array job that runs all suite + appendix runs +# - Caps concurrency to 16 tasks (== 16 GPUs, assuming 1 GPU per task) +# - Schedules build_artifacts.sh after the array completes (afterany) +# ============================================================================ + +set -euo pipefail + +OUTPUT_BASE="${OUTPUT_BASE:-/n/holylfs06/LABS/kempner_project_b/Lab/alignment/alignment_red}" + +if [[ "$OUTPUT_BASE" != /* ]]; then + echo "[error] OUTPUT_BASE must be an absolute path. Got: $OUTPUT_BASE" + echo "[hint] Use: export OUTPUT_BASE=\"/n/holylfs06/LABS/kempner_project_b/Lab/alignment/alignment_red\"" + exit 1 +fi +mkdir -p "$OUTPUT_BASE" + +echo "==============================================" +echo "Submitting Vision Paper: ALL runs as ONE array" +echo "==============================================" +echo "OUTPUT_BASE: $OUTPUT_BASE" +echo "" + +cd /n/holylabs/kempner_dev/Users/hsafaai/Code/alignment +mkdir -p logs + +export OUTPUT_BASE + +# 21 tasks total (0..20). Concurrency cap: 16 GPUs max. +JOB_ALL=$(sbatch --array=0-20%16 --export=ALL,OUTPUT_BASE="$OUTPUT_BASE" slurm_jobs/vision_prune/run_all_array.sh | awk '{print $4}') +echo "Array job (0-20%16): $JOB_ALL" + +# Build job after the array finishes (even if some tasks fail). +JOB_BUILD=$(sbatch --dependency=afterany:${JOB_ALL} --export=ALL,OUTPUT_BASE="$OUTPUT_BASE" slurm_jobs/vision_prune/build_artifacts.sh | awk '{print $4}') +echo "Build all artifacts (afterany:${JOB_ALL}): $JOB_BUILD" + +echo "" +echo "==============================================" +echo "Submitted." +echo "==============================================" +echo "Monitor:" +echo " squeue -u $USER" +echo " sacct -j ${JOB_ALL} --format=JobID,State,ExitCode,Elapsed" +echo "" + diff --git a/slurm_jobs/vision_prune/submit_appendix.sh b/slurm_jobs/vision_prune/submit_appendix.sh new file mode 100644 index 00000000..31f32304 --- /dev/null +++ b/slurm_jobs/vision_prune/submit_appendix.sh @@ -0,0 +1,51 @@ +#!/bin/bash +# ============================================================================ +# SUBMIT VISION PAPER APPENDIX SUITE (robustness + sweeps) +# ============================================================================ +# Usage: +# cd /n/holylabs/kempner_dev/Users/hsafaai/Code/alignment +# export OUTPUT_BASE="/n/holylfs06/LABS/kempner_project_b/Lab/alignment/alignment_red" +# bash slurm_jobs/vision_prune/submit_appendix.sh +# ============================================================================ + +set -euo pipefail + +OUTPUT_BASE="${OUTPUT_BASE:-/n/holylfs06/LABS/kempner_project_b/Lab/alignment/alignment_red}" + +if [[ "$OUTPUT_BASE" != /* ]]; then + echo "[error] OUTPUT_BASE must be an absolute path. Got: $OUTPUT_BASE" + echo "[hint] Use: export OUTPUT_BASE=\"/n/holylfs06/LABS/kempner_project_b/Lab/alignment/alignment_red\"" + exit 1 +fi +mkdir -p "$OUTPUT_BASE" + +echo "==============================================" +echo "Submitting Vision Paper Appendix Suite" +echo "==============================================" +echo "OUTPUT_BASE: $OUTPUT_BASE" +echo "" + +cd /n/holylabs/kempner_dev/Users/hsafaai/Code/alignment +mkdir -p logs + +export OUTPUT_BASE + +JOB_GAP=$(sbatch --export=ALL,OUTPUT_BASE="$OUTPUT_BASE" slurm_jobs/vision_prune/run_resnet18_cifar10_gap.sh | awk '{print $4}') +echo "GAP robustness (ResNet-18): $JOB_GAP" + +JOB_ABL=$(sbatch --export=ALL,OUTPUT_BASE="$OUTPUT_BASE" slurm_jobs/vision_prune/run_resnet18_cifar10_ablation.sh | awk '{print $4}') +echo "Ablation (ResNet-18 @ 50%): $JOB_ABL" + +JOB_WS=$(sbatch --export=ALL,OUTPUT_BASE="$OUTPUT_BASE" slurm_jobs/vision_prune/run_weightsweep_resnet18_array.sh | awk '{print $4}') +echo "Weight sweep array (ResNet-18): $JOB_WS" + +JOB_DP=$(sbatch --export=ALL,OUTPUT_BASE="$OUTPUT_BASE" slurm_jobs/vision_prune/run_damage_prediction_resnet18.sh | awk '{print $4}') +echo "Damage prediction eval (ResNet-18): $JOB_DP" + +echo "" +echo "==============================================" +echo "Appendix jobs submitted" +echo "==============================================" +echo "Monitor with: squeue -u \$USER" +echo "" + diff --git a/slurm_jobs/vision_prune/submit_suite.sh b/slurm_jobs/vision_prune/submit_suite.sh new file mode 100644 index 00000000..29be8dc2 --- /dev/null +++ b/slurm_jobs/vision_prune/submit_suite.sh @@ -0,0 +1,51 @@ +#!/bin/bash +# ============================================================================ +# SUBMIT FULL VISION PAPER SUITE +# ============================================================================ +# Usage: +# cd /n/holylabs/kempner_dev/Users/hsafaai/Code/alignment +# export OUTPUT_BASE="/n/holylfs06/LABS/kempner_project_b/Lab/alignment/alignment_red" +# bash slurm_jobs/vision_prune/submit_suite.sh +# ============================================================================ + +set -euo pipefail + +OUTPUT_BASE="${OUTPUT_BASE:-/n/holylfs06/LABS/kempner_project_b/Lab/alignment/alignment_red}" + +if [[ "$OUTPUT_BASE" != /* ]]; then + echo "[error] OUTPUT_BASE must be an absolute path. Got: $OUTPUT_BASE" + echo "[hint] Use: export OUTPUT_BASE=\"/n/holylfs06/LABS/kempner_project_b/Lab/alignment/alignment_red\"" + exit 1 +fi +mkdir -p "$OUTPUT_BASE" + +echo "==============================================" +echo "Submitting Vision Paper Suite" +echo "==============================================" +echo "OUTPUT_BASE: $OUTPUT_BASE" +echo "" + +cd /n/holylabs/kempner_dev/Users/hsafaai/Code/alignment +mkdir -p logs + +export OUTPUT_BASE + +JOB_R18=$(sbatch --export=ALL,OUTPUT_BASE="$OUTPUT_BASE" slurm_jobs/vision_prune/run_resnet18_cifar10.sh | awk '{print $4}') +echo "ResNet-18/CIFAR-10: $JOB_R18" + +JOB_VGG=$(sbatch --export=ALL,OUTPUT_BASE="$OUTPUT_BASE" slurm_jobs/vision_prune/run_vgg16_cifar10.sh | awk '{print $4}') +echo "VGG-16-BN/CIFAR-10: $JOB_VGG" + +JOB_MBV2=$(sbatch --export=ALL,OUTPUT_BASE="$OUTPUT_BASE" slurm_jobs/vision_prune/run_mobilenetv2_cifar10.sh | awk '{print $4}') +echo "MobileNetV2/CIFAR-10: $JOB_MBV2" + +JOB_R50=$(sbatch --export=ALL,OUTPUT_BASE="$OUTPUT_BASE" slurm_jobs/vision_prune/run_resnet50_imagenet100.sh | awk '{print $4}') +echo "ResNet-50/ImageNet-100: $JOB_R50" + +echo "" +echo "==============================================" +echo "All suite jobs submitted" +echo "==============================================" +echo "Monitor with: squeue -u \$USER" +echo "" + diff --git a/src/alignment/analysis/cascade_analysis.py b/src/alignment/analysis/cascade_analysis.py index 06d6a53b..2b988e3b 100644 --- a/src/alignment/analysis/cascade_analysis.py +++ b/src/alignment/analysis/cascade_analysis.py @@ -129,7 +129,9 @@ def compute_damages(self, n_ch: int, frac: float = 0.2) -> np.ndarray: test_idx = np.random.choice(n_ch, max(1, int(n_ch * frac)), replace=False) for i in test_idx: r = self.cascade.ablate(self.layer, [int(i)]) - damages[i] = r.accuracy_drop + # Use loss increase as a smoother "damage" signal than accuracy drop, + # especially when evaluating on a small test subset. + damages[i] = r.loss_increase self._damages = damages return damages @@ -143,12 +145,19 @@ def evaluate(self, scores: np.ndarray, method: str = "composite", if mask.sum() < 5: return DamageResult(self.layer, method, 0., {}) d, s = self._damages[mask], scores[mask] + # In the paper scripts we treat `scores` as a *prune score* where higher + # means "safer to remove". A good prune score should correlate with + # *lower* damage; we therefore correlate against -d so higher rho is better. rho, _ = stats.spearmanr(s, -d) recall = {} - by_d = np.argsort(-d) - by_s = np.argsort(s) + # Recall@k: how well the prune score identifies the least-damaging channels. + by_d = np.argsort(d) # least damaging first + by_s = np.argsort(-s) # highest prune score first for k in top_ks: - k = min(k, len(d)) - overlap = len(set(by_d[:k]) & set(by_s[:k])) - recall[k] = overlap / k if k > 0 else 0. + # Keep the dictionary key as the *requested* k for stable downstream + # table formatting, but clamp the effective k to the number of + # evaluated channels. + k_eff = min(int(k), len(d)) + overlap = len(set(by_d[:k_eff]) & set(by_s[:k_eff])) + recall[int(k)] = overlap / k_eff if k_eff > 0 else 0.0 return DamageResult(self.layer, method, float(rho) if not np.isnan(rho) else 0., recall) diff --git a/src/alignment/analysis/visualization/paper_plots.py b/src/alignment/analysis/visualization/paper_plots.py index f5b03b90..2b8f57bb 100644 --- a/src/alignment/analysis/visualization/paper_plots.py +++ b/src/alignment/analysis/visualization/paper_plots.py @@ -124,8 +124,9 @@ def plot_halo_structure( max_points: int = 60000, ) -> plt.Figure: """ - Two-panel plot: + Three-panel plot: (Left) Conn vs redundancy-to-core (halo channels) + (Middle) Redundancy-to-core distribution: halo vs non-halo (sample where defined) (Right) Protect vs Conn (all channels; halo emphasized) """ conn_np = _to_numpy(conn).astype(np.float64).reshape(-1) @@ -149,7 +150,7 @@ def plot_halo_structure( idx_non = idx_all[(~halo_np[idx_all]) & (~super_np[idx_all])] idx_sup = idx_all[super_np[idx_all]] - fig, axes = plt.subplots(1, 2, figsize=(12, 4.2)) + fig, axes = plt.subplots(1, 3, figsize=(15, 4.2)) # Panel A: Conn vs redundancy-to-core (halo only, since redundancy is defined there) ax = axes[0] @@ -169,8 +170,50 @@ def plot_halo_structure( if y.size > 0 and np.nanmin(y) > 0: ax.set_yscale("log") - # Panel B: Protect vs Conn (all channels) + # Panel B: Redundancy-to-core distribution comparison (halo vs non-halo sample) ax = axes[1] + y_h = red_np[idx_halo] + y_n = red_np[idx_non] + y_h = y_h[np.isfinite(y_h)] + y_n = y_n[np.isfinite(y_n)] + + if y_h.size == 0 or y_n.size == 0: + ax.text( + 0.5, + 0.5, + "Redundancy-to-core\n(non-halo sample unavailable)", + ha="center", + va="center", + transform=ax.transAxes, + fontsize=10, + color="#2c3e50", + ) + ax.set_axis_off() + else: + bp = ax.boxplot( + [y_h, y_n], + vert=True, + patch_artist=True, + showfliers=False, + medianprops=dict(color="#2c3e50", linewidth=2), + boxprops=dict(linewidth=1.2, color="#2c3e50"), + whiskerprops=dict(linewidth=1.2, color="#2c3e50"), + capprops=dict(linewidth=1.2, color="#2c3e50"), + ) + colors = ["#1f77b4", "#7f8c8d"] + for patch, c in zip(bp.get("boxes", []), colors): + patch.set_facecolor(c) + patch.set_alpha(0.75) + + ax.set_xticklabels([f"Halo\n(n={y_h.size})", f"Non-halo\n(sample, n={y_n.size})"]) + ax.set_ylabel(r"Redundancy to core $\mathrm{Red}^{\rightarrow \mathcal{M}}$") + ax.set_title("Halo vs non-halo\nredundancy-to-core") + ax.grid(True, alpha=0.25) + if y_h.size > 0 and y_n.size > 0 and np.nanmin(np.concatenate([y_h, y_n])) > 0: + ax.set_yscale("log") + + # Panel C: Protect vs Conn (all channels) + ax = axes[2] ax.scatter(conn_np[idx_non], prot_np[idx_non], s=6, alpha=0.15, color="#7f8c8d", label="Non-halo", edgecolors="none") ax.scatter(conn_np[idx_halo], prot_np[idx_halo], s=10, alpha=0.35, color="#1f77b4", label="Halo", edgecolors="none") if idx_sup.size > 0: diff --git a/src/alignment/experiments/llm_experiments.py b/src/alignment/experiments/llm_experiments.py index dc426343..6524bd62 100644 --- a/src/alignment/experiments/llm_experiments.py +++ b/src/alignment/experiments/llm_experiments.py @@ -365,7 +365,7 @@ def evaluate_multiple_metrics( except Exception as e: logger.error(f"Failed to evaluate perplexity (shared): {e}") ppl_cached = None - + for metric in metrics: num_fewshot = fewshot_settings.get(metric, 0) try: @@ -1558,11 +1558,27 @@ def _wrap_existing_hf_model(self) -> None: logger.info(f"Loading tokenizer for existing HF causal LM '{model_id}'") tokenizer = AutoTokenizer.from_pretrained(model_id, **self.config.tokenizer_kwargs) - if tokenizer.pad_token is None: - tokenizer.pad_token = tokenizer.eos_token + # Ensure the tokenizer can pad batched inputs (needed by several analysis utilities). + # For causal LMs, padding with EOS is standard; fall back to BOS/UNK if EOS is unavailable. + added_special = 0 + if getattr(tokenizer, "pad_token", None) is None or getattr(tokenizer, "pad_token_id", None) is None: + if getattr(tokenizer, "eos_token", None) is not None: + tokenizer.pad_token = tokenizer.eos_token + elif getattr(tokenizer, "bos_token", None) is not None: + tokenizer.pad_token = tokenizer.bos_token + elif getattr(tokenizer, "unk_token", None) is not None: + tokenizer.pad_token = tokenizer.unk_token + if getattr(tokenizer, "pad_token", None) is None or getattr(tokenizer, "pad_token_id", None) is None: + # Last resort: add a PAD token (should almost never trigger for Llama-family tokenizers). + added_special = tokenizer.add_special_tokens({"pad_token": "[PAD]"}) # Unwrap underlying HF model if we're holding a small wrapper (e.g., HFCausalLM) hf_model = getattr(self.model, "model", self.model) + if added_special > 0: + try: + hf_model.resize_token_embeddings(len(tokenizer)) + except Exception: + pass # Wrap with TransformerWrapper (expects an nn.Module) wrapper_kwargs = {"tracked_layers": getattr(self.config, "tracked_layers", None)} @@ -1588,8 +1604,19 @@ def _load_hf_tokenizer_and_model(self) -> None: logger.info(f"Loading tokenizer for {model_id}") tokenizer = AutoTokenizer.from_pretrained(model_id, **self.config.tokenizer_kwargs) - if tokenizer.pad_token is None: - tokenizer.pad_token = tokenizer.eos_token + # Ensure the tokenizer can pad batched inputs (needed by several analysis utilities). + # For causal LMs, padding with EOS is standard; fall back to BOS/UNK if EOS is unavailable. + added_special = 0 + if getattr(tokenizer, "pad_token", None) is None or getattr(tokenizer, "pad_token_id", None) is None: + if getattr(tokenizer, "eos_token", None) is not None: + tokenizer.pad_token = tokenizer.eos_token + elif getattr(tokenizer, "bos_token", None) is not None: + tokenizer.pad_token = tokenizer.bos_token + elif getattr(tokenizer, "unk_token", None) is not None: + tokenizer.pad_token = tokenizer.unk_token + if getattr(tokenizer, "pad_token", None) is None or getattr(tokenizer, "pad_token_id", None) is None: + # Last resort: add a PAD token (should almost never trigger for Llama-family tokenizers). + added_special = tokenizer.add_special_tokens({"pad_token": "[PAD]"}) # load model config and model with dtype/device options model_kwargs = dict(self.config.model_kwargs or {}) @@ -1606,6 +1633,11 @@ def _load_hf_tokenizer_and_model(self) -> None: logger.info(f"Loading HF model {model_id} with dtype={torch_dtype} device_map={device_map}") hf_model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype=torch_dtype, device_map=device_map, **model_kwargs) + if added_special > 0: + try: + hf_model.resize_token_embeddings(len(tokenizer)) + except Exception: + pass # Move model to explicit device if device_map not used if device_map is None: @@ -2412,6 +2444,7 @@ def _resolve_mlp_path(layer_idx: int) -> Optional[str]: ) except Exception as e: logger.warning(f"Failed to compute Wanda channel scores for {mlp_path}: {e}") + continue logger.info(f"Wanda: computed channel scores for {len(layer_indices)} MLP layers") except Exception as e: @@ -2471,6 +2504,7 @@ def _resolve_mlp_path(layer_idx: int) -> Optional[str]: ) except Exception as e: logger.warning(f"Failed to compute SparseGPT channel scores for {mlp_path}: {e}") + continue logger.info(f"SparseGPT: computed channel scores for {len(layer_indices)} MLP layers") except Exception as e: @@ -4763,6 +4797,31 @@ def compute_supernode_connectivity_pruning_score( conn = abs_W.index_select(0, core_idx).sum(dim=0) / v_norm conn = conn.clamp(0.0, 1.0) + # Optional post-processing to give Conn more dynamic range when needed. + # + # - rank-normalize Conn among non-supernodes (maps to [0,1] by empirical CDF) + # - apply a power transform (power < 1 increases small Conn values; power > 1 shrinks them) + if bool(supernode_cfg.get("connectivity_rank_normalize", False)): + non_super_idx_for_rank = (~super_mask).nonzero(as_tuple=True)[0] + if non_super_idx_for_rank.numel() > 1: + vals = conn[non_super_idx_for_rank] + _, order = torch.sort(vals, stable=True) # ascending + ranks = torch.empty_like(order, dtype=torch.float32) + ranks[order] = torch.arange(order.numel(), dtype=torch.float32) + ranks = ranks / float(max(1, order.numel() - 1)) + conn_rank = conn.clone() + conn_rank[non_super_idx_for_rank] = ranks + conn_rank[super_idx] = 1.0 + conn = conn_rank + + conn_power = supernode_cfg.get("connectivity_power", 1.0) + try: + conn_power_f = float(conn_power) + except Exception: + conn_power_f = 1.0 + if conn_power_f != 1.0: + conn = conn.clamp(0.0, 1.0).pow(conn_power_f).clamp(0.0, 1.0) + # Halo: top eta among non-supernodes by Conn non_super_idx = (~super_mask).nonzero(as_tuple=True)[0] if non_super_idx.numel() == 0: @@ -4772,20 +4831,46 @@ def compute_supernode_connectivity_pruning_score( _, halo_rel = torch.topk(halo_scores, k=num_halo, largest=True) halo_idx = non_super_idx[halo_rel].long() + # Optional: sample a subset of *non-halo* channels for redundancy-to-core analysis. + # This lets us explicitly compare halo-to-core redundancy vs non-halo-to-core redundancy + # without the prohibitive cost of computing redundancy for *all* non-halo channels. + non_halo_sample_size = int(supernode_cfg.get("non_halo_sample_size", 256) or 0) + non_halo_idx = torch.empty((0,), dtype=torch.long) + if non_halo_sample_size > 0: + halo_mask_tmp = torch.zeros(m, dtype=torch.bool) + halo_mask_tmp[halo_idx] = True + non_halo_all = (~super_mask & ~halo_mask_tmp).nonzero(as_tuple=True)[0] + if non_halo_all.numel() > 0: + sample_n = min(non_halo_sample_size, int(non_halo_all.numel())) + seed_base = int(supernode_cfg.get("non_halo_sample_seed", 0) or 0) + try: + layer_idx_int = int(layer_name.split("layers.")[-1].split(".")[0]) + except Exception: + layer_idx_int = 0 + g = torch.Generator() + g.manual_seed(seed_base + layer_idx_int) + perm = torch.randperm(int(non_halo_all.numel()), generator=g) + non_halo_idx = non_halo_all[perm[:sample_n]].long() + plan[layer_name] = { "lp_cpu": lp_cpu, "conn_cpu": conn, "super_idx_cpu": super_idx, "halo_idx_cpu": halo_idx, + "non_halo_idx_cpu": non_halo_idx, "m": m, # device-side indices + streaming sums (initialized lazily in hooks) "super_idx": None, "halo_idx": None, + "non_halo_idx": None, "sum_q_super": None, "sum_q2_super": None, "sum_q_halo": None, "sum_q2_halo": None, "sum_q_halo_super": None, + "sum_q_non_halo": None, + "sum_q2_non_halo": None, + "sum_q_non_halo_super": None, "count": 0, } @@ -4836,13 +4921,18 @@ def bwd_hook(mod: nn.Module, grad_input: Tuple[torch.Tensor, ...], grad_output: st["super_idx"] = st["super_idx_cpu"].to(device=u_flat.device) if st["halo_idx"] is None or st["halo_idx"].device != u_flat.device: st["halo_idx"] = st["halo_idx_cpu"].to(device=u_flat.device) + if st.get("non_halo_idx") is None or (st.get("non_halo_idx") is not None and st["non_halo_idx"].device != u_flat.device): + st["non_halo_idx"] = st.get("non_halo_idx_cpu", torch.empty((0,), dtype=torch.long)).to(device=u_flat.device) super_idx_dev = st["super_idx"] halo_idx_dev = st["halo_idx"] + non_halo_idx_dev = st.get("non_halo_idx") + if non_halo_idx_dev is None: + non_halo_idx_dev = torch.empty((0,), device=u_flat.device, dtype=torch.long) # Compute q = u * s where s := dL/du is already computed by backprop. # We only materialize the supernode+halo indices. - idx_union = torch.cat([super_idx_dev, halo_idx_dev], dim=0) # [|M|+|H|] + idx_union = torch.cat([super_idx_dev, halo_idx_dev, non_halo_idx_dev], dim=0) # [|M|+|H|+|N|] try: u_sel = u_flat.index_select(1, idx_union).float() # [N, |M|+|H|] s_sel = g_u_flat.index_select(1, idx_union).float() # [N, |M|+|H|] @@ -4851,8 +4941,10 @@ def bwd_hook(mod: nn.Module, grad_input: Tuple[torch.Tensor, ...], grad_output: q_sel = u_sel * s_sel # [N, |M|+|H|] n_super = super_idx_dev.numel() + n_halo = halo_idx_dev.numel() q_super = q_sel[:, :n_super] # [N, |M|] - q_halo = q_sel[:, n_super:] # [N, |H|] + q_halo = q_sel[:, n_super : n_super + n_halo] # [N, |H|] + q_non_halo = q_sel[:, n_super + n_halo :] # [N, |N|] N = q_sel.shape[0] @@ -4865,12 +4957,21 @@ def bwd_hook(mod: nn.Module, grad_input: Tuple[torch.Tensor, ...], grad_output: st["sum_q_halo_super"] = torch.zeros( (q_halo.shape[1], q_super.shape[1]), device=q_halo.device, dtype=torch.float32 ) + st["sum_q_non_halo"] = torch.zeros(q_non_halo.shape[1], device=q_non_halo.device, dtype=torch.float32) + st["sum_q2_non_halo"] = torch.zeros_like(st["sum_q_non_halo"]) + st["sum_q_non_halo_super"] = torch.zeros( + (q_non_halo.shape[1], q_super.shape[1]), device=q_non_halo.device, dtype=torch.float32 + ) st["sum_q_super"] += q_super.sum(dim=0) st["sum_q2_super"] += (q_super * q_super).sum(dim=0) st["sum_q_halo"] += q_halo.sum(dim=0) st["sum_q2_halo"] += (q_halo * q_halo).sum(dim=0) st["sum_q_halo_super"] += q_halo.transpose(0, 1) @ q_super # [|H|,|M|] + if q_non_halo.numel() > 0: + st["sum_q_non_halo"] += q_non_halo.sum(dim=0) + st["sum_q2_non_halo"] += (q_non_halo * q_non_halo).sum(dim=0) + st["sum_q_non_halo_super"] += q_non_halo.transpose(0, 1) @ q_super # [|N|,|M|] st["count"] += N return fwd_hook, bwd_hook @@ -4924,6 +5025,8 @@ def bwd_hook(mod: nn.Module, grad_input: Tuple[torch.Tensor, ...], grad_output: # ------------------------------------------------------------------ # Phase 3: Compute Protect + final importance scores; store into importance_scores # ------------------------------------------------------------------ + agg_red_halo: List[float] = [] + agg_red_non_halo: List[float] = [] for layer_name, st in plan.items(): N = int(st.get("count", 0)) if N <= 1 or st["sum_q_halo_super"] is None: @@ -4953,6 +5056,31 @@ def bwd_hook(mod: nn.Module, grad_input: Tuple[torch.Tensor, ...], grad_output: redundancy_to_core = mi.max(dim=1).values # [|H|] + # Optional: redundancy-to-core for a sampled set of non-halo channels (analysis only). + redundancy_to_core_non_halo = None + non_halo_idx_cpu = st.get("non_halo_idx_cpu", None) + if ( + non_halo_idx_cpu is not None + and hasattr(non_halo_idx_cpu, "numel") + and int(non_halo_idx_cpu.numel()) > 0 + and st.get("sum_q_non_halo_super") is not None + ): + sum_q_non = st["sum_q_non_halo"].detach().cpu() + sum_q2_non = st["sum_q2_non_halo"].detach().cpu() + sum_q_non_super = st["sum_q_non_halo_super"].detach().cpu() + + mean_non = sum_q_non / float(N) + cov_non = (sum_q_non_super / float(N)) - (mean_non.unsqueeze(1) * mean_super.unsqueeze(0)) + var_non = (sum_q2_non / float(N)) - (mean_non * mean_non) + denom_non = torch.sqrt(var_non.clamp_min(0).unsqueeze(1) * var_super.clamp_min(0).unsqueeze(0) + eps) + corr_non = torch.where(denom_non > 0, cov_non / denom_non, torch.zeros_like(cov_non)) + corr_non = corr_non.clamp(-0.9999, 0.9999) + + corr_eff_non = torch.clamp(corr_non, min=0.0) if positive_redundancy else corr_non + rho_sq_non = (corr_eff_non * corr_eff_non).clamp(0.0, 0.9999) + mi_non = -0.5 * torch.log(1 - rho_sq_non) + redundancy_to_core_non_halo = mi_non.max(dim=1).values # [|N|] + # Convert redundancy-to-core into a [0, 1] protection score. # # Empirically, redundancy magnitudes can be extremely small; min-max normalization @@ -5023,6 +5151,11 @@ def bwd_hook(mod: nn.Module, grad_input: Tuple[torch.Tensor, ...], grad_output: redundancy_full[halo_idx] = redundancy_to_core.float() except Exception: pass + if redundancy_to_core_non_halo is not None and non_halo_idx_cpu is not None: + try: + redundancy_full[non_halo_idx_cpu] = redundancy_to_core_non_halo.float() + except Exception: + pass # SCAR-Prot and SCAR-Conn importance scores (high=keep) prot_score = (lp * protect_full).float() @@ -5053,12 +5186,52 @@ def bwd_hook(mod: nn.Module, grad_input: Tuple[torch.Tensor, ...], grad_output: results[layer_name] = { "num_supernodes": int(super_idx.numel()), "num_halo": int(halo_idx.numel()), + "num_non_halo_sample": int(non_halo_idx_cpu.numel()) if non_halo_idx_cpu is not None else 0, "q_samples": N, "conn_mean": float(conn.mean().item()), "protect_halo_mean": float(protect_halo.mean().item()) if protect_halo.numel() else 0.0, "redundancy_to_core_mean": float(redundancy_to_core.mean().item()) if redundancy_to_core.numel() else 0.0, + "non_halo_redundancy_to_core_mean": float(redundancy_to_core_non_halo.mean().item()) + if redundancy_to_core_non_halo is not None and redundancy_to_core_non_halo.numel() + else 0.0, } + + # Aggregate distributions (for tables / sanity checks) + try: + halo_vals = redundancy_to_core.detach().float() + halo_vals = halo_vals[torch.isfinite(halo_vals)] + agg_red_halo.extend([float(x) for x in halo_vals.tolist() if x == x]) + except Exception: + pass + if redundancy_to_core_non_halo is not None: + try: + non_vals = redundancy_to_core_non_halo.detach().float() + non_vals = non_vals[torch.isfinite(non_vals)] + agg_red_non_halo.extend([float(x) for x in non_vals.tolist() if x == x]) + except Exception: + pass + # Add aggregate stats for paper tables (useful even when per-layer values are noisy). + if agg_red_halo or agg_red_non_halo: + def _stats(vals: List[float]) -> Dict[str, Any]: + arr = np.asarray(vals, dtype=np.float64) + arr = arr[np.isfinite(arr)] + if arr.size == 0: + return {"n": 0, "mean": None, "std": None, "median": None} + return { + "n": int(arr.size), + "mean": float(arr.mean()), + "std": float(arr.std()), + "median": float(np.median(arr)), + } + + results["_aggregate"] = { + "redundancy_to_core": { + "halo": _stats(agg_red_halo), + "non_halo_sample": _stats(agg_red_non_halo), + } + } + logger.info(f"Computed SCAR protection/connectivity scores for {len(results)} layers") return results @@ -6459,7 +6632,7 @@ def apply_pruning(self, sparsity: float = 0.2, metric: str = "activation_l2_norm if core_mask is not None: break except Exception: - core_mask = self.importance_scores[layer_name].get("supernode_mask") + core_mask = (self.importance_scores.get(layer_name) or {}).get("supernode_mask") if core_mask is not None and self._should_protect_supernodes_for_metric(metric): margin = torch.abs(scores).max().detach().item() + 1.0 if mode == "low": diff --git a/src/alignment/pruning/dependency_aware.py b/src/alignment/pruning/dependency_aware.py index fca4005c..ce4bf844 100644 --- a/src/alignment/pruning/dependency_aware.py +++ b/src/alignment/pruning/dependency_aware.py @@ -291,26 +291,61 @@ def _create_weight_mask(self, module: nn.Module, out_mask: torch.Tensor, in_mask # Conv2d: [out_channels, in_channels, k_h, k_w] expected_out = module.out_channels expected_in = module.in_channels + groups = int(getattr(module, "groups", 1)) + # Weight shape is [out_channels, in_channels/groups, k_h, k_w] + in_per_group = int(module.weight.shape[1]) # Handle dimension mismatch (e.g., from skip connections or downsample layers) if out_mask.shape[0] != expected_out: out_mask = torch.ones(expected_out, dtype=torch.bool, device=device) if in_mask.shape[0] != expected_in: in_mask = torch.ones(expected_in, dtype=torch.bool, device=device) - - weight_mask = (out_mask.view(-1, 1, 1, 1) & in_mask.view(1, -1, 1, 1)).expand_as(module.weight) + + # Grouped/depthwise conv needs special handling: the weight's "in channel" + # dimension is in_channels/groups, and the input mask must be applied per-group. + if groups > 1: + out_per_group = expected_out // groups if groups > 0 else expected_out + # Build a per-output-channel view of the relevant slice of in_mask + in_mask_per_out = torch.empty((expected_out, in_per_group), dtype=torch.bool, device=device) + for g in range(groups): + out_slice = slice(g * out_per_group, min((g + 1) * out_per_group, expected_out)) + in_slice = slice(g * in_per_group, min((g + 1) * in_per_group, expected_in)) + # Fallback to all-True if something is off (robustness) + if (in_slice.stop - in_slice.start) != in_per_group: + in_mask_per_out[out_slice] = torch.ones((out_slice.stop - out_slice.start, in_per_group), dtype=torch.bool, device=device) + else: + in_mask_per_out[out_slice] = in_mask[in_slice].view(1, -1).expand(out_slice.stop - out_slice.start, -1) + + weight_mask = (out_mask.view(-1, 1, 1, 1) & in_mask_per_out.view(expected_out, in_per_group, 1, 1)).expand_as(module.weight) + else: + weight_mask = (out_mask.view(-1, 1, 1, 1) & in_mask.view(1, -1, 1, 1)).expand_as(module.weight) elif isinstance(module, nn.Conv1d): # Conv1d: [out_channels, in_channels, k] expected_out = module.out_channels expected_in = module.in_channels + groups = int(getattr(module, "groups", 1)) + in_per_group = int(module.weight.shape[1]) if out_mask.shape[0] != expected_out: out_mask = torch.ones(expected_out, dtype=torch.bool, device=device) if in_mask.shape[0] != expected_in: in_mask = torch.ones(expected_in, dtype=torch.bool, device=device) - - weight_mask = (out_mask.view(-1, 1, 1) & in_mask.view(1, -1, 1)).expand_as(module.weight) + + if groups > 1: + out_per_group = expected_out // groups if groups > 0 else expected_out + in_mask_per_out = torch.empty((expected_out, in_per_group), dtype=torch.bool, device=device) + for g in range(groups): + out_slice = slice(g * out_per_group, min((g + 1) * out_per_group, expected_out)) + in_slice = slice(g * in_per_group, min((g + 1) * in_per_group, expected_in)) + if (in_slice.stop - in_slice.start) != in_per_group: + in_mask_per_out[out_slice] = torch.ones((out_slice.stop - out_slice.start, in_per_group), dtype=torch.bool, device=device) + else: + in_mask_per_out[out_slice] = in_mask[in_slice].view(1, -1).expand(out_slice.stop - out_slice.start, -1) + + weight_mask = (out_mask.view(-1, 1, 1) & in_mask_per_out.view(expected_out, in_per_group, 1)).expand_as(module.weight) + else: + weight_mask = (out_mask.view(-1, 1, 1) & in_mask.view(1, -1, 1)).expand_as(module.weight) else: # Default: mask entire weight tensor diff --git a/src/alignment/services/mask_ops.py b/src/alignment/services/mask_ops.py index 3efd0747..2577a023 100644 --- a/src/alignment/services/mask_ops.py +++ b/src/alignment/services/mask_ops.py @@ -271,6 +271,24 @@ def threshold_fn(s): def threshold_fn(s): return s <= threshold + elif mode == "random": + # Random global keep mask (ignores scores). Useful for a random baseline + # under global-threshold pruning. + perm = torch.randperm(num_total, device=all_scores_cat.device) + keep_idx = perm[:num_to_keep] + keep_mask_flat = torch.zeros(num_total, dtype=torch.bool, device=all_scores_cat.device) + keep_mask_flat[keep_idx] = True + + masks = {} + offset = 0 + for layer_name, shape, n in layer_info: + mask_flat = keep_mask_flat[offset : offset + n] + masks[layer_name] = mask_flat.view(shape) + offset += n + + logger.info(f"Global random masking: {num_to_keep}/{num_total} elements kept") + return masks + else: raise ValueError(f"Global thresholding not supported for mode: {mode}") From 8cf6e3237110b4ee927580bf6bad40c133fdfcd6 Mon Sep 17 00:00:00 2001 From: Houman Safaai Date: Wed, 14 Jan 2026 12:24:13 -0500 Subject: [PATCH 13/15] clean files --- configs/prune_llm/llama3_8b_full.yaml | 1 + configs/prune_llm/llama3_8b_random_only.yaml | 87 +++++ .../prune_llm/run_llama3_8b_random_only.sh | 76 +++++ .../analysis/visualization/__init__.py | 13 - .../analysis/visualization/halo_plots.py | 2 +- .../analysis/visualization/paper_plots.py | 285 +++++++---------- src/alignment/experiments/llm_experiments.py | 300 +++++++++++++++++- src/alignment/metrics/__init__.py | 2 +- src/alignment/metrics/cross_layer.py | 2 +- src/alignment/metrics/halo_redundancy.py | 6 +- .../metrics/information/gaussian_mi.py | 2 +- src/alignment/metrics/multi_supernode.py | 2 +- .../pruning/strategies/cluster_aware.py | 9 +- 13 files changed, 572 insertions(+), 215 deletions(-) create mode 100644 configs/prune_llm/llama3_8b_random_only.yaml create mode 100644 slurm_jobs/prune_llm/run_llama3_8b_random_only.sh diff --git a/configs/prune_llm/llama3_8b_full.yaml b/configs/prune_llm/llama3_8b_full.yaml index 5f04a420..0a607614 100644 --- a/configs/prune_llm/llama3_8b_full.yaml +++ b/configs/prune_llm/llama3_8b_full.yaml @@ -107,6 +107,7 @@ llm: - "accuracy_hellaswag" - "accuracy_arc_easy" - "accuracy_arc_challenge" + - "accuracy_openbookqa" # Common Sense - "accuracy_winogrande" diff --git a/configs/prune_llm/llama3_8b_random_only.yaml b/configs/prune_llm/llama3_8b_random_only.yaml new file mode 100644 index 00000000..1a23a29d --- /dev/null +++ b/configs/prune_llm/llama3_8b_random_only.yaml @@ -0,0 +1,87 @@ +# ============================================================================ +# LLAMA-3.1-8B RANDOM (CHANNEL) BASELINE +# ============================================================================ +# +# Purpose: +# - Fill the missing "Random (channel)" baseline row in paper tables. +# - Run ONLY one pruning strategy (random) at 50% sparsity. +# - Keep evaluation protocol consistent with the main paper run (few-shot settings, ppl protocol). +# +# This is intentionally lightweight: we skip SCAR analyses/plots and only do: +# - Baseline eval +# - Random structured channel pruning @ 50% +# - Post-prune eval +# ============================================================================ + +experiment: + name: "llama3_8b_paper_results_random" + type: "llm_alignment" + output_dir: "./results/paper/llama3_8b_random" + seed: 42 + device: "cuda" + save_activations: false + num_networks: 1 + +model: + name: "hf_causal_lm" + model_id: "meta-llama/Llama-3.1-8B" + dtype: "bfloat16" + device_map: "auto" + trust_remote_code: true + +dataset: + name: "wikitext" + batch_size: 1 + num_workers: 0 + +llm: + evaluate_perplexity: true + evaluation_num_samples: 100 + use_nvidia_fewshot: true + perplexity_protocol: "oats" + wikitext_subset: "wikitext-2-raw-v1" + perplexity_seq_len: 2048 + + evaluation_metrics: + - "perplexity" + - "accuracy_openbookqa" + - "accuracy_mmlu" + - "accuracy_hellaswag" + - "accuracy_piqa" + - "accuracy_boolq" + - "accuracy_winogrande" + - "accuracy_arc_easy" + - "accuracy_arc_challenge" + +# Disable heavy analyses for this baseline-only run +analysis: + generate_plots: false + save_scores: false + +do_scar_metrics: false +do_directed_redundancy: false +do_connectivity_pruning: false +do_halo_analysis: false +do_generalized_importance: false + +pruning: + enabled: true + target: "ffn" + structured: true + dependency_aware: true + distribution: "uniform" + min_per_layer: 0.0 + max_per_layer: 0.95 + + # Single point needed for table_full_benchmarks_50 + sparsity_levels: [0.5] + + # Random structured pruning: selection done by mode="random" + selection_modes: ["random"] + + # Only one strategy for this run; scores are generated in-code (deterministic). + algorithms: + - "random" + + single_strategy: "random" + diff --git a/slurm_jobs/prune_llm/run_llama3_8b_random_only.sh b/slurm_jobs/prune_llm/run_llama3_8b_random_only.sh new file mode 100644 index 00000000..c00e56f4 --- /dev/null +++ b/slurm_jobs/prune_llm/run_llama3_8b_random_only.sh @@ -0,0 +1,76 @@ +#!/bin/bash +#SBATCH --job-name=paper_llama3_random +#SBATCH --output=logs/paper_llama3_random_%j.out +#SBATCH --error=logs/paper_llama3_random_%j.err +#SBATCH --nodes=1 +#SBATCH --ntasks-per-node=1 +#SBATCH --gres=gpu:1 +#SBATCH --cpus-per-task=8 +#SBATCH --time=04:00:00 +#SBATCH --mem=96GB +#SBATCH --partition=kempner_h100_priority3 +#SBATCH --account=kempner_dev + +set -euo pipefail + +echo "============================================================================" +echo "SCAR Paper: LLaMA-3.1-8B Random (channel) baseline" +echo "============================================================================" +echo "Job ID: $SLURM_JOB_ID" +echo "Node: $(hostname)" +echo "Start time: $(date)" +echo "GPU: $(nvidia-smi --query-gpu=name --format=csv,noheader | head -1)" +OUTPUT_BASE="${OUTPUT_BASE:-/n/holylfs06/LABS/kempner_project_b/Lab/alignment/Prune_LLM}" +echo "Output Base: $OUTPUT_BASE" +echo "" + +module purge +module load cuda/12.2.0-fasrc01 +eval "$(conda shell.bash hook)" +conda activate networkAlignmentAnalysis + +cd "${SLURM_SUBMIT_DIR:-/n/holylabs/kempner_dev/Users/hsafaai/Code/alignment}" +mkdir -p logs + +export PYTHONPATH="${PWD}:${PWD}/src:${PYTHONPATH:-}" +export OMP_NUM_THREADS=$SLURM_CPUS_PER_TASK +export TOKENIZERS_PARALLELISM=false +export PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True + +if [[ -z "${HF_HOME:-}" ]]; then + if [[ -f "${OUTPUT_BASE}/huggingface_cache/token" ]]; then + export HF_HOME="${OUTPUT_BASE}/huggingface_cache" + elif [[ -d /n/holyscratch01/kempner_dev/Users/hsafaai/huggingface_cache ]]; then + export HF_HOME="/n/holyscratch01/kempner_dev/Users/hsafaai/huggingface_cache" + else + export HF_HOME="/n/home13/hsafaai/.cache/huggingface" + fi +fi +HF_TOKEN_FILE="${HF_HOME}/token" +if [[ -f "$HF_TOKEN_FILE" ]]; then + export HF_TOKEN="$(cat "$HF_TOKEN_FILE")" +elif [[ -z "${HF_TOKEN:-}" ]]; then + echo "WARN: HF token file not found at $HF_TOKEN_FILE (set HF_TOKEN env var if needed)" >&2 +fi +if [[ -n "${HF_TOKEN:-}" ]]; then + export HUGGINGFACE_HUB_TOKEN="${HF_TOKEN}" +fi + +echo "HF_HOME: $HF_HOME" +if [[ -n "${HF_TOKEN:-}" ]]; then + echo "HF_TOKEN: set" +else + echo "HF_TOKEN: unset" +fi +echo "" + +python scripts/run_experiment.py \ + --config configs/prune_llm/llama3_8b_random_only.yaml \ + --device cuda \ + --base-output-dir "$OUTPUT_BASE" + +echo "" +echo "============================================================================" +echo "Random baseline completed at $(date)" +echo "============================================================================" + diff --git a/src/alignment/analysis/visualization/__init__.py b/src/alignment/analysis/visualization/__init__.py index 0fe5c59a..4ccb65d3 100644 --- a/src/alignment/analysis/visualization/__init__.py +++ b/src/alignment/analysis/visualization/__init__.py @@ -77,14 +77,6 @@ plot_halo_redundancy_heatmap, ) -# Paper-specific plots (SCAR draft) -from .paper_plots import ( - plot_loss_proxy_concentration, - plot_halo_structure, - plot_supernode_halo_summary, - plot_scar_schematic, -) - # Cluster visualization plots from .cluster_plots import ( plot_metric_scatter, @@ -136,11 +128,6 @@ "plot_halo_redundancy_by_depth", "plot_halo_redundancy_comprehensive", "plot_halo_redundancy_heatmap", - # Paper plots - "plot_loss_proxy_concentration", - "plot_halo_structure", - "plot_supernode_halo_summary", - "plot_scar_schematic", # Cluster plots "plot_metric_scatter", "plot_cluster_evolution", diff --git a/src/alignment/analysis/visualization/halo_plots.py b/src/alignment/analysis/visualization/halo_plots.py index 04908b1f..16e7273a 100644 --- a/src/alignment/analysis/visualization/halo_plots.py +++ b/src/alignment/analysis/visualization/halo_plots.py @@ -297,7 +297,7 @@ def plot_halo_redundancy_comprehensive( {echo_interpret} -THEORETICAL BASIS (alignment_notes): +THEORETICAL BASIS (Gaussian mutual information): ───────────────────────────────────────── Redundancy: I(Yᵢ; Yⱼ) = -½ log(1 - ρ²) This is the Gaussian MI between neuron pairs. diff --git a/src/alignment/analysis/visualization/paper_plots.py b/src/alignment/analysis/visualization/paper_plots.py index 2b8f57bb..a02ee127 100644 --- a/src/alignment/analysis/visualization/paper_plots.py +++ b/src/alignment/analysis/visualization/paper_plots.py @@ -12,7 +12,7 @@ import logging from pathlib import Path -from typing import Any, Dict, Iterable, List, Optional, Sequence, Tuple, Union +from typing import Any, Dict, List, Optional, Sequence, Tuple, Union import matplotlib @@ -41,7 +41,7 @@ def _to_numpy(x: Any) -> np.ndarray: def _save(fig: plt.Figure, save_path: Union[str, Path], dpi: int = 300) -> None: save_path = Path(save_path) save_path.parent.mkdir(parents=True, exist_ok=True) - fig.savefig(save_path, dpi=dpi, bbox_inches="tight") + fig.savefig(save_path, dpi=dpi, bbox_inches="tight", pad_inches=0.02, facecolor="white") logger.info(f"[Saved] {save_path}") @@ -53,15 +53,15 @@ def plot_loss_proxy_concentration( dpi: int = 300, ) -> plt.Figure: """ - Two-panel plot: - (Left) sorted LP values (heavy tail) - (Right) cumulative proxy mass vs fraction of channels kept + Two-panel plot (ICML figure* friendly): + (a) sorted LP values (heavy tail) + (b) cumulative proxy mass vs fraction of channels kept """ lp = _to_numpy(loss_proxy).astype(np.float64).reshape(-1) lp = lp[np.isfinite(lp)] lp = np.maximum(lp, 0.0) - fig, axes = plt.subplots(1, 2, figsize=(12, 4.0)) + fig, axes = plt.subplots(1, 2, figsize=(7.2, 2.6)) if lp.size == 0: for ax in axes: ax.axis("off") @@ -73,7 +73,6 @@ def plot_loss_proxy_concentration( lp_sorted = np.sort(lp)[::-1] n = lp_sorted.size k = max(1, int(round(rho * n))) - threshold = lp_sorted[k - 1] total = float(lp_sorted.sum()) if float(lp_sorted.sum()) > 0 else 1.0 cum_mass = np.cumsum(lp_sorted) / total @@ -82,6 +81,7 @@ def plot_loss_proxy_concentration( # Panel A: sorted values ax = axes[0] + ax.text(0.02, 0.98, "(a)", transform=ax.transAxes, ha="left", va="top", fontsize=10, fontweight="bold") ax.plot(frac, lp_sorted, color="#2c3e50", linewidth=1.5) ax.axvline(x=rho, color="#c0392b", linestyle="--", linewidth=2, label=f"Top {rho*100:.1f}%") ax.set_yscale("log") @@ -90,23 +90,23 @@ def plot_loss_proxy_concentration( title = "Loss-proxy heavy tail" if layer_label: title += f"\n{layer_label}" - ax.set_title(title) + ax.set_title(title, fontsize=10.5) ax.grid(True, alpha=0.25) - ax.legend(loc="upper right") + ax.legend(loc="upper right", fontsize=8, frameon=True) # Panel B: cumulative mass ax = axes[1] + ax.text(0.02, 0.98, "(b)", transform=ax.transAxes, ha="left", va="top", fontsize=10, fontweight="bold") ax.plot(frac, cum_mass, color="#2980b9", linewidth=2.0) ax.axvline(x=rho, color="#c0392b", linestyle="--", linewidth=2) ax.scatter([rho], [top_mass], color="#c0392b", zorder=5) ax.set_xlabel("Fraction of channels kept (top by LP)") ax.set_ylabel("Cumulative LP mass") ax.set_ylim(0, 1.02) - ax.set_title(f"Top {rho*100:.1f}% mass = {top_mass*100:.1f}%") + ax.set_title(f"Top {rho*100:.1f}% mass = {top_mass*100:.1f}%", fontsize=10.5) ax.grid(True, alpha=0.25) plt.tight_layout() - if save_path is not None: _save(fig, save_path, dpi=dpi) return fig @@ -124,10 +124,10 @@ def plot_halo_structure( max_points: int = 60000, ) -> plt.Figure: """ - Three-panel plot: - (Left) Conn vs redundancy-to-core (halo channels) - (Middle) Redundancy-to-core distribution: halo vs non-halo (sample where defined) - (Right) Protect vs Conn (all channels; halo emphasized) + Three-panel plot (ICML figure* friendly): + (a) Conn vs redundancy-to-core (halo channels) + (b) Redundancy-to-core distribution: halo vs non-halo (sample where defined) + (c) Protect vs Conn (all channels; halo emphasized) """ conn_np = _to_numpy(conn).astype(np.float64).reshape(-1) red_np = _to_numpy(redundancy_to_core).astype(np.float64).reshape(-1) @@ -136,8 +136,10 @@ def plot_halo_structure( halo_np = _to_numpy(halo_mask).astype(bool).reshape(-1) n = int(conn_np.size) + fig, axes = plt.subplots(1, 3, figsize=(7.2, 2.6)) if n == 0: - fig, _ = plt.subplots(figsize=(10, 4)) + for ax in axes: + ax.axis("off") return fig # Downsample for plotting stability @@ -150,42 +152,41 @@ def plot_halo_structure( idx_non = idx_all[(~halo_np[idx_all]) & (~super_np[idx_all])] idx_sup = idx_all[super_np[idx_all]] - fig, axes = plt.subplots(1, 3, figsize=(15, 4.2)) - - # Panel A: Conn vs redundancy-to-core (halo only, since redundancy is defined there) + # (a) Conn vs redundancy-to-core (halo only) ax = axes[0] + ax.text(0.02, 0.98, "(a)", transform=ax.transAxes, ha="left", va="top", fontsize=10, fontweight="bold") x = conn_np[idx_halo] y = red_np[idx_halo] finite = np.isfinite(x) & np.isfinite(y) x = x[finite] y = y[finite] - ax.scatter(x, y, s=10, alpha=0.35, color="#1f77b4", edgecolors="none") + ax.scatter(x, y, s=8, alpha=0.35, color="#1f77b4", edgecolors="none") ax.set_xlabel(r"Connectivity $\mathrm{Conn}$") - ax.set_ylabel(r"Redundancy to core $\mathrm{Red}^{\rightarrow \mathcal{M}}$") + ax.set_ylabel(r"Red.\ to core $\mathrm{Red}^{\rightarrow \mathcal{M}}$") title = "Halo redundancy structure" if layer_label: title += f"\n{layer_label}" - ax.set_title(title) + ax.set_title(title, fontsize=10.5) ax.grid(True, alpha=0.25) if y.size > 0 and np.nanmin(y) > 0: ax.set_yscale("log") - # Panel B: Redundancy-to-core distribution comparison (halo vs non-halo sample) + # (b) Halo vs non-halo redundancy-to-core distribution ax = axes[1] + ax.text(0.02, 0.98, "(b)", transform=ax.transAxes, ha="left", va="top", fontsize=10, fontweight="bold") y_h = red_np[idx_halo] y_n = red_np[idx_non] y_h = y_h[np.isfinite(y_h)] y_n = y_n[np.isfinite(y_n)] - if y_h.size == 0 or y_n.size == 0: ax.text( 0.5, 0.5, - "Redundancy-to-core\n(non-halo sample unavailable)", + "Red-to-core\n(non-halo sample unavailable)", ha="center", va="center", transform=ax.transAxes, - fontsize=10, + fontsize=9.5, color="#2c3e50", ) ax.set_axis_off() @@ -204,29 +205,28 @@ def plot_halo_structure( for patch, c in zip(bp.get("boxes", []), colors): patch.set_facecolor(c) patch.set_alpha(0.75) - - ax.set_xticklabels([f"Halo\n(n={y_h.size})", f"Non-halo\n(sample, n={y_n.size})"]) - ax.set_ylabel(r"Redundancy to core $\mathrm{Red}^{\rightarrow \mathcal{M}}$") - ax.set_title("Halo vs non-halo\nredundancy-to-core") + ax.set_xticklabels([f"Halo\n(n={y_h.size})", f"Non-halo\n(sample, n={y_n.size})"], fontsize=8.5) + ax.set_ylabel(r"Red.\ to core $\mathrm{Red}^{\rightarrow \mathcal{M}}$") + ax.set_title("Halo vs non-halo", fontsize=10.5) ax.grid(True, alpha=0.25) - if y_h.size > 0 and y_n.size > 0 and np.nanmin(np.concatenate([y_h, y_n])) > 0: + if np.nanmin(np.concatenate([y_h, y_n])) > 0: ax.set_yscale("log") - # Panel C: Protect vs Conn (all channels) + # (c) Protect vs Conn ax = axes[2] - ax.scatter(conn_np[idx_non], prot_np[idx_non], s=6, alpha=0.15, color="#7f8c8d", label="Non-halo", edgecolors="none") - ax.scatter(conn_np[idx_halo], prot_np[idx_halo], s=10, alpha=0.35, color="#1f77b4", label="Halo", edgecolors="none") + ax.text(0.02, 0.98, "(c)", transform=ax.transAxes, ha="left", va="top", fontsize=10, fontweight="bold") + ax.scatter(conn_np[idx_non], prot_np[idx_non], s=5, alpha=0.15, color="#7f8c8d", label="Non-halo", edgecolors="none") + ax.scatter(conn_np[idx_halo], prot_np[idx_halo], s=7, alpha=0.35, color="#1f77b4", label="Halo", edgecolors="none") if idx_sup.size > 0: - ax.scatter(conn_np[idx_sup], prot_np[idx_sup], s=14, alpha=0.7, color="#c0392b", label="Supernodes", edgecolors="none") + ax.scatter(conn_np[idx_sup], prot_np[idx_sup], s=10, alpha=0.7, color="#c0392b", label="Supernodes", edgecolors="none") ax.set_xlabel(r"Connectivity $\mathrm{Conn}$") ax.set_ylabel(r"Protection $\mathrm{Protect}$") - ax.set_title("Protection vs connectivity") + ax.set_title("Protection vs Conn", fontsize=10.5) ax.set_ylim(-0.02, 1.02) ax.grid(True, alpha=0.25) - ax.legend(loc="lower left", frameon=True) + ax.legend(loc="lower left", fontsize=8, frameon=True) plt.tight_layout() - if save_path is not None: _save(fig, save_path, dpi=dpi) return fig @@ -243,26 +243,33 @@ def plot_supernode_halo_summary( ) -> plt.Figure: """ Two-panel plot: - (Left) top-rho LP mass ratio across layers - (Right) halo/non-halo redundancy summary bars (from halo_analysis.aggregate) + (a) top-rho LP mass ratio across layers + (b) halo/non-halo redundancy summary (from halo_analysis.per_layer if available) """ layers = np.asarray(list(layer_indices), dtype=int) ratios = np.asarray(list(top_mass_ratios), dtype=np.float64) - fig, axes = plt.subplots(1, 2, figsize=(12, 4.0)) + fig, axes = plt.subplots(1, 2, figsize=(7.2, 2.6)) ax = axes[0] - ax.plot(layers, ratios, "o-", color="#2c3e50", linewidth=2) + ax.text(0.02, 0.98, "(a)", transform=ax.transAxes, ha="left", va="top", fontsize=10, fontweight="bold") + ax.plot(layers, ratios, "o-", color="#2c3e50", linewidth=2, markersize=3.5) ax.set_xlabel("Layer index") ax.set_ylabel(f"Top-{rho*100:.1f}% LP mass ratio") ax.set_ylim(0, 1.02) - ax.set_title("Supernode concentration across layers") + ax.set_title("Supernode concentration", fontsize=10.5) ax.grid(True, alpha=0.25) ax = axes[1] - groups = [("Within-Halo", "halo_halo", "#1f77b4"), ("Within-Non-Halo", "non_halo", "#7f8c8d"), ("Cross", "cross", "#2ecc71")] + ax.text(0.02, 0.98, "(b)", transform=ax.transAxes, ha="left", va="top", fontsize=10, fontweight="bold") - # Prefer per-layer distributions (much clearer than mean±std when the MI distribution is heavy-tailed). + groups = [ + ("Within-Halo", "halo_halo", "#1f77b4"), + ("Within-Non-Halo", "non_halo", "#7f8c8d"), + ("Cross", "cross", "#2ecc71"), + ] + + # Prefer per-layer medians (more robust for heavy tails). if isinstance(halo_per_layer, dict) and halo_per_layer: data = [] for _, key, _ in groups: @@ -292,35 +299,17 @@ def plot_supernode_halo_summary( whiskerprops=dict(linewidth=1.2, color="#2c3e50"), capprops=dict(linewidth=1.2, color="#2c3e50"), ) - # Color the boxes for patch, (_, _, color) in zip(bp.get("boxes", []), groups): patch.set_facecolor(color) patch.set_alpha(0.75) - # Overlay jittered per-layer medians for transparency - rng = np.random.default_rng(0) - for i, vals in enumerate(data, start=1): - if vals.size == 0: - continue - jitter = rng.normal(loc=0.0, scale=0.05, size=vals.size) - ax.scatter( - np.full(vals.shape, i, dtype=float) + jitter, - vals, - s=14, - alpha=0.35, - color="#2c3e50", - edgecolors="none", - ) - ax.set_xticks(np.arange(1, len(groups) + 1)) - ax.set_xticklabels([g[0] for g in groups], rotation=15, ha="right") + ax.set_xticklabels([g[0] for g in groups], rotation=15, ha="right", fontsize=8.5) ax.set_ylabel("Redundancy (Gaussian MI, nats)\n(per-layer median)") - ax.set_title("Halo redundancy across layers") + ax.set_title("Halo redundancy", fontsize=10.5) ax.grid(True, alpha=0.25, axis="y") - # MI is positive and often heavy-tailed; log helps readability. ax.set_yscale("log") else: - # Fallback: show mean ± 95% CI of the mean (std can be huge for heavy tails). halo_aggregate = halo_aggregate or {} means = [] cis = [] @@ -332,13 +321,12 @@ def plot_supernode_halo_summary( sem = sd / np.sqrt(n) if n > 1 else 0.0 means.append(mu) cis.append(1.96 * sem) - x = np.arange(len(groups)) - ax.bar(x, means, yerr=cis, capsize=4, color=[g[2] for g in groups], alpha=0.85, edgecolor="none") + ax.bar(x, means, yerr=cis, capsize=3, color=[g[2] for g in groups], alpha=0.85, edgecolor="none") ax.set_xticks(x) - ax.set_xticklabels([g[0] for g in groups], rotation=15, ha="right") + ax.set_xticklabels([g[0] for g in groups], rotation=15, ha="right", fontsize=8.5) ax.set_ylabel("Redundancy (Gaussian MI, nats)\n(mean ± 95% CI)") - ax.set_title("Halo redundancy (aggregate)") + ax.set_title("Halo redundancy", fontsize=10.5) ax.grid(True, alpha=0.25, axis="y") plt.tight_layout() @@ -358,10 +346,9 @@ def plot_supernode_outlier_profile( dpi: int = 300, ) -> plt.Figure: """ - Two-panel plot for supernode outlier strength across depth. - - - Left: activation outlier ratio (supernode mean / population mean), log scale. - - Right: z-scores across layers (activation and loss-proxy); plus max-neuron activation z on a secondary axis. + Two-panel plot: + (a) activation outlier ratio (supernode mean / population mean), log scale. + (b) z-scores across layers (activation and loss-proxy), plus max-neuron z. """ layers = np.asarray(list(layer_indices), dtype=int) ratios = np.asarray(list(outlier_ratios), dtype=np.float64) @@ -369,39 +356,38 @@ def plot_supernode_outlier_profile( z_lp = np.asarray(list(z_scores_loss_proxy), dtype=np.float64) z_max = np.asarray(list(z_scores_max_activation), dtype=np.float64) - fig, axes = plt.subplots(1, 2, figsize=(12, 4.0)) + fig, axes = plt.subplots(1, 2, figsize=(7.2, 2.6)) - # Panel A: outlier ratio (log) ax = axes[0] - ax.plot(layers, ratios, "o-", color="#8e44ad", linewidth=2.0, markersize=4, label="Supernode mean / population mean") + ax.text(0.02, 0.98, "(a)", transform=ax.transAxes, ha="left", va="top", fontsize=10, fontweight="bold") + ax.plot(layers, ratios, "o-", color="#8e44ad", linewidth=2.0, markersize=3.5) ax.set_yscale("log") - ax.axhline(10.0, color="#f39c12", linestyle="--", linewidth=1.8, label="10×") - ax.axhline(100.0, color="#c0392b", linestyle="--", linewidth=1.8, label="100×") + ax.axhline(10.0, color="#f39c12", linestyle="--", linewidth=1.4, label="10×") + ax.axhline(100.0, color="#c0392b", linestyle="--", linewidth=1.4, label="100×") ax.set_xlabel("Layer index") - ax.set_ylabel("Activation outlier ratio (log scale)") - ax.set_title(f"Supernode outlier ratio (top {rho*100:.0f}% by LP)") + ax.set_ylabel("Activation outlier ratio") + ax.set_title(f"Outlier ratio (top {rho*100:.0f}% by LP)", fontsize=10.5) ax.grid(True, alpha=0.25, axis="y") - ax.legend(loc="upper right", frameon=True) + ax.legend(loc="upper right", fontsize=8, frameon=True) - # Panel B: z-scores (dual axis) ax = axes[1] - ax.plot(layers, z_act, "o-", color="#e67e22", linewidth=2.0, markersize=4, label="Activation z (supernode mean)") - ax.plot(layers, z_lp, "o-", color="#2980b9", linewidth=2.0, markersize=4, label="Loss-proxy z (supernode mean)") - ax.axhline(2.0, color="#7f8c8d", linestyle="--", linewidth=1.5, alpha=0.8) - ax.axhline(3.0, color="#7f8c8d", linestyle="--", linewidth=1.5, alpha=0.8) + ax.text(0.02, 0.98, "(b)", transform=ax.transAxes, ha="left", va="top", fontsize=10, fontweight="bold") + ax.plot(layers, z_act, "o-", color="#e67e22", linewidth=2.0, markersize=3.5, label="Activation z (supernode mean)") + ax.plot(layers, z_lp, "o-", color="#2980b9", linewidth=2.0, markersize=3.5, label="LP z (supernode mean)") + ax.axhline(2.0, color="#7f8c8d", linestyle="--", linewidth=1.2, alpha=0.8) + ax.axhline(3.0, color="#7f8c8d", linestyle="--", linewidth=1.2, alpha=0.8) ax.set_xlabel("Layer index") ax.set_ylabel("Z-score (supernode mean)") - ax.set_title("Outlier z-scores across layers") + ax.set_title("Outlier z-scores", fontsize=10.5) ax.grid(True, alpha=0.25, axis="y") ax2 = ax.twinx() - ax2.plot(layers, z_max, "^-", color="#2c3e50", linewidth=1.8, markersize=5, label="Activation z (max neuron)") - ax2.set_ylabel("Z-score (max neuron, activation)") + ax2.plot(layers, z_max, "^-", color="#2c3e50", linewidth=1.6, markersize=4, label="Activation z (max neuron)") + ax2.set_ylabel("Z-score (max neuron)") - # Combined legend h1, l1 = ax.get_legend_handles_labels() h2, l2 = ax2.get_legend_handles_labels() - ax.legend(h1 + h2, l1 + l2, loc="upper right", frameon=True) + ax.legend(h1 + h2, l1 + l2, loc="upper right", fontsize=8, frameon=True) plt.tight_layout() if save_path is not None: @@ -416,22 +402,16 @@ def plot_sparsity_perplexity_curves( save_path: Optional[Union[str, Path]] = None, dpi: int = 300, ) -> plt.Figure: - """ - Paper-facing plot: perplexity vs structured sparsity for multiple methods. - - Inputs are already "paper-ready" (i.e., only the intended pruning direction, typically low-mode). - """ xs = np.asarray(list(sparsities), dtype=np.float64) - fig, ax = plt.subplots(figsize=(7.0, 4.2)) + fig, ax = plt.subplots(figsize=(3.45, 2.6)) - # Stable ordering for legend for label in sorted(ppl_by_method.keys()): ys_raw = ppl_by_method[label] ys = np.asarray([np.nan if v is None else float(v) for v in ys_raw], dtype=np.float64) finite = np.isfinite(ys) if not np.any(finite): continue - ax.plot(xs[finite], ys[finite], "o-", linewidth=2.0, markersize=5, label=label, alpha=0.9) + ax.plot(xs[finite], ys[finite], "o-", linewidth=2.0, markersize=4, label=label, alpha=0.9) if baseline_ppl is not None: try: @@ -441,14 +421,14 @@ def plot_sparsity_perplexity_curves( except Exception: pass - ax.set_xlabel("Structured FFN channel sparsity", fontsize=11) - ax.set_ylabel("Perplexity (WikiText-2)", fontsize=11) - ax.set_title("Perplexity vs sparsity (low-mode)", fontsize=12, fontweight="bold") + ax.set_xlabel("Structured FFN channel sparsity", fontsize=10) + ax.set_ylabel("Perplexity (WikiText-2)", fontsize=10) + ax.set_title("Perplexity vs sparsity", fontsize=11, fontweight="bold") ax.grid(True, alpha=0.25) - ax.legend(loc="upper left", fontsize=9, frameon=True) + ax.legend(loc="upper left", fontsize=7.5, frameon=True) # Use log if the dynamic range is large. - all_vals = [] + all_vals: List[float] = [] for vs in ppl_by_method.values(): for v in vs: if v is None: @@ -477,19 +457,12 @@ def plot_sparsity_accuracy_curves( baseline_acc: Optional[float] = None, *, ylabel: str = "Accuracy (%)", - title: str = "Accuracy vs sparsity (low-mode)", + title: str = "Accuracy vs sparsity", save_path: Optional[Union[str, Path]] = None, dpi: int = 300, ) -> plt.Figure: - """ - Paper-facing plot: downstream accuracy vs structured sparsity for multiple methods. - - Notes: - - Accuracies are expected to already be in percent units (e.g., 58.0 for 58%). - - Inputs should be filtered to the intended pruning direction (typically low-mode). - """ xs = np.asarray(list(sparsities), dtype=np.float64) - fig, ax = plt.subplots(figsize=(7.0, 4.2)) + fig, ax = plt.subplots(figsize=(3.45, 2.6)) for label in sorted(acc_by_method.keys()): ys_raw = acc_by_method[label] @@ -497,7 +470,7 @@ def plot_sparsity_accuracy_curves( finite = np.isfinite(ys) if not np.any(finite): continue - ax.plot(xs[finite], ys[finite], "o-", linewidth=2.0, markersize=5, label=label, alpha=0.9) + ax.plot(xs[finite], ys[finite], "o-", linewidth=2.0, markersize=4, label=label, alpha=0.9) if baseline_acc is not None: try: @@ -507,11 +480,11 @@ def plot_sparsity_accuracy_curves( except Exception: pass - ax.set_xlabel("Structured FFN channel sparsity", fontsize=11) - ax.set_ylabel(ylabel, fontsize=11) - ax.set_title(title, fontsize=12, fontweight="bold") + ax.set_xlabel("Structured FFN channel sparsity", fontsize=10) + ax.set_ylabel(ylabel, fontsize=10) + ax.set_title(title, fontsize=11, fontweight="bold") ax.grid(True, alpha=0.25) - ax.legend(loc="lower left", fontsize=9, frameon=True) + ax.legend(loc="lower left", fontsize=7.5, frameon=True) plt.tight_layout() if save_path is not None: @@ -527,7 +500,6 @@ def plot_scar_schematic( Generate a schematic of SCAR (supernodes + halos) as a flowchart. This is model-agnostic and can be generated during artifact collection. """ - # Keep this figure intentionally clean + legible in a 1-column ICML layout. fig = plt.figure(figsize=(12, 3.8)) ax = fig.add_subplot(111) ax.set_axis_off() @@ -551,9 +523,6 @@ def arrow(x1, y1, x2, y2, color="#2c3e50"): a = FancyArrowPatch((x1, y1), (x2, y2), arrowstyle="->", linewidth=1.6, color=color, mutation_scale=12) ax.add_patch(a) - # ------------------------------------------------------------------ - # Column layout - # ------------------------------------------------------------------ x0 = 0.03 col_w = 0.22 gap = 0.035 @@ -562,87 +531,47 @@ def arrow(x1, y1, x2, y2, color="#2c3e50"): y_bot = 0.15 h_bot = 0.30 - # Colors (match paper narrative) - C_SUP = "#c0392b" # supernodes - C_HALO = "#1f77b4" # halo - C_STEP = "#2c3e50" # neutral - C_CAL = "#d35400" # calibration/loss-proxy compute + C_SUP = "#c0392b" + C_STEP = "#2c3e50" + C_CAL = "#d35400" - # --- Col 1: Calibration + proxy --- + # Col 1 box(x0, y_top, col_w, h_top, "Calibration\n(tokens)", fc="#fdf2e9", ec=C_CAL) box( x0, y_bot, col_w, h_bot, - # NOTE: Use \frac{1}{2} (not \frac12) for broad compatibility with matplotlib mathtext. r"Loss proxy\n$\mathrm{LP}_i=\frac{1}{2}\,\mathbb{E}[(u_i s_i)^2]$", fc="#fdf2e9", ec=C_CAL, ) - - # Tiny icon: forward/backward arrows ax.text(x0 + col_w / 2, y_top + 0.07, "fwd + bwd", ha="center", va="center", fontsize=9.5, color=C_STEP) - # --- Col 2: Supernodes --- + # Col 2 x1 = x0 + col_w + gap box(x1, y_top, col_w, h_top, r"Supernodes\n(top-$\rho$ by LP)\n\bf protect", fc="#fdebd0", ec=C_SUP) box(x1, y_bot, col_w, h_bot, "FFN channels\n(sorted by LP)", fc="#f8f9f9", ec=C_STEP) - # Draw a stylized heavy-tail: a few bars, with 2 red outliers - xs = np.linspace(x1 + 0.03, x1 + col_w - 0.03, 10) - heights = np.array([0.06, 0.05, 0.04, 0.035, 0.03, 0.028, 0.025, 0.022, 0.18, 0.24]) - for i, (xx, hh) in enumerate(zip(xs, heights)): - c = C_SUP if i >= 8 else "#7f8c8d" - ax.plot([xx, xx], [y_bot + 0.06, y_bot + 0.06 + hh], color=c, linewidth=4, solid_capstyle="round") - ax.text(x1 + col_w / 2, y_bot + 0.03, "rare outliers", ha="center", va="center", fontsize=9.0, color=C_STEP) - - # --- Col 3: Halo + redundancy --- + # Col 3 x2 = x1 + col_w + gap - box(x2, y_top, col_w, h_top, r"Halo\n(high Conn to core)", fc="#e8f6ff", ec=C_HALO) - box( - x2, - y_bot, - col_w, - h_bot, - r"Redundancy to core\n$\mathrm{Red}^{\rightarrow\mathcal{M}}_j=\max_{m\in\mathcal{M}} I(q_j;q_m)$", - fc="#eafaf1", - ec="#27ae60", - ) - ax.text(x2 + col_w / 2, y_bot + 0.03, r"$q=u\odot s$", ha="center", va="center", fontsize=9.0, color=C_STEP) + box(x2, y_top, col_w, h_top, r"Halo (Conn)\n(top-$\eta$)", fc="#eaf2f8", ec="#1f77b4") + box(x2, y_bot, col_w, h_bot, r"Red-to-core\n$\max_{s\in\mathcal{M}}\mathrm{Red}(j,s)$", fc="#eaf2f8", ec="#1f77b4") - # --- Col 4: Structured pruning --- + # Col 4 x3 = x2 + col_w + gap - box( - x3, - y_top, - col_w, - h_top, - "Score + prune\n(non-core only)\nlayer caps", - fc="#f8f9f9", - ec=C_STEP, - ) - box(x3, y_bot, col_w, h_bot, r"Result:\nstructured FFN\nchannel sparsity", fc="#f8f9f9", ec=C_STEP) + box(x3, y_top, col_w, h_top, r"Protect\n(rank-power)", fc="#f8f9f9", ec=C_STEP) + box(x3, y_bot, col_w, h_bot, r"Prune\n(redundant followers)", fc="#f8f9f9", ec=C_STEP) - # Arrows across columns (top row) + # Arrows arrow(x0 + col_w, y_top + h_top / 2, x1, y_top + h_top / 2, color=C_STEP) + arrow(x0 + col_w, y_bot + h_bot / 2, x1, y_bot + h_bot / 2, color=C_STEP) arrow(x1 + col_w, y_top + h_top / 2, x2, y_top + h_top / 2, color=C_STEP) + arrow(x1 + col_w, y_bot + h_bot / 2, x2, y_bot + h_bot / 2, color=C_STEP) arrow(x2 + col_w, y_top + h_top / 2, x3, y_top + h_top / 2, color=C_STEP) + arrow(x2 + col_w, y_bot + h_bot / 2, x3, y_bot + h_bot / 2, color=C_STEP) - # Vertical arrows within columns - arrow(x0 + col_w / 2, y_top, x0 + col_w / 2, y_bot + h_bot, color=C_STEP) - arrow(x2 + col_w / 2, y_top, x2 + col_w / 2, y_bot + h_bot, color=C_STEP) - - ax.text( - 0.02, - 0.98, - "SCAR: supernodes + halos for structured FFN channel pruning", - fontsize=12.5, - fontweight="bold", - ha="left", - va="top", - color=C_STEP, - ) + ax.text(0.5, 0.98, "SCAR pipeline overview", ha="center", va="top", fontsize=12, fontweight="bold", color=C_STEP) plt.tight_layout() if save_path is not None: diff --git a/src/alignment/experiments/llm_experiments.py b/src/alignment/experiments/llm_experiments.py index 6524bd62..94fce2aa 100644 --- a/src/alignment/experiments/llm_experiments.py +++ b/src/alignment/experiments/llm_experiments.py @@ -389,6 +389,10 @@ def evaluate_multiple_metrics( results["accuracy_arc_challenge"] = self._evaluate_arc_challenge( num_samples=num_samples, num_fewshot=num_fewshot ) + elif metric == "accuracy_openbookqa": + results["accuracy_openbookqa"] = self._evaluate_openbookqa( + num_samples=num_samples, num_fewshot=num_fewshot + ) elif metric == "accuracy_piqa": results["accuracy_piqa"] = self._evaluate_piqa( num_samples=num_samples, num_fewshot=num_fewshot @@ -511,7 +515,7 @@ def _evaluate_mmlu(self, num_samples: int = 100, subjects: List[str] = None, num return 0.0 shot_str = f"{num_fewshot}-shot" if num_fewshot > 0 else "zero-shot" - logger.info(f"Evaluating {shot_str} accuracy on MMLU ({num_samples} samples)...") + logger.info(f"Evaluating {shot_str} accuracy on MMLU (~{num_samples} samples total)...") # Default subjects for quick evaluation (covers different domains) if subjects is None: @@ -580,11 +584,27 @@ def _evaluate_mmlu(self, num_samples: int = 100, subjects: List[str] = None, num device = torch.device(self.config.device) choice_labels = ["A", "B", "C", "D"] - # Calculate samples per subject - samples_per_subject = max(1, num_samples // len(subjects)) + # Deterministic sampling (avoid label/order bias from always taking the first example). + # Interpret num_samples as a TOTAL budget across subjects. + import math + import zlib + + subjects = list(subjects) + seed = int(getattr(self.config, "seed", 0) or 0) + + # Shuffle subject order to avoid any systematic ordering artifacts. + rng_subj = np.random.default_rng(seed) + rng_subj.shuffle(subjects) + + # Allocate an exact per-subject quota summing to num_samples. + base = int(num_samples) // max(1, len(subjects)) + rem = int(num_samples) % max(1, len(subjects)) + quotas = [base + (1 if i < rem else 0) for i in range(len(subjects))] with torch.no_grad(): - for subject in subjects: + for subject, quota in zip(subjects, quotas): + if quota <= 0: + continue try: dataset = load_dataset("cais/mmlu", subject, split="test", trust_remote_code=True) # Load dev split for few-shot examples @@ -597,9 +617,17 @@ def _evaluate_mmlu(self, num_samples: int = 100, subjects: List[str] = None, num # Build few-shot prompt for this subject fewshot_prompt = "" if num_fewshot > 0: - for i, ex in enumerate(dev_dataset): - if i >= num_fewshot: - break + # Sample few-shot examples from dev split. + try: + dev_n = len(dev_dataset) + dev_seed = seed + int(zlib.adler32(f"{subject}:dev".encode())) + rng_dev = np.random.default_rng(dev_seed) + dev_idxs = rng_dev.choice(dev_n, size=min(int(num_fewshot), dev_n), replace=False).tolist() + except Exception: + dev_idxs = list(range(int(num_fewshot))) + + for ex_idx in dev_idxs: + ex = dev_dataset[int(ex_idx)] q = ex["question"] choices = ex["choices"] answer_idx = ex["answer"] @@ -611,11 +639,21 @@ def _evaluate_mmlu(self, num_samples: int = 100, subjects: List[str] = None, num subject_correct = 0 subject_total = 0 - for i, example in enumerate(dataset): - if subject_total >= samples_per_subject: + # Sample test examples for this subject (without replacement). + try: + n_test = len(dataset) + test_seed = seed + int(zlib.adler32(f"{subject}:test".encode())) + rng_test = np.random.default_rng(test_seed) + test_idxs = rng_test.choice(n_test, size=min(int(quota), n_test), replace=False).tolist() + except Exception: + test_idxs = list(range(int(quota))) + + for ex_i, ex_idx in enumerate(test_idxs): + if total >= num_samples: break try: + example = dataset[int(ex_idx)] question = example["question"] choices = example["choices"] answer_idx = example["answer"] # 0-indexed @@ -637,7 +675,7 @@ def _evaluate_mmlu(self, num_samples: int = 100, subjects: List[str] = None, num subject_total += 1 except Exception as e: - logger.warning(f"Error on MMLU {subject} sample {i}: {e}") + logger.warning(f"Error on MMLU {subject} sample {ex_i}: {e}") continue if subject_total > 0: @@ -1131,6 +1169,105 @@ def _evaluate_arc_challenge(self, num_samples: int = 100, num_fewshot: int = 0) logger.info(f"ARC-Challenge accuracy ({shot_str}): {accuracy:.2f}% ({correct}/{total})") return accuracy + def _evaluate_openbookqa(self, num_samples: int = 100, num_fewshot: int = 0) -> float: + """ + Zero-/few-shot evaluation on OpenBookQA (4-way MCQ). + + We score options using conditional log-probability of the *option label* continuation, + with the full question + choices included in the prompt (standard MCQ protocol). + + Returns accuracy in percent (higher is better). + """ + try: + from datasets import load_dataset + except ImportError: + logger.error("datasets library not installed, cannot evaluate OpenBookQA") + return 0.0 + + shot_str = f"{num_fewshot}-shot" if num_fewshot > 0 else "zero-shot" + logger.info(f"Evaluating {shot_str} accuracy on OpenBookQA ({num_samples} samples)...") + + # Dataset schema varies a bit across versions; handle both common shapes. + # HF dataset: openbookqa, config \"main\". + try: + dataset = load_dataset("openbookqa", "main", split="test", trust_remote_code=True) + if num_fewshot > 0: + train_dataset = load_dataset("openbookqa", "main", split="train", trust_remote_code=True) + except Exception as e: + logger.error(f"Failed to load OpenBookQA dataset: {e}") + return 0.0 + + def _get_question(ex: Dict[str, Any]) -> str: + if isinstance(ex.get("question_stem"), str): + return ex["question_stem"] + q = ex.get("question") + if isinstance(q, dict) and isinstance(q.get("stem"), str): + return q["stem"] + return str(ex.get("question", "")) + + def _get_choices(ex: Dict[str, Any]) -> Tuple[List[str], List[str]]: + ch = ex.get("choices") + if isinstance(ch, dict): + texts = ch.get("text") or ch.get("texts") or [] + labels = ch.get("label") or ch.get("labels") or [] + return list(texts), list(labels) + # Some variants store as list of dicts + if isinstance(ch, list): + texts = [c.get("text", "") for c in ch if isinstance(c, dict)] + labels = [c.get("label", "") for c in ch if isinstance(c, dict)] + return texts, labels + return [], [] + + # Build few-shot prompt in the same MCQ format. + fewshot_prompt = "" + if num_fewshot > 0: + for i, ex in enumerate(train_dataset): + if i >= num_fewshot: + break + q = _get_question(ex) + choice_texts, choice_labels = _get_choices(ex) + answer_key = ex.get("answerKey") + if not choice_texts or not choice_labels or answer_key not in choice_labels: + continue + choices_str = "\n".join([f"{choice_labels[j]}) {choice_texts[j]}" for j in range(len(choice_texts))]) + fewshot_prompt += f"Question: {q}\n{choices_str}\nAnswer: {answer_key}\n\n" + + self.model.eval() + correct = 0 + total = 0 + + with torch.no_grad(): + for i, example in enumerate(dataset): + if total >= num_samples: + break + try: + question = _get_question(example) + choice_texts, choice_labels = _get_choices(example) + answer_key = example.get("answerKey") + if not choice_texts or not choice_labels or answer_key not in choice_labels: + continue + + # Prompt includes choices; continuation is the option label. + choices_str = "\n".join([f"{choice_labels[j]}) {choice_texts[j]}" for j in range(len(choice_texts))]) + prompt = ( + f"{fewshot_prompt}Question: {question}\n{choices_str}\nAnswer:" if num_fewshot > 0 else + f"Question: {question}\n{choices_str}\nAnswer:" + ) + continuations = [f" {lab}" for lab in choice_labels] + scores = self._score_continuations_conditional_logprob(prompt, continuations, max_length=2048) + + predicted = int(np.argmax(scores)) + if choice_labels[predicted] == answer_key: + correct += 1 + total += 1 + except Exception as e: + logger.warning(f"Error on OpenBookQA sample {i}: {e}") + continue + + accuracy = 100 * correct / total if total > 0 else 0.0 + logger.info(f"OpenBookQA accuracy ({shot_str}): {accuracy:.2f}% ({correct}/{total})") + return accuracy + def _evaluate_truthfulqa(self, num_samples: int = 100, num_fewshot: int = 0) -> float: """ Few-shot evaluation on TruthfulQA (truthfulness in answers). @@ -2596,6 +2733,95 @@ def _resolve_mlp_path(layer_idx: int) -> Optional[str]: logger.info(f"Computed weight_magnitude channel scores for {len(layer_indices)} MLP layers") return results + + def compute_random_channel_scores( + self, + *, + seed: Optional[int] = None, + ) -> Dict[str, Dict[str, torch.Tensor]]: + """ + Structured *channel* random baseline. + + We generate one random score per intermediate FFN channel (shared across gate/up/down + projections) and store it under metric name "random" in `self.importance_scores`. + + Note: If pruning_mode == "random", the pruning mask creation ignores score values and + uses uniform random selection; we still store scores to provide consistent shapes and + to make this baseline explicit in saved artifacts. + """ + import re + + if seed is None: + seed = int(getattr(self.config, "seed", 0) or 0) + + underlying_model = self._get_underlying_model() + module_dict = dict(underlying_model.named_modules()) + + # Identify MLP layer indices by scanning module names (robust even if no other + # importance scores were computed). + layer_indices = set() + for name in module_dict.keys(): + m = re.search(r"layers\.(\d+)\.mlp\.gate_proj$", name) + if m: + layer_indices.add(int(m.group(1))) + + if not layer_indices: + logger.warning("random: no MLP layers found; skipping random channel baseline") + return {} + + # Use a dedicated generator for determinism. + gen = torch.Generator(device="cpu") + gen.manual_seed(seed) + + def _resolve_mlp_path(layer_idx: int) -> Optional[str]: + candidates = [ + f"model.model.layers.{layer_idx}.mlp", + f"model.layers.{layer_idx}.mlp", + f"layers.{layer_idx}.mlp", + ] + for c in candidates: + if c in module_dict: + return c + return None + + results: Dict[str, Dict[str, torch.Tensor]] = {} + for layer_idx in sorted(layer_indices): + mlp_path = _resolve_mlp_path(layer_idx) + if mlp_path is None: + logger.warning(f"random: could not resolve MLP path for layer {layer_idx}") + continue + + gate_name = f"{mlp_path}.gate_proj" + up_name = f"{mlp_path}.up_proj" + down_name = f"{mlp_path}.down_proj" + if gate_name not in module_dict or up_name not in module_dict or down_name not in module_dict: + logger.warning(f"random: missing projections for {mlp_path}") + continue + + gate = module_dict[gate_name] + up = module_dict[up_name] + down = module_dict[down_name] + if not all(isinstance(m, nn.Linear) for m in (gate, up, down)): + logger.warning(f"random: projections for {mlp_path} are not all nn.Linear; skipping") + continue + + n = int(gate.out_features) + if n <= 0: + continue + + # One score per intermediate channel. + scores = torch.rand((n,), generator=gen, dtype=torch.float32) + + for store_name in (gate_name, up_name, down_name): + if store_name not in self.importance_scores: + self.importance_scores[store_name] = {} + self.importance_scores[store_name]["random"] = scores + if store_name not in results: + results[store_name] = {} + results[store_name]["random"] = scores + + logger.info(f"Computed random channel scores for {len(layer_indices)} MLP layers (seed={seed})") + return results @staticmethod def _normalize_scores_tensor(scores: torch.Tensor, eps: float = 1e-8) -> torch.Tensor: @@ -6545,6 +6771,12 @@ def apply_pruning(self, sparsity: float = 0.2, metric: str = "activation_l2_norm if not self.importance_scores: raise ValueError("Must compute importance scores before pruning") + # Per-call diagnostics that downstream artifact collection can use to explain + # catastrophic baseline failures (e.g., pruning supernodes). + # + # Stored as a side effect to avoid changing the public return type. + self._last_pruning_diagnostics = {} + # Paper-faithful *unstructured* reproductions for Wanda/SparseGPT (kept separate from channel-adapted baselines). if metric in {"wanda_unstructured", "sparsegpt_unstructured"}: return self.apply_unstructured_baseline_pruning(sparsity=sparsity, metric=metric, mode=mode) @@ -6592,6 +6824,12 @@ def apply_pruning(self, sparsity: float = 0.2, metric: str = "activation_l2_norm masks = {} processed_mlps = set() # Track which MLPs we've already processed + + # Supernode "hit-rate" diagnostic: fraction of supernodes pruned by this method. + super_total = 0 + super_pruned = 0 + layers_with_super = 0 + layers_with_super_pruned = 0 for layer_name in self.importance_scores.keys(): if metric not in self.importance_scores[layer_name]: @@ -6642,6 +6880,26 @@ def apply_pruning(self, sparsity: float = 0.2, metric: str = "activation_l2_norm # Create mask based on importance scores mask = pruner.create_pruning_mask(scores) + + # Diagnostic: how many supernodes did we prune in this layer? + if core_mask is not None: + try: + cm = core_mask + if not torch.is_tensor(cm): + cm = torch.as_tensor(cm) + cm = cm.to(device=mask.device, dtype=torch.bool) + + if cm.numel() == mask.numel(): + layers_with_super += 1 + super_total += int(cm.sum().item()) + pruned = (mask == 0) + pruned_super = int((pruned & cm).sum().item()) + super_pruned += pruned_super + if pruned_super > 0: + layers_with_super_pruned += 1 + except Exception: + # Never fail pruning due to diagnostics. + pass # Get the MLP module - use underlying model to handle HFCausalLM wrapper underlying_model = self._get_underlying_model() @@ -6712,6 +6970,16 @@ def apply_pruning(self, sparsity: float = 0.2, metric: str = "activation_l2_norm masks.update(attention_masks) self.pruning_masks = masks + # Store diagnostics for the caller (run()) to attach into results JSON. + self._last_pruning_diagnostics = { + "supernode_pruning": { + "supernodes_total": int(super_total), + "supernodes_pruned": int(super_pruned), + "supernodes_pruned_frac": (float(super_pruned) / float(super_total)) if super_total > 0 else None, + "layers_with_supernodes": int(layers_with_super), + "layers_with_supernodes_pruned": int(layers_with_super_pruned), + } + } logger.info(f"Pruned {len(processed_mlps)} MLP layers with {sparsity:.1%} target sparsity") if num_attention_layers > 0: logger.info(f"Pruned {num_attention_layers} attention blocks with shared Q/K/V/O masks") @@ -7381,6 +7649,16 @@ def run(self) -> Dict[str, Any]: import traceback logger.error(traceback.format_exc()) + # Structured random baseline (paper: "Random (channel)") + if "random" in pruning_strategies: + try: + # Deterministic by default (seeded by config.seed). + self.compute_random_channel_scores() + except Exception as rand_err: + logger.error(f"Failed random baseline score computation: {rand_err}") + import traceback + logger.error(traceback.format_exc()) + # Example: per-layer histogram with top-5 annotations # self.plot_layer_importance_histogram( # layer_name="model.layers.1.mlp.up_proj", @@ -7622,6 +7900,8 @@ def restore_weights(): "num_pruned_layers": len(masks), "metric": metric, "mode": mode, + # Extra diagnostics for paper analysis (e.g., explain why some baselines collapse) + **(getattr(self, "_last_pruning_diagnostics", {}) or {}), } else: pruning_data["perplexities"].append(None) diff --git a/src/alignment/metrics/__init__.py b/src/alignment/metrics/__init__.py index 79948a7d..0fbb86bc 100644 --- a/src/alignment/metrics/__init__.py +++ b/src/alignment/metrics/__init__.py @@ -2,7 +2,7 @@ Metrics for measuring neural network alignment, redundancy, and synergy. ============================================================================= -METRIC TAXONOMY (from drafts/alignment_notes/alignment_red.tex) +METRIC TAXONOMY (paper-aligned definitions) ============================================================================= 1. ALIGNMENT METRICS (Rayleigh Quotient based) diff --git a/src/alignment/metrics/cross_layer.py b/src/alignment/metrics/cross_layer.py index a475dbb8..8b78bb82 100644 --- a/src/alignment/metrics/cross_layer.py +++ b/src/alignment/metrics/cross_layer.py @@ -13,7 +13,7 @@ 3. Cross-layer importance score: Combines current-layer and next-layer analysis -Theory (extending alignment_notes, following SCAR logic): +Theory (cross-layer importance; SCAR-style downstream dependence): The key insight is that importance flows FORWARD through the network: - A neuron is important if DOWNSTREAM layers depend on it - A neuron is redundant if its information is already carried by other neurons diff --git a/src/alignment/metrics/halo_redundancy.py b/src/alignment/metrics/halo_redundancy.py index f9240781..c5bf0d2c 100644 --- a/src/alignment/metrics/halo_redundancy.py +++ b/src/alignment/metrics/halo_redundancy.py @@ -3,7 +3,7 @@ Computes and visualizes redundancy patterns within and between halo/non-halo groups. -Theory (from alignment_notes): +Theory (Gaussian mutual information): - Pairwise redundancy: I(Y_i; Y_j) = -0.5 * log(1 - ρ²) - Per-neuron redundancy: R(Y_i) = mean over neighbors of I(Y_i; Y_j) - Halo = neurons with high connectivity to supernodes @@ -147,7 +147,7 @@ def correlation_to_redundancy(corr: torch.Tensor) -> torch.Tensor: """ Convert correlation to redundancy using Gaussian MI formula. - Theory (from drafts/alignment_notes/alignment_red.tex): + Theory (Gaussian mutual information): I(Y_i; Y_j) = -0.5 * log(1 - ρ²) This is the mutual information between jointly Gaussian variables. @@ -220,7 +220,7 @@ class HaloRedundancy(BaseMetric): This metric analyzes the information structure of halo vs non-halo neurons, helping to validate whether halo membership correlates with redundancy. - Theory (from alignment_notes): + Theory (Gaussian mutual information): Redundancy: I(Y_i; Y_j) = -0.5 * log(1 - ρ²) If halo neurons are indeed "echo chambers" of supernodes, we expect: diff --git a/src/alignment/metrics/information/gaussian_mi.py b/src/alignment/metrics/information/gaussian_mi.py index 1300067b..1d8aec7e 100644 --- a/src/alignment/metrics/information/gaussian_mi.py +++ b/src/alignment/metrics/information/gaussian_mi.py @@ -289,7 +289,7 @@ def compute(self, inputs: torch.Tensor, weights: torch.Tensor, outputs: Optional # - RQ = (w^T Σ_x w) / (w^T w) -- normalizes by weight norm (scale-invariant) # - MI = 0.5 * log(1 + (w^T Σ_x w) / σ_n²) -- uses raw signal variance! # - # From the theory (see drafts/alignment_notes/alignment_red.tex): + # From the theory (see paper): # For noisy linear neuron y = w^T X + n where n ~ N(0, σ_n²): # I(X; y) = 0.5 * log(1 + (w^T Σ_X w) / σ_n²) # diff --git a/src/alignment/metrics/multi_supernode.py b/src/alignment/metrics/multi_supernode.py index ac285f0f..c2c10df4 100644 --- a/src/alignment/metrics/multi_supernode.py +++ b/src/alignment/metrics/multi_supernode.py @@ -10,7 +10,7 @@ 3. Measuring redundancy within and between supernode clusters 4. More nuanced pruning based on cluster structure -Theory (extending alignment_notes): +Theory (multi-supernode extension): - Instead of treating top k% as a single supernode group, we cluster them - Each cluster represents a different "functional group" of important neurons - Halo is defined relative to each cluster diff --git a/src/alignment/pruning/strategies/cluster_aware.py b/src/alignment/pruning/strategies/cluster_aware.py index 11f5cf19..12037219 100644 --- a/src/alignment/pruning/strategies/cluster_aware.py +++ b/src/alignment/pruning/strategies/cluster_aware.py @@ -1,7 +1,7 @@ """ Cluster-aware pruning strategy with halo scoring and cluster constraints. -This implements the cluster-and-halo pruning approach from the vision paper: +This implements a cluster-and-halo pruning approach: Score_i = α·log(RQ_i) + β·Syn_i - γ·Red_i + λ·HaloSyn_i @@ -9,9 +9,6 @@ 1. Protect critical: prune at most p_C fraction from critical cluster per layer 2. Target redundant/background: prioritize pruning from these clusters 3. Synergy-pair constraint: don't prune both members of top synergistic pairs - -References: -- Channel Clusters and Halo Dependencies for Structured Pruning (ICML 2026) """ import logging @@ -31,7 +28,7 @@ class ClusterAwarePruningConfig(PruningConfig): """Configuration for cluster-aware pruning.""" - # Score weights (Eq. 14 in paper) + # Score weights for the composite pruning score alpha: float = 1.0 # Weight for log(RQ) beta: float = 0.5 # Weight for Synergy gamma: float = 0.3 # Weight for Redundancy (subtracted) @@ -155,7 +152,7 @@ def compute_importance_scores( clusters, n_channels, layer_name ) - # 4. Compute composite scores (Eq. 14) + # 4. Compute composite scores log_rq = np.log(np.clip(metrics['rq'], 1e-10, None)) # Normalize each component to [0, 1] for stable weighting From ec018c585211ac661c62a5963c079049c7e36cbd Mon Sep 17 00:00:00 2001 From: Houman Safaai Date: Wed, 14 Jan 2026 12:56:28 -0500 Subject: [PATCH 14/15] Track cluster_experiments module; stop ignoring src experiments --- .gitignore | 1 - .../experiments/cluster_experiments.py | 2479 +++++++++++++++++ 2 files changed, 2479 insertions(+), 1 deletion(-) create mode 100644 src/alignment/experiments/cluster_experiments.py diff --git a/.gitignore b/.gitignore index 442f7cc0..008ac0e9 100644 --- a/.gitignore +++ b/.gitignore @@ -9,7 +9,6 @@ logs/ runs/ outputs/ results/ -experiments/ # Backup files diff --git a/src/alignment/experiments/cluster_experiments.py b/src/alignment/experiments/cluster_experiments.py new file mode 100644 index 00000000..b2bad4d9 --- /dev/null +++ b/src/alignment/experiments/cluster_experiments.py @@ -0,0 +1,2479 @@ +""" +Cluster-based analysis experiments for neural networks. + +This module provides a general experiment runner for: +1. Computing per-channel metrics (RQ, Redundancy, Synergy with continuous target) +2. Clustering channels/neurons into functional types (Critical, Redundant, Synergistic, Background) +3. Cross-layer halo analysis (downstream dependencies) +4. Cascade/damage prediction experiments +5. Cluster-aware pruning with baseline comparisons + +Compatible with any neural network architecture: +- Vision: ResNet, VGG, MobileNet, etc. +- LLMs: Can be adapted for FFN analysis +- Any model with convolutional or linear layers +""" + +import logging +from dataclasses import dataclass, field +from pathlib import Path +from typing import Any, Dict, List, Optional, Tuple, Union +import json +import numpy as np + +logger = logging.getLogger(__name__) + +try: + import torch + import torch.nn as nn + from torch.utils.data import DataLoader + HAS_TORCH = True +except ImportError: + HAS_TORCH = False + +from ..analysis.clustering import MetricSpaceClustering, CrossLayerHaloAnalysis +from ..analysis.cascade_analysis import CascadeAnalysis, DamagePrediction +from ..pruning.pipeline import PruningPipelineOptions, run_pruning_pipeline + + +class _CovAccumulator: + """ + Streaming Gaussian-statistics accumulator for a layer. + + Maintains sufficient statistics to compute: + - per-channel variance + - channel-channel covariance/correlation + - covariance between scalar target T and channels + """ + + def __init__(self, n_channels: int): + self.n = 0 + self.sum_y = np.zeros(n_channels, dtype=np.float64) + self.sum_abs_y = np.zeros(n_channels, dtype=np.float64) + self.sum_yy = np.zeros((n_channels, n_channels), dtype=np.float64) + self.sum_t = 0.0 + self.sum_tt = 0.0 + self.sum_ty = np.zeros(n_channels, dtype=np.float64) + + def update(self, y: np.ndarray, t: np.ndarray) -> None: + """ + Args: + y: [N, C] channel samples (float) + t: [N] target samples (float) + """ + if y.size == 0: + return + y = np.asarray(y, dtype=np.float64) + t = np.asarray(t, dtype=np.float64).reshape(-1) + if y.ndim != 2: + raise ValueError(f"Expected y as [N,C], got shape {y.shape}") + if t.shape[0] != y.shape[0]: + raise ValueError(f"Mismatched sample count: y has {y.shape[0]}, t has {t.shape[0]}") + + self.n += int(y.shape[0]) + self.sum_y += y.sum(axis=0) + # For activation-magnitude baselines (mean |activation| per channel) + self.sum_abs_y += np.abs(y).sum(axis=0) + self.sum_yy += y.T @ y + self.sum_t += float(t.sum()) + self.sum_tt += float((t * t).sum()) + self.sum_ty += (t[:, None] * y).sum(axis=0) + + def finalize(self) -> Tuple[float, np.ndarray, np.ndarray, np.ndarray]: + """ + Returns: + var_t: scalar + var_y: [C] + cov_yy: [C, C] + cov_ty: [C] + """ + if self.n < 2: + c = self.sum_y.shape[0] + return 0.0, np.zeros(c), np.zeros((c, c)), np.zeros(c) + + n = float(self.n) + mean_y = self.sum_y / n + mean_t = self.sum_t / n + + # Unbiased covariance estimates (divide by n-1) + cov_yy = (self.sum_yy - n * np.outer(mean_y, mean_y)) / (n - 1.0) + var_y = np.clip(np.diag(cov_yy), 1e-12, None) + + var_t = float((self.sum_tt - n * mean_t * mean_t) / (n - 1.0)) + var_t = max(var_t, 1e-12) + + cov_ty = (self.sum_ty - n * mean_t * mean_y) / (n - 1.0) + return var_t, var_y, cov_yy, cov_ty + + +@dataclass +class ClusterAnalysisConfig: + """Configuration for cluster-based analysis experiments.""" + model_name: str = "resnet18" + dataset_name: str = "cifar10" + n_calibration: int = 5000 + n_clusters: int = 4 + # How to form channel samples from Conv outputs Y[B,C,H,W] + # - "flatten_spatial": treat spatial positions as samples (subsample per image) + # - "gap": global-average-pool per image (one sample per image) + activation_samples: str = "flatten_spatial" + spatial_samples_per_image: int = 16 # used when activation_samples="flatten_spatial" + synergy_target: str = "logit_margin" # logit_margin, correct_logit + # Synergy settings: + # - synergy_candidate_pool: number of candidate partners per channel (chosen by redundancy) + # - synergy_pairs: top-m partners to average (Eq. per_channel_syn) + synergy_candidate_pool: int = 50 + synergy_pairs: int = 10 + halo_percentile: float = 90.0 + use_activation_weight: bool = True # Use activation-weighted influence for halos + cascade_n_remove: int = 5 + damage_sample_frac: float = 0.2 + # Pruning experiment settings + pruning_ratios: List[float] = field(default_factory=lambda: [0.1, 0.3, 0.5, 0.7]) + pruning_methods: List[str] = field(default_factory=lambda: [ + 'random', 'magnitude', 'taylor', 'network_slimming', 'composite', 'cluster_aware' + ]) + fine_tune_after_pruning: bool = False # Whether to fine-tune after pruning + fine_tune_epochs: int = 10 + fine_tune_lr: float = 0.0001 + fine_tune_max_batches: Optional[int] = None + fine_tune_weight_decay: float = 0.0 + # Output + output_dir: str = "results/cluster_analysis" + device: str = "cuda" + seed: int = 42 + + +# Backward compatibility alias +VisionExperimentConfig = ClusterAnalysisConfig + + +class ClusterAnalysisExperiment: + """ + General experiment class for cluster-based neural network analysis. + + Works with any architecture that has Conv2d or Linear layers. + + Example: + >>> config = ClusterAnalysisConfig(model_name="resnet18") + >>> exp = ClusterAnalysisExperiment(config, model, train_loader, test_loader) + >>> results = exp.run() + """ + + def __init__( + self, + config: ClusterAnalysisConfig, + model: "nn.Module", + train_loader: "DataLoader", + test_loader: "DataLoader", + ): + self.config = config + self.model = model.to(config.device) + self.train_loader = train_loader + self.test_loader = test_loader + self.device = config.device + + # Results storage + self.layer_metrics = {} + self.cluster_results = {} + self.halo_results = {} + self.halo_flow_results = {} + self.cascade_results = {} + self.pruning_results = {} + self.pruning_cluster_distributions = {} + # Cache for expensive pruning scores (e.g., gradient-based Taylor) + self._pruning_score_cache: Dict[str, Dict[str, "torch.Tensor"]] = {} + + # Setup output directory + self.output_dir = Path(config.output_dir) + self.output_dir.mkdir(parents=True, exist_ok=True) + + # Get analyzable layers + self.layers = self._get_conv_layers() + logger.info(f"Found {len(self.layers)} convolutional layers") + + def _get_conv_layers(self) -> List[Tuple[str, nn.Module]]: + """Get all Conv2d layers for analysis.""" + layers = [] + for name, module in self.model.named_modules(): + if isinstance(module, nn.Conv2d) and module.out_channels >= 4: + layers.append((name, module)) + return layers + + def compute_metrics(self) -> Dict[str, Dict[str, np.ndarray]]: + """ + Compute per-channel metrics for all layers. + + Returns: + Dict mapping layer_name to dict of metric arrays + """ + logger.info("Computing per-channel metrics (streaming)...") + self.model.eval() + + # Per-layer accumulators (filled lazily once we see a batch for the layer) + accs: Dict[str, _CovAccumulator] = {} + + # Temporary per-batch activations captured by hooks + batch_acts: Dict[str, "torch.Tensor"] = {} + + def hook_fn(name: str): + def fn(_m, _inp, out): + # Store only for this batch; processed after logits are computed + batch_acts[name] = out.detach() + return fn + + # Register hooks + handles = [] + for name, layer in self.layers: + handles.append(layer.register_forward_hook(hook_fn(name))) + + activation_mode = str(getattr(self.config, "activation_samples", "flatten_spatial")).lower() + samples_per_img = int(getattr(self.config, "spatial_samples_per_image", 16)) + samples_per_img = max(1, samples_per_img) + + rng = np.random.default_rng(int(getattr(self.config, "seed", 42))) + + n_seen = 0 + with torch.no_grad(): + for x, y in self.train_loader: + if n_seen >= self.config.n_calibration: + break + + # Trim last batch to hit n_calibration exactly + remaining = int(self.config.n_calibration) - int(n_seen) + if remaining <= 0: + break + if x.size(0) > remaining: + x = x[:remaining] + y = y[:remaining] + + x = x.to(self.device) + y = y.to(self.device) + + batch_acts.clear() + logits = self.model(x) + + # Continuous target T (logit margin) + bsz = logits.size(0) + correct_logits = logits[torch.arange(bsz, device=logits.device), y] + mask = torch.ones_like(logits, dtype=torch.bool) + mask[torch.arange(bsz, device=logits.device), y] = False + max_incorrect = logits.masked_fill(~mask, float("-inf")).max(dim=1)[0] + T_img = (correct_logits - max_incorrect).detach().cpu().numpy() # [B] + + # Update each layer accumulator using the captured activations + for name, layer in self.layers: + out = batch_acts.get(name) + if out is None: + continue + if out.ndim != 4: + continue + + out_cpu = out.detach().cpu() # [B, C, H, W] + b, c, h, w = out_cpu.shape + + if activation_mode in {"gap", "global", "global_avg", "global_average"}: + y_s = out_cpu.mean(dim=(2, 3)).numpy() # [B, C] + t_s = T_img + else: + # Spatially-flattened samples, subsampled per image + hw = int(h * w) + p = min(samples_per_img, hw) + # [B, HW, C] as numpy for fast per-image patch subsampling + y_hw_np = out_cpu.permute(0, 2, 3, 1).reshape(b, hw, c).numpy() + if p < hw: + idx = rng.integers(0, hw, size=(b, p), endpoint=False) + row = np.arange(b)[:, None] + y_s = y_hw_np[row, idx, :].reshape(b * p, c) + t_s = np.repeat(T_img, p) + else: + y_s = y_hw_np.reshape(b * hw, c) + t_s = np.repeat(T_img, hw) + + if name not in accs: + accs[name] = _CovAccumulator(n_channels=c) + accs[name].update(y_s, t_s) + + n_seen += int(x.size(0)) + + # Remove hooks + for h in handles: + h.remove() + + # Compute metrics per layer from accumulated Gaussian stats + for name, layer in self.layers: + acc = accs.get(name) + if acc is None: + continue + + var_t, var_y, cov_yy, cov_ty = acc.finalize() + n_channels = int(var_y.shape[0]) + + metrics: Dict[str, np.ndarray] = {} + + # 0) Activation magnitude baselines (computed from the same calibration samples) + # Mean absolute activation per channel (requested baseline) + if acc.n > 0: + metrics["activation_mean"] = (acc.sum_abs_y / float(acc.n))[:n_channels].astype(np.float64) + # RMS activation (close cousin of activation L2 norm; scale doesn't affect ranking) + y2 = np.clip(np.diag(acc.sum_yy) / float(acc.n), 0.0, None) + metrics["activation_rms"] = np.sqrt(y2)[:n_channels].astype(np.float64) + + # 1) Rayleigh Quotient proxy: Var(Y_i) / ||w_i||^2 + weight = layer.weight.data.cpu() # [C_out, C_in, k, k] + weight_flat = weight.view(weight.size(0), -1) # [C_out, ...] + weight_norm = weight_flat.norm(dim=1).numpy().astype(np.float64) ** 2 + rq = var_y / (weight_norm[:n_channels] + 1e-10) + metrics["rq"] = rq.astype(np.float64) + + # 2) Redundancy via Gaussian MI from correlations + denom = np.sqrt(np.outer(var_y, var_y)) + 1e-12 + corr = cov_yy / denom + corr = np.clip(corr, -0.999, 0.999) + mi_matrix = -0.5 * np.log(1.0 - corr ** 2) + np.fill_diagonal(mi_matrix, 0.0) + metrics["redundancy"] = mi_matrix.mean(axis=1).astype(np.float64) + + # 3) Synergy with scalar target under Gaussian approximation (MMI) + # MI(T;Y_i) depends only on corr(T,Y_i) + corr_ty = cov_ty / (np.sqrt(var_t * var_y) + 1e-12) + corr_ty = np.clip(corr_ty, -0.999, 0.999) + mi_t = np.maximum(0.0, -0.5 * np.log(1.0 - corr_ty ** 2)) + + candidate_pool = int(getattr(self.config, "synergy_candidate_pool", 50)) + top_m = int(getattr(self.config, "synergy_pairs", 10)) + candidate_pool = max(2, min(candidate_pool, n_channels)) + top_m = max(1, min(top_m, candidate_pool - 1)) + + synergy = np.zeros(n_channels, dtype=np.float64) + + # Precompute partner ordering by redundancy (MI) per channel + # Use the MI matrix row i, excluding i. + for i in range(n_channels): + order = np.argsort(-mi_matrix[i]) + order = order[order != i] + cand = order[:candidate_pool] + if cand.size == 0: + continue + + mi_i = float(mi_t[i]) + syn_vals: List[float] = [] + for j in cand: + j = int(j) + mi_j = float(mi_t[j]) + cov_i_j = float(cov_yy[i, j]) + mi_joint = self._gaussian_mi_joint_from_stats( + var_t=var_t, + var_i=float(var_y[i]), + var_j=float(var_y[j]), + cov_t_i=float(cov_ty[i]), + cov_t_j=float(cov_ty[j]), + cov_i_j=cov_i_j, + ) + s = mi_joint - mi_i - mi_j + min(mi_i, mi_j) + syn_vals.append(float(s)) + + if syn_vals: + syn_vals.sort(reverse=True) + synergy[i] = float(np.mean(syn_vals[:top_m])) + + metrics["synergy"] = synergy + + self.layer_metrics[name] = metrics + logger.info( + " %s: %d channels (mode=%s, n_samples=%d)", + name, + n_channels, + activation_mode, + acc.n, + ) + + return self.layer_metrics + + def _gaussian_mi(self, x: np.ndarray, y: np.ndarray) -> float: + """Compute Gaussian MI between two variables.""" + rho = np.corrcoef(x, y)[0, 1] + rho = np.clip(rho, -0.999, 0.999) + return max(0, -0.5 * np.log(1 - rho ** 2)) + + def _gaussian_mi_joint(self, t: np.ndarray, y1: np.ndarray, y2: np.ndarray) -> float: + """Compute Gaussian MI I(T; [Y1, Y2]).""" + joint = np.column_stack([t, y1, y2]) + cov = np.cov(joint.T) + 1e-8 * np.eye(3) + var_t = cov[0, 0] + cov_y = cov[1:, 1:] + det_all = np.linalg.det(cov) + det_y = np.linalg.det(cov_y) + if det_all <= 0 or det_y <= 0 or var_t <= 0: + return 0. + return max(0, 0.5 * np.log(var_t * det_y / det_all)) + + def _gaussian_mi_joint_from_stats( + self, + *, + var_t: float, + var_i: float, + var_j: float, + cov_t_i: float, + cov_t_j: float, + cov_i_j: float, + ) -> float: + """Gaussian MI I(T; [Y_i, Y_j]) from covariance statistics (no raw samples).""" + # 3x3 covariance matrix for (T, Y_i, Y_j) + cov = np.array( + [ + [var_t, cov_t_i, cov_t_j], + [cov_t_i, var_i, cov_i_j], + [cov_t_j, cov_i_j, var_j], + ], + dtype=np.float64, + ) + cov += 1e-10 * np.eye(3) + cov_y = np.array([[var_i, cov_i_j], [cov_i_j, var_j]], dtype=np.float64) + cov_y += 1e-10 * np.eye(2) + + det_all = float(np.linalg.det(cov)) + det_y = float(np.linalg.det(cov_y)) + if det_all <= 0.0 or det_y <= 0.0 or var_t <= 0.0: + return 0.0 + return max(0.0, 0.5 * float(np.log(var_t * det_y / det_all))) + + def run_clustering(self) -> Dict[str, Any]: + """Cluster channels in each layer.""" + logger.info("Clustering channels...") + + clusterer = MetricSpaceClustering( + n_clusters=self.config.n_clusters, + seed=self.config.seed, + ) + + for name, metrics in self.layer_metrics.items(): + result = clusterer.fit( + metrics["rq"], + metrics["redundancy"], + metrics["synergy"], + name, + ) + self.cluster_results[name] = { + "labels": result.labels, + "centroids": result.centroids, + "silhouette": result.silhouette, + "type_mapping": result.type_mapping, + "type_counts": result.type_counts, + "layer_name": name, + } + logger.info(f" {name}: silhouette={result.silhouette:.3f}, types={result.type_counts}") + + return self.cluster_results + + def run_halo_analysis(self) -> Dict[str, Any]: + """ + Analyze cross-layer halos with activation-weighted influence. + + Uses effective influence: ||W||_1 * std(Y) to account for + batch normalization scaling effects. + """ + logger.info("Analyzing cross-layer halos...") + + halo_analyzer = CrossLayerHaloAnalysis( + percentile=self.config.halo_percentile, + use_activation_weight=getattr(self.config, 'use_activation_weight', True), + ) + + layer_names = list(self.cluster_results.keys()) + modules = dict(self.model.named_modules()) + + # Choose halo transitions along *direct weight-connected* edges by matching channel dimensions. + # This avoids spurious transitions in residual blocks (e.g., conv2 -> downsample conv), + # while still supporting skip-branch convs as valid sources into the next block. + for i, src_name in enumerate(layer_names[:-1]): + src_layer = modules.get(src_name) + if src_layer is None or not hasattr(src_layer, "weight"): + continue + + src_out = int(src_layer.weight.shape[0]) + + tgt_name = None + for j in range(i + 1, len(layer_names)): + cand_name = layer_names[j] + cand_layer = modules.get(cand_name) + if cand_layer is None or not hasattr(cand_layer, "weight"): + continue + w = cand_layer.weight + if w is None or w.ndim < 2: + continue + cand_in = int(w.shape[1]) + if cand_in == src_out: + tgt_name = cand_name + break + + if tgt_name is None: + continue + + src_result = self.cluster_results[src_name] + tgt_result = self.cluster_results.get(tgt_name, {}) + src_metrics = self.layer_metrics.get(src_name, {}) + tgt_metrics = self.layer_metrics.get(tgt_name, {}) + + if not tgt_metrics: + continue + + # Get weight matrix between layers + tgt_layer = modules[tgt_name] + tgt_weight = tgt_layer.weight.data.cpu().numpy() + n_out, n_in = tgt_weight.shape[0], tgt_weight.shape[1] + + # Base influence: L1 norm over kernel dimensions + influence = np.abs(tgt_weight.reshape(n_out, n_in, -1)).sum(axis=2) + + # Apply activation weighting (effective influence = weight * std) + # This accounts for BN scaling: channels with large gamma/sqrt(var) + # have larger effective signal even if outgoing weights are small + # Activation-weighted influence proxy. + # We approximate sigma_i as the (post-BN when present) channel std: + # sigma_conv = sqrt(RQ_i * ||w_i||^2) (since RQ_i = Var(Y_i)/||w_i||^2) + # sigma_postBN ≈ sigma_conv * |gamma| / sqrt(running_var + eps) + if "rq" in src_metrics: + w_src = src_layer.weight.data.cpu().numpy().astype(np.float64) + w_norm_sq = np.sum(w_src.reshape(w_src.shape[0], -1) ** 2, axis=1) + rq = np.asarray(src_metrics["rq"], dtype=np.float64).reshape(-1) + sigma = np.sqrt(np.clip(rq[: len(w_norm_sq)] * w_norm_sq[: len(rq)], 0.0, None)) + + bn = self._find_bn_for_conv(self.model, src_name) + if bn is not None and hasattr(bn, "weight") and hasattr(bn, "running_var"): + gamma = bn.weight.detach().cpu().numpy().astype(np.float64) + rv = bn.running_var.detach().cpu().numpy().astype(np.float64) + eps = float(getattr(bn, "eps", 1e-5)) + scale = np.abs(gamma) / np.sqrt(rv + eps) + m = min(len(sigma), len(scale)) + sigma[:m] = sigma[:m] * scale[:m] + + n_in_actual = min(n_in, len(sigma)) + influence[:, :n_in_actual] = influence[:, :n_in_actual] * sigma[:n_in_actual] + + halo_data = {} + for cid, ctype in src_result["type_mapping"].items(): + cluster_idx = np.where(src_result["labels"] == cid)[0] + if len(cluster_idx) == 0 or cluster_idx.max() >= n_in: + continue + + halo_idx, rel_infl = halo_analyzer.find_halo(influence, cluster_idx) + if len(halo_idx) == 0: + continue + + halo_data[ctype] = { + "halo_size": len(halo_idx), + "halo_red": float(np.mean(tgt_metrics["redundancy"][halo_idx])), + "halo_syn": float(np.mean(tgt_metrics["synergy"][halo_idx])), + "cluster_type": ctype, + } + + self.halo_results[f"{src_name}->{tgt_name}"] = halo_data + logger.info(f" {src_name}->{tgt_name}: {len(halo_data)} cluster halos analyzed") + + # Also compute cluster-to-cluster flow matrix (for influence heatmaps) + try: + src_labels = np.asarray(src_result.get("labels", np.array([], dtype=int))).astype(int) + tgt_labels = np.asarray(tgt_result.get("labels", np.array([], dtype=int))).astype(int) + if src_labels.size > 0 and tgt_labels.size > 0: + # Trim labels to match influence matrix dimensions if needed + src_labels = src_labels[: min(len(src_labels), n_in)] + tgt_labels = tgt_labels[: min(len(tgt_labels), n_out)] + + flow = halo_analyzer.compute_cluster_to_cluster_flow( + influence, + source_labels=src_labels, + target_labels=tgt_labels, + source_types=src_result.get("type_mapping", {}), + target_types=tgt_result.get("type_mapping", {}), + ) + self.halo_flow_results[f"{src_name}->{tgt_name}"] = flow + except Exception as exc: + logger.debug("Could not compute halo flow matrix for %s->%s: %s", src_name, tgt_name, exc) + + return self.halo_results + + def run_cascade_test(self) -> Dict[str, Any]: + """Run cascade damage test by cluster type.""" + logger.info("Running cascade tests...") + + cascade = CascadeAnalysis(self.model, self.test_loader, self.device) + cascade.baseline() + + for name, cluster_data in self.cluster_results.items(): + results = cascade.by_cluster( + name, + cluster_data["labels"], + cluster_data["type_mapping"], + n_rm=self.config.cascade_n_remove, + ) + self.cascade_results[name] = { + ctype: { + "accuracy_drop": r.accuracy_drop, + "loss_increase": r.loss_increase, + "n_removed": r.n_removed, + } + for ctype, r in results.items() + } + logger.info(f" {name}: {len(results)} cluster types tested") + + return self.cascade_results + + def run_pruning_experiments( + self, + ratios: Optional[List[float]] = None, + methods: Optional[List[str]] = None, + fine_tune_epochs: int = 0, + fine_tune_lr: float = 0.0001, + fine_tune_max_batches: Optional[int] = None, + fine_tune_weight_decay: float = 0.0, + ) -> Dict[str, Any]: + """ + Run pruning experiments comparing different methods. + + Args: + ratios: Sparsity ratios to test (default: [0.3, 0.5, 0.7]) + methods: Pruning methods to compare (default: all) + fine_tune_epochs: Number of fine-tuning epochs after pruning + fine_tune_lr: Learning rate for fine-tuning (unused when fine_tune_epochs=0) + + Returns: + Dict mapping (method, ratio) to accuracy results + """ + import copy + + ratios = ratios or getattr(self.config, "pruning_ratios", None) \ + or getattr(self.config, "pruning_amounts", None) \ + or [0.1, 0.3, 0.5, 0.7] + + default_methods = [ + "random", "magnitude", + "network_slimming", + "rq_low", "rq_high", + "redundancy_low", "redundancy_high", + "synergy_low", "synergy_high", + "composite", "composite_pos_red", + "rq_minus_red", "rq_plus_red", + "magnitude_plus_rq", "magnitude_minus_red", "magnitude_plus_red", + ] + methods = methods or getattr(self.config, "pruning_methods", None) \ + or getattr(self.config, "pruning_algorithms", None) \ + or getattr(self.config, "pruning_strategies", None) \ + or default_methods + + pipeline_options = PruningPipelineOptions( + distribution=getattr(self.config, "pruning_distribution", "uniform"), + dependency_aware=bool(getattr(self.config, "dependency_aware_pruning", False)), + min_amount=getattr(self.config, "pruning_min_per_layer", 0.0), + max_amount=getattr(self.config, "pruning_max_per_layer", 0.95), + ) + + baseline_acc = self._evaluate_accuracy() + logger.info(f"Baseline accuracy: {baseline_acc:.2%}") + + if baseline_acc < 0.7: + logger.warning("Baseline accuracy is low; pruning comparisons may be noisy.") + + results = {"baseline": baseline_acc, "methods": {}} + + for method in methods: + logger.info(f"Running pruning method: {method}") + method_results = {} + results["methods"][method] = method_results + + for ratio in ratios: + logger.info(f" Target sparsity: {ratio:.0%}") + model_copy = copy.deepcopy(self.model) + layer_modules = self._get_layer_module_map(model_copy) + selection_mode = self._selection_mode_for_method(method) + + try: + if method.startswith("cluster_aware"): + pipeline_result = self._run_cluster_aware_pruning( + model_copy, + layer_modules=layer_modules, + ratio=ratio, + method=method, + ) + else: + layer_scores = self._compute_layer_scores_for_method(method, model_copy) + if not layer_scores: + raise ValueError("No layer scores available for method") + + pipeline_result = run_pruning_pipeline( + model_copy, + layer_scores, + layer_modules=layer_modules, + target_sparsity=ratio, + selection_mode=selection_mode, + options=pipeline_options, + ) + + self._zero_batchnorm_from_masks(model_copy, pipeline_result.get("masks", {})) + + acc_before = self._evaluate_accuracy(model_copy) + acc_after = acc_before + if fine_tune_epochs > 0: + model_copy = self._fine_tune( + model_copy, + epochs=fine_tune_epochs, + lr=fine_tune_lr, + max_batches=fine_tune_max_batches, + weight_decay=fine_tune_weight_decay, + masks=pipeline_result.get("masks", {}) if isinstance(pipeline_result, dict) else None, + ) + acc_after = self._evaluate_accuracy(model_copy) + + method_results[ratio] = { + "accuracy_before_ft": acc_before, + "accuracy_after_ft": acc_after, + "accuracy_drop": baseline_acc - acc_before, + "accuracy_recovery": acc_after - acc_before if fine_tune_epochs > 0 else 0.0, + "selection_mode": selection_mode, + "mask_stats": pipeline_result.get("stats", {}), + } + + logger.info(" Result: %.2f%% (drop %.2f%%)", acc_after * 100, (baseline_acc - acc_after) * 100) + except Exception as exc: + logger.warning(" Pruning failed for %s @ %.0f%%: %s", method, ratio * 100, exc) + method_results[ratio] = {"error": str(exc)} + finally: + del model_copy + if torch.cuda.is_available(): + torch.cuda.empty_cache() + + self.pruning_results = results + with open(self.output_dir / "pruning_results.json", "w") as f: + json.dump(results, f, indent=2, default=str) + return results + + def _get_layer_module_map(self, model: nn.Module) -> Dict[str, nn.Module]: + modules = dict(model.named_modules()) + return {name: modules.get(name) for name, _ in self.layers if name in modules} + + def _selection_mode_for_method(self, method: str) -> str: + if method == "random": + return "random" + high_methods = {"rq_high", "redundancy_high", "synergy_high", "magnitude_high", "activation_l2_norm_high"} + if method in high_methods: + return "high" + return "low" + + def _compute_taylor_channel_scores(self, model: nn.Module) -> Dict[str, "torch.Tensor"]: + """ + Compute per-output-channel Taylor saliency scores for each analyzed conv layer. + + Uses weight-based first-order Taylor approximation: + score_i = sum_d | w_i[d] * grad_w_i[d] | + + Computed over a small calibration subset from self.train_loader. + """ + if not HAS_TORCH: + return {} + + # Keep this small by default; configurable via config if present. + max_samples = int(getattr(self.config, "taylor_samples", 1024)) + max_samples = max(1, max_samples) + + model = model.to(self.device) + model.eval() + + criterion = nn.CrossEntropyLoss() + model.zero_grad(set_to_none=True) + + n_seen = 0 + for x, y in self.train_loader: + if n_seen >= max_samples: + break + + remaining = max_samples - n_seen + if x.size(0) > remaining: + x = x[:remaining] + y = y[:remaining] + + x = x.to(self.device) + y = y.to(self.device) + + logits = model(x) + loss = criterion(logits, y) + loss.backward() + n_seen += int(x.size(0)) + + modules = dict(model.named_modules()) + out: Dict[str, "torch.Tensor"] = {} + for name, _layer in self.layers: + m = modules.get(name) + if m is None or not hasattr(m, "weight") or m.weight is None: + continue + if m.weight.grad is None: + continue + g = m.weight.grad.detach() + w = m.weight.detach() + # Reduce to [C_out] + score = (g * w).abs().view(w.shape[0], -1).sum(dim=1) + out[name] = score.detach().cpu() + + model.zero_grad(set_to_none=True) + return out + + def _compute_geometric_median_channel_scores(self, model: nn.Module) -> Dict[str, "torch.Tensor"]: + """ + Geometric-median (FPGM-style) per-channel importance for Conv layers. + + For each conv layer, treat each output channel filter as a vector and compute + the geometric median m (Weiszfeld). Channels closest to m are considered + more redundant; we prune LOW distances. + """ + if not HAS_TORCH: + return {} + + # Weiszfeld settings (keep small; this is run once and cached) + iters = int(getattr(self.config, "geometric_median_iters", 10)) + iters = max(1, min(iters, 50)) + eps = float(getattr(self.config, "geometric_median_eps", 1e-8)) + eps = max(eps, 1e-12) + + modules = dict(model.named_modules()) + out: Dict[str, "torch.Tensor"] = {} + for name, _layer in self.layers: + m = modules.get(name) + if m is None or not hasattr(m, "weight") or m.weight is None: + continue + + w = m.weight.detach().float().cpu() + if w.ndim < 2: + continue + x = w.view(w.shape[0], -1) # [C_out, D] + if x.numel() == 0: + continue + + # Initialize at the mean + med = x.mean(dim=0) + for _ in range(iters): + d = torch.norm(x - med, p=2, dim=1).clamp_min(eps) # [C_out] + inv = 1.0 / d + med = (inv[:, None] * x).sum(dim=0) / inv.sum() + + # Importance = distance to median (prune low) + dist = torch.norm(x - med, p=2, dim=1) + out[name] = dist.detach().cpu() + + return out + + def _compute_hrank_channel_scores(self, model: nn.Module) -> Dict[str, "torch.Tensor"]: + """ + HRank-style baseline: per-channel average feature-map rank. + + We approximate the rank of each channel's feature map by: + - adaptive average pooling to (p x p) + - computing matrix rank via singular values on the pooled map + - averaging across a small calibration subset + + Channels with LOW average rank are pruned. + """ + if not HAS_TORCH: + return {} + + import torch.nn.functional as F + + max_images = int(getattr(self.config, "hrank_images", 256)) + max_images = max(1, max_images) + pool = int(getattr(self.config, "hrank_pool", 8)) + pool = max(2, min(pool, 32)) + sv_eps = float(getattr(self.config, "hrank_sv_eps", 1e-3)) + sv_eps = max(sv_eps, 1e-6) + + model = model.to(self.device) + model.eval() + + modules = dict(model.named_modules()) + + rank_sum: Dict[str, "torch.Tensor"] = {} + rank_count: Dict[str, int] = {} + + def _svdvals(x: "torch.Tensor") -> "torch.Tensor": + # Batched singular values. Fall back gracefully across torch versions. + try: + return torch.linalg.svdvals(x) + except Exception: + try: + # torch.linalg.svd returns U,S,Vh + return torch.linalg.svd(x, full_matrices=False).S + except Exception: + # Old fallback + return torch.svd(x).S + + def hook_fn(layer_name: str): + def fn(_m, _inp, out): + # out: [B,C,H,W] + try: + if out is None or out.ndim != 4: + return + out_f = out.float() + b, c, _h, _w = out_f.shape + pooled = F.adaptive_avg_pool2d(out_f, (pool, pool)) # [B,C,p,p] + mats = pooled.reshape(b * c, pool, pool) # [B*C,p,p] + + sv = _svdvals(mats) # [B*C,p] + thr = sv.max(dim=1).values * sv_eps + 1e-12 # [B*C] + r = (sv > thr[:, None]).sum(dim=1).float() # [B*C] + r = r.view(b, c).sum(dim=0).detach().cpu().double() # [C] + + if layer_name not in rank_sum: + rank_sum[layer_name] = torch.zeros(c, dtype=torch.float64) + rank_count[layer_name] = 0 + rank_sum[layer_name] += r + rank_count[layer_name] += int(b) + except Exception as exc: + logger.debug("HRank hook failed for %s (%s)", layer_name, exc) + return fn + + handles = [] + for name, _layer in self.layers: + m = modules.get(name) + if isinstance(m, nn.Conv2d): + handles.append(m.register_forward_hook(hook_fn(name))) + + n_seen = 0 + with torch.no_grad(): + for x, _y in self.train_loader: + if n_seen >= max_images: + break + + remaining = max_images - n_seen + if x.size(0) > remaining: + x = x[:remaining] + + x = x.to(self.device) + + _ = model(x) + + bsz = int(x.size(0)) + n_seen += bsz + + for h in handles: + h.remove() + + out_scores: Dict[str, "torch.Tensor"] = {} + for lname, s in rank_sum.items(): + cnt = int(rank_count.get(lname, 0)) + if cnt <= 0: + continue + out_scores[lname] = (s / float(cnt)).float().cpu() + + return out_scores + + def _compute_layer_scores_for_method(self, method: str, model: nn.Module) -> Dict[str, torch.Tensor]: + layer_scores: Dict[str, torch.Tensor] = {} + modules = self._get_layer_module_map(model) + metric_map = { + "rq_low": "rq", + "rq_high": "rq", + "redundancy_low": "redundancy", + "redundancy_high": "redundancy", + "synergy_low": "synergy", + "synergy_high": "synergy", + } + for name, layer in modules.items(): + if layer is None or not hasattr(layer, "weight"): + continue + weight = layer.weight + device = weight.device + metrics = self.layer_metrics.get(name, {}) + n_channels = weight.shape[0] + + if method == "random": + layer_scores[name] = torch.rand(n_channels, device=device) + elif method in {"activation_mean", "activation_rms"}: + values = metrics.get(method) + if values is None: + continue + layer_scores[name] = torch.as_tensor(values, dtype=torch.float32, device=device) + elif method in {"magnitude", "magnitude_high", "activation_l2_norm", "activation_l2_norm_high"}: + w_flat = weight.view(n_channels, -1) + mags = torch.norm(w_flat, p=2, dim=1) + layer_scores[name] = mags + elif method in {"network_slimming", "bn_scale"}: + # Network Slimming baseline: prune small BN gamma (|gamma|) + bn = self._find_bn_for_conv(model, name) + if bn is not None and hasattr(bn, "weight") and bn.weight is not None: + gamma = bn.weight.detach().abs() + if gamma.numel() != n_channels: + logger.warning( + "BN gamma size mismatch for %s: gamma=%d, channels=%d; falling back to weight magnitude", + name, + int(gamma.numel()), + int(n_channels), + ) + w_flat = weight.view(n_channels, -1) + layer_scores[name] = torch.norm(w_flat, p=2, dim=1) + else: + layer_scores[name] = gamma.to(device=device, dtype=torch.float32) + else: + # Fallback for layers without BN (rare in these vision backbones) + w_flat = weight.view(n_channels, -1) + layer_scores[name] = torch.norm(w_flat, p=2, dim=1) + elif method == "taylor": + # Gradient-based baseline. Compute once per experiment and cache on CPU. + if "taylor" not in self._pruning_score_cache: + try: + self._pruning_score_cache["taylor"] = self._compute_taylor_channel_scores(model) + except Exception as exc: + logger.warning("Taylor score computation failed (%s); falling back to magnitude", exc) + self._pruning_score_cache["taylor"] = {} + cpu_scores = self._pruning_score_cache.get("taylor", {}).get(name) + if cpu_scores is None or cpu_scores.numel() != n_channels: + # Fallback: weight magnitude if we couldn't compute gradients or mismatch + w_flat = weight.view(n_channels, -1) + layer_scores[name] = torch.norm(w_flat, p=2, dim=1) + else: + layer_scores[name] = cpu_scores.to(device=device, dtype=torch.float32) + elif method in {"geometric_median", "fpgm"}: + cache_key = "geometric_median" + if cache_key not in self._pruning_score_cache: + try: + self._pruning_score_cache[cache_key] = self._compute_geometric_median_channel_scores(model) + except Exception as exc: + logger.warning("Geometric median score computation failed (%s); falling back to magnitude", exc) + self._pruning_score_cache[cache_key] = {} + cpu_scores = self._pruning_score_cache.get(cache_key, {}).get(name) + if cpu_scores is None or cpu_scores.numel() != n_channels: + w_flat = weight.view(n_channels, -1) + layer_scores[name] = torch.norm(w_flat, p=2, dim=1) + else: + layer_scores[name] = cpu_scores.to(device=device, dtype=torch.float32) + elif method == "hrank": + cache_key = "hrank" + if cache_key not in self._pruning_score_cache: + try: + self._pruning_score_cache[cache_key] = self._compute_hrank_channel_scores(model) + except Exception as exc: + logger.warning("HRank score computation failed (%s); falling back to magnitude", exc) + self._pruning_score_cache[cache_key] = {} + cpu_scores = self._pruning_score_cache.get(cache_key, {}).get(name) + if cpu_scores is None or cpu_scores.numel() != n_channels: + w_flat = weight.view(n_channels, -1) + layer_scores[name] = torch.norm(w_flat, p=2, dim=1) + else: + layer_scores[name] = cpu_scores.to(device=device, dtype=torch.float32) + elif method in metric_map: + values = metrics.get(metric_map[method]) + if values is None: + continue + layer_scores[name] = torch.as_tensor(values, dtype=torch.float32, device=device) + elif method in { + "composite", + "composite_pos_red", + "rq_minus_red", + "rq_plus_red", + "magnitude_plus_rq", + "magnitude_minus_red", + "magnitude_plus_red", + }: + comp = self._compute_composite_metric(method, metrics, layer) + if comp is not None: + layer_scores[name] = comp.to(device) + else: + logger.warning("Unknown pruning method '%s'; skipping layer scores", method) + return {} + return layer_scores + + def _compute_composite_metric(self, method: str, metrics: Dict[str, np.ndarray], layer: nn.Module) -> Optional[torch.Tensor]: + rq = np.log(np.clip(metrics.get("rq", np.ones(layer.weight.shape[0])), 1e-10, None)) + redundancy = metrics.get("redundancy", np.zeros_like(rq)) + synergy = metrics.get("synergy", np.zeros_like(rq)) + + def normalize(arr: np.ndarray) -> np.ndarray: + if arr.size == 0: + return arr + min_v = arr.min() + max_v = arr.max() + if max_v - min_v < 1e-8: + return np.zeros_like(arr) + return (arr - min_v) / (max_v - min_v) + + rq_norm = normalize(rq) + red_norm = normalize(redundancy) + syn_norm = normalize(synergy) + + if method == "composite": + scores = rq_norm + 0.5 * syn_norm - 0.3 * red_norm + elif method == "composite_pos_red": + scores = rq_norm + 0.5 * syn_norm + 0.3 * red_norm + elif method == "rq_minus_red": + scores = rq_norm - 0.5 * red_norm + elif method == "rq_plus_red": + scores = rq_norm + 0.5 * red_norm + elif method == "magnitude_plus_rq": + w = layer.weight.detach().view(layer.weight.shape[0], -1) + mag = normalize(w.norm(p=2, dim=1).cpu().numpy()) + scores = mag + 0.5 * rq_norm + elif method == "magnitude_minus_red": + w = layer.weight.detach().view(layer.weight.shape[0], -1) + mag = normalize(w.norm(p=2, dim=1).cpu().numpy()) + scores = mag - 0.3 * red_norm + elif method == "magnitude_plus_red": + w = layer.weight.detach().view(layer.weight.shape[0], -1) + mag = normalize(w.norm(p=2, dim=1).cpu().numpy()) + scores = mag + 0.3 * red_norm + else: + return None + + return torch.as_tensor(scores, dtype=torch.float32) + + def _compute_halo_syn_proxy( + self, + *, + layer_name: str, + layer: nn.Module, + next_layer: Optional[nn.Module], + next_layer_name: Optional[str], + halo_percentile: float, + use_activation_weight: bool, + ) -> np.ndarray: + """ + Compute per-channel HaloSyn proxy without needing raw activations. + + Uses effective influence: + influence[j,i] = ||W_{j,i}||_1 * sigma_i + + Where sigma_i is approximated from cached RQ and weight norms: + Var(Y_i) = RQ_i * ||w_i||^2 => sigma_i = sqrt(Var(Y_i)) + """ + metrics = self.layer_metrics.get(layer_name, {}) + rq = np.asarray(metrics.get("rq", np.array([])), dtype=np.float64).reshape(-1) + if rq.size == 0 or next_layer is None or next_layer_name is None or not hasattr(next_layer, "weight"): + return np.zeros(int(layer.weight.shape[0]), dtype=np.float64) + + # sigma proxy from rq and weight norms (and BN scaling when present) + w = layer.weight.detach().view(layer.weight.shape[0], -1).cpu().numpy().astype(np.float64) + w_norm_sq = np.sum(w * w, axis=1) + sigma = np.sqrt(np.clip(rq[: len(w_norm_sq)] * w_norm_sq[: len(rq)], 0.0, None)) + + bn = self._find_bn_for_conv(self.model, layer_name) + if bn is not None and hasattr(bn, "weight") and hasattr(bn, "running_var"): + gamma = bn.weight.detach().cpu().numpy().astype(np.float64) + rv = bn.running_var.detach().cpu().numpy().astype(np.float64) + eps = float(getattr(bn, "eps", 1e-5)) + scale = np.abs(gamma) / np.sqrt(rv + eps) + m = min(len(sigma), len(scale)) + sigma[:m] = sigma[:m] * scale[:m] + + w_next = next_layer.weight.detach().cpu().numpy().astype(np.float64) + if w_next.ndim == 4: + influence = np.abs(w_next).sum(axis=(2, 3)) # [out, in] + elif w_next.ndim == 3: + influence = np.abs(w_next).sum(axis=2) # [out, in] + else: + influence = np.abs(w_next) + + # Apply activation weighting via sigma_i (effective influence) + if use_activation_weight: + n_in = min(influence.shape[1], sigma.shape[0]) + influence[:, :n_in] = influence[:, :n_in] * sigma[:n_in][None, :] + + next_metrics = self.layer_metrics.get(next_layer_name, {}) if next_layer_name else {} + next_syn = np.asarray(next_metrics.get("synergy", np.array([])), dtype=np.float64).reshape(-1) + if next_syn.size == 0: + next_syn = np.zeros(influence.shape[0], dtype=np.float64) + else: + next_syn = next_syn[: influence.shape[0]] + + halo_syn = np.zeros(int(layer.weight.shape[0]), dtype=np.float64) + total_infl = influence.sum(axis=1) + 1e-10 + pct = float(halo_percentile) + for i in range(min(halo_syn.shape[0], influence.shape[1])): + rel_infl = influence[:, i] / total_infl + thresh = np.percentile(rel_infl, pct) + mask = rel_infl >= thresh + if mask.sum() > 0: + halo_syn[i] = float(np.mean(next_syn[mask])) + return halo_syn + + def _run_cluster_aware_pruning( + self, + model: nn.Module, + *, + layer_modules: Dict[str, nn.Module], + ratio: float, + method: str, + ) -> Dict[str, Any]: + """ + Apply cluster-aware pruning using the paper strategy (halo score + constraints). + + Returns a pipeline-like dict with: + - masks: {layer_name: [C] mask} + - stats: {layer_name: mask stats} + Also stores a pruned-by-cluster summary under self.pruning_cluster_distributions. + """ + from ..pruning.strategies.cluster_aware import ClusterAwarePruning, ClusterAwarePruningConfig + from ..services.mask_ops import MaskOperations + + # Base config + cfg = ClusterAwarePruningConfig(amount=float(ratio), structured=True) + + # Variants for ablations / controls + if method == "cluster_aware_no_halo": + cfg.lambda_halo = 0.0 + elif method == "cluster_aware_no_constraints": + cfg.protect_critical_frac = 1.0 + cfg.target_redundant = False + cfg.synergy_pair_constraint = False + elif method == "cluster_aware_protect_redundant": + # Inverted priority (rough proxy): do not preferentially prune redundant/background + cfg.target_redundant = False + elif method == "cluster_aware_annealed": + # Anneal constraints + mix in a strong low-sparsity baseline (Taylor) so we + # behave like Taylor/Magnitude at low sparsity and like Cluster-aware at high sparsity. + # + # anneal_w(r)=0 below start, 1 above end. + start = float(getattr(self.config, "cluster_aware_anneal_start", 0.70)) + end = float(getattr(self.config, "cluster_aware_anneal_end", 0.90)) + if end <= start: + end = start + 1e-6 + if ratio <= start: + w_anneal = 0.0 + elif ratio >= end: + w_anneal = 1.0 + else: + w_anneal = float((ratio - start) / (end - start)) + + # Constraints: off at low sparsity, on at high sparsity + base_lambda = float(cfg.lambda_halo) + base_protect = float(cfg.protect_critical_frac) + cfg.lambda_halo = base_lambda * w_anneal + cfg.protect_critical_frac = 1.0 - w_anneal * (1.0 - base_protect) + cfg.target_redundant = bool(w_anneal >= 0.5) + cfg.synergy_pair_constraint = bool(w_anneal >= 0.5) + + # Allow paper scripts / SLURM jobs to sweep score weights via config overrides + cfg.alpha = float(getattr(self.config, "cluster_aware_alpha", cfg.alpha)) + cfg.beta = float(getattr(self.config, "cluster_aware_beta", cfg.beta)) + cfg.gamma = float(getattr(self.config, "cluster_aware_gamma", cfg.gamma)) + cfg.lambda_halo = float(getattr(self.config, "cluster_aware_lambda_halo", cfg.lambda_halo)) + cfg.protect_critical_frac = float(getattr(self.config, "cluster_aware_protect_critical_frac", cfg.protect_critical_frac)) + + # Keep halo settings consistent with experiment config unless overridden + cfg.halo_percentile = float(getattr(self.config, "halo_percentile", cfg.halo_percentile)) + cfg.use_activation_weight = bool(getattr(self.config, "use_activation_weight", cfg.use_activation_weight)) + cfg.n_clusters = int(getattr(self.config, "n_clusters", cfg.n_clusters)) + + masks: Dict[str, torch.Tensor] = {} + stats: Dict[str, Any] = {} + + # Aggregate pruning distribution by cluster type + by_type_pruned: Dict[str, int] = {} + by_type_total: Dict[str, int] = {} + + layer_names = [nm for nm, _ in self.layers] + module_map = dict(model.named_modules()) + + # ------------------------------------------------------------------ + # Respect the same pruning distribution knobs as the baseline pipeline. + # + # - pruning_distribution controls how much to prune per layer (uniform, + # global_threshold, size_proportional, importance_weighted, ...) + # - pruning_{min,max}_per_layer bound the per-layer amounts + # + # For score-dependent strategies (global_threshold / importance_weighted), + # we compute the per-layer cluster-aware scores first, then allocate per + # layer amounts from those scores. + # ------------------------------------------------------------------ + distribution = getattr(self.config, "pruning_distribution", "uniform") + min_amount = float(getattr(self.config, "pruning_min_per_layer", 0.0)) + max_amount = float(getattr(self.config, "pruning_max_per_layer", 0.95)) + + # First pass: compute per-layer cluster-aware scores (no pruning yet) + layer_scores: Dict[str, torch.Tensor] = {} + layer_pruners: Dict[str, "ClusterAwarePruning"] = {} + layer_num_channels: Dict[str, int] = {} + + for idx, layer_name in enumerate(layer_names): + layer = module_map.get(layer_name) + if layer is None or not hasattr(layer, "weight") or layer.weight is None: + continue + + n_channels = int(layer.weight.shape[0]) + layer_num_channels[layer_name] = n_channels + + # Pick the next *weight-connected* layer by matching channel dimensions (same logic as halo analysis). + src_out = int(layer.weight.shape[0]) + next_layer_name = None + for j in range(idx + 1, len(layer_names)): + cand_name = layer_names[j] + cand_layer = module_map.get(cand_name) + if cand_layer is None or not hasattr(cand_layer, "weight"): + continue + w = cand_layer.weight + if w is None or w.ndim < 2: + continue + if int(w.shape[1]) == src_out: + next_layer_name = cand_name + break + next_layer = module_map.get(next_layer_name) if next_layer_name else None + + # Cached metrics + clusters from the original (unpruned) analysis + pre_metrics = self.layer_metrics.get(layer_name, {}) + pre_clusters = self.cluster_results.get(layer_name, {}) + + labels = np.asarray(pre_clusters.get("labels", np.zeros(n_channels, dtype=int))).astype(int) + type_mapping = pre_clusters.get("type_mapping", {}) + + # HaloSyn proxy (uses sigma from RQ and next-layer synergy) + halo_syn = self._compute_halo_syn_proxy( + layer_name=layer_name, + layer=layer, + next_layer=next_layer, + next_layer_name=next_layer_name, + halo_percentile=cfg.halo_percentile, + use_activation_weight=cfg.use_activation_weight, + ) + + pruner = ClusterAwarePruning( + cfg, + precomputed_metrics=pre_metrics, + precomputed_clusters={"labels": labels, "type_mapping": type_mapping}, + precomputed_halos={"halo_syn": halo_syn}, + ) + + scores = pruner.compute_importance_scores( + layer, + outputs=None, # halo syn is precomputed + next_layer_weights=next_layer.weight if next_layer is not None else None, + next_layer_metrics=self.layer_metrics.get(next_layer_name, {}) if next_layer_name else None, + layer_name=layer_name, + ) + + # Optional annealed mixing: blend cluster-aware score with Taylor at low sparsity. + if method == "cluster_aware_annealed": + # Ensure Taylor cache exists (computed once; reused across ratios/methods). + if "taylor" not in self._pruning_score_cache: + try: + self._pruning_score_cache["taylor"] = self._compute_taylor_channel_scores(self.model) + except Exception: + self._pruning_score_cache["taylor"] = {} + + t_cpu = (self._pruning_score_cache.get("taylor", {}) or {}).get(layer_name) + if t_cpu is None or (hasattr(t_cpu, "numel") and int(t_cpu.numel()) != int(n_channels)): + # Fallback to weight magnitude if Taylor is unavailable/mismatched + w_flat = layer.weight.detach().view(n_channels, -1) + t = w_flat.norm(p=2, dim=1).detach().cpu() + else: + t = t_cpu.detach().cpu() + + # Normalize both to [0,1] per-layer for stable mixing + def _minmax(x: "torch.Tensor") -> "torch.Tensor": + x = x.float() + if x.numel() == 0: + return x + mn = float(x.min().item()) + mx = float(x.max().item()) + if mx - mn < 1e-12: + return torch.zeros_like(x) + return (x - mn) / (mx - mn) + + s_ca = _minmax(scores.detach().cpu()) + s_t = _minmax(t) + + start = float(getattr(self.config, "cluster_aware_anneal_start", 0.70)) + end = float(getattr(self.config, "cluster_aware_anneal_end", 0.90)) + if end <= start: + end = start + 1e-6 + if ratio <= start: + w_anneal = 0.0 + elif ratio >= end: + w_anneal = 1.0 + else: + w_anneal = float((ratio - start) / (end - start)) + + mixed = (1.0 - w_anneal) * s_t + w_anneal * s_ca + scores = mixed.to(device=scores.device) + + layer_scores[layer_name] = scores.detach() + layer_pruners[layer_name] = pruner + + # Compute per-layer amounts using the shared distribution manager. + try: + from ..pruning.distribution import PruningDistributionManager + + manager = PruningDistributionManager( + strategy=str(distribution), + target_sparsity=float(ratio), + min_amount=float(min_amount), + max_amount=float(max_amount), + ) + # Only include layers we actually scored + scored_names = [nm for nm in layer_names if nm in layer_scores] + per_layer_amounts = manager.compute_distribution(model, scored_names, layer_scores=layer_scores) + except Exception as exc: + logger.warning( + "Cluster-aware pruning: failed to compute distribution '%s' (%s); falling back to uniform", + distribution, + exc, + ) + clipped = max(min_amount, min(max_amount, float(ratio))) + per_layer_amounts = {nm: clipped for nm in layer_scores.keys()} + + # Second pass: apply pruning using per-layer allocated amounts + for layer_name in layer_names: + layer = module_map.get(layer_name) + if layer is None or not hasattr(layer, "weight") or layer.weight is None: + continue + if layer_name not in layer_scores or layer_name not in layer_pruners: + continue + + n_channels = int(layer_num_channels.get(layer_name, layer.weight.shape[0])) + amount = float(per_layer_amounts.get(layer_name, float(ratio))) + n_prune = int(n_channels * amount) + if n_prune <= 0: + masks[layer_name] = torch.ones(n_channels, dtype=torch.bool, device=layer.weight.device) + stats[layer_name] = MaskOperations.get_mask_statistics(masks[layer_name]) + continue + + # Cached clusters from the original (unpruned) analysis (for by-type summaries) + pre_clusters = self.cluster_results.get(layer_name, {}) + + labels = np.asarray(pre_clusters.get("labels", np.zeros(n_channels, dtype=int))).astype(int) + type_mapping = pre_clusters.get("type_mapping", {}) + + pruner = layer_pruners[layer_name] + scores = layer_scores[layer_name].to(device=layer.weight.device) + prune_idx = pruner.select_channels_to_prune(scores, n_prune, layer_name=layer_name) + + mask = torch.ones(n_channels, dtype=torch.bool, device=layer.weight.device) + if prune_idx: + mask[torch.as_tensor(prune_idx, device=layer.weight.device)] = False + + with torch.no_grad(): + layer.weight.data[~mask] = 0.0 + if getattr(layer, "bias", None) is not None and layer.bias.data.numel() == n_channels: + layer.bias.data[~mask] = 0.0 + + masks[layer_name] = mask + stats[layer_name] = MaskOperations.get_mask_statistics(mask) + + # Update by-type counts for diagnostics/figures + # Trim labels if necessary + labels = labels[: min(len(labels), n_channels)] + for cid, ctype in type_mapping.items(): + cid_int = int(cid) + idxs = np.where(labels == cid_int)[0] + by_type_total[ctype] = by_type_total.get(ctype, 0) + int(len(idxs)) + if len(idxs) > 0: + pruned = int((~mask.detach().cpu().numpy().astype(bool))[idxs].sum()) + by_type_pruned[ctype] = by_type_pruned.get(ctype, 0) + pruned + + # Store summary for paper figures + self.pruning_cluster_distributions.setdefault(method, {}) + self.pruning_cluster_distributions[method][float(ratio)] = { + "pruned": by_type_pruned, + "total": by_type_total, + } + + return {"masks": masks, "stats": stats} + + def _zero_batchnorm_from_masks(self, model: nn.Module, masks: Dict[str, torch.Tensor]) -> None: + for layer_name, mask in masks.items(): + bn_layer = self._find_bn_for_conv(model, layer_name) + if bn_layer is None or not hasattr(bn_layer, "weight"): + continue + mask_bool = mask.to(bn_layer.weight.device).bool() + if mask_bool.numel() != bn_layer.weight.data.numel(): + continue + with torch.no_grad(): + bn_layer.weight.data[~mask_bool] = 0.0 + if getattr(bn_layer, "bias", None) is not None: + bn_layer.bias.data[~mask_bool] = 0.0 + if hasattr(bn_layer, "running_mean"): + bn_layer.running_mean.data[~mask_bool] = 0.0 + if hasattr(bn_layer, "running_var"): + bn_layer.running_var.data[~mask_bool] = 1.0 + + def _apply_pruning(self, model: nn.Module, method: str, ratio: float) -> nn.Module: + """ + Apply a specific pruning method. + + Supported methods: + + BASELINE: + - 'random': Random channel selection + - 'magnitude': Prune lowest activation magnitude (standard baseline) + - 'taylor': Prune by gradient-based importance + + SINGLE METRICS (prune LOW values = assume low is unimportant): + - 'rq_low': Prune channels with lowest Rayleigh Quotient + - 'redundancy_low': Prune channels with lowest redundancy (MI) + - 'synergy_low': Prune channels with lowest synergy + + SINGLE METRICS (prune HIGH values = assume high is unimportant): + - 'rq_high': Prune channels with highest RQ + - 'redundancy_high': Prune channels with highest redundancy + - 'synergy_high': Prune channels with highest synergy + - 'magnitude_high': Prune highest magnitude channels + + COMPOSITE COMBINATIONS: + - 'composite': Original formula: score = RQ + syn - red (prune low) + - 'composite_pos_red': Flipped: score = RQ + syn + red (prune low) + - 'rq_minus_red': score = RQ - redundancy (prune low) + - 'rq_plus_red': score = RQ + redundancy (prune low) + - 'magnitude_plus_rq': score = magnitude + RQ (prune low) + + CLUSTER-AWARE: + - 'cluster_aware': Cluster-constrained pruning (targets redundant cluster) + - 'cluster_aware_protect_redundant': Inverted (protects redundant, targets critical) + """ + model = model.to(self.device) + pruner = None # Will use metric-based pruning for most methods + + if method == 'random': + from ..pruning.strategies import RandomPruning + from ..pruning.base import PruningConfig + pruner = RandomPruning(PruningConfig(amount=ratio, structured=True)) + + elif method == 'magnitude': + from ..pruning.strategies import MagnitudePruning + from ..pruning.base import PruningConfig + pruner = MagnitudePruning(PruningConfig(amount=ratio, structured=True)) + + elif method == 'taylor': + # Taylor pruning needs gradients from a backward pass. In this + # analysis-only flow we do not run backward, so running Taylor here + # would be misleading. Fail fast so results clearly mark it unusable. + raise ValueError("Taylor pruning requires gradients; not available in analysis-only mode.") + + elif method == 'composite': + from ..pruning.strategies import CompositePruning, ClusterAwarePruningConfig + config = ClusterAwarePruningConfig(amount=ratio) + pruner = CompositePruning(config) + + elif method == 'cluster_aware': + from ..pruning.strategies import ClusterAwarePruning, ClusterAwarePruningConfig + config = ClusterAwarePruningConfig(amount=ratio) + pruner = ClusterAwarePruning( + config, + precomputed_metrics=None, + precomputed_clusters=None, + ) + + elif method == 'cluster_aware_protect_redundant': + # Inverted cluster-aware: protect redundant, target critical + from ..pruning.strategies import ClusterAwarePruning, ClusterAwarePruningConfig + config = ClusterAwarePruningConfig( + amount=ratio, + target_redundant=False, # Don't target redundant + protect_critical_frac=1.0, # Don't protect critical + ) + pruner = ClusterAwarePruning(config) + + # Apply to each conv layer in the COPIED model (not self.model!) + # Get layer references from the passed model, not self.layers + model_modules = dict(model.named_modules()) + + for name, orig_layer in self.layers: + # Get the corresponding layer from model_copy, not self.model + if name not in model_modules: + logger.debug(f" {name}: not found in model copy, skipping") + continue + layer = model_modules[name] + + if not hasattr(layer, 'weight'): + continue + + n_channels = layer.weight.shape[0] + n_prune = int(n_channels * ratio) + if n_prune == 0: + continue + + # Get cached metrics for this layer (from original model analysis) + metrics = self.layer_metrics.get(name, {}) + clusters = self.cluster_results.get(name, {}) + + # Debug: log if metrics are missing + if not metrics: + logger.warning(f" {name}: NO METRICS CACHED! Using defaults (will select channels 0,1,2...)") + elif 'rq' not in metrics: + logger.warning(f" {name}: 'rq' not in metrics. Keys: {list(metrics.keys())}") + + try: + # ================================================================ + # METRIC-BASED PRUNING (single metrics and combinations) + # ================================================================ + if method.startswith('rq_') or method.startswith('redundancy_') or \ + method.startswith('synergy_') or method.startswith('magnitude_') or \ + method.startswith('composite') or method.startswith('rq_'): + + # Get metric arrays + rq = np.array(metrics.get('rq', np.ones(n_channels))) + redundancy = np.array(metrics.get('redundancy', np.zeros(n_channels))) + synergy = np.array(metrics.get('synergy', np.zeros(n_channels))) + + # Compute magnitude from activations if available + acts = metrics.get('_activations', None) + if acts is not None: + magnitude = np.mean(np.abs(acts), axis=0) + else: + # Use weight L2 norm as proxy + w = layer.weight.data.cpu().numpy() + magnitude = np.sqrt(np.sum(w.reshape(w.shape[0], -1)**2, axis=1)) + + # Normalize metrics to [0, 1] for stable combination + def normalize(x): + x_min, x_max = x.min(), x.max() + if x_max > x_min: + return (x - x_min) / (x_max - x_min) + return np.zeros_like(x) + + rq_norm = normalize(np.log(np.clip(rq, 1e-10, None))) + red_norm = normalize(redundancy) + syn_norm = normalize(synergy) + mag_norm = normalize(magnitude) + + # Compute scores based on method + # SINGLE METRICS - prune LOW + if method == 'rq_low': + scores = rq_norm # Low RQ → prune + elif method == 'redundancy_low': + scores = red_norm # Low redundancy → prune + elif method == 'synergy_low': + scores = syn_norm # Low synergy → prune + + # SINGLE METRICS - prune HIGH + elif method == 'rq_high': + scores = -rq_norm # High RQ → prune (invert) + elif method == 'redundancy_high': + scores = -red_norm # High redundancy → prune + elif method == 'synergy_high': + scores = -syn_norm # High synergy → prune + elif method == 'magnitude_high': + scores = -mag_norm # High magnitude → prune + + # COMPOSITE COMBINATIONS + elif method == 'composite': + # Original: High RQ + High Syn - High Red = important + # Prune LOW scores + scores = rq_norm + 0.5 * syn_norm - 0.3 * red_norm + elif method == 'composite_pos_red': + # Flipped: High RQ + High Syn + High Red = important + scores = rq_norm + 0.5 * syn_norm + 0.3 * red_norm + elif method == 'rq_minus_red': + scores = rq_norm - 0.5 * red_norm + elif method == 'rq_plus_red': + scores = rq_norm + 0.5 * red_norm + elif method == 'magnitude_plus_rq': + scores = mag_norm + 0.5 * rq_norm + elif method == 'magnitude_minus_red': + scores = mag_norm - 0.3 * red_norm + elif method == 'magnitude_plus_red': + scores = mag_norm + 0.3 * red_norm + else: + raise ValueError(f"Unknown metric-based method: {method}") + + # Select lowest scores to prune + scores_tensor = torch.from_numpy(scores).float() + prune_idx = torch.argsort(scores_tensor)[:n_prune].tolist() + + # Debug: check if all channels have same score (bug indicator) + score_range = scores.max() - scores.min() + if score_range < 1e-8: + logger.warning(f" {name}: ALL SCORES ARE IDENTICAL ({scores[0]:.6f})! Selecting channels {prune_idx[:5]}...") + + # ================================================================ + # CLUSTER-AWARE PRUNING + # ================================================================ + elif method in ['cluster_aware', 'cluster_aware_protect_redundant']: + pruner.precomputed_metrics = metrics + pruner.precomputed_clusters = clusters + pruner._metrics_cache[name] = metrics + pruner._cluster_cache[name] = clusters + + scores = pruner.compute_importance_scores(layer, layer_name=name) + prune_idx = pruner.select_channels_to_prune(scores, n_prune, name) + + # ================================================================ + # COMPOSITE PRUNING (using pruner class) + # ================================================================ + elif method == 'composite_class': + pruner.precomputed_metrics = metrics + pruner.precomputed_clusters = clusters + pruner._metrics_cache[name] = metrics + pruner._cluster_cache[name] = clusters + + scores = pruner.compute_importance_scores(layer, layer_name=name) + prune_idx = torch.argsort(scores)[:n_prune].tolist() + + # ================================================================ + # STANDARD PRUNERS (random, magnitude, taylor) + # ================================================================ + elif pruner is not None: + # Get weight-level importance scores + weight_scores = pruner.compute_importance_scores(layer, layer_name=name) + + # Convert to channel-level scores by averaging over non-channel dims + # weight_scores shape: [C_out, C_in, k, k] for Conv2d + if len(weight_scores.shape) == 4: + # Average over input channels and kernel dims + channel_scores = weight_scores.abs().mean(dim=(1, 2, 3)) + elif len(weight_scores.shape) == 2: + # Linear: [out, in] -> average over input dim + channel_scores = weight_scores.abs().mean(dim=1) + else: + # Fallback: flatten and take first n_channels + channel_scores = weight_scores.view(n_channels, -1).abs().mean(dim=1) + + prune_idx = torch.argsort(channel_scores)[:n_prune].tolist() + logger.debug(f" {name}: pruner {method}, channel scores range [{channel_scores.min():.4f}, {channel_scores.max():.4f}]") + + # ================================================================ + # TAYLOR PRUNING (gradient-weight product) + # ================================================================ + elif method == 'taylor': + # Taylor needs gradients - compute them on the fly + # Use weight L2 norm as importance (magnitude-based fallback) + # since we don't have gradients readily available + w = layer.weight.data + if len(w.shape) == 4: + # Conv: L2 norm per output channel + channel_scores = w.pow(2).sum(dim=(1, 2, 3)).sqrt() + else: + channel_scores = w.pow(2).sum(dim=1).sqrt() + + prune_idx = torch.argsort(channel_scores)[:n_prune].tolist() + logger.debug(f" {name}: taylor (magnitude fallback), scores range [{channel_scores.min():.4f}, {channel_scores.max():.4f}]") + + else: + raise ValueError(f"Unknown pruning method: {method}") + + # Zero out pruned channels in conv layer + with torch.no_grad(): + layer.weight.data[prune_idx] = 0 + if layer.bias is not None: + layer.bias.data[prune_idx] = 0 + + # Also zero corresponding BatchNorm parameters + bn_layer = self._find_bn_for_conv(model, name) + if bn_layer is not None: + with torch.no_grad(): + bn_layer.weight.data[prune_idx] = 0 + bn_layer.bias.data[prune_idx] = 0 + bn_layer.running_mean.data[prune_idx] = 0 + bn_layer.running_var.data[prune_idx] = 1 # Avoid div by zero + + # Verify pruning: count zeroed channels + n_zeroed = (layer.weight.data.view(n_channels, -1).abs().sum(dim=1) == 0).sum().item() + logger.debug(f" {name}: pruned {n_prune} channels, verified {n_zeroed} are zeroed") + + except Exception as e: + logger.debug(f"Pruning {name} with {method} failed: {e}") + import traceback + logger.debug(traceback.format_exc()) + + return model + + def _find_bn_for_conv(self, model: nn.Module, conv_name: str) -> Optional[nn.Module]: + """ + Find the BatchNorm layer that corresponds to a conv layer. + + In standard architectures (ResNet, VGG-BN), BN follows conv with naming like: + - conv1 -> bn1 + - layer1.0.conv1 -> layer1.0.bn1 + """ + modules = dict(model.named_modules()) + + # Try common naming patterns + patterns = [ + conv_name.replace('conv', 'bn'), # conv1 -> bn1 + conv_name.replace('.conv', '.bn'), # layer1.0.conv1 -> layer1.0.bn1 + conv_name + '_bn', # some architectures + ] + + for pattern in patterns: + if pattern in modules: + bn = modules[pattern] + if isinstance(bn, (nn.BatchNorm1d, nn.BatchNorm2d)): + return bn + + # For downsample layers: layer2.0.downsample.0 -> layer2.0.downsample.1 + if 'downsample.0' in conv_name: + bn_name = conv_name.replace('downsample.0', 'downsample.1') + if bn_name in modules: + bn = modules[bn_name] + if isinstance(bn, (nn.BatchNorm1d, nn.BatchNorm2d)): + return bn + + # Generic Sequential convention: Conv at index i, BN at index i+1 + # Covers VGG16-BN (features.0 -> features.1) and MobileNetV2 (....0 -> ....1). + parts = conv_name.split(".") + if parts and parts[-1].isdigit(): + try: + i = int(parts[-1]) + cand = ".".join(parts[:-1] + [str(i + 1)]) + bn = modules.get(cand) + if isinstance(bn, (nn.BatchNorm1d, nn.BatchNorm2d)): + return bn + except Exception: + pass + + return None + + def _fine_tune( + self, + model: nn.Module, + epochs: int, + lr: float, + max_batches: Optional[int] = None, + weight_decay: float = 0.0, + masks: Optional[Dict[str, torch.Tensor]] = None, + ) -> nn.Module: + """Fine-tune a pruned model. + + Important: when fine-tuning after structured pruning, we must keep pruned + channels pruned. We do this by re-applying channel masks after each + optimizer step (and keeping the corresponding BatchNorm params zeroed). + """ + import torch.optim as optim + + model.train() + optimizer = optim.Adam(model.parameters(), lr=lr, weight_decay=float(weight_decay or 0.0)) + criterion = nn.CrossEntropyLoss() + + module_map: Dict[str, nn.Module] = dict(model.named_modules()) + masks_dev: Dict[str, torch.Tensor] = {} + bn_map: Dict[str, nn.Module] = {} + + if masks: + for layer_name, mask in masks.items(): + m = module_map.get(layer_name) + if m is None or not hasattr(m, "weight") or getattr(m, "weight", None) is None: + continue + try: + mb = mask.to(m.weight.device).bool() + if mb.numel() != int(m.weight.shape[0]): + continue + masks_dev[layer_name] = mb + bn = self._find_bn_for_conv(model, layer_name) + if bn is not None: + bn_map[layer_name] = bn + except Exception: + continue + + def _reapply_masks() -> None: + if not masks_dev: + return + with torch.no_grad(): + for layer_name, mb in masks_dev.items(): + m = module_map.get(layer_name) + if m is None or not hasattr(m, "weight") or getattr(m, "weight", None) is None: + continue + if mb.numel() != int(m.weight.shape[0]): + continue + + # Zero pruned output channels + m.weight.data[~mb] = 0.0 + if getattr(m, "bias", None) is not None and m.bias.data.numel() == mb.numel(): + m.bias.data[~mb] = 0.0 + + # Keep matched BatchNorm channels zeroed too (when present) + bn = bn_map.get(layer_name) + if bn is None or not hasattr(bn, "weight") or getattr(bn, "weight", None) is None: + continue + if mb.numel() != bn.weight.data.numel(): + continue + bn.weight.data[~mb] = 0.0 + if getattr(bn, "bias", None) is not None: + bn.bias.data[~mb] = 0.0 + if hasattr(bn, "running_mean"): + bn.running_mean.data[~mb] = 0.0 + if hasattr(bn, "running_var"): + bn.running_var.data[~mb] = 1.0 + + for epoch in range(epochs): + total_loss = 0 + n_batches = 0 + + for x, y in self.train_loader: + x, y = x.to(self.device), y.to(self.device) + + optimizer.zero_grad() + out = model(x) + loss = criterion(out, y) + loss.backward() + optimizer.step() + _reapply_masks() + + total_loss += loss.item() + n_batches += 1 + if max_batches is not None and n_batches >= int(max_batches): + break + + if epoch == 0 or (epoch + 1) % 5 == 0: + avg_loss = total_loss / max(n_batches, 1) + logger.debug(f" FT epoch {epoch+1}/{epochs}: loss={avg_loss:.4f}") + + model.eval() + return model + + def _evaluate_accuracy(self, model: Optional[nn.Module] = None) -> float: + """Evaluate model accuracy on test set.""" + model = model or self.model + model.eval() + + correct = 0 + total = 0 + + with torch.no_grad(): + for x, y in self.test_loader: + x, y = x.to(self.device), y.to(self.device) + out = model(x) + correct += (out.argmax(1) == y).sum().item() + total += y.size(0) + + return correct / total if total > 0 else 0.0 + + def run_full_analysis(self, include_pruning: bool = True) -> Dict[str, Any]: + """ + Run complete analysis pipeline. + + Args: + include_pruning: Whether to run pruning comparison experiments + """ + logger.info(f"Starting full analysis for {self.config.model_name}") + + # 1. Compute metrics + self.compute_metrics() + + # 2. Clustering + self.run_clustering() + + # 3. Halo analysis + self.run_halo_analysis() + + # 4. Cascade test + self.run_cascade_test() + + # 5. Pruning experiments (optional) + if include_pruning and getattr(self.config, 'pruning_ratios', None): + # Check if fine-tuning is enabled + fine_tune_enabled = getattr(self.config, 'fine_tune_after_pruning', True) + fine_tune_epochs = getattr(self.config, 'fine_tune_epochs', 10) if fine_tune_enabled else 0 + fine_tune_lr = getattr(self.config, 'fine_tune_lr', 0.0001) + fine_tune_max_batches = getattr(self.config, "fine_tune_max_batches", None) + fine_tune_weight_decay = float(getattr(self.config, "fine_tune_weight_decay", 0.0) or 0.0) + + logger.info(f"Fine-tuning after pruning: {'enabled' if fine_tune_epochs > 0 else 'disabled'}") + + self.run_pruning_experiments( + ratios=self.config.pruning_ratios, + methods=getattr(self.config, "pruning_methods", None), + fine_tune_epochs=fine_tune_epochs, + fine_tune_lr=fine_tune_lr, + fine_tune_max_batches=fine_tune_max_batches, + fine_tune_weight_decay=fine_tune_weight_decay, + ) + + # Save results (including centroids for visualization) + results = { + "config": { + "model_name": self.config.model_name, + "dataset_name": self.config.dataset_name, + "n_clusters": self.config.n_clusters, + "activation_samples": getattr(self.config, "activation_samples", "flatten_spatial"), + "spatial_samples_per_image": getattr(self.config, "spatial_samples_per_image", 16), + }, + "layer_metrics": self.layer_metrics, + "cluster_results": { + k: { + "labels": v["labels"].tolist() if hasattr(v.get("labels", None), "tolist") else v.get("labels", []), + "type_counts": v["type_counts"], + "silhouette": v["silhouette"], + "centroids": v["centroids"].tolist() if hasattr(v["centroids"], 'tolist') else v["centroids"], + "type_mapping": {str(kk): vv for kk, vv in v["type_mapping"].items()}, + } + for k, v in self.cluster_results.items() + }, + "halo_results": self.halo_results, + "halo_flow_results": self.halo_flow_results, + "cascade_results": self.cascade_results, + "pruning_results": getattr(self, 'pruning_results', {}), + "pruning_cluster_distributions": getattr(self, "pruning_cluster_distributions", {}), + } + + with open(self.output_dir / "results.json", "w") as f: + json.dump(results, f, indent=2, default=str) + + logger.info(f"Results saved to {self.output_dir}") + return results + + def run(self) -> Dict[str, Any]: + """ + Standard run method for compatibility with run_experiment.py. + + This is the main entry point when running via: + python scripts/run_experiment.py --config configs/vision/resnet18_cifar10_full.yaml + """ + # Run full analysis + results = self.run_full_analysis() + + # Generate figures + self.generate_figures() + + return results + + def generate_figures(self) -> None: + """Generate all visualization figures using centralized visualization module.""" + # Import visualization functions from the unified module + from ..analysis.visualization.cluster_plots import ( + plot_metric_scatter, + plot_cluster_evolution, + plot_metric_scatter_3d, + plot_influence_matrix, + plot_cascade_test, + plot_halo_properties, + plot_pruning_comparison, + plot_pruning_by_cluster_type, + plot_centroid_evolution, + plot_centroid_depth_profiles, + plot_metric_distributions_for_layer, + plot_layer_metric_summary, + plot_layer_metric_trends, + plot_metric_statistics_table, + ) + from ..analysis.visualization.metric_plots import ( + plot_metric_histogram, + plot_metric_violin, + plot_metric_correlation_heatmap, + plot_top_neurons_bar, + ) + from ..analysis.visualization.pruning_plots import ( + plot_pruning_recovery_chart, + plot_pruning_accuracy_loss_grid, + plot_pruning_bar_comparison, + plot_pruning_heatmap, + plot_pruning_ranking, + ) + + # Determine figures directory - check both new "figures" and old "plots" subdirectories + if (self.output_dir / "figures").exists(): + fig_dir = self.output_dir / "figures" + elif (self.output_dir / "plots").exists(): + fig_dir = self.output_dir / "plots" + else: + fig_dir = self.output_dir / "figures" + fig_dir.mkdir(exist_ok=True, parents=True) + + # Helper: keep backward-compatible root-level copies for paper scripts + # while also writing into organized subfolders. + try: + import shutil + + def _copy_legacy(src: "Path", dst: "Path") -> None: + try: + shutil.copy2(src, dst) + except Exception: + pass + + except Exception: + shutil = None # type: ignore + + def _copy_legacy(_src: "Path", _dst: "Path") -> None: + return + + # Create organized subdirectories + distributions_dir = fig_dir / "01_distributions" + distributions_dir.mkdir(exist_ok=True) + + summary_dir = fig_dir / "02_summary" + summary_dir.mkdir(exist_ok=True) + + clustering_dir = fig_dir / "03_clustering" + clustering_dir.mkdir(exist_ok=True) + + cascade_dir = fig_dir / "04_cascade" + cascade_dir.mkdir(exist_ok=True) + + halo_dir = fig_dir / "05_halo" + halo_dir.mkdir(exist_ok=True) + + pruning_dir = fig_dir / "06_pruning" + pruning_dir.mkdir(exist_ok=True) + + # ================================================================== + # 1. Metric Distributions (Histograms) - NEW + # ================================================================== + logger.info("Generating metric distribution plots...") + for name, metrics in self.layer_metrics.items(): + safe_name = name.replace('.', '_') + + # Combined histogram for all metrics in this layer + plot_metric_distributions_for_layer( + metrics=metrics, + layer_name=name, + save_dir=distributions_dir, + ) + + # Individual histograms with percentile highlighting + for metric_name in ['rq', 'redundancy', 'synergy']: + if metric_name in metrics: + plot_metric_histogram( + values=metrics[metric_name], + metric_name=metric_name, + layer_name=name, + highlight_percentile=95, + log_scale=(metric_name == 'rq'), + save_path=distributions_dir / f"{metric_name}_{safe_name}.png", + ) + + # ================================================================== + # 2. Layer-wise Violin/Boxplots for each metric + # ================================================================== + logger.info("Generating layer-wise metric plots...") + for metric_name in ['rq', 'redundancy', 'synergy']: + layer_data = { + name: metrics.get(metric_name, np.array([])) + for name, metrics in self.layer_metrics.items() + if metric_name in metrics + } + if layer_data: + plot_metric_violin( + layer_metrics=layer_data, + metric_name=metric_name, + save_path=summary_dir / f"{metric_name}_violin_all_layers.png", + ) + + # ================================================================== + # 3. Metric Correlation Heatmap per layer + # ================================================================== + for name, metrics in self.layer_metrics.items(): + if len(metrics) >= 2: + safe_name = name.replace('.', '_') + plot_metric_correlation_heatmap( + metrics=metrics, + layer_name=name, + save_path=distributions_dir / f"correlation_{safe_name}.png", + ) + + # ================================================================== + # 4. Layer Metric Summary (overview of all layers and metrics) + # ================================================================== + if self.layer_metrics: + # Original heatmap-style summary + _p = summary_dir / "layer_metric_summary.png" + plot_layer_metric_summary( + layer_metrics=self.layer_metrics, + save_path=_p, + ) + _copy_legacy(_p, fig_dir / "layer_metric_summary.png") + + # NEW: Smoother trend plots with confidence intervals + _p = summary_dir / "layer_metric_trends.png" + plot_layer_metric_trends( + layer_metrics=self.layer_metrics, + metrics_to_plot=['rq', 'redundancy', 'synergy'], + smooth_window=3, # Moving average over 3 layers + show_ci=True, + ci_percentile=95, + save_path=_p, + ) + _copy_legacy(_p, fig_dir / "layer_metric_trends.png") + + # NEW: Statistics table for paper/report + _p = summary_dir / "metric_statistics_table.png" + plot_metric_statistics_table( + layer_metrics=self.layer_metrics, + save_path=_p, + ) + _copy_legacy(_p, fig_dir / "metric_statistics_table.png") + + # ================================================================== + # 5. Cluster scatter for each layer + # ================================================================== + logger.info("Generating cluster scatter plots...") + for name, metrics in self.layer_metrics.items(): + cluster = self.cluster_results.get(name, {}) + if not cluster: + continue + plot_metric_scatter( + metrics["rq"], + metrics["redundancy"], + metrics["synergy"], + cluster["labels"], + cluster["type_mapping"], + name, + clustering_dir / f"cluster_scatter_{name.replace('.', '_')}.png", + ) + _copy_legacy( + clustering_dir / f"cluster_scatter_{name.replace('.', '_')}.png", + fig_dir / f"cluster_scatter_{name.replace('.', '_')}.png", + ) + + # Representative 3D scatter for the paper (best-effort) + try: + if self.cluster_results and self.layer_metrics: + rep_layer = None + for candidate in self.cluster_results.keys(): + if "layer3" in candidate or "layer4" in candidate: + rep_layer = candidate + break + if rep_layer is None: + rep_layer = list(self.cluster_results.keys())[len(self.cluster_results) // 2] + + rep_cluster = self.cluster_results.get(rep_layer, {}) + rep_metrics = self.layer_metrics.get(rep_layer, {}) + if rep_cluster and rep_metrics: + _p = clustering_dir / "cluster_3d_scatter.png" + plot_metric_scatter_3d( + rq=rep_metrics.get("rq", np.array([])), + redundancy=rep_metrics.get("redundancy", np.array([])), + synergy=rep_metrics.get("synergy", np.array([])), + labels=rep_cluster.get("labels", np.array([])), + type_mapping=rep_cluster.get("type_mapping", {}), + layer_name=rep_layer, + save_path=_p, + ) + _copy_legacy(_p, fig_dir / "cluster_3d_scatter.png") + except Exception as exc: + logger.debug("Could not generate representative 3D cluster scatter: %s", exc) + + # ================================================================== + # 6. Cluster evolution across depth + # ================================================================== + layer_results = [ + {"layer_name": k, "type_counts": v["type_counts"]} + for k, v in self.cluster_results.items() + ] + _p = clustering_dir / "cluster_evolution.png" + plot_cluster_evolution(layer_results, _p) + _copy_legacy(_p, fig_dir / "cluster_evolution.png") + + # ================================================================== + # 7. Cascade test results + # ================================================================== + logger.info("Generating cascade test plots...") + for name, cascade in self.cascade_results.items(): + if cascade: + from ..analysis.cascade_analysis import CascadeResult + results = { + ct: CascadeResult(name, ct, d["n_removed"], d["accuracy_drop"], d["loss_increase"]) + for ct, d in cascade.items() + } + _p = cascade_dir / f"cascade_{name.replace('.', '_')}.png" + plot_cascade_test(results, _p) + # Paper scripts glob fig_dir/"cascade_*.png" (non-recursive) + _copy_legacy(_p, fig_dir / f"cascade_{name.replace('.', '_')}.png") + + # ================================================================== + # 8. Halo properties + # ================================================================== + if self.halo_results: + halo_summary = [] + for transition, clusters in self.halo_results.items(): + for ctype, data in clusters.items(): + halo_summary.append({ + "cluster_type": ctype, + "halo_red": data.get("halo_red", 0), + "halo_syn": data.get("halo_syn", 0), + }) + if halo_summary: + from collections import defaultdict + by_type = defaultdict(lambda: {"halo_red": [], "halo_syn": []}) + for h in halo_summary: + by_type[h["cluster_type"]]["halo_red"].append(h["halo_red"]) + by_type[h["cluster_type"]]["halo_syn"].append(h["halo_syn"]) + + avg_halo = [ + { + "cluster_type": ct, + "halo_red": np.mean(v["halo_red"]), + "halo_syn": np.mean(v["halo_syn"]), + } + for ct, v in by_type.items() + ] + # Save into the organized halo subfolder, but also keep a + # root-level copy for backward compatibility (paper scripts expect it). + halo_props_path = halo_dir / "halo_properties.png" + plot_halo_properties(avg_halo, halo_props_path) + try: + _copy_legacy(halo_props_path, fig_dir / "halo_properties.png") + except Exception: + pass + + # Representative cluster-to-cluster influence matrix for the paper (best-effort) + try: + if self.halo_flow_results: + rep_transition = None + for t in self.halo_flow_results.keys(): + if "layer3" in t or "layer4" in t: + rep_transition = t + break + if rep_transition is None: + rep_transition = list(self.halo_flow_results.keys())[0] + flow = self.halo_flow_results.get(rep_transition, {}) + if flow: + halo_infl_path = halo_dir / "halo_influence_matrix.png" + plot_influence_matrix( + flow=flow, + layer_name=rep_transition, + save_path=halo_infl_path, + ) + try: + _copy_legacy(halo_infl_path, fig_dir / "halo_influence_matrix.png") + except Exception: + pass + except Exception as exc: + logger.debug("Could not generate influence matrix plot: %s", exc) + + # ================================================================== + # 9. Pruning comparison (using unified interface) + # ================================================================== + logger.info("Generating pruning plots...") + if hasattr(self, 'pruning_results') and self.pruning_results: + baseline = self.pruning_results.get('baseline', 0.9) + methods = self.pruning_results.get('methods', {}) + + if methods: + # Main pruning comparison (line plot) - shows accuracy vs sparsity + plot_pruning_comparison( + methods, baseline, + pruning_dir / "01_accuracy_vs_sparsity.png" + ) + + # Accuracy recovery chart + plot_pruning_recovery_chart( + results=methods, + baseline_value=baseline, + metric='accuracy', + title='Accuracy Recovery After Pruning', + save_path=pruning_dir / "02_accuracy_recovery.png", + ) + + # ============================================================ + # Bar charts for method comparison at specific sparsities + # ============================================================ + # Bar chart at 30% sparsity (conservative) + plot_pruning_bar_comparison( + results=methods, + baseline_value=baseline, + target_sparsity=0.3, + metric='accuracy', + show_before_ft=True, + title='Pruning Methods at 30% Sparsity', + save_path=pruning_dir / "03_bar_30pct_sparsity.png", + ) + + # Bar chart at 50% sparsity (standard comparison point) + plot_pruning_bar_comparison( + results=methods, + baseline_value=baseline, + target_sparsity=0.5, + metric='accuracy', + show_before_ft=True, + title='Pruning Methods at 50% Sparsity', + save_path=pruning_dir / "04_bar_50pct_sparsity.png", + ) + + # Bar chart at 70% sparsity (aggressive) + plot_pruning_bar_comparison( + results=methods, + baseline_value=baseline, + target_sparsity=0.7, + metric='accuracy', + show_before_ft=True, + title='Pruning Methods at 70% Sparsity', + save_path=pruning_dir / "05_bar_70pct_sparsity.png", + ) + + # Heatmap of all methods x all sparsities + plot_pruning_heatmap( + results=methods, + metric='accuracy', + title='Pruning Performance Heatmap (Accuracy %)', + save_path=pruning_dir / "06_heatmap_all_methods.png", + ) + + # Ranking plot (methods ranked by average performance) + plot_pruning_ranking( + results=methods, + metric='accuracy', + title='Pruning Method Ranking (by Average Accuracy)', + save_path=pruning_dir / "07_method_ranking.png", + ) + + # Accuracy + Loss grid (if loss data available) + has_loss = any( + 'loss' in d or 'test_loss' in d + for ratio_data in methods.values() + for d in ratio_data.values() + if isinstance(d, dict) + ) + if has_loss: + plot_pruning_accuracy_loss_grid( + results=methods, + baseline_acc=baseline, + title='Pruning: Accuracy and Loss', + save_path=pruning_dir / "08_accuracy_loss_grid.png", + ) + + # Paper figure: which channels get pruned by cluster-aware? + try: + dist = getattr(self, "pruning_cluster_distributions", {}).get("cluster_aware", {}) + rep_ratio = 0.5 if 0.5 in dist else (sorted(dist.keys())[0] if dist else None) + if rep_ratio is not None: + summary = dist.get(rep_ratio, {}) + pruned = summary.get("pruned", {}) + total = summary.get("total", {}) + if pruned and total: + _p = pruning_dir / "pruning_by_cluster.png" + plot_pruning_by_cluster_type( + pruned=pruned, + total=total, + save_path=_p, + title=f"Cluster-aware pruning (sparsity={rep_ratio:.0%})", + ) + _copy_legacy(_p, fig_dir / "pruning_by_cluster.png") + except Exception as exc: + logger.debug("Could not generate pruning-by-cluster plot: %s", exc) + + # ================================================================== + # 10. Centroid evolution across depth + # ================================================================== + if self.cluster_results: + layer_names = list(self.cluster_results.keys()) + layer_centroids = [] + for depth, name in enumerate(layer_names): + cluster_data = self.cluster_results[name] + if "centroids" in cluster_data: + layer_centroids.append({ + "layer_name": name, + "depth": depth, + "centroids": cluster_data["centroids"].tolist() if hasattr(cluster_data["centroids"], 'tolist') else cluster_data["centroids"], + "type_mapping": cluster_data["type_mapping"], + }) + + if layer_centroids: + plot_centroid_evolution(layer_centroids, clustering_dir / "centroid_evolution_2d.png") + plot_centroid_depth_profiles(layer_centroids, clustering_dir / "centroid_depth_profiles.png") + + # ================================================================== + # 11. Top neurons by each metric (for first and last layers) + # ================================================================== + if self.layer_metrics: + layer_names = list(self.layer_metrics.keys()) + key_layers = [layer_names[0], layer_names[-1]] if len(layer_names) > 1 else layer_names + + for name in key_layers: + metrics = self.layer_metrics[name] + safe_name = name.replace('.', '_') + for metric_name in ['rq', 'redundancy', 'synergy']: + if metric_name in metrics: + plot_top_neurons_bar( + values=metrics[metric_name], + metric_name=metric_name, + layer_name=name, + top_k=15, + save_path=distributions_dir / f"top_neurons_{metric_name}_{safe_name}.png", + ) + + logger.info(f"All figures saved to {fig_dir}") + + +# Backward compatibility aliases +VisionExperiment = ClusterAnalysisExperiment From e4538afab4f3891437208d0546fe337f3cdd585f Mon Sep 17 00:00:00 2001 From: Houman Safaai Date: Wed, 14 Jan 2026 12:57:49 -0500 Subject: [PATCH 15/15] add experiments folder --- .gitignore | 1 - src/alignment/experiments/llm_experiments.py | 2 +- 2 files changed, 1 insertion(+), 2 deletions(-) diff --git a/.gitignore b/.gitignore index 008ac0e9..947cc9af 100644 --- a/.gitignore +++ b/.gitignore @@ -160,7 +160,6 @@ dmypy.json /runs/ /outputs/ /results/ -/experiments/ # Temporary files *.tmp diff --git a/src/alignment/experiments/llm_experiments.py b/src/alignment/experiments/llm_experiments.py index 94fce2aa..6884c97c 100644 --- a/src/alignment/experiments/llm_experiments.py +++ b/src/alignment/experiments/llm_experiments.py @@ -8029,7 +8029,7 @@ def restore_weights(): # ------------------------------------------------------------------ if getattr(self.config, "generate_plots", True): try: - from alignment.analysis.visualization import ( + from alignment.analysis.visualization.paper_plots import ( plot_halo_structure, plot_loss_proxy_concentration, plot_supernode_halo_summary,