|
3 | 3 | #include <immintrin.h> |
4 | 4 | #include <omp.h> |
5 | 5 |
|
| 6 | +#include <algorithm> |
6 | 7 | #include <cassert> |
7 | 8 | #include <cmath> |
8 | 9 | #include <cstddef> |
@@ -118,9 +119,7 @@ inline IVF::IVF(size_t n, size_t dim, size_t cluster_num, size_t bits, RotatorTy |
118 | 119 | std::cerr.flush(); |
119 | 120 | exit(1); |
120 | 121 | }; |
121 | | - rotator_ = choose_rotator<float>( |
122 | | - dim, RotatorType::FhtKacRotator, round_up_to_multiple(dim_, 64) |
123 | | - ); |
| 122 | + rotator_ = choose_rotator<float>(dim, type, round_up_to_multiple(dim_, 64)); |
124 | 123 | padded_dim_ = rotator_->size(); |
125 | 124 | /* check size */ |
126 | 125 | assert(padded_dim_ % 64 == 0); |
@@ -332,9 +331,7 @@ inline void IVF::load(const char* filename) { |
332 | 331 | input.read(reinterpret_cast<char*>(&this->ex_bits_), sizeof(size_t)); |
333 | 332 | input.read(reinterpret_cast<char*>(&type_), sizeof(type_)); |
334 | 333 |
|
335 | | - rotator_ = choose_rotator<float>( |
336 | | - dim_, RotatorType::FhtKacRotator, round_up_to_multiple(dim_, 64) |
337 | | - ); |
| 334 | + rotator_ = choose_rotator<float>(dim_, type_, round_up_to_multiple(dim_, 64)); |
338 | 335 | padded_dim_ = rotator_->size(); |
339 | 336 |
|
340 | 337 | /* Load number of vectors of each cluster */ |
@@ -376,12 +373,12 @@ inline void IVF::search( |
376 | 373 | PID* __restrict__ results, |
377 | 374 | bool use_hacc = true |
378 | 375 | ) const { |
379 | | - if (nprobe > num_cluster_) { |
380 | | - nprobe = num_cluster_; |
381 | | - } |
| 376 | + nprobe = std::min(nprobe, num_cluster_); // corner case |
382 | 377 | std::vector<float> rotated_query(padded_dim_); |
383 | 378 | this->rotator_->rotate(query, rotated_query.data()); |
384 | 379 |
|
| 380 | + std::cout << l2norm_sqr(query, dim_) << '\t' << l2norm_sqr(rotated_query.data(), padded_dim_) << '\n'; |
| 381 | + |
385 | 382 | // use initer to get closest nprobe centroids |
386 | 383 | std::vector<AnnCandidate<float>> centroid_dist(nprobe); |
387 | 384 | this->initer_->centroids_distances(rotated_query.data(), nprobe, centroid_dist); |
|
0 commit comments