Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
98 changes: 8 additions & 90 deletions torchrec/distributed/batched_embedding_kernel.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@
from math import sqrt
from typing import (
Any,
Callable,
cast,
Dict,
FrozenSet,
Expand Down Expand Up @@ -109,22 +108,6 @@
ENABLE_RAW_EMBEDDING_STREAMING_STR = "enable_raw_embedding_streaming"


class RawIdTrackerWrapper:
def __init__(
self,
get_indexed_lookups: Callable[
[List[str], Optional[str]],
Dict[str, List[torch.Tensor]],
],
delete: Callable[
[int],
None,
],
) -> None:
self.get_indexed_lookups = get_indexed_lookups
self.delete = delete


def _populate_res_params(config: GroupedEmbeddingConfig) -> Tuple[bool, RESParams]:
# populate res_params, which is used for raw embedding streaming
# here only populates the params available in fused_params and TBE configs
Expand Down Expand Up @@ -1723,9 +1706,17 @@ def init_parameters(self) -> None:
)

def forward(self, features: KeyedJaggedTensor) -> torch.Tensor:
hash_zch_identities = self._get_hash_zch_identities(features)
if hash_zch_identities is None:
return self.emb_module(
indices=features.values().long(),
offsets=features.offsets().long(),
)

return self.emb_module(
indices=features.values().long(),
offsets=features.offsets().long(),
hash_zch_identities=hash_zch_identities,
)

# pyre-fixme[14]: `state_dict` overrides method defined in `Module` inconsistently.
Expand Down Expand Up @@ -2563,7 +2554,6 @@ def __init__(
self._lengths_per_emb: List[int] = []
self.table_name_to_count: Dict[str, int] = {}
self._param_per_table: Dict[str, TableBatchedEmbeddingSlice] = {}
self._raw_id_tracker_wrapper: Optional[RawIdTrackerWrapper] = None

for idx, table_config in enumerate(self._config.embedding_tables):
self._local_rows.append(table_config.local_rows)
Expand Down Expand Up @@ -2617,62 +2607,6 @@ def init_parameters(self) -> None:
weight_init_max,
)

def _get_hash_zch_identities(
self, features: KeyedJaggedTensor
) -> Optional[torch.Tensor]:
if self._raw_id_tracker_wrapper is None or not isinstance(
self.emb_module, SplitTableBatchedEmbeddingBagsCodegen
):
return None

raw_id_tracker_wrapper = self._raw_id_tracker_wrapper
assert (
raw_id_tracker_wrapper is not None
), "self._raw_id_tracker_wrapper should not be None"
assert hasattr(
self.emb_module, "res_params"
), "res_params should exist when raw_id_tracker is enabled"
res_params: RESParams = self.emb_module.res_params # pyre-ignore[9]
table_names = res_params.table_names

# TODO: get_indexed_lookups() may return multiple IndexedLookup objects
# across multiple training iterations. Current logic appends raw_ids from
# all batches sequentially. This may cause misalignment with
# features.values() which only contains the current batch.
raw_ids_dict = raw_id_tracker_wrapper.get_indexed_lookups(
table_names, self.emb_module.uuid
)

# Build hash_zch_identities by concatenating raw IDs from tracked tables.
# Output maintains 1-to-1 alignment with features.values().
# Iterate through table_names explicitly (not raw_ids_dict.values()) to
# ensure correct ordering, since there is no guarantee on dict ordering.
#
# E.g. If features.values() = [f1_val1, f1_val2, f2_val1, f2_val2, ...]
# where table1 has [feature1, feature2] and table2 has [feature3, feature4]
# then hash_zch_identities = [f1_id1, f1_id2, f2_id1, f2_id2, ...]
#
# TODO: Handle tables without identity tracking. Currently, only tables with
# raw_ids are included. If some tables lack identity while others have them,
# padding with -1 may be needed to maintain alignment.
all_raw_ids = []
for table_name in table_names:
if table_name in raw_ids_dict:
raw_ids_list = raw_ids_dict[table_name]
for raw_ids in raw_ids_list:
all_raw_ids.append(raw_ids)

if not all_raw_ids:
return None

hash_zch_identities = torch.cat(all_raw_ids)
assert hash_zch_identities.size(0) == features.values().numel(), (
f"hash_zch_identities row count ({hash_zch_identities.size(0)}) must match "
f"features.values() length ({features.values().numel()}) to maintain 1-to-1 alignment"
)

