Skip to content

fix: preserve q/k/v quantizer mapping in AST attention patching#1307

Open
Brumbelow wants to merge 1 commit intoNVIDIA:mainfrom
Brumbelow:fix/issue-1064-kv-attention-ast-ordering
Open

fix: preserve q/k/v quantizer mapping in AST attention patching#1307
Brumbelow wants to merge 1 commit intoNVIDIA:mainfrom
Brumbelow:fix/issue-1064-kv-attention-ast-ordering

Conversation

@Brumbelow
Copy link
Copy Markdown

@Brumbelow Brumbelow commented Apr 21, 2026

Summary

Preserve q/k/v quantizer wiring when register_attention_for_kv_quant() patches AST-generated attention wrappers.

Motivation

The old AST patching logic relied on breadth-first ast.walk() order, which can visit nested and sequential attention matmuls in a different order than runtime evaluation. That could attach q/k/v quantizers to the wrong operands.

Changes

  • switch attention matmul collection to deterministic post-order traversal
  • patch the first matmul as q/k score computation and the second as attention/value aggregation
  • keep the transpose wrapper only on the key operand for per-token KV-cache quantization
  • add sequential unit coverage for torch.matmul, torch.bmm, and @
  • assert that q, k, and v quantizers see the expected tensors while preserving forward outputs

Testing

Run with:

  • python -m pytest tests/unit/torch/quantization/plugins/test_attention_quant.py
  • python -m pytest tests/unit/torch/quantization/test_quantize_replace.py
  • pre-commit run --all-files

Checklist

  • Backward compatible
  • Followed guidance, no copied code.
  • Added tests
  • No docs changes (no API changes)

Additional information:
Closes #1064.

Summary by CodeRabbit

  • New Features

    • Improved attention quantization: more accurate operand selection and deterministic ordering for applying quantizers, yielding more reliable instrumentation and transpose handling.
  • Tests

    • Added parametrized tests covering multiple sequential attention variants.
    • Introduced a recording identity quantizer to verify each q/k/v is quantized once and that quantized outputs match originals exactly.

@Brumbelow Brumbelow requested a review from a team as a code owner April 21, 2026 02:43
@Brumbelow Brumbelow requested a review from sychen52 April 21, 2026 02:43
@copy-pr-bot
Copy link
Copy Markdown

copy-pr-bot Bot commented Apr 21, 2026

This pull request requires additional validation before any workflows can run on NVIDIA's runners.

Pull request vetters can view their responsibilities here.

Contributors can view more details about this message here.

@coderabbitai
Copy link
Copy Markdown
Contributor

coderabbitai Bot commented Apr 21, 2026

📝 Walkthrough

Walkthrough

Reworks attention quantizer wiring: collects attention nodes in child-first (execution-aligned) AST order, adds operand-index helper and selective transpose quantizer wrapping, makes class renaming deterministic, and updates bmm/bin-matmul quantizer application order. Tests add sequential-attention modules and a recording quantizer to validate wiring.

Changes

Cohort / File(s) Summary
Attention Plugin Core Logic
modelopt/torch/quantization/plugins/attention.py
Replaced breadth-first AST discovery with collect_attention_nodes() (child-first/post-order) to collect bmm, scaled_dot_product_attention, and MatMult in execution-aligned order; added get_operand_indices() (special-cases baddbmm with num_operands==2 to quantize operands (1,2)); changed patch() to use transpose_quantizers (selective transpose wrappers per operand) instead of a single transpose flag; made class renaming deterministic by using the first ast.ClassDef in head.body; adjusted operand/transpose ordering for cases where len(bmm_nodes)==2 or len(bin_matmul_nodes)==2 to match new collection order.
Attention Quantization Tests
tests/unit/torch/quantization/plugins/test_attention_quant.py
Added three sequential attention test classes (SequentialMatmulAttention, SequentialBMMAttention, SequentialBinMatmulAttention) that expose (q,k,v)->(output,None) paths via matmul/bmm/@; added RecordingIdentityQuantizer to capture cloned inputs; added parametrized test test_kv_quant_sequential_attention_wiring registering/unregistering KV quant for each class, injecting recording quantizers, and asserting exact output equality plus one invocation per q, k, v quantizer.

