Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
95 changes: 95 additions & 0 deletions core/runtime/TRTEngine.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -251,6 +251,10 @@ TRTEngine::TRTEngine(
num_io = std::make_pair(inputs_size, outputs);
}

// Reconstruct optimization-profile info (count + per-profile ranges) from the
// TRT API so multi-profile selection works for any loaded engine.
this->setup_optimization_profiles();

#ifndef NDEBUG
this->enable_profiling();
#endif
Expand Down Expand Up @@ -512,6 +516,97 @@ void TRTEngine::reset_captured_graph() {
cudagraph.reset();
}

void TRTEngine::setup_optimization_profiles() {
num_optimization_profiles = cuda_engine->getNbOptimizationProfiles();
profile_dim_ranges.clear();
is_shape_inference_io.clear();
for (const auto& name : in_binding_names) {
is_shape_inference_io[name] = cuda_engine->isShapeInferenceIO(name.c_str());
}
if (num_optimization_profiles <= 1) {
return;
}
// name -> [dim] -> [(min, max), ...] (one entry per optimization-profile index).
for (int64_t p = 0; p < num_optimization_profiles; ++p) {
for (const auto& name : in_binding_names) {
if (is_shape_inference_io[name]) {
continue;
}
auto dmin =
cuda_engine->getProfileShape(name.c_str(), static_cast<int32_t>(p), nvinfer1::OptProfileSelector::kMIN);
auto dmax =
cuda_engine->getProfileShape(name.c_str(), static_cast<int32_t>(p), nvinfer1::OptProfileSelector::kMAX);
auto& dims = profile_dim_ranges[name];
if (dims.empty()) {
dims.resize(dmin.nbDims);
}
for (int d = 0; d < dmin.nbDims; ++d) {
dims[d].push_back(std::make_pair(dmin.d[d], dmax.d[d]));
}
}
}
}

void TRTEngine::set_active_profile(int64_t profile_index) {
if (num_optimization_profiles <= 1) {
return;
}
if (profile_index == active_profile_index) {
return;
}
auto stream = c10::cuda::getCurrentCUDAStream(device_info.id);

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does this work with the green context pr?

// setOptimizationProfileAsync returns false for an out-of-range index; the
// index is validated upstream in TorchTensorRTModule.resolve_profile_index.
TORCHTRT_CHECK(
exec_ctx->setOptimizationProfileAsync(static_cast<int32_t>(profile_index), stream.stream()),
"Failed to switch to optimization profile index " << profile_index);
stream.synchronize();
active_profile_index = profile_index;
// A profile switch invalidates any captured CUDA graph and changes the
// context state, so force re-record / shape re-inference on the next call.
runtime_states.context_changed = true;
reset_captured_graph();
shape_key = "None";
LOG_DEBUG("Switched to optimization profile index " << profile_index);
}

int64_t TRTEngine::auto_select_profile(const std::vector<at::Tensor>& inputs) {
// Lazy selection: scan profiles in index order and return the first one whose
// [min, max] ranges contain every input shape.
for (int64_t p = 0; p < num_optimization_profiles; ++p) {
bool fits = true;
for (size_t i = 0; i < in_binding_names.size() && fits; ++i) {
const auto& name = in_binding_names[i];
if (i >= inputs.size() || is_shape_inference_io[name]) {
continue;
}
auto ranges_it = profile_dim_ranges.find(name);
if (ranges_it == profile_dim_ranges.end()) {
continue;
}
const auto& dims = ranges_it->second;
auto sizes = inputs[i].sizes();
for (size_t d = 0; d < sizes.size(); ++d) {

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we cache only what the dynamic dimension is for each profile and its ranges? Then we dont need to search mostly static dims

if (d < dims.size()) {
int64_t lo = dims[d][p].first;
int64_t hi = dims[d][p].second;
if (!(lo <= sizes[d] && sizes[d] <= hi)) {
fits = false;
break;
}
}
}
}
if (fits) {
return p;
}
}
TORCHTRT_THROW_ERROR(
"No optimization profile matches the input shapes. Fix the input shapes or pin a profile "
"explicitly via optimization_profile(module, index).");
return 0; // unreachable
}

