Skip to content

Commit

Permalink
add trailer for knowhere
Browse files Browse the repository at this point in the history
Signed-off-by: cqy123456 <[email protected]>
  • Loading branch information
cqy123456 committed Dec 6, 2023
1 parent 868926e commit 2b87809
Show file tree
Hide file tree
Showing 7 changed files with 282 additions and 12 deletions.
32 changes: 22 additions & 10 deletions src/index/diskann/diskann.cc
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
#include "diskann/pq_flash_index.h"
#include "fmt/core.h"
#include "index/diskann/diskann_config.h"
#include "io/trailer.h"
#include "knowhere/comp/index_param.h"
#include "knowhere/comp/thread_pool.h"
#include "knowhere/dataset.h"
Expand Down Expand Up @@ -191,9 +192,13 @@ TryDiskANNCall(std::function<void()>&& diskann_call) {
}
}

inline std::string
GetDiskAnnTrailerFileName(const std::string& disk_index_filename) {
return disk_index_filename + "_trailer.bin";
}

std::vector<std::string>
GetNecessaryFilenames(const std::string& prefix, const bool need_norm, const bool use_sample_cache,
const bool use_sample_warmup) {
GetNecessaryFilenames(const std::string& prefix, const bool need_norm) {
std::vector<std::string> filenames;
auto pq_pivots_filename = diskann::get_pq_pivots_filename(prefix);
auto disk_index_filename = diskann::get_disk_index_filename(prefix);
Expand All @@ -207,9 +212,6 @@ GetNecessaryFilenames(const std::string& prefix, const bool need_norm, const boo
if (need_norm) {
filenames.push_back(diskann::get_disk_index_max_base_norm_file(disk_index_filename));
}
if (use_sample_cache || use_sample_warmup) {
filenames.push_back(diskann::get_sample_data_filename(prefix));
}
return filenames;
}

Expand All @@ -220,6 +222,7 @@ GetOptionalFilenames(const std::string& prefix) {
filenames.push_back(diskann::get_disk_index_centroids_filename(disk_index_filename));
filenames.push_back(diskann::get_disk_index_medoids_filename(disk_index_filename));
filenames.push_back(diskann::get_cached_nodes_file(prefix));
filenames.push_back(GetDiskAnnTrailerFileName(prefix));
return filenames;
}

Expand All @@ -233,7 +236,7 @@ AnyIndexFileExist(const std::string& index_prefix) {
}
return false;
};
return file_exist(GetNecessaryFilenames(index_prefix, diskann::INNER_PRODUCT, true, true)) ||
return file_exist(GetNecessaryFilenames(index_prefix, diskann::INNER_PRODUCT)) ||
file_exist(GetOptionalFilenames(index_prefix));
}

Expand Down Expand Up @@ -314,8 +317,12 @@ DiskANNIndexNode<T>::Build(const DataSet& dataset, const Config& cfg) {
-1);
}));

auto trailer_file = GetDiskAnnTrailerFileName(index_prefix_);
auto necessary_files = GetNecessaryFilenames(index_prefix_, need_norm);
AddTrailerForFiles(necessary_files, trailer_file, Type(), version_);

