Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
195 changes: 56 additions & 139 deletions modelopt/onnx/quantization/autotune/benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,6 @@
- TensorRTPyBenchmark: Uses TensorRT Python API for direct engine profiling
"""

import contextlib
import ctypes
import importlib.util
import os
Expand All @@ -40,6 +39,7 @@
from typing import Any

import numpy as np
import torch

from modelopt.onnx.logging_config import logger
from modelopt.onnx.quantization.ort_utils import _check_for_tensorrt
Expand All @@ -48,13 +48,7 @@
if TRT_AVAILABLE:
import tensorrt as trt

CUDART_AVAILABLE = importlib.util.find_spec("cuda") is not None
if CUDART_AVAILABLE:
try:
from cuda.bindings import runtime as cudart
except ImportError:
with contextlib.suppress(ImportError):
from cuda import cudart # deprecated: prefer cuda.bindings.runtime
TORCH_CUDA_AVAILABLE = torch.cuda.is_available()


def _validate_shape_range(min_shape: list, opt_shape: list, max_shape: list) -> None:
Expand Down Expand Up @@ -337,17 +331,17 @@ def __init__(
engine building. If None, no custom plugins are loaded.

Raises:
ImportError: If tensorrt or cuda-python (cudart) packages are not available.
ImportError: If tensorrt is not installed or if torch is not built with CUDA support.
FileNotFoundError: If a specified plugin library file does not exist.
RuntimeError: If plugin library loading fails.
"""
super().__init__(timing_cache_file, warmup_runs, timing_runs, plugin_libraries)

if not TRT_AVAILABLE:
raise ImportError("TensorRT Python API not available. Please install tensorrt package.")
if not CUDART_AVAILABLE or cudart is None:
if not TORCH_CUDA_AVAILABLE:
raise ImportError(
"CUDA Runtime (cudart) not available. Please install cuda-python package: pip install cuda-python"
"PyTorch with CUDA support not available. Please install torch with CUDA: pip install torch"
)

self.trt_logger = trt.Logger(trt.Logger.WARNING)
Expand Down Expand Up @@ -527,107 +521,57 @@ def _build_engine(
del parser, network, config

@staticmethod
def _alloc_pinned_host(size: int, dtype: np.dtype) -> tuple[Any, np.ndarray, Any]:
"""Allocate pinned host memory and return (host_ptr, array view, cuda error).
def _alloc_pinned_host(size: int, dtype: np.dtype) -> tuple[Any, np.ndarray]:
"""Allocate pinned host memory using PyTorch and return (tensor, numpy_view).

Returns:
(host_ptr, arr, err): On success err is cudaSuccess; on failure host_ptr/arr
may be None and err is the CUDA error code.
(host_tensor, arr): Pinned PyTorch tensor and a numpy view over it.
"""
dtype = np.dtype(dtype)
nbytes = size * dtype.itemsize
err, host_ptr = cudart.cudaMallocHost(nbytes)
if err != cudart.cudaError_t.cudaSuccess:
return (None, None, err)
addr = int(host_ptr) if hasattr(host_ptr, "__int__") else host_ptr
try:
ctype = np.ctypeslib.as_ctypes_type(dtype)
arr = np.ctypeslib.as_array((ctype * size).from_address(addr))
except NotImplementedError as e:
# float16/bfloat16 have no ctypes equivalent; use same-size type and view
if dtype.itemsize == 2:
ctype = ctypes.c_uint16
else:
raise TypeError(
f"Pinned host allocation for dtype {dtype} is not supported: "
"no ctypes mapping and no fallback for this itemsize"
) from e
arr = np.ctypeslib.as_array((ctype * size).from_address(addr)).view(dtype)
return (host_ptr, arr, cudart.cudaError_t.cudaSuccess)
torch_dtype = torch.from_numpy(np.empty(0, dtype=dtype)).dtype
host_tensor = torch.empty(int(size), dtype=torch_dtype).pin_memory()
return host_tensor, host_tensor.numpy()

@staticmethod
def _free_buffers(bufs: list[dict]) -> None:
"""Free host and device memory for a list of buffer dicts (host_ptr, device_ptr)."""
for buf in bufs:
if "host_ptr" in buf and buf["host_ptr"] is not None:
cudart.cudaFreeHost(buf["host_ptr"])
if "device_ptr" in buf and buf["device_ptr"] is not None:
cudart.cudaFree(buf["device_ptr"])
"""Release buffer references; PyTorch handles underlying memory deallocation."""
bufs.clear()

def _allocate_buffers(
self,
engine: "trt.ICudaEngine",
context: "trt.IExecutionContext",
) -> tuple[list[dict], list[dict], Any]:
"""Allocate host and device buffers for engine I/O and set tensor addresses.
) -> tuple[list[dict], list[dict]]:
"""Allocate pinned host and device tensors for engine I/O and set tensor addresses.

Args:
engine: Deserialized TensorRT engine.
context: Execution context with tensor shapes set.

Returns:
(inputs, outputs, cuda_error): On success cuda_error is cudaSuccess;
on failure inputs/outputs are empty and cuda_error is the failing CUDA error code.
(inputs, outputs): Lists of buffer dicts containing PyTorch tensors.
"""
inputs: list[dict] = []
outputs: list[dict] = []

