Skip to content

Commit

Permalink
add io trailer for knowhere
Browse files Browse the repository at this point in the history
Signed-off-by: cqy123456 <[email protected]>
  • Loading branch information
cqy123456 committed Dec 12, 2023
1 parent 968157c commit 2d65386
Show file tree
Hide file tree
Showing 11 changed files with 346 additions and 15 deletions.
1 change: 1 addition & 0 deletions include/knowhere/expected.h
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ enum class Status {
raft_inner_error = 18,
invalid_binary_set = 19,
invalid_instruction_set = 20,
invalid_trailer = 21,
};

inline std::string
Expand Down
37 changes: 27 additions & 10 deletions src/index/diskann/diskann.cc
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
#include "diskann/pq_flash_index.h"
#include "fmt/core.h"
#include "index/diskann/diskann_config.h"
#include "io/trailer.h"
#include "knowhere/comp/index_param.h"
#include "knowhere/comp/thread_pool.h"
#include "knowhere/dataset.h"
Expand Down Expand Up @@ -189,9 +190,13 @@ TryDiskANNCall(std::function<void()>&& diskann_call) {
}
}

inline std::string
GetDiskAnnTrailerFileName(const std::string& disk_index_filename) {
return disk_index_filename + "_trailer.bin";
}

std::vector<std::string>
GetNecessaryFilenames(const std::string& prefix, const bool need_norm, const bool use_sample_cache,
const bool use_sample_warmup) {
GetNecessaryFilenames(const std::string& prefix, const bool need_norm) {
std::vector<std::string> filenames;
auto pq_pivots_filename = diskann::get_pq_pivots_filename(prefix);
auto disk_index_filename = diskann::get_disk_index_filename(prefix);
Expand All @@ -205,9 +210,6 @@ GetNecessaryFilenames(const std::string& prefix, const bool need_norm, const boo
if (need_norm) {
filenames.push_back(diskann::get_disk_index_max_base_norm_file(disk_index_filename));
}
if (use_sample_cache || use_sample_warmup) {
filenames.push_back(diskann::get_sample_data_filename(prefix));
}
return filenames;
}

Expand All @@ -218,6 +220,7 @@ GetOptionalFilenames(const std::string& prefix) {
filenames.push_back(diskann::get_disk_index_centroids_filename(disk_index_filename));
filenames.push_back(diskann::get_disk_index_medoids_filename(disk_index_filename));
filenames.push_back(diskann::get_cached_nodes_file(prefix));
filenames.push_back(GetDiskAnnTrailerFileName(prefix));
return filenames;
}

Expand All @@ -231,7 +234,7 @@ AnyIndexFileExist(const std::string& index_prefix) {
}
return false;
};
return file_exist(GetNecessaryFilenames(index_prefix, diskann::INNER_PRODUCT, true, true)) ||
return file_exist(GetNecessaryFilenames(index_prefix, diskann::INNER_PRODUCT)) ||
file_exist(GetOptionalFilenames(index_prefix));
}

Expand Down Expand Up @@ -312,8 +315,16 @@ DiskANNIndexNode<T>::Build(const DataSet& dataset, const Config& cfg) {
-1);
}));

auto trailer_file = GetDiskAnnTrailerFileName(index_prefix_);
auto necessary_files = GetNecessaryFilenames(index_prefix_, need_norm);
auto trailer_status = AddTrailerForFiles(necessary_files, trailer_file, Type(), version_);
if (trailer_status != Status::success) {
LOG_KNOWHERE_ERROR_ << "Failed to trailer file.";
return trailer_status;
}

