diff --git a/modelopt/torch/quantization/activation_collector.py b/modelopt/torch/quantization/activation_collector.py new file mode 100644 index 000000000..0d060296b --- /dev/null +++ b/modelopt/torch/quantization/activation_collector.py @@ -0,0 +1,334 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Sequential calibration layer patching and activation capture. + +This module provides :class:`LayerActivationCollector`, a stateful helper that +patches decoder layers with a skip / run / capture strategy for efficient +layer-by-layer calibration. +""" + +import copy +from collections import deque +from dataclasses import dataclass, field +from typing import Any + +import torch +import torch.nn as nn + +from modelopt.torch.opt.searcher import ForwardLoop +from modelopt.torch.utils import print_rank_0 +from modelopt.torch.utils.network import bind_forward_method, unpatch_forward_method + + +class _EarlyStopForwardError(Exception): + """Raised to halt the forward pass after capturing layer inputs.""" + + +@dataclass +class _LayerCalibState: + """Mutable per-layer state used during sequential calibration. + + Attached to each decoder layer as ``_seq_calib`` and accessed by the + patched forward to decide skip / run / capture / original behaviour. + """ + + mode: str = "original" + name: str = "" + cached_inputs: deque = field(default_factory=deque) + collected_inputs: list = field(default_factory=list) + output_meta: tuple | None = None + + +class LayerActivationCollector: + """Collects layer activations for sequential (layer-by-layer) calibration. + + Each decoder layer is patched with a unified forward whose behaviour is + governed by a per-layer :class:`_LayerCalibState`: + + * **skip** — return a zero-filled dummy whose shape and type match the + layer's real output (reconstructed from lightweight metadata). No + computation is performed. The correctly shaped dummy ensures un-patched + inter-layer operations in the parent forward (e.g. LayerNorm, tuple + unpacking) do not raise shape or type errors. + * **run** — replay previously captured inputs through the original forward, + ignoring whatever the parent passes in. Only the just-calibrated layer + uses this mode, so its output reflects updated weights. + * **capture** — record ``(args, kwargs)`` and raise + ``_EarlyStopForwardError`` to halt the forward pass early. + * **original** — call the original forward unchanged. + + Because the *run* layer discards upstream values, skip-layer outputs are + never consumed for real computation. + """ + + # Global registry of (predicate, discoverer) pairs. Populated at import time + # by plugins (e.g. huggingface.py). Order matters: the first matching entry wins, + # so more specific predicates (e.g. Nemotron-H) must be registered before + # generic ones (e.g. homogeneous HF models). + _decoder_layer_support: list[tuple[Any, Any]] = [] + _LAYER_ATTR = "_seq_calib" + + def __init__(self, model: nn.Module): + """Initialize the collector for the given model.""" + self.model = model + self._decoder_layers: nn.ModuleList | None = None + self._layer_to_idx: dict[nn.Module, int] = {} + self._patched = False + + # ------------------------------------------------------------------ + # Decoder-layer discovery + # ------------------------------------------------------------------ + + @staticmethod + def get_decoder_layers(model: nn.Module) -> nn.ModuleList | None: + """Return decoder layers supported by sequential calibration.""" + for is_supported, discoverer in LayerActivationCollector._decoder_layer_support: + if not is_supported(model): + continue + decoder_layers = discoverer(model) + if decoder_layers is not None: + return decoder_layers + return None + + @staticmethod + def is_supported(model: nn.Module) -> bool: + """Whether the model supports decoder-layer sequential calibration.""" + return LayerActivationCollector.get_decoder_layers(model) is not None + + @classmethod + def register_decoder_layer_support(cls, is_supported: Any, discoverer: Any): + """Register a (predicate, discoverer) pair for decoder-layer detection.""" + entry = (is_supported, discoverer) + if entry not in cls._decoder_layer_support: + cls._decoder_layer_support.append(entry) + + # ------------------------------------------------------------------ + # Output metadata helpers + # ------------------------------------------------------------------ + + @staticmethod + def _extract_output_meta(output): + """Extract lightweight (shape, dtype, device) metadata from a layer output. + + Recursively handles tensors, tuples, lists, and non-tensor values (e.g. None). + The returned structure can be passed to ``_zeros_from_meta`` to reconstruct a + zero-filled output with identical shape and type. + """ + if isinstance(output, torch.Tensor): + return ("tensor", output.shape, output.dtype, output.device) + if isinstance(output, tuple): + return ( + "tuple", + tuple(LayerActivationCollector._extract_output_meta(o) for o in output), + ) + if isinstance(output, list): + return ("list", [LayerActivationCollector._extract_output_meta(o) for o in output]) + return ("other", output) + + @staticmethod + def _zeros_from_meta(meta): + """Reconstruct a zero-filled output from metadata produced by ``_extract_output_meta``.""" + tag = meta[0] + if tag == "tensor": + _, shape, dtype, device = meta + return torch.zeros(shape, dtype=dtype, device=device) + if tag == "tuple": + return tuple(LayerActivationCollector._zeros_from_meta(m) for m in meta[1]) + if tag == "list": + return [LayerActivationCollector._zeros_from_meta(m) for m in meta[1]] + return copy.deepcopy(meta[1]) + + # ------------------------------------------------------------------ + # Patched forward + # ------------------------------------------------------------------ + + @staticmethod + def _patched_forward(self, *args, **kwargs): + """Unified forward bound to every decoder layer during sequential calibration. + + ``self`` here is the decoder layer module (bound via ``bind_forward_method``). + All per-layer state is accessed through ``self._seq_calib``. + """ + info: _LayerCalibState = self._seq_calib + + if info.mode == "skip": + if info.output_meta is None: + raise RuntimeError( + f"Layer {info.name} is in 'skip' mode but has no output_meta. " + "This indicates a state-machine bug: the layer should have run " + "in 'run' mode (which sets output_meta) before transitioning to 'skip'." + ) + return LayerActivationCollector._zeros_from_meta(info.output_meta) + + if info.mode == "run": + assert info.cached_inputs, ( + f"Layer {info.name} is in 'run' mode but has no cached inputs to replay." + ) + real_args, real_kwargs = info.cached_inputs.popleft() + output = self._original_forward(*real_args, **real_kwargs) + info.output_meta = LayerActivationCollector._extract_output_meta(output) + return output + + if info.mode == "capture": + info.collected_inputs.append((args, kwargs)) + raise _EarlyStopForwardError() + + return self._original_forward(*args, **kwargs) + + # ------------------------------------------------------------------ + # Patch / unpatch lifecycle + # ------------------------------------------------------------------ + + def _patch_all_layers(self, decoder_layers: nn.ModuleList | None = None): + """Bind the unified forward to every decoder layer and the model. Called once. + + Args: + decoder_layers: Pre-resolved decoder layers. If *None*, layers are + discovered via :meth:`get_decoder_layers`. + """ + if decoder_layers is not None: + self._decoder_layers = decoder_layers + else: + self._decoder_layers = self.get_decoder_layers(self.model) + assert self._decoder_layers is not None + + self._layer_to_idx = {layer: i for i, layer in enumerate(self._decoder_layers)} + module_to_name = {m: name for name, m in self.model.named_modules()} + + try: + for layer in self._decoder_layers: + layer._seq_calib = _LayerCalibState( + name=module_to_name.get(layer, type(layer).__name__), + ) + bind_forward_method(layer, self._patched_forward, "_original_forward") + + def _early_stop_forward(module_self, *args, **kwargs): + try: + return module_self._original_forward(*args, **kwargs) + except _EarlyStopForwardError: + return None + + bind_forward_method(self.model, _early_stop_forward, "_original_forward") + except Exception: + self._cleanup_layers() + raise + + self._patched = True + + def _cleanup_layers(self): + """Best-effort cleanup of any patched layers and model forward.""" + if hasattr(self.model, "_original_forward"): + unpatch_forward_method(self.model, "_original_forward") + + if self._decoder_layers is not None: + for layer in self._decoder_layers: + if hasattr(layer, "_original_forward"): + unpatch_forward_method(layer, "_original_forward") + if hasattr(layer, self._LAYER_ATTR): + delattr(layer, self._LAYER_ATTR) + + def _unpatch_all_layers(self): + """Restore original forwards and clean up state attributes. Called once.""" + if not self._patched: + return + self._cleanup_layers() + self._patched = False + + # ------------------------------------------------------------------ + # Per-iteration state management + # ------------------------------------------------------------------ + + def _set_layer_states(self, layer_idx: int): + """Transition layer modes for the next calibration step. + + When calibrating layer *i*, three transitions happen: + + * Layer ``i - 2`` → **skip** (fully done, free its cached inputs). + * Layer ``i - 1`` → **run** (replay captured inputs with calibrated weights). + * Layer ``i`` → **capture** (record inputs, then early-stop). + """ + assert self._decoder_layers is not None + + if layer_idx > 1: + done = self._decoder_layers[layer_idx - 2]._seq_calib + done.mode = "skip" + # output_meta is intentionally kept: skip mode needs it to produce + # correctly shaped zero-filled outputs for the parent forward. + done.cached_inputs.clear() + + if layer_idx > 0: + prev = self._decoder_layers[layer_idx - 1]._seq_calib + prev.mode = "run" + prev.cached_inputs = deque(prev.collected_inputs) + prev.collected_inputs = [] + + cur = self._decoder_layers[layer_idx]._seq_calib + cur.mode = "capture" + cur.collected_inputs = [] + + def _log_layer_summary(self, layer_idx: int): + """Log a one-line summary of layer modes for the current calibration step.""" + assert self._decoder_layers is not None + n = len(self._decoder_layers) + groups: dict[str, list[int]] = {} + for i, layer in enumerate(self._decoder_layers): + mode = layer._seq_calib.mode + if mode in ("skip", "run", "capture"): + groups.setdefault(mode, []).append(i) + parts = [f"{mode}: {groups[mode]}" for mode in ("skip", "run", "capture") if mode in groups] + print_rank_0(f"Calibrating layer {layer_idx}/{n} | {' | '.join(parts)}") + + # ------------------------------------------------------------------ + # Public API + # ------------------------------------------------------------------ + + @torch.no_grad() + def get_input_activations(self, layer: torch.nn.Module, forward_loop: ForwardLoop) -> list: + """Collect input activations for *layer* by running a full model forward. + + Layers before the target are skipped or re-run (if just calibrated), the + target layer captures its inputs, and an early-stop prevents unnecessary + computation beyond the target. + + :meth:`_patch_all_layers` must be called before this method. + + Note: the model forward returns ``None`` for every batch during capture + (because ``_EarlyStopForwardError`` short-circuits the forward pass). + Callers should not rely on the model's return value within *forward_loop*. + """ + if not self._patched: + raise RuntimeError( + "get_input_activations() requires _patch_all_layers() to be called first." + ) + layer_idx = self._layer_to_idx[layer] + self._set_layer_states(layer_idx) + self._log_layer_summary(layer_idx) + + info = layer._seq_calib + try: + forward_loop(self.model) + except Exception: + # Reset the current layer so subsequent calls don't see stale state. + info.mode = "original" + info.collected_inputs = [] + raise + + inputs = list(info.collected_inputs) + # After capture, set to original so calib_func can call the layer's + # real forward directly. The layer will transition to run → skip + # in subsequent iterations via _set_layer_states. + info.mode = "original" + return inputs diff --git a/modelopt/torch/quantization/model_calib.py b/modelopt/torch/quantization/model_calib.py index 70f036a8d..513b66fdb 100644 --- a/modelopt/torch/quantization/model_calib.py +++ b/modelopt/torch/quantization/model_calib.py @@ -28,14 +28,10 @@ from tqdm import tqdm from modelopt.torch.opt.searcher import ForwardLoop -from modelopt.torch.quantization.utils import LayerActivationCollector +from modelopt.torch.quantization.activation_collector import LayerActivationCollector from modelopt.torch.utils import print_rank_0 from modelopt.torch.utils.distributed import DistributedProcessGroup, ParallelState -from modelopt.torch.utils.network import ( - bind_forward_method, - get_decoder_layers, - unpatch_forward_method, -) +from modelopt.torch.utils.network import bind_forward_method, unpatch_forward_method from modelopt.torch.utils.perf import get_used_gpu_mem_fraction from .calib import MseCalibrator, NVFP4MSECalibrator @@ -1840,31 +1836,36 @@ def sequential_calibrate( calib_func: Callable, **calib_kwargs, ): - """Sequential calibration - a sequential layer-by-layer calibration algorithm.""" - if forward_loop is None: - raise ValueError("forward_loop must not be None for sequential calibration.") + """Sequential calibration - a sequential layer-by-layer calibration algorithm. - transformer_layers = get_decoder_layers(model) - if transformer_layers is None: + Runs the full model forward per layer but patches decoder layers with a + skip / run / capture strategy so that inter-layer logic in parent modules + (e.g. mask construction) executes naturally without model-specific hooks. + """ + transformer_layers = LayerActivationCollector.get_decoder_layers(model) + if transformer_layers is None or len(transformer_layers) == 0: raise ValueError( - "Could not find transformer layers in model'. " + "Could not find transformer layers in model. " "Sequential calibration requires a model with identifiable transformer layers." ) print_rank_0(f"Sequential calibration: Found {len(transformer_layers)} transformer layers") - gettr = LayerActivationCollector(model) + input_getter = LayerActivationCollector(model) + input_getter._patch_all_layers(decoder_layers=transformer_layers) - for layer in transformer_layers: - # Get updated input activations to the current layer - layer_inputs = gettr.get_input_activations(layer, forward_loop) + try: + for layer_idx, layer in enumerate(transformer_layers): + print_rank_0(f"Calibrating layer {layer_idx}") + layer_inputs = input_getter.get_input_activations(layer, forward_loop) - # Define a forward loop for the current layer - def _layer_forward_loop(m, _inputs=layer_inputs): - for args, kwargs_input in _inputs: - m(*args, **kwargs_input) + def _layer_forward_loop(m, _inputs=layer_inputs): + for args, kwargs_input in _inputs: + m(*args, **kwargs_input) - # Call calibration function - calib_func(layer, _layer_forward_loop, **calib_kwargs) - del layer_inputs - torch.cuda.empty_cache() + calib_func(layer, _layer_forward_loop, **calib_kwargs) + + del layer_inputs + torch.cuda.empty_cache() + finally: + input_getter._unpatch_all_layers() diff --git a/modelopt/torch/quantization/plugins/huggingface.py b/modelopt/torch/quantization/plugins/huggingface.py index 318646c39..4f1e65229 100644 --- a/modelopt/torch/quantization/plugins/huggingface.py +++ b/modelopt/torch/quantization/plugins/huggingface.py @@ -56,6 +56,7 @@ else: weight_dequant = None +from ..activation_collector import LayerActivationCollector from ..utils import replace_function from .attention import register_attention_for_kv_quant from .custom import CUSTOM_MODEL_PLUGINS, _ParallelLinear, _QuantFunctionalMixin @@ -1179,6 +1180,42 @@ def _is_supported_hf_model(model): return isinstance(model, tuple(supported_models)) +def is_nemotron_h_model(model: nn.Module) -> bool: + return get_nemotron_h_decoder_layers(model) is not None + + +def get_nemotron_h_decoder_layers(model: nn.Module) -> nn.ModuleList | None: + if not _is_supported_hf_model(model): + return None + + if hasattr(model, "backbone") and hasattr(model.backbone, "layers"): + layers = model.backbone.layers + if len(layers) > 0 and hasattr(layers[0], "block_type"): + return layers + + return None + + +def is_homogeneous_hf_model(model: nn.Module) -> bool: + if is_nemotron_h_model(model): + return False + decoder_layers = get_homogeneous_hf_decoder_layers(model) + if decoder_layers is None or len(decoder_layers) == 0: + return False + layer_classes = {type(layer) for layer in decoder_layers} + return len(layer_classes) == 1 + + +def get_homogeneous_hf_decoder_layers(model: nn.Module) -> nn.ModuleList | None: + if not _is_supported_hf_model(model): + return None + + if hasattr(model, "model") and hasattr(model.model, "layers"): + return model.model.layers + + return None + + @contextmanager def setup_model_for_gradient_checkpointing(model: nn.Module): use_cache = None @@ -1228,6 +1265,17 @@ def _is_param_grad_enabled_for_auto_quantize(pname, model): _is_param_grad_enabled_for_auto_quantize, ) +# Order matters: more specific predicates must be registered first because +# the first matching entry wins. Nemotron-H must precede the generic +# homogeneous HF discoverer (which explicitly rejects Nemotron-H). +LayerActivationCollector.register_decoder_layer_support( + is_nemotron_h_model, get_nemotron_h_decoder_layers +) + +LayerActivationCollector.register_decoder_layer_support( + is_homogeneous_hf_model, get_homogeneous_hf_decoder_layers +) + CUSTOM_MODEL_PLUGINS.update( [ register_falcon_linears_on_the_fly, diff --git a/modelopt/torch/quantization/utils.py b/modelopt/torch/quantization/utils.py index 3c0d5e434..9216a89a1 100644 --- a/modelopt/torch/quantization/utils.py +++ b/modelopt/torch/quantization/utils.py @@ -26,9 +26,10 @@ from torch.distributed.fsdp._fully_shard._fsdp_param import FSDPParam from torch.distributed.tensor import Replicate -from modelopt.torch.opt.searcher import ForwardLoop +from modelopt.torch.quantization.activation_collector import ( + LayerActivationCollector, # noqa: F401 # re-export +) from modelopt.torch.utils import get_unwrapped_name, print_rank_0 -from modelopt.torch.utils.network import bind_forward_method, unpatch_forward_method if TYPE_CHECKING: from collections.abc import Generator @@ -808,64 +809,3 @@ def update_quant_cfg_with_kv_cache_quant( quant_cfg["algorithm"] = "max" print_rank_0(f"Updated quant_cfg with KV cache quantization: {quant_cfg}") return quant_cfg - - -class _EarlyStopForwardError(Exception): - """Error to stop the forward pass after collection.""" - - -class LayerActivationCollector: - """Helper class for collecting layer activations during forward passes. - - This class allows for sequential layer calibration by - patching layers to capture inputs/outputs during forward passes - """ - - def __init__(self, model: nn.Module): - self.model = model - - @staticmethod - def _patch_and_initialize_layer(layer: torch.nn.Module, stop_after_collection: bool = False): - """Patch a layer to collect inputs during forward passes.""" - - def _forward_w_data_collection(self, *args, **kwargs): - # Note: 'self' refers to the patched layer. - assert len(args) >= 1, ( - f"Expected at least 1 positional arg, got {len(args)} args and {list(kwargs.keys())} kwargs" - ) - # Only collect the inputs to the layer - self.inputs.append((args, kwargs)) - if stop_after_collection: - raise _EarlyStopForwardError() # Stop the forward pass after collection - - return self._original_forward(*args, **kwargs) - - bind_forward_method(layer, _forward_w_data_collection, "_original_forward") - layer.inputs = [] - - @staticmethod - def _unpatch_and_cleanup_layer(layer: torch.nn.Module): - if hasattr(layer, "_original_forward"): - unpatch_forward_method(layer, "_original_forward") - if hasattr(layer, "inputs"): - del layer.inputs - - @torch.no_grad() - def get_input_activations(self, layer: torch.nn.Module, forward_loop: ForwardLoop) -> list: - # Wrap model forward to catch _EarlyStopForward per-batch - def _early_stop_forward(self, *args, **kwargs): - try: - return self._original_forward(*args, **kwargs) - except _EarlyStopForwardError: - return None # Stop propagation but allow next batch - - try: - bind_forward_method(self.model, _early_stop_forward, "_original_forward") - self._patch_and_initialize_layer(layer, stop_after_collection=True) - forward_loop(self.model) - inputs = layer.inputs.copy() - finally: - self._unpatch_and_cleanup_layer(layer) - unpatch_forward_method(self.model, "_original_forward") - - return inputs diff --git a/modelopt/torch/utils/network.py b/modelopt/torch/utils/network.py index 21c096db2..b54332375 100644 --- a/modelopt/torch/utils/network.py +++ b/modelopt/torch/utils/network.py @@ -634,36 +634,3 @@ def unpatch_forward_method(module: nn.Module, orig_forward_cache_name: str): with temporarily_remove_accelerate_hook(module): setattr(module, "forward", getattr(module, orig_forward_cache_name)) delattr(module, orig_forward_cache_name) - - -def get_decoder_layers(model: nn.Module, granularity: str = "decoder") -> nn.ModuleList | None: - """Detect the decoder layers from a model for sequential calibration. - - This temporary decoder-layer detection heuristic will be replaced with a more robust solution - that also supports FSDP/DDP models. - """ - if granularity != "decoder": - raise ValueError(f"Unsupported granularity: {granularity}. Only 'decoder' is supported.") - - # HuggingFace transformers pattern: model.model.layers - if hasattr(model, "model") and hasattr(model.model, "layers"): - return model.model.layers - - # Megatron/MCore pattern: model.decoder.layers - if hasattr(model, "decoder") and hasattr(model.decoder, "layers"): - return model.decoder.layers - - # Direct layers attribute (some models) - if hasattr(model, "layers") and isinstance(model.layers, nn.ModuleList): - return model.layers - - # GPT-style: model.transformer.h - if hasattr(model, "transformer") and hasattr(model.transformer, "h"): - return model.transformer.h - - # Nemotron Super/Nano - if hasattr(model, "backbone") and hasattr(model.backbone, "layers"): - return model.backbone.layers - - print("No decoder layers found for model, returning None") - return None diff --git a/tests/unit/torch/quantization/plugins/test_huggingface.py b/tests/unit/torch/quantization/plugins/test_huggingface.py index 6b934a32c..67e82c629 100644 --- a/tests/unit/torch/quantization/plugins/test_huggingface.py +++ b/tests/unit/torch/quantization/plugins/test_huggingface.py @@ -23,13 +23,19 @@ from _test_utils.torch.misc import set_seed from _test_utils.torch.transformers_models import ( create_tiny_llama_dir, + get_tiny_gpt_oss, get_tiny_llama, get_tiny_qwen3_moe, tf_modelopt_state_and_output_tester, ) import modelopt.torch.quantization as mtq +from modelopt.torch.quantization.activation_collector import LayerActivationCollector from modelopt.torch.quantization.nn import QuantLinear, QuantModuleRegistry +from modelopt.torch.quantization.plugins.huggingface import ( + get_homogeneous_hf_decoder_layers, + is_homogeneous_hf_model, +) pytest.importorskip("transformers") @@ -199,3 +205,24 @@ def test_quantized_transformers_save_restore(tmp_path, model_cls, quant_config): model_test = model_cls.from_pretrained(tiny_llama_dir / "modelopt_model") tf_modelopt_state_and_output_tester(model_ref, model_test) + + +def test_is_homogeneous_hf_model_llama(): + model = get_tiny_llama() + assert is_homogeneous_hf_model(model) + + +def test_is_homogeneous_hf_model_gpt_oss(): + model = get_tiny_gpt_oss(num_hidden_layers=1) + assert is_homogeneous_hf_model(model) + + +def test_hf_decoder_discoverer_registration_path(): + model = get_tiny_llama() + assert any( + is_supported is is_homogeneous_hf_model and discoverer is get_homogeneous_hf_decoder_layers + for is_supported, discoverer in LayerActivationCollector._decoder_layer_support + ) + assert LayerActivationCollector.get_decoder_layers(model) is get_homogeneous_hf_decoder_layers( + model + ) diff --git a/tests/unit/torch/quantization/test_calib.py b/tests/unit/torch/quantization/test_calib.py index 7bc78c40e..c868b8bfb 100644 --- a/tests/unit/torch/quantization/test_calib.py +++ b/tests/unit/torch/quantization/test_calib.py @@ -17,6 +17,7 @@ from functools import partial +import pytest import torch import torch.nn as nn from _test_utils.torch.quantization.quantize_common import get_awq_config @@ -26,6 +27,7 @@ from modelopt.torch.quantization.model_calib import ( apply_pre_quant_scale_and_smooth, disable_pre_quant_scale_and_resmooth, + sequential_calibrate, ) from modelopt.torch.quantization.nn import TensorQuantizer @@ -375,3 +377,172 @@ def test_svdquant_lora_weights(): module.weight_quantizer.svdquant_lora_b @ module.weight_quantizer.svdquant_lora_a ) assert lora_residual.shape == module.weight.shape + + +def test_sequential_calibrate_support_gate(): + class _UnsupportedModel(nn.Module): + def __init__(self): + super().__init__() + self.linear = nn.Linear(4, 4, bias=False) + + def forward(self, x): + return self.linear(x) + + model = _UnsupportedModel() + + with ( + torch.no_grad(), + pytest.raises(ValueError, match="Sequential calibration requires a model"), + ): + sequential_calibrate( + model, + forward_loop=lambda m: m(torch.randn(2, 4)), + calib_func=lambda layer, loop: loop(layer), + ) + + +def test_sequential_calibrate_propagates_inputs_without_replaying_full_model(monkeypatch): + from modelopt.torch.quantization.activation_collector import LayerActivationCollector + + class _ToyLayer(nn.Module): + def __init__(self, scale: float, bias: float): + super().__init__() + self.scale = scale + self.bias = bias + + def forward(self, hidden_states): + return hidden_states * self.scale + self.bias + + class _ToyDecoder(nn.Module): + def __init__(self): + super().__init__() + self.layers = nn.ModuleList( + [ + _ToyLayer(scale=2.0, bias=1.0), + _ToyLayer(scale=0.5, bias=3.0), + _ToyLayer(scale=1.0, bias=-2.0), + ] + ) + + def forward(self, hidden_states): + for layer in self.layers: + hidden_states = layer(hidden_states) + return hidden_states + + model = _ToyDecoder() + monkeypatch.setattr( + LayerActivationCollector, + "_decoder_layer_support", + [(lambda m: hasattr(m, "layers"), lambda m: m.layers)], + ) + batches = [ + torch.tensor([[1.0, 2.0]]), + torch.tensor([[3.0, 4.0]]), + ] + + forward_loop_calls = 0 + + def _forward_loop(m): + nonlocal forward_loop_calls + forward_loop_calls += 1 + for batch in batches: + m(batch) + + observed_layer_inputs = [] + + def _calib_func(layer, layer_forward_loop): + captured = [] + + def _pre_hook(_module, args): + captured.append(args[0].clone()) + + handle = layer.register_forward_pre_hook(_pre_hook) + try: + layer_forward_loop(layer) + finally: + handle.remove() + observed_layer_inputs.append(captured) + + sequential_calibrate(model, _forward_loop, _calib_func) + + assert forward_loop_calls == len(model.layers) + assert len(observed_layer_inputs) == len(model.layers) + for layer_inputs in observed_layer_inputs: + assert len(layer_inputs) == len(batches) + + expected_layer_0 = batches + expected_layer_1 = [model.layers[0](batch) for batch in batches] + expected_layer_2 = [model.layers[1](batch) for batch in expected_layer_1] + + for observed, expected in zip(observed_layer_inputs[0], expected_layer_0): + assert torch.allclose(observed, expected) + for observed, expected in zip(observed_layer_inputs[1], expected_layer_1): + assert torch.allclose(observed, expected) + for observed, expected in zip(observed_layer_inputs[2], expected_layer_2): + assert torch.allclose(observed, expected) + + +def test_sequential_calibrate_handles_inter_layer_logic(monkeypatch): + """Verify that parent-level inter-layer logic (e.g. mask selection) works correctly.""" + from modelopt.torch.quantization.activation_collector import LayerActivationCollector + + class _ToyLayer(nn.Module): + def __init__(self, scale: float): + super().__init__() + self.scale = scale + + def forward(self, hidden_states, mask=None): + if mask is not None: + hidden_states = hidden_states * mask + return hidden_states * self.scale + + class _ToyDecoder(nn.Module): + def __init__(self): + super().__init__() + self.layers = nn.ModuleList( + [_ToyLayer(scale=2.0), _ToyLayer(scale=0.5), _ToyLayer(scale=3.0)] + ) + self.masks = [1.0, 0.5, 2.0] + + def forward(self, hidden_states): + for layer, mask_val in zip(self.layers, self.masks): + mask = torch.full_like(hidden_states, mask_val) + hidden_states = layer(hidden_states, mask=mask) + return hidden_states + + model = _ToyDecoder() + monkeypatch.setattr( + LayerActivationCollector, + "_decoder_layer_support", + [(lambda m: hasattr(m, "layers"), lambda m: m.layers)], + ) + batches = [torch.tensor([[1.0, 2.0]])] + + def _forward_loop(m): + for batch in batches: + m(batch) + + observed_layer_inputs = [] + + def _calib_func(layer, layer_forward_loop): + captured = [] + + def _pre_hook(_module, args): + captured.append(args[0].clone()) + + handle = layer.register_forward_pre_hook(_pre_hook) + try: + layer_forward_loop(layer) + finally: + handle.remove() + observed_layer_inputs.append(captured) + + sequential_calibrate(model, _forward_loop, _calib_func) + + assert len(observed_layer_inputs) == 3 + # Layer 0 gets raw batch + assert torch.allclose(observed_layer_inputs[0][0], batches[0]) + # Layer 1 gets output of layer 0 (batch * mask0 * scale0 = [1,2] * 1.0 * 2.0 = [2,4]) + assert torch.allclose(observed_layer_inputs[1][0], torch.tensor([[2.0, 4.0]])) + # Layer 2 gets output of layer 1 (prev * mask1 * scale1 = [2,4] * 0.5 * 0.5 = [0.5,1.0]) + assert torch.allclose(observed_layer_inputs[2][0], torch.tensor([[0.5, 1.0]])) diff --git a/tests/unit/torch/quantization/test_sequential_calibrate.py b/tests/unit/torch/quantization/test_sequential_calibrate.py index 3b6b166be..ad9a0e58f 100644 --- a/tests/unit/torch/quantization/test_sequential_calibrate.py +++ b/tests/unit/torch/quantization/test_sequential_calibrate.py @@ -15,12 +15,14 @@ """Unit tests for sequential_calibrate and LayerActivationCollector.""" +from collections import deque + import pytest import torch import torch.nn as nn +from modelopt.torch.quantization.activation_collector import LayerActivationCollector from modelopt.torch.quantization.model_calib import sequential_calibrate -from modelopt.torch.quantization.utils import LayerActivationCollector class _DecoderBlock(nn.Module): @@ -98,7 +100,17 @@ def _run_forward(model, data): # LayerActivationCollector tests -def test_collector_collects_correct_number_of_inputs(): +def _register_test_discoverer(monkeypatch): + """Register a simple discoverer that finds model.layers on any model.""" + monkeypatch.setattr( + LayerActivationCollector, + "_decoder_layer_support", + [(lambda m: hasattr(m, "layers"), lambda m: m.layers)], + ) + + +def test_collector_collects_correct_number_of_inputs(monkeypatch): + _register_test_discoverer(monkeypatch) torch.manual_seed(0) model = _SimpleTwoLayerModel(dim=8) collector = LayerActivationCollector(model) @@ -108,12 +120,17 @@ def forward_loop(m): for d in data: m(d) - inputs = collector.get_input_activations(model.layers[0], forward_loop) - assert len(inputs) == 3 + collector._patch_all_layers() + try: + inputs = collector.get_input_activations(model.layers[0], forward_loop) + assert len(inputs) == 3 + finally: + collector._unpatch_all_layers() -def test_collector_activations_match_expected(): +def test_collector_activations_match_expected(monkeypatch): """First layer should receive the raw input data.""" + _register_test_discoverer(monkeypatch) torch.manual_seed(0) model = _SimpleTwoLayerModel(dim=8) collector = LayerActivationCollector(model) @@ -123,13 +140,18 @@ def forward_loop(m): for d in data: m(d) - inputs = collector.get_input_activations(model.layers[0], forward_loop) - args, kwargs = inputs[0] - assert torch.allclose(args[0], data[0]) + collector._patch_all_layers() + try: + inputs = collector.get_input_activations(model.layers[0], forward_loop) + args, kwargs = inputs[0] + assert torch.allclose(args[0], data[0]) + finally: + collector._unpatch_all_layers() -def test_collector_second_layer_receives_transformed_input(): +def test_collector_second_layer_receives_transformed_input(monkeypatch): """Second layer should receive first layer's output, not raw input.""" + _register_test_discoverer(monkeypatch) torch.manual_seed(0) model = _SimpleTwoLayerModel(dim=8) collector = LayerActivationCollector(model) @@ -139,53 +161,55 @@ def forward_loop(m): m(x) expected = model.layers[0](x) - inputs = collector.get_input_activations(model.layers[1], forward_loop) - args, _ = inputs[0] - assert torch.allclose(args[0], expected) + + collector._patch_all_layers() + try: + collector.get_input_activations(model.layers[0], forward_loop) + inputs = collector.get_input_activations(model.layers[1], forward_loop) + args, _ = inputs[0] + assert torch.allclose(args[0], expected) + finally: + collector._unpatch_all_layers() -def test_collector_forward_is_restored_after_collection(): +def test_collector_forward_is_restored_after_collection(monkeypatch): + _register_test_discoverer(monkeypatch) model = _SimpleTwoLayerModel(dim=8) collector = LayerActivationCollector(model) def forward_loop(m): m(torch.randn(2, 8)) + collector._patch_all_layers() collector.get_input_activations(model.layers[0], forward_loop) + collector._unpatch_all_layers() assert not hasattr(model, "_original_forward") - assert not hasattr(model.layers[0], "inputs") + assert not hasattr(model.layers[0], "_seq_calib") assert not hasattr(model.layers[0], "_original_forward") -def test_collector_cleanup_on_forward_loop_error(): +def test_collector_cleanup_on_forward_loop_error(monkeypatch): """Patching should be cleaned up even if forward_loop raises.""" + _register_test_discoverer(monkeypatch) model = _SimpleTwoLayerModel(dim=8) collector = LayerActivationCollector(model) def bad_forward_loop(m): raise RuntimeError("intentional error") - with pytest.raises(RuntimeError, match="intentional error"): - collector.get_input_activations(model.layers[0], bad_forward_loop) + collector._patch_all_layers() + try: + with pytest.raises(RuntimeError, match="intentional error"): + collector.get_input_activations(model.layers[0], bad_forward_loop) + finally: + collector._unpatch_all_layers() assert not hasattr(model, "_original_forward") - assert not hasattr(model.layers[0], "inputs") + assert not hasattr(model.layers[0], "_seq_calib") # sequential_calibrate tests - - -def test_seq_calib_raises_on_none_forward_loop(): - model, _ = _make_model_and_data(n_layers=2) - with pytest.raises(ValueError, match="forward_loop must not be None"): - sequential_calibrate( - model, - forward_loop=None, - calib_func=lambda *a, **kw: None, - ) - - def test_seq_calib_raises_on_unrecognized_model(): model = _FlatMLP() with pytest.raises(ValueError, match="Could not find transformer layers"): @@ -196,7 +220,8 @@ def test_seq_calib_raises_on_unrecognized_model(): ) -def test_seq_calib_func_called_per_layer(): +def test_seq_calib_func_called_per_layer(monkeypatch): + _register_test_discoverer(monkeypatch) model, data = _make_model_and_data(n_layers=4) call_count = [0] @@ -212,7 +237,8 @@ def counting_calib(layer, forward_loop, **kwargs): assert call_count[0] == 4 -def test_seq_calib_func_receives_correct_layer(): +def test_seq_calib_func_receives_correct_layer(monkeypatch): + _register_test_discoverer(monkeypatch) model, data = _make_model_and_data(n_layers=3) called_layers = [] @@ -229,7 +255,8 @@ def track_layers(layer, forward_loop, **kwargs): assert called_layers[i] is layer -def test_seq_calib_kwargs_forwarded(): +def test_seq_calib_kwargs_forwarded(monkeypatch): + _register_test_discoverer(monkeypatch) model, data = _make_model_and_data(n_layers=2) received_kwargs = [] @@ -250,8 +277,9 @@ def capture_kwargs(layer, forward_loop, **kwargs): assert kw["method"] == "max" -def test_seq_calib_layer_forward_loop_runs_all_batches(): +def test_seq_calib_layer_forward_loop_runs_all_batches(monkeypatch): """The per-layer forward loop passed to calib_func should replay all batches.""" + _register_test_discoverer(monkeypatch) n_batches = 5 model, data = _make_model_and_data(n_layers=2, n_batches=n_batches) batch_counts = [] @@ -279,8 +307,9 @@ def counting_forward(*args, **kw): assert count == n_batches -def test_seq_calib_does_not_alter_weights(): +def test_seq_calib_does_not_alter_weights(monkeypatch): """sequential_calibrate itself should not modify model weights.""" + _register_test_discoverer(monkeypatch) model, data = _make_model_and_data(n_layers=3) weights_before = {n: p.clone() for n, p in model.named_parameters()} @@ -294,8 +323,9 @@ def test_seq_calib_does_not_alter_weights(): assert torch.equal(p, weights_before[n]), f"Weight {n} was modified" -def test_seq_calib_activations_update_across_layers(): +def test_seq_calib_activations_update_across_layers(monkeypatch): """Subsequent layers should see activations transformed by prior layers.""" + _register_test_discoverer(monkeypatch) torch.manual_seed(0) model = _SimpleTransformerModel(n_layers=2, dim=16) tokens = [torch.randint(0, 32, (2, 4))] @@ -328,8 +358,9 @@ def capture_forward(*args, **kw): ) -def test_seq_calib_empty_forward_loop(): +def test_seq_calib_empty_forward_loop(monkeypatch): """If forward_loop feeds no data, calib_func still gets called with an empty replay.""" + _register_test_discoverer(monkeypatch) model = _SimpleTransformerModel(n_layers=2, dim=16) replay_counts = [] @@ -354,3 +385,350 @@ def counting_forward(*args, **kw): for count in replay_counts: assert count == 0 + + +# --------------------------------------------------------------------------- +# Skip / run / capture path verification tests +# --------------------------------------------------------------------------- + + +class _TupleReturningBlock(nn.Module): + """Decoder layer that returns a tuple, mimicking HuggingFace decoder layers.""" + + def __init__(self, dim=16): + super().__init__() + self.linear = nn.Linear(dim, dim, bias=False) + + def forward(self, x, **kwargs): + return (self.linear(x), None) + + +class _TupleUnpackingModel(nn.Module): + """Parent model that unpacks layer outputs as tuples. + + This would crash with a naive skip that returns a bare tensor. + """ + + def __init__(self, n_layers=4, dim=16): + super().__init__() + self.layers = nn.ModuleList([_TupleReturningBlock(dim) for _ in range(n_layers)]) + + def forward(self, x): + for layer in self.layers: + x, _ = layer(x) + return x + + +class _InterLayerNormModel(nn.Module): + """Model with LayerNorm between decoder layers (not inside them).""" + + def __init__(self, n_layers=4, dim=16): + super().__init__() + self.layers = nn.ModuleList([_TupleReturningBlock(dim) for _ in range(n_layers)]) + self.norms = nn.ModuleList([nn.LayerNorm(dim) for _ in range(n_layers)]) + + def forward(self, x): + for norm, layer in zip(self.norms, self.layers): + x = norm(x) + x, _ = layer(x) + return x + + +def test_skip_output_preserves_tuple_structure(monkeypatch): + """Skip layers must return a tuple when the real layer returns a tuple. + + Without this, the parent's ``x, _ = layer(x)`` unpacking would crash. + """ + _register_test_discoverer(monkeypatch) + model = _TupleUnpackingModel(n_layers=5, dim=16) + data = [torch.randn(2, 16) for _ in range(3)] + + def forward_loop(m): + for d in data: + m(d) + + collector = LayerActivationCollector(model) + collector._patch_all_layers() + try: + for layer in model.layers: + inputs = collector.get_input_activations(layer, forward_loop) + assert len(inputs) == len(data) + finally: + collector._unpatch_all_layers() + + +def test_skip_output_preserves_shape_with_inter_layer_norm(monkeypatch): + """Skip outputs must have correct shape for un-patched LayerNorm between layers.""" + _register_test_discoverer(monkeypatch) + model = _InterLayerNormModel(n_layers=5, dim=16) + data = [torch.randn(2, 16) for _ in range(3)] + + def forward_loop(m): + for d in data: + m(d) + + collector = LayerActivationCollector(model) + collector._patch_all_layers() + try: + for layer in model.layers: + inputs = collector.get_input_activations(layer, forward_loop) + assert len(inputs) == len(data) + finally: + collector._unpatch_all_layers() + + +def test_run_layer_populates_output_meta(monkeypatch): + """After a layer executes in 'run' mode, its output_meta must be set.""" + _register_test_discoverer(monkeypatch) + model = _TupleUnpackingModel(n_layers=3, dim=16) + data = [torch.randn(2, 16)] + + def forward_loop(m): + for d in data: + m(d) + + collector = LayerActivationCollector(model) + collector._patch_all_layers() + try: + # Layer 0 starts as capture — no output_meta yet + collector.get_input_activations(model.layers[0], forward_loop) + assert model.layers[0]._seq_calib.output_meta is None + + # Calibrating layer 1 puts layer 0 into run, which sets output_meta + collector.get_input_activations(model.layers[1], forward_loop) + meta = model.layers[0]._seq_calib.output_meta + assert meta is not None + assert meta[0] == "tuple", "Tuple-returning layer should produce tuple metadata" + finally: + collector._unpatch_all_layers() + + +def test_run_layer_consumes_cached_inputs(monkeypatch): + """The run layer must pop all cached inputs during the forward loop.""" + _register_test_discoverer(monkeypatch) + n_batches = 4 + model = _TupleUnpackingModel(n_layers=3, dim=16) + data = [torch.randn(2, 16) for _ in range(n_batches)] + + def forward_loop(m): + for d in data: + m(d) + + collector = LayerActivationCollector(model) + collector._patch_all_layers() + try: + collector.get_input_activations(model.layers[0], forward_loop) + collector.get_input_activations(model.layers[1], forward_loop) + + # Before calibrating layer 2, layer 1 transitions to run. + # Its cached_inputs should be populated from collected_inputs. + collector._set_layer_states(2) + assert len(model.layers[1]._seq_calib.cached_inputs) == n_batches + + # After the forward loop, all cached inputs should be consumed + forward_loop(model) + assert len(model.layers[1]._seq_calib.cached_inputs) == 0 + finally: + collector._unpatch_all_layers() + + +def test_capture_layer_collects_all_batches(monkeypatch): + """The capture layer must record one entry per batch in the forward loop.""" + _register_test_discoverer(monkeypatch) + n_batches = 5 + model = _TupleUnpackingModel(n_layers=3, dim=16) + data = [torch.randn(2, 16) for _ in range(n_batches)] + + def forward_loop(m): + for d in data: + m(d) + + collector = LayerActivationCollector(model) + collector._patch_all_layers() + try: + inputs = collector.get_input_activations(model.layers[0], forward_loop) + assert len(inputs) == n_batches + + inputs = collector.get_input_activations(model.layers[2], forward_loop) + assert len(inputs) == n_batches + finally: + collector._unpatch_all_layers() + + +def test_mode_transitions_across_calibration_steps(monkeypatch): + """Verify mode transitions follow the skip/run/capture pattern at each step.""" + _register_test_discoverer(monkeypatch) + model = _TupleUnpackingModel(n_layers=5, dim=16) + + collector = LayerActivationCollector(model) + collector._patch_all_layers() + try: + + def modes(): + return [model.layers[i]._seq_calib.mode for i in range(5)] + + collector._set_layer_states(0) + assert modes() == ["capture", "original", "original", "original", "original"] + + collector._set_layer_states(1) + assert modes() == ["run", "capture", "original", "original", "original"] + + collector._set_layer_states(2) + assert modes() == ["skip", "run", "capture", "original", "original"] + + collector._set_layer_states(3) + assert modes() == ["skip", "skip", "run", "capture", "original"] + + collector._set_layer_states(4) + assert modes() == ["skip", "skip", "skip", "run", "capture"] + finally: + collector._unpatch_all_layers() + + +def test_run_asserts_on_empty_cached_inputs(monkeypatch): + """A layer in 'run' mode with no cached inputs must raise AssertionError.""" + _register_test_discoverer(monkeypatch) + model = _TupleUnpackingModel(n_layers=2, dim=16) + + collector = LayerActivationCollector(model) + collector._patch_all_layers() + try: + model.layers[0]._seq_calib.mode = "run" + model.layers[0]._seq_calib.cached_inputs = deque() + + with pytest.raises(AssertionError, match="no cached inputs to replay"): + model(torch.randn(2, 16)) + finally: + collector._unpatch_all_layers() + + +def test_cleanup_removes_seq_calib_attr(monkeypatch): + """After unpatch, no layer should have the _seq_calib attribute.""" + _register_test_discoverer(monkeypatch) + model = _TupleUnpackingModel(n_layers=3, dim=16) + data = [torch.randn(2, 16)] + + def forward_loop(m): + for d in data: + m(d) + + collector = LayerActivationCollector(model) + collector._patch_all_layers() + for layer in model.layers: + collector.get_input_activations(layer, forward_loop) + collector._unpatch_all_layers() + + for i, layer in enumerate(model.layers): + assert not hasattr(layer, "_seq_calib"), f"Layer {i} still has _seq_calib after cleanup" + assert not hasattr(layer, "_original_forward"), ( + f"Layer {i} still has _original_forward after cleanup" + ) + assert not hasattr(model, "_original_forward") + + +def test_skip_output_meta_not_shared_across_heterogeneous_layers(monkeypatch): + """Each layer stores its own output_meta, supporting heterogeneous architectures.""" + + class _SmallBlock(nn.Module): + def __init__(self): + super().__init__() + self.linear = nn.Linear(8, 8, bias=False) + + def forward(self, x): + return (self.linear(x), None, torch.zeros(1)) + + class _BigBlock(nn.Module): + def __init__(self): + super().__init__() + self.linear = nn.Linear(8, 8, bias=False) + + def forward(self, x): + return (self.linear(x),) + + class _HeterogeneousModel(nn.Module): + def __init__(self): + super().__init__() + self.layers = nn.ModuleList([_SmallBlock(), _BigBlock(), _SmallBlock()]) + + def forward(self, x): + for layer in self.layers: + out = layer(x) + x = out[0] + return x + + _register_test_discoverer(monkeypatch) + model = _HeterogeneousModel() + data = [torch.randn(2, 8)] + + def forward_loop(m): + for d in data: + m(d) + + collector = LayerActivationCollector(model) + collector._patch_all_layers() + try: + for layer in model.layers: + collector.get_input_activations(layer, forward_loop) + + # After full calibration, layers 0 and 1 have been through 'run' and have output_meta + meta_0 = model.layers[0]._seq_calib.output_meta + meta_1 = model.layers[1]._seq_calib.output_meta + assert meta_0 is not None + assert meta_1 is not None + # SmallBlock returns 3-element tuple, BigBlock returns 1-element tuple + assert len(meta_0[1]) == 3 + assert len(meta_1[1]) == 1 + finally: + collector._unpatch_all_layers() + + +def test_run_layer_reflects_weight_updates(monkeypatch): + """After calib_func modifies weights, the next layer should see updated activations.""" + _register_test_discoverer(monkeypatch) + torch.manual_seed(0) + dim = 8 + + class _ScaleLayer(nn.Module): + def __init__(self): + super().__init__() + self.weight = nn.Parameter(torch.ones(dim)) + + def forward(self, x): + return x * self.weight + + class _TwoScaleModel(nn.Module): + def __init__(self): + super().__init__() + self.layers = nn.ModuleList([_ScaleLayer(), _ScaleLayer()]) + + def forward(self, x): + for layer in self.layers: + x = layer(x) + return x + + model = _TwoScaleModel() + x = torch.randn(2, dim) + + activations_before_weight_update = model.layers[0](x).clone() + + def forward_loop(m): + m(x) + + def weight_doubling_calib(layer, layer_forward_loop, **kwargs): + with torch.no_grad(): + layer.weight.mul_(2.0) + layer_forward_loop(layer) + + sequential_calibrate( + model, + forward_loop=forward_loop, + calib_func=weight_doubling_calib, + ) + + # Layer 0's weight was doubled by calib_func. When collecting inputs + # for layer 1, the run-mode replay of layer 0 should use the updated + # weight, so layer 1 should have received 2x the original activations. + expected = activations_before_weight_update * 2.0 + # Verify by running model.layers[0] with its updated weights + actual = model.layers[0](x) + assert torch.allclose(actual, expected) diff --git a/tests/unit/torch/quantization/test_utils.py b/tests/unit/torch/quantization/test_utils.py index a88501192..f49c575b7 100644 --- a/tests/unit/torch/quantization/test_utils.py +++ b/tests/unit/torch/quantization/test_utils.py @@ -16,6 +16,7 @@ import pytest import torch +from modelopt.torch.quantization.activation_collector import LayerActivationCollector from modelopt.torch.quantization.utils import ( convert_quantization_axis_to_reduce_axis, reduce_block_amax, @@ -101,3 +102,214 @@ def test_convert_quantization_axis_to_reduce_axis(shape, quant_axis, expected_re assert reduced.shape == tuple(expected_shape), ( f"Reduction result shape {reduced.shape} doesn't match expected {tuple(expected_shape)}" ) + + +def test_layer_activation_collector_support_api(monkeypatch): + class _SupportedModel(torch.nn.Module): + def __init__(self): + super().__init__() + self.layers = torch.nn.ModuleList([torch.nn.Identity()]) + + class _UnsupportedModel(torch.nn.Module): + def __init__(self): + super().__init__() + self.linear = torch.nn.Linear(4, 4) + + supported = _SupportedModel() + unsupported = _UnsupportedModel() + + def _supports_layers(model): + return hasattr(model, "layers") + + def _discover_layers(model): + return model.layers + + monkeypatch.setattr(LayerActivationCollector, "_decoder_layer_support", []) + LayerActivationCollector.register_decoder_layer_support(_supports_layers, _discover_layers) + + assert LayerActivationCollector.is_supported(supported) + assert LayerActivationCollector.get_decoder_layers(supported) is supported.layers + assert not LayerActivationCollector.is_supported(unsupported) + assert LayerActivationCollector.get_decoder_layers(unsupported) is None + + +def test_layer_activation_collector_decoder_discoverer_resolution_order(monkeypatch): + class _Model(torch.nn.Module): + def __init__(self): + super().__init__() + self.layers = torch.nn.ModuleList([torch.nn.Identity()]) + + calls = {"first": 0, "second": 0} + + def _supported(_model): + return True + + def _first_discoverer(_model): + calls["first"] += 1 + + def _second_discoverer(model): + calls["second"] += 1 + return model.layers + + monkeypatch.setattr( + LayerActivationCollector, + "_decoder_layer_support", + [(_supported, _first_discoverer), (_supported, _second_discoverer)], + ) + + model = _Model() + resolved = LayerActivationCollector.get_decoder_layers(model) + assert resolved is model.layers + assert calls["first"] == 1 + assert calls["second"] == 1 + + +def test_layer_activation_collector_decoder_discoverer_no_match(monkeypatch): + class _Model(torch.nn.Module): + pass + + def _unsupported(_model): + return False + + def _discoverer(_model): + return torch.nn.ModuleList([torch.nn.Identity()]) + + monkeypatch.setattr( + LayerActivationCollector, "_decoder_layer_support", [(_unsupported, _discoverer)] + ) + + model = _Model() + assert LayerActivationCollector.get_decoder_layers(model) is None + assert not LayerActivationCollector.is_supported(model) + + +def test_layer_activation_collector_decoder_discoverer_dedup(monkeypatch): + class _Model(torch.nn.Module): + def __init__(self): + super().__init__() + self.layers = torch.nn.ModuleList([torch.nn.Identity()]) + + def _supported(model): + return hasattr(model, "layers") + + def _discoverer(model): + return model.layers + + monkeypatch.setattr(LayerActivationCollector, "_decoder_layer_support", []) + LayerActivationCollector.register_decoder_layer_support(_supported, _discoverer) + LayerActivationCollector.register_decoder_layer_support(_supported, _discoverer) + + assert len(LayerActivationCollector._decoder_layer_support) == 1 + + +def test_layer_activation_collector_skip_forward_captures_correct_inputs(monkeypatch): + """The skip/run/capture strategy produces the same inputs as a plain forward.""" + + class _ToyLayer(torch.nn.Module): + def __init__(self, scale: float): + super().__init__() + self.scale = scale + + def forward(self, hidden_states): + return hidden_states * self.scale + + class _ToyDecoder(torch.nn.Module): + def __init__(self): + super().__init__() + self.layers = torch.nn.ModuleList([_ToyLayer(2.0), _ToyLayer(0.5), _ToyLayer(3.0)]) + + def forward(self, hidden_states): + for layer in self.layers: + hidden_states = layer(hidden_states) + return hidden_states + + monkeypatch.setattr( + LayerActivationCollector, + "_decoder_layer_support", + [(lambda m: hasattr(m, "layers"), lambda m: m.layers)], + ) + + model = _ToyDecoder() + batches = [torch.tensor([[1.0, 2.0]]), torch.tensor([[3.0, 4.0]])] + + def _forward_loop(m): + for b in batches: + m(b) + + collector = LayerActivationCollector(model) + collector._patch_all_layers() + try: + inp0 = collector.get_input_activations(model.layers[0], _forward_loop) + inp1 = collector.get_input_activations(model.layers[1], _forward_loop) + inp2 = collector.get_input_activations(model.layers[2], _forward_loop) + finally: + collector._unpatch_all_layers() + + expected_0 = batches + expected_1 = [model.layers[0](b) for b in batches] + expected_2 = [model.layers[1](b) for b in expected_1] + + for (args, _kw), exp in zip(inp0, expected_0): + assert torch.allclose(args[0], exp) + for (args, _kw), exp in zip(inp1, expected_1): + assert torch.allclose(args[0], exp) + for (args, _kw), exp in zip(inp2, expected_2): + assert torch.allclose(args[0], exp) + + +def test_layer_activation_collector_run_uses_cached_inputs_not_parent(monkeypatch): + """Verify that the 'run' layer uses cached inputs, not garbage from skip layers.""" + + call_log = [] + + class _ToyLayer(torch.nn.Module): + def __init__(self, name, bias): + super().__init__() + self.layer_name = name + self.bias = bias + + def forward(self, hidden_states): + call_log.append(self.layer_name) + return hidden_states + self.bias + + class _ToyDecoder(torch.nn.Module): + def __init__(self): + super().__init__() + self.layers = torch.nn.ModuleList( + [_ToyLayer("L0", 1.0), _ToyLayer("L1", 2.0), _ToyLayer("L2", 3.0)] + ) + + def forward(self, hidden_states): + for layer in self.layers: + hidden_states = layer(hidden_states) + return hidden_states + + monkeypatch.setattr( + LayerActivationCollector, + "_decoder_layer_support", + [(lambda m: hasattr(m, "layers"), lambda m: m.layers)], + ) + + model = _ToyDecoder() + batches = [torch.tensor([[10.0]])] + + def _forward_loop(m): + for b in batches: + m(b) + + collector = LayerActivationCollector(model) + collector._patch_all_layers() + try: + collector.get_input_activations(model.layers[0], _forward_loop) + call_log.clear() + collector.get_input_activations(model.layers[1], _forward_loop) + call_log_for_layer1 = list(call_log) + call_log.clear() + inp2 = collector.get_input_activations(model.layers[2], _forward_loop) + finally: + collector._unpatch_all_layers() + + assert "L0" in call_log_for_layer1 + assert "L1" not in call_log_for_layer1 + + assert torch.allclose(inp2[0][0][0], torch.tensor([[10.0 + 1.0 + 2.0]]))