diff --git a/modelopt/torch/quantization/plugins/attention.py b/modelopt/torch/quantization/plugins/attention.py index 2113edea8a..9719f760bf 100644 --- a/modelopt/torch/quantization/plugins/attention.py +++ b/modelopt/torch/quantization/plugins/attention.py @@ -72,11 +72,48 @@ def is_sdpa(node): def is_bin_matmul(node): return isinstance(node, ast.BinOp) and isinstance(node.op, ast.MatMult) - def patch(node, quantizer_names, transpose=False): - for index, quantizer_name in enumerate(quantizer_names): + def collect_attention_nodes(node): + """Collect attention operators in runtime evaluation order. + + ``ast.walk`` traverses breadth-first, which visits an outer matmul before the inner + q/k score matmul in nested attention expressions. Visiting children first preserves the + execution order for both nested expressions and sequential assignments. + """ + bmm_nodes = [] + sdpa_nodes = [] + bin_matmul_nodes = [] + + def visit(current_node): + for child in ast.iter_child_nodes(current_node): + visit(child) + if is_bmm(current_node): + bmm_nodes.append(current_node) + if is_sdpa(current_node): + sdpa_nodes.append(current_node) + if is_bin_matmul(current_node): + bin_matmul_nodes.append(current_node) + + visit(node) + return bmm_nodes, sdpa_nodes, bin_matmul_nodes + + def get_operand_indices(node, num_operands): + if ( + isinstance(node, ast.Call) + and isinstance(node.func, ast.Attribute) + and node.func.attr == "baddbmm" + and num_operands == 2 + ): + return (1, 2) + return tuple(range(num_operands)) + + def patch(node, quantizer_names, transpose_quantizers=()): + for index, quantizer_name in zip( + get_operand_indices(node, len(quantizer_names)), quantizer_names + ): if quantizer_name is None: continue arg = node.args[index] + transpose = quantizer_name in transpose_quantizers if not transpose: node.args[index] = ast.Call( @@ -158,20 +195,11 @@ def patch_binop(node, quantizer_names, transpose=False): ) node.right = quant_arg - nodes = list(ast.walk(head)) - org_class_name = nodes[1].name # type: ignore[attr-defined] - new_class_name = nodes[1].name = "_Quant" + nodes[1].name # type: ignore[attr-defined] - - bmm_nodes = [] - sdpa_nodes = [] - bin_matmul_nodes = [] - for node in ast.walk(head): - if is_bmm(node): - bmm_nodes.append(node) - if is_sdpa(node): - sdpa_nodes.append(node) - if is_bin_matmul(node): - bin_matmul_nodes.append(node) + class_def = next(node for node in head.body if isinstance(node, ast.ClassDef)) + org_class_name = class_def.name + new_class_name = class_def.name = "_Quant" + class_def.name + + bmm_nodes, sdpa_nodes, bin_matmul_nodes = collect_attention_nodes(head) if len(bmm_nodes) != 2 and len(sdpa_nodes) != 1 and len(bin_matmul_nodes) != 2: print(f"Expect 2 bmm/matmul op in the {org_class_name}, found {len(bmm_nodes)}") print(f"Or expect 1 sdpa op in the {org_class_name}, found {len(sdpa_nodes)}") @@ -180,22 +208,23 @@ def patch_binop(node, quantizer_names, transpose=False): return False if len(bmm_nodes) == 2: - # transpose k cache here to enable per-token quantization - # without transpose, the quantization will be per-channel, i.e., - # self.k_bmm_quantizer(key_states.transpose(-1, -2)) - # after transpose, the quantization will be per-token, i.e., - # self.k_bmm_quantizer(key_states.transpose(-1, -2).transpose(-1, -2)).transpose(-1, -2) - # removing the additional transpose is doable but not trivial - patch(bmm_nodes[0], quantizer_names=(None, "v_bmm_quantizer")) - patch(bmm_nodes[1], quantizer_names=("q_bmm_quantizer", "k_bmm_quantizer"), transpose=True) + # The first matmul computes attention scores from q and k, while the second one combines + # attention probabilities with v. The transpose wrapper keeps the key quantizer on the + # original cache layout so per-token quantization still works when the matmul expects k^T. + patch( + bmm_nodes[0], + quantizer_names=("q_bmm_quantizer", "k_bmm_quantizer"), + transpose_quantizers=("k_bmm_quantizer",), + ) + patch(bmm_nodes[1], quantizer_names=(None, "v_bmm_quantizer")) print("Patching 2 BMM/Matmul operators with quantizers") if len(bin_matmul_nodes) == 2: patch_binop( - bin_matmul_nodes[1], + bin_matmul_nodes[0], quantizer_names=("q_bmm_quantizer", "k_bmm_quantizer"), transpose=True, ) - patch_binop(bin_matmul_nodes[0], quantizer_names=(None, "v_bmm_quantizer")) + patch_binop(bin_matmul_nodes[1], quantizer_names=(None, "v_bmm_quantizer")) print("Patching 2 @ operators with quantizers") if len(sdpa_nodes) == 1: diff --git a/tests/unit/torch/quantization/plugins/test_attention_quant.py b/tests/unit/torch/quantization/plugins/test_attention_quant.py index 30947f3a06..1dead11f01 100644 --- a/tests/unit/torch/quantization/plugins/test_attention_quant.py +++ b/tests/unit/torch/quantization/plugins/test_attention_quant.py @@ -60,6 +60,37 @@ def forward(self, hidden_states, **kwargs): return F.scaled_dot_product_attention(q, k, v), None +class SequentialMatmulAttention(nn.Module): + def forward(self, q, k, v): + scores = torch.matmul(q, k.transpose(-2, -1)) + probs = torch.softmax(scores, dim=-1) + return torch.matmul(probs, v), None + + +class SequentialBMMAttention(nn.Module): + def forward(self, q, k, v): + scores = torch.bmm(q, k.transpose(-2, -1)) + probs = torch.softmax(scores, dim=-1) + return torch.bmm(probs, v), None + + +class SequentialBinMatmulAttention(nn.Module): + def forward(self, q, k, v): + scores = q @ k.transpose(-2, -1) + probs = scores.softmax(dim=-1) + return probs @ v, None + + +class RecordingIdentityQuantizer(nn.Module): + def __init__(self): + super().__init__() + self.inputs = [] + + def forward(self, tensor): + self.inputs.append(tensor.detach().clone()) + return tensor + + kv_cache_config = { "quant_cfg": [ {"quantizer_name": "*[kv]_bmm_quantizer", "cfg": {"num_bits": 4}, "enable": True}, @@ -159,6 +190,44 @@ def test_kv_quant_bert(): assert output.end_logits is not None +@pytest.mark.parametrize( + "attn_cls", + [SequentialMatmulAttention, SequentialBMMAttention, SequentialBinMatmulAttention], +) +def test_kv_quant_sequential_attention_wiring(attn_cls): + q = torch.arange(24, dtype=torch.float32).reshape(2, 3, 4) / 100 + k = (torch.arange(24, dtype=torch.float32).reshape(2, 3, 4) + 100) / 100 + v = (torch.arange(24, dtype=torch.float32).reshape(2, 3, 4) + 200) / 100 + + original_attention = attn_cls() + quant_attention = attn_cls() + + assert mtq.plugins.register_attention_for_kv_quant(attn_cls) + + try: + mtq.replace_quant_module(quant_attention) + + q_bmm_quantizer = RecordingIdentityQuantizer() + k_bmm_quantizer = RecordingIdentityQuantizer() + v_bmm_quantizer = RecordingIdentityQuantizer() + quant_attention.q_bmm_quantizer = q_bmm_quantizer + quant_attention.k_bmm_quantizer = k_bmm_quantizer + quant_attention.v_bmm_quantizer = v_bmm_quantizer + + expected, _ = original_attention(q, k, v) + actual, _ = quant_attention(q, k, v) + + torch.testing.assert_close(actual, expected) + assert len(q_bmm_quantizer.inputs) == 1 + assert len(k_bmm_quantizer.inputs) == 1 + assert len(v_bmm_quantizer.inputs) == 1 + torch.testing.assert_close(q_bmm_quantizer.inputs[0], q) + torch.testing.assert_close(k_bmm_quantizer.inputs[0], k) + torch.testing.assert_close(v_bmm_quantizer.inputs[0], v) + finally: + mtq.unregister(attn_cls) + + @pytest.mark.skipif(kitchen is None, reason="kitchen is not installed.") def test_kitchen_fa(): batch_size = 2