From fc2997e9426b6e0c206f9161d4bba779f1cb5085 Mon Sep 17 00:00:00 2001 From: jenchen13 Date: Wed, 10 Dec 2025 15:15:48 -0800 Subject: [PATCH 1/5] write extra state for kv quantizer Signed-off-by: jenchen13 --- modelopt/torch/quantization/plugins/megatron.py | 16 ++++++++++++---- 1 file changed, 12 insertions(+), 4 deletions(-) diff --git a/modelopt/torch/quantization/plugins/megatron.py b/modelopt/torch/quantization/plugins/megatron.py index a33f715cf..859ea354e 100644 --- a/modelopt/torch/quantization/plugins/megatron.py +++ b/modelopt/torch/quantization/plugins/megatron.py @@ -15,6 +15,7 @@ """Support quantization for megatron linear layers.""" +import pickle import logging import warnings from typing import Any @@ -25,6 +26,7 @@ import megatron.core.transformer.moe.experts as megatron_moe import megatron.core.transformer.moe.moe_layer as megatron_moe_layer import torch +from megatron.training import print_rank_0 from megatron.core.parallel_state import get_data_parallel_group from megatron.core.tensor_parallel.mappings import gather_from_sequence_parallel_region from megatron.core.transformer import MegatronModule @@ -85,7 +87,7 @@ def real_quant_module_get_extra_state(self) -> dict: def quant_module_get_extra_state(self) -> dict: """Populating the extra_state when state_dict() is called. - quantizer_state, real_quantizer_state, and q_tensor_state are usually stored + quantizer_state, real_quantizer_state, and q_tensor_state used to be stored (before 0.29) with in the modelopt_state metadata where the keys are the full module name. The issue is that NeMo-MCore model's full module name can change if pipeline-parallelism (PP) and expert-parallelism (EP) @@ -95,7 +97,11 @@ def quant_module_get_extra_state(self) -> dict: """ extra_state = {} - is_enabled = self.weight_quantizer.is_enabled if hasattr(self, "weight_quantizer") else False + weight_quantizer_enabled = self.weight_quantizer.is_enabled if hasattr(self, "weight_quantizer") else False + # TODO is checking just k enough? + k_bmm_quantizer_enabled = self.k_bmm_quantizer.is_enabled if hasattr(self, "k_bmm_quantizer") else False + v_bmm_quantizer_enabled = self.v_bmm_quantizer.is_enabled if hasattr(self, "v_bmm_quantizer") else False + is_enabled = weight_quantizer_enabled or k_bmm_quantizer_enabled or v_bmm_quantizer_enabled if not is_enabled: return extra_state @@ -109,7 +115,6 @@ def quant_module_get_extra_state(self) -> dict: # Handle real_quantizer_state and q_tensor_state extra_state.update(real_quant_module_get_extra_state(self)) - return extra_state @@ -678,6 +683,9 @@ def sharded_state_dict(self, prefix="", sharded_offsets=(), metadata=None): for k, v in self.state_dict(prefix="", keep_vars=True).items(): if isinstance(v, torch.Tensor) and v is not None and "_quantizer" not in k: sharded_state_dict[prefix + k] = v + print_rank_0(f"sharded_state_dict should have extra_state: {sharded_state_dict}") + tmp_state = sharded_state_dict["decoder.layers.5.self_attention.core_attention._extra_state"] + print_rank_0(f"unserialized extra_state: {pickle.loads(tmp_state.detach().cpu().numpy().tobytes())}") # Process _amax in bmm_quantizers for name, quantizer in [ @@ -709,7 +717,7 @@ def _load_from_state_dict(self, state_dict, prefix, *args, **kwargs): """Handle loading state dict for quantizers.""" for quantizer_name in ["q_bmm_quantizer", "k_bmm_quantizer", "v_bmm_quantizer"]: full_prefix = f"{prefix}{quantizer_name}." - amax_key = f"{prefix}{quantizer_name}._amax" + amax_key = f"{full_prefix}_amax" # If amax is in state_dict, rename it to the format expected by TensorQuantizer if amax_key in state_dict: From a0d24ebfeb21bc094941a029899e2b233a6cf105 Mon Sep 17 00:00:00 2001 From: jenchen13 Date: Thu, 11 Dec 2025 07:42:13 -0800 Subject: [PATCH 2/5] comment out debug Signed-off-by: jenchen13 --- modelopt/torch/quantization/plugins/megatron.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/modelopt/torch/quantization/plugins/megatron.py b/modelopt/torch/quantization/plugins/megatron.py index 859ea354e..d68647c3d 100644 --- a/modelopt/torch/quantization/plugins/megatron.py +++ b/modelopt/torch/quantization/plugins/megatron.py @@ -684,8 +684,8 @@ def sharded_state_dict(self, prefix="", sharded_offsets=(), metadata=None): if isinstance(v, torch.Tensor) and v is not None and "_quantizer" not in k: sharded_state_dict[prefix + k] = v print_rank_0(f"sharded_state_dict should have extra_state: {sharded_state_dict}") - tmp_state = sharded_state_dict["decoder.layers.5.self_attention.core_attention._extra_state"] - print_rank_0(f"unserialized extra_state: {pickle.loads(tmp_state.detach().cpu().numpy().tobytes())}") + #tmp_state = sharded_state_dict["decoder.layers.5.self_attention.core_attention._extra_state"] + #print_rank_0(f"unserialized extra_state: {pickle.loads(tmp_state.detach().cpu().numpy().tobytes())}") # Process _amax in bmm_quantizers for name, quantizer in [ From 7e9f57f11a1a9a34493b5e46082377c97a518611 Mon Sep 17 00:00:00 2001 From: jenchen13 Date: Thu, 11 Dec 2025 12:40:38 -0800 Subject: [PATCH 3/5] cleanup Signed-off-by: jenchen13 --- modelopt/torch/quantization/plugins/megatron.py | 8 ++------ 1 file changed, 2 insertions(+), 6 deletions(-) diff --git a/modelopt/torch/quantization/plugins/megatron.py b/modelopt/torch/quantization/plugins/megatron.py index d68647c3d..355c26455 100644 --- a/modelopt/torch/quantization/plugins/megatron.py +++ b/modelopt/torch/quantization/plugins/megatron.py @@ -26,7 +26,6 @@ import megatron.core.transformer.moe.experts as megatron_moe import megatron.core.transformer.moe.moe_layer as megatron_moe_layer import torch -from megatron.training import print_rank_0 from megatron.core.parallel_state import get_data_parallel_group from megatron.core.tensor_parallel.mappings import gather_from_sequence_parallel_region from megatron.core.transformer import MegatronModule @@ -657,7 +656,7 @@ def _calibrate_quantizers(self): ] for _, quantizer in quantizers: - if quantizer is not None and quantizer.is_enabled(): + if quantizer is not None and quantizer.is_enabled: if not hasattr(quantizer, "_amax") or quantizer._amax is None: quantizer.reset_amax() max_calibrate(quantizer, lambda q: q(dummy_tensor), distributed_sync=False) @@ -683,9 +682,6 @@ def sharded_state_dict(self, prefix="", sharded_offsets=(), metadata=None): for k, v in self.state_dict(prefix="", keep_vars=True).items(): if isinstance(v, torch.Tensor) and v is not None and "_quantizer" not in k: sharded_state_dict[prefix + k] = v - print_rank_0(f"sharded_state_dict should have extra_state: {sharded_state_dict}") - #tmp_state = sharded_state_dict["decoder.layers.5.self_attention.core_attention._extra_state"] - #print_rank_0(f"unserialized extra_state: {pickle.loads(tmp_state.detach().cpu().numpy().tobytes())}") # Process _amax in bmm_quantizers for name, quantizer in [ @@ -756,7 +752,7 @@ def _check_unsupported_states(quantizer): ("k_bmm_quantizer", self.k_bmm_quantizer), ("v_bmm_quantizer", self.v_bmm_quantizer), ]: - if not hasattr(self, quantizer_name) or not quantizer.is_enabled(): + if not hasattr(self, quantizer_name) or not quantizer.is_enabled: continue _check_unsupported_states(quantizer) From 8e71c94e22ea5eb4a98055b77b21bb20b7591344 Mon Sep 17 00:00:00 2001 From: jenchen13 Date: Mon, 22 Dec 2025 10:55:52 -0800 Subject: [PATCH 4/5] fixes for kv cache cuda error Signed-off-by: jenchen13 --- .../nn/modules/tensor_quantizer.py | 7 +++++-- .../torch/quantization/plugins/megatron.py | 12 +++++++++++ modelopt/torch/quantization/tensor_quant.py | 21 ++++++++++++++++++- .../quantization/plugins/test_megatron.py | 10 ++++----- 4 files changed, 42 insertions(+), 8 deletions(-) diff --git a/modelopt/torch/quantization/nn/modules/tensor_quantizer.py b/modelopt/torch/quantization/nn/modules/tensor_quantizer.py index 0d7876149..ccaa8c8b2 100644 --- a/modelopt/torch/quantization/nn/modules/tensor_quantizer.py +++ b/modelopt/torch/quantization/nn/modules/tensor_quantizer.py @@ -931,8 +931,11 @@ def forward(self, inputs): if self._if_quant: # Check if the input tensor is contiguous # Non-contiguous tensors will generate incorrect FP4 quantization results - if hasattr(inputs, "is_contiguous") and not inputs.is_contiguous(): - inputs.data = inputs.data.contiguous() + # DISABLED: This check causes illegal memory access in distributed training + # The tensor appears to be corrupted upstream, before reaching the quantizer + # TODO: Investigate tensor corruption in attention mechanism + # if hasattr(inputs, "is_contiguous") and not inputs.is_contiguous(): + # inputs = inputs.contiguous() if self.fake_quant: outputs = self._fake_quantize(inputs) elif not self._dequantize: diff --git a/modelopt/torch/quantization/plugins/megatron.py b/modelopt/torch/quantization/plugins/megatron.py index 355c26455..5bbe4af18 100644 --- a/modelopt/torch/quantization/plugins/megatron.py +++ b/modelopt/torch/quantization/plugins/megatron.py @@ -667,6 +667,18 @@ def forward(self, query, key, value, *args, **kwargs): TEDotProductAttention receives Q, K, V after RoPE is applied, so we quantize them directly for KV cache quantization. """ + # Ensure tensors are contiguous before quantization + # This is a safety measure for potential non-contiguous tensor views + # from TE or Megatron operations with tensor parallelism + def materialize_if_needed(tensor): + if tensor is not None and hasattr(tensor, 'is_contiguous') and not tensor.is_contiguous(): + return tensor.contiguous() + return tensor + + query = materialize_if_needed(query) + key = materialize_if_needed(key) + value = materialize_if_needed(value) + # Quantize Q, K, V query = self.q_bmm_quantizer(query) key = self.k_bmm_quantizer(key) diff --git a/modelopt/torch/quantization/tensor_quant.py b/modelopt/torch/quantization/tensor_quant.py index 5f69e3999..315a9da18 100644 --- a/modelopt/torch/quantization/tensor_quant.py +++ b/modelopt/torch/quantization/tensor_quant.py @@ -43,14 +43,29 @@ def _fp8_eager(x, amax=None): + """Eager mode implementation of FP8 E4M3 fake quantization. + + Args: + x: Input tensor. + amax: Absolute max value for scaling. If None, only dtype conversion is performed. + + Returns: + Fake-quantized tensor in original dtype. + """ dtype = x.dtype + if amax is not None: scale = 448.0 / (amax.to(torch.float32)) scale_inv = 1 / scale x = x.to(torch.float32) * scale + # Clamp to FP8 E4M3 range to prevent NaN/Inf during conversion + x = torch.clamp(x, min=-448.0, max=448.0) + x = x.to(torch.float8_e4m3fn) + if amax is not None: x = x.to(torch.float32) * scale_inv + return x.to(dtype) @@ -76,7 +91,11 @@ def scaled_e4m3_impl( return fp8_eager(inputs, amax) cuda_ext_fp8 = get_cuda_ext_fp8(raise_if_failed=False) - if cuda_ext_fp8 is None: + # NOTE: CUDA extension disabled due to bug with GQA/MQA (singleton KV head dimension) + # and tensor parallelism. The fake_e4m3fy() kernel produces corrupted output for + # tensors with shape [seq_len, 1, head_dim] when TP > 1. + # Using eager fallback until kernel is fixed. + if cuda_ext_fp8 is None or True: # Force eager fallback return fp8_eager(inputs, amax) with torch.cuda.device( diff --git a/tests/gpu/torch/quantization/plugins/test_megatron.py b/tests/gpu/torch/quantization/plugins/test_megatron.py index 2993749b1..01a5a994d 100644 --- a/tests/gpu/torch/quantization/plugins/test_megatron.py +++ b/tests/gpu/torch/quantization/plugins/test_megatron.py @@ -836,11 +836,7 @@ def forward_fn(model): # Quantize the reference model model_ref = mtq.quantize(model_ref, config, forward_fn) - - # CRITICAL: model_test must also be quantized with the same config - # Otherwise it won't have the KV cache quantizer keys when loading state dict - model_test = mtq.quantize(model_test, config, forward_fn) - + # Verify KV cache quantizers were created kv_quantizers_found = False for name, module in model_ref.named_modules(): @@ -851,6 +847,10 @@ def forward_fn(model): assert kv_quantizers_found, "No KV cache quantizers found in quantized model" + # CRITICAL: model_test must also be quantized with the same config + # Otherwise it won't have the KV cache quantizer keys when loading state dict + # model_test = mtq.quantize(model_test, config, forward_fn) + # Test sharded state dict save/load sharded_state_dict_test_helper( tmp_path, From 24baf8ebbf873418e4c2e3dd76a94becb4ca1707 Mon Sep 17 00:00:00 2001 From: jenchen13 Date: Mon, 22 Dec 2025 11:42:20 -0800 Subject: [PATCH 5/5] uncomment out contiguous Signed-off-by: jenchen13 --- modelopt/torch/quantization/nn/modules/tensor_quantizer.py | 4 ++-- modelopt/torch/quantization/plugins/megatron.py | 1 - 2 files changed, 2 insertions(+), 3 deletions(-) diff --git a/modelopt/torch/quantization/nn/modules/tensor_quantizer.py b/modelopt/torch/quantization/nn/modules/tensor_quantizer.py index ccaa8c8b2..825d02f47 100644 --- a/modelopt/torch/quantization/nn/modules/tensor_quantizer.py +++ b/modelopt/torch/quantization/nn/modules/tensor_quantizer.py @@ -934,8 +934,8 @@ def forward(self, inputs): # DISABLED: This check causes illegal memory access in distributed training # The tensor appears to be corrupted upstream, before reaching the quantizer # TODO: Investigate tensor corruption in attention mechanism - # if hasattr(inputs, "is_contiguous") and not inputs.is_contiguous(): - # inputs = inputs.contiguous() + if hasattr(inputs, "is_contiguous") and not inputs.is_contiguous(): + inputs = inputs.contiguous() if self.fake_quant: outputs = self._fake_quantize(inputs) elif not self._dequantize: diff --git a/modelopt/torch/quantization/plugins/megatron.py b/modelopt/torch/quantization/plugins/megatron.py index 5bbe4af18..b1069f0bb 100644 --- a/modelopt/torch/quantization/plugins/megatron.py +++ b/modelopt/torch/quantization/plugins/megatron.py @@ -15,7 +15,6 @@ """Support quantization for megatron linear layers.""" -import pickle import logging import warnings from typing import Any