diff --git a/.github/workflows/03-macos-linux-build.yml b/.github/workflows/03-macos-linux-build.yml index 1dd1e45c7..e80d63672 100644 --- a/.github/workflows/03-macos-linux-build.yml +++ b/.github/workflows/03-macos-linux-build.yml @@ -54,6 +54,14 @@ jobs: sudo apt-get install -y clang libomp-dev shell: bash + - name: Install AIO + if: runner.os == 'Linux' && runner.arch == 'X64' + run: | + sudo apt-get update + sudo apt-get install -y --no-install-recommends \ + libaio-dev + shell: bash + - name: Print CPU info if: runner.os == 'Linux' run: lscpu @@ -89,7 +97,6 @@ jobs: pytest \ scikit-build-core \ setuptools_scm - shell: bash - name: Build from source run: | diff --git a/.github/workflows/clang_tidy.yml b/.github/workflows/clang_tidy.yml index 0b595b67c..1d4a7eff1 100644 --- a/.github/workflows/clang_tidy.yml +++ b/.github/workflows/clang_tidy.yml @@ -29,6 +29,14 @@ jobs: sudo apt-get update sudo apt-get install -y clang-tidy=1:18.0-59~exp2 cmake ninja-build libomp-dev + - name: Install AIO + if: runner.os == 'Linux' && runner.arch == 'X64' + run: | + sudo apt-get update + sudo apt-get install -y --no-install-recommends \ + libaio-dev + shell: bash + - name: Configure CMake and export compile commands run: | cmake -S . -B build -G Ninja \ diff --git a/.gitmodules b/.gitmodules index 49ed1920b..af8239611 100644 --- a/.gitmodules +++ b/.gitmodules @@ -43,3 +43,6 @@ [submodule "thirdparty/RaBitQ-Library/RaBitQ-Library-0.1"] path = thirdparty/RaBitQ-Library/RaBitQ-Library-0.1 url = https://github.com/VectorDB-NTU/RaBitQ-Library.git +[submodule "thirdparty/aio/libaio-0.3"] + path = thirdparty/aio/libaio-0.3 + url = https://github.com/yugabyte/libaio.git diff --git a/CMakeLists.txt b/CMakeLists.txt index a33e61e99..8cc69fea0 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -109,6 +109,17 @@ else() endif() message(STATUS "RABITQ_ARCH_FLAG: ${RABITQ_ARCH_FLAG}") +# DiskAnn support (Linux x86_64 only, requires libaio) +if(CMAKE_SYSTEM_NAME STREQUAL "Linux" AND CMAKE_SYSTEM_PROCESSOR MATCHES "x86_64|i686|i386" AND NOT ANDROID AND NOT IOS) + set(DISKANN_SUPPORTED ON) + add_definitions(-DDISKANN_SUPPORTED=1) +else() + set(DISKANN_SUPPORTED OFF) + add_definitions(-DDISKANN_SUPPORTED=0) + message(STATUS "DiskAnn support disabled - only supported on Linux x86_64") +endif() +message(STATUS "DISKANN_SUPPORTED: ${DISKANN_SUPPORTED}") + option(USE_OSS_MIRROR "Use OSS mirror for faster third-party downloads" OFF) if(DEFINED ENV{USE_OSS_MIRROR} AND NOT "$ENV{USE_OSS_MIRROR}" STREQUAL "") set(USE_OSS_MIRROR "$ENV{USE_OSS_MIRROR}" CACHE BOOL "Use OSS mirror for faster third-party downloads" FORCE) diff --git a/python/tests/detail/fixture_helper.py b/python/tests/detail/fixture_helper.py index 6082956dc..fb40da198 100644 --- a/python/tests/detail/fixture_helper.py +++ b/python/tests/detail/fixture_helper.py @@ -1,5 +1,13 @@ import pytest import logging +import platform + +DISKANN_SUPPORTED = platform.system() == "Linux" and platform.machine() in ( + "x86_64", + "AMD64", + "i686", + "i386", +) from typing import Any, Generator from zvec.typing import DataType, StatusCode, MetricType, QuantizeType @@ -97,6 +105,12 @@ def full_schema_new(request) -> CollectionSchema: else: nullable, has_index, vector_index = True, False, HnswIndexParam() + # Skip DiskAnn tests on unsupported platforms + from zvec.model.param import DiskAnnIndexParam + + if isinstance(vector_index, DiskAnnIndexParam) and not DISKANN_SUPPORTED: + pytest.skip("DiskAnn only supported on Linux x86_64") + scalar_index_param = None vector_index_param = None if has_index: diff --git a/python/tests/detail/test_collection_recall.py b/python/tests/detail/test_collection_recall.py index 25dad12dd..6e14c5d89 100644 --- a/python/tests/detail/test_collection_recall.py +++ b/python/tests/detail/test_collection_recall.py @@ -22,8 +22,10 @@ HnswIndexParam, FlatIndexParam, IVFIndexParam, + DiskAnnIndexParam, HnswQueryParam, IVFQueryParam, + DiskAnnQueryParam, ) from zvec.model.schema import FieldSchema, VectorSchema @@ -179,10 +181,24 @@ def get_ground_truth_map(collection, test_docs, query_vectors_map, metric_type, for field_name, query_vectors in query_vectors_map.items(): ground_truth_map[field_name] = {} + # Support per-field metric type: metric_type can be a dict mapping + # field_name -> MetricType, or a single MetricType applied to all fields. + if isinstance(metric_type, dict): + field_metric = metric_type.get(field_name, MetricType.IP) + else: + field_metric = metric_type + for i, query_vector in enumerate(query_vectors): # Get the ground truth for this query relevant_doc_ids_scores = get_ground_truth_for_vector_query( - collection, query_vector, field_name, test_docs, i, metric_type, k, True + collection, + query_vector, + field_name, + test_docs, + i, + field_metric, + k, + True, ) ground_truth_map[field_name][i] = relevant_doc_ids_scores @@ -292,6 +308,7 @@ class TestRecall: [ (True, True, HnswIndexParam()), (False, True, IVFIndexParam()), + (False, True, DiskAnnIndexParam()), (False, True, FlatIndexParam()), # ——ok ( True, @@ -371,6 +388,19 @@ class TestRecall: use_soar=False, ), ), + ( + True, + True, + DiskAnnIndexParam( + metric_type=MetricType.IP, + max_degree=32, + ), + ), + ( + True, + True, + DiskAnnIndexParam(metric_type=MetricType.L2, max_degree=32), + ), ], indirect=True, ) @@ -388,10 +418,16 @@ def test_recall_with_single_vector_valid_500( ): full_schema_params = request.getfixturevalue("full_schema_new") + # Build per-field metric type map so ground truth uses each field's + # actual index metric (fields may fall back to HnswIndexParam/IP). + field_metric_map = {} for vector_para in full_schema_params.vectors: - if vector_para.name == "vector_fp32_field": - metric_type = vector_para.index_param.metric_type - break + if vector_para.index_param is not None: + field_metric_map[vector_para.name] = vector_para.index_param.metric_type + else: + field_metric_map[vector_para.name] = MetricType.IP + + metric_type = field_metric_map.get("vector_fp32_field", MetricType.IP) multiple_docs = [ generate_doc_recall(i, full_collection_new.schema) for i in range(doc_num) @@ -438,9 +474,13 @@ def test_recall_with_single_vector_valid_500( multiple_docs[i].vectors[field_name] for i in range(query_num) ] - # Get ground truth mapping + # Get ground truth mapping (pass per-field metric map) ground_truth_map = get_ground_truth_map( - full_collection_new, multiple_docs, query_vectors_map, metric_type, top_k + full_collection_new, + multiple_docs, + query_vectors_map, + field_metric_map, + top_k, ) # Validate ground truth mapping structure @@ -479,8 +519,8 @@ def test_recall_with_single_vector_valid_500( print("(recall_at_k_stats:\n") print(recall_at_k_stats) - print("metric_type:") - print(metric_type) + print("field_metric_map:") + print(field_metric_map) # Print Recall@K statistics print(f"Recall@{top_k} using Ground Truth:") for field_name, stats in recall_at_k_stats.items(): @@ -552,7 +592,21 @@ def test_recall_with_single_vector_valid_500( use_soar=True, ), ), - # (True, True, IVFIndexParam(metric_type=MetricType.COSINE, n_list=150, n_iters=15, use_soar=False, )), + ( + True, + True, + DiskAnnIndexParam(metric_type=MetricType.IP, max_degree=32), + ), + ( + True, + True, + DiskAnnIndexParam(metric_type=MetricType.L2, max_degree=32), + ), + ( + True, + True, + DiskAnnIndexParam(metric_type=MetricType.COSINE, max_degree=32), + ), ], indirect=True, ) @@ -571,10 +625,16 @@ def test_recall_with_single_vector_valid_2000( ): full_schema_params = request.getfixturevalue("full_schema_new") + # Build per-field metric type map so ground truth uses each field's + # actual index metric (fields may fall back to HnswIndexParam/IP). + field_metric_map = {} for vector_para in full_schema_params.vectors: - if vector_para.name == "vector_fp32_field": - metric_type = vector_para.index_param.metric_type - break + if vector_para.index_param is not None: + field_metric_map[vector_para.name] = vector_para.index_param.metric_type + else: + field_metric_map[vector_para.name] = MetricType.IP + + metric_type = field_metric_map.get("vector_fp32_field", MetricType.IP) multiple_docs = [ generate_doc_recall(i, full_collection_new.schema) for i in range(doc_num) @@ -621,9 +681,13 @@ def test_recall_with_single_vector_valid_2000( multiple_docs[i].vectors[field_name] for i in range(query_num) ] - # Get ground truth mapping + # Get ground truth mapping (pass per-field metric map) ground_truth_map = get_ground_truth_map( - full_collection_new, multiple_docs, query_vectors_map, metric_type, top_k + full_collection_new, + multiple_docs, + query_vectors_map, + field_metric_map, + top_k, ) # Validate ground truth mapping structure @@ -662,8 +726,8 @@ def test_recall_with_single_vector_valid_2000( print("(recall_at_k_stats:\n") print(recall_at_k_stats) - print("metric_type:") - print(metric_type) + print("field_metric_map:") + print(field_metric_map) # Print Recall@K statistics print(f"Recall@{top_k} using Ground Truth:") for field_name, stats in recall_at_k_stats.items(): diff --git a/python/zvec/model/param/__init__.py b/python/zvec/model/param/__init__.py index c613edf52..f0fbc47c6 100644 --- a/python/zvec/model/param/__init__.py +++ b/python/zvec/model/param/__init__.py @@ -17,6 +17,8 @@ AddColumnOption, AlterColumnOption, CollectionOption, + DiskAnnIndexParam, + DiskAnnQueryParam, FlatIndexParam, HnswIndexParam, HnswQueryParam, @@ -33,6 +35,8 @@ "AddColumnOption", "AlterColumnOption", "CollectionOption", + "DiskAnnIndexParam", + "DiskAnnQueryParam", "FlatIndexParam", "HnswIndexParam", "HnswQueryParam", diff --git a/src/ailego/algorithm/kmeans.h b/src/ailego/algorithm/kmeans.h index 097b71ecc..8b160e762 100644 --- a/src/ailego/algorithm/kmeans.h +++ b/src/ailego/algorithm/kmeans.h @@ -94,6 +94,7 @@ class Kmc2CentroidsGenerator { auto *centroids = owner->mutable_centroids(); std::mt19937 mt((std::random_device())()); + std::uniform_real_distribution dist(0.0, 1.0); ContainerType benches(cache.dimension()); @@ -1216,4 +1217,4 @@ using NibbleInnerProductKmeans = LloydCluster>; } // namespace ailego -} // namespace zvec \ No newline at end of file +} // namespace zvec diff --git a/src/ailego/algorithm/lloyd_cluster.h b/src/ailego/algorithm/lloyd_cluster.h index d889271c2..b130b52c9 100644 --- a/src/ailego/algorithm/lloyd_cluster.h +++ b/src/ailego/algorithm/lloyd_cluster.h @@ -361,4 +361,4 @@ class LloydCluster { }; } // namespace ailego -} // namespace zvec \ No newline at end of file +} // namespace zvec diff --git a/src/binding/c/CMakeLists.txt b/src/binding/c/CMakeLists.txt index 6533028c7..ef99a6d47 100644 --- a/src/binding/c/CMakeLists.txt +++ b/src/binding/c/CMakeLists.txt @@ -83,6 +83,11 @@ set_target_properties(zvec_c_api PROPERTIES VISIBILITY_INLINES_HIDDEN ON ) +# On Windows, define ZVEC_BUILD_SHARED so that c_api.h uses __declspec(dllexport) +if(MSVC OR WIN32) + target_compile_definitions(zvec_c_api PRIVATE ZVEC_BUILD_SHARED) +endif() + find_package(Threads REQUIRED) # Static linking of C++ standard library is handled in platform-specific sections diff --git a/src/binding/python/CMakeLists.txt b/src/binding/python/CMakeLists.txt index 7e5169176..cc054e2c7 100644 --- a/src/binding/python/CMakeLists.txt +++ b/src/binding/python/CMakeLists.txt @@ -19,25 +19,49 @@ set(SRC_LISTS pybind11_add_module(_zvec ${SRC_LISTS}) if (CMAKE_SYSTEM_NAME STREQUAL "Linux") - target_link_libraries(_zvec PRIVATE - -Wl,--whole-archive - $ - $ - $ - $ - $ - $ - $ - $ - $ - $ - $ - -Wl,--no-whole-archive - zvec_db - ) - target_link_options(_zvec PRIVATE - "LINKER:--version-script=${CMAKE_CURRENT_SOURCE_DIR}/exports.map" - ) + if (CMAKE_SYSTEM_PROCESSOR MATCHES "aarch64|arm64|arm") + target_link_libraries(_zvec PRIVATE + -Wl,--whole-archive + $ + $ + $ + $ + $ + $ + $ + $ + $ + $ + $ + -Wl,--no-whole-archive + zvec_db + ) + target_link_options(_zvec PRIVATE + "LINKER:--version-script=${CMAKE_CURRENT_SOURCE_DIR}/exports.map" + ) + else () + target_link_libraries(_zvec PRIVATE + -Wl,--whole-archive + $ + $ + $ + $ + $ + $ + $ + $ + $ + $ + $ + $ + -Wl,--no-whole-archive + zvec_db + aio + ) + target_link_options(_zvec PRIVATE + "LINKER:--version-script=${CMAKE_CURRENT_SOURCE_DIR}/exports.map" + ) + endif() elseif (APPLE) target_link_libraries(_zvec PRIVATE -Wl,-force_load,$ diff --git a/src/binding/python/model/param/python_param.cc b/src/binding/python/model/param/python_param.cc index 6ad7e1b58..bcd75b86d 100644 --- a/src/binding/python/model/param/python_param.cc +++ b/src/binding/python/model/param/python_param.cc @@ -33,6 +33,8 @@ static std::string index_type_to_string(const IndexType type) { return "HNSW"; case IndexType::HNSW_RABITQ: return "HNSW_RABITQ"; + case IndexType::DISKANN: + return "DISKANN"; default: return "UNDEFINED"; } @@ -643,6 +645,104 @@ Constructs an IVFIndexParam instance. t[0].cast(), t[1].cast(), t[2].cast(), t[3].cast(), t[4].cast()); })); + + // DiskAnnIndexParams + py::class_> + diskann_params(m, "DiskAnnIndexParam", R"pbdoc( +Parameters for configuring an DiskAnn index. + +DiskAnn stores compressed vector in memory and high-definition vector on disk. At query time, +only compressed vector will be loaded into memory. By this way, search memory at runtime is diminished. + +Attributes: + metric_type (MetricType): Distance metric used for similarity computation. + Default is ``MetricType.IP`` (inner product). + max_degree (int):. + list_size (int): . + pq_chunk_num (bool): . + quantize_type (QuantizeType): Optional quantization type for vector + compression (e.g., FP16, INT8). Default is ``QuantizeType.UNDEFINED``. + +Examples: + >>> from zvec.typing import MetricType, QuantizeType + >>> params = DiskAnnIndexParams( + ... metric_type=MetricType.COSINE, + ... max_degree=100, + ... list_size=15, + ... pq_chunk_num=True, + ... quantize_type=QuantizeType.FP16 + ... ) + >>> print(params.n_list) + 100 +)pbdoc"); + diskann_params + .def(py::init(), + py::arg("metric_type") = MetricType::IP, py::arg("max_degree") = 100, + py::arg("list_size") = 50, py::arg("pq_chunk_num") = 0, + py::arg("quantize_type") = QuantizeType::UNDEFINED, + R"pbdoc( +Constructs an DiskAnnIndexParams instance. + +Args: + metric_type (MetricType, optional): Distance metric. Defaults to MetricType.IP. + max_degree (int, optional): Number of inverted lists (clusters). Set to 0 for auto. + Defaults to 100. + list_size (int, optional): the list size to construct graph. + Defaults to 50. + pq chunk num (int, optional): The chunk num of production quantization. + quantize_type (QuantizeType, optional): Vector quantization type. + Defaults to QuantizeType.UNDEFINED. +)pbdoc") + .def_property_readonly("max_degree", &DiskAnnIndexParams::max_degree, + "int: max node degree.") + .def_property_readonly("list_size", &DiskAnnIndexParams::list_size, + "int: list size of graph construction") + .def_property_readonly( + "pq_chunk_num", + [](const DiskAnnIndexParams &self) -> int { + return self.pq_chunk_num(); + }, + "int: chunk num of production quantization.") + .def( + "to_dict", + [](const DiskAnnIndexParams &self) -> py::dict { + py::dict dict; + dict["type"] = index_type_to_string(self.type()); + dict["metric_type"] = metric_type_to_string(self.metric_type()); + dict["max_degree"] = self.max_degree(); + dict["list_size"] = self.list_size(); + dict["pq_chunk_num"] = self.pq_chunk_num(); + dict["quantize_type"] = + quantize_type_to_string(self.quantize_type()); + return dict; + }, + "Convert to dictionary with all fields") + .def( + "__repr__", + [](const DiskAnnIndexParams &self) { + return "{" + "\"metric_type\":" + + metric_type_to_string(self.metric_type()) + + ", \"max_degree\":" + std::to_string(self.max_degree()) + + ", \"list_size\":" + std::to_string(self.list_size()) + + ", \"pq_chunk_num\":" + std::to_string(self.pq_chunk_num()) + + ", \"quantize_type\":" + + quantize_type_to_string(self.quantize_type()) + "}"; + }) + .def(py::pickle( + [](const DiskAnnIndexParams &self) { + return py::make_tuple(self.metric_type(), self.max_degree(), + self.list_size(), self.pq_chunk_num(), + self.quantize_type()); + }, + [](py::tuple t) { + if (t.size() != 5) + throw std::runtime_error("Invalid state for DiskAnnIndexParams"); + return std::make_shared( + t[0].cast(), t[1].cast(), t[2].cast(), + t[3].cast(), t[4].cast()); + })); } void ZVecPyParams::bind_query_params(py::module_ &m) { @@ -884,6 +984,53 @@ Constructs an HnswRabitqQueryParam instance. obj->set_is_using_refiner(t[3].cast()); return obj; })); + + // binding diskann query params + py::class_> + diskann_params(m, "DiskAnnQueryParam", R"pbdoc( +Query parameters for DiskAnn index. + +Attributes: + type (IndexType): Always ``IndexType.DiskAnn``. + list_size (int): Number of closest clusters (inverted lists) to search. + Higher values improve recall but increase latency. + Default is 10. + +Examples: + >>> params = DiskAnnQueryParams(list_size=20) + >>> print(params.nprobe) + 20 +)pbdoc"); + diskann_params + .def(py::init(), py::arg("list_size") = 10, R"pbdoc( +Constructs an DiskAnnQueryParams instance. + +Args: + list_size (int, optional): list size during graph search. Defaults to 300. +)pbdoc") + .def_property_readonly( + "list_size", + [](const DiskAnnQueryParams &self) -> int { + return self.list_size(); + }, + "int: Number of inverted lists to search during DiskAnn query.") + .def("__repr__", + [](const DiskAnnQueryParams &self) -> std::string { + return "{" + "\"type\":" + + index_type_to_string(self.type()) + + ", \"list_size\":" + std::to_string(self.list_size()) + "}"; + }) + .def(py::pickle( + [](const DiskAnnQueryParams &self) { + return py::make_tuple(self.list_size()); + }, + [](py::tuple t) { + if (t.size() != 1) + throw std::runtime_error("Invalid state for DiskAnnQueryParams"); + return std::make_shared(t[0].cast()); + })); } void ZVecPyParams::bind_options(py::module_ &m) { // binding collection options diff --git a/src/core/CMakeLists.txt b/src/core/CMakeLists.txt index 76bebd9f9..330f1a26a 100644 --- a/src/core/CMakeLists.txt +++ b/src/core/CMakeLists.txt @@ -39,6 +39,16 @@ if(NOT RABITQ_SUPPORTED) list(FILTER ALL_CORE_SRCS EXCLUDE REGEX ".*/algorithm/hnsw_rabitq/.*") endif() +# Always exclude algorithm/diskann implementation files from zvec_core. +# The DiskAnn algorithm is provided by the separate core_knn_diskann library +# (real on Linux x86_64, stub on other platforms). Including them here causes +# duplicate symbols and missing -laio when test binaries link both zvec_core +# (via zvec_db) and core_knn_diskann. +list(FILTER ALL_CORE_SRCS EXCLUDE REGEX ".*/algorithm/diskann/.*") +if(NOT DISKANN_SUPPORTED) + list(FILTER ALL_CORE_SRCS EXCLUDE REGEX ".*/interface/indexes/diskann_index\\.cc") +endif() + cc_library( NAME zvec_core STATIC STRICT PACKED SRCS ${ALL_CORE_SRCS} diff --git a/src/core/algorithm/CMakeLists.txt b/src/core/algorithm/CMakeLists.txt index 20a459052..d200aea82 100644 --- a/src/core/algorithm/CMakeLists.txt +++ b/src/core/algorithm/CMakeLists.txt @@ -7,6 +7,42 @@ cc_directory(flat_sparse) cc_directory(ivf) cc_directory(hnsw) cc_directory(hnsw_sparse) + +if(DISKANN_SUPPORTED) + message(STATUS "build diskann") + cc_directory(diskann) +else() + message(STATUS "not build diskann") + # Empty stub library for unsupported platforms + file(WRITE ${CMAKE_CURRENT_BINARY_DIR}/diskann_stub.cc + "// Stub implementation for unsupported platforms\n" + "// DiskAnn only supports Linux x86_64\n" + "namespace zvec { namespace core { /* empty namespace for compatibility */ } }\n" + ) + + if(MSVC) + # MSVC: STATIC-only stub to avoid creating an empty DLL with no exports + # (MSVC linker fails to produce an import library when there are zero exports) + cc_library( + NAME core_knn_diskann + STATIC STRICT ALWAYS_LINK + SRCS ${CMAKE_CURRENT_BINARY_DIR}/diskann_stub.cc + LIBS core_framework + INCS . ${PROJECT_ROOT_DIR}/src ${PROJECT_ROOT_DIR}/src/core ${PROJECT_ROOT_DIR}/src/core/algorithm + VERSION "${PROXIMA_ZVEC_VERSION}" + ) + else() + cc_library( + NAME core_knn_diskann + STATIC SHARED STRICT ALWAYS_LINK + SRCS ${CMAKE_CURRENT_BINARY_DIR}/diskann_stub.cc + LIBS core_framework + INCS . ${PROJECT_ROOT_DIR}/src ${PROJECT_ROOT_DIR}/src/core ${PROJECT_ROOT_DIR}/src/core/algorithm + VERSION "${PROXIMA_ZVEC_VERSION}" + ) + endif() +endif() + if(RABITQ_SUPPORTED) message(STATUS "BUILD RABITQ") cc_directory(hnsw_rabitq) @@ -19,12 +55,24 @@ else() "namespace zvec { namespace core { /* empty namespace for compatibility */ } }\n" ) - cc_library( - NAME core_knn_hnsw_rabitq - STATIC SHARED STRICT ALWAYS_LINK - SRCS ${CMAKE_CURRENT_BINARY_DIR}/rabitq_stub.cc - LIBS core_framework - INCS . ${PROJECT_ROOT_DIR}/src ${PROJECT_ROOT_DIR}/src/core ${PROJECT_ROOT_DIR}/src/core/algorithm - VERSION "${PROXIMA_ZVEC_VERSION}" - ) + if(MSVC) + # MSVC: STATIC-only stub to avoid creating an empty DLL with no exports + cc_library( + NAME core_knn_hnsw_rabitq + STATIC STRICT ALWAYS_LINK + SRCS ${CMAKE_CURRENT_BINARY_DIR}/rabitq_stub.cc + LIBS core_framework + INCS . ${PROJECT_ROOT_DIR}/src ${PROJECT_ROOT_DIR}/src/core ${PROJECT_ROOT_DIR}/src/core/algorithm + VERSION "${PROXIMA_ZVEC_VERSION}" + ) + else() + cc_library( + NAME core_knn_hnsw_rabitq + STATIC SHARED STRICT ALWAYS_LINK + SRCS ${CMAKE_CURRENT_BINARY_DIR}/rabitq_stub.cc + LIBS core_framework + INCS . ${PROJECT_ROOT_DIR}/src ${PROJECT_ROOT_DIR}/src/core ${PROJECT_ROOT_DIR}/src/core/algorithm + VERSION "${PROXIMA_ZVEC_VERSION}" + ) + endif() endif() diff --git a/src/core/algorithm/cluster/cluster_params.h b/src/core/algorithm/cluster/cluster_params.h index 06cfdad4b..0152153ca 100644 --- a/src/core/algorithm/cluster/cluster_params.h +++ b/src/core/algorithm/cluster/cluster_params.h @@ -203,5 +203,17 @@ static const std::string STRATIFIED_TRAINER_AUTOAUNE = static const std::string STRATIFIED_TRAINER_PARAMS_IN_LEVEL_PREFIX = "proxima.stratified.trainer.cluster_params_in_level_"; +static const std::string MULTI_CHUNK_CLUSTER_COUNT = + "proxima.cluster.multi_chunk_cluster.count"; +static const std::string MULTI_CHUNK_CLUSTER_CHUNK_COUNT = + "proxima.cluster.multi_chunk_cluster.chunk_count"; +static const std::string MULTI_CHUNK_CLUSTER_THREAD_COUNT = + "proxima.cluster.multi_chunk_cluster.thread_count"; +static const std::string MULTI_CHUNK_CLUSTER_EPSILON = + "proxima.cluster.multi_chunk_cluster.epsilon"; +static const std::string MULTI_CHUNK_CLUSTER_MAX_ITERATIONS = + "proxima.cluster.multi_chunk_cluster.max_iterations"; +static const std::string MULTI_CHUNK_CLUSTER_MARKOV_CHAIN_LENGTH = + "proxima.cluster.multi_chunk_cluster.markov_chain_length"; } // namespace core } // namespace zvec diff --git a/src/core/algorithm/cluster/multi_chunk_cluster.cc b/src/core/algorithm/cluster/multi_chunk_cluster.cc new file mode 100644 index 000000000..1e8356fa0 --- /dev/null +++ b/src/core/algorithm/cluster/multi_chunk_cluster.cc @@ -0,0 +1,440 @@ +// Copyright 2025-present the zvec project +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "multi_chunk_cluster.h" +#include + +namespace zvec { +namespace core { + +bool MultiChunkClusterAlgorithm::is_valid(void) const { + if (!features_ || !features_->count()) { + return false; + } + return true; +} + +bool MultiChunkClusterAlgorithm::check_centroids( + const IndexCluster::CentroidList & /*cents*/) { + return true; +} + +void MultiChunkClusterAlgorithm::init_centroids( + size_t count, IndexCluster::CentroidList *out) { + // Just resize, because the get random centroid step is done by cluster_once + out->resize(count); +} + +int MultiChunkClusterAlgorithm::reset(void) { + features_.reset(); + + return 0; +} + +int MultiChunkClusterAlgorithm::cleanup(void) { + features_.reset(); + + return 0; +} + +void MultiChunkClusterAlgorithm::suggest(uint32_t k) { + cluster_count_ = k; +} + +int MultiChunkClusterAlgorithm::update(const ailego::Params ¶ms) { + this->update_params(params); + // algorithm_->reset(cluster_count_); + return 0; +} + +//! MultiChunkCluster +int MultiChunkClusterAlgorithm::update_params(const ailego::Params ¶ms) { + params.get(GENERAL_THREAD_COUNT, &thread_count_); + params.get(GENERAL_CLUSTER_COUNT, &cluster_count_); + + params.get(MULTI_CHUNK_CLUSTER_THREAD_COUNT, &thread_count_); + params.get(MULTI_CHUNK_CLUSTER_COUNT, &cluster_count_); + params.get(MULTI_CHUNK_CLUSTER_CHUNK_COUNT, &chunk_count_); + params.get(MULTI_CHUNK_CLUSTER_EPSILON, &epsilon_); + params.get(MULTI_CHUNK_CLUSTER_MAX_ITERATIONS, &max_iterations_); + params.get(MULTI_CHUNK_CLUSTER_MARKOV_CHAIN_LENGTH, &markov_chain_length_); + + return 0; +} + +int MultiChunkClusterAlgorithm::init_distance_func() { + IndexMetric::Pointer metric_{}; + metric_ = IndexFactory::CreateMetric(meta_.metric_name()); + if (!metric_) { + LOG_ERROR("Create metric %s failed.", meta_.metric_name().c_str()); + return IndexError_Unsupported; + } + + int ret = metric_->init(meta_, meta_.metric_params()); + if (ret != 0) { + LOG_ERROR("IndexMetric init failed wit ret %d.", ret); + return ret; + } + + distance_func_ = metric_->distance_matrix(1, 1); + if (!distance_func_) { + LOG_ERROR("DistanceMatrix function is nullptr."); + return IndexError_Unsupported; + } + return 0; +} + +int MultiChunkClusterAlgorithm::do_chunk() { + size_t large_chunk_count = meta_.dimension() % chunk_count_; + size_t base_chunk_dim_ = meta_.dimension() / chunk_count_; + + chunk_dims_.clear(); + + for (size_t i = 0; i < chunk_count_; ++i) { + if (i < large_chunk_count) { + chunk_dims_.push_back(base_chunk_dim_ + 1); + } else { + chunk_dims_.push_back(base_chunk_dim_); + } + } + + chunk_dim_offsets_.clear(); + chunk_dim_offsets_.push_back(0); + for (size_t i = 1; i < chunk_count_; ++i) { + chunk_dim_offsets_.push_back(chunk_dim_offsets_[i - 1] + + chunk_dims_[i - 1]); + } + chunk_dim_offsets_.push_back(meta_.dimension()); + + return 0; +} + +int MultiChunkClusterAlgorithm::init(const IndexMeta &meta, + const ailego::Params ¶ms) { + meta_ = meta; + + int ret = update_params(params); + if (ret != 0) { + return ret; + } + + ret = do_chunk(); + if (ret != 0) { + return ret; + } + + ret = init_distance_func(); + if (ret != 0) { + return ret; + } + + return 0; +} + +int MultiChunkClusterAlgorithm::mount(IndexFeatures::Pointer features) { + if (!features) { + return IndexError_InvalidArgument; + } + + if (!features->is_matched(meta_)) { + return IndexError_Mismatch; + } + + auto data_type = meta_.data_type(); + if (data_type != IndexMeta::DataType::DT_FP32 && + data_type != IndexMeta::DataType::DT_FP16) { + LOG_ERROR("Unsupported meta type %u", data_type); + + return IndexError_Unsupported; + } + + features_ = std::move(features); + + return 0; +} + +//! cluster +int MultiChunkClusterAlgorithm::cluster(IndexThreads::Pointer threads, + IndexCluster::CentroidList ¢s) { + if (chunk_count_ == 0) { + LOG_ERROR("Invalid Chunk Count: %u", chunk_count_); + + return IndexError_InvalidArgument; + } + + if (cluster_count_ == 0) { + LOG_ERROR("Invalid cluster Count: %u", cluster_count_); + + return IndexError_InvalidArgument; + } + + if (!threads) { + threads = std::make_shared(thread_count_, false); + if (!threads) { + return IndexError_NoMemory; + } + } + + auto task_group = threads->make_group(); + if (!task_group) { + LOG_ERROR("Failed to create task group"); + return IndexError_Runtime; + } + + cents.clear(); + cents.resize(chunk_count_ * cluster_count_); + + std::atomic finished{0}; + + for (size_t i = 0; i < threads->count(); ++i) { + task_group->submit( + ailego::Closure::New(this, &MultiChunkClusterAlgorithm::do_cluster, i, + threads->count(), ¢s, &finished)); + } + + while (!task_group->is_finished()) { + std::unique_lock lk(mutex_); + cond_.wait_until(lk, std::chrono::system_clock::now() + + std::chrono::seconds(check_interval_secs_)); + if (error_.load(std::memory_order_acquire)) { + LOG_ERROR("Failed to cluster while waiting finish"); + return errcode_; + } + LOG_INFO("Finish Chunk Count %zu, Finished Percent %.3f%%", finished.load(), + finished.load() * 100.0f / chunk_count_); + } + + if (error_.load(std::memory_order_acquire)) { + LOG_ERROR("Failed to cluster while waiting finish"); + return errcode_; + } + + task_group->wait_finish(); + + return 0; +} + +//! Classify +int MultiChunkClusterAlgorithm::classify( + IndexThreads::Pointer /*threads*/, IndexCluster::CentroidList & /*cents*/) { + return IndexError_Unsupported; +} + +//! Label +int MultiChunkClusterAlgorithm::label(IndexThreads::Pointer threads, + const IndexCluster::CentroidList ¢s, + std::vector *out) { + if (chunk_count_ == 0) { + LOG_ERROR("Invalid Chunk Count: %u", chunk_count_); + + return IndexError_InvalidArgument; + } + + if (cents.empty()) { + LOG_ERROR("The input centroid's list is empty."); + return IndexError_InvalidArgument; + } + + if (!this->check_centroids(cents)) { + LOG_ERROR("The input centroid's list includes some invalid centroids."); + return IndexError_InvalidArgument; + } + + if (!this->is_valid()) { + LOG_ERROR("The cluster is not ready."); + return IndexError_NoReady; + } + + if (cluster_count_ == 0) { + LOG_ERROR("Invalid cluster Count: %u", cluster_count_); + + return IndexError_InvalidArgument; + } + + if (!threads) { + threads = std::make_shared(thread_count_, false); + if (!threads) { + return IndexError_NoMemory; + } + } + + // threads = std::make_shared(1, false); + + auto task_group = threads->make_group(); + if (!task_group) { + LOG_ERROR("Failed to create task group"); + return IndexError_Runtime; + } + + size_t features_count = features_->count(); + out->resize(features_count * chunk_count_); + + std::atomic finished{0}; + + for (size_t i = 0; i < threads->count(); ++i) { + task_group->submit( + ailego::Closure::New(this, &MultiChunkClusterAlgorithm::do_label, i, + threads->count(), cents, out, &finished)); + } + + while (!task_group->is_finished()) { + std::unique_lock lk(mutex_); + cond_.wait_until(lk, std::chrono::system_clock::now() + + std::chrono::seconds(check_interval_secs_)); + if (error_.load(std::memory_order_acquire)) { + LOG_ERROR("Failed to cluster while waiting finish"); + return errcode_; + } + LOG_INFO("Finish label cnt %zu, finished percent %.3f%%", finished.load(), + finished.load() * 100.0f / features_count); + } + + if (error_.load(std::memory_order_acquire)) { + LOG_ERROR("Failed to cluster while waiting finish"); + return errcode_; + } + + task_group->wait_finish(); + + return 0; +} + +//! Cluster +int MultiChunkCluster::cluster(IndexThreads::Pointer threads, + IndexCluster::CentroidList ¢s) { + return algorithm_->cluster(std::move(threads), cents); +} + +//! Classify +int MultiChunkCluster::classify(IndexThreads::Pointer threads, + IndexCluster::CentroidList ¢s) { + return algorithm_->classify(std::move(threads), cents); +} + +//! Label +int MultiChunkCluster::label(IndexThreads::Pointer threads, + const IndexCluster::CentroidList ¢s, + std::vector *out) { + return algorithm_->label(std::move(threads), cents, out); +} + +//! Update Cluster +int MultiChunkCluster::update(const ailego::Params ¶ms) { + return algorithm_->update(params); +} + +//! Reset Cluster +int MultiChunkCluster::reset(void) { + return algorithm_->reset(); +} + +//! Cleanup Cluster +int MultiChunkCluster::cleanup(void) { + return algorithm_->cleanup(); +} + +//! Suggest dividing to K clusters +void MultiChunkCluster::suggest(uint32_t k) { + algorithm_->suggest(k); +} + +int MultiChunkCluster::mount(IndexFeatures::Pointer feats) { + return algorithm_->mount(feats); +} + +int MultiChunkCluster::init(const IndexMeta &meta, + const ailego::Params ¶ms) { + IndexMeta new_meta(meta.data_type(), meta.dimension()); + + if (meta.metric_name() == "Cosine") { + new_meta.set_dimension(meta.dimension() - 1); + new_meta.set_metric("InnerProduct", 0, ailego::Params()); + } + + auto data_type = new_meta.data_type(); + + if (new_meta.metric_name() == "InnerProduct") { + switch (data_type) { + case IndexMeta::DataType::DT_FP16: { + algorithm_.reset( + new (std::nothrow) + MultiChunkNumericalInnerProductAlgorithm); + break; + } + case IndexMeta::DataType::DT_FP32: { + algorithm_.reset(new (std::nothrow) + MultiChunkNumericalInnerProductAlgorithm); + break; + } + case IndexMeta::DataType::DT_FP64: { + algorithm_.reset(new (std::nothrow) + MultiChunkNumericalInnerProductAlgorithm); + break; + } + case IndexMeta::DataType::DT_INT8: { + algorithm_.reset(new (std::nothrow) + MultiChunkNumericalInnerProductAlgorithm); + break; + } + case IndexMeta::DataType::DT_INT16: { + algorithm_.reset(new (std::nothrow) + MultiChunkNumericalInnerProductAlgorithm); + break; + } + default: { + LOG_ERROR("Unsupported feature types %d.", data_type); + return IndexError_Mismatch; + } + } + } else { + switch (data_type) { + case IndexMeta::DataType::DT_FP16: { + algorithm_.reset(new (std::nothrow) + MultiChunkNumericalAlgorithm); + break; + } + case IndexMeta::DataType::DT_FP32: { + algorithm_.reset(new (std::nothrow) + MultiChunkNumericalAlgorithm); + break; + } + case IndexMeta::DataType::DT_FP64: { + algorithm_.reset(new (std::nothrow) + MultiChunkNumericalAlgorithm); + break; + } + case IndexMeta::DataType::DT_INT8: { + algorithm_.reset(new (std::nothrow) + MultiChunkNumericalAlgorithm); + break; + } + case IndexMeta::DataType::DT_INT16: { + algorithm_.reset(new (std::nothrow) + MultiChunkNumericalAlgorithm); + break; + } + default: { + LOG_ERROR("Unsupported feature types %d.", data_type); + return IndexError_Mismatch; + } + } + } + + algorithm_->init(new_meta, params); + + return 0; +} + +} // namespace core +} // namespace zvec diff --git a/src/core/algorithm/cluster/multi_chunk_cluster.h b/src/core/algorithm/cluster/multi_chunk_cluster.h new file mode 100644 index 000000000..25496cded --- /dev/null +++ b/src/core/algorithm/cluster/multi_chunk_cluster.h @@ -0,0 +1,468 @@ +// Copyright 2025-present the zvec project +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +#pragma once + +#include +#include +#include +#include +#include +#include +#include "cluster_params.h" + +namespace zvec { +namespace core { + +//! MultiChunkClusterAlgorithm +class MultiChunkClusterAlgorithm { + public: + typedef std::shared_ptr Pointer; + + //! Constructor + MultiChunkClusterAlgorithm(void) {} + + //! Destructor + virtual ~MultiChunkClusterAlgorithm(void) {} + + //! Initialize Cluster + int init(const IndexMeta &meta, const ailego::Params ¶ms); + + //! Cleanup Cluster + int cleanup(void); + + //! Reset Cluster + int reset(void); + + //! Update Cluster + int update(const ailego::Params ¶ms); + + //! Suggest dividing to K clusters + void suggest(uint32_t k); + + //! Mount features + int mount(IndexFeatures::Pointer feats); + + //! Cluster + int cluster(IndexThreads::Pointer threads, IndexCluster::CentroidList ¢s); + + //! Classify + int classify(IndexThreads::Pointer threads, + IndexCluster::CentroidList ¢s); + + //! Label + int label(IndexThreads::Pointer threads, + const IndexCluster::CentroidList ¢s, + std::vector *out); + + + const std::vector &chunk_dims() const { + return chunk_dims_; + } + + const std::vector &chunk_dim_offsets() const { + return chunk_dim_offsets_; + } + + protected: + //! Check Centroids + bool check_centroids(const IndexCluster::CentroidList ¢s); + + //! Test if it is valid + bool is_valid(void) const; + + //! Do chunk + int do_chunk(); + + //! Update parameters + int update_params(const ailego::Params ¶ms); + + int init_distance_func(); + + //! cluster thread + virtual void do_cluster(size_t idx, size_t chunk_step, + IndexCluster::CentroidList *cents, + std::atomic *finished) = 0; + + //! label thread + virtual void do_label(size_t idx, size_t step, + const IndexCluster::CentroidList ¢s, + std::vector *out, + std::atomic *finished) = 0; + + //! Initialize Centroids + void init_centroids(size_t count, IndexCluster::CentroidList *out); + + private: + constexpr static uint32_t kDefaultLogIntervalSecs = 15U; + + protected: + uint32_t cluster_count_{0u}; + uint32_t thread_count_{0u}; + uint32_t chunk_count_{0u}; + uint32_t max_iterations_{20u}; + bool assumption_free_{false}; + uint32_t markov_chain_length_{32}; + double epsilon_{std::numeric_limits::epsilon()}; + + int errcode_{0}; + std::atomic_bool error_{false}; + uint32_t check_interval_secs_{kDefaultLogIntervalSecs}; + std::mutex mutex_{}; + std::condition_variable cond_{}; + + IndexMeta meta_{}; + IndexFeatures::Pointer features_{}; + + std::vector chunk_dims_; + std::vector chunk_dim_offsets_; + + IndexMetric::MatrixDistance distance_func_{nullptr}; +}; + +/*! Numerical cluster algorithm + */ +template +class MultiChunkNumericalAlgorithm : public MultiChunkClusterAlgorithm { + public: + //! Type of value + using ValueType = typename std::remove_cv::type; + + // Check supporting type + static_assert(ailego::IsArithmetic::value, + "ValueType must be arithmetic"); + + //! Constructor + MultiChunkNumericalAlgorithm(void) {} + + //! Destructor + ~MultiChunkNumericalAlgorithm(void) {} + + protected: + //! cluster thread + void do_cluster(size_t idx, size_t chunk_step, + IndexCluster::CentroidList *cents, + std::atomic *finished); + + //! label thread + void do_label(size_t idx, size_t step, + const IndexCluster::CentroidList ¢s, + std::vector *out, std::atomic *finished); +}; + +//! cluster thread +template +void MultiChunkNumericalAlgorithm::do_cluster( + size_t idx, size_t chunk_step, IndexCluster::CentroidList *cents, + std::atomic *finished) { + for (size_t chunk = idx; chunk < chunk_count_; chunk += chunk_step) { + auto chunk_dim = chunk_dims_[chunk]; + + ailego::NumericalKmeans algorithm(cluster_count_, + chunk_dim); + + // mount features into algorithm + auto features_count = features_->count(); + + algorithm.feature_matrix_reserve(features_count); + + for (size_t i = 0; i < features_count; ++i) { + auto vec = reinterpret_cast(features_->element(i)); + algorithm.append(vec + chunk_dim_offsets_[chunk], chunk_dim); + } + + IndexThreads::Pointer local_threads = + std::make_shared(1, false); + if (!local_threads) { + error_ = IndexError_NoMemory; + return; + } + + ailego::Kmc2CentroidsGenerator, + IndexThreads> + cent_gen; + cent_gen.set_chain_length(markov_chain_length_); + cent_gen.set_assumption_free(assumption_free_); + cent_gen(&algorithm, *local_threads); + + double cost = 0.0; + + for (uint32_t i = 0; i < max_iterations_; ++i) { + double old_cost, new_epsilon; + old_cost = cost; + + bool result = algorithm.cluster_once(*local_threads, &cost); + if (result != true) { + LOG_ERROR("(%u) Failed to cluster.", i + 1); + errcode_ = -1; + + return; + } + + new_epsilon = std::abs(cost - old_cost); + if (new_epsilon < epsilon_) { + break; + } + } + + auto &chunk_cents = algorithm.centroids(); + + for (size_t i = 0; i < chunk_cents.count(); ++i) { + size_t global_cent_idx = chunk * cluster_count_ + i; + + IndexCluster::Centroid *centroid = &(cents->at(global_cent_idx)); + centroid->set_score(algorithm.context().clusters()[i].cost()); + centroid->set_follows(algorithm.context().clusters()[i].count()); + centroid->set_feature(algorithm.centroids()[i], + chunk_dim * meta_.unit_size()); + } + + LOG_INFO("(%zu) Chunk Done. Clusters Count: %zu, Features: %zu, Cost: %f", + chunk, algorithm.centroids().count(), features_->count(), cost); + + (*finished)++; + } + + return; +} + +//! label thread +template +void MultiChunkNumericalAlgorithm::do_label( + size_t idx, size_t step, const IndexCluster::CentroidList ¢s, + std::vector *out, std::atomic *finished) { + for (size_t id = idx; id < features_->count(); id += step) { + const T *feat = reinterpret_cast(features_->element(id)); + + for (size_t chunk = 0; chunk < chunk_count_; ++chunk) { + size_t chunk_dim_offset = chunk_dim_offsets_[chunk]; + size_t chunk_dim = chunk_dims_[chunk]; + + uint32_t sel_index = 0; + float sel_score = std::numeric_limits::max(); + + for (uint32_t cluster = 0; cluster < cluster_count_; ++cluster) { + float score{0.0}; + + distance_func_(cents[chunk * cluster_count_ + cluster].feature(), + feat + chunk_dim_offset, chunk_dim, &score); + + if (score < sel_score) { + sel_score = score; + sel_index = cluster; + } + } + + (*out)[id * chunk_count_ + chunk] = static_cast(sel_index); + } + + (*finished)++; + } +} + +/*! Inner Product Cluster Algorithm + */ +template +class MultiChunkNumericalInnerProductAlgorithm + : public MultiChunkClusterAlgorithm { + public: + //! Type of value + using ValueType = typename std::remove_cv::type; + + // Check supporting type + static_assert(ailego::IsArithmetic::value, + "ValueType must be arithmetic"); + + //! Constructor + MultiChunkNumericalInnerProductAlgorithm(void) {} + + //! Destructor + ~MultiChunkNumericalInnerProductAlgorithm(void) {} + + protected: + //! cluster thread + void do_cluster(size_t idx, size_t chunk_step, + IndexCluster::CentroidList *cents, + std::atomic *finished); + + //! label thread + void do_label(size_t idx, size_t chunk_step, + const IndexCluster::CentroidList ¢s, + std::vector *out, std::atomic *finished); +}; + +//! cluster thread +template +void MultiChunkNumericalInnerProductAlgorithm::do_cluster( + size_t idx, size_t chunk_step, IndexCluster::CentroidList *cents, + std::atomic *finished) { + for (size_t chunk = idx; chunk < chunk_count_; chunk += chunk_step) { + auto chunk_dim = chunk_dims_[chunk]; + + ailego::NumericalInnerProductKmeans algorithm( + cluster_count_, chunk_dim); + + // mount features into algorithm + auto features_count = features_->count(); + + algorithm.feature_matrix_reserve(features_count); + + for (size_t i = 0; i < features_count; ++i) { + auto vec = reinterpret_cast(features_->element(i)); + algorithm.append(vec + chunk_dim_offsets_[chunk], chunk_dim); + } + + IndexThreads::Pointer local_threads = + std::make_shared(1, false); + if (!local_threads) { + error_ = IndexError_NoMemory; + return; + } + + ailego::Kmc2CentroidsGenerator< + ailego::NumericalInnerProductKmeans, IndexThreads> + cent_gen; + cent_gen.set_chain_length(markov_chain_length_); + cent_gen.set_assumption_free(assumption_free_); + cent_gen(&algorithm, *local_threads); + + double cost = 0.0; + + for (uint32_t i = 0; i < max_iterations_; ++i) { + double old_cost, new_epsilon; + old_cost = cost; + + bool result = algorithm.cluster_once(*local_threads, &cost); + if (result != true) { + LOG_ERROR("(%u) Failed to cluster.", i + 1); + errcode_ = -1; + + return; + } + + new_epsilon = std::abs(cost - old_cost); + if (new_epsilon < epsilon_) { + break; + } + } + + auto &chunk_cents = algorithm.centroids(); + + for (size_t i = 0; i < chunk_cents.count(); ++i) { + size_t global_cent_idx = chunk * cluster_count_ + i; + + IndexCluster::Centroid *centroid = &(cents->at(global_cent_idx)); + centroid->set_score(algorithm.context().clusters()[i].cost()); + centroid->set_follows(algorithm.context().clusters()[i].count()); + centroid->set_feature(algorithm.centroids()[i], + chunk_dim * meta_.unit_size()); + } + + LOG_INFO("(%zu) Chunk Done. Clusters Count: %zu, Features: %zu, Cost: %f", + chunk, algorithm.centroids().count(), features_->count(), cost); + + (*finished)++; + } + + return; +} + +//! label thread +template +void MultiChunkNumericalInnerProductAlgorithm::do_label( + size_t idx, size_t step, const IndexCluster::CentroidList ¢s, + std::vector *out, std::atomic *finished) { + for (size_t id = idx; id < features_->count(); id += step) { + const T *feat = reinterpret_cast(features_->element(id)); + + for (size_t chunk = 0; chunk < chunk_count_; ++chunk) { + size_t chunk_dim_offset = chunk_dim_offsets_[chunk]; + size_t chunk_dim = chunk_dims_[chunk]; + + uint32_t sel_index = 0; + float sel_score = std::numeric_limits::max(); + + for (uint32_t cluster = 0; cluster < cluster_count_; ++cluster) { + float score{0.0}; + + distance_func_(cents[chunk * cluster_count_ + cluster].feature(), + feat + chunk_dim_offset, chunk_dim, &score); + + if (score < sel_score) { + sel_score = score; + sel_index = cluster; + } + } + + (*out)[id * chunk_count_ + chunk] = static_cast(sel_index); + } + + (*finished)++; + } +} + +//! MultiChunkCluster +class MultiChunkCluster { + public: + std::shared_ptr Pointer; + + //! Constructor + MultiChunkCluster(void) {} + + //! Destructor + ~MultiChunkCluster(void) {} + + //! Initialize Cluster + int init(const IndexMeta &meta, const ailego::Params ¶ms); + + //! Cleanup Cluster + int cleanup(void); + + //! Reset Cluster + int reset(void); + + //! Update Cluster + int update(const ailego::Params ¶ms); + + //! Suggest dividing to K clusters + void suggest(uint32_t k); + + //! Mount features + int mount(IndexFeatures::Pointer feats); + + //! Cluster + int cluster(IndexThreads::Pointer threads, IndexCluster::CentroidList ¢s); + + //! Classify + int classify(IndexThreads::Pointer threads, + IndexCluster::CentroidList ¢s); + + //! Label + int label(IndexThreads::Pointer threads, + const IndexCluster::CentroidList ¢s, + std::vector *out); + + const std::vector &chunk_dims() const { + return algorithm_->chunk_dims(); + } + + const std::vector &chunk_dim_offsets() const { + return algorithm_->chunk_dim_offsets(); + } + + protected: + MultiChunkClusterAlgorithm::Pointer algorithm_{}; +}; + +} // namespace core +} // namespace zvec diff --git a/src/core/algorithm/diskann/CMakeLists.txt b/src/core/algorithm/diskann/CMakeLists.txt new file mode 100644 index 000000000..6f1229b80 --- /dev/null +++ b/src/core/algorithm/diskann/CMakeLists.txt @@ -0,0 +1,28 @@ +## +## Copyright (C) The Software Authors. All rights reserved. +## +## \file CMakeLists.txt +## \author Hechong.xyf +## \date Oct 2019 +## \version 1.0 +## \brief Detail cmake build script +## + +include(${CMAKE_SOURCE_DIR}/cmake/bazel.cmake) + +file(GLOB_RECURSE ALL_SRCS *.cc *.c) + +set(CORE_KNN_DISKANN_LIBS zvec_ailego core_framework core_knn_cluster) + +if(CMAKE_SYSTEM_NAME STREQUAL "Linux" AND CMAKE_SYSTEM_PROCESSOR MATCHES "x86_64|i686|i386") + list(APPEND CORE_KNN_DISKANN_LIBS aio) +endif() + +cc_library( + NAME core_knn_diskann + STATIC SHARED STRICT ALWAYS_LINK + SRCS *.cc + LIBS ${CORE_KNN_DISKANN_LIBS} + INCS . ${PROJECT_ROOT_DIR}/src/core ${PROJECT_ROOT_DIR}/src/core/algorithm + VERSION "${PROXIMA_ZVEC_VERSION}" +) \ No newline at end of file diff --git a/src/core/algorithm/diskann/diskann_algorithm.cc b/src/core/algorithm/diskann/diskann_algorithm.cc new file mode 100644 index 000000000..e19e5bf1a --- /dev/null +++ b/src/core/algorithm/diskann/diskann_algorithm.cc @@ -0,0 +1,333 @@ +// Copyright 2025-present the zvec project +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "diskann_algorithm.h" +#include +#include +#include "diskann_util.h" + +namespace zvec { +namespace core { + +DiskAnnAlgorithm::DiskAnnAlgorithm(DiskAnnEntity &entity, uint32_t max_degree) + : entity_(entity), max_degree_(max_degree), lock_pool_(kLockCnt) {} + +std::vector DiskAnnAlgorithm::get_init_ids(DiskAnnContext *ctx) { + const auto &entity = ctx->get_entity(); + + std::vector init_ids; + + init_ids.emplace_back(entity.medoid()); + + return init_ids; +} + +int DiskAnnAlgorithm::add_node(diskann_id_t id, DiskAnnContext *ctx) { + const void *vec = entity_.get_vector(id); + + ctx->reset_query(vec); + + std::vector pruned_list; + + int ret = search_neighbor_and_prune(id, pruned_list, ctx); + if (ret != 0) { + return ret; + } + + uint32_t lock_idx = id & kLockMask; + lock_pool_[lock_idx].lock(); + entity_.set_neighbors(id, pruned_list); + lock_pool_[lock_idx].unlock(); + + ret = inter_insert(id, pruned_list, ctx); + + return 0; +} + +int DiskAnnAlgorithm::prune_node(diskann_id_t id, DiskAnnContext *ctx) { + DistCalculator &dc = ctx->dist_calculator(); + + auto neighbors = entity_.get_neighbors(id); + + if (neighbors.first > max_degree_) { + std::set dummy_visited; + std::vector dummy_pool(0); + std::vector new_out_neighbors; + + for (size_t i = 0; i < neighbors.first; ++i) { + diskann_id_t node_id = (neighbors.second)[i]; + + auto itr = dummy_visited.find(node_id); + if (itr == dummy_visited.end() && node_id != id) { + float dist = dc.dist(id, node_id); + + dummy_pool.emplace_back(Neighbor(node_id, dist)); + dummy_visited.insert(node_id); + } + } + + prune_neighbors(id, dummy_pool, new_out_neighbors, ctx); + + uint32_t lock_idx = id & kLockMask; + lock_pool_[lock_idx].lock(); + entity_.set_neighbors(id, new_out_neighbors); + lock_pool_[lock_idx].unlock(); + } + + return 0; +} + +int DiskAnnAlgorithm::inter_insert(diskann_id_t id, + std::vector &pruned_list, + DiskAnnContext *ctx) { + DistCalculator &dc = ctx->dist_calculator(); + + for (auto &des : pruned_list) { + std::vector new_neighbors; + bool need_prune = false; + + uint32_t lock_idx = des & kLockMask; + lock_pool_[lock_idx].lock(); + + auto neighbors = entity_.get_neighbors(des); + + bool found = false; + for (size_t i = 0; i < neighbors.first; ++i) { + if ((neighbors.second)[i] == id) { + found = true; + break; + } + } + + if (!found) { + if (neighbors.first < + static_cast(DiskAnnEntity::kDefaultGraphSlackFactor * + max_degree_)) { + entity_.add_neighbor(des, id); + need_prune = false; + } else { + new_neighbors.resize(neighbors.first + 1); + memcpy(&new_neighbors[0], neighbors.second, + sizeof(diskann_id_t) * neighbors.first); + + new_neighbors[neighbors.first] = id; + + need_prune = true; + } + } + + lock_pool_[lock_idx].unlock(); + + if (need_prune) { + std::set new_visited; + std::vector new_pool(0); + + size_t reserve_size = static_cast(std::ceil( + 1.05 * DiskAnnEntity::kDefaultGraphSlackFactor * max_degree_)); + + new_pool.reserve(reserve_size); + + for (auto node_id : new_neighbors) { + if (new_visited.find(node_id) == new_visited.end() && node_id != des) { + float dist = dc.dist(des, node_id); + new_pool.emplace_back(Neighbor(node_id, dist)); + new_visited.insert(node_id); + } + } + + std::vector new_pruned_neighbors; + prune_neighbors(des, new_pool, new_pruned_neighbors, ctx); + + lock_idx = des & kLockMask; + lock_pool_[lock_idx].lock(); + entity_.set_neighbors(des, new_pruned_neighbors); + lock_pool_[lock_idx].unlock(); + } + } + + return 0; +} + +int DiskAnnAlgorithm::iterate_to_fixed_point( + const std::vector &init_ids, DiskAnnContext *ctx) { + DistCalculator &dc = ctx->dist_calculator(); + std::vector &expanded_nodes = ctx->expanded_nodes(); + NeighborPriorityQueue &best_list_nodes = ctx->best_list_nodes(); + VisitFilter &visit = ctx->visit_filter(); + + best_list_nodes.reserve(ctx->list_size()); + + for (auto id : init_ids) { + const void *vec = entity_.get_vector(id); + + float distance = dc.dist(vec); + + Neighbor nn = Neighbor(id, distance); + best_list_nodes.insert(nn); + } + + while (best_list_nodes.has_unexpanded_node()) { + auto neighbor = best_list_nodes.closest_unexpanded(); + auto node_id = neighbor.id; + + expanded_nodes.emplace_back(neighbor); + + uint32_t lock_idx = node_id & kLockMask; + + lock_pool_[lock_idx].lock(); + auto neighbors = entity_.get_neighbors(node_id); + + std::vector id_scratch; + + for (size_t i = 0; i < neighbors.first; ++i) { + diskann_id_t neighbor_id = (neighbors.second)[i]; + + if (!visit.visited(neighbor_id)) { + id_scratch.push_back(neighbor_id); + + visit.set_visited(neighbor_id); + } + } + lock_pool_[lock_idx].unlock(); + + for (size_t i = 0; i < id_scratch.size(); ++i) { + diskann_id_t id = id_scratch[i]; + + const void *vec = entity_.get_vector(id); + float dist = dc.dist(vec); + + best_list_nodes.insert(Neighbor(id, dist)); + } + } + + return 0; +} + +int DiskAnnAlgorithm::occlude_list(diskann_id_t id, std::vector &pool, + std::vector &result, + DiskAnnContext *ctx) { + if (pool.size() == 0) return 0; + + DistCalculator &dc = ctx->dist_calculator(); + + ailego_assert(std::is_sorted(pool.begin(), pool.end())); + ailego_assert(result.size() == 0); + + if (pool.size() > max_candidate_size_) { + pool.resize(max_candidate_size_); + } + + std::vector &occlude_factor = ctx->occlude_factor(); + + occlude_factor.clear(); + occlude_factor.insert(occlude_factor.end(), pool.size(), 0.0f); + + float cur_alpha = 1; + while (cur_alpha <= alpha_ && result.size() < max_degree_) { + for (auto iter = pool.begin(); + result.size() < max_degree_ && iter != pool.end(); ++iter) { + if (occlude_factor[iter - pool.begin()] > cur_alpha) { + continue; + } + + occlude_factor[iter - pool.begin()] = std::numeric_limits::max(); + + if (iter->id != id) { + result.push_back(iter->id); + } + + for (auto iter2 = iter + 1; iter2 != pool.end(); iter2++) { + auto t = iter2 - pool.begin(); + if (occlude_factor[t] > alpha_) { + continue; + } + + float djk = dc.dist(iter2->id, iter->id); + + if (true) { + occlude_factor[t] = + (djk == 0) ? std::numeric_limits::max() + : std::max(occlude_factor[t], iter2->distance / djk); + } + } + } + cur_alpha *= 1.2f; + } + + return 0; +} + +int DiskAnnAlgorithm::prune_neighbors(diskann_id_t id, + std::vector &pool, + std::vector &pruned_list, + DiskAnnContext *ctx) { + if (pool.size() == 0) { + pruned_list.clear(); + return 0; + } + + std::sort(pool.begin(), pool.end()); + + pruned_list.clear(); + pruned_list.reserve(max_degree_); + + occlude_list(id, pool, pruned_list, ctx); + + ailego_assert(pruned_list.size() <= max_degree_); + + if (saturate_graph_ && alpha_ > 1) { + for (const auto &node : pool) { + if (pruned_list.size() >= max_degree_) { + break; + } + + if ((std::find(pruned_list.begin(), pruned_list.end(), node.id) == + pruned_list.end()) && + node.id != id) { + pruned_list.push_back(node.id); + } + } + } + + return 0; +} + +int DiskAnnAlgorithm::search_neighbor_and_prune( + diskann_id_t id, std::vector &pruned_list, + DiskAnnContext *ctx) { + const std::vector init_ids = get_init_ids(ctx); + + int ret = iterate_to_fixed_point(init_ids, ctx); + if (ret != 0) { + return ret; + } + + auto &pool = ctx->expanded_nodes(); + for (uint32_t i = 0; i < pool.size(); i++) { + if (pool[i].id == id) { + pool.erase(pool.begin() + i); + i--; + } + } + + ret = prune_neighbors(id, pool, pruned_list, ctx); + if (ret != 0) { + return ret; + } + + return 0; +} + +} // namespace core +} // namespace zvec \ No newline at end of file diff --git a/src/core/algorithm/diskann/diskann_algorithm.h b/src/core/algorithm/diskann/diskann_algorithm.h new file mode 100644 index 000000000..75c610850 --- /dev/null +++ b/src/core/algorithm/diskann/diskann_algorithm.h @@ -0,0 +1,66 @@ +// Copyright 2025-present the zvec project +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +#pragma once + +#include +#include +#include "diskann_context.h" + +namespace zvec { +namespace core { + +class DiskAnnAlgorithm { + public: + typedef std::unique_ptr UPointer; + + public: + DiskAnnAlgorithm(DiskAnnEntity &entity, uint32_t max_degree); + + public: + int add_node(diskann_id_t id, DiskAnnContext *ctx); + int prune_node(diskann_id_t id, DiskAnnContext *ctx); + + private: + int search_neighbor_and_prune(diskann_id_t id, + std::vector &pruned_list, + DiskAnnContext *ctx); + int iterate_to_fixed_point(const std::vector &init_ids, + DiskAnnContext *ctx); + int prune_neighbors(diskann_id_t id, std::vector &pool, + std::vector &pruned_list, + DiskAnnContext *ctx); + int inter_insert(diskann_id_t id, std::vector &pruned_list, + DiskAnnContext *ctx); + int occlude_list(diskann_id_t id, std::vector &pool, + std::vector &result, DiskAnnContext *ctx); + + std::vector get_init_ids(DiskAnnContext *ctx); + + private: + static constexpr uint32_t kLockCnt{1U << 16}; + static constexpr uint32_t kLockMask{kLockCnt - 1U}; + + DiskAnnEntity &entity_; + + uint32_t max_degree_{DiskAnnEntity::kDefaultMaxDegree}; + uint32_t max_candidate_size_{DiskAnnEntity::kDefaultMaxOcclusionSize}; + + std::vector lock_pool_{}; + + float alpha_{DiskAnnEntity::kDefaultAlpha}; + bool saturate_graph_{true}; +}; + +} // namespace core +} // namespace zvec \ No newline at end of file diff --git a/src/core/algorithm/diskann/diskann_builder.cc b/src/core/algorithm/diskann/diskann_builder.cc new file mode 100644 index 000000000..d82be13bc --- /dev/null +++ b/src/core/algorithm/diskann/diskann_builder.cc @@ -0,0 +1,647 @@ +// Copyright 2025-present the zvec project +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "diskann_builder.h" +#include +#include +#include +#include +#include +#include "diskann_context.h" +#include "diskann_params.h" + +namespace zvec { +namespace core { + +int DiskAnnBuilder::init(const IndexMeta &meta, const ailego::Params ¶ms) { + LOG_INFO("Begin DiskAnnBuilder::init"); + + params.get(PARAM_DISKANN_BUILDER_MAX_DEGREE, &max_degree_); + params.get(PARAM_DISKANN_BUILDER_LIST_SIZE, &list_size_); + params.get(PARAM_DISKANN_BUILDER_THREAD_COUNT, &build_thread_count_); + + if (build_thread_count_ == 0) { + build_thread_count_ = std::thread::hardware_concurrency(); + } + + if (build_thread_count_ > std::thread::hardware_concurrency()) { + LOG_WARN("Build thread count [%s] greater than cpu cores %u", + PARAM_DISKANN_BUILDER_THREAD_COUNT.c_str(), + std::thread::hardware_concurrency()); + } + + if (params.has(PARAM_DISKANN_BUILDER_MAX_PQ_CHUNK_NUM)) { + uint32_t max_pq_chunk_num{0}; + params.get(PARAM_DISKANN_BUILDER_MAX_PQ_CHUNK_NUM, &max_pq_chunk_num); + if (max_pq_chunk_num > meta.dimension()) { + LOG_ERROR( + "PQ Chunk Num larger than dimension, PQ Chunk Num: %d, Dimension: %d", + max_pq_chunk_num, meta.dimension()); + return IndexError_InvalidArgument; + } + + max_pq_chunk_num_ = max_pq_chunk_num; + } + + if (params.has(PARAM_DISKANN_BUILDER_MEMORY_LIMIT)) { + params.get(PARAM_DISKANN_BUILDER_MEMORY_LIMIT, &memory_limit_); + if (memory_limit_ <= 0) { + LOG_ERROR("Invalid memory limit: %lf", memory_limit_); + return IndexError_InvalidArgument; + } + + memory_limit_set_ = true; + } + + if (params.has(PARAM_DISKANN_BUILDER_MAX_TRAIN_SAMPLE_COUNT)) { + params.get(PARAM_DISKANN_BUILDER_MAX_TRAIN_SAMPLE_COUNT, + &max_train_sample_count_); + } + + if (params.has(PARAM_DISKANN_BUILDER_TRAIN_SAMPLE_RATIO)) { + params.get(PARAM_DISKANN_BUILDER_TRAIN_SAMPLE_RATIO, &train_sample_ratio_); + } + + raw_meta_ = meta; + + build_meta_ = meta; + if (meta.metric_name() == "InnerProduct") { + build_meta_.set_metric("SquaredEuclidean", 0, ailego::Params()); + } else if (meta.metric_name() == "Cosine") { + build_meta_.set_metric("SquaredEuclidean", 0, ailego::Params()); + + if (meta.data_type() == IndexMeta::DataType::DT_FP32) { + build_meta_.set_dimension(meta.dimension() - 1); + } else { + build_meta_.set_dimension(meta.dimension() - 2); + } + } + + metric_ = IndexFactory::CreateMetric(build_meta_.metric_name()); + if (!metric_) { + LOG_ERROR("CreateMetric failed, name: %s", + build_meta_.metric_name().c_str()); + return IndexError_NoExist; + } + + int ret = metric_->init(build_meta_, build_meta_.metric_params()); + if (ret != 0) { + LOG_ERROR("IndexMeasure init failed, ret=%d", ret); + return ret; + } + + raw_meta_.set_builder("DiskAnnBuilder", DiskAnnEntity::kRevision, params); + + ret = entity_.init(meta, max_degree_, list_size_, memory_limit_, + build_thread_count_); + if (ret != 0) { + return ret; + } + + algo_ = + DiskAnnAlgorithm::UPointer(new DiskAnnAlgorithm(entity_, max_degree_)); + + trainer_ = + DiskAnnPqTrainer::UPointer(new DiskAnnPqTrainer(max_train_sample_count_)); + + state_ = BUILD_STATE_INITED; + + return 0; +} + +int DiskAnnBuilder::cleanup(void) { + LOG_INFO("Begin DiskAnnBuilder::cleanup"); + + LOG_INFO("End DiskAnnBuilder::cleanup"); + + return 0; +} + +int DiskAnnBuilder::calculate_entry_point() { + std::string centroid; + + size_t dimension = build_meta_.dimension(); + + if (build_meta_.data_type() != IndexMeta::DataType::DT_FP32 && + build_meta_.data_type() != IndexMeta::DataType::DT_FP16) { + LOG_ERROR("Data type not supported"); + return IndexError_InvalidArgument; + } + + centroid.resize(dimension * sizeof(float)); + + float *centroid_data_ptr = reinterpret_cast(¢roid[0]); + for (size_t i = 0; i < dimension; i++) { + centroid_data_ptr[i] = 0; + } + + switch (build_meta_.data_type()) { + case IndexMeta::DataType::DT_FP32: + for (size_t id = 0; id < entity_.doc_cnt(); id++) { + const float *data_ptr = + reinterpret_cast(entity_.get_vector(id)); + + for (size_t i = 0; i < dimension; i++) { + centroid_data_ptr[i] += data_ptr[i]; + } + } + break; + case IndexMeta::DataType::DT_FP16: + for (size_t id = 0; id < entity_.doc_cnt(); id++) { + const ailego::Float16 *data_ptr = + reinterpret_cast(entity_.get_vector(id)); + + for (size_t i = 0; i < dimension; i++) { + centroid_data_ptr[i] += data_ptr[i]; + } + } + break; + default: + return IndexError_Unsupported; + } + + for (size_t i = 0; i < dimension; i++) { + centroid_data_ptr[i] /= entity_.doc_cnt(); + } + + // compute all to one distance + diskann_id_t medoid_id = kInvalidId; + float min_dist = std::numeric_limits::max(); + + switch (build_meta_.data_type()) { + case IndexMeta::DataType::DT_FP32: + for (size_t id = 0; id < entity_.doc_cnt(); id++) { + const float *data_ptr = + reinterpret_cast(entity_.get_vector(id)); + + float dist = 0; + for (size_t i = 0; i < dimension; i++) { + float diff = (centroid_data_ptr[i] - data_ptr[i]) * + (centroid_data_ptr[i] - data_ptr[i]); + dist += diff; + } + + if (dist < min_dist) { + min_dist = dist; + medoid_id = id; + } + } + break; + case IndexMeta::DataType::DT_FP16: + for (size_t id = 0; id < entity_.doc_cnt(); id++) { + const ailego::Float16 *data_ptr = + reinterpret_cast(entity_.get_vector(id)); + + float dist = 0; + for (size_t i = 0; i < dimension; i++) { + float diff = (centroid_data_ptr[i] - data_ptr[i]) * + (centroid_data_ptr[i] - data_ptr[i]); + dist += diff; + } + + if (dist < min_dist) { + min_dist = dist; + medoid_id = id; + } + } + break; + default: + return IndexError_Unsupported; + } + + (*entity_.mutable_medoid()) = medoid_id; + + LOG_INFO("Medroid Calculation Done. ID: %zu", (size_t)medoid_id); + + return 0; +} + +int DiskAnnBuilder::calculate_pq_chunk_num() { + size_t doc_cnt = holder_->count(); + if (doc_cnt == 0) { + LOG_ERROR("Invalid Input. Empty Vecs."); + + return IndexError_InvalidLength; + } + + if (memory_limit_set_) { + size_t memory_limit_bytes = get_memory_in_bytes(memory_limit_); + size_t pq_chunk_num = std::floor(memory_limit_bytes / doc_cnt); + if (pq_chunk_num <= 0) { + LOG_ERROR("Insufficient memory limit for vec, memory: %zu, vec num: %zu", + memory_limit_bytes, doc_cnt); + return IndexError_InvalidArgument; + } + } + + pq_chunk_num_ = + pq_chunk_num_ < max_pq_chunk_num_ ? pq_chunk_num_ : max_pq_chunk_num_; + if (pq_chunk_num_ > build_meta_.dimension()) { + LOG_ERROR("PQ Chunk Num is more than dimension, chunk num: %u, dim: %u", + pq_chunk_num_, build_meta_.dimension()); + return IndexError_InvalidArgument; + } + + if (pq_chunk_num_ == kDefaultPqChunkNum) { + pq_chunk_num_ = build_meta_.dimension() / 2; + LOG_INFO( + "No Chunk Num input. Quantizing %u dimension data into %u dimension.", + build_meta_.dimension(), pq_chunk_num_); + } + + LOG_INFO("Quantizing %u dimension data into %u bytes.", + build_meta_.dimension(), pq_chunk_num_); + + return 0; +} + +int DiskAnnBuilder::build_internal(IndexThreads::Pointer threads) { + auto task_group = threads->make_group(); + if (!task_group) { + LOG_ERROR("Failed to create task group"); + return IndexError_Runtime; + } + + std::atomic finished{0}; + for (size_t i = 0; i < threads->count(); ++i) { + task_group->submit(ailego::Closure ::New(this, &DiskAnnBuilder::do_build, i, + threads->count(), &finished)); + } + + while (!task_group->is_finished()) { + std::unique_lock lk(mutex_); + cond_.wait_until(lk, std::chrono::system_clock::now() + + std::chrono::seconds(check_interval_secs_)); + if (error_.load(std::memory_order_acquire)) { + LOG_ERROR("Failed to build index while waiting finish"); + return errcode_; + } + LOG_INFO("Built cnt %zu, finished percent %.3f%%", (size_t)finished.load(), + finished.load() * 100.0f / entity_.doc_cnt()); + } + if (error_.load(std::memory_order_acquire)) { + LOG_ERROR("Failed to build index while waiting finish"); + return errcode_; + } + task_group->wait_finish(); + + return 0; +} + +int DiskAnnBuilder::prune_internal(IndexThreads::Pointer threads) { + auto task_group = threads->make_group(); + if (!task_group) { + LOG_ERROR("Failed to create task group"); + return IndexError_Runtime; + } + + std::atomic finished{0}; + for (size_t i = 0; i < threads->count(); ++i) { + task_group->submit(ailego::Closure ::New(this, &DiskAnnBuilder::do_prune, i, + threads->count(), &finished)); + } + + while (!task_group->is_finished()) { + std::unique_lock lk(mutex_); + cond_.wait_until(lk, std::chrono::system_clock::now() + + std::chrono::seconds(check_interval_secs_)); + if (error_.load(std::memory_order_acquire)) { + LOG_ERROR("Failed to purne index while waiting finish"); + return errcode_; + } + LOG_INFO("Prune cnt %zu, finished percent %.3f%%", (size_t)finished.load(), + finished.load() * 100.0f / entity_.doc_cnt()); + } + if (error_.load(std::memory_order_acquire)) { + LOG_ERROR("Failed to prune index while waiting finish"); + return errcode_; + } + task_group->wait_finish(); + + return 0; +} + +int DiskAnnBuilder::train_quantized_data(IndexThreads::Pointer threads) { + LOG_INFO("Starting Train: Chunk Num: %u", pq_chunk_num_); + + ailego::ElapsedTime timer; + int ret = trainer_->train_quantized_data( + threads, holder_, build_meta_, entity_.pq_full_pivot_data(), + entity_.pq_centroid(), entity_.pq_chunk_offsets(), pq_chunk_num_); + if (ret != 0) { + LOG_ERROR("Train Quantized Data Error, ret=%d", ret); + return ret; + } + + size_t pq_time = timer.milli_seconds(); + LOG_INFO("Train Quantized Data Done, time: %zu ms", pq_time); + + (*entity_.mutable_pq_meta()).full_pivot_data_size = + entity_.pq_full_pivot_data().size(); + (*entity_.mutable_pq_meta()).centroid_data_size = + entity_.pq_centroid().size(); + (*entity_.mutable_pq_meta()).chunk_num = pq_chunk_num_; + + return 0; +} + +int DiskAnnBuilder::generate_quantized_data(IndexThreads::Pointer threads) { + LOG_INFO("Starting PQ Generate: Query Memory Limit: %lf, Chunk Num: %u", + memory_limit_, pq_chunk_num_); + + ailego::ElapsedTime timer; + int ret = trainer_->generate_quantized_data( + threads, holder_, build_meta_, entity_.pq_centroid(), + entity_.block_compressed_data(), pq_chunk_num_); + if (ret != 0) { + LOG_ERROR("Generate Quantized Data Error, ret=%d", ret); + return ret; + } + + size_t pq_time = timer.milli_seconds(); + LOG_INFO("Generate Quantized Data Done, time: %zu ms", pq_time); + + return 0; +} + +void DiskAnnBuilder::do_build(uint64_t idx, size_t step_size, + std::atomic *finished) { + AILEGO_DEFER([&]() { + std::lock_guard latch(mutex_); + cond_.notify_one(); + }); + + DiskAnnContext *ctx = new (std::nothrow) DiskAnnContext( + build_meta_, metric_, + std::shared_ptr(&entity_, [](DiskAnnEntity *) {})); + + if (ailego_unlikely(ctx == nullptr)) { + if (!error_.exchange(true)) { + LOG_ERROR("Failed to create context"); + errcode_ = IndexError_NoMemory; + } + return; + } + + ctx->init(DiskAnnContext::kBuilderContext, max_degree_, pq_chunk_num_, + build_meta_.element_size()); + ctx->set_list_size(list_size_); + + DiskAnnContext::Pointer auto_ptr(ctx); + for (uint64_t id = idx; id < entity_.doc_cnt(); id += step_size) { + ctx->reset_query(entity_.get_vector(id)); + int ret = algo_->add_node(id, ctx); + if (ailego_unlikely(ret != 0)) { + if (!error_.exchange(true)) { + LOG_ERROR("DiskAnn graph add node failed"); + errcode_ = ret; + } + return; + } + ctx->clear(); + (*finished)++; + } +} + +void DiskAnnBuilder::do_prune(uint64_t idx, size_t step_size, + std::atomic *finished) { + AILEGO_DEFER([&]() { + std::lock_guard latch(mutex_); + cond_.notify_one(); + }); + + DiskAnnContext *ctx = new (std::nothrow) DiskAnnContext( + build_meta_, metric_, + std::shared_ptr(&entity_, [](DiskAnnEntity *) {})); + + if (ailego_unlikely(ctx == nullptr)) { + if (!error_.exchange(true)) { + LOG_ERROR("Failed to create context"); + errcode_ = IndexError_NoMemory; + } + return; + } + + ctx->init(DiskAnnContext::kBuilderContext, max_degree_, pq_chunk_num_, + build_meta_.element_size()); + ctx->set_list_size(list_size_); + + DiskAnnContext::Pointer auto_ptr(ctx); + for (uint64_t id = idx; id < entity_.doc_cnt(); id += step_size) { + ctx->reset_query(entity_.get_vector(id)); + int ret = algo_->prune_node(id, ctx); + if (ailego_unlikely(ret != 0)) { + if (!error_.exchange(true)) { + LOG_ERROR("DiskAnn graph add node failed"); + errcode_ = ret; + } + return; + } + ctx->clear(); + (*finished)++; + } +} + +int DiskAnnBuilder::train(const IndexTrainer::Pointer & /*trainer*/) { + if (state_ != BUILD_STATE_INITED) { + LOG_ERROR("Init the builder before DiskAnnBuilder::train"); + return IndexError_NoReady; + } + + LOG_INFO("Begin DiskAnnBuilder::train by trainer"); + + stats_.set_trained_count(0UL); + stats_.set_trained_costtime(0UL); + state_ = BUILD_STATE_TRAINED; + + LOG_INFO("End DiskAnnBuilder::train by trainer"); + + return 0; +} + +int DiskAnnBuilder::train(IndexThreads::Pointer threads, + IndexHolder::Pointer holder) { + if (state_ != BUILD_STATE_INITED) { + LOG_ERROR("Init the builder before DiskAnnBuilder::train"); + return IndexError_NoReady; + } + + LOG_INFO("Begin DiskAnnBuilder::train"); + + auto start_time = ailego::Monotime::MilliSeconds(); + + holder_ = holder; + + LOG_INFO("Start to calculate chunk num"); + int ret = calculate_pq_chunk_num(); + if (ailego_unlikely(ret != 0)) { + return ret; + } + + if (!threads) { + threads = + std::make_shared(build_thread_count_, false); + if (!threads) { + return IndexError_NoMemory; + } + } + + ret = train_quantized_data(threads); + if (ailego_unlikely(ret != 0)) { + return ret; + } + + stats_.set_trained_count(holder_->count()); + + stats_.set_trained_costtime(ailego::Monotime::MilliSeconds() - start_time); + + state_ = BUILD_STATE_TRAINED; + + holder_.reset(); + + LOG_INFO("End DiskAnnBuilder::train"); + + return 0; +} + +int DiskAnnBuilder::do_norm(const void *data_ptr, std::string *norm_data) { + float norm_pt = std::numeric_limits::epsilon(); + + const float *float_data_ptr = reinterpret_cast(data_ptr); + + norm_data->resize(build_meta_.dimension() * sizeof(float)); + float *output_buf = reinterpret_cast(&((*norm_data)[0])); + + for (uint32_t dim = 0; dim < build_meta_.dimension(); dim++) { + norm_pt += *(float_data_ptr + dim) * *(float_data_ptr + dim); + } + norm_pt = std::sqrt(norm_pt); + + for (uint32_t dim = 0; dim < build_meta_.dimension(); dim++) { + *(output_buf + dim) = *(float_data_ptr + dim) / norm_pt; + } + + return 0; +} + +int DiskAnnBuilder::build(IndexThreads::Pointer threads, + IndexHolder::Pointer holder) { + LOG_INFO("Start DiskAnnBuilder::build"); + + auto start_time = ailego::Monotime::MilliSeconds(); + + holder_ = holder; + + if (!threads) { + threads = + std::make_shared(build_thread_count_, false); + if (!threads) { + return IndexError_NoMemory; + } + } + + auto iter = holder->create_iterator(); + if (!iter) { + LOG_ERROR("Create iterator for holder failed"); + return IndexError_Runtime; + } + + int ret = entity_.reserve_space(holder->count()); + + error_ = false; + while (iter->is_valid()) { + ret = entity_.add_vector(iter->key(), iter->data()); + if (ailego_unlikely(ret != 0)) { + return ret; + } + + iter->next(); + } + + LOG_INFO("Finished saving vector"); + + LOG_INFO("Start to calculate entrypoint"); + ret = calculate_entry_point(); + if (ailego_unlikely(ret != 0)) { + return ret; + } + + LOG_INFO("Start to build vamana graph"); + // auto test_threads = std::make_shared(1, + // false); + // ret = build_internal(test_threads); + ret = build_internal(threads); + if (ret != 0) { + return ret; + } + + LOG_INFO("Start final cleanup.."); + // ret = prune_internal(test_threads); + ret = prune_internal(threads); + if (ret != 0) { + return ret; + } + + LOG_INFO("Start to generate quantized data"); + ret = generate_quantized_data(threads); + if (ailego_unlikely(ret != 0)) { + return ret; + } + + state_ = BUILD_STATE_BUILT; + + stats_.set_built_count(entity_.doc_cnt()); + stats_.set_built_costtime(ailego::Monotime::MilliSeconds() - start_time); + + LOG_INFO("End DiskAnnBuilder::build"); + + return 0; +} + +int DiskAnnBuilder::dump(const IndexDumper::Pointer &dumper) { + if (state_ != BUILD_STATE_BUILT) { + LOG_INFO("Build the index before HnswBuilder::dump"); + return IndexError_NoReady; + } + + LOG_INFO("Begin DiskAnnBuilder::dump"); + + raw_meta_.set_searcher("DiskAnnSearcher", 0, ailego::Params()); + auto start_time = ailego::Monotime::MilliSeconds(); + + int ret = IndexHelper::SerializeToDumper(raw_meta_, dumper.get()); + if (ret != 0) { + LOG_ERROR("Failed to serialize meta into dumper."); + return ret; + } + + ret = entity_.dump(holder_, raw_meta_, dumper); + if (ret != 0) { + LOG_ERROR("Index dump failed, ret: %u", ret); + + return IndexError_Runtime; + } + + stats_.set_dumped_count(holder_->count()); + stats_.set_dumped_costtime(ailego::Monotime::MilliSeconds() - start_time); + + LOG_INFO("DiskAnnBuilder::dump"); + + return 0; +} + +INDEX_FACTORY_REGISTER_BUILDER(DiskAnnBuilder); + +} // namespace core +} // namespace zvec diff --git a/src/core/algorithm/diskann/diskann_builder.h b/src/core/algorithm/diskann/diskann_builder.h new file mode 100644 index 000000000..b8df7f57f --- /dev/null +++ b/src/core/algorithm/diskann/diskann_builder.h @@ -0,0 +1,131 @@ +// Copyright 2025-present the zvec project +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +#pragma once + +#include +#include +#include "diskann_algorithm.h" +#include "diskann_builder_entity.h" +#include "diskann_pq_trainer.h" + +namespace zvec { +namespace core { + +class DiskAnnBuilder : public IndexBuilder { + public: + //! Constructor + DiskAnnBuilder() = default; + + //! Initialize the builder + virtual int init(const IndexMeta &meta, + const ailego::Params ¶ms) override; + + //! Cleanup the builder + virtual int cleanup(void) override; + + //! Train the data + virtual int train(IndexThreads::Pointer threads, + IndexHolder::Pointer holder) override; + + //! Train the data + virtual int train(const IndexTrainer::Pointer &trainer) override; + + //! Build the index + virtual int build(IndexThreads::Pointer threads, + IndexHolder::Pointer holder) override; + + //! Dump index into storage + virtual int dump(const IndexDumper::Pointer &dumper) override; + + //! Retrieve statistics + virtual const Stats &stats(void) const override { + return stats_; + } + + int do_norm(const void *data_ptr, std::string *norm_data); + + private: + int train_quantized_data(IndexThreads::Pointer threads); + int generate_quantized_data(IndexThreads::Pointer threads); + int build_internal(IndexThreads::Pointer threads); + int prune_internal(IndexThreads::Pointer threads); + + void do_build(uint64_t idx, size_t step_size, + std::atomic *finished); + + void do_prune(uint64_t idx, size_t step_size, + std::atomic *finished); + + int calculate_entry_point(); + + int calculate_pq_chunk_num(); + + double get_memory_in_bytes(double search_ram_budget) { + return search_ram_budget * 1024 * 1024 * 1024; + } + + private: + enum BUILD_STATE { + BUILD_STATE_INIT = 0, + BUILD_STATE_INITED = 1, + BUILD_STATE_TRAINED = 2, + BUILD_STATE_BUILT = 3 + }; + + constexpr static uint32_t kDefaultLogIntervalSecs = 15U; + constexpr static uint32_t kDefaultListSize = 50U; + constexpr static uint32_t kDefaultMaxDegree = 100U; + constexpr static uint32_t kDefaultPqChunkNum = -1U; + + std::string data_file_; + + uint32_t max_degree_{kDefaultMaxDegree}; + uint32_t list_size_{kDefaultListSize}; + double memory_limit_{0.0}; + bool memory_limit_set_{false}; + uint32_t max_pq_chunk_num_{kDefaultPqChunkNum}; + uint32_t pq_chunk_num_{kDefaultPqChunkNum}; + uint32_t build_thread_count_{0}; + uint32_t max_train_sample_count_{PQTable::kMaxTrainSampleCount}; + double train_sample_ratio_{PQTable::kTrainSampleRatio}; + std::string universal_label_{""}; + std::string codebook_prefix_{""}; + std::string index_path_prefix_{"./diskann"}; + + BUILD_STATE state_; + Stats stats_; + + int errcode_{0}; + std::atomic_bool error_{false}; + + IndexMetric::Pointer metric_{}; + + std::mutex mutex_{}; + std::condition_variable cond_{}; + + IndexMeta raw_meta_; + IndexMeta build_meta_; + + DiskAnnBuilderEntity entity_{}; + + IndexHolder::Pointer holder_; + + DiskAnnAlgorithm::UPointer algo_; + DiskAnnPqTrainer::UPointer trainer_; + + uint32_t check_interval_secs_{kDefaultLogIntervalSecs}; +}; + +} // namespace core +} // namespace zvec diff --git a/src/core/algorithm/diskann/diskann_builder_entity.cc b/src/core/algorithm/diskann/diskann_builder_entity.cc new file mode 100644 index 000000000..e023cbc97 --- /dev/null +++ b/src/core/algorithm/diskann/diskann_builder_entity.cc @@ -0,0 +1,656 @@ +// Copyright 2025-present the zvec project +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "diskann_builder_entity.h" +#include +#include "diskann_algorithm.h" +#include "diskann_util.h" + +namespace zvec { +namespace core { + +int DiskAnnBuilderEntity::init(const IndexMeta &meta, uint32_t max_degree, + uint32_t list_size, double memory_limit, + uint32_t build_threads) { + meta_ = meta; + + max_degree_ = max_degree; + list_size_ = list_size; + + memory_limit_ = memory_limit; + + num_threads_ = build_threads; + + max_build_degree_ = max_degree_ * kDefaultGraphSlackFactor; + + neighbor_size_ = sizeof(uint32_t) + max_build_degree_ * sizeof(diskann_id_t); + + return 0; +} + +int DiskAnnBuilderEntity::reserve_space(uint32_t docs) { + vectors_buffer_.reserve(meta_.element_size() * docs); + keys_buffer_.reserve(sizeof(diskann_key_t) * docs); + neighbors_buffer_.reserve(neighbor_size_ * docs); + + return 0; +} + +int DiskAnnBuilderEntity::add_vector(diskann_key_t key, const void *vec) { + vectors_buffer_.append(reinterpret_cast(vec), + meta_.element_size()); + keys_buffer_.append(reinterpret_cast(&key), sizeof(key)); + + uint32_t neighbor_cnt = 0; + std::vector neighbor{max_build_degree_, 0}; + + neighbors_buffer_.append(reinterpret_cast(&neighbor_cnt), + sizeof(uint32_t)); + neighbors_buffer_.append(reinterpret_cast(neighbor.data()), + sizeof(diskann_id_t) * max_build_degree_); + + (*mutable_doc_cnt())++; + + return 0; +} + +const void *DiskAnnBuilderEntity::get_vector(diskann_id_t id) const { + size_t offset = (size_t)id * meta_.element_size(); + return vectors_buffer_.data() + offset; +} + +diskann_key_t DiskAnnBuilderEntity::get_key(diskann_id_t id) const { + size_t offset = (size_t)id * sizeof(diskann_key_t); + + return *( + reinterpret_cast(keys_buffer_.data() + offset)); +} + +//! Get vector local id by key +diskann_id_t DiskAnnBuilderEntity::get_id(diskann_key_t /*key*/) const { + LOG_ERROR("DiskAnnBuilderEntity::get_id not implemented."); + return kInvalidId; +} + +std::pair DiskAnnBuilderEntity::get_neighbors( + diskann_id_t id) const { + size_t offset = (size_t)id * neighbor_size_; + + const uint8_t *start_ptr = + reinterpret_cast(neighbors_buffer_.data()) + offset; + + uint32_t neighbor_cnt = *(reinterpret_cast(start_ptr)); + + const diskann_id_t *neighbors = + reinterpret_cast(start_ptr + sizeof(uint32_t)); + + return std::make_pair(neighbor_cnt, neighbors); +} + +int DiskAnnBuilderEntity::set_neighbors( + diskann_id_t id, const std::vector &neighbor_ids) { + size_t offset = (size_t)id * neighbor_size_; + + uint8_t *start_ptr = + reinterpret_cast(&neighbors_buffer_[0]) + offset; + + uint32_t neighbor_cnt = neighbor_ids.size(); + + memcpy(start_ptr + sizeof(uint32_t), neighbor_ids.data(), + sizeof(diskann_id_t) * neighbor_cnt); + memcpy(start_ptr, &neighbor_cnt, sizeof(uint32_t)); + + if (max_observed_degree_ < neighbor_cnt) { + max_observed_degree_ = neighbor_cnt; + } + + return 0; +} + +int DiskAnnBuilderEntity::add_neighbor(diskann_id_t id, + diskann_id_t neighbor_id) { + size_t offset = (size_t)id * neighbor_size_; + + uint8_t *start_ptr = + reinterpret_cast(&neighbors_buffer_[0]) + offset; + + uint32_t neighbor_cnt = *reinterpret_cast(start_ptr); + + memcpy(start_ptr + sizeof(uint32_t) + sizeof(diskann_id_t) * neighbor_cnt, + &neighbor_id, sizeof(diskann_id_t)); + + neighbor_cnt += 1; + + memcpy(start_ptr, &neighbor_cnt, sizeof(uint32_t)); + + if (max_observed_degree_ < neighbor_cnt) { + max_observed_degree_ = neighbor_cnt; + } + + return 0; +} + +int64_t DiskAnnBuilderEntity::dump_segment(const IndexDumper::Pointer &dumper, + const std::string &segment_id, + const void *data, + size_t size) const { + size_t len = dumper->write(data, size); + if (len != size) { + LOG_ERROR("Dump segment %s data failed, expect: %lu, actual: %lu", + segment_id.c_str(), size, len); + return IndexError_WriteData; + } + + size_t padding_size = AlignSize(size) - size; + if (padding_size > 0) { + std::string padding(padding_size, '\0'); + if (dumper->write(padding.data(), padding_size) != padding_size) { + LOG_ERROR("Append padding failed, size %lu", padding_size); + return IndexError_WriteData; + } + } + + uint32_t crc = ailego::Crc32c::Hash(data, size); + int ret = dumper->append(segment_id, size, padding_size, crc); + if (ret != 0) { + LOG_ERROR("Dump segment %s meta failed, ret=%d", segment_id.c_str(), ret); + return ret; + } + + return 0; +} + +int DiskAnnBuilderEntity::dump_pq_meta_segment( + const IndexDumper::Pointer &dumper) const { + uint32_t crc = 0U; + + // write meta + size_t size_pq_meta = dumper->write(&pq_meta_, sizeof(DiskAnnPqMeta)); + if (size_pq_meta != sizeof(DiskAnnPqMeta)) { + LOG_ERROR("Failed to dump PQ meta data, expect: %lu, actual: %lu", + sizeof(DiskAnnPqMeta), size_pq_meta); + return IndexError_WriteData; + } + + crc = ailego::Crc32c::Hash(&pq_meta_, sizeof(DiskAnnPqMeta), crc); + + // write full pivot data + size_t size_full_pivot_data = + dumper->write(pq_full_pivot_data_.data(), pq_meta_.full_pivot_data_size); + if (size_full_pivot_data != pq_meta_.full_pivot_data_size) { + LOG_ERROR("Failed to dump full pivot data, expect: %zu, actual: %zu", + (size_t)pq_meta_.full_pivot_data_size, size_full_pivot_data); + return IndexError_WriteData; + } + + crc = ailego::Crc32c::Hash(pq_full_pivot_data_.data(), + pq_meta_.full_pivot_data_size, crc); + + // write centroid num + size_t size_centroid = + dumper->write(pq_centroid_.data(), pq_meta_.centroid_data_size); + if (size_centroid != pq_meta_.centroid_data_size) { + LOG_ERROR("Failed to dump centroid num, expect: %zu, actual: %zu", + (size_t)pq_meta_.centroid_data_size, size_centroid); + return IndexError_WriteData; + } + + crc = ailego::Crc32c::Hash(pq_centroid_.data(), pq_meta_.centroid_data_size, + crc); + + // write chunk offset + size_t size_chunk_offset = dumper->write( + pq_chunk_offsets_.data(), (pq_meta_.chunk_num + 1) * sizeof(uint32_t)); + if (size_chunk_offset != (pq_meta_.chunk_num + 1) * sizeof(uint32_t)) { + LOG_ERROR("Failed to dump centroid num, expect: %zu, actual: %zu", + (size_t)((pq_meta_.chunk_num + 1) * sizeof(uint32_t)), + size_chunk_offset); + return IndexError_WriteData; + } + + crc = ailego::Crc32c::Hash(pq_chunk_offsets_.data(), + (pq_meta_.chunk_num + 1) * sizeof(uint32_t), crc); + + // write size + size_t size_total = + size_pq_meta + size_full_pivot_data + size_centroid + size_chunk_offset; + + // write pad + size_t padding_size = AlignSize(size_total) - size_total; + if (padding_size > 0) { + std::string padding(padding_size, '\0'); + if (dumper->write(padding.data(), padding_size) != padding_size) { + LOG_ERROR("Append padding failed, size %lu", padding_size); + return IndexError_WriteData; + } + } + + int ret = + dumper->append(kDiskAnnPqMetaSegmentId, size_total, padding_size, crc); + if (ret != 0) { + LOG_ERROR("Dump PQ segment failed, ret %d", ret); + return ret; + } + + return 0; +} + +int DiskAnnBuilderEntity::dump_pq_data_segment( + const IndexDumper::Pointer &dumper) const { + uint64_t doc_cnt = meta_header_.doc_cnt; + uint64_t chunk_num = pq_meta_.chunk_num; + + uint32_t crc = 0U; + + // write pq data + size_t size_total = + dumper->write(block_compressed_data_.data(), doc_cnt * chunk_num); + + if (size_total != doc_cnt * chunk_num) { + LOG_ERROR("Failed to dump block compressed data, expect: %zu, actual: %zu", + (size_t)(doc_cnt * chunk_num), size_total); + return IndexError_WriteData; + } + + crc = ailego::Crc32c::Hash(block_compressed_data_.data(), doc_cnt * chunk_num, + crc); + + // write pad + size_t padding_size = AlignSize(size_total) - size_total; + if (padding_size > 0) { + std::string padding(padding_size, '\0'); + if (dumper->write(padding.data(), padding_size) != padding_size) { + LOG_ERROR("Append padding failed, size %lu", padding_size); + return IndexError_WriteData; + } + } + + int ret = + dumper->append(kDiskAnnPqDataSegmentId, size_total, padding_size, crc); + if (ret != 0) { + LOG_ERROR("Dump PQ data segment failed, ret %d", ret); + return ret; + } + + return 0; +} + +int DiskAnnBuilderEntity::dump_dummy_segment( + const IndexDumper::Pointer &dumper) const { + // to make offset aligned with 4K + size_t dumper_header_size = dumper->size(); + + size_t dummy_size = + DiskAnnUtil::round_up(dumper_header_size, DiskAnnUtil::kSectorSize) - + dumper_header_size; + + if (dummy_size != 0) { + std::string dummpy_data(dummy_size, '\0'); + if (dumper->write(dummpy_data.data(), dummy_size) != dummy_size) { + LOG_ERROR("write dummy failed, size %lu", dummy_size); + return IndexError_WriteData; + } + + int ret = dumper->append(kDiskAnnDummpySegmentId, dummy_size, 0, 0); + if (ret != 0) { + LOG_ERROR("Dump dummy data segment failed, ret %d", ret); + return ret; + } + } + + return 0; +} + +int DiskAnnBuilderEntity::dump_key_segment( + const IndexDumper::Pointer &dumper) const { + //! Dump keys + size_t key_segment_size = doc_cnt() * sizeof(diskann_key_t); + int64_t keys_size = dump_segment(dumper, kDiskAnnKeySegmentId, + keys_buffer_.data(), key_segment_size); + if (keys_size < 0) { + return keys_size; + } + + return 0; +} + +int DiskAnnBuilderEntity::dump_key_mapping_segment( + const IndexDumper::Pointer &dumper) const { + std::vector mapping(doc_cnt()); + + const diskann_key_t *keys = reinterpret_cast( + const_cast(keys_buffer_.data())); + + std::iota(mapping.begin(), mapping.end(), 0U); + std::sort(mapping.begin(), mapping.end(), + [&](diskann_id_t i, diskann_id_t j) { return keys[i] < keys[j]; }); + + size_t size = mapping.size() * sizeof(diskann_id_t); + int64_t ret = + dump_segment(dumper, kDiskAnnKeyMappingSegmentId, mapping.data(), size); + + if (ret != 0) { + LOG_ERROR("Dump vectors segment failed"); + + return ret; + } + + return 0; +} + +int DiskAnnBuilderEntity::dump_entrypoint_segment( + const IndexDumper::Pointer &dumper) const { + std::string entrypoint_buffer; + + size_t size = sizeof(uint32_t) + entrypoints_.size() * sizeof(diskann_id_t); + entrypoint_buffer.resize(size); + + uint8_t *buffer_ptr = reinterpret_cast(&entrypoint_buffer[0]); + + uint32_t point_cnt = entrypoints_.size(); + memcpy(buffer_ptr, &point_cnt, sizeof(uint32_t)); + memcpy(buffer_ptr + sizeof(uint32_t), entrypoints_.data(), + entrypoints_.size() * sizeof(diskann_id_t)); + + int64_t ret = dump_segment(dumper, kDiskAnnEntryPointSegmentId, + entrypoint_buffer.data(), size); + + if (ret != 0) { + LOG_ERROR("Dump entrypoint segment failed"); + + return ret; + } + + return 0; +} + +int DiskAnnBuilderEntity::dump(IndexHolder::Pointer holder, IndexMeta &meta, + const IndexDumper::Pointer &dumper) { + uint64_t doc_cnt = holder->count(); + uint64_t max_node_size = + (uint64_t)max_observed_degree_ * sizeof(diskann_id_t) + sizeof(uint32_t) + + meta_.element_size(); + uint64_t node_per_sector = + DiskAnnUtil::kSectorSize / + max_node_size; // 0 if max_node_size > DiskAnnUtil::kSectorSize + + std::string node_buf; + node_buf.resize(max_node_size); + + // uint32_t & neighbor_num = *(uint32_t *)(node_buf.data() + + // meta_.element_size()); + diskann_id_t *neighbor_buf = + (diskann_id_t *)(node_buf.data() + (meta_.element_size()) + + sizeof(uint32_t)); + + LOG_INFO( + "Dump Data, medoid: %zu, max node size: %zu, node per sector: %zu, " + "max observed degree: %zu", + (size_t)medoid(), (size_t)max_node_size, (size_t)node_per_sector, + (size_t)max_observed_degree_); + + // write a dummy segment to make data align + int ret = dump_dummy_segment(dumper); + if (ret != 0) { + LOG_ERROR("Dump dummy segment failed"); + + return ret; + } + + // dump data by sector + size_t write_size = 0; + uint32_t crc = 0U; + size_t len = 0; + + // no need to write first sector + // size_t len = dumper->write(sector_buf.data(), DiskAnnUtil::kSectorSize); + // if (len != DiskAnnUtil::kSectorSize) { + // LOG_ERROR("Write Vector Error, write=%zu, expect=%zu", len, + // DiskAnnUtil::kSectorSize); + + // return IndexError_WriteData; + // } + + // write_size += len; + // crc = ailego::Crc32c::Hash(sector_buf.data(), DiskAnnUtil::kSectorSize, + // crc); + auto iter = holder->create_iterator(); + if (!iter) { + LOG_ERROR("Create iterator for holder failed"); + return IndexError_Runtime; + } + + uint64_t index_size = 0; + uint32_t neighbor_num; + if (node_per_sector > 0) { + uint64_t sector_num = + DiskAnnUtil::round_up(doc_cnt, node_per_sector) / node_per_sector; + + diskann_id_t cur_node_id = 0; + + std::string sector_buf; + sector_buf.resize(DiskAnnUtil::kSectorSize); + + for (uint64_t sector = 0; sector < sector_num; sector++) { + if (sector != 0 && sector % 100000 == 0) { + LOG_INFO("Sector #%zu written", (size_t)sector); + } + + memset(&(sector_buf[0]), 0, DiskAnnUtil::kSectorSize); + + for (uint64_t sector_node_id = 0; + sector_node_id < node_per_sector && cur_node_id < doc_cnt; + sector_node_id++) { + memset(&(node_buf[0]), 0, max_node_size); + + auto neighbors = get_neighbors(cur_node_id); + neighbor_num = neighbors.first; + + ailego_assert(neighbor_num > 0); + ailego_assert(neighbor_num <= max_observed_degree_); + + memcpy(&(neighbor_buf[0]), neighbors.second, + neighbors.first * sizeof(diskann_id_t)); + + if (iter->is_valid()) { + const void *vec = iter->data(); + memcpy(&(node_buf[0]), vec, meta.element_size()); + + iter->next(); + } else { + return IndexError_Runtime; + } + + // write neighbor num + *(uint32_t *)(node_buf.data() + meta_.element_size()) = neighbor_num; + + // write neighbor buffer + memcpy(&(node_buf[0]) + meta_.element_size() + sizeof(uint32_t), + neighbor_buf, neighbor_num * sizeof(diskann_id_t)); + + // get offset into sector_buf + char *sector_node_buf = §or_buf[sector_node_id * max_node_size]; + + // copy node buf into sector_node_buf + memcpy(sector_node_buf, node_buf.data(), max_node_size); + + cur_node_id++; + } + + // flush sector to disk + len = dumper->write(sector_buf.data(), DiskAnnUtil::kSectorSize); + if (len != DiskAnnUtil::kSectorSize) { + LOG_ERROR("Write Vector Error, write=%zu, expect=%zu", len, + (size_t)DiskAnnUtil::kSectorSize); + + return IndexError_WriteData; + } + write_size += DiskAnnUtil::kSectorSize; + crc = ailego::Crc32c::Hash(sector_buf.data(), DiskAnnUtil::kSectorSize, + crc); + } + + LOG_INFO("Total Sector #%zu written", (size_t)sector_num); + + index_size = sector_num * DiskAnnUtil::kSectorSize; + } else { + // Write multi-sector nodes + std::string multisector_buf; + multisector_buf.resize( + DiskAnnUtil::round_up(max_node_size, DiskAnnUtil::kSectorSize)); + + uint64_t sector_num_per_node = + DiskAnnUtil::div_round_up(max_node_size, DiskAnnUtil::kSectorSize); + + for (uint64_t i = 0; i < doc_cnt; i++) { + if (i != 0 && (i * sector_num_per_node) % 100000 == 0) { + LOG_INFO("Sector # %zu written", (size_t)(i * sector_num_per_node)); + } + + memset(&(multisector_buf[0]), 0, + sector_num_per_node * DiskAnnUtil::kSectorSize); + memset(&(node_buf[0]), 0, max_node_size); + + auto neighbors = get_neighbors(i); + neighbor_num = neighbors.first; + + ailego_assert(neighbor_num > 0); + ailego_assert(neighbor_num <= max_observed_degree_); + + // read node's nhood + memcpy((char *)neighbor_buf, neighbors.second, + neighbor_num * sizeof(diskann_id_t)); + + if (iter->is_valid()) { + const void *vec = iter->data(); + memcpy(&(multisector_buf[0]), vec, meta.element_size()); + + iter->next(); + } else { + return IndexError_Runtime; + } + + // write neighbor + *(uint32_t *)(&(multisector_buf[0]) + meta_.element_size()) = + neighbor_num; + + // write nhood next + memcpy(&(multisector_buf[0]) + meta_.element_size() + sizeof(uint32_t), + neighbor_buf, neighbor_num * sizeof(diskann_id_t)); + + // flush sector to disk + len = dumper->write(multisector_buf.data(), + sector_num_per_node * DiskAnnUtil::kSectorSize); + if (len != sector_num_per_node * DiskAnnUtil::kSectorSize) { + LOG_ERROR("Write Vector Error, write=%zu, expect=%zu", len, + (size_t)(sector_num_per_node * DiskAnnUtil::kSectorSize)); + + return IndexError_WriteData; + } + + write_size += sector_num_per_node * DiskAnnUtil::kSectorSize; + + crc = ailego::Crc32c::Hash(multisector_buf.data(), + sector_num_per_node * DiskAnnUtil::kSectorSize, + crc); + } + + LOG_INFO("Total Sector #%zu written", + (size_t)(doc_cnt * sector_num_per_node)); + + index_size = doc_cnt * sector_num_per_node * DiskAnnUtil::kSectorSize; + } + + size_t padding_size = AlignSize(write_size) - write_size; + if (padding_size > 0) { + std::string padding(padding_size, '\0'); + if (dumper->write(padding.data(), padding_size) != padding_size) { + LOG_ERROR("Append padding failed, size %lu", padding_size); + return IndexError_WriteData; + } + } + + ret = dumper->append(kDiskAnnVectorSegmentId, write_size, 0UL, crc); + if (ret != 0) { + LOG_ERROR("Dump vectors segment failed, ret %d", ret); + return ret; + } + + // dump diskann meta + meta_header_.doc_cnt = doc_cnt; + meta_header_.ndims = meta_.dimension(); + meta_header_.medoid = medoid(); + meta_header_.max_node_size = max_node_size; + meta_header_.max_degree = max_observed_degree_; + meta_header_.node_per_sector = node_per_sector; + meta_header_.vamana_frozen_num = 0; + meta_header_.vamana_frozen_loc = medoid(); + meta_header_.append_reorder_data = 0; + meta_header_.index_size = index_size; + + ret = dump_segment(dumper, kDiskAnnMetaSegmentId, &meta_header_, + sizeof(DiskAnnMetaHeader)); + if (ret != 0) { + LOG_ERROR("Dump vectors segment failed"); + + return ret; + } + + // dump pq meta + ret = dump_pq_meta_segment(dumper); + if (ret != 0) { + LOG_ERROR("Dump pq meta segment failed"); + + return ret; + } + + // dump pq data + ret = dump_pq_data_segment(dumper); + if (ret != 0) { + LOG_ERROR("Dump pq data segment failed"); + + return ret; + } + + // dump key + ret = dump_key_segment(dumper); + if (ret != 0) { + LOG_ERROR("Dump key segment failed"); + + return ret; + } + + // dump key mapping + ret = dump_key_mapping_segment(dumper); + if (ret != 0) { + LOG_ERROR("Dump key mapping segment failed"); + + return ret; + } + + // dump entrypint + ret = dump_entrypoint_segment(dumper); + if (ret != 0) { + LOG_ERROR("Dump entrypoint segment failed"); + + return ret; + } + + LOG_INFO("DiskAnn Index File Dumped"); + + return 0; +} + +} // namespace core +} // namespace zvec diff --git a/src/core/algorithm/diskann/diskann_builder_entity.h b/src/core/algorithm/diskann/diskann_builder_entity.h new file mode 100644 index 000000000..d2223ab95 --- /dev/null +++ b/src/core/algorithm/diskann/diskann_builder_entity.h @@ -0,0 +1,107 @@ +// Copyright 2025-present the zvec project +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +#pragma once + +#include +#include +#include "diskann_entity.h" + +namespace zvec { +namespace core { + +// wrapper class aligned with diskann +class DiskAnnBuilderEntity : public DiskAnnEntity { + public: + using Pointer = std::shared_ptr; + + DiskAnnBuilderEntity() = default; + virtual ~DiskAnnBuilderEntity() = default; + + public: + int add_vector(diskann_key_t key, const void *vec) override; + + std::pair get_neighbors( + diskann_id_t id) const override; + + int set_neighbors(diskann_id_t id, + const std::vector &neighbor_ids) override; + + int add_neighbor(diskann_id_t id, diskann_id_t neighbor_id) override; + + diskann_id_t get_id(diskann_key_t key) const override; + diskann_key_t get_key(diskann_id_t id) const override; + const void *get_vector(diskann_id_t id) const override; + + public: + int init(const IndexMeta &meta, uint32_t max_degree, uint32_t list_size, + double memory_limit, uint32_t build_threads); + + int dump(IndexHolder::Pointer holder, IndexMeta &meta, + const IndexDumper::Pointer &dumper); + + int64_t dump_segment(const IndexDumper::Pointer &dumper, + const std::string &segment_id, const void *data, + size_t size) const; + int dump_dummy_segment(const IndexDumper::Pointer &dumper) const; + int dump_pq_meta_segment(const IndexDumper::Pointer &dumper) const; + int dump_pq_data_segment(const IndexDumper::Pointer &dumper) const; + int dump_key_mapping_segment(const IndexDumper::Pointer &dumper) const; + int dump_entrypoint_segment(const IndexDumper::Pointer &dumper) const; + int dump_key_segment(const IndexDumper::Pointer &dumper) const; + + int reserve_space(uint32_t docs); + + std::vector &pq_full_pivot_data() { + return pq_full_pivot_data_; + } + + std::vector &pq_centroid() { + return pq_centroid_; + } + + std::vector &pq_chunk_offsets() { + return pq_chunk_offsets_; + } + + std::vector &block_compressed_data() { + return block_compressed_data_; + } + + private: + uint32_t max_degree_{0}; + uint32_t list_size_{0}; + double memory_limit_{0}; + uint32_t num_threads_{0}; + uint32_t max_build_degree_{0}; + uint32_t max_observed_degree_{0}; + uint32_t neighbor_size_{0}; + + std::string mem_index_file_{""}; + std::string index_path_prefix_{""}; + + std::string vectors_buffer_{}; + std::string keys_buffer_{}; + std::string neighbors_buffer_{}; + std::vector entrypoints_{}; + + IndexMeta meta_; + + std::vector pq_full_pivot_data_; + std::vector pq_centroid_; + std::vector pq_chunk_offsets_; + std::vector block_compressed_data_; +}; + +} // namespace core +} // namespace zvec \ No newline at end of file diff --git a/src/core/algorithm/diskann/diskann_context.cc b/src/core/algorithm/diskann/diskann_context.cc new file mode 100644 index 000000000..a564b684f --- /dev/null +++ b/src/core/algorithm/diskann/diskann_context.cc @@ -0,0 +1,151 @@ +// Copyright 2025-present the zvec project +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "diskann_context.h" +#include +#include "diskann_pq_table.h" +#include "diskann_util.h" + +namespace zvec { +namespace core { + +DiskAnnContext::DiskAnnContext(const IndexMeta &meta, + const IndexMetric::Pointer &measure, + const DiskAnnEntity::Pointer &entity) + : dc_(entity.get(), measure, meta.dimension()), entity_{entity} {} + +int DiskAnnContext::init(ContextType type, uint32_t graph_degree, + uint32_t pq_chunk_num, uint32_t element_size) { + type_ = type; + element_size_ = element_size; + pq_chunk_num_ = pq_chunk_num; + + DiskAnnUtil::alloc_aligned((void **)&query_, element_size_, 32); + DiskAnnUtil::alloc_aligned((void **)&query_rotated_, element_size_, 32); + + int ret; + switch (type) { + case kBuilderContext: + ret = visit_filter_.init(VisitFilter::ByteMap, entity_->doc_cnt(), + entity_->doc_cnt(), negative_probility_); + if (ret != 0) { + LOG_ERROR("Create filter failed, mode %d", filter_mode_); + return ret; + } + break; + + case kSearcherContext: + ret = visit_filter_.init(filter_mode_, entity_->doc_cnt(), + entity_->doc_cnt(), negative_probility_); + if (ret != 0) { + LOG_ERROR("Create filter failed, mode %d", filter_mode_); + return ret; + } + + DiskAnnUtil::alloc_aligned( + (void **)&pq_table_dist_buffer_, + PQTable::kPQCentroidNum * pq_chunk_num_ * sizeof(float), 256); + DiskAnnUtil::alloc_aligned((void **)&pq_coord_buffer_, + graph_degree * pq_chunk_num_ * sizeof(uint8_t), + 256); + DiskAnnUtil::alloc_aligned((void **)&coord_buffer_, element_size_, 256); + DiskAnnUtil::alloc_aligned( + (void **)§or_buffer_, + DiskAnnUtil::kMaxSectorReadNum * DiskAnnUtil::kSectorSize, + DiskAnnUtil::kSectorSize); + + ret = setup_io_ctx(io_ctx_); + if (ret != 0) { + LOG_ERROR("setup io ctx error, ret=%d", ret); + return ret; + } + break; + + default: + LOG_ERROR("Init context failed"); + return IndexError_Runtime; + } + + return 0; +} + +DiskAnnContext::~DiskAnnContext() { + if (query_) { + free(query_); + } + + if (query_rotated_) { + free(query_rotated_); + } + + if (pq_table_dist_buffer_) { + free(pq_table_dist_buffer_); + } + + if (pq_coord_buffer_) { + free(pq_coord_buffer_); + } + + if (coord_buffer_) { + free(coord_buffer_); + } + + if (sector_buffer_) { + free(sector_buffer_); + } + + if (type_ == kSearcherContext) { + destroy_io_ctx(io_ctx_); + } +} + +int DiskAnnContext::update_context(ContextType type, const IndexMeta &meta, + const IndexMetric::Pointer &measure, + const DiskAnnEntity::Pointer &entity, + uint32_t magic_num) { + if (ailego_unlikely(type != type_)) { + LOG_ERROR( + "DiskAnnContext does not support shared by different type, " + "src=%u dst=%u", + type_, type); + return IndexError_Unsupported; + } + + magic_ = kInvalidMgic; + + switch (type) { + case kBuilderContext: + LOG_ERROR("BuildContext does not support update"); + return IndexError_NotImplemented; + + case kSearcherContext: + break; + + case kReducerContext: + break; + + default: + LOG_ERROR("update context failed"); + return IndexError_Runtime; + } + + entity_ = entity; + dc_.update(measure, meta.dimension()); + magic_ = magic_num; + + return 0; +} + +} // namespace core +} // namespace zvec \ No newline at end of file diff --git a/src/core/algorithm/diskann/diskann_context.h b/src/core/algorithm/diskann/diskann_context.h new file mode 100644 index 000000000..a3f092f84 --- /dev/null +++ b/src/core/algorithm/diskann/diskann_context.h @@ -0,0 +1,439 @@ +// Copyright 2025-present the zvec project +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +#pragma once + +#include +#include "diskann_dist_calculator.h" +#include "diskann_entity.h" +#include "diskann_file_reader.h" +#include "diskann_visit_filter.h" + +namespace zvec { +namespace core { + +struct SearchStats { + public: + float total_us = 0; + float io_us = 0; + float cpu_us = 0; + uint64_t disk_page_reads = 0; + uint64_t io_num = 0; + uint64_t dist_num = 0; + uint64_t cache_hits = 0; + uint64_t hop_num = 0; +}; + +class DiskAnnContext : public IndexContext { + public: + //! Index Context Pointer + typedef std::unique_ptr Pointer; + + enum ContextType { + kUnknownContext = 0, + kSearcherContext = 1, + kBuilderContext = 2, + kReducerContext = 3 + }; + + //! Construct + DiskAnnContext(const IndexMeta &meta, const IndexMetric::Pointer &measure, + const DiskAnnEntity::Pointer &entity); + + //! Destructor + virtual ~DiskAnnContext(); + + public: + //! Init + int init(ContextType type, uint32_t graph_degree, uint32_t pq_chunk_num, + uint32_t element_size); + + //! Update context, the context may be shared by different searcher/streamer + int update_context(ContextType type, const IndexMeta &meta, + const IndexMetric::Pointer &measure, + const DiskAnnEntity::Pointer &entity, uint32_t magic_num); + + //! Retrieve search result + virtual const IndexDocumentList &result(void) const override { + return results_[0]; + } + + //! Retrieve search result + virtual const IndexDocumentList &result(size_t idx) const override { + return results_[idx]; + } + + //! Retrieve result object for output + virtual IndexDocumentList *mutable_result(size_t idx) override { + ailego_assert_with(idx < results_.size(), "invalid idx"); + return &results_[idx]; + } + + //! Retrieve search group result with index + virtual const IndexGroupDocumentList &group_result(void) const override { + return group_results_[0]; + } + + //! Retrieve search group result with index + virtual const IndexGroupDocumentList &group_result( + size_t idx) const override { + return group_results_[idx]; + } + + virtual uint32_t magic(void) const override { + return magic_; + } + + void set_magic(uint32_t magic) { + magic_ = magic; + } + + //! Set mode of debug + virtual void set_debug_mode(bool enable) override { + debug_mode_ = enable; + } + + //! Retrieve mode of debug + virtual bool debug_mode(void) const override { + return debug_mode_; + } + + //! Retrieve string of debug + virtual std::string debug_string(void) const override { + return std::string(""); + } + + //! Update the parameters of context + virtual int update(const ailego::Params & /*params*/) override { + return 0; + } + + inline DistCalculator &dist_calculator() { + return dc_; + } + + public: + //! Set topk of search result + void set_topk(uint32_t val) override { + topk_ = val; + topk_heap_.limit(val); + } + + void set_list_size(uint32_t list_size) { + list_size_ = list_size; + } + + void set_fetch_vector(bool v) override { + fetch_vector_ = v; + } + + //! Get topk + inline uint32_t topk() const override { + return topk_; + } + + inline uint32_t list_size() const { + return list_size_; + } + + inline void reset_query(const void *query) { + memcpy(query_, query, element_size_); + memcpy(query_rotated_, query, element_size_); + + dc_.reset_query(query); + } + + inline TopkHeap &topk_heap() { + return topk_heap_; + } + + inline void *query() { + return query_; + } + + inline void *query_rotated() { + return query_rotated_; + } + + inline float *pq_table_dist_buffer() { + return pq_table_dist_buffer_; + } + + inline void *pq_coord_buffer() { + return pq_coord_buffer_; + } + + inline void *coord_buffer() { + return coord_buffer_; + } + + inline void *sector_buffer() { + return sector_buffer_; + } + + inline IOContext &io_ctx() { + return io_ctx_; + } + + inline void resize_results(size_t size) { + if (group_by_search()) { + group_results_.resize(size); + } else { + results_.resize(size); + } + } + + inline bool error() const { + return has_error_; + } + + inline void set_error(bool err) { + has_error_ = err; + } + + inline void clear() { + for (auto &it : results_) { + it.clear(); + } + + best_list_nodes_.clear(); + expanded_nodes_.clear(); + visit_filter_.clear(); + has_error_ = false; + } + + SearchStats &query_stats() { + return query_stats_; + } + + const DiskAnnEntity &get_entity() const { + return *entity_; + } + + NeighborPriorityQueue &best_list_nodes() { + return best_list_nodes_; + } + + std::vector &expanded_nodes() { + return expanded_nodes_; + } + + std::vector &occlude_factor() { + return occlude_factor_; + } + + VisitFilter &visit_filter() { + return visit_filter_; + } + + //! Reset context + void reset(void) override { + set_filter(nullptr); + reset_threshold(); + set_fetch_vector(false); + set_group_params(0, 0); + reset_group_by(); + } + + inline std::map &group_topk_heaps() { + return group_topk_heaps_; + } + + //! Get group topk + inline uint32_t group_topk() const { + return group_topk_; + } + + //! Get group num + inline uint32_t group_num() const { + return group_num_; + } + + //! Get if group by search + inline bool group_by_search() { + return group_num_ > 0; + } + + //! Set group params + void set_group_params(uint32_t group_num, uint32_t group_topk) override { + group_num_ = group_num; + group_topk_ = group_topk; + + topk_ = group_topk_ * group_num_; + + topk_heap_.limit(topk_); + + group_topk_heaps_.clear(); + } + + inline void topk_to_result(uint32_t idx) { + if (group_by_search()) { + topk_to_group_result(idx); + } else { + topk_to_single_result(idx); + } + } + + void set_to_result(uint32_t idx, const std::vector &result_ids, + const std::vector &result_dists) { + if (result_ids.size() != result_dists.size()) { + return; + } + + uint32_t size = result_ids.size(); + + for (uint32_t i = 0; i < size; ++i) { + results_[idx].emplace_back(result_ids[i], result_dists[i], 0); + } + } + + inline void topk_to_single_result(uint32_t idx) { + if (ailego_unlikely(topk_heap_.size() == 0)) { + return; + } + + ailego_assert_with(idx < results_.size(), "invalid idx"); + int size = std::min(topk_, static_cast(topk_heap_.size())); + topk_heap_.sort(); + results_[idx].clear(); + + for (int i = 0; i < size; ++i) { + auto info = topk_heap_[i].second; + if (info.dist_ > this->threshold()) { + break; + } + + diskann_id_t id = topk_heap_[i].first; + if (fetch_vector_) { + results_[idx].emplace_back(entity_->get_key(id), info.dist_, id, + info.vec_); + } else { + results_[idx].emplace_back(entity_->get_key(id), info.dist_, id); + } + } + + return; + } + + //! Construct result from topk heap, result will be normalized + inline void topk_to_group_result(uint32_t idx) { + ailego_assert_with(idx < group_results_.size(), "invalid idx"); + + group_results_[idx].clear(); + + std::vector> group_topk_list; + std::vector> best_score_in_groups; + for (auto itr = group_topk_heaps_.begin(); itr != group_topk_heaps_.end(); + itr++) { + const std::string &group_id = (*itr).first; + auto &heap = (*itr).second; + heap.sort(); + + if (heap.size() > 0) { + float best_score = heap[0].second.dist_; + best_score_in_groups.push_back(std::make_pair(group_id, best_score)); + } + } + + std::sort(best_score_in_groups.begin(), best_score_in_groups.end(), + [](const std::pair &a, + const std::pair &b) -> int { + return a.second < b.second; + }); + + // truncate to group num + for (uint32_t i = 0; i < group_num() && i < best_score_in_groups.size(); + ++i) { + const std::string &group_id = best_score_in_groups[i].first; + + group_topk_list.emplace_back( + std::make_pair(group_id, group_topk_heaps_[group_id])); + } + + group_results_[idx].resize(group_topk_list.size()); + + for (uint32_t i = 0; i < group_topk_list.size(); ++i) { + const std::string &group_id = group_topk_list[i].first; + group_results_[idx][i].set_group_id(group_id); + + uint32_t size = std::min( + group_topk_, static_cast(group_topk_list[i].second.size())); + + for (uint32_t j = 0; j < size; ++j) { + auto info = group_topk_list[i].second[j].second; + if (info.dist_ > this->threshold()) { + break; + } + + diskann_id_t id = group_topk_list[i].second[j].first; + + if (fetch_vector_) { + group_results_[idx][i].mutable_docs()->emplace_back( + entity_->get_key(id), info.dist_, id, info.vec_); + } else { + group_results_[idx][i].mutable_docs()->emplace_back( + entity_->get_key(id), info.dist_, id); + } + } + } + } + + private: + constexpr static uint32_t kInvalidMgic = -1U; + + uint32_t type_{kUnknownContext}; + + DistCalculator dc_; + DiskAnnEntity::Pointer entity_; + + uint32_t topk_{0}; + uint32_t magic_{0U}; + bool debug_mode_{false}; + uint32_t pq_chunk_num_{0}; + uint32_t element_size_{0}; + uint32_t element_rotated_size_{0}; + uint32_t list_size_{0}; + + TopkHeap topk_heap_{}; + + uint32_t group_topk_{0}; + uint32_t group_num_{0}; + std::map group_topk_heaps_{}; + + IOContext io_ctx_{0}; + SearchStats query_stats_; + + float *pq_table_dist_buffer_{nullptr}; + void *pq_coord_buffer_{nullptr}; + void *query_{nullptr}; + void *query_rotated_{nullptr}; + void *coord_buffer_{nullptr}; + void *sector_buffer_{nullptr}; + + std::vector results_{}; + std::vector group_results_{}; + + bool fetch_vector_{false}; + bool has_error_{false}; + + NeighborPriorityQueue best_list_nodes_; + std::vector expanded_nodes_; + std::vector occlude_factor_; + + VisitFilter visit_filter_{}; + uint32_t filter_mode_{VisitFilter::ByteMap}; + float negative_probility_{DiskAnnEntity::kDefaultBFNegativeProbility}; +}; + +} // namespace core +} // namespace zvec diff --git a/src/core/algorithm/diskann/diskann_dist_calculator.h b/src/core/algorithm/diskann/diskann_dist_calculator.h new file mode 100644 index 000000000..e64b6f35e --- /dev/null +++ b/src/core/algorithm/diskann/diskann_dist_calculator.h @@ -0,0 +1,150 @@ +// Copyright 2025-present the zvec project +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +#pragma once + +#include +#include +#include +#include "diskann_entity.h" + +namespace zvec { +namespace core { + +class DistCalculator { + public: + typedef std::shared_ptr Pointer; + + public: + //! Constructor + DistCalculator(const DiskAnnEntity *entity, + const IndexMetric::Pointer &measure, uint32_t dim) + : entity_(entity), + distance_(measure->distance()), + query_(nullptr), + dim_(dim), + compare_cnt_(0) {} + + void update(const IndexMetric::Pointer &measure, uint32_t dim) { + distance_ = measure->distance(); + dim_ = dim; + } + + inline void update_distance(const IndexMetric::MatrixDistance &distance) { + distance_ = distance; + } + + //! Reset query vector data + inline void reset_query(const void *query) { + error_ = false; + query_ = query; + } + + //! Returns distance + inline dist_t dist(const void *vec_lhs, const void *vec_rhs) { + if (ailego_unlikely(vec_lhs == nullptr || vec_rhs == nullptr)) { + LOG_ERROR("Nullptr of dense vector"); + + error_ = true; + return 0.0f; + } + + float score{0.0f}; + distance_(vec_lhs, vec_rhs, dim_, &score); + + return score; + } + + //! Returns distance between query and vec. + inline dist_t dist(const void *vec) { + compare_cnt_++; + + return dist(vec, query_); + } + + inline dist_t dist(diskann_id_t id) { + compare_cnt_++; + + const void *vec = entity_->get_vector(id); + if (ailego_unlikely(vec == nullptr)) { + LOG_ERROR("Get nullptr vector, id=%u", id); + error_ = true; + return 0.0f; + } + + return dist(vec, query_); + } + + inline dist_t dist(diskann_id_t lhs, diskann_id_t rhs) { + compare_cnt_++; + + const void *vec_lhs = entity_->get_vector(lhs); + if (ailego_unlikely(vec_lhs == nullptr)) { + LOG_ERROR("Get nullptr vector, lhs id=%u", lhs); + error_ = true; + return 0.0f; + } + + const void *vec_rhs = entity_->get_vector(rhs); + if (ailego_unlikely(vec_rhs == nullptr)) { + LOG_ERROR("Get nullptr vector, rhs id=%u", rhs); + error_ = true; + return 0.0f; + } + + return dist(vec_lhs, vec_rhs); + } + + dist_t operator()(const void *vec) { + return dist(vec); + } + + inline void clear() { + compare_cnt_ = 0; + error_ = false; + } + + inline void clear_compare_cnt() { + compare_cnt_ = 0; + } + + inline bool error() const { + return error_; + } + + //! Get distances compute times + inline uint32_t compare_cnt() const { + return compare_cnt_; + } + + inline uint32_t dimension() const { + return dim_; + } + + private: + DistCalculator(const DistCalculator &) = delete; + DistCalculator &operator=(const DistCalculator &) = delete; + + private: + const DiskAnnEntity *entity_; + + IndexMetric::MatrixDistance distance_; + const void *query_; + uint32_t dim_; + + uint32_t compare_cnt_; // record distance compute times + bool error_{false}; +}; + +} // namespace core +} // namespace zvec diff --git a/src/core/algorithm/diskann/diskann_entity.cc b/src/core/algorithm/diskann/diskann_entity.cc new file mode 100644 index 000000000..56d63fc8f --- /dev/null +++ b/src/core/algorithm/diskann/diskann_entity.cc @@ -0,0 +1,32 @@ +// Copyright 2025-present the zvec project +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "diskann_entity.h" + +namespace zvec { +namespace core { + +const std::string DiskAnnEntity::kDiskAnnVectorSegmentId = "diskann.vector"; +const std::string DiskAnnEntity::kDiskAnnMetaSegmentId = "diskann.meta"; +const std::string DiskAnnEntity::kDiskAnnPqMetaSegmentId = "diskann.pq_meta"; +const std::string DiskAnnEntity::kDiskAnnPqDataSegmentId = "diskann.pq_data"; +const std::string DiskAnnEntity::kDiskAnnDummpySegmentId = "diskann.dummy"; +const std::string DiskAnnEntity::kDiskAnnKeyMappingSegmentId = + "diskann.key_mapping"; +const std::string DiskAnnEntity::kDiskAnnEntryPointSegmentId = + "diskann.entrypoint"; +const std::string DiskAnnEntity::kDiskAnnKeySegmentId = "diskann.key"; + +} // namespace core +} // namespace zvec diff --git a/src/core/algorithm/diskann/diskann_entity.h b/src/core/algorithm/diskann/diskann_entity.h new file mode 100644 index 000000000..3575f3fb0 --- /dev/null +++ b/src/core/algorithm/diskann/diskann_entity.h @@ -0,0 +1,241 @@ +// Copyright 2025-present the zvec project +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +#pragma once + +#include + +namespace zvec { +namespace core { + +using dist_t = float; +using diskann_key_t = uint64_t; +using diskann_id_t = uint32_t; + +constexpr diskann_id_t kInvalidId = static_cast(-1); +constexpr diskann_key_t kInvalidKey = static_cast(-1); + +struct VectorInfo { + public: + float dist_; + std::string vec_; + + public: + VectorInfo() = default; + VectorInfo(float dist, const std::string &vec) : dist_{dist}, vec_{vec} {} +}; + +/*! Key Value Vecotr Heap Comparer + */ +struct KeyValueVectorHeapComparer { + //! Function call + bool operator()(const std::pair &lhs, + const std::pair &rhs) const { + return compare_(lhs.second.dist_, rhs.second.dist_); + } + + private: + std::less compare_; +}; + +/*! Key Value Vector Heap + */ +using TopkHeap = ailego::Heap, + KeyValueVectorHeapComparer>; + +struct DiskAnnMetaHeader { + public: + uint64_t doc_cnt; + uint64_t ndims; + uint64_t medoid; + uint64_t max_node_size; + uint64_t max_degree; + uint64_t node_per_sector; + uint64_t vamana_frozen_num; + uint64_t vamana_frozen_loc; + uint64_t append_reorder_data; + uint64_t index_size; + uint8_t reserved[4048]; /// to fill up to 4096 + + DiskAnnMetaHeader() { + clear(); + } + + DiskAnnMetaHeader(const DiskAnnMetaHeader &header) { + memcpy(this, &header, sizeof(header)); + } + + DiskAnnMetaHeader &operator=(const DiskAnnMetaHeader &header) { + memcpy(this, &header, sizeof(header)); + return *this; + } + + void inline reset() { + doc_cnt = 0U; + } + + void inline clear() { + memset(this, 0, sizeof(DiskAnnMetaHeader)); + } +}; + +struct DiskAnnPqMeta { + public: + uint64_t full_pivot_data_size{0}; + uint64_t centroid_data_size{0}; + uint64_t chunk_offsets_size{0}; + uint64_t chunk_num{0}; + uint8_t reserved[128]; + + DiskAnnPqMeta() { + clear(); + } + + DiskAnnPqMeta(const DiskAnnPqMeta &meta) { + memcpy(this, &meta, sizeof(meta)); + } + + DiskAnnPqMeta &operator=(const DiskAnnPqMeta &meta) { + memcpy(this, &meta, sizeof(meta)); + return *this; + } + + void inline clear() { + memset(this, 0, sizeof(DiskAnnPqMeta)); + } +}; + +static_assert(sizeof(DiskAnnMetaHeader) % 32 == 0, + "DiskAnnMetaHeader must be aligned with 32 bytes"); + +static_assert(sizeof(DiskAnnPqMeta) % 32 == 0, + "DiskAnnPqMeta must be aligned with 32 bytes"); + +class DiskAnnEntity { + public: + DiskAnnEntity() = default; + virtual ~DiskAnnEntity() = default; + + //! Constructor + DiskAnnEntity(const DiskAnnMetaHeader &meta_header, + const DiskAnnPqMeta &pq_meta) { + meta_header_ = meta_header; + pq_meta_ = pq_meta; + } + + //! DiskAnnEntity Pointerd; + typedef std::shared_ptr Pointer; + + public: + static inline size_t AlignSize(size_t size) { + return (size + 0xFFF) & (~0xFFF); + } + + public: + virtual int add_vector(diskann_key_t /*key*/, const void * /*vec*/) { + return IndexError_NotImplemented; + } + + virtual const void *get_vector(diskann_id_t /*id*/) const { + return nullptr; + } + + virtual std::pair get_neighbors( + diskann_id_t /*id*/) const { + return std::make_pair(0, nullptr); + } + + virtual int set_neighbors( + diskann_id_t /*id*/, const std::vector & /*neighbor_ids*/) { + return IndexError_NotImplemented; + } + + virtual int add_neighbor(diskann_id_t /*id*/, diskann_id_t /*neighbor_id*/) { + return IndexError_NotImplemented; + } + + //! Get node id of primary key + virtual diskann_id_t get_id(diskann_key_t key) const = 0; + + //! Get primary key of the node id + virtual diskann_key_t get_key(diskann_id_t id) const = 0; + + public: + uint64_t max_node_size() const { + return meta_header_.max_node_size; + } + + uint64_t medoid() const { + return meta_header_.medoid; + } + + uint64_t *mutable_medoid() { + return &meta_header_.medoid; + } + + uint64_t node_per_sector() const { + return meta_header_.node_per_sector; + } + + uint64_t pq_chunk_num() { + return pq_meta_.chunk_num; + } + + uint64_t doc_cnt() const { + return meta_header_.doc_cnt; + } + + inline uint64_t *mutable_doc_cnt() { + return &meta_header_.doc_cnt; + } + + uint64_t max_degree() { + return meta_header_.max_degree; + } + + DiskAnnPqMeta *mutable_pq_meta() { + return &pq_meta_; + } + + public: + virtual const DiskAnnEntity::Pointer clone() const { + LOG_ERROR("Update neighbors not implemented"); + return DiskAnnEntity::Pointer(); + } + + public: + const static std::string kDiskAnnVectorSegmentId; + const static std::string kDiskAnnMetaSegmentId; + const static std::string kDiskAnnPqMetaSegmentId; + const static std::string kDiskAnnPqDataSegmentId; + const static std::string kDiskAnnDummpySegmentId; + const static std::string kDiskAnnMappingSegmentId; + const static std::string kDiskAnnKeyMappingSegmentId; + const static std::string kDiskAnnEntryPointSegmentId; + const static std::string kDiskAnnKeySegmentId; + + constexpr static float kDefaultBFNegativeProbility = 0.001f; + constexpr static float kDefaultGraphSlackFactor = 1.3f; + constexpr static float kDefaultAlpha = 1.2f; + constexpr static uint32_t kDefaultMaxOcclusionSize = 750; + constexpr static uint32_t kDefaultMaxDegree = 100; + constexpr static uint32_t kDefaultCompressBatchSize = 5000000; + constexpr static uint32_t kRevision = 0U; + + protected: + DiskAnnMetaHeader meta_header_; + DiskAnnPqMeta pq_meta_; +}; + +} // namespace core +} // namespace zvec diff --git a/src/core/algorithm/diskann/diskann_file_reader.cc b/src/core/algorithm/diskann/diskann_file_reader.cc new file mode 100644 index 000000000..3326a7400 --- /dev/null +++ b/src/core/algorithm/diskann/diskann_file_reader.cc @@ -0,0 +1,269 @@ +// Copyright 2025-present the zvec project +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "diskann_file_reader.h" +#include +#include +#include +#include +#include + +#define MAX_EVENTS 1024 + +namespace zvec { +namespace core { + +#if (defined(__linux) || defined(__linux__)) +typedef struct io_event io_event_t; +typedef struct iocb iocb_t; +#endif + +int setup_io_ctx(IOContext &ctx) { +#if (defined(__linux) || defined(__linux__)) + int ret = io_setup(MAX_EVENTS, &ctx); + + return ret; +#else + return 0; +#endif +} + +int destroy_io_ctx(IOContext &ctx) { +#if (defined(__linux) || defined(__linux__)) + int ret = io_destroy(ctx); + + return ret; +#else + return 0; +#endif +} + +static int execute_io_pread(int fd, std::vector &read_reqs) { + for (auto &req : read_reqs) { + ssize_t bytes_read = ::pread(fd, req.buf, req.len, req.offset); + if (bytes_read < 0) { + LOG_ERROR("pread failed; errno=%d, %s, offset=%lu, len=%lu", errno, + ::strerror(errno), (unsigned long)req.offset, + (unsigned long)req.len); + return IndexError_Runtime; + } + if ((size_t)bytes_read != req.len) { + LOG_ERROR("pread short read; got=%zd, expected=%lu", bytes_read, + (unsigned long)req.len); + return IndexError_Runtime; + } + } + return 0; +} + +int execute_io(IOContext ctx, int fd, std::vector &read_reqs, + uint64_t n_retries = 0) { +#if (defined(__linux) || defined(__linux__)) + uint64_t iters = + DiskAnnUtil::round_up(read_reqs.size(), MAX_EVENTS) / MAX_EVENTS; + + for (uint64_t iter = 0; iter < iters; iter++) { + uint64_t n_ops = std::min((uint64_t)read_reqs.size() - (iter * MAX_EVENTS), + (uint64_t)MAX_EVENTS); + + std::vector cbs(n_ops, nullptr); + std::vector evts(n_ops); + std::vector cb(n_ops); + for (uint64_t j = 0; j < n_ops; j++) { + io_prep_pread(cb.data() + j, fd, read_reqs[j + iter * MAX_EVENTS].buf, + read_reqs[j + iter * MAX_EVENTS].len, + read_reqs[j + iter * MAX_EVENTS].offset); + } + + for (uint64_t i = 0; i < n_ops; i++) { + cbs[i] = cb.data() + i; + } + + size_t n_tries = 0; + while (n_tries <= n_retries) { + int ret = io_submit(ctx, (int64_t)n_ops, cbs.data()); + + if (ret != (int)n_ops) { + LOG_WARN( + "io_submit failed; returned: %d, expected=%lu, errno=%d, %s, " + "falling back to pread", + ret, n_ops, errno, ::strerror(-ret)); + return execute_io_pread(fd, read_reqs); + } else { + ret = io_getevents(ctx, (int64_t)n_ops, (int64_t)n_ops, evts.data(), + nullptr); + if (ret != (int64_t)n_ops) { + LOG_WARN( + "io_getevents failed; returned: %d, expected=%lu, errno=%d, %s, " + "falling back to pread", + ret, n_ops, errno, ::strerror(-ret)); + return execute_io_pread(fd, read_reqs); + } else { + break; + } + } + n_tries++; + } + } + + return 0; +#else + return execute_io_pread(fd, read_reqs); +#endif +} + +LinuxAlignedFileReader::LinuxAlignedFileReader(int file_desc) { + this->file_desc = file_desc; +} + +LinuxAlignedFileReader::LinuxAlignedFileReader() { + this->file_desc = -1; +} + +LinuxAlignedFileReader::~LinuxAlignedFileReader() { + deregister_all_threads(); + if (file_desc >= 0) { + ::close(file_desc); + file_desc = -1; + } +} + +IOContext &LinuxAlignedFileReader::get_ctx() { + std::unique_lock lk(ctx_mut); + if (ctx_map.find(std::this_thread::get_id()) == ctx_map.end()) { + std::cerr << "bad thread access; returning -1 as io_context_t" << std::endl; + return this->bad_ctx; + } else { + return ctx_map[std::this_thread::get_id()]; + } +} + +void LinuxAlignedFileReader::register_thread() { +#if (defined(__linux) || defined(__linux__)) + auto thread_id = std::this_thread::get_id(); + std::unique_lock lk(ctx_mut); + if (ctx_map.find(thread_id) != ctx_map.end()) { + LOG_ERROR("multiple calls to register_thread from the same thread"); + + return; + } + + IOContext ctx = nullptr; + + int ret = io_setup(MAX_EVENTS, &ctx); + if (ret != 0) { + lk.unlock(); + if (ret == -EAGAIN) { + LOG_ERROR( + "io_setup failed with EAGAIN: Consider increasing " + "/proc/sys/fs/aio-max-nr"); + } else { + LOG_ERROR("io_setup failed; returned: %d, %s", ret, ::strerror(-ret)); + ; + } + } else { + LOG_INFO("allocating ctx: %lu", (uint64_t)ctx); + + ctx_map[thread_id] = ctx; + } + + lk.unlock(); +#endif +} + +void LinuxAlignedFileReader::deregister_thread() { +#if (defined(__linux) || defined(__linux__)) + auto thread_id = std::this_thread::get_id(); + std::unique_lock lk(ctx_mut); + assert(ctx_map.find(thread_id) != ctx_map.end()); + + lk.unlock(); + IOContext ctx = this->get_ctx(); + io_destroy(ctx); + // assert(ret == 0); + lk.lock(); + ctx_map.erase(thread_id); + + LOG_INFO("returned ctx from thread"); + + lk.unlock(); +#endif +} + +void LinuxAlignedFileReader::deregister_all_threads() { +#if (defined(__linux) || defined(__linux__)) + std::unique_lock lk(ctx_mut); + for (auto x = ctx_map.begin(); x != ctx_map.end(); x++) { + IOContext ctx = x->second; + io_destroy(ctx); + } + ctx_map.clear(); +#endif +} + +void LinuxAlignedFileReader::open(const std::string &fname) { + int flags = O_RDONLY; + +#if defined(__linux__) || defined(__linux) + flags |= O_DIRECT | O_LARGEFILE; +#endif + + this->file_desc = ::open(fname.c_str(), flags); + +#if defined(__linux__) || defined(__linux) + // O_DIRECT may not be supported on all filesystems (e.g. tmpfs, overlay). + // Fall back to regular buffered I/O when it fails. + if (this->file_desc == -1) { + LOG_WARN( + "open with O_DIRECT failed for %s (errno=%d: %s), " + "falling back to buffered I/O", + fname.c_str(), errno, ::strerror(errno)); + this->file_desc = ::open(fname.c_str(), O_RDONLY | O_LARGEFILE); + } +#endif + + if (this->file_desc == -1) { + LOG_ERROR("Failed to open file: %s (errno=%d: %s)", fname.c_str(), errno, + ::strerror(errno)); + } + + LOG_INFO("Opened file : %s", fname.c_str()); +} + +void LinuxAlignedFileReader::close() { + if (file_desc >= 0) { + ::close(file_desc); + file_desc = -1; + } +} + +int LinuxAlignedFileReader::read(std::vector &read_reqs, + IOContext &ctx, bool async) { + if (async == true) { + LOG_WARN("Async currently not supported"); + } + + if (this->file_desc == -1) { + LOG_ERROR("Attempt to read from invalid file descriptor"); + return IndexError_Runtime; + } + + int ret = execute_io(ctx, this->file_desc, read_reqs); + + return ret; +} + + +} // namespace core +} // namespace zvec diff --git a/src/core/algorithm/diskann/diskann_file_reader.h b/src/core/algorithm/diskann/diskann_file_reader.h new file mode 100644 index 000000000..8099946aa --- /dev/null +++ b/src/core/algorithm/diskann/diskann_file_reader.h @@ -0,0 +1,104 @@ +// Copyright 2025-present the zvec project +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +#pragma once + +#define MAX_IO_DEPTH 128 + +#include + +#if (defined(__linux) || defined(__linux__)) +#include +#endif + +#include +#include +#include +#include +#include "diskann_util.h" + +namespace zvec { +namespace core { + +#if (defined(__linux) || defined(__linux__)) +typedef io_context_t IOContext; +#else +typedef uint32_t IOContext; +#endif + +int setup_io_ctx(IOContext &ctx); +int destroy_io_ctx(IOContext &ctx); + +struct AlignedRead { + uint64_t offset; + uint64_t len; + void *buf; + + AlignedRead() : offset(0), len(0), buf(nullptr) {} + + AlignedRead(uint64_t offset, uint64_t len, void *buf) + : offset(offset), len(len), buf(buf) { + ailego_assert(static_cast(offset) % 512 == 0); + ailego_assert(static_cast(len) % 512 == 0); + ailego_assert(reinterpret_cast(buf) % 512 == 0); + } +}; + +class AlignedFileReader { + protected: + std::map ctx_map; + std::mutex ctx_mut; + + public: + virtual IOContext &get_ctx() = 0; + + virtual ~AlignedFileReader() {}; + + virtual void register_thread() = 0; + virtual void deregister_thread() = 0; + virtual void deregister_all_threads() = 0; + + virtual void open(const std::string &fname) = 0; + virtual void close() = 0; + + virtual int read(std::vector &read_reqs, IOContext &ctx, + bool async = false) = 0; +}; + +class LinuxAlignedFileReader : public AlignedFileReader { + private: + uint64_t file_sz; + int file_desc; + + IOContext bad_ctx = (IOContext)-1; + + public: + LinuxAlignedFileReader(); + LinuxAlignedFileReader(int file_desc); + ~LinuxAlignedFileReader(); + + public: + IOContext &get_ctx(); + + void register_thread(); + void deregister_thread(); + void deregister_all_threads(); + void open(const std::string &fname); + void close(); + + int read(std::vector &read_reqs, IOContext &ctx, + bool async = false); +}; + +} // namespace core +} // namespace zvec diff --git a/src/core/algorithm/diskann/diskann_holder.cc b/src/core/algorithm/diskann/diskann_holder.cc new file mode 100644 index 000000000..236e4ac21 --- /dev/null +++ b/src/core/algorithm/diskann/diskann_holder.cc @@ -0,0 +1,18 @@ +// Copyright 2025-present the zvec project +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +#pragma once + +namespace zvec { +namespace core {} // namespace core +} // namespace zvec diff --git a/src/core/algorithm/diskann/diskann_holder.h b/src/core/algorithm/diskann/diskann_holder.h new file mode 100644 index 000000000..e37fad6ac --- /dev/null +++ b/src/core/algorithm/diskann/diskann_holder.h @@ -0,0 +1,321 @@ +// Copyright 2025-present the zvec project +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +#pragma once + +#include +#include +#include "diskann_entity.h" + +namespace zvec { +namespace core { + +struct DiskAnnIndexHolderMeta { + uint32_t element_size_; + uint32_t key_size_; + uint32_t sector_size_; + uint32_t doc_cnt_; + uint8_t reserve_[]; +}; + +class DiskAnnIndexHolder : public IndexHolder { + public: + typedef std::shared_ptr Pointer; + + public: + enum Status { STATUS_UNINITED = 0, STATUS_WRITE = 1, STATUS_READ = 2 }; + + public: + static constexpr uint32_t kDataSectorSize = 128 * 1024; + static constexpr uint32_t kMetaSectorSize = 4096; + + public: + inline static uint32_t get_sector_id(uint32_t id, uint32_t sector_vec_num) { + return id / sector_vec_num; + } + + inline static uint32_t get_sector_offset(uint32_t id, uint32_t sector_vec_num, + uint32_t data_size) { + return (id % sector_vec_num) * data_size; + } + + public: + /*! Random Access Index Holder Iterator + */ + class Iterator : public IndexHolder::Iterator { + public: + //! Index Holder Iterator Pointer + typedef std::unique_ptr Pointer; + + //! Constructor + Iterator(DiskAnnIndexHolder *owner) + : holder_(owner), sector_id_{0}, sector_offset_{0} { + path_ = holder_->path(); + data_size_ = holder_->data_size(); + data_sector_size_ = holder_->data_sector_size(); + meta_sector_size_ = holder_->meta_sector_size(); + + sector_buffer_.resize(data_sector_size_); + + sector_vec_num_ = data_sector_size_ / data_size_; + } + + //! Destructor + virtual ~Iterator(void) { + if (file_.is_open()) { + file_.close(); + } + } + + int init() { + file_.open(path_, std::ios::in); + if (!file_.is_open()) { + LOG_ERROR("file can not create, %s", path_.c_str()); + return IndexError_OpenFile; + } + + file_.seekg(meta_sector_size_); + + read_sector(); + + return 0; + } + + //! Retrieve pointer of data + virtual const void *data(void) const override { + const uint8_t *data_ptr = + reinterpret_cast(sector_buffer_.data()); + return data_ptr + sector_offset_ + sizeof(diskann_key_t); + } + + //! Test if the iterator is valid + virtual bool is_valid(void) const override { + return id_ < holder_->count(); + } + + //! Retrieve primary key + virtual uint64_t key(void) const override { + const uint8_t *data_ptr = + reinterpret_cast(sector_buffer_.data()); + uint64_t key = + *reinterpret_cast(data_ptr + sector_offset_); + + return key; + } + + //! Next iterator + virtual void next(void) override { + ++id_; + + uint32_t sector_id = get_sector_id(id_, sector_vec_num_); + if (sector_id > sector_id_) { + file_.seekg(sector_id * data_sector_size_ + meta_sector_size_); + read_sector(); + sector_id_ = sector_id; + } + + sector_offset_ = get_sector_offset(id_, sector_vec_num_, data_size_); + } + + int read_sector() { + file_.read(&((sector_buffer_)[0]), data_sector_size_); + + return 0; + } + + private: + //! Members + DiskAnnIndexHolder *holder_{nullptr}; + std::string path_; + std::ifstream file_; + uint32_t sector_id_{0}; + std::string sector_buffer_; + uint32_t id_{0}; + uint32_t data_size_{0}; + uint32_t sector_offset_{0}; + uint32_t data_sector_size_{0}; + uint32_t meta_sector_size_{0}; + uint32_t sector_vec_num_{0}; + }; + + public: + DiskAnnIndexHolder(IndexMeta &meta, std::string &path) { + path_ = path; + + data_size_ = meta.element_size() + sizeof(diskann_key_t); + dimension_ = meta.dimension(); + type_ = meta.data_type(); + + element_size_ = meta.element_size(); + sector_vec_num_ = data_sector_size_ / data_size_; + padding_size_ = data_sector_size_ - sector_vec_num_ * data_size_; + sector_buffer_.resize(data_sector_size_); + sector_internal_id_ = 0; + } + + ~DiskAnnIndexHolder() { + if (file_.is_open()) { + file_.close(); + } + } + + //! Init + int init() { + // file_.open(path, std::ios::in | std::ios::out); + // file_.open(path_, std::ios::in | std::ios::out | std::ios::trunc); + file_.open(path_, std::ios::out | std::ios::trunc); + + if (!file_.is_open()) { + LOG_ERROR("file can not create, %s", path_.c_str()); + return IndexError_OpenFile; + } + + DiskAnnIndexHolderMeta holder_meta; + holder_meta.element_size_ = element_size_; + holder_meta.key_size_ = sizeof(diskann_key_t); + holder_meta.sector_size_ = data_sector_size_; + + std::vector empty_sector; + empty_sector.resize(meta_sector_size_); + + std::memset(&(empty_sector[0]), 0, meta_sector_size_); + std::memcpy(&(empty_sector[0]), &holder_meta, + sizeof(DiskAnnIndexHolderMeta)); + + file_.write(reinterpret_cast(&(empty_sector[0])), + meta_sector_size_); + + status_ = STATUS_WRITE; + + return 0; + } + + int close() { + if (sector_internal_id_ != 0) { + file_.write(reinterpret_cast(&(sector_buffer_[0])), + data_sector_size_); + } + + file_.close(); + + return 0; + } + + //! Retrieve count of elements in holder (-1 indicates unknown) + virtual size_t count(void) const override { + return count_; + } + + //! Retrieve dimension + virtual size_t dimension(void) const override { + return dimension_; + } + + //! Retrieve type information + virtual IndexMeta::DataType data_type(void) const override { + return type_; + } + + //! Retrieve element size in bytes + virtual size_t element_size(void) const override { + return element_size_; + } + + //! Retrieve if it can multi-pass + virtual bool multipass(void) const override { + return true; + } + + //! Create a new iterator + virtual IndexHolder::Iterator::Pointer create_iterator(void) override { + DiskAnnIndexHolder::Iterator::Pointer pointer( + new DiskAnnIndexHolder::Iterator(this)); + + if (pointer->init() != 0) { + return nullptr; + } + + return pointer; + } + + int emplace(uint64_t pkey, const void *vec) { + if (status_ != STATUS_WRITE) { + return IndexError_NoReady; + } + + uint8_t *data_ptr = reinterpret_cast(&(sector_buffer_[0])) + + sector_internal_id_ * data_size_; + std::memcpy(data_ptr, &pkey, sizeof(diskann_key_t)); + std::memcpy(data_ptr + sizeof(diskann_key_t), vec, element_size_); + + sector_internal_id_++; + if (sector_internal_id_ >= sector_vec_num_) { + std::vector padding_(padding_size_, 0); + std::memcpy(data_ptr + data_size_, padding_.data(), padding_size_); + + file_.write(reinterpret_cast(&(sector_buffer_[0])), + data_sector_size_); + + sector_internal_id_ = 0; + sector_id_++; + } + + count_++; + + return 0; + } + + uint32_t data_sector_size() { + return data_sector_size_; + } + + uint32_t meta_sector_size() { + return meta_sector_size_; + } + + uint32_t data_size() { + return data_size_; + } + + uint32_t *mutable_sector_id() { + return §or_id_; + } + + uint32_t sector_id() { + return sector_id_; + } + + std::string &path() { + return path_; + } + + private: + std::string path_; + std::ofstream file_; + uint32_t element_size_{0}; + uint32_t dimension_{0}; + IndexMeta::DataType type_{IndexMeta::DataType::DT_UNDEFINED}; + uint32_t sector_vec_num_{0}; + uint32_t data_size_{0}; + uint32_t padding_size_{0}; + uint32_t meta_sector_size_{DiskAnnIndexHolder::kMetaSectorSize}; + uint32_t data_sector_size_{DiskAnnIndexHolder::kDataSectorSize}; + std::string sector_buffer_; + uint32_t sector_internal_id_{0}; + uint32_t sector_id_{0}; + uint32_t count_{0}; + uint32_t status_{STATUS_UNINITED}; +}; + +} // namespace core +} // namespace zvec \ No newline at end of file diff --git a/src/core/algorithm/diskann/diskann_indexer.cc b/src/core/algorithm/diskann/diskann_indexer.cc new file mode 100644 index 000000000..bc050ee57 --- /dev/null +++ b/src/core/algorithm/diskann/diskann_indexer.cc @@ -0,0 +1,1252 @@ +// Copyright 2025-present the zvec project +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "diskann_indexer.h" +#include +#include +#include +#include +#include + +namespace zvec { +namespace core { + +DiskAnnIndexer::DiskAnnIndexer(const IndexMeta &meta) { + meta_ = meta; +} + +DiskAnnIndexer::~DiskAnnIndexer() { + destroy_io_ctx(init_ctx_); + if (centroid_data_) { + free(centroid_data_); + } + DiskAnnUtil::free_aligned(coord_cache_buf_); +} + +int DiskAnnIndexer::init(DiskAnnSearcherEntity &entity) { + entity_ = &entity; + + auto storage = entity.get_storage(); + auto vector_segment = entity.get_vector_segment(); + + pq_table_ = entity.get_pq_table(); + + index_segment_offset_ = vector_segment->data_offset(); + + bool with_direct_io = true; + if (with_direct_io) { + reader_.reset(new LinuxAlignedFileReader()); + + auto file_path = storage->file_path(); + reader_->open(file_path); + + storage->cleanup(); + } else { + auto file = storage->file(); + reader_.reset(new LinuxAlignedFileReader(file->native_handle())); + } + + int ret = setup_io_ctx(init_ctx_); + if (ret != 0) { + LOG_ERROR("setup io ctx error"); + return ret; + } + + max_node_size_ = entity.max_node_size(); + disk_bytes_per_point_ = meta_.element_size(); + + node_per_sector_ = entity.node_per_sector(); + aligned_dim_ = meta_.dimension(); + + pq_chunk_num_ = entity.pq_chunk_num(); + + medoid_ = entity.medoid(); + + entrypints_.push_back(medoid_); + auto &entrypoints = entity.entrypoints(); + for (size_t i = 0; i < entrypoints.size(); ++i) { + entrypints_.push_back(entrypoints[i]); + } + + doc_cnt_ = entity.doc_cnt(); + + max_degree_ = entity.max_degree(); + + sector_num_per_node_ = + DiskAnnUtil::div_round_up(max_node_size_, DiskAnnUtil::kSectorSize); + if (beam_width_ > sector_num_per_node_ * DiskAnnUtil::kMaxSectorReadNum) { + LOG_ERROR("Beamwidth can not be higher than kMaxSectorReadNum"); + + return IndexError_InvalidArgument; + } + + DiskAnnUtil::alloc_aligned((void **)(¢roid_data_), + entrypints_.size() * aligned_dim_ * sizeof(float), + 32); + + use_medroids_data_as_centroids(); + + return 0; +} + +int DiskAnnIndexer::use_medroids_data_as_centroids() { + LOG_INFO("Loading centroid data from medoid vector data"); + + std::vector nodes_to_read; + std::vector medoid_bufs; + std::vector> neighbor_bufs; + + std::vector centroid_buffer; + + size_t dim = meta_.dimension(); + centroid_buffer.resize(dim); + + nodes_to_read.push_back(medoid_); + medoid_bufs.push_back(&(centroid_buffer[0])); + neighbor_bufs.emplace_back(0, nullptr); + + auto read_status = read_nodes(nodes_to_read, medoid_bufs, neighbor_bufs); + + if (read_status[0] == true) { + for (uint32_t i = 0; i < dim; i++) centroid_data_[i] = centroid_buffer[i]; + } else { + LOG_ERROR("Failed to read medoid"); + return IndexError_Runtime; + } + + return 0; +} + +diskann_key_t DiskAnnIndexer::get_key(diskann_id_t id) const { + return entity_->get_key(id); +} + +diskann_id_t DiskAnnIndexer::get_id(diskann_key_t key) const { + return entity_->get_id(key); +} + +std::vector DiskAnnIndexer::read_nodes( + const std::vector &node_ids, + std::vector &coord_buffers, + std::vector> &neighbor_buffers) { + std::vector read_reqs; + std::vector retval(node_ids.size(), true); + + uint8_t *buf = nullptr; + auto sector_num = + node_per_sector_ > 0 + ? 1 + : DiskAnnUtil::div_round_up(max_node_size_, DiskAnnUtil::kSectorSize); + DiskAnnUtil::alloc_aligned( + (void **)&buf, node_ids.size() * sector_num * DiskAnnUtil::kSectorSize, + DiskAnnUtil::kSectorSize); + + for (size_t i = 0; i < node_ids.size(); ++i) { + auto node_id = node_ids[i]; + + AlignedRead read; + read.len = sector_num * DiskAnnUtil::kSectorSize; + read.buf = buf + i * sector_num * DiskAnnUtil::kSectorSize; + read.offset = + index_segment_offset_ + + DiskAnnUtil::get_node_sector(node_per_sector_, max_node_size_, + DiskAnnUtil::kSectorSize, node_id) * + DiskAnnUtil::kSectorSize; + read_reqs.push_back(read); + } + + int read_ret = reader_->read(read_reqs, init_ctx_); + if (read_ret != 0) { + LOG_ERROR("read_nodes: reader_->read failed, ret=%d", read_ret); + for (size_t i = 0; i < retval.size(); i++) { + retval[i] = false; + } + DiskAnnUtil::free_aligned(buf); + return retval; + } + + for (uint32_t i = 0; i < read_reqs.size(); i++) { + uint8_t *node_buf = + DiskAnnUtil::offset_to_node(node_per_sector_, max_node_size_, + (uint8_t *)read_reqs[i].buf, node_ids[i]); + + if (coord_buffers[i] != nullptr) { + void *node_coords = node_buf; + memcpy(coord_buffers[i], node_coords, disk_bytes_per_point_); + } + + if (neighbor_buffers[i].second != nullptr) { + uint32_t *node_neighbor = + DiskAnnUtil::offset_to_node_neighbor(node_buf, meta_.element_size()); + uint32_t neighbor_num = *node_neighbor; + + neighbor_buffers[i].first = neighbor_num; + memcpy(neighbor_buffers[i].second, node_neighbor + 1, + neighbor_num * sizeof(diskann_id_t)); + } + } + + DiskAnnUtil::free_aligned(buf); + + return retval; +} + +int DiskAnnIndexer::load_cache_list( + const std::vector &node_list) { + LOG_INFO("Loading the cache list into memory"); + + size_t num_cached_nodes = node_list.size(); + + neighbor_cache_buffer_.resize(num_cached_nodes * (max_degree_ + 1), 0); + + size_t coord_cache_buf_len = num_cached_nodes * aligned_dim_; + DiskAnnUtil::alloc_aligned((void **)&coord_cache_buf_, + coord_cache_buf_len * meta_.unit_size(), + 8 * meta_.unit_size()); + + memset(coord_cache_buf_, 0, coord_cache_buf_len * meta_.unit_size()); + + constexpr size_t BLOCK_SIZE = 8; + size_t num_blocks = DiskAnnUtil::div_round_up(num_cached_nodes, BLOCK_SIZE); + for (size_t block = 0; block < num_blocks; block++) { + size_t start_idx = block * BLOCK_SIZE; + size_t end_idx = std::min(num_cached_nodes, (block + 1) * BLOCK_SIZE); + + std::vector nodes_to_read; + std::vector coord_buffers; + std::vector> neighbor_buffers; + for (size_t node_idx = start_idx; node_idx < end_idx; node_idx++) { + nodes_to_read.push_back(node_list[node_idx]); + coord_buffers.push_back(reinterpret_cast(coord_cache_buf_) + + node_idx * meta_.element_size()); + neighbor_buffers.emplace_back( + 0, neighbor_cache_buffer_.data() + node_idx * (max_degree_ + 1)); + } + + auto read_status = + read_nodes(nodes_to_read, coord_buffers, neighbor_buffers); + + for (size_t i = 0; i < read_status.size(); i++) { + if (read_status[i] == true) { + coord_cache_.insert(std::make_pair(nodes_to_read[i], coord_buffers[i])); + neighbor_cache_.insert( + std::make_pair(nodes_to_read[i], neighbor_buffers[i])); + } + } + } + + LOG_INFO("Load Cache List Done"); + + return 0; +} + +void DiskAnnIndexer::cache_bfs_levels(uint64_t num_nodes_to_cache, + std::vector &node_list) { + std::set node_set; + + size_t tenp_cnt = static_cast(std::round(doc_cnt_ * 0.1)); + if (num_nodes_to_cache > tenp_cnt) { + LOG_WARN( + "Reducing nodes to cache from: %zu, to: (10 percent of total nodes: " + "%zu)", + (size_t)num_nodes_to_cache, (size_t)tenp_cnt); + + num_nodes_to_cache = tenp_cnt == 0 ? 1 : tenp_cnt; + } + + LOG_INFO("Begin to cache %zu Nodes", (size_t)num_nodes_to_cache); + + std::unordered_set cur_level; + std::unordered_set prev_level; + + for (uint64_t iter = 0; + iter < entrypints_.size() && cur_level.size() < num_nodes_to_cache; + iter++) { + cur_level.insert(entrypints_[iter]); + } + + uint64_t level = 1; + uint64_t prev_node_set_size = 0; + while ((node_set.size() + cur_level.size() < num_nodes_to_cache) && + cur_level.size() != 0) { + prev_level.swap(cur_level); + + cur_level.clear(); + + std::vector nodes_to_expand; + nodes_to_expand.reserve(prev_level.size()); + + for (const diskann_id_t &id : prev_level) { + if (node_set.find(id) != node_set.end()) { + continue; + } + + node_set.insert(id); + nodes_to_expand.push_back(id); + } + + std::sort(nodes_to_expand.begin(), nodes_to_expand.end()); + + bool finish_flag = false; + + constexpr uint64_t BLOCK_SIZE = 1024; + uint64_t nblocks = + DiskAnnUtil::div_round_up(nodes_to_expand.size(), BLOCK_SIZE); + for (size_t block = 0; block < nblocks && !finish_flag; block++) { + size_t start = block * BLOCK_SIZE; + size_t end = std::min((uint64_t)((block + 1) * BLOCK_SIZE), + (uint64_t)(nodes_to_expand.size())); + const size_t block_size = end - start; + + std::vector nodes_to_read(nodes_to_expand.begin() + start, + nodes_to_expand.begin() + end); + std::vector coord_buffers(block_size, nullptr); + + std::vector>> + neighbor_buffers; + neighbor_buffers.reserve(block_size); + + for (size_t i = 0; i < block_size; i++) { + neighbor_buffers.emplace_back( + 0, std::vector(max_degree_ + 1)); + } + + std::vector> neighbor_buffers_ptr; + neighbor_buffers_ptr.reserve(block_size); + for (size_t i = 0; i < block_size; i++) { + neighbor_buffers_ptr.emplace_back(neighbor_buffers[i].first, + neighbor_buffers[i].second.data()); + } + + auto read_status = + read_nodes(nodes_to_read, coord_buffers, neighbor_buffers_ptr); + + for (uint32_t i = 0; i < read_status.size(); i++) { + if (read_status[i] == false) { + continue; + } else { + neighbor_buffers[i].first = neighbor_buffers_ptr[i].first; + uint32_t neighbor_num = neighbor_buffers[i].first; + diskann_id_t *neighbors = neighbor_buffers[i].second.data(); + + for (uint32_t j = 0; j < neighbor_num && !finish_flag; j++) { + if (node_set.find(neighbors[j]) == node_set.end()) { + cur_level.insert(neighbors[j]); + } + if (cur_level.size() + node_set.size() >= num_nodes_to_cache) { + finish_flag = true; + } + } + } + } + } + + size_t total_size = node_set.size(); + + LOG_INFO("Level: %zu, Cached Size: %zu, Total Cached Size: %zu", + (size_t)level, (size_t)(total_size - prev_node_set_size), + total_size); + + prev_node_set_size = total_size; + level++; + } + + ailego_assert(node_set.size() + cur_level.size() == num_nodes_to_cache || + cur_level.size() == 0); + + node_list.clear(); + node_list.reserve(node_set.size() + cur_level.size()); + + for (auto node : node_set) { + node_list.push_back(node); + } + + for (auto node : cur_level) { + node_list.push_back(node); + } + + size_t total_size = node_list.size(); + LOG_INFO("Level: %zu, Cached Size: %zu, Total Cached Size: %zu", + (size_t)level, (size_t)(total_size - prev_node_set_size), + (size_t)total_size); + + return; +} + +int DiskAnnIndexer::linear_search(DiskAnnContext *ctx) { + auto &stats = ctx->query_stats(); + auto &dc = ctx->dist_calculator(); + auto &topk_heap = ctx->topk_heap(); + + topk_heap.clear(); + + IOContext &io_ctx = ctx->io_ctx(); + void *aligned_query_raw = ctx->query(); + + void *data_buf = reinterpret_cast(ctx->coord_buffer()); + + uint8_t *sector_buffer = reinterpret_cast(ctx->sector_buffer()); + + const uint64_t sector_num_per_node = + node_per_sector_ > 0 + ? 1 + : DiskAnnUtil::div_round_up(max_node_size_, DiskAnnUtil::kSectorSize); + + ailego::ElapsedTime io_timer; + ailego::ElapsedTime query_timer; + ailego::ElapsedTime cpu_timer; + + uint32_t num_ios = 0; + + std::vector frontier; + frontier.reserve(2 * beam_width_); + + std::vector> frontier_neighbors; + frontier_neighbors.reserve(2 * beam_width_); + + std::vector frontier_read_reqs; + frontier_read_reqs.reserve(2 * beam_width_); + + std::vector>> + cached_neighbors; + cached_neighbors.reserve(2 * beam_width_); + + uint64_t sector_buffer_idx = 0; + + diskann_id_t id = 0; + while (id < doc_cnt_) { + while (frontier.size() < beam_width_) { + if (!ctx->filter().is_valid() || !ctx->filter()(get_key(id))) { + auto iter = neighbor_cache_.find(id); + if (iter != neighbor_cache_.end()) { + cached_neighbors.push_back(std::make_pair(id, iter->second)); + stats.cache_hits++; + } else { + frontier.push_back(id); + } + } + + id++; + if (id >= doc_cnt_) { + break; + } + } + + if (!frontier.empty()) { + for (uint64_t i = 0; i < frontier.size(); i++) { + diskann_id_t cur_id = frontier[i]; + + std::pair frontier_neighbor; + frontier_neighbor.first = cur_id; + frontier_neighbor.second = sector_buffer + sector_num_per_node * + sector_buffer_idx * + DiskAnnUtil::kSectorSize; + frontier_neighbors.push_back(frontier_neighbor); + + sector_buffer_idx++; + + frontier_read_reqs.emplace_back( + index_segment_offset_ + + DiskAnnUtil::get_node_sector(node_per_sector_, max_node_size_, + DiskAnnUtil::kSectorSize, cur_id) * + DiskAnnUtil::kSectorSize, + sector_num_per_node * DiskAnnUtil::kSectorSize, + frontier_neighbor.second); + + stats.disk_page_reads++; + stats.io_num++; + num_ios++; + } + + io_timer.reset(); + + int read_ret = reader_->read(frontier_read_reqs, io_ctx); + stats.io_us += io_timer.micro_seconds(); + if (read_ret != 0) { + LOG_ERROR("linear_search: reader_->read failed, ret=%d", read_ret); + ctx->set_error(true); + return IndexError_Runtime; + } + } + + for (auto &cached_neighbor : cached_neighbors) { + auto global_cache_iter = coord_cache_.find(cached_neighbor.first); + void *node_fp_coords_copy = global_cache_iter->second; + + float cur_expanded_dist = dc.dist(aligned_query_raw, node_fp_coords_copy); + + std::string vec_value; + vec_value.resize(meta_.element_size()); + ::memcpy(&(vec_value[0]), node_fp_coords_copy, meta_.element_size()); + + topk_heap.emplace(cached_neighbor.first, + VectorInfo(cur_expanded_dist, vec_value)); + } + + for (auto &frontier_neighbor : frontier_neighbors) { + uint8_t *node_disk_buf = DiskAnnUtil::offset_to_node( + node_per_sector_, max_node_size_, frontier_neighbor.second, + frontier_neighbor.first); + + void *node_fp_coords = node_disk_buf; + memcpy(data_buf, node_fp_coords, disk_bytes_per_point_); + + float cur_expanded_dist = dc.dist(aligned_query_raw, data_buf); + + std::string vec_value; + vec_value.resize(meta_.element_size()); + ::memcpy(&(vec_value[0]), data_buf, meta_.element_size()); + + topk_heap.emplace(frontier_neighbor.first, + VectorInfo(cur_expanded_dist, vec_value)); + + stats.cpu_us += cpu_timer.micro_seconds(); + } + + frontier.clear(); + frontier_neighbors.clear(); + frontier_read_reqs.clear(); + cached_neighbors.clear(); + sector_buffer_idx = 0; + } + + stats.total_us += query_timer.micro_seconds(); + + return 0; +} + +int DiskAnnIndexer::keys_search(const std::vector &keys, + DiskAnnContext *ctx) { + auto &stats = ctx->query_stats(); + auto &dc = ctx->dist_calculator(); + auto &topk_heap = ctx->topk_heap(); + + topk_heap.clear(); + + IOContext &io_ctx = ctx->io_ctx(); + void *aligned_query_raw = ctx->query(); + + void *data_buf = reinterpret_cast(ctx->coord_buffer()); + + uint8_t *sector_buffer = reinterpret_cast(ctx->sector_buffer()); + + const uint64_t sector_num_per_node = + node_per_sector_ > 0 + ? 1 + : DiskAnnUtil::div_round_up(max_node_size_, DiskAnnUtil::kSectorSize); + + ailego::ElapsedTime query_timer; + ailego::ElapsedTime io_timer; + ailego::ElapsedTime cpu_timer; + + uint32_t num_ios = 0; + + std::vector frontier; + frontier.reserve(2 * beam_width_); + + std::vector> frontier_neighbors; + frontier_neighbors.reserve(2 * beam_width_); + + std::vector frontier_read_reqs; + frontier_read_reqs.reserve(2 * beam_width_); + + std::vector>> + cached_neighbors; + cached_neighbors.reserve(2 * beam_width_); + + uint64_t sector_buffer_idx = 0; + + size_t idx = 0; + while (idx < keys.size()) { + while (frontier.size() < beam_width_) { + if (!ctx->filter().is_valid() || !ctx->filter()(keys[idx])) { + diskann_id_t id = get_id(keys[idx]); + + auto iter = neighbor_cache_.find(id); + if (iter != neighbor_cache_.end()) { + cached_neighbors.push_back(std::make_pair(id, iter->second)); + stats.cache_hits++; + } else { + frontier.push_back(id); + } + } + + idx++; + if (idx >= keys.size()) { + break; + } + } + + if (!frontier.empty()) { + for (uint64_t i = 0; i < frontier.size(); i++) { + diskann_id_t cur_id = frontier[i]; + + std::pair frontier_neighbor; + frontier_neighbor.first = cur_id; + frontier_neighbor.second = sector_buffer + sector_num_per_node * + sector_buffer_idx * + DiskAnnUtil::kSectorSize; + frontier_neighbors.push_back(frontier_neighbor); + + sector_buffer_idx++; + + frontier_read_reqs.emplace_back( + index_segment_offset_ + + DiskAnnUtil::get_node_sector(node_per_sector_, max_node_size_, + DiskAnnUtil::kSectorSize, cur_id) * + DiskAnnUtil::kSectorSize, + sector_num_per_node * DiskAnnUtil::kSectorSize, + frontier_neighbor.second); + + stats.disk_page_reads++; + stats.io_num++; + num_ios++; + } + + io_timer.reset(); + + int read_ret = reader_->read(frontier_read_reqs, io_ctx); + stats.io_us += io_timer.micro_seconds(); + if (read_ret != 0) { + LOG_ERROR("keys_search: reader_->read failed, ret=%d", read_ret); + ctx->set_error(true); + return IndexError_Runtime; + } + } + + for (auto &cached_neighbor : cached_neighbors) { + auto global_cache_iter = coord_cache_.find(cached_neighbor.first); + void *node_fp_coords_copy = global_cache_iter->second; + + float cur_expanded_dist = dc.dist(aligned_query_raw, node_fp_coords_copy); + + std::string vec_value; + vec_value.resize(meta_.element_size()); + ::memcpy(&(vec_value[0]), node_fp_coords_copy, meta_.element_size()); + + topk_heap.emplace(cached_neighbor.first, + VectorInfo(cur_expanded_dist, vec_value)); + } + + for (auto &frontier_neighbor : frontier_neighbors) { + uint8_t *node_disk_buf = DiskAnnUtil::offset_to_node( + node_per_sector_, max_node_size_, frontier_neighbor.second, + frontier_neighbor.first); + + void *node_fp_coords = node_disk_buf; + memcpy(data_buf, node_fp_coords, disk_bytes_per_point_); + + float cur_expanded_dist = dc.dist(aligned_query_raw, data_buf); + + std::string vec_value; + vec_value.resize(meta_.element_size()); + ::memcpy(&(vec_value[0]), data_buf, meta_.element_size()); + + topk_heap.emplace(frontier_neighbor.first, + VectorInfo(cur_expanded_dist, vec_value)); + + stats.cpu_us += cpu_timer.micro_seconds(); + } + + frontier.clear(); + frontier_neighbors.clear(); + frontier_read_reqs.clear(); + cached_neighbors.clear(); + sector_buffer_idx = 0; + } + + stats.total_us += query_timer.micro_seconds(); + + return 0; +} + +int DiskAnnIndexer::get_vector(diskann_id_t id, IndexContext::Pointer &context, + std::string &vector) { + DiskAnnContext *ctx = dynamic_cast(context.get()); + + auto &stats = ctx->query_stats(); + + IOContext &io_ctx = ctx->io_ctx(); + + uint8_t *sector_buffer = reinterpret_cast(ctx->sector_buffer()); + + const uint64_t sector_num_per_node = + node_per_sector_ > 0 + ? 1 + : DiskAnnUtil::div_round_up(max_node_size_, DiskAnnUtil::kSectorSize); + + ailego::ElapsedTime query_timer; + ailego::ElapsedTime io_timer; + ailego::ElapsedTime cpu_timer; + + uint32_t num_ios = 0; + + std::vector frontier; + frontier.reserve(2 * beam_width_); + + std::vector> frontier_neighbors; + frontier_neighbors.reserve(2 * beam_width_); + + std::vector frontier_read_reqs; + frontier_read_reqs.reserve(2 * beam_width_); + + std::vector>> + cached_neighbors; + cached_neighbors.reserve(2 * beam_width_); + + auto iter = neighbor_cache_.find(id); + if (iter != neighbor_cache_.end()) { + void *node_fp_coords_copy = iter->second.second; + + vector.resize(meta_.element_size()); + ::memcpy(&(vector[0]), node_fp_coords_copy, meta_.element_size()); + + return 0; + } else { + std::pair frontier_neighbor; + frontier_neighbor.first = id; + frontier_neighbor.second = sector_buffer; + frontier_neighbors.push_back(frontier_neighbor); + + frontier_read_reqs.emplace_back( + index_segment_offset_ + + DiskAnnUtil::get_node_sector(node_per_sector_, max_node_size_, + DiskAnnUtil::kSectorSize, id) * + DiskAnnUtil::kSectorSize, + sector_num_per_node * DiskAnnUtil::kSectorSize, + frontier_neighbor.second); + + stats.disk_page_reads++; + stats.io_num++; + num_ios++; + + io_timer.reset(); + + reader_->read(frontier_read_reqs, io_ctx); + stats.io_us += io_timer.micro_seconds(); + + uint8_t *node_disk_buf = DiskAnnUtil::offset_to_node( + node_per_sector_, max_node_size_, frontier_neighbor.second, id); + + void *node_fp_coords = node_disk_buf; + + vector.resize(meta_.element_size()); + ::memcpy(&(vector[0]), node_fp_coords, meta_.element_size()); + + stats.cpu_us += cpu_timer.micro_seconds(); + } + + return 0; +} + +int DiskAnnIndexer::knn_search(DiskAnnContext *ctx) { + int ret = cached_beam_search(ctx); + if (ret != 0) { + return ret; + } + + if (ctx->group_by_search()) { + ret = cached_beam_search_by_group(ctx); + if (ret != 0) { + return ret; + } + } + + return 0; +} + +int DiskAnnIndexer::cached_beam_search(DiskAnnContext *ctx) { + auto &stats = ctx->query_stats(); + auto &dc = ctx->dist_calculator(); + auto &topk_heap = ctx->topk_heap(); + auto &visit_filter = ctx->visit_filter(); + + topk_heap.clear(); + + IOContext &io_ctx = ctx->io_ctx(); + + void *data_buf = reinterpret_cast(ctx->coord_buffer()); + uint8_t *sector_buffer = reinterpret_cast(ctx->sector_buffer()); + + const uint64_t sector_num_per_node = + node_per_sector_ > 0 + ? 1 + : DiskAnnUtil::div_round_up(max_node_size_, DiskAnnUtil::kSectorSize); + + pq_table_->preprocess_pq_dist_table(ctx->query_rotated(), + ctx->pq_table_dist_buffer()); + + ailego::ElapsedTime query_timer; + ailego::ElapsedTime io_timer; + ailego::ElapsedTime cpu_timer; + + NeighborPriorityQueue candidates; + + candidates.reserve(ctx->list_size()); + + diskann_id_t best_medoid = 0; + float best_dist = (std::numeric_limits::max)(); + for (uint64_t cur_m = 0; cur_m < entrypints_.size(); cur_m++) { + float cur_expanded_dist = + dc.dist(ctx->query(), centroid_data_ + aligned_dim_ * cur_m); + + if (cur_expanded_dist < best_dist) { + best_medoid = entrypints_[cur_m]; + best_dist = cur_expanded_dist; + } + } + + float dist; + pq_table_->compute_dists(1, &best_medoid, pq_chunk_num_, + ctx->pq_table_dist_buffer(), ctx->pq_coord_buffer(), + &dist); + candidates.insert(Neighbor(best_medoid, dist)); + visit_filter.set_visited(best_medoid); + + uint32_t cmps = 0; + uint32_t hops = 0; + uint32_t num_ios = 0; + + std::vector frontier; + frontier.reserve(2 * beam_width_); + + std::vector> frontier_neighbors; + frontier_neighbors.reserve(2 * beam_width_); + + std::vector frontier_read_reqs; + frontier_read_reqs.reserve(2 * beam_width_); + + std::vector>> + cached_neighbors; + cached_neighbors.reserve(2 * beam_width_); + + while (candidates.has_unexpanded_node() && num_ios < io_limit_) { + frontier.clear(); + frontier_neighbors.clear(); + frontier_read_reqs.clear(); + cached_neighbors.clear(); + + uint64_t sector_buffer_idx = 0; + + uint32_t num_seen = 0; + while (candidates.has_unexpanded_node() && frontier.size() < beam_width_ && + num_seen < beam_width_) { + auto neighbor = candidates.closest_unexpanded(); + num_seen++; + + auto iter = neighbor_cache_.find(neighbor.id); + if (iter != neighbor_cache_.end()) { + cached_neighbors.push_back(std::make_pair(neighbor.id, iter->second)); + stats.cache_hits++; + } else { + frontier.push_back(neighbor.id); + } + } + + if (!frontier.empty()) { + stats.hop_num++; + + for (uint64_t i = 0; i < frontier.size(); i++) { + diskann_id_t cur_id = frontier[i]; + + std::pair frontier_neighbor; + frontier_neighbor.first = cur_id; + frontier_neighbor.second = sector_buffer + sector_num_per_node * + sector_buffer_idx * + DiskAnnUtil::kSectorSize; + frontier_neighbors.push_back(frontier_neighbor); + + sector_buffer_idx++; + + frontier_read_reqs.emplace_back( + index_segment_offset_ + + DiskAnnUtil::get_node_sector(node_per_sector_, max_node_size_, + DiskAnnUtil::kSectorSize, cur_id) * + DiskAnnUtil::kSectorSize, + sector_num_per_node * DiskAnnUtil::kSectorSize, + frontier_neighbor.second); + + stats.disk_page_reads++; + stats.io_num++; + num_ios++; + } + + io_timer.reset(); + + int read_ret = reader_->read(frontier_read_reqs, io_ctx); + stats.io_us += io_timer.micro_seconds(); + if (read_ret != 0) { + LOG_ERROR("cached_beam_search: reader_->read failed, ret=%d", read_ret); + ctx->set_error(true); + return IndexError_Runtime; + } + } + + for (auto &cached_neighbor : cached_neighbors) { + auto global_cache_iter = coord_cache_.find(cached_neighbor.first); + void *node_fp_coords_copy = global_cache_iter->second; + + float cur_expanded_dist = dc.dist(ctx->query(), node_fp_coords_copy); + + if (!ctx->filter().is_valid() || + !ctx->filter()(get_key(cached_neighbor.first))) { + std::string vec_value; + vec_value.resize(meta_.element_size()); + ::memcpy(&(vec_value[0]), node_fp_coords_copy, meta_.element_size()); + + topk_heap.emplace(cached_neighbor.first, + VectorInfo(cur_expanded_dist, vec_value)); + } + + uint32_t neighbor_num = cached_neighbor.second.first; + diskann_id_t *node_neighbors = cached_neighbor.second.second; + + cpu_timer.reset(); + + float distances[neighbor_num]; + pq_table_->compute_dists(neighbor_num, node_neighbors, pq_chunk_num_, + ctx->pq_table_dist_buffer(), + ctx->pq_coord_buffer(), distances); + + stats.dist_num += neighbor_num; + stats.cpu_us += cpu_timer.micro_seconds(); + + for (uint64_t m = 0; m < neighbor_num; ++m) { + diskann_id_t id = node_neighbors[m]; + visit_filter.set_visited(id); + cmps++; + + Neighbor nn(id, distances[m]); + candidates.insert(nn); + } + } + + for (auto &frontier_neighbor : frontier_neighbors) { + uint8_t *node_disk_buf = DiskAnnUtil::offset_to_node( + node_per_sector_, max_node_size_, frontier_neighbor.second, + frontier_neighbor.first); + uint32_t *node_buf = DiskAnnUtil::offset_to_node_neighbor( + node_disk_buf, meta_.element_size()); + uint32_t neighbor_num = *node_buf; + + void *node_fp_coords = node_disk_buf; + memcpy(data_buf, node_fp_coords, disk_bytes_per_point_); + + float cur_expanded_dist = dc.dist(ctx->query(), data_buf); + + if (!ctx->filter().is_valid() || + !ctx->filter()(get_key(frontier_neighbor.first))) { + std::string vec_value; + vec_value.resize(meta_.element_size()); + ::memcpy(&(vec_value[0]), data_buf, meta_.element_size()); + + topk_heap.emplace(frontier_neighbor.first, + VectorInfo(cur_expanded_dist, vec_value)); + } + + diskann_id_t *node_neighbors = + reinterpret_cast(node_buf + 1); + + cpu_timer.reset(); + float distances[neighbor_num]; + pq_table_->compute_dists(neighbor_num, node_neighbors, pq_chunk_num_, + ctx->pq_table_dist_buffer(), + ctx->pq_coord_buffer(), distances); + + stats.dist_num += neighbor_num; + stats.cpu_us += cpu_timer.micro_seconds(); + + cpu_timer.reset(); + for (uint64_t m = 0; m < neighbor_num; ++m) { + diskann_id_t id = node_neighbors[m]; + visit_filter.set_visited(id); + cmps++; + stats.dist_num++; + + Neighbor nn(id, distances[m]); + candidates.insert(nn); + } + + stats.cpu_us += cpu_timer.micro_seconds(); + } + + hops++; + } + + stats.total_us += query_timer.micro_seconds(); + + return 0; +} + +int DiskAnnIndexer::cached_beam_search_in_mem(DiskAnnContext * /*ctx*/) { + return IndexError_NotImplemented; +} + +int DiskAnnIndexer::cached_beam_search_by_group(DiskAnnContext *ctx) { + if (!ctx->group_by().is_valid()) { + return 0; + } + + std::function group_by = [&](diskann_id_t id) { + return ctx->group_by()(get_key(id)); + }; + + // devide into groups + auto &topk_heap = ctx->topk_heap(); + auto &visit_filter = ctx->visit_filter(); + + std::map &group_topk_heaps = ctx->group_topk_heaps(); + + for (uint32_t i = 0; i < topk_heap.size(); ++i) { + diskann_id_t id = topk_heap[i].first; + auto info = topk_heap[i].second; + + std::string group_id = group_by(id); + + auto &group_topk_heap = group_topk_heaps[group_id]; + if (group_topk_heap.empty()) { + group_topk_heap.limit(ctx->group_topk()); + } + + topk_heap.emplace(id, info); + } + + // stage 2, expand to reach group num as possible + if (group_topk_heaps.size() < ctx->group_num()) { + NeighborPriorityQueue candidates; + + candidates.reserve(ctx->list_size()); + + for (uint32_t i = 0; i < topk_heap.size(); ++i) { + diskann_id_t id = topk_heap[i].first; + float score = topk_heap[i].second.dist_; + + visit_filter.set_visited(id); + candidates.insert(Neighbor(id, score)); + } + + ailego::ElapsedTime io_timer; + ailego::ElapsedTime query_timer; + ailego::ElapsedTime cpu_timer; + + auto &stats = ctx->query_stats(); + auto &dc = ctx->dist_calculator(); + + IOContext &io_ctx = ctx->io_ctx(); + + void *data_buf = reinterpret_cast(ctx->coord_buffer()); + uint8_t *sector_buffer = reinterpret_cast(ctx->sector_buffer()); + + const uint64_t sector_num_per_node = + node_per_sector_ > 0 ? 1 + : DiskAnnUtil::div_round_up( + max_node_size_, DiskAnnUtil::kSectorSize); + + pq_table_->preprocess_pq_dist_table(ctx->query_rotated(), + ctx->pq_table_dist_buffer()); + + uint32_t cmps = 0; + uint32_t hops = 0; + uint32_t num_ios = 0; + + std::vector frontier; + frontier.reserve(2 * beam_width_); + std::vector> frontier_neighbors; + frontier_neighbors.reserve(2 * beam_width_); + std::vector frontier_read_reqs; + frontier_read_reqs.reserve(2 * beam_width_); + std::vector>> + cached_neighbors; + cached_neighbors.reserve(2 * beam_width_); + + uint64_t sector_buffer_idx; + + while (candidates.has_unexpanded_node() && num_ios < io_limit_) { + frontier.clear(); + frontier_neighbors.clear(); + frontier_read_reqs.clear(); + cached_neighbors.clear(); + sector_buffer_idx = 0; + + uint32_t num_seen = 0; + while (candidates.has_unexpanded_node() && + frontier.size() < beam_width_ && num_seen < beam_width_) { + auto neighbor = candidates.closest_unexpanded(); + num_seen++; + + auto iter = neighbor_cache_.find(neighbor.id); + if (iter != neighbor_cache_.end()) { + cached_neighbors.push_back(std::make_pair(neighbor.id, iter->second)); + stats.cache_hits++; + } else { + frontier.push_back(neighbor.id); + } + } + + if (!frontier.empty()) { + stats.hop_num++; + + for (uint64_t i = 0; i < frontier.size(); i++) { + diskann_id_t cur_id = frontier[i]; + + std::pair frontier_neighbor; + frontier_neighbor.first = cur_id; + frontier_neighbor.second = + sector_buffer + sector_num_per_node * sector_buffer_idx * + DiskAnnUtil::kSectorSize; + frontier_neighbors.push_back(frontier_neighbor); + + sector_buffer_idx++; + + frontier_read_reqs.emplace_back( + index_segment_offset_ + DiskAnnUtil::get_node_sector( + node_per_sector_, max_node_size_, + DiskAnnUtil::kSectorSize, cur_id) * + DiskAnnUtil::kSectorSize, + sector_num_per_node * DiskAnnUtil::kSectorSize, + frontier_neighbor.second); + + stats.disk_page_reads++; + stats.io_num++; + num_ios++; + } + + io_timer.reset(); + + reader_->read(frontier_read_reqs, io_ctx); // synchronous IO linux + stats.io_us += io_timer.micro_seconds(); + } + + for (auto &cached_neighbor : cached_neighbors) { + auto global_cache_iter = coord_cache_.find(cached_neighbor.first); + void *node_fp_coords_copy = global_cache_iter->second; + + float cur_expanded_dist = dc.dist(ctx->query(), node_fp_coords_copy); + + if (!ctx->filter().is_valid() || + !ctx->filter()(get_key(cached_neighbor.first))) { + std::string group_id = group_by(cached_neighbor.first); + + auto &group_topk_heap = group_topk_heaps[group_id]; + if (group_topk_heap.empty()) { + group_topk_heap.limit(ctx->group_topk()); + } + + std::string vec_value; + vec_value.resize(meta_.element_size()); + ::memcpy(&(vec_value[0]), node_fp_coords_copy, meta_.element_size()); + + group_topk_heap.emplace_back( + cached_neighbor.first, VectorInfo(cur_expanded_dist, vec_value)); + + if (group_topk_heaps.size() >= ctx->group_num()) { + break; + } + } + + uint64_t neighbor_num = cached_neighbor.second.first; + diskann_id_t *node_neighbors = cached_neighbor.second.second; + + cpu_timer.reset(); + + float distances[neighbor_num]; + pq_table_->compute_dists(neighbor_num, node_neighbors, pq_chunk_num_, + ctx->pq_table_dist_buffer(), + ctx->pq_coord_buffer(), distances); + + stats.dist_num += neighbor_num; + stats.cpu_us += cpu_timer.micro_seconds(); + + for (uint64_t m = 0; m < neighbor_num; ++m) { + diskann_id_t id = node_neighbors[m]; + visit_filter.set_visited(id); + cmps++; + + Neighbor nn(id, distances[m]); + candidates.insert(nn); + } + } + + for (auto &frontier_neighbor : frontier_neighbors) { + uint8_t *node_disk_buf = DiskAnnUtil::offset_to_node( + node_per_sector_, max_node_size_, frontier_neighbor.second, + frontier_neighbor.first); + uint32_t *node_buf = DiskAnnUtil::offset_to_node_neighbor( + node_disk_buf, meta_.element_size()); + uint32_t neighbor_num = *node_buf; + + void *node_fp_coords = node_disk_buf; + memcpy(data_buf, node_fp_coords, disk_bytes_per_point_); + + float cur_expanded_dist = dc.dist(ctx->query(), data_buf); + + if (!ctx->filter().is_valid() || + !ctx->filter()(get_key(frontier_neighbor.first))) { + std::string group_id = group_by(frontier_neighbor.first); + + auto &group_topk_heap = group_topk_heaps[group_id]; + if (group_topk_heap.empty()) { + group_topk_heap.limit(ctx->group_topk()); + } + + std::string vec_value; + vec_value.resize(meta_.element_size()); + ::memcpy(&(vec_value[0]), data_buf, meta_.element_size()); + + group_topk_heap.emplace_back( + frontier_neighbor.first, + VectorInfo(cur_expanded_dist, vec_value)); + + if (group_topk_heaps.size() >= ctx->group_num()) { + break; + } + } + + cpu_timer.reset(); + + float distances[neighbor_num]; + diskann_id_t *node_neighbors = + reinterpret_cast(node_buf + 1); + pq_table_->compute_dists(neighbor_num, node_neighbors, pq_chunk_num_, + ctx->pq_table_dist_buffer(), + ctx->pq_coord_buffer(), distances); + + stats.dist_num += neighbor_num; + stats.cpu_us += cpu_timer.micro_seconds(); + + cpu_timer.reset(); + for (uint64_t m = 0; m < neighbor_num; ++m) { + diskann_id_t id = node_neighbors[m]; + visit_filter.set_visited(id); + cmps++; + stats.dist_num++; + + Neighbor nn(id, distances[m]); + candidates.insert(nn); + } + + stats.cpu_us += cpu_timer.micro_seconds(); + } + + hops++; + } + + stats.total_us += query_timer.micro_seconds(); + } + + return 0; +} + +} // namespace core +} // namespace zvec diff --git a/src/core/algorithm/diskann/diskann_indexer.h b/src/core/algorithm/diskann/diskann_indexer.h new file mode 100644 index 000000000..6fe684970 --- /dev/null +++ b/src/core/algorithm/diskann/diskann_indexer.h @@ -0,0 +1,104 @@ +// Copyright 2025-present the zvec project +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +#pragma once + +#include +#include +#include "diskann_context.h" +#include "diskann_file_reader.h" +#include "diskann_pq_table.h" +#include "diskann_searcher_entity.h" +#include "diskann_util.h" + +namespace zvec { +namespace core { + +class DiskAnnIndexer { + public: + typedef std::shared_ptr Pointer; + + public: + DiskAnnIndexer(const IndexMeta &meta); + ~DiskAnnIndexer(); + + public: + int init(DiskAnnSearcherEntity &entity); + int load_cache_list(const std::vector &node_list); + + void cache_bfs_levels(uint64_t num_nodes_to_cache, + std::vector &node_list); + + int cached_beam_search(DiskAnnContext *ctx); + int cached_beam_search_by_group(DiskAnnContext *ctx); + + int cached_beam_search_in_mem(DiskAnnContext *ctx); + + int knn_search(DiskAnnContext *ctx); + int linear_search(DiskAnnContext *ctx); + int keys_search(const std::vector &keys, DiskAnnContext *ctx); + + int get_vector(diskann_id_t id, IndexContext::Pointer &context, + std::string &vector); + + diskann_key_t get_key(diskann_id_t id) const; + diskann_id_t get_id(diskann_key_t key) const; + + std::vector read_nodes( + const std::vector &node_ids, + std::vector &coord_buffers, + std::vector> &nbr_buffers); + + protected: + int use_medroids_data_as_centroids(); + + private: + DiskAnnSearcherEntity *entity_; + + IndexStorage::Pointer storage_{}; + IndexMeta meta_; + + uint32_t max_degree_{0}; + uint32_t node_per_sector_{0}; + uint32_t max_node_size_{0}; + uint64_t pq_chunk_num_{0}; + uint64_t disk_bytes_per_point_{0}; + uint64_t aligned_dim_{0}; + uint64_t index_segment_offset_{0}; + uint64_t sector_num_per_node_{0}; + + float *centroid_data_{nullptr}; + + diskann_id_t medoid_; + std::vector entrypints_; + + std::shared_ptr reader_{nullptr}; + + PQTable::Pointer pq_table_; + + IOContext init_ctx_{0}; + + std::vector neighbor_cache_buffer_; + void *coord_cache_buf_{nullptr}; + + std::map coord_cache_; + std::map> neighbor_cache_; + + uint32_t beam_width_{2}; + uint32_t io_limit_{std::numeric_limits::max()}; + + uint64_t doc_cnt_{0}; +}; + +} // namespace core +} // namespace zvec diff --git a/src/core/algorithm/diskann/diskann_params.h b/src/core/algorithm/diskann/diskann_params.h new file mode 100644 index 000000000..1f308626d --- /dev/null +++ b/src/core/algorithm/diskann/diskann_params.h @@ -0,0 +1,53 @@ +// Copyright 2025-present the zvec project +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +#pragma once + +#include + +namespace zvec { +namespace core { + +static const std::string PARAM_DISKANN_BUILDER_MAX_DEGREE( + "proxima.diskann.builder.max_degree"); +static const std::string PARAM_DISKANN_BUILDER_LIST_SIZE( + "proxima.diskann.builder.list_size"); +static const std::string PARAM_DISKANN_BUILDER_MEMORY_LIMIT( + "proxima.diskann.builder.memory_limit"); +static const std::string PARAM_DISKANN_BUILDER_MEMORY_BUDGET( + "proxima.diskann.builder.memory_budget"); +static const std::string PARAM_DISKANN_BUILDER_DISK_PQ_DIM( + "proxima.diskann.builder.disk_pq_dim"); +static const std::string PARAM_DISKANN_BUILDER_THREAD_COUNT( + "proxima.diskann.builder.thread_count"); +static const std::string PARAM_DISKANN_BUILDER_MAX_TRAIN_SAMPLE_COUNT( + "proxima.diskann.builder.max_train_sample_count"); +static const std::string PARAM_DISKANN_BUILDER_TRAIN_SAMPLE_RATIO( + "proxima.diskann.builder.train_sample_ratio"); +static const std::string PARAM_DISKANN_BUILDER_MAX_PQ_CHUNK_NUM( + "proxima.diskann.builder.max_pq_chunk_num"); + +static const std::string PARAM_DISKANN_SEARCHER_LIST_SIZE( + "proxima.diskann.searcher.list_size"); +static const std::string PARAM_DISKANN_SEARCHER_CACHE_NODE_NUM( + "proxima.diskann.searcher.cache_node_num"); + +static const std::string PARAM_DISKANN_REDUCER_INDEX_NAME( + "proxima.diskann.reducer.index_name"); +static const std::string PARAM_DISKANN_REDUCER_WORKING_PATH( + "proxima.diskann.reducer.working_path"); +static const std::string PARAM_DISKANN_REDUCER_NUM_OF_ADD_THREADS( + "proxima.diskann.reducer.num_of_add_threads"); + +} // namespace core +} // namespace zvec diff --git a/src/core/algorithm/diskann/diskann_pq_table.cc b/src/core/algorithm/diskann/diskann_pq_table.cc new file mode 100644 index 000000000..0c13e4061 --- /dev/null +++ b/src/core/algorithm/diskann/diskann_pq_table.cc @@ -0,0 +1,117 @@ +// Copyright 2025-present the zvec project +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "diskann_pq_table.h" +#include "diskann_entity.h" + +namespace zvec { +namespace core { + +PQTable::PQTable(const IndexMeta &meta, uint32_t chunk_num) + : chunk_num_(chunk_num) { + meta_ = meta; + + if (meta.metric_name() == "Cosine") { + if (meta.data_type() == IndexMeta::DataType::DT_FP32) { + meta_.set_dimension(meta.dimension() - 1); + } else { + meta_.set_dimension(meta.dimension() - 2); + } + } +} + +PQTable::~PQTable() {} + +int PQTable::init(std::vector &full_pivot_data, + std::vector ¢roid, + std::vector &chunk_offsets, + std::vector &pq_data) { + full_pivot_data_ = std::move(full_pivot_data); + centroid_ = std::move(centroid); + chunk_offsets_ = std::move(chunk_offsets); + pq_data_ = std::move(pq_data); + + // alloc and compute transpose + transposed_tables_.resize(kPQCentroidNum * meta_.element_size()); + + uint32_t dim = meta_.dimension(); + uint32_t type = meta_.data_type(); + + switch (type) { + case IndexMeta::DataType::DT_FP32: { + float *transposed_tables_ptr = + reinterpret_cast(&transposed_tables_[0]); + float *full_pivot_data_ptr = + reinterpret_cast(&full_pivot_data_[0]); + for (size_t i = 0; i < kPQCentroidNum; i++) { + for (size_t j = 0; j < dim; j++) { + transposed_tables_ptr[j * kPQCentroidNum + i] = + full_pivot_data_ptr[i * dim + j]; + } + } + break; + } + case IndexMeta::DataType::DT_FP16: { + ailego::Float16 *transposed_tables_ptr = + reinterpret_cast(&transposed_tables_[0]); + ailego::Float16 *full_pivot_data_ptr = + reinterpret_cast(&full_pivot_data_[0]); + for (size_t i = 0; i < kPQCentroidNum; i++) { + for (size_t j = 0; j < dim; j++) { + transposed_tables_ptr[j * kPQCentroidNum + i] = + full_pivot_data_ptr[i * dim + j]; + } + } + break; + } + default: + LOG_ERROR("unsupported type, type: %u", type); + return IndexError_Unsupported; + } + + return 0; +} + +void PQTable::aggregate_coords(uint32_t id_num, const diskann_id_t *ids, + const uint8_t *all_coords, size_t dim, + uint8_t *out) { + for (size_t i = 0; i < id_num; i++) { + memcpy(out + i * dim, all_coords + ids[i] * dim, dim * sizeof(uint8_t)); + } +} + +void PQTable::pq_dist_lookup(const uint8_t *pq_ids, size_t id_num, + size_t pq_nchunks, const float *pq_dist_buffer, + float *dists_out) { + ailego_prefetch(dists_out); + ailego_prefetch(pq_ids); + ailego_prefetch(pq_ids + 64); + ailego_prefetch(pq_ids + 128); + + memset(dists_out, 0, id_num * sizeof(float)); + + for (size_t chunk = 0; chunk < pq_nchunks; chunk++) { + const float *chunk_dists = pq_dist_buffer + kPQCentroidNum * chunk; + if (chunk < pq_nchunks - 1) { + ailego_prefetch(chunk_dists + kPQCentroidNum); + } + for (size_t idx = 0; idx < id_num; idx++) { + uint8_t pq_centerid = pq_ids[pq_nchunks * idx + chunk]; + dists_out[idx] += chunk_dists[pq_centerid]; + } + } +} + +} // namespace core +} // namespace zvec diff --git a/src/core/algorithm/diskann/diskann_pq_table.h b/src/core/algorithm/diskann/diskann_pq_table.h new file mode 100644 index 000000000..a998c2988 --- /dev/null +++ b/src/core/algorithm/diskann/diskann_pq_table.h @@ -0,0 +1,154 @@ +// Copyright 2025-present the zvec project +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +#pragma once + +#include +#include "diskann_entity.h" + +namespace zvec { +namespace core { + +class PQTable { + public: + typedef std::shared_ptr Pointer; + + public: + static constexpr uint32_t kPQBitNum = 8; + static constexpr uint32_t kPQCentroidNum = 1 << kPQBitNum; + static constexpr uint32_t kMaxTrainSampleCount = 200000; + static constexpr double kTrainSampleRatio = 1.0; + static constexpr uint32_t kMeanIterNum = 12; + + public: + PQTable(const IndexMeta &meta, uint32_t chunk_num); + virtual ~PQTable(); + + int init(std::vector &table, std::vector ¢roid, + std::vector &chunk_offsets, std::vector &pq_data); + + + template + void compute_distance_table_ip(const T *query_vec, float *dist_vec) { + memset(dist_vec, 0, kPQCentroidNum * chunk_num_ * sizeof(float)); + + const T *transposed_tables_ptr = + reinterpret_cast(transposed_tables_.data()); + // chunk wise distance computation + for (size_t chunk = 0; chunk < chunk_num_; chunk++) { + // sum (q-c)^2 for the dimensions associated with this chunk + float *chunk_dists = dist_vec + (kPQCentroidNum * chunk); + + for (size_t j = chunk_offsets_[chunk]; j < chunk_offsets_[chunk + 1]; + j++) { + const T *centers_dim_vec = &transposed_tables_ptr[kPQCentroidNum * j]; + + for (size_t idx = 0; idx < kPQCentroidNum; idx++) { + float centor_data = centers_dim_vec[idx]; + float query_data = query_vec[j]; + float dim_score = centor_data * query_data; + chunk_dists[idx] += -dim_score; + } + } + } + } + + template + void compute_distance_table(const T *query_vec, float *dist_vec) { + memset(dist_vec, 0, kPQCentroidNum * chunk_num_ * sizeof(float)); + + const T *transposed_tables_ptr = + reinterpret_cast(transposed_tables_.data()); + // chunk wise distance computation + for (size_t chunk = 0; chunk < chunk_num_; chunk++) { + // sum (q-c)^2 for the dimensions associated with this chunk + float *chunk_dists = dist_vec + (kPQCentroidNum * chunk); + + for (size_t j = chunk_offsets_[chunk]; j < chunk_offsets_[chunk + 1]; + j++) { + const T *centers_dim_vec = &transposed_tables_ptr[kPQCentroidNum * j]; + + for (size_t idx = 0; idx < kPQCentroidNum; idx++) { + float diff = centers_dim_vec[idx] - query_vec[j]; + chunk_dists[idx] += (diff * diff); + } + } + } + } + + template + void preprocess_query(T *query) { + const T *centroid_ptr = reinterpret_cast(centroid_.data()); + for (size_t i = 0; i < meta_.dimension(); i++) { + query[i] -= centroid_ptr[i]; + } + } + + void aggregate_coords(uint32_t id_num, const diskann_id_t *ids, + const uint8_t *all_coords, size_t dim, uint8_t *out); + + void pq_dist_lookup(const uint8_t *pq_ids, size_t id_num, size_t pq_nchunks, + const float *pq_dist_buffer, float *dists); + + void compute_dists(uint32_t id_num, const diskann_id_t *ids, + uint32_t chunk_num, float *pq_dist_buffer, + void *coord_buffer, float *dists) { + uint8_t *coord_buffer_ptr = reinterpret_cast(coord_buffer); + + aggregate_coords(id_num, ids, this->pq_data(), chunk_num, coord_buffer_ptr); + + pq_dist_lookup(coord_buffer_ptr, id_num, chunk_num, pq_dist_buffer, dists); + + return; + } + + int preprocess_pq_dist_table(void *query_rotated, float *dist_buffer) { + switch (meta_.data_type()) { + case IndexMeta::DataType::DT_FP32: + preprocess_query(reinterpret_cast(query_rotated)); + compute_distance_table(reinterpret_cast(query_rotated), + dist_buffer); + + break; + case IndexMeta::DataType::DT_FP16: + preprocess_query(reinterpret_cast(query_rotated)); + compute_distance_table( + reinterpret_cast(query_rotated), dist_buffer); + break; + default: + LOG_ERROR("Unsupported Type: %u", meta_.data_type()); + return IndexError_Unsupported; + } + + return 0; + } + + public: + const uint8_t *pq_data() const { + return pq_data_.data(); + } + + private: + std::vector full_pivot_data_; + std::vector transposed_tables_; + + std::vector centroid_; + std::vector chunk_offsets_; + std::vector pq_data_; + + IndexMeta meta_; + uint64_t chunk_num_{0}; +}; + +} // namespace core +} // namespace zvec diff --git a/src/core/algorithm/diskann/diskann_pq_trainer.cc b/src/core/algorithm/diskann/diskann_pq_trainer.cc new file mode 100644 index 000000000..81b0ffe10 --- /dev/null +++ b/src/core/algorithm/diskann/diskann_pq_trainer.cc @@ -0,0 +1,419 @@ +// Copyright 2025-present the zvec project +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "diskann_pq_trainer.h" +#include "diskann_entity.h" +#include "diskann_util.h" + +namespace zvec { +namespace core { + +DiskAnnPqTrainer::DiskAnnPqTrainer(uint32_t max_train_sample_count) + : max_train_sample_count_{max_train_sample_count} {} + +DiskAnnPqTrainer::~DiskAnnPqTrainer() {} + +int DiskAnnPqTrainer::gen_random_sample(IndexHolder::Pointer holder, + const IndexMeta &meta, + std::string &sample_data, + size_t &sample_size) { + double train_sample_ratio = + max_train_sample_count_ < 1 ? max_train_sample_count_ : 1; + + uint32_t max_train_sample_count = train_sample_ratio * holder->count(); + max_train_sample_count = max_train_sample_count > max_train_sample_count_ + ? max_train_sample_count_ + : max_train_sample_count; + + std::vector> sample_vecs; + + // std::random_device rd; + // uint32_t x = rd(); + uint32_t x = 456321; + std::mt19937 gen(x); + std::uniform_real_distribution dist(0, 1); + + uint32_t vec_size = meta.element_size(); + + auto iter = holder->create_iterator(); + if (!iter) { + LOG_ERROR("Create iterator for holder failed"); + return IndexError_Runtime; + } + + size_t sample_count = 0; + while (iter->is_valid() && sample_count < max_train_sample_count) { + float random = dist(gen); + + if (random < train_sample_ratio) { + const void *vec = iter->data(); + + std::vector temp_vec; + temp_vec.resize(vec_size); + + std::memcpy(reinterpret_cast(&temp_vec[0]), vec, vec_size); + + sample_vecs.push_back(std::move(temp_vec)); + + sample_count++; + } + + iter->next(); + } + + sample_size = sample_vecs.size(); + sample_data.reserve(sample_size * vec_size); + + for (size_t i = 0; i < sample_size; i++) { + sample_data.append(reinterpret_cast(sample_vecs[i].data()), + vec_size); + } + + return 0; +} + +template +int DiskAnnPqTrainer::prepare_pq_train_data( + const IndexMeta &meta, size_t num_train, std::string &train_data, + bool use_zero_mean, std::vector ¢roid, + std::shared_ptr &train_features) { + uint32_t dim = meta.dimension(); + uint32_t vec_size = meta.element_size(); + + std::string train_data_processed; + train_data_processed.resize(num_train * vec_size); + + std::memcpy(&(train_data_processed[0]), train_data.data(), + num_train * vec_size); + + // use fp32 to accumulate to avoid overflow + std::vector centroid_temp(dim); + for (uint64_t d = 0; d < dim; d++) { + centroid_temp[d] = 0; + } + + T *train_data_processed_ptr = reinterpret_cast(&train_data_processed[0]); + + if (use_zero_mean) { + for (uint64_t d = 0; d < dim; d++) { + for (uint64_t p = 0; p < num_train; p++) { + centroid_temp[d] += train_data_processed_ptr[p * dim + d]; + } + centroid_temp[d] /= num_train; + } + + for (uint64_t d = 0; d < dim; d++) { + for (uint64_t p = 0; p < num_train; p++) { + train_data_processed_ptr[p * dim + d] -= centroid_temp[d]; + } + } + } + + for (size_t i = 0; i < num_train; ++i) { + train_features->emplace(train_data_processed_ptr + i * dim); + } + + // copy the centroid out + centroid.resize(vec_size); + T *centroid_ptr = reinterpret_cast(centroid.data()); + for (uint64_t d = 0; d < dim; d++) { + centroid_ptr[d] = centroid_temp[d]; + } + + return 0; +} + +template +int DiskAnnPqTrainer::convert_pivot_data( + const IndexMeta &meta, uint32_t num_centers, uint32_t pq_chunk_num, + const std::vector &chunk_dims, + const std::vector &chunk_offsets, + IndexCluster::CentroidList ¢roids, + std::vector &full_pivot_data) { + uint32_t dim = meta.dimension(); + uint32_t element_size = meta.element_size(); + + full_pivot_data.resize(num_centers * element_size); + + for (size_t chunk = 0; chunk < pq_chunk_num; ++chunk) { + for (size_t cluster = 0; cluster < num_centers; ++cluster) { + size_t idx = chunk * num_centers + cluster; + + T *pivot_data_ptr = reinterpret_cast(&(full_pivot_data[0])) + + cluster * dim + chunk_offsets[chunk]; + const T *feature_ptr = + reinterpret_cast(centroids[idx].feature()); + for (size_t d = 0; d <= chunk_dims[chunk]; ++d) { + pivot_data_ptr[d] = feature_ptr[d]; + } + } + } + + return 0; +} + +int DiskAnnPqTrainer::train_pq(IndexThreads::Pointer threads, + const IndexMeta &meta, std::string &train_data, + size_t num_train, uint32_t num_centers, + uint32_t pq_chunk_num, uint32_t max_iterations, + bool use_zero_mean, + std::vector &full_pivot_data, + std::vector ¢roid, + std::vector &chunk_offsets) { + uint32_t dim = meta.dimension(); + if (pq_chunk_num > dim) { + LOG_ERROR("Error: number of chunks more than dimension. chunk: %u, dim: %u", + pq_chunk_num, dim); + return IndexError_InvalidArgument; + } + + std::shared_ptr train_features( + new CompactIndexFeatures(meta)); + + uint32_t type = meta.data_type(); + + int ret; + switch (type) { + case IndexMeta::DataType::DT_FP32: + ret = prepare_pq_train_data( + meta, num_train, train_data, use_zero_mean, centroid, train_features); + if (ret != 0) { + LOG_ERROR("Failed to prepare pq train data"); + return ret; + } + break; + + case IndexMeta::DataType::DT_FP16: + ret = prepare_pq_train_data( + meta, num_train, train_data, use_zero_mean, centroid, train_features); + if (ret != 0) { + LOG_ERROR("Failed to prepare pq train data"); + return ret; + } + break; + } + + // Do Train + ailego::Params params; + params.set("proxima.cluster.multi_chunk_cluster.count", num_centers); + params.set("proxima.cluster.multi_chunk_cluster.chunk_count", pq_chunk_num); + params.set("proxima.cluster.multi_chunk_cluster.max_iterations", + max_iterations); + + ret = chunk_cluster_.init(meta, params); + if (ret != 0) { + LOG_ERROR("Failed to get chunk cluster"); + return IndexError_InvalidArgument; + } + + ret = chunk_cluster_.mount(train_features); + if (ret != 0) { + LOG_ERROR("Cannot mount train features"); + return ret; + } + + + std::vector labels; + + ret = chunk_cluster_.cluster(threads, cluster_centroids_); + if (ret != 0) { + LOG_ERROR("Failed to cluster"); + return ret; + } + + chunk_offsets = chunk_cluster_.chunk_dim_offsets(); + auto chunk_dims = chunk_cluster_.chunk_dims(); + + switch (type) { + case IndexMeta::DataType::DT_FP32: + ret = convert_pivot_data(meta, num_centers, pq_chunk_num, + chunk_dims, chunk_offsets, + cluster_centroids_, full_pivot_data); + if (ret != 0) { + LOG_ERROR("Failed to convert pivot data"); + return ret; + } + break; + + case IndexMeta::DataType::DT_FP16: + ret = convert_pivot_data( + meta, num_centers, pq_chunk_num, chunk_dims, chunk_offsets, + cluster_centroids_, full_pivot_data); + if (ret != 0) { + LOG_ERROR("Failed to convert pivot data"); + return ret; + } + break; + } + + return 0; +} + +int DiskAnnPqTrainer::train_quantized_data( + IndexThreads::Pointer threads, IndexHolder::Pointer holder, + const IndexMeta &meta, std::vector &pq_full_pivot_data, + std::vector &pq_centroid, std::vector &pq_chunk_offsets, + size_t pq_chunk_num) { + size_t train_size; + std::string train_data; + + int ret = gen_random_sample(holder, meta, train_data, train_size); + if (ret != 0) { + LOG_ERROR("Get Random Sample Error, ret: %d", ret); + return ret; + } + + LOG_INFO("Training data with %zu samples loaded.", train_size); + + // bool use_zero_mean = (meta.metric_name() != "InnerProduct" ? true : + // false); + bool use_zero_mean = false; + + ret = train_pq(threads, meta, train_data, train_size, PQTable::kPQCentroidNum, + pq_chunk_num, PQTable::kMeanIterNum, use_zero_mean, + pq_full_pivot_data, pq_centroid, pq_chunk_offsets); + if (ret != 0) { + LOG_ERROR("Train PQ Error, ret: %d", ret); + return ret; + } + + return 0; +} + +int DiskAnnPqTrainer::generate_pq(IndexThreads::Pointer threads, + const IndexMeta &meta, + IndexHolder::Pointer holder, + uint32_t pq_chunk_num, + std::vector ¢roid, + std::vector &block_compressed_data) { + uint32_t type = meta.data_type(); + uint32_t dim = meta.dimension(); + + if (pq_chunk_num > dim) { + LOG_ERROR("Error: number of chunks more than dimension. chunk: %u, dim: %u", + pq_chunk_num, dim); + return IndexError_InvalidArgument; + } + + // Do Label + std::vector labels; + size_t num_vecs = holder->count(); + size_t batch_size = + num_vecs <= compress_batch_size_ ? num_vecs : compress_batch_size_; + + std::vector block_compressed_base(batch_size * pq_chunk_num); + + std::memset(&block_compressed_base[0], 0, + batch_size * pq_chunk_num * sizeof(uint32_t)); + + std::vector block_data(batch_size * meta.element_size()); + std::vector block_data_converted(batch_size * meta.element_size()); + + size_t block_num = DiskAnnUtil::div_round_up(num_vecs, batch_size); + + block_compressed_data.resize(num_vecs * pq_chunk_num); + + auto iter = holder->create_iterator(); + if (!iter) { + LOG_ERROR("Create iterator for holder failed"); + return IndexError_Runtime; + } + + for (size_t block = 0; block < block_num; block++) { + size_t start_id = block * batch_size; + size_t end_id = std::min((block + 1) * batch_size, num_vecs); + + size_t cur_block_size = end_id - start_id; + + for (size_t i = 0; i < cur_block_size && iter->is_valid(); i++) { + const void *vec = iter->data(); + std::memcpy( + reinterpret_cast(&block_data[0]) + i * meta.element_size(), + vec, meta.element_size()); + iter->next(); + } + + std::memcpy(block_data_converted.data(), block_data.data(), + cur_block_size * meta.element_size()); + + LOG_INFO("Processing Docs, Range: [%zu, %zu)..", start_id, end_id); + + std::shared_ptr block_features( + new CompactIndexFeatures(meta)); + + switch (type) { + case IndexMeta::DataType::DT_FP32: + DiskAnnUtil::convert_vector_to_residual( + reinterpret_cast(block_data_converted.data()), + cur_block_size, dim, centroid.data()); + break; + case IndexMeta::DataType::DT_FP16: + DiskAnnUtil::convert_vector_to_residual( + reinterpret_cast(block_data_converted.data()), + cur_block_size, dim, centroid.data()); + break; + default: + return IndexError_InvalidArgument; + } + + for (size_t i = 0; i < cur_block_size; i++) { + block_features->emplace(block_data_converted.data() + + i * meta.element_size()); + } + + int ret = chunk_cluster_.mount(block_features); + if (ret != 0) { + LOG_ERROR("Cannot mount block features"); + return ret; + } + + ret = chunk_cluster_.label(threads, cluster_centroids_, &labels); + if (ret != 0) { + LOG_ERROR("Failed to label"); + return ret; + } + + std::vector compressed_data(cur_block_size * pq_chunk_num); + + DiskAnnUtil::convert_types_uint32_to_uint8( + labels.data(), compressed_data.data(), cur_block_size, pq_chunk_num); + + memcpy(&(block_compressed_data[0]) + start_id * pq_chunk_num, + compressed_data.data(), cur_block_size * pq_chunk_num); + + LOG_INFO("Generate PQ Data Done."); + } + + return 0; +} + +int DiskAnnPqTrainer::generate_quantized_data( + IndexThreads::Pointer threads, IndexHolder::Pointer holder, + const IndexMeta &meta, std::vector &pq_centroid, + std::vector &block_compressed_data, size_t pq_chunk_num) { + // bool use_zero_mean = (meta.metric_name() != "InnerProduct" ? true : + // false); + + int ret = generate_pq(threads, meta, holder, pq_chunk_num, pq_centroid, + block_compressed_data); + if (ret != 0) { + LOG_ERROR("Generate PQ Error, ret: %d", ret); + return ret; + } + + return 0; +} + +} // namespace core +} // namespace zvec diff --git a/src/core/algorithm/diskann/diskann_pq_trainer.h b/src/core/algorithm/diskann/diskann_pq_trainer.h new file mode 100644 index 000000000..5822d8f9a --- /dev/null +++ b/src/core/algorithm/diskann/diskann_pq_trainer.h @@ -0,0 +1,89 @@ +// Copyright 2025-present the zvec project +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +#pragma once + +#include +#include "diskann_entity.h" +#include "diskann_pq_table.h" +#include "../cluster/multi_chunk_cluster.h" + +namespace zvec { +namespace core { + +class DiskAnnPqTrainer { + public: + typedef std::unique_ptr UPointer; + + public: + DiskAnnPqTrainer(uint32_t max_train_sample_count); + virtual ~DiskAnnPqTrainer(); + + public: + template + static int prepare_pq_train_data( + const IndexMeta &meta, size_t num_train, std::string &train_data, + bool use_zero_mean, std::vector ¢roid, + std::shared_ptr &train_features); + + template + static int convert_pivot_data(const IndexMeta &meta, uint32_t num_centers, + uint32_t pq_chunk_num, + const std::vector &chunk_dims, + const std::vector &chunk_offsets, + IndexCluster::CentroidList ¢roids, + std::vector &full_pivot_data); + + int gen_random_sample(IndexHolder::Pointer holder, const IndexMeta &meta, + std::string &sample_data, size_t &sample_size); + + int generate_quantized_data(IndexThreads::Pointer threads, + IndexHolder::Pointer holder, + const IndexMeta &meta, + // std::vector &pq_full_pivot_data, + std::vector &pq_centroid, + // std::vector &pq_chunk_offsets, + std::vector &block_compressed_data, + size_t num_pq_chunks); + + int generate_pq(IndexThreads::Pointer threads, const IndexMeta &meta, + IndexHolder::Pointer holder, uint32_t num_pq_chunks, + std::vector ¢roid, + std::vector &block_compressed_data); + + int train_quantized_data(IndexThreads::Pointer threads, + IndexHolder::Pointer holder, const IndexMeta &meta, + std::vector &pq_full_pivot_data, + std::vector &pq_centroid, + std::vector &pq_chunk_offsets, + size_t num_pq_chunks); + + int train_pq(IndexThreads::Pointer threads, const IndexMeta &meta, + std::string &train_data, size_t num_train, uint32_t num_centers, + uint32_t num_pq_chunks, uint32_t max_iterations, + bool use_zero_mean, std::vector &full_pivot_data, + std::vector ¢roid, + std::vector &chunk_offsets); + + private: + static constexpr uint32_t compress_batch_size_{ + DiskAnnEntity::kDefaultCompressBatchSize}; + + // pq cluster + MultiChunkCluster chunk_cluster_; + IndexCluster::CentroidList cluster_centroids_; + uint32_t max_train_sample_count_{PQTable::kMaxTrainSampleCount}; +}; + +} // namespace core +} // namespace zvec \ No newline at end of file diff --git a/src/core/algorithm/diskann/diskann_reducer.cc b/src/core/algorithm/diskann/diskann_reducer.cc new file mode 100644 index 000000000..05c9455f7 --- /dev/null +++ b/src/core/algorithm/diskann/diskann_reducer.cc @@ -0,0 +1,195 @@ +// Copyright 2025-present the zvec project +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "diskann_reducer.h" +#include +#include +#include +#include +#include +#include "diskann_params.h" + +namespace zvec { +namespace core { + +int DiskAnnReducer::init(const ailego::Params ¶ms) { + params.get(PARAM_DISKANN_REDUCER_WORKING_PATH, &working_path_); + if (working_path_.empty()) { + LOG_ERROR("Missing parameter. %s", + PARAM_DISKANN_REDUCER_WORKING_PATH.c_str()); + return IndexError_InvalidArgument; + } + + std::string index_name = + params.get_as_string(PARAM_DISKANN_REDUCER_INDEX_NAME); + if (index_name.empty()) { + index_name = std::to_string(std::clock()); + } + + reducer_file_path_ = ailego::StringHelper::Concat( + working_path_, "/", kReducerFileName, index_name); + + holder_file_path_ = ailego::StringHelper::Concat(working_path_, "/", + kHolderFileName, index_name); + + state_ = STATE_INITED; + return 0; +} + +int DiskAnnReducer::cleanup(void) { + return 0; +} + +//! Reduce operator with filter +int DiskAnnReducer::reduce(const IndexFilter &filter) { + if (entities_.empty() || state_ != STATE_FEED) { + LOG_ERROR("No container to reduce, feed first"); + return IndexError_NoReady; + } + + // size_t total_cnt = 0; + // for (auto entity_ : entities_) { + // total_cnt += entity_->doc_cnt(); + // } + + if (use_mem_holder_) { + mem_holder_ = std::make_shared(meta_); + for (auto entity : entities_) { + size_t doc_cnt = entity->doc_cnt(); + for (size_t id = 0; id < doc_cnt; ++id) { + diskann_key_t pkey = entity->get_key(id); + + if (filter.is_valid() && filter(pkey)) { + continue; + } + + const void *vec = entity->get_vector(id); + mem_holder_->emplace(pkey, vec); + } + } + } else { + disk_holder_ = + std::make_shared(meta_, holder_file_path_); + + int ret = disk_holder_->init(); + if (ret != 0) { + LOG_ERROR("DiskAnn Index Holder init failed"); + return ret; + } + + for (auto entity : entities_) { + size_t doc_cnt = entity->doc_cnt(); + for (size_t id = 0; id < doc_cnt; ++id) { + diskann_key_t pkey = entity->get_key(id); + + if (filter.is_valid() && filter(pkey)) { + continue; + } + + const void *vec = entity->get_vector(id); + disk_holder_->emplace(pkey, vec); + } + } + + disk_holder_->close(); + } + + builder_ = IndexFactory::CreateBuilder(kDiskAnnBuilderName); + if (!builder_) { + LOG_ERROR("Create builder failed. name[%s]", kDiskAnnBuilderName.c_str()); + return IndexError_Runtime; + } + + if (thread_pool_ == nullptr) { + LOG_ERROR( + "Only support multi-thread mode. Thread pool is not set for reducer."); + return IndexError_NoReady; + } + + LOG_INFO("Start diskann reduce"); + + ailego::ElapsedTime timer; + + auto params = meta_.builder_params(); + + int ret = builder_->init(meta_, params); + if (ret != 0) { + LOG_ERROR("Init proxima streamer failed. ret[%d]", ret); + return ret; + } + + if (use_mem_holder_) { + ret = builder_->train(mem_holder_); + if (ret != 0) { + LOG_ERROR("Diskann builder failed to train. ret[%d]", ret); + return ret; + } + + ret = builder_->build(mem_holder_); + if (ret != 0) { + LOG_ERROR("Diskann builder failed to build. ret[%d]", ret); + return ret; + } + } else { + ret = builder_->train(disk_holder_); + if (ret != 0) { + LOG_ERROR("Diskann builder failed to train. ret[%d]", ret); + return ret; + } + + ret = builder_->build(disk_holder_); + if (ret != 0) { + LOG_ERROR("Diskann builder failed to build. ret[%d]", ret); + return ret; + } + } + + auto &stats = builder_->stats(); + + stats_.set_reduced_costtime(timer.seconds()); + stats_.set_filtered_count(stats.discarded_count()); + + state_ = STATE_REDUCE; + + LOG_INFO("End DiskAnn reduce. cost time: [%zu]s", (size_t)timer.seconds()); + return 0; +} + +//! Dump index by dumper +int DiskAnnReducer::dump(const IndexDumper::Pointer &dumper) { + LOG_INFO("Begin diskann reducer dump"); + + if (state_ != STATE_REDUCE) { + LOG_WARN("Reduce first before dump."); + return IndexError_NoReady; + } + + ailego::ElapsedTime timer; + int ret = builder_->dump(dumper); + if (ret != 0) { + LOG_ERROR("diskann reducer dump failed. ret[%d]", ret); + return ret; + } + stats_.set_dumped_costtime(timer.seconds()); + + LOG_INFO("End diskann reducer dump, dump costtime=[%zu]s", + (size_t)(stats_.dumped_costtime())); + + return 0; +} + +INDEX_FACTORY_REGISTER_REDUCER(DiskAnnReducer); + +} // namespace core +} // namespace zvec diff --git a/src/core/algorithm/diskann/diskann_reducer.h b/src/core/algorithm/diskann/diskann_reducer.h new file mode 100644 index 000000000..b912ea1a8 --- /dev/null +++ b/src/core/algorithm/diskann/diskann_reducer.h @@ -0,0 +1,85 @@ +// Copyright 2025-present the zvec project +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +#pragma once + +#include +#include +#include +#include +#include +#include +#include "diskann_holder.h" +#include "diskann_reducer_entity.h" + +namespace zvec { +namespace core { + +class DiskAnnReducer : public IndexReducer { + public: + //! Constructor + DiskAnnReducer(void) = default; + + protected: + //! Initialize Reducer + virtual int init(const ailego::Params ¶ms) override; + + //! Cleanup Reducer + virtual int cleanup(void) override; + + //! Feed indexes from containers + // virtual int feed(IndexStorage::Pointer container) override; + + //! Reduce operator (with filter) + virtual int reduce(const IndexFilter &filter) override; + + //! Dump index by dumper + virtual int dump(const IndexDumper::Pointer &dumper) override; + + //! Retrieve statistics + virtual const Stats &stats(void) const override { + return stats_; + } + + private: + enum State { + STATE_UNINITED = 0, + STATE_INITED = 1, + STATE_FEED = 2, + STATE_REDUCE = 3 + }; + + std::string working_path_{""}; + + IndexMeta meta_{}; + std::vector entities_{}; + + // bool use_mem_holder_{true}; + bool use_mem_holder_{false}; + RandomAccessIndexHolder::Pointer mem_holder_; + DiskAnnIndexHolder::Pointer disk_holder_; + + IndexBuilder::Pointer builder_{nullptr}; + std::string reducer_file_path_{""}; + std::string holder_file_path_{""}; + + Stats stats_{}; + State state_{STATE_UNINITED}; + + const std::string kDiskAnnBuilderName{"DiskAnnBuilder"}; + const std::string kReducerFileName{"diskann.reducer.builder."}; + const std::string kHolderFileName{"diskann.reducer.holder."}; +}; + +} // namespace core +} // namespace zvec diff --git a/src/core/algorithm/diskann/diskann_reducer_entity.cc b/src/core/algorithm/diskann/diskann_reducer_entity.cc new file mode 100644 index 000000000..4ccfb6d21 --- /dev/null +++ b/src/core/algorithm/diskann/diskann_reducer_entity.cc @@ -0,0 +1,215 @@ +// Copyright 2025-present the zvec project +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "diskann_reducer_entity.h" +#include +#include + +namespace zvec { +namespace core { + +int DiskAnnReducerEntity::load(const IndexStorage::Pointer &container, + bool check_crc) { + container_ = container; + + int ret = load_segments(check_crc); + if (ret != 0) { + return ret; + } + + sector_num_per_node_ = node_per_sector() > 0 + ? 1 + : DiskAnnUtil::div_round_up( + max_node_size(), DiskAnnUtil::kSectorSize); + + loaded_ = true; + + return 0; +} + +int DiskAnnReducerEntity::load_segments(bool /*check_crc*/) { + int ret; + ret = load_header_segment(); + if (ret != 0) { + LOG_ERROR("Load Header Segment Failed, ret = %d", ret); + + return ret; + } + + ret = load_key_segment(); + if (ret != 0) { + LOG_ERROR("Load Key Segment Failed, ret = %d", ret); + + return ret; + } + + ret = load_vector_segment(); + if (ret != 0) { + LOG_ERROR("Load Vector Segment Failed, ret = %d", ret); + + return ret; + } + + return 0; +} + +int DiskAnnReducerEntity::load_header_segment() { + const void *data = nullptr; + meta_segment_ = container_->get(kDiskAnnMetaSegmentId); + if (!meta_segment_ || + meta_segment_->data_size() < sizeof(DiskAnnMetaHeader)) { + LOG_ERROR("Miss or invalid segment %s", kDiskAnnMetaSegmentId.c_str()); + return IndexError_InvalidFormat; + } + if (meta_segment_->read(0, reinterpret_cast(&data), + sizeof(DiskAnnMetaHeader)) != + sizeof(DiskAnnMetaHeader)) { + LOG_ERROR("Read segment %s failed", kDiskAnnMetaSegmentId.c_str()); + return IndexError_ReadData; + } + + ::memcpy(reinterpret_cast(&meta_header_), data, + sizeof(DiskAnnMetaHeader)); + + return 0; +} + +int DiskAnnReducerEntity::load_vector_segment() { + vector_segment_ = container_->get(kDiskAnnVectorSegmentId); + if (!vector_segment_) { + LOG_ERROR("Miss or invalid segment %s", + DiskAnnEntity::kDiskAnnVectorSegmentId.c_str()); + return IndexError_InvalidFormat; + } + + return 0; +} + +int DiskAnnReducerEntity::load_key_segment() { + // load key + key_segment_ = container_->get(kDiskAnnKeySegmentId); + if (!key_segment_) { + LOG_ERROR("Miss or invalid segment %s", + DiskAnnEntity::kDiskAnnKeySegmentId.c_str()); + return IndexError_InvalidFormat; + } + + size_t key_data_len = doc_cnt() * sizeof(key_t); + + // load key mapping + key_mapping_segment_ = container_->get(kDiskAnnKeyMappingSegmentId); + const void *data = nullptr; + if (key_mapping_segment_->read(0, reinterpret_cast(&data), + key_data_len) != key_data_len) { + LOG_ERROR("Read segment %s failed", kDiskAnnKeyMappingSegmentId.c_str()); + return IndexError_ReadData; + } + + key_buffer_.resize(key_data_len); + memcpy(&(key_buffer_[0]), data, key_data_len); + + return 0; +} + +bool DiskAnnReducerEntity::do_crc_check( + std::vector &segments) const { + constexpr size_t blk_size = 4096; + const void *data; + + for (auto &segment : segments) { + size_t offset = 0; + size_t rd_size; + uint32_t crc = 0; + while (offset < segment->data_size()) { + size_t size = std::min(blk_size, segment->data_size() - offset); + if ((rd_size = segment->read(offset, &data, size)) <= 0) { + break; + } + offset += rd_size; + crc = ailego::Crc32c::Hash(data, rd_size, crc); + } + if (crc != segment->data_crc()) { + return false; + } + } + return true; +} + +//! Get vector local id by key +diskann_id_t DiskAnnReducerEntity::get_id(diskann_key_t key) const { + const diskann_id_t *key_mapping_data_ptr = + reinterpret_cast(key_mapping_buffer_.data()); + const diskann_key_t *key_data_ptr = + reinterpret_cast(key_buffer_.data()); + + //! Do binary search + diskann_id_t start = 0UL; + diskann_id_t end = doc_cnt(); + diskann_id_t idx = 0u; + while (start < end) { + idx = start + (end - start) / 2; + diskann_id_t local_id = key_mapping_data_ptr[idx]; + + const diskann_key_t local_key = key_data_ptr[local_id]; + + if (local_key < key) { + start = idx + 1; + } else if (local_key > key) { + end = idx; + } else { + return local_id; + } + } + + return kInvalidId; +} + +diskann_key_t DiskAnnReducerEntity::get_key(diskann_id_t id) const { + const void *key; + if (ailego_unlikely(key_segment_->read(id * sizeof(diskann_key_t), &key, + sizeof(diskann_key_t)) != + sizeof(diskann_key_t))) { + LOG_ERROR("Read key from segment failed"); + return kInvalidKey; + } + + return *(reinterpret_cast(key)); +} + +const void *DiskAnnReducerEntity::get_vector(diskann_id_t id) const { + size_t read_size = sector_num_per_node_ * DiskAnnUtil::kSectorSize; + size_t sector_id = DiskAnnUtil::get_node_sector( + node_per_sector(), max_node_size(), DiskAnnUtil::kSectorSize, id); + size_t offset = sector_id * DiskAnnUtil::kSectorSize; + + if (sector_id != sector_id_) { + const void *sector_data; + if (ailego_unlikely(vector_segment_->read(offset, §or_data, + read_size) != read_size)) { + LOG_ERROR("Read vector from segment failed"); + return nullptr; + } + + sector_id_ = sector_id; + sector_buffer_.assign(reinterpret_cast(sector_data), + read_size); + } + + return DiskAnnUtil::offset_to_node_const( + node_per_sector(), max_node_size(), + reinterpret_cast(sector_buffer_.data()), id); +} + +} // namespace core +} // namespace zvec diff --git a/src/core/algorithm/diskann/diskann_reducer_entity.h b/src/core/algorithm/diskann/diskann_reducer_entity.h new file mode 100644 index 000000000..076f40aad --- /dev/null +++ b/src/core/algorithm/diskann/diskann_reducer_entity.h @@ -0,0 +1,69 @@ +// Copyright 2025-present the zvec project +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +#pragma once + +#include +#include +#include +#include +#include "diskann_entity.h" +#include "diskann_file_reader.h" +#include "diskann_pq_table.h" +#include "diskann_util.h" + +namespace zvec { +namespace core { + +class DiskAnnReducerEntity : public DiskAnnEntity { + public: + using Pointer = std::shared_ptr; + using SegmentPointer = IndexStorage::Segment::Pointer; + + public: + DiskAnnReducerEntity() = default; + virtual ~DiskAnnReducerEntity() = default; + + int load(const IndexStorage::Pointer &container, bool check_crc); + int load_segments(bool check_crc); + int load_header_segment(); + int load_vector_segment(); + int load_key_segment(); + int load_key_mapping_segment(); + + bool do_crc_check(std::vector &segments) const; + + diskann_id_t get_id(diskann_key_t key) const; + diskann_key_t get_key(diskann_id_t id) const; + const void *get_vector(diskann_id_t id) const; + + private: + IndexStorage::Pointer container_{}; + IndexStorage::Segment::Pointer meta_segment_{}; + IndexStorage::Segment::Pointer vector_segment_{}; + IndexStorage::Segment::Pointer key_segment_{}; + IndexStorage::Segment::Pointer key_mapping_segment_{}; + + std::string key_buffer_; + std::string key_mapping_buffer_; + + size_t sector_num_per_node_{0}; + + mutable size_t sector_id_{-1U}; + mutable std::string sector_buffer_; + + bool loaded_{false}; +}; + +} // namespace core +} // namespace zvec \ No newline at end of file diff --git a/src/core/algorithm/diskann/diskann_searcher.cc b/src/core/algorithm/diskann/diskann_searcher.cc new file mode 100644 index 000000000..bd5093142 --- /dev/null +++ b/src/core/algorithm/diskann/diskann_searcher.cc @@ -0,0 +1,308 @@ +// Copyright 2025-present the zvec project +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "diskann_searcher.h" +#include "diskann_context.h" +#include "diskann_indexer.h" +#include "diskann_params.h" + +namespace zvec { +namespace core { + +DiskAnnSearcher::DiskAnnSearcher() {} + +DiskAnnSearcher::~DiskAnnSearcher() {} + +int DiskAnnSearcher::init(const ailego::Params &search_params) { + search_params.get(PARAM_DISKANN_SEARCHER_LIST_SIZE, &list_size_); + search_params.get(PARAM_DISKANN_SEARCHER_CACHE_NODE_NUM, &cache_nodes_num_); + return 0; +} + +void DiskAnnSearcher::print_debug_info() {} + +int DiskAnnSearcher::cleanup() { + LOG_INFO("Begin DiskAnnSearcher:cleanup"); + + LOG_INFO("End DiskAnnSearcher:cleanup"); + + return 0; +} + +int DiskAnnSearcher::load(IndexStorage::Pointer storage, + IndexMetric::Pointer measure) { + LOG_INFO("DiskAnnSearcher::load Begin"); + + auto start_time = ailego::Monotime::MilliSeconds(); + + int ret = IndexHelper::DeserializeFromStorage(storage.get(), &meta_); + if (ret != 0) { + LOG_ERROR("Failed to deserialize meta from storage"); + return ret; + } + + ret = entity_.load(meta_, storage); + if (ret != 0) { + LOG_INFO("Searcher Entity Load Failed"); + return ret; + } + + diskann_indexer_ = std::make_shared(meta_); + + int res = diskann_indexer_->init(entity_); + if (res != 0) { + return res; + } + + if (cache_nodes_num_ != 0) { + std::vector node_list; + LOG_INFO("Caching %u nodes around medoid(s)", cache_nodes_num_); + + diskann_indexer_->cache_bfs_levels(cache_nodes_num_, node_list); + + diskann_indexer_->load_cache_list(node_list); + + node_list.clear(); + node_list.shrink_to_fit(); + } + + if (measure) { + measure_ = measure; + } else { + measure_ = IndexFactory::CreateMetric(meta_.metric_name()); + if (!measure_) { + LOG_ERROR("CreateMetric failed, name: %s", meta_.metric_name().c_str()); + return IndexError_NoExist; + } + ret = measure_->init(meta_, meta_.metric_params()); + if (ret != 0) { + LOG_ERROR("IndexMetric init failed, ret=%d", ret); + return ret; + } + if (measure_->query_metric()) { + measure_ = measure_->query_metric(); + } + } + + stats_.set_loaded_costtime(ailego::Monotime::MilliSeconds() - start_time); + state_ = STATE_LOADED; + + magic_ = IndexContext::GenerateMagic(); + + LOG_INFO("DiskAnnSearcher::load Done"); + + return 0; +} + +int DiskAnnSearcher::unload() { + LOG_INFO("DiskAnnSearcher unload index"); + + state_ = STATE_INITED; + + return 0; +} + +int DiskAnnSearcher::update_context(DiskAnnContext *ctx) const { + const DiskAnnEntity::Pointer entity = entity_.clone(); + if (!entity) { + LOG_ERROR("Failed to clone search context entity"); + return IndexError_Runtime; + } + + return ctx->update_context(DiskAnnContext::kSearcherContext, meta_, measure_, + entity, magic_); +} + +int DiskAnnSearcher::search_impl(const void *query, const IndexQueryMeta &qmeta, + uint32_t count, + Context::Pointer &context) const { + // do search + if (ailego_unlikely(!query || !context)) { + LOG_ERROR("The context is not created by this searcher"); + return IndexError_Mismatch; + } + + DiskAnnContext *ctx = dynamic_cast(context.get()); + ailego_do_if_false(ctx) { + LOG_ERROR("Cast context to DiskAnnContext failed"); + return IndexError_Cast; + } + + // Context is pooled per index type. When switching between DiskAnn indexes + // with different element sizes (e.g., fp16 vs fp32), the cached context has + // undersized buffers. Recreate it to ensure correct buffer allocations. + if (ctx->magic() != magic_) { + uint32_t saved_topk = ctx->topk(); + context = create_context(); + if (!context) { + LOG_ERROR("Failed to recreate context for current streamer"); + return IndexError_Runtime; + } + ctx = dynamic_cast(context.get()); + ctx->set_topk(saved_topk); + } + + ctx->clear(); + ctx->resize_results(count); + + for (uint32_t i = 0; i < count; i++) { + ctx->reset_query(query); + + diskann_indexer_->knn_search(ctx); + + if (ailego_unlikely(ctx->error())) { + return IndexError_Runtime; + } + + ctx->topk_to_result(i); + + query = static_cast(query) + qmeta.element_size(); + } + + return 0; +} + +int DiskAnnSearcher::search_bf_impl(const void *query, + const IndexQueryMeta &qmeta, uint32_t count, + Context::Pointer &context) const { + if (ailego_unlikely(!query || !context)) { + LOG_ERROR("The context is not created by this searcher"); + return IndexError_Mismatch; + } + + DiskAnnContext *ctx = dynamic_cast(context.get()); + ailego_do_if_false(ctx) { + LOG_ERROR("Cast context to DiskAnnContext failed"); + return IndexError_Cast; + } + + if (ctx->magic() != magic_) { + //! context is created by another searcher or streamer, recreate it + //! to ensure buffers are correctly sized for this index's parameters. + uint32_t saved_topk = ctx->topk(); + context = create_context(); + if (!context) { + LOG_ERROR("Failed to recreate context for current streamer"); + return IndexError_Runtime; + } + ctx = dynamic_cast(context.get()); + ctx->set_topk(saved_topk); + } + + ctx->clear(); + ctx->resize_results(count); + + for (size_t i = 0; i < count; ++i) { + ctx->reset_query(query); + + diskann_indexer_->linear_search(ctx); + + ctx->topk_to_result(i); + + query = static_cast(query) + qmeta.element_size(); + } + + if (ailego_unlikely(ctx->error())) { + return IndexError_Runtime; + } + + return 0; +} + +int DiskAnnSearcher::search_bf_by_p_keys_impl( + const void *query, const std::vector> &p_keys, + const IndexQueryMeta &qmeta, uint32_t count, + Context::Pointer &context) const { + if (ailego_unlikely(!query || !context)) { + LOG_ERROR("The context is not created by this searcher"); + return IndexError_Mismatch; + } + + DiskAnnContext *ctx = dynamic_cast(context.get()); + ailego_do_if_false(ctx) { + LOG_ERROR("Cast context to DiskAnnContext failed"); + return IndexError_Cast; + } + + if (ailego_unlikely(p_keys.size() != count)) { + LOG_ERROR("The size of p_keys is not equal to count"); + return IndexError_InvalidArgument; + } + + if (ctx->magic() != magic_) { + //! context is created by another searcher or streamer, recreate it + //! to ensure buffers are correctly sized for this index's parameters. + uint32_t saved_topk = ctx->topk(); + context = create_context(); + if (!context) { + LOG_ERROR("Failed to recreate context for current streamer"); + return IndexError_Runtime; + } + ctx = dynamic_cast(context.get()); + ctx->set_topk(saved_topk); + } + + ctx->clear(); + ctx->resize_results(count); + + for (size_t i = 0; i < count; ++i) { + ctx->reset_query(query); + + diskann_indexer_->keys_search(p_keys[i], ctx); + + ctx->topk_to_result(i); + + query = static_cast(query) + qmeta.element_size(); + } + + if (ailego_unlikely(ctx->error())) { + return IndexError_Runtime; + } + + return 0; +} + +int DiskAnnSearcher::get_vector(uint64_t key, Context::Pointer &context, + std::string &vector) const { + return diskann_indexer_->get_vector(key, context, vector); +} + +IndexSearcher::Context::Pointer DiskAnnSearcher::create_context() const { + const DiskAnnEntity::Pointer search_ctx_entity = entity_.clone(); + if (!search_ctx_entity) { + LOG_ERROR("Failed to create search context entity"); + return Context::Pointer(); + } + + DiskAnnContext *ctx = + new (std::nothrow) DiskAnnContext(meta_, measure_, search_ctx_entity); + if (ailego_unlikely(ctx->init( + DiskAnnContext::kSearcherContext, search_ctx_entity->max_degree(), + search_ctx_entity->pq_chunk_num(), meta_.element_size())) != 0) { + LOG_ERROR("Init DiskAnn Context failed"); + delete ctx; + + return Context::Pointer(); + } + + ctx->set_list_size(list_size_); + ctx->set_magic(magic_); + + return Context::Pointer(ctx); +} + +INDEX_FACTORY_REGISTER_SEARCHER(DiskAnnSearcher); + +} // namespace core +} // namespace zvec diff --git a/src/core/algorithm/diskann/diskann_searcher.h b/src/core/algorithm/diskann/diskann_searcher.h new file mode 100644 index 000000000..99584fa35 --- /dev/null +++ b/src/core/algorithm/diskann/diskann_searcher.h @@ -0,0 +1,166 @@ +// Copyright 2025-present the zvec project +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +#pragma once + +#include +#include "diskann_context.h" +#include "diskann_indexer.h" + +class LinuxAlignedFileReader; + +namespace zvec { +namespace core { + +class DiskAnnSearcher : public IndexSearcher { + public: + using ContextPointer = IndexSearcher::Context::Pointer; + + public: + DiskAnnSearcher(void); + ~DiskAnnSearcher(void); + + DiskAnnSearcher(const DiskAnnSearcher &) = delete; + DiskAnnSearcher &operator=(const DiskAnnSearcher &) = delete; + + protected: + //! Initialize Searcher + int init(const ailego::Params ¶ms) override; + + //! Cleanup Searcher + int cleanup(void) override; + + //! Load Index from storage + int load(IndexStorage::Pointer storage, IndexMetric::Pointer metric) override; + + //! Unload index from storage + int unload(void) override; + + //! KNN Search + int search_impl(const void *query, const IndexQueryMeta &qmeta, + ContextPointer &context) const override { + return search_impl(query, qmeta, 1, context); + } + + //! KNN Search + int search_impl(const void *query, const IndexQueryMeta &qmeta, + uint32_t count, ContextPointer &context) const override; + + //! Linear Search + int search_bf_impl(const void *query, const IndexQueryMeta &qmeta, + ContextPointer &context) const override { + return search_bf_impl(query, qmeta, 1, context); + } + + //! Linear Search + int search_bf_impl(const void *query, const IndexQueryMeta &qmeta, + uint32_t count, ContextPointer &context) const override; + + //! Linear search by primary keys + int search_bf_by_p_keys_impl(const void *query, + const std::vector> &p_keys, + const IndexQueryMeta &qmeta, + ContextPointer &context) const override { + return search_bf_by_p_keys_impl(query, p_keys, qmeta, 1, context); + } + + //! Linear search by primary keys + int search_bf_by_p_keys_impl(const void *query, + const std::vector> &p_keys, + const IndexQueryMeta &qmeta, uint32_t count, + ContextPointer &context) const override; + + //! Linear search by primary keys + int search_bf_by_p_keys_impl(const void *query, const uint32_t sparse_count, + const uint32_t *sparse_indices, + const void *sparse_query, + const std::vector> &p_keys, + const IndexQueryMeta &qmeta, + ContextPointer &context) const override { + return search_bf_by_p_keys_impl(query, &sparse_count, sparse_indices, + sparse_query, p_keys, qmeta, 1, context); + } + + //! Linear search by primary keys + int search_bf_by_p_keys_impl( + const void * /*query*/, const uint32_t * /*sparse_count*/, + const uint32_t * /*sparse_indices*/, const void * /*sparse_query*/, + const std::vector> & /*p_keys*/, + const IndexQueryMeta & /*qmeta*/, uint32_t /*count*/, + ContextPointer & /*context*/) const override { + return IndexError_NotImplemented; + } + + //! Get vector by key + int get_vector(uint64_t key, Context::Pointer &context, + std::string &vector) const override; + + //! Create a searcher context + ContextPointer create_context() const override; + + //! Create a new iterator + IndexSearcher::Provider::Pointer create_provider(void) const override { + return nullptr; + } + + //! Retrieve statistics + const Stats &stats(void) const override { + return stats_; + } + + //! Retrieve meta of index + const IndexMeta &meta(void) const override { + return meta_; + } + + //! Retrieve params of index + const ailego::Params ¶ms(void) const override { + return params_; + } + + void print_debug_info() override; + + private: + template + int search_disk_index(const std::string &query_file, + const uint32_t num_nodes_to_cache, + const uint32_t recall_at, const uint32_t beamwidth); + + //! To share ctx across streamer/searcher, we need to update the context for + //! current streamer/searcher + int update_context(DiskAnnContext *ctx) const; + + private: + enum State { STATE_INIT = 0, STATE_INITED = 1, STATE_LOADED = 2 }; + + IndexMetric::Pointer measure_{}; + IndexMeta meta_{}; + ailego::Params params_{}; + + uint32_t list_size_{200}; + uint32_t cache_nodes_num_{0}; + + bool warm_up_{false}; + uint32_t beam_size_{2}; + + DiskAnnIndexer::Pointer diskann_indexer_{nullptr}; + DiskAnnSearcherEntity entity_{}; + + uint32_t magic_{0U}; + + Stats stats_; + State state_{STATE_INIT}; +}; + +} // namespace core +} // namespace zvec diff --git a/src/core/algorithm/diskann/diskann_searcher_entity.cc b/src/core/algorithm/diskann/diskann_searcher_entity.cc new file mode 100644 index 000000000..36a23c600 --- /dev/null +++ b/src/core/algorithm/diskann/diskann_searcher_entity.cc @@ -0,0 +1,443 @@ +// Copyright 2025-present the zvec project +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "diskann_searcher_entity.h" + +namespace zvec { +namespace core { + +const DiskAnnEntity::Pointer DiskAnnSearcherEntity::clone() const { + auto meta_segment = meta_segment_->clone(); + if (ailego_unlikely(!meta_segment)) { + LOG_ERROR("clone segment %s failed", kDiskAnnMetaSegmentId.c_str()); + return DiskAnnEntity::Pointer(); + } + + auto pq_meta_segment = pq_meta_segment_->clone(); + if (ailego_unlikely(!pq_meta_segment)) { + LOG_ERROR("clone segment %s failed", kDiskAnnPqMetaSegmentId.c_str()); + return DiskAnnEntity::Pointer(); + } + + auto pq_data_segment = pq_data_segment_->clone(); + if (ailego_unlikely(!pq_data_segment)) { + LOG_ERROR("clone segment %s failed", kDiskAnnPqDataSegmentId.c_str()); + return DiskAnnEntity::Pointer(); + } + + auto vector_segment = vector_segment_->clone(); + if (ailego_unlikely(!vector_segment)) { + LOG_ERROR("clone segment %s failed", kDiskAnnVectorSegmentId.c_str()); + return DiskAnnEntity::Pointer(); + } + + auto key_segment = key_segment_->clone(); + if (ailego_unlikely(!key_segment)) { + LOG_ERROR("clone segment %s failed", kDiskAnnKeySegmentId.c_str()); + return DiskAnnEntity::Pointer(); + } + + auto key_mapping_segment = key_mapping_segment_->clone(); + if (ailego_unlikely(!key_mapping_segment)) { + LOG_ERROR("clone segment %s failed", kDiskAnnKeyMappingSegmentId.c_str()); + return DiskAnnEntity::Pointer(); + } + + auto entrypoint_segment = entrypoint_segment_->clone(); + if (ailego_unlikely(!entrypoint_segment)) { + LOG_ERROR("clone segment %s failed", kDiskAnnEntryPointSegmentId.c_str()); + return DiskAnnEntity::Pointer(); + } + + DiskAnnSearcherEntity *entity = new (std::nothrow) DiskAnnSearcherEntity( + meta_header_, pq_meta_, meta_segment, pq_meta_segment, pq_data_segment, + vector_segment, key_segment, key_mapping_segment, entrypoint_segment, + num_threads_, list_size_, cache_nodes_num_, warm_up_, beam_size_, meta_, + pq_table_, key_buffer_, key_mapping_buffer_, entrypoints_); + if (ailego_unlikely(!entity)) { + LOG_ERROR("DiskAnnSearcherEntity new failed"); + } + + return DiskAnnEntity::Pointer(entity); +} + +int DiskAnnSearcherEntity::load(const IndexMeta &meta, + IndexStorage::Pointer storage) { + meta_ = meta; + + storage_ = storage; + + int ret; + ret = load_header_segment(); + if (ret != 0) { + LOG_ERROR("Load Header Segment Failed, ret = %d", ret); + + return ret; + } + + ret = load_pq_segment(); + if (ret != 0) { + LOG_ERROR("Load PQ Meta Segment Failed, ret = %d", ret); + + return ret; + } + + ret = load_key_segment(); + if (ret != 0) { + LOG_ERROR("Load Key Segment Failed, ret = %d", ret); + + return ret; + } + + ret = load_key_mapping_segment(); + if (ret != 0) { + LOG_ERROR("Load Key Segment Failed, ret = %d", ret); + + return ret; + } + + ret = load_entrypoint_segment(); + if (ret != 0) { + LOG_WARN("Load EntryPoint Segment Failed, ret = %d", ret); + + return ret; + } + + ret = load_vector_segment(); + if (ret != 0) { + LOG_ERROR("Load Vector Segment Failed, ret = %d", ret); + + return ret; + } + + return 0; +} + +int DiskAnnSearcherEntity::load_pq_segment() { + const void *data = nullptr; + + // load pq meta + pq_meta_segment_ = storage_->get(DiskAnnEntity::kDiskAnnPqMetaSegmentId); + if (!pq_meta_segment_) { + LOG_ERROR("Miss or invalid segment %s", + DiskAnnEntity::kDiskAnnPqMetaSegmentId.c_str()); + return IndexError_InvalidFormat; + } + + size_t read_size; + size_t offset = 0; + + // 1. read pq meta + read_size = pq_meta_segment_->read(offset, &data, sizeof(DiskAnnPqMeta)); + if (read_size != sizeof(DiskAnnPqMeta)) { + LOG_ERROR("Read segment %s failed, expect: %zu, actual: %zu", + DiskAnnEntity::kDiskAnnPqMetaSegmentId.c_str(), + sizeof(DiskAnnPqMeta), read_size); + + return IndexError_ReadData; + } + + memcpy(reinterpret_cast(&pq_meta_), data, sizeof(DiskAnnPqMeta)); + offset += read_size; + + // 2. read full pivot data + std::vector full_pivot_data; + full_pivot_data.resize(pq_meta_.full_pivot_data_size); + + read_size = + pq_meta_segment_->read(offset, &data, pq_meta_.full_pivot_data_size); + if (read_size != pq_meta_.full_pivot_data_size) { + LOG_ERROR("Read segment %s failed, expect: %zu, actual: %zu", + DiskAnnEntity::kDiskAnnPqMetaSegmentId.c_str(), + (size_t)(pq_meta_.full_pivot_data_size), (size_t)read_size); + return IndexError_ReadData; + } + memcpy(&(full_pivot_data[0]), data, read_size); + offset += read_size; + + // 3. read centroid + std::vector centroid; + centroid.resize(pq_meta_.centroid_data_size); + + read_size = + pq_meta_segment_->read(offset, &data, pq_meta_.centroid_data_size); + if (read_size != pq_meta_.centroid_data_size) { + LOG_ERROR("Read segment %s failed, expect: %zu, actual: %zu", + DiskAnnEntity::kDiskAnnPqMetaSegmentId.c_str(), + (size_t)(pq_meta_.centroid_data_size), (size_t)read_size); + return IndexError_ReadData; + } + memcpy(&(centroid[0]), data, read_size); + offset += read_size; + + // 4. chunk offset + std::vector chunk_offsets; + chunk_offsets.resize(pq_meta_.chunk_num + 1); + + read_size = pq_meta_segment_->read( + offset, &data, (pq_meta_.chunk_num + 1) * sizeof(uint32_t)); + if (read_size != (pq_meta_.chunk_num + 1) * sizeof(uint32_t)) { + LOG_ERROR("Read segment %s failed, expect: %zu, actual: %zu", + DiskAnnEntity::kDiskAnnPqMetaSegmentId.c_str(), + (size_t)((pq_meta_.chunk_num + 1) * sizeof(uint32_t)), + (size_t)read_size); + return IndexError_ReadData; + } + memcpy(&(chunk_offsets[0]), data, read_size); + + // load pq data + std::vector pq_data; + pq_data_segment_ = storage_->get(DiskAnnEntity::kDiskAnnPqDataSegmentId); + if (!pq_data_segment_) { + LOG_ERROR("Miss or invalid segment %s", + DiskAnnEntity::kDiskAnnPqDataSegmentId.c_str()); + return IndexError_InvalidFormat; + } + + pq_data.resize(meta_header_.doc_cnt * pq_meta_.chunk_num); + + void *pq_data_ptr = &pq_data[0]; + read_size = pq_data_segment_->fetch( + 0, pq_data_ptr, meta_header_.doc_cnt * pq_meta_.chunk_num); + + if (read_size != meta_header_.doc_cnt * pq_meta_.chunk_num) { + LOG_ERROR("Read segment %s failed, expect: %zu, actual: %zu", + DiskAnnEntity::kDiskAnnPqMetaSegmentId.c_str(), + (size_t)(meta_header_.doc_cnt * pq_meta_.chunk_num), + (size_t)read_size); + + return IndexError_ReadData; + } + + pq_table_ = std::make_shared(meta_, pq_meta_.chunk_num); + + pq_table_->init(full_pivot_data, centroid, chunk_offsets, pq_data); + + return 0; +} + +int DiskAnnSearcherEntity::load_header_segment() { + const void *data = nullptr; + meta_segment_ = storage_->get(kDiskAnnMetaSegmentId); + if (!meta_segment_ || + meta_segment_->data_size() < sizeof(DiskAnnMetaHeader)) { + LOG_ERROR("Miss or invalid segment %s", kDiskAnnMetaSegmentId.c_str()); + return IndexError_InvalidFormat; + } + if (meta_segment_->read(0, reinterpret_cast(&data), + sizeof(DiskAnnMetaHeader)) != + sizeof(DiskAnnMetaHeader)) { + LOG_ERROR("Read segment %s failed", kDiskAnnMetaSegmentId.c_str()); + return IndexError_ReadData; + } + memcpy(reinterpret_cast(&meta_header_), data, + sizeof(DiskAnnMetaHeader)); + + return 0; +} + +int DiskAnnSearcherEntity::load_vector_segment() { + vector_segment_ = storage_->get(kDiskAnnVectorSegmentId); + if (!vector_segment_) { + LOG_ERROR("Miss or invalid segment %s", + DiskAnnEntity::kDiskAnnVectorSegmentId.c_str()); + return IndexError_InvalidFormat; + } + + return 0; +} + +int DiskAnnSearcherEntity::load_key_segment() { + // load key + key_segment_ = storage_->get(kDiskAnnKeySegmentId); + if (!key_segment_) { + LOG_ERROR("Miss or invalid segment %s", + DiskAnnEntity::kDiskAnnKeySegmentId.c_str()); + return IndexError_InvalidFormat; + } + + size_t key_data_len = doc_cnt() * sizeof(diskann_key_t); + + const void *data = nullptr; + if (key_segment_->read(0, reinterpret_cast(&data), + key_data_len) != key_data_len) { + LOG_ERROR("Read segment %s failed", kDiskAnnKeySegmentId.c_str()); + return IndexError_ReadData; + } + + key_buffer_.resize(key_data_len); + memcpy(&(key_buffer_[0]), data, key_data_len); + + return 0; +} + +int DiskAnnSearcherEntity::load_entrypoint_segment() { + entrypoint_segment_ = storage_->get(kDiskAnnEntryPointSegmentId); + if (!entrypoint_segment_) { + LOG_ERROR("Miss or invalid segment %s", + DiskAnnEntity::kDiskAnnEntryPointSegmentId.c_str()); + return IndexError_InvalidFormat; + } + + const void *data = nullptr; + + if (entrypoint_segment_->read(0, reinterpret_cast(&data), + sizeof(uint32_t)) != sizeof(uint32_t)) { + LOG_ERROR("Read segment %s failed", kDiskAnnEntryPointSegmentId.c_str()); + return IndexError_ReadData; + } + + uint32_t entrypoint_cnt = 0; + memcpy(&entrypoint_cnt, data, sizeof(uint32_t)); + + if (entrypoint_cnt != 0) { + size_t entrypoint_data_len = entrypoint_cnt * sizeof(diskann_id_t); + + if (entrypoint_segment_->read(sizeof(uint32_t), + reinterpret_cast(&data), + entrypoint_data_len) != entrypoint_data_len) { + LOG_ERROR("Read segment %s failed", kDiskAnnEntryPointSegmentId.c_str()); + return IndexError_ReadData; + } + + entrypoints_.resize(entrypoint_cnt); + memcpy(&(entrypoints_[0]), data, entrypoint_data_len); + } + + return 0; +} + + +int DiskAnnSearcherEntity::load_key_mapping_segment() { + key_mapping_segment_ = storage_->get(kDiskAnnKeyMappingSegmentId); + if (!key_mapping_segment_) { + LOG_ERROR("Miss or invalid segment %s", + DiskAnnEntity::kDiskAnnKeyMappingSegmentId.c_str()); + return IndexError_InvalidFormat; + } + + size_t key_mapping_data_len = doc_cnt() * sizeof(diskann_id_t); + + const void *data = nullptr; + if (key_mapping_segment_->read(0, reinterpret_cast(&data), + key_mapping_data_len) != + key_mapping_data_len) { + LOG_ERROR("Read segment %s failed", kDiskAnnKeyMappingSegmentId.c_str()); + return IndexError_ReadData; + } + + key_mapping_buffer_.resize(key_mapping_data_len); + memcpy(&(key_mapping_buffer_[0]), data, key_mapping_data_len); + + return 0; +} + +//! Get vector local id by key +diskann_id_t DiskAnnSearcherEntity::get_id(diskann_key_t key) const { + const diskann_id_t *key_mapping_data_ptr = + reinterpret_cast(key_mapping_buffer_.data()); + + const diskann_key_t *key_data_ptr = + reinterpret_cast(key_buffer_.data()); + + //! Do binary search + diskann_id_t start = 0UL; + diskann_id_t end = doc_cnt(); + diskann_id_t idx = 0u; + while (start < end) { + idx = start + (end - start) / 2; + diskann_id_t local_id = key_mapping_data_ptr[idx]; + + const diskann_key_t local_key = key_data_ptr[local_id]; + + if (local_key < key) { + start = idx + 1; + } else if (local_key > key) { + end = idx; + } else { + return local_id; + } + } + + return kInvalidId; +} + +diskann_key_t DiskAnnSearcherEntity::get_key(diskann_id_t id) const { + const diskann_key_t *key_data_ptr = + reinterpret_cast(key_buffer_.data()); + + return key_data_ptr[id]; +} + +const void *DiskAnnSearcherEntity::get_vector(diskann_id_t id) const { + if (!vector_segment_) { + LOG_ERROR("Vector segment is null"); + return nullptr; + } + + uint64_t sector_offset = + DiskAnnUtil::get_node_sector(node_per_sector(), max_node_size(), + DiskAnnUtil::kSectorSize, id) * + DiskAnnUtil::kSectorSize; + uint64_t within_sector_offset = + (node_per_sector() == 0 ? 0 : (id % node_per_sector()) * max_node_size()); + uint64_t total_offset = sector_offset + within_sector_offset; + + size_t read_size = meta_.element_size(); + const void *vec; + if (ailego_unlikely(vector_segment_->read(total_offset, &vec, read_size) != + read_size)) { + LOG_ERROR("Read vector from segment failed, id: %u, offset: %lu", id, + total_offset); + return nullptr; + } + + return vec; +} + +std::pair DiskAnnSearcherEntity::get_neighbors( + diskann_id_t id) const { + if (!vector_segment_) { + return std::make_pair(0, nullptr); + } + + // size_t vector_segment_offset = vector_segment_->data_offset(); + + uint64_t read_sector_offset = + DiskAnnUtil::get_node_sector(node_per_sector(), max_node_size(), + DiskAnnUtil::kSectorSize, id) * + DiskAnnUtil::kSectorSize; + uint64_t node_vec_offset = + read_sector_offset + + (node_per_sector() == 0 ? 0 : (id % node_per_sector()) * max_node_size()); + + const void *data; + if (ailego_unlikely( + vector_segment_->read(node_vec_offset, &data, max_node_size()) != + max_node_size())) { + LOG_ERROR("Read neighbors from segment failed"); + return {0, nullptr}; + } + + const uint8_t *data_ptr = reinterpret_cast(data); + const diskann_id_t *node_neighbor = + reinterpret_cast(data_ptr + meta_.element_size()); + + auto neighbor_num = *node_neighbor; + + return std::make_pair(neighbor_num, node_neighbor + 1); +} + +} // namespace core +} // namespace zvec diff --git a/src/core/algorithm/diskann/diskann_searcher_entity.h b/src/core/algorithm/diskann/diskann_searcher_entity.h new file mode 100644 index 000000000..68ab49891 --- /dev/null +++ b/src/core/algorithm/diskann/diskann_searcher_entity.h @@ -0,0 +1,125 @@ +// Copyright 2025-present the zvec project +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +#pragma once + +#include +#include +#include "diskann_entity.h" +#include "diskann_file_reader.h" +#include "diskann_pq_table.h" + +namespace zvec { +namespace core { + +class DiskAnnSearcherEntity : public DiskAnnEntity { + public: + using Pointer = std::shared_ptr; + using SegmentPointer = IndexStorage::Segment::Pointer; + + public: + DiskAnnSearcherEntity() = default; + virtual ~DiskAnnSearcherEntity() = default; + + public: + virtual const DiskAnnEntity::Pointer clone() const override; + + int load(const IndexMeta &meta, IndexStorage::Pointer storage); + int load_pq_segment(); + int load_header_segment(); + int load_vector_segment(); + int load_key_segment(); + int load_key_mapping_segment(); + int load_entrypoint_segment(); + + PQTable::Pointer get_pq_table() { + return pq_table_; + } + + IndexStorage::Pointer get_storage() { + return storage_; + } + + SegmentPointer get_vector_segment() { + return vector_segment_; + } + + std::vector &entrypoints() { + return entrypoints_; + } + + std::pair get_neighbors( + diskann_id_t id) const override; + + diskann_id_t get_id(diskann_key_t key) const override; + diskann_key_t get_key(diskann_id_t id) const override; + const void *get_vector(diskann_id_t id) const override; + + private: + DiskAnnSearcherEntity( + const DiskAnnMetaHeader &meta_header, const DiskAnnPqMeta &pq_meta, + const SegmentPointer &meta_segment, const SegmentPointer &pq_meta_segment, + const SegmentPointer &pq_data_segment, + const SegmentPointer &vector_segment, const SegmentPointer &key_segment, + const SegmentPointer &key_mapping_segment, + const SegmentPointer &entrypoint_segment, uint32_t num_threads, + uint32_t list_size, uint32_t cache_nodes_num, bool warm_up, + uint32_t beam_size, const IndexMeta meta, PQTable::Pointer pq_table, + const std::string &key_buffer, const std::string &key_mapping_buffer, + const std::vector &entrypoints) + : DiskAnnEntity(meta_header, pq_meta), + meta_segment_(meta_segment), + pq_meta_segment_(pq_meta_segment), + pq_data_segment_(pq_data_segment), + vector_segment_(vector_segment), + key_segment_(key_segment), + key_mapping_segment_(key_mapping_segment), + entrypoint_segment_{entrypoint_segment}, + num_threads_{num_threads}, + list_size_{list_size}, + cache_nodes_num_{cache_nodes_num}, + warm_up_{warm_up}, + beam_size_{beam_size}, + meta_{meta}, + pq_table_{pq_table}, + key_buffer_{key_buffer}, + key_mapping_buffer_{key_mapping_buffer}, + entrypoints_{entrypoints} {} + + IndexStorage::Pointer storage_{}; + + SegmentPointer meta_segment_{nullptr}; + SegmentPointer pq_meta_segment_{nullptr}; + SegmentPointer pq_data_segment_{nullptr}; + SegmentPointer vector_segment_{nullptr}; + SegmentPointer key_segment_{nullptr}; + SegmentPointer key_mapping_segment_{nullptr}; + SegmentPointer entrypoint_segment_{nullptr}; + + uint32_t num_threads_{1}; + uint32_t list_size_{200}; + uint32_t cache_nodes_num_{0}; + + bool warm_up_{false}; + uint32_t beam_size_{2}; + + IndexMeta meta_; + + PQTable::Pointer pq_table_; + std::string key_buffer_; + std::string key_mapping_buffer_; + std::vector entrypoints_; +}; + +} // namespace core +} // namespace zvec diff --git a/src/core/algorithm/diskann/diskann_streamer.cc b/src/core/algorithm/diskann/diskann_streamer.cc new file mode 100644 index 000000000..009d3fe2b --- /dev/null +++ b/src/core/algorithm/diskann/diskann_streamer.cc @@ -0,0 +1,326 @@ +// Copyright 2025-present the zvec project +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "diskann_streamer.h" +#include "diskann_context.h" +#include "diskann_indexer.h" +#include "diskann_params.h" + +namespace zvec { +namespace core { + +DiskAnnStreamer::DiskAnnStreamer() {} + +DiskAnnStreamer::~DiskAnnStreamer() {} + +int DiskAnnStreamer::init(const IndexMeta &meta, + const ailego::Params &search_params) { + meta_ = meta; + search_params.get(PARAM_DISKANN_SEARCHER_LIST_SIZE, &list_size_); + search_params.get(PARAM_DISKANN_SEARCHER_CACHE_NODE_NUM, &cache_nodes_num_); + return 0; +} + +void DiskAnnStreamer::print_debug_info() {} + +int DiskAnnStreamer::cleanup() { + LOG_INFO("Begin DiskAnnStreamer:cleanup"); + + LOG_INFO("End DiskAnnStreamer:cleanup"); + + return 0; +} + +int DiskAnnStreamer::open(IndexStorage::Pointer storage) { + LOG_INFO("DiskAnnStreamer::load Begin"); + + auto start_time = ailego::Monotime::MilliSeconds(); + + int ret = IndexHelper::DeserializeFromStorage(storage.get(), &meta_); + if (ret != 0) { + LOG_ERROR("Failed to deserialize meta from storage"); + return ret; + } + + ret = entity_.load(meta_, storage); + if (ret != 0) { + LOG_INFO("Searcher Entity Load Failed"); + return ret; + } + + diskann_indexer_ = std::make_shared(meta_); + + int res = diskann_indexer_->init(entity_); + if (res != 0) { + return res; + } + + if (cache_nodes_num_ != 0) { + std::vector node_list; + LOG_INFO("Caching %u nodes around medoid(s)", cache_nodes_num_); + + diskann_indexer_->cache_bfs_levels(cache_nodes_num_, node_list); + + diskann_indexer_->load_cache_list(node_list); + + node_list.clear(); + node_list.shrink_to_fit(); + } + + measure_ = IndexFactory::CreateMetric(meta_.metric_name()); + if (!measure_) { + LOG_ERROR("CreateMetric failed, name: %s", meta_.metric_name().c_str()); + return IndexError_NoExist; + } + ret = measure_->init(meta_, meta_.metric_params()); + if (ret != 0) { + LOG_ERROR("IndexMetric init failed, ret=%d", ret); + return ret; + } + if (measure_->query_metric()) { + measure_ = measure_->query_metric(); + } + + stats_.set_loaded_costtime(ailego::Monotime::MilliSeconds() - start_time); + state_ = STATE_LOADED; + + magic_ = IndexContext::GenerateMagic(); + + LOG_INFO("DiskAnnStreamer::load Done"); + + return 0; +} + +int DiskAnnStreamer::unload() { + LOG_INFO("DiskAnnStreamer unload index"); + + state_ = STATE_INITED; + + return 0; +} + +int DiskAnnStreamer::update_context(DiskAnnContext *ctx) const { + const DiskAnnEntity::Pointer entity = entity_.clone(); + if (!entity) { + LOG_ERROR("Failed to clone search context entity"); + return IndexError_Runtime; + } + + return ctx->update_context(DiskAnnContext::kSearcherContext, meta_, measure_, + entity, magic_); +} + +int DiskAnnStreamer::search_impl(const void *query, const IndexQueryMeta &qmeta, + uint32_t count, + Context::Pointer &context) const { + // do search + if (ailego_unlikely(!query || !context)) { + LOG_ERROR("The context is not created by this searcher"); + return IndexError_Mismatch; + } + + DiskAnnContext *ctx = dynamic_cast(context.get()); + ailego_do_if_false(ctx) { + LOG_ERROR("Cast context to DiskAnnContext failed"); + return IndexError_Cast; + } + + // Context is pooled per index type. When switching between DiskAnn indexes + // with different element sizes (e.g., fp16 vs fp32), the cached context has + // undersized buffers. Recreate it to ensure correct buffer allocations. + if (ctx->magic() != magic_) { + uint32_t saved_topk = ctx->topk(); + context = create_context(); + if (!context) { + LOG_ERROR("Failed to recreate context for current streamer"); + return IndexError_Runtime; + } + ctx = dynamic_cast(context.get()); + ctx->set_topk(saved_topk); + } + + ctx->clear(); + ctx->resize_results(count); + + for (uint32_t i = 0; i < count; i++) { + ctx->reset_query(query); + + diskann_indexer_->knn_search(ctx); + + ctx->topk_to_result(i); + + query = static_cast(query) + qmeta.element_size(); + } + + return 0; +} + +int DiskAnnStreamer::search_bf_impl(const void *query, + const IndexQueryMeta &qmeta, uint32_t count, + Context::Pointer &context) const { + if (ailego_unlikely(!query || !context)) { + LOG_ERROR("The context is not created by this searcher"); + return IndexError_Mismatch; + } + + DiskAnnContext *ctx = dynamic_cast(context.get()); + ailego_do_if_false(ctx) { + LOG_ERROR("Cast context to DiskAnnContext failed"); + return IndexError_Cast; + } + + if (ctx->magic() != magic_) { + //! context is created by another searcher or streamer, recreate it + //! to ensure buffers are correctly sized for this index's parameters. + uint32_t saved_topk = ctx->topk(); + context = create_context(); + if (!context) { + LOG_ERROR("Failed to recreate context for current streamer"); + return IndexError_Runtime; + } + ctx = dynamic_cast(context.get()); + ctx->set_topk(saved_topk); + } + + ctx->clear(); + ctx->resize_results(count); + + for (size_t i = 0; i < count; ++i) { + ctx->reset_query(query); + + diskann_indexer_->linear_search(ctx); + + ctx->topk_to_result(i); + + query = static_cast(query) + qmeta.element_size(); + } + + if (ailego_unlikely(ctx->error())) { + return IndexError_Runtime; + } + + return 0; +} + +int DiskAnnStreamer::search_bf_by_p_keys_impl( + const void *query, const std::vector> &p_keys, + const IndexQueryMeta &qmeta, uint32_t count, + Context::Pointer &context) const { + if (ailego_unlikely(!query || !context)) { + LOG_ERROR("The context is not created by this searcher"); + return IndexError_Mismatch; + } + + DiskAnnContext *ctx = dynamic_cast(context.get()); + ailego_do_if_false(ctx) { + LOG_ERROR("Cast context to DiskAnnContext failed"); + return IndexError_Cast; + } + + if (ailego_unlikely(p_keys.size() != count)) { + LOG_ERROR("The size of p_keys is not equal to count"); + return IndexError_InvalidArgument; + } + + if (ctx->magic() != magic_) { + //! context is created by another searcher or streamer, recreate it + //! to ensure buffers are correctly sized for this index's parameters. + uint32_t saved_topk = ctx->topk(); + context = create_context(); + if (!context) { + LOG_ERROR("Failed to recreate context for current streamer"); + return IndexError_Runtime; + } + ctx = dynamic_cast(context.get()); + ctx->set_topk(saved_topk); + } + + ctx->clear(); + ctx->resize_results(count); + + for (size_t i = 0; i < count; ++i) { + ctx->reset_query(query); + + diskann_indexer_->keys_search(p_keys[i], ctx); + + ctx->topk_to_result(i); + + query = static_cast(query) + qmeta.element_size(); + } + + if (ailego_unlikely(ctx->error())) { + return IndexError_Runtime; + } + + return 0; +} + +int DiskAnnStreamer::get_vector(uint64_t key, Context::Pointer &context, + std::string &vector) const { + return diskann_indexer_->get_vector(key, context, vector); +} + +const void *DiskAnnStreamer::get_vector_by_id(uint32_t id) const { + // DiskAnn vectors are stored on disk in sector format; + // a const void* access requires sector I/O via create_context + // Return nullptr to indicate this path is not supported. + return nullptr; +} + +int DiskAnnStreamer::get_vector_by_id(const uint32_t id, + IndexStorage::MemoryBlock &block) const { + // Lazily create a reusable context for fetch operations + if (!fetch_ctx_) { + fetch_ctx_ = create_context(); + if (!fetch_ctx_) { + LOG_ERROR("Failed to create context for get_vector_by_id"); + return IndexError_Runtime; + } + } + int ret = diskann_indexer_->get_vector(id, fetch_ctx_, fetch_vector_buffer_); + if (ret != 0) { + LOG_ERROR("Failed to get vector by id: %u", id); + return IndexError_Runtime; + } + block.reset((void *)fetch_vector_buffer_.data()); + return 0; +} + +IndexSearcher::Context::Pointer DiskAnnStreamer::create_context() const { + const DiskAnnEntity::Pointer search_ctx_entity = entity_.clone(); + if (!search_ctx_entity) { + LOG_ERROR("Failed to create search context entity"); + return Context::Pointer(); + } + + DiskAnnContext *ctx = + new (std::nothrow) DiskAnnContext(meta_, measure_, search_ctx_entity); + if (ailego_unlikely(ctx->init( + DiskAnnContext::kSearcherContext, search_ctx_entity->max_degree(), + search_ctx_entity->pq_chunk_num(), meta_.element_size())) != 0) { + LOG_ERROR("Init DiskAnn Context failed"); + delete ctx; + + return Context::Pointer(); + } + + ctx->set_list_size(list_size_); + + return Context::Pointer(ctx); +} + +INDEX_FACTORY_REGISTER_STREAMER(DiskAnnStreamer); + +} // namespace core +} // namespace zvec diff --git a/src/core/algorithm/diskann/diskann_streamer.h b/src/core/algorithm/diskann/diskann_streamer.h new file mode 100644 index 000000000..6100450e2 --- /dev/null +++ b/src/core/algorithm/diskann/diskann_streamer.h @@ -0,0 +1,180 @@ +// Copyright 2025-present the zvec project +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +#pragma once + +#include +#include "diskann_context.h" +#include "diskann_indexer.h" + +class LinuxAlignedFileReader; + +namespace zvec { +namespace core { + +class DiskAnnStreamer : public IndexStreamer { + public: + using ContextPointer = IndexStreamer::Context::Pointer; + + public: + DiskAnnStreamer(void); + ~DiskAnnStreamer(void); + + DiskAnnStreamer(const DiskAnnStreamer &) = delete; + DiskAnnStreamer &operator=(const DiskAnnStreamer &) = delete; + + protected: + //! Initialize Searcher + int init(const IndexMeta &meta, const ailego::Params ¶ms) override; + + //! Cleanup Searcher + int cleanup(void) override; + + //! Load Index from storage + int open(IndexStorage::Pointer storage) override; + + //! Unload index from storage + int unload(void) override; + + //! KNN Search + int search_impl(const void *query, const IndexQueryMeta &qmeta, + ContextPointer &context) const override { + return search_impl(query, qmeta, 1, context); + } + + //! KNN Search + int search_impl(const void *query, const IndexQueryMeta &qmeta, + uint32_t count, ContextPointer &context) const override; + + //! Linear Search + int search_bf_impl(const void *query, const IndexQueryMeta &qmeta, + ContextPointer &context) const override { + return search_bf_impl(query, qmeta, 1, context); + } + + //! Linear Search + int search_bf_impl(const void *query, const IndexQueryMeta &qmeta, + uint32_t count, ContextPointer &context) const override; + + //! Linear search by primary keys + int search_bf_by_p_keys_impl(const void *query, + const std::vector> &p_keys, + const IndexQueryMeta &qmeta, + ContextPointer &context) const override { + return search_bf_by_p_keys_impl(query, p_keys, qmeta, 1, context); + } + + //! Linear search by primary keys + int search_bf_by_p_keys_impl(const void *query, + const std::vector> &p_keys, + const IndexQueryMeta &qmeta, uint32_t count, + ContextPointer &context) const override; + + //! Linear search by primary keys + int search_bf_by_p_keys_impl(const void *query, const uint32_t sparse_count, + const uint32_t *sparse_indices, + const void *sparse_query, + const std::vector> &p_keys, + const IndexQueryMeta &qmeta, + ContextPointer &context) const override { + return search_bf_by_p_keys_impl(query, &sparse_count, sparse_indices, + sparse_query, p_keys, qmeta, 1, context); + } + + //! Linear search by primary keys + int search_bf_by_p_keys_impl( + const void * /*query*/, const uint32_t * /*sparse_count*/, + const uint32_t * /*sparse_indices*/, const void * /*sparse_query*/, + const std::vector> & /*p_keys*/, + const IndexQueryMeta & /*qmeta*/, uint32_t /*count*/, + ContextPointer & /*context*/) const override { + return IndexError_NotImplemented; + } + + //! Get vector by key + int get_vector(uint64_t key, Context::Pointer &context, + std::string &vector) const override; + + //! Fetch vector by id + const void *get_vector_by_id(uint32_t id) const override; + + //! Fetch vector by id into memory block + int get_vector_by_id(const uint32_t id, + IndexStorage::MemoryBlock &block) const override; + + //! Create a searcher context + ContextPointer create_context() const override; + + //! Create a new iterator + IndexSearcher::Provider::Pointer create_provider(void) const override { + return nullptr; + } + + //! Retrieve statistics + const Stats &stats(void) const override { + return stats_; + } + + //! Retrieve meta of index + const IndexMeta &meta(void) const override { + return meta_; + } + + virtual int flush(uint64_t /*check_point*/) override { + return 0; + } + + virtual int close(void) override { + return this->unload(); + } + + void print_debug_info() override; + + private: + template + int search_disk_index(const std::string &query_file, + const uint32_t num_nodes_to_cache, + const uint32_t recall_at, const uint32_t beamwidth); + + //! To share ctx across streamer/searcher, we need to update the context for + //! current streamer/searcher + int update_context(DiskAnnContext *ctx) const; + + private: + enum State { STATE_INIT = 0, STATE_INITED = 1, STATE_LOADED = 2 }; + + IndexMetric::Pointer measure_{}; + IndexMeta meta_{}; + ailego::Params params_{}; + + uint32_t list_size_{200}; + uint32_t cache_nodes_num_{0}; + + bool warm_up_{false}; + uint32_t beam_size_{2}; + + DiskAnnIndexer::Pointer diskann_indexer_{nullptr}; + DiskAnnSearcherEntity entity_{}; + + // Mutable members for get_vector_by_id (caches context and buffer) + mutable ContextPointer fetch_ctx_{}; + mutable std::string fetch_vector_buffer_; + + uint32_t magic_{0U}; + + Stats stats_; + State state_{STATE_INIT}; +}; + +} // namespace core +} // namespace zvec diff --git a/src/core/algorithm/diskann/diskann_util.h b/src/core/algorithm/diskann/diskann_util.h new file mode 100644 index 000000000..a8adcb5fe --- /dev/null +++ b/src/core/algorithm/diskann/diskann_util.h @@ -0,0 +1,223 @@ +// Copyright 2025-present the zvec project +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +#pragma once + +#include +#include +#include "diskann_entity.h" + +namespace zvec { +namespace core { + +class DiskAnnUtil { + public: + static constexpr uint64_t kSectorSize = 4096; + static constexpr uint64_t kMaxSectorReadNum = 128; + // static constexpr double kSpaceForCachedNodeInGB = 0.25; + // static constexpr double kThresholdForCacheInGB = 1.0; + + public: + static inline size_t div_round_up(size_t x, size_t y) { + return (x / y + (x % y != 0)); + } + + static inline size_t round_up(size_t x, size_t y) { + return div_round_up(x, y) * y; + } + + static inline void alloc_aligned(void **ptr, size_t size, size_t align) { + *ptr = ::aligned_alloc(align, size); + } + + static inline void free_aligned(void *ptr) { + if (ptr == nullptr) { + return; + } + + free(ptr); + } + + template + static inline void convert_vector_to_residual(T *data, uint32_t blocksize_, + size_t dim, void *centroid) { + const T *centroid_ptr = reinterpret_cast(centroid); + for (size_t i = 0; i < blocksize_; i++) { + for (uint64_t d = 0; d < dim; d++) { + float data_float = data[i * dim + d]; + data_float -= centroid_ptr[d]; + data[i * dim + d] = data_float; + } + } + } + + static inline void convert_types_uint32_to_uint8(const uint32_t *src, + uint8_t *dest, size_t npts, + size_t dim) { + for (size_t i = 0; i < npts; i++) { + for (size_t j = 0; j < dim; j++) { + dest[i * dim + j] = src[i * dim + j]; + } + } + } + + static inline uint64_t get_node_sector(uint32_t node_per_sector, + uint32_t max_nodesize_, + uint32_t sectorsize_, + diskann_id_t node_id) { + return (node_per_sector > 0 + ? node_id / node_per_sector + : node_id * div_round_up(max_nodesize_, sectorsize_)); + } + + static inline uint32_t *offset_to_node_neighbor(uint8_t *node_buf, + uint32_t elementsize_) { + return (uint32_t *)(node_buf + elementsize_); + } + + static inline uint8_t *offset_to_node(uint32_t node_per_sector, + uint32_t max_nodesize_, + uint8_t *sector_buf, + diskann_id_t node_id) { + return sector_buf + (node_per_sector == 0 + ? 0 + : (node_id % node_per_sector) * max_nodesize_); + } + + static inline const uint8_t *offset_to_node_const(uint32_t node_per_sector, + uint32_t max_nodesize_, + const uint8_t *sector_buf, + diskann_id_t node_id) { + return sector_buf + (node_per_sector == 0 + ? 0 + : (node_id % node_per_sector) * max_nodesize_); + } +}; + +//! Neighbor +struct Neighbor { + public: + Neighbor() = default; + + Neighbor(diskann_id_t id, float distance) + : id{id}, distance{distance}, expanded(false) {} + + inline bool operator<(const Neighbor &other) const { + return distance < other.distance || + (distance == other.distance && id < other.id); + } + + inline bool operator==(const Neighbor &other) const { + return (id == other.id); + } + + public: + diskann_id_t id; + float distance; + bool expanded; +}; + +//! NeighborPriorityQueue +class NeighborPriorityQueue { + public: + NeighborPriorityQueue() : size_(0), capacity_(0), cur_(0) {} + + explicit NeighborPriorityQueue(size_t capacity) + : size_(0), capacity_(capacity), cur_(0), data_(capacity + 1) {} + + void insert(const Neighbor &nbr) { + if (size_ == capacity_ && data_[size_ - 1] < nbr) { + return; + } + + size_t low = 0, high = size_; + while (low < high) { + size_t mid = (low + high) >> 1; + if (nbr < data_[mid]) { + high = mid; + } else if (data_[mid].id == nbr.id) { + return; + } else { + low = mid + 1; + } + } + + if (low < capacity_) { + std::memmove(&data_[low + 1], &data_[low], + (size_ - low) * sizeof(Neighbor)); + } + + data_[low] = {nbr.id, nbr.distance}; + if (size_ < capacity_) { + size_++; + } + + if (low < cur_) { + cur_ = low; + } + } + + Neighbor closest_unexpanded() { + data_[cur_].expanded = true; + size_t pre = cur_; + while (cur_ < size_ && data_[cur_].expanded) { + cur_++; + } + return data_[pre]; + } + + bool has_unexpanded_node() const { + return cur_ < size_; + } + + size_t size() const { + return size_; + } + + size_t capacity() const { + return capacity_; + } + + void reserve(size_t capacity) { + if (capacity + 1 > data_.size()) { + data_.resize(capacity + 1); + } + capacity_ = capacity; + } + + Neighbor &operator[](size_t i) { + return data_[i]; + } + + Neighbor operator[](size_t i) const { + return data_[i]; + } + + void sort() { + std::sort(data_.begin(), data_.begin() + size_); + } + + void clear() { + size_ = 0; + cur_ = 0; + } + + private: + size_t size_; + size_t capacity_; + size_t cur_; + std::vector data_; +}; + +} // namespace core +} // namespace zvec \ No newline at end of file diff --git a/src/core/algorithm/diskann/diskann_vecs_reader.h b/src/core/algorithm/diskann/diskann_vecs_reader.h new file mode 100644 index 000000000..0d80933e7 --- /dev/null +++ b/src/core/algorithm/diskann/diskann_vecs_reader.h @@ -0,0 +1,320 @@ +// Copyright 2025-present the zvec project +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +#pragma once + +#include +#include +#include +#include + +namespace zvec { +namespace core { + +enum VecsBitMapIndex { + BITMAP_INDEX_KEY = 0, + BITMAP_INDEX_DENSE = 1, + BITMAP_INDEX_SPARSE = 2, + BITMAP_INDEX_PARTITION = 3, + BITMAP_INDEX_TAGLIST = 4 +}; + +#pragma pack(4) +struct VecsHeaderDetect { + uint64_t num_vecs; + uint16_t reserved; // in v0, it's meta size + uint16_t version; // 0 for v0 format, 1 for new format + uint8_t meta_buf[0]; +}; +#pragma pack() + +#pragma pack(4) +struct VecsHeader { + uint64_t num_vecs; + uint32_t meta_size; + uint8_t meta_buf[0]; +}; +#pragma pack() + +#pragma pack(4) +struct VecsHeaderV1 { + uint64_t num_vecs; + uint16_t meta_size_v1; + uint16_t version; + uint32_t meta_size; + uint64_t bitmap; // set for data section + uint64_t key_offset; // offset for key + uint64_t key_size; // size for key + uint64_t dense_offset; // offset for dense + uint64_t dense_size; // size for dense + uint64_t sparse_offset; // offset for sparse + uint64_t sparse_size; // size for sparse + uint64_t partition_offset; // offset for partition + uint64_t partition_size; // size for partition + uint64_t taglist_offset; // offset for taglist + uint64_t taglist_size; // size for taglist + uint8_t meta_buf[0]; +}; +#pragma pack() + +class DiskAnnVecsReader { + public: + DiskAnnVecsReader() + : mmap_file_(), + index_meta_(), + num_vecs_(0), + vector_base_(nullptr), + key_base_(nullptr), + sparse_base_meta_{nullptr}, + sparse_base_data_{nullptr}, + partition_base_{nullptr}, + taglist_base_meta_{nullptr}, + taglist_base_data_{nullptr}, + taglist_size_{0} {} + + void set_measure(const std::string &name, const IndexParams ¶ms) { + index_meta_.set_measure(name, 0, params); + } + + bool load(const std::string &fname) { + return load(fname.c_str()); + } + + bool load(const char *fname) { + if (!fname) { + std::cerr << "Load fname is nullptr" << std::endl; + return false; + } + if (!mmap_file_.open(fname, true)) { + std::cerr << "Open file error: " << fname << std::endl; + return false; + } + if (mmap_file_.size() < sizeof(VecsHeaderDetect)) { + std::cerr << "File size is too small: " << mmap_file_.size() << std::endl; + return false; + } + + const VecsHeaderDetect *header = + reinterpret_cast(mmap_file_.region()); + + if (header->version == 0) { + return load_v0(); + } else if (header->version == 1) { + return load_v1(); + } + + std::cerr << "Can not recognize version: " << header->version << std::endl; + + return false; + } + + bool load_v0() { + const VecsHeader *header = + reinterpret_cast(mmap_file_.region()); + // check + num_vecs_ = header->num_vecs; + + // deserialize + bool bret = index_meta_.deserialize(&header->meta_buf, header->meta_size); + if (!bret) { + std::cerr << "deserialize index meta error." << std::endl; + return false; + } + + if (!index_meta_.hybrid_vector()) { + if ((mmap_file_.size() - sizeof(*header) - header->meta_size) % + num_vecs_ != + 0) { + std::cerr << "input file foramt check error." << std::endl; + return false; + } + } + + if (!index_meta_.hybrid_vector()) { + vector_base_ = + reinterpret_cast(header + 1) + header->meta_size; + key_base_ = reinterpret_cast( + vector_base_ + num_vecs_ * index_meta_.element_size()); + } else { + vector_base_ = + reinterpret_cast(header + 1) + header->meta_size; + key_base_ = reinterpret_cast( + vector_base_ + num_vecs_ * index_meta_.element_size()); + sparse_base_meta_ = reinterpret_cast(key_base_ + num_vecs_); + sparse_base_data_ = reinterpret_cast( + sparse_base_meta_ + num_vecs_ * sizeof(uint64_t)); + } + + return true; + } + + bool load_v1() { + const VecsHeaderV1 *header = + reinterpret_cast(mmap_file_.region()); + // check + num_vecs_ = header->num_vecs; + + // deserialize + bool bret = index_meta_.deserialize(&header->meta_buf, header->meta_size); + if (!bret) { + std::cerr << "deserialize index meta error." << std::endl; + return false; + } + + const char *data_base_ptr = + reinterpret_cast(header + 1) + header->meta_size; + + vector_base_ = reinterpret_cast(data_base_ptr); + key_base_ = reinterpret_cast( + vector_base_ + num_vecs_ * index_meta_.element_size()); + + if (header->sparse_offset != -1LLU) { + sparse_base_meta_ = data_base_ptr + header->sparse_offset; + sparse_base_data_ = sparse_base_meta_ + num_vecs_ * sizeof(uint64_t); + } + + if (header->partition_offset != -1LLU) { + partition_base_ = reinterpret_cast( + data_base_ptr + header->partition_offset); + } + + if (header->taglist_offset != -1LLU) { + taglist_base_meta_ = data_base_ptr + header->taglist_offset; + taglist_base_data_ = taglist_base_meta_ + num_vecs_; + taglist_size_ = header->taglist_size; + } + + return true; + } + + size_t num_vecs() const { + return num_vecs_; + } + + const void *vector_base() const { + return vector_base_; + } + + const uint64_t *key_base() const { + return key_base_; + } + + const IndexMeta &index_meta() const { + return index_meta_; + } + + uint64_t get_key(size_t index) const { + return key_base_[index]; + } + + const void *get_vector(size_t index) const { + return vector_base_ + index * index_meta_.element_size(); + } + + uint32_t get_sparse_count(size_t index) const { + if (index_meta_.hybrid_vector()) { + auto sparse_data_meta = sparse_base_meta_ + index * sizeof(uint64_t); + uint64_t sparse_offset = *((uint64_t *)sparse_data_meta); + uint32_t sparse_count = + *((uint32_t *)(sparse_base_data_ + sparse_offset)); + + return sparse_count; + } + + return 0; + } + + const uint32_t *get_sparse_indices(size_t index) const { + if (index_meta_.hybrid_vector()) { + auto sparse_data_meta = sparse_base_meta_ + index * sizeof(uint64_t); + uint64_t sparse_offset = *((uint64_t *)sparse_data_meta); + uint32_t *sparse_indices = + (uint32_t *)(sparse_base_data_ + sparse_offset + sizeof(uint32_t)); + + return sparse_indices; + } + + return nullptr; + } + + const void *get_sparse_data(size_t index) const { + if (index_meta_.hybrid_vector()) { + auto sparse_data_meta = sparse_base_meta_ + index * sizeof(uint64_t); + uint64_t sparse_offset = *((uint64_t *)sparse_data_meta); + uint32_t sparse_count = + *((uint32_t *)(sparse_base_data_ + sparse_offset)); + void *sparse_data = + (uint32_t *)(sparse_base_data_ + sparse_offset + sizeof(uint32_t) + + sparse_count * sizeof(uint32_t)); + + return sparse_data; + } + + return nullptr; + } + + size_t get_total_sparse_count(void) const { + size_t total_sparse_count = 0; + for (size_t i = 0; i < num_vecs_; ++i) { + total_sparse_count += get_sparse_count(i); + } + + return total_sparse_count; + } + + bool has_taglist(void) const { + return taglist_base_meta_ != nullptr; + } + + uint64_t get_taglist_count(size_t index) const { + if (!taglist_base_data_ || !taglist_base_meta_) { + return 0; + } + + uint64_t taglist_count = *reinterpret_cast( + taglist_base_data_ + taglist_base_meta_[index]); + return taglist_count; + } + + const uint64_t *get_taglist(size_t index) const { + if (!taglist_base_data_ || !taglist_base_meta_) { + return nullptr; + } + + return reinterpret_cast(taglist_base_data_ + + taglist_base_meta_[index]) + + 1; + } + + const void *get_taglist_data(size_t &size) const { + size = taglist_size_; + + return taglist_base_meta_; + } + + private: + ailego::MMapFile mmap_file_; + IndexMeta index_meta_; + size_t num_vecs_; + const char *vector_base_; + const uint64_t *key_base_; + const char *sparse_base_meta_; + const char *sparse_base_data_; + const uint32_t *partition_base_; + const char *taglist_base_meta_; + const char *taglist_base_data_; + uint64_t taglist_size_; +}; + +} // namespace core +} // namespace zvec diff --git a/src/core/algorithm/diskann/diskann_visit_filter.h b/src/core/algorithm/diskann/diskann_visit_filter.h new file mode 100644 index 000000000..b0b879955 --- /dev/null +++ b/src/core/algorithm/diskann/diskann_visit_filter.h @@ -0,0 +1,419 @@ +// Copyright 2025-present the zvec project +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +#pragma once + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace zvec { +namespace core { + +struct VisitFilterHeader { + VisitFilterHeader() : maxDocCnt(0), maxScanNum(0) {} + uint64_t maxDocCnt; + uint64_t maxScanNum; +}; + +constexpr int PROXIMA_DISKANN_VISITFILTER_CUSTOM_PARAMS_INDEX_NEGPROB = 0; + +class VisitBloomFilter { + public: + static constexpr int mode = 1; + + static constexpr int N = 5; + struct Context { + Context() + : mt(std::chrono::system_clock::now().time_since_epoch().count()) {}; + VisitFilterHeader h; + std::mt19937 mt; + ailego::BloomFilter *filter{nullptr}; + int offset[N] = {0}; + }; +#define BLOOM_FILTER_HASH_BITS_OFFSETS(i) \ + i + c->offset[0], i + c->offset[1], i + c->offset[2], i + c->offset[3], \ + i + c->offset[4] + + VisitBloomFilter() = delete; + + inline static void set_visited(Context *c, id_t idx) { + c->filter->force_insert(BLOOM_FILTER_HASH_BITS_OFFSETS(idx)); + return; + } + + inline static void *get_visited(Context *, id_t) { + // TODO + return nullptr; + } + + inline static bool visited(Context *c, id_t idx) { + return c->filter->has(BLOOM_FILTER_HASH_BITS_OFFSETS(idx)); + } + + inline static int set_max_scan_num(Context *c, uint64_t maxScanNum) { + if (maxScanNum == c->h.maxScanNum) { + return 0; + } + c->h.maxScanNum = maxScanNum; + if (c->filter->reset(maxScanNum, c->filter->probability()) != 0) { + LOG_ERROR("reset BloomFilter failed"); + return IndexError_Runtime; + } + genRandomHashBits(c); + return 0; + } + + inline static void clear(Context *c) { + c->filter->clear(); + return; + } + + inline static bool reset(Context *c, uint64_t maxDocCnt, + uint64_t max_scan_num) { + if (ailego_unlikely(maxDocCnt > c->h.maxDocCnt || + max_scan_num > c->h.maxScanNum)) { + // Create a new one, if failed, we can reuse the old one + auto filter = new (std::nothrow) ailego::BloomFilter( + max_scan_num, c->filter->probability()); + if (ailego_unlikely(filter == nullptr)) { + LOG_ERROR("reset bloomfilter failed, maxScanNum %zu prob %f", + (size_t)max_scan_num, c->filter->probability()); + c->filter->clear(); + return false; + } + + delete c->filter; + c->filter = filter; + c->h.maxScanNum = max_scan_num; + c->h.maxDocCnt = maxDocCnt; + genRandomHashBits(c); + } + return true; + } + + inline static void genRandomHashBits(Context *c) { + std::uniform_int_distribution dt(0, c->h.maxDocCnt); + for (size_t i = 0; i < sizeof(c->offset) / sizeof(c->offset[0]); ++i) { + int r = dt(c->mt); + size_t j = 0; + do { // gen distinct number + for (j = 0; j < i; ++j) { + if (c->offset[j] == r) { + r = dt(c->mt); + break; + } + } + } while (j < i); + c->offset[i] = r; + } + std::sort(c->offset, c->offset + N); + } + + template + static int init(Context *, void **ctx, uint64_t maxDocCnt, + uint64_t maxScanNum, std::tuple &&tpl) { + Context *c = new (std::nothrow) Context; + if (c == nullptr) { + LOG_ERROR("New memory in initVisitBitMap failed"); + return IndexError_NoMemory; + } + c->h.maxDocCnt = maxDocCnt; + c->h.maxScanNum = maxScanNum; + float p = + std::get(tpl); + c->filter = new (std::nothrow) + ailego::BloomFilter(maxScanNum, p); + if (c->filter == nullptr) { + LOG_ERROR("New BloomFilter failed, reuse old one"); + return IndexError_NoMemory; + } + genRandomHashBits(c); + *ctx = c; + return 0; + } + + inline static void destroy(Context *c) { + delete c->filter; + delete c; + } +#undef BLOOM_FILTER_HASH_BITS_OFFSETS +}; // end of VisitBloomFilter + +class VisitBitMap { + public: + static constexpr int mode = 2; + + struct Context { + VisitFilterHeader h; + ailego::BitsetHelper bitset; + char *buf{nullptr}; + }; + + VisitBitMap() = delete; + + inline static void set_visited(Context *c, id_t idx) { + c->bitset.set(idx); + return; + } + + inline static void *get_visited(Context *c, id_t idx) { + return &c->buf[idx >> 3]; + } + + inline static bool visited(Context *c, id_t idx) { + return c->bitset.test(idx); + } + + inline static int set_max_scan_num(Context *c, uint64_t maxScanNum) { + c->h.maxScanNum = maxScanNum; + return 0; + } + + inline static void clear(Context *c) { + c->bitset.clear(); + return; + } + + inline static bool reset(Context *c, uint64_t maxDocCnt, + uint64_t maxScanNum) { + if (ailego_unlikely(maxDocCnt > c->h.maxDocCnt || + maxScanNum > c->h.maxScanNum)) { + uint64_t len = ((maxDocCnt + 31) >> 5) << 2; // round to uint32_t + auto buf = new (std::nothrow) char[len]; + if (buf == nullptr) { + LOG_ERROR("New memory in initVisitBitMap failed"); + c->bitset.clear(); + return false; + } + + c->h.maxDocCnt = maxDocCnt; + c->h.maxScanNum = maxScanNum; + delete[] c->buf; + c->buf = buf; + memset(c->buf, 0, len); + c->bitset.mount(c->buf, len); + } + return true; + } + + template + static int init(Context *, void **ctx, uint64_t maxDocCnt, + uint64_t maxScanNum, std::tuple &&tpl) { + (void)tpl; // unsed warning + Context *c = new (std::nothrow) Context; + if (c == nullptr) { + LOG_ERROR("New memory in initVisitBitMap failed"); + return IndexError_NoMemory; + } + c->h.maxDocCnt = maxDocCnt; + c->h.maxScanNum = maxScanNum; + uint64_t len = ((maxDocCnt + 31) >> 5) << 2; // round to uint32_t + c->buf = new (std::nothrow) char[len]; + if (c->buf == nullptr) { + LOG_ERROR("New memory in initVisitBitMap failed, reuse old one"); + delete c; + return IndexError_NoMemory; + } + memset(c->buf, 0, len); + c->bitset.mount(c->buf, len); + *ctx = c; + return 0; + } + + inline static void destroy(Context *c) { + delete[] c->buf; + delete c; + } +}; // end of VisitBitMap + +class VisitByteMap { + public: + static constexpr int mode = 3; + struct Context { + VisitFilterHeader h; + uint8_t curNum{0}; + uint8_t *arr{nullptr}; + }; + + VisitByteMap() = delete; + + inline static void set_visited(Context *c, id_t idx) { + c->arr[idx] = c->curNum; + return; + } + + inline static void *get_visited(Context *c, id_t idx) { + return c->arr + idx; + } + + inline static bool visited(Context *c, id_t idx) { + return c->arr[idx] == c->curNum; + } + + inline static int set_max_scan_num(Context *c, uint64_t maxScanNum) { + c->h.maxScanNum = maxScanNum; + return 0; + } + + inline static void clear(Context *c) { + c->curNum++; + if (c->curNum == 0) { + memset(c->arr, 0, c->h.maxDocCnt * sizeof(uint8_t)); + c->curNum = 1; + } + return; + } + + inline static bool reset(Context *c, uint64_t maxDocCnt, + uint64_t maxScanNum) { + if (ailego_unlikely(maxDocCnt > c->h.maxDocCnt || + maxScanNum > c->h.maxScanNum)) { + auto arr = new (std::nothrow) uint8_t[maxDocCnt]; + if (arr != nullptr) { + memset(arr, 0, maxDocCnt * sizeof(uint8_t)); + c->curNum = 1; + c->h.maxDocCnt = maxDocCnt; + c->h.maxScanNum = maxScanNum; + delete[] c->arr; + c->arr = arr; + return true; + } + LOG_ERROR("New memory in initVisitByteMap failed, reuse old one"); + } + return true; + } + + template + static int init(Context *, void **ctx, uint64_t maxDocCnt, + uint64_t maxScanNum, std::tuple &&tpl) { + (void)tpl; // unsed warning + Context *c = new (std::nothrow) Context; + if (c == nullptr) { + LOG_ERROR("New memory in initVisitByteMap failed"); + return IndexError_NoMemory; + } + c->h.maxDocCnt = maxDocCnt; + c->h.maxScanNum = maxScanNum; + c->arr = new (std::nothrow) uint8_t[maxDocCnt]; + if (c->arr == nullptr) { + LOG_ERROR("New memory in initVisitByteMap failed"); + delete c; + return IndexError_NoMemory; + } + memset(c->arr, 0, maxDocCnt * sizeof(uint8_t)); + c->curNum = 1; + *ctx = c; + return 0; + } + + inline static void destroy(Context *c) { + delete[] c->arr; + delete c; + } +}; // end of VisitByteMap + + +#define PROXIMA_DISKANN_VISITFILTER_SWITCH_CASE(cls, impl, ctx, ...) \ + case cls::mode: \ + return cls::impl(static_cast(ctx), ##__VA_ARGS__); + +#define PROXIMA_DISKANN_VISITFILTER_CALL_IMPL(impl, ...) \ + switch (mode_) { \ + PROXIMA_DISKANN_VISITFILTER_SWITCH_CASE(VisitBloomFilter, impl, ctx_, \ + ##__VA_ARGS__) \ + PROXIMA_DISKANN_VISITFILTER_SWITCH_CASE(VisitBitMap, impl, ctx_, \ + ##__VA_ARGS__) \ + PROXIMA_DISKANN_VISITFILTER_SWITCH_CASE(VisitByteMap, impl, ctx_, \ + ##__VA_ARGS__) \ + } + + +// visit list will be called with high frequency, +// so using switch instead of std::function or virtual class +// funtion point, lambda, virtual class all cannot be inlined +class VisitFilter { + public: + enum Mode { + Default = 0, + BloomFilter = VisitBloomFilter::mode, + BitMap = VisitBitMap::mode, + ByteMap = VisitByteMap::mode + }; + + VisitFilter() : mode_(0), ctx_(nullptr) {}; + + inline bool visited(id_t idx) { + PROXIMA_DISKANN_VISITFILTER_CALL_IMPL(visited, idx); + return true; // place holder + } + + inline void set_visited(id_t idx) { + PROXIMA_DISKANN_VISITFILTER_CALL_IMPL(set_visited, idx); + } + + inline void *get_visited(id_t idx) { + PROXIMA_DISKANN_VISITFILTER_CALL_IMPL(get_visited, idx); + return nullptr; // place holder + } + + inline int set_max_scan_num(id_t idx) { + PROXIMA_DISKANN_VISITFILTER_CALL_IMPL(set_max_scan_num, idx); + return 0; // place holder + } + + inline void clear() { + PROXIMA_DISKANN_VISITFILTER_CALL_IMPL(clear); + } + + inline bool reset(uint64_t maxDocCnt, uint64_t maxScanNum) { + PROXIMA_DISKANN_VISITFILTER_CALL_IMPL(reset, maxDocCnt, maxScanNum); + return true; + } + + inline void destroy() { + if (ctx_ != nullptr) { + PROXIMA_DISKANN_VISITFILTER_CALL_IMPL(destroy); + } + } + + int init(int mode, uint64_t maxDocCnt, uint64_t maxScanNum, + float negativeProbility) { + mode_ = mode; + PROXIMA_DISKANN_VISITFILTER_CALL_IMPL(init, &ctx_, maxDocCnt, maxScanNum, + std::make_tuple(negativeProbility)); + return 0; // place holder + } + + int get_mode(void) const { + return mode_; + } + + + private: + VisitFilter(const VisitFilter &) = delete; + VisitFilter &operator=(const VisitFilter &) = delete; + + int mode_{0U}; // custom data for each method + void *ctx_{nullptr}; +}; + +} // namespace core +} // namespace zvec diff --git a/src/core/interface/index_factory.cc b/src/core/interface/index_factory.cc index 0d8157282..dc599c368 100644 --- a/src/core/interface/index_factory.cc +++ b/src/core/interface/index_factory.cc @@ -47,6 +47,10 @@ Index::Pointer IndexFactory::CreateAndInitIndex(const BaseIndexParam ¶m) { ptr = std::make_shared(); } else if (param.index_type == IndexType::kHNSWRabitq) { ptr = std::make_shared(); +#if DISKANN_SUPPORTED + } else if (param.index_type == IndexType::kDiskAnn) { + ptr = std::make_shared(); +#endif } else { LOG_ERROR("Unsupported index type: "); return nullptr; diff --git a/src/core/interface/index_param.cc b/src/core/interface/index_param.cc index e21256ba3..4857b4e49 100644 --- a/src/core/interface/index_param.cc +++ b/src/core/interface/index_param.cc @@ -174,6 +174,27 @@ ailego::JsonObject HNSWRabitqIndexParam::SerializeToJsonObject( return json_obj; } + +bool DiskAnnIndexParam::DeserializeFromJsonObject( + const ailego::JsonObject &json_obj) { + if (!BaseIndexParam::DeserializeFromJsonObject(json_obj)) { + return false; + } + + if (index_type != IndexType::kDiskAnn) { + LOG_ERROR("index_type is not DiskAnn"); + return false; + } + + return true; +} + +ailego::JsonObject DiskAnnIndexParam::SerializeToJsonObject( + bool omit_empty_value) const { + auto json_obj = BaseIndexParam::SerializeToJsonObject(omit_empty_value); + return json_obj; +} + ailego::JsonObject QuantizerParam::SerializeToJsonObject( bool omit_empty_value) const { ailego::JsonObject json_obj; diff --git a/src/core/interface/indexes/diskann_index.cc b/src/core/interface/indexes/diskann_index.cc new file mode 100644 index 000000000..32760a509 --- /dev/null +++ b/src/core/interface/indexes/diskann_index.cc @@ -0,0 +1,276 @@ +// Copyright 2025-present the zvec project +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include +#include +#include "algorithm/diskann/diskann_params.h" + +namespace zvec::core_interface { + +static constexpr uint64_t kInvalidKey = std::numeric_limits::max(); + +int DiskAnnIndex::CreateAndInitStreamer(const BaseIndexParam ¶m) { + param_ = dynamic_cast(param); + + if (is_sparse_) { + LOG_ERROR("Failed to create streamer. Sparse is not Supported."); + return core::IndexError_Unsupported; + } + + param_ = dynamic_cast(param); + param_.max_degree = std::max(1, std::min(100, param_.max_degree)); + param_.list_size = std::max(10, std::min(100, param_.list_size)); + param_.pq_chunk_num = std::max(1, std::min(1024, param_.pq_chunk_num)); + + proxima_index_params_.set(core::PARAM_DISKANN_BUILDER_MAX_DEGREE, + param_.max_degree); + proxima_index_params_.set(core::PARAM_DISKANN_BUILDER_LIST_SIZE, + param_.list_size); + proxima_index_params_.set(core::PARAM_DISKANN_BUILDER_MAX_PQ_CHUNK_NUM, + param_.pq_chunk_num); + + builder_ = core::IndexFactory::CreateBuilder("DiskAnnBuilder"); + streamer_ = core::IndexFactory::CreateStreamer("DiskAnnStreamer"); + + if (ailego_unlikely(!builder_)) { + LOG_ERROR("Failed to create builder"); + return core::IndexError_Runtime; + } + + if (ailego_unlikely(!streamer_)) { + LOG_ERROR("Failed to create streamer"); + return core::IndexError_Runtime; + } + + IndexMeta real_meta; + if (converter_) { + real_meta = converter_->meta(); + } else { + real_meta = proxima_index_meta_; + } + + if (ailego_unlikely(builder_->init(real_meta, proxima_index_params_) != 0)) { + LOG_ERROR("Failed to init builder"); + return core::IndexError_Runtime; + } + if (ailego_unlikely(streamer_->init(real_meta, proxima_index_params_) != 0)) { + LOG_ERROR("Failed to init streamer"); + return core::IndexError_Runtime; + } + + return 0; +} + +int DiskAnnIndex::Open(const std::string &file_path, + StorageOptions storage_options) { + ailego::Params storage_params; + file_path_ = file_path; + is_read_only_ = storage_options.read_only; + switch (storage_options.type) { + // case StorageOptions::StorageType::kDisk: + case StorageOptions::StorageType::kMMAP: { + storage_ = core::IndexFactory::CreateStorage("FileReadStorage"); + if (storage_ == nullptr) { + LOG_ERROR("Failed to create FileReadStorage"); + return core::IndexError_Runtime; + } + int ret = storage_->init(storage_params); + if (ret != 0) { + LOG_ERROR("Failed to init FileReadStorage, path: %s, err: %s", + file_path_.c_str(), core::IndexError::What(ret)); + return ret; + } + break; + } + default: { + LOG_ERROR("Unsupported storage type"); + return core::IndexError_Unsupported; + } + } + + if (!storage_options.create_new) { + // read_options.create_new + int ret = storage_->open(file_path_, false); + if (ret != 0) { + LOG_ERROR("Failed to open storage, path: %s, err: %s", file_path_.c_str(), + core::IndexError::What(ret)); + return core::IndexError_Runtime; + } + if (streamer_ == nullptr || streamer_->open(storage_) != 0) { + LOG_ERROR("Failed to open streamer, path: %s", file_path_.c_str()); + return core::IndexError_Runtime; + } + is_trained_ = true; + } + is_open_ = true; + return 0; +} + +int DiskAnnIndex::GenerateHolder() { + if (param_.data_type == DataType::DT_FP16) { + auto holder = + std::make_shared>( + param_.dimension); + for (auto doc : doc_cache_) { + ailego::NumericalVector vec(doc.second); + if (doc.first == kInvalidKey) { + continue; + } + if (!holder->emplace(doc.first, vec)) { + LOG_ERROR("Failed to add vector"); + return core::IndexError_Runtime; + } + } + holder_ = holder; + } else if (param_.data_type == DataType::DT_FP32) { + auto holder = + std::make_shared>( + param_.dimension); + for (auto doc : doc_cache_) { + ailego::NumericalVector vec(doc.second); + if (doc.first == kInvalidKey) { + continue; + } + if (!holder->emplace(doc.first, vec)) { + LOG_ERROR("Failed to add vector"); + return core::IndexError_Runtime; + } + } + holder_ = holder; + } else if (param_.data_type == DataType::DT_INT8) { + auto holder = + std::make_shared>( + param_.dimension); + for (auto doc : doc_cache_) { + ailego::NumericalVector vec(doc.second); + if (doc.first == kInvalidKey) { + continue; + } + if (!holder->emplace(doc.first, vec)) { + LOG_ERROR("Failed to add vector"); + return core::IndexError_Runtime; + } + } + holder_ = holder; + } else { + LOG_ERROR("data_type is not support"); + return core::IndexError_Runtime; + } + if (converter_) { + core::IndexConverter::TrainAndTransform(converter_, holder_); + holder_ = converter_->result(); + } + return 0; +} + +int DiskAnnIndex::Add(const VectorData &vector, uint32_t doc_id) { + if (is_trained_) { + LOG_ERROR("this diskann index is trained"); + return core::IndexError_Runtime; + } + if (!std::holds_alternative(vector.vector)) { + LOG_ERROR("Invalid vector data"); + return core::IndexError_Runtime; + } + const DenseVector &dense_vector = std::get(vector.vector); + std::string out_vector_buffer = std::string( + static_cast(dense_vector.data), + input_vector_meta_.dimension() * input_vector_meta_.unit_size()); + + std::lock_guard lock(mutex_); + while (doc_cache_.size() <= doc_id) { + std::string fake_data( + input_vector_meta_.dimension() * input_vector_meta_.unit_size(), 0); + doc_cache_.push_back(std::make_pair(kInvalidKey, fake_data)); + } + doc_cache_[doc_id] = std::make_pair(doc_id, out_vector_buffer); + return 0; +} + +int DiskAnnIndex::Train() { + GenerateHolder(); + builder_->train(holder_); + builder_->build(holder_); + auto dumper = core::IndexFactory::CreateDumper("FileDumper"); + + dumper->create(file_path_); + builder_->dump(dumper); + dumper->close(); + int ret = storage_->open(file_path_, false); + if (ret != 0) { + LOG_ERROR("Failed to open storage, path: %s, err: %s", file_path_.c_str(), + core::IndexError::What(ret)); + return core::IndexError_Runtime; + } + if (streamer_ == nullptr || streamer_->open(storage_) != 0) { + LOG_ERROR("Failed to open streamer, path: %s", file_path_.c_str()); + return core::IndexError_Runtime; + } + is_trained_ = true; + return 0; +} + +int DiskAnnIndex::_dense_fetch(const uint32_t doc_id, + VectorDataBuffer *vector_data_buffer) { + if (is_trained_) { + return Index::_dense_fetch(doc_id, vector_data_buffer); + } else { + DenseVectorBuffer dense_vector_buffer; + std::string &out_vector_buffer = dense_vector_buffer.data; + out_vector_buffer = doc_cache_[doc_id].second; + vector_data_buffer->vector_buffer = std::move(dense_vector_buffer); + return 0; + } +} + +int DiskAnnIndex::_prepare_for_search( + const VectorData & /*query*/, + const BaseIndexQueryParam::Pointer &search_param, + core::IndexContext::Pointer &context) { + const auto &diskann_search_param = + std::dynamic_pointer_cast(search_param); + + context->set_topk(diskann_search_param->topk); + + return 0; +} + +int DiskAnnIndex::Merge(const std::vector &indexes, + const IndexFilter &filter, + const MergeOptions &options) { + int pre_ret = Index::Merge(indexes, filter, options); + if (pre_ret != 0) { + return pre_ret; + } + auto dumper = core::IndexFactory::CreateDumper("FileDumper"); + + dumper->create(file_path_); + builder_->dump(dumper); + dumper->close(); + int ret = storage_->open(file_path_, false); + if (ret != 0) { + LOG_ERROR("Failed to open storage, path: %s, err: %s", file_path_.c_str(), + core::IndexError::What(ret)); + return core::IndexError_Runtime; + } + if (streamer_ == nullptr || streamer_->open(storage_) != 0) { + LOG_ERROR("Failed to open streamer, path: %s", file_path_.c_str()); + return core::IndexError_Runtime; + } + is_trained_ = true; + return 0; +} + +} // namespace zvec::core_interface \ No newline at end of file diff --git a/src/core/utility/file_dumper.cc b/src/core/utility/file_dumper.cc index b53fc09a3..1866f9352 100644 --- a/src/core/utility/file_dumper.cc +++ b/src/core/utility/file_dumper.cc @@ -99,6 +99,11 @@ struct FileDumper : public IndexDumper { return packer_.magic(); } + //! Retrieve size of index + size_t size(void) const override { + return file_.size(); + } + protected: //! Close index file bool close_index(void) { diff --git a/src/core/utility/file_read_storage.cc b/src/core/utility/file_read_storage.cc index 3459df0ef..b27ba7b62 100644 --- a/src/core/utility/file_read_storage.cc +++ b/src/core/utility/file_read_storage.cc @@ -82,6 +82,11 @@ class FileReadStorage : public IndexStorage { return region_size_; } + //! Retrieve offset of data + size_t data_offset(void) const override { + return data_offset_; + } + //! Fetch data from segment (with own buffer) size_t fetch(size_t offset, void *buf, size_t len) const override { if (ailego_unlikely(offset + len > region_size_)) { @@ -285,7 +290,7 @@ class FileReadStorage : public IndexStorage { } int flush(void) override { - return IndexError_NotImplemented; + return 0; } int append(const std::string & /*id*/, size_t /*size*/) override { @@ -385,6 +390,15 @@ class FileReadStorage : public IndexStorage { return magic_; } + //! Retrieve file ptr if has + std::shared_ptr file(void) const { + return file_ptr_; + } + + std::string file_path(void) const { + return file_path_; + } + protected: //! Open a index file static inline std::shared_ptr OpenFile(const std::string &path, diff --git a/src/core/utility/mmap_file_read_storage.cc b/src/core/utility/mmap_file_read_storage.cc index a1a2c92a9..411e46c33 100644 --- a/src/core/utility/mmap_file_read_storage.cc +++ b/src/core/utility/mmap_file_read_storage.cc @@ -39,6 +39,7 @@ class MMapFileReadStorage : public IndexStorage { : data_ptr_(reinterpret_cast(file_ptr->region()) + offset + segment.data_offset()), data_size_(segment.data_size()), + data_offset_(offset + segment.data_offset()), padding_size_(segment.padding_size()), region_size_(segment.data_size() + segment.padding_size()), data_crc_(segment.data_crc()), @@ -66,6 +67,11 @@ class MMapFileReadStorage : public IndexStorage { return region_size_; } + //! Retrieve offset of data + size_t data_offset(void) const override { + return data_offset_; + } + //! Fetch data from segment (with own buffer) size_t fetch(size_t offset, void *buf, size_t len) const override { if (ailego_unlikely(offset + len > region_size_)) { @@ -130,6 +136,7 @@ class MMapFileReadStorage : public IndexStorage { private: const uint8_t *data_ptr_{nullptr}; size_t data_size_{0u}; + size_t data_offset_{0}; size_t padding_size_{0u}; size_t region_size_{0u}; uint32_t data_crc_{0u}; diff --git a/src/db/index/column/vector_column/engine_helper.hpp b/src/db/index/column/vector_column/engine_helper.hpp index e7e323f0c..458bd2076 100644 --- a/src/db/index/column/vector_column/engine_helper.hpp +++ b/src/db/index/column/vector_column/engine_helper.hpp @@ -202,6 +202,18 @@ class ProximaEngineHelper { } return std::move(ivf_query_param); } + + case IndexType::DISKANN: { + auto diskann_query_param_result = + _build_common_query_param( + query_params); + if (!diskann_query_param_result.has_value()) { + return tl::make_unexpected(Status::InvalidArgument( + "failed to build query param: " + + diskann_query_param_result.error().message())); + } + return std::move(diskann_query_param_result.value()); + } default: return tl::make_unexpected(Status::InvalidArgument("not supported")); } @@ -395,6 +407,24 @@ class ProximaEngineHelper { return index_param_builder->Build(); } + case IndexType::DISKANN: { + auto index_param_builder_result = + _build_common_index_param( + field_schema); + if (!index_param_builder_result.has_value()) { + return tl::make_unexpected(Status::InvalidArgument( + "failed to build index param: " + + index_param_builder_result.error().message())); + } + auto index_param_builder = index_param_builder_result.value(); + + // auto db_index_params = dynamic_cast( + // field_schema.index_params().get()); + + return index_param_builder->Build(); + } + default: return tl::make_unexpected(Status::InvalidArgument("not supported")); } diff --git a/src/db/index/common/doc.cc b/src/db/index/common/doc.cc index bcb401d5f..41cfc1859 100644 --- a/src/db/index/common/doc.cc +++ b/src/db/index/common/doc.cc @@ -1156,7 +1156,7 @@ struct Doc::ValueEqual { const std::vector &b) const { if (a.size() != b.size()) return false; for (size_t i = 0; i < a.size(); ++i) - if (std::fabs(a[i] - b[i]) >= 1e-4f) return false; + if (std::fabs(a[i] - b[i]) >= 5e-4f) return false; return true; } diff --git a/src/db/index/common/proto_converter.cc b/src/db/index/common/proto_converter.cc index 46eb93f5a..b4b0e752c 100644 --- a/src/db/index/common/proto_converter.cc +++ b/src/db/index/common/proto_converter.cc @@ -116,6 +116,28 @@ proto::InvertIndexParams ProtoConverter::ToPb(const InvertIndexParams *params) { return params_pb; } +// DiskAnnIndexParams +DiskAnnIndexParams::OPtr ProtoConverter::FromPb( + const proto::DiskAnnIndexParams ¶ms_pb) { + return std::make_shared( + MetricTypeCodeBook::Get(params_pb.base().metric_type()), + params_pb.max_degree(), params_pb.list_size(), params_pb.pq_chunk_num(), + QuantizeTypeCodeBook::Get(params_pb.base().quantize_type())); +} + +proto::DiskAnnIndexParams ProtoConverter::ToPb( + const DiskAnnIndexParams *params) { + proto::DiskAnnIndexParams params_pb; + params_pb.mutable_base()->set_metric_type( + MetricTypeCodeBook::Get(params->metric_type())); + params_pb.mutable_base()->set_quantize_type( + QuantizeTypeCodeBook::Get(params->quantize_type())); + params_pb.set_max_degree(params->max_degree()); + params_pb.set_list_size(params->list_size()); + params_pb.set_pq_chunk_num(params->pq_chunk_num()); + return params_pb; +} + // FieldSchema FieldSchema::Ptr ProtoConverter::FromPb(const proto::FieldSchema &schema_pb) { auto schema = std::make_shared(); @@ -185,6 +207,8 @@ IndexParams::Ptr ProtoConverter::FromPb(const proto::IndexParams ¶ms_pb) { return ProtoConverter::FromPb(params_pb.flat()); } else if (params_pb.has_hnsw_rabitq()) { return ProtoConverter::FromPb(params_pb.hnsw_rabitq()); + } else if (params_pb.has_diskann()) { + return ProtoConverter::FromPb(params_pb.diskann()); } return nullptr; @@ -246,6 +270,15 @@ proto::IndexParams ProtoConverter::ToPb(const IndexParams *params) { params_pb.mutable_hnsw_rabitq()->CopyFrom( ProtoConverter::ToPb(hnsw_rabitq_params)); } + break; + } + case IndexType::DISKANN: { + auto diskann_params = dynamic_cast(params); + if (diskann_params) { + params_pb.mutable_diskann()->CopyFrom( + ProtoConverter::ToPb(diskann_params)); + } + break; } default: break; diff --git a/src/db/index/common/proto_converter.h b/src/db/index/common/proto_converter.h index ad96007a4..dc6a09621 100644 --- a/src/db/index/common/proto_converter.h +++ b/src/db/index/common/proto_converter.h @@ -43,6 +43,11 @@ struct ProtoConverter { const proto::InvertIndexParams ¶ms_pb); static proto::InvertIndexParams ToPb(const InvertIndexParams *params); + // DiskAnnIndexParams + static DiskAnnIndexParams::OPtr FromPb( + const proto::DiskAnnIndexParams ¶ms_pb); + static proto::DiskAnnIndexParams ToPb(const DiskAnnIndexParams *params); + // IndexParams static IndexParams::Ptr FromPb(const proto::IndexParams ¶ms_pb); static proto::IndexParams ToPb(const IndexParams *params); diff --git a/src/db/index/common/schema.cc b/src/db/index/common/schema.cc index 971de61d6..45f76df54 100644 --- a/src/db/index/common/schema.cc +++ b/src/db/index/common/schema.cc @@ -54,7 +54,8 @@ std::unordered_set support_sparse_vector_type = { }; std::unordered_set support_dense_vector_index = { - IndexType::FLAT, IndexType::HNSW, IndexType::HNSW_RABITQ, IndexType::IVF}; + IndexType::FLAT, IndexType::HNSW, IndexType::HNSW_RABITQ, IndexType::IVF, + IndexType::DISKANN}; std::unordered_set support_sparse_vector_index = {IndexType::FLAT, IndexType::HNSW}; diff --git a/src/db/index/common/type_helper.h b/src/db/index/common/type_helper.h index 33d2ee344..ed520fe1b 100644 --- a/src/db/index/common/type_helper.h +++ b/src/db/index/common/type_helper.h @@ -35,6 +35,8 @@ struct IndexTypeCodeBook { return IndexType::IVF; case proto::IT_INVERT: return IndexType::INVERT; + case proto::IT_DISKANN: + return IndexType::DISKANN; default: break; } @@ -54,6 +56,8 @@ struct IndexTypeCodeBook { return proto::IT_IVF; case IndexType::INVERT: return proto::IT_INVERT; + case IndexType::DISKANN: + return proto::IT_DISKANN; default: break; } @@ -71,6 +75,8 @@ struct IndexTypeCodeBook { return "FLAT"; case IndexType::IVF: return "IVF"; + case IndexType::DISKANN: + return "DISKANN"; case IndexType::INVERT: return "INVERT"; default: diff --git a/src/db/proto/zvec.proto b/src/db/proto/zvec.proto index 3c9d33319..1294e89c1 100644 --- a/src/db/proto/zvec.proto +++ b/src/db/proto/zvec.proto @@ -58,6 +58,8 @@ enum IndexType { IT_FLAT = 3; // Proxima HNSW RABITQ Index IT_HNSW_RABITQ = 4; + // Proxima DiskAnn Index + IT_DISKANN = 5; // Invert Index IT_INVERT = 10; }; @@ -112,6 +114,13 @@ message IVFIndexParams { bool use_soar = 4; } +message DiskAnnIndexParams { + BaseIndexParams base = 1; + int32 max_degree = 2; + int32 list_size = 3; + int32 pq_chunk_num = 4; +} + message IndexParams { oneof params { InvertIndexParams invert = 1; @@ -119,6 +128,7 @@ message IndexParams { FlatIndexParams flat = 3; IVFIndexParams ivf = 4; HnswRabitqIndexParams hnsw_rabitq = 5; + DiskAnnIndexParams diskann = 6; }; }; diff --git a/src/include/zvec/core/framework/index_document.h b/src/include/zvec/core/framework/index_document.h index aa74e6064..686488e18 100644 --- a/src/include/zvec/core/framework/index_document.h +++ b/src/include/zvec/core/framework/index_document.h @@ -98,6 +98,23 @@ class IndexDocument { vector_(vector), sparse_doc_(std::move(sparse_doc)) {} + //! Constructor + IndexDocument(uint64_t k, float v, uint32_t i, std::string vector_string) + : key_(k), + score_(v), + index_(i), + vector_string_(std::move(vector_string)) {} + + //! Constructor + IndexDocument(uint64_t k, float v, uint32_t i, std::string vector_string, + IndexSparseDocument sparse_doc) + : key_(k), + score_(v), + index_(i), + vector_{nullptr}, + vector_string_(std::move(vector_string)), + sparse_doc_(std::move(sparse_doc)) {} + IndexDocument(uint64_t k, float v, uint32_t i, IndexStorage::MemoryBlock vec_block, IndexSparseDocument sparse_doc) @@ -116,6 +133,7 @@ class IndexDocument { score_(rhs.score_), index_(rhs.index_), vector_(rhs.vector_), + vector_string_(std::move(rhs.vector_string_)), sparse_doc_{rhs.sparse_doc_} { if (rhs.has_vec_mem_block_) { vec_mem_block_ = rhs.vec_mem_block_; @@ -130,6 +148,7 @@ class IndexDocument { score_ = rhs.score_; index_ = rhs.index_; vector_ = rhs.vector_; + vector_string_ = rhs.vector_string_; if (rhs.has_vec_mem_block_) { vec_mem_block_ = rhs.vec_mem_block_; has_vec_mem_block_ = true; @@ -169,6 +188,11 @@ class IndexDocument { return vector_; } + //! Retrieve vec string + const std::string &vector_string() const { + return vector_string_; + } + //! Retrieve vec const IndexSparseDocument &sparse_doc() const { return sparse_doc_; @@ -210,6 +234,7 @@ class IndexDocument { float score_{0.0f}; uint32_t index_{0u}; const void *vector_{nullptr}; + std::string vector_string_{}; bool has_vec_mem_block_{false}; mutable IndexStorage::MemoryBlock vec_mem_block_{}; IndexSparseDocument sparse_doc_{}; diff --git a/src/include/zvec/core/framework/index_dumper.h b/src/include/zvec/core/framework/index_dumper.h index a638adcf9..53b9b5374 100644 --- a/src/include/zvec/core/framework/index_dumper.h +++ b/src/include/zvec/core/framework/index_dumper.h @@ -51,6 +51,11 @@ class IndexDumper : public IndexModule { //! Retrieve magic number of index virtual uint32_t magic(void) const = 0; + + //! Retrieve size + virtual size_t size(void) const { + return 0; + } }; /*! Index Segment Dumper diff --git a/src/include/zvec/core/framework/index_runner.h b/src/include/zvec/core/framework/index_runner.h index 72d39e979..5406bf41e 100644 --- a/src/include/zvec/core/framework/index_runner.h +++ b/src/include/zvec/core/framework/index_runner.h @@ -447,6 +447,12 @@ class IndexRunner : public IndexModule { return nullptr; } + //! Get vector by key + virtual int get_vector(uint64_t /*key*/, Context::Pointer & /*context*/, + std::string & /*vector*/) const { + return IndexError_NotImplemented; + } + virtual int get_vector_by_id(const uint32_t /*id*/, IndexStorage::MemoryBlock & /*block*/) const { return IndexError_NotImplemented; diff --git a/src/include/zvec/core/framework/index_segment_storage.h b/src/include/zvec/core/framework/index_segment_storage.h index 82b316d1b..56c01b40c 100644 --- a/src/include/zvec/core/framework/index_segment_storage.h +++ b/src/include/zvec/core/framework/index_segment_storage.h @@ -108,6 +108,11 @@ class IndexSegmentStorage : public IndexStorage { return IndexError_NotImplemented; } + //! Retrieve offset of data + virtual size_t data_offset(void) const { + return 0; + } + void update_data_crc(uint32_t) override { return; } diff --git a/src/include/zvec/core/framework/index_storage.h b/src/include/zvec/core/framework/index_storage.h index 8273004a3..808abbf8b 100644 --- a/src/include/zvec/core/framework/index_storage.h +++ b/src/include/zvec/core/framework/index_storage.h @@ -16,6 +16,7 @@ #include #include +#include #include #include @@ -184,6 +185,11 @@ class IndexStorage : public IndexModule { //! Retrieve size of data virtual size_t data_size(void) const = 0; + //! Retrieve offset of data + virtual size_t data_offset(void) const { + return 0; + } + //! Retrieve crc of data virtual uint32_t data_crc(void) const = 0; @@ -264,6 +270,15 @@ class IndexStorage : public IndexModule { virtual bool isHugePage(void) const { return false; } + + //! Retrieve file ptr if has + virtual std::shared_ptr file(void) const { + return nullptr; + } + + virtual std::string file_path(void) const { + return ""; + } }; } // namespace core diff --git a/src/include/zvec/core/interface/constants.h b/src/include/zvec/core/interface/constants.h index 1bc61dce6..402ad2cd6 100644 --- a/src/include/zvec/core/interface/constants.h +++ b/src/include/zvec/core/interface/constants.h @@ -26,5 +26,8 @@ constexpr static uint32_t kDefaultHnswEfSearch = 300; constexpr const uint32_t kDefaultRabitqTotalBits = 7; constexpr const uint32_t kDefaultRabitqNumClusters = 16; +constexpr const uint32_t kDefaultDiskAnnMaxDegree = 100; +constexpr const uint32_t kDefaultDiskAnnListSize = 200; +constexpr const uint32_t kDefaultDiskAnnPqChunkNum = 16; } // namespace zvec::core_interface \ No newline at end of file diff --git a/src/include/zvec/core/interface/index.h b/src/include/zvec/core/interface/index.h index 8634e3904..0e95087d1 100644 --- a/src/include/zvec/core/interface/index.h +++ b/src/include/zvec/core/interface/index.h @@ -316,5 +316,37 @@ class HNSWRabitqIndex : public Index { HNSWRabitqIndexParam param_{}; }; +class DiskAnnIndex : public Index { + public: + DiskAnnIndex() = default; + + protected: + virtual int CreateAndInitStreamer(const BaseIndexParam ¶m) override; + + virtual int _prepare_for_search( + const VectorData &query, const BaseIndexQueryParam::Pointer &search_param, + core::IndexContext::Pointer &context) override; + + virtual int Add(const VectorData &vector, uint32_t doc_id) override; + + virtual int Train() override; + + virtual int Open(const std::string &file_path, + StorageOptions storage_options) override; + + virtual int _dense_fetch(const uint32_t doc_id, + VectorDataBuffer *vector_data_buffer) override; + virtual int Merge(const std::vector &indexes, + const IndexFilter &filter, + const MergeOptions &options) override; + int GenerateHolder(); + + private: + DiskAnnIndexParam param_{}; + std::mutex mutex_{}; + std::vector> doc_cache_; + core::IndexHolder::Pointer holder_{}; + std::string file_path_; +}; } // namespace zvec::core_interface diff --git a/src/include/zvec/core/interface/index_param.h b/src/include/zvec/core/interface/index_param.h index 0d7bf3017..aeb364700 100644 --- a/src/include/zvec/core/interface/index_param.h +++ b/src/include/zvec/core/interface/index_param.h @@ -63,6 +63,7 @@ enum class IndexType { kIVF, // it's actual a two-layer index kHNSW, kHNSWRabitq, + kDiskAnn, }; enum class IVFSearchMethod { kBF, kHNSW }; @@ -214,6 +215,14 @@ struct IVFQueryParam : public BaseIndexQueryParam { } }; +struct DiskAnnQueryParam : public BaseIndexQueryParam { + using Pointer = std::shared_ptr; + + BaseIndexQueryParam::Pointer Clone() const override { + return std::make_shared(*this); + } +}; + // --- Construction Parameters --- // template class BaseIndexParam : public SerializableBase { @@ -362,4 +371,27 @@ struct HNSWRabitqIndexParam : public BaseIndexParam { bool omit_empty_value = false) const override; }; +struct DiskAnnIndexParam : public BaseIndexParam { + using Pointer = std::shared_ptr; + + int max_degree = kDefaultDiskAnnMaxDegree; + int list_size = kDefaultDiskAnnListSize; + int pq_chunk_num = kDefaultDiskAnnPqChunkNum; + + // Constructors with delegation + DiskAnnIndexParam() : BaseIndexParam(IndexType::kDiskAnn) {} + + DiskAnnIndexParam(MetricType metric, int dim, int max_degree, int list_size, + int pq_chunk_num) + : BaseIndexParam(IndexType::kDiskAnn, metric, dim), + max_degree(max_degree), + list_size(list_size), + pq_chunk_num(pq_chunk_num) {} + + protected: + bool DeserializeFromJsonObject(const ailego::JsonObject &json_obj) override; + ailego::JsonObject SerializeToJsonObject( + bool omit_empty_value = false) const override; +}; + } // namespace zvec::core_interface \ No newline at end of file diff --git a/src/include/zvec/core/interface/index_param_builders.h b/src/include/zvec/core/interface/index_param_builders.h index e22ecb392..7baed42d8 100644 --- a/src/include/zvec/core/interface/index_param_builders.h +++ b/src/include/zvec/core/interface/index_param_builders.h @@ -191,6 +191,23 @@ class HNSWRabitqIndexParamBuilder } }; +class DiskAnnIndexParamBuilder + : public BaseIndexParamBuilder { + public: + DiskAnnIndexParamBuilder() = default; + DiskAnnIndexParamBuilder &WithMaxDegree(int max_degree) { + param->max_degree = max_degree; + return *this; + } + DiskAnnIndexParamBuilder &WithPqChunkNum(int pq_chunk_num) { + param->pq_chunk_num = pq_chunk_num; + return *this; + } + std::shared_ptr Build() override { + return param; + } +}; // class CompositeIndexParamBuilder : public // BaseIndexParamBuilder // { public: diff --git a/src/include/zvec/db/index_params.h b/src/include/zvec/db/index_params.h index fcccf080d..b8519b04c 100644 --- a/src/include/zvec/db/index_params.h +++ b/src/include/zvec/db/index_params.h @@ -46,7 +46,8 @@ class IndexParams { bool is_vector_index_type() const { return type_ == IndexType::FLAT || type_ == IndexType::HNSW || - type_ == IndexType::HNSW_RABITQ || type_ == IndexType::IVF; + type_ == IndexType::HNSW_RABITQ || type_ == IndexType::IVF || + type_ == IndexType::DISKANN; } IndexType type() const { @@ -428,4 +429,76 @@ class IVFIndexParams : public VectorIndexParams { bool use_soar_; }; +class DiskAnnIndexParams : public VectorIndexParams { + public: + DiskAnnIndexParams(MetricType metric_type, int max_degree = 100, + int list_size = 50, int pq_chunk_num = 0, + QuantizeType quantize_type = QuantizeType::UNDEFINED) + : VectorIndexParams(IndexType::DISKANN, metric_type, quantize_type), + max_degree_{max_degree}, + list_size_{list_size}, + pq_chunk_num_{pq_chunk_num} {} + + using OPtr = std::shared_ptr; + + public: + Ptr clone() const override { + return std::make_shared( + metric_type_, max_degree_, list_size_, pq_chunk_num_, quantize_type_); + } + + std::string to_string() const override { + auto base_str = vector_index_params_to_string("DiskAnnIndexParams", + metric_type_, quantize_type_); + std::ostringstream oss; + oss << base_str << ",max_degree:" << max_degree_ + << ",list_size:" << list_size_ << ", pq_chunk_num:" << pq_chunk_num_ + << "}"; + return oss.str(); + } + + int max_degree() const { + return max_degree_; + } + + void set_max_degree(int max_degree) { + max_degree_ = max_degree; + } + + int list_size() const { + return list_size_; + } + + void set_list_size(int list_size) { + list_size_ = list_size; + } + + int pq_chunk_num() const { + return pq_chunk_num_; + } + + void pq_chunk_num(int pq_chunk_num) { + pq_chunk_num_ = pq_chunk_num; + } + + bool operator==(const IndexParams &other) const override { + return type() == other.type() && + metric_type() == + static_cast(other).metric_type() && + max_degree_ == + static_cast(other).max_degree_ && + list_size_ == + static_cast(other).list_size_ && + pq_chunk_num_ == + static_cast(other).pq_chunk_num_ && + quantize_type() == + static_cast(other).quantize_type(); + } + + private: + int max_degree_; + int list_size_; + int pq_chunk_num_; +}; + } // namespace zvec \ No newline at end of file diff --git a/src/include/zvec/db/query_params.h b/src/include/zvec/db/query_params.h index ba62dab9c..a4e06c841 100644 --- a/src/include/zvec/db/query_params.h +++ b/src/include/zvec/db/query_params.h @@ -172,4 +172,24 @@ class FlatQueryParams : public QueryParams { float scale_factor_{10}; }; +class DiskAnnQueryParams : public QueryParams { + public: + DiskAnnQueryParams(int list_size = 300) : QueryParams(IndexType::DISKANN) { + set_list_size(list_size); + } + + virtual ~DiskAnnQueryParams() = default; + + int list_size() const { + return list_size_; + } + + void set_list_size(int list_size) { + list_size_ = list_size; + } + + private: + int list_size_; +}; + } // namespace zvec \ No newline at end of file diff --git a/src/include/zvec/db/type.h b/src/include/zvec/db/type.h index 1578f81d8..2b1d6b79a 100644 --- a/src/include/zvec/db/type.h +++ b/src/include/zvec/db/type.h @@ -26,6 +26,7 @@ enum class IndexType : uint32_t { IVF = 2, FLAT = 3, HNSW_RABITQ = 4, + DISKANN = 5, INVERT = 10, }; diff --git a/tests/core/algorithm/CMakeLists.txt b/tests/core/algorithm/CMakeLists.txt index 9ef1ec2a0..07c3687e6 100644 --- a/tests/core/algorithm/CMakeLists.txt +++ b/tests/core/algorithm/CMakeLists.txt @@ -7,6 +7,14 @@ cc_directories(flat_sparse) cc_directories(ivf) cc_directories(hnsw) cc_directories(hnsw_sparse) + +if(DISKANN_SUPPORTED) + message(STATUS "build diskann tests") + cc_directory(diskann) +else() + message(STATUS "skip diskann tests (unsupported platform)") +endif() + if(RABITQ_SUPPORTED) -cc_directories(hnsw_rabitq) + cc_directories(hnsw_rabitq) endif() diff --git a/tests/core/algorithm/cluster/multi_chunk_cluster_test.cc b/tests/core/algorithm/cluster/multi_chunk_cluster_test.cc new file mode 100644 index 000000000..8a3b6c96b --- /dev/null +++ b/tests/core/algorithm/cluster/multi_chunk_cluster_test.cc @@ -0,0 +1,221 @@ +// Copyright 2025-present the zvec project +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "cluster/multi_chunk_cluster.h" +#include +#include +#include +#include +#include +#include "zvec/core/framework/index_framework.h" + +using namespace zvec::core; +using namespace zvec::ailego; +using namespace zvec::ailego; +TEST(MultiChunkCluster, General) { + const uint32_t count = 10000u; + const uint32_t dimension = 960u; + const uint32_t chunk_count = 480u; + const uint32_t cluster_count = 256u; + // const uint32_t thread_count = 4; + const uint32_t thread_count = 16; + + IndexMeta index_meta; + + index_meta.set_meta(IndexMeta::DataType::DT_FP32, dimension); + index_meta.set_metric("SquaredEuclidean", 0, Params()); + + std::shared_ptr features( + new CompactIndexFeatures(index_meta)); + + std::random_device rd; + std::mt19937 gen(rd()); + std::uniform_real_distribution dist(0.0, 5.0); + + for (uint32_t i = 0; i < count; ++i) { + std::vector vec(dimension); + for (size_t j = 0; j < dimension; ++j) { + vec[j] = dist(gen); + } + features->emplace(vec.data()); + } + + // Create a Kmeans cluster + MultiChunkCluster cluster = MultiChunkCluster(); + + Params params; + params.set("proxima.cluster.multi_chunk_cluster.count", cluster_count); + params.set("proxima.cluster.multi_chunk_cluster.chunk_count", chunk_count); + params.set("proxima.cluster.multi_chunk_cluster.thread_count", thread_count); + + ASSERT_EQ(0, cluster.init(index_meta, params)); + ASSERT_EQ(0, cluster.mount(features)); + + IndexCluster::CentroidList centroids; + std::vector labels; + + ASSERT_EQ(0, cluster.cluster(nullptr, centroids)); + + for (size_t chunk = 0; chunk < chunk_count; ++chunk) { + for (size_t cluster = 0; cluster < cluster_count; ++cluster) { + size_t idx = chunk * cluster_count + cluster; + const auto ¢ = centroids[idx]; + const auto &vec = cent.vector(); + + std::cout << "chunk: " << chunk << ", cluster: " << cluster + << ", dim: " << vec.size() << ", count: " << cent.follows() + << " (" << cent.score() << ") { " << vec[0] << "," << vec[1] + << " }" << std::endl; + ASSERT_EQ(0u, cent.similars().size()); + } + } + + ASSERT_EQ(0, cluster.label(nullptr, centroids, &labels)); +} + +TEST(MultiChunkCluster, TestChunk) { + const uint32_t count = 10000u; + const uint32_t dimension = 95; + const uint32_t chunk_count = 20u; + const uint32_t cluster_count = 256u; + // const uint32_t thread_count = 4; + const uint32_t thread_count = 16; + + IndexMeta index_meta; + + index_meta.set_meta(IndexMeta::DataType::DT_FP32, dimension); + index_meta.set_metric("SquaredEuclidean", 0, Params()); + + std::shared_ptr features( + new CompactIndexFeatures(index_meta)); + + std::random_device rd; + std::mt19937 gen(rd()); + std::uniform_real_distribution dist(0.0, 5.0); + + for (uint32_t i = 0; i < count; ++i) { + std::vector vec(dimension); + for (size_t j = 0; j < dimension; ++j) { + vec[j] = dist(gen); + } + features->emplace(vec.data()); + } + + // Create a Kmeans cluster + MultiChunkCluster cluster = MultiChunkCluster(); + + Params params; + params.set("proxima.cluster.multi_chunk_cluster.count", cluster_count); + params.set("proxima.cluster.multi_chunk_cluster.chunk_count", chunk_count); + params.set("proxima.cluster.multi_chunk_cluster.thread_count", thread_count); + + ASSERT_EQ(0, cluster.init(index_meta, params)); + ASSERT_EQ(0, cluster.mount(features)); + + IndexCluster::CentroidList centroids; + std::vector labels; + + ASSERT_EQ(0, cluster.cluster(nullptr, centroids)); + + for (size_t chunk = 0; chunk < chunk_count; ++chunk) { + for (size_t cluster = 0; cluster < cluster_count; ++cluster) { + size_t idx = chunk * cluster_count + cluster; + const auto ¢ = centroids[idx]; + const auto &vec = cent.vector(); + + std::cout << "chunk: " << chunk << ", cluster: " << cluster + << ", dim: " << vec.size() << ", count: " << cent.follows() + << " (" << cent.score() << ") { " << vec[0] << "," << vec[1] + << " }" << std::endl; + ASSERT_EQ(0u, cent.similars().size()); + } + } + + ASSERT_EQ(0, cluster.label(nullptr, centroids, &labels)); +} + +TEST(MultiChunkCluster, General_InnerProduct) { + const uint32_t count = 50000u; + const uint32_t dimension = 96u; + const uint32_t chunk_count = 12u; + const uint32_t cluster_count = 16u; + const uint32_t chain_length = 0; + const uint32_t thread_count = 16; + + IndexMeta index_meta; + + index_meta.set_meta(IndexMeta::DataType::DT_FP32, dimension); + index_meta.set_metric("InnerProduct", 0, Params()); + + std::shared_ptr features( + new CompactIndexFeatures(index_meta)); + + std::random_device rd; + std::mt19937 gen(rd()); + std::uniform_real_distribution dist(-1.0, 1.0); + + // do normalize + for (uint32_t i = 0; i < count; ++i) { + std::vector vec(dimension); + + float norm = 0; + for (size_t j = 0; j < dimension; ++j) { + vec[j] = dist(gen); + + norm += vec[j] * vec[j]; + } + norm = sqrt(norm); + + for (size_t j = 0; j < dimension; ++j) { + vec[j] /= norm; + } + + features->emplace(vec.data()); + } + + // Create a Kmeans cluster + MultiChunkCluster cluster = MultiChunkCluster(); + + Params params; + params.set("proxima.cluster.multi_chunk_cluster.count", cluster_count); + params.set("proxima.cluster.multi_chunk_cluster.chunk_count", chunk_count); + params.set("proxima.cluster.multi_chunk_cluster.thread_count", thread_count); + params.set("proxima.cluster.multi_chunk_cluster.markov_chain_length", + chain_length); + + ASSERT_EQ(0, cluster.init(index_meta, params)); + ASSERT_EQ(0, cluster.mount(features)); + + IndexCluster::CentroidList centroids; + std::vector labels; + + ASSERT_EQ(0, cluster.cluster(nullptr, centroids)); + + for (size_t chunk = 0; chunk < chunk_count; ++chunk) { + for (size_t cluster = 0; cluster < cluster_count; ++cluster) { + size_t idx = chunk * cluster_count + cluster; + const auto ¢ = centroids[idx]; + const auto &vec = cent.vector(); + + std::cout << "chunk: " << chunk << ", cluster: " << cluster + << ", dim: " << vec.size() << ", count: " << cent.follows() + << " (" << cent.score() << ") { " << vec[0] << ", " << vec[1] + << ", " << vec[2] << ", ... , " << vec[vec.size() - 2] << ", " + << vec[vec.size() - 1] << " }" << std::endl; + ASSERT_EQ(0u, cent.similars().size()); + } + } + + ASSERT_EQ(0, cluster.label(nullptr, centroids, &labels)); +} diff --git a/tests/core/algorithm/diskann/CMakeLists.txt b/tests/core/algorithm/diskann/CMakeLists.txt new file mode 100644 index 000000000..3f2a093e2 --- /dev/null +++ b/tests/core/algorithm/diskann/CMakeLists.txt @@ -0,0 +1,14 @@ +include(${CMAKE_SOURCE_DIR}/cmake/bazel.cmake) + +file(GLOB_RECURSE ALL_TEST_SRCS *_test.cc) + +foreach(CC_SRCS ${ALL_TEST_SRCS}) + get_filename_component(CC_TARGET ${CC_SRCS} NAME_WE) + cc_gtest( + NAME ${CC_TARGET} + STRICT + LIBS zvec_ailego core_framework core_utility core_metric core_quantizer core_knn_diskann + SRCS ${CC_SRCS} + INCS . ${CMAKE_SOURCE_DIR}/src/core ${CMAKE_SOURCE_DIR}/src/core/algorithm/diskann + ) +endforeach() \ No newline at end of file diff --git a/tests/core/algorithm/diskann/diskann_builder_test.cc b/tests/core/algorithm/diskann/diskann_builder_test.cc new file mode 100644 index 000000000..d9d2d64e6 --- /dev/null +++ b/tests/core/algorithm/diskann/diskann_builder_test.cc @@ -0,0 +1,100 @@ +// Copyright 2025-present the zvec project +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "diskann_builder.h" +#include +#include +#include +#include +#include +#include +#include +#include "diskann_holder.h" + +using namespace zvec::core; +using namespace zvec::ailego; +using namespace std; + +constexpr size_t static dim = 64; + +class DiskAnnBuilderTest : public testing::Test { + protected: + void SetUp(void); + void TearDown(void); + + static std::string _dir; + static shared_ptr _index_meta_ptr; +}; + +std::string DiskAnnBuilderTest::_dir("DiskAnnBuilderTest"); +shared_ptr DiskAnnBuilderTest::_index_meta_ptr; + +void DiskAnnBuilderTest::SetUp(void) { + LoggerBroker::SetLevel(Logger::LEVEL_INFO); + + _index_meta_ptr.reset(new (nothrow) + IndexMeta(IndexMeta::DataType::DT_FP32, dim)); + _index_meta_ptr->set_metric("SquaredEuclidean", 0, Params()); +} + +void DiskAnnBuilderTest::TearDown(void) { + char cmdBuf[100]; + snprintf(cmdBuf, 100, "rm -rf %s", _dir.c_str()); + system(cmdBuf); +} + +TEST_F(DiskAnnBuilderTest, TestGeneral) { + IndexBuilder::Pointer builder = IndexFactory::CreateBuilder("DiskAnnBuilder"); + ASSERT_NE(builder, nullptr); + + auto holder = + make_shared>(dim); + size_t doc_cnt = 10000UL; + for (size_t i = 0; i < doc_cnt; i++) { + NumericalVector vec(dim); + for (size_t j = 0; j < dim; ++j) { + vec[j] = i; + } + ASSERT_TRUE(holder->emplace(i, vec)); + } + + Params params; + + params.set("proxima.diskann.builder.max_degree", 32); + params.set("proxima.diskann.builder.list_size", 50); + params.set("proxima.diskann.builder.max_pq_chunk_num", 32); + params.set("proxima.diskann.builder.threads", 4); + + ASSERT_EQ(0, builder->init(*_index_meta_ptr, params)); + + ASSERT_EQ(0, builder->train(holder)); + + ASSERT_EQ(0, builder->build(holder)); + + auto dumper = IndexFactory::CreateDumper("FileDumper"); + ASSERT_NE(dumper, nullptr); + + string path = _dir + "/TestGeneral"; + ASSERT_EQ(0, dumper->create(path)); + ASSERT_EQ(0, builder->dump(dumper)); + ASSERT_EQ(0, dumper->close()); + + auto &stats = builder->stats(); + ASSERT_EQ(doc_cnt, stats.trained_count()); + ASSERT_EQ(doc_cnt, stats.built_count()); + ASSERT_EQ(doc_cnt, stats.dumped_count()); + ASSERT_EQ(0UL, stats.discarded_count()); + ASSERT_GT(stats.trained_costtime(), 0UL); + ASSERT_GT(stats.built_costtime(), 0UL); +} diff --git a/tests/core/algorithm/diskann/diskann_searcher_test.cc b/tests/core/algorithm/diskann/diskann_searcher_test.cc new file mode 100644 index 000000000..5b39400cd --- /dev/null +++ b/tests/core/algorithm/diskann/diskann_searcher_test.cc @@ -0,0 +1,816 @@ +// Copyright 2025-present the zvec project +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "diskann_searcher.h" +#include +#include +#include +#include +#include +#include +#include +#include "diskann_holder.h" +#include "diskann_params.h" + +using namespace zvec::core; +using namespace zvec::ailego; +using namespace std; + +constexpr size_t static dim = 64; + +class DiskAnnSearcherTest : public testing::Test { + protected: + void SetUp(void); + void TearDown(void); + + static std::string _dir; + static shared_ptr _index_meta_ptr; +}; + +std::string DiskAnnSearcherTest::_dir("DiskAnnSearcherTest/"); +shared_ptr DiskAnnSearcherTest::_index_meta_ptr; + +void DiskAnnSearcherTest::SetUp(void) { + LoggerBroker::SetLevel(Logger::LEVEL_INFO); + + _index_meta_ptr.reset(new (nothrow) + IndexMeta(IndexMeta::DataType::DT_FP32, dim)); + _index_meta_ptr->set_metric("SquaredEuclidean", 0, Params()); +} + +void DiskAnnSearcherTest::TearDown(void) { + char cmdBuf[100]; + snprintf(cmdBuf, 100, "rm -rf %s", _dir.c_str()); + system(cmdBuf); +} + +TEST_F(DiskAnnSearcherTest, TestGeneral) { + IndexBuilder::Pointer builder = IndexFactory::CreateBuilder("DiskAnnBuilder"); + ASSERT_NE(builder, nullptr); + + auto holder = + make_shared>(dim); + size_t doc_cnt = 10000UL; + for (size_t i = 0; i < doc_cnt; i++) { + NumericalVector vec(dim); + for (size_t j = 0; j < dim; ++j) { + vec[j] = i; + } + ASSERT_TRUE(holder->emplace(i, vec)); + } + + Params params; + + params.set("proxima.diskann.builder.max_degree", 32); + params.set("proxima.diskann.builder.list_size", 300); + params.set("proxima.diskann.builder.max_pq_chunk_num", 32); + params.set("proxima.diskann.builder.threads", 4); + + ASSERT_EQ(0, builder->init(*_index_meta_ptr, params)); + + ASSERT_EQ(0, builder->train(holder)); + + ASSERT_EQ(0, builder->build(holder)); + + auto dumper = IndexFactory::CreateDumper("FileDumper"); + ASSERT_NE(dumper, nullptr); + + string path = _dir + "/TestGeneral"; + ASSERT_EQ(0, dumper->create(path)); + ASSERT_EQ(0, builder->dump(dumper)); + ASSERT_EQ(0, dumper->close()); + + auto &stats = builder->stats(); + ASSERT_EQ(doc_cnt, stats.trained_count()); + ASSERT_EQ(doc_cnt, stats.built_count()); + ASSERT_EQ(doc_cnt, stats.dumped_count()); + ASSERT_EQ(0UL, stats.discarded_count()); + ASSERT_GT(stats.trained_costtime(), 0UL); + ASSERT_GT(stats.built_costtime(), 0UL); + + // test searcher + IndexSearcher::Pointer searcher = + IndexFactory::CreateSearcher("DiskAnnSearcher"); + ASSERT_TRUE(searcher != nullptr); + + Params search_params; + search_params.set("proxima.diskann.searcher.list_size", 500); + + ASSERT_EQ(0, searcher->init(search_params)); + + auto storage = IndexFactory::CreateStorage("FileReadStorage"); + ASSERT_EQ(0, storage->open(path, false)); + ASSERT_EQ(0, searcher->load(storage, IndexMetric::Pointer())); + auto ctx = searcher->create_context(); + ASSERT_TRUE(!!ctx); + + auto linearCtx = searcher->create_context(); + auto linearByPKeysCtx = searcher->create_context(); + auto knnCtx = searcher->create_context(); + + ASSERT_TRUE(!!linearCtx); + ASSERT_TRUE(!!linearByPKeysCtx); + ASSERT_TRUE(!!knnCtx); + + NumericalVector vec(dim); + IndexQueryMeta qmeta(IndexMeta::DataType::DT_FP32, dim); + size_t topk = 200; + uint64_t knnTotalTime = 0; + uint64_t linearTotalTime = 0; + int totalHits = 0; + int totalCnts = 0; + int topk1Hits = 0; + linearCtx->set_topk(topk); + linearByPKeysCtx->set_topk(topk); + knnCtx->set_topk(topk); + + // do linear search test + { + float query[dim]; + for (size_t i = 0; i < dim; ++i) { + query[i] = 3.1f; + } + ASSERT_EQ(0, searcher->search_bf_impl(query, qmeta, linearCtx)); + auto &linearResult = linearCtx->result(); + ASSERT_EQ(3UL, linearResult[0].key()); + ASSERT_EQ(4UL, linearResult[1].key()); + ASSERT_EQ(2UL, linearResult[2].key()); + ASSERT_EQ(5UL, linearResult[3].key()); + ASSERT_EQ(1UL, linearResult[4].key()); + ASSERT_EQ(6UL, linearResult[5].key()); + ASSERT_EQ(0UL, linearResult[6].key()); + ASSERT_EQ(7UL, linearResult[7].key()); + for (size_t i = 8; i < topk; ++i) { + ASSERT_EQ(i, linearResult[i].key()); + } + } + + // do linear search by p_keys test + std::vector> p_keys; + p_keys.resize(1); + p_keys[0] = {8, 9, 10, 11, 3, 2, 1, 0}; + { + float query[dim]; + for (size_t i = 0; i < dim; ++i) { + query[i] = 3.1f; + } + + ASSERT_EQ(0, searcher->search_bf_by_p_keys_impl(query, p_keys, qmeta, + linearByPKeysCtx)); + auto &linearByPKeysResult = linearByPKeysCtx->result(); + ASSERT_EQ(8, linearByPKeysResult.size()); + ASSERT_EQ(3UL, linearByPKeysResult[0].key()); + ASSERT_EQ(2UL, linearByPKeysResult[1].key()); + ASSERT_EQ(1UL, linearByPKeysResult[2].key()); + ASSERT_EQ(0UL, linearByPKeysResult[3].key()); + ASSERT_EQ(8UL, linearByPKeysResult[4].key()); + ASSERT_EQ(9UL, linearByPKeysResult[5].key()); + ASSERT_EQ(10UL, linearByPKeysResult[6].key()); + ASSERT_EQ(11UL, linearByPKeysResult[7].key()); + } + + size_t step = 500; + for (size_t i = 0; i < doc_cnt; i += step) { + for (size_t j = 0; j < dim; ++j) { + vec[j] = i + 0.1f; + } + auto t1 = Realtime::MicroSeconds(); + ASSERT_EQ(0, searcher->search_impl(vec.data(), qmeta, knnCtx)); + auto t2 = Realtime::MicroSeconds(); + + ASSERT_EQ(0, searcher->search_bf_impl(vec.data(), qmeta, linearCtx)); + auto t3 = Realtime::MicroSeconds(); + knnTotalTime += t2 - t1; + linearTotalTime += t3 - t2; + + auto &knnResult = knnCtx->result(); + // TODO: check + topk1Hits += i == knnResult[0].key(); + + auto &linearResult = linearCtx->result(); + ASSERT_EQ(topk, linearResult.size()); + ASSERT_EQ(i, linearResult[0].key()); + + for (size_t k = 0; k < topk; ++k) { + totalCnts++; + for (size_t j = 0; j < topk; ++j) { + if (linearResult[j].key() == knnResult[k].key()) { + totalHits++; + break; + } + } + } + } + + float recall = totalHits * step * step * 1.0f / totalCnts; + float topk1Recall = topk1Hits * step * 1.0f / doc_cnt; + float cost = linearTotalTime * 1.0f / knnTotalTime; + + EXPECT_GT(recall, 0.90f); + EXPECT_GT(topk1Recall, 0.80f); + EXPECT_GT(cost, 2.0f); +} + +TEST_F(DiskAnnSearcherTest, TestNodeCache) { + IndexBuilder::Pointer builder = IndexFactory::CreateBuilder("DiskAnnBuilder"); + ASSERT_NE(builder, nullptr); + + auto holder = + make_shared>(dim); + size_t doc_cnt = 10000UL; + for (size_t i = 0; i < doc_cnt; i++) { + NumericalVector vec(dim); + for (size_t j = 0; j < dim; ++j) { + vec[j] = i; + } + ASSERT_TRUE(holder->emplace(i, vec)); + } + + Params params; + + params.set("proxima.diskann.builder.max_degree", 32); + params.set("proxima.diskann.builder.list_size", 300); + params.set("proxima.diskann.builder.max_pq_chunk_num", 32); + params.set("proxima.diskann.builder.threads", 4); + + ASSERT_EQ(0, builder->init(*_index_meta_ptr, params)); + + ASSERT_EQ(0, builder->train(holder)); + + ASSERT_EQ(0, builder->build(holder)); + + auto dumper = IndexFactory::CreateDumper("FileDumper"); + ASSERT_NE(dumper, nullptr); + + string path = _dir + "/TestNodeCache"; + ASSERT_EQ(0, dumper->create(path)); + ASSERT_EQ(0, builder->dump(dumper)); + ASSERT_EQ(0, dumper->close()); + + auto &stats = builder->stats(); + ASSERT_EQ(doc_cnt, stats.trained_count()); + ASSERT_EQ(doc_cnt, stats.built_count()); + ASSERT_EQ(doc_cnt, stats.dumped_count()); + ASSERT_EQ(0UL, stats.discarded_count()); + ASSERT_GT(stats.trained_costtime(), 0UL); + ASSERT_GT(stats.built_costtime(), 0UL); + + // test searcher + IndexSearcher::Pointer searcher = + IndexFactory::CreateSearcher("DiskAnnSearcher"); + ASSERT_TRUE(searcher != nullptr); + + Params search_params; + search_params.set("proxima.diskann.searcher.cache_node_num", 32); + search_params.set("proxima.diskann.searcher.list_size", 500); + + ASSERT_EQ(0, searcher->init(search_params)); + + auto storage = IndexFactory::CreateStorage("FileReadStorage"); + ASSERT_EQ(0, storage->open(path, false)); + ASSERT_EQ(0, searcher->load(storage, IndexMetric::Pointer())); + auto ctx = searcher->create_context(); + ASSERT_TRUE(!!ctx); + + auto linearCtx = searcher->create_context(); + auto linearByPKeysCtx = searcher->create_context(); + auto knnCtx = searcher->create_context(); + + ASSERT_TRUE(!!linearCtx); + ASSERT_TRUE(!!linearByPKeysCtx); + ASSERT_TRUE(!!knnCtx); + + NumericalVector vec(dim); + IndexQueryMeta qmeta(IndexMeta::DataType::DT_FP32, dim); + size_t topk = 200; + uint64_t knnTotalTime = 0; + uint64_t linearTotalTime = 0; + int totalHits = 0; + int totalCnts = 0; + int topk1Hits = 0; + linearCtx->set_topk(topk); + linearByPKeysCtx->set_topk(topk); + knnCtx->set_topk(topk); + + size_t step = 500; + for (size_t i = 0; i < doc_cnt; i += step) { + for (size_t j = 0; j < dim; ++j) { + vec[j] = i + 0.1f; + } + auto t1 = Realtime::MicroSeconds(); + ASSERT_EQ(0, searcher->search_impl(vec.data(), qmeta, knnCtx)); + auto t2 = Realtime::MicroSeconds(); + + ASSERT_EQ(0, searcher->search_bf_impl(vec.data(), qmeta, linearCtx)); + auto t3 = Realtime::MicroSeconds(); + knnTotalTime += t2 - t1; + linearTotalTime += t3 - t2; + + auto &knnResult = knnCtx->result(); + // TODO: check + topk1Hits += i == knnResult[0].key(); + + auto &linearResult = linearCtx->result(); + ASSERT_EQ(topk, linearResult.size()); + ASSERT_EQ(i, linearResult[0].key()); + + for (size_t k = 0; k < topk; ++k) { + totalCnts++; + for (size_t j = 0; j < topk; ++j) { + if (linearResult[j].key() == knnResult[k].key()) { + totalHits++; + break; + } + } + } + } + + float recall = totalHits * step * step * 1.0f / totalCnts; + float topk1Recall = topk1Hits * step * 1.0f / doc_cnt; + float cost = linearTotalTime * 1.0f / knnTotalTime; + + EXPECT_GT(recall, 0.90f); + EXPECT_GT(topk1Recall, 0.80f); + EXPECT_GT(cost, 2.0f); +} + +TEST_F(DiskAnnSearcherTest, TestFilter) { + IndexBuilder::Pointer builder = IndexFactory::CreateBuilder("DiskAnnBuilder"); + ASSERT_NE(builder, nullptr); + + auto holder = + make_shared>(dim); + size_t doc_cnt = 10000UL; + for (size_t i = 0; i < doc_cnt; i++) { + NumericalVector vec(dim); + for (size_t j = 0; j < dim; ++j) { + vec[j] = i; + } + ASSERT_TRUE(holder->emplace(i, vec)); + } + + Params params; + + params.set("proxima.diskann.builder.max_degree", 32); + params.set("proxima.diskann.builder.list_size", 300); + params.set("proxima.diskann.builder.max_pq_chunk_num", 32); + params.set("proxima.diskann.builder.threads", 4); + + ASSERT_EQ(0, builder->init(*_index_meta_ptr, params)); + + ASSERT_EQ(0, builder->train(holder)); + + ASSERT_EQ(0, builder->build(holder)); + + auto dumper = IndexFactory::CreateDumper("FileDumper"); + ASSERT_NE(dumper, nullptr); + + string path = _dir + "/TestFilter"; + ASSERT_EQ(0, dumper->create(path)); + ASSERT_EQ(0, builder->dump(dumper)); + ASSERT_EQ(0, dumper->close()); + + auto &stats = builder->stats(); + ASSERT_EQ(doc_cnt, stats.trained_count()); + ASSERT_EQ(doc_cnt, stats.built_count()); + ASSERT_EQ(doc_cnt, stats.dumped_count()); + ASSERT_EQ(0UL, stats.discarded_count()); + ASSERT_GT(stats.trained_costtime(), 0UL); + ASSERT_GT(stats.built_costtime(), 0UL); + + // test searcher + IndexSearcher::Pointer searcher = + IndexFactory::CreateSearcher("DiskAnnSearcher"); + ASSERT_TRUE(searcher != nullptr); + + Params search_params; + search_params.set("proxima.diskann.searcher.cache_node_num", 32); + search_params.set("proxima.diskann.searcher.list_size", 500); + + ASSERT_EQ(0, searcher->init(search_params)); + + auto storage = IndexFactory::CreateStorage("FileReadStorage"); + ASSERT_EQ(0, storage->open(path, false)); + ASSERT_EQ(0, searcher->load(storage, IndexMetric::Pointer())); + auto ctx = searcher->create_context(); + ASSERT_TRUE(!!ctx); + + auto linearCtx = searcher->create_context(); + auto linearByPKeysCtx = searcher->create_context(); + auto knnCtx = searcher->create_context(); + + ASSERT_TRUE(!!linearCtx); + ASSERT_TRUE(!!linearByPKeysCtx); + ASSERT_TRUE(!!knnCtx); + + NumericalVector vec(dim); + IndexQueryMeta qmeta(IndexMeta::DataType::DT_FP32, dim); + + size_t topk = 200; + linearCtx->set_topk(topk); + linearByPKeysCtx->set_topk(topk); + knnCtx->set_topk(topk); + + size_t key = 50; + for (size_t j = 0; j < dim; ++j) { + vec[j] = key + 0.1f; + } + + // no filter + { + ASSERT_EQ(0, searcher->search_impl(vec.data(), qmeta, knnCtx)); + + auto &knnResult = knnCtx->result(); + ASSERT_EQ(topk, knnResult.size()); + ASSERT_EQ(50UL, knnResult[0].key()); + ASSERT_EQ(51UL, knnResult[1].key()); + ASSERT_EQ(49UL, knnResult[2].key()); + + ASSERT_EQ(0, searcher->search_bf_impl(vec.data(), qmeta, linearCtx)); + + auto &linearResult = linearCtx->result(); + ASSERT_EQ(topk, linearResult.size()); + ASSERT_EQ(50UL, linearResult[0].key()); + ASSERT_EQ(51UL, linearResult[1].key()); + ASSERT_EQ(49UL, linearResult[2].key()); + } + + // with filter + { + auto filterFunc = [](uint64_t key) { + if (key == 50UL || key == 51UL || key == 49UL) { + return true; + } + return false; + }; + + + knnCtx->set_filter(filterFunc); + ASSERT_EQ(0, searcher->search_impl(vec.data(), qmeta, knnCtx)); + + auto &knnResult = knnCtx->result(); + ASSERT_EQ(topk, knnResult.size()); + ASSERT_EQ(52UL, knnResult[0].key()); + ASSERT_EQ(48UL, knnResult[1].key()); + ASSERT_EQ(53UL, knnResult[2].key()); + + linearCtx->set_filter(filterFunc); + ASSERT_EQ(0, searcher->search_bf_impl(vec.data(), qmeta, linearCtx)); + + auto &linearResult = linearCtx->result(); + ASSERT_EQ(topk, linearResult.size()); + ASSERT_EQ(52UL, linearResult[0].key()); + ASSERT_EQ(48UL, linearResult[1].key()); + ASSERT_EQ(53UL, linearResult[2].key()); + } +} + +TEST_F(DiskAnnSearcherTest, TestGroup) { + IndexBuilder::Pointer builder = IndexFactory::CreateBuilder("DiskAnnBuilder"); + ASSERT_NE(builder, nullptr); + + auto holder = + make_shared>(dim); + size_t doc_cnt = 10000UL; + for (size_t i = 0; i < doc_cnt; i++) { + NumericalVector vec(dim); + for (size_t j = 0; j < dim; ++j) { + vec[j] = i / 10.0; + } + ASSERT_TRUE(holder->emplace(i, vec)); + } + + Params params; + + params.set("proxima.diskann.builder.max_degree", 32); + params.set("proxima.diskann.builder.list_size", 300); + params.set("proxima.diskann.builder.max_pq_chunk_num", 32); + params.set("proxima.diskann.builder.threads", 4); + + ASSERT_EQ(0, builder->init(*_index_meta_ptr, params)); + + ASSERT_EQ(0, builder->train(holder)); + + ASSERT_EQ(0, builder->build(holder)); + + auto dumper = IndexFactory::CreateDumper("FileDumper"); + ASSERT_NE(dumper, nullptr); + + string path = _dir + "/TestGroup"; + ASSERT_EQ(0, dumper->create(path)); + ASSERT_EQ(0, builder->dump(dumper)); + ASSERT_EQ(0, dumper->close()); + + auto &stats = builder->stats(); + ASSERT_EQ(doc_cnt, stats.trained_count()); + ASSERT_EQ(doc_cnt, stats.built_count()); + ASSERT_EQ(doc_cnt, stats.dumped_count()); + ASSERT_EQ(0UL, stats.discarded_count()); + ASSERT_GT(stats.trained_costtime(), 0UL); + ASSERT_GT(stats.built_costtime(), 0UL); + + // test searcher + IndexSearcher::Pointer searcher = + IndexFactory::CreateSearcher("DiskAnnSearcher"); + ASSERT_TRUE(searcher != nullptr); + + Params search_params; + search_params.set("proxima.diskann.searcher.list_size", 500); + + ASSERT_EQ(0, searcher->init(search_params)); + + auto storage = IndexFactory::CreateStorage("FileReadStorage"); + ASSERT_EQ(0, storage->open(path, false)); + ASSERT_EQ(0, searcher->load(storage, IndexMetric::Pointer())); + auto ctx = searcher->create_context(); + ASSERT_TRUE(!!ctx); + + NumericalVector vec(dim); + IndexQueryMeta qmeta(IndexMeta::DataType::DT_FP32, dim); + size_t group_topk = 20; + uint64_t total_time = 0; + + auto groupbyFunc = [](uint64_t key) { + uint32_t group_id = key / 10 % 10; + + // std::cout << "key: " << key << ", group id: " << group_id << std::endl; + + return std::string("g_") + std::to_string(group_id); + }; + + size_t group_num = 5; + + ctx->set_group_params(group_num, group_topk); + ctx->set_group_by(groupbyFunc); + + size_t query_value = doc_cnt / 2; + for (size_t j = 0; j < dim; ++j) { + vec[j] = query_value / 10 + 0.1f; + } + + auto t1 = Realtime::MicroSeconds(); + ASSERT_EQ(0, searcher->search_impl(vec.data(), qmeta, ctx)); + auto t2 = Realtime::MicroSeconds(); + + total_time += t2 - t1; + + auto &group_result = ctx->group_result(); + + for (uint32_t i = 0; i < group_result.size(); ++i) { + const std::string &group_id = group_result[i].group_id(); + auto &result = group_result[i].docs(); + + ASSERT_GT(result.size(), 0); + std::cout << "Group ID: " << group_id << std::endl; + + for (uint32_t j = 0; j < result.size(); ++j) { + std::cout << "\tKey: " << result[j].key() << std::fixed + << std::setprecision(3) << ", Score: " << result[j].score() + << std::endl; + } + } + +#if 0 + // do linear search by p_keys test + auto groupbyFuncLinear = [](uint64_t key) { + uint32_t group_id = key % 10; + + return std::string("g_") + std::to_string(group_id); + }; + + auto linear_pk_ctx = searcher->create_context(); + + linear_pk_ctx->set_group_params(group_num, group_topk); + linear_pk_ctx->set_group_by(groupbyFuncLinear); + + std::vector> p_keys; + p_keys.resize(1); + p_keys[0] = {4, 3, 2, 1, 5, 6, 7, 8, 9, 10}; + + ASSERT_EQ(0, searcher->search_bf_by_p_keys_impl(vec.data(), p_keys, qmeta, + linear_pk_ctx)); + auto &linear_by_pkeys_group_result = linear_pk_ctx->group_result(); + ASSERT_EQ(linear_by_pkeys_group_result.size(), group_num); + + for (uint32_t i = 0; i < linear_by_pkeys_group_result.size(); ++i) { + const std::string &group_id = linear_by_pkeys_group_result[i].group_id(); + auto &result = linear_by_pkeys_group_result[i].docs(); + + ASSERT_GT(result.size(), 0); + std::cout << "Group ID: " << group_id << std::endl; + + for (uint32_t j = 0; j < result.size(); ++j) { + std::cout << "\tKey: " << result[j].key() << std::fixed + << std::setprecision(3) << ", Score: " << result[j].score() + << std::endl; + } + + ASSERT_EQ(10 - i, result[0].key()); + } +#endif +} + +TEST_F(DiskAnnSearcherTest, TestFetchVector) { + IndexBuilder::Pointer builder = IndexFactory::CreateBuilder("DiskAnnBuilder"); + ASSERT_NE(builder, nullptr); + + auto holder = + make_shared>(dim); + size_t doc_cnt = 10000UL; + for (size_t i = 0; i < doc_cnt; i++) { + NumericalVector vec(dim); + for (size_t j = 0; j < dim; ++j) { + vec[j] = i; + } + ASSERT_TRUE(holder->emplace(i, vec)); + } + + Params params; + + params.set("proxima.diskann.builder.max_degree", 32); + params.set("proxima.diskann.builder.list_size", 300); + params.set("proxima.diskann.builder.max_pq_chunk_num", 32); + params.set("proxima.diskann.builder.threads", 4); + + ASSERT_EQ(0, builder->init(*_index_meta_ptr, params)); + + ASSERT_EQ(0, builder->train(holder)); + + ASSERT_EQ(0, builder->build(holder)); + + auto dumper = IndexFactory::CreateDumper("FileDumper"); + ASSERT_NE(dumper, nullptr); + + string path = _dir + "/TestFetchVector"; + ASSERT_EQ(0, dumper->create(path)); + ASSERT_EQ(0, builder->dump(dumper)); + ASSERT_EQ(0, dumper->close()); + + auto &stats = builder->stats(); + ASSERT_EQ(doc_cnt, stats.trained_count()); + ASSERT_EQ(doc_cnt, stats.built_count()); + ASSERT_EQ(doc_cnt, stats.dumped_count()); + ASSERT_EQ(0UL, stats.discarded_count()); + ASSERT_GT(stats.trained_costtime(), 0UL); + ASSERT_GT(stats.built_costtime(), 0UL); + + // test searcher + IndexSearcher::Pointer searcher = + IndexFactory::CreateSearcher("DiskAnnSearcher"); + ASSERT_TRUE(searcher != nullptr); + + Params search_params; + search_params.set("proxima.diskann.searcher.list_size", 500); + + ASSERT_EQ(0, searcher->init(search_params)); + + auto storage = IndexFactory::CreateStorage("FileReadStorage"); + ASSERT_EQ(0, storage->open(path, false)); + ASSERT_EQ(0, searcher->load(storage, IndexMetric::Pointer())); + + size_t query_cnt = 20U; + auto linearCtx = searcher->create_context(); + auto knnCtx = searcher->create_context(); + auto linearByPKeysCtx = searcher->create_context(); + knnCtx->set_fetch_vector(true); + + for (size_t i = 0; i < doc_cnt; i += doc_cnt / 10) { + std::string vec_value; + ASSERT_EQ(0, searcher->get_vector(i, linearCtx, vec_value)); + + float vector_value = *(const float *)(vec_value.data()); + ASSERT_EQ(vector_value, i); + } + + size_t topk = 200; + linearCtx->set_topk(topk); + knnCtx->set_topk(topk); + uint64_t knnTotalTime = 0; + uint64_t linearTotalTime = 0; + + IndexQueryMeta qmeta(IndexMeta::DataType::DT_FP32, dim); + + NumericalVector vec(dim); + for (size_t i = 0; i < query_cnt; i++) { + for (size_t j = 0; j < dim; ++j) { + vec[j] = i; + } + + auto t1 = Realtime::MicroSeconds(); + ASSERT_EQ(0, searcher->search_impl(vec.data(), qmeta, knnCtx)); + auto t2 = Realtime::MicroSeconds(); + ASSERT_EQ(0, searcher->search_bf_impl(vec.data(), qmeta, linearCtx)); + auto t3 = Realtime::MicroSeconds(); + knnTotalTime += t2 - t1; + linearTotalTime += t3 - t2; + + auto &knnResult = knnCtx->result(); + ASSERT_EQ(topk, knnResult.size()); + + auto &linearResult = linearCtx->result(); + ASSERT_EQ(topk, linearResult.size()); + ASSERT_EQ(i, linearResult[0].key()); + + ASSERT_NE(knnResult[0].vector_string(), ""); + float vector_value = *((float *)(knnResult[0].vector_string().data())); + ASSERT_EQ(vector_value, i); + } +} + +TEST_F(DiskAnnSearcherTest, TestRnnSearch) { + IndexBuilder::Pointer builder = IndexFactory::CreateBuilder("DiskAnnBuilder"); + ASSERT_NE(builder, nullptr); + + auto holder = + make_shared>(dim); + size_t doc_cnt = 10000UL; + for (size_t i = 0; i < doc_cnt; i++) { + NumericalVector vec(dim); + for (size_t j = 0; j < dim; ++j) { + vec[j] = i; + } + ASSERT_TRUE(holder->emplace(i, vec)); + } + + Params params; + + params.set("proxima.diskann.builder.max_degree", 32); + params.set("proxima.diskann.builder.list_size", 300); + params.set("proxima.diskann.builder.max_pq_chunk_num", 32); + params.set("proxima.diskann.builder.threads", 4); + + ASSERT_EQ(0, builder->init(*_index_meta_ptr, params)); + + ASSERT_EQ(0, builder->train(holder)); + + ASSERT_EQ(0, builder->build(holder)); + + auto dumper = IndexFactory::CreateDumper("FileDumper"); + ASSERT_NE(dumper, nullptr); + + string path = _dir + "/TestRnnSearch"; + ASSERT_EQ(0, dumper->create(path)); + ASSERT_EQ(0, builder->dump(dumper)); + ASSERT_EQ(0, dumper->close()); + + auto &stats = builder->stats(); + ASSERT_EQ(doc_cnt, stats.trained_count()); + ASSERT_EQ(doc_cnt, stats.built_count()); + ASSERT_EQ(doc_cnt, stats.dumped_count()); + ASSERT_EQ(0UL, stats.discarded_count()); + ASSERT_GT(stats.trained_costtime(), 0UL); + ASSERT_GT(stats.built_costtime(), 0UL); + + // test searcher + IndexSearcher::Pointer searcher = + IndexFactory::CreateSearcher("DiskAnnSearcher"); + ASSERT_TRUE(searcher != nullptr); + + Params search_params; + search_params.set("proxima.diskann.searcher.list_size", 500); + + ASSERT_EQ(0, searcher->init(search_params)); + + auto storage = IndexFactory::CreateStorage("FileReadStorage"); + ASSERT_EQ(0, storage->open(path, false)); + ASSERT_EQ(0, searcher->load(storage, IndexMetric::Pointer())); + + auto ctx = searcher->create_context(); + ASSERT_TRUE(!!ctx); + + NumericalVector vec(dim); + for (size_t j = 0; j < dim; ++j) { + vec[j] = 0.0; + } + IndexQueryMeta qmeta(IndexMeta::DataType::DT_FP32, dim); + size_t topk = 50; + ctx->set_topk(topk); + ASSERT_EQ(0, searcher->search_impl(vec.data(), qmeta, ctx)); + auto &results = ctx->result(); + ASSERT_EQ(topk, results.size()); + + float radius = results[topk / 2].score(); + ctx->set_threshold(radius); + ASSERT_EQ(0, searcher->search_impl(vec.data(), qmeta, ctx)); + ASSERT_GT(topk, results.size()); + for (size_t k = 0; k < results.size(); ++k) { + ASSERT_GE(radius, results[k].score()); + } + + // Test Reset Threshold + ctx->reset_threshold(); + ASSERT_EQ(0, searcher->search_impl(vec.data(), qmeta, ctx)); + ASSERT_EQ(topk, results.size()); + ASSERT_LT(radius, results[topk - 1].score()); +} diff --git a/tests/db/CMakeLists.txt b/tests/db/CMakeLists.txt index 3ea846706..8659108ee 100644 --- a/tests/db/CMakeLists.txt +++ b/tests/db/CMakeLists.txt @@ -34,6 +34,7 @@ foreach(CC_SRCS ${ALL_TEST_SRCS}) core_knn_hnsw_rabitq core_knn_hnsw_sparse core_knn_ivf + core_knn_diskann core_mix_reducer core_metric core_utility diff --git a/tests/db/collection_test.cc b/tests/db/collection_test.cc index d50701dc3..29ec83ef2 100644 --- a/tests/db/collection_test.cc +++ b/tests/db/collection_test.cc @@ -4180,6 +4180,81 @@ TEST_F(CollectionTest, Feature_Optimize_HNSW_RABITQ) { } #endif +#if DISKANN_SUPPORTED +TEST_F(CollectionTest, Feature_Optimize_DiskAnn) { + auto func = [](MetricType metric_type, int concurrency) { + FileHelper::RemoveDirectory(col_path); + + int doc_count = 10000; + + auto schema = std::make_shared("diskann_demo"); + schema->set_max_doc_count_per_segment(MAX_DOC_COUNT_PER_SEGMENT); + + auto diskann_params = std::make_shared(metric_type); + schema->add_field(std::make_shared( + "dense_fp32", DataType::VECTOR_FP32, 128, false, diskann_params)); + + auto options = CollectionOptions{false, true, 64 * 1024 * 1024}; + auto collection = TestHelper::CreateCollectionWithDoc( + col_path, *schema, options, 0, doc_count, false); + + auto check_doc = [&]() { + for (int i = 0; i < doc_count; i++) { + auto expect_doc = TestHelper::CreateDoc(i, *schema); + auto result = collection->Fetch({expect_doc.pk()}); + ASSERT_TRUE(result.has_value()); + ASSERT_EQ(result.value().size(), 1); + ASSERT_EQ(result.value().count(expect_doc.pk()), 1); + auto doc = result.value()[expect_doc.pk()]; + ASSERT_NE(doc, nullptr); + if (*doc != expect_doc) { + std::cout << " doc:" << doc->to_detail_string() << std::endl; + std::cout << "expect_doc:" << expect_doc.to_detail_string() + << std::endl; + } + ASSERT_EQ(*doc, expect_doc); + } + }; + + check_doc(); + std::cout << "check success 1" << std::endl; + + ASSERT_TRUE(collection->Flush().ok()); + auto stats = collection->Stats().value(); + ASSERT_EQ(stats.doc_count, doc_count); + ASSERT_EQ(stats.index_completeness["dense_fp32"], 0); + + auto s = collection->Optimize(OptimizeOptions{concurrency}); + if (!s.ok()) { + std::cout << s.message() << std::endl; + } + ASSERT_TRUE(s.ok()); + + stats = collection->Stats().value(); + ASSERT_EQ(stats.doc_count, doc_count); + ASSERT_EQ(stats.index_completeness["dense_fp32"], 1); + + // check_doc(); + std::cout << "check success 2" << std::endl; + + collection.reset(); + auto result = Collection::Open(col_path, options); + ASSERT_TRUE(result.has_value()); + collection = std::move(result.value()); + + // check_doc(); + std::cout << "check success 3" << std::endl; + }; + + func(MetricType::L2, 0); + func(MetricType::L2, 4); + func(MetricType::IP, 0); + func(MetricType::IP, 4); + func(MetricType::COSINE, 0); + func(MetricType::COSINE, 4); +} +#endif + // **** CORNER CASES **** // TEST_F(CollectionTest, CornerCase_CreateAndOpen) { // Collection::CreateAndOpen diff --git a/tools/core/CMakeLists.txt b/tools/core/CMakeLists.txt index c36b26409..f3623a476 100644 --- a/tools/core/CMakeLists.txt +++ b/tools/core/CMakeLists.txt @@ -14,7 +14,7 @@ cc_binary( STRICT PACKED SRCS local_builder.cc INCS ${PROJECT_ROOT_DIR}/src/core/ - LIBS gflags yaml-cpp magic_enum core_framework core_metric core_quantizer core_utility core_knn_flat core_knn_flat_sparse core_knn_hnsw core_knn_hnsw_sparse core_knn_hnsw_rabitq core_knn_cluster core_knn_ivf core_interface + LIBS gflags yaml-cpp magic_enum core_framework core_metric core_quantizer core_utility core_knn_flat core_knn_flat_sparse core_knn_hnsw core_knn_hnsw_sparse core_knn_hnsw_rabitq core_knn_cluster core_knn_ivf core_interface core_knn_diskann ) cc_binary( @@ -22,7 +22,7 @@ cc_binary( STRICT PACKED SRCS recall.cc INCS ${PROJECT_ROOT_DIR}/src/core/ - LIBS gflags yaml-cpp magic_enum core_framework core_metric core_quantizer core_utility core_knn_flat core_knn_flat_sparse core_knn_hnsw core_knn_hnsw_sparse core_knn_hnsw_rabitq core_knn_cluster core_knn_ivf roaring core_interface + LIBS gflags yaml-cpp magic_enum core_framework core_metric core_quantizer core_utility core_knn_flat core_knn_flat_sparse core_knn_hnsw core_knn_hnsw_sparse core_knn_hnsw_rabitq core_knn_cluster core_knn_ivf roaring core_interface core_knn_diskann ) cc_binary( @@ -30,24 +30,24 @@ cc_binary( STRICT PACKED SRCS bench.cc INCS ${PROJECT_ROOT_DIR}/src/core/ - LIBS gflags yaml-cpp magic_enum core_framework core_metric core_quantizer core_utility core_knn_flat core_knn_flat_sparse core_knn_hnsw core_knn_hnsw_sparse core_knn_hnsw_rabitq core_knn_cluster core_knn_ivf roaring core_interface + LIBS gflags yaml-cpp magic_enum core_framework core_metric core_quantizer core_utility core_knn_flat core_knn_flat_sparse core_knn_hnsw core_knn_hnsw_sparse core_knn_hnsw_rabitq core_knn_cluster core_knn_ivf roaring core_interface core_knn_diskann ) cc_binary( NAME recall_original STRICT PACKED - SRCS recall_original.cc flow.cc + SRCS recall_original.cc INCS ${PROJECT_ROOT_DIR}/src/core/ - LIBS gflags yaml-cpp magic_enum core_framework core_metric core_quantizer core_utility core_knn_flat core_knn_flat_sparse core_knn_hnsw core_knn_hnsw_sparse core_knn_hnsw_rabitq core_knn_cluster core_knn_ivf roaring core_interface + LIBS gflags yaml-cpp magic_enum core_framework core_metric core_quantizer core_utility core_knn_flat core_knn_flat_sparse core_knn_hnsw core_knn_hnsw_sparse core_knn_hnsw_rabitq core_knn_cluster core_knn_ivf roaring core_interface core_knn_diskann ) cc_binary( NAME bench_original STRICT PACKED - SRCS bench_original.cc flow.cc + SRCS bench_original.cc INCS ${PROJECT_ROOT_DIR}/src/core/ - LIBS gflags yaml-cpp magic_enum core_framework core_metric core_quantizer core_utility core_knn_flat core_knn_flat_sparse core_knn_hnsw core_knn_hnsw_sparse core_knn_hnsw_rabitq core_knn_cluster core_knn_ivf roaring core_interface + LIBS gflags yaml-cpp magic_enum core_framework core_metric core_quantizer core_utility core_knn_flat core_knn_flat_sparse core_knn_hnsw core_knn_hnsw_sparse core_knn_hnsw_rabitq core_knn_cluster core_knn_ivf roaring core_interface core_knn_diskann ) cc_binary( @@ -55,5 +55,5 @@ cc_binary( STRICT PACKED SRCS local_builder_original.cc INCS ${PROJECT_ROOT_DIR}/src/core/ - LIBS gflags yaml-cpp magic_enum core_framework core_metric core_quantizer core_utility core_knn_flat core_knn_flat_sparse core_knn_hnsw core_knn_hnsw_sparse core_knn_hnsw_rabitq core_knn_cluster core_knn_ivf core_interface + LIBS gflags yaml-cpp magic_enum core_framework core_metric core_quantizer core_utility core_knn_flat core_knn_flat_sparse core_knn_hnsw core_knn_hnsw_sparse core_knn_hnsw_rabitq core_knn_cluster core_knn_ivf core_interface core_knn_diskann )