void TRTEngine::set_resource_allocation_strategy(TRTEngine::ResourceAllocationStrategy new_strategy) {
if (new_strategy != this->resource_allocation_strategy) {
this->resource_allocation_strategy = new_strategy;
Expand Down
27 changes: 27 additions & 0 deletions core/runtime/TRTEngine.h
Original file line number Diff line number Diff line change
Expand Up @@ -218,6 +218,33 @@ struct TRTEngine : torch::CustomClassHolder {
bool use_pre_allocated_outputs = false;
std::vector<at::Tensor> pre_allocated_outputs;

// --- Multiple optimization profiles ---
// State and helpers mirror the Python runtime (TRTEngine in _TRTEngine.py) so
// the C++ and Python runtimes are interchangeable: the same attribute and
// method names are exposed via torchbind in register_jit_hooks.cpp
// (``num_optimization_profiles``, ``_active_profile_index``,
// ``_auto_select_profiles``, ``set_active_profile``). Index validation lives
// in the runtime-agnostic TorchTensorRTModule.resolve_profile_index.
int64_t num_optimization_profiles = 1; // cuda_engine->getNbOptimizationProfiles()
int64_t active_profile_index = 0; // profile currently loaded in exec_ctx
bool auto_select_profiles = false; // opt-in shape-based selection (per call)
// input name -> [dim index] -> per-profile [min, max]; cached from the TRT
// API. The dim axis is a dense vector indexed by dimension.
std::unordered_map<std::string, std::vector<std::vector<std::pair<int64_t, int64_t>>>> profile_dim_ranges;
std::unordered_map<std::string, bool> is_shape_inference_io;

// Cache profile count + per-profile dim ranges purely from the TRT API
// (getNbOptimizationProfiles / getProfileShape) so selection works for any
// loaded engine with no extra serialized metadata.
void setup_optimization_profiles();
// Switch the active TRT optimization profile (idempotent).
void set_active_profile(int64_t profile_index);
// Lazy / first-working: first profile whose [min, max] fits all input shapes.
// Called internally from the execute_engine run paths (guarded by
// num_optimization_profiles > 1 && auto_select_profiles); manual pins are
// applied eagerly via set_active_profile.
int64_t auto_select_profile(const std::vector<at::Tensor>& inputs);

// Single placeholder buffer for empty tensor inputs (allocated once, reused)
void* empty_tensor_placeholder = nullptr;

Expand Down
12 changes: 12 additions & 0 deletions core/runtime/execute_engine.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -241,6 +241,13 @@ std::vector<at::Tensor> execute_engine(std::vector<at::Tensor> inputs, c10::intr

auto run_standard_execution = [&]() {
bool cudagraphs_enabled = (CUDAGRAPHS_MODE == SUBGRAPH_CUDAGRAPHS);
// Auto-select the optimization profile from input shapes before validating
// shapes, so a profile switch's context_changed flag and shape_key reset are
// observed below. Only auto-selection runs per call; manual pins are applied
// eagerly via set_active_profile.
if (compiled_engine->num_optimization_profiles > 1 && compiled_engine->auto_select_profiles) {
compiled_engine->set_active_profile(compiled_engine->auto_select_profile(inputs));
}
bool shape_changed = _validate_shapes(inputs, compiled_engine);

auto current_device_id = inputs.size() > 0 ? inputs[0].device().index() : at::cuda::current_device();
Expand Down Expand Up @@ -401,6 +408,11 @@ std::vector<at::Tensor> execute_engine(std::vector<at::Tensor> inputs, c10::intr
};

auto run_output_allocator = [&]() {
// Auto-select the optimization profile from input shapes before binding
// them. Only auto-selection runs per call; manual pins are applied eagerly.
if (compiled_engine->num_optimization_profiles > 1 && compiled_engine->auto_select_profiles) {
compiled_engine->set_active_profile(compiled_engine->auto_select_profile(inputs));
}
{ // Input Setup
std::unique_ptr<torch::autograd::profiler::RecordProfile> input_profiler_guard;
if (compiled_engine->profile_execution) {
Expand Down
7 changes: 7 additions & 0 deletions core/runtime/register_jit_hooks.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,13 @@ static auto TORCHTRT_UNUSED TRTEngineTSRegistrtion =
.def("reset_captured_graph", &TRTEngine::reset_captured_graph)
.def("set_output_tensors_as_unowned", &TRTEngine::set_output_tensors_as_unowned)
.def("are_output_tensors_unowned", &TRTEngine::are_output_tensors_unowned)
// Multiple optimization profiles. Names match the Python runtime
// (_TRTEngine.py) so both runtimes are interchangeable behind
// TorchTensorRTModule / the optimization_profile context manager.
.def("set_active_profile", &TRTEngine::set_active_profile)
.def_readonly("num_optimization_profiles", &TRTEngine::num_optimization_profiles)
.def_readonly("_active_profile_index", &TRTEngine::active_profile_index)
.def_readwrite("_auto_select_profiles", &TRTEngine::auto_select_profiles)
.def(
"use_dynamically_allocated_resources",
[](const c10::intrusive_ptr<TRTEngine>& self, bool dynamic) -> void {
Expand Down
6 changes: 5 additions & 1 deletion docsrc/tutorials/runtime_opt/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,16 @@ Runtime Optimization
=====================

Optimize inference throughput and latency: CUDA Graphs for kernel-replay,
pre-allocated output buffers, and choosing the Python vs C++ TRT execution path.
pre-allocated output buffers, multiple optimization profiles for distinct shape
regimes (e.g. LLM prefill/decode), and choosing the Python vs C++ TRT execution
path.

.. toctree::
:maxdepth: 1

cuda_graphs
Example: Torch Export with Cudagraphs <../_rendered_examples/dynamo/torch_export_cudagraphs>
Example: Pre-allocated output buffer <../_rendered_examples/dynamo/pre_allocated_output_example>
multi_optimization_profiles
Example: Multiple optimization profiles (prefill/decode) <../_rendered_examples/dynamo/multi_optimization_profiles>
Python vs C++ runtime <python_runtime>
164 changes: 164 additions & 0 deletions docsrc/tutorials/runtime_opt/multi_optimization_profiles.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,164 @@
.. _multi_optimization_profiles_tutorial:

Multiple Optimization Profiles (Prefill / Decode)
=================================================

TensorRT tunes kernels for the **optimization profile** of an engine: a
``[min, opt, max]`` range for every dynamic input dimension. Kernels are tuned at
the ``opt`` point, so a single profile can only be optimal for one shape.

Many models, however, run in several distinct shape *regimes* that share the same
weights. The canonical case is an autoregressive LLM:

* **prefill** -- the prompt is processed in one shot, so ``seq`` is large, and
* **decode** -- tokens are generated one at a time, so ``seq == 1``.

With a single dynamic range ``seq in [1, max]`` you must pick one ``opt``. Tuning
for the long prefill length leaves **decode** -- the latency-critical, most
frequently executed phase -- running on kernels chosen for a sequence length it
never sees.

Torch-TensorRT lets you declare **multiple optimization profiles** on a single
input and select the active one at runtime. The engine is built once and each
profile is tuned independently.

Declaring profiles
------------------

Pass an ordered list of ``{"min", "opt", "max"}`` dicts to
:class:`torch_tensorrt.Input` via ``profiles``. The **list index** is the
optimization-profile index you select at runtime.

.. code-block:: python

import torch
import torch_tensorrt

DECODE_IDX, PREFILL_IDX = 0, 1

profiled_input = torch_tensorrt.Input(
dtype=torch.int64,
profiles=[
# index 0 -> decode: seq pinned to 1 (a fully static profile)
{"min": (1, 1), "opt": (1, 1), "max": (1, 1)},
# index 1 -> prefill: seq in [1, 512], tuned at 256
{"min": (1, 1), "opt": (1, 256), "max": (1, 512)},
],
)

``profiles`` is mutually exclusive with the single-shape ``min_shape`` /
``opt_shape`` / ``max_shape`` (and ``shape``) arguments.

The **union envelope**
~~~~~~~~~~~~~~~~~~~~~~~~

``torch.export`` traces a model over one ``[min, opt, max]`` range, so
``Input`` automatically derives the **union envelope** of all profiles
(elementwise ``min`` of every ``min`` and ``max`` of every ``max``; ``opt`` is
taken from the first profile). Each declared profile is a subset of this
envelope. You export over the envelope and the individual profiles become the
per-profile TensorRT tunings:

.. code-block:: python

print(profiled_input.shape["min_shape"]) # (1, 1)
print(profiled_input.shape["max_shape"]) # (1, 512)

Compile
-------

Export once over the union range, then compile as usual. Every input that
declares ``profiles`` must declare the **same number** of profiles; static
inputs (or dynamic inputs without ``profiles``) reuse their single shape in every
profile.

.. code-block:: python

seq = torch.export.Dim("seq", min=1, max=512)
exported = torch.export.export(model, (example_ids,), dynamic_shapes=({1: seq},))

trt_model = torch_tensorrt.dynamo.compile(
exported,
arg_inputs=[profiled_input],
enabled_precisions={torch.float16},
min_block_size=1,
)

Selecting a profile at runtime
------------------------------

Selection is **manual by default**. Use the
:func:`torch_tensorrt.runtime.optimization_profile` context manager to pin a
profile by index for the duration of a ``with`` block; the prior state is saved
on enter and restored on exit, so blocks nest cleanly.

.. code-block:: python

from torch_tensorrt.runtime import optimization_profile

with optimization_profile(trt_model, DECODE_IDX):
logits = trt_model(decode_ids) # seq == 1

with optimization_profile(trt_model, PREFILL_IDX):
logits = trt_model(prefill_ids) # seq == 256

Pass ``"auto"`` to let Torch-TensorRT choose from the input shapes. Auto-selection
is **lazy / first-working**: it scans profiles in index order and uses the first
whose ``[min, max]`` contains the input. Order matters when profiles overlap --
declaring ``decode`` first lets it win the ``seq == 1`` overlap:

.. code-block:: python

with optimization_profile(trt_model, "auto"):
trt_model(decode_ids) # seq == 1 -> index 0 (decode) accepts -> decode
trt_model(prefill_ids) # seq == 256 -> index 0 rejects -> index 1 (prefill)

Profiles, graph breaks, and serialization
-----------------------------------------

* **Graph breaks**: when a model is partitioned into several TensorRT engines,
every engine carries the same number of profiles. Torch-TensorRT propagates the
per-profile bounds across the break, evaluating any *derived* dynamic dimension
(e.g. a ``reshape`` that turns ``seq`` into ``16 * seq``) through to the
downstream engine, so runtime selection stays consistent for the whole module.
* **Serialization / runtimes**: profile state is reconstructed from the TensorRT
API on load (``getNbOptimizationProfiles`` / ``getProfileShape``), so a
serialized engine keeps its profiles with no extra metadata. The same
``optimization_profile`` API drives both the C++ and Python runtimes, which
remain interchangeable.

Why it helps: a worked latency example
--------------------------------------

The example :ref:`multi_optimization_profiles` compiles ``google/gemma-3-1b-it``
twice -- once with a single profile (tuned at the prefill length) and once with
separate decode/prefill profiles -- then compares per-call latency. The
multi-profile engine dedicates a **static** profile (``seq`` pinned to 1) to
decode, letting TensorRT specialize that path (measured on an NVIDIA A40, FP16):

.. code-block:: text

Per-call latency (ms), batch=1
regime single-profile multi-profile speedup
--------------------------------------------------------------
decode (seq=1) 5.232 4.597 1.14x
prefill (seq=128) 7.152 7.534 0.95x

Prefill is essentially unchanged (both engines tune it at the same ``opt``),
while decode -- the regime executed once per generated token -- is faster. Exact
numbers depend on the model and GPU; the takeaway is that one engine can be tuned
well for *both* regimes instead of compromising on a single ``opt`` shape.

.. note::

Because the model has two dynamic inputs (``input_ids`` and ``position_ids``),
the example passes one profiled ``Input`` for each, both declaring the same
profiles. The HuggingFace attention path also needs a TensorRT-friendly SDPA
lowering (``tools/llm/torchtrt_ext/register_sdpa``), and ``gemma-3-1b-it`` is a
gated model requiring Hugging Face authentication.

.. seealso::

- Runnable example: :ref:`multi_optimization_profiles`
- :class:`torch_tensorrt.Input`
- :func:`torch_tensorrt.runtime.optimization_profile`
Loading
Loading