Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions py/torch_tensorrt/dynamo/conversion/_ConversionContext.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
from dataclasses import dataclass, field
from typing import Optional

import torch
import torch.fx
from torch_tensorrt.dynamo._settings import CompilationSettings
from torch_tensorrt.dynamo.types import TRTNetwork

Expand All @@ -25,6 +27,7 @@ class ConversionContext:
requires_native_multidevice: bool = False
weight_refit_map: dict[str, torch.Tensor] = field(default_factory=dict)
cpu_weights_reference_holder: list[torch.Tensor] = field(default_factory=list)
current_node: Optional[torch.fx.Node] = field(default=None)

def record_weight(self, name: str, weight: torch.Tensor) -> None:
"""
Expand Down
1 change: 1 addition & 0 deletions py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py
Original file line number Diff line number Diff line change
Expand Up @@ -841,6 +841,7 @@ def call_function(self, target: str, args: Any, kwargs: Any) -> Any:
self.ctx.requires_native_multidevice = True
_LOGGER.debug(f"{target} requires native multi-device support")

self.ctx.current_node = self._cur_node
if calling_convention is CallingConvention.LEGACY:
return converter(self.ctx.net, target, args, kwargs, self._cur_node_name)
else:
Expand Down
67 changes: 64 additions & 3 deletions py/torch_tensorrt/dynamo/conversion/impl/attention.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
import logging
import math
from typing import Optional, Tuple, Union

import torch
from torch.fx.node import Target
from torch_tensorrt.dynamo._SourceIR import SourceIR
from torch_tensorrt.dynamo.conversion import impl
Expand All @@ -16,6 +18,36 @@

_LOGGER: logging.Logger = logging.getLogger(__name__)

# FP8 E4M3 max representable magnitude. Softmax output is bounded to [0, 1],
# so 1/448 saturates exactly at 1.0 and is data-independent (no calibration needed).
_FP8_E4M3_MAX = 448.0


def _maybe_set_fp8_softmax(
ctx: ConversionContext,
name: str,
attention_layer: trt.IAttention,
) -> bool:
"""Set FP8 softmax normalization quantization on the IAttention layer if the current
node was annotated with a softmax FP8 scale by the fp8_attention_softmax lowering pass.

Returns True if FP8 normalization was configured (caller must set decomposable=False).
"""
if ctx.current_node is None:
return False
scale_val = ctx.current_node.meta.get("_fp8_softmax_scale")
if scale_val is None:
return False
scale_tensor = get_trt_tensor(
ctx,
torch.tensor(scale_val, dtype=torch.float32),
name + "_softmax_fp8_scale",
dtype=torch.float32,
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

dtype needs to match the pre-quant QKV dtype. otherwise TRT compilatio will fail on some platforms

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you know where we can fetch this info?

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

7f0d61c I pulled the attention layer's output tensor's dtype.

)
attention_layer.normalization_quantize_to_type = trt.DataType.FP8
attention_layer.normalization_quantize_scale = scale_tensor
return True


def tril(
ctx: ConversionContext,
Expand Down Expand Up @@ -164,6 +196,18 @@ def scaled_dot_product_attention(
Returns:
TRTTensor: Attention output tensor with shape [batch, heads, seq_len, head_dim]
"""
# When FP8 softmax normalization is active (modelopt FP8 MHA pattern) TRT's
# FP8 MHA fusion requires the Q/DQ output to feed IAttention via a single
# same-dtype Mul; any HALF<->FLOAT cast inserted by the default dynamic
# 1/sqrt(D) computation breaks the fusion. Use a static same-dtype scalar
# scale computed from the concrete head_dim.
fp8_norm_active = (
ctx.current_node is not None
and ctx.current_node.meta.get("_fp8_softmax_scale") is not None
)
if fp8_norm_active and scale is None and isinstance(query.shape[-1], int):
scale = 1.0 / math.sqrt(query.shape[-1])

