Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion hnswlib/bruteforce.h
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ class BruteforceSearch : public AlgorithmInterface<dist_t> {
size_per_element_(0),
data_size_(0),
dist_func_param_(nullptr) {
(void)s; // silence unused variable warnings.
}


Expand Down Expand Up @@ -65,6 +66,7 @@ class BruteforceSearch : public AlgorithmInterface<dist_t> {


Status addPointNoExceptions(const void *datapoint, labeltype label, bool replace_deleted = false) override {
(void)replace_deleted; // silence unused variable warning.
int idx;
{
std::unique_lock<std::mutex> lock(index_lock);
Expand Down Expand Up @@ -113,7 +115,7 @@ class BruteforceSearch : public AlgorithmInterface<dist_t> {
assert(k <= cur_element_count);
std::priority_queue<std::pair<dist_t, labeltype >> topResults;
dist_t lastdist = std::numeric_limits<dist_t>::max();
for (int i = 0; i < cur_element_count; i++) {
for (size_t i = 0; i < cur_element_count; i++) {
dist_t dist = fstdistfunc_(query_data, data_ + size_per_element_ * i, dist_func_param_);
if (dist <= lastdist || topResults.size() < k) {
labeltype label = *((labeltype *) (data_ + size_per_element_ * i + data_size_));
Expand Down
7 changes: 5 additions & 2 deletions hnswlib/hnswalg.h
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,7 @@ class HierarchicalNSW : public AlgorithmInterface<dist_t> {


HierarchicalNSW(SpaceInterface<dist_t> *s) {
(void)s; // silence unused variable warnings.
}


Expand All @@ -87,6 +88,7 @@ class HierarchicalNSW : public AlgorithmInterface<dist_t> {
size_t max_elements = 0,
bool allow_replace_deleted = false)
: allow_replace_deleted_(allow_replace_deleted) {
(void)nmslib; // silence unused variable warnings.
loadIndex(location, s, max_elements);
}

Expand Down Expand Up @@ -543,6 +545,7 @@ class HierarchicalNSW : public AlgorithmInterface<dist_t> {
std::priority_queue<std::pair<dist_t, tableint>, std::vector<std::pair<dist_t, tableint>>, CompareByFirst> &top_candidates,
int level,
bool isUpdate) {
(void)data_point; // silence unused variable warning.
size_t Mcurmax = level ? maxM_ : maxM0_;
getNeighborsByHeuristic2(top_candidates, M_);
if (top_candidates.size() > M_)
Expand Down Expand Up @@ -1291,7 +1294,7 @@ class HierarchicalNSW : public AlgorithmInterface<dist_t> {
tableint *datal = (tableint *) (data + 1);
for (int i = 0; i < size; i++) {
tableint cand = datal[i];
if (cand < 0 || cand > max_elements_)
if (cand > max_elements_)
return Status("cand error");
dist_t d = fstdistfunc_(data_point, getDataByInternalId(cand), dist_func_param_);
if (d < curdist) {
Expand Down Expand Up @@ -1365,7 +1368,7 @@ class HierarchicalNSW : public AlgorithmInterface<dist_t> {
tableint *datal = (tableint *) (data + 1);
for (int i = 0; i < size; i++) {
tableint cand = datal[i];
if (cand < 0 || cand > max_elements_)
if (cand > max_elements_)
return Status("cand error");
dist_t d = fstdistfunc_(query_data, getDataByInternalId(cand), dist_func_param_);

Expand Down
14 changes: 13 additions & 1 deletion hnswlib/hnswlib.h
Original file line number Diff line number Diff line change
Expand Up @@ -30,16 +30,19 @@
#ifdef _MSC_VER
#include <intrin.h>
#include <stdexcept>
#if defined(USE_AVX)
static void cpuid(int32_t out[4], int32_t eax, int32_t ecx) {
__cpuidex(out, eax, ecx);
}
static __int64 xgetbv(unsigned int x) {
return _xgetbv(x);
}
#endif
#else
#include <x86intrin.h>
#include <cpuid.h>
#include <stdint.h>
#if defined(USE_AVX)
static void cpuid(int32_t cpuInfo[4], int32_t eax, int32_t ecx) {
__cpuid_count(eax, ecx, cpuInfo[0], cpuInfo[1], cpuInfo[2], cpuInfo[3]);
}
Expand All @@ -49,6 +52,7 @@ static uint64_t xgetbv(unsigned int index) {
return ((uint64_t)edx << 32) | eax;
}
#endif
#endif

#if defined(USE_AVX512)
#include <immintrin.h>
Expand All @@ -65,6 +69,7 @@ static uint64_t xgetbv(unsigned int index) {
// Adapted from https://github.com/Mysticial/FeatureDetector
#define _XCR_XFEATURE_ENABLED_MASK 0

#if defined(USE_AVX)
static bool AVXCapable() {
int cpuInfo[4];

Expand Down Expand Up @@ -92,6 +97,7 @@ static bool AVXCapable() {
return HW_AVX && avxSupported;
}

#if defined(USE_AVX512)
static bool AVX512Capable() {
if (!AVXCapable()) return false;

Expand Down Expand Up @@ -121,6 +127,8 @@ static bool AVX512Capable() {
return HW_AVX512F && avx512Supported;
}
#endif
#endif
#endif

#include <queue>
#include <vector>
Expand Down Expand Up @@ -212,7 +220,11 @@ typedef size_t labeltype;
// This can be extended to store state for filtering (e.g. from a std::set)
class BaseFilterFunctor {
public:
virtual bool operator()(hnswlib::labeltype id) { return true; }
virtual bool operator()(hnswlib::labeltype id) {
(void)id; // silence unused variable warning.
return true;
}

virtual ~BaseFilterFunctor() {};
};

Expand Down
1 change: 0 additions & 1 deletion hnswlib/space_ip.h
Original file line number Diff line number Diff line change
Expand Up @@ -145,7 +145,6 @@ InnerProductDistanceSIMD4ExtSSE(const void *pVect1v, const void *pVect2v, const

static float
InnerProductSIMD16ExtAVX512(const void *pVect1v, const void *pVect2v, const void *qty_ptr) {
float PORTABLE_ALIGN64 TmpRes[16];
float *pVect1 = (float *) pVect1v;
float *pVect2 = (float *) pVect2v;
size_t qty = *((size_t *) qty_ptr);
Expand Down
9 changes: 9 additions & 0 deletions hnswlib/stop_condition.h
Original file line number Diff line number Diff line change
Expand Up @@ -164,6 +164,7 @@ class MultiVectorSearchStopCondition : public BaseSearchStopCondition<dist_t> {
}

void add_point_to_result(labeltype label, const void *datapoint, dist_t dist) override {
(void)label; // silence unused variable warnings.
DOCIDTYPE doc_id = space_.get_doc_id(datapoint);
if (doc_counter_[doc_id] == 0) {
curr_num_docs_ += 1;
Expand All @@ -173,6 +174,8 @@ class MultiVectorSearchStopCondition : public BaseSearchStopCondition<dist_t> {
}

void remove_point_from_result(labeltype label, const void *datapoint, dist_t dist) override {
(void)label; // silence unused variable warnings.
(void)dist;
DOCIDTYPE doc_id = space_.get_doc_id(datapoint);
doc_counter_[doc_id] -= 1;
if (doc_counter_[doc_id] == 0) {
Expand Down Expand Up @@ -232,10 +235,16 @@ class EpsilonSearchStopCondition : public BaseSearchStopCondition<dist_t> {
}

void add_point_to_result(labeltype label, const void *datapoint, dist_t dist) override {
(void)label; // silence unused variable warnings;
(void)datapoint;
(void)dist;
curr_num_items_ += 1;
}

void remove_point_from_result(labeltype label, const void *datapoint, dist_t dist) override {
(void)label; // silence unused variable warnings;
(void)datapoint;
(void)dist;
curr_num_items_ -= 1;
}

Expand Down
73 changes: 41 additions & 32 deletions python_bindings/bindings.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
#include <pybind11/numpy.h>
#include <pybind11/stl.h>
#include "hnswlib.h"
#include <type_traits>
#include <thread>
#include <atomic>
#include <stdlib.h>
Expand Down Expand Up @@ -100,8 +101,8 @@ inline void get_input_array_shapes(const py::buffer_info& buffer, size_t* rows,
if (buffer.ndim != 2 && buffer.ndim != 1) {
char msg[256];
snprintf(msg, sizeof(msg),
"Input vector data wrong shape. Number of dimensions %d. Data must be a 1D or 2D array.",
buffer.ndim);
"Input vector data wrong shape. Number of dimensions %lld. Data must be a 1D or 2D array.",
(long long)buffer.ndim); // use long long to avoid overflowing an int (%d) from a pybind11::ssize_t.
HNSWLIB_THROW_RUNTIME_ERROR(msg);
}
if (buffer.ndim == 2) {
Expand All @@ -113,19 +114,29 @@ inline void get_input_array_shapes(const py::buffer_info& buffer, size_t* rows,
}
}

// Quick and dirty implementations of C++20's std::cmp_equal() and friends.
template<typename Left_, typename Right_>
bool safe_unsigned_eq(Left_ l, Right_ r) {
return static_cast<typename std::make_unsigned<Left_>::type>(l) == static_cast<typename std::make_unsigned<Right_>::type>(r);
}

template<typename Left_, typename Right_>
bool safe_unsigned_lte(Left_ l, Right_ r) {
return static_cast<typename std::make_unsigned<Left_>::type>(l) <= static_cast<typename std::make_unsigned<Right_>::type>(r);
}

inline std::vector<size_t> get_input_ids_and_check_shapes(const py::object& ids_, size_t feature_rows) {
std::vector<size_t> ids;
if (!ids_.is_none()) {
py::array_t < size_t, py::array::c_style | py::array::forcecast > items(ids_);
auto ids_numpy = items.request();
// check shapes
if (!((ids_numpy.ndim == 1 && ids_numpy.shape[0] == feature_rows) ||
if (!((ids_numpy.ndim == 1 && safe_unsigned_eq(ids_numpy.shape[0], feature_rows)) ||
(ids_numpy.ndim == 0 && feature_rows == 1))) {
char msg[256];
snprintf(msg, sizeof(msg),
"The input label shape %d does not match the input data vector shape %d",
ids_numpy.ndim, feature_rows);
"The input label shape %lld does not match the input data vector shape %zu",
(long long)ids_numpy.ndim, feature_rows);
HNSWLIB_THROW_RUNTIME_ERROR(msg);
}
// extract data
Expand Down Expand Up @@ -259,11 +270,11 @@ class Index {
size_t rows, features;
get_input_array_shapes(buffer, &rows, &features);

if (features != dim)
if (!safe_unsigned_eq(features, dim))
HNSWLIB_THROW_RUNTIME_ERROR("Wrong dimensionality of the vectors");

// avoid using threads when the number of additions is small:
if (rows <= num_threads * 4) {
if (safe_unsigned_lte(rows, num_threads * 4)) {
num_threads = 1;
}

Expand All @@ -287,6 +298,7 @@ class Index {
py::gil_scoped_release l;
if (normalize == false) {
ParallelFor(start, rows, num_threads, [&](size_t row, size_t threadId) {
(void)threadId; // silence unused variable warnings.
size_t id = ids.size() ? ids.at(row) : (cur_l + row);
appr_alg->addPoint((void*)items.data(row), (size_t)id, replace_deleted);
});
Expand Down Expand Up @@ -334,9 +346,9 @@ class Index {
if (return_type == "list") {
return py::cast(data);
}
if (return_type == "numpy") {
return py::array_t< data_t, py::array::c_style | py::array::forcecast >(py::cast(data));
}

// Must be numpy if it's not a list.
return py::array_t< data_t, py::array::c_style | py::array::forcecast >(py::cast(data));
}


Expand Down Expand Up @@ -394,19 +406,19 @@ class Index {
}

py::capsule free_when_done_l0(data_level0_npy, [](void* f) {
delete[] f;
delete[] reinterpret_cast<char*>(f);
});
py::capsule free_when_done_lvl(element_levels_npy, [](void* f) {
delete[] f;
delete[] reinterpret_cast<int*>(f);
});
py::capsule free_when_done_lb(label_lookup_key_npy, [](void* f) {
delete[] f;
delete[] reinterpret_cast<hnswlib::labeltype*>(f);
});
py::capsule free_when_done_id(label_lookup_val_npy, [](void* f) {
delete[] f;
delete[] reinterpret_cast<hnswlib::tableint*>(f);
});
py::capsule free_when_done_ll(link_list_npy, [](void* f) {
delete[] f;
delete[] reinterpret_cast<char*>(f);
});

/* TODO: serialize state of random generators appr_alg->level_generator_ and appr_alg->update_probability_generator_ */
Expand Down Expand Up @@ -557,11 +569,10 @@ class Index {
auto link_list_npy = d["link_lists"].cast<py::array_t < char, py::array::c_style | py::array::forcecast > >();

for (size_t i = 0; i < appr_alg->cur_element_count; i++) {
if (label_lookup_val_npy.data()[i] < 0) {
HNSWLIB_THROW_RUNTIME_ERROR("Internal id cannot be negative!");
} else {
appr_alg->label_lookup_.insert(std::make_pair(label_lookup_key_npy.data()[i], label_lookup_val_npy.data()[i]));
}
// if (label_lookup_val_npy.data()[i] < 0) { // unnecessary as tableint is unsigned.
// HNSWLIB_THROW_RUNTIME_ERROR("Internal id cannot be negative!");
// }
appr_alg->label_lookup_.insert(std::make_pair(label_lookup_key_npy.data()[i], label_lookup_val_npy.data()[i]));
}

memcpy(appr_alg->element_levels_.data(), element_levels_npy.data(), element_levels_npy.nbytes());
Expand Down Expand Up @@ -630,7 +641,7 @@ class Index {
get_input_array_shapes(buffer, &rows, &features);

// avoid using threads when the number of searches is small:
if (rows <= num_threads * 4) {
if (safe_unsigned_lte(rows, num_threads * 4)) {
num_threads = 1;
}

Expand All @@ -643,6 +654,7 @@ class Index {

if (normalize == false) {
ParallelFor(0, rows, num_threads, [&](size_t row, size_t threadId) {
(void)threadId; // silence unused variable warnings.
std::priority_queue<std::pair<dist_t, hnswlib::labeltype >> result = appr_alg->searchKnn(
(void*)items.data(row), k, p_idFilter);
if (result.size() != k)
Expand All @@ -658,8 +670,6 @@ class Index {
} else {
std::vector<float> norm_array(num_threads * features);
ParallelFor(0, rows, num_threads, [&](size_t row, size_t threadId) {
float* data = (float*)items.data(row);

size_t start_idx = threadId * dim;
normalize_vector((float*)items.data(row), (norm_array.data() + start_idx));

Expand All @@ -678,10 +688,10 @@ class Index {
}
}
py::capsule free_when_done_l(data_numpy_l, [](void* f) {
delete[] f;
delete[] reinterpret_cast<hnswlib::labeltype*>(f);
});
py::capsule free_when_done_d(data_numpy_d, [](void* f) {
delete[] f;
delete[] reinterpret_cast<dist_t*>(f);
});

return py::make_tuple(
Expand Down Expand Up @@ -807,7 +817,7 @@ class BFIndex {
size_t rows, features;
get_input_array_shapes(buffer, &rows, &features);

if (features != dim)
if (!safe_unsigned_eq(features, dim))
HNSWLIB_THROW_RUNTIME_ERROR("Wrong dimensionality of the vectors");

std::vector<size_t> ids = get_input_ids_and_check_shapes(ids_, rows);
Expand Down Expand Up @@ -839,6 +849,7 @@ class BFIndex {


void loadIndex(const std::string &path_to_index, size_t max_elements) {
(void)max_elements; // silence unused variable warnings.
if (alg) {
std::cerr << "Warning: Calling load_index for an already inited index. Old index is being deallocated." << std::endl;
delete alg;
Expand Down Expand Up @@ -875,6 +886,7 @@ class BFIndex {

if (!normalize) {
ParallelFor(0, rows, num_threads, [&](size_t row, size_t threadId) {
(void)threadId; // silence unused variable warnings.
std::priority_queue<std::pair<dist_t, hnswlib::labeltype >> result = alg->searchKnn(
(void*)items.data(row), k, p_idFilter);
if (result.size() != k)
Expand Down Expand Up @@ -909,10 +921,10 @@ class BFIndex {
}

py::capsule free_when_done_l(data_numpy_l, [](void *f) {
delete[] f;
delete[] reinterpret_cast<hnswlib::labeltype*>(f);
});
py::capsule free_when_done_d(data_numpy_d, [](void *f) {
delete[] f;
delete[] reinterpret_cast<dist_t*>(f);
});


Expand All @@ -932,9 +944,7 @@ class BFIndex {
};


PYBIND11_PLUGIN(hnswlib) {
py::module m("hnswlib");

PYBIND11_MODULE(hnswlib, m) {
py::class_<Index<float>>(m, "Index")
.def(py::init(&Index<float>::createFromParams), py::arg("params"))
/* WARNING: Index::createFromIndex is not thread-safe with Index::addItems */
Expand Down Expand Up @@ -1034,5 +1044,4 @@ PYBIND11_PLUGIN(hnswlib) {
.def("get_max_elements", &BFIndex<float>::getMaxElements)
.def("get_current_count", &BFIndex<float>::getCurrentCount)
.def_readwrite("num_threads", &BFIndex<float>::num_threads_default);
return m.ptr();
}
Loading