88#define AVX512_ARGSORT_64BIT
99
1010#include " avx512-64bit-common.h"
11- #include " avx512-common-argsort.h"
1211#include " avx512-64bit-keyvalue-networks.hpp"
12+ #include " avx512-common-argsort.h"
1313
1414template <typename T>
15- void std_argselect_withnan (T *arr, int64_t *arg, int64_t k, int64_t left, int64_t right)
15+ void std_argselect_withnan (
16+ T *arr, int64_t *arg, int64_t k, int64_t left, int64_t right)
1617{
1718 std::nth_element (arg + left,
1819 arg + k,
1920 arg + right,
2021 [arr](int64_t a, int64_t b) -> bool {
21- if ((!std::isnan (arr[a])) && (!std::isnan (arr[b]))) {return arr[a] < arr[b];}
22- else if (std::isnan (arr[a])) {return false ;}
23- else {return true ;}
22+ if ((!std::isnan (arr[a])) && (!std::isnan (arr[b]))) {
23+ return arr[a] < arr[b];
24+ }
25+ else if (std::isnan (arr[a])) {
26+ return false ;
27+ }
28+ else {
29+ return true ;
30+ }
2431 });
2532}
2633
27-
2834/* argsort using std::sort */
2935template <typename T>
3036void std_argsort_withnan (T *arr, int64_t *arg, int64_t left, int64_t right)
3137{
3238 std::sort (arg + left,
3339 arg + right,
3440 [arr](int64_t left, int64_t right) -> bool {
35- if ((!std::isnan (arr[left])) && (!std::isnan (arr[right]))) {return arr[left] < arr[right];}
36- else if (std::isnan (arr[left])) {return false ;}
37- else {return true ;}
41+ if ((!std::isnan (arr[left])) && (!std::isnan (arr[right]))) {
42+ return arr[left] < arr[right];
43+ }
44+ else if (std::isnan (arr[left])) {
45+ return false ;
46+ }
47+ else {
48+ return true ;
49+ }
3850 });
3951}
4052
@@ -325,13 +337,15 @@ static void argselect_64bit_(type_t *arr,
325337 int64_t pivot_index = partition_avx512_unrolled<vtype, 4 >(
326338 arr, arg, left, right + 1 , pivot, &smallest, &biggest);
327339 if ((pivot != smallest) && (pos < pivot_index))
328- argselect_64bit_<vtype>(arr, arg, pos, left, pivot_index - 1 , max_iters - 1 );
340+ argselect_64bit_<vtype>(
341+ arr, arg, pos, left, pivot_index - 1 , max_iters - 1 );
329342 else if ((pivot != biggest) && (pos >= pivot_index))
330- argselect_64bit_<vtype>(arr, arg, pos, pivot_index, right, max_iters - 1 );
343+ argselect_64bit_<vtype>(
344+ arr, arg, pos, pivot_index, right, max_iters - 1 );
331345}
332346
333347template <typename vtype, typename type_t >
334- bool has_nan (type_t * arr, int64_t arrsize)
348+ bool has_nan (type_t * arr, int64_t arrsize)
335349{
336350 using opmask_t = typename vtype::opmask_t ;
337351 using zmm_t = typename vtype::zmm_t ;
@@ -346,7 +360,7 @@ bool has_nan(type_t* arr, int64_t arrsize)
346360 else {
347361 in = vtype::loadu (arr);
348362 }
349- opmask_t nanmask = vtype::template fpclass<0x01 | 0x80 >(in);
363+ opmask_t nanmask = vtype::template fpclass<0x01 | 0x80 >(in);
350364 arr += vtype::numlanes;
351365 arrsize -= vtype::numlanes;
352366 if (nanmask != 0x00 ) {
@@ -357,10 +371,9 @@ bool has_nan(type_t* arr, int64_t arrsize)
357371 return found_nan;
358372}
359373
360-
361374/* argsort methods for 32-bit and 64-bit dtypes */
362375template <typename T>
363- void avx512_argsort (T* arr, int64_t *arg, int64_t arrsize)
376+ void avx512_argsort (T * arr, int64_t *arg, int64_t arrsize)
364377{
365378 if (arrsize > 1 ) {
366379 argsort_64bit_<zmm_vector<T>>(
@@ -369,7 +382,7 @@ void avx512_argsort(T* arr, int64_t *arg, int64_t arrsize)
369382}
370383
371384template <>
372- void avx512_argsort (double * arr, int64_t *arg, int64_t arrsize)
385+ void avx512_argsort (double * arr, int64_t *arg, int64_t arrsize)
373386{
374387 if (arrsize > 1 ) {
375388 if (has_nan<zmm_vector<double >>(arr, arrsize)) {
@@ -382,9 +395,8 @@ void avx512_argsort(double* arr, int64_t *arg, int64_t arrsize)
382395 }
383396}
384397
385-
386398template <>
387- void avx512_argsort (int32_t * arr, int64_t *arg, int64_t arrsize)
399+ void avx512_argsort (int32_t * arr, int64_t *arg, int64_t arrsize)
388400{
389401 if (arrsize > 1 ) {
390402 argsort_64bit_<ymm_vector<int32_t >>(
@@ -393,7 +405,7 @@ void avx512_argsort(int32_t* arr, int64_t *arg, int64_t arrsize)
393405}
394406
395407template <>
396- void avx512_argsort (uint32_t * arr, int64_t *arg, int64_t arrsize)
408+ void avx512_argsort (uint32_t * arr, int64_t *arg, int64_t arrsize)
397409{
398410 if (arrsize > 1 ) {
399411 argsort_64bit_<ymm_vector<uint32_t >>(
@@ -402,7 +414,7 @@ void avx512_argsort(uint32_t* arr, int64_t *arg, int64_t arrsize)
402414}
403415
404416template <>
405- void avx512_argsort (float * arr, int64_t *arg, int64_t arrsize)
417+ void avx512_argsort (float * arr, int64_t *arg, int64_t arrsize)
406418{
407419 if (arrsize > 1 ) {
408420 if (has_nan<ymm_vector<float >>(arr, arrsize)) {
@@ -416,7 +428,7 @@ void avx512_argsort(float* arr, int64_t *arg, int64_t arrsize)
416428}
417429
418430template <typename T>
419- std::vector<int64_t > avx512_argsort (T* arr, int64_t arrsize)
431+ std::vector<int64_t > avx512_argsort (T * arr, int64_t arrsize)
420432{
421433 std::vector<int64_t > indices (arrsize);
422434 std::iota (indices.begin (), indices.end (), 0 );
@@ -426,7 +438,7 @@ std::vector<int64_t> avx512_argsort(T* arr, int64_t arrsize)
426438
427439/* argselect methods for 32-bit and 64-bit dtypes */
428440template <typename T>
429- void avx512_argselect (T* arr, int64_t *arg, int64_t k, int64_t arrsize)
441+ void avx512_argselect (T * arr, int64_t *arg, int64_t k, int64_t arrsize)
430442{
431443 if (arrsize > 1 ) {
432444 argselect_64bit_<zmm_vector<T>>(
@@ -435,7 +447,7 @@ void avx512_argselect(T* arr, int64_t *arg, int64_t k, int64_t arrsize)
435447}
436448
437449template <>
438- void avx512_argselect (double * arr, int64_t *arg, int64_t k, int64_t arrsize)
450+ void avx512_argselect (double * arr, int64_t *arg, int64_t k, int64_t arrsize)
439451{
440452 if (arrsize > 1 ) {
441453 if (has_nan<zmm_vector<double >>(arr, arrsize)) {
@@ -449,7 +461,7 @@ void avx512_argselect(double* arr, int64_t *arg, int64_t k, int64_t arrsize)
449461}
450462
451463template <>
452- void avx512_argselect (int32_t * arr, int64_t *arg, int64_t k, int64_t arrsize)
464+ void avx512_argselect (int32_t * arr, int64_t *arg, int64_t k, int64_t arrsize)
453465{
454466 if (arrsize > 1 ) {
455467 argselect_64bit_<ymm_vector<int32_t >>(
@@ -458,7 +470,7 @@ void avx512_argselect(int32_t* arr, int64_t *arg, int64_t k, int64_t arrsize)
458470}
459471
460472template <>
461- void avx512_argselect (uint32_t * arr, int64_t *arg, int64_t k, int64_t arrsize)
473+ void avx512_argselect (uint32_t * arr, int64_t *arg, int64_t k, int64_t arrsize)
462474{
463475 if (arrsize > 1 ) {
464476 argselect_64bit_<ymm_vector<uint32_t >>(
@@ -467,7 +479,7 @@ void avx512_argselect(uint32_t* arr, int64_t *arg, int64_t k, int64_t arrsize)
467479}
468480
469481template <>
470- void avx512_argselect (float * arr, int64_t *arg, int64_t k, int64_t arrsize)
482+ void avx512_argselect (float * arr, int64_t *arg, int64_t k, int64_t arrsize)
471483{
472484 if (arrsize > 1 ) {
473485 if (has_nan<ymm_vector<float >>(arr, arrsize)) {
@@ -481,7 +493,7 @@ void avx512_argselect(float* arr, int64_t *arg, int64_t k, int64_t arrsize)
481493}
482494
483495template <typename T>
484- std::vector<int64_t > avx512_argselect (T* arr, int64_t k, int64_t arrsize)
496+ std::vector<int64_t > avx512_argselect (T * arr, int64_t k, int64_t arrsize)
485497{
486498 std::vector<int64_t > indices (arrsize);
487499 std::iota (indices.begin (), indices.end (), 0 );
0 commit comments