From a996b5451bbfae986715c7b1a5a45be3c373c7e6 Mon Sep 17 00:00:00 2001 From: shoumikhin Date: Wed, 10 Jun 2026 16:45:24 -0700 Subject: [PATCH] Add TensorRT weight streaming support to the ExecuTorch delegate Apply a TensorRT weight streaming budget in the delegate init(), after the engine is deserialized and before the execution context is created (the budget cannot be changed once a context exists). When the engine was built with weight streaming, the delegate applies TensorRT's automatic budget by default, mirroring the PyTorch runtimes. An explicit budget can be set two ways, in order of precedence: a load-time ExecuTorch backend option ("weight_streaming_budget" runtime spec passed to Module::load), or the same key baked into the .pte at export via torch_tensorrt.save(output_format="executorch", weight_streaming_budget=N). The load-time option lets a deployment size the budget for its own GPU without re-exporting; the export-time value is the default and the only channel for loaders that cannot pass backend options yet (Python, Android). The value is a non-negative decimal byte count, string-encoded on the wire. Non-streamable engines make no budget call, so existing programs are unchanged. The one intended behavior change is that a streamable engine now applies the automatic budget on load, which enables running models whose weights exceed GPU memory. Ref #4334 --- cpp/BUILD | 14 ++ .../executorch/WeightStreamingBudget.h | 34 +++ .../torch_tensorrt/executorch/CMakeLists.txt | 1 + .../executorch/TensorRTBackend.cpp | 127 ++++++++++- .../executorch/WeightStreamingBudget.cpp | 35 +++ py/torch_tensorrt/_compile.py | 163 +++++++++++--- py/torch_tensorrt/executorch/partitioner.py | 11 +- tests/cpp/BUILD | 1 + tests/cpp/executorch/BUILD | 10 + ...est_executorch_weight_streaming_budget.cpp | 109 +++++++++ .../test_weight_streaming_budget.py | 206 ++++++++++++++++++ 11 files changed, 679 insertions(+), 32 deletions(-) create mode 100644 cpp/include/torch_tensorrt/executorch/WeightStreamingBudget.h create mode 100644 cpp/src/torch_tensorrt/executorch/WeightStreamingBudget.cpp create mode 100644 tests/cpp/executorch/test_executorch_weight_streaming_budget.cpp create mode 100644 tests/py/dynamo/executorch/test_weight_streaming_budget.py diff --git a/cpp/BUILD b/cpp/BUILD index 23509b9164..8d74717444 100644 --- a/cpp/BUILD +++ b/cpp/BUILD @@ -96,6 +96,17 @@ cc_library( strip_include_prefix = "include", ) +cc_library( + name = "tensorrt_executorch_weight_streaming_budget", + srcs = [ + "src/torch_tensorrt/executorch/WeightStreamingBudget.cpp", + ], + hdrs = [ + "include/torch_tensorrt/executorch/WeightStreamingBudget.h", + ], + strip_include_prefix = "include", +) + cc_library( name = "tensorrt_executorch_binding_names", hdrs = [ @@ -127,6 +138,7 @@ cc_library( deps = [ ":tensorrt_executorch_binding_names", ":tensorrt_executorch_blob_header", + ":tensorrt_executorch_weight_streaming_budget", ] + select({ ":linux_x86_64": [ "@executorch//:executorch_headers", @@ -150,6 +162,7 @@ filegroup( "src/torch_tensorrt/executorch/README.md", "src/torch_tensorrt/executorch/TensorRTBackend.cpp", "src/torch_tensorrt/executorch/TensorRTBlobHeader.cpp", + "src/torch_tensorrt/executorch/WeightStreamingBudget.cpp", ], ) @@ -169,5 +182,6 @@ filegroup( "include/torch_tensorrt/executorch/TensorRTBackend.h", "include/torch_tensorrt/executorch/TensorRTBindingNames.h", "include/torch_tensorrt/executorch/TensorRTBlobHeader.h", + "include/torch_tensorrt/executorch/WeightStreamingBudget.h", ], ) diff --git a/cpp/include/torch_tensorrt/executorch/WeightStreamingBudget.h b/cpp/include/torch_tensorrt/executorch/WeightStreamingBudget.h new file mode 100644 index 0000000000..8fc457542b --- /dev/null +++ b/cpp/include/torch_tensorrt/executorch/WeightStreamingBudget.h @@ -0,0 +1,34 @@ +#pragma once + +#include +#include + +namespace torch_tensorrt { +namespace executorch_backend { + +// Compile spec key that carries the weight streaming budget from export into the +// delegate. Must match WEIGHT_STREAMING_BUDGET_COMPILE_SPEC_KEY on the Python +// side (py/torch_tensorrt/executorch/partitioner.py). +inline constexpr char kWeightStreamingBudgetKey[] = "weight_streaming_budget"; + +// Result of parsing a weight streaming budget compile spec value. +// +// The value is a non-negative decimal integer: an explicit GPU budget in bytes. +// The automatic budget is intentionally not encoded here: when no budget spec is +// present, the delegate applies TensorRT's automatic budget itself, mirroring the +// PyTorch runtimes. +// +// valid == false -> a value was present but could not be parsed (or was +// negative); the caller should reject the program. +// valid == true -> bytes holds the parsed non-negative byte budget. +struct WsBudget { + bool valid = false; + int64_t bytes = 0; +}; + +// Parses a budget from a raw byte range. The value is NOT assumed to be NUL +// terminated; only the first nbytes bytes are read. +WsBudget parse_weight_streaming_budget(const void* value, std::size_t nbytes); + +} // namespace executorch_backend +} // namespace torch_tensorrt diff --git a/cpp/src/torch_tensorrt/executorch/CMakeLists.txt b/cpp/src/torch_tensorrt/executorch/CMakeLists.txt index 1b0546b545..09f71b31ca 100644 --- a/cpp/src/torch_tensorrt/executorch/CMakeLists.txt +++ b/cpp/src/torch_tensorrt/executorch/CMakeLists.txt @@ -23,6 +23,7 @@ find_package(Threads REQUIRED) set(_torchtrt_executorch_sources "${CMAKE_CURRENT_LIST_DIR}/TensorRTBackend.cpp" "${CMAKE_CURRENT_LIST_DIR}/TensorRTBlobHeader.cpp" + "${CMAKE_CURRENT_LIST_DIR}/WeightStreamingBudget.cpp" ) add_library(executorch_trt_backend STATIC ${_torchtrt_executorch_sources}) diff --git a/cpp/src/torch_tensorrt/executorch/TensorRTBackend.cpp b/cpp/src/torch_tensorrt/executorch/TensorRTBackend.cpp index b2e3b08232..8c914f6385 100644 --- a/cpp/src/torch_tensorrt/executorch/TensorRTBackend.cpp +++ b/cpp/src/torch_tensorrt/executorch/TensorRTBackend.cpp @@ -8,6 +8,7 @@ #include "torch_tensorrt/executorch/TensorRTBackend.h" #include "torch_tensorrt/executorch/TensorRTBindingNames.h" #include "torch_tensorrt/executorch/TensorRTBlobHeader.h" +#include "torch_tensorrt/executorch/WeightStreamingBudget.h" #include #include @@ -225,8 +226,6 @@ Result TensorRTBackend::init( BackendInitContext& context, FreeableBuffer* processed, ArrayRef compile_specs) const { - (void)compile_specs; - TORCHTRT_ET_CHECK_NOT_NULL(processed, Error::InvalidArgument, "TensorRTBackend::init: null processed buffer"); TORCHTRT_ET_CHECK_NOT_NULL(processed->data(), Error::InvalidArgument, "TensorRTBackend::init: null processed buffer"); @@ -284,6 +283,130 @@ Result TensorRTBackend::init( TORCHTRT_ET_CHECK_NOT_NULL( handle->engine, Error::InvalidProgram, "TensorRTBackend::init: failed to deserialize TensorRT engine"); + // Apply the weight streaming budget before the execution context is created + // below: TensorRT forbids changing the budget while a context is active. The + // budget is a non-negative decimal byte count and may come from two places, in + // order of precedence: + // 1. A load-time backend option ("weight_streaming_budget" runtime spec) that + // the caller passes to Module::load(LoadBackendOptionsMap). This lets a + // deployment size the budget for its own GPU without re-exporting. + // 2. The same key baked into the .pte as a compile spec at export, used as a + // default when no load-time option is given (and the only channel for + // loaders that cannot pass backend options yet, e.g. Python/Android). + // When neither is present and the engine supports streaming, we apply + // TensorRT's automatic budget, mirroring what the PyTorch runtimes do on + // deserialize. Negative or malformed values are rejected as InvalidProgram. + WsBudget ws_request; + bool is_explicit = false; + + // (1) A load-time runtime spec takes precedence over the baked compile spec. + // The value is a decimal byte string; a non-negative int is also accepted for + // small budgets. A present-but-wrong-type or empty option is handled explicitly + // so a runtime option is never silently dropped. The const char* returned by + // get_runtime_spec points into the caller's LoadBackendOptionsMap storage, which + // outlives init(); we parse it immediately and keep only the int64 result. + const auto ws_runtime = context.get_runtime_spec(kWeightStreamingBudgetKey); + if (ws_runtime.ok()) { + const char* const value = ws_runtime.get(); + // The option array need not be NUL terminated (the struct is public), so + // bound the scan. An empty value means "unset", so fall through to (2). + constexpr std::size_t kRuntimeBudgetMaxScan = 256; + std::size_t len = 0; + if (value != nullptr) { + while (len < kRuntimeBudgetMaxScan && value[len] != '\0') { + ++len; + } + } + if (len > 0) { + ws_request = parse_weight_streaming_budget(value, len); + if (!ws_request.valid) { + ET_LOG(Error, "TensorRTBackend::init: malformed weight_streaming_budget runtime option"); + return Error::InvalidProgram; + } + is_explicit = true; + } + } else if (ws_runtime.error() != Error::NotFound) { + // The key is present but stored as a non-string type. Accept a non-negative + // int for convenience (its 32-bit range only covers budgets under 2 GB); + // otherwise reject it so a wrong-typed option is never silently ignored. + const auto ws_runtime_int = context.get_runtime_spec(kWeightStreamingBudgetKey); + if (ws_runtime_int.ok() && ws_runtime_int.get() >= 0) { + ws_request.valid = true; + ws_request.bytes = ws_runtime_int.get(); + is_explicit = true; + } else { + ET_LOG( + Error, + "TensorRTBackend::init: weight_streaming_budget runtime option must be a " + "non-negative int or a decimal byte string"); + return Error::InvalidProgram; + } + } + + // (2) Otherwise fall back to the compile spec baked into the .pte at export. + if (!is_explicit) { + const CompileSpec* ws_spec = nullptr; + for (const auto& spec : compile_specs) { + if (spec.key != nullptr && std::strcmp(spec.key, kWeightStreamingBudgetKey) == 0) { + if (ws_spec != nullptr) { + // The budget must appear at most once; a second match means the spec + // list is inconsistent, so reject the program instead of guessing. + ET_LOG(Error, "TensorRTBackend::init: duplicate weight_streaming_budget compile spec"); + return Error::InvalidProgram; + } + ws_spec = &spec; + } + } + if (ws_spec != nullptr) { + ws_request = parse_weight_streaming_budget(ws_spec->value.buffer, ws_spec->value.nbytes); + if (!ws_request.valid) { + ET_LOG(Error, "TensorRTBackend::init: malformed weight_streaming_budget compile spec"); + return Error::InvalidProgram; + } + is_explicit = true; + } + } + + const int64_t streamable = handle->engine->getStreamableWeightsSize(); + if (streamable > 0) { + // getStreamableWeightsSize is > 0 only when the engine was built with + // BuilderFlag::kWEIGHT_STREAMING. + int64_t budget; + if (is_explicit) { + // An explicit budget is a non-negative byte count, clamped to the + // streamable size (TensorRT also caps it, but clamp for a clear log). + budget = ws_request.bytes > streamable ? streamable : ws_request.bytes; + } else { + budget = handle->engine->getWeightStreamingAutomaticBudget(); + } + if (!handle->engine->setWeightStreamingBudgetV2(budget)) { + if (!is_explicit && handle->engine->setWeightStreamingBudgetV2(0)) { + // The automatic budget could not be applied; fall back to budget 0, which + // streams all weights (minimum resident memory) and always fits. + ET_LOG(Info, "TensorRTBackend::init: automatic weight streaming budget failed; falling back to budget 0 (stream all weights)"); + } else { + ET_LOG( + Error, + "TensorRTBackend::init: setWeightStreamingBudgetV2 failed (requested=%lld%s)", + (long long)budget, + is_explicit ? "" : ", and fallback to 0 also failed"); + return Error::InvalidProgram; + } + } + ET_LOG( + Info, + "TensorRTBackend::init: weight streaming budget=%lld streamable=%lld scratch=%lld", + (long long)handle->engine->getWeightStreamingBudgetV2(), + (long long)streamable, + (long long)handle->engine->getWeightStreamingScratchMemorySize()); + } else if (is_explicit) { + // A budget was requested but the engine has no streamable weights (it was not + // built with enable_weight_streaming=True, or nothing is streamable). The + // engine is still valid and runs fully resident, so log and continue rather + // than fail; failing here would break mixed multi-engine programs. + ET_LOG(Info, "TensorRTBackend::init: weight_streaming_budget ignored; engine does not support weight streaming (build with enable_weight_streaming=True)"); + } + Error err = initialize_engine_io(*handle); if (err != Error::Ok) { return err; diff --git a/cpp/src/torch_tensorrt/executorch/WeightStreamingBudget.cpp b/cpp/src/torch_tensorrt/executorch/WeightStreamingBudget.cpp new file mode 100644 index 0000000000..3bf9ef405f --- /dev/null +++ b/cpp/src/torch_tensorrt/executorch/WeightStreamingBudget.cpp @@ -0,0 +1,35 @@ +#include "torch_tensorrt/executorch/WeightStreamingBudget.h" + +#include +#include +#include +#include + +namespace torch_tensorrt { +namespace executorch_backend { + +WsBudget parse_weight_streaming_budget(const void* value, std::size_t nbytes) { + WsBudget result; // valid == false until the value is fully parsed + + if (value == nullptr || nbytes == 0) { + return result; + } + // The value is a non-negative decimal byte budget and is not NUL terminated. + // std::from_chars consumes only ASCII digits (no leading whitespace, sign, or + // base prefix), so a leftover byte (trailing garbage or an embedded NUL), an + // out-of-range value, or a negative leaves the result invalid. + const char* const first = static_cast(value); + const char* const last = first + nbytes; + int64_t parsed = 0; + const std::from_chars_result fc = std::from_chars(first, last, parsed); + if (fc.ec != std::errc() || fc.ptr != last || parsed < 0) { + return result; + } + + result.valid = true; + result.bytes = parsed; + return result; +} + +} // namespace executorch_backend +} // namespace torch_tensorrt diff --git a/py/torch_tensorrt/_compile.py b/py/torch_tensorrt/_compile.py index f052954efb..3d09b06a06 100644 --- a/py/torch_tensorrt/_compile.py +++ b/py/torch_tensorrt/_compile.py @@ -30,32 +30,26 @@ if ENABLED_FEATURES.torchscript_frontend: import torch_tensorrt.ts - from torch_tensorrt.ts._compiler import compile as torchscript_compile from torch_tensorrt.ts._compiler import ( + compile as torchscript_compile, convert_method_to_trt_engine as ts_convert_method_to_trt_engine, ) if ENABLED_FEATURES.dynamo_frontend: from torch.export import ExportedProgram - from torch_tensorrt.dynamo._compiler import compile as dynamo_compile from torch_tensorrt.dynamo._compiler import ( + compile as dynamo_compile, convert_exported_program_to_serialized_trt_engine as dynamo_convert_exported_program_to_serialized_trt_engine, - ) - from torch_tensorrt.dynamo._compiler import ( cross_compile_for_windows as dynamo_cross_compile_for_windows, - ) - from torch_tensorrt.dynamo._compiler import ( load_cross_compiled_exported_program as dynamo_load_cross_compiled_exported_program, - ) - from torch_tensorrt.dynamo._compiler import ( save_cross_compiled_exported_program as dynamo_save_cross_compiled_exported_program, ) from torch_tensorrt.dynamo._defaults import default_device from torch_tensorrt.dynamo._tracer import ( get_dynamic_shapes_args, get_dynamic_shapes_kwargs, + trace as dynamo_trace, ) - from torch_tensorrt.dynamo._tracer import trace as dynamo_trace from torch_tensorrt.dynamo.utils import get_torch_inputs logger = logging.getLogger(__name__) @@ -728,11 +722,26 @@ def save( - If both dynamic_shapes and Input objects are provided, the explicit dynamic_shapes parameter takes precedence. - kwargs: Additional format-specific kwargs. ``partitioners=`` and - ``compile_specs=`` are only used with ``output_format="executorch"``; - otherwise they are ignored with a warning. Pass + kwargs: Additional format-specific kwargs. ``partitioners=``, + ``compile_specs=``, and ``weight_streaming_budget=`` are only used + with ``output_format="executorch"``; otherwise they are ignored + with a warning. Pass ``compile_specs=[CompileSpec("target_device", b"cuda:")]`` to override the default target device (``cuda:0``). + ``weight_streaming_budget`` controls the TensorRT weight streaming + budget for the delegate: ``None`` (the default) lets the delegate + pick an automatic budget at load, and a non-negative integer sets + an explicit GPU budget in bytes. For a program with multiple + TensorRT engines an explicit byte budget applies to each engine + independently, so prefer ``None`` there. Only an engine built with + ``enable_weight_streaming=True`` honors the budget at runtime. This + value is an export-time default; it can be overridden at load with a + ``weight_streaming_budget`` backend option so the same ``.pte`` can + be sized for different GPUs without re-exporting. That load-time + override currently works only from runtimes that pass ExecuTorch + backend options (the C++ ``Module`` API and iOS); the ExecuTorch + Python and Android loaders do not expose it yet and use this + export-time value. """ if isinstance(module, CudaGraphsTorchTensorRTModule): module = module.compiled_module @@ -759,6 +768,7 @@ def save( executorch_partitioners = kwargs.pop("partitioners", None) executorch_compile_specs = kwargs.pop("compile_specs", None) + executorch_weight_streaming_budget = kwargs.pop("weight_streaming_budget", None) if output_format not in accepted_formats: raise ValueError( @@ -770,6 +780,16 @@ def save( "with executorch.exir. Install with: pip install " "\"executorch\" to use output_format='executorch'." ) + # Recognized executorch options are popped above; any other leftover kwarg + # for output_format='executorch' is a typo (the executorch path forwards no + # other kwargs), so fail loudly instead of silently ignoring it. + if output_format == "executorch" and kwargs: + raise TypeError( + "save() received unexpected keyword argument(s) for " + f"output_format='executorch': {sorted(kwargs)}. Supported executorch " + "options are 'partitioners', 'compile_specs', and " + "'weight_streaming_budget'." + ) def _all_are_input_objects(obj: Any) -> bool: """Recursively check if all elements in nested collections are Input objects.""" @@ -880,6 +900,16 @@ def _extract_tensor(obj: Any) -> Any: "compile_specs= is only used with output_format='executorch' and will " f"be ignored for output_format='{output_format}'." ) + if executorch_weight_streaming_budget is not None and output_format != "executorch": + logger.warning( + "weight_streaming_budget= is only used with output_format='executorch' " + f"and will be ignored for output_format='{output_format}'." + ) + if executorch_weight_streaming_budget is not None and output_format == "executorch": + # Validate eagerly so an invalid budget (non-int or out-of-range) fails + # fast, before any expensive export. It is normalized again later when + # building the compile_specs. + _normalize_weight_streaming_budget(executorch_weight_streaming_budget) if output_format == "aot_inductor" and platform.system() != "Linux": raise ValueError( f"The AOT Inductor format is only supported on Linux, {platform.system()} is not a supported platform for this format" @@ -948,6 +978,7 @@ def _extract_tensor(obj: Any) -> Any: file_path, partitioners=executorch_partitioners, compile_specs=executorch_compile_specs, + weight_streaming_budget=executorch_weight_streaming_budget, ) else: raise RuntimeError( @@ -1013,6 +1044,7 @@ def _extract_tensor(obj: Any) -> Any: file_path, partitioners=executorch_partitioners, compile_specs=executorch_compile_specs, + weight_streaming_budget=executorch_weight_streaming_budget, ) else: raise RuntimeError( @@ -1099,6 +1131,7 @@ def _extract_tensor(obj: Any) -> Any: file_path, partitioners=executorch_partitioners, compile_specs=executorch_compile_specs, + weight_streaming_budget=executorch_weight_streaming_budget, ) else: raise RuntimeError( @@ -1233,9 +1266,7 @@ def _replace_execute_engine_for_executorch(exp_program: Any) -> Any: # Use FX's unique-attr-name helper so re-export passes (which may # invoke this rewriter multiple times on the same `gm`) don't # silently overwrite earlier engine buffers. - from torch.fx.experimental.const_fold import ( - get_unique_attr_name_in_module, - ) + from torch.fx.experimental.const_fold import get_unique_attr_name_in_module buffer_name = get_unique_attr_name_in_module(gm, "_trt_engine_0") gm.register_buffer(buffer_name, engine_tensor, persistent=True) @@ -1296,6 +1327,80 @@ def _replace_execute_engine_for_executorch(exp_program: Any) -> Any: return exp_program +WEIGHT_STREAMING_BUDGET_MAX_BYTES = 2**63 + + +def _normalize_weight_streaming_budget( + weight_streaming_budget: Any, +) -> Optional[bytes]: + """Validate the budget and encode it as a compile-spec value. + + ``None`` (the default) means automatic: the delegate picks the budget at load. + A non-negative integer is an explicit GPU budget in bytes. Returns the ASCII + bytes to store in the CompileSpec, or ``None`` when no budget was supplied. + """ + if weight_streaming_budget is None: + return None + # bool is an int subclass, so reject it explicitly along with non-ints. + if isinstance(weight_streaming_budget, bool) or not isinstance( + weight_streaming_budget, int + ): + raise TypeError( + "weight_streaming_budget must be a non-negative int (number of bytes) " + f"or None for automatic, got {type(weight_streaming_budget).__name__}." + ) + if ( + weight_streaming_budget < 0 + or weight_streaming_budget >= WEIGHT_STREAMING_BUDGET_MAX_BYTES + ): + raise ValueError( + "weight_streaming_budget must be in [0, 2**63), got " + f"{weight_streaming_budget}." + ) + return str(weight_streaming_budget).encode("ascii") + + +def _resolve_executorch_compile_specs( + exp_program: Any, + caller_compile_specs: Sequence[Any], + weight_streaming_budget: Any, +) -> List[Any]: + """Resolve the compile_specs passed to TensorRTPartitioner. + + Appends the explicit weight streaming budget (from the save() kwarg) as a + CompileSpec. When no budget is given nothing is added and the delegate applies + TensorRT's automatic budget itself for engines built with weight streaming. + Caller-provided compile_specs are forwarded unchanged. + """ + from executorch.exir.backend.compile_spec_schema import CompileSpec + from torch_tensorrt.executorch.partitioner import ( + WEIGHT_STREAMING_BUDGET_COMPILE_SPEC_KEY, + ) + + specs = list(caller_compile_specs) + if any( + getattr(spec, "key", None) == WEIGHT_STREAMING_BUDGET_COMPILE_SPEC_KEY + for spec in specs + ): + raise ValueError( + f"Do not pass a CompileSpec('{WEIGHT_STREAMING_BUDGET_COMPILE_SPEC_KEY}', " + "...) in compile_specs; use the weight_streaming_budget argument of " + "save() instead." + ) + spec_value = _normalize_weight_streaming_budget(weight_streaming_budget) + if spec_value is None: + return specs + + if _count_executorch_engine_nodes(exp_program) > 1: + logger.warning( + "weight_streaming_budget is an explicit byte budget but the program " + "contains multiple TensorRT engines; it is applied to each engine " + "independently. Pass None to size each engine's budget automatically." + ) + specs.append(CompileSpec(WEIGHT_STREAMING_BUDGET_COMPILE_SPEC_KEY, spec_value)) + return specs + + def _save_as_executorch(exp_program: Any, file_path: str, **kwargs: Any) -> None: """Save an ExportedProgram (with TensorRT execute_engine nodes) as an ExecuTorch .pte file. @@ -1318,16 +1423,12 @@ def _save_as_executorch(exp_program: Any, file_path: str, **kwargs: Any) -> None "\"executorch\" to use output_format='executorch'." ) import torch_tensorrt.dynamo.runtime.meta_ops.register_meta_ops # noqa: F401 - from torch_tensorrt.executorch import ( - TensorRTPartitioner, - get_edge_compile_config, - ) + from torch_tensorrt.executorch import get_edge_compile_config, TensorRTPartitioner extra_partitioners = kwargs.get("partitioners") or [] if not isinstance(extra_partitioners, (list, tuple)): raise TypeError( - "partitioners must be a list or tuple when using " - "output_format='executorch'" + "partitioners must be a list or tuple when using output_format='executorch'" ) # Forward any caller-provided compile_specs to TensorRTPartitioner so users # can override the default target_device ("cuda:0") by passing e.g. @@ -1339,9 +1440,17 @@ def _save_as_executorch(exp_program: Any, file_path: str, **kwargs: Any) -> None "compile_specs must be a list or tuple when using " "output_format='executorch'" ) - partitioners = [ - TensorRTPartitioner(compile_specs=list(executorch_compile_specs)) - ] + list(extra_partitioners) + # Resolve the weight streaming budget into the compile_specs from the save() + # kwarg. If none is given nothing is added and the delegate applies TensorRT's + # automatic budget itself for engines built with weight streaming. + resolved_compile_specs = _resolve_executorch_compile_specs( + exp_program, + list(executorch_compile_specs), + kwargs.get("weight_streaming_budget"), + ) + partitioners = [TensorRTPartitioner(compile_specs=resolved_compile_specs)] + list( + extra_partitioners + ) engine_count = _count_executorch_engine_nodes(exp_program) if engine_count > 1: @@ -1377,14 +1486,10 @@ def _normalize_engine_constants_to_python(exp_program: "ExportedProgram") -> Non import base64 from torch_tensorrt.dynamo.runtime._serialized_engine_layout import ENGINE_IDX - from torch_tensorrt.dynamo.runtime._TRTEngine import ( - EngineSerializer, - TRTEngine, - ) + from torch_tensorrt.dynamo.runtime._TRTEngine import EngineSerializer, TRTEngine for fqn, constant in list(exp_program.constants.items()): if isinstance(constant, (torch._C.ScriptObject, TRTEngine)): - state = constant.__getstate__() if len(state) == 2 and ( state[1] == "TRTEngine" diff --git a/py/torch_tensorrt/executorch/partitioner.py b/py/torch_tensorrt/executorch/partitioner.py index c7637a90b7..bba89c5538 100644 --- a/py/torch_tensorrt/executorch/partitioner.py +++ b/py/torch_tensorrt/executorch/partitioner.py @@ -15,9 +15,9 @@ from torch.fx.passes.infra.partitioner import CapabilityBasedPartitioner from torch_tensorrt.dynamo.runtime._TorchTensorRTModule import DEVICE_IDX from torch_tensorrt.executorch.backend import ( - TensorRTBackend, _get_engine_info_from_edge_program, _parse_device_id, + TensorRTBackend, ) from torch_tensorrt.executorch.operator_support import TensorRTOperatorSupport @@ -35,6 +35,15 @@ except ImportError: _TARGET_DEVICE_COMPILE_SPEC_KEY = "target_device" +# Compile spec key that carries the TensorRT weight streaming budget into the +# delegate. Must match kWeightStreamingBudgetKey on the C++ side +# (cpp/include/torch_tensorrt/executorch/WeightStreamingBudget.h). The value is a +# non-negative decimal integer of bytes (ASCII). The key is absent for the +# automatic budget, which the delegate applies itself for streamable engines. +# The delegate also reads this same key as a load-time backend option (runtime +# spec), which takes precedence over this baked value when provided at load. +WEIGHT_STREAMING_BUDGET_COMPILE_SPEC_KEY = "weight_streaming_budget" + logger = logging.getLogger(__name__) diff --git a/tests/cpp/BUILD b/tests/cpp/BUILD index aa42a07058..e2af4b8c9c 100644 --- a/tests/cpp/BUILD +++ b/tests/cpp/BUILD @@ -66,6 +66,7 @@ test_suite( tests = [ "//tests/cpp/executorch:test_executorch_binding_names", "//tests/cpp/executorch:test_executorch_blob_header", + "//tests/cpp/executorch:test_executorch_weight_streaming_budget", ], ) diff --git a/tests/cpp/executorch/BUILD b/tests/cpp/executorch/BUILD index aee13cbcfd..5d1b664659 100644 --- a/tests/cpp/executorch/BUILD +++ b/tests/cpp/executorch/BUILD @@ -7,6 +7,7 @@ test_suite( tests = [ ":test_executorch_binding_names", ":test_executorch_blob_header", + ":test_executorch_weight_streaming_budget", ], ) @@ -27,3 +28,12 @@ cc_test( "@googletest//:gtest_main", ], ) + +cc_test( + name = "test_executorch_weight_streaming_budget", + srcs = ["test_executorch_weight_streaming_budget.cpp"], + deps = [ + "//cpp:tensorrt_executorch_weight_streaming_budget", + "@googletest//:gtest_main", + ], +) diff --git a/tests/cpp/executorch/test_executorch_weight_streaming_budget.cpp b/tests/cpp/executorch/test_executorch_weight_streaming_budget.cpp new file mode 100644 index 0000000000..08edc3ec05 --- /dev/null +++ b/tests/cpp/executorch/test_executorch_weight_streaming_budget.cpp @@ -0,0 +1,109 @@ +#include "torch_tensorrt/executorch/WeightStreamingBudget.h" + +#include "gtest/gtest.h" + +#include +#include +#include + +namespace torch_tensorrt { +namespace executorch_backend { +namespace { + +// The compile spec value is a raw byte range; parse a C string literal by length. +WsBudget parse(const char* s) { + return parse_weight_streaming_budget(s, std::strlen(s)); +} + +TEST(ExecuTorchWeightStreamingBudget, ParsesZeroBytes) { + const WsBudget budget = parse("0"); + EXPECT_TRUE(budget.valid); + EXPECT_EQ(budget.bytes, 0); +} + +TEST(ExecuTorchWeightStreamingBudget, ParsesLargeByteCount) { + const WsBudget budget = parse("8589934592"); + EXPECT_TRUE(budget.valid); + EXPECT_EQ(budget.bytes, 8589934592LL); +} + +TEST(ExecuTorchWeightStreamingBudget, RejectsNegative) { + // The budget is a non-negative byte count; negatives are rejected. + EXPECT_FALSE(parse("-1").valid); +} + +TEST(ExecuTorchWeightStreamingBudget, ParsesInt64Max) { + const WsBudget budget = parse("9223372036854775807"); + EXPECT_TRUE(budget.valid); + EXPECT_EQ(budget.bytes, std::numeric_limits::max()); +} + +TEST(ExecuTorchWeightStreamingBudget, RejectsLargeNegative) { + EXPECT_FALSE(parse("-9223372036854775808").valid); +} + +TEST(ExecuTorchWeightStreamingBudget, RejectsOverflow) { + // One past the int64_t maximum, so from_chars reports result_out_of_range. + EXPECT_FALSE(parse("9223372036854775808").valid); +} + +TEST(ExecuTorchWeightStreamingBudget, RejectsEmpty) { + EXPECT_FALSE(parse("").valid); +} + +TEST(ExecuTorchWeightStreamingBudget, RejectsGarbage) { + EXPECT_FALSE(parse("garbage").valid); +} + +TEST(ExecuTorchWeightStreamingBudget, RejectsTrailingNonNumeric) { + EXPECT_FALSE(parse("12x").valid); +} + +TEST(ExecuTorchWeightStreamingBudget, RejectsInternalWhitespace) { + EXPECT_FALSE(parse("12 34").valid); +} + +TEST(ExecuTorchWeightStreamingBudget, RejectsTrailingWhitespace) { + EXPECT_FALSE(parse("42 ").valid); +} + +TEST(ExecuTorchWeightStreamingBudget, ParsesNonNulTerminatedBuffer) { + // The compile spec value is a raw byte range, not a C string. Place the value + // inside a larger buffer with no terminator after it and parse only its bytes. + const char raw[] = {'1', '2', '8', 'X', 'Y', 'Z'}; + const WsBudget budget = parse_weight_streaming_budget(raw, 3); + EXPECT_TRUE(budget.valid); + EXPECT_EQ(budget.bytes, 128); +} + +TEST(ExecuTorchWeightStreamingBudget, ZeroLengthBufferIsMalformed) { + EXPECT_FALSE(parse_weight_streaming_budget("0", 0).valid); +} + +TEST(ExecuTorchWeightStreamingBudget, ParsesLeadingZeros) { + // Leading zeros are accepted as long as the value fits in int64. + const WsBudget budget = parse("0000000001"); + EXPECT_TRUE(budget.valid); + EXPECT_EQ(budget.bytes, 1); +} + +TEST(ExecuTorchWeightStreamingBudget, RejectsEmbeddedNul) { + // The compile spec value is a raw byte range, so an interior NUL must not let + // the value parse as just the bytes before it. + const char raw[] = {'1', '2', '3', '\0', 'x'}; + EXPECT_FALSE(parse_weight_streaming_budget(raw, sizeof(raw)).valid); +} + +TEST(ExecuTorchWeightStreamingBudget, RejectsLeadingWhitespace) { + // from_chars does not skip leading whitespace, so " 42" is rejected. + EXPECT_FALSE(parse(" 42").valid); +} + +TEST(ExecuTorchWeightStreamingBudget, RejectsLeadingPlus) { + // from_chars does not accept a leading plus sign. + EXPECT_FALSE(parse("+42").valid); +} + +} // namespace +} // namespace executorch_backend +} // namespace torch_tensorrt diff --git a/tests/py/dynamo/executorch/test_weight_streaming_budget.py b/tests/py/dynamo/executorch/test_weight_streaming_budget.py new file mode 100644 index 0000000000..75123041b8 --- /dev/null +++ b/tests/py/dynamo/executorch/test_weight_streaming_budget.py @@ -0,0 +1,206 @@ +"""CPU-only tests for the ExecuTorch weight streaming budget option. + +These exercise the export-time plumbing: budget validation and the compile spec +carried into the delegate. The automatic default is applied by the C++ delegate +at load time, so it is not covered here. +""" + +import logging +from types import SimpleNamespace + +import pytest + +executorch = pytest.importorskip("executorch.exir") + +import torch # noqa: E402 +from torch_tensorrt._compile import ( # noqa: E402 + _normalize_weight_streaming_budget, + _resolve_executorch_compile_specs, + save, +) +from torch_tensorrt.executorch.partitioner import ( # noqa: E402 + WEIGHT_STREAMING_BUDGET_COMPILE_SPEC_KEY, +) + +_KEY = WEIGHT_STREAMING_BUDGET_COMPILE_SPEC_KEY + + +def _budget_spec(specs): + for spec in specs: + if spec.key == _KEY: + return spec + return None + + +@pytest.fixture +def patch_engine_count(monkeypatch): + """Patch the engine-node count so the resolver runs without a real program.""" + + def _apply(count=1): + monkeypatch.setattr( + "torch_tensorrt._compile._count_executorch_engine_nodes", + lambda exp_program: count, + ) + + return _apply + + +# --------------------------------------------------------------------------- +# _normalize_weight_streaming_budget +# --------------------------------------------------------------------------- +@pytest.mark.unit +@pytest.mark.parametrize( + "value,expected", + [ + (None, None), + (0, b"0"), + (8589934592, b"8589934592"), + ], +) +def test_normalize_valid(value, expected): + assert _normalize_weight_streaming_budget(value) == expected + + +@pytest.mark.unit +@pytest.mark.parametrize("value", [-1, -(2**63), 2**63, 2**63 + 5]) +def test_normalize_out_of_range_raises(value): + with pytest.raises(ValueError): + _normalize_weight_streaming_budget(value) + + +@pytest.mark.unit +@pytest.mark.parametrize("value", ["auto", "disabled", "1024"]) +def test_normalize_string_raises(value): + # Strings are not accepted; the budget is a non-negative int (or None). + with pytest.raises(TypeError): + _normalize_weight_streaming_budget(value) + + +@pytest.mark.unit +@pytest.mark.parametrize("value", [True, False]) +def test_normalize_bool_raises(value): + with pytest.raises(TypeError): + _normalize_weight_streaming_budget(value) + + +@pytest.mark.unit +def test_normalize_float_raises(): + with pytest.raises(TypeError): + _normalize_weight_streaming_budget(1.5) + + +# --------------------------------------------------------------------------- +# _resolve_executorch_compile_specs +# --------------------------------------------------------------------------- +@pytest.mark.unit +@pytest.mark.parametrize("budget,expected", [(0, b"0"), (8589934592, b"8589934592")]) +def test_kwarg_injects_compile_spec(patch_engine_count, budget, expected): + patch_engine_count(1) + specs = _resolve_executorch_compile_specs(SimpleNamespace(), [], budget) + spec = _budget_spec(specs) + assert spec is not None + assert spec.value == expected + + +@pytest.mark.unit +def test_kwarg_spec_lands_on_delegation_spec(patch_engine_count): + from torch_tensorrt.executorch.partitioner import TensorRTPartitioner + + patch_engine_count(1) + specs = _resolve_executorch_compile_specs(SimpleNamespace(), [], 8589934592) + partitioner = TensorRTPartitioner(compile_specs=specs) + spec = _budget_spec(partitioner.delegation_spec.compile_specs) + assert spec is not None + assert spec.value == b"8589934592" + + +@pytest.mark.unit +def test_no_spec_injected_without_budget(): + # No budget: nothing is injected. The delegate applies the automatic budget + # itself for streaming-built engines. + specs = _resolve_executorch_compile_specs(SimpleNamespace(), [], None) + assert _budget_spec(specs) is None + + +@pytest.mark.unit +def test_caller_compile_specs_passed_through(): + # Non-budget caller compile_specs are forwarded unchanged. + sentinel = SimpleNamespace(key="target_device", value=b"cuda:1") + specs = _resolve_executorch_compile_specs(SimpleNamespace(), [sentinel], None) + assert sentinel in specs + assert _budget_spec(specs) is None + + +@pytest.mark.unit +def test_caller_budget_spec_in_compile_specs_raises(): + # The budget must come from the kwarg, not a manually-pinned compile spec. + spec = SimpleNamespace(key=_KEY, value=b"4096") + with pytest.raises(ValueError): + _resolve_executorch_compile_specs(SimpleNamespace(), [spec], None) + + +# --------------------------------------------------------------------------- +# Multi-engine warning +# --------------------------------------------------------------------------- +@pytest.mark.unit +def test_multi_engine_explicit_warns(patch_engine_count, caplog): + patch_engine_count(2) + with caplog.at_level(logging.WARNING, logger="torch_tensorrt._compile"): + _resolve_executorch_compile_specs(SimpleNamespace(), [], 4096) + assert "multiple TensorRT engines" in caplog.text + + +@pytest.mark.unit +def test_multi_engine_none_does_not_warn(patch_engine_count, caplog): + patch_engine_count(2) + with caplog.at_level(logging.WARNING, logger="torch_tensorrt._compile"): + specs = _resolve_executorch_compile_specs(SimpleNamespace(), [], None) + assert _budget_spec(specs) is None + assert "multiple TensorRT engines" not in caplog.text + + +# --------------------------------------------------------------------------- +# save() entry-point guards +# --------------------------------------------------------------------------- +@pytest.mark.unit +def test_save_rejects_bool_budget(tmp_path): + with pytest.raises(TypeError): + save( + torch.nn.Linear(1, 1), + str(tmp_path / "model.pte"), + output_format="executorch", + weight_streaming_budget=True, + ) + + +@pytest.mark.unit +def test_save_rejects_string_budget(tmp_path): + with pytest.raises(TypeError): + save( + torch.nn.Linear(1, 1), + str(tmp_path / "model.pte"), + output_format="executorch", + weight_streaming_budget="auto", + ) + + +@pytest.mark.unit +def test_save_rejects_negative_budget(tmp_path): + with pytest.raises(ValueError): + save( + torch.nn.Linear(1, 1), + str(tmp_path / "model.pte"), + output_format="executorch", + weight_streaming_budget=-1, + ) + + +@pytest.mark.unit +def test_save_rejects_unknown_executorch_kwarg(tmp_path): + with pytest.raises(TypeError, match="unexpected keyword argument"): + save( + torch.nn.Linear(1, 1), + str(tmp_path / "model.pte"), + output_format="executorch", + weight_streaming_budgett=4096, + )