return hash_zch_identities

def forward(
self,
features: KeyedJaggedTensor,
Expand Down Expand Up @@ -2775,22 +2709,6 @@ def named_parameters_by_table(
for name, param in self._param_per_table.items():
yield name, param

def init_raw_id_tracker(
self,
get_indexed_lookups: Callable[
[List[str], Optional[str]],
Dict[str, List[torch.Tensor]],
],
delete: Callable[
[int],
None,
],
) -> None:
if isinstance(self._emb_module, SplitTableBatchedEmbeddingBagsCodegen):
self._raw_id_tracker_wrapper = RawIdTrackerWrapper(
get_indexed_lookups, delete
)


class KeyValueEmbeddingBag(BaseBatchedEmbeddingBag[torch.Tensor], FusedOptimizerModule):
def __init__(
Expand Down
104 changes: 103 additions & 1 deletion torchrec/distributed/embedding_kernel.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,13 @@
import copy
import logging
from collections import defaultdict, OrderedDict
from typing import Any, Dict, List, Optional, Tuple, Union
from typing import Any, Callable, cast, Dict, List, Optional, Tuple, Union

import torch
import torch.distributed as dist
from fbgemm_gpu.split_table_batched_embeddings_ops_training import (
SplitTableBatchedEmbeddingBagsCodegen,
)
from fbgemm_gpu.tbe.ssd.utils.partially_materialized_tensor import (
PartiallyMaterializedTensor,
)
Expand All @@ -39,11 +42,33 @@
logger: logging.Logger = logging.getLogger(__name__)


class RawIdTrackerWrapper:
__slots__ = ("get_indexed_lookups", "delete")

def __init__(
self,
get_indexed_lookups: Callable[
[List[str], Optional[str]],
Dict[str, List[torch.Tensor]],
],
delete: Callable[
[int],
None,
],
) -> None:
self.get_indexed_lookups = get_indexed_lookups
self.delete = delete


class BaseEmbedding(abc.ABC, nn.Module):
"""
Abstract base class for grouped `nn.Embedding` and `nn.EmbeddingBag`
"""

def __init__(self) -> None:
super().__init__()
self._raw_id_tracker_wrapper: Optional[RawIdTrackerWrapper] = None

@abc.abstractmethod
def forward(
self,
Expand All @@ -62,6 +87,83 @@ def forward(
def config(self) -> GroupedEmbeddingConfig:
pass

def init_raw_id_tracker(
self,
get_indexed_lookups: Callable[
[List[str], Optional[str]],
Dict[str, List[torch.Tensor]],
],
delete: Callable[
[int],
None,
],
) -> None:
"""
Initialize raw ID tracker for hash-based zero collision handling.

Args:
get_indexed_lookups: Callable to retrieve indexed lookups
delete: Callable to delete tracked data
"""
if isinstance(self._emb_module, SplitTableBatchedEmbeddingBagsCodegen):
self._raw_id_tracker_wrapper = RawIdTrackerWrapper(
get_indexed_lookups, delete
)

def _get_hash_zch_identities(
self, features: KeyedJaggedTensor
) -> Optional[torch.Tensor]:
if self._raw_id_tracker_wrapper is None or not isinstance(
self.emb_module, SplitTableBatchedEmbeddingBagsCodegen
):
return None

emb_module = cast(SplitTableBatchedEmbeddingBagsCodegen, self.emb_module)
raw_id_tracker_wrapper = self._raw_id_tracker_wrapper
assert (
raw_id_tracker_wrapper is not None
), "self._raw_id_tracker_wrapper should not be None"
table_names = emb_module.table_names
if not table_names:
return None

# TODO: get_indexed_lookups() may return multiple IndexedLookup objects
# across multiple training iterations. Current logic appends raw_ids from
# all batches sequentially. This may cause misalignment with
# features.values() which only contains the current batch.
raw_ids_dict = raw_id_tracker_wrapper.get_indexed_lookups(
table_names, emb_module.uuid
)

# Build hash_zch_identities by concatenating raw IDs from tracked tables.
# Output maintains 1-to-1 alignment with features.values().
# Iterate through table_names explicitly (not raw_ids_dict.values()) to
# ensure correct ordering, since there is no guarantee on dict ordering.
#
# E.g. If features.values() = [f1_val1, f1_val2, f2_val1, f2_val2, ...]
# where table1 has [feature1, feature2] and table2 has [feature3, feature4]
# then hash_zch_identities = [f1_id1, f1_id2, f2_id1, f2_id2, ...]
#
# TODO: Handle tables without identity tracking. Currently, only tables with
# raw_ids are included. If some tables lack identity while others have them,
# padding with -1 may be needed to maintain alignment.
all_raw_ids = []
for table_name in table_names:
if table_name in raw_ids_dict:
raw_ids_list = raw_ids_dict[table_name]
for raw_ids in raw_ids_list:
all_raw_ids.append(raw_ids)

if not all_raw_ids:
return None

hash_zch_identities = torch.cat(all_raw_ids)
assert hash_zch_identities.size(0) == features.values().numel(), (
f"hash_zch_identities row count ({hash_zch_identities.size(0)}) must match "
f"features.values() length ({features.values().numel()}) to maintain 1-to-1 alignment"
)
return hash_zch_identities


def create_virtual_table_local_metadata(
local_metadata: ShardMetadata,
Expand Down
18 changes: 12 additions & 6 deletions torchrec/distributed/model_tracker/trackers/raw_id_tracker.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,18 +19,24 @@
KeyedJaggedTensor,
ShardedEmbeddingTable,
)
from torchrec.distributed.mc_embedding import ShardedManagedCollisionEmbeddingCollection
from torchrec.distributed.mc_embeddingbag import (
ShardedManagedCollisionEmbeddingBagCollection,
)
from torchrec.distributed.mc_modules import ShardedManagedCollisionCollection
from torchrec.distributed.model_tracker.delta_store import RawIdTrackerStore

from torchrec.distributed.model_tracker.model_delta_tracker import ModelDeltaTracker
from torchrec.distributed.model_tracker.types import IndexedLookup, UniqueRows
from torchrec.distributed.model_tracker.types import UniqueRows

logger: logging.Logger = logging.getLogger(__name__)

SUPPORTED_MODULES = (ShardedManagedCollisionCollection,)
SUPPORTED_TRACKING_MODULES = (ShardedManagedCollisionCollection,)

MANAGED_COLLISION_WRAPPER_MODULES = (
ShardedManagedCollisionEmbeddingCollection,
ShardedManagedCollisionEmbeddingBagCollection,
)


class RawIdTracker(ModelDeltaTracker):
Expand All @@ -49,7 +55,7 @@ def __init__(
self.curr_batch_idx: int = 0
self.curr_compact_index: int = 0

# from module FQN to SUPPORTED_MODULES
# from module FQN to SUPPORTED_TRACKING_MODULES
self.tracked_modules: Dict[str, nn.Module] = {}
self.table_to_fqn: Dict[str, str] = {}
self.feature_to_fqn: Dict[str, str] = {}
Expand Down Expand Up @@ -124,7 +130,7 @@ def fqn_to_feature_names(self) -> Dict[str, List[str]]:
if self._should_skip_fqn(fqn):
continue
# Using FQNs of the embedding and mapping them to features as state_dict() API uses these to key states.
if isinstance(named_module, SUPPORTED_MODULES):
if isinstance(named_module, SUPPORTED_TRACKING_MODULES):
should_track_module = True
for table_name, config in named_module._table_name_to_config.items():
for fqn_to_skip in self._fqns_to_skip:
Expand Down Expand Up @@ -227,15 +233,15 @@ def _clean_fqn_fn(self, fqn: str) -> str:
def _validate_and_init_tracker_fns(self) -> None:
"To validate the mode is supported for the given module"
for module in self.tracked_modules.values():
if isinstance(module, SUPPORTED_MODULES):
if isinstance(module, SUPPORTED_TRACKING_MODULES):
# register post lookup function
module.register_post_lookup_tracker_fn(self.record_lookup)

def _init_tbe_tracker_wrapper(self, module: nn.Module) -> None:
for fqn, named_module in self._model.named_modules():
if self._should_skip_fqn(fqn):
continue
if isinstance(named_module, ShardedManagedCollisionEmbeddingBagCollection):
if isinstance(named_module, MANAGED_COLLISION_WRAPPER_MODULES):
for lookup in named_module._embedding_module._lookups:
# pyre-ignore
for emb in lookup._emb_modules:
Expand Down
Loading