for i in range(engine.num_io_tensors):
tensor_name = engine.get_tensor_name(i)
dtype = trt.nptype(engine.get_tensor_dtype(tensor_name))
np_dtype = trt.nptype(engine.get_tensor_dtype(tensor_name))
shape = context.get_tensor_shape(tensor_name)
size = int(trt.volume(shape))

size = trt.volume(shape)
nbytes = size * np.dtype(dtype).itemsize
host_tensor, host_mem = self._alloc_pinned_host(size, np_dtype)
torch_dtype = torch.from_numpy(np.empty(0, dtype=np_dtype)).dtype
device_tensor = torch.empty(size, dtype=torch_dtype, device="cuda")

err, device_ptr = cudart.cudaMalloc(nbytes)
if err != cudart.cudaError_t.cudaSuccess:
self.logger.error(f"cudaMalloc failed: {err}")
self._free_buffers(inputs + outputs)
return ([], [], err)

host_ptr, host_mem, err = self._alloc_pinned_host(size, dtype)
if err != cudart.cudaError_t.cudaSuccess:
self.logger.error(f"cudaMallocHost failed: {err}")
cudart.cudaFree(device_ptr)
self._free_buffers(inputs + outputs)
return ([], [], err)
context.set_tensor_address(tensor_name, device_tensor.data_ptr())

if engine.get_tensor_mode(tensor_name) == trt.TensorIOMode.INPUT:
np.copyto(host_mem, np.random.randn(size).astype(dtype))
inputs.append(
{
"host_ptr": host_ptr,
"host": host_mem,
"device_ptr": device_ptr,
"nbytes": nbytes,
"name": tensor_name,
}
)
np.copyto(host_mem, np.random.randn(size).astype(np_dtype))
inputs.append({"host": host_tensor, "device": device_tensor, "name": tensor_name})
else:
outputs.append(
{
"host_ptr": host_ptr,
"host": host_mem,
"device_ptr": device_ptr,
"nbytes": nbytes,
"name": tensor_name,
}
)

context.set_tensor_address(tensor_name, int(device_ptr))
outputs.append({"host": host_tensor, "device": device_tensor, "name": tensor_name})

return (inputs, outputs, cudart.cudaError_t.cudaSuccess)
return (inputs, outputs)

