fix: preserve q/k/v quantizer mapping in AST attention patching#1307
fix: preserve q/k/v quantizer mapping in AST attention patching#1307Brumbelow wants to merge 1 commit intoNVIDIA:mainfrom
Conversation
📝 WalkthroughWalkthroughReworks attention quantizer wiring: collects attention nodes in child-first (execution-aligned) AST order, adds operand-index helper and selective transpose quantizer wrapping, makes class renaming deterministic, and updates bmm/bin-matmul quantizer application order. Tests add sequential-attention modules and a recording quantizer to validate wiring. Changes
Estimated code review effort🎯 4 (Complex) | ⏱️ ~45 minutes 🚥 Pre-merge checks | ✅ 5 | ❌ 1❌ Failed checks (1 warning)
✅ Passed checks (5 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing Touches🧪 Generate unit tests (beta)
Comment |
Signed-off-by: Andrew Brumbelow <andrewbrumbelow@gmail.com>
3cccc41 to
45966b3
Compare
There was a problem hiding this comment.
Actionable comments posted: 2
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@modelopt/torch/quantization/plugins/attention.py`:
- Around line 75-97: collect_attention_nodes currently walks the entire class
AST which can pick up matmuls in helper methods; restrict discovery to only the
forward() method by locating the ast.FunctionDef with name "forward" in the
provided node and running the child-first visit only on that function's body (or
its AST node) instead of the whole class node; apply the same change to the
other similar walker around lines 198-202 so both collectors only traverse
forward() AST to avoid picking up helpers or nested classes.
- Around line 99-117: The special-case logic in get_operand_indices and its use
in patch assumes baddbmm operands are positional and directly indexes node.args
(in patch), which will IndexError for keyword-only calls; update patch (and/or
get_operand_indices) to resolve operands by checking node.args and
node.keywords: for each expected index returned by get_operand_indices, first
try to get node.args[index], and if out of range, look up the corresponding
keyword in node.keywords (matching arg names "batch1"/"batch2" for baddbmm) and
use that value; alternatively, validate before indexing and raise a clear
exception mentioning function baddbmm and the missing operand; modify functions
get_operand_indices and patch accordingly to handle both node.args and
node.keywords and reference these symbols in your change.
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Enterprise
Run ID: fc4e9160-3850-4010-8e66-61d8cc459d9f
📒 Files selected for processing (2)
modelopt/torch/quantization/plugins/attention.pytests/unit/torch/quantization/plugins/test_attention_quant.py
🚧 Files skipped from review as they are similar to previous changes (1)
- tests/unit/torch/quantization/plugins/test_attention_quant.py
Summary
Preserve q/k/v quantizer wiring when
register_attention_for_kv_quant()patches AST-generated attention wrappers.Motivation
The old AST patching logic relied on breadth-first
ast.walk()order, which can visit nested and sequential attention matmuls in a different order than runtime evaluation. That could attach q/k/v quantizers to the wrong operands.Changes
torch.matmul,torch.bmm, and@Testing
Run with:
python -m pytest tests/unit/torch/quantization/plugins/test_attention_quant.pypython -m pytest tests/unit/torch/quantization/test_quantize_replace.pypre-commit run --all-filesChecklist
Additional information:
Closes #1064.
Summary by CodeRabbit
New Features
Tests