Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -931,8 +931,11 @@ def forward(self, inputs):
if self._if_quant:
# Check if the input tensor is contiguous
# Non-contiguous tensors will generate incorrect FP4 quantization results
# DISABLED: This check causes illegal memory access in distributed training
# The tensor appears to be corrupted upstream, before reaching the quantizer
# TODO: Investigate tensor corruption in attention mechanism
if hasattr(inputs, "is_contiguous") and not inputs.is_contiguous():
inputs.data = inputs.data.contiguous()
inputs = inputs.contiguous()
if self.fake_quant:
outputs = self._fake_quantize(inputs)
elif not self._dequantize:
Expand Down
27 changes: 21 additions & 6 deletions modelopt/torch/quantization/plugins/megatron.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ def real_quant_module_get_extra_state(self) -> dict:
def quant_module_get_extra_state(self) -> dict:
"""Populating the extra_state when state_dict() is called.

quantizer_state, real_quantizer_state, and q_tensor_state are usually stored
quantizer_state, real_quantizer_state, and q_tensor_state used to be stored (before 0.29)
with in the modelopt_state metadata where the keys are the full module name. The issue
is that NeMo-MCore model's full module name can change
if pipeline-parallelism (PP) and expert-parallelism (EP)
Expand All @@ -95,7 +95,11 @@ def quant_module_get_extra_state(self) -> dict:
"""
extra_state = {}

is_enabled = self.weight_quantizer.is_enabled if hasattr(self, "weight_quantizer") else False
weight_quantizer_enabled = self.weight_quantizer.is_enabled if hasattr(self, "weight_quantizer") else False
# TODO is checking just k enough?
k_bmm_quantizer_enabled = self.k_bmm_quantizer.is_enabled if hasattr(self, "k_bmm_quantizer") else False
v_bmm_quantizer_enabled = self.v_bmm_quantizer.is_enabled if hasattr(self, "v_bmm_quantizer") else False
is_enabled = weight_quantizer_enabled or k_bmm_quantizer_enabled or v_bmm_quantizer_enabled
Comment on lines +98 to +102
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why not do:

Suggested change
weight_quantizer_enabled = self.weight_quantizer.is_enabled if hasattr(self, "weight_quantizer") else False
# TODO is checking just k enough?
k_bmm_quantizer_enabled = self.k_bmm_quantizer.is_enabled if hasattr(self, "k_bmm_quantizer") else False
v_bmm_quantizer_enabled = self.v_bmm_quantizer.is_enabled if hasattr(self, "v_bmm_quantizer") else False
is_enabled = weight_quantizer_enabled or k_bmm_quantizer_enabled or v_bmm_quantizer_enabled
is_enabled = any(isinstance(child, TensorQuantizer) and child.is_enabled for child in self.children())


if not is_enabled:
return extra_state
Expand All @@ -109,7 +113,6 @@ def quant_module_get_extra_state(self) -> dict:

# Handle real_quantizer_state and q_tensor_state
extra_state.update(real_quant_module_get_extra_state(self))

return extra_state


Expand Down Expand Up @@ -652,7 +655,7 @@ def _calibrate_quantizers(self):
]

