Skip to content

Commit 373d798

Browse files
EddyLXJmeta-codesync[bot]
authored andcommitted
Do not pass opt type in hstu umia st publish (#5125)
Summary: Pull Request resolved: #5125 X-link: meta-pytorch/torchrec#3544 X-link: https://github.com/facebookresearch/FBGEMM/pull/2127 Previously if kvzch table enable PARTIAL_ROWWISE_ADAM opt type. It will pass PARTIAL_ROWWISE_ADAM to all sharder as fused param, which will let sharder init opt with PARTIAL_ROWWISE_ADAM and will cause OOM issue. This diff is changing only pass opt type PARTIAL_ROWWISE_ADAM to KVZCH tbe and avoid OOM issue. Reviewed By: steven1327 Differential Revision: D86787539 fbshipit-source-id: eeeed6a449e8ea130fc684b58e6f15e5f7418e3a
1 parent f3d282b commit 373d798

File tree

2 files changed

+31
-1
lines changed

2 files changed

+31
-1
lines changed

fbgemm_gpu/fbgemm_gpu/split_table_batched_embeddings_ops_common.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111

1212
import enum
1313
from dataclasses import dataclass
14-
from typing import NamedTuple, Optional
14+
from typing import FrozenSet, NamedTuple, Optional, Tuple
1515

1616
import torch
1717
from torch import Tensor
@@ -249,6 +249,8 @@ class KVZCHParams(NamedTuple):
249249
eviction_policy: EvictionPolicy = EvictionPolicy()
250250
embedding_cache_mode: bool = False
251251
load_ckpt_without_opt: bool = False
252+
optimizer_type_for_st: Optional[str] = None
253+
optimizer_state_dtypes_for_st: Optional[FrozenSet[Tuple[str, int]]] = None
252254

253255
def validate(self) -> None:
254256
assert len(self.bucket_offsets) == len(self.bucket_sizes), (
@@ -274,6 +276,10 @@ class KVZCHTBEConfig(NamedTuple):
274276
threshold_calculation_bucket_num: Optional[int] = 1000000 # 1M
275277
# When true, we only save weight to kvzch backend and not optimizer state.
276278
load_ckpt_without_opt: bool = False
279+
# [DO NOT USE] This is for st publish only, do not set it in your config
280+
optimizer_type_for_st: Optional[str] = None
281+
# [DO NOT USE] This is for st publish only, do not set it in your config
282+
optimizer_state_dtypes_for_st: Optional[FrozenSet[Tuple[str, int]]] = None
277283

278284

279285
class BackendType(enum.IntEnum):

fbgemm_gpu/fbgemm_gpu/tbe/ssd/training.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -249,6 +249,30 @@ def __init__(
249249
assert self.optimizer in [
250250
OptimType.EXACT_ROWWISE_ADAGRAD
251251
], f"only EXACT_ROWWISE_ADAGRAD supports embedding cache mode, but got {self.optimizer}"
252+
if self.is_st_publish:
253+
if (
254+
# pyre-ignore [16]
255+
self.kv_zch_params.optimizer_type_for_st
256+
== OptimType.PARTIAL_ROWWISE_ADAM.value
257+
):
258+
self.optimizer = OptimType.PARTIAL_ROWWISE_ADAM
259+
logging.info(
260+
f"Override optimizer type with {self.optimizer=} for st publish"
261+
)
262+
if (
263+
# pyre-ignore [16]
264+
self.kv_zch_params.optimizer_state_dtypes_for_st
265+
is not None
266+
):
267+
optimizer_state_dtypes = {}
268+
for k, v in dict(
269+
self.kv_zch_params.optimizer_state_dtypes_for_st
270+
).items():
271+
optimizer_state_dtypes[k] = SparseType.from_int(v)
272+
self.optimizer_state_dtypes = optimizer_state_dtypes
273+
logging.info(
274+
f"Override optimizer_state_dtypes with {self.optimizer_state_dtypes=} for st publish"
275+
)
252276

253277
self.pooling_mode = pooling_mode
254278
self.bounds_check_mode_int: int = bounds_check_mode.value

0 commit comments

Comments
 (0)