diff --git a/.gitignore b/.gitignore index 131c540e7..cbf21c314 100644 --- a/.gitignore +++ b/.gitignore @@ -14,6 +14,7 @@ tags compile_commands.json .python-version .vscode +.cache # Python related files __pycache__/ diff --git a/bindings/python/include/svs/python/ivf_loader.h b/bindings/python/include/svs/python/ivf_loader.h new file mode 100644 index 000000000..031d5d9ab --- /dev/null +++ b/bindings/python/include/svs/python/ivf_loader.h @@ -0,0 +1,195 @@ +/* + * Copyright 2026 Intel Corporation + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +// svs +#include "svs/core/distance.h" +#include "svs/index/ivf/data_traits.h" +#include "svs/lib/datatype.h" +#include "svs/lib/exception.h" +#include "svs/lib/float16.h" +#include "svs/lib/saveload.h" + +// toml +#include + +// stl +#include +#include + +namespace svs::python::ivf_loader { + +/// +/// @brief Rebind an allocator to a different type. +/// +template +using RebindAlloc = typename std::allocator_traits::template rebind_alloc; + +/// +/// @brief Read and parse the data type configuration from a saved IVF index. +/// +/// @param config_path Path to the configuration directory +/// @return The parsed DataTypeConfig +/// +inline svs::index::ivf::DataTypeConfig read_data_type_config(const std::string& config_path +) { + auto config_file = std::filesystem::path(config_path) / svs::lib::config_file_name; + auto table = toml::parse_file(config_file.string()); + + // The data_type_config is nested inside "object" section + auto object_node = table["object"]; + if (!object_node) { + throw ANNEXCEPTION("Config file missing 'object' section."); + } + auto* object_table = object_node.as_table(); + if (!object_table) { + throw ANNEXCEPTION("'object' section is not a table."); + } + + // Get the data_type_config section from object + auto data_type_node = (*object_table)["data_type_config"]; + if (!data_type_node) { + throw ANNEXCEPTION("Config file missing 'data_type_config' section."); + } + + // Convert to table and create ContextFreeLoadTable + auto* data_type_table = data_type_node.as_table(); + if (!data_type_table) { + throw ANNEXCEPTION("data_type_config is not a table"); + } + auto ctx_free = svs::lib::ContextFreeLoadTable(*data_type_table); + return svs::index::ivf::DataTypeConfig::load(ctx_free); +} + +/// +/// @brief Generic loader function template for IVF index assembly. +/// +/// This template reduces boilerplate by providing a generic loader that can be +/// instantiated with different centroid and data types. +/// +/// @tparam IndexType The IVF index type (svs::IVF or svs::DynamicIVF) +/// @tparam CentroidType The centroid type (Float16 or BFloat16) +/// @tparam DataType The data type for the clusters +/// +template +IndexType load_typed( + const std::string& config_path, + const std::string& data_path, + svs::DistanceType distance_type, + size_t num_threads, + size_t intra_query_threads +) { + return IndexType::template assemble( + config_path, data_path, distance_type, num_threads, intra_query_threads + ); +} + +/// +/// @brief Loader for uncompressed IVF data with type dispatch. +/// +/// Dispatches to the appropriate loader based on element type and centroid type. +/// +/// @tparam IndexType The type of index to return +/// @tparam DataContainer The data container template (SimpleData or BlockedData) +/// @tparam Allocator The allocator type for the data (will be rebound to element type) +/// +template < + typename IndexType, + template + class DataContainer, + typename Allocator> +IndexType load_uncompressed_with_dispatch( + const std::string& config_path, + const std::string& data_path, + svs::DistanceType distance_type, + size_t num_threads, + size_t intra_query_threads, + const svs::index::ivf::DataTypeConfig& data_config +) { + bool is_f16_centroids = (data_config.centroid_type == svs::DataType::float16); + bool is_f16_data = (data_config.element_type == svs::DataType::float16); + + // Dispatch based on data type and centroid type combinations + // Rebind the allocator to the appropriate element type + if (is_f16_data) { + using ReboundAlloc = RebindAlloc; + using DataType = DataContainer; + if (is_f16_centroids) { + return load_typed( + config_path, data_path, distance_type, num_threads, intra_query_threads + ); + } else { + return load_typed( + config_path, data_path, distance_type, num_threads, intra_query_threads + ); + } + } else { + using ReboundAlloc = RebindAlloc; + using DataType = DataContainer; + if (is_f16_centroids) { + return load_typed( + config_path, data_path, distance_type, num_threads, intra_query_threads + ); + } else { + return load_typed( + config_path, data_path, distance_type, num_threads, intra_query_threads + ); + } + } +} + +/// +/// @brief Generic IVF index loader with type dispatch based on saved configuration. +/// +/// @tparam IndexType The type of index to return (svs::IVF or svs::DynamicIVF) +/// @tparam DataContainer The data container template (SimpleData or BlockedData) +/// @tparam Allocator The allocator type for uncompressed data +/// +template < + typename IndexType, + template + class DataContainer, + typename Allocator> +IndexType load_index_auto( + const std::string& config_path, + const std::string& data_path, + svs::DistanceType distance_type, + size_t num_threads, + size_t intra_query_threads +) { + auto data_config = read_data_type_config(config_path); + + // Dispatch based on schema - only uncompressed supported in public repo + if (data_config.schema == "uncompressed_data") { + return load_uncompressed_with_dispatch( + config_path, + data_path, + distance_type, + num_threads, + intra_query_threads, + data_config + ); + } + + throw ANNEXCEPTION( + "Unknown or unsupported data type schema: ", + data_config.schema, + ". Only uncompressed data is supported in the public repository. " + ); +} + +} // namespace svs::python::ivf_loader diff --git a/bindings/python/src/dynamic_ivf.cpp b/bindings/python/src/dynamic_ivf.cpp index 52950ad38..784ee5f8a 100644 --- a/bindings/python/src/dynamic_ivf.cpp +++ b/bindings/python/src/dynamic_ivf.cpp @@ -19,12 +19,18 @@ #include "svs/python/common.h" #include "svs/python/core.h" #include "svs/python/ivf.h" +#include "svs/python/ivf_loader.h" #include "svs/python/manager.h" // svs +#include "svs/index/ivf/data_traits.h" #include "svs/lib/dispatcher.h" +#include "svs/lib/saveload.h" #include "svs/orchestrators/dynamic_ivf.h" +// toml +#include + // pybind #include #include @@ -34,6 +40,7 @@ #include // stl +#include #include ///// @@ -342,6 +349,20 @@ void save_index( index.save(config_path, data_dir); } +// Load with auto-detection from saved config using common template dispatcher +svs::DynamicIVF load_index_auto( + const std::string& config_path, + const std::string& data_path, + svs::DistanceType distance_type, + size_t num_threads, + size_t intra_query_threads = 1 +) { + return svs::python::ivf_loader:: + load_index_auto( + config_path, data_path, distance_type, num_threads, intra_query_threads + ); +} + void wrap(py::module& m) { std::string name = "DynamicIVF"; py::class_ dynamic_ivf( @@ -530,6 +551,37 @@ It is the caller's responsibility to ensure that no existing data will be overwritten when saving the index to this directory. )" ); + + // Loading + dynamic_ivf.def_static( + "load", + &load_index_auto, + py::arg("config_directory"), + py::arg("data_directory"), + py::arg("distance") = svs::L2, + py::arg("num_threads") = 1, + py::arg("intra_query_threads") = 1, + R"( +Load a saved DynamicIVF index from disk. + +The data type (uncompressed with float32 or float16) and centroid type (bfloat16 or float16) +are automatically detected from the saved configuration file. + +Args: + config_directory: Directory where index configuration was saved. + data_directory: Directory where the dataset was saved. + distance: The distance function to use. + num_threads: The number of threads to use for queries. + intra_query_threads: Number of threads for intra-query parallelism (default: 1). + +Returns: + A loaded DynamicIVF index ready for searching and modifications. + +Note: + This method auto-detects the data type from the saved configuration. + The index must have been saved with a version that includes data type information. + )" + ); } } // namespace svs::python::dynamic_ivf diff --git a/bindings/python/src/ivf.cpp b/bindings/python/src/ivf.cpp index 7d231c998..0ee08ec50 100644 --- a/bindings/python/src/ivf.cpp +++ b/bindings/python/src/ivf.cpp @@ -19,6 +19,7 @@ #include "svs/python/common.h" #include "svs/python/core.h" #include "svs/python/dispatch.h" +#include "svs/python/ivf_loader.h" #include "svs/python/manager.h" // pybind11 @@ -27,13 +28,18 @@ // svs #include "svs/core/data/simple.h" #include "svs/core/distance.h" +#include "svs/index/ivf/data_traits.h" #include "svs/lib/array.h" #include "svs/lib/datatype.h" #include "svs/lib/dispatcher.h" #include "svs/lib/float16.h" #include "svs/lib/meta.h" +#include "svs/lib/saveload.h" #include "svs/orchestrators/ivf.h" +// toml +#include + // pybind #include #include @@ -41,6 +47,7 @@ // stl #include +#include #include #include #include @@ -532,6 +539,27 @@ auto load_clustering(const std::string& clustering_path, size_t num_threads = 1) } } +// Save the IVF index to directories +void save_index( + svs::IVF& index, const std::string& config_path, const std::string& data_dir +) { + index.save(config_path, data_dir); +} + +// Load with auto-detection from saved config using common template dispatcher +svs::IVF load_index( + const std::string& config_path, + const std::string& data_path, + svs::DistanceType distance_type, + size_t num_threads, + size_t intra_query_threads = 1 +) { + return svs::python::ivf_loader:: + load_index_auto( + config_path, data_path, distance_type, num_threads, intra_query_threads + ); +} + } // namespace detail void wrap(py::module& m) { @@ -630,6 +658,60 @@ void wrap(py::module& m) { // IVF Specific Extensions. add_interface(ivf); + // Index Saving. + ivf.def( + "save", + &detail::save_index, + py::arg("config_directory"), + py::arg("data_directory"), + R"( +Save a constructed index to disk (useful following index construction). + +Args: + config_directory: Directory where index configuration information will be saved. + data_directory: Directory where the dataset will be saved. + +Note: All directories should be separate to avoid accidental name collision with any +auxiliary files that are needed when saving the various components of the index. + +If the directory does not exist, it will be created if its parent exists. + +It is the caller's responsibility to ensure that no existing data will be +overwritten when saving the index to this directory. + )" + ); + + // Index Loading. + ivf.def_static( + "load", + &detail::load_index, + py::arg("config_directory"), + py::arg("data_directory"), + py::arg("distance") = svs::L2, + py::arg("num_threads") = 1, + py::arg("intra_query_threads") = 1, + R"( +Load a saved IVF index from disk. + +The data type (uncompressed with float32 or float16) and centroid type (bfloat16 or float16) +are automatically detected from the saved configuration file. + +Args: + config_directory: Directory where index configuration was saved. + data_directory: Directory where the dataset was saved. + distance: The distance function to use. + num_threads: The number of threads to use for queries. + intra_query_threads: Number of threads for intra-query parallelism (default: 1). + +Returns: + A loaded IVF index ready for searching. + +Note: + This method auto-detects the data type from the saved configuration. + The index must have been saved with a version that includes data type information. + )" + ); + // Reconstruction. // add_reconstruct_interface(ivf); diff --git a/bindings/python/tests/test_dynamic_ivf.py b/bindings/python/tests/test_dynamic_ivf.py new file mode 100644 index 000000000..45749f979 --- /dev/null +++ b/bindings/python/tests/test_dynamic_ivf.py @@ -0,0 +1,253 @@ +# Copyright 2026 Intel Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Tests for the Dynamic IVF index with save/load auto-detection. +import unittest +import os +import numpy as np +from tempfile import TemporaryDirectory + +import svs + +# Local dependencies +from .common import \ + test_data_svs, \ + test_data_vecs, \ + test_data_dims, \ + test_queries, \ + test_groundtruth_l2, \ + test_number_of_vectors + +from .dynamic import ReferenceDataset + + +class DynamicIVFTester(unittest.TestCase): + """ + Test building, adding, deleting points from the dynamic IVF index. + Tests include save/load with auto-detection for uncompressed data types. + """ + + def id_check(self, index, ids): + """Check that the index contains exactly the given IDs.""" + # Check that every id in `ids` is in the index. + for this_id in ids: + self.assertTrue(index.has_id(this_id)) + + # Check that every id in the index is in `ids` + all_ids = index.all_ids() + for this_id in all_ids: + self.assertTrue(this_id in ids) + + def recall_check( + self, + index: svs.DynamicIVF, + reference: ReferenceDataset, + num_neighbors: int, + expected_recall: float, + recall_delta: float, + ): + """Check recall and test save/reload functionality with auto-detection.""" + gt = reference.ground_truth(num_neighbors) + I, D = index.search(reference.queries, num_neighbors) + recall = svs.k_recall_at(gt, I, num_neighbors, num_neighbors) + print(f" Recall: {recall}") + self.assertTrue(recall < expected_recall + recall_delta) + self.assertTrue(recall > expected_recall - recall_delta) + + # Make sure saving and reloading work with auto-detection. + with TemporaryDirectory() as tempdir: + configdir = os.path.join(tempdir, "config") + datadir = os.path.join(tempdir, "data") + index.save(configdir, datadir) + + # Load with auto-detection - should detect data type automatically + reloaded = svs.DynamicIVF.load( + config_directory = configdir, + data_directory = datadir, + distance = svs.DistanceType.L2, + num_threads = 2, + ) + + # Set the same search parameters as the original index + reloaded.search_parameters = index.search_parameters + + # Get recall with the same search parameters. + I, D = reloaded.search(reference.queries, num_neighbors) + reloaded_recall = svs.k_recall_at(gt, I, num_neighbors, num_neighbors) + + print(f" Reloaded Recall: {reloaded_recall}") + self.assertTrue(reloaded_recall < expected_recall + recall_delta) + self.assertTrue(reloaded_recall > expected_recall - recall_delta) + + def _build_clustering(self, data_loader, num_threads): + """Build IVF clustering from a data loader.""" + build_params = svs.IVFBuildParameters( + num_centroids = 64, + minibatch_size = 128, + num_iterations = 10, + is_hierarchical = False, + training_fraction = 0.8, + hierarchical_level1_clusters = 0, + seed = 42, + ) + + return svs.Clustering.build( + build_parameters = build_params, + data_loader = data_loader, + distance = svs.DistanceType.L2, + num_threads = num_threads, + ) + + def test_uncompressed(self): + """Test DynamicIVF with uncompressed float32 data and auto-detection on reload.""" + num_threads = 2 + num_neighbors = 10 + expected_recall = 0.65 + expected_recall_delta = 0.20 + + reference = ReferenceDataset(num_threads = num_threads) + data, ids = reference.new_ids(5000) + + with TemporaryDirectory() as tempdir: + data_file = os.path.join(tempdir, "data.fvecs") + svs.write_vecs(data, data_file) + + data_loader = svs.VectorDataLoader( + data_file, + svs.DataType.float32, + dims = data.shape[1] + ) + + clustering = self._build_clustering(data_loader, num_threads) + + # Assemble DynamicIVF from clustering + index = svs.DynamicIVF.assemble_from_clustering( + clustering = clustering, + data_loader = data_loader, + ids = ids, + distance = svs.DistanceType.L2, + num_threads = num_threads, + ) + + print(f"Testing uncompressed: {index.experimental_backend_string}") + + # Set search parameters + search_params = svs.IVFSearchParameters(n_probes = 20, k_reorder = 100) + index.search_parameters = search_params + + # Perform an ID check + self.id_check(index, reference.ids()) + + # Groundtruth Check with save/reload auto-detection + print("Initial uncompressed") + self.recall_check( + index, reference, num_neighbors, expected_recall, expected_recall_delta + ) + + # Add and delete some vectors + (add_data, add_ids) = reference.new_ids(1000) + index.add(add_data, add_ids) + print("After add") + self.id_check(index, reference.ids()) + self.recall_check( + index, reference, num_neighbors, expected_recall, expected_recall_delta + ) + + delete_ids = reference.remove_ids(1000) + index.delete(delete_ids) + print("After delete") + self.id_check(index, reference.ids()) + self.recall_check( + index, reference, num_neighbors, expected_recall, expected_recall_delta + ) + + def test_build_from_loader(self): + """Test building DynamicIVF using a VectorDataLoader and explicit IDs.""" + num_threads = 2 + + loader = svs.VectorDataLoader(test_data_svs, svs.DataType.float32, dims = test_data_dims) + + # Sequential IDs + ids = np.arange(test_number_of_vectors, dtype = np.uint64) + + # Build IVF clustering + build_params = svs.IVFBuildParameters( + num_centroids = 128, + minibatch_size = 128, + num_iterations = 10, + is_hierarchical = False, + training_fraction = 0.8, + hierarchical_level1_clusters = 0, + seed = 42, + ) + + clustering = svs.Clustering.build( + build_parameters = build_params, + data_loader = loader, + distance = svs.DistanceType.L2, + num_threads = num_threads, + ) + + # Assemble DynamicIVF from clustering + index = svs.DynamicIVF.assemble_from_clustering( + clustering = clustering, + data_loader = loader, + ids = ids, + distance = svs.DistanceType.L2, + num_threads = num_threads, + ) + + # Basic invariants + self.assertEqual(index.size, test_number_of_vectors) + self.assertEqual(index.dimensions, test_data_dims) + self.assertTrue(index.has_id(0)) + self.assertTrue(index.has_id(test_number_of_vectors - 1)) + + # Search test + queries = svs.read_vecs(test_queries) + groundtruth = svs.read_vecs(test_groundtruth_l2) + k = 10 + + search_params = svs.IVFSearchParameters(n_probes = 30, k_reorder = 200) + index.search_parameters = search_params + + I, D = index.search(queries, k) + self.assertEqual(I.shape[1], k) + recall = svs.k_recall_at(groundtruth, I, k, k) + print(f"Build from loader recall: {recall}") + self.assertTrue(0.5 < recall <= 1.0) + + # Test save and load with auto-detection + with TemporaryDirectory() as tempdir: + configdir = os.path.join(tempdir, "config") + datadir = os.path.join(tempdir, "data") + index.save(configdir, datadir) + + # Reload from saved directories - auto-detect data type + reloaded = svs.DynamicIVF.load( + config_directory = configdir, + data_directory = datadir, + distance = svs.DistanceType.L2, + num_threads = num_threads + ) + + self.assertEqual(reloaded.size, test_number_of_vectors) + self.assertEqual(reloaded.dimensions, test_data_dims) + + # Set search parameters and verify recall + reloaded.search_parameters = search_params + I, D = reloaded.search(queries, k) + reloaded_recall = svs.k_recall_at(groundtruth, I, k, k) + print(f"Reloaded recall: {reloaded_recall}") + self.assertTrue(0.5 < reloaded_recall <= 1.0) diff --git a/bindings/python/tests/test_ivf.py b/bindings/python/tests/test_ivf.py index f8c416f22..f3253a5af 100644 --- a/bindings/python/tests/test_ivf.py +++ b/bindings/python/tests/test_ivf.py @@ -149,7 +149,6 @@ def _test_basic_inner( matcher, num_threads: int, skip_thread_test: bool = False, - first_iter: bool = False, test_single_query: bool = False, ): # Make sure that the number of threads is propagated correctly. @@ -192,9 +191,9 @@ def _test_basic_inner( if test_single_query: self._test_single_query(ivf, queries) - def _test_basic(self, loader, matcher, first_iter: bool = False): + def _test_basic(self, loader, matcher, test_single_query: bool = False): num_threads = 2 - print("Assemble from file") + print(f"Assemble from file. Data loader type: {matcher.kind}") ivf = svs.IVF.assemble_from_file( clustering_path = test_ivf_clustering, data_loader = loader, @@ -205,8 +204,7 @@ def _test_basic(self, loader, matcher, first_iter: bool = False): print(f"Testing: {ivf.experimental_backend_string}") self._test_basic_inner(ivf, matcher, num_threads, skip_thread_test = False, - first_iter = first_iter, - test_single_query = first_iter, + test_single_query = test_single_query, ) print("Load and Assemble from clustering") @@ -220,10 +218,32 @@ def _test_basic(self, loader, matcher, first_iter: bool = False): print(f"Testing: {ivf.experimental_backend_string}") self._test_basic_inner(ivf, matcher, num_threads, skip_thread_test = False, - first_iter = first_iter, - test_single_query = first_iter, + test_single_query = test_single_query, ) + # Test saving and reloading for all data types + print(f"Testing save and load for {matcher.kind}") + with TemporaryDirectory() as tempdir: + configdir = os.path.join(tempdir, "config") + datadir = os.path.join(tempdir, "data") + ivf.save(configdir, datadir) + + # Reload from saved directories - data type auto-detected from config + reloaded = svs.IVF.load( + config_directory = configdir, + data_directory = datadir, + distance = svs.DistanceType.L2, + num_threads = num_threads + ) + + print(f"Testing reloaded: {reloaded.experimental_backend_string}") + self._test_basic_inner( + reloaded, + matcher, + num_threads, + skip_thread_test = True, + ) + def test_basic(self): # Load the index from files. default_loader = svs.VectorDataLoader( @@ -231,9 +251,21 @@ def test_basic(self): ) self._setup(default_loader) - # Standard tests + # Standard tests - all data types now support save/load + is_first = True for loader, matcher in self.loader_and_matcher: - self._test_basic(loader, matcher) + self._test_basic(loader, matcher, test_single_query=is_first) + is_first = False + + # Test with float16 data loader + data = svs.read_vecs(test_data_vecs) + data_f16 = data.astype('float16') + with TemporaryDirectory() as tempdir: + hvecs_path = os.path.join(tempdir, "data_f16.hvecs") + svs.write_vecs(data_f16, hvecs_path) + loader_f16 = svs.VectorDataLoader(hvecs_path, svs.DataType.float16) + matcher_f16 = UncompressedMatcher("float32") + self._test_basic(loader_f16, matcher_f16) def _groundtruth_map(self): return { @@ -246,7 +278,8 @@ def _test_build( self, loader, distance: svs.DistanceType, - matcher + matcher, + epsilon: float = 0.005 ): num_threads = 2 distance_map = self._distance_map() @@ -300,14 +333,27 @@ def _test_build( recall = svs.k_recall_at(get_test_set(groundtruth, nq), results[0], k, k) print(f"Recall = {recall}, Expected = {expected_recall}") if not DEBUG: - self.assertTrue(isapprox(recall, expected_recall, epsilon = 0.005)) + self.assertTrue(isapprox(recall, expected_recall, epsilon = epsilon)) def test_build(self): # Build directly from data - queries = svs.read_vecs(test_queries) + data = svs.read_vecs(test_data_vecs) - # Build from file loader + # Build from file loader with float32 loader = svs.VectorDataLoader(test_data_svs, svs.DataType.float32) matcher = UncompressedMatcher("bfloat16") self._test_build(loader, svs.DistanceType.L2, matcher) self._test_build(loader, svs.DistanceType.MIP, matcher) + + # Build using float16 + data_f16 = data.astype('float16') + with TemporaryDirectory() as tempdir: + # Save float16 data to hvecs format + hvecs_path = os.path.join(tempdir, "data_f16.hvecs") + svs.write_vecs(data_f16, hvecs_path) + + # Build from file loader with float16 + # Use larger epsilon since float16 has different precision than bfloat16 + loader_f16 = svs.VectorDataLoader(hvecs_path, svs.DataType.float16) + self._test_build(loader_f16, svs.DistanceType.L2, matcher, epsilon = 0.015) + self._test_build(loader_f16, svs.DistanceType.MIP, matcher, epsilon = 0.015) diff --git a/examples/python/example_ivf.py b/examples/python/example_ivf.py index 63f80dd62..7cfe1ef82 100644 --- a/examples/python/example_ivf.py +++ b/examples/python/example_ivf.py @@ -134,9 +134,38 @@ def main(): print(f" ✓ Clustering saved to '{clustering_path}'") # [save-clustering] + # [save-index] + # Save the assembled IVF index to disk + print("\n9. Saving the IVF index...") + config_dir = os.path.join(test_data_dir, "index_config") + data_dir = os.path.join(test_data_dir, "index_data") + index.save(config_dir, data_dir) + print(f" ✓ Index saved to:") + print(f" Config: {config_dir}") + print(f" Data: {data_dir}") + # [save-index] + + # [load-index] + # Reload the saved index + print("\n10. Reloading saved index...") + reloaded_index = svs.IVF.load( + config_directory = config_dir, + data_directory = data_dir, + distance = svs.DistanceType.L2, + num_threads = 4, + ) + print(f" ✓ Index reloaded with {reloaded_index.size} vectors") + + # Verify the reloaded index works correctly + reloaded_index.search_parameters = search_params + I_reloaded, D_reloaded = reloaded_index.search(queries, num_neighbors) + recall_reloaded = svs.k_recall_at(groundtruth, I_reloaded, num_neighbors, num_neighbors) + print(f" ✓ Recall@{num_neighbors}: {recall_reloaded:.4f}") + # [load-index] + # [load-and-assemble] # Load clustering and assemble a new index - print("\n9. Loading clustering and assembling new index...") + print("\n11. Loading clustering and assembling new index...") loaded_clustering = svs.Clustering.load_clustering(clustering_path) new_index = svs.IVF.assemble_from_clustering( @@ -151,7 +180,7 @@ def main(): # [assemble-from-file] # Or directly assemble from file - print("\n10. Assembling index directly from clustering file...") + print("\n12. Assembling index directly from clustering file...") index_from_file = svs.IVF.assemble_from_file( clustering_path = clustering_path, data_loader = data_loader, @@ -164,7 +193,7 @@ def main(): # [search-verification] # Verify both indices produce the same results - print("\n11. Verifying search results consistency...") + print("\n13. Verifying search results consistency...") index_from_file.search_parameters = search_params I2, D2 = index_from_file.search(queries, num_neighbors) recall2 = svs.k_recall_at(groundtruth, I2, num_neighbors, num_neighbors) @@ -178,7 +207,7 @@ def main(): # [tune-search-parameters] # Experiment with different search parameters - print("\n12. Tuning search parameters...") + print("\n14. Tuning search parameters...") for n_probes in [5, 10, 20]: search_params.n_probes = n_probes index.search_parameters = search_params diff --git a/examples/python/example_ivf_dynamic.py b/examples/python/example_ivf_dynamic.py index 605ff9ecc..698f3e9e6 100644 --- a/examples/python/example_ivf_dynamic.py +++ b/examples/python/example_ivf_dynamic.py @@ -209,12 +209,26 @@ def main(): # [save-index] # [load-index] - # Note: DynamicIVF.load() is being implemented for easier reload - # For now, the index has been successfully saved and can be accessed at: - print("\n12. Index saved successfully!") - print(f" ✓ Config: {config_dir}") - print(f" ✓ Data: {data_dir}") - print(f" Note: load() API coming soon for simplified reload") + # Reload the saved index + print("\n12. Reloading saved index...") + reloaded_index = svs.DynamicIVF.load( + config_directory = config_dir, + data_directory = data_dir, + distance = svs.DistanceType.L2, + num_threads = 4, + ) + print(f" ✓ Index reloaded with {reloaded_index.size} vectors") + print(f" ✓ Index dimensions: {reloaded_index.dimensions}") + + # Verify the reloaded index works correctly + reloaded_index.search_parameters = search_params + I_reloaded, D_reloaded = reloaded_index.search(queries, num_neighbors) + recall_reloaded = svs.k_recall_at(groundtruth, I_reloaded, num_neighbors, num_neighbors) + print(f" ✓ Recall@{num_neighbors}: {recall_reloaded:.4f}") + + # Verify ID consistency + all_ids_reloaded = reloaded_index.all_ids() + print(f" ✓ Reloaded index contains {len(all_ids_reloaded)} unique IDs") # [load-index] # [get-all-ids] diff --git a/include/svs/extensions/ivf/scalar.h b/include/svs/extensions/ivf/scalar.h index 4732a3fa3..cc2a611f9 100644 --- a/include/svs/extensions/ivf/scalar.h +++ b/include/svs/extensions/ivf/scalar.h @@ -14,6 +14,9 @@ * limitations under the License. */ +#pragma once + +#include "svs/index/ivf/data_traits.h" #include "svs/index/ivf/extensions.h" #include "svs/quantization/scalar/scalar.h" @@ -63,3 +66,28 @@ auto svs_invoke( } } // namespace svs::quantization::scalar + +///// +///// DataTypeTraits specialization for Scalar Quantization datasets +///// + +namespace svs::index::ivf { + +/// @brief Specialization of DataTypeTraits for SQDataset +/// +/// This enables automatic save/load of Scalar Quantization data type information +/// in IVF indices. +template +struct DataTypeTraits> { + using Data = quantization::scalar::SQDataset; + + static DataTypeConfig get_config() { + DataTypeConfig config; + config.schema = + std::string(quantization::scalar::scalar_quantization_serialization_schema); + config.element_type = datatype_v; + return config; + } +}; + +} // namespace svs::index::ivf diff --git a/include/svs/index/ivf/clustering.h b/include/svs/index/ivf/clustering.h index 93a36526c..45aa50473 100644 --- a/include/svs/index/ivf/clustering.h +++ b/include/svs/index/ivf/clustering.h @@ -316,6 +316,9 @@ class DenseClusteredDataset { using index_type = I; using data_type = Data; + // Default constructor for use in load() + DenseClusteredDataset() = default; + // Constructor from clustering (for building from existing data) template DenseClusteredDataset( @@ -343,12 +346,12 @@ class DenseClusteredDataset { } // Constructor for empty clusters (for assembly/dynamic operations) - template - DenseClusteredDataset(size_t num_clusters, size_t dimensions, const Alloc& allocator) + // Note: This constructor creates empty clusters with the given dimensionality + DenseClusteredDataset(size_t num_clusters, size_t dimensions) : clusters_{} { clusters_.reserve(num_clusters); for (size_t i = 0; i < num_clusters; ++i) { - clusters_.emplace_back(Data(0, dimensions, allocator), std::vector()); + clusters_.emplace_back(Data(0, dimensions), std::vector()); } } @@ -389,6 +392,211 @@ class DenseClusteredDataset { // View cluster data (mutable) - for dynamic IVF operations Data& view_cluster(size_t cluster) { return clusters_[cluster].view_cluster(); } + // Get the dimensions of the data + size_t dimensions() const { + if (clusters_.empty()) { + return 0; + } + return clusters_[0].data_.dimensions(); + } + + ///// Saving and Loading ///// + + static constexpr lib::Version save_version{0, 0, 0}; + static constexpr std::string_view serialization_schema = "ivf_dense_clustered_dataset"; + + /// @brief Save the DenseClusteredDataset to disk. + /// + /// Saves all cluster data using the existing save mechanisms for each data type + /// (SimpleData, LVQ, LeanVec, etc.), then archives everything into a single file. + /// + /// File format: + /// - clusters_archive.bin: Archive containing all cluster data directories + /// - ids.bin: Concatenated binary IDs for all clusters + /// - Config contains: cluster_sizes array, ids_offsets array + /// + /// @param ctx The save context providing directory and naming utilities + /// @return SaveTable containing metadata for reloading + lib::SaveTable save(const lib::SaveContext& ctx) const { + auto num_clusters = size(); + auto dims = dimensions(); + + // Compute cluster sizes and ID offsets + std::vector cluster_sizes(num_clusters); + std::vector ids_offsets(num_clusters + 1); + + size_t ids_offset = 0; + for (size_t i = 0; i < num_clusters; ++i) { + cluster_sizes[i] = clusters_[i].size(); + ids_offsets[i] = ids_offset; + ids_offset += cluster_sizes[i] * sizeof(I); + } + ids_offsets[num_clusters] = ids_offset; + + // Create a temporary directory for cluster data + lib::UniqueTempDirectory tempdir{"svs_ivf_clusters_save"}; + auto clusters_temp_dir = tempdir.get() / "clusters"; + std::filesystem::create_directories(clusters_temp_dir); + + // Save each cluster's data using lib::save_to_disk + // This uses the existing save mechanism for each data type (LVQ, LeanVec, etc.) + for (size_t i = 0; i < num_clusters; ++i) { + auto cluster_dir = clusters_temp_dir / fmt::format("cluster_{}", i); + lib::save_to_disk(clusters_[i].data_, cluster_dir); + } + + // Archive all cluster data into a single file + auto archive_path = ctx.get_directory() / "clusters_archive.bin"; + { + std::ofstream archive_stream(archive_path, std::ios::binary); + if (!archive_stream) { + throw ANNEXCEPTION("Failed to open {} for writing!", archive_path); + } + lib::DirectoryArchiver::pack(clusters_temp_dir, archive_stream); + } + + // Write all cluster IDs to a single file + auto ids_path = ctx.get_directory() / "ids.bin"; + { + auto stream = lib::open_write(ids_path); + for (size_t i = 0; i < num_clusters; ++i) { + if (!clusters_[i].ids_.empty()) { + lib::write_binary(stream, clusters_[i].ids_); + } + } + } + + // Serialize offset arrays to binary files for efficiency + auto cluster_sizes_path = ctx.get_directory() / "cluster_sizes.bin"; + { + auto stream = lib::open_write(cluster_sizes_path); + lib::write_binary(stream, cluster_sizes); + } + + auto ids_offsets_path = ctx.get_directory() / "ids_offsets.bin"; + { + auto stream = lib::open_write(ids_offsets_path); + lib::write_binary(stream, ids_offsets); + } + + return lib::SaveTable( + serialization_schema, + save_version, + {{"num_clusters", lib::save(num_clusters)}, + {"dimensions", lib::save(dims)}, + {"prefetch_offset", lib::save(prefetch_offset_)}, + {"index_type", lib::save(datatype_v)}, + {"clusters_archive_file", lib::save(std::string("clusters_archive.bin"))}, + {"ids_file", lib::save(std::string("ids.bin"))}, + {"cluster_sizes_file", lib::save(std::string("cluster_sizes.bin"))}, + {"ids_offsets_file", lib::save(std::string("ids_offsets.bin"))}, + {"total_ids_bytes", lib::save(ids_offset)}} + ); + } + + /// @brief Check if a saved file is compatible with this loader + static bool check_load_compatibility(std::string_view schema, lib::Version version) { + return schema == serialization_schema && version <= save_version; + } + + /// @brief Load a DenseClusteredDataset from disk. + /// + /// Loads from the archive-based format where cluster data is stored + /// using the native save/load mechanism for each data type. + /// + /// @tparam Pool Thread pool type for parallel loading + /// @tparam Allocator Allocator type for cluster data (optional) + /// @param table The load table containing saved metadata + /// @param threadpool Thread pool for parallel operations (unused, kept for API + /// consistency) + /// @param allocator Optional allocator for cluster data. For blocked data types, + /// this controls the block size. If not provided, default allocator is used. + /// @return Loaded DenseClusteredDataset + template + static DenseClusteredDataset load( + const lib::LoadTable& table, + Pool& SVS_UNUSED(threadpool), + const Allocator& allocator = Allocator{} + ) { + auto num_clusters = lib::load_at(table, "num_clusters"); + // Note: "dimensions" field is saved for validation but not used during load + // since each cluster's data type determines its own dimensions + [[maybe_unused]] auto dims = lib::load_at(table, "dimensions"); + auto prefetch_offset = lib::load_at(table, "prefetch_offset"); + + // Verify index type matches + auto saved_index_type = lib::load_at(table, "index_type"); + if (saved_index_type != datatype_v) { + throw ANNEXCEPTION( + "DenseClusteredDataset was saved using index type {} but we're trying to " + "reload it using {}!", + saved_index_type, + datatype_v + ); + } + + auto base_dir = table.context().get_directory(); + + // Load offset arrays from binary files + std::vector cluster_sizes(num_clusters); + std::vector ids_offsets(num_clusters + 1); + + { + auto stream = lib::open_read(base_dir / "cluster_sizes.bin"); + lib::read_binary(stream, cluster_sizes); + } + { + auto stream = lib::open_read(base_dir / "ids_offsets.bin"); + lib::read_binary(stream, ids_offsets); + } + + // Create a temporary directory to unpack the clusters archive + lib::UniqueTempDirectory tempdir{"svs_ivf_clusters_load"}; + auto clusters_temp_dir = tempdir.get() / "clusters"; + std::filesystem::create_directories(clusters_temp_dir); + + // Unpack the clusters archive + { + std::ifstream archive_stream( + base_dir / "clusters_archive.bin", std::ios::binary + ); + if (!archive_stream) { + throw ANNEXCEPTION( + "Failed to open {} for reading!", base_dir / "clusters_archive.bin" + ); + } + lib::DirectoryArchiver::unpack(archive_stream, clusters_temp_dir); + } + + // Create result dataset with default constructor + DenseClusteredDataset result; + result.prefetch_offset_ = prefetch_offset; + result.clusters_.reserve(num_clusters); + + // Load IDs file for reading + auto ids_stream = lib::open_read(base_dir / "ids.bin"); + + // Load each cluster's data and ids together + for (size_t i = 0; i < num_clusters; ++i) { + // Load cluster data with provided allocator + auto cluster_dir = clusters_temp_dir / fmt::format("cluster_{}", i); + auto cluster_data = lib::load_from_disk(cluster_dir, allocator); + + // Load cluster IDs + size_t cluster_size = cluster_sizes[i]; + std::vector cluster_ids(cluster_size); + if (cluster_size > 0) { + ids_stream.seekg(static_cast(ids_offsets[i])); + lib::read_binary(ids_stream, cluster_ids); + } + + // Construct cluster with both data and ids + result.clusters_.emplace_back(std::move(cluster_data), std::move(cluster_ids)); + } + + return result; + } + private: std::vector> clusters_; size_t prefetch_offset_ = 8; diff --git a/include/svs/index/ivf/common.h b/include/svs/index/ivf/common.h index e914778d6..feedddadc 100644 --- a/include/svs/index/ivf/common.h +++ b/include/svs/index/ivf/common.h @@ -42,6 +42,17 @@ // Common definitions. namespace svs::index::ivf { +/// Helper trait to check if a distance type behaves like IP (inner product) +template +inline constexpr bool is_ip_like_v = + std::is_same_v, distance::DistanceIP> || + std::is_same_v, distance::DistanceCosineSimilarity>; + +/// Helper trait to check if a distance type is L2 +template +inline constexpr bool is_l2_v = + std::is_same_v, distance::DistanceL2>; + // Small epsilon value used for floating-point comparisons to avoid precision // issues. The value 1/1024 (approximately 0.0009765625) is chosen as a reasonable // threshold for numerical stability in algorithms such as k-means clustering, where exact @@ -371,11 +382,19 @@ void centroid_assignment( using DataType = typename Data::element_type; using CentroidType = T; - // Convert data to match centroid type if necessary - data::SimpleData data_conv; + // Convert data to match centroid type if necessary, otherwise use original data + [[maybe_unused]] data::SimpleData data_conv; if constexpr (!std::is_same_v) { data_conv = convert_data(data, threadpool); } + const auto& matmul_data = [&]() -> const auto& { + if constexpr (!std::is_same_v) { + return data_conv; + } else { + return data; + } + } + (); auto generate_assignments = timer.push_back("generate assignments"); threads::parallel_for( @@ -383,28 +402,15 @@ void centroid_assignment( threads::StaticPartition{batch_range.size()}, [&](auto indices, auto /*tid*/) { auto range = threads::UnitRange(indices); - if constexpr (!std::is_same_v) { - compute_matmul( - data_conv.get_datum(range.start()).data(), - centroids.data(), - matmul_results.get_datum(range.start()).data(), - range.size(), - centroids.size(), - data.dimensions() - ); - } else { - compute_matmul( - data.get_datum(range.start()).data(), - centroids.data(), - matmul_results.get_datum(range.start()).data(), - range.size(), - centroids.size(), - data.dimensions() - ); - } - if constexpr (std::is_same_v< - std::remove_cvref_t, - distance::DistanceIP>) { + compute_matmul( + matmul_data.get_datum(range.start()).data(), + centroids.data(), + matmul_results.get_datum(range.start()).data(), + range.size(), + centroids.size(), + matmul_data.dimensions() + ); + if constexpr (is_ip_like_v) { for (auto i : indices) { auto nearest = type_traits::sentinel_v, std::greater<>>; @@ -414,9 +420,7 @@ void centroid_assignment( } assignments[batch_range.start() + i] = nearest.id(); } - } else if constexpr (std::is_same_v< - std::remove_cvref_t, - distance::DistanceL2>) { + } else if constexpr (is_l2_v) { for (auto i : indices) { auto nearest = type_traits::sentinel_v, std::less<>>; auto dists = matmul_results.get_datum(i); @@ -428,7 +432,11 @@ void centroid_assignment( assignments[batch_range.start() + i] = nearest.id(); } } else { - throw ANNEXCEPTION("Only L2 and MIP distances supported in IVF build!"); + // Compile-time error for unsupported distance types + static_assert( + sizeof(Distance) == 0, + "Only L2, MIP, and Cosine distances are supported in IVF build!" + ); } } ); @@ -565,13 +573,13 @@ auto kmeans_training( auto training_timer = timer.push_back("Kmeans training"); data::SimpleData centroids_fp32 = convert_data(centroids, threadpool); - if constexpr (std::is_same_v, distance::DistanceIP>) { + if constexpr (is_ip_like_v) { normalize_centroids(centroids_fp32, threadpool, timer); } auto assignments = std::vector(data.size()); std::vector data_norm; - if constexpr (std::is_same_v, distance::DistanceL2>) { + if constexpr (is_l2_v) { generate_norms(data, data_norm, threadpool); } std::vector centroids_norm; @@ -580,7 +588,7 @@ auto kmeans_training( auto iter_timer = timer.push_back("iteration"); auto batchsize = parameters.minibatch_size_; auto num_batches = lib::div_round_up(data.size(), batchsize); - if constexpr (std::is_same_v, distance::DistanceL2>) { + if constexpr (is_l2_v) { generate_norms(centroids_fp32, centroids_norm, threadpool); } @@ -613,7 +621,7 @@ auto kmeans_training( centroid_split(data, centroids_fp32, counts, rng, threadpool, timer); - if constexpr (std::is_same_v, distance::DistanceIP>) { + if constexpr (is_ip_like_v) { normalize_centroids(centroids_fp32, threadpool, timer); } } @@ -727,7 +735,7 @@ data::SimpleData init_centroids( template std::vector maybe_compute_norms(const Data& data, Pool& threadpool) { std::vector norms; - if constexpr (std::is_same_v, distance::DistanceL2>) { + if constexpr (is_l2_v) { generate_norms(data, norms, threadpool); } return norms; @@ -750,7 +758,7 @@ std::vector> group_assignments( /// @tparam BuildType The numeric type used for matrix operations (float, Float16, BFloat16) /// @tparam Data The dataset type /// @tparam Centroids The centroids dataset type -/// @tparam Distance The distance metric type (DistanceIP or DistanceL2) +/// @tparam Distance The distance metric type /// @tparam Pool The thread pool type /// @tparam I The integer type for cluster indices /// @@ -854,7 +862,7 @@ void search_centroids( ) { unsigned int count = 0; buffer.clear(); - if constexpr (std::is_same_v, distance::DistanceIP>) { + if constexpr (is_ip_like_v) { for (size_t j = 0; j < num_threads; j++) { auto distance = matmul_results[j].get_datum(query_id); for (size_t k = 0; k < distance.size(); k++) { @@ -862,7 +870,7 @@ void search_centroids( count++; } } - } else if constexpr (std::is_same_v, distance::DistanceL2>) { + } else if constexpr (is_l2_v) { float query_norm = distance::norm_square(query); for (size_t j = 0; j < num_threads; j++) { auto distance = matmul_results[j].get_datum(query_id); @@ -873,7 +881,9 @@ void search_centroids( } } } else { - throw ANNEXCEPTION("Only L2 and MIP distances supported in IVF search!"); + static_assert( + sizeof(Dist) == 0, "Only L2, MIP, and Cosine distances supported in IVF search!" + ); } } diff --git a/include/svs/index/ivf/data_traits.h b/include/svs/index/ivf/data_traits.h new file mode 100644 index 000000000..38f4a87ff --- /dev/null +++ b/include/svs/index/ivf/data_traits.h @@ -0,0 +1,105 @@ +/* + * Copyright 2026 Intel Corporation + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include "svs/core/data/simple.h" +#include "svs/lib/saveload.h" + +#include +#include + +namespace svs::index::ivf { + +/// @brief Data type configuration for IVF save/load +/// +/// This struct holds information about the data type stored in an IVF index, +/// allowing automatic reconstruction of the appropriate loader during load. +/// +/// The base implementation supports uncompressed data (fp32, fp16, bf16). +/// Extensions (e.g., LVQ, LeanVec) can be added by including additional +/// specialization headers that specialize DataTypeTraits for their types. +struct DataTypeConfig { + // Schema identifier (e.g., "uncompressed_data", "one_level_lvq_dataset", + // "leanvec_dataset") + std::string schema; + + // For uncompressed data: element type + DataType element_type = DataType::undef; + + // Centroid type (bfloat16 or float16) - saved separately to match centroid storage + DataType centroid_type = DataType::bfloat16; + + // For LVQ: compression parameters + size_t primary_bits = 0; + size_t residual_bits = 0; + std::string strategy; // "sequential" or "turbo" + + // For LeanVec: dimensionality and encoding kinds + std::string primary_kind; // "float32", "float16", "lvq4", "lvq8" + std::string secondary_kind; // "float32", "float16", "lvq4", "lvq8" + size_t leanvec_dims = 0; + + // Serialization + static constexpr std::string_view serialization_schema = "ivf_data_type_config"; + static constexpr lib::Version save_version{0, 0, 0}; + + lib::SaveTable save() const { + return lib::SaveTable( + serialization_schema, + save_version, + {{"schema", lib::save(schema)}, + {"element_type", lib::save(element_type)}, + {"centroid_type", lib::save(centroid_type)}, + {"primary_bits", lib::save(primary_bits)}, + {"residual_bits", lib::save(residual_bits)}, + {"strategy", lib::save(strategy)}, + {"primary_kind", lib::save(primary_kind)}, + {"secondary_kind", lib::save(secondary_kind)}, + {"leanvec_dims", lib::save(leanvec_dims)}} + ); + } + + static DataTypeConfig load(const lib::ContextFreeLoadTable& table) { + DataTypeConfig config; + config.schema = lib::load_at(table, "schema"); + config.element_type = lib::load_at(table, "element_type"); + config.centroid_type = lib::load_at(table, "centroid_type"); + config.primary_bits = lib::load_at(table, "primary_bits"); + config.residual_bits = lib::load_at(table, "residual_bits"); + config.strategy = lib::load_at(table, "strategy"); + config.primary_kind = lib::load_at(table, "primary_kind"); + config.secondary_kind = lib::load_at(table, "secondary_kind"); + config.leanvec_dims = lib::load_at(table, "leanvec_dims"); + return config; + } +}; + +/// @brief Trait to extract DataTypeConfig from a data type +/// +/// Default implementation for uncompressed SimpleData. +/// Specializations for LVQ/LeanVec are provided in svs/extensions/ivf/lvq.h +/// and svs/extensions/ivf/leanvec.h respectively. +template struct DataTypeTraits { + static DataTypeConfig get_config() { + DataTypeConfig config; + config.schema = "uncompressed_data"; + config.element_type = datatype_v; + return config; + } +}; + +} // namespace svs::index::ivf diff --git a/include/svs/index/ivf/dynamic_ivf.h b/include/svs/index/ivf/dynamic_ivf.h index f6b8dbb09..9aefbfcc4 100644 --- a/include/svs/index/ivf/dynamic_ivf.h +++ b/include/svs/index/ivf/dynamic_ivf.h @@ -207,7 +207,8 @@ class DynamicIVFIndex { , logger_{std::move(logger)} { // Initialize metadata structures based on cluster contents size_t total_size = 0; - for (const auto& cluster : clusters_) { + for (size_t cluster_idx = 0; cluster_idx < clusters_.size(); ++cluster_idx) { + const auto& cluster = clusters_[cluster_idx]; for (size_t pos = 0; pos < cluster.ids_.size(); ++pos) { total_size = std::max(total_size, static_cast(cluster.ids_[pos]) + 1); @@ -725,10 +726,21 @@ class DynamicIVFIndex { // Compact before saving to remove empty slots compact(); + // Create directories + std::filesystem::create_directories(config_directory); + std::filesystem::create_directories(data_directory); + auto clusters_dir = data_directory / "clusters"; + std::filesystem::create_directories(clusters_dir); + + // Get data type configuration for automatic loader construction during load + auto data_type_config = DataTypeTraits::get_config(); + // Set the centroid type from the Centroids template parameter + data_type_config.centroid_type = datatype_v; + // Save configuration lib::save_to_disk( lib::SaveOverride([&](const lib::SaveContext& ctx) { - return lib::SaveTable( + auto table = lib::SaveTable( "dynamic_ivf_config", save_version, { @@ -737,20 +749,18 @@ class DynamicIVFIndex { {"num_clusters", lib::save(clusters_.size())}, } ); + // Insert nested table for data type config + table.insert("data_type_config", lib::save(data_type_config)); + return table; }), config_directory ); - // Save centroids and cluster data + // Save centroids lib::save_to_disk(centroids_, data_directory / "centroids"); - for (size_t i = 0; i < clusters_.size(); ++i) { - auto cluster_path = data_directory / fmt::format("cluster_{}", i); - lib::save_to_disk(clusters_[i].data_, cluster_path); - - auto ids_path = data_directory / fmt::format("cluster_ids_{}", i); - lib::save_to_disk(clusters_[i].ids_, ids_path); - } + // Save clustered dataset + lib::save_to_disk(clusters_, clusters_dir); } private: @@ -786,7 +796,7 @@ class DynamicIVFIndex { } void initialize_distance_metadata() { - if constexpr (std::is_same_v, distance::DistanceL2>) { + if constexpr (is_l2_v) { centroids_norm_.reserve(centroids_.size()); for (size_t i = 0; i < centroids_.size(); ++i) { centroids_norm_.push_back(distance::norm_square(centroids_.get_datum(i))); @@ -1063,4 +1073,100 @@ auto assemble_dynamic_from_clustering( ); } +/// @brief Load a saved DynamicIVFIndex from disk +/// +/// This function loads a previously saved DynamicIVFIndex, including centroids, +/// clustered dataset, and ID translation table. +/// +/// @tparam CentroidType Element type of centroids (e.g., float, Float16) +/// @tparam DataType The full type of cluster data (e.g., BlockedData) +/// @tparam Distance Distance metric type +/// @tparam ThreadpoolProto Thread pool prototype type +/// +/// @param config_path Path to the saved index configuration directory +/// @param data_path Path to the saved data directory (centroids and clusters) +/// @param distance Distance metric for searching +/// @param threadpool_proto Thread pool for parallel processing +/// @param intra_query_thread_count Number of threads for intra-query parallelism (default: +/// 1) +/// @param logger Logger for logging customization +/// +/// @return Fully constructed DynamicIVFIndex ready for searching and modifications +/// +template < + typename CentroidType, + typename DataType, + typename Distance, + typename ThreadpoolProto> +auto load_dynamic_ivf_index( + const std::filesystem::path& config_path, + const std::filesystem::path& data_path, + Distance distance, + ThreadpoolProto threadpool_proto, + const size_t intra_query_thread_count = 1, + svs::logging::logger_ptr logger = svs::logging::get() +) { + // Initialize timer for performance tracking + auto timer = lib::Timer(); + auto load_timer = timer.push_back("Total loading time"); + + // Initialize thread pool + auto threadpool = threads::as_threadpool(std::move(threadpool_proto)); + + // Load configuration to get translator + auto config_timer = timer.push_back("Loading configuration"); + auto serialized = lib::begin_deserialization(config_path); + auto table = serialized.cast(); + auto translator = lib::load_at(table, "translation"); + config_timer.finish(); + + // Load centroids + auto centroids_timer = timer.push_back("Loading centroids"); + using centroids_type = data::SimpleData; + auto centroids = lib::load_from_disk(data_path / "centroids"); + centroids_timer.finish(); + + // Define cluster types - use lib_blocked_alloc_data_type pattern for proper allocator + // This uses lib::Allocator instead of potentially HugepageAllocator + using I = uint32_t; + using blocked_data_type = typename DataType::lib_blocked_alloc_data_type; + using cluster_type = DenseClusteredDataset; + + auto clusters_timer = timer.push_back("Loading clusters"); + + auto clusters_dir = data_path / "clusters"; + + // Use a small block size for IVF clusters (1MB instead of 1GB default) + // This prevents excessive memory allocation when loading many clusters + auto blocking_params = data::BlockingParameters{ + .blocksize_bytes = lib::PowerOfTwo(20) // 2^20 = 1MB + }; + using allocator_type = typename blocked_data_type::allocator_type; + auto blocked_allocator = + allocator_type(blocking_params, typename allocator_type::allocator_type()); + + auto dense_clusters = + lib::load_from_disk(clusters_dir, threadpool, blocked_allocator); + clusters_timer.finish(); + + // Create the index with the translator constructor + auto index_timer = timer.push_back("Index construction"); + auto index = + DynamicIVFIndex( + std::move(centroids), + std::move(dense_clusters), + std::move(translator), + std::move(distance), + std::move(threadpool), + intra_query_thread_count, + logger + ); + index_timer.finish(); + + load_timer.finish(); + svs::logging::debug(logger, "{}", timer); + + return index; +} + } // namespace svs::index::ivf diff --git a/include/svs/index/ivf/index.h b/include/svs/index/ivf/index.h index b4a23b0f6..dc983921d 100644 --- a/include/svs/index/ivf/index.h +++ b/include/svs/index/ivf/index.h @@ -21,6 +21,7 @@ #include "svs/core/loading.h" #include "svs/core/query_result.h" #include "svs/index/ivf/clustering.h" +#include "svs/index/ivf/data_traits.h" #include "svs/index/ivf/extensions.h" #include "svs/index/ivf/hierarchical_kmeans.h" #include "svs/index/ivf/kmeans.h" @@ -442,6 +443,69 @@ class IVFIndex { ); } + ///// Saving ///// + + /// @brief Indicates that the index supports saving + static constexpr bool supports_saving = true; + + /// @brief Save version for the IVF index + static constexpr lib::Version save_version = lib::Version(0, 0, 0); + + /// @brief Serialization schema identifier + static constexpr std::string_view serialization_schema = "ivf_index"; + + /// @brief Save the IVF index to disk + /// + /// This saves all components needed to reconstruct the index: + /// - Centroids + /// - Clustered dataset (DenseClusteredDataset) + /// - Configuration (number of clusters) + /// + /// @param config_directory Directory where the index configuration will be saved. + /// @param data_directory Directory where the centroids and cluster data will be saved. + /// + /// Each directory may be created as a side-effect of this method call provided that + /// the parent directory exists. + /// + void save( + const std::filesystem::path& config_directory, + const std::filesystem::path& data_directory + ) const { + // Create directories if they don't exist + auto centroids_dir = data_directory / "centroids"; + auto clusters_dir = data_directory / "clusters"; + std::filesystem::create_directories(config_directory); + std::filesystem::create_directories(centroids_dir); + std::filesystem::create_directories(clusters_dir); + + // Get data type configuration for automatic loader construction during load + auto data_type_config = DataTypeTraits::get_config(); + // Set the centroid type from the Centroids template parameter + data_type_config.centroid_type = datatype_v; + + // Save configuration + lib::save_to_disk( + lib::SaveOverride([&]() { + auto table = lib::SaveTable( + serialization_schema, + save_version, + {{"name", lib::save(name())}, + {"num_clusters", lib::save(num_clusters())}} + ); + // Insert nested table for data type config + table.insert("data_type_config", lib::save(data_type_config)); + return table; + }), + config_directory + ); + + // Save centroids + lib::save_to_disk(centroids_, centroids_dir); + + // Save clustered dataset + lib::save_to_disk(cluster_, clusters_dir); + } + private: ///// Core Components ///// Centroids centroids_; @@ -506,7 +570,7 @@ class IVFIndex { void initialize_distance_metadata() { // Precalculate centroid norms for L2 distance - if constexpr (std::is_same_v, distance::DistanceL2>) { + if constexpr (is_l2_v) { centroids_norm_.reserve(centroids_.size()); for (size_t i = 0; i < centroids_.size(); i++) { centroids_norm_.push_back(distance::norm_square(centroids_.get_datum(i))); @@ -764,4 +828,76 @@ auto assemble_from_file( ); } +/// @brief Load a saved IVF index from disk +/// +/// This function loads a previously saved IVF index, including centroids and +/// clustered dataset. Unlike assemble_from_clustering which requires the original +/// data, this function loads the pre-built index directly. +/// +/// @tparam CentroidType Element type of centroids (e.g., float, Float16) +/// @tparam DataType The element type of cluster data (e.g., float) +/// @tparam Distance Distance metric type +/// @tparam ThreadpoolProto Thread pool prototype type +/// +/// @param config_path Path to the saved index configuration directory +/// @param data_path Path to the saved data directory (centroids and clusters) +/// @param distance Distance metric for searching +/// @param threadpool_proto Thread pool for parallel processing +/// @param intra_query_thread_count Number of threads for intra-query parallelism (default: +/// 1) +/// @param logger Logger for logging customization +/// +/// @return Fully constructed IVF index ready for searching +/// +template < + typename CentroidType, + typename DataType, + typename Distance, + typename ThreadpoolProto> +auto load_ivf_index( + const std::filesystem::path& SVS_UNUSED(config_path), + const std::filesystem::path& data_path, + Distance distance, + ThreadpoolProto threadpool_proto, + const size_t intra_query_thread_count = 1, + svs::logging::logger_ptr logger = svs::logging::get() +) { + // Initialize timer for performance tracking + auto timer = lib::Timer(); + auto load_timer = timer.push_back("Total loading time"); + + // Initialize thread pool + auto threadpool = threads::as_threadpool(std::move(threadpool_proto)); + + // Load centroids + auto centroids_timer = timer.push_back("Loading centroids"); + using centroids_type = data::SimpleData; + auto centroids = lib::load_from_disk(data_path / "centroids"); + centroids_timer.finish(); + + // Load clustered dataset + auto clusters_timer = timer.push_back("Loading clusters"); + using data_type = typename DataType::lib_alloc_data_type; + using cluster_type = DenseClusteredDataset; + auto clusters = lib::load_from_disk(data_path / "clusters", threadpool); + clusters_timer.finish(); + + // Construct the IVF index + auto index_timer = timer.push_back("Index construction"); + auto index = IVFIndex( + std::move(centroids), + std::move(clusters), + std::move(distance), + std::move(threadpool), + intra_query_thread_count, + logger + ); + index_timer.finish(); + + load_timer.finish(); + svs::logging::debug(logger, "{}", timer); + + return index; +} + } // namespace svs::index::ivf diff --git a/include/svs/orchestrators/dynamic_ivf.h b/include/svs/orchestrators/dynamic_ivf.h index de1322b41..0b3592ab6 100644 --- a/include/svs/orchestrators/dynamic_ivf.h +++ b/include/svs/orchestrators/dynamic_ivf.h @@ -57,6 +57,8 @@ class DynamicIVFInterface : public IVFInterface { const std::filesystem::path& config_directory, const std::filesystem::path& data_directory ) = 0; + + virtual void save(std::ostream& stream) = 0; }; template @@ -119,6 +121,16 @@ class DynamicIVFImpl : public IVFImpl { ) override { impl().save(config_directory, data_directory); } + + void save(std::ostream& stream) override { + lib::UniqueTempDirectory tempdir{"svs_dynamic_ivf_save"}; + const auto config_dir = tempdir.get() / "config"; + const auto data_dir = tempdir.get() / "data"; + std::filesystem::create_directories(config_dir); + std::filesystem::create_directories(data_dir); + save(config_dir, data_dir); + lib::DirectoryArchiver::pack(tempdir, stream); + } }; // Forward Declarations. @@ -209,6 +221,18 @@ class DynamicIVF : public manager::IndexManager { impl_->save(config_directory, data_directory); } + /// + /// @brief Save the DynamicIVF index to a stream. + /// + /// @param stream Output stream to save the index to. + /// + /// The index is saved in a binary format that can be loaded using the + /// stream-based ``assemble`` method. + /// + /// @sa assemble + /// + void save(std::ostream& stream) const { impl_->save(stream); } + ///// Distance template double get_distance(size_t id, const Query& query) const { // Create AnonymousArray from the query @@ -315,6 +339,121 @@ class DynamicIVF : public manager::IndexManager { intra_query_threads ); } + + /// + /// @brief Load a saved DynamicIVF index from disk. + /// + /// This method restores a DynamicIVF index that was previously saved using `save()`. + /// + /// @tparam QueryTypes The query types supported by the returned index. + /// @tparam CentroidType Element type of centroids (e.g., float, BFloat16). + /// @tparam DataType Full cluster data type (e.g., BlockedData). + /// + /// @param config_path Path to the saved configuration directory. + /// @param data_path Path to the saved data directory (centroids and clusters). + /// @param distance Distance metric for searching. + /// @param threadpool_proto Thread pool prototype for parallel processing. + /// @param intra_query_threads Number of threads for intra-query parallelism. + /// + /// @return A fully constructed DynamicIVF ready for searching and modifications. + /// + /// @sa save, assemble_from_file + /// + template < + manager::QueryTypeDefinition QueryTypes, + typename CentroidType, + typename DataType, + typename Distance, + typename ThreadPoolProto> + static DynamicIVF assemble( + const std::filesystem::path& config_path, + const std::filesystem::path& data_path, + Distance distance, + ThreadPoolProto threadpool_proto, + size_t intra_query_threads = 1 + ) { + auto threadpool = threads::as_threadpool(std::move(threadpool_proto)); + if constexpr (std::is_same_v, DistanceType>) { + auto dispatcher = DistanceDispatcher(distance); + return dispatcher([&](auto distance_function) { + return DynamicIVF( + AssembleTag(), + manager::as_typelist{}, + index::ivf::load_dynamic_ivf_index( + config_path, + data_path, + std::move(distance_function), + std::move(threadpool), + intra_query_threads + ) + ); + }); + } else { + return DynamicIVF( + AssembleTag(), + manager::as_typelist{}, + index::ivf::load_dynamic_ivf_index( + config_path, + data_path, + distance, + std::move(threadpool), + intra_query_threads + ) + ); + } + } + + /// + /// @brief Load a DynamicIVF index from a stream. + /// + /// @tparam QueryTypes The query types supported by the returned index. + /// @tparam CentroidType Element type of centroids (e.g., float, BFloat16). + /// @tparam DataType Full cluster data type (e.g., SimpleData). + /// + /// @param stream Input stream to load the index from. + /// @param distance Distance metric for searching. + /// @param threadpool_proto Thread pool prototype for parallel processing. + /// @param intra_query_threads Number of threads for intra-query parallelism. + /// + /// @return A fully constructed DynamicIVF ready for searching and modifications. + /// + /// @sa save + /// + template < + manager::QueryTypeDefinition QueryTypes, + typename CentroidType, + typename DataType, + typename Distance, + typename ThreadPoolProto> + static DynamicIVF assemble( + std::istream& stream, + Distance distance, + ThreadPoolProto threadpool_proto, + size_t intra_query_threads = 1 + ) { + namespace fs = std::filesystem; + lib::UniqueTempDirectory tempdir{"svs_dynamic_ivf_load"}; + lib::DirectoryArchiver::unpack(stream, tempdir); + + const auto config_path = tempdir.get() / "config"; + if (!fs::is_directory(config_path)) { + throw ANNEXCEPTION("Invalid DynamicIVF index archive: missing config directory!" + ); + } + + const auto data_path = tempdir.get() / "data"; + if (!fs::is_directory(data_path)) { + throw ANNEXCEPTION("Invalid DynamicIVF index archive: missing data directory!"); + } + + return assemble( + config_path, + data_path, + distance, + std::move(threadpool_proto), + intra_query_threads + ); + } }; } // namespace svs diff --git a/include/svs/orchestrators/ivf.h b/include/svs/orchestrators/ivf.h index b45ab7edb..b020a8100 100644 --- a/include/svs/orchestrators/ivf.h +++ b/include/svs/orchestrators/ivf.h @@ -36,6 +36,13 @@ class IVFInterface { virtual IVFIterator batch_iterator( svs::AnonymousArray<1> query, size_t extra_search_buffer_capacity = 0 ) = 0; + + ///// Saving + virtual void save( + const std::filesystem::path& config_dir, const std::filesystem::path& data_dir + ) = 0; + + virtual void save(std::ostream& stream) = 0; }; template @@ -95,6 +102,31 @@ class IVFImpl : public manager::ManagerImpl { } ); } + + ///// Saving + void save( + const std::filesystem::path& config_dir, const std::filesystem::path& data_dir + ) override { + if constexpr (Impl::supports_saving) { + impl().save(config_dir, data_dir); + } else { + throw ANNEXCEPTION("The current IVF backend doesn't support saving!"); + } + } + + void save(std::ostream& stream) override { + if constexpr (Impl::supports_saving) { + lib::UniqueTempDirectory tempdir{"svs_ivf_save"}; + const auto config_dir = tempdir.get() / "config"; + const auto data_dir = tempdir.get() / "data"; + std::filesystem::create_directories(config_dir); + std::filesystem::create_directories(data_dir); + save(config_dir, data_dir); + lib::DirectoryArchiver::pack(tempdir, stream); + } else { + throw ANNEXCEPTION("The current IVF backend doesn't support saving!"); + } + } }; ///// @@ -150,6 +182,37 @@ class IVF : public manager::IndexManager { ); } + ///// Saving + /// + /// @brief Save the IVF index to disk. + /// + /// @param config_directory Directory where the index configuration will be saved. + /// @param data_directory Directory where the centroids and cluster data will be saved. + /// + /// Each directory may be created as a side-effect of this method call provided that + /// the parent directory exists. + /// + /// @sa assemble + /// + void save( + const std::filesystem::path& config_directory, + const std::filesystem::path& data_directory + ) { + impl_->save(config_directory, data_directory); + } + + /// + /// @brief Save the IVF index to a stream. + /// + /// @param stream Output stream to save the index to. + /// + /// The index is saved in a binary format that can be loaded using the + /// stream-based ``assemble`` method. + /// + /// @sa assemble + /// + void save(std::ostream& stream) const { impl_->save(stream); } + ///// Assembling template < manager::QueryTypeDefinition QueryTypes, @@ -223,6 +286,123 @@ class IVF : public manager::IndexManager { ); } + /// + /// @brief Load an IVF Index from a previously saved index. + /// + /// @tparam QueryTypes The element types of queries that will be used when requesting + /// searches over the index. Can be a single type or a ``svs::lib::Types``. + /// @tparam CentroidType The element type of the centroids. + /// @tparam DataType The element type of the cluster data. + /// + /// @param config_path Path to the directory where the index configuration was saved. + /// This corresponds to the ``config_directory`` argument of ``svs::IVF::save``. + /// @param data_path Path to the directory where the centroids and cluster data were + /// saved. This corresponds to the ``data_directory`` argument of + /// ``svs::IVF::save``. + /// @param distance The distance functor or ``svs::DistanceType`` enum to use for + /// similarity search computations. + /// @param threadpool_proto Precursor for the thread pool to use. Can either be an + /// acceptable thread pool instance or an integer specifying the number of threads + /// to use. + /// @param intra_query_threads Number of threads for intra-query parallelism. + /// + /// @sa save, assemble_from_file + /// + template < + manager::QueryTypeDefinition QueryTypes, + typename CentroidType, + typename DataType, + typename Distance, + typename ThreadpoolProto> + static IVF assemble( + const std::filesystem::path& config_path, + const std::filesystem::path& data_path, + const Distance& distance, + ThreadpoolProto threadpool_proto, + size_t intra_query_threads = 1 + ) { + auto threadpool = threads::as_threadpool(std::move(threadpool_proto)); + if constexpr (std::is_same_v, DistanceType>) { + auto dispatcher = DistanceDispatcher(distance); + return dispatcher([&](auto distance_function) { + return IVF( + std::in_place, + manager::as_typelist{}, + index::ivf::load_ivf_index( + config_path, + data_path, + std::move(distance_function), + std::move(threadpool), + intra_query_threads + ) + ); + }); + } else { + return IVF( + std::in_place, + manager::as_typelist{}, + index::ivf::load_ivf_index( + config_path, + data_path, + distance, + std::move(threadpool), + intra_query_threads + ) + ); + } + } + + /// + /// @brief Load an IVF Index from a stream. + /// + /// @tparam QueryTypes The element types of queries that will be used when requesting + /// searches over the index. Can be a single type or a ``svs::lib::Types``. + /// @tparam CentroidType The element type of the centroids. + /// @tparam DataType The element type of the cluster data. + /// + /// @param stream Input stream to load the index from. + /// @param distance The distance functor or ``svs::DistanceType`` enum to use for + /// similarity search computations. + /// @param threadpool_proto Precursor for the thread pool to use. + /// @param intra_query_threads Number of threads for intra-query parallelism. + /// + /// @sa save + /// + template < + manager::QueryTypeDefinition QueryTypes, + typename CentroidType, + typename DataType, + typename Distance, + typename ThreadpoolProto> + static IVF assemble( + std::istream& stream, + const Distance& distance, + ThreadpoolProto threadpool_proto, + size_t intra_query_threads = 1 + ) { + namespace fs = std::filesystem; + lib::UniqueTempDirectory tempdir{"svs_ivf_load"}; + lib::DirectoryArchiver::unpack(stream, tempdir); + + const auto config_path = tempdir.get() / "config"; + if (!fs::is_directory(config_path)) { + throw ANNEXCEPTION("Invalid IVF index archive: missing config directory!"); + } + + const auto data_path = tempdir.get() / "data"; + if (!fs::is_directory(data_path)) { + throw ANNEXCEPTION("Invalid IVF index archive: missing data directory!"); + } + + return assemble( + config_path, + data_path, + distance, + std::move(threadpool_proto), + intra_query_threads + ); + } + ///// Building template static auto build_clustering( diff --git a/tests/integration/ivf/dynamic_scalar.cpp b/tests/integration/ivf/dynamic_scalar.cpp index df6a761fa..37258629d 100644 --- a/tests/integration/ivf/dynamic_scalar.cpp +++ b/tests/integration/ivf/dynamic_scalar.cpp @@ -33,9 +33,14 @@ #include "fmt/core.h" // stl +#include +#include #include #include +// tests +#include "tests/utils/utils.h" + namespace sc = svs::quantization::scalar; namespace { @@ -225,3 +230,254 @@ CATCH_TEST_CASE( CATCH_SECTION("int8 stress test") { test_dynamic_ivf_scalar_stress(distance); } } + +CATCH_TEST_CASE("Dynamic IVF Save and Load", "[integration][dynamic_ivf][saveload]") { + namespace ivf = svs::index::ivf; + + size_t num_threads = 2; + size_t intra_query_threads = 1; + + auto distance = svs::DistanceL2(); + + // Load test dataset - use uncompressed float data for this test since scalar + // quantized data doesn't support all the operations needed for save/load + auto data = svs::data::SimpleData::load(test_dataset::data_svs_file()); + auto queries = test_dataset::queries(); + auto gt = test_dataset::groundtruth_euclidean(); + + // Build clustering on data + auto build_params = svs::index::ivf::IVFBuildParameters(NUM_CLUSTERS, 10, false); + auto threadpool = svs::threads::SequentialThreadPool(); + auto clustering = svs::index::ivf::build_clustering( + build_params, data, distance, threadpool, false + ); + + // Generate external IDs for the data + std::vector ids(data.size()); + std::iota(ids.begin(), ids.end(), 0); + + auto index = svs::DynamicIVF::assemble_from_clustering( + std::move(clustering), + data, + ids, + distance, + svs::threads::as_threadpool(num_threads), + intra_query_threads + ); + + CATCH_REQUIRE(index.size() == data.size()); + CATCH_REQUIRE(index.dimensions() == EXTENT); + + // Set search parameters + auto search_params = ivf::IVFSearchParameters(NUM_CLUSTERS, NUM_NEIGHBORS); + index.set_search_parameters(search_params); + + // Run search on original index + auto original_results = svs::QueryResult(queries.size(), NUM_NEIGHBORS); + index.search( + original_results.view(), + svs::data::ConstSimpleDataView{ + queries.data(), queries.size(), queries.dimensions()}, + search_params + ); + + auto original_recall = + svs::k_recall_at_n(gt, original_results, NUM_NEIGHBORS, NUM_NEIGHBORS); + CATCH_REQUIRE(original_recall > 0.9); + + // Prepare temp directory for save/load tests + auto temp_dir = svs_test::temp_directory(); + svs_test::prepare_temp_directory(); + + // Lambda to verify loaded index + auto verify_loaded_index = [&](svs::DynamicIVF& loaded_index) { + // Verify the loaded index has correct properties + CATCH_REQUIRE(loaded_index.size() == data.size()); + CATCH_REQUIRE(loaded_index.dimensions() == EXTENT); + + // Note: Search parameters are not persisted during save/load, + // so we set them again for the loaded index + loaded_index.set_search_parameters(search_params); + + // Run search on loaded index - should produce same results + auto loaded_results = svs::QueryResult(queries.size(), NUM_NEIGHBORS); + loaded_index.search( + loaded_results.view(), + svs::data::ConstSimpleDataView{ + queries.data(), queries.size(), queries.dimensions()}, + search_params + ); + + auto loaded_recall = + svs::k_recall_at_n(gt, loaded_results, NUM_NEIGHBORS, NUM_NEIGHBORS); + CATCH_REQUIRE(loaded_recall > 0.9); + + // Verify the results are similar + CATCH_REQUIRE(std::abs(original_recall - loaded_recall) < 0.01); + }; + + CATCH_SECTION("Directory-based save/load") { + auto config_dir = temp_dir / "config"; + auto data_dir = temp_dir / "data"; + + // Save the index to directories + index.save(config_dir, data_dir); + + // Load the index from directories + auto loaded_index = + svs::DynamicIVF::assemble>( + config_dir, data_dir, distance, num_threads, intra_query_threads + ); + + verify_loaded_index(loaded_index); + } + + CATCH_SECTION("Stream-based save/load") { + auto file = temp_dir / "dynamic_ivf_index.bin"; + + // Save the index to a stream + { + std::ofstream file_ostream(file, std::ios::binary); + CATCH_REQUIRE(file_ostream.good()); + index.save(file_ostream); + file_ostream.close(); + } + + // Load the index from the stream + std::ifstream file_istream(file, std::ios::binary); + CATCH_REQUIRE(file_istream.good()); + auto loaded_index = + svs::DynamicIVF::assemble>( + file_istream, distance, num_threads, intra_query_threads + ); + + verify_loaded_index(loaded_index); + } +} + +CATCH_TEST_CASE( + "Dynamic IVF SQDataset Save and Load", "[integration][dynamic_ivf][scalar][saveload]" +) { + namespace ivf = svs::index::ivf; + + size_t num_threads = 2; + size_t intra_query_threads = 1; + + auto distance = svs::DistanceL2(); + + // Load test dataset + auto data = svs::data::SimpleData::load(test_dataset::data_svs_file()); + auto queries = test_dataset::queries(); + auto gt = test_dataset::groundtruth_euclidean(); + + // Build clustering on uncompressed data + auto build_params = svs::index::ivf::IVFBuildParameters(NUM_CLUSTERS, 10, false); + auto threadpool = svs::threads::SequentialThreadPool(); + auto clustering = svs::index::ivf::build_clustering( + build_params, data, distance, threadpool, false + ); + + // Compress the data with Scalar Quantization + auto compressed_data = sc::SQDataset::compress(data); + + // Generate external IDs for the data + std::vector ids(data.size()); + std::iota(ids.begin(), ids.end(), 0); + + auto index = svs::DynamicIVF::assemble_from_clustering( + std::move(clustering), + compressed_data, + ids, + distance, + svs::threads::as_threadpool(num_threads), + intra_query_threads + ); + + CATCH_REQUIRE(index.size() == data.size()); + CATCH_REQUIRE(index.dimensions() == EXTENT); + + // Set search parameters + auto search_params = ivf::IVFSearchParameters(NUM_CLUSTERS, NUM_NEIGHBORS); + index.set_search_parameters(search_params); + + // Run search on original index + auto original_results = svs::QueryResult(queries.size(), NUM_NEIGHBORS); + index.search( + original_results.view(), + svs::data::ConstSimpleDataView{ + queries.data(), queries.size(), queries.dimensions()}, + search_params + ); + + auto original_recall = + svs::k_recall_at_n(gt, original_results, NUM_NEIGHBORS, NUM_NEIGHBORS); + CATCH_REQUIRE(original_recall > 0.9); + + // Prepare temp directory for save/load tests + auto temp_dir = svs_test::temp_directory(); + svs_test::prepare_temp_directory(); + + // Lambda to verify loaded index + auto verify_loaded_index = [&](svs::DynamicIVF& loaded_index) { + // Verify the loaded index has correct properties + CATCH_REQUIRE(loaded_index.size() == data.size()); + CATCH_REQUIRE(loaded_index.dimensions() == EXTENT); + + // Set search parameters for loaded index + loaded_index.set_search_parameters(search_params); + + // Run search on loaded index - should produce same results + auto loaded_results = svs::QueryResult(queries.size(), NUM_NEIGHBORS); + loaded_index.search( + loaded_results.view(), + svs::data::ConstSimpleDataView{ + queries.data(), queries.size(), queries.dimensions()}, + search_params + ); + + auto loaded_recall = + svs::k_recall_at_n(gt, loaded_results, NUM_NEIGHBORS, NUM_NEIGHBORS); + CATCH_REQUIRE(loaded_recall > 0.9); + + // Verify the results are similar (within 1% tolerance) + CATCH_REQUIRE(std::abs(original_recall - loaded_recall) < 0.01); + }; + + using SQData = sc::SQDataset; + + CATCH_SECTION("Directory-based save/load") { + auto config_dir = temp_dir / "config"; + auto data_dir = temp_dir / "data"; + + // Save the index to directories + index.save(config_dir, data_dir); + + // Load the index from directories + auto loaded_index = svs::DynamicIVF::assemble( + config_dir, data_dir, distance, num_threads, intra_query_threads + ); + + verify_loaded_index(loaded_index); + } + + CATCH_SECTION("Stream-based save/load") { + auto file = temp_dir / "dynamic_ivf_sq_index.bin"; + + // Save the index to a stream + { + std::ofstream file_ostream(file, std::ios::binary); + CATCH_REQUIRE(file_ostream.good()); + index.save(file_ostream); + file_ostream.close(); + } + + // Load the index from the stream + std::ifstream file_istream(file, std::ios::binary); + CATCH_REQUIRE(file_istream.good()); + auto loaded_index = svs::DynamicIVF::assemble( + file_istream, distance, num_threads, intra_query_threads + ); + + verify_loaded_index(loaded_index); + } +} diff --git a/tests/integration/ivf/index_build.cpp b/tests/integration/ivf/index_build.cpp index 3c6a11d90..c22b50443 100644 --- a/tests/integration/ivf/index_build.cpp +++ b/tests/integration/ivf/index_build.cpp @@ -215,3 +215,56 @@ CATCH_TEST_CASE("IVF Build/Clustering", "[integration][build][ivf][train_only]") test_build_train_only(svs::DistanceL2()); test_build_train_only(svs::DistanceIP()); } + +// Test Cosine distance (uses IP internally for clustering) +CATCH_TEST_CASE("IVF Build/Clustering Cosine", "[integration][build][ivf][cosine]") { + const auto queries = svs::data::SimpleData::load(test_dataset::query_file()); + CATCH_REQUIRE(svs_test::prepare_temp_directory()); + size_t num_threads = 2; + size_t num_inner_threads = 1; + + // Use simple build parameters for Cosine test + svs::index::ivf::IVFBuildParameters parameters; + parameters.num_centroids_ = 50; + parameters.num_iterations_ = 5; + parameters.is_hierarchical_ = true; + parameters.training_fraction_ = 0.5; + + // Build index with Cosine distance + auto data = svs::data::SimpleData::load(test_dataset::data_svs_file()); + auto clustering = svs::IVF::build_clustering( + parameters, data, svs::DistanceCosineSimilarity(), num_threads + ); + + auto index = svs::IVF::assemble_from_clustering( + std::move(clustering), + std::move(data), + svs::DistanceCosineSimilarity(), + num_threads, + num_inner_threads + ); + + // Verify index was built correctly + CATCH_REQUIRE(index.size() == test_dataset::VECTORS_IN_DATA_SET); + CATCH_REQUIRE(index.dimensions() == test_dataset::NUM_DIMENSIONS); + + // Test search with different n_probes values + auto groundtruth = test_dataset::load_groundtruth(svs::DistanceType::Cosine); + svs::index::ivf::IVFSearchParameters search_params; + + // Test with n_probes = 10 + search_params.n_probes_ = 10; + index.set_search_parameters(search_params); + auto results = index.search(queries, 10); + double recall = svs::k_recall_at_n(groundtruth, results, 10, 10); + fmt::print("Cosine - n_probes: {}, Recall@10: {}\n", search_params.n_probes_, recall); + CATCH_REQUIRE(recall > 0.3); // Basic sanity check + + // Test with n_probes = 30 + search_params.n_probes_ = 30; + index.set_search_parameters(search_params); + results = index.search(queries, 10); + recall = svs::k_recall_at_n(groundtruth, results, 10, 10); + fmt::print("Cosine - n_probes: {}, Recall@10: {}\n", search_params.n_probes_, recall); + CATCH_REQUIRE(recall > 0.6); // Higher n_probes should give better recall +} diff --git a/tests/integration/ivf/index_search.cpp b/tests/integration/ivf/index_search.cpp index 7de26e2e7..a262b411d 100644 --- a/tests/integration/ivf/index_search.cpp +++ b/tests/integration/ivf/index_search.cpp @@ -269,3 +269,106 @@ CATCH_TEST_CASE( CATCH_REQUIRE(error_count == 0); CATCH_REQUIRE(success_count == NUM_TEST_THREADS * CALLS_PER_THREAD); } + +CATCH_TEST_CASE("IVF Save and Load", "[integration][ivf][saveload]") { + namespace ivf = svs::index::ivf; + + auto datafile = test_dataset::data_svs_file(); + auto queries = test_dataset::queries(); + auto gt_l2 = test_dataset::groundtruth_euclidean(); + auto dist_l2 = svs::distance::DistanceL2(); + + auto data = svs::data::SimpleData::load(datafile); + + // Find the expected results for this dataset. + auto expected_result = test_dataset::ivf::expected_search_results( + svs::distance_type_v, + svsbenchmark::Uncompressed(svs::DataType::float32) + ); + + size_t num_threads = 2; + size_t intra_query_threads = 1; + + // Build and run the original index + // Note: The pre-built clustering uses BFloat16 centroids, so we use that as the + // second template parameter + auto index = svs::IVF::assemble_from_file( + test_dataset::clustering_directory(), + data, + dist_l2, + num_threads, + intra_query_threads + ); + + CATCH_REQUIRE(index.size() == test_dataset::VECTORS_IN_DATA_SET); + CATCH_REQUIRE(index.dimensions() == test_dataset::NUM_DIMENSIONS); + + // Run search on original index to verify it works + run_search(index, queries, gt_l2, expected_result.config_and_recall_); + + // Set some search parameters to verify they're saved + ivf::IVFSearchParameters params; + params.n_probes_ = 5; + params.k_reorder_ = 2.0; + index.set_search_parameters(params); + + // Prepare temp directory for save/load tests + auto temp_dir = svs_test::temp_directory(); + svs_test::prepare_temp_directory(); + + // Lambda to verify loaded index + auto verify_loaded_index = [&](svs::IVF& loaded_index) { + // Verify the loaded index has correct properties + CATCH_REQUIRE(loaded_index.size() == test_dataset::VECTORS_IN_DATA_SET); + CATCH_REQUIRE(loaded_index.dimensions() == test_dataset::NUM_DIMENSIONS); + + // Search parameters are not persisted (they are runtime configurations) + // Set them on the loaded index before searching + loaded_index.set_search_parameters(params); + + // Run search on loaded index - should produce same results + run_search(loaded_index, queries, gt_l2, expected_result.config_and_recall_); + }; + + CATCH_SECTION("Directory-based save/load") { + auto config_dir = temp_dir / "config"; + auto data_dir = temp_dir / "data"; + + // Save the index to directories + index.save(config_dir, data_dir); + + // Load the index back + // Centroids were saved as BFloat16 (from the original clustering), but cluster + // data is float (from our input data) + using DataType = + svs::data::SimpleData>; + auto loaded_index = svs::IVF::assemble( + config_dir, data_dir, dist_l2, num_threads, intra_query_threads + ); + + verify_loaded_index(loaded_index); + } + + CATCH_SECTION("Stream-based save/load") { + auto file = temp_dir / "ivf_index.bin"; + + // Save the index to a stream + { + std::ofstream file_ostream(file, std::ios::binary); + CATCH_REQUIRE(file_ostream.good()); + index.save(file_ostream); + file_ostream.close(); + } + + // Load the index from the stream + std::ifstream file_istream(file, std::ios::binary); + CATCH_REQUIRE(file_istream.good()); + using DataType = + svs::data::SimpleData>; + auto loaded_index = svs::IVF::assemble( + file_istream, dist_l2, num_threads, intra_query_threads + ); + + verify_loaded_index(loaded_index); + } +} diff --git a/tests/svs/index/ivf/dynamic_ivf.cpp b/tests/svs/index/ivf/dynamic_ivf.cpp index 89c6308de..9ae9cfd6a 100644 --- a/tests/svs/index/ivf/dynamic_ivf.cpp +++ b/tests/svs/index/ivf/dynamic_ivf.cpp @@ -37,6 +37,8 @@ #include #include #include +#include +#include #include #include @@ -1135,3 +1137,151 @@ CATCH_TEST_CASE("Dynamic IVF Single Search", "[ivf][dynamic_ivf][single_search]" } } } + +CATCH_TEST_CASE("Dynamic IVF Save and Load", "[dynamic_ivf][saveload]") { + const size_t num_threads = 4; + + // Load data and queries + auto data = svs::data::SimpleData::load(test_dataset::data_svs_file()); + auto queries = test_dataset::queries(); + + // Build clustering + auto build_params = svs::index::ivf::IVFBuildParameters(NUM_CLUSTERS, 10, false); + auto threadpool = svs::threads::SequentialThreadPool(); + auto clustering = svs::index::ivf::build_clustering( + build_params, + svs::lib::Lazy([&data]() { return data; }), + Distance(), + threadpool, + false + ); + + // Create initial indices for all data points + std::vector initial_indices; + for (size_t c = 0; c < clustering.size(); ++c) { + for (auto idx : clustering.cluster(c)) { + initial_indices.push_back(idx); + } + } + + // Create the dynamic IVF index using DenseClusteredDataset + auto centroids = clustering.centroids(); + using DataType = svs::data::SimpleData; + auto dense_clusters = + svs::index::ivf::DenseClusteredDataset( + clustering, data, threadpool, svs::lib::Allocator() + ); + + auto threadpool_for_index = svs::threads::as_threadpool(num_threads); + using IndexType = svs::index::ivf::DynamicIVFIndex< + decltype(centroids), + decltype(dense_clusters), + Distance, + decltype(threadpool_for_index)>; + + auto index = IndexType( + std::move(centroids), + std::move(dense_clusters), + initial_indices, + Distance(), + std::move(threadpool_for_index), + 1 // intra_query_threads + ); + + // Configure search parameters + auto search_params = svs::index::ivf::IVFSearchParameters(NUM_CLUSTERS, NUM_NEIGHBORS); + index.set_search_parameters(search_params); + + // Perform initial search to get baseline results + auto original_results = svs::QueryResult(queries.size(), NUM_NEIGHBORS); + index.search( + original_results.view(), + svs::data::ConstSimpleDataView{ + queries.data(), queries.size(), queries.dimensions()}, + search_params + ); + + // Create temporary directories for saving + auto temp_dir = svs_test::temp_directory(); + svs_test::prepare_temp_directory(); + auto config_dir = temp_dir / "config"; + auto data_dir = temp_dir / "data"; + + // Save the index + index.save(config_dir, data_dir); + + // Verify saved files exist + CATCH_REQUIRE(std::filesystem::exists(config_dir)); + CATCH_REQUIRE(std::filesystem::exists(data_dir / "centroids")); + // Verify format files exist in clusters/ subdirectory + // DenseClusteredDataset saves: clusters_archive.bin, ids.bin, cluster_sizes.bin, + // ids_offsets.bin + CATCH_REQUIRE(std::filesystem::exists(data_dir / "clusters" / "clusters_archive.bin")); + CATCH_REQUIRE(std::filesystem::exists(data_dir / "clusters" / "ids.bin")); + CATCH_REQUIRE(std::filesystem::exists(data_dir / "clusters" / "cluster_sizes.bin")); + CATCH_REQUIRE(std::filesystem::exists(data_dir / "clusters" / "ids_offsets.bin")); + + // Load the index back using the load function + auto loaded_index = svs::index::ivf::load_dynamic_ivf_index( + config_dir, + data_dir, + Distance(), + svs::threads::as_threadpool(num_threads), + 1 // intra_query_threads + ); + + // Set search parameters on loaded index + loaded_index.set_search_parameters(search_params); + + // Perform search on loaded index + auto loaded_results = svs::QueryResult(queries.size(), NUM_NEIGHBORS); + loaded_index.search( + loaded_results.view(), + svs::data::ConstSimpleDataView{ + queries.data(), queries.size(), queries.dimensions()}, + search_params + ); + + // Verify results match + for (size_t q = 0; q < queries.size(); ++q) { + for (size_t k = 0; k < NUM_NEIGHBORS; ++k) { + CATCH_REQUIRE(original_results.index(q, k) == loaded_results.index(q, k)); + CATCH_REQUIRE(original_results.distance(q, k) == loaded_results.distance(q, k)); + } + } + + // Verify index properties are preserved + CATCH_REQUIRE(loaded_index.size() == index.size()); + CATCH_REQUIRE(loaded_index.num_clusters() == index.num_clusters()); + CATCH_REQUIRE(loaded_index.dimensions() == index.dimensions()); + + // Test that dynamic operations still work after loading + // Delete some points + std::vector ids_to_delete; + for (size_t i = 0; i < 10 && i < initial_indices.size(); ++i) { + ids_to_delete.push_back(initial_indices[i]); + } + size_t deleted = loaded_index.delete_entries(ids_to_delete); + CATCH_REQUIRE(deleted == ids_to_delete.size()); + CATCH_REQUIRE(loaded_index.size() == index.size() - deleted); + + // Compact and verify + loaded_index.compact(1000); + + // Search should still work after modifications + loaded_index.search( + loaded_results.view(), + svs::data::ConstSimpleDataView{ + queries.data(), queries.size(), queries.dimensions()}, + search_params + ); + + // Verify we still get valid results + size_t valid_results = 0; + for (size_t i = 0; i < loaded_results.n_queries(); ++i) { + if (loaded_results.index(i, 0) != std::numeric_limits::max()) { + valid_results++; + } + } + CATCH_REQUIRE(valid_results > 0); +} diff --git a/tests/svs/index/ivf/index.cpp b/tests/svs/index/ivf/index.cpp index 39c264707..6724dc3b3 100644 --- a/tests/svs/index/ivf/index.cpp +++ b/tests/svs/index/ivf/index.cpp @@ -19,6 +19,7 @@ // tests #include "tests/utils/test_dataset.h" +#include "tests/utils/utils.h" // catch #include "catch2/catch_test_macros.hpp" @@ -28,6 +29,7 @@ #include "svs/core/distance.h" #include "svs/index/ivf/clustering.h" #include "svs/index/ivf/hierarchical_kmeans.h" +#include "svs/lib/saveload.h" // stl #include @@ -170,3 +172,170 @@ CATCH_TEST_CASE("IVF Index Single Search", "[ivf][index][single_search]") { } } } + +CATCH_TEST_CASE("IVF Index Save and Load", "[ivf][index][saveload]") { + namespace ivf = svs::index::ivf; + + // Load test data + auto data = svs::data::SimpleData::load(test_dataset::data_svs_file()); + auto queries = test_dataset::queries(); + + size_t num_clusters = 10; + size_t num_threads = 2; + size_t num_inner_threads = 2; + auto distance = svs::distance::DistanceL2(); + + // Build clustering + auto build_params = ivf::IVFBuildParameters(num_clusters, 10, false); + auto threadpool = svs::threads::SequentialThreadPool(); + auto clustering = + ivf::build_clustering(build_params, data, distance, threadpool, false); + + // Create clustered dataset + auto centroids = clustering.centroids(); + using Idx = uint32_t; + auto cluster = ivf::DenseClusteredDataset( + clustering, data, threadpool, svs::lib::Allocator() + ); + + // Build IVF index + auto threadpool_for_index = svs::threads::as_threadpool(num_threads); + using IndexType = ivf::IVFIndex< + decltype(centroids), + decltype(cluster), + decltype(distance), + decltype(threadpool_for_index)>; + + auto index = IndexType( + std::move(centroids), + std::move(cluster), + distance, + std::move(threadpool_for_index), + num_inner_threads + ); + + // Get search results before saving + auto search_params = ivf::IVFSearchParameters(); + search_params.n_probes_ = 5; + search_params.k_reorder_ = 100; + size_t num_neighbors = 10; + + auto batch_queries = svs::data::ConstSimpleDataView( + queries.data(), queries.size(), queries.dimensions() + ); + auto original_results = svs::QueryResult(queries.size(), num_neighbors); + index.search(original_results.view(), batch_queries, search_params); + + CATCH_SECTION("Save and load IVF index") { + // Prepare temp directory + auto tempdir = svs_test::prepare_temp_directory_v2(); + auto config_dir = tempdir / "config"; + auto data_dir = tempdir / "data"; + + // Save the index + index.save(config_dir, data_dir); + + // Verify files exist + CATCH_REQUIRE(std::filesystem::exists(config_dir)); + CATCH_REQUIRE(std::filesystem::exists(data_dir / "centroids")); + CATCH_REQUIRE(std::filesystem::exists(data_dir / "clusters")); + + // Load the index + using DataType = + svs::data::SimpleData>; + auto loaded_index = ivf::load_ivf_index( + config_dir, + data_dir, + distance, + svs::threads::as_threadpool(num_threads), + num_inner_threads + ); + + // Verify index properties + CATCH_REQUIRE(loaded_index.size() == index.size()); + CATCH_REQUIRE(loaded_index.num_clusters() == index.num_clusters()); + CATCH_REQUIRE(loaded_index.dimensions() == index.dimensions()); + + // Search with loaded index + auto loaded_results = svs::QueryResult(queries.size(), num_neighbors); + loaded_index.search(loaded_results.view(), batch_queries, search_params); + + // Compare results - should be identical + for (size_t q = 0; q < queries.size(); ++q) { + for (size_t i = 0; i < num_neighbors; ++i) { + CATCH_REQUIRE(loaded_results.index(q, i) == original_results.index(q, i)); + CATCH_REQUIRE( + loaded_results.distance(q, i) == + Catch::Approx(original_results.distance(q, i)).epsilon(1e-5) + ); + } + } + + // Cleanup + svs_test::cleanup_temp_directory(); + } + + CATCH_SECTION("Save and load DenseClusteredDataset") { + // Prepare temp directory + auto tempdir = svs_test::prepare_temp_directory_v2(); + + // Re-create clustering and dense clusters for this section + auto section_clustering = + ivf::build_clustering(build_params, data, distance, threadpool, false); + auto section_centroids = section_clustering.centroids(); + auto dense_clusters = + ivf::DenseClusteredDataset( + section_clustering, data, threadpool, svs::lib::Allocator() + ); + + // Save the dense clusters + svs::lib::save_to_disk(dense_clusters, tempdir); + + // Verify config file exists + CATCH_REQUIRE(std::filesystem::exists(tempdir / "svs_config.toml")); + + // Verify saved format: should have clusters_archive.bin, ids.bin, and offset files + CATCH_REQUIRE(std::filesystem::exists(tempdir / "clusters_archive.bin")); + CATCH_REQUIRE(std::filesystem::exists(tempdir / "ids.bin")); + CATCH_REQUIRE(std::filesystem::exists(tempdir / "cluster_sizes.bin")); + CATCH_REQUIRE(std::filesystem::exists(tempdir / "ids_offsets.bin")); + + // Load the dense clusters + auto loaded_clusters = svs::lib::load_from_disk< + ivf::DenseClusteredDataset>( + tempdir, threadpool + ); + + // Verify properties + CATCH_REQUIRE(loaded_clusters.size() == dense_clusters.size()); + CATCH_REQUIRE(loaded_clusters.dimensions() == dense_clusters.dimensions()); + CATCH_REQUIRE( + loaded_clusters.get_prefetch_offset() == dense_clusters.get_prefetch_offset() + ); + + // Verify cluster contents + for (size_t c = 0; c < dense_clusters.size(); ++c) { + auto& orig_cluster = dense_clusters[c]; + auto& loaded_cluster = loaded_clusters[c]; + + CATCH_REQUIRE(orig_cluster.size() == loaded_cluster.size()); + + // Verify data and IDs match + for (size_t i = 0; i < orig_cluster.size(); ++i) { + CATCH_REQUIRE(orig_cluster.ids_[i] == loaded_cluster.ids_[i]); + + // Verify data values + auto orig_datum = orig_cluster.get_datum(i); + auto loaded_datum = loaded_cluster.get_datum(i); + for (size_t d = 0; d < data.dimensions(); ++d) { + CATCH_REQUIRE( + orig_datum[d] == Catch::Approx(loaded_datum[d]).epsilon(1e-6) + ); + } + } + } + + // Cleanup + svs_test::cleanup_temp_directory(); + } +}