-
Notifications
You must be signed in to change notification settings - Fork 394
feat: add TRT-RTX native CUDA graph support #4187
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Draft
tp5uiuc
wants to merge
1
commit into
pytorch:main
Choose a base branch
from
tp5uiuc:feat/trtrtx-cudagraphs
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Draft
Changes from all commits
Commits
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -36,6 +36,14 @@ def _get_dynamic_shapes_kernel_strategy(strategy_str: str) -> Any: | |
| }.get(strategy_str, trt.DynamicShapesKernelSpecializationStrategy.LAZY) | ||
|
|
||
|
|
||
| def _get_cuda_graph_strategy(strategy_str: str) -> Any: | ||
| """Map strategy string to TRT CudaGraphStrategy enum. Only called on RTX builds.""" | ||
| return { | ||
| "disabled": trt.CudaGraphStrategy.DISABLED, | ||
| "whole_graph_capture": trt.CudaGraphStrategy.WHOLE_GRAPH_CAPTURE, | ||
| }.get(strategy_str, trt.CudaGraphStrategy.DISABLED) | ||
|
|
||
|
|
||
| class DynamicOutputAllocator(trt.IOutputAllocator): # type: ignore[misc] | ||
| def __init__(self, output_dtypes: Dict[str, torch.dtype]) -> None: | ||
| trt.IOutputAllocator.__init__(self) | ||
|
|
@@ -241,6 +249,7 @@ def __init__( | |
| self.runtime_config: Any = None | ||
| self.runtime_cache: Any = None | ||
| self.runtime_cache_path = settings.runtime_cache_path | ||
| self._rtx_native_cudagraphs = False | ||
|
|
||
| if self.serialized_engine is not None and not self.settings.lazy_engine_init: | ||
| self.setup_engine() | ||
|
|
@@ -309,6 +318,10 @@ def setup_engine(self) -> None: | |
|
|
||
| if ENABLED_FEATURES.tensorrt_rtx: | ||
| self._setup_runtime_config() | ||
| self._rtx_native_cudagraphs = ( | ||
| ENABLED_FEATURES.tensorrt_rtx | ||
|
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. ENABLED_FEATURES.tensorrt_rtx is already true, don't need to double check. |
||
| and self.settings.cuda_graph_strategy != "disabled" | ||
| ) | ||
|
|
||
| self.context = self._create_context() | ||
| assert self.context is not None, "Failed to create execution context" | ||
|
|
@@ -336,7 +349,10 @@ def setup_engine(self) -> None: | |
| if self.requires_output_allocator: | ||
| self.create_output_allocator() | ||
|
|
||
| if torch_tensorrt.runtime.get_cudagraphs_mode(): | ||
| if ( | ||
| torch_tensorrt.runtime.get_cudagraphs_mode() | ||
| and not self._rtx_native_cudagraphs | ||
| ): | ||
| self.cudagraph = torch.cuda.CUDAGraph() | ||
|
|
||
| self.is_shape_inference_io = { | ||
|
|
@@ -362,6 +378,10 @@ def _setup_runtime_config(self) -> None: | |
| logger.info( | ||
| f"Dynamic shapes kernel specialization strategy: {self.settings.dynamic_shapes_kernel_specialization_strategy}" | ||
| ) | ||
| self.runtime_config.cuda_graph_strategy = _get_cuda_graph_strategy( | ||
| self.settings.cuda_graph_strategy | ||
| ) | ||
| logger.info(f"CUDA graph strategy: {self.settings.cuda_graph_strategy}") | ||
| self.runtime_cache = self.runtime_config.create_runtime_cache() | ||
| self._load_runtime_cache() | ||
| self.runtime_config.set_runtime_cache(self.runtime_cache) | ||
|
|
@@ -466,6 +486,32 @@ def _reset_captured_graph(self) -> None: | |
| self.cudagraph.reset() | ||
| self.cudagraph = None | ||
|
|
||
| def _is_monolithic_capturable(self, stream: torch.cuda.Stream) -> bool: | ||
| """Check if manual torch.cuda.CUDAGraph capture is safe for this engine. | ||
|
|
||
| Returns False on RTX if the engine has conditions that prevent | ||
| manual stream capture (runtime allocation, DDS, lazy kernels). | ||
| """ | ||
| if not ENABLED_FEATURES.tensorrt_rtx: | ||
| return True # non-RTX: assume capturable (existing behavior) | ||
| # Check 1: TRT-RTX stream capturability (runtime allocation, DDS, etc.) | ||
| if not self.context.is_stream_capturable(stream.cuda_stream): | ||
| return False | ||
| # Check 2: Lazy kernel specialization would invalidate captured graph | ||
| if self.settings.dynamic_shapes_kernel_specialization_strategy == "lazy": | ||
| return False | ||
| return True | ||
|
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Refactor to use any(conditions) rather than individual checks. |
||
|
|
||
| def _enable_rtx_native_cudagraphs(self) -> None: | ||
| """Switch to RTX-native CUDA graphs by recreating the execution context.""" | ||
| if self.runtime_config is not None: | ||
| self.runtime_config.cuda_graph_strategy = _get_cuda_graph_strategy( | ||
| "whole_graph_capture" | ||
| ) | ||
| self.context = self._create_context() | ||
| self._rtx_native_cudagraphs = True | ||
| logger.info("Switched to TRT-RTX native CUDA graphs") | ||
|
|
||
| def __del__(self) -> None: | ||
| self._save_runtime_cache() | ||
| self._reset_captured_graph() | ||
|
|
@@ -559,13 +605,32 @@ def create_output_allocator(self) -> None: | |
|
|
||
| def forward(self, *inputs: torch.Tensor) -> torch.Tensor | Tuple[torch.Tensor, ...]: | ||
| def run_standard_execution() -> torch.Tensor | Tuple[torch.Tensor, ...]: | ||
| # On RTX + SUBGRAPH cudagraphs: always use RTX-native CUDA graphs. | ||
| # Manual torch.cuda.CUDAGraph capture is not safe on TRT-RTX because | ||
| # lazy kernel specialization can invalidate captured graphs and | ||
| # runtime allocation can prevent stream capture. | ||
| if ENABLED_FEATURES.tensorrt_rtx and self.cudagraphs_enabled: | ||
| if not self._rtx_native_cudagraphs: | ||
| logger.warning( | ||
| "Manual CUDA graph capture is not guaranteed to work " | ||
| "on TRT-RTX (lazy kernel specialization or " | ||
| "non-capturable stream). Switching to TRT-RTX native " | ||
| "CUDA graphs. Set cuda_graph_strategy=" | ||
| '"whole_graph_capture" at compile time to avoid ' | ||
| "this warning." | ||
| ) | ||
| self._enable_rtx_native_cudagraphs() | ||
|
|
||
| effective_cudagraphs = ( | ||
| self.cudagraphs_enabled and not self._rtx_native_cudagraphs | ||
| ) | ||
| shape_changed = self.validate_input_shapes(contiguous_inputs) | ||
| ( | ||
| need_cudagraphs_record, | ||
| can_use_pre_allocated_outputs, | ||
| need_cudagraphs_reset, | ||
| ) = self.runtime_states.set_runtime_states( | ||
| self.cudagraphs_enabled, self.use_pre_allocated_outputs, shape_changed | ||
| effective_cudagraphs, self.use_pre_allocated_outputs, shape_changed | ||
| ) | ||
|
|
||
| if need_cudagraphs_reset: | ||
|
|
@@ -587,7 +652,7 @@ def run_standard_execution() -> torch.Tensor | Tuple[torch.Tensor, ...]: | |
| ), f"Wrong number of inputs, expect {len(self.input_names)} get {len(contiguous_inputs)}." | ||
|
|
||
| self.setup_input_tensors( | ||
| contiguous_inputs, self.cudagraphs_enabled, need_cudagraphs_record | ||
| contiguous_inputs, effective_cudagraphs, need_cudagraphs_record | ||
| ) | ||
|
|
||
| if shape_changed: | ||
|
|
@@ -623,7 +688,7 @@ def run_standard_execution() -> torch.Tensor | Tuple[torch.Tensor, ...]: | |
| if need_cudagraphs_record: | ||
| self._output_buffers[o] = outputs[o].clone() | ||
|
|
||
| if self.cudagraphs_enabled: | ||
| if effective_cudagraphs: | ||
| self.context.set_tensor_address( | ||
| output_name, self._output_buffers[o].data_ptr() | ||
| ) | ||
|
|
@@ -649,7 +714,7 @@ def run_standard_execution() -> torch.Tensor | Tuple[torch.Tensor, ...]: | |
| self._engine_stream.wait_stream(self._caller_stream) | ||
|
|
||
| with torch.cuda.stream(self._engine_stream): | ||
| if self.cudagraphs_enabled: | ||
| if effective_cudagraphs: | ||
| if need_cudagraphs_record: | ||
| self.cudagraph = torch.cuda.CUDAGraph() | ||
|
|
||
|
|
@@ -683,7 +748,7 @@ def run_standard_execution() -> torch.Tensor | Tuple[torch.Tensor, ...]: | |
| ): | ||
| self.pre_allocated_outputs = self.create_output_tensors() | ||
|
|
||
| if self.cudagraphs_enabled: | ||
| if effective_cudagraphs: | ||
| for idx, o in enumerate(outputs): | ||
| o.copy_(self._output_buffers[idx]) | ||
|
|
||
|
|
@@ -840,7 +905,8 @@ def run_output_allocator() -> torch.Tensor | Tuple[torch.Tensor, ...]: | |
| return run_output_allocator() | ||
| else: | ||
| logger.debug( | ||
| f"Using the standard execution runtime mode with cudagraphs={self.cudagraphs_enabled}." | ||
| f"Using the standard execution runtime mode with cudagraphs={self.cudagraphs_enabled}" | ||
| + (" (RTX native)" if self._rtx_native_cudagraphs else "") | ||
| ) | ||
| return run_standard_execution() | ||
|
|
||
|
|
||
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
TRT-RTX would need to avoid the
"If your input shapes change between requests, the graph is re-recorded for each new shape. "
behavior from torch-TRT here in subgraphs mode. TRT-RTX takes care of re-capturing graphs internally if shapes have changed.
https://docs.pytorch.org/TensorRT/tutorials/runtime_opt/cuda_graphs.html
We should add an explicit test to verify this.