From f14f5216d2acd949eec00b83127e23b86e70c9cc Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Daniel=20J=C3=BCnger?= Date: Fri, 14 Jun 2024 23:50:44 +0000 Subject: [PATCH 01/11] Add static_multiset::for_each --- .../open_addressing_ref_impl.cuh | 100 ++++++++++++++++++ .../static_multiset/static_multiset_ref.inl | 42 ++++++++ include/cuco/operator.hpp | 6 ++ tests/CMakeLists.txt | 3 +- 4 files changed, 150 insertions(+), 1 deletion(-) diff --git a/include/cuco/detail/open_addressing/open_addressing_ref_impl.cuh b/include/cuco/detail/open_addressing/open_addressing_ref_impl.cuh index 876ef65c5..21b438eff 100644 --- a/include/cuco/detail/open_addressing/open_addressing_ref_impl.cuh +++ b/include/cuco/detail/open_addressing/open_addressing_ref_impl.cuh @@ -962,6 +962,106 @@ class open_addressing_ref_impl { } } + /** + * @brief Executes a callback on every element in the container with key equivalent to the probe + key. + * + * @note Passes an un-incrementable input iterator to the element whose key is equivalent to + * `key` to the callback. + * + * @tparam ProbeKey Input type which is convertible to 'key_type' + + @tparam Callback Callback functor or lambda + * + * @param key The key to search for + * @param callback Function to call on every element found + */ + template + __device__ void for_each(ProbeKey const& key, Callback callback) const noexcept + { + static_assert(cg_size == 1, "Non-CG operation is incompatible with the current probing scheme"); + auto probing_iter = this->probing_scheme_(key, this->storage_ref_.window_extent()); + + while (true) { + // TODO atomic_ref::load if insert operator is present + auto const window_slots = this->storage_ref_[*probing_iter]; + + for (auto i = 0; i < window_size; ++i) { + switch ( + this->predicate_.operator()(key, this->extract_key(window_slots[i]))) { + case detail::equal_result::EMPTY: { + return; + } + case detail::equal_result::EQUAL: { + callback(const_iterator{&(*(this->storage_ref_.data() + *probing_iter))[i]}); + if constexpr (allows_duplicates) { + continue; + } else { + return; + } + } + default: continue; + } + } + ++probing_iter; + } + } + + /** + * @brief Executes a callback on every element in the container with key equivalent to the probe + key. + * + * @note Passes an un-incrementable input iterator to the element whose key is equivalent to + * `key` to the callback. + * + * @note This function uses cooperative group semantics, meaning that any thread may call the + * callback if it finds a matching element. If multiple elements are found within the same group, + * each thread with a match will call the callback with its associated element. + * + * @tparam ProbeKey Input type which is convertible to 'key_type' + + @tparam Callback Callback functor or lambda + * + * @param group The Cooperative Group used to perform this operation + * @param key The key to search for + * @param callback Function to call on every element found + */ + template + __device__ void for_each(cooperative_groups::thread_block_tile const& group, + ProbeKey const& key, + Callback callback) const noexcept + { + auto probing_iter = probing_scheme_(group, key, storage_ref_.window_extent()); + + while (true) { + auto const window_slots = storage_ref_[*probing_iter]; + + auto const [state, intra_window_index] = [&]() { + auto res = detail::equal_result::UNEQUAL; + for (auto i = 0; i < window_size; ++i) { + res = this->predicate_.operator()(key, this->extract_key(window_slots[i])); + if (res != detail::equal_result::UNEQUAL) { return window_probing_results{res, i}; } + } + // returns dummy index `-1` for UNEQUAL + return window_probing_results{res, -1}; + }(); + + // Find a match for the probe key, thus call the callback with an iterator to the entry + auto const equal = state == detail::equal_result::EQUAL; + if (equal) { + callback(const_iterator{&(*(storage_ref_.data() + *probing_iter))[intra_window_index]}); + } + + if constexpr (not allows_duplicates) { + if (group.any(equal)) { return; } + } + + // Find an empty slot, meaning that the probe key isn't present in the container + auto const empty = state == detail::equal_result::EMPTY; + if (group.any(empty)) { return; } + + ++probing_iter; + } + } + /** * @brief Compares the content of the address `address` (old value) with the `expected` value and, * only if they are the same, sets the content of `address` to `desired`. diff --git a/include/cuco/detail/static_multiset/static_multiset_ref.inl b/include/cuco/detail/static_multiset/static_multiset_ref.inl index d4fadc9cb..0c1ae97db 100644 --- a/include/cuco/detail/static_multiset/static_multiset_ref.inl +++ b/include/cuco/detail/static_multiset/static_multiset_ref.inl @@ -446,6 +446,48 @@ class operator_impl< } }; +template +class operator_impl< + op::for_each_tag, + static_multiset_ref> { + using base_type = static_multiset_ref; + using ref_type = + static_multiset_ref; + using key_type = typename base_type::key_type; + using value_type = typename base_type::value_type; + using iterator = typename base_type::iterator; + using const_iterator = typename base_type::const_iterator; + + static constexpr auto cg_size = base_type::cg_size; + static constexpr auto window_size = base_type::window_size; + + public: + // TODO docs + template + __device__ void for_each(ProbeKey const& key, Callback callback) const noexcept + { + // CRTP: cast `this` to the actual ref type + auto const& ref_ = static_cast(*this); + ref_.impl_.for_each(key, callback); + } + + // TODO docs + template + __device__ void for_each(cooperative_groups::thread_block_tile const& group, + ProbeKey const& key, + Callback callback) const noexcept + { + // CRTP: cast `this` to the actual ref type + auto const& ref_ = static_cast(*this); + ref_.impl_.for_each(group, key, callback); + } +}; + template Date: Sat, 15 Jun 2024 00:18:33 +0000 Subject: [PATCH 02/11] Add unit test --- tests/static_multiset/for_each_test.cu | 138 +++++++++++++++++++++++++ 1 file changed, 138 insertions(+) create mode 100644 tests/static_multiset/for_each_test.cu diff --git a/tests/static_multiset/for_each_test.cu b/tests/static_multiset/for_each_test.cu new file mode 100644 index 000000000..d4f97b76c --- /dev/null +++ b/tests/static_multiset/for_each_test.cu @@ -0,0 +1,138 @@ +/* + * Copyright (c) 2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include + +#include +#include + +#include +#include +#include +#include + +#include +#include + +#include + +#include + +template +CUCO_KERNEL void for_each_check_scalar(Ref ref, + InputIt first, + std::size_t n, + std::size_t multiplicity, + AtomicErrorCounter* error_counter) +{ + static_assert(Ref::cg_size == 1, "Scalar test must have cg_size==1"); + auto const loop_stride = cuco::detail::grid_stride(); + auto idx = cuco::detail::global_thread_id(); + + while (idx < n) { + auto const& key = *(first + idx); + std::size_t matches = 0; + ref.for_each(key, [&] __device__(auto const it) { + if (ref.key_eq()(key, *it)) { matches++; } + }); + if (matches != multiplicity) { error_counter->fetch_add(1, cuda::memory_order_relaxed); } + idx += loop_stride; + } +} + +template +CUCO_KERNEL void for_each_check_cooperative(Ref ref, + InputIt first, + std::size_t n, + std::size_t multiplicity, + AtomicErrorCounter* error_counter) +{ + auto const loop_stride = cuco::detail::grid_stride() / Ref::cg_size; + auto idx = cuco::detail::global_thread_id() / Ref::cg_size; + ; + + while (idx < n) { + auto const tile = + cooperative_groups::tiled_partition(cooperative_groups::this_thread_block()); + auto const& key = *(first + idx); + std::size_t thread_matches = 0; + ref.for_each(tile, key, [&] __device__(auto const it) { + if (ref.key_eq()(key, *it)) { thread_matches++; } + }); + auto const tile_matches = + cooperative_groups::reduce(tile, thread_matches, cooperative_groups::plus()); + if (tile_matches != multiplicity and tile.thread_rank() == 0) { + error_counter->fetch_add(1, cuda::memory_order_relaxed); + } + idx += loop_stride; + } +} + +TEMPLATE_TEST_CASE_SIG( + "static_multiset for_each tests", + "", + ((typename Key, cuco::test::probe_sequence Probe, int CGSize), Key, Probe, CGSize), + (int32_t, cuco::test::probe_sequence::double_hashing, 1), + (int32_t, cuco::test::probe_sequence::double_hashing, 2), + (int64_t, cuco::test::probe_sequence::double_hashing, 1), + (int64_t, cuco::test::probe_sequence::double_hashing, 2), + (int32_t, cuco::test::probe_sequence::linear_probing, 1), + (int32_t, cuco::test::probe_sequence::linear_probing, 2), + (int64_t, cuco::test::probe_sequence::linear_probing, 1), + (int64_t, cuco::test::probe_sequence::linear_probing, 2)) +{ + constexpr size_t num_unique_keys{400}; + constexpr size_t key_multiplicity{5}; + constexpr size_t num_keys{num_unique_keys * key_multiplicity}; + + using probe = std::conditional_t>, + cuco::double_hashing>>; + + auto set = + cuco::static_multiset{num_keys, cuco::empty_key{-1}, {}, probe{}, {}, cuco::storage<2>{}}; + + auto unique_keys_begin = thrust::counting_iterator(0); + auto gen_duplicate_keys = cuda::proclaim_return_type( + [] __device__(auto const& k) { return static_cast(k % num_unique_keys); }); + auto keys_begin = thrust::make_transform_iterator(unique_keys_begin, gen_duplicate_keys); + + set.insert(keys_begin, keys_begin + num_keys); + + using error_counter_type = cuda::atomic; + error_counter_type* error_counter; + CUCO_CUDA_TRY(cudaMallocHost(&error_counter, sizeof(error_counter_type))); + new (error_counter) error_counter_type{0}; + + auto const grid_size = cuco::detail::grid_size(num_unique_keys, CGSize); + auto const block_size = cuco::detail::default_block_size(); + + // test scalar for_each + if constexpr (CGSize == 1) { + for_each_check_scalar<<>>( + set.ref(cuco::for_each), unique_keys_begin, num_unique_keys, key_multiplicity, error_counter); + CUCO_CUDA_TRY(cudaDeviceSynchronize()); + REQUIRE(error_counter->load() == 0); + error_counter->store(0); + } + + // test CG for_each + + for_each_check_cooperative<<>>( + set.ref(cuco::for_each), unique_keys_begin, num_unique_keys, key_multiplicity, error_counter); + + CUCO_CUDA_TRY(cudaFreeHost(error_counter)); +} \ No newline at end of file From 099503cb407fa9cbf7f7c9e7e63d954044c354dd Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Daniel=20J=C3=BCnger?= Date: Sat, 15 Jun 2024 00:19:08 +0000 Subject: [PATCH 03/11] Fix docstring --- .../cuco/detail/open_addressing/open_addressing_ref_impl.cuh | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/include/cuco/detail/open_addressing/open_addressing_ref_impl.cuh b/include/cuco/detail/open_addressing/open_addressing_ref_impl.cuh index 21b438eff..92ad4fe3c 100644 --- a/include/cuco/detail/open_addressing/open_addressing_ref_impl.cuh +++ b/include/cuco/detail/open_addressing/open_addressing_ref_impl.cuh @@ -964,7 +964,7 @@ class open_addressing_ref_impl { /** * @brief Executes a callback on every element in the container with key equivalent to the probe - key. + * key. * * @note Passes an un-incrementable input iterator to the element whose key is equivalent to * `key` to the callback. @@ -1008,7 +1008,7 @@ class open_addressing_ref_impl { /** * @brief Executes a callback on every element in the container with key equivalent to the probe - key. + * key. * * @note Passes an un-incrementable input iterator to the element whose key is equivalent to * `key` to the callback. From ca41a48d192e040a4c9ad8a90912eeb01c5403a5 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Daniel=20J=C3=BCnger?= Date: Sat, 15 Jun 2024 00:20:11 +0000 Subject: [PATCH 04/11] Remove newline --- tests/static_multiset/for_each_test.cu | 1 - 1 file changed, 1 deletion(-) diff --git a/tests/static_multiset/for_each_test.cu b/tests/static_multiset/for_each_test.cu index d4f97b76c..c8effd4fe 100644 --- a/tests/static_multiset/for_each_test.cu +++ b/tests/static_multiset/for_each_test.cu @@ -130,7 +130,6 @@ TEMPLATE_TEST_CASE_SIG( } // test CG for_each - for_each_check_cooperative<<>>( set.ref(cuco::for_each), unique_keys_begin, num_unique_keys, key_multiplicity, error_counter); From e7a8e03c60e9e389511cfc65c809b9569028d559 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Daniel=20J=C3=BCnger?= Date: Sat, 15 Jun 2024 00:21:50 +0000 Subject: [PATCH 05/11] Add operator docs --- .../static_multiset/static_multiset_ref.inl | 33 +++++++++++++++++-- 1 file changed, 31 insertions(+), 2 deletions(-) diff --git a/include/cuco/detail/static_multiset/static_multiset_ref.inl b/include/cuco/detail/static_multiset/static_multiset_ref.inl index 0c1ae97db..d34586579 100644 --- a/include/cuco/detail/static_multiset/static_multiset_ref.inl +++ b/include/cuco/detail/static_multiset/static_multiset_ref.inl @@ -467,7 +467,19 @@ class operator_impl< static constexpr auto window_size = base_type::window_size; public: - // TODO docs + /** + * @brief Executes a callback on every element in the container with key equivalent to the probe + * key. + * + * @note Passes an un-incrementable input iterator to the element whose key is equivalent to + * `key` to the callback. + * + * @tparam ProbeKey Input type which is convertible to 'key_type' + + @tparam Callback Callback functor or lambda + * + * @param key The key to search for + * @param callback Function to call on every element found + */ template __device__ void for_each(ProbeKey const& key, Callback callback) const noexcept { @@ -476,7 +488,24 @@ class operator_impl< ref_.impl_.for_each(key, callback); } - // TODO docs + /** + * @brief Executes a callback on every element in the container with key equivalent to the probe + * key. + * + * @note Passes an un-incrementable input iterator to the element whose key is equivalent to + * `key` to the callback. + * + * @note This function uses cooperative group semantics, meaning that any thread may call the + * callback if it finds a matching element. If multiple elements are found within the same group, + * each thread with a match will call the callback with its associated element. + * + * @tparam ProbeKey Input type which is convertible to 'key_type' + + @tparam Callback Callback functor or lambda + * + * @param group The Cooperative Group used to perform this operation + * @param key The key to search for + * @param callback Function to call on every element found + */ template __device__ void for_each(cooperative_groups::thread_block_tile const& group, ProbeKey const& key, From 7053703623c73fcc1acfcaa1fc5eab2b9ef366b4 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Daniel=20J=C3=BCnger?= Date: Sat, 15 Jun 2024 00:22:47 +0000 Subject: [PATCH 06/11] Fix unit test --- tests/static_multiset/for_each_test.cu | 2 ++ 1 file changed, 2 insertions(+) diff --git a/tests/static_multiset/for_each_test.cu b/tests/static_multiset/for_each_test.cu index c8effd4fe..e7ecece1e 100644 --- a/tests/static_multiset/for_each_test.cu +++ b/tests/static_multiset/for_each_test.cu @@ -132,6 +132,8 @@ TEMPLATE_TEST_CASE_SIG( // test CG for_each for_each_check_cooperative<<>>( set.ref(cuco::for_each), unique_keys_begin, num_unique_keys, key_multiplicity, error_counter); + CUCO_CUDA_TRY(cudaDeviceSynchronize()); + REQUIRE(error_counter->load() == 0); CUCO_CUDA_TRY(cudaFreeHost(error_counter)); } \ No newline at end of file From 18c5f60e0a1808c70f9766eec3d71062ac4acc35 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Daniel=20J=C3=BCnger?= Date: Sat, 15 Jun 2024 01:35:51 +0000 Subject: [PATCH 07/11] Pass callback as universal reference --- .../open_addressing/open_addressing_ref_impl.cuh | 4 ++-- .../detail/static_multiset/static_multiset_ref.inl | 10 ++++++---- tests/static_multiset/for_each_test.cu | 1 + 3 files changed, 9 insertions(+), 6 deletions(-) diff --git a/include/cuco/detail/open_addressing/open_addressing_ref_impl.cuh b/include/cuco/detail/open_addressing/open_addressing_ref_impl.cuh index 92ad4fe3c..1c77c76be 100644 --- a/include/cuco/detail/open_addressing/open_addressing_ref_impl.cuh +++ b/include/cuco/detail/open_addressing/open_addressing_ref_impl.cuh @@ -976,7 +976,7 @@ class open_addressing_ref_impl { * @param callback Function to call on every element found */ template - __device__ void for_each(ProbeKey const& key, Callback callback) const noexcept + __device__ void for_each(ProbeKey const& key, Callback&& callback) const noexcept { static_assert(cg_size == 1, "Non-CG operation is incompatible with the current probing scheme"); auto probing_iter = this->probing_scheme_(key, this->storage_ref_.window_extent()); @@ -1027,7 +1027,7 @@ class open_addressing_ref_impl { template __device__ void for_each(cooperative_groups::thread_block_tile const& group, ProbeKey const& key, - Callback callback) const noexcept + Callback&& callback) const noexcept { auto probing_iter = probing_scheme_(group, key, storage_ref_.window_extent()); diff --git a/include/cuco/detail/static_multiset/static_multiset_ref.inl b/include/cuco/detail/static_multiset/static_multiset_ref.inl index d34586579..78c54b1b4 100644 --- a/include/cuco/detail/static_multiset/static_multiset_ref.inl +++ b/include/cuco/detail/static_multiset/static_multiset_ref.inl @@ -22,6 +22,8 @@ #include +#include + namespace cuco { template - __device__ void for_each(ProbeKey const& key, Callback callback) const noexcept + __device__ void for_each(ProbeKey const& key, Callback&& callback) const noexcept { // CRTP: cast `this` to the actual ref type auto const& ref_ = static_cast(*this); - ref_.impl_.for_each(key, callback); + ref_.impl_.for_each(key, std::forward(callback)); } /** @@ -509,11 +511,11 @@ class operator_impl< template __device__ void for_each(cooperative_groups::thread_block_tile const& group, ProbeKey const& key, - Callback callback) const noexcept + Callback&& callback) const noexcept { // CRTP: cast `this` to the actual ref type auto const& ref_ = static_cast(*this); - ref_.impl_.for_each(group, key, callback); + ref_.impl_.for_each(group, key, std::forward(callback)); } }; diff --git a/tests/static_multiset/for_each_test.cu b/tests/static_multiset/for_each_test.cu index e7ecece1e..b0cb81091 100644 --- a/tests/static_multiset/for_each_test.cu +++ b/tests/static_multiset/for_each_test.cu @@ -72,6 +72,7 @@ CUCO_KERNEL void for_each_check_cooperative(Ref ref, ref.for_each(tile, key, [&] __device__(auto const it) { if (ref.key_eq()(key, *it)) { thread_matches++; } }); + tile.sync(); auto const tile_matches = cooperative_groups::reduce(tile, thread_matches, cooperative_groups::plus()); if (tile_matches != multiplicity and tile.thread_rank() == 0) { From 22dab496e3c1d8291e18c45e8ab717331c5739e8 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Daniel=20J=C3=BCnger?= Date: Sat, 15 Jun 2024 01:39:54 +0000 Subject: [PATCH 08/11] Remove unused operator members --- .../cuco/detail/static_multiset/static_multiset_ref.inl | 7 +------ 1 file changed, 1 insertion(+), 6 deletions(-) diff --git a/include/cuco/detail/static_multiset/static_multiset_ref.inl b/include/cuco/detail/static_multiset/static_multiset_ref.inl index 78c54b1b4..61d7b1ac6 100644 --- a/include/cuco/detail/static_multiset/static_multiset_ref.inl +++ b/include/cuco/detail/static_multiset/static_multiset_ref.inl @@ -460,13 +460,8 @@ class operator_impl< using base_type = static_multiset_ref; using ref_type = static_multiset_ref; - using key_type = typename base_type::key_type; - using value_type = typename base_type::value_type; - using iterator = typename base_type::iterator; - using const_iterator = typename base_type::const_iterator; - static constexpr auto cg_size = base_type::cg_size; - static constexpr auto window_size = base_type::window_size; + static constexpr auto cg_size = base_type::cg_size; public: /** From bd309c154ac11c1a810c002c6c894506217b8493 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Daniel=20J=C3=BCnger?= Date: Tue, 25 Jun 2024 00:51:11 +0000 Subject: [PATCH 09/11] Fix probing logic --- .../open_addressing_ref_impl.cuh | 48 ++++++++----------- 1 file changed, 20 insertions(+), 28 deletions(-) diff --git a/include/cuco/detail/open_addressing/open_addressing_ref_impl.cuh b/include/cuco/detail/open_addressing/open_addressing_ref_impl.cuh index 1c77c76be..33d2c5322 100644 --- a/include/cuco/detail/open_addressing/open_addressing_ref_impl.cuh +++ b/include/cuco/detail/open_addressing/open_addressing_ref_impl.cuh @@ -985,7 +985,7 @@ class open_addressing_ref_impl { // TODO atomic_ref::load if insert operator is present auto const window_slots = this->storage_ref_[*probing_iter]; - for (auto i = 0; i < window_size; ++i) { + for (int32_t i = 0; i < window_size; ++i) { switch ( this->predicate_.operator()(key, this->extract_key(window_slots[i]))) { case detail::equal_result::EMPTY: { @@ -993,11 +993,7 @@ class open_addressing_ref_impl { } case detail::equal_result::EQUAL: { callback(const_iterator{&(*(this->storage_ref_.data() + *probing_iter))[i]}); - if constexpr (allows_duplicates) { - continue; - } else { - return; - } + continue; } default: continue; } @@ -1029,33 +1025,29 @@ class open_addressing_ref_impl { ProbeKey const& key, Callback&& callback) const noexcept { - auto probing_iter = probing_scheme_(group, key, storage_ref_.window_extent()); + auto probing_iter = this->probing_scheme_(group, key, this->storage_ref_.window_extent()); + bool empty = false; while (true) { - auto const window_slots = storage_ref_[*probing_iter]; + // TODO atomic_ref::load if insert operator is present + auto const window_slots = this->storage_ref_[*probing_iter]; - auto const [state, intra_window_index] = [&]() { - auto res = detail::equal_result::UNEQUAL; - for (auto i = 0; i < window_size; ++i) { - res = this->predicate_.operator()(key, this->extract_key(window_slots[i])); - if (res != detail::equal_result::UNEQUAL) { return window_probing_results{res, i}; } + for (int32_t i = 0; i < window_size and !empty; ++i) { + switch ( + this->predicate_.operator()(key, this->extract_key(window_slots[i]))) { + case detail::equal_result::EMPTY: { + empty = true; + continue; + } + case detail::equal_result::EQUAL: { + callback(const_iterator{&(*(this->storage_ref_.data() + *probing_iter))[i]}); + continue; + } + default: { + continue; + } } - // returns dummy index `-1` for UNEQUAL - return window_probing_results{res, -1}; - }(); - - // Find a match for the probe key, thus call the callback with an iterator to the entry - auto const equal = state == detail::equal_result::EQUAL; - if (equal) { - callback(const_iterator{&(*(storage_ref_.data() + *probing_iter))[intra_window_index]}); } - - if constexpr (not allows_duplicates) { - if (group.any(equal)) { return; } - } - - // Find an empty slot, meaning that the probe key isn't present in the container - auto const empty = state == detail::equal_result::EMPTY; if (group.any(empty)) { return; } ++probing_iter; From 6f6e5ff8b084044b5b166a451f1931cb27e565c4 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Daniel=20J=C3=BCnger?= Date: Tue, 25 Jun 2024 00:55:35 +0000 Subject: [PATCH 10/11] Rename callback to make the usage a bit more clear --- .../open_addressing/open_addressing_ref_impl.cuh | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/include/cuco/detail/open_addressing/open_addressing_ref_impl.cuh b/include/cuco/detail/open_addressing/open_addressing_ref_impl.cuh index 33d2c5322..673b84bf3 100644 --- a/include/cuco/detail/open_addressing/open_addressing_ref_impl.cuh +++ b/include/cuco/detail/open_addressing/open_addressing_ref_impl.cuh @@ -970,13 +970,13 @@ class open_addressing_ref_impl { * `key` to the callback. * * @tparam ProbeKey Input type which is convertible to 'key_type' - + @tparam Callback Callback functor or lambda + * @tparam CallbackOp Unary callback functor or device lambda * * @param key The key to search for * @param callback Function to call on every element found */ - template - __device__ void for_each(ProbeKey const& key, Callback&& callback) const noexcept + template + __device__ void for_each(ProbeKey const& key, CallbackOp&& callback) const noexcept { static_assert(cg_size == 1, "Non-CG operation is incompatible with the current probing scheme"); auto probing_iter = this->probing_scheme_(key, this->storage_ref_.window_extent()); @@ -1014,16 +1014,16 @@ class open_addressing_ref_impl { * each thread with a match will call the callback with its associated element. * * @tparam ProbeKey Input type which is convertible to 'key_type' - + @tparam Callback Callback functor or lambda + * @tparam CallbackOp Unary callback functor or device lambda * * @param group The Cooperative Group used to perform this operation * @param key The key to search for * @param callback Function to call on every element found */ - template + template __device__ void for_each(cooperative_groups::thread_block_tile const& group, ProbeKey const& key, - Callback&& callback) const noexcept + CallbackOp&& callback) const noexcept { auto probing_iter = this->probing_scheme_(group, key, this->storage_ref_.window_extent()); bool empty = false; From 7cd072c26c2524e57f35f6621657ca6f565ef558 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Daniel=20J=C3=BCnger?= Date: Tue, 25 Jun 2024 01:59:15 +0000 Subject: [PATCH 11/11] Add overload that allows for synchronizing the CG inbetween probing windows (required for shmem bounce buffer flushing during retrieve()) --- .../open_addressing_ref_impl.cuh | 78 +++++++++++++++++-- .../static_multiset/static_multiset_ref.inl | 62 ++++++++++++--- tests/static_multiset/for_each_test.cu | 28 +++++-- 3 files changed, 146 insertions(+), 22 deletions(-) diff --git a/include/cuco/detail/open_addressing/open_addressing_ref_impl.cuh b/include/cuco/detail/open_addressing/open_addressing_ref_impl.cuh index 673b84bf3..5df8f0110 100644 --- a/include/cuco/detail/open_addressing/open_addressing_ref_impl.cuh +++ b/include/cuco/detail/open_addressing/open_addressing_ref_impl.cuh @@ -973,10 +973,10 @@ class open_addressing_ref_impl { * @tparam CallbackOp Unary callback functor or device lambda * * @param key The key to search for - * @param callback Function to call on every element found + * @param callback_op Function to call on every element found */ template - __device__ void for_each(ProbeKey const& key, CallbackOp&& callback) const noexcept + __device__ void for_each(ProbeKey const& key, CallbackOp&& callback_op) const noexcept { static_assert(cg_size == 1, "Non-CG operation is incompatible with the current probing scheme"); auto probing_iter = this->probing_scheme_(key, this->storage_ref_.window_extent()); @@ -992,7 +992,7 @@ class open_addressing_ref_impl { return; } case detail::equal_result::EQUAL: { - callback(const_iterator{&(*(this->storage_ref_.data() + *probing_iter))[i]}); + callback_op(const_iterator{&(*(this->storage_ref_.data() + *probing_iter))[i]}); continue; } default: continue; @@ -1013,17 +1013,82 @@ class open_addressing_ref_impl { * callback if it finds a matching element. If multiple elements are found within the same group, * each thread with a match will call the callback with its associated element. * + * @note Synchronizing `group` within `callback_op` is undefined behavior. + * * @tparam ProbeKey Input type which is convertible to 'key_type' * @tparam CallbackOp Unary callback functor or device lambda * * @param group The Cooperative Group used to perform this operation * @param key The key to search for - * @param callback Function to call on every element found + * @param callback_op Function to call on every element found */ template __device__ void for_each(cooperative_groups::thread_block_tile const& group, ProbeKey const& key, - CallbackOp&& callback) const noexcept + CallbackOp&& callback_op) const noexcept + { + auto probing_iter = this->probing_scheme_(group, key, this->storage_ref_.window_extent()); + bool empty = false; + + while (true) { + // TODO atomic_ref::load if insert operator is present + auto const window_slots = this->storage_ref_[*probing_iter]; + + for (int32_t i = 0; i < window_size and !empty; ++i) { + switch ( + this->predicate_.operator()(key, this->extract_key(window_slots[i]))) { + case detail::equal_result::EMPTY: { + empty = true; + continue; + } + case detail::equal_result::EQUAL: { + callback_op(const_iterator{&(*(this->storage_ref_.data() + *probing_iter))[i]}); + continue; + } + default: { + continue; + } + } + } + if (group.any(empty)) { return; } + + ++probing_iter; + } + } + + /** + * @brief Executes a callback on every element in the container with key equivalent to the probe + * key and can additionally perform work that requires synchronizing the Cooperative Group + * performing this operation. + * + * @note Passes an un-incrementable input iterator to the element whose key is equivalent to + * `key` to the callback. + * + * @note This function uses cooperative group semantics, meaning that any thread may call the + * callback if it finds a matching element. If multiple elements are found within the same group, + * each thread with a match will call the callback with its associated element. + * + * @note Synchronizing `group` within `callback_op` is undefined behavior. + * + * @note The `sync_op` function can be used to perform work that requires synchronizing threads in + * `group` inbetween probing steps, where the number of probing steps performed between + * synchronization points is capped by `window_size * cg_size`. The functor will be called right + * after the current probing window has been traversed. + * + * @tparam ProbeKey Input type which is convertible to 'key_type' + * @tparam CallbackOp Unary callback functor or device lambda + * @tparam SyncOp Functor or device lambda which accepts the current `group` object + * + * @param group The Cooperative Group used to perform this operation + * @param key The key to search for + * @param callback_op Function to call on every element found + * @param sync_op Function that is allowed to synchronize `group` inbetween probing windows + */ + template + __device__ void for_each(cooperative_groups::thread_block_tile const& group, + ProbeKey const& key, + CallbackOp&& callback_op, + SyncOp&& sync_op) const noexcept { auto probing_iter = this->probing_scheme_(group, key, this->storage_ref_.window_extent()); bool empty = false; @@ -1040,7 +1105,7 @@ class open_addressing_ref_impl { continue; } case detail::equal_result::EQUAL: { - callback(const_iterator{&(*(this->storage_ref_.data() + *probing_iter))[i]}); + callback_op(const_iterator{&(*(this->storage_ref_.data() + *probing_iter))[i]}); continue; } default: { @@ -1048,6 +1113,7 @@ class open_addressing_ref_impl { } } } + sync_op(group); if (group.any(empty)) { return; } ++probing_iter; diff --git a/include/cuco/detail/static_multiset/static_multiset_ref.inl b/include/cuco/detail/static_multiset/static_multiset_ref.inl index 61d7b1ac6..1cd92d1ee 100644 --- a/include/cuco/detail/static_multiset/static_multiset_ref.inl +++ b/include/cuco/detail/static_multiset/static_multiset_ref.inl @@ -472,17 +472,17 @@ class operator_impl< * `key` to the callback. * * @tparam ProbeKey Input type which is convertible to 'key_type' - + @tparam Callback Callback functor or lambda + * @tparam CallbackOp Unary callback functor or device lambda * * @param key The key to search for - * @param callback Function to call on every element found + * @param callback_op Function to call on every element found */ - template - __device__ void for_each(ProbeKey const& key, Callback&& callback) const noexcept + template + __device__ void for_each(ProbeKey const& key, CallbackOp&& callback_op) const noexcept { // CRTP: cast `this` to the actual ref type auto const& ref_ = static_cast(*this); - ref_.impl_.for_each(key, std::forward(callback)); + ref_.impl_.for_each(key, std::forward(callback_op)); } /** @@ -496,21 +496,63 @@ class operator_impl< * callback if it finds a matching element. If multiple elements are found within the same group, * each thread with a match will call the callback with its associated element. * + * @note Synchronizing `group` within `callback_op` is undefined behavior. + * + * @tparam ProbeKey Input type which is convertible to 'key_type' + * @tparam CallbackOp Unary callback functor or device lambda + * + * @param group The Cooperative Group used to perform this operation + * @param key The key to search for + * @param callback_op Function to call on every element found + */ + template + __device__ void for_each(cooperative_groups::thread_block_tile const& group, + ProbeKey const& key, + CallbackOp&& callback_op) const noexcept + { + // CRTP: cast `this` to the actual ref type + auto const& ref_ = static_cast(*this); + ref_.impl_.for_each(group, key, std::forward(callback_op)); + } + + /** + * @brief Executes a callback on every element in the container with key equivalent to the probe + * key and can additionally perform work that requires synchronizing the Cooperative Group + * performing this operation. + * + * @note Passes an un-incrementable input iterator to the element whose key is equivalent to + * `key` to the callback. + * + * @note This function uses cooperative group semantics, meaning that any thread may call the + * callback if it finds a matching element. If multiple elements are found within the same group, + * each thread with a match will call the callback with its associated element. + * + * @note Synchronizing `group` within `callback_op` is undefined behavior. + * + * @note The `sync_op` function can be used to perform work that requires synchronizing threads in + * `group` inbetween probing steps, where the number of probing steps performed between + * synchronization points is capped by `window_size * cg_size`. The functor will be called right + * after the current probing window has been traversed. + * * @tparam ProbeKey Input type which is convertible to 'key_type' - + @tparam Callback Callback functor or lambda + * @tparam CallbackOp Unary callback functor or device lambda + * @tparam SyncOp Functor or device lambda which accepts the current `group` object * * @param group The Cooperative Group used to perform this operation * @param key The key to search for - * @param callback Function to call on every element found + * @param callback_op Function to call on every element found + * @param sync_op Function that is allowed to synchronize `group` inbetween probing windows */ - template + template __device__ void for_each(cooperative_groups::thread_block_tile const& group, ProbeKey const& key, - Callback&& callback) const noexcept + CallbackOp&& callback_op, + SyncOp&& sync_op) const noexcept { // CRTP: cast `this` to the actual ref type auto const& ref_ = static_cast(*this); - ref_.impl_.for_each(group, key, std::forward(callback)); + ref_.impl_.for_each( + group, key, std::forward(callback_op), std::forward(sync_op)); } }; diff --git a/tests/static_multiset/for_each_test.cu b/tests/static_multiset/for_each_test.cu index b0cb81091..1872586b7 100644 --- a/tests/static_multiset/for_each_test.cu +++ b/tests/static_multiset/for_each_test.cu @@ -53,7 +53,7 @@ CUCO_KERNEL void for_each_check_scalar(Ref ref, } } -template +template CUCO_KERNEL void for_each_check_cooperative(Ref ref, InputIt first, std::size_t n, @@ -69,10 +69,19 @@ CUCO_KERNEL void for_each_check_cooperative(Ref ref, cooperative_groups::tiled_partition(cooperative_groups::this_thread_block()); auto const& key = *(first + idx); std::size_t thread_matches = 0; - ref.for_each(tile, key, [&] __device__(auto const it) { - if (ref.key_eq()(key, *it)) { thread_matches++; } - }); - tile.sync(); + if constexpr (Synced) { + ref.for_each( + tile, + key, + [&] __device__(auto const it) { + if (ref.key_eq()(key, *it)) { thread_matches++; } + }, + [] __device__(auto const& group) { group.sync(); }); + } else { + ref.for_each(tile, key, [&] __device__(auto const it) { + if (ref.key_eq()(key, *it)) { thread_matches++; } + }); + } auto const tile_matches = cooperative_groups::reduce(tile, thread_matches, cooperative_groups::plus()); if (tile_matches != multiplicity and tile.thread_rank() == 0) { @@ -131,7 +140,14 @@ TEMPLATE_TEST_CASE_SIG( } // test CG for_each - for_each_check_cooperative<<>>( + for_each_check_cooperative<<>>( + set.ref(cuco::for_each), unique_keys_begin, num_unique_keys, key_multiplicity, error_counter); + CUCO_CUDA_TRY(cudaDeviceSynchronize()); + REQUIRE(error_counter->load() == 0); + error_counter->store(0); + + // test synchronized CG for_each + for_each_check_cooperative<<>>( set.ref(cuco::for_each), unique_keys_begin, num_unique_keys, key_multiplicity, error_counter); CUCO_CUDA_TRY(cudaDeviceSynchronize()); REQUIRE(error_counter->load() == 0);