diff --git a/include/knowhere/comp/knowhere_check.h b/include/knowhere/comp/knowhere_check.h index 15734dda5..6a82f34d4 100644 --- a/include/knowhere/comp/knowhere_check.h +++ b/include/knowhere/comp/knowhere_check.h @@ -20,7 +20,7 @@ namespace knowhere { namespace KnowhereCheck { static bool IndexTypeAndDataTypeCheck(const std::string& index_name, VecType data_type) { - auto& static_index_table = IndexFactory::StaticIndexTableInstance(); + auto& static_index_table = std::get<0>(IndexFactory::StaticIndexTableInstance()); auto key = std::pair(index_name, data_type); if (static_index_table.find(key) != static_index_table.end()) { return true; @@ -28,6 +28,16 @@ IndexTypeAndDataTypeCheck(const std::string& index_name, VecType data_type) { return false; } } + +static bool +SuppportMmapIndexTypeCheck(const std::string& index_name) { + auto& mmap_index_table = std::get<1>(IndexFactory::StaticIndexTableInstance()); + if (mmap_index_table.find(index_name) != mmap_index_table.end()) { + return true; + } else { + return false; + } +} } // namespace KnowhereCheck } // namespace knowhere diff --git a/include/knowhere/index/index_factory.h b/include/knowhere/index/index_factory.h index f66cfe9a3..501904f69 100644 --- a/include/knowhere/index/index_factory.h +++ b/include/knowhere/index/index_factory.h @@ -31,7 +31,7 @@ class IndexFactory { Register(const std::string& name, std::function(const int32_t&, const Object&)> func); static IndexFactory& Instance(); - typedef std::set> GlobalIndexTable; + typedef std::tuple>, std::set> GlobalIndexTable; static GlobalIndexTable& StaticIndexTableInstance(); @@ -77,11 +77,11 @@ class IndexFactory { std::make_unique::type>>(version, object), thread_size)); \ }, \ data_type) -#define KNOWHERE_SET_STATIC_GLOBAL_INDEX_TABLE(name, index_table) \ - static int name = []() -> int { \ - auto& static_index_table = IndexFactory::StaticIndexTableInstance(); \ - static_index_table.insert(index_table.begin(), index_table.end()); \ - return 0; \ +#define KNOWHERE_SET_STATIC_GLOBAL_INDEX_TABLE(table_index, name, index_table) \ + static int name = []() -> int { \ + auto& static_index_table = std::get(IndexFactory::StaticIndexTableInstance()); \ + static_index_table.insert(index_table.begin(), index_table.end()); \ + return 0; \ }(); } // namespace knowhere diff --git a/include/knowhere/index/index_table.h b/include/knowhere/index/index_table.h index 275bdaf07..34a646b63 100644 --- a/include/knowhere/index/index_table.h +++ b/include/knowhere/index/index_table.h @@ -74,6 +74,55 @@ static std::set> legal_knowhere_index = { {IndexEnum::INDEX_SPARSE_INVERTED_INDEX, VecType::VECTOR_SPARSE_FLOAT}, {IndexEnum::INDEX_SPARSE_WAND, VecType::VECTOR_SPARSE_FLOAT}, }; -KNOWHERE_SET_STATIC_GLOBAL_INDEX_TABLE(KNOWHERE_STATIC_INDEX, legal_knowhere_index) + +static std::set legal_support_mmap_knowhere_index = { + // binary ivf + IndexEnum::INDEX_FAISS_BIN_IDMAP, + IndexEnum::INDEX_FAISS_BIN_IVFFLAT, + // ivf + IndexEnum::INDEX_FAISS_IDMAP, + IndexEnum::INDEX_FAISS_IDMAP, + IndexEnum::INDEX_FAISS_IDMAP, + + IndexEnum::INDEX_FAISS_IVFFLAT, + IndexEnum::INDEX_FAISS_IVFFLAT, + IndexEnum::INDEX_FAISS_IVFFLAT, + + IndexEnum::INDEX_FAISS_IVFPQ, + IndexEnum::INDEX_FAISS_IVFPQ, + IndexEnum::INDEX_FAISS_IVFPQ, + + IndexEnum::INDEX_FAISS_SCANN, + IndexEnum::INDEX_FAISS_SCANN, + IndexEnum::INDEX_FAISS_SCANN, + + IndexEnum::INDEX_FAISS_IVFSQ8, + IndexEnum::INDEX_FAISS_IVFSQ8, + IndexEnum::INDEX_FAISS_IVFSQ8, + + IndexEnum::INDEX_FAISS_IVFSQ_CC, + IndexEnum::INDEX_FAISS_IVFSQ_CC, + IndexEnum::INDEX_FAISS_IVFSQ_CC, + + // hnsw + IndexEnum::INDEX_HNSW, + IndexEnum::INDEX_HNSW, + IndexEnum::INDEX_HNSW, + + IndexEnum::INDEX_HNSW_SQ8, + IndexEnum::INDEX_HNSW_SQ8, + IndexEnum::INDEX_HNSW_SQ8, + + IndexEnum::INDEX_HNSW_SQ8_REFINE, + IndexEnum::INDEX_HNSW_SQ8_REFINE, + IndexEnum::INDEX_HNSW_SQ8_REFINE, + // sparse index + IndexEnum::INDEX_SPARSE_INVERTED_INDEX, + IndexEnum::INDEX_SPARSE_WAND, + +}; +KNOWHERE_SET_STATIC_GLOBAL_INDEX_TABLE(0, KNOWHERE_STATIC_INDEX, legal_knowhere_index) +KNOWHERE_SET_STATIC_GLOBAL_INDEX_TABLE(1, KNOWHERE_SUPPORT_MMAP_INDEX, legal_support_mmap_knowhere_index) + } // namespace knowhere #endif /* INDEX_TABLE_H */ diff --git a/tests/ut/test_index_check.cc b/tests/ut/test_index_check.cc index fb8dc059c..6830a722e 100644 --- a/tests/ut/test_index_check.cc +++ b/tests/ut/test_index_check.cc @@ -46,3 +46,19 @@ TEST_CASE("Test index and data type check", "[IndexCheckTest]") { knowhere::VecType::VECTOR_SPARSE_FLOAT) == false); } } + +TEST_CASE("Test support mmap index", "[IndexCheckTest]") { + SECTION("Test valid") { + REQUIRE(knowhere::KnowhereCheck::SuppportMmapIndexTypeCheck(knowhere::IndexEnum::INDEX_HNSW) == true); + REQUIRE(knowhere::KnowhereCheck::SuppportMmapIndexTypeCheck(knowhere::IndexEnum::INDEX_SPARSE_WAND) == true); + REQUIRE(knowhere::KnowhereCheck::SuppportMmapIndexTypeCheck(knowhere::IndexEnum::INDEX_SPARSE_INVERTED_INDEX) == + true); +#ifndef KNOWHERE_WITH_CARDINAL + REQUIRE(knowhere::KnowhereCheck::SuppportMmapIndexTypeCheck(knowhere::IndexEnum::INDEX_DISKANN) == false); +#else + REQUIRE(knowhere::KnowhereCheck::SuppportMmapIndexTypeCheck(knowhere::IndexEnum::INDEX_DISKANN) == true); +#endif + REQUIRE(knowhere::KnowhereCheck::SuppportMmapIndexTypeCheck(knowhere::IndexEnum::INDEX_FAISS_BIN_IVFFLAT) == + true); + } +}