[Minor] Improve local hessian and mse calibration#976
[Minor] Improve local hessian and mse calibration#976
Conversation
Signed-off-by: Fridah-nv <201670829+Fridah-nv@users.noreply.github.com>
Signed-off-by: Fridah-nv <201670829+Fridah-nv@users.noreply.github.com>
📝 WalkthroughWalkthroughThe changes update quantization calibration behavior to apply MSE calibration exclusively to weight quantizers while activation quantizers retain max-calibration amax values. A new utility function replaces MSE calibrators with max calibrators after calibration steps, and docstrings clarify this weight-only calibration scope across MSE and Hessian calibration configurations. Changes
Estimated code review effort🎯 3 (Moderate) | ⏱️ ~20 minutes Important Pre-merge checks failedPlease resolve all errors before merging. Addressing warnings is optional. ❌ Failed checks (1 error)
✅ Passed checks (3 passed)
✨ Finishing Touches
🧪 Generate unit tests (beta)
Tip Try Coding Plans. Let us write the prompt for your AI agent so you can ship faster (with fewer bugs). Comment |
There was a problem hiding this comment.
Actionable comments posted: 1
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@modelopt/torch/quantization/model_calib.py`:
- Around line 580-615: The captured_quantized[0] may still be None if
capture_forward is not invoked during _forward_no_local_hessian; add a defensive
guard before calling self.hessian_helper.accumulate_hessian: check
captured_quantized[0] is not None and only then call
accumulate_hessian(captured_quantized[0]); if it is None fall back to the raw
input path (use input.to_local() if available) and then restore
self.input_quantizer.forward to original_forward if set; update the block that
follows out = self._forward_no_local_hessian(...) to perform this None-check and
fallback while preserving restoration of the original_forward.
ℹ️ Review info
⚙️ Run configuration
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Pro
Run ID: c0dbb264-f19a-4f5d-a597-2d4860ebdde3
📒 Files selected for processing (2)
modelopt/torch/quantization/config.pymodelopt/torch/quantization/model_calib.py
| # Forward without weight quantization during caching | ||
| if LocalHessianHelper.cache_mode: | ||
| self.weight_quantizer.disable() | ||
|
|
||
| # Capture quantized input from the forward pass for Hessian collection | ||
| captured_quantized = [None] | ||
| original_forward = None | ||
| if ( | ||
| self.hessian_helper.is_enabled | ||
| and hasattr(self, "input_quantizer") | ||
| and self.input_quantizer.is_enabled | ||
| ): | ||
| original_forward = self.input_quantizer.forward | ||
|
|
||
| def capture_forward(input_tensor): | ||
| quantized = original_forward(input_tensor) | ||
| captured_quantized[0] = ( | ||
| quantized.to_local() if hasattr(quantized, "to_local") else quantized | ||
| ) | ||
| return quantized | ||
|
|
||
| self.input_quantizer.forward = capture_forward | ||
|
|
||
| out = self._forward_no_local_hessian(input, *args, **kwargs) | ||
|
|
||
| # Collect Hessian from the quantized input that was used in forward pass | ||
| if self.hessian_helper.is_enabled: | ||
| if hasattr(self, "input_quantizer") and self.input_quantizer.is_enabled: | ||
| self.hessian_helper.accumulate_hessian(captured_quantized[0]) | ||
| if original_forward is not None: | ||
| self.input_quantizer.forward = original_forward | ||
| else: | ||
| # No input_quantizer, use raw input | ||
| input_local = input.to_local() if hasattr(input, "to_local") else input | ||
| self.hessian_helper.accumulate_hessian(input_local) | ||
|
|
There was a problem hiding this comment.
Consider adding a guard for captured_quantized[0] being None.
The monkey-patching approach to capture quantized inputs is sound and correctly restores the original forward method. However, if capture_forward is never invoked during _forward_no_local_hessian (an unlikely edge case), captured_quantized[0] would remain None, causing accumulate_hessian to fail on input_tensor.reshape().
Consider adding a defensive check:
🛡️ Proposed defensive guard
# Collect Hessian from the quantized input that was used in forward pass
if self.hessian_helper.is_enabled:
if hasattr(self, "input_quantizer") and self.input_quantizer.is_enabled:
- self.hessian_helper.accumulate_hessian(captured_quantized[0])
+ if captured_quantized[0] is not None:
+ self.hessian_helper.accumulate_hessian(captured_quantized[0])
if original_forward is not None:
self.input_quantizer.forward = original_forward
else:🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@modelopt/torch/quantization/model_calib.py` around lines 580 - 615, The
captured_quantized[0] may still be None if capture_forward is not invoked during
_forward_no_local_hessian; add a defensive guard before calling
self.hessian_helper.accumulate_hessian: check captured_quantized[0] is not None
and only then call accumulate_hessian(captured_quantized[0]); if it is None fall
back to the raw input path (use input.to_local() if available) and then restore
self.input_quantizer.forward to original_forward if set; update the block that
follows out = self._forward_no_local_hessian(...) to perform this None-check and
fallback while preserving restoration of the original_forward.
Codecov Report❌ Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## main #976 +/- ##
==========================================
- Coverage 72.12% 72.09% -0.04%
==========================================
Files 209 209
Lines 23628 23652 +24
==========================================
+ Hits 17042 17052 +10
- Misses 6586 6600 +14 ☔ View full report in Codecov by Sentry. 🚀 New features to boost your workflow:
|
What does this PR do?
Type of change: ? Bug fix
Testing
Before your PR is "Ready for review"
Make sure you read and follow Contributor guidelines and your commits are signed (
git commit -s -S).Make sure you read and follow the Security Best Practices (e.g. avoiding hardcoded
trust_remote_code=True, usingtorch.load(..., weights_only=True), avoidingpickle, etc.).Additional Information
Summary by CodeRabbit
Refactor
Documentation