Skip to content

Commit

Permalink
fix knowhere-hnsw filter strategy
Browse files Browse the repository at this point in the history
Signed-off-by: min.tian <[email protected]>
  • Loading branch information
alwayslove2013 committed Jul 17, 2024
1 parent 1d759d0 commit b620498
Showing 1 changed file with 10 additions and 5 deletions.
15 changes: 10 additions & 5 deletions thirdparty/hnswlib/hnswlib/hnswalg.h
Original file line number Diff line number Diff line change
Expand Up @@ -435,7 +435,7 @@ class HierarchicalNSW : public AlgorithmInterface<dist_t> {
metric_hops++;
metric_distance_computations += size;
}
float kAlpha = bitset.filter_ratio() / 2.0f;
float kAlpha = bitset.filter_ratio() * 0.7;
for (size_t i = 1; i <= size; ++i) {
if (i + 1 <= size) {
prefetchData(list[i + 1]);
Expand Down Expand Up @@ -484,7 +484,7 @@ class HierarchicalNSW : public AlgorithmInterface<dist_t> {
searchBaseLayerST(tableint ep_id, const void* data_point, size_t ef, std::vector<bool>& visited,
const knowhere::BitsetView& bitset,
const knowhere::feder::hnsw::FederResultUniq& feder_result = nullptr,
IteratorMinHeap* disqualified = nullptr, float accumulative_alpha = 0.0f) const {
IteratorMinHeap* disqualified = nullptr, float accumulative_alpha = 1.0f) const {
if (feder_result != nullptr) {
feder_result->visit_info_.AddLevelVisitRecord(0);
}
Expand Down Expand Up @@ -1418,14 +1418,15 @@ class HierarchicalNSW : public AlgorithmInterface<dist_t> {
}

// do bruteforce search when delete rate high
const size_t filtered_out_num = bitset.count();
const size_t valid_count = cur_element_count - filtered_out_num;
if (!bitset.empty()) {
const size_t filtered_out_num = bitset.count();
#if defined(NOT_COMPILE_FOR_SWIG) && !defined(KNOWHERE_WITH_LIGHT)
double ratio = ((double)filtered_out_num) / bitset.size();
knowhere::knowhere_hnsw_bitset_ratio.Observe(ratio);
#endif
if (filtered_out_num >= (cur_element_count * kHnswSearchKnnBFFilterThreshold) ||
k >= (cur_element_count - filtered_out_num) * kHnswSearchBFTopkThreshold) {
k >= valid_count * kHnswSearchBFTopkThreshold) {
return searchKnnBF(query_data, k, bitset);
}
}
Expand All @@ -1442,6 +1443,10 @@ class HierarchicalNSW : public AlgorithmInterface<dist_t> {
}
std::vector<std::pair<dist_t, labeltype>> result;
size_t len = std::min(k, retset.size());
// switch to brute-force when insufficient k
if (len < valid_count) {
return searchKnnBF(query_data, k, bitset);
}
result.reserve(len);
if constexpr (sq_enabled && has_raw_data) {
knowhere::ResultMaxHeap<dist_t, labeltype> max_heap(len);
Expand All @@ -1468,7 +1473,7 @@ class HierarchicalNSW : public AlgorithmInterface<dist_t> {
const knowhere::BitsetView& bitset) const {
auto accumulative_alpha = (bitset.count() >= (cur_element_count * kHnswSearchKnnBFFilterThreshold))
? std::numeric_limits<float>::max()
: 0.0f;
: 1.0f;
std::unique_ptr<int8_t[]> query_data_copy = nullptr;
query_data_copy = std::make_unique<int8_t[]>(data_size_);
std::memcpy(query_data_copy.get(), query_data, data_size_);
Expand Down

0 comments on commit b620498

Please sign in to comment.