Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
73 changes: 62 additions & 11 deletions py/torch_tensorrt/executorch/partitioner.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
# ExecuTorch partitioner: partition by execute_engine nodes.

import logging
from typing import Callable, Dict, List, Optional, Tuple

import torch
Expand All @@ -12,7 +13,12 @@
from executorch.exir.backend.utils import tag_constant_data
from torch.export import ExportedProgram
from torch.fx.passes.infra.partitioner import CapabilityBasedPartitioner
from torch_tensorrt.executorch.backend import TensorRTBackend
from torch_tensorrt.dynamo.runtime._TorchTensorRTModule import DEVICE_IDX
from torch_tensorrt.executorch.backend import (
TensorRTBackend,
_get_engine_info_from_edge_program,
_parse_device_id,
)
from torch_tensorrt.executorch.operator_support import TensorRTOperatorSupport

# Key recognized by ExecuTorch's PropagateDevicePass that tags delegate I/O
Expand All @@ -29,6 +35,8 @@
except ImportError:
_TARGET_DEVICE_COMPILE_SPEC_KEY = "target_device"

logger = logging.getLogger(__name__)


class TensorRTPartitioner(Partitioner): # type: ignore[misc]
"""Partitions the graph for TensorRT delegation.
Expand All @@ -42,6 +50,11 @@ class TensorRTPartitioner(Partitioner): # type: ignore[misc]
Callers targeting a non-default GPU should pre-populate
``compile_specs`` with the desired ``CompileSpec("target_device",
b"cuda:<index>")`` to override the default.

Note: ``target_device`` is AOT metadata only -- it drives ExecuTorch's
PropagateDevicePass tagging at export time. At runtime the C++ backend
selects the GPU from the device baked into the serialized engine blob,
not from this value.
"""