if scale is None:
# 1 / math.sqrt(query.size(-1))
q_dim = impl.shape.shape(ctx, target, source_ir, f"{name}_shape_q", query, -1)
Expand Down Expand Up @@ -256,7 +300,8 @@ def scaled_dot_product_attention(

if mask_tensor is not None:
attention_layer.mask = mask_tensor
attention_layer.decomposable = True
fp8_norm = _maybe_set_fp8_softmax(ctx, name, attention_layer)
attention_layer.decomposable = not fp8_norm
attention_output = attention_layer.get_output(0)
return attention_output

Expand Down Expand Up @@ -284,6 +329,13 @@ def scaled_dot_product_flash_attention(
Optional[TRTTensor],
Optional[TRTTensor],
]:
fp8_norm_active = (
ctx.current_node is not None
and ctx.current_node.meta.get("_fp8_softmax_scale") is not None
)
if fp8_norm_active and scale is None and isinstance(query.shape[-1], int):
scale = 1.0 / math.sqrt(query.shape[-1])

if scale is None:
# 1 / math.sqrt(query.size(-1))
q_dim = impl.shape.shape(ctx, target, source_ir, f"{name}_shape_q", query, -1)
Expand Down Expand Up @@ -314,7 +366,8 @@ def scaled_dot_product_flash_attention(
)
assert attention_layer is not None, "attention layer is None"

attention_layer.decomposable = True
fp8_norm = _maybe_set_fp8_softmax(ctx, name, attention_layer)
attention_layer.decomposable = not fp8_norm

attention_output = attention_layer.get_output(0)
return attention_output, None, None, None, 0.0, 0.0, None, None, None
Expand All @@ -334,6 +387,13 @@ def scaled_dot_product_efficient_attention(
is_causal: bool = False,
scale: Optional[float] = None,
) -> Tuple[TRTTensor, Optional[TRTTensor], Optional[TRTTensor], Optional[TRTTensor]]:
fp8_norm_active = (
ctx.current_node is not None
and ctx.current_node.meta.get("_fp8_softmax_scale") is not None
)
if fp8_norm_active and scale is None and isinstance(query.shape[-1], int):
scale = 1.0 / math.sqrt(query.shape[-1])

if scale is None:
# 1 / math.sqrt(query.size(-1))
q_dim = impl.shape.shape(ctx, target, source_ir, f"{name}_shape_q", query, -1)
Expand Down Expand Up @@ -450,7 +510,8 @@ def scaled_dot_product_efficient_attention(
if mask_tensor is not None:
attention_layer.mask = mask_tensor

attention_layer.decomposable = True
fp8_norm = _maybe_set_fp8_softmax(ctx, name, attention_layer)
attention_layer.decomposable = not fp8_norm

attention_output = attention_layer.get_output(0)
return attention_output, None, None, None
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,12 @@
trace_intermediate_node_outputs,
)

from .annotate_fp8_sdpa import annotate_fp8_sdpa
from .complex_graph_rewrite import complex_graph_detection
from .constant_folding import constant_fold
from .force_causal_efficient_attention import force_causal_efficient_attention
from .fuse_prims_broadcast import fuse_prims_broadcast
from .insert_fp8_softmax_qdq import insert_fp8_softmax_qdq
from .pass_manager import DynamoPassManager
from .remove_assert_nodes import remove_assert_nodes
from .remove_detach import remove_detach
Expand Down Expand Up @@ -41,6 +43,8 @@
remove_num_users_is_0_nodes,
complex_graph_detection,
force_causal_efficient_attention,
annotate_fp8_sdpa,
insert_fp8_softmax_qdq,
]

if not is_tegra_platform():
Expand Down
76 changes: 76 additions & 0 deletions py/torch_tensorrt/dynamo/lowering/passes/annotate_fp8_sdpa.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
import logging

import torch
from torch_tensorrt.dynamo._settings import CompilationSettings

logger = logging.getLogger(__name__)

# FP8 E4M3 max. Softmax output is bounded to [0, 1], so 1/448 saturates at 1.0 exactly
# and is data-independent (no calibration required for the softmax output scale).
_FP8_E4M3_SOFTMAX_SCALE = 1.0 / 448.0

_SDPA_TARGETS = {
torch.ops.aten.scaled_dot_product_attention.default,
torch.ops.aten._scaled_dot_product_flash_attention.default,
torch.ops.aten._scaled_dot_product_efficient_attention.default,
torch.ops.aten._scaled_dot_product_cudnn_attention.default,
}


def _is_fp8_quantize_op(node: torch.fx.Node) -> bool:
"""Return True when node is a tensorrt.quantize_op with FP8 dtype (exponent_bits=4)."""
if node.op != "call_function":
return False
try:
if node.target != torch.ops.tensorrt.quantize_op.default:
return False
except AttributeError:
return False
# args: (input, amax, num_bits, exponent_bits, ...)
args = node.args
return len(args) >= 4 and args[2] == 8 and args[3] == 4


def annotate_fp8_sdpa(
gm: torch.fx.GraphModule, settings: CompilationSettings
) -> torch.fx.GraphModule:
"""Annotate SDPA nodes whose Q, K, V inputs are all FP8-quantized.

Detects the pattern emitted by modelopt when an attention module is
registered via ``register_attention_for_kv_quant``, which wraps the
Q, K, V arguments to ``F.scaled_dot_product_attention`` with
``q_bmm_quantizer``, ``k_bmm_quantizer``, ``v_bmm_quantizer``:

q_fp8 = quantize_op(q, amax_q, num_bits=8, exponent_bits=4, ...)
k_fp8 = quantize_op(k, amax_k, num_bits=8, exponent_bits=4, ...)
v_fp8 = quantize_op(v, amax_v, num_bits=8, exponent_bits=4, ...)
out = scaled_dot_product_attention(q_fp8, k_fp8, v_fp8, ...)

When all three inputs match this pattern the pass sets
``node.meta["_fp8_softmax_scale"] = 1/448`` on the SDPA node so the
attention converter can set ``IAttention.normalization_quantize_to_type
= FP8`` and ``IAttention.normalization_quantize_scale``, which TRT
requires to fuse into the ``_gemm_mha_v2`` FP8 MHA kernel.
"""
changed = False
for node in gm.graph.nodes:
if node.op != "call_function" or node.target not in _SDPA_TARGETS:
continue
if len(node.args) < 3:
continue
q_node, k_node, v_node = node.args[0], node.args[1], node.args[2]
if not all(
isinstance(n, torch.fx.Node) and _is_fp8_quantize_op(n)
for n in (q_node, k_node, v_node)
):
continue
node.meta["_fp8_softmax_scale"] = _FP8_E4M3_SOFTMAX_SCALE
changed = True
logger.debug(
f"Annotated SDPA node {node.name} with FP8 softmax scale "
f"{_FP8_E4M3_SOFTMAX_SCALE} (Q/K/V inputs are FP8-quantized)"
)

if changed:
logger.debug("FP8 SDPA softmax annotation complete")
return gm
Loading
Loading