Skip to content

feat: Allow for users / kv cache to add aliased I/O for inplace operations#4251

Open
narendasan wants to merge 1 commit into
mainfrom
narendasan/aliased_io
Open

feat: Allow for users / kv cache to add aliased I/O for inplace operations#4251
narendasan wants to merge 1 commit into
mainfrom
narendasan/aliased_io

Conversation

@narendasan

Copy link
Copy Markdown
Collaborator

Description

Adds support for in-place ATen operators by extending the Torch-TensorRT compile pipeline and C++ runtime with aliased input/output bindings. The motivating case is streaming inference with a key/value cache (e.g. autoregressive decoders, ZoomASR): each step writes a single timestep into the cache, and without aliasing every step pays a full cache-size copy at the engine boundary. With aliased I/O the TensorRT engine writes directly into the user's (or module-held) cache storage; no fresh allocation, no post-engine copy.

Two aliased_io "kinds" are tracked so the runtime can reason about provenance:

  • kv_cache_update — TensorRT-enforced via IKVCacheUpdateLayer; reported through ICudaEngine::getAliasedInputTensor.
  • user — Torch-TensorRT-declared; reserved for future expansion if TRT exposes a public non-KV aliasing API.

What this PR does

Pipeline (Python)

  • New slice_scatter and index_copy converters that recognize KV-cache-update patterns (4-D static cache, dim=2, batch=1) and emit IKVCacheUpdateLayer with the output aliased to the cache input. Non-eligible cases fall back to scatter in TRT — no graph break.
  • For index_copy, two disjoint converters (validator-gated KV fast path at ConverterPriority.HIGH + scatter fallback at standard priority) cleanly split the cases.
  • aliased_io plumbed through TRTInterpreterTRTInterpreterResultSerializedInterpreterResultTorchTensorRTModule. The output() step automatically promotes layer outputs that need to be network outputs (KVCacheUpdate requires it) and appends them after user outputs. The user/side-effect boundary is derived at runtime, not stored.

Buffer-style support

  • lift_mutated_buffers lowering pass detects BUFFER_MUTATION patterns (the trailing aten.copy_(get_attr_buffer, ...) that ep.module() emits) and lifts each mutated buffer from get_attr to placeholder so the engine sees it as an input binding.
  • inline_lifted_buffers_into_gm post-compile transform registers the buffers as nn.Module state on the compiled GraphModule and rewrites the lifted placeholders to get_attr reads. The result is a plain fx.GraphModule (no custom wrapper class) that serializes cleanly through torch_tensorrt.save / torch.export.
  • Low-level entry point convert_exported_program_to_serialized_trt_engine gains lift_mutable_buffers: bool = False for power users who want to manage the resulting bindings themselves.

C++ runtime (ABI v9 → v10)

  • Bumped ABI_VERSION to "10"; added ALIASED_IO_IDX to SerializedInfoIndex.
  • serialize_aliased_io / deserialize_aliased_io helpers (wire format: output@input@kind records joined by BINDING_DELIM). Helpers live in runtime_utils.cpp alongside serialize_bindings.
  • TRTEngine constructor reconciles the build-time map against ICudaEngine::getAliasedInputTensor — the engine API is the source of truth for KV-style aliasing.
  • execute_engine records bound input tensors by binding name; for each output binding in aliased_io, binds the same data_ptr as the source input and skips fresh allocation. Pre-allocated outputs are disabled when aliased I/O is present.
  • CUDA Graphs integration: aliased input bindings bypass the persistent-clone path so the engine writes through to user storage; aliased outputs are skipped in the post-exec copy-back loop. Capture + replay both correctly mutate the user's tensor.

Docs + examples

  • docsrc/contributors/inplace_operations.rst — full design doc covering motivation, primitives, pipeline, runtime, serialization format, and known limitations.
  • Three examples under examples/dynamo/:
    • aliased_io_user_inputs.py — caller-owned cache (simplest case)
    • aliased_io_buffers.py — module-owned cache via register_buffer
    • aliased_io_kv_attention.py — realistic single-layer transformer attention block with static KV cache

Fixes partially #4240 (in-place custom plugins / multiple outputs — addresses the in-place-operator side; plugin-side aliased I/O is explicitly out of scope here).

