Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
163 changes: 26 additions & 137 deletions modelopt/torch/quantization/plugins/megatron.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,6 @@
)
from modelopt.torch.utils.distributed import ParallelState

from ..model_calib import max_calibrate
from ..nn import QuantModule, QuantModuleRegistry, TensorQuantizer
from ..nn.modules.quant_linear import RealQuantLinear
from ..qtensor import QTensorWrapper
Expand Down Expand Up @@ -98,7 +97,9 @@ def quant_module_get_extra_state(self) -> dict:
"""
extra_state = {}

is_enabled = self.weight_quantizer.is_enabled if hasattr(self, "weight_quantizer") else False
is_enabled = any(
isinstance(child, TensorQuantizer) and child.is_enabled for child in self.children()
)

if not is_enabled:
return extra_state
Expand Down Expand Up @@ -222,6 +223,10 @@ def _register_extra_state_callbacks(model: torch.nn.Module):
quant_module_get_extra_state,
quant_module_set_extra_state,
)
if HAS_TE and isinstance(module, TEDotProductAttention):
# A hack to set the dtype and device for DotProductAttention
# to be used in _QuantTEDotProductAttention.modelopt_post_restore()
_QuantTEDotProductAttention.set_dtype(module, name, model)

for name, module in model.named_modules():
if isinstance(module, MegatronModule):
Expand Down Expand Up @@ -632,152 +637,36 @@ def _setup(self):
self.k_bmm_quantizer = TensorQuantizer()
self.v_bmm_quantizer = TensorQuantizer()

def _calibrate_quantizers(self):
"""Calibrate quantizers with minimal dummy tensors."""
# Get device and dtype from the parent module's parameters
param = next(iter(self.parameters()), None)
device = param.device if param is not None else torch.device("cuda")
dtype = param.dtype if param is not None else torch.float16

# TEDotProductAttention expects format 'sbhd' or 'bshd' depending on rope_fusion
batch_size = 1
seq_len = 1

# Get dimensions from config
num_heads = self.config.num_attention_heads
head_dim = (
self.config.kv_channels
if hasattr(self.config, "kv_channels")
else self.config.hidden_size // num_heads
)

# Determine tensor format (default to sbhd if not specified)
apply_rope_fusion = getattr(self.config, "apply_rope_fusion", False)
qkv_format = "bshd" if apply_rope_fusion else "sbhd"

if qkv_format == "sbhd":
dummy_tensor = torch.randn(
seq_len, batch_size, num_heads, head_dim, device=device, dtype=dtype
)
else:
dummy_tensor = torch.randn(
batch_size, seq_len, num_heads, head_dim, device=device, dtype=dtype
)

# Calibrate each quantizer
quantizers = [
("q_bmm_quantizer", self.q_bmm_quantizer),
("k_bmm_quantizer", self.k_bmm_quantizer),
("v_bmm_quantizer", self.v_bmm_quantizer),
]

for _, quantizer in quantizers:
if quantizer is not None and quantizer.is_enabled():
if not hasattr(quantizer, "_amax") or quantizer._amax is None:
quantizer.reset_amax()
max_calibrate(quantizer, lambda q: q(dummy_tensor), distributed_sync=False)

def forward(self, query, key, value, *args, **kwargs):
"""Apply post-RoPE quantization to KV cache.