def __init__(
Expand All @@ -50,21 +63,45 @@ def __init__(
) -> None:
super().__init__()
self.compile_specs = list(compile_specs) if compile_specs else []
# Mirror CudaPartitioner: emit a target_device CompileSpec so that
# ExecuTorch's PropagateDevicePass tags delegate I/O TensorSpecs with
# the correct device, which is then serialized into the .pte's
# extra_tensor_info.device_type field.
if not any(
# Mirror CudaPartitioner: a target_device CompileSpec drives ExecuTorch's
# PropagateDevicePass, which tags delegate I/O TensorSpecs with the device
# and serializes it into the .pte's extra_tensor_info. When the caller pins
# it we use that verbatim; otherwise it is derived per export from the
# engine's real device in partition() (engine nodes are not available here)
# so a cuda:N engine is not mislabeled cuda:0.
self._has_explicit_target_device = any(
s.key == _TARGET_DEVICE_COMPILE_SPEC_KEY for s in self.compile_specs
):
self.compile_specs.append(
CompileSpec(_TARGET_DEVICE_COMPILE_SPEC_KEY, b"cuda:0")
)
)
self.delegation_spec = DelegationSpec(
backend_id=TensorRTBackend.__name__,
compile_specs=self.compile_specs,
)

def _resolve_target_device(self, exported_program: ExportedProgram) -> bytes:
"""Best-effort ``target_device`` for the delegate-boundary TensorSpecs.

Reuses the backend's own engine-info extraction so the device index
cannot drift from the runtime blob. Any extraction failure -- no single
engine node (zero or multiple TRT partitions) or an unreadable index --
falls back to ``cuda:0``; per-partition multi-GPU labeling is left to a
follow-up.
"""
try:
engine_info = _get_engine_info_from_edge_program(exported_program)
return f"cuda:{_parse_device_id(engine_info[DEVICE_IDX])}".encode()
except Exception as e:
# Broad by design: any extraction failure must fall back, not abort
# the export. Warn so a non-default GPU silently labeled cuda:0 stays
# diagnosable.
logger.warning(
"Could not derive target_device from the TensorRT engine (%s); "
"falling back to cuda:0. A non-default GPU engine may be "
'mislabeled -- pin it via CompileSpec("target_device", '
'b"cuda:<index>").',
e,
)
return b"cuda:0"

def partition(self, exported_program: ExportedProgram) -> PartitionResult:
capability_partitioner = CapabilityBasedPartitioner(
exported_program.graph_module,
Expand All @@ -73,12 +110,26 @@ def partition(self, exported_program: ExportedProgram) -> PartitionResult:
)
partition_list = capability_partitioner.propose_partitions()

if self._has_explicit_target_device:
delegation_spec = self.delegation_spec
else:
delegation_spec = DelegationSpec(
backend_id=TensorRTBackend.__name__,
compile_specs=self.compile_specs
+ [
CompileSpec(
_TARGET_DEVICE_COMPILE_SPEC_KEY,
self._resolve_target_device(exported_program),
)
],
)

partition_tags: Dict[str, DelegationSpec] = {}
for partition in partition_list:
tag = f"tensorrt_{partition.id}"
for node in partition.nodes:
node.meta["delegation_tag"] = tag
partition_tags[tag] = self.delegation_spec
partition_tags[tag] = delegation_spec

tag_constant_data(exported_program)

Expand Down
118 changes: 118 additions & 0 deletions tests/py/dynamo/executorch/test_partitioner_target_device.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,118 @@
from types import SimpleNamespace

import pytest

executorch = pytest.importorskip("executorch.exir")

import torch # noqa: E402
from executorch.exir.backend.compile_spec_schema import CompileSpec # noqa: E402
from torch_tensorrt.dynamo.runtime._TorchTensorRTModule import ( # noqa: E402
DEVICE_IDX,
ENGINE_IDX,
SERIALIZATION_LEN,
)
from torch_tensorrt.executorch.partitioner import ( # noqa: E402
_TARGET_DEVICE_COMPILE_SPEC_KEY,
TensorRTPartitioner,
)


# A realistic single-engine edge program so partition() runs the *real*
# _get_engine_info_from_edge_program / _parse_device_id path. That is what
# guards "an engine node is present and its device is extractable at partition()
# time" -- a monkeypatched extractor would not. Mirrors the mocked edge programs
# in tests/py/dynamo/executorch/test_backend.py.
class _SchemaTarget:
def __init__(self, name):
self._schema = SimpleNamespace(name=name)


def _engine_node(device_id):
engine_info = [""] * SERIALIZATION_LEN
engine_info[ENGINE_IDX] = torch.frombuffer(bytearray(b"engine"), dtype=torch.uint8)
engine_info[DEVICE_IDX] = device_id
return SimpleNamespace(
op="call_function",
target=_SchemaTarget("tensorrt::no_op_placeholder_for_execute_engine"),
args=(["x"], *engine_info),
name="trt_node",
)


def _edge_program(*nodes):
return SimpleNamespace(
graph_module=SimpleNamespace(graph=SimpleNamespace(nodes=list(nodes))),
constants={},
)


class _FakeCapabilityPartitioner:
def __init__(self, *args, **kwargs):
pass

def propose_partitions(self):
return [SimpleNamespace(id=1, nodes=[SimpleNamespace(meta={})])]


@pytest.fixture(autouse=True)
def _stub_partition_internals(monkeypatch):
# Both need a real fx GraphModule, so stub them out -- the engine-info
# extraction under test still runs for real against the mocked node.
monkeypatch.setattr(
"torch_tensorrt.executorch.partitioner.CapabilityBasedPartitioner",
_FakeCapabilityPartitioner,
)
monkeypatch.setattr(
"torch_tensorrt.executorch.partitioner.tag_constant_data",
lambda exported_program: None,
)


def _target_device(result):
spec = result.partition_tags["tensorrt_1"]
for cs in spec.compile_specs:
if cs.key == _TARGET_DEVICE_COMPILE_SPEC_KEY:
return cs.value
return None


@pytest.mark.unit
def test_target_device_derived_for_default_gpu():
result = TensorRTPartitioner().partition(_edge_program(_engine_node("0")))
assert _target_device(result) == b"cuda:0"


@pytest.mark.unit
def test_target_device_derived_for_nonzero_gpu():
# The bug this fixes: a cuda:1 engine must not be mislabeled cuda:0.
result = TensorRTPartitioner().partition(_edge_program(_engine_node("1")))
assert _target_device(result) == b"cuda:1"


@pytest.mark.unit
def test_target_device_falls_back_to_cuda0_on_multiple_engines():
# >1 engine node -> real extraction raises -> contract fallback to cuda:0.
result = TensorRTPartitioner().partition(
_edge_program(_engine_node("1"), _engine_node("2"))
)
assert _target_device(result) == b"cuda:0"


@pytest.mark.unit
def test_target_device_falls_back_to_cuda0_on_malformed_graph():
# An unexpected graph shape makes the real extraction raise; the broadened
# except must still fall back to cuda:0 rather than abort the export.
bad_node = SimpleNamespace(op="call_function", target=SimpleNamespace(), name="x")
result = TensorRTPartitioner().partition(_edge_program(bad_node))
assert _target_device(result) == b"cuda:0"


@pytest.mark.unit
def test_explicit_target_device_used_verbatim():
# Engine reports cuda:0, but the caller pinned cuda:3 -> the pin wins and
# extraction is skipped entirely.
partitioner = TensorRTPartitioner(
compile_specs=[CompileSpec(_TARGET_DEVICE_COMPILE_SPEC_KEY, b"cuda:3")]
)
result = partitioner.partition(_edge_program(_engine_node("0")))
assert _target_device(result) == b"cuda:3"
Loading