Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
334 changes: 334 additions & 0 deletions modelopt/torch/quantization/activation_collector.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,334 @@
# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Sequential calibration layer patching and activation capture.

This module provides :class:`LayerActivationCollector`, a stateful helper that
patches decoder layers with a skip / run / capture strategy for efficient
layer-by-layer calibration.
"""

import copy
from collections import deque
from dataclasses import dataclass, field
from typing import Any

import torch
import torch.nn as nn

from modelopt.torch.opt.searcher import ForwardLoop
from modelopt.torch.utils import print_rank_0
from modelopt.torch.utils.network import bind_forward_method, unpatch_forward_method


class _EarlyStopForwardError(Exception):
"""Raised to halt the forward pass after capturing layer inputs."""


@dataclass
class _LayerCalibState:
"""Mutable per-layer state used during sequential calibration.

Attached to each decoder layer as ``_seq_calib`` and accessed by the
patched forward to decide skip / run / capture / original behaviour.
"""

mode: str = "original"
name: str = ""
cached_inputs: deque = field(default_factory=deque)
collected_inputs: list = field(default_factory=list)
output_meta: tuple | None = None
Comment on lines +50 to +52
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

output_meta only tracks the last batch.

run overwrites a single output_meta on every replayed batch, and skip reuses that one shape for every later batch. With a normal drop_last=False loader or variable sequence lengths, a later pass can synthesize dummy outputs with the wrong dimensions for earlier batches. This needs per-batch metadata that can be replayed in order on each future pass, not one shared slot per layer.

Also applies to: 163-179

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@modelopt/torch/quantization/activation_collector.py` around lines 50 - 52,
The current implementation stores a single output_meta that gets overwritten in
ActivationCollector.run and reused in skip, causing replayed batches to get
incorrect shapes; change output_meta from a single tuple to a per-batch sequence
(e.g., a deque or list) named output_meta_list or output_meta: deque, append the
metadata for each captured batch inside ActivationCollector.run (use the same
indexing order as cached_inputs/collected_inputs), update
ActivationCollector.skip and any replay logic to consume/peek the corresponding
per-batch metadata in FIFO order (maintaining alignment with cached_inputs
entries) and remove or rotate entries consistently after replays so each
replayed batch uses its original per-batch output_meta rather than a shared slot
(also apply the same change to the other occurrence referenced around lines
163-179).



class LayerActivationCollector:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

qq, will this work in TP/EP/PP scenarios, which we usually use for mcore?

"""Collects layer activations for sequential (layer-by-layer) calibration.

Each decoder layer is patched with a unified forward whose behaviour is
governed by a per-layer :class:`_LayerCalibState`:

* **skip** — return a zero-filled dummy whose shape and type match the
layer's real output (reconstructed from lightweight metadata). No
computation is performed. The correctly shaped dummy ensures un-patched
inter-layer operations in the parent forward (e.g. LayerNorm, tuple
unpacking) do not raise shape or type errors.
* **run** — replay previously captured inputs through the original forward,
ignoring whatever the parent passes in. Only the just-calibrated layer
uses this mode, so its output reflects updated weights.
* **capture** — record ``(args, kwargs)`` and raise
``_EarlyStopForwardError`` to halt the forward pass early.
* **original** — call the original forward unchanged.

Because the *run* layer discards upstream values, skip-layer outputs are
never consumed for real computation.
"""

# Global registry of (predicate, discoverer) pairs. Populated at import time
# by plugins (e.g. huggingface.py). Order matters: the first matching entry wins,
# so more specific predicates (e.g. Nemotron-H) must be registered before
# generic ones (e.g. homogeneous HF models).
_decoder_layer_support: list[tuple[Any, Any]] = []
_LAYER_ATTR = "_seq_calib"

def __init__(self, model: nn.Module):
"""Initialize the collector for the given model."""
self.model = model
self._decoder_layers: nn.ModuleList | None = None
self._layer_to_idx: dict[nn.Module, int] = {}
self._patched = False

# ------------------------------------------------------------------
# Decoder-layer discovery
# ------------------------------------------------------------------

@staticmethod
def get_decoder_layers(model: nn.Module) -> nn.ModuleList | None:
"""Return decoder layers supported by sequential calibration."""
for is_supported, discoverer in LayerActivationCollector._decoder_layer_support:
if not is_supported(model):
continue
decoder_layers = discoverer(model)
if decoder_layers is not None:
return decoder_layers
return None

@staticmethod
def is_supported(model: nn.Module) -> bool:
"""Whether the model supports decoder-layer sequential calibration."""
return LayerActivationCollector.get_decoder_layers(model) is not None

@classmethod
def register_decoder_layer_support(cls, is_supported: Any, discoverer: Any):
"""Register a (predicate, discoverer) pair for decoder-layer detection."""
entry = (is_supported, discoverer)
if entry not in cls._decoder_layer_support:
cls._decoder_layer_support.append(entry)

# ------------------------------------------------------------------
# Output metadata helpers
# ------------------------------------------------------------------

@staticmethod
def _extract_output_meta(output):
"""Extract lightweight (shape, dtype, device) metadata from a layer output.