// Add file to the file manager
for (auto& filename : GetNecessaryFilenames(index_prefix_, need_norm, true, true)) {
for (auto& filename : necessary_files) {
if (!AddFile(filename)) {
LOG_KNOWHERE_ERROR_ << "Failed to add file " << filename << ".";
return Status::disk_file_error;
Expand Down Expand Up @@ -360,9 +371,8 @@ DiskANNIndexNode<T>::Deserialize(const BinarySet& binset, const Config& cfg) {
}();

// Load file from file manager.
for (auto& filename : GetNecessaryFilenames(
index_prefix_, need_norm, prep_conf.search_cache_budget_gb.value() > 0 && !prep_conf.use_bfs_cache.value(),
prep_conf.warm_up.value())) {
auto necessary_files = GetNecessaryFilenames(index_prefix_, need_norm);
for (auto& filename : necessary_files) {
if (!LoadFile(filename)) {
return Status::disk_file_error;
}
Expand All @@ -378,6 +388,13 @@ DiskANNIndexNode<T>::Deserialize(const BinarySet& binset, const Config& cfg) {
}
}

auto trailer_file = GetDiskAnnTrailerFileName(index_prefix_);
auto trailer_stat = CheckTrailerForFiles(necessary_files, trailer_file, Type());
if (trailer_stat != Status::success) {
LOG_KNOWHERE_ERROR_ << "Failed to check diskann trailer.";
return trailer_stat;
}

// set thread pool
search_pool_ = ThreadPool::GetGlobalSearchThreadPool();

Expand Down
6 changes: 6 additions & 0 deletions src/index/flat/flat.cc
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
#include "faiss/index_io.h"
#include "index/flat/flat_config.h"
#include "io/memory_io.h"
#include "io/trailer.h"
#include "knowhere/bitsetview_idselector.h"
#include "knowhere/comp/thread_pool.h"
#include "knowhere/factory.h"
Expand Down Expand Up @@ -304,6 +305,11 @@ class FlatIndexNode : public IndexNode {
if constexpr (std::is_same<T, faiss::IndexBinaryFlat>::value) {
faiss::write_index_binary(index_.get(), &writer);
}
auto trailer_status = AddTrailerForMemoryIO(writer, Type(), this->version_);
if (trailer_status != Status::success) {
LOG_KNOWHERE_ERROR_ << "fail to append trailer.";
return trailer_status;
}
std::shared_ptr<uint8_t[]> data(writer.data());
binset.Append(Type(), data, writer.tellg());
return Status::success;
Expand Down
6 changes: 6 additions & 0 deletions src/index/gpu/flat_gpu/flat_gpu.cc
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
#include "index/flat_gpu/flat_gpu_config.h"
#include "index/gpu/gpu_res_mgr.h"
#include "io/memory_io.h"
#include "io/trailer.h"
#include "knowhere/factory.h"
#include "knowhere/log.h"

Expand Down Expand Up @@ -120,6 +121,11 @@ class GpuFlatIndexNode : public IndexNode {
MemoryIOWriter writer;
// Serialize() is called after Add(), at this time index_ is CPU index actually
faiss::write_index(index_.get(), &writer);
auto trailer_status = AddTrailerForMemoryIO(writer, Type(), this->version_);
if (trailer_status != Status::success) {
LOG_KNOWHERE_ERROR_ << "fail to append trailer.";
return trailer_status;
}
std::shared_ptr<uint8_t[]> data(writer.data());
binset.Append(Type(), data, writer.tellg());
} catch (const std::exception& e) {
Expand Down
9 changes: 8 additions & 1 deletion src/index/gpu/ivf_gpu/ivf_gpu.cc
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
#include "index/gpu/gpu_res_mgr.h"
#include "index/ivf_gpu/ivf_gpu_config.h"
#include "io/memory_io.h"
#include "io/trailer.h"
#include "knowhere/comp/index_param.h"
#include "knowhere/factory.h"
#include "knowhere/log.h"
Expand Down Expand Up @@ -68,7 +69,7 @@ class GpuIvfIndexNode : public IndexNode {

auto metric = Str2FaissMetricType(ivf_gpu_cfg.metric_type);
if (!metric.has_value()) {
LOG_KNOWHERE_ERROR_ << "unsupported metric type: " << ivf_gpu_cfg.metric_type;
LOG_KNOWHERE_WARNING_ << "please check metric value: " << ivf_gpu_cfg.metric_type;
return metric.error();
}

Expand Down Expand Up @@ -191,6 +192,12 @@ class GpuIvfIndexNode : public IndexNode {
faiss::write_index(host_index, &writer);
delete host_index;
}
auto trailer_status = AddTrailerForMemoryIO(writer, Type(), this->version_);
if (trailer_status != Status::success) {
LOG_KNOWHERE_ERROR_ << "fail to append trailer.";
return trailer_status;
}

std::shared_ptr<uint8_t[]> data(writer.data());
binset.Append(Type(), data, writer.tellg());
} catch (std::exception& e) {
Expand Down
12 changes: 10 additions & 2 deletions src/index/gpu_raft/gpu_raft.h
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,8 @@
#include "index/gpu_raft/gpu_raft_cagra_config.h"
#include "index/gpu_raft/gpu_raft_ivf_flat_config.h"
#include "index/gpu_raft/gpu_raft_ivf_pq_config.h"
#include "io/memory_io.h"
#include "io/trailer.h"
#include "knowhere/comp/index_param.h"
#include "knowhere/expected.h"
#include "knowhere/factory.h"
Expand Down Expand Up @@ -170,9 +172,15 @@ struct GpuRaftIndexNode : public IndexNode {
os.flush();
}
if (result == Status::success) {
std::shared_ptr<uint8_t[]> index_binary(new (std::nothrow) uint8_t[buf.str().size()]);
memcpy(index_binary.get(), buf.str().c_str(), buf.str().size());
MemoryIOWriter writer;
writer.write(buf.str().c_str(), buf.str().size());
auto trailer_status = AddTrailerForMemoryIO(writer, Type(), this->version_);
if (trailer_status != Status::success) {
LOG_KNOWHERE_ERROR_ << "fail to append trailer.";
return trailer_status;
}
binset.Append(this->Type(), index_binary, buf.str().size());
writer.close();
}
return result;
}
Expand Down
7 changes: 6 additions & 1 deletion src/index/hnsw/hnsw.cc
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
#include "hnswlib/hnswalg.h"
#include "hnswlib/hnswlib.h"
#include "index/hnsw/hnsw_config.h"
#include "io/trailer.h"
#include "knowhere/comp/index_param.h"
#include "knowhere/comp/thread_pool.h"
#include "knowhere/comp/time_recorder.h"
Expand Down Expand Up @@ -393,6 +394,11 @@ class HnswIndexNode : public IndexNode {
try {
MemoryIOWriter writer;
index_->saveIndex(writer);
auto trailer_status = AddTrailerForMemoryIO(writer, Type(), this->version_);
if (trailer_status != Status::success) {
LOG_KNOWHERE_ERROR_ << "fail to append trailer.";
return trailer_status;
}
std::shared_ptr<uint8_t[]> data(writer.data());
binset.Append(Type(), data, writer.tellg());
} catch (std::exception& e) {
Expand All @@ -415,7 +421,6 @@ class HnswIndexNode : public IndexNode {
}

MemoryIOReader reader(binary->data.get(), binary->size);

hnswlib::SpaceInterface<float>* space = nullptr;
index_ = new (std::nothrow) hnswlib::HierarchicalNSW<float>(space);
index_->loadIndex(reader);
Expand Down
11 changes: 11 additions & 0 deletions src/index/ivf/ivf.cc
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
#include "faiss/index_io.h"
#include "index/ivf/ivf_config.h"
#include "io/memory_io.h"
#include "io/trailer.h"
#include "knowhere/bitsetview_idselector.h"
#include "knowhere/comp/thread_pool.h"
#include "knowhere/dataset.h"
Expand Down Expand Up @@ -824,6 +825,11 @@ IvfIndexNode<T>::Serialize(BinarySet& binset) const {
} else {
faiss::write_index(index_.get(), &writer);
}
auto trailer_status = AddTrailerForMemoryIO(writer, Type(), this->version_);
if (trailer_status != Status::success) {
LOG_KNOWHERE_ERROR_ << "fail to append trailer.";
return trailer_status;
}
std::shared_ptr<uint8_t[]> data(writer.data());
binset.Append(Type(), data, writer.tellg());
return Status::success;
Expand All @@ -846,6 +852,11 @@ IvfIndexNode<faiss::IndexIVFFlat>::Serialize(BinarySet& binset) const {
faiss::write_index(index_.get(), &writer);
LOG_KNOWHERE_INFO_ << "write IVF_FLAT, file size " << writer.tellg();
}
auto trailer_status = AddTrailerForMemoryIO(writer, Type(), this->version_);
if (trailer_status != Status::success) {
LOG_KNOWHERE_ERROR_ << "fail to append trailer.";
return trailer_status;
}
std::shared_ptr<uint8_t[]> index_data_ptr(writer.data());
binset.Append(Type(), index_data_ptr, writer.tellg());

Expand Down
11 changes: 10 additions & 1 deletion src/io/memory_io.h
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@
// or implied. See the License for the specific language governing permissions and limitations under the License

#pragma once

#include <faiss/impl/io.h>

namespace knowhere {
Expand Down Expand Up @@ -136,10 +135,20 @@ struct MemoryIOReader : public faiss::IOReader {
return rp_;
}

size_t
size() {
return total_;
}

void
reset() {
rp_ = 0;
}

void
seekg(const size_t offset) {
rp_ = offset;
}
};

} // namespace knowhere
Loading

0 comments on commit 2d65386

Please sign in to comment.