88#define AVX512_ARGSORT_64BIT
99
1010#include " avx512-64bit-common.h"
11- #include " avx512-common-argsort.h"
1211#include " avx512-64bit-keyvalue-networks.hpp"
12+ #include " avx512-common-argsort.h"
13+
14+ template <typename T>
15+ void std_argselect_withnan (
16+ T *arr, int64_t *arg, int64_t k, int64_t left, int64_t right)
17+ {
18+ std::nth_element (arg + left,
19+ arg + k,
20+ arg + right,
21+ [arr](int64_t a, int64_t b) -> bool {
22+ if ((!std::isnan (arr[a])) && (!std::isnan (arr[b]))) {
23+ return arr[a] < arr[b];
24+ }
25+ else if (std::isnan (arr[a])) {
26+ return false ;
27+ }
28+ else {
29+ return true ;
30+ }
31+ });
32+ }
1333
1434/* argsort using std::sort */
1535template <typename T>
@@ -18,9 +38,15 @@ void std_argsort_withnan(T *arr, int64_t *arg, int64_t left, int64_t right)
1838 std::sort (arg + left,
1939 arg + right,
2040 [arr](int64_t left, int64_t right) -> bool {
21- if ((!std::isnan (arr[left])) && (!std::isnan (arr[right]))) {return arr[left] < arr[right];}
22- else if (std::isnan (arr[left])) {return false ;}
23- else {return true ;}
41+ if ((!std::isnan (arr[left])) && (!std::isnan (arr[right]))) {
42+ return arr[left] < arr[right];
43+ }
44+ else if (std::isnan (arr[left])) {
45+ return false ;
46+ }
47+ else {
48+ return true ;
49+ }
2450 });
2551}
2652
@@ -284,7 +310,42 @@ inline void argsort_64bit_(type_t *arr,
284310}
285311
286312template <typename vtype, typename type_t >
287- bool has_nan (type_t * arr, int64_t arrsize)
313+ static void argselect_64bit_ (type_t *arr,
314+ int64_t *arg,
315+ int64_t pos,
316+ int64_t left,
317+ int64_t right,
318+ int64_t max_iters)
319+ {
320+ /*
321+ * Resort to std::sort if quicksort isnt making any progress
322+ */
323+ if (max_iters <= 0 ) {
324+ std_argsort (arr, arg, left, right + 1 );
325+ return ;
326+ }
327+ /*
328+ * Base case: use bitonic networks to sort arrays <= 64
329+ */
330+ if (right + 1 - left <= 64 ) {
331+ argsort_64_64bit<vtype>(arr, arg + left, (int32_t )(right + 1 - left));
332+ return ;
333+ }
334+ type_t pivot = get_pivot_64bit<vtype>(arr, arg, left, right);
335+ type_t smallest = vtype::type_max ();
336+ type_t biggest = vtype::type_min ();
337+ int64_t pivot_index = partition_avx512_unrolled<vtype, 4 >(
338+ arr, arg, left, right + 1 , pivot, &smallest, &biggest);
339+ if ((pivot != smallest) && (pos < pivot_index))
340+ argselect_64bit_<vtype>(
341+ arr, arg, pos, left, pivot_index - 1 , max_iters - 1 );
342+ else if ((pivot != biggest) && (pos >= pivot_index))
343+ argselect_64bit_<vtype>(
344+ arr, arg, pos, pivot_index, right, max_iters - 1 );
345+ }
346+
347+ template <typename vtype, typename type_t >
348+ bool has_nan (type_t *arr, int64_t arrsize)
288349{
289350 using opmask_t = typename vtype::opmask_t ;
290351 using zmm_t = typename vtype::zmm_t ;
@@ -299,7 +360,7 @@ bool has_nan(type_t* arr, int64_t arrsize)
299360 else {
300361 in = vtype::loadu (arr);
301362 }
302- opmask_t nanmask = vtype::template fpclass<0x01 | 0x80 >(in);
363+ opmask_t nanmask = vtype::template fpclass<0x01 | 0x80 >(in);
303364 arr += vtype::numlanes;
304365 arrsize -= vtype::numlanes;
305366 if (nanmask != 0x00 ) {
@@ -310,8 +371,9 @@ bool has_nan(type_t* arr, int64_t arrsize)
310371 return found_nan;
311372}
312373
374+ /* argsort methods for 32-bit and 64-bit dtypes */
313375template <typename T>
314- void avx512_argsort (T* arr, int64_t *arg, int64_t arrsize)
376+ void avx512_argsort (T * arr, int64_t *arg, int64_t arrsize)
315377{
316378 if (arrsize > 1 ) {
317379 argsort_64bit_<zmm_vector<T>>(
@@ -320,7 +382,7 @@ void avx512_argsort(T* arr, int64_t *arg, int64_t arrsize)
320382}
321383
322384template <>
323- void avx512_argsort (double * arr, int64_t *arg, int64_t arrsize)
385+ void avx512_argsort (double * arr, int64_t *arg, int64_t arrsize)
324386{
325387 if (arrsize > 1 ) {
326388 if (has_nan<zmm_vector<double >>(arr, arrsize)) {
@@ -333,9 +395,8 @@ void avx512_argsort(double* arr, int64_t *arg, int64_t arrsize)
333395 }
334396}
335397
336-
337398template <>
338- void avx512_argsort (int32_t * arr, int64_t *arg, int64_t arrsize)
399+ void avx512_argsort (int32_t * arr, int64_t *arg, int64_t arrsize)
339400{
340401 if (arrsize > 1 ) {
341402 argsort_64bit_<ymm_vector<int32_t >>(
@@ -344,7 +405,7 @@ void avx512_argsort(int32_t* arr, int64_t *arg, int64_t arrsize)
344405}
345406
346407template <>
347- void avx512_argsort (uint32_t * arr, int64_t *arg, int64_t arrsize)
408+ void avx512_argsort (uint32_t * arr, int64_t *arg, int64_t arrsize)
348409{
349410 if (arrsize > 1 ) {
350411 argsort_64bit_<ymm_vector<uint32_t >>(
@@ -353,7 +414,7 @@ void avx512_argsort(uint32_t* arr, int64_t *arg, int64_t arrsize)
353414}
354415
355416template <>
356- void avx512_argsort (float * arr, int64_t *arg, int64_t arrsize)
417+ void avx512_argsort (float * arr, int64_t *arg, int64_t arrsize)
357418{
358419 if (arrsize > 1 ) {
359420 if (has_nan<ymm_vector<float >>(arr, arrsize)) {
@@ -367,12 +428,77 @@ void avx512_argsort(float* arr, int64_t *arg, int64_t arrsize)
367428}
368429
369430template <typename T>
370- std::vector<int64_t > avx512_argsort (T* arr, int64_t arrsize)
431+ std::vector<int64_t > avx512_argsort (T * arr, int64_t arrsize)
371432{
372433 std::vector<int64_t > indices (arrsize);
373434 std::iota (indices.begin (), indices.end (), 0 );
374435 avx512_argsort<T>(arr, indices.data (), arrsize);
375436 return indices;
376437}
377438
439+ /* argselect methods for 32-bit and 64-bit dtypes */
440+ template <typename T>
441+ void avx512_argselect (T *arr, int64_t *arg, int64_t k, int64_t arrsize)
442+ {
443+ if (arrsize > 1 ) {
444+ argselect_64bit_<zmm_vector<T>>(
445+ arr, arg, k, 0 , arrsize - 1 , 2 * (int64_t )log2 (arrsize));
446+ }
447+ }
448+
449+ template <>
450+ void avx512_argselect (double *arr, int64_t *arg, int64_t k, int64_t arrsize)
451+ {
452+ if (arrsize > 1 ) {
453+ if (has_nan<zmm_vector<double >>(arr, arrsize)) {
454+ std_argselect_withnan (arr, arg, k, 0 , arrsize);
455+ }
456+ else {
457+ argselect_64bit_<zmm_vector<double >>(
458+ arr, arg, k, 0 , arrsize - 1 , 2 * (int64_t )log2 (arrsize));
459+ }
460+ }
461+ }
462+
463+ template <>
464+ void avx512_argselect (int32_t *arr, int64_t *arg, int64_t k, int64_t arrsize)
465+ {
466+ if (arrsize > 1 ) {
467+ argselect_64bit_<ymm_vector<int32_t >>(
468+ arr, arg, k, 0 , arrsize - 1 , 2 * (int64_t )log2 (arrsize));
469+ }
470+ }
471+
472+ template <>
473+ void avx512_argselect (uint32_t *arr, int64_t *arg, int64_t k, int64_t arrsize)
474+ {
475+ if (arrsize > 1 ) {
476+ argselect_64bit_<ymm_vector<uint32_t >>(
477+ arr, arg, k, 0 , arrsize - 1 , 2 * (int64_t )log2 (arrsize));
478+ }
479+ }
480+
481+ template <>
482+ void avx512_argselect (float *arr, int64_t *arg, int64_t k, int64_t arrsize)
483+ {
484+ if (arrsize > 1 ) {
485+ if (has_nan<ymm_vector<float >>(arr, arrsize)) {
486+ std_argselect_withnan (arr, arg, k, 0 , arrsize);
487+ }
488+ else {
489+ argselect_64bit_<ymm_vector<float >>(
490+ arr, arg, k, 0 , arrsize - 1 , 2 * (int64_t )log2 (arrsize));
491+ }
492+ }
493+ }
494+
495+ template <typename T>
496+ std::vector<int64_t > avx512_argselect (T *arr, int64_t k, int64_t arrsize)
497+ {
498+ std::vector<int64_t > indices (arrsize);
499+ std::iota (indices.begin (), indices.end (), 0 );
500+ avx512_argselect<T>(arr, indices.data (), k, arrsize);
501+ return indices;
502+ }
503+
378504#endif // AVX512_ARGSORT_64BIT
0 commit comments