Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion .github/workflows/build-test-linux-x86_64.yml
Original file line number Diff line number Diff line change
Expand Up @@ -611,9 +611,11 @@ jobs:
python -m pytest -ra -v --junitxml=${RUNNER_TEST_RESULTS_DIR}/l2_dynamo_distributed_test_results.xml \
distributed/test_nccl_ops.py \
distributed/test_native_nccl.py \
distributed/test_export_save_load.py
distributed/test_export_save_load.py \
distributed/test_distributed_engine_cache.py
python -m torch_tensorrt.distributed.run --nproc_per_node=2 distributed/test_native_nccl.py --multirank
python -m torch_tensorrt.distributed.run --nproc_per_node=2 distributed/test_export_save_load.py --multirank
python -m torch_tensorrt.distributed.run --nproc_per_node=2 distributed/test_distributed_engine_cache.py --multirank
popd

concurrency:
Expand Down
1 change: 1 addition & 0 deletions py/torch_tensorrt/distributed/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from torch_tensorrt.distributed._distributed import ( # noqa: F401
distributed_context,
is_distributed_caching_enabled,
set_distributed_mode,
)
from torch_tensorrt.distributed._nccl_utils import ( # noqa: F401
Expand Down
25 changes: 25 additions & 0 deletions py/torch_tensorrt/distributed/_distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -205,3 +205,28 @@ def set_distributed_mode(group: Any, module: nn.Module) -> None:
seen.add(id(engine))
if getattr(engine, "requires_native_multidevice", False):
engine.set_group_name(group_name)


def is_distributed_caching_enabled(
is_engine_caching_supported: bool,
cache_built_engines: bool,
reuse_cached_engines: bool,
) -> bool:
"""Check if distributed engine cache coordination should be used.

Returns True when all conditions are met:
- Engine caching is supported (cache exists, refit available, mutable weights)
- User enabled both cache_built_engines and reuse_cached_engines
- Running in a distributed environment with world_size > 1

When True, only one rank builds the TRT engine and caches it.
Other ranks wait and load from the shared DiskEngineCache.
"""
return (
is_engine_caching_supported
and cache_built_engines
and reuse_cached_engines
and dist.is_available()
and dist.is_initialized()
and dist.get_world_size() > 1
)
43 changes: 41 additions & 2 deletions py/torch_tensorrt/dynamo/conversion/_conversion.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,12 @@
import logging
from typing import Any, Dict, List, NamedTuple, Optional, Sequence

import tensorrt as trt
import torch
from torch_tensorrt._enums import dtype
from torch_tensorrt._features import ENABLED_FEATURES
from torch_tensorrt._Input import Input
from torch_tensorrt.distributed._distributed import is_distributed_caching_enabled
from torch_tensorrt.dynamo._engine_cache import BaseEngineCache
from torch_tensorrt.dynamo._settings import CompilationSettings, settings_are_compatible
from torch_tensorrt.dynamo.conversion._symbolic_shape_capture import (
Expand All @@ -25,8 +27,6 @@
)
from torch_tensorrt.logging import TRT_LOGGER

import tensorrt as trt

logger = logging.getLogger(__name__)


Expand Down Expand Up @@ -262,6 +262,41 @@ def interpret_module_to_result(
if serialized_interpreter_result is not None: # hit the cache
return serialized_interpreter_result

# Distributed engine cache coordination: only one rank builds,
# others wait and load from shared cache.
_distributed_caching = is_distributed_caching_enabled(
is_engine_caching_supported,
settings.cache_built_engines,
settings.reuse_cached_engines,
)
_lock: Optional[Any] = None

if _distributed_caching:
import os as _os

from filelock import FileLock

# is_distributed_caching_enabled guarantees engine_cache and hash_val are set.
assert engine_cache is not None
assert hash_val is not None

_lock_path = _os.path.join(engine_cache.engine_cache_dir, f".{hash_val}.lock")
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You should build a locking facility into the core BaseEngineCache class so that if someone uses some other method other than disk cache it still works the same. For DiskCache FileLocks might be the right choice but say you use S3 buckets, it might be something else

_lock = FileLock(_lock_path, timeout=600)
_lock.acquire()

# Check cache again — another rank may have built while we waited
cached = pull_cached_engine(
hash_val,
module,
engine_cache,
settings,
inputs,
symbolic_shape_expressions,
)
if cached is not None:
_lock.release()
return cached

output_dtypes = infer_module_output_dtypes(
module, truncate_double=settings.truncate_double
)
Expand Down Expand Up @@ -307,6 +342,10 @@ def interpret_module_to_result(
hash_val, interpreter_result, engine_cache, settings, inputs
)

# Release the filelock so other ranks can proceed
if _distributed_caching and _lock is not None:
_lock.release()

serialized_engine = interpreter_result.engine.serialize()
with io.BytesIO() as engine_bytes:
engine_bytes.write(serialized_engine)
Expand Down
Loading
Loading