From fbbf71fecdc52d6d59cfb5aa22d4ab03f8186763 Mon Sep 17 00:00:00 2001 From: Suguna Velury <178320438+sugunav14@users.noreply.github.com> Date: Tue, 24 Feb 2026 01:06:29 +0000 Subject: [PATCH 01/11] sequential flow Signed-off-by: Suguna Velury <178320438+sugunav14@users.noreply.github.com> --- modelopt/torch/quantization/utils.py | 3 ++- modelopt/torch/utils/network.py | 32 ---------------------------- 2 files changed, 2 insertions(+), 33 deletions(-) diff --git a/modelopt/torch/quantization/utils.py b/modelopt/torch/quantization/utils.py index 3c0d5e434..2702c0089 100644 --- a/modelopt/torch/quantization/utils.py +++ b/modelopt/torch/quantization/utils.py @@ -33,6 +33,8 @@ if TYPE_CHECKING: from collections.abc import Generator + from modelopt.torch.opt.searcher import ForwardLoop + __all__ = [ "EXPORT_MODE", "convert_quantization_axis_to_reduce_axis", @@ -867,5 +869,4 @@ def _early_stop_forward(self, *args, **kwargs): finally: self._unpatch_and_cleanup_layer(layer) unpatch_forward_method(self.model, "_original_forward") - return inputs diff --git a/modelopt/torch/utils/network.py b/modelopt/torch/utils/network.py index 21c096db2..afabdf0ac 100644 --- a/modelopt/torch/utils/network.py +++ b/modelopt/torch/utils/network.py @@ -635,35 +635,3 @@ def unpatch_forward_method(module: nn.Module, orig_forward_cache_name: str): setattr(module, "forward", getattr(module, orig_forward_cache_name)) delattr(module, orig_forward_cache_name) - -def get_decoder_layers(model: nn.Module, granularity: str = "decoder") -> nn.ModuleList | None: - """Detect the decoder layers from a model for sequential calibration. - - This temporary decoder-layer detection heuristic will be replaced with a more robust solution - that also supports FSDP/DDP models. - """ - if granularity != "decoder": - raise ValueError(f"Unsupported granularity: {granularity}. Only 'decoder' is supported.") - - # HuggingFace transformers pattern: model.model.layers - if hasattr(model, "model") and hasattr(model.model, "layers"): - return model.model.layers - - # Megatron/MCore pattern: model.decoder.layers - if hasattr(model, "decoder") and hasattr(model.decoder, "layers"): - return model.decoder.layers - - # Direct layers attribute (some models) - if hasattr(model, "layers") and isinstance(model.layers, nn.ModuleList): - return model.layers - - # GPT-style: model.transformer.h - if hasattr(model, "transformer") and hasattr(model.transformer, "h"): - return model.transformer.h - - # Nemotron Super/Nano - if hasattr(model, "backbone") and hasattr(model.backbone, "layers"): - return model.backbone.layers - - print("No decoder layers found for model, returning None") - return None From 92d227e4efa39385d448c299d9ba9831768b9284 Mon Sep 17 00:00:00 2001 From: Suguna Velury <178320438+sugunav14@users.noreply.github.com> Date: Tue, 24 Feb 2026 01:49:33 +0000 Subject: [PATCH 02/11] clean up Signed-off-by: Suguna Velury <178320438+sugunav14@users.noreply.github.com> --- modelopt/torch/quantization/utils.py | 1 + 1 file changed, 1 insertion(+) diff --git a/modelopt/torch/quantization/utils.py b/modelopt/torch/quantization/utils.py index 2702c0089..72fe3aea2 100644 --- a/modelopt/torch/quantization/utils.py +++ b/modelopt/torch/quantization/utils.py @@ -869,4 +869,5 @@ def _early_stop_forward(self, *args, **kwargs): finally: self._unpatch_and_cleanup_layer(layer) unpatch_forward_method(self.model, "_original_forward") + return inputs From f1cd26f05b7cc9a21b27230dba5be6db73596863 Mon Sep 17 00:00:00 2001 From: realAsma Date: Tue, 24 Feb 2026 22:43:26 +0000 Subject: [PATCH 03/11] Modular/Plugin based sequential calib Signed-off-by: realAsma --- modelopt/torch/quantization/model_calib.py | 21 +- .../torch/quantization/plugins/huggingface.py | 59 ++++- modelopt/torch/quantization/utils.py | 60 ++++- modelopt/torch/utils/network.py | 1 - .../quantization/plugins/test_huggingface.py | 27 +++ tests/unit/torch/quantization/test_calib.py | 188 ++++++++++++++++ tests/unit/torch/quantization/test_utils.py | 212 ++++++++++++++++++ 7 files changed, 551 insertions(+), 17 deletions(-) diff --git a/modelopt/torch/quantization/model_calib.py b/modelopt/torch/quantization/model_calib.py index 70f036a8d..c0b70c396 100644 --- a/modelopt/torch/quantization/model_calib.py +++ b/modelopt/torch/quantization/model_calib.py @@ -31,11 +31,7 @@ from modelopt.torch.quantization.utils import LayerActivationCollector from modelopt.torch.utils import print_rank_0 from modelopt.torch.utils.distributed import DistributedProcessGroup, ParallelState -from modelopt.torch.utils.network import ( - bind_forward_method, - get_decoder_layers, - unpatch_forward_method, -) +from modelopt.torch.utils.network import bind_forward_method, unpatch_forward_method from modelopt.torch.utils.perf import get_used_gpu_mem_fraction from .calib import MseCalibrator, NVFP4MSECalibrator @@ -1844,20 +1840,17 @@ def sequential_calibrate( if forward_loop is None: raise ValueError("forward_loop must not be None for sequential calibration.") - transformer_layers = get_decoder_layers(model) - if transformer_layers is None: - raise ValueError( - "Could not find transformer layers in model'. " - "Sequential calibration requires a model with identifiable transformer layers." - ) + transformer_layers = LayerActivationCollector.get_decoder_layers(model) + assert transformer_layers is not None print_rank_0(f"Sequential calibration: Found {len(transformer_layers)} transformer layers") + if len(transformer_layers) == 0: + return - gettr = LayerActivationCollector(model) + input_getter = LayerActivationCollector(model) for layer in transformer_layers: - # Get updated input activations to the current layer - layer_inputs = gettr.get_input_activations(layer, forward_loop) + layer_inputs = input_getter.get_input_activations(layer, forward_loop) # Define a forward loop for the current layer def _layer_forward_loop(m, _inputs=layer_inputs): diff --git a/modelopt/torch/quantization/plugins/huggingface.py b/modelopt/torch/quantization/plugins/huggingface.py index 318646c39..9020c3b91 100644 --- a/modelopt/torch/quantization/plugins/huggingface.py +++ b/modelopt/torch/quantization/plugins/huggingface.py @@ -56,7 +56,7 @@ else: weight_dequant = None -from ..utils import replace_function +from ..utils import LayerActivationCollector, replace_function from .attention import register_attention_for_kv_quant from .custom import CUSTOM_MODEL_PLUGINS, _ParallelLinear, _QuantFunctionalMixin @@ -1179,6 +1179,55 @@ def _is_supported_hf_model(model): return isinstance(model, tuple(supported_models)) +def is_homogenous_hf_model(model: nn.Module) -> bool: + decoder_layers = get_homogeneous_hf_decoder_layers(model) + if decoder_layers is None or len(decoder_layers) == 0: + return False + layer_classes = {type(layer) for layer in decoder_layers} + return len(layer_classes) == 1 + + +def get_homogeneous_hf_decoder_layers(model: nn.Module) -> nn.ModuleList | None: + if not _is_supported_hf_model(model): + return None + + if hasattr(model, "model") and hasattr(model.model, "layers"): + return model.model.layers + + return None + + +def build_hf_homogenous_next_layer_inputs_hook(model: nn.Module): + def _extract_hidden_states(layer_output): + if isinstance(layer_output, tuple): + return layer_output[0] + if isinstance(layer_output, dict): + if "hidden_states" in layer_output: + return layer_output["hidden_states"] + return layer_output + + def _build_next_layer_inputs_hook(prev_layer, cached_inputs): + next_inputs = [] + for args, kwargs in cached_inputs: + prev_output = prev_layer(*args, **kwargs) + hidden_states = _extract_hidden_states(prev_output) + if len(args) >= 1: + next_args = (hidden_states, *args[1:]) + next_kwargs = kwargs + elif "hidden_states" in kwargs: + next_args = args + next_kwargs = dict(kwargs) + next_kwargs["hidden_states"] = hidden_states + else: + raise ValueError( + "Unable to build next-layer inputs without hidden_states in args/kwargs." + ) + next_inputs.append((next_args, next_kwargs)) + return next_inputs + + return _build_next_layer_inputs_hook + + @contextmanager def setup_model_for_gradient_checkpointing(model: nn.Module): use_cache = None @@ -1228,6 +1277,14 @@ def _is_param_grad_enabled_for_auto_quantize(pname, model): _is_param_grad_enabled_for_auto_quantize, ) +LayerActivationCollector.register_decoder_layer_support( + is_homogenous_hf_model, get_homogeneous_hf_decoder_layers +) + +LayerActivationCollector.register_next_layer_input_support( + is_homogenous_hf_model, build_hf_homogenous_next_layer_inputs_hook +) + CUSTOM_MODEL_PLUGINS.update( [ register_falcon_linears_on_the_fly, diff --git a/modelopt/torch/quantization/utils.py b/modelopt/torch/quantization/utils.py index 72fe3aea2..4dce60598 100644 --- a/modelopt/torch/quantization/utils.py +++ b/modelopt/torch/quantization/utils.py @@ -823,8 +823,43 @@ class LayerActivationCollector: patching layers to capture inputs/outputs during forward passes """ + _next_layer_input_support: list[tuple[Any, Any]] = [] + _decoder_layer_support: list[tuple[Any, Any]] = [] + def __init__(self, model: nn.Module): self.model = model + self._previous_layer = None + self._previous_layer_inputs = None + + @staticmethod + def get_decoder_layers(model: nn.Module) -> nn.ModuleList | None: + """Return decoder layers supported by sequential calibration.""" + for is_supported, discoverer in LayerActivationCollector._decoder_layer_support: + if not is_supported(model): + continue + decoder_layers = discoverer(model) + if decoder_layers is not None: + return decoder_layers + return None + + @staticmethod + def is_supported(model: nn.Module) -> bool: + """Whether the model supports decoder-layer sequential calibration.""" + return LayerActivationCollector.get_decoder_layers(model) is not None + + @classmethod + def register_next_layer_input_support( + cls, is_supported: Any, build_next_layer_inputs_hook: Any + ): + entry = (is_supported, build_next_layer_inputs_hook) + if entry not in cls._next_layer_input_support: + cls._next_layer_input_support.append(entry) + + @classmethod + def register_decoder_layer_support(cls, is_supported: Any, discoverer: Any): + entry = (is_supported, discoverer) + if entry not in cls._decoder_layer_support: + cls._decoder_layer_support.append(entry) @staticmethod def _patch_and_initialize_layer(layer: torch.nn.Module, stop_after_collection: bool = False): @@ -852,8 +887,15 @@ def _unpatch_and_cleanup_layer(layer: torch.nn.Module): if hasattr(layer, "inputs"): del layer.inputs + def _resolve_next_layer_inputs_hook(self): + for is_supported, build_next_layer_inputs_hook in self._next_layer_input_support: + if not is_supported(self.model): + continue + return build_next_layer_inputs_hook(self.model) + return None + @torch.no_grad() - def get_input_activations(self, layer: torch.nn.Module, forward_loop: ForwardLoop) -> list: + def _collect_input_activations(self, layer: torch.nn.Module, forward_loop: ForwardLoop) -> list: # Wrap model forward to catch _EarlyStopForward per-batch def _early_stop_forward(self, *args, **kwargs): try: @@ -871,3 +913,19 @@ def _early_stop_forward(self, *args, **kwargs): unpatch_forward_method(self.model, "_original_forward") return inputs + + @torch.no_grad() + def get_input_activations(self, layer: torch.nn.Module, forward_loop: ForwardLoop) -> list: + is_first_layer = self._previous_layer is None or self._previous_layer_inputs is None + if is_first_layer: + inputs = self._collect_input_activations(layer, forward_loop) + else: + next_layer_inputs_hook = self._resolve_next_layer_inputs_hook() + if next_layer_inputs_hook is None: + inputs = self._collect_input_activations(layer, forward_loop) + else: + inputs = next_layer_inputs_hook(self._previous_layer, self._previous_layer_inputs) + + self._previous_layer = layer + self._previous_layer_inputs = inputs + return inputs diff --git a/modelopt/torch/utils/network.py b/modelopt/torch/utils/network.py index afabdf0ac..b54332375 100644 --- a/modelopt/torch/utils/network.py +++ b/modelopt/torch/utils/network.py @@ -634,4 +634,3 @@ def unpatch_forward_method(module: nn.Module, orig_forward_cache_name: str): with temporarily_remove_accelerate_hook(module): setattr(module, "forward", getattr(module, orig_forward_cache_name)) delattr(module, orig_forward_cache_name) - diff --git a/tests/unit/torch/quantization/plugins/test_huggingface.py b/tests/unit/torch/quantization/plugins/test_huggingface.py index 6b934a32c..043d6d6aa 100644 --- a/tests/unit/torch/quantization/plugins/test_huggingface.py +++ b/tests/unit/torch/quantization/plugins/test_huggingface.py @@ -23,6 +23,7 @@ from _test_utils.torch.misc import set_seed from _test_utils.torch.transformers_models import ( create_tiny_llama_dir, + get_tiny_gpt_oss, get_tiny_llama, get_tiny_qwen3_moe, tf_modelopt_state_and_output_tester, @@ -30,6 +31,11 @@ import modelopt.torch.quantization as mtq from modelopt.torch.quantization.nn import QuantLinear, QuantModuleRegistry +from modelopt.torch.quantization.plugins.huggingface import ( + get_homogeneous_hf_decoder_layers, + is_homogenous_hf_model, +) +from modelopt.torch.quantization.utils import LayerActivationCollector pytest.importorskip("transformers") @@ -199,3 +205,24 @@ def test_quantized_transformers_save_restore(tmp_path, model_cls, quant_config): model_test = model_cls.from_pretrained(tiny_llama_dir / "modelopt_model") tf_modelopt_state_and_output_tester(model_ref, model_test) + + +def test_is_homogenous_hf_model_llama(): + model = get_tiny_llama() + assert is_homogenous_hf_model(model) + + +def test_is_homogenous_hf_model_gpt_oss(): + model = get_tiny_gpt_oss(num_hidden_layers=1) + assert is_homogenous_hf_model(model) + + +def test_hf_decoder_discoverer_registration_path(): + model = get_tiny_llama() + assert any( + is_supported is is_homogenous_hf_model and discoverer is get_homogeneous_hf_decoder_layers + for is_supported, discoverer in LayerActivationCollector._decoder_layer_support + ) + assert LayerActivationCollector.get_decoder_layers(model) is get_homogeneous_hf_decoder_layers( + model + ) diff --git a/tests/unit/torch/quantization/test_calib.py b/tests/unit/torch/quantization/test_calib.py index 7bc78c40e..f534cd6bb 100644 --- a/tests/unit/torch/quantization/test_calib.py +++ b/tests/unit/torch/quantization/test_calib.py @@ -17,6 +17,7 @@ from functools import partial +import pytest import torch import torch.nn as nn from _test_utils.torch.quantization.quantize_common import get_awq_config @@ -26,6 +27,7 @@ from modelopt.torch.quantization.model_calib import ( apply_pre_quant_scale_and_smooth, disable_pre_quant_scale_and_resmooth, + sequential_calibrate, ) from modelopt.torch.quantization.nn import TensorQuantizer @@ -375,3 +377,189 @@ def test_svdquant_lora_weights(): module.weight_quantizer.svdquant_lora_b @ module.weight_quantizer.svdquant_lora_a ) assert lora_residual.shape == module.weight.shape + + +def test_sequential_calibrate_support_gate(): + class _UnsupportedModel(nn.Module): + def __init__(self): + super().__init__() + self.linear = nn.Linear(4, 4, bias=False) + + def forward(self, x): + return self.linear(x) + + model = _UnsupportedModel() + + with ( + torch.no_grad(), + pytest.raises(ValueError, match="Sequential calibration requires a model"), + ): + sequential_calibrate( + model, + forward_loop=lambda m: m(torch.randn(2, 4)), + calib_func=lambda layer, loop: loop(layer), + ) + + +def test_sequential_calibrate_propagates_inputs_without_replaying_full_model(monkeypatch): + from modelopt.torch.quantization.utils import LayerActivationCollector + + class _ToyLayer(nn.Module): + def __init__(self, scale: float, bias: float): + super().__init__() + self.scale = scale + self.bias = bias + + def forward(self, hidden_states): + return hidden_states * self.scale + self.bias + + class _ToyDecoder(nn.Module): + def __init__(self): + super().__init__() + self.layers = nn.ModuleList( + [ + _ToyLayer(scale=2.0, bias=1.0), + _ToyLayer(scale=0.5, bias=3.0), + _ToyLayer(scale=1.0, bias=-2.0), + ] + ) + + def forward(self, hidden_states): + for layer in self.layers: + hidden_states = layer(hidden_states) + return hidden_states + + model = _ToyDecoder() + monkeypatch.setattr( + LayerActivationCollector, + "_decoder_layer_support", + [(lambda m: hasattr(m, "layers"), lambda m: m.layers)], + ) + batches = [ + torch.tensor([[1.0, 2.0]]), + torch.tensor([[3.0, 4.0]]), + ] + + forward_loop_calls = 0 + + def _forward_loop(m): + nonlocal forward_loop_calls + forward_loop_calls += 1 + for batch in batches: + m(batch) + + observed_layer_inputs = [] + + def _calib_func(layer, layer_forward_loop): + captured = [] + + def _pre_hook(_module, args): + captured.append(args[0].clone()) + + handle = layer.register_forward_pre_hook(_pre_hook) + try: + layer_forward_loop(layer) + finally: + handle.remove() + observed_layer_inputs.append(captured) + + sequential_calibrate(model, _forward_loop, _calib_func) + + assert forward_loop_calls == len(model.layers) + assert len(observed_layer_inputs) == len(model.layers) + for layer_inputs in observed_layer_inputs: + assert len(layer_inputs) == len(batches) + + expected_layer_0 = batches + expected_layer_1 = [model.layers[0](batch) for batch in batches] + expected_layer_2 = [model.layers[1](batch) for batch in expected_layer_1] + + for observed, expected in zip(observed_layer_inputs[0], expected_layer_0): + assert torch.allclose(observed, expected) + for observed, expected in zip(observed_layer_inputs[1], expected_layer_1): + assert torch.allclose(observed, expected) + for observed, expected in zip(observed_layer_inputs[2], expected_layer_2): + assert torch.allclose(observed, expected) + + +def test_sequential_calibrate_uses_next_layer_hook_without_replaying_full_model(monkeypatch): + from modelopt.torch.quantization.utils import LayerActivationCollector + + class _ToyLayer(nn.Module): + def __init__(self, scale: float, bias: float): + super().__init__() + self.scale = scale + self.bias = bias + + def forward(self, hidden_states): + return hidden_states * self.scale + self.bias + + class _ToyDecoder(nn.Module): + def __init__(self): + super().__init__() + self.layers = nn.ModuleList( + [ + _ToyLayer(scale=2.0, bias=1.0), + _ToyLayer(scale=0.5, bias=3.0), + _ToyLayer(scale=1.0, bias=-2.0), + ] + ) + + def forward(self, hidden_states): + for layer in self.layers: + hidden_states = layer(hidden_states) + return hidden_states + + model = _ToyDecoder() + batches = [torch.tensor([[1.0, 2.0]]), torch.tensor([[3.0, 4.0]])] + forward_loop_calls = 0 + + def _forward_loop(m): + nonlocal forward_loop_calls + forward_loop_calls += 1 + for batch in batches: + m(batch) + + def _supported(_model): + return True + + def _build_hook(_model): + def _hook(prev_layer, cached_inputs): + next_inputs = [] + for args, kwargs in cached_inputs: + hidden_states = prev_layer(*args, **kwargs) + next_inputs.append(((hidden_states, *args[1:]), kwargs)) + return next_inputs + + return _hook + + monkeypatch.setattr( + LayerActivationCollector, + "_decoder_layer_support", + [(lambda m: hasattr(m, "layers"), lambda m: m.layers)], + ) + monkeypatch.setattr( + LayerActivationCollector, + "_next_layer_input_support", + [(_supported, _build_hook)], + ) + + observed_layer_inputs = [] + + def _calib_func(layer, layer_forward_loop): + captured = [] + + def _pre_hook(_module, args): + captured.append(args[0].clone()) + + handle = layer.register_forward_pre_hook(_pre_hook) + try: + layer_forward_loop(layer) + finally: + handle.remove() + observed_layer_inputs.append(captured) + + sequential_calibrate(model, _forward_loop, _calib_func) + + assert forward_loop_calls == 1 + assert len(observed_layer_inputs) == len(model.layers) diff --git a/tests/unit/torch/quantization/test_utils.py b/tests/unit/torch/quantization/test_utils.py index a88501192..3dc0d93e3 100644 --- a/tests/unit/torch/quantization/test_utils.py +++ b/tests/unit/torch/quantization/test_utils.py @@ -17,11 +17,21 @@ import torch from modelopt.torch.quantization.utils import ( + LayerActivationCollector, convert_quantization_axis_to_reduce_axis, reduce_block_amax, ) +def _build_next_inputs(prev_layer, cached_inputs): + next_inputs = [] + for args, kwargs in cached_inputs: + prev_output = prev_layer(*args, **kwargs) + hidden_states = prev_output[0] if isinstance(prev_output, tuple) else prev_output + next_inputs.append(((hidden_states, *args[1:]), kwargs)) + return next_inputs + + @pytest.mark.parametrize( ("block_sizes", "test_input", "expected_scales"), [ @@ -101,3 +111,205 @@ def test_convert_quantization_axis_to_reduce_axis(shape, quant_axis, expected_re assert reduced.shape == tuple(expected_shape), ( f"Reduction result shape {reduced.shape} doesn't match expected {tuple(expected_shape)}" ) + + +def test_layer_activation_collector_support_api(monkeypatch): + class _SupportedModel(torch.nn.Module): + def __init__(self): + super().__init__() + self.layers = torch.nn.ModuleList([torch.nn.Identity()]) + + class _UnsupportedModel(torch.nn.Module): + def __init__(self): + super().__init__() + self.linear = torch.nn.Linear(4, 4) + + supported = _SupportedModel() + unsupported = _UnsupportedModel() + + def _supports_layers(model): + return hasattr(model, "layers") + + def _discover_layers(model): + return model.layers + + monkeypatch.setattr(LayerActivationCollector, "_decoder_layer_support", []) + LayerActivationCollector.register_decoder_layer_support(_supports_layers, _discover_layers) + + assert LayerActivationCollector.is_supported(supported) + assert LayerActivationCollector.get_decoder_layers(supported) is supported.layers + assert not LayerActivationCollector.is_supported(unsupported) + assert LayerActivationCollector.get_decoder_layers(unsupported) is None + + +def test_layer_activation_collector_decoder_discoverer_resolution_order(monkeypatch): + class _Model(torch.nn.Module): + def __init__(self): + super().__init__() + self.layers = torch.nn.ModuleList([torch.nn.Identity()]) + + calls = {"first": 0, "second": 0} + + def _supported(_model): + return True + + def _first_discoverer(_model): + calls["first"] += 1 + + def _second_discoverer(model): + calls["second"] += 1 + return model.layers + + monkeypatch.setattr( + LayerActivationCollector, + "_decoder_layer_support", + [(_supported, _first_discoverer), (_supported, _second_discoverer)], + ) + + model = _Model() + resolved = LayerActivationCollector.get_decoder_layers(model) + assert resolved is model.layers + assert calls["first"] == 1 + assert calls["second"] == 1 + + +def test_layer_activation_collector_decoder_discoverer_no_match(monkeypatch): + class _Model(torch.nn.Module): + pass + + def _unsupported(_model): + return False + + def _discoverer(_model): + return torch.nn.ModuleList([torch.nn.Identity()]) + + monkeypatch.setattr( + LayerActivationCollector, "_decoder_layer_support", [(_unsupported, _discoverer)] + ) + + model = _Model() + assert LayerActivationCollector.get_decoder_layers(model) is None + assert not LayerActivationCollector.is_supported(model) + + +def test_layer_activation_collector_decoder_discoverer_dedup(monkeypatch): + class _Model(torch.nn.Module): + def __init__(self): + super().__init__() + self.layers = torch.nn.ModuleList([torch.nn.Identity()]) + + def _supported(model): + return hasattr(model, "layers") + + def _discoverer(model): + return model.layers + + monkeypatch.setattr(LayerActivationCollector, "_decoder_layer_support", []) + LayerActivationCollector.register_decoder_layer_support(_supported, _discoverer) + LayerActivationCollector.register_decoder_layer_support(_supported, _discoverer) + + assert len(LayerActivationCollector._decoder_layer_support) == 1 + + +def test_layer_activation_collector_uses_first_matching_next_layer_hook(monkeypatch): + class _ToyLayer(torch.nn.Module): + def forward(self, hidden_states, attention_mask=None): + return hidden_states + 1.0, attention_mask + + class _ToyDecoder(torch.nn.Module): + def __init__(self): + super().__init__() + self.layers = torch.nn.ModuleList([_ToyLayer(), _ToyLayer()]) + + def forward(self, hidden_states, attention_mask=None): + for layer in self.layers: + hidden_states, _ = layer(hidden_states, attention_mask=attention_mask) + return hidden_states + + model = _ToyDecoder() + collector = LayerActivationCollector(model) + called = {"first": 0, "second": 0} + + def _unsupported(_model): + return False + + def _supported(_model): + return True + + def _build_first_hook(_model): + def _first_hook(prev_layer, cached_inputs): + called["first"] += 1 + return _build_next_inputs(prev_layer, cached_inputs) + + return _first_hook + + def _build_second_hook(_model): + def _second_hook(prev_layer, cached_inputs): + called["second"] += 1 + return _build_next_inputs(prev_layer, cached_inputs) + + return _second_hook + + monkeypatch.setattr( + LayerActivationCollector, + "_next_layer_input_support", + [ + (_unsupported, _build_first_hook), + (_supported, _build_second_hook), + (_supported, _build_first_hook), + ], + ) + + batches = [torch.tensor([[1.0, 2.0]]), torch.tensor([[3.0, 4.0]])] + + def _forward_loop(m): + for batch in batches: + m(batch) + + first_inputs = collector.get_input_activations(model.layers[0], _forward_loop) + second_inputs = collector.get_input_activations(model.layers[1], _forward_loop) + + assert called["first"] == 0 + assert called["second"] == 1 + assert len(second_inputs) == len(first_inputs) + assert isinstance(second_inputs[0][0], tuple) + assert isinstance(second_inputs[0][1], dict) + + +def test_layer_activation_collector_falls_back_to_collection_without_matching_hook(monkeypatch): + class _ToyLayer(torch.nn.Module): + def forward(self, hidden_states): + return hidden_states + 1.0 + + class _ToyDecoder(torch.nn.Module): + def __init__(self): + super().__init__() + self.layers = torch.nn.ModuleList([_ToyLayer(), _ToyLayer()]) + + def forward(self, hidden_states): + for layer in self.layers: + hidden_states = layer(hidden_states) + return hidden_states + + model = _ToyDecoder() + collector = LayerActivationCollector(model) + collect_calls = {"count": 0} + original_collect = LayerActivationCollector._collect_input_activations + + def _spy_collect(self, layer, forward_loop): + collect_calls["count"] += 1 + return original_collect(self, layer, forward_loop) + + monkeypatch.setattr(LayerActivationCollector, "_next_layer_input_support", []) + monkeypatch.setattr(LayerActivationCollector, "_collect_input_activations", _spy_collect) + + batches = [torch.tensor([[1.0, 2.0]])] + + def _forward_loop(m): + for batch in batches: + m(batch) + + collector.get_input_activations(model.layers[0], _forward_loop) + collector.get_input_activations(model.layers[1], _forward_loop) + + assert collect_calls["count"] == 2 From 4e4d7c3201d4508b40ffb9c1875a234f3d3e11df Mon Sep 17 00:00:00 2001 From: Suguna Velury <178320438+sugunav14@users.noreply.github.com> Date: Wed, 4 Mar 2026 18:31:05 +0000 Subject: [PATCH 04/11] update Signed-off-by: Suguna Velury <178320438+sugunav14@users.noreply.github.com> --- .../torch/quantization/plugins/huggingface.py | 106 ++++++++++++++++-- modelopt/torch/quantization/utils.py | 24 ++-- 2 files changed, 115 insertions(+), 15 deletions(-) diff --git a/modelopt/torch/quantization/plugins/huggingface.py b/modelopt/torch/quantization/plugins/huggingface.py index 9020c3b91..4dd05ec4c 100644 --- a/modelopt/torch/quantization/plugins/huggingface.py +++ b/modelopt/torch/quantization/plugins/huggingface.py @@ -1179,7 +1179,25 @@ def _is_supported_hf_model(model): return isinstance(model, tuple(supported_models)) +def is_nemotron_h_model(model: nn.Module) -> bool: + return get_nemotron_h_decoder_layers(model) is not None + + +def get_nemotron_h_decoder_layers(model: nn.Module) -> nn.ModuleList | None: + if not _is_supported_hf_model(model): + return None + + if hasattr(model, "model") and hasattr(model.model, "layers"): + layers = model.model.layers + if len(layers) > 0 and hasattr(layers[0], "block_type"): + return layers + + return None + + def is_homogenous_hf_model(model: nn.Module) -> bool: + if is_nemotron_h_model(model): + return False decoder_layers = get_homogeneous_hf_decoder_layers(model) if decoder_layers is None or len(decoder_layers) == 0: return False @@ -1197,15 +1215,16 @@ def get_homogeneous_hf_decoder_layers(model: nn.Module) -> nn.ModuleList | None: return None -def build_hf_homogenous_next_layer_inputs_hook(model: nn.Module): - def _extract_hidden_states(layer_output): - if isinstance(layer_output, tuple): - return layer_output[0] - if isinstance(layer_output, dict): - if "hidden_states" in layer_output: - return layer_output["hidden_states"] - return layer_output +def _extract_hidden_states(layer_output): + if isinstance(layer_output, tuple): + return layer_output[0] + if isinstance(layer_output, dict): + if "hidden_states" in layer_output: + return layer_output["hidden_states"] + return layer_output + +def build_hf_homogenous_next_layer_inputs_hook(model: nn.Module): def _build_next_layer_inputs_hook(prev_layer, cached_inputs): next_inputs = [] for args, kwargs in cached_inputs: @@ -1228,6 +1247,69 @@ def _build_next_layer_inputs_hook(prev_layer, cached_inputs): return _build_next_layer_inputs_hook +def build_nemotron_h_next_layer_inputs_hook(model): + """Build a hook that propagates hidden_states and reconstructs per-block-type masks. + + Captures the original attention_mask via a forward pre-hook on model.model, then + reconstructs the correct mask for each layer's block_type using create_causal_mask + (for attention) or _update_mamba_mask (for mamba). + + Returns (hook, handle) where handle must be removed after first-layer collection. + """ + inner_model = model.model + layers = inner_model.layers + next_block_type_for = {layers[i]: layers[i + 1].block_type for i in range(len(layers) - 1)} + + update_mamba_mask = getattr(inner_model, "_update_mamba_mask", None) + + try: + from transformers.masking_utils import create_causal_mask + except ImportError: + create_causal_mask = None + + cached_original_masks = [] + + def _capture_attention_mask(module, args, kwargs): + cached_original_masks.append(kwargs.get("attention_mask")) + + handle = inner_model.register_forward_pre_hook(_capture_attention_mask, with_kwargs=True) + + base_hook = build_hf_homogenous_next_layer_inputs_hook(model) + + def hook(prev_layer, cached_inputs): + next_inputs = base_hook(prev_layer, cached_inputs) + + next_block_type = next_block_type_for.get(prev_layer) + if next_block_type is None: + return next_inputs + + for i, (args, kwargs) in enumerate(next_inputs): + original_mask = cached_original_masks[i] if i < len(cached_original_masks) else None + + if next_block_type == "mamba" and update_mamba_mask is not None: + mask = update_mamba_mask(original_mask, kwargs.get("cache_position")) + elif next_block_type == "attention" and create_causal_mask is not None: + hidden_states = args[0] if args else kwargs["hidden_states"] + mask = create_causal_mask( + config=inner_model.config, + input_embeds=hidden_states, + attention_mask=original_mask, + cache_position=kwargs.get("cache_position"), + past_key_values=kwargs.get("past_key_values"), + position_ids=kwargs.get("position_ids"), + ) + else: + mask = None + + next_kwargs = dict(kwargs) + next_kwargs["attention_mask"] = mask + next_inputs[i] = (args, next_kwargs) + + return next_inputs + + return hook, handle + + @contextmanager def setup_model_for_gradient_checkpointing(model: nn.Module): use_cache = None @@ -1277,10 +1359,18 @@ def _is_param_grad_enabled_for_auto_quantize(pname, model): _is_param_grad_enabled_for_auto_quantize, ) +LayerActivationCollector.register_decoder_layer_support( + is_nemotron_h_model, get_nemotron_h_decoder_layers +) + LayerActivationCollector.register_decoder_layer_support( is_homogenous_hf_model, get_homogeneous_hf_decoder_layers ) +LayerActivationCollector.register_next_layer_input_support( + is_nemotron_h_model, build_nemotron_h_next_layer_inputs_hook +) + LayerActivationCollector.register_next_layer_input_support( is_homogenous_hf_model, build_hf_homogenous_next_layer_inputs_hook ) diff --git a/modelopt/torch/quantization/utils.py b/modelopt/torch/quantization/utils.py index 4dce60598..f3e94bbf8 100644 --- a/modelopt/torch/quantization/utils.py +++ b/modelopt/torch/quantization/utils.py @@ -830,6 +830,7 @@ def __init__(self, model: nn.Module): self.model = model self._previous_layer = None self._previous_layer_inputs = None + self._next_layer_inputs_hook = None @staticmethod def get_decoder_layers(model: nn.Module) -> nn.ModuleList | None: @@ -888,11 +889,19 @@ def _unpatch_and_cleanup_layer(layer: torch.nn.Module): del layer.inputs def _resolve_next_layer_inputs_hook(self): + """Resolve the next-layer inputs hook from the registry. + + Returns (hook, handle) where handle is an optional RemovableHandle for + pre-collection cleanup. If the factory returns just a hook, handle is None. + """ for is_supported, build_next_layer_inputs_hook in self._next_layer_input_support: if not is_supported(self.model): continue - return build_next_layer_inputs_hook(self.model) - return None + result = build_next_layer_inputs_hook(self.model) + if isinstance(result, tuple): + return result + return result, None + return None, None @torch.no_grad() def _collect_input_activations(self, layer: torch.nn.Module, forward_loop: ForwardLoop) -> list: @@ -918,13 +927,14 @@ def _early_stop_forward(self, *args, **kwargs): def get_input_activations(self, layer: torch.nn.Module, forward_loop: ForwardLoop) -> list: is_first_layer = self._previous_layer is None or self._previous_layer_inputs is None if is_first_layer: + self._next_layer_inputs_hook, handle = self._resolve_next_layer_inputs_hook() + inputs = self._collect_input_activations(layer, forward_loop) + if handle is not None: + handle.remove() + elif self._next_layer_inputs_hook is None: inputs = self._collect_input_activations(layer, forward_loop) else: - next_layer_inputs_hook = self._resolve_next_layer_inputs_hook() - if next_layer_inputs_hook is None: - inputs = self._collect_input_activations(layer, forward_loop) - else: - inputs = next_layer_inputs_hook(self._previous_layer, self._previous_layer_inputs) + inputs = self._next_layer_inputs_hook(self._previous_layer, self._previous_layer_inputs) self._previous_layer = layer self._previous_layer_inputs = inputs From d822d929d67c4514f02751dd52b17c9c18bd72a2 Mon Sep 17 00:00:00 2001 From: Suguna Velury <178320438+sugunav14@users.noreply.github.com> Date: Thu, 5 Mar 2026 19:01:56 +0000 Subject: [PATCH 05/11] initial e2e tested sequential calibrate refactor Signed-off-by: Suguna Velury <178320438+sugunav14@users.noreply.github.com> --- modelopt/torch/quantization/model_calib.py | 60 ++++-- .../torch/quantization/plugins/huggingface.py | 107 +-------- modelopt/torch/quantization/utils.py | 203 +++++++++++------- tests/unit/torch/quantization/test_calib.py | 65 +++--- tests/unit/torch/quantization/test_utils.py | 150 ++++++------- 5 files changed, 275 insertions(+), 310 deletions(-) diff --git a/modelopt/torch/quantization/model_calib.py b/modelopt/torch/quantization/model_calib.py index c0b70c396..ecbbc24e8 100644 --- a/modelopt/torch/quantization/model_calib.py +++ b/modelopt/torch/quantization/model_calib.py @@ -127,10 +127,32 @@ def max_calibrate( forward_loop(model) finish_stats_collection(model) - # Sync input_quantizer amax across local experts within each rank (for SequentialMLP) + # Sync amax across local experts within each rank (for SequentialMLP and HuggingFace MoE) for name, module in model.named_modules(): if hasattr(module, "layer_sync_moe_local_experts_amax"): module.layer_sync_moe_local_experts_amax() + elif hasattr(module, "sync_moe_local_experts_amax"): + module.sync_moe_local_experts_amax() + + for name, module in list(model.named_modules()): + if isinstance(module, TensorQuantizer) and not module._disabled: + if module._calibrator is not None and not module._dynamic and hasattr(module, "_amax"): + # Get the initial amax from max calibration + initial_amax = module._amax.clone().detach() + + is_nvfp4_static = ( + module.is_static_block_quant + and module._num_bits == (2, 1) + and module._block_sizes is not None + and module._block_sizes.get("scale_bits") == (4, 3) + ) + + if is_nvfp4_static: + # Compute and set global_amax + global_amax = reduce_amax(initial_amax, axis=None) + + # Convert to NVFP4StaticQuantizer in-place + NVFP4StaticQuantizer.from_tensor_quantizer(module, global_amax=global_amax) if not distributed_sync: return @@ -1836,10 +1858,20 @@ def sequential_calibrate( calib_func: Callable, **calib_kwargs, ): - """Sequential calibration - a sequential layer-by-layer calibration algorithm.""" + """Sequential calibration - a sequential layer-by-layer calibration algorithm. + + Runs the full model forward per layer but patches decoder layers with a + skip / run / capture strategy so that inter-layer logic in parent modules + (e.g. mask construction) executes naturally without model-specific hooks. + """ if forward_loop is None: raise ValueError("forward_loop must not be None for sequential calibration.") + if not LayerActivationCollector.is_supported(model): + raise ValueError( + "Could not find transformer layers in model. " + "Sequential calibration requires a model with identifiable transformer layers." + ) transformer_layers = LayerActivationCollector.get_decoder_layers(model) assert transformer_layers is not None @@ -1848,16 +1880,20 @@ def sequential_calibrate( return input_getter = LayerActivationCollector(model) + input_getter._patch_all_layers() + + try: + for layer_idx, layer in enumerate(transformer_layers): + print_rank_0(f"Calibrating layer {layer_idx}") + layer_inputs = input_getter.get_input_activations(layer, forward_loop) - for layer in transformer_layers: - layer_inputs = input_getter.get_input_activations(layer, forward_loop) + def _layer_forward_loop(m, _inputs=layer_inputs): + for args, kwargs_input in _inputs: + m(*args, **kwargs_input) - # Define a forward loop for the current layer - def _layer_forward_loop(m, _inputs=layer_inputs): - for args, kwargs_input in _inputs: - m(*args, **kwargs_input) + calib_func(layer, _layer_forward_loop, **calib_kwargs) - # Call calibration function - calib_func(layer, _layer_forward_loop, **calib_kwargs) - del layer_inputs - torch.cuda.empty_cache() + del layer_inputs + torch.cuda.empty_cache() + finally: + input_getter._unpatch_all_layers() diff --git a/modelopt/torch/quantization/plugins/huggingface.py b/modelopt/torch/quantization/plugins/huggingface.py index 4dd05ec4c..c08fb7258 100644 --- a/modelopt/torch/quantization/plugins/huggingface.py +++ b/modelopt/torch/quantization/plugins/huggingface.py @@ -1187,8 +1187,8 @@ def get_nemotron_h_decoder_layers(model: nn.Module) -> nn.ModuleList | None: if not _is_supported_hf_model(model): return None - if hasattr(model, "model") and hasattr(model.model, "layers"): - layers = model.model.layers + if hasattr(model, "backbone") and hasattr(model.backbone, "layers"): + layers = model.backbone.layers if len(layers) > 0 and hasattr(layers[0], "block_type"): return layers @@ -1215,101 +1215,6 @@ def get_homogeneous_hf_decoder_layers(model: nn.Module) -> nn.ModuleList | None: return None -def _extract_hidden_states(layer_output): - if isinstance(layer_output, tuple): - return layer_output[0] - if isinstance(layer_output, dict): - if "hidden_states" in layer_output: - return layer_output["hidden_states"] - return layer_output - - -def build_hf_homogenous_next_layer_inputs_hook(model: nn.Module): - def _build_next_layer_inputs_hook(prev_layer, cached_inputs): - next_inputs = [] - for args, kwargs in cached_inputs: - prev_output = prev_layer(*args, **kwargs) - hidden_states = _extract_hidden_states(prev_output) - if len(args) >= 1: - next_args = (hidden_states, *args[1:]) - next_kwargs = kwargs - elif "hidden_states" in kwargs: - next_args = args - next_kwargs = dict(kwargs) - next_kwargs["hidden_states"] = hidden_states - else: - raise ValueError( - "Unable to build next-layer inputs without hidden_states in args/kwargs." - ) - next_inputs.append((next_args, next_kwargs)) - return next_inputs - - return _build_next_layer_inputs_hook - - -def build_nemotron_h_next_layer_inputs_hook(model): - """Build a hook that propagates hidden_states and reconstructs per-block-type masks. - - Captures the original attention_mask via a forward pre-hook on model.model, then - reconstructs the correct mask for each layer's block_type using create_causal_mask - (for attention) or _update_mamba_mask (for mamba). - - Returns (hook, handle) where handle must be removed after first-layer collection. - """ - inner_model = model.model - layers = inner_model.layers - next_block_type_for = {layers[i]: layers[i + 1].block_type for i in range(len(layers) - 1)} - - update_mamba_mask = getattr(inner_model, "_update_mamba_mask", None) - - try: - from transformers.masking_utils import create_causal_mask - except ImportError: - create_causal_mask = None - - cached_original_masks = [] - - def _capture_attention_mask(module, args, kwargs): - cached_original_masks.append(kwargs.get("attention_mask")) - - handle = inner_model.register_forward_pre_hook(_capture_attention_mask, with_kwargs=True) - - base_hook = build_hf_homogenous_next_layer_inputs_hook(model) - - def hook(prev_layer, cached_inputs): - next_inputs = base_hook(prev_layer, cached_inputs) - - next_block_type = next_block_type_for.get(prev_layer) - if next_block_type is None: - return next_inputs - - for i, (args, kwargs) in enumerate(next_inputs): - original_mask = cached_original_masks[i] if i < len(cached_original_masks) else None - - if next_block_type == "mamba" and update_mamba_mask is not None: - mask = update_mamba_mask(original_mask, kwargs.get("cache_position")) - elif next_block_type == "attention" and create_causal_mask is not None: - hidden_states = args[0] if args else kwargs["hidden_states"] - mask = create_causal_mask( - config=inner_model.config, - input_embeds=hidden_states, - attention_mask=original_mask, - cache_position=kwargs.get("cache_position"), - past_key_values=kwargs.get("past_key_values"), - position_ids=kwargs.get("position_ids"), - ) - else: - mask = None - - next_kwargs = dict(kwargs) - next_kwargs["attention_mask"] = mask - next_inputs[i] = (args, next_kwargs) - - return next_inputs - - return hook, handle - - @contextmanager def setup_model_for_gradient_checkpointing(model: nn.Module): use_cache = None @@ -1367,14 +1272,6 @@ def _is_param_grad_enabled_for_auto_quantize(pname, model): is_homogenous_hf_model, get_homogeneous_hf_decoder_layers ) -LayerActivationCollector.register_next_layer_input_support( - is_nemotron_h_model, build_nemotron_h_next_layer_inputs_hook -) - -LayerActivationCollector.register_next_layer_input_support( - is_homogenous_hf_model, build_hf_homogenous_next_layer_inputs_hook -) - CUSTOM_MODEL_PLUGINS.update( [ register_falcon_linears_on_the_fly, diff --git a/modelopt/torch/quantization/utils.py b/modelopt/torch/quantization/utils.py index f3e94bbf8..c0d1e91b3 100644 --- a/modelopt/torch/quantization/utils.py +++ b/modelopt/torch/quantization/utils.py @@ -817,20 +817,34 @@ class _EarlyStopForwardError(Exception): class LayerActivationCollector: - """Helper class for collecting layer activations during forward passes. + """Helper class for collecting layer activations during sequential calibration. - This class allows for sequential layer calibration by - patching layers to capture inputs/outputs during forward passes + Uses a "skip / run / capture" strategy: each decoder layer is patched once with a + unified forward that checks ``_seq_calib_state`` to decide its behaviour. + + * **skip** -- return a lightweight dummy (no computation, no cache). + * **run** -- ignore the parent-provided input, use cached inputs from a prior + capture step, and execute the real forward. Only the just-calibrated layer + is in this state, so it reflects updated weights. + * **capture** -- record ``(args, kwargs)`` and raise ``_EarlyStopForwardError``. + * **passthrough** -- call the original forward unchanged. + + Because the *run* layer ignores upstream values, skip layers never need to + produce meaningful outputs. Memory overhead is O(B) -- only one layer's + captured inputs are kept at a time. """ - _next_layer_input_support: list[tuple[Any, Any]] = [] _decoder_layer_support: list[tuple[Any, Any]] = [] def __init__(self, model: nn.Module): self.model = model - self._previous_layer = None - self._previous_layer_inputs = None - self._next_layer_inputs_hook = None + self._decoder_layers: nn.ModuleList | None = None + self._layer_to_idx: dict[nn.Module, int] = {} + self._patched = False + + # ------------------------------------------------------------------ + # Decoder-layer discovery (unchanged public API) + # ------------------------------------------------------------------ @staticmethod def get_decoder_layers(model: nn.Module) -> nn.ModuleList | None: @@ -848,94 +862,129 @@ def is_supported(model: nn.Module) -> bool: """Whether the model supports decoder-layer sequential calibration.""" return LayerActivationCollector.get_decoder_layers(model) is not None - @classmethod - def register_next_layer_input_support( - cls, is_supported: Any, build_next_layer_inputs_hook: Any - ): - entry = (is_supported, build_next_layer_inputs_hook) - if entry not in cls._next_layer_input_support: - cls._next_layer_input_support.append(entry) - @classmethod def register_decoder_layer_support(cls, is_supported: Any, discoverer: Any): entry = (is_supported, discoverer) if entry not in cls._decoder_layer_support: cls._decoder_layer_support.append(entry) - @staticmethod - def _patch_and_initialize_layer(layer: torch.nn.Module, stop_after_collection: bool = False): - """Patch a layer to collect inputs during forward passes.""" - - def _forward_w_data_collection(self, *args, **kwargs): - # Note: 'self' refers to the patched layer. - assert len(args) >= 1, ( - f"Expected at least 1 positional arg, got {len(args)} args and {list(kwargs.keys())} kwargs" - ) - # Only collect the inputs to the layer - self.inputs.append((args, kwargs)) - if stop_after_collection: - raise _EarlyStopForwardError() # Stop the forward pass after collection + # ------------------------------------------------------------------ + # Unified patched forward + # ------------------------------------------------------------------ + @staticmethod + def _patched_forward(self, *args, **kwargs): + """Single forward bound to every decoder layer during sequential calibration.""" + state = self._seq_calib_state + if state == "skip": + print_rank_0(f"Skipping layer {self.name}") + return args[0] if args else next(iter(kwargs.values())) + elif state == "run": + print_rank_0(f"Running layer {self.name}") + real_args, real_kwargs = self._seq_calib_cached_inputs.pop(0) + return self._original_forward(*real_args, **real_kwargs) + elif state == "capture": + print_rank_0(f"Capturing layer {self.name}") + self._seq_calib_collected_inputs.append((args, kwargs)) + raise _EarlyStopForwardError() + else: return self._original_forward(*args, **kwargs) - bind_forward_method(layer, _forward_w_data_collection, "_original_forward") - layer.inputs = [] + # ------------------------------------------------------------------ + # Patch / unpatch lifecycle + # ------------------------------------------------------------------ - @staticmethod - def _unpatch_and_cleanup_layer(layer: torch.nn.Module): - if hasattr(layer, "_original_forward"): - unpatch_forward_method(layer, "_original_forward") - if hasattr(layer, "inputs"): - del layer.inputs - - def _resolve_next_layer_inputs_hook(self): - """Resolve the next-layer inputs hook from the registry. + def _patch_all_layers(self): + """Bind the unified forward to every decoder layer and the model. Called once.""" + self._decoder_layers = self.get_decoder_layers(self.model) + assert self._decoder_layers is not None + self._layer_to_idx = {layer: i for i, layer in enumerate(self._decoder_layers)} + module_to_name = {m: name for name, m in self.model.named_modules()} - Returns (hook, handle) where handle is an optional RemovableHandle for - pre-collection cleanup. If the factory returns just a hook, handle is None. - """ - for is_supported, build_next_layer_inputs_hook in self._next_layer_input_support: - if not is_supported(self.model): - continue - result = build_next_layer_inputs_hook(self.model) - if isinstance(result, tuple): - return result - return result, None - return None, None + for layer in self._decoder_layers: + layer._seq_calib_state = "passthrough" + layer._seq_calib_cached_inputs = [] + layer._seq_calib_collected_inputs = [] + layer._seq_calib_name = module_to_name.get(layer, type(layer).__name__) + bind_forward_method(layer, self._patched_forward, "_original_forward") - @torch.no_grad() - def _collect_input_activations(self, layer: torch.nn.Module, forward_loop: ForwardLoop) -> list: - # Wrap model forward to catch _EarlyStopForward per-batch def _early_stop_forward(self, *args, **kwargs): try: return self._original_forward(*args, **kwargs) except _EarlyStopForwardError: - return None # Stop propagation but allow next batch - - try: - bind_forward_method(self.model, _early_stop_forward, "_original_forward") - self._patch_and_initialize_layer(layer, stop_after_collection=True) - forward_loop(self.model) - inputs = layer.inputs.copy() - finally: - self._unpatch_and_cleanup_layer(layer) - unpatch_forward_method(self.model, "_original_forward") - - return inputs + return None + + bind_forward_method(self.model, _early_stop_forward, "_original_forward") + self._patched = True + + def _unpatch_all_layers(self): + """Restore original forwards and clean up state attributes. Called once.""" + if not self._patched: + return + assert self._decoder_layers is not None + + unpatch_forward_method(self.model, "_original_forward") + + for layer in self._decoder_layers: + if hasattr(layer, "_original_forward"): + unpatch_forward_method(layer, "_original_forward") + for attr in ( + "_seq_calib_state", + "_seq_calib_cached_inputs", + "_seq_calib_collected_inputs", + "_seq_calib_name", + ): + if hasattr(layer, attr): + delattr(layer, attr) + + self._patched = False + + # ------------------------------------------------------------------ + # Per-iteration state management + # ------------------------------------------------------------------ + + def _set_layer_states(self, layer_idx: int): + """Update only the affected layer states (O(1) per call). + + Layers are processed in sequential order, so only three transitions + can happen: the layer two back becomes ``"skip"``, the previous layer + becomes ``"run"``, and the current layer becomes ``"capture"``. + """ + assert self._decoder_layers is not None + if layer_idx > 1: + two_back = self._decoder_layers[layer_idx - 2] + two_back._seq_calib_state = "skip" + two_back._seq_calib_cached_inputs = [] + + if layer_idx > 0: + prev = self._decoder_layers[layer_idx - 1] + prev._seq_calib_state = "run" + prev._seq_calib_cached_inputs = prev._seq_calib_collected_inputs + prev._seq_calib_collected_inputs = [] + + cur = self._decoder_layers[layer_idx] + cur._seq_calib_state = "capture" + cur._seq_calib_collected_inputs = [] + + # ------------------------------------------------------------------ + # Public API + # ------------------------------------------------------------------ @torch.no_grad() def get_input_activations(self, layer: torch.nn.Module, forward_loop: ForwardLoop) -> list: - is_first_layer = self._previous_layer is None or self._previous_layer_inputs is None - if is_first_layer: - self._next_layer_inputs_hook, handle = self._resolve_next_layer_inputs_hook() - inputs = self._collect_input_activations(layer, forward_loop) - if handle is not None: - handle.remove() - elif self._next_layer_inputs_hook is None: - inputs = self._collect_input_activations(layer, forward_loop) - else: - inputs = self._next_layer_inputs_hook(self._previous_layer, self._previous_layer_inputs) + """Collect input activations for *layer* by running a full model forward. - self._previous_layer = layer - self._previous_layer_inputs = inputs + Layers before the target are skipped or re-run (if just calibrated), the + target layer captures its inputs, and an early-stop prevents unnecessary + computation beyond the target. + """ + layer_idx = self._layer_to_idx[layer] + self._set_layer_states(layer_idx) + print_rank_0(f"Getting input activations for layer {layer_idx}") + forward_loop(self.model) + inputs = list(layer._seq_calib_collected_inputs) + # After calibration this layer will be "run" (next iter) then "skip" (all + # subsequent). For the interim period where calib_func calls the layer + # directly, passthrough lets the original forward execute normally. + layer._seq_calib_state = "passthrough" return inputs diff --git a/tests/unit/torch/quantization/test_calib.py b/tests/unit/torch/quantization/test_calib.py index f534cd6bb..98fc0a915 100644 --- a/tests/unit/torch/quantization/test_calib.py +++ b/tests/unit/torch/quantization/test_calib.py @@ -482,67 +482,45 @@ def _pre_hook(_module, args): assert torch.allclose(observed, expected) -def test_sequential_calibrate_uses_next_layer_hook_without_replaying_full_model(monkeypatch): +def test_sequential_calibrate_handles_inter_layer_logic(monkeypatch): + """Verify that parent-level inter-layer logic (e.g. mask selection) works correctly.""" from modelopt.torch.quantization.utils import LayerActivationCollector class _ToyLayer(nn.Module): - def __init__(self, scale: float, bias: float): + def __init__(self, scale: float): super().__init__() self.scale = scale - self.bias = bias - def forward(self, hidden_states): - return hidden_states * self.scale + self.bias + def forward(self, hidden_states, mask=None): + if mask is not None: + hidden_states = hidden_states * mask + return hidden_states * self.scale class _ToyDecoder(nn.Module): def __init__(self): super().__init__() self.layers = nn.ModuleList( - [ - _ToyLayer(scale=2.0, bias=1.0), - _ToyLayer(scale=0.5, bias=3.0), - _ToyLayer(scale=1.0, bias=-2.0), - ] + [_ToyLayer(scale=2.0), _ToyLayer(scale=0.5), _ToyLayer(scale=3.0)] ) + self.masks = [1.0, 0.5, 2.0] def forward(self, hidden_states): - for layer in self.layers: - hidden_states = layer(hidden_states) + for layer, mask_val in zip(self.layers, self.masks): + mask = torch.full_like(hidden_states, mask_val) + hidden_states = layer(hidden_states, mask=mask) return hidden_states model = _ToyDecoder() - batches = [torch.tensor([[1.0, 2.0]]), torch.tensor([[3.0, 4.0]])] - forward_loop_calls = 0 - - def _forward_loop(m): - nonlocal forward_loop_calls - forward_loop_calls += 1 - for batch in batches: - m(batch) - - def _supported(_model): - return True - - def _build_hook(_model): - def _hook(prev_layer, cached_inputs): - next_inputs = [] - for args, kwargs in cached_inputs: - hidden_states = prev_layer(*args, **kwargs) - next_inputs.append(((hidden_states, *args[1:]), kwargs)) - return next_inputs - - return _hook - monkeypatch.setattr( LayerActivationCollector, "_decoder_layer_support", [(lambda m: hasattr(m, "layers"), lambda m: m.layers)], ) - monkeypatch.setattr( - LayerActivationCollector, - "_next_layer_input_support", - [(_supported, _build_hook)], - ) + batches = [torch.tensor([[1.0, 2.0]])] + + def _forward_loop(m): + for batch in batches: + m(batch) observed_layer_inputs = [] @@ -561,5 +539,10 @@ def _pre_hook(_module, args): sequential_calibrate(model, _forward_loop, _calib_func) - assert forward_loop_calls == 1 - assert len(observed_layer_inputs) == len(model.layers) + assert len(observed_layer_inputs) == 3 + # Layer 0 gets raw batch + assert torch.allclose(observed_layer_inputs[0][0], batches[0]) + # Layer 1 gets output of layer 0 (batch * mask0 * scale0 = [1,2] * 1.0 * 2.0 = [2,4]) + assert torch.allclose(observed_layer_inputs[1][0], torch.tensor([[2.0, 4.0]])) + # Layer 2 gets output of layer 1 (prev * mask1 * scale1 = [2,4] * 0.5 * 0.5 = [0.5,1.0]) + assert torch.allclose(observed_layer_inputs[2][0], torch.tensor([[0.5, 1.0]])) diff --git a/tests/unit/torch/quantization/test_utils.py b/tests/unit/torch/quantization/test_utils.py index 3dc0d93e3..6e31bdad4 100644 --- a/tests/unit/torch/quantization/test_utils.py +++ b/tests/unit/torch/quantization/test_utils.py @@ -23,15 +23,6 @@ ) -def _build_next_inputs(prev_layer, cached_inputs): - next_inputs = [] - for args, kwargs in cached_inputs: - prev_output = prev_layer(*args, **kwargs) - hidden_states = prev_output[0] if isinstance(prev_output, tuple) else prev_output - next_inputs.append(((hidden_states, *args[1:]), kwargs)) - return next_inputs - - @pytest.mark.parametrize( ("block_sizes", "test_input", "expected_scales"), [ @@ -211,105 +202,114 @@ def _discoverer(model): assert len(LayerActivationCollector._decoder_layer_support) == 1 -def test_layer_activation_collector_uses_first_matching_next_layer_hook(monkeypatch): +def test_layer_activation_collector_skip_forward_captures_correct_inputs(monkeypatch): + """The skip/run/capture strategy produces the same inputs as a plain forward.""" + class _ToyLayer(torch.nn.Module): - def forward(self, hidden_states, attention_mask=None): - return hidden_states + 1.0, attention_mask + def __init__(self, scale: float): + super().__init__() + self.scale = scale + + def forward(self, hidden_states): + return hidden_states * self.scale class _ToyDecoder(torch.nn.Module): def __init__(self): super().__init__() - self.layers = torch.nn.ModuleList([_ToyLayer(), _ToyLayer()]) + self.layers = torch.nn.ModuleList([_ToyLayer(2.0), _ToyLayer(0.5), _ToyLayer(3.0)]) - def forward(self, hidden_states, attention_mask=None): + def forward(self, hidden_states): for layer in self.layers: - hidden_states, _ = layer(hidden_states, attention_mask=attention_mask) + hidden_states = layer(hidden_states) return hidden_states - model = _ToyDecoder() - collector = LayerActivationCollector(model) - called = {"first": 0, "second": 0} - - def _unsupported(_model): - return False - - def _supported(_model): - return True - - def _build_first_hook(_model): - def _first_hook(prev_layer, cached_inputs): - called["first"] += 1 - return _build_next_inputs(prev_layer, cached_inputs) - - return _first_hook - - def _build_second_hook(_model): - def _second_hook(prev_layer, cached_inputs): - called["second"] += 1 - return _build_next_inputs(prev_layer, cached_inputs) - - return _second_hook - monkeypatch.setattr( LayerActivationCollector, - "_next_layer_input_support", - [ - (_unsupported, _build_first_hook), - (_supported, _build_second_hook), - (_supported, _build_first_hook), - ], + "_decoder_layer_support", + [(lambda m: hasattr(m, "layers"), lambda m: m.layers)], ) + model = _ToyDecoder() batches = [torch.tensor([[1.0, 2.0]]), torch.tensor([[3.0, 4.0]])] def _forward_loop(m): - for batch in batches: - m(batch) + for b in batches: + m(b) + + collector = LayerActivationCollector(model) + collector._patch_all_layers() + try: + inp0 = collector.get_input_activations(model.layers[0], _forward_loop) + inp1 = collector.get_input_activations(model.layers[1], _forward_loop) + inp2 = collector.get_input_activations(model.layers[2], _forward_loop) + finally: + collector._unpatch_all_layers() - first_inputs = collector.get_input_activations(model.layers[0], _forward_loop) - second_inputs = collector.get_input_activations(model.layers[1], _forward_loop) + expected_0 = batches + expected_1 = [model.layers[0](b) for b in batches] + expected_2 = [model.layers[1](b) for b in expected_1] - assert called["first"] == 0 - assert called["second"] == 1 - assert len(second_inputs) == len(first_inputs) - assert isinstance(second_inputs[0][0], tuple) - assert isinstance(second_inputs[0][1], dict) + for (args, _kw), exp in zip(inp0, expected_0): + assert torch.allclose(args[0], exp) + for (args, _kw), exp in zip(inp1, expected_1): + assert torch.allclose(args[0], exp) + for (args, _kw), exp in zip(inp2, expected_2): + assert torch.allclose(args[0], exp) -def test_layer_activation_collector_falls_back_to_collection_without_matching_hook(monkeypatch): +def test_layer_activation_collector_run_uses_cached_inputs_not_parent(monkeypatch): + """Verify that the 'run' layer uses cached inputs, not garbage from skip layers.""" + + call_log = [] + class _ToyLayer(torch.nn.Module): + def __init__(self, name, bias): + super().__init__() + self.layer_name = name + self.bias = bias + def forward(self, hidden_states): - return hidden_states + 1.0 + call_log.append(self.layer_name) + return hidden_states + self.bias class _ToyDecoder(torch.nn.Module): def __init__(self): super().__init__() - self.layers = torch.nn.ModuleList([_ToyLayer(), _ToyLayer()]) + self.layers = torch.nn.ModuleList( + [_ToyLayer("L0", 1.0), _ToyLayer("L1", 2.0), _ToyLayer("L2", 3.0)] + ) def forward(self, hidden_states): for layer in self.layers: hidden_states = layer(hidden_states) return hidden_states - model = _ToyDecoder() - collector = LayerActivationCollector(model) - collect_calls = {"count": 0} - original_collect = LayerActivationCollector._collect_input_activations - - def _spy_collect(self, layer, forward_loop): - collect_calls["count"] += 1 - return original_collect(self, layer, forward_loop) - - monkeypatch.setattr(LayerActivationCollector, "_next_layer_input_support", []) - monkeypatch.setattr(LayerActivationCollector, "_collect_input_activations", _spy_collect) + monkeypatch.setattr( + LayerActivationCollector, + "_decoder_layer_support", + [(lambda m: hasattr(m, "layers"), lambda m: m.layers)], + ) - batches = [torch.tensor([[1.0, 2.0]])] + model = _ToyDecoder() + batches = [torch.tensor([[10.0]])] def _forward_loop(m): - for batch in batches: - m(batch) + for b in batches: + m(b) - collector.get_input_activations(model.layers[0], _forward_loop) - collector.get_input_activations(model.layers[1], _forward_loop) - - assert collect_calls["count"] == 2 + collector = LayerActivationCollector(model) + collector._patch_all_layers() + try: + collector.get_input_activations(model.layers[0], _forward_loop) + call_log.clear() + collector.get_input_activations(model.layers[1], _forward_loop) + call_log_for_layer1 = list(call_log) + call_log.clear() + inp2 = collector.get_input_activations(model.layers[2], _forward_loop) + finally: + collector._unpatch_all_layers() + + assert "L0" in call_log_for_layer1 + assert "L1" not in call_log_for_layer1 + + assert torch.allclose(inp2[0][0][0], torch.tensor([[10.0 + 1.0 + 2.0]])) From 7f72422654969fa2b55e2e8634a65f4e36b2624d Mon Sep 17 00:00:00 2001 From: Suguna Velury <178320438+sugunav14@users.noreply.github.com> Date: Thu, 5 Mar 2026 21:16:33 +0000 Subject: [PATCH 06/11] added meta data caching Signed-off-by: Suguna Velury <178320438+sugunav14@users.noreply.github.com> --- modelopt/torch/quantization/model_calib.py | 1 + modelopt/torch/quantization/utils.py | 187 ++++++++++++++------- 2 files changed, 128 insertions(+), 60 deletions(-) diff --git a/modelopt/torch/quantization/model_calib.py b/modelopt/torch/quantization/model_calib.py index ecbbc24e8..6b2a3154e 100644 --- a/modelopt/torch/quantization/model_calib.py +++ b/modelopt/torch/quantization/model_calib.py @@ -1880,6 +1880,7 @@ def sequential_calibrate( return input_getter = LayerActivationCollector(model) + # Patch all transformer layers with state aware module forward input_getter._patch_all_layers() try: diff --git a/modelopt/torch/quantization/utils.py b/modelopt/torch/quantization/utils.py index c0d1e91b3..e5b568e57 100644 --- a/modelopt/torch/quantization/utils.py +++ b/modelopt/torch/quantization/utils.py @@ -17,6 +17,7 @@ from collections import namedtuple from contextlib import ExitStack, contextmanager, nullcontext +from dataclasses import dataclass, field from typing import TYPE_CHECKING, Any import torch @@ -813,28 +814,48 @@ def update_quant_cfg_with_kv_cache_quant( class _EarlyStopForwardError(Exception): - """Error to stop the forward pass after collection.""" + """Raised to halt the forward pass after capturing layer inputs.""" -class LayerActivationCollector: - """Helper class for collecting layer activations during sequential calibration. +@dataclass +class _LayerCalibState: + """Mutable per-layer state used during sequential calibration. + + Attached to each decoder layer as ``_seq_calib`` and accessed by the + patched forward to decide skip / run / capture / passthrough behaviour. + """ - Uses a "skip / run / capture" strategy: each decoder layer is patched once with a - unified forward that checks ``_seq_calib_state`` to decide its behaviour. + mode: str = "passthrough" + name: str = "" + cached_inputs: list = field(default_factory=list) + collected_inputs: list = field(default_factory=list) + output_meta: tuple | None = None - * **skip** -- return a lightweight dummy (no computation, no cache). - * **run** -- ignore the parent-provided input, use cached inputs from a prior - capture step, and execute the real forward. Only the just-calibrated layer - is in this state, so it reflects updated weights. - * **capture** -- record ``(args, kwargs)`` and raise ``_EarlyStopForwardError``. - * **passthrough** -- call the original forward unchanged. - Because the *run* layer ignores upstream values, skip layers never need to - produce meaningful outputs. Memory overhead is O(B) -- only one layer's - captured inputs are kept at a time. +class LayerActivationCollector: + """Collects layer activations for sequential (layer-by-layer) calibration. + + Each decoder layer is patched with a unified forward whose behaviour is + governed by a per-layer :class:`_LayerCalibState`: + + * **skip** — return a zero-filled dummy whose shape and type match the + layer's real output (reconstructed from lightweight metadata). No + computation is performed. The correctly shaped dummy ensures un-patched + inter-layer operations in the parent forward (e.g. LayerNorm, tuple + unpacking) do not raise shape or type errors. + * **run** — replay previously captured inputs through the original forward, + ignoring whatever the parent passes in. Only the just-calibrated layer + uses this mode, so its output reflects updated weights. + * **capture** — record ``(args, kwargs)`` and raise + ``_EarlyStopForwardError`` to halt the forward pass early. + * **passthrough** — call the original forward unchanged. + + Because the *run* layer discards upstream values, skip-layer outputs are + never consumed for real computation. """ _decoder_layer_support: list[tuple[Any, Any]] = [] + _LAYER_ATTR = "_seq_calib" def __init__(self, model: nn.Module): self.model = model @@ -843,7 +864,7 @@ def __init__(self, model: nn.Module): self._patched = False # ------------------------------------------------------------------ - # Decoder-layer discovery (unchanged public API) + # Decoder-layer discovery # ------------------------------------------------------------------ @staticmethod @@ -869,26 +890,74 @@ def register_decoder_layer_support(cls, is_supported: Any, discoverer: Any): cls._decoder_layer_support.append(entry) # ------------------------------------------------------------------ - # Unified patched forward + # Output metadata helpers + # ------------------------------------------------------------------ + + @staticmethod + def _extract_output_meta(output): + """Extract lightweight (shape, dtype, device) metadata from a layer output. + + Recursively handles tensors, tuples, lists, and non-tensor values (e.g. None). + The returned structure can be passed to ``_zeros_from_meta`` to reconstruct a + zero-filled output with identical shape and type. + """ + if isinstance(output, torch.Tensor): + return ("tensor", output.shape, output.dtype, output.device) + if isinstance(output, tuple): + return ( + "tuple", + tuple(LayerActivationCollector._extract_output_meta(o) for o in output), + ) + if isinstance(output, list): + return ("list", [LayerActivationCollector._extract_output_meta(o) for o in output]) + return ("other", output) + + @staticmethod + def _zeros_from_meta(meta): + """Reconstruct a zero-filled output from metadata produced by ``_extract_output_meta``.""" + tag = meta[0] + if tag == "tensor": + _, shape, dtype, device = meta + return torch.zeros(shape, dtype=dtype, device=device) + if tag == "tuple": + return tuple(LayerActivationCollector._zeros_from_meta(m) for m in meta[1]) + if tag == "list": + return [LayerActivationCollector._zeros_from_meta(m) for m in meta[1]] + return meta[1] + + # ------------------------------------------------------------------ + # Patched forward # ------------------------------------------------------------------ @staticmethod def _patched_forward(self, *args, **kwargs): - """Single forward bound to every decoder layer during sequential calibration.""" - state = self._seq_calib_state - if state == "skip": - print_rank_0(f"Skipping layer {self.name}") + """Unified forward bound to every decoder layer during sequential calibration. + + ``self`` here is the decoder layer module (bound via ``bind_forward_method``). + All per-layer state is accessed through ``self._seq_calib``. + """ + info: _LayerCalibState = self._seq_calib + + if info.mode == "skip": + print_rank_0(f"Skipping layer {info.name}") + if info.output_meta is not None: + return LayerActivationCollector._zeros_from_meta(info.output_meta) + print_rank_0(f"Warning: No output metadata found for layer {info.name}") return args[0] if args else next(iter(kwargs.values())) - elif state == "run": - print_rank_0(f"Running layer {self.name}") - real_args, real_kwargs = self._seq_calib_cached_inputs.pop(0) - return self._original_forward(*real_args, **real_kwargs) - elif state == "capture": - print_rank_0(f"Capturing layer {self.name}") - self._seq_calib_collected_inputs.append((args, kwargs)) + + if info.mode == "run": + print_rank_0(f"Running layer {info.name}") + real_args, real_kwargs = info.cached_inputs.pop(0) + output = self._original_forward(*real_args, **real_kwargs) + info.output_meta = LayerActivationCollector._extract_output_meta(output) + return output + + if info.mode == "capture": + print_rank_0(f"Capturing layer {info.name}") + info.collected_inputs.append((args, kwargs)) raise _EarlyStopForwardError() - else: - return self._original_forward(*args, **kwargs) + + return self._original_forward(*args, **kwargs) # ------------------------------------------------------------------ # Patch / unpatch lifecycle @@ -902,10 +971,9 @@ def _patch_all_layers(self): module_to_name = {m: name for name, m in self.model.named_modules()} for layer in self._decoder_layers: - layer._seq_calib_state = "passthrough" - layer._seq_calib_cached_inputs = [] - layer._seq_calib_collected_inputs = [] - layer._seq_calib_name = module_to_name.get(layer, type(layer).__name__) + layer._seq_calib = _LayerCalibState( + name=module_to_name.get(layer, type(layer).__name__), + ) bind_forward_method(layer, self._patched_forward, "_original_forward") def _early_stop_forward(self, *args, **kwargs): @@ -928,14 +996,8 @@ def _unpatch_all_layers(self): for layer in self._decoder_layers: if hasattr(layer, "_original_forward"): unpatch_forward_method(layer, "_original_forward") - for attr in ( - "_seq_calib_state", - "_seq_calib_cached_inputs", - "_seq_calib_collected_inputs", - "_seq_calib_name", - ): - if hasattr(layer, attr): - delattr(layer, attr) + if hasattr(layer, self._LAYER_ATTR): + delattr(layer, self._LAYER_ATTR) self._patched = False @@ -944,27 +1006,30 @@ def _unpatch_all_layers(self): # ------------------------------------------------------------------ def _set_layer_states(self, layer_idx: int): - """Update only the affected layer states (O(1) per call). + """Transition layer modes for the next calibration step. - Layers are processed in sequential order, so only three transitions - can happen: the layer two back becomes ``"skip"``, the previous layer - becomes ``"run"``, and the current layer becomes ``"capture"``. + When calibrating layer *i*, three transitions happen: + + * Layer ``i - 2`` → **skip** (fully done, free its cached inputs). + * Layer ``i - 1`` → **run** (replay captured inputs with calibrated weights). + * Layer ``i`` → **capture** (record inputs, then early-stop). """ assert self._decoder_layers is not None + if layer_idx > 1: - two_back = self._decoder_layers[layer_idx - 2] - two_back._seq_calib_state = "skip" - two_back._seq_calib_cached_inputs = [] + done = self._decoder_layers[layer_idx - 2]._seq_calib + done.mode = "skip" + done.cached_inputs = [] if layer_idx > 0: - prev = self._decoder_layers[layer_idx - 1] - prev._seq_calib_state = "run" - prev._seq_calib_cached_inputs = prev._seq_calib_collected_inputs - prev._seq_calib_collected_inputs = [] + prev = self._decoder_layers[layer_idx - 1]._seq_calib + prev.mode = "run" + prev.cached_inputs = prev.collected_inputs + prev.collected_inputs = [] - cur = self._decoder_layers[layer_idx] - cur._seq_calib_state = "capture" - cur._seq_calib_collected_inputs = [] + cur = self._decoder_layers[layer_idx]._seq_calib + cur.mode = "capture" + cur.collected_inputs = [] # ------------------------------------------------------------------ # Public API @@ -982,9 +1047,11 @@ def get_input_activations(self, layer: torch.nn.Module, forward_loop: ForwardLoo self._set_layer_states(layer_idx) print_rank_0(f"Getting input activations for layer {layer_idx}") forward_loop(self.model) - inputs = list(layer._seq_calib_collected_inputs) - # After calibration this layer will be "run" (next iter) then "skip" (all - # subsequent). For the interim period where calib_func calls the layer - # directly, passthrough lets the original forward execute normally. - layer._seq_calib_state = "passthrough" + + info = layer._seq_calib + inputs = list(info.collected_inputs) + # After capture, set to passthrough so calib_func can call the layer's + # original forward directly. The layer will transition to run → skip + # in subsequent iterations via _set_layer_states. + info.mode = "passthrough" return inputs From 2f72d6d544e73e38e360b84e71a4bc8bee143c1e Mon Sep 17 00:00:00 2001 From: Suguna Velury <178320438+sugunav14@users.noreply.github.com> Date: Thu, 5 Mar 2026 22:57:57 +0000 Subject: [PATCH 07/11] added logging and unit tests Signed-off-by: Suguna Velury <178320438+sugunav14@users.noreply.github.com> --- modelopt/torch/quantization/model_calib.py | 3 - modelopt/torch/quantization/utils.py | 36 ++- .../quantization/test_sequential_calibrate.py | 304 ++++++++++++++++++ 3 files changed, 327 insertions(+), 16 deletions(-) diff --git a/modelopt/torch/quantization/model_calib.py b/modelopt/torch/quantization/model_calib.py index 6b2a3154e..fe14e364d 100644 --- a/modelopt/torch/quantization/model_calib.py +++ b/modelopt/torch/quantization/model_calib.py @@ -1864,9 +1864,6 @@ def sequential_calibrate( skip / run / capture strategy so that inter-layer logic in parent modules (e.g. mask construction) executes naturally without model-specific hooks. """ - if forward_loop is None: - raise ValueError("forward_loop must not be None for sequential calibration.") - if not LayerActivationCollector.is_supported(model): raise ValueError( "Could not find transformer layers in model. " diff --git a/modelopt/torch/quantization/utils.py b/modelopt/torch/quantization/utils.py index e5b568e57..3f1af580c 100644 --- a/modelopt/torch/quantization/utils.py +++ b/modelopt/torch/quantization/utils.py @@ -34,8 +34,6 @@ if TYPE_CHECKING: from collections.abc import Generator - from modelopt.torch.opt.searcher import ForwardLoop - __all__ = [ "EXPORT_MODE", "convert_quantization_axis_to_reduce_axis", @@ -822,10 +820,10 @@ class _LayerCalibState: """Mutable per-layer state used during sequential calibration. Attached to each decoder layer as ``_seq_calib`` and accessed by the - patched forward to decide skip / run / capture / passthrough behaviour. + patched forward to decide skip / run / capture / original behaviour. """ - mode: str = "passthrough" + mode: str = "original" name: str = "" cached_inputs: list = field(default_factory=list) collected_inputs: list = field(default_factory=list) @@ -848,7 +846,7 @@ class LayerActivationCollector: uses this mode, so its output reflects updated weights. * **capture** — record ``(args, kwargs)`` and raise ``_EarlyStopForwardError`` to halt the forward pass early. - * **passthrough** — call the original forward unchanged. + * **original** — call the original forward unchanged. Because the *run* layer discards upstream values, skip-layer outputs are never consumed for real computation. @@ -939,21 +937,21 @@ def _patched_forward(self, *args, **kwargs): info: _LayerCalibState = self._seq_calib if info.mode == "skip": - print_rank_0(f"Skipping layer {info.name}") if info.output_meta is not None: return LayerActivationCollector._zeros_from_meta(info.output_meta) - print_rank_0(f"Warning: No output metadata found for layer {info.name}") + print_rank_0(f"Layer {info.name} is in 'skip' mode but has no output meta to return") return args[0] if args else next(iter(kwargs.values())) if info.mode == "run": - print_rank_0(f"Running layer {info.name}") + assert info.cached_inputs, ( + f"Layer {info.name} is in 'run' mode but has no cached inputs to replay." + ) real_args, real_kwargs = info.cached_inputs.pop(0) output = self._original_forward(*real_args, **real_kwargs) info.output_meta = LayerActivationCollector._extract_output_meta(output) return output if info.mode == "capture": - print_rank_0(f"Capturing layer {info.name}") info.collected_inputs.append((args, kwargs)) raise _EarlyStopForwardError() @@ -1031,6 +1029,18 @@ def _set_layer_states(self, layer_idx: int): cur.mode = "capture" cur.collected_inputs = [] + def _log_layer_summary(self, layer_idx: int): + """Log a one-line summary of layer modes for the current calibration step.""" + assert self._decoder_layers is not None + n = len(self._decoder_layers) + groups: dict[str, list[int]] = {} + for i, layer in enumerate(self._decoder_layers): + mode = layer._seq_calib.mode + if mode in ("skip", "run", "capture"): + groups.setdefault(mode, []).append(i) + parts = [f"{mode}: {groups[mode]}" for mode in ("skip", "run", "capture") if mode in groups] + print_rank_0(f"Calibrating layer {layer_idx}/{n} | {' | '.join(parts)}") + # ------------------------------------------------------------------ # Public API # ------------------------------------------------------------------ @@ -1045,13 +1055,13 @@ def get_input_activations(self, layer: torch.nn.Module, forward_loop: ForwardLoo """ layer_idx = self._layer_to_idx[layer] self._set_layer_states(layer_idx) - print_rank_0(f"Getting input activations for layer {layer_idx}") + self._log_layer_summary(layer_idx) forward_loop(self.model) info = layer._seq_calib inputs = list(info.collected_inputs) - # After capture, set to passthrough so calib_func can call the layer's - # original forward directly. The layer will transition to run → skip + # After capture, set to original so calib_func can call the layer's + # real forward directly. The layer will transition to run → skip # in subsequent iterations via _set_layer_states. - info.mode = "passthrough" + info.mode = "original" return inputs diff --git a/tests/unit/torch/quantization/test_sequential_calibrate.py b/tests/unit/torch/quantization/test_sequential_calibrate.py index 3b6b166be..0cba247af 100644 --- a/tests/unit/torch/quantization/test_sequential_calibrate.py +++ b/tests/unit/torch/quantization/test_sequential_calibrate.py @@ -354,3 +354,307 @@ def counting_forward(*args, **kw): for count in replay_counts: assert count == 0 + + +# --------------------------------------------------------------------------- +# Skip / run / capture path verification tests +# --------------------------------------------------------------------------- + + +class _TupleReturningBlock(nn.Module): + """Decoder layer that returns a tuple, mimicking HuggingFace decoder layers.""" + + def __init__(self, dim=16): + super().__init__() + self.linear = nn.Linear(dim, dim, bias=False) + + def forward(self, x, **kwargs): + return (self.linear(x), None) + + +class _TupleUnpackingModel(nn.Module): + """Parent model that unpacks layer outputs as tuples. + + This would crash with a naive skip that returns a bare tensor. + """ + + def __init__(self, n_layers=4, dim=16): + super().__init__() + self.layers = nn.ModuleList([_TupleReturningBlock(dim) for _ in range(n_layers)]) + + def forward(self, x): + for layer in self.layers: + x, _ = layer(x) + return x + + +class _InterLayerNormModel(nn.Module): + """Model with LayerNorm between decoder layers (not inside them).""" + + def __init__(self, n_layers=4, dim=16): + super().__init__() + self.layers = nn.ModuleList([_TupleReturningBlock(dim) for _ in range(n_layers)]) + self.norms = nn.ModuleList([nn.LayerNorm(dim) for _ in range(n_layers)]) + + def forward(self, x): + for norm, layer in zip(self.norms, self.layers): + x = norm(x) + x, _ = layer(x) + return x + + +def _register_test_discoverer(monkeypatch): + """Register a simple discoverer that finds model.layers on any model.""" + monkeypatch.setattr( + LayerActivationCollector, + "_decoder_layer_support", + [(lambda m: hasattr(m, "layers"), lambda m: m.layers)], + ) + + +def test_skip_output_preserves_tuple_structure(monkeypatch): + """Skip layers must return a tuple when the real layer returns a tuple. + + Without this, the parent's ``x, _ = layer(x)`` unpacking would crash. + """ + _register_test_discoverer(monkeypatch) + model = _TupleUnpackingModel(n_layers=5, dim=16) + data = [torch.randn(2, 16) for _ in range(3)] + + def forward_loop(m): + for d in data: + m(d) + + collector = LayerActivationCollector(model) + collector._patch_all_layers() + try: + for layer in model.layers: + inputs = collector.get_input_activations(layer, forward_loop) + assert len(inputs) == len(data) + finally: + collector._unpatch_all_layers() + + +def test_skip_output_preserves_shape_with_inter_layer_norm(monkeypatch): + """Skip outputs must have correct shape for un-patched LayerNorm between layers.""" + _register_test_discoverer(monkeypatch) + model = _InterLayerNormModel(n_layers=5, dim=16) + data = [torch.randn(2, 16) for _ in range(3)] + + def forward_loop(m): + for d in data: + m(d) + + collector = LayerActivationCollector(model) + collector._patch_all_layers() + try: + for layer in model.layers: + inputs = collector.get_input_activations(layer, forward_loop) + assert len(inputs) == len(data) + finally: + collector._unpatch_all_layers() + + +def test_run_layer_populates_output_meta(monkeypatch): + """After a layer executes in 'run' mode, its output_meta must be set.""" + _register_test_discoverer(monkeypatch) + model = _TupleUnpackingModel(n_layers=3, dim=16) + data = [torch.randn(2, 16)] + + def forward_loop(m): + for d in data: + m(d) + + collector = LayerActivationCollector(model) + collector._patch_all_layers() + try: + # Layer 0 starts as capture — no output_meta yet + collector.get_input_activations(model.layers[0], forward_loop) + assert model.layers[0]._seq_calib.output_meta is None + + # Calibrating layer 1 puts layer 0 into run, which sets output_meta + collector.get_input_activations(model.layers[1], forward_loop) + meta = model.layers[0]._seq_calib.output_meta + assert meta is not None + assert meta[0] == "tuple", "Tuple-returning layer should produce tuple metadata" + finally: + collector._unpatch_all_layers() + + +def test_run_layer_consumes_cached_inputs(monkeypatch): + """The run layer must pop all cached inputs during the forward loop.""" + _register_test_discoverer(monkeypatch) + n_batches = 4 + model = _TupleUnpackingModel(n_layers=3, dim=16) + data = [torch.randn(2, 16) for _ in range(n_batches)] + + def forward_loop(m): + for d in data: + m(d) + + collector = LayerActivationCollector(model) + collector._patch_all_layers() + try: + collector.get_input_activations(model.layers[0], forward_loop) + collector.get_input_activations(model.layers[1], forward_loop) + + # Before calibrating layer 2, layer 1 transitions to run. + # Its cached_inputs should be populated from collected_inputs. + collector._set_layer_states(2) + assert len(model.layers[1]._seq_calib.cached_inputs) == n_batches + + # After the forward loop, all cached inputs should be consumed + forward_loop(model) + assert len(model.layers[1]._seq_calib.cached_inputs) == 0 + finally: + collector._unpatch_all_layers() + + +def test_capture_layer_collects_all_batches(monkeypatch): + """The capture layer must record one entry per batch in the forward loop.""" + _register_test_discoverer(monkeypatch) + n_batches = 5 + model = _TupleUnpackingModel(n_layers=3, dim=16) + data = [torch.randn(2, 16) for _ in range(n_batches)] + + def forward_loop(m): + for d in data: + m(d) + + collector = LayerActivationCollector(model) + collector._patch_all_layers() + try: + inputs = collector.get_input_activations(model.layers[0], forward_loop) + assert len(inputs) == n_batches + + inputs = collector.get_input_activations(model.layers[2], forward_loop) + assert len(inputs) == n_batches + finally: + collector._unpatch_all_layers() + + +def test_mode_transitions_across_calibration_steps(monkeypatch): + """Verify mode transitions follow the skip/run/capture pattern at each step.""" + _register_test_discoverer(monkeypatch) + model = _TupleUnpackingModel(n_layers=5, dim=16) + + collector = LayerActivationCollector(model) + collector._patch_all_layers() + try: + + def modes(): + return [model.layers[i]._seq_calib.mode for i in range(5)] + + collector._set_layer_states(0) + assert modes() == ["capture", "original", "original", "original", "original"] + + collector._set_layer_states(1) + assert modes() == ["run", "capture", "original", "original", "original"] + + collector._set_layer_states(2) + assert modes() == ["skip", "run", "capture", "original", "original"] + + collector._set_layer_states(3) + assert modes() == ["skip", "skip", "run", "capture", "original"] + + collector._set_layer_states(4) + assert modes() == ["skip", "skip", "skip", "run", "capture"] + finally: + collector._unpatch_all_layers() + + +def test_run_asserts_on_empty_cached_inputs(monkeypatch): + """A layer in 'run' mode with no cached inputs must raise AssertionError.""" + _register_test_discoverer(monkeypatch) + model = _TupleUnpackingModel(n_layers=2, dim=16) + + collector = LayerActivationCollector(model) + collector._patch_all_layers() + try: + model.layers[0]._seq_calib.mode = "run" + model.layers[0]._seq_calib.cached_inputs = [] + + with pytest.raises(AssertionError, match="no cached inputs to replay"): + model(torch.randn(2, 16)) + finally: + collector._unpatch_all_layers() + + +def test_cleanup_removes_seq_calib_attr(monkeypatch): + """After unpatch, no layer should have the _seq_calib attribute.""" + _register_test_discoverer(monkeypatch) + model = _TupleUnpackingModel(n_layers=3, dim=16) + data = [torch.randn(2, 16)] + + def forward_loop(m): + for d in data: + m(d) + + collector = LayerActivationCollector(model) + collector._patch_all_layers() + for layer in model.layers: + collector.get_input_activations(layer, forward_loop) + collector._unpatch_all_layers() + + for i, layer in enumerate(model.layers): + assert not hasattr(layer, "_seq_calib"), f"Layer {i} still has _seq_calib after cleanup" + assert not hasattr(layer, "_original_forward"), ( + f"Layer {i} still has _original_forward after cleanup" + ) + assert not hasattr(model, "_original_forward") + + +def test_skip_output_meta_not_shared_across_heterogeneous_layers(monkeypatch): + """Each layer stores its own output_meta, supporting heterogeneous architectures.""" + + class _SmallBlock(nn.Module): + def __init__(self): + super().__init__() + self.linear = nn.Linear(8, 8, bias=False) + + def forward(self, x): + return (self.linear(x), None, torch.zeros(1)) + + class _BigBlock(nn.Module): + def __init__(self): + super().__init__() + self.linear = nn.Linear(8, 8, bias=False) + + def forward(self, x): + return (self.linear(x),) + + class _HeterogeneousModel(nn.Module): + def __init__(self): + super().__init__() + self.layers = nn.ModuleList([_SmallBlock(), _BigBlock(), _SmallBlock()]) + + def forward(self, x): + for layer in self.layers: + out = layer(x) + x = out[0] + return x + + _register_test_discoverer(monkeypatch) + model = _HeterogeneousModel() + data = [torch.randn(2, 8)] + + def forward_loop(m): + for d in data: + m(d) + + collector = LayerActivationCollector(model) + collector._patch_all_layers() + try: + for layer in model.layers: + collector.get_input_activations(layer, forward_loop) + + # After full calibration, layers 0 and 1 have been through 'run' and have output_meta + meta_0 = model.layers[0]._seq_calib.output_meta + meta_1 = model.layers[1]._seq_calib.output_meta + assert meta_0 is not None + assert meta_1 is not None + # SmallBlock returns 3-element tuple, BigBlock returns 1-element tuple + assert len(meta_0[1]) == 3 + assert len(meta_1[1]) == 1 + finally: + collector._unpatch_all_layers() From 96ccdadf094b0faeced5a3e202c5481d49eb8453 Mon Sep 17 00:00:00 2001 From: Suguna Velury <178320438+sugunav14@users.noreply.github.com> Date: Thu, 5 Mar 2026 23:01:42 +0000 Subject: [PATCH 08/11] removed stray changes Signed-off-by: Suguna Velury <178320438+sugunav14@users.noreply.github.com> --- modelopt/torch/quantization/model_calib.py | 24 +--------------------- 1 file changed, 1 insertion(+), 23 deletions(-) diff --git a/modelopt/torch/quantization/model_calib.py b/modelopt/torch/quantization/model_calib.py index fe14e364d..cd2c575b6 100644 --- a/modelopt/torch/quantization/model_calib.py +++ b/modelopt/torch/quantization/model_calib.py @@ -127,32 +127,10 @@ def max_calibrate( forward_loop(model) finish_stats_collection(model) - # Sync amax across local experts within each rank (for SequentialMLP and HuggingFace MoE) + # Sync input_quantizer amax across local experts within each rank (for SequentialMLP) for name, module in model.named_modules(): if hasattr(module, "layer_sync_moe_local_experts_amax"): module.layer_sync_moe_local_experts_amax() - elif hasattr(module, "sync_moe_local_experts_amax"): - module.sync_moe_local_experts_amax() - - for name, module in list(model.named_modules()): - if isinstance(module, TensorQuantizer) and not module._disabled: - if module._calibrator is not None and not module._dynamic and hasattr(module, "_amax"): - # Get the initial amax from max calibration - initial_amax = module._amax.clone().detach() - - is_nvfp4_static = ( - module.is_static_block_quant - and module._num_bits == (2, 1) - and module._block_sizes is not None - and module._block_sizes.get("scale_bits") == (4, 3) - ) - - if is_nvfp4_static: - # Compute and set global_amax - global_amax = reduce_amax(initial_amax, axis=None) - - # Convert to NVFP4StaticQuantizer in-place - NVFP4StaticQuantizer.from_tensor_quantizer(module, global_amax=global_amax) if not distributed_sync: return From b25d5850b5cd1a9954b3f7e2c41de8f26b4e75c3 Mon Sep 17 00:00:00 2001 From: Suguna Velury <178320438+sugunav14@users.noreply.github.com> Date: Tue, 10 Mar 2026 22:58:12 +0000 Subject: [PATCH 09/11] AI PR comments addressed Signed-off-by: Suguna Velury <178320438+sugunav14@users.noreply.github.com> --- modelopt/torch/quantization/utils.py | 80 +++++++----- .../quantization/test_sequential_calibrate.py | 114 +++++++++++------- 2 files changed, 117 insertions(+), 77 deletions(-) diff --git a/modelopt/torch/quantization/utils.py b/modelopt/torch/quantization/utils.py index 3f1af580c..4aa5d98ba 100644 --- a/modelopt/torch/quantization/utils.py +++ b/modelopt/torch/quantization/utils.py @@ -15,7 +15,8 @@ """Quantization utilities.""" -from collections import namedtuple +import copy +from collections import deque, namedtuple from contextlib import ExitStack, contextmanager, nullcontext from dataclasses import dataclass, field from typing import TYPE_CHECKING, Any @@ -825,7 +826,7 @@ class _LayerCalibState: mode: str = "original" name: str = "" - cached_inputs: list = field(default_factory=list) + cached_inputs: deque = field(default_factory=deque) collected_inputs: list = field(default_factory=list) output_meta: tuple | None = None @@ -921,7 +922,7 @@ def _zeros_from_meta(meta): return tuple(LayerActivationCollector._zeros_from_meta(m) for m in meta[1]) if tag == "list": return [LayerActivationCollector._zeros_from_meta(m) for m in meta[1]] - return meta[1] + return copy.deepcopy(meta[1]) # ------------------------------------------------------------------ # Patched forward @@ -937,16 +938,19 @@ def _patched_forward(self, *args, **kwargs): info: _LayerCalibState = self._seq_calib if info.mode == "skip": - if info.output_meta is not None: - return LayerActivationCollector._zeros_from_meta(info.output_meta) - print_rank_0(f"Layer {info.name} is in 'skip' mode but has no output meta to return") - return args[0] if args else next(iter(kwargs.values())) + if info.output_meta is None: + raise RuntimeError( + f"Layer {info.name} is in 'skip' mode but has no output_meta. " + "This indicates a state-machine bug: the layer should have run " + "in 'run' mode (which sets output_meta) before transitioning to 'skip'." + ) + return LayerActivationCollector._zeros_from_meta(info.output_meta) if info.mode == "run": assert info.cached_inputs, ( f"Layer {info.name} is in 'run' mode but has no cached inputs to replay." ) - real_args, real_kwargs = info.cached_inputs.pop(0) + real_args, real_kwargs = info.cached_inputs.popleft() output = self._original_forward(*real_args, **real_kwargs) info.output_meta = LayerActivationCollector._extract_output_meta(output) return output @@ -968,35 +972,43 @@ def _patch_all_layers(self): self._layer_to_idx = {layer: i for i, layer in enumerate(self._decoder_layers)} module_to_name = {m: name for name, m in self.model.named_modules()} - for layer in self._decoder_layers: - layer._seq_calib = _LayerCalibState( - name=module_to_name.get(layer, type(layer).__name__), - ) - bind_forward_method(layer, self._patched_forward, "_original_forward") + try: + for layer in self._decoder_layers: + layer._seq_calib = _LayerCalibState( + name=module_to_name.get(layer, type(layer).__name__), + ) + bind_forward_method(layer, self._patched_forward, "_original_forward") + + def _early_stop_forward(self, *args, **kwargs): + try: + return self._original_forward(*args, **kwargs) + except _EarlyStopForwardError: + return None - def _early_stop_forward(self, *args, **kwargs): - try: - return self._original_forward(*args, **kwargs) - except _EarlyStopForwardError: - return None + bind_forward_method(self.model, _early_stop_forward, "_original_forward") + except Exception: + self._cleanup_layers() + raise - bind_forward_method(self.model, _early_stop_forward, "_original_forward") self._patched = True + def _cleanup_layers(self): + """Best-effort cleanup of any patched layers and model forward.""" + if hasattr(self.model, "_original_forward"): + unpatch_forward_method(self.model, "_original_forward") + + if self._decoder_layers is not None: + for layer in self._decoder_layers: + if hasattr(layer, "_original_forward"): + unpatch_forward_method(layer, "_original_forward") + if hasattr(layer, self._LAYER_ATTR): + delattr(layer, self._LAYER_ATTR) + def _unpatch_all_layers(self): """Restore original forwards and clean up state attributes. Called once.""" if not self._patched: return - assert self._decoder_layers is not None - - unpatch_forward_method(self.model, "_original_forward") - - for layer in self._decoder_layers: - if hasattr(layer, "_original_forward"): - unpatch_forward_method(layer, "_original_forward") - if hasattr(layer, self._LAYER_ATTR): - delattr(layer, self._LAYER_ATTR) - + self._cleanup_layers() self._patched = False # ------------------------------------------------------------------ @@ -1017,12 +1029,12 @@ def _set_layer_states(self, layer_idx: int): if layer_idx > 1: done = self._decoder_layers[layer_idx - 2]._seq_calib done.mode = "skip" - done.cached_inputs = [] + done.cached_inputs = deque() if layer_idx > 0: prev = self._decoder_layers[layer_idx - 1]._seq_calib prev.mode = "run" - prev.cached_inputs = prev.collected_inputs + prev.cached_inputs = deque(prev.collected_inputs) prev.collected_inputs = [] cur = self._decoder_layers[layer_idx]._seq_calib @@ -1052,7 +1064,13 @@ def get_input_activations(self, layer: torch.nn.Module, forward_loop: ForwardLoo Layers before the target are skipped or re-run (if just calibrated), the target layer captures its inputs, and an early-stop prevents unnecessary computation beyond the target. + + :meth:`_patch_all_layers` must be called before this method. """ + if not self._patched: + raise RuntimeError( + "get_input_activations() requires _patch_all_layers() to be called first." + ) layer_idx = self._layer_to_idx[layer] self._set_layer_states(layer_idx) self._log_layer_summary(layer_idx) diff --git a/tests/unit/torch/quantization/test_sequential_calibrate.py b/tests/unit/torch/quantization/test_sequential_calibrate.py index 0cba247af..2fca85a75 100644 --- a/tests/unit/torch/quantization/test_sequential_calibrate.py +++ b/tests/unit/torch/quantization/test_sequential_calibrate.py @@ -15,6 +15,8 @@ """Unit tests for sequential_calibrate and LayerActivationCollector.""" +from collections import deque + import pytest import torch import torch.nn as nn @@ -98,7 +100,17 @@ def _run_forward(model, data): # LayerActivationCollector tests -def test_collector_collects_correct_number_of_inputs(): +def _register_test_discoverer(monkeypatch): + """Register a simple discoverer that finds model.layers on any model.""" + monkeypatch.setattr( + LayerActivationCollector, + "_decoder_layer_support", + [(lambda m: hasattr(m, "layers"), lambda m: m.layers)], + ) + + +def test_collector_collects_correct_number_of_inputs(monkeypatch): + _register_test_discoverer(monkeypatch) torch.manual_seed(0) model = _SimpleTwoLayerModel(dim=8) collector = LayerActivationCollector(model) @@ -108,12 +120,17 @@ def forward_loop(m): for d in data: m(d) - inputs = collector.get_input_activations(model.layers[0], forward_loop) - assert len(inputs) == 3 + collector._patch_all_layers() + try: + inputs = collector.get_input_activations(model.layers[0], forward_loop) + assert len(inputs) == 3 + finally: + collector._unpatch_all_layers() -def test_collector_activations_match_expected(): +def test_collector_activations_match_expected(monkeypatch): """First layer should receive the raw input data.""" + _register_test_discoverer(monkeypatch) torch.manual_seed(0) model = _SimpleTwoLayerModel(dim=8) collector = LayerActivationCollector(model) @@ -123,13 +140,18 @@ def forward_loop(m): for d in data: m(d) - inputs = collector.get_input_activations(model.layers[0], forward_loop) - args, kwargs = inputs[0] - assert torch.allclose(args[0], data[0]) + collector._patch_all_layers() + try: + inputs = collector.get_input_activations(model.layers[0], forward_loop) + args, kwargs = inputs[0] + assert torch.allclose(args[0], data[0]) + finally: + collector._unpatch_all_layers() -def test_collector_second_layer_receives_transformed_input(): +def test_collector_second_layer_receives_transformed_input(monkeypatch): """Second layer should receive first layer's output, not raw input.""" + _register_test_discoverer(monkeypatch) torch.manual_seed(0) model = _SimpleTwoLayerModel(dim=8) collector = LayerActivationCollector(model) @@ -139,53 +161,55 @@ def forward_loop(m): m(x) expected = model.layers[0](x) - inputs = collector.get_input_activations(model.layers[1], forward_loop) - args, _ = inputs[0] - assert torch.allclose(args[0], expected) + collector._patch_all_layers() + try: + collector.get_input_activations(model.layers[0], forward_loop) + inputs = collector.get_input_activations(model.layers[1], forward_loop) + args, _ = inputs[0] + assert torch.allclose(args[0], expected) + finally: + collector._unpatch_all_layers() -def test_collector_forward_is_restored_after_collection(): + +def test_collector_forward_is_restored_after_collection(monkeypatch): + _register_test_discoverer(monkeypatch) model = _SimpleTwoLayerModel(dim=8) collector = LayerActivationCollector(model) def forward_loop(m): m(torch.randn(2, 8)) + collector._patch_all_layers() collector.get_input_activations(model.layers[0], forward_loop) + collector._unpatch_all_layers() assert not hasattr(model, "_original_forward") - assert not hasattr(model.layers[0], "inputs") + assert not hasattr(model.layers[0], "_seq_calib") assert not hasattr(model.layers[0], "_original_forward") -def test_collector_cleanup_on_forward_loop_error(): +def test_collector_cleanup_on_forward_loop_error(monkeypatch): """Patching should be cleaned up even if forward_loop raises.""" + _register_test_discoverer(monkeypatch) model = _SimpleTwoLayerModel(dim=8) collector = LayerActivationCollector(model) def bad_forward_loop(m): raise RuntimeError("intentional error") - with pytest.raises(RuntimeError, match="intentional error"): - collector.get_input_activations(model.layers[0], bad_forward_loop) + collector._patch_all_layers() + try: + with pytest.raises(RuntimeError, match="intentional error"): + collector.get_input_activations(model.layers[0], bad_forward_loop) + finally: + collector._unpatch_all_layers() assert not hasattr(model, "_original_forward") - assert not hasattr(model.layers[0], "inputs") + assert not hasattr(model.layers[0], "_seq_calib") # sequential_calibrate tests - - -def test_seq_calib_raises_on_none_forward_loop(): - model, _ = _make_model_and_data(n_layers=2) - with pytest.raises(ValueError, match="forward_loop must not be None"): - sequential_calibrate( - model, - forward_loop=None, - calib_func=lambda *a, **kw: None, - ) - - def test_seq_calib_raises_on_unrecognized_model(): model = _FlatMLP() with pytest.raises(ValueError, match="Could not find transformer layers"): @@ -196,7 +220,8 @@ def test_seq_calib_raises_on_unrecognized_model(): ) -def test_seq_calib_func_called_per_layer(): +def test_seq_calib_func_called_per_layer(monkeypatch): + _register_test_discoverer(monkeypatch) model, data = _make_model_and_data(n_layers=4) call_count = [0] @@ -212,7 +237,8 @@ def counting_calib(layer, forward_loop, **kwargs): assert call_count[0] == 4 -def test_seq_calib_func_receives_correct_layer(): +def test_seq_calib_func_receives_correct_layer(monkeypatch): + _register_test_discoverer(monkeypatch) model, data = _make_model_and_data(n_layers=3) called_layers = [] @@ -229,7 +255,8 @@ def track_layers(layer, forward_loop, **kwargs): assert called_layers[i] is layer -def test_seq_calib_kwargs_forwarded(): +def test_seq_calib_kwargs_forwarded(monkeypatch): + _register_test_discoverer(monkeypatch) model, data = _make_model_and_data(n_layers=2) received_kwargs = [] @@ -250,8 +277,9 @@ def capture_kwargs(layer, forward_loop, **kwargs): assert kw["method"] == "max" -def test_seq_calib_layer_forward_loop_runs_all_batches(): +def test_seq_calib_layer_forward_loop_runs_all_batches(monkeypatch): """The per-layer forward loop passed to calib_func should replay all batches.""" + _register_test_discoverer(monkeypatch) n_batches = 5 model, data = _make_model_and_data(n_layers=2, n_batches=n_batches) batch_counts = [] @@ -279,8 +307,9 @@ def counting_forward(*args, **kw): assert count == n_batches -def test_seq_calib_does_not_alter_weights(): +def test_seq_calib_does_not_alter_weights(monkeypatch): """sequential_calibrate itself should not modify model weights.""" + _register_test_discoverer(monkeypatch) model, data = _make_model_and_data(n_layers=3) weights_before = {n: p.clone() for n, p in model.named_parameters()} @@ -294,8 +323,9 @@ def test_seq_calib_does_not_alter_weights(): assert torch.equal(p, weights_before[n]), f"Weight {n} was modified" -def test_seq_calib_activations_update_across_layers(): +def test_seq_calib_activations_update_across_layers(monkeypatch): """Subsequent layers should see activations transformed by prior layers.""" + _register_test_discoverer(monkeypatch) torch.manual_seed(0) model = _SimpleTransformerModel(n_layers=2, dim=16) tokens = [torch.randint(0, 32, (2, 4))] @@ -328,8 +358,9 @@ def capture_forward(*args, **kw): ) -def test_seq_calib_empty_forward_loop(): +def test_seq_calib_empty_forward_loop(monkeypatch): """If forward_loop feeds no data, calib_func still gets called with an empty replay.""" + _register_test_discoverer(monkeypatch) model = _SimpleTransformerModel(n_layers=2, dim=16) replay_counts = [] @@ -403,15 +434,6 @@ def forward(self, x): return x -def _register_test_discoverer(monkeypatch): - """Register a simple discoverer that finds model.layers on any model.""" - monkeypatch.setattr( - LayerActivationCollector, - "_decoder_layer_support", - [(lambda m: hasattr(m, "layers"), lambda m: m.layers)], - ) - - def test_skip_output_preserves_tuple_structure(monkeypatch): """Skip layers must return a tuple when the real layer returns a tuple. @@ -572,7 +594,7 @@ def test_run_asserts_on_empty_cached_inputs(monkeypatch): collector._patch_all_layers() try: model.layers[0]._seq_calib.mode = "run" - model.layers[0]._seq_calib.cached_inputs = [] + model.layers[0]._seq_calib.cached_inputs = deque() with pytest.raises(AssertionError, match="no cached inputs to replay"): model(torch.randn(2, 16)) From 50e31f0daa014dd6a1c478735b48eac749869858 Mon Sep 17 00:00:00 2001 From: Suguna Velury <178320438+sugunav14@users.noreply.github.com> Date: Tue, 10 Mar 2026 23:19:17 +0000 Subject: [PATCH 10/11] moved LayerActivationCollector to new file Signed-off-by: Suguna Velury <178320438+sugunav14@users.noreply.github.com> --- .../quantization/activation_collector.py | 308 ++++++++++++++++++ modelopt/torch/quantization/model_calib.py | 2 +- .../torch/quantization/plugins/huggingface.py | 3 +- modelopt/torch/quantization/utils.py | 280 +--------------- .../quantization/plugins/test_huggingface.py | 2 +- tests/unit/torch/quantization/test_calib.py | 4 +- .../quantization/test_sequential_calibrate.py | 2 +- tests/unit/torch/quantization/test_utils.py | 2 +- 8 files changed, 318 insertions(+), 285 deletions(-) create mode 100644 modelopt/torch/quantization/activation_collector.py diff --git a/modelopt/torch/quantization/activation_collector.py b/modelopt/torch/quantization/activation_collector.py new file mode 100644 index 000000000..dbeeb57f7 --- /dev/null +++ b/modelopt/torch/quantization/activation_collector.py @@ -0,0 +1,308 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Sequential calibration layer patching and activation capture. + +This module provides :class:`LayerActivationCollector`, a stateful helper that +patches decoder layers with a skip / run / capture strategy for efficient +layer-by-layer calibration. +""" + +import copy +from collections import deque +from dataclasses import dataclass, field +from typing import Any + +import torch +import torch.nn as nn + +from modelopt.torch.opt.searcher import ForwardLoop +from modelopt.torch.utils import print_rank_0 +from modelopt.torch.utils.network import bind_forward_method, unpatch_forward_method + + +class _EarlyStopForwardError(Exception): + """Raised to halt the forward pass after capturing layer inputs.""" + + +@dataclass +class _LayerCalibState: + """Mutable per-layer state used during sequential calibration. + + Attached to each decoder layer as ``_seq_calib`` and accessed by the + patched forward to decide skip / run / capture / original behaviour. + """ + + mode: str = "original" + name: str = "" + cached_inputs: deque = field(default_factory=deque) + collected_inputs: list = field(default_factory=list) + output_meta: tuple | None = None + + +class LayerActivationCollector: + """Collects layer activations for sequential (layer-by-layer) calibration. + + Each decoder layer is patched with a unified forward whose behaviour is + governed by a per-layer :class:`_LayerCalibState`: + + * **skip** — return a zero-filled dummy whose shape and type match the + layer's real output (reconstructed from lightweight metadata). No + computation is performed. The correctly shaped dummy ensures un-patched + inter-layer operations in the parent forward (e.g. LayerNorm, tuple + unpacking) do not raise shape or type errors. + * **run** — replay previously captured inputs through the original forward, + ignoring whatever the parent passes in. Only the just-calibrated layer + uses this mode, so its output reflects updated weights. + * **capture** — record ``(args, kwargs)`` and raise + ``_EarlyStopForwardError`` to halt the forward pass early. + * **original** — call the original forward unchanged. + + Because the *run* layer discards upstream values, skip-layer outputs are + never consumed for real computation. + """ + + _decoder_layer_support: list[tuple[Any, Any]] = [] + _LAYER_ATTR = "_seq_calib" + + def __init__(self, model: nn.Module): + """Initialize the collector for the given model.""" + self.model = model + self._decoder_layers: nn.ModuleList | None = None + self._layer_to_idx: dict[nn.Module, int] = {} + self._patched = False + + # ------------------------------------------------------------------ + # Decoder-layer discovery + # ------------------------------------------------------------------ + + @staticmethod + def get_decoder_layers(model: nn.Module) -> nn.ModuleList | None: + """Return decoder layers supported by sequential calibration.""" + for is_supported, discoverer in LayerActivationCollector._decoder_layer_support: + if not is_supported(model): + continue + decoder_layers = discoverer(model) + if decoder_layers is not None: + return decoder_layers + return None + + @staticmethod + def is_supported(model: nn.Module) -> bool: + """Whether the model supports decoder-layer sequential calibration.""" + return LayerActivationCollector.get_decoder_layers(model) is not None + + @classmethod + def register_decoder_layer_support(cls, is_supported: Any, discoverer: Any): + """Register a (predicate, discoverer) pair for decoder-layer detection.""" + entry = (is_supported, discoverer) + if entry not in cls._decoder_layer_support: + cls._decoder_layer_support.append(entry) + + # ------------------------------------------------------------------ + # Output metadata helpers + # ------------------------------------------------------------------ + + @staticmethod + def _extract_output_meta(output): + """Extract lightweight (shape, dtype, device) metadata from a layer output. + + Recursively handles tensors, tuples, lists, and non-tensor values (e.g. None). + The returned structure can be passed to ``_zeros_from_meta`` to reconstruct a + zero-filled output with identical shape and type. + """ + if isinstance(output, torch.Tensor): + return ("tensor", output.shape, output.dtype, output.device) + if isinstance(output, tuple): + return ( + "tuple", + tuple(LayerActivationCollector._extract_output_meta(o) for o in output), + ) + if isinstance(output, list): + return ("list", [LayerActivationCollector._extract_output_meta(o) for o in output]) + return ("other", output) + + @staticmethod + def _zeros_from_meta(meta): + """Reconstruct a zero-filled output from metadata produced by ``_extract_output_meta``.""" + tag = meta[0] + if tag == "tensor": + _, shape, dtype, device = meta + return torch.zeros(shape, dtype=dtype, device=device) + if tag == "tuple": + return tuple(LayerActivationCollector._zeros_from_meta(m) for m in meta[1]) + if tag == "list": + return [LayerActivationCollector._zeros_from_meta(m) for m in meta[1]] + return copy.deepcopy(meta[1]) + + # ------------------------------------------------------------------ + # Patched forward + # ------------------------------------------------------------------ + + @staticmethod + def _patched_forward(self, *args, **kwargs): + """Unified forward bound to every decoder layer during sequential calibration. + + ``self`` here is the decoder layer module (bound via ``bind_forward_method``). + All per-layer state is accessed through ``self._seq_calib``. + """ + info: _LayerCalibState = self._seq_calib + + if info.mode == "skip": + if info.output_meta is None: + raise RuntimeError( + f"Layer {info.name} is in 'skip' mode but has no output_meta. " + "This indicates a state-machine bug: the layer should have run " + "in 'run' mode (which sets output_meta) before transitioning to 'skip'." + ) + return LayerActivationCollector._zeros_from_meta(info.output_meta) + + if info.mode == "run": + assert info.cached_inputs, ( + f"Layer {info.name} is in 'run' mode but has no cached inputs to replay." + ) + real_args, real_kwargs = info.cached_inputs.popleft() + output = self._original_forward(*real_args, **real_kwargs) + info.output_meta = LayerActivationCollector._extract_output_meta(output) + return output + + if info.mode == "capture": + info.collected_inputs.append((args, kwargs)) + raise _EarlyStopForwardError() + + return self._original_forward(*args, **kwargs) + + # ------------------------------------------------------------------ + # Patch / unpatch lifecycle + # ------------------------------------------------------------------ + + def _patch_all_layers(self): + """Bind the unified forward to every decoder layer and the model. Called once.""" + self._decoder_layers = self.get_decoder_layers(self.model) + assert self._decoder_layers is not None + self._layer_to_idx = {layer: i for i, layer in enumerate(self._decoder_layers)} + module_to_name = {m: name for name, m in self.model.named_modules()} + + try: + for layer in self._decoder_layers: + layer._seq_calib = _LayerCalibState( + name=module_to_name.get(layer, type(layer).__name__), + ) + bind_forward_method(layer, self._patched_forward, "_original_forward") + + def _early_stop_forward(self, *args, **kwargs): + try: + return self._original_forward(*args, **kwargs) + except _EarlyStopForwardError: + return None + + bind_forward_method(self.model, _early_stop_forward, "_original_forward") + except Exception: + self._cleanup_layers() + raise + + self._patched = True + + def _cleanup_layers(self): + """Best-effort cleanup of any patched layers and model forward.""" + if hasattr(self.model, "_original_forward"): + unpatch_forward_method(self.model, "_original_forward") + + if self._decoder_layers is not None: + for layer in self._decoder_layers: + if hasattr(layer, "_original_forward"): + unpatch_forward_method(layer, "_original_forward") + if hasattr(layer, self._LAYER_ATTR): + delattr(layer, self._LAYER_ATTR) + + def _unpatch_all_layers(self): + """Restore original forwards and clean up state attributes. Called once.""" + if not self._patched: + return + self._cleanup_layers() + self._patched = False + + # ------------------------------------------------------------------ + # Per-iteration state management + # ------------------------------------------------------------------ + + def _set_layer_states(self, layer_idx: int): + """Transition layer modes for the next calibration step. + + When calibrating layer *i*, three transitions happen: + + * Layer ``i - 2`` → **skip** (fully done, free its cached inputs). + * Layer ``i - 1`` → **run** (replay captured inputs with calibrated weights). + * Layer ``i`` → **capture** (record inputs, then early-stop). + """ + assert self._decoder_layers is not None + + if layer_idx > 1: + done = self._decoder_layers[layer_idx - 2]._seq_calib + done.mode = "skip" + done.cached_inputs = deque() + + if layer_idx > 0: + prev = self._decoder_layers[layer_idx - 1]._seq_calib + prev.mode = "run" + prev.cached_inputs = deque(prev.collected_inputs) + prev.collected_inputs = [] + + cur = self._decoder_layers[layer_idx]._seq_calib + cur.mode = "capture" + cur.collected_inputs = [] + + def _log_layer_summary(self, layer_idx: int): + """Log a one-line summary of layer modes for the current calibration step.""" + assert self._decoder_layers is not None + n = len(self._decoder_layers) + groups: dict[str, list[int]] = {} + for i, layer in enumerate(self._decoder_layers): + mode = layer._seq_calib.mode + if mode in ("skip", "run", "capture"): + groups.setdefault(mode, []).append(i) + parts = [f"{mode}: {groups[mode]}" for mode in ("skip", "run", "capture") if mode in groups] + print_rank_0(f"Calibrating layer {layer_idx}/{n} | {' | '.join(parts)}") + + # ------------------------------------------------------------------ + # Public API + # ------------------------------------------------------------------ + + @torch.no_grad() + def get_input_activations(self, layer: torch.nn.Module, forward_loop: ForwardLoop) -> list: + """Collect input activations for *layer* by running a full model forward. + + Layers before the target are skipped or re-run (if just calibrated), the + target layer captures its inputs, and an early-stop prevents unnecessary + computation beyond the target. + + :meth:`_patch_all_layers` must be called before this method. + """ + if not self._patched: + raise RuntimeError( + "get_input_activations() requires _patch_all_layers() to be called first." + ) + layer_idx = self._layer_to_idx[layer] + self._set_layer_states(layer_idx) + self._log_layer_summary(layer_idx) + forward_loop(self.model) + + info = layer._seq_calib + inputs = list(info.collected_inputs) + # After capture, set to original so calib_func can call the layer's + # real forward directly. The layer will transition to run → skip + # in subsequent iterations via _set_layer_states. + info.mode = "original" + return inputs diff --git a/modelopt/torch/quantization/model_calib.py b/modelopt/torch/quantization/model_calib.py index cd2c575b6..59fde1138 100644 --- a/modelopt/torch/quantization/model_calib.py +++ b/modelopt/torch/quantization/model_calib.py @@ -28,7 +28,7 @@ from tqdm import tqdm from modelopt.torch.opt.searcher import ForwardLoop -from modelopt.torch.quantization.utils import LayerActivationCollector +from modelopt.torch.quantization.activation_collector import LayerActivationCollector from modelopt.torch.utils import print_rank_0 from modelopt.torch.utils.distributed import DistributedProcessGroup, ParallelState from modelopt.torch.utils.network import bind_forward_method, unpatch_forward_method diff --git a/modelopt/torch/quantization/plugins/huggingface.py b/modelopt/torch/quantization/plugins/huggingface.py index c08fb7258..a14119938 100644 --- a/modelopt/torch/quantization/plugins/huggingface.py +++ b/modelopt/torch/quantization/plugins/huggingface.py @@ -56,7 +56,8 @@ else: weight_dequant = None -from ..utils import LayerActivationCollector, replace_function +from ..activation_collector import LayerActivationCollector +from ..utils import replace_function from .attention import register_attention_for_kv_quant from .custom import CUSTOM_MODEL_PLUGINS, _ParallelLinear, _QuantFunctionalMixin diff --git a/modelopt/torch/quantization/utils.py b/modelopt/torch/quantization/utils.py index 4aa5d98ba..d5cdd8a47 100644 --- a/modelopt/torch/quantization/utils.py +++ b/modelopt/torch/quantization/utils.py @@ -15,10 +15,8 @@ """Quantization utilities.""" -import copy -from collections import deque, namedtuple +from collections import namedtuple from contextlib import ExitStack, contextmanager, nullcontext -from dataclasses import dataclass, field from typing import TYPE_CHECKING, Any import torch @@ -28,9 +26,8 @@ from torch.distributed.fsdp._fully_shard._fsdp_param import FSDPParam from torch.distributed.tensor import Replicate -from modelopt.torch.opt.searcher import ForwardLoop +from modelopt.torch.quantization.activation_collector import LayerActivationCollector # noqa: F401 from modelopt.torch.utils import get_unwrapped_name, print_rank_0 -from modelopt.torch.utils.network import bind_forward_method, unpatch_forward_method if TYPE_CHECKING: from collections.abc import Generator @@ -810,276 +807,3 @@ def update_quant_cfg_with_kv_cache_quant( quant_cfg["algorithm"] = "max" print_rank_0(f"Updated quant_cfg with KV cache quantization: {quant_cfg}") return quant_cfg - - -class _EarlyStopForwardError(Exception): - """Raised to halt the forward pass after capturing layer inputs.""" - - -@dataclass -class _LayerCalibState: - """Mutable per-layer state used during sequential calibration. - - Attached to each decoder layer as ``_seq_calib`` and accessed by the - patched forward to decide skip / run / capture / original behaviour. - """ - - mode: str = "original" - name: str = "" - cached_inputs: deque = field(default_factory=deque) - collected_inputs: list = field(default_factory=list) - output_meta: tuple | None = None - - -class LayerActivationCollector: - """Collects layer activations for sequential (layer-by-layer) calibration. - - Each decoder layer is patched with a unified forward whose behaviour is - governed by a per-layer :class:`_LayerCalibState`: - - * **skip** — return a zero-filled dummy whose shape and type match the - layer's real output (reconstructed from lightweight metadata). No - computation is performed. The correctly shaped dummy ensures un-patched - inter-layer operations in the parent forward (e.g. LayerNorm, tuple - unpacking) do not raise shape or type errors. - * **run** — replay previously captured inputs through the original forward, - ignoring whatever the parent passes in. Only the just-calibrated layer - uses this mode, so its output reflects updated weights. - * **capture** — record ``(args, kwargs)`` and raise - ``_EarlyStopForwardError`` to halt the forward pass early. - * **original** — call the original forward unchanged. - - Because the *run* layer discards upstream values, skip-layer outputs are - never consumed for real computation. - """ - - _decoder_layer_support: list[tuple[Any, Any]] = [] - _LAYER_ATTR = "_seq_calib" - - def __init__(self, model: nn.Module): - self.model = model - self._decoder_layers: nn.ModuleList | None = None - self._layer_to_idx: dict[nn.Module, int] = {} - self._patched = False - - # ------------------------------------------------------------------ - # Decoder-layer discovery - # ------------------------------------------------------------------ - - @staticmethod - def get_decoder_layers(model: nn.Module) -> nn.ModuleList | None: - """Return decoder layers supported by sequential calibration.""" - for is_supported, discoverer in LayerActivationCollector._decoder_layer_support: - if not is_supported(model): - continue - decoder_layers = discoverer(model) - if decoder_layers is not None: - return decoder_layers - return None - - @staticmethod - def is_supported(model: nn.Module) -> bool: - """Whether the model supports decoder-layer sequential calibration.""" - return LayerActivationCollector.get_decoder_layers(model) is not None - - @classmethod - def register_decoder_layer_support(cls, is_supported: Any, discoverer: Any): - entry = (is_supported, discoverer) - if entry not in cls._decoder_layer_support: - cls._decoder_layer_support.append(entry) - - # ------------------------------------------------------------------ - # Output metadata helpers - # ------------------------------------------------------------------ - - @staticmethod - def _extract_output_meta(output): - """Extract lightweight (shape, dtype, device) metadata from a layer output. - - Recursively handles tensors, tuples, lists, and non-tensor values (e.g. None). - The returned structure can be passed to ``_zeros_from_meta`` to reconstruct a - zero-filled output with identical shape and type. - """ - if isinstance(output, torch.Tensor): - return ("tensor", output.shape, output.dtype, output.device) - if isinstance(output, tuple): - return ( - "tuple", - tuple(LayerActivationCollector._extract_output_meta(o) for o in output), - ) - if isinstance(output, list): - return ("list", [LayerActivationCollector._extract_output_meta(o) for o in output]) - return ("other", output) - - @staticmethod - def _zeros_from_meta(meta): - """Reconstruct a zero-filled output from metadata produced by ``_extract_output_meta``.""" - tag = meta[0] - if tag == "tensor": - _, shape, dtype, device = meta - return torch.zeros(shape, dtype=dtype, device=device) - if tag == "tuple": - return tuple(LayerActivationCollector._zeros_from_meta(m) for m in meta[1]) - if tag == "list": - return [LayerActivationCollector._zeros_from_meta(m) for m in meta[1]] - return copy.deepcopy(meta[1]) - - # ------------------------------------------------------------------ - # Patched forward - # ------------------------------------------------------------------ - - @staticmethod - def _patched_forward(self, *args, **kwargs): - """Unified forward bound to every decoder layer during sequential calibration. - - ``self`` here is the decoder layer module (bound via ``bind_forward_method``). - All per-layer state is accessed through ``self._seq_calib``. - """ - info: _LayerCalibState = self._seq_calib - - if info.mode == "skip": - if info.output_meta is None: - raise RuntimeError( - f"Layer {info.name} is in 'skip' mode but has no output_meta. " - "This indicates a state-machine bug: the layer should have run " - "in 'run' mode (which sets output_meta) before transitioning to 'skip'." - ) - return LayerActivationCollector._zeros_from_meta(info.output_meta) - - if info.mode == "run": - assert info.cached_inputs, ( - f"Layer {info.name} is in 'run' mode but has no cached inputs to replay." - ) - real_args, real_kwargs = info.cached_inputs.popleft() - output = self._original_forward(*real_args, **real_kwargs) - info.output_meta = LayerActivationCollector._extract_output_meta(output) - return output - - if info.mode == "capture": - info.collected_inputs.append((args, kwargs)) - raise _EarlyStopForwardError() - - return self._original_forward(*args, **kwargs) - - # ------------------------------------------------------------------ - # Patch / unpatch lifecycle - # ------------------------------------------------------------------ - - def _patch_all_layers(self): - """Bind the unified forward to every decoder layer and the model. Called once.""" - self._decoder_layers = self.get_decoder_layers(self.model) - assert self._decoder_layers is not None - self._layer_to_idx = {layer: i for i, layer in enumerate(self._decoder_layers)} - module_to_name = {m: name for name, m in self.model.named_modules()} - - try: - for layer in self._decoder_layers: - layer._seq_calib = _LayerCalibState( - name=module_to_name.get(layer, type(layer).__name__), - ) - bind_forward_method(layer, self._patched_forward, "_original_forward") - - def _early_stop_forward(self, *args, **kwargs): - try: - return self._original_forward(*args, **kwargs) - except _EarlyStopForwardError: - return None - - bind_forward_method(self.model, _early_stop_forward, "_original_forward") - except Exception: - self._cleanup_layers() - raise - - self._patched = True - - def _cleanup_layers(self): - """Best-effort cleanup of any patched layers and model forward.""" - if hasattr(self.model, "_original_forward"): - unpatch_forward_method(self.model, "_original_forward") - - if self._decoder_layers is not None: - for layer in self._decoder_layers: - if hasattr(layer, "_original_forward"): - unpatch_forward_method(layer, "_original_forward") - if hasattr(layer, self._LAYER_ATTR): - delattr(layer, self._LAYER_ATTR) - - def _unpatch_all_layers(self): - """Restore original forwards and clean up state attributes. Called once.""" - if not self._patched: - return - self._cleanup_layers() - self._patched = False - - # ------------------------------------------------------------------ - # Per-iteration state management - # ------------------------------------------------------------------ - - def _set_layer_states(self, layer_idx: int): - """Transition layer modes for the next calibration step. - - When calibrating layer *i*, three transitions happen: - - * Layer ``i - 2`` → **skip** (fully done, free its cached inputs). - * Layer ``i - 1`` → **run** (replay captured inputs with calibrated weights). - * Layer ``i`` → **capture** (record inputs, then early-stop). - """ - assert self._decoder_layers is not None - - if layer_idx > 1: - done = self._decoder_layers[layer_idx - 2]._seq_calib - done.mode = "skip" - done.cached_inputs = deque() - - if layer_idx > 0: - prev = self._decoder_layers[layer_idx - 1]._seq_calib - prev.mode = "run" - prev.cached_inputs = deque(prev.collected_inputs) - prev.collected_inputs = [] - - cur = self._decoder_layers[layer_idx]._seq_calib - cur.mode = "capture" - cur.collected_inputs = [] - - def _log_layer_summary(self, layer_idx: int): - """Log a one-line summary of layer modes for the current calibration step.""" - assert self._decoder_layers is not None - n = len(self._decoder_layers) - groups: dict[str, list[int]] = {} - for i, layer in enumerate(self._decoder_layers): - mode = layer._seq_calib.mode - if mode in ("skip", "run", "capture"): - groups.setdefault(mode, []).append(i) - parts = [f"{mode}: {groups[mode]}" for mode in ("skip", "run", "capture") if mode in groups] - print_rank_0(f"Calibrating layer {layer_idx}/{n} | {' | '.join(parts)}") - - # ------------------------------------------------------------------ - # Public API - # ------------------------------------------------------------------ - - @torch.no_grad() - def get_input_activations(self, layer: torch.nn.Module, forward_loop: ForwardLoop) -> list: - """Collect input activations for *layer* by running a full model forward. - - Layers before the target are skipped or re-run (if just calibrated), the - target layer captures its inputs, and an early-stop prevents unnecessary - computation beyond the target. - - :meth:`_patch_all_layers` must be called before this method. - """ - if not self._patched: - raise RuntimeError( - "get_input_activations() requires _patch_all_layers() to be called first." - ) - layer_idx = self._layer_to_idx[layer] - self._set_layer_states(layer_idx) - self._log_layer_summary(layer_idx) - forward_loop(self.model) - - info = layer._seq_calib - inputs = list(info.collected_inputs) - # After capture, set to original so calib_func can call the layer's - # real forward directly. The layer will transition to run → skip - # in subsequent iterations via _set_layer_states. - info.mode = "original" - return inputs diff --git a/tests/unit/torch/quantization/plugins/test_huggingface.py b/tests/unit/torch/quantization/plugins/test_huggingface.py index 043d6d6aa..1b6a4a84c 100644 --- a/tests/unit/torch/quantization/plugins/test_huggingface.py +++ b/tests/unit/torch/quantization/plugins/test_huggingface.py @@ -30,12 +30,12 @@ ) import modelopt.torch.quantization as mtq +from modelopt.torch.quantization.activation_collector import LayerActivationCollector from modelopt.torch.quantization.nn import QuantLinear, QuantModuleRegistry from modelopt.torch.quantization.plugins.huggingface import ( get_homogeneous_hf_decoder_layers, is_homogenous_hf_model, ) -from modelopt.torch.quantization.utils import LayerActivationCollector pytest.importorskip("transformers") diff --git a/tests/unit/torch/quantization/test_calib.py b/tests/unit/torch/quantization/test_calib.py index 98fc0a915..c868b8bfb 100644 --- a/tests/unit/torch/quantization/test_calib.py +++ b/tests/unit/torch/quantization/test_calib.py @@ -402,7 +402,7 @@ def forward(self, x): def test_sequential_calibrate_propagates_inputs_without_replaying_full_model(monkeypatch): - from modelopt.torch.quantization.utils import LayerActivationCollector + from modelopt.torch.quantization.activation_collector import LayerActivationCollector class _ToyLayer(nn.Module): def __init__(self, scale: float, bias: float): @@ -484,7 +484,7 @@ def _pre_hook(_module, args): def test_sequential_calibrate_handles_inter_layer_logic(monkeypatch): """Verify that parent-level inter-layer logic (e.g. mask selection) works correctly.""" - from modelopt.torch.quantization.utils import LayerActivationCollector + from modelopt.torch.quantization.activation_collector import LayerActivationCollector class _ToyLayer(nn.Module): def __init__(self, scale: float): diff --git a/tests/unit/torch/quantization/test_sequential_calibrate.py b/tests/unit/torch/quantization/test_sequential_calibrate.py index 2fca85a75..ab099b365 100644 --- a/tests/unit/torch/quantization/test_sequential_calibrate.py +++ b/tests/unit/torch/quantization/test_sequential_calibrate.py @@ -21,8 +21,8 @@ import torch import torch.nn as nn +from modelopt.torch.quantization.activation_collector import LayerActivationCollector from modelopt.torch.quantization.model_calib import sequential_calibrate -from modelopt.torch.quantization.utils import LayerActivationCollector class _DecoderBlock(nn.Module): diff --git a/tests/unit/torch/quantization/test_utils.py b/tests/unit/torch/quantization/test_utils.py index 6e31bdad4..f49c575b7 100644 --- a/tests/unit/torch/quantization/test_utils.py +++ b/tests/unit/torch/quantization/test_utils.py @@ -16,8 +16,8 @@ import pytest import torch +from modelopt.torch.quantization.activation_collector import LayerActivationCollector from modelopt.torch.quantization.utils import ( - LayerActivationCollector, convert_quantization_axis_to_reduce_axis, reduce_block_amax, ) From 5cf716ab1851e3e01c426fd454eb3397d2a3579c Mon Sep 17 00:00:00 2001 From: Suguna Velury <178320438+sugunav14@users.noreply.github.com> Date: Wed, 11 Mar 2026 20:22:24 +0000 Subject: [PATCH 11/11] Claude review addressed Signed-off-by: Suguna Velury <178320438+sugunav14@users.noreply.github.com> --- .../quantization/activation_collector.py | 40 +++++++++++--- modelopt/torch/quantization/model_calib.py | 10 ++-- .../torch/quantization/plugins/huggingface.py | 7 ++- modelopt/torch/quantization/utils.py | 4 +- .../quantization/plugins/test_huggingface.py | 12 ++--- .../quantization/test_sequential_calibrate.py | 52 +++++++++++++++++++ 6 files changed, 102 insertions(+), 23 deletions(-) diff --git a/modelopt/torch/quantization/activation_collector.py b/modelopt/torch/quantization/activation_collector.py index dbeeb57f7..0d060296b 100644 --- a/modelopt/torch/quantization/activation_collector.py +++ b/modelopt/torch/quantization/activation_collector.py @@ -74,6 +74,10 @@ class LayerActivationCollector: never consumed for real computation. """ + # Global registry of (predicate, discoverer) pairs. Populated at import time + # by plugins (e.g. huggingface.py). Order matters: the first matching entry wins, + # so more specific predicates (e.g. Nemotron-H) must be registered before + # generic ones (e.g. homogeneous HF models). _decoder_layer_support: list[tuple[Any, Any]] = [] _LAYER_ATTR = "_seq_calib" @@ -188,10 +192,19 @@ def _patched_forward(self, *args, **kwargs): # Patch / unpatch lifecycle # ------------------------------------------------------------------ - def _patch_all_layers(self): - """Bind the unified forward to every decoder layer and the model. Called once.""" - self._decoder_layers = self.get_decoder_layers(self.model) + def _patch_all_layers(self, decoder_layers: nn.ModuleList | None = None): + """Bind the unified forward to every decoder layer and the model. Called once. + + Args: + decoder_layers: Pre-resolved decoder layers. If *None*, layers are + discovered via :meth:`get_decoder_layers`. + """ + if decoder_layers is not None: + self._decoder_layers = decoder_layers + else: + self._decoder_layers = self.get_decoder_layers(self.model) assert self._decoder_layers is not None + self._layer_to_idx = {layer: i for i, layer in enumerate(self._decoder_layers)} module_to_name = {m: name for name, m in self.model.named_modules()} @@ -202,9 +215,9 @@ def _patch_all_layers(self): ) bind_forward_method(layer, self._patched_forward, "_original_forward") - def _early_stop_forward(self, *args, **kwargs): + def _early_stop_forward(module_self, *args, **kwargs): try: - return self._original_forward(*args, **kwargs) + return module_self._original_forward(*args, **kwargs) except _EarlyStopForwardError: return None @@ -252,7 +265,9 @@ def _set_layer_states(self, layer_idx: int): if layer_idx > 1: done = self._decoder_layers[layer_idx - 2]._seq_calib done.mode = "skip" - done.cached_inputs = deque() + # output_meta is intentionally kept: skip mode needs it to produce + # correctly shaped zero-filled outputs for the parent forward. + done.cached_inputs.clear() if layer_idx > 0: prev = self._decoder_layers[layer_idx - 1]._seq_calib @@ -289,6 +304,10 @@ def get_input_activations(self, layer: torch.nn.Module, forward_loop: ForwardLoo computation beyond the target. :meth:`_patch_all_layers` must be called before this method. + + Note: the model forward returns ``None`` for every batch during capture + (because ``_EarlyStopForwardError`` short-circuits the forward pass). + Callers should not rely on the model's return value within *forward_loop*. """ if not self._patched: raise RuntimeError( @@ -297,9 +316,16 @@ def get_input_activations(self, layer: torch.nn.Module, forward_loop: ForwardLoo layer_idx = self._layer_to_idx[layer] self._set_layer_states(layer_idx) self._log_layer_summary(layer_idx) - forward_loop(self.model) info = layer._seq_calib + try: + forward_loop(self.model) + except Exception: + # Reset the current layer so subsequent calls don't see stale state. + info.mode = "original" + info.collected_inputs = [] + raise + inputs = list(info.collected_inputs) # After capture, set to original so calib_func can call the layer's # real forward directly. The layer will transition to run → skip diff --git a/modelopt/torch/quantization/model_calib.py b/modelopt/torch/quantization/model_calib.py index 59fde1138..513b66fdb 100644 --- a/modelopt/torch/quantization/model_calib.py +++ b/modelopt/torch/quantization/model_calib.py @@ -1842,21 +1842,17 @@ def sequential_calibrate( skip / run / capture strategy so that inter-layer logic in parent modules (e.g. mask construction) executes naturally without model-specific hooks. """ - if not LayerActivationCollector.is_supported(model): + transformer_layers = LayerActivationCollector.get_decoder_layers(model) + if transformer_layers is None or len(transformer_layers) == 0: raise ValueError( "Could not find transformer layers in model. " "Sequential calibration requires a model with identifiable transformer layers." ) - transformer_layers = LayerActivationCollector.get_decoder_layers(model) - assert transformer_layers is not None print_rank_0(f"Sequential calibration: Found {len(transformer_layers)} transformer layers") - if len(transformer_layers) == 0: - return input_getter = LayerActivationCollector(model) - # Patch all transformer layers with state aware module forward - input_getter._patch_all_layers() + input_getter._patch_all_layers(decoder_layers=transformer_layers) try: for layer_idx, layer in enumerate(transformer_layers): diff --git a/modelopt/torch/quantization/plugins/huggingface.py b/modelopt/torch/quantization/plugins/huggingface.py index a14119938..4f1e65229 100644 --- a/modelopt/torch/quantization/plugins/huggingface.py +++ b/modelopt/torch/quantization/plugins/huggingface.py @@ -1196,7 +1196,7 @@ def get_nemotron_h_decoder_layers(model: nn.Module) -> nn.ModuleList | None: return None -def is_homogenous_hf_model(model: nn.Module) -> bool: +def is_homogeneous_hf_model(model: nn.Module) -> bool: if is_nemotron_h_model(model): return False decoder_layers = get_homogeneous_hf_decoder_layers(model) @@ -1265,12 +1265,15 @@ def _is_param_grad_enabled_for_auto_quantize(pname, model): _is_param_grad_enabled_for_auto_quantize, ) +# Order matters: more specific predicates must be registered first because +# the first matching entry wins. Nemotron-H must precede the generic +# homogeneous HF discoverer (which explicitly rejects Nemotron-H). LayerActivationCollector.register_decoder_layer_support( is_nemotron_h_model, get_nemotron_h_decoder_layers ) LayerActivationCollector.register_decoder_layer_support( - is_homogenous_hf_model, get_homogeneous_hf_decoder_layers + is_homogeneous_hf_model, get_homogeneous_hf_decoder_layers ) CUSTOM_MODEL_PLUGINS.update( diff --git a/modelopt/torch/quantization/utils.py b/modelopt/torch/quantization/utils.py index d5cdd8a47..9216a89a1 100644 --- a/modelopt/torch/quantization/utils.py +++ b/modelopt/torch/quantization/utils.py @@ -26,7 +26,9 @@ from torch.distributed.fsdp._fully_shard._fsdp_param import FSDPParam from torch.distributed.tensor import Replicate -from modelopt.torch.quantization.activation_collector import LayerActivationCollector # noqa: F401 +from modelopt.torch.quantization.activation_collector import ( + LayerActivationCollector, # noqa: F401 # re-export +) from modelopt.torch.utils import get_unwrapped_name, print_rank_0 if TYPE_CHECKING: diff --git a/tests/unit/torch/quantization/plugins/test_huggingface.py b/tests/unit/torch/quantization/plugins/test_huggingface.py index 1b6a4a84c..67e82c629 100644 --- a/tests/unit/torch/quantization/plugins/test_huggingface.py +++ b/tests/unit/torch/quantization/plugins/test_huggingface.py @@ -34,7 +34,7 @@ from modelopt.torch.quantization.nn import QuantLinear, QuantModuleRegistry from modelopt.torch.quantization.plugins.huggingface import ( get_homogeneous_hf_decoder_layers, - is_homogenous_hf_model, + is_homogeneous_hf_model, ) pytest.importorskip("transformers") @@ -207,20 +207,20 @@ def test_quantized_transformers_save_restore(tmp_path, model_cls, quant_config): tf_modelopt_state_and_output_tester(model_ref, model_test) -def test_is_homogenous_hf_model_llama(): +def test_is_homogeneous_hf_model_llama(): model = get_tiny_llama() - assert is_homogenous_hf_model(model) + assert is_homogeneous_hf_model(model) -def test_is_homogenous_hf_model_gpt_oss(): +def test_is_homogeneous_hf_model_gpt_oss(): model = get_tiny_gpt_oss(num_hidden_layers=1) - assert is_homogenous_hf_model(model) + assert is_homogeneous_hf_model(model) def test_hf_decoder_discoverer_registration_path(): model = get_tiny_llama() assert any( - is_supported is is_homogenous_hf_model and discoverer is get_homogeneous_hf_decoder_layers + is_supported is is_homogeneous_hf_model and discoverer is get_homogeneous_hf_decoder_layers for is_supported, discoverer in LayerActivationCollector._decoder_layer_support ) assert LayerActivationCollector.get_decoder_layers(model) is get_homogeneous_hf_decoder_layers( diff --git a/tests/unit/torch/quantization/test_sequential_calibrate.py b/tests/unit/torch/quantization/test_sequential_calibrate.py index ab099b365..ad9a0e58f 100644 --- a/tests/unit/torch/quantization/test_sequential_calibrate.py +++ b/tests/unit/torch/quantization/test_sequential_calibrate.py @@ -680,3 +680,55 @@ def forward_loop(m): assert len(meta_1[1]) == 1 finally: collector._unpatch_all_layers() + + +def test_run_layer_reflects_weight_updates(monkeypatch): + """After calib_func modifies weights, the next layer should see updated activations.""" + _register_test_discoverer(monkeypatch) + torch.manual_seed(0) + dim = 8 + + class _ScaleLayer(nn.Module): + def __init__(self): + super().__init__() + self.weight = nn.Parameter(torch.ones(dim)) + + def forward(self, x): + return x * self.weight + + class _TwoScaleModel(nn.Module): + def __init__(self): + super().__init__() + self.layers = nn.ModuleList([_ScaleLayer(), _ScaleLayer()]) + + def forward(self, x): + for layer in self.layers: + x = layer(x) + return x + + model = _TwoScaleModel() + x = torch.randn(2, dim) + + activations_before_weight_update = model.layers[0](x).clone() + + def forward_loop(m): + m(x) + + def weight_doubling_calib(layer, layer_forward_loop, **kwargs): + with torch.no_grad(): + layer.weight.mul_(2.0) + layer_forward_loop(layer) + + sequential_calibrate( + model, + forward_loop=forward_loop, + calib_func=weight_doubling_calib, + ) + + # Layer 0's weight was doubled by calib_func. When collecting inputs + # for layer 1, the run-mode replay of layer 0 should use the updated + # weight, so layer 1 should have received 2x the original activations. + expected = activations_before_weight_update * 2.0 + # Verify by running model.layers[0] with its updated weights + actual = model.layers[0](x) + assert torch.allclose(actual, expected)