diff --git a/.gitignore b/.gitignore index ba8ce996..442f7cc0 100644 --- a/.gitignore +++ b/.gitignore @@ -2,8 +2,7 @@ _arxiv/ # Ignore drafts by default, but keep the SCAR paper folder tracked (it has its own .gitignore). drafts/ -!drafts/LLM_prune/ -!drafts/LLM_prune/** + checkpoints/ results/ logs/ diff --git a/configs/vision_prune/mobilenetv2_cifar10_unified.yaml b/configs/vision_prune/mobilenetv2_cifar10_unified.yaml index 06657ff0..e9b04085 100644 --- a/configs/vision_prune/mobilenetv2_cifar10_unified.yaml +++ b/configs/vision_prune/mobilenetv2_cifar10_unified.yaml @@ -43,6 +43,18 @@ dataset: batch_size: 128 num_workers: 4 +# ----------------------------------------------------------------------------- +# TRAINING (paper-quality CIFAR baselines) +# ----------------------------------------------------------------------------- +training: + enabled: true + epochs: 50 + learning_rate: 0.05 + optimizer: "sgd" + scheduler: "cosine" + momentum: 0.9 + weight_decay: 0.0005 + # ----------------------------------------------------------------------------- # CALIBRATION # ----------------------------------------------------------------------------- @@ -156,14 +168,6 @@ pruning: - "redundancy_low" # Prune low redundancy (MI) - "synergy_low" # Prune low synergy - # ========================================================================= - # SINGLE METRICS - Prune HIGH (assumes high = unimportant) - # ========================================================================= - - "rq_high" # Prune high RQ (TEST: is high RQ bad?) - - "redundancy_high" # Prune high redundancy (TEST: is high corr bad?) - - "synergy_high" # Prune high synergy - - "magnitude_high" # Prune high magnitude (inverse of standard) - # ========================================================================= # COMPOSITE COMBINATIONS # ========================================================================= @@ -195,10 +199,11 @@ pruning: - "composite_pos_red" fine_tune: - enabled: false # Disabled to see pure pruning impact without recovery - epochs: 15 # MobileNet may need more fine-tuning + enabled: true # Enable recovery fine-tuning after pruning (standard for reporting) + epochs: 5 learning_rate: 0.0001 weight_decay: 0.00001 + max_batches: 100 # ----------------------------------------------------------------------------- # EVALUATION (Enhanced for Vision) diff --git a/configs/vision_prune/resnet18_cifar10_unified.yaml b/configs/vision_prune/resnet18_cifar10_unified.yaml index cf2c4fc8..38723a36 100644 --- a/configs/vision_prune/resnet18_cifar10_unified.yaml +++ b/configs/vision_prune/resnet18_cifar10_unified.yaml @@ -40,6 +40,19 @@ dataset: batch_size: 128 num_workers: 4 +# ----------------------------------------------------------------------------- +# TRAINING (paper-quality CIFAR baselines) +# ----------------------------------------------------------------------------- +# NOTE: This trains/fine-tunes the model on CIFAR-10 before running the metric/cluster/pruning analyses. +training: + enabled: true + epochs: 50 + learning_rate: 0.05 + optimizer: "sgd" + scheduler: "cosine" + momentum: 0.9 + weight_decay: 0.0005 + # ----------------------------------------------------------------------------- # CALIBRATION # ----------------------------------------------------------------------------- @@ -165,14 +178,6 @@ pruning: - "redundancy_low" # Prune low redundancy (MI) - "synergy_low" # Prune low synergy - # ========================================================================= - # SINGLE METRICS - Prune HIGH (assumes high = unimportant) - # ========================================================================= - - "rq_high" # Prune high RQ (TEST: is high RQ bad?) - - "redundancy_high" # Prune high redundancy (TEST: is high corr bad?) - - "synergy_high" # Prune high synergy - - "magnitude_high" # Prune high magnitude (inverse of standard) - # ========================================================================= # COMPOSITE COMBINATIONS # ========================================================================= @@ -205,10 +210,12 @@ pruning: - "composite_pos_red" fine_tune: - enabled: false # Disabled to see pure pruning impact without recovery - epochs: 10 + enabled: true # Enable recovery fine-tuning after pruning (standard for reporting) + epochs: 5 learning_rate: 0.0001 weight_decay: 0.0001 + # Safety cap: limits fine-tune compute so the full method×ratio grid stays feasible on 1 GPU + max_batches: 100 # ----------------------------------------------------------------------------- # EVALUATION (Enhanced for Vision) diff --git a/configs/vision_prune/resnet50_imagenet100_unified.yaml b/configs/vision_prune/resnet50_imagenet100_unified.yaml index a233ea46..08f0091a 100644 --- a/configs/vision_prune/resnet50_imagenet100_unified.yaml +++ b/configs/vision_prune/resnet50_imagenet100_unified.yaml @@ -45,6 +45,17 @@ dataset: image_size: 224 normalize: true +# ----------------------------------------------------------------------------- +# TRAINING (required: classifier head is replaced for ImageNet-100) +# ----------------------------------------------------------------------------- +training: + enabled: true + epochs: 30 + learning_rate: 0.001 + optimizer: "adam" + scheduler: "cosine" + weight_decay: 0.0001 + # ----------------------------------------------------------------------------- # CALIBRATION # ----------------------------------------------------------------------------- @@ -157,14 +168,6 @@ pruning: - "redundancy_low" # Prune low redundancy (MI) - "synergy_low" # Prune low synergy - # ========================================================================= - # SINGLE METRICS - Prune HIGH (assumes high = unimportant) - # ========================================================================= - - "rq_high" # Prune high RQ (TEST: is high RQ bad?) - - "redundancy_high" # Prune high redundancy (TEST: is high corr bad?) - - "synergy_high" # Prune high synergy - - "magnitude_high" # Prune high magnitude (inverse of standard) - # ========================================================================= # COMPOSITE COMBINATIONS # ========================================================================= @@ -196,10 +199,12 @@ pruning: - "composite_pos_red" fine_tune: - enabled: false # Disabled to see pure pruning impact without recovery + enabled: true # Enable recovery fine-tuning after pruning (standard for reporting) epochs: 5 # Fewer epochs for ImageNet learning_rate: 0.00001 weight_decay: 0.0001 + # Critical for feasibility: fine-tuning every (method,ratio) on ImageNet-100 otherwise explodes runtime. + max_batches: 10 # ----------------------------------------------------------------------------- # EVALUATION (Enhanced for Vision) diff --git a/configs/vision_prune/vgg16_cifar10_unified.yaml b/configs/vision_prune/vgg16_cifar10_unified.yaml index 481d3580..5afb2ec5 100644 --- a/configs/vision_prune/vgg16_cifar10_unified.yaml +++ b/configs/vision_prune/vgg16_cifar10_unified.yaml @@ -41,6 +41,18 @@ dataset: batch_size: 128 num_workers: 4 +# ----------------------------------------------------------------------------- +# TRAINING (paper-quality CIFAR baselines) +# ----------------------------------------------------------------------------- +training: + enabled: true + epochs: 50 + learning_rate: 0.05 + optimizer: "sgd" + scheduler: "cosine" + momentum: 0.9 + weight_decay: 0.0005 + # ----------------------------------------------------------------------------- # CALIBRATION # ----------------------------------------------------------------------------- @@ -153,14 +165,6 @@ pruning: - "redundancy_low" # Prune low redundancy (MI) - "synergy_low" # Prune low synergy - # ========================================================================= - # SINGLE METRICS - Prune HIGH (assumes high = unimportant) - # ========================================================================= - - "rq_high" # Prune high RQ (TEST: is high RQ bad?) - - "redundancy_high" # Prune high redundancy (TEST: is high corr bad?) - - "synergy_high" # Prune high synergy - - "magnitude_high" # Prune high magnitude (inverse of standard) - # ========================================================================= # COMPOSITE COMBINATIONS # ========================================================================= @@ -192,10 +196,11 @@ pruning: - "composite_pos_red" fine_tune: - enabled: false # Disabled to see pure pruning impact without recovery - epochs: 10 + enabled: true # Enable recovery fine-tuning after pruning (standard for reporting) + epochs: 5 learning_rate: 0.0001 weight_decay: 0.0001 + max_batches: 100 # ----------------------------------------------------------------------------- # EVALUATION (Enhanced for Vision) diff --git a/scripts/run_experiment.py b/scripts/run_experiment.py index 5fa8a8ea..2ae88204 100644 --- a/scripts/run_experiment.py +++ b/scripts/run_experiment.py @@ -29,6 +29,7 @@ import sys from datetime import datetime from pathlib import Path +from typing import Optional import torch import yaml @@ -126,6 +127,18 @@ def _get_nested(obj, key, default): fine_tune_cfg.get("epochs", 10) if isinstance(fine_tune_cfg, dict) else 10) fine_tune_lr = getattr(config, "fine_tune_learning_rate", fine_tune_cfg.get("learning_rate", 0.0001) if isinstance(fine_tune_cfg, dict) else 0.0001) + fine_tune_max_batches = getattr( + config, + "fine_tune_max_batches", + fine_tune_cfg.get("max_batches", None) if isinstance(fine_tune_cfg, dict) else None, + ) + fine_tune_weight_decay = float( + getattr( + config, + "fine_tune_weight_decay", + fine_tune_cfg.get("weight_decay", 0.0) if isinstance(fine_tune_cfg, dict) else 0.0, + ) or 0.0 + ) # Get pruning algorithms/methods pruning_methods = getattr(config, "pruning_strategies", None) or \ @@ -165,10 +178,40 @@ def _get_nested(obj, key, default): fine_tune_after_pruning=fine_tune_enabled, fine_tune_epochs=fine_tune_epochs, fine_tune_lr=fine_tune_lr, + fine_tune_max_batches=fine_tune_max_batches, + fine_tune_weight_decay=fine_tune_weight_decay, output_dir=getattr(config, "experiment_dir", "results/cluster_analysis"), device=getattr(config, "device", "cuda"), seed=getattr(config, "seed", 42), ) + + # Optional: allow sweeping cluster-aware score weights via nested pruning config: + # pruning.cluster_aware.{alpha,beta,gamma,lambda_halo,protect_critical_frac} + if isinstance(pruning_cfg, dict): + ca = pruning_cfg.get("cluster_aware", {}) + if isinstance(ca, dict): + if "alpha" in ca: + setattr(cluster_config, "cluster_aware_alpha", float(ca["alpha"])) + if "beta" in ca: + setattr(cluster_config, "cluster_aware_beta", float(ca["beta"])) + if "gamma" in ca: + setattr(cluster_config, "cluster_aware_gamma", float(ca["gamma"])) + if "lambda_halo" in ca: + setattr(cluster_config, "cluster_aware_lambda_halo", float(ca["lambda_halo"])) + if "protect_critical_frac" in ca: + setattr(cluster_config, "cluster_aware_protect_critical_frac", float(ca["protect_critical_frac"])) + + # Also support the flat ExperimentConfig fields used by our config loader + # (and mapped from unified-style dotted CLI overrides). + for attr in ( + "cluster_aware_alpha", + "cluster_aware_beta", + "cluster_aware_gamma", + "cluster_aware_lambda_halo", + "cluster_aware_protect_critical_frac", + ): + if hasattr(config, attr): + setattr(cluster_config, attr, float(getattr(config, attr))) # Load model model_name = cluster_config.model_name.lower() @@ -217,21 +260,44 @@ def _get_nested(obj, key, default): # Load dataset if "cifar10" in dataset_name: - transform = transforms.Compose([ - transforms.ToTensor(), - transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2470, 0.2435, 0.2616)), - ]) - root = dataset_cfg.get("root", "./data") if isinstance(dataset_cfg, dict) else "./data" - train_dataset = torchvision.datasets.CIFAR10(root=root, train=True, download=True, transform=transform) - test_dataset = torchvision.datasets.CIFAR10(root=root, train=False, download=True, transform=transform) + mean = (0.4914, 0.4822, 0.4465) + std = (0.2470, 0.2435, 0.2616) + root = ( + (dataset_cfg.get("root") if isinstance(dataset_cfg, dict) else None) + or getattr(config, "data_path", None) + or "./data" + ) + # Use standard CIFAR augmentation when training so baseline accuracies match common reporting. + train_transform = transforms.Compose( + [ + transforms.RandomCrop(32, padding=4), + transforms.RandomHorizontalFlip(), + transforms.ToTensor(), + transforms.Normalize(mean, std), + ] + ) + test_transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize(mean, std)]) + train_dataset = torchvision.datasets.CIFAR10(root=root, train=True, download=True, transform=train_transform) + test_dataset = torchvision.datasets.CIFAR10(root=root, train=False, download=True, transform=test_transform) elif "cifar100" in dataset_name: - transform = transforms.Compose([ - transforms.ToTensor(), - transforms.Normalize((0.5071, 0.4867, 0.4408), (0.2675, 0.2565, 0.2761)), - ]) - root = dataset_cfg.get("root", "./data") if isinstance(dataset_cfg, dict) else "./data" - train_dataset = torchvision.datasets.CIFAR100(root=root, train=True, download=True, transform=transform) - test_dataset = torchvision.datasets.CIFAR100(root=root, train=False, download=True, transform=transform) + mean = (0.5071, 0.4867, 0.4408) + std = (0.2675, 0.2565, 0.2761) + root = ( + (dataset_cfg.get("root") if isinstance(dataset_cfg, dict) else None) + or getattr(config, "data_path", None) + or "./data" + ) + train_transform = transforms.Compose( + [ + transforms.RandomCrop(32, padding=4), + transforms.RandomHorizontalFlip(), + transforms.ToTensor(), + transforms.Normalize(mean, std), + ] + ) + test_transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize(mean, std)]) + train_dataset = torchvision.datasets.CIFAR100(root=root, train=True, download=True, transform=train_transform) + test_dataset = torchvision.datasets.CIFAR100(root=root, train=False, download=True, transform=test_transform) elif "imagenet100" in dataset_name: # Expected folder structure: {root}/train/* and {root}/val/* (ImageFolder) root = dataset_cfg.get("root", "./data/imagenet100") if isinstance(dataset_cfg, dict) else "./data/imagenet100" @@ -262,36 +328,48 @@ def _get_nested(obj, key, default): else: raise ValueError(f"Unknown dataset: {dataset_name}") - batch_size = int(dataset_cfg.get("batch_size", getattr(config, "batch_size", 128))) if isinstance(dataset_cfg, dict) else int(getattr(config, "batch_size", 128)) - num_workers = int(dataset_cfg.get("num_workers", 4)) if isinstance(dataset_cfg, dict) else 4 + batch_size = int(getattr(config, "batch_size", 128)) + num_workers = int(getattr(config, "num_workers", 4)) train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=num_workers) test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=batch_size * 2, shuffle=False, num_workers=num_workers) - # Fine-tune the model on target dataset before experiments - # This is necessary because we replaced the classifier head with random weights - # Get training settings from config (check multiple locations) - training_cfg = _get_nested(config, "training", {}) - extra_cfg = _get_nested(config, "extra", {}) - - # Check in order: training.epochs, extra.pretrain_epochs, config.pretrain_epochs - pretrain_epochs = ( - training_cfg.get("epochs") if isinstance(training_cfg, dict) else None - ) or ( - extra_cfg.get("pretrain_epochs") if isinstance(extra_cfg, dict) else None - ) or getattr(config, "pretrain_epochs", 30) - - pretrain_lr = ( - training_cfg.get("learning_rate") if isinstance(training_cfg, dict) else None - ) or ( - extra_cfg.get("pretrain_lr") if isinstance(extra_cfg, dict) else None - ) or getattr(config, "pretrain_lr", 0.001) - - if needs_training: + # CIFAR-specific stem tweak: using the ImageNet stem (7x7,stride2 + maxpool) + # degrades CIFAR accuracy. Use the standard CIFAR stem and (when pretrained) + # seed weights by center-cropping the 7x7 conv filter. + if ("cifar" in dataset_name) and ("resnet" in model_name): + if hasattr(model, "conv1") and hasattr(model, "maxpool"): + old_conv = model.conv1 + new_conv = torch.nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False) + try: + if pretrained and hasattr(old_conv, "weight") and old_conv.weight.shape[-1] == 7: + with torch.no_grad(): + new_conv.weight.copy_(old_conv.weight[:, :, 2:5, 2:5]) + except Exception: + pass + model.conv1 = new_conv + model.maxpool = torch.nn.Identity() + + # Train/fine-tune the model on target dataset before experiments. + # If you want a pure "no-training" analysis, provide an explicit checkpoint and set do_train=false. + do_train = bool(getattr(config, "do_train", True)) + if needs_training and not do_train: + raise RuntimeError( + f"Model checkpoint not found and training is disabled (do_train=false). " + f"Provide model.checkpoint/model_checkpoint or enable training in the config." + ) + if needs_training and do_train: model = _finetune_model_for_dataset( - model, train_loader, test_loader, + model, + train_loader, + test_loader, device=cluster_config.device, - epochs=pretrain_epochs, - lr=pretrain_lr, + epochs=int(getattr(config, "training_epochs", 30)), + lr=float(getattr(config, "learning_rate", 1e-3)), + optimizer_name=str(getattr(config, "optimizer", "adam")), + weight_decay=float(getattr(config, "weight_decay", 0.0) or 0.0), + momentum=float(getattr(config, "momentum", 0.9) or 0.9), + scheduler=getattr(config, "scheduler", None), + scheduler_config=getattr(config, "scheduler_config", {}) or {}, ) # Save the trained model checkpoint @@ -317,6 +395,12 @@ def _finetune_model_for_dataset( device: str = "cuda", epochs: int = 20, lr: float = 0.001, + optimizer_name: str = "adam", + weight_decay: float = 1e-4, + momentum: float = 0.9, + scheduler: Optional[str] = "cosine", + scheduler_config: Optional[dict] = None, + max_batches: Optional[int] = None, ) -> torch.nn.Module: """ Fine-tune a pretrained model on the target dataset. @@ -350,8 +434,8 @@ def _finetune_model_for_dataset( total += y.size(0) initial_acc = correct / total - # If already well-trained (>85% accuracy on CIFAR-10), skip fine-tuning - if initial_acc > 0.85: + # If already well-trained, skip fine-tuning (useful when an explicit checkpoint is provided). + if initial_acc > 0.90: logger.info(f"Model already well-trained (accuracy: {initial_acc:.2%}), skipping fine-tuning") return model @@ -368,12 +452,46 @@ def _finetune_model_for_dataset( else: pretrained_params.append(param) - optimizer = optim.Adam([ - {'params': pretrained_params, 'lr': lr * 0.1}, # Lower LR for pretrained - {'params': new_params, 'lr': lr}, # Higher LR for new classifier - ], weight_decay=1e-4) - - scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=epochs) + opt_name = (optimizer_name or "adam").lower() + if opt_name in {"sgd", "momentum", "sgd_momentum"}: + optimizer = optim.SGD( + [ + {"params": pretrained_params, "lr": lr * 0.1}, + {"params": new_params, "lr": lr}, + ], + momentum=float(momentum), + weight_decay=float(weight_decay), + nesterov=True, + ) + elif opt_name in {"adamw"}: + optimizer = optim.AdamW( + [ + {"params": pretrained_params, "lr": lr * 0.1}, + {"params": new_params, "lr": lr}, + ], + weight_decay=float(weight_decay), + ) + else: + optimizer = optim.Adam( + [ + {"params": pretrained_params, "lr": lr * 0.1}, + {"params": new_params, "lr": lr}, + ], + weight_decay=float(weight_decay), + ) + + # Scheduler (optional) + sch_name = (str(scheduler).lower() if scheduler is not None else "none") + scheduler_config = scheduler_config or {} + if sch_name in {"none", "null", ""}: + lr_scheduler = None + elif sch_name in {"step", "steplr"}: + step_size = int(scheduler_config.get("step_size", max(1, epochs // 3))) + gamma = float(scheduler_config.get("gamma", 0.1)) + lr_scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=step_size, gamma=gamma) + else: + # Default: cosine + lr_scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=epochs) criterion = torch.nn.CrossEntropyLoss() best_acc = 0 @@ -383,7 +501,9 @@ def _finetune_model_for_dataset( # Train model.train() train_loss = 0 - for x, y in train_loader: + for bi, (x, y) in enumerate(train_loader): + if max_batches is not None and bi >= int(max_batches): + break x, y = x.to(device), y.to(device) optimizer.zero_grad() out = model(x) @@ -391,8 +511,9 @@ def _finetune_model_for_dataset( loss.backward() optimizer.step() train_loss += loss.item() - - scheduler.step() + + if lr_scheduler is not None: + lr_scheduler.step() # Evaluate model.eval() diff --git a/src/alignment/analysis/visualization/paper_plots.py b/src/alignment/analysis/visualization/paper_plots.py index 3247abee..81471491 100644 --- a/src/alignment/analysis/visualization/paper_plots.py +++ b/src/alignment/analysis/visualization/paper_plots.py @@ -192,7 +192,8 @@ def plot_halo_structure( def plot_supernode_halo_summary( layer_indices: Sequence[int], top_mass_ratios: Sequence[float], - halo_aggregate: Dict[str, Any], + halo_aggregate: Optional[Dict[str, Any]] = None, + halo_per_layer: Optional[Dict[str, Any]] = None, rho: float = 0.01, save_path: Optional[Union[str, Path]] = None, dpi: int = 300, @@ -216,22 +217,211 @@ def plot_supernode_halo_summary( ax.grid(True, alpha=0.25) ax = axes[1] - groups = [("Within-Halo", "halo_halo"), ("Within-Non-Halo", "non_halo"), ("Cross", "cross")] - means = [] - stds = [] - for _, key in groups: - rec = halo_aggregate.get(key) or {} - means.append(float(rec.get("mean", 0.0))) - stds.append(float(rec.get("std", 0.0))) - - x = np.arange(len(groups)) - ax.bar(x, means, yerr=stds, capsize=4, color=["#1f77b4", "#7f8c8d", "#2ecc71"], alpha=0.85) - ax.set_xticks(x) - ax.set_xticklabels([g[0] for g in groups], rotation=15, ha="right") - ax.set_ylabel("Redundancy (Gaussian MI, nats)") - ax.set_title("Halo redundancy vs non-halo (avg.)") + groups = [("Within-Halo", "halo_halo", "#1f77b4"), ("Within-Non-Halo", "non_halo", "#7f8c8d"), ("Cross", "cross", "#2ecc71")] + + # Prefer per-layer distributions (much clearer than mean±std when the MI distribution is heavy-tailed). + if isinstance(halo_per_layer, dict) and halo_per_layer: + data = [] + for _, key, _ in groups: + vals: List[float] = [] + for _, rec in halo_per_layer.items(): + if not isinstance(rec, dict): + continue + g = rec.get(key) + if not isinstance(g, dict): + continue + m = g.get("median") + try: + mf = float(m) + except Exception: + continue + if np.isfinite(mf) and mf > 0: + vals.append(mf) + data.append(np.asarray(vals, dtype=np.float64)) + + bp = ax.boxplot( + data, + vert=True, + patch_artist=True, + showfliers=False, + medianprops=dict(color="#2c3e50", linewidth=2), + boxprops=dict(linewidth=1.2, color="#2c3e50"), + whiskerprops=dict(linewidth=1.2, color="#2c3e50"), + capprops=dict(linewidth=1.2, color="#2c3e50"), + ) + # Color the boxes + for patch, (_, _, color) in zip(bp.get("boxes", []), groups): + patch.set_facecolor(color) + patch.set_alpha(0.75) + + # Overlay jittered per-layer medians for transparency + rng = np.random.default_rng(0) + for i, vals in enumerate(data, start=1): + if vals.size == 0: + continue + jitter = rng.normal(loc=0.0, scale=0.05, size=vals.size) + ax.scatter( + np.full(vals.shape, i, dtype=float) + jitter, + vals, + s=14, + alpha=0.35, + color="#2c3e50", + edgecolors="none", + ) + + ax.set_xticks(np.arange(1, len(groups) + 1)) + ax.set_xticklabels([g[0] for g in groups], rotation=15, ha="right") + ax.set_ylabel("Redundancy (Gaussian MI, nats)\n(per-layer median)") + ax.set_title("Halo redundancy across layers") + ax.grid(True, alpha=0.25, axis="y") + # MI is positive and often heavy-tailed; log helps readability. + ax.set_yscale("log") + else: + # Fallback: show mean ± 95% CI of the mean (std can be huge for heavy tails). + halo_aggregate = halo_aggregate or {} + means = [] + cis = [] + for _, key, _ in groups: + rec = halo_aggregate.get(key) or {} + mu = float(rec.get("mean", 0.0)) + sd = float(rec.get("std", 0.0)) + n = float(rec.get("count", 0.0) or 0.0) + sem = sd / np.sqrt(n) if n > 1 else 0.0 + means.append(mu) + cis.append(1.96 * sem) + + x = np.arange(len(groups)) + ax.bar(x, means, yerr=cis, capsize=4, color=[g[2] for g in groups], alpha=0.85, edgecolor="none") + ax.set_xticks(x) + ax.set_xticklabels([g[0] for g in groups], rotation=15, ha="right") + ax.set_ylabel("Redundancy (Gaussian MI, nats)\n(mean ± 95% CI)") + ax.set_title("Halo redundancy (aggregate)") + ax.grid(True, alpha=0.25, axis="y") + + plt.tight_layout() + if save_path is not None: + _save(fig, save_path, dpi=dpi) + return fig + + +def plot_supernode_outlier_profile( + layer_indices: Sequence[int], + outlier_ratios: Sequence[float], + z_scores_activation: Sequence[float], + z_scores_loss_proxy: Sequence[float], + z_scores_max_activation: Sequence[float], + rho: float = 0.01, + save_path: Optional[Union[str, Path]] = None, + dpi: int = 300, +) -> plt.Figure: + """ + Two-panel plot for supernode outlier strength across depth. + + - Left: activation outlier ratio (supernode mean / population mean), log scale. + - Right: z-scores across layers (activation and loss-proxy); plus max-neuron activation z on a secondary axis. + """ + layers = np.asarray(list(layer_indices), dtype=int) + ratios = np.asarray(list(outlier_ratios), dtype=np.float64) + z_act = np.asarray(list(z_scores_activation), dtype=np.float64) + z_lp = np.asarray(list(z_scores_loss_proxy), dtype=np.float64) + z_max = np.asarray(list(z_scores_max_activation), dtype=np.float64) + + fig, axes = plt.subplots(1, 2, figsize=(12, 4.0)) + + # Panel A: outlier ratio (log) + ax = axes[0] + ax.plot(layers, ratios, "o-", color="#8e44ad", linewidth=2.0, markersize=4, label="Supernode mean / population mean") + ax.set_yscale("log") + ax.axhline(10.0, color="#f39c12", linestyle="--", linewidth=1.8, label="10×") + ax.axhline(100.0, color="#c0392b", linestyle="--", linewidth=1.8, label="100×") + ax.set_xlabel("Layer index") + ax.set_ylabel("Activation outlier ratio (log scale)") + ax.set_title(f"Supernode outlier ratio (top {rho*100:.0f}% by LP)") + ax.grid(True, alpha=0.25, axis="y") + ax.legend(loc="upper right", frameon=True) + + # Panel B: z-scores (dual axis) + ax = axes[1] + ax.plot(layers, z_act, "o-", color="#e67e22", linewidth=2.0, markersize=4, label="Activation z (supernode mean)") + ax.plot(layers, z_lp, "o-", color="#2980b9", linewidth=2.0, markersize=4, label="Loss-proxy z (supernode mean)") + ax.axhline(2.0, color="#7f8c8d", linestyle="--", linewidth=1.5, alpha=0.8) + ax.axhline(3.0, color="#7f8c8d", linestyle="--", linewidth=1.5, alpha=0.8) + ax.set_xlabel("Layer index") + ax.set_ylabel("Z-score (supernode mean)") + ax.set_title("Outlier z-scores across layers") ax.grid(True, alpha=0.25, axis="y") + ax2 = ax.twinx() + ax2.plot(layers, z_max, "^-", color="#2c3e50", linewidth=1.8, markersize=5, label="Activation z (max neuron)") + ax2.set_ylabel("Z-score (max neuron, activation)") + + # Combined legend + h1, l1 = ax.get_legend_handles_labels() + h2, l2 = ax2.get_legend_handles_labels() + ax.legend(h1 + h2, l1 + l2, loc="upper right", frameon=True) + + plt.tight_layout() + if save_path is not None: + _save(fig, save_path, dpi=dpi) + return fig + + +def plot_sparsity_perplexity_curves( + sparsities: Sequence[float], + ppl_by_method: Dict[str, Sequence[Optional[float]]], + baseline_ppl: Optional[float] = None, + save_path: Optional[Union[str, Path]] = None, + dpi: int = 300, +) -> plt.Figure: + """ + Paper-facing plot: perplexity vs structured sparsity for multiple methods. + + Inputs are already "paper-ready" (i.e., only the intended pruning direction, typically low-mode). + """ + xs = np.asarray(list(sparsities), dtype=np.float64) + fig, ax = plt.subplots(figsize=(7.0, 4.2)) + + # Stable ordering for legend + for label in sorted(ppl_by_method.keys()): + ys_raw = ppl_by_method[label] + ys = np.asarray([np.nan if v is None else float(v) for v in ys_raw], dtype=np.float64) + finite = np.isfinite(ys) + if not np.any(finite): + continue + ax.plot(xs[finite], ys[finite], "o-", linewidth=2.0, markersize=5, label=label, alpha=0.9) + + if baseline_ppl is not None: + try: + b = float(baseline_ppl) + if np.isfinite(b): + ax.axhline(b, color="#2c3e50", linestyle=":", linewidth=2.0, label=f"Unpruned ({b:.1f})") + except Exception: + pass + + ax.set_xlabel("Structured FFN channel sparsity", fontsize=11) + ax.set_ylabel("Perplexity (WikiText-2)", fontsize=11) + ax.set_title("Perplexity vs sparsity (low-mode)", fontsize=12, fontweight="bold") + ax.grid(True, alpha=0.25) + ax.legend(loc="upper left", fontsize=9, frameon=True) + + # Use log if the dynamic range is large. + all_vals = [] + for vs in ppl_by_method.values(): + for v in vs: + if v is None: + continue + try: + vf = float(v) + except Exception: + continue + if np.isfinite(vf) and vf > 0: + all_vals.append(vf) + if all_vals: + mn = min(all_vals) + mx = max(all_vals) + if mx / max(mn, 1e-9) > 20: + ax.set_yscale("log") + plt.tight_layout() if save_path is not None: _save(fig, save_path, dpi=dpi) @@ -243,68 +433,125 @@ def plot_scar_schematic( dpi: int = 300, ) -> plt.Figure: """ - Generate a simple schematic of SCAR (supernodes + halos) as a flowchart. + Generate a schematic of SCAR (supernodes + halos) as a flowchart. This is model-agnostic and can be generated during artifact collection. """ - fig = plt.figure(figsize=(12, 4.5)) + # Keep this figure intentionally clean + legible in a 1-column ICML layout. + fig = plt.figure(figsize=(12, 3.8)) ax = fig.add_subplot(111) ax.set_axis_off() ax.set_xlim(0, 1) ax.set_ylim(0, 1) - def box(x, y, w, h, text, fc="#ecf0f1", ec="#2c3e50"): + def box(x, y, w, h, text, fc="#ecf0f1", ec="#2c3e50", lw: float = 1.6): p = FancyBboxPatch( (x, y), w, h, boxstyle="round,pad=0.02,rounding_size=0.02", - linewidth=1.6, + linewidth=lw, edgecolor=ec, facecolor=fc, ) ax.add_patch(p) - ax.text(x + w / 2, y + h / 2, text, ha="center", va="center", fontsize=10) + ax.text(x + w / 2, y + h / 2, text, ha="center", va="center", fontsize=10.5) def arrow(x1, y1, x2, y2, color="#2c3e50"): a = FancyArrowPatch((x1, y1), (x2, y2), arrowstyle="->", linewidth=1.6, color=color, mutation_scale=12) ax.add_patch(a) - # Left: FFN depiction (conceptual) - box(0.02, 0.62, 0.18, 0.26, "FFN layer\n(MLP channels)", fc="#f8f9f9") - ax.text(0.11, 0.71, "u channels", ha="center", va="center", fontsize=9) - # Draw "channels": a few vertical ticks, color some as supernodes/halo - for i, x in enumerate(np.linspace(0.05, 0.19, 9)): - c = "#7f8c8d" - if i in (2, 3): - c = "#c0392b" # supernodes - if i in (5, 6): - c = "#1f77b4" # halo - ax.plot([x, x], [0.66, 0.84], color=c, linewidth=3) - ax.text(0.03, 0.58, "Supernodes (red): high LP\nHalo (blue): high Conn + redundant", fontsize=9, ha="left", va="top") - - # Middle: compute steps - box(0.28, 0.70, 0.20, 0.20, "Calibration\nforward+backward", fc="#fdf2e9", ec="#d35400") - # NOTE: Use \frac{1}{2} (not \frac12) for broad compatibility with matplotlib mathtext. - box(0.52, 0.70, 0.22, 0.20, r"Loss proxy\n$\mathrm{LP}_i=\frac{1}{2}\mathbb{E}[(u_i s_i)^2]$", fc="#fdf2e9", ec="#d35400") - box(0.78, 0.70, 0.20, 0.20, r"Supernodes\n(top-$\rho$ by LP)\nprotect core", fc="#fdebd0", ec="#c0392b") - - arrow(0.20, 0.80, 0.28, 0.80) - arrow(0.48, 0.80, 0.52, 0.80) - arrow(0.74, 0.80, 0.78, 0.80) - - # Bottom: halo + redundancy + pruning - box(0.28, 0.35, 0.22, 0.20, r"Connectivity\n$\mathrm{Conn}_j$ from $|v_j|$ overlap", fc="#e8f6ff", ec="#2980b9") - box(0.54, 0.35, 0.20, 0.20, r"Halo\n(top-$\eta$ non-core by Conn)", fc="#e8f6ff", ec="#2980b9") - box(0.78, 0.35, 0.20, 0.20, r"Redundancy\n$\mathrm{Red}^{\rightarrow\mathcal{M}}$ from $q=u\!\odot\!s$", fc="#eafaf1", ec="#27ae60") - box(0.52, 0.06, 0.46, 0.20, r"Score + prune\n(prune low-$\mathrm{LP}$ first,\nboost halo followers; respect caps)", fc="#f8f9f9", ec="#2c3e50") - - arrow(0.62, 0.70, 0.39, 0.55) - arrow(0.50, 0.45, 0.54, 0.45) - arrow(0.74, 0.45, 0.78, 0.45) - arrow(0.88, 0.35, 0.75, 0.26) - arrow(0.64, 0.35, 0.64, 0.26) - - ax.text(0.02, 0.97, "SCAR schematic (supernodes + halos for structured FFN channel pruning)", fontsize=12, fontweight="bold", ha="left", va="top") + # ------------------------------------------------------------------ + # Column layout + # ------------------------------------------------------------------ + x0 = 0.03 + col_w = 0.22 + gap = 0.035 + y_top = 0.58 + h_top = 0.32 + y_bot = 0.15 + h_bot = 0.30 + + # Colors (match paper narrative) + C_SUP = "#c0392b" # supernodes + C_HALO = "#1f77b4" # halo + C_STEP = "#2c3e50" # neutral + C_CAL = "#d35400" # calibration/loss-proxy compute + + # --- Col 1: Calibration + proxy --- + box(x0, y_top, col_w, h_top, "Calibration\n(tokens)", fc="#fdf2e9", ec=C_CAL) + box( + x0, + y_bot, + col_w, + h_bot, + # NOTE: Use \frac{1}{2} (not \frac12) for broad compatibility with matplotlib mathtext. + r"Loss proxy\n$\mathrm{LP}_i=\frac{1}{2}\,\mathbb{E}[(u_i s_i)^2]$", + fc="#fdf2e9", + ec=C_CAL, + ) + + # Tiny icon: forward/backward arrows + ax.text(x0 + col_w / 2, y_top + 0.07, "fwd + bwd", ha="center", va="center", fontsize=9.5, color=C_STEP) + + # --- Col 2: Supernodes --- + x1 = x0 + col_w + gap + box(x1, y_top, col_w, h_top, r"Supernodes\n(top-$\rho$ by LP)\n\bf protect", fc="#fdebd0", ec=C_SUP) + box(x1, y_bot, col_w, h_bot, "FFN channels\n(sorted by LP)", fc="#f8f9f9", ec=C_STEP) + + # Draw a stylized heavy-tail: a few bars, with 2 red outliers + xs = np.linspace(x1 + 0.03, x1 + col_w - 0.03, 10) + heights = np.array([0.06, 0.05, 0.04, 0.035, 0.03, 0.028, 0.025, 0.022, 0.18, 0.24]) + for i, (xx, hh) in enumerate(zip(xs, heights)): + c = C_SUP if i >= 8 else "#7f8c8d" + ax.plot([xx, xx], [y_bot + 0.06, y_bot + 0.06 + hh], color=c, linewidth=4, solid_capstyle="round") + ax.text(x1 + col_w / 2, y_bot + 0.03, "rare outliers", ha="center", va="center", fontsize=9.0, color=C_STEP) + + # --- Col 3: Halo + redundancy --- + x2 = x1 + col_w + gap + box(x2, y_top, col_w, h_top, r"Halo\n(high Conn to core)", fc="#e8f6ff", ec=C_HALO) + box( + x2, + y_bot, + col_w, + h_bot, + r"Redundancy to core\n$\mathrm{Red}^{\rightarrow\mathcal{M}}_j=\max_{m\in\mathcal{M}} I(q_j;q_m)$", + fc="#eafaf1", + ec="#27ae60", + ) + ax.text(x2 + col_w / 2, y_bot + 0.03, r"$q=u\odot s$", ha="center", va="center", fontsize=9.0, color=C_STEP) + + # --- Col 4: Structured pruning --- + x3 = x2 + col_w + gap + box( + x3, + y_top, + col_w, + h_top, + "Score + prune\n(non-core only)\nlayer caps", + fc="#f8f9f9", + ec=C_STEP, + ) + box(x3, y_bot, col_w, h_bot, r"Result:\nstructured FFN\nchannel sparsity", fc="#f8f9f9", ec=C_STEP) + + # Arrows across columns (top row) + arrow(x0 + col_w, y_top + h_top / 2, x1, y_top + h_top / 2, color=C_STEP) + arrow(x1 + col_w, y_top + h_top / 2, x2, y_top + h_top / 2, color=C_STEP) + arrow(x2 + col_w, y_top + h_top / 2, x3, y_top + h_top / 2, color=C_STEP) + + # Vertical arrows within columns + arrow(x0 + col_w / 2, y_top, x0 + col_w / 2, y_bot + h_bot, color=C_STEP) + arrow(x2 + col_w / 2, y_top, x2 + col_w / 2, y_bot + h_bot, color=C_STEP) + + ax.text( + 0.02, + 0.98, + "SCAR: supernodes + halos for structured FFN channel pruning", + fontsize=12.5, + fontweight="bold", + ha="left", + va="top", + color=C_STEP, + ) plt.tight_layout() if save_path is not None: diff --git a/src/alignment/configs/config_loader.py b/src/alignment/configs/config_loader.py index c4b9bc65..002340e2 100644 --- a/src/alignment/configs/config_loader.py +++ b/src/alignment/configs/config_loader.py @@ -139,6 +139,14 @@ def _convert_unified_to_original(unified: Dict[str, Any]) -> Dict[str, Any]: original["dataset"]["split"] = dataset["split"] if "root" in dataset: original["dataset"]["data_path"] = dataset["root"] + + # ------------------------------------------------------------------------- + # TRAINING + # ------------------------------------------------------------------------- + # Pass through unified `training:` block so downstream flattening can set + # ExperimentConfig.{do_train,training_epochs,learning_rate,optimizer,...}. + if "training" in unified and isinstance(unified["training"], dict): + original["training"] = unified["training"] # ------------------------------------------------------------------------- # CALIBRATION @@ -946,6 +954,10 @@ def _map_nested_to_flat_config(nested_config: Dict[str, Any]) -> Dict[str, Any]: flat_config["fine_tune_epochs"] = fine_tune_block["epochs"] if "learning_rate" in fine_tune_block: flat_config["fine_tune_learning_rate"] = fine_tune_block["learning_rate"] + if "max_batches" in fine_tune_block: + flat_config["fine_tune_max_batches"] = fine_tune_block["max_batches"] + if "weight_decay" in fine_tune_block: + flat_config["fine_tune_weight_decay"] = fine_tune_block["weight_decay"] # Map top-level analysis flags flat_config["do_pruning_experiments"] = pruning_block.get("enabled", nested_config.get("do_pruning_experiments", False)) @@ -1202,9 +1214,41 @@ def load_config_with_overrides( # Apply CLI overrides if cli_args: + # Map "unified-style" dotted CLI keys used by paper SLURM scripts into the + # flat ExperimentConfig namespace produced by load_config(). + # + # Without this mapping, overrides like `metrics.activation_samples=gap` would + # try to index into `config_dict["metrics"]` (a list) and crash, and overrides + # like `pruning.cluster_aware.gamma=...` would create a new top-level `pruning` + # dict (which ExperimentConfig cannot accept). + dotted_key_map = { + # Activation sampling / CNN handling for cluster experiments + "metrics.activation_samples": "activation_samples", + "metrics.spatial_samples_per_image": "spatial_samples_per_image", + "metrics.synergy_target": "synergy_target", + "metrics.synergy_candidate_pool": "synergy_candidate_pool", + "metrics.synergy_num_pairs": "synergy_pairs", + # Clustering + "clustering.n_clusters": "n_clusters", + # Cluster-aware pruning weight sweeps (paper) + "pruning.cluster_aware.alpha": "cluster_aware_alpha", + "pruning.cluster_aware.beta": "cluster_aware_beta", + "pruning.cluster_aware.gamma": "cluster_aware_gamma", + "pruning.cluster_aware.lambda_halo": "cluster_aware_lambda_halo", + "pruning.cluster_aware.protect_critical_frac": "cluster_aware_protect_critical_frac", + # Fine-tuning after pruning + "pruning.fine_tune.enabled": "fine_tune_after_pruning", + "pruning.fine_tune.epochs": "fine_tune_epochs", + "pruning.fine_tune.learning_rate": "fine_tune_learning_rate", + "pruning.fine_tune.max_batches": "fine_tune_max_batches", + "pruning.fine_tune.weight_decay": "fine_tune_weight_decay", + } + for arg in cli_args: if "=" in arg: key, value = arg.split("=", 1) + key = key.strip() + key = dotted_key_map.get(key, key) # Convert value to appropriate type try: # Common CLI convenience: YAML-style booleans/nulls @@ -1221,12 +1265,23 @@ def load_config_with_overrides( pass # Keep as string # Handle nested keys (e.g., "model.hidden_dims=[300,200]") - keys = key.split(".") - target = config_dict - for k in keys[:-1]: - if k not in target: - target[k] = {} - target = target[k] - target[keys[-1]] = value + if "." not in key: + config_dict[key] = value + else: + keys = key.split(".") + target = config_dict + for k in keys[:-1]: + if not isinstance(target, dict): + raise ValueError(f"Cannot apply override '{key}': encountered non-dict target ({type(target)})") + if k not in target or target[k] is None: + target[k] = {} + elif not isinstance(target[k], dict): + raise ValueError( + f"Cannot apply override '{key}': '{k}' is not a dict (got {type(target[k])}). " + f"Use a flat override (e.g., '{dotted_key_map.get(arg.split('=',1)[0].strip(), key)}=...') " + "or override an actual dict field like 'model_config.*' / 'dataset_config.*'." + ) + target = target[k] + target[keys[-1]] = value return ExperimentConfig.from_dict(config_dict) diff --git a/src/alignment/experiments/base.py b/src/alignment/experiments/base.py index de9a6d3e..b7402aee 100644 --- a/src/alignment/experiments/base.py +++ b/src/alignment/experiments/base.py @@ -90,6 +90,26 @@ class ExperimentConfig: # CNN-specific configuration cnn_mode: str = "unfold" # Options: "unfold", "patchwise", "batch_patch_combined" + # --------------------------------------------------------------------- + # Vision / cluster-analysis extras (used by ClusterAnalysisExperiment) + # --------------------------------------------------------------------- + # How to form channel samples from Conv outputs Y[B,C,H,W] + # - "flatten_spatial": treat spatial positions as samples (subsample per image) + # - "gap": global-average-pool per image (one sample per image) + activation_samples: str = "flatten_spatial" + spatial_samples_per_image: int = 16 # used when activation_samples="flatten_spatial" + n_clusters: int = 4 + synergy_target: str = "logit_margin" # logit_margin, correct_logit, logit_pc1 + synergy_candidate_pool: int = 50 + synergy_pairs: int = 10 + + # Cluster-aware pruning score weights (paper sweeps) + cluster_aware_alpha: float = 1.0 + cluster_aware_beta: float = 0.5 + cluster_aware_gamma: float = 0.3 + cluster_aware_lambda_halo: float = 0.5 + cluster_aware_protect_critical_frac: float = 0.3 + # Analysis control flags do_dropout_analysis: bool = False do_eigenfeature_analysis: bool = False @@ -119,6 +139,10 @@ class ExperimentConfig: pruning_min_per_layer: float = 0.0 pruning_max_per_layer: float = 0.95 fine_tune_learning_rate: Optional[float] = None # Will default to learning_rate * 0.1 + # Optional cap for post-pruning fine-tuning speed (useful for ImageNet-scale runs) + # None => use the full training loader each epoch. + fine_tune_max_batches: Optional[int] = None + fine_tune_weight_decay: float = 0.0 alignment_structured_pruning: bool = False # Use structured pruning for alignment cascading_direction: str = "forward" # Direction for cascading pruning dependency_aware_pruning: bool = False # Propagate masks across dependent layers diff --git a/src/alignment/metrics/information/synergy_continuous.py b/src/alignment/metrics/information/synergy_continuous.py index 97798359..8aaa62fc 100644 --- a/src/alignment/metrics/information/synergy_continuous.py +++ b/src/alignment/metrics/information/synergy_continuous.py @@ -104,7 +104,7 @@ def compute( if outputs.ndim > 2: # Conv layer: [B, C, H, W] -> [B, C] via GAP outputs = outputs.mean(dim=(2, 3)) if outputs.ndim == 4 else outputs.reshape(outputs.shape[0], -1) - + # Handle batch mismatch (common when upstream preprocessing unfolds CNN outputs) # If outputs has more samples than logits/labels, aggregate back to per-example activations. # This makes synergy w.r.t. per-example target T well-defined. diff --git a/src/alignment/preprocessing/layer_preprocessing.py b/src/alignment/preprocessing/layer_preprocessing.py index cfc0a65f..a24462de 100644 --- a/src/alignment/preprocessing/layer_preprocessing.py +++ b/src/alignment/preprocessing/layer_preprocessing.py @@ -119,15 +119,15 @@ def _unfold_mode(self, activation: torch.Tensor, layer: nn.Module, is_input: boo if activation.ndim != 4: raise ValueError(f"Expected 4D tensor for Conv2d, got {activation.ndim}D") - b, c, h, w = activation.shape + b, c, h, w = activation.shape if is_input: # Unfold based on the layer's kernel parameters so feature dimension matches weight flattening - unfold_params = self._get_unfold_params(layer) - unfolded = torch.nn.functional.unfold(activation, kernel_size=layer.kernel_size, **unfold_params) + unfold_params = self._get_unfold_params(layer) + unfolded = torch.nn.functional.unfold(activation, kernel_size=layer.kernel_size, **unfold_params) # [b, features, num_patches] -> [b*num_patches, features] - unfolded = unfolded.transpose(1, 2).contiguous() - return unfolded.view(-1, unfolded.size(2)) + unfolded = unfolded.transpose(1, 2).contiguous() + return unfolded.view(-1, unfolded.size(2)) # Output: treat each spatial location as a sample (node = output channel) # [b, c, h, w] -> [b*h*w, c] @@ -172,13 +172,13 @@ def _patchwise_mode(self, activation: torch.Tensor, layer: nn.Module, is_input: if activation.ndim != 4: raise ValueError(f"Expected 4D tensor for Conv2d, got {activation.ndim}D") - b, c, h, w = activation.shape + b, c, h, w = activation.shape if is_input: # Unfold to get kernel patches - unfold_params = self._get_unfold_params(layer) - unfolded = torch.nn.functional.unfold(activation, kernel_size=layer.kernel_size, **unfold_params) - return unfolded # [b, features, patches] + unfold_params = self._get_unfold_params(layer) + unfolded = torch.nn.functional.unfold(activation, kernel_size=layer.kernel_size, **unfold_params) + return unfolded # [b, features, patches] # Output: reshape spatial dims to patches (node = output channel) return activation.reshape(b, c, h * w) # [b, c, patches] @@ -231,9 +231,9 @@ def _get_unfold_params(self, layer: nn.Module) -> Dict[str, Any]: def get_output_shape(self, input_shape: Tuple[int, ...], layer: nn.Module) -> Tuple[int, ...]: """Get expected output shape after preprocessing.""" if isinstance(layer, nn.Conv2d): - if len(input_shape) != 4: + if len(input_shape) != 4: raise ValueError(f"Expected 4D input shape for Conv2d, got {len(input_shape)}D") - b, c, h, w = input_shape + b, c, h, w = input_shape # Output spatial size (PyTorch conv2d formula; floor division) k_h, k_w = layer.kernel_size