Skip to content

Commit a1b7918

Browse files
Joey Yangfacebook-github-bot
authored andcommitted
Extend raw_id_tracker to track ShardedManagedCollisionEmbeddingCollection (#3545)
Summary: X-link: facebookresearch/FBGEMM#2131 X-link: pytorch/FBGEMM#5128 This diff extends `raw_id_tracker` to support both `ShardedManagedCollisionEmbeddingCollection` and `ShardedManagedCollisionEmbeddingBagCollection`. To avoid code duplication, the diff refactors the initialization and identity parsing logic into the `BaseEmbedding` class. Differential Revision: D87018676
1 parent 979f102 commit a1b7918

File tree

3 files changed

+123
-97
lines changed

3 files changed

+123
-97
lines changed

torchrec/distributed/batched_embedding_kernel.py

Lines changed: 8 additions & 90 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,6 @@
1717
from math import sqrt
1818
from typing import (
1919
Any,
20-
Callable,
2120
cast,
2221
Dict,
2322
Generic,
@@ -108,22 +107,6 @@
108107
ENABLE_RAW_EMBEDDING_STREAMING_STR = "enable_raw_embedding_streaming"
109108

110109

111-
class RawIdTrackerWrapper:
112-
def __init__(
113-
self,
114-
get_indexed_lookups: Callable[
115-
[List[str], Optional[str]],
116-
Dict[str, List[torch.Tensor]],
117-
],
118-
delete: Callable[
119-
[int],
120-
None,
121-
],
122-
) -> None:
123-
self.get_indexed_lookups = get_indexed_lookups
124-
self.delete = delete
125-
126-
127110
def _populate_res_params(config: GroupedEmbeddingConfig) -> Tuple[bool, RESParams]:
128111
# populate res_params, which is used for raw embedding streaming
129112
# here only populates the params available in fused_params and TBE configs
@@ -1711,9 +1694,17 @@ def init_parameters(self) -> None:
17111694
)
17121695

17131696
def forward(self, features: KeyedJaggedTensor) -> torch.Tensor:
1697+
hash_zch_identities = self._get_hash_zch_identities(features)
1698+
if hash_zch_identities is None:
1699+
return self.emb_module(
1700+
indices=features.values().long(),
1701+
offsets=features.offsets().long(),
1702+
)
1703+
17141704
return self.emb_module(
17151705
indices=features.values().long(),
17161706
offsets=features.offsets().long(),
1707+
hash_zch_identities=hash_zch_identities,
17171708
)
17181709

17191710
# pyre-fixme[14]: `state_dict` overrides method defined in `Module` inconsistently.
@@ -2551,7 +2542,6 @@ def __init__(
25512542
self._lengths_per_emb: List[int] = []
25522543
self.table_name_to_count: Dict[str, int] = {}
25532544
self._param_per_table: Dict[str, TableBatchedEmbeddingSlice] = {}
2554-
self._raw_id_tracker_wrapper: Optional[RawIdTrackerWrapper] = None
25552545

25562546
for idx, table_config in enumerate(self._config.embedding_tables):
25572547
self._local_rows.append(table_config.local_rows)
@@ -2605,62 +2595,6 @@ def init_parameters(self) -> None:
26052595
weight_init_max,
26062596
)
26072597

2608-
def _get_hash_zch_identities(
2609-
self, features: KeyedJaggedTensor
2610-
) -> Optional[torch.Tensor]:
2611-
if self._raw_id_tracker_wrapper is None or not isinstance(
2612-
self.emb_module, SplitTableBatchedEmbeddingBagsCodegen
2613-
):
2614-
return None
2615-
2616-
raw_id_tracker_wrapper = self._raw_id_tracker_wrapper
2617-
assert (
2618-
raw_id_tracker_wrapper is not None
2619-
), "self._raw_id_tracker_wrapper should not be None"
2620-
assert hasattr(
2621-
self.emb_module, "res_params"
2622-
), "res_params should exist when raw_id_tracker is enabled"
2623-
res_params: RESParams = self.emb_module.res_params # pyre-ignore[9]
2624-
table_names = res_params.table_names
2625-
2626-
# TODO: get_indexed_lookups() may return multiple IndexedLookup objects
2627-
# across multiple training iterations. Current logic appends raw_ids from
2628-
# all batches sequentially. This may cause misalignment with
2629-
# features.values() which only contains the current batch.
2630-
raw_ids_dict = raw_id_tracker_wrapper.get_indexed_lookups(
2631-
table_names, self.emb_module.uuid
2632-
)
2633-
2634-
# Build hash_zch_identities by concatenating raw IDs from tracked tables.
2635-
# Output maintains 1-to-1 alignment with features.values().
2636-
# Iterate through table_names explicitly (not raw_ids_dict.values()) to
2637-
# ensure correct ordering, since there is no guarantee on dict ordering.
2638-
#
2639-
# E.g. If features.values() = [f1_val1, f1_val2, f2_val1, f2_val2, ...]
2640-
# where table1 has [feature1, feature2] and table2 has [feature3, feature4]
2641-
# then hash_zch_identities = [f1_id1, f1_id2, f2_id1, f2_id2, ...]
2642-
#
2643-
# TODO: Handle tables without identity tracking. Currently, only tables with
2644-
# raw_ids are included. If some tables lack identity while others have them,
2645-
# padding with -1 may be needed to maintain alignment.
2646-
all_raw_ids = []
2647-
for table_name in table_names:
2648-
if table_name in raw_ids_dict:
2649-
raw_ids_list = raw_ids_dict[table_name]
2650-
for raw_ids in raw_ids_list:
2651-
all_raw_ids.append(raw_ids)
2652-
2653-
if not all_raw_ids:
2654-
return None
2655-
2656-
hash_zch_identities = torch.cat(all_raw_ids)
2657-
assert hash_zch_identities.size(0) == features.values().numel(), (
2658-
f"hash_zch_identities row count ({hash_zch_identities.size(0)}) must match "
2659-
f"features.values() length ({features.values().numel()}) to maintain 1-to-1 alignment"
2660-
)
2661-
2662-
return hash_zch_identities
2663-
26642598
def forward(
26652599
self,
26662600
features: KeyedJaggedTensor,
@@ -2763,22 +2697,6 @@ def named_parameters_by_table(
27632697
for name, param in self._param_per_table.items():
27642698
yield name, param
27652699

2766-
def init_raw_id_tracker(
2767-
self,
2768-
get_indexed_lookups: Callable[
2769-
[List[str], Optional[str]],
2770-
Dict[str, List[torch.Tensor]],
2771-
],
2772-
delete: Callable[
2773-
[int],
2774-
None,
2775-
],
2776-
) -> None:
2777-
if isinstance(self._emb_module, SplitTableBatchedEmbeddingBagsCodegen):
2778-
self._raw_id_tracker_wrapper = RawIdTrackerWrapper(
2779-
get_indexed_lookups, delete
2780-
)
2781-
27822700

27832701
class KeyValueEmbeddingBag(BaseBatchedEmbeddingBag[torch.Tensor], FusedOptimizerModule):
27842702
def __init__(

torchrec/distributed/embedding_kernel.py

Lines changed: 103 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,10 +11,13 @@
1111
import copy
1212
import logging
1313
from collections import defaultdict, OrderedDict
14-
from typing import Any, Dict, List, Optional, Tuple, Union
14+
from typing import Any, Callable, cast, Dict, List, Optional, Tuple, Union
1515

1616
import torch
1717
import torch.distributed as dist
18+
from fbgemm_gpu.split_table_batched_embeddings_ops_training import (
19+
SplitTableBatchedEmbeddingBagsCodegen,
20+
)
1821
from fbgemm_gpu.tbe.ssd.utils.partially_materialized_tensor import (
1922
PartiallyMaterializedTensor,
2023
)
@@ -39,11 +42,33 @@
3942
logger: logging.Logger = logging.getLogger(__name__)
4043

4144

45+
class RawIdTrackerWrapper:
46+
__slots__ = ("get_indexed_lookups", "delete")
47+
48+
def __init__(
49+
self,
50+
get_indexed_lookups: Callable[
51+
[List[str], Optional[str]],
52+
Dict[str, List[torch.Tensor]],
53+
],
54+
delete: Callable[
55+
[int],
56+
None,
57+
],
58+
) -> None:
59+
self.get_indexed_lookups = get_indexed_lookups
60+
self.delete = delete
61+
62+
4263
class BaseEmbedding(abc.ABC, nn.Module):
4364
"""
4465
Abstract base class for grouped `nn.Embedding` and `nn.EmbeddingBag`
4566
"""
4667

68+
def __init__(self) -> None:
69+
super().__init__()
70+
self._raw_id_tracker_wrapper: Optional[RawIdTrackerWrapper] = None
71+
4772
@abc.abstractmethod
4873
def forward(
4974
self,
@@ -62,6 +87,83 @@ def forward(
6287
def config(self) -> GroupedEmbeddingConfig:
6388
pass
6489

90+
def init_raw_id_tracker(
91+
self,
92+
get_indexed_lookups: Callable[
93+
[List[str], Optional[str]],
94+
Dict[str, List[torch.Tensor]],
95+
],
96+
delete: Callable[
97+
[int],
98+
None,
99+
],
100+
) -> None:
101+
"""
102+
Initialize raw ID tracker for hash-based zero collision handling.
103+
104+
Args:
105+
get_indexed_lookups: Callable to retrieve indexed lookups
106+
delete: Callable to delete tracked data
107+
"""
108+
if isinstance(self._emb_module, SplitTableBatchedEmbeddingBagsCodegen):
109+
self._raw_id_tracker_wrapper = RawIdTrackerWrapper(
110+
get_indexed_lookups, delete
111+
)
112+
113+
def _get_hash_zch_identities(
114+
self, features: KeyedJaggedTensor
115+
) -> Optional[torch.Tensor]:
116+
if self._raw_id_tracker_wrapper is None or not isinstance(
117+
self.emb_module, SplitTableBatchedEmbeddingBagsCodegen
118+
):
119+
return None
120+
121+
emb_module = cast(SplitTableBatchedEmbeddingBagsCodegen, self.emb_module)
122+
raw_id_tracker_wrapper = self._raw_id_tracker_wrapper
123+
assert (
124+
raw_id_tracker_wrapper is not None
125+
), "self._raw_id_tracker_wrapper should not be None"
126+
table_names = emb_module.table_names
127+
if not table_names:
128+
return None
129+
130+
# TODO: get_indexed_lookups() may return multiple IndexedLookup objects
131+
# across multiple training iterations. Current logic appends raw_ids from
132+
# all batches sequentially. This may cause misalignment with
133+
# features.values() which only contains the current batch.
134+
raw_ids_dict = raw_id_tracker_wrapper.get_indexed_lookups(
135+
table_names, emb_module.uuid
136+
)
137+
138+
# Build hash_zch_identities by concatenating raw IDs from tracked tables.
139+
# Output maintains 1-to-1 alignment with features.values().
140+
# Iterate through table_names explicitly (not raw_ids_dict.values()) to
141+
# ensure correct ordering, since there is no guarantee on dict ordering.
142+
#
143+
# E.g. If features.values() = [f1_val1, f1_val2, f2_val1, f2_val2, ...]
144+
# where table1 has [feature1, feature2] and table2 has [feature3, feature4]
145+
# then hash_zch_identities = [f1_id1, f1_id2, f2_id1, f2_id2, ...]
146+
#
147+
# TODO: Handle tables without identity tracking. Currently, only tables with
148+
# raw_ids are included. If some tables lack identity while others have them,
149+
# padding with -1 may be needed to maintain alignment.
150+
all_raw_ids = []
151+
for table_name in table_names:
152+
if table_name in raw_ids_dict:
153+
raw_ids_list = raw_ids_dict[table_name]
154+
for raw_ids in raw_ids_list:
155+
all_raw_ids.append(raw_ids)
156+
157+
if not all_raw_ids:
158+
return None
159+
160+
hash_zch_identities = torch.cat(all_raw_ids)
161+
assert hash_zch_identities.size(0) == features.values().numel(), (
162+
f"hash_zch_identities row count ({hash_zch_identities.size(0)}) must match "
163+
f"features.values() length ({features.values().numel()}) to maintain 1-to-1 alignment"
164+
)
165+
return hash_zch_identities
166+
65167

66168
def create_virtual_table_local_metadata(
67169
local_metadata: ShardMetadata,

torchrec/distributed/model_tracker/trackers/raw_id_tracker.py

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -19,18 +19,24 @@
1919
KeyedJaggedTensor,
2020
ShardedEmbeddingTable,
2121
)
22+
from torchrec.distributed.mc_embedding import ShardedManagedCollisionEmbeddingCollection
2223
from torchrec.distributed.mc_embeddingbag import (
2324
ShardedManagedCollisionEmbeddingBagCollection,
2425
)
2526
from torchrec.distributed.mc_modules import ShardedManagedCollisionCollection
2627
from torchrec.distributed.model_tracker.delta_store import RawIdTrackerStore
2728

2829
from torchrec.distributed.model_tracker.model_delta_tracker import ModelDeltaTracker
29-
from torchrec.distributed.model_tracker.types import IndexedLookup, UniqueRows
30+
from torchrec.distributed.model_tracker.types import UniqueRows
3031

3132
logger: logging.Logger = logging.getLogger(__name__)
3233

33-
SUPPORTED_MODULES = (ShardedManagedCollisionCollection,)
34+
SUPPORTED_TRACKING_MODULES = (ShardedManagedCollisionCollection,)
35+
36+
MANAGED_COLLISION_WRAPPER_MODULES = (
37+
ShardedManagedCollisionEmbeddingCollection,
38+
ShardedManagedCollisionEmbeddingBagCollection,
39+
)
3440

3541

3642
class RawIdTracker(ModelDeltaTracker):
@@ -49,7 +55,7 @@ def __init__(
4955
self.curr_batch_idx: int = 0
5056
self.curr_compact_index: int = 0
5157

52-
# from module FQN to SUPPORTED_MODULES
58+
# from module FQN to SUPPORTED_TRACKING_MODULES
5359
self.tracked_modules: Dict[str, nn.Module] = {}
5460
self.table_to_fqn: Dict[str, str] = {}
5561
self.feature_to_fqn: Dict[str, str] = {}
@@ -124,7 +130,7 @@ def fqn_to_feature_names(self) -> Dict[str, List[str]]:
124130
if self._should_skip_fqn(fqn):
125131
continue
126132
# Using FQNs of the embedding and mapping them to features as state_dict() API uses these to key states.
127-
if isinstance(named_module, SUPPORTED_MODULES):
133+
if isinstance(named_module, SUPPORTED_TRACKING_MODULES):
128134
should_track_module = True
129135
for table_name, config in named_module._table_name_to_config.items():
130136
for fqn_to_skip in self._fqns_to_skip:
@@ -227,15 +233,15 @@ def _clean_fqn_fn(self, fqn: str) -> str:
227233
def _validate_and_init_tracker_fns(self) -> None:
228234
"To validate the mode is supported for the given module"
229235
for module in self.tracked_modules.values():
230-
if isinstance(module, SUPPORTED_MODULES):
236+
if isinstance(module, SUPPORTED_TRACKING_MODULES):
231237
# register post lookup function
232238
module.register_post_lookup_tracker_fn(self.record_lookup)
233239

234240
def _init_tbe_tracker_wrapper(self, module: nn.Module) -> None:
235241
for fqn, named_module in self._model.named_modules():
236242
if self._should_skip_fqn(fqn):
237243
continue
238-
if isinstance(named_module, ShardedManagedCollisionEmbeddingBagCollection):
244+
if isinstance(named_module, MANAGED_COLLISION_WRAPPER_MODULES):
239245
for lookup in named_module._embedding_module._lookups:
240246
# pyre-ignore
241247
for emb in lookup._emb_modules:

0 commit comments

Comments
 (0)