Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 13 additions & 8 deletions modelopt/torch/quantization/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -1159,12 +1159,15 @@ class MaxCalibConfig(QuantizeAlgorithmConfig):


class MseCalibConfig(QuantizeAlgorithmConfig):
"""Configuration for per-tensor MSE calibration.
"""Configuration for per-tensor MSE calibration (weight quantizers only).

Finds a scale s (via amax a, with s = a / q_max) that minimizes the
reconstruction error of a tensor after uniform Q→DQ:
The MSE search is applied only to weight quantizers. Activation quantizers
keep their max-calibration amax.

s* = argmin_s E[(X - DQ(Q(X; s)))^2], X ∈ {weights | activations}
For weights, finds a scale s (via amax a, with s = a / q_max) that minimizes
the reconstruction error after uniform Q→DQ:

s* = argmin_s E[(X - DQ(Q(X; s)))^2], X = weights

When fp8_scale_sweep is enabled, step_size is ignored.
"""
Expand Down Expand Up @@ -1209,11 +1212,13 @@ class MseCalibConfig(QuantizeAlgorithmConfig):


class LocalHessianCalibConfig(QuantizeAlgorithmConfig):
"""Configuration for local Hessian-weighted MSE calibration.
"""Configuration for local Hessian-weighted MSE calibration (weight quantizers only).

This algorithm uses activation information to optimize per-block scales for weight
quantization. It minimizes the output reconstruction error by weighting the loss
with the local Hessian matrix computed from input activations.
Only weight quantizers are calibrated with this algorithm; activation quantizers
keep their max-calibration amax. This algorithm uses activation information to
optimize per-block scales for weight quantization. It minimizes the output
reconstruction error by weighting the loss with the local Hessian matrix
computed from input activations.

The local Hessian loss for each block is: ``(dw @ H @ dw.T)`` where:
- ``dw = weight - quantized_weight`` (weight reconstruction error per block)
Expand Down
86 changes: 74 additions & 12 deletions modelopt/torch/quantization/model_calib.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@
)
from modelopt.torch.utils.perf import get_used_gpu_mem_fraction

from .calib import MseCalibrator, NVFP4MSECalibrator
from .calib import MaxCalibrator, MseCalibrator, NVFP4MSECalibrator
from .conversion import create_and_replace_svdquant_linear_on_the_fly, set_quantizer_by_cfg_context
from .nn import NVFP4StaticQuantizer, QuantModule, SequentialQuantizer, TensorQuantizer
from .utils import (
Expand Down Expand Up @@ -288,11 +288,11 @@ def mse_calibrate(
stop_multiplier: float = 4.0,
fp8_scale_sweep: bool = False,
):
"""Calibrate the model using MSE-based amax search.
"""Calibrate the model using MSE-based amax search (weight quantizers only).

This calibration method first uses max calibration to get initial amax values,
then searches for better amax values by minimizing the MSE between original
and quantized tensors.
This calibration method first uses max calibration to get initial amax values
for all quantizers, then runs MSE-based amax search only for weight
quantizers. Activation quantizers keep their max-calibration amax.

Args:
model: Model to be calibrated.
Expand Down Expand Up @@ -415,9 +415,39 @@ def mse_calibrate(
torch.cuda.synchronize(torch.device(f"cuda:{dev_id}"))
torch.cuda.empty_cache()

replace_mse_calibrators_with_max(model)

# TODO: Sync amax across distributed processes


def replace_mse_calibrators_with_max(model: nn.Module) -> int:
"""Replace MseCalibrator and NVFP4MSECalibrator with MaxCalibrator after calibration.

Call this after :func:`mse_calibrate` or :func:`local_hessian_calibrate` so that
downstream steps (e.g. GPTQ or another calibration pass) can run correctly, since
advanced algorithms (MSE, local_hessian) all start with max calibration and expect
MaxCalibrator on quantizers that still need calibration.

Args:
model: The calibrated model.

