@@ -6,30 +6,25 @@ namespace refactor::kernel::cuda {
66
77 __global__ static void whereKernel (
88 unsigned long long n,
9- unsigned int const *strides,
10- bool const *c,
11- uint8_t const *x,
12- uint8_t const *y,
13- uint8_t *output,
9+ unsigned int const *__restrict__ strides,
10+ bool const *__restrict__ c,
11+ uint8_t const *__restrict__ x,
12+ uint8_t const *__restrict__ y,
13+ uint8_t *__restrict__ output,
1414 unsigned int rank,
1515 unsigned int eleSize) {
16- extern __shared__ unsigned int shared[];
17- for (auto i = threadIdx .x ; i < rank * 4 ; i += blockDim .x ) {
18- shared[i] = strides[i];
19- }
20- __syncthreads ();
2116 for (auto tid = blockIdx .x * blockDim .x + threadIdx .x ,
2217 step = blockDim .x * gridDim .x ;
2318 tid < n;
2419 tid += step) {
2520 auto ic = 0u , ix = 0u , iy = 0u , rem = tid;
2621 for (auto j = 0u ; j < rank; ++j) {
27- auto dim = shared + 4 * j;
28- auto quot = rem / dim[ 3 ] ;
29- rem %= dim[ 3 ] ;
30- ic += quot * dim[ 0 ] ;
31- ix += quot * dim[ 1 ] ;
32- iy += quot * dim[ 2 ] ;
22+ auto dim = strides + 4 * j;
23+ auto quot = rem / __ldg ( dim + 3 ) ;
24+ rem %= __ldg ( dim + 3 ) ;
25+ ic += quot * __ldg ( dim + 0 ) ;
26+ ix += quot * __ldg ( dim + 1 ) ;
27+ iy += quot * __ldg ( dim + 2 ) ;
3328 }
3429
3530 optimizedMemcpy (output + tid * eleSize,
@@ -52,7 +47,7 @@ namespace refactor::kernel::cuda {
5247 whereKernel<<<
5348 params.gridSize,
5449 params.blockSize,
55- rank * sizeof ( unsigned int ) * 4 ,
50+ 0 ,
5651 reinterpret_cast <cudaStream_t>(params.stream)>>> (
5752 params.n ,
5853 strides,
0 commit comments