Skip to content

Commit 2284227

Browse files
committed
Address default index parameters code review issue
1 parent adcd2cb commit 2284227

File tree

6 files changed

+126
-57
lines changed

6 files changed

+126
-57
lines changed

bindings/cpp/include/svs/runtime/api_defs.h

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
#include <svs/runtime/version.h>
2020

2121
#include <cstdint>
22+
#include <limits>
2223
#include <span>
2324

2425
#ifdef svs_runtime_EXPORTS
@@ -33,6 +34,53 @@ namespace svs {
3334
namespace runtime {
3435
namespace v0 {
3536

37+
class OptionalBool {
38+
enum class Value : int8_t { Undef = -1, True = 1, False = 0 };
39+
Value value_;
40+
41+
public:
42+
constexpr OptionalBool()
43+
: value_(Value::Undef) {}
44+
constexpr OptionalBool(bool b)
45+
: value_(b ? Value::True : Value::False) {}
46+
47+
constexpr bool is_enabled() const { return value_ == Value::True; }
48+
constexpr bool is_disabled() const { return value_ == Value::False; }
49+
constexpr bool is_default() const { return value_ == Value::Undef; }
50+
51+
friend constexpr bool operator==(const OptionalBool& lhs, const OptionalBool& rhs) {
52+
return lhs.value_ == rhs.value_;
53+
}
54+
friend constexpr bool operator!=(const OptionalBool& lhs, const OptionalBool& rhs) {
55+
return lhs.value_ != rhs.value_;
56+
}
57+
};
58+
59+
template <typename T> struct Unspecified;
60+
template <> struct Unspecified<size_t> {
61+
static constexpr size_t value = std::numeric_limits<size_t>::max();
62+
};
63+
template <> struct Unspecified<float> {
64+
static constexpr float value = std::numeric_limits<float>::infinity();
65+
};
66+
template <> struct Unspecified<int> {
67+
static constexpr int value = std::numeric_limits<int>::max();
68+
};
69+
template <> struct Unspecified<bool> {
70+
static constexpr OptionalBool value = {};
71+
};
72+
template <> struct Unspecified<OptionalBool> {
73+
static constexpr OptionalBool value = {};
74+
};
75+
76+
template <typename T> constexpr auto Unspecify() { return Unspecified<T>::value; }
77+
78+
inline bool is_specified(const OptionalBool& value) { return !value.is_default(); }
79+
80+
template <typename T> bool is_specified(const T& value) {
81+
return value != Unspecified<T>::value;
82+
}
83+
3684
enum class MetricType { L2, INNER_PRODUCT };
3785

3886
enum class StorageKind {

bindings/cpp/include/svs/runtime/dynamic_vamana_index.h

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -45,8 +45,8 @@ struct SVS_RUNTIME_API DynamicVamanaIndex : public VamanaIndex {
4545
size_t dim,
4646
MetricType metric,
4747
StorageKind storage_kind,
48-
const VamanaIndex::BuildParams& params,
49-
const VamanaIndex::SearchParams& default_search_params = {10, 10, 0, 0}
48+
const VamanaIndex::BuildParams& params = {},
49+
const VamanaIndex::SearchParams& default_search_params = {}
5050
) noexcept;
5151

5252
static Status destroy(DynamicVamanaIndex* index) noexcept;
@@ -68,8 +68,8 @@ struct SVS_RUNTIME_API DynamicVamanaIndexLeanVec : public DynamicVamanaIndex {
6868
MetricType metric,
6969
StorageKind storage_kind,
7070
size_t leanvec_dims,
71-
const VamanaIndex::BuildParams& params,
72-
const VamanaIndex::SearchParams& default_search_params = {10, 10, 0, 0}
71+
const VamanaIndex::BuildParams& params = {},
72+
const VamanaIndex::SearchParams& default_search_params = {}
7373
) noexcept;
7474

7575
// Specialization to build LeanVec-based Vamana index with provided training data
@@ -79,8 +79,8 @@ struct SVS_RUNTIME_API DynamicVamanaIndexLeanVec : public DynamicVamanaIndex {
7979
MetricType metric,
8080
StorageKind storage_kind,
8181
const LeanVecTrainingData* training_data,
82-
const VamanaIndex::BuildParams& params,
83-
const VamanaIndex::SearchParams& default_search_params = {10, 10, 0, 0}
82+
const VamanaIndex::BuildParams& params = {},
83+
const VamanaIndex::SearchParams& default_search_params = {}
8484
) noexcept;
8585
};
8686

bindings/cpp/include/svs/runtime/vamana_index.h

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -29,19 +29,19 @@ struct SVS_RUNTIME_API VamanaIndex {
2929
virtual ~VamanaIndex();
3030

3131
struct BuildParams {
32-
size_t graph_max_degree;
33-
size_t prune_to = 0;
34-
float alpha = 0;
35-
size_t construction_window_size = 40;
36-
size_t max_candidate_pool_size = 200;
37-
bool use_full_search_history = true;
32+
size_t graph_max_degree = Unspecify<size_t>();
33+
size_t prune_to = Unspecify<size_t>();
34+
float alpha = Unspecify<float>();
35+
size_t construction_window_size = Unspecify<size_t>();
36+
size_t max_candidate_pool_size = Unspecify<size_t>();
37+
OptionalBool use_full_search_history = Unspecify<bool>();
3838
};
3939

4040
struct SearchParams {
41-
size_t search_window_size = 10;
42-
size_t search_buffer_capacity = 10;
43-
size_t prefetch_lookahead = 0;
44-
size_t prefetch_step = 0;
41+
size_t search_window_size = Unspecify<size_t>();
42+
size_t search_buffer_capacity = Unspecify<size_t>();
43+
size_t prefetch_lookahead = Unspecify<size_t>();
44+
size_t prefetch_step = Unspecify<size_t>();
4545
};
4646

4747
virtual Status search(

bindings/cpp/src/dynamic_vamana_index_impl.h

Lines changed: 40 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -59,15 +59,6 @@ class DynamicVamanaIndexImpl {
5959
"The specified storage kind is not compatible with the "
6060
"DynamicVamanaIndex"};
6161
}
62-
63-
if (build_params_.prune_to == 0) {
64-
build_params_.prune_to = build_params_.graph_max_degree < 4
65-
? build_params_.graph_max_degree
66-
: build_params_.graph_max_degree - 4;
67-
}
68-
if (build_params_.alpha == 0) {
69-
build_params_.alpha = metric == MetricType::L2 ? 1.2f : 0.95f;
70-
}
7162
}
7263

7364
size_t size() const { return impl_ ? impl_->size() : 0; }
@@ -337,13 +328,19 @@ class DynamicVamanaIndexImpl {
337328
protected:
338329
// Utility functions
339330
svs::index::vamana::VamanaBuildParameters vamana_build_parameters() const {
340-
return svs::index::vamana::VamanaBuildParameters{
341-
build_params_.alpha,
342-
build_params_.graph_max_degree,
343-
build_params_.construction_window_size,
344-
build_params_.max_candidate_pool_size,
345-
build_params_.prune_to,
346-
build_params_.use_full_search_history};
331+
svs::index::vamana::VamanaBuildParameters result;
332+
set_if_specified(result.alpha, build_params_.alpha);
333+
set_if_specified(result.graph_max_degree, build_params_.graph_max_degree);
334+
set_if_specified(result.window_size, build_params_.construction_window_size);
335+
set_if_specified(
336+
result.max_candidate_pool_size, build_params_.max_candidate_pool_size
337+
);
338+
set_if_specified(result.prune_to, build_params_.prune_to);
339+
if (is_specified(build_params_.use_full_search_history)) {
340+
result.use_full_search_history =
341+
build_params_.use_full_search_history.is_enabled();
342+
}
343+
return result;
347344
}
348345

349346
svs::index::vamana::VamanaSearchParameters
@@ -352,33 +349,37 @@ class DynamicVamanaIndexImpl {
352349
throw StatusException{ErrorCode::NOT_INITIALIZED, "Index not initialized"};
353350
}
354351

355-
auto sp = impl_->get_search_parameters();
356-
357-
auto search_window_size = default_search_params_.search_window_size;
358-
auto search_buffer_capacity = default_search_params_.search_buffer_capacity;
359-
if (default_search_params_.prefetch_lookahead > 0) {
360-
sp = sp.prefetch_lookahead(default_search_params_.prefetch_lookahead);
361-
}
362-
if (default_search_params_.prefetch_step > 0) {
363-
sp = sp.prefetch_step(default_search_params_.prefetch_step);
352+
// Copy default search parameters
353+
auto search_params = default_search_params_;
354+
// Update with user-specified parameters
355+
if (params) {
356+
set_if_specified(search_params.search_window_size, params->search_window_size);
357+
set_if_specified(
358+
search_params.search_buffer_capacity, params->search_buffer_capacity
359+
);
360+
set_if_specified(search_params.prefetch_lookahead, params->prefetch_lookahead);
361+
set_if_specified(search_params.prefetch_step, params->prefetch_step);
364362
}
365363

366-
if (params != nullptr) {
367-
if (params->search_window_size > 0)
368-
search_window_size = params->search_window_size;
369-
if (params->search_buffer_capacity > 0)
370-
search_buffer_capacity = params->search_buffer_capacity;
371-
if (params->prefetch_lookahead > 0) {
372-
sp = sp.prefetch_lookahead(params->prefetch_lookahead);
373-
}
374-
if (params->prefetch_step > 0) {
375-
sp = sp.prefetch_step(params->prefetch_step);
364+
// Get current search parameters from the index
365+
auto result = impl_->get_search_parameters();
366+
// Update with specified parameters
367+
if (is_specified(search_params.search_window_size)) {
368+
if (is_specified(search_params.search_buffer_capacity)) {
369+
result.buffer_config(
370+
{search_params.search_window_size, search_params.search_buffer_capacity}
371+
);
372+
} else {
373+
result.buffer_config(search_params.search_window_size);
376374
}
375+
} else if (is_specified(search_params.search_buffer_capacity)) {
376+
result.buffer_config(search_params.search_buffer_capacity);
377377
}
378378

379-
return impl_->get_search_parameters().buffer_config(
380-
{search_window_size, search_buffer_capacity}
381-
);
379+
set_if_specified(result.prefetch_lookahead_, search_params.prefetch_lookahead);
380+
set_if_specified(result.prefetch_step_, search_params.prefetch_step);
381+
382+
return result;
382383
}
383384

384385
template <typename Tag, typename... StorageArgs>
@@ -447,7 +448,7 @@ class DynamicVamanaIndexImpl {
447448
buffer_config.get_search_window_size(), buffer_config.get_total_capacity()};
448449
metric_type_ = metric;
449450
storage_kind_ = storage_kind;
450-
build_params_ = {
451+
build_params_ = VamanaIndex::BuildParams{
451452
impl_->get_graph_max_degree(),
452453
impl_->get_prune_to(),
453454
impl_->get_alpha(),

bindings/cpp/src/dynamic_vamana_index_leanvec_impl.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,7 @@ struct DynamicVamanaIndexLeanVecImpl : public DynamicVamanaIndexImpl {
5454
StorageKind storage_kind,
5555
const LeanVecTrainingDataImpl& training_data,
5656
const VamanaIndex::BuildParams& params,
57-
const VamanaIndex::SearchParams& default_search_params = {10, 10}
57+
const VamanaIndex::SearchParams& default_search_params
5858
)
5959
: DynamicVamanaIndexImpl{dim, metric, storage_kind, params, default_search_params}
6060
, leanvec_dims_{training_data.get_leanvec_dims()}
@@ -68,7 +68,7 @@ struct DynamicVamanaIndexLeanVecImpl : public DynamicVamanaIndexImpl {
6868
StorageKind storage_kind,
6969
size_t leanvec_dims,
7070
const VamanaIndex::BuildParams& params,
71-
const VamanaIndex::SearchParams& default_search_params = {10, 10}
71+
const VamanaIndex::SearchParams& default_search_params
7272
)
7373
: DynamicVamanaIndexImpl{dim, metric, storage_kind, params, default_search_params}
7474
, leanvec_dims_{leanvec_dims}

bindings/cpp/src/svs_runtime_utils.h

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -96,6 +96,26 @@ inline auto runtime_error_wrapper(Callable&& func) noexcept -> Status {
9696
}
9797
}
9898

99+
inline void set_if_specified(bool& target, const OptionalBool& value) {
100+
if (is_specified(value)) {
101+
target = value.is_enabled();
102+
}
103+
}
104+
105+
template <typename T> void set_if_specified(T& target, const T& value) {
106+
if (is_specified(value)) {
107+
target = value;
108+
}
109+
}
110+
111+
template <typename T> void require_specified(const T& value, const char* name) {
112+
if (!is_specified(value)) {
113+
throw StatusException{
114+
ErrorCode::INVALID_ARGUMENT,
115+
std::string("The parameter '") + name + "' must be specified."};
116+
}
117+
}
118+
99119
namespace storage {
100120

101121
// Consolidated storage kind checks using constexpr functions

0 commit comments

Comments
 (0)