Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 14 additions & 0 deletions cpp/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,17 @@ cc_library(
strip_include_prefix = "include",
)

cc_library(
name = "tensorrt_executorch_weight_streaming_budget",
srcs = [
"src/torch_tensorrt/executorch/WeightStreamingBudget.cpp",
],
hdrs = [
"include/torch_tensorrt/executorch/WeightStreamingBudget.h",
],
strip_include_prefix = "include",
)

cc_library(
name = "tensorrt_executorch_binding_names",
hdrs = [
Expand Down Expand Up @@ -127,6 +138,7 @@ cc_library(
deps = [
":tensorrt_executorch_binding_names",
":tensorrt_executorch_blob_header",
":tensorrt_executorch_weight_streaming_budget",
] + select({
":linux_x86_64": [
"@executorch//:executorch_headers",
Expand All @@ -150,6 +162,7 @@ filegroup(
"src/torch_tensorrt/executorch/README.md",
"src/torch_tensorrt/executorch/TensorRTBackend.cpp",
"src/torch_tensorrt/executorch/TensorRTBlobHeader.cpp",
"src/torch_tensorrt/executorch/WeightStreamingBudget.cpp",
],
)

Expand All @@ -169,5 +182,6 @@ filegroup(
"include/torch_tensorrt/executorch/TensorRTBackend.h",
"include/torch_tensorrt/executorch/TensorRTBindingNames.h",
"include/torch_tensorrt/executorch/TensorRTBlobHeader.h",
"include/torch_tensorrt/executorch/WeightStreamingBudget.h",
],
)
34 changes: 34 additions & 0 deletions cpp/include/torch_tensorrt/executorch/WeightStreamingBudget.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
#pragma once

#include <cstddef>
#include <cstdint>

namespace torch_tensorrt {
namespace executorch_backend {

// Compile spec key that carries the weight streaming budget from export into the
// delegate. Must match WEIGHT_STREAMING_BUDGET_COMPILE_SPEC_KEY on the Python
// side (py/torch_tensorrt/executorch/partitioner.py).
inline constexpr char kWeightStreamingBudgetKey[] = "weight_streaming_budget";

// Result of parsing a weight streaming budget compile spec value.
//
// The value is a non-negative decimal integer: an explicit GPU budget in bytes.
// The automatic budget is intentionally not encoded here: when no budget spec is
// present, the delegate applies TensorRT's automatic budget itself, mirroring the
// PyTorch runtimes.
//
// valid == false -> a value was present but could not be parsed (or was
// negative); the caller should reject the program.
// valid == true -> bytes holds the parsed non-negative byte budget.
struct WsBudget {
bool valid = false;
int64_t bytes = 0;
};

// Parses a budget from a raw byte range. The value is NOT assumed to be NUL
// terminated; only the first nbytes bytes are read.
WsBudget parse_weight_streaming_budget(const void* value, std::size_t nbytes);

} // namespace executorch_backend
} // namespace torch_tensorrt
1 change: 1 addition & 0 deletions cpp/src/torch_tensorrt/executorch/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ find_package(Threads REQUIRED)
set(_torchtrt_executorch_sources
"${CMAKE_CURRENT_LIST_DIR}/TensorRTBackend.cpp"
"${CMAKE_CURRENT_LIST_DIR}/TensorRTBlobHeader.cpp"
"${CMAKE_CURRENT_LIST_DIR}/WeightStreamingBudget.cpp"
)

add_library(executorch_trt_backend STATIC ${_torchtrt_executorch_sources})
Expand Down
127 changes: 125 additions & 2 deletions cpp/src/torch_tensorrt/executorch/TensorRTBackend.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
#include "torch_tensorrt/executorch/TensorRTBackend.h"
#include "torch_tensorrt/executorch/TensorRTBindingNames.h"
#include "torch_tensorrt/executorch/TensorRTBlobHeader.h"
#include "torch_tensorrt/executorch/WeightStreamingBudget.h"

#include <cstdint>
#include <cstring>
Expand Down Expand Up @@ -225,8 +226,6 @@ Result<DelegateHandle*> TensorRTBackend::init(
BackendInitContext& context,
FreeableBuffer* processed,
ArrayRef<CompileSpec> compile_specs) const {
(void)compile_specs;

TORCHTRT_ET_CHECK_NOT_NULL(processed, Error::InvalidArgument, "TensorRTBackend::init: null processed buffer");
TORCHTRT_ET_CHECK_NOT_NULL(processed->data(), Error::InvalidArgument, "TensorRTBackend::init: null processed buffer");

Expand Down Expand Up @@ -284,6 +283,130 @@ Result<DelegateHandle*> TensorRTBackend::init(
TORCHTRT_ET_CHECK_NOT_NULL(
handle->engine, Error::InvalidProgram, "TensorRTBackend::init: failed to deserialize TensorRT engine");

// Apply the weight streaming budget before the execution context is created
// below: TensorRT forbids changing the budget while a context is active. The
// budget is a non-negative decimal byte count and may come from two places, in
// order of precedence:
// 1. A load-time backend option ("weight_streaming_budget" runtime spec) that
// the caller passes to Module::load(LoadBackendOptionsMap). This lets a
// deployment size the budget for its own GPU without re-exporting.
// 2. The same key baked into the .pte as a compile spec at export, used as a
// default when no load-time option is given (and the only channel for
// loaders that cannot pass backend options yet, e.g. Python/Android).
// When neither is present and the engine supports streaming, we apply
// TensorRT's automatic budget, mirroring what the PyTorch runtimes do on
// deserialize. Negative or malformed values are rejected as InvalidProgram.
WsBudget ws_request;
bool is_explicit = false;

// (1) A load-time runtime spec takes precedence over the baked compile spec.
// The value is a decimal byte string; a non-negative int is also accepted for
// small budgets. A present-but-wrong-type or empty option is handled explicitly
// so a runtime option is never silently dropped. The const char* returned by
// get_runtime_spec points into the caller's LoadBackendOptionsMap storage, which
// outlives init(); we parse it immediately and keep only the int64 result.
const auto ws_runtime = context.get_runtime_spec<const char*>(kWeightStreamingBudgetKey);
if (ws_runtime.ok()) {
const char* const value = ws_runtime.get();
// The option array need not be NUL terminated (the struct is public), so
// bound the scan. An empty value means "unset", so fall through to (2).
constexpr std::size_t kRuntimeBudgetMaxScan = 256;
std::size_t len = 0;
if (value != nullptr) {
while (len < kRuntimeBudgetMaxScan && value[len] != '\0') {
++len;
}
}
if (len > 0) {
ws_request = parse_weight_streaming_budget(value, len);
if (!ws_request.valid) {
ET_LOG(Error, "TensorRTBackend::init: malformed weight_streaming_budget runtime option");
return Error::InvalidProgram;
}
is_explicit = true;
}

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What happens if len = 0? Give a warning to the user

} else if (ws_runtime.error() != Error::NotFound) {
// The key is present but stored as a non-string type. Accept a non-negative
// int for convenience (its 32-bit range only covers budgets under 2 GB);
// otherwise reject it so a wrong-typed option is never silently ignored.
const auto ws_runtime_int = context.get_runtime_spec<int>(kWeightStreamingBudgetKey);
if (ws_runtime_int.ok() && ws_runtime_int.get() >= 0) {
ws_request.valid = true;
ws_request.bytes = ws_runtime_int.get();
is_explicit = true;
} else {
ET_LOG(
Error,
"TensorRTBackend::init: weight_streaming_budget runtime option must be a "
"non-negative int or a decimal byte string");
return Error::InvalidProgram;
}
}

// (2) Otherwise fall back to the compile spec baked into the .pte at export.
if (!is_explicit) {
const CompileSpec* ws_spec = nullptr;
for (const auto& spec : compile_specs) {
if (spec.key != nullptr && std::strcmp(spec.key, kWeightStreamingBudgetKey) == 0) {
if (ws_spec != nullptr) {
// The budget must appear at most once; a second match means the spec
// list is inconsistent, so reject the program instead of guessing.
ET_LOG(Error, "TensorRTBackend::init: duplicate weight_streaming_budget compile spec");
return Error::InvalidProgram;
}
ws_spec = &spec;
}
}
if (ws_spec != nullptr) {
ws_request = parse_weight_streaming_budget(ws_spec->value.buffer, ws_spec->value.nbytes);
if (!ws_request.valid) {
ET_LOG(Error, "TensorRTBackend::init: malformed weight_streaming_budget compile spec");
return Error::InvalidProgram;
}
is_explicit = true;
}
}

const int64_t streamable = handle->engine->getStreamableWeightsSize();
if (streamable > 0) {
// getStreamableWeightsSize is > 0 only when the engine was built with
// BuilderFlag::kWEIGHT_STREAMING.
int64_t budget;
if (is_explicit) {
// An explicit budget is a non-negative byte count, clamped to the
// streamable size (TensorRT also caps it, but clamp for a clear log).
budget = ws_request.bytes > streamable ? streamable : ws_request.bytes;
} else {
budget = handle->engine->getWeightStreamingAutomaticBudget();
}
if (!handle->engine->setWeightStreamingBudgetV2(budget)) {
if (!is_explicit && handle->engine->setWeightStreamingBudgetV2(0)) {
// The automatic budget could not be applied; fall back to budget 0, which
// streams all weights (minimum resident memory) and always fits.
ET_LOG(Info, "TensorRTBackend::init: automatic weight streaming budget failed; falling back to budget 0 (stream all weights)");
} else {
ET_LOG(
Error,
"TensorRTBackend::init: setWeightStreamingBudgetV2 failed (requested=%lld%s)",
(long long)budget,
is_explicit ? "" : ", and fallback to 0 also failed");
return Error::InvalidProgram;
}
}
ET_LOG(
Info,
"TensorRTBackend::init: weight streaming budget=%lld streamable=%lld scratch=%lld",
(long long)handle->engine->getWeightStreamingBudgetV2(),
(long long)streamable,
(long long)handle->engine->getWeightStreamingScratchMemorySize());
} else if (is_explicit) {
// A budget was requested but the engine has no streamable weights (it was not
// built with enable_weight_streaming=True, or nothing is streamable). The
// engine is still valid and runs fully resident, so log and continue rather
// than fail; failing here would break mixed multi-engine programs.
ET_LOG(Info, "TensorRTBackend::init: weight_streaming_budget ignored; engine does not support weight streaming (build with enable_weight_streaming=True)");
}

Error err = initialize_engine_io(*handle);
if (err != Error::Ok) {
return err;
Expand Down
35 changes: 35 additions & 0 deletions cpp/src/torch_tensorrt/executorch/WeightStreamingBudget.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
#include "torch_tensorrt/executorch/WeightStreamingBudget.h"

#include <charconv>
#include <cstddef>
#include <cstdint>
#include <system_error>

namespace torch_tensorrt {
namespace executorch_backend {

WsBudget parse_weight_streaming_budget(const void* value, std::size_t nbytes) {
WsBudget result; // valid == false until the value is fully parsed

if (value == nullptr || nbytes == 0) {
return result;
}
// The value is a non-negative decimal byte budget and is not NUL terminated.
// std::from_chars consumes only ASCII digits (no leading whitespace, sign, or
// base prefix), so a leftover byte (trailing garbage or an embedded NUL), an
// out-of-range value, or a negative leaves the result invalid.
const char* const first = static_cast<const char*>(value);
const char* const last = first + nbytes;
int64_t parsed = 0;
const std::from_chars_result fc = std::from_chars(first, last, parsed);
if (fc.ec != std::errc() || fc.ptr != last || parsed < 0) {
return result;
}

result.valid = true;
result.bytes = parsed;
return result;
}

} // namespace executorch_backend
} // namespace torch_tensorrt
Loading
Loading