Recursively handles tensors, tuples, lists, and non-tensor values (e.g. None).
The returned structure can be passed to ``_zeros_from_meta`` to reconstruct a
zero-filled output with identical shape and type.
"""
if isinstance(output, torch.Tensor):
return ("tensor", output.shape, output.dtype, output.device)
if isinstance(output, tuple):
return (
"tuple",
tuple(LayerActivationCollector._extract_output_meta(o) for o in output),
)
if isinstance(output, list):
return ("list", [LayerActivationCollector._extract_output_meta(o) for o in output])
return ("other", output)

@staticmethod
def _zeros_from_meta(meta):
"""Reconstruct a zero-filled output from metadata produced by ``_extract_output_meta``."""
tag = meta[0]
if tag == "tensor":
_, shape, dtype, device = meta
return torch.zeros(shape, dtype=dtype, device=device)
if tag == "tuple":
return tuple(LayerActivationCollector._zeros_from_meta(m) for m in meta[1])
if tag == "list":
return [LayerActivationCollector._zeros_from_meta(m) for m in meta[1]]
return copy.deepcopy(meta[1])

# ------------------------------------------------------------------
# Patched forward
# ------------------------------------------------------------------

@staticmethod
def _patched_forward(self, *args, **kwargs):
"""Unified forward bound to every decoder layer during sequential calibration.

``self`` here is the decoder layer module (bound via ``bind_forward_method``).
All per-layer state is accessed through ``self._seq_calib``.
"""
info: _LayerCalibState = self._seq_calib

if info.mode == "skip":
if info.output_meta is None:
raise RuntimeError(
f"Layer {info.name} is in 'skip' mode but has no output_meta. "
"This indicates a state-machine bug: the layer should have run "
"in 'run' mode (which sets output_meta) before transitioning to 'skip'."
)
return LayerActivationCollector._zeros_from_meta(info.output_meta)

if info.mode == "run":
assert info.cached_inputs, (
f"Layer {info.name} is in 'run' mode but has no cached inputs to replay."
)
real_args, real_kwargs = info.cached_inputs.popleft()
output = self._original_forward(*real_args, **real_kwargs)
info.output_meta = LayerActivationCollector._extract_output_meta(output)
return output

if info.mode == "capture":
info.collected_inputs.append((args, kwargs))
raise _EarlyStopForwardError()

return self._original_forward(*args, **kwargs)

# ------------------------------------------------------------------
# Patch / unpatch lifecycle
# ------------------------------------------------------------------

def _patch_all_layers(self, decoder_layers: nn.ModuleList | None = None):
"""Bind the unified forward to every decoder layer and the model. Called once.

Args:
decoder_layers: Pre-resolved decoder layers. If *None*, layers are
discovered via :meth:`get_decoder_layers`.
"""
if decoder_layers is not None:
self._decoder_layers = decoder_layers
else:
self._decoder_layers = self.get_decoder_layers(self.model)
assert self._decoder_layers is not None

self._layer_to_idx = {layer: i for i, layer in enumerate(self._decoder_layers)}
module_to_name = {m: name for name, m in self.model.named_modules()}

try:
for layer in self._decoder_layers:
layer._seq_calib = _LayerCalibState(
name=module_to_name.get(layer, type(layer).__name__),
)
bind_forward_method(layer, self._patched_forward, "_original_forward")

def _early_stop_forward(module_self, *args, **kwargs):
try:
return module_self._original_forward(*args, **kwargs)
except _EarlyStopForwardError:
return None

bind_forward_method(self.model, _early_stop_forward, "_original_forward")
except Exception:
self._cleanup_layers()
raise

self._patched = True

def _cleanup_layers(self):
"""Best-effort cleanup of any patched layers and model forward."""
if hasattr(self.model, "_original_forward"):
unpatch_forward_method(self.model, "_original_forward")

if self._decoder_layers is not None:
for layer in self._decoder_layers:
if hasattr(layer, "_original_forward"):
unpatch_forward_method(layer, "_original_forward")
if hasattr(layer, self._LAYER_ATTR):
delattr(layer, self._LAYER_ATTR)

def _unpatch_all_layers(self):
"""Restore original forwards and clean up state attributes. Called once."""
if not self._patched:
return
self._cleanup_layers()
self._patched = False

# ------------------------------------------------------------------
# Per-iteration state management
# ------------------------------------------------------------------

def _set_layer_states(self, layer_idx: int):
"""Transition layer modes for the next calibration step.

When calibrating layer *i*, three transitions happen:

