Skip to content

Commit f3d282b

Browse files
EddyLXJmeta-codesync[bot]
authored andcommitted
st publish mode only load weight (#5116)
Summary: Pull Request resolved: #5116 X-link: meta-pytorch/torchrec#3538 X-link: https://github.com/facebookresearch/FBGEMM/pull/2122 For silvertorch publish, we don't want to load opt into backend due to limited cpu memory in publish host. So we need to load the whole row into state dict which loading the checkpoint in st publish, then only save weight into backend, after that backend will only have metaheader + weight. For the first loading, we need to set dim with metaheader_dim + emb_dim + optimizer_state_dim, otherwise the checkpoint loadding will throw size mismatch error. after the first loading, we only need to get metaheader+weight from backend for state dict, so we can set dim with metaheader_dim + emb Reviewed By: emlin Differential Revision: D85830053 fbshipit-source-id: 0eddbe9e69ea8271e8c77dc0147e87a08f0b3934
1 parent 94088ab commit f3d282b

File tree

5 files changed

+72
-15
lines changed

5 files changed

+72
-15
lines changed

fbgemm_gpu/fbgemm_gpu/split_table_batched_embeddings_ops_common.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -248,6 +248,7 @@ class KVZCHParams(NamedTuple):
248248
backend_return_whole_row: bool = False
249249
eviction_policy: EvictionPolicy = EvictionPolicy()
250250
embedding_cache_mode: bool = False
251+
load_ckpt_without_opt: bool = False
251252

252253
def validate(self) -> None:
253254
assert len(self.bucket_offsets) == len(self.bucket_sizes), (
@@ -271,6 +272,8 @@ class KVZCHTBEConfig(NamedTuple):
271272
threshold_calculation_bucket_stride: float = 0.2
272273
# Total number of feature score buckets used for threshold calculation in feature score-based eviction.
273274
threshold_calculation_bucket_num: Optional[int] = 1000000 # 1M
275+
# When true, we only save weight to kvzch backend and not optimizer state.
276+
load_ckpt_without_opt: bool = False
274277

275278

276279
class BackendType(enum.IntEnum):

fbgemm_gpu/fbgemm_gpu/tbe/ssd/training.py

Lines changed: 49 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -217,8 +217,13 @@ def __init__(
217217
self.enable_optimizer_offloading: bool = False
218218
self.backend_return_whole_row: bool = False
219219
self._embedding_cache_mode: bool = False
220+
self.load_ckpt_without_opt: bool = False
220221
if self.kv_zch_params:
221222
self.kv_zch_params.validate()
223+
self.load_ckpt_without_opt = (
224+
# pyre-ignore [16]
225+
self.kv_zch_params.load_ckpt_without_opt
226+
)
222227
self.enable_optimizer_offloading = (
223228
# pyre-ignore [16]
224229
self.kv_zch_params.enable_optimizer_offloading
@@ -1105,7 +1110,9 @@ def cache_row_dim(self) -> int:
11051110
padding to the nearest 4 elements and the optimizer state appended to
11061111
the back of the row
11071112
"""
1108-
if self.enable_optimizer_offloading:
1113+
1114+
# For st publish, we only need to load weight for publishing and bulk eval
1115+
if self.enable_optimizer_offloading and not self.load_ckpt_without_opt:
11091116
return self.max_D + pad4(
11101117
# Compute the number of elements of cache_dtype needed to store
11111118
# the optimizer state
@@ -3092,6 +3099,38 @@ def _may_create_snapshot_for_state_dict(
30923099
self.flush(force=should_flush)
30933100
return snapshot_handle, checkpoint_handle
30943101

3102+
def get_embedding_dim_for_kvt(
3103+
self, metaheader_dim: int, emb_dim: int, is_loading_checkpoint: bool
3104+
) -> int:
3105+
if self.load_ckpt_without_opt:
3106+
# For silvertorch publish, we don't want to load opt into backend due to limited cpu memory in publish host.
3107+
# So we need to load the whole row into state dict which loading the checkpoint in st publish, then only save weight into backend, after that
3108+
# backend will only have metaheader + weight.
3109+
# For the first loading, we need to set dim with metaheader_dim + emb_dim + optimizer_state_dim, otherwise the checkpoint loadding will throw size mismatch error
3110+
# after the first loading, we only need to get metaheader+weight from backend for state dict, so we can set dim with metaheader_dim + emb
3111+
if is_loading_checkpoint:
3112+
return (
3113+
(
3114+
metaheader_dim # metaheader is already padded
3115+
+ pad4(emb_dim)
3116+
+ pad4(self.optimizer_state_dim)
3117+
)
3118+
if self.backend_return_whole_row
3119+
else emb_dim
3120+
)
3121+
else:
3122+
return metaheader_dim + pad4(emb_dim)
3123+
else:
3124+
return (
3125+
(
3126+
metaheader_dim # metaheader is already padded
3127+
+ pad4(emb_dim)
3128+
+ pad4(self.optimizer_state_dim)
3129+
)
3130+
if self.backend_return_whole_row
3131+
else emb_dim
3132+
)
3133+
30953134
@torch.jit.export
30963135
def split_embedding_weights(
30973136
self,
@@ -3149,6 +3188,7 @@ def split_embedding_weights(
31493188

31503189
table_offset = 0
31513190
for i, (emb_height, emb_dim) in enumerate(self.embedding_specs):
3191+
is_loading_checkpoint = False
31523192
bucket_ascending_id_tensor = None
31533193
bucket_t = None
31543194
metadata_tensor = None
@@ -3214,6 +3254,7 @@ def split_embedding_weights(
32143254
dtype=torch.int64,
32153255
)
32163256
skip_metadata = True
3257+
is_loading_checkpoint = True
32173258

32183259
# self.local_weight_counts[i] = 0 # Reset the count
32193260

@@ -3238,14 +3279,8 @@ def split_embedding_weights(
32383279
if bucket_ascending_id_tensor is not None
32393280
else emb_height
32403281
),
3241-
(
3242-
(
3243-
metaheader_dim # metaheader is already padded
3244-
+ pad4(emb_dim)
3245-
+ pad4(self.optimizer_state_dim)
3246-
)
3247-
if self.backend_return_whole_row
3248-
else emb_dim
3282+
self.get_embedding_dim_for_kvt(
3283+
metaheader_dim, emb_dim, is_loading_checkpoint
32493284
),
32503285
],
32513286
dtype=dtype,
@@ -3257,6 +3292,11 @@ def split_embedding_weights(
32573292
bucket_ascending_id_tensor if self.kv_zch_params else None
32583293
),
32593294
checkpoint_handle=checkpoint_handle,
3295+
only_load_weight=(
3296+
True
3297+
if self.load_ckpt_without_opt and is_loading_checkpoint
3298+
else False
3299+
),
32603300
)
32613301
(
32623302
tensor_wrapper.set_embedding_rocks_dp_wrapper(self.ssd_db)

fbgemm_gpu/src/ssd_split_embeddings_cache/kv_tensor_wrapper.h

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,8 @@ class KVTensorWrapper : public torch::jit::CustomClassHolder {
6565
int64_t width_offset = 0,
6666
const std::optional<c10::intrusive_ptr<RocksdbCheckpointHandleWrapper>>
6767
checkpoint_handle = std::nullopt,
68-
bool read_only = false);
68+
bool read_only = false,
69+
bool only_load_weight = false);
6970

7071
explicit KVTensorWrapper(const std::string& serialized);
7172

@@ -153,6 +154,7 @@ class KVTensorWrapper : public torch::jit::CustomClassHolder {
153154
int64_t max_D{};
154155
std::string checkpoint_uuid;
155156
bool read_only_{};
157+
bool only_load_weight_{};
156158
};
157159

158160
void to_json(json& j, const KVTensorWrapper& kvt);

fbgemm_gpu/src/ssd_split_embeddings_cache/kv_tensor_wrapper_cpu.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,8 @@ KVTensorWrapper::KVTensorWrapper(
3838
[[maybe_unused]] int64_t width_offset,
3939
[[maybe_unused]] const std::optional<
4040
c10::intrusive_ptr<RocksdbCheckpointHandleWrapper>>,
41-
[[maybe_unused]] bool read_only)
41+
[[maybe_unused]] bool read_only,
42+
[[maybe_unused]] bool only_load_weight)
4243
// @lint-ignore CLANGTIDY clang-diagnostic-missing-noreturn
4344
: shape_(std::move(shape)), row_offset_(row_offset) {
4445
FBEXCEPTION("Not implemented");

fbgemm_gpu/src/ssd_split_embeddings_cache/ssd_split_table_batched_embeddings.cpp

Lines changed: 15 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -374,12 +374,14 @@ KVTensorWrapper::KVTensorWrapper(
374374
int64_t width_offset_,
375375
const std::optional<c10::intrusive_ptr<RocksdbCheckpointHandleWrapper>>
376376
checkpoint_handle,
377-
bool read_only)
377+
bool read_only,
378+
bool only_load_weight)
378379
: db_(nullptr),
379380
shape_(std::move(shape)),
380381
row_offset_(row_offset),
381382
width_offset_(width_offset_),
382-
read_only_(read_only) {
383+
read_only_(read_only),
384+
only_load_weight_(only_load_weight) {
383385
CHECK_GE(width_offset_, 0);
384386
CHECK_EQ(shape_.size(), 2) << "Only 2D emb tensors are supported";
385387
options_ = at::TensorOptions()
@@ -558,7 +560,10 @@ void KVTensorWrapper::set_range(
558560
CHECK(db_) << "EmbeddingRocksDB must be a valid pointer to call set_range";
559561
CHECK_EQ(dim, 0) << "Only set_range on dim 0 is supported";
560562
CHECK_TRUE(db_ != nullptr);
561-
CHECK_GE(db_->get_max_D() + db_->get_metaheader_width_in_front(), shape_[1]);
563+
if (!only_load_weight_) {
564+
CHECK_GE(
565+
db_->get_max_D() + db_->get_metaheader_width_in_front(), shape_[1]);
566+
}
562567

563568
if (db_->get_backend_return_whole_row()) {
564569
// backend returns whole row, so we need to replace the first 8 bytes with
@@ -576,6 +581,10 @@ void KVTensorWrapper::set_range(
576581
db_->get_max_D() + db_->get_metaheader_width_in_front() - weights.size(1);
577582
if (pad_right == 0) {
578583
db_->set_range_to_storage(weights, start + row_offset_, length);
584+
} else if (pad_right < 0 && only_load_weight_) {
585+
int64_t cut_dim = db_->get_max_D() + db_->get_metaheader_width_in_front();
586+
at::Tensor new_weights = weights.narrow(1, 0, cut_dim).contiguous();
587+
db_->set_range_to_storage(new_weights, start + row_offset_, length);
579588
} else {
580589
std::vector<int64_t> padding = {0, pad_right, 0, 0};
581590
auto padded_weights = torch::constant_pad_nd(weights, padding, 0);
@@ -1080,6 +1089,7 @@ static auto kv_tensor_wrapper =
10801089
int64_t,
10811090
std::optional<
10821091
c10::intrusive_ptr<RocksdbCheckpointHandleWrapper>>,
1092+
bool,
10831093
bool>(),
10841094
"",
10851095
{torch::arg("shape"),
@@ -1091,7 +1101,8 @@ static auto kv_tensor_wrapper =
10911101
torch::arg("sorted_indices") = std::nullopt,
10921102
torch::arg("width_offset") = 0,
10931103
torch::arg("checkpoint_handle") = std::nullopt,
1094-
torch::arg("read_only") = false})
1104+
torch::arg("read_only") = false,
1105+
torch::arg("only_load_weight") = false})
10951106
.def(
10961107
"set_embedding_rocks_dp_wrapper",
10971108
&KVTensorWrapper::set_embedding_rocks_dp_wrapper,

0 commit comments

Comments
 (0)