diff --git a/modelopt/torch/quantization/config.py b/modelopt/torch/quantization/config.py index 80a2a6876..32971094a 100644 --- a/modelopt/torch/quantization/config.py +++ b/modelopt/torch/quantization/config.py @@ -1159,12 +1159,15 @@ class MaxCalibConfig(QuantizeAlgorithmConfig): class MseCalibConfig(QuantizeAlgorithmConfig): - """Configuration for per-tensor MSE calibration. + """Configuration for per-tensor MSE calibration (weight quantizers only). - Finds a scale s (via amax a, with s = a / q_max) that minimizes the - reconstruction error of a tensor after uniform Q→DQ: + The MSE search is applied only to weight quantizers. Activation quantizers + keep their max-calibration amax. - s* = argmin_s E[(X - DQ(Q(X; s)))^2], X ∈ {weights | activations} + For weights, finds a scale s (via amax a, with s = a / q_max) that minimizes + the reconstruction error after uniform Q→DQ: + + s* = argmin_s E[(X - DQ(Q(X; s)))^2], X = weights When fp8_scale_sweep is enabled, step_size is ignored. """ @@ -1209,11 +1212,13 @@ class MseCalibConfig(QuantizeAlgorithmConfig): class LocalHessianCalibConfig(QuantizeAlgorithmConfig): - """Configuration for local Hessian-weighted MSE calibration. + """Configuration for local Hessian-weighted MSE calibration (weight quantizers only). - This algorithm uses activation information to optimize per-block scales for weight - quantization. It minimizes the output reconstruction error by weighting the loss - with the local Hessian matrix computed from input activations. + Only weight quantizers are calibrated with this algorithm; activation quantizers + keep their max-calibration amax. This algorithm uses activation information to + optimize per-block scales for weight quantization. It minimizes the output + reconstruction error by weighting the loss with the local Hessian matrix + computed from input activations. The local Hessian loss for each block is: ``(dw @ H @ dw.T)`` where: - ``dw = weight - quantized_weight`` (weight reconstruction error per block) diff --git a/modelopt/torch/quantization/model_calib.py b/modelopt/torch/quantization/model_calib.py index 70f036a8d..7d074c8fc 100644 --- a/modelopt/torch/quantization/model_calib.py +++ b/modelopt/torch/quantization/model_calib.py @@ -38,7 +38,7 @@ ) from modelopt.torch.utils.perf import get_used_gpu_mem_fraction -from .calib import MseCalibrator, NVFP4MSECalibrator +from .calib import MaxCalibrator, MseCalibrator, NVFP4MSECalibrator from .conversion import create_and_replace_svdquant_linear_on_the_fly, set_quantizer_by_cfg_context from .nn import NVFP4StaticQuantizer, QuantModule, SequentialQuantizer, TensorQuantizer from .utils import ( @@ -288,11 +288,11 @@ def mse_calibrate( stop_multiplier: float = 4.0, fp8_scale_sweep: bool = False, ): - """Calibrate the model using MSE-based amax search. + """Calibrate the model using MSE-based amax search (weight quantizers only). - This calibration method first uses max calibration to get initial amax values, - then searches for better amax values by minimizing the MSE between original - and quantized tensors. + This calibration method first uses max calibration to get initial amax values + for all quantizers, then runs MSE-based amax search only for weight + quantizers. Activation quantizers keep their max-calibration amax. Args: model: Model to be calibrated. @@ -415,9 +415,39 @@ def mse_calibrate( torch.cuda.synchronize(torch.device(f"cuda:{dev_id}")) torch.cuda.empty_cache() + replace_mse_calibrators_with_max(model) + # TODO: Sync amax across distributed processes +def replace_mse_calibrators_with_max(model: nn.Module) -> int: + """Replace MseCalibrator and NVFP4MSECalibrator with MaxCalibrator after calibration. + + Call this after :func:`mse_calibrate` or :func:`local_hessian_calibrate` so that + downstream steps (e.g. GPTQ or another calibration pass) can run correctly, since + advanced algorithms (MSE, local_hessian) all start with max calibration and expect + MaxCalibrator on quantizers that still need calibration. + + Args: + model: The calibrated model. + + Returns: + Number of calibrators replaced. + """ + replaced_count = 0 + for name, module in model.named_modules(): + if isinstance(module, TensorQuantizer) and not module._disabled: + if hasattr(module, "_calibrator") and module._calibrator is not None: + if isinstance(module._calibrator, (MseCalibrator, NVFP4MSECalibrator)): + module._calibrator = MaxCalibrator( + num_bits=module._num_bits, + axis=module._axis, + unsigned=module._unsigned, + ) + replaced_count += 1 + return replaced_count + + @torch.no_grad() def local_hessian_calibrate( model: nn.Module, @@ -430,7 +460,10 @@ def local_hessian_calibrate( block_size: int = 16, debug: bool = False, ): - """Calibrate the model using local Hessian-weighted MSE search. + """Calibrate the model using local Hessian-weighted MSE search (weight quantizers only). + + Only weight quantizers are calibrated; activation quantizers keep their + max-calibration amax from the initial max calibration pass. Instead of minimizing weight error ``||W - Wq||²``, this minimizes Hessian-weighted error ``loss = (W - Wq)ᵀ H (W - Wq)`` where ``H = X @ X.T`` approximates output reconstruction @@ -544,15 +577,42 @@ def local_hessian_error(x: torch.Tensor, xq: torch.Tensor) -> torch.Tensor: def forward(self, input, *args, **kwargs): """Custom forward that collects activations in cache mode.""" - if LocalHessianHelper.cache_mode and self.hessian_helper.is_enabled: - # Get local tensor from DTensor if applicable - input_local = input.to_local() if hasattr(input, "to_local") else input - self.hessian_helper.accumulate_hessian(input_local) - - # Forward without quantization during caching + # Forward without weight quantization during caching if LocalHessianHelper.cache_mode: self.weight_quantizer.disable() + + # Capture quantized input from the forward pass for Hessian collection + captured_quantized = [None] + original_forward = None + if ( + self.hessian_helper.is_enabled + and hasattr(self, "input_quantizer") + and self.input_quantizer.is_enabled + ): + original_forward = self.input_quantizer.forward + + def capture_forward(input_tensor): + quantized = original_forward(input_tensor) + captured_quantized[0] = ( + quantized.to_local() if hasattr(quantized, "to_local") else quantized + ) + return quantized + + self.input_quantizer.forward = capture_forward + out = self._forward_no_local_hessian(input, *args, **kwargs) + + # Collect Hessian from the quantized input that was used in forward pass + if self.hessian_helper.is_enabled: + if hasattr(self, "input_quantizer") and self.input_quantizer.is_enabled: + self.hessian_helper.accumulate_hessian(captured_quantized[0]) + if original_forward is not None: + self.input_quantizer.forward = original_forward + else: + # No input_quantizer, use raw input + input_local = input.to_local() if hasattr(input, "to_local") else input + self.hessian_helper.accumulate_hessian(input_local) + self.weight_quantizer.enable() return out @@ -699,6 +759,8 @@ def quant_func(x, amax, quantizer=weight_quantizer): for name, module in all_patched_modules: module.hessian_helper.cleanup() + replace_mse_calibrators_with_max(model) + print_rank_0("local_hessian: Calibration complete.")