diff --git a/.github/workflows/build-test-linux-x86_64.yml b/.github/workflows/build-test-linux-x86_64.yml index 405b03a8e9..1a382477c8 100644 --- a/.github/workflows/build-test-linux-x86_64.yml +++ b/.github/workflows/build-test-linux-x86_64.yml @@ -611,9 +611,11 @@ jobs: python -m pytest -ra -v --junitxml=${RUNNER_TEST_RESULTS_DIR}/l2_dynamo_distributed_test_results.xml \ distributed/test_nccl_ops.py \ distributed/test_native_nccl.py \ - distributed/test_export_save_load.py + distributed/test_export_save_load.py \ + distributed/test_distributed_engine_cache.py python -m torch_tensorrt.distributed.run --nproc_per_node=2 distributed/test_native_nccl.py --multirank python -m torch_tensorrt.distributed.run --nproc_per_node=2 distributed/test_export_save_load.py --multirank + python -m torch_tensorrt.distributed.run --nproc_per_node=2 distributed/test_distributed_engine_cache.py --multirank popd concurrency: diff --git a/py/torch_tensorrt/distributed/__init__.py b/py/torch_tensorrt/distributed/__init__.py index fdb95b02b1..ed30f09d89 100644 --- a/py/torch_tensorrt/distributed/__init__.py +++ b/py/torch_tensorrt/distributed/__init__.py @@ -1,5 +1,6 @@ from torch_tensorrt.distributed._distributed import ( # noqa: F401 distributed_context, + is_distributed_caching_enabled, set_distributed_mode, ) from torch_tensorrt.distributed._nccl_utils import ( # noqa: F401 diff --git a/py/torch_tensorrt/distributed/_distributed.py b/py/torch_tensorrt/distributed/_distributed.py index 16714e8a1e..8c810decca 100644 --- a/py/torch_tensorrt/distributed/_distributed.py +++ b/py/torch_tensorrt/distributed/_distributed.py @@ -205,3 +205,28 @@ def set_distributed_mode(group: Any, module: nn.Module) -> None: seen.add(id(engine)) if getattr(engine, "requires_native_multidevice", False): engine.set_group_name(group_name) + + +def is_distributed_caching_enabled( + is_engine_caching_supported: bool, + cache_built_engines: bool, + reuse_cached_engines: bool, +) -> bool: + """Check if distributed engine cache coordination should be used. + + Returns True when all conditions are met: + - Engine caching is supported (cache exists, refit available, mutable weights) + - User enabled both cache_built_engines and reuse_cached_engines + - Running in a distributed environment with world_size > 1 + + When True, only one rank builds the TRT engine and caches it. + Other ranks wait and load from the shared DiskEngineCache. + """ + return ( + is_engine_caching_supported + and cache_built_engines + and reuse_cached_engines + and dist.is_available() + and dist.is_initialized() + and dist.get_world_size() > 1 + ) diff --git a/py/torch_tensorrt/dynamo/conversion/_conversion.py b/py/torch_tensorrt/dynamo/conversion/_conversion.py index 77e48fe92e..cdf5c82eea 100644 --- a/py/torch_tensorrt/dynamo/conversion/_conversion.py +++ b/py/torch_tensorrt/dynamo/conversion/_conversion.py @@ -4,10 +4,12 @@ import logging from typing import Any, Dict, List, NamedTuple, Optional, Sequence +import tensorrt as trt import torch from torch_tensorrt._enums import dtype from torch_tensorrt._features import ENABLED_FEATURES from torch_tensorrt._Input import Input +from torch_tensorrt.distributed._distributed import is_distributed_caching_enabled from torch_tensorrt.dynamo._engine_cache import BaseEngineCache from torch_tensorrt.dynamo._settings import CompilationSettings, settings_are_compatible from torch_tensorrt.dynamo.conversion._symbolic_shape_capture import ( @@ -25,8 +27,6 @@ ) from torch_tensorrt.logging import TRT_LOGGER -import tensorrt as trt - logger = logging.getLogger(__name__) @@ -262,6 +262,41 @@ def interpret_module_to_result( if serialized_interpreter_result is not None: # hit the cache return serialized_interpreter_result + # Distributed engine cache coordination: only one rank builds, + # others wait and load from shared cache. + _distributed_caching = is_distributed_caching_enabled( + is_engine_caching_supported, + settings.cache_built_engines, + settings.reuse_cached_engines, + ) + _lock: Optional[Any] = None + + if _distributed_caching: + import os as _os + + from filelock import FileLock + + # is_distributed_caching_enabled guarantees engine_cache and hash_val are set. + assert engine_cache is not None + assert hash_val is not None + + _lock_path = _os.path.join(engine_cache.engine_cache_dir, f".{hash_val}.lock") + _lock = FileLock(_lock_path, timeout=600) + _lock.acquire() + + # Check cache again — another rank may have built while we waited + cached = pull_cached_engine( + hash_val, + module, + engine_cache, + settings, + inputs, + symbolic_shape_expressions, + ) + if cached is not None: + _lock.release() + return cached + output_dtypes = infer_module_output_dtypes( module, truncate_double=settings.truncate_double ) @@ -307,6 +342,10 @@ def interpret_module_to_result( hash_val, interpreter_result, engine_cache, settings, inputs ) + # Release the filelock so other ranks can proceed + if _distributed_caching and _lock is not None: + _lock.release() + serialized_engine = interpreter_result.engine.serialize() with io.BytesIO() as engine_bytes: engine_bytes.write(serialized_engine) diff --git a/tests/py/dynamo/distributed/test_distributed_engine_cache.py b/tests/py/dynamo/distributed/test_distributed_engine_cache.py new file mode 100644 index 0000000000..9f2253d7ef --- /dev/null +++ b/tests/py/dynamo/distributed/test_distributed_engine_cache.py @@ -0,0 +1,322 @@ +""" +Distributed engine cache coordination tests. + +Verifies that when multiple ranks compile the same model with engine caching +enabled, only one rank builds the TRT engine and others load from the shared +DiskEngineCache via filelock coordination. + +Run multi-rank tests (requires 2 GPUs): + pytest distributed/test_distributed_engine_cache.py::TestMultirankDistributedCache -v + +Run via torchrun: + torchrun --nproc_per_node=2 distributed/test_distributed_engine_cache.py --multirank +""" + +from __future__ import annotations + +import os +import shutil +import sys +import tempfile +import time +import unittest + +import torch +import torch.distributed as dist +import torch.nn as nn +from torch.testing._internal.common_distributed import ( + MultiProcessTestCase, + requires_nccl, + skip_if_lt_x_gpu, +) +from torch.testing._internal.common_utils import run_tests + +# --------------------------------------------------------------------------- +# Capability checks +# --------------------------------------------------------------------------- + + +def is_nccl_available() -> bool: + try: + return dist.is_nccl_available() + except Exception: + return False + + +# --------------------------------------------------------------------------- +# Models +# --------------------------------------------------------------------------- + + +class SimpleModel(nn.Module): + """Small model for cache coordination testing (no TP).""" + + def __init__(self): + super().__init__() + self.fc1 = nn.Linear(32, 64) + self.relu = nn.ReLU() + self.fc2 = nn.Linear(64, 32) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return self.fc2(self.relu(self.fc1(x))) + + +# --------------------------------------------------------------------------- +# Shared helpers +# --------------------------------------------------------------------------- + + +def _clean_cache_dir(cache_dir: str, rank: int) -> None: + """Rank 0 cleans cache dir, all ranks wait then recreate.""" + if rank == 0 and os.path.exists(cache_dir): + shutil.rmtree(cache_dir) + dist.barrier() + os.makedirs(cache_dir, exist_ok=True) + + +def _compile_with_cache(model, cache_dir, use_distributed_mode_trace=False): + """Compile model with TRT + engine caching enabled.""" + import torch_tensorrt + + return torch.compile( + model, + backend="torch_tensorrt", + options={ + "enabled_precisions": {torch.float32}, + "use_python_runtime": False, + "min_block_size": 1, + "cache_built_engines": True, + "reuse_cached_engines": True, + "immutable_weights": False, + "use_distributed_mode_trace": use_distributed_mode_trace, + "engine_cache_dir": cache_dir, + "engine_cache_size": 1 << 30, + }, + ) + + +def _assert_cache_has_files(cache_dir: str, rank: int) -> int: + """Assert cache dir has engine files (ignoring .lock files).""" + cache_files = [f for f in os.listdir(cache_dir) if not f.endswith(".lock")] + assert len(cache_files) > 0, f"Rank {rank}: cache dir is empty" + return len(cache_files) + + +# --------------------------------------------------------------------------- +# Test logic +# --------------------------------------------------------------------------- + + +def _multirank_distributed_cache_test( + rank: int, world_size: int, device: torch.device, cache_dir: str +) -> None: + """Test cache coordination with simple model (same weights on all ranks). + + Phase 1: First compile — one rank builds, other loads from cache. + Phase 2: Second compile — both ranks load from cache (refit). + """ + torch.manual_seed(42) + model = SimpleModel().eval().to(device) + inp = torch.randn(2, 32, device=device) + + with torch.no_grad(): + ref_output = model(inp) + + # Phase 1: compile + cache + torch._dynamo.reset() + trt_model = _compile_with_cache(model, cache_dir) + with torch.no_grad(): + t0 = time.time() + trt_output = trt_model(inp) + build_time = time.time() - t0 + + diff = (ref_output - trt_output).abs().max().item() + assert diff < 0.01, f"Rank {rank}: output mismatch, max diff={diff}" + n_files = _assert_cache_has_files(cache_dir, rank) + print( + f"[Rank {rank}] Compile + cache OK " + f"(max_diff={diff:.6f}, build_time={build_time:.2f}s, cache_entries={n_files})", + flush=True, + ) + + dist.barrier() + + # Phase 2: cache reuse + torch._dynamo.reset() + trt_model2 = _compile_with_cache(model, cache_dir) + with torch.no_grad(): + t0 = time.time() + trt_output2 = trt_model2(inp) + refit_time = time.time() - t0 + + diff2 = (ref_output - trt_output2).abs().max().item() + assert diff2 < 0.01, f"Rank {rank}: cached output mismatch, max diff={diff2}" + print( + f"[Rank {rank}] Cache reuse OK " + f"(max_diff={diff2:.6f}, refit_time={refit_time:.2f}s, " + f"speedup={build_time / max(refit_time, 0.001):.1f}x)", + flush=True, + ) + + +def _multirank_tp_cache_test( + rank: int, world_size: int, device: torch.device, cache_dir: str +) -> None: + """Test cache coordination with TP-sharded model (different weights per rank). + + Each rank holds ColwiseParallel/RowwiseParallel sharded weights. + Engine structure is identical — only weights differ. + One rank builds, other loads from cache and refits with its own weights. + """ + from torch.distributed.device_mesh import init_device_mesh + from torch.distributed.tensor.parallel import ( + ColwiseParallel, + RowwiseParallel, + parallelize_module, + ) + + mesh = init_device_mesh("cuda", (world_size,)) + + torch.manual_seed(42) + model = ( + nn.Sequential( + nn.Linear(64, 128, bias=False), + nn.ReLU(), + nn.Linear(128, 64, bias=False), + ) + .eval() + .to(device) + ) + + inp = torch.randn(2, 64, device=device) + + with torch.no_grad(): + ref_output = model(inp) + + tp_plan = {"0": ColwiseParallel(), "2": RowwiseParallel()} + parallelize_module(model, mesh, tp_plan) + + torch._dynamo.reset() + trt_model = _compile_with_cache(model, cache_dir, use_distributed_mode_trace=True) + with torch.no_grad(): + trt_output = trt_model(inp) + + diff = (ref_output - trt_output).abs().max().item() + assert diff < 0.05, f"Rank {rank}: TP output mismatch, max diff={diff}" + n_files = _assert_cache_has_files(cache_dir, rank) + print( + f"[Rank {rank}] TP compile + cache OK " + f"(max_diff={diff:.6f}, cache_entries={n_files})", + flush=True, + ) + + +# --------------------------------------------------------------------------- +# pytest path (MultiProcessTestCase) +# --------------------------------------------------------------------------- + + +@unittest.skipIf(not is_nccl_available(), "NCCL not available") +@unittest.skipIf( + not torch.cuda.is_available() or torch.cuda.device_count() < 2, + "Requires at least 2 GPUs", +) +class TestMultirankDistributedCache(MultiProcessTestCase): + """Distributed engine cache tests as pytest-compatible MultiProcessTestCase.""" + + world_size = 2 + + def setUp(self) -> None: + super().setUp() + self._spawn_processes() + + def _init_dist(self, cache_dir: str) -> torch.device: + """Initialize dist, clean cache, return device.""" + store = dist.FileStore(self.file_name, self.world_size) + dist.init_process_group( + backend="nccl", + store=store, + rank=self.rank, + world_size=self.world_size, + ) + local = self.rank % torch.cuda.device_count() + torch.cuda.set_device(local) + _clean_cache_dir(cache_dir, self.rank) + return torch.device(f"cuda:{local}") + + @requires_nccl() + def test_distributed_cache_coordination(self) -> None: + """Both ranks compile same model with caching — output matches reference.""" + cache_dir = tempfile.mkdtemp(prefix="trt_dist_cache_pytest_") + device = self._init_dist(cache_dir) + _multirank_distributed_cache_test(self.rank, self.world_size, device, cache_dir) + + @requires_nccl() + def test_tp_cache_coordination(self) -> None: + """TP-sharded model: one rank builds, other loads from cache + refits.""" + cache_dir = tempfile.mkdtemp(prefix="trt_dist_tp_cache_pytest_") + device = self._init_dist(cache_dir) + _multirank_tp_cache_test(self.rank, self.world_size, device, cache_dir) + + +# --------------------------------------------------------------------------- +# torchrun entry point +# --------------------------------------------------------------------------- + + +def _multirank_setup() -> tuple: + dist.init_process_group(backend="nccl") + rank = dist.get_rank() + world_size = dist.get_world_size() + local_rank = int(os.environ.get("LOCAL_RANK", rank % torch.cuda.device_count())) + torch.cuda.set_device(local_rank) + return rank, world_size, torch.device(f"cuda:{local_rank}") + + +def run_multirank_tests() -> None: + """Entry point for --multirank mode (called by torchrun workers).""" + rank, world_size, device = _multirank_setup() + print(f"[Rank {rank}/{world_size}] device={device}", flush=True) + + base_cache_dir = os.path.join(tempfile.gettempdir(), "trt_dist_cache_torchrun") + + tests = [ + ( + "simple_model_cache", + base_cache_dir + "_simple", + _multirank_distributed_cache_test, + ), + ("tp_model_cache", base_cache_dir + "_tp", _multirank_tp_cache_test), + ] + + failed = [] + for name, cache_dir, test_fn in tests: + _clean_cache_dir(cache_dir, rank) + dist.barrier() + try: + test_fn(rank, world_size, device, cache_dir) + except Exception as e: + failed.append((name, str(e))) + print(f"[Rank {rank}] FAIL {name}: {e}", flush=True) + import traceback + + traceback.print_exc() + + dist.barrier() + dist.destroy_process_group() + + if failed: + print(f"[Rank {rank}] {len(failed)} test(s) FAILED:", flush=True) + for name, err in failed: + print(f" - {name}: {err}", flush=True) + sys.exit(1) + else: + print(f"[Rank {rank}] All distributed cache tests PASSED.", flush=True) + + +if __name__ == "__main__": + if "--multirank" in sys.argv or "--multinode" in sys.argv: + sys.argv = [a for a in sys.argv if a not in ("--multirank", "--multinode")] + run_multirank_tests() + else: + run_tests()