diff --git a/py/torch_tensorrt/dynamo/runtime/_TRTEngine.py b/py/torch_tensorrt/dynamo/runtime/_TRTEngine.py index aeda1aa1e4..d2ec9e702d 100644 --- a/py/torch_tensorrt/dynamo/runtime/_TRTEngine.py +++ b/py/torch_tensorrt/dynamo/runtime/_TRTEngine.py @@ -696,10 +696,17 @@ def device_memory_budget(self) -> Any: def device_memory_budget(self, budget_bytes: int) -> None: if budget_bytes < 0: budget_bytes = self.streamable_device_memory_budget + # The weight streaming budget cannot be modified while an execution + # context is active, so release the current context first, then update + # the budget and recreate it (mirrors the C++ runtime's + # set_device_memory_budget). + self.context = None self.cuda_engine.weight_streaming_budget_v2 = budget_bytes if self.cuda_engine.weight_streaming_budget_v2 != budget_bytes: logger.error(f"Failed to set weight streaming budget to {budget_bytes}") self.context = self._create_execution_context() + if self._profile_execution: + self.enable_profiling() self.runtime_states.context_changed = True def reset_captured_graph(self) -> None: @@ -882,11 +889,12 @@ def _prepare_streams(self, contiguous_inputs: List[torch.Tensor]) -> bool: ): # Captured CUDA graph was recorded against the old stream. self.runtime_states.context_changed = True - return caller_on_default + return bool(caller_on_default) def _execute_standard( self, contiguous_inputs: List[torch.Tensor] ) -> torch.Tensor | Tuple[torch.Tensor, ...]: + cudagraphs_enabled = torch_tensorrt.runtime.get_cudagraphs_mode() if ( ENABLED_FEATURES.tensorrt_rtx @@ -913,6 +921,9 @@ def _execute_standard( # cudagraph recapture (set_runtime_states consumes and resets the # flag). caller_on_default = self._prepare_streams(contiguous_inputs) + engine_stream = self._engine_stream + caller_stream = self._caller_stream + assert engine_stream is not None and caller_stream is not None shape_changed = self.validate_input_shapes(contiguous_inputs) ( need_cudagraphs_record, @@ -970,8 +981,8 @@ def _execute_standard( with self._profile_section("TRTEngine:TensorRTRuntime"): if caller_on_default: - self._engine_stream.wait_stream(self._caller_stream) - with torch.cuda.stream(self._engine_stream): + engine_stream.wait_stream(caller_stream) + with torch.cuda.stream(engine_stream): if self.resource_allocation_strategy: self._dynamic_workspace = torch.empty( self.cuda_engine.device_memory_size_v2, @@ -985,22 +996,18 @@ def _execute_standard( self.cudagraph = torch.cuda.CUDAGraph() if self._profile_execution: self.cudagraph.enable_debug_mode() - with torch.cuda.graph( - self.cudagraph, stream=self._engine_stream - ): - self.context.execute_async_v3( - self._engine_stream.cuda_stream - ) + with torch.cuda.graph(self.cudagraph, stream=engine_stream): + self.context.execute_async_v3(engine_stream.cuda_stream) if self._profile_execution: self.cudagraph.debug_dump( f"{DEBUG_LOGGING_DIR}/{self.name}_cudagraph.dot" ) self.cudagraph.replay() # type: ignore[union-attr] else: - self.context.execute_async_v3(self._engine_stream.cuda_stream) + self.context.execute_async_v3(engine_stream.cuda_stream) if caller_on_default: - self._caller_stream.wait_stream(self._engine_stream) + caller_stream.wait_stream(engine_stream) if self.use_pre_allocated_outputs and ( self.output_tensors_are_unowned @@ -1040,14 +1047,17 @@ def _execute_output_allocator( ) caller_on_default = self._prepare_streams(contiguous_inputs) + engine_stream = self._engine_stream + caller_stream = self._caller_stream + assert engine_stream is not None and caller_stream is not None with self._profile_section("TRTEngine:TensorRTRuntime"): if caller_on_default: - self._engine_stream.wait_stream(self._caller_stream) - with torch.cuda.stream(self._engine_stream): - self.context.execute_async_v3(self._engine_stream.cuda_stream) + engine_stream.wait_stream(caller_stream) + with torch.cuda.stream(engine_stream): + self.context.execute_async_v3(engine_stream.cuda_stream) if caller_on_default: - self._caller_stream.wait_stream(self._engine_stream) + caller_stream.wait_stream(engine_stream) outputs = [] assert self.output_allocator is not None