Skip to content

Commit a474843

Browse files
committed
fixed cols_per_warp
1 parent 94ccd97 commit a474843

File tree

1 file changed

+3
-11
lines changed

1 file changed

+3
-11
lines changed

fbgemm_gpu/src/sparse_ops/sparse_group_index.cu

Lines changed: 3 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -12,18 +12,10 @@ using Tensor = at::Tensor;
1212

1313
namespace fbgemm_gpu {
1414

15-
#ifdef USE_ROCM
16-
// The wave size is forced to be 32 on ROCm devices in favor
17-
// of granularity losses reduction.
18-
constexpr int EMULATED_WARP_SIZE = 32;
19-
#else
20-
constexpr int EMULATED_WARP_SIZE = kWarpSize;
21-
#endif
22-
2315
// TODO: Update UNROLL_FACTOR
2416
constexpr int GROUP_INDEX_SELECT_UNROLL_FACTOR = 1;
2517
constexpr int GROUP_INDEX_SELECT_COLS_PER_WARP =
26-
GROUP_INDEX_SELECT_UNROLL_FACTOR * EMULATED_WARP_SIZE;
18+
GROUP_INDEX_SELECT_UNROLL_FACTOR * kWarpSize;
2719

2820
// GROUP_INDEX_SELECT_COLS_PER_WARP must be power of two
2921
constexpr int GROUP_INDEX_SELECT_LOG_COLS_PER_WARP =
@@ -287,13 +279,13 @@ DLL_PUBLIC void group_index_select_or_add_cuda(
287279
at::cuda::OptionalCUDAGuard device_guard(device);
288280

289281
// Partition work based on num_work_rows
290-
uint32_t num_warps_per_threadblock = kMaxThreads / EMULATED_WARP_SIZE;
282+
uint32_t num_warps_per_threadblock = kMaxThreads / kWarpSize;
291283
uint32_t max_grid_size =
292284
at::cuda::getCurrentDeviceProperties()->multiProcessorCount * 8;
293285
uint32_t grid_size = std::min(
294286
cuda_calc_xblock_count(total_num_warps, num_warps_per_threadblock),
295287
max_grid_size);
296-
dim3 block_size(EMULATED_WARP_SIZE, num_warps_per_threadblock, 1);
288+
dim3 block_size(kWarpSize, num_warps_per_threadblock, 1);
297289

298290
#define INVOKE_GROUP_INDEX_SELECT_OR_ADD(USE_INDEX_SELECT, USE_VAR_COLS) \
299291
FBGEMM_LAUNCH_KERNEL( \

0 commit comments

Comments
 (0)