Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 25 additions & 15 deletions py/torch_tensorrt/dynamo/runtime/_TRTEngine.py
Original file line number Diff line number Diff line change
Expand Up @@ -696,10 +696,17 @@ def device_memory_budget(self) -> Any:
def device_memory_budget(self, budget_bytes: int) -> None:
if budget_bytes < 0:
budget_bytes = self.streamable_device_memory_budget
# The weight streaming budget cannot be modified while an execution
# context is active, so release the current context first, then update
# the budget and recreate it (mirrors the C++ runtime's
# set_device_memory_budget).
self.context = None
self.cuda_engine.weight_streaming_budget_v2 = budget_bytes
if self.cuda_engine.weight_streaming_budget_v2 != budget_bytes:
logger.error(f"Failed to set weight streaming budget to {budget_bytes}")
self.context = self._create_execution_context()
if self._profile_execution:
self.enable_profiling()
self.runtime_states.context_changed = True

def reset_captured_graph(self) -> None:
Expand Down Expand Up @@ -882,11 +889,12 @@ def _prepare_streams(self, contiguous_inputs: List[torch.Tensor]) -> bool:
):
# Captured CUDA graph was recorded against the old stream.
self.runtime_states.context_changed = True
return caller_on_default
return bool(caller_on_default)

def _execute_standard(
self, contiguous_inputs: List[torch.Tensor]
) -> torch.Tensor | Tuple[torch.Tensor, ...]:

cudagraphs_enabled = torch_tensorrt.runtime.get_cudagraphs_mode()
if (
ENABLED_FEATURES.tensorrt_rtx
Expand All @@ -913,6 +921,9 @@ def _execute_standard(
# cudagraph recapture (set_runtime_states consumes and resets the
# flag).
caller_on_default = self._prepare_streams(contiguous_inputs)
engine_stream = self._engine_stream
caller_stream = self._caller_stream
assert engine_stream is not None and caller_stream is not None
shape_changed = self.validate_input_shapes(contiguous_inputs)
(
need_cudagraphs_record,
Expand Down Expand Up @@ -970,8 +981,8 @@ def _execute_standard(

with self._profile_section("TRTEngine:TensorRTRuntime"):
if caller_on_default:
self._engine_stream.wait_stream(self._caller_stream)
with torch.cuda.stream(self._engine_stream):
engine_stream.wait_stream(caller_stream)
with torch.cuda.stream(engine_stream):
if self.resource_allocation_strategy:
self._dynamic_workspace = torch.empty(
self.cuda_engine.device_memory_size_v2,
Expand All @@ -985,22 +996,18 @@ def _execute_standard(
self.cudagraph = torch.cuda.CUDAGraph()
if self._profile_execution:
self.cudagraph.enable_debug_mode()
with torch.cuda.graph(
self.cudagraph, stream=self._engine_stream
):
self.context.execute_async_v3(
self._engine_stream.cuda_stream
)
with torch.cuda.graph(self.cudagraph, stream=engine_stream):
self.context.execute_async_v3(engine_stream.cuda_stream)
if self._profile_execution:
self.cudagraph.debug_dump(
f"{DEBUG_LOGGING_DIR}/{self.name}_cudagraph.dot"
)
self.cudagraph.replay() # type: ignore[union-attr]
else:
self.context.execute_async_v3(self._engine_stream.cuda_stream)
self.context.execute_async_v3(engine_stream.cuda_stream)

if caller_on_default:
self._caller_stream.wait_stream(self._engine_stream)
caller_stream.wait_stream(engine_stream)

if self.use_pre_allocated_outputs and (
self.output_tensors_are_unowned
Expand Down Expand Up @@ -1040,14 +1047,17 @@ def _execute_output_allocator(
)

caller_on_default = self._prepare_streams(contiguous_inputs)
engine_stream = self._engine_stream
caller_stream = self._caller_stream
assert engine_stream is not None and caller_stream is not None

with self._profile_section("TRTEngine:TensorRTRuntime"):
if caller_on_default:
self._engine_stream.wait_stream(self._caller_stream)
with torch.cuda.stream(self._engine_stream):
self.context.execute_async_v3(self._engine_stream.cuda_stream)
engine_stream.wait_stream(caller_stream)
with torch.cuda.stream(engine_stream):
self.context.execute_async_v3(engine_stream.cuda_stream)
if caller_on_default:
self._caller_stream.wait_stream(self._engine_stream)
caller_stream.wait_stream(engine_stream)

outputs = []
assert self.output_allocator is not None
Expand Down
Loading