// Add file to the file manager
for (auto& filename : GetNecessaryFilenames(index_prefix_, need_norm, true, true)) {
for (auto& filename : necessary_files) {
if (!AddFile(filename)) {
LOG_KNOWHERE_ERROR_ << "Failed to add file " << filename << ".";
return Status::disk_file_error;
Expand Down Expand Up @@ -362,9 +369,8 @@ DiskANNIndexNode<T>::Deserialize(const BinarySet& binset, const Config& cfg) {
}();

// Load file from file manager.
for (auto& filename : GetNecessaryFilenames(
index_prefix_, need_norm, prep_conf.search_cache_budget_gb.value() > 0 && !prep_conf.use_bfs_cache.value(),
prep_conf.warm_up.value())) {
auto necessary_files = GetNecessaryFilenames(index_prefix_, need_norm);
for (auto& filename : necessary_files) {
if (!LoadFile(filename)) {
return Status::disk_file_error;
}
Expand All @@ -380,6 +386,12 @@ DiskANNIndexNode<T>::Deserialize(const BinarySet& binset, const Config& cfg) {
}
}

auto trailer_file = GetDiskAnnTrailerFileName(index_prefix_);
if (!CheckTrailerForFiles(necessary_files, trailer_file, Type())) {
LOG_KNOWHERE_ERROR_ << "Failed to check diskann trailer.";
return Status::disk_file_error;
}

// set thread pool
search_pool_ = ThreadPool::GetGlobalSearchThreadPool();

Expand Down
2 changes: 2 additions & 0 deletions src/index/flat/flat.cc
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
#include "faiss/index_io.h"
#include "index/flat/flat_config.h"
#include "io/memory_io.h"
#include "io/trailer.h"
#include "knowhere/bitsetview_idselector.h"
#include "knowhere/comp/thread_pool.h"
#include "knowhere/factory.h"
Expand Down Expand Up @@ -304,6 +305,7 @@ class FlatIndexNode : public IndexNode {
if constexpr (std::is_same<T, faiss::IndexBinaryFlat>::value) {
faiss::write_index_binary(index_.get(), &writer);
}
AddTrailerForMemoryIO(writer, Type(), this->version_);
std::shared_ptr<uint8_t[]> data(writer.data());
binset.Append(Type(), data, writer.tellg());
return Status::success;
Expand Down
3 changes: 2 additions & 1 deletion src/index/hnsw/hnsw.cc
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
#include "hnswlib/hnswalg.h"
#include "hnswlib/hnswlib.h"
#include "index/hnsw/hnsw_config.h"
#include "io/trailer.h"
#include "knowhere/comp/index_param.h"
#include "knowhere/comp/thread_pool.h"
#include "knowhere/comp/time_recorder.h"
Expand Down Expand Up @@ -389,6 +390,7 @@ class HnswIndexNode : public IndexNode {
try {
MemoryIOWriter writer;
index_->saveIndex(writer);
AddTrailerForMemoryIO(writer, Type(), version_);
std::shared_ptr<uint8_t[]> data(writer.data());
binset.Append(Type(), data, writer.tellg());
} catch (std::exception& e) {
Expand All @@ -411,7 +413,6 @@ class HnswIndexNode : public IndexNode {
}

MemoryIOReader reader(binary->data.get(), binary->size);

hnswlib::SpaceInterface<float>* space = nullptr;
index_ = new (std::nothrow) hnswlib::HierarchicalNSW<float>(space);
index_->loadIndex(reader);
Expand Down
3 changes: 3 additions & 0 deletions src/index/ivf/ivf.cc
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
#include "faiss/index_io.h"
#include "index/ivf/ivf_config.h"
#include "io/memory_io.h"
#include "io/trailer.h"
#include "knowhere/bitsetview_idselector.h"
#include "knowhere/comp/thread_pool.h"
#include "knowhere/dataset.h"
Expand Down Expand Up @@ -824,6 +825,7 @@ IvfIndexNode<T>::Serialize(BinarySet& binset) const {
} else {
faiss::write_index(index_.get(), &writer);
}
AddTrailerForMemoryIO(writer, Type(), this->version_);
std::shared_ptr<uint8_t[]> data(writer.data());
binset.Append(Type(), data, writer.tellg());
return Status::success;
Expand All @@ -846,6 +848,7 @@ IvfIndexNode<faiss::IndexIVFFlat>::Serialize(BinarySet& binset) const {
faiss::write_index(index_.get(), &writer);
LOG_KNOWHERE_INFO_ << "write IVF_FLAT, file size " << writer.tellg();
}
AddTrailerForMemoryIO(writer, Type(), version_);
std::shared_ptr<uint8_t[]> index_data_ptr(writer.data());
binset.Append(Type(), index_data_ptr, writer.tellg());

Expand Down
12 changes: 11 additions & 1 deletion src/io/memory_io.h
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@
// or implied. See the License for the specific language governing permissions and limitations under the License

#pragma once

#include <faiss/impl/io.h>

namespace knowhere {
Expand Down Expand Up @@ -71,6 +70,7 @@ getSwappedBytes(char C) {

#endif

// MemoryIOwriter and MemoryIOreader it not thread safe
struct MemoryIOWriter : public faiss::IOWriter {
uint8_t* data_ = nullptr;
size_t total_ = 0;
Expand Down Expand Up @@ -136,10 +136,20 @@ struct MemoryIOReader : public faiss::IOReader {
return rp_;
}

size_t
size() {
return total_;
}

void
reset() {
rp_ = 0;
}

void
seekg(const size_t offset) {
rp_ = offset;
}
};

} // namespace knowhere
156 changes: 156 additions & 0 deletions src/io/trailer.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,156 @@
// Copyright (C) 2019-2023 Zilliz. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software distributed under the License
// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
// or implied. See the License for the specific language governing permissions and limitations under the License

#include "io/trailer.h"

#include <cstring>
#include <fstream>

namespace {
static constexpr size_t kBlockSize = 4096;
uint32_t
CalculateCheckSum(const uint8_t* data, int64_t size) {
uint32_t checksum = 0;
for (auto i = 0; i < size; i++) {
checksum ^= data[i]; // xor
}
return checksum;
}

uint32_t
GetFilesCheckSum(std::vector<std::string> files) {
uint32_t checksum = 0;
auto buffer = std::shared_ptr<uint8_t[]>(new uint8_t[kBlockSize]);
for (auto& file_name : files) {
std::ifstream reader(file_name.c_str(), std::ios::binary);
if (!reader) {
LOG_KNOWHERE_WARNING_ << file_name << "not exist, skip calculate check sum for this file.";
continue;
}
while (reader.read((char*)buffer.get(), kBlockSize)) {
std::streamsize read_size = reader.gcount();
checksum ^= CalculateCheckSum(buffer.get(), read_size);
}
}
return checksum;
}
} // namespace

namespace knowhere {
void
AddTrailerForMemoryIO(MemoryIOWriter& writer, const std::string& name, const Version& version) {
auto trailer_ptr = std::make_unique<Trailer>();
auto size = writer.tellg();
trailer_ptr->SetIndexBinarySize(size);
trailer_ptr->SetCheckSum(CalculateCheckSum(writer.data(), size));
trailer_ptr->SetVersion(version.VersionNumber());
trailer_ptr->SetIndexName(name);
writer.write(trailer_ptr->bytes, KNOWHERE_TRAILER_SIZE);
}

bool
CheckTrailerForMemoryIO(MemoryIOReader& reader, const std::string& name) {
uint64_t bin_size = TRAILER_OFFSET(reader.size());
if (bin_size < 0) {
LOG_KNOWHERE_WARNING_ << "The binary is too small and assume no Trailer, pass Trailer check.";
return true;
}

auto trailer_ptr = std::make_unique<Trailer>();
auto pre_rp = reader.tellg();
reader.seekg(bin_size);
reader.read(trailer_ptr.get(), KNOWHERE_TRAILER_SIZE);
reader.seekg(pre_rp);

if (!trailer_ptr->TrailerValidCheck()) {
LOG_KNOWHERE_WARNING_ << "Trailer not exist in Binary.";
return true;
}

auto version = Version(trailer_ptr->GetVersion());
if (!Version::VersionSupport(version)) {
LOG_KNOWHERE_ERROR_ << "Index version(" << version.VersionNumber() << ") is not supported, Trailer check fail.";
return false;
}

if (trailer_ptr->GetIndexName() != name) {
LOG_KNOWHERE_ERROR_ << "Index type or data type is not correct(" << name << ").";
return false;
}

if (trailer_ptr->GetIndexBinarySize() != bin_size) {
LOG_KNOWHERE_ERROR_ << "The size of index binary is not correct.";
return false;
}

if (CalculateCheckSum(reader.data(), bin_size) != trailer_ptr->GetCheckSum()) {
LOG_KNOWHERE_ERROR_ << "Binary checksum check fail.";
return false;
}
LOG_KNOWHERE_INFO_ << "Index Trailer check succeed.";
return true;
}

void
AddTrailerForFiles(const std::vector<std::string>& files, const std::string& trailer_file, const std::string& name,
const Version& version) {
auto trailer_ptr = std::make_unique<Trailer>();
trailer_ptr->SetCheckSum(GetFilesCheckSum(files));
trailer_ptr->SetVersion(version.VersionNumber());
trailer_ptr->SetIndexName(name);

std::ofstream writer(trailer_file.c_str(), std::ios::binary);
writer.write((char*)trailer_ptr->bytes, KNOWHERE_TRAILER_SIZE);
writer.close();
}

bool
CheckTrailerForFiles(const std::vector<std::string>& files, const std::string& trailer_file, const std::string& name) {
std::ifstream reader(trailer_file.c_str(), std::ios::binary);
if (!reader) {
LOG_KNOWHERE_WARNING_ << "Trailer file not exist.";
return true;
}
reader.seekg(0, std::ios::end);
auto fsize = reader.tellg();

reader.seekg(0, std::ios::beg);
if (fsize != KNOWHERE_TRAILER_SIZE) {
LOG_KNOWHERE_ERROR_ << "Trailer size (" << fsize << ")not correct.";
return false;
}
auto trailer_ptr = std::make_unique<Trailer>();
reader.read((char*)trailer_ptr.get()->bytes, KNOWHERE_TRAILER_SIZE);
if (!trailer_ptr->TrailerValidCheck()) {
LOG_KNOWHERE_WARNING_ << "Trailer flag not right.";
return false;
}

auto version = Version(trailer_ptr->GetVersion());
if (!Version::VersionSupport(version)) {
LOG_KNOWHERE_ERROR_ << "Index version(" << version.VersionNumber() << ") is not supported, Trailer check fail.";
return false;
}

if (trailer_ptr->GetIndexName() != name) {
LOG_KNOWHERE_ERROR_ << "Index type or data type is not correct(" << name << ").";
return false;
}

if (GetFilesCheckSum(files) != trailer_ptr->GetCheckSum()) {
LOG_KNOWHERE_ERROR_ << "Files checksum check fail.";
return false;
}
LOG_KNOWHERE_INFO_ << "Index Trailer check succeed.";
return true;
}

} // namespace knowhere
Loading

0 comments on commit 2b87809

Please sign in to comment.