From b47cf7dd08cfe75215a82eced3ab2f38c1196c5d Mon Sep 17 00:00:00 2001
From: Li Li
Date: Fri, 7 Nov 2025 21:54:01 +0000
Subject: [PATCH 1/5] opt group_index_select_or_add_2d_kernel
---
.../src/sparse_ops/sparse_group_index.cu | 251 ++++++++++++++----
fbgemm_gpu/test/sparse/index_select_test.py | 30 ++-
2 files changed, 234 insertions(+), 47 deletions(-)
diff --git a/fbgemm_gpu/src/sparse_ops/sparse_group_index.cu b/fbgemm_gpu/src/sparse_ops/sparse_group_index.cu
index 96c57cde68..c05c99de13 100644
--- a/fbgemm_gpu/src/sparse_ops/sparse_group_index.cu
+++ b/fbgemm_gpu/src/sparse_ops/sparse_group_index.cu
@@ -51,59 +51,218 @@ __launch_bounds__(kMaxThreads) void group_index_select_or_add_2d_kernel(
const int64_t num_work_rows, // number of rows to work on per member
const int64_t group_size) {
const auto total_num_warps = warp_offsets_group[group_size];
- int32_t num_cols = 0;
- int32_t warps_per_row = 0;
-
- if constexpr (!USE_VAR_COLS) {
- num_cols = num_cols_group[0];
- warps_per_row = (num_cols + COLS_PER_WARP - 1) >> LOG_COLS_PER_WARP;
- }
+ // USE_INDEX_SELECT is a template argument; the compiler prunes the unused branch.
+ if (USE_INDEX_SELECT) {
+ for (int64_t warp_id = threadIdx.y * gridDim.x + blockIdx.x;
+ warp_id < total_num_warps;
+ warp_id += gridDim.x * blockDim.y) {
+ int32_t member_id, member_warp_id, num_cols, warps_per_row;
+ if (USE_VAR_COLS) {
+ __shared__ int member_ids[kMaxThreads / kWarpSize];
+ if (threadIdx.x == 0) {
+ binary_search_range(
+ &member_ids[threadIdx.y],
+ warp_offsets_group + 1,
+ warp_id,
+ group_size);
+ }
+ syncwarp();
+ member_id = member_ids[threadIdx.y];
+ num_cols = num_cols_group[member_id];
+ warps_per_row = (num_cols + COLS_PER_WARP - 1) >> LOG_COLS_PER_WARP;
+ member_warp_id = warp_id - warp_offsets_group[member_id];
+ } else {
+ // All columns are the same
+ num_cols = num_cols_group[0];
+ warps_per_row = (num_cols + COLS_PER_WARP - 1) >> LOG_COLS_PER_WARP;
+ member_id = warp_id / (warps_per_row * num_work_rows);
+ member_warp_id = warp_id - (member_id * warps_per_row * num_work_rows);
+ }
+ const auto row = member_warp_id / warps_per_row;
+ const auto col_offset =
+ ((member_warp_id % warps_per_row) << LOG_COLS_PER_WARP) +
+ (threadIdx.x * UNROLL_FACTOR);
+ scalar_t* input =
+ reinterpret_cast(input_ptrs[member_id]) + col_offset;
+ scalar_t* output =
+ reinterpret_cast(output_ptrs[member_id]) + col_offset;
- for (int64_t warp_id = threadIdx.y * gridDim.x + blockIdx.x;
- warp_id < total_num_warps;
- warp_id += gridDim.x * blockDim.y) {
- int32_t member_id = 0;
- int32_t member_warp_id = 0;
- if constexpr (USE_VAR_COLS) {
- __shared__ int member_ids[kMaxThreads / EMULATED_WARP_SIZE];
- if (threadIdx.x == 0) {
- binary_search_range(
- &member_ids[threadIdx.y],
- warp_offsets_group + 1,
- warp_id,
- group_size);
+ index_t* indices = reinterpret_cast(indices_ptrs[member_id]);
+ const index_t idx = indices[row];
+#pragma unroll
+ for (int i = 0; i < UNROLL_FACTOR && col_offset + i < num_cols; i++) {
+ output[row * num_cols + i] = LDG(&input[idx * num_cols + i]);
}
- syncwarp();
- member_id = member_ids[threadIdx.y];
- num_cols = num_cols_group[member_id];
- warps_per_row = (num_cols + COLS_PER_WARP - 1) >> LOG_COLS_PER_WARP;
- member_warp_id = warp_id - warp_offsets_group[member_id];
- } else {
- // All columns are the same
- member_id = warp_id / (warps_per_row * num_work_rows);
- member_warp_id = warp_id - (member_id * warps_per_row * num_work_rows);
}
- const auto row = member_warp_id / warps_per_row;
- const auto col_offset =
- ((member_warp_id % warps_per_row) << LOG_COLS_PER_WARP) +
- (threadIdx.x * UNROLL_FACTOR);
- scalar_t* input =
- reinterpret_cast(input_ptrs[member_id]) + col_offset;
- scalar_t* output =
- reinterpret_cast(output_ptrs[member_id]) + col_offset;
-
- index_t* indices = reinterpret_cast(indices_ptrs[member_id]);
- const index_t idx = indices[row];
+ } else {
+ // Cache a handful of scatter destinations per warp so we can merge
+ // consecutive updates that hit the same index before touching global memory.
+ constexpr int kCacheSlots = 2;
+ index_t cached_idx[kCacheSlots];
+ scalar_t cached_vals[kCacheSlots][UNROLL_FACTOR];
+ bool cached_valid[kCacheSlots];
#pragma unroll
- for (int i = 0; i < UNROLL_FACTOR && col_offset + i < num_cols; i++) {
- // Compile time conditional
- if constexpr (USE_INDEX_SELECT) {
- output[row * num_cols + i] = LDG(&input[idx * num_cols + i]);
+ for (int slot = 0; slot < kCacheSlots; ++slot) {
+ cached_valid[slot] = false;
+ }
+ int32_t active_member_id = -1;
+ int32_t active_num_cols = 0;
+ int32_t active_col_offset = -1;
+ scalar_t* active_input_base = nullptr;
+ scalar_t* active_output_base = nullptr;
+ index_t* active_indices = nullptr;
+
+ auto flush_cache = [&](scalar_t* out_base,
+ int32_t num_cols,
+ int32_t col_offset) {
+ if (!out_base) {
+ return;
+ }
+#pragma unroll
+ for (int slot = 0; slot < kCacheSlots; ++slot) {
+ if (!cached_valid[slot]) {
+ continue;
+ }
+ const int64_t row_offset =
+ static_cast(cached_idx[slot]) * num_cols;
+#pragma unroll
+ for (int j = 0; j < UNROLL_FACTOR; ++j) {
+ const int32_t col = col_offset + j;
+ if (col >= num_cols) {
+ break;
+ }
+ gpuAtomicAddNoReturn(
+ out_base + row_offset + col, cached_vals[slot][j]);
+ }
+ cached_valid[slot] = false;
+ }
+ };
+
+ for (int64_t warp_id = threadIdx.y * gridDim.x + blockIdx.x;
+ warp_id < total_num_warps;
+ warp_id += gridDim.x * blockDim.y) {
+ int32_t member_id, member_warp_id, num_cols, warps_per_row;
+ if (USE_VAR_COLS) {
+ __shared__ int member_ids[kMaxThreads / kWarpSize];
+ if (threadIdx.x == 0) {
+ binary_search_range(
+ &member_ids[threadIdx.y],
+ warp_offsets_group + 1,
+ warp_id,
+ group_size);
+ }
+ syncwarp();
+ member_id = member_ids[threadIdx.y];
+ num_cols = num_cols_group[member_id];
+ warps_per_row = (num_cols + COLS_PER_WARP - 1) >> LOG_COLS_PER_WARP;
+ member_warp_id = warp_id - warp_offsets_group[member_id];
} else {
- gpuAtomicAddNoReturn(
- &output[idx * num_cols + i], input[row * num_cols + i]);
+ // All columns are the same
+ num_cols = num_cols_group[0];
+ warps_per_row = (num_cols + COLS_PER_WARP - 1) >> LOG_COLS_PER_WARP;
+ member_id = warp_id / (warps_per_row * num_work_rows);
+ member_warp_id = warp_id - (member_id * warps_per_row * num_work_rows);
+ }
+ const int64_t row = member_warp_id / warps_per_row;
+ const int32_t col_offset =
+ static_cast(((member_warp_id % warps_per_row)
+ << LOG_COLS_PER_WARP) +
+ (threadIdx.x * UNROLL_FACTOR));
+
+ const bool member_changed = member_id != active_member_id;
+ const bool num_cols_changed =
+ member_changed ? false : (num_cols != active_num_cols);
+ const bool col_changed =
+ member_changed ? false : (col_offset != active_col_offset);
+ if (member_changed || num_cols_changed || col_changed) {
+ flush_cache(active_output_base, active_num_cols, active_col_offset);
+ active_member_id = member_id;
+ active_num_cols = num_cols;
+ active_col_offset = col_offset;
+ active_input_base =
+ reinterpret_cast(input_ptrs[member_id]);
+ active_output_base =
+ reinterpret_cast(output_ptrs[member_id]);
+ active_indices =
+ reinterpret_cast(indices_ptrs[member_id]);
+ }
+
+ if (col_offset >= active_num_cols) {
+ continue;
+ }
+
+ const index_t idx = active_indices[row];
+
+ scalar_t local_vals[UNROLL_FACTOR];
+#pragma unroll
+ for (int j = 0; j < UNROLL_FACTOR; ++j) {
+ local_vals[j] = static_cast(0);
+ }
+ const int64_t input_offset =
+ static_cast(row) * active_num_cols + active_col_offset;
+#pragma unroll
+ for (int j = 0; j < UNROLL_FACTOR; ++j) {
+ const int32_t col = active_col_offset + j;
+ if (col >= active_num_cols) {
+ break;
+ }
+ local_vals[j] = active_input_base[input_offset + j];
+ }
+
+ bool appended = false;
+#pragma unroll
+ for (int slot = 0; slot < kCacheSlots; ++slot) {
+ if (cached_valid[slot] && cached_idx[slot] == idx) {
+#pragma unroll
+ for (int j = 0; j < UNROLL_FACTOR; ++j) {
+ const int32_t col = active_col_offset + j;
+ if (col >= active_num_cols) {
+ break;
+ }
+ cached_vals[slot][j] += local_vals[j];
+ }
+ appended = true;
+ break;
+ }
+ }
+
+ if (!appended) {
+ int slot_to_use = -1;
+#pragma unroll
+ for (int slot = 0; slot < kCacheSlots; ++slot) {
+ if (!cached_valid[slot]) {
+ slot_to_use = slot;
+ break;
+ }
+ }
+ if (slot_to_use == -1) {
+ slot_to_use = 0;
+ const int64_t row_offset =
+ static_cast(cached_idx[slot_to_use]) *
+ active_num_cols;
+#pragma unroll
+ for (int j = 0; j < UNROLL_FACTOR; ++j) {
+ const int32_t col = active_col_offset + j;
+ if (col >= active_num_cols) {
+ break;
+ }
+ gpuAtomicAddNoReturn(
+ active_output_base + row_offset + col,
+ cached_vals[slot_to_use][j]);
+ }
+ cached_valid[slot_to_use] = false;
+ }
+
+ cached_idx[slot_to_use] = idx;
+#pragma unroll
+ for (int j = 0; j < UNROLL_FACTOR; ++j) {
+ cached_vals[slot_to_use][j] = local_vals[j];
+ }
+ cached_valid[slot_to_use] = true;
}
}
+
+ flush_cache(active_output_base, active_num_cols, active_col_offset);
}
}
diff --git a/fbgemm_gpu/test/sparse/index_select_test.py b/fbgemm_gpu/test/sparse/index_select_test.py
index 6c61b77bf8..ff4b264e05 100644
--- a/fbgemm_gpu/test/sparse/index_select_test.py
+++ b/fbgemm_gpu/test/sparse/index_select_test.py
@@ -1,4 +1,4 @@
-#!/usr/bin/env python3
+ #!/usr/bin/env python3
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
@@ -239,6 +239,34 @@ def compare_tensor_groups(
{"rtol": 1e-02, "atol": 1e-02} if dtype == torch.half else {},
)
+ @unittest.skipIf(not gpu_available, "CUDA not available")
+ def test_group_index_select_dim0_duplicate_gradients(self) -> None:
+ device = torch.device("cuda")
+ dtype = torch.float
+
+ num_rows = 4
+ num_cols = 9
+ indices = torch.tensor([0, 1, 2, 1, 0, 2], dtype=torch.long, device=device)
+
+ input_tensor = torch.randn(
+ (num_rows, num_cols), dtype=dtype, device=device
+ ).requires_grad_(True)
+
+ output_group = torch.ops.fbgemm.group_index_select_dim0(
+ [input_tensor], [indices]
+ )
+ output = output_group[0]
+
+ grad = torch.arange(
+ output.numel(), dtype=dtype, device=device
+ ).view_as(output)
+ output.backward(grad)
+
+ ref_grad = torch.zeros_like(input_tensor)
+ ref_grad.index_add_(0, indices, grad)
+
+ torch.testing.assert_close(input_tensor.grad, ref_grad)
+
@given(
num_inputs=st.integers(0, 100),
max_input_rows=st.integers(2, 32),
From 94ccd9707e66fcfd869346869235f3a2469b569d Mon Sep 17 00:00:00 2001
From: Shreyashri Biswas
Date: Thu, 13 Nov 2025 18:50:50 +0000
Subject: [PATCH 2/5] remove group index ut
---
fbgemm_gpu/test/sparse/index_select_test.py | 30 +--------------------
1 file changed, 1 insertion(+), 29 deletions(-)
diff --git a/fbgemm_gpu/test/sparse/index_select_test.py b/fbgemm_gpu/test/sparse/index_select_test.py
index ff4b264e05..6c61b77bf8 100644
--- a/fbgemm_gpu/test/sparse/index_select_test.py
+++ b/fbgemm_gpu/test/sparse/index_select_test.py
@@ -1,4 +1,4 @@
- #!/usr/bin/env python3
+#!/usr/bin/env python3
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
@@ -239,34 +239,6 @@ def compare_tensor_groups(
{"rtol": 1e-02, "atol": 1e-02} if dtype == torch.half else {},
)
- @unittest.skipIf(not gpu_available, "CUDA not available")
- def test_group_index_select_dim0_duplicate_gradients(self) -> None:
- device = torch.device("cuda")
- dtype = torch.float
-
- num_rows = 4
- num_cols = 9
- indices = torch.tensor([0, 1, 2, 1, 0, 2], dtype=torch.long, device=device)
-
- input_tensor = torch.randn(
- (num_rows, num_cols), dtype=dtype, device=device
- ).requires_grad_(True)
-
- output_group = torch.ops.fbgemm.group_index_select_dim0(
- [input_tensor], [indices]
- )
- output = output_group[0]
-
- grad = torch.arange(
- output.numel(), dtype=dtype, device=device
- ).view_as(output)
- output.backward(grad)
-
- ref_grad = torch.zeros_like(input_tensor)
- ref_grad.index_add_(0, indices, grad)
-
- torch.testing.assert_close(input_tensor.grad, ref_grad)
-
@given(
num_inputs=st.integers(0, 100),
max_input_rows=st.integers(2, 32),
From a474843aaa9c5d51a509be2644892c875f1ee240 Mon Sep 17 00:00:00 2001
From: Shreyashri Biswas
Date: Thu, 13 Nov 2025 18:53:06 +0000
Subject: [PATCH 3/5] fixed cols_per_warp
---
fbgemm_gpu/src/sparse_ops/sparse_group_index.cu | 14 +++-----------
1 file changed, 3 insertions(+), 11 deletions(-)
diff --git a/fbgemm_gpu/src/sparse_ops/sparse_group_index.cu b/fbgemm_gpu/src/sparse_ops/sparse_group_index.cu
index c05c99de13..d9f8b823e2 100644
--- a/fbgemm_gpu/src/sparse_ops/sparse_group_index.cu
+++ b/fbgemm_gpu/src/sparse_ops/sparse_group_index.cu
@@ -12,18 +12,10 @@ using Tensor = at::Tensor;
namespace fbgemm_gpu {
-#ifdef USE_ROCM
-// The wave size is forced to be 32 on ROCm devices in favor
-// of granularity losses reduction.
-constexpr int EMULATED_WARP_SIZE = 32;
-#else
-constexpr int EMULATED_WARP_SIZE = kWarpSize;
-#endif
-
// TODO: Update UNROLL_FACTOR
constexpr int GROUP_INDEX_SELECT_UNROLL_FACTOR = 1;
constexpr int GROUP_INDEX_SELECT_COLS_PER_WARP =
- GROUP_INDEX_SELECT_UNROLL_FACTOR * EMULATED_WARP_SIZE;
+ GROUP_INDEX_SELECT_UNROLL_FACTOR * kWarpSize;
// GROUP_INDEX_SELECT_COLS_PER_WARP must be power of two
constexpr int GROUP_INDEX_SELECT_LOG_COLS_PER_WARP =
@@ -287,13 +279,13 @@ DLL_PUBLIC void group_index_select_or_add_cuda(
at::cuda::OptionalCUDAGuard device_guard(device);
// Partition work based on num_work_rows
- uint32_t num_warps_per_threadblock = kMaxThreads / EMULATED_WARP_SIZE;
+ uint32_t num_warps_per_threadblock = kMaxThreads / kWarpSize;
uint32_t max_grid_size =
at::cuda::getCurrentDeviceProperties()->multiProcessorCount * 8;
uint32_t grid_size = std::min(
cuda_calc_xblock_count(total_num_warps, num_warps_per_threadblock),
max_grid_size);
- dim3 block_size(EMULATED_WARP_SIZE, num_warps_per_threadblock, 1);
+ dim3 block_size(kWarpSize, num_warps_per_threadblock, 1);
#define INVOKE_GROUP_INDEX_SELECT_OR_ADD(USE_INDEX_SELECT, USE_VAR_COLS) \
FBGEMM_LAUNCH_KERNEL( \
From 76b963d05727f035d06362c2065764f4224d8b30 Mon Sep 17 00:00:00 2001
From: Shreyashri Biswas
Date: Wed, 19 Nov 2025 05:26:33 +0000
Subject: [PATCH 4/5] added cuda implementation and rocm guards
---
.../src/sparse_ops/sparse_group_index.cu | 139 ++++++++++++------
1 file changed, 98 insertions(+), 41 deletions(-)
diff --git a/fbgemm_gpu/src/sparse_ops/sparse_group_index.cu b/fbgemm_gpu/src/sparse_ops/sparse_group_index.cu
index d9f8b823e2..d633608fa5 100644
--- a/fbgemm_gpu/src/sparse_ops/sparse_group_index.cu
+++ b/fbgemm_gpu/src/sparse_ops/sparse_group_index.cu
@@ -11,19 +11,21 @@
using Tensor = at::Tensor;
namespace fbgemm_gpu {
+namespace {
+
+#ifdef USE_ROCM
+constexpr int kGroupIndexWarpSize = kWarpSize;
+#else
+constexpr int kGroupIndexWarpSize = kWarpSize;
+#endif
-// TODO: Update UNROLL_FACTOR
constexpr int GROUP_INDEX_SELECT_UNROLL_FACTOR = 1;
constexpr int GROUP_INDEX_SELECT_COLS_PER_WARP =
- GROUP_INDEX_SELECT_UNROLL_FACTOR * kWarpSize;
-
-// GROUP_INDEX_SELECT_COLS_PER_WARP must be power of two
+ GROUP_INDEX_SELECT_UNROLL_FACTOR * kGroupIndexWarpSize;
constexpr int GROUP_INDEX_SELECT_LOG_COLS_PER_WARP =
log2_calc::value;
-int get_group_index_select_cols_per_warp() {
- return GROUP_INDEX_SELECT_COLS_PER_WARP;
-}
+#ifdef USE_ROCM
template <
typename index_t,
@@ -40,17 +42,16 @@ __launch_bounds__(kMaxThreads) void group_index_select_or_add_2d_kernel(
const int64_t* indices_ptrs,
const int64_t* warp_offsets_group,
const int32_t* num_cols_group,
- const int64_t num_work_rows, // number of rows to work on per member
+ const int64_t num_work_rows,
const int64_t group_size) {
const auto total_num_warps = warp_offsets_group[group_size];
- // USE_INDEX_SELECT is a template argument; the compiler prunes the unused branch.
if (USE_INDEX_SELECT) {
for (int64_t warp_id = threadIdx.y * gridDim.x + blockIdx.x;
warp_id < total_num_warps;
warp_id += gridDim.x * blockDim.y) {
int32_t member_id, member_warp_id, num_cols, warps_per_row;
if (USE_VAR_COLS) {
- __shared__ int member_ids[kMaxThreads / kWarpSize];
+ __shared__ int member_ids[kMaxThreads / kGroupIndexWarpSize];
if (threadIdx.x == 0) {
binary_search_range(
&member_ids[threadIdx.y],
@@ -64,7 +65,6 @@ __launch_bounds__(kMaxThreads) void group_index_select_or_add_2d_kernel(
warps_per_row = (num_cols + COLS_PER_WARP - 1) >> LOG_COLS_PER_WARP;
member_warp_id = warp_id - warp_offsets_group[member_id];
} else {
- // All columns are the same
num_cols = num_cols_group[0];
warps_per_row = (num_cols + COLS_PER_WARP - 1) >> LOG_COLS_PER_WARP;
member_id = warp_id / (warps_per_row * num_work_rows);
@@ -78,7 +78,6 @@ __launch_bounds__(kMaxThreads) void group_index_select_or_add_2d_kernel(
reinterpret_cast(input_ptrs[member_id]) + col_offset;
scalar_t* output =
reinterpret_cast(output_ptrs[member_id]) + col_offset;
-
index_t* indices = reinterpret_cast(indices_ptrs[member_id]);
const index_t idx = indices[row];
#pragma unroll
@@ -87,8 +86,6 @@ __launch_bounds__(kMaxThreads) void group_index_select_or_add_2d_kernel(
}
}
} else {
- // Cache a handful of scatter destinations per warp so we can merge
- // consecutive updates that hit the same index before touching global memory.
constexpr int kCacheSlots = 2;
index_t cached_idx[kCacheSlots];
scalar_t cached_vals[kCacheSlots][UNROLL_FACTOR];
@@ -135,7 +132,7 @@ __launch_bounds__(kMaxThreads) void group_index_select_or_add_2d_kernel(
warp_id += gridDim.x * blockDim.y) {
int32_t member_id, member_warp_id, num_cols, warps_per_row;
if (USE_VAR_COLS) {
- __shared__ int member_ids[kMaxThreads / kWarpSize];
+ __shared__ int member_ids[kMaxThreads / kGroupIndexWarpSize];
if (threadIdx.x == 0) {
binary_search_range(
&member_ids[threadIdx.y],
@@ -149,7 +146,6 @@ __launch_bounds__(kMaxThreads) void group_index_select_or_add_2d_kernel(
warps_per_row = (num_cols + COLS_PER_WARP - 1) >> LOG_COLS_PER_WARP;
member_warp_id = warp_id - warp_offsets_group[member_id];
} else {
- // All columns are the same
num_cols = num_cols_group[0];
warps_per_row = (num_cols + COLS_PER_WARP - 1) >> LOG_COLS_PER_WARP;
member_id = warp_id / (warps_per_row * num_work_rows);
@@ -258,6 +254,88 @@ __launch_bounds__(kMaxThreads) void group_index_select_or_add_2d_kernel(
}
}
+#else // !USE_ROCM
+
+template <
+ typename index_t,
+ typename scalar_t,
+ bool USE_INDEX_SELECT,
+ bool USE_VAR_COLS,
+ int UNROLL_FACTOR,
+ int COLS_PER_WARP,
+ int LOG_COLS_PER_WARP>
+__global__
+__launch_bounds__(kMaxThreads) void group_index_select_or_add_2d_kernel(
+ const int64_t* input_ptrs,
+ const int64_t* output_ptrs,
+ const int64_t* indices_ptrs,
+ const int64_t* warp_offsets_group,
+ const int32_t* num_cols_group,
+ const int64_t num_work_rows,
+ const int64_t group_size) {
+ const auto total_num_warps = warp_offsets_group[group_size];
+ int32_t num_cols = 0;
+ int32_t warps_per_row = 0;
+
+ if constexpr (!USE_VAR_COLS) {
+ num_cols = num_cols_group[0];
+ warps_per_row = (num_cols + COLS_PER_WARP - 1) >> LOG_COLS_PER_WARP;
+ }
+
+ for (int64_t warp_id = threadIdx.y * gridDim.x + blockIdx.x;
+ warp_id < total_num_warps;
+ warp_id += gridDim.x * blockDim.y) {
+ int32_t member_id = 0;
+ int32_t member_warp_id = 0;
+ if constexpr (USE_VAR_COLS) {
+ __shared__ int member_ids[kMaxThreads / kGroupIndexWarpSize];
+ if (threadIdx.x == 0) {
+ binary_search_range(
+ &member_ids[threadIdx.y],
+ warp_offsets_group + 1,
+ warp_id,
+ group_size);
+ }
+ syncwarp();
+ member_id = member_ids[threadIdx.y];
+ num_cols = num_cols_group[member_id];
+ warps_per_row = (num_cols + COLS_PER_WARP - 1) >> LOG_COLS_PER_WARP;
+ member_warp_id = warp_id - warp_offsets_group[member_id];
+ } else {
+ member_id = warp_id / (warps_per_row * num_work_rows);
+ member_warp_id = warp_id - (member_id * warps_per_row * num_work_rows);
+ }
+ const auto row = member_warp_id / warps_per_row;
+ const auto col_offset =
+ ((member_warp_id % warps_per_row) << LOG_COLS_PER_WARP) +
+ (threadIdx.x * UNROLL_FACTOR);
+ scalar_t* input =
+ reinterpret_cast(input_ptrs[member_id]) + col_offset;
+ scalar_t* output =
+ reinterpret_cast(output_ptrs[member_id]) + col_offset;
+
+ index_t* indices = reinterpret_cast(indices_ptrs[member_id]);
+ const index_t idx = indices[row];
+#pragma unroll
+ for (int i = 0; i < UNROLL_FACTOR && col_offset + i < num_cols; i++) {
+ if constexpr (USE_INDEX_SELECT) {
+ output[row * num_cols + i] = LDG(&input[idx * num_cols + i]);
+ } else {
+ gpuAtomicAddNoReturn(
+ &output[idx * num_cols + i], input[row * num_cols + i]);
+ }
+ }
+ }
+}
+
+#endif // USE_ROCM
+
+} // namespace
+
+int get_group_index_select_cols_per_warp() {
+ return GROUP_INDEX_SELECT_COLS_PER_WARP;
+}
+
DLL_PUBLIC void group_index_select_or_add_cuda(
const int64_t* input_ptrs,
const int64_t* output_ptrs,
@@ -278,36 +356,15 @@ DLL_PUBLIC void group_index_select_or_add_cuda(
at::cuda::OptionalCUDAGuard device_guard(device);
- // Partition work based on num_work_rows
- uint32_t num_warps_per_threadblock = kMaxThreads / kWarpSize;
+ uint32_t num_warps_per_threadblock = kMaxThreads / kGroupIndexWarpSize;
uint32_t max_grid_size =
at::cuda::getCurrentDeviceProperties()->multiProcessorCount * 8;
uint32_t grid_size = std::min(
cuda_calc_xblock_count(total_num_warps, num_warps_per_threadblock),
max_grid_size);
- dim3 block_size(kWarpSize, num_warps_per_threadblock, 1);
-
-#define INVOKE_GROUP_INDEX_SELECT_OR_ADD(USE_INDEX_SELECT, USE_VAR_COLS) \
- FBGEMM_LAUNCH_KERNEL( \
- (group_index_select_or_add_2d_kernel< \
- index_t, \
- scalar_t, \
- USE_INDEX_SELECT, \
- USE_VAR_COLS, \
- GROUP_INDEX_SELECT_UNROLL_FACTOR, \
- GROUP_INDEX_SELECT_COLS_PER_WARP, \
- GROUP_INDEX_SELECT_LOG_COLS_PER_WARP>), \
- grid_size, \
- block_size, \
- 0, \
- at::cuda::getCurrentCUDAStream(), \
- input_ptrs, \
- output_ptrs, \
- indices_ptrs, \
- warp_offsets_group, \
- num_cols_group, \
- num_work_rows, \
- group_size)
+ dim3 block_size(kGroupIndexWarpSize, num_warps_per_threadblock, 1);
+
+#define INVOKE_GROUP_INDEX_SELECT_OR_ADD(USE_INDEX_SELECT_FLAG, USE_VAR_COLS_FLAG) FBGEMM_LAUNCH_KERNEL( (group_index_select_or_add_2d_kernel< index_t, scalar_t, USE_INDEX_SELECT_FLAG, USE_VAR_COLS_FLAG, GROUP_INDEX_SELECT_UNROLL_FACTOR, GROUP_INDEX_SELECT_COLS_PER_WARP, GROUP_INDEX_SELECT_LOG_COLS_PER_WARP>), grid_size, block_size, 0, at::cuda::getCurrentCUDAStream(), input_ptrs, output_ptrs, indices_ptrs, warp_offsets_group, num_cols_group, num_work_rows, group_size)
AT_DISPATCH_INDEX_TYPES(
indices_scalar_type, "group_index_select_2d_wrapper_1", [&] {
From 1063c0c3fd8df70a822e4d7ec8d4ca6d08b85f4f Mon Sep 17 00:00:00 2001
From: Shreyashri Biswas
Date: Wed, 19 Nov 2025 18:12:17 +0000
Subject: [PATCH 5/5] removed redundant kGroupIndexWarpSize assign
---
fbgemm_gpu/src/sparse_ops/sparse_group_index.cu | 5 -----
1 file changed, 5 deletions(-)
diff --git a/fbgemm_gpu/src/sparse_ops/sparse_group_index.cu b/fbgemm_gpu/src/sparse_ops/sparse_group_index.cu
index d633608fa5..a7741bd9af 100644
--- a/fbgemm_gpu/src/sparse_ops/sparse_group_index.cu
+++ b/fbgemm_gpu/src/sparse_ops/sparse_group_index.cu
@@ -13,12 +13,7 @@ using Tensor = at::Tensor;
namespace fbgemm_gpu {
namespace {
-#ifdef USE_ROCM
-constexpr int kGroupIndexWarpSize = kWarpSize;
-#else
constexpr int kGroupIndexWarpSize = kWarpSize;
-#endif
-
constexpr int GROUP_INDEX_SELECT_UNROLL_FACTOR = 1;
constexpr int GROUP_INDEX_SELECT_COLS_PER_WARP =
GROUP_INDEX_SELECT_UNROLL_FACTOR * kGroupIndexWarpSize;