Skip to content

Commit

Permalink
fix knowhere-hnsw filter strategy (#711)
Browse files Browse the repository at this point in the history
Signed-off-by: min.tian <[email protected]>
  • Loading branch information
alwayslove2013 authored Jul 17, 2024
1 parent 1d759d0 commit 3115ad8
Showing 1 changed file with 7 additions and 3 deletions.
10 changes: 7 additions & 3 deletions thirdparty/hnswlib/hnswlib/hnswalg.h
Original file line number Diff line number Diff line change
Expand Up @@ -435,7 +435,7 @@ class HierarchicalNSW : public AlgorithmInterface<dist_t> {
metric_hops++;
metric_distance_computations += size;
}
float kAlpha = bitset.filter_ratio() / 2.0f;
float kAlpha = bitset.filter_ratio() * 0.7f;
for (size_t i = 1; i <= size; ++i) {
if (i + 1 <= size) {
prefetchData(list[i + 1]);
Expand Down Expand Up @@ -484,7 +484,7 @@ class HierarchicalNSW : public AlgorithmInterface<dist_t> {
searchBaseLayerST(tableint ep_id, const void* data_point, size_t ef, std::vector<bool>& visited,
const knowhere::BitsetView& bitset,
const knowhere::feder::hnsw::FederResultUniq& feder_result = nullptr,
IteratorMinHeap* disqualified = nullptr, float accumulative_alpha = 0.0f) const {
IteratorMinHeap* disqualified = nullptr, float accumulative_alpha = 1.0f) const {
if (feder_result != nullptr) {
feder_result->visit_info_.AddLevelVisitRecord(0);
}
Expand Down Expand Up @@ -1440,6 +1440,10 @@ class HierarchicalNSW : public AlgorithmInterface<dist_t> {
retset =
searchBaseLayerST<false, true>(currObj, query_data, std::max(ef, k), visited, bitset, feder_result);
}
// switch to brute-force when insufficient k
if (retset.size() < k) {
return searchKnnBF(query_data, k, bitset);
}
std::vector<std::pair<dist_t, labeltype>> result;
size_t len = std::min(k, retset.size());
result.reserve(len);
Expand Down Expand Up @@ -1468,7 +1472,7 @@ class HierarchicalNSW : public AlgorithmInterface<dist_t> {
const knowhere::BitsetView& bitset) const {
auto accumulative_alpha = (bitset.count() >= (cur_element_count * kHnswSearchKnnBFFilterThreshold))
? std::numeric_limits<float>::max()
: 0.0f;
: 1.0f;
std::unique_ptr<int8_t[]> query_data_copy = nullptr;
query_data_copy = std::make_unique<int8_t[]>(data_size_);
std::memcpy(query_data_copy.get(), query_data, data_size_);
Expand Down

0 comments on commit 3115ad8

Please sign in to comment.