-
Notifications
You must be signed in to change notification settings - Fork 394
IAttention FP8 #4209
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Draft
narendasan
wants to merge
2
commits into
main
Choose a base branch
from
narendasan/quantization_fixes
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Draft
IAttention FP8 #4209
Changes from all commits
Commits
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
76 changes: 76 additions & 0 deletions
76
py/torch_tensorrt/dynamo/lowering/passes/annotate_fp8_sdpa.py
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,76 @@ | ||
| import logging | ||
|
|
||
| import torch | ||
| from torch_tensorrt.dynamo._settings import CompilationSettings | ||
|
|
||
| logger = logging.getLogger(__name__) | ||
|
|
||
| # FP8 E4M3 max. Softmax output is bounded to [0, 1], so 1/448 saturates at 1.0 exactly | ||
| # and is data-independent (no calibration required for the softmax output scale). | ||
| _FP8_E4M3_SOFTMAX_SCALE = 1.0 / 448.0 | ||
|
|
||
| _SDPA_TARGETS = { | ||
| torch.ops.aten.scaled_dot_product_attention.default, | ||
| torch.ops.aten._scaled_dot_product_flash_attention.default, | ||
| torch.ops.aten._scaled_dot_product_efficient_attention.default, | ||
| torch.ops.aten._scaled_dot_product_cudnn_attention.default, | ||
| } | ||
|
|
||
|
|
||
| def _is_fp8_quantize_op(node: torch.fx.Node) -> bool: | ||
| """Return True when node is a tensorrt.quantize_op with FP8 dtype (exponent_bits=4).""" | ||
| if node.op != "call_function": | ||
| return False | ||
| try: | ||
| if node.target != torch.ops.tensorrt.quantize_op.default: | ||
| return False | ||
| except AttributeError: | ||
| return False | ||
| # args: (input, amax, num_bits, exponent_bits, ...) | ||
| args = node.args | ||
| return len(args) >= 4 and args[2] == 8 and args[3] == 4 | ||
|
|
||
|
|
||
| def annotate_fp8_sdpa( | ||
| gm: torch.fx.GraphModule, settings: CompilationSettings | ||
| ) -> torch.fx.GraphModule: | ||
| """Annotate SDPA nodes whose Q, K, V inputs are all FP8-quantized. | ||
|
|
||
| Detects the pattern emitted by modelopt when an attention module is | ||
| registered via ``register_attention_for_kv_quant``, which wraps the | ||
| Q, K, V arguments to ``F.scaled_dot_product_attention`` with | ||
| ``q_bmm_quantizer``, ``k_bmm_quantizer``, ``v_bmm_quantizer``: | ||
|
|
||
| q_fp8 = quantize_op(q, amax_q, num_bits=8, exponent_bits=4, ...) | ||
| k_fp8 = quantize_op(k, amax_k, num_bits=8, exponent_bits=4, ...) | ||
| v_fp8 = quantize_op(v, amax_v, num_bits=8, exponent_bits=4, ...) | ||
| out = scaled_dot_product_attention(q_fp8, k_fp8, v_fp8, ...) | ||
|
|
||
| When all three inputs match this pattern the pass sets | ||
| ``node.meta["_fp8_softmax_scale"] = 1/448`` on the SDPA node so the | ||
| attention converter can set ``IAttention.normalization_quantize_to_type | ||
| = FP8`` and ``IAttention.normalization_quantize_scale``, which TRT | ||
| requires to fuse into the ``_gemm_mha_v2`` FP8 MHA kernel. | ||
| """ | ||
| changed = False | ||
| for node in gm.graph.nodes: | ||
| if node.op != "call_function" or node.target not in _SDPA_TARGETS: | ||
| continue | ||
| if len(node.args) < 3: | ||
| continue | ||
| q_node, k_node, v_node = node.args[0], node.args[1], node.args[2] | ||
| if not all( | ||
| isinstance(n, torch.fx.Node) and _is_fp8_quantize_op(n) | ||
| for n in (q_node, k_node, v_node) | ||
| ): | ||
| continue | ||
| node.meta["_fp8_softmax_scale"] = _FP8_E4M3_SOFTMAX_SCALE | ||
| changed = True | ||
| logger.debug( | ||
| f"Annotated SDPA node {node.name} with FP8 softmax scale " | ||
| f"{_FP8_E4M3_SOFTMAX_SCALE} (Q/K/V inputs are FP8-quantized)" | ||
| ) | ||
|
|
||
| if changed: | ||
| logger.debug("FP8 SDPA softmax annotation complete") | ||
| return gm |
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
dtype needs to match the pre-quant QKV dtype. otherwise TRT compilatio will fail on some platforms
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do you know where we can fetch this info?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
7f0d61c I pulled the attention layer's output tensor's dtype.