From 89443d8bca4ed82e4e8a8378831c34a7052fa983 Mon Sep 17 00:00:00 2001 From: apbose Date: Mon, 4 May 2026 21:50:58 -0700 Subject: [PATCH 1/5] dynamic shape arg --- py/torch_tensorrt/_compile.py | 36 ++- py/torch_tensorrt/dynamo/_tracer.py | 11 +- .../dynamo/models/test_shared_dynamic_dim.py | 250 ++++++++++++++++++ 3 files changed, 289 insertions(+), 8 deletions(-) create mode 100644 tests/py/dynamo/models/test_shared_dynamic_dim.py diff --git a/py/torch_tensorrt/_compile.py b/py/torch_tensorrt/_compile.py index f052954efb..f01e95eed9 100644 --- a/py/torch_tensorrt/_compile.py +++ b/py/torch_tensorrt/_compile.py @@ -7,7 +7,18 @@ import platform import warnings from enum import Enum -from typing import Any, Callable, Dict, List, Optional, Sequence, Set, Tuple, Union +from typing import ( + Any, + Callable, + Dict, + List, + Optional, + Sequence, + Set, + Tuple, + Union, + cast, +) import torch from torch_tensorrt._enums import dtype @@ -191,6 +202,7 @@ def compile( arg_inputs: Optional[Sequence[Sequence[Any]]] = None, kwarg_inputs: Optional[Dict[str, Any]] = None, enabled_precisions: Optional[Set[Union[torch.dtype, dtype]]] = None, + dynamic_shapes: Optional[Any] = None, **kwargs: Any, ) -> ( torch.nn.Module | torch.jit.ScriptModule | torch.fx.GraphModule | Callable[..., Any] @@ -226,6 +238,14 @@ def compile( kwarg_inputs (dict[Any, ...]): Optional, kwarg inputs to the module forward function. enabled_precision (Set(Union(torch.dtype, torch_tensorrt.dtype))): The set of datatypes that TensorRT can use when selecting kernels ir (str): The requested strategy to compile. (Options: default - Let Torch-TensorRT decide, ts - TorchScript with scripting path) + dynamic_shapes (Any): Optional ``dynamic_shapes`` dict (or list / nested + structure) forwarded to ``torch.export.export``. Supply this to share a + ``Dim`` across multiple inputs (e.g. when ``input_ids`` and ``attention_mask`` + must have the same batch size at runtime). When omitted, dynamic shapes are + auto-inferred from per-input ``min_shape``/``max_shape`` and **each input gets + its own independent symbol** -- which fails ``torch.export``'s constraint + check for models that broadcast across these axes. Only consulted when + ``module`` is an ``nn.Module`` (ignored for ``ExportedProgram``). **kwargs: Additional settings for the specific requested strategy (See submodules for more info) Returns: @@ -296,7 +316,7 @@ def _fx_input_interface( return compiled_fx_module elif target_ir == _IRType.dynamo: # Prepare torch and torchtrt inputs - if arg_inputs is None and inputs is None: + if arg_inputs is None and inputs is None and not kwarg_inputs: raise AssertionError("'arg_inputs' and 'inputs' should not both be None.") elif arg_inputs is not None and inputs is not None: @@ -311,8 +331,10 @@ def _fx_input_interface( from torch_tensorrt.dynamo.utils import prepare_inputs - if not isinstance(arg_inputs, collections.abc.Sequence): - arg_inputs = [arg_inputs] # type: ignore + if arg_inputs is None: + arg_inputs = [] + elif not isinstance(arg_inputs, collections.abc.Sequence): + arg_inputs = [arg_inputs] torchtrt_arg_inputs = prepare_inputs(arg_inputs) torchtrt_kwarg_inputs = prepare_inputs(kwarg_inputs) @@ -324,6 +346,7 @@ def _fx_input_interface( module, torchtrt_arg_inputs, kwarg_inputs=torchtrt_kwarg_inputs, + dynamic_shapes=dynamic_shapes, **kwargs, ) trt_graph_module = dynamo_compile( @@ -830,8 +853,13 @@ def _all_are_input_objects(obj: Any) -> bool: f"Inferred dynamic_shapes from torch_tensorrt.Input objects with min/opt/max specifications: {dynamic_shapes}" ) +<<<<<<< HEAD arg_tensors = tuple(get_torch_inputs(arg_inputs, default_device())) # type: ignore[arg-type] kwarg_tensors = get_torch_inputs(kwarg_inputs, default_device()) # type: ignore[assignment] +======= + arg_tensors = tuple(get_torch_inputs(arg_inputs, default_device())) + kwarg_tensors = get_torch_inputs(kwarg_inputs, default_device()) +>>>>>>> 7fa5d838 (dynamic shape arg) else: # Mixed case: some inputs are Tensors, some are Input objects diff --git a/py/torch_tensorrt/dynamo/_tracer.py b/py/torch_tensorrt/dynamo/_tracer.py index 0595c6a8f9..e36c3b5240 100644 --- a/py/torch_tensorrt/dynamo/_tracer.py +++ b/py/torch_tensorrt/dynamo/_tracer.py @@ -19,6 +19,7 @@ def trace( *, arg_inputs: Optional[Tuple[Any, ...]] = None, kwarg_inputs: Optional[dict[Any, Any]] = None, + dynamic_shapes: Optional[Any] = None, **kwargs: Any, ) -> torch.export.ExportedProgram: """Exports a ``torch.export.ExportedProgram`` from a ``torch.nn.Module`` or ``torch.fx.GraphModule`` specifically targeting being compiled with Torch-TensorRT @@ -65,7 +66,7 @@ def trace( raise AssertionError( "'arg_inputs' and 'inputs' should not be used at the same time." ) - arg_inputs = inputs or arg_inputs + arg_inputs = inputs if inputs is not None else arg_inputs if kwarg_inputs is None: kwarg_inputs = {} @@ -73,9 +74,11 @@ def trace( device = to_torch_device(kwargs.get("device", default_device())) torch_arg_inputs = get_torch_inputs(arg_inputs, device) torch_kwarg_inputs = get_torch_inputs(kwarg_inputs, device) - # Constructing dynamic shape list as a nested dict - dynamic_shapes = get_dynamic_shapes_args(mod, arg_inputs) - dynamic_shapes.update(get_dynamic_shapes_kwargs(kwarg_inputs)) + if dynamic_shapes is None: + # Auto-inferred dims are independent per input; pass dynamic_shapes + # explicitly to share a Dim across inputs. + dynamic_shapes = get_dynamic_shapes_args(mod, arg_inputs) + dynamic_shapes.update(get_dynamic_shapes_kwargs(kwarg_inputs)) exp_program = export( mod, tuple(torch_arg_inputs), diff --git a/tests/py/dynamo/models/test_shared_dynamic_dim.py b/tests/py/dynamo/models/test_shared_dynamic_dim.py new file mode 100644 index 0000000000..1f281e6851 --- /dev/null +++ b/tests/py/dynamo/models/test_shared_dynamic_dim.py @@ -0,0 +1,250 @@ +# type: ignore +""" +Tests for the ``dynamic_shapes=`` passthrough kwarg on ``torch_tensorrt.compile``. + +Background: when a model takes multiple inputs whose dynamic axes must be +**equal at runtime** (e.g. HF encoders with ``input_ids`` / ``attention_mask`` +both shaped ``[B, S]``), the legacy auto-inference path in +``dynamo/_tracer.py`` mints an *independent* ``Dim`` per input. ``torch.export`` +then fails its constraint check for any forward() that broadcasts across those +axes (here: ``embed(input_ids) * mask.unsqueeze(-1)``), raising +``ConstraintViolationError``. + +These tests exercise the new ``dynamic_shapes=`` passthrough that lets the +caller supply a shared ``Dim`` directly to ``torch_tensorrt.compile`` -- +mirroring the ``torch.export.export(dynamic_shapes=...)`` signature -- so the +shared-batch case compiles end to end without the caller having to pre-export +the module themselves. +""" +import unittest + +import pytest +import torch +import torch.nn as nn +import torch_tensorrt as torchtrt +from torch.export import Dim +from torch_tensorrt.dynamo.utils import COSINE_THRESHOLD, cosine_similarity + +assertions = unittest.TestCase() + + +class _SharedBatchEncoder(nn.Module): + """HF-style encoder stand-in: two int64 inputs sharing the batch axis. + + The ``embed(input_ids) * mask.unsqueeze(-1)`` broadcast forces + ``input_ids.size(0) == attention_mask.size(0)`` -- the relationship the + auto-inference path cannot express. + """ + + def __init__(self, vocab: int = 1024, hidden: int = 32): + super().__init__() + self.embed = nn.Embedding(vocab, hidden) + self.proj = nn.Linear(hidden, hidden) + + def forward(self, input_ids, attention_mask): + x = self.embed(input_ids) + mask = attention_mask.unsqueeze(-1).to(x.dtype) + return self.proj(x * mask) + + +def _kwarg_inputs(seq: int = 16, batch_min: int = 1, batch_max: int = 4): + return { + "input_ids": torchtrt.Input( + min_shape=(batch_min, seq), + opt_shape=(batch_max, seq), + max_shape=(batch_max, seq), + dtype=torch.int64, + name="input_ids", + ), + "attention_mask": torchtrt.Input( + min_shape=(batch_min, seq), + opt_shape=(batch_max, seq), + max_shape=(batch_max, seq), + dtype=torch.int64, + name="attention_mask", + ), + } + + +@pytest.mark.unit +@pytest.mark.critical +def test_dynamic_shapes_passthrough_with_shared_batch_dim(): + """With ``dynamic_shapes={..: {0: batch}, ..: {0: batch}}`` (one shared + ``Dim``), compile succeeds and the engine matches the eager model.""" + model = _SharedBatchEncoder().eval().cuda() + + batch = Dim("batch", min=1, max=4) + dynamic_shapes = { + "input_ids": {0: batch}, + "attention_mask": {0: batch}, + } + + trt_mod = torchtrt.compile( + model, + ir="dynamo", + kwarg_inputs=_kwarg_inputs(), + dynamic_shapes=dynamic_shapes, + min_block_size=1, + cache_built_engines=False, + reuse_cached_engines=False, + ) + + # Sample at the optimization shape and at a smaller batch within the range. + for bs in (4, 2): + ids = torch.randint(0, 1024, (bs, 16), dtype=torch.int64, device="cuda") + mask = torch.ones((bs, 16), dtype=torch.int64, device="cuda") + + with torch.no_grad(): + ref = model(input_ids=ids, attention_mask=mask) + out = trt_mod(input_ids=ids, attention_mask=mask) + + cos_sim = cosine_similarity(ref, out) + assertions.assertTrue( + cos_sim > COSINE_THRESHOLD, + f"Shared-batch encoder out-of-tolerance at bs={bs}: cos_sim={cos_sim}", + ) + + +@pytest.mark.unit +def test_dynamic_shapes_passthrough_positional_tuple_form(): + """``torch.export`` also accepts ``dynamic_shapes`` as a tuple matching the + positional-args order. Verify the passthrough handles that form too.""" + model = _SharedBatchEncoder().eval().cuda() + + batch = Dim("batch", min=1, max=4) + seq = 16 + positional_inputs = [ + torchtrt.Input( + min_shape=(1, seq), + opt_shape=(4, seq), + max_shape=(4, seq), + dtype=torch.int64, + name="input_ids", + ), + torchtrt.Input( + min_shape=(1, seq), + opt_shape=(4, seq), + max_shape=(4, seq), + dtype=torch.int64, + name="attention_mask", + ), + ] + # Tuple form: one entry per positional arg, in declaration order. + dynamic_shapes = ({0: batch}, {0: batch}) + + trt_mod = torchtrt.compile( + model, + ir="dynamo", + inputs=positional_inputs, + dynamic_shapes=dynamic_shapes, + min_block_size=1, + cache_built_engines=False, + reuse_cached_engines=False, + ) + + for bs in (4, 2): + ids = torch.randint(0, 1024, (bs, seq), dtype=torch.int64, device="cuda") + mask = torch.ones((bs, seq), dtype=torch.int64, device="cuda") + + with torch.no_grad(): + ref = model(ids, mask) + out = trt_mod(ids, mask) + + cos_sim = cosine_similarity(ref, out) + assertions.assertTrue( + cos_sim > COSINE_THRESHOLD, + f"Tuple-form dynamic_shapes out-of-tolerance at bs={bs}: cos_sim={cos_sim}", + ) + + +@pytest.mark.unit +def test_dynamic_shapes_passthrough_mixed_args_and_kwargs(): + """One positional input, one kwarg input, sharing a batch ``Dim``. Uses the + unified dict-by-parameter-name form, which spans both positional and keyword + parameters.""" + model = _SharedBatchEncoder().eval().cuda() + + batch = Dim("batch", min=1, max=4) + seq = 16 + + # input_ids passed positionally, attention_mask as a kwarg. + positional_inputs = [ + torchtrt.Input( + min_shape=(1, seq), + opt_shape=(4, seq), + max_shape=(4, seq), + dtype=torch.int64, + name="input_ids", + ), + ] + kwarg_inputs = { + "attention_mask": torchtrt.Input( + min_shape=(1, seq), + opt_shape=(4, seq), + max_shape=(4, seq), + dtype=torch.int64, + name="attention_mask", + ), + } + dynamic_shapes = { + "input_ids": {0: batch}, + "attention_mask": {0: batch}, + } + + trt_mod = torchtrt.compile( + model, + ir="dynamo", + inputs=positional_inputs, + kwarg_inputs=kwarg_inputs, + dynamic_shapes=dynamic_shapes, + min_block_size=1, + cache_built_engines=False, + reuse_cached_engines=False, + ) + + for bs in (4, 2): + ids = torch.randint(0, 1024, (bs, seq), dtype=torch.int64, device="cuda") + mask = torch.ones((bs, seq), dtype=torch.int64, device="cuda") + + with torch.no_grad(): + ref = model(ids, attention_mask=mask) + out = trt_mod(ids, attention_mask=mask) + + cos_sim = cosine_similarity(ref, out) + assertions.assertTrue( + cos_sim > COSINE_THRESHOLD, + f"Mixed args/kwargs out-of-tolerance at bs={bs}: cos_sim={cos_sim}", + ) + + +@pytest.mark.unit +def test_dynamic_shapes_default_path_unchanged_for_static_inputs(): + """Sanity check: when ``dynamic_shapes=None`` and inputs are fully static, + behavior is unchanged from the legacy path.""" + + class StaticModel(nn.Module): + def __init__(self): + super().__init__() + self.linear = nn.Linear(8, 8) + + def forward(self, x): + return self.linear(x) + + model = StaticModel().eval().cuda() + trt_mod = torchtrt.compile( + model, + ir="dynamo", + inputs=[torchtrt.Input(shape=(2, 8), dtype=torch.float32, name="x")], + min_block_size=1, + cache_built_engines=False, + reuse_cached_engines=False, + ) + x = torch.randn((2, 8), device="cuda") + with torch.no_grad(): + ref = model(x) + out = trt_mod(x) + assertions.assertTrue(cosine_similarity(ref, out) > COSINE_THRESHOLD) + + +if __name__ == "__main__": + pytest.main([__file__, "-v"]) From b781ae455d0a41dffb8a3f242b7a938065c15353 Mon Sep 17 00:00:00 2001 From: apbose Date: Wed, 3 Jun 2026 23:37:22 -0700 Subject: [PATCH 2/5] shared dynamic dims across inputs via Inputs --- py/torch_tensorrt/_Input.py | 47 +++++ py/torch_tensorrt/_compile.py | 32 +-- py/torch_tensorrt/dynamo/_tracer.py | 86 ++++++-- .../dynamo/models/test_shared_dynamic_dim.py | 197 ++++++++---------- 4 files changed, 221 insertions(+), 141 deletions(-) diff --git a/py/torch_tensorrt/_Input.py b/py/torch_tensorrt/_Input.py index b547afb278..27ee329f69 100644 --- a/py/torch_tensorrt/_Input.py +++ b/py/torch_tensorrt/_Input.py @@ -51,6 +51,9 @@ class _ShapeMode(Enum): torch_tensor: torch.Tensor = None name: str = "" is_shape_tensor: bool = False + name_dims: Dict[int, str] = ( + {} + ) #: Optional {axis_index: name} for dynamic axes. The same name across inputs is exported as one shared ``torch.export.Dim`` (e.g. a batch axis shared by ``input_ids`` and ``attention_mask``). def __init__(self, *args: Any, **kwargs: Any) -> None: """__init__ Method for torch_tensorrt.Input @@ -162,11 +165,22 @@ def __init__(self, *args: Any, **kwargs: Any) -> None: "if you try to run inference with empty tensor inputs." ) + if "name_dims" in kwargs and kwargs["name_dims"]: + self.name_dims = Input._parse_name_dims( + kwargs["name_dims"], self.shape + ) + else: raise ValueError( f"Unexpected number of positional arguments for class Input \n Found {len(args)} arguments, expected either zero or a single positional arguments" ) + if kwargs.get("name_dims") and self.shape_mode != Input._ShapeMode.DYNAMIC: + raise ValueError( + "name_dims is only valid for dynamic inputs (min_shape/opt_shape/max_shape); " + "it has no meaning for a statically shaped Input." + ) + if "dtype" in kwargs: self.dtype = dtype._from(kwargs["dtype"]) @@ -261,6 +275,39 @@ def equivalent_spec(a: Input, b: Input) -> bool: ] return all(checks) + @staticmethod + def _parse_name_dims( + name_dims: Any, shape: Dict[str, Tuple[int, ...]] + ) -> Dict[int, str]: + """Validate and normalize the ``name_dims`` mapping ({axis: name}). + + Each named axis must be a valid index into the shape and must be + genuinely dynamic (``min != max``); a static axis cannot vary, so naming + it for cross-input sharing is a user error. + """ + if not isinstance(name_dims, dict): + raise TypeError( + f"name_dims must be a dict of {{axis_index: name}}, got {type(name_dims)}" + ) + rank = len(shape["min_shape"]) + parsed: Dict[int, str] = {} + for axis, dim_name in name_dims.items(): + if not isinstance(axis, int) or not (0 <= axis < rank): + raise ValueError( + f"name_dims key {axis!r} is not a valid axis index for an input of rank {rank}" + ) + if not isinstance(dim_name, str) or not dim_name: + raise ValueError( + f"name_dims value for axis {axis} must be a non-empty string, got {dim_name!r}" + ) + if shape["min_shape"][axis] == shape["max_shape"][axis]: + raise ValueError( + f"Axis {axis} named '{dim_name}' is static " + f"(min == max == {shape['min_shape'][axis]}); only dynamic axes can be named." + ) + parsed[axis] = dim_name + return parsed + @staticmethod def _supported_input_size_type(input_size: Any) -> bool: if isinstance(input_size, torch.Size): diff --git a/py/torch_tensorrt/_compile.py b/py/torch_tensorrt/_compile.py index f01e95eed9..447dbc56b0 100644 --- a/py/torch_tensorrt/_compile.py +++ b/py/torch_tensorrt/_compile.py @@ -17,7 +17,6 @@ Set, Tuple, Union, - cast, ) import torch @@ -63,6 +62,7 @@ ) from torch_tensorrt.dynamo._defaults import default_device from torch_tensorrt.dynamo._tracer import ( + build_dim_registry, get_dynamic_shapes_args, get_dynamic_shapes_kwargs, ) @@ -202,7 +202,6 @@ def compile( arg_inputs: Optional[Sequence[Sequence[Any]]] = None, kwarg_inputs: Optional[Dict[str, Any]] = None, enabled_precisions: Optional[Set[Union[torch.dtype, dtype]]] = None, - dynamic_shapes: Optional[Any] = None, **kwargs: Any, ) -> ( torch.nn.Module | torch.jit.ScriptModule | torch.fx.GraphModule | Callable[..., Any] @@ -238,14 +237,6 @@ def compile( kwarg_inputs (dict[Any, ...]): Optional, kwarg inputs to the module forward function. enabled_precision (Set(Union(torch.dtype, torch_tensorrt.dtype))): The set of datatypes that TensorRT can use when selecting kernels ir (str): The requested strategy to compile. (Options: default - Let Torch-TensorRT decide, ts - TorchScript with scripting path) - dynamic_shapes (Any): Optional ``dynamic_shapes`` dict (or list / nested - structure) forwarded to ``torch.export.export``. Supply this to share a - ``Dim`` across multiple inputs (e.g. when ``input_ids`` and ``attention_mask`` - must have the same batch size at runtime). When omitted, dynamic shapes are - auto-inferred from per-input ``min_shape``/``max_shape`` and **each input gets - its own independent symbol** -- which fails ``torch.export``'s constraint - check for models that broadcast across these axes. Only consulted when - ``module`` is an ``nn.Module`` (ignored for ``ExportedProgram``). **kwargs: Additional settings for the specific requested strategy (See submodules for more info) Returns: @@ -324,7 +315,7 @@ def _fx_input_interface( "'arg_inputs' and 'inputs' should not be used at the same time." ) if inputs is not None: - arg_inputs = inputs + arg_inputs = inputs # type: ignore[assignment] if kwarg_inputs is None: kwarg_inputs = {} @@ -346,7 +337,6 @@ def _fx_input_interface( module, torchtrt_arg_inputs, kwarg_inputs=torchtrt_kwarg_inputs, - dynamic_shapes=dynamic_shapes, **kwargs, ) trt_graph_module = dynamo_compile( @@ -424,7 +414,7 @@ def cross_compile_for_windows( "'arg_inputs' and 'inputs' should not be used at the same time." ) - arg_inputs = inputs or arg_inputs + arg_inputs = inputs or arg_inputs # type: ignore[assignment] if kwarg_inputs is None: kwarg_inputs = {} @@ -524,7 +514,7 @@ def convert_method_to_trt_engine( raise AssertionError( "'arg_inputs' and 'inputs' should not be used at the same time." ) - arg_inputs = arg_inputs or inputs + arg_inputs = arg_inputs or inputs # type: ignore[assignment] module_type = _parse_module_type(module) target_ir = _get_target_fe(module_type, ir) @@ -844,8 +834,13 @@ def _all_are_input_objects(obj: Any) -> bool: "The explicit dynamic_shapes parameter takes precedence and Input shape specifications will be ignored." ) else: - inferred_dynamic_shapes = get_dynamic_shapes_args(module, arg_inputs) - inferred_dynamic_shapes.update(get_dynamic_shapes_kwargs(kwarg_inputs)) + dim_registry = build_dim_registry(arg_inputs, kwarg_inputs) + inferred_dynamic_shapes = get_dynamic_shapes_args( + module, arg_inputs, dim_registry + ) + inferred_dynamic_shapes.update( + get_dynamic_shapes_kwargs(kwarg_inputs, dim_registry) + ) if inferred_dynamic_shapes is not None: dynamic_shapes = inferred_dynamic_shapes @@ -853,13 +848,8 @@ def _all_are_input_objects(obj: Any) -> bool: f"Inferred dynamic_shapes from torch_tensorrt.Input objects with min/opt/max specifications: {dynamic_shapes}" ) -<<<<<<< HEAD - arg_tensors = tuple(get_torch_inputs(arg_inputs, default_device())) # type: ignore[arg-type] - kwarg_tensors = get_torch_inputs(kwarg_inputs, default_device()) # type: ignore[assignment] -======= arg_tensors = tuple(get_torch_inputs(arg_inputs, default_device())) kwarg_tensors = get_torch_inputs(kwarg_inputs, default_device()) ->>>>>>> 7fa5d838 (dynamic shape arg) else: # Mixed case: some inputs are Tensors, some are Input objects diff --git a/py/torch_tensorrt/dynamo/_tracer.py b/py/torch_tensorrt/dynamo/_tracer.py index e36c3b5240..3cd809e296 100644 --- a/py/torch_tensorrt/dynamo/_tracer.py +++ b/py/torch_tensorrt/dynamo/_tracer.py @@ -19,7 +19,6 @@ def trace( *, arg_inputs: Optional[Tuple[Any, ...]] = None, kwarg_inputs: Optional[dict[Any, Any]] = None, - dynamic_shapes: Optional[Any] = None, **kwargs: Any, ) -> torch.export.ExportedProgram: """Exports a ``torch.export.ExportedProgram`` from a ``torch.nn.Module`` or ``torch.fx.GraphModule`` specifically targeting being compiled with Torch-TensorRT @@ -74,11 +73,12 @@ def trace( device = to_torch_device(kwargs.get("device", default_device())) torch_arg_inputs = get_torch_inputs(arg_inputs, device) torch_kwarg_inputs = get_torch_inputs(kwarg_inputs, device) - if dynamic_shapes is None: - # Auto-inferred dims are independent per input; pass dynamic_shapes - # explicitly to share a Dim across inputs. - dynamic_shapes = get_dynamic_shapes_args(mod, arg_inputs) - dynamic_shapes.update(get_dynamic_shapes_kwargs(kwarg_inputs)) + # Build dynamic shapes from the Input objects. Inputs carrying name_dims + # share a Dim across inputs via the registry; the rest get an independent + # per-input Dim. + dim_registry = build_dim_registry(arg_inputs, kwarg_inputs) + dynamic_shapes = get_dynamic_shapes_args(mod, arg_inputs, dim_registry) + dynamic_shapes.update(get_dynamic_shapes_kwargs(kwarg_inputs, dim_registry)) exp_program = export( mod, tuple(torch_arg_inputs), @@ -90,49 +90,109 @@ def trace( return exp_program -def get_dynamic_shapes_kwargs(inputs: Any) -> Union[dict[str, Any], list[Any]]: +def _collect_inputs(obj: Any) -> list[Input]: + """Flatten an arg/kwarg input structure into a list of Input objects.""" + if isinstance(obj, Input): + return [obj] + elif isinstance(obj, dict): + collected: list[Input] = [] + for v in obj.values(): + collected.extend(_collect_inputs(v)) + return collected + elif isinstance(obj, (list, tuple)): + collected = [] + for v in obj: + collected.extend(_collect_inputs(v)) + return collected + return [] + + +def build_dim_registry(arg_inputs: Any, kwarg_inputs: Any) -> dict[str, Any]: + """Build a ``{name: torch.export.Dim}`` registry from Input.name_dims. + + The same name appearing on multiple inputs yields a single shared ``Dim`` + instance, so ``torch.export`` treats those axes as one symbol. Conflicting + (min, max) ranges for the same name are rejected. + """ + registry: dict[str, Any] = {} + bounds: dict[str, tuple[int, int]] = {} + for inp in _collect_inputs(arg_inputs) + _collect_inputs(kwarg_inputs): + name_dims = getattr(inp, "name_dims", None) + if not name_dims or inp.shape_mode != Input._ShapeMode.DYNAMIC: + continue + assert isinstance(inp.shape, dict) + min_shape = inp.shape["min_shape"] + max_shape = inp.shape["max_shape"] + for axis, dim_name in name_dims.items(): + lo, hi = int(min_shape[axis]), int(max_shape[axis]) + if dim_name in bounds: + if bounds[dim_name] != (lo, hi): + raise ValueError( + f"Dimension name '{dim_name}' is used with conflicting ranges " + f"{bounds[dim_name]} and {(lo, hi)}. A shared named dimension " + f"must have identical (min, max) on every input that uses it." + ) + else: + bounds[dim_name] = (lo, hi) + registry[dim_name] = Dim(dim_name, min=lo, max=hi) + return registry + + +def get_dynamic_shapes_kwargs( + inputs: Any, dim_registry: Optional[dict[str, Any]] = None +) -> Union[dict[str, Any], list[Any]]: if isinstance(inputs, dict): dynamic_shapes_kwarg = {} for k, v in inputs.items(): - dynamic_shapes_kwarg[k] = get_dynamic_shapes_kwargs(v) + dynamic_shapes_kwarg[k] = get_dynamic_shapes_kwargs(v, dim_registry) return dynamic_shapes_kwarg elif isinstance(inputs, Input): - return get_dynamic_shapes(inputs) + return get_dynamic_shapes(inputs, dim_registry) elif isinstance(inputs, (list, tuple)): dynamic_shapes = [] for input in inputs: - dynamic_shapes.append(get_dynamic_shapes(input)) + dynamic_shapes.append(get_dynamic_shapes(input, dim_registry)) return dynamic_shapes raise TypeError(f"Unknown type {type(inputs)}.") -def get_dynamic_shapes_args(mod: torch.nn.Module, inputs: Any) -> dict[str, Any]: +def get_dynamic_shapes_args( + mod: torch.nn.Module, inputs: Any, dim_registry: Optional[dict[str, Any]] = None +) -> dict[str, Any]: # dynamic_shape is a dict and cannot work without keys. Here we use position argument name # in forward function as the name args = list(signature(mod.forward).parameters.keys()) dynamic_shapes = {} for input, input_name in zip(inputs, args[: len(inputs)]): - dynamic_shapes[input_name] = get_dynamic_shapes(input) + dynamic_shapes[input_name] = get_dynamic_shapes(input, dim_registry) return dynamic_shapes -def get_dynamic_shapes(input: Input) -> dict[Any, Any]: +def get_dynamic_shapes( + input: Input, dim_registry: Optional[dict[str, Any]] = None +) -> dict[Any, Any]: if not isinstance(input, Input): # If the input is torch.Tensor, no dynamic is needed. Return empty dict return {} else: dynamic_dims = {} if input.shape_mode == Input._ShapeMode.DYNAMIC: + assert isinstance(input.shape, dict) min_shape = input.shape["min_shape"] opt_shape = input.shape["opt_shape"] max_shape = input.shape["max_shape"] + name_dims = getattr(input, "name_dims", None) or {} assert len(min_shape) == len(opt_shape) == len(max_shape) for dim in range(len(min_shape)): if min_shape[dim] == opt_shape[dim] == max_shape[dim]: continue + elif dim_registry is not None and dim in name_dims: + # Named axis: reuse the shared Dim so axes with the same + # name across inputs become a single exported symbol. + dynamic_dims[dim] = dim_registry[name_dims[dim]] else: dynamic_dims[dim] = Dim( input.name + "_" + str(dim), diff --git a/tests/py/dynamo/models/test_shared_dynamic_dim.py b/tests/py/dynamo/models/test_shared_dynamic_dim.py index 1f281e6851..25ff7e292f 100644 --- a/tests/py/dynamo/models/test_shared_dynamic_dim.py +++ b/tests/py/dynamo/models/test_shared_dynamic_dim.py @@ -1,28 +1,27 @@ # type: ignore """ -Tests for the ``dynamic_shapes=`` passthrough kwarg on ``torch_tensorrt.compile``. +Tests for sharing a dynamic dimension across inputs via ``Input(name_dims=...)``. Background: when a model takes multiple inputs whose dynamic axes must be **equal at runtime** (e.g. HF encoders with ``input_ids`` / ``attention_mask`` -both shaped ``[B, S]``), the legacy auto-inference path in -``dynamo/_tracer.py`` mints an *independent* ``Dim`` per input. ``torch.export`` -then fails its constraint check for any forward() that broadcasts across those -axes (here: ``embed(input_ids) * mask.unsqueeze(-1)``), raising -``ConstraintViolationError``. - -These tests exercise the new ``dynamic_shapes=`` passthrough that lets the -caller supply a shared ``Dim`` directly to ``torch_tensorrt.compile`` -- -mirroring the ``torch.export.export(dynamic_shapes=...)`` signature -- so the -shared-batch case compiles end to end without the caller having to pre-export -the module themselves. +both shaped ``[B, S]``), naming each axis independently makes ``torch.export`` +mint an *independent* ``Dim`` per input. ``torch.export`` then fails its +constraint check for any forward() that broadcasts across those axes (here: +``embed(input_ids) * mask.unsqueeze(-1)``), raising ``ConstraintViolationError``. + +``Input(name_dims={axis: name})`` lets the caller tag a dynamic axis with a +name; the same name across inputs is exported as a single shared ``Dim``. All +the dynamic-shape intent lives on the ``Input`` objects -- no separate +``dynamic_shapes`` argument and no ``torch.export`` knowledge required at the +call site. """ + import unittest import pytest import torch import torch.nn as nn import torch_tensorrt as torchtrt -from torch.export import Dim from torch_tensorrt.dynamo.utils import COSINE_THRESHOLD, cosine_similarity assertions = unittest.TestCase() @@ -32,8 +31,8 @@ class _SharedBatchEncoder(nn.Module): """HF-style encoder stand-in: two int64 inputs sharing the batch axis. The ``embed(input_ids) * mask.unsqueeze(-1)`` broadcast forces - ``input_ids.size(0) == attention_mask.size(0)`` -- the relationship the - auto-inference path cannot express. + ``input_ids.size(0) == attention_mask.size(0)`` -- the relationship a shared + named dimension expresses. """ def __init__(self, vocab: int = 1024, hidden: int = 32): @@ -47,43 +46,34 @@ def forward(self, input_ids, attention_mask): return self.proj(x * mask) -def _kwarg_inputs(seq: int = 16, batch_min: int = 1, batch_max: int = 4): - return { - "input_ids": torchtrt.Input( - min_shape=(batch_min, seq), - opt_shape=(batch_max, seq), - max_shape=(batch_max, seq), - dtype=torch.int64, - name="input_ids", - ), - "attention_mask": torchtrt.Input( - min_shape=(batch_min, seq), - opt_shape=(batch_max, seq), - max_shape=(batch_max, seq), - dtype=torch.int64, - name="attention_mask", - ), - } +def _named_input(name: str, seq: int = 16, batch_min: int = 1, batch_max: int = 4): + """A dynamic int64 Input whose batch axis (0) is named "B" for sharing.""" + return torchtrt.Input( + min_shape=(batch_min, seq), + opt_shape=(batch_max, seq), + max_shape=(batch_max, seq), + dtype=torch.int64, + name=name, + name_dims={0: "B"}, + ) @pytest.mark.unit @pytest.mark.critical -def test_dynamic_shapes_passthrough_with_shared_batch_dim(): - """With ``dynamic_shapes={..: {0: batch}, ..: {0: batch}}`` (one shared - ``Dim``), compile succeeds and the engine matches the eager model.""" +def test_name_dims_shared_batch_kwarg_inputs(): + """Shared batch axis declared via ``Input(name_dims={0: "B"})`` on both + kwarg inputs -- same name => one exported symbol; engine matches eager.""" model = _SharedBatchEncoder().eval().cuda() - batch = Dim("batch", min=1, max=4) - dynamic_shapes = { - "input_ids": {0: batch}, - "attention_mask": {0: batch}, + kwarg_inputs = { + "input_ids": _named_input("input_ids"), + "attention_mask": _named_input("attention_mask"), } trt_mod = torchtrt.compile( model, ir="dynamo", - kwarg_inputs=_kwarg_inputs(), - dynamic_shapes=dynamic_shapes, + kwarg_inputs=kwarg_inputs, min_block_size=1, cache_built_engines=False, reuse_cached_engines=False, @@ -101,50 +91,32 @@ def test_dynamic_shapes_passthrough_with_shared_batch_dim(): cos_sim = cosine_similarity(ref, out) assertions.assertTrue( cos_sim > COSINE_THRESHOLD, - f"Shared-batch encoder out-of-tolerance at bs={bs}: cos_sim={cos_sim}", + f"name_dims shared batch (kwargs) out-of-tolerance at bs={bs}: cos_sim={cos_sim}", ) @pytest.mark.unit -def test_dynamic_shapes_passthrough_positional_tuple_form(): - """``torch.export`` also accepts ``dynamic_shapes`` as a tuple matching the - positional-args order. Verify the passthrough handles that form too.""" +def test_name_dims_shared_batch_positional_inputs(): + """Same feature with positional ``inputs=[...]`` instead of kwargs.""" model = _SharedBatchEncoder().eval().cuda() - batch = Dim("batch", min=1, max=4) - seq = 16 positional_inputs = [ - torchtrt.Input( - min_shape=(1, seq), - opt_shape=(4, seq), - max_shape=(4, seq), - dtype=torch.int64, - name="input_ids", - ), - torchtrt.Input( - min_shape=(1, seq), - opt_shape=(4, seq), - max_shape=(4, seq), - dtype=torch.int64, - name="attention_mask", - ), + _named_input("input_ids"), + _named_input("attention_mask"), ] - # Tuple form: one entry per positional arg, in declaration order. - dynamic_shapes = ({0: batch}, {0: batch}) trt_mod = torchtrt.compile( model, ir="dynamo", inputs=positional_inputs, - dynamic_shapes=dynamic_shapes, min_block_size=1, cache_built_engines=False, reuse_cached_engines=False, ) for bs in (4, 2): - ids = torch.randint(0, 1024, (bs, seq), dtype=torch.int64, device="cuda") - mask = torch.ones((bs, seq), dtype=torch.int64, device="cuda") + ids = torch.randint(0, 1024, (bs, 16), dtype=torch.int64, device="cuda") + mask = torch.ones((bs, 16), dtype=torch.int64, device="cuda") with torch.no_grad(): ref = model(ids, mask) @@ -153,58 +125,28 @@ def test_dynamic_shapes_passthrough_positional_tuple_form(): cos_sim = cosine_similarity(ref, out) assertions.assertTrue( cos_sim > COSINE_THRESHOLD, - f"Tuple-form dynamic_shapes out-of-tolerance at bs={bs}: cos_sim={cos_sim}", + f"name_dims shared batch (positional) out-of-tolerance at bs={bs}: cos_sim={cos_sim}", ) @pytest.mark.unit -def test_dynamic_shapes_passthrough_mixed_args_and_kwargs(): - """One positional input, one kwarg input, sharing a batch ``Dim``. Uses the - unified dict-by-parameter-name form, which spans both positional and keyword - parameters.""" +def test_name_dims_shared_batch_mixed_args_and_kwargs(): + """input_ids passed positionally, attention_mask as a kwarg; both share "B".""" model = _SharedBatchEncoder().eval().cuda() - batch = Dim("batch", min=1, max=4) - seq = 16 - - # input_ids passed positionally, attention_mask as a kwarg. - positional_inputs = [ - torchtrt.Input( - min_shape=(1, seq), - opt_shape=(4, seq), - max_shape=(4, seq), - dtype=torch.int64, - name="input_ids", - ), - ] - kwarg_inputs = { - "attention_mask": torchtrt.Input( - min_shape=(1, seq), - opt_shape=(4, seq), - max_shape=(4, seq), - dtype=torch.int64, - name="attention_mask", - ), - } - dynamic_shapes = { - "input_ids": {0: batch}, - "attention_mask": {0: batch}, - } - trt_mod = torchtrt.compile( model, ir="dynamo", - inputs=positional_inputs, - kwarg_inputs=kwarg_inputs, - dynamic_shapes=dynamic_shapes, + inputs=[_named_input("input_ids")], + kwarg_inputs={"attention_mask": _named_input("attention_mask")}, min_block_size=1, cache_built_engines=False, reuse_cached_engines=False, ) for bs in (4, 2): - ids = torch.randint(0, 1024, (bs, seq), dtype=torch.int64, device="cuda") - mask = torch.ones((bs, seq), dtype=torch.int64, device="cuda") + ids = torch.randint(0, 1024, (bs, 16), dtype=torch.int64, device="cuda") + mask = torch.ones((bs, 16), dtype=torch.int64, device="cuda") with torch.no_grad(): ref = model(ids, attention_mask=mask) @@ -213,14 +155,55 @@ def test_dynamic_shapes_passthrough_mixed_args_and_kwargs(): cos_sim = cosine_similarity(ref, out) assertions.assertTrue( cos_sim > COSINE_THRESHOLD, - f"Mixed args/kwargs out-of-tolerance at bs={bs}: cos_sim={cos_sim}", + f"name_dims shared batch (mixed) out-of-tolerance at bs={bs}: cos_sim={cos_sim}", + ) + + +@pytest.mark.unit +def test_name_dims_conflicting_ranges_raises(): + """Same name with different (min, max) across inputs is a user error.""" + from torch_tensorrt.dynamo._tracer import build_dim_registry + + seq = 16 + inputs = { + "input_ids": torchtrt.Input( + min_shape=(1, seq), + opt_shape=(4, seq), + max_shape=(4, seq), + dtype=torch.int64, + name="input_ids", + name_dims={0: "B"}, + ), + "attention_mask": torchtrt.Input( + min_shape=(1, seq), + opt_shape=(8, seq), + max_shape=(8, seq), + dtype=torch.int64, + name="attention_mask", + name_dims={0: "B"}, + ), + } + with assertions.assertRaises(ValueError): + build_dim_registry((), inputs) + + +@pytest.mark.unit +def test_name_dims_rejected_on_static_axis(): + """Naming a static axis (min == max) is rejected at Input construction.""" + with assertions.assertRaises(ValueError): + torchtrt.Input( + min_shape=(1, 16), + opt_shape=(1, 16), + max_shape=(1, 16), + dtype=torch.int64, + name="x", + name_dims={0: "B"}, ) @pytest.mark.unit -def test_dynamic_shapes_default_path_unchanged_for_static_inputs(): - """Sanity check: when ``dynamic_shapes=None`` and inputs are fully static, - behavior is unchanged from the legacy path.""" +def test_default_path_unchanged_for_static_inputs(): + """Sanity check: a fully static input with no name_dims is unchanged.""" class StaticModel(nn.Module): def __init__(self): From e1cff6e8678e548b273845c379589f94611344b3 Mon Sep 17 00:00:00 2001 From: apbose Date: Thu, 4 Jun 2026 10:58:59 -0700 Subject: [PATCH 3/5] adding testcase --- tests/py/dynamo/models/test_shared_dynamic_dim.py | 14 ++++++++++++++ 1 file changed, 14 insertions(+) diff --git a/tests/py/dynamo/models/test_shared_dynamic_dim.py b/tests/py/dynamo/models/test_shared_dynamic_dim.py index 25ff7e292f..b0f7c94d9b 100644 --- a/tests/py/dynamo/models/test_shared_dynamic_dim.py +++ b/tests/py/dynamo/models/test_shared_dynamic_dim.py @@ -201,6 +201,20 @@ def test_name_dims_rejected_on_static_axis(): ) +@pytest.mark.unit +def test_name_dims_rejected_on_out_of_range_axis(): + """An axis index outside the input's rank is rejected at construction.""" + with assertions.assertRaises(ValueError): + torchtrt.Input( + min_shape=(1, 16), + opt_shape=(4, 16), + max_shape=(4, 16), + dtype=torch.int64, + name="x", + name_dims={5: "B"}, # rank is 2; axis 5 does not exist + ) + + @pytest.mark.unit def test_default_path_unchanged_for_static_inputs(): """Sanity check: a fully static input with no name_dims is unchanged.""" From 896857ba5c31c31212c4b73501969ff0bf3d9ed2 Mon Sep 17 00:00:00 2001 From: apbose Date: Mon, 8 Jun 2026 13:02:12 -0700 Subject: [PATCH 4/5] replacing named_dims with shared_dims --- py/torch_tensorrt/_Input.py | 28 +++++++-------- py/torch_tensorrt/dynamo/_tracer.py | 16 ++++----- .../dynamo/models/test_shared_dynamic_dim.py | 36 +++++++++---------- 3 files changed, 40 insertions(+), 40 deletions(-) diff --git a/py/torch_tensorrt/_Input.py b/py/torch_tensorrt/_Input.py index 27ee329f69..4e078daac4 100644 --- a/py/torch_tensorrt/_Input.py +++ b/py/torch_tensorrt/_Input.py @@ -51,7 +51,7 @@ class _ShapeMode(Enum): torch_tensor: torch.Tensor = None name: str = "" is_shape_tensor: bool = False - name_dims: Dict[int, str] = ( + shared_dims: Dict[int, str] = ( {} ) #: Optional {axis_index: name} for dynamic axes. The same name across inputs is exported as one shared ``torch.export.Dim`` (e.g. a batch axis shared by ``input_ids`` and ``attention_mask``). @@ -165,9 +165,9 @@ def __init__(self, *args: Any, **kwargs: Any) -> None: "if you try to run inference with empty tensor inputs." ) - if "name_dims" in kwargs and kwargs["name_dims"]: - self.name_dims = Input._parse_name_dims( - kwargs["name_dims"], self.shape + if "shared_dims" in kwargs and kwargs["shared_dims"]: + self.shared_dims = Input._parse_shared_dims( + kwargs["shared_dims"], self.shape ) else: @@ -175,9 +175,9 @@ def __init__(self, *args: Any, **kwargs: Any) -> None: f"Unexpected number of positional arguments for class Input \n Found {len(args)} arguments, expected either zero or a single positional arguments" ) - if kwargs.get("name_dims") and self.shape_mode != Input._ShapeMode.DYNAMIC: + if kwargs.get("shared_dims") and self.shape_mode != Input._ShapeMode.DYNAMIC: raise ValueError( - "name_dims is only valid for dynamic inputs (min_shape/opt_shape/max_shape); " + "shared_dims is only valid for dynamic inputs (min_shape/opt_shape/max_shape); " "it has no meaning for a statically shaped Input." ) @@ -276,29 +276,29 @@ def equivalent_spec(a: Input, b: Input) -> bool: return all(checks) @staticmethod - def _parse_name_dims( - name_dims: Any, shape: Dict[str, Tuple[int, ...]] + def _parse_shared_dims( + shared_dims: Any, shape: Dict[str, Tuple[int, ...]] ) -> Dict[int, str]: - """Validate and normalize the ``name_dims`` mapping ({axis: name}). + """Validate and normalize the ``shared_dims`` mapping ({axis: name}). Each named axis must be a valid index into the shape and must be genuinely dynamic (``min != max``); a static axis cannot vary, so naming it for cross-input sharing is a user error. """ - if not isinstance(name_dims, dict): + if not isinstance(shared_dims, dict): raise TypeError( - f"name_dims must be a dict of {{axis_index: name}}, got {type(name_dims)}" + f"shared_dims must be a dict of {{axis_index: name}}, got {type(shared_dims)}" ) rank = len(shape["min_shape"]) parsed: Dict[int, str] = {} - for axis, dim_name in name_dims.items(): + for axis, dim_name in shared_dims.items(): if not isinstance(axis, int) or not (0 <= axis < rank): raise ValueError( - f"name_dims key {axis!r} is not a valid axis index for an input of rank {rank}" + f"shared_dims key {axis!r} is not a valid axis index for an input of rank {rank}" ) if not isinstance(dim_name, str) or not dim_name: raise ValueError( - f"name_dims value for axis {axis} must be a non-empty string, got {dim_name!r}" + f"shared_dims value for axis {axis} must be a non-empty string, got {dim_name!r}" ) if shape["min_shape"][axis] == shape["max_shape"][axis]: raise ValueError( diff --git a/py/torch_tensorrt/dynamo/_tracer.py b/py/torch_tensorrt/dynamo/_tracer.py index 3cd809e296..399559347b 100644 --- a/py/torch_tensorrt/dynamo/_tracer.py +++ b/py/torch_tensorrt/dynamo/_tracer.py @@ -73,7 +73,7 @@ def trace( device = to_torch_device(kwargs.get("device", default_device())) torch_arg_inputs = get_torch_inputs(arg_inputs, device) torch_kwarg_inputs = get_torch_inputs(kwarg_inputs, device) - # Build dynamic shapes from the Input objects. Inputs carrying name_dims + # Build dynamic shapes from the Input objects. Inputs carrying shared_dims # share a Dim across inputs via the registry; the rest get an independent # per-input Dim. dim_registry = build_dim_registry(arg_inputs, kwarg_inputs) @@ -108,7 +108,7 @@ def _collect_inputs(obj: Any) -> list[Input]: def build_dim_registry(arg_inputs: Any, kwarg_inputs: Any) -> dict[str, Any]: - """Build a ``{name: torch.export.Dim}`` registry from Input.name_dims. + """Build a ``{name: torch.export.Dim}`` registry from Input.shared_dims. The same name appearing on multiple inputs yields a single shared ``Dim`` instance, so ``torch.export`` treats those axes as one symbol. Conflicting @@ -117,13 +117,13 @@ def build_dim_registry(arg_inputs: Any, kwarg_inputs: Any) -> dict[str, Any]: registry: dict[str, Any] = {} bounds: dict[str, tuple[int, int]] = {} for inp in _collect_inputs(arg_inputs) + _collect_inputs(kwarg_inputs): - name_dims = getattr(inp, "name_dims", None) - if not name_dims or inp.shape_mode != Input._ShapeMode.DYNAMIC: + shared_dims = getattr(inp, "shared_dims", None) + if not shared_dims or inp.shape_mode != Input._ShapeMode.DYNAMIC: continue assert isinstance(inp.shape, dict) min_shape = inp.shape["min_shape"] max_shape = inp.shape["max_shape"] - for axis, dim_name in name_dims.items(): + for axis, dim_name in shared_dims.items(): lo, hi = int(min_shape[axis]), int(max_shape[axis]) if dim_name in bounds: if bounds[dim_name] != (lo, hi): @@ -184,15 +184,15 @@ def get_dynamic_shapes( min_shape = input.shape["min_shape"] opt_shape = input.shape["opt_shape"] max_shape = input.shape["max_shape"] - name_dims = getattr(input, "name_dims", None) or {} + shared_dims = getattr(input, "shared_dims", None) or {} assert len(min_shape) == len(opt_shape) == len(max_shape) for dim in range(len(min_shape)): if min_shape[dim] == opt_shape[dim] == max_shape[dim]: continue - elif dim_registry is not None and dim in name_dims: + elif dim_registry is not None and dim in shared_dims: # Named axis: reuse the shared Dim so axes with the same # name across inputs become a single exported symbol. - dynamic_dims[dim] = dim_registry[name_dims[dim]] + dynamic_dims[dim] = dim_registry[shared_dims[dim]] else: dynamic_dims[dim] = Dim( input.name + "_" + str(dim), diff --git a/tests/py/dynamo/models/test_shared_dynamic_dim.py b/tests/py/dynamo/models/test_shared_dynamic_dim.py index b0f7c94d9b..9d1ac28fbb 100644 --- a/tests/py/dynamo/models/test_shared_dynamic_dim.py +++ b/tests/py/dynamo/models/test_shared_dynamic_dim.py @@ -1,6 +1,6 @@ # type: ignore """ -Tests for sharing a dynamic dimension across inputs via ``Input(name_dims=...)``. +Tests for sharing a dynamic dimension across inputs via ``Input(shared_dims=...)``. Background: when a model takes multiple inputs whose dynamic axes must be **equal at runtime** (e.g. HF encoders with ``input_ids`` / ``attention_mask`` @@ -9,7 +9,7 @@ constraint check for any forward() that broadcasts across those axes (here: ``embed(input_ids) * mask.unsqueeze(-1)``), raising ``ConstraintViolationError``. -``Input(name_dims={axis: name})`` lets the caller tag a dynamic axis with a +``Input(shared_dims={axis: name})`` lets the caller tag a dynamic axis with a name; the same name across inputs is exported as a single shared ``Dim``. All the dynamic-shape intent lives on the ``Input`` objects -- no separate ``dynamic_shapes`` argument and no ``torch.export`` knowledge required at the @@ -54,14 +54,14 @@ def _named_input(name: str, seq: int = 16, batch_min: int = 1, batch_max: int = max_shape=(batch_max, seq), dtype=torch.int64, name=name, - name_dims={0: "B"}, + shared_dims={0: "B"}, ) @pytest.mark.unit @pytest.mark.critical -def test_name_dims_shared_batch_kwarg_inputs(): - """Shared batch axis declared via ``Input(name_dims={0: "B"})`` on both +def test_shared_dims_shared_batch_kwarg_inputs(): + """Shared batch axis declared via ``Input(shared_dims={0: "B"})`` on both kwarg inputs -- same name => one exported symbol; engine matches eager.""" model = _SharedBatchEncoder().eval().cuda() @@ -91,12 +91,12 @@ def test_name_dims_shared_batch_kwarg_inputs(): cos_sim = cosine_similarity(ref, out) assertions.assertTrue( cos_sim > COSINE_THRESHOLD, - f"name_dims shared batch (kwargs) out-of-tolerance at bs={bs}: cos_sim={cos_sim}", + f"shared_dims shared batch (kwargs) out-of-tolerance at bs={bs}: cos_sim={cos_sim}", ) @pytest.mark.unit -def test_name_dims_shared_batch_positional_inputs(): +def test_shared_dims_shared_batch_positional_inputs(): """Same feature with positional ``inputs=[...]`` instead of kwargs.""" model = _SharedBatchEncoder().eval().cuda() @@ -125,12 +125,12 @@ def test_name_dims_shared_batch_positional_inputs(): cos_sim = cosine_similarity(ref, out) assertions.assertTrue( cos_sim > COSINE_THRESHOLD, - f"name_dims shared batch (positional) out-of-tolerance at bs={bs}: cos_sim={cos_sim}", + f"shared_dims shared batch (positional) out-of-tolerance at bs={bs}: cos_sim={cos_sim}", ) @pytest.mark.unit -def test_name_dims_shared_batch_mixed_args_and_kwargs(): +def test_shared_dims_shared_batch_mixed_args_and_kwargs(): """input_ids passed positionally, attention_mask as a kwarg; both share "B".""" model = _SharedBatchEncoder().eval().cuda() @@ -155,12 +155,12 @@ def test_name_dims_shared_batch_mixed_args_and_kwargs(): cos_sim = cosine_similarity(ref, out) assertions.assertTrue( cos_sim > COSINE_THRESHOLD, - f"name_dims shared batch (mixed) out-of-tolerance at bs={bs}: cos_sim={cos_sim}", + f"shared_dims shared batch (mixed) out-of-tolerance at bs={bs}: cos_sim={cos_sim}", ) @pytest.mark.unit -def test_name_dims_conflicting_ranges_raises(): +def test_shared_dims_conflicting_ranges_raises(): """Same name with different (min, max) across inputs is a user error.""" from torch_tensorrt.dynamo._tracer import build_dim_registry @@ -172,7 +172,7 @@ def test_name_dims_conflicting_ranges_raises(): max_shape=(4, seq), dtype=torch.int64, name="input_ids", - name_dims={0: "B"}, + shared_dims={0: "B"}, ), "attention_mask": torchtrt.Input( min_shape=(1, seq), @@ -180,7 +180,7 @@ def test_name_dims_conflicting_ranges_raises(): max_shape=(8, seq), dtype=torch.int64, name="attention_mask", - name_dims={0: "B"}, + shared_dims={0: "B"}, ), } with assertions.assertRaises(ValueError): @@ -188,7 +188,7 @@ def test_name_dims_conflicting_ranges_raises(): @pytest.mark.unit -def test_name_dims_rejected_on_static_axis(): +def test_shared_dims_rejected_on_static_axis(): """Naming a static axis (min == max) is rejected at Input construction.""" with assertions.assertRaises(ValueError): torchtrt.Input( @@ -197,12 +197,12 @@ def test_name_dims_rejected_on_static_axis(): max_shape=(1, 16), dtype=torch.int64, name="x", - name_dims={0: "B"}, + shared_dims={0: "B"}, ) @pytest.mark.unit -def test_name_dims_rejected_on_out_of_range_axis(): +def test_shared_dims_rejected_on_out_of_range_axis(): """An axis index outside the input's rank is rejected at construction.""" with assertions.assertRaises(ValueError): torchtrt.Input( @@ -211,13 +211,13 @@ def test_name_dims_rejected_on_out_of_range_axis(): max_shape=(4, 16), dtype=torch.int64, name="x", - name_dims={5: "B"}, # rank is 2; axis 5 does not exist + shared_dims={5: "B"}, # rank is 2; axis 5 does not exist ) @pytest.mark.unit def test_default_path_unchanged_for_static_inputs(): - """Sanity check: a fully static input with no name_dims is unchanged.""" + """Sanity check: a fully static input with no shared_dims is unchanged.""" class StaticModel(nn.Module): def __init__(self): From a0eeae7b672f72c5e00a7fbdaba27e4e127e5550 Mon Sep 17 00:00:00 2001 From: cehongwang Date: Mon, 8 Jun 2026 21:32:19 +0000 Subject: [PATCH 5/5] feat: multiple optimization profiles for disjoint input shape regimes Add support for defining N optimization profiles at compile time via the list-based ``Input.profiles`` API and selecting the active profile at runtime (manual pin by index, or opt-in shape-based auto-selection). - AOT (torch.export) compile path builds one TRT optimization profile per declared profile index; submodules inherit the profile count via propagation across graph breaks. - Python and C++ runtimes expose a matching primitive engine API (set_active_profile / num_optimization_profiles / _active_profile_index / _auto_select_profiles) so the two runtimes remain interchangeable. - Profile selection is exposed through the optimization_profile context manager; auto-selection uses lazy (first-fitting) profile selection. - Backward compatible: engines without declared profiles keep the historical single-profile (dynamic) / no-profile (static) behavior. Includes an example and runtime tests covering dynamic submodule inputs. --- core/runtime/TRTEngine.cpp | 95 ++++ core/runtime/TRTEngine.h | 27 ++ core/runtime/execute_engine.cpp | 12 + core/runtime/register_jit_hooks.cpp | 7 + docsrc/tutorials/runtime_opt/index.rst | 6 +- .../multi_optimization_profiles.rst | 164 +++++++ .../dynamo/multi_optimization_profiles.py | 265 +++++++++++ py/torch_tensorrt/_Input.py | 144 +++++- py/torch_tensorrt/dynamo/_compiler.py | 30 +- .../dynamo/conversion/_TRTInterpreter.py | 109 ++++- .../dynamo/partitioning/__init__.py | 1 + .../dynamo/partitioning/common.py | 186 +++++++- .../dynamo/runtime/_TRTEngine.py | 118 ++++- .../dynamo/runtime/_TorchTensorRTModule.py | 70 +++ py/torch_tensorrt/dynamo/utils.py | 64 +++ py/torch_tensorrt/runtime/__init__.py | 1 + .../runtime/_optimization_profile.py | 95 ++++ .../test_multi_optimization_profiles.py | 426 ++++++++++++++++++ 18 files changed, 1779 insertions(+), 41 deletions(-) create mode 100644 docsrc/tutorials/runtime_opt/multi_optimization_profiles.rst create mode 100644 examples/dynamo/multi_optimization_profiles.py create mode 100644 py/torch_tensorrt/runtime/_optimization_profile.py create mode 100644 tests/py/dynamo/runtime/test_multi_optimization_profiles.py diff --git a/core/runtime/TRTEngine.cpp b/core/runtime/TRTEngine.cpp index 6fbb9c60f0..831e347cba 100644 --- a/core/runtime/TRTEngine.cpp +++ b/core/runtime/TRTEngine.cpp @@ -251,6 +251,10 @@ TRTEngine::TRTEngine( num_io = std::make_pair(inputs_size, outputs); } + // Reconstruct optimization-profile info (count + per-profile ranges) from the + // TRT API so multi-profile selection works for any loaded engine. + this->setup_optimization_profiles(); + #ifndef NDEBUG this->enable_profiling(); #endif @@ -512,6 +516,97 @@ void TRTEngine::reset_captured_graph() { cudagraph.reset(); } +void TRTEngine::setup_optimization_profiles() { + num_optimization_profiles = cuda_engine->getNbOptimizationProfiles(); + profile_dim_ranges.clear(); + is_shape_inference_io.clear(); + for (const auto& name : in_binding_names) { + is_shape_inference_io[name] = cuda_engine->isShapeInferenceIO(name.c_str()); + } + if (num_optimization_profiles <= 1) { + return; + } + // name -> [dim] -> [(min, max), ...] (one entry per optimization-profile index). + for (int64_t p = 0; p < num_optimization_profiles; ++p) { + for (const auto& name : in_binding_names) { + if (is_shape_inference_io[name]) { + continue; + } + auto dmin = + cuda_engine->getProfileShape(name.c_str(), static_cast(p), nvinfer1::OptProfileSelector::kMIN); + auto dmax = + cuda_engine->getProfileShape(name.c_str(), static_cast(p), nvinfer1::OptProfileSelector::kMAX); + auto& dims = profile_dim_ranges[name]; + if (dims.empty()) { + dims.resize(dmin.nbDims); + } + for (int d = 0; d < dmin.nbDims; ++d) { + dims[d].push_back(std::make_pair(dmin.d[d], dmax.d[d])); + } + } + } +} + +void TRTEngine::set_active_profile(int64_t profile_index) { + if (num_optimization_profiles <= 1) { + return; + } + if (profile_index == active_profile_index) { + return; + } + auto stream = c10::cuda::getCurrentCUDAStream(device_info.id); + // setOptimizationProfileAsync returns false for an out-of-range index; the + // index is validated upstream in TorchTensorRTModule.resolve_profile_index. + TORCHTRT_CHECK( + exec_ctx->setOptimizationProfileAsync(static_cast(profile_index), stream.stream()), + "Failed to switch to optimization profile index " << profile_index); + stream.synchronize(); + active_profile_index = profile_index; + // A profile switch invalidates any captured CUDA graph and changes the + // context state, so force re-record / shape re-inference on the next call. + runtime_states.context_changed = true; + reset_captured_graph(); + shape_key = "None"; + LOG_DEBUG("Switched to optimization profile index " << profile_index); +} + +int64_t TRTEngine::auto_select_profile(const std::vector& inputs) { + // Lazy selection: scan profiles in index order and return the first one whose + // [min, max] ranges contain every input shape. + for (int64_t p = 0; p < num_optimization_profiles; ++p) { + bool fits = true; + for (size_t i = 0; i < in_binding_names.size() && fits; ++i) { + const auto& name = in_binding_names[i]; + if (i >= inputs.size() || is_shape_inference_io[name]) { + continue; + } + auto ranges_it = profile_dim_ranges.find(name); + if (ranges_it == profile_dim_ranges.end()) { + continue; + } + const auto& dims = ranges_it->second; + auto sizes = inputs[i].sizes(); + for (size_t d = 0; d < sizes.size(); ++d) { + if (d < dims.size()) { + int64_t lo = dims[d][p].first; + int64_t hi = dims[d][p].second; + if (!(lo <= sizes[d] && sizes[d] <= hi)) { + fits = false; + break; + } + } + } + } + if (fits) { + return p; + } + } + TORCHTRT_THROW_ERROR( + "No optimization profile matches the input shapes. Fix the input shapes or pin a profile " + "explicitly via optimization_profile(module, index)."); + return 0; // unreachable +} + void TRTEngine::set_resource_allocation_strategy(TRTEngine::ResourceAllocationStrategy new_strategy) { if (new_strategy != this->resource_allocation_strategy) { this->resource_allocation_strategy = new_strategy; diff --git a/core/runtime/TRTEngine.h b/core/runtime/TRTEngine.h index c6d06dfb40..a9dbe426e4 100644 --- a/core/runtime/TRTEngine.h +++ b/core/runtime/TRTEngine.h @@ -218,6 +218,33 @@ struct TRTEngine : torch::CustomClassHolder { bool use_pre_allocated_outputs = false; std::vector pre_allocated_outputs; + // --- Multiple optimization profiles --- + // State and helpers mirror the Python runtime (TRTEngine in _TRTEngine.py) so + // the C++ and Python runtimes are interchangeable: the same attribute and + // method names are exposed via torchbind in register_jit_hooks.cpp + // (``num_optimization_profiles``, ``_active_profile_index``, + // ``_auto_select_profiles``, ``set_active_profile``). Index validation lives + // in the runtime-agnostic TorchTensorRTModule.resolve_profile_index. + int64_t num_optimization_profiles = 1; // cuda_engine->getNbOptimizationProfiles() + int64_t active_profile_index = 0; // profile currently loaded in exec_ctx + bool auto_select_profiles = false; // opt-in shape-based selection (per call) + // input name -> [dim index] -> per-profile [min, max]; cached from the TRT + // API. The dim axis is a dense vector indexed by dimension. + std::unordered_map>>> profile_dim_ranges; + std::unordered_map is_shape_inference_io; + + // Cache profile count + per-profile dim ranges purely from the TRT API + // (getNbOptimizationProfiles / getProfileShape) so selection works for any + // loaded engine with no extra serialized metadata. + void setup_optimization_profiles(); + // Switch the active TRT optimization profile (idempotent). + void set_active_profile(int64_t profile_index); + // Lazy / first-working: first profile whose [min, max] fits all input shapes. + // Called internally from the execute_engine run paths (guarded by + // num_optimization_profiles > 1 && auto_select_profiles); manual pins are + // applied eagerly via set_active_profile. + int64_t auto_select_profile(const std::vector& inputs); + // Single placeholder buffer for empty tensor inputs (allocated once, reused) void* empty_tensor_placeholder = nullptr; diff --git a/core/runtime/execute_engine.cpp b/core/runtime/execute_engine.cpp index 6a070db3cf..d0973c465b 100644 --- a/core/runtime/execute_engine.cpp +++ b/core/runtime/execute_engine.cpp @@ -241,6 +241,13 @@ std::vector execute_engine(std::vector inputs, c10::intr auto run_standard_execution = [&]() { bool cudagraphs_enabled = (CUDAGRAPHS_MODE == SUBGRAPH_CUDAGRAPHS); + // Auto-select the optimization profile from input shapes before validating + // shapes, so a profile switch's context_changed flag and shape_key reset are + // observed below. Only auto-selection runs per call; manual pins are applied + // eagerly via set_active_profile. + if (compiled_engine->num_optimization_profiles > 1 && compiled_engine->auto_select_profiles) { + compiled_engine->set_active_profile(compiled_engine->auto_select_profile(inputs)); + } bool shape_changed = _validate_shapes(inputs, compiled_engine); auto current_device_id = inputs.size() > 0 ? inputs[0].device().index() : at::cuda::current_device(); @@ -401,6 +408,11 @@ std::vector execute_engine(std::vector inputs, c10::intr }; auto run_output_allocator = [&]() { + // Auto-select the optimization profile from input shapes before binding + // them. Only auto-selection runs per call; manual pins are applied eagerly. + if (compiled_engine->num_optimization_profiles > 1 && compiled_engine->auto_select_profiles) { + compiled_engine->set_active_profile(compiled_engine->auto_select_profile(inputs)); + } { // Input Setup std::unique_ptr input_profiler_guard; if (compiled_engine->profile_execution) { diff --git a/core/runtime/register_jit_hooks.cpp b/core/runtime/register_jit_hooks.cpp index 7eae8bfb91..4685bc25dd 100644 --- a/core/runtime/register_jit_hooks.cpp +++ b/core/runtime/register_jit_hooks.cpp @@ -40,6 +40,13 @@ static auto TORCHTRT_UNUSED TRTEngineTSRegistrtion = .def("reset_captured_graph", &TRTEngine::reset_captured_graph) .def("set_output_tensors_as_unowned", &TRTEngine::set_output_tensors_as_unowned) .def("are_output_tensors_unowned", &TRTEngine::are_output_tensors_unowned) + // Multiple optimization profiles. Names match the Python runtime + // (_TRTEngine.py) so both runtimes are interchangeable behind + // TorchTensorRTModule / the optimization_profile context manager. + .def("set_active_profile", &TRTEngine::set_active_profile) + .def_readonly("num_optimization_profiles", &TRTEngine::num_optimization_profiles) + .def_readonly("_active_profile_index", &TRTEngine::active_profile_index) + .def_readwrite("_auto_select_profiles", &TRTEngine::auto_select_profiles) .def( "use_dynamically_allocated_resources", [](const c10::intrusive_ptr& self, bool dynamic) -> void { diff --git a/docsrc/tutorials/runtime_opt/index.rst b/docsrc/tutorials/runtime_opt/index.rst index 007bd1f645..2e451a5365 100644 --- a/docsrc/tutorials/runtime_opt/index.rst +++ b/docsrc/tutorials/runtime_opt/index.rst @@ -2,7 +2,9 @@ Runtime Optimization ===================== Optimize inference throughput and latency: CUDA Graphs for kernel-replay, -pre-allocated output buffers, and choosing the Python vs C++ TRT execution path. +pre-allocated output buffers, multiple optimization profiles for distinct shape +regimes (e.g. LLM prefill/decode), and choosing the Python vs C++ TRT execution +path. .. toctree:: :maxdepth: 1 @@ -10,4 +12,6 @@ pre-allocated output buffers, and choosing the Python vs C++ TRT execution path. cuda_graphs Example: Torch Export with Cudagraphs <../_rendered_examples/dynamo/torch_export_cudagraphs> Example: Pre-allocated output buffer <../_rendered_examples/dynamo/pre_allocated_output_example> + multi_optimization_profiles + Example: Multiple optimization profiles (prefill/decode) <../_rendered_examples/dynamo/multi_optimization_profiles> Python vs C++ runtime diff --git a/docsrc/tutorials/runtime_opt/multi_optimization_profiles.rst b/docsrc/tutorials/runtime_opt/multi_optimization_profiles.rst new file mode 100644 index 0000000000..7922a90a17 --- /dev/null +++ b/docsrc/tutorials/runtime_opt/multi_optimization_profiles.rst @@ -0,0 +1,164 @@ +.. _multi_optimization_profiles_tutorial: + +Multiple Optimization Profiles (Prefill / Decode) +================================================= + +TensorRT tunes kernels for the **optimization profile** of an engine: a +``[min, opt, max]`` range for every dynamic input dimension. Kernels are tuned at +the ``opt`` point, so a single profile can only be optimal for one shape. + +Many models, however, run in several distinct shape *regimes* that share the same +weights. The canonical case is an autoregressive LLM: + +* **prefill** -- the prompt is processed in one shot, so ``seq`` is large, and +* **decode** -- tokens are generated one at a time, so ``seq == 1``. + +With a single dynamic range ``seq in [1, max]`` you must pick one ``opt``. Tuning +for the long prefill length leaves **decode** -- the latency-critical, most +frequently executed phase -- running on kernels chosen for a sequence length it +never sees. + +Torch-TensorRT lets you declare **multiple optimization profiles** on a single +input and select the active one at runtime. The engine is built once and each +profile is tuned independently. + +Declaring profiles +------------------ + +Pass an ordered list of ``{"min", "opt", "max"}`` dicts to +:class:`torch_tensorrt.Input` via ``profiles``. The **list index** is the +optimization-profile index you select at runtime. + +.. code-block:: python + + import torch + import torch_tensorrt + + DECODE_IDX, PREFILL_IDX = 0, 1 + + profiled_input = torch_tensorrt.Input( + dtype=torch.int64, + profiles=[ + # index 0 -> decode: seq pinned to 1 (a fully static profile) + {"min": (1, 1), "opt": (1, 1), "max": (1, 1)}, + # index 1 -> prefill: seq in [1, 512], tuned at 256 + {"min": (1, 1), "opt": (1, 256), "max": (1, 512)}, + ], + ) + +``profiles`` is mutually exclusive with the single-shape ``min_shape`` / +``opt_shape`` / ``max_shape`` (and ``shape``) arguments. + +The **union envelope** +~~~~~~~~~~~~~~~~~~~~~~~~ + +``torch.export`` traces a model over one ``[min, opt, max]`` range, so +``Input`` automatically derives the **union envelope** of all profiles +(elementwise ``min`` of every ``min`` and ``max`` of every ``max``; ``opt`` is +taken from the first profile). Each declared profile is a subset of this +envelope. You export over the envelope and the individual profiles become the +per-profile TensorRT tunings: + +.. code-block:: python + + print(profiled_input.shape["min_shape"]) # (1, 1) + print(profiled_input.shape["max_shape"]) # (1, 512) + +Compile +------- + +Export once over the union range, then compile as usual. Every input that +declares ``profiles`` must declare the **same number** of profiles; static +inputs (or dynamic inputs without ``profiles``) reuse their single shape in every +profile. + +.. code-block:: python + + seq = torch.export.Dim("seq", min=1, max=512) + exported = torch.export.export(model, (example_ids,), dynamic_shapes=({1: seq},)) + + trt_model = torch_tensorrt.dynamo.compile( + exported, + arg_inputs=[profiled_input], + enabled_precisions={torch.float16}, + min_block_size=1, + ) + +Selecting a profile at runtime +------------------------------ + +Selection is **manual by default**. Use the +:func:`torch_tensorrt.runtime.optimization_profile` context manager to pin a +profile by index for the duration of a ``with`` block; the prior state is saved +on enter and restored on exit, so blocks nest cleanly. + +.. code-block:: python + + from torch_tensorrt.runtime import optimization_profile + + with optimization_profile(trt_model, DECODE_IDX): + logits = trt_model(decode_ids) # seq == 1 + + with optimization_profile(trt_model, PREFILL_IDX): + logits = trt_model(prefill_ids) # seq == 256 + +Pass ``"auto"`` to let Torch-TensorRT choose from the input shapes. Auto-selection +is **lazy / first-working**: it scans profiles in index order and uses the first +whose ``[min, max]`` contains the input. Order matters when profiles overlap -- +declaring ``decode`` first lets it win the ``seq == 1`` overlap: + +.. code-block:: python + + with optimization_profile(trt_model, "auto"): + trt_model(decode_ids) # seq == 1 -> index 0 (decode) accepts -> decode + trt_model(prefill_ids) # seq == 256 -> index 0 rejects -> index 1 (prefill) + +Profiles, graph breaks, and serialization +----------------------------------------- + +* **Graph breaks**: when a model is partitioned into several TensorRT engines, + every engine carries the same number of profiles. Torch-TensorRT propagates the + per-profile bounds across the break, evaluating any *derived* dynamic dimension + (e.g. a ``reshape`` that turns ``seq`` into ``16 * seq``) through to the + downstream engine, so runtime selection stays consistent for the whole module. +* **Serialization / runtimes**: profile state is reconstructed from the TensorRT + API on load (``getNbOptimizationProfiles`` / ``getProfileShape``), so a + serialized engine keeps its profiles with no extra metadata. The same + ``optimization_profile`` API drives both the C++ and Python runtimes, which + remain interchangeable. + +Why it helps: a worked latency example +-------------------------------------- + +The example :ref:`multi_optimization_profiles` compiles ``google/gemma-3-1b-it`` +twice -- once with a single profile (tuned at the prefill length) and once with +separate decode/prefill profiles -- then compares per-call latency. The +multi-profile engine dedicates a **static** profile (``seq`` pinned to 1) to +decode, letting TensorRT specialize that path (measured on an NVIDIA A40, FP16): + +.. code-block:: text + + Per-call latency (ms), batch=1 + regime single-profile multi-profile speedup + -------------------------------------------------------------- + decode (seq=1) 5.232 4.597 1.14x + prefill (seq=128) 7.152 7.534 0.95x + +Prefill is essentially unchanged (both engines tune it at the same ``opt``), +while decode -- the regime executed once per generated token -- is faster. Exact +numbers depend on the model and GPU; the takeaway is that one engine can be tuned +well for *both* regimes instead of compromising on a single ``opt`` shape. + +.. note:: + + Because the model has two dynamic inputs (``input_ids`` and ``position_ids``), + the example passes one profiled ``Input`` for each, both declaring the same + profiles. The HuggingFace attention path also needs a TensorRT-friendly SDPA + lowering (``tools/llm/torchtrt_ext/register_sdpa``), and ``gemma-3-1b-it`` is a + gated model requiring Hugging Face authentication. + +.. seealso:: + + - Runnable example: :ref:`multi_optimization_profiles` + - :class:`torch_tensorrt.Input` + - :func:`torch_tensorrt.runtime.optimization_profile` diff --git a/examples/dynamo/multi_optimization_profiles.py b/examples/dynamo/multi_optimization_profiles.py new file mode 100644 index 0000000000..f94dbbc860 --- /dev/null +++ b/examples/dynamo/multi_optimization_profiles.py @@ -0,0 +1,265 @@ +""" +.. _multi_optimization_profiles: + +Multiple Optimization Profiles: Prefill vs Decode for Gemma-3 +============================================================= + +Autoregressive LLMs run in two very different shape *regimes* that share one set +of weights (and ideally one engine): + +- **prefill**: the prompt is processed in one shot, so the sequence length + ``seq`` is large, and +- **decode**: tokens are generated one at a time, so ``seq == 1``. + +A single dynamic range ``seq in [1, max]`` works, but TensorRT can only tune +kernels for **one** ``opt`` point. Tuning for the prefill length leaves decode +(the latency-critical, most-frequently-executed phase) running on kernels picked +for a sequence it never sees. + +``torch_tensorrt.Input(profiles=[...])`` declares **N optimization profiles** on +a single input. The engine is built once (a single ``torch.export`` over the +union of all profiles), each profile gets its own TensorRT kernel tuning, and you +select the active profile per call (by index, or ``"auto"``). + +This example compiles `google/gemma-3-1b-it +`_ **twice** -- once with a single +profile and once with separate prefill/decode profiles -- and compares the decode +and prefill latency of the two engines. + +.. note:: + + ``google/gemma-3-1b-it`` is a **gated** model: you must accept its license on + the Hugging Face Hub and authenticate (``hf auth login`` or the ``HF_TOKEN`` + environment variable) before running this example. It downloads ~2 GB of + weights on first use and requires a CUDA GPU. + +.. note:: + + This uses the Ahead-Of-Time (AOT) ``torch.export`` + ``dynamo.compile`` path. + Runtime profile selection works with whichever TensorRT runtime (C++ or + Python) the installed Torch-TensorRT build provides. +""" + +# %% +# Imports and Setup +# ^^^^^^^^^^^^^^^^^^ +# +# The HuggingFace attention path needs a TensorRT-friendly SDPA lowering. The +# reusable LLM helpers ``register_sdpa`` (a Gemma-3-specific SDPA pass) and +# ``export_llm`` live under ``tools/llm`` in the Torch-TensorRT repo, so we add +# that directory to ``sys.path``. + +import sys +import timeit +from pathlib import Path + +import torch +import torch_tensorrt + +REPO_ROOT = Path(__file__).resolve().parents[2] +sys.path.insert(0, str(REPO_ROOT / "tools" / "llm")) + +MODEL_ID = "google/gemma-3-1b-it" +DEVICE = torch.device("cuda:0") + +# The two regimes we benchmark. +MAX_SEQ = 256 # largest prompt the engine must support +PREFILL_SEQ = 128 +DECODE_SEQ = 1 +DECODE_IDX, PREFILL_IDX = 0, 1 + + +# %% +# Load the Model +# ^^^^^^^^^^^^^^ +# +# Load with ``use_cache=False`` (this example recomputes over the full sequence +# rather than using a KV cache, which keeps the export simple) and the ``sdpa`` +# attention implementation, then register the Gemma-3 SDPA lowering pass. +def load_model(): + from transformers import AutoModelForCausalLM + + with torch.no_grad(): + model = ( + AutoModelForCausalLM.from_pretrained( + MODEL_ID, + use_cache=False, + attn_implementation="sdpa", + ignore_mismatched_sizes=True, + ) + .eval() + .cuda() + .to(torch.float16) + ) + from torchtrt_ext import register_sdpa + + register_sdpa.enable_sdpa_converter(MODEL_ID, model.config) + return model + + +try: + model = load_model() +except Exception as e: # gated/no-auth/no-GPU environments (e.g. CI docs build) + print(f"Skipping example: could not load {MODEL_ID} ({type(e).__name__}: {e}).") + print("Accept the license and authenticate (hf auth login / HF_TOKEN) to run.") + sys.exit(0) + + +def make_inputs(seq_len: int): + ids = torch.randint(1, 10000, (1, seq_len), dtype=torch.int64, device=DEVICE) + position_ids = torch.arange(seq_len, device=DEVICE).unsqueeze(0) + return ids, position_ids + + +# %% +# Declaring the Optimization Profiles +# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ +# +# ``profiles`` is an ordered list; the list index is the optimization-profile +# index used at runtime. Both model inputs (``input_ids`` and ``position_ids``) +# are dynamic over ``seq``, so each gets a profiled ``Input`` with identical +# profiles: +# +# - index ``0`` -> **decode**: ``seq`` pinned to 1 (a fully static profile) +# - index ``1`` -> **prefill**: ``seq`` in ``[1, MAX_SEQ]``, tuned at ``PREFILL_SEQ`` +# +# Profile order matters for auto-selection: the profiles overlap at ``seq == 1`` +# and auto-selection picks the *first* profile whose ``[min, max]`` accepts the +# input, so declaring ``decode`` first lets it win the ``seq == 1`` overlap. +profiles = [ + {"min": (1, 1), "opt": (1, 1), "max": (1, 1)}, # decode + {"min": (1, 1), "opt": (1, PREFILL_SEQ), "max": (1, MAX_SEQ)}, # prefill +] +multi_profile_inputs = [ + torch_tensorrt.Input(dtype=torch.int64, profiles=profiles), # input_ids + torch_tensorrt.Input(dtype=torch.int64, profiles=profiles), # position_ids +] + +# %% +# Export Once, Compile Twice +# ^^^^^^^^^^^^^^^^^^^^^^^^^^^ +# +# ``export_llm`` traces the model over a dynamic ``seq`` range. We reuse the +# exported program for both the single-profile baseline (tuned at the prefill +# length, the conventional choice) and the multi-profile engine. +from utils import export_llm # noqa: E402 + +example_ids, _ = make_inputs(PREFILL_SEQ) +with torch.inference_mode(): + exported = export_llm(model, example_ids, min_seq_len=1, max_seq_len=MAX_SEQ) + +# ``offload_module_to_cpu`` must stay False here: it is currently incompatible +# with the multi-profile ``Input(profiles=...)`` path (CPU/CUDA device mismatch). +common = dict( + use_fp32_acc=True, + disable_tf32=True, + offload_module_to_cpu=False, + min_block_size=1, + require_full_compilation=True, + device=DEVICE, +) + +print("Compiling single-profile engine (tuned at prefill length) ...") +bench_ids, bench_pos = make_inputs(PREFILL_SEQ) +trt_single = torch_tensorrt.dynamo.compile( + exported, inputs=[bench_ids, bench_pos], **common +) + +print("Compiling multi-profile engine (decode + prefill) ...") +trt_multi = torch_tensorrt.dynamo.compile( + exported, arg_inputs=multi_profile_inputs, **common +) + + +# %% +# Correctness +# ^^^^^^^^^^^ +# +# FP16 logits over Gemma's 262K-token vocabulary are noisy, so we compare the +# *predicted token* (argmax) rather than raw logits. +def logits(out): + return (out.logits if hasattr(out, "logits") else out).float() + + +from torch_tensorrt.runtime import optimization_profile # noqa: E402 + +decode_ids, decode_pos = make_inputs(DECODE_SEQ) +prefill_ids, prefill_pos = make_inputs(PREFILL_SEQ) + +with torch.inference_mode(): + ref_decode = logits(model(decode_ids, position_ids=decode_pos)) + ref_prefill = logits(model(prefill_ids, position_ids=prefill_pos)) + + with optimization_profile(trt_multi, DECODE_IDX): + trt_decode = logits(trt_multi(decode_ids, decode_pos)) + with optimization_profile(trt_multi, PREFILL_IDX): + trt_prefill = logits(trt_multi(prefill_ids, prefill_pos)) + + +def top1_match(a, b): + return (a.argmax(-1) == b.argmax(-1)).float().mean().item() + + +print(f"decode top-1 token match vs eager: {top1_match(trt_decode, ref_decode):.1%}") +print(f"prefill top-1 token match vs eager: {top1_match(trt_prefill, ref_prefill):.1%}") + + +# %% +# Latency Comparison +# ^^^^^^^^^^^^^^^^^^^ +# +# Time each regime on each engine. For the multi-profile engine we pin the +# matching profile around the loop (the realistic serving pattern). We report the +# min over several rounds to reduce noise. +def benchmark(run, iters: int = 50, warmup: int = 20, rounds: int = 3) -> float: + for _ in range(warmup): + run() + torch.cuda.synchronize() + best = float("inf") + for _ in range(rounds): + start = timeit.default_timer() + for _ in range(iters): + run() + torch.cuda.synchronize() + best = min(best, (timeit.default_timer() - start) / iters * 1000) # ms/call + return best + + +with torch.inference_mode(): + single_decode = benchmark(lambda: trt_single(decode_ids, decode_pos)) + single_prefill = benchmark(lambda: trt_single(prefill_ids, prefill_pos)) + with optimization_profile(trt_multi, DECODE_IDX): + multi_decode = benchmark(lambda: trt_multi(decode_ids, decode_pos)) + with optimization_profile(trt_multi, PREFILL_IDX): + multi_prefill = benchmark(lambda: trt_multi(prefill_ids, prefill_pos)) + +# %% +# Results. Decode is the win: the multi-profile engine dedicates a *static* +# profile (``seq`` pinned to 1) to decode, so TensorRT specializes that path +# instead of serving it from kernels tuned for the long prefill length. Prefill +# is unchanged (both engines tune it at the same ``opt``). +print("\nPer-call latency (ms), batch=1") +print(f"{'regime':<20}{'single-profile':>16}{'multi-profile':>16}{'speedup':>10}") +print("-" * 62) +print( + f"{f'decode (seq={DECODE_SEQ})':<20}{single_decode:>16.3f}" + f"{multi_decode:>16.3f}{single_decode / multi_decode:>9.2f}x" +) +print( + f"{f'prefill (seq={PREFILL_SEQ})':<20}{single_prefill:>16.3f}" + f"{multi_prefill:>16.3f}{single_prefill / multi_prefill:>9.2f}x" +) + +# %% +# Summary +# ^^^^^^^ +# +# - Declare ``N`` profiles on an ``Input`` with ``profiles=[{min, opt, max}, ...]`` +# (one per dynamic model input -- here ``input_ids`` and ``position_ids``). +# - One export + one engine; each profile gets its own TensorRT kernel tuning. +# - Select at runtime by **index** (``optimization_profile(m, i)``) or let +# ``"auto"`` pick the first profile that fits the input shapes. +# - Dedicating a static ``seq == 1`` profile to decode lets TensorRT tune that +# latency-critical path independently of the prefill length. + +print("Done.") diff --git a/py/torch_tensorrt/_Input.py b/py/torch_tensorrt/_Input.py index 4e078daac4..d0621ad330 100644 --- a/py/torch_tensorrt/_Input.py +++ b/py/torch_tensorrt/_Input.py @@ -54,6 +54,15 @@ class _ShapeMode(Enum): shared_dims: Dict[int, str] = ( {} ) #: Optional {axis_index: name} for dynamic axes. The same name across inputs is exported as one shared ``torch.export.Dim`` (e.g. a batch axis shared by ``input_ids`` and ``attention_mask``). + #: Optional ordered optimization profiles for multi-profile engines. A list + #: of dicts, one per profile; the list index is the TRT optimization-profile + #: index used to select the profile at runtime. Each entry has ``min`` / + #: ``opt`` / ``max`` shape tuples. ``None`` for the default zero/one-profile + #: behavior. ``profiles`` may be combined with ``shared_dims``: the profiles + #: define the per-profile TRT ranges while ``shared_dims`` names dynamic axes + #: on the union envelope so they export as one shared ``torch.export.Dim`` + #: across inputs. + profiles: Optional[List[Dict[str, Tuple[int, ...]]]] = None def __init__(self, *args: Any, **kwargs: Any) -> None: """__init__ Method for torch_tensorrt.Input @@ -82,8 +91,17 @@ def __init__(self, *args: Any, **kwargs: Any) -> None: - Input(shape=(1,3,32,32), dtype=torch_tensorrt.dtype.int32, format=torch_tensorrt.TensorFormat.NCHW) - Input(min_shape=(1,3,32,32), opt_shape=[2,3,32,32], max_shape=(3,3,32,32)) #Implicitly dtype=torch_tensorrt.dtype.float32, format=torch_tensorrt.TensorFormat.NCHW """ - # Compatibility code for switching over from InputTensorSpec - if "shape" in kwargs and "shape_ranges" in kwargs: + # Multi optimization profile support. ``profiles`` is validated and + # translated into the union ``min_shape`` / ``opt_shape`` / ``max_shape`` + # kwargs so the regular dynamic-shape path below constructs the Input + # unchanged. + if kwargs.get("profiles") is not None: + self._init_from_profiles(args, kwargs) + + # Compatibility code for switching over from InputTensorSpec. + # Mutually exclusive with `profiles` above (which translates into + # min/opt/max kwargs and forbids `shape`/`shape_ranges`). + elif "shape" in kwargs and "shape_ranges" in kwargs: assert ( len(kwargs["shape_ranges"]) == 1 and len(kwargs["shape_ranges"][0]) == 3 ) @@ -217,6 +235,114 @@ def __init__(self, *args: Any, **kwargs: Any) -> None: if "name" in kwargs: self.name = kwargs["name"] + def _init_from_profiles(self, args: Any, kwargs: Any) -> None: + """Validate ``profiles`` and translate them into union min/opt/max kwargs. + + ``profiles`` (``kwargs["profiles"]``) is an ordered list of dicts (one per + profile); the list index is the TRT optimization-profile index selected at + runtime. Each entry has ``min`` / ``opt`` / ``max`` shape tuples. This + stores the normalized list on ``self.profiles`` and fills ``kwargs`` with + the elementwise union envelope (the range ``torch.export`` must cover), so + the regular dynamic-shape path in ``__init__`` builds the Input. + + ``shared_dims`` is intentionally *not* mutually exclusive with + ``profiles``: profiles describe the per-profile TRT shape ranges, whereas + ``shared_dims`` names dynamic axes so they export as a single shared + ``torch.export.Dim`` across inputs. The two compose -- ``shared_dims`` is + left in ``kwargs`` and validated against the union envelope by the regular + dynamic-shape path (a named axis must be dynamic in the union). + """ + profiles = kwargs["profiles"] + # ``profiles`` is the only *shape* specifier when used (``shared_dims`` is + # an axis-naming modifier, not a shape, so it is allowed alongside). + if len(args) != 0: + raise ValueError( + "Found both a positional shape argument and `profiles`. " + "class Input expects `profiles` to be the only shape specifier when used." + ) + if any( + k in kwargs + for k in ["shape", "min_shape", "opt_shape", "max_shape", "shape_ranges"] + ): + raise ValueError( + "`profiles` is mutually exclusive with `shape` / `min_shape` / " + "`opt_shape` / `max_shape`. Specify only one shape declaration on an Input." + ) + if not isinstance(profiles, (list, tuple)) or len(profiles) == 0: + raise ValueError( + "`profiles` must be a non-empty list of dicts, each with keys " + "'min', 'opt', 'max'. The list index is the optimization-profile " + "index used to select the profile at runtime." + ) + + normalized: List[Dict[str, Tuple[int, ...]]] = [] + rank: Optional[int] = None + for i, prof in enumerate(profiles): + if not isinstance(prof, dict) or not all( + k in prof for k in ("min", "opt", "max") + ): + raise ValueError( + f"Profile at index {i} must be a dict with keys 'min', 'opt', 'max'" + ) + + for field_name in ("min", "opt", "max"): + if not Input._supported_input_size_type(prof[field_name]): + raise TypeError( + f"Profile at index {i} field '{field_name}' must be a List, " + f"tuple or torch.Size, found {type(prof[field_name])}" + ) + + min_shape = tuple(prof["min"]) + opt_shape = tuple(prof["opt"]) + max_shape = tuple(prof["max"]) + + if not (len(min_shape) == len(opt_shape) == len(max_shape)): + raise ValueError( + f"Profile at index {i} min/opt/max shapes must have the same number " + f"of dimensions, found {len(min_shape)}/{len(opt_shape)}/{len(max_shape)}" + ) + if rank is None: + rank = len(min_shape) + elif rank != len(min_shape): + raise ValueError( + "All profiles on an Input must have the same number of dimensions. " + f"Profile at index {i} has {len(min_shape)}, expected {rank}." + ) + + for d in range(len(min_shape)): + # No min=0 (and TRT requires >= 1). + if min_shape[d] < 1: + raise ValueError( + f"Profile at index {i} min_shape[{d}]={min_shape[d]} is invalid; " + "every dimension must have min >= 1 (min=0 is not supported)." + ) + if not (min_shape[d] <= opt_shape[d] <= max_shape[d]): + raise ValueError( + f"Profile at index {i} requires min <= opt <= max element-wise. " + f"Got min={min_shape[d]}, opt={opt_shape[d]}, max={max_shape[d]} at dim {d}." + ) + + normalized.append( + { + "min": min_shape, + "opt": opt_shape, + "max": max_shape, + } + ) + + self.profiles = normalized + + # Derive the export envelope: elementwise union over every profile. opt + # is taken from the first declared profile (the shape export will trace / + # specialize at). The regular dynamic-shape path in ``__init__`` consumes + # these to set ``self.shape`` and ``shape_mode``. + assert rank is not None + union_min = [min(p["min"][d] for p in normalized) for d in range(rank)] + union_max = [max(p["max"][d] for p in normalized) for d in range(rank)] + kwargs["min_shape"] = tuple(union_min) + kwargs["opt_shape"] = tuple(normalized[0]["opt"]) + kwargs["max_shape"] = tuple(union_max) + def __str__(self) -> str: if self.shape_mode == Input._ShapeMode.STATIC: return "Input(shape={}, dtype={}, format={}, domain=[{}, {}))".format( @@ -228,10 +354,21 @@ def __str__(self) -> str: ) elif self.shape_mode == Input._ShapeMode.DYNAMIC: if isinstance(self.shape, dict): - return "Input(min_shape={}, opt_shape={}, max_shape={}, dtype={}, format={}, domain=[{}, {}))".format( + profiles_str = ( + ", profiles={}".format( + [ + {k: tuple(v) for k, v in prof.items()} + for prof in self.profiles + ] + ) + if self.profiles + else "" + ) + return "Input(min_shape={}, opt_shape={}, max_shape={}{}, dtype={}, format={}, domain=[{}, {}))".format( self.shape["min_shape"], self.shape["opt_shape"], self.shape["max_shape"], + profiles_str, str(self.dtype), str(self.format), str(self.tensor_domain[0]), @@ -259,6 +396,7 @@ def equivalent_spec(a: Input, b: Input) -> bool: a.shape["min_shape"] == b.shape["min_shape"], a.shape["opt_shape"] == b.shape["opt_shape"], a.shape["max_shape"] == b.shape["max_shape"], + a.profiles == b.profiles, a.dtype == b.dtype, a.format == b.format, a.low_tensor_domain_incl == b.low_tensor_domain_incl, diff --git a/py/torch_tensorrt/dynamo/_compiler.py b/py/torch_tensorrt/dynamo/_compiler.py index 62808a9498..3d6de6efb2 100644 --- a/py/torch_tensorrt/dynamo/_compiler.py +++ b/py/torch_tensorrt/dynamo/_compiler.py @@ -52,6 +52,7 @@ prepare_inputs, to_torch_device, to_torch_tensorrt_device, + validate_optimization_profiles, ) logger = logging.getLogger(__name__) @@ -759,6 +760,7 @@ def compile( } logger.debug(f"CPU memory usage before lowering: {get_cpu_memory_usage()} MB") settings = CompilationSettings(**compilation_options) + logger.info("Compilation Settings: %s\n", settings) exported_program = pre_export_lowering(exported_program, settings) exported_program = exported_program.run_decompositions( @@ -1052,6 +1054,24 @@ def preserve_module_specs( submodule_node_dict[node.name] = node preserve_module_specs(original_in_spec, original_out_spec, partitioned_module) + + # Multi-profile propagation: build the map from export source + # symbols to per-profile bounds once, from the top-level inputs, then reuse + # it to attach the same profiles (by index) to every TRT submodule's inputs. + top_level_inputs: List[Input] = list(sample_arg_inputs) + if isinstance(sample_kwarg_inputs, dict): + top_level_inputs.extend(sample_kwarg_inputs.values()) + num_profiles = validate_optimization_profiles(top_level_inputs) + profile_source_bounds = None + if num_profiles: + logger.info( + f"Building engine(s) with {num_profiles} " + "optimization profiles (selected by index)" + ) + profile_source_bounds = partitioning.build_profile_source_bounds( + partitioned_module, top_level_inputs, num_profiles + ) + # Store TRT replicas of Torch subgraphs trt_modules = {} # Iterate over all components that can be accelerated @@ -1109,8 +1129,14 @@ def preserve_module_specs( ] ) - # Get the submodule inputs for min, opt, max shapes of the graph inputs - submodule_inputs = partitioning.construct_submodule_inputs(submodule) + # Get the submodule inputs for min, opt, max shapes of the graph inputs. + # With multi-profile compile, propagate the profiles (by index) to each + # submodule input by symbolic substitution. + submodule_inputs = partitioning.construct_submodule_inputs( + submodule, + profile_source_bounds=profile_source_bounds, + num_profiles=num_profiles, + ) assert submodule_inputs is not None diff --git a/py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py b/py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py index 1b7982f074..3d715bd6a9 100644 --- a/py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py +++ b/py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py @@ -53,6 +53,7 @@ deallocate_module, get_cpu_memory_usage, to_torch_device, + validate_optimization_profiles, ) from torch_tensorrt.logging import TRT_LOGGER @@ -120,12 +121,28 @@ def __init__( + "\n".join(f"{i}" for i in missing_ops) ) + # Optimization profiles. Profiles are an ordered list on + # ``Input.profiles``; profile index i is built from each input's + # ``profiles[i]``. The count is derived from the input specs (submodule + # inputs inherit the same number of profiles via propagation). It is + # ``0`` when no input declares ``profiles``, in which case we fall back + # to the historical behavior: a single profile for dynamic inputs and + # none for fully static engines. + self.optimization_profile_count: int = validate_optimization_profiles( + input_specs + ) + has_dynamic_input = any( + input_spec.shape_mode == Input._ShapeMode.DYNAMIC + for input_spec in input_specs + ) + num_profiles = ( + self.optimization_profile_count + if self.optimization_profile_count + else (1 if has_dynamic_input else 0) + ) self.optimization_profiles: Optional[List[trt.IOptimizationProfile]] = ( - [self.builder.create_optimization_profile()] - if any( - input_spec.shape_mode == Input._ShapeMode.DYNAMIC - for input_spec in input_specs - ) + [self.builder.create_optimization_profile() for _ in range(num_profiles)] + if num_profiles > 0 else None ) @@ -697,6 +714,36 @@ def run_node(self, n: torch.fx.Node) -> torch.fx.Node: return trt_node + def _per_profile_shapes( + self, current_input: Input + ) -> List[Tuple[Sequence[int], Sequence[int], Sequence[int]]]: + """Return ``(min, opt, max)`` shapes for each optimization profile index. + + - If multiple profiles are active and the input declares ``profiles``, + each entry uses the input's range at that profile index. + - If multiple profiles are active but the input has none (static-like + input reused across regimes), every entry repeats the union range. + - If no profiles are active, there is a single entry (the union range) — + the historical single-profile behavior. + """ + assert isinstance(current_input.shape, dict) + union = ( + current_input.shape["min_shape"], + current_input.shape["opt_shape"], + current_input.shape["max_shape"], + ) + if not self.optimization_profile_count: + return [union] + + result: List[Tuple[Sequence[int], Sequence[int], Sequence[int]]] = [] + for i in range(self.optimization_profile_count): + if current_input.profiles and i < len(current_input.profiles): + prof = current_input.profiles[i] + result.append((prof["min"], prof["opt"], prof["max"])) + else: + result.append(union) + return result + def placeholder(self, target: str, args: Any, kwargs: Any) -> trt.ITensor: self._input_names.append(target) current_input = self.input_specs[self.input_specs_iter] @@ -706,27 +753,43 @@ def placeholder(self, target: str, args: Any, kwargs: Any) -> trt.ITensor: if current_input.shape_mode == Input._ShapeMode.DYNAMIC: assert isinstance(current_input.shape, dict) shape = [] - min_shape = current_input.shape["min_shape"] - opt_shape = current_input.shape["opt_shape"] - max_shape = current_input.shape["max_shape"] - # TODO: Does not support disjoint optimization profiles? assert self.optimization_profiles is not None - assert len(min_shape) == len(opt_shape) == len(max_shape) + + # Build the per-profile (min, opt, max) shapes for this input. Each + # TRT optimization profile index gets the input's corresponding + # profile range; inputs without profiles (e.g. static tensors reused + # across regimes) repeat their single union range in every profile. + per_profile_shapes = self._per_profile_shapes(current_input) + + for profile_idx, opt_profile in enumerate(self.optimization_profiles): + min_shape, opt_shape, max_shape = per_profile_shapes[profile_idx] + assert len(min_shape) == len(opt_shape) == len(max_shape) + if current_input.is_shape_tensor: + # For shape_tensors, min/opt/max_shapes correspond to actual + # values of the shapes provided during runtime. + opt_profile.set_shape_input(target, min_shape, opt_shape, max_shape) + else: + opt_profile.set_shape(target, min_shape, opt_shape, max_shape) + + # The INetwork input shape uses the union envelope to mark which + # dims are dynamic (-1). A dim is static only if it is identical + # across every profile's min/opt/max. + union_min = current_input.shape["min_shape"] + union_opt = current_input.shape["opt_shape"] + union_max = current_input.shape["max_shape"] if current_input.is_shape_tensor: - # For shape_tensors, min/opt/max_shapes correspond to actual values - # of the shapes provided during runtime - self.optimization_profiles[0].set_shape_input( - target, min_shape, opt_shape, max_shape - ) - shape.append(len(opt_shape)) + shape.append(len(union_opt)) else: - self.optimization_profiles[0].set_shape( - target, min_shape, opt_shape, max_shape - ) - - for i in range(len(min_shape)): - if min_shape[i] == opt_shape[i] == max_shape[i]: - shape.append(min_shape[i]) + for i in range(len(union_min)): + dim_is_static = all( + per_profile_shapes[p][0][i] + == per_profile_shapes[p][1][i] + == per_profile_shapes[p][2][i] + == per_profile_shapes[0][0][i] + for p in range(len(per_profile_shapes)) + ) + if dim_is_static: + shape.append(union_min[i]) else: # -1 to represent the dynamic dimension shape.append(DYNAMIC_DIM) diff --git a/py/torch_tensorrt/dynamo/partitioning/__init__.py b/py/torch_tensorrt/dynamo/partitioning/__init__.py index 4ef0c271d1..c1e256a4be 100644 --- a/py/torch_tensorrt/dynamo/partitioning/__init__.py +++ b/py/torch_tensorrt/dynamo/partitioning/__init__.py @@ -2,6 +2,7 @@ from ._global_partitioner import partition as global_partition from ._hierarchical_partitioner import hierarchical_adjacency_partition from .common import ( + build_profile_source_bounds, construct_submodule_inputs, get_graph_converter_support, run_shape_analysis, diff --git a/py/torch_tensorrt/dynamo/partitioning/common.py b/py/torch_tensorrt/dynamo/partitioning/common.py index 27d8974784..099becbadf 100644 --- a/py/torch_tensorrt/dynamo/partitioning/common.py +++ b/py/torch_tensorrt/dynamo/partitioning/common.py @@ -1,30 +1,101 @@ import logging -from typing import Any, Dict, Optional, Sequence, Set, Tuple +from typing import Any, Dict, List, Optional, Sequence, Set, Tuple import torch from torch._subclasses.fake_tensor import FakeTensor from torch.fx.experimental.proxy_tensor import unset_fake_temporarily - from torch_tensorrt._Input import Input from torch_tensorrt.dynamo.utils import ( COMPLEX_TO_REAL_DTYPE, contains_sym_int, extract_var_range_info, + extract_var_range_info_for_profile, ) logger = logging.getLogger(__name__) +# Per-profile source-symbol bounds, indexed by optimization-profile index: +# [ {symbol_name: {"min": int, "opt": int, "max": int}}, ... ] +# Built at the top level from ``Input.profiles`` and the export symbols, then +# propagated to submodule inputs by substituting into their SymInt expressions. +ProfileSourceBounds = List[Dict[str, Dict[str, int]]] + + +def _build_submodule_profiles( + input_shape: torch.Size, + union_min: Sequence[int], + union_opt: Sequence[int], + union_max: Sequence[int], + profile_source_bounds: ProfileSourceBounds, + num_profiles: int, +) -> Optional[List[Dict[str, Tuple[int, ...]]]]: + """Evaluate per-profile min/opt/max for a (possibly symbolic) submodule shape. + + For each profile (by index), substitute the profile's source-symbol values + into each SymInt dim and evaluate. Static dims and dims whose + symbols cannot be resolved fall back to the union value. Returns an ordered + list of profile dicts (one per profile index). + """ + profiles: List[Dict[str, Tuple[int, ...]]] = [] + for i in range(num_profiles): + sym_bounds = profile_source_bounds[i] if i < len(profile_source_bounds) else {} + p_min: List[int] = [] + p_opt: List[int] = [] + p_max: List[int] = [] + for d, dim in enumerate(input_shape): + if isinstance(dim, torch.SymInt): + try: + p_min.append( + extract_var_range_info_for_profile( + dim, {s: b["min"] for s, b in sym_bounds.items()} + ) + ) + p_opt.append( + extract_var_range_info_for_profile( + dim, {s: b["opt"] for s, b in sym_bounds.items()} + ) + ) + p_max.append( + extract_var_range_info_for_profile( + dim, {s: b["max"] for s, b in sym_bounds.items()} + ) + ) + except KeyError: + # Symbol(s) not present in this profile's source bounds; + # fall back to the union range for this dim. + p_min.append(int(union_min[d])) + p_opt.append(int(union_opt[d])) + p_max.append(int(union_max[d])) + else: + p_min.append(int(dim)) + p_opt.append(int(dim)) + p_max.append(int(dim)) + profiles.append( + { + "min": tuple(p_min), + "opt": tuple(p_opt), + "max": tuple(p_max), + } + ) + return profiles + def construct_dynamic_input( input_shape: torch.Size, input_dtype: torch.dtype, name: str = "", is_shape_tensor: bool = False, + profile_source_bounds: Optional[ProfileSourceBounds] = None, + num_profiles: int = 0, ) -> Input: """ Constructs a torch_tensorrt.Input based on a symbolic input Args: input_shape: A symbolic shape / regular shape of a tensor (which can have a mix of SymInt nodes and static values) + profile_source_bounds: Optional per-profile source-symbol bounds used to + propagate optimization profiles to this (intermediate) input. + num_profiles: Number of profiles to emit when ``profile_source_bounds`` + is provided. Returns: A dynamic shaped torch_tensorrt.Input which has the properties of the symbolic shaped input. """ @@ -70,6 +141,34 @@ def construct_dynamic_input( opt_shape.append(dim) max_shape.append(dim) + # Multi-profile propagation: emit profiles for this intermediate input by + # substituting source-symbol values into its SymInt dims. + if profile_source_bounds and num_profiles: + profiles = _build_submodule_profiles( + input_shape, + min_shape, + opt_shape, + max_shape, + profile_source_bounds, + num_profiles, + ) + if profiles is not None: + try: + return Input( + profiles=profiles, + dtype=input_dtype, + name=name, + is_shape_tensor=is_shape_tensor, + ) + except (ValueError, TypeError) as e: + # Non-affine / non-monotonic expressions can yield invalid + # per-profile bounds; fall back to the single union profile. + logger.warning( + f"Could not propagate optimization profiles to submodule input " + f"'{name}' (shape {input_shape}): {e}. Falling back to the union " + "range for this input." + ) + return Input( min_shape=min_shape, opt_shape=opt_shape, @@ -85,6 +184,8 @@ def get_input( dtype: torch.dtype, name: str = "", is_shape_tensor: bool = False, + profile_source_bounds: Optional[ProfileSourceBounds] = None, + num_profiles: int = 0, ) -> Input: """ Based on type of dimensions in the input_shape, construct regular or dynamic shaped inputs @@ -106,6 +207,8 @@ def get_input( dtype, name=name, is_shape_tensor=is_shape_tensor, + profile_source_bounds=profile_source_bounds, + num_profiles=num_profiles, ) else: return Input( @@ -113,12 +216,22 @@ def get_input( ) -def construct_submodule_inputs(module: torch.fx.GraphModule) -> Sequence[Input]: +def construct_submodule_inputs( + module: torch.fx.GraphModule, + profile_source_bounds: Optional[ProfileSourceBounds] = None, + num_profiles: int = 0, +) -> Sequence[Input]: """ Construct torch_tensorrt Inputs based on the module inputs. The module inputs will have meta data which has the shape and dtype info Args: module: Input FX GraphModule + profile_source_bounds: Optional per-profile source-symbol bounds. When + provided (multi-profile compile), each dynamic submodule input gets + ``profiles`` derived by substituting source symbols into its + symbolic shape. + num_profiles: Number of profiles corresponding to + ``profile_source_bounds``. Returns: Sequence of torch_tensorrt.Input's representing inputs to given module """ @@ -134,7 +247,13 @@ def construct_submodule_inputs(module: torch.fx.GraphModule) -> Sequence[Input]: if isinstance(input_meta, (FakeTensor, torch.Tensor)): input_shape = input_meta.size() torchtrt_inputs.append( - get_input(input_shape, input_meta.dtype, name=input.name) + get_input( + input_shape, + input_meta.dtype, + name=input.name, + profile_source_bounds=profile_source_bounds, + num_profiles=num_profiles, + ) ) elif isinstance(input_meta, torch.SymInt): # Assuming sym_integers | shape inputs always have torch.int64 dtype @@ -144,6 +263,8 @@ def construct_submodule_inputs(module: torch.fx.GraphModule) -> Sequence[Input]: torch.int64, name=input.name, is_shape_tensor=True, + profile_source_bounds=profile_source_bounds, + num_profiles=num_profiles, ) ) elif isinstance(input_meta, torch.SymFloat): @@ -178,6 +299,63 @@ def construct_submodule_inputs(module: torch.fx.GraphModule) -> Sequence[Input]: return torchtrt_inputs +def build_profile_source_bounds( + module: torch.fx.GraphModule, + top_level_inputs: Sequence[Input], + num_profiles: int, +) -> ProfileSourceBounds: + """Map export source symbols to per-profile bounds from top-level inputs. + + For each top-level placeholder whose ``Input`` declares ``profiles``, read the + export ``SymInt`` for each dynamic dim and record, per profile, the + ``min`` / ``opt`` / ``max`` value of the corresponding source symbol. The + result feeds :func:`construct_submodule_inputs` so intermediate engines + inherit the same profiles (by index) via symbolic substitution. + + Args: + module: Top-level (partitioned) GraphModule whose placeholders carry the + export symbolic shapes. + top_level_inputs: Ordered top-level Inputs (arg inputs followed by kwarg + inputs), aligned with the module placeholders. + num_profiles: Number of optimization profiles. + Returns: + A list indexed by optimization-profile index: + ``[{symbol_name: {"min": int, "opt": int, "max": int}}, ...]`` + """ + bounds: ProfileSourceBounds = [{} for _ in range(num_profiles)] + if not num_profiles: + return bounds + + placeholders = [n for n in module.graph.nodes if n.op == "placeholder"] + with unset_fake_temporarily(): + for ph, inp in zip(placeholders, top_level_inputs): + profiles = getattr(inp, "profiles", None) + if not profiles: + continue + if not ph.meta or "val" not in ph.meta: + continue + val = ph.meta["val"] + if not isinstance(val, (FakeTensor, torch.Tensor)): + continue + for d, dim in enumerate(val.size()): + if not isinstance(dim, torch.SymInt): + continue + expr = dim.node.expr + # Top-level dynamic dims map directly to a single source symbol. + if not getattr(expr, "is_symbol", False): + continue + sym_name = expr.name + for i, prof in enumerate(profiles): + if i >= len(bounds) or d >= len(prof["min"]): + continue + bounds[i][sym_name] = { + "min": int(prof["min"][d]), + "opt": int(prof["opt"][d]), + "max": int(prof["max"][d]), + } + return bounds + + def run_shape_analysis( parent_module: torch.fx.GraphModule, inputs: Sequence[Input], diff --git a/py/torch_tensorrt/dynamo/runtime/_TRTEngine.py b/py/torch_tensorrt/dynamo/runtime/_TRTEngine.py index aeda1aa1e4..351dc649c8 100644 --- a/py/torch_tensorrt/dynamo/runtime/_TRTEngine.py +++ b/py/torch_tensorrt/dynamo/runtime/_TRTEngine.py @@ -249,6 +249,13 @@ def __init__( # engines compiled with native multi-device collective layers. self._nccl_comm: Optional[Any] = None + # Multiple optimization profiles. Manual selection by default: + # ``_active_profile_index`` is the profile currently loaded in the TRT + # context (default 0, reused across calls). ``_auto_select_profiles`` + # opts into shape-based selection, re-evaluated on every forward. + self._active_profile_index = 0 + self._auto_select_profiles = False + self._load_serialized_info(serialized_info) self._setup_engine() @@ -313,6 +320,9 @@ def __setstate__(self, state: Any) -> None: # forward pass via setup_nccl_comm(). self._nccl_comm = None + self._active_profile_index = 0 + self._auto_select_profiles = False + serialized_info = list(state[0]) engine_field = serialized_info[ENGINE_IDX] if isinstance(engine_field, str): @@ -499,6 +509,7 @@ def _setup_engine(self) -> None: input_name: self.cuda_engine.is_shape_inference_io(input_name) for input_name in self.in_binding_names } + self._setup_optimization_profiles() if self.requires_output_allocator: self.create_output_allocator() @@ -614,6 +625,87 @@ def _enable_rtx_native_cudagraphs(self) -> None: self._rtx_native_cudagraphs = True logger.info("Switched to TRT-RTX native CUDA graphs") + def _setup_optimization_profiles(self) -> None: + """Cache per-profile shape ranges from the TRT API. + + Rebuilds the profile bounds via ``get_tensor_profile_shape`` so that + runtime profile selection works for engines compiled in-process, loaded + from cache, or deserialized from disk — no new serialization fields. + Populates: + + - ``_profile_dim_ranges``: ``name -> [dim] -> [(min, max), ...]`` (the dim + axis is a dense list indexed by dimension; one ``(min, max)`` tuple per + optimization-profile index), used for auto-selection containment. + """ + self.num_optimization_profiles = self.cuda_engine.num_optimization_profiles + self._profile_dim_ranges: Dict[str, List[List[Tuple[int, int]]]] = {} + + if self.num_optimization_profiles <= 1: + return + + for p in range(self.num_optimization_profiles): + for name in self.in_binding_names: + if self.is_shape_inference_io.get(name, False): + continue + rmin, _, rmax = self.cuda_engine.get_tensor_profile_shape(name, p) + dims = self._profile_dim_ranges.setdefault(name, []) + if not dims: + dims.extend([] for _ in range(len(rmin))) + for d, (lo, hi) in enumerate(zip(rmin, rmax)): + dims[d].append((int(lo), int(hi))) + + # --- optimization profile selection --- + + def set_active_profile(self, profile_index: int) -> None: + """Make ``profile_index`` the active TRT optimization profile (idempotent).""" + if self.num_optimization_profiles <= 1: + return + if profile_index == self._active_profile_index: + return + stream = self._engine_stream or torch.cuda.current_stream(self._target_device) + self.context.set_optimization_profile_async(profile_index, stream.cuda_stream) + stream.synchronize() + self._active_profile_index = profile_index + # A profile switch invalidates any captured CUDA graph and changes the + # context state, so force re-record / shape re-inference next run. + self.runtime_states.context_changed = True + self.reset_captured_graph() + self.shape_key = None + logger.debug(f"Switched to optimization profile index {profile_index}") + + def _auto_select_profile(self, inputs: Sequence[torch.Tensor]) -> int: + """Select an optimization profile from input shapes. + + Lazy selection: scan profiles in index order and return the **first** one + whose ``[min, max]`` ranges contain every input shape. Overlapping + profiles therefore resolve to the lowest matching index; pin manually via + ``optimization_profile(module, index)`` to force a specific profile. + """ + for p in range(self.num_optimization_profiles): + fits = True + for i, name in enumerate(self.in_binding_names): + if i >= len(inputs) or self.is_shape_inference_io.get(name, False): + continue + shape = tuple(int(s) for s in inputs[i].shape) + ranges = self._profile_dim_ranges.get(name, []) + for d, extent in enumerate(shape): + if d < len(ranges): + lo, hi = ranges[d][p] + if not (lo <= extent <= hi): + fits = False + break + if not fits: + break + if fits: + return p + + raise RuntimeError( + "No optimization profile matches the input shapes " + f"{[tuple(t.shape) for t in inputs]}. Cached profile ranges: " + f"{self._profile_dim_ranges}. Fix the input shapes or pin a profile " + "explicitly via optimization_profile(module, index)." + ) + # --- distributed / NCCL --- @property @@ -882,7 +974,7 @@ def _prepare_streams(self, contiguous_inputs: List[torch.Tensor]) -> bool: ): # Captured CUDA graph was recorded against the old stream. self.runtime_states.context_changed = True - return caller_on_default + return bool(caller_on_default) def _execute_standard( self, contiguous_inputs: List[torch.Tensor] @@ -913,6 +1005,12 @@ def _execute_standard( # cudagraph recapture (set_runtime_states consumes and resets the # flag). caller_on_default = self._prepare_streams(contiguous_inputs) + # Auto-select the optimization profile from input shapes before + # validating shapes, so set_active_profile's context_changed flag is + # consumed by set_runtime_states below. Manual pins are applied eagerly + # in set_optimization_profile, so only auto needs per-call selection. + if self.num_optimization_profiles > 1 and self._auto_select_profiles: + self.set_active_profile(self._auto_select_profile(contiguous_inputs)) shape_changed = self.validate_input_shapes(contiguous_inputs) ( need_cudagraphs_record, @@ -970,7 +1068,7 @@ def _execute_standard( with self._profile_section("TRTEngine:TensorRTRuntime"): if caller_on_default: - self._engine_stream.wait_stream(self._caller_stream) + self._engine_stream.wait_stream(self._caller_stream) # type: ignore[union-attr] with torch.cuda.stream(self._engine_stream): if self.resource_allocation_strategy: self._dynamic_workspace = torch.empty( @@ -989,7 +1087,7 @@ def _execute_standard( self.cudagraph, stream=self._engine_stream ): self.context.execute_async_v3( - self._engine_stream.cuda_stream + self._engine_stream.cuda_stream # type: ignore[union-attr] ) if self._profile_execution: self.cudagraph.debug_dump( @@ -997,10 +1095,10 @@ def _execute_standard( ) self.cudagraph.replay() # type: ignore[union-attr] else: - self.context.execute_async_v3(self._engine_stream.cuda_stream) + self.context.execute_async_v3(self._engine_stream.cuda_stream) # type: ignore[union-attr] if caller_on_default: - self._caller_stream.wait_stream(self._engine_stream) + self._caller_stream.wait_stream(self._engine_stream) # type: ignore[union-attr] if self.use_pre_allocated_outputs and ( self.output_tensors_are_unowned @@ -1026,6 +1124,10 @@ def _execute_output_allocator( "incompatible runtime modes. Please disable one of the two." ) + # Only auto-selection needs per-call work; manual pins are applied eagerly. + if self.num_optimization_profiles > 1 and self._auto_select_profiles: + self.set_active_profile(self._auto_select_profile(contiguous_inputs)) + with self._profile_section("TRTEngine:ProcessInputs"): self.setup_input_tensors(contiguous_inputs, False, False) @@ -1043,11 +1145,11 @@ def _execute_output_allocator( with self._profile_section("TRTEngine:TensorRTRuntime"): if caller_on_default: - self._engine_stream.wait_stream(self._caller_stream) + self._engine_stream.wait_stream(self._caller_stream) # type: ignore[union-attr] with torch.cuda.stream(self._engine_stream): - self.context.execute_async_v3(self._engine_stream.cuda_stream) + self.context.execute_async_v3(self._engine_stream.cuda_stream) # type: ignore[union-attr] if caller_on_default: - self._caller_stream.wait_stream(self._engine_stream) + self._caller_stream.wait_stream(self._engine_stream) # type: ignore[union-attr] outputs = [] assert self.output_allocator is not None diff --git a/py/torch_tensorrt/dynamo/runtime/_TorchTensorRTModule.py b/py/torch_tensorrt/dynamo/runtime/_TorchTensorRTModule.py index 1d83bd646f..1179fd2d6c 100644 --- a/py/torch_tensorrt/dynamo/runtime/_TorchTensorRTModule.py +++ b/py/torch_tensorrt/dynamo/runtime/_TorchTensorRTModule.py @@ -405,6 +405,76 @@ def set_extra_state(self, state: SerializedTorchTensorRTModuleFmt) -> None: self.output_binding_names = state[3] self.target_device = self._resolve_target_device() + def set_optimization_profile(self, profile: Optional[Any]) -> None: + """Select the active TRT optimization profile for this engine. + + ``profile`` may be a profile index (``int``), the string ``"auto"`` to + enable shape-based auto-selection, or ``None`` to clear any pin / auto + setting and reset to the default profile (index 0). + See :func:`torch_tensorrt.runtime.optimization_profile`. + """ + if self.engine is None: + self.setup_engine() + assert self.engine is not None + engine = self.engine + # Drive the primitive engine API (set_active_profile / _auto_select_profiles) + # that both the Python and C++ runtimes expose under the same names, so + # this works regardless of the active runtime. + if not hasattr(engine, "set_active_profile"): + raise RuntimeError( + "This engine does not support optimization profile selection." + ) + if profile is None: + engine._auto_select_profiles = False + engine.set_active_profile(0) + return + if isinstance(profile, str) and profile == "auto": + engine._auto_select_profiles = True + return + # Validate the profile index. ``num_optimization_profiles`` is exposed by + # both the Python and C++ runtimes under the same name. + if isinstance(profile, bool) or not isinstance(profile, int): + raise TypeError( + f"Optimization profile must be an integer index, got {type(profile)}" + ) + num_profiles = engine.num_optimization_profiles + if not (0 <= profile < num_profiles): + raise ValueError( + f"Optimization profile index {profile} out of range " + f"[0, {num_profiles})" + ) + engine._auto_select_profiles = False + engine.set_active_profile(profile) + + def get_optimization_profile_state(self) -> Optional[Tuple[Any, ...]]: + """Return the engine's current ``(auto, active)`` profile state. + + Used by the :func:`optimization_profile` context manager to save/restore + state across a ``with`` block. Returns ``None`` if unsupported. + """ + if self.engine is None: + return None + engine = self.engine + if not hasattr(engine, "_active_profile_index"): + return None + return ( + engine._auto_select_profiles, + engine._active_profile_index, + ) + + def restore_optimization_profile_state( + self, state: Optional[Tuple[Any, ...]] + ) -> None: + """Restore profile state captured by :meth:`get_optimization_profile_state`.""" + if state is None or self.engine is None: + return + engine = self.engine + if not hasattr(engine, "set_active_profile"): + return + auto, active = state + engine._auto_select_profiles = auto + engine.set_active_profile(active) + def set_pre_allocated_outputs(self, enable: bool) -> None: self.get_engine().use_pre_allocated_outputs = enable diff --git a/py/torch_tensorrt/dynamo/utils.py b/py/torch_tensorrt/dynamo/utils.py index 33595f4709..9511aaef9f 100644 --- a/py/torch_tensorrt/dynamo/utils.py +++ b/py/torch_tensorrt/dynamo/utils.py @@ -187,6 +187,7 @@ def get_torch_tensor( if input.is_shape_tensor: # TODO: All the shape tensors we've encountered so far are plain integers. # Validate this assumption on more models. + assert isinstance(input.shape, dict) return input.shape["opt_shape"][0] if len(mode) > 0: @@ -451,6 +452,69 @@ def extract_var_range_info(symbolic_integer: torch.SymInt) -> Dict[str, Optional return min_max_opt +def validate_optimization_profiles(input_specs: Sequence[Input]) -> int: + """Validate multi-profile inputs and return the number of profiles. + + ``Input.profiles`` is an ordered list; the list index is the TRT + optimization-profile index selected at runtime. Every dynamic input that + declares ``profiles`` must declare the *same number* of profiles. + Static inputs (or dynamic inputs without profiles) are allowed and reuse + their single shape in every profile. + + Returns the number of *declared* optimization profiles, i.e. ``0`` when no + input declares ``profiles`` (the multi-profile feature is unused and the + actual profile count is decided downstream: ``0`` for fully static engines, + ``1`` for dynamic ones). + """ + num_profiles = 0 + for inp in input_specs: + profiles = getattr(inp, "profiles", None) + if not profiles: + continue + if num_profiles == 0: + num_profiles = len(profiles) + elif len(profiles) != num_profiles: + raise ValueError( + "All inputs declaring optimization profiles must declare the same " + f"number of profiles, found both {num_profiles} and {len(profiles)}." + ) + return num_profiles + + +def extract_var_range_info_for_profile( + symbolic_integer: torch.SymInt, + symbol_values: Dict[str, int], +) -> int: + """Evaluate a symbolic dimension at a single profile corner. + + ``symbol_values`` maps source symbol name (e.g. ``"s0"``) to the concrete + integer value of that symbol at one corner (min / opt / max) of a profile. + The intermediate ``SymInt`` expression (e.g. ``s0/4``) is evaluated by + substituting those values. Shape ops in export produce affine expressions + that are monotonic in each source symbol, so per-corner substitution is + exact. + """ + node = symbolic_integer.node + expr = node.expr + # A fully-static expression may already be a Python/sympy integer. + if isinstance(expr, int): + return int(expr) + free_symbols: Any = getattr(expr, "free_symbols", set()) + substitution = { + sym: symbol_values[sym.name] + for sym in free_symbols + if getattr(sym, "name", None) in symbol_values + } + evaluated = expr.xreplace(substitution) if substitution else expr + if getattr(evaluated, "free_symbols", set()): + # A free symbol was missing from symbol_values; caller should fall back. + raise KeyError( + f"Could not fully evaluate symbolic dim {expr} with symbol values " + f"{symbol_values} (unresolved free symbols remain)." + ) + return int(evaluated) + + def unwrap_tensor_shape( tensor: Union[torch.Tensor, FakeTensor, torch.SymInt], mode: Optional[str] = "" ) -> Sequence[Union[Optional[int], Tuple[Optional[int], Optional[int]]]]: diff --git a/py/torch_tensorrt/runtime/__init__.py b/py/torch_tensorrt/runtime/__init__.py index 7283ca0f33..9054f841e8 100644 --- a/py/torch_tensorrt/runtime/__init__.py +++ b/py/torch_tensorrt/runtime/__init__.py @@ -8,6 +8,7 @@ set_cudagraphs_mode, ) from torch_tensorrt.runtime._multi_device_safe_mode import set_multi_device_safe_mode +from torch_tensorrt.runtime._optimization_profile import optimization_profile from torch_tensorrt.runtime._output_allocator import enable_output_allocator from torch_tensorrt.runtime._pre_allocated_outputs import enable_pre_allocated_outputs from torch_tensorrt.runtime._weight_streaming import weight_streaming diff --git a/py/torch_tensorrt/runtime/_optimization_profile.py b/py/torch_tensorrt/runtime/_optimization_profile.py new file mode 100644 index 0000000000..e80981499d --- /dev/null +++ b/py/torch_tensorrt/runtime/_optimization_profile.py @@ -0,0 +1,95 @@ +"""Runtime optimization-profile selection for multi-profile TensorRT engines. + +Profile selection is **manual by default**: pin a profile by its integer index +for a ``with`` span. Pass ``"auto"`` to opt into shape-based auto-selection for +that span. State is saved on enter and restored on exit (stack semantics), so +nested ``with`` blocks compose. + +Example:: + + from torch_tensorrt.runtime import optimization_profile + + # profiles=[prefill, decode] -> index 1 is decode + with optimization_profile(trt_gm, 1): + out = trt_gm(inputs_embeds=embeds, past_key_values=kv) + + with optimization_profile(trt_gm, "auto"): + out = trt_gm(x) +""" + +from __future__ import annotations + +import logging +from typing import Any, List, Optional, Tuple + +import torch + +logger = logging.getLogger(__name__) + + +def _collect_trt_modules(module: Any) -> List[Any]: + """Return all TorchTensorRTModule instances reachable from ``module``.""" + from torch_tensorrt.dynamo.runtime._TorchTensorRTModule import TorchTensorRTModule + + trt_modules: List[Any] = [] + if isinstance(module, TorchTensorRTModule): + trt_modules.append(module) + elif isinstance(module, torch.nn.Module): + for submodule in module.modules(): + if isinstance(submodule, TorchTensorRTModule): + trt_modules.append(submodule) + return trt_modules + + +class _OptimizationProfileContext: + """Context manager that pins (or auto-selects) an optimization profile. + + Applies the requested profile to every TensorRT submodule on enter, and + restores each submodule's previous profile state on exit. + """ + + def __init__(self, module: Any, profile: Optional[Any]) -> None: + self.module = module + self.profile = profile + self._saved_state: List[Tuple[Any, Optional[Tuple[Any, ...]]]] = [] + + def __enter__(self) -> Any: + trt_modules = _collect_trt_modules(self.module) + if not trt_modules: + logger.warning( + "optimization_profile() found no TensorRT submodules to configure." + ) + for trt_module in trt_modules: + saved = trt_module.get_optimization_profile_state() + self._saved_state.append((trt_module, saved)) + trt_module.set_optimization_profile(self.profile) + return self.module + + def __exit__(self, *exc: Any) -> None: + # restore_* re-applies all of (pinned, auto, active) captured on enter, + # including the active TRT profile index switched inside the block. + for trt_module, saved in reversed(self._saved_state): + try: + trt_module.restore_optimization_profile_state(saved) + except Exception as e: # pragma: no cover - defensive restore + logger.warning( + f"Failed to restore optimization profile state on exit: {e}" + ) + self._saved_state.clear() + + +def optimization_profile( + module: Any, profile: Optional[Any] +) -> _OptimizationProfileContext: + """Select the active TensorRT optimization profile for a ``with`` span. + + Args: + module: A compiled ``GraphModule`` (or ``TorchTensorRTModule``) containing + one or more TensorRT engines. + profile: Profile index (``int``), the string ``"auto"`` to enable + shape-based auto-selection, or ``None`` to clear. + + Returns: + A context manager. Pinned/auto state is restored on exit. + """ + return _OptimizationProfileContext(module, profile) diff --git a/tests/py/dynamo/runtime/test_multi_optimization_profiles.py b/tests/py/dynamo/runtime/test_multi_optimization_profiles.py new file mode 100644 index 0000000000..ce8f0f31ed --- /dev/null +++ b/tests/py/dynamo/runtime/test_multi_optimization_profiles.py @@ -0,0 +1,426 @@ +import copy +import unittest + +import torch +import torch_tensorrt +from torch.testing._internal.common_utils import TestCase, run_tests +from torch_tensorrt._Input import Input +from torch_tensorrt.dynamo.runtime._TorchTensorRTModule import TorchTensorRTModule +from torch_tensorrt.dynamo.runtime._TRTEngine import TRTEngine +from torch_tensorrt.runtime import optimization_profile + +# Profiles are an ordered list; the list index is the optimization-profile +# index selected at runtime. Order is meaningful for lazy auto-selection: the +# decode profile ([1, 1]) and prefill profile ([1, 64]) overlap at seq=1, so we +# declare decode FIRST (index 0) to make it win the overlap (first-working). +DECODE_IDX = 0 +PREFILL_IDX = 1 + + +class MLP(torch.nn.Module): + def __init__(self): + super().__init__() + self.l1 = torch.nn.Linear(16, 32) + self.l2 = torch.nn.Linear(32, 16) + + def forward(self, x): + return self.l2(torch.relu(self.l1(x))) + + +def _make_profiles_input(): + return Input( + name="x", + dtype=torch.float16, + profiles=[ + {"min": (4, 1, 16), "opt": (4, 1, 16), "max": (4, 1, 16)}, # decode + {"min": (4, 1, 16), "opt": (4, 48, 16), "max": (4, 64, 16)}, # prefill + ], + ) + + +def _compile_mlp(model, **kwargs): + inp = _make_profiles_input() + example = torch.randn(4, 48, 16, dtype=torch.float16, device="cuda") + ep = torch.export.export( + model, + (example,), + dynamic_shapes=({1: torch.export.Dim("seq", min=1, max=64)},), + ) + return torch_tensorrt.dynamo.compile( + ep, + arg_inputs=[inp], + min_block_size=1, + enabled_precisions={torch.float16}, + **kwargs, + ) + + +class TestInputProfilesValidation(TestCase): + def test_union_envelope(self): + inp = _make_profiles_input() + self.assertEqual(inp.shape["min_shape"], (4, 1, 16)) + self.assertEqual(inp.shape["max_shape"], (4, 64, 16)) + # opt is taken from the first declared profile (decode at index 0) + self.assertEqual(inp.shape["opt_shape"], (4, 1, 16)) + self.assertEqual(len(inp.profiles), 2) + self.assertEqual(inp.profiles[DECODE_IDX]["max"], (4, 1, 16)) + self.assertEqual(inp.profiles[PREFILL_IDX]["opt"], (4, 48, 16)) + + def test_min_zero_rejected(self): + with self.assertRaises(ValueError): + Input(profiles=[{"min": (0, 1), "opt": (1, 1), "max": (2, 1)}]) + + def test_min_opt_max_ordering(self): + with self.assertRaises(ValueError): + Input(profiles=[{"min": (4, 1), "opt": (4, 8), "max": (4, 2)}]) + + def test_empty_profiles_rejected(self): + with self.assertRaises(ValueError): + Input(profiles=[]) + + def test_mutual_exclusion(self): + with self.assertRaises(ValueError): + Input( + shape=(1, 2), + profiles=[{"min": (1,), "opt": (1,), "max": (1,)}], + ) + + def test_str_includes_profiles(self): + inp = _make_profiles_input() + self.assertIn("profiles=", str(inp)) + + def test_profiles_with_shared_dims(self): + # ``profiles`` and ``shared_dims`` compose: profiles set the per-profile + # ranges while shared_dims names the dynamic axis (validated against the + # union envelope) for cross-input symbol sharing. + inp = Input( + name="input_ids", + profiles=[ + {"min": (1, 1), "opt": (1, 1), "max": (1, 1)}, # decode + {"min": (1, 1), "opt": (1, 128), "max": (1, 512)}, # prefill + ], + shared_dims={1: "seq"}, + ) + self.assertEqual(len(inp.profiles), 2) + self.assertEqual(inp.shared_dims, {1: "seq"}) + # union envelope marks the shared axis dynamic (1..512) + self.assertEqual(inp.shape["min_shape"], (1, 1)) + self.assertEqual(inp.shape["max_shape"], (1, 512)) + + def test_shared_dims_on_static_union_axis_rejected(self): + # Axis 0 has min == max across every profile, so it is static in the + # union; naming it for sharing is a user error. + with self.assertRaises(ValueError): + Input( + profiles=[ + {"min": (1, 1), "opt": (1, 8), "max": (1, 16)}, + {"min": (1, 1), "opt": (1, 4), "max": (1, 32)}, + ], + shared_dims={0: "batch"}, + ) + + +@unittest.skipIf(not torch.cuda.is_available(), "CUDA is required") +class TestMultiProfileRuntime(TestCase): + """Runtime behavior of multi-profile engines. + + The runtime (C++ or Python ``TRTEngine``) is selected automatically by the + Torch-TensorRT build, so these tests drive the runtime-agnostic engine API + (``num_optimization_profiles`` / ``set_active_profile`` / + ``_active_profile_index`` / ``_auto_select_profiles``) exposed identically by + both runtimes via the ``optimization_profile`` context manager. + """ + + def setUp(self): + self.model = MLP().eval().cuda().half() + self.trt_gm = _compile_mlp(self.model) + self.decode_in = torch.randn(4, 1, 16, dtype=torch.float16, device="cuda") + self.prefill_in = torch.randn(4, 32, 16, dtype=torch.float16, device="cuda") + + def _trt_engines(self, module): + return [ + m.engine for m in module.modules() if isinstance(m, TorchTensorRTModule) + ] + + def test_two_profiles_built(self): + engines = self._trt_engines(self.trt_gm) + self.assertGreaterEqual(len(engines), 1) + for e in engines: + self.assertEqual(e.num_optimization_profiles, 2) + + def test_manual_pin_by_decode_index(self): + ref = self.model(self.decode_in) + with optimization_profile(self.trt_gm, DECODE_IDX): + out = self.trt_gm(self.decode_in) + self.assertEqual(tuple(out.shape), (4, 1, 16)) + self.assertTrue(torch.allclose(out, ref, atol=1e-2)) + + def test_manual_pin_by_prefill_index(self): + with optimization_profile(self.trt_gm, PREFILL_IDX): + out = self.trt_gm(self.prefill_in) + self.assertEqual(tuple(out.shape), (4, 32, 16)) + + def test_auto_selection_decode_and_prefill(self): + ref_d = self.model(self.decode_in) + ref_p = self.model(self.prefill_in) + with optimization_profile(self.trt_gm, "auto"): + out_d = self.trt_gm(self.decode_in) + out_p = self.trt_gm(self.prefill_in) + self.assertTrue(torch.allclose(out_d, ref_d, atol=1e-2)) + self.assertTrue(torch.allclose(out_p, ref_p, atol=1e-2)) + + def test_auto_selection_is_lazy_first_working(self): + # decode (idx 0) and prefill (idx 1) both accept seq=1; lazy selection + # must pick the first that fits (decode). seq=32 only fits prefill. + engines = self._trt_engines(self.trt_gm) + with optimization_profile(self.trt_gm, "auto"): + self.trt_gm(self.decode_in) + for e in engines: + self.assertEqual(e._active_profile_index, DECODE_IDX) + self.trt_gm(self.prefill_in) + for e in engines: + self.assertEqual(e._active_profile_index, PREFILL_IDX) + + def test_manual_pin_persists_across_calls(self): + # A manual pin (no "auto") must stick: the engine keeps the pinned + # profile across invocations instead of re-selecting per call. + engines = self._trt_engines(self.trt_gm) + with optimization_profile(self.trt_gm, PREFILL_IDX): + self.trt_gm(self.prefill_in) + for e in engines: + self.assertEqual(e._active_profile_index, PREFILL_IDX) + self.assertFalse(e._auto_select_profiles) + self.trt_gm(self.prefill_in) + for e in engines: + self.assertEqual(e._active_profile_index, PREFILL_IDX) + + def test_profile_state_restored_after_context(self): + # Leaving the context manager restores the engine's prior profile state. + engines = self._trt_engines(self.trt_gm) + before = [(e._auto_select_profiles, e._active_profile_index) for e in engines] + with optimization_profile(self.trt_gm, PREFILL_IDX): + self.trt_gm(self.prefill_in) + after = [(e._auto_select_profiles, e._active_profile_index) for e in engines] + self.assertEqual(before, after) + + def test_out_of_range_index_raises(self): + with self.assertRaises(ValueError): + with optimization_profile(self.trt_gm, 99): + self.trt_gm(self.decode_in) + + def test_non_int_profile_raises(self): + with self.assertRaises(TypeError): + with optimization_profile(self.trt_gm, "decode"): + self.trt_gm(self.decode_in) + + def test_serialization_round_trip_preserves_profiles(self): + trt_gm2 = copy.deepcopy(self.trt_gm) + for e in self._trt_engines(trt_gm2): + self.assertEqual(e.num_optimization_profiles, 2) + with optimization_profile(trt_gm2, DECODE_IDX): + out = trt_gm2(self.decode_in) + self.assertEqual(tuple(out.shape), (4, 1, 16)) + + def _check_reconstructed_profile_state(self, engine): + # A Python ``TRTEngine`` reconstructed with NO optimization-profile + # metadata must rebuild its profile count and per-profile [min, max] dim + # ranges purely from the TensorRT API (getNbOptimizationProfiles / + # getProfileShape). This white-box check reads the Python runtime's + # internal ``_profile_dim_ranges`` / ``_auto_select_profile`` (not part + # of the user-facing API; the C++ runtime is covered behaviorally below). + self.assertEqual(engine.num_optimization_profiles, 2) + + ranges = engine._profile_dim_ranges + name = next(iter(ranges)) + # Ranges for the dynamic seq axis (dim 1). tuple() normalizes the Python + # runtime (tuples) and C++ runtime (lists) to compare uniformly. + seq_ranges = ranges[name][1] + self.assertEqual(tuple(seq_ranges[DECODE_IDX]), (1, 1)) + self.assertEqual(tuple(seq_ranges[PREFILL_IDX]), (1, 64)) + + # seq=1 -> decode (first profile that fits), seq=32 -> prefill. + self.assertEqual(engine._auto_select_profile([self.decode_in]), DECODE_IDX) + self.assertEqual(engine._auto_select_profile([self.prefill_in]), PREFILL_IDX) + + def test_profiles_restored_from_trt_api_without_metadata(self): + # Python runtime: rebuild a fresh ``TRTEngine`` straight from the + # serialized layout (engine bytes + binding names + device only), + # simulating loading an engine that carries no profile metadata. + src = self._trt_engines(self.trt_gm)[0] + if not isinstance(src, TRTEngine): + self.skipTest("Python TRTEngine-specific construction path") + fresh = TRTEngine(list(src.serialized_info)) + self._check_reconstructed_profile_state(fresh) + + def test_profiles_restored_from_trt_api_without_metadata_cpp(self): + # C++ runtime: a deep-copied module round-trips the engine through + # serialize/deserialize, which is the "load with no profile metadata" + # path; ``setup_optimization_profiles()`` (called from the C++ ctor) + # rebuilds everything from the TRT API. Verified behaviorally through the + # user-facing API only (profile count + shape-based auto-selection); the + # internal dim-range introspection is intentionally not exposed to users. + src = self._trt_engines(self.trt_gm)[0] + if isinstance(src, TRTEngine): + self.skipTest("C++ runtime-specific test") + trt_gm2 = copy.deepcopy(self.trt_gm) + engines = self._trt_engines(trt_gm2) + for e in engines: + self.assertEqual(e.num_optimization_profiles, 2) + # Auto-selection relies on the dim ranges rebuilt from the TRT API: seq=1 + # must resolve to decode, seq>1 to prefill. + with optimization_profile(trt_gm2, "auto"): + trt_gm2(self.decode_in) + for e in engines: + self.assertEqual(e._active_profile_index, DECODE_IDX) + trt_gm2(self.prefill_in) + for e in engines: + self.assertEqual(e._active_profile_index, PREFILL_IDX) + + +@unittest.skipIf(not torch.cuda.is_available(), "CUDA is required") +class TestMultiProfileSharedDims(TestCase): + """``profiles`` combined with ``shared_dims``. + + Two inputs share a dynamic ``seq`` axis (one exported ``Dim``) while each + declares the same prefill/decode profiles. The shared dim feeds the export + envelope; the profiles feed the per-profile TRT ranges. + """ + + class TwoInput(torch.nn.Module): + def forward(self, a, b): + return torch.relu(a * b + a) + + def _shared_inputs(self): + # prefill at index 0 so the export opt (taken from profile[0]) is > 1 and + # the seq axis is not specialized to a constant during tracing. + prof = [ + {"min": (1, 1), "opt": (1, 128), "max": (1, 512)}, # prefill + {"min": (1, 1), "opt": (1, 1), "max": (1, 16)}, # decode + ] + a = Input(profiles=prof, shared_dims={1: "seq"}, dtype=torch.float32, name="a") + b = Input(profiles=prof, shared_dims={1: "seq"}, dtype=torch.float32, name="b") + return a, b + + def test_shared_axis_is_single_export_symbol(self): + # Both inputs' dynamic seq axis must trace to the *same* export symbol. + model = self.TwoInput().eval().cuda() + a, b = self._shared_inputs() + ep = torch_tensorrt.dynamo.trace(model, arg_inputs=[a, b]) + placeholders = [n for n in ep.graph.nodes if n.op == "placeholder"] + symbols = [] + for ph in placeholders[:2]: + val = ph.meta["val"] + dim = val.shape[1] + self.assertIsInstance(dim, torch.SymInt) + symbols.append(dim.node.expr.name) + self.assertEqual(symbols[0], symbols[1]) + + def _trt_engines(self, module): + return [ + m.engine for m in module.modules() if isinstance(m, TorchTensorRTModule) + ] + + def test_compile_and_run_across_profiles(self): + model = self.TwoInput().eval().cuda() + a, b = self._shared_inputs() + ep = torch_tensorrt.dynamo.trace(model, arg_inputs=[a, b]) + trt_gm = torch_tensorrt.dynamo.compile( + ep, + arg_inputs=[a, b], + min_block_size=1, + enabled_precisions={torch.float32}, + ) + for e in self._trt_engines(trt_gm): + self.assertEqual(e.num_optimization_profiles, 2) + + # seq=1 fits the decode profile, seq=256 fits prefill. + for seq in (1, 8, 256): + x = torch.randn(1, seq, device="cuda") + y = torch.randn(1, seq, device="cuda") + with optimization_profile(trt_gm, "auto"): + out = trt_gm(x, y) + self.assertTrue(torch.allclose(out, model(x, y), atol=1e-2)) + + +@unittest.skipIf(not torch.cuda.is_available(), "CUDA is required") +class TestMultiProfileGraphBreak(TestCase): + """Profile propagation across graph breaks (build-selected runtime).""" + + def _trt_engines(self, module): + return [ + m.engine for m in module.modules() if isinstance(m, TorchTensorRTModule) + ] + + def test_submodule_profile_propagation(self): + class GraphBreakMLP(torch.nn.Module): + def __init__(self): + super().__init__() + self.l1 = torch.nn.Linear(16, 32) + self.l2 = torch.nn.Linear(32, 16) + + def forward(self, x): + a = self.l1(x) + b = torch.sin(a) # forced to run in torch -> graph break + return self.l2(torch.relu(b)) + + model = GraphBreakMLP().eval().cuda().half() + trt_gm = _compile_mlp(model, torch_executed_ops={"torch.ops.aten.sin.default"}) + + engines = self._trt_engines(trt_gm) + # Expect more than one TRT engine (graph break) and every engine has + # the two profiles propagated. + self.assertGreaterEqual(len(engines), 2) + for e in engines: + self.assertEqual(e.num_optimization_profiles, 2) + + decode_in = torch.randn(4, 1, 16, dtype=torch.float16, device="cuda") + prefill_in = torch.randn(4, 32, 16, dtype=torch.float16, device="cuda") + with optimization_profile(trt_gm, "auto"): + out_d = trt_gm(decode_in) + out_p = trt_gm(prefill_in) + self.assertTrue(torch.allclose(out_d, model(decode_in), atol=1e-2)) + self.assertTrue(torch.allclose(out_p, model(prefill_in), atol=1e-2)) + + def test_reshaped_dynamic_submodule_input(self): + # The reshape makes the post-graph-break submodule input dim a *derived* + # symbolic expression of the source seq symbol (16 * seq), not the source + # symbol itself. This exercises per-profile bound evaluation of derived + # dynamic dims when propagating profiles to intermediate submodules. + class ReshapeGraphBreakMLP(torch.nn.Module): + def __init__(self): + super().__init__() + self.fc1 = torch.nn.Linear(4, 32) + self.fc2 = torch.nn.Linear(32, 32) + + def forward(self, x): # x: (4, seq, 16) + x = self.fc1(x.reshape(-1, 4)) # (16 * seq, 32): derived dynamic dim + x = torch.relu(x) # forced to run in torch -> graph break + return self.fc2(x) + + model = ReshapeGraphBreakMLP().eval().cuda().half() + trt_gm = _compile_mlp(model, torch_executed_ops={"torch.ops.aten.relu.default"}) + + engines = self._trt_engines(trt_gm) + # Graph break -> multiple TRT engines; every engine (including the one + # fed the reshaped derived-dynamic tensor) carries both profiles. + self.assertGreaterEqual(len(engines), 2) + for e in engines: + self.assertEqual(e.num_optimization_profiles, 2) + + decode_in = torch.randn(4, 1, 16, dtype=torch.float16, device="cuda") + prefill_in = torch.randn(4, 32, 16, dtype=torch.float16, device="cuda") + with optimization_profile(trt_gm, "auto"): + out_d = trt_gm(decode_in) + # seq=1 -> derived dim 16*1=16; both engines auto-select decode + for e in engines: + self.assertEqual(e._active_profile_index, DECODE_IDX) + out_p = trt_gm(prefill_in) + # seq=32 -> derived dim 16*32=512; both engines auto-select prefill + for e in engines: + self.assertEqual(e._active_profile_index, PREFILL_IDX) + self.assertTrue(torch.allclose(out_d, model(decode_in), atol=1e-2)) + self.assertTrue(torch.allclose(out_p, model(prefill_in), atol=1e-2)) + + +if __name__ == "__main__": + run_tests()