Skip to content

Commit 0b2f5e1

Browse files
zhaojuanmaometa-codesync[bot]
authored andcommitted
disable initializer when disable random init (#5182)
Summary: Pull Request resolved: #5182 X-link: https://github.com/facebookresearch/FBGEMM/pull/2178 no need to initialize initializers when random init is disabled, it could save cpu memory significantly Reviewed By: q10 Differential Revision: D86874544 fbshipit-source-id: 5a0ee722e0c40ec280372582ac6238057e4e5fa0
1 parent 0fa6299 commit 0b2f5e1

File tree

1 file changed

+24
-14
lines changed

1 file changed

+24
-14
lines changed

fbgemm_gpu/src/ssd_split_embeddings_cache/ssd_table_batched_embeddings.h

Lines changed: 24 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -422,6 +422,14 @@ class EmbeddingRocksDB : public kv_db::EmbeddingKVDB {
422422
float uniform_init_upper,
423423
int64_t row_storage_bitwidth,
424424
bool disable_random_init) {
425+
// if disable random init, disable random init for all shards for saving cpu
426+
// memory
427+
disable_random_init_ = disable_random_init;
428+
if (disable_random_init_) {
429+
LOG(INFO) << "disable random init for all shards";
430+
return;
431+
}
432+
LOG(INFO) << "enable random init for all shards";
425433
for (auto i = 0; i < num_shards; ++i) {
426434
auto* gen = at::check_generator<at::CPUGeneratorImpl>(
427435
at::detail::getDefaultCPUGenerator());
@@ -436,7 +444,6 @@ class EmbeddingRocksDB : public kv_db::EmbeddingKVDB {
436444
row_storage_bitwidth));
437445
}
438446
}
439-
disable_random_init_ = disable_random_init;
440447
}
441448

442449
void maybe_evict() override {
@@ -1251,19 +1258,22 @@ class EmbeddingRocksDB : public kv_db::EmbeddingKVDB {
12511258
CHECK_EQ(key_indices.size(), keys.size());
12521259
CHECK_EQ(key_indices.size(), cfs.size());
12531260

1254-
const auto& init_storage =
1255-
initializers_[shard]->row_storage_;
1256-
// Sanity check
1257-
TORCH_CHECK(
1258-
init_storage.scalar_type() ==
1259-
weights.scalar_type(),
1260-
"init_storage (",
1261-
toString(init_storage.scalar_type()),
1262-
") and weights scalar (",
1263-
toString(weights.scalar_type()),
1264-
") types mismatch");
1265-
auto row_storage_data_ptr =
1266-
init_storage.data_ptr<value_t>();
1261+
value_t* row_storage_data_ptr = nullptr;
1262+
if (!disable_random_init_) {
1263+
const auto& init_storage =
1264+
initializers_[shard]->row_storage_;
1265+
// Sanity check
1266+
TORCH_CHECK(
1267+
init_storage.scalar_type() ==
1268+
weights.scalar_type(),
1269+
"init_storage (",
1270+
toString(init_storage.scalar_type()),
1271+
") and weights scalar (",
1272+
toString(weights.scalar_type()),
1273+
") types mismatch");
1274+
row_storage_data_ptr =
1275+
init_storage.data_ptr<value_t>();
1276+
}
12671277
if (use_iterator) {
12681278
ssd_get_weights_iterator(
12691279
keys,

0 commit comments

Comments
 (0)