From 44cb5a85b433225e2247364ad3d5fc69f66c3361 Mon Sep 17 00:00:00 2001 From: Joey Yang Date: Mon, 17 Nov 2025 15:42:44 -0800 Subject: [PATCH] Extend raw_id_tracker to track ShardedManagedCollisionEmbeddingCollection (#3545) Summary: X-link: https://github.com/facebookresearch/FBGEMM/pull/2131 X-link: https://github.com/pytorch/FBGEMM/pull/5128 This diff extends `raw_id_tracker` to support both `ShardedManagedCollisionEmbeddingCollection` and `ShardedManagedCollisionEmbeddingBagCollection`. To avoid code duplication, the diff refactors the initialization and identity parsing logic into the `BaseEmbedding` class. Reviewed By: chouxi, aliafzal Differential Revision: D87018676 --- .../distributed/batched_embedding_kernel.py | 98 ++--------------- torchrec/distributed/embedding_kernel.py | 104 +++++++++++++++++- .../model_tracker/trackers/raw_id_tracker.py | 18 ++- 3 files changed, 123 insertions(+), 97 deletions(-) diff --git a/torchrec/distributed/batched_embedding_kernel.py b/torchrec/distributed/batched_embedding_kernel.py index f0d4cbabe..02c191b6e 100644 --- a/torchrec/distributed/batched_embedding_kernel.py +++ b/torchrec/distributed/batched_embedding_kernel.py @@ -17,7 +17,6 @@ from math import sqrt from typing import ( Any, - Callable, cast, Dict, FrozenSet, @@ -109,22 +108,6 @@ ENABLE_RAW_EMBEDDING_STREAMING_STR = "enable_raw_embedding_streaming" -class RawIdTrackerWrapper: - def __init__( - self, - get_indexed_lookups: Callable[ - [List[str], Optional[str]], - Dict[str, List[torch.Tensor]], - ], - delete: Callable[ - [int], - None, - ], - ) -> None: - self.get_indexed_lookups = get_indexed_lookups - self.delete = delete - - def _populate_res_params(config: GroupedEmbeddingConfig) -> Tuple[bool, RESParams]: # populate res_params, which is used for raw embedding streaming # here only populates the params available in fused_params and TBE configs @@ -1723,9 +1706,17 @@ def init_parameters(self) -> None: ) def forward(self, features: KeyedJaggedTensor) -> torch.Tensor: + hash_zch_identities = self._get_hash_zch_identities(features) + if hash_zch_identities is None: + return self.emb_module( + indices=features.values().long(), + offsets=features.offsets().long(), + ) + return self.emb_module( indices=features.values().long(), offsets=features.offsets().long(), + hash_zch_identities=hash_zch_identities, ) # pyre-fixme[14]: `state_dict` overrides method defined in `Module` inconsistently. @@ -2563,7 +2554,6 @@ def __init__( self._lengths_per_emb: List[int] = [] self.table_name_to_count: Dict[str, int] = {} self._param_per_table: Dict[str, TableBatchedEmbeddingSlice] = {} - self._raw_id_tracker_wrapper: Optional[RawIdTrackerWrapper] = None for idx, table_config in enumerate(self._config.embedding_tables): self._local_rows.append(table_config.local_rows) @@ -2617,62 +2607,6 @@ def init_parameters(self) -> None: weight_init_max, ) - def _get_hash_zch_identities( - self, features: KeyedJaggedTensor - ) -> Optional[torch.Tensor]: - if self._raw_id_tracker_wrapper is None or not isinstance( - self.emb_module, SplitTableBatchedEmbeddingBagsCodegen - ): - return None - - raw_id_tracker_wrapper = self._raw_id_tracker_wrapper - assert ( - raw_id_tracker_wrapper is not None - ), "self._raw_id_tracker_wrapper should not be None" - assert hasattr( - self.emb_module, "res_params" - ), "res_params should exist when raw_id_tracker is enabled" - res_params: RESParams = self.emb_module.res_params # pyre-ignore[9] - table_names = res_params.table_names - - # TODO: get_indexed_lookups() may return multiple IndexedLookup objects - # across multiple training iterations. Current logic appends raw_ids from - # all batches sequentially. This may cause misalignment with - # features.values() which only contains the current batch. - raw_ids_dict = raw_id_tracker_wrapper.get_indexed_lookups( - table_names, self.emb_module.uuid - ) - - # Build hash_zch_identities by concatenating raw IDs from tracked tables. - # Output maintains 1-to-1 alignment with features.values(). - # Iterate through table_names explicitly (not raw_ids_dict.values()) to - # ensure correct ordering, since there is no guarantee on dict ordering. - # - # E.g. If features.values() = [f1_val1, f1_val2, f2_val1, f2_val2, ...] - # where table1 has [feature1, feature2] and table2 has [feature3, feature4] - # then hash_zch_identities = [f1_id1, f1_id2, f2_id1, f2_id2, ...] - # - # TODO: Handle tables without identity tracking. Currently, only tables with - # raw_ids are included. If some tables lack identity while others have them, - # padding with -1 may be needed to maintain alignment. - all_raw_ids = [] - for table_name in table_names: - if table_name in raw_ids_dict: - raw_ids_list = raw_ids_dict[table_name] - for raw_ids in raw_ids_list: - all_raw_ids.append(raw_ids) - - if not all_raw_ids: - return None - - hash_zch_identities = torch.cat(all_raw_ids) - assert hash_zch_identities.size(0) == features.values().numel(), ( - f"hash_zch_identities row count ({hash_zch_identities.size(0)}) must match " - f"features.values() length ({features.values().numel()}) to maintain 1-to-1 alignment" - ) - - return hash_zch_identities - def forward( self, features: KeyedJaggedTensor, @@ -2775,22 +2709,6 @@ def named_parameters_by_table( for name, param in self._param_per_table.items(): yield name, param - def init_raw_id_tracker( - self, - get_indexed_lookups: Callable[ - [List[str], Optional[str]], - Dict[str, List[torch.Tensor]], - ], - delete: Callable[ - [int], - None, - ], - ) -> None: - if isinstance(self._emb_module, SplitTableBatchedEmbeddingBagsCodegen): - self._raw_id_tracker_wrapper = RawIdTrackerWrapper( - get_indexed_lookups, delete - ) - class KeyValueEmbeddingBag(BaseBatchedEmbeddingBag[torch.Tensor], FusedOptimizerModule): def __init__( diff --git a/torchrec/distributed/embedding_kernel.py b/torchrec/distributed/embedding_kernel.py index 09a2c0375..79013c85f 100644 --- a/torchrec/distributed/embedding_kernel.py +++ b/torchrec/distributed/embedding_kernel.py @@ -11,10 +11,13 @@ import copy import logging from collections import defaultdict, OrderedDict -from typing import Any, Dict, List, Optional, Tuple, Union +from typing import Any, Callable, cast, Dict, List, Optional, Tuple, Union import torch import torch.distributed as dist +from fbgemm_gpu.split_table_batched_embeddings_ops_training import ( + SplitTableBatchedEmbeddingBagsCodegen, +) from fbgemm_gpu.tbe.ssd.utils.partially_materialized_tensor import ( PartiallyMaterializedTensor, ) @@ -39,11 +42,33 @@ logger: logging.Logger = logging.getLogger(__name__) +class RawIdTrackerWrapper: + __slots__ = ("get_indexed_lookups", "delete") + + def __init__( + self, + get_indexed_lookups: Callable[ + [List[str], Optional[str]], + Dict[str, List[torch.Tensor]], + ], + delete: Callable[ + [int], + None, + ], + ) -> None: + self.get_indexed_lookups = get_indexed_lookups + self.delete = delete + + class BaseEmbedding(abc.ABC, nn.Module): """ Abstract base class for grouped `nn.Embedding` and `nn.EmbeddingBag` """ + def __init__(self) -> None: + super().__init__() + self._raw_id_tracker_wrapper: Optional[RawIdTrackerWrapper] = None + @abc.abstractmethod def forward( self, @@ -62,6 +87,83 @@ def forward( def config(self) -> GroupedEmbeddingConfig: pass + def init_raw_id_tracker( + self, + get_indexed_lookups: Callable[ + [List[str], Optional[str]], + Dict[str, List[torch.Tensor]], + ], + delete: Callable[ + [int], + None, + ], + ) -> None: + """ + Initialize raw ID tracker for hash-based zero collision handling. + + Args: + get_indexed_lookups: Callable to retrieve indexed lookups + delete: Callable to delete tracked data + """ + if isinstance(self._emb_module, SplitTableBatchedEmbeddingBagsCodegen): + self._raw_id_tracker_wrapper = RawIdTrackerWrapper( + get_indexed_lookups, delete + ) + + def _get_hash_zch_identities( + self, features: KeyedJaggedTensor + ) -> Optional[torch.Tensor]: + if self._raw_id_tracker_wrapper is None or not isinstance( + self.emb_module, SplitTableBatchedEmbeddingBagsCodegen + ): + return None + + emb_module = cast(SplitTableBatchedEmbeddingBagsCodegen, self.emb_module) + raw_id_tracker_wrapper = self._raw_id_tracker_wrapper + assert ( + raw_id_tracker_wrapper is not None + ), "self._raw_id_tracker_wrapper should not be None" + table_names = emb_module.table_names + if not table_names: + return None + + # TODO: get_indexed_lookups() may return multiple IndexedLookup objects + # across multiple training iterations. Current logic appends raw_ids from + # all batches sequentially. This may cause misalignment with + # features.values() which only contains the current batch. + raw_ids_dict = raw_id_tracker_wrapper.get_indexed_lookups( + table_names, emb_module.uuid + ) + + # Build hash_zch_identities by concatenating raw IDs from tracked tables. + # Output maintains 1-to-1 alignment with features.values(). + # Iterate through table_names explicitly (not raw_ids_dict.values()) to + # ensure correct ordering, since there is no guarantee on dict ordering. + # + # E.g. If features.values() = [f1_val1, f1_val2, f2_val1, f2_val2, ...] + # where table1 has [feature1, feature2] and table2 has [feature3, feature4] + # then hash_zch_identities = [f1_id1, f1_id2, f2_id1, f2_id2, ...] + # + # TODO: Handle tables without identity tracking. Currently, only tables with + # raw_ids are included. If some tables lack identity while others have them, + # padding with -1 may be needed to maintain alignment. + all_raw_ids = [] + for table_name in table_names: + if table_name in raw_ids_dict: + raw_ids_list = raw_ids_dict[table_name] + for raw_ids in raw_ids_list: + all_raw_ids.append(raw_ids) + + if not all_raw_ids: + return None + + hash_zch_identities = torch.cat(all_raw_ids) + assert hash_zch_identities.size(0) == features.values().numel(), ( + f"hash_zch_identities row count ({hash_zch_identities.size(0)}) must match " + f"features.values() length ({features.values().numel()}) to maintain 1-to-1 alignment" + ) + return hash_zch_identities + def create_virtual_table_local_metadata( local_metadata: ShardMetadata, diff --git a/torchrec/distributed/model_tracker/trackers/raw_id_tracker.py b/torchrec/distributed/model_tracker/trackers/raw_id_tracker.py index b64792912..0db0e6b50 100644 --- a/torchrec/distributed/model_tracker/trackers/raw_id_tracker.py +++ b/torchrec/distributed/model_tracker/trackers/raw_id_tracker.py @@ -19,6 +19,7 @@ KeyedJaggedTensor, ShardedEmbeddingTable, ) +from torchrec.distributed.mc_embedding import ShardedManagedCollisionEmbeddingCollection from torchrec.distributed.mc_embeddingbag import ( ShardedManagedCollisionEmbeddingBagCollection, ) @@ -26,11 +27,16 @@ from torchrec.distributed.model_tracker.delta_store import RawIdTrackerStore from torchrec.distributed.model_tracker.model_delta_tracker import ModelDeltaTracker -from torchrec.distributed.model_tracker.types import IndexedLookup, UniqueRows +from torchrec.distributed.model_tracker.types import UniqueRows logger: logging.Logger = logging.getLogger(__name__) -SUPPORTED_MODULES = (ShardedManagedCollisionCollection,) +SUPPORTED_TRACKING_MODULES = (ShardedManagedCollisionCollection,) + +MANAGED_COLLISION_WRAPPER_MODULES = ( + ShardedManagedCollisionEmbeddingCollection, + ShardedManagedCollisionEmbeddingBagCollection, +) class RawIdTracker(ModelDeltaTracker): @@ -49,7 +55,7 @@ def __init__( self.curr_batch_idx: int = 0 self.curr_compact_index: int = 0 - # from module FQN to SUPPORTED_MODULES + # from module FQN to SUPPORTED_TRACKING_MODULES self.tracked_modules: Dict[str, nn.Module] = {} self.table_to_fqn: Dict[str, str] = {} self.feature_to_fqn: Dict[str, str] = {} @@ -124,7 +130,7 @@ def fqn_to_feature_names(self) -> Dict[str, List[str]]: if self._should_skip_fqn(fqn): continue # Using FQNs of the embedding and mapping them to features as state_dict() API uses these to key states. - if isinstance(named_module, SUPPORTED_MODULES): + if isinstance(named_module, SUPPORTED_TRACKING_MODULES): should_track_module = True for table_name, config in named_module._table_name_to_config.items(): for fqn_to_skip in self._fqns_to_skip: @@ -227,7 +233,7 @@ def _clean_fqn_fn(self, fqn: str) -> str: def _validate_and_init_tracker_fns(self) -> None: "To validate the mode is supported for the given module" for module in self.tracked_modules.values(): - if isinstance(module, SUPPORTED_MODULES): + if isinstance(module, SUPPORTED_TRACKING_MODULES): # register post lookup function module.register_post_lookup_tracker_fn(self.record_lookup) @@ -235,7 +241,7 @@ def _init_tbe_tracker_wrapper(self, module: nn.Module) -> None: for fqn, named_module in self._model.named_modules(): if self._should_skip_fqn(fqn): continue - if isinstance(named_module, ShardedManagedCollisionEmbeddingBagCollection): + if isinstance(named_module, MANAGED_COLLISION_WRAPPER_MODULES): for lookup in named_module._embedding_module._lookups: # pyre-ignore for emb in lookup._emb_modules: