Support a custom comparison operator in DeviceReduce::ArgMin#8285
Support a custom comparison operator in DeviceReduce::ArgMin#8285bernhardmgruber wants to merge 4 commits intoNVIDIA:mainfrom
DeviceReduce::ArgMin#8285Conversation
DeviceReduce::ArgMinDeviceReduce::ArgMin
| // TODO(bgruber): this constraint is not accurate, since the implementation will compare the value types of | ||
| // ExtremumOutIteratorT, which is wrong IMO | ||
| ::cuda::std::enable_if_t<::cuda::std::indirectly_comparable<InputIteratorT, InputIteratorT, CompareOpT>, int> = 0> |
There was a problem hiding this comment.
Instead of InputIteratorT we should use non_void_value_t<ExtremumOutIteratorT, it_value_t<InputIteratorT>>, but that just "feels" wrong here. But this is what the implementation does. What do the reviewers think?
I think the implementation should actually be changed to compare the input values, not the converted ones.
😬 CI Workflow Results🟥 Finished in 1h 57m: Pass: 59%/249 | Total: 7d 12h | Max: 1h 27m | Hits: 61%/126688See results here. |
NaderAlAwar
left a comment
There was a problem hiding this comment.
Suggestion: the issue being closed mentions ArgMax as well in the title, but this PR only appears to add public custom-comparator overloads and
test coverage for ArgMin. The internal refactor is more general, but DeviceReduce::ArgMax still seems to expose only the old no-comparator API. I would either create a separate issue for ArgMax or expose the custom comparator overload as well.
Fixes: #6123