Skip to content

Commit 7826ec9

Browse files
garroudmeta-codesync[bot]
authored andcommitted
shortcut for merge_pooled_embedding (#5147)
Summary: Pull Request resolved: #5147 X-link: https://github.com/facebookresearch/FBGEMM/pull/2146 att. When all the input embedding are from the same device, we can just use cat as a short cut. This can avoid unnecessary cross device sync with current impl. Reviewed By: yyetim Differential Revision: D87306514 fbshipit-source-id: 71298220bf12b0fba384ce76146824b2bb094e2c
1 parent 44f943c commit 7826ec9

File tree

2 files changed

+22
-2
lines changed

2 files changed

+22
-2
lines changed

fbgemm_gpu/src/merge_pooled_embedding_ops/merge_pooled_embedding_ops_gpu.cpp

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -688,6 +688,20 @@ Tensor merge_pooled_embeddings(
688688
at::cuda::OptionalCUDAGuard g;
689689

690690
at::Device out_device = target_device;
691+
692+
// if target_device is the same as input devices, we can directly call
693+
// cat
694+
bool is_same_device = true;
695+
for (const auto& t : pooled_embeddings) {
696+
if (t.device() != target_device) {
697+
is_same_device = false;
698+
break;
699+
}
700+
}
701+
if (is_same_device) {
702+
return at::cat(pooled_embeddings, cat_dim);
703+
}
704+
691705
if (target_device.is_cuda()) {
692706
init_p2p_access();
693707
g.set_device(target_device);

fbgemm_gpu/test/merge_pooled_embeddings_test.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,7 @@ class MergePooledEmbeddingsTest(unittest.TestCase):
6868
non_default_stream=st.booleans(),
6969
r=st.randoms(use_true_random=False),
7070
dim=st.integers(min_value=0, max_value=1),
71+
source_from_same_device=st.booleans(),
7172
)
7273
# Can instantiate 8 contexts which takes a long time.
7374
@settings(verbosity=Verbosity.verbose, max_examples=40, deadline=None)
@@ -81,14 +82,19 @@ def test_merge(
8182
# pyre-fixme[2]: Parameter must be annotated.
8283
r,
8384
dim: int,
85+
source_from_same_device: bool,
8486
) -> None:
8587
dst_device = r.randint(0, num_gpus - 1)
8688
torch.cuda.set_device(dst_device)
8789
ad_ds = [embedding_dimension * ads_tables for _ in range(num_gpus)]
8890
batch_indices = torch.zeros(num_ads).long().cuda()
8991
pooled_ad_embeddings = [
90-
torch.randn(
91-
num_ads, ad_d, dtype=torch.float16, device=torch.device(f"cuda:{i}")
92+
(
93+
torch.randn(num_ads, ad_d, dtype=torch.float16, device=dst_device)
94+
if source_from_same_device
95+
else torch.randn(
96+
num_ads, ad_d, dtype=torch.float16, device=torch.device(f"cuda:{i}")
97+
)
9298
)
9399
for i, ad_d in enumerate(ad_ds)
94100
]

0 commit comments

Comments
 (0)