diff --git a/core/runtime/BUILD b/core/runtime/BUILD index 19260149ae..796b0d3c2d 100644 --- a/core/runtime/BUILD +++ b/core/runtime/BUILD @@ -1,6 +1,7 @@ load("@rules_cc//cc:defs.bzl", "cc_library") load("@rules_pkg//:pkg.bzl", "pkg_tar") load("@rules_pkg//pkg:mappings.bzl", "pkg_files") + package(default_visibility = ["//visibility:public"]) config_setting( @@ -66,6 +67,7 @@ cc_library( "RTDevice.cpp", "TRTEngine.cpp", "TRTEngineProfiler.cpp", + "TRTRuntimeConfig.cpp", "execute_engine.cpp", "register_jit_hooks.cpp", "runtime.cpp", @@ -75,11 +77,19 @@ cc_library( "RTDevice.h", "TRTEngine.h", "TRTEngineProfiler.h", + "TRTRuntimeConfig.h", "runtime.h", ], linkopts = [ "-lstdc++fs", ], + local_defines = select({ + # TensorRT-RTX builds: opt into feature-gated APIs that the runtime layer + # depends on (e.g. IExecutionContext::isStreamCapturable). + ":rtx_win": ["ENABLE_FEATURE_DISABLE_RUNTIME_ALLOCATION"], + ":rtx_x86_64": ["ENABLE_FEATURE_DISABLE_RUNTIME_ALLOCATION"], + "//conditions:default": [], + }), deps = [ "//core/plugins:torch_tensorrt_plugins", "//core/util:prelude", @@ -107,6 +117,7 @@ filegroup( "RTDevice.h", "TRTEngine.h", "TRTEngineProfiler.h", + "TRTRuntimeConfig.h", "runtime.h", ], visibility = ["//visibility:public"], @@ -121,6 +132,6 @@ pkg_tar( pkg_files( name = "include_pkg_files", srcs = [":include_files"], - visibility = ["//visibility:public"], prefix = "include/torch_tensorrt/core/runtime/", + visibility = ["//visibility:public"], ) diff --git a/core/runtime/TRTEngine.cpp b/core/runtime/TRTEngine.cpp index d29daa112b..51efa2388a 100644 --- a/core/runtime/TRTEngine.cpp +++ b/core/runtime/TRTEngine.cpp @@ -1,4 +1,5 @@ #include +#include #include #include "NvInfer.h" @@ -54,26 +55,28 @@ void DynamicOutputAllocator::notifyShape(char const* tensorName, nvinfer1::Dims } TRTEngine::TRTEngine( - const std::string& serialized_engine, + std::string serialized_engine, const RTDevice& cuda_device, const std::vector& _in_binding_names, const std::vector& _out_binding_names, const Platform& target_platform, bool hardware_compatible, bool requires_output_allocator, - const std::string& serialized_metadata, - const ResourceAllocationStrategy resource_allocation_strategy) + std::string serialized_metadata, + const ResourceAllocationStrategy resource_allocation_strategy, + TRTRuntimeConfig runtime_cfg) : TRTEngine( "deserialized_trt", - serialized_engine, + std::move(serialized_engine), cuda_device, _in_binding_names, _out_binding_names, target_platform, hardware_compatible, requires_output_allocator, - serialized_metadata, - resource_allocation_strategy) {} + std::move(serialized_metadata), + resource_allocation_strategy, + std::move(runtime_cfg)) {} TRTEngine::TRTEngine(std::vector serialized_info) : TRTEngine( @@ -88,19 +91,22 @@ TRTEngine::TRTEngine(std::vector serialized_info) serialized_info[SERIALIZED_METADATA_IDX], (static_cast(std::stoi(serialized_info[RESOURCE_ALLOCATION_STRATEGY_IDX])) ? ResourceAllocationStrategy::kDynamic - : ResourceAllocationStrategy::kStatic)) {} + : ResourceAllocationStrategy::kStatic), + make_runtime_config_from_serialized(serialized_info)) {} TRTEngine::TRTEngine( - const std::string& mod_name, - const std::string& serialized_engine, + std::string mod_name, + std::string serialized_engine, const RTDevice& cuda_device, const std::vector& _in_binding_names, const std::vector& _out_binding_names, const Platform& target_platform, bool hardware_compatible, bool requires_output_allocator, - const std::string& serialized_metadata, - const ResourceAllocationStrategy resource_allocation_strategy) { + std::string serialized_metadata, + const ResourceAllocationStrategy resource_allocation_strategy, + TRTRuntimeConfig runtime_cfg) { + this->runtime_cfg = std::move(runtime_cfg); TORCHTRT_CHECK( is_supported_on_current_platform(target_platform), "This engine was not built to run on this platform (built for: " << target_platform << ", current platform: " @@ -111,7 +117,7 @@ TRTEngine::TRTEngine( auto most_compatible_device = get_most_compatible_device(cuda_device, RTDevice(), hardware_compatible); TORCHTRT_CHECK(most_compatible_device, "No compatible device was found for instantiating TensorRT engine"); - this->serialized_metadata = serialized_metadata; + this->serialized_metadata = std::move(serialized_metadata); this->requires_output_allocator = requires_output_allocator; device_info = most_compatible_device.value(); multi_gpu_device_check(); @@ -119,7 +125,7 @@ TRTEngine::TRTEngine( rt = make_trt(nvinfer1::createInferRuntime(util::logging::get_logger())); - name = slugify(mod_name); + name = slugify(std::move(mod_name)); cuda_engine = make_trt(rt->deserializeCudaEngine(serialized_engine.c_str(), serialized_engine.size())); TORCHTRT_CHECK((cuda_engine.get() != nullptr), "Unable to deserialize the TensorRT engine"); @@ -134,13 +140,7 @@ TRTEngine::TRTEngine( LOG_DEBUG( "Resource allocation strategy: " << (this->resource_allocation_strategy == ResourceAllocationStrategy::kDynamic ? "Dynamic" : "Static")); - if (this->resource_allocation_strategy == ResourceAllocationStrategy::kDynamic) { - this->exec_ctx = - make_trt(cuda_engine->createExecutionContext(nvinfer1::ExecutionContextAllocationStrategy::kUSER_MANAGED)); - } else { - this->exec_ctx = make_trt(cuda_engine->createExecutionContext()); - } - TORCHTRT_CHECK((exec_ctx.get() != nullptr), "Unable to create TensorRT execution context"); + recreate_execution_context(); // Pre-allocate placeholder for empty tensors (TensorRT requires non-null addresses) cudaMalloc(&empty_tensor_placeholder, 1); @@ -265,6 +265,9 @@ TRTEngine::TRTEngine( TRTEngine::~TRTEngine() { torch::cuda::synchronize(device_info.id); + // Marked noexcept by the type system, so safe to invoke from a destructor without + // explicit try/catch; any I/O error is logged internally. + runtime_cfg.save_runtime_cache(); trt_engine_profiler.reset(); exec_ctx.reset(); cuda_engine.reset(); @@ -278,8 +281,7 @@ void TRTEngine::disable_profiling() { torch::cuda::synchronize(device_info.id); profile_execution = false; trt_engine_profiler.reset(); - exec_ctx = make_trt(cuda_engine->createExecutionContext()); - TORCHTRT_CHECK((exec_ctx.get() != nullptr), "Unable to recreate TensorRT execution context"); + recreate_execution_context(); } void TRTEngine::dump_engine_layer_info_to_file(const std::string& path) { @@ -376,10 +378,7 @@ bool TRTEngine::set_device_memory_budget(int64_t budget) { trt_engine_profiler.reset(); } bool result = cuda_engine->setWeightStreamingBudgetV2(budget); - exec_ctx = make_trt(cuda_engine->createExecutionContext()); - TORCHTRT_CHECK( - (exec_ctx.get() != nullptr), - "Unable to recreate TensorRT execution context after setting new device memory budget"); + recreate_execution_context(); if (profile_execution) { enable_profiling(); } @@ -428,6 +427,7 @@ std::string TRTEngine::to_str() const { ss << " Hardware Compatibility: " << (hardware_compatible ? "Enabled" : "Disabled") << std::endl; ss << " Target Platform: " << target_platform << std::endl; ss << " Resource Allocation Strategy: " << (resource_allocation_strategy == ResourceAllocationStrategy::kDynamic ? "Dynamic" : "Static") << std::endl; + ss << runtime_cfg.to_str(); // clang-format on return ss.str(); } @@ -472,7 +472,14 @@ FlattenedState TRTEngine::__obj_flatten__() { std::tuple("serialized_metadata", serialized_info[SERIALIZED_METADATA_IDX]), std::tuple("requires_output_allocator", serialized_info[REQUIRES_OUTPUT_ALLOCATOR_IDX]), std::tuple("target_platform", serialized_info[TARGET_PLATFORM_IDX]), - std::tuple("resource_allocation_strategy", serialized_info[RESOURCE_ALLOCATION_STRATEGY_IDX])); + std::tuple("resource_allocation_strategy", serialized_info[RESOURCE_ALLOCATION_STRATEGY_IDX]) +#ifdef TRT_MAJOR_RTX + , + std::tuple("runtime_cache_path", serialized_info[RUNTIME_CACHE_PATH_IDX]), + std::tuple("dynamic_shapes_kernel_strategy", serialized_info[DYNAMIC_SHAPES_KERNEL_STRATEGY_IDX]), + std::tuple("cuda_graph_strategy", serialized_info[CUDA_GRAPH_STRATEGY_IDX]) +#endif + ); } std::vector TRTEngine::serialize() { @@ -497,6 +504,13 @@ std::vector TRTEngine::serialize() { serialized_info[TARGET_PLATFORM_IDX] = this->target_platform.serialize(); serialized_info[RESOURCE_ALLOCATION_STRATEGY_IDX] = this->resource_allocation_strategy == ResourceAllocationStrategy::kDynamic ? "1" : "0"; +#ifdef TRT_MAJOR_RTX + serialized_info[RUNTIME_CACHE_PATH_IDX] = runtime_cfg.runtime_cache_path; + serialized_info[DYNAMIC_SHAPES_KERNEL_STRATEGY_IDX] = std::to_string( + static_cast>(runtime_cfg.dynamic_shapes_kernel_strategy)); + serialized_info[CUDA_GRAPH_STRATEGY_IDX] = + std::to_string(static_cast>(runtime_cfg.cuda_graph_strategy)); +#endif return serialized_info; } @@ -508,17 +522,44 @@ void TRTEngine::reset_captured_graph() { void TRTEngine::set_resource_allocation_strategy(TRTEngine::ResourceAllocationStrategy new_strategy) { if (new_strategy != this->resource_allocation_strategy) { this->resource_allocation_strategy = new_strategy; - if (this->resource_allocation_strategy == TRTEngine::ResourceAllocationStrategy::kDynamic) { - LOG_DEBUG("Setting resource allocation strategy to dynamic"); - this->exec_ctx = - make_trt(cuda_engine->createExecutionContext(nvinfer1::ExecutionContextAllocationStrategy::kUSER_MANAGED)); - } else { - LOG_DEBUG("Setting resource allocation strategy to static"); - this->exec_ctx = make_trt(cuda_engine->createExecutionContext()); - } + LOG_DEBUG( + "Setting resource allocation strategy to " + << (this->resource_allocation_strategy == TRTEngine::ResourceAllocationStrategy::kDynamic ? "dynamic" + : "static")); + recreate_execution_context(); } } +bool TRTEngine::is_monolithic_capturable(cudaStream_t stream) const { + return runtime_cfg.is_monolithic_capturable(exec_ctx.get(), stream); +} + +void TRTEngine::disable_rtx_native_cudagraphs() { + bool was_disabled = runtime_cfg.rtx_native_cudagraphs_disabled; + runtime_cfg.disable_rtx_native_cudagraphs(name); + if (!was_disabled && runtime_cfg.rtx_native_cudagraphs_disabled) { + // The CUDA graph strategy on the IRuntimeConfig has been flipped; rebuild exec_ctx + // so the new strategy takes effect for subsequent enqueueV3 calls. + recreate_execution_context(); + } +} + +void TRTEngine::recreate_execution_context() { + // Flush any kernels the previous execution context may have compiled into the + // runtime cache before creating the replacement. The destructor also saves, but + // doing it here guards against losing compiled kernels across profiling toggles, + // allocator changes, or process kills that happen between allocator changes and + // teardown. No-op on standard TensorRT or when no cache path is configured. + runtime_cfg.save_runtime_cache(); + runtime_cfg.ensure_initialized(cuda_engine.get()); + runtime_cfg.set_execution_context_allocation_strategy( + resource_allocation_strategy == ResourceAllocationStrategy::kDynamic + ? nvinfer1::ExecutionContextAllocationStrategy::kUSER_MANAGED + : nvinfer1::ExecutionContextAllocationStrategy::kSTATIC); + exec_ctx = make_trt(cuda_engine->createExecutionContext(runtime_cfg.config.get())); + TORCHTRT_CHECK(exec_ctx.get() != nullptr, "Unable to (re)create TensorRT execution context"); +} + } // namespace runtime } // namespace core } // namespace torch_tensorrt diff --git a/core/runtime/TRTEngine.h b/core/runtime/TRTEngine.h index 363631863f..6ad5b2a3f2 100644 --- a/core/runtime/TRTEngine.h +++ b/core/runtime/TRTEngine.h @@ -13,6 +13,7 @@ #include "torch/custom_class.h" #include "core/runtime/TRTEngineProfiler.h" +#include "core/runtime/TRTRuntimeConfig.h" #include "core/util/prelude.h" namespace torch_tensorrt { @@ -30,7 +31,14 @@ using FlattenedState = std::tuple< std::tuple, // requires_output_allocator std::tuple, // serialized metadata std::tuple, // Platform - std::tuple>; // Resource Allocation Strategy + std::tuple // Resource Allocation Strategy +#ifdef TRT_MAJOR_RTX + , + std::tuple, // Runtime Cache Path (TRT-RTX) + std::tuple, // Dynamic Shapes Kernel Strategy (TRT-RTX) + std::tuple // CUDA Graph Strategy (TRT-RTX) +#endif + >; struct TorchTRTRuntimeStates { // Indicates whether CUDAGraphs were enabled in the previous execute_engine @@ -125,31 +133,33 @@ struct TRTEngine : torch::CustomClassHolder { ~TRTEngine(); TRTEngine( - const std::string& serialized_engine, + std::string serialized_engine, const RTDevice& cuda_device, const std::vector& in_binding_names, const std::vector& out_binding_names, const Platform& target_platform = get_current_platform(), bool hardware_compatible = false, bool requires_output_allocator = false, - const std::string& serialized_metadata = "", + std::string serialized_metadata = "", const TRTEngine::ResourceAllocationStrategy resource_allocation_strategy = - TRTEngine::ResourceAllocationStrategy::kStatic); + TRTEngine::ResourceAllocationStrategy::kStatic, + TRTRuntimeConfig runtime_cfg = TRTRuntimeConfig{}); TRTEngine(std::vector serialized_info); TRTEngine( - const std::string& mod_name, - const std::string& serialized_engine, + std::string mod_name, + std::string serialized_engine, const RTDevice& cuda_device, const std::vector& in_binding_names, const std::vector& out_binding_names, const Platform& target_platform = get_current_platform(), bool hardware_compatible = false, bool requires_output_allocator = false, - const std::string& serialized_metadata = "", + std::string serialized_metadata = "", const TRTEngine::ResourceAllocationStrategy resource_allocation_strategy = - TRTEngine::ResourceAllocationStrategy::kStatic); + TRTEngine::ResourceAllocationStrategy::kStatic, + TRTRuntimeConfig runtime_cfg = TRTRuntimeConfig{}); TRTEngine& operator=(const TRTEngine& other); std::string to_str() const; @@ -217,6 +227,24 @@ struct TRTEngine : torch::CustomClassHolder { ResourceAllocationStrategy resource_allocation_strategy = kStatic; void set_resource_allocation_strategy(ResourceAllocationStrategy new_strategy); ResourceAllocationStrategy get_resource_allocation_strategy(); + + // All TensorRT-RTX-specific IRuntimeConfig state lives here. On non-RTX builds this + // still owns a shared IRuntimeConfig (so the execution-context allocation strategy is + // applied via the uniform code path) but the RTX-only setters become no-ops. + TRTRuntimeConfig runtime_cfg; + + // Monolithic-capturability check used when this engine is wrapped by an outer whole-graph + // capture (e.g. CudaGraphsTorchTensorRTModule). Non-RTX builds always return true. + bool is_monolithic_capturable(cudaStream_t stream) const; + + // Disable TensorRT-RTX native CUDA graph capture on this engine (one-shot, invoked when + // an outer stream capture is detected around execute_engine). No-op on non-RTX. + void disable_rtx_native_cudagraphs(); + + private: + // Single entry point that (re)creates exec_ctx. Also creates (once) the IRuntimeConfig + // owned by runtime_cfg and applies all runtime config settings. + void recreate_execution_context(); }; } // namespace runtime diff --git a/core/runtime/TRTRuntimeConfig.cpp b/core/runtime/TRTRuntimeConfig.cpp new file mode 100644 index 0000000000..0804a0a7fa --- /dev/null +++ b/core/runtime/TRTRuntimeConfig.cpp @@ -0,0 +1,245 @@ +#include "core/runtime/TRTRuntimeConfig.h" + +#include +#include +#include +#include +#include + +#include "core/runtime/runtime.h" +#include "core/util/prelude.h" + +namespace torch_tensorrt { +namespace core { +namespace runtime { + +// File-local helpers. Kept out of the header because they are only used by this +// translation unit -- TRTEngine now consumes a TRTRuntimeConfig directly and does not +// need the enum conversion helpers. +namespace { + +[[nodiscard]] std::string to_string(DynamicShapesKernelStrategy s) { + switch (s) { + case DynamicShapesKernelStrategy::kLazy: + return "lazy"; + case DynamicShapesKernelStrategy::kEager: + return "eager"; + case DynamicShapesKernelStrategy::kNone: + return "none"; + } + TORCHTRT_CHECK( + false, + "Unexpected DynamicShapesKernelStrategy value: " + << static_cast>(s)); +} + +[[nodiscard]] std::string to_string(CudaGraphStrategyOption s) { + switch (s) { + case CudaGraphStrategyOption::kDisabled: + return "disabled"; + case CudaGraphStrategyOption::kWholeGraphCapture: + return "whole_graph_capture"; + } + TORCHTRT_CHECK( + false, + "Unexpected CudaGraphStrategyOption value: " << static_cast>(s)); +} + +[[nodiscard]] DynamicShapesKernelStrategy to_dynamic_shapes_kernel_strategy( + std::underlying_type_t v) { + TORCHTRT_CHECK( + v >= 0 && v <= 2, + "Invalid dynamic shapes kernel strategy value: " << v << ". Expected 0 (lazy), 1 (eager), or 2 (none)."); + return static_cast(v); +} + +[[nodiscard]] CudaGraphStrategyOption to_cuda_graph_strategy_option(std::underlying_type_t v) { + TORCHTRT_CHECK( + v >= 0 && v <= 1, + "Invalid CUDA graph strategy value: " << v << ". Expected 0 (disabled) or 1 (whole_graph_capture)."); + return static_cast(v); +} + +#ifdef TRT_MAJOR_RTX +// Raw cache I/O helpers. Exception-propagating; the caller wraps in try/catch at the +// TRTRuntimeConfig member level. Kept file-local because the IRuntimeCache type is +// itself TensorRT-RTX-only and tests reach this path through the member wrappers. +void load_runtime_cache(const std::string& path, nvinfer1::IRuntimeCache* cache) { + TORCHTRT_CHECK(cache != nullptr, "load_runtime_cache requires a non-null IRuntimeCache"); + if (!std::filesystem::exists(path)) { + LOG_DEBUG("No existing runtime cache at " << path); + return; + } + std::ifstream f(path, std::ios::binary); + std::vector buf((std::istreambuf_iterator(f)), std::istreambuf_iterator()); + if (buf.empty()) { + return; + } + TORCHTRT_CHECK(cache->deserialize(buf.data(), buf.size()), "IRuntimeCache::deserialize returned false for " << path); + LOG_INFO("Loaded runtime cache from " << path << " (" << buf.size() << " bytes)"); +} + +void save_runtime_cache_impl(const std::string& path, nvinfer1::IRuntimeCache* cache) { + TORCHTRT_CHECK(cache != nullptr, "save_runtime_cache requires a non-null IRuntimeCache"); + auto host_mem = make_trt(cache->serialize()); + if (!host_mem || host_mem->size() == 0) { + return; + } + std::filesystem::path fs_path(path); + if (fs_path.has_parent_path()) { + std::filesystem::create_directories(fs_path.parent_path()); + } + std::filesystem::path tmp_path = fs_path; + tmp_path += ".tmp"; + { + std::ofstream out(tmp_path, std::ios::binary); + out.write(reinterpret_cast(host_mem->data()), host_mem->size()); + } + std::filesystem::rename(tmp_path, fs_path); + LOG_INFO("Saved runtime cache to " << path << " (" << host_mem->size() << " bytes)"); +} +#endif // TRT_MAJOR_RTX + +} // namespace + +void TRTRuntimeConfig::ensure_initialized(nvinfer1::ICudaEngine* cuda_engine) { + if (config) { + return; + } + TORCHTRT_CHECK(cuda_engine != nullptr, "Cannot initialize TRTRuntimeConfig without a live ICudaEngine"); + config = make_trt(cuda_engine->createRuntimeConfig()); + TORCHTRT_CHECK(config.get() != nullptr, "Unable to create TensorRT IRuntimeConfig"); + +#ifdef TRT_MAJOR_RTX + // Runtime cache -- TRT-RTX only. + if (!runtime_cache_path.empty()) { + runtime_cache = make_trt(config->createRuntimeCache()); + if (runtime_cache.get() == nullptr) { + LOG_WARNING("Failed to create TensorRT IRuntimeCache; runtime cache will be skipped."); + } else { + try { + load_runtime_cache(runtime_cache_path, runtime_cache.get()); + } catch (const std::exception& e) { + LOG_WARNING("Failed to load runtime cache from " << runtime_cache_path << ": " << e.what()); + } + if (config->setRuntimeCache(*runtime_cache)) { + LOG_DEBUG("TensorRT-RTX runtime cache configured at " << runtime_cache_path); + } else { + LOG_WARNING("Failed to attach runtime cache to IRuntimeConfig; cache will be unused."); + runtime_cache.reset(); + } + } + } else { + LOG_DEBUG("Runtime cache disabled (no path configured)."); + } + + // Dynamic shapes kernel specialization strategy -- TRT-RTX only. + config->setDynamicShapesKernelSpecializationStrategy( + static_cast(dynamic_shapes_kernel_strategy)); + LOG_DEBUG("Dynamic shapes kernel specialization strategy set to " << to_string(dynamic_shapes_kernel_strategy)); + + // CUDA graph strategy -- TRT-RTX only. + if (!config->setCudaGraphStrategy( + cuda_graph_strategy == CudaGraphStrategyOption::kWholeGraphCapture + ? nvinfer1::CudaGraphStrategy::kWHOLE_GRAPH_CAPTURE + : nvinfer1::CudaGraphStrategy::kDISABLED)) { + LOG_WARNING("Failed to set CUDA graph strategy; continuing with default."); + } +#endif +} + +void TRTRuntimeConfig::set_execution_context_allocation_strategy( + nvinfer1::ExecutionContextAllocationStrategy strategy) const { + TORCHTRT_ASSERT(config, "TRTRuntimeConfig::config must be initialized before setting allocation strategy"); + config->setExecutionContextAllocationStrategy(strategy); +} + +bool TRTRuntimeConfig::uses_internal_capture(TORCHTRT_UNUSED bool cudagraphs_enabled) const { +#ifdef TRT_MAJOR_RTX + // On TRT-RTX the internal runtime handles capture/replay whenever a non-disabled + // strategy is set, or when subgraph cudagraphs are enabled globally. In both cases the + // caller should skip its manual at::cuda::CUDAGraph wrapper because TRT-RTX's internal + // capture would collide with it. + return cuda_graph_strategy != CudaGraphStrategyOption::kDisabled || cudagraphs_enabled; +#else + return false; +#endif +} + +void TRTRuntimeConfig::disable_rtx_native_cudagraphs(TORCHTRT_UNUSED const std::string& engine_name) noexcept { +#ifdef TRT_MAJOR_RTX + if (rtx_native_cudagraphs_disabled || cuda_graph_strategy == CudaGraphStrategyOption::kDisabled) { + return; + } + LOG_WARNING( + "Outer CUDA stream capture detected; disabling TensorRT-RTX native CUDA graph strategy on engine " + << engine_name << " for the remainder of its lifetime."); + // Persist any kernels the engine-internal capture has compiled so far; the outer + // capture will run without them otherwise, and we want future reloads to reuse them. + save_runtime_cache(); + cuda_graph_strategy = CudaGraphStrategyOption::kDisabled; + if (config && !config->setCudaGraphStrategy(nvinfer1::CudaGraphStrategy::kDISABLED)) { + LOG_WARNING("Failed to update CUDA graph strategy on IRuntimeConfig after disable."); + } + rtx_native_cudagraphs_disabled = true; +#endif +} + +bool TRTRuntimeConfig::is_monolithic_capturable( + TORCHTRT_UNUSED nvinfer1::IExecutionContext* exec_ctx, + TORCHTRT_UNUSED cudaStream_t stream) const { +#ifdef TRT_MAJOR_RTX + TORCHTRT_ASSERT(exec_ctx != nullptr, "is_monolithic_capturable requires a live IExecutionContext"); + // "lazy" kernel specialization swaps specialized kernels in mid-run, which invalidates + // captured graphs. Other strategies (eager/none) are safe when the context reports the + // stream capturable. + return exec_ctx->isStreamCapturable(stream) && dynamic_shapes_kernel_strategy != DynamicShapesKernelStrategy::kLazy; +#else + return true; +#endif +} + +void TRTRuntimeConfig::save_runtime_cache() noexcept { +#ifdef TRT_MAJOR_RTX + if (!runtime_cache || runtime_cache_path.empty()) { + return; + } + try { + save_runtime_cache_impl(runtime_cache_path, runtime_cache.get()); + } catch (const std::exception& e) { + LOG_WARNING("Failed to save runtime cache to " << runtime_cache_path << ": " << e.what()); + } catch (...) { + LOG_WARNING("Failed to save runtime cache (unknown exception)."); + } +#endif +} + +std::string TRTRuntimeConfig::to_str() const { + std::ostringstream os; + os << "Runtime Cache Path: " << (runtime_cache_path.empty() ? "" : runtime_cache_path) << std::endl; + os << "Dynamic Shapes Kernel Strategy: " << to_string(dynamic_shapes_kernel_strategy) << std::endl; + os << "CUDA Graph Strategy: " << to_string(cuda_graph_strategy) << std::endl; + return os.str(); +} + +TRTRuntimeConfig make_runtime_config_from_serialized(TORCHTRT_UNUSED const std::vector& info) { + TRTRuntimeConfig cfg; +#ifdef TRT_MAJOR_RTX + cfg.runtime_cache_path = info[RUNTIME_CACHE_PATH_IDX]; + cfg.dynamic_shapes_kernel_strategy = + to_dynamic_shapes_kernel_strategy(std::stoi(info[DYNAMIC_SHAPES_KERNEL_STRATEGY_IDX])); + cfg.cuda_graph_strategy = to_cuda_graph_strategy_option(std::stoi(info[CUDA_GRAPH_STRATEGY_IDX])); +#endif + return cfg; +} + +std::ostream& operator<<(std::ostream& os, const TRTRuntimeConfig& cfg) { + os << "Runtime cfg {" << std::endl; + os << cfg.to_str(); + os << "}" << std::endl; + return os; +} + +} // namespace runtime +} // namespace core +} // namespace torch_tensorrt diff --git a/core/runtime/TRTRuntimeConfig.h b/core/runtime/TRTRuntimeConfig.h new file mode 100644 index 0000000000..e964706c2e --- /dev/null +++ b/core/runtime/TRTRuntimeConfig.h @@ -0,0 +1,95 @@ +#pragma once + +#include +#include +#include +#include +#include +#include +#include + +#include "NvInfer.h" + +namespace torch_tensorrt { +namespace core { +namespace runtime { + +// TensorRT-RTX-only configuration for how shape-specialized kernels are compiled. +enum class DynamicShapesKernelStrategy : int32_t { + kLazy = 0, + kEager = 1, + kNone = 2, +}; + +// TensorRT-RTX-only configuration for how CUDA graph capture/replay is handled. +enum class CudaGraphStrategyOption : int32_t { + kDisabled = 0, + kWholeGraphCapture = 1, +}; + +// Encapsulates the nvinfer1::IRuntimeConfig owned by a TRTEngine along with the +// TensorRT-RTX-specific state (runtime cache, dynamic shapes kernel strategy, native +// CUDA graph strategy). All `#ifdef TRT_MAJOR_RTX` guards live in this file and its +// implementation so callers can treat this struct uniformly between RTX and standard +// TensorRT builds. +struct TRTRuntimeConfig { + // Settings - typically populated from engine deserialization before `ensure_initialized`. + std::string runtime_cache_path = ""; + DynamicShapesKernelStrategy dynamic_shapes_kernel_strategy = DynamicShapesKernelStrategy::kLazy; + CudaGraphStrategyOption cuda_graph_strategy = CudaGraphStrategyOption::kDisabled; + + // One-shot: set to true once an outer stream capture has been detected and the + // engine-internal CUDA graph strategy has been disabled for the remainder of the + // owning engine's lifetime. + bool rtx_native_cudagraphs_disabled = false; + + // Live resources. The IRuntimeConfig is lazy-constructed on first `ensure_initialized`. + std::shared_ptr config; +#ifdef TRT_MAJOR_RTX + std::shared_ptr runtime_cache; +#endif + + // Construct the IRuntimeConfig once and apply all TRT-RTX-specific settings. Safe to + // call multiple times; only the first call initializes and applies the RTX-only + // setters. On subsequent calls this is a no-op. + void ensure_initialized(nvinfer1::ICudaEngine* cuda_engine); + + // Apply (or re-apply) the execution context allocation strategy on the IRuntimeConfig. + // Available on both standard TensorRT and TensorRT-RTX via IRuntimeConfig. + void set_execution_context_allocation_strategy(nvinfer1::ExecutionContextAllocationStrategy strategy) const; + + // Returns true if the TensorRT-RTX runtime owns capture/replay for this engine so the + // caller should bypass its own at::cuda::CUDAGraph capture around enqueueV3. Always + // false on non-RTX builds. + [[nodiscard]] bool uses_internal_capture(bool cudagraphs_enabled) const; + + // One-shot: disable engine-internal CUDA graph capture. Invoked when an outer stream + // capture is detected around execute_engine, so the outer capture can contain the + // kernel launches directly. Saves the runtime cache before recreating the context so + // compiled kernels from the present run are preserved for future reloads. + void disable_rtx_native_cudagraphs(const std::string& engine_name) noexcept; + + // Whether the execution context is safe to include in an outer monolithic capture. + // Non-RTX builds always return true. + [[nodiscard]] bool is_monolithic_capturable(nvinfer1::IExecutionContext* exec_ctx, cudaStream_t stream) const; + + // Save the runtime cache to disk. Signature is `noexcept` so this is safe from a + // destructor. The underlying file I/O is performed by free functions declared below + // (non-noexcept, exception-leaky for easier testing); this member wraps them and + // swallows any exceptions. + void save_runtime_cache() noexcept; + + // Returns a human-readable summary of the runtime config. + [[nodiscard]] std::string to_str() const; +}; + +// Construct a TRTRuntimeConfig from a flattened serialization vector. Reads the +// RTX-only indices only on RTX builds; standard TRT builds return a default-initialized +// struct. +[[nodiscard]] TRTRuntimeConfig make_runtime_config_from_serialized(const std::vector& info); + +std::ostream& operator<<(std::ostream& os, const TRTRuntimeConfig& cfg); + +} // namespace runtime +} // namespace core +} // namespace torch_tensorrt diff --git a/core/runtime/execute_engine.cpp b/core/runtime/execute_engine.cpp index 553469392b..2a71b7ebd3 100644 --- a/core/runtime/execute_engine.cpp +++ b/core/runtime/execute_engine.cpp @@ -217,11 +217,28 @@ std::vector execute_engine(std::vector inputs, c10::intr auto run_standard_execution = [&]() { bool cudagraphs_enabled = (CUDAGRAPHS_MODE == SUBGRAPH_CUDAGRAPHS); + // effective_cudagraphs controls the manual at::cuda::CUDAGraph path below. On TRT-RTX + // builds the engine-internal runtime owns capture/replay inside enqueueV3 whenever the + // engine has a cuda_graph_strategy set or subgraph cudagraphs are enabled; the struct + // reports that via `uses_internal_capture` so the caller skips its manual wrapper. If + // an outer stream capture is already in progress (e.g. the caller wraps this module in + // CudaGraphsTorchTensorRTModule for whole-graph capture), engine-internal capture would + // collide, so we disable it one-shot here. + bool effective_cudagraphs = cudagraphs_enabled; + if (compiled_engine->runtime_cfg.uses_internal_capture(cudagraphs_enabled)) { + effective_cudagraphs = false; + cudaStreamCaptureStatus capture_status; + cudaStreamIsCapturing(compiled_engine->engine_stream.stream(), &capture_status); + if (capture_status != cudaStreamCaptureStatusNone) { + compiled_engine->disable_rtx_native_cudagraphs(); + } + } + bool shape_changed = _validate_shapes(inputs, compiled_engine); // Whether cudagraphs needs to record the graph on this pass auto result = compiled_engine->runtime_states.set_runtime_states( - cudagraphs_enabled, compiled_engine->use_pre_allocated_outputs, shape_changed); + effective_cudagraphs, compiled_engine->use_pre_allocated_outputs, shape_changed); bool need_cudagraphs_record = std::get<0>(result); bool can_use_pre_allocated_outputs = std::get<1>(result); @@ -244,7 +261,8 @@ std::vector execute_engine(std::vector inputs, c10::intr std::make_unique(compiled_engine->input_profile_path); } - setup_input_tensors(inputs, compiled_engine, cudagraphs_enabled, need_cudagraphs_record, inputShapeTensorValues); + setup_input_tensors( + inputs, compiled_engine, effective_cudagraphs, need_cudagraphs_record, inputShapeTensorValues); // Check if input shapes can be inferred. int32_t const io_size{compiled_engine->cuda_engine->getNbIOTensors()}; std::vector names(io_size); @@ -276,7 +294,7 @@ std::vector execute_engine(std::vector inputs, c10::intr compiled_engine->output_buffers[pyt_idx] = std::move(outputs[pyt_idx].clone()); } - if (cudagraphs_enabled) { + if (effective_cudagraphs) { TORCHTRT_CHECK( compiled_engine->exec_ctx->setTensorAddress( name.c_str(), compiled_engine->output_buffers[pyt_idx].data_ptr()), @@ -316,8 +334,10 @@ std::vector execute_engine(std::vector inputs, c10::intr caller_exec_complete.record(compiled_engine->caller_stream); caller_exec_complete.block(compiled_engine->engine_stream); - if (!cudagraphs_enabled) { - // Direct execution uses the caller buffers directly + if (!effective_cudagraphs) { + // Direct execution uses the caller buffers directly. On TRT-RTX with a + // cuda_graph_strategy set, the engine captures/replays internally during + // this enqueueV3 call. compiled_engine->exec_ctx->enqueueV3(compiled_engine->engine_stream); } else { if (need_cudagraphs_record) { @@ -350,7 +370,7 @@ std::vector execute_engine(std::vector inputs, c10::intr trt_exec_complete.record(compiled_engine->engine_stream); trt_exec_complete.block(compiled_engine->caller_stream); - if (cudagraphs_enabled) { + if (effective_cudagraphs) { // If in CUDAGraph mode, results need to be copied to the result buffers (on caller stream) for (size_t o = 0; o < compiled_engine->output_buffers.size(); o++) { outputs[o].copy_(compiled_engine->output_buffers[o], false); diff --git a/core/runtime/register_jit_hooks.cpp b/core/runtime/register_jit_hooks.cpp index e8f6217a21..ad49890307 100644 --- a/core/runtime/register_jit_hooks.cpp +++ b/core/runtime/register_jit_hooks.cpp @@ -150,6 +150,11 @@ TORCH_LIBRARY(tensorrt, m) { m.def("REQUIRES_OUTPUT_ALLOCATOR_IDX", []() -> int64_t { return REQUIRES_OUTPUT_ALLOCATOR_IDX; }); m.def("SERIALIZATION_LEN", []() -> int64_t { return SERIALIZATION_LEN; }); m.def("RESOURCE_ALLOCATION_STRATEGY_IDX", []() -> int64_t { return RESOURCE_ALLOCATION_STRATEGY_IDX; }); +#ifdef TRT_MAJOR_RTX + m.def("RUNTIME_CACHE_PATH_IDX", []() -> int64_t { return RUNTIME_CACHE_PATH_IDX; }); + m.def("DYNAMIC_SHAPES_KERNEL_STRATEGY_IDX", []() -> int64_t { return DYNAMIC_SHAPES_KERNEL_STRATEGY_IDX; }); + m.def("CUDA_GRAPH_STRATEGY_IDX", []() -> int64_t { return CUDA_GRAPH_STRATEGY_IDX; }); +#endif m.def("_platform_linux_x86_64", []() -> std::string { auto it = get_platform_name_map().find(Platform::PlatformEnum::kLINUX_X86_64); return it->second; diff --git a/core/runtime/runtime.h b/core/runtime/runtime.h index d8f71683d3..70c8aa8119 100644 --- a/core/runtime/runtime.h +++ b/core/runtime/runtime.h @@ -16,7 +16,7 @@ namespace core { namespace runtime { using EngineID = int64_t; -const std::string ABI_VERSION = "8"; +const std::string ABI_VERSION = "9"; extern bool MULTI_DEVICE_SAFE_MODE; typedef enum { @@ -39,6 +39,11 @@ typedef enum { TARGET_PLATFORM_IDX, REQUIRES_OUTPUT_ALLOCATOR_IDX, RESOURCE_ALLOCATION_STRATEGY_IDX, +#ifdef TRT_MAJOR_RTX + RUNTIME_CACHE_PATH_IDX, + DYNAMIC_SHAPES_KERNEL_STRATEGY_IDX, + CUDA_GRAPH_STRATEGY_IDX, +#endif SERIALIZATION_LEN, // NEVER USED FOR DATA, USED TO DETERMINE LENGTH OF SERIALIZED INFO } SerializedInfoIndex; diff --git a/py/torch_tensorrt/dynamo/_compiler.py b/py/torch_tensorrt/dynamo/_compiler.py index d04c294ad9..a50d55469f 100644 --- a/py/torch_tensorrt/dynamo/_compiler.py +++ b/py/torch_tensorrt/dynamo/_compiler.py @@ -113,6 +113,7 @@ def cross_compile_for_windows( enable_resource_partitioning: bool = _defaults.ENABLE_RESOURCE_PARTITIONING, cpu_memory_budget: Optional[int] = _defaults.CPU_MEMORY_BUDGET, dynamically_allocate_resources: bool = _defaults.DYNAMICALLY_ALLOCATE_RESOURCES, + cuda_graph_strategy: str = _defaults.CUDA_GRAPH_STRATEGY, decompose_attention: bool = _defaults.DECOMPOSE_ATTENTION, attn_bias_is_causal: bool = _defaults.ATTN_BIAS_IS_CAUSAL, **kwargs: Any, @@ -356,6 +357,7 @@ def cross_compile_for_windows( "enable_resource_partitioning": enable_resource_partitioning, "cpu_memory_budget": cpu_memory_budget, "dynamically_allocate_resources": dynamically_allocate_resources, + "cuda_graph_strategy": cuda_graph_strategy, "decompose_attention": decompose_attention, "attn_bias_is_causal": attn_bias_is_causal, } @@ -487,6 +489,7 @@ def compile( cpu_memory_budget: Optional[int] = _defaults.CPU_MEMORY_BUDGET, enable_resource_partitioning: bool = _defaults.ENABLE_RESOURCE_PARTITIONING, dynamically_allocate_resources: bool = _defaults.DYNAMICALLY_ALLOCATE_RESOURCES, + cuda_graph_strategy: str = _defaults.CUDA_GRAPH_STRATEGY, decompose_attention: bool = _defaults.DECOMPOSE_ATTENTION, attn_bias_is_causal: bool = _defaults.ATTN_BIAS_IS_CAUSAL, **kwargs: Any, @@ -785,6 +788,7 @@ def compile( "enable_resource_partitioning": enable_resource_partitioning, "cpu_memory_budget": cpu_memory_budget, "dynamically_allocate_resources": dynamically_allocate_resources, + "cuda_graph_strategy": cuda_graph_strategy, "decompose_attention": decompose_attention, "attn_bias_is_causal": attn_bias_is_causal, } @@ -1192,6 +1196,7 @@ def convert_exported_program_to_serialized_trt_engine( l2_limit_for_tiling: int = _defaults.L2_LIMIT_FOR_TILING, offload_module_to_cpu: bool = _defaults.OFFLOAD_MODULE_TO_CPU, use_distributed_mode_trace: bool = _defaults.USE_DISTRIBUTED_MODE_TRACE, + cuda_graph_strategy: str = _defaults.CUDA_GRAPH_STRATEGY, decompose_attention: bool = _defaults.DECOMPOSE_ATTENTION, attn_bias_is_causal: bool = _defaults.ATTN_BIAS_IS_CAUSAL, **kwargs: Any, @@ -1443,6 +1448,7 @@ def convert_exported_program_to_serialized_trt_engine( "l2_limit_for_tiling": l2_limit_for_tiling, "offload_module_to_cpu": offload_module_to_cpu, "use_distributed_mode_trace": use_distributed_mode_trace, + "cuda_graph_strategy": cuda_graph_strategy, "decompose_attention": decompose_attention, "attn_bias_is_causal": attn_bias_is_causal, } diff --git a/py/torch_tensorrt/dynamo/_defaults.py b/py/torch_tensorrt/dynamo/_defaults.py index 8998479a63..a929b5ea1d 100644 --- a/py/torch_tensorrt/dynamo/_defaults.py +++ b/py/torch_tensorrt/dynamo/_defaults.py @@ -70,9 +70,10 @@ ENABLE_RESOURCE_PARTITIONING = False CPU_MEMORY_BUDGET = None DYNAMICALLY_ALLOCATE_RESOURCES = False +DYNAMIC_SHAPES_KERNEL_SPECIALIZATION_STRATEGY = "lazy" +CUDA_GRAPH_STRATEGY = "disabled" DECOMPOSE_ATTENTION = False ATTN_BIAS_IS_CAUSAL = True -DYNAMIC_SHAPES_KERNEL_SPECIALIZATION_STRATEGY = "lazy" if platform.system() == "Linux": import pwd diff --git a/py/torch_tensorrt/dynamo/_settings.py b/py/torch_tensorrt/dynamo/_settings.py index 595f9dcb55..eb4f4e07e7 100644 --- a/py/torch_tensorrt/dynamo/_settings.py +++ b/py/torch_tensorrt/dynamo/_settings.py @@ -17,6 +17,7 @@ AUTOCAST_MAX_OUTPUT_THRESHOLD, CACHE_BUILT_ENGINES, CPU_MEMORY_BUDGET, + CUDA_GRAPH_STRATEGY, DECOMPOSE_ATTENTION, DISABLE_TF32, DLA_GLOBAL_DRAM_SIZE, @@ -124,6 +125,7 @@ class CompilationSettings: autocast_calibration_dataloader (Optional[torch.utils.data.DataLoader]): The dataloader to use for autocast calibration. Default is None. offload_module_to_cpu (bool): Offload the model to CPU to reduce memory footprint during compilation dynamically_allocate_resources (bool): Dynamically allocate resources for TensorRT engines + cuda_graph_strategy (str): TensorRT-RTX CUDA graph strategy: "disabled" (default) or "whole_graph_capture" (let TensorRT-RTX manage CUDA graph capture/replay internally). When set and combined with `torch_tensorrt.runtime.set_cudagraphs_mode(True)` on RTX, overrides manual capture. Not used for standard TensorRT. decompose_attention (bool): Whether to decompose attention layers. We have converters for handling attention ops, but if you want to decompose them into smaller ops, you can set this to True. attn_bias_is_causal (bool): Whether the attn_bias in efficient SDPA is causal. Default is True. This can accelerate models from HF because attn_bias is always a causal mask in HF. If you want to use non-causal attn_bias, you can set this to False. """ @@ -189,6 +191,7 @@ class CompilationSettings: enable_resource_partitioning: bool = ENABLE_RESOURCE_PARTITIONING cpu_memory_budget: Optional[int] = CPU_MEMORY_BUDGET dynamically_allocate_resources: bool = DYNAMICALLY_ALLOCATE_RESOURCES + cuda_graph_strategy: str = CUDA_GRAPH_STRATEGY decompose_attention: bool = DECOMPOSE_ATTENTION attn_bias_is_causal: bool = ATTN_BIAS_IS_CAUSAL diff --git a/py/torch_tensorrt/dynamo/runtime/_TorchTensorRTModule.py b/py/torch_tensorrt/dynamo/runtime/_TorchTensorRTModule.py index d77c0bf39f..79c14ddb9d 100644 --- a/py/torch_tensorrt/dynamo/runtime/_TorchTensorRTModule.py +++ b/py/torch_tensorrt/dynamo/runtime/_TorchTensorRTModule.py @@ -35,6 +35,10 @@ SERIALIZED_METADATA_IDX = -1 # Not implemented TARGET_PLATFORM_IDX = -1 # Not implemented REQUIRES_OUTPUT_ALLOCATOR_IDX = -1 # Not implemented +RESOURCE_ALLOCATION_STRATEGY_IDX = -1 # Not implemented +RUNTIME_CACHE_PATH_IDX = -1 # Not implemented +DYNAMIC_SHAPES_KERNEL_STRATEGY_IDX = -1 # Not implemented +CUDA_GRAPH_STRATEGY_IDX = -1 # Not implemented SERIALIZATION_LEN = -1 # Not implemented if ENABLED_FEATURES.torch_tensorrt_runtime: @@ -53,7 +57,25 @@ RESOURCE_ALLOCATION_STRATEGY_IDX = ( torch.ops.tensorrt.RESOURCE_ALLOCATION_STRATEGY_IDX() ) # 10 - SERIALIZATION_LEN = torch.ops.tensorrt.SERIALIZATION_LEN() # 11 + if ENABLED_FEATURES.tensorrt_rtx: + RUNTIME_CACHE_PATH_IDX = torch.ops.tensorrt.RUNTIME_CACHE_PATH_IDX() # 11 + DYNAMIC_SHAPES_KERNEL_STRATEGY_IDX = ( + torch.ops.tensorrt.DYNAMIC_SHAPES_KERNEL_STRATEGY_IDX() + ) # 12 + CUDA_GRAPH_STRATEGY_IDX = torch.ops.tensorrt.CUDA_GRAPH_STRATEGY_IDX() # 13 + SERIALIZATION_LEN = ( + torch.ops.tensorrt.SERIALIZATION_LEN() + ) # 14 (RTX) / 11 (standard) + +_DYNAMIC_SHAPES_KERNEL_STRATEGY_MAP: Dict[str, int] = { + "lazy": 0, + "eager": 1, + "none": 2, +} +_CUDA_GRAPH_STRATEGY_MAP: Dict[str, int] = { + "disabled": 0, + "whole_graph_capture": 1, +} @for_all_methods(needs_torch_tensorrt_runtime) @@ -145,6 +167,14 @@ def __init__( self.engine = None self.requires_output_allocator = requires_output_allocator self.dynamically_allocate_resources = settings.dynamically_allocate_resources + # TensorRT-RTX-only runtime config mirror. The engine-info serialization slots + # only exist on RTX builds (see below), but we validate the strategy names on + # every build so typos are caught regardless of backend. + self.runtime_cache_path = settings.runtime_cache_path + self.dynamic_shapes_kernel_specialization_strategy = ( + settings.dynamic_shapes_kernel_specialization_strategy + ) + self.cuda_graph_strategy = settings.cuda_graph_strategy self.symbolic_shape_expressions = symbolic_shape_expressions if ( @@ -203,6 +233,35 @@ def _pack_engine_info(self) -> List[str | bytes]: engine_info[RESOURCE_ALLOCATION_STRATEGY_IDX] = str( int(self.dynamically_allocate_resources) ) + # Validate TensorRT-RTX strategy names on every build so typos are caught + # regardless of backend. The engine-info slots themselves only exist on RTX + # builds and are written below, but the validation is cheap and catches user + # errors early. + if ENABLED_FEATURES.tensorrt_rtx and self.runtime_cache_path is not None: + engine_info[RUNTIME_CACHE_PATH_IDX] = self.runtime_cache_path or "" + if ( + self.dynamic_shapes_kernel_specialization_strategy + not in _DYNAMIC_SHAPES_KERNEL_STRATEGY_MAP + ): + raise ValueError( + f"Invalid dynamic_shapes_kernel_specialization_strategy " + f"{self.dynamic_shapes_kernel_specialization_strategy!r}; expected one of " + f"{list(_DYNAMIC_SHAPES_KERNEL_STRATEGY_MAP.keys())}" + ) + if self.cuda_graph_strategy not in _CUDA_GRAPH_STRATEGY_MAP: + raise ValueError( + f"Invalid cuda_graph_strategy {self.cuda_graph_strategy!r}; expected one of " + f"{list(_CUDA_GRAPH_STRATEGY_MAP.keys())}" + ) + if ENABLED_FEATURES.tensorrt_rtx: + engine_info[DYNAMIC_SHAPES_KERNEL_STRATEGY_IDX] = str( + _DYNAMIC_SHAPES_KERNEL_STRATEGY_MAP[ + self.dynamic_shapes_kernel_specialization_strategy + ] + ) + engine_info[CUDA_GRAPH_STRATEGY_IDX] = str( + _CUDA_GRAPH_STRATEGY_MAP[self.cuda_graph_strategy] + ) return engine_info diff --git a/tests/py/dynamo/models/test_dynamic_shapes_kernel_strategy_models.py b/tests/py/dynamo/models/test_dynamic_shapes_kernel_strategy_models.py index badfff81ea..e3d438c3ae 100644 --- a/tests/py/dynamo/models/test_dynamic_shapes_kernel_strategy_models.py +++ b/tests/py/dynamo/models/test_dynamic_shapes_kernel_strategy_models.py @@ -3,10 +3,43 @@ import torch import torch_tensorrt as torchtrt +from parameterized import parameterized from torch.testing._internal.common_utils import TestCase, run_tests from torch_tensorrt._features import ENABLED_FEATURES from torch_tensorrt.dynamo.utils import COSINE_THRESHOLD, cosine_similarity +# Combinations of (strategy, runtime_name, use_python_runtime). Tests use parameterized +# so the strategy sweep runs on both runtimes with a single test body. +_STRATEGY_RUNTIMES = [ + ("lazy_python", "lazy", True), + ("eager_python", "eager", True), + ("none_python", "none", True), + ("lazy_cpp", "lazy", False), + ("eager_cpp", "eager", False), + ("none_cpp", "none", False), +] + + +def _skip_if_cpp_unavailable(testcase, use_python_runtime): + if not use_python_runtime and not ENABLED_FEATURES.torch_tensorrt_runtime: + testcase.skipTest("C++ runtime is not available") + + +def _compile_with_strategy( + model, inputs, *, use_python_runtime, strategy, enabled_precisions +): + compiled = torchtrt.compile( + model, + ir="dynamo", + inputs=inputs, + enabled_precisions=enabled_precisions, + use_python_runtime=use_python_runtime, + min_block_size=1, + dynamic_shapes_kernel_specialization_strategy=strategy, + ) + torch._dynamo.reset() + return compiled + @unittest.skipIf( not ENABLED_FEATURES.tensorrt_rtx, @@ -17,17 +50,18 @@ "torchvision is not installed", ) class TestDynamicShapesKernelStrategyModels(TestCase): - """End-to-end model tests with different kernel specialization strategies.""" + """End-to-end model tests with each strategy across both runtimes.""" - def tearDown(self): - torch._dynamo.reset() + @parameterized.expand(_STRATEGY_RUNTIMES) + def test_resnet18_strategy(self, _name, strategy, use_python_runtime): + _skip_if_cpp_unavailable(self, use_python_runtime) + import torchvision.models as models - def _compile_and_verify(self, model, strategy): + model = models.resnet18(pretrained=True).eval().cuda() input_tensor = torch.randn(4, 3, 224, 224).cuda() - compiled = torchtrt.compile( + compiled = _compile_with_strategy( model, - ir="dynamo", - inputs=[ + [ torchtrt.Input( min_shape=(1, 3, 224, 224), opt_shape=(4, 3, 224, 224), @@ -35,10 +69,9 @@ def _compile_and_verify(self, model, strategy): dtype=torch.float32, ) ], + use_python_runtime=use_python_runtime, + strategy=strategy, enabled_precisions={torch.float32}, - use_python_runtime=True, - min_block_size=1, - dynamic_shapes_kernel_specialization_strategy=strategy, ) ref_output = model(input_tensor) trt_output = compiled(input_tensor) @@ -46,39 +79,21 @@ def _compile_and_verify(self, model, strategy): self.assertTrue( cos_sim > COSINE_THRESHOLD, f"Cosine similarity {cos_sim} below threshold {COSINE_THRESHOLD} " - f"with strategy={strategy}", + f"(strategy={strategy}, python_runtime={use_python_runtime})", ) - def test_resnet18_lazy_strategy(self): - import torchvision.models as models - - model = models.resnet18(pretrained=True).eval().cuda() - self._compile_and_verify(model, "lazy") - - def test_resnet18_eager_strategy(self): - import torchvision.models as models - - model = models.resnet18(pretrained=True).eval().cuda() - self._compile_and_verify(model, "eager") - - def test_resnet18_none_strategy(self): - import torchvision.models as models - - model = models.resnet18(pretrained=True).eval().cuda() - self._compile_and_verify(model, "none") - @unittest.skipIf( not ENABLED_FEATURES.tensorrt_rtx, "Dynamic shapes kernel specialization strategy requires TensorRT-RTX", ) class TestDynamicShapesKernelStrategyDynamic(TestCase): - """Tests kernel specialization strategies with dynamic input shapes.""" + """Tests kernel specialization strategies with dynamic input shapes, both runtimes.""" - def tearDown(self): - torch._dynamo.reset() + @parameterized.expand(_STRATEGY_RUNTIMES) + def test_dynamic_batch_with_strategy(self, _name, strategy, use_python_runtime): + _skip_if_cpp_unavailable(self, use_python_runtime) - def _test_dynamic_batch_with_strategy(self, strategy): class ConvModel(torch.nn.Module): def __init__(self): super().__init__() @@ -90,10 +105,9 @@ def forward(self, x): model = ConvModel().eval().cuda() - compiled = torchtrt.compile( + compiled = _compile_with_strategy( model, - ir="dynamo", - inputs=[ + [ torchtrt.Input( min_shape=(1, 3, 32, 32), opt_shape=(4, 3, 32, 32), @@ -101,32 +115,21 @@ def forward(self, x): dtype=torch.float32, ) ], + use_python_runtime=use_python_runtime, + strategy=strategy, enabled_precisions={torch.float32}, - use_python_runtime=True, - min_block_size=1, - dynamic_shapes_kernel_specialization_strategy=strategy, ) for batch_size in (1, 4, 8): - with self.subTest(batch_size=batch_size, strategy=strategy): - input_tensor = torch.randn(batch_size, 3, 32, 32).cuda() - ref_output = model(input_tensor) - trt_output = compiled(input_tensor) - cos_sim = cosine_similarity(ref_output, trt_output) - self.assertTrue( - cos_sim > COSINE_THRESHOLD, - f"BS={batch_size}, strategy={strategy}: cosine similarity " - f"{cos_sim} below threshold {COSINE_THRESHOLD}", - ) - - def test_dynamic_batch_lazy(self): - self._test_dynamic_batch_with_strategy("lazy") - - def test_dynamic_batch_eager(self): - self._test_dynamic_batch_with_strategy("eager") - - def test_dynamic_batch_none(self): - self._test_dynamic_batch_with_strategy("none") + input_tensor = torch.randn(batch_size, 3, 32, 32).cuda() + ref_output = model(input_tensor) + trt_output = compiled(input_tensor) + cos_sim = cosine_similarity(ref_output, trt_output) + self.assertTrue( + cos_sim > COSINE_THRESHOLD, + f"BS={batch_size}, strategy={strategy}, python_runtime={use_python_runtime}: " + f"cosine similarity {cos_sim} below threshold {COSINE_THRESHOLD}", + ) if __name__ == "__main__": diff --git a/tests/py/dynamo/models/test_runtime_cache_models.py b/tests/py/dynamo/models/test_runtime_cache_models.py index aecb2fbaa3..55b11b623e 100644 --- a/tests/py/dynamo/models/test_runtime_cache_models.py +++ b/tests/py/dynamo/models/test_runtime_cache_models.py @@ -8,10 +8,32 @@ import torch import torch_tensorrt as torchtrt +from parameterized import parameterized from torch.testing._internal.common_utils import TestCase, run_tests from torch_tensorrt._features import ENABLED_FEATURES from torch_tensorrt.dynamo.utils import COSINE_THRESHOLD, cosine_similarity +# Parameterize end-to-end cache tests over both runtime paths. The C++ variant is +# skipped inside the test body when the C++ runtime is not available. +_RUNTIMES = [("python", True), ("cpp", False)] + + +def _compile(model, inputs, *, use_python_runtime, runtime_cache_path): + kwargs = { + "ir": "dynamo", + "inputs": inputs, + "enabled_precisions": {torch.float32}, + "use_python_runtime": use_python_runtime, + "min_block_size": 1, + "runtime_cache_path": runtime_cache_path, + } + return torchtrt.compile(model, **kwargs) + + +def _skip_if_cpp_unavailable(testcase, use_python_runtime): + if not use_python_runtime and not ENABLED_FEATURES.torch_tensorrt_runtime: + testcase.skipTest("C++ runtime is not available") + @unittest.skipIf( not ENABLED_FEATURES.tensorrt_rtx, @@ -22,7 +44,7 @@ "torchvision is not installed", ) class TestRuntimeCacheModels(TestCase): - """End-to-end model tests with runtime cache enabled.""" + """End-to-end model tests with runtime cache enabled — both runtimes.""" def setUp(self): self.cache_dir = tempfile.mkdtemp() @@ -32,19 +54,18 @@ def tearDown(self): shutil.rmtree(self.cache_dir, ignore_errors=True) torch._dynamo.reset() - def test_resnet18_with_runtime_cache(self): + @parameterized.expand(_RUNTIMES) + def test_resnet18_with_runtime_cache(self, _name, use_python_runtime): + _skip_if_cpp_unavailable(self, use_python_runtime) import torchvision.models as models model = models.resnet18(pretrained=True).eval().cuda() input_tensor = torch.randn(1, 3, 224, 224).cuda() - compiled = torchtrt.compile( + compiled = _compile( model, - ir="dynamo", - inputs=[torchtrt.Input(input_tensor.shape, dtype=torch.float32)], - enabled_precisions={torch.float32}, - use_python_runtime=True, - min_block_size=1, + [torchtrt.Input(input_tensor.shape, dtype=torch.float32)], + use_python_runtime=use_python_runtime, runtime_cache_path=self.cache_path, ) @@ -57,7 +78,6 @@ def test_resnet18_with_runtime_cache(self): f"ResNet18 cosine similarity {cos_sim} below threshold {COSINE_THRESHOLD}", ) - # Verify runtime cache is saved on cleanup del compiled gc.collect() self.assertTrue( @@ -65,8 +85,10 @@ def test_resnet18_with_runtime_cache(self): "Runtime cache should be saved after ResNet18 inference", ) - def test_resnet18_cache_reuse(self): - """Compile + infer twice with same cache path. Second run should load cached data.""" + @parameterized.expand(_RUNTIMES) + def test_resnet18_cache_reuse(self, _name, use_python_runtime): + """Compile + infer twice with same cache path. Second run loads cached data.""" + _skip_if_cpp_unavailable(self, use_python_runtime) import torchvision.models as models model = models.resnet18(pretrained=True).eval().cuda() @@ -74,16 +96,13 @@ def test_resnet18_cache_reuse(self): ref_output = model(input_tensor) compile_kwargs = { - "ir": "dynamo", "inputs": [torchtrt.Input(input_tensor.shape, dtype=torch.float32)], - "enabled_precisions": {torch.float32}, - "use_python_runtime": True, - "min_block_size": 1, + "use_python_runtime": use_python_runtime, "runtime_cache_path": self.cache_path, } # First compilation — cold cache - compiled1 = torchtrt.compile(model, **compile_kwargs) + compiled1 = _compile(model, **compile_kwargs) _ = compiled1(input_tensor) del compiled1 gc.collect() @@ -92,7 +111,7 @@ def test_resnet18_cache_reuse(self): cache_size_1 = os.path.getsize(self.cache_path) # Second compilation — warm cache - compiled2 = torchtrt.compile(model, **compile_kwargs) + compiled2 = _compile(model, **compile_kwargs) output2 = compiled2(input_tensor) cos_sim = cosine_similarity(ref_output, output2) @@ -104,23 +123,21 @@ def test_resnet18_cache_reuse(self): del compiled2 gc.collect() cache_size_2 = os.path.getsize(self.cache_path) - # Cache should exist and be non-empty after both runs self.assertGreater(cache_size_1, 0) self.assertGreater(cache_size_2, 0) - def test_mobilenet_v2_with_runtime_cache(self): + @parameterized.expand(_RUNTIMES) + def test_mobilenet_v2_with_runtime_cache(self, _name, use_python_runtime): + _skip_if_cpp_unavailable(self, use_python_runtime) import torchvision.models as models model = models.mobilenet_v2(pretrained=True).eval().cuda() input_tensor = torch.randn(1, 3, 224, 224).cuda() - compiled = torchtrt.compile( + compiled = _compile( model, - ir="dynamo", - inputs=[torchtrt.Input(input_tensor.shape, dtype=torch.float32)], - enabled_precisions={torch.float32}, - use_python_runtime=True, - min_block_size=1, + [torchtrt.Input(input_tensor.shape, dtype=torch.float32)], + use_python_runtime=use_python_runtime, runtime_cache_path=self.cache_path, ) @@ -143,7 +160,7 @@ def test_mobilenet_v2_with_runtime_cache(self): "Runtime cache is only available with TensorRT-RTX", ) class TestRuntimeCacheDynamicShapes(TestCase): - """Tests runtime cache with dynamic input shapes.""" + """Tests runtime cache with dynamic input shapes, exercised on both runtimes.""" def setUp(self): self.cache_dir = tempfile.mkdtemp() @@ -153,7 +170,10 @@ def tearDown(self): shutil.rmtree(self.cache_dir, ignore_errors=True) torch._dynamo.reset() - def test_dynamic_batch_with_cache(self): + @parameterized.expand(_RUNTIMES) + def test_dynamic_batch_with_cache(self, _name, use_python_runtime): + _skip_if_cpp_unavailable(self, use_python_runtime) + class ConvModel(torch.nn.Module): def __init__(self): super().__init__() @@ -165,10 +185,9 @@ def forward(self, x): model = ConvModel().eval().cuda() - compiled = torchtrt.compile( + compiled = _compile( model, - ir="dynamo", - inputs=[ + [ torchtrt.Input( min_shape=(1, 3, 32, 32), opt_shape=(4, 3, 32, 32), @@ -176,39 +195,28 @@ def forward(self, x): dtype=torch.float32, ) ], - enabled_precisions={torch.float32}, - use_python_runtime=True, - min_block_size=1, + use_python_runtime=use_python_runtime, runtime_cache_path=self.cache_path, ) - # Test with batch size 1 - input_bs1 = torch.randn(1, 3, 32, 32).cuda() - ref_bs1 = model(input_bs1) - out_bs1 = compiled(input_bs1) - cos_sim_1 = cosine_similarity(ref_bs1, out_bs1) - self.assertTrue( - cos_sim_1 > COSINE_THRESHOLD, - f"BS=1 cosine similarity {cos_sim_1} below threshold", - ) - - # Test with batch size 4 - input_bs4 = torch.randn(4, 3, 32, 32).cuda() - ref_bs4 = model(input_bs4) - out_bs4 = compiled(input_bs4) - cos_sim_4 = cosine_similarity(ref_bs4, out_bs4) - self.assertTrue( - cos_sim_4 > COSINE_THRESHOLD, - f"BS=4 cosine similarity {cos_sim_4} below threshold", - ) + for batch_size in (1, 4): + input_tensor = torch.randn(batch_size, 3, 32, 32).cuda() + ref_output = model(input_tensor) + out = compiled(input_tensor) + cos_sim = cosine_similarity(ref_output, out) + self.assertTrue( + cos_sim > COSINE_THRESHOLD, + f"BS={batch_size} cosine similarity {cos_sim} below threshold", + ) - # Verify cache is saved del compiled gc.collect() self.assertTrue(os.path.isfile(self.cache_path)) - def test_cache_valid_across_shapes(self): + @parameterized.expand(_RUNTIMES) + def test_cache_valid_across_shapes(self, _name, use_python_runtime): """Save cache from one shape, load and verify it works with another shape in range.""" + _skip_if_cpp_unavailable(self, use_python_runtime) class SimpleConv(torch.nn.Module): def __init__(self): @@ -221,7 +229,6 @@ def forward(self, x): model = SimpleConv().eval().cuda() compile_kwargs = { - "ir": "dynamo", "inputs": [ torchtrt.Input( min_shape=(1, 3, 16, 16), @@ -230,14 +237,12 @@ def forward(self, x): dtype=torch.float32, ) ], - "enabled_precisions": {torch.float32}, - "use_python_runtime": True, - "min_block_size": 1, + "use_python_runtime": use_python_runtime, "runtime_cache_path": self.cache_path, } # First run with batch=2 — saves cache - compiled1 = torchtrt.compile(model, **compile_kwargs) + compiled1 = _compile(model, **compile_kwargs) input_bs2 = torch.randn(2, 3, 16, 16).cuda() _ = compiled1(input_bs2) del compiled1 @@ -246,7 +251,7 @@ def forward(self, x): self.assertTrue(os.path.isfile(self.cache_path)) # Second run with batch=3 — loads same cache - compiled2 = torchtrt.compile(model, **compile_kwargs) + compiled2 = _compile(model, **compile_kwargs) input_bs3 = torch.randn(3, 3, 16, 16).cuda() ref_bs3 = model(input_bs3) out_bs3 = compiled2(input_bs3) @@ -273,8 +278,10 @@ def tearDown(self): shutil.rmtree(self.cache_dir, ignore_errors=True) torch._dynamo.reset() - def test_warmup_timing(self): - """Measure cold vs warm cache inference time. Informational only — no strict pass/fail.""" + @parameterized.expand(_RUNTIMES) + def test_warmup_timing(self, _name, use_python_runtime): + """Measure cold vs warm cache inference time. Informational — no strict assertion.""" + _skip_if_cpp_unavailable(self, use_python_runtime) class MLP(torch.nn.Module): def __init__(self): @@ -290,16 +297,12 @@ def forward(self, x): input_tensor = torch.randn(16, 256).cuda() compile_kwargs = { - "ir": "dynamo", "inputs": [torchtrt.Input(input_tensor.shape, dtype=torch.float32)], - "enabled_precisions": {torch.float32}, - "use_python_runtime": True, - "min_block_size": 1, + "use_python_runtime": use_python_runtime, "runtime_cache_path": self.cache_path, } - # Cold cache compilation + inference - compiled1 = torchtrt.compile(model, **compile_kwargs) + compiled1 = _compile(model, **compile_kwargs) torch.cuda.synchronize() start = time.perf_counter() _ = compiled1(input_tensor) @@ -309,19 +312,16 @@ def forward(self, x): gc.collect() torch._dynamo.reset() - # Warm cache compilation + inference - compiled2 = torchtrt.compile(model, **compile_kwargs) + compiled2 = _compile(model, **compile_kwargs) torch.cuda.synchronize() start = time.perf_counter() _ = compiled2(input_tensor) torch.cuda.synchronize() warm_time = time.perf_counter() - start - print(f"\n Cold cache first inference: {cold_time*1000:.1f}ms") - print(f" Warm cache first inference: {warm_time*1000:.1f}ms") - print(f" Speedup: {cold_time/warm_time:.2f}x") - - # No strict assertion — just log for visibility + print(f"\n [{_name}] Cold cache first inference: {cold_time*1000:.1f}ms") + print(f" [{_name}] Warm cache first inference: {warm_time*1000:.1f}ms") + print(f" [{_name}] Speedup: {cold_time/warm_time:.2f}x") self.assertTrue(True, "Timing test completed (informational)") diff --git a/tests/py/dynamo/runtime/test_000_runtime_cache.py b/tests/py/dynamo/runtime/test_000_runtime_cache.py index bad67db24c..dc23847870 100644 --- a/tests/py/dynamo/runtime/test_000_runtime_cache.py +++ b/tests/py/dynamo/runtime/test_000_runtime_cache.py @@ -1,5 +1,4 @@ import gc -import logging import os import shutil import tempfile @@ -7,10 +6,11 @@ import torch import torch_tensorrt as torchtrt +from parameterized import parameterized from torch.testing._internal.common_utils import TestCase, run_tests from torch_tensorrt._features import ENABLED_FEATURES from torch_tensorrt.dynamo._defaults import RUNTIME_CACHE_PATH, TIMING_CACHE_PATH -from torch_tensorrt.dynamo._settings import CompilationSettings +from torch_tensorrt.dynamo.utils import COSINE_THRESHOLD, cosine_similarity class SimpleModel(torch.nn.Module): @@ -18,31 +18,50 @@ def forward(self, x): return torch.relu(x) + 1.0 -class TwoLayerModel(torch.nn.Module): +class ConvModel(torch.nn.Module): def __init__(self): super().__init__() - self.linear = torch.nn.Linear(8, 8) + self.conv = torch.nn.Conv2d(3, 8, 3, padding=1) def forward(self, x): - return torch.relu(self.linear(x)) + return torch.relu(self.conv(x)) -def _compile_simple(runtime_cache_path=None): - """Helper: compile SimpleModel with Python runtime, return (compiled_module, inputs).""" - model = SimpleModel().eval().cuda() - inputs = [torch.randn(2, 3).cuda()] +def _fresh_conv_model_and_inputs(seed=0): + """Deterministic ConvModel + input pair for end-to-end cache tests on either runtime.""" + torch.manual_seed(seed) + return ConvModel().eval().cuda(), [torch.randn(2, 3, 16, 16).cuda()] + + +def _compile(model, inputs, *, use_python_runtime, runtime_cache_path=None): + """Compile `model` through either runtime. Returns the compiled module.""" kwargs = { "ir": "dynamo", "inputs": inputs, "enabled_precisions": {torch.float32}, - "use_python_runtime": True, + "use_python_runtime": use_python_runtime, "min_block_size": 1, } if runtime_cache_path is not None: kwargs["runtime_cache_path"] = runtime_cache_path compiled = torchtrt.compile(model, **kwargs) torch._dynamo.reset() - return compiled, inputs + return compiled + + +def _compile_simple(runtime_cache_path=None): + """Compile the SimpleModel on the Python runtime (used by Python-only setup tests).""" + model = SimpleModel().eval().cuda() + inputs = [torch.randn(2, 3).cuda()] + return ( + _compile( + model, + inputs, + use_python_runtime=True, + runtime_cache_path=runtime_cache_path, + ), + inputs, + ) def _find_python_trt_module(compiled): @@ -51,18 +70,23 @@ def _find_python_trt_module(compiled): PythonTorchTensorRTModule, ) - for name, mod in compiled.named_modules(): + for _name, mod in compiled.named_modules(): if isinstance(mod, PythonTorchTensorRTModule): return mod return None +# Parameterize end-to-end cache persistence tests over both runtime paths. The C++ +# variant is skipped inside the test body when the C++ runtime is not available. +_RUNTIMES = [("python", True), ("cpp", False)] + + @unittest.skipIf( not ENABLED_FEATURES.tensorrt_rtx, "Runtime cache is only available with TensorRT-RTX", ) class TestRuntimeCacheSetup(TestCase): - """Tests that runtime config and cache are correctly created for RTX.""" + """Python-runtime-only setup checks: the compiled module exposes a live runtime cache.""" def test_runtime_config_created(self): compiled, _ = _compile_simple() @@ -77,7 +101,6 @@ def test_context_created_successfully(self): compiled, inputs = _compile_simple() mod = _find_python_trt_module(compiled) self.assertIsNotNone(mod.context, "execution context should be created") - # Verify inference works output = compiled(*[inp.clone() for inp in inputs]) self.assertEqual(output.shape, inputs[0].shape) @@ -102,7 +125,7 @@ def test_runtime_cache_path_custom(self): "Runtime cache is only available with TensorRT-RTX", ) class TestRuntimeCachePersistence(TestCase): - """Tests that runtime cache is correctly saved to and loaded from disk.""" + """Load-on-setup / save-on-destructor contract, exercised on both runtimes.""" def setUp(self): self.cache_dir = tempfile.mkdtemp() @@ -111,9 +134,20 @@ def setUp(self): def tearDown(self): shutil.rmtree(self.cache_dir, ignore_errors=True) - def test_cache_saved_on_del(self): - compiled, inputs = _compile_simple(runtime_cache_path=self.cache_path) - # Run inference to populate the cache + def _skip_if_cpp_unavailable(self, use_python_runtime): + if not use_python_runtime and not ENABLED_FEATURES.torch_tensorrt_runtime: + self.skipTest("C++ runtime is not available") + + @parameterized.expand(_RUNTIMES) + def test_cache_saved_on_del(self, _name, use_python_runtime): + self._skip_if_cpp_unavailable(use_python_runtime) + model, inputs = _fresh_conv_model_and_inputs() + compiled = _compile( + model, + inputs, + use_python_runtime=use_python_runtime, + runtime_cache_path=self.cache_path, + ) _ = compiled(*[inp.clone() for inp in inputs]) self.assertFalse( os.path.isfile(self.cache_path), @@ -126,8 +160,16 @@ def test_cache_saved_on_del(self): "Cache file should be created after module cleanup", ) - def test_cache_file_nonempty(self): - compiled, inputs = _compile_simple(runtime_cache_path=self.cache_path) + @parameterized.expand(_RUNTIMES) + def test_cache_file_nonempty(self, _name, use_python_runtime): + self._skip_if_cpp_unavailable(use_python_runtime) + model, inputs = _fresh_conv_model_and_inputs() + compiled = _compile( + model, + inputs, + use_python_runtime=use_python_runtime, + runtime_cache_path=self.cache_path, + ) _ = compiled(*[inp.clone() for inp in inputs]) del compiled gc.collect() @@ -137,30 +179,54 @@ def test_cache_file_nonempty(self): "Cache file should have nonzero size", ) - def test_cache_roundtrip(self): - """Compile, infer, save. Then compile again with same cache path and verify correctness.""" - model = SimpleModel().eval().cuda() - inputs = [torch.randn(2, 3).cuda()] - ref_output = model(*inputs) - - # First compilation — populates and saves cache - compiled1, _ = _compile_simple(runtime_cache_path=self.cache_path) - _ = compiled1(*[inp.clone() for inp in inputs]) + @parameterized.expand(_RUNTIMES) + def test_cache_roundtrip(self, _name, use_python_runtime): + """Populate + save, then recompile and confirm correctness against eager output.""" + self._skip_if_cpp_unavailable(use_python_runtime) + model, inputs = _fresh_conv_model_and_inputs() + with torch.no_grad(): + ref_output = model(*inputs) + + compiled1 = _compile( + model, + inputs, + use_python_runtime=use_python_runtime, + runtime_cache_path=self.cache_path, + ) + out1 = compiled1(*[inp.clone() for inp in inputs]) + self.assertGreater( + cosine_similarity(ref_output, out1), + COSINE_THRESHOLD, + "First compiled output should match eager", + ) del compiled1 gc.collect() self.assertTrue(os.path.isfile(self.cache_path)) - # Second compilation — should load cached data - compiled2, _ = _compile_simple(runtime_cache_path=self.cache_path) - output = compiled2(*[inp.clone() for inp in inputs]) - max_diff = float(torch.max(torch.abs(ref_output - output))) - self.assertAlmostEqual( - max_diff, 0, places=3, msg="Output mismatch after cache roundtrip" + compiled2 = _compile( + model, + inputs, + use_python_runtime=use_python_runtime, + runtime_cache_path=self.cache_path, + ) + out2 = compiled2(*[inp.clone() for inp in inputs]) + self.assertGreater( + cosine_similarity(ref_output, out2), + COSINE_THRESHOLD, + "Second compiled output (warm cache) should still match eager", ) - def test_save_creates_directory(self): + @parameterized.expand(_RUNTIMES) + def test_save_creates_directory(self, _name, use_python_runtime): + self._skip_if_cpp_unavailable(use_python_runtime) nested_path = os.path.join(self.cache_dir, "a", "b", "c", "runtime_cache.bin") - compiled, inputs = _compile_simple(runtime_cache_path=nested_path) + model, inputs = _fresh_conv_model_and_inputs() + compiled = _compile( + model, + inputs, + use_python_runtime=use_python_runtime, + runtime_cache_path=nested_path, + ) _ = compiled(*[inp.clone() for inp in inputs]) del compiled gc.collect() @@ -175,7 +241,7 @@ def test_save_creates_directory(self): "Runtime cache is only available with TensorRT-RTX", ) class TestRuntimeCacheConcurrency(TestCase): - """Tests that file locking works for concurrent access.""" + """Tests that file locking works for concurrent access (Python runtime only).""" def setUp(self): self.cache_dir = tempfile.mkdtemp() @@ -191,7 +257,6 @@ def test_filelock_works(self): del compiled gc.collect() self.assertTrue(os.path.isfile(self.cache_path)) - # Verify we can acquire a lock on the same path (no deadlock) from filelock import FileLock lock = FileLock(self.cache_path + ".lock") @@ -201,14 +266,12 @@ def test_filelock_works(self): def test_sequential_save_load(self): """Two modules saving and loading from the same path should not corrupt data.""" - # First module saves compiled1, inputs = _compile_simple(runtime_cache_path=self.cache_path) _ = compiled1(*[inp.clone() for inp in inputs]) del compiled1 gc.collect() size1 = os.path.getsize(self.cache_path) - # Second module saves (overwrites) compiled2, inputs = _compile_simple(runtime_cache_path=self.cache_path) _ = compiled2(*[inp.clone() for inp in inputs]) del compiled2 @@ -227,7 +290,6 @@ class TestTimingCacheSkipped(TestCase): """Tests that timing cache is correctly skipped for RTX builds.""" def setUp(self): - # Clean up any pre-existing timing cache if os.path.isfile(TIMING_CACHE_PATH): os.remove(TIMING_CACHE_PATH) @@ -272,7 +334,6 @@ def test_no_runtime_config_for_standard_trt(self): ) def test_timing_cache_still_created(self): - # Clean up any pre-existing timing cache if os.path.isfile(TIMING_CACHE_PATH): os.remove(TIMING_CACHE_PATH) compiled, inputs = _compile_simple() @@ -283,5 +344,26 @@ def test_timing_cache_still_created(self): ) +@unittest.skipIf( + not ENABLED_FEATURES.torch_tensorrt_runtime, + "C++ runtime is not available", +) +@unittest.skipIf( + not ENABLED_FEATURES.tensorrt_rtx, + "New serialization indices are registered only on TensorRT-RTX builds", +) +class TestCppSerializationIndices(TestCase): + """Verify the new RTX-only C++ serialization indices are registered by the runtime.""" + + def test_new_indices_registered(self): + self.assertEqual(int(torch.ops.tensorrt.ABI_VERSION()), 9) + self.assertEqual(int(torch.ops.tensorrt.SERIALIZATION_LEN()), 14) + self.assertEqual(int(torch.ops.tensorrt.RUNTIME_CACHE_PATH_IDX()), 11) + self.assertEqual( + int(torch.ops.tensorrt.DYNAMIC_SHAPES_KERNEL_STRATEGY_IDX()), 12 + ) + self.assertEqual(int(torch.ops.tensorrt.CUDA_GRAPH_STRATEGY_IDX()), 13) + + if __name__ == "__main__": run_tests() diff --git a/tests/py/dynamo/runtime/test_001_cuda_graph_strategy.py b/tests/py/dynamo/runtime/test_001_cuda_graph_strategy.py new file mode 100644 index 0000000000..8a2968b0d8 --- /dev/null +++ b/tests/py/dynamo/runtime/test_001_cuda_graph_strategy.py @@ -0,0 +1,116 @@ +import unittest + +import torch +import torch_tensorrt as torchtrt +from torch.testing._internal.common_utils import TestCase, run_tests +from torch_tensorrt._features import ENABLED_FEATURES +from torch_tensorrt.dynamo._defaults import CUDA_GRAPH_STRATEGY +from torch_tensorrt.dynamo._settings import CompilationSettings + + +class CudaGraphModel(torch.nn.Module): + def __init__(self): + super().__init__() + self.conv = torch.nn.Conv2d(3, 8, 3, padding=1) + + def forward(self, x): + return torch.relu(self.conv(x)) + + +def _compile_cpp(strategy): + model = CudaGraphModel().eval().cuda() + inputs = [torch.randn(2, 3, 16, 16).cuda()] + compiled = torchtrt.compile( + model, + ir="dynamo", + inputs=inputs, + enabled_precisions={torch.float32}, + use_python_runtime=False, + min_block_size=1, + cuda_graph_strategy=strategy, + ) + torch._dynamo.reset() + return compiled, inputs + + +class TestCudaGraphStrategySettings(TestCase): + """Setting-level validation that runs on every build (RTX and non-RTX).""" + + def test_default_value(self): + settings = CompilationSettings() + self.assertEqual(settings.cuda_graph_strategy, CUDA_GRAPH_STRATEGY) + + def test_settable_values(self): + for value in ("disabled", "whole_graph_capture"): + settings = CompilationSettings(cuda_graph_strategy=value) + self.assertEqual(settings.cuda_graph_strategy, value) + + +@unittest.skipIf( + not ENABLED_FEATURES.torch_tensorrt_runtime, + "C++ runtime is not available", +) +@unittest.skipIf( + not ENABLED_FEATURES.tensorrt_rtx, + "CUDA graph strategy is a TensorRT-RTX feature", +) +class TestCudaGraphStrategyCpp(TestCase): + """End-to-end: compile + infer through the C++ runtime with each strategy.""" + + def tearDown(self): + torchtrt.runtime.set_cudagraphs_mode(False) + + def test_disabled(self): + compiled, inputs = _compile_cpp("disabled") + y = compiled(*[inp.clone() for inp in inputs]) + self.assertEqual(tuple(y.shape), (2, 8, 16, 16)) + self.assertTrue(torch.isfinite(y).all().item()) + + def test_whole_graph_capture(self): + compiled, inputs = _compile_cpp("whole_graph_capture") + y = compiled(*[inp.clone() for inp in inputs]) + self.assertEqual(tuple(y.shape), (2, 8, 16, 16)) + self.assertTrue(torch.isfinite(y).all().item()) + + def test_whole_graph_capture_with_subgraph_cudagraphs(self): + """Subgraph cudagraph mode + RTX strategy: RTX-native should take over without errors.""" + compiled, inputs = _compile_cpp("whole_graph_capture") + torchtrt.runtime.set_cudagraphs_mode(True) + y = compiled(*[inp.clone() for inp in inputs]) + self.assertEqual(tuple(y.shape), (2, 8, 16, 16)) + self.assertTrue(torch.isfinite(y).all().item()) + + def test_repeated_inference(self): + """Repeated inference exercises the RTX-native capture/replay path.""" + compiled, inputs = _compile_cpp("whole_graph_capture") + ref = compiled(*[inp.clone() for inp in inputs]) + for _ in range(4): + out = compiled(*[inp.clone() for inp in inputs]) + self.assertEqual(out.shape, ref.shape) + self.assertTrue(torch.isfinite(out).all().item()) + + +@unittest.skipIf( + not ENABLED_FEATURES.torch_tensorrt_runtime, + "C++ runtime is not available", +) +class TestCudaGraphStrategyInvalidValue(TestCase): + """Invalid strategy names are rejected at engine-packing time.""" + + def test_invalid_strategy_raises(self): + model = CudaGraphModel().eval().cuda() + inputs = [torch.randn(2, 3, 16, 16).cuda()] + with self.assertRaises((ValueError, RuntimeError)): + torchtrt.compile( + model, + ir="dynamo", + inputs=inputs, + enabled_precisions={torch.float32}, + use_python_runtime=False, + min_block_size=1, + cuda_graph_strategy="not_a_real_strategy", + ) + + +if __name__ == "__main__": + run_tests() diff --git a/tests/py/dynamo/runtime/test_001_dynamic_shapes_kernel_strategy.py b/tests/py/dynamo/runtime/test_001_dynamic_shapes_kernel_strategy.py index 8c0a12cbdf..d514be86d1 100644 --- a/tests/py/dynamo/runtime/test_001_dynamic_shapes_kernel_strategy.py +++ b/tests/py/dynamo/runtime/test_001_dynamic_shapes_kernel_strategy.py @@ -2,16 +2,29 @@ import torch import torch_tensorrt as torchtrt +from parameterized import parameterized from torch.testing._internal.common_utils import TestCase, run_tests from torch_tensorrt._features import ENABLED_FEATURES from torch_tensorrt.dynamo._settings import CompilationSettings +_STRATEGIES = [("lazy",), ("eager",), ("none",)] + class SimpleModel(torch.nn.Module): def forward(self, x): return torch.relu(x) + 1.0 +class DynamicConvModel(torch.nn.Module): + def __init__(self): + super().__init__() + self.conv1 = torch.nn.Conv2d(3, 16, 3, padding=1) + self.conv2 = torch.nn.Conv2d(16, 8, 3, padding=1) + + def forward(self, x): + return torch.relu(self.conv2(torch.relu(self.conv1(x)))) + + def _compile_simple(**extra_kwargs): """Helper: compile SimpleModel with dynamic shapes and Python runtime.""" model = SimpleModel().eval().cuda() @@ -36,13 +49,34 @@ def _compile_simple(**extra_kwargs): return compiled +def _compile_cpp(strategy): + model = DynamicConvModel().eval().cuda() + inp = torchtrt.Input( + min_shape=(1, 3, 16, 16), + opt_shape=(2, 3, 16, 16), + max_shape=(4, 3, 16, 16), + dtype=torch.float32, + ) + compiled = torchtrt.compile( + model, + ir="dynamo", + inputs=[inp], + enabled_precisions={torch.float32}, + use_python_runtime=False, + min_block_size=1, + dynamic_shapes_kernel_specialization_strategy=strategy, + ) + torch._dynamo.reset() + return compiled + + def _find_python_trt_module(compiled): """Walk the compiled graph module to find PythonTorchTensorRTModule instances.""" from torch_tensorrt.dynamo.runtime._PythonTorchTensorRTModule import ( PythonTorchTensorRTModule, ) - for name, mod in compiled.named_modules(): + for _name, mod in compiled.named_modules(): if isinstance(mod, PythonTorchTensorRTModule): return mod return None @@ -55,6 +89,12 @@ def _find_python_trt_module(compiled): class TestDynamicShapesKernelStrategySetup(TestCase): """Tests that the dynamic shapes kernel specialization strategy is correctly applied.""" + _EXPECTED_ENUM = { + "lazy": "LAZY", + "eager": "EAGER", + "none": "NONE", + } + def test_default_strategy_is_lazy(self): import tensorrt as trt @@ -67,28 +107,21 @@ def test_default_strategy_is_lazy(self): trt.DynamicShapesKernelSpecializationStrategy.LAZY, ) - def test_eager_strategy(self): + @parameterized.expand(_STRATEGIES) + def test_strategy_applied(self, strategy): import tensorrt as trt compiled = _compile_simple( - dynamic_shapes_kernel_specialization_strategy="eager" - ) - mod = _find_python_trt_module(compiled) - self.assertIsNotNone(mod) - self.assertEqual( - mod.runtime_config.dynamic_shapes_kernel_specialization_strategy, - trt.DynamicShapesKernelSpecializationStrategy.EAGER, + dynamic_shapes_kernel_specialization_strategy=strategy ) - - def test_none_strategy(self): - import tensorrt as trt - - compiled = _compile_simple(dynamic_shapes_kernel_specialization_strategy="none") mod = _find_python_trt_module(compiled) self.assertIsNotNone(mod) self.assertEqual( mod.runtime_config.dynamic_shapes_kernel_specialization_strategy, - trt.DynamicShapesKernelSpecializationStrategy.NONE, + getattr( + trt.DynamicShapesKernelSpecializationStrategy, + self._EXPECTED_ENUM[strategy], + ), ) def test_context_created_with_each_strategy(self): @@ -101,7 +134,6 @@ def test_context_created_with_each_strategy(self): self.assertIsNotNone( mod.context, f"Execution context should be created for {strategy}" ) - # Test inference with multiple dynamic batch sizes for bs in (1, 2, 4): output = compiled(torch.randn(bs, 3).cuda()) self.assertEqual(output.shape, (bs, 3)) @@ -137,10 +169,64 @@ def test_setting_ignored_on_non_rtx(self): mod.runtime_config, "runtime_config should be None for standard TRT", ) - # Inference should still work output = compiled(torch.randn(2, 3).cuda()) self.assertEqual(output.shape, (2, 3)) +@unittest.skipIf( + not ENABLED_FEATURES.torch_tensorrt_runtime, + "C++ runtime is not available", +) +@unittest.skipIf( + not ENABLED_FEATURES.tensorrt_rtx, + "Dynamic shapes kernel strategy is a TensorRT-RTX feature", +) +class TestDynamicShapesKernelStrategyCpp(TestCase): + """End-to-end: compile + infer through the C++ runtime with each strategy.""" + + @parameterized.expand(_STRATEGIES) + def test_strategy_inference(self, strategy): + compiled = _compile_cpp(strategy) + x = torch.randn(2, 3, 16, 16, device="cuda") + y = compiled(x) + self.assertEqual(tuple(y.shape), (2, 8, 16, 16)) + self.assertTrue(torch.isfinite(y).all().item()) + + def test_dynamic_shape_with_eager(self): + """Exercise shape changes under eager kernel specialization.""" + compiled = _compile_cpp("eager") + for batch in (1, 2, 3, 4): + x = torch.randn(batch, 3, 16, 16, device="cuda") + y = compiled(x) + self.assertEqual(tuple(y.shape), (batch, 8, 16, 16)) + + +@unittest.skipIf( + not ENABLED_FEATURES.torch_tensorrt_runtime, + "C++ runtime is not available", +) +class TestDynamicShapesKernelStrategyCppInvalidValue(TestCase): + """Invalid strategy names are rejected at engine-packing time on the C++ runtime path.""" + + def test_invalid_strategy_raises(self): + model = DynamicConvModel().eval().cuda() + inp = torchtrt.Input( + min_shape=(1, 3, 16, 16), + opt_shape=(2, 3, 16, 16), + max_shape=(4, 3, 16, 16), + dtype=torch.float32, + ) + with self.assertRaises((ValueError, RuntimeError)): + torchtrt.compile( + model, + ir="dynamo", + inputs=[inp], + enabled_precisions={torch.float32}, + use_python_runtime=False, + min_block_size=1, + dynamic_shapes_kernel_specialization_strategy="not_a_real_strategy", + ) + + if __name__ == "__main__": run_tests()