-
Notifications
You must be signed in to change notification settings - Fork 222
Fix KV cache quantization bugs #673
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -85,7 +85,7 @@ def real_quant_module_get_extra_state(self) -> dict: | |
| def quant_module_get_extra_state(self) -> dict: | ||
| """Populating the extra_state when state_dict() is called. | ||
|
|
||
| quantizer_state, real_quantizer_state, and q_tensor_state are usually stored | ||
| quantizer_state, real_quantizer_state, and q_tensor_state used to be stored (before 0.29) | ||
| with in the modelopt_state metadata where the keys are the full module name. The issue | ||
| is that NeMo-MCore model's full module name can change | ||
| if pipeline-parallelism (PP) and expert-parallelism (EP) | ||
|
|
@@ -95,7 +95,11 @@ def quant_module_get_extra_state(self) -> dict: | |
| """ | ||
| extra_state = {} | ||
|
|
||
| is_enabled = self.weight_quantizer.is_enabled if hasattr(self, "weight_quantizer") else False | ||
| weight_quantizer_enabled = self.weight_quantizer.is_enabled if hasattr(self, "weight_quantizer") else False | ||
| # TODO is checking just k enough? | ||
| k_bmm_quantizer_enabled = self.k_bmm_quantizer.is_enabled if hasattr(self, "k_bmm_quantizer") else False | ||
| v_bmm_quantizer_enabled = self.v_bmm_quantizer.is_enabled if hasattr(self, "v_bmm_quantizer") else False | ||
| is_enabled = weight_quantizer_enabled or k_bmm_quantizer_enabled or v_bmm_quantizer_enabled | ||
|
|
||
| if not is_enabled: | ||
| return extra_state | ||
|
|
@@ -109,7 +113,6 @@ def quant_module_get_extra_state(self) -> dict: | |
|
|
||
| # Handle real_quantizer_state and q_tensor_state | ||
| extra_state.update(real_quant_module_get_extra_state(self)) | ||
|
|
||
| return extra_state | ||
|
|
||
|
|
||
|
|
@@ -652,7 +655,7 @@ def _calibrate_quantizers(self): | |
| ] | ||
|
|
||
| for _, quantizer in quantizers: | ||
| if quantizer is not None and quantizer.is_enabled(): | ||
| if quantizer is not None and quantizer.is_enabled: | ||
| if not hasattr(quantizer, "_amax") or quantizer._amax is None: | ||
| quantizer.reset_amax() | ||
| max_calibrate(quantizer, lambda q: q(dummy_tensor), distributed_sync=False) | ||
|
|
@@ -663,6 +666,18 @@ def forward(self, query, key, value, *args, **kwargs): | |
| TEDotProductAttention receives Q, K, V after RoPE is applied, | ||
| so we quantize them directly for KV cache quantization. | ||
| """ | ||
| # Ensure tensors are contiguous before quantization | ||
| # This is a safety measure for potential non-contiguous tensor views | ||
| # from TE or Megatron operations with tensor parallelism | ||
| def materialize_if_needed(tensor): | ||
| if tensor is not None and hasattr(tensor, 'is_contiguous') and not tensor.is_contiguous(): | ||
| return tensor.contiguous() | ||
| return tensor | ||
|
|
||
| query = materialize_if_needed(query) | ||
| key = materialize_if_needed(key) | ||
| value = materialize_if_needed(value) | ||
|
Comment on lines
+677
to
+679
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Do we need this if we are calling |
||
|
|
||
| # Quantize Q, K, V | ||
| query = self.q_bmm_quantizer(query) | ||
| key = self.k_bmm_quantizer(key) | ||
|
|
@@ -709,7 +724,7 @@ def _load_from_state_dict(self, state_dict, prefix, *args, **kwargs): | |
| """Handle loading state dict for quantizers.""" | ||
| for quantizer_name in ["q_bmm_quantizer", "k_bmm_quantizer", "v_bmm_quantizer"]: | ||
| full_prefix = f"{prefix}{quantizer_name}." | ||
| amax_key = f"{prefix}{quantizer_name}._amax" | ||
| amax_key = f"{full_prefix}_amax" | ||
|
|
||
| # If amax is in state_dict, rename it to the format expected by TensorQuantizer | ||
| if amax_key in state_dict: | ||
|
|
@@ -748,7 +763,7 @@ def _check_unsupported_states(quantizer): | |
| ("k_bmm_quantizer", self.k_bmm_quantizer), | ||
| ("v_bmm_quantizer", self.v_bmm_quantizer), | ||
| ]: | ||
| if not hasattr(self, quantizer_name) or not quantizer.is_enabled(): | ||
| if not hasattr(self, quantizer_name) or not quantizer.is_enabled: | ||
| continue | ||
|
|
||
| _check_unsupported_states(quantizer) | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -836,11 +836,7 @@ def forward_fn(model): | |
|
|
||
| # Quantize the reference model | ||
| model_ref = mtq.quantize(model_ref, config, forward_fn) | ||
|
|
||
| # CRITICAL: model_test must also be quantized with the same config | ||
| # Otherwise it won't have the KV cache quantizer keys when loading state dict | ||
| model_test = mtq.quantize(model_test, config, forward_fn) | ||
|
|
||
|
Comment on lines
838
to
-843
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @kaix-nv this is an incorrect unit test. This completely breaks the modelopt resume workflow (that is resume requires an ModelOpt un-modified model). |
||
|
|
||
| # Verify KV cache quantizers were created | ||
| kv_quantizers_found = False | ||
| for name, module in model_ref.named_modules(): | ||
|
|
@@ -851,6 +847,10 @@ def forward_fn(model): | |
|
|
||
| assert kv_quantizers_found, "No KV cache quantizers found in quantized model" | ||
|
|
||
| # CRITICAL: model_test must also be quantized with the same config | ||
| # Otherwise it won't have the KV cache quantizer keys when loading state dict | ||
| # model_test = mtq.quantize(model_test, config, forward_fn) | ||
|
|
||
| # Test sharded state dict save/load | ||
| sharded_state_dict_test_helper( | ||
| tmp_path, | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why not do: