@@ -348,113 +348,57 @@ static void argselect_64bit_(type_t *arr,
348348template <typename T>
349349void avx512_argsort (T *arr, int64_t *arg, int64_t arrsize)
350350{
351+ using vectype = typename std::conditional<sizeof (T) == sizeof (int32_t ),
352+ ymm_vector<T>,
353+ zmm_vector<T>>::type;
351354 if (arrsize > 1 ) {
352- argsort_64bit_<zmm_vector<T>>(
353- arr, arg, 0 , arrsize - 1 , 2 * (int64_t )log2 (arrsize));
354- }
355- }
356-
357- template <>
358- void avx512_argsort (double *arr, int64_t *arg, int64_t arrsize)
359- {
360- if (arrsize > 1 ) {
361- if (has_nan<zmm_vector<double >>(arr, arrsize)) {
362- std_argsort_withnan (arr, arg, 0 , arrsize);
355+ if constexpr (std::is_floating_point_v<T>) {
356+ if (has_nan<vectype>(arr, arrsize)) {
357+ std_argsort_withnan (arr, arg, 0 , arrsize);
358+ return ;
359+ }
363360 }
364- else {
365- argsort_64bit_<zmm_vector<double >>(
366- arr, arg, 0 , arrsize - 1 , 2 * (int64_t )log2 (arrsize));
367- }
368- }
369- }
370-
371- template <>
372- void avx512_argsort (int32_t *arr, int64_t *arg, int64_t arrsize)
373- {
374- if (arrsize > 1 ) {
375- argsort_64bit_<ymm_vector<int32_t >>(
361+ argsort_64bit_<vectype>(
376362 arr, arg, 0 , arrsize - 1 , 2 * (int64_t )log2 (arrsize));
377363 }
378364}
379365
380- template <>
381- void avx512_argsort (uint32_t *arr, int64_t *arg, int64_t arrsize)
382- {
383- if (arrsize > 1 ) {
384- argsort_64bit_<ymm_vector<uint32_t >>(
385- arr, arg, 0 , arrsize - 1 , 2 * (int64_t )log2 (arrsize));
386- }
387- }
388-
389- template <>
390- void avx512_argsort (float *arr, int64_t *arg, int64_t arrsize)
366+ template <typename T>
367+ std::vector<int64_t > avx512_argsort (T *arr, int64_t arrsize)
391368{
392- if (arrsize > 1 ) {
393- if (has_nan<ymm_vector<float >>(arr, arrsize)) {
394- std_argsort_withnan (arr, arg, 0 , arrsize);
395- }
396- else {
397- argsort_64bit_<ymm_vector<float >>(
398- arr, arg, 0 , arrsize - 1 , 2 * (int64_t )log2 (arrsize));
399- }
400- }
369+ std::vector<int64_t > indices (arrsize);
370+ std::iota (indices.begin (), indices.end (), 0 );
371+ avx512_argsort<T>(arr, indices.data (), arrsize);
372+ return indices;
401373}
402374
403375/* argselect methods for 32-bit and 64-bit dtypes */
404376template <typename T>
405377void avx512_argselect (T *arr, int64_t *arg, int64_t k, int64_t arrsize)
406378{
407- if (arrsize > 1 ) {
408- argselect_64bit_<zmm_vector<T>>(
409- arr, arg, k, 0 , arrsize - 1 , 2 * (int64_t )log2 (arrsize));
410- }
411- }
379+ using vectype = typename std::conditional<sizeof (T) == sizeof (int32_t ),
380+ ymm_vector<T>,
381+ zmm_vector<T>>::type;
412382
413- template <>
414- void avx512_argselect (double *arr, int64_t *arg, int64_t k, int64_t arrsize)
415- {
416383 if (arrsize > 1 ) {
417- if (has_nan<zmm_vector<double >>(arr, arrsize)) {
418- std_argselect_withnan (arr, arg, k, 0 , arrsize);
419- }
420- else {
421- argselect_64bit_<zmm_vector<double >>(
422- arr, arg, k, 0 , arrsize - 1 , 2 * (int64_t )log2 (arrsize));
384+ if constexpr (std::is_floating_point_v<T>) {
385+ if (has_nan<vectype>(arr, arrsize)) {
386+ std_argselect_withnan (arr, arg, k, 0 , arrsize);
387+ return ;
388+ }
423389 }
424- }
425- }
426-
427- template <>
428- void avx512_argselect (int32_t *arr, int64_t *arg, int64_t k, int64_t arrsize)
429- {
430- if (arrsize > 1 ) {
431- argselect_64bit_<ymm_vector<int32_t >>(
390+ argselect_64bit_<vectype>(
432391 arr, arg, k, 0 , arrsize - 1 , 2 * (int64_t )log2 (arrsize));
433392 }
434393}
435394
436- template <>
437- void avx512_argselect (uint32_t *arr, int64_t *arg, int64_t k, int64_t arrsize)
438- {
439- if (arrsize > 1 ) {
440- argselect_64bit_<ymm_vector<uint32_t >>(
441- arr, arg, k, 0 , arrsize - 1 , 2 * (int64_t )log2 (arrsize));
442- }
443- }
444-
445- template <>
446- void avx512_argselect (float *arr, int64_t *arg, int64_t k, int64_t arrsize)
395+ template <typename T>
396+ std::vector<int64_t > avx512_argselect (T *arr, int64_t k, int64_t arrsize)
447397{
448- if (arrsize > 1 ) {
449- if (has_nan<ymm_vector<float >>(arr, arrsize)) {
450- std_argselect_withnan (arr, arg, k, 0 , arrsize);
451- }
452- else {
453- argselect_64bit_<ymm_vector<float >>(
454- arr, arg, k, 0 , arrsize - 1 , 2 * (int64_t )log2 (arrsize));
455- }
456- }
398+ std::vector<int64_t > indices (arrsize);
399+ std::iota (indices.begin (), indices.end (), 0 );
400+ avx512_argselect<T>(arr, indices.data (), k, arrsize);
401+ return indices;
457402}
458403
459-
460404#endif // AVX512_ARGSORT_64BIT
0 commit comments