Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions py/torch_tensorrt/dynamo/_compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,7 @@ def cross_compile_for_windows(
timing_cache_path: str = _defaults.TIMING_CACHE_PATH,
runtime_cache_path: str = _defaults.RUNTIME_CACHE_PATH,
dynamic_shapes_kernel_specialization_strategy: str = _defaults.DYNAMIC_SHAPES_KERNEL_SPECIALIZATION_STRATEGY,
cuda_graph_strategy: str = _defaults.CUDA_GRAPH_STRATEGY,
lazy_engine_init: bool = _defaults.LAZY_ENGINE_INIT,
cache_built_engines: bool = _defaults.CACHE_BUILT_ENGINES,
reuse_cached_engines: bool = _defaults.REUSE_CACHED_ENGINES,
Expand Down Expand Up @@ -176,6 +177,7 @@ def cross_compile_for_windows(
timing_cache_path (str): Path to the timing cache if it exists (or) where it will be saved after compilation. Not used for TensorRT-RTX.
runtime_cache_path (str): Path to the runtime cache for TensorRT-RTX JIT compilation results. Not used for standard TensorRT.
dynamic_shapes_kernel_specialization_strategy (str): Strategy for dynamic shape kernel specialization at runtime (TensorRT-RTX only). Options: "lazy", "eager", "none". Default: "lazy".
cuda_graph_strategy (str): Strategy for CUDA graph capture/replay (TensorRT-RTX only). Options: "disabled" (manual capture), "whole_graph_capture" (TRT-RTX handles internally). Default: "disabled".
lazy_engine_init (bool): Defer setting up engines until the compilation of all engines is complete. Can allow larger models with multiple graph breaks to compile but can lead to oversubscription of GPU memory at runtime.
cache_built_engines (bool): Whether to save the compiled TRT engines to storage
reuse_cached_engines (bool): Whether to load the compiled TRT engines from storage
Expand Down Expand Up @@ -342,6 +344,7 @@ def cross_compile_for_windows(
"timing_cache_path": timing_cache_path,
"runtime_cache_path": runtime_cache_path,
"dynamic_shapes_kernel_specialization_strategy": dynamic_shapes_kernel_specialization_strategy,
"cuda_graph_strategy": cuda_graph_strategy,
"lazy_engine_init": lazy_engine_init,
"cache_built_engines": cache_built_engines,
"reuse_cached_engines": reuse_cached_engines,
Expand Down Expand Up @@ -455,6 +458,7 @@ def compile(
timing_cache_path: str = _defaults.TIMING_CACHE_PATH,
runtime_cache_path: str = _defaults.RUNTIME_CACHE_PATH,
dynamic_shapes_kernel_specialization_strategy: str = _defaults.DYNAMIC_SHAPES_KERNEL_SPECIALIZATION_STRATEGY,
cuda_graph_strategy: str = _defaults.CUDA_GRAPH_STRATEGY,
lazy_engine_init: bool = _defaults.LAZY_ENGINE_INIT,
cache_built_engines: bool = _defaults.CACHE_BUILT_ENGINES,
reuse_cached_engines: bool = _defaults.REUSE_CACHED_ENGINES,
Expand Down Expand Up @@ -552,6 +556,7 @@ def compile(
timing_cache_path (str): Path to the timing cache if it exists (or) where it will be saved after compilation. Not used for TensorRT-RTX.
runtime_cache_path (str): Path to the runtime cache for TensorRT-RTX JIT compilation results. Not used for standard TensorRT.
dynamic_shapes_kernel_specialization_strategy (str): Strategy for dynamic shape kernel specialization at runtime (TensorRT-RTX only). Options: "lazy", "eager", "none". Default: "lazy".
cuda_graph_strategy (str): Strategy for CUDA graph capture/replay (TensorRT-RTX only). Options: "disabled" (manual capture), "whole_graph_capture" (TRT-RTX handles internally). Default: "disabled".
lazy_engine_init (bool): Defer setting up engines until the compilation of all engines is complete. Can allow larger models with multiple graph breaks to compile but can lead to oversubscription of GPU memory at runtime.
cache_built_engines (bool): Whether to save the compiled TRT engines to storage
reuse_cached_engines (bool): Whether to load the compiled TRT engines from storage
Expand Down Expand Up @@ -761,6 +766,7 @@ def compile(
"timing_cache_path": timing_cache_path,
"runtime_cache_path": runtime_cache_path,
"dynamic_shapes_kernel_specialization_strategy": dynamic_shapes_kernel_specialization_strategy,
"cuda_graph_strategy": cuda_graph_strategy,
"lazy_engine_init": lazy_engine_init,
"cache_built_engines": cache_built_engines,
"reuse_cached_engines": reuse_cached_engines,
Expand Down Expand Up @@ -1176,6 +1182,7 @@ def convert_exported_program_to_serialized_trt_engine(
timing_cache_path: str = _defaults.TIMING_CACHE_PATH,
runtime_cache_path: str = _defaults.RUNTIME_CACHE_PATH,
dynamic_shapes_kernel_specialization_strategy: str = _defaults.DYNAMIC_SHAPES_KERNEL_SPECIALIZATION_STRATEGY,
cuda_graph_strategy: str = _defaults.CUDA_GRAPH_STRATEGY,
lazy_engine_init: bool = _defaults.LAZY_ENGINE_INIT,
cache_built_engines: bool = _defaults.CACHE_BUILT_ENGINES,
reuse_cached_engines: bool = _defaults.REUSE_CACHED_ENGINES,
Expand Down Expand Up @@ -1254,6 +1261,7 @@ def convert_exported_program_to_serialized_trt_engine(
timing_cache_path (str): Path to the timing cache if it exists (or) where it will be saved after compilation. Not used for TensorRT-RTX.
runtime_cache_path (str): Path to the runtime cache for TensorRT-RTX JIT compilation results. Not used for standard TensorRT.
dynamic_shapes_kernel_specialization_strategy (str): Strategy for dynamic shape kernel specialization at runtime (TensorRT-RTX only). Options: "lazy", "eager", "none". Default: "lazy".
cuda_graph_strategy (str): Strategy for CUDA graph capture/replay (TensorRT-RTX only). Options: "disabled" (manual capture), "whole_graph_capture" (TRT-RTX handles internally). Default: "disabled".
lazy_engine_init (bool): Defer setting up engines until the compilation of all engines is complete. Can allow larger models with multiple graph breaks to compile but can lead to oversubscription of GPU memory at runtime.
cache_built_engines (bool): Whether to save the compiled TRT engines to storage
reuse_cached_engines (bool): Whether to load the compiled TRT engines from storage
Expand Down Expand Up @@ -1429,6 +1437,7 @@ def convert_exported_program_to_serialized_trt_engine(
"timing_cache_path": timing_cache_path,
"runtime_cache_path": runtime_cache_path,
"dynamic_shapes_kernel_specialization_strategy": dynamic_shapes_kernel_specialization_strategy,
"cuda_graph_strategy": cuda_graph_strategy,
"lazy_engine_init": lazy_engine_init,
"cache_built_engines": cache_built_engines,
"reuse_cached_engines": reuse_cached_engines,
Expand Down
1 change: 1 addition & 0 deletions py/torch_tensorrt/dynamo/_defaults.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,7 @@
DECOMPOSE_ATTENTION = False
ATTN_BIAS_IS_CAUSAL = True
DYNAMIC_SHAPES_KERNEL_SPECIALIZATION_STRATEGY = "lazy"
CUDA_GRAPH_STRATEGY = "disabled"

if platform.system() == "Linux":
import pwd
Expand Down
3 changes: 3 additions & 0 deletions py/torch_tensorrt/dynamo/_settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
AUTOCAST_MAX_OUTPUT_THRESHOLD,
CACHE_BUILT_ENGINES,
CPU_MEMORY_BUDGET,
CUDA_GRAPH_STRATEGY,
DECOMPOSE_ATTENTION,
DISABLE_TF32,
DLA_GLOBAL_DRAM_SIZE,
Expand Down Expand Up @@ -102,6 +103,7 @@ class CompilationSettings:
timing_cache_path (str): Path to the timing cache if it exists (or) where it will be saved after compilation. Not used for TensorRT-RTX (no autotuning).
runtime_cache_path (str): Path to the runtime cache for TensorRT-RTX JIT compilation results. The cache is loaded on engine setup and saved on module cleanup. Uses file locking for concurrent access safety. Not used for standard TensorRT.
dynamic_shapes_kernel_specialization_strategy (str): Strategy for compiling shape-specialized kernels at runtime for dynamic shapes (TensorRT-RTX only). Options: "lazy" (compile in background, use fallback until ready), "eager" (compile immediately, blocking), "none" (always use fallback kernels). Default: "lazy".
cuda_graph_strategy (str): Strategy for CUDA graph capture/replay (TensorRT-RTX only). Options: "disabled" (no native CUDA graphs, uses manual capture if cudagraphs mode is enabled), "whole_graph_capture" (TRT-RTX handles CUDA graph capture internally). When set to "whole_graph_capture", the manual torch CUDA graph capture/replay in forward() is bypassed. Default: "disabled".
cache_built_engines (bool): Whether to save the compiled TRT engines to storage
reuse_cached_engines (bool): Whether to load the compiled TRT engines from storage
use_strong_typing (bool): This flag enables strong typing in TensorRT compilation which respects the precisions set in the Pytorch model. This is useful when users have mixed precision graphs.
Expand Down Expand Up @@ -159,6 +161,7 @@ class CompilationSettings:
dynamic_shapes_kernel_specialization_strategy: str = (
DYNAMIC_SHAPES_KERNEL_SPECIALIZATION_STRATEGY
)
cuda_graph_strategy: str = CUDA_GRAPH_STRATEGY
lazy_engine_init: bool = LAZY_ENGINE_INIT
cache_built_engines: bool = CACHE_BUILT_ENGINES
reuse_cached_engines: bool = REUSE_CACHED_ENGINES
Expand Down
48 changes: 48 additions & 0 deletions py/torch_tensorrt/dynamo/runtime/_CudaGraphsTorchTensorRTModule.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,53 @@ def __del__(self) -> None:
def set_use_output_allocator(self, enable: bool) -> None:
self.use_output_allocator_outputs = enable

def _check_monolithic_capturability(self, stream: torch.cuda.Stream) -> None:
Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

TRT-RTX would need to avoid the
"If your input shapes change between requests, the graph is re-recorded for each new shape. "
behavior from torch-TRT here in subgraphs mode. TRT-RTX takes care of re-capturing graphs internally if shapes have changed.

https://docs.pytorch.org/TensorRT/tutorials/runtime_opt/cuda_graphs.html

We should add an explicit test to verify this.

"""Verify all TRT submodules are monolithically capturable on RTX.

For whole-graph CUDA graph mode with mixed TRT + PyTorch ops,
all TRT engines must be safe for manual stream capture. If any
engine has lazy kernel specialization or non-capturable conditions,
raises RuntimeError.
"""
from torch_tensorrt._features import ENABLED_FEATURES

if not ENABLED_FEATURES.tensorrt_rtx:
return # non-RTX: no check needed
from torch_tensorrt.dynamo.runtime._PythonTorchTensorRTModule import (
PythonTorchTensorRTModule,
)

for name, mod in self.compiled_module.named_modules():
if isinstance(mod, PythonTorchTensorRTModule):
if not mod._is_monolithic_capturable(stream):
raise RuntimeError(
f"CUDA graph capture failed: TRT submodule "
f"'{name}' is not monolithically capturable "
f"(lazy kernel specialization or non-capturable "
f"stream). Whole-graph CUDA graph mode with mixed "
f"TRT + PyTorch ops requires all TRT engines to be "
f"capturable. Consider using "
f"cuda_graph_strategy='whole_graph_capture' with "
f"set_cudagraphs_mode(True) instead of "
f"enable_cudagraphs()."
)
# Ensure RTX-native is DISABLED so TRT engines do not
# interfere with the outer monolithic capture
if mod._rtx_native_cudagraphs:
from torch_tensorrt.dynamo.runtime._PythonTorchTensorRTModule import (
_get_cuda_graph_strategy,
)

mod.runtime_config.cuda_graph_strategy = _get_cuda_graph_strategy(
"disabled"
)
mod.context = mod._create_context()
mod._rtx_native_cudagraphs = False
logger.info(
f"Disabled RTX-native CUDA graphs for '{name}' "
f"(using outer monolithic capture instead)"
)

def forward(
self, *args: Any, **kwargs: Any
) -> torch.Tensor | Tuple[torch.Tensor, ...]:
Expand Down Expand Up @@ -183,6 +230,7 @@ def forward(

with torch.cuda.stream(self._engine_stream):
if need_cudagraphs_record:
self._check_monolithic_capturability(self._engine_stream)
self.cudagraph = torch.cuda.CUDAGraph()
with torch.cuda.graph(self.cudagraph, stream=self._engine_stream):
self._output_buffers = self.compiled_module(*args, **kwargs)
Expand Down
80 changes: 73 additions & 7 deletions py/torch_tensorrt/dynamo/runtime/_PythonTorchTensorRTModule.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,14 @@ def _get_dynamic_shapes_kernel_strategy(strategy_str: str) -> Any:
}.get(strategy_str, trt.DynamicShapesKernelSpecializationStrategy.LAZY)


def _get_cuda_graph_strategy(strategy_str: str) -> Any:
"""Map strategy string to TRT CudaGraphStrategy enum. Only called on RTX builds."""
return {
"disabled": trt.CudaGraphStrategy.DISABLED,
"whole_graph_capture": trt.CudaGraphStrategy.WHOLE_GRAPH_CAPTURE,
}.get(strategy_str, trt.CudaGraphStrategy.DISABLED)


class DynamicOutputAllocator(trt.IOutputAllocator): # type: ignore[misc]
def __init__(self, output_dtypes: Dict[str, torch.dtype]) -> None:
trt.IOutputAllocator.__init__(self)
Expand Down Expand Up @@ -241,6 +249,7 @@ def __init__(
self.runtime_config: Any = None
self.runtime_cache: Any = None
self.runtime_cache_path = settings.runtime_cache_path
self._rtx_native_cudagraphs = False

if self.serialized_engine is not None and not self.settings.lazy_engine_init:
self.setup_engine()
Expand Down Expand Up @@ -309,6 +318,10 @@ def setup_engine(self) -> None:

if ENABLED_FEATURES.tensorrt_rtx:
self._setup_runtime_config()
self._rtx_native_cudagraphs = (
ENABLED_FEATURES.tensorrt_rtx
Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ENABLED_FEATURES.tensorrt_rtx is already true, don't need to double check.

and self.settings.cuda_graph_strategy != "disabled"
)

self.context = self._create_context()
assert self.context is not None, "Failed to create execution context"
Expand Down Expand Up @@ -336,7 +349,10 @@ def setup_engine(self) -> None:
if self.requires_output_allocator:
self.create_output_allocator()

if torch_tensorrt.runtime.get_cudagraphs_mode():
if (
torch_tensorrt.runtime.get_cudagraphs_mode()
and not self._rtx_native_cudagraphs
):
self.cudagraph = torch.cuda.CUDAGraph()

self.is_shape_inference_io = {
Expand All @@ -362,6 +378,10 @@ def _setup_runtime_config(self) -> None:
logger.info(
f"Dynamic shapes kernel specialization strategy: {self.settings.dynamic_shapes_kernel_specialization_strategy}"
)
self.runtime_config.cuda_graph_strategy = _get_cuda_graph_strategy(
self.settings.cuda_graph_strategy
)
logger.info(f"CUDA graph strategy: {self.settings.cuda_graph_strategy}")
self.runtime_cache = self.runtime_config.create_runtime_cache()
self._load_runtime_cache()
self.runtime_config.set_runtime_cache(self.runtime_cache)
Expand Down Expand Up @@ -466,6 +486,32 @@ def _reset_captured_graph(self) -> None:
self.cudagraph.reset()
self.cudagraph = None

def _is_monolithic_capturable(self, stream: torch.cuda.Stream) -> bool:
"""Check if manual torch.cuda.CUDAGraph capture is safe for this engine.

Returns False on RTX if the engine has conditions that prevent
manual stream capture (runtime allocation, DDS, lazy kernels).
"""
if not ENABLED_FEATURES.tensorrt_rtx:
return True # non-RTX: assume capturable (existing behavior)
# Check 1: TRT-RTX stream capturability (runtime allocation, DDS, etc.)
if not self.context.is_stream_capturable(stream.cuda_stream):
return False
# Check 2: Lazy kernel specialization would invalidate captured graph
if self.settings.dynamic_shapes_kernel_specialization_strategy == "lazy":
return False
return True
Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Refactor to use any(conditions) rather than individual checks.


def _enable_rtx_native_cudagraphs(self) -> None:
"""Switch to RTX-native CUDA graphs by recreating the execution context."""
if self.runtime_config is not None:
self.runtime_config.cuda_graph_strategy = _get_cuda_graph_strategy(
"whole_graph_capture"
)
self.context = self._create_context()
self._rtx_native_cudagraphs = True
logger.info("Switched to TRT-RTX native CUDA graphs")

def __del__(self) -> None:
self._save_runtime_cache()
self._reset_captured_graph()
Expand Down Expand Up @@ -559,13 +605,32 @@ def create_output_allocator(self) -> None:

def forward(self, *inputs: torch.Tensor) -> torch.Tensor | Tuple[torch.Tensor, ...]:
def run_standard_execution() -> torch.Tensor | Tuple[torch.Tensor, ...]:
# On RTX + SUBGRAPH cudagraphs: always use RTX-native CUDA graphs.
# Manual torch.cuda.CUDAGraph capture is not safe on TRT-RTX because
# lazy kernel specialization can invalidate captured graphs and
# runtime allocation can prevent stream capture.
if ENABLED_FEATURES.tensorrt_rtx and self.cudagraphs_enabled:
if not self._rtx_native_cudagraphs:
logger.warning(
"Manual CUDA graph capture is not guaranteed to work "
"on TRT-RTX (lazy kernel specialization or "
"non-capturable stream). Switching to TRT-RTX native "
"CUDA graphs. Set cuda_graph_strategy="
'"whole_graph_capture" at compile time to avoid '
"this warning."
)
self._enable_rtx_native_cudagraphs()

effective_cudagraphs = (
self.cudagraphs_enabled and not self._rtx_native_cudagraphs
)
shape_changed = self.validate_input_shapes(contiguous_inputs)
(
need_cudagraphs_record,
can_use_pre_allocated_outputs,
need_cudagraphs_reset,
) = self.runtime_states.set_runtime_states(
self.cudagraphs_enabled, self.use_pre_allocated_outputs, shape_changed
effective_cudagraphs, self.use_pre_allocated_outputs, shape_changed
)

if need_cudagraphs_reset:
Expand All @@ -587,7 +652,7 @@ def run_standard_execution() -> torch.Tensor | Tuple[torch.Tensor, ...]:
), f"Wrong number of inputs, expect {len(self.input_names)} get {len(contiguous_inputs)}."

self.setup_input_tensors(
contiguous_inputs, self.cudagraphs_enabled, need_cudagraphs_record
contiguous_inputs, effective_cudagraphs, need_cudagraphs_record
)

if shape_changed:
Expand Down Expand Up @@ -623,7 +688,7 @@ def run_standard_execution() -> torch.Tensor | Tuple[torch.Tensor, ...]:
if need_cudagraphs_record:
self._output_buffers[o] = outputs[o].clone()

if self.cudagraphs_enabled:
if effective_cudagraphs:
self.context.set_tensor_address(
output_name, self._output_buffers[o].data_ptr()
)
Expand All @@ -649,7 +714,7 @@ def run_standard_execution() -> torch.Tensor | Tuple[torch.Tensor, ...]:
self._engine_stream.wait_stream(self._caller_stream)

with torch.cuda.stream(self._engine_stream):
if self.cudagraphs_enabled:
if effective_cudagraphs:
if need_cudagraphs_record:
self.cudagraph = torch.cuda.CUDAGraph()

Expand Down Expand Up @@ -683,7 +748,7 @@ def run_standard_execution() -> torch.Tensor | Tuple[torch.Tensor, ...]:
):
self.pre_allocated_outputs = self.create_output_tensors()

if self.cudagraphs_enabled:
if effective_cudagraphs:
for idx, o in enumerate(outputs):
o.copy_(self._output_buffers[idx])

Expand Down Expand Up @@ -840,7 +905,8 @@ def run_output_allocator() -> torch.Tensor | Tuple[torch.Tensor, ...]:
return run_output_allocator()
else:
logger.debug(
f"Using the standard execution runtime mode with cudagraphs={self.cudagraphs_enabled}."
f"Using the standard execution runtime mode with cudagraphs={self.cudagraphs_enabled}"
+ (" (RTX native)" if self._rtx_native_cudagraphs else "")
)
return run_standard_execution()

Expand Down
Loading
Loading