TEDotProductAttention receives Q, K, V after RoPE is applied,
so we quantize them directly for KV cache quantization.
"""
"""Apply post-RoPE quantization to KV cache."""
# Quantize Q, K, V
query = self.q_bmm_quantizer(query)
key = self.k_bmm_quantizer(key)
value = self.v_bmm_quantizer(value)

return super().forward(query, key, value, *args, **kwargs)

def sharded_state_dict(self, prefix="", sharded_offsets=(), metadata=None):
"""Create a sharded state dictionary for distributed checkpointing."""
sharded_state_dict = {}

# First add non-quantizer parameters
for k, v in self.state_dict(prefix="", keep_vars=True).items():
if isinstance(v, torch.Tensor) and v is not None and "_quantizer" not in k:
sharded_state_dict[prefix + k] = v

# Process _amax in bmm_quantizers
for name, quantizer in [
("q_bmm_quantizer", self.q_bmm_quantizer),
("k_bmm_quantizer", self.k_bmm_quantizer),
("v_bmm_quantizer", self.v_bmm_quantizer),
]:
if hasattr(quantizer, "_amax") and quantizer._amax is not None:
amax_key = f"{prefix}{name}._amax"
sharded_state_dict[amax_key] = quantizer._amax

# Process other quantizer parameters in bmm_quantizers
quantizer_state_dict = {
k: v
for k, v in self.state_dict(prefix="", keep_vars=True).items()
if isinstance(v, torch.Tensor) and "_quantizer" in k and "_amax" not in k
}

if quantizer_state_dict:
sharded_state_dict.update(
**make_sharded_tensors_for_checkpoint(
quantizer_state_dict, prefix, {}, sharded_offsets
)
)

return sharded_state_dict

def _load_from_state_dict(self, state_dict, prefix, *args, **kwargs):
"""Handle loading state dict for quantizers."""
for quantizer_name in ["q_bmm_quantizer", "k_bmm_quantizer", "v_bmm_quantizer"]:
full_prefix = f"{prefix}{quantizer_name}."
amax_key = f"{prefix}{quantizer_name}._amax"

# If amax is in state_dict, rename it to the format expected by TensorQuantizer
if amax_key in state_dict:
expected_amax_key = f"{full_prefix}_amax"
state_dict[expected_amax_key] = state_dict.pop(amax_key)

# Handle other quantizer states
for k in list(state_dict.keys()):
if "_quantizer" in k and "_amax" not in k:
name = k.split(prefix)[-1] if prefix else k
if name in self.state_dict():
state_dict[k] = state_dict[k].view_as(self.state_dict()[name])

super()._load_from_state_dict(state_dict, prefix, *args, **kwargs)

def modelopt_post_restore(self, name=""):
"""Restore quantizer states after model loading."""
super().modelopt_post_restore(name)

def _check_unsupported_states(quantizer):
"""Check for unsupported quantizer states and warn if found."""
if not hasattr(quantizer, "state_dict"):
return

for k in quantizer.state_dict():
if k not in ["_amax", "_pre_quant_scale"]:
warnings.warn(
f"Restore of {k} for {name} is not supported. The restore of this layer might be "
f"incorrect. Please implement a custom restore for {k}."
)

calibration_needed = False

for quantizer_name, quantizer in [
("q_bmm_quantizer", self.q_bmm_quantizer),
("k_bmm_quantizer", self.k_bmm_quantizer),
("v_bmm_quantizer", self.v_bmm_quantizer),
]:
if not hasattr(self, quantizer_name) or not quantizer.is_enabled():
continue

_check_unsupported_states(quantizer)
for tq in [self.q_bmm_quantizer, self.k_bmm_quantizer, self.v_bmm_quantizer]:
# TODO: Add support for non-scalar states such as
# Affine KVCache bias vector which is per head per channel
assert all(v.numel() == 1 for v in tq.state_dict().values()), (
"Only scalar states are KV Cache/BMM Quantizers"
)
# Should have been set in the `megatron_replace_quant_module_hook`
assert hasattr(self, "device") and hasattr(self, "dtype")
self.to(device=self.device, dtype=self.dtype)

if not hasattr(quantizer, "_amax") or quantizer._amax is None:
calibration_needed = True
@staticmethod
def set_dtype(module: "TEDotProductAttention", name, model: torch.nn.Module):
"""Set the dtype for the module from any parameter in the model.

if calibration_needed:
self._calibrate_quantizers()
DotProductAttention does not have any parameters, so lets get the parameter from the parent module.
"""
parent = model.get_submodule(name.rsplit(".", 1)[0]) if "." in name else model
param = next(iter(parent.parameters()))
module.dtype = param.dtype
module.device = param.device


@QuantModuleRegistry.register({megatron_moe_layer.MoELayer: "megatron_moe_MoELayer"})
Expand Down
10 changes: 10 additions & 0 deletions tests/_test_utils/torch/megatron/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -214,6 +214,16 @@ def convert_maybe_fp8(v):
f"diff: {logits_diff.max()} ref: {logits_ref}, test: {logits_test}"
)

# Test backward pass on model_test
model_test.train()
loss = forward_fn(model_test).sum()
loss.backward()

# Assert that trainable parameters have gradients computed
for name, param in model_test.named_parameters():
if param.requires_grad:
assert param.grad is not None, f"Parameter {name} has no gradient computed"


def copy_weights_from_grouped_to_non_grouped(te_grouped_moe_model, sequential_moe_model):
"""Copy weights from TEGrouped MoE model to sequential MoE model."""
Expand Down
Loading
Loading