Skip to content

Commit

Permalink
allow empty sparse row (zilliztech#704)
Browse files Browse the repository at this point in the history
Signed-off-by: Buqian Zheng <[email protected]>
  • Loading branch information
zhengbuqian authored and [email protected] committed Jul 23, 2024
1 parent a5b0d76 commit 00ca273
Show file tree
Hide file tree
Showing 3 changed files with 132 additions and 21 deletions.
6 changes: 3 additions & 3 deletions src/index/sparse/sparse_inverted_index.h
Original file line number Diff line number Diff line change
Expand Up @@ -560,9 +560,9 @@ class InvertedIndex : public BaseInvertedIndex<T> {
if constexpr (bm25) {
row_sum += val;
}
// Skip values close enough to zero(which contributes little to
// the total IP score).
if (drop_during_build_ && fabs(val) < value_threshold_) {
// Skip values equals to or close enough to zero(which contributes
// little to the total IP score).
if (val == 0 || (drop_during_build_ && fabs(val) < value_threshold_)) {
continue;
}
if (inverted_lut_.find(idx) == inverted_lut_.end()) {
Expand Down
105 changes: 105 additions & 0 deletions tests/ut/test_sparse.cc
Original file line number Diff line number Diff line change
Expand Up @@ -385,3 +385,108 @@ TEST_CASE("Test Mem Sparse Index GetVectorByIds", "[float metrics]") {
}
}
}

TEST_CASE("Test Mem Sparse Index Handle Empty Vector", "[float metrics]") {
std::vector<std::map<int32_t, float>> base_data = {{{1, 1.1f}, {2, 2.2f}, {6, 3.3f}},
// explicitly empty row
{},
// implicitly empty row
{{5, 0.0f}}};
auto dim = 7;
const auto train_ds = GenSparseDataSet(base_data, dim);

auto topk = 5;

auto metric = GENERATE(knowhere::metric::IP, knowhere::metric::BM25);
auto version = GenTestVersionList();

auto base_gen = [=, dim = dim]() {
knowhere::Json json;
json[knowhere::meta::DIM] = dim;
json[knowhere::meta::METRIC_TYPE] = metric;
json[knowhere::meta::TOPK] = topk;
json[knowhere::meta::BM25_K1] = 1.2;
json[knowhere::meta::BM25_B] = 0.75;
json[knowhere::meta::BM25_AVGDL] = 100;
return json;
};

auto [name, gen] = GENERATE_REF(table<std::string, std::function<knowhere::Json()>>({
std::make_tuple(knowhere::IndexEnum::INDEX_SPARSE_INVERTED_INDEX, base_gen),
std::make_tuple(knowhere::IndexEnum::INDEX_SPARSE_WAND, base_gen),
}));

std::vector<std::map<int32_t, float>> query_data = {// q0 should find doc 0 only
{{1, 1.1f}},
// q1 and q2 should find no neighbor
{{5, 1.1f}},
{}};
const auto query_ds = GenSparseDataSet(query_data, dim);

auto idx = knowhere::IndexFactory::Instance().Create<knowhere::fp32>(name, version).value();
auto cfg_json = gen().dump();
CAPTURE(name, cfg_json);
knowhere::Json json = knowhere::Json::parse(cfg_json);
REQUIRE(idx.Type() == name);
REQUIRE(idx.Build(train_ds, json) == knowhere::Status::success);
REQUIRE(idx.Size() > 0);

knowhere::BinarySet bs;
REQUIRE(idx.Serialize(bs) == knowhere::Status::success);
REQUIRE(idx.Deserialize(bs, json) == knowhere::Status::success);

auto check_result = [&](const knowhere::DataSet& ds) {
auto nq = ds.GetRows();
auto k = ds.GetDim();
auto* ids = ds.GetIds();
REQUIRE(ids[0] == 0);
for (auto i = 1; i < nq * k; ++i) {
REQUIRE(ids[i] == -1);
}
};

const knowhere::Json conf = {
{knowhere::meta::METRIC_TYPE, metric}, {knowhere::meta::TOPK, topk}, {knowhere::meta::BM25_K1, 1.2},
{knowhere::meta::BM25_B, 0.75}, {knowhere::meta::BM25_AVGDL, 100},
};

SECTION("Test Search") {
auto bf_res = knowhere::BruteForce::SearchSparse(train_ds, query_ds, conf, nullptr);
REQUIRE(bf_res.has_value());
check_result(*bf_res.value());

auto results = idx.Search(query_ds, json, nullptr);
REQUIRE(results.has_value());
check_result(*results.value());
}

SECTION("Test RangeSearch") {
json[knowhere::meta::RADIUS] = 0.0f;
json[knowhere::meta::RANGE_FILTER] = 10000.0f;

auto bf_res =
knowhere::BruteForce::RangeSearch<knowhere::sparse::SparseRow<float>>(train_ds, query_ds, json, nullptr);
REQUIRE(bf_res.has_value());
check_result(*bf_res.value());

auto results = idx.RangeSearch(query_ds, json, nullptr);
REQUIRE(results.has_value());
check_result(*results.value());
}

SECTION("Test GetVectorByIds") {
std::vector<int64_t> ids = {0, 1, 2};
auto results = idx.GetVectorByIds(GenIdsDataSet(2, ids));
REQUIRE(results.has_value());
auto xb = (knowhere::sparse::SparseRow<float>*)train_ds->GetTensor();
auto res_data = (knowhere::sparse::SparseRow<float>*)results.value()->GetTensor();
for (int i = 0; i < 2; ++i) {
const auto& truth_row = xb[i];
const auto& res_row = res_data[i];
REQUIRE(truth_row.size() == res_row.size());
for (size_t j = 0; j < truth_row.size(); ++j) {
REQUIRE(truth_row[j] == res_row[j]);
}
}
}
}
42 changes: 24 additions & 18 deletions tests/ut/utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -313,6 +313,29 @@ GenTestVersionList() {
return GENERATE(as<int32_t>{}, knowhere::Version::GetCurrentVersion().VersionNumber());
}

inline knowhere::DataSetPtr
GenSparseDataSet(const std::vector<std::map<int32_t, float>>& data, int32_t cols) {
int32_t rows = data.size();
auto tensor = std::make_unique<knowhere::sparse::SparseRow<float>[]>(rows);

for (int32_t i = 0; i < rows; ++i) {
if (data[i].size() == 0) {
continue;
}
knowhere::sparse::SparseRow<float> row(data[i].size());
size_t j = 0;
for (auto& [idx, val] : data[i]) {
row.set_at(j++, idx, val);
}
tensor[i] = std::move(row);
}

auto ds = knowhere::GenDataSet(rows, cols, tensor.release());
ds->SetIsOwner(true);
ds->SetIsSparse(true);
return ds;
}

// Generate a sparse dataset with given sparsity.
inline knowhere::DataSetPtr
GenSparseDataSet(int32_t rows, int32_t cols, float sparsity, int seed = 42) {
Expand All @@ -338,22 +361,5 @@ GenSparseDataSet(int32_t rows, int32_t cols, float sparsity, int seed = 42) {
data[row][col] = val;
}

auto tensor = std::make_unique<knowhere::sparse::SparseRow<float>[]>(rows);

for (int32_t i = 0; i < rows; ++i) {
if (data[i].size() == 0) {
continue;
}
knowhere::sparse::SparseRow<float> row(data[i].size());
size_t j = 0;
for (auto& [idx, val] : data[i]) {
row.set_at(j++, idx, val);
}
tensor[i] = std::move(row);
}

auto ds = knowhere::GenDataSet(rows, cols, tensor.release());
ds->SetIsOwner(true);
ds->SetIsSparse(true);
return ds;
return GenSparseDataSet(data, cols);
}

0 comments on commit 00ca273

Please sign in to comment.