From b47cf7dd08cfe75215a82eced3ab2f38c1196c5d Mon Sep 17 00:00:00 2001 From: Li Li Date: Fri, 7 Nov 2025 21:54:01 +0000 Subject: [PATCH 1/5] opt group_index_select_or_add_2d_kernel --- .../src/sparse_ops/sparse_group_index.cu | 251 ++++++++++++++---- fbgemm_gpu/test/sparse/index_select_test.py | 30 ++- 2 files changed, 234 insertions(+), 47 deletions(-) diff --git a/fbgemm_gpu/src/sparse_ops/sparse_group_index.cu b/fbgemm_gpu/src/sparse_ops/sparse_group_index.cu index 96c57cde68..c05c99de13 100644 --- a/fbgemm_gpu/src/sparse_ops/sparse_group_index.cu +++ b/fbgemm_gpu/src/sparse_ops/sparse_group_index.cu @@ -51,59 +51,218 @@ __launch_bounds__(kMaxThreads) void group_index_select_or_add_2d_kernel( const int64_t num_work_rows, // number of rows to work on per member const int64_t group_size) { const auto total_num_warps = warp_offsets_group[group_size]; - int32_t num_cols = 0; - int32_t warps_per_row = 0; - - if constexpr (!USE_VAR_COLS) { - num_cols = num_cols_group[0]; - warps_per_row = (num_cols + COLS_PER_WARP - 1) >> LOG_COLS_PER_WARP; - } + // USE_INDEX_SELECT is a template argument; the compiler prunes the unused branch. + if (USE_INDEX_SELECT) { + for (int64_t warp_id = threadIdx.y * gridDim.x + blockIdx.x; + warp_id < total_num_warps; + warp_id += gridDim.x * blockDim.y) { + int32_t member_id, member_warp_id, num_cols, warps_per_row; + if (USE_VAR_COLS) { + __shared__ int member_ids[kMaxThreads / kWarpSize]; + if (threadIdx.x == 0) { + binary_search_range( + &member_ids[threadIdx.y], + warp_offsets_group + 1, + warp_id, + group_size); + } + syncwarp(); + member_id = member_ids[threadIdx.y]; + num_cols = num_cols_group[member_id]; + warps_per_row = (num_cols + COLS_PER_WARP - 1) >> LOG_COLS_PER_WARP; + member_warp_id = warp_id - warp_offsets_group[member_id]; + } else { + // All columns are the same + num_cols = num_cols_group[0]; + warps_per_row = (num_cols + COLS_PER_WARP - 1) >> LOG_COLS_PER_WARP; + member_id = warp_id / (warps_per_row * num_work_rows); + member_warp_id = warp_id - (member_id * warps_per_row * num_work_rows); + } + const auto row = member_warp_id / warps_per_row; + const auto col_offset = + ((member_warp_id % warps_per_row) << LOG_COLS_PER_WARP) + + (threadIdx.x * UNROLL_FACTOR); + scalar_t* input = + reinterpret_cast(input_ptrs[member_id]) + col_offset; + scalar_t* output = + reinterpret_cast(output_ptrs[member_id]) + col_offset; - for (int64_t warp_id = threadIdx.y * gridDim.x + blockIdx.x; - warp_id < total_num_warps; - warp_id += gridDim.x * blockDim.y) { - int32_t member_id = 0; - int32_t member_warp_id = 0; - if constexpr (USE_VAR_COLS) { - __shared__ int member_ids[kMaxThreads / EMULATED_WARP_SIZE]; - if (threadIdx.x == 0) { - binary_search_range( - &member_ids[threadIdx.y], - warp_offsets_group + 1, - warp_id, - group_size); + index_t* indices = reinterpret_cast(indices_ptrs[member_id]); + const index_t idx = indices[row]; +#pragma unroll + for (int i = 0; i < UNROLL_FACTOR && col_offset + i < num_cols; i++) { + output[row * num_cols + i] = LDG(&input[idx * num_cols + i]); } - syncwarp(); - member_id = member_ids[threadIdx.y]; - num_cols = num_cols_group[member_id]; - warps_per_row = (num_cols + COLS_PER_WARP - 1) >> LOG_COLS_PER_WARP; - member_warp_id = warp_id - warp_offsets_group[member_id]; - } else { - // All columns are the same - member_id = warp_id / (warps_per_row * num_work_rows); - member_warp_id = warp_id - (member_id * warps_per_row * num_work_rows); } - const auto row = member_warp_id / warps_per_row; - const auto col_offset = - ((member_warp_id % warps_per_row) << LOG_COLS_PER_WARP) + - (threadIdx.x * UNROLL_FACTOR); - scalar_t* input = - reinterpret_cast(input_ptrs[member_id]) + col_offset; - scalar_t* output = - reinterpret_cast(output_ptrs[member_id]) + col_offset; - - index_t* indices = reinterpret_cast(indices_ptrs[member_id]); - const index_t idx = indices[row]; + } else { + // Cache a handful of scatter destinations per warp so we can merge + // consecutive updates that hit the same index before touching global memory. + constexpr int kCacheSlots = 2; + index_t cached_idx[kCacheSlots]; + scalar_t cached_vals[kCacheSlots][UNROLL_FACTOR]; + bool cached_valid[kCacheSlots]; #pragma unroll - for (int i = 0; i < UNROLL_FACTOR && col_offset + i < num_cols; i++) { - // Compile time conditional - if constexpr (USE_INDEX_SELECT) { - output[row * num_cols + i] = LDG(&input[idx * num_cols + i]); + for (int slot = 0; slot < kCacheSlots; ++slot) { + cached_valid[slot] = false; + } + int32_t active_member_id = -1; + int32_t active_num_cols = 0; + int32_t active_col_offset = -1; + scalar_t* active_input_base = nullptr; + scalar_t* active_output_base = nullptr; + index_t* active_indices = nullptr; + + auto flush_cache = [&](scalar_t* out_base, + int32_t num_cols, + int32_t col_offset) { + if (!out_base) { + return; + } +#pragma unroll + for (int slot = 0; slot < kCacheSlots; ++slot) { + if (!cached_valid[slot]) { + continue; + } + const int64_t row_offset = + static_cast(cached_idx[slot]) * num_cols; +#pragma unroll + for (int j = 0; j < UNROLL_FACTOR; ++j) { + const int32_t col = col_offset + j; + if (col >= num_cols) { + break; + } + gpuAtomicAddNoReturn( + out_base + row_offset + col, cached_vals[slot][j]); + } + cached_valid[slot] = false; + } + }; + + for (int64_t warp_id = threadIdx.y * gridDim.x + blockIdx.x; + warp_id < total_num_warps; + warp_id += gridDim.x * blockDim.y) { + int32_t member_id, member_warp_id, num_cols, warps_per_row; + if (USE_VAR_COLS) { + __shared__ int member_ids[kMaxThreads / kWarpSize]; + if (threadIdx.x == 0) { + binary_search_range( + &member_ids[threadIdx.y], + warp_offsets_group + 1, + warp_id, + group_size); + } + syncwarp(); + member_id = member_ids[threadIdx.y]; + num_cols = num_cols_group[member_id]; + warps_per_row = (num_cols + COLS_PER_WARP - 1) >> LOG_COLS_PER_WARP; + member_warp_id = warp_id - warp_offsets_group[member_id]; } else { - gpuAtomicAddNoReturn( - &output[idx * num_cols + i], input[row * num_cols + i]); + // All columns are the same + num_cols = num_cols_group[0]; + warps_per_row = (num_cols + COLS_PER_WARP - 1) >> LOG_COLS_PER_WARP; + member_id = warp_id / (warps_per_row * num_work_rows); + member_warp_id = warp_id - (member_id * warps_per_row * num_work_rows); + } + const int64_t row = member_warp_id / warps_per_row; + const int32_t col_offset = + static_cast(((member_warp_id % warps_per_row) + << LOG_COLS_PER_WARP) + + (threadIdx.x * UNROLL_FACTOR)); + + const bool member_changed = member_id != active_member_id; + const bool num_cols_changed = + member_changed ? false : (num_cols != active_num_cols); + const bool col_changed = + member_changed ? false : (col_offset != active_col_offset); + if (member_changed || num_cols_changed || col_changed) { + flush_cache(active_output_base, active_num_cols, active_col_offset); + active_member_id = member_id; + active_num_cols = num_cols; + active_col_offset = col_offset; + active_input_base = + reinterpret_cast(input_ptrs[member_id]); + active_output_base = + reinterpret_cast(output_ptrs[member_id]); + active_indices = + reinterpret_cast(indices_ptrs[member_id]); + } + + if (col_offset >= active_num_cols) { + continue; + } + + const index_t idx = active_indices[row]; + + scalar_t local_vals[UNROLL_FACTOR]; +#pragma unroll + for (int j = 0; j < UNROLL_FACTOR; ++j) { + local_vals[j] = static_cast(0); + } + const int64_t input_offset = + static_cast(row) * active_num_cols + active_col_offset; +#pragma unroll + for (int j = 0; j < UNROLL_FACTOR; ++j) { + const int32_t col = active_col_offset + j; + if (col >= active_num_cols) { + break; + } + local_vals[j] = active_input_base[input_offset + j]; + } + + bool appended = false; +#pragma unroll + for (int slot = 0; slot < kCacheSlots; ++slot) { + if (cached_valid[slot] && cached_idx[slot] == idx) { +#pragma unroll + for (int j = 0; j < UNROLL_FACTOR; ++j) { + const int32_t col = active_col_offset + j; + if (col >= active_num_cols) { + break; + } + cached_vals[slot][j] += local_vals[j]; + } + appended = true; + break; + } + } + + if (!appended) { + int slot_to_use = -1; +#pragma unroll + for (int slot = 0; slot < kCacheSlots; ++slot) { + if (!cached_valid[slot]) { + slot_to_use = slot; + break; + } + } + if (slot_to_use == -1) { + slot_to_use = 0; + const int64_t row_offset = + static_cast(cached_idx[slot_to_use]) * + active_num_cols; +#pragma unroll + for (int j = 0; j < UNROLL_FACTOR; ++j) { + const int32_t col = active_col_offset + j; + if (col >= active_num_cols) { + break; + } + gpuAtomicAddNoReturn( + active_output_base + row_offset + col, + cached_vals[slot_to_use][j]); + } + cached_valid[slot_to_use] = false; + } + + cached_idx[slot_to_use] = idx; +#pragma unroll + for (int j = 0; j < UNROLL_FACTOR; ++j) { + cached_vals[slot_to_use][j] = local_vals[j]; + } + cached_valid[slot_to_use] = true; } } + + flush_cache(active_output_base, active_num_cols, active_col_offset); } } diff --git a/fbgemm_gpu/test/sparse/index_select_test.py b/fbgemm_gpu/test/sparse/index_select_test.py index 6c61b77bf8..ff4b264e05 100644 --- a/fbgemm_gpu/test/sparse/index_select_test.py +++ b/fbgemm_gpu/test/sparse/index_select_test.py @@ -1,4 +1,4 @@ -#!/usr/bin/env python3 + #!/usr/bin/env python3 # Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. # @@ -239,6 +239,34 @@ def compare_tensor_groups( {"rtol": 1e-02, "atol": 1e-02} if dtype == torch.half else {}, ) + @unittest.skipIf(not gpu_available, "CUDA not available") + def test_group_index_select_dim0_duplicate_gradients(self) -> None: + device = torch.device("cuda") + dtype = torch.float + + num_rows = 4 + num_cols = 9 + indices = torch.tensor([0, 1, 2, 1, 0, 2], dtype=torch.long, device=device) + + input_tensor = torch.randn( + (num_rows, num_cols), dtype=dtype, device=device + ).requires_grad_(True) + + output_group = torch.ops.fbgemm.group_index_select_dim0( + [input_tensor], [indices] + ) + output = output_group[0] + + grad = torch.arange( + output.numel(), dtype=dtype, device=device + ).view_as(output) + output.backward(grad) + + ref_grad = torch.zeros_like(input_tensor) + ref_grad.index_add_(0, indices, grad) + + torch.testing.assert_close(input_tensor.grad, ref_grad) + @given( num_inputs=st.integers(0, 100), max_input_rows=st.integers(2, 32), From 94ccd9707e66fcfd869346869235f3a2469b569d Mon Sep 17 00:00:00 2001 From: Shreyashri Biswas Date: Thu, 13 Nov 2025 18:50:50 +0000 Subject: [PATCH 2/5] remove group index ut --- fbgemm_gpu/test/sparse/index_select_test.py | 30 +-------------------- 1 file changed, 1 insertion(+), 29 deletions(-) diff --git a/fbgemm_gpu/test/sparse/index_select_test.py b/fbgemm_gpu/test/sparse/index_select_test.py index ff4b264e05..6c61b77bf8 100644 --- a/fbgemm_gpu/test/sparse/index_select_test.py +++ b/fbgemm_gpu/test/sparse/index_select_test.py @@ -1,4 +1,4 @@ - #!/usr/bin/env python3 +#!/usr/bin/env python3 # Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. # @@ -239,34 +239,6 @@ def compare_tensor_groups( {"rtol": 1e-02, "atol": 1e-02} if dtype == torch.half else {}, ) - @unittest.skipIf(not gpu_available, "CUDA not available") - def test_group_index_select_dim0_duplicate_gradients(self) -> None: - device = torch.device("cuda") - dtype = torch.float - - num_rows = 4 - num_cols = 9 - indices = torch.tensor([0, 1, 2, 1, 0, 2], dtype=torch.long, device=device) - - input_tensor = torch.randn( - (num_rows, num_cols), dtype=dtype, device=device - ).requires_grad_(True) - - output_group = torch.ops.fbgemm.group_index_select_dim0( - [input_tensor], [indices] - ) - output = output_group[0] - - grad = torch.arange( - output.numel(), dtype=dtype, device=device - ).view_as(output) - output.backward(grad) - - ref_grad = torch.zeros_like(input_tensor) - ref_grad.index_add_(0, indices, grad) - - torch.testing.assert_close(input_tensor.grad, ref_grad) - @given( num_inputs=st.integers(0, 100), max_input_rows=st.integers(2, 32), From a474843aaa9c5d51a509be2644892c875f1ee240 Mon Sep 17 00:00:00 2001 From: Shreyashri Biswas Date: Thu, 13 Nov 2025 18:53:06 +0000 Subject: [PATCH 3/5] fixed cols_per_warp --- fbgemm_gpu/src/sparse_ops/sparse_group_index.cu | 14 +++----------- 1 file changed, 3 insertions(+), 11 deletions(-) diff --git a/fbgemm_gpu/src/sparse_ops/sparse_group_index.cu b/fbgemm_gpu/src/sparse_ops/sparse_group_index.cu index c05c99de13..d9f8b823e2 100644 --- a/fbgemm_gpu/src/sparse_ops/sparse_group_index.cu +++ b/fbgemm_gpu/src/sparse_ops/sparse_group_index.cu @@ -12,18 +12,10 @@ using Tensor = at::Tensor; namespace fbgemm_gpu { -#ifdef USE_ROCM -// The wave size is forced to be 32 on ROCm devices in favor -// of granularity losses reduction. -constexpr int EMULATED_WARP_SIZE = 32; -#else -constexpr int EMULATED_WARP_SIZE = kWarpSize; -#endif - // TODO: Update UNROLL_FACTOR constexpr int GROUP_INDEX_SELECT_UNROLL_FACTOR = 1; constexpr int GROUP_INDEX_SELECT_COLS_PER_WARP = - GROUP_INDEX_SELECT_UNROLL_FACTOR * EMULATED_WARP_SIZE; + GROUP_INDEX_SELECT_UNROLL_FACTOR * kWarpSize; // GROUP_INDEX_SELECT_COLS_PER_WARP must be power of two constexpr int GROUP_INDEX_SELECT_LOG_COLS_PER_WARP = @@ -287,13 +279,13 @@ DLL_PUBLIC void group_index_select_or_add_cuda( at::cuda::OptionalCUDAGuard device_guard(device); // Partition work based on num_work_rows - uint32_t num_warps_per_threadblock = kMaxThreads / EMULATED_WARP_SIZE; + uint32_t num_warps_per_threadblock = kMaxThreads / kWarpSize; uint32_t max_grid_size = at::cuda::getCurrentDeviceProperties()->multiProcessorCount * 8; uint32_t grid_size = std::min( cuda_calc_xblock_count(total_num_warps, num_warps_per_threadblock), max_grid_size); - dim3 block_size(EMULATED_WARP_SIZE, num_warps_per_threadblock, 1); + dim3 block_size(kWarpSize, num_warps_per_threadblock, 1); #define INVOKE_GROUP_INDEX_SELECT_OR_ADD(USE_INDEX_SELECT, USE_VAR_COLS) \ FBGEMM_LAUNCH_KERNEL( \ From 76b963d05727f035d06362c2065764f4224d8b30 Mon Sep 17 00:00:00 2001 From: Shreyashri Biswas Date: Wed, 19 Nov 2025 05:26:33 +0000 Subject: [PATCH 4/5] added cuda implementation and rocm guards --- .../src/sparse_ops/sparse_group_index.cu | 139 ++++++++++++------ 1 file changed, 98 insertions(+), 41 deletions(-) diff --git a/fbgemm_gpu/src/sparse_ops/sparse_group_index.cu b/fbgemm_gpu/src/sparse_ops/sparse_group_index.cu index d9f8b823e2..d633608fa5 100644 --- a/fbgemm_gpu/src/sparse_ops/sparse_group_index.cu +++ b/fbgemm_gpu/src/sparse_ops/sparse_group_index.cu @@ -11,19 +11,21 @@ using Tensor = at::Tensor; namespace fbgemm_gpu { +namespace { + +#ifdef USE_ROCM +constexpr int kGroupIndexWarpSize = kWarpSize; +#else +constexpr int kGroupIndexWarpSize = kWarpSize; +#endif -// TODO: Update UNROLL_FACTOR constexpr int GROUP_INDEX_SELECT_UNROLL_FACTOR = 1; constexpr int GROUP_INDEX_SELECT_COLS_PER_WARP = - GROUP_INDEX_SELECT_UNROLL_FACTOR * kWarpSize; - -// GROUP_INDEX_SELECT_COLS_PER_WARP must be power of two + GROUP_INDEX_SELECT_UNROLL_FACTOR * kGroupIndexWarpSize; constexpr int GROUP_INDEX_SELECT_LOG_COLS_PER_WARP = log2_calc::value; -int get_group_index_select_cols_per_warp() { - return GROUP_INDEX_SELECT_COLS_PER_WARP; -} +#ifdef USE_ROCM template < typename index_t, @@ -40,17 +42,16 @@ __launch_bounds__(kMaxThreads) void group_index_select_or_add_2d_kernel( const int64_t* indices_ptrs, const int64_t* warp_offsets_group, const int32_t* num_cols_group, - const int64_t num_work_rows, // number of rows to work on per member + const int64_t num_work_rows, const int64_t group_size) { const auto total_num_warps = warp_offsets_group[group_size]; - // USE_INDEX_SELECT is a template argument; the compiler prunes the unused branch. if (USE_INDEX_SELECT) { for (int64_t warp_id = threadIdx.y * gridDim.x + blockIdx.x; warp_id < total_num_warps; warp_id += gridDim.x * blockDim.y) { int32_t member_id, member_warp_id, num_cols, warps_per_row; if (USE_VAR_COLS) { - __shared__ int member_ids[kMaxThreads / kWarpSize]; + __shared__ int member_ids[kMaxThreads / kGroupIndexWarpSize]; if (threadIdx.x == 0) { binary_search_range( &member_ids[threadIdx.y], @@ -64,7 +65,6 @@ __launch_bounds__(kMaxThreads) void group_index_select_or_add_2d_kernel( warps_per_row = (num_cols + COLS_PER_WARP - 1) >> LOG_COLS_PER_WARP; member_warp_id = warp_id - warp_offsets_group[member_id]; } else { - // All columns are the same num_cols = num_cols_group[0]; warps_per_row = (num_cols + COLS_PER_WARP - 1) >> LOG_COLS_PER_WARP; member_id = warp_id / (warps_per_row * num_work_rows); @@ -78,7 +78,6 @@ __launch_bounds__(kMaxThreads) void group_index_select_or_add_2d_kernel( reinterpret_cast(input_ptrs[member_id]) + col_offset; scalar_t* output = reinterpret_cast(output_ptrs[member_id]) + col_offset; - index_t* indices = reinterpret_cast(indices_ptrs[member_id]); const index_t idx = indices[row]; #pragma unroll @@ -87,8 +86,6 @@ __launch_bounds__(kMaxThreads) void group_index_select_or_add_2d_kernel( } } } else { - // Cache a handful of scatter destinations per warp so we can merge - // consecutive updates that hit the same index before touching global memory. constexpr int kCacheSlots = 2; index_t cached_idx[kCacheSlots]; scalar_t cached_vals[kCacheSlots][UNROLL_FACTOR]; @@ -135,7 +132,7 @@ __launch_bounds__(kMaxThreads) void group_index_select_or_add_2d_kernel( warp_id += gridDim.x * blockDim.y) { int32_t member_id, member_warp_id, num_cols, warps_per_row; if (USE_VAR_COLS) { - __shared__ int member_ids[kMaxThreads / kWarpSize]; + __shared__ int member_ids[kMaxThreads / kGroupIndexWarpSize]; if (threadIdx.x == 0) { binary_search_range( &member_ids[threadIdx.y], @@ -149,7 +146,6 @@ __launch_bounds__(kMaxThreads) void group_index_select_or_add_2d_kernel( warps_per_row = (num_cols + COLS_PER_WARP - 1) >> LOG_COLS_PER_WARP; member_warp_id = warp_id - warp_offsets_group[member_id]; } else { - // All columns are the same num_cols = num_cols_group[0]; warps_per_row = (num_cols + COLS_PER_WARP - 1) >> LOG_COLS_PER_WARP; member_id = warp_id / (warps_per_row * num_work_rows); @@ -258,6 +254,88 @@ __launch_bounds__(kMaxThreads) void group_index_select_or_add_2d_kernel( } } +#else // !USE_ROCM + +template < + typename index_t, + typename scalar_t, + bool USE_INDEX_SELECT, + bool USE_VAR_COLS, + int UNROLL_FACTOR, + int COLS_PER_WARP, + int LOG_COLS_PER_WARP> +__global__ +__launch_bounds__(kMaxThreads) void group_index_select_or_add_2d_kernel( + const int64_t* input_ptrs, + const int64_t* output_ptrs, + const int64_t* indices_ptrs, + const int64_t* warp_offsets_group, + const int32_t* num_cols_group, + const int64_t num_work_rows, + const int64_t group_size) { + const auto total_num_warps = warp_offsets_group[group_size]; + int32_t num_cols = 0; + int32_t warps_per_row = 0; + + if constexpr (!USE_VAR_COLS) { + num_cols = num_cols_group[0]; + warps_per_row = (num_cols + COLS_PER_WARP - 1) >> LOG_COLS_PER_WARP; + } + + for (int64_t warp_id = threadIdx.y * gridDim.x + blockIdx.x; + warp_id < total_num_warps; + warp_id += gridDim.x * blockDim.y) { + int32_t member_id = 0; + int32_t member_warp_id = 0; + if constexpr (USE_VAR_COLS) { + __shared__ int member_ids[kMaxThreads / kGroupIndexWarpSize]; + if (threadIdx.x == 0) { + binary_search_range( + &member_ids[threadIdx.y], + warp_offsets_group + 1, + warp_id, + group_size); + } + syncwarp(); + member_id = member_ids[threadIdx.y]; + num_cols = num_cols_group[member_id]; + warps_per_row = (num_cols + COLS_PER_WARP - 1) >> LOG_COLS_PER_WARP; + member_warp_id = warp_id - warp_offsets_group[member_id]; + } else { + member_id = warp_id / (warps_per_row * num_work_rows); + member_warp_id = warp_id - (member_id * warps_per_row * num_work_rows); + } + const auto row = member_warp_id / warps_per_row; + const auto col_offset = + ((member_warp_id % warps_per_row) << LOG_COLS_PER_WARP) + + (threadIdx.x * UNROLL_FACTOR); + scalar_t* input = + reinterpret_cast(input_ptrs[member_id]) + col_offset; + scalar_t* output = + reinterpret_cast(output_ptrs[member_id]) + col_offset; + + index_t* indices = reinterpret_cast(indices_ptrs[member_id]); + const index_t idx = indices[row]; +#pragma unroll + for (int i = 0; i < UNROLL_FACTOR && col_offset + i < num_cols; i++) { + if constexpr (USE_INDEX_SELECT) { + output[row * num_cols + i] = LDG(&input[idx * num_cols + i]); + } else { + gpuAtomicAddNoReturn( + &output[idx * num_cols + i], input[row * num_cols + i]); + } + } + } +} + +#endif // USE_ROCM + +} // namespace + +int get_group_index_select_cols_per_warp() { + return GROUP_INDEX_SELECT_COLS_PER_WARP; +} + DLL_PUBLIC void group_index_select_or_add_cuda( const int64_t* input_ptrs, const int64_t* output_ptrs, @@ -278,36 +356,15 @@ DLL_PUBLIC void group_index_select_or_add_cuda( at::cuda::OptionalCUDAGuard device_guard(device); - // Partition work based on num_work_rows - uint32_t num_warps_per_threadblock = kMaxThreads / kWarpSize; + uint32_t num_warps_per_threadblock = kMaxThreads / kGroupIndexWarpSize; uint32_t max_grid_size = at::cuda::getCurrentDeviceProperties()->multiProcessorCount * 8; uint32_t grid_size = std::min( cuda_calc_xblock_count(total_num_warps, num_warps_per_threadblock), max_grid_size); - dim3 block_size(kWarpSize, num_warps_per_threadblock, 1); - -#define INVOKE_GROUP_INDEX_SELECT_OR_ADD(USE_INDEX_SELECT, USE_VAR_COLS) \ - FBGEMM_LAUNCH_KERNEL( \ - (group_index_select_or_add_2d_kernel< \ - index_t, \ - scalar_t, \ - USE_INDEX_SELECT, \ - USE_VAR_COLS, \ - GROUP_INDEX_SELECT_UNROLL_FACTOR, \ - GROUP_INDEX_SELECT_COLS_PER_WARP, \ - GROUP_INDEX_SELECT_LOG_COLS_PER_WARP>), \ - grid_size, \ - block_size, \ - 0, \ - at::cuda::getCurrentCUDAStream(), \ - input_ptrs, \ - output_ptrs, \ - indices_ptrs, \ - warp_offsets_group, \ - num_cols_group, \ - num_work_rows, \ - group_size) + dim3 block_size(kGroupIndexWarpSize, num_warps_per_threadblock, 1); + +#define INVOKE_GROUP_INDEX_SELECT_OR_ADD(USE_INDEX_SELECT_FLAG, USE_VAR_COLS_FLAG) FBGEMM_LAUNCH_KERNEL( (group_index_select_or_add_2d_kernel< index_t, scalar_t, USE_INDEX_SELECT_FLAG, USE_VAR_COLS_FLAG, GROUP_INDEX_SELECT_UNROLL_FACTOR, GROUP_INDEX_SELECT_COLS_PER_WARP, GROUP_INDEX_SELECT_LOG_COLS_PER_WARP>), grid_size, block_size, 0, at::cuda::getCurrentCUDAStream(), input_ptrs, output_ptrs, indices_ptrs, warp_offsets_group, num_cols_group, num_work_rows, group_size) AT_DISPATCH_INDEX_TYPES( indices_scalar_type, "group_index_select_2d_wrapper_1", [&] { From 1063c0c3fd8df70a822e4d7ec8d4ca6d08b85f4f Mon Sep 17 00:00:00 2001 From: Shreyashri Biswas Date: Wed, 19 Nov 2025 18:12:17 +0000 Subject: [PATCH 5/5] removed redundant kGroupIndexWarpSize assign --- fbgemm_gpu/src/sparse_ops/sparse_group_index.cu | 5 ----- 1 file changed, 5 deletions(-) diff --git a/fbgemm_gpu/src/sparse_ops/sparse_group_index.cu b/fbgemm_gpu/src/sparse_ops/sparse_group_index.cu index d633608fa5..a7741bd9af 100644 --- a/fbgemm_gpu/src/sparse_ops/sparse_group_index.cu +++ b/fbgemm_gpu/src/sparse_ops/sparse_group_index.cu @@ -13,12 +13,7 @@ using Tensor = at::Tensor; namespace fbgemm_gpu { namespace { -#ifdef USE_ROCM -constexpr int kGroupIndexWarpSize = kWarpSize; -#else constexpr int kGroupIndexWarpSize = kWarpSize; -#endif - constexpr int GROUP_INDEX_SELECT_UNROLL_FACTOR = 1; constexpr int GROUP_INDEX_SELECT_COLS_PER_WARP = GROUP_INDEX_SELECT_UNROLL_FACTOR * kGroupIndexWarpSize;