Estimated code review effort

🎯 4 (Complex) | ⏱️ ~45 minutes

🚥 Pre-merge checks | ✅ 5 | ❌ 1

❌ Failed checks (1 warning)

Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 18.75% which is insufficient. The required threshold is 80.00%. Write docstrings for the functions missing them to satisfy the coverage threshold.
✅ Passed checks (5 passed)
Check name Status Explanation
Description Check ✅ Passed Check skipped - CodeRabbit’s high-level summary is enabled.
Title check ✅ Passed The title accurately describes the main change: fixing the preservation of q/k/v quantizer mapping in AST attention patching, which directly addresses the root cause of issue #1064.
Linked Issues check ✅ Passed The changes directly address issue #1064 by replacing breadth-first ast.walk() with deterministic post-order traversal, ensuring q/k/v quantizers attach to correct operands and preserving proper tensor mapping and transpose operations.
Out of Scope Changes check ✅ Passed All changes are scoped to fixing the quantizer mapping issue: core logic in attention.py, test coverage for sequential attention variants, and quantizer wiring validation. No unrelated changes detected.
Security Anti-Patterns ✅ Passed Security review found no torch.load without weights_only, numpy.load with allow_pickle=True, trust_remote_code=True, eval/exec on external inputs, nosec bypasses, or problematic dependencies in modified files.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing Touches
🧪 Generate unit tests (beta)
  • Create PR with unit tests

Comment @coderabbitai help to get the list of available commands and usage tips.

Signed-off-by: Andrew Brumbelow <andrewbrumbelow@gmail.com>
@Brumbelow Brumbelow force-pushed the fix/issue-1064-kv-attention-ast-ordering branch from 3cccc41 to 45966b3 Compare April 24, 2026 22:58
Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 2

🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Inline comments:
In `@modelopt/torch/quantization/plugins/attention.py`:
- Around line 75-97: collect_attention_nodes currently walks the entire class
AST which can pick up matmuls in helper methods; restrict discovery to only the
forward() method by locating the ast.FunctionDef with name "forward" in the
provided node and running the child-first visit only on that function's body (or
its AST node) instead of the whole class node; apply the same change to the
other similar walker around lines 198-202 so both collectors only traverse
forward() AST to avoid picking up helpers or nested classes.
- Around line 99-117: The special-case logic in get_operand_indices and its use
in patch assumes baddbmm operands are positional and directly indexes node.args
(in patch), which will IndexError for keyword-only calls; update patch (and/or
get_operand_indices) to resolve operands by checking node.args and
node.keywords: for each expected index returned by get_operand_indices, first
try to get node.args[index], and if out of range, look up the corresponding
keyword in node.keywords (matching arg names "batch1"/"batch2" for baddbmm) and
use that value; alternatively, validate before indexing and raise a clear
exception mentioning function baddbmm and the missing operand; modify functions
get_operand_indices and patch accordingly to handle both node.args and
node.keywords and reference these symbols in your change.
🪄 Autofix (Beta)

Fix all unresolved CodeRabbit comments on this PR:

  • Push a commit to this branch (recommended)
  • Create a new PR with the fixes

ℹ️ Review info
⚙️ Run configuration

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Enterprise

Run ID: fc4e9160-3850-4010-8e66-61d8cc459d9f

📥 Commits

Reviewing files that changed from the base of the PR and between 3cccc41 and 45966b3.

📒 Files selected for processing (2)
  • modelopt/torch/quantization/plugins/attention.py
  • tests/unit/torch/quantization/plugins/test_attention_quant.py
🚧 Files skipped from review as they are similar to previous changes (1)
  • tests/unit/torch/quantization/plugins/test_attention_quant.py

Comment thread modelopt/torch/quantization/plugins/attention.py
Comment thread modelopt/torch/quantization/plugins/attention.py
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Bug for register_attention_for_kv_quant

1 participant