for _, quantizer in quantizers:
if quantizer is not None and quantizer.is_enabled():
if quantizer is not None and quantizer.is_enabled:
if not hasattr(quantizer, "_amax") or quantizer._amax is None:
quantizer.reset_amax()
max_calibrate(quantizer, lambda q: q(dummy_tensor), distributed_sync=False)
Expand All @@ -663,6 +666,18 @@ def forward(self, query, key, value, *args, **kwargs):
TEDotProductAttention receives Q, K, V after RoPE is applied,
so we quantize them directly for KV cache quantization.
"""
# Ensure tensors are contiguous before quantization
# This is a safety measure for potential non-contiguous tensor views
# from TE or Megatron operations with tensor parallelism
def materialize_if_needed(tensor):
if tensor is not None and hasattr(tensor, 'is_contiguous') and not tensor.is_contiguous():
return tensor.contiguous()
return tensor

query = materialize_if_needed(query)
key = materialize_if_needed(key)
value = materialize_if_needed(value)
Comment on lines +677 to +679
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we need this if we are calling inputs = inputs.contiguous() in TensorQuantize forward?


# Quantize Q, K, V
query = self.q_bmm_quantizer(query)
key = self.k_bmm_quantizer(key)
Expand Down Expand Up @@ -709,7 +724,7 @@ def _load_from_state_dict(self, state_dict, prefix, *args, **kwargs):
"""Handle loading state dict for quantizers."""
for quantizer_name in ["q_bmm_quantizer", "k_bmm_quantizer", "v_bmm_quantizer"]:
full_prefix = f"{prefix}{quantizer_name}."
amax_key = f"{prefix}{quantizer_name}._amax"
amax_key = f"{full_prefix}_amax"

# If amax is in state_dict, rename it to the format expected by TensorQuantizer
if amax_key in state_dict:
Expand Down Expand Up @@ -748,7 +763,7 @@ def _check_unsupported_states(quantizer):
("k_bmm_quantizer", self.k_bmm_quantizer),
("v_bmm_quantizer", self.v_bmm_quantizer),
]:
if not hasattr(self, quantizer_name) or not quantizer.is_enabled():
if not hasattr(self, quantizer_name) or not quantizer.is_enabled:
continue

_check_unsupported_states(quantizer)
Expand Down
21 changes: 20 additions & 1 deletion modelopt/torch/quantization/tensor_quant.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,14 +43,29 @@


def _fp8_eager(x, amax=None):
"""Eager mode implementation of FP8 E4M3 fake quantization.

Args:
x: Input tensor.
amax: Absolute max value for scaling. If None, only dtype conversion is performed.

Returns:
Fake-quantized tensor in original dtype.
"""
dtype = x.dtype

if amax is not None:
scale = 448.0 / (amax.to(torch.float32))
scale_inv = 1 / scale
x = x.to(torch.float32) * scale
# Clamp to FP8 E4M3 range to prevent NaN/Inf during conversion
x = torch.clamp(x, min=-448.0, max=448.0)

x = x.to(torch.float8_e4m3fn)

if amax is not None:
x = x.to(torch.float32) * scale_inv

return x.to(dtype)


Expand All @@ -76,7 +91,11 @@ def scaled_e4m3_impl(
return fp8_eager(inputs, amax)

cuda_ext_fp8 = get_cuda_ext_fp8(raise_if_failed=False)
if cuda_ext_fp8 is None:
# NOTE: CUDA extension disabled due to bug with GQA/MQA (singleton KV head dimension)
# and tensor parallelism. The fake_e4m3fy() kernel produces corrupted output for
# tensors with shape [seq_len, 1, head_dim] when TP > 1.
# Using eager fallback until kernel is fixed.
if cuda_ext_fp8 is None or True: # Force eager fallback
return fp8_eager(inputs, amax)

with torch.cuda.device(
Expand Down
10 changes: 5 additions & 5 deletions tests/gpu/torch/quantization/plugins/test_megatron.py
Original file line number Diff line number Diff line change
Expand Up @@ -836,11 +836,7 @@ def forward_fn(model):

# Quantize the reference model
model_ref = mtq.quantize(model_ref, config, forward_fn)

# CRITICAL: model_test must also be quantized with the same config
# Otherwise it won't have the KV cache quantizer keys when loading state dict
model_test = mtq.quantize(model_test, config, forward_fn)

Comment on lines 838 to -843
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@kaix-nv this is an incorrect unit test. This completely breaks the modelopt resume workflow (that is resume requires an ModelOpt un-modified model).


# Verify KV cache quantizers were created
kv_quantizers_found = False
for name, module in model_ref.named_modules():
Expand All @@ -851,6 +847,10 @@ def forward_fn(model):

assert kv_quantizers_found, "No KV cache quantizers found in quantized model"

# CRITICAL: model_test must also be quantized with the same config
# Otherwise it won't have the KV cache quantizer keys when loading state dict
# model_test = mtq.quantize(model_test, config, forward_fn)

# Test sharded state dict save/load
sharded_state_dict_test_helper(
tmp_path,
Expand Down
Loading