From e289caa39c16254212f0d8ebbcb6e11d784d2277 Mon Sep 17 00:00:00 2001 From: Zain Huda Date: Thu, 20 Nov 2025 14:15:37 -0800 Subject: [PATCH] Fully Sharded 2D Parallelism (#3558) Summary: **This diff introduces Fully Sharded 2D Parallelism in TorchRec. It brings forth significant memory (50%+) savings by sharding embedding tables when they are not in use.** After the embedding lookup, the embedding table is further sharded across the data parallel dimension until it is needed in the backward pass. This allows model layers after the embedding lookup to have more memory headroom. Enabling further scaling of the dense architecture. **Practically speaking, this saves 50%+ embedding memory per GPU which account for upwards of 10GB of memory saving on large models.** The peak memory during this step becomes, ```O(shard + shard/num_replication)```, which then leads to an embedding memory of ```O(shard/num_replication)``` after the lookup step. The memory free and collective communications are done in a overhead free manner by maximizing computation and communication collectives through asynchronous handling on multiple streams. With Fully Sharded 2D, the embedding weight synchronization has to happen every step or trained batches are lost across ranks. We use an asynchronous reduce scatter after the embedding lookup step. We are able to fully overlap this collective with compute to expose no additional overhead. A new awaitable is introduced, ```ReduceScatterResizeAwaitable``` under the Fully Sharded path that is called with SDD output_dist all to all. This awaitable ```wait()```s on the async reduce scatter and calls the ```resize()``` operation on the embedding memory ensuring no race conditions. Users can enable fully sharded 2D through, a new arg `ShardingStrategy` ``` DMPCollection(..., sharding_strategy=ShardingStrategy.FULLY_SHARDED) ``` This is part of our work to create an overhead free 2D parallel which will allow us to use it for every model. Remaining work from this diff is to launch an async all gather in the backward pass, making planner aware of such memory savings, and integrate this work with per module 2D Differential Revision: D82253387 --- .../distributed/batched_embedding_kernel.py | 364 ++++++++++++++++++ torchrec/distributed/embedding.py | 14 + torchrec/distributed/embedding_lookup.py | 84 +++- torchrec/distributed/embeddingbag.py | 14 + torchrec/distributed/model_parallel.py | 39 +- .../sharding/cw_sequence_sharding.py | 1 + torchrec/distributed/sharding/cw_sharding.py | 1 + .../sharding/dp_sequence_sharding.py | 1 + torchrec/distributed/sharding/dp_sharding.py | 1 + .../distributed/sharding/grid_sharding.py | 1 + .../sharding/rw_sequence_sharding.py | 1 + torchrec/distributed/sharding/rw_sharding.py | 1 + .../sharding/tw_sequence_sharding.py | 1 + torchrec/distributed/sharding/tw_sharding.py | 1 + .../distributed/sharding/twrw_sharding.py | 1 + torchrec/distributed/types.py | 18 +- 16 files changed, 512 insertions(+), 31 deletions(-) diff --git a/torchrec/distributed/batched_embedding_kernel.py b/torchrec/distributed/batched_embedding_kernel.py index 02c191b6e..5e5fe806b 100644 --- a/torchrec/distributed/batched_embedding_kernel.py +++ b/torchrec/distributed/batched_embedding_kernel.py @@ -13,10 +13,12 @@ import itertools import logging import tempfile +from collections import defaultdict, OrderedDict from dataclasses import dataclass from math import sqrt from typing import ( Any, + Callable, cast, Dict, FrozenSet, @@ -74,9 +76,12 @@ from torchrec.distributed.model_tracker.types import IndexedLookup from torchrec.distributed.shards_wrapper import LocalShardsWrapper from torchrec.distributed.types import ( + LazyAwaitable, Shard, ShardedTensor, ShardedTensorMetadata, + ShardingEnv, + ShardingEnv2D, ShardingType, ShardMetadata, TensorProperties, @@ -108,6 +113,61 @@ ENABLE_RAW_EMBEDDING_STREAMING_STR = "enable_raw_embedding_streaming" +class ReduceScatterResizeAwaitable(LazyAwaitable[torch.Tensor]): + """ + Awaitable that packages async reduce scatter work with a deferred resize operation. + The resize happens only when wait() is called, allowing maximum overlap with computation. + + This enables allows us to ensure we're 1) calling wait() on the async operation until the last + possible moment and 2) avoid race conditions with tensor memory with the collecitve and resize op. + """ + + def __init__( + self, + async_work: Optional[dist.Work], + async_event: Optional[torch.cuda.Event], + async_stream: Optional[torch.cuda.Stream], + unsharded_param: torch.Tensor, + shard_buf: torch.Tensor, + resize_callback: Callable[[], None], + ) -> None: + """ + Args: + async_work: The async reduce scatter work handle + async_event: CUDA event to synchronize streams + async_stream: The communication stream + unsharded_param: The original unsharded parameter tensor + shard_buf: The buffer containing the sharded result + resize_callback: Callback to perform resize operation (called on wait()) + """ + super().__init__() + self._async_work = async_work + self._async_event = async_event + self._async_stream = async_stream + self._unsharded_param = unsharded_param + self._shard_buf = shard_buf + self._resize_callback = resize_callback + self._completed = False + + def _wait_impl(self) -> torch.Tensor: + """ + Wait for the async reduce scatter to complete, then perform the resize operation. + This is where the deferred resize actually happens. + """ + if self._completed: + return self._shard_buf + + if self._async_event is not None: + torch.cuda.current_stream().wait_event(self._async_event) + if self._async_work is not None: + self._async_work.wait() + + self._resize_callback() + + self._completed = True + return self._shard_buf + + def _populate_res_params(config: GroupedEmbeddingConfig) -> Tuple[bool, RESParams]: # populate res_params, which is used for raw embedding streaming # here only populates the params available in fused_params and TBE configs @@ -2467,6 +2527,157 @@ def purge(self) -> None: self._emb_module.reset_cache_states() +class ShardedBatchedFusedEmbedding(BatchedFusedEmbedding): + """ + Hybird Sharded Version of BatchedFusedEmbedding. + + This is used with DMPCollection when ShardingStrategy.HYBRID is enabled. + """ + + def __init__( + self, + config: GroupedEmbeddingConfig, + pg: Optional[dist.ProcessGroup] = None, + device: Optional[torch.device] = None, + env: Optional[ShardingEnv] = None, + ) -> None: + super().__init__(config, pg, device) + + assert isinstance( + env, ShardingEnv2D + ), "env is required for ShardedBatchedFusedEmbeddingBag" + self._env: ShardingEnv2D = env + + self.weights_sharded = False + # pyre-ignore[8] + self._original_shape: torch.Size = self._emb_module.weights_dev.shape + # pyre-ignore[8] + self._unsharded_param: torch.Tensor = self._emb_module.weights_dev + self._stash_nbytes: int = ( + self._emb_module.weights_dev.untyped_storage().nbytes() # pyre-ignore[29] + ) + self._shard_buf_nbytes: int = 0 + self.shard_buf: Optional[torch.Tensor] = None + + self._async_stream: torch.cuda.Stream = torch.cuda.Stream( + device=self._emb_module.weights_dev.device + ) + self._async_work: Optional[dist.Work] = None + self._async_event: Optional[torch.cuda.Event] = None + self._rs_awaitable: Optional[ReduceScatterResizeAwaitable] = None + + self.register_full_backward_pre_hook( + self._hybird_sharded_backward_hook, # pyre-ignore[6] + ) + + def _all_gather_table_weights(self) -> None: + if not self.weights_sharded: + return + self._wait_on_reduce_scatter() + self._unsharded_param.untyped_storage().resize_(self._stash_nbytes) + + dist.all_gather_into_tensor( + output_tensor=self._unsharded_param, + input_tensor=self.shard_buf, + group=self._env.replica_pg, + async_op=False, + ) + # pyre-ignore[16] + self._emb_module.weights_dev = self._unsharded_param + # pyre-ignore[16] + self.shard_buf.untyped_storage().resize_(0) + self.weights_sharded = False + + def _hybird_sharded_backward_hook( + self, module: nn.Module, grad_input: List[torch.Tensor] + ) -> None: + self._all_gather_table_weights() + + def _wait_on_reduce_scatter(self) -> None: + """ + Ensure the post embedding lookup reduce scatter is finished before backward. + + Ideally, backward does not need to wait on RS, as we will be all gathering the shards + in the backward pass. + + Now uses the awaitable mechanism to defer resize until needed. + """ + # pyre-ignore[16] + self._rs_awaitable.wait() + self._rs_awaitable = None + + def get_rs_awaitable(self) -> Optional[ReduceScatterResizeAwaitable]: + """ + Get the current reduce scatter awaitable. + This can be used by higher-level modules to compose awaitables. + """ + return self._rs_awaitable + + def forward(self, features: KeyedJaggedTensor) -> torch.Tensor: + embs = super().forward(features) + self._async_stream.wait_stream(torch.cuda.current_stream()) + with torch.cuda.stream(self._async_stream): + self._rs_awaitable = self._reduce_scatter_weights_async() + return embs + + def _reduce_scatter_weights_async(self) -> ReduceScatterResizeAwaitable: + """ + Launch async reduce scatter but defer the resize operation. + Returns an awaitable that will perform resize on wait(). + + This allows the resize operation to be deferred until the result is actually needed, + maximizing overlap between async communication and computation. + """ + with torch.no_grad(): + self.weights_sharded = True + num_groups = self._env.num_sharding_groups() + + # pyre-ignore[29] + total_size = self._emb_module.weights_dev.numel() + shard_size = total_size // num_groups + + if self.shard_buf is None: + self.shard_buf = torch.empty( + shard_size, + # pyre-ignore[6] + dtype=self._emb_module.weights_dev.dtype, + # pyre-ignore[6] + device=self._emb_module.weights_dev.device, + ) + # pyre-ignore[16] + self._shard_buf_nbytes = self.shard_buf.untyped_storage().nbytes() + else: + self.shard_buf.untyped_storage().resize_(self._shard_buf_nbytes) + + # pyre-ignore[29] + input_tensor = self._emb_module.weights_dev.contiguous() + + self._async_work = dist.reduce_scatter_tensor( + output=self.shard_buf, + input=input_tensor, + op=dist.ReduceOp.AVG, + group=self._env.replica_pg, + async_op=True, + ) + + self._async_event = torch.cuda.Event(enable_timing=False, blocking=False) + # pyre-ignore[16] + self._async_event.record(self._async_stream) + + def resize_callback() -> None: + self._emb_module.weights_dev.untyped_storage().resize_(0) # pyre-ignore[29] + self._emb_module.weights_dev = self.shard_buf # pyre-ignore[16] + + return ReduceScatterResizeAwaitable( + async_work=self._async_work, + async_event=self._async_event, + async_stream=self._async_stream, + unsharded_param=self._unsharded_param, + shard_buf=self.shard_buf, + resize_callback=resize_callback, + ) + + class BatchedDenseEmbedding(BaseBatchedEmbedding[torch.Tensor]): def __init__( self, @@ -3238,6 +3449,7 @@ def __init__( pg: Optional[dist.ProcessGroup] = None, device: Optional[torch.device] = None, sharding_type: Optional[ShardingType] = None, + env: Optional[ShardingEnv] = None, ) -> None: super().__init__(config, pg, device, sharding_type) @@ -3356,6 +3568,158 @@ def purge(self) -> None: self._emb_module.reset_cache_states() +class ShardedBatchedFusedEmbeddingBag(BatchedFusedEmbeddingBag): + """ + Hybird Sharded Version of BatchedFusedEmbeddingBag. + + This is used with DMPCollection when ShardingStrategy.HYBRID is enabled. + """ + + def __init__( + self, + config: GroupedEmbeddingConfig, + pg: Optional[dist.ProcessGroup] = None, + device: Optional[torch.device] = None, + sharding_type: Optional[ShardingType] = None, + env: Optional[ShardingEnv] = None, + ) -> None: + super().__init__(config, pg, device, sharding_type) + + assert isinstance( + env, ShardingEnv2D + ), "env is required for ShardedBatchedFusedEmbeddingBag" + self._env: ShardingEnv2D = env + + self.weights_sharded = False + # pyre-ignore[8] + self._original_shape: torch.Size = self._emb_module.weights_dev.shape + # pyre-ignore[8] + self._unsharded_param: torch.Tensor = self._emb_module.weights_dev + self._stash_nbytes: int = ( + self._emb_module.weights_dev.untyped_storage().nbytes() # pyre-ignore[29] + ) + self._shard_buf_nbytes: int = 0 + self.shard_buf: Optional[torch.Tensor] = None + + self._async_stream: torch.cuda.Stream = torch.cuda.Stream( + device=self._emb_module.weights_dev.device + ) + self._async_work: Optional[dist.Work] = None + self._async_event: Optional[torch.cuda.Event] = None + self._rs_awaitable: Optional[ReduceScatterResizeAwaitable] = None + + self.register_full_backward_pre_hook( + self._hybird_sharded_backward_hook, # pyre-ignore[6] + ) + + def _all_gather_table_weights(self) -> None: + if not self.weights_sharded: + return + self._wait_on_reduce_scatter() + self._unsharded_param.untyped_storage().resize_(self._stash_nbytes) + + dist.all_gather_into_tensor( + output_tensor=self._unsharded_param, + input_tensor=self.shard_buf, + group=self._env.replica_pg, + async_op=False, + ) + # pyre-ignore[16] + self._emb_module.weights_dev = self._unsharded_param + # pyre-ignore[16] + self.shard_buf.untyped_storage().resize_(0) + self.weights_sharded = False + + def _hybird_sharded_backward_hook( + self, module: nn.Module, grad_input: List[torch.Tensor] + ) -> None: + self._all_gather_table_weights() + + def _wait_on_reduce_scatter(self) -> None: + """ + Ensure the post embedding lookup reduce scatter is finished before backward. + + Ideally, backward does not need to wait on RS, as we will be all gathering the shards + in the backward pass. + + Now uses the awaitable mechanism to defer resize until needed. + """ + # pyre-ignore[16] + self._rs_awaitable.wait() + self._rs_awaitable = None + + def get_rs_awaitable(self) -> Optional[ReduceScatterResizeAwaitable]: + """ + Get the current reduce scatter awaitable. + This can be used by higher-level modules to compose awaitables. + """ + return self._rs_awaitable + + def forward(self, features: KeyedJaggedTensor) -> torch.Tensor: + embs = super().forward(features) + self._async_stream.wait_stream(torch.cuda.current_stream()) + with torch.cuda.stream(self._async_stream): + self._rs_awaitable = self._reduce_scatter_weights_async() + return embs + + def _reduce_scatter_weights_async(self) -> ReduceScatterResizeAwaitable: + """ + Launch async reduce scatter but defer the resize operation. + Returns an awaitable that will perform resize on wait(). + + This allows the resize operation to be deferred until the result is actually needed, + maximizing overlap between async communication and computation. + """ + with torch.no_grad(): + self.weights_sharded = True + num_groups = self._env.num_sharding_groups() + + # pyre-ignore[29] + total_size = self._emb_module.weights_dev.numel() + shard_size = total_size // num_groups + + if self.shard_buf is None: + self.shard_buf = torch.empty( + shard_size, + # pyre-ignore[6] + dtype=self._emb_module.weights_dev.dtype, + # pyre-ignore[6] + device=self._emb_module.weights_dev.device, + ) + # pyre-ignore[16] + self._shard_buf_nbytes = self.shard_buf.untyped_storage().nbytes() + else: + self.shard_buf.untyped_storage().resize_(self._shard_buf_nbytes) + + # pyre-ignore[29] + input_tensor = self._emb_module.weights_dev.contiguous() + + self._async_work = dist.reduce_scatter_tensor( + output=self.shard_buf, + input=input_tensor, + op=dist.ReduceOp.AVG, + group=self._env.replica_pg, + async_op=True, + ) + + self._async_event = torch.cuda.Event(enable_timing=False, blocking=False) + # pyre-ignore[16] + self._async_event.record(self._async_stream) + + def resize_callback() -> None: + self._emb_module.weights_dev.untyped_storage().resize_(0) # pyre-ignore[29] + self._emb_module.weights_dev = self.shard_buf # pyre-ignore[16] + + return ReduceScatterResizeAwaitable( + async_work=self._async_work, + async_event=self._async_event, + async_stream=self._async_stream, + unsharded_param=self._unsharded_param, + shard_buf=self.shard_buf, + resize_callback=resize_callback, + ) + + class BatchedDenseEmbeddingBag(BaseBatchedEmbeddingBag[torch.Tensor]): def __init__( self, diff --git a/torchrec/distributed/embedding.py b/torchrec/distributed/embedding.py index 5e9a96ca0..84d3ef659 100644 --- a/torchrec/distributed/embedding.py +++ b/torchrec/distributed/embedding.py @@ -343,6 +343,7 @@ def __init__( module_fqn: Optional[str] = None, sharding_types: Optional[List[str]] = None, use_gather_select: bool = False, + resize_awaitables: Optional[List[Awaitable[torch.Tensor]]] = None, ) -> None: super().__init__() self._awaitables_per_sharding = awaitables_per_sharding @@ -354,6 +355,7 @@ def __init__( self._module_fqn = module_fqn self._sharding_types = sharding_types self._use_gather_select = use_gather_select + self._resize_awaitables = resize_awaitables def _wait_impl(self) -> Dict[str, JaggedTensor]: jt_dict: Dict[str, JaggedTensor] = {} @@ -398,6 +400,12 @@ def _wait_impl(self) -> Dict[str, JaggedTensor]: use_gather_select=self._use_gather_select, ) ) + + # free memory and resize + # pyre-ignore[16] + for awaitable in self._resize_awaitables: + awaitable.wait() + return jt_dict @@ -1588,6 +1596,8 @@ def compute_and_output_dist( ) -> LazyAwaitable[Dict[str, JaggedTensor]]: awaitables_per_sharding: List[Awaitable[torch.Tensor]] = [] features_before_all2all_per_sharding: List[KeyedJaggedTensor] = [] + resize_awaitables = [] + for lookup, odist, features, sharding_ctx, sharding_type in zip( self._lookups, self._output_dists, @@ -1604,6 +1614,9 @@ def compute_and_output_dist( EmbeddingEvent.LOOKUP, self._module_fqn, sharding_type ): embs = lookup(features) + if hasattr(lookup, "get_resize_awaitables"): + # pyre-ignore[29] + resize_awaitables.extend(lookup.get_resize_awaitables()) if self.post_lookup_tracker_fn is not None: self.post_lookup_tracker_fn(features, embs, self, None) @@ -1631,6 +1644,7 @@ def compute_and_output_dist( module_fqn=self._module_fqn, sharding_types=list(self._sharding_type_to_sharding.keys()), use_gather_select=self._use_gather_select, + resize_awaitables=resize_awaitables, ) def _embedding_dim_for_sharding_type(self, sharding_type: str) -> int: diff --git a/torchrec/distributed/embedding_lookup.py b/torchrec/distributed/embedding_lookup.py index 9f3ce69c7..2d14e4cde 100644 --- a/torchrec/distributed/embedding_lookup.py +++ b/torchrec/distributed/embedding_lookup.py @@ -39,6 +39,8 @@ BatchedFusedEmbeddingBag, KeyValueEmbedding, KeyValueEmbeddingBag, + ShardedBatchedFusedEmbedding, + ShardedBatchedFusedEmbeddingBag, ZeroCollisionEmbeddingCache, ZeroCollisionKeyValueEmbedding, ZeroCollisionKeyValueEmbeddingBag, @@ -65,7 +67,15 @@ QuantBatchedEmbedding, QuantBatchedEmbeddingBag, ) -from torchrec.distributed.types import rank_device, ShardedTensor, ShardingType +from torchrec.distributed.types import ( + LazyAwaitable, + rank_device, + ShardedTensor, + ShardingEnv, + ShardingEnv2D, + ShardingStrategy, + ShardingType, +) from torchrec.sparse.jagged_tensor import KeyedJaggedTensor logger: logging.Logger = logging.getLogger(__name__) @@ -185,12 +195,15 @@ def __init__( grouped_configs: List[GroupedEmbeddingConfig], pg: Optional[dist.ProcessGroup] = None, device: Optional[torch.device] = None, + env: Optional[ShardingEnv] = None, ) -> None: super().__init__() self._emb_modules: nn.ModuleList = nn.ModuleList() self._need_prefetch: bool = False for config in grouped_configs: - self._emb_modules.append(self._create_embedding_kernel(config, pg, device)) + self._emb_modules.append( + self._create_embedding_kernel(config, pg, device, env) + ) self._feature_splits: List[int] = [] for config in grouped_configs: @@ -218,6 +231,7 @@ def _create_embedding_kernel( config: GroupedEmbeddingConfig, pg: Optional[dist.ProcessGroup], device: Optional[torch.device], + env: Optional[ShardingEnv] = None, ) -> BaseEmbedding: for table in config.embedding_tables: if ( @@ -234,11 +248,20 @@ def _create_embedding_kernel( device=device, ) elif config.compute_kernel == EmbeddingComputeKernel.FUSED: - return BatchedFusedEmbedding( - config=config, - pg=pg, - device=device, - ) + if ( + env + and isinstance(env, ShardingEnv2D) + and env.sharding_strategy == ShardingStrategy.FULLY_SHARDED + ): + return ShardedBatchedFusedEmbedding( + config=config, pg=pg, device=device, env=env + ) + else: + return BatchedFusedEmbedding( + config=config, + pg=pg, + device=device, + ) elif config.compute_kernel == EmbeddingComputeKernel.KEY_VALUE: return KeyValueEmbedding( config=config, @@ -329,6 +352,14 @@ def forward( return embeddings_cat_empty_rank_handle(embeddings, self._dummy_embs_tensor) + def get_resize_awaitables(self) -> List[LazyAwaitable[torch.Tensor]]: + # TODO - we can probably do some smart grouping to make this more efficient + return [ + emb_module.get_rs_awaitable() # pyre-ignore[29] + for emb_module in self._emb_modules + if hasattr(emb_module, "get_rs_awaitable") + ] + # pyre-fixme[14]: `state_dict` overrides method defined in `Module` inconsistently. def state_dict( self, @@ -512,12 +543,14 @@ def __init__( feature_processor: Optional[BaseGroupedFeatureProcessor] = None, scale_weight_gradients: bool = True, sharding_type: Optional[ShardingType] = None, + env: Optional[ShardingEnv] = None, ) -> None: super().__init__() + self._env = env self._emb_modules: nn.ModuleList = nn.ModuleList() for config in grouped_configs: self._emb_modules.append( - self._create_embedding_kernel(config, device, pg, sharding_type) + self._create_embedding_kernel(config, device, pg, sharding_type, env) ) self._feature_splits: List[int] = [] @@ -555,6 +588,7 @@ def _create_embedding_kernel( device: Optional[torch.device], pg: Optional[dist.ProcessGroup], sharding_type: Optional[ShardingType], + env: Optional[ShardingEnv], ) -> BaseEmbedding: if config.compute_kernel == EmbeddingComputeKernel.DENSE: return BatchedDenseEmbeddingBag( @@ -564,12 +598,26 @@ def _create_embedding_kernel( sharding_type=sharding_type, ) elif config.compute_kernel == EmbeddingComputeKernel.FUSED: - return BatchedFusedEmbeddingBag( - config=config, - pg=pg, - device=device, - sharding_type=sharding_type, - ) + if ( + env + and isinstance(env, ShardingEnv2D) + and env.sharding_strategy == ShardingStrategy.FULLY_SHARDED + ): + return ShardedBatchedFusedEmbeddingBag( + config=config, + pg=pg, + device=device, + sharding_type=sharding_type, + env=env, + ) + else: + return BatchedFusedEmbeddingBag( + config=config, + pg=pg, + device=device, + sharding_type=sharding_type, + env=env, + ) elif config.compute_kernel in { EmbeddingComputeKernel.KEY_VALUE, }: @@ -744,6 +792,14 @@ def forward( dim=1, ) + def get_resize_awaitables(self) -> List[LazyAwaitable[torch.Tensor]]: + # TODO - we can probably do some smart grouping to make this more efficient + return [ + emb_module.get_rs_awaitable() # pyre-ignore[29] + for emb_module in self._emb_modules + if hasattr(emb_module, "get_rs_awaitable") + ] + # pyre-fixme[14]: `state_dict` overrides method defined in `Module` inconsistently. def state_dict( self, diff --git a/torchrec/distributed/embeddingbag.py b/torchrec/distributed/embeddingbag.py index 754b9e6fa..a7fd36ca6 100644 --- a/torchrec/distributed/embeddingbag.py +++ b/torchrec/distributed/embeddingbag.py @@ -407,6 +407,7 @@ def __init__( embedding_names: List[str], module_fqn: Optional[str] = None, sharding_types: Optional[List[str]] = None, + resize_awaitables: Optional[List[Awaitable[torch.Tensor]]] = None, ) -> None: super().__init__() self._awaitables = awaitables @@ -414,6 +415,7 @@ def __init__( self._embedding_names = embedding_names self._module_fqn = module_fqn self._sharding_types = sharding_types + self._resize_awaitables = resize_awaitables def _wait_impl(self) -> KeyedTensor: embeddings = [] @@ -425,6 +427,12 @@ def _wait_impl(self) -> KeyedTensor: ): embeddings.append(w.wait()) + # free memory and resize + if self._resize_awaitables is not None: + # pyre-ignore[16] + for awaitable in self._resize_awaitables: + awaitable.wait() + return construct_output_kt( embeddings=embeddings, embedding_names=self._embedding_names, @@ -1655,6 +1663,7 @@ def compute_and_output_dist( """ batch_size_per_feature_pre_a2a = [] awaitables = [] + resize_awaitables = [] # No usage of zip for dynamo for i in range(len(self._lookups)): @@ -1669,7 +1678,11 @@ def compute_and_output_dist( self._module_fqn, sharding_type, ): + # with fully sharded 2D enabled, it returns an awaitable for the reduce scatter and resize operation embs = lookup(features) + if hasattr(lookup, "get_resize_awaitables"): + # pyre-ignore[29] + resize_awaitables.extend(lookup.get_resize_awaitables()) if self.post_lookup_tracker_fn is not None: self.post_lookup_tracker_fn(features, embs, self, None) @@ -1710,6 +1723,7 @@ def compute_and_output_dist( embedding_names=self._embedding_names, module_fqn=self._module_fqn, sharding_types=self._sharding_types, + resize_awaitables=resize_awaitables, ) # register callback if there are features that need mean pooling diff --git a/torchrec/distributed/model_parallel.py b/torchrec/distributed/model_parallel.py index e17e03727..ccca503e5 100644 --- a/torchrec/distributed/model_parallel.py +++ b/torchrec/distributed/model_parallel.py @@ -54,6 +54,7 @@ ShardingEnv, ShardingEnv2D, ShardingPlan, + ShardingStrategy, ) from torchrec.distributed.utils import ( add_prefix_to_state_dict, @@ -885,6 +886,7 @@ def __init__( world_size: int, sharding_group_size: int, global_pg: dist.ProcessGroup, + sharding_strategy: ShardingStrategy = ShardingStrategy.DEFAULT, node_group_size: Optional[int] = None, sharders: Optional[List[ModuleSharder[torch.nn.Module]]] = None, init_data_parallel: bool = True, @@ -974,9 +976,11 @@ def __init__( default_env = ShardingEnv2D( global_pg=self._pg, sharding_pg=self._ctxs[0].sharding_pg, + replica_pg=self._ctxs[0].replica_pg, device_mesh=self._ctxs[0].device_mesh, node_group_size=node_group_size, use_inter_host_allreduce=self._ctxs[0].use_inter_host_allreduce, + sharding_strategy=sharding_strategy, ) super().__init__( # type: ignore[misc] @@ -1014,6 +1018,7 @@ def _shard_modules_impl( env = ShardingEnv2D( global_pg=self._pg, sharding_pg=ctx.sharding_pg, + replica_pg=ctx.replica_pg, device_mesh=ctx.device_mesh, node_group_size=ctx.sharding_group_size, use_inter_host_allreduce=ctx.use_inter_host_allreduce, @@ -1062,15 +1067,14 @@ def sync(self, include_optimizer_state: bool = True) -> None: def _sync( self, replica_pg: dist.ProcessGroup, - modules_to_sync: List[nn.Module], + modules_to_sync: List[Tuple[nn.Module, nn.Module]], include_optimizer_state: bool = True, ) -> None: assert replica_pg is not None, "replica_pg is not initialized!" all_weights_by_dtype: dict[torch.dtype, List[torch.Tensor]] = defaultdict(list) - for emb_kernel in modules_to_sync: - # pyre-fixme[29]: `Union[Module, Tensor]` is not a function. - for w in emb_kernel.split_embedding_weights(): + for emb_kernel, _ in modules_to_sync: + for w in emb_kernel.split_embedding_weights(): # pyre-ignore[29] all_weights_by_dtype[w.dtype].append(w) opts = None @@ -1085,7 +1089,7 @@ def _sync( optimizer_tensors_by_dtype: Dict[torch.dtype, List[torch.Tensor]] = ( defaultdict(list) ) - for emb_kernel in modules_to_sync: + for emb_kernel, _ in modules_to_sync: # pyre-fixme[29]: `Union[Module, Tensor]` is not a function. optimizer_states = emb_kernel.get_optimizer_state() for state in optimizer_states: @@ -1267,46 +1271,49 @@ def _group_sharded_modules( # Group leftover embedding kernels, with respect to default context # pyre-ignore[9] modules_to_skip: List[nn.Module] = [c.sharded_module for c in contexts[1:]] - sharded_modules: List[nn.Module] = [] + sharded_modules: List[Tuple[nn.Module, nn.Module]] = [] def _find_sharded_modules( module: nn.Module, + prev_module: nn.Module, ) -> None: if isinstance(module, SplitTableBatchedEmbeddingBagsCodegen): - sharded_modules.append(module) + sharded_modules.append((module, prev_module)) if not isinstance( module, tuple(modules_to_skip) # pyre-ignore[6] ) and hasattr(module, "_lookups"): for lookup in module._lookups: # pyre-ignore[29] - _find_sharded_modules(lookup) + _find_sharded_modules(lookup, module) for _, child in module.named_children(): - _find_sharded_modules(child) + _find_sharded_modules(child, module) - _find_sharded_modules(self._dmp_wrapped_module) + # pyre-ignore[6] + _find_sharded_modules(self._dmp_wrapped_module, None) contexts[0].modules_to_sync = sharded_modules def _group_sharded_module( self, sharded_module: nn.Module, - ) -> List[nn.Module]: + ) -> List[Tuple[nn.Module, nn.Module]]: # Traverse module and find all sharded module kernels matching the sharded module # Post init DMP, save the embedding kernels - sharded_modules: List[nn.Module] = [] + sharded_modules: List[Tuple[nn.Module, nn.Module]] = [] def _find_sharded_modules( module: nn.Module, + prev_module: nn.Module, ) -> None: if isinstance(module, SplitTableBatchedEmbeddingBagsCodegen): - sharded_modules.append(module) + sharded_modules.append((module, prev_module)) if isinstance(module, sharded_module): # pyre-ignore[6] for lookup in module._lookups: # pyre-ignore[29] - _find_sharded_modules(lookup) + _find_sharded_modules(lookup, module) for _, child in module.named_children(): - _find_sharded_modules(child) + _find_sharded_modules(child, module) - _find_sharded_modules(self._dmp_wrapped_module) + _find_sharded_modules(self._dmp_wrapped_module, None) return sharded_modules @property diff --git a/torchrec/distributed/sharding/cw_sequence_sharding.py b/torchrec/distributed/sharding/cw_sequence_sharding.py index 643e1d815..6f0fa4a71 100644 --- a/torchrec/distributed/sharding/cw_sequence_sharding.py +++ b/torchrec/distributed/sharding/cw_sequence_sharding.py @@ -68,6 +68,7 @@ def create_lookup( grouped_configs=self._grouped_embedding_configs, pg=self._pg, device=device if device is not None else self._device, + env=self._env, ) def create_output_dist( diff --git a/torchrec/distributed/sharding/cw_sharding.py b/torchrec/distributed/sharding/cw_sharding.py index 90a2e6bef..e777798d4 100644 --- a/torchrec/distributed/sharding/cw_sharding.py +++ b/torchrec/distributed/sharding/cw_sharding.py @@ -287,6 +287,7 @@ def create_lookup( pg=self._pg, device=device if device is not None else self._device, feature_processor=feature_processor, + env=self._env, ) def create_output_dist( diff --git a/torchrec/distributed/sharding/dp_sequence_sharding.py b/torchrec/distributed/sharding/dp_sequence_sharding.py index 2ad11b247..aaf4752b9 100644 --- a/torchrec/distributed/sharding/dp_sequence_sharding.py +++ b/torchrec/distributed/sharding/dp_sequence_sharding.py @@ -80,6 +80,7 @@ def create_lookup( grouped_configs=self._grouped_embedding_configs, pg=self._env.process_group, device=device if device is not None else self._device, + env=self._env, ) def create_output_dist( diff --git a/torchrec/distributed/sharding/dp_sharding.py b/torchrec/distributed/sharding/dp_sharding.py index 6ffb52e4c..47f9f7732 100644 --- a/torchrec/distributed/sharding/dp_sharding.py +++ b/torchrec/distributed/sharding/dp_sharding.py @@ -213,6 +213,7 @@ def create_lookup( # For data parallel we need to turn always gradient scaling in for weights # because get_gradient_scaling from comm_ops only affects model_parallel tables, not DP scale_weight_gradients=False, + env=self._env, ) def create_output_dist( diff --git a/torchrec/distributed/sharding/grid_sharding.py b/torchrec/distributed/sharding/grid_sharding.py index 88edbbe87..87984a977 100644 --- a/torchrec/distributed/sharding/grid_sharding.py +++ b/torchrec/distributed/sharding/grid_sharding.py @@ -508,6 +508,7 @@ def create_lookup( device=device if device is not None else self._device, feature_processor=feature_processor, sharding_type=ShardingType.TABLE_ROW_WISE, + env=self._env, ) def create_output_dist( diff --git a/torchrec/distributed/sharding/rw_sequence_sharding.py b/torchrec/distributed/sharding/rw_sequence_sharding.py index 1ca53d40b..7e3dbd133 100644 --- a/torchrec/distributed/sharding/rw_sequence_sharding.py +++ b/torchrec/distributed/sharding/rw_sequence_sharding.py @@ -157,6 +157,7 @@ def create_lookup( grouped_configs=self._grouped_embedding_configs, pg=self._pg, device=device if device is not None else self._device, + env=self._env, ) def create_output_dist( diff --git a/torchrec/distributed/sharding/rw_sharding.py b/torchrec/distributed/sharding/rw_sharding.py index f6ea83798..8c051d00b 100644 --- a/torchrec/distributed/sharding/rw_sharding.py +++ b/torchrec/distributed/sharding/rw_sharding.py @@ -678,6 +678,7 @@ def create_lookup( device=device if device is not None else self._device, feature_processor=feature_processor, sharding_type=ShardingType.ROW_WISE, + env=self._env, ) def create_output_dist( diff --git a/torchrec/distributed/sharding/tw_sequence_sharding.py b/torchrec/distributed/sharding/tw_sequence_sharding.py index 1a7b517bc..d5a77422a 100644 --- a/torchrec/distributed/sharding/tw_sequence_sharding.py +++ b/torchrec/distributed/sharding/tw_sequence_sharding.py @@ -141,6 +141,7 @@ def create_lookup( grouped_configs=self._grouped_embedding_configs, pg=self._pg, device=device if device is not None else self._device, + env=self._env, ) def create_output_dist( diff --git a/torchrec/distributed/sharding/tw_sharding.py b/torchrec/distributed/sharding/tw_sharding.py index 7b1f211ea..a3f637201 100644 --- a/torchrec/distributed/sharding/tw_sharding.py +++ b/torchrec/distributed/sharding/tw_sharding.py @@ -435,6 +435,7 @@ def create_lookup( pg=self._pg, device=device if device is not None else self._device, feature_processor=feature_processor, + env=self._env, ) def create_output_dist( diff --git a/torchrec/distributed/sharding/twrw_sharding.py b/torchrec/distributed/sharding/twrw_sharding.py index 86b5da9fc..7823ac82c 100644 --- a/torchrec/distributed/sharding/twrw_sharding.py +++ b/torchrec/distributed/sharding/twrw_sharding.py @@ -676,6 +676,7 @@ def create_lookup( device=device if device is not None else self._device, feature_processor=feature_processor, sharding_type=ShardingType.TABLE_ROW_WISE, + env=self._env, ) def create_output_dist( diff --git a/torchrec/distributed/types.py b/torchrec/distributed/types.py index 0144b208d..604bc2978 100644 --- a/torchrec/distributed/types.py +++ b/torchrec/distributed/types.py @@ -921,6 +921,16 @@ def from_local(cls, world_size: int, rank: int) -> "ShardingEnv": return cls(world_size, rank, None) +class ShardingStrategy(Enum): + """ + Sharding strategy for DMPCollection. + """ + + DEFAULT = "default" + PER_MODULE = "per_module" + FULLY_SHARDED = "fully_sharded" + + class ShardingEnv2D(ShardingEnv): """ Creates a sharding environment for 2D parallelism, enables usage of 2D parallelism in sharding @@ -945,10 +955,12 @@ class ShardingEnv2D(ShardingEnv): def __init__( self, sharding_pg: dist.ProcessGroup, + replica_pg: dist.ProcessGroup, global_pg: dist.ProcessGroup, device_mesh: DeviceMesh, node_group_size: Optional[int] = None, use_inter_host_allreduce: bool = False, + sharding_strategy: ShardingStrategy = ShardingStrategy.DEFAULT, ) -> None: assert device_mesh.ndim == 2, "DeviceMesh must be two dimensional!" self.world_size: int = dist.get_world_size(sharding_pg) @@ -959,10 +971,12 @@ def __init__( global_pg # to keep consistent naming between ShardingEnv and ShardingEnv2D ) self.sharding_pg: dist.ProcessGroup = sharding_pg + self.replica_pg: dist.ProcessGroup = replica_pg self.device_mesh: DeviceMesh = device_mesh self.node_group_size: Optional[int] = node_group_size self.output_dtensor: bool = True self.use_inter_host_allreduce: bool = use_inter_host_allreduce + self.sharding_strategy: ShardingStrategy = sharding_strategy def num_sharding_groups(self) -> int: """ @@ -1396,5 +1410,7 @@ class DMPCollectionContext(DMPCollectionConfig): device_mesh: "DeviceMesh" = field(init=False) sharding_pg: "dist.ProcessGroup" = field(init=False) replica_pg: "dist.ProcessGroup" = field(init=False) - modules_to_sync: List[nn.Module] = field(init=False, default_factory=list) + modules_to_sync: List[Tuple[nn.Module, nn.Module]] = field( + init=False, default_factory=list + ) sharded_module: Optional[nn.Module] = field(init=False, default=None)