Skip to content

Commit

Permalink
Run clang-format
Browse files Browse the repository at this point in the history
  • Loading branch information
andyfengHKU committed Jan 14, 2025
1 parent fce9d4b commit 9b54987
Show file tree
Hide file tree
Showing 9 changed files with 72 additions and 63 deletions.
15 changes: 8 additions & 7 deletions src/binder/bind/bind_graph_pattern.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -404,18 +404,19 @@ std::shared_ptr<RelExpression> Binder::createRecursiveQueryRel(const parser::Rel
recursiveInfo->nodeProjectionList = std::move(nodeProjectionList);
recursiveInfo->relProjectionList = std::move(relProjectionList);

recursiveInfo->pathNodeIDsExpr = createInvisibleVariable("pathNodeIDs",
LogicalType::LIST(LogicalType::INTERNAL_ID()));
recursiveInfo->pathEdgeIDsExpr = createInvisibleVariable("pathEdgeIDs",
LogicalType::LIST(LogicalType::INTERNAL_ID()));
recursiveInfo->pathEdgeDirectionsExpr = createInvisibleVariable("pathEdgeDirections",
LogicalType::LIST(LogicalType::BOOL()));
recursiveInfo->pathNodeIDsExpr =
createInvisibleVariable("pathNodeIDs", LogicalType::LIST(LogicalType::INTERNAL_ID()));
recursiveInfo->pathEdgeIDsExpr =
createInvisibleVariable("pathEdgeIDs", LogicalType::LIST(LogicalType::INTERNAL_ID()));
recursiveInfo->pathEdgeDirectionsExpr =
createInvisibleVariable("pathEdgeDirections", LogicalType::LIST(LogicalType::BOOL()));

if (relPattern.getRelType() == QueryRelType::WEIGHTED_SHORTEST) {
auto propertyExpr = expressionBinder.bindNodeOrRelPropertyExpression(*rel,
recursivePatternInfo->weightPropertyName);
recursiveInfo->weightPropertyExpr = propertyExpr;
recursiveInfo->weightOutputExpr = createVariable(parsedName + "_weight", propertyExpr->getDataType());
recursiveInfo->weightOutputExpr =
createVariable(parsedName + "_weight", propertyExpr->getDataType());
}

queryRel->setRecursiveInfo(std::move(recursiveInfo));
Expand Down
12 changes: 7 additions & 5 deletions src/function/gds/all_shortest_paths.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,8 @@ struct AllSPDestinationOutputs : public RJOutputs {
public:
AllSPDestinationOutputs(nodeID_t sourceNodeID, std::shared_ptr<PathLengths> pathLengths,
std::shared_ptr<PathMultiplicities> multiplicities)
: RJOutputs{sourceNodeID}, pathLengths{std::move(pathLengths)}, multiplicities{std::move(multiplicities)} {}
: RJOutputs{sourceNodeID}, pathLengths{std::move(pathLengths)},
multiplicities{std::move(multiplicities)} {}

void initRJFromSource(nodeID_t source) override {
multiplicities->pinTargetTable(source.tableID);
Expand Down Expand Up @@ -123,8 +124,9 @@ class AllSPDestinationsOutputWriter : public RJOutputWriter {
private:
bool skipInternal(nodeID_t dstNodeID) const override {
auto outputs = rjOutputs->ptrCast<AllSPDestinationOutputs>();
return dstNodeID == outputs->sourceNodeID || outputs->pathLengths->getMaskValueFromCurFrontier(
dstNodeID.offset) == PathLengths::UNVISITED;
return dstNodeID == outputs->sourceNodeID ||
outputs->pathLengths->getMaskValueFromCurFrontier(dstNodeID.offset) ==
PathLengths::UNVISITED;
}

private:
Expand Down Expand Up @@ -245,8 +247,8 @@ class AllSPDestinationsAlgorithm final : public SPAlgorithm {
auto outputWriter = std::make_unique<AllSPDestinationsOutputWriter>(clientContext,
output.get(), sharedState->getOutputNodeMaskMap());
auto frontierPair = std::make_unique<SinglePathLengthsFrontierPair>(frontier);
auto edgeCompute = std::make_unique<AllSPDestinationsEdgeCompute>(frontierPair.get(),
multiplicities);
auto edgeCompute =
std::make_unique<AllSPDestinationsEdgeCompute>(frontierPair.get(), multiplicities);
return RJCompState(std::move(frontierPair), std::move(edgeCompute),
sharedState->getOutputNodeMaskMap(), std::move(output), std::move(outputWriter));
}
Expand Down
9 changes: 4 additions & 5 deletions src/function/gds/gds_utils.cpp
Original file line number Diff line number Diff line change
@@ -1,13 +1,13 @@
#include "function/gds/gds_utils.h"

#include "binder/expression/property_expression.h"
#include "catalog/catalog.h"
#include "catalog/catalog_entry/table_catalog_entry.h"
#include "common/task_system/task_scheduler.h"
#include "function/gds/gds_task.h"
#include "graph/graph.h"
#include "graph/graph_entry.h"
#include "main/settings.h"
#include "binder/expression/property_expression.h"
#include "catalog/catalog_entry/table_catalog_entry.h"

using namespace kuzu::common;
using namespace kuzu::function;
Expand All @@ -33,8 +33,7 @@ void GDSComputeState::beginFrontierComputeBetweenTables(common::table_id_t currT
void GDSUtils::scheduleFrontierTask(catalog::TableCatalogEntry* fromEntry,
catalog::TableCatalogEntry* toEntry, catalog::TableCatalogEntry* relEntry, graph::Graph* graph,
ExtendDirection extendDirection, GDSComputeState& gdsComputeState,
processor::ExecutionContext* context, uint64_t numThreads,
const std::string& propertyToScan) {
processor::ExecutionContext* context, uint64_t numThreads, const std::string& propertyToScan) {
auto clientContext = context->clientContext;
auto transaction = clientContext->getTransaction();
auto info = FrontierTaskInfo(fromEntry, toEntry, relEntry, graph, extendDirection,
Expand Down Expand Up @@ -96,7 +95,7 @@ void GDSUtils::runFrontiersUntilConvergence(processor::ExecutionContext* context
compState.beginFrontierComputeBetweenTables(toEntry->getTableID(),
fromEntry->getTableID());
scheduleFrontierTask(toEntry, fromEntry, relEntry, graph, ExtendDirection::BWD,
compState, context, numThreads, propertyToScan);
compState, context, numThreads, propertyToScan);
} break;
default:
KU_UNREACHABLE;
Expand Down
5 changes: 3 additions & 2 deletions src/function/gds/rec_joins.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

#include "binder/binder.h"
#include "binder/expression/expression_util.h"
#include "binder/expression/property_expression.h"
#include "common/exception/interrupt.h"
#include "common/exception/runtime.h"
#include "common/task_system/progress_bar.h"
Expand All @@ -12,7 +13,6 @@
#include "storage/buffer_manager/memory_manager.h"
#include "storage/local_storage/local_node_table.h"
#include "storage/local_storage/local_storage.h"
#include "binder/expression/property_expression.h"

using namespace kuzu::binder;
using namespace kuzu::common;
Expand Down Expand Up @@ -212,7 +212,8 @@ void RJAlgorithm::exec(processor::ExecutionContext* context) {
rjCompState.initSource(sourceNodeID);
auto rjBindData = bindData->ptrCast<RJBindData>();
GDSUtils::runFrontiersUntilConvergence(context, rjCompState, graph,
rjBindData->extendDirection, rjBindData->upperBound, rjBindData->weightPropertyName);
rjBindData->extendDirection, rjBindData->upperBound,
rjBindData->weightPropertyName);
auto vertexCompute =
std::make_unique<RJVertexCompute>(clientContext->getMemoryManager(),
sharedState.get(), rjCompState.outputWriter->copy());
Expand Down
20 changes: 10 additions & 10 deletions src/function/gds/single_shortest_paths.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -40,8 +40,7 @@ class SingleSPDestinationsOutputWriter : public RJOutputWriter {
lengthVector = createVector(LogicalType::UINT16(), context->getMemoryManager());
}

void write(FactorizedTable& fTable, nodeID_t dstNodeID,
GDSOutputCounter* counter) override {
void write(FactorizedTable& fTable, nodeID_t dstNodeID, GDSOutputCounter* counter) override {
auto outputs = rjOutputs->ptrCast<SingleSPDestinationOutputs>();
auto length = outputs->pathLengths->getMaskValueFromCurFrontier(dstNodeID.offset);
dstNodeIDVector->setValue<nodeID_t>(0, dstNodeID);
Expand All @@ -53,14 +52,16 @@ class SingleSPDestinationsOutputWriter : public RJOutputWriter {
}

std::unique_ptr<RJOutputWriter> copy() override {
return std::make_unique<SingleSPDestinationsOutputWriter>(context, rjOutputs, outputNodeMask);
return std::make_unique<SingleSPDestinationsOutputWriter>(context, rjOutputs,
outputNodeMask);
}

private:
bool skipInternal(nodeID_t dstNodeID) const override {
auto outputs = rjOutputs->ptrCast<SingleSPDestinationOutputs>();
return dstNodeID == outputs->sourceNodeID || outputs->pathLengths->getMaskValueFromCurFrontier(
dstNodeID.offset) == PathLengths::UNVISITED;
return dstNodeID == outputs->sourceNodeID ||
outputs->pathLengths->getMaskValueFromCurFrontier(dstNodeID.offset) ==
PathLengths::UNVISITED;
}

private:
Expand All @@ -72,8 +73,7 @@ class SingleSPDestinationsEdgeCompute : public SPEdgeCompute {
explicit SingleSPDestinationsEdgeCompute(SinglePathLengthsFrontierPair* frontierPair)
: SPEdgeCompute{frontierPair} {};

std::vector<nodeID_t> edgeCompute(nodeID_t, NbrScanState::Chunk& resultChunk,
bool) override {
std::vector<nodeID_t> edgeCompute(nodeID_t, NbrScanState::Chunk& resultChunk, bool) override {
std::vector<nodeID_t> activeNodes;
resultChunk.forEach([&](auto nbrNode, auto) {
if (frontierPair->getPathLengths()->getMaskValueFromNextFrontier(nbrNode.offset) ==
Expand Down Expand Up @@ -135,7 +135,7 @@ class SingleSPDestinationsAlgorithm : public SPAlgorithm {
: SPAlgorithm{other} {}

expression_vector getResultColumns(Binder*) const override {
auto columns = getBaseResultColumns();
auto columns = getBaseResultColumns();
columns.push_back(bindData->ptrCast<RJBindData>()->lengthExpr);
return columns;
}
Expand All @@ -149,8 +149,8 @@ class SingleSPDestinationsAlgorithm : public SPAlgorithm {
auto clientContext = context->clientContext;
auto frontier = getPathLengthsFrontier(context, PathLengths::UNVISITED);
auto output = std::make_unique<SingleSPDestinationOutputs>(sourceNodeID, frontier);
auto outputWriter = std::make_unique<SingleSPDestinationsOutputWriter>(clientContext, output.get(),
sharedState->getOutputNodeMaskMap());
auto outputWriter = std::make_unique<SingleSPDestinationsOutputWriter>(clientContext,
output.get(), sharedState->getOutputNodeMaskMap());
auto frontierPair = std::make_unique<SinglePathLengthsFrontierPair>(output->pathLengths);
auto edgeCompute = std::make_unique<SingleSPDestinationsEdgeCompute>(frontierPair.get());
return RJCompState(std::move(frontierPair), std::move(edgeCompute),
Expand Down
57 changes: 29 additions & 28 deletions src/function/gds/weighted_shortest_paths.cpp
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
#include "function/gds_function.h"
#include "binder/binder.h"
#include "function/gds/gds_function_collection.h"
#include "function/gds/rec_joins.h"
#include "function/gds_function.h"
#include "main/client_context.h"
#include "processor/execution_context.h"
#include "binder/binder.h"

using namespace kuzu::common;
using namespace kuzu::storage;
Expand All @@ -21,9 +21,7 @@ class Weights {

void pinTable(table_id_t tableID) { weights = weightsMap.getData(tableID); }

T getWeight(offset_t offset) {
return weights[offset].load(std::memory_order_relaxed);
}
T getWeight(offset_t offset) { return weights[offset].load(std::memory_order_relaxed); }

void setWeight(offset_t offset, T val) {
weights[offset].store(val, std::memory_order_relaxed);
Expand Down Expand Up @@ -59,9 +57,11 @@ class Weights {
template<typename T>
class DestinationsEdgeCompute : public EdgeCompute {
public:
explicit DestinationsEdgeCompute(std::shared_ptr<Weights<T>> weights) : weights{std::move(weights)} {}
explicit DestinationsEdgeCompute(std::shared_ptr<Weights<T>> weights)
: weights{std::move(weights)} {}

std::vector<nodeID_t> edgeCompute(nodeID_t boundNodeID, graph::NbrScanState::Chunk &chunk, bool) override {
std::vector<nodeID_t> edgeCompute(nodeID_t boundNodeID, graph::NbrScanState::Chunk& chunk,
bool) override {
std::vector<nodeID_t> result;
chunk.forEach<T>([&](auto nbrNodeID, auto /* edgeID */, auto weight) {
if (weights->update(boundNodeID.offset, nbrNodeID.offset, weight)) {
Expand All @@ -86,9 +86,7 @@ struct DestinationOutputs : public RJOutputs {
DestinationOutputs(nodeID_t sourceNodeID, std::shared_ptr<Weights<T>> weights)
: RJOutputs{sourceNodeID}, weights{std::move(weights)} {}

void initRJFromSource(common::nodeID_t a) override {
weights->setWeight(a.offset, 0);
}
void initRJFromSource(common::nodeID_t a) override { weights->setWeight(a.offset, 0); }

void beginFrontierComputeBetweenTables(table_id_t, table_id_t nextFrontierTableID) override {
weights->pinTable(nextFrontierTableID);
Expand All @@ -105,12 +103,12 @@ class DestinationsOutputWriter : public RJOutputWriter {
DestinationsOutputWriter(main::ClientContext* context, RJOutputs* rjOutputs,
processor::NodeOffsetMaskMap* outputNodeMask, std::shared_ptr<Weights<T>> weights,
const LogicalType& weightType)
: RJOutputWriter{context, rjOutputs, outputNodeMask},
weights{std::move(weights)} {
: RJOutputWriter{context, rjOutputs, outputNodeMask}, weights{std::move(weights)} {
weightVector = createVector(weightType, context->getMemoryManager());
}

void write(processor::FactorizedTable &fTable, nodeID_t dstNodeID, processor::GDSOutputCounter *counter) override {
void write(processor::FactorizedTable& fTable, nodeID_t dstNodeID,
processor::GDSOutputCounter* counter) override {
dstNodeIDVector->setValue<nodeID_t>(0, dstNodeID);
auto weight = weights->getWeight(dstNodeID.offset);
weightVector->setValue<T>(0, weight);
Expand All @@ -121,7 +119,8 @@ class DestinationsOutputWriter : public RJOutputWriter {
}

std::unique_ptr<RJOutputWriter> copy() override {
return std::make_unique<DestinationsOutputWriter<T>>(context, rjOutputs, outputNodeMask, weights, weightVector->dataType);
return std::make_unique<DestinationsOutputWriter<T>>(context, rjOutputs, outputNodeMask,
weights, weightVector->dataType);
}

private:
Expand All @@ -141,7 +140,7 @@ class WeightedSPDestinationsAlgorithm : public SPAlgorithm {
WeightedSPDestinationsAlgorithm(const WeightedSPDestinationsAlgorithm& other)
: SPAlgorithm{other} {}

binder::expression_vector getResultColumns(binder::Binder *) const override {
binder::expression_vector getResultColumns(binder::Binder*) const override {
auto columns = getBaseResultColumns();
columns.push_back(bindData->ptrCast<RJBindData>()->weightOutputExpr);
return columns;
Expand Down Expand Up @@ -192,28 +191,30 @@ class WeightedSPDestinationsAlgorithm : public SPAlgorithm {
default:
break;
}
throw RuntimeException(stringFormat("{} weight type is not supported for weighted shortest path.", dataType.toString()));
throw RuntimeException(stringFormat(
"{} weight type is not supported for weighted shortest path.", dataType.toString()));
}

RJCompState getRJCompState(processor::ExecutionContext *context, nodeID_t sourceNodeID) override {
RJCompState getRJCompState(processor::ExecutionContext* context,
nodeID_t sourceNodeID) override {
auto clientContext = context->clientContext;
auto numNodes = sharedState->graph->getNumNodesMap(clientContext->getTransaction());
auto curFrontier = getPathLengthsFrontier(context, PathLengths::UNVISITED);
auto nextFrontier = getPathLengthsFrontier(context, PathLengths::UNVISITED);
auto frontierPair = std::make_unique<DoublePathLengthsFrontierPair>(curFrontier, nextFrontier);
auto frontierPair =
std::make_unique<DoublePathLengthsFrontierPair>(curFrontier, nextFrontier);
auto rjBindData = bindData->ptrCast<RJBindData>();
std::unique_ptr<EdgeCompute> edgeCompute;
std::unique_ptr<RJOutputs> outputs;
std::unique_ptr<RJOutputWriter> outputWriter;
auto& dataType = rjBindData->weightOutputExpr->getDataType();
visit(dataType,
[&]<typename T>(T) {
auto weight = std::make_shared<Weights<T>>(numNodes, clientContext->getMemoryManager());
edgeCompute = std::make_unique<DestinationsEdgeCompute<T>>(weight);
outputs = std::make_unique<DestinationOutputs<T>>(sourceNodeID, weight);
outputWriter = std::make_unique<DestinationsOutputWriter<T>>(clientContext,
outputs.get(), sharedState->getOutputNodeMaskMap(), weight, dataType);
});
visit(dataType, [&]<typename T>(T) {
auto weight = std::make_shared<Weights<T>>(numNodes, clientContext->getMemoryManager());
edgeCompute = std::make_unique<DestinationsEdgeCompute<T>>(weight);
outputs = std::make_unique<DestinationOutputs<T>>(sourceNodeID, weight);
outputWriter = std::make_unique<DestinationsOutputWriter<T>>(clientContext,
outputs.get(), sharedState->getOutputNodeMaskMap(), weight, dataType);
});
return RJCompState(std::move(frontierPair), std::move(edgeCompute),
sharedState->getOutputNodeMaskMap(), std::move(outputs), std::move(outputWriter));
}
Expand All @@ -231,5 +232,5 @@ GDSFunction WeightedSPDestinationsFunction::getFunction() {
return GDSFunction(name, std::move(params), std::move(algo));
}

}
}
} // namespace function
} // namespace kuzu
2 changes: 1 addition & 1 deletion src/include/function/gds/gds_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,10 @@

#include <optional>

#include "binder/expression/expression.h"
#include "catalog/catalog_entry/table_catalog_entry.h"
#include "common/enums/extend_direction.h"
#include "common/types/types.h"
#include "binder/expression/expression.h"

namespace kuzu {
namespace processor {
Expand Down
12 changes: 8 additions & 4 deletions src/parser/transform/transform_graph_pattern.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,8 @@ RelPattern Transformer::transformRelationshipPattern(
relType = QueryRelType::ALL_SHORTEST;
} else if (recursiveType->WSHORTEST()) {
relType = QueryRelType::WEIGHTED_SHORTEST;
recursiveInfo.weightPropertyName = transformPropertyKeyName(*recursiveType->oC_PropertyKeyName());
recursiveInfo.weightPropertyName =
transformPropertyKeyName(*recursiveType->oC_PropertyKeyName());
} else if (recursiveDetail->kU_RecursiveType()->SHORTEST()) {
relType = QueryRelType::SHORTEST;
} else if (recursiveDetail->kU_RecursiveType()->TRAIL()) {
Expand Down Expand Up @@ -147,13 +148,16 @@ RelPattern Transformer::transformRelationshipPattern(
if (!comprehension->kU_RecursiveProjectionItems().empty()) {
recursiveInfo.hasProjection = true;
KU_ASSERT(comprehension->kU_RecursiveProjectionItems().size() == 2);
auto relProjectionList = comprehension->kU_RecursiveProjectionItems(0)->oC_ProjectionItems();
auto relProjectionList =
comprehension->kU_RecursiveProjectionItems(0)->oC_ProjectionItems();
if (relProjectionList) {
recursiveInfo.relProjectionList = transformProjectionItems(*relProjectionList);
}
auto nodeProjectionList = comprehension->kU_RecursiveProjectionItems(1)->oC_ProjectionItems();
auto nodeProjectionList =
comprehension->kU_RecursiveProjectionItems(1)->oC_ProjectionItems();
if (nodeProjectionList) {
recursiveInfo.nodeProjectionList = transformProjectionItems(*nodeProjectionList);
recursiveInfo.nodeProjectionList =
transformProjectionItems(*nodeProjectionList);
}
}
}
Expand Down
3 changes: 2 additions & 1 deletion src/planner/plan/append_extend.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -152,7 +152,8 @@ void Planner::appendRecursiveExtendAsGDS(const std::shared_ptr<NodeExpression>&
bindData->pathNodeIDsExpr = recursiveInfo->pathNodeIDsExpr;
bindData->pathEdgeIDsExpr = recursiveInfo->pathEdgeIDsExpr;
if (recursiveInfo->weightPropertyExpr != nullptr) {
bindData->weightPropertyName = recursiveInfo->weightPropertyExpr->ptrCast<PropertyExpression>()->getPropertyName();
bindData->weightPropertyName =
recursiveInfo->weightPropertyExpr->ptrCast<PropertyExpression>()->getPropertyName();
bindData->weightOutputExpr = recursiveInfo->weightOutputExpr;
}
gdsFunction.gds->setBindData(std::move(bindData));
Expand Down

0 comments on commit 9b54987

Please sign in to comment.