From 3115ad855416ef4803ea8b5249f7b10222968aa8 Mon Sep 17 00:00:00 2001 From: Min Tian Date: Wed, 17 Jul 2024 15:19:36 +0800 Subject: [PATCH] fix knowhere-hnsw filter strategy (#711) Signed-off-by: min.tian --- thirdparty/hnswlib/hnswlib/hnswalg.h | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/thirdparty/hnswlib/hnswlib/hnswalg.h b/thirdparty/hnswlib/hnswlib/hnswalg.h index 8f02b76df..ae8f6c0ae 100644 --- a/thirdparty/hnswlib/hnswlib/hnswalg.h +++ b/thirdparty/hnswlib/hnswlib/hnswalg.h @@ -435,7 +435,7 @@ class HierarchicalNSW : public AlgorithmInterface { metric_hops++; metric_distance_computations += size; } - float kAlpha = bitset.filter_ratio() / 2.0f; + float kAlpha = bitset.filter_ratio() * 0.7f; for (size_t i = 1; i <= size; ++i) { if (i + 1 <= size) { prefetchData(list[i + 1]); @@ -484,7 +484,7 @@ class HierarchicalNSW : public AlgorithmInterface { searchBaseLayerST(tableint ep_id, const void* data_point, size_t ef, std::vector& visited, const knowhere::BitsetView& bitset, const knowhere::feder::hnsw::FederResultUniq& feder_result = nullptr, - IteratorMinHeap* disqualified = nullptr, float accumulative_alpha = 0.0f) const { + IteratorMinHeap* disqualified = nullptr, float accumulative_alpha = 1.0f) const { if (feder_result != nullptr) { feder_result->visit_info_.AddLevelVisitRecord(0); } @@ -1440,6 +1440,10 @@ class HierarchicalNSW : public AlgorithmInterface { retset = searchBaseLayerST(currObj, query_data, std::max(ef, k), visited, bitset, feder_result); } + // switch to brute-force when insufficient k + if (retset.size() < k) { + return searchKnnBF(query_data, k, bitset); + } std::vector> result; size_t len = std::min(k, retset.size()); result.reserve(len); @@ -1468,7 +1472,7 @@ class HierarchicalNSW : public AlgorithmInterface { const knowhere::BitsetView& bitset) const { auto accumulative_alpha = (bitset.count() >= (cur_element_count * kHnswSearchKnnBFFilterThreshold)) ? std::numeric_limits::max() - : 0.0f; + : 1.0f; std::unique_ptr query_data_copy = nullptr; query_data_copy = std::make_unique(data_size_); std::memcpy(query_data_copy.get(), query_data, data_size_);