Returns:
Number of calibrators replaced.
"""
replaced_count = 0
for name, module in model.named_modules():
if isinstance(module, TensorQuantizer) and not module._disabled:
if hasattr(module, "_calibrator") and module._calibrator is not None:
if isinstance(module._calibrator, (MseCalibrator, NVFP4MSECalibrator)):
module._calibrator = MaxCalibrator(
num_bits=module._num_bits,
axis=module._axis,
unsigned=module._unsigned,
)
replaced_count += 1
return replaced_count


@torch.no_grad()
def local_hessian_calibrate(
model: nn.Module,
Expand All @@ -430,7 +460,10 @@ def local_hessian_calibrate(
block_size: int = 16,
debug: bool = False,
):
"""Calibrate the model using local Hessian-weighted MSE search.
"""Calibrate the model using local Hessian-weighted MSE search (weight quantizers only).

Only weight quantizers are calibrated; activation quantizers keep their
max-calibration amax from the initial max calibration pass.

Instead of minimizing weight error ``||W - Wq||²``, this minimizes Hessian-weighted error
``loss = (W - Wq)ᵀ H (W - Wq)`` where ``H = X @ X.T`` approximates output reconstruction
Expand Down Expand Up @@ -544,15 +577,42 @@ def local_hessian_error(x: torch.Tensor, xq: torch.Tensor) -> torch.Tensor:

def forward(self, input, *args, **kwargs):
"""Custom forward that collects activations in cache mode."""
if LocalHessianHelper.cache_mode and self.hessian_helper.is_enabled:
# Get local tensor from DTensor if applicable
input_local = input.to_local() if hasattr(input, "to_local") else input
self.hessian_helper.accumulate_hessian(input_local)

# Forward without quantization during caching
# Forward without weight quantization during caching
if LocalHessianHelper.cache_mode:
self.weight_quantizer.disable()

# Capture quantized input from the forward pass for Hessian collection
captured_quantized = [None]
original_forward = None
if (
self.hessian_helper.is_enabled
and hasattr(self, "input_quantizer")
and self.input_quantizer.is_enabled
):
original_forward = self.input_quantizer.forward

def capture_forward(input_tensor):
quantized = original_forward(input_tensor)
captured_quantized[0] = (
quantized.to_local() if hasattr(quantized, "to_local") else quantized
)
return quantized

self.input_quantizer.forward = capture_forward

out = self._forward_no_local_hessian(input, *args, **kwargs)

# Collect Hessian from the quantized input that was used in forward pass
if self.hessian_helper.is_enabled:
if hasattr(self, "input_quantizer") and self.input_quantizer.is_enabled:
self.hessian_helper.accumulate_hessian(captured_quantized[0])
if original_forward is not None:
self.input_quantizer.forward = original_forward
else:
# No input_quantizer, use raw input
input_local = input.to_local() if hasattr(input, "to_local") else input
self.hessian_helper.accumulate_hessian(input_local)

Comment on lines +580 to +615
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

Consider adding a guard for captured_quantized[0] being None.

The monkey-patching approach to capture quantized inputs is sound and correctly restores the original forward method. However, if capture_forward is never invoked during _forward_no_local_hessian (an unlikely edge case), captured_quantized[0] would remain None, causing accumulate_hessian to fail on input_tensor.reshape().

Consider adding a defensive check:

🛡️ Proposed defensive guard
         # Collect Hessian from the quantized input that was used in forward pass
         if self.hessian_helper.is_enabled:
             if hasattr(self, "input_quantizer") and self.input_quantizer.is_enabled:
-                self.hessian_helper.accumulate_hessian(captured_quantized[0])
+                if captured_quantized[0] is not None:
+                    self.hessian_helper.accumulate_hessian(captured_quantized[0])
                 if original_forward is not None:
                     self.input_quantizer.forward = original_forward
             else:
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@modelopt/torch/quantization/model_calib.py` around lines 580 - 615, The
captured_quantized[0] may still be None if capture_forward is not invoked during
_forward_no_local_hessian; add a defensive guard before calling
self.hessian_helper.accumulate_hessian: check captured_quantized[0] is not None
and only then call accumulate_hessian(captured_quantized[0]); if it is None fall
back to the raw input path (use input.to_local() if available) and then restore
self.input_quantizer.forward to original_forward if set; update the block that
follows out = self._forward_no_local_hessian(...) to perform this None-check and
fallback while preserving restoration of the original_forward.

self.weight_quantizer.enable()
return out

Expand Down Expand Up @@ -699,6 +759,8 @@ def quant_func(x, amax, quantizer=weight_quantizer):
for name, module in all_patched_modules:
module.hessian_helper.cleanup()

replace_mse_calibrators_with_max(model)

print_rank_0("local_hessian: Calibration complete.")


Expand Down