From 0af1296a499ff9883000c00e7716842d132737e3 Mon Sep 17 00:00:00 2001 From: apbose Date: Thu, 23 Apr 2026 17:32:45 -0700 Subject: [PATCH 1/3] engine caching for distributed --- py/torch_tensorrt/distributed/__init__.py | 4 + py/torch_tensorrt/distributed/_distributed.py | 93 +++++ py/torch_tensorrt/distributed/_lock.py | 189 +++++++++ .../dynamo/conversion/_conversion.py | 49 ++- .../test_distributed_engine_cache.py | 381 ++++++++++++++++++ 5 files changed, 714 insertions(+), 2 deletions(-) create mode 100644 py/torch_tensorrt/distributed/_lock.py create mode 100644 tests/py/dynamo/distributed/test_distributed_engine_cache.py diff --git a/py/torch_tensorrt/distributed/__init__.py b/py/torch_tensorrt/distributed/__init__.py index fdb95b02b1..15194667ff 100644 --- a/py/torch_tensorrt/distributed/__init__.py +++ b/py/torch_tensorrt/distributed/__init__.py @@ -1,7 +1,11 @@ from torch_tensorrt.distributed._distributed import ( # noqa: F401 distributed_context, + is_distributed_caching_enabled, set_distributed_mode, + signal_distributed_engine_build_complete, + wait_for_distributed_engine_build, ) +from torch_tensorrt.distributed._lock import DistributedFileLock # noqa: F401 from torch_tensorrt.distributed._nccl_utils import ( # noqa: F401 setup_nccl_for_torch_tensorrt, ) diff --git a/py/torch_tensorrt/distributed/_distributed.py b/py/torch_tensorrt/distributed/_distributed.py index 16714e8a1e..b0e0577df0 100644 --- a/py/torch_tensorrt/distributed/_distributed.py +++ b/py/torch_tensorrt/distributed/_distributed.py @@ -205,3 +205,96 @@ def set_distributed_mode(group: Any, module: nn.Module) -> None: seen.add(id(engine)) if getattr(engine, "requires_native_multidevice", False): engine.set_group_name(group_name) + + +def is_distributed_caching_enabled( + is_engine_caching_supported: bool, + cache_built_engines: bool, + reuse_cached_engines: bool, +) -> bool: + """Check if distributed engine cache coordination should be used. + + Returns True when all conditions are met: + - Engine caching is supported (cache exists, refit available, mutable weights) + - User enabled both cache_built_engines and reuse_cached_engines + - Running in a distributed environment with world_size > 1 + + When True, only one rank builds the TRT engine and caches it. + Other ranks wait and load from the shared DiskEngineCache. + """ + return ( + is_engine_caching_supported + and cache_built_engines + and reuse_cached_engines + and dist.is_available() + and dist.is_initialized() + and dist.get_world_size() > 1 + ) + + +def wait_for_distributed_engine_build( + pull_fn: Any, + cache_dir: str, + hash_val: str, + poll_interval: float = 0.5, + timeout: float = 600.0, +) -> Any: + """Non-building rank: poll for cached engine file, then load from cache. + + Called when this rank failed to acquire the build lock, meaning another + rank is building the engine. Polls the filesystem for the cached engine + file instead of using NCCL collectives (which are unreliable inside + the TRT compilation path due to aot_autograd/CUDA stream conflicts). + + Args: + pull_fn: Zero-arg callable (e.g. functools.partial) that loads the + engine from cache. Returns SerializedInterpreterResult on + hit, None on miss. + cache_dir: Shared engine cache directory path. + hash_val: Engine hash for this compilation. + poll_interval: Seconds between filesystem checks (default 0.5s). + timeout: Maximum seconds to wait before giving up (default 600s). + + Returns: + SerializedInterpreterResult on cache hit, None on timeout. + """ + import logging + import os + import time + + logger = logging.getLogger(__name__) + + blob_path = os.path.join(cache_dir, hash_val, "blob.bin") + logger.info(f"Polling for cached engine: {blob_path}") + + elapsed = 0.0 + while not os.path.exists(blob_path): + time.sleep(poll_interval) + elapsed += poll_interval + if elapsed >= timeout: + logger.warning( + f"Polling timed out after {timeout:.0f}s — building engine locally" + ) + return None + + logger.info(f"Cached engine found after {elapsed:.1f}s — loading from cache") + cached = pull_fn() + if cached is not None: + return cached + + logger.warning("Cache file exists but pull_cached_engine failed — building locally") + return None + + +def signal_distributed_engine_build_complete(lock: Any) -> None: + """Building rank: release the file lock after caching the engine. + + Called after the building rank has inserted the engine into the shared + cache. Releases the file lock so other ranks' stale lock detection + works correctly. No NCCL collective needed — waiter ranks poll the + filesystem directly. + + Args: + lock: DistributedFileLock instance that was acquired by this rank. + """ + lock.release() diff --git a/py/torch_tensorrt/distributed/_lock.py b/py/torch_tensorrt/distributed/_lock.py new file mode 100644 index 0000000000..32acc407a6 --- /dev/null +++ b/py/torch_tensorrt/distributed/_lock.py @@ -0,0 +1,189 @@ +""" +File-based distributed lock for coordinating work across ranks. + +Used to ensure only one rank performs an expensive operation (e.g., TRT engine +build) while others wait and consume the result from a shared cache. + +The lock file is created atomically via os.O_CREAT | os.O_EXCL — only one +process can succeed. Other processes see FileExistsError and know to wait. + +If the lock holder crashes, the lock file becomes stale. A configurable +timeout detects this: if the lock file's modification time is older than +the timeout, it is treated as stale and forcibly removed so another rank +can proceed. +""" + +import logging +import os +import time + +import torch.distributed as dist + +logger = logging.getLogger(__name__) + +# Default timeout for stale lock detection (seconds). +# TRT engine builds can take several minutes for large models. +_DEFAULT_STALE_TIMEOUT_S = 600 # 10 minutes + + +class DistributedFileLock: + """Atomic file lock for coordinating distributed engine builds. + + Args: + lock_dir: Directory to create lock files in (typically the cache dir). + name: Unique name for this lock (typically the engine hash). + suffix: File suffix for the lock file. + stale_timeout_s: Seconds after which a lock file is considered stale + (holder likely crashed). Set to 0 to disable stale detection. + """ + + def __init__( + self, + lock_dir: str, + name: str, + suffix: str = ".building", + stale_timeout_s: float = _DEFAULT_STALE_TIMEOUT_S, + ) -> None: + self._lock_path = os.path.join(lock_dir, f".{name}{suffix}") + self._acquired = False + self._stale_timeout_s = stale_timeout_s + + @property + def acquired(self) -> bool: + """Whether this instance holds the lock.""" + return self._acquired + + @property + def lock_path(self) -> str: + return self._lock_path + + def _is_stale(self) -> bool: + """Check if an existing lock file is stale (holder likely crashed). + + A lock is stale if its modification time is older than stale_timeout_s. + """ + if self._stale_timeout_s <= 0: + return False + try: + mtime = os.path.getmtime(self._lock_path) + age = time.time() - mtime + if age > self._stale_timeout_s: + logger.warning( + f"Stale build lock detected: {self._lock_path} " + f"(age={age:.0f}s > timeout={self._stale_timeout_s:.0f}s). " + f"Lock holder likely crashed. Removing stale lock." + ) + return True + except FileNotFoundError: + pass + return False + + def _remove_stale(self) -> None: + """Remove a stale lock file so acquire() can succeed.""" + try: + os.remove(self._lock_path) + logger.debug(f"Removed stale lock: {self._lock_path}") + except FileNotFoundError: + pass + + def acquire(self) -> bool: + """Try to acquire the lock atomically. + + If the lock file exists but is stale (older than stale_timeout_s), + it is removed and acquisition is retried once. + + Returns: + True if this process acquired the lock (should do the work). + False if another process already holds it (should wait). + """ + try: + fd = os.open(self._lock_path, os.O_CREAT | os.O_EXCL | os.O_WRONLY) + os.close(fd) + self._acquired = True + logger.debug(f"Acquired build lock: {self._lock_path}") + return True + except FileExistsError: + pass + + if self._is_stale(): + self._remove_stale() + try: + fd = os.open(self._lock_path, os.O_CREAT | os.O_EXCL | os.O_WRONLY) + os.close(fd) + self._acquired = True + logger.debug( + f"Acquired build lock after stale removal: {self._lock_path}" + ) + return True + except FileExistsError: + pass + + self._acquired = False + logger.debug(f"Build lock already held: {self._lock_path}") + return False + + def release(self) -> None: + """Release the lock by removing the lock file.""" + if self._acquired: + try: + os.remove(self._lock_path) + logger.debug(f"Released build lock: {self._lock_path}") + except FileNotFoundError: + pass + self._acquired = False + + @staticmethod + def cleanup_stale_locks( + lock_dir: str, + suffix: str = ".building", + stale_timeout_s: float = _DEFAULT_STALE_TIMEOUT_S, + ) -> int: + """Remove all stale lock files in a directory. + + Useful for cleaning up after crashes before starting a new run. + + Args: + lock_dir: Directory to scan for lock files. + suffix: Lock file suffix to match. + stale_timeout_s: Age threshold in seconds. + + Returns: + Number of stale lock files removed. + """ + removed = 0 + if not os.path.isdir(lock_dir): + return 0 + now = time.time() + for entry in os.scandir(lock_dir): + if ( + entry.name.startswith(".") + and entry.name.endswith(suffix) + and entry.is_file() + ): + try: + age = now - entry.stat().st_mtime + if age > stale_timeout_s: + os.remove(entry.path) + logger.debug( + f"Cleaned up stale lock: {entry.path} (age={age:.0f}s)" + ) + removed += 1 + except (FileNotFoundError, OSError): + pass + if removed > 0: + logger.info(f"Cleaned up {removed} stale lock file(s) in {lock_dir}") + return removed + + @staticmethod + def barrier() -> None: + """Distributed barrier — all ranks must call this.""" + if dist.is_available() and dist.is_initialized() and dist.get_world_size() > 1: + dist.barrier() + + def __enter__(self) -> "DistributedFileLock": + self.acquire() + return self + + def __exit__(self, exc_type, exc_val, exc_tb) -> None: # type: ignore[no-untyped-def] + self.release() + return None diff --git a/py/torch_tensorrt/dynamo/conversion/_conversion.py b/py/torch_tensorrt/dynamo/conversion/_conversion.py index 77e48fe92e..80870ceb43 100644 --- a/py/torch_tensorrt/dynamo/conversion/_conversion.py +++ b/py/torch_tensorrt/dynamo/conversion/_conversion.py @@ -2,12 +2,20 @@ import io import logging +from functools import partial from typing import Any, Dict, List, NamedTuple, Optional, Sequence +import tensorrt as trt import torch from torch_tensorrt._enums import dtype from torch_tensorrt._features import ENABLED_FEATURES from torch_tensorrt._Input import Input +from torch_tensorrt.distributed._distributed import ( + is_distributed_caching_enabled, + signal_distributed_engine_build_complete, + wait_for_distributed_engine_build, +) +from torch_tensorrt.distributed._lock import DistributedFileLock from torch_tensorrt.dynamo._engine_cache import BaseEngineCache from torch_tensorrt.dynamo._settings import CompilationSettings, settings_are_compatible from torch_tensorrt.dynamo.conversion._symbolic_shape_capture import ( @@ -25,8 +33,6 @@ ) from torch_tensorrt.logging import TRT_LOGGER -import tensorrt as trt - logger = logging.getLogger(__name__) @@ -262,6 +268,41 @@ def interpret_module_to_result( if serialized_interpreter_result is not None: # hit the cache return serialized_interpreter_result + # Distributed engine cache coordination: only one rank builds, + # others wait and load from shared cache. + _distributed_caching = is_distributed_caching_enabled( + is_engine_caching_supported, + settings.cache_built_engines, + settings.reuse_cached_engines, + ) + _build_lock = None + + if _distributed_caching: + # is_distributed_caching_enabled guarantees engine_cache and hash_val are set. + assert engine_cache is not None + assert hash_val is not None + _build_lock = DistributedFileLock(engine_cache.engine_cache_dir, hash_val) + if _build_lock.acquire(): + logger.info("Acquired engine build lock — this rank builds") + else: + logger.info("Lock held by another rank — polling for cached engine") + _pull_fn = partial( + pull_cached_engine, + hash_val, + module, + engine_cache, + settings, + inputs, + symbolic_shape_expressions, + ) + cached: Optional[SerializedInterpreterResult] = ( + wait_for_distributed_engine_build( + _pull_fn, engine_cache.engine_cache_dir, hash_val + ) + ) + if cached is not None: + return cached + output_dtypes = infer_module_output_dtypes( module, truncate_double=settings.truncate_double ) @@ -307,6 +348,10 @@ def interpret_module_to_result( hash_val, interpreter_result, engine_cache, settings, inputs ) + # Signal other ranks that the engine is cached and ready + if _build_lock is not None and _build_lock.acquired: + signal_distributed_engine_build_complete(_build_lock) + serialized_engine = interpreter_result.engine.serialize() with io.BytesIO() as engine_bytes: engine_bytes.write(serialized_engine) diff --git a/tests/py/dynamo/distributed/test_distributed_engine_cache.py b/tests/py/dynamo/distributed/test_distributed_engine_cache.py new file mode 100644 index 0000000000..5aa79aa79c --- /dev/null +++ b/tests/py/dynamo/distributed/test_distributed_engine_cache.py @@ -0,0 +1,381 @@ +""" +Distributed engine cache coordination tests. + +Verifies that when multiple ranks compile the same model with engine caching +enabled, only one rank builds the TRT engine and others load from the shared +DiskEngineCache. + +Tests: + 1. DistributedFileLock — acquire, release, stale detection (no GPU) + 2. Multi-rank: one rank builds, other loads from cache (2 GPUs) + +Run single-rank tests (no GPU needed): + cd tests/py/dynamo + pytest distributed/test_distributed_engine_cache.py -v + +Run multi-rank tests (requires 2 GPUs): + pytest distributed/test_distributed_engine_cache.py::TestMultirankDistributedCache -v + +Run via torchrun: + torchrun --nproc_per_node=2 distributed/test_distributed_engine_cache.py --multirank +""" + +from __future__ import annotations + +import os +import sys +import tempfile +import time +import unittest + +import torch +import torch.distributed as dist +import torch.nn as nn +from torch.testing._internal.common_distributed import ( + MultiProcessTestCase, + requires_nccl, + skip_if_lt_x_gpu, +) +from torch.testing._internal.common_utils import run_tests + +# --------------------------------------------------------------------------- +# Capability checks +# --------------------------------------------------------------------------- + + +def is_nccl_available() -> bool: + try: + return dist.is_nccl_available() + except Exception: + return False + + +def has_nccl_collectives() -> bool: + try: + from torch_tensorrt._features import ENABLED_FEATURES + + return bool(ENABLED_FEATURES.native_trt_collectives) or bool( + ENABLED_FEATURES.trtllm_for_nccl + ) + except Exception: + return False + + +# --------------------------------------------------------------------------- +# Section 1 — DistributedFileLock (no GPU, no dist) +# --------------------------------------------------------------------------- + + +class TestDistributedFileLock(unittest.TestCase): + """Unit tests for the file-based distributed lock.""" + + def setUp(self) -> None: + self.tmp_dir = tempfile.mkdtemp(prefix="trt_lock_test_") + + def test_acquire_returns_true_first_time(self) -> None: + from torch_tensorrt.distributed._lock import DistributedFileLock + + lock = DistributedFileLock(self.tmp_dir, "test_hash") + self.assertTrue(lock.acquire()) + self.assertTrue(lock.acquired) + lock.release() + + def test_acquire_returns_false_when_held(self) -> None: + from torch_tensorrt.distributed._lock import DistributedFileLock + + lock1 = DistributedFileLock(self.tmp_dir, "test_hash") + lock2 = DistributedFileLock(self.tmp_dir, "test_hash") + + self.assertTrue(lock1.acquire()) + self.assertFalse(lock2.acquire()) + self.assertFalse(lock2.acquired) + + lock1.release() + + def test_acquire_succeeds_after_release(self) -> None: + from torch_tensorrt.distributed._lock import DistributedFileLock + + lock1 = DistributedFileLock(self.tmp_dir, "test_hash") + lock1.acquire() + lock1.release() + + lock2 = DistributedFileLock(self.tmp_dir, "test_hash") + self.assertTrue(lock2.acquire()) + lock2.release() + + def test_release_without_acquire_is_noop(self) -> None: + from torch_tensorrt.distributed._lock import DistributedFileLock + + lock = DistributedFileLock(self.tmp_dir, "test_hash") + lock.release() # should not raise + + def test_context_manager_acquires_and_releases(self) -> None: + from torch_tensorrt.distributed._lock import DistributedFileLock + + with DistributedFileLock(self.tmp_dir, "test_hash") as lock: + self.assertTrue(lock.acquired) + self.assertTrue(os.path.exists(lock.lock_path)) + + # After exit: lock file removed + self.assertFalse(os.path.exists(lock.lock_path)) + + def test_context_manager_releases_on_exception(self) -> None: + from torch_tensorrt.distributed._lock import DistributedFileLock + + try: + with DistributedFileLock(self.tmp_dir, "test_hash") as lock: + lock_path = lock.lock_path + raise RuntimeError("test error") + except RuntimeError: + pass + + self.assertFalse(os.path.exists(lock_path)) + + def test_different_names_dont_conflict(self) -> None: + from torch_tensorrt.distributed._lock import DistributedFileLock + + lock_a = DistributedFileLock(self.tmp_dir, "hash_a") + lock_b = DistributedFileLock(self.tmp_dir, "hash_b") + + self.assertTrue(lock_a.acquire()) + self.assertTrue(lock_b.acquire()) # different name, no conflict + + lock_a.release() + lock_b.release() + + def test_stale_lock_detected_and_reacquired(self) -> None: + from torch_tensorrt.distributed._lock import DistributedFileLock + + # Create a lock with very short stale timeout + lock1 = DistributedFileLock(self.tmp_dir, "test_hash", stale_timeout_s=0.1) + lock1.acquire() + # Don't release — simulate crash + + time.sleep(0.2) # wait for it to become stale + + # Another process tries to acquire + lock2 = DistributedFileLock(self.tmp_dir, "test_hash", stale_timeout_s=0.1) + self.assertTrue(lock2.acquire()) # should detect stale and reacquire + lock2.release() + + def test_non_stale_lock_not_reacquired(self) -> None: + from torch_tensorrt.distributed._lock import DistributedFileLock + + lock1 = DistributedFileLock(self.tmp_dir, "test_hash", stale_timeout_s=600) + lock1.acquire() + + lock2 = DistributedFileLock(self.tmp_dir, "test_hash", stale_timeout_s=600) + self.assertFalse(lock2.acquire()) # not stale, can't acquire + + lock1.release() + + def test_cleanup_stale_locks(self) -> None: + from torch_tensorrt.distributed._lock import DistributedFileLock + + # Create stale locks + for i in range(3): + lock = DistributedFileLock(self.tmp_dir, f"hash_{i}", stale_timeout_s=0.1) + lock.acquire() + # Don't release + + time.sleep(0.2) + + removed = DistributedFileLock.cleanup_stale_locks( + self.tmp_dir, stale_timeout_s=0.1 + ) + self.assertEqual(removed, 3) + + def test_lock_path_format(self) -> None: + from torch_tensorrt.distributed._lock import DistributedFileLock + + lock = DistributedFileLock("/tmp/cache", "abc123") + self.assertEqual(lock.lock_path, "/tmp/cache/.abc123.building") + + +# --------------------------------------------------------------------------- +# Section 2 — Multi-rank distributed caching test +# --------------------------------------------------------------------------- + + +class SimpleModel(nn.Module): + """Small model for cache coordination testing.""" + + def __init__(self): + super().__init__() + self.fc1 = nn.Linear(32, 64) + self.relu = nn.ReLU() + self.fc2 = nn.Linear(64, 32) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return self.fc2(self.relu(self.fc1(x))) + + +def _multirank_distributed_cache_test( + rank: int, world_size: int, device: torch.device, cache_dir: str +) -> None: + """Test that distributed caching coordinates engine builds across ranks. + + Compiles the same model on both ranks with cache_built_engines=True and + reuse_cached_engines=True. All ranks must use the same cache_dir for + lock coordination to work. + + Verifies: + 1. Both ranks produce correct output + 2. Cache directory has engine files + """ + import torch_tensorrt + + torch.manual_seed(42) + model = SimpleModel().eval().to(device) + inp = torch.randn(2, 32, device=device) + + # PyTorch reference + with torch.no_grad(): + ref_output = model(inp) + + # Compile with caching enabled + with torch.no_grad(): + torch._dynamo.reset() + trt_model = torch.compile( + model, + backend="torch_tensorrt", + options={ + "enabled_precisions": {torch.float32}, + "use_python_runtime": False, + "min_block_size": 1, + "cache_built_engines": True, + "reuse_cached_engines": True, + "immutable_weights": False, + "engine_cache_dir": cache_dir, + "engine_cache_size": 1 << 30, # 1GB + }, + ) + trt_output = trt_model(inp) + + # Verify correctness + diff = (ref_output - trt_output).abs().max().item() + assert diff < 0.01, f"Rank {rank}: output mismatch, max diff={diff}" + print(f"[Rank {rank}] Compile + cache OK (max_diff={diff:.6f})", flush=True) + + # Verify cache directory has files + cache_files = os.listdir(cache_dir) + print( + f"[Rank {rank}] Cache dir has {len(cache_files)} entries: {cache_files[:5]}", + flush=True, + ) + assert ( + len(cache_files) > 0 + ), f"Rank {rank}: cache dir is empty — caching didn't work" + + dist.barrier() + + # Second compile — should hit cache on both ranks + torch._dynamo.reset() + trt_model2 = torch.compile( + model, + backend="torch_tensorrt", + options={ + "enabled_precisions": {torch.float32}, + "use_python_runtime": False, + "min_block_size": 1, + "cache_built_engines": True, + "reuse_cached_engines": True, + "immutable_weights": False, + "engine_cache_dir": cache_dir, + "engine_cache_size": 1 << 30, + }, + ) + with torch.no_grad(): + trt_output2 = trt_model2(inp) + + diff2 = (ref_output - trt_output2).abs().max().item() + assert diff2 < 0.01, f"Rank {rank}: cached output mismatch, max diff={diff2}" + print(f"[Rank {rank}] Cache reuse OK (max_diff={diff2:.6f})", flush=True) + + +# --------------------------------------------------------------------------- +# Section 3 — Multi-rank pytest (MultiProcessTestCase) +# --------------------------------------------------------------------------- + + +class TestMultirankDistributedCache(MultiProcessTestCase): + """Distributed engine cache tests as pytest-compatible MultiProcessTestCase.""" + + world_size = 2 + + def setUp(self) -> None: + super().setUp() + self._spawn_processes() + + def _init_dist(self) -> torch.device: + store = dist.FileStore(self.file_name, self.world_size) + dist.init_process_group( + backend="nccl", + store=store, + rank=self.rank, + world_size=self.world_size, + ) + local = self.rank % torch.cuda.device_count() + torch.cuda.set_device(local) + dist.barrier() + return torch.device(f"cuda:{local}") + + @requires_nccl() + @skip_if_lt_x_gpu(2) + def test_distributed_cache_coordination(self) -> None: + """Both ranks compile same model with caching — output matches reference.""" + device = self._init_dist() + cache_dir = tempfile.mkdtemp(prefix="trt_dist_cache_pytest_") + _multirank_distributed_cache_test(self.rank, self.world_size, device, cache_dir) + + +# --------------------------------------------------------------------------- +# Section 4 — torchrun entry point +# --------------------------------------------------------------------------- + + +def _multirank_setup() -> tuple: + dist.init_process_group(backend="nccl") + rank = dist.get_rank() + world_size = dist.get_world_size() + local_rank = int(os.environ.get("LOCAL_RANK", rank % torch.cuda.device_count())) + torch.cuda.set_device(local_rank) + return rank, world_size, torch.device(f"cuda:{local_rank}") + + +def run_multirank_tests() -> None: + rank, world_size, device = _multirank_setup() + print(f"[Rank {rank}/{world_size}] device={device}", flush=True) + + # All ranks must share the same cache dir — use a deterministic path. + # Clean up from previous runs so the lock coordination is tested fresh. + import shutil + + cache_dir = os.path.join(tempfile.gettempdir(), "trt_dist_cache_test_shared") + if rank == 0: + if os.path.exists(cache_dir): + shutil.rmtree(cache_dir) + dist.barrier() # wait for rank 0 to clean up before anyone creates it + os.makedirs(cache_dir, exist_ok=True) + + try: + _multirank_distributed_cache_test(rank, world_size, device, cache_dir) + except Exception as e: + print(f"[Rank {rank}] FAIL: {e}", flush=True) + import traceback + + traceback.print_exc() + dist.destroy_process_group() + sys.exit(1) + + dist.barrier() + dist.destroy_process_group() + print(f"[Rank {rank}] All distributed cache tests PASSED.", flush=True) + + +if __name__ == "__main__": + if "--multirank" in sys.argv or "--multinode" in sys.argv: + sys.argv = [a for a in sys.argv if a not in ("--multirank", "--multinode")] + run_multirank_tests() + else: + run_tests() From 2ea1d3cebdef432f52d24ce5d323db1d2b888df9 Mon Sep 17 00:00:00 2001 From: apbose Date: Thu, 23 Apr 2026 18:47:39 -0700 Subject: [PATCH 2/3] using filelock --- .github/workflows/build-test-linux-x86_64.yml | 4 +- py/torch_tensorrt/distributed/__init__.py | 3 - py/torch_tensorrt/distributed/_distributed.py | 68 ------ py/torch_tensorrt/distributed/_lock.py | 189 ----------------- .../dynamo/conversion/_conversion.py | 58 +++--- .../test_distributed_engine_cache.py | 196 ++---------------- 6 files changed, 44 insertions(+), 474 deletions(-) delete mode 100644 py/torch_tensorrt/distributed/_lock.py diff --git a/.github/workflows/build-test-linux-x86_64.yml b/.github/workflows/build-test-linux-x86_64.yml index 405b03a8e9..1a382477c8 100644 --- a/.github/workflows/build-test-linux-x86_64.yml +++ b/.github/workflows/build-test-linux-x86_64.yml @@ -611,9 +611,11 @@ jobs: python -m pytest -ra -v --junitxml=${RUNNER_TEST_RESULTS_DIR}/l2_dynamo_distributed_test_results.xml \ distributed/test_nccl_ops.py \ distributed/test_native_nccl.py \ - distributed/test_export_save_load.py + distributed/test_export_save_load.py \ + distributed/test_distributed_engine_cache.py python -m torch_tensorrt.distributed.run --nproc_per_node=2 distributed/test_native_nccl.py --multirank python -m torch_tensorrt.distributed.run --nproc_per_node=2 distributed/test_export_save_load.py --multirank + python -m torch_tensorrt.distributed.run --nproc_per_node=2 distributed/test_distributed_engine_cache.py --multirank popd concurrency: diff --git a/py/torch_tensorrt/distributed/__init__.py b/py/torch_tensorrt/distributed/__init__.py index 15194667ff..ed30f09d89 100644 --- a/py/torch_tensorrt/distributed/__init__.py +++ b/py/torch_tensorrt/distributed/__init__.py @@ -2,10 +2,7 @@ distributed_context, is_distributed_caching_enabled, set_distributed_mode, - signal_distributed_engine_build_complete, - wait_for_distributed_engine_build, ) -from torch_tensorrt.distributed._lock import DistributedFileLock # noqa: F401 from torch_tensorrt.distributed._nccl_utils import ( # noqa: F401 setup_nccl_for_torch_tensorrt, ) diff --git a/py/torch_tensorrt/distributed/_distributed.py b/py/torch_tensorrt/distributed/_distributed.py index b0e0577df0..8c810decca 100644 --- a/py/torch_tensorrt/distributed/_distributed.py +++ b/py/torch_tensorrt/distributed/_distributed.py @@ -230,71 +230,3 @@ def is_distributed_caching_enabled( and dist.is_initialized() and dist.get_world_size() > 1 ) - - -def wait_for_distributed_engine_build( - pull_fn: Any, - cache_dir: str, - hash_val: str, - poll_interval: float = 0.5, - timeout: float = 600.0, -) -> Any: - """Non-building rank: poll for cached engine file, then load from cache. - - Called when this rank failed to acquire the build lock, meaning another - rank is building the engine. Polls the filesystem for the cached engine - file instead of using NCCL collectives (which are unreliable inside - the TRT compilation path due to aot_autograd/CUDA stream conflicts). - - Args: - pull_fn: Zero-arg callable (e.g. functools.partial) that loads the - engine from cache. Returns SerializedInterpreterResult on - hit, None on miss. - cache_dir: Shared engine cache directory path. - hash_val: Engine hash for this compilation. - poll_interval: Seconds between filesystem checks (default 0.5s). - timeout: Maximum seconds to wait before giving up (default 600s). - - Returns: - SerializedInterpreterResult on cache hit, None on timeout. - """ - import logging - import os - import time - - logger = logging.getLogger(__name__) - - blob_path = os.path.join(cache_dir, hash_val, "blob.bin") - logger.info(f"Polling for cached engine: {blob_path}") - - elapsed = 0.0 - while not os.path.exists(blob_path): - time.sleep(poll_interval) - elapsed += poll_interval - if elapsed >= timeout: - logger.warning( - f"Polling timed out after {timeout:.0f}s — building engine locally" - ) - return None - - logger.info(f"Cached engine found after {elapsed:.1f}s — loading from cache") - cached = pull_fn() - if cached is not None: - return cached - - logger.warning("Cache file exists but pull_cached_engine failed — building locally") - return None - - -def signal_distributed_engine_build_complete(lock: Any) -> None: - """Building rank: release the file lock after caching the engine. - - Called after the building rank has inserted the engine into the shared - cache. Releases the file lock so other ranks' stale lock detection - works correctly. No NCCL collective needed — waiter ranks poll the - filesystem directly. - - Args: - lock: DistributedFileLock instance that was acquired by this rank. - """ - lock.release() diff --git a/py/torch_tensorrt/distributed/_lock.py b/py/torch_tensorrt/distributed/_lock.py deleted file mode 100644 index 32acc407a6..0000000000 --- a/py/torch_tensorrt/distributed/_lock.py +++ /dev/null @@ -1,189 +0,0 @@ -""" -File-based distributed lock for coordinating work across ranks. - -Used to ensure only one rank performs an expensive operation (e.g., TRT engine -build) while others wait and consume the result from a shared cache. - -The lock file is created atomically via os.O_CREAT | os.O_EXCL — only one -process can succeed. Other processes see FileExistsError and know to wait. - -If the lock holder crashes, the lock file becomes stale. A configurable -timeout detects this: if the lock file's modification time is older than -the timeout, it is treated as stale and forcibly removed so another rank -can proceed. -""" - -import logging -import os -import time - -import torch.distributed as dist - -logger = logging.getLogger(__name__) - -# Default timeout for stale lock detection (seconds). -# TRT engine builds can take several minutes for large models. -_DEFAULT_STALE_TIMEOUT_S = 600 # 10 minutes - - -class DistributedFileLock: - """Atomic file lock for coordinating distributed engine builds. - - Args: - lock_dir: Directory to create lock files in (typically the cache dir). - name: Unique name for this lock (typically the engine hash). - suffix: File suffix for the lock file. - stale_timeout_s: Seconds after which a lock file is considered stale - (holder likely crashed). Set to 0 to disable stale detection. - """ - - def __init__( - self, - lock_dir: str, - name: str, - suffix: str = ".building", - stale_timeout_s: float = _DEFAULT_STALE_TIMEOUT_S, - ) -> None: - self._lock_path = os.path.join(lock_dir, f".{name}{suffix}") - self._acquired = False - self._stale_timeout_s = stale_timeout_s - - @property - def acquired(self) -> bool: - """Whether this instance holds the lock.""" - return self._acquired - - @property - def lock_path(self) -> str: - return self._lock_path - - def _is_stale(self) -> bool: - """Check if an existing lock file is stale (holder likely crashed). - - A lock is stale if its modification time is older than stale_timeout_s. - """ - if self._stale_timeout_s <= 0: - return False - try: - mtime = os.path.getmtime(self._lock_path) - age = time.time() - mtime - if age > self._stale_timeout_s: - logger.warning( - f"Stale build lock detected: {self._lock_path} " - f"(age={age:.0f}s > timeout={self._stale_timeout_s:.0f}s). " - f"Lock holder likely crashed. Removing stale lock." - ) - return True - except FileNotFoundError: - pass - return False - - def _remove_stale(self) -> None: - """Remove a stale lock file so acquire() can succeed.""" - try: - os.remove(self._lock_path) - logger.debug(f"Removed stale lock: {self._lock_path}") - except FileNotFoundError: - pass - - def acquire(self) -> bool: - """Try to acquire the lock atomically. - - If the lock file exists but is stale (older than stale_timeout_s), - it is removed and acquisition is retried once. - - Returns: - True if this process acquired the lock (should do the work). - False if another process already holds it (should wait). - """ - try: - fd = os.open(self._lock_path, os.O_CREAT | os.O_EXCL | os.O_WRONLY) - os.close(fd) - self._acquired = True - logger.debug(f"Acquired build lock: {self._lock_path}") - return True - except FileExistsError: - pass - - if self._is_stale(): - self._remove_stale() - try: - fd = os.open(self._lock_path, os.O_CREAT | os.O_EXCL | os.O_WRONLY) - os.close(fd) - self._acquired = True - logger.debug( - f"Acquired build lock after stale removal: {self._lock_path}" - ) - return True - except FileExistsError: - pass - - self._acquired = False - logger.debug(f"Build lock already held: {self._lock_path}") - return False - - def release(self) -> None: - """Release the lock by removing the lock file.""" - if self._acquired: - try: - os.remove(self._lock_path) - logger.debug(f"Released build lock: {self._lock_path}") - except FileNotFoundError: - pass - self._acquired = False - - @staticmethod - def cleanup_stale_locks( - lock_dir: str, - suffix: str = ".building", - stale_timeout_s: float = _DEFAULT_STALE_TIMEOUT_S, - ) -> int: - """Remove all stale lock files in a directory. - - Useful for cleaning up after crashes before starting a new run. - - Args: - lock_dir: Directory to scan for lock files. - suffix: Lock file suffix to match. - stale_timeout_s: Age threshold in seconds. - - Returns: - Number of stale lock files removed. - """ - removed = 0 - if not os.path.isdir(lock_dir): - return 0 - now = time.time() - for entry in os.scandir(lock_dir): - if ( - entry.name.startswith(".") - and entry.name.endswith(suffix) - and entry.is_file() - ): - try: - age = now - entry.stat().st_mtime - if age > stale_timeout_s: - os.remove(entry.path) - logger.debug( - f"Cleaned up stale lock: {entry.path} (age={age:.0f}s)" - ) - removed += 1 - except (FileNotFoundError, OSError): - pass - if removed > 0: - logger.info(f"Cleaned up {removed} stale lock file(s) in {lock_dir}") - return removed - - @staticmethod - def barrier() -> None: - """Distributed barrier — all ranks must call this.""" - if dist.is_available() and dist.is_initialized() and dist.get_world_size() > 1: - dist.barrier() - - def __enter__(self) -> "DistributedFileLock": - self.acquire() - return self - - def __exit__(self, exc_type, exc_val, exc_tb) -> None: # type: ignore[no-untyped-def] - self.release() - return None diff --git a/py/torch_tensorrt/dynamo/conversion/_conversion.py b/py/torch_tensorrt/dynamo/conversion/_conversion.py index 80870ceb43..cdf5c82eea 100644 --- a/py/torch_tensorrt/dynamo/conversion/_conversion.py +++ b/py/torch_tensorrt/dynamo/conversion/_conversion.py @@ -2,7 +2,6 @@ import io import logging -from functools import partial from typing import Any, Dict, List, NamedTuple, Optional, Sequence import tensorrt as trt @@ -10,12 +9,7 @@ from torch_tensorrt._enums import dtype from torch_tensorrt._features import ENABLED_FEATURES from torch_tensorrt._Input import Input -from torch_tensorrt.distributed._distributed import ( - is_distributed_caching_enabled, - signal_distributed_engine_build_complete, - wait_for_distributed_engine_build, -) -from torch_tensorrt.distributed._lock import DistributedFileLock +from torch_tensorrt.distributed._distributed import is_distributed_caching_enabled from torch_tensorrt.dynamo._engine_cache import BaseEngineCache from torch_tensorrt.dynamo._settings import CompilationSettings, settings_are_compatible from torch_tensorrt.dynamo.conversion._symbolic_shape_capture import ( @@ -275,33 +269,33 @@ def interpret_module_to_result( settings.cache_built_engines, settings.reuse_cached_engines, ) - _build_lock = None + _lock: Optional[Any] = None if _distributed_caching: + import os as _os + + from filelock import FileLock + # is_distributed_caching_enabled guarantees engine_cache and hash_val are set. assert engine_cache is not None assert hash_val is not None - _build_lock = DistributedFileLock(engine_cache.engine_cache_dir, hash_val) - if _build_lock.acquire(): - logger.info("Acquired engine build lock — this rank builds") - else: - logger.info("Lock held by another rank — polling for cached engine") - _pull_fn = partial( - pull_cached_engine, - hash_val, - module, - engine_cache, - settings, - inputs, - symbolic_shape_expressions, - ) - cached: Optional[SerializedInterpreterResult] = ( - wait_for_distributed_engine_build( - _pull_fn, engine_cache.engine_cache_dir, hash_val - ) - ) - if cached is not None: - return cached + + _lock_path = _os.path.join(engine_cache.engine_cache_dir, f".{hash_val}.lock") + _lock = FileLock(_lock_path, timeout=600) + _lock.acquire() + + # Check cache again — another rank may have built while we waited + cached = pull_cached_engine( + hash_val, + module, + engine_cache, + settings, + inputs, + symbolic_shape_expressions, + ) + if cached is not None: + _lock.release() + return cached output_dtypes = infer_module_output_dtypes( module, truncate_double=settings.truncate_double @@ -348,9 +342,9 @@ def interpret_module_to_result( hash_val, interpreter_result, engine_cache, settings, inputs ) - # Signal other ranks that the engine is cached and ready - if _build_lock is not None and _build_lock.acquired: - signal_distributed_engine_build_complete(_build_lock) + # Release the filelock so other ranks can proceed + if _distributed_caching and _lock is not None: + _lock.release() serialized_engine = interpreter_result.engine.serialize() with io.BytesIO() as engine_bytes: diff --git a/tests/py/dynamo/distributed/test_distributed_engine_cache.py b/tests/py/dynamo/distributed/test_distributed_engine_cache.py index 5aa79aa79c..03638f577f 100644 --- a/tests/py/dynamo/distributed/test_distributed_engine_cache.py +++ b/tests/py/dynamo/distributed/test_distributed_engine_cache.py @@ -3,15 +3,7 @@ Verifies that when multiple ranks compile the same model with engine caching enabled, only one rank builds the TRT engine and others load from the shared -DiskEngineCache. - -Tests: - 1. DistributedFileLock — acquire, release, stale detection (no GPU) - 2. Multi-rank: one rank builds, other loads from cache (2 GPUs) - -Run single-rank tests (no GPU needed): - cd tests/py/dynamo - pytest distributed/test_distributed_engine_cache.py -v +DiskEngineCache via filelock coordination. Run multi-rank tests (requires 2 GPUs): pytest distributed/test_distributed_engine_cache.py::TestMultirankDistributedCache -v @@ -25,8 +17,6 @@ import os import sys import tempfile -import time -import unittest import torch import torch.distributed as dist @@ -39,161 +29,7 @@ from torch.testing._internal.common_utils import run_tests # --------------------------------------------------------------------------- -# Capability checks -# --------------------------------------------------------------------------- - - -def is_nccl_available() -> bool: - try: - return dist.is_nccl_available() - except Exception: - return False - - -def has_nccl_collectives() -> bool: - try: - from torch_tensorrt._features import ENABLED_FEATURES - - return bool(ENABLED_FEATURES.native_trt_collectives) or bool( - ENABLED_FEATURES.trtllm_for_nccl - ) - except Exception: - return False - - -# --------------------------------------------------------------------------- -# Section 1 — DistributedFileLock (no GPU, no dist) -# --------------------------------------------------------------------------- - - -class TestDistributedFileLock(unittest.TestCase): - """Unit tests for the file-based distributed lock.""" - - def setUp(self) -> None: - self.tmp_dir = tempfile.mkdtemp(prefix="trt_lock_test_") - - def test_acquire_returns_true_first_time(self) -> None: - from torch_tensorrt.distributed._lock import DistributedFileLock - - lock = DistributedFileLock(self.tmp_dir, "test_hash") - self.assertTrue(lock.acquire()) - self.assertTrue(lock.acquired) - lock.release() - - def test_acquire_returns_false_when_held(self) -> None: - from torch_tensorrt.distributed._lock import DistributedFileLock - - lock1 = DistributedFileLock(self.tmp_dir, "test_hash") - lock2 = DistributedFileLock(self.tmp_dir, "test_hash") - - self.assertTrue(lock1.acquire()) - self.assertFalse(lock2.acquire()) - self.assertFalse(lock2.acquired) - - lock1.release() - - def test_acquire_succeeds_after_release(self) -> None: - from torch_tensorrt.distributed._lock import DistributedFileLock - - lock1 = DistributedFileLock(self.tmp_dir, "test_hash") - lock1.acquire() - lock1.release() - - lock2 = DistributedFileLock(self.tmp_dir, "test_hash") - self.assertTrue(lock2.acquire()) - lock2.release() - - def test_release_without_acquire_is_noop(self) -> None: - from torch_tensorrt.distributed._lock import DistributedFileLock - - lock = DistributedFileLock(self.tmp_dir, "test_hash") - lock.release() # should not raise - - def test_context_manager_acquires_and_releases(self) -> None: - from torch_tensorrt.distributed._lock import DistributedFileLock - - with DistributedFileLock(self.tmp_dir, "test_hash") as lock: - self.assertTrue(lock.acquired) - self.assertTrue(os.path.exists(lock.lock_path)) - - # After exit: lock file removed - self.assertFalse(os.path.exists(lock.lock_path)) - - def test_context_manager_releases_on_exception(self) -> None: - from torch_tensorrt.distributed._lock import DistributedFileLock - - try: - with DistributedFileLock(self.tmp_dir, "test_hash") as lock: - lock_path = lock.lock_path - raise RuntimeError("test error") - except RuntimeError: - pass - - self.assertFalse(os.path.exists(lock_path)) - - def test_different_names_dont_conflict(self) -> None: - from torch_tensorrt.distributed._lock import DistributedFileLock - - lock_a = DistributedFileLock(self.tmp_dir, "hash_a") - lock_b = DistributedFileLock(self.tmp_dir, "hash_b") - - self.assertTrue(lock_a.acquire()) - self.assertTrue(lock_b.acquire()) # different name, no conflict - - lock_a.release() - lock_b.release() - - def test_stale_lock_detected_and_reacquired(self) -> None: - from torch_tensorrt.distributed._lock import DistributedFileLock - - # Create a lock with very short stale timeout - lock1 = DistributedFileLock(self.tmp_dir, "test_hash", stale_timeout_s=0.1) - lock1.acquire() - # Don't release — simulate crash - - time.sleep(0.2) # wait for it to become stale - - # Another process tries to acquire - lock2 = DistributedFileLock(self.tmp_dir, "test_hash", stale_timeout_s=0.1) - self.assertTrue(lock2.acquire()) # should detect stale and reacquire - lock2.release() - - def test_non_stale_lock_not_reacquired(self) -> None: - from torch_tensorrt.distributed._lock import DistributedFileLock - - lock1 = DistributedFileLock(self.tmp_dir, "test_hash", stale_timeout_s=600) - lock1.acquire() - - lock2 = DistributedFileLock(self.tmp_dir, "test_hash", stale_timeout_s=600) - self.assertFalse(lock2.acquire()) # not stale, can't acquire - - lock1.release() - - def test_cleanup_stale_locks(self) -> None: - from torch_tensorrt.distributed._lock import DistributedFileLock - - # Create stale locks - for i in range(3): - lock = DistributedFileLock(self.tmp_dir, f"hash_{i}", stale_timeout_s=0.1) - lock.acquire() - # Don't release - - time.sleep(0.2) - - removed = DistributedFileLock.cleanup_stale_locks( - self.tmp_dir, stale_timeout_s=0.1 - ) - self.assertEqual(removed, 3) - - def test_lock_path_format(self) -> None: - from torch_tensorrt.distributed._lock import DistributedFileLock - - lock = DistributedFileLock("/tmp/cache", "abc123") - self.assertEqual(lock.lock_path, "/tmp/cache/.abc123.building") - - -# --------------------------------------------------------------------------- -# Section 2 — Multi-rank distributed caching test +# Multi-rank distributed caching test # --------------------------------------------------------------------------- @@ -217,11 +53,12 @@ def _multirank_distributed_cache_test( Compiles the same model on both ranks with cache_built_engines=True and reuse_cached_engines=True. All ranks must use the same cache_dir for - lock coordination to work. + filelock coordination to work. Verifies: 1. Both ranks produce correct output 2. Cache directory has engine files + 3. Second compile hits cache on both ranks """ import torch_tensorrt @@ -233,7 +70,7 @@ def _multirank_distributed_cache_test( with torch.no_grad(): ref_output = model(inp) - # Compile with caching enabled + # Phase 1: Compile with caching enabled with torch.no_grad(): torch._dynamo.reset() trt_model = torch.compile( @@ -247,29 +84,26 @@ def _multirank_distributed_cache_test( "reuse_cached_engines": True, "immutable_weights": False, "engine_cache_dir": cache_dir, - "engine_cache_size": 1 << 30, # 1GB + "engine_cache_size": 1 << 30, }, ) trt_output = trt_model(inp) - # Verify correctness diff = (ref_output - trt_output).abs().max().item() assert diff < 0.01, f"Rank {rank}: output mismatch, max diff={diff}" print(f"[Rank {rank}] Compile + cache OK (max_diff={diff:.6f})", flush=True) - # Verify cache directory has files - cache_files = os.listdir(cache_dir) + # Verify cache has files (ignore .lock files) + cache_files = [f for f in os.listdir(cache_dir) if not f.endswith(".lock")] print( f"[Rank {rank}] Cache dir has {len(cache_files)} entries: {cache_files[:5]}", flush=True, ) - assert ( - len(cache_files) > 0 - ), f"Rank {rank}: cache dir is empty — caching didn't work" + assert len(cache_files) > 0, f"Rank {rank}: cache dir is empty" dist.barrier() - # Second compile — should hit cache on both ranks + # Phase 2: Second compile — should hit cache on both ranks torch._dynamo.reset() trt_model2 = torch.compile( model, @@ -294,7 +128,7 @@ def _multirank_distributed_cache_test( # --------------------------------------------------------------------------- -# Section 3 — Multi-rank pytest (MultiProcessTestCase) +# Multi-rank pytest (MultiProcessTestCase) # --------------------------------------------------------------------------- @@ -330,7 +164,7 @@ def test_distributed_cache_coordination(self) -> None: # --------------------------------------------------------------------------- -# Section 4 — torchrun entry point +# torchrun entry point # --------------------------------------------------------------------------- @@ -347,15 +181,15 @@ def run_multirank_tests() -> None: rank, world_size, device = _multirank_setup() print(f"[Rank {rank}/{world_size}] device={device}", flush=True) - # All ranks must share the same cache dir — use a deterministic path. - # Clean up from previous runs so the lock coordination is tested fresh. + # All ranks must share the same cache dir. + # Clean up from previous runs so filelock coordination is tested fresh. import shutil cache_dir = os.path.join(tempfile.gettempdir(), "trt_dist_cache_test_shared") if rank == 0: if os.path.exists(cache_dir): shutil.rmtree(cache_dir) - dist.barrier() # wait for rank 0 to clean up before anyone creates it + dist.barrier() os.makedirs(cache_dir, exist_ok=True) try: From 94d89b84b448cf4d3c5d02c94b03b91cec9391b8 Mon Sep 17 00:00:00 2001 From: apbose Date: Thu, 23 Apr 2026 19:46:07 -0700 Subject: [PATCH 3/3] adding TP test --- .../test_distributed_engine_cache.py | 257 +++++++++++++----- 1 file changed, 182 insertions(+), 75 deletions(-) diff --git a/tests/py/dynamo/distributed/test_distributed_engine_cache.py b/tests/py/dynamo/distributed/test_distributed_engine_cache.py index 03638f577f..9f2253d7ef 100644 --- a/tests/py/dynamo/distributed/test_distributed_engine_cache.py +++ b/tests/py/dynamo/distributed/test_distributed_engine_cache.py @@ -15,8 +15,11 @@ from __future__ import annotations import os +import shutil import sys import tempfile +import time +import unittest import torch import torch.distributed as dist @@ -29,12 +32,24 @@ from torch.testing._internal.common_utils import run_tests # --------------------------------------------------------------------------- -# Multi-rank distributed caching test +# Capability checks +# --------------------------------------------------------------------------- + + +def is_nccl_available() -> bool: + try: + return dist.is_nccl_available() + except Exception: + return False + + +# --------------------------------------------------------------------------- +# Models # --------------------------------------------------------------------------- class SimpleModel(nn.Module): - """Small model for cache coordination testing.""" + """Small model for cache coordination testing (no TP).""" def __init__(self): super().__init__() @@ -46,92 +61,166 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: return self.fc2(self.relu(self.fc1(x))) +# --------------------------------------------------------------------------- +# Shared helpers +# --------------------------------------------------------------------------- + + +def _clean_cache_dir(cache_dir: str, rank: int) -> None: + """Rank 0 cleans cache dir, all ranks wait then recreate.""" + if rank == 0 and os.path.exists(cache_dir): + shutil.rmtree(cache_dir) + dist.barrier() + os.makedirs(cache_dir, exist_ok=True) + + +def _compile_with_cache(model, cache_dir, use_distributed_mode_trace=False): + """Compile model with TRT + engine caching enabled.""" + import torch_tensorrt + + return torch.compile( + model, + backend="torch_tensorrt", + options={ + "enabled_precisions": {torch.float32}, + "use_python_runtime": False, + "min_block_size": 1, + "cache_built_engines": True, + "reuse_cached_engines": True, + "immutable_weights": False, + "use_distributed_mode_trace": use_distributed_mode_trace, + "engine_cache_dir": cache_dir, + "engine_cache_size": 1 << 30, + }, + ) + + +def _assert_cache_has_files(cache_dir: str, rank: int) -> int: + """Assert cache dir has engine files (ignoring .lock files).""" + cache_files = [f for f in os.listdir(cache_dir) if not f.endswith(".lock")] + assert len(cache_files) > 0, f"Rank {rank}: cache dir is empty" + return len(cache_files) + + +# --------------------------------------------------------------------------- +# Test logic +# --------------------------------------------------------------------------- + + def _multirank_distributed_cache_test( rank: int, world_size: int, device: torch.device, cache_dir: str ) -> None: - """Test that distributed caching coordinates engine builds across ranks. - - Compiles the same model on both ranks with cache_built_engines=True and - reuse_cached_engines=True. All ranks must use the same cache_dir for - filelock coordination to work. + """Test cache coordination with simple model (same weights on all ranks). - Verifies: - 1. Both ranks produce correct output - 2. Cache directory has engine files - 3. Second compile hits cache on both ranks + Phase 1: First compile — one rank builds, other loads from cache. + Phase 2: Second compile — both ranks load from cache (refit). """ - import torch_tensorrt - torch.manual_seed(42) model = SimpleModel().eval().to(device) inp = torch.randn(2, 32, device=device) - # PyTorch reference with torch.no_grad(): ref_output = model(inp) - # Phase 1: Compile with caching enabled + # Phase 1: compile + cache + torch._dynamo.reset() + trt_model = _compile_with_cache(model, cache_dir) with torch.no_grad(): - torch._dynamo.reset() - trt_model = torch.compile( - model, - backend="torch_tensorrt", - options={ - "enabled_precisions": {torch.float32}, - "use_python_runtime": False, - "min_block_size": 1, - "cache_built_engines": True, - "reuse_cached_engines": True, - "immutable_weights": False, - "engine_cache_dir": cache_dir, - "engine_cache_size": 1 << 30, - }, - ) + t0 = time.time() trt_output = trt_model(inp) + build_time = time.time() - t0 diff = (ref_output - trt_output).abs().max().item() assert diff < 0.01, f"Rank {rank}: output mismatch, max diff={diff}" - print(f"[Rank {rank}] Compile + cache OK (max_diff={diff:.6f})", flush=True) - - # Verify cache has files (ignore .lock files) - cache_files = [f for f in os.listdir(cache_dir) if not f.endswith(".lock")] + n_files = _assert_cache_has_files(cache_dir, rank) print( - f"[Rank {rank}] Cache dir has {len(cache_files)} entries: {cache_files[:5]}", + f"[Rank {rank}] Compile + cache OK " + f"(max_diff={diff:.6f}, build_time={build_time:.2f}s, cache_entries={n_files})", flush=True, ) - assert len(cache_files) > 0, f"Rank {rank}: cache dir is empty" dist.barrier() - # Phase 2: Second compile — should hit cache on both ranks + # Phase 2: cache reuse torch._dynamo.reset() - trt_model2 = torch.compile( - model, - backend="torch_tensorrt", - options={ - "enabled_precisions": {torch.float32}, - "use_python_runtime": False, - "min_block_size": 1, - "cache_built_engines": True, - "reuse_cached_engines": True, - "immutable_weights": False, - "engine_cache_dir": cache_dir, - "engine_cache_size": 1 << 30, - }, - ) + trt_model2 = _compile_with_cache(model, cache_dir) with torch.no_grad(): + t0 = time.time() trt_output2 = trt_model2(inp) + refit_time = time.time() - t0 diff2 = (ref_output - trt_output2).abs().max().item() assert diff2 < 0.01, f"Rank {rank}: cached output mismatch, max diff={diff2}" - print(f"[Rank {rank}] Cache reuse OK (max_diff={diff2:.6f})", flush=True) + print( + f"[Rank {rank}] Cache reuse OK " + f"(max_diff={diff2:.6f}, refit_time={refit_time:.2f}s, " + f"speedup={build_time / max(refit_time, 0.001):.1f}x)", + flush=True, + ) + + +def _multirank_tp_cache_test( + rank: int, world_size: int, device: torch.device, cache_dir: str +) -> None: + """Test cache coordination with TP-sharded model (different weights per rank). + + Each rank holds ColwiseParallel/RowwiseParallel sharded weights. + Engine structure is identical — only weights differ. + One rank builds, other loads from cache and refits with its own weights. + """ + from torch.distributed.device_mesh import init_device_mesh + from torch.distributed.tensor.parallel import ( + ColwiseParallel, + RowwiseParallel, + parallelize_module, + ) + + mesh = init_device_mesh("cuda", (world_size,)) + + torch.manual_seed(42) + model = ( + nn.Sequential( + nn.Linear(64, 128, bias=False), + nn.ReLU(), + nn.Linear(128, 64, bias=False), + ) + .eval() + .to(device) + ) + + inp = torch.randn(2, 64, device=device) + + with torch.no_grad(): + ref_output = model(inp) + + tp_plan = {"0": ColwiseParallel(), "2": RowwiseParallel()} + parallelize_module(model, mesh, tp_plan) + + torch._dynamo.reset() + trt_model = _compile_with_cache(model, cache_dir, use_distributed_mode_trace=True) + with torch.no_grad(): + trt_output = trt_model(inp) + + diff = (ref_output - trt_output).abs().max().item() + assert diff < 0.05, f"Rank {rank}: TP output mismatch, max diff={diff}" + n_files = _assert_cache_has_files(cache_dir, rank) + print( + f"[Rank {rank}] TP compile + cache OK " + f"(max_diff={diff:.6f}, cache_entries={n_files})", + flush=True, + ) # --------------------------------------------------------------------------- -# Multi-rank pytest (MultiProcessTestCase) +# pytest path (MultiProcessTestCase) # --------------------------------------------------------------------------- +@unittest.skipIf(not is_nccl_available(), "NCCL not available") +@unittest.skipIf( + not torch.cuda.is_available() or torch.cuda.device_count() < 2, + "Requires at least 2 GPUs", +) class TestMultirankDistributedCache(MultiProcessTestCase): """Distributed engine cache tests as pytest-compatible MultiProcessTestCase.""" @@ -141,7 +230,8 @@ def setUp(self) -> None: super().setUp() self._spawn_processes() - def _init_dist(self) -> torch.device: + def _init_dist(self, cache_dir: str) -> torch.device: + """Initialize dist, clean cache, return device.""" store = dist.FileStore(self.file_name, self.world_size) dist.init_process_group( backend="nccl", @@ -151,17 +241,23 @@ def _init_dist(self) -> torch.device: ) local = self.rank % torch.cuda.device_count() torch.cuda.set_device(local) - dist.barrier() + _clean_cache_dir(cache_dir, self.rank) return torch.device(f"cuda:{local}") @requires_nccl() - @skip_if_lt_x_gpu(2) def test_distributed_cache_coordination(self) -> None: """Both ranks compile same model with caching — output matches reference.""" - device = self._init_dist() cache_dir = tempfile.mkdtemp(prefix="trt_dist_cache_pytest_") + device = self._init_dist(cache_dir) _multirank_distributed_cache_test(self.rank, self.world_size, device, cache_dir) + @requires_nccl() + def test_tp_cache_coordination(self) -> None: + """TP-sharded model: one rank builds, other loads from cache + refits.""" + cache_dir = tempfile.mkdtemp(prefix="trt_dist_tp_cache_pytest_") + device = self._init_dist(cache_dir) + _multirank_tp_cache_test(self.rank, self.world_size, device, cache_dir) + # --------------------------------------------------------------------------- # torchrun entry point @@ -178,33 +274,44 @@ def _multirank_setup() -> tuple: def run_multirank_tests() -> None: + """Entry point for --multirank mode (called by torchrun workers).""" rank, world_size, device = _multirank_setup() print(f"[Rank {rank}/{world_size}] device={device}", flush=True) - # All ranks must share the same cache dir. - # Clean up from previous runs so filelock coordination is tested fresh. - import shutil + base_cache_dir = os.path.join(tempfile.gettempdir(), "trt_dist_cache_torchrun") - cache_dir = os.path.join(tempfile.gettempdir(), "trt_dist_cache_test_shared") - if rank == 0: - if os.path.exists(cache_dir): - shutil.rmtree(cache_dir) - dist.barrier() - os.makedirs(cache_dir, exist_ok=True) + tests = [ + ( + "simple_model_cache", + base_cache_dir + "_simple", + _multirank_distributed_cache_test, + ), + ("tp_model_cache", base_cache_dir + "_tp", _multirank_tp_cache_test), + ] - try: - _multirank_distributed_cache_test(rank, world_size, device, cache_dir) - except Exception as e: - print(f"[Rank {rank}] FAIL: {e}", flush=True) - import traceback + failed = [] + for name, cache_dir, test_fn in tests: + _clean_cache_dir(cache_dir, rank) + dist.barrier() + try: + test_fn(rank, world_size, device, cache_dir) + except Exception as e: + failed.append((name, str(e))) + print(f"[Rank {rank}] FAIL {name}: {e}", flush=True) + import traceback - traceback.print_exc() - dist.destroy_process_group() - sys.exit(1) + traceback.print_exc() dist.barrier() dist.destroy_process_group() - print(f"[Rank {rank}] All distributed cache tests PASSED.", flush=True) + + if failed: + print(f"[Rank {rank}] {len(failed)} test(s) FAILED:", flush=True) + for name, err in failed: + print(f" - {name}: {err}", flush=True) + sys.exit(1) + else: + print(f"[Rank {rank}] All distributed cache tests PASSED.", flush=True) if __name__ == "__main__":