Skip to content

Commit 640abed

Browse files
Joy Zhangmeta-codesync[bot]
authored andcommitted
Refine Register fbgemm::sum_reduce_to_one (#5107)
Summary: Pull Request resolved: #5107 X-link: https://github.com/facebookresearch/FBGEMM/pull/2112 Only support fp16 and bf16 for now. Reviewed By: domiyy Differential Revision: D86421539 fbshipit-source-id: ed94dd5236395f2b43d09665e8a645b9d48f1b25
1 parent 287dc96 commit 640abed

File tree

1 file changed

+11
-0
lines changed

1 file changed

+11
-0
lines changed

fbgemm_gpu/fbgemm_gpu/sparse_ops.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1206,6 +1206,16 @@ def all_to_one_device(
12061206
]
12071207

12081208

1209+
def sum_reduce_to_one(
1210+
input_tensors: list[Tensor],
1211+
target_device: torch.device,
1212+
) -> Tensor:
1213+
torch._check(len(input_tensors) > 0, lambda: "reducing no tensor is undefined")
1214+
# All tensors should have the same shape
1215+
first_tensor = input_tensors[0]
1216+
return torch.empty_like(first_tensor, device=torch.device("meta"))
1217+
1218+
12091219
def _setup() -> None:
12101220
# pyre-ignore[16]
12111221
_setup.done = getattr(_setup, "done", False)
@@ -1281,6 +1291,7 @@ def impl_autograd(op_name, fn, setup_context: Optional[Callable] = None) -> None
12811291
impl_abstract("fbgemm::segment_sum_csr", segment_sum_csr_abstract)
12821292
impl_abstract("fbgemm::dense_to_jagged_forward", dense_to_jagged_forward)
12831293
impl_abstract("fbgemm::all_to_one_device", all_to_one_device)
1294+
impl_abstract("fbgemm::sum_reduce_to_one", sum_reduce_to_one)
12841295
impl_abstract(
12851296
"fbgemm::batch_index_select_dim0", batch_index_select_dim0_abstract
12861297
)

0 commit comments

Comments
 (0)