From 8a75086b934181cb1e188c76b7f8ad142ecf9a61 Mon Sep 17 00:00:00 2001 From: Shanfeng Pang Date: Thu, 3 Jul 2025 16:00:28 +0800 Subject: [PATCH 1/2] Integrate SIMDe to support cross-platform SIMD operations. --- .gitignore | 1 + .gitmodules | 3 + CMakeLists.txt | 24 +- cmake/cpu_features.cmake | 54 +++ contrib/CMakeLists.txt | 6 + contrib/README.md | 306 +++++++++++++++ contrib/dataset.h | 497 ++++++++++++++++++++++++ contrib/dataset_example.cpp | 119 ++++++ contrib/test_normalization.cpp | 83 ++++ example.sh | 17 +- rabitqlib/fastscan/fastscan.hpp | 128 ++---- rabitqlib/fastscan/highacc_fastscan.hpp | 105 +++-- rabitqlib/quantization/pack_excode.hpp | 149 +++---- rabitqlib/third/simde | 1 + rabitqlib/utils/rotator.hpp | 52 +-- rabitqlib/utils/simde-utils.hpp | 69 ++++ rabitqlib/utils/space.hpp | 434 ++++++++++----------- rabitqlib/utils/warmup_space.hpp | 34 +- sample/ivf_rabitq_querying.cpp | 20 +- tests/CMakeLists.txt | 19 + tests/test_init.cpp | 5 + tests/test_ivf_rabitq.cpp | 102 +++++ 22 files changed, 1700 insertions(+), 528 deletions(-) create mode 100644 .gitmodules create mode 100644 cmake/cpu_features.cmake create mode 100644 contrib/CMakeLists.txt create mode 100644 contrib/README.md create mode 100644 contrib/dataset.h create mode 100644 contrib/dataset_example.cpp create mode 100644 contrib/test_normalization.cpp create mode 160000 rabitqlib/third/simde create mode 100644 rabitqlib/utils/simde-utils.hpp create mode 100644 tests/CMakeLists.txt create mode 100644 tests/test_init.cpp create mode 100644 tests/test_ivf_rabitq.cpp diff --git a/.gitignore b/.gitignore index 6c2a39d..b7df886 100644 --- a/.gitignore +++ b/.gitignore @@ -39,6 +39,7 @@ data/ *.pyc .clang-* +.vscode/ # macOS garbages .DS_Store diff --git a/.gitmodules b/.gitmodules new file mode 100644 index 0000000..928ca9a --- /dev/null +++ b/.gitmodules @@ -0,0 +1,3 @@ +[submodule "rabitqlib/third/simde"] + path = rabitqlib/third/simde + url = https://github.com/simd-everywhere/simde.git diff --git a/CMakeLists.txt b/CMakeLists.txt index c3c5927..bae94af 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -5,11 +5,31 @@ project(RaBitQLib LANGUAGES CXX) set(CMAKE_CXX_STANDARD 17) set(CMAKE_CXX_STANDARD_REQUIRED ON) +include(cmake/cpu_features.cmake) + +if (CMAKE_BUILD_TYPE STREQUAL "Debug") + message(STATUS "Building in debug mode") + set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -g") + set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -g") +elseif(CMAKE_BUILD_TYPE STREQUAL "Release") + message(STATUS "Building in release mode") + set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -O2") + set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -O2") +endif() include_directories(${PROJECT_SOURCE_DIR}/rabitqlib) +include_directories(${PROJECT_SOURCE_DIR}/rabitqlib/third/simde) set(CMAKE_EXPORT_COMPILE_COMMANDS ON) -SET(CMAKE_CXX_FLAGS "-Wall -Ofast -Wextra -lrt -march=native -fpic -fopenmp -ftree-vectorize -fexceptions") +SET(CMAKE_CXX_FLAGS "-Wall -Ofast -Wextra -lrt -march=native -fpic -fopenmp -ftree-vectorize -fexceptions -w") +SET(CMAKE_C_FLAGS "-Wall -Ofast -Wextra -lrt -march=native -fpic -fopenmp -ftree-vectorize -fexceptions -w") + +add_subdirectory(sample) + +option(BUILD_TESTS "Build tests" ON) -add_subdirectory(sample) \ No newline at end of file +if (BUILD_TESTS) + add_subdirectory(contrib) + add_subdirectory(tests) +endif() \ No newline at end of file diff --git a/cmake/cpu_features.cmake b/cmake/cpu_features.cmake new file mode 100644 index 0000000..3b75eb3 --- /dev/null +++ b/cmake/cpu_features.cmake @@ -0,0 +1,54 @@ +# check if the cpu supports avx512 +if(CMAKE_SYSTEM_PROCESSOR STREQUAL "x86_64") + execute_process(COMMAND lscpu OUTPUT_VARIABLE CPU_INFO) + string(REGEX MATCH "avx512f" AVX512_SUPPORT "${CPU_INFO}") + if(AVX512_SUPPORT) + set(AVX512_SUPPORT ON) + endif() + + string(REGEX MATCH "avx2" AVX2_SUPPORT "${CPU_INFO}") + if(AVX2_SUPPORT) + set(AVX2_SUPPORT ON) + endif() + + string(REGEX MATCH "avx" AVX_SUPPORT "${CPU_INFO}") + if(AVX_SUPPORT) + set(AVX_SUPPORT ON) + endif() + + string(REGEX MATCH "sse4_2" SSE4_2_SUPPORT "${CPU_INFO}") + if(SSE4_2_SUPPORT) + set(SSE4_2_SUPPORT ON) + endif() + + string(REGEX MATCH "sse4_1" SSE4_1_SUPPORT "${CPU_INFO}") + if(SSE4_1_SUPPORT) + set(SSE4_1_SUPPORT ON) + endif() +endif() + + +if(AVX512_SUPPORT) + set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -mavx512f") + set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -mavx512f") +endif() + +if(AVX2_SUPPORT) + set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -mavx2") + set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -mavx2") +endif() + +if(AVX_SUPPORT) + set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -mavx") + set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -mavx") +endif() + +if(SSE4_2_SUPPORT) + set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -msse4.2") + set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -msse4.2") +endif() + +if(SSE4_1_SUPPORT) + set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -msse4.1") + set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -msse4.1") +endif() \ No newline at end of file diff --git a/contrib/CMakeLists.txt b/contrib/CMakeLists.txt new file mode 100644 index 0000000..08e7087 --- /dev/null +++ b/contrib/CMakeLists.txt @@ -0,0 +1,6 @@ +if (BUILD_TESTS) + add_executable(dataset_example dataset_example.cpp) + # target_link_libraries(dataset_example PRIVATE rabitqlib) + target_include_directories(dataset_example PRIVATE ${PROJECT_SOURCE_DIR}/contrib) + target_include_directories(dataset_example PRIVATE ${PROJECT_SOURCE_DIR}/) +endif() \ No newline at end of file diff --git a/contrib/README.md b/contrib/README.md new file mode 100644 index 0000000..d688088 --- /dev/null +++ b/contrib/README.md @@ -0,0 +1,306 @@ +# Vector Dataset Generator + +This module provides a fully functional vector dataset generator specifically designed for building IVF indexes in the RaBitQ-Library. + +## Features + +- **Multiple Dataset Types**: Supports various generation methods such as random datasets, Gaussian mixture models, spherical clusters, etc. +- **Flexible Parameter Configuration**: Customizable parameters including number of data points, dimension, number of clusters, number of queries, etc. +- **Data Quality Assurance**: Built-in data validation and statistics computation +- **Batch Query Generation**: Automatically generates 10% of data points as normalized query vectors +- **Easy Integration**: Simple interfaces for seamless integration with the `IVF::construct` function + +## Main Class and Methods + +### `Dataset` Class + +The `Dataset` class is now an object-oriented design that encapsulates all dataset information and provides convenient access methods. + +```cpp +class Dataset { +public: + enum class DatasetType { + Random, + GaussianMixture, + SphericalCluster + }; + + // Constructor + Dataset( + size_t num_points, // Number of data points + size_t dim, // Vector dimension + size_t num_clusters, // Number of clusters + size_t num_queries, // Number of query vectors + DatasetType type, // Dataset generation type + uint32_t seed = 42 // Random seed + ); +}; +``` + +### Dataset Generation Types + +#### 1. Random Dataset (`DatasetType::Random`) + +```cpp +auto dataset = Dataset( + 10000, // num_points + 128, // dim + 10, // num_clusters + 1000, // num_queries + Dataset::DatasetType::Random, // type + 42 // seed +); +``` + +**Features**: +- Cluster centers are randomly distributed in space +- Data points are generated around centroids with Gaussian noise +- Suitable for basic functionality testing + +#### 2. Gaussian Mixture Model Dataset (`DatasetType::GaussianMixture`) + +```cpp +auto dataset = Dataset( + 50000, // num_points + 256, // dim + 20, // num_clusters + 5000, // num_queries + Dataset::DatasetType::GaussianMixture, // type + 42 // seed +); +``` + +**Features**: +- Centroids uniformly distributed in a hypercube +- Each cluster has its own standard deviation +- Data points are evenly distributed across clusters +- Closer to real-world cluster distribution + +#### 3. Spherical Cluster Dataset (`DatasetType::SphericalCluster`) + +```cpp +auto dataset = Dataset( + 8000, // num_points + 64, // dim + 8, // num_clusters + 800, // num_queries + Dataset::DatasetType::SphericalCluster, // type + 42 // seed +); +``` + +**Features**: +- Cluster centers randomly distributed in space +- Data points are located in spherical regions around centroids +- Different clusters may have different radii +- Useful for testing spatial distribution algorithms + +### Data Access Methods + +#### Basic Information + +```cpp +size_t num_points = dataset.get_num_points(); +size_t dim = dataset.get_dim(); +size_t num_clusters = dataset.get_num_clusters(); +``` + +#### Data Pointers + +```cpp +const float* data_ptr = dataset.get_data_ptr(); +const float* centroids_ptr = dataset.get_centroids_ptr(); +const PID* cluster_ids_ptr = dataset.get_cluster_ids_ptr(); +const float* queries_ptr = dataset.get_queries_ptr(); +``` + +#### Data Validation + +```cpp +bool is_valid = dataset.validate_dataset(); +``` + +#### Dataset Statistics + +```cpp +std::string stats = dataset.get_dataset_stats(); +std::cout << stats << std::endl; +``` + +## Usage Example + +### Basic Usage + +```cpp +#include "dataset.h" + +using namespace rabitqlib::test; + +int main() { + // Generate dataset with 10% queries + Dataset dataset(10000, 128, 10, 1000, Dataset::DatasetType::GaussianMixture); + + // Validate dataset + if (!dataset.validate_dataset()) { + std::cerr << "Dataset generation failed!" << std::endl; + return -1; + } + + // Access data + const float* data = dataset.get_data_ptr(); + const float* centroids = dataset.get_centroids_ptr(); + const PID* cluster_ids = dataset.get_cluster_ids_ptr(); + const float* queries = dataset.get_queries_ptr(); + + // Print statistics + std::cout << dataset.get_dataset_stats() << std::endl; + + return 0; +} +``` + +### Usage with IVF Index + +```cpp +#include "dataset.h" +#include + +using namespace rabitqlib::test; + +int main() { + // Generate test dataset + Dataset dataset(50000, 256, 20, 5000, Dataset::DatasetType::GaussianMixture); + + // Access data pointer + const float* data_ptr = dataset.get_data_ptr(); + const float* queries_ptr = dataset.get_queries_ptr(); + + // Create IVF index + rabitqlib::index::ivf::IVF ivf_index; + + // Build the index (adjust according to the actual IVF::construct interface) + // ivf_index.construct(data_ptr, dataset.get_num_points(), dataset.get_dim(), ...); + + // Test queries + for (size_t i = 0; i < dataset.get_num_queries(); ++i) { + const float* query = queries_ptr + i * dataset.get_dim(); + // Perform search with query + // auto results = ivf_index.search(query, k); + } + + return 0; +} +``` + +### Batch Query Testing + +```cpp +#include "dataset.h" + +using namespace rabitqlib::test; + +int main() { + Dataset dataset(10000, 128, 10, 1000, Dataset::DatasetType::SphericalCluster); + + const float* queries = dataset.get_queries_ptr(); + size_t num_queries = dataset.get_num_queries(); + size_t dim = dataset.get_dim(); + + // Test each query vector + for (size_t i = 0; i < num_queries; ++i) { + const float* query = queries + i * dim; + + // Verify query vector is normalized + float norm = 0.0f; + for (size_t j = 0; j < dim; ++j) { + norm += query[j] * query[j]; + } + norm = std::sqrt(norm); + + std::cout << "Query " << i << " norm: " << norm << std::endl; + } + + return 0; +} +``` + +### Data Quality Analysis + +```cpp +auto dataset = Dataset(8000, 64, 8, 800, Dataset::DatasetType::SphericalCluster); + +// Analyze value range +const float* data = dataset.get_data_ptr(); +float min_val = data[0], max_val = data[0]; +for (size_t i = 1; i < dataset.get_num_points() * dataset.get_dim(); ++i) { + min_val = std::min(min_val, data[i]); + max_val = std::max(max_val, data[i]); +} + +std::cout << "Value range: [" << min_val << ", " << max_val << "]" << std::endl; + +// Analyze cluster distribution +std::vector cluster_counts(dataset.get_num_clusters(), 0); +const PID* cluster_ids = dataset.get_cluster_ids_ptr(); +for (size_t i = 0; i < dataset.get_num_points(); ++i) { + cluster_counts[cluster_ids[i]]++; +} + +for (size_t i = 0; i < dataset.get_num_clusters(); ++i) { + std::cout << "Cluster " << i << ": " << cluster_counts[i] << " points" << std::endl; +} + +// Analyze query distribution +const float* queries = dataset.get_queries_ptr(); +std::cout << "Number of queries: " << dataset.get_num_queries() << std::endl; +``` + +## Parameter Recommendations + +### Number of Data Points +- **Small-scale test**: 1,000 - 10,000 points +- **Medium-scale**: 10,000 - 100,000 points +- **Large-scale test**: 100,000+ points + +### Vector Dimension +- **Low-dimensional**: 64 - 128 dimensions +- **Medium-dimensional**: 128 - 512 dimensions +- **High-dimensional**: 512+ dimensions + +### Number of Clusters +- **Default**: `num_points / 1000` +- **Dense clustering**: `num_points / 500` +- **Sparse clustering**: `num_points / 2000` + +### Number of Queries +- **Recommended**: `num_points / 10` (10% of data points) +- **Light testing**: `num_points / 20` (5% of data points) +- **Heavy testing**: `num_points / 5` (20% of data points) + +## Notes + +1. **Memory Usage**: Large datasets can consume a lot of memory. Ensure your system has enough RAM. +2. **Random Seed**: Use a fixed seed for reproducible results. +3. **Data Validation**: Always call `validate_dataset()` before using the data. +4. **Cluster IDs**: Cluster IDs start from 0, in the range [0, num_clusters - 1] +5. **Query Vectors**: All query vectors are automatically normalized to unit length +6. **Data Layout**: All vectors (data, centroids, queries) are stored in row-major format + +## Compilation and Execution + +Ensure your project has correctly configured dependencies for RaBitQ-Library. Then compile and run the example: + +```bash +g++ -std=c++17 -O3 -mavx2 dataset_example.cpp -o dataset_example +./dataset_example +``` + +## Extension Ideas + +You can extend this dataset generator with the following features: + +1. **Add New Dataset Types**: Implement new generation methods by adding new enum values and corresponding generation functions +2. **Custom Distributions**: Modify distribution parameters in existing methods +3. **Data Export**: Add functionality to save datasets to files +4. **Visualization**: Add visualization tools (for low-dimensional data) +5. **Query Generation Strategies**: Implement different query generation strategies (e.g., based on cluster centroids, edge cases, etc.) diff --git a/contrib/dataset.h b/contrib/dataset.h new file mode 100644 index 0000000..2cf4478 --- /dev/null +++ b/contrib/dataset.h @@ -0,0 +1,497 @@ +#pragma once + +#include +#include +#include +#include +#include +#include +#include +#include + +namespace rabitqlib { + +/** + * @brief calculate L2 distance between two vectors, copied from hnswlib/space_l2.h + * @param pVect1 pointer to the first vector + * @param pVect2 pointer to the second vector + * @param qty_ptr pointer to the dimension of the vectors + * @return L2 distance + */ +static float +L2Sqr(const void *pVect1v, const void *pVect2v, const void *qty_ptr) { + float *pVect1 = (float *) pVect1v; + float *pVect2 = (float *) pVect2v; + size_t qty = *((size_t *) qty_ptr); + + float res = 0; + for (size_t i = 0; i < qty; i++) { + float t = *pVect1 - *pVect2; + pVect1++; + pVect2++; + res += t * t; + } + return (res); +} + +namespace test { + +using data_type = rabitqlib::RowMajorArray; + + +/** + * @brief normalize a vector + * @param vec vector to normalize + * @param dim dimension of the vector + */ +void normalize_vector(float* vec, size_t dim) { + float norm = 0.0f; + for (size_t i = 0; i < dim; ++i) { + norm += vec[i] * vec[i]; + } + norm = std::sqrt(norm); + if (norm > 1e-6f) { + for (size_t i = 0; i < dim; ++i) { + vec[i] /= norm; + } + } +} + +class Dataset { +public: + enum class DatasetType { + Random, + GaussianMixture, + SphericalCluster + }; + + Dataset( + size_t num_points, + size_t dim, + size_t num_clusters, + size_t num_queries, + DatasetType type = DatasetType::Random, + uint32_t seed = 42 + ) + : data(num_points * dim) + , centroids(num_clusters * dim) + , cluster_ids(num_points) + , queries(num_queries * dim) + , num_points(num_points) + , dim(dim) + , num_clusters(num_clusters) + , num_queries(num_queries) + , seed(seed) { + switch (type) { + case DatasetType::Random: + generate_random_dataset(); + break; + case DatasetType::GaussianMixture: + generate_gaussian_mixture_dataset(); + break; + case DatasetType::SphericalCluster: + generate_spherical_cluster_dataset(); + break; + default: + throw std::invalid_argument("Invalid dataset type"); + } + } + ~Dataset() = default; + + /** + * @brief get number of points + * @return number of points + */ + size_t get_num_points() { + return num_points; + } + + /** + * @brief get dimension + * @return dimension + */ + size_t get_dim() { + return dim; + } + + /** + * @brief get number of clusters + * @return number of clusters + */ + size_t get_num_clusters() { + return num_clusters; + } + + /** + * @brief get data pointer + * @return data pointer + */ + const float* get_data_ptr() { + return data.data(); + } + + /** + * @brief get centroids pointer + * @return centroids pointer + */ + const float* get_centroids_ptr() { + return centroids.data(); + } + + /** + * @brief get cluster ids pointer + * @return cluster ids pointer + */ + const PID* get_cluster_ids_ptr() { + return cluster_ids.data(); + } + + /** + * @brief get queries pointer + * @return queries pointer + */ + const float* get_queries_ptr() { + return queries.data(); + } + + /** + * @brief get number of queries + * @return number of queries + */ + size_t get_num_queries() { + return num_queries; + } + + /** + * @brief validate dataset quality + * @return whether valid + */ + bool validate_dataset() { + if (data.empty() || centroids.empty() || cluster_ids.empty()) { + return false; + } + + if (data.size() != num_points * dim) { + return false; + } + + if (centroids.size() != num_clusters * dim) { + return false; + } + + if (cluster_ids.size() != num_points) { + return false; + } + + if (queries.size() != num_queries * dim) { + return false; + } + + // check if cluster ids are in valid range + for (PID cluster_id : cluster_ids) { + if (cluster_id >= static_cast(num_clusters)) { + return false; + } + } + + return true; + } + + /** + * @brief calculate dataset statistics + * @return statistics string + */ + std::string get_dataset_stats() { + std::string stats = "Dataset Statistics:\n"; + stats += " Points: " + std::to_string(num_points) + "\n"; + stats += " Dimension: " + std::to_string(dim) + "\n"; + stats += " Clusters: " + std::to_string(num_clusters) + "\n"; + + // calculate number of points per cluster + std::vector cluster_counts(num_clusters, 0); + for (PID cluster_id : cluster_ids) { + cluster_counts[cluster_id]++; + } + + stats += " Points per cluster:\n"; + for (size_t i = 0; i < num_clusters; ++i) { + stats += " Cluster " + std::to_string(i) + ": " + + std::to_string(cluster_counts[i]) + "\n"; + } + + return stats; + } + + /** + * @brief get results + * @param metric_type metric type, currently only support L2 + * @param topk top k results + * @return results + */ + std::unordered_map> get_results(MetricType metric_type = METRIC_L2, size_t topk = 10) { + auto cmp = [](const std::pair& a, const std::pair& b) { + return a.second < b.second; + }; + std::priority_queue, std::vector>, decltype(cmp)> pq(cmp); + std::unordered_map> results; + for (size_t i = 0; i < num_queries; i++) { + for (size_t j = 0; j < num_points; j++) { + float distance = L2Sqr(data.data() + j * dim, queries.data() + i * dim, &dim); + if (pq.size() < topk) { + pq.push(std::make_pair(j, distance)); + } else { + if (pq.top().second > distance) { + pq.pop(); + pq.push(std::make_pair(j, distance)); + } + } + } + auto candidate = std::set(); + while (!pq.empty()) { + candidate.insert(pq.top().first); + pq.pop(); + } + results.insert(std::make_pair(i, candidate)); + } + return results; + } + +private: + /** + * @brief generate random dataset + */ + void generate_random_dataset() { + if (num_clusters == 0) { + num_clusters = std::max(1UL, num_points / 1000); // default number of clusters + } + + std::mt19937 gen(seed); + std::normal_distribution dist(0.0f, 1.0f); + + // generate centroids + centroids.resize(num_clusters * dim); + for (size_t i = 0; i < num_clusters; ++i) { + for (size_t j = 0; j < dim; ++j) { + centroids[i * dim + j] = dist(gen) * 2.0f; // centroids distribution + } + } + + // generate data points + data.resize(num_points * dim); + cluster_ids.resize(num_points); + + std::uniform_int_distribution cluster_dist(0, num_clusters - 1); + std::normal_distribution noise_dist(0.0f, 0.3f); // add noise + + for (size_t i = 0; i < num_points; ++i) { + size_t cluster_id = cluster_dist(gen); + cluster_ids[i] = static_cast(cluster_id); + + // generate data points based on centroids + for (size_t j = 0; j < dim; ++j) { + float centroid_val = centroids[cluster_id * dim + j]; + data[i * dim + j] = centroid_val + noise_dist(gen); + } + + // normalize the vector + float* vec_ptr = &data[i * dim]; + normalize_vector(vec_ptr, dim); + } + + // normalize centroids + for (size_t i = 0; i < num_clusters; ++i) { + float* centroid_ptr = ¢roids[i * dim]; + normalize_vector(centroid_ptr, dim); + } + + // generate batch queries (10% of num_points, at least 1) + num_queries = std::max(1, num_points / 10); + queries.resize(num_queries * dim); + std::normal_distribution query_dist(0.0f, 1.0f); + for (size_t i = 0; i < num_queries; ++i) { + float* qptr = &queries[i * dim]; + for (size_t j = 0; j < dim; ++j) { + qptr[j] = query_dist(gen); + } + normalize_vector(qptr, dim); + } + + } + + /** + * @brief generate gaussian mixture dataset + */ + void generate_gaussian_mixture_dataset() { + if (num_clusters == 0) { + num_clusters = std::max(1UL, num_points / 1000); + } + + std::mt19937 gen(seed); + + // generate centroids (uniformly distributed in hypercube) + centroids.resize(num_clusters * dim); + std::uniform_real_distribution centroid_dist(-5.0f, 5.0f); + + for (size_t i = 0; i < num_clusters; ++i) { + for (size_t j = 0; j < dim; ++j) { + centroids[i * dim + j] = centroid_dist(gen); + } + } + + // set different standard deviations for each cluster + std::vector cluster_std(num_clusters); + std::uniform_real_distribution std_dist(0.5f, 2.0f); + for (size_t i = 0; i < num_clusters; ++i) { + cluster_std[i] = std_dist(gen); + } + + // generate data points + data.resize(num_points * dim); + cluster_ids.resize(num_points); + + // number of points per cluster + std::vector points_per_cluster(num_clusters, num_points / num_clusters); + size_t remaining = num_points % num_clusters; + for (size_t i = 0; i < remaining; ++i) { + points_per_cluster[i]++; + } + + size_t point_idx = 0; + for (size_t cluster_id = 0; cluster_id < num_clusters; ++cluster_id) { + std::normal_distribution noise_dist(0.0f, cluster_std[cluster_id]); + + for (size_t p = 0; p < points_per_cluster[cluster_id]; ++p) { + cluster_ids[point_idx] = static_cast(cluster_id); + + for (size_t j = 0; j < dim; ++j) { + float centroid_val = centroids[cluster_id * dim + j]; + data[point_idx * dim + j] = centroid_val + noise_dist(gen); + } + + // normalize the vector + float* vec_ptr = &data[point_idx * dim]; + normalize_vector(vec_ptr, dim); + + point_idx++; + } + } + + // normalize centroids + for (size_t i = 0; i < num_clusters; ++i) { + float* centroid_ptr = ¢roids[i * dim]; + normalize_vector(centroid_ptr, dim); + } + + // generate batch queries (10% of num_points, at least 1) + num_queries = std::max(1, num_points / 10); + queries.resize(num_queries * dim); + std::normal_distribution query_dist(0.0f, 1.0f); + for (size_t i = 0; i < num_queries; ++i) { + float* qptr = &queries[i * dim]; + for (size_t j = 0; j < dim; ++j) { + qptr[j] = query_dist(gen); + } + normalize_vector(qptr, dim); + } + } + + /** + * @brief generate spherical cluster dataset + */ + void generate_spherical_cluster_dataset() { + if (num_clusters == 0) { + num_clusters = std::max(1UL, num_points / 1000); + } + + std::mt19937 gen(seed); + + // generate centroids + centroids.resize(num_clusters * dim); + std::uniform_real_distribution centroid_dist(-10.0f, 10.0f); + + for (size_t i = 0; i < num_clusters; ++i) { + for (size_t j = 0; j < dim; ++j) { + centroids[i * dim + j] = centroid_dist(gen); + } + } + + // generate data points + data.resize(num_points * dim); + cluster_ids.resize(num_points); + + std::vector points_per_cluster(num_clusters, num_points / num_clusters); + size_t remaining = num_points % num_clusters; + for (size_t i = 0; i < remaining; ++i) { + points_per_cluster[i]++; + } + + size_t point_idx = 0; + for (size_t cluster_id = 0; cluster_id < num_clusters; ++cluster_id) { + float radius = 1.0f + (cluster_id % 3) * 0.5f; // different cluster radii + + for (size_t p = 0; p < points_per_cluster[cluster_id]; ++p) { + cluster_ids[point_idx] = static_cast(cluster_id); + + // generate random points on the sphere + std::vector direction(dim); + float norm = 0.0f; + + std::normal_distribution normal_dist(0.0f, 1.0f); + for (size_t j = 0; j < dim; ++j) { + direction[j] = normal_dist(gen); + norm += direction[j] * direction[j]; + } + norm = std::sqrt(norm); + + // normalize and add radius + std::uniform_real_distribution radius_dist(0.0f, radius); + float r = radius_dist(gen); + + for (size_t j = 0; j < dim; ++j) { + float centroid_val = centroids[cluster_id * dim + j]; + data[point_idx * dim + j] = centroid_val + (direction[j] / norm) * r; + } + + // normalize the vector + float* vec_ptr = &data[point_idx * dim]; + normalize_vector(vec_ptr, dim); + + point_idx++; + } + } + + // normalize centroids + for (size_t i = 0; i < num_clusters; ++i) { + float* centroid_ptr = ¢roids[i * dim]; + normalize_vector(centroid_ptr, dim); + } + + // generate batch queries (10% of num_points, at least 1) + num_queries = std::max(1, num_points / 10); + queries.resize(num_queries * dim); + std::normal_distribution query_dist(0.0f, 1.0f); + for (size_t i = 0; i < num_queries; ++i) { + float* qptr = &queries[i * dim]; + for (size_t j = 0; j < dim; ++j) { + qptr[j] = query_dist(gen); + } + normalize_vector(qptr, dim); + } + } + + +private: + std::vector data; // raw data + std::vector centroids; // centroids + std::vector cluster_ids; // cluster ids + std::vector queries; // batch queries + size_t num_points; // number of points + size_t dim; // dimension + size_t num_clusters; // number of clusters + size_t num_queries; // number of queries + uint32_t seed; // random seed +}; + +} // namespace test + +} // namespace rabitqlib \ No newline at end of file diff --git a/contrib/dataset_example.cpp b/contrib/dataset_example.cpp new file mode 100644 index 0000000..f380321 --- /dev/null +++ b/contrib/dataset_example.cpp @@ -0,0 +1,119 @@ +#include "dataset.h" +#include +#include +#include + +using namespace rabitqlib::test; +using namespace rabitqlib; + +int main() { + // case1: generate random dataset + std::cout << "=== generate random dataset ===" << std::endl; + auto random_dataset = Dataset(10000, 128, 10, 42, Dataset::DatasetType::Random); + + if (random_dataset.validate_dataset()) { + std::cout << "random dataset generated successfully!" << std::endl; + std::cout << random_dataset.get_dataset_stats() << std::endl; + } + + // case2: generate gaussian mixture dataset + std::cout << "\n=== generate gaussian mixture dataset ===" << std::endl; + auto gmm_dataset = Dataset(15000, 256, 15, 123, Dataset::DatasetType::GaussianMixture); + + if (gmm_dataset.validate_dataset()) { + std::cout << "gaussian mixture dataset generated successfully!" << std::endl; + std::cout << gmm_dataset.get_dataset_stats() << std::endl; + } + + // case3: generate spherical cluster dataset + std::cout << "\n=== generate spherical cluster dataset ===" << std::endl; + auto spherical_dataset = Dataset(8000, 64, 8, 456, Dataset::DatasetType::SphericalCluster); + + if (spherical_dataset.validate_dataset()) { + std::cout << "spherical cluster dataset generated successfully!" << std::endl; + std::cout << spherical_dataset.get_dataset_stats() << std::endl; + } + + // case4: use IVF::construct function + std::cout << "\n=== use IVF::construct function ===" << std::endl; + + // generate test dataset + auto test_dataset = Dataset(5000, 128, 5, 789, Dataset::DatasetType::GaussianMixture); + + // get data pointer + const float* data_ptr = test_dataset.get_data_ptr(); + const float* centroids_ptr = test_dataset.get_centroids_ptr(); + const PID* cluster_ids_ptr = test_dataset.get_cluster_ids_ptr(); + + std::cout << "dataset info:" << std::endl; + std::cout << " num_points: " << test_dataset.get_num_points() << std::endl; + std::cout << " dim: " << test_dataset.get_dim() << std::endl; + std::cout << " num_clusters: " << test_dataset.get_num_clusters() << std::endl; + + // IVF::construct(data_ptr, test_dataset.num_points, test_dataset.dim, ...); + + // case5: validate data quality and normalization + std::cout << "\n=== validate data quality and normalization ===" << std::endl; + + // check data range + float min_val = data_ptr[0], max_val = data_ptr[0]; + for (size_t i = 1; i < test_dataset.get_num_points() * test_dataset.get_dim(); ++i) { + min_val = std::min(min_val, data_ptr[i]); + max_val = std::max(max_val, data_ptr[i]); + } + + std::cout << "data range: [" << min_val << ", " << max_val << "]" << std::endl; + + // verify normalization + std::cout << "verifying normalization..." << std::endl; + bool all_normalized = true; + for (size_t i = 0; i < test_dataset.get_num_points(); ++i) { + float norm = 0.0f; + for (size_t j = 0; j < test_dataset.get_dim(); ++j) { + norm += data_ptr[i * test_dataset.get_dim() + j] * data_ptr[i * test_dataset.get_dim() + j]; + } + norm = std::sqrt(norm); + + if (std::abs(norm - 1.0f) > 1e-5f) { + std::cout << "Warning: Vector " << i << " not normalized, norm = " << norm << std::endl; + all_normalized = false; + } + } + + if (all_normalized) { + std::cout << "All vectors are properly normalized!" << std::endl; + } + + // verify centroid normalization + std::cout << "verifying centroid normalization..." << std::endl; + bool centroids_normalized = true; + for (size_t i = 0; i < test_dataset.get_num_clusters(); ++i) { + float norm = 0.0f; + for (size_t j = 0; j < test_dataset.get_dim(); ++j) { + norm += centroids_ptr[i * test_dataset.get_dim() + j] * centroids_ptr[i * test_dataset.get_dim() + j]; + } + norm = std::sqrt(norm); + + if (std::abs(norm - 1.0f) > 1e-5f) { + std::cout << "Warning: Centroid " << i << " not normalized, norm = " << norm << std::endl; + centroids_normalized = false; + } + } + + if (centroids_normalized) { + std::cout << "All centroids are properly normalized!" << std::endl; + } + + // check cluster distribution + std::vector cluster_counts(test_dataset.get_num_clusters(), 0); + for (size_t i = 0; i < test_dataset.get_num_points(); ++i) { + cluster_counts[cluster_ids_ptr[i]]++; + } + + std::cout << "cluster distribution:" << std::endl; + for (size_t i = 0; i < test_dataset.get_num_clusters(); ++i) { + std::cout << " cluster " << i << ": " << cluster_counts[i] << " points" << std::endl; + } + + return 0; +} \ No newline at end of file diff --git a/contrib/test_normalization.cpp b/contrib/test_normalization.cpp new file mode 100644 index 0000000..1b2f3dd --- /dev/null +++ b/contrib/test_normalization.cpp @@ -0,0 +1,83 @@ +#include "dataset.h" +#include +#include + +using namespace rabitqlib::test; + +bool verify_normalization(const float* data, size_t num_points, size_t dim, const std::string& name) { + std::cout << "Verifying " << name << " normalization..." << std::endl; + + bool all_normalized = true; + float min_norm = 1.0f, max_norm = 1.0f; + + for (size_t i = 0; i < num_points; ++i) { + float norm = 0.0f; + for (size_t j = 0; j < dim; ++j) { + norm += data[i * dim + j] * data[i * dim + j]; + } + norm = std::sqrt(norm); + + min_norm = std::min(min_norm, norm); + max_norm = std::max(max_norm, norm); + + if (std::abs(norm - 1.0f) > 1e-5f) { + std::cout << " Warning: Vector " << i << " not normalized, norm = " << norm << std::endl; + all_normalized = false; + } + } + + std::cout << " Norm range: [" << min_norm << ", " << max_norm << "]" << std::endl; + + if (all_normalized) { + std::cout << " ✓ All vectors are properly normalized!" << std::endl; + } else { + std::cout << " ✗ Some vectors are not normalized!" << std::endl; + } + + return all_normalized; +} + +int main() { + std::cout << "=== Testing Dataset Normalization ===" << std::endl; + + // Test 1: Random dataset + std::cout << "\n1. Testing random dataset..." << std::endl; + auto random_dataset = Dataset::generate_random_dataset(1000, 128, 5, 42); + const float* random_data = Dataset::get_data_ptr(random_dataset); + const float* random_centroids = Dataset::get_centroids_ptr(random_dataset); + + bool random_ok = verify_normalization(random_data, random_dataset.num_points, random_dataset.dim, "random data"); + bool random_centroids_ok = verify_normalization(random_centroids, random_dataset.num_clusters, random_dataset.dim, "random centroids"); + + // Test 2: Gaussian mixture dataset + std::cout << "\n2. Testing gaussian mixture dataset..." << std::endl; + auto gmm_dataset = Dataset::generate_gaussian_mixture_dataset(1000, 256, 8, 123); + const float* gmm_data = Dataset::get_data_ptr(gmm_dataset); + const float* gmm_centroids = Dataset::get_centroids_ptr(gmm_dataset); + + bool gmm_ok = verify_normalization(gmm_data, gmm_dataset.num_points, gmm_dataset.dim, "GMM data"); + bool gmm_centroids_ok = verify_normalization(gmm_centroids, gmm_dataset.num_clusters, gmm_dataset.dim, "GMM centroids"); + + // Test 3: Spherical cluster dataset + std::cout << "\n3. Testing spherical cluster dataset..." << std::endl; + auto spherical_dataset = Dataset::generate_spherical_cluster_dataset(1000, 64, 6, 456); + const float* spherical_data = Dataset::get_data_ptr(spherical_dataset); + const float* spherical_centroids = Dataset::get_centroids_ptr(spherical_dataset); + + bool spherical_ok = verify_normalization(spherical_data, spherical_dataset.num_points, spherical_dataset.dim, "spherical data"); + bool spherical_centroids_ok = verify_normalization(spherical_centroids, spherical_dataset.num_clusters, spherical_dataset.dim, "spherical centroids"); + + // Summary + std::cout << "\n=== Summary ===" << std::endl; + std::cout << "Random dataset: " << (random_ok && random_centroids_ok ? "✓ PASS" : "✗ FAIL") << std::endl; + std::cout << "GMM dataset: " << (gmm_ok && gmm_centroids_ok ? "✓ PASS" : "✗ FAIL") << std::endl; + std::cout << "Spherical dataset: " << (spherical_ok && spherical_centroids_ok ? "✓ PASS" : "✗ FAIL") << std::endl; + + bool all_passed = random_ok && random_centroids_ok && + gmm_ok && gmm_centroids_ok && + spherical_ok && spherical_centroids_ok; + + std::cout << "\nOverall result: " << (all_passed ? "✓ ALL TESTS PASSED" : "✗ SOME TESTS FAILED") << std::endl; + + return all_passed ? 0 : 1; +} \ No newline at end of file diff --git a/example.sh b/example.sh index fa38680..0828e15 100755 --- a/example.sh +++ b/example.sh @@ -2,11 +2,18 @@ mkdir build bin cd build cmake .. -make - -# Download the dataset -wget -P ./data/gist ftp://ftp.irisa.fr/local/texmex/corpus/gist.tar.gz -tar -xzvf ./data/gist/gist.tar.gz -C ./data/gist +make -j 10 +cd .. + +if [ ! -d ./data/gist ]; then + mkdir -p ./data/gist +fi + +if [ ! -f ./data/gist/gist_base.fvecs ]; then + # download the dataset + wget -P ./data/gist ftp://ftp.irisa.fr/local/texmex/corpus/gist.tar.gz + tar -xzvf ./data/gist/gist.tar.gz -C ./data/gist +fi # indexing and querying for symqg ./bin/symqg_indexing ./data/gist/gist_base.fvecs 32 400 ./data/gist/symqg_32.index diff --git a/rabitqlib/fastscan/fastscan.hpp b/rabitqlib/fastscan/fastscan.hpp index 77c7a0f..82f02c1 100644 --- a/rabitqlib/fastscan/fastscan.hpp +++ b/rabitqlib/fastscan/fastscan.hpp @@ -3,7 +3,8 @@ #pragma once -#include +// #include +#include #include #include @@ -113,44 +114,44 @@ inline void accumulate( size_t dim ) { size_t code_length = dim << 2; -#if defined(__AVX512F__) - __m512i c; - __m512i lo; - __m512i hi; - __m512i lut; - __m512i res_lo; - __m512i res_hi; - - const __m512i lo_mask = _mm512_set1_epi8(0x0f); - __m512i accu0 = _mm512_setzero_si512(); - __m512i accu1 = _mm512_setzero_si512(); - __m512i accu2 = _mm512_setzero_si512(); - __m512i accu3 = _mm512_setzero_si512(); + + simde__m512i c; + simde__m512i lo; + simde__m512i hi; + simde__m512i lut; + simde__m512i res_lo; + simde__m512i res_hi; + + const simde__m512i lo_mask = simde_mm512_set1_epi8(0x0f); + simde__m512i accu0 = simde_mm512_setzero_si512(); + simde__m512i accu1 = simde_mm512_setzero_si512(); + simde__m512i accu2 = simde_mm512_setzero_si512(); + simde__m512i accu3 = simde_mm512_setzero_si512(); // ! here, we assume the code_length is a multiple of 64, thus the dim must be a // ! multiple of 16 for (size_t i = 0; i < code_length; i += 64) { - c = _mm512_loadu_si512(&codes[i]); - lut = _mm512_loadu_si512(&lp_table[i]); - lo = _mm512_and_si512(c, lo_mask); // code of vector 0 to 15 - hi = _mm512_and_si512(_mm512_srli_epi16(c, 4), lo_mask); // code of vector 16 to 31 + c = simde_mm512_loadu_si512(&codes[i]); + lut = simde_mm512_loadu_si512(&lp_table[i]); + lo = simde_mm512_and_si512(c, lo_mask); // code of vector 0 to 15 + hi = simde_mm512_and_si512(simde_mm512_srli_epi16(c, 4), lo_mask); - res_lo = _mm512_shuffle_epi8(lut, lo); // get the target value in lookup table - res_hi = _mm512_shuffle_epi8(lut, hi); + res_lo = simde_mm512_shuffle_epi8(lut, lo); // get the target value in lookup table + res_hi = simde_mm512_shuffle_epi8(lut, hi); // since values in lookup table are represented as i8, we add them as i16 to avoid // overflow. Since the data order is 0, 8, 1, 9, 2, 10, 3, 11, 4, 12, 5, 13, 6, 14, // 7, 15, accu0 accumulates for vec 8 to 15 (the upper 8 bits need to be updated // since they stored useless info of vec 0 to 7) accu1 accumulates for vec 0 to 7 // similar for accu2 and accu3 - accu0 = _mm512_add_epi16(accu0, res_lo); - accu1 = _mm512_add_epi16(accu1, _mm512_srli_epi16(res_lo, 8)); - accu2 = _mm512_add_epi16(accu2, res_hi); - accu3 = _mm512_add_epi16(accu3, _mm512_srli_epi16(res_hi, 8)); + accu0 = simde_mm512_add_epi16(accu0, res_lo); + accu1 = simde_mm512_add_epi16(accu1, simde_mm512_srli_epi16(res_lo, 8)); + accu2 = simde_mm512_add_epi16(accu2, res_hi); + accu3 = simde_mm512_add_epi16(accu3, simde_mm512_srli_epi16(res_hi, 8)); } // remove the influence of upper 8 bits for accu0 and accu2 - accu0 = _mm512_sub_epi16(accu0, _mm512_slli_epi16(accu1, 8)); - accu2 = _mm512_sub_epi16(accu2, _mm512_slli_epi16(accu3, 8)); + accu0 = simde_mm512_sub_epi16(accu0, simde_mm512_slli_epi16(accu1, 8)); + accu2 = simde_mm512_sub_epi16(accu2, simde_mm512_slli_epi16(accu3, 8)); // At this point, we already have the correct accumulating result (accu0: 8-15, accu1: // 0-7, accu2: 16-23, accu3: 24-31), but we still need to write them back to RAM. Also, @@ -158,75 +159,20 @@ inline void accumulate( // final results. 512/16=32, so we can use one __m512i to contain all results. The // following codes are designed for this purpose. For detailed information, please check // the SIMD documentation. - __m512i ret1 = _mm512_add_epi16( - _mm512_mask_blend_epi64(0b11110000, accu0, accu1), - _mm512_shuffle_i64x2(accu0, accu1, 0b01001110) + simde__m512i ret1 = simde_mm512_add_epi16( + simde_mm512_mask_blend_epi64(0b11110000, accu0, accu1), + simde_mm512_shuffle_i64x2(accu0, accu1, 0b01001110) ); - __m512i ret2 = _mm512_add_epi16( - _mm512_mask_blend_epi64(0b11110000, accu2, accu3), - _mm512_shuffle_i64x2(accu2, accu3, 0b01001110) + simde__m512i ret2 = simde_mm512_add_epi16( + simde_mm512_mask_blend_epi64(0b11110000, accu2, accu3), + simde_mm512_shuffle_i64x2(accu2, accu3, 0b01001110) ); - __m512i ret = _mm512_setzero_si512(); - - ret = _mm512_add_epi16(ret, _mm512_shuffle_i64x2(ret1, ret2, 0b10001000)); - ret = _mm512_add_epi16(ret, _mm512_shuffle_i64x2(ret1, ret2, 0b11011101)); + simde__m512i ret = simde_mm512_setzero_si512(); - _mm512_storeu_si512(result, ret); + ret = simde_mm512_add_epi16(ret, simde_mm512_shuffle_i64x2(ret1, ret2, 0b10001000)); + ret = simde_mm512_add_epi16(ret, simde_mm512_shuffle_i64x2(ret1, ret2, 0b11011101)); -#elif defined(__AVX2__) - __m256i c, lo, hi, lut, res_lo, res_hi; - - __m256i low_mask = _mm256_set1_epi8(0xf); - __m256i accu0 = _mm256_setzero_si256(); - __m256i accu1 = _mm256_setzero_si256(); - __m256i accu2 = _mm256_setzero_si256(); - __m256i accu3 = _mm256_setzero_si256(); - - for (size_t i = 0; i < code_length; i += 64) { - c = _mm256_loadu_si256((__m256i*)&codes[i]); - lut = _mm256_loadu_si256((__m256i*)&lp_table[i]); - lo = _mm256_and_si256(c, low_mask); - hi = _mm256_and_si256(_mm256_srli_epi16(c, 4), low_mask); - - res_lo = _mm256_shuffle_epi8(lut, lo); - res_hi = _mm256_shuffle_epi8(lut, hi); - - accu0 = _mm256_add_epi16(accu0, res_lo); - accu1 = _mm256_add_epi16(accu1, _mm256_srli_epi16(res_lo, 8)); - accu2 = _mm256_add_epi16(accu2, res_hi); - accu3 = _mm256_add_epi16(accu3, _mm256_srli_epi16(res_hi, 8)); - - c = _mm256_loadu_si256((__m256i*)&codes[i + 32]); - lut = _mm256_loadu_si256((__m256i*)&lp_table[i + 32]); - lo = _mm256_and_si256(c, low_mask); - hi = _mm256_and_si256(_mm256_srli_epi16(c, 4), low_mask); - - res_lo = _mm256_shuffle_epi8(lut, lo); - res_hi = _mm256_shuffle_epi8(lut, hi); - - accu0 = _mm256_add_epi16(accu0, res_lo); - accu1 = _mm256_add_epi16(accu1, _mm256_srli_epi16(res_lo, 8)); - accu2 = _mm256_add_epi16(accu2, res_hi); - accu3 = _mm256_add_epi16(accu3, _mm256_srli_epi16(res_hi, 8)); - } - - accu0 = _mm256_sub_epi16(accu0, _mm256_slli_epi16(accu1, 8)); - __m256i dis0 = _mm256_add_epi16( - _mm256_permute2f128_si256(accu0, accu1, 0x21), - _mm256_blend_epi32(accu0, accu1, 0xF0) - ); - _mm256_storeu_si256((__m256i*)result, dis0); - - accu2 = _mm256_sub_epi16(accu2, _mm256_slli_epi16(accu3, 8)); - __m256i dis1 = _mm256_add_epi16( - _mm256_permute2f128_si256(accu2, accu3, 0x21), - _mm256_blend_epi32(accu2, accu3, 0xF0) - ); - _mm256_storeu_si256((__m256i*)&result[16], dis1); -#else - std::cerr << "no avx simd supported!\n"; - exit(1); -#endif + simde_mm512_storeu_si512(result, ret); } // pack lookup table for fastscan, for each 4 dim, we have 16 (2^4) different results diff --git a/rabitqlib/fastscan/highacc_fastscan.hpp b/rabitqlib/fastscan/highacc_fastscan.hpp index a4657b0..110948e 100644 --- a/rabitqlib/fastscan/highacc_fastscan.hpp +++ b/rabitqlib/fastscan/highacc_fastscan.hpp @@ -1,6 +1,7 @@ #pragma once -#include +#include +#include "../utils/simde-utils.hpp" #include #include @@ -19,15 +20,8 @@ inline void transfer_lut_hacc(const uint16_t* lut, size_t dim, uint8_t* hc_lut) size_t num_codebook = dim >> 2; for (size_t i = 0; i < num_codebook; i++) { - // avx2 - 256, avx512 - 512 -#if defined(__AVX512F__) + constexpr size_t kRegBits = 512; -#elif defined(__AVX2__) - constexpr size_t kRegBits = 256; -#else - static_assert(false, "At least requried AVX2 for using fastscan\n"); - exit(1); -#endif constexpr size_t kLaneBits = 128; constexpr size_t kByteBits = 8; @@ -39,21 +33,12 @@ inline void transfer_lut_hacc(const uint16_t* lut, size_t dim, uint8_t* hc_lut) hc_lut + (i / kLutPerIter * kCodePerIter) + ((i % kLutPerIter) * kCodePerLine); uint8_t* fill_hi = fill_lo + (kRegBits / kByteBits); -#if defined(__AVX512F__) - __m512i tmp = _mm512_cvtepi16_epi32(_mm256_loadu_epi16(lut)); - __m128i lo = _mm512_cvtepi32_epi8(tmp); - __m128i hi = _mm512_cvtepi32_epi8(_mm512_srli_epi32(tmp, 8)); - _mm_store_si128(reinterpret_cast<__m128i*>(fill_lo), lo); - _mm_store_si128(reinterpret_cast<__m128i*>(fill_hi), hi); -#else - for (size_t j = 0; j < 16; ++j) { - int tmp = lut[j]; - uint8_t lo = static_cast(tmp); - uint8_t hi = static_cast(tmp >> 8); - fill_lo[j] = lo; - fill_hi[j] = hi; - } -#endif + simde__m512i tmp = simde_mm512_cvtepi16_epi32(simde_mm256_loadu_epi16(lut)); + simde__m128i lo = simde_mm512_cvtepi32_epi8(tmp); + simde__m128i hi = simde_mm512_cvtepi32_epi8(simde_mm512_srli_epi32(tmp, 8)); + simde_mm_store_si128(reinterpret_cast(fill_lo), lo); + simde_mm_store_si128(reinterpret_cast(fill_hi), hi); + lut += 16; } } @@ -64,12 +49,12 @@ inline void accumulate_hacc( int32_t* accu_res, size_t dim ) { - __m512i low_mask = _mm512_set1_epi8(0xf); - __m512i accu[2][4]; + simde__m512i low_mask = simde_mm512_set1_epi8(0xf); + simde__m512i accu[2][4]; for (auto& a : accu) { for (auto& reg : a) { - reg = _mm512_setzero_si512(); + reg = simde_mm512_setzero_si512(); } } @@ -77,24 +62,24 @@ inline void accumulate_hacc( // std::cerr << "FastScan YES!" << std::endl; for (size_t m = 0; m < num_codebook; m += 4) { - __m512i c = _mm512_loadu_si512(codes); - __m512i lo = _mm512_and_si512(c, low_mask); - __m512i hi = _mm512_and_si512(_mm512_srli_epi16(c, 4), low_mask); + simde__m512i c = simde_mm512_loadu_si512(codes); + simde__m512i lo = simde_mm512_and_si512(c, low_mask); + simde__m512i hi = simde_mm512_and_si512(simde_mm512_srli_epi16(c, 4), low_mask); // accumulate lower & upper results respectively // accu[0][0-3] for lower 8-bit result // accu[1][0-3] for upper 8-bit result for (auto& i : accu) { - __m512i lut = _mm512_loadu_si512(hc_lut); + simde__m512i lut = simde_mm512_loadu_si512(hc_lut); - __m512i res_lo = _mm512_shuffle_epi8(lut, lo); - __m512i res_hi = _mm512_shuffle_epi8(lut, hi); + simde__m512i res_lo = simde_mm512_shuffle_epi8(lut, lo); + simde__m512i res_hi = simde_mm512_shuffle_epi8(lut, hi); - i[0] = _mm512_add_epi16(i[0], res_lo); - i[1] = _mm512_add_epi16(i[1], _mm512_srli_epi16(res_lo, 8)); + i[0] = simde_mm512_add_epi16(i[0], res_lo); + i[1] = simde_mm512_add_epi16(i[1], simde_mm512_srli_epi16(res_lo, 8)); - i[2] = _mm512_add_epi16(i[2], res_hi); - i[3] = _mm512_add_epi16(i[3], _mm512_srli_epi16(res_hi, 8)); + i[2] = simde_mm512_add_epi16(i[2], res_hi); + i[3] = simde_mm512_add_epi16(i[3], simde_mm512_srli_epi16(res_hi, 8)); hc_lut += 64; } @@ -103,44 +88,44 @@ inline void accumulate_hacc( // std::cerr << "FastScan YES!" << std::endl; - __m512i res[2]; - __m512i dis0[2]; - __m512i dis1[2]; + simde__m512i res[2]; + simde__m512i dis0[2]; + simde__m512i dis1[2]; for (size_t i = 0; i < 2; ++i) { - __m256i tmp0 = _mm256_add_epi16( - _mm512_castsi512_si256(accu[i][0]), _mm512_extracti64x4_epi64(accu[i][0], 1) + simde__m256i tmp0 = simde_mm256_add_epi16( + simde_mm512_castsi512_si256(accu[i][0]), simde_mm512_extracti64x4_epi64(accu[i][0], 1) ); - __m256i tmp1 = _mm256_add_epi16( - _mm512_castsi512_si256(accu[i][1]), _mm512_extracti64x4_epi64(accu[i][1], 1) + simde__m256i tmp1 = simde_mm256_add_epi16( + simde_mm512_castsi512_si256(accu[i][1]), simde_mm512_extracti64x4_epi64(accu[i][1], 1) ); - tmp0 = _mm256_sub_epi16(tmp0, _mm256_slli_epi16(tmp1, 8)); + tmp0 = simde_mm256_sub_epi16(tmp0, simde_mm256_slli_epi16(tmp1, 8)); - dis0[i] = _mm512_add_epi32( - _mm512_cvtepu16_epi32(_mm256_permute2f128_si256(tmp0, tmp1, 0x21)), - _mm512_cvtepu16_epi32(_mm256_blend_epi32(tmp0, tmp1, 0xF0)) + dis0[i] = simde_mm512_add_epi32( + simde_mm512_cvtepu16_epi32(simde_mm256_permute2f128_si256(tmp0, tmp1, 0x21)), + simde_mm512_cvtepu16_epi32(simde_mm256_blend_epi32(tmp0, tmp1, 0xF0)) ); - __m256i tmp2 = _mm256_add_epi16( - _mm512_castsi512_si256(accu[i][2]), _mm512_extracti64x4_epi64(accu[i][2], 1) + simde__m256i tmp2 = simde_mm256_add_epi16( + simde_mm512_castsi512_si256(accu[i][2]), simde_mm512_extracti64x4_epi64(accu[i][2], 1) ); - __m256i tmp3 = _mm256_add_epi16( - _mm512_castsi512_si256(accu[i][3]), _mm512_extracti64x4_epi64(accu[i][3], 1) + simde__m256i tmp3 = simde_mm256_add_epi16( + simde_mm512_castsi512_si256(accu[i][3]), simde_mm512_extracti64x4_epi64(accu[i][3], 1) ); - tmp2 = _mm256_sub_epi16(tmp2, _mm256_slli_epi16(tmp3, 8)); + tmp2 = simde_mm256_sub_epi16(tmp2, simde_mm256_slli_epi16(tmp3, 8)); - dis1[i] = _mm512_add_epi32( - _mm512_cvtepu16_epi32(_mm256_permute2f128_si256(tmp2, tmp3, 0x21)), - _mm512_cvtepu16_epi32(_mm256_blend_epi32(tmp2, tmp3, 0xF0)) + dis1[i] = simde_mm512_add_epi32( + simde_mm512_cvtepu16_epi32(simde_mm256_permute2f128_si256(tmp2, tmp3, 0x21)), + simde_mm512_cvtepu16_epi32(simde_mm256_blend_epi32(tmp2, tmp3, 0xF0)) ); } // shift res of high, add res of low res[0] = - _mm512_add_epi32(dis0[0], _mm512_slli_epi32(dis0[1], 8)); // res for vec 0 to 15 + simde_mm512_add_epi32(dis0[0], simde_mm512_slli_epi32(dis0[1], 8)); // res for vec 0 to 15 res[1] = - _mm512_add_epi32(dis1[0], _mm512_slli_epi32(dis1[1], 8)); // res for vec 16 to 31 + simde_mm512_add_epi32(dis1[0], simde_mm512_slli_epi32(dis1[1], 8)); // res for vec 16 to 31 - _mm512_storeu_epi32(accu_res, res[0]); - _mm512_storeu_epi32(accu_res + 16, res[1]); + simde_mm512_storeu_epi32(accu_res, res[0]); + simde_mm512_storeu_epi32(accu_res + 16, res[1]); } } // namespace rabitqlib::fastscan \ No newline at end of file diff --git a/rabitqlib/quantization/pack_excode.hpp b/rabitqlib/quantization/pack_excode.hpp index c790c7a..d5eb52c 100644 --- a/rabitqlib/quantization/pack_excode.hpp +++ b/rabitqlib/quantization/pack_excode.hpp @@ -1,6 +1,6 @@ #pragma once -#include +#include #include #include @@ -11,7 +11,6 @@ namespace rabitqlib::quant::rabitq_impl::ex_bits { inline void packing_1bit_excode(const uint8_t* o_raw, uint8_t* o_compact, size_t dim) { -#if defined(__AVX512F__) // ! require dim % 16 == 0 for (size_t j = 0; j < dim; j += 16) { uint16_t code = 0; @@ -23,14 +22,9 @@ inline void packing_1bit_excode(const uint8_t* o_raw, uint8_t* o_compact, size_t o_raw += 16; o_compact += 2; } -#else - std::cerr << "Current only support AVX512F only for packing excode\n" << std::flush; - exit(1); -#endif } inline void packing_2bit_excode(const uint8_t* o_raw, uint8_t* o_compact, size_t dim) { -#if defined(__AVX512F__) // ! require dim % 16 == 0 for (size_t j = 0; j < dim; j += 16) { // pack 16 2-bit codes into int32 @@ -47,39 +41,34 @@ inline void packing_2bit_excode(const uint8_t* o_raw, uint8_t* o_compact, size_t o_raw += 16; o_compact += 4; } -#else - std::cerr << "Current only support AVX512F only for packing excode\n" << std::flush; - exit(1); -#endif } inline void packing_3bit_excode(const uint8_t* o_raw, uint8_t* o_compact, size_t dim) { -#if defined(__AVX512F__) // ! require dim % 64 == 0 - const __m128i mask = _mm_set1_epi8(0b11); + const simde__m128i mask = simde_mm_set1_epi8(0b11); for (size_t d = 0; d < dim; d += 64) { // split 3-bit codes into 2 bits and 1 bit // for 2-bit part, compact it like 2-bit code // for 1-bit part, compact 64 1-bit code into a int64 - __m128i vec_00_to_15 = _mm_loadu_si128(reinterpret_cast(o_raw)); - __m128i vec_16_to_31 = - _mm_loadu_si128(reinterpret_cast(o_raw + 16)); - __m128i vec_32_to_47 = - _mm_loadu_si128(reinterpret_cast(o_raw + 32)); - __m128i vec_48_to_63 = - _mm_loadu_si128(reinterpret_cast(o_raw + 48)); - - vec_00_to_15 = _mm_and_si128(vec_00_to_15, mask); - vec_16_to_31 = _mm_slli_epi16(_mm_and_si128(vec_16_to_31, mask), 2); - vec_32_to_47 = _mm_slli_epi16(_mm_and_si128(vec_32_to_47, mask), 4); - vec_48_to_63 = _mm_slli_epi16(_mm_and_si128(vec_48_to_63, mask), 6); - - __m128i compact2 = _mm_or_si128( - _mm_or_si128(vec_00_to_15, vec_16_to_31), - _mm_or_si128(vec_32_to_47, vec_48_to_63) + simde__m128i vec_00_to_15 = simde_mm_loadu_si128(reinterpret_cast(o_raw)); + simde__m128i vec_16_to_31 = + simde_mm_loadu_si128(reinterpret_cast(o_raw + 16)); + simde__m128i vec_32_to_47 = + simde_mm_loadu_si128(reinterpret_cast(o_raw + 32)); + simde__m128i vec_48_to_63 = + simde_mm_loadu_si128(reinterpret_cast(o_raw + 48)); + + vec_00_to_15 = simde_mm_and_si128(vec_00_to_15, mask); + vec_16_to_31 = simde_mm_slli_epi16(simde_mm_and_si128(vec_16_to_31, mask), 2); + vec_32_to_47 = simde_mm_slli_epi16(simde_mm_and_si128(vec_32_to_47, mask), 4); + vec_48_to_63 = simde_mm_slli_epi16(simde_mm_and_si128(vec_48_to_63, mask), 6); + + simde__m128i compact2 = simde_mm_or_si128( + simde_mm_or_si128(vec_00_to_15, vec_16_to_31), + simde_mm_or_si128(vec_32_to_47, vec_48_to_63) ); - _mm_storeu_si128(reinterpret_cast<__m128i*>(o_compact), compact2); + simde_mm_storeu_si128(reinterpret_cast(o_compact), compact2); o_compact += 16; // from lower to upper, each bit in each byte represents vec00 to vec07, @@ -95,16 +84,11 @@ inline void packing_3bit_excode(const uint8_t* o_raw, uint8_t* o_compact, size_t o_raw += 64; o_compact += 8; } -#else - std::cerr << "Current only support AVX512F only for packing excode\n" << std::flush; - exit(1); -#endif } inline void packing_4bit_excode(const uint8_t* o_raw, uint8_t* o_compact, size_t dim) { // although this part only requries SSE, computing inner product for this orgnization // requires AVX512F, similar for remaining functions -#if defined(__AVX512F__) // ! require dim % 16 == 0 for (size_t j = 0; j < dim; j += 16) { // pack 16 4-bit codes into uint64 @@ -120,35 +104,30 @@ inline void packing_4bit_excode(const uint8_t* o_raw, uint8_t* o_compact, size_t o_raw += 16; o_compact += 8; } -#else - std::cerr << "Current only support AVX512F only for packing excode\n" << std::flush; - exit(1); -#endif } inline void packing_5bit_excode(const uint8_t* o_raw, uint8_t* o_compact, size_t dim) { -#if defined(__AVX512F__) // ! require dim % 64 == 0 - const __m128i mask = _mm_set1_epi8(0b1111); + const simde__m128i mask = simde_mm_set1_epi8(0b1111); for (size_t j = 0; j < dim; j += 64) { - __m128i vec_00_to_15 = _mm_loadu_si128(reinterpret_cast(o_raw)); - __m128i vec_16_to_31 = - _mm_loadu_si128(reinterpret_cast(o_raw + 16)); - __m128i vec_32_to_47 = - _mm_loadu_si128(reinterpret_cast(o_raw + 32)); - __m128i vec_48_to_63 = - _mm_loadu_si128(reinterpret_cast(o_raw + 48)); + simde__m128i vec_00_to_15 = simde_mm_loadu_si128(reinterpret_cast(o_raw)); + simde__m128i vec_16_to_31 = + simde_mm_loadu_si128(reinterpret_cast(o_raw + 16)); + simde__m128i vec_32_to_47 = + simde_mm_loadu_si128(reinterpret_cast(o_raw + 32)); + simde__m128i vec_48_to_63 = + simde_mm_loadu_si128(reinterpret_cast(o_raw + 48)); - vec_00_to_15 = _mm_and_si128(vec_00_to_15, mask); - vec_16_to_31 = _mm_slli_epi16(_mm_and_si128(vec_16_to_31, mask), 4); - vec_32_to_47 = _mm_and_si128(vec_32_to_47, mask); - vec_48_to_63 = _mm_slli_epi16(_mm_and_si128(vec_48_to_63, mask), 4); + vec_00_to_15 = simde_mm_and_si128(vec_00_to_15, mask); + vec_16_to_31 = simde_mm_slli_epi16(simde_mm_and_si128(vec_16_to_31, mask), 4); + vec_32_to_47 = simde_mm_and_si128(vec_32_to_47, mask); + vec_48_to_63 = simde_mm_slli_epi16(simde_mm_and_si128(vec_48_to_63, mask), 4); - __m128i compact4_1 = _mm_or_si128(vec_00_to_15, vec_16_to_31); - __m128i compact4_2 = _mm_or_si128(vec_32_to_47, vec_48_to_63); + simde__m128i compact4_1 = simde_mm_or_si128(vec_00_to_15, vec_16_to_31); + simde__m128i compact4_2 = simde_mm_or_si128(vec_32_to_47, vec_48_to_63); - _mm_storeu_si128(reinterpret_cast<__m128i*>(o_compact), compact4_1); - _mm_storeu_si128(reinterpret_cast<__m128i*>(o_compact + 16), compact4_2); + simde_mm_storeu_si128(reinterpret_cast(o_compact), compact4_1); + simde_mm_storeu_si128(reinterpret_cast(o_compact + 16), compact4_2); o_compact += 32; @@ -165,14 +144,9 @@ inline void packing_5bit_excode(const uint8_t* o_raw, uint8_t* o_compact, size_t o_raw += 64; o_compact += 8; } -#else - std::cerr << "Current only support AVX512F only for packing excode\n" << std::flush; - exit(1); -#endif } inline void packing_6bit_excode(const uint8_t* o_raw, uint8_t* o_compact, size_t dim) { -#if defined(__AVX512F__) constexpr int64_t kMask4 = 0x0f0f0f0f0f0f0f0f; constexpr int32_t kMask2 = 0x30303030; for (size_t j = 0; j < dim; j += 16) { @@ -198,44 +172,39 @@ inline void packing_6bit_excode(const uint8_t* o_raw, uint8_t* o_compact, size_t o_raw += 16; o_compact += 4; } -#else - std::cerr << "Current only support AVX512F only for packing excode\n" << std::flush; - exit(1); -#endif } inline void packing_7bit_excode(const uint8_t* o_raw, uint8_t* o_compact, size_t dim) { -#if defined(__AVX512F__) // for vec00 to vec47, split code into 6 + 1 // for vec48 to vec63, split code into 2 + 2 + 2 + 1 - const __m128i mask2 = _mm_set1_epi8(0b11000000); - const __m128i mask6 = _mm_set1_epi8(0b00111111); + const simde__m128i mask2 = simde_mm_set1_epi8(0b11000000); + const simde__m128i mask6 = simde_mm_set1_epi8(0b00111111); for (size_t d = 0; d < dim; d += 64) { - __m128i vec_00_to_15 = _mm_loadu_si128(reinterpret_cast(o_raw)); - __m128i vec_16_to_31 = - _mm_loadu_si128(reinterpret_cast(o_raw + 16)); - __m128i vec_32_to_47 = - _mm_loadu_si128(reinterpret_cast(o_raw + 32)); - __m128i vec_48_to_63 = - _mm_loadu_si128(reinterpret_cast(o_raw + 48)); - - __m128i compact = _mm_or_si128( - _mm_and_si128(vec_00_to_15, mask6), - _mm_and_si128(_mm_slli_epi16(vec_48_to_63, 6), mask2) + simde__m128i vec_00_to_15 = simde_mm_loadu_si128(reinterpret_cast(o_raw)); + simde__m128i vec_16_to_31 = + simde_mm_loadu_si128(reinterpret_cast(o_raw + 16)); + simde__m128i vec_32_to_47 = + simde_mm_loadu_si128(reinterpret_cast(o_raw + 32)); + simde__m128i vec_48_to_63 = + simde_mm_loadu_si128(reinterpret_cast(o_raw + 48)); + + simde__m128i compact = simde_mm_or_si128( + simde_mm_and_si128(vec_00_to_15, mask6), + simde_mm_and_si128(simde_mm_slli_epi16(vec_48_to_63, 6), mask2) ); - _mm_storeu_si128(reinterpret_cast<__m128i*>(o_compact), compact); + simde_mm_storeu_si128(reinterpret_cast(o_compact), compact); - compact = _mm_or_si128( - _mm_and_si128(vec_16_to_31, mask6), - _mm_and_si128(_mm_slli_epi16(vec_48_to_63, 4), mask2) + compact = simde_mm_or_si128( + simde_mm_and_si128(vec_16_to_31, mask6), + simde_mm_and_si128(simde_mm_slli_epi16(vec_48_to_63, 4), mask2) ); - _mm_storeu_si128(reinterpret_cast<__m128i*>(o_compact + 16), compact); + simde_mm_storeu_si128(reinterpret_cast(o_compact + 16), compact); - compact = _mm_or_si128( - _mm_and_si128(vec_32_to_47, mask6), - _mm_and_si128(_mm_slli_epi16(vec_48_to_63, 2), mask2) + compact = simde_mm_or_si128( + simde_mm_and_si128(vec_32_to_47, mask6), + simde_mm_and_si128(simde_mm_slli_epi16(vec_48_to_63, 2), mask2) ); - _mm_storeu_si128(reinterpret_cast<__m128i*>(o_compact + 32), compact); + simde_mm_storeu_si128(reinterpret_cast(o_compact + 32), compact); o_compact += 48; int64_t top_bit = 0; @@ -249,10 +218,6 @@ inline void packing_7bit_excode(const uint8_t* o_raw, uint8_t* o_compact, size_t o_compact += 8; o_raw += 64; } -#else - std::cerr << "Current only support AVX512F only for packing excode\n" << std::flush; - exit(1); -#endif } inline void packing_8bit_excode(const uint8_t* o_raw, uint8_t* o_compact, size_t dim) { diff --git a/rabitqlib/third/simde b/rabitqlib/third/simde new file mode 160000 index 0000000..71fd833 --- /dev/null +++ b/rabitqlib/third/simde @@ -0,0 +1 @@ +Subproject commit 71fd833d9666141edcd1d3c109a80e228303d8d7 diff --git a/rabitqlib/utils/rotator.hpp b/rabitqlib/utils/rotator.hpp index 4e19c18..c6a6871 100644 --- a/rabitqlib/utils/rotator.hpp +++ b/rabitqlib/utils/rotator.hpp @@ -111,33 +111,33 @@ static inline void flip_sign(const uint8_t* flip, float* data, size_t dim) { std::memcpy(&mask_bits, &flip[i / 8], sizeof(mask_bits)); // Split into four 16-bit mask segments - const __mmask16 mask0 = _cvtu32_mask16(static_cast(mask_bits & 0xFFFF)); - const __mmask16 mask1 = - _cvtu32_mask16(static_cast((mask_bits >> 16) & 0xFFFF)); - const __mmask16 mask2 = - _cvtu32_mask16(static_cast((mask_bits >> 32) & 0xFFFF)); - const __mmask16 mask3 = - _cvtu32_mask16(static_cast((mask_bits >> 48) & 0xFFFF)); + const simde__mmask16 mask0 = simde_cvtu32_mask16(static_cast(mask_bits & 0xFFFF)); + const simde__mmask16 mask1 = + simde_cvtu32_mask16(static_cast((mask_bits >> 16) & 0xFFFF)); + const simde__mmask16 mask2 = + simde_cvtu32_mask16(static_cast((mask_bits >> 32) & 0xFFFF)); + const simde__mmask16 mask3 = + simde_cvtu32_mask16(static_cast((mask_bits >> 48) & 0xFFFF)); // Prepare sign-flip constant - const __m512 sign_flip = _mm512_castsi512_ps(_mm512_set1_epi32(0x80000000)); + const simde__m512 sign_flip = simde_mm512_castsi512_ps(simde_mm512_set1_epi32(0x80000000)); // Process 16 floats at a time with each mask segment - __m512 vec0 = _mm512_loadu_ps(&data[i]); - vec0 = _mm512_mask_xor_ps(vec0, mask0, vec0, sign_flip); - _mm512_storeu_ps(&data[i], vec0); + simde__m512 vec0 = simde_mm512_loadu_ps(&data[i]); + vec0 = simde_mm512_mask_xor_ps(vec0, mask0, vec0, sign_flip); + simde_mm512_storeu_ps(&data[i], vec0); - __m512 vec1 = _mm512_loadu_ps(&data[i + 16]); - vec1 = _mm512_mask_xor_ps(vec1, mask1, vec1, sign_flip); - _mm512_storeu_ps(&data[i + 16], vec1); + simde__m512 vec1 = simde_mm512_loadu_ps(&data[i + 16]); + vec1 = simde_mm512_mask_xor_ps(vec1, mask1, vec1, sign_flip); + simde_mm512_storeu_ps(&data[i + 16], vec1); - __m512 vec2 = _mm512_loadu_ps(&data[i + 32]); - vec2 = _mm512_mask_xor_ps(vec2, mask2, vec2, sign_flip); - _mm512_storeu_ps(&data[i + 32], vec2); + simde__m512 vec2 = simde_mm512_loadu_ps(&data[i + 32]); + vec2 = simde_mm512_mask_xor_ps(vec2, mask2, vec2, sign_flip); + simde_mm512_storeu_ps(&data[i + 32], vec2); - __m512 vec3 = _mm512_loadu_ps(&data[i + 48]); - vec3 = _mm512_mask_xor_ps(vec3, mask3, vec3, sign_flip); - _mm512_storeu_ps(&data[i + 48], vec3); + simde__m512 vec3 = simde_mm512_loadu_ps(&data[i + 48]); + vec3 = simde_mm512_mask_xor_ps(vec3, mask3, vec3, sign_flip); + simde_mm512_storeu_ps(&data[i + 48], vec3); } } @@ -224,14 +224,14 @@ class FhtKacRotator : public Rotator { static void kacs_walk(float* data, size_t len) { // ! len % 32 == 0; for (size_t i = 0; i < len / 2; i += 16) { - __m512 x = _mm512_loadu_ps(&data[i]); - __m512 y = _mm512_loadu_ps(&data[i + (len / 2)]); + simde__m512 x = simde_mm512_loadu_ps(&data[i]); + simde__m512 y = simde_mm512_loadu_ps(&data[i + (len / 2)]); - __m512 new_x = _mm512_add_ps(x, y); - __m512 new_y = _mm512_sub_ps(x, y); + simde__m512 new_x = simde_mm512_add_ps(x, y); + simde__m512 new_y = simde_mm512_sub_ps(x, y); - _mm512_storeu_ps(&data[i], new_x); - _mm512_storeu_ps(&data[i + (len / 2)], new_y); + simde_mm512_storeu_ps(&data[i], new_x); + simde_mm512_storeu_ps(&data[i + (len / 2)], new_y); } } diff --git a/rabitqlib/utils/simde-utils.hpp b/rabitqlib/utils/simde-utils.hpp new file mode 100644 index 0000000..7e1ed48 --- /dev/null +++ b/rabitqlib/utils/simde-utils.hpp @@ -0,0 +1,69 @@ +#pragma once + +#include +#include + +inline float simde_mm512_reduce_add_ps(simde__m512 sum) { + simde__m256 low = simde_mm512_castps512_ps256(sum); + simde__m256 high = simde_mm512_extractf32x8_ps(sum, 1); + simde__m256 sum256 = simde_mm256_add_ps(low, high); + + simde__m128 lo128 = simde_mm256_castps256_ps128(sum256); + simde__m128 hi128 = simde_mm256_extractf128_ps(sum256, 1); + simde__m128 sum128 = simde_mm_add_ps(lo128, hi128); + sum128 = simde_mm_hadd_ps(sum128, sum128); + sum128 = simde_mm_hadd_ps(sum128, sum128); + return simde_mm_cvtss_f32(sum128); +} + +inline simde__m512i simde_mm512_cvtepu8_epi32(simde__m128i a) { + alignas(16) uint8_t vals_u8[16]; + simde_mm_storeu_si128(reinterpret_cast(vals_u8), a); + + int32_t vals_i32[16]; + for (int i = 0; i < 16; ++i) { + vals_i32[i] = static_cast(vals_u8[i]); // zero-extend + } + + return simde_mm512_loadu_epi32(vals_i32); +} + +inline simde__m512i simde_mm512_cvtepi8_epi32(simde__m128i a) { + alignas(16) int8_t vals_i8[16]; + simde_mm_storeu_si128((simde__m128i*)vals_i8, a); + + int32_t vals_i32[16]; + for (int i = 0; i < 16; ++i) { + vals_i32[i] = (int32_t)vals_i8[i]; // sign extend + } + + return simde_mm512_loadu_epi32(vals_i32); +} + +inline simde__mmask16 simde_cvtu32_mask16(unsigned int a) { + return (simde__mmask16)(a & 0xFFFFu); +} + +inline simde__m128i simde_mm512_cvtepi32_epi8(simde__m512i a) { + alignas(64) int32_t tmp32[16]; + simde_mm512_storeu_epi32(tmp32, a); + + alignas(16) int8_t tmp8[16]; + for (int i = 0; i < 16; ++i) { + tmp8[i] = (int8_t)(tmp32[i]); + } + + return simde_mm_loadu_epi8(tmp8); +} + +inline simde__m256i simde_mm512_cvtepi32_epi16(simde__m512i a) { + alignas(64) int32_t input[16]; + simde_mm512_storeu_epi32(input, a); + + alignas(32) int16_t output[16]; + for (int i = 0; i < 16; ++i) { + output[i] = (int16_t)(input[i]); + } + + return simde_mm256_loadu_epi16(output); +} \ No newline at end of file diff --git a/rabitqlib/utils/space.hpp b/rabitqlib/utils/space.hpp index 3cd9327..c76a1a8 100644 --- a/rabitqlib/utils/space.hpp +++ b/rabitqlib/utils/space.hpp @@ -1,7 +1,8 @@ #pragma once -#include -#include +#include +#include +#include "simde-utils.hpp" #include #include @@ -51,24 +52,20 @@ inline void scalar_quantize_optimized( float lo, float delta ) { -#if defined(__AVX512F__) size_t mul16 = dim - (dim & 0b1111); size_t i = 0; float one_over_delta = 1 / delta; - auto lo512 = _mm512_set1_ps(lo); - auto od512 = _mm512_set1_ps(one_over_delta); + auto lo512 = simde_mm512_set1_ps(lo); + auto od512 = simde_mm512_set1_ps(one_over_delta); for (; i < mul16; i += 16) { - auto cur = _mm512_loadu_ps(&vec0[i]); - cur = _mm512_mul_ps(_mm512_sub_ps(cur, lo512), od512); - auto i8 = _mm512_cvtepi32_epi8(_mm512_cvtps_epi32(cur)); - _mm_storeu_epi8(&result[i], i8); + auto cur = simde_mm512_loadu_ps(&vec0[i]); + cur = simde_mm512_mul_ps(simde_mm512_sub_ps(cur, lo512), od512); + auto i8 = simde_mm512_cvtepi32_epi8(simde_mm512_cvtps_epi32(cur)); + simde_mm_storeu_si128(reinterpret_cast(&result[i]), i8); } for (; i < dim; ++i) { result[i] = static_cast(std::round((vec0[i] - lo) * one_over_delta)); } -#else - scalar_quantize_normal(result, vec0, dim, lo, delta); -#endif } template <> @@ -79,39 +76,20 @@ inline void scalar_quantize_optimized( float lo, float delta ) { -#if defined(__AVX512F__) size_t mul16 = dim - (dim & 0b1111); size_t i = 0; float one_over_delta = 1 / delta; - auto lo512 = _mm512_set1_ps(lo); - auto ow512 = _mm512_set1_ps(one_over_delta); + auto lo512 = simde_mm512_set1_ps(lo); + auto ow512 = simde_mm512_set1_ps(one_over_delta); for (; i < mul16; i += 16) { - auto cur = _mm512_loadu_ps(&vec0[i]); - cur = _mm512_mul_ps(_mm512_sub_ps(cur, lo512), ow512); - auto i16 = _mm512_cvtepi32_epi16(_mm512_cvtps_epi32(cur)); - _mm256_storeu_epi16(&result[i], i16); - } - for (; i < dim; ++i) { - result[i] = static_cast(std::round((vec0[i] - lo) * one_over_delta)); - } -#elif defined(__AVX2__) - size_t mul8 = dim - (dim & 0b111); - size_t i = 0; - float one_over_delta = 1 / delta; - auto lo256 = _mm256_set1_ps(lo); - auto ow256 = _mm256_set1_ps(one_over_delta); - for (; i < mul8; i += 8) { - auto cur = _mm256_loadu_ps(&vec0[i]); - cur = _mm256_mul_ps(_mm256_sub_ps(cur, lo256), ow256); - auto i16 = _mm256_cvtepi32_epi16(_mm256_cvtps_epi32(cur)); - _mm_storeu_epi16(&result[i], i16); + auto cur = simde_mm512_loadu_ps(&vec0[i]); + cur = simde_mm512_mul_ps(simde_mm512_sub_ps(cur, lo512), ow512); + auto i16 = simde_mm512_cvtepi32_epi16(simde_mm512_cvtps_epi32(cur)); + simde_mm256_storeu_si256(reinterpret_cast(&result[i]), i16); } for (; i < dim; ++i) { result[i] = static_cast(std::round((vec0[i] - lo) * one_over_delta)); } -#else - scalar_quantize_normal(result, vec0, dim, lo, delta); -#endif } } // namespace scalar_impl @@ -275,280 +253,280 @@ inline float ip16_fxu1_avx512( const float* __restrict__ query, const uint8_t* __restrict__ compact_code, size_t dim ) { float result = 0; - __m512 sum = _mm512_setzero_ps(); + simde__m512 sum = simde_mm512_setzero_ps(); for (size_t i = 0; i < dim; i += 16) { - __mmask16 mask = *reinterpret_cast(compact_code); - __m512 q = _mm512_loadu_ps(query); + simde__mmask16 mask = *reinterpret_cast(compact_code); + simde__m512 q = simde_mm512_loadu_ps(query); - sum = _mm512_add_ps(_mm512_maskz_mov_ps(mask, q), sum); + sum = simde_mm512_add_ps(simde_mm512_maskz_mov_ps(mask, q), sum); compact_code += 2; query += 16; } - result = _mm512_reduce_add_ps(sum); + result = simde_mm512_reduce_add_ps(sum); return result; } inline float ip16_fxu2_avx512( const float* __restrict__ query, const uint8_t* __restrict__ compact_code, size_t dim ) { - __m512 sum = _mm512_setzero_ps(); + simde__m512 sum = simde_mm512_setzero_ps(); float result = 0; - const __m128i mask = _mm_set1_epi8(0b00000011); + const simde__m128i mask = simde_mm_set1_epi8(0b00000011); for (size_t i = 0; i < dim; i += 16) { int32_t compact = *reinterpret_cast(compact_code); - __m128i code = _mm_set_epi32(compact >> 6, compact >> 4, compact >> 2, compact); - code = _mm_and_si128(code, mask); + simde__m128i code = simde_mm_set_epi32(compact >> 6, compact >> 4, compact >> 2, compact); + code = simde_mm_and_si128(code, mask); - __m512 cf = _mm512_cvtepi32_ps(_mm512_cvtepi8_epi32(code)); + simde__m512 cf = simde_mm512_cvtepi32_ps(simde_mm512_cvtepi8_epi32(code)); - __m512 q = _mm512_loadu_ps(&query[i]); - sum = _mm512_fmadd_ps(cf, q, sum); + simde__m512 q = simde_mm512_loadu_ps(&query[i]); + sum = simde_mm512_fmadd_ps(cf, q, sum); compact_code += 4; } - result = _mm512_reduce_add_ps(sum); + result = simde_mm512_reduce_add_ps(sum); return result; } inline float ip64_fxu3_avx512( const float* __restrict__ query, const uint8_t* __restrict__ compact_code, size_t dim ) { - __m512 sum = _mm512_setzero_ps(); + simde__m512 sum = simde_mm512_setzero_ps(); - const __m128i mask = _mm_set1_epi8(0b11); - const __m128i top_mask = _mm_set1_epi8(0b100); + const simde__m128i mask = simde_mm_set1_epi8(0b11); + const simde__m128i top_mask = simde_mm_set1_epi8(0b100); for (size_t i = 0; i < dim; i += 64) { - __m128i compact2 = _mm_loadu_si128(reinterpret_cast(compact_code)); + simde__m128i compact2 = simde_mm_loadu_si128(reinterpret_cast(compact_code)); compact_code += 16; int64_t top_bit = *reinterpret_cast(compact_code); compact_code += 8; - __m128i vec_00_to_15 = _mm_and_si128(compact2, mask); - __m128i vec_16_to_31 = _mm_and_si128(_mm_srli_epi16(compact2, 2), mask); - __m128i vec_32_to_47 = _mm_and_si128(_mm_srli_epi16(compact2, 4), mask); - __m128i vec_48_to_63 = _mm_and_si128(_mm_srli_epi16(compact2, 6), mask); + simde__m128i vec_00_to_15 = simde_mm_and_si128(compact2, mask); + simde__m128i vec_16_to_31 = simde_mm_and_si128(simde_mm_srli_epi16(compact2, 2), mask); + simde__m128i vec_32_to_47 = simde_mm_and_si128(simde_mm_srli_epi16(compact2, 4), mask); + simde__m128i vec_48_to_63 = simde_mm_and_si128(simde_mm_srli_epi16(compact2, 6), mask); - __m128i top_00_to_15 = - _mm_and_si128(_mm_set_epi64x(top_bit << 1, top_bit << 2), top_mask); - __m128i top_16_to_31 = - _mm_and_si128(_mm_set_epi64x(top_bit >> 1, top_bit >> 0), top_mask); - __m128i top_32_to_47 = - _mm_and_si128(_mm_set_epi64x(top_bit >> 3, top_bit >> 2), top_mask); - __m128i top_48_to_63 = - _mm_and_si128(_mm_set_epi64x(top_bit >> 5, top_bit >> 4), top_mask); + simde__m128i top_00_to_15 = + simde_mm_and_si128(simde_mm_set_epi64x(top_bit << 1, top_bit << 2), top_mask); + simde__m128i top_16_to_31 = + simde_mm_and_si128(simde_mm_set_epi64x(top_bit >> 1, top_bit >> 0), top_mask); + simde__m128i top_32_to_47 = + simde_mm_and_si128(simde_mm_set_epi64x(top_bit >> 3, top_bit >> 2), top_mask); + simde__m128i top_48_to_63 = + simde_mm_and_si128(simde_mm_set_epi64x(top_bit >> 5, top_bit >> 4), top_mask); - vec_00_to_15 = _mm_or_si128(top_00_to_15, vec_00_to_15); - vec_16_to_31 = _mm_or_si128(top_16_to_31, vec_16_to_31); - vec_32_to_47 = _mm_or_si128(top_32_to_47, vec_32_to_47); - vec_48_to_63 = _mm_or_si128(top_48_to_63, vec_48_to_63); + vec_00_to_15 = simde_mm_or_si128(top_00_to_15, vec_00_to_15); + vec_16_to_31 = simde_mm_or_si128(top_16_to_31, vec_16_to_31); + vec_32_to_47 = simde_mm_or_si128(top_32_to_47, vec_32_to_47); + vec_48_to_63 = simde_mm_or_si128(top_48_to_63, vec_48_to_63); - __m512 q; - __m512 cf; + simde__m512 q; + simde__m512 cf; - q = _mm512_loadu_ps(&query[i]); - cf = _mm512_cvtepi32_ps(_mm512_cvtepu8_epi32(vec_00_to_15)); - sum = _mm512_fmadd_ps(q, cf, sum); + q = simde_mm512_loadu_ps(&query[i]); + cf = simde_mm512_cvtepi32_ps(simde_mm512_cvtepu8_epi32(vec_00_to_15)); + sum = simde_mm512_fmadd_ps(q, cf, sum); - q = _mm512_loadu_ps(&query[i + 16]); - cf = _mm512_cvtepi32_ps(_mm512_cvtepu8_epi32(vec_16_to_31)); - sum = _mm512_fmadd_ps(q, cf, sum); + q = simde_mm512_loadu_ps(&query[i + 16]); + cf = simde_mm512_cvtepi32_ps(simde_mm512_cvtepu8_epi32(vec_16_to_31)); + sum = simde_mm512_fmadd_ps(q, cf, sum); - q = _mm512_loadu_ps(&query[i + 32]); - cf = _mm512_cvtepi32_ps(_mm512_cvtepu8_epi32(vec_32_to_47)); - sum = _mm512_fmadd_ps(q, cf, sum); + q = simde_mm512_loadu_ps(&query[i + 32]); + cf = simde_mm512_cvtepi32_ps(simde_mm512_cvtepu8_epi32(vec_32_to_47)); + sum = simde_mm512_fmadd_ps(q, cf, sum); - q = _mm512_loadu_ps(&query[i + 48]); - cf = _mm512_cvtepi32_ps(_mm512_cvtepu8_epi32(vec_48_to_63)); - sum = _mm512_fmadd_ps(q, cf, sum); + q = simde_mm512_loadu_ps(&query[i + 48]); + cf = simde_mm512_cvtepi32_ps(simde_mm512_cvtepu8_epi32(vec_48_to_63)); + sum = simde_mm512_fmadd_ps(q, cf, sum); } - return _mm512_reduce_add_ps(sum); + return simde_mm512_reduce_add_ps(sum); } inline float ip16_fxu4_avx512( const float* __restrict__ query, const uint8_t* __restrict__ compact_code, size_t dim ) { constexpr int64_t kMask = 0x0f0f0f0f0f0f0f0f; - __m512 sum = _mm512_setzero_ps(); + simde__m512 sum = simde_mm512_setzero_ps(); for (size_t i = 0; i < dim; i += 16) { int64_t compact = *reinterpret_cast(compact_code); int64_t code0 = compact & kMask; int64_t code1 = (compact >> 4) & kMask; - __m128i c8 = _mm_set_epi64x(code1, code0); - __m512 cf = _mm512_cvtepi32_ps(_mm512_cvtepi8_epi32(c8)); + simde__m128i c8 = simde_mm_set_epi64x(code1, code0); + simde__m512 cf = simde_mm512_cvtepi32_ps(simde_mm512_cvtepi8_epi32(c8)); - __m512 q = _mm512_loadu_ps(&query[i]); - sum = _mm512_fmadd_ps(cf, q, sum); + simde__m512 q = simde_mm512_loadu_ps(&query[i]); + sum = simde_mm512_fmadd_ps(cf, q, sum); compact_code += 8; } - return _mm512_reduce_add_ps(sum); + return simde_mm512_reduce_add_ps(sum); } inline float ip64_fxu5_avx512( const float* __restrict__ query, const uint8_t* __restrict__ compact_code, size_t dim ) { - __m512 sum = _mm512_setzero_ps(); + simde__m512 sum = simde_mm512_setzero_ps(); - const __m128i mask = _mm_set1_epi8(0b1111); - const __m128i top_mask = _mm_set1_epi8(0b10000); + const simde__m128i mask = simde_mm_set1_epi8(0b1111); + const simde__m128i top_mask = simde_mm_set1_epi8(0b10000); for (size_t i = 0; i < dim; i += 64) { - __m128i compact4_1 = - _mm_loadu_si128(reinterpret_cast(compact_code)); - __m128i compact4_2 = - _mm_loadu_si128(reinterpret_cast(compact_code + 16)); + simde__m128i compact4_1 = + simde_mm_loadu_si128(reinterpret_cast(compact_code)); + simde__m128i compact4_2 = + simde_mm_loadu_si128(reinterpret_cast(compact_code + 16)); compact_code += 32; int64_t top_bit = *reinterpret_cast(compact_code); compact_code += 8; - __m128i vec_00_to_15 = _mm_and_si128(compact4_1, mask); - __m128i vec_16_to_31 = _mm_and_si128(_mm_srli_epi16(compact4_1, 4), mask); - __m128i vec_32_to_47 = _mm_and_si128(compact4_2, mask); - __m128i vec_48_to_63 = _mm_and_si128(_mm_srli_epi16(compact4_2, 4), mask); + simde__m128i vec_00_to_15 = simde_mm_and_si128(compact4_1, mask); + simde__m128i vec_16_to_31 = simde_mm_and_si128(simde_mm_srli_epi16(compact4_1, 4), mask); + simde__m128i vec_32_to_47 = simde_mm_and_si128(compact4_2, mask); + simde__m128i vec_48_to_63 = simde_mm_and_si128(simde_mm_srli_epi16(compact4_2, 4), mask); - __m128i top_00_to_15 = - _mm_and_si128(_mm_set_epi64x(top_bit << 3, top_bit << 4), top_mask); - __m128i top_16_to_31 = - _mm_and_si128(_mm_set_epi64x(top_bit << 1, top_bit << 2), top_mask); - __m128i top_32_to_47 = - _mm_and_si128(_mm_set_epi64x(top_bit >> 1, top_bit >> 0), top_mask); - __m128i top_48_to_63 = - _mm_and_si128(_mm_set_epi64x(top_bit >> 3, top_bit >> 2), top_mask); + simde__m128i top_00_to_15 = + simde_mm_and_si128(simde_mm_set_epi64x(top_bit << 3, top_bit << 4), top_mask); + simde__m128i top_16_to_31 = + simde_mm_and_si128(simde_mm_set_epi64x(top_bit << 1, top_bit << 2), top_mask); + simde__m128i top_32_to_47 = + simde_mm_and_si128(simde_mm_set_epi64x(top_bit >> 1, top_bit >> 0), top_mask); + simde__m128i top_48_to_63 = + simde_mm_and_si128(simde_mm_set_epi64x(top_bit >> 3, top_bit >> 2), top_mask); - vec_00_to_15 = _mm_or_si128(top_00_to_15, vec_00_to_15); - vec_16_to_31 = _mm_or_si128(top_16_to_31, vec_16_to_31); - vec_32_to_47 = _mm_or_si128(top_32_to_47, vec_32_to_47); - vec_48_to_63 = _mm_or_si128(top_48_to_63, vec_48_to_63); + vec_00_to_15 = simde_mm_or_si128(top_00_to_15, vec_00_to_15); + vec_16_to_31 = simde_mm_or_si128(top_16_to_31, vec_16_to_31); + vec_32_to_47 = simde_mm_or_si128(top_32_to_47, vec_32_to_47); + vec_48_to_63 = simde_mm_or_si128(top_48_to_63, vec_48_to_63); - __m512 q; - __m512 cf; + simde__m512 q; + simde__m512 cf; - q = _mm512_loadu_ps(&query[i]); - cf = _mm512_cvtepi32_ps(_mm512_cvtepu8_epi32(vec_00_to_15)); - sum = _mm512_fmadd_ps(q, cf, sum); + q = simde_mm512_loadu_ps(&query[i]); + cf = simde_mm512_cvtepi32_ps(simde_mm512_cvtepu8_epi32(vec_00_to_15)); + sum = simde_mm512_fmadd_ps(q, cf, sum); - q = _mm512_loadu_ps(&query[i + 16]); - cf = _mm512_cvtepi32_ps(_mm512_cvtepu8_epi32(vec_16_to_31)); - sum = _mm512_fmadd_ps(q, cf, sum); + q = simde_mm512_loadu_ps(&query[i + 16]); + cf = simde_mm512_cvtepi32_ps(simde_mm512_cvtepu8_epi32(vec_16_to_31)); + sum = simde_mm512_fmadd_ps(q, cf, sum); - q = _mm512_loadu_ps(&query[i + 32]); - cf = _mm512_cvtepi32_ps(_mm512_cvtepu8_epi32(vec_32_to_47)); - sum = _mm512_fmadd_ps(q, cf, sum); + q = simde_mm512_loadu_ps(&query[i + 32]); + cf = simde_mm512_cvtepi32_ps(simde_mm512_cvtepu8_epi32(vec_32_to_47)); + sum = simde_mm512_fmadd_ps(q, cf, sum); - q = _mm512_loadu_ps(&query[i + 48]); - cf = _mm512_cvtepi32_ps(_mm512_cvtepu8_epi32(vec_48_to_63)); - sum = _mm512_fmadd_ps(q, cf, sum); + q = simde_mm512_loadu_ps(&query[i + 48]); + cf = simde_mm512_cvtepi32_ps(simde_mm512_cvtepu8_epi32(vec_48_to_63)); + sum = simde_mm512_fmadd_ps(q, cf, sum); } - return _mm512_reduce_add_ps(sum); + return simde_mm512_reduce_add_ps(sum); } inline float ip16_fxu6_avx512( const float* __restrict__ query, const uint8_t* __restrict__ compact_code, size_t dim ) { constexpr int64_t kMask4 = 0x0f0f0f0f0f0f0f0f; - __m512 sum = _mm512_setzero_ps(); - const __m128i mask2 = _mm_set1_epi8(0b00110000); + simde__m512 sum = simde_mm512_setzero_ps(); + const simde__m128i mask2 = simde_mm_set1_epi8(0b00110000); for (size_t i = 0; i < dim; i += 16) { int64_t compact4 = *reinterpret_cast(compact_code); int64_t code4_0 = compact4 & kMask4; int64_t code4_1 = (compact4 >> 4) & kMask4; - __m128i c4 = _mm_set_epi64x(code4_1, code4_0); // lower 4 + simde__m128i c4 = simde_mm_set_epi64x(code4_1, code4_0); // lower 4 compact_code += 8; int32_t compact2 = *reinterpret_cast(compact_code); - __m128i c2 = _mm_set_epi32(compact2 >> 2, compact2, compact2 << 2, compact2 << 4); - c2 = _mm_and_si128(c2, mask2); + simde__m128i c2 = simde_mm_set_epi32(compact2 >> 2, compact2, compact2 << 2, compact2 << 4); + c2 = simde_mm_and_si128(c2, mask2); - __m128i c6 = _mm_or_si128(c2, c4); + simde__m128i c6 = simde_mm_or_si128(c2, c4); - __m512 cf = _mm512_cvtepi32_ps(_mm512_cvtepi8_epi32(c6)); + simde__m512 cf = simde_mm512_cvtepi32_ps(simde_mm512_cvtepi8_epi32(c6)); - __m512 q = _mm512_loadu_ps(&query[i]); - sum = _mm512_fmadd_ps(cf, q, sum); + simde__m512 q = simde_mm512_loadu_ps(&query[i]); + sum = simde_mm512_fmadd_ps(cf, q, sum); compact_code += 4; } - return _mm512_reduce_add_ps(sum); + return simde_mm512_reduce_add_ps(sum); } inline float ip64_fxu7_avx512( const float* __restrict__ query, const uint8_t* __restrict__ compact_code, size_t dim ) { - __m512 sum = _mm512_setzero_ps(); + simde__m512 sum = simde_mm512_setzero_ps(); - const __m128i mask6 = _mm_set1_epi8(0b00111111); - const __m128i mask2 = _mm_set1_epi8(0b11000000); - const __m128i top_mask = _mm_set1_epi8(0b1000000); + const simde__m128i mask6 = simde_mm_set1_epi8(0b00111111); + const simde__m128i mask2 = simde_mm_set1_epi8(0b11000000); + const simde__m128i top_mask = simde_mm_set1_epi8(0b1000000); for (size_t i = 0; i < dim; i += 64) { - __m128i cpt1 = _mm_loadu_si128(reinterpret_cast(compact_code)); - __m128i cpt2 = _mm_loadu_si128(reinterpret_cast(compact_code + 16)); - __m128i cpt3 = _mm_loadu_si128(reinterpret_cast(compact_code + 32)); + simde__m128i cpt1 = simde_mm_loadu_si128(reinterpret_cast(compact_code)); + simde__m128i cpt2 = simde_mm_loadu_si128(reinterpret_cast(compact_code + 16)); + simde__m128i cpt3 = simde_mm_loadu_si128(reinterpret_cast(compact_code + 32)); compact_code += 48; - __m128i vec_00_to_15 = _mm_and_si128(cpt1, mask6); - __m128i vec_16_to_31 = _mm_and_si128(cpt2, mask6); - __m128i vec_32_to_47 = _mm_and_si128(cpt3, mask6); - __m128i vec_48_to_63 = _mm_or_si128( - _mm_or_si128( - _mm_srli_epi16(_mm_and_si128(cpt1, mask2), 6), - _mm_srli_epi16(_mm_and_si128(cpt2, mask2), 4) + simde__m128i vec_00_to_15 = simde_mm_and_si128(cpt1, mask6); + simde__m128i vec_16_to_31 = simde_mm_and_si128(cpt2, mask6); + simde__m128i vec_32_to_47 = simde_mm_and_si128(cpt3, mask6); + simde__m128i vec_48_to_63 = simde_mm_or_si128( + simde_mm_or_si128( + simde_mm_srli_epi16(simde_mm_and_si128(cpt1, mask2), 6), + simde_mm_srli_epi16(simde_mm_and_si128(cpt2, mask2), 4) ), - _mm_srli_epi16(_mm_and_si128(cpt3, mask2), 2) + simde_mm_srli_epi16(simde_mm_and_si128(cpt3, mask2), 2) ); int64_t top_bit = *reinterpret_cast(compact_code); compact_code += 8; - __m128i top_00_to_15 = - _mm_and_si128(_mm_set_epi64x(top_bit << 5, top_bit << 6), top_mask); - __m128i top_16_to_31 = - _mm_and_si128(_mm_set_epi64x(top_bit << 3, top_bit << 4), top_mask); - __m128i top_32_to_47 = - _mm_and_si128(_mm_set_epi64x(top_bit << 1, top_bit << 2), top_mask); - __m128i top_48_to_63 = - _mm_and_si128(_mm_set_epi64x(top_bit >> 1, top_bit << 0), top_mask); + simde__m128i top_00_to_15 = + simde_mm_and_si128(simde_mm_set_epi64x(top_bit << 5, top_bit << 6), top_mask); + simde__m128i top_16_to_31 = + simde_mm_and_si128(simde_mm_set_epi64x(top_bit << 3, top_bit << 4), top_mask); + simde__m128i top_32_to_47 = + simde_mm_and_si128(simde_mm_set_epi64x(top_bit << 1, top_bit << 2), top_mask); + simde__m128i top_48_to_63 = + simde_mm_and_si128(simde_mm_set_epi64x(top_bit >> 1, top_bit << 0), top_mask); - vec_00_to_15 = _mm_or_si128(top_00_to_15, vec_00_to_15); - vec_16_to_31 = _mm_or_si128(top_16_to_31, vec_16_to_31); - vec_32_to_47 = _mm_or_si128(top_32_to_47, vec_32_to_47); - vec_48_to_63 = _mm_or_si128(top_48_to_63, vec_48_to_63); + vec_00_to_15 = simde_mm_or_si128(top_00_to_15, vec_00_to_15); + vec_16_to_31 = simde_mm_or_si128(top_16_to_31, vec_16_to_31); + vec_32_to_47 = simde_mm_or_si128(top_32_to_47, vec_32_to_47); + vec_48_to_63 = simde_mm_or_si128(top_48_to_63, vec_48_to_63); - __m512 q; - __m512 cf; + simde__m512 q; + simde__m512 cf; - q = _mm512_loadu_ps(&query[i]); - cf = _mm512_cvtepi32_ps(_mm512_cvtepu8_epi32(vec_00_to_15)); - sum = _mm512_fmadd_ps(q, cf, sum); + q = simde_mm512_loadu_ps(&query[i]); + cf = simde_mm512_cvtepi32_ps(simde_mm512_cvtepu8_epi32(vec_00_to_15)); + sum = simde_mm512_fmadd_ps(q, cf, sum); - q = _mm512_loadu_ps(&query[i + 16]); - cf = _mm512_cvtepi32_ps(_mm512_cvtepu8_epi32(vec_16_to_31)); - sum = _mm512_fmadd_ps(q, cf, sum); + q = simde_mm512_loadu_ps(&query[i + 16]); + cf = simde_mm512_cvtepi32_ps(simde_mm512_cvtepu8_epi32(vec_16_to_31)); + sum = simde_mm512_fmadd_ps(q, cf, sum); - q = _mm512_loadu_ps(&query[i + 32]); - cf = _mm512_cvtepi32_ps(_mm512_cvtepu8_epi32(vec_32_to_47)); - sum = _mm512_fmadd_ps(q, cf, sum); + q = simde_mm512_loadu_ps(&query[i + 32]); + cf = simde_mm512_cvtepi32_ps(simde_mm512_cvtepu8_epi32(vec_32_to_47)); + sum = simde_mm512_fmadd_ps(q, cf, sum); - q = _mm512_loadu_ps(&query[i + 48]); - cf = _mm512_cvtepi32_ps(_mm512_cvtepu8_epi32(vec_48_to_63)); - sum = _mm512_fmadd_ps(q, cf, sum); + q = simde_mm512_loadu_ps(&query[i + 48]); + cf = simde_mm512_cvtepi32_ps(simde_mm512_cvtepu8_epi32(vec_48_to_63)); + sum = simde_mm512_fmadd_ps(q, cf, sum); } - return _mm512_reduce_add_ps(sum); + return simde_mm512_reduce_add_ps(sum); } // inner product between float type and int type vectors @@ -620,14 +598,14 @@ inline void transpose_bin( ) { // 512 / 16 = 32 for (size_t i = 0; i < padded_dim; i += 32) { - __m512i v = _mm512_loadu_si512(q); - v = _mm512_slli_epi32(v, (16 - b_query)); + simde__m512i v = simde_mm512_loadu_si512(q); + v = simde_mm512_slli_epi32(v, (16 - b_query)); for (size_t j = 0; j < b_query; ++j) { - uint32_t v1 = _mm512_movepi16_mask(v); // get most significant bit + uint32_t v1 = simde_mm512_movepi16_mask(v); // get most significant bit v1 = reverse_bits(v1); tq[((b_query - j - 1) * (padded_dim / 64)) + (i / 64)] |= (static_cast(v1) << ((i / 32 % 2 == 0) ? 32 : 0)); - v = _mm512_add_epi32(v, v); + v = simde_mm512_add_epi32(v, v); } q += 32; } @@ -638,16 +616,16 @@ static inline void new_transpose_bin( ) { // 512 / 16 = 32 for (size_t i = 0; i < padded_dim; i += 64) { - __m512i vec_00_to_31 = _mm512_loadu_si512(q); - __m512i vec_32_to_63 = _mm512_loadu_si512(q + 32); + simde__m512i vec_00_to_31 = simde_mm512_loadu_si512(q); + simde__m512i vec_32_to_63 = simde_mm512_loadu_si512(q + 32); // the first (16 - b_query) bits are empty - vec_00_to_31 = _mm512_slli_epi32(vec_00_to_31, (16 - b_query)); - vec_32_to_63 = _mm512_slli_epi32(vec_32_to_63, (16 - b_query)); + vec_00_to_31 = simde_mm512_slli_epi32(vec_00_to_31, (16 - b_query)); + vec_32_to_63 = simde_mm512_slli_epi32(vec_32_to_63, (16 - b_query)); for (size_t j = 0; j < b_query; ++j) { - uint32_t v0 = _mm512_movepi16_mask(vec_00_to_31); // get most significant bit - uint32_t v1 = _mm512_movepi16_mask(vec_32_to_63); // get most significant bit + uint32_t v0 = simde_mm512_movepi16_mask(vec_00_to_31); // get most significant bit + uint32_t v1 = simde_mm512_movepi16_mask(vec_32_to_63); // get most significant bit // [TODO: remove all reverse_bits] v0 = reverse_bits(v0); v1 = reverse_bits(v1); @@ -655,8 +633,8 @@ static inline void new_transpose_bin( tq[b_query - j - 1] = v; - vec_00_to_31 = _mm512_slli_epi16(vec_00_to_31, 1); - vec_32_to_63 = _mm512_slli_epi16(vec_32_to_63, 1); + vec_00_to_31 = simde_mm512_slli_epi16(vec_00_to_31, 1); + vec_32_to_63 = simde_mm512_slli_epi16(vec_32_to_63, 1); } tq += b_query; q += 64; @@ -668,33 +646,33 @@ inline float mask_ip_x0_q_old(const float* query, const uint64_t* data, size_t p const auto* it_data = data; const auto* it_query = query; - __m512 sum = _mm512_setzero_ps(); + simde__m512 sum = simde_mm512_setzero_ps(); for (size_t i = 0; i < num_blk; ++i) { uint64_t bits = reverse_bits_u64(*it_data); - __mmask16 mask0 = static_cast<__mmask16>(bits >> 00); // for q[0..15] - __mmask16 mask1 = static_cast<__mmask16>(bits >> 16); // for q[16..31] - __mmask16 mask2 = static_cast<__mmask16>(bits >> 32); // for q[32..47] - __mmask16 mask3 = static_cast<__mmask16>(bits >> 48); // for q[48..63] - - __m512 q0 = _mm512_loadu_ps(it_query); - __m512 q1 = _mm512_loadu_ps(it_query + 16); - __m512 q2 = _mm512_loadu_ps(it_query + 32); - __m512 q3 = _mm512_loadu_ps(it_query + 48); - - __m512 masked0 = _mm512_maskz_mov_ps(mask0, q0); - __m512 masked1 = _mm512_maskz_mov_ps(mask1, q1); - __m512 masked2 = _mm512_maskz_mov_ps(mask2, q2); - __m512 masked3 = _mm512_maskz_mov_ps(mask3, q3); - - sum = _mm512_add_ps(sum, masked0); - sum = _mm512_add_ps(sum, masked1); - sum = _mm512_add_ps(sum, masked2); - sum = _mm512_add_ps(sum, masked3); + simde__mmask16 mask0 = static_cast(bits >> 00); // for q[0..15] + simde__mmask16 mask1 = static_cast(bits >> 16); // for q[16..31] + simde__mmask16 mask2 = static_cast(bits >> 32); // for q[32..47] + simde__mmask16 mask3 = static_cast(bits >> 48); // for q[48..63] + + simde__m512 q0 = simde_mm512_loadu_ps(it_query); + simde__m512 q1 = simde_mm512_loadu_ps(it_query + 16); + simde__m512 q2 = simde_mm512_loadu_ps(it_query + 32); + simde__m512 q3 = simde_mm512_loadu_ps(it_query + 48); + + simde__m512 masked0 = simde_mm512_maskz_mov_ps(mask0, q0); + simde__m512 masked1 = simde_mm512_maskz_mov_ps(mask1, q1); + simde__m512 masked2 = simde_mm512_maskz_mov_ps(mask2, q2); + simde__m512 masked3 = simde_mm512_maskz_mov_ps(mask3, q3); + + sum = simde_mm512_add_ps(sum, masked0); + sum = simde_mm512_add_ps(sum, masked1); + sum = simde_mm512_add_ps(sum, masked2); + sum = simde_mm512_add_ps(sum, masked3); it_data++; it_query += 64; } - return _mm512_reduce_add_ps(sum); + return simde_mm512_reduce_add_ps(sum); } inline float mask_ip_x0_q(const float* query, const uint64_t* data, size_t padded_dim) { @@ -707,24 +685,24 @@ inline float mask_ip_x0_q(const float* query, const uint64_t* data, size_t padde // __m512 sum2 = _mm512_setzero_ps(); // __m512 sum3 = _mm512_setzero_ps(); - __m512 sum = _mm512_setzero_ps(); + simde__m512 sum = simde_mm512_setzero_ps(); for (size_t i = 0; i < num_blk; ++i) { uint64_t bits = reverse_bits_u64(*it_data); - __mmask16 mask0 = static_cast<__mmask16>(bits); - __mmask16 mask1 = static_cast<__mmask16>(bits >> 16); - __mmask16 mask2 = static_cast<__mmask16>(bits >> 32); - __mmask16 mask3 = static_cast<__mmask16>(bits >> 48); + simde__mmask16 mask0 = static_cast(bits); + simde__mmask16 mask1 = static_cast(bits >> 16); + simde__mmask16 mask2 = static_cast(bits >> 32); + simde__mmask16 mask3 = static_cast(bits >> 48); - __m512 masked0 = _mm512_maskz_loadu_ps(mask0, it_query); - __m512 masked1 = _mm512_maskz_loadu_ps(mask1, it_query + 16); - __m512 masked2 = _mm512_maskz_loadu_ps(mask2, it_query + 32); - __m512 masked3 = _mm512_maskz_loadu_ps(mask3, it_query + 48); + simde__m512 masked0 = simde_mm512_maskz_loadu_ps(mask0, it_query); + simde__m512 masked1 = simde_mm512_maskz_loadu_ps(mask1, it_query + 16); + simde__m512 masked2 = simde_mm512_maskz_loadu_ps(mask2, it_query + 32); + simde__m512 masked3 = simde_mm512_maskz_loadu_ps(mask3, it_query + 48); - sum = _mm512_add_ps(sum, masked0); - sum = _mm512_add_ps(sum, masked1); - sum = _mm512_add_ps(sum, masked2); - sum = _mm512_add_ps(sum, masked3); + sum = simde_mm512_add_ps(sum, masked0); + sum = simde_mm512_add_ps(sum, masked1); + sum = simde_mm512_add_ps(sum, masked2); + sum = simde_mm512_add_ps(sum, masked3); // _mm_prefetch(reinterpret_cast(it_query + 128), _MM_HINT_T1); @@ -733,7 +711,7 @@ inline float mask_ip_x0_q(const float* query, const uint64_t* data, size_t padde } // __m512 sum = _mm512_add_ps(_mm512_add_ps(sum0, sum1), _mm512_add_ps(sum2, sum3)); - return _mm512_reduce_add_ps(sum); + return simde_mm512_reduce_add_ps(sum); } inline float ip_x0_q( diff --git a/rabitqlib/utils/warmup_space.hpp b/rabitqlib/utils/warmup_space.hpp index dcb0320..8be442f 100644 --- a/rabitqlib/utils/warmup_space.hpp +++ b/rabitqlib/utils/warmup_space.hpp @@ -1,6 +1,6 @@ #pragma once -#include +#include #include #include @@ -24,22 +24,22 @@ inline float warmup_ip_x0_q( size_t vec_end = (num_blk / vec_width) * vec_width; // Vector accumulators (each holds 8 64-bit lanes) - __m512i ip_vec = _mm512_setzero_si512( + simde__m512i ip_vec = simde_mm512_setzero_si512( ); // will accumulate weighted popcount intersections per block - __m512i ppc_vec = _mm512_setzero_si512(); // will accumulate popcounts of data blocks + simde__m512i ppc_vec = simde_mm512_setzero_si512(); // will accumulate popcounts of data blocks // Loop over blocks in batches of 8 for (size_t i = 0; i < vec_end; i += vec_width) { // Load eight 64-bit data blocks into x_vec. - __m512i x_vec = _mm512_loadu_si512(reinterpret_cast(data + i)); + simde__m512i x_vec = simde_mm512_loadu_si512(reinterpret_cast(data + i)); // Compute popcount for each 64-bit block in x_vec using the AVX512 VPOPCNTDQ // instruction. (Ensure you compile with the proper flags for VPOPCNTDQ.) - __m512i popcnt_x_vec = _mm512_popcnt_epi64(x_vec); - ppc_vec = _mm512_add_epi64(ppc_vec, popcnt_x_vec); + simde__m512i popcnt_x_vec = simde_mm512_popcnt_epi64(x_vec); + ppc_vec = simde_mm512_add_epi64(ppc_vec, popcnt_x_vec); // For accumulating the weighted popcounts per block. - __m512i block_ip = _mm512_setzero_si512(); + simde__m512i block_ip = simde_mm512_setzero_si512(); // Process each query component (b_query is a compile-time constant, and is small). for (uint32_t j = 0; j < b_query; j++) { @@ -51,33 +51,33 @@ inline float warmup_ip_x0_q( indices[k] = ((i + k) * b_query + j); } // Load indices from memory. - __m512i index_vec = _mm512_loadu_si512(indices); + simde__m512i index_vec = simde_mm512_loadu_si512(indices); // Gather 8 query words with a scale of 8 (since query is an array of 64-bit // integers). - __m512i q_vec = _mm512_i64gather_epi64(index_vec, query, 8); + simde__m512i q_vec = simde_mm512_i64gather_epi64(index_vec, static_cast(query), 8); // Compute bitwise AND of data blocks and corresponding query words. - __m512i and_vec = _mm512_and_si512(x_vec, q_vec); + simde__m512i and_vec = simde_mm512_and_si512(x_vec, q_vec); // Compute popcount on each lane. - __m512i popcnt_and = _mm512_popcnt_epi64(and_vec); + simde__m512i popcnt_and = simde_mm512_popcnt_epi64(and_vec); // Multiply by the weighting factor (1 << j) for this query position. const uint64_t shift = 1ULL << j; - __m512i shift_vec = _mm512_set1_epi64(shift); - __m512i weighted = _mm512_mullo_epi64(popcnt_and, shift_vec); + simde__m512i shift_vec = simde_mm512_set1_epi64(shift); + simde__m512i weighted = simde_mm512_mullo_epi64(popcnt_and, shift_vec); // Accumulate weighted popcounts for these blocks. - block_ip = _mm512_add_epi64(block_ip, weighted); + block_ip = simde_mm512_add_epi64(block_ip, weighted); } // Add the block's query-weighted popcount to the overall ip vector. - ip_vec = _mm512_add_epi64(ip_vec, block_ip); + ip_vec = simde_mm512_add_epi64(ip_vec, block_ip); } // Horizontally reduce the vector accumulators. uint64_t ip_arr[vec_width]; uint64_t ppc_arr[vec_width]; - _mm512_storeu_si512(reinterpret_cast<__m512i*>(ip_arr), ip_vec); - _mm512_storeu_si512(reinterpret_cast<__m512i*>(ppc_arr), ppc_vec); + simde_mm512_storeu_si512(reinterpret_cast(ip_arr), ip_vec); + simde_mm512_storeu_si512(reinterpret_cast(ppc_arr), ppc_vec); for (size_t k = 0; k < vec_width; k++) { ip_scalar += ip_arr[k]; diff --git a/sample/ivf_rabitq_querying.cpp b/sample/ivf_rabitq_querying.cpp index 38265e9..f11664a 100644 --- a/sample/ivf_rabitq_querying.cpp +++ b/sample/ivf_rabitq_querying.cpp @@ -16,11 +16,12 @@ static std::vector get_nprobes( const index_type& ivf, const std::vector& all_nprobes, data_type& query, - gt_type& gt + gt_type& gt, + bool use_hacc ); static size_t topk = 100; -static size_t test_round = 5; +static size_t test_round = 3; int main(int argc, char** argv) { if (argc < 4) { @@ -56,6 +57,7 @@ int main(int argc, char** argv) { index_type ivf; ivf.load(index_file); + std::cout << "Load index done" << std::endl; std::vector all_nprobes; all_nprobes.push_back(5); for (size_t i = 10; i < 200; i += 10) { @@ -74,18 +76,20 @@ int main(int argc, char** argv) { all_nprobes.push_back(6000); all_nprobes.push_back(10000); all_nprobes.push_back(15000); - + std::cout << "all_nprobes size: " << all_nprobes.size() << std::endl; rabitqlib::StopW stopw; - auto nprobes = get_nprobes(ivf, all_nprobes, query, gt); + auto nprobes = get_nprobes(ivf, all_nprobes, query, gt, use_hacc); size_t length = nprobes.size(); std::vector> all_qps(test_round, std::vector(length)); std::vector> all_recall(test_round, std::vector(length)); - + std::cout << "all_qps size: " << all_qps.size() << std::endl; + std::cout << "all_recall size: " << all_recall.size() << std::endl; for (size_t r = 0; r < test_round; r++) { for (size_t l = 0; l < length; ++l) { size_t nprobe = nprobes[l]; + std::cout << "Current nprobe: " << nprobe << std::endl; if (nprobe > ivf.num_clusters()) { std::cout << "nprobe " << nprobe << " is larger than number of clusters, "; std::cout << "will use nprobe = num_cluster (" << ivf.num_clusters() << ").\n"; @@ -135,7 +139,8 @@ static std::vector get_nprobes( const index_type& ivf, const std::vector& all_nprobes, data_type& query, - gt_type& gt + gt_type& gt, + bool use_hacc ) { size_t nq = query.rows(); size_t total_count = topk * nq; @@ -148,7 +153,7 @@ static std::vector get_nprobes( size_t total_correct = 0; std::vector results(topk); for (size_t i = 0; i < nq; i++) { - ivf.search(&query(i, 0), topk, nprobe, results.data()); + ivf.search(&query(i, 0), topk, nprobe, results.data(), use_hacc); for (size_t j = 0; j < topk; j++) { for (size_t k = 0; k < topk; k++) { if (gt(i, k) == results[j]) { @@ -159,6 +164,7 @@ static std::vector get_nprobes( } } float recall = static_cast(total_correct) / static_cast(total_count); + std::cout << "nprobe: " << nprobe << " recall: " << recall << " old_recall: " << old_recall << std::endl; if (recall > 0.997 || recall - old_recall < 1e-5) { break; } diff --git a/tests/CMakeLists.txt b/tests/CMakeLists.txt new file mode 100644 index 0000000..6e50706 --- /dev/null +++ b/tests/CMakeLists.txt @@ -0,0 +1,19 @@ +cmake_policy(SET CMP0135 NEW) + +file(GLOB TEST_FILES "*.cpp") + +include(FetchContent) +FetchContent_Declare(googletest + URL "https://github.com/google/googletest/archive/release-1.12.1.tar.gz") +set(BUILD_GMOCK CACHE BOOL OFF) +set(INSTALL_GTEST CACHE BOOL OFF) +FetchContent_MakeAvailable(googletest) + + +add_executable(${PROJECT_NAME}_test ${TEST_FILES}) +target_link_libraries(${PROJECT_NAME}_test PRIVATE gtest gtest_main) +target_include_directories(${PROJECT_NAME}_test PRIVATE ${PROJECT_SOURCE_DIR}) + +include(GoogleTest) +gtest_discover_tests(${PROJECT_NAME}_test) + diff --git a/tests/test_init.cpp b/tests/test_init.cpp new file mode 100644 index 0000000..fa4dbae --- /dev/null +++ b/tests/test_init.cpp @@ -0,0 +1,5 @@ +#include + +TEST(FastScan, TestFastScan) { + EXPECT_EQ(1, 1); +} \ No newline at end of file diff --git a/tests/test_ivf_rabitq.cpp b/tests/test_ivf_rabitq.cpp new file mode 100644 index 0000000..9a4493e --- /dev/null +++ b/tests/test_ivf_rabitq.cpp @@ -0,0 +1,102 @@ +#include + +#include +#include +#include +#include +#include +#include + +#include + +using PID = rabitqlib::PID; +using ivf_index = rabitqlib::ivf::IVF; +using IvfIndexPtr = std::unique_ptr; +using data_type = rabitqlib::RowMajorArray; +using gt_type = rabitqlib::RowMajorArray; +using DatasetPtr = std::unique_ptr; + +static constexpr size_t n_clusters = 128; + + +DatasetPtr create_dataset(size_t num_points, size_t dim, size_t num_centroids, rabitqlib::test::Dataset::DatasetType type) { + return std::make_unique(num_points, dim, num_centroids, 42, type); +} + +IvfIndexPtr create_ivf_index(rabitqlib::test::Dataset & dataset, size_t total_bits) { + auto index = std::make_unique(dataset.get_num_points(), dataset.get_dim(), dataset.get_num_clusters(), total_bits); + index->construct(dataset.get_data_ptr(), dataset.get_centroids_ptr(), dataset.get_cluster_ids_ptr(), false); + return index; +} + +float calculate_recall(std::vector> & hacc_results, std::unordered_map> & results) { + float recall = 0; + for (size_t i = 0; i < hacc_results.size(); i++) { + size_t correct_count = 0; + if (results.count(i) > 0) { + for (auto & pid : hacc_results[i]) { + if (results[i].count(pid) > 0) { + correct_count++; + } + } + recall += static_cast(correct_count) / static_cast(results[i].size()); + } + else { + recall += 0; + } + } + return recall / static_cast(hacc_results.size()); +} + +void test_hacc_multi_bits(size_t num_points, size_t dim, size_t num_centroids, size_t total_bits, size_t topk = 100) { + auto dataset = create_dataset(num_points, dim, num_centroids, rabitqlib::test::Dataset::DatasetType::Random); + auto ivf_index = create_ivf_index(*dataset, total_bits); + EXPECT_EQ(ivf_index->num_clusters(), dataset->get_num_clusters()); + EXPECT_EQ(ivf_index->padded_dim(), dataset->get_dim()); + + auto results = dataset->get_results(rabitqlib::METRIC_L2, topk); + + auto queries = dataset->get_queries_ptr(); + auto hacc_results = std::vector>(dataset->get_num_queries(), std::vector(topk)); + auto no_hacc_results = std::vector>(dataset->get_num_queries(), std::vector(topk)); + for (size_t i = 0; i < dataset->get_num_queries(); i++) { + ivf_index->search(queries + i * dim, topk, dataset->get_num_clusters() / 2, hacc_results[i].data(), true); + ivf_index->search(queries + i * dim, topk, dataset->get_num_clusters() / 2, no_hacc_results[i].data(), false); + } + auto hacc_recall = calculate_recall(hacc_results, results); + auto no_hacc_recall = calculate_recall(no_hacc_results, results); + EXPECT_GT(hacc_recall, 0.9); + EXPECT_GT(no_hacc_recall, 0.9); +} + +TEST(IvfIndexTest, lut_hacc_ex_1bits) { + test_hacc_multi_bits(5000, 128, n_clusters, 2); +} + +TEST(IvfIndexTest, lut_hacc_ex_2bits) { + test_hacc_multi_bits(5000, 128, n_clusters, 3); +} + +TEST(IvfIndexTest, lut_hacc_ex_3bits) { + test_hacc_multi_bits(5000, 128, n_clusters, 4); +} + +TEST(IvfIndexTest, lut_hacc_ex_4bits) { + test_hacc_multi_bits(5000, 128, n_clusters, 5); +} + +TEST(IvfIndexTest, lut_hacc_ex_5bits) { + test_hacc_multi_bits(8000, 128, n_clusters, 6); +} + +TEST(IvfIndexTest, lut_hacc_ex_6bits) { + test_hacc_multi_bits(10000, 128, n_clusters, 7); +} + +TEST(IvfIndexTest, lut_hacc_ex_7bits) { + test_hacc_multi_bits(20000, 128, n_clusters, 8); +} + +TEST(IvfIndexTest, lut_hacc_ex_8bits) { + test_hacc_multi_bits(100000, 128, n_clusters, 9); +} From b0330311b5483c22713ab1a20862bf9bb08723f0 Mon Sep 17 00:00:00 2001 From: YiJustin Date: Tue, 30 Sep 2025 10:54:24 +0800 Subject: [PATCH 2/2] fix cmake flags --- CMakeLists.txt | 15 +++++++++------ 1 file changed, 9 insertions(+), 6 deletions(-) diff --git a/CMakeLists.txt b/CMakeLists.txt index bae94af..9ee81f6 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -9,12 +9,12 @@ include(cmake/cpu_features.cmake) if (CMAKE_BUILD_TYPE STREQUAL "Debug") message(STATUS "Building in debug mode") - set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -g") - set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -g") + set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -g -O0") + set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -g -O0") elseif(CMAKE_BUILD_TYPE STREQUAL "Release") message(STATUS "Building in release mode") - set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -O2") - set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -O2") + set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Ofast") + set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -Ofast") endif() include_directories(${PROJECT_SOURCE_DIR}/rabitqlib) @@ -22,8 +22,11 @@ include_directories(${PROJECT_SOURCE_DIR}/rabitqlib/third/simde) set(CMAKE_EXPORT_COMPILE_COMMANDS ON) -SET(CMAKE_CXX_FLAGS "-Wall -Ofast -Wextra -lrt -march=native -fpic -fopenmp -ftree-vectorize -fexceptions -w") -SET(CMAKE_C_FLAGS "-Wall -Ofast -Wextra -lrt -march=native -fpic -fopenmp -ftree-vectorize -fexceptions -w") +SET(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wall -Wextra -lrt -march=native -fpic -fopenmp -ftree-vectorize -fexceptions -w") +SET(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -Wall -Wextra -lrt -march=native -fpic -fopenmp -ftree-vectorize -fexceptions -w") + +message(STATUS "CMAKE_CXX_FLAGS: ${CMAKE_CXX_FLAGS}") +message(STATUS "CMAKE_C_FLAGS: ${CMAKE_C_FLAGS}") add_subdirectory(sample)