def _setup_execution_context(
self, serialized_engine: bytes
Expand All @@ -652,55 +596,44 @@ def _run_warmup(
context: "trt.IExecutionContext",
inputs: list[dict],
outputs: list[dict],
stream_handle: Any,
stream: "torch.cuda.Stream",
) -> None:
"""Run warmup iterations to stabilize GPU state and cache."""
h2d = cudart.cudaMemcpyKind.cudaMemcpyHostToDevice
d2h = cudart.cudaMemcpyKind.cudaMemcpyDeviceToHost
self.logger.debug(f"Running {self.warmup_runs} warmup iterations...")
for _ in range(self.warmup_runs):
for inp in inputs:
cudart.cudaMemcpyAsync(
inp["device_ptr"], inp["host_ptr"], inp["nbytes"], h2d, stream_handle
)
context.execute_async_v3(stream_handle)
for out in outputs:
cudart.cudaMemcpyAsync(
out["host_ptr"], out["device_ptr"], out["nbytes"], d2h, stream_handle
)
cudart.cudaStreamSynchronize(stream_handle)
with torch.cuda.stream(stream):
for _ in range(self.warmup_runs):
for inp in inputs:
inp["device"].copy_(inp["host"], non_blocking=True)
context.execute_async_v3(stream.cuda_stream)
for out in outputs:
out["host"].copy_(out["device"], non_blocking=True)
stream.synchronize()

def _run_timing(
self,
context: "trt.IExecutionContext",
inputs: list[dict],
outputs: list[dict],
stream_handle: Any,
stream: "torch.cuda.Stream",
) -> np.ndarray:
"""Run timing iterations and return per-run latencies in milliseconds."""
h2d = cudart.cudaMemcpyKind.cudaMemcpyHostToDevice
d2h = cudart.cudaMemcpyKind.cudaMemcpyDeviceToHost
self.logger.debug(f"Running {self.timing_runs} timing iterations...")
latencies = []
for _ in range(self.timing_runs):
for inp in inputs:
cudart.cudaMemcpyAsync(
inp["device_ptr"], inp["host_ptr"], inp["nbytes"], h2d, stream_handle
)
with torch.cuda.stream(stream):
for _ in range(self.timing_runs):
for inp in inputs:
inp["device"].copy_(inp["host"], non_blocking=True)

cudart.cudaStreamSynchronize(stream_handle)
start = time.perf_counter()
context.execute_async_v3(stream_handle)
cudart.cudaStreamSynchronize(stream_handle)
end = time.perf_counter()
stream.synchronize()
start = time.perf_counter()
context.execute_async_v3(stream.cuda_stream)
stream.synchronize()
end = time.perf_counter()

latency_ms = (end - start) * 1000.0
latencies.append(latency_ms)
latencies.append((end - start) * 1000.0)

for out in outputs:
cudart.cudaMemcpyAsync(
out["host_ptr"], out["device_ptr"], out["nbytes"], d2h, stream_handle
)
for out in outputs:
out["host"].copy_(out["device"], non_blocking=True)

return np.array(latencies)

Expand All @@ -721,7 +654,7 @@ def run(
Measured median latency in milliseconds, or float("inf") on any error
(e.g. build failure, deserialization failure, buffer/stream allocation failure).
"""
serialized_engine = engine = context = stream_handle = None
serialized_engine = engine = context = stream = None
inputs, outputs = [], []

try:
Expand All @@ -733,19 +666,11 @@ def run(
if engine is None or context is None:
return float("inf")

inputs, outputs, alloc_err = self._allocate_buffers(engine, context)
if alloc_err != cudart.cudaError_t.cudaSuccess:
self.logger.error(f"Buffer allocation failed: {alloc_err}")
return float("inf")
inputs, outputs = self._allocate_buffers(engine, context)
stream = torch.cuda.Stream()

err, sh = cudart.cudaStreamCreate()
if err != cudart.cudaError_t.cudaSuccess:
self.logger.error(f"cudaStreamCreate failed: {err}")
return float("inf")
stream_handle = sh

self._run_warmup(context, inputs, outputs, stream_handle)
latencies = self._run_timing(context, inputs, outputs, stream_handle)
self._run_warmup(context, inputs, outputs, stream)
latencies = self._run_timing(context, inputs, outputs, stream)

median_latency = float(np.median(latencies))
mean_latency = float(np.mean(latencies))
Expand Down Expand Up @@ -788,17 +713,9 @@ def run(
return float("inf")
finally:
try:
self._free_buffers(inputs + outputs)
if stream_handle is not None:
cudart.cudaStreamDestroy(stream_handle)
del (
inputs,
outputs,
stream_handle,
context,
engine,
serialized_engine,
)
self._free_buffers(inputs)
self._free_buffers(outputs)
del inputs, outputs, stream, context, engine, serialized_engine
except Exception as cleanup_error:
self.logger.warning(f"Error during cleanup: {cleanup_error}")

Expand Down
Loading