tests: add quantized attention tests (SDPA + eager) with MHA fusion c…#4204
tests: add quantized attention tests (SDPA + eager) with MHA fusion c…#4204yizhuoz004 wants to merge 3 commits intopytorch:mainfrom
Conversation
|
Hi @yizhuoz004! Thank you for your pull request and welcome to our community. Action RequiredIn order to merge any pull request (code, docs, etc.), we require contributors to sign our Contributor License Agreement, and we don't seem to have one on file for you. ProcessIn order for us to review and merge your suggested changes, please sign at https://code.facebook.com/cla. If you are contributing on behalf of someone else (eg your employer), the individual CLA may not be sufficient and your employer may need to sign the corporate CLA. Once the CLA is signed, our tooling will perform checks and validations. Afterwards, the pull request will be tagged with If you have received this in error or have any questions, please contact us at cla@meta.com. Thanks! |
51b4f6e to
023def9
Compare
809fc3a to
9d065d5
Compare
Adds tests/py/dynamo/hlo/test_quantized_attention.py covering FP8, INT8, and NVFP4 bmm quantization via modelopt PTQ. Test coverage: - test_static: fixed shapes, causal/non-causal, LLM-realistic configs (Qwen2.5, Llama-3.2 style) - test_dynamic_batch / test_dynamic_seq: dynamic dims including decode (seq=1) and prefill - test_edge_cases: single head, non-pow2 head_dim, large batch+heads, causal prefill - test_gqa: GQA/MQA with separate Q and KV head counts Precision verification (_assert_quantized): locates the fused MHA kernel (_gemm_mha_v2 / _gemv_mha_v1) in the serialized TRT engine via IEngineInspector and asserts that the layers preceding it carry the CastMulCast QDQ pattern, confirming inputs to the kernel are quantized rather than full-precision. Falls back to a broader CastMulCast search for GQA/MQA configs where TRT decomposes attention into separate matmuls.
…s to annotate target nodes
9d065d5 to
323519e
Compare
Description
Adds test_quantized_attention.py covering FP8 and INT8 PTQ via modelopt for both SDPA-based (VanillaAttention, GQAAttention) and hand-rolled eager attention (EagerAttention, mirroring HF ViT) patterns.
Test coverage:
_QuantAttentionMixin (SDPA, IAttentionLayer path):
- test_static: fixed shapes, causal/non-causal, LLM-realistic configs
- test_dynamic_batch / test_dynamic_seq: dynamic dims incl. decode (seq=1)
- test_edge_cases: single head, non-pow2 head_dim, causal prefill
- test_gqa: GQA/MQA with separate Q and KV head counts
- test_mha_kernel_precision: @expectedfailure (#4167) — MHA inputs are Half not FP8/Int8; normalization_quantize_scale not set in torch-trt
_EagerAttentionMixin (hand-rolled matmul+softmax+matmul):
- test_static: ViT-realistic shapes
- test_dynamic_seq: dynamic seq covering ViT patch-count range
- test_mha_kernel_precision: @expectedfailure (#4200) — TRT fuses into _gemm_mha_v2 but selects a Half tactic; quantizer scales don't reach the fused kernel boundary
MHA fusion check (_assert_mha_fused) matches both _gemm_mha (prefill) and _gemv_mha (decode) kernel prefixes.
Fixes # (issue)
Type of change
Please delete options that are not relevant and/or add your own.
Checklist: