Skip to content

Commit

Permalink
MultiCfIterator Implementations (#12422)
Browse files Browse the repository at this point in the history
Summary:
This PR continues #12153 by implementing the missing `Iterator` APIs - `Seek()`, `SeekForPrev()`, `SeekToLast()`, and `Prev`. A MaxHeap Implementation has been added to handle the reverse direction.

The current implementation does not include upper/lower bounds yet. These will be added in subsequent PRs. The API is still marked as under construction and will be lifted after being added to the stress test.

Please note that changing the iterator direction in the middle of iteration is expensive, as it requires seeking the element in each iterator again in the opposite direction and rebuilding the heap along the way. The first `Next()` after `SeekForPrev()` requires changing the direction under the current implementation. We may optimize this in later PRs.

Pull Request resolved: #12422

Test Plan: The `multi_cf_iterator_test` has been extended to cover the API implementations.

Reviewed By: pdillinger

Differential Revision: D54820754

Pulled By: jaykorean

fbshipit-source-id: 9eb741508df0f7bad598fb8e6bd5cdffc39e81d1
  • Loading branch information
jaykorean authored and facebook-github-bot committed Mar 18, 2024
1 parent 3d5be59 commit db1dea2
Show file tree
Hide file tree
Showing 3 changed files with 284 additions and 45 deletions.
73 changes: 57 additions & 16 deletions db/multi_cf_iterator.cc
Original file line number Diff line number Diff line change
Expand Up @@ -9,54 +9,95 @@

namespace ROCKSDB_NAMESPACE {

void MultiCfIterator::SeekToFirst() {
void MultiCfIterator::SeekCommon(
const std::function<void(Iterator*)>& child_seek_func,
Direction direction) {
direction_ = direction;
Reset();
int i = 0;
for (auto& cfh_iter_pair : cfh_iter_pairs_) {
auto& cfh = cfh_iter_pair.first;
auto& iter = cfh_iter_pair.second;
iter->SeekToFirst();
child_seek_func(iter.get());
if (iter->Valid()) {
assert(iter->status().ok());
min_heap_.push(MultiCfIteratorInfo{iter.get(), cfh, i});
if (direction_ == kReverse) {
auto& max_heap = std::get<MultiCfMaxHeap>(heap_);
max_heap.push(MultiCfIteratorInfo{iter.get(), cfh, i});
} else {
auto& min_heap = std::get<MultiCfMinHeap>(heap_);
min_heap.push(MultiCfIteratorInfo{iter.get(), cfh, i});
}
} else {
considerStatus(iter->status());
}
++i;
}
}

void MultiCfIterator::Next() {
assert(Valid());
template <typename BinaryHeap>
void MultiCfIterator::AdvanceIterator(
BinaryHeap& heap, const std::function<void(Iterator*)>& advance_func) {
// 1. Keep the top iterator (by popping it from the heap)
// 2. Make sure all others have iterated past the top iterator key slice
// 3. Advance the top iterator, and add it back to the heap if valid
auto top = min_heap_.top();
min_heap_.pop();
if (!min_heap_.empty()) {
auto* current = min_heap_.top().iterator;
auto top = heap.top();
heap.pop();
if (!heap.empty()) {
auto* current = heap.top().iterator;
while (current->Valid() &&
comparator_->Compare(top.iterator->key(), current->key()) == 0) {
assert(current->status().ok());
current->Next();
advance_func(current);
if (current->Valid()) {
min_heap_.replace_top(min_heap_.top());
heap.replace_top(heap.top());
} else {
considerStatus(current->status());
min_heap_.pop();
heap.pop();
}
if (!min_heap_.empty()) {
current = min_heap_.top().iterator;
if (!heap.empty()) {
current = heap.top().iterator;
}
}
}
top.iterator->Next();
advance_func(top.iterator);
if (top.iterator->Valid()) {
assert(top.iterator->status().ok());
min_heap_.push(top);
heap.push(top);
} else {
considerStatus(top.iterator->status());
}
}

void MultiCfIterator::SeekToFirst() {
SeekCommon([](Iterator* iter) { iter->SeekToFirst(); }, kForward);
}
void MultiCfIterator::Seek(const Slice& target) {
SeekCommon([&target](Iterator* iter) { iter->Seek(target); }, kForward);
}
void MultiCfIterator::SeekToLast() {
SeekCommon([](Iterator* iter) { iter->SeekToLast(); }, kReverse);
}
void MultiCfIterator::SeekForPrev(const Slice& target) {
SeekCommon([&target](Iterator* iter) { iter->SeekForPrev(target); },
kReverse);
}

void MultiCfIterator::Next() {
assert(Valid());
if (direction_ != kForward) {
SwitchToDirection(kForward);
}
auto& min_heap = std::get<MultiCfMinHeap>(heap_);
AdvanceIterator(min_heap, [](Iterator* iter) { iter->Next(); });
}
void MultiCfIterator::Prev() {
assert(Valid());
if (direction_ != kReverse) {
SwitchToDirection(kReverse);
}
auto& max_heap = std::get<MultiCfMaxHeap>(heap_);
AdvanceIterator(max_heap, [](Iterator* iter) { iter->Prev(); });
}

} // namespace ROCKSDB_NAMESPACE
117 changes: 91 additions & 26 deletions db/multi_cf_iterator.h
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,13 @@

#pragma once

#include <variant>

#include "rocksdb/comparator.h"
#include "rocksdb/iterator.h"
#include "rocksdb/options.h"
#include "util/heap.h"
#include "util/overload.h"

namespace ROCKSDB_NAMESPACE {

Expand All @@ -23,7 +26,8 @@ class MultiCfIterator : public Iterator {
const std::vector<ColumnFamilyHandle*>& column_families,
const std::vector<Iterator*>& child_iterators)
: comparator_(comparator),
min_heap_(MultiCfMinHeapItemComparator(comparator_)) {
heap_(MultiCfMinHeap(
MultiCfHeapItemComparator<std::greater<int>>(comparator_))) {
assert(column_families.size() > 0 &&
column_families.size() == child_iterators.size());
cfh_iter_pairs_.reserve(column_families.size());
Expand Down Expand Up @@ -52,11 +56,11 @@ class MultiCfIterator : public Iterator {
int order;
};

class MultiCfMinHeapItemComparator {
template <typename CompareOp>
class MultiCfHeapItemComparator {
public:
explicit MultiCfMinHeapItemComparator(const Comparator* comparator)
explicit MultiCfHeapItemComparator(const Comparator* comparator)
: comparator_(comparator) {}

bool operator()(const MultiCfIteratorInfo& a,
const MultiCfIteratorInfo& b) const {
assert(a.iterator);
Expand All @@ -65,52 +69,113 @@ class MultiCfIterator : public Iterator {
assert(b.iterator->Valid());
int c = comparator_->Compare(a.iterator->key(), b.iterator->key());
assert(c != 0 || a.order != b.order);
return c == 0 ? a.order - b.order > 0 : c > 0;
return c == 0 ? a.order - b.order > 0 : CompareOp()(c, 0);
}

private:
const Comparator* comparator_;
};

const Comparator* comparator_;
using MultiCfMinHeap =
BinaryHeap<MultiCfIteratorInfo, MultiCfMinHeapItemComparator>;
MultiCfMinHeap min_heap_;
// TODO: MaxHeap for Reverse Iteration
BinaryHeap<MultiCfIteratorInfo,
MultiCfHeapItemComparator<std::greater<int>>>;
using MultiCfMaxHeap = BinaryHeap<MultiCfIteratorInfo,
MultiCfHeapItemComparator<std::less<int>>>;

using MultiCfIterHeap = std::variant<MultiCfMinHeap, MultiCfMaxHeap>;

MultiCfIterHeap heap_;

enum Direction : uint8_t { kForward, kReverse };
Direction direction_ = kForward;

// TODO: Lower and Upper bounds

Iterator* current() const {
if (direction_ == kReverse) {
auto& max_heap = std::get<MultiCfMaxHeap>(heap_);
return max_heap.top().iterator;
}
auto& min_heap = std::get<MultiCfMinHeap>(heap_);
return min_heap.top().iterator;
}

Slice key() const override {
assert(Valid());
return min_heap_.top().iterator->key();
return current()->key();
}
Slice value() const override {
assert(Valid());
return current()->value();
}
const WideColumns& columns() const override {
assert(Valid());
return current()->columns();
}

bool Valid() const override {
if (direction_ == kReverse) {
auto& max_heap = std::get<MultiCfMaxHeap>(heap_);
return !max_heap.empty() && status_.ok();
}
auto& min_heap = std::get<MultiCfMinHeap>(heap_);
return !min_heap.empty() && status_.ok();
}
bool Valid() const override { return !min_heap_.empty() && status_.ok(); }

Status status() const override { return status_; }
void considerStatus(Status s) {
if (!s.ok() && status_.ok()) {
status_ = std::move(s);
}
}
void Reset() {
min_heap_.clear();
std::visit(overload{[&](MultiCfMinHeap& min_heap) -> void {
min_heap.clear();
if (direction_ == kReverse) {
InitMaxHeap();
}
},
[&](MultiCfMaxHeap& max_heap) -> void {
max_heap.clear();
if (direction_ == kForward) {
InitMinHeap();
}
}},
heap_);
status_ = Status::OK();
}

void SeekToFirst() override;
void Next() override;

// TODO - Implement these
void Seek(const Slice& /*target*/) override {}
void SeekForPrev(const Slice& /*target*/) override {}
void SeekToLast() override {}
void Prev() override { assert(false); }
Slice value() const override {
assert(Valid());
return min_heap_.top().iterator->value();
void InitMinHeap() {
heap_.emplace<MultiCfMinHeap>(
MultiCfHeapItemComparator<std::greater<int>>(comparator_));
}
const WideColumns& columns() const override {
assert(Valid());
return min_heap_.top().iterator->columns();
void InitMaxHeap() {
heap_.emplace<MultiCfMaxHeap>(
MultiCfHeapItemComparator<std::less<int>>(comparator_));
}

void SwitchToDirection(Direction new_direction) {
assert(direction_ != new_direction);
Slice target = key();
if (new_direction == kForward) {
Seek(target);
} else {
SeekForPrev(target);
}
}

void SeekCommon(const std::function<void(Iterator*)>& child_seek_func,
Direction direction);
template <typename BinaryHeap>
void AdvanceIterator(BinaryHeap& heap,
const std::function<void(Iterator*)>& advance_func);

void SeekToFirst() override;
void SeekToLast() override;
void Seek(const Slice& /*target*/) override;
void SeekForPrev(const Slice& /*target*/) override;
void Next() override;
void Prev() override;
};

} // namespace ROCKSDB_NAMESPACE
Loading

0 comments on commit db1dea2

Please sign in to comment.