diff --git a/modelopt/torch/quantization/model_calib.py b/modelopt/torch/quantization/model_calib.py index 70f036a8d..0c9204566 100644 --- a/modelopt/torch/quantization/model_calib.py +++ b/modelopt/torch/quantization/model_calib.py @@ -70,12 +70,18 @@ def weight_only_quantize(model: nn.Module): for name, module in model.named_modules(): if module in seen_modules: continue - for weight_name in weight_attr_names(module): + + if isinstance(module, QuantModule): with enable_weight_access_and_writeback(module, model): - weight_quantizer = getattr( - module, quantizer_attr_names(weight_name).weight_quantizer - ) - weight_quantizer(getattr(module, weight_name)) + for weight, weight_quantizer in module.iter_weights_for_calibration(): + weight_quantizer(weight) + else: + for weight_name in weight_attr_names(module): + with enable_weight_access_and_writeback(module, model): + weight_quantizer = getattr( + module, quantizer_attr_names(weight_name).weight_quantizer + ) + weight_quantizer(getattr(module, weight_name)) seen_modules.add(module) diff --git a/modelopt/torch/quantization/nn/modules/quant_module.py b/modelopt/torch/quantization/nn/modules/quant_module.py index 7d1471207..e698d6203 100644 --- a/modelopt/torch/quantization/nn/modules/quant_module.py +++ b/modelopt/torch/quantization/nn/modules/quant_module.py @@ -119,6 +119,21 @@ def modelopt_post_restore(self, prefix: str = ""): if isinstance(module, TensorQuantizer): module.to(non_tq_param_or_buffer.device) + def iter_weights_for_calibration(self): + """Yield ``(weight, weight_quantizer)`` pairs for weight-only calibration. + + The default implementation iterates over all weights returned by + :func:`~modelopt.torch.quantization.utils.weight_attr_names`. Subclasses that + store weights under non-standard attribute names (e.g. + ``_QuantTEGroupedLinear`` uses ``weight0``, ``weight1``, …) should + override this method. + """ + from modelopt.torch.quantization.utils import quantizer_attr_names, weight_attr_names + + for weight_name in weight_attr_names(self): + weight_quantizer = getattr(self, quantizer_attr_names(weight_name).weight_quantizer) + yield getattr(self, weight_name), weight_quantizer + def fold_weight(self, keep_attrs: bool = False): """Fold the weight for faster eval.""" # Handle all attributes that end with _weight_quantizer diff --git a/modelopt/torch/quantization/plugins/transformer_engine.py b/modelopt/torch/quantization/plugins/transformer_engine.py index afc08211f..f0f18e55b 100644 --- a/modelopt/torch/quantization/plugins/transformer_engine.py +++ b/modelopt/torch/quantization/plugins/transformer_engine.py @@ -151,6 +151,18 @@ def modelopt_post_restore(self, prefix: str = ""): # Remove self.weight after post_restore. delattr(self, "weight") + def iter_weights_for_calibration(self): + """Yield ``(weight_i, weight_quantizer)`` for each of the ``num_gemms`` grouped weights. + + Override is needed because ``self.weight`` is removed in ``_setup``, so the + base-class implementation (which relies on ``weight_attr_names``) would find + no weights. Here we iterate over ``weight0``, ``weight1``, … directly. + """ + for i in range(self.num_gemms): + weight_i = getattr(self, f"weight{i}", None) + if weight_i is not None: + yield weight_i, self.weight_quantizer + @staticmethod def te_grouped_quantized_linear_fn(package, func_name, self, *args): _assert_te_fp8_enabled() diff --git a/modelopt/torch/quantization/utils.py b/modelopt/torch/quantization/utils.py index 3c0d5e434..ef489e8d3 100644 --- a/modelopt/torch/quantization/utils.py +++ b/modelopt/torch/quantization/utils.py @@ -229,7 +229,7 @@ def weight_attr_names(module: nn.Module) -> "Generator[str, None, None]": # the standard weight and quantizer case weight = getattr(module, "weight", None) weight_quantizer = getattr(module, "weight_quantizer", None) - if isinstance(weight_quantizer, (TensorQuantizer, SequentialQuantizer)): + if weight is not None and isinstance(weight_quantizer, (TensorQuantizer, SequentialQuantizer)): yield "weight" # other weight and quantizer case