From fff21afe47c7156c92a56a4ee51a321ea0fde79b Mon Sep 17 00:00:00 2001 From: gcunhase <4861122+gcunhase@users.noreply.github.com> Date: Fri, 6 Mar 2026 13:49:54 -0500 Subject: [PATCH 1/6] Replace CUDA memory management from CUDART to Torch Signed-off-by: gcunhase <4861122+gcunhase@users.noreply.github.com> --- .../onnx/quantization/autotune/benchmark.py | 195 ++++++------------ 1 file changed, 58 insertions(+), 137 deletions(-) diff --git a/modelopt/onnx/quantization/autotune/benchmark.py b/modelopt/onnx/quantization/autotune/benchmark.py index 7613a119a..b90ad55ea 100644 --- a/modelopt/onnx/quantization/autotune/benchmark.py +++ b/modelopt/onnx/quantization/autotune/benchmark.py @@ -26,7 +26,6 @@ - TensorRTPyBenchmark: Uses TensorRT Python API for direct engine profiling """ -import contextlib import ctypes import importlib.util import os @@ -48,13 +47,13 @@ if TRT_AVAILABLE: import tensorrt as trt -CUDART_AVAILABLE = importlib.util.find_spec("cuda") is not None -if CUDART_AVAILABLE: - try: - from cuda.bindings import runtime as cudart - except ImportError: - with contextlib.suppress(ImportError): - from cuda import cudart # deprecated: prefer cuda.bindings.runtime +TORCH_CUDA_AVAILABLE = False +try: + import torch + + TORCH_CUDA_AVAILABLE = torch.cuda.is_available() +except ImportError: + pass def _validate_shape_range(min_shape: list, opt_shape: list, max_shape: list) -> None: @@ -345,9 +344,9 @@ def __init__( if not TRT_AVAILABLE: raise ImportError("TensorRT Python API not available. Please install tensorrt package.") - if not CUDART_AVAILABLE or cudart is None: + if not TORCH_CUDA_AVAILABLE: raise ImportError( - "CUDA Runtime (cudart) not available. Please install cuda-python package: pip install cuda-python" + "PyTorch with CUDA support not available. Please install torch with CUDA: pip install torch" ) self.trt_logger = trt.Logger(trt.Logger.WARNING) @@ -527,107 +526,57 @@ def _build_engine( del parser, network, config @staticmethod - def _alloc_pinned_host(size: int, dtype: np.dtype) -> tuple[Any, np.ndarray, Any]: - """Allocate pinned host memory and return (host_ptr, array view, cuda error). + def _alloc_pinned_host(size: int, dtype: np.dtype) -> tuple[Any, np.ndarray]: + """Allocate pinned host memory using PyTorch and return (tensor, numpy_view). Returns: - (host_ptr, arr, err): On success err is cudaSuccess; on failure host_ptr/arr - may be None and err is the CUDA error code. + (host_tensor, arr): Pinned PyTorch tensor and a numpy view over it. """ - dtype = np.dtype(dtype) - nbytes = size * dtype.itemsize - err, host_ptr = cudart.cudaMallocHost(nbytes) - if err != cudart.cudaError_t.cudaSuccess: - return (None, None, err) - addr = int(host_ptr) if hasattr(host_ptr, "__int__") else host_ptr - try: - ctype = np.ctypeslib.as_ctypes_type(dtype) - arr = np.ctypeslib.as_array((ctype * size).from_address(addr)) - except NotImplementedError as e: - # float16/bfloat16 have no ctypes equivalent; use same-size type and view - if dtype.itemsize == 2: - ctype = ctypes.c_uint16 - else: - raise TypeError( - f"Pinned host allocation for dtype {dtype} is not supported: " - "no ctypes mapping and no fallback for this itemsize" - ) from e - arr = np.ctypeslib.as_array((ctype * size).from_address(addr)).view(dtype) - return (host_ptr, arr, cudart.cudaError_t.cudaSuccess) + torch_dtype = torch.from_numpy(np.empty(0, dtype=dtype)).dtype + host_tensor = torch.empty(int(size), dtype=torch_dtype).pin_memory() + return host_tensor, host_tensor.numpy() @staticmethod def _free_buffers(bufs: list[dict]) -> None: - """Free host and device memory for a list of buffer dicts (host_ptr, device_ptr).""" - for buf in bufs: - if "host_ptr" in buf and buf["host_ptr"] is not None: - cudart.cudaFreeHost(buf["host_ptr"]) - if "device_ptr" in buf and buf["device_ptr"] is not None: - cudart.cudaFree(buf["device_ptr"]) + """Release buffer references; PyTorch handles underlying memory deallocation.""" + bufs.clear() def _allocate_buffers( self, engine: "trt.ICudaEngine", context: "trt.IExecutionContext", - ) -> tuple[list[dict], list[dict], Any]: - """Allocate host and device buffers for engine I/O and set tensor addresses. + ) -> tuple[list[dict], list[dict]]: + """Allocate pinned host and device tensors for engine I/O and set tensor addresses. Args: engine: Deserialized TensorRT engine. context: Execution context with tensor shapes set. Returns: - (inputs, outputs, cuda_error): On success cuda_error is cudaSuccess; - on failure inputs/outputs are empty and cuda_error is the failing CUDA error code. + (inputs, outputs): Lists of buffer dicts containing PyTorch tensors. """ inputs: list[dict] = [] outputs: list[dict] = [] for i in range(engine.num_io_tensors): tensor_name = engine.get_tensor_name(i) - dtype = trt.nptype(engine.get_tensor_dtype(tensor_name)) + np_dtype = trt.nptype(engine.get_tensor_dtype(tensor_name)) shape = context.get_tensor_shape(tensor_name) + size = int(trt.volume(shape)) - size = trt.volume(shape) - nbytes = size * np.dtype(dtype).itemsize + host_tensor, host_mem = self._alloc_pinned_host(size, np_dtype) + torch_dtype = torch.from_numpy(np.empty(0, dtype=np_dtype)).dtype + device_tensor = torch.empty(size, dtype=torch_dtype, device="cuda") - err, device_ptr = cudart.cudaMalloc(nbytes) - if err != cudart.cudaError_t.cudaSuccess: - self.logger.error(f"cudaMalloc failed: {err}") - self._free_buffers(inputs + outputs) - return ([], [], err) - - host_ptr, host_mem, err = self._alloc_pinned_host(size, dtype) - if err != cudart.cudaError_t.cudaSuccess: - self.logger.error(f"cudaMallocHost failed: {err}") - cudart.cudaFree(device_ptr) - self._free_buffers(inputs + outputs) - return ([], [], err) + context.set_tensor_address(tensor_name, device_tensor.data_ptr()) if engine.get_tensor_mode(tensor_name) == trt.TensorIOMode.INPUT: - np.copyto(host_mem, np.random.randn(size).astype(dtype)) - inputs.append( - { - "host_ptr": host_ptr, - "host": host_mem, - "device_ptr": device_ptr, - "nbytes": nbytes, - "name": tensor_name, - } - ) + np.copyto(host_mem, np.random.randn(size).astype(np_dtype)) + inputs.append({"host": host_tensor, "device": device_tensor, "name": tensor_name}) else: - outputs.append( - { - "host_ptr": host_ptr, - "host": host_mem, - "device_ptr": device_ptr, - "nbytes": nbytes, - "name": tensor_name, - } - ) - - context.set_tensor_address(tensor_name, int(device_ptr)) + outputs.append({"host": host_tensor, "device": device_tensor, "name": tensor_name}) - return (inputs, outputs, cudart.cudaError_t.cudaSuccess) + return (inputs, outputs) def _setup_execution_context( self, serialized_engine: bytes @@ -652,55 +601,44 @@ def _run_warmup( context: "trt.IExecutionContext", inputs: list[dict], outputs: list[dict], - stream_handle: Any, + stream: "torch.cuda.Stream", ) -> None: """Run warmup iterations to stabilize GPU state and cache.""" - h2d = cudart.cudaMemcpyKind.cudaMemcpyHostToDevice - d2h = cudart.cudaMemcpyKind.cudaMemcpyDeviceToHost self.logger.debug(f"Running {self.warmup_runs} warmup iterations...") - for _ in range(self.warmup_runs): - for inp in inputs: - cudart.cudaMemcpyAsync( - inp["device_ptr"], inp["host_ptr"], inp["nbytes"], h2d, stream_handle - ) - context.execute_async_v3(stream_handle) - for out in outputs: - cudart.cudaMemcpyAsync( - out["host_ptr"], out["device_ptr"], out["nbytes"], d2h, stream_handle - ) - cudart.cudaStreamSynchronize(stream_handle) + with torch.cuda.stream(stream): + for _ in range(self.warmup_runs): + for inp in inputs: + inp["device"].copy_(inp["host"], non_blocking=True) + context.execute_async_v3(stream.cuda_stream) + for out in outputs: + out["host"].copy_(out["device"], non_blocking=True) + stream.synchronize() def _run_timing( self, context: "trt.IExecutionContext", inputs: list[dict], outputs: list[dict], - stream_handle: Any, + stream: "torch.cuda.Stream", ) -> np.ndarray: """Run timing iterations and return per-run latencies in milliseconds.""" - h2d = cudart.cudaMemcpyKind.cudaMemcpyHostToDevice - d2h = cudart.cudaMemcpyKind.cudaMemcpyDeviceToHost self.logger.debug(f"Running {self.timing_runs} timing iterations...") latencies = [] - for _ in range(self.timing_runs): - for inp in inputs: - cudart.cudaMemcpyAsync( - inp["device_ptr"], inp["host_ptr"], inp["nbytes"], h2d, stream_handle - ) + with torch.cuda.stream(stream): + for _ in range(self.timing_runs): + for inp in inputs: + inp["device"].copy_(inp["host"], non_blocking=True) - cudart.cudaStreamSynchronize(stream_handle) - start = time.perf_counter() - context.execute_async_v3(stream_handle) - cudart.cudaStreamSynchronize(stream_handle) - end = time.perf_counter() + stream.synchronize() + start = time.perf_counter() + context.execute_async_v3(stream.cuda_stream) + stream.synchronize() + end = time.perf_counter() - latency_ms = (end - start) * 1000.0 - latencies.append(latency_ms) + latencies.append((end - start) * 1000.0) - for out in outputs: - cudart.cudaMemcpyAsync( - out["host_ptr"], out["device_ptr"], out["nbytes"], d2h, stream_handle - ) + for out in outputs: + out["host"].copy_(out["device"], non_blocking=True) return np.array(latencies) @@ -721,7 +659,7 @@ def run( Measured median latency in milliseconds, or float("inf") on any error (e.g. build failure, deserialization failure, buffer/stream allocation failure). """ - serialized_engine = engine = context = stream_handle = None + serialized_engine = engine = context = stream = None inputs, outputs = [], [] try: @@ -733,19 +671,11 @@ def run( if engine is None or context is None: return float("inf") - inputs, outputs, alloc_err = self._allocate_buffers(engine, context) - if alloc_err != cudart.cudaError_t.cudaSuccess: - self.logger.error(f"Buffer allocation failed: {alloc_err}") - return float("inf") - - err, sh = cudart.cudaStreamCreate() - if err != cudart.cudaError_t.cudaSuccess: - self.logger.error(f"cudaStreamCreate failed: {err}") - return float("inf") - stream_handle = sh + inputs, outputs = self._allocate_buffers(engine, context) + stream = torch.cuda.Stream() - self._run_warmup(context, inputs, outputs, stream_handle) - latencies = self._run_timing(context, inputs, outputs, stream_handle) + self._run_warmup(context, inputs, outputs, stream) + latencies = self._run_timing(context, inputs, outputs, stream) median_latency = float(np.median(latencies)) mean_latency = float(np.mean(latencies)) @@ -789,16 +719,7 @@ def run( finally: try: self._free_buffers(inputs + outputs) - if stream_handle is not None: - cudart.cudaStreamDestroy(stream_handle) - del ( - inputs, - outputs, - stream_handle, - context, - engine, - serialized_engine, - ) + del inputs, outputs, stream, context, engine, serialized_engine except Exception as cleanup_error: self.logger.warning(f"Error during cleanup: {cleanup_error}") From 2864c6e1e48a5bf24b80f0c2de8a4d553a58ce72 Mon Sep 17 00:00:00 2001 From: gcunhase <4861122+gcunhase@users.noreply.github.com> Date: Fri, 6 Mar 2026 14:29:12 -0500 Subject: [PATCH 2/6] Added benchmark unittests Signed-off-by: gcunhase <4861122+gcunhase@users.noreply.github.com> --- .../quantization/autotune/test_benchmark.py | 344 ++++++++++++++++-- 1 file changed, 323 insertions(+), 21 deletions(-) diff --git a/tests/gpu/onnx/quantization/autotune/test_benchmark.py b/tests/gpu/onnx/quantization/autotune/test_benchmark.py index 7e7875010..e739a20e3 100644 --- a/tests/gpu/onnx/quantization/autotune/test_benchmark.py +++ b/tests/gpu/onnx/quantization/autotune/test_benchmark.py @@ -13,17 +13,36 @@ # See the License for the specific language governing permissions and # limitations under the License. -"""GPU tests for autotune Benchmark (TrtExecBenchmark, TensorRTPyBenchmark).""" +"""Tests for autotune benchmark (TrtExecBenchmark, TensorRTPyBenchmark). + +Covers: +- Pure-Python logic: _validate_shape_range, _free_buffers, __init__ import guards. +- Mocked subprocess: TrtExecBenchmark.run error paths and latency parsing. +- PyTorch CUDA management: pinned host allocation, stream creation, H2D/D2H copy + loops in _run_warmup and _run_timing (mocked TRT context). Requires CUDA. +- Full integration: TensorRTPyBenchmark and TrtExecBenchmark end-to-end. + Requires TensorRT and trtexec respectively. +""" import contextlib import os import shutil import tempfile +from unittest.mock import MagicMock, patch +import numpy as np import pytest +import torch from _test_utils.onnx.quantization.autotune.models import _create_simple_conv_onnx_model +import modelopt.onnx.quantization.autotune.benchmark as bm from modelopt.onnx.quantization.autotune import TensorRTPyBenchmark, TrtExecBenchmark +from modelopt.onnx.quantization.autotune.benchmark import _validate_shape_range + +_requires_cuda = pytest.mark.skipif(not bm.TORCH_CUDA_AVAILABLE, reason="CUDA not available") + + +# --- fixtures --- @pytest.fixture @@ -41,19 +60,307 @@ def simple_conv_model_path(simple_conv_model_bytes, tmp_path): return str(path) +@pytest.fixture +def trtexec_bench(tmp_path): + return TrtExecBenchmark( + timing_cache_file=str(tmp_path / "cache.bin"), + warmup_runs=1, + timing_runs=2, + ) + + +# --- helpers --- + + +def _make_bench(warmup_runs=2, timing_runs=3): + """Instantiate TensorRTPyBenchmark without triggering __init__ guards.""" + bench = bm.TensorRTPyBenchmark.__new__(bm.TensorRTPyBenchmark) + bench.warmup_runs = warmup_runs + bench.timing_runs = timing_runs + bench.logger = MagicMock() + return bench + + +def _make_buffers(size=4): + """Return (inputs, outputs) using real pinned host + CUDA device tensors.""" + host_in = torch.ones(size).pin_memory() + device_in = torch.zeros(size, device="cuda") + host_out = torch.zeros(size).pin_memory() + device_out = torch.ones(size, device="cuda") * 2.0 + inputs = [{"host": host_in, "device": device_in, "name": "x"}] + outputs = [{"host": host_out, "device": device_out, "name": "y"}] + return inputs, outputs + + +# --- _validate_shape_range --- + + +def test_validate_shape_range_valid(): + _validate_shape_range([1, 1], [2, 2], [4, 4]) + + +def test_validate_shape_range_equal_bounds(): + _validate_shape_range([2, 3], [2, 3], [2, 3]) + + +@pytest.mark.parametrize( + ("min_s", "opt_s", "max_s"), + [ + ([1, 1], [2, 2], [4]), + ([1], [2, 2], [4, 4]), + ], +) +def test_validate_shape_range_length_mismatch(min_s, opt_s, max_s): + with pytest.raises(ValueError, match="same length"): + _validate_shape_range(min_s, opt_s, max_s) + + +@pytest.mark.parametrize( + ("min_s", "opt_s", "max_s"), + [ + ([3], [2], [4]), # min > opt + ([1], [5], [4]), # opt > max + ([5], [3], [2]), # both violated + ], +) +def test_validate_shape_range_invalid_order(min_s, opt_s, max_s): + with pytest.raises(ValueError, match="min <= opt <= max"): + _validate_shape_range(min_s, opt_s, max_s) + + +# --- TensorRTPyBenchmark._free_buffers --- + + +def test_free_buffers_clears_list(): + bufs = [{"host": object(), "device": object(), "name": "x"}] + bm.TensorRTPyBenchmark._free_buffers(bufs) + assert bufs == [] + + +def test_free_buffers_empty_list(): + bufs = [] + bm.TensorRTPyBenchmark._free_buffers(bufs) + assert bufs == [] + + +# --- TensorRTPyBenchmark.__init__ import guards --- + + +def test_tensorrt_py_benchmark_raises_without_trt(): + with patch.object(bm, "TRT_AVAILABLE", False), pytest.raises(ImportError, match="TensorRT"): + bm.TensorRTPyBenchmark() + + +def test_tensorrt_py_benchmark_raises_without_torch_cuda(): + # TRT guard passes; TORCH_CUDA guard fires before any trt symbol is used. + with ( + patch.object(bm, "TRT_AVAILABLE", True), + patch.object(bm, "TORCH_CUDA_AVAILABLE", False), + pytest.raises(ImportError, match="PyTorch"), + ): + bm.TensorRTPyBenchmark() + + +# --- TrtExecBenchmark.run (mocked) --- + + +def test_trtexec_run_returns_inf_on_nonzero_returncode(trtexec_bench, tmp_path): + model_path = str(tmp_path / "model.onnx") + (tmp_path / "model.onnx").write_bytes(b"") + + mock_result = MagicMock() + mock_result.returncode = 1 + mock_result.stderr = "engine build error" + mock_result.stdout = "" + + with patch("subprocess.run", return_value=mock_result): + assert trtexec_bench.run(model_path) == float("inf") + + +def test_trtexec_run_returns_inf_when_latency_not_parsed(trtexec_bench, tmp_path): + model_path = str(tmp_path / "model.onnx") + (tmp_path / "model.onnx").write_bytes(b"") + + mock_result = MagicMock() + mock_result.returncode = 0 + mock_result.stdout = "Build complete. No latency line here." + mock_result.stderr = "" + + with patch("subprocess.run", return_value=mock_result): + assert trtexec_bench.run(model_path) == float("inf") + + +def test_trtexec_run_returns_parsed_latency(trtexec_bench, tmp_path): + model_path = str(tmp_path / "model.onnx") + (tmp_path / "model.onnx").write_bytes(b"") + + mock_result = MagicMock() + mock_result.returncode = 0 + mock_result.stdout = "[I] Latency: min = 2.50 ms, max = 4.00 ms, median = 3.14 ms" + mock_result.stderr = "" + + with patch("subprocess.run", return_value=mock_result): + assert trtexec_bench.run(model_path) == pytest.approx(3.14) + + +def test_trtexec_run_returns_inf_when_binary_not_found(trtexec_bench, tmp_path): + model_path = str(tmp_path / "model.onnx") + (tmp_path / "model.onnx").write_bytes(b"") + + with patch("subprocess.run", side_effect=FileNotFoundError): + assert trtexec_bench.run(model_path) == float("inf") + + +def test_trtexec_run_accepts_bytes_input(trtexec_bench): + mock_result = MagicMock() + mock_result.returncode = 0 + mock_result.stdout = "[I] Latency: min = 4.00 ms, max = 6.00 ms, median = 5.00 ms" + mock_result.stderr = "" + + with patch("subprocess.run", return_value=mock_result): + assert trtexec_bench.run(b"fake onnx bytes") == pytest.approx(5.0) + + +# --- TensorRTPyBenchmark._alloc_pinned_host --- + + +@_requires_cuda +def test_alloc_pinned_host_returns_pinned_tensor_and_numpy_view(): + size = 16 + host_tensor, arr = bm.TensorRTPyBenchmark._alloc_pinned_host(size, np.float32) + + assert host_tensor.is_pinned() + assert isinstance(arr, np.ndarray) + assert arr.shape == (size,) + assert arr.dtype == np.float32 + + # Tensor and numpy array share the same memory. + host_tensor[0] = 42.0 + assert arr[0] == pytest.approx(42.0) + + +@_requires_cuda +@pytest.mark.parametrize("dtype", [np.float16, np.float32, np.int8, np.int32]) +def test_alloc_pinned_host_dtype(dtype): + host_tensor, arr = bm.TensorRTPyBenchmark._alloc_pinned_host(8, dtype) + assert host_tensor.is_pinned() + assert arr.dtype == dtype + assert arr.shape == (8,) + + +# --- torch.cuda.Stream integration points --- + + +@_requires_cuda +def test_cuda_stream_handle_is_integer(): + """stream.cuda_stream must be an int — this is passed directly to TRT execute_async_v3.""" + stream = torch.cuda.Stream() + assert isinstance(stream.cuda_stream, int) + + +# --- TensorRTPyBenchmark._run_warmup --- + + +@_requires_cuda +def test_run_warmup_calls_execute_async_v3_correct_times(): + bench = _make_bench(warmup_runs=3) + inputs, outputs = _make_buffers() + context = MagicMock() + stream = torch.cuda.Stream() + + bench._run_warmup(context, inputs, outputs, stream) + + assert context.execute_async_v3.call_count == 3 + + +@_requires_cuda +def test_run_warmup_passes_stream_handle_to_trt(): + bench = _make_bench(warmup_runs=1) + inputs, outputs = _make_buffers() + context = MagicMock() + stream = torch.cuda.Stream() + + bench._run_warmup(context, inputs, outputs, stream) + + context.execute_async_v3.assert_called_once_with(stream.cuda_stream) + + +@_requires_cuda +def test_run_warmup_copies_host_to_device(): + bench = _make_bench(warmup_runs=1) + inputs, outputs = _make_buffers() + context = MagicMock() + stream = torch.cuda.Stream() + + bench._run_warmup(context, inputs, outputs, stream) + torch.cuda.synchronize() + + # host_in was all-ones; device_in should now match. + assert torch.allclose(inputs[0]["device"].cpu(), inputs[0]["host"]) + + +# --- TensorRTPyBenchmark._run_timing --- + + +@_requires_cuda +def test_run_timing_returns_correct_number_of_latencies(): + bench = _make_bench(timing_runs=4) + inputs, outputs = _make_buffers() + context = MagicMock() + stream = torch.cuda.Stream() + + latencies = bench._run_timing(context, inputs, outputs, stream) + + assert len(latencies) == 4 + + +@_requires_cuda +def test_run_timing_latencies_are_non_negative(): + bench = _make_bench(timing_runs=3) + inputs, outputs = _make_buffers() + context = MagicMock() + stream = torch.cuda.Stream() + + latencies = bench._run_timing(context, inputs, outputs, stream) + + assert all(lat >= 0.0 for lat in latencies) + + +@_requires_cuda +def test_run_timing_calls_execute_async_v3_correct_times(): + bench = _make_bench(timing_runs=3) + inputs, outputs = _make_buffers() + context = MagicMock() + stream = torch.cuda.Stream() + + bench._run_timing(context, inputs, outputs, stream) + + assert context.execute_async_v3.call_count == 3 + + +@_requires_cuda +def test_run_timing_passes_stream_handle_to_trt(): + bench = _make_bench(timing_runs=1) + inputs, outputs = _make_buffers() + context = MagicMock() + stream = torch.cuda.Stream() + + bench._run_timing(context, inputs, outputs, stream) + + context.execute_async_v3.assert_called_once_with(stream.cuda_stream) + + +# --- TensorRTPyBenchmark (integration) --- + + class TestTensorRTPyBenchmark: - """Tests for TensorRTPyBenchmark (TensorRT Python API + cudart).""" + """End-to-end tests for TensorRTPyBenchmark. Requires TensorRT and CUDA.""" @pytest.fixture(autouse=True) - def _require_tensorrt_and_cudart(self): + def _require_tensorrt(self): pytest.importorskip("tensorrt") - try: - from cuda.bindings import runtime # noqa: F401 - except ImportError: - try: - from cuda import cudart # noqa: F401 # deprecated: prefer cuda.bindings.runtime - except ImportError: - pytest.skip("cuda-python (cudart) not available", allow_module_level=False) + if not bm.TORCH_CUDA_AVAILABLE: + pytest.skip("CUDA not available") def test_run_with_bytes(self, simple_conv_model_bytes): """TensorRTPyBenchmark accepts model bytes and returns finite latency.""" @@ -79,8 +386,11 @@ def test_callable(self, simple_conv_model_bytes): assert latency_ms > 0 +# --- TrtExecBenchmark (integration) --- + + class TestTrtExecBenchmark: - """Tests for TrtExecBenchmark (trtexec CLI).""" + """End-to-end tests for TrtExecBenchmark. Requires trtexec in PATH.""" @pytest.fixture(autouse=True) def _require_trtexec(self): @@ -92,11 +402,7 @@ def test_run_with_path(self, simple_conv_model_path): with tempfile.NamedTemporaryFile(suffix=".cache", delete=False) as f: cache_path = f.name try: - benchmark = TrtExecBenchmark( - timing_cache_file=cache_path, - warmup_runs=1, - timing_runs=2, - ) + benchmark = TrtExecBenchmark(timing_cache_file=cache_path, warmup_runs=1, timing_runs=2) latency_ms = benchmark.run(simple_conv_model_path) assert isinstance(latency_ms, float) assert latency_ms > 0 @@ -110,11 +416,7 @@ def test_run_with_bytes(self, simple_conv_model_bytes): with tempfile.NamedTemporaryFile(suffix=".cache", delete=False) as f: cache_path = f.name try: - benchmark = TrtExecBenchmark( - timing_cache_file=cache_path, - warmup_runs=1, - timing_runs=2, - ) + benchmark = TrtExecBenchmark(timing_cache_file=cache_path, warmup_runs=1, timing_runs=2) latency_ms = benchmark.run(simple_conv_model_bytes) assert isinstance(latency_ms, float) assert latency_ms > 0 From 9e1400dbb38afd277360b24f0463248a4f0a6404 Mon Sep 17 00:00:00 2001 From: gcunhase <4861122+gcunhase@users.noreply.github.com> Date: Fri, 6 Mar 2026 15:44:11 -0500 Subject: [PATCH 3/6] Simplify torch import Signed-off-by: gcunhase <4861122+gcunhase@users.noreply.github.com> --- modelopt/onnx/quantization/autotune/benchmark.py | 9 ++------- 1 file changed, 2 insertions(+), 7 deletions(-) diff --git a/modelopt/onnx/quantization/autotune/benchmark.py b/modelopt/onnx/quantization/autotune/benchmark.py index b90ad55ea..342f3cadc 100644 --- a/modelopt/onnx/quantization/autotune/benchmark.py +++ b/modelopt/onnx/quantization/autotune/benchmark.py @@ -39,6 +39,7 @@ from typing import Any import numpy as np +import torch from modelopt.onnx.logging_config import logger from modelopt.onnx.quantization.ort_utils import _check_for_tensorrt @@ -47,13 +48,7 @@ if TRT_AVAILABLE: import tensorrt as trt -TORCH_CUDA_AVAILABLE = False -try: - import torch - - TORCH_CUDA_AVAILABLE = torch.cuda.is_available() -except ImportError: - pass +TORCH_CUDA_AVAILABLE = torch.cuda.is_available() def _validate_shape_range(min_shape: list, opt_shape: list, max_shape: list) -> None: From febb5bd856ecb13f967b0854d489f844ec894672 Mon Sep 17 00:00:00 2001 From: gcunhase <4861122+gcunhase@users.noreply.github.com> Date: Fri, 6 Mar 2026 15:44:37 -0500 Subject: [PATCH 4/6] Remove requires_cuda check from tests Signed-off-by: gcunhase <4861122+gcunhase@users.noreply.github.com> --- .../onnx/quantization/autotune/test_benchmark.py | 13 ------------- 1 file changed, 13 deletions(-) diff --git a/tests/gpu/onnx/quantization/autotune/test_benchmark.py b/tests/gpu/onnx/quantization/autotune/test_benchmark.py index e739a20e3..c630baff6 100644 --- a/tests/gpu/onnx/quantization/autotune/test_benchmark.py +++ b/tests/gpu/onnx/quantization/autotune/test_benchmark.py @@ -39,9 +39,6 @@ from modelopt.onnx.quantization.autotune import TensorRTPyBenchmark, TrtExecBenchmark from modelopt.onnx.quantization.autotune.benchmark import _validate_shape_range -_requires_cuda = pytest.mark.skipif(not bm.TORCH_CUDA_AVAILABLE, reason="CUDA not available") - - # --- fixtures --- @@ -224,7 +221,6 @@ def test_trtexec_run_accepts_bytes_input(trtexec_bench): # --- TensorRTPyBenchmark._alloc_pinned_host --- -@_requires_cuda def test_alloc_pinned_host_returns_pinned_tensor_and_numpy_view(): size = 16 host_tensor, arr = bm.TensorRTPyBenchmark._alloc_pinned_host(size, np.float32) @@ -239,7 +235,6 @@ def test_alloc_pinned_host_returns_pinned_tensor_and_numpy_view(): assert arr[0] == pytest.approx(42.0) -@_requires_cuda @pytest.mark.parametrize("dtype", [np.float16, np.float32, np.int8, np.int32]) def test_alloc_pinned_host_dtype(dtype): host_tensor, arr = bm.TensorRTPyBenchmark._alloc_pinned_host(8, dtype) @@ -251,7 +246,6 @@ def test_alloc_pinned_host_dtype(dtype): # --- torch.cuda.Stream integration points --- -@_requires_cuda def test_cuda_stream_handle_is_integer(): """stream.cuda_stream must be an int — this is passed directly to TRT execute_async_v3.""" stream = torch.cuda.Stream() @@ -261,7 +255,6 @@ def test_cuda_stream_handle_is_integer(): # --- TensorRTPyBenchmark._run_warmup --- -@_requires_cuda def test_run_warmup_calls_execute_async_v3_correct_times(): bench = _make_bench(warmup_runs=3) inputs, outputs = _make_buffers() @@ -273,7 +266,6 @@ def test_run_warmup_calls_execute_async_v3_correct_times(): assert context.execute_async_v3.call_count == 3 -@_requires_cuda def test_run_warmup_passes_stream_handle_to_trt(): bench = _make_bench(warmup_runs=1) inputs, outputs = _make_buffers() @@ -285,7 +277,6 @@ def test_run_warmup_passes_stream_handle_to_trt(): context.execute_async_v3.assert_called_once_with(stream.cuda_stream) -@_requires_cuda def test_run_warmup_copies_host_to_device(): bench = _make_bench(warmup_runs=1) inputs, outputs = _make_buffers() @@ -302,7 +293,6 @@ def test_run_warmup_copies_host_to_device(): # --- TensorRTPyBenchmark._run_timing --- -@_requires_cuda def test_run_timing_returns_correct_number_of_latencies(): bench = _make_bench(timing_runs=4) inputs, outputs = _make_buffers() @@ -314,7 +304,6 @@ def test_run_timing_returns_correct_number_of_latencies(): assert len(latencies) == 4 -@_requires_cuda def test_run_timing_latencies_are_non_negative(): bench = _make_bench(timing_runs=3) inputs, outputs = _make_buffers() @@ -326,7 +315,6 @@ def test_run_timing_latencies_are_non_negative(): assert all(lat >= 0.0 for lat in latencies) -@_requires_cuda def test_run_timing_calls_execute_async_v3_correct_times(): bench = _make_bench(timing_runs=3) inputs, outputs = _make_buffers() @@ -338,7 +326,6 @@ def test_run_timing_calls_execute_async_v3_correct_times(): assert context.execute_async_v3.call_count == 3 -@_requires_cuda def test_run_timing_passes_stream_handle_to_trt(): bench = _make_bench(timing_runs=1) inputs, outputs = _make_buffers() From d05634a5d99003ebb03629137bfc0a941b0ccee6 Mon Sep 17 00:00:00 2001 From: gcunhase <4861122+gcunhase@users.noreply.github.com> Date: Fri, 6 Mar 2026 15:48:13 -0500 Subject: [PATCH 5/6] Resolve copilot comments Signed-off-by: gcunhase <4861122+gcunhase@users.noreply.github.com> --- modelopt/onnx/quantization/autotune/benchmark.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/modelopt/onnx/quantization/autotune/benchmark.py b/modelopt/onnx/quantization/autotune/benchmark.py index 342f3cadc..6278eb43e 100644 --- a/modelopt/onnx/quantization/autotune/benchmark.py +++ b/modelopt/onnx/quantization/autotune/benchmark.py @@ -331,7 +331,7 @@ def __init__( engine building. If None, no custom plugins are loaded. Raises: - ImportError: If tensorrt or cuda-python (cudart) packages are not available. + ImportError: If tensorrt is not installed or if torch is not built with CUDA support. FileNotFoundError: If a specified plugin library file does not exist. RuntimeError: If plugin library loading fails. """ @@ -713,7 +713,8 @@ def run( return float("inf") finally: try: - self._free_buffers(inputs + outputs) + self._free_buffers(inputs) + self._free_buffers(outputs) del inputs, outputs, stream, context, engine, serialized_engine except Exception as cleanup_error: self.logger.warning(f"Error during cleanup: {cleanup_error}") From ee9ba37e3cd45e8c662da0d6cffb47585e750ce8 Mon Sep 17 00:00:00 2001 From: gcunhase <4861122+gcunhase@users.noreply.github.com> Date: Mon, 9 Mar 2026 10:13:32 -0400 Subject: [PATCH 6/6] Rename _require_trt_and_torch_cuda function Signed-off-by: gcunhase <4861122+gcunhase@users.noreply.github.com> --- tests/gpu/onnx/quantization/autotune/test_benchmark.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/gpu/onnx/quantization/autotune/test_benchmark.py b/tests/gpu/onnx/quantization/autotune/test_benchmark.py index c630baff6..925d45fff 100644 --- a/tests/gpu/onnx/quantization/autotune/test_benchmark.py +++ b/tests/gpu/onnx/quantization/autotune/test_benchmark.py @@ -344,7 +344,7 @@ class TestTensorRTPyBenchmark: """End-to-end tests for TensorRTPyBenchmark. Requires TensorRT and CUDA.""" @pytest.fixture(autouse=True) - def _require_tensorrt(self): + def _require_tensorrt_and_torch_cuda(self): pytest.importorskip("tensorrt") if not bm.TORCH_CUDA_AVAILABLE: pytest.skip("CUDA not available")