Skip to content

Conversation

@jenchen13
Copy link
Contributor

@jenchen13 jenchen13 commented Dec 10, 2025

What does this PR do?

Type of change: ? Bug fix

Fix bug during resuming training from KV-cache-quantized checkpoint by writing extra state for core_attention to checkpoint

798 12: [rank12]:   File "/opt/megatron-lm/megatron/training/checkpointing.py", line 924, in _load_global_dist_base_checkpoint
 799 12: [rank12]:     state_dict = dist_checkpointing.load(sharded_state_dict, checkpoint_name, load_strategy, strict=args.dist_ckpt_strictness)
 800 12: [rank12]:                  ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
 ...
993 115: [rank115]:   File "/usr/local/lib/python3.12/dist-packages/modelopt/torch/distill/distillation_model.py", line 207, in load_state_dict
2994 115: [rank115]:     return super().load_state_dict(state_dict, *args, **kwargs)
2995 115: [rank115]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
2996 115: [rank115]:   File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 2593, in load_state_dict
2997 115: [rank115]:     raise RuntimeError(
2998 115: [rank115]: RuntimeError: Error(s) in loading state_dict for DistillMambaModel:
2999 115: [rank115]:     Unexpected key(s) in state_dict: "decoder.layers.5.self_attention.core_attention.k_bmm_quantizer._amax"

Overview: ?

Usage

# Add a code snippet demonstrating how to use this

Testing

Before your PR is "Ready for review"

  • Make sure you read and follow Contributor guidelines and your commits are signed.
  • Is this change backward compatible?: Yes/No
  • Did you write any new necessary tests?: Yes/No
  • Did you add or update any necessary documentation?: Yes/No
  • Did you update Changelog?: Yes/No

Additional Information

Signed-off-by: jenchen13 <jennifchen@nvidia.com>
@jenchen13 jenchen13 requested a review from a team as a code owner December 10, 2025 23:20
@jenchen13 jenchen13 requested a review from realAsma December 10, 2025 23:20
@copy-pr-bot
Copy link

copy-pr-bot bot commented Dec 10, 2025

This pull request requires additional validation before any workflows can run on NVIDIA's runners.

Pull request vetters can view their responsibilities here.

Contributors can view more details about this message here.

@codecov
Copy link

codecov bot commented Dec 10, 2025

Codecov Report

❌ Patch coverage is 20.00000% with 4 lines in your changes missing coverage. Please review.
✅ Project coverage is 74.71%. Comparing base (53a2dde) to head (24baf8e).
⚠️ Report is 44 commits behind head on main.

Files with missing lines Patch % Lines
.../torch/quantization/nn/modules/tensor_quantizer.py 33.33% 2 Missing ⚠️
modelopt/torch/quantization/tensor_quant.py 0.00% 2 Missing ⚠️
Additional details and impacted files
@@            Coverage Diff             @@
##             main     #673      +/-   ##
==========================================
+ Coverage   74.50%   74.71%   +0.21%     
==========================================
  Files         183      192       +9     
  Lines       18400    18941     +541     
==========================================
+ Hits        13709    14152     +443     
- Misses       4691     4789      +98     

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

Signed-off-by: jenchen13 <jennifchen@nvidia.com>
Signed-off-by: jenchen13 <jennifchen@nvidia.com>
@jenchen13
Copy link
Contributor Author

jenchen13 commented Dec 15, 2025

TODO fix unit test to work while commenting on quantization of test model https://github.com/NVIDIA/Model-Optimizer/blob/main/tests/gpu/torch/quantization/plugins/test_megatron.py#L870
TODO also add E2E unit test on tiny llama

@jenchen13
Copy link
Contributor Author

TODO currently kv_scale keys are not being exported, which means if the model is served it will use default KV scales as 1. fix this bug

@jenchen13
Copy link
Contributor Author

Signed-off-by: jenchen13 <jennifchen@nvidia.com>
@jenchen13 jenchen13 changed the title Write extra state for KV quantizer Fix KV cache quantization bugs Dec 22, 2025
Signed-off-by: jenchen13 <jennifchen@nvidia.com>
Comment on lines +98 to +102
weight_quantizer_enabled = self.weight_quantizer.is_enabled if hasattr(self, "weight_quantizer") else False
# TODO is checking just k enough?
k_bmm_quantizer_enabled = self.k_bmm_quantizer.is_enabled if hasattr(self, "k_bmm_quantizer") else False
v_bmm_quantizer_enabled = self.v_bmm_quantizer.is_enabled if hasattr(self, "v_bmm_quantizer") else False
is_enabled = weight_quantizer_enabled or k_bmm_quantizer_enabled or v_bmm_quantizer_enabled
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why not do:

Suggested change
weight_quantizer_enabled = self.weight_quantizer.is_enabled if hasattr(self, "weight_quantizer") else False
# TODO is checking just k enough?
k_bmm_quantizer_enabled = self.k_bmm_quantizer.is_enabled if hasattr(self, "k_bmm_quantizer") else False
v_bmm_quantizer_enabled = self.v_bmm_quantizer.is_enabled if hasattr(self, "v_bmm_quantizer") else False
is_enabled = weight_quantizer_enabled or k_bmm_quantizer_enabled or v_bmm_quantizer_enabled
is_enabled = any(isinstance(child, TensorQuantizer) and child.is_enabled for child in self.children())

Comment on lines +677 to +679
query = materialize_if_needed(query)
key = materialize_if_needed(key)
value = materialize_if_needed(value)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we need this if we are calling inputs = inputs.contiguous() in TensorQuantize forward?

Comment on lines 838 to -843
model_ref = mtq.quantize(model_ref, config, forward_fn)

# CRITICAL: model_test must also be quantized with the same config
# Otherwise it won't have the KV cache quantizer keys when loading state dict
model_test = mtq.quantize(model_test, config, forward_fn)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@kaix-nv this is an incorrect unit test. This completely breaks the modelopt resume workflow (that is resume requires an ModelOpt un-modified model).

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants