Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 11 additions & 5 deletions modelopt/torch/quantization/model_calib.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,12 +70,18 @@ def weight_only_quantize(model: nn.Module):
for name, module in model.named_modules():
if module in seen_modules:
continue
for weight_name in weight_attr_names(module):

if isinstance(module, QuantModule):
with enable_weight_access_and_writeback(module, model):
weight_quantizer = getattr(
module, quantizer_attr_names(weight_name).weight_quantizer
)
weight_quantizer(getattr(module, weight_name))
for weight, weight_quantizer in module.iter_weights_for_calibration():
weight_quantizer(weight)
else:
for weight_name in weight_attr_names(module):
with enable_weight_access_and_writeback(module, model):
weight_quantizer = getattr(
module, quantizer_attr_names(weight_name).weight_quantizer
)
weight_quantizer(getattr(module, weight_name))
seen_modules.add(module)


Expand Down
15 changes: 15 additions & 0 deletions modelopt/torch/quantization/nn/modules/quant_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,21 @@ def modelopt_post_restore(self, prefix: str = ""):
if isinstance(module, TensorQuantizer):
module.to(non_tq_param_or_buffer.device)

def iter_weights_for_calibration(self):
"""Yield ``(weight, weight_quantizer)`` pairs for weight-only calibration.

The default implementation iterates over all weights returned by
:func:`~modelopt.torch.quantization.utils.weight_attr_names`. Subclasses that
store weights under non-standard attribute names (e.g.
``_QuantTEGroupedLinear`` uses ``weight0``, ``weight1``, …) should
override this method.
"""
from modelopt.torch.quantization.utils import quantizer_attr_names, weight_attr_names

for weight_name in weight_attr_names(self):
weight_quantizer = getattr(self, quantizer_attr_names(weight_name).weight_quantizer)
yield getattr(self, weight_name), weight_quantizer

def fold_weight(self, keep_attrs: bool = False):
"""Fold the weight for faster eval."""
# Handle all attributes that end with _weight_quantizer
Expand Down
12 changes: 12 additions & 0 deletions modelopt/torch/quantization/plugins/transformer_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,6 +151,18 @@ def modelopt_post_restore(self, prefix: str = ""):
# Remove self.weight after post_restore.
delattr(self, "weight")

def iter_weights_for_calibration(self):
"""Yield ``(weight_i, weight_quantizer)`` for each of the ``num_gemms`` grouped weights.

Override is needed because ``self.weight`` is removed in ``_setup``, so the
base-class implementation (which relies on ``weight_attr_names``) would find
no weights. Here we iterate over ``weight0``, ``weight1``, … directly.
"""
for i in range(self.num_gemms):
weight_i = getattr(self, f"weight{i}", None)
if weight_i is not None:
yield weight_i, self.weight_quantizer

@staticmethod
def te_grouped_quantized_linear_fn(package, func_name, self, *args):
_assert_te_fp8_enabled()
Expand Down
2 changes: 1 addition & 1 deletion modelopt/torch/quantization/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -229,7 +229,7 @@ def weight_attr_names(module: nn.Module) -> "Generator[str, None, None]":
# the standard weight and quantizer case
weight = getattr(module, "weight", None)
weight_quantizer = getattr(module, "weight_quantizer", None)
if isinstance(weight_quantizer, (TensorQuantizer, SequentialQuantizer)):
if weight is not None and isinstance(weight_quantizer, (TensorQuantizer, SequentialQuantizer)):
yield "weight"

# other weight and quantizer case
Expand Down