Type of change

  • New feature (non-breaking for callers who don't opt in to aliased I/O; ABI-breaking for existing engine binaries — older serialized engines fail the version check and need to be rebuilt, consistent with prior ABI bumps).
  • This change requires a documentation update (included).

Checklist

  • My code follows the style guidelines of this project
  • I have performed a self-review of my own code
  • I have commented my code, particularly in hard-to-understand areas and hacks
  • I have made corresponding changes to the documentation
  • I have added tests to verify my fix or my feature
  • New and existing unit tests pass locally with my changes
  • I have added the relevant labels to my PR

Test summary

38 new tests across 8 files, all passing:

File Cases Covers
tests/py/dynamo/conversion/test_slice_scatter_aten.py 8 scatter-fallback path (numerical correctness via the standard converter harness)
tests/py/dynamo/runtime/test_aliased_io.py 8 end-to-end aliased I/O (user-input single/paired/streaming + buffer-style + regressions)
tests/py/dynamo/runtime/test_index_copy_kv.py 4 KV fast path + 3 fallback shapes for aten.index_copy
tests/py/dynamo/runtime/test_lift_mutable_buffers_api.py 4 low-level lift_mutable_buffers=True flag round-trip (introspect engine, construct module, run)
tests/py/dynamo/runtime/test_aliased_io_serialization.py 3 torch_tensorrt.save / load round-trip for user-input + buffer-backed + streaming buffer
tests/py/dynamo/runtime/test_aliased_io_cudagraphs.py 3 CUDA Graph capture + replay for both user-input and buffer-backed KV; cudagraphs vs non-cudagraphs parity
tests/py/dynamo/runtime/test_hf_static_cache_xfail.py 1 xfail documenting the current HF + StaticCache gap (asserts known failure mode)
tests/py/dynamo/lowering/test_buffer_lifting.py 9 lift_mutated_buffers + inline_lifted_buffers_into_gm unit tests

Known gaps (documented)

  • Stock HuggingFace decoder LMs with StaticCache don't compile end-to-end yet: torch.export's run_decompositions raises internally on the EP that convert_and_export_with_cache produces. The xfail test asserts the known failure so a future upstream fix surfaces as a test failure. Path forward documented in the design doc.
  • IKVCacheUpdateLayer requires static s_max. Dynamic-sequence-length cache shapes fall through to the scatter path (still correct, no aliasing).

@narendasan narendasan requested a review from apbose May 12, 2026 00:19
@meta-cla meta-cla Bot added the cla signed label May 12, 2026
@github-actions github-actions Bot added documentation Improvements or additions to documentation component: tests Issues re: Tests component: lowering Issues re: The lowering / preprocessing passes component: conversion Issues re: Conversion stage component: core Issues re: The core compiler component: converters Issues re: Specific op converters component: api [Python] Issues re: Python API component: runtime component: dynamo Issues relating to the `torch.compile` or `torch._dynamo.export` paths labels May 12, 2026
@github-actions github-actions Bot requested a review from cehongwang May 12, 2026 00:19
github-actions[bot]

This comment was marked as outdated.

github-actions[bot]

This comment was marked as outdated.

github-actions[bot]

This comment was marked as outdated.

github-actions[bot]

This comment was marked as outdated.

@narendasan narendasan force-pushed the narendasan/aliased_io branch from 354674d to 813e753 Compare May 12, 2026 00:26

@github-actions github-actions Bot left a comment

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There are some changes that do not conform to Python style guidelines:

--- /home/runner/work/TensorRT/TensorRT/tests/py/dynamo/conversion/test_slice_scatter_aten.py	2026-05-12 00:26:56.728308+00:00
+++ /home/runner/work/TensorRT/TensorRT/tests/py/dynamo/conversion/test_slice_scatter_aten.py	2026-05-12 00:27:18.993037+00:00
@@ -15,10 +15,11 @@
This file covers the fallback path. To force the fallback regardless of
shape we add a small no-op (``+ 0``) to the cache so it isn't a direct
network input — the converter's "input is a placeholder" check fails and
falls through to scatter.
"""
+
import torch
from parameterized import parameterized
from torch.testing._internal.common_utils import run_tests

from .harness import DispatchTestCase
--- /home/runner/work/TensorRT/TensorRT/tests/py/dynamo/lowering/test_buffer_lifting.py	2026-05-12 00:26:56.731194+00:00
+++ /home/runner/work/TensorRT/TensorRT/tests/py/dynamo/lowering/test_buffer_lifting.py	2026-05-12 00:27:19.634834+00:00
@@ -21,10 +21,11 @@
* ``inline_lifted_buffers_into_gm`` rewrites the lifted-buffer
  placeholders into ``get_attr`` reads and registers the buffers as
  module state. The result is a plain ``fx.GraphModule`` that
  serializes via ``torch_tensorrt.save`` without an external wrapper.
"""
+
import inspect

import torch
from torch.export import export
from torch.testing._internal.common_utils import TestCase, run_tests
--- /home/runner/work/TensorRT/TensorRT/tests/py/dynamo/runtime/test_aliased_io_cudagraphs.py	2026-05-12 00:26:56.733500+00:00
+++ /home/runner/work/TensorRT/TensorRT/tests/py/dynamo/runtime/test_aliased_io_cudagraphs.py	2026-05-12 00:27:21.127481+00:00
@@ -17,10 +17,11 @@
  is already visible on the user's input).

These tests cover capture + replay correctness for both KV-cache patterns
(user-input and buffer-style).
"""
+
import unittest

import torch
import torch_tensorrt
from torch.export import export
--- /home/runner/work/TensorRT/TensorRT/tests/py/dynamo/runtime/test_aliased_io_serialization.py	2026-05-12 00:26:56.733500+00:00
+++ /home/runner/work/TensorRT/TensorRT/tests/py/dynamo/runtime/test_aliased_io_serialization.py	2026-05-12 00:27:21.138494+00:00
@@ -15,10 +15,11 @@
  ``torch.export``. The ``inline_lifted_buffers_into_gm`` post-compile
  transform replaces what used to be an external ``BufferThreadingModule``
  wrapper — making the result a plain ``fx.GraphModule`` that exports
  naturally without a custom wrapper class.
"""
+
import tempfile

import torch
import torch_tensorrt
from torch.export import export
--- /home/runner/work/TensorRT/TensorRT/tests/py/dynamo/runtime/test_hf_static_cache_xfail.py	2026-05-12 00:26:56.733500+00:00
+++ /home/runner/work/TensorRT/TensorRT/tests/py/dynamo/runtime/test_hf_static_cache_xfail.py	2026-05-12 00:27:21.208126+00:00
@@ -36,10 +36,11 @@
  workaround that skips ``run_decompositions`` for already-decomposed EPs.

When the upstream issues are resolved or those features land, this
xfail test should start passing — flip it to a real test then.
"""
+
import unittest

import torch
import torch_tensorrt

--- /home/runner/work/TensorRT/TensorRT/tests/py/dynamo/runtime/test_aliased_io.py	2026-05-12 00:26:56.733500+00:00
+++ /home/runner/work/TensorRT/TensorRT/tests/py/dynamo/runtime/test_aliased_io.py	2026-05-12 00:27:21.275543+00:00
@@ -19,10 +19,11 @@
* ``TorchTensorRTModule.forward`` filters aliased outputs from the user
  return tuple.
* For buffer-style models, ``lift_mutated_buffers`` rewrites the EP and
  ``BufferThreadingModule`` threads buffers through each call.
"""
+
import torch
import torch_tensorrt
from torch.export import export
from torch.testing._internal.common_utils import TestCase, run_tests

--- /home/runner/work/TensorRT/TensorRT/tests/py/dynamo/runtime/test_index_copy_kv.py	2026-05-12 00:26:56.733500+00:00
+++ /home/runner/work/TensorRT/TensorRT/tests/py/dynamo/runtime/test_index_copy_kv.py	2026-05-12 00:27:21.304553+00:00
@@ -13,10 +13,11 @@

These tests verify both paths end-to-end via the C++ runtime: the
fast path mutates in place, the fallback produces correct numerical
results without aliasing.
"""
+
import torch
import torch_tensorrt
from torch.export import export
from torch.testing._internal.common_utils import TestCase, run_tests

--- /home/runner/work/TensorRT/TensorRT/tests/py/dynamo/runtime/test_lift_mutable_buffers_api.py	2026-05-12 00:26:56.733500+00:00
+++ /home/runner/work/TensorRT/TensorRT/tests/py/dynamo/runtime/test_lift_mutable_buffers_api.py	2026-05-12 00:27:21.380261+00:00
@@ -14,10 +14,11 @@
3. Construct a ``TorchTensorRTModule`` (C++ runtime — required for
   aliased I/O) with the discovered bindings.
4. Thread the buffer values in on each call and verify in-place
   mutation works (cache state persists across calls).
"""
+
import torch
import torch_tensorrt
from torch.export import export
from torch.testing._internal.common_utils import TestCase, run_tests
from torch_tensorrt.dynamo import convert_exported_program_to_serialized_trt_engine

@narendasan narendasan force-pushed the narendasan/aliased_io branch from 813e753 to bcaf725 Compare May 12, 2026 20:26

@github-actions github-actions Bot left a comment

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There are some changes that do not conform to Python style guidelines:

--- /home/runner/work/TensorRT/TensorRT/tests/py/dynamo/conversion/test_slice_scatter_aten.py	2026-05-12 20:26:34.855069+00:00
+++ /home/runner/work/TensorRT/TensorRT/tests/py/dynamo/conversion/test_slice_scatter_aten.py	2026-05-12 20:26:58.441876+00:00
@@ -15,10 +15,11 @@
This file covers the fallback path. To force the fallback regardless of
shape we add a small no-op (``+ 0``) to the cache so it isn't a direct
network input — the converter's "input is a placeholder" check fails and
falls through to scatter.
"""
+
import torch
from parameterized import parameterized
from torch.testing._internal.common_utils import run_tests

from .harness import DispatchTestCase
--- /home/runner/work/TensorRT/TensorRT/tests/py/dynamo/lowering/test_buffer_lifting.py	2026-05-12 20:26:34.858373+00:00
+++ /home/runner/work/TensorRT/TensorRT/tests/py/dynamo/lowering/test_buffer_lifting.py	2026-05-12 20:26:59.052592+00:00
@@ -21,10 +21,11 @@
* ``inline_lifted_buffers_into_gm`` rewrites the lifted-buffer
  placeholders into ``get_attr`` reads and registers the buffers as
  module state. The result is a plain ``fx.GraphModule`` that
  serializes via ``torch_tensorrt.save`` without an external wrapper.
"""
+
import inspect

import torch
from torch.export import export
from torch.testing._internal.common_utils import TestCase, run_tests
--- /home/runner/work/TensorRT/TensorRT/tests/py/dynamo/runtime/test_aliased_io_cudagraphs.py	2026-05-12 20:26:34.860665+00:00
+++ /home/runner/work/TensorRT/TensorRT/tests/py/dynamo/runtime/test_aliased_io_cudagraphs.py	2026-05-12 20:27:00.527955+00:00
@@ -17,10 +17,11 @@
  is already visible on the user's input).

These tests cover capture + replay correctness for both KV-cache patterns
(user-input and buffer-style).
"""
+
import unittest

import torch
import torch_tensorrt
from torch.export import export
--- /home/runner/work/TensorRT/TensorRT/tests/py/dynamo/runtime/test_aliased_io_serialization.py	2026-05-12 20:26:34.861069+00:00
+++ /home/runner/work/TensorRT/TensorRT/tests/py/dynamo/runtime/test_aliased_io_serialization.py	2026-05-12 20:27:00.554620+00:00
@@ -15,10 +15,11 @@
  ``torch.export``. The ``inline_lifted_buffers_into_gm`` post-compile
  transform replaces what used to be an external ``BufferThreadingModule``
  wrapper — making the result a plain ``fx.GraphModule`` that exports
  naturally without a custom wrapper class.
"""
+
import tempfile

import torch
import torch_tensorrt
from torch.export import export
--- /home/runner/work/TensorRT/TensorRT/tests/py/dynamo/runtime/test_hf_static_cache_xfail.py	2026-05-12 20:26:34.861069+00:00
+++ /home/runner/work/TensorRT/TensorRT/tests/py/dynamo/runtime/test_hf_static_cache_xfail.py	2026-05-12 20:27:00.602080+00:00
@@ -36,10 +36,11 @@
  workaround that skips ``run_decompositions`` for already-decomposed EPs.

When the upstream issues are resolved or those features land, this
xfail test should start passing — flip it to a real test then.
"""
+
import unittest

import torch
import torch_tensorrt

--- /home/runner/work/TensorRT/TensorRT/tests/py/dynamo/runtime/test_aliased_io.py	2026-05-12 20:26:34.860665+00:00
+++ /home/runner/work/TensorRT/TensorRT/tests/py/dynamo/runtime/test_aliased_io.py	2026-05-12 20:27:00.676156+00:00
@@ -19,10 +19,11 @@
* ``TorchTensorRTModule.forward`` filters aliased outputs from the user
  return tuple.
* For buffer-style models, ``lift_mutated_buffers`` rewrites the EP and
  ``BufferThreadingModule`` threads buffers through each call.
"""
+
import torch
import torch_tensorrt
from torch.export import export
from torch.testing._internal.common_utils import TestCase, run_tests

--- /home/runner/work/TensorRT/TensorRT/tests/py/dynamo/runtime/test_index_copy_kv.py	2026-05-12 20:26:34.861069+00:00
+++ /home/runner/work/TensorRT/TensorRT/tests/py/dynamo/runtime/test_index_copy_kv.py	2026-05-12 20:27:00.713992+00:00
@@ -13,10 +13,11 @@

These tests verify both paths end-to-end via the C++ runtime: the
fast path mutates in place, the fallback produces correct numerical
results without aliasing.
"""
+
import torch
import torch_tensorrt
from torch.export import export
from torch.testing._internal.common_utils import TestCase, run_tests

--- /home/runner/work/TensorRT/TensorRT/tests/py/dynamo/runtime/test_lift_mutable_buffers_api.py	2026-05-12 20:26:34.861069+00:00
+++ /home/runner/work/TensorRT/TensorRT/tests/py/dynamo/runtime/test_lift_mutable_buffers_api.py	2026-05-12 20:27:00.785133+00:00
@@ -14,10 +14,11 @@
3. Construct a ``TorchTensorRTModule`` (C++ runtime — required for
   aliased I/O) with the discovered bindings.
4. Thread the buffer values in on each call and verify in-place
   mutation works (cache state persists across calls).
"""
+
import torch
import torch_tensorrt
from torch.export import export
from torch.testing._internal.common_utils import TestCase, run_tests
from torch_tensorrt.dynamo import convert_exported_program_to_serialized_trt_engine

@github-actions github-actions Bot left a comment

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There are some changes that do not conform to C++ style guidelines:

diff --git a/home/runner/work/TensorRT/TensorRT/core/runtime/execute_engine.cpp b/tmp/changes.txt
index a46ad8f..45dbf63 100644
--- a/home/runner/work/TensorRT/TensorRT/core/runtime/execute_engine.cpp
+++ b/tmp/changes.txt
@@ -335,8 +335,7 @@ std::vector<at::Tensor> execute_engine(std::vector<at::Tensor> inputs, c10::intr
            std::make_unique<torch::autograd::profiler::RecordProfile>(compiled_engine->input_profile_path);
      }

-      setup_input_tensors(
-          inputs, compiled_engine, cudagraphs_enabled, need_cudagraphs_record, bound_inputs_by_name);
+      setup_input_tensors(inputs, compiled_engine, cudagraphs_enabled, need_cudagraphs_record, bound_inputs_by_name);
      // Check if input shapes can be inferred.
      int32_t const io_size{compiled_engine->cuda_engine->getNbIOTensors()};
      std::vector<char const*> names(io_size);
@@ -494,7 +493,6 @@ std::vector<at::Tensor> execute_engine(std::vector<at::Tensor> inputs, c10::intr
    // (validated at engine construction). The bound-inputs map is unused here.
    std::unordered_map<std::string, at::Tensor> bound_inputs_by_name;

-
    { // Input Setup
      std::unique_ptr<torch::autograd::profiler::RecordProfile> input_profiler_guard;
      if (compiled_engine->profile_execution) {
ERROR: Some files do not conform to style guidelines

@github-actions github-actions Bot left a comment

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There are some changes that do not conform to Python style guidelines:

--- /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/dynamo/runtime/_TorchTensorRTModule.py	2026-06-10 20:10:37.595181+00:00
+++ /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/dynamo/runtime/_TorchTensorRTModule.py	2026-06-10 20:11:00.737050+00:00
@@ -41,10 +41,11 @@
    Optional[SerializedTensorRTEngineFmt],
    List[str],
    List[str],
]

+
def user_output_count(
    output_binding_names: List[str], aliased_io: Dict[str, Tuple[str, str]]
) -> int:
    """Derive the boundary between user-visible outputs and side-effect
    aliased outputs.
--- /home/runner/work/TensorRT/TensorRT/tests/py/dynamo/conversion/test_slice_scatter_aten.py	2026-06-10 20:10:37.618585+00:00
+++ /home/runner/work/TensorRT/TensorRT/tests/py/dynamo/conversion/test_slice_scatter_aten.py	2026-06-10 20:11:02.717030+00:00
@@ -15,10 +15,11 @@
This file covers the fallback path. To force the fallback regardless of
shape we add a small no-op (``+ 0``) to the cache so it isn't a direct
network input — the converter's "input is a placeholder" check fails and
falls through to scatter.
"""
+
import torch
from parameterized import parameterized
from torch.testing._internal.common_utils import run_tests

from .harness import DispatchTestCase
--- /home/runner/work/TensorRT/TensorRT/tests/py/dynamo/lowering/test_buffer_lifting.py	2026-06-10 20:10:37.622167+00:00
+++ /home/runner/work/TensorRT/TensorRT/tests/py/dynamo/lowering/test_buffer_lifting.py	2026-06-10 20:11:03.643555+00:00
@@ -21,10 +21,11 @@
* ``inline_lifted_buffers_into_gm`` rewrites the lifted-buffer
  placeholders into ``get_attr`` reads and registers the buffers as
  module state. The result is a plain ``fx.GraphModule`` that
  serializes via ``torch_tensorrt.save`` without an external wrapper.
"""
+
import inspect

import torch
from torch.export import export
from torch.testing._internal.common_utils import TestCase, run_tests
--- /home/runner/work/TensorRT/TensorRT/tests/py/dynamo/runtime/test_aliased_io_serialization.py	2026-06-10 20:10:37.624595+00:00
+++ /home/runner/work/TensorRT/TensorRT/tests/py/dynamo/runtime/test_aliased_io_serialization.py	2026-06-10 20:11:05.387147+00:00
@@ -15,10 +15,11 @@
  ``torch.export``. The ``inline_lifted_buffers_into_gm`` post-compile
  transform replaces what used to be an external ``BufferThreadingModule``
  wrapper — making the result a plain ``fx.GraphModule`` that exports
  naturally without a custom wrapper class.
"""
+
import tempfile

import torch
import torch_tensorrt
from torch.export import export
--- /home/runner/work/TensorRT/TensorRT/tests/py/dynamo/runtime/test_index_copy_kv.py	2026-06-10 20:10:37.624595+00:00
+++ /home/runner/work/TensorRT/TensorRT/tests/py/dynamo/runtime/test_index_copy_kv.py	2026-06-10 20:11:05.513263+00:00
@@ -13,10 +13,11 @@

These tests verify both paths end-to-end via the C++ runtime: the
fast path mutates in place, the fallback produces correct numerical
results without aliasing.
"""
+
import torch
import torch_tensorrt
from torch.export import export
from torch.testing._internal.common_utils import TestCase, run_tests

--- /home/runner/work/TensorRT/TensorRT/tests/py/dynamo/runtime/test_aliased_io.py	2026-06-10 20:10:37.624595+00:00
+++ /home/runner/work/TensorRT/TensorRT/tests/py/dynamo/runtime/test_aliased_io.py	2026-06-10 20:11:05.532849+00:00
@@ -19,10 +19,11 @@
* ``TorchTensorRTModule.forward`` filters aliased outputs from the user
  return tuple.
* For buffer-style models, ``lift_mutated_buffers`` rewrites the EP and
  ``BufferThreadingModule`` threads buffers through each call.
"""
+
import torch
import torch_tensorrt
from torch.export import export
from torch.testing._internal.common_utils import TestCase, run_tests

--- /home/runner/work/TensorRT/TensorRT/tests/py/dynamo/runtime/test_aliased_io_cudagraphs.py	2026-06-10 20:10:37.624595+00:00
+++ /home/runner/work/TensorRT/TensorRT/tests/py/dynamo/runtime/test_aliased_io_cudagraphs.py	2026-06-10 20:11:05.670076+00:00
@@ -17,10 +17,11 @@
  is already visible on the user's input).

These tests cover capture + replay correctness for both KV-cache patterns
(user-input and buffer-style).
"""
+
import unittest

import torch
import torch_tensorrt
from torch.export import export
--- /home/runner/work/TensorRT/TensorRT/tests/py/dynamo/runtime/test_lift_mutable_buffers_api.py	2026-06-10 20:10:37.624595+00:00
+++ /home/runner/work/TensorRT/TensorRT/tests/py/dynamo/runtime/test_lift_mutable_buffers_api.py	2026-06-10 20:11:05.709848+00:00
@@ -14,10 +14,11 @@
3. Construct a ``TorchTensorRTModule`` (C++ runtime — required for
   aliased I/O) with the discovered bindings.
4. Thread the buffer values in on each call and verify in-place
   mutation works (cache state persists across calls).
"""
+
import torch
import torch_tensorrt
from torch.export import export
from torch.testing._internal.common_utils import TestCase, run_tests
from torch_tensorrt.dynamo import convert_exported_program_to_serialized_trt_engine
--- /home/runner/work/TensorRT/TensorRT/tests/py/dynamo/runtime/test_hf_static_cache_xfail.py	2026-06-10 20:10:37.624595+00:00
+++ /home/runner/work/TensorRT/TensorRT/tests/py/dynamo/runtime/test_hf_static_cache_xfail.py	2026-06-10 20:11:05.721107+00:00
@@ -36,10 +36,11 @@
  workaround that skips ``run_decompositions`` for already-decomposed EPs.

When the upstream issues are resolved or those features land, this
xfail test should start passing — flip it to a real test then.
"""
+
import unittest

import torch
import torch_tensorrt

@narendasan narendasan force-pushed the narendasan/aliased_io branch from 3afcfd3 to a8fce13 Compare June 10, 2026 20:14

@github-actions github-actions Bot left a comment

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There are some changes that do not conform to C++ style guidelines:

diff --git a/home/runner/work/TensorRT/TensorRT/core/runtime/execute_engine.cpp b/tmp/changes.txt
index a46ad8f..45dbf63 100644
--- a/home/runner/work/TensorRT/TensorRT/core/runtime/execute_engine.cpp
+++ b/tmp/changes.txt
@@ -335,8 +335,7 @@ std::vector<at::Tensor> execute_engine(std::vector<at::Tensor> inputs, c10::intr
            std::make_unique<torch::autograd::profiler::RecordProfile>(compiled_engine->input_profile_path);
      }

-      setup_input_tensors(
-          inputs, compiled_engine, cudagraphs_enabled, need_cudagraphs_record, bound_inputs_by_name);
+      setup_input_tensors(inputs, compiled_engine, cudagraphs_enabled, need_cudagraphs_record, bound_inputs_by_name);
      // Check if input shapes can be inferred.
      int32_t const io_size{compiled_engine->cuda_engine->getNbIOTensors()};
      std::vector<char const*> names(io_size);
@@ -494,7 +493,6 @@ std::vector<at::Tensor> execute_engine(std::vector<at::Tensor> inputs, c10::intr
    // (validated at engine construction). The bound-inputs map is unused here.
    std::unordered_map<std::string, at::Tensor> bound_inputs_by_name;

-
    { // Input Setup
      std::unique_ptr<torch::autograd::profiler::RecordProfile> input_profiler_guard;
      if (compiled_engine->profile_execution) {
ERROR: Some files do not conform to style guidelines

@github-actions github-actions Bot left a comment

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There are some changes that do not conform to Python style guidelines:

--- /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/dynamo/runtime/_TorchTensorRTModule.py	2026-06-10 20:14:52.123795+00:00
+++ /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/dynamo/runtime/_TorchTensorRTModule.py	2026-06-10 20:15:12.601131+00:00
@@ -41,10 +41,11 @@
    Optional[SerializedTensorRTEngineFmt],
    List[str],
    List[str],
]

+
def user_output_count(
    output_binding_names: List[str], aliased_io: Dict[str, Tuple[str, str]]
) -> int:
    """Derive the boundary between user-visible outputs and side-effect
    aliased outputs.
--- /home/runner/work/TensorRT/TensorRT/tests/py/dynamo/conversion/test_slice_scatter_aten.py	2026-06-10 20:14:52.147711+00:00
+++ /home/runner/work/TensorRT/TensorRT/tests/py/dynamo/conversion/test_slice_scatter_aten.py	2026-06-10 20:15:14.572201+00:00
@@ -15,10 +15,11 @@
This file covers the fallback path. To force the fallback regardless of
shape we add a small no-op (``+ 0``) to the cache so it isn't a direct
network input — the converter's "input is a placeholder" check fails and
falls through to scatter.
"""
+
import torch
from parameterized import parameterized
from torch.testing._internal.common_utils import run_tests

from .harness import DispatchTestCase
--- /home/runner/work/TensorRT/TensorRT/tests/py/dynamo/lowering/test_buffer_lifting.py	2026-06-10 20:14:52.151132+00:00
+++ /home/runner/work/TensorRT/TensorRT/tests/py/dynamo/lowering/test_buffer_lifting.py	2026-06-10 20:15:15.343322+00:00
@@ -21,10 +21,11 @@
* ``inline_lifted_buffers_into_gm`` rewrites the lifted-buffer
  placeholders into ``get_attr`` reads and registers the buffers as
  module state. The result is a plain ``fx.GraphModule`` that
  serializes via ``torch_tensorrt.save`` without an external wrapper.
"""
+
import inspect

import torch
from torch.export import export
from torch.testing._internal.common_utils import TestCase, run_tests
--- /home/runner/work/TensorRT/TensorRT/tests/py/dynamo/runtime/test_aliased_io_cudagraphs.py	2026-06-10 20:14:52.153711+00:00
+++ /home/runner/work/TensorRT/TensorRT/tests/py/dynamo/runtime/test_aliased_io_cudagraphs.py	2026-06-10 20:15:16.838431+00:00
@@ -17,10 +17,11 @@
  is already visible on the user's input).

These tests cover capture + replay correctness for both KV-cache patterns
(user-input and buffer-style).
"""
+
import unittest

import torch
import torch_tensorrt
from torch.export import export
--- /home/runner/work/TensorRT/TensorRT/tests/py/dynamo/runtime/test_aliased_io_serialization.py	2026-06-10 20:14:52.153711+00:00
+++ /home/runner/work/TensorRT/TensorRT/tests/py/dynamo/runtime/test_aliased_io_serialization.py	2026-06-10 20:15:16.847891+00:00
@@ -15,10 +15,11 @@
  ``torch.export``. The ``inline_lifted_buffers_into_gm`` post-compile
  transform replaces what used to be an external ``BufferThreadingModule``
  wrapper — making the result a plain ``fx.GraphModule`` that exports
  naturally without a custom wrapper class.
"""
+
import tempfile

import torch
import torch_tensorrt
from torch.export import export
--- /home/runner/work/TensorRT/TensorRT/tests/py/dynamo/runtime/test_hf_static_cache_xfail.py	2026-06-10 20:14:52.153711+00:00
+++ /home/runner/work/TensorRT/TensorRT/tests/py/dynamo/runtime/test_hf_static_cache_xfail.py	2026-06-10 20:15:16.911013+00:00
@@ -36,10 +36,11 @@
  workaround that skips ``run_decompositions`` for already-decomposed EPs.

When the upstream issues are resolved or those features land, this
xfail test should start passing — flip it to a real test then.
"""
+
import unittest

import torch
import torch_tensorrt

--- /home/runner/work/TensorRT/TensorRT/tests/py/dynamo/runtime/test_aliased_io.py	2026-06-10 20:14:52.153711+00:00
+++ /home/runner/work/TensorRT/TensorRT/tests/py/dynamo/runtime/test_aliased_io.py	2026-06-10 20:15:16.974958+00:00
@@ -19,10 +19,11 @@
* ``TorchTensorRTModule.forward`` filters aliased outputs from the user
  return tuple.
* For buffer-style models, ``lift_mutated_buffers`` rewrites the EP and
  ``BufferThreadingModule`` threads buffers through each call.
"""
+
import torch
import torch_tensorrt
from torch.export import export
from torch.testing._internal.common_utils import TestCase, run_tests

--- /home/runner/work/TensorRT/TensorRT/tests/py/dynamo/runtime/test_index_copy_kv.py	2026-06-10 20:14:52.153711+00:00
+++ /home/runner/work/TensorRT/TensorRT/tests/py/dynamo/runtime/test_index_copy_kv.py	2026-06-10 20:15:17.016423+00:00
@@ -13,10 +13,11 @@

These tests verify both paths end-to-end via the C++ runtime: the
fast path mutates in place, the fallback produces correct numerical
results without aliasing.
"""
+
import torch
import torch_tensorrt
from torch.export import export
from torch.testing._internal.common_utils import TestCase, run_tests

--- /home/runner/work/TensorRT/TensorRT/tests/py/dynamo/runtime/test_lift_mutable_buffers_api.py	2026-06-10 20:14:52.153711+00:00
+++ /home/runner/work/TensorRT/TensorRT/tests/py/dynamo/runtime/test_lift_mutable_buffers_api.py	2026-06-10 20:15:17.363571+00:00
@@ -14,10 +14,11 @@
3. Construct a ``TorchTensorRTModule`` (C++ runtime — required for
   aliased I/O) with the discovered bindings.
4. Thread the buffer values in on each call and verify in-place
   mutation works (cache state persists across calls).
"""
+
import torch
import torch_tensorrt
from torch.export import export
from torch.testing._internal.common_utils import TestCase, run_tests
from torch_tensorrt.dynamo import convert_exported_program_to_serialized_trt_engine

@cehongwang

Copy link
Copy Markdown
Collaborator

Anything we could do to avoid manually resetting the KV cache before every run?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

cla signed component: api [Python] Issues re: Python API component: conversion Issues re: Conversion stage component: converters Issues re: Specific op converters component: core Issues re: The core compiler component: dynamo Issues relating to the `torch.compile` or `torch._dynamo.export` paths component: lowering Issues re: The lowering / preprocessing passes component: runtime component: tests Issues re: Tests documentation Improvements or additions to documentation

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants