Skip to content

Commit 6c7dfa8

Browse files
authored
Merge pull request #2 from VectorDB-NTU/hotfix
Deal with the cases where nprobe > num_clusters.
2 parents c16aacc + e4149f8 commit 6c7dfa8

File tree

2 files changed

+9
-0
lines changed

2 files changed

+9
-0
lines changed

rabitqlib/index/ivf/ivf.hpp

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -101,6 +101,8 @@ class IVF {
101101
void search(const float*, size_t, size_t, PID*, bool) const;
102102

103103
[[nodiscard]] size_t padded_dim() const { return this->padded_dim_; }
104+
105+
[[nodiscard]] size_t num_clusters() const { return this->num_cluster_; }
104106
};
105107

106108
inline IVF::IVF(size_t n, size_t dim, size_t cluster_num, size_t bits, RotatorType type)
@@ -374,6 +376,9 @@ inline void IVF::search(
374376
PID* __restrict__ results,
375377
bool use_hacc = true
376378
) const {
379+
if (nprobe > num_cluster_) {
380+
nprobe = num_cluster_;
381+
}
377382
std::vector<float> rotated_query(padded_dim_);
378383
this->rotator_->rotate(query, rotated_query.data());
379384

sample/ivf_rabitq_querying.cpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -86,6 +86,10 @@ int main(int argc, char** argv) {
8686
for (size_t r = 0; r < test_round; r++) {
8787
for (size_t l = 0; l < length; ++l) {
8888
size_t nprobe = nprobes[l];
89+
if (nprobe > ivf.num_clusters()) {
90+
std::cout << "nprobe " << nprobe << " is larger than number of clusters, ";
91+
std::cout << "will use nprobe = num_cluster (" << ivf.num_clusters() << ").\n";
92+
}
8993
size_t total_correct = 0;
9094
float total_time = 0;
9195
std::vector<PID> results(topk);

0 commit comments

Comments
 (0)