-
Notifications
You must be signed in to change notification settings - Fork 195
[BUG] Fix CAGRA search recall with a graph built by NN Descent #819
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
92b70f9
cac426b
4982692
04cde70
0cfd92d
211eb5d
c55c01a
d9dfe5f
d70e92a
73a9769
4dcb4dd
f0d93ab
669ae25
a787da1
dc6a3f3
8fccf58
88638fc
6d84fc1
ed11f70
751c13b
48cc3db
c603d3d
d5fea47
0ecb878
e53df73
e8c9dea
663116f
8c5b8eb
dfc239d
47e4c8f
025d8c7
228824a
129dfdd
fdf7f69
3ee213d
5c76b43
e7efe71
0cf5a29
a623bd2
a13b72e
1995529
876f17e
3ef14a1
0668816
c191c29
bce9a3a
0faa5a6
88584d2
390b21d
73d6311
7dddecb
5f22011
2255584
5386cab
a690d2a
819b621
b4d348c
b116082
afcab42
420698d
820e7ad
1d1532c
4373f1f
8f4726a
9350ddb
d64652f
96c6594
b7ebda8
456d7a8
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
| @@ -1,5 +1,5 @@ | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| /* | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| * SPDX-FileCopyrightText: Copyright (c) 2023-2026, NVIDIA CORPORATION. | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| * SPDX-FileCopyrightText: Copyright (c) 2023-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| * SPDX-License-Identifier: Apache-2.0 | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| */ | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
@@ -39,6 +39,7 @@ | |||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| #include <mma.h> | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| #include <limits> | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| #include <numeric> | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| #include <optional> | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| #include <queue> | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| #include <random> | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
@@ -1166,8 +1167,14 @@ GnndGraph<Index_t>::GnndGraph(raft::resources const& res, | |||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| h_list_sizes_old{raft::make_pinned_vector<int2, size_t>(res, nrow)} | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| { | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| // node_degree must be a multiple of segment_size; | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| assert(node_degree % segment_size == 0); | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| assert(internal_node_degree % segment_size == 0); | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| RAFT_EXPECTS(node_degree % segment_size == 0, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| "node_degree (%u) %% segment_size (%u) == 0", | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| static_cast<uint32_t>(node_degree), | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| static_cast<uint32_t>(segment_size)); | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| RAFT_EXPECTS(internal_node_degree % segment_size == 0, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| "internal_node_degree (%u) %% segment_size (%u) == 0", | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| static_cast<uint32_t>(internal_node_degree), | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| static_cast<uint32_t>(segment_size)); | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| num_segments = node_degree / segment_size; | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| // To save the CPU memory, graph should be allocated by external function | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
@@ -1198,38 +1205,46 @@ void GnndGraph<Index_t>::sample_graph_new(InternalID_t<Index_t>* new_neighbors, | |||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| // Initialize the graph with random neighbors and apply the segmentation rule. Split the neighbor | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| // list into num_segments segments. A neighbor with index v is placed into segment (v % | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| // num_segments). The details are in Sec 4.3 in H Wang et.al. "Fast k-NN Graph Construction by GPU | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| // based NN-Descent". | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| template <typename Index_t> | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| void GnndGraph<Index_t>::init_random_graph() | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| { | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| for (size_t seg_idx = 0; seg_idx < static_cast<size_t>(num_segments); seg_idx++) { | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| // random sequence (range: 0~nrow) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| // segment_x stores neighbors which id % num_segments == x | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| std::vector<Index_t> rand_seq((nrow + num_segments - 1) / num_segments); | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| std::iota(rand_seq.begin(), rand_seq.end(), 0); | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| auto gen = std::default_random_engine{seg_idx}; | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| std::shuffle(rand_seq.begin(), rand_seq.end(), gen); | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| const auto extended_nrows = | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| raft::round_up_safe(static_cast<uint32_t>(nrow), static_cast<uint32_t>(num_segments)); | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| for (uint32_t seg_id = 0; seg_id < static_cast<uint32_t>(num_segments); seg_id++) { | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| const auto actual_segment_size = | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| std::min(static_cast<uint64_t>(segment_size), node_degree - seg_id * segment_size); | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| uint64_t stride = nrow / segment_size; | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| while (std::gcd(extended_nrows, stride) != 1 || std::gcd(actual_segment_size, stride) != 1) { | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| stride++; | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| #pragma omp parallel for | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| for (size_t i = 0; i < nrow; i++) { | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| size_t base_idx = i * node_degree + seg_idx * segment_size; | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| auto h_neighbor_list = h_graph + base_idx; | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| auto h_dist_list = h_dists.data_handle() + base_idx; | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| size_t idx = base_idx; | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| size_t self_in_this_seg = 0; | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| for (size_t j = 0; j < static_cast<size_t>(segment_size); j++) { | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| Index_t id = rand_seq[idx % rand_seq.size()] * num_segments + seg_idx; | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| if ((size_t)id == i) { | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| idx++; | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| id = rand_seq[idx % rand_seq.size()] * num_segments + seg_idx; | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| self_in_this_seg = 1; | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| for (uint64_t i = 0; i < nrow; i++) { | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| // Generate a starting index. The node ((i + 1) % nrow) will be included in the neighbor list | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| // of node i. This rule guarantees the connectivity of the graph. | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| uint64_t id; | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| if ((i + 1) % num_segments == seg_id) { | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| id = i + 1; | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| if (id >= nrow) { id = seg_id; } | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| } else { | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| id = (i + 1) * num_segments + seg_id; | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| for (uint32_t j = 0; j < actual_segment_size; j++) { | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| while (id >= nrow || id == i) { | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| id = (id + stride * num_segments) % extended_nrows; | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
coderabbitai[bot] marked this conversation as resolved.
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| h_neighbor_list[j].id_with_flag() = | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| j < (rand_seq.size() - self_in_this_seg) && size_t(id) < nrow | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| ? id | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| : std::numeric_limits<Index_t>::max(); | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| h_dist_list[j] = std::numeric_limits<DistData_t>::max(); | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| idx++; | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| const auto store_index = i * node_degree + seg_id * segment_size + j; | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| h_graph[store_index].id_with_flag() = id; | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| h_dists.data_handle()[store_index] = std::numeric_limits<DistData_t>::max(); | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| id = (id + num_segments * stride) % nrow; | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
Comment on lines
+1238
to
1248
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 🧩 Analysis chain🏁 Script executed: #!/bin/bash
# Confirm update_graph and sample_graph rely on `id % num_segments` to address segment slots.
rg -nP -C3 '%\s*num_segments' cpp/src/neighbors/detail/nn_descent.cuhRepository: rapidsai/cuvs Length of output: 1085 🏁 Script executed: #!/bin/bash
# Look at the full context around lines 1224-1234
sed -n '1200,1250p' cpp/src/neighbors/detail/nn_descent.cuhRepository: rapidsai/cuvs Length of output: 1927 🏁 Script executed: #!/bin/bash
# Find extended_nrows calculation
rg -nP 'extended_nrows\s*=' cpp/src/neighbors/detail/nn_descent.cuhRepository: rapidsai/cuvs Length of output: 92 🏁 Script executed: #!/bin/bash
# Find stride calculation and how it relates to num_segments, nrow
rg -nP 'stride\s*=' cpp/src/neighbors/detail/nn_descent.cuhRepository: rapidsai/cuvs Length of output: 105 🏁 Script executed: #!/bin/bash
# Find update_graph and sample_graph implementations to confirm they use id % num_segments
rg -nP -A10 'void update_graph\(|void sample_graph\(' cpp/src/neighbors/detail/nn_descent.cuh | head -100Repository: rapidsai/cuvs Length of output: 39 🏁 Script executed: #!/bin/bash
# Search for where h_graph is read, especially in update_graph and sample_graph
rg -nP 'update_graph|sample_graph' cpp/src/neighbors/detail/nn_descent.cuh | head -20Repository: rapidsai/cuvs Length of output: 678 🏁 Script executed: #!/bin/bash
# Find where the segment-based indexing is used - look for how segments are accessed
rg -nP -B5 -A5 'seg_idx.*segment_size|segment_id.*segment_size' cpp/src/neighbors/detail/nn_descent.cuhRepository: rapidsai/cuvs Length of output: 781 🏁 Script executed: #!/bin/bash
# Look at the actual fix more carefully - check if storing values in [nrow, extended_nrows) would break anything
rg -nP 'h_graph\[.*\]' cpp/src/neighbors/detail/nn_descent.cuh | head -30Repository: rapidsai/cuvs Length of output: 334 🏁 Script executed: #!/bin/bash
# Check if any code assumes id < nrow or validates boundaries when reading h_graph
rg -nP -B3 -A3 'id\(\)|\.id\(\)' cpp/src/neighbors/detail/nn_descent.cuh | grep -A3 -B3 'nrow\|>= nrow\|< nrow' | head -50Repository: rapidsai/cuvs Length of output: 756 🏁 Script executed: #!/bin/bash
# Look at the full sample_graph implementation to see how it uses h_graph entries
sed -n '1240,1285p' cpp/src/neighbors/detail/nn_descent.cuhRepository: rapidsai/cuvs Length of output: 1838 Fix modulus mismatch in post-store advance to preserve the segmentation invariant. The inner Concrete example: Use Suggested fix- id = (id + num_segments * stride) % nrow;
+ id = (id + num_segments * stride) % extended_nrows;📝 Committable suggestion
Suggested change
🤖 Prompt for AI Agents |
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
@@ -1330,7 +1345,6 @@ void GnndGraph<Index_t>::clear() | |||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| template <typename Index_t> | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| GnndGraph<Index_t>::~GnndGraph() | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| { | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| assert(h_graph == nullptr); | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| template <typename Data_t, typename Index_t> | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
@@ -1712,17 +1726,47 @@ void GNND<Data_t, Index_t>::build(Data_t* data, | |||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| Index_t* graph_shrink_buffer = (Index_t*)graph_.h_dists.data_handle(); | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| // Copy the output graph while removing duplicates. | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| #pragma omp parallel for | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| for (size_t i = 0; i < (size_t)nrow_; i++) { | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| for (size_t j = 0; j < build_config_.node_degree; j++) { | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| size_t idx = i * graph_.node_degree + j; | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| int id = graph_.h_graph[idx].id(); | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| if (id < static_cast<int>(nrow_)) { | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| graph_shrink_buffer[i * build_config_.node_degree + j] = id; | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| } else { | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| graph_shrink_buffer[i * build_config_.node_degree + j] = | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| cuvs::neighbors::detail::device::xorshift64(idx) % nrow_; | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| auto output_neighbor_list_ptr = graph_shrink_buffer + i * build_config_.node_degree; | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| size_t out_j = 0; | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| // Copy neighbor list while removing duplicates. | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| for (size_t in_j = 0; in_j < build_config_.node_degree; in_j++) { | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| size_t idx = graph_.h_graph[i * graph_.node_degree + in_j].id(); | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| bool dup = false; | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| for (size_t exi_j = 0; exi_j < out_j; exi_j++) { | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| if (static_cast<decltype(idx)>(output_neighbor_list_ptr[exi_j]) == idx || i == idx) { | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| dup = true; | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| break; | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| if (!dup) { | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| output_neighbor_list_ptr[out_j] = idx; | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| out_j++; | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| // Fill with random nodes if the length of the filled neighbor list is less than the degree. | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| for (size_t j = out_j; j < build_config_.node_degree; j++) { | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| uint64_t rnd = static_cast<uint64_t>(i * build_config_.node_degree + j + 1); | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| uint64_t idx; | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| bool dup = false; | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| do { | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| rnd = cuvs::neighbors::detail::device::xorshift64(rnd); | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| idx = rnd % nrow_; | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| dup = false; | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| for (size_t exi_j = 0; exi_j < j; exi_j++) { | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| if (static_cast<decltype(idx)>(output_neighbor_list_ptr[exi_j]) == idx || i == idx) { | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| dup = true; | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| break; | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| } while (dup); | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| output_neighbor_list_ptr[j] = static_cast<int>(idx); | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| graph_.h_graph = nullptr; | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
@@ -1757,6 +1801,18 @@ void build(raft::resources const& res, | |||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| auto int_graph = | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| raft::make_host_matrix<int, int64_t, raft::row_major>(dataset.extent(0), extended_graph_degree); | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| // When the graph will be a complete graph, output it without NND process for better performance. | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| if (static_cast<size_t>(dataset.extent(0) - 1) == graph_degree && (!params.return_distances)) { | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| auto graph = idx.graph().data_handle(); | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| #pragma omp parallel for | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| for (size_t i = 0; i < static_cast<size_t>(dataset.extent(0)); i++) { | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| for (size_t j = 0; j < graph_degree; j++) { | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| graph[i * graph_degree + j] = (i + j + 1) % dataset.extent(0); | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| return; | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
coderabbitai[bot] marked this conversation as resolved.
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| GNND<const T, int> nnd(res, build_config); | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| if (idx.distances().has_value() || !params.return_distances) { | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
Uh oh!
There was an error while loading. Please reload this page.