From 0a880c75a337e96599399eaebce98a652fb267cc Mon Sep 17 00:00:00 2001 From: Asma Thekkumpate Date: Wed, 24 Dec 2025 10:32:01 -0800 Subject: [PATCH 1/2] Fixed for MCore KVCache QAD --- .../torch/quantization/plugins/megatron.py | 163 +++--------------- .../quantization/plugins/test_megatron.py | 122 ++----------- 2 files changed, 43 insertions(+), 242 deletions(-) diff --git a/modelopt/torch/quantization/plugins/megatron.py b/modelopt/torch/quantization/plugins/megatron.py index 95e8651aa..3bdacd367 100644 --- a/modelopt/torch/quantization/plugins/megatron.py +++ b/modelopt/torch/quantization/plugins/megatron.py @@ -38,7 +38,6 @@ ) from modelopt.torch.utils.distributed import ParallelState -from ..model_calib import max_calibrate from ..nn import QuantModule, QuantModuleRegistry, TensorQuantizer from ..nn.modules.quant_linear import RealQuantLinear from ..qtensor import QTensorWrapper @@ -98,7 +97,9 @@ def quant_module_get_extra_state(self) -> dict: """ extra_state = {} - is_enabled = self.weight_quantizer.is_enabled if hasattr(self, "weight_quantizer") else False + is_enabled = any( + isinstance(child, TensorQuantizer) and child.is_enabled for child in self.children() + ) if not is_enabled: return extra_state @@ -222,6 +223,10 @@ def _register_extra_state_callbacks(model: torch.nn.Module): quant_module_get_extra_state, quant_module_set_extra_state, ) + if HAS_TE and isinstance(module, TEDotProductAttention): + # A hack to set the dtype and device for DotProductAttention + # to be used in _QuantTEDotProductAttention.modelopt_post_restore() + _QuantTEDotProductAttention.set_dtype(module, name, model) for name, module in model.named_modules(): if isinstance(module, MegatronModule): @@ -632,152 +637,36 @@ def _setup(self): self.k_bmm_quantizer = TensorQuantizer() self.v_bmm_quantizer = TensorQuantizer() - def _calibrate_quantizers(self): - """Calibrate quantizers with minimal dummy tensors.""" - # Get device and dtype from the parent module's parameters - param = next(iter(self.parameters()), None) - device = param.device if param is not None else torch.device("cuda") - dtype = param.dtype if param is not None else torch.float16 - - # TEDotProductAttention expects format 'sbhd' or 'bshd' depending on rope_fusion - batch_size = 1 - seq_len = 1 - - # Get dimensions from config - num_heads = self.config.num_attention_heads - head_dim = ( - self.config.kv_channels - if hasattr(self.config, "kv_channels") - else self.config.hidden_size // num_heads - ) - - # Determine tensor format (default to sbhd if not specified) - apply_rope_fusion = getattr(self.config, "apply_rope_fusion", False) - qkv_format = "bshd" if apply_rope_fusion else "sbhd" - - if qkv_format == "sbhd": - dummy_tensor = torch.randn( - seq_len, batch_size, num_heads, head_dim, device=device, dtype=dtype - ) - else: - dummy_tensor = torch.randn( - batch_size, seq_len, num_heads, head_dim, device=device, dtype=dtype - ) - - # Calibrate each quantizer - quantizers = [ - ("q_bmm_quantizer", self.q_bmm_quantizer), - ("k_bmm_quantizer", self.k_bmm_quantizer), - ("v_bmm_quantizer", self.v_bmm_quantizer), - ] - - for _, quantizer in quantizers: - if quantizer is not None and quantizer.is_enabled(): - if not hasattr(quantizer, "_amax") or quantizer._amax is None: - quantizer.reset_amax() - max_calibrate(quantizer, lambda q: q(dummy_tensor), distributed_sync=False) - def forward(self, query, key, value, *args, **kwargs): - """Apply post-RoPE quantization to KV cache. - - TEDotProductAttention receives Q, K, V after RoPE is applied, - so we quantize them directly for KV cache quantization. - """ + """Apply post-RoPE quantization to KV cache.""" # Quantize Q, K, V query = self.q_bmm_quantizer(query) key = self.k_bmm_quantizer(key) value = self.v_bmm_quantizer(value) - return super().forward(query, key, value, *args, **kwargs) - def sharded_state_dict(self, prefix="", sharded_offsets=(), metadata=None): - """Create a sharded state dictionary for distributed checkpointing.""" - sharded_state_dict = {} - - # First add non-quantizer parameters - for k, v in self.state_dict(prefix="", keep_vars=True).items(): - if isinstance(v, torch.Tensor) and v is not None and "_quantizer" not in k: - sharded_state_dict[prefix + k] = v - - # Process _amax in bmm_quantizers - for name, quantizer in [ - ("q_bmm_quantizer", self.q_bmm_quantizer), - ("k_bmm_quantizer", self.k_bmm_quantizer), - ("v_bmm_quantizer", self.v_bmm_quantizer), - ]: - if hasattr(quantizer, "_amax") and quantizer._amax is not None: - amax_key = f"{prefix}{name}._amax" - sharded_state_dict[amax_key] = quantizer._amax - - # Process other quantizer parameters in bmm_quantizers - quantizer_state_dict = { - k: v - for k, v in self.state_dict(prefix="", keep_vars=True).items() - if isinstance(v, torch.Tensor) and "_quantizer" in k and "_amax" not in k - } - - if quantizer_state_dict: - sharded_state_dict.update( - **make_sharded_tensors_for_checkpoint( - quantizer_state_dict, prefix, {}, sharded_offsets - ) - ) - - return sharded_state_dict - - def _load_from_state_dict(self, state_dict, prefix, *args, **kwargs): - """Handle loading state dict for quantizers.""" - for quantizer_name in ["q_bmm_quantizer", "k_bmm_quantizer", "v_bmm_quantizer"]: - full_prefix = f"{prefix}{quantizer_name}." - amax_key = f"{prefix}{quantizer_name}._amax" - - # If amax is in state_dict, rename it to the format expected by TensorQuantizer - if amax_key in state_dict: - expected_amax_key = f"{full_prefix}_amax" - state_dict[expected_amax_key] = state_dict.pop(amax_key) - - # Handle other quantizer states - for k in list(state_dict.keys()): - if "_quantizer" in k and "_amax" not in k: - name = k.split(prefix)[-1] if prefix else k - if name in self.state_dict(): - state_dict[k] = state_dict[k].view_as(self.state_dict()[name]) - - super()._load_from_state_dict(state_dict, prefix, *args, **kwargs) - def modelopt_post_restore(self, name=""): """Restore quantizer states after model loading.""" - super().modelopt_post_restore(name) - - def _check_unsupported_states(quantizer): - """Check for unsupported quantizer states and warn if found.""" - if not hasattr(quantizer, "state_dict"): - return - - for k in quantizer.state_dict(): - if k not in ["_amax", "_pre_quant_scale"]: - warnings.warn( - f"Restore of {k} for {name} is not supported. The restore of this layer might be " - f"incorrect. Please implement a custom restore for {k}." - ) - - calibration_needed = False - - for quantizer_name, quantizer in [ - ("q_bmm_quantizer", self.q_bmm_quantizer), - ("k_bmm_quantizer", self.k_bmm_quantizer), - ("v_bmm_quantizer", self.v_bmm_quantizer), - ]: - if not hasattr(self, quantizer_name) or not quantizer.is_enabled(): - continue - - _check_unsupported_states(quantizer) + for tq in [self.q_bmm_quantizer, self.k_bmm_quantizer, self.v_bmm_quantizer]: + # TODO: Add support for non-scalar states such as + # Affine KVCache bias vector which is per head per channel + assert all(v.numel() == 1 for v in tq.state_dict().values()), ( + "Only scalar states are KV Cache/BMM Quantizers" + ) + # Should have been set in the `megatron_replace_quant_module_hook` + assert hasattr(self, "device") and hasattr(self, "dtype") + self.to(device=self.device, dtype=self.dtype) - if not hasattr(quantizer, "_amax") or quantizer._amax is None: - calibration_needed = True + @staticmethod + def set_dtype(module: "TEDotProductAttention", name, model: torch.nn.Module): + """Set the dtype for the module from any parameter in the model. - if calibration_needed: - self._calibrate_quantizers() + DotProductAttention does not have any parameters, so lets get the parameter from the parent module. + """ + parent = model.get_submodule(name.rsplit(".", 1)[0]) if "." in name else model + param = next(iter(parent.parameters())) + module.dtype = param.dtype + module.device = param.device @QuantModuleRegistry.register({megatron_moe_layer.MoELayer: "megatron_moe_MoELayer"}) diff --git a/tests/gpu/torch/quantization/plugins/test_megatron.py b/tests/gpu/torch/quantization/plugins/test_megatron.py index 5b2a8cc0a..57513d7ad 100644 --- a/tests/gpu/torch/quantization/plugins/test_megatron.py +++ b/tests/gpu/torch/quantization/plugins/test_megatron.py @@ -308,7 +308,7 @@ def _gpt_model_provider( def _test_sharded_state_dict( - tmp_path, config, hidden_size, modelopt_version, compress, meta_device, moe_config, rank, size + tmp_path, config, hidden_size, modelopt_version, compress, meta_device, model_config, rank, size ): # Must disable output_layer quantization since output_layer amax cannot be restore via # sharded_state_dict. All output_layer quantizers state are removed. @@ -318,13 +318,13 @@ def _test_sharded_state_dict( mto.conversion.__version__ = modelopt_version mtq.plugins.megatron.__version__ = modelopt_version - tp_size = moe_config.get("tp_size", size) - ep_size = moe_config.get("ep_size", 1) - etp_size = moe_config.get("etp_size", None) - num_moe_experts = moe_config.get("num_moe_experts", None) - moe_grouped_gemm = moe_config.get("moe_grouped_gemm", False) - use_te = moe_config.get("use_te", False) - transformer_impl = moe_config.get("transformer_impl", "local") + tp_size = model_config.get("tp_size", size) + ep_size = model_config.get("ep_size", 1) + etp_size = model_config.get("etp_size", None) + num_moe_experts = model_config.get("num_moe_experts", None) + moe_grouped_gemm = model_config.get("moe_grouped_gemm", False) + use_te = model_config.get("use_te", False) + transformer_impl = model_config.get("transformer_impl", "local") initialize_for_megatron( tensor_model_parallel_size=tp_size, @@ -424,8 +424,8 @@ def forward_fn(model): mtq.W4A8_AWQ_BETA_CFG, mtq.NVFP4_DEFAULT_CFG, mtq.FP8_2D_BLOCKWISE_WEIGHT_ONLY_CFG, - # Note: KV cache configs (FP8_KV_CFG, NVFP4_KV_CFG) are tested separately in test_kv_cache_quant - # They require TEDotProductAttention which needs transformer_impl="modelopt", not "local" + mtq.FP8_KV_CFG, + mtq.NVFP4_KV_CFG, ], ) @pytest.mark.parametrize("compress", [False, True]) @@ -827,100 +827,6 @@ def forward_fn(model): assert output is not None, "Forward pass failed" -def _test_kv_cache_sharded_state_dict_helper(tmp_path, config, rank, size): - """Helper for testing KV cache quantization with sharded state dict save/load.""" - # Disable output_layer quantization (same as other sharded state dict tests) - config["quant_cfg"]["*output_layer*"] = {"enable": False} - - initialize_for_megatron( - tensor_model_parallel_size=size, pipeline_model_parallel_size=1, seed=SEED - ) - - # Create GPT models with TEDotProductAttention (transformer_impl="modelopt") - model_ref = get_mcore_gpt_model( - tensor_model_parallel_size=size, - num_layers=2, # At least 2 layers to test multiple attention modules - hidden_size=64, - num_attention_heads=4, - vocab_size=64, - transformer_impl="modelopt", # CRITICAL: Use TEDotProductAttention - ).cuda() - - model_test = get_mcore_gpt_model( - tensor_model_parallel_size=size, - num_layers=2, - hidden_size=64, - num_attention_heads=4, - vocab_size=64, - transformer_impl="modelopt", - ).cuda() - - prompt_tokens = torch.randint( - 0, model_ref.vocab_size, (2, model_ref.max_sequence_length) - ).cuda() - - def forward_fn(model): - return megatron_prefill(model, prompt_tokens) - - # Quantize the reference model - model_ref = mtq.quantize(model_ref, config, forward_fn) - - # CRITICAL: model_test must also be quantized with the same config - # Otherwise it won't have the KV cache quantizer keys when loading state dict - model_test = mtq.quantize(model_test, config, forward_fn) - - # Verify KV cache quantizers were created - kv_quantizers_found = False - for name, module in model_ref.named_modules(): - if hasattr(module, "k_bmm_quantizer") and hasattr(module, "v_bmm_quantizer"): - kv_quantizers_found = True - assert module.k_bmm_quantizer.is_enabled, f"K quantizer not enabled in {name}" - assert module.v_bmm_quantizer.is_enabled, f"V quantizer not enabled in {name}" - - assert kv_quantizers_found, "No KV cache quantizers found in quantized model" - - # Test sharded state dict save/load - sharded_state_dict_test_helper( - tmp_path, - model_ref, - model_test, - forward_fn, - meta_device=False, - version=None, - ) - - # Verify KV cache quantizers are restored correctly in model_test - for (name_ref, module_ref), (name_test, module_test) in zip( - model_ref.named_modules(), model_test.named_modules() - ): - if hasattr(module_ref, "k_bmm_quantizer"): - assert hasattr(module_test, "k_bmm_quantizer"), ( - f"K quantizer missing after restore in {name_test}" - ) - assert hasattr(module_test, "v_bmm_quantizer"), ( - f"V quantizer missing after restore in {name_test}" - ) - - # Check that quantizer states match - if hasattr(module_ref.k_bmm_quantizer, "_amax"): - assert hasattr(module_test.k_bmm_quantizer, "_amax"), ( - f"K quantizer _amax missing in {name_test}" - ) - if module_ref.k_bmm_quantizer._amax is not None: - assert torch.allclose( - module_ref.k_bmm_quantizer._amax, module_test.k_bmm_quantizer._amax - ), f"K quantizer _amax mismatch in {name_test}" - - if hasattr(module_ref.v_bmm_quantizer, "_amax"): - assert hasattr(module_test.v_bmm_quantizer, "_amax"), ( - f"V quantizer _amax missing in {name_test}" - ) - if module_ref.v_bmm_quantizer._amax is not None: - assert torch.allclose( - module_ref.v_bmm_quantizer._amax, module_test.v_bmm_quantizer._amax - ), f"V quantizer _amax mismatch in {name_test}" - - @pytest.mark.parametrize( "config", [ @@ -955,9 +861,15 @@ def test_kv_cache_sharded_state_dict(tmp_path, config): preserved across the save/load cycle. """ size = min(2, torch.cuda.device_count()) # Use 2 GPUs if available, else 1 + model_config = { + "transformer_impl": "modelopt", + "use_te": True, + } spawn_multiprocess_job( size=size, - job=partial(_test_kv_cache_sharded_state_dict_helper, tmp_path, config), + job=partial( + _test_sharded_state_dict, tmp_path, config, 256, None, False, False, model_config + ), backend="nccl", ) From c8d7c517014652782b31ab123151f7b4a47e0e6c Mon Sep 17 00:00:00 2001 From: Asma Thekkumpate Date: Wed, 24 Dec 2025 13:41:38 -0800 Subject: [PATCH 2/2] updated/cleaned up tests --- tests/_test_utils/torch/megatron/utils.py | 10 ++ .../quantization/plugins/test_megatron.py | 110 +++++++++--------- 2 files changed, 65 insertions(+), 55 deletions(-) diff --git a/tests/_test_utils/torch/megatron/utils.py b/tests/_test_utils/torch/megatron/utils.py index 5ca0cf14c..bb91f83cd 100644 --- a/tests/_test_utils/torch/megatron/utils.py +++ b/tests/_test_utils/torch/megatron/utils.py @@ -214,6 +214,16 @@ def convert_maybe_fp8(v): f"diff: {logits_diff.max()} ref: {logits_ref}, test: {logits_test}" ) + # Test backward pass on model_test + model_test.train() + loss = forward_fn(model_test).sum() + loss.backward() + + # Assert that trainable parameters have gradients computed + for name, param in model_test.named_parameters(): + if param.requires_grad: + assert param.grad is not None, f"Parameter {name} has no gradient computed" + def copy_weights_from_grouped_to_non_grouped(te_grouped_moe_model, sequential_moe_model): """Copy weights from TEGrouped MoE model to sequential MoE model.""" diff --git a/tests/gpu/torch/quantization/plugins/test_megatron.py b/tests/gpu/torch/quantization/plugins/test_megatron.py index 57513d7ad..5426f6a9c 100644 --- a/tests/gpu/torch/quantization/plugins/test_megatron.py +++ b/tests/gpu/torch/quantization/plugins/test_megatron.py @@ -52,7 +52,6 @@ import modelopt.torch.quantization as mtq from modelopt.torch.quantization.nn import QuantModuleRegistry from modelopt.torch.quantization.plugins.megatron import _QuantTEMCoreRowParallelLinear -from modelopt.torch.utils.plugins import megatron_prefill try: from megatron.core.extensions.transformer_engine import TERowParallelLinear @@ -81,6 +80,22 @@ def get_batch(model, batch_size=2): return input_ids, labels, position_ids, attention_mask, loss_mask +def get_forward(model, batch_size=2): + """Return a forward function with cached batch inputs.""" + input_ids, labels, position_ids, attention_mask, loss_mask = get_batch(model, batch_size) + + def forward(model): + return model.forward( + input_ids=input_ids, + position_ids=position_ids, + attention_mask=attention_mask, + labels=labels, + loss_mask=loss_mask, + ) + + return forward + + def test_convert_megatron_parallel_linear(distributed_setup_size_1): initialize_for_megatron(seed=SEED) set_seed(SEED) @@ -357,14 +372,8 @@ def _test_sharded_state_dict( transformer_impl=transformer_impl, ) - prompt_tokens = torch.randint( - 0, model_ref.vocab_size, (2, model_ref.max_sequence_length) - ).cuda() - - def forward_fn(model): - return megatron_prefill(model, prompt_tokens) - - model_ref = mtq.quantize(model_ref, config, forward_fn) + forward = get_forward(model_ref) + model_ref = mtq.quantize(model_ref, config, forward) if compress: mtq.compress(model_ref) @@ -376,7 +385,7 @@ def forward_fn(model): tmp_path, model_ref, model_test, - forward_fn, + forward, meta_device=meta_device, version=modelopt_version, ) @@ -430,16 +439,27 @@ def forward_fn(model): ) @pytest.mark.parametrize("compress", [False, True]) @pytest.mark.parametrize("meta_device", [False, True]) -def test_homogeneous_sharded_state_dict(tmp_path, config, compress, meta_device): +@pytest.mark.parametrize("transformer_impl", ["local", "modelopt"]) +def test_homogeneous_sharded_state_dict(tmp_path, config, compress, meta_device, transformer_impl): if compress and config is mtq.W4A8_AWQ_BETA_CFG: pytest.skip("W4A8_AWQ_BETA_CFG is not supported for compress") size = torch.cuda.device_count() + model_config = {"transformer_impl": transformer_impl} + if transformer_impl == "modelopt": + model_config["use_te"] = True spawn_multiprocess_job( size=size, job=partial( - _test_sharded_state_dict, tmp_path, config, 256, None, compress, meta_device, {} + _test_sharded_state_dict, + tmp_path, + config, + 256, + None, + compress, + meta_device, + model_config, ), backend="nccl", ) @@ -534,16 +554,13 @@ def _test_fp8_real_quantize_helper(rank, size): config = mtq.FP8_2D_BLOCKWISE_WEIGHT_ONLY_CFG model = _gpt_model_provider(tp_size=1, hidden_size=hidden_size) - prompt_tokens = torch.randint(0, model.vocab_size, (2, model.max_sequence_length)).cuda() - - def forward_fn(model): - return megatron_prefill(model, prompt_tokens) - forward_fn(model) + forward = get_forward(model) + forward(model) # real quant the model cur_mem = get_model_size(model) - real_quant_model = mtq.quantize(model, config, forward_fn) + real_quant_model = mtq.quantize(model, config, forward) mtq.compress(real_quant_model) real_quant_mem = get_model_size(real_quant_model) @@ -551,7 +568,7 @@ def forward_fn(model): assert real_quant_mem < (cur_mem / 2) * 1.1, "Memory after real quantization is not reduced." # check forward works after real quantization - forward_fn(real_quant_model) + forward(real_quant_model) assert real_quant_mem < cur_mem @@ -606,12 +623,6 @@ def _test_te_grouped_vs_sequential_quantize_helper(tp_size, ep_size, etp_size, r seed=SEED, ) - # Create input - prompt_tokens = torch.randint(0, 64, (2, 16)).cuda() - - def forward_fn(model): - return megatron_prefill(model, prompt_tokens) - # Create TEGrouped MoE model te_grouped_moe_model = _gpt_model_provider( tp_size=tp_size, @@ -622,6 +633,10 @@ def forward_fn(model): use_te=True, num_moe_experts=4, ) + + # Create forward function with cached inputs + forward = get_forward(te_grouped_moe_model) + num_te_grouped_mlp = sum( isinstance(module, TEGroupedMLP) for module in te_grouped_moe_model.modules() ) @@ -649,19 +664,19 @@ def forward_fn(model): copy_weights_from_grouped_to_non_grouped(te_grouped_moe_model, sequential_moe_model) # Compare model outputs before quantization - te_grouped_moe_output = forward_fn(te_grouped_moe_model) - sequential_moe_output = forward_fn(sequential_moe_model) + te_grouped_moe_output = forward(te_grouped_moe_model) + sequential_moe_output = forward(sequential_moe_model) assert torch.allclose(te_grouped_moe_output, sequential_moe_output, atol=1e-6, rtol=1e-6) # Quantize grouped model - mtq.quantize(te_grouped_moe_model, mtq.FP8_DEFAULT_CFG, forward_fn) + mtq.quantize(te_grouped_moe_model, mtq.FP8_DEFAULT_CFG, forward) # Quantize non-grouped model - mtq.quantize(sequential_moe_model, mtq.FP8_DEFAULT_CFG, forward_fn) + mtq.quantize(sequential_moe_model, mtq.FP8_DEFAULT_CFG, forward) # Compare model outputs after quantization - te_grouped_moe_quant_output = forward_fn(te_grouped_moe_model) - sequential_moe_quant_output = forward_fn(sequential_moe_model) + te_grouped_moe_quant_output = forward(te_grouped_moe_model) + sequential_moe_quant_output = forward(sequential_moe_model) assert torch.allclose( te_grouped_moe_quant_output, sequential_moe_quant_output, atol=1e-6, rtol=1e-6 ) @@ -716,18 +731,15 @@ def _test_expert_model_parallel_amax_sync( param.data.fill_(const_val) weight_idx += 1 - prompt_tokens = (torch.ones((2, model.max_sequence_length)) * 0.05 + rank * 0.5).cuda().long() - # force all expert routing for module in model.modules(): if isinstance(module, TopKRouter): module.topk = module.num_experts - def forward_fn(model): - return megatron_prefill(model, prompt_tokens) + forward = get_forward(model) # quantize the model - model = mtq.quantize(model, config, forward_fn) + model = mtq.quantize(model, config, forward) # Check initial sync status initial_sync, quantizer_type, rank_values = compare_amax_sync_across_expert_parallel(model) assert initial_sync, ( @@ -735,7 +747,7 @@ def forward_fn(model): ) # Test if the amax values are inconsistent when distributed sync is disabled - mtq.model_calib.max_calibrate(model, forward_fn, distributed_sync=False) + mtq.model_calib.max_calibrate(model, forward, distributed_sync=False) inconsistent_amax, _, _ = compare_amax_sync_across_expert_parallel( model, compare_across_experts=False ) @@ -745,7 +757,7 @@ def forward_fn(model): "Amax should not be synchronized across expert parallel ranks since expert parallel is disabled" ) # calibrate the model with distributed sync and test synchronization - mtq.model_calib.max_calibrate(model, forward_fn, distributed_sync=True) + mtq.model_calib.max_calibrate(model, forward, distributed_sync=True) for module in model.modules(): if hasattr(module, "sync_moe_local_experts_amax"): module.sync_moe_local_experts_amax() @@ -798,14 +810,11 @@ def _test_kv_cache_quant_helper(config, rank, size): transformer_impl="modelopt", # This uses TEDotProductAttention via get_gpt_modelopt_spec ).cuda() - # Create dummy input for calibration - prompt_tokens = torch.randint(0, model.vocab_size, (2, model.max_sequence_length)).cuda() - - def forward_fn(model): - return megatron_prefill(model, prompt_tokens) + # Create forward function with cached inputs + forward = get_forward(model) # Test KV cache quantization with the given config - quantized_model = mtq.quantize(model, config, forward_fn) + quantized_model = mtq.quantize(model, config, forward) # Find TEDotProductAttention modules and verify they have KV cache quantizers te_attention_found = False @@ -823,7 +832,7 @@ def forward_fn(model): assert te_attention_found, "No TEDotProductAttention with KV cache quantizers found in model" # Quick smoke test that forward still works - output = forward_fn(quantized_model) + output = forward(quantized_model) assert output is not None, "Forward pass failed" @@ -880,16 +889,7 @@ def test_convert_mcore_te_gpt_model(distributed_setup_size_1): initialize_for_megatron(tensor_model_parallel_size=1, seed=SEED) model = get_mcore_gpt_model(tensor_model_parallel_size=1, transformer_impl="transformer_engine") - input_ids, labels, position_ids, attention_mask, loss_mask = get_batch(model) - - def forward(model): - return model.forward( - input_ids=input_ids, - position_ids=position_ids, - attention_mask=attention_mask, - labels=labels, - loss_mask=loss_mask, - ) + forward = get_forward(model) for name, param in model.named_parameters(): param.requires_grad = True