diff --git a/include/cuco/detail/open_addressing/open_addressing_ref_impl.cuh b/include/cuco/detail/open_addressing/open_addressing_ref_impl.cuh index e06e1d684..060c78737 100644 --- a/include/cuco/detail/open_addressing/open_addressing_ref_impl.cuh +++ b/include/cuco/detail/open_addressing/open_addressing_ref_impl.cuh @@ -962,6 +962,164 @@ class open_addressing_ref_impl { } } + /** + * @brief Executes a callback on every element in the container with key equivalent to the probe + * key. + * + * @note Passes an un-incrementable input iterator to the element whose key is equivalent to + * `key` to the callback. + * + * @tparam ProbeKey Input type which is convertible to 'key_type' + * @tparam CallbackOp Unary callback functor or device lambda + * + * @param key The key to search for + * @param callback_op Function to call on every element found + */ + template + __device__ void for_each(ProbeKey const& key, CallbackOp&& callback_op) const noexcept + { + static_assert(cg_size == 1, "Non-CG operation is incompatible with the current probing scheme"); + auto probing_iter = this->probing_scheme_(key, this->storage_ref_.window_extent()); + + while (true) { + // TODO atomic_ref::load if insert operator is present + auto const window_slots = this->storage_ref_[*probing_iter]; + + for (int32_t i = 0; i < window_size; ++i) { + switch ( + this->predicate_.operator()(key, this->extract_key(window_slots[i]))) { + case detail::equal_result::EMPTY: { + return; + } + case detail::equal_result::EQUAL: { + callback_op(const_iterator{&(*(this->storage_ref_.data() + *probing_iter))[i]}); + continue; + } + default: continue; + } + } + ++probing_iter; + } + } + + /** + * @brief Executes a callback on every element in the container with key equivalent to the probe + * key. + * + * @note Passes an un-incrementable input iterator to the element whose key is equivalent to + * `key` to the callback. + * + * @note This function uses cooperative group semantics, meaning that any thread may call the + * callback if it finds a matching element. If multiple elements are found within the same group, + * each thread with a match will call the callback with its associated element. + * + * @note Synchronizing `group` within `callback_op` is undefined behavior. + * + * @tparam ProbeKey Input type which is convertible to 'key_type' + * @tparam CallbackOp Unary callback functor or device lambda + * + * @param group The Cooperative Group used to perform this operation + * @param key The key to search for + * @param callback_op Function to call on every element found + */ + template + __device__ void for_each(cooperative_groups::thread_block_tile const& group, + ProbeKey const& key, + CallbackOp&& callback_op) const noexcept + { + auto probing_iter = this->probing_scheme_(group, key, this->storage_ref_.window_extent()); + bool empty = false; + + while (true) { + // TODO atomic_ref::load if insert operator is present + auto const window_slots = this->storage_ref_[*probing_iter]; + + for (int32_t i = 0; i < window_size and !empty; ++i) { + switch ( + this->predicate_.operator()(key, this->extract_key(window_slots[i]))) { + case detail::equal_result::EMPTY: { + empty = true; + continue; + } + case detail::equal_result::EQUAL: { + callback_op(const_iterator{&(*(this->storage_ref_.data() + *probing_iter))[i]}); + continue; + } + default: { + continue; + } + } + } + if (group.any(empty)) { return; } + + ++probing_iter; + } + } + + /** + * @brief Executes a callback on every element in the container with key equivalent to the probe + * key and can additionally perform work that requires synchronizing the Cooperative Group + * performing this operation. + * + * @note Passes an un-incrementable input iterator to the element whose key is equivalent to + * `key` to the callback. + * + * @note This function uses cooperative group semantics, meaning that any thread may call the + * callback if it finds a matching element. If multiple elements are found within the same group, + * each thread with a match will call the callback with its associated element. + * + * @note Synchronizing `group` within `callback_op` is undefined behavior. + * + * @note The `sync_op` function can be used to perform work that requires synchronizing threads in + * `group` inbetween probing steps, where the number of probing steps performed between + * synchronization points is capped by `window_size * cg_size`. The functor will be called right + * after the current probing window has been traversed. + * + * @tparam ProbeKey Input type which is convertible to 'key_type' + * @tparam CallbackOp Unary callback functor or device lambda + * @tparam SyncOp Functor or device lambda which accepts the current `group` object + * + * @param group The Cooperative Group used to perform this operation + * @param key The key to search for + * @param callback_op Function to call on every element found + * @param sync_op Function that is allowed to synchronize `group` inbetween probing windows + */ + template + __device__ void for_each(cooperative_groups::thread_block_tile const& group, + ProbeKey const& key, + CallbackOp&& callback_op, + SyncOp&& sync_op) const noexcept + { + auto probing_iter = this->probing_scheme_(group, key, this->storage_ref_.window_extent()); + bool empty = false; + + while (true) { + // TODO atomic_ref::load if insert operator is present + auto const window_slots = this->storage_ref_[*probing_iter]; + + for (int32_t i = 0; i < window_size and !empty; ++i) { + switch ( + this->predicate_.operator()(key, this->extract_key(window_slots[i]))) { + case detail::equal_result::EMPTY: { + empty = true; + continue; + } + case detail::equal_result::EQUAL: { + callback_op(const_iterator{&(*(this->storage_ref_.data() + *probing_iter))[i]}); + continue; + } + default: { + continue; + } + } + } + sync_op(group); + if (group.any(empty)) { return; } + + ++probing_iter; + } + } + /** * @brief Compares the content of the address `address` (old value) with the `expected` value and, * only if they are the same, sets the content of `address` to `desired`. diff --git a/include/cuco/detail/static_multiset/static_multiset_ref.inl b/include/cuco/detail/static_multiset/static_multiset_ref.inl index d4fadc9cb..1cd92d1ee 100644 --- a/include/cuco/detail/static_multiset/static_multiset_ref.inl +++ b/include/cuco/detail/static_multiset/static_multiset_ref.inl @@ -22,6 +22,8 @@ #include +#include + namespace cuco { template +class operator_impl< + op::for_each_tag, + static_multiset_ref> { + using base_type = static_multiset_ref; + using ref_type = + static_multiset_ref; + + static constexpr auto cg_size = base_type::cg_size; + + public: + /** + * @brief Executes a callback on every element in the container with key equivalent to the probe + * key. + * + * @note Passes an un-incrementable input iterator to the element whose key is equivalent to + * `key` to the callback. + * + * @tparam ProbeKey Input type which is convertible to 'key_type' + * @tparam CallbackOp Unary callback functor or device lambda + * + * @param key The key to search for + * @param callback_op Function to call on every element found + */ + template + __device__ void for_each(ProbeKey const& key, CallbackOp&& callback_op) const noexcept + { + // CRTP: cast `this` to the actual ref type + auto const& ref_ = static_cast(*this); + ref_.impl_.for_each(key, std::forward(callback_op)); + } + + /** + * @brief Executes a callback on every element in the container with key equivalent to the probe + * key. + * + * @note Passes an un-incrementable input iterator to the element whose key is equivalent to + * `key` to the callback. + * + * @note This function uses cooperative group semantics, meaning that any thread may call the + * callback if it finds a matching element. If multiple elements are found within the same group, + * each thread with a match will call the callback with its associated element. + * + * @note Synchronizing `group` within `callback_op` is undefined behavior. + * + * @tparam ProbeKey Input type which is convertible to 'key_type' + * @tparam CallbackOp Unary callback functor or device lambda + * + * @param group The Cooperative Group used to perform this operation + * @param key The key to search for + * @param callback_op Function to call on every element found + */ + template + __device__ void for_each(cooperative_groups::thread_block_tile const& group, + ProbeKey const& key, + CallbackOp&& callback_op) const noexcept + { + // CRTP: cast `this` to the actual ref type + auto const& ref_ = static_cast(*this); + ref_.impl_.for_each(group, key, std::forward(callback_op)); + } + + /** + * @brief Executes a callback on every element in the container with key equivalent to the probe + * key and can additionally perform work that requires synchronizing the Cooperative Group + * performing this operation. + * + * @note Passes an un-incrementable input iterator to the element whose key is equivalent to + * `key` to the callback. + * + * @note This function uses cooperative group semantics, meaning that any thread may call the + * callback if it finds a matching element. If multiple elements are found within the same group, + * each thread with a match will call the callback with its associated element. + * + * @note Synchronizing `group` within `callback_op` is undefined behavior. + * + * @note The `sync_op` function can be used to perform work that requires synchronizing threads in + * `group` inbetween probing steps, where the number of probing steps performed between + * synchronization points is capped by `window_size * cg_size`. The functor will be called right + * after the current probing window has been traversed. + * + * @tparam ProbeKey Input type which is convertible to 'key_type' + * @tparam CallbackOp Unary callback functor or device lambda + * @tparam SyncOp Functor or device lambda which accepts the current `group` object + * + * @param group The Cooperative Group used to perform this operation + * @param key The key to search for + * @param callback_op Function to call on every element found + * @param sync_op Function that is allowed to synchronize `group` inbetween probing windows + */ + template + __device__ void for_each(cooperative_groups::thread_block_tile const& group, + ProbeKey const& key, + CallbackOp&& callback_op, + SyncOp&& sync_op) const noexcept + { + // CRTP: cast `this` to the actual ref type + auto const& ref_ = static_cast(*this); + ref_.impl_.for_each( + group, key, std::forward(callback_op), std::forward(sync_op)); + } +}; + template + +#include +#include + +#include +#include +#include +#include + +#include +#include + +#include + +#include + +template +CUCO_KERNEL void for_each_check_scalar(Ref ref, + InputIt first, + std::size_t n, + std::size_t multiplicity, + AtomicErrorCounter* error_counter) +{ + static_assert(Ref::cg_size == 1, "Scalar test must have cg_size==1"); + auto const loop_stride = cuco::detail::grid_stride(); + auto idx = cuco::detail::global_thread_id(); + + while (idx < n) { + auto const& key = *(first + idx); + std::size_t matches = 0; + ref.for_each(key, [&] __device__(auto const it) { + if (ref.key_eq()(key, *it)) { matches++; } + }); + if (matches != multiplicity) { error_counter->fetch_add(1, cuda::memory_order_relaxed); } + idx += loop_stride; + } +} + +template +CUCO_KERNEL void for_each_check_cooperative(Ref ref, + InputIt first, + std::size_t n, + std::size_t multiplicity, + AtomicErrorCounter* error_counter) +{ + auto const loop_stride = cuco::detail::grid_stride() / Ref::cg_size; + auto idx = cuco::detail::global_thread_id() / Ref::cg_size; + ; + + while (idx < n) { + auto const tile = + cooperative_groups::tiled_partition(cooperative_groups::this_thread_block()); + auto const& key = *(first + idx); + std::size_t thread_matches = 0; + if constexpr (Synced) { + ref.for_each( + tile, + key, + [&] __device__(auto const it) { + if (ref.key_eq()(key, *it)) { thread_matches++; } + }, + [] __device__(auto const& group) { group.sync(); }); + } else { + ref.for_each(tile, key, [&] __device__(auto const it) { + if (ref.key_eq()(key, *it)) { thread_matches++; } + }); + } + auto const tile_matches = + cooperative_groups::reduce(tile, thread_matches, cooperative_groups::plus()); + if (tile_matches != multiplicity and tile.thread_rank() == 0) { + error_counter->fetch_add(1, cuda::memory_order_relaxed); + } + idx += loop_stride; + } +} + +TEMPLATE_TEST_CASE_SIG( + "static_multiset for_each tests", + "", + ((typename Key, cuco::test::probe_sequence Probe, int CGSize), Key, Probe, CGSize), + (int32_t, cuco::test::probe_sequence::double_hashing, 1), + (int32_t, cuco::test::probe_sequence::double_hashing, 2), + (int64_t, cuco::test::probe_sequence::double_hashing, 1), + (int64_t, cuco::test::probe_sequence::double_hashing, 2), + (int32_t, cuco::test::probe_sequence::linear_probing, 1), + (int32_t, cuco::test::probe_sequence::linear_probing, 2), + (int64_t, cuco::test::probe_sequence::linear_probing, 1), + (int64_t, cuco::test::probe_sequence::linear_probing, 2)) +{ + constexpr size_t num_unique_keys{400}; + constexpr size_t key_multiplicity{5}; + constexpr size_t num_keys{num_unique_keys * key_multiplicity}; + + using probe = std::conditional_t>, + cuco::double_hashing>>; + + auto set = + cuco::static_multiset{num_keys, cuco::empty_key{-1}, {}, probe{}, {}, cuco::storage<2>{}}; + + auto unique_keys_begin = thrust::counting_iterator(0); + auto gen_duplicate_keys = cuda::proclaim_return_type( + [] __device__(auto const& k) { return static_cast(k % num_unique_keys); }); + auto keys_begin = thrust::make_transform_iterator(unique_keys_begin, gen_duplicate_keys); + + set.insert(keys_begin, keys_begin + num_keys); + + using error_counter_type = cuda::atomic; + error_counter_type* error_counter; + CUCO_CUDA_TRY(cudaMallocHost(&error_counter, sizeof(error_counter_type))); + new (error_counter) error_counter_type{0}; + + auto const grid_size = cuco::detail::grid_size(num_unique_keys, CGSize); + auto const block_size = cuco::detail::default_block_size(); + + // test scalar for_each + if constexpr (CGSize == 1) { + for_each_check_scalar<<>>( + set.ref(cuco::for_each), unique_keys_begin, num_unique_keys, key_multiplicity, error_counter); + CUCO_CUDA_TRY(cudaDeviceSynchronize()); + REQUIRE(error_counter->load() == 0); + error_counter->store(0); + } + + // test CG for_each + for_each_check_cooperative<<>>( + set.ref(cuco::for_each), unique_keys_begin, num_unique_keys, key_multiplicity, error_counter); + CUCO_CUDA_TRY(cudaDeviceSynchronize()); + REQUIRE(error_counter->load() == 0); + error_counter->store(0); + + // test synchronized CG for_each + for_each_check_cooperative<<>>( + set.ref(cuco::for_each), unique_keys_begin, num_unique_keys, key_multiplicity, error_counter); + CUCO_CUDA_TRY(cudaDeviceSynchronize()); + REQUIRE(error_counter->load() == 0); + + CUCO_CUDA_TRY(cudaFreeHost(error_counter)); +} \ No newline at end of file