From 1c4a180f5a0cb90289e1fcfeb940131ab47f7848 Mon Sep 17 00:00:00 2001 From: cqy123456 Date: Tue, 28 Nov 2023 02:59:46 -0500 Subject: [PATCH] Add async thread pool for generating diskann cache Signed-off-by: cqy123456 --- knowhere/common/ThreadPool.cpp | 15 +++ knowhere/common/ThreadPool.h | 9 ++ knowhere/index/vector_index/IndexDiskANN.cpp | 14 +- thirdparty/DiskANN/include/pq_flash_index.h | 2 +- thirdparty/DiskANN/src/pq_flash_index.cpp | 134 ++++++++++--------- 5 files changed, 103 insertions(+), 71 deletions(-) diff --git a/knowhere/common/ThreadPool.cpp b/knowhere/common/ThreadPool.cpp index 78796c7c8..879e07ac6 100644 --- a/knowhere/common/ThreadPool.cpp +++ b/knowhere/common/ThreadPool.cpp @@ -60,4 +60,19 @@ ThreadPool::GetGlobalThreadPool() { static auto pool = std::make_shared(global_thread_pool_size_); return pool; } + +std::shared_ptr +ThreadPool::GetGlobalAsyncThreadPool() { + if (global_thread_pool_size_ == 0) { + std::lock_guard lock(global_thread_pool_mutex_); + if (global_thread_pool_size_ == 0) { + global_thread_pool_size_ = std::thread::hardware_concurrency(); + } + } + uint32_t async_thread_pool_size = int(std::ceil(global_thread_pool_size_ / 2.0)); + LOG_KNOWHERE_WARNING_ << "async thread pool size init with thread number:" + << async_thread_pool_size; + static auto async_pool = std::make_shared(async_thread_pool_size); + return async_pool; +} } // namespace knowhere diff --git a/knowhere/common/ThreadPool.h b/knowhere/common/ThreadPool.h index c7fd414bf..91aba7a6d 100644 --- a/knowhere/common/ThreadPool.h +++ b/knowhere/common/ThreadPool.h @@ -11,6 +11,7 @@ #pragma once +#include #include #include @@ -58,6 +59,14 @@ class ThreadPool { static std::shared_ptr GetGlobalThreadPool(); + /** + * @brief Get the global async thread pool of knowhere. + * + * @return ThreadPool& + */ + static std::shared_ptr + GetGlobalAsyncThreadPool(); + class ScopedOmpSetter { int omp_before; public: diff --git a/knowhere/index/vector_index/IndexDiskANN.cpp b/knowhere/index/vector_index/IndexDiskANN.cpp index fb19dcb5f..1ed4bc9a2 100644 --- a/knowhere/index/vector_index/IndexDiskANN.cpp +++ b/knowhere/index/vector_index/IndexDiskANN.cpp @@ -312,15 +312,13 @@ IndexDiskANN::Prepare(const Config& config) { return false; } } else { - pq_flash_index_->set_async_cache_flag(true); - pool_->push([&, cache_num = num_nodes_to_cache, + auto aysnc_pool_ = ThreadPool::GetGlobalAsyncThreadPool(); + + pq_flash_index_->setup_cache_sync_task(); + aysnc_pool_->push([&, cache_num = num_nodes_to_cache, sample_nodes_file = warmup_query_file]() { - try { - pq_flash_index_->generate_cache_list_from_sample_queries( - sample_nodes_file, 15, 6, cache_num); - } catch (const std::exception& e) { - LOG_KNOWHERE_ERROR_ << "DiskANN Exception: " << e.what(); - } + pq_flash_index_->generate_cache_list_from_sample_queries( + sample_nodes_file, 15, 6, cache_num); }); } } diff --git a/thirdparty/DiskANN/include/pq_flash_index.h b/thirdparty/DiskANN/include/pq_flash_index.h index 7278fcc58..f1e584d19 100644 --- a/thirdparty/DiskANN/include/pq_flash_index.h +++ b/thirdparty/DiskANN/include/pq_flash_index.h @@ -128,7 +128,7 @@ namespace diskann { DISKANN_DLLEXPORT diskann::Metric get_metric() const noexcept; - DISKANN_DLLEXPORT void set_async_cache_flag(const bool flag); + DISKANN_DLLEXPORT void setup_cache_sync_task(); protected: DISKANN_DLLEXPORT void use_medoids_data_as_centroids(); diff --git a/thirdparty/DiskANN/src/pq_flash_index.cpp b/thirdparty/DiskANN/src/pq_flash_index.cpp index 213fc303c..da062e4ff 100644 --- a/thirdparty/DiskANN/src/pq_flash_index.cpp +++ b/thirdparty/DiskANN/src/pq_flash_index.cpp @@ -307,77 +307,78 @@ namespace diskann { std::string sample_bin, _u64 l_search, _u64 beamwidth, _u64 num_nodes_to_cache) { #endif - auto s = std::chrono::high_resolution_clock::now(); - this->search_counter.store(0); - this->node_visit_counter.clear(); - this->node_visit_counter.resize(this->num_points); - this->count_visited_nodes.store(true); - - for (_u32 i = 0; i < node_visit_counter.size(); i++) { - this->node_visit_counter[i].first = i; - this->node_visit_counter[i].second = 0; - } - - _u64 sample_num, sample_dim, sample_aligned_dim; T * samples; + try { + auto s = std::chrono::high_resolution_clock::now(); + + _u64 sample_num, sample_dim, sample_aligned_dim; + std::stringstream stream; #ifdef EXEC_ENV_OLS - if (files.fileExists(sample_bin)) { - diskann::load_aligned_bin(files, sample_bin, samples, sample_num, - sample_dim, sample_aligned_dim); - } + if (files.fileExists(sample_bin)) { + diskann::load_aligned_bin(files, sample_bin, samples, sample_num, + sample_dim, sample_aligned_dim); + } #else - if (file_exists(sample_bin)) { - diskann::load_aligned_bin(sample_bin, samples, sample_num, sample_dim, - sample_aligned_dim); - } + if (file_exists(sample_bin)) { + diskann::load_aligned_bin(sample_bin, samples, sample_num, sample_dim, + sample_aligned_dim); + } #endif - else { - diskann::cerr << "Sample bin file not found. Not generating cache." - << std::endl; - return; - } - - int64_t tmp_result_ids_64; - float tmp_result_dists; + else { + stream << "Sample bin file not found. Not generating cache." + << std::endl; + throw diskann::ANNException(stream.str(), -1, __FUNCSIG__, __FILE__, + __LINE__); + } - auto id = 0; - while (this->search_counter.load() < sample_num && id < sample_num && - !this->semaph.IsWaitting()) { - cached_beam_search(samples + (id * sample_aligned_dim), 1, l_search, - &tmp_result_ids_64, &tmp_result_dists, beamwidth); - id++; - } + int64_t tmp_result_ids_64; + float tmp_result_dists; - if (this->semaph.IsWaitting()) { - this->semaph.Signal(); - return; - } + auto id = 0; + while (this->search_counter.load() < sample_num && id < sample_num && + !this->semaph.IsWaitting()) { + cached_beam_search(samples + (id * sample_aligned_dim), 1, l_search, + &tmp_result_ids_64, &tmp_result_dists, beamwidth); + id++; + } - this->count_visited_nodes.store(false); - std::sort(this->node_visit_counter.begin(), node_visit_counter.end(), - [](std::pair<_u32, _u32> &left, std::pair<_u32, _u32> &right) { - return left.second > right.second; - }); - - std::vector node_list; - node_list.clear(); - node_list.shrink_to_fit(); - node_list.reserve(num_nodes_to_cache); - for (_u64 i = 0; i < num_nodes_to_cache; i++) { - node_list.push_back(this->node_visit_counter[i].first); - } - this->node_visit_counter.clear(); - this->node_visit_counter.shrink_to_fit(); - this->search_counter.store(0); + if (this->semaph.IsWaitting()) { + stream << "pq_flash_index is destoried, async thread should be exit." + << std::endl; + throw diskann::ANNException(stream.str(), -1, __FUNCSIG__, __FILE__, + __LINE__); + } - diskann::aligned_free(samples); - this->load_cache_list(node_list); - auto e = std::chrono::high_resolution_clock::now(); - std::chrono::duration diff = e - s; - LOG(INFO) << "Using sample queries to generate cache, cost: " << diff.count() << "s"; + this->count_visited_nodes.store(false); + std::sort(this->node_visit_counter.begin(), node_visit_counter.end(), + [](std::pair<_u32, _u32> &left, std::pair<_u32, _u32> &right) { + return left.second > right.second; + }); + std::vector node_list; + node_list.clear(); + node_list.shrink_to_fit(); + node_list.reserve(num_nodes_to_cache); + for (_u64 i = 0; i < num_nodes_to_cache; i++) { + node_list.push_back(this->node_visit_counter[i].first); + } + this->node_visit_counter.clear(); + this->node_visit_counter.shrink_to_fit(); + this->search_counter.store(0); + + this->load_cache_list(node_list); + auto e = std::chrono::high_resolution_clock::now(); + std::chrono::duration diff = e - s; + LOG(INFO) << "Using sample queries to generate cache, cost: " << diff.count() << "s"; + } catch (const std::exception& e) { + LOG(ERROR) << "DiskANN Exception: " << e.what(); + } this->semaph.Signal(); + // free samples + if (samples != nullptr) { + diskann::aligned_free(samples); + } return; } @@ -1574,8 +1575,17 @@ namespace diskann { } template - void PQFlashIndex::set_async_cache_flag(const bool flag) { - this->async_generate_cache.exchange(flag); + void PQFlashIndex::setup_cache_sync_task() { + this->async_generate_cache.exchange(true); + this->search_counter.store(0); + this->node_visit_counter.clear(); + this->node_visit_counter.resize(this->num_points); + this->count_visited_nodes.store(true); + + for (_u32 i = 0; i < node_visit_counter.size(); i++) { + this->node_visit_counter[i].first = i; + this->node_visit_counter[i].second = 0; + } } #ifdef EXEC_ENV_OLS