@@ -12,18 +12,10 @@ using Tensor = at::Tensor;
1212
1313namespace fbgemm_gpu {
1414
15- #ifdef USE_ROCM
16- // The wave size is forced to be 32 on ROCm devices in favor
17- // of granularity losses reduction.
18- constexpr int EMULATED_WARP_SIZE = 32 ;
19- #else
20- constexpr int EMULATED_WARP_SIZE = kWarpSize ;
21- #endif
22-
2315// TODO: Update UNROLL_FACTOR
2416constexpr int GROUP_INDEX_SELECT_UNROLL_FACTOR = 1 ;
2517constexpr int GROUP_INDEX_SELECT_COLS_PER_WARP =
26- GROUP_INDEX_SELECT_UNROLL_FACTOR * EMULATED_WARP_SIZE ;
18+ GROUP_INDEX_SELECT_UNROLL_FACTOR * kWarpSize ;
2719
2820// GROUP_INDEX_SELECT_COLS_PER_WARP must be power of two
2921constexpr int GROUP_INDEX_SELECT_LOG_COLS_PER_WARP =
@@ -287,13 +279,13 @@ DLL_PUBLIC void group_index_select_or_add_cuda(
287279 at::cuda::OptionalCUDAGuard device_guard (device);
288280
289281 // Partition work based on num_work_rows
290- uint32_t num_warps_per_threadblock = kMaxThreads / EMULATED_WARP_SIZE ;
282+ uint32_t num_warps_per_threadblock = kMaxThreads / kWarpSize ;
291283 uint32_t max_grid_size =
292284 at::cuda::getCurrentDeviceProperties ()->multiProcessorCount * 8 ;
293285 uint32_t grid_size = std::min (
294286 cuda_calc_xblock_count (total_num_warps, num_warps_per_threadblock),
295287 max_grid_size);
296- dim3 block_size (EMULATED_WARP_SIZE , num_warps_per_threadblock, 1 );
288+ dim3 block_size (kWarpSize , num_warps_per_threadblock, 1 );
297289
298290#define INVOKE_GROUP_INDEX_SELECT_OR_ADD (USE_INDEX_SELECT, USE_VAR_COLS ) \
299291 FBGEMM_LAUNCH_KERNEL ( \
0 commit comments