Skip to content

Commit

Permalink
Extend the test coverage of FaissIVFIndex
Browse files Browse the repository at this point in the history
Summary: The patch adds a new unit test for `FaissIVFIndex` that compares its results with a regular in-memory FAISS index. Specifically, it trains two identical IVF indices using the same training vectors, passes the ownership of one to `FaissIVFIndex`, adds the same set of database vectors to both, and then queries them using the same query vectors (with a variety of values for number of neighbors and number of probes).

Differential Revision: D68233815
  • Loading branch information
ltamasi authored and facebook-github-bot committed Jan 15, 2025
1 parent b333358 commit e74c44c
Showing 1 changed file with 139 additions and 0 deletions.
139 changes: 139 additions & 0 deletions utilities/secondary_index/faiss_ivf_index_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -290,6 +290,145 @@ TEST(FaissIVFIndexTest, Basic) {
}
}

TEST(FaissIVFIndexTest, Compare) {
// Train two copies of the same index; hand over one to FaissIVFIndex and use
// the other one as a baseline for comparison
constexpr size_t dim = 128;
auto quantizer_cmp = std::make_unique<faiss::IndexFlatL2>(dim);
auto quantizer = std::make_unique<faiss::IndexFlatL2>(dim);

constexpr size_t num_lists = 16;
auto index_cmp = std::make_unique<faiss::IndexIVFFlat>(quantizer_cmp.get(),
dim, num_lists);
auto index =
std::make_unique<faiss::IndexIVFFlat>(quantizer.get(), dim, num_lists);

{
constexpr faiss::idx_t num_train = 1024;
std::vector<float> embeddings_train(dim * num_train);
faiss::float_rand(embeddings_train.data(), dim * num_train, 42);

index_cmp->train(num_train, embeddings_train.data());
index->train(num_train, embeddings_train.data());
}

const std::string db_name = test::PerThreadDBPath("faiss_ivf_index_test");
EXPECT_OK(DestroyDB(db_name, Options()));

Options options;
options.create_if_missing = true;

TransactionDBOptions txn_db_options;
txn_db_options.secondary_indices.emplace_back(std::make_shared<FaissIVFIndex>(
std::move(index), kDefaultWideColumnName.ToString()));

TransactionDB* db = nullptr;
ASSERT_OK(TransactionDB::Open(options, txn_db_options, db_name, &db));

std::unique_ptr<TransactionDB> db_guard(db);

ColumnFamilyOptions cf1_opts;
ColumnFamilyHandle* cfh1 = nullptr;
ASSERT_OK(db->CreateColumnFamily(cf1_opts, "cf1", &cfh1));
std::unique_ptr<ColumnFamilyHandle> cfh1_guard(cfh1);

ColumnFamilyOptions cf2_opts;
ColumnFamilyHandle* cfh2 = nullptr;
ASSERT_OK(db->CreateColumnFamily(cf2_opts, "cf2", &cfh2));
std::unique_ptr<ColumnFamilyHandle> cfh2_guard(cfh2);

const auto& secondary_index = txn_db_options.secondary_indices.back();
secondary_index->SetPrimaryColumnFamily(cfh1);
secondary_index->SetSecondaryColumnFamily(cfh2);

// Add the same set of database vectors to both indices
{
constexpr faiss::idx_t num_db = 4096;
std::vector<float> embeddings_db(dim * num_db);
faiss::float_rand(embeddings_db.data(), dim * num_db, 123);

for (faiss::idx_t i = 0; i < num_db; ++i) {
const float* const embedding = embeddings_db.data() + i * dim;

index_cmp->add(1, embedding);

const std::string primary_key = std::to_string(i);
ASSERT_OK(db->Put(WriteOptions(), cfh1, primary_key,
Slice(reinterpret_cast<const char*>(embedding),
dim * sizeof(float))));
}
}

// Search both indices with the same set of query vectors and make sure the
// results match
{
constexpr faiss::idx_t num_query = 32;
std::vector<float> embeddings_query(dim * num_query);
faiss::float_rand(embeddings_query.data(), dim * num_query, 456);

for (size_t neighbors : {1, 2, 4}) {
for (size_t probes : {1, 2, 4}) {
std::unique_ptr<Iterator> underlying_it(
db->NewIterator(ReadOptions(), cfh2));

SecondaryIndexReadOptions read_options;
read_options.similarity_search_neighbors = neighbors;
read_options.similarity_search_probes = probes;

std::unique_ptr<Iterator> it =
txn_db_options.secondary_indices.back()->NewIterator(
read_options, std::move(underlying_it));

auto get_id = [&]() -> faiss::idx_t {
Slice key = it->key();
faiss::idx_t id = -1;

if (std::from_chars(key.data(), key.data() + key.size(), id).ec !=
std::errc()) {
return -1;
}

return id;
};

for (faiss::idx_t i = 0; i < num_query; ++i) {
const float* const embedding = embeddings_query.data() + i * dim;

std::vector<float> distances(neighbors, 0.0f);
std::vector<faiss::idx_t> ids(neighbors, -1);

faiss::SearchParametersIVF params;
params.nprobe = probes;

index_cmp->search(1, embedding, neighbors, distances.data(),
ids.data(), &params);

size_t num_found_cmp = 0;
for (faiss::idx_t id : ids) {
if (id == -1) {
break;
}

++num_found_cmp;
}

size_t num_found = 0;
for (it->Seek(Slice(reinterpret_cast<const char*>(embedding),
dim * sizeof(float)));
it->Valid(); it->Next()) {
ASSERT_EQ(get_id(), ids[num_found]);

++num_found;
}

ASSERT_OK(it->status());
ASSERT_EQ(num_found, num_found_cmp);
}
}
}
}
}

} // namespace ROCKSDB_NAMESPACE

int main(int argc, char** argv) {
Expand Down

0 comments on commit e74c44c

Please sign in to comment.