* Layer ``i - 2`` → **skip** (fully done, free its cached inputs).
* Layer ``i - 1`` → **run** (replay captured inputs with calibrated weights).
* Layer ``i`` → **capture** (record inputs, then early-stop).
"""
assert self._decoder_layers is not None

if layer_idx > 1:
done = self._decoder_layers[layer_idx - 2]._seq_calib
done.mode = "skip"
# output_meta is intentionally kept: skip mode needs it to produce
# correctly shaped zero-filled outputs for the parent forward.
done.cached_inputs.clear()

if layer_idx > 0:
prev = self._decoder_layers[layer_idx - 1]._seq_calib
prev.mode = "run"
prev.cached_inputs = deque(prev.collected_inputs)
prev.collected_inputs = []

cur = self._decoder_layers[layer_idx]._seq_calib
cur.mode = "capture"
cur.collected_inputs = []
Comment on lines +254 to +280
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🔴 Critical

Fail fast when the sequential state machine falls out of sync.

These transitions assume the previous pass both captured inputs and replayed once. If a caller skips a decoder layer, or forward_loop() never reaches the target, you only discover it later when a skipped layer hits the RuntimeError on Line 165. Please validate done.output_meta / prev.collected_inputs before switching modes, and reject an empty capture right after forward_loop() so the failure is reported at the source.

Suggested guardrails
 def _set_layer_states(self, layer_idx: int):
     assert self._decoder_layers is not None

     if layer_idx > 1:
         done = self._decoder_layers[layer_idx - 2]._seq_calib
+        if done.output_meta is None:
+            raise RuntimeError(
+                f"Layer {done.name} cannot enter 'skip' before a successful replay."
+            )
         done.mode = "skip"
         done.cached_inputs = deque()

     if layer_idx > 0:
         prev = self._decoder_layers[layer_idx - 1]._seq_calib
+        if not prev.collected_inputs:
+            raise RuntimeError(
+                f"Layer {prev.name} has no captured inputs to replay."
+            )
         prev.mode = "run"
         prev.cached_inputs = deque(prev.collected_inputs)
         prev.collected_inputs = []
         forward_loop(self.model)

         info = layer._seq_calib
         inputs = list(info.collected_inputs)
         # After capture, set to original so calib_func can call the layer's
         # real forward directly.  The layer will transition to run → skip
         # in subsequent iterations via _set_layer_states.
         info.mode = "original"
+        if not inputs:
+            raise RuntimeError(
+                f"Layer {info.name} did not capture any inputs during forward_loop()."
+            )
         return inputs

Also applies to: 293-307

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@modelopt/torch/quantization/activation_collector.py` around lines 241 - 265,
The state-transition code in _set_layer_states assumes prior stages
captured/replayed successfully; add fail-fast checks before changing modes:
verify done.output_meta is present (non-empty) before setting done.mode="skip"
and verify prev.collected_inputs is non-empty before setting prev.mode="run" and
copying to prev.cached_inputs; if either check fails, raise a clear RuntimeError
indicating the specific layer index and missing data (use layer_idx and
references to self._decoder_layers[...] ._seq_calib). After forward_loop()
completes, also validate that the current layer's collected_inputs is non-empty
and raise a descriptive error immediately if the capture is empty. Apply the
same guardrails to the analogous transition block later in the file (the other
spot that manipulates ._seq_calib, .collected_inputs and .cached_inputs).


def _log_layer_summary(self, layer_idx: int):
"""Log a one-line summary of layer modes for the current calibration step."""
assert self._decoder_layers is not None
n = len(self._decoder_layers)
groups: dict[str, list[int]] = {}
for i, layer in enumerate(self._decoder_layers):
mode = layer._seq_calib.mode
if mode in ("skip", "run", "capture"):
groups.setdefault(mode, []).append(i)
parts = [f"{mode}: {groups[mode]}" for mode in ("skip", "run", "capture") if mode in groups]
print_rank_0(f"Calibrating layer {layer_idx}/{n} | {' | '.join(parts)}")

# ------------------------------------------------------------------
# Public API
# ------------------------------------------------------------------

@torch.no_grad()
def get_input_activations(self, layer: torch.nn.Module, forward_loop: ForwardLoop) -> list:
"""Collect input activations for *layer* by running a full model forward.

Layers before the target are skipped or re-run (if just calibrated), the
target layer captures its inputs, and an early-stop prevents unnecessary
computation beyond the target.

:meth:`_patch_all_layers` must be called before this method.

Note: the model forward returns ``None`` for every batch during capture
(because ``_EarlyStopForwardError`` short-circuits the forward pass).
Callers should not rely on the model's return value within *forward_loop*.
"""
if not self._patched:
raise RuntimeError(
"get_input_activations() requires _patch_all_layers() to be called first."
)
layer_idx = self._layer_to_idx[layer]
self._set_layer_states(layer_idx)
self._log_layer_summary(layer_idx)

info = layer._seq_calib
try:
forward_loop(self.model)
except Exception:
# Reset the current layer so subsequent calls don't see stale state.
info.mode = "original"
info.collected_inputs = []
raise

inputs = list(info.collected_inputs)
# After capture, set to original so calib_func can call the layer's
# real forward directly. The layer will transition to run → skip
# in subsequent iterations via _set_layer_states.
info.mode = "original"
